Skip to content

Commit 37fcf17

Browse files
committed
Merge branch 'fix_moe' of https://github.com/modelscope/twinkle into fix_moe
2 parents d6c274d + 01f88f7 commit 37fcf17

File tree

5 files changed

+14
-13
lines changed

5 files changed

+14
-13
lines changed

cookbook/client/tinker/self_congnition.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
# The server must be running first (see server.py and server_config.yaml).
99
import numpy as np
1010
import os
11-
from modelscope import AutoTokenizer
11+
from tqdm import tqdm
1212
from tinker import types
1313
from twinkle_client import init_tinker_compat_client
1414
from twinkle.data_format import Message, Trajectory
@@ -125,5 +125,5 @@ def eval():
125125

126126

127127
if __name__ == '__main__':
128-
# train() # Uncomment to run training
129-
eval() # Run evaluation / inference
128+
train() # Uncomment to run training
129+
# eval() # Run evaluation / inference

src/twinkle/model/megatron/megatron.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -848,13 +848,13 @@ def load(self, name: str, output_dir: Optional[str] = None, **kwargs):
848848
Args:
849849
name: Checkpoint name or HuggingFace Hub model id.
850850
output_dir: Parent directory that contains the checkpoint folder.
851-
If None **and** ``resume`` is False, downloads from Hub.
852-
resume: If True, restore optimizer, lr_scheduler and RNG state
851+
If None **and** ``load_optimizer`` is False, downloads from Hub.
852+
load_optimizer: If True, restore optimizer, lr_scheduler and RNG state
853853
from the mcore sub-checkpoint for training resumption.
854854
**kwargs: Additional arguments (``adapter_name``, ``no_load_optim``,
855855
``no_load_rng``, etc.).
856856
"""
857-
resume = kwargs.pop('resume', False)
857+
resume = kwargs.pop('load_optimizer', False)
858858
if output_dir is None and not resume:
859859
# Load from hub
860860
token = kwargs.pop('token', None)

src/twinkle/server/utils/state.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ class ServerState:
3131
def __init__(
3232
self,
3333
expiration_timeout: float = 86400.0, # 24 hours in seconds
34-
cleanup_interval: float = 3600.0) -> None: # 1 hour in seconds
34+
cleanup_interval: float = 3600.0,
35+
**kwargs) -> None: # 1 hour in seconds
3536
# Session tracking
3637
self.sessions: dict[str, dict[str, Any]] = {}
3738
# Model registration

src/twinkle/server/utils/task_queue.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ async def _queue_worker(self) -> None:
222222
Selection policy: round-robin across queue keys. If a task is rate-limited
223223
at execution time, it is requeued and the worker tries other queues.
224224
"""
225-
print('[TaskQueue] Worker started')
225+
logger.debug('[TaskQueue] Worker started')
226226
while True:
227227
try:
228228
# Wait until there is at least one queue with a task
@@ -470,7 +470,7 @@ async def schedule_task(
470470
if self._event_loop is None:
471471
self._event_loop = asyncio.get_running_loop()
472472

473-
print(
473+
logger.debug(
474474
f'[TaskQueue] Scheduling task {request_id}, rps_limit={self._task_queue_config.rps_limit}, enabled={self._task_queue_config.enabled}' # noqa: E501
475475
)
476476

@@ -487,7 +487,7 @@ async def schedule_task(
487487

488488
# 5. Put task in queue and update status
489489
q = self._task_queues[queue_key]
490-
print(
490+
logger.debug(
491491
f'[TaskQueue] Adding task {request_id} to queue key={queue_key} (current size: {q.qsize()}) type={task_type}' # noqa: E501
492492
)
493493
await q.put(
@@ -502,7 +502,7 @@ async def schedule_task(
502502
))
503503
self.state.store_future_status(
504504
request_id, TaskStatus.QUEUED.value, model_id, queue_state=QueueState.ACTIVE.value)
505-
print(f'[TaskQueue] Task {request_id} queued, new queue size: {q.qsize()} key={queue_key}')
505+
logger.debug(f'[TaskQueue] Task {request_id} queued, new queue size: {q.qsize()} key={queue_key}')
506506

507507
self._new_task_event.set()
508508

@@ -567,4 +567,4 @@ async def shutdown_task_queue(self) -> None:
567567
self._task_queues.clear()
568568
self._queue_order.clear()
569569

570-
print('[TaskQueue] Task queue shutdown complete')
570+
logger.debug('[TaskQueue] Task queue shutdown complete')

src/twinkle_client/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def init_tinker_compat_client(base_url: str | None = None, api_key: str | None =
2121
# Apply patch to bypass tinker:// prefix validation
2222
patch_tinker()
2323

24-
if api_key is None:
24+
if not api_key:
2525
api_key = get_api_key()
2626

2727
if base_url and not base_url.startswith(('http://', 'https://')):

0 commit comments

Comments
 (0)