|
| 1 | +# Copyright (c) ModelScope Contributors. All rights reserved. |
| 2 | +import numpy as np |
| 3 | +import torch |
| 4 | +from collections.abc import Mapping |
| 5 | +from typing import Any, List, Union |
| 6 | + |
| 7 | +from twinkle import remote_class, remote_function |
| 8 | +from twinkle.data_format import InputFeature, Trajectory |
| 9 | +from twinkle.model import MultiLoraTransformersModel |
| 10 | + |
| 11 | + |
| 12 | +@remote_class() |
| 13 | +class TwinkleCompatTransformersModel(MultiLoraTransformersModel): |
| 14 | + |
| 15 | + @staticmethod |
| 16 | + def _to_cpu_safe_output(obj: Any) -> Any: |
| 17 | + """Convert nested outputs into CPU-safe Python objects for HTTP transport.""" |
| 18 | + from twinkle.utils import torch_util |
| 19 | + |
| 20 | + if isinstance(obj, torch.Tensor): |
| 21 | + tensor = torch_util.to_local_tensor(obj).detach().cpu() |
| 22 | + if tensor.numel() == 1: |
| 23 | + return tensor.item() |
| 24 | + return tensor.tolist() |
| 25 | + if isinstance(obj, np.ndarray): |
| 26 | + if obj.size == 1: |
| 27 | + return obj.item() |
| 28 | + return obj.tolist() |
| 29 | + if isinstance(obj, np.generic): |
| 30 | + return obj.item() |
| 31 | + if isinstance(obj, Mapping): |
| 32 | + return {key: TwinkleCompatTransformersModel._to_cpu_safe_output(value) for key, value in obj.items()} |
| 33 | + if isinstance(obj, (list, tuple)): |
| 34 | + return [TwinkleCompatTransformersModel._to_cpu_safe_output(value) for value in obj] |
| 35 | + return obj |
| 36 | + |
| 37 | + @remote_function(dispatch='slice_dp', collect='mean') |
| 38 | + def forward_backward(self, *, inputs: Union[InputFeature, List[InputFeature], Trajectory, List[Trajectory]], |
| 39 | + **kwargs): |
| 40 | + output = super().forward_backward(inputs=inputs, **kwargs) |
| 41 | + return self._to_cpu_safe_output(output) |
0 commit comments