Skip to content

Commit 502354a

Browse files
authored
[Fix] Prevent client from importing ray via twinkle.server.common.serialize (#123)
1 parent 7fd432f commit 502354a

File tree

2 files changed

+44
-1
lines changed

2 files changed

+44
-1
lines changed
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Copyright (c) ModelScope Contributors. All rights reserved.
2+
import json
3+
from numbers import Number
4+
from peft import LoraConfig
5+
from typing import Mapping
6+
7+
from twinkle.dataset import DatasetMeta
8+
9+
primitive_types = (str, Number, bool, bytes, type(None))
10+
container_types = (Mapping, list, tuple, set, frozenset)
11+
basic_types = (*primitive_types, *container_types)
12+
13+
14+
def _serialize_data_slice(data_slice):
15+
if data_slice is None:
16+
return None
17+
if isinstance(data_slice, range):
18+
return {'_slice_type_': 'range', 'start': data_slice.start, 'stop': data_slice.stop, 'step': data_slice.step}
19+
if isinstance(data_slice, (list, tuple)):
20+
return {'_slice_type_': 'list', 'values': list(data_slice)}
21+
raise ValueError(f'Http mode does not support data_slice of type {type(data_slice).__name__}. '
22+
'Supported types: range, list, tuple.')
23+
24+
def serialize_object(obj) -> str:
25+
if isinstance(obj, DatasetMeta):
26+
data = obj.__dict__.copy()
27+
data['data_slice'] = _serialize_data_slice(data.get('data_slice'))
28+
data['_TWINKLE_TYPE_'] = 'DatasetMeta'
29+
return json.dumps(data, ensure_ascii=False)
30+
elif isinstance(obj, LoraConfig):
31+
filtered_dict = {
32+
_subkey: _subvalue
33+
for _subkey, _subvalue in obj.__dict__.items()
34+
if isinstance(_subvalue, basic_types) and not _subkey.startswith('_')
35+
}
36+
filtered_dict['_TWINKLE_TYPE_'] = 'LoraConfig'
37+
return json.dumps(filtered_dict, ensure_ascii=False)
38+
elif isinstance(obj, Mapping):
39+
return json.dumps(obj, ensure_ascii=False)
40+
elif isinstance(obj, basic_types):
41+
return obj
42+
else:
43+
raise ValueError(f'Unsupported object: {obj}')

src/twinkle_client/http/http_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def _serialize_params(params: Dict[str, Any]) -> Dict[str, Any]:
4545
if hasattr(value, 'processor_id'):
4646
serialized[key] = value.processor_id
4747
elif hasattr(value, '__dict__'):
48-
from twinkle.server.common.serialize import serialize_object
48+
from twinkle_client.common.serialize import serialize_object
4949
serialized[key] = serialize_object(value)
5050
else:
5151
serialized[key] = value

0 commit comments

Comments
 (0)