Skip to content

Commit a8450de

Browse files
committed
fix npu grpo
1 parent b941858 commit a8450de

File tree

7 files changed

+155
-14
lines changed

7 files changed

+155
-14
lines changed

src/twinkle/dataset/base.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from dataclasses import dataclass
55
from typing import Callable, Type, Union, Dict, Any
66

7-
from datasets import interleave_datasets, concatenate_datasets, load_dataset, IterableDataset
7+
from datasets import interleave_datasets, concatenate_datasets, load_dataset, IterableDataset, DatasetDict
88
from torch.utils.data import Dataset as TorchDataset
99

1010
import twinkle
@@ -132,6 +132,21 @@ def _load_dataset(dataset_meta: DatasetMeta, **kwargs):
132132
dataset = load_dataset(file_type, data_files=dataset_id, **kwargs)
133133
else:
134134
dataset = HubOperation.load_dataset(dataset_id, subset_name, split, **kwargs)
135+
136+
# fix: Some dataset sources return DatasetDict instead of Dataset, which breaks downstream select/map calls.
137+
# fix: Normalize split resolution here (target split first, then train) and fail early with a clear error.
138+
if isinstance(dataset, DatasetDict):
139+
if split in dataset:
140+
dataset = dataset[split]
141+
elif 'train' in dataset:
142+
dataset = dataset['train']
143+
else:
144+
available_splits = list(dataset.keys())
145+
raise KeyError(
146+
f"Split '{split}' not found for dataset '{dataset_id}'. "
147+
f'Available splits: {available_splits}'
148+
)
149+
135150
if isinstance(dataset_meta.data_slice, Iterable) and hasattr(dataset, '__len__'):
136151

137152
iter_list = []

src/twinkle/model/base.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,11 +138,20 @@ def upload_to_hub(self, checkpoint_dir: str, hub_model_id: str, hub_token: Optio
138138
)
139139

