Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions examples/python/7.1_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,15 +214,21 @@ def test_dbuf_4wave_mxfp_asymmetric_gemm(


def test_dbuf_4wave_mxfp_preshuffle_b_gemm(
is_debug=False, shape=(1024, 1024, 8192), block=(256, 256, 256)
is_debug=False,
shape=(1024, 1024, 8192),
block=(128, 256, 256),
eliminate_epilogue=False,
):
"""Asymmetric MXFP4 GEMM with preshuffled B data and B scales."""
gemm, options = get_tagged_mxfp4_gemm_preshuffle_b(shape, block, wave_shape=(1, 4))
options.minimize_shared_allocs = True
options.linearize_shared_access = True
options.use_buffer_ops = True
options.eliminate_epilogue = eliminate_epilogue
options.dump_intermediates = "build/intermediates"
schedule = get_mxfp4_asymmetric_schedule(is_bscale_shuffled=True)
schedule = get_mxfp4_asymmetric_schedule(
eliminate_epilogue=eliminate_epilogue, is_bscale_shuffled=True
)

options.print_ir_after = "all" if is_debug else []
options = set_default_run_config(options)
Expand All @@ -232,6 +238,18 @@ def test_dbuf_4wave_mxfp_preshuffle_b_gemm(
print("MXFP GEMM preshuffle-B 4-wave test passed!")


def test_dbuf_4wave_mxfp_preshuffle_b_no_epilogue_gemm(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

once, reviewed, I will just add flag to the main test function (whichever it might be at that point).
This test will be removed.

is_debug=False, shape=(1024, 1024, 8192), block=(256, 256, 256)
):
"""Asymmetric MXFP4 GEMM with preshuffled B, epilogue eliminated via OOB=0."""
test_dbuf_4wave_mxfp_preshuffle_b_gemm(
is_debug=is_debug,
shape=shape,
block=block,
eliminate_epilogue=True,
)


def test_dbuf_4wave_mxfp_asymmetric_gemm_cpp(
is_debug=False, shape=(1024, 1024, 8192), block=(128, 256, 256)
):
Expand Down
1 change: 1 addition & 0 deletions lit_tests/kernel/wave/mlir_roundtrip_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,7 @@ def attention_progressive_roundtrip():
"compute_shared_memory_usage",
"partition_gather_like_ops",
"generate_bounds_exprs",
"guard_g2s_with_bounds_check",
"merge_contiguous_reads",
"location_check_pass",
"simplify_indices",
Expand Down
142 changes: 129 additions & 13 deletions wave_lang/kernel/compiler/wave_codegen/read_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,29 @@ def _valid_bytes_buffer(elem_type: IrType) -> int:
return ans


def _compute_total_valid_bytes(
elem_type: IrType,
symbolic_shape: tuple,
use_real_bounds: bool,
) -> int:
"""Return the total valid byte count for a buffer SRD.

When *use_real_bounds* is True and *symbolic_shape* resolves to a concrete
number, the result is clamped to the actual tensor size (still bounded by
the hardware maximum). Otherwise falls back to the hardware maximum
returned by ``_valid_bytes_buffer``.
"""
if use_real_bounds and symbolic_shape is not None:
elem_bytes = elem_type.width // 8
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For mxfp4, this returns a 0. Would that be a problem?

total_elements = subs_idxc(sympy.prod(s for s in symbolic_shape))
if isinstance(total_elements, (int, float)) or (
hasattr(total_elements, "is_number") and total_elements.is_number
):
total_bytes = int(total_elements) * elem_bytes
return min(total_bytes, _valid_bytes_buffer(elem_type))
return _valid_bytes_buffer(elem_type)


def _get_out_of_bounds_index(element_type: IrType) -> int:
"""
returns the first index that's out of bounds of a buffer based on the element type and maximum bytes
Expand All @@ -362,25 +385,99 @@ def _get_constant_value(candidate: Value):
return candidate.owner.opview.value.value


def _compute_branchless_valid_bytes(
emitter: WaveEmitter,
symbolic_shape: tuple,
elem_type: IrType,
guard_condition: sympy.Basic,
) -> Value:
"""Compute a dynamic validBytes that becomes 0 when OOB.

The guard_condition is a sympy expression like:
iv + prefetch_offset < max_iv

We emit:
cond = gen_sympy_index(guard_condition) # index type, nonzero=true
real_valid = compute_static_validBytes()
validBytes = select(cond != 0, real_valid, 0)

When the condition is false (last iterations), validBytes=0 makes the
SRD's NUM_RECORDS=0 so gather_to_lds DMA is a hardware no-op.
"""
uint64 = IntegerType.get_signless(64)
total_bytes = _compute_total_valid_bytes(
elem_type, symbolic_shape, use_real_bounds=True
)

real_valid = arith_d.constant(uint64, get_constant_attr(total_bytes, uint64))
zero_valid = arith_d.constant(uint64, get_constant_attr(0, uint64))

cond_val = gen_sympy_index(add_emitter_subs(emitter), guard_condition)
i1 = IntegerType.get_signless(1)
if cond_val.type != i1:
zero_idx = arith_d.constant(cond_val.type, 0)
cond_val = arith_d.cmpi(arith_d.CmpIPredicate.ne, cond_val, zero_idx)

return arith_d.select(cond_val, real_valid, zero_valid)


def _cast_buffer_and_encode_stride(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you remove is_read from this as the functionality of this function is going beyond its original definition to something like

mem = _cast_buffer_and_encode_stride(
      mem, strides, element_type, emitter,
      valid_bytes_override=_compute_valid_bytes(
          mem, element_type,
          symbolic_shape if is_read else None,
          emitter,
      ),
  )

and then compute valid bytes becomes

  def _compute_valid_bytes(ptr, elem_type, symbolic_shape, emitter) -> Value:
      use_real_bounds = (
          emitter.options.eliminate_epilogue and symbolic_shape is not None
      )
      total_bytes = _compute_total_valid_bytes(elem_type, symbolic_shape, use_real_bounds)
      uint64 = IntegerType.get_signless(64)

      if use_real_bounds:
          metadata = memref_d.extract_strided_metadata(ptr)
          offset_elements = metadata[1]
          offset_bytes = arith_d.index_cast(uint64, offset_elements)
          elem_bytes_val = arith_d.constant(uint64, get_constant_attr(elem_type.width // 8, uint64))
          offset_bytes = arith_d.muli(offset_bytes, elem_bytes_val)
          total_val = arith_d.constant(uint64, get_constant_attr(total_bytes, uint64))
          return arith_d.subi(total_val, offset_bytes)

      return arith_d.constant(uint64, get_constant_attr(total_bytes, uint64))

ptr: Value, strides: tuple[Value], elem_type: IrType, emitter: WaveEmitter
ptr: Value,
strides: tuple[Value],
elem_type: IrType,
emitter: WaveEmitter,
symbolic_shape: tuple = None,
is_read: bool = True,
valid_bytes_override: Value = None,
) -> Value:
uint64 = IntegerType.get_signless(64)
uint14 = IntegerType.get_signless(14)
elem_bytes = elem_type.width // 8

if valid_bytes_override is not None:
valid_bytes_val = valid_bytes_override
else:
use_real_bounds = (
emitter.options.eliminate_epilogue and symbolic_shape is not None
)
total_bytes = _compute_total_valid_bytes(
elem_type, symbolic_shape, use_real_bounds
)

# With resetOffset, the SRD base is adjusted forward by the memref's
# offset, so validBytes must be total_bytes - offset_bytes to avoid
# over-reporting the valid range past the actual allocation.
#
# For writes, skip the offset subtraction: real buffer bounds are only
# useful for reads (OOB loads return 0), and subtracting the offset
# from a clamped total_bytes can produce values that overflow the
# 32-bit SRD NUM_RECORDS field for output buffers larger than 2 GB.
if use_real_bounds and is_read:
metadata = memref_d.extract_strided_metadata(ptr)
offset_elements = metadata[1]
offset_bytes = arith_d.index_cast(uint64, offset_elements)
elem_bytes_val = arith_d.constant(
uint64, get_constant_attr(elem_bytes, uint64)
)
offset_bytes = arith_d.muli(offset_bytes, elem_bytes_val)
total_bytes_val = arith_d.constant(
uint64, get_constant_attr(total_bytes, uint64)
)
valid_bytes_val = arith_d.subi(total_bytes_val, offset_bytes)
else:
valid_bytes_val = arith_d.constant(
uint64, get_constant_attr(total_bytes, uint64)
)

valid_bytes = _valid_bytes_buffer(
elem_type
) # max bytes that are in range to be addressed from a buffer
valid_bytes_constant = get_constant_attr(valid_bytes, uint64)
valid_bytes_constant = arith_d.constant(uint64, valid_bytes_constant)
stride_rank = len(strides)
swizzle_stride = None

if stride_rank >= 2:
# fastest_dim_bound == second to last stride.
stride_candidate = strides[-2]
stride_int = _get_constant_value(stride_candidate)
# Only swizzle if stride is static and <= 8192(the useful case).
# Only swizzle if stride is static and fits in signed i14
# (max representable positive value is 8191, but 8192 wraps to
# -8192 which the hardware accepts).
if stride_int and stride_int <= 8192:
swizzle_stride = arith_d.index_cast(uint14, stride_candidate)

Expand All @@ -390,14 +487,14 @@ def _cast_buffer_and_encode_stride(
cache_swizzle_stride=swizzle_stride,
bounds_check=True,
reset_offset=True,
valid_bytes=valid_bytes_constant,
valid_bytes=valid_bytes_val,
)
else:
ptr = amdgpu_d.fat_raw_buffer_cast(
ptr,
bounds_check=True,
reset_offset=True,
valid_bytes=valid_bytes_constant,
valid_bytes=valid_bytes_val,
)

return ptr
Expand Down Expand Up @@ -508,7 +605,9 @@ def extract(vec, ind):
mem, offset_th = _linearize_memref(
mem, start_indices_wg, start_indices_th, strides
)
mem = _cast_buffer_and_encode_stride(mem, strides, element_type, emitter)
mem = _cast_buffer_and_encode_stride(
mem, strides, element_type, emitter, symbolic_shape, is_read
)
elif is_global_mem and not is_read:
mem, offset_th = _linearize_memref(
mem, start_indices_wg, start_indices_th, strides
Expand Down Expand Up @@ -553,7 +652,9 @@ def extract(vec, ind):
mem, offset_th = _linearize_memref(
mem, start_indices_wg, start_indices_th, strides
)
mem = _cast_buffer_and_encode_stride(mem, strides, element_type, emitter)
mem = _cast_buffer_and_encode_stride(
mem, strides, element_type, emitter, symbolic_shape, is_read
)

indices = [offset_th] if buffer_ops_enabled else start_indices

Expand Down Expand Up @@ -1109,7 +1210,22 @@ def handle_gather_to_lds(emitter: WaveEmitter, node: fx.Node):
]

src, offset_th = _linearize_memref(src, src_index_wg, src_index_th, strides)
src = _cast_buffer_and_encode_stride(src, strides, element_type, emitter)

valid_bytes_override = None
guard_condition = node.meta.get("g2s_guard", None)
if guard_condition is not None:
valid_bytes_override = _compute_branchless_valid_bytes(
emitter, src_symbolic_shape, element_type, guard_condition
)

src = _cast_buffer_and_encode_stride(
src,
strides,
element_type,
emitter,
src_symbolic_shape,
valid_bytes_override=valid_bytes_override,
)

# We previously checked mask is same for all elements, so we can use
# elements_per_thread=1 to build the mask.
Expand Down
20 changes: 17 additions & 3 deletions wave_lang/kernel/ops/wave_schedule_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,11 @@ def reorder_graph(loop: Any, clusters: Any): ...


@define_schedule_op
def pipeline(iterate: Sequence[fx.Node], multi_buffer_count: Optional[int] = None): ...
def pipeline(
iterate: Sequence[fx.Node],
eliminate_epilogue: bool = False,
multi_buffer_count: Optional[int] = None,
): ...


@define_schedule_op
Expand Down Expand Up @@ -804,11 +808,13 @@ def __init__(
iterate: Sequence[fx.Node],
kernel_trace: "CapturedTrace",
constraints: list[Constraint],
eliminate_epilogue: bool = False,
multi_buffer_count: Optional[int] = None,
):
self.iterate = iterate
self.kernel_trace = kernel_trace
self.constraints = constraints
self.eliminate_epilogue = eliminate_epilogue
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not directly related but does the handle() method of Pipeline need to support taking in multi-buffer count?

self.multi_buffer_count = multi_buffer_count

# Access options from the current ScheduleContext
Expand Down Expand Up @@ -875,6 +881,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
scheduling_type=SchedulingType.MANUAL,
visualize=False,
multi_buffer_count=self.multi_buffer_count,
eliminate_epilogue=self.eliminate_epilogue,
)

# Store the pipelined iterate node and node mapping, then create proxies for the stages
Expand Down Expand Up @@ -937,7 +944,11 @@ def PROLOGUE(self):

@property
def EPILOGUE(self):
"""Get a reference to the EPILOGUE stage (nodes after pipelined iterate)."""
"""Get a reference to the EPILOGUE stage (nodes after pipelined iterate).
Returns None when eliminate_epilogue=True (epilogue was not generated).
"""
if self.eliminate_epilogue:
return None
return self._EPILOGUE

def _update_kernel_node_mapping(self):
Expand Down Expand Up @@ -1110,8 +1121,11 @@ def handle(
kernel_trace,
constraints: list[Constraint],
iterate: Sequence[fx.Node],
eliminate_epilogue: bool = False,
):
real_pipelined_loop = PipelinedLoop(iterate, kernel_trace, constraints)
real_pipelined_loop = PipelinedLoop(
iterate, kernel_trace, constraints, eliminate_epilogue=eliminate_epilogue
)

# Return the real object directly (no proxy needed)
return real_pipelined_loop
Expand Down
5 changes: 5 additions & 0 deletions wave_lang/kernel/wave/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
from .multicast import multicast
from .promotion import compute_shared_memory_usage, promote_placeholders
from .schedule_reordering import schedule_reordering
from .scheduling.loop_reconstruction import guard_g2s_with_bounds_check
from .scheduling.schedule import schedule_graph
from .scheduling.schedule_enums import SchedulingType
from .shared_memory_indexing import apply_shared_memory_indexing_corrections
Expand Down Expand Up @@ -554,6 +555,10 @@ def build_graph_passes(
)
)

graph_passes.append(
partial(guard_g2s_with_bounds_check, trace, launchable.constraints)
)

if options.optimization_level:
graph_passes += [
partial(
Expand Down
1 change: 1 addition & 0 deletions wave_lang/kernel/wave/compile_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ class WaveCompileOptions:
dump_schedule: Optional[str] = None
use_bound_check: bool = False
specialize: bool = False
eliminate_epilogue: bool = False

# Cluster barrier signal/wait delay in number of loop iterations
# None - no barriers inside the loop
Expand Down
Loading