Skip to content

Commit 317f9b8

Browse files
Refactor code (#72)
1 parent c5a08c2 commit 317f9b8

File tree

24 files changed

+477
-454
lines changed

24 files changed

+477
-454
lines changed

src/twinkle/advantage/__init__.py

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -3,34 +3,8 @@
33
from .grpo import GRPOAdvantage
44
from .rloo import RLOOAdvantage
55

6-
7-
# TODO: Temporary helpers added to unblock cookbook/grpo examples.
8-
# Each call creates a new Advantage instance, not suitable for production.
9-
# Remove once the framework provides a proper advantage computation API.
10-
def compute_advantages(rewards, num_generations=1, scale='group', **kwargs):
11-
"""Backward-compatible helper for GRPO advantage computation."""
12-
return GRPOAdvantage()(
13-
rewards=rewards,
14-
num_generations=num_generations,
15-
scale=scale,
16-
**kwargs,
17-
)
18-
19-
20-
def compute_advantages_rloo(rewards, num_generations=1, scale='group', **kwargs):
21-
"""Backward-compatible helper for RLOO advantage computation."""
22-
return RLOOAdvantage()(
23-
rewards=rewards,
24-
num_generations=num_generations,
25-
scale=scale,
26-
**kwargs,
27-
)
28-
29-
306
__all__ = [
317
'Advantage',
328
'GRPOAdvantage',
339
'RLOOAdvantage',
34-
'compute_advantages',
35-
'compute_advantages_rloo',
3610
]

src/twinkle/checkpoint_engine/base.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
11
# Copyright (c) ModelScope Contributors. All rights reserved.
22
# Adapted from https://github.com/volcengine/verl/blob/main/verl/checkpoint_engine/base.py
3-
import torch
43
from abc import ABC, abstractmethod
5-
from typing import Any, AsyncGenerator, Generator, TypedDict
4+
from typing import TYPE_CHECKING, Any, AsyncGenerator, Generator, TypedDict
5+
6+
if TYPE_CHECKING:
7+
import torch
68

79

810
class TensorMeta(TypedDict):
911
"""Metadata for a tensor in the weight bucket."""
1012
name: str
11-
shape: torch.Size
12-
dtype: torch.dtype
13+
shape: 'torch.Size'
14+
dtype: 'torch.dtype'
1315
offset: int
1416

1517

@@ -99,7 +101,7 @@ def finalize(self):
99101
raise NotImplementedError
100102

101103
@abstractmethod
102-
async def send_weights(self, weights: Generator[tuple[str, torch.Tensor], None, None]):
104+
async def send_weights(self, weights: Generator[tuple[str, 'torch.Tensor'], None, None]):
103105
"""Send model weights to rollout workers.
104106
105107
This method streams weights in buckets to avoid memory issues with
@@ -112,7 +114,7 @@ async def send_weights(self, weights: Generator[tuple[str, torch.Tensor], None,
112114
raise NotImplementedError
113115

114116
@abstractmethod
115-
async def receive_weights(self) -> AsyncGenerator[tuple[str, torch.Tensor], None]:
117+
async def receive_weights(self) -> AsyncGenerator[tuple[str, 'torch.Tensor'], None]:
116118
"""Receive model weights from trainer.
117119
118120
This method receives weights in buckets and yields them as they

src/twinkle/checkpoint_engine/hccl_checkpoint_engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from typing import Any, AsyncGenerator, Generator
1818

1919
from twinkle import get_logger
20-
from twinkle.utils.network import find_free_port, find_node_ip, is_valid_ipv6_address, stateless_init_process_group
20+
from twinkle.utils import find_free_port, find_node_ip, is_valid_ipv6_address, stateless_init_process_group
2121
from .base import CheckpointEngine, TensorMeta
2222

2323
logger = get_logger()

src/twinkle/infra/_ray/resource_manager.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -173,8 +173,7 @@ def get_visible_devices():
173173

174174
self.device_groups = {}
175175
ray_address = str(ray.get_runtime_context().gcs_address)
176-
assert len(groups) == len(visible_devices)
177-
for group, visible_device_list in zip(groups, self.visible_devices):
176+
for group in groups:
178177
if group.device_type != 'CPU':
179178
ranks = group.ranks
180179
gpus_per_worker = getattr(group, 'gpus_per_worker', 1)
@@ -195,8 +194,12 @@ def get_visible_devices():
195194
worker_ranks = normalized_ranks[start_idx:start_idx + gpus_per_worker]
196195

197196
# All GPUs for a worker should be on the same node
198-
node_ranks = [r // nproc_per_node for r in worker_ranks]
199-
gpu_ranks_local = [visible_device_list[r % nproc_per_node] for r in worker_ranks]
197+
gpu_ranks_local = []
198+
for r in worker_ranks:
199+
node_rank = r // nproc_per_node
200+
node_ranks.append(node_rank)
201+
gpu_ranks = self.visible_devices[node_rank][r % nproc_per_node]
202+
gpu_ranks_local.append(gpu_ranks)
200203

201204
if len(set(node_ranks)) > 1:
202205
raise ValueError(f"DeviceGroup '{group.name}': GPUs {worker_ranks} span multiple nodes. "
@@ -211,7 +214,7 @@ def get_visible_devices():
211214
else:
212215
for alloc_rank in normalized_ranks:
213216
node_rank = alloc_rank // nproc_per_node
214-
gpu_rank = visible_device_list[alloc_rank % nproc_per_node]
217+
gpu_rank = self.visible_devices[node_rank][alloc_rank % nproc_per_node]
215218
local_device_groups.append(
216219
dict(gpu_rank=[gpu_rank], placement_group=self.node2pg[node_rank], ray_address=ray_address))
217220

src/twinkle/model/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,9 +145,9 @@ def _try_init_process_group(self):
145145
# fix: Inject deterministic per-job port ranges before PG init to reduce cross-job conflicts.
146146
# Keep training-side HCCL sockets on a per-job port layout to
147147
# avoid collisions with other jobs on the same host.
148-
from twinkle.utils.network import _ensure_hccl_socket_env
148+
from twinkle.utils.platforms import ensure_hccl_socket_env
149149
master_port = int(os.environ.get('MASTER_PORT', '29500'))
150-
_ensure_hccl_socket_env(master_port)
150+
ensure_hccl_socket_env(master_port)
151151
init_kwargs = {
152152
'backend': backend,
153153
'init_method': 'env://',

src/twinkle/model/megatron/megatron.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1213,7 +1213,7 @@ def _save_megatron_format(self, output_dir: str, adapter_name: str, lora_convert
12131213
torch.save(cpu_state_dict, checkpoint_path)
12141214

12151215
def _save_tokenizer(self, output_dir: str, **kwargs):
1216-
from twinkle.utils.platform import is_last_rank
1216+
from twinkle.utils import is_last_rank
12171217
if not is_last_rank():
12181218
return
12191219

@@ -1344,7 +1344,7 @@ def add_adapter_to_model(
13441344

13451345
@remote_function()
13461346
def apply_patch(self, patch_cls: Union[Patch, Type[Patch], str], **kwargs):
1347-
apply_patch(self, patch_cls, **kwargs)
1347+
apply_patch(self.model, patch_cls, **kwargs)
13481348

13491349
@remote_function(dispatch='all')
13501350
def set_template(self, template_cls: Union[Template, Type[Template], str], **kwargs):

src/twinkle/model/megatron/model/gpt_bridge.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,7 @@
1818
from twinkle.hub import HubOperation
1919
from twinkle.model.megatron.args import get_args # Use twinkle's get_args
2020
from twinkle.utils import (MxFp4Dequantizer, SafetensorLazyLoader, StreamingSafetensorSaver, deep_getattr, get_logger,
21-
get_modules_to_not_convert, get_multimodal_target_regex, requires)
22-
from twinkle.utils.platform import is_last_rank
21+
get_modules_to_not_convert, get_multimodal_target_regex, is_last_rank, requires)
2322

2423
logger = get_logger()
2524

src/twinkle/model/transformers/transformers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -754,7 +754,7 @@ def set_lr_scheduler(self, scheduler_cls: Union[Type[LRScheduler], str, LRSchedu
754754

755755
@remote_function()
756756
def apply_patch(self, patch_cls: Union[Patch, Type[Patch], str], **kwargs):
757-
apply_patch(self, patch_cls, **kwargs)
757+
apply_patch(self.model, patch_cls, **kwargs)
758758

759759
def __del__(self):
760760
HubOperation.wait_for()

src/twinkle/patch/base.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
# Copyright (c) ModelScope Contributors. All rights reserved.
2-
from typing import Any, Type, Union
2+
from typing import TYPE_CHECKING, List, Union
33

4-
from twinkle.utils import construct_class
4+
if TYPE_CHECKING:
5+
import torch
56

67

78
class Patch:
89

9-
def __call__(self, module, *args, **kwargs):
10+
def __call__(self, module: Union['torch.nn.Module', List['torch.nn.Module']], *args, **kwargs):
1011
...

src/twinkle/sampler/vllm_sampler/vllm_engine.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from twinkle import get_logger
99
from twinkle.data_format.sampling import SampledSequence, SampleResponse, SamplingParams, StopReason
1010
from twinkle.sampler.base_engine import BaseSamplerEngine
11-
from twinkle.utils.platform import get_vllm_device_uuid
11+
from twinkle.utils import Platform
1212

1313
logger = get_logger()
1414

@@ -524,7 +524,7 @@ async def _sync_iter():
524524
# fix: Route through platform-level fallback so IPC socket name remains stable.
525525
# Get device UUID for ZMQ handle.
526526
# For NPU, this is resolved from `npu-smi info` Bus-Id when needed.
527-
device_uuid = get_vllm_device_uuid(0)
527+
device_uuid = Platform.get_vllm_device_uuid(0)
528528
zmq_handle = f'ipc:///tmp/twinkle-ipc-{device_uuid}.sock'
529529

530530
bucket_size = bucket_size_mb << 20

0 commit comments

Comments
 (0)