Skip to content

Commit 8220365

Browse files
committed
update handler
1 parent d65c2af commit 8220365

File tree

7 files changed

+41
-151
lines changed

7 files changed

+41
-151
lines changed

src/twinkle/server/model/app.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,10 +118,10 @@ def __del__(self):
118118
pass
119119

120120
async def _cleanup_adapter(self, adapter_name: str) -> None:
121-
if self.get_adapter_info(adapter_name):
122-
self.clear_adapter_state(adapter_name)
121+
if self.get_resource_info(adapter_name):
122+
self.clear_resource_state(adapter_name)
123123
self.model.remove_adapter(adapter_name)
124-
self.unregister_adapter(adapter_name)
124+
self.unregister_resource(adapter_name)
125125
await self.state.unload_model(adapter_name)
126126

127127
async def _on_adapter_expired(self, adapter_name: str) -> None:

src/twinkle/server/model/tinker_handlers.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,12 @@ async def _create_adapter():
4444
if body.lora_config:
4545
lora_cfg = LoraConfig(r=body.lora_config.rank, target_modules='all-linear')
4646
adapter_name = self.get_adapter_name(adapter_name=_model_id)
47-
self.register_adapter(adapter_name, token, session_id=body.session_id)
47+
self.register_resource(adapter_name, token, session_id=body.session_id)
4848
self.model.add_adapter_to_model(adapter_name=adapter_name, config_or_dir=lora_cfg)
4949
self.model.set_template('Template', adapter_name=adapter_name, model_id=self.base_model)
5050
self.model.set_processor('InputProcessor', adapter_name=adapter_name)
5151
self.model.set_optimizer('Adam', adapter_name=adapter_name)
52-
self.set_adapter_state(adapter_name, 'grad_ready', False)
52+
self.set_resource_state(adapter_name, 'grad_ready', False)
5353
training_run_manager = create_training_run_manager(token, client_type='tinker')
5454
training_run_manager.save(_model_id, body)
5555
return types.CreateModelResponse(model_id=_model_id)
@@ -108,7 +108,7 @@ async def forward(request: Request, body: types.ForwardRequest,
108108
async def _do_forward():
109109
try:
110110
adapter_name = self.get_adapter_name(adapter_name=body.model_id)
111-
self.assert_adapter_exists(adapter_name=adapter_name)
111+
self.assert_resource_exists(adapter_name)
112112
datum_list = body.forward_input.data
113113
loss_fn_config = body.forward_input.loss_fn_config or {}
114114
output = self.model.tinker_forward_only(inputs=datum_list, adapter_name=adapter_name)
@@ -149,15 +149,15 @@ async def forward_backward(
149149
async def _do_forward_backward():
150150
try:
151151
adapter_name = self.get_adapter_name(adapter_name=body.model_id)
152-
self.assert_adapter_exists(adapter_name=adapter_name)
152+
self.assert_resource_exists(adapter_name)
153153
datum_list = body.forward_backward_input.data
154154
loss_fn = body.forward_backward_input.loss_fn
155155
loss_fn_config = body.forward_backward_input.loss_fn_config or {}
156156
output, loss = self.model.tinker_forward_backward(
157157
inputs=datum_list, adapter_name=adapter_name, loss_fn=loss_fn, **loss_fn_config)
158158
output_type = ('ImportanceSamplingLossReturn'
159159
if loss_fn == 'importance_sampling' else 'CrossEntropyLossReturn')
160-
self.set_adapter_state(adapter_name, 'grad_ready', True)
160+
self.set_resource_state(adapter_name, 'grad_ready', True)
161161
return types.ForwardBackwardOutput(
162162
loss_fn_output_type=output_type,
163163
loss_fn_outputs=output,
@@ -194,12 +194,12 @@ async def optim_step(
194194
async def _do_optim():
195195
try:
196196
adapter_name = self.get_adapter_name(adapter_name=body.model_id)
197-
self.assert_adapter_exists(adapter_name=adapter_name)
198-
if not self.get_adapter_state(adapter_name, 'grad_ready', False):
197+
self.assert_resource_exists(adapter_name)
198+
if not self.get_resource_state(adapter_name, 'grad_ready', False):
199199
raise RuntimeError(f'No accumulated gradients for adapter={adapter_name}; '
200200
'call forward_backward before optim_step')
201201
self.model.tinker_step(adam_params=body.adam_params, adapter_name=adapter_name)
202-
self.set_adapter_state(adapter_name, 'grad_ready', False)
202+
self.set_resource_state(adapter_name, 'grad_ready', False)
203203
metrics = self.model.tinker_calculate_metric(is_training=True, adapter_name=adapter_name)
204204
return types.OptimStepResponse(metrics=metrics)
205205
except Exception:
@@ -222,7 +222,7 @@ async def save_weights(
222222
async def _do_save():
223223
try:
224224
adapter_name = self.get_adapter_name(adapter_name=body.model_id)
225-
self.assert_adapter_exists(adapter_name=adapter_name)
225+
self.assert_resource_exists(adapter_name)
226226
checkpoint_manager = create_checkpoint_manager(token, client_type='tinker')
227227
checkpoint_name = checkpoint_manager.get_ckpt_name(body.path)
228228
save_dir = checkpoint_manager.get_save_dir(model_id=body.model_id, is_sampler=False)
@@ -250,7 +250,7 @@ async def save_weights_for_sampler(
250250
async def _do_save_for_sampler():
251251
try:
252252
adapter_name = self.get_adapter_name(adapter_name=body.model_id)
253-
self.assert_adapter_exists(adapter_name=adapter_name)
253+
self.assert_resource_exists(adapter_name)
254254
checkpoint_manager = create_checkpoint_manager(token, client_type='tinker')
255255
checkpoint_name = checkpoint_manager.get_ckpt_name(body.path)
256256
save_dir = checkpoint_manager.get_save_dir(model_id=body.model_id, is_sampler=True)
@@ -287,10 +287,10 @@ async def _do_load():
287287
try:
288288
assert self.model is not None, 'Model not loaded, please load model first'
289289
adapter_name = self.get_adapter_name(adapter_name=body.model_id)
290-
self.assert_adapter_exists(adapter_name=adapter_name)
290+
self.assert_resource_exists(adapter_name)
291291
self.model.tinker_load(
292292
checkpoint_dir=body.path, load_optimizer=body.optimizer, adapter_name=adapter_name, token=token)
293-
self.set_adapter_state(adapter_name, 'grad_ready', False)
293+
self.set_resource_state(adapter_name, 'grad_ready', False)
294294
return types.LoadWeightsResponse(path=body.path, type='load_weights')
295295
except Exception:
296296
logger.error(traceback.format_exc())

src/twinkle/server/model/twinkle_handlers.py

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ async def forward(request: Request, body: types.ForwardRequest,
8585
adapter_name = _get_twinkle_adapter_name(request, body.adapter_name)
8686

8787
async def _task():
88-
self.assert_adapter_exists(adapter_name=adapter_name)
88+
self.assert_resource_exists(adapter_name)
8989
extra_kwargs = body.model_extra or {}
9090
inputs = _parse_inputs(body.inputs)
9191
ret = self.model.forward(inputs=inputs, adapter_name=adapter_name, **extra_kwargs)
@@ -103,7 +103,7 @@ async def forward_only(
103103
adapter_name = _get_twinkle_adapter_name(request, body.adapter_name)
104104

105105
async def _task():
106-
self.assert_adapter_exists(adapter_name=adapter_name)
106+
self.assert_resource_exists(adapter_name)
107107
extra_kwargs = body.model_extra or {}
108108
inputs = _parse_inputs(body.inputs)
109109
ret = self.model.forward_only(inputs=inputs, adapter_name=adapter_name, **extra_kwargs)
@@ -121,7 +121,7 @@ async def calculate_loss(
121121
adapter_name = _get_twinkle_adapter_name(request, body.adapter_name)
122122

123123
async def _task():
124-
self.assert_adapter_exists(adapter_name=adapter_name)
124+
self.assert_resource_exists(adapter_name)
125125
extra_kwargs = body.model_extra or {}
126126
ret = self.model.calculate_loss(adapter_name=adapter_name, **extra_kwargs)
127127
return {'result': ret}
@@ -134,7 +134,7 @@ async def backward(request: Request, body: types.AdapterRequest, self: ModelMana
134134
adapter_name = _get_twinkle_adapter_name(request, body.adapter_name)
135135

136136
async def _task():
137-
self.assert_adapter_exists(adapter_name=adapter_name)
137+
self.assert_resource_exists(adapter_name)
138138
extra_kwargs = body.model_extra or {}
139139
self.model.backward(adapter_name=adapter_name, **extra_kwargs)
140140

@@ -157,7 +157,7 @@ def first_element(data):
157157
return data
158158

159159
async def _task():
160-
self.assert_adapter_exists(adapter_name=adapter_name)
160+
self.assert_resource_exists(adapter_name)
161161
extra_kwargs = body.model_extra or {}
162162
all_inputs = _parse_inputs(body.inputs)
163163
for inputs in all_inputs:
@@ -179,7 +179,7 @@ async def clip_grad_norm(
179179
adapter_name = _get_twinkle_adapter_name(request, body.adapter_name)
180180

181181
async def _task():
182-
self.assert_adapter_exists(adapter_name=adapter_name)
182+
self.assert_resource_exists(adapter_name)
183183
extra_kwargs = body.model_extra or {}
184184
ret = self.model.clip_grad_norm(adapter_name=adapter_name, **extra_kwargs)
185185
return {'result': str(ret)}
@@ -192,7 +192,7 @@ async def step(request: Request, body: types.AdapterRequest, self: ModelManageme
192192
adapter_name = _get_twinkle_adapter_name(request, body.adapter_name)
193193

194194
async def _task():
195-
self.assert_adapter_exists(adapter_name=adapter_name)
195+
self.assert_resource_exists(adapter_name)
196196
extra_kwargs = body.model_extra or {}
197197
self.model.step(adapter_name=adapter_name, **extra_kwargs)
198198

@@ -204,7 +204,7 @@ async def zero_grad(request: Request, body: types.AdapterRequest, self: ModelMan
204204
adapter_name = _get_twinkle_adapter_name(request, body.adapter_name)
205205

206206
async def _task():
207-
self.assert_adapter_exists(adapter_name=adapter_name)
207+
self.assert_resource_exists(adapter_name)
208208
extra_kwargs = body.model_extra or {}
209209
self.model.zero_grad(adapter_name=adapter_name, **extra_kwargs)
210210

@@ -216,7 +216,7 @@ async def lr_step(request: Request, body: types.AdapterRequest, self: ModelManag
216216
adapter_name = _get_twinkle_adapter_name(request, body.adapter_name)
217217

218218
async def _task():
219-
self.assert_adapter_exists(adapter_name=adapter_name)
219+
self.assert_resource_exists(adapter_name)
220220
extra_kwargs = body.model_extra or {}
221221
self.model.lr_step(adapter_name=adapter_name, **extra_kwargs)
222222

@@ -232,7 +232,7 @@ async def clip_grad_and_step(
232232
adapter_name = _get_twinkle_adapter_name(request, body.adapter_name)
233233

234234
async def _task():
235-
self.assert_adapter_exists(adapter_name=adapter_name)
235+
self.assert_resource_exists(adapter_name)
236236
extra_kwargs = body.model_extra or {}
237237
self.model.clip_grad_and_step(
238238
max_grad_norm=body.max_grad_norm,
@@ -253,7 +253,7 @@ async def get_train_configs(
253253
adapter_name = _get_twinkle_adapter_name(request, body.adapter_name)
254254

255255
async def _task():
256-
self.assert_adapter_exists(adapter_name=adapter_name)
256+
self.assert_resource_exists(adapter_name)
257257
extra_kwargs = body.model_extra or {}
258258
ret = self.model.get_train_configs(adapter_name=adapter_name, **extra_kwargs)
259259
return {'result': ret}
@@ -266,7 +266,7 @@ async def set_loss(request: Request, body: types.SetLossRequest, self: ModelMana
266266
adapter_name = _get_twinkle_adapter_name(request, body.adapter_name)
267267

268268
async def _task():
269-
self.assert_adapter_exists(adapter_name=adapter_name)
269+
self.assert_resource_exists(adapter_name)
270270
extra_kwargs = body.model_extra or {}
271271
self.model.set_loss(body.loss_cls, adapter_name=adapter_name, **extra_kwargs)
272272

@@ -282,7 +282,7 @@ async def set_optimizer(
282282
adapter_name = _get_twinkle_adapter_name(request, body.adapter_name)
283283

284284
async def _task():
285-
self.assert_adapter_exists(adapter_name=adapter_name)
285+
self.assert_resource_exists(adapter_name)
286286
extra_kwargs = body.model_extra or {}
287287
self.model.set_optimizer(body.optimizer_cls, adapter_name=adapter_name, **extra_kwargs)
288288

@@ -298,7 +298,7 @@ async def set_lr_scheduler(
298298
adapter_name = _get_twinkle_adapter_name(request, body.adapter_name)
299299

300300
async def _task():
301-
self.assert_adapter_exists(adapter_name=adapter_name)
301+
self.assert_resource_exists(adapter_name)
302302
extra_kwargs = body.model_extra or {}
303303
self.model.set_lr_scheduler(body.scheduler_cls, adapter_name=adapter_name, **extra_kwargs)
304304

@@ -311,7 +311,7 @@ async def save(request: Request, body: types.SaveRequest,
311311
adapter_name = _get_twinkle_adapter_name(request, body.adapter_name)
312312

313313
async def _task():
314-
self.assert_adapter_exists(adapter_name=adapter_name)
314+
self.assert_resource_exists(adapter_name)
315315
extra_kwargs = body.model_extra or {}
316316
checkpoint_manager = create_checkpoint_manager(token, client_type='twinkle')
317317
checkpoint_name = checkpoint_manager.get_ckpt_name(body.name)
@@ -333,7 +333,7 @@ async def load(request: Request, body: types.LoadRequest, self: ModelManagement
333333
adapter_name = _get_twinkle_adapter_name(request, body.adapter_name)
334334

335335
async def _task():
336-
self.assert_adapter_exists(adapter_name=adapter_name)
336+
self.assert_resource_exists(adapter_name)
337337
extra_kwargs = body.model_extra or {}
338338
checkpoint_manager = create_checkpoint_manager(token, client_type='twinkle')
339339
resolved = checkpoint_manager.resolve_load_path(body.name)
@@ -393,7 +393,7 @@ async def _task():
393393
config = deserialize_object(body.config)
394394
extra_kwargs = body.model_extra or {}
395395
training_run_manager = create_training_run_manager(token, client_type='twinkle')
396-
self.register_adapter(adapter_name, token, session_id=session_id)
396+
self.register_resource(adapter_name, token, session_id)
397397
self.model.add_adapter_to_model(adapter_name, config, **extra_kwargs)
398398

399399
lora_config = None
@@ -416,7 +416,7 @@ async def apply_patch(
416416
adapter_name = _get_twinkle_adapter_name(request, body.adapter_name)
417417

418418
async def _task():
419-
self.assert_adapter_exists(adapter_name=adapter_name)
419+
self.assert_resource_exists(adapter_name)
420420
extra_kwargs = body.model_extra or {}
421421
patch_cls = deserialize_object(body.patch_cls)
422422
self.model.apply_patch(patch_cls, adapter_name=adapter_name, **extra_kwargs)
@@ -433,7 +433,7 @@ async def add_metric(
433433
adapter_name = _get_twinkle_adapter_name(request, body.adapter_name)
434434

435435
async def _task():
436-
self.assert_adapter_exists(adapter_name=adapter_name)
436+
self.assert_resource_exists(adapter_name)
437437
extra_kwargs = body.model_extra or {}
438438
metric_cls = deserialize_object(body.metric_cls)
439439
self.model.add_metric(metric_cls, is_training=body.is_training, adapter_name=adapter_name, **extra_kwargs)
@@ -450,7 +450,7 @@ async def set_template(
450450
adapter_name = _get_twinkle_adapter_name(request, body.adapter_name)
451451

452452
async def _task():
453-
self.assert_adapter_exists(adapter_name=adapter_name)
453+
self.assert_resource_exists(adapter_name)
454454
extra_kwargs = body.model_extra or {}
455455
self.model.set_template(body.template_cls, adapter_name=adapter_name, **extra_kwargs)
456456

@@ -466,7 +466,7 @@ async def set_processor(
466466
adapter_name = _get_twinkle_adapter_name(request, body.adapter_name)
467467

468468
async def _task():
469-
self.assert_adapter_exists(adapter_name=adapter_name)
469+
self.assert_resource_exists(adapter_name)
470470
extra_kwargs = body.model_extra or {}
471471
self.model.set_processor(body.processor_cls, adapter_name=adapter_name, **extra_kwargs)
472472

@@ -482,7 +482,7 @@ async def calculate_metric(
482482
adapter_name = _get_twinkle_adapter_name(request, body.adapter_name)
483483

484484
async def _task():
485-
self.assert_adapter_exists(adapter_name=adapter_name)
485+
self.assert_resource_exists(adapter_name)
486486
extra_kwargs = body.model_extra or {}
487487
ret = self.model.calculate_metric(is_training=body.is_training, adapter_name=adapter_name, **extra_kwargs)
488488
return {'result': ret}
@@ -499,7 +499,7 @@ async def get_state_dict(
499499
adapter_name = _get_twinkle_adapter_name(request, body.adapter_name)
500500

501501
async def _task():
502-
self.assert_adapter_exists(adapter_name=adapter_name)
502+
self.assert_resource_exists(adapter_name)
503503
extra_kwargs = body.model_extra or {}
504504
ret = self.model.get_state_dict(adapter_name=adapter_name, **extra_kwargs)
505505
return {'result': ret}

src/twinkle/server/processor/app.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ async def _ensure_sticky(self):
8484
def _on_processor_expired(self, processor_id: str) -> None:
8585
"""Called by the countdown thread when a processor's session expires."""
8686
self.resource_dict.pop(processor_id, None)
87-
self.unregister_processor(processor_id)
87+
self.unregister_resource(processor_id)
8888

8989

9090
def build_processor_app(ncpu_proc_per_node: int,

src/twinkle/server/processor/twinkle_handlers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ async def create(
5353
processor_id = str(uuid.uuid4().hex)
5454

5555
# Register for lifecycle tracking (enforces per-user limit)
56-
self.register_processor(processor_id, token, session_id)
56+
self.register_resource(processor_id, token, session_id)
5757

5858
_kwargs.pop('remote_group', None)
5959
_kwargs.pop('device_mesh', None)
@@ -91,7 +91,7 @@ async def call(
9191
function_name = body.function
9292
_kwargs = body.model_extra or {}
9393
processor_id = processor_id[4:]
94-
self.assert_processor_exists(processor_id=processor_id)
94+
self.assert_resource_exists(processor_id)
9595
processor = self.resource_dict.get(processor_id)
9696
function = getattr(processor, function_name, None)
9797

0 commit comments

Comments
 (0)