Skip to content

ishovkun/flash-attention

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

66 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

flash-attention

This is an implementation of the Flash attention V2 in raw CUDA without any libraries (e.g. Cutlass). This started as fork of flash-attention-minimal, and I was wondering how far I could get this (see the benchmark below).

The code is organized as multiple kernels, each representing a separate optimization step. Therefore, it might serve as an educational source (see Implemented Kernels.

Benchmark

This is the benchmark on the parameters of a transformer from GPT-3. The benchmark has been measured on a single A100 GPU. scalability image

Build Instructions

  • Install CUDA Toolkit (I used 12.8).
  • Install pytorch (I used 2.6.0).
  • Clone flash-attention.
git clone https://github.com/ishovkun/flash-attention.git
cd flash-attention
mkdir build && cd build
  • Determine the Compute Capability of your GPU:
COMPUTE_CAP=$(nvidia-smi --query-gpu=compute_cap --format=csv | sed -n '2p' | sed 's/\.//g')
  • Determine Pytorch path:
TORCH_PATH=$(python -c "import torch; print(torch.__path__[0])")
  • Configure the build:
cmake .. -DTorch_DIR=${TORCH_PATH}/share/cmake/Torch \
        -DCMAKE_CUDA_ARCHITECTURES=${COMPUTE_CAP} \
        -DCMAKE_BUILD_TYPE=Release
  • Compile the code:
make -j

Run Instructions

  • First, run the C++ code to make sure that the application built correctly. The runner runs the correctness tests.
./runner
  • Next, run the Python benchmark:
python ../benchmarks/bench.py

Implemented Kernels

  • flash_naive: This code is an implementation from flash-attention-minimal. It does many things wrong, e.g. non-coalesced memory loads.
  • flash_2d: This kernel uses 2D grid blocks; the memory loads are coalesced.
  • warp_wmma: This kernel is taken from flash-attention-minimal. It uses cuda's wmma tensor core instructions and launches single-warp CTAs.
  • block_wmma: This kernel builds on warp_wmma but uses multiple warps in a CTA.
  • scalar2d_row_tile: This kernel is scalar but processes multiple rows of a single attention head instead of a full head as opposed to previous versions.
  • wmma_row_block: This kernel is a block version of scalar2d_row_tile and uses wmma instructions.
  • kernel_mma: This kernel is a block version of scalar2d_row_tile and uses mma (inline assembly) instructions.
  • kernel_mma_swizzle: Same as previous version but uses skeweing (a form of swizzling) to improve memory access patterns.
  • kernel_mma_qreg: This kernel uses the algorithm from FlashAttention2. It stores the Q tile in registers, thus minimizing the number of thread synchronizations. It also uses skewing and vectorized stores into the global memory.
  • kernel_mma_qreg_async: This kernel builds on the previous version but uses async memory copies for K and V matrices.

About

A rewrite of Minimal Flash Attention in

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages

  • Cuda 68.9%
  • C++ 25.6%
  • Python 3.5%
  • CMake 2.0%