140140
def _try_init_process_group(self):
141+
import os
141142
import torch
142143
import torch.distributed as dist
143144
if not dist.is_initialized() and Platform.get_world_size() > 1:
144145
torch_util.set_device()
145146
backend = Platform.device_backend()
147+
if backend == "hccl":
148+
# fix: In multi-job NPU runs, HCCL default ports may collide (bind/listen failures).
149+
# fix: Inject deterministic per-job port ranges before PG init to reduce cross-job conflicts.
150+
# Keep training-side HCCL sockets on a per-job port layout to
151+
# avoid collisions with other jobs on the same host.
152+
from twinkle.utils.network import _ensure_hccl_socket_env
153+
master_port = int(os.environ.get("MASTER_PORT", "29500"))
154+
_ensure_hccl_socket_env(master_port)
146155
init_kwargs = {
147156
"backend": backend,
148157
"init_method": "env://",

src/twinkle/sampler/vllm_sampler/vllm_engine.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from twinkle import get_logger
88
from twinkle.sampler.base_engine import BaseSamplerEngine
99
from twinkle.data_format.sampling import StopReason, SamplingParams, SampleResponse, SampledSequence
10+
from twinkle.utils.platform import get_vllm_device_uuid
1011

1112
import inspect
1213
logger = get_logger()
@@ -569,8 +570,11 @@ async def _sync_iter():
569570
use_gpu_ipc = first_tensor.is_cuda
570571
use_shm = not use_gpu_ipc
571572

572-
# Get device UUID for ZMQ handle
573-
device_uuid = current_platform.get_device_uuid(0)
573+
# fix: On NPU, current_platform.get_device_uuid may be unimplemented and break receive_weights flow.
574+
# fix: Route through platform-level fallback so IPC socket name remains stable.
575+
# Get device UUID for ZMQ handle.
576+
# For NPU, this is resolved from `npu-smi info` Bus-Id when needed.
577+
device_uuid = get_vllm_device_uuid(0)
574578
zmq_handle = f"ipc:///tmp/twinkle-ipc-{device_uuid}.sock"
575579

576580
bucket_size = bucket_size_mb << 20

src/twinkle/sampler/vllm_sampler/vllm_sampler.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,11 @@ def __init__(
138138
self.engine: VLLMEngine = self._run_in_loop(
139139
self._create_engine_async(VLLMEngine, model_id, engine_kwargs)
140140
)
141-
self._run_in_loop(self.engine.engine.collective_rpc("monkey_patch_model"))
141+
# fix: On NPU, monkey_patch_model can trigger Triton compatibility errors and abort sampler init.
142+
# fix: Explicitly skip this patch on NPU and keep it for non-NPU paths only.
143+
# NPU platform may trigger triton errors with monkey_patch_model
144+
if not Platform.get_platform().device_prefix().upper() == 'NPU':
145+
self._run_in_loop(self.engine.engine.collective_rpc("monkey_patch_model"))
142146

143147
VLLMLoraWeights()(self)
144148

@@ -559,4 +563,3 @@ def shutdown(self):
559563
self._async_thread.join(timeout=5)
560564
except Exception as e:
561565
logger.warning(f"vLLMSampler event loop shutdown error: {e}")
562-

src/twinkle/sampler/vllm_sampler/vllm_worker_extension.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from twinkle import get_logger
2222
import torch
2323
from twinkle.utils.framework import Torch
24+
from twinkle.utils.platform import get_vllm_device_uuid
2425

2526
logger = get_logger()
2627

@@ -64,12 +65,6 @@ def _rebuild_shared_memory(name: str, size: int):
6465
return tensor, shm
6566

6667

67-
def _get_device_uuid(device_id: int) -> str:
68-
"""Get unique device identifier."""
69-
from vllm.platforms import current_platform
70-
return current_platform.get_device_uuid(device_id)
71-
72-
7368
class TwinkleWorkerExtension:
7469
"""Extension class for vLLM workers to support weight synchronization.
7570
@@ -122,7 +117,10 @@ def update_weights_from_ipc(
122117
import torch.distributed as dist
123118

124119
if self.device is None:
125-
self.device = torch.device(Torch.get_device())
120+
# fix: In some worker paths, omitting local_rank can pick the wrong device / trigger get_device arg issues.
121+
# fix: Pass local_rank when available so each worker binds to the expected local device.
122+
print(f"VLLM Worker local_rank: {getattr(self, 'local_rank', None)} <<<<<<<<<<<<< {Torch.get_device()}")
123+
self.device = torch.device(Torch.get_device(getattr(self, "local_rank", None)))
126124

127125
if peft_config and base_sync_done:
128126
self.remove_lora(VLLM_LORA_INT_ID)
@@ -257,7 +255,8 @@ def load_synced_weights(
257255
base_sync_done: If True with peft_config, load as LoRA adapter.
258256
"""
259257
if self.device is None:
260-
self.device = torch.device(Torch.get_device())
258+
# fix: Keep device resolution consistent with update_weights_from_ipc to avoid path divergence.
259+
self.device = torch.device(Torch.get_device(getattr(self, "local_rank", None)))
261260

262261
weight_list = list(weights.items())
263262
self._load_weights(weight_list, peft_config=peft_config, base_sync_done=base_sync_done)
@@ -374,5 +373,6 @@ def _load_weights(
374373
def _get_zmq_handle(self) -> str:
375374
"""Get ZMQ handle for IPC communication."""
376375
if not hasattr(self, '_device_uuid') or not self._device_uuid:
377-
self._device_uuid = _get_device_uuid(self.device.index)
376+
# fix: Always use platform fallback to avoid worker-side crashes when NPU get_device_uuid is unimplemented.
377+
self._device_uuid = get_vllm_device_uuid(self.device.index)
378378
return f"ipc:///tmp/twinkle-ipc-{self._device_uuid}.sock"

src/twinkle/utils/network.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,48 @@
11
# Copyright (c) ModelScope Contributors. All rights reserved.
2+
import os
23
import socket
34
from datetime import timedelta
45
from typing import Optional
56

67
import torch
78

9+
# ref: https://www.hiascend.com/document/detail/zh/canncommercial/83RC1/maintenref/envvar/envref_07_0144.html
10+
# HCCL base port anchor. HCCL derives internal listen/connect ports from this base.
11+
_HCCL_IF_BASE_PORT_ENV = "HCCL_IF_BASE_PORT"
12+
# Host-side socket port pool used by HCCL in multi-process communication.
13+
_HCCL_HOST_SOCKET_PORT_RANGE_ENV = "HCCL_HOST_SOCKET_PORT_RANGE"
14+
# NPU-side socket port pool used by HCCL for device communication channels.
15+
_HCCL_NPU_SOCKET_PORT_RANGE_ENV = "HCCL_NPU_SOCKET_PORT_RANGE"
16+
17+
18+
def _derive_hccl_socket_env_defaults(master_port: int) -> dict:
19+
"""Derive deterministic default HCCL socket env values from master_port."""
20+
# Keep values stable per job and spread jobs across non-overlapping ranges.
21+
host_offset = master_port % 8000
22+
return {
23+
_HCCL_IF_BASE_PORT_ENV: str(20000 + ((master_port + 997) % 20000)),
24+
_HCCL_HOST_SOCKET_PORT_RANGE_ENV: f"{40000 + host_offset}-{40000 + host_offset + 511}",
25+
_HCCL_NPU_SOCKET_PORT_RANGE_ENV: f"{50000 + host_offset}-{50000 + host_offset + 511}",
26+
}
27+
28+
29+
def _ensure_hccl_socket_env(master_port: int, environ: Optional[dict] = None) -> None:
30+
"""Set deterministic HCCL socket env defaults to avoid port collisions.
31+
32+
In multi-job environments, HCCL's default base port (60000) can collide
33+
across concurrent jobs and lead to:
34+
`ra_hdc_socket_listen_start ... ret(-98)`.
35+
36+
We derive a per-job port layout from `master_port` so all ranks use the
37+
same values while reducing cross-job conflicts. Explicit user settings are
38+
preserved and never overwritten.
39+
"""
40+
# fix: We hit `ra_hdc_socket_listen_start ... ret(-98)` due to HCCL port collisions.
41+
# fix: Derive stable ranges from master_port and preserve explicit user overrides.
42+
env = os.environ if environ is None else environ
43+
for key, value in _derive_hccl_socket_env_defaults(master_port).items():
44+
env.setdefault(key, value)
45+
846

947
def is_valid_ipv6_address(ip: str) -> bool:
1048
"""Check if the given string is a valid IPv6 address."""
@@ -87,6 +125,8 @@ def stateless_init_process_group(
87125
from vllm.distributed.utils import StatelessProcessGroup
88126

89127
if backend == "hccl":
128+
# fix: Stateless PG + HCCL path needs the same port policy, otherwise workers can still collide.
129+
_ensure_hccl_socket_env(master_port)
90130
from vllm_ascend.distributed.device_communicators.pyhccl import (
91131
PyHcclCommunicator as Communicator,
92132
)

src/twinkle/utils/platform.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
# Copyright (c) ModelScope Contributors. All rights reserved.
22
import os
33
import platform
4+
import hashlib
5+
import re
46
import shutil
7+
import socket
58
import subprocess
69
from abc import ABC
710
from dataclasses import dataclass, field
@@ -641,5 +644,72 @@ def is_last_rank():
641644
return True
642645
return dist.get_rank() == dist.get_world_size() - 1
643646

647+
648+
def _resolve_ascend_physical_device_id(device_id: int) -> int:
649+
"""Map local NPU device index to physical device id via visible devices."""
650+
visible = os.environ.get("ASCEND_RT_VISIBLE_DEVICES", "").strip()
651+
if not visible:
652+
return device_id
653+
parts = [p.strip() for p in visible.split(",") if p.strip()]
654+
if device_id < 0 or device_id >= len(parts):
655+
return device_id
656+
return int(parts[device_id])
657+
658+
659+
def _get_npu_bus_id_from_npu_smi(device_id: int) -> Optional[str]:
660+
"""Get NPU Bus-Id from `npu-smi info` output."""
661+
try:
662+
physical_id = _resolve_ascend_physical_device_id(device_id)
663+
except Exception:
664+
physical_id = device_id
665+
666+
try:
667+
output = subprocess.check_output(
668+
["npu-smi", "info"],
669+
text=True,
670+
stderr=subprocess.STDOUT,
671+
timeout=5,
672+
)
673+
except Exception:
674+
return None
675+
676+
# fix: vllm-ascend may not implement get_device_uuid, but we still need a reproducible cross-process device id.
677+
# fix: Prefer physical Bus-Id parsed from npu-smi instead of unstable/random identifiers.
678+
# Typical line:
679+
# | 0 0 | 0000:9D:00.0 | ...
680+
pattern = re.compile(
681+
r"^\|\s*\d+\s+(\d+)\s*\|\s*"
682+
r"([0-9A-Fa-f]{4}:[0-9A-Fa-f]{2}:[0-9A-Fa-f]{2}\.[0-9A-Fa-f])\s*\|",
683+
re.MULTILINE,
684+
)
685+
for match in pattern.finditer(output):
686+
phy_id = int(match.group(1))
687+
if phy_id == physical_id:
688+
return match.group(2).lower()
689+
return None
690+
691+
692+
def get_vllm_device_uuid(device_id: int = 0) -> str:
693+
"""Get vLLM device uuid with NPU Bus-Id special handling."""
694+
from vllm.platforms import current_platform
695+
696+
try:
697+
return current_platform.get_device_uuid(device_id)
698+
except NotImplementedError:
699+
# fix: Root cause was NPU platform calling vLLM base placeholder and raising NotImplementedError.
700+
# fix: Use Bus-Id fallback first so sender/receiver compute the same IPC endpoint.
701+
# NPU special case: prefer stable PCIe Bus-Id from npu-smi.
702+
bus_id = _get_npu_bus_id_from_npu_smi(device_id)
703+
if bus_id:
704+
return bus_id
705+
# fix: If npu-smi is unavailable, fall back to deterministic hash instead of failing hard.
706+
# Generic deterministic fallback to keep sender/receiver socket names aligned.
707+
visible = os.environ.get("ASCEND_RT_VISIBLE_DEVICES") or os.environ.get(
708+
"CUDA_VISIBLE_DEVICES", ""
709+
)
710+
raw = f"{socket.gethostname()}:{visible}:{device_id}"
711+
return hashlib.sha1(raw.encode("utf-8")).hexdigest()[:16]
712+
713+
644714
def is_master():
645715
return Platform.is_master()

0 commit comments

Comments
 (0)