Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 0 additions & 26 deletions src/twinkle/advantage/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,34 +3,8 @@
from .grpo import GRPOAdvantage
from .rloo import RLOOAdvantage


# TODO: Temporary helpers added to unblock cookbook/grpo examples.
# Each call creates a new Advantage instance, not suitable for production.
# Remove once the framework provides a proper advantage computation API.
def compute_advantages(rewards, num_generations=1, scale='group', **kwargs):
"""Backward-compatible helper for GRPO advantage computation."""
return GRPOAdvantage()(
rewards=rewards,
num_generations=num_generations,
scale=scale,
**kwargs,
)


def compute_advantages_rloo(rewards, num_generations=1, scale='group', **kwargs):
"""Backward-compatible helper for RLOO advantage computation."""
return RLOOAdvantage()(
rewards=rewards,
num_generations=num_generations,
scale=scale,
**kwargs,
)


__all__ = [
'Advantage',
'GRPOAdvantage',
'RLOOAdvantage',
'compute_advantages',
'compute_advantages_rloo',
]
14 changes: 8 additions & 6 deletions src/twinkle/checkpoint_engine/base.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
# Adapted from https://github.com/volcengine/verl/blob/main/verl/checkpoint_engine/base.py
import torch
from abc import ABC, abstractmethod
from typing import Any, AsyncGenerator, Generator, TypedDict
from typing import TYPE_CHECKING, Any, AsyncGenerator, Generator, TypedDict

if TYPE_CHECKING:
import torch


class TensorMeta(TypedDict):
"""Metadata for a tensor in the weight bucket."""
name: str
shape: torch.Size
dtype: torch.dtype
shape: 'torch.Size'
dtype: 'torch.dtype'
offset: int


Expand Down Expand Up @@ -99,7 +101,7 @@ def finalize(self):
raise NotImplementedError

@abstractmethod
async def send_weights(self, weights: Generator[tuple[str, torch.Tensor], None, None]):
async def send_weights(self, weights: Generator[tuple[str, 'torch.Tensor'], None, None]):
"""Send model weights to rollout workers.

This method streams weights in buckets to avoid memory issues with
Expand All @@ -112,7 +114,7 @@ async def send_weights(self, weights: Generator[tuple[str, torch.Tensor], None,
raise NotImplementedError

@abstractmethod
async def receive_weights(self) -> AsyncGenerator[tuple[str, torch.Tensor], None]:
async def receive_weights(self) -> AsyncGenerator[tuple[str, 'torch.Tensor'], None]:
"""Receive model weights from trainer.

This method receives weights in buckets and yields them as they
Expand Down
2 changes: 1 addition & 1 deletion src/twinkle/checkpoint_engine/hccl_checkpoint_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from typing import Any, AsyncGenerator, Generator

from twinkle import get_logger
from twinkle.utils.network import find_free_port, find_node_ip, is_valid_ipv6_address, stateless_init_process_group
from twinkle.utils import find_free_port, find_node_ip, is_valid_ipv6_address, stateless_init_process_group
from .base import CheckpointEngine, TensorMeta

logger = get_logger()
Expand Down
13 changes: 8 additions & 5 deletions src/twinkle/infra/_ray/resource_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,7 @@ def get_visible_devices():

self.device_groups = {}
ray_address = str(ray.get_runtime_context().gcs_address)
assert len(groups) == len(visible_devices)
for group, visible_device_list in zip(groups, self.visible_devices):
for group in groups:
if group.device_type != 'CPU':
ranks = group.ranks
gpus_per_worker = getattr(group, 'gpus_per_worker', 1)
Expand All @@ -195,8 +194,12 @@ def get_visible_devices():
worker_ranks = normalized_ranks[start_idx:start_idx + gpus_per_worker]

# All GPUs for a worker should be on the same node
node_ranks = [r // nproc_per_node for r in worker_ranks]
gpu_ranks_local = [visible_device_list[r % nproc_per_node] for r in worker_ranks]
gpu_ranks_local = []
for r in worker_ranks:
node_rank = r // nproc_per_node
node_ranks.append(node_rank)
gpu_ranks = self.visible_devices[node_rank][r % nproc_per_node]
gpu_ranks_local.append(gpu_ranks)

if len(set(node_ranks)) > 1:
raise ValueError(f"DeviceGroup '{group.name}': GPUs {worker_ranks} span multiple nodes. "
Expand All @@ -211,7 +214,7 @@ def get_visible_devices():
else:
for alloc_rank in normalized_ranks:
node_rank = alloc_rank // nproc_per_node
gpu_rank = visible_device_list[alloc_rank % nproc_per_node]
gpu_rank = self.visible_devices[node_rank][alloc_rank % nproc_per_node]
local_device_groups.append(
dict(gpu_rank=[gpu_rank], placement_group=self.node2pg[node_rank], ray_address=ray_address))

Expand Down
4 changes: 2 additions & 2 deletions src/twinkle/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,9 @@ def _try_init_process_group(self):
# fix: Inject deterministic per-job port ranges before PG init to reduce cross-job conflicts.
# Keep training-side HCCL sockets on a per-job port layout to
# avoid collisions with other jobs on the same host.
from twinkle.utils.network import _ensure_hccl_socket_env
from twinkle.utils.platforms import ensure_hccl_socket_env
master_port = int(os.environ.get('MASTER_PORT', '29500'))
_ensure_hccl_socket_env(master_port)
ensure_hccl_socket_env(master_port)
init_kwargs = {
'backend': backend,
'init_method': 'env://',
Expand Down
4 changes: 2 additions & 2 deletions src/twinkle/model/megatron/megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -1213,7 +1213,7 @@ def _save_megatron_format(self, output_dir: str, adapter_name: str, lora_convert
torch.save(cpu_state_dict, checkpoint_path)

def _save_tokenizer(self, output_dir: str, **kwargs):
from twinkle.utils.platform import is_last_rank
from twinkle.utils import is_last_rank
if not is_last_rank():
return

Expand Down Expand Up @@ -1344,7 +1344,7 @@ def add_adapter_to_model(

@remote_function()
def apply_patch(self, patch_cls: Union[Patch, Type[Patch], str], **kwargs):
apply_patch(self, patch_cls, **kwargs)
apply_patch(self.model, patch_cls, **kwargs)

@remote_function(dispatch='all')
def set_template(self, template_cls: Union[Template, Type[Template], str], **kwargs):
Expand Down
3 changes: 1 addition & 2 deletions src/twinkle/model/megatron/model/gpt_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@
from twinkle.hub import HubOperation
from twinkle.model.megatron.args import get_args # Use twinkle's get_args
from twinkle.utils import (MxFp4Dequantizer, SafetensorLazyLoader, StreamingSafetensorSaver, deep_getattr, get_logger,
get_modules_to_not_convert, get_multimodal_target_regex, requires)
from twinkle.utils.platform import is_last_rank
get_modules_to_not_convert, get_multimodal_target_regex, is_last_rank, requires)

logger = get_logger()

Expand Down
2 changes: 1 addition & 1 deletion src/twinkle/model/transformers/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -754,7 +754,7 @@ def set_lr_scheduler(self, scheduler_cls: Union[Type[LRScheduler], str, LRSchedu

@remote_function()
def apply_patch(self, patch_cls: Union[Patch, Type[Patch], str], **kwargs):
apply_patch(self, patch_cls, **kwargs)
apply_patch(self.model, patch_cls, **kwargs)

def __del__(self):
HubOperation.wait_for()
Expand Down
7 changes: 4 additions & 3 deletions src/twinkle/patch/base.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
from typing import Any, Type, Union
from typing import TYPE_CHECKING, List, Union

from twinkle.utils import construct_class
if TYPE_CHECKING:
import torch


class Patch:

def __call__(self, module, *args, **kwargs):
def __call__(self, module: Union['torch.nn.Module', List['torch.nn.Module']], *args, **kwargs):
...
4 changes: 2 additions & 2 deletions src/twinkle/sampler/vllm_sampler/vllm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from twinkle import get_logger
from twinkle.data_format.sampling import SampledSequence, SampleResponse, SamplingParams, StopReason
from twinkle.sampler.base_engine import BaseSamplerEngine
from twinkle.utils.platform import get_vllm_device_uuid
from twinkle.utils import Platform

logger = get_logger()

Expand Down Expand Up @@ -524,7 +524,7 @@ async def _sync_iter():
# fix: Route through platform-level fallback so IPC socket name remains stable.
# Get device UUID for ZMQ handle.
# For NPU, this is resolved from `npu-smi info` Bus-Id when needed.
device_uuid = get_vllm_device_uuid(0)
device_uuid = Platform.get_vllm_device_uuid(0)
zmq_handle = f'ipc:///tmp/twinkle-ipc-{device_uuid}.sock'

bucket_size = bucket_size_mb << 20
Expand Down
2 changes: 1 addition & 1 deletion src/twinkle/sampler/vllm_sampler/vllm_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from twinkle.data_format import InputFeature, SampledSequence, SampleResponse, SamplingParams, Trajectory
from twinkle.patch.vllm_lora_weights import VLLMLoraWeights
from twinkle.sampler.base import Sampler
from twinkle.utils.platform import Platform
from twinkle.utils import Platform

logger = get_logger()

Expand Down
4 changes: 2 additions & 2 deletions src/twinkle/sampler/vllm_sampler/vllm_worker_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
from typing import Dict, List, Optional, Tuple

from twinkle import get_logger
from twinkle.utils import Platform
from twinkle.utils.framework import Torch
from twinkle.utils.platform import get_vllm_device_uuid

logger = get_logger()

Expand Down Expand Up @@ -376,5 +376,5 @@ def _get_zmq_handle(self) -> str:
"""Get ZMQ handle for IPC communication."""
if not hasattr(self, '_device_uuid') or not self._device_uuid:
# fix: Always use platform fallback to avoid worker-side crashes when NPU get_device_uuid is unimplemented.
self._device_uuid = get_vllm_device_uuid(self.device.index)
self._device_uuid = Platform.get_vllm_device_uuid(self.device.index)
return f'ipc:///tmp/twinkle-ipc-{self._device_uuid}.sock'
7 changes: 4 additions & 3 deletions src/twinkle/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
from .dequantizer import Fp8Dequantizer, MxFp4Dequantizer
from .device_mesh import DeviceGroup, DeviceMesh, is_last_rank, is_master
from .framework import Framework as framework_util
from .framework import Torch as torch_util
from .import_utils import exists, requires
from .loader import Plugin, construct_class
from .logger import get_logger
from .network import find_free_port, find_node_ip
from .network import find_free_port, find_node_ip, is_valid_ipv6_address
from .parallel import processing_lock
from .platform import GPU, NPU, DeviceGroup, DeviceMesh, Platform
from .platforms import GPU, NPU, Platform, ensure_hccl_socket_env, ensure_npu_backend
from .safetensors import LazyTensor, SafetensorLazyLoader, StreamingSafetensorSaver
from .torch_utils import to_device
from .torch_utils import pad_sequence_to_length, selective_log_softmax, stateless_init_process_group, to_device
from .transformers_utils import find_all_linears, find_layers, get_modules_to_not_convert, get_multimodal_target_regex
from .unsafe import check_unsafe, trust_remote_code
from .utils import copy_files_by_pattern, deep_getattr
Loading
Loading