Coalesce B-scale reads for dynamic-dim MXFP4 preshuffle kernels#1058
Coalesce B-scale reads for dynamic-dim MXFP4 preshuffle kernels#1058Hardcode84 wants to merge 15 commits intoiree-org:mainfrom
Conversation
Signed-off-by: xintin <gaurav.verma@amd.com>
Signed-off-by: xintin <gaurav.verma@amd.com>
Signed-off-by: xintin <gaurav.verma@amd.com>
Signed-off-by: xintin <gaurav.verma@amd.com>
Signed-off-by: xintin <gaurav.verma@amd.com>
Signed-off-by: xintin <gaurav.verma@amd.com>
Signed-off-by: xintin <gaurav.verma@amd.com>
Signed-off-by: xintin <gaurav.verma@amd.com>
…explosion The old _pairwise_merge used O(n²) symbolic diff resolution via sympy.lambdify, which hangs on huge preshuffle index expressions (postorder_traversal of the diff tree never completes). Replace with xreplace-based numeric evaluation of each offset independently, dict lookup for O(1) candidate matching, and verification across multiple probe value sets. Fixes dynamic preshuffle MXFP4 GEMM compilation hanging in merge_contiguous_reads (128 reads now merge in ~1s). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
Previously, _group_reads_by_memory skipped reads with
precomputed_mask_expr, preventing merged ept=2 reads from being
further merged to ept=4/8/16. Fix by removing the skip and remapping
the sub-read's iota symbol ($IOTA{old_size} -> $IOTA{wide_ept} -
offset) when composing masks in _build_wide_mask_expr.
Result: dynamic preshuffle MXFP4 b-tensor loads go from 332
vector<2xi8> + 84 vector<16xi8> to 8 vector<2xi8> + 120
vector<16xi8>.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
…scing The pairwise merge uses numeric probing to verify that adjacent reads have consistent per-dim offset diffs across multiple probe points. With symbolic K, the 2D decomposition (row = offset floordiv K/2, col = offset mod K/2) gives inconsistent diffs when probe values don't respect divisibility constraints — e.g. at K=137, K/2=68, adjacent bytes straddle a row boundary that doesn't exist at K=256. Fix by applying divisibility forward subs (K -> 256*K') before probing, so floordiv/Mod evaluate consistently across all probes. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
Tests use actual B-scale preshuffle index expressions (row = floor(offset / (K/2)), col = offset mod (K/2)) to verify that: - Flat offset diffs are always correct regardless of probe values. - Per-dim diffs are inconsistent without divisibility subs (the bug). - Per-dim diffs become consistent after K -> 256*K' substitution. - _find_merge_dim_from_diffs correctly identifies the merge dimension. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
Verifies that divisibility substitutions (K % 256) enable the read coalescer to produce clean vector<16xi8> B-scale and vector<4xi8> A-scale loads from fat_raw_buffer, with no vector.from_elements fragmentation, when M, N, K are dynamic. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
| # CHECK: amdgpu.fat_raw_buffer_cast | ||
|
|
||
| # B-scale reads are clean vector<16xi8> from fat_raw_buffer — no | ||
| # fragmentation into mixed-width loads glued by from_elements. |
| return groups | ||
|
|
||
|
|
||
| def _resolve_symbolic_diff(raw_diff, has_complex_mapping, expected_vals=None): |
| elems_per_reg = 32 // elem_bits | ||
|
|
||
| reg_offset = offset_val // elems_per_reg | ||
| reg_count = max(1, (size_val * elem_bits + 31) // 32) |
There was a problem hiding this comment.
we have a ceildiv function if that is what this doing?
|
|
||
| def _pairwise_merge( | ||
| read_infos, ept, symbolic_dims, symbolic_shape, hw_constraint, divisibility_fwd=None | ||
| ): |
There was a problem hiding this comment.
There is a lot of overlap between this and multiway_merge. Can you refactor these 2 functions? In theory they could both use a common ProbeEvaluator + ReadInfo class, where ProbeEvaluator implements verify_diff, offset_map, etc.
| hw_constraint, | ||
| divisibility_fwd=None, | ||
| ): | ||
| """Coalesce unmerged ept==1 reads whose flat offsets fall in an aligned window. |
There was a problem hiding this comment.
The probe-based approach computes num_offs once (O(n)), but then for each anchor, iterates over all probes:
for anchor_idx in range(len(unmerged_infos)):
...
for probe_idx in range(len(unmerged_infos)):With the same dict-lookup pattern used for _pairwise_merge, the inner loop could check offset_map[target] for each window position instead of scanning all probes.
| consistent = True | ||
| for ep in extra_probes: | ||
| try: | ||
| va = _eval_expr(resolved_offs[anchor_idx], ep) |
There was a problem hiding this comment.
Also could do with more expressive variable names.
| break | ||
| if not consistent: | ||
| continue | ||
| _, custom_p, node_p = ( |
There was a problem hiding this comment.
Why not just unpack unmerged_infos directly?
Problem
When M, N, K are dynamic, the read coalescer in
partition_strided_operators.pyfails to merge the 16 per-thread B-scale byte reads into a single contiguousvector<16xi8>load. Instead it produces fragmented loads of mixed widths ({2, 16, 8, 4}) stitched together withvector.from_elements.Root cause: the pairwise merge verifies candidate pairs by computing per-dim offset diffs at multiple numeric probe points. For B-scale buffers with shape
[N, K/2], the 2D decompositionrow = offset floordiv K/2, col = offset mod K/2gives inconsistent diffs when probe values (e.g. K=137 → K/2=68) don't respect the kernel's divisibility constraints (K % 256 == 0). The verification then rejects perfectly valid merge candidates.With static dims this doesn't happen because K is concrete and the decomposition is trivially consistent.
Solution
Divisibility substitutions before probing: plumb
get_divisibility_subs(constraints)into the merge pipeline. Before numeric probing, apply forward subs likeK → 256*K'so thatfloordiv/Modevaluate consistently across all probe sets. This is the key fix — with it, B-scale reads coalesce into 8 cleanvector<16xi8>loads identical to the static-dims case.Numeric probing for pairwise merge: replace the old symbolic diff approach (which exploded on complex preshuffle index expressions) with concrete numeric evaluation using three linear generators that avoid pathological values.
Re-merging across mask levels: allow reads that already carry precomputed masks (from prior merge rounds) to participate in further merging. The mask is extended to cover the wider result and the precomputed condition is preserved.
Bounds pre-flattening: pre-compute each read's bounds check as a flat sympy boolean before merging, so the coalescer can reason about mask compatibility without re-deriving bounds at each merge step.