-
Notifications
You must be signed in to change notification settings - Fork 1
[attention-int8] Multiple build issues and numerical accuracy problem #3
Description
Summary
I audited the attention-int8 kernel and found several issues preventing it from building and running correctly outside of a Codespaces environment. After fixing the build issues, the kernel runs but produces numerically inaccurate results.
Environment
- GPU: NVIDIA RTX 5070 (sm_120)
- CUDA: 12.9
- PyTorch: 2.9.1+cu128
- OS: Ubuntu (x86_64)
Issues Found
1. Makefile: hardcoded Codespaces path
The NIX_ENV variable in the Makefile points to /home/codespace/.nix-profile/etc/profile.d/nix.sh, which breaks on any non-Codespaces machine.
Fix:
# Before
NIX_ENV = . /home/codespace/.nix-profile/etc/profile.d/nix.sh
# After
NIX_ENV = . $(HOME)/.nix-profile/etc/profile.d/nix.sh2. launch_int8_attention — inline prevents linking + type mismatch
In attention_int8.cu, the function launch_int8_attention is declared inline, which prevents it from being visible to the linker when torch_binding.cpp declares it as extern.
Additionally, the timestep parameter is int64_t in the .cu file but int in the extern declaration in torch_binding.cpp, causing a symbol mismatch.
Fix in attention_int8.cu:
// Before
inline cudaError_t launch_int8_attention(..., int64_t timestep, ...)
// After
cudaError_t launch_int8_attention(..., int timestep, ...)3. static vs extern conflict in validation functions
Several functions in torch_binding.cpp are declared static but are also declared (non-static) in torch_binding.h. This causes a compilation error.
Fix: Remove static from validate_tensor, validate_shapes, validate_head_dim, validate_kv_constraint, validate_timestep_scales, and int8_attention_cuda in torch_binding.cpp.
4. example.py does not test the actual kernel
The example script creates a 1D tensor [1.0, 2.0, 3.0] and checks if the result is x + 1.0. This has nothing to do with an attention kernel that expects 4D tensors [B, H, N, D].
5. Numerical accuracy — results diverge from reference
After fixing all build issues, I compared the kernel output against torch.nn.functional.scaled_dot_product_attention:
Output shape: torch.Size([1, 8, 64, 64])
Mean difference: 0.168081
Max difference: 1.693359
For INT8 quantized attention, the expected mean error should be below ~0.01-0.02. A mean error of 0.168 suggests a problem in the quantization, the softmax, or the dequantization logic.
Test script used
import torch
torch.ops.load_library("./torch-ext/attention_int8/_attention_int8_cuda_dba582b_dirty.abi3.so")
Q = torch.randn(1, 8, 64, 64, dtype=torch.float16, device="cuda")
K = torch.randn(1, 8, 64, 64, dtype=torch.float16, device="cuda")
V = torch.randn(1, 8, 64, 64, dtype=torch.float16, device="cuda")
O = torch.ops.int8_attn.int8_attention_forward(Q, K, V)
ref = torch.nn.functional.scaled_dot_product_attention(Q, K, V)
diff_mean = (O.float() - ref.float()).abs().mean().item()
diff_max = (O.float() - ref.float()).abs().max().item()
print(f"Output shape: {O.shape}")
print(f"Mean difference: {diff_mean:.6f}")
print(f"Max difference: {diff_max:.6f}")Suggestions
- Fix the build issues (1-3) so the kernel compiles out of the box
- Replace
example.pywith a real test - Investigate the numerical accuracy issue — likely in the online softmax or INT8 quantization/dequantization path