Skip to content

Commit 1c12fff

Browse files
committed
wip
1 parent 45c09a1 commit 1c12fff

File tree

3 files changed

+3
-17
lines changed

3 files changed

+3
-17
lines changed

cookbook/rl/gkd_off_policy.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@
6969

7070
GKD_BETA = float(os.environ.get('GKD_BETA', 0.5))
7171
GKD_TEMPERATURE = float(os.environ.get('GKD_TEMPERATURE', 1.0))
72-
GKD_TOPK = int(os.environ.get('GKD_TOPK', 20))
72+
GKD_TOPK = int(os.environ.get('GKD_TOPK', 64))
7373
MAX_NEW_TOKENS = int(os.environ.get('MAX_NEW_TOKENS', 2048))
7474
N_SAMPLES = int(os.environ.get('N_SAMPLES', 1))
7575
ADAPTER_NAME = 'default'
@@ -188,7 +188,7 @@ def main():
188188
# ── Teacher vLLM sampler (for prompt logprobs) ─────────────────────────────
189189
teacher_sampler = vLLMSampler(
190190
model_id=TEACHER_MODEL_ID,
191-
engine_args={'gpu_memory_utilization': 0.85, 'max_model_len': 10240, 'logprobs_mode': 'raw_logprobs'},
191+
engine_args={'gpu_memory_utilization': 0.85, 'max_model_len': 10240, 'logprobs_mode': 'raw_logprobs', 'max_logprobs': 64},
192192
device_mesh=sampler_mesh,
193193
remote_group='teacher_sampler',
194194
)
@@ -199,7 +199,6 @@ def main():
199199
dataset=create_dataset,
200200
batch_size=BATCH_SIZE,
201201
min_batch_size=BATCH_SIZE,
202-
device_mesh=model_mesh,
203202
remote_group='student_model',
204203
)
205204

src/twinkle/preprocessor/llm.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -122,10 +122,7 @@ class GSM8KProcessor(Preprocessor):
122122
Extracts the ground truth number and stores it in user_data for reward.
123123
"""
124124

125-
system_prompt = ('You are a helpful math assistant. Solve the problem step by step. '
126-
'Show your reasoning in <think> </think> tags, then give the final '
127-
'numerical answer after ####.\n'
128-
'For example:\n<think> ... reasoning ... </think>\n#### 42')
125+
system_prompt = ('You are a helpful math assistant. Solve the problem step by step and put your final answer within \\boxed{}.')
129126

130127
def extract_ground_truth(self, answer_str: str) -> str:
131128
"""Extract the number after '####' from GSM8K answer."""

src/twinkle/utils/torch_utils.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -72,16 +72,6 @@ def selective_log_softmax(logits, index) -> 'torch.Tensor':
7272
import torch
7373
import torch.nn.functional as F
7474

75-
try:
76-
from megatron.core import parallel_state as mpu
77-
if mpu.get_tensor_model_parallel_world_size() >= 1:
78-
try:
79-
return _vocab_parallel_selective_log_softmax(logits, index)
80-
except Exception: # noqa
81-
import traceback
82-
print(traceback.format_exc())
83-
except Exception: # noqa
84-
pass
8575
if logits.dtype in [torch.float32, torch.float64]:
8676
selected_logits = torch.gather(logits, dim=-1, index=index.unsqueeze(-1)).squeeze(-1)
8777
# loop to reduce peak mem consumption

0 commit comments

Comments
 (0)