Skip to content

Commit 57e4482

Browse files
committed
Merge branch 'add_rl_example' of https://github.com/modelscope/twinkle into add_rl_example
2 parents 546ae09 + 994d5d8 commit 57e4482

File tree

155 files changed

+5085
-515
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

155 files changed

+5085
-515
lines changed

.github/copilot-instructions.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ These instructions help AI agents work productively in this repo. Focus on concr
5757
## Examples
5858
- **Visualize a custom mesh:** create `DeviceMesh` and call `get_device_placement()`; example in [tests/infra/test_infra_graph.py](tests/infra/test_infra_graph.py).
5959
- **Add LoRA adapter via HTTP:** POST to `/add_adapter_to_model` with serialized `LoraConfig`; see server routes in [src/twinkle/server/twinkle/model.py](src/twinkle/server/twinkle/model.py).
60-
- **Sample with vLLM:** Configure `VLLMSampler`, set `Template`/`Processor`, then `sample()` on `Trajectory` list; see [src/twinkle/sampler/vllm_sampler.py](src/twinkle/sampler/vllm_sampler.py).
60+
- **Sample with vLLM:** Configure `vLLMSampler`, set `Template`/`Processor`, then `sample()` on `Trajectory` list; see [src/twinkle/sampler/vllm_sampler.py](src/twinkle/sampler/vllm_sampler.py).
6161

6262
---
6363
Questions or gaps? Tell us where guidance is unclear (e.g., missing run scripts, Ray cluster setup), and we’ll refine this document.

README.md

Lines changed: 227 additions & 5 deletions
Large diffs are not rendered by default.

README_ZH.md

Lines changed: 50 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
</p>
2222

2323
<p align="center">
24-
<a href="https://swift.readthedocs.io/en/latest/">English Documentation</a> &nbsp | &nbsp <a href="https://swift.readthedocs.io/zh-cn/latest/">中文文档</a> &nbsp
24+
<a href="https://twinkle-kit.readthedocs.io/en/latest/">English Documentation</a> &nbsp | &nbsp <a href="https://twinkle-kit.readthedocs.io/zh-cn/latest/">中文文档</a> &nbsp
2525
</p>
2626

2727
<div align="center">
@@ -191,14 +191,57 @@ twinkle的架构由client和server两部分构成,其中client端包含两个
191191

192192
这使得开发者可以直接使用Tinker API调用twinkle部署起来的后端训练服务。
193193

194+
## 多租户支持
195+
196+
Twinkle支持多个租户同时使用一个基模型进行训练。这一行为目前仅限于[LoRA](https://github.com/huggingface/peft/blob/main/src/peft/tuners/lora/config.py#L323)
197+
Twinkle采用了LoRA池+租户申请的技术方案。这个方案可以支持最大N个租户并行训练互不干扰,并且在模型角度来看,不同租户的训练流程可能不同,在基模中的数据padding方式、optimizer、Loss类型也可以不同。
198+
199+
<img src="assets/multi_lora.png" style="max-width: 500px; width: 100%;" />
200+
201+
例如:
202+
203+
- 租户A:本机加载本地私有数据集,loRA rank=8,使用基模进行SFT
204+
- 租户B:使用远端加载Hub端开源数据集,LoRA rank=32,使用基模进行PT
205+
- 租户C:使用基模进行GRPO Loss计算,使用Sampler采样
206+
- 租户D:使用基模进行logps推理
207+
208+
这些过程可以同时发生在一个基模上,因为模型、Sampler本质上也是twinkle组件的一部分,可以做到任务无关。训练完成后,支持checkpoint推送HuggingFace/ModelScope的模型仓库,默认为私有。twinkle提供了完整的多租户训练解决方案,在server端支持集群化管理和动态扩缩容,可以进行简单定制化后作为企业级服务。
209+
210+
> 作为模块化框架,twinkle本身也可以支持远端临时的独占式训练,即全参数方式。
211+
212+
194213
## 支持的组件
195214

196-
| | | | | |
197-
| :--------------------------------------------------------: | :-------------------------------------------------------: | :----------------------------------------------------------: | :--------------------------------------------------------: | :-------------------------------------------------------------: |
198-
| **Dataset**`<br><sub>`数据加载和预处理`</sub>` | **Template**`<br><sub>`编码和解码`</sub>` | **DataLoader**`<br><sub>`数据分发和batch化`</sub>` | **Preprocessor**`<br><sub>`数据ETL`</sub>` | **InputProcessor**`<br><sub>`处理任务特定输入`</sub>` |
199-
| **Model**`<br><sub>`大模型,支持多种框架`</sub>` | **Sampler**`<br><sub>`采样器`</sub>` | **Loss**`<br><sub>`残差`</sub>` | **Metric**`<br><sub>`训练指标集合`</sub>` | **Reward**`<br><sub>`奖励函数`</sub>` |
200-
| **Advantage**`<br><sub>`优势函数`</sub>` | **CheckpointEngine**`<br><sub>`权重同步`</sub>` | **Patch**`<br><sub>`补丁,用于模型修复`</sub>` | **Module**`<br><sub>`组件,例如Optimizer`</sub>` | **Kernel**`<br><sub>`算子`</sub>` |
201-
| **Server**`<br><sub>`开启后端集群`</sub>` | **Client**`<br><sub>`客户端代码`</sub>` | **Infra**`<br><sub>`隔离ray和torchrun差异`</sub>` | **Plugin**`<br><sub>`使用hub端组件`</sub>` | **Hub**`<br><sub>`对接HF/MS网络库`</sub>` |
215+
<table>
216+
<tr>
217+
<td align="center"><b>Dataset</b><br><sub>数据加载和预处理</sub></td>
218+
<td align="center"><b>Template</b><br><sub>编码和解码</sub></td>
219+
<td align="center"><b>DataLoader</b><br><sub>数据分发和batch化</sub></td>
220+
<td align="center"><b>Preprocessor</b><br><sub>数据ETL</sub></td>
221+
<td align="center"><b>InputProcessor</b><br><sub>处理任务特定输入</sub></td>
222+
</tr>
223+
<tr>
224+
<td align="center"><b>Model</b><br><sub>大模型,支持多种框架</sub></td>
225+
<td align="center"><b>Sampler</b><br><sub>采样器</sub></td>
226+
<td align="center"><b>Loss</b><br><sub>残差</sub></td>
227+
<td align="center"><b>Metric</b><br><sub>训练指标集合</sub></td>
228+
<td align="center"><b>Reward</b><br><sub>奖励函数</sub></td>
229+
</tr>
230+
<tr>
231+
<td align="center"><b>Advantage</b><br><sub>优势函数</sub></td>
232+
<td align="center"><b>CheckpointEngine</b><br><sub>权重同步</sub></td>
233+
<td align="center"><b>Patch</b><br><sub>补丁,用于模型修复</sub></td>
234+
<td align="center"><b>Module</b><br><sub>组件,例如Optimizer</sub></td>
235+
<td align="center"><b>Kernel</b><br><sub>算子</sub></td>
236+
</tr>
237+
<tr>
238+
<td align="center"><b>Server</b><br><sub>开启后端集群</sub></td>
239+
<td align="center"><b>Client</b><br><sub>客户端代码</sub></td>
240+
<td align="center"><b>Infra</b><br><sub>隔离ray和torchrun差异</sub></td>
241+
<td align="center"><b>Plugin</b><br><sub>使用hub端组件</sub></td>
242+
<td align="center"><b>Hub</b><br><sub>对接HF/MS网络库</sub></td>
243+
</tr>
244+
</table>
202245

203246
## 社区组件
204247

ROADMAP.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
- [ ] 支持GKD、on-policy-distill等蒸馏算法
6565
- [ ] 支持DPO对齐训练
6666
- [ ] 支持colocate RL训练
67+
- [ ] Preprocess支持batched
6768

6869
### 网络能力
6970

@@ -82,6 +83,7 @@
8283
- [ ] Support for distillation algorithms such as GKD and on-policy distillation
8384
- [ ] Support for DPO alignment training
8485
- [ ] Support for colocate RL training
86+
- [ ] Support for batched preprocessing
8587

8688
### Networking Capabilities
8789

assets/multi_lora.png

178 KB
Loading

client_tools/client_generator.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,6 @@ def build_imports() -> Tuple[List[str], str]:
240240
if typing_imports:
241241
lines.append(f"from typing import {', '.join(sorted(typing_imports))}")
242242
lines.extend([
243-
"from twinkle_client.http import TWINKLE_SERVER_URL",
244243
"from twinkle_client.http import http_post, heartbeat_manager",
245244
])
246245
lines.extend(sorted(twinkle_imports))
@@ -447,7 +446,6 @@ def generate_models():
447446

448447
model_code = AUTO_GEN_WARNING + '''from typing import Any, Optional, Union, Type, Dict, Literal, List
449448
import uuid
450-
from twinkle_client.http import TWINKLE_SERVER_URL
451449
from twinkle_client.http import http_post, heartbeat_manager
452450
from twinkle import DeviceMesh
453451
from twinkle.data_format import InputFeature, Trajectory
@@ -724,18 +722,13 @@ def generate_samplers():
724722
client_module_path.mkdir(parents=True, exist_ok=True)
725723

726724
sampler_code = AUTO_GEN_WARNING + '''from typing import Any, Optional, List, Dict, Union
727-
import uuid
728-
from twinkle_client.http import TWINKLE_SERVER_URL
729725
from twinkle_client.http import http_post, heartbeat_manager
730726
from twinkle.sampler.base import Sampler
731-
from twinkle.sampler.types import SamplingParams, SampleResponse
732-
from twinkle import DeviceMesh
733727
from peft import PeftConfig
734728
from twinkle.data_format import Trajectory, InputFeature
735-
import json
736729
737730
738-
class VLLMSampler(Sampler):
731+
class vLLMSampler(Sampler):
739732
"""Client wrapper for Sampler that calls server HTTP endpoints.
740733
741734
This client manages sampling operations and adapter synchronization with the sampler server.
@@ -756,7 +749,6 @@ def __init__(self, model_id: str, **kwargs):
756749
json_data=kwargs
757750
)
758751
response.raise_for_status()
759-
return response.json()
760752
761753
def _send_adapter_heartbeat(self):
762754
"""Internal method to send adapter heartbeat."""
@@ -859,7 +851,7 @@ def set_template(self, template_cls: str, adapter_name: str = '', **kwargs):
859851

860852
# Create/overwrite __init__.py
861853
init_file = client_module_path / '__init__.py'
862-
init_content = AUTO_GEN_WARNING + "from .vllm_sampler import VLLMSampler\n"
854+
init_content = AUTO_GEN_WARNING + "from .vllm_sampler import vLLMSampler\n"
863855
print(f"Writing {init_file}...")
864856
with open(init_file, 'w', encoding='utf-8') as f:
865857
f.write(init_content)

cookbook/client/tinker/megatron/lora.py

Lines changed: 0 additions & 158 deletions
This file was deleted.

cookbook/client/tinker/transformer/grpo.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,25 +35,25 @@
3535
logger = get_logger()
3636

3737
# ========== Configuration ==========
38-
BASE_MODEL = "Qwen/Qwen2.5-3B-Instruct"
38+
MODEL_ID = 'ms://Qwen/Qwen2.5-0.5B-Instruct'
3939
NUM_GENERATIONS = 8
4040
MAX_NEW_TOKENS = 1024
4141
LEARNING_RATE = 1e-5
42-
MAX_STEPS = 2000
42+
MAX_STEPS = 10
4343
BATCH_SIZE = 4
4444
TEMPERATURE = 1.0
45-
SYNC_INTERVAL = 1 # Save weights for sampler every N steps
46-
LORA_RANK = 8
45+
SYNC_INTERVAL = 5 # Save weights for sampler every N steps
46+
GRADIENT_ACCUMULATION_STEPS = 4
4747

4848

4949
def create_countdown_dataset():
5050
"""Create Countdown Game dataset for GRPO training."""
51-
from twinkle.preprocessor import CountdownProcessor
51+
5252
dataset = Dataset(DatasetMeta(
53-
"ms://zouxuhong/Countdown-Tasks-3to4", data_slice=range(50000)))
53+
"ms://zouxuhong/Countdown-Tasks-3to4", data_slice=range(500)))
5454
dataset.set_template(
5555
"Template", model_id=f'ms://{BASE_MODEL}', max_length=8192)
56-
dataset.map(CountdownProcessor())
56+
dataset.map('CountdownProcessor')
5757
dataset.encode(add_generation_prompt=True)
5858
return dataset
5959

cookbook/client/twinkle/transformer/grpo.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,26 +41,25 @@
4141
logger = get_logger()
4242

4343
# ========== Configuration ==========
44-
MODEL_ID = 'ms://Qwen/Qwen2.5-3B-Instruct'
44+
MODEL_ID = 'ms://Qwen/Qwen2.5-0.5B-Instruct'
4545
NUM_GENERATIONS = 8
4646
MAX_NEW_TOKENS = 1024
4747
LEARNING_RATE = 1e-5
48-
MAX_STEPS = 2000
48+
MAX_STEPS = 10
4949
BATCH_SIZE = 4
5050
TEMPERATURE = 1.0
51-
SYNC_INTERVAL = 1 # Save weights for sampler every N steps
51+
SYNC_INTERVAL = 5 # Save weights for sampler every N steps
5252
GRADIENT_ACCUMULATION_STEPS = 4
5353

5454

5555
def create_countdown_dataset():
5656
"""Create Countdown Game dataset for GRPO training."""
57-
from twinkle.preprocessor import CountdownProcessor
5857

5958
dataset = Dataset(dataset_meta=DatasetMeta(
60-
"ms://zouxuhong/Countdown-Tasks-3to4", data_slice=range(50000)))
59+
"ms://zouxuhong/Countdown-Tasks-3to4", data_slice=range(500)))
6160
dataset.set_template(
6261
'Template', model_id=MODEL_ID, max_length=8192)
63-
dataset.map(CountdownProcessor())
62+
dataset.map('CountdownProcessor')
6463
dataset.encode(add_generation_prompt=True, batched=True)
6564
return dataset
6665

0 commit comments

Comments
 (0)