Skip to content

Commit 7bcbfe5

Browse files
authored
[fix] http_options leaking to model init & NPU tensor serialization failure over HTTP (#109)
* fix * fix * fix * fix
1 parent f7a0c3b commit 7bcbfe5

File tree

5 files changed

+50
-8
lines changed

5 files changed

+50
-8
lines changed

cookbook/client/twinkle/self_congnition.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,10 +102,11 @@ def train():
102102
for step, batch in enumerate(dataloader):
103103
# Forward pass + backward pass (computes gradients)
104104
output = model.forward_backward(inputs=batch)
105+
loss=output.get('loss', 'N/A')
105106

106107
# Log the loss every 2 steps (aligned with gradient accumulation)
107108
if step % 2 == 0:
108-
logger.info(f'Current is step {step // 2}, loss: {output}')
109+
logger.info(f'Current is step {step // 2}, loss: {loss}')
109110

110111
# Clip gradients to prevent exploding gradients (max norm = 1.0)
111112
model.clip_grad_norm(1.0)

src/twinkle/model/transformers/transformers.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1108,7 +1108,7 @@ def get_train_configs(self, **kwargs) -> str:
11081108
return expr
11091109

11101110
# =========================================================================
1111-
# Checkpoint Engine — Weight Sync (from CheckpointEngineMixin)
1111+
# Checkpoint Engine weight sync (from CheckpointEngineMixin)
11121112
# =========================================================================
11131113
# prepare_checkpoint_engine, init_checkpoint_process_group, and
11141114
# finalize_checkpoint_engine are inherited from CheckpointEngineMixin.
@@ -1145,7 +1145,7 @@ def weight_generator():
11451145
if isinstance(model, PeftModel):
11461146
model.unmerge_adapter()
11471147
else:
1148-
# ── LoRA-only mode: send only adapter weights ────────────────
1148+
# LoRA-only mode: send only adapter weights.
11491149
# Use PEFT's get_peft_model_state_dict for clean LoRA extraction
11501150
from peft.utils import get_peft_model_state_dict
11511151
lora_state_dict = get_peft_model_state_dict(model, adapter_name=adapter_name)
@@ -1156,7 +1156,7 @@ def weight_generator():
11561156
yield name, tensor
11571157

11581158
else:
1159-
# ── Full model mode: send all weights (base model sync) ──────
1159+
# Full model mode: send all weights (base model sync).
11601160
state_dict = model.state_dict()
11611161

11621162
def weight_generator():

src/twinkle/server/launcher.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ def _deploy_application(self, app_config: dict[str, Any]) -> None:
216216

217217
# Pass http_options to server apps for internal proxy routing
218218
http_options = self.config.get('http_options', {})
219-
if http_options:
219+
if import_path == 'server' and http_options:
220220
args['http_options'] = http_options
221221

222222
# Build and deploy the application
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Copyright (c) ModelScope Contributors. All rights reserved.
2+
import numpy as np
3+
import torch
4+
from collections.abc import Mapping
5+
from typing import Any, List, Union
6+
7+
from twinkle import remote_class, remote_function
8+
from twinkle.data_format import InputFeature, Trajectory
9+
from twinkle.model import MultiLoraTransformersModel
10+
11+
12+
@remote_class()
13+
class TwinkleCompatTransformersModel(MultiLoraTransformersModel):
14+
15+
@staticmethod
16+
def _to_cpu_safe_output(obj: Any) -> Any:
17+
"""Convert nested outputs into CPU-safe Python objects for HTTP transport."""
18+
from twinkle.utils import torch_util
19+
20+
if isinstance(obj, torch.Tensor):
21+
tensor = torch_util.to_local_tensor(obj).detach().cpu()
22+
if tensor.numel() == 1:
23+
return tensor.item()
24+
return tensor.tolist()
25+
if isinstance(obj, np.ndarray):
26+
if obj.size == 1:
27+
return obj.item()
28+
return obj.tolist()
29+
if isinstance(obj, np.generic):
30+
return obj.item()
31+
if isinstance(obj, Mapping):
32+
return {key: TwinkleCompatTransformersModel._to_cpu_safe_output(value) for key, value in obj.items()}
33+
if isinstance(obj, (list, tuple)):
34+
return [TwinkleCompatTransformersModel._to_cpu_safe_output(value) for value in obj]
35+
return obj
36+
37+
@remote_function(dispatch='slice_dp', collect='mean')
38+
def forward_backward(self, *, inputs: Union[InputFeature, List[InputFeature], Trajectory, List[Trajectory]],
39+
**kwargs):
40+
output = super().forward_backward(inputs=inputs, **kwargs)
41+
return self._to_cpu_safe_output(output)

src/twinkle/server/twinkle/model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -182,8 +182,8 @@ def __init__(self, nproc_per_node: int, device_group: Dict[str, Any], device_mes
182182
instance_id=replica_id,
183183
**kwargs)
184184
else:
185-
from twinkle.model import MultiLoraTransformersModel
186-
self.model = MultiLoraTransformersModel(
185+
from .common.transformers_model import TwinkleCompatTransformersModel
186+
self.model = TwinkleCompatTransformersModel(
187187
model_id=model_id,
188188
device_mesh=self.device_mesh,
189189
remote_group=self.device_group.name,
@@ -296,7 +296,7 @@ def forward_backward(self, request: Request, body: ForwardRequest):
296296
assert isinstance(inputs, dict)
297297
inputs = InputFeature(**inputs) if 'input_ids' in inputs else Trajectory(**inputs)
298298
ret = self.model.forward_backward(inputs=inputs, adapter_name=adapter_name, **extra_kwargs)
299-
return {'result': str(ret)}
299+
return {'result': ret}
300300

301301
@app.post('/get_train_configs')
302302
def get_train_configs(self, request: Request, body: AdapterRequest):

0 commit comments

Comments
 (0)