11# Copyright (c) ModelScope Contributors. All rights reserved.
22"""
3- Twinkle-specific IO managers for training runs and checkpoints .
3+ Twinkle-specific checkpoint and training-run managers .
44
55Uses ``twinkle_client.types.training`` models for all serialization and response construction.
66"""
77from datetime import datetime
88from typing import Any , Dict , List , Optional
99
10- from twinkle .server .utils .io_utils import (TRAIN_RUN_INFO_FILENAME , BaseCheckpointManager , BaseTrainingRunManager ,
11- validate_ownership )
12- from twinkle_client .types .training import Checkpoint as TwinkleCheckpoint
13- from twinkle_client .types .training import (CheckpointsListResponse , CreateModelRequest , Cursor ,
14- ParsedCheckpointTwinklePath )
15- from twinkle_client .types .training import TrainingRun as TwinkleTrainingRun
16- from twinkle_client .types .training import TrainingRunsResponse , WeightsInfoResponse
10+ from twinkle .server .utils .checkpoint_base import (TRAIN_RUN_INFO_FILENAME , BaseCheckpointManager ,
11+ BaseTrainingRunManager , validate_ownership )
12+ from twinkle_client .types .training import (Checkpoint , CheckpointsListResponse , CreateModelRequest , Cursor ,
13+ ParsedCheckpointTwinklePath , TrainingRun , TrainingRunsResponse ,
14+ WeightsInfoResponse )
1715
1816
1917class TwinkleTrainingRunManager (BaseTrainingRunManager ):
@@ -25,7 +23,7 @@ def train_run_info_filename(self) -> str:
2523
2624 def _create_training_run (self , model_id : str , run_config : CreateModelRequest ) -> Dict [str , Any ]:
2725 lora_config = run_config .lora_config
28- train_run_data = TwinkleTrainingRun (
26+ train_run_data = TrainingRun (
2927 training_run_id = model_id ,
3028 base_model = run_config .base_model ,
3129 model_owner = self .token ,
@@ -44,14 +42,14 @@ def _create_training_run(self, model_id: str, run_config: CreateModelRequest) ->
4442 new_data ['train_attn' ] = lora_config .train_attn
4543 return new_data
4644
47- def _parse_training_run (self , data : Dict [str , Any ]) -> TwinkleTrainingRun :
48- return TwinkleTrainingRun (** data )
45+ def _parse_training_run (self , data : Dict [str , Any ]) -> TrainingRun :
46+ return TrainingRun (** data )
4947
50- def _create_training_runs_response (self , runs : List [TwinkleTrainingRun ], limit : int , offset : int ,
48+ def _create_training_runs_response (self , runs : List [TrainingRun ], limit : int , offset : int ,
5149 total : int ) -> TrainingRunsResponse :
5250 return TrainingRunsResponse (training_runs = runs , cursor = Cursor (limit = limit , offset = offset , total_count = total ))
5351
54- def get_with_permission (self , model_id : str ) -> Optional [TwinkleTrainingRun ]:
52+ def get_with_permission (self , model_id : str ) -> Optional [TrainingRun ]:
5553 run = self .get (model_id )
5654 if run and validate_ownership (self .token , run .model_owner ):
5755 return run
@@ -82,7 +80,7 @@ def _create_checkpoint(self,
8280 train_mlp = None ,
8381 train_attn = None ,
8482 user_metadata = None ) -> Dict [str , Any ]:
85- checkpoint = TwinkleCheckpoint (
83+ checkpoint = Checkpoint (
8684 checkpoint_id = checkpoint_id ,
8785 checkpoint_type = checkpoint_type ,
8886 time = datetime .now (),
@@ -98,15 +96,15 @@ def _create_checkpoint(self,
9896 user_metadata = user_metadata )
9997 return checkpoint .model_dump (mode = 'json' )
10098
101- def _parse_checkpoint (self , data : Dict [str , Any ]) -> TwinkleCheckpoint :
99+ def _parse_checkpoint (self , data : Dict [str , Any ]) -> Checkpoint :
102100 data = data .copy ()
103101 if 'tinker_path' in data and 'twinkle_path' not in data :
104102 data ['twinkle_path' ] = data .pop ('tinker_path' )
105103 elif 'twinkle_path' not in data and 'path' in data :
106104 data ['twinkle_path' ] = data .pop ('path' )
107- return TwinkleCheckpoint (** data )
105+ return Checkpoint (** data )
108106
109- def get (self , model_id : str , checkpoint_id : str ) -> Optional [TwinkleCheckpoint ]:
107+ def get (self , model_id : str , checkpoint_id : str ) -> Optional [Checkpoint ]:
110108 data = self ._read_ckpt_info (model_id , checkpoint_id )
111109 if not data :
112110 return None
@@ -116,7 +114,7 @@ def get(self, model_id: str, checkpoint_id: str) -> Optional[TwinkleCheckpoint]:
116114 data ['twinkle_path' ] = f"{ self .path_prefix } { model_id } /{ data ['checkpoint_id' ]} "
117115 return self ._parse_checkpoint (data )
118116
119- def _create_checkpoints_response (self , checkpoints : List [TwinkleCheckpoint ]) -> CheckpointsListResponse :
117+ def _create_checkpoints_response (self , checkpoints : List [Checkpoint ]) -> CheckpointsListResponse :
120118 return CheckpointsListResponse (checkpoints = checkpoints , cursor = None )
121119
122120 def _create_parsed_path (self , path , training_run_id , checkpoint_type , checkpoint_id ) -> ParsedCheckpointTwinklePath :
0 commit comments