Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions lit_tests/kernel/wave/scaled_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,9 +361,9 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
# CHECK: }

# Epilogue Local Read
# CHECK-COUNT-16: vector.load {{.*}} : memref<256x16xi8, #gpu.address_space<workgroup>>, vector<1xi8>
# CHECK-COUNT-8: vector.load {{.*}} : memref<256x16xi8, #gpu.address_space<workgroup>>, vector<8xi8>
# CHECK-COUNT-16: vector.load {{.*}} : memref<256x136xi8, #gpu.address_space<workgroup>>, vector<16xi8>
# CHECK-COUNT-8: vector.load {{.*}} : memref<256x16xi8, #gpu.address_space<workgroup>>, vector<1xi8>
# CHECK-COUNT-4: vector.load {{.*}} : memref<256x16xi8, #gpu.address_space<workgroup>>, vector<8xi8>
# CHECK-COUNT-8: vector.load {{.*}} : memref<256x136xi8, #gpu.address_space<workgroup>>, vector<16xi8>

# Epilogue MFMA
Expand Down Expand Up @@ -471,8 +471,8 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
# CHECK: rocdl.s.waitcnt
# CHECK: amdgpu.lds_barrier

# Steady state local loads
# CHECK-COUNT-48: vector.load{{.*}} memref<{{.*}}, #gpu.address_space<workgroup>>
# Steady state local loads (8+4 scale loads as vector<8xi8> + 16+8 data loads as vector<16xi8>)
# CHECK-COUNT-36: vector.load{{.*}} memref<{{.*}}, #gpu.address_space<workgroup>>

