中文 | English
从Mamba原始仓库提取的Selective Scan CUDA算子,封装为独立Python包。
📢 重要声明:本项目是从Tri Dao的Mamba原始仓库中提取Selective Scan算子代码,进行独立封装。我们没有创造或修改核心算法,所有CUDA实现均来自原始Mamba项目。本项目的价值在于提供了轻量级、可独立安装的封装方案。
本项目的核心工作: 从Tri Dao的Mamba原始实现中提取Selective Scan算法的CUDA核心代码,重新封装为可独立安装的Python依赖包,方便在其他项目中直接使用。
提取版本信息:
- 源仓库:state-spaces/mamba
- 提取自commit:
d49d7c9 - 提取日期:2024-08-30
- 原始版本:Mamba v2.2.6
技术说明: Selective Scan是Mamba架构的核心算子,通过并行扫描技术将传统O(n)串行递归转化为O(log n)可并行操作。原始实现深度集成在Mamba项目中,本项目将其提取为独立模块,保持原有的高性能CUDA实现。
提取封装特性:
- ✅ 完整保留原始Mamba的Selective Scan实现
- ✅ 独立安装,无需完整Mamba依赖
- ✅ 标准PyTorch扩展接口
- ✅ 保持原有性能优化(多精度、并行扫描)
- ✅ 支持实数/复数、可变参数等原有功能
- ✅ CUDA/ROCm双平台兼容
这不是原创实现:本项目的所有CUDA代码均来自Mamba原始仓库,我们的工作仅是:
- 从Mamba仓库中提取Selective Scan相关代码
- 重新组织目录结构,使其可独立编译
- 编写setup.py和Python接口,封装为标准包
- 提供完整文档,便于独立使用
为什么需要提取:Mamba原始仓库包含完整的模型实现,体积较大且依赖复杂。如果只需要使用Selective Scan算子,安装整个Mamba会引入不必要的依赖。本项目提供了轻量级的独立安装方案。
致谢:核心算法和CUDA实现的所有功劳归于Tri Dao和Mamba团队,本项目仅做提取和封装工作。
- Python >= 3.8
- PyTorch >= 2.0.0
- CUDA >= 11.0 或 ROCm >= 5.0
- GCC >= 7.0
- NVIDIA GPU (计算能力 >= 7.0)
git clone https://github.com/biubushy/selective_scan.git
cd selective_scan# 使用conda
conda create -n selective_scan python=3.10
conda activate selective_scan
# 或使用venv
python -m venv venv
source venv/bin/activate # Linux/Mac
# venv\Scripts\activate # Windows访问 PyTorch官网 根据您的CUDA版本安装:
# 示例:CUDA 12.1
conda install pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia
# 或使用pip
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121方式A:开发模式(推荐用于开发调试)
python setup.py develop方式B:正式安装
pip install .方式C:从源码直接安装(其他项目使用)
# 在您的项目中
pip install git+https://github.com/biubushy/selective_scan.gitpython -c "import selective_scan; print(f'Version: {selective_scan.__version__}')"import torch
import selective_scan
# biubushy | 2025-10
batch, dim, seqlen, dstate = 2, 16, 128, 16
device = 'cuda'
u = torch.randn(batch, dim, seqlen, device=device)
delta = torch.randn(batch, dim, seqlen, device=device)
A = torch.randn(dim, dstate, device=device)
B = torch.randn(dim, dstate, device=device)
C = torch.randn(dim, dstate, device=device)
out = selective_scan.selective_scan_fn(u, delta, A, B, C)
print(f"Input: {u.shape} -> Output: {out.shape}")import torch
import selective_scan
u = torch.randn(2, 16, 128, device='cuda', requires_grad=True)
delta = torch.randn(2, 16, 128, device='cuda', requires_grad=True)
A = torch.randn(16, 16, device='cuda', requires_grad=True)
B = torch.randn(16, 16, device='cuda', requires_grad=True)
C = torch.randn(16, 16, device='cuda', requires_grad=True)
out = selective_scan.selective_scan_fn(u, delta, A, B, C)
loss = out.sum()
loss.backward()
print(f"Gradients computed: u.grad={u.grad is not None}")import torch
import selective_scan
batch, dim, seqlen, dstate = 2, 16, 128, 16
u = torch.randn(batch, dim, seqlen, device='cuda', dtype=torch.float32)
delta = torch.randn(batch, dim, seqlen, device='cuda', dtype=torch.float32)
A = torch.randn(dim, dstate, device='cuda', dtype=torch.float32)
B = torch.randn(dim, dstate, device='cuda', dtype=torch.float32)
C = torch.randn(dim, dstate, device='cuda', dtype=torch.float32)
D = torch.randn(dim, device='cuda', dtype=torch.float32)
z = torch.randn(batch, dim, seqlen, device='cuda', dtype=torch.float32)
delta_bias = torch.randn(dim, device='cuda', dtype=torch.float32)
out = selective_scan.selective_scan_fn(
u, delta, A, B, C,
D=D,
z=z,
delta_bias=delta_bias,
delta_softplus=True
)
if isinstance(out, tuple):
out, out_z = out
print(f"Output: {out.shape}, Gated output: {out_z.shape}")
else:
print(f"Output: {out.shape}")在您的项目 requirements.txt 中添加:
torch>=2.0.0
selective_scan @ git+https://github.com/biubushy/selective_scan.git
然后安装:
pip install -r requirements.txt如果您克隆了源码到本地:
pip install /path/to/selective_scanimport torch
import torch.nn as nn
import selective_scan
class MambaBlock(nn.Module):
def __init__(self, dim, dstate=16):
super().__init__()
self.dim = dim
self.dstate = dstate
self.delta_proj = nn.Linear(dim, dim)
self.A = nn.Parameter(torch.randn(dim, dstate))
self.B = nn.Parameter(torch.randn(dim, dstate))
self.C = nn.Parameter(torch.randn(dim, dstate))
self.D = nn.Parameter(torch.randn(dim))
def forward(self, x):
batch, seqlen, dim = x.shape
x = x.transpose(1, 2)
delta = self.delta_proj(x.transpose(1, 2)).transpose(1, 2)
out = selective_scan.selective_scan_fn(
x, delta, self.A, self.B, self.C, D=self.D
)
return out.transpose(1, 2)
model = MambaBlock(dim=64).cuda()
x = torch.randn(2, 100, 64).cuda()
y = model(x)
print(f"Input: {x.shape} -> Output: {y.shape}")执行选择性扫描操作。
参数:
u(Tensor): 输入序列,形状[batch, dim, seqlen]delta(Tensor): 步长参数,形状[batch, dim, seqlen]A(Tensor): 状态转移矩阵,形状[dim, dstate],支持实数/复数B(Tensor): 输入矩阵- 固定模式:
[dim, dstate] - 可变模式:
[batch, n_groups, dstate, seqlen]
- 固定模式:
C(Tensor): 输出矩阵- 固定模式:
[dim, dstate] - 可变模式:
[batch, n_groups, dstate, seqlen]
- 固定模式:
D(Tensor, 可选): 跳跃连接,形状[dim]z(Tensor, 可选): 门控输入,形状[batch, dim, seqlen]delta_bias(Tensor, 可选): delta偏置,形状[dim]delta_softplus(bool): 是否对delta应用softplus激活
返回:
- 如果无
z:out(Tensor),形状[batch, dim, seqlen] - 如果有
z:(out, out_z)(Tuple[Tensor, Tensor])
示例:
out = selective_scan.selective_scan_fn(u, delta, A, B, C)直接调用前向传播CUDA kernel,不建立自动微分图。
直接调用反向传播CUDA kernel。
import torch
import selective_scan
with torch.cuda.amp.autocast():
u = torch.randn(2, 16, 128, device='cuda', dtype=torch.float16)
delta = torch.randn(2, 16, 128, device='cuda', dtype=torch.float16)
A = torch.randn(16, 16, device='cuda', dtype=torch.float32)
B = torch.randn(16, 16, device='cuda', dtype=torch.float32)
C = torch.randn(16, 16, device='cuda', dtype=torch.float32)
out = selective_scan.selective_scan_fn(u, delta, A, B, C)batch, dim, seqlen, dstate = 2, 16, 128, 16
n_groups = 4
u = torch.randn(batch, dim, seqlen, device='cuda')
delta = torch.randn(batch, dim, seqlen, device='cuda')
A = torch.randn(dim, dstate, device='cuda')
B_var = torch.randn(batch, n_groups, dstate, seqlen, device='cuda')
C_var = torch.randn(batch, n_groups, dstate, seqlen, device='cuda')
out = selective_scan.selective_scan_fn(u, delta, A, B_var, C_var)u = torch.randn(2, 16, 128, device='cuda', dtype=torch.float32)
delta = torch.randn(2, 16, 128, device='cuda', dtype=torch.float32)
A = torch.randn(16, 16, device='cuda', dtype=torch.complex64)
B = torch.randn(16, 16, device='cuda', dtype=torch.complex64)
C = torch.randn(16, 16, device='cuda', dtype=torch.complex64)
out = selective_scan.selective_scan_fn(u, delta, A, B, C)- 序列长度: 针对不同序列长度,kernel会自动选择最优的线程配置
- 批次大小: 建议批次大小为2的幂次,利于GPU内存对齐
- 状态维度:
dstate <= 256,超过此值会导致共享内存溢出 - 精度选择:
- FP32: 最高精度,速度较慢
- FP16: 平衡精度和速度
- BF16: 训练稳定性好,推荐用于大模型
问题: No module named 'torch'
解决: 确保先安装PyTorch再编译本项目
pip install torch
python setup.py develop问题: CUDA版本不匹配
解决: 确保PyTorch的CUDA版本与系统CUDA版本一致
python -c "import torch; print(torch.version.cuda)"
nvcc --version问题: CUDA out of memory
解决: 减小批次大小或序列长度
问题: Expected stride to be 1
解决: 确保输入张量是连续的
u = u.contiguous()
delta = delta.contiguous()selective_scan/
├── __init__.py # Python API接口
├── setup.py # 安装配置
├── README.md # 本文档
├── .gitignore # Git忽略配置
└── core/ # CUDA核心实现
├── selective_scan.cpp # PyTorch扩展入口
├── selective_scan.h # 参数结构定义
├── selective_scan_common.h # 通用工具
├── selective_scan_fwd_kernel.cuh # 前向kernel
├── selective_scan_bwd_kernel.cuh # 反向kernel
├── selective_scan_fwd_*.cu # 前向实例化
├── selective_scan_bwd_*.cu # 反向实例化
├── reverse_scan.cuh # 反向扫描实现
├── static_switch.h # 编译期分支
└── uninitialized_copy.cuh # 内存工具
核心算法引用:如果使用了Selective Scan算法,请引用Mamba原始论文:
@article{mamba,
title={Mamba: Linear-Time Sequence Modeling with Selective State Spaces},
author={Gu, Albert and Dao, Tri},
journal={arXiv preprint arXiv:2312.00752},
year={2023}
}本提取项目引用(可选):
@software{selective_scan_extracted,
author = {biubushy},
title = {Selective Scan CUDA: Extracted from Mamba for Standalone Use},
year = {2025},
url = {https://github.com/biubushy/selective_scan},
note = {Extracted and packaged from the original Mamba implementation}
}重要说明:
- 本项目的核心CUDA代码完全来自Mamba原始仓库(Tri Dao © 2023)
- 我们仅进行了代码提取和封装工作,未对算法实现做任何修改
- 所有代码遵循Mamba原始项目的许可证
- 封装和文档部分由biubushy完成(2025-10)
欢迎提交Issue和Pull Request!
贡献范围:
- ✅ 封装层改进(setup.py、init.py)
- ✅ 文档完善和示例补充
- ✅ 安装脚本优化
- ✅ Bug修复和兼容性改进
- ❌ 核心CUDA算法修改(请向Mamba原始仓库提交)
开发流程:
- Fork本仓库
- 创建特性分支 (
git checkout -b feature/PackagingImprovement) - 提交更改 (
git commit -m 'Improve packaging') - 推送到分支 (
git push origin feature/PackagingImprovement) - 开启Pull Request
- v0.1.0 (2025-10) - 初始提取版本
- 从Mamba原始仓库提取Selective Scan核心代码
- 源自commit:
d49d7c909b5a9d151b3f4e7e8477e7d20f156860(2024-08-30) - 对应Mamba v2.2.6版本
- 源自commit:
- 重新组织为独立Python包结构
- 保留原有全部功能:前向/反向传播、多精度、实数/复数支持
- 添加标准setup.py和__init__.py接口
- 编写完整使用文档
- 从Mamba原始仓库提取Selective Scan核心代码
- GitHub: @biubushy
- 项目主页: https://github.com/biubushy/selective_scan
- Issues: https://github.com/biubushy/selective_scan/issues
biubushy | 2025-10