Skip to content

Commit 987d89e

Browse files
committed
update
1 parent 7b5412b commit 987d89e

File tree

3 files changed

+41
-36
lines changed

3 files changed

+41
-36
lines changed

cookbook/client/tinker/megatron/server_config_7b.yaml

Lines changed: 34 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -71,37 +71,37 @@ applications:
7171

7272
# 3. Sampler Service - Runs inference / sampling using vLLM engine
7373
# Used for generating text from the model (e.g., evaluating LoRA results).
74-
# - name: sampler-Qwen2.5-7B-Instruct
75-
# route_prefix: /api/v1/sampler/Qwen/Qwen2.5-7B-Instruct
76-
# import_path: sampler
77-
# args:
78-
# model_id: "ms://Qwen/Qwen2.5-7B-Instruct" # ModelScope model identifier
79-
# nproc_per_node: 2 # Number of GPU processes per node
80-
# sampler_type: vllm # Inference engine: 'vllm' (fast) or 'torch' (TorchSampler)
81-
# engine_args: # vLLM engine-specific settings
82-
# max_model_len: 4096 # Maximum sequence length the engine supports
83-
# gpu_memory_utilization: 0.5 # Fraction of GPU memory to use (0.0-1.0)
84-
# enable_lora: true # Allow loading LoRA adapters during inference
85-
# logprobs_mode: processed_logprobs # Logprobs mode for sampling results
86-
# device_group: # Logical device group for the sampler
87-
# name: sampler
88-
# ranks: [2] # GPU rank indices to use
89-
# device_type: cuda
90-
# device_mesh:
91-
# device_type: cuda
92-
# dp_size: 1
93-
# queue_config:
94-
# rps_limit: 100 # Max requests per second
95-
# tps_limit: 100000 # Max tokens per second
96-
# deployments:
97-
# - name: SamplerManagement
98-
# autoscaling_config:
99-
# min_replicas: 1
100-
# max_replicas: 1
101-
# target_ongoing_requests: 16
102-
# ray_actor_options:
103-
# num_cpus: 0.1
104-
# runtime_env:
105-
# env_vars:
106-
# TWINKLE_TRUST_REMOTE_CODE: "0"
107-
# DEVICE_COUNT_PER_PHYSICAL_NODE: "8"
74+
- name: sampler-Qwen2.5-7B-Instruct
75+
route_prefix: /api/v1/sampler/Qwen/Qwen2.5-7B-Instruct
76+
import_path: sampler
77+
args:
78+
model_id: "ms://Qwen/Qwen2.5-7B-Instruct" # ModelScope model identifier
79+
nproc_per_node: 2 # Number of GPU processes per node
80+
sampler_type: vllm # Inference engine: 'vllm' (fast) or 'torch' (TorchSampler)
81+
engine_args: # vLLM engine-specific settings
82+
max_model_len: 4096 # Maximum sequence length the engine supports
83+
gpu_memory_utilization: 0.5 # Fraction of GPU memory to use (0.0-1.0)
84+
enable_lora: true # Allow loading LoRA adapters during inference
85+
logprobs_mode: processed_logprobs # Logprobs mode for sampling results
86+
device_group: # Logical device group for the sampler
87+
name: sampler
88+
ranks: [2] # GPU rank indices to use
89+
device_type: cuda
90+
device_mesh:
91+
device_type: cuda
92+
dp_size: 1
93+
queue_config:
94+
rps_limit: 100 # Max requests per second
95+
tps_limit: 100000 # Max tokens per second
96+
deployments:
97+
- name: SamplerManagement
98+
autoscaling_config:
99+
min_replicas: 1
100+
max_replicas: 1
101+
target_ongoing_requests: 16
102+
ray_actor_options:
103+
num_cpus: 0.1
104+
runtime_env:
105+
env_vars:
106+
TWINKLE_TRUST_REMOTE_CODE: "0"
107+
DEVICE_COUNT_PER_PHYSICAL_NODE: "8"

cookbook/client/tinker/sample.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
# for text generation (sampling) via the Tinker-compatible client API.
55
# The server must be running first (see server.py and server_config.yaml).
66

7-
from modelscope import AutoTokenizer
87
from tinker import types
98

109
from twinkle.data_format import Message, Trajectory

cookbook/client/tinker/short_math_grpo.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,17 @@
2121
import numpy as np
2222
import os
2323
import re
24-
from modelscope import AutoTokenizer
2524
from tinker import types
2625
from typing import List, Tuple
2726

27+
from twinkle_client import init_tinker_compat_client
2828
from twinkle import get_logger
2929
from twinkle.advantage import GRPOAdvantage
3030
from twinkle.data_format import Message, Trajectory
3131
from twinkle.dataloader import DataLoader
3232
from twinkle.dataset import Dataset, DatasetMeta
33+
from twinkle.preprocessor import Preprocessor
34+
from twinkle.reward.base import Reward
3335
from twinkle.metric import CompletionRewardMetric
3436
from twinkle.template import Template
3537

@@ -332,6 +334,10 @@ def main():
332334
).tolist()
333335

334336
frac_zero_std = (1.0 if all(abs(a) < 1e-8 for a in advantages) else 0.0)
337+
if frac_zero_std == 1.0:
338+
logger.info(f'Step {step}: All advantages are zero, skipping training')
339+
step += 1
340+
continue
335341

336342
# ========== 6. Train the policies with GRPO loss ==========
337343
# Train the policies with the Advantage-Regularized policy

0 commit comments

Comments
 (0)