Skip to content

klae01/ma2mba

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Ma2mba

Mamba

Mamba: Linear-Time Sequence Modeling with Selective State Spaces
Albert Gu*, Tri Dao*
Paper: https://arxiv.org/abs/2312.00752

Mamba-2

Transformers are SSMs: Generalized Models and Efficient Algorithms
Through Structured State Space Duality
Tri Dao*, Albert Gu*
Paper: https://arxiv.org/abs/2405.21060

Ma2mba

Look Every Frame All at Once: Video-Ma2mba for Efficient Long-form Video Understanding
with Multi-Axis Gradient Checkpointing
Hosu Lee*, Junho Kim*, Hyunjun Kim, Yong Man Ro
Paper: https://arxiv.org/abs/2411.19460

About

Ma2mba optimizes memory usage by leveraging multi-axis gradient checkpointing, ensuring that memory scales independently of the number of layers. This advancement makes Ma2mba a powerful extension to Mamba, particularly for long-sequence tasks such as language modeling, long video processing, and genomic sequence analysis.

Mamba itself is a state space model architecture showing promising performance on information-dense data such as language modeling, where previous subquadratic models fall short of Transformers. It is based on the line of progress on structured state space models, with an efficient hardware-aware design and implementation in the spirit of FlashAttention.

Installation

Due to partial CUDA compilation requirements, you need to install the original mamba_ssm package and include this repository (ma2mba) at the front of your Python path:

  1. Clone this repository:

    git clone https://github.com/klae01/ma2mba.git
  2. Install the original packages from PyPI:

    pip install mamba-ssm causal-conv1d
  3. Add ma2mba to the front of your PYTHONPATH:

    export PYTHONPATH=/path/to/ma2mba:$PYTHONPATH

    (Replace /path/to/ma2mba with your local repository path.)

    Alternatively, in Python scripts:

    import sys
    sys.path.insert(0, "/path/to/ma2mba")

Usage

We expose several levels of interface with the Mamba model.

Selective SSM

Mamba is based on a selective SSM layer, which is the focus of the paper (Section 3; Algorithm 2).

Source: ops/selective_scan_interface.py.

Mamba Block

The main module of this repository is the Mamba architecture block wrapping the selective SSM.

Source: modules/mamba_simple.py.

Usage:

import torch
from mamba_ssm import Mamba

batch, length, dim = 2, 64, 16
x = torch.randn(batch, length, dim).to("cuda")
model = Mamba(
    # This module uses roughly 3 * expand * d_model^2 parameters
    d_model=dim, # Model dimension d_model
    d_state=16,  # SSM state expansion factor
    d_conv=4,    # Local convolution width
    expand=2,    # Block expansion factor
).to("cuda")
y = model(x)
assert y.shape == x.shape

Mamba-2

The Mamba-2 block is implemented at modules/mamba2.py.

A simpler version is at modules/mamba2_simple.py

The usage is similar to Mamba(-1):

from mamba_ssm import Mamba2
model = Mamba2(
    # This module uses roughly 3 * expand * d_model^2 parameters
    d_model=dim, # Model dimension d_model
    d_state=64,  # SSM state expansion factor, typically 64 or 128
    d_conv=4,    # Local convolution width
    expand=2,    # Block expansion factor
).to("cuda")
y = model(x)
assert y.shape == x.shape

SSD

A minimal version of the inner SSD module (Listing 1 from the Mamba-2 paper) with conversion between "discrete" and "continuous" SSM versions is at modules/ssd_minimal.py.

Mamba Language Model with Gradient Checkpointing

The Mamba Language Model demonstrates how to integrate a deep sequence model backbone (repeating Mamba blocks) with a language model head. To optimize memory usage for long sequences, Multi-Axis Gradient Checkpointing is supported, ensuring efficient training without modifying the original architecture.

Source: models/mixer_seq_simple.py.

import torch
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel

# Load the pretrained model
model = MambaLMHeadModel.from_pretrained("state-spaces/mamba2-2.7b", device="cuda", dtype=torch.bfloat16)

# Enable multi-axis gradient checkpointing
model.gradient_checkpointing_enable()

vocab_size = model.config.vocab_size
sequence_length = 65536
batch_size = 2

input_ids = torch.randint(0, vocab_size, (batch_size, sequence_length), device="cuda")
output = model(input_ids=input_ids, num_last_tokens=1)

assert output.logits.shape == (batch_size, 1, vocab_size)

Text Generation

The MambaLMHeadModel supports token-by-token generation using the step function, allowing efficient text generation for long sequences. Below is a minimal example using a pretrained Mamba-2 model:

import torch
from transformers import AutoTokenizer
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
import torch.nn.functional as F

tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
tokenizer.pad_token_id = tokenizer.eos_token_id
model = MambaLMHeadModel.from_pretrained("state-spaces/mamba2-2.7b", device="cuda", dtype=torch.bfloat16)
prompt_text = "Artificial intelligence has made significant strides in the past decade."
input_ids = tokenizer(prompt_text, return_tensors="pt").input_ids.to("cuda")

generated = []
max_length = 100
temperature = 0.7

with torch.no_grad():
    hidden_states, conv_states, ssm_states = model.backbone.step(input_ids=input_ids)

for _ in range(max_length):
    logits = model.lm_head(hidden_states[:, -1:, :])
    probs = F.softmax(logits / temperature, dim=-1)
    next_token = torch.multinomial(probs.squeeze(1), num_samples=1)
    generated.append(next_token)
    with torch.no_grad():
        hidden_states, conv_states, ssm_states = model.backbone.step(
            input_ids=next_token, conv_states=conv_states, ssm_states=ssm_states
        )

generated_tokens = torch.cat(generated, dim=-1)
generated_text = tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
print(generated_text)

Pretrained Models

Pretrained models are uploaded to Hugging Face: mamba-130m, mamba-370m, mamba-790m, mamba-1.4b, mamba-2.8b, mamba2-130m, mamba2-370m, mamba2-780m, mamba2-1.3b, mamba2-2.7b, transformerpp-2.7b, mamba2attn-2.7b, trained on 300B tokens on the Pile, as well as mamba-2.8b-slimpj (trained on 600B tokens on the SlimPajama dataset).

The models will be autodownloaded by the generation script below.

These models were trained on the Pile, and follow the standard model dimensions described by GPT-3 and followed by many open source models:

Parameters Layers Model dim.
130M 24 768
370M 48 1024
790M 48 1536
1.4B 48 2048
2.8B 64 2560

(The layer count of Mamba doubles that of a Transformer with similar size, as two Mamba blocks are needed for each "layer" (MHA block + MLP block) of a Transformer.)

Note: these are base models trained only for 300B tokens, without any form of downstream modification (instruction tuning, etc.). Performance is expected to be comparable or better than other architectures trained on similar data, but not to match larger or fine-tuned models.

Evaluations

To run zero-shot evaluations of models (corresponding to Table 3 of the paper), we use the lm-evaluation-harness library.

  1. Install lm-evaluation-harness by pip install lm-eval==0.4.2.
  2. Run evaluation with (more documentation at the lm-evaluation-harness repo):
lm_eval --model mamba_ssm --model_args pretrained=state-spaces/mamba-130m --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande,openbookqa --device cuda --batch_size 256
python evals/lm_harness_eval.py --model hf --model_args pretrained=EleutherAI/pythia-160m --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande --device cuda --batch_size 64

To reproduce the results on the mamba-2.8b-slimpj model reported in the blogposts:

lm_eval --model mamba_ssm --model_args pretrained=state-spaces/mamba-2.8b-slimpj --tasks boolq,piqa,hellaswag,winogrande,arc_easy,arc_challenge,openbookqa,race,truthfulqa_mc2 --device cuda --batch_size 256
lm_eval --model mamba_ssm --model_args pretrained=state-spaces/mamba-2.8b-slimpj --tasks mmlu --num_fewshot 5 --device cuda --batch_size 256

To run evaluations on Mamba-2 models, simply replace the model names:

lm_eval --model mamba_ssm --model_args pretrained=state-spaces/mamba2-2.7b --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande,openbookqa --device cuda --batch_size 256
lm_eval --model mamba_ssm --model_args pretrained=state-spaces/transformerpp-2.7b --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande,openbookqa --device cuda --batch_size 256
lm_eval --model mamba_ssm --model_args pretrained=state-spaces/mamba2attn-2.7b --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande,openbookqa --device cuda --batch_size 256

Note that the result of each task might differ from reported values by 0.1-0.3 due to noise in the evaluation process.

Inference

The script benchmarks/benchmark_generation_mamba_simple.py

  1. autoloads a model from the Hugging Face Hub,
  2. generates completions of a user-specified prompt,
  3. benchmarks the inference speed of this generation.

Other configurable options include the top-p (nucleus sampling) probability, and the softmax temperature.

Examples

To test generation latency (e.g. batch size = 1) with different sampling strategies:

python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba-2.8b" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.7 --repetition-penalty 1.2
python benchmarks/benchmark_generation_mamba_simple.py --model-name "EleutherAI/pythia-2.8b" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.7 --repetition-penalty 1.2
python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba-2.8b" --prompt "My cat wrote all this CUDA code for a new language model and" --minp 0.05 --topk 0 --temperature 0.7 --repetition-penalty 1.2

To test generation throughput with random prompts (e.g. large batch size):

python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba-2.8b" --batch 64
python benchmarks/benchmark_generation_mamba_simple.py --model-name "EleutherAI/pythia-2.8b" --batch 64

With Mamba-2, you just need to change the model name:

python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba2-2.7b" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.7 --repetition-penalty 1.2

Citation

If you use this codebase, or otherwise find our work valuable, please cite Mamba:

@article{mamba,
  title={Mamba: Linear-Time Sequence Modeling with Selective State Spaces},
  author={Gu, Albert and Dao, Tri},
  journal={arXiv preprint arXiv:2312.00752},
  year={2023}
}

@inproceedings{mamba2,
  title={Transformers are {SSM}s: Generalized Models and Efficient Algorithms Through Structured State Space Duality},
  author={Dao, Tri and Gu, Albert},
  booktitle={International Conference on Machine Learning (ICML)},
  year={2024}
}

@article{video_ma2mba,
  title={Look Every Frame All at Once: Video-Ma$^2$mba for Efficient Long-form Video Understanding with Multi-Axis Gradient Checkpointing},
  author={Lee, Hosu and Kim, Junho and Kim, Hyunjun and Ro, Yong Man},
  journal={arXiv preprint arXiv:2411.19460},
  year={2024}
}

About

Mamba SSM architecture with Multi-Axis Gradient Checkpointing

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 81.4%
  • Cuda 13.0%
  • C++ 5.2%
  • C 0.4%