# Steady State global load to lds
# CHECK-COUNT-34: amdgpu.gather_to_lds
Expand Down Expand Up @@ -637,9 +637,9 @@ def repeat(
# CHECK: }

# Epilogue Local Read
# CHECK-COUNT-16: vector.load {{.*}} : memref<4096xi8, #gpu.address_space<workgroup>>, vector<1xi8>
# CHECK-COUNT-8: vector.load {{.*}} : memref<4096xi8, #gpu.address_space<workgroup>>, vector<8xi8>
# CHECK-COUNT-16: vector.load {{.*}} : memref<34816xi8, #gpu.address_space<workgroup>>, vector<16xi8>
# CHECK-COUNT-8: vector.load {{.*}} : memref<4096xi8, #gpu.address_space<workgroup>>, vector<1xi8>
# CHECK-COUNT-4: vector.load {{.*}} : memref<4096xi8, #gpu.address_space<workgroup>>, vector<8xi8>
# CHECK-COUNT-8: vector.load {{.*}} : memref<34816xi8, #gpu.address_space<workgroup>>, vector<16xi8>

# Epilogue MFMA
Expand Down
4 changes: 2 additions & 2 deletions lit_tests/kernel/wave/scaled_mma.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,9 +352,9 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
# CHECK-COUNT-4: vector.load {{.*}} : memref<16384x8192xi8, strided<[8192, 1]>>, vector<16xi8>
# CHECK-COUNT-1: vector.load {{.*}} : memref<16384x512xi8, strided<[512, 1]>>, vector<4xi8>
# CHECK: amdgpu.lds_barrier
# CHECK-COUNT-16: vector.load {{.*}} : memref<256x16xi8, #gpu.address_space<workgroup>>, vector<1xi8>
# CHECK-COUNT-8: vector.load {{.*}} : memref<256x16xi8, #gpu.address_space<workgroup>>, vector<8xi8>
# CHECK-COUNT-16: vector.load {{.*}} : memref<256x136xi8, #gpu.address_space<workgroup>>, vector<16xi8>
# CHECK-COUNT-8: vector.load {{.*}} : memref<256x16xi8, #gpu.address_space<workgroup>>, vector<1xi8>
# CHECK-COUNT-4: vector.load {{.*}} : memref<256x16xi8, #gpu.address_space<workgroup>>, vector<8xi8>
# CHECK-COUNT-8: vector.load {{.*}} : memref<256x136xi8, #gpu.address_space<workgroup>>, vector<16xi8>
# CHECK-COUNT-8: vector.bitcast {{.*}} : vector<16xi8> to vector<32xf4E2M1FN>
# CHECK-COUNT-8: vector.bitcast {{.*}} : vector<1xi8> to vector<1xf8E8M0FNU>
Expand Down
50 changes: 32 additions & 18 deletions wave_lang/kernel/compiler/wave_codegen/read_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,7 @@ def _create_vec_read_write(
memory: CustomOp,
mask: Optional[Value],
node_index: Optional[IndexSequence] = None,
use_wide_load_select: bool = False,
) -> Optional[Value]:
is_read = value is None
uint32 = IntegerType.get_signless(32)
Expand Down Expand Up @@ -558,7 +559,8 @@ def extract(vec, ind):
indices = [offset_th] if buffer_ops_enabled else start_indices

if no_masked_load_store_ops:
# find the index at which memory out of bounds of buffer
scalar_offset_th = offset_th

oob_index_value = _get_out_of_bounds_index(element_type)
oob_index = arith_d.constant(IndexType.get(), oob_index_value)

Expand All @@ -582,7 +584,6 @@ def extract(vec, ind):

# based on mask, select between the offsets_vec and out of bounds. In this case all 3 operands can be vectors
selected_index = arith_d.select(mask, offsets_vec, oob_index)
elems = list()

if splatted_mask:
# mask is same for all of them, can just pick the first index
Expand All @@ -595,27 +596,28 @@ def extract(vec, ind):
vector_d.store(value, mem, indices=[selected_index])
return

for i in range(elements_per_thread):
# mask is not same for all elements, need to unroll
this_index = extract(selected_index, i) # this element
if is_read and use_wide_load_select:
result = vector_d.load(vector_type, mem, indices=[scalar_offset_th])
zero_vec = vector_d.broadcast(vector_type, zero)
return arith_d.select(mask, result, zero_vec)

# Unmasked load, using selected_index
singlenumvec_type = VectorType.get([1], vector_type.element_type)
if is_read:
if is_read:
elems = []
for i in range(elements_per_thread):
this_index = extract(selected_index, i)
singlenumvec_type = VectorType.get([1], vector_type.element_type)
elem = vector_d.load(singlenumvec_type, mem, indices=[this_index])
elem = extract(elem, 0)
elems.append(elem)
else:
elem = extract(value, i)
single_num_vector = vector_d.broadcast(singlenumvec_type, elem)
vector_d.store(single_num_vector, mem, indices=[this_index])

if is_read:
# now make a vector from all the elements loaded
return vector_d.from_elements(vector_type, elems)

else: # it was a store, return
return
for i in range(elements_per_thread):
this_index = extract(selected_index, i)
elem = extract(value, i)
singlenumvec_type = VectorType.get([1], vector_type.element_type)
single_num_vector = vector_d.broadcast(singlenumvec_type, elem)
vector_d.store(single_num_vector, mem, indices=[this_index])
return

else:
# normal masked load/store
Expand Down Expand Up @@ -699,7 +701,18 @@ def handle_read(emitter: WaveEmitter, node: fx.Node):
)
dynamic_vals_map_start = _build_dyn_vals_map(mapping, dyn_vals)

if mapping:
is_global_mem = kb_src.type.memory_space is None
buffer_ops_enabled = emitter.options.use_buffer_ops and is_global_mem

precomputed_mask_expr = getattr(node, "precomputed_mask_expr", None)
if precomputed_mask_expr is not None:
mask = gen_sympy_index(add_emitter_subs(emitter), precomputed_mask_expr)
mask_vec_type = VectorType.get(
[elements_per_thread], IntegerType.get_signless(1)
)
if mask.type != mask_vec_type:
mask = vector_d.broadcast(mask_vec_type, mask)
elif mapping:
transformed_index = transform_index_on_mapping(
mapping, input_shape, index, is_read=True
)
Expand Down Expand Up @@ -744,6 +757,7 @@ def handle_read(emitter: WaveEmitter, node: fx.Node):
get_custom(memory),
mask,
node_index=index,
use_wide_load_select=precomputed_mask_expr is not None,
)

emitter.bind_node_proxy(node, IRProxyValue(result))
Expand Down
Loading
Loading