@@ -85,7 +85,7 @@ async def forward(request: Request, body: types.ForwardRequest,
8585 adapter_name = _get_twinkle_adapter_name (request , body .adapter_name )
8686
8787 async def _task ():
88- self .assert_adapter_exists ( adapter_name = adapter_name )
88+ self .assert_resource_exists ( adapter_name )
8989 extra_kwargs = body .model_extra or {}
9090 inputs = _parse_inputs (body .inputs )
9191 ret = self .model .forward (inputs = inputs , adapter_name = adapter_name , ** extra_kwargs )
@@ -103,7 +103,7 @@ async def forward_only(
103103 adapter_name = _get_twinkle_adapter_name (request , body .adapter_name )
104104
105105 async def _task ():
106- self .assert_adapter_exists ( adapter_name = adapter_name )
106+ self .assert_resource_exists ( adapter_name )
107107 extra_kwargs = body .model_extra or {}
108108 inputs = _parse_inputs (body .inputs )
109109 ret = self .model .forward_only (inputs = inputs , adapter_name = adapter_name , ** extra_kwargs )
@@ -121,7 +121,7 @@ async def calculate_loss(
121121 adapter_name = _get_twinkle_adapter_name (request , body .adapter_name )
122122
123123 async def _task ():
124- self .assert_adapter_exists ( adapter_name = adapter_name )
124+ self .assert_resource_exists ( adapter_name )
125125 extra_kwargs = body .model_extra or {}
126126 ret = self .model .calculate_loss (adapter_name = adapter_name , ** extra_kwargs )
127127 return {'result' : ret }
@@ -134,7 +134,7 @@ async def backward(request: Request, body: types.AdapterRequest, self: ModelMana
134134 adapter_name = _get_twinkle_adapter_name (request , body .adapter_name )
135135
136136 async def _task ():
137- self .assert_adapter_exists ( adapter_name = adapter_name )
137+ self .assert_resource_exists ( adapter_name )
138138 extra_kwargs = body .model_extra or {}
139139 self .model .backward (adapter_name = adapter_name , ** extra_kwargs )
140140
@@ -157,7 +157,7 @@ def first_element(data):
157157 return data
158158
159159 async def _task ():
160- self .assert_adapter_exists ( adapter_name = adapter_name )
160+ self .assert_resource_exists ( adapter_name )
161161 extra_kwargs = body .model_extra or {}
162162 all_inputs = _parse_inputs (body .inputs )
163163 for inputs in all_inputs :
@@ -179,7 +179,7 @@ async def clip_grad_norm(
179179 adapter_name = _get_twinkle_adapter_name (request , body .adapter_name )
180180
181181 async def _task ():
182- self .assert_adapter_exists ( adapter_name = adapter_name )
182+ self .assert_resource_exists ( adapter_name )
183183 extra_kwargs = body .model_extra or {}
184184 ret = self .model .clip_grad_norm (adapter_name = adapter_name , ** extra_kwargs )
185185 return {'result' : str (ret )}
@@ -192,7 +192,7 @@ async def step(request: Request, body: types.AdapterRequest, self: ModelManageme
192192 adapter_name = _get_twinkle_adapter_name (request , body .adapter_name )
193193
194194 async def _task ():
195- self .assert_adapter_exists ( adapter_name = adapter_name )
195+ self .assert_resource_exists ( adapter_name )
196196 extra_kwargs = body .model_extra or {}
197197 self .model .step (adapter_name = adapter_name , ** extra_kwargs )
198198
@@ -204,7 +204,7 @@ async def zero_grad(request: Request, body: types.AdapterRequest, self: ModelMan
204204 adapter_name = _get_twinkle_adapter_name (request , body .adapter_name )
205205
206206 async def _task ():
207- self .assert_adapter_exists ( adapter_name = adapter_name )
207+ self .assert_resource_exists ( adapter_name )
208208 extra_kwargs = body .model_extra or {}
209209 self .model .zero_grad (adapter_name = adapter_name , ** extra_kwargs )
210210
@@ -216,7 +216,7 @@ async def lr_step(request: Request, body: types.AdapterRequest, self: ModelManag
216216 adapter_name = _get_twinkle_adapter_name (request , body .adapter_name )
217217
218218 async def _task ():
219- self .assert_adapter_exists ( adapter_name = adapter_name )
219+ self .assert_resource_exists ( adapter_name )
220220 extra_kwargs = body .model_extra or {}
221221 self .model .lr_step (adapter_name = adapter_name , ** extra_kwargs )
222222
@@ -232,7 +232,7 @@ async def clip_grad_and_step(
232232 adapter_name = _get_twinkle_adapter_name (request , body .adapter_name )
233233
234234 async def _task ():
235- self .assert_adapter_exists ( adapter_name = adapter_name )
235+ self .assert_resource_exists ( adapter_name )
236236 extra_kwargs = body .model_extra or {}
237237 self .model .clip_grad_and_step (
238238 max_grad_norm = body .max_grad_norm ,
@@ -253,7 +253,7 @@ async def get_train_configs(
253253 adapter_name = _get_twinkle_adapter_name (request , body .adapter_name )
254254
255255 async def _task ():
256- self .assert_adapter_exists ( adapter_name = adapter_name )
256+ self .assert_resource_exists ( adapter_name )
257257 extra_kwargs = body .model_extra or {}
258258 ret = self .model .get_train_configs (adapter_name = adapter_name , ** extra_kwargs )
259259 return {'result' : ret }
@@ -266,7 +266,7 @@ async def set_loss(request: Request, body: types.SetLossRequest, self: ModelMana
266266 adapter_name = _get_twinkle_adapter_name (request , body .adapter_name )
267267
268268 async def _task ():
269- self .assert_adapter_exists ( adapter_name = adapter_name )
269+ self .assert_resource_exists ( adapter_name )
270270 extra_kwargs = body .model_extra or {}
271271 self .model .set_loss (body .loss_cls , adapter_name = adapter_name , ** extra_kwargs )
272272
@@ -282,7 +282,7 @@ async def set_optimizer(
282282 adapter_name = _get_twinkle_adapter_name (request , body .adapter_name )
283283
284284 async def _task ():
285- self .assert_adapter_exists ( adapter_name = adapter_name )
285+ self .assert_resource_exists ( adapter_name )
286286 extra_kwargs = body .model_extra or {}
287287 self .model .set_optimizer (body .optimizer_cls , adapter_name = adapter_name , ** extra_kwargs )
288288
@@ -298,7 +298,7 @@ async def set_lr_scheduler(
298298 adapter_name = _get_twinkle_adapter_name (request , body .adapter_name )
299299
300300 async def _task ():
301- self .assert_adapter_exists ( adapter_name = adapter_name )
301+ self .assert_resource_exists ( adapter_name )
302302 extra_kwargs = body .model_extra or {}
303303 self .model .set_lr_scheduler (body .scheduler_cls , adapter_name = adapter_name , ** extra_kwargs )
304304
@@ -311,7 +311,7 @@ async def save(request: Request, body: types.SaveRequest,
311311 adapter_name = _get_twinkle_adapter_name (request , body .adapter_name )
312312
313313 async def _task ():
314- self .assert_adapter_exists ( adapter_name = adapter_name )
314+ self .assert_resource_exists ( adapter_name )
315315 extra_kwargs = body .model_extra or {}
316316 checkpoint_manager = create_checkpoint_manager (token , client_type = 'twinkle' )
317317 checkpoint_name = checkpoint_manager .get_ckpt_name (body .name )
@@ -333,7 +333,7 @@ async def load(request: Request, body: types.LoadRequest, self: ModelManagement
333333 adapter_name = _get_twinkle_adapter_name (request , body .adapter_name )
334334
335335 async def _task ():
336- self .assert_adapter_exists ( adapter_name = adapter_name )
336+ self .assert_resource_exists ( adapter_name )
337337 extra_kwargs = body .model_extra or {}
338338 checkpoint_manager = create_checkpoint_manager (token , client_type = 'twinkle' )
339339 resolved = checkpoint_manager .resolve_load_path (body .name )
@@ -393,7 +393,7 @@ async def _task():
393393 config = deserialize_object (body .config )
394394 extra_kwargs = body .model_extra or {}
395395 training_run_manager = create_training_run_manager (token , client_type = 'twinkle' )
396- self .register_adapter (adapter_name , token , session_id = session_id )
396+ self .register_resource (adapter_name , token , session_id )
397397 self .model .add_adapter_to_model (adapter_name , config , ** extra_kwargs )
398398
399399 lora_config = None
@@ -416,7 +416,7 @@ async def apply_patch(
416416 adapter_name = _get_twinkle_adapter_name (request , body .adapter_name )
417417
418418 async def _task ():
419- self .assert_adapter_exists ( adapter_name = adapter_name )
419+ self .assert_resource_exists ( adapter_name )
420420 extra_kwargs = body .model_extra or {}
421421 patch_cls = deserialize_object (body .patch_cls )
422422 self .model .apply_patch (patch_cls , adapter_name = adapter_name , ** extra_kwargs )
@@ -433,7 +433,7 @@ async def add_metric(
433433 adapter_name = _get_twinkle_adapter_name (request , body .adapter_name )
434434
435435 async def _task ():
436- self .assert_adapter_exists ( adapter_name = adapter_name )
436+ self .assert_resource_exists ( adapter_name )
437437 extra_kwargs = body .model_extra or {}
438438 metric_cls = deserialize_object (body .metric_cls )
439439 self .model .add_metric (metric_cls , is_training = body .is_training , adapter_name = adapter_name , ** extra_kwargs )
@@ -450,7 +450,7 @@ async def set_template(
450450 adapter_name = _get_twinkle_adapter_name (request , body .adapter_name )
451451
452452 async def _task ():
453- self .assert_adapter_exists ( adapter_name = adapter_name )
453+ self .assert_resource_exists ( adapter_name )
454454 extra_kwargs = body .model_extra or {}
455455 self .model .set_template (body .template_cls , adapter_name = adapter_name , ** extra_kwargs )
456456
@@ -466,7 +466,7 @@ async def set_processor(
466466 adapter_name = _get_twinkle_adapter_name (request , body .adapter_name )
467467
468468 async def _task ():
469- self .assert_adapter_exists ( adapter_name = adapter_name )
469+ self .assert_resource_exists ( adapter_name )
470470 extra_kwargs = body .model_extra or {}
471471 self .model .set_processor (body .processor_cls , adapter_name = adapter_name , ** extra_kwargs )
472472
@@ -482,7 +482,7 @@ async def calculate_metric(
482482 adapter_name = _get_twinkle_adapter_name (request , body .adapter_name )
483483
484484 async def _task ():
485- self .assert_adapter_exists ( adapter_name = adapter_name )
485+ self .assert_resource_exists ( adapter_name )
486486 extra_kwargs = body .model_extra or {}
487487 ret = self .model .calculate_metric (is_training = body .is_training , adapter_name = adapter_name , ** extra_kwargs )
488488 return {'result' : ret }
@@ -499,7 +499,7 @@ async def get_state_dict(
499499 adapter_name = _get_twinkle_adapter_name (request , body .adapter_name )
500500
501501 async def _task ():
502- self .assert_adapter_exists ( adapter_name = adapter_name )
502+ self .assert_resource_exists ( adapter_name )
503503 extra_kwargs = body .model_extra or {}
504504 ret = self .model .get_state_dict (adapter_name = adapter_name , ** extra_kwargs )
505505 return {'result' : ret }
0 commit comments