Skip to content

PyTorch-ready drop-in Triton-based Gated SSM for ultra-long sequences. Heavily optimized. Suitable for pre-training and fine-tuning.

Notifications You must be signed in to change notification settings

samblouir/quick_ssm

Repository files navigation

Quick SSM: Efficient Triton-based Scan for RNNs/SSMs

Quick SSM provides an optimized implementation of the associative scan operation using Triton kernels for high performance. It includes a convenient PyTorch interface for the scan function for easy integration into existing SSM codebases, as a well as a layer to easily add an SSM layer to your model.

The base model used in this library is the baseline Gated SSM from Birdie.

This implementation is inspired by code and techniques found in:

Relevant Papers:


quick ssm logo of a rocket

You are here (Quick SSM)!



Key Features

  • Block Scan Algorithm: Implements the parallel prefix sum (scan) algorithm over blocks, enabling efficient processing of long sequences.
  • Triton Kernels: Pure Triton kernels improve efficiency.
  • PyTorch Integration with an Interface and Layer: quick_ssm.scan_interface.scan and quick_ssm.layers.SSM are included for easy integration into existing Torch models or codebases.

Installation

You can install the package directly from GitHub:

pip install git+https://github.com/samblouir/quick_ssm

Alternatively, install it in editable mode for development:

git clone https://github.com/samblouir/quick_ssm
cd quick_ssm
pip install -e .

Note: Quick SSM requires a very recent PyTorch version with CUDA and Triton 3.1+ (which typically requires an NVIDIA GPU).

Quick SSM has been tested on the following NVIDIA setups:

  • Torch 3.7, 1x NVIDIA 3090 with CUDA 12.4 and 12.6
  • Torch 3.7, 4x NVIDIA A100 with CUDA 12.4
  • Torch 3.7, 4x NVIDIA H100 with CUDA 12.4

Usage

Here's a basic example using the scan_interface:

import torch
from quick_ssm.scan_interface import scan

device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32

# Example dimensions (Batch, Sequence Length, Hidden Dimension)
# Note: Sequence length L must be a power of 2
B, L, D = 4, 2048, 16

x = torch.randn(B, L, D, device=device, dtype=dtype, requires_grad=True)
a = torch.rand(B, L, D, device=device, dtype=dtype, requires_grad=True) * 0.1 + 0.9
b = torch.randn(B, L, D, device=device, dtype=dtype, requires_grad=True) * 0.1
c = torch.sigmoid(torch.randn(B, L, D, device=device, dtype=dtype))

# The scan computes:
#   h(t) = a(t) * h(t-1) + b(t) * x(t)
#   y(t) = c(t) * h(t)
# checkpoint=True saves VRAM by recomputing h during backward.
# tile_b / tile_d let you chunk batch or feature dims to fit in small VRAM.
y = scan(x, a, b, c, block_l=256, checkpoint=True, tile_b=1, tile_d=512)

Layer Example

import torch
import torch.nn as nn
from quick_ssm.layers import SSM

# Basic Torch Model
class AnyTorchModel(nn.Module):
	def __init__(self, hidden_size):
		super(AnyTorchModel, self).__init__()
		self.ssm = SSM(
			hidden_size=hidden_size,
			state_size_mult=(hidden_size * 4),
			dtype=torch.float32, # Parameter dtype
			compute_dtype=torch.float16, # Computation dtype 
		)
			

	def forward(self, x):
		return self.ssm(x)

Copying Toy Training Run

The repo includes copy_train.py, a minimal script that trains an SSM to copy a random string of a–zA–Z0–9 characters (not next‑token). It’s a fast sanity check that the recurrence can store and decode information.

python copy_train.py --steps 200 --seq-len 64 --batch-size 16
# add --cpu to force CPU; seq_len must be a power of 2

Tune VRAM via tile_b, tile_d, block_l inside the script or in quick_ssm.layers.SSM.

Precision Benchmark (teacher-forced copy, seq=8)

12‑layer stack, hidden=256, state=1024, batch=32, warmup 500 + cosine LR, teacher forcing. fp32 converges fastest; fp16 and bf16 follow closely.

Loss vs step Accuracy vs step

Core Concept: SSM Scan

This library efficiently computes the following core SSM recurrence relation:

  1. Hidden State Update: h(t) = a(t) * h(t-1) + b(t) * x(t)
  2. Output Calculation: y(t) = c(t) * h(t)

Where:

  • x(t): Input sequence tensor at time t.
  • h(t): Hidden state tensor at time t.
  • a(t): State transition factor
  • b(t): Input gate/projection factor.
  • c(t): Output gate/projection factor (sometimes called a side gate).
  • y(t): Output sequence tensor at time t.

All tensors are of shape (B, L, D), where:

  • B: Batch size
  • L: Sequence length (must be a power of 2)
  • D: Hidden dimension

Repository Structure

  • example_interface.py: Minimal scan function usage example.
  • example_layer.py: Minimal SSM layer example.
  • src/: Contains the core library code.
    • triton_scan.py: Triton kernels for the forward and backward scan passes.
    • scan_interface.py: The main torch.autograd.Function interface (scan).
    • layers.py: An example nn.Module (SSM) demonstrating usage.
    • naive_baseline.py: Pure PyTorch implementations of the scan for testing.
    • test_forwards.py: Correctness tests for the forward pass.
    • test_backwards.py: Correctness tests for the backward pass (gradient calculation).

TODO / Future Work

  • Reduce unnecessary memory usage by avoiding additional materializations.
  • Add tensor-parallel support for the scan.
  • Add automatic padding to support non-power-of-2 sequence lengths.
  • Verify torch.compile compatibility with distributed training.
  • Complete Gradient Checkpointing support to reduce VRAM usage during training.
  • Explore additional VRAM optimization strategies.
  • Implement a fast inference/generation mode (e.g., for autoregressive sampling).
  • Add support for Hawk, which the original Birdie paper had.

Not currently planned, but possible

  • Pipeline scan, splitting the sequence time-wise across devices.

Contributing

Contributions are welcome! Please feel free to open an issue to report bugs or suggest features.

License

Apache 2.0

About

PyTorch-ready drop-in Triton-based Gated SSM for ultra-long sequences. Heavily optimized. Suitable for pre-training and fine-tuning.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages