Skip to content

Commit c7b235b

Browse files
committed
fix lint
1 parent 4e3c68b commit c7b235b

File tree

12 files changed

+208
-235
lines changed

12 files changed

+208
-235
lines changed

cookbook/client/tinker/sample.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,25 +14,25 @@
1414
from twinkle.template import Template
1515

1616
# Step 1: Define the base model and connect to the server
17-
base_model = "Qwen/Qwen2.5-7B-Instruct"
17+
base_model = 'Qwen/Qwen2.5-7B-Instruct'
1818
service_client = init_tinker_compat_client(base_url='http://localhost:8000')
1919

2020
# Step 2: Create a sampling client by loading weights from a saved checkpoint.
2121
# The model_path is a twinkle:// URI pointing to a previously saved LoRA checkpoint.
2222
# The server will load the base model and apply the LoRA adapter weights.
2323
sampling_client = service_client.create_sampling_client(
24-
model_path="twinkle://20260212_174205-Qwen_Qwen2_5-7B-Instruct-51edc9ed/weights/twinkle-lora-2",
24+
model_path='twinkle://20260212_174205-Qwen_Qwen2_5-7B-Instruct-51edc9ed/weights/twinkle-lora-2',
2525
base_model=base_model)
2626

2727
# Step 3: Load the tokenizer locally to encode the prompt and decode the results
28-
print(f"Using model {base_model}")
28+
print(f'Using model {base_model}')
2929

3030
template = Template(model_id=f'ms://{base_model}')
3131

3232
trajectory = Trajectory(
3333
messages=[
3434
Message(role='system', content='You are a helpful assistant'),
35-
Message(role='user', content="你是谁?"),
35+
Message(role='user', content='你是谁?'),
3636
]
3737
)
3838

@@ -44,8 +44,8 @@
4444
prompt = types.ModelInput.from_ints(input_ids)
4545
params = types.SamplingParams(
4646
max_tokens=128, # Maximum number of tokens to generate
47-
temperature=0.7,
48-
stop=["\n"] # Stop generation when a newline character is produced
47+
temperature=0.7,
48+
stop=['\n'] # Stop generation when a newline character is produced
4949
)
5050

5151
# Step 5: Send the sampling request to the server.
@@ -57,4 +57,4 @@
5757
# Step 6: Decode and print the generated responses
5858
print('Responses:')
5959
for i, seq in enumerate(result.sequences):
60-
print(f"{i}: {repr(template.decode(seq.tokens))}")
60+
print(f'{i}: {repr(template.decode(seq.tokens))}')

cookbook/client/tinker/self_congnition.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def eval():
8282
# Step 1: Load the trained LoRA checkpoint for inference
8383

8484
# Path to a previously saved LoRA checkpoint (twinkle:// URI)
85-
weight_path = "twinkle://20260212_174205-Qwen_Qwen2_5-7B-Instruct-51edc9ed/weights/twinkle-lora-2"
85+
weight_path = 'twinkle://20260212_174205-Qwen_Qwen2_5-7B-Instruct-51edc9ed/weights/twinkle-lora-2'
8686

8787
# Connect to the server and create a sampling client with the trained weights
8888
service_client = init_tinker_compat_client(base_url='http://localhost:8000')
@@ -96,7 +96,7 @@ def eval():
9696
trajectory = Trajectory(
9797
messages=[
9898
Message(role='system', content='You are a helpful assistant'),
99-
Message(role='user', content="你是谁?"),
99+
Message(role='user', content='你是谁?'),
100100
]
101101
)
102102

@@ -121,9 +121,9 @@ def eval():
121121
# Decode and print each response
122122
print('Responses:')
123123
for i, seq in enumerate(result.sequences):
124-
print(f"{i}: {repr(template.decode(seq.tokens))}")
124+
print(f'{i}: {repr(template.decode(seq.tokens))}')
125125

126126

127-
if __name__ == "__main__":
127+
if __name__ == '__main__':
128128
# train() # Uncomment to run training
129129
eval() # Run evaluation / inference

cookbook/client/tinker/short_math_grpo.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -208,8 +208,8 @@ def main():
208208
dataset = create_Math_dataset()
209209
dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE)
210210
template = Template(model_id=f'ms://{BASE_MODEL}')
211-
212-
logger.info("Dataset and template initialized")
211+
212+
logger.info('Dataset and template initialized')
213213

214214
# Step 2: Initialize the Tinker-compatible client
215215
logger.info('Connecting to Tinker server...')

src/twinkle/server/tinker/common/compat_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def _to_float(v):
101101
head, unit = s.split() # ignore unit/tail
102102
cleaned[f'{key}/{unit}'] = float(head)
103103
except Exception:
104-
m = re.match(r"^([+-]?(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][+-]?\d+)?)", s)
104+
m = re.match(r'^([+-]?(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][+-]?\d+)?)', s)
105105
if m:
106106
cleaned[key] = float(m.group(1))
107107

src/twinkle/server/tinker/model.py

Lines changed: 27 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -120,25 +120,25 @@ def __init__(self,
120120

121121
def _cleanup_adapter(self, adapter_name: str) -> None:
122122
"""Common adapter cleanup logic used by both manual unload and automatic expiration.
123-
123+
124124
This method handles:
125125
1. Clearing adapter state
126126
2. Removing adapter from model
127127
3. Unregistering from adapter manager
128128
4. Removing from server state
129-
129+
130130
Args:
131131
adapter_name: Name of the adapter to clean up
132132
"""
133133
# Remove from model if it exists
134134
if self.get_adapter_info(adapter_name):
135135
# Clear adapter state
136136
self.clear_adapter_state(adapter_name)
137-
137+
138138
self.model.remove_adapter(adapter_name)
139139
# Unregister from adapter manager
140140
self.unregister_adapter(adapter_name)
141-
141+
142142
# Remove from server state
143143
self.state.unload_model(adapter_name)
144144

@@ -175,16 +175,13 @@ async def _create_adapter():
175175
# TODO: support more lora config parameters, train_unembed, etc.
176176
lora_cfg = LoraConfig(r=body.lora_config.rank, target_modules='all-linear')
177177

178-
adapter_name = self.get_adapter_name(
179-
adapter_name=model_id)
180-
178+
adapter_name = self.get_adapter_name(adapter_name=model_id)
179+
181180
# Register adapter FIRST (limit check happens inside register_adapter)
182-
self.register_adapter(
183-
adapter_name, request.state.token, session_id=body.session_id)
184-
181+
self.register_adapter(adapter_name, request.state.token, session_id=body.session_id)
182+
185183
# Create adapter AFTER successful registration
186-
self.model.add_adapter_to_model(
187-
adapter_name=adapter_name, config_or_dir=lora_cfg)
184+
self.model.add_adapter_to_model(adapter_name=adapter_name, config_or_dir=lora_cfg)
188185

189186
self.model.set_template('Template', adapter_name=adapter_name, model_id=self.base_model)
190187
self.model.set_processor('InputProcessor', adapter_name=adapter_name)
@@ -193,8 +190,7 @@ async def _create_adapter():
193190
# Fresh adapter has no accumulated gradients.
194191
self.set_adapter_state(adapter_name, 'grad_ready', False)
195192

196-
training_run_manager = create_training_run_manager(
197-
request.state.token)
193+
training_run_manager = create_training_run_manager(request.state.token)
198194
training_run_manager.save(model_id, body)
199195

200196
return types.CreateModelResponse(model_id=model_id)
@@ -261,8 +257,7 @@ async def unload_model(self, request: Request, body: types.UnloadModelRequest) -
261257

262258
async def _do_unload():
263259
# Only remove adapter, not the base model
264-
adapter_name = self.get_adapter_name(
265-
adapter_name=body.model_id)
260+
adapter_name = self.get_adapter_name(adapter_name=body.model_id)
266261
# Use common cleanup logic
267262
self._cleanup_adapter(adapter_name)
268263
return types.UnloadModelResponse(model_id=body.model_id)
@@ -315,9 +310,7 @@ async def _do_forward():
315310

316311
# Calculate input tokens and batch size for validation
317312
datum_list = body.forward_input.data
318-
input_tokens = sum(
319-
len(d.model_input.to_ints()) for d in datum_list
320-
)
313+
input_tokens = sum(len(d.model_input.to_ints()) for d in datum_list)
321314
batch_size = len(datum_list)
322315
return await self.schedule_task(
323316
_do_forward,
@@ -360,11 +353,12 @@ async def _do_forward_backward():
360353
loss_fn_config = body.forward_backward_input.loss_fn_config or {}
361354

362355
# Unified forward_backward for both Megatron and Transformers
363-
output, loss = self.model.forward_backward(inputs=datum_list,
364-
adapter_name=adapter_name,
365-
loss_fn=loss_fn,
366-
**loss_fn_config)
367-
output_type = 'ImportanceSamplingLossReturn' if loss_fn == 'importance_sampling' else 'CrossEntropyLossReturn'
356+
output, loss = self.model.forward_backward(
357+
inputs=datum_list, adapter_name=adapter_name, loss_fn=loss_fn, **loss_fn_config)
358+
if loss_fn == 'importance_sampling':
359+
output_type = 'ImportanceSamplingLossReturn'
360+
else:
361+
output_type = 'CrossEntropyLossReturn'
368362
# Mark gradients as ready after a successful forward_backward.
369363
self.set_adapter_state(adapter_name, 'grad_ready', True)
370364
return types.ForwardBackwardOutput(
@@ -381,9 +375,7 @@ async def _do_forward_backward():
381375

382376
# Calculate input tokens and batch size for validation
383377
datum_list = body.forward_backward_input.data
384-
input_tokens = sum(
385-
len(d.model_input.to_ints()) for d in datum_list
386-
)
378+
input_tokens = sum(len(d.model_input.to_ints()) for d in datum_list)
387379
batch_size = len(datum_list)
388380
return await self.schedule_task(
389381
_do_forward_backward,
@@ -417,14 +409,13 @@ async def _do_optim():
417409
# Disallow empty step (must have at least one forward_backward since last step)
418410
if not self.get_adapter_state(adapter_name, 'grad_ready', False):
419411
raise RuntimeError(
420-
f"No accumulated gradients for adapter={adapter_name}; call forward_backward before optim_step"
412+
f'No accumulated gradients for adapter={adapter_name}; call forward_backward before optim_step' # noqa: E501
421413
)
422414

423415
# Touch adapter to reset inactivity counter
424416
self.touch_adapter(adapter_name)
425417

426-
self.model.step(adam_params=body.adam_params,
427-
adapter_name=adapter_name)
418+
self.model.step(adam_params=body.adam_params, adapter_name=adapter_name)
428419
# Clear grad-ready after a successful step.
429420
self.set_adapter_state(adapter_name, 'grad_ready', False)
430421
metrics = self.model.calculate_metric(is_training=True, adapter_name=adapter_name)
@@ -590,15 +581,15 @@ async def _do_load():
590581
weight_path = body.path
591582
load_optimizer = body.optimizer
592583

593-
self.model.load(checkpoint_dir=weight_path,
594-
load_optimizer=load_optimizer,
595-
adapter_name=adapter_name,
596-
token=token)
584+
self.model.load(
585+
checkpoint_dir=weight_path,
586+
load_optimizer=load_optimizer,
587+
adapter_name=adapter_name,
588+
token=token)
597589

598590
# Loading a checkpoint should reset step readiness.
599591
self.set_adapter_state(adapter_name, 'grad_ready', False)
600-
return types.LoadWeightsResponse(path=body.path,
601-
type='load_weights')
592+
return types.LoadWeightsResponse(path=body.path, type='load_weights')
602593
except Exception:
603594
logger.error(traceback.format_exc())
604595
return types.RequestFailedResponse(

src/twinkle/server/tinker/sampler.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -160,14 +160,14 @@ async def _do_sample():
160160
token = request.state.token
161161
checkpoint_manager = create_checkpoint_manager(token)
162162
adapter_name, adapter_uri = checkpoint_manager.parse_adapter_uri(model_path)
163-
163+
164164
# Validate adapter URI existence if provided
165165
if not adapter_uri or not os.path.exists(adapter_uri):
166166
return types.RequestFailedResponse(
167-
error=f"Adapter URI {model_path} does not exist. Please check the model_path.",
167+
error=f'Adapter URI {model_path} does not exist. Please check the model_path.',
168168
category=types.RequestErrorCategory.User,
169169
)
170-
170+
171171
# Convert tinker SamplingParams to twinkle SamplingParams if needed
172172
sampling_params = None
173173
if body.sampling_params:

src/twinkle/server/tinker/server.py

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,11 @@
2828

2929
logger = logging.getLogger(__name__)
3030

31-
def build_server_app(
32-
deploy_options: Dict[str, Any],
33-
supported_models: Optional[List[types.SupportedModel]] = None,
34-
server_config: Dict[str, Any] = {},
35-
**kwargs
36-
):
31+
32+
def build_server_app(deploy_options: dict[str, Any],
33+
supported_models: list[types.SupportedModel] | None = None,
34+
server_config: dict[str, Any] = {},
35+
**kwargs):
3736
"""Build and configure the Tinker-compatible server application.
3837
3938
This factory function creates a FastAPI application with Ray Serve deployment
@@ -66,8 +65,11 @@ class TinkerCompatServer:
6665
- Proxying to model/sampler deployments
6766
- Training run and checkpoint CRUD operations
6867
"""
69-
70-
def __init__(self, supported_models: Optional[List[types.SupportedModel]] = None, server_config: Dict[str, Any] = {}, **kwargs) -> None:
68+
69+
def __init__(self,
70+
supported_models: list[types.SupportedModel] | None = None,
71+
server_config: dict[str, Any] = {},
72+
**kwargs) -> None:
7173
"""Initialize the Tinker-compatible server.
7274
7375
Args:
@@ -78,13 +80,13 @@ def __init__(self, supported_models: Optional[List[types.SupportedModel]] = None
7880
self.state = get_server_state(**server_config)
7981
# Disable proxy for internal requests to avoid routing through external proxies
8082
self.client = httpx.AsyncClient(timeout=None, trust_env=False)
81-
self.route_prefix = kwargs.get("route_prefix", "/api/v1")
83+
self.route_prefix = kwargs.get('route_prefix', '/api/v1')
8284
self.supported_models = self.normalize_models(supported_models) or [
83-
types.SupportedModel(model_name="Qwen/Qwen2.5-0.5B-Instruct"),
84-
types.SupportedModel(model_name="Qwen/Qwen2.5-3B-Instruct"),
85-
types.SupportedModel(model_name="Qwen/Qwen2.5-7B-Instruct"),
86-
types.SupportedModel(model_name="Qwen/Qwen2.5-72B-Instruct"),
87-
types.SupportedModel(model_name="Qwen/Qwen3-30B-A3B-Instruct-2507"),
85+
types.SupportedModel(model_name='Qwen/Qwen2.5-0.5B-Instruct'),
86+
types.SupportedModel(model_name='Qwen/Qwen2.5-3B-Instruct'),
87+
types.SupportedModel(model_name='Qwen/Qwen2.5-7B-Instruct'),
88+
types.SupportedModel(model_name='Qwen/Qwen2.5-72B-Instruct'),
89+
types.SupportedModel(model_name='Qwen/Qwen3-30B-A3B-Instruct-2507'),
8890
]
8991
# Lock for ModelScope config file operations (login writes, get_user_info reads)
9092
self._modelscope_config_lock = asyncio.Lock()
@@ -682,7 +684,4 @@ async def save_weights_for_sampler(self, request: Request, body: types.SaveWeigh
682684
return await self._proxy_to_model(request, 'save_weights_for_sampler', base_model)
683685

684686
return TinkerCompatServer.options(**deploy_options).bind(
685-
supported_models=supported_models,
686-
server_config=server_config,
687-
**kwargs
688-
)
687+
supported_models=supported_models, server_config=server_config, **kwargs)

src/twinkle/server/twinkle/model.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -200,18 +200,17 @@ def _on_adapter_expired(self, adapter_name: str) -> None:
200200
if self.get_adapter_info(adapter_name):
201201
# Clear adapter state
202202
self.clear_adapter_state(adapter_name)
203-
203+
204204
self.model.remove_adapter(adapter_name)
205205
# Unregister from adapter manager
206206
self.unregister_adapter(adapter_name)
207-
207+
208208
# Remove from server state
209209
self.state.unload_model(adapter_name)
210210
# Remove adapter from model
211211
self.model.remove_adapter(adapter_name)
212212

213-
214-
@app.post("/create")
213+
@app.post('/create')
215214
def create(self, request: Request, body: CreateRequest):
216215
return {'status': 'ok'}
217216

@@ -508,13 +507,13 @@ def add_adapter_to_model(self, request: Request, body: AddAdapterRequest):
508507
# Extract token for metadata storage
509508
token = request.state.token
510509
training_run_manager = create_training_run_manager(token)
511-
510+
512511
# Register adapter FIRST (limit check happens inside register_adapter)
513512
self.register_adapter(adapter_name, token)
514-
513+
515514
# Create adapter AFTER successful registration
516515
self.model.add_adapter_to_model(adapter_name, config, **extra_kwargs)
517-
516+
518517
# Save training run metadata (similar to tinker's create_model)
519518
# Create a training run config from the adapter configuration
520519
lora_config = None

0 commit comments

Comments
 (0)