What's Changed
- Enable MLA in TE JAX Extension
- Allocate dkv expanded buffer according to max_tokens_kv
- Support aiter build in multiple target/dockerfile
- Enable the jax side fused-attn pytests with sequence packing +swa
- update hipify_torch submodule to fix v2 mappings
- Ensure weight transpose is valid for FP8 training
- MXFP8 hipblasLt GEMM support
- Normalization kernels for mxfp8
- Enabled fp8 gemm gelu_aux_bias
- If 'ninja' is not found, it installs it via pip
- Add transpose cache to LayerNorm kernel
Fixes
- FIX Accumulate
intoverflow in workspace memory calculation - Fix dropout when using a new-style rng
- FIX Update intra-seq padding detection in CK Fused Attention backend
- Fix NCCL error in test_torch_fsdp2
- [Fix] Ensure ln_out is not cached if wgrads
- [Fix] Increased tolerance and used FP32 to compute for unpermute kernel
Upstream release notes: https://github.com/NVIDIA/TransformerEngine/releases/tag/v2.4
Full Changelog: v2.2_rocm...v2.4_rocm