Skip to content

Implement mxfp4 split-k gemm#958

Open
willghatch wants to merge 3 commits intomainfrom
users/willghatch/splitk-mxfp4
Open

Implement mxfp4 split-k gemm#958
willghatch wants to merge 3 commits intomainfrom
users/willghatch/splitk-mxfp4

Conversation

@willghatch
Copy link
Contributor

@willghatch willghatch commented Feb 24, 2026

The core things added are split-k gemm, and it is tested for (1) generation of the buffer_atomic_pk_add_bf16 instruction that we wanted to use, and (2) for gemm correctness.

Overview of changes unrelated to wave_asm:

  • remove_global_indexing in general_utils.py: Zeroes out tiling constraint starts (e.g. K_SPLIT_OFF) alongside workgroup IDs before dimension scaling, so that the subtraction of the start offset doesn't mix scaled and unscaled units (K vs K/32 for MXFP4 scales).

  • Fixing spurious bounds on split-K tiling that prevented scale vector merging: TilingConstraint.get_index_bound was conservatively generating bounds for the split-K case because sympy could not prove that ceiling(Min(K, f(wg)) / tile)

    • tile <= K. These bounds prevented merge_contiguous_reads from combining scalar scale reads into vector<4xi8> loads (it skips reads that already have bounds). Add _work_may_exceed_dim() to structurally detect the aligned split-k pattern and prove no overshoot, avoiding the spurious bound. (This was necessary to get scale_preshuffle to have 4x vector loads when combined with split-k.)

@willghatch
Copy link
Contributor Author

@harsh-nod This has splitk with preshuffle_scales functional with the 4x vector load. I've done some basic cleanup, but as mentioned there are still parts of it that I haven't fully reviewed or understood.

@willghatch willghatch force-pushed the users/willghatch/splitk-mxfp4 branch 2 times, most recently from 88b0c99 to 8de6506 Compare February 24, 2026 18:44
@willghatch
Copy link
Contributor Author

@harsh-nod this is now rebased on top of main, which now has the wave_asm backend commit that you carved out of this one. So it should be ready to go.

Returns ``False`` (no overshoot) when we can prove that the tiled work
never exceeds the tensor dimension. In particular this handles the
split-K pattern where ``work_bound = tile * ceiling(Min(dim, f(wg)) / tile)``
and ``dim`` is tile-aligned: ``ceiling(Min(dim, x) / tile) * tile <= dim``.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain what is happening here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic is now tightened, but it is improving the analysis for whether a read might overshoot bounds. By checking for the pattern of tiles where reads are bound to min(X, bound), we know that it is still within the bound, and we don't need to emit the bound guard rails. This allows the read merge to take effect to load the 4xi8 vectors instead of individual bytes.

@willghatch willghatch force-pushed the users/willghatch/splitk-mxfp4 branch from 6bbc407 to f02dc5b Compare February 25, 2026 21:10
The core things added are split-k gemm, and it is tested for (1) generation of the `buffer_atomic_pk_add_bf16` instruction that we wanted to use, and (2) for gemm correctness.

Overview of some of the major changes:

- `remove_global_indexing` in `general_utils.py`: Zeroes out tiling constraint
  starts (e.g. `K_SPLIT_OFF`) alongside workgroup IDs before dimension scaling,
  so that the subtraction of the start offset doesn't mix scaled and unscaled
  units (K vs K/32 for MXFP4 scales).

- Fixing spurious bounds on split-K tiling that prevented scale vector merging:
  TilingConstraint.get_index_bound was conservatively generating bounds for the
  split-K case because sympy could not prove that ceiling(Min(K, f(wg)) / tile)
  * tile <= K.  These bounds prevented merge_contiguous_reads from combining
  scalar scale reads into vector<4xi8> loads (it skips reads that already have
  bounds).  Add _work_may_exceed_dim() to structurally detect the aligned
  split-k pattern and prove no overshoot, avoiding the spurious bound.  (This
  was necessary to get scale_preshuffle to have 4x vector loads when combined
  with split-k.)

Signed-off-by: William G Hatch <william@hatch.uno>
@willghatch willghatch force-pushed the users/willghatch/splitk-mxfp4 branch from f02dc5b to 65575f0 Compare February 25, 2026 23:28
Signed-off-by: William G Hatch <william@hatch.uno>
@willghatch willghatch force-pushed the users/willghatch/splitk-mxfp4 branch from 65575f0 to ca5f8e8 Compare February 25, 2026 23:31
)
for _p in [str(_EXAMPLES_DIR), str(_WAVE_ROOT), str(_E2E_DIR)]:
if _p not in sys.path:
sys.path.insert(0, _p)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of this could you modify the imports so we can do something like
import WaveASMCompiler, capture_wave_kernel_info ?

torch.cuda.synchronize()

bf16_eps = 2**-7
atol = num_splits * bf16_eps * max(torch_ref.abs().max().item(), 1.0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does atol depend on num_splits? Shouldnt it be independent of num splits?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, the number of splits increases the accumulation of errors. IE we get error from casting to BF16, then we accumulate in BF16 which means we accumulate error num_splits times.

torch.cuda.synchronize()

bf16_eps = 2**-7
atol = num_splits * bf16_eps * max(torch_ref.abs().max().item(), 1.0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here regarding num_splits

w_scales_gpu = w_scales.cuda()
c_gpu = device_zeros(m, n, dtype=torch.bfloat16)

splitk_gemm(x_gpu, x_scales_gpu, w_t_gpu, w_scales_gpu, c_gpu)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we have the bitcast to fp16 controllable through a flag so that for the correctness tests, we disable the bitcast?

dim_int = int(dim_bound)
tile, ceil_expr = _extract_tile_and_ceiling(work_bound)
if tile is not None and dim_int % tile == 0 and ceil_expr is not None:
numerator = (ceil_expr.args[0] * tile).simplify()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the use of .simplify important here? Are you relying sympy to transform this to a canonical form?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And if so, will this work

numer, _ = ceil_expr.args[0].as_numer_denom()
  if isinstance(numer, Min) and any(a == dim_int for a in numer.args):
      return False

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants