Conversation
a942e22 to
f146ad5
Compare
f146ad5 to
b1a1c3a
Compare
|
You will need to fix the following to get this working on main: |
Signed-off-by: xintin <gaurav.verma@amd.com>
Signed-off-by: xintin <gaurav.verma@amd.com>
Signed-off-by: xintin <gaurav.verma@amd.com>
Signed-off-by: xintin <gaurav.verma@amd.com>
Signed-off-by: xintin <gaurav.verma@amd.com>
Signed-off-by: xintin <gaurav.verma@amd.com>
51804ec to
9a6d8f5
Compare
Signed-off-by: xintin <gaurav.verma@amd.com>
3faf480 to
951d391
Compare
da2591c to
0f53afe
Compare
Signed-off-by: xintin <gaurav.verma@amd.com>
|
@xintin - does this work with dynamic shapes where there is loop peeling? |
Yes, it works with the dynamic shape, llvm backend. |
| print("MXFP GEMM preshuffle-B 4-wave test passed!") | ||
|
|
||
|
|
||
| def test_dbuf_4wave_mxfp_preshuffle_b_no_epilogue_gemm( |
There was a problem hiding this comment.
once, reviewed, I will just add flag to the main test function (whichever it might be at that point).
This test will be removed.
| epilogue_bitcast_b_scale_itr0, epilogue_bitcast_b_scale_itr1 = ( | ||
| split_by_iteration(epilogue_bitcast_b_scale) | ||
| ) | ||
| epilogue_mma = tkw.filter_nodes(scaled_mma, subgraph=pipeline_loop.EPILOGUE) |
There was a problem hiding this comment.
seems to be an additional tab here?
There was a problem hiding this comment.
That's how pre-commit is formatting it but i can check once.
| returned by ``_valid_bytes_buffer``. | ||
| """ | ||
| if use_real_bounds and symbolic_shape is not None: | ||
| elem_bytes = elem_type.width // 8 |
There was a problem hiding this comment.
For mxfp4, this returns a 0. Would that be a problem?
| self.iterate = iterate | ||
| self.kernel_trace = kernel_trace | ||
| self.constraints = constraints | ||
| self.eliminate_epilogue = eliminate_epilogue |
There was a problem hiding this comment.
Not directly related but does the handle() method of Pipeline need to support taking in multi-buffer count?
| return arith_d.select(cond_val, real_valid, zero_valid) | ||
|
|
||
|
|
||
| def _cast_buffer_and_encode_stride( |
There was a problem hiding this comment.
Could you remove is_read from this as the functionality of this function is going beyond its original definition to something like
mem = _cast_buffer_and_encode_stride(
mem, strides, element_type, emitter,
valid_bytes_override=_compute_valid_bytes(
mem, element_type,
symbolic_shape if is_read else None,
emitter,
),
)
and then compute valid bytes becomes
def _compute_valid_bytes(ptr, elem_type, symbolic_shape, emitter) -> Value:
use_real_bounds = (
emitter.options.eliminate_epilogue and symbolic_shape is not None
)
total_bytes = _compute_total_valid_bytes(elem_type, symbolic_shape, use_real_bounds)
uint64 = IntegerType.get_signless(64)
if use_real_bounds:
metadata = memref_d.extract_strided_metadata(ptr)
offset_elements = metadata[1]
offset_bytes = arith_d.index_cast(uint64, offset_elements)
elem_bytes_val = arith_d.constant(uint64, get_constant_attr(elem_type.width // 8, uint64))
offset_bytes = arith_d.muli(offset_bytes, elem_bytes_val)
total_val = arith_d.constant(uint64, get_constant_attr(total_bytes, uint64))
return arith_d.subi(total_val, offset_bytes)
return arith_d.constant(uint64, get_constant_attr(total_bytes, uint64))
With epilogue elimination, the loop simply runs for N extra iterations instead. Those extra iterations load OOB data, but
validBytes = select(iv + prefetch_offset < max_iv, real_validBytes, 0). When OOB,validBytes=0makesNUM_RECORDS=0in the SRD, so the hardware DMA becomes a no-op without any branch.Since the OOB loads produce zeros, the extra MFMAs are
0 x something + accumulator = accumulator, no-ops that don't change the result. The epilogue code is eliminated.