diff --git a/src/twinkle/advantage/__init__.py b/src/twinkle/advantage/__init__.py index 57912415..cbf5565d 100644 --- a/src/twinkle/advantage/__init__.py +++ b/src/twinkle/advantage/__init__.py @@ -3,34 +3,8 @@ from .grpo import GRPOAdvantage from .rloo import RLOOAdvantage - -# TODO: Temporary helpers added to unblock cookbook/grpo examples. -# Each call creates a new Advantage instance, not suitable for production. -# Remove once the framework provides a proper advantage computation API. -def compute_advantages(rewards, num_generations=1, scale='group', **kwargs): - """Backward-compatible helper for GRPO advantage computation.""" - return GRPOAdvantage()( - rewards=rewards, - num_generations=num_generations, - scale=scale, - **kwargs, - ) - - -def compute_advantages_rloo(rewards, num_generations=1, scale='group', **kwargs): - """Backward-compatible helper for RLOO advantage computation.""" - return RLOOAdvantage()( - rewards=rewards, - num_generations=num_generations, - scale=scale, - **kwargs, - ) - - __all__ = [ 'Advantage', 'GRPOAdvantage', 'RLOOAdvantage', - 'compute_advantages', - 'compute_advantages_rloo', ] diff --git a/src/twinkle/checkpoint_engine/base.py b/src/twinkle/checkpoint_engine/base.py index 346b5005..9cb95e08 100644 --- a/src/twinkle/checkpoint_engine/base.py +++ b/src/twinkle/checkpoint_engine/base.py @@ -1,15 +1,17 @@ # Copyright (c) ModelScope Contributors. All rights reserved. # Adapted from https://github.com/volcengine/verl/blob/main/verl/checkpoint_engine/base.py -import torch from abc import ABC, abstractmethod -from typing import Any, AsyncGenerator, Generator, TypedDict +from typing import TYPE_CHECKING, Any, AsyncGenerator, Generator, TypedDict + +if TYPE_CHECKING: + import torch class TensorMeta(TypedDict): """Metadata for a tensor in the weight bucket.""" name: str - shape: torch.Size - dtype: torch.dtype + shape: 'torch.Size' + dtype: 'torch.dtype' offset: int @@ -99,7 +101,7 @@ def finalize(self): raise NotImplementedError @abstractmethod - async def send_weights(self, weights: Generator[tuple[str, torch.Tensor], None, None]): + async def send_weights(self, weights: Generator[tuple[str, 'torch.Tensor'], None, None]): """Send model weights to rollout workers. This method streams weights in buckets to avoid memory issues with @@ -112,7 +114,7 @@ async def send_weights(self, weights: Generator[tuple[str, torch.Tensor], None, raise NotImplementedError @abstractmethod - async def receive_weights(self) -> AsyncGenerator[tuple[str, torch.Tensor], None]: + async def receive_weights(self) -> AsyncGenerator[tuple[str, 'torch.Tensor'], None]: """Receive model weights from trainer. This method receives weights in buckets and yields them as they diff --git a/src/twinkle/checkpoint_engine/hccl_checkpoint_engine.py b/src/twinkle/checkpoint_engine/hccl_checkpoint_engine.py index 16b4dd05..85c2f0a8 100644 --- a/src/twinkle/checkpoint_engine/hccl_checkpoint_engine.py +++ b/src/twinkle/checkpoint_engine/hccl_checkpoint_engine.py @@ -17,7 +17,7 @@ from typing import Any, AsyncGenerator, Generator from twinkle import get_logger -from twinkle.utils.network import find_free_port, find_node_ip, is_valid_ipv6_address, stateless_init_process_group +from twinkle.utils import find_free_port, find_node_ip, is_valid_ipv6_address, stateless_init_process_group from .base import CheckpointEngine, TensorMeta logger = get_logger() diff --git a/src/twinkle/infra/_ray/resource_manager.py b/src/twinkle/infra/_ray/resource_manager.py index 7a45aa8e..d5e87e53 100644 --- a/src/twinkle/infra/_ray/resource_manager.py +++ b/src/twinkle/infra/_ray/resource_manager.py @@ -173,8 +173,7 @@ def get_visible_devices(): self.device_groups = {} ray_address = str(ray.get_runtime_context().gcs_address) - assert len(groups) == len(visible_devices) - for group, visible_device_list in zip(groups, self.visible_devices): + for group in groups: if group.device_type != 'CPU': ranks = group.ranks gpus_per_worker = getattr(group, 'gpus_per_worker', 1) @@ -195,8 +194,12 @@ def get_visible_devices(): worker_ranks = normalized_ranks[start_idx:start_idx + gpus_per_worker] # All GPUs for a worker should be on the same node - node_ranks = [r // nproc_per_node for r in worker_ranks] - gpu_ranks_local = [visible_device_list[r % nproc_per_node] for r in worker_ranks] + gpu_ranks_local = [] + for r in worker_ranks: + node_rank = r // nproc_per_node + node_ranks.append(node_rank) + gpu_ranks = self.visible_devices[node_rank][r % nproc_per_node] + gpu_ranks_local.append(gpu_ranks) if len(set(node_ranks)) > 1: raise ValueError(f"DeviceGroup '{group.name}': GPUs {worker_ranks} span multiple nodes. " @@ -211,7 +214,7 @@ def get_visible_devices(): else: for alloc_rank in normalized_ranks: node_rank = alloc_rank // nproc_per_node - gpu_rank = visible_device_list[alloc_rank % nproc_per_node] + gpu_rank = self.visible_devices[node_rank][alloc_rank % nproc_per_node] local_device_groups.append( dict(gpu_rank=[gpu_rank], placement_group=self.node2pg[node_rank], ray_address=ray_address)) diff --git a/src/twinkle/model/base.py b/src/twinkle/model/base.py index 62f430b4..b4550350 100644 --- a/src/twinkle/model/base.py +++ b/src/twinkle/model/base.py @@ -145,9 +145,9 @@ def _try_init_process_group(self): # fix: Inject deterministic per-job port ranges before PG init to reduce cross-job conflicts. # Keep training-side HCCL sockets on a per-job port layout to # avoid collisions with other jobs on the same host. - from twinkle.utils.network import _ensure_hccl_socket_env + from twinkle.utils.platforms import ensure_hccl_socket_env master_port = int(os.environ.get('MASTER_PORT', '29500')) - _ensure_hccl_socket_env(master_port) + ensure_hccl_socket_env(master_port) init_kwargs = { 'backend': backend, 'init_method': 'env://', diff --git a/src/twinkle/model/megatron/megatron.py b/src/twinkle/model/megatron/megatron.py index 3ac632d2..aa74e72e 100644 --- a/src/twinkle/model/megatron/megatron.py +++ b/src/twinkle/model/megatron/megatron.py @@ -1213,7 +1213,7 @@ def _save_megatron_format(self, output_dir: str, adapter_name: str, lora_convert torch.save(cpu_state_dict, checkpoint_path) def _save_tokenizer(self, output_dir: str, **kwargs): - from twinkle.utils.platform import is_last_rank + from twinkle.utils import is_last_rank if not is_last_rank(): return @@ -1344,7 +1344,7 @@ def add_adapter_to_model( @remote_function() def apply_patch(self, patch_cls: Union[Patch, Type[Patch], str], **kwargs): - apply_patch(self, patch_cls, **kwargs) + apply_patch(self.model, patch_cls, **kwargs) @remote_function(dispatch='all') def set_template(self, template_cls: Union[Template, Type[Template], str], **kwargs): diff --git a/src/twinkle/model/megatron/model/gpt_bridge.py b/src/twinkle/model/megatron/model/gpt_bridge.py index 58e40440..a4b59c9c 100644 --- a/src/twinkle/model/megatron/model/gpt_bridge.py +++ b/src/twinkle/model/megatron/model/gpt_bridge.py @@ -18,8 +18,7 @@ from twinkle.hub import HubOperation from twinkle.model.megatron.args import get_args # Use twinkle's get_args from twinkle.utils import (MxFp4Dequantizer, SafetensorLazyLoader, StreamingSafetensorSaver, deep_getattr, get_logger, - get_modules_to_not_convert, get_multimodal_target_regex, requires) -from twinkle.utils.platform import is_last_rank + get_modules_to_not_convert, get_multimodal_target_regex, is_last_rank, requires) logger = get_logger() diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index 6f80699b..45df3082 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -754,7 +754,7 @@ def set_lr_scheduler(self, scheduler_cls: Union[Type[LRScheduler], str, LRSchedu @remote_function() def apply_patch(self, patch_cls: Union[Patch, Type[Patch], str], **kwargs): - apply_patch(self, patch_cls, **kwargs) + apply_patch(self.model, patch_cls, **kwargs) def __del__(self): HubOperation.wait_for() diff --git a/src/twinkle/patch/base.py b/src/twinkle/patch/base.py index af532ba3..c9387b0e 100644 --- a/src/twinkle/patch/base.py +++ b/src/twinkle/patch/base.py @@ -1,10 +1,11 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -from typing import Any, Type, Union +from typing import TYPE_CHECKING, List, Union -from twinkle.utils import construct_class +if TYPE_CHECKING: + import torch class Patch: - def __call__(self, module, *args, **kwargs): + def __call__(self, module: Union['torch.nn.Module', List['torch.nn.Module']], *args, **kwargs): ... diff --git a/src/twinkle/sampler/vllm_sampler/vllm_engine.py b/src/twinkle/sampler/vllm_sampler/vllm_engine.py index 2da3300c..c7b886fe 100644 --- a/src/twinkle/sampler/vllm_sampler/vllm_engine.py +++ b/src/twinkle/sampler/vllm_sampler/vllm_engine.py @@ -8,7 +8,7 @@ from twinkle import get_logger from twinkle.data_format.sampling import SampledSequence, SampleResponse, SamplingParams, StopReason from twinkle.sampler.base_engine import BaseSamplerEngine -from twinkle.utils.platform import get_vllm_device_uuid +from twinkle.utils import Platform logger = get_logger() @@ -524,7 +524,7 @@ async def _sync_iter(): # fix: Route through platform-level fallback so IPC socket name remains stable. # Get device UUID for ZMQ handle. # For NPU, this is resolved from `npu-smi info` Bus-Id when needed. - device_uuid = get_vllm_device_uuid(0) + device_uuid = Platform.get_vllm_device_uuid(0) zmq_handle = f'ipc:///tmp/twinkle-ipc-{device_uuid}.sock' bucket_size = bucket_size_mb << 20 diff --git a/src/twinkle/sampler/vllm_sampler/vllm_sampler.py b/src/twinkle/sampler/vllm_sampler/vllm_sampler.py index a5959248..b4d1c6fd 100644 --- a/src/twinkle/sampler/vllm_sampler/vllm_sampler.py +++ b/src/twinkle/sampler/vllm_sampler/vllm_sampler.py @@ -30,7 +30,7 @@ from twinkle.data_format import InputFeature, SampledSequence, SampleResponse, SamplingParams, Trajectory from twinkle.patch.vllm_lora_weights import VLLMLoraWeights from twinkle.sampler.base import Sampler -from twinkle.utils.platform import Platform +from twinkle.utils import Platform logger = get_logger() diff --git a/src/twinkle/sampler/vllm_sampler/vllm_worker_extension.py b/src/twinkle/sampler/vllm_sampler/vllm_worker_extension.py index 7941ebf0..fa2fb748 100644 --- a/src/twinkle/sampler/vllm_sampler/vllm_worker_extension.py +++ b/src/twinkle/sampler/vllm_sampler/vllm_worker_extension.py @@ -21,8 +21,8 @@ from typing import Dict, List, Optional, Tuple from twinkle import get_logger +from twinkle.utils import Platform from twinkle.utils.framework import Torch -from twinkle.utils.platform import get_vllm_device_uuid logger = get_logger() @@ -376,5 +376,5 @@ def _get_zmq_handle(self) -> str: """Get ZMQ handle for IPC communication.""" if not hasattr(self, '_device_uuid') or not self._device_uuid: # fix: Always use platform fallback to avoid worker-side crashes when NPU get_device_uuid is unimplemented. - self._device_uuid = get_vllm_device_uuid(self.device.index) + self._device_uuid = Platform.get_vllm_device_uuid(self.device.index) return f'ipc:///tmp/twinkle-ipc-{self._device_uuid}.sock' diff --git a/src/twinkle/utils/__init__.py b/src/twinkle/utils/__init__.py index 0d84d6a6..edcefc34 100644 --- a/src/twinkle/utils/__init__.py +++ b/src/twinkle/utils/__init__.py @@ -1,15 +1,16 @@ # Copyright (c) ModelScope Contributors. All rights reserved. from .dequantizer import Fp8Dequantizer, MxFp4Dequantizer +from .device_mesh import DeviceGroup, DeviceMesh, is_last_rank, is_master from .framework import Framework as framework_util from .framework import Torch as torch_util from .import_utils import exists, requires from .loader import Plugin, construct_class from .logger import get_logger -from .network import find_free_port, find_node_ip +from .network import find_free_port, find_node_ip, is_valid_ipv6_address from .parallel import processing_lock -from .platform import GPU, NPU, DeviceGroup, DeviceMesh, Platform +from .platforms import GPU, NPU, Platform, ensure_hccl_socket_env, ensure_npu_backend from .safetensors import LazyTensor, SafetensorLazyLoader, StreamingSafetensorSaver -from .torch_utils import to_device +from .torch_utils import pad_sequence_to_length, selective_log_softmax, stateless_init_process_group, to_device from .transformers_utils import find_all_linears, find_layers, get_modules_to_not_convert, get_multimodal_target_regex from .unsafe import check_unsafe, trust_remote_code from .utils import copy_files_by_pattern, deep_getattr diff --git a/src/twinkle/utils/platform.py b/src/twinkle/utils/device_mesh.py similarity index 66% rename from src/twinkle/utils/platform.py rename to src/twinkle/utils/device_mesh.py index 0e1d9c97..1aa8bc50 100644 --- a/src/twinkle/utils/platform.py +++ b/src/twinkle/utils/device_mesh.py @@ -1,17 +1,11 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -import hashlib import numpy as np import os -import platform -import re -import shutil -import socket -import subprocess -from abc import ABC from dataclasses import dataclass, field -from functools import lru_cache from itertools import product -from typing import Dict, List, Optional, Type, Union +from typing import Dict, List, Optional, Union + +from .platforms import Platform @dataclass @@ -493,204 +487,6 @@ class DeviceGroup: _device_mesh: Dict[str, DeviceMesh] = field(default_factory=dict) -class Platform(ABC): - - @staticmethod - def _ensure_npu_backend() -> None: - try: - import torch_npu # noqa: F401 - except Exception as exc: - raise RuntimeError('NPU backend is not available. Please install torch_npu/Ascend PyTorch.') from exc - - @staticmethod - def visible_device_env(platform: str = None) -> str: - return Platform.get_platform(platform).visible_device_env() - - @staticmethod - def device_prefix(platform: str = None) -> str: - return Platform.get_platform(platform).device_prefix() - - @staticmethod - def get_platform_names() -> List[str]: - return ['GPU', 'NPU', 'MPS'] - - @staticmethod - def get_platform(platform: str = None) -> Type['Platform']: - if platform is None: - if shutil.which('npu-smi'): - Platform._ensure_npu_backend() - return NPU - elif shutil.which('nvidia-smi'): - return GPU - elif MPS.is_mps_available(): - return MPS - else: - return GPU - elif platform.upper() in ('GPU', 'CUDA'): - return GPU - elif platform.upper() == 'NPU': - Platform._ensure_npu_backend() - return NPU - elif platform.upper() == 'MPS': - return MPS - else: - raise ValueError(f'Unsupported platform: {platform}.') - - @staticmethod - def get_rank() -> int: - """Get the global rank""" - return int(os.getenv('RANK', -1)) - - @staticmethod - def get_local_rank() -> int: - """Get the local rank""" - return int(os.getenv('LOCAL_RANK', -1)) - - @staticmethod - def get_world_size() -> int: - """Get the world size""" - return int(os.getenv('WORLD_SIZE') or os.getenv('_PATCH_WORLD_SIZE') or 1) - - @staticmethod - def get_local_world_size() -> int: - """Get the local world size""" - return int(os.getenv('LOCAL_WORLD_SIZE', None) or os.getenv('LOCAL_SIZE', 1)) - - @staticmethod - def get_nnodes() -> int: - """Get the node count""" - return int(os.getenv('NNODES', 1)) - - @staticmethod - def get_node_rank() -> int: - """Get the current node rank""" - return int(os.getenv('NODE_RANK', 0)) - - @staticmethod - def is_local_master() -> bool: - """Get if current is the local master""" - local_rank = Platform.get_local_rank() - return local_rank in {-1, 0} - - @staticmethod - def is_master() -> bool: - """Get if current is the global master""" - rank = Platform.get_rank() - return rank in {-1, 0} - - @staticmethod - def is_last_rank() -> bool: - """Get if current is the last rank""" - rank = Platform.get_rank() - world_size = Platform.get_world_size() - return rank in {-1, world_size - 1} - - @staticmethod - def get_peer_index(target_size, rank=None, world_size=None): - if rank is None: - rank = Platform.get_rank() - if rank < 0: - rank = 0 - if world_size is None: - world_size = Platform.get_world_size() - if world_size <= 0: - world_size = 1 - - k, m = divmod(target_size, world_size) - start_idx = rank * k + min(rank, m) - end_idx = (rank + 1) * k + min(rank + 1, m) - if target_size < world_size: - start_idx = rank % target_size - end_idx = start_idx + 1 - - return slice(start_idx, end_idx) - - @staticmethod - def get_local_device(idx: int = None, *, platform: str = None): - platform = Platform.get_platform(platform) - if idx is None: - idx = Platform.get_local_rank() - if idx < 0: - idx = 0 - return platform.get_local_device(idx) - - @staticmethod - def device_backend(platform: str = None): - platform = Platform.get_platform(platform) - return platform.device_backend() - - -class GPU(Platform): - - @staticmethod - def visible_device_env(): - return 'CUDA_VISIBLE_DEVICES' - - @staticmethod - def device_prefix(): - return 'cuda' - - @staticmethod - def get_local_device(idx, **kwargs) -> str: - return f'cuda:{idx}' - - @staticmethod - def device_backend(platform: str = None): - return 'nccl' - - -class NPU(Platform): - - @staticmethod - def visible_device_env(): - # Ascend runtime uses ASCEND_RT_VISIBLE_DEVICES. - return 'ASCEND_RT_VISIBLE_DEVICES' - - @staticmethod - def device_prefix(): - return 'npu' - - @staticmethod - def get_local_device(idx, **kwargs) -> str: - return f'npu:{idx}' - - @staticmethod - def device_backend(platform: str = None): - return 'hccl' - - -class MPS(Platform): - - @staticmethod - def visible_device_env(): - return None - - @staticmethod - def device_prefix(): - return 'mps' - - @staticmethod - def get_local_device(idx, **kwargs) -> str: - return 'mps' - - @staticmethod - def device_backend(platform: str = None): - return 'gloo' - - @lru_cache - @staticmethod - def is_mps_available(): - if platform.system() != 'Darwin': - return False - try: - output = subprocess.check_output(['system_profiler', 'SPDisplaysDataType'], - stderr=subprocess.DEVNULL, - text=True) - return 'Metal Support' in output - except Exception: # noqa - return False - - def is_last_rank(): import torch.distributed as dist if not dist.is_initialized(): @@ -698,69 +494,5 @@ def is_last_rank(): return dist.get_rank() == dist.get_world_size() - 1 -def _resolve_ascend_physical_device_id(device_id: int) -> int: - """Map local NPU device index to physical device id via visible devices.""" - visible = os.environ.get('ASCEND_RT_VISIBLE_DEVICES', '').strip() - if not visible: - return device_id - parts = [p.strip() for p in visible.split(',') if p.strip()] - if device_id < 0 or device_id >= len(parts): - return device_id - return int(parts[device_id]) - - -def _get_npu_bus_id_from_npu_smi(device_id: int) -> Optional[str]: - """Get NPU Bus-Id from `npu-smi info` output.""" - try: - physical_id = _resolve_ascend_physical_device_id(device_id) - except Exception: - physical_id = device_id - - try: - output = subprocess.check_output( - ['npu-smi', 'info'], - text=True, - stderr=subprocess.STDOUT, - timeout=5, - ) - except Exception: - return None - - # fix: vllm-ascend may not implement get_device_uuid, but we still need a reproducible cross-process device id. - # fix: Prefer physical Bus-Id parsed from npu-smi instead of unstable/random identifiers. - # Typical line: - # | 0 0 | 0000:9D:00.0 | ... - pattern = re.compile( - r'^\|\s*\d+\s+(\d+)\s*\|\s*' - r'([0-9A-Fa-f]{4}:[0-9A-Fa-f]{2}:[0-9A-Fa-f]{2}\.[0-9A-Fa-f])\s*\|', - re.MULTILINE, - ) - for match in pattern.finditer(output): - phy_id = int(match.group(1)) - if phy_id == physical_id: - return match.group(2).lower() - return None - - -def get_vllm_device_uuid(device_id: int = 0) -> str: - """Get vLLM device uuid with NPU Bus-Id special handling.""" - from vllm.platforms import current_platform - - try: - return current_platform.get_device_uuid(device_id) - except NotImplementedError: - # fix: Root cause was NPU platform calling vLLM base placeholder and raising NotImplementedError. - # fix: Use Bus-Id fallback first so sender/receiver compute the same IPC endpoint. - # NPU special case: prefer stable PCIe Bus-Id from npu-smi. - bus_id = _get_npu_bus_id_from_npu_smi(device_id) - if bus_id: - return bus_id - # fix: If npu-smi is unavailable, fall back to deterministic hash instead of failing hard. - # Generic deterministic fallback to keep sender/receiver socket names aligned. - visible = os.environ.get('ASCEND_RT_VISIBLE_DEVICES') or os.environ.get('CUDA_VISIBLE_DEVICES', '') - raw = f'{socket.gethostname()}:{visible}:{device_id}' - return hashlib.sha1(raw.encode('utf-8')).hexdigest()[:16] - - def is_master(): return Platform.is_master() diff --git a/src/twinkle/utils/framework.py b/src/twinkle/utils/framework.py index d7472563..7d4f7bb6 100644 --- a/src/twinkle/utils/framework.py +++ b/src/twinkle/utils/framework.py @@ -7,7 +7,7 @@ from functools import lru_cache from typing import TYPE_CHECKING, Any, Optional, Union -from .platform import DeviceMesh, Platform +from .device_mesh import DeviceMesh, Platform if TYPE_CHECKING: import torch diff --git a/src/twinkle/utils/logger.py b/src/twinkle/utils/logger.py index 7f4564f2..26a74b9a 100644 --- a/src/twinkle/utils/logger.py +++ b/src/twinkle/utils/logger.py @@ -6,7 +6,7 @@ from types import MethodType from typing import Optional -from .platform import Platform +from .platforms import Platform # Avoid circular reference diff --git a/src/twinkle/utils/network.py b/src/twinkle/utils/network.py index 582a6cdc..b4821f39 100644 --- a/src/twinkle/utils/network.py +++ b/src/twinkle/utils/network.py @@ -1,47 +1,7 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -import os import socket -import torch -from datetime import timedelta from typing import Optional -# ref: https://www.hiascend.com/document/detail/zh/canncommercial/83RC1/maintenref/envvar/envref_07_0144.html -# HCCL base port anchor. HCCL derives internal listen/connect ports from this base. -_HCCL_IF_BASE_PORT_ENV = 'HCCL_IF_BASE_PORT' -# Host-side socket port pool used by HCCL in multi-process communication. -_HCCL_HOST_SOCKET_PORT_RANGE_ENV = 'HCCL_HOST_SOCKET_PORT_RANGE' -# NPU-side socket port pool used by HCCL for device communication channels. -_HCCL_NPU_SOCKET_PORT_RANGE_ENV = 'HCCL_NPU_SOCKET_PORT_RANGE' - - -def _derive_hccl_socket_env_defaults(master_port: int) -> dict: - """Derive deterministic default HCCL socket env values from master_port.""" - # Keep values stable per job and spread jobs across non-overlapping ranges. - host_offset = master_port % 8000 - return { - _HCCL_IF_BASE_PORT_ENV: str(20000 + ((master_port + 997) % 20000)), - _HCCL_HOST_SOCKET_PORT_RANGE_ENV: f'{40000 + host_offset}-{40000 + host_offset + 511}', - _HCCL_NPU_SOCKET_PORT_RANGE_ENV: f'{50000 + host_offset}-{50000 + host_offset + 511}', - } - - -def _ensure_hccl_socket_env(master_port: int, environ: Optional[dict] = None) -> None: - """Set deterministic HCCL socket env defaults to avoid port collisions. - - In multi-job environments, HCCL's default base port (60000) can collide - across concurrent jobs and lead to: - `ra_hdc_socket_listen_start ... ret(-98)`. - - We derive a per-job port layout from `master_port` so all ranks use the - same values while reducing cross-job conflicts. Explicit user settings are - preserved and never overwritten. - """ - # fix: We hit `ra_hdc_socket_listen_start ... ret(-98)` due to HCCL port collisions. - # fix: Derive stable ranges from master_port and preserve explicit user overrides. - env = os.environ if environ is None else environ - for key, value in _derive_hccl_socket_env_defaults(master_port).items(): - env.setdefault(key, value) - def is_valid_ipv6_address(ip: str) -> bool: """Check if the given string is a valid IPv6 address.""" @@ -85,86 +45,3 @@ def find_free_port(address: str = '', start_port: Optional[int] = None, retry: i except OSError: pass return port - - -def stateless_init_process_group( - master_address: str, - master_port: int, - rank: int, - world_size: int, - device: int | torch.device = None, - backend: str = 'nccl', - listen_socket: socket.socket = None, - listen_fd: int = None, -): - """Create a stateless process group using vLLM's StatelessProcessGroup. - - vLLM provides `StatelessProcessGroup` to create a process group - without considering the global process group in torch.distributed. - It is recommended to create `StatelessProcessGroup`, and then initialize - the data-plane communication (NCCL/HCCL) between external (train processes) - and vLLM workers. - - Args: - master_address: The IP address of the master (rank 0). - master_port: The port of the master. - rank: The rank of this process. - world_size: Total number of processes. - device: The CUDA device to use. If None, uses current device. - backend: The communication backend ("nccl" or "hccl"). - listen_socket: Optional pre-created listening socket for master (rank 0). - If provided, this socket will be reused instead of creating a new one. - listen_fd: Optional file descriptor of the listening socket. - - Returns: - PyNcclCommunicator or PyHcclCommunicator instance. - """ - from torch.distributed import TCPStore - from vllm.distributed.utils import StatelessProcessGroup - - if backend == 'hccl': - # fix: Stateless PG + HCCL path needs the same port policy, otherwise workers can still collide. - _ensure_hccl_socket_env(master_port) - from vllm_ascend.distributed.device_communicators.pyhccl import PyHcclCommunicator as Communicator - else: - from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator as Communicator - - if device is None: - device = torch.cuda.current_device() if backend == 'nccl' else torch.npu.current_device() - - # Create the stateless process group - launch_server = rank == 0 - - if launch_server and listen_socket is None: - # For master, create a listening socket if not provided - if is_valid_ipv6_address(master_address): - listen_socket = socket.socket(socket.AF_INET6, socket.SOCK_STREAM) - else: - listen_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - listen_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - listen_socket.bind((master_address, master_port)) - listen_socket.listen() - listen_fd = listen_socket.fileno() - elif launch_server and listen_fd is None: - listen_fd = listen_socket.fileno() - - store = TCPStore( - host_name=master_address, - port=master_port, - world_size=world_size, - is_master=launch_server, - timeout=timedelta(seconds=300), - use_libuv=False, # for compatibility - master_listen_fd=listen_fd, - ) - - pg = StatelessProcessGroup( - rank=rank, - world_size=world_size, - store=store, - socket=listen_socket, - data_expiration_seconds=3600, - ) - - communicator = Communicator(pg, device=device) - return communicator diff --git a/src/twinkle/utils/platforms/__init__.py b/src/twinkle/utils/platforms/__init__.py new file mode 100644 index 00000000..ea327f84 --- /dev/null +++ b/src/twinkle/utils/platforms/__init__.py @@ -0,0 +1,4 @@ +from .base import Platform +from .gpu import GPU +from .mps import MPS, is_mps_available +from .npu import NPU, ensure_hccl_socket_env, ensure_npu_backend diff --git a/src/twinkle/utils/platforms/base.py b/src/twinkle/utils/platforms/base.py new file mode 100644 index 00000000..71c2bd18 --- /dev/null +++ b/src/twinkle/utils/platforms/base.py @@ -0,0 +1,138 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +import os +import shutil +from abc import ABC +from typing import List, Type + + +class Platform(ABC): + + @staticmethod + def visible_device_env(platform: str = None) -> str: + return Platform.get_platform(platform).visible_device_env() + + @staticmethod + def device_prefix(platform: str = None) -> str: + return Platform.get_platform(platform).device_prefix() + + @staticmethod + def get_platform_names() -> List[str]: + return ['GPU', 'NPU', 'MPS'] + + @staticmethod + def get_platform(platform: str = None) -> Type['Platform']: + if platform is None: + from .mps import is_mps_available + if shutil.which('npu-smi'): + from .npu import NPU, ensure_npu_backend + ensure_npu_backend() + return NPU + elif shutil.which('nvidia-smi'): + from .gpu import GPU + return GPU + elif is_mps_available(): + from .mps import MPS + return MPS + else: + from .gpu import GPU + return GPU + elif platform.upper() in ('GPU', 'CUDA'): + from .gpu import GPU + return GPU + elif platform.upper() == 'NPU': + from .npu import NPU, ensure_npu_backend + ensure_npu_backend() + return NPU + elif platform.upper() == 'MPS': + from .mps import MPS + return MPS + else: + raise ValueError(f'Unsupported platform: {platform}.') + + @staticmethod + def get_rank() -> int: + """Get the global rank""" + return int(os.getenv('RANK', -1)) + + @staticmethod + def get_local_rank() -> int: + """Get the local rank""" + return int(os.getenv('LOCAL_RANK', -1)) + + @staticmethod + def get_world_size() -> int: + """Get the world size""" + return int(os.getenv('WORLD_SIZE') or os.getenv('_PATCH_WORLD_SIZE') or 1) + + @staticmethod + def get_local_world_size() -> int: + """Get the local world size""" + return int(os.getenv('LOCAL_WORLD_SIZE', None) or os.getenv('LOCAL_SIZE', 1)) + + @staticmethod + def get_nnodes() -> int: + """Get the node count""" + return int(os.getenv('NNODES', 1)) + + @staticmethod + def get_node_rank() -> int: + """Get the current node rank""" + return int(os.getenv('NODE_RANK', 0)) + + @staticmethod + def is_local_master() -> bool: + """Get if current is the local master""" + local_rank = Platform.get_local_rank() + return local_rank in {-1, 0} + + @staticmethod + def is_master() -> bool: + """Get if current is the global master""" + rank = Platform.get_rank() + return rank in {-1, 0} + + @staticmethod + def is_last_rank() -> bool: + """Get if current is the last rank""" + rank = Platform.get_rank() + world_size = Platform.get_world_size() + return rank in {-1, world_size - 1} + + @staticmethod + def get_peer_index(target_size, rank=None, world_size=None): + if rank is None: + rank = Platform.get_rank() + if rank < 0: + rank = 0 + if world_size is None: + world_size = Platform.get_world_size() + if world_size <= 0: + world_size = 1 + + k, m = divmod(target_size, world_size) + start_idx = rank * k + min(rank, m) + end_idx = (rank + 1) * k + min(rank + 1, m) + if target_size < world_size: + start_idx = rank % target_size + end_idx = start_idx + 1 + + return slice(start_idx, end_idx) + + @staticmethod + def get_local_device(idx: int = None, *, platform: str = None): + platform = Platform.get_platform(platform) + if idx is None: + idx = Platform.get_local_rank() + if idx < 0: + idx = 0 + return platform.get_local_device(idx) + + @staticmethod + def device_backend(platform: str = None): + platform = Platform.get_platform(platform) + return platform.device_backend() + + @staticmethod + def get_vllm_device_uuid(device_id: int = 0, platform=None) -> str: + platform = Platform.get_platform(platform) + return platform.get_vllm_device_uuid(device_id) diff --git a/src/twinkle/utils/platforms/gpu.py b/src/twinkle/utils/platforms/gpu.py new file mode 100644 index 00000000..0b99f885 --- /dev/null +++ b/src/twinkle/utils/platforms/gpu.py @@ -0,0 +1,26 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +from .base import Platform + + +class GPU(Platform): + + @staticmethod + def visible_device_env(): + return 'CUDA_VISIBLE_DEVICES' + + @staticmethod + def device_prefix(): + return 'cuda' + + @staticmethod + def get_local_device(idx, **kwargs) -> str: + return f'cuda:{idx}' + + @staticmethod + def device_backend(platform: str = None): + return 'nccl' + + @staticmethod + def get_vllm_device_uuid(device_id: int = 0) -> str: + from vllm.platforms import current_platform + return current_platform.get_device_uuid(device_id) diff --git a/src/twinkle/utils/platforms/mps.py b/src/twinkle/utils/platforms/mps.py new file mode 100644 index 00000000..e99abb0e --- /dev/null +++ b/src/twinkle/utils/platforms/mps.py @@ -0,0 +1,42 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +import platform +import subprocess +from functools import lru_cache + +from .base import Platform + + +@lru_cache +def is_mps_available(): + if platform.system() != 'Darwin': + return False + try: + output = subprocess.check_output(['system_profiler', 'SPDisplaysDataType'], + stderr=subprocess.DEVNULL, + text=True) + return 'Metal Support' in output + except Exception: # noqa + return False + + +class MPS(Platform): + + @staticmethod + def visible_device_env(): + return None + + @staticmethod + def device_prefix(): + return 'mps' + + @staticmethod + def get_local_device(idx, **kwargs) -> str: + return 'mps' + + @staticmethod + def device_backend(platform: str = None): + return 'gloo' + + @staticmethod + def get_vllm_device_uuid(device_id: int = 0) -> str: + raise NotImplementedError diff --git a/src/twinkle/utils/platforms/npu.py b/src/twinkle/utils/platforms/npu.py new file mode 100644 index 00000000..89066b28 --- /dev/null +++ b/src/twinkle/utils/platforms/npu.py @@ -0,0 +1,135 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +import hashlib +import os +import re +import socket +import subprocess +from typing import Optional + +from .base import Platform + +# ref: https://www.hiascend.com/document/detail/zh/canncommercial/83RC1/maintenref/envvar/envref_07_0144.html +# HCCL base port anchor. HCCL derives internal listen/connect ports from this base. +_HCCL_IF_BASE_PORT_ENV = 'HCCL_IF_BASE_PORT' +# Host-side socket port pool used by HCCL in multi-process communication. +_HCCL_HOST_SOCKET_PORT_RANGE_ENV = 'HCCL_HOST_SOCKET_PORT_RANGE' +# NPU-side socket port pool used by HCCL for device communication channels. +_HCCL_NPU_SOCKET_PORT_RANGE_ENV = 'HCCL_NPU_SOCKET_PORT_RANGE' + + +def _derive_hccl_socket_env_defaults(master_port: int) -> dict: + """Derive deterministic default HCCL socket env values from master_port.""" + # Keep values stable per job and spread jobs across non-overlapping ranges. + host_offset = master_port % 8000 + return { + _HCCL_IF_BASE_PORT_ENV: str(20000 + ((master_port + 997) % 20000)), + _HCCL_HOST_SOCKET_PORT_RANGE_ENV: f'{40000 + host_offset}-{40000 + host_offset + 511}', + _HCCL_NPU_SOCKET_PORT_RANGE_ENV: f'{50000 + host_offset}-{50000 + host_offset + 511}', + } + + +def ensure_hccl_socket_env(master_port: int, environ: Optional[dict] = None) -> None: + """Set deterministic HCCL socket env defaults to avoid port collisions. + + In multi-job environments, HCCL's default base port (60000) can collide + across concurrent jobs and lead to: + `ra_hdc_socket_listen_start ... ret(-98)`. + + We derive a per-job port layout from `master_port` so all ranks use the + same values while reducing cross-job conflicts. Explicit user settings are + preserved and never overwritten. + """ + # fix: We hit `ra_hdc_socket_listen_start ... ret(-98)` due to HCCL port collisions. + # fix: Derive stable ranges from master_port and preserve explicit user overrides. + env = os.environ if environ is None else environ + for key, value in _derive_hccl_socket_env_defaults(master_port).items(): + env.setdefault(key, value) + + +def _resolve_ascend_physical_device_id(device_id: int) -> int: + """Map local NPU device index to physical device id via visible devices.""" + visible = os.environ.get('ASCEND_RT_VISIBLE_DEVICES', '').strip() + if not visible: + return device_id + parts = [p.strip() for p in visible.split(',') if p.strip()] + if device_id < 0 or device_id >= len(parts): + return device_id + return int(parts[device_id]) + + +def _get_npu_bus_id_from_npu_smi(device_id: int) -> Optional[str]: + """Get NPU Bus-Id from `npu-smi info` output.""" + try: + physical_id = _resolve_ascend_physical_device_id(device_id) + except Exception: + physical_id = device_id + + try: + output = subprocess.check_output( + ['npu-smi', 'info'], + text=True, + stderr=subprocess.STDOUT, + timeout=5, + ) + except Exception: + return None + + # fix: vllm-ascend may not implement get_device_uuid, but we still need a reproducible cross-process device id. + # fix: Prefer physical Bus-Id parsed from npu-smi instead of unstable/random identifiers. + # Typical line: + # | 0 0 | 0000:9D:00.0 | ... + pattern = re.compile( + r'^\|\s*\d+\s+(\d+)\s*\|\s*' + r'([0-9A-Fa-f]{4}:[0-9A-Fa-f]{2}:[0-9A-Fa-f]{2}\.[0-9A-Fa-f])\s*\|', + re.MULTILINE, + ) + for match in pattern.finditer(output): + phy_id = int(match.group(1)) + if phy_id == physical_id: + return match.group(2).lower() + return None + + +def ensure_npu_backend() -> None: + try: + import torch_npu # noqa: F401 + except Exception as exc: + raise RuntimeError('NPU backend is not available. Please install torch_npu/Ascend PyTorch.') from exc + + +class NPU(Platform): + + @staticmethod + def visible_device_env(): + # Ascend runtime uses ASCEND_RT_VISIBLE_DEVICES. + return 'ASCEND_RT_VISIBLE_DEVICES' + + @staticmethod + def device_prefix(): + return 'npu' + + @staticmethod + def get_local_device(idx, **kwargs) -> str: + return f'npu:{idx}' + + @staticmethod + def device_backend(platform: str = None): + return 'hccl' + + @staticmethod + def get_vllm_device_uuid(device_id: int = 0) -> str: + from vllm.platforms import current_platform + try: + return current_platform.get_device_uuid(device_id) + except NotImplementedError: + # fix: Root cause was NPU platform calling vLLM base placeholder and raising NotImplementedError. + # fix: Use Bus-Id fallback first so sender/receiver compute the same IPC endpoint. + # NPU special case: prefer stable PCIe Bus-Id from npu-smi. + bus_id = _get_npu_bus_id_from_npu_smi(device_id) + if bus_id: + return bus_id + # fix: If npu-smi is unavailable, fall back to deterministic hash instead of failing hard. + # Generic deterministic fallback to keep sender/receiver socket names aligned. + visible = os.environ.get(Platform.visible_device_env()) + raw = f'{socket.gethostname()}:{visible}:{device_id}' + return hashlib.sha1(raw.encode('utf-8')).hexdigest()[:16] diff --git a/src/twinkle/utils/safetensors.py b/src/twinkle/utils/safetensors.py index e1fa62e9..42619c54 100644 --- a/src/twinkle/utils/safetensors.py +++ b/src/twinkle/utils/safetensors.py @@ -3,7 +3,7 @@ from functools import partial from typing import Literal -from .platform import is_last_rank, is_master +from .device_mesh import is_last_rank, is_master class LazyTensor: diff --git a/src/twinkle/utils/torch_utils.py b/src/twinkle/utils/torch_utils.py index c5b45047..91bf3569 100644 --- a/src/twinkle/utils/torch_utils.py +++ b/src/twinkle/utils/torch_utils.py @@ -1,4 +1,8 @@ -from typing import TYPE_CHECKING, Any, List, Mapping, Optional, Union +import socket +from datetime import timedelta +from typing import TYPE_CHECKING, Any, Mapping, Union + +from .network import is_valid_ipv6_address if TYPE_CHECKING: import torch @@ -103,3 +107,88 @@ def _vocab_parallel_selective_log_softmax( tp_group = mpu.get_tensor_model_parallel_group() return -fused_vocab_parallel_cross_entropy(logits, index, tp_group) + + +def stateless_init_process_group( + master_address: str, + master_port: int, + rank: int, + world_size: int, + device: Union[int, 'torch.device'] = None, + backend: str = 'nccl', + listen_socket: socket.socket = None, + listen_fd: int = None, +): + """Create a stateless process group using vLLM's StatelessProcessGroup. + + vLLM provides `StatelessProcessGroup` to create a process group + without considering the global process group in torch.distributed. + It is recommended to create `StatelessProcessGroup`, and then initialize + the data-plane communication (NCCL/HCCL) between external (train processes) + and vLLM workers. + + Args: + master_address: The IP address of the master (rank 0). + master_port: The port of the master. + rank: The rank of this process. + world_size: Total number of processes. + device: The CUDA device to use. If None, uses current device. + backend: The communication backend ("nccl" or "hccl"). + listen_socket: Optional pre-created listening socket for master (rank 0). + If provided, this socket will be reused instead of creating a new one. + listen_fd: Optional file descriptor of the listening socket. + + Returns: + PyNcclCommunicator or PyHcclCommunicator instance. + """ + import torch + from torch.distributed import TCPStore + from vllm.distributed.utils import StatelessProcessGroup + + if backend == 'hccl': + # fix: Stateless PG + HCCL path needs the same port policy, otherwise workers can still collide. + from .platforms import ensure_hccl_socket_env + ensure_hccl_socket_env(master_port) + from vllm_ascend.distributed.device_communicators.pyhccl import PyHcclCommunicator as Communicator + else: + from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator as Communicator + + if device is None: + device = torch.cuda.current_device() if backend == 'nccl' else torch.npu.current_device() + + # Create the stateless process group + launch_server = rank == 0 + + if launch_server and listen_socket is None: + # For master, create a listening socket if not provided + if is_valid_ipv6_address(master_address): + listen_socket = socket.socket(socket.AF_INET6, socket.SOCK_STREAM) + else: + listen_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + listen_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + listen_socket.bind((master_address, master_port)) + listen_socket.listen() + listen_fd = listen_socket.fileno() + elif launch_server and listen_fd is None: + listen_fd = listen_socket.fileno() + + store = TCPStore( + host_name=master_address, + port=master_port, + world_size=world_size, + is_master=launch_server, + timeout=timedelta(seconds=300), + use_libuv=False, # for compatibility + master_listen_fd=listen_fd, + ) + + pg = StatelessProcessGroup( + rank=rank, + world_size=world_size, + store=store, + socket=listen_socket, + data_expiration_seconds=3600, + ) + + communicator = Communicator(pg, device=device) + return communicator