diff --git a/src/twinkle/dataset/base.py b/src/twinkle/dataset/base.py index bb41b609..8eb08c5b 100644 --- a/src/twinkle/dataset/base.py +++ b/src/twinkle/dataset/base.py @@ -4,7 +4,7 @@ from dataclasses import dataclass from typing import Callable, Type, Union, Dict, Any -from datasets import interleave_datasets, concatenate_datasets, load_dataset, IterableDataset +from datasets import interleave_datasets, concatenate_datasets, load_dataset, IterableDataset, DatasetDict from torch.utils.data import Dataset as TorchDataset import twinkle @@ -132,6 +132,21 @@ def _load_dataset(dataset_meta: DatasetMeta, **kwargs): dataset = load_dataset(file_type, data_files=dataset_id, **kwargs) else: dataset = HubOperation.load_dataset(dataset_id, subset_name, split, **kwargs) + + # fix: Some dataset sources return DatasetDict instead of Dataset, which breaks downstream select/map calls. + # fix: Normalize split resolution here (target split first, then train) and fail early with a clear error. + if isinstance(dataset, DatasetDict): + if split in dataset: + dataset = dataset[split] + elif 'train' in dataset: + dataset = dataset['train'] + else: + available_splits = list(dataset.keys()) + raise KeyError( + f"Split '{split}' not found for dataset '{dataset_id}'. " + f'Available splits: {available_splits}' + ) + if isinstance(dataset_meta.data_slice, Iterable) and hasattr(dataset, '__len__'): iter_list = [] diff --git a/src/twinkle/model/base.py b/src/twinkle/model/base.py index 2d3ee364..a04e66bd 100644 --- a/src/twinkle/model/base.py +++ b/src/twinkle/model/base.py @@ -138,11 +138,20 @@ def upload_to_hub(self, checkpoint_dir: str, hub_model_id: str, hub_token: Optio ) def _try_init_process_group(self): + import os import torch import torch.distributed as dist if not dist.is_initialized() and Platform.get_world_size() > 1: torch_util.set_device() backend = Platform.device_backend() + if backend == "hccl": + # fix: In multi-job NPU runs, HCCL default ports may collide (bind/listen failures). + # fix: Inject deterministic per-job port ranges before PG init to reduce cross-job conflicts. + # Keep training-side HCCL sockets on a per-job port layout to + # avoid collisions with other jobs on the same host. + from twinkle.utils.network import _ensure_hccl_socket_env + master_port = int(os.environ.get("MASTER_PORT", "29500")) + _ensure_hccl_socket_env(master_port) init_kwargs = { "backend": backend, "init_method": "env://", diff --git a/src/twinkle/sampler/vllm_sampler/vllm_engine.py b/src/twinkle/sampler/vllm_sampler/vllm_engine.py index 2561adae..8a5a491a 100644 --- a/src/twinkle/sampler/vllm_sampler/vllm_engine.py +++ b/src/twinkle/sampler/vllm_sampler/vllm_engine.py @@ -7,6 +7,7 @@ from twinkle import get_logger from twinkle.sampler.base_engine import BaseSamplerEngine from twinkle.data_format.sampling import StopReason, SamplingParams, SampleResponse, SampledSequence +from twinkle.utils.platform import get_vllm_device_uuid import inspect logger = get_logger() @@ -569,8 +570,11 @@ async def _sync_iter(): use_gpu_ipc = first_tensor.is_cuda use_shm = not use_gpu_ipc - # Get device UUID for ZMQ handle - device_uuid = current_platform.get_device_uuid(0) + # fix: On NPU, current_platform.get_device_uuid may be unimplemented and break receive_weights flow. + # fix: Route through platform-level fallback so IPC socket name remains stable. + # Get device UUID for ZMQ handle. + # For NPU, this is resolved from `npu-smi info` Bus-Id when needed. + device_uuid = get_vllm_device_uuid(0) zmq_handle = f"ipc:///tmp/twinkle-ipc-{device_uuid}.sock" bucket_size = bucket_size_mb << 20 diff --git a/src/twinkle/sampler/vllm_sampler/vllm_sampler.py b/src/twinkle/sampler/vllm_sampler/vllm_sampler.py index e0c9e6da..2573fc38 100644 --- a/src/twinkle/sampler/vllm_sampler/vllm_sampler.py +++ b/src/twinkle/sampler/vllm_sampler/vllm_sampler.py @@ -138,7 +138,11 @@ def __init__( self.engine: VLLMEngine = self._run_in_loop( self._create_engine_async(VLLMEngine, model_id, engine_kwargs) ) - self._run_in_loop(self.engine.engine.collective_rpc("monkey_patch_model")) + # fix: On NPU, monkey_patch_model can trigger Triton compatibility errors and abort sampler init. + # fix: Explicitly skip this patch on NPU and keep it for non-NPU paths only. + # NPU platform may trigger triton errors with monkey_patch_model + if Platform.get_platform().device_prefix() != 'npu': + self._run_in_loop(self.engine.engine.collective_rpc("monkey_patch_model")) VLLMLoraWeights()(self) @@ -559,4 +563,3 @@ def shutdown(self): self._async_thread.join(timeout=5) except Exception as e: logger.warning(f"vLLMSampler event loop shutdown error: {e}") - diff --git a/src/twinkle/sampler/vllm_sampler/vllm_worker_extension.py b/src/twinkle/sampler/vllm_sampler/vllm_worker_extension.py index b5a824d0..00b234b2 100644 --- a/src/twinkle/sampler/vllm_sampler/vllm_worker_extension.py +++ b/src/twinkle/sampler/vllm_sampler/vllm_worker_extension.py @@ -21,6 +21,7 @@ from twinkle import get_logger import torch from twinkle.utils.framework import Torch +from twinkle.utils.platform import get_vllm_device_uuid logger = get_logger() @@ -64,12 +65,6 @@ def _rebuild_shared_memory(name: str, size: int): return tensor, shm -def _get_device_uuid(device_id: int) -> str: - """Get unique device identifier.""" - from vllm.platforms import current_platform - return current_platform.get_device_uuid(device_id) - - class TwinkleWorkerExtension: """Extension class for vLLM workers to support weight synchronization. @@ -122,7 +117,10 @@ def update_weights_from_ipc( import torch.distributed as dist if self.device is None: - self.device = torch.device(Torch.get_device()) + # fix: In some worker paths, omitting local_rank can pick the wrong device / trigger get_device arg issues. + # fix: Pass local_rank when available so each worker binds to the expected local device. + print(f"VLLM Worker local_rank: {getattr(self, 'local_rank', None)} <<<<<<<<<<<<< {Torch.get_device()}") + self.device = torch.device(Torch.get_device(getattr(self, "local_rank", None))) if peft_config and base_sync_done: self.remove_lora(VLLM_LORA_INT_ID) @@ -257,7 +255,8 @@ def load_synced_weights( base_sync_done: If True with peft_config, load as LoRA adapter. """ if self.device is None: - self.device = torch.device(Torch.get_device()) + # fix: Keep device resolution consistent with update_weights_from_ipc to avoid path divergence. + self.device = torch.device(Torch.get_device(getattr(self, "local_rank", None))) weight_list = list(weights.items()) self._load_weights(weight_list, peft_config=peft_config, base_sync_done=base_sync_done) @@ -374,5 +373,6 @@ def _load_weights( def _get_zmq_handle(self) -> str: """Get ZMQ handle for IPC communication.""" if not hasattr(self, '_device_uuid') or not self._device_uuid: - self._device_uuid = _get_device_uuid(self.device.index) + # fix: Always use platform fallback to avoid worker-side crashes when NPU get_device_uuid is unimplemented. + self._device_uuid = get_vllm_device_uuid(self.device.index) return f"ipc:///tmp/twinkle-ipc-{self._device_uuid}.sock" diff --git a/src/twinkle/utils/network.py b/src/twinkle/utils/network.py index 83d686fb..ef37be3f 100644 --- a/src/twinkle/utils/network.py +++ b/src/twinkle/utils/network.py @@ -1,10 +1,48 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +import os import socket from datetime import timedelta from typing import Optional import torch +# ref: https://www.hiascend.com/document/detail/zh/canncommercial/83RC1/maintenref/envvar/envref_07_0144.html +# HCCL base port anchor. HCCL derives internal listen/connect ports from this base. +_HCCL_IF_BASE_PORT_ENV = "HCCL_IF_BASE_PORT" +# Host-side socket port pool used by HCCL in multi-process communication. +_HCCL_HOST_SOCKET_PORT_RANGE_ENV = "HCCL_HOST_SOCKET_PORT_RANGE" +# NPU-side socket port pool used by HCCL for device communication channels. +_HCCL_NPU_SOCKET_PORT_RANGE_ENV = "HCCL_NPU_SOCKET_PORT_RANGE" + + +def _derive_hccl_socket_env_defaults(master_port: int) -> dict: + """Derive deterministic default HCCL socket env values from master_port.""" + # Keep values stable per job and spread jobs across non-overlapping ranges. + host_offset = master_port % 8000 + return { + _HCCL_IF_BASE_PORT_ENV: str(20000 + ((master_port + 997) % 20000)), + _HCCL_HOST_SOCKET_PORT_RANGE_ENV: f"{40000 + host_offset}-{40000 + host_offset + 511}", + _HCCL_NPU_SOCKET_PORT_RANGE_ENV: f"{50000 + host_offset}-{50000 + host_offset + 511}", + } + + +def _ensure_hccl_socket_env(master_port: int, environ: Optional[dict] = None) -> None: + """Set deterministic HCCL socket env defaults to avoid port collisions. + + In multi-job environments, HCCL's default base port (60000) can collide + across concurrent jobs and lead to: + `ra_hdc_socket_listen_start ... ret(-98)`. + + We derive a per-job port layout from `master_port` so all ranks use the + same values while reducing cross-job conflicts. Explicit user settings are + preserved and never overwritten. + """ + # fix: We hit `ra_hdc_socket_listen_start ... ret(-98)` due to HCCL port collisions. + # fix: Derive stable ranges from master_port and preserve explicit user overrides. + env = os.environ if environ is None else environ + for key, value in _derive_hccl_socket_env_defaults(master_port).items(): + env.setdefault(key, value) + def is_valid_ipv6_address(ip: str) -> bool: """Check if the given string is a valid IPv6 address.""" @@ -87,6 +125,8 @@ def stateless_init_process_group( from vllm.distributed.utils import StatelessProcessGroup if backend == "hccl": + # fix: Stateless PG + HCCL path needs the same port policy, otherwise workers can still collide. + _ensure_hccl_socket_env(master_port) from vllm_ascend.distributed.device_communicators.pyhccl import ( PyHcclCommunicator as Communicator, ) diff --git a/src/twinkle/utils/platform.py b/src/twinkle/utils/platform.py index 382e7169..0fe97c26 100644 --- a/src/twinkle/utils/platform.py +++ b/src/twinkle/utils/platform.py @@ -1,7 +1,10 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import os import platform +import hashlib +import re import shutil +import socket import subprocess from abc import ABC from dataclasses import dataclass, field @@ -641,5 +644,72 @@ def is_last_rank(): return True return dist.get_rank() == dist.get_world_size() - 1 + +def _resolve_ascend_physical_device_id(device_id: int) -> int: + """Map local NPU device index to physical device id via visible devices.""" + visible = os.environ.get("ASCEND_RT_VISIBLE_DEVICES", "").strip() + if not visible: + return device_id + parts = [p.strip() for p in visible.split(",") if p.strip()] + if device_id < 0 or device_id >= len(parts): + return device_id + return int(parts[device_id]) + + +def _get_npu_bus_id_from_npu_smi(device_id: int) -> Optional[str]: + """Get NPU Bus-Id from `npu-smi info` output.""" + try: + physical_id = _resolve_ascend_physical_device_id(device_id) + except Exception: + physical_id = device_id + + try: + output = subprocess.check_output( + ["npu-smi", "info"], + text=True, + stderr=subprocess.STDOUT, + timeout=5, + ) + except Exception: + return None + + # fix: vllm-ascend may not implement get_device_uuid, but we still need a reproducible cross-process device id. + # fix: Prefer physical Bus-Id parsed from npu-smi instead of unstable/random identifiers. + # Typical line: + # | 0 0 | 0000:9D:00.0 | ... + pattern = re.compile( + r"^\|\s*\d+\s+(\d+)\s*\|\s*" + r"([0-9A-Fa-f]{4}:[0-9A-Fa-f]{2}:[0-9A-Fa-f]{2}\.[0-9A-Fa-f])\s*\|", + re.MULTILINE, + ) + for match in pattern.finditer(output): + phy_id = int(match.group(1)) + if phy_id == physical_id: + return match.group(2).lower() + return None + + +def get_vllm_device_uuid(device_id: int = 0) -> str: + """Get vLLM device uuid with NPU Bus-Id special handling.""" + from vllm.platforms import current_platform + + try: + return current_platform.get_device_uuid(device_id) + except NotImplementedError: + # fix: Root cause was NPU platform calling vLLM base placeholder and raising NotImplementedError. + # fix: Use Bus-Id fallback first so sender/receiver compute the same IPC endpoint. + # NPU special case: prefer stable PCIe Bus-Id from npu-smi. + bus_id = _get_npu_bus_id_from_npu_smi(device_id) + if bus_id: + return bus_id + # fix: If npu-smi is unavailable, fall back to deterministic hash instead of failing hard. + # Generic deterministic fallback to keep sender/receiver socket names aligned. + visible = os.environ.get("ASCEND_RT_VISIBLE_DEVICES") or os.environ.get( + "CUDA_VISIBLE_DEVICES", "" + ) + raw = f"{socket.gethostname()}:{visible}:{device_id}" + return hashlib.sha1(raw.encode("utf-8")).hexdigest()[:16] + + def is_master(): return Platform.is_master()