Skip to content

Epilogue elimination in MXFP4#975

Open
xintin wants to merge 12 commits intomainfrom
xintin/epilogue_elimination
Open

Epilogue elimination in MXFP4#975
xintin wants to merge 12 commits intomainfrom
xintin/epilogue_elimination

Conversation

@xintin
Copy link
Contributor

@xintin xintin commented Feb 25, 2026

With epilogue elimination, the loop simply runs for N extra iterations instead. Those extra iterations load OOB data, but

  • VGPR buffer loads return 0 (hardware guarantee via SRD validBytes)
  • GatherToLDS uses branchless SRD zeroing: validBytes = select(iv + prefetch_offset < max_iv, real_validBytes, 0). When OOB, validBytes=0 makes NUM_RECORDS=0 in the SRD, so the hardware DMA becomes a no-op without any branch.

Since the OOB loads produce zeros, the extra MFMAs are 0 x something + accumulator = accumulator, no-ops that don't change the result. The epilogue code is eliminated.

  • Add a parameter eliminate_epilogue to get_mxfp4_asymmetric_schedule
  • Pass it through to tkw.pipeline(k_loop, eliminate_epilogue=eliminate_epilogue)

@xintin xintin force-pushed the xintin/epilogue_elimination branch 2 times, most recently from a942e22 to f146ad5 Compare February 25, 2026 18:56
@xintin xintin changed the title [wip] Epilogue elimination in MXFP4 Epilogue elimination in MXFP4 Feb 25, 2026
@xintin xintin changed the title Epilogue elimination in MXFP4 [wip] Epilogue elimination in MXFP4 Feb 25, 2026
@xintin xintin force-pushed the xintin/epilogue_elimination branch from f146ad5 to b1a1c3a Compare February 25, 2026 22:18
@xintin xintin changed the title [wip] Epilogue elimination in MXFP4 Epilogue elimination in MXFP4 Feb 26, 2026
@xintin xintin requested review from harsh-nod and panditsa February 26, 2026 00:39
@harsh-nod
Copy link
Collaborator

You will need to fix the following to get this working on main:

  1. wave_schedule_ops.py: PipelinedLoop.__exit__ now uses
   getattr(self, 'multi_buffer_count', None) instead of hardcoded None,
   so pipeline_loop.multi_buffer_count = 2 in the asymmetric schedule
   is respected (restoring triple-buffering for A-matrix LDS).

2. gemm_mxfp4_double_buffer.py: Restore the missing unroll(KERNEL, 2),
   insert_at_start(MemoryCounterWaitBarrier), and insert_after calls
   in the else-branch of get_mxfp4_asymmetric_schedule. The PR removed
   these but they are needed for correct kernel behavior.

3. 7.1_schedule.py: Fix NameError in test_dbuf_4wave_mxfp_preshuffle_b_no_epilogue_gemm:
   call _run_mxfp_gemm_preshuffle(all=True) instead of non-existent
   _run_mxfp_gemm_preshuffle_b.

xintin added 6 commits March 3, 2026 18:14
Signed-off-by: xintin <gaurav.verma@amd.com>
Signed-off-by: xintin <gaurav.verma@amd.com>
Signed-off-by: xintin <gaurav.verma@amd.com>
Signed-off-by: xintin <gaurav.verma@amd.com>
Signed-off-by: xintin <gaurav.verma@amd.com>
Signed-off-by: xintin <gaurav.verma@amd.com>
@xintin xintin force-pushed the xintin/epilogue_elimination branch from 51804ec to 9a6d8f5 Compare March 3, 2026 18:36
Signed-off-by: xintin <gaurav.verma@amd.com>
@xintin xintin force-pushed the xintin/epilogue_elimination branch from 3faf480 to 951d391 Compare March 3, 2026 18:58
Signed-off-by: xintin <gaurav.verma@amd.com>
@xintin xintin force-pushed the xintin/epilogue_elimination branch from da2591c to 0f53afe Compare March 3, 2026 19:34
Signed-off-by: xintin <gaurav.verma@amd.com>
@xintin xintin requested a review from harsh-nod March 3, 2026 21:58
@harsh-nod
Copy link
Collaborator

@xintin - does this work with dynamic shapes where there is loop peeling?

@xintin
Copy link
Contributor Author

xintin commented Mar 4, 2026

@xintin - does this work with dynamic shapes where there is loop peeling?

it is updated based on yesterday morning changes.
Nothing related to dynamic shapes has been tested for this pr.

Yes, it works with the dynamic shape, llvm backend.

test_block_without_epilogue_64_192_256.mlir (2878 lines) eliminate_epilogue=True:

if (K >= num_stages):
    scf.for #1: pipelined loop, runs full trip count
        - guard (validBytes=0) handles OOB in last iterations
        - no epilogue drain inside
else:
    return zeros

scf.for #2: remainder loop (leftover iterations when K not multiple of num_stages)

No epilogue. The pipelined loop runs all iterations. OOB prefetches are validBytes=0.

test_block_with_epilogue64_192_256.mlir (4465 lines) eliminate_epilogue=False:

if (K >= num_stages):
    scf.for #1: pipelined loop, runs reduced trip count
		- stops early, leaving (num_stages-1) iterations
        - has epilogue drain inside (prologue + kernel + epilogue blocks)
else:
    return zeros

scf.for #2: remainder loop (same as above)

Has epilogue. The pipelined loop stops early, and the drain iterations are generated as an epilogue block inside construct_pipelined_loop. 

print("MXFP GEMM preshuffle-B 4-wave test passed!")


def test_dbuf_4wave_mxfp_preshuffle_b_no_epilogue_gemm(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

once, reviewed, I will just add flag to the main test function (whichever it might be at that point).
This test will be removed.

epilogue_bitcast_b_scale_itr0, epilogue_bitcast_b_scale_itr1 = (
split_by_iteration(epilogue_bitcast_b_scale)
)
epilogue_mma = tkw.filter_nodes(scaled_mma, subgraph=pipeline_loop.EPILOGUE)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems to be an additional tab here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's how pre-commit is formatting it but i can check once.

returned by ``_valid_bytes_buffer``.
"""
if use_real_bounds and symbolic_shape is not None:
elem_bytes = elem_type.width // 8
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For mxfp4, this returns a 0. Would that be a problem?

self.iterate = iterate
self.kernel_trace = kernel_trace
self.constraints = constraints
self.eliminate_epilogue = eliminate_epilogue
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not directly related but does the handle() method of Pipeline need to support taking in multi-buffer count?

return arith_d.select(cond_val, real_valid, zero_valid)


def _cast_buffer_and_encode_stride(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you remove is_read from this as the functionality of this function is going beyond its original definition to something like

mem = _cast_buffer_and_encode_stride(
      mem, strides, element_type, emitter,
      valid_bytes_override=_compute_valid_bytes(
          mem, element_type,
          symbolic_shape if is_read else None,
          emitter,
      ),
  )

and then compute valid bytes becomes

  def _compute_valid_bytes(ptr, elem_type, symbolic_shape, emitter) -> Value:
      use_real_bounds = (
          emitter.options.eliminate_epilogue and symbolic_shape is not None
      )
      total_bytes = _compute_total_valid_bytes(elem_type, symbolic_shape, use_real_bounds)
      uint64 = IntegerType.get_signless(64)

      if use_real_bounds:
          metadata = memref_d.extract_strided_metadata(ptr)
          offset_elements = metadata[1]
          offset_bytes = arith_d.index_cast(uint64, offset_elements)
          elem_bytes_val = arith_d.constant(uint64, get_constant_attr(elem_type.width // 8, uint64))
          offset_bytes = arith_d.muli(offset_bytes, elem_bytes_val)
          total_val = arith_d.constant(uint64, get_constant_attr(total_bytes, uint64))
          return arith_d.subi(total_val, offset_bytes)

      return arith_d.constant(uint64, get_constant_attr(total_bytes, uint64))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants