Skip to content

Commit 51ebbbf

Browse files
committed
update
1 parent 2ca4e60 commit 51ebbbf

File tree

5 files changed

+56
-42
lines changed

5 files changed

+56
-42
lines changed

cookbook/client/tinker/self_host/short_math_grpo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ def main():
217217
from tinker import ServiceClient
218218
service_client = ServiceClient(
219219
base_url='http://localhost:8000',
220-
api_key=os.environ.get('MODELSCOPE_TOKEN')
220+
api_key='EMPTY_TOKEN'
221221
)
222222

223223
logger.info('Creating LoRA training client...')

cookbook/client/twinkle/self_host/grpo.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -103,12 +103,13 @@ def train():
103103
model.set_loss('GRPOLoss', epsilon=0.2, beta=0.0)
104104

105105
# Set optimizer and LR scheduler
106-
model.set_optimizer('AdamW', lr=LEARNING_RATE)
107-
model.set_lr_scheduler(
108-
'CosineWarmupScheduler',
109-
num_warmup_steps=500,
110-
num_training_steps=MAX_STEPS,
111-
)
106+
model.set_optimizer('Adam', lr=LEARNING_RATE)
107+
# Set LR scheduler (if server use megatron, don't support set self.lr_scheduler)
108+
# model.set_lr_scheduler(
109+
# 'CosineWarmupScheduler',
110+
# num_warmup_steps=500,
111+
# num_training_steps=MAX_STEPS,
112+
# )
112113

113114
# Set processor and template for encoding inputs
114115
model.set_processor('InputProcessor')

src/twinkle/server/model/backends/common.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@
55
import numpy as np
66
import re
77
import torch
8+
from collections.abc import Mapping
89
from numbers import Number
910
from tinker import types
10-
from typing import List
11+
from typing import Any, List
1112

1213
from twinkle import DeviceMesh
1314
from twinkle.template import Template
@@ -58,6 +59,33 @@ def collect_forward_backward_results(results, device_mesh: DeviceMesh):
5859
return [all_outputs, avg_loss]
5960

6061

62+
def to_cpu_safe_output(obj: Any) -> Any:
63+
"""Convert nested model outputs into CPU-safe Python objects for HTTP transport.
64+
65+
Recursively walks tensors, numpy arrays, mappings and sequences,
66+
converting each tensor/array to a plain Python scalar or list so
67+
Ray can serialise the result without requiring CUDA on the driver.
68+
"""
69+
from twinkle.utils import torch_util
70+
71+
if isinstance(obj, torch.Tensor):
72+
tensor = torch_util.to_local_tensor(obj).detach().cpu()
73+
if tensor.numel() == 1:
74+
return tensor.item()
75+
return tensor.tolist()
76+
if isinstance(obj, np.ndarray):
77+
if obj.size == 1:
78+
return obj.item()
79+
return obj.tolist()
80+
if isinstance(obj, np.generic):
81+
return obj.item()
82+
if isinstance(obj, Mapping):
83+
return {key: to_cpu_safe_output(value) for key, value in obj.items()}
84+
if isinstance(obj, (list, tuple)):
85+
return [to_cpu_safe_output(value) for value in obj]
86+
return obj
87+
88+
6189
def clean_metrics(metrics: dict) -> dict:
6290

6391
def _to_float(v):

src/twinkle/server/model/backends/megatron_model.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,14 @@
44
"""
55
import torch
66
from tinker import types
7-
from typing import TYPE_CHECKING, Any, List, Optional, Tuple
7+
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union
88

99
from twinkle import remote_class, remote_function
10+
from twinkle.data_format import InputFeature, Trajectory
1011
from twinkle.model.megatron import MultiLoraMegatronModel
1112
from twinkle.server.common.datum import datum_to_input_feature, extract_rl_feature
12-
from twinkle.server.model.backends.common import TwinkleCompatModelBase, clean_metrics, collect_forward_backward_results
13+
from twinkle.server.model.backends.common import (TwinkleCompatModelBase, clean_metrics,
14+
collect_forward_backward_results, to_cpu_safe_output)
1315

1416

1517
@remote_class(execute='all')
@@ -112,3 +114,14 @@ def tinker_load(self, checkpoint_dir: str, **kwargs):
112114
return super().load(name=resolved.checkpoint_name, output_dir=resolved.checkpoint_dir, **kwargs)
113115
else:
114116
return super().load(name=resolved.checkpoint_name, **kwargs)
117+
118+
# ------------------------------------------------------------------
119+
# Twinkle-native methods (InputFeature/Trajectory-based I/O)
120+
# ------------------------------------------------------------------
121+
122+
@remote_function(dispatch='slice_dp', collect='mean')
123+
def forward_backward(self, *, inputs: Union[InputFeature, List[InputFeature], Trajectory, List[Trajectory]],
124+
**kwargs):
125+
"""Forward+backward for twinkle-native clients (InputFeature/Trajectory I/O)."""
126+
output = super().forward_backward(inputs=inputs, **kwargs)
127+
return to_cpu_safe_output(output)

src/twinkle/server/model/backends/transformers_model.py

Lines changed: 4 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,15 @@
66
- TwinkleCompatTransformersModel: handles both tinker (Datum-based I/O) via /tinker/*
77
endpoints and twinkle-native (InputFeature/Trajectory-based I/O) via /twinkle/* endpoints.
88
"""
9-
import numpy as np
10-
import torch
11-
from collections.abc import Mapping
129
from tinker import types
13-
from typing import Any, List, Union
10+
from typing import List, Union
1411

1512
from twinkle import remote_class, remote_function
1613
from twinkle.data_format import InputFeature, Trajectory
1714
from twinkle.model import MultiLoraTransformersModel
1815
from twinkle.server.common.datum import datum_to_input_feature, extract_rl_feature
19-
from twinkle.server.model.backends.common import TwinkleCompatModelBase, clean_metrics, collect_forward_backward_results
16+
from twinkle.server.model.backends.common import (TwinkleCompatModelBase, clean_metrics,
17+
collect_forward_backward_results, to_cpu_safe_output)
2018

2119

2220
@remote_class()
@@ -28,32 +26,6 @@ class TwinkleCompatTransformersModel(MultiLoraTransformersModel, TwinkleCompatMo
2826
- Twinkle-native I/O (InputFeature / Trajectory) via /twinkle/* endpoints.
2927
"""
3028

31-
# ------------------------------------------------------------------
32-
# Shared helper: CPU-safe serialisation for HTTP transport
33-
# ------------------------------------------------------------------
34-
35-
@staticmethod
36-
def _to_cpu_safe_output(obj: Any) -> Any:
37-
"""Convert nested outputs into CPU-safe Python objects for HTTP transport."""
38-
from twinkle.utils import torch_util
39-
40-
if isinstance(obj, torch.Tensor):
41-
tensor = torch_util.to_local_tensor(obj).detach().cpu()
42-
if tensor.numel() == 1:
43-
return tensor.item()
44-
return tensor.tolist()
45-
if isinstance(obj, np.ndarray):
46-
if obj.size == 1:
47-
return obj.item()
48-
return obj.tolist()
49-
if isinstance(obj, np.generic):
50-
return obj.item()
51-
if isinstance(obj, Mapping):
52-
return {key: TwinkleCompatTransformersModel._to_cpu_safe_output(value) for key, value in obj.items()}
53-
if isinstance(obj, (list, tuple)):
54-
return [TwinkleCompatTransformersModel._to_cpu_safe_output(value) for value in obj]
55-
return obj
56-
5729
# ------------------------------------------------------------------
5830
# Tinker-compat methods (Datum-based I/O)
5931
# ------------------------------------------------------------------
@@ -135,4 +107,4 @@ def forward_backward(self, *, inputs: Union[InputFeature, List[InputFeature], Tr
135107
**kwargs):
136108
"""Forward+backward for twinkle-native clients (InputFeature/Trajectory I/O)."""
137109
output = super().forward_backward(inputs=inputs, **kwargs)
138-
return self._to_cpu_safe_output(output)
110+
return to_cpu_safe_output(output)

0 commit comments

Comments
 (0)