Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
ce5a4a2
wip
tastelikefeet Mar 10, 2026
89b96b4
wip
tastelikefeet Mar 10, 2026
85c5afb
wip
tastelikefeet Mar 10, 2026
75d0377
fix
tastelikefeet Mar 10, 2026
79c22fb
fix
tastelikefeet Mar 10, 2026
b50f565
Merge commit '85e4f7df0a5bf3346868bca77080d8be80aa27fe' into feat/gkd
tastelikefeet Mar 11, 2026
1082035
wip
tastelikefeet Mar 11, 2026
e9c6590
Merge branch 'feat/gkd' of https://github.com/tastelikefeet/twinkle i…
tastelikefeet Mar 11, 2026
eb6b5be
fix
tastelikefeet Mar 11, 2026
e382a26
Merge commit 'e9c6590efa11d89b705f91bed9f9f4cdafa637a6' into feat/gkd
tastelikefeet Mar 11, 2026
09c3c0f
wip
tastelikefeet Mar 11, 2026
1c9be4c
Merge commit 'd69a864530a25909a743ba51e74e32bffb624132' into feat/gkd
tastelikefeet Mar 13, 2026
e7677f5
fix
tastelikefeet Mar 13, 2026
a4ff6c5
fix
tastelikefeet Mar 13, 2026
7c726f7
fix
tastelikefeet Mar 13, 2026
4e6ac60
fix
tastelikefeet Mar 13, 2026
5ced908
fix
tastelikefeet Mar 14, 2026
1524dbf
wip
tastelikefeet Mar 14, 2026
30df960
fix
tastelikefeet Mar 14, 2026
43be0f8
wip
tastelikefeet Mar 14, 2026
0449340
no message
tastelikefeet Mar 15, 2026
4296d62
wip
tastelikefeet Mar 15, 2026
39f9449
fix
tastelikefeet Mar 15, 2026
a903cb9
wip
tastelikefeet Mar 15, 2026
926210c
wip
tastelikefeet Mar 15, 2026
c49fccd
wip
tastelikefeet Mar 16, 2026
1e7240f
fix
tastelikefeet Mar 16, 2026
29dc7ac
wip
tastelikefeet Mar 16, 2026
36a0eb2
wip
tastelikefeet Mar 17, 2026
488ea43
wip
tastelikefeet Mar 17, 2026
e4b931a
wip
tastelikefeet Mar 17, 2026
17329d3
wip
tastelikefeet Mar 17, 2026
1ebac31
Merge commit 'cb52a6c6108c8227034648dff917b32d5cab84c5' into feat/gkd
tastelikefeet Mar 17, 2026
fcb163b
wip
tastelikefeet Mar 17, 2026
b6332d9
wip
tastelikefeet Mar 17, 2026
37d38c1
lint code
tastelikefeet Mar 17, 2026
a01c524
wip
tastelikefeet Mar 17, 2026
45c09a1
wip
tastelikefeet Mar 17, 2026
1c12fff
wip
tastelikefeet Mar 17, 2026
f2a1fc7
fix
tastelikefeet Mar 17, 2026
519dba9
wip
tastelikefeet Mar 17, 2026
2576f18
fix
tastelikefeet Mar 18, 2026
e23ee41
fix
tastelikefeet Mar 18, 2026
2370699
Revert "fix"
tastelikefeet Mar 18, 2026
7e83bdc
fix
tastelikefeet Mar 18, 2026
ff37789
fix
tastelikefeet Mar 18, 2026
a6205ce
fix
tastelikefeet Mar 18, 2026
781514e
wip
tastelikefeet Mar 18, 2026
7653caf
lint code
tastelikefeet Mar 18, 2026
5c36715
fix
tastelikefeet Mar 18, 2026
5dc401d
fix docs
tastelikefeet Mar 18, 2026
2a0bfe0
fix
tastelikefeet Mar 18, 2026
a46bd59
fix
tastelikefeet Mar 18, 2026
758ab1b
fix docs
tastelikefeet Mar 18, 2026
aeb0ec1
fix
tastelikefeet Mar 18, 2026
c26a708
fix
tastelikefeet Mar 18, 2026
e952bde
fix
tastelikefeet Mar 18, 2026
4facc4f
fix
tastelikefeet Mar 18, 2026
29db904
fix
tastelikefeet Mar 19, 2026
05be92f
lint
tastelikefeet Mar 19, 2026
9fdd627
fix
tastelikefeet Mar 19, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,8 @@ Or use ModelScope's [official image](https://www.modelscope.cn/docs/intro/enviro

## Changelog

- 🎉2026-03-19 Support GKD training ,please refer to this [cookbook](cookbook/rl/gkd_on_policy.py).
- 🎉2026-02-13 Initial version of Twinkle✨ released, including SFT/PT/RL support for text models.
We also made available serverless training capabilities on [ModelScope](https://modelscope.cn) via
Tinker-compatible APIs.

## Training as a Service on ModelScope

Expand Down
1 change: 1 addition & 0 deletions README_ZH.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ Twinkle✨支持相同的算法接口运行在单GPU、torchrun多机、Ray、Cl

## 更新日志

🎉2026-03-19 支持GKD蒸馏能力,参考[cookbook](cookbook/rl/gkd_on_policy.py)。
🎉2026-02-13 Twinkle✨ 初始版本发布,支持文本模型的SFT/PT/RL训练。我们还通过兼容Tinker的API,在魔搭社区上提供了无服务器训练功能。

## ModelScope 的训练服务
Expand Down
4 changes: 2 additions & 2 deletions client_tools/client_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -768,7 +768,7 @@ def sample(
adapter_name: str = '',
adapter_uri: Optional[str] = None,
num_samples: int = 1,
) -> SampleResponseModel:
) -> List[SampleResponseModel]:
"""Sample from the model.

Args:
Expand All @@ -795,7 +795,7 @@ def sample(
json_data=json_data
)
response.raise_for_status()
return SampleResponseModel(**response.json())
return [SampleResponseModel(**r) for r in response.json()['samples']]

def set_template(self, template_cls: str, adapter_name: str = '', **kwargs) -> SetTemplateResponse:
"""Set the template for encoding trajectories."""
Expand Down
6 changes: 6 additions & 0 deletions cookbook/client/server/megatron/run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
export RAY_ROTATION_MAX_BYTES=1024
export RAY_ROTATION_BACKUP_COUNT=1
CUDA_VISIBLE_DEVICES=0,1,2,3 ray start --head --port=6379 --num-gpus=4 --disable-usage-stats --include-dashboard=false
CUDA_VISIBLE_DEVICES=4,5,6,7 ray start --address=127.0.0.1:6379 --num-gpus=4
CUDA_VISIBLE_DEVICES="" ray start --address=127.0.0.1:6379 --num-gpus=0
python server.py
6 changes: 6 additions & 0 deletions cookbook/client/server/transformer/run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
export RAY_ROTATION_MAX_BYTES=1024
export RAY_ROTATION_BACKUP_COUNT=1
CUDA_VISIBLE_DEVICES=0,1,2,3 ray start --head --port=6379 --num-gpus=4 --disable-usage-stats --include-dashboard=false
CUDA_VISIBLE_DEVICES=4,5,6,7 ray start --address=127.0.0.1:6379 --num-gpus=4
CUDA_VISIBLE_DEVICES="" ray start --address=127.0.0.1:6379 --num-gpus=0
python server.py
15 changes: 8 additions & 7 deletions cookbook/client/twinkle/self_host/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
from twinkle_client.dataset import Dataset
from twinkle_client.model import MultiLoraTransformersModel
from twinkle_client.sampler import vLLMSampler
from twinkle.preprocessor.llm import GSM8KProcessor

logger = get_logger()

Expand Down Expand Up @@ -127,6 +126,8 @@ def train():
'max_tokens': MAX_NEW_TOKENS,
'temperature': TEMPERATURE,
'top_p': 0.95,
'num_samples': NUM_GENERATIONS,
'logprobs': 1,
}

# Track the current adapter path for sampling
Expand All @@ -153,21 +154,21 @@ def train():
logger.info(f'Step {step}: Saved weights to {current_adapter_uri}')

# ========== 2. Sample completions ==========
sample_response = sampler.sample(
sample_responses = sampler.sample(
inputs=prompts,
sampling_params=sampling_params,
adapter_uri=current_adapter_uri,
num_samples=NUM_GENERATIONS,
)

all_input_data: List[Dict[str, Any]] = []
all_old_logps: List[List[float]] = []
all_completion_lengths: List[int] = []

for sequence in sample_response.sequences:
all_input_data.append(sequence.new_input_feature)
all_old_logps.append(sequence.logprobs)
all_completion_lengths.append(len(sequence.tokens))
for sample_response in sample_responses:
for sequence in sample_response.sequences:
all_input_data.append(sequence.new_input_feature)
all_old_logps.append([logprob[0][1] for logprob in sequence.logprobs])
all_completion_lengths.append(len(sequence.tokens))

# ========== 3. Compute rewards ==========

Expand Down
15 changes: 8 additions & 7 deletions cookbook/client/twinkle/self_host/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,29 +66,30 @@ def sample():
sampling_params = {
'max_tokens': 128,
'temperature': 1.0,
'num_samples': num_samples,
}

# Step 7: Call the sampler
# - inputs: list of Trajectory dicts (will be encoded server-side using the template)
# - sampling_params: controls generation behavior
# - adapter_uri: optional LoRA adapter path for fine-tuned inference
# - num_samples: number of completions per prompt
response = sampler.sample(
responses = sampler.sample(
inputs=[trajectory] * num_prompts,
sampling_params=sampling_params,
adapter_uri=ADAPTER_URI,
num_samples=num_samples,
)

# Step 8: Decode and print the results
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)

logger.info(f'Generated {len(response.sequences)} sequences '
f'({num_prompts} prompts x {num_samples} samples)')
for response in responses:
logger.info(f'Generated {len(response.sequences)} sequences '
f'({num_prompts} prompts x {num_samples} samples)')

for i, seq in enumerate(response.sequences):
text = tokenizer.decode(seq.tokens, skip_special_tokens=True)
logger.info(f'Sequence {i}:\n {text}\n')
for i, seq in enumerate(response.sequences):
text = tokenizer.decode(seq.tokens, skip_special_tokens=True)
logger.info(f'Sequence {i}:\n {text}\n')


if __name__ == '__main__':
Expand Down
226 changes: 226 additions & 0 deletions cookbook/rl/gkd_off_policy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
"""GKD Off-Policy Distillation via Ray.

Off-policy knowledge distillation: the student learns to match the teacher's
token distribution on pre-existing reference responses from the dataset.

Pipeline:
1. DataLoader supplies full-text batches (prompt + reference answer).
2. Teacher vLLM sampler computes top-k prompt logprobs on the sequences.
3. Student MegatronModel runs forward_backward() with GKDLoss.

Key difference from on-policy:
- No student sampler needed (responses already in the dataset).
- No weight sync needed (student doesn't sample).
- Faster per-step (no generation latency), but less exploration.

Architecture (Ray):
┌─────────────────────────────────────────────────────────────────┐
│ Driver (CPU) │
│ dataloader ──► full-text batch (prompt + reference answer) │
│ teacher_sampler.sample(prompt_logprobs=k) ──► teacher lps │
│ student_model.forward_backward(teacher_output=...) ──► GKD │
└─────────────────────────────────────────────────────────────────┘
vLLMSampler + MegatronModel
(teacher) (student)

Environment variables (all optional):
STUDENT_MODEL_ID – (default: ms://Qwen/Qwen3-0.6B)
TEACHER_MODEL_ID – (default: ms://Qwen/Qwen3-8B)
MODEL_GPUS – GPUs for student model (default: 8)
SAMPLER_GPUS – GPUs for teacher vLLM sampler (default: 8)
BATCH_SIZE – global batch size (default: 16)
MAX_STEPS – total optimisation steps (default: 1000)
LR – learning rate (default: 5e-5)
GKD_BETA – JSD beta (0=fwd KL, 1=rev KL) (default: 0.5)
GKD_TEMPERATURE – distillation temperature (default: 1.0)
GKD_TOPK – top-k vocab for teacher logprobs (default: 64)
"""

import os
from typing import List, Optional

import torch
from peft import LoraConfig

import twinkle
from twinkle import DeviceMesh, DeviceGroup, get_device_placement, get_logger
from twinkle.data_format import SamplingParams
from twinkle.dataloader import DataLoader
from twinkle.dataset import Dataset, DatasetMeta
from twinkle.loss import GKDLoss
from twinkle.model import MegatronModel
from twinkle.preprocessor import GSM8KProcessor
from twinkle.sampler import vLLMSampler
from twinkle.template import Template

logger = get_logger()

# ── Configuration ─────────────────────────────────────────────────────────────
STUDENT_MODEL_ID = os.environ.get('STUDENT_MODEL_ID', 'ms://Qwen/Qwen3-0.6B')
TEACHER_MODEL_ID = os.environ.get('TEACHER_MODEL_ID', 'ms://Qwen/Qwen3-8B')

MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 8))
SAMPLER_GPUS = int(os.environ.get('SAMPLER_GPUS', 8))
NUM_GPUS = MODEL_GPUS + SAMPLER_GPUS

BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 16))
MAX_STEPS = int(os.environ.get('MAX_STEPS', 1000))
LEARNING_RATE = float(os.environ.get('LR', 5e-5))

GKD_BETA = float(os.environ.get('GKD_BETA', 0.5))
GKD_TEMPERATURE = float(os.environ.get('GKD_TEMPERATURE', 1.0))
GKD_TOPK = int(os.environ.get('GKD_TOPK', 64))
ADAPTER_NAME = 'default'
SYSTEM_PROMPT = ('You are a helpful math assistant. Solve the problem step by step and put '
'your final answer within #### <number>')


# ── Dataset ───────────────────────────────────────────────────────────────────

def create_dataset():
"""Full-text dataset with prompt + reference answer for off-policy distillation."""
dataset = Dataset(DatasetMeta('ms://modelscope/gsm8k', subset_name='main', split='train'))
dataset.set_template('Template', model_id=STUDENT_MODEL_ID, max_length=1024)
dataset.map(GSM8KProcessor(system=SYSTEM_PROMPT, add_assistant=True))
return dataset


# ── Utility ───────────────────────────────────────────────────────────────────

def convert_topk_prompt_logprobs(
topk_prompt_logprobs_batch: List[Optional[List[List[tuple]]]],
) -> dict:
"""Convert vLLM topk_prompt_logprobs to GKDLoss teacher_output format.

Args:
topk_prompt_logprobs_batch: [batch] each is topk_prompt_logprobs for one request.
Shape: [seq_len, topk] per request, where each position is List[(token_id, logprob)].

Returns:
Dict with teacher logprobs/indices tensors.
"""
batch_logprobs = []
batch_indices = []

for seq_topk in topk_prompt_logprobs_batch:
seq_logprobs = []
seq_indices = []
if seq_topk is not None:
for pos_topk in seq_topk:
if pos_topk is None:
# First position is None, fill with placeholder
seq_logprobs.append([0.0] * GKD_TOPK)
seq_indices.append([0] * GKD_TOPK)
else:
seq_logprobs.append([lp for _, lp in pos_topk])
seq_indices.append([tid for tid, _ in pos_topk])
batch_logprobs.append(seq_logprobs)
batch_indices.append(seq_indices)

# Pad to same seq_len within batch
max_len = max(len(seq) for seq in batch_logprobs) if batch_logprobs else 1

for i in range(len(batch_logprobs)):
pad_len = max_len - len(batch_logprobs[i])
if pad_len > 0:
batch_logprobs[i].extend([[0.0] * GKD_TOPK] * pad_len)
batch_indices[i].extend([[0] * GKD_TOPK] * pad_len)

# Roll to align with labels (first position has no valid logprobs)
return {
'teacher_topk_logprobs': torch.roll(torch.tensor(batch_logprobs, dtype=torch.float32), shifts=-1, dims=1),
'teacher_topk_indices': torch.roll(torch.tensor(batch_indices, dtype=torch.long), shifts=-1, dims=1),
}


# ── Training ──────────────────────────────────────────────────────────────────

def main():
device_groups = [
DeviceGroup(name='student_model', ranks=MODEL_GPUS, device_type='cuda'),
DeviceGroup(name='teacher_sampler', ranks=SAMPLER_GPUS, device_type='cuda'),
]
model_mesh = DeviceMesh.from_sizes(world_size=MODEL_GPUS, dp_size=MODEL_GPUS)
sampler_mesh = DeviceMesh.from_sizes(world_size=SAMPLER_GPUS, dp_size=SAMPLER_GPUS)

twinkle.initialize(
mode='ray',
nproc_per_node=NUM_GPUS,
groups=device_groups,
)
logger.info(get_device_placement())

# ── Student model (trainable) ──────────────────────────────────────────────
student_model = MegatronModel(
model_id=STUDENT_MODEL_ID,
device_mesh=model_mesh,
remote_group='student_model',
)
student_model.add_adapter_to_model(
ADAPTER_NAME,
LoraConfig(r=16, lora_alpha=32, lora_dropout=0.05, target_modules='all-linear'),
)
student_model.set_optimizer('default', lr=LEARNING_RATE)
student_model.set_lr_scheduler('default', lr_decay_steps=MAX_STEPS)
student_model.set_loss(GKDLoss(beta=GKD_BETA, temperature=GKD_TEMPERATURE))
student_model.set_template('Template', model_id=STUDENT_MODEL_ID)

# ── Teacher vLLM sampler (for prompt logprobs) ─────────────────────────────
teacher_sampler = vLLMSampler(
model_id=TEACHER_MODEL_ID,
engine_args={'gpu_memory_utilization': 0.85, 'max_model_len': 10240, 'logprobs_mode': 'raw_logprobs', 'max_logprobs': 64},
device_mesh=sampler_mesh,
remote_group='teacher_sampler',
)
teacher_sampler.set_template(Template, model_id=TEACHER_MODEL_ID)

# ── DataLoader (full-text: prompt + reference answer) ──────────────────────
dataloader = DataLoader(
dataset=create_dataset,
batch_size=BATCH_SIZE,
min_batch_size=BATCH_SIZE,
remote_group='student_model',
)

logger.info(f'GKD Off-Policy | student={STUDENT_MODEL_ID} teacher={TEACHER_MODEL_ID}')
logger.info(f' beta={GKD_BETA} T={GKD_TEMPERATURE} topk={GKD_TOPK}')

optim_step = 0
for batch in dataloader:
if optim_step >= MAX_STEPS:
break

# Teacher vLLM computes top-k prompt logprobs on the reference sequences
# max_tokens=0: don't generate new content, just compute logprobs on input
teacher_response = teacher_sampler.sample(
batch,
SamplingParams(max_tokens=0, temperature=1.0, prompt_logprobs=GKD_TOPK, num_samples=1),
)

input_data = [seq.new_input_feature for response in teacher_response for seq in response.sequences]

# Convert teacher logprobs to tensor format for GKDLoss
teacher_output = convert_topk_prompt_logprobs(
[resp.topk_prompt_logprobs for resp in teacher_response],
)

# Student forward + GKD backward
student_model.forward_backward(inputs=input_data, **teacher_output)
student_model.clip_grad_and_step()

if optim_step > 0 and optim_step % 10 == 0:
metric = student_model.calculate_metric(is_training=True)
logger.info(f'[Step {optim_step}/{MAX_STEPS}] {metric}')

if optim_step > 0 and optim_step % 50 == 0:
student_model.save(f'gkd-offpolicy-ckpt-{optim_step}')

optim_step += 1

student_model.save('gkd-offpolicy-final')
logger.info('GKD off-policy training completed.')


if __name__ == '__main__':
main()
Loading
Loading