Skip to content

Commit 7fd432f

Browse files
fix vllmsampler client (#122)
1 parent bea8a11 commit 7fd432f

File tree

3 files changed

+29
-1
lines changed

3 files changed

+29
-1
lines changed

client_tools/client_generator.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -805,6 +805,14 @@ def set_template(self, template_cls: str, adapter_name: str = '', **kwargs) -> S
805805
)
806806
response.raise_for_status()
807807
return SetTemplateResponse(**response.json())
808+
809+
def apply_patch(self, patch_cls: str, **kwargs) -> None:
810+
"""Apply a patch to the model."""
811+
response = http_post(
812+
url=f'{self.server_url}/apply_patch',
813+
json_data={'patch_cls': patch_cls, 'adapter_name': self.adapter_name, **kwargs}
814+
)
815+
response.raise_for_status()
808816
'''
809817

810818
# Write the sampler client file

src/twinkle/server/sampler/twinkle_handlers.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@
88

99
import traceback
1010
from fastapi import Depends, FastAPI, HTTPException, Request
11-
from typing import TYPE_CHECKING, Callable, List, Optional
11+
from typing import TYPE_CHECKING, Callable
12+
13+
from twinkle.server.common.serialize import deserialize_object
1214

1315
if TYPE_CHECKING:
1416
from .app import SamplerManagement
@@ -162,3 +164,13 @@ def add_adapter_to_sampler(
162164
self.sampler.add_adapter_to_sampler(full_adapter_name, config)
163165

164166
return types.AddAdapterResponse(adapter_name=full_adapter_name)
167+
168+
@app.post('/twinkle/apply_patch')
169+
async def apply_patch(
170+
request: Request,
171+
body: types.ApplyPatchRequest,
172+
self: SamplerManagement = Depends(self_fn),
173+
) -> None:
174+
extra_kwargs = body.model_extra or {}
175+
patch_cls = deserialize_object(body.patch_cls)
176+
self.sampler.apply_patch(patch_cls, **extra_kwargs)

src/twinkle_client/sampler/vllm_sampler.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,3 +94,11 @@ def set_template(self, template_cls: str, adapter_name: str = '', **kwargs) -> S
9494
)
9595
response.raise_for_status()
9696
return SetTemplateResponse(**response.json())
97+
98+
def apply_patch(self, patch_cls: str, **kwargs) -> None:
99+
"""Apply a patch to the model."""
100+
response = http_post(
101+
url=f'{self.server_url}/apply_patch',
102+
json_data={'patch_cls': patch_cls, 'adapter_name': self.adapter_name, **kwargs}
103+
)
104+
response.raise_for_status()

0 commit comments

Comments
 (0)