Skip to content

feat: NEON-vectorized flash attention with hardware FP16 on AArch64#24

Open
siddiquifaras wants to merge 1 commit intoRightNow-AI:mainfrom
siddiquifaras:feat/neon-flash-attention
Open

feat: NEON-vectorized flash attention with hardware FP16 on AArch64#24
siddiquifaras wants to merge 1 commit intoRightNow-AI:mainfrom
siddiquifaras:feat/neon-flash-attention

Conversation

@siddiquifaras
Copy link

What does this PR do?

NEON-vectorizes the flash attention inner loops and KV cache FP16 writes on AArch64, using hardware vcvt_f32_f16 / vcvt_f16_f32 instead of the software fp16_to_fp32 / fp32_to_fp16 functions. Also adds a NEON path for RoPE K heads to match the existing Q-head pattern.

All scalar fallback code is preserved in #else blocks. x86, ARM32, RISC-V, and Windows builds are completely untouched.

Type of change

  • Performance improvement

Why this matters

The flash attention inner loops call fp16_to_fp32 once per dimension per position per head. That function is ~12 scalar instructions (bit shifts, branches, masking). At 100 tokens of context across 32 heads and 22 layers, this costs ~117M scalar instructions per token, nearly half the cost of the already-NEON-optimized matmuls.

The AArch64 vcvt_f32_f16 instruction converts 4 FP16 values in a single cycle (vs ~48 scalar instructions). Combined with NEON FMA for the dot products and accumulations, the attention overhead drops from ~48% of matmul cost to <2%.

Changes

quant.h (+9 lines): Two inline helpers guarded by PICOLM_NEON && __aarch64__:

  • fp16x4_to_f32: loads 4 uint16 FP16 values, converts to float32x4 via vcvt_f32_f16
  • f32x4_to_fp16: converts float32x4 to FP16, stores as 4 uint16 via vcvt_f16_f32
  • Defines PICOLM_FP16_HW macro for use in model.c

model.c (+43 lines): Five #ifdef PICOLM_FP16_HW blocks with scalar code preserved in #else:

  • Q.K dot product: NEON vmlaq_f32 accumulator + vaddvq_f32 horizontal sum
  • V accumulation (new-max branch): vmlaq_f32(fp16_val, acc, correction)
  • V accumulation (normal branch): vmlaq_f32(acc, fp16_val, weight)
  • Normalization: vmulq_f32(acc, inv_sum)
  • KV cache writes (K and V): f32x4_to_fp16 for FP32-to-FP16 conversion

tensor.c (+19 lines): NEON path for RoPE K heads, identical pattern to the existing Q-head NEON code using vld2q_f32 / vst2q_f32 interleaved loads with scalar tail. Included for code symmetry (4 KV heads, negligible performance impact).

Precondition: head_dim and kv_dim must be multiples of 4. True for all LLaMA-architecture models (head_dim is always 64 or 128).

Testing

  • Tested on ARM64 (Apple M4 Pro, macOS)
  • Tested with TinyLlama 1.1B Q4_K_M

Test command:

make clean && make native
./picolm model.gguf -p "The capital of France is" -n 100 -t 0 -j 8

Output:

Paris.

2. B.C. The capital of ancient Rome was Rome.

3. A.D. The capital of ancient Rome was Rome.

4. P.R. The capital of ancient Rome was Rome.

5. A.D. The capital of ancient Rome was Rome.

6. P.R. The capital of ancient Rome was Rome.

7. A.D. The capital of ancient Rome was Rome.

8

Output is character-identical to the scalar baseline at all context lengths tested (-n 20, -n 100, -n 256).

Results (with -O3 -ffast-math from #16)

Metric Baseline (-O2, scalar) With NEON attention
Binary size 87,784 bytes 87,736 bytes
Generation (-n 20) 23.9 tok/s 29.6 tok/s (+24%)
Generation (-n 100) 20.9 tok/s 27.2 tok/s (+30%)

The context-dependent slowdown is effectively eliminated. Generation at 100 tokens of context (27.2 tok/s) now exceeds the original speed at 20 tokens (23.9 tok/s).

Checklist

  • Code compiles without warnings (make native)
  • No new dependencies added
  • Memory usage not increased (45.17 MB, unchanged)
  • Works with --json mode
  • Works with --cache round-trip

The flash attention inner loops were scalar with software FP16
conversion (~12 instructions per fp16_to_fp32 call). At 100 tokens
of context, this costs 117M scalar instructions per token across
32 heads x 22 layers, nearly half the cost of the NEON-optimized
matmuls.

Hardware vcvt_f32_f16 does 4 conversions in 1 instruction (vs 48
scalar), collapsing the attention overhead to <2% of matmul cost.

Changes:
- quant.h: add fp16x4_to_f32 / f32x4_to_fp16 helpers, guarded by
  PICOLM_NEON && __aarch64__. Software FP16 untouched as fallback.
- model.c: NEON paths for Q.K dot product, V accumulation (both
  branches), normalization, and KV cache FP16 writes. All under
  #ifdef PICOLM_FP16_HW with original scalar code in #else.
- tensor.c: NEON path for RoPE K heads (mirrors existing Q-head
  pattern, for code symmetry).

Precondition: head_dim and kv_dim must be multiples of 4. True for
all LLaMA-architecture models (head_dim is always 64 or 128).

3 files changed, 62 insertions(+).
Binary size unchanged (87736 bytes with -O3 -ffast-math).

Tested on Apple M4 Pro, TinyLlama 1.1B Q4_K_M, -t 0 greedy:
  -n 20:  23.9 -> 29.6 tok/s (+24% vs baseline, +11% vs flags-only)
  -n 100: 20.9 -> 27.2 tok/s (+30% vs baseline, +23% vs flags-only)
  Output character-identical to baseline at all context lengths.
  --json mode and --cache round-trip verified.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant