66- TwinkleCompatTransformersModel: handles both tinker (Datum-based I/O) via /tinker/*
77 endpoints and twinkle-native (InputFeature/Trajectory-based I/O) via /twinkle/* endpoints.
88"""
9- import numpy as np
10- import torch
11- from collections .abc import Mapping
129from tinker import types
13- from typing import Any , List , Union
10+ from typing import List , Union
1411
1512from twinkle import remote_class , remote_function
1613from twinkle .data_format import InputFeature , Trajectory
1714from twinkle .model import MultiLoraTransformersModel
1815from twinkle .server .common .datum import datum_to_input_feature , extract_rl_feature
19- from twinkle .server .model .backends .common import TwinkleCompatModelBase , clean_metrics , collect_forward_backward_results
16+ from twinkle .server .model .backends .common import (TwinkleCompatModelBase , clean_metrics ,
17+ collect_forward_backward_results , to_cpu_safe_output )
2018
2119
2220@remote_class ()
@@ -28,32 +26,6 @@ class TwinkleCompatTransformersModel(MultiLoraTransformersModel, TwinkleCompatMo
2826 - Twinkle-native I/O (InputFeature / Trajectory) via /twinkle/* endpoints.
2927 """
3028
31- # ------------------------------------------------------------------
32- # Shared helper: CPU-safe serialisation for HTTP transport
33- # ------------------------------------------------------------------
34-
35- @staticmethod
36- def _to_cpu_safe_output (obj : Any ) -> Any :
37- """Convert nested outputs into CPU-safe Python objects for HTTP transport."""
38- from twinkle .utils import torch_util
39-
40- if isinstance (obj , torch .Tensor ):
41- tensor = torch_util .to_local_tensor (obj ).detach ().cpu ()
42- if tensor .numel () == 1 :
43- return tensor .item ()
44- return tensor .tolist ()
45- if isinstance (obj , np .ndarray ):
46- if obj .size == 1 :
47- return obj .item ()
48- return obj .tolist ()
49- if isinstance (obj , np .generic ):
50- return obj .item ()
51- if isinstance (obj , Mapping ):
52- return {key : TwinkleCompatTransformersModel ._to_cpu_safe_output (value ) for key , value in obj .items ()}
53- if isinstance (obj , (list , tuple )):
54- return [TwinkleCompatTransformersModel ._to_cpu_safe_output (value ) for value in obj ]
55- return obj
56-
5729 # ------------------------------------------------------------------
5830 # Tinker-compat methods (Datum-based I/O)
5931 # ------------------------------------------------------------------
@@ -135,4 +107,4 @@ def forward_backward(self, *, inputs: Union[InputFeature, List[InputFeature], Tr
135107 ** kwargs ):
136108 """Forward+backward for twinkle-native clients (InputFeature/Trajectory I/O)."""
137109 output = super ().forward_backward (inputs = inputs , ** kwargs )
138- return self . _to_cpu_safe_output (output )
110+ return to_cpu_safe_output (output )
0 commit comments