Skip to content

Commit 994d5d8

Browse files
authored
Merge branch 'dev' into add_rl_example
2 parents fe82fda + ef7ec14 commit 994d5d8

File tree

139 files changed

+5027
-271
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

139 files changed

+5027
-271
lines changed

.github/copilot-instructions.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ These instructions help AI agents work productively in this repo. Focus on concr
5757
## Examples
5858
- **Visualize a custom mesh:** create `DeviceMesh` and call `get_device_placement()`; example in [tests/infra/test_infra_graph.py](tests/infra/test_infra_graph.py).
5959
- **Add LoRA adapter via HTTP:** POST to `/add_adapter_to_model` with serialized `LoraConfig`; see server routes in [src/twinkle/server/twinkle/model.py](src/twinkle/server/twinkle/model.py).
60-
- **Sample with vLLM:** Configure `VLLMSampler`, set `Template`/`Processor`, then `sample()` on `Trajectory` list; see [src/twinkle/sampler/vllm_sampler.py](src/twinkle/sampler/vllm_sampler.py).
60+
- **Sample with vLLM:** Configure `vLLMSampler`, set `Template`/`Processor`, then `sample()` on `Trajectory` list; see [src/twinkle/sampler/vllm_sampler.py](src/twinkle/sampler/vllm_sampler.py).
6161

6262
---
6363
Questions or gaps? Tell us where guidance is unclear (e.g., missing run scripts, Ray cluster setup), and we’ll refine this document.

README.md

Lines changed: 227 additions & 5 deletions
Large diffs are not rendered by default.

README_ZH.md

Lines changed: 50 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
</p>
2222

2323
<p align="center">
24-
<a href="https://swift.readthedocs.io/en/latest/">English Documentation</a> &nbsp | &nbsp <a href="https://swift.readthedocs.io/zh-cn/latest/">中文文档</a> &nbsp
24+
<a href="https://twinkle-kit.readthedocs.io/en/latest/">English Documentation</a> &nbsp | &nbsp <a href="https://twinkle-kit.readthedocs.io/zh-cn/latest/">中文文档</a> &nbsp
2525
</p>
2626

2727
<div align="center">
@@ -191,14 +191,57 @@ twinkle的架构由client和server两部分构成,其中client端包含两个
191191

192192
这使得开发者可以直接使用Tinker API调用twinkle部署起来的后端训练服务。
193193

194+
## 多租户支持
195+
196+
Twinkle支持多个租户同时使用一个基模型进行训练。这一行为目前仅限于[LoRA](https://github.com/huggingface/peft/blob/main/src/peft/tuners/lora/config.py#L323)
197+
Twinkle采用了LoRA池+租户申请的技术方案。这个方案可以支持最大N个租户并行训练互不干扰,并且在模型角度来看,不同租户的训练流程可能不同,在基模中的数据padding方式、optimizer、Loss类型也可以不同。
198+
199+
<img src="assets/multi_lora.png" style="max-width: 500px; width: 100%;" />
200+
201+
例如:
202+
203+
- 租户A:本机加载本地私有数据集,loRA rank=8,使用基模进行SFT
204+
- 租户B:使用远端加载Hub端开源数据集,LoRA rank=32,使用基模进行PT
205+
- 租户C:使用基模进行GRPO Loss计算,使用Sampler采样
206+
- 租户D:使用基模进行logps推理
207+
208+
这些过程可以同时发生在一个基模上,因为模型、Sampler本质上也是twinkle组件的一部分,可以做到任务无关。训练完成后,支持checkpoint推送HuggingFace/ModelScope的模型仓库,默认为私有。twinkle提供了完整的多租户训练解决方案,在server端支持集群化管理和动态扩缩容,可以进行简单定制化后作为企业级服务。
209+
210+
> 作为模块化框架,twinkle本身也可以支持远端临时的独占式训练,即全参数方式。
211+
212+
194213
## 支持的组件
195214

196-
| | | | | |
197-
| :--------------------------------------------------------: | :-------------------------------------------------------: | :----------------------------------------------------------: | :--------------------------------------------------------: | :-------------------------------------------------------------: |
198-
| **Dataset**`<br><sub>`数据加载和预处理`</sub>` | **Template**`<br><sub>`编码和解码`</sub>` | **DataLoader**`<br><sub>`数据分发和batch化`</sub>` | **Preprocessor**`<br><sub>`数据ETL`</sub>` | **InputProcessor**`<br><sub>`处理任务特定输入`</sub>` |
199-
| **Model**`<br><sub>`大模型,支持多种框架`</sub>` | **Sampler**`<br><sub>`采样器`</sub>` | **Loss**`<br><sub>`残差`</sub>` | **Metric**`<br><sub>`训练指标集合`</sub>` | **Reward**`<br><sub>`奖励函数`</sub>` |
200-
| **Advantage**`<br><sub>`优势函数`</sub>` | **CheckpointEngine**`<br><sub>`权重同步`</sub>` | **Patch**`<br><sub>`补丁,用于模型修复`</sub>` | **Module**`<br><sub>`组件,例如Optimizer`</sub>` | **Kernel**`<br><sub>`算子`</sub>` |
201-
| **Server**`<br><sub>`开启后端集群`</sub>` | **Client**`<br><sub>`客户端代码`</sub>` | **Infra**`<br><sub>`隔离ray和torchrun差异`</sub>` | **Plugin**`<br><sub>`使用hub端组件`</sub>` | **Hub**`<br><sub>`对接HF/MS网络库`</sub>` |
215+
<table>
216+
<tr>
217+
<td align="center"><b>Dataset</b><br><sub>数据加载和预处理</sub></td>
218+
<td align="center"><b>Template</b><br><sub>编码和解码</sub></td>
219+
<td align="center"><b>DataLoader</b><br><sub>数据分发和batch化</sub></td>
220+
<td align="center"><b>Preprocessor</b><br><sub>数据ETL</sub></td>
221+
<td align="center"><b>InputProcessor</b><br><sub>处理任务特定输入</sub></td>
222+
</tr>
223+
<tr>
224+
<td align="center"><b>Model</b><br><sub>大模型,支持多种框架</sub></td>
225+
<td align="center"><b>Sampler</b><br><sub>采样器</sub></td>
226+
<td align="center"><b>Loss</b><br><sub>残差</sub></td>
227+
<td align="center"><b>Metric</b><br><sub>训练指标集合</sub></td>
228+
<td align="center"><b>Reward</b><br><sub>奖励函数</sub></td>
229+
</tr>
230+
<tr>
231+
<td align="center"><b>Advantage</b><br><sub>优势函数</sub></td>
232+
<td align="center"><b>CheckpointEngine</b><br><sub>权重同步</sub></td>
233+
<td align="center"><b>Patch</b><br><sub>补丁,用于模型修复</sub></td>
234+
<td align="center"><b>Module</b><br><sub>组件,例如Optimizer</sub></td>
235+
<td align="center"><b>Kernel</b><br><sub>算子</sub></td>
236+
</tr>
237+
<tr>
238+
<td align="center"><b>Server</b><br><sub>开启后端集群</sub></td>
239+
<td align="center"><b>Client</b><br><sub>客户端代码</sub></td>
240+
<td align="center"><b>Infra</b><br><sub>隔离ray和torchrun差异</sub></td>
241+
<td align="center"><b>Plugin</b><br><sub>使用hub端组件</sub></td>
242+
<td align="center"><b>Hub</b><br><sub>对接HF/MS网络库</sub></td>
243+
</tr>
244+
</table>
202245

203246
## 社区组件
204247

ROADMAP.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
- [ ] 支持GKD、on-policy-distill等蒸馏算法
6565
- [ ] 支持DPO对齐训练
6666
- [ ] 支持colocate RL训练
67+
- [ ] Preprocess支持batched
6768

6869
### 网络能力
6970

@@ -82,6 +83,7 @@
8283
- [ ] Support for distillation algorithms such as GKD and on-policy distillation
8384
- [ ] Support for DPO alignment training
8485
- [ ] Support for colocate RL training
86+
- [ ] Support for batched preprocessing
8587

8688
### Networking Capabilities
8789

assets/multi_lora.png

178 KB
Loading

client_tools/client_generator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -728,7 +728,7 @@ def generate_samplers():
728728
from twinkle.data_format import Trajectory, InputFeature
729729
730730
731-
class VLLMSampler(Sampler):
731+
class vLLMSampler(Sampler):
732732
"""Client wrapper for Sampler that calls server HTTP endpoints.
733733
734734
This client manages sampling operations and adapter synchronization with the sampler server.
@@ -851,7 +851,7 @@ def set_template(self, template_cls: str, adapter_name: str = '', **kwargs):
851851

852852
# Create/overwrite __init__.py
853853
init_file = client_module_path / '__init__.py'
854-
init_content = AUTO_GEN_WARNING + "from .vllm_sampler import VLLMSampler\n"
854+
init_content = AUTO_GEN_WARNING + "from .vllm_sampler import vLLMSampler\n"
855855
print(f"Writing {init_file}...")
856856
with open(init_file, 'w', encoding='utf-8') as f:
857857
f.write(init_content)

cookbook/legacy/grpo/lora.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from twinkle.dataset import Dataset, DatasetMeta
1616
from twinkle.model import TransformersModel
1717
from twinkle.processor import InputProcessor
18-
from twinkle.sampler import VLLMSampler
18+
from twinkle.sampler import vLLMSampler
1919
from twinkle.template import Template
2020
from twinkle.metric import CompletionRewardMetric
2121

@@ -126,7 +126,7 @@ def main():
126126
model.set_processor(InputProcessor, adapter_name=ADAPTER_NAME)
127127
model.set_template('Template', model_id=MODEL_ID, adapter_name=ADAPTER_NAME)
128128

129-
sampler = VLLMSampler(
129+
sampler = vLLMSampler(
130130
model_id=MODEL_ID,
131131
engine_args={
132132
'load_format': 'dummy',

cookbook/legacy/grpo/lora_gpu.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
44
This script tests the twinkle RL training capabilities on GPU:
55
1. TransformersModel backend
6-
2. VLLMSampler / TorchSampler integration
6+
2. vLLMSampler / TorchSampler integration
77
3. GRPOLoss and advantage computation
88
4. Weight synchronization between model and sampler
99
@@ -16,7 +16,7 @@
1616
# Test with multiple GPUs (Ray mode)
1717
CUDA_VISIBLE_DEVICES=0,1 TWINKLE_MODE=ray python lora_gpu.py
1818
19-
# Use VLLMSampler (requires more GPU memory)
19+
# Use vLLMSampler (requires more GPU memory)
2020
TWINKLE_USE_TORCH_SAMPLER=0 python lora_gpu.py
2121
2222
# Debug mode
@@ -27,14 +27,14 @@
2727
TWINKLE_MAX_LENGTH: Max sequence length (default: 2048)
2828
TWINKLE_MAX_STEPS: Max training steps (default: 3)
2929
TWINKLE_USE_REF_MODEL: Use reference model for KL (default: 0)
30-
TWINKLE_USE_TORCH_SAMPLER: Use TorchSampler instead of VLLMSampler (default: 1)
30+
TWINKLE_USE_TORCH_SAMPLER: Use TorchSampler instead of vLLMSampler (default: 1)
3131
TWINKLE_DEBUG: Enable debug logging (default: 0)
3232
TWINKLE_MODE: 'local' or 'ray' (default: local)
3333
3434
Test Results (as of 2026-01-30):
3535
- TransformersModel + TorchSampler: PASS
36-
- VLLMSampler sampling: PASS
37-
- VLLMSampler LoRA weight sync: IN PROGRESS (needs more debugging)
36+
- vLLMSampler sampling: PASS
37+
- vLLMSampler LoRA weight sync: IN PROGRESS (needs more debugging)
3838
"""
3939
import numpy as np
4040
from peft import LoraConfig
@@ -52,7 +52,7 @@
5252
from twinkle.infra import DeviceGroup, remote_function, remote_class
5353
from twinkle.model import TransformersModel
5454
from twinkle.reward import MathReward
55-
from twinkle.sampler import VLLMSampler, TorchSampler
55+
from twinkle.sampler import vLLMSampler, TorchSampler
5656
from twinkle.data_format.sampling import SamplingParams
5757
from twinkle.weight_loader import NativeLoader
5858
from twinkle.advantage import GRPOAdvantage
@@ -238,8 +238,8 @@ def __init__(self, engine_args=None, lora_config=None, adapter_name=None, **kwar
238238
)
239239
else:
240240
if engine_args is None:
241-
raise ValueError("engine_args is required for VLLMSampler.")
242-
self.sampler = VLLMSampler(
241+
raise ValueError("engine_args is required for vLLMSampler.")
242+
self.sampler = vLLMSampler(
243243
model_path,
244244
engine_args=engine_args,
245245
device_mesh=actor_device_mesh,
@@ -403,7 +403,7 @@ def train_local():
403403
device_mesh=actor_device_mesh,
404404
)
405405
else:
406-
from twinkle.sampler import VLLMSampler
406+
from twinkle.sampler import vLLMSampler
407407
engine_args = {
408408
'model': model_path,
409409
'enable_lora': True,
@@ -413,7 +413,7 @@ def train_local():
413413
'gpu_memory_utilization': 0.5,
414414
'trust_remote_code': True,
415415
}
416-
sampler = VLLMSampler(
416+
sampler = vLLMSampler(
417417
model_path,
418418
engine_args=engine_args,
419419
device_mesh=actor_device_mesh,

cookbook/legacy/grpo/lora_npu.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from twinkle.infra import DeviceGroup, remote_function, remote_class
99
from twinkle.model import TransformersModel
1010
from twinkle.reward import MathReward
11-
from twinkle.sampler import VLLMSampler, TorchSampler
11+
from twinkle.sampler import vLLMSampler, TorchSampler
1212
from twinkle.data_format.sampling import SamplingParams, SampleResponse
1313
from twinkle.weight_loader import NativeLoader
1414
from twinkle.advantage import compute_advantages
@@ -230,8 +230,8 @@ def __init__(self, engine_args=None, lora_config=None, adapter_name=None, **kwar
230230
)
231231
else:
232232
if engine_args is None:
233-
raise ValueError("engine_args is required for VLLMSampler.")
234-
self.sampler = VLLMSampler(
233+
raise ValueError("engine_args is required for vLLMSampler.")
234+
self.sampler = vLLMSampler(
235235
model_path,
236236
engine_args=engine_args,
237237
device_mesh=actor_device_mesh,

cookbook/legacy/grpo/lora_pr.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from twinkle.metric import CompletionRewardMetric
1313
from twinkle.model import TransformersModel
1414
from twinkle.processor import InputProcessor
15-
from twinkle.sampler import VLLMSampler
15+
from twinkle.sampler import vLLMSampler
1616
from twinkle.template import Template
1717
from twinkle import torch_util
1818

@@ -47,7 +47,7 @@ def main():
4747
lora_config = LoraConfig(target_modules="all-linear", r=8, lora_alpha=32, lora_dropout=0.05)
4848
model = TransformersModel(model_id='ms://Qwen/Qwen2.5-3B-Instruct', device_mesh=model_mesh, remote_group='model')
4949
model.add_adapter_to_model('default', lora_config, gradient_accumulation_steps=4,)
50-
sampler = VLLMSampler(
50+
sampler = vLLMSampler(
5151
model_id='ms://Qwen/Qwen2.5-3B-Instruct',
5252
engine_args={
5353
'load_format': 'dummy',

0 commit comments

Comments
 (0)