@@ -120,25 +120,25 @@ def __init__(self,
120120
121121 def _cleanup_adapter (self , adapter_name : str ) -> None :
122122 """Common adapter cleanup logic used by both manual unload and automatic expiration.
123-
123+
124124 This method handles:
125125 1. Clearing adapter state
126126 2. Removing adapter from model
127127 3. Unregistering from adapter manager
128128 4. Removing from server state
129-
129+
130130 Args:
131131 adapter_name: Name of the adapter to clean up
132132 """
133133 # Remove from model if it exists
134134 if self .get_adapter_info (adapter_name ):
135135 # Clear adapter state
136136 self .clear_adapter_state (adapter_name )
137-
137+
138138 self .model .remove_adapter (adapter_name )
139139 # Unregister from adapter manager
140140 self .unregister_adapter (adapter_name )
141-
141+
142142 # Remove from server state
143143 self .state .unload_model (adapter_name )
144144
@@ -175,16 +175,13 @@ async def _create_adapter():
175175 # TODO: support more lora config parameters, train_unembed, etc.
176176 lora_cfg = LoraConfig (r = body .lora_config .rank , target_modules = 'all-linear' )
177177
178- adapter_name = self .get_adapter_name (
179- adapter_name = model_id )
180-
178+ adapter_name = self .get_adapter_name (adapter_name = model_id )
179+
181180 # Register adapter FIRST (limit check happens inside register_adapter)
182- self .register_adapter (
183- adapter_name , request .state .token , session_id = body .session_id )
184-
181+ self .register_adapter (adapter_name , request .state .token , session_id = body .session_id )
182+
185183 # Create adapter AFTER successful registration
186- self .model .add_adapter_to_model (
187- adapter_name = adapter_name , config_or_dir = lora_cfg )
184+ self .model .add_adapter_to_model (adapter_name = adapter_name , config_or_dir = lora_cfg )
188185
189186 self .model .set_template ('Template' , adapter_name = adapter_name , model_id = self .base_model )
190187 self .model .set_processor ('InputProcessor' , adapter_name = adapter_name )
@@ -193,8 +190,7 @@ async def _create_adapter():
193190 # Fresh adapter has no accumulated gradients.
194191 self .set_adapter_state (adapter_name , 'grad_ready' , False )
195192
196- training_run_manager = create_training_run_manager (
197- request .state .token )
193+ training_run_manager = create_training_run_manager (request .state .token )
198194 training_run_manager .save (model_id , body )
199195
200196 return types .CreateModelResponse (model_id = model_id )
@@ -261,8 +257,7 @@ async def unload_model(self, request: Request, body: types.UnloadModelRequest) -
261257
262258 async def _do_unload ():
263259 # Only remove adapter, not the base model
264- adapter_name = self .get_adapter_name (
265- adapter_name = body .model_id )
260+ adapter_name = self .get_adapter_name (adapter_name = body .model_id )
266261 # Use common cleanup logic
267262 self ._cleanup_adapter (adapter_name )
268263 return types .UnloadModelResponse (model_id = body .model_id )
@@ -315,9 +310,7 @@ async def _do_forward():
315310
316311 # Calculate input tokens and batch size for validation
317312 datum_list = body .forward_input .data
318- input_tokens = sum (
319- len (d .model_input .to_ints ()) for d in datum_list
320- )
313+ input_tokens = sum (len (d .model_input .to_ints ()) for d in datum_list )
321314 batch_size = len (datum_list )
322315 return await self .schedule_task (
323316 _do_forward ,
@@ -360,11 +353,12 @@ async def _do_forward_backward():
360353 loss_fn_config = body .forward_backward_input .loss_fn_config or {}
361354
362355 # Unified forward_backward for both Megatron and Transformers
363- output , loss = self .model .forward_backward (inputs = datum_list ,
364- adapter_name = adapter_name ,
365- loss_fn = loss_fn ,
366- ** loss_fn_config )
367- output_type = 'ImportanceSamplingLossReturn' if loss_fn == 'importance_sampling' else 'CrossEntropyLossReturn'
356+ output , loss = self .model .forward_backward (
357+ inputs = datum_list , adapter_name = adapter_name , loss_fn = loss_fn , ** loss_fn_config )
358+ if loss_fn == 'importance_sampling' :
359+ output_type = 'ImportanceSamplingLossReturn'
360+ else :
361+ output_type = 'CrossEntropyLossReturn'
368362 # Mark gradients as ready after a successful forward_backward.
369363 self .set_adapter_state (adapter_name , 'grad_ready' , True )
370364 return types .ForwardBackwardOutput (
@@ -381,9 +375,7 @@ async def _do_forward_backward():
381375
382376 # Calculate input tokens and batch size for validation
383377 datum_list = body .forward_backward_input .data
384- input_tokens = sum (
385- len (d .model_input .to_ints ()) for d in datum_list
386- )
378+ input_tokens = sum (len (d .model_input .to_ints ()) for d in datum_list )
387379 batch_size = len (datum_list )
388380 return await self .schedule_task (
389381 _do_forward_backward ,
@@ -417,14 +409,13 @@ async def _do_optim():
417409 # Disallow empty step (must have at least one forward_backward since last step)
418410 if not self .get_adapter_state (adapter_name , 'grad_ready' , False ):
419411 raise RuntimeError (
420- f" No accumulated gradients for adapter={ adapter_name } ; call forward_backward before optim_step"
412+ f' No accumulated gradients for adapter={ adapter_name } ; call forward_backward before optim_step' # noqa: E501
421413 )
422414
423415 # Touch adapter to reset inactivity counter
424416 self .touch_adapter (adapter_name )
425417
426- self .model .step (adam_params = body .adam_params ,
427- adapter_name = adapter_name )
418+ self .model .step (adam_params = body .adam_params , adapter_name = adapter_name )
428419 # Clear grad-ready after a successful step.
429420 self .set_adapter_state (adapter_name , 'grad_ready' , False )
430421 metrics = self .model .calculate_metric (is_training = True , adapter_name = adapter_name )
@@ -590,15 +581,15 @@ async def _do_load():
590581 weight_path = body .path
591582 load_optimizer = body .optimizer
592583
593- self .model .load (checkpoint_dir = weight_path ,
594- load_optimizer = load_optimizer ,
595- adapter_name = adapter_name ,
596- token = token )
584+ self .model .load (
585+ checkpoint_dir = weight_path ,
586+ load_optimizer = load_optimizer ,
587+ adapter_name = adapter_name ,
588+ token = token )
597589
598590 # Loading a checkpoint should reset step readiness.
599591 self .set_adapter_state (adapter_name , 'grad_ready' , False )
600- return types .LoadWeightsResponse (path = body .path ,
601- type = 'load_weights' )
592+ return types .LoadWeightsResponse (path = body .path , type = 'load_weights' )
602593 except Exception :
603594 logger .error (traceback .format_exc ())
604595 return types .RequestFailedResponse (
0 commit comments