Skip to content

[draft, low prio] infra: Fix flash-attn#55

Closed
tomtseng wants to merge 3 commits intomainfrom
tomtseng/flash-attn
Closed

[draft, low prio] infra: Fix flash-attn#55
tomtseng wants to merge 3 commits intomainfrom
tomtseng/flash-attn

Conversation

@tomtseng
Copy link
Collaborator

@tomtseng tomtseng commented Jan 16, 2026

Changes

Our implementation of GCG uses flash attention. PR #53 added the flash attention package, but actually it's the wrong one — we want flash-attn instead of flash_attention. However flash-attn is tricky to install, it needs to be compiled. So the Docker image needs to be updated.

It's not clear we care about GCG and we can prob turn off flash attention in GCG. so this PR is low prio, I think I'll put it aside for now.

TODO: trying to build dockerfile but currently it's OOMing when compiling locally, even when I max out the Docker Desktop memory cap setting and I set MAX_JOBS=1. Maybe compile on flamingo?

Testing

TODO build image, launch a new devbox and get this script to work:

"""Test script to verify flash-attn works with transformers."""

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

MODEL_ID = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA version: {torch.version.cuda}")

print(f"\nLoading {MODEL_ID} with Flash Attention 2...")

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.float16,
    attn_implementation="flash_attention_2",
    device_map="auto",
)

print(f"Model loaded successfully!")
print(f"Attention implementation: {model.config._attn_implementation}")

# Quick inference test
input_text = "Hello, world!"
inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
with torch.no_grad():
    outputs = model.generate(**inputs, max_new_tokens=10)
result = tokenizer.decode(outputs[0], skip_special_tokens=True)

print(f"\nInference test:")
print(f"  Input: {input_text}")
print(f"  Output: {result}")
print("\nFlash Attention 2 is working!")

@tomtseng tomtseng force-pushed the tomtseng/flash-attn branch 2 times, most recently from 1a6141a to b3d34b4 Compare January 16, 2026 06:35
@tomtseng tomtseng mentioned this pull request Jan 16, 2026
@@ -1,5 +1,6 @@
ARG PYTORCH_CUDA_VERSION=2.0.1-cuda11.7-cudnn8
FROM pytorch/pytorch:${PYTORCH_CUDA_VERSION}-runtime
ARG PYTORCH_CUDA_VERSION=2.9.0-cuda12.8-cudnn9
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

needed to bump CUDA version or else flash-attn would have a compile error. decided to bump PyTorch version to the latest compatible version too

@tomtseng tomtseng changed the title infra: Fix flash-attn [draft, low prio] infra: Fix flash-attn Jan 16, 2026
@tomtseng tomtseng changed the base branch from sh/arxiv_techdebt to main January 23, 2026 00:08
- Alphabetized all dependencies (case-insensitive)
- Changed flash-attention>=1.0.0 to flash-attn>=2.8.3; platform_system == 'Linux'
- Changed platform_system != 'Darwin' to platform_system == 'Linux' for vllm and bitsandbytes
- Docker image needs to be `devel` image, since flash-attn needs nvcc
- needs nvcc (devel image)
- need PyTorch image version to match PyTorch library version, and also
  need CUDA version to match (12.8). Old version 2.6.0 doesn't have CUDA
  12.8 support, so bumped torch to the latest compatible version (2.9.0; 2.9.1
  doesn't work with the current version of vllm) and then updated image
  version
- fiddling with the Dockerfile
@tomtseng
Copy link
Collaborator Author

tomtseng commented Feb 7, 2026

GCG refactor in #82 removed the reference to flash-attn, this is no longer relevant. Though I may take some of the Dockerfile refactors and open a separate PR for that

@tomtseng tomtseng closed this Feb 7, 2026
@tomtseng tomtseng deleted the tomtseng/flash-attn branch February 7, 2026 09:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant