feat: NEON-vectorized flash attention with hardware FP16 on AArch64#24
Open
siddiquifaras wants to merge 1 commit intoRightNow-AI:mainfrom
Open
feat: NEON-vectorized flash attention with hardware FP16 on AArch64#24siddiquifaras wants to merge 1 commit intoRightNow-AI:mainfrom
siddiquifaras wants to merge 1 commit intoRightNow-AI:mainfrom
Conversation
The flash attention inner loops were scalar with software FP16 conversion (~12 instructions per fp16_to_fp32 call). At 100 tokens of context, this costs 117M scalar instructions per token across 32 heads x 22 layers, nearly half the cost of the NEON-optimized matmuls. Hardware vcvt_f32_f16 does 4 conversions in 1 instruction (vs 48 scalar), collapsing the attention overhead to <2% of matmul cost. Changes: - quant.h: add fp16x4_to_f32 / f32x4_to_fp16 helpers, guarded by PICOLM_NEON && __aarch64__. Software FP16 untouched as fallback. - model.c: NEON paths for Q.K dot product, V accumulation (both branches), normalization, and KV cache FP16 writes. All under #ifdef PICOLM_FP16_HW with original scalar code in #else. - tensor.c: NEON path for RoPE K heads (mirrors existing Q-head pattern, for code symmetry). Precondition: head_dim and kv_dim must be multiples of 4. True for all LLaMA-architecture models (head_dim is always 64 or 128). 3 files changed, 62 insertions(+). Binary size unchanged (87736 bytes with -O3 -ffast-math). Tested on Apple M4 Pro, TinyLlama 1.1B Q4_K_M, -t 0 greedy: -n 20: 23.9 -> 29.6 tok/s (+24% vs baseline, +11% vs flags-only) -n 100: 20.9 -> 27.2 tok/s (+30% vs baseline, +23% vs flags-only) Output character-identical to baseline at all context lengths. --json mode and --cache round-trip verified.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What does this PR do?
NEON-vectorizes the flash attention inner loops and KV cache FP16 writes on AArch64, using hardware
vcvt_f32_f16/vcvt_f16_f32instead of the softwarefp16_to_fp32/fp32_to_fp16functions. Also adds a NEON path for RoPE K heads to match the existing Q-head pattern.All scalar fallback code is preserved in
#elseblocks. x86, ARM32, RISC-V, and Windows builds are completely untouched.Type of change
Why this matters
The flash attention inner loops call
fp16_to_fp32once per dimension per position per head. That function is ~12 scalar instructions (bit shifts, branches, masking). At 100 tokens of context across 32 heads and 22 layers, this costs ~117M scalar instructions per token, nearly half the cost of the already-NEON-optimized matmuls.The AArch64
vcvt_f32_f16instruction converts 4 FP16 values in a single cycle (vs ~48 scalar instructions). Combined with NEON FMA for the dot products and accumulations, the attention overhead drops from ~48% of matmul cost to <2%.Changes
quant.h (+9 lines): Two inline helpers guarded by
PICOLM_NEON && __aarch64__:fp16x4_to_f32: loads 4 uint16 FP16 values, converts to float32x4 viavcvt_f32_f16f32x4_to_fp16: converts float32x4 to FP16, stores as 4 uint16 viavcvt_f16_f32PICOLM_FP16_HWmacro for use in model.cmodel.c (+43 lines): Five
#ifdef PICOLM_FP16_HWblocks with scalar code preserved in#else:vmlaq_f32accumulator +vaddvq_f32horizontal sumvmlaq_f32(fp16_val, acc, correction)vmlaq_f32(acc, fp16_val, weight)vmulq_f32(acc, inv_sum)f32x4_to_fp16for FP32-to-FP16 conversiontensor.c (+19 lines): NEON path for RoPE K heads, identical pattern to the existing Q-head NEON code using
vld2q_f32/vst2q_f32interleaved loads with scalar tail. Included for code symmetry (4 KV heads, negligible performance impact).Precondition:
head_dimandkv_dimmust be multiples of 4. True for all LLaMA-architecture models (head_dim is always 64 or 128).Testing
Test command:
Output:
Output is character-identical to the scalar baseline at all context lengths tested (-n 20, -n 100, -n 256).
Results (with -O3 -ffast-math from #16)
The context-dependent slowdown is effectively eliminated. Generation at 100 tokens of context (27.2 tok/s) now exceeds the original speed at 20 tokens (23.9 tok/s).
Checklist
make native)--jsonmode--cacheround-trip