Quick SSM provides an optimized implementation of the associative scan operation using Triton kernels for high performance. It includes a convenient PyTorch interface for the scan function for easy integration into existing SSM codebases, as a well as a layer to easily add an SSM layer to your model.
The base model used in this library is the baseline Gated SSM from Birdie.
- Paper: Birdie: Advancing State Space Models with Reward-Driven Objectives and Curricula
- Github: samblouir/birdie
This implementation is inspired by code and techniques found in:
- accelerated-scan by Volodymyr Kyrylov
- Mamba: The Hard Way by Sasha Rush
Relevant Papers:
- Gated State Spaces
- Mamba: Linear-Time Sequence Modeling with Selective State Spaces
- Griffin: Mixing Gated Linear Recurrences with Local Attention for Efficient Language Models
- Block Scan Algorithm: Implements the parallel prefix sum (scan) algorithm over blocks, enabling efficient processing of long sequences.
- Triton Kernels: Pure Triton kernels improve efficiency.
- PyTorch Integration with an Interface and Layer:
quick_ssm.scan_interface.scanandquick_ssm.layers.SSMare included for easy integration into existing Torch models or codebases.
You can install the package directly from GitHub:
pip install git+https://github.com/samblouir/quick_ssmAlternatively, install it in editable mode for development:
git clone https://github.com/samblouir/quick_ssm
cd quick_ssm
pip install -e .Note: Quick SSM requires a very recent PyTorch version with CUDA and Triton 3.1+ (which typically requires an NVIDIA GPU).
Quick SSM has been tested on the following NVIDIA setups:
- Torch 3.7, 1x NVIDIA 3090 with CUDA 12.4 and 12.6
- Torch 3.7, 4x NVIDIA A100 with CUDA 12.4
- Torch 3.7, 4x NVIDIA H100 with CUDA 12.4
Here's a basic example using the scan_interface:
import torch
from quick_ssm.scan_interface import scan
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32
# Example dimensions (Batch, Sequence Length, Hidden Dimension)
# Note: Sequence length L must be a power of 2
B, L, D = 4, 2048, 16
x = torch.randn(B, L, D, device=device, dtype=dtype, requires_grad=True)
a = torch.rand(B, L, D, device=device, dtype=dtype, requires_grad=True) * 0.1 + 0.9
b = torch.randn(B, L, D, device=device, dtype=dtype, requires_grad=True) * 0.1
c = torch.sigmoid(torch.randn(B, L, D, device=device, dtype=dtype))
# The scan computes:
# h(t) = a(t) * h(t-1) + b(t) * x(t)
# y(t) = c(t) * h(t)
# checkpoint=True saves VRAM by recomputing h during backward.
# tile_b / tile_d let you chunk batch or feature dims to fit in small VRAM.
y = scan(x, a, b, c, block_l=256, checkpoint=True, tile_b=1, tile_d=512)import torch
import torch.nn as nn
from quick_ssm.layers import SSM
# Basic Torch Model
class AnyTorchModel(nn.Module):
def __init__(self, hidden_size):
super(AnyTorchModel, self).__init__()
self.ssm = SSM(
hidden_size=hidden_size,
state_size_mult=(hidden_size * 4),
dtype=torch.float32, # Parameter dtype
compute_dtype=torch.float16, # Computation dtype
)
def forward(self, x):
return self.ssm(x)The repo includes copy_train.py, a minimal script that trains an SSM to copy a random string of a–zA–Z0–9 characters (not next‑token). It’s a fast sanity check that the recurrence can store and decode information.
python copy_train.py --steps 200 --seq-len 64 --batch-size 16
# add --cpu to force CPU; seq_len must be a power of 2Tune VRAM via tile_b, tile_d, block_l inside the script or in quick_ssm.layers.SSM.
12‑layer stack, hidden=256, state=1024, batch=32, warmup 500 + cosine LR, teacher forcing. fp32 converges fastest; fp16 and bf16 follow closely.
This library efficiently computes the following core SSM recurrence relation:
- Hidden State Update:
h(t) = a(t) * h(t-1) + b(t) * x(t) - Output Calculation:
y(t) = c(t) * h(t)
Where:
x(t): Input sequence tensor at timet.h(t): Hidden state tensor at timet.a(t): State transition factorb(t): Input gate/projection factor.c(t): Output gate/projection factor (sometimes called a side gate).y(t): Output sequence tensor at timet.
All tensors are of shape (B, L, D), where:
B: Batch sizeL: Sequence length (must be a power of 2)D: Hidden dimension
example_interface.py: Minimalscanfunction usage example.example_layer.py: MinimalSSMlayer example.src/: Contains the core library code.triton_scan.py: Triton kernels for the forward and backward scan passes.scan_interface.py: The maintorch.autograd.Functioninterface (scan).layers.py: An examplenn.Module(SSM) demonstrating usage.naive_baseline.py: Pure PyTorch implementations of the scan for testing.test_forwards.py: Correctness tests for the forward pass.test_backwards.py: Correctness tests for the backward pass (gradient calculation).
- Reduce unnecessary memory usage by avoiding additional materializations.
- Add tensor-parallel support for the scan.
- Add automatic padding to support non-power-of-2 sequence lengths.
- Verify torch.compile compatibility with distributed training.
- Complete Gradient Checkpointing support to reduce VRAM usage during training.
- Explore additional VRAM optimization strategies.
- Implement a fast inference/generation mode (e.g., for autoregressive sampling).
- Add support for Hawk, which the original Birdie paper had.
- Pipeline scan, splitting the sequence time-wise across devices.
Contributions are welcome! Please feel free to open an issue to report bugs or suggest features.
Apache 2.0


