Skip to content

Commit 3981b40

Browse files
committed
kernel doc
1 parent 5e1f465 commit 3981b40

File tree

6 files changed

+321
-0
lines changed

6 files changed

+321
-0
lines changed

docs/source/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ Twinkle DOCUMENTATION
3434
组件/LRScheduler/index.rst
3535
组件/补丁/index.rst
3636
组件/组件化/index.rst
37+
组件/Kernel/index.rst
3738
组件/训练中间件/index.rst
3839

3940
Indices and tables
Lines changed: 307 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,307 @@
1+
# Twinkle Kernel 模块
2+
3+
Twinkle Kernel 模块提供了两条内核替换路径,用于加速训练和推理:
4+
5+
* **层级 Kernelize(Layer-level kernelize)**
6+
使用优化内核替换完整的 `nn.Module` 实现。
7+
* **函数级 Kernelize(Function-level kernelize)**
8+
对 Python 模块中的特定函数进行 monkey-patch。
9+
10+
这两种方式可以独立使用,也可以通过统一入口组合使用。
11+
12+
---
13+
14+
## 概览:两条 Kernelize 路径
15+
16+
| 路径 | 粒度 | 典型场景 |
17+
| --- | --- | --- |
18+
| 层级替换 | 整个 `nn.Module` | Linear / Conv / MLP / Attention |
19+
| 函数级替换 | 单个函数 | 热点路径、数学算子、激活函数 |
20+
21+
---
22+
23+
## 层级内核替换(Layer-Level)
24+
25+
### 适用场景
26+
27+
* 你已经有完整的层内核实现
28+
* 希望在模型中批量替换某类 `nn.Module`
29+
* 同时适用于训练与推理
30+
31+
---
32+
33+
### 示例 1:本地 Kernel 仓库
34+
35+
适用于:
36+
37+
* 内核实现位于本地仓库
38+
* 希望替换 HuggingFace 或自定义模型中的层
39+
40+
```python
41+
from twinkle.kernel import (
42+
kernelize_model,
43+
register_layer_kernel,
44+
register_external_layer,
45+
)
46+
from transformers import Qwen2Config, Qwen2ForCausalLM
47+
from transformers.models.qwen2.modeling_qwen2 import Qwen2MLP
48+
49+
# 1) 从本地仓库注册层内核
50+
register_layer_kernel(
51+
kernel_name="MyAwesomeMLP",
52+
repo_path="/path/to/local/repo",
53+
package_name="my_kernels",
54+
layer_name="Qwen2MLPTrainingKernel",
55+
device="cuda",
56+
mode="train",
57+
)
58+
59+
# 2) 绑定外部层与内核名
60+
register_external_layer(Qwen2MLP, "MyAwesomeMLP")
61+
62+
# 3) 构建模型并应用内核替换
63+
config = Qwen2Config(
64+
hidden_size=128,
65+
num_hidden_layers=1,
66+
num_attention_heads=4,
67+
num_key_value_heads=4,
68+
intermediate_size=256,
69+
use_cache=False,
70+
)
71+
model = Qwen2ForCausalLM(config)
72+
model = kernelize_model(model, mode="train", device="cuda", use_fallback=True)
73+
```
74+
75+
---
76+
77+
### 示例 2:Hub Kernel 仓库
78+
79+
适用于:
80+
81+
* 内核托管在 Hub 上
82+
83+
```python
84+
import torch
85+
import torch.nn as nn
86+
from twinkle.kernel import (
87+
kernelize_model,
88+
register_layer_kernel,
89+
register_external_layer,
90+
)
91+
92+
# 1) 定义自定义层
93+
class SiluAndMul(nn.Module):
94+
def forward(self, x: torch.Tensor) -> torch.Tensor:
95+
x1, x2 = x.chunk(2, dim=-1)
96+
return nn.functional.silu(x1) * x2
97+
98+
# 2) 注册 Hub 内核并绑定层
99+
register_layer_kernel(
100+
kernel_name="SiluAndMulKernel",
101+
repo_id="kernels-community/activation",
102+
layer_name="SiluAndMul",
103+
device="cuda",
104+
mode="train",
105+
)
106+
register_external_layer(SiluAndMul, "SiluAndMulKernel")
107+
108+
# 3) 应用到模型
109+
class SimpleModel(nn.Module):
110+
def __init__(self):
111+
super().__init__()
112+
self.activation = SiluAndMul()
113+
114+
def forward(self, x: torch.Tensor) -> torch.Tensor:
115+
return self.activation(x)
116+
117+
model = SimpleModel()
118+
model = kernelize_model(model, mode="train", device="cuda", use_fallback=True)
119+
```
120+
121+
---
122+
123+
## 本地 Kernel 仓库(最小结构)
124+
125+
本地 kernel 仓库本质上是一个普通 Python 包。
126+
最少只需要一个 `layers.py` 来放层级内核实现。
127+
128+
```text
129+
# 仓库结构:
130+
my_kernels/ # 本地 kernel 仓库(Python 包)
131+
├── __init__.py # 包入口
132+
└── layers.py # 层级 kernel 实现
133+
```
134+
135+
```python
136+
# my_kernels/__init__.py
137+
from . import layers
138+
__all__ = ["layers"]
139+
140+
# my_kernels/layers.py
141+
import torch
142+
import torch.nn as nn
143+
144+
class Qwen2MLPTrainingKernel(nn.Module):
145+
def forward(self, x: torch.Tensor) -> torch.Tensor:
146+
gate = self.gate_proj(x)
147+
up = self.up_proj(x)
148+
return self.down_proj(self.act_fn(gate) * up)
149+
```
150+
151+
---
152+
153+
## 函数级内核替换(Function-Level)
154+
155+
### 适用场景
156+
157+
* 只需要加速少量热点函数
158+
* 不适合或不需要替换整个层
159+
* 常用于数学算子、激活函数、工具函数
160+
161+
---
162+
163+
### 示例 1:批量注册(简单场景)
164+
165+
```python
166+
from twinkle.kernel import register_kernels, kernelize_model
167+
168+
# 1) 注册函数内核
169+
config = {
170+
"functions": {
171+
"add": {
172+
"target_module": "my_pkg.math_ops",
173+
"func_impl": lambda x, y: x + y + 1,
174+
"device": "cuda",
175+
"mode": "inference",
176+
},
177+
},
178+
}
179+
register_kernels(config)
180+
181+
# 2) 应用(仅函数替换时 model 可为 None)
182+
kernelize_model(model=None, mode="inference", device="cuda", use_fallback=True)
183+
```
184+
185+
---
186+
187+
### 示例 2:高级函数来源(完整控制)
188+
189+
适用于:
190+
191+
* 不同函数来自不同来源(impl / repo / hub),或需要 compile/backward 等标志。
192+
193+
```python
194+
from twinkle.kernel.function import (
195+
register_function_kernel,
196+
apply_function_kernel,
197+
)
198+
import torch.nn as nn
199+
from twinkle.kernel import kernelize_model
200+
201+
TARGET_MODULE = "my_pkg.math_ops"
202+
203+
# 1) 直接传入实现
204+
def fast_add(x, y):
205+
return x + y + 1
206+
207+
register_function_kernel(
208+
func_name="add",
209+
target_module=TARGET_MODULE,
210+
func_impl=fast_add,
211+
device="cuda",
212+
mode="inference",
213+
)
214+
215+
# 2) Repo 对象(FuncRepositoryProtocol)
216+
class MyFuncRepo:
217+
def load(self):
218+
return MyKernelFunc
219+
220+
class MyKernelFunc(nn.Module):
221+
def forward(self, x, y):
222+
return x * y
223+
224+
register_function_kernel(
225+
func_name="mul",
226+
target_module=TARGET_MODULE,
227+
repo=MyFuncRepo(),
228+
device="cuda",
229+
mode="compile",
230+
)
231+
232+
# 3) Hub 仓库
233+
register_function_kernel(
234+
func_name="silu_and_mul",
235+
target_module="my_pkg.activations",
236+
repo_id="kernels-community/activation",
237+
revision="main", # 或 version="0.1.0"
238+
device="cuda",
239+
mode="inference",
240+
)
241+
242+
# 4) 应用函数内核
243+
applied = apply_function_kernel(
244+
target_module=TARGET_MODULE,
245+
device="cuda",
246+
mode="inference",
247+
strict=False,
248+
)
249+
print("patched:", applied)
250+
251+
# 5) 可选:通过 kernelize_model 统一应用
252+
model = nn.Sequential(nn.Linear(8, 8), nn.ReLU())
253+
kernelize_model(model=model, mode="inference", device="cuda", use_fallback=True)
254+
```
255+
256+
---
257+
258+
## 层级 + 函数级统一批量注册
259+
260+
### 适用场景
261+
262+
* 需要框架级统一集成
263+
* 希望通过单一配置入口管理
264+
* 同时管理层和函数两类内核
265+
266+
```python
267+
from twinkle.kernel import register_kernels, kernelize_model
268+
import torch.nn as nn
269+
270+
# 1) 注册层级 + 函数级内核
271+
config = {
272+
"layers": {
273+
"linear": {
274+
"repo_id": "kernels-community/linear",
275+
"layer_name": "Linear",
276+
"version": "0.1.0",
277+
"device": "cuda",
278+
"mode": "train",
279+
},
280+
"conv2d": {
281+
"repo_path": "/path/to/local/repo",
282+
"package_name": "my_kernels",
283+
"layer_name": "Conv2d",
284+
"device": "cuda",
285+
},
286+
},
287+
"functions": {
288+
"add": {
289+
"target_module": "my_pkg.math_ops",
290+
"func_impl": lambda x, y: x + y + 1,
291+
"device": "cuda",
292+
"mode": "inference",
293+
},
294+
"relu": {
295+
"target_module": "my_pkg.activations",
296+
"repo_id": "kernels-community/activation",
297+
"revision": "main",
298+
"device": "cuda",
299+
},
300+
},
301+
}
302+
register_kernels(config)
303+
304+
# 2) 通过 kernelize_model 应用
305+
model = nn.Sequential(nn.Linear(8, 8), nn.ReLU())
306+
kernelize_model(model=model, mode="train", device="cuda", use_fallback=True)
307+
```
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
Kernel
2+
===============
3+
.. toctree::
4+
:maxdepth: 1
5+
6+
Kernel.md
File renamed without changes.
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
Kernel
2+
===============
3+
.. toctree::
4+
:maxdepth: 1
5+
6+
Kernel.md

docs/source_en/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ Twinkle DOCUMENTATION
3434
Components/LRScheduler/index.rst
3535
Components/Patch/index.rst
3636
Components/Plugin/index.rst
37+
Components/Kernel/index.rst
3738
Components/Training Middleware/index.rst
3839

3940
Indices and tables

0 commit comments

Comments
 (0)