A GPU-Optimized Transformer Inference Runtime with Fused CUDA/Triton Kernels, IR Fusion Passes, and FX-Based Graph Lowering
FlashFlow is a end-to-end optimized inference engine project for GPT-style Transformers. It integrates:
- A compiler-style IR with graph-level optimizations
- Fusion passes for attention + MLP
- Fused CUDA and Triton kernels for high-throughput inference
- Quantization-aware training (INT8) and quantized runtime
- Autotuning for kernel tile sizes
- A clean PyTorch-to-FlashFlow export pipeline using FX
- Profiling support via Nsight Compute and Nsight Systems
- A fully working reference PyTorch GPTSmall model + training scripts
- Extensive unit tests for both C++ and Python components
FlashFlow serves as a mini TensorRT / Inductor specialized for GPT inference, built from scratch.
FLASHFLOW/
│
├── benchmarks/
│ ├── bench_kernel_micro.py
│ └── bench_throughput.py
│
├── cpp/
│ ├── include/flashflow/
│ │ ├── autotune.hpp
│ │ ├── graph.hpp
│ │ ├── ir.hpp
│ │ ├── kernels.hpp
│ │ ├── quantization.hpp
│ │ └── runtime.hpp
│ │
│ ├── kernels/
│ │ ├── cuda/
│ │ │ ├── fused_attention.cu
│ │ │ ├── fused_mlp.cu
│ │ │ ├── layernorm.cu
│ │ │ └── softmax.cu
│ │ └── triton/
│ │ ├── fused_attention_triton.py
│ │ └── fused_mlp_triton.py
│ │
│ ├── src/
│ │ ├── fusion_passes.cpp
│ │ ├── graph_lowering_fx.cpp
│ │ ├── ir.cpp
│ │ ├── kernel_registry.cpp
│ │ ├── logging.cpp
│ │ ├── memory_planner.cpp
│ │ ├── quant_runtime.cpp
│ │ └── runtime.cpp
│ │
│ └── CMakeLists.txt
│
├── python/
│ ├── eval/
│ │ ├── eval_latency.py
│ │ └── eval_perplexity.py
│ ├── export/
│ │ ├── export_checkpoint.py
│ │ └── export_fx_graph.py
│ ├── models/
│ │ ├── gpt_small.py
│ │ └── transformer_blocks.py
│ ├── training/
│ │ ├── train_baseline.py
│ │ ├── train_mixed_precision.py
│ │ └── train_qat_int8.py
│
├── scripts/
│ ├── build_cpp.sh
│ ├── run_profiler_ncu.sh
│ └── run_profiler_nsys.sh
│
├── tests/
│ ├── cpp/
│ │ ├── test_fusion.cpp
│ │ ├── test_ir.cpp
│ │ └── test_runtime.cpp
│ └── python/
│ ├── test_export_fx.py
│ └── test_training_equivalence.py
│
└── README.md
FlashFlow requires:
- Python ≥ 3.9
- CUDA ≥ 11.7
- PyTorch ≥ 2.1
- A GPU with Compute Capability ≥ 7.0 (V100/A100/RTX30xx/RTX40xx)
conda create -n flashflow python=3.10 -y
conda activate flashflowpip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121pip install -r requirements.txt./scripts/build_cpp.shmkdir -p build
cd build
cmake .. -DCMAKE_BUILD_TYPE=Release
cmake --build . -j$(nproc)pytest -q python/tests./build/test_ir
./build/test_fusion
./build/test_runtimepython python/training/train_baseline.py --data data/train.txt --save-dir checkpoints/baselinepython python/training/train_mixed_precision.py --data data/train.txt --save-dir checkpoints/amp --amp-dtype bf16python python/training/train_qat_int8.py --data data/train.txt --save-dir checkpoints/qatpython python/export/export_fx_graph.py --model-name checkpoints/baseline/final_model.pt --out graph.jsonpython python/export/export_checkpoint.py --model checkpoints/baseline/final_model.pt --out-dir export/python python/eval/eval_latency.py --engine flashflow_engine.ptpython python/eval/eval_perplexity.py --model checkpoints/baseline/final_model.pt./scripts/run_profiler_ncu.sh ./build/bench_mlp --kernel-name fused_mlp./scripts/run_profiler_nsys.sh ./build/bench_attentionFlashFlow is a full-stack optimized Transformer inference engine with real fused kernels, IR fusion, FX lowering, quantization, autotuning, and runtime execution.