Skip to content

Commit 9bee81d

Browse files
committed
update init
1 parent 7178b06 commit 9bee81d

File tree

8 files changed

+46
-27
lines changed

8 files changed

+46
-27
lines changed

src/twinkle/server/__init__.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,7 @@
11
# Copyright (c) ModelScope Contributors. All rights reserved.
22
from .launcher import ServerLauncher, launch_server
3-
from .twinkle.model import build_model_app
4-
from .twinkle.processor import build_processor_app
5-
from .twinkle.sampler import build_sampler_app
6-
from .twinkle.server import build_server_app
73

84
__all__ = [
9-
'build_model_app',
10-
'build_processor_app',
11-
'build_sampler_app',
12-
'build_server_app',
135
'ServerLauncher',
146
'launch_server',
157
]

src/twinkle/server/launcher.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def _get_builders(self) -> dict[str, Callable]:
101101
'build_sampler_app': build_sampler_app,
102102
}
103103
else: # twinkle
104-
from twinkle.server import build_model_app, build_processor_app, build_sampler_app, build_server_app
104+
from twinkle.server.twinkle import build_model_app, build_processor_app, build_sampler_app, build_server_app
105105
self._builders = {
106106
'build_server_app': build_server_app,
107107
'build_model_app': build_model_app,
Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,18 @@
11
# Copyright (c) ModelScope Contributors. All rights reserved.
2+
import sys
3+
from typing import TYPE_CHECKING
24

3-
from ..utils import wrap_builder_with_device_group_env
4-
from .model import build_model_app as _build_model_app
5-
from .sampler import build_sampler_app as _build_sampler_app
6-
from .server import build_server_app
5+
from twinkle.utils.import_utils import _LazyModule
76

8-
build_model_app = wrap_builder_with_device_group_env(_build_model_app)
9-
build_sampler_app = wrap_builder_with_device_group_env(_build_sampler_app)
7+
_import_structure = {
8+
'model': ['build_model_app'],
9+
'sampler': ['build_sampler_app'],
10+
'server': ['build_server_app'],
11+
}
1012

11-
__all__ = [
12-
'build_model_app',
13-
'build_sampler_app',
14-
'build_server_app',
15-
]
13+
if TYPE_CHECKING:
14+
from .model import build_model_app
15+
from .sampler import build_sampler_app
16+
from .server import build_server_app
17+
else:
18+
sys.modules[__name__] = _LazyModule(__name__, __file__, _import_structure, module_spec=__spec__)

src/twinkle/server/tinker/common/datum.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import numpy as np
44
from collections import defaultdict
55
from tinker import types
6-
from typing import List, Union
76

87
from twinkle.data_format.input_feature import InputFeature
98
from twinkle.template import Template
@@ -92,6 +91,8 @@ def input_feature_to_datum(input_feature: InputFeature) -> types.Datum:
9291
labels_raw = input_feature['labels']
9392
if isinstance(labels_raw, np.ndarray):
9493
labels_arr = labels_raw.astype(np.int64)
94+
elif isinstance(labels_raw, list):
95+
labels_arr = np.asarray(labels_raw, dtype=np.int64)
9596
else:
9697
labels_arr = np.asarray(labels_raw.cpu(), dtype=np.int64)
9798

src/twinkle/server/tinker/model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from twinkle.server.utils.task_queue import TaskQueueConfig, TaskQueueMixin
2525
from twinkle.server.utils.validation import get_token_from_request, verify_request_token
2626
from twinkle.utils.logger import get_logger
27+
from ..utils import wrap_builder_with_device_group_env
2728
from .common.io_utils import create_checkpoint_manager, create_training_run_manager
2829
from .common.router import StickyLoraRequestRouter
2930

@@ -653,3 +654,6 @@ async def _do_load():
653654

654655
return ModelManagement.options(**deploy_options).bind(nproc_per_node, device_group, device_mesh, use_megatron,
655656
queue_config, **kwargs)
657+
658+
659+
build_model_app = wrap_builder_with_device_group_env(build_model_app)

src/twinkle/server/tinker/sampler.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from twinkle.server.utils.task_queue import TaskQueueConfig, TaskQueueMixin
2424
from twinkle.server.utils.validation import get_token_from_request, verify_request_token
2525
from twinkle.utils.logger import get_logger
26+
from ..utils import wrap_builder_with_device_group_env
2627
from .common.io_utils import create_checkpoint_manager
2728

2829
logger = get_logger()
@@ -245,3 +246,6 @@ async def _do_sample():
245246

246247
return SamplerManagement.options(**deploy_options).bind(nproc_per_node, device_group, device_mesh, sampler_type,
247248
engine_args, queue_config, **kwargs)
249+
250+
251+
build_sampler_app = wrap_builder_with_device_group_env(build_sampler_app)

src/twinkle/server/tinker/server.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313

1414
import asyncio
1515
import httpx
16-
import logging
1716
import os
1817
from fastapi import FastAPI, HTTPException, Request, Response
1918
from ray import serve
@@ -24,9 +23,10 @@
2423
from twinkle.server.utils.state import get_server_state
2524
from twinkle.server.utils.task_queue import QueueState
2625
from twinkle.server.utils.validation import get_token_from_request, verify_request_token
26+
from twinkle.utils.logger import get_logger
2727
from .common.io_utils import create_checkpoint_manager, create_training_run_manager
2828

29-
logger = logging.getLogger(__name__)
29+
logger = get_logger()
3030

3131

3232
def build_server_app(deploy_options: dict[str, Any],
Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,20 @@
11
# Copyright (c) ModelScope Contributors. All rights reserved.
2-
from .model import build_model_app
3-
from .processor import build_processor_app
4-
from .sampler import build_sampler_app
5-
from .server import build_server_app
2+
import sys
3+
from typing import TYPE_CHECKING
4+
5+
from twinkle.utils.import_utils import _LazyModule
6+
7+
_import_structure = {
8+
'model': ['build_model_app'],
9+
'processor': ['build_processor_app'],
10+
'sampler': ['build_sampler_app'],
11+
'server': ['build_server_app'],
12+
}
13+
14+
if TYPE_CHECKING:
15+
from .model import build_model_app
16+
from .processor import build_processor_app
17+
from .sampler import build_sampler_app
18+
from .server import build_server_app
19+
else:
20+
sys.modules[__name__] = _LazyModule(__name__, __file__, _import_structure, module_spec=__spec__)

0 commit comments

Comments
 (0)