Skip to content

Commit 55c2183

Browse files
committed
npu grpo fix
1 parent b941858 commit 55c2183

File tree

7 files changed

+124
-13
lines changed

7 files changed

+124
-13
lines changed

src/twinkle/dataset/base.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from dataclasses import dataclass
55
from typing import Callable, Type, Union, Dict, Any
66

7-
from datasets import interleave_datasets, concatenate_datasets, load_dataset, IterableDataset
7+
from datasets import interleave_datasets, concatenate_datasets, load_dataset, IterableDataset, DatasetDict
88
from torch.utils.data import Dataset as TorchDataset
99

1010
import twinkle
@@ -132,6 +132,19 @@ def _load_dataset(dataset_meta: DatasetMeta, **kwargs):
132132
dataset = load_dataset(file_type, data_files=dataset_id, **kwargs)
133133
else:
134134
dataset = HubOperation.load_dataset(dataset_id, subset_name, split, **kwargs)
135+
136+
if isinstance(dataset, DatasetDict):
137+
if split in dataset:
138+
dataset = dataset[split]
139+
elif 'train' in dataset:
140+
dataset = dataset['train']
141+
else:
142+
available_splits = list(dataset.keys())
143+
raise KeyError(
144+
f"Split '{split}' not found for dataset '{dataset_id}'. "
145+
f'Available splits: {available_splits}'
146+
)
147+
135148
if isinstance(dataset_meta.data_slice, Iterable) and hasattr(dataset, '__len__'):
136149

137150
iter_list = []

src/twinkle/model/base.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,11 +138,18 @@ def upload_to_hub(self, checkpoint_dir: str, hub_model_id: str, hub_token: Optio
138138
)
139139

140140
def _try_init_process_group(self):
141+
import os
141142
import torch
142143
import torch.distributed as dist
143144
if not dist.is_initialized() and Platform.get_world_size() > 1:
144145
torch_util.set_device()
145146
backend = Platform.device_backend()
147+
if backend == "hccl":
148+
# Keep training-side HCCL sockets on a per-job port layout to
149+
# avoid collisions with other jobs on the same host.
150+
from twinkle.utils.network import _ensure_hccl_socket_env
151+
master_port = int(os.environ.get("MASTER_PORT", "29500"))
152+
_ensure_hccl_socket_env(master_port)
146153
init_kwargs = {
147154
"backend": backend,
148155
"init_method": "env://",

src/twinkle/sampler/vllm_sampler/vllm_engine.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from twinkle import get_logger
88
from twinkle.sampler.base_engine import BaseSamplerEngine
99
from twinkle.data_format.sampling import StopReason, SamplingParams, SampleResponse, SampledSequence
10+
from twinkle.utils.platform import get_vllm_device_uuid
1011

1112
import inspect
1213
logger = get_logger()
@@ -569,8 +570,9 @@ async def _sync_iter():
569570
use_gpu_ipc = first_tensor.is_cuda
570571
use_shm = not use_gpu_ipc
571572

572-
# Get device UUID for ZMQ handle
573-
device_uuid = current_platform.get_device_uuid(0)
573+
# Get device UUID for ZMQ handle.
574+
# For NPU, this is resolved from `npu-smi info` Bus-Id when needed.
575+
device_uuid = get_vllm_device_uuid(0)
574576
zmq_handle = f"ipc:///tmp/twinkle-ipc-{device_uuid}.sock"
575577

576578
bucket_size = bucket_size_mb << 20

src/twinkle/sampler/vllm_sampler/vllm_sampler.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,9 @@ def __init__(
138138
self.engine: VLLMEngine = self._run_in_loop(
139139
self._create_engine_async(VLLMEngine, model_id, engine_kwargs)
140140
)
141-
self._run_in_loop(self.engine.engine.collective_rpc("monkey_patch_model"))
141+
# NPU platform may trigger triton errors with monkey_patch_model
142+
if not Platform.get_platform().device_prefix().upper() == 'NPU':
143+
self._run_in_loop(self.engine.engine.collective_rpc("monkey_patch_model"))
142144

143145
VLLMLoraWeights()(self)
144146

src/twinkle/sampler/vllm_sampler/vllm_worker_extension.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from twinkle import get_logger
2222
import torch
2323
from twinkle.utils.framework import Torch
24+
from twinkle.utils.platform import get_vllm_device_uuid
2425

2526
logger = get_logger()
2627

@@ -64,12 +65,6 @@ def _rebuild_shared_memory(name: str, size: int):
6465
return tensor, shm
6566

6667

67-
def _get_device_uuid(device_id: int) -> str:
68-
"""Get unique device identifier."""
69-
from vllm.platforms import current_platform
70-
return current_platform.get_device_uuid(device_id)
71-
72-
7368
class TwinkleWorkerExtension:
7469
"""Extension class for vLLM workers to support weight synchronization.
7570
@@ -122,7 +117,7 @@ def update_weights_from_ipc(
122117
import torch.distributed as dist
123118

124119
if self.device is None:
125-
self.device = torch.device(Torch.get_device())
120+
self.device = torch.device(Torch.get_device(getattr(self, "local_rank", None)))
126121

127122
if peft_config and base_sync_done:
128123
self.remove_lora(VLLM_LORA_INT_ID)
@@ -257,7 +252,7 @@ def load_synced_weights(
257252
base_sync_done: If True with peft_config, load as LoRA adapter.
258253
"""
259254
if self.device is None:
260-
self.device = torch.device(Torch.get_device())
255+
self.device = torch.device(Torch.get_device(getattr(self, "local_rank", None)))
261256

262257
weight_list = list(weights.items())
263258
self._load_weights(weight_list, peft_config=peft_config, base_sync_done=base_sync_done)
@@ -374,5 +369,5 @@ def _load_weights(
374369
def _get_zmq_handle(self) -> str:
375370
"""Get ZMQ handle for IPC communication."""
376371
if not hasattr(self, '_device_uuid') or not self._device_uuid:
377-
self._device_uuid = _get_device_uuid(self.device.index)
372+
self._device_uuid = get_vllm_device_uuid(self.device.index)
378373
return f"ipc:///tmp/twinkle-ipc-{self._device_uuid}.sock"

src/twinkle/utils/network.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,37 @@
11
# Copyright (c) ModelScope Contributors. All rights reserved.
2+
import os
23
import socket
34
from datetime import timedelta
45
from typing import Optional
56

67
import torch
78

89

10+
def _ensure_hccl_socket_env(master_port: int, environ: Optional[dict] = None) -> None:
11+
"""Set deterministic HCCL socket env defaults to avoid port collisions.
12+
13+
In multi-job environments, HCCL's default base port (60000) can collide
14+
across concurrent jobs and lead to:
15+
`ra_hdc_socket_listen_start ... ret(-98)`.
16+
17+
We derive a per-job port layout from `master_port` so all ranks use the
18+
same values while reducing cross-job conflicts. Explicit user settings are
19+
preserved and never overwritten.
20+
"""
21+
env = os.environ if environ is None else environ
22+
if "HCCL_IF_BASE_PORT" not in env:
23+
# 20000-39999, with an offset to avoid colliding with TCPStore port.
24+
env["HCCL_IF_BASE_PORT"] = str(20000 + ((master_port + 997) % 20000))
25+
if "HCCL_HOST_SOCKET_PORT_RANGE" not in env:
26+
# 40000-40511 ... 47999-48510
27+
start = 40000 + (master_port % 8000)
28+
env["HCCL_HOST_SOCKET_PORT_RANGE"] = f"{start}-{start + 511}"
29+
if "HCCL_NPU_SOCKET_PORT_RANGE" not in env:
30+
# 50000-50511 ... 57999-58510
31+
start = 50000 + (master_port % 8000)
32+
env["HCCL_NPU_SOCKET_PORT_RANGE"] = f"{start}-{start + 511}"
33+
34+
935
def is_valid_ipv6_address(ip: str) -> bool:
1036
"""Check if the given string is a valid IPv6 address."""
1137
try:
@@ -87,6 +113,7 @@ def stateless_init_process_group(
87113
from vllm.distributed.utils import StatelessProcessGroup
88114

89115
if backend == "hccl":
116+
_ensure_hccl_socket_env(master_port)
90117
from vllm_ascend.distributed.device_communicators.pyhccl import (
91118
PyHcclCommunicator as Communicator,
92119
)

src/twinkle/utils/platform.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
# Copyright (c) ModelScope Contributors. All rights reserved.
22
import os
33
import platform
4+
import hashlib
5+
import re
46
import shutil
7+
import socket
58
import subprocess
69
from abc import ABC
710
from dataclasses import dataclass, field
@@ -641,5 +644,67 @@ def is_last_rank():
641644
return True
642645
return dist.get_rank() == dist.get_world_size() - 1
643646

647+
648+
def _resolve_ascend_physical_device_id(device_id: int) -> int:
649+
"""Map local NPU device index to physical device id via visible devices."""
650+
visible = os.environ.get("ASCEND_RT_VISIBLE_DEVICES", "").strip()
651+
if not visible:
652+
return device_id
653+
parts = [p.strip() for p in visible.split(",") if p.strip()]
654+
if device_id < 0 or device_id >= len(parts):
655+
return device_id
656+
return int(parts[device_id])
657+
658+
659+
def _get_npu_bus_id_from_npu_smi(device_id: int) -> Optional[str]:
660+
"""Get NPU Bus-Id from `npu-smi info` output."""
661+
try:
662+
physical_id = _resolve_ascend_physical_device_id(device_id)
663+
except Exception:
664+
physical_id = device_id
665+
666+
try:
667+
output = subprocess.check_output(
668+
["npu-smi", "info"],
669+
text=True,
670+
stderr=subprocess.STDOUT,
671+
timeout=5,
672+
)
673+
except Exception:
674+
return None
675+
676+
# Typical line:
677+
# | 0 0 | 0000:9D:00.0 | ...
678+
pattern = re.compile(
679+
r"^\|\s*\d+\s+(\d+)\s*\|\s*"
680+
r"([0-9A-Fa-f]{4}:[0-9A-Fa-f]{2}:[0-9A-Fa-f]{2}\.[0-9A-Fa-f])\s*\|",
681+
re.MULTILINE,
682+
)
683+
for match in pattern.finditer(output):
684+
phy_id = int(match.group(1))
685+
if phy_id == physical_id:
686+
return match.group(2).lower()
687+
return None
688+
689+
690+
def get_vllm_device_uuid(device_id: int = 0) -> str:
691+
"""Get vLLM device uuid with NPU Bus-Id special handling."""
692+
from vllm.platforms import current_platform
693+
694+
try:
695+
return current_platform.get_device_uuid(device_id)
696+
except NotImplementedError:
697+
# NPU special case: prefer stable PCIe Bus-Id from npu-smi.
698+
bus_id = _get_npu_bus_id_from_npu_smi(device_id)
699+
if bus_id:
700+
return bus_id
701+
# Generic deterministic fallback to keep sender/receiver socket names aligned.
702+
visible = os.environ.get("ASCEND_RT_VISIBLE_DEVICES") or os.environ.get(
703+
"CUDA_VISIBLE_DEVICES", ""
704+
)
705+
raw = f"{socket.gethostname()}:{visible}:{device_id}"
706+
return hashlib.sha1(raw.encode("utf-8")).hexdigest()[:16]
707+
708+
644709
def is_master():
645710
return Platform.is_master()

0 commit comments

Comments
 (0)