Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion src/twinkle/dataset/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from dataclasses import dataclass
from typing import Callable, Type, Union, Dict, Any

from datasets import interleave_datasets, concatenate_datasets, load_dataset, IterableDataset
from datasets import interleave_datasets, concatenate_datasets, load_dataset, IterableDataset, DatasetDict
from torch.utils.data import Dataset as TorchDataset

import twinkle
Expand Down Expand Up @@ -132,6 +132,21 @@ def _load_dataset(dataset_meta: DatasetMeta, **kwargs):
dataset = load_dataset(file_type, data_files=dataset_id, **kwargs)
else:
dataset = HubOperation.load_dataset(dataset_id, subset_name, split, **kwargs)

# fix: Some dataset sources return DatasetDict instead of Dataset, which breaks downstream select/map calls.
# fix: Normalize split resolution here (target split first, then train) and fail early with a clear error.
if isinstance(dataset, DatasetDict):
if split in dataset:
dataset = dataset[split]
elif 'train' in dataset:
dataset = dataset['train']
else:
available_splits = list(dataset.keys())
raise KeyError(
f"Split '{split}' not found for dataset '{dataset_id}'. "
f'Available splits: {available_splits}'
)

if isinstance(dataset_meta.data_slice, Iterable) and hasattr(dataset, '__len__'):

iter_list = []
Expand Down
9 changes: 9 additions & 0 deletions src/twinkle/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,11 +138,20 @@ def upload_to_hub(self, checkpoint_dir: str, hub_model_id: str, hub_token: Optio
)

def _try_init_process_group(self):
import os
import torch
import torch.distributed as dist
if not dist.is_initialized() and Platform.get_world_size() > 1:
torch_util.set_device()
backend = Platform.device_backend()
if backend == "hccl":
# fix: In multi-job NPU runs, HCCL default ports may collide (bind/listen failures).
# fix: Inject deterministic per-job port ranges before PG init to reduce cross-job conflicts.
# Keep training-side HCCL sockets on a per-job port layout to
# avoid collisions with other jobs on the same host.
from twinkle.utils.network import _ensure_hccl_socket_env
master_port = int(os.environ.get("MASTER_PORT", "29500"))
_ensure_hccl_socket_env(master_port)
init_kwargs = {
"backend": backend,
"init_method": "env://",
Expand Down
8 changes: 6 additions & 2 deletions src/twinkle/sampler/vllm_sampler/vllm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from twinkle import get_logger
from twinkle.sampler.base_engine import BaseSamplerEngine
from twinkle.data_format.sampling import StopReason, SamplingParams, SampleResponse, SampledSequence
from twinkle.utils.platform import get_vllm_device_uuid

import inspect
logger = get_logger()
Expand Down Expand Up @@ -569,8 +570,11 @@ async def _sync_iter():
use_gpu_ipc = first_tensor.is_cuda
use_shm = not use_gpu_ipc

# Get device UUID for ZMQ handle
device_uuid = current_platform.get_device_uuid(0)
# fix: On NPU, current_platform.get_device_uuid may be unimplemented and break receive_weights flow.
# fix: Route through platform-level fallback so IPC socket name remains stable.
# Get device UUID for ZMQ handle.
# For NPU, this is resolved from `npu-smi info` Bus-Id when needed.
device_uuid = get_vllm_device_uuid(0)
zmq_handle = f"ipc:///tmp/twinkle-ipc-{device_uuid}.sock"

bucket_size = bucket_size_mb << 20
Expand Down
7 changes: 5 additions & 2 deletions src/twinkle/sampler/vllm_sampler/vllm_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,11 @@ def __init__(
self.engine: VLLMEngine = self._run_in_loop(
self._create_engine_async(VLLMEngine, model_id, engine_kwargs)
)
self._run_in_loop(self.engine.engine.collective_rpc("monkey_patch_model"))
# fix: On NPU, monkey_patch_model can trigger Triton compatibility errors and abort sampler init.
# fix: Explicitly skip this patch on NPU and keep it for non-NPU paths only.
# NPU platform may trigger triton errors with monkey_patch_model
if Platform.get_platform().device_prefix() != 'npu':
self._run_in_loop(self.engine.engine.collective_rpc("monkey_patch_model"))

VLLMLoraWeights()(self)

Expand Down Expand Up @@ -559,4 +563,3 @@ def shutdown(self):
self._async_thread.join(timeout=5)
except Exception as e:
logger.warning(f"vLLMSampler event loop shutdown error: {e}")

18 changes: 9 additions & 9 deletions src/twinkle/sampler/vllm_sampler/vllm_worker_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from twinkle import get_logger
import torch
from twinkle.utils.framework import Torch
from twinkle.utils.platform import get_vllm_device_uuid

logger = get_logger()

Expand Down Expand Up @@ -64,12 +65,6 @@ def _rebuild_shared_memory(name: str, size: int):
return tensor, shm


def _get_device_uuid(device_id: int) -> str:
"""Get unique device identifier."""
from vllm.platforms import current_platform
return current_platform.get_device_uuid(device_id)


class TwinkleWorkerExtension:
"""Extension class for vLLM workers to support weight synchronization.

Expand Down Expand Up @@ -122,7 +117,10 @@ def update_weights_from_ipc(
import torch.distributed as dist

if self.device is None:
self.device = torch.device(Torch.get_device())
# fix: In some worker paths, omitting local_rank can pick the wrong device / trigger get_device arg issues.
# fix: Pass local_rank when available so each worker binds to the expected local device.
print(f"VLLM Worker local_rank: {getattr(self, 'local_rank', None)} <<<<<<<<<<<<< {Torch.get_device()}")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This print statement appears to be for debugging purposes. It should be removed from the production code to avoid cluttering the logs.

self.device = torch.device(Torch.get_device(getattr(self, "local_rank", None)))

if peft_config and base_sync_done:
self.remove_lora(VLLM_LORA_INT_ID)
Expand Down Expand Up @@ -257,7 +255,8 @@ def load_synced_weights(
base_sync_done: If True with peft_config, load as LoRA adapter.
"""
if self.device is None:
self.device = torch.device(Torch.get_device())
# fix: Keep device resolution consistent with update_weights_from_ipc to avoid path divergence.
self.device = torch.device(Torch.get_device(getattr(self, "local_rank", None)))

weight_list = list(weights.items())
self._load_weights(weight_list, peft_config=peft_config, base_sync_done=base_sync_done)
Expand Down Expand Up @@ -374,5 +373,6 @@ def _load_weights(
def _get_zmq_handle(self) -> str:
"""Get ZMQ handle for IPC communication."""
if not hasattr(self, '_device_uuid') or not self._device_uuid:
self._device_uuid = _get_device_uuid(self.device.index)
# fix: Always use platform fallback to avoid worker-side crashes when NPU get_device_uuid is unimplemented.
self._device_uuid = get_vllm_device_uuid(self.device.index)
return f"ipc:///tmp/twinkle-ipc-{self._device_uuid}.sock"
40 changes: 40 additions & 0 deletions src/twinkle/utils/network.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,48 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
import os
import socket
from datetime import timedelta
from typing import Optional

import torch

# ref: https://www.hiascend.com/document/detail/zh/canncommercial/83RC1/maintenref/envvar/envref_07_0144.html
# HCCL base port anchor. HCCL derives internal listen/connect ports from this base.
_HCCL_IF_BASE_PORT_ENV = "HCCL_IF_BASE_PORT"
# Host-side socket port pool used by HCCL in multi-process communication.
_HCCL_HOST_SOCKET_PORT_RANGE_ENV = "HCCL_HOST_SOCKET_PORT_RANGE"
# NPU-side socket port pool used by HCCL for device communication channels.
_HCCL_NPU_SOCKET_PORT_RANGE_ENV = "HCCL_NPU_SOCKET_PORT_RANGE"


def _derive_hccl_socket_env_defaults(master_port: int) -> dict:
"""Derive deterministic default HCCL socket env values from master_port."""
# Keep values stable per job and spread jobs across non-overlapping ranges.
host_offset = master_port % 8000
return {
_HCCL_IF_BASE_PORT_ENV: str(20000 + ((master_port + 997) % 20000)),
_HCCL_HOST_SOCKET_PORT_RANGE_ENV: f"{40000 + host_offset}-{40000 + host_offset + 511}",
_HCCL_NPU_SOCKET_PORT_RANGE_ENV: f"{50000 + host_offset}-{50000 + host_offset + 511}",
}


def _ensure_hccl_socket_env(master_port: int, environ: Optional[dict] = None) -> None:
"""Set deterministic HCCL socket env defaults to avoid port collisions.

In multi-job environments, HCCL's default base port (60000) can collide
across concurrent jobs and lead to:
`ra_hdc_socket_listen_start ... ret(-98)`.

We derive a per-job port layout from `master_port` so all ranks use the
same values while reducing cross-job conflicts. Explicit user settings are
preserved and never overwritten.
"""
# fix: We hit `ra_hdc_socket_listen_start ... ret(-98)` due to HCCL port collisions.
# fix: Derive stable ranges from master_port and preserve explicit user overrides.
env = os.environ if environ is None else environ
for key, value in _derive_hccl_socket_env_defaults(master_port).items():
env.setdefault(key, value)


def is_valid_ipv6_address(ip: str) -> bool:
"""Check if the given string is a valid IPv6 address."""
Expand Down Expand Up @@ -87,6 +125,8 @@ def stateless_init_process_group(
from vllm.distributed.utils import StatelessProcessGroup

if backend == "hccl":
# fix: Stateless PG + HCCL path needs the same port policy, otherwise workers can still collide.
_ensure_hccl_socket_env(master_port)
from vllm_ascend.distributed.device_communicators.pyhccl import (
PyHcclCommunicator as Communicator,
)
Expand Down
70 changes: 70 additions & 0 deletions src/twinkle/utils/platform.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
import os
import platform
import hashlib
import re
import shutil
import socket
import subprocess
from abc import ABC
from dataclasses import dataclass, field
Expand Down Expand Up @@ -641,5 +644,72 @@ def is_last_rank():
return True
return dist.get_rank() == dist.get_world_size() - 1


def _resolve_ascend_physical_device_id(device_id: int) -> int:
"""Map local NPU device index to physical device id via visible devices."""
visible = os.environ.get("ASCEND_RT_VISIBLE_DEVICES", "").strip()
if not visible:
return device_id
parts = [p.strip() for p in visible.split(",") if p.strip()]
if device_id < 0 or device_id >= len(parts):
return device_id
return int(parts[device_id])


def _get_npu_bus_id_from_npu_smi(device_id: int) -> Optional[str]:
"""Get NPU Bus-Id from `npu-smi info` output."""
try:
physical_id = _resolve_ascend_physical_device_id(device_id)
except Exception:
physical_id = device_id

try:
output = subprocess.check_output(
["npu-smi", "info"],
text=True,
stderr=subprocess.STDOUT,
timeout=5,
)
except Exception:
return None

# fix: vllm-ascend may not implement get_device_uuid, but we still need a reproducible cross-process device id.
# fix: Prefer physical Bus-Id parsed from npu-smi instead of unstable/random identifiers.
# Typical line:
# | 0 0 | 0000:9D:00.0 | ...
pattern = re.compile(
r"^\|\s*\d+\s+(\d+)\s*\|\s*"
r"([0-9A-Fa-f]{4}:[0-9A-Fa-f]{2}:[0-9A-Fa-f]{2}\.[0-9A-Fa-f])\s*\|",
re.MULTILINE,
)
for match in pattern.finditer(output):
phy_id = int(match.group(1))
if phy_id == physical_id:
return match.group(2).lower()
return None


def get_vllm_device_uuid(device_id: int = 0) -> str:
"""Get vLLM device uuid with NPU Bus-Id special handling."""
from vllm.platforms import current_platform

try:
return current_platform.get_device_uuid(device_id)
except NotImplementedError:
# fix: Root cause was NPU platform calling vLLM base placeholder and raising NotImplementedError.
# fix: Use Bus-Id fallback first so sender/receiver compute the same IPC endpoint.
# NPU special case: prefer stable PCIe Bus-Id from npu-smi.
bus_id = _get_npu_bus_id_from_npu_smi(device_id)
if bus_id:
return bus_id
# fix: If npu-smi is unavailable, fall back to deterministic hash instead of failing hard.
# Generic deterministic fallback to keep sender/receiver socket names aligned.
visible = os.environ.get("ASCEND_RT_VISIBLE_DEVICES") or os.environ.get(
"CUDA_VISIBLE_DEVICES", ""
)
raw = f"{socket.gethostname()}:{visible}:{device_id}"
return hashlib.sha1(raw.encode("utf-8")).hexdigest()[:16]


def is_master():
return Platform.is_master()
Loading