Conversation
|
@harsh-nod This has splitk with preshuffle_scales functional with the 4x vector load. I've done some basic cleanup, but as mentioned there are still parts of it that I haven't fully reviewed or understood. |
88b0c99 to
8de6506
Compare
|
@harsh-nod this is now rebased on top of main, which now has the |
wave_lang/kernel/wave/constraints.py
Outdated
| Returns ``False`` (no overshoot) when we can prove that the tiled work | ||
| never exceeds the tensor dimension. In particular this handles the | ||
| split-K pattern where ``work_bound = tile * ceiling(Min(dim, f(wg)) / tile)`` | ||
| and ``dim`` is tile-aligned: ``ceiling(Min(dim, x) / tile) * tile <= dim``. |
There was a problem hiding this comment.
Can you explain what is happening here?
There was a problem hiding this comment.
The logic is now tightened, but it is improving the analysis for whether a read might overshoot bounds. By checking for the pattern of tiles where reads are bound to min(X, bound), we know that it is still within the bound, and we don't need to emit the bound guard rails. This allows the read merge to take effect to load the 4xi8 vectors instead of individual bytes.
6bbc407 to
f02dc5b
Compare
The core things added are split-k gemm, and it is tested for (1) generation of the `buffer_atomic_pk_add_bf16` instruction that we wanted to use, and (2) for gemm correctness. Overview of some of the major changes: - `remove_global_indexing` in `general_utils.py`: Zeroes out tiling constraint starts (e.g. `K_SPLIT_OFF`) alongside workgroup IDs before dimension scaling, so that the subtraction of the start offset doesn't mix scaled and unscaled units (K vs K/32 for MXFP4 scales). - Fixing spurious bounds on split-K tiling that prevented scale vector merging: TilingConstraint.get_index_bound was conservatively generating bounds for the split-K case because sympy could not prove that ceiling(Min(K, f(wg)) / tile) * tile <= K. These bounds prevented merge_contiguous_reads from combining scalar scale reads into vector<4xi8> loads (it skips reads that already have bounds). Add _work_may_exceed_dim() to structurally detect the aligned split-k pattern and prove no overshoot, avoiding the spurious bound. (This was necessary to get scale_preshuffle to have 4x vector loads when combined with split-k.) Signed-off-by: William G Hatch <william@hatch.uno>
f02dc5b to
65575f0
Compare
Signed-off-by: William G Hatch <william@hatch.uno>
65575f0 to
ca5f8e8
Compare
| ) | ||
| for _p in [str(_EXAMPLES_DIR), str(_WAVE_ROOT), str(_E2E_DIR)]: | ||
| if _p not in sys.path: | ||
| sys.path.insert(0, _p) |
There was a problem hiding this comment.
Instead of this could you modify the imports so we can do something like
import WaveASMCompiler, capture_wave_kernel_info ?
| torch.cuda.synchronize() | ||
|
|
||
| bf16_eps = 2**-7 | ||
| atol = num_splits * bf16_eps * max(torch_ref.abs().max().item(), 1.0) |
There was a problem hiding this comment.
Why does atol depend on num_splits? Shouldnt it be independent of num splits?
There was a problem hiding this comment.
No, the number of splits increases the accumulation of errors. IE we get error from casting to BF16, then we accumulate in BF16 which means we accumulate error num_splits times.
| torch.cuda.synchronize() | ||
|
|
||
| bf16_eps = 2**-7 | ||
| atol = num_splits * bf16_eps * max(torch_ref.abs().max().item(), 1.0) |
There was a problem hiding this comment.
same here regarding num_splits
| w_scales_gpu = w_scales.cuda() | ||
| c_gpu = device_zeros(m, n, dtype=torch.bfloat16) | ||
|
|
||
| splitk_gemm(x_gpu, x_scales_gpu, w_t_gpu, w_scales_gpu, c_gpu) |
There was a problem hiding this comment.
Can we have the bitcast to fp16 controllable through a flag so that for the correctness tests, we disable the bitcast?
| dim_int = int(dim_bound) | ||
| tile, ceil_expr = _extract_tile_and_ceiling(work_bound) | ||
| if tile is not None and dim_int % tile == 0 and ceil_expr is not None: | ||
| numerator = (ceil_expr.args[0] * tile).simplify() |
There was a problem hiding this comment.
Is the use of .simplify important here? Are you relying sympy to transform this to a canonical form?
There was a problem hiding this comment.
And if so, will this work
numer, _ = ceil_expr.args[0].as_numer_denom()
if isinstance(numer, Min) and any(a == dim_int for a in numer.args):
return False
The core things added are split-k gemm, and it is tested for (1) generation of the
buffer_atomic_pk_add_bf16instruction that we wanted to use, and (2) for gemm correctness.Overview of changes unrelated to wave_asm:
remove_global_indexingingeneral_utils.py: Zeroes out tiling constraint starts (e.g.K_SPLIT_OFF) alongside workgroup IDs before dimension scaling, so that the subtraction of the start offset doesn't mix scaled and unscaled units (K vs K/32 for MXFP4 scales).Fixing spurious bounds on split-K tiling that prevented scale vector merging: TilingConstraint.get_index_bound was conservatively generating bounds for the split-K case because sympy could not prove that ceiling(Min(K, f(wg)) / tile)