Skip to content

Commit 1315d44

Browse files
committed
fix
1 parent fe6eb14 commit 1315d44

File tree

4 files changed

+20
-10
lines changed

4 files changed

+20
-10
lines changed

cookbook/client/tinker/self_host/sample.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from tinker import types
99

1010
from twinkle.data_format import Message, Trajectory
11-
from twinkle.template import Template
11+
from twinkle.template import Template, Qwen3_5Template
1212
from twinkle import init_tinker_client
1313

1414
# Step 1: Initialize Tinker client
@@ -27,14 +27,14 @@
2727
# The model_path is a twinkle:// URI pointing to a previously saved LoRA checkpoint.
2828
# The server will load the base model and apply the LoRA adapter weights.
2929
sampling_client = service_client.create_sampling_client(
30-
model_path='twinkle://xxx-Qwen_Qwen3.5-4B-xxx/weights/twinkle-lora-1',
30+
# model_path='twinkle://xxx-Qwen_Qwen3.5-4B-xxx/weights/twinkle-lora-1',
3131
base_model=base_model
3232
)
3333

3434
# Step 4: Load the tokenizer locally to encode the prompt and decode the results
3535
print(f'Using model {base_model}')
3636

37-
template = Template(model_id=f'ms://{base_model}')
37+
template = Qwen3_5Template(model_id=f'ms://{base_model}')
3838

3939
trajectory = Trajectory(
4040
messages=[
@@ -43,7 +43,7 @@
4343
]
4444
)
4545

46-
input_feature = template.encode(trajectory, add_generation_prompt=True)
46+
input_feature = template.batch_encode([trajectory], add_generation_prompt=True)[0]
4747

4848
input_ids = input_feature['input_ids'].tolist()
4949

cookbook/client/tinker/self_host/short_math_grpo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939

4040
# ========== Configuration ==========
4141
BASE_MODEL = 'Qwen/Qwen3.5-4B'
42-
NUM_GENERATIONS = 8
42+
NUM_GENERATIONS = 4
4343
MAX_NEW_TOKENS = 4096
4444
LEARNING_RATE = 1e-5
4545
MAX_STEPS = 1000

src/twinkle/model/megatron/megatron.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1169,11 +1169,14 @@ def _save_hf_format(self, output_dir: str, adapter_name: str, lora_converter=Non
11691169
# Save config on rank 0 only
11701170
if dp_rank == 0:
11711171
self.hf_config.save_pretrained(output_dir)
1172+
if isinstance(model[0], PeftModel):
1173+
model[0].peft_config[adapter_name].save_pretrained(output_dir)
11721174

11731175
def _save_megatron_format(self, output_dir: str, adapter_name: str, lora_converter=None):
11741176
"""Save in Megatron checkpoint format."""
11751177
os.makedirs(output_dir, exist_ok=True)
1176-
1178+
from megatron.core import parallel_state as mpu
1179+
dp_rank = mpu.get_data_parallel_rank() if mpu.is_initialized() else 0
11771180
state_dict = self._get_trainable_parameters(adapter_name)
11781181
cpu_state_dict = {}
11791182
for k, v in state_dict.items():
@@ -1189,6 +1192,12 @@ def _save_megatron_format(self, output_dir: str, adapter_name: str, lora_convert
11891192
rank = dist.get_rank() if dist.is_initialized() else 0
11901193
checkpoint_path = os.path.join(output_dir, f'model_rank{rank}.pt')
11911194
torch.save(cpu_state_dict, checkpoint_path)
1195+
# Save config on rank 0 only
1196+
model = self.strategy.unwrap_model(self.model)
1197+
if dp_rank == 0:
1198+
self.hf_config.save_pretrained(output_dir)
1199+
if isinstance(model[0], PeftModel):
1200+
model[0].peft_config[adapter_name].save_pretrained(output_dir)
11921201

11931202
def _save_tokenizer(self, output_dir: str, **kwargs):
11941203
from twinkle.utils import is_last_rank

src/twinkle/sampler/vllm_sampler/vllm_sampler.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -235,14 +235,15 @@ async def _sample_single(
235235
"""
236236
multi_modal_data = self._extract_multi_modal_data(feat)
237237
response = await self.engine.sample(
238-
prompt=feat['prompt'],
238+
prompt=feat['prompt'] if 'prompt' in feat else feat['input_ids'],
239239
sampling_params=sampling_params,
240240
lora_request=lora_request,
241241
multi_modal_data=multi_modal_data,
242242
mm_processor_kwargs=feat.get('mm_processor_kwargs'),
243243
)
244-
feat['input_ids'] = response.prompt_token_ids
245-
feat['labels'] = [-100] * len(response.prompt_token_ids)
244+
if 'input_ids' not in feat:
245+
feat['input_ids'] = response.prompt_token_ids
246+
feat['labels'] = [-100] * len(response.prompt_token_ids)
246247
if not logprobs_only:
247248
# response.sequences contains num_samples sequences for this prompt
248249
sequences = []
@@ -318,7 +319,7 @@ def sample(
318319
inputs_list = self._normalize_inputs(inputs)
319320

320321
# Check if inputs are Trajectory (not encoded) - aligned with Model.forward logic
321-
is_trajectory = 'prompt' not in inputs_list[0] or 'input_ids' not in inputs_list[0]
322+
is_trajectory = 'prompt' not in inputs_list[0] and 'input_ids' not in inputs_list[0]
322323
logprobs_only = False
323324
if sampling_params.max_tokens == 0:
324325
sampling_params.max_tokens = 1

0 commit comments

Comments
 (0)