Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 26 additions & 26 deletions lit_tests/kernel/wave/expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,21 +108,21 @@ def test_read_write_equal_sizes():
# CHECK-NEXT: placeholder(_name=a
# CHECK-NEXT: placeholder(_name=c
# CHECK-NEXT: read(memory=a
# CHECK-SAME: index={M: $WG0*BLOCK_M + BLOCK_M*floor($T0/64) + Mod($T0, 64) : 1 : 16, N: $T1*BLOCK_N + $WG1*BLOCK_N : 4 : 1}
# CHECK-SAME: index={M: $WG0*BLOCK_M + BLOCK_M*floor($T0/64) + Mod($T0, 64) : 1 : 1, N: $T1*BLOCK_N + $WG1*BLOCK_N : 4 : 1}
# CHECK-NEXT: read(memory=a
# CHECK-SAME: index={M: $WG0*BLOCK_M + BLOCK_M*floor($T0/64) + Mod($T0, 64) : 1 : 16, N: $T1*BLOCK_N + $WG1*BLOCK_N + 16 : 4 : 1}
# CHECK-SAME: index={M: $WG0*BLOCK_M + BLOCK_M*floor($T0/64) + Mod($T0, 64) : 1 : 1, N: $T1*BLOCK_N + $WG1*BLOCK_N + 16 : 4 : 1}
# CHECK-NEXT: read(memory=a
# CHECK-SAME: index={M: $WG0*BLOCK_M + BLOCK_M*floor($T0/64) + Mod($T0, 64) + 16 : 1 : 16, N: $T1*BLOCK_N + $WG1*BLOCK_N : 4 : 1}
# CHECK-SAME: index={M: $WG0*BLOCK_M + BLOCK_M*floor($T0/64) + Mod($T0, 64) + 16 : 1 : 1, N: $T1*BLOCK_N + $WG1*BLOCK_N : 4 : 1}
# CHECK-NEXT: read(memory=a
# CHECK-SAME: index={M: $WG0*BLOCK_M + BLOCK_M*floor($T0/64) + Mod($T0, 64) + 16 : 1 : 16, N: $T1*BLOCK_N + $WG1*BLOCK_N + 16 : 4 : 1}
# CHECK-SAME: index={M: $WG0*BLOCK_M + BLOCK_M*floor($T0/64) + Mod($T0, 64) + 16 : 1 : 1, N: $T1*BLOCK_N + $WG1*BLOCK_N + 16 : 4 : 1}
# CHECK-NEXT: write(register_=read_M:0_N:0
# CHECK-SAME: index={M: $WG0*BLOCK_M + BLOCK_M*floor($T0/64) + Mod($T0, 64) : 1 : 16, N: $T1*BLOCK_N + $WG1*BLOCK_N : 4 : 1}
# CHECK-SAME: index={M: $WG0*BLOCK_M + BLOCK_M*floor($T0/64) + Mod($T0, 64) : 1 : 1, N: $T1*BLOCK_N + $WG1*BLOCK_N : 4 : 1}
# CHECK-NEXT: write(register_=read_M:0_N:1
# CHECK-SAME: index={M: $WG0*BLOCK_M + BLOCK_M*floor($T0/64) + Mod($T0, 64) : 1 : 16, N: $T1*BLOCK_N + $WG1*BLOCK_N + 16 : 4 : 1}
# CHECK-SAME: index={M: $WG0*BLOCK_M + BLOCK_M*floor($T0/64) + Mod($T0, 64) : 1 : 1, N: $T1*BLOCK_N + $WG1*BLOCK_N + 16 : 4 : 1}
# CHECK-NEXT: write(register_=read_M:1_N:0
# CHECK-SAME: index={M: $WG0*BLOCK_M + BLOCK_M*floor($T0/64) + Mod($T0, 64) + 16 : 1 : 16, N: $T1*BLOCK_N + $WG1*BLOCK_N : 4 : 1}
# CHECK-SAME: index={M: $WG0*BLOCK_M + BLOCK_M*floor($T0/64) + Mod($T0, 64) + 16 : 1 : 1, N: $T1*BLOCK_N + $WG1*BLOCK_N : 4 : 1}
# CHECK-NEXT: write(register_=read_M:1_N:1
# CHECK-SAME: index={M: $WG0*BLOCK_M + BLOCK_M*floor($T0/64) + Mod($T0, 64) + 16 : 1 : 16, N: $T1*BLOCK_N + $WG1*BLOCK_N + 16 : 4 : 1}
# CHECK-SAME: index={M: $WG0*BLOCK_M + BLOCK_M*floor($T0/64) + Mod($T0, 64) + 16 : 1 : 1, N: $T1*BLOCK_N + $WG1*BLOCK_N + 16 : 4 : 1}
# CHECK-NEXT: output

# CHECK: -----
Expand Down Expand Up @@ -187,17 +187,17 @@ def test_read_write():
# CHECK-NEXT: placeholder(_name=a
# CHECK-NEXT: placeholder(_name=c
# CHECK-NEXT: read(memory=a
# CHECK-SAME: index={M: $T0*BLOCK_M/64 + $WG0*BLOCK_M + Mod($T0, 64) : 1 : 16, N: $T1*BLOCK_N + $WG1*BLOCK_N : 4 : 1}
# CHECK-SAME: index={M: $T0*BLOCK_M/64 + $WG0*BLOCK_M + Mod($T0, 64) : 1 : 1, N: $T1*BLOCK_N + $WG1*BLOCK_N : 4 : 1}
# CHECK-NEXT: read(memory=a
# CHECK-SAME: index={M: $T0*BLOCK_M/64 + $WG0*BLOCK_M + Mod($T0, 64) + 16 : 1 : 16, N: $T1*BLOCK_N + $WG1*BLOCK_N : 4 : 1}
# CHECK-SAME: index={M: $T0*BLOCK_M/64 + $WG0*BLOCK_M + Mod($T0, 64) + 16 : 1 : 1, N: $T1*BLOCK_N + $WG1*BLOCK_N : 4 : 1}
# CHECK-NEXT: write(register_=read_M:0_N:0_K:0
# CHECK-SAME: index={M: $T0*BLOCK_M/64 + $WG0*BLOCK_M + Mod($T0, 64) : 1 : 16, K: $T2*BLOCK_K + $WG2*BLOCK_K : 4 : 1}
# CHECK-SAME: index={M: $T0*BLOCK_M/64 + $WG0*BLOCK_M + Mod($T0, 64) : 1 : 1, K: $T2*BLOCK_K + $WG2*BLOCK_K : 4 : 1}
# CHECK-NEXT: write(register_=read_M:0_N:0_K:0
# CHECK-SAME: index={M: $T0*BLOCK_M/64 + $WG0*BLOCK_M + Mod($T0, 64) : 1 : 16, K: $T2*BLOCK_K + $WG2*BLOCK_K + 16 : 4 : 1}
# CHECK-SAME: index={M: $T0*BLOCK_M/64 + $WG0*BLOCK_M + Mod($T0, 64) : 1 : 1, K: $T2*BLOCK_K + $WG2*BLOCK_K + 16 : 4 : 1}
# CHECK-NEXT: write(register_=read_M:1_N:0_K:0
# CHECK-SAME: index={M: $T0*BLOCK_M/64 + $WG0*BLOCK_M + Mod($T0, 64) + 16 : 1 : 16, K: $T2*BLOCK_K + $WG2*BLOCK_K : 4 : 1}
# CHECK-SAME: index={M: $T0*BLOCK_M/64 + $WG0*BLOCK_M + Mod($T0, 64) + 16 : 1 : 1, K: $T2*BLOCK_K + $WG2*BLOCK_K : 4 : 1}
# CHECK-NEXT: write(register_=read_M:1_N:0_K:0
# CHECK-SAME: index={M: $T0*BLOCK_M/64 + $WG0*BLOCK_M + Mod($T0, 64) + 16 : 1 : 16, K: $T2*BLOCK_K + $WG2*BLOCK_K + 16 : 4 : 1}
# CHECK-SAME: index={M: $T0*BLOCK_M/64 + $WG0*BLOCK_M + Mod($T0, 64) + 16 : 1 : 1, K: $T2*BLOCK_K + $WG2*BLOCK_K + 16 : 4 : 1}
# CHECK-NEXT: output

# CHECK: -----
Expand Down Expand Up @@ -1126,18 +1126,18 @@ def py_arithmetic_different_dims():
# CHECK: Custom format:
# CHECK-NEXT: placeholder(_name=a
# CHECK-NEXT: placeholder(_name=c
# CHECK-NEXT: read(memory=a, elements_per_thread=4, mapping_dynamic_vals=(), flags=MemoryAccessFlags.NONE, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 64) : 1 : 16, N: $T1*BLOCK_N/4 + $WG1*BLOCK_N : 4 : 1}
# CHECK-NEXT: read(memory=a, elements_per_thread=4, mapping_dynamic_vals=(), flags=MemoryAccessFlags.NONE, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 64) + 16 : 1 : 16, N: $T1*BLOCK_N/4 + $WG1*BLOCK_N : 4 : 1}
# CHECK-NEXT: add(lhs=read_M:0_N:0_K:0, rhs=read_M:0_N:0_K:0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 64) : 1 : 16, N: $T1*BLOCK_N/4 + $WG1*BLOCK_N : 4 : 1}
# CHECK-NEXT: add(lhs=read_M:1_N:0_K:0, rhs=read_M:1_N:0_K:0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 64) + 16 : 1 : 16, N: $T1*BLOCK_N/4 + $WG1*BLOCK_N : 4 : 1}
# CHECK-NEXT: sub(lhs=add_M:0_N:0_K:0, rhs=read_M:0_N:0_K:0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 64) : 1 : 16, N: $T1*BLOCK_N/4 + $WG1*BLOCK_N : 4 : 1}
# CHECK-NEXT: sub(lhs=add_M:1_N:0_K:0, rhs=read_M:1_N:0_K:0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 64) + 16 : 1 : 16, N: $T1*BLOCK_N/4 + $WG1*BLOCK_N : 4 : 1}
# CHECK-NEXT: neg(arg=sub_M:0_N:0_K:0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 64) : 1 : 16, N: $T1*BLOCK_N/4 + $WG1*BLOCK_N : 4 : 1}
# CHECK-NEXT: neg(arg=sub_M:1_N:0_K:0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 64) + 16 : 1 : 16, N: $T1*BLOCK_N/4 + $WG1*BLOCK_N : 4 : 1}
# CHECK-NEXT: write(register_=neg_M:0_N:0_K:0, memory=c, elements_per_thread=4, mapping_dynamic_vals=(), flags=MemoryAccessFlags.NONE, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 64) : 1 : 16, K: $WG2*BLOCK_K : 4 : 1}
# CHECK-NEXT: write(register_=neg_M:0_N:0_K:0, memory=c, elements_per_thread=4, mapping_dynamic_vals=(), flags=MemoryAccessFlags.NONE, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 64) : 1 : 16, K: $WG2*BLOCK_K + 16 : 4 : 1}
# CHECK-NEXT: write(register_=neg_M:1_N:0_K:0, memory=c, elements_per_thread=4, mapping_dynamic_vals=(), flags=MemoryAccessFlags.NONE, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 64) + 16 : 1 : 16, K: $WG2*BLOCK_K : 4 : 1}
# CHECK-NEXT: write(register_=neg_M:1_N:0_K:0, memory=c, elements_per_thread=4, mapping_dynamic_vals=(), flags=MemoryAccessFlags.NONE, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 64) + 16 : 1 : 16, K: $WG2*BLOCK_K + 16 : 4 : 1}
# CHECK-NEXT: read(memory=a, elements_per_thread=4, mapping_dynamic_vals=(), flags=MemoryAccessFlags.NONE, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 64) : 1 : 1, N: $T1*BLOCK_N/4 + $WG1*BLOCK_N : 4 : 1}
# CHECK-NEXT: read(memory=a, elements_per_thread=4, mapping_dynamic_vals=(), flags=MemoryAccessFlags.NONE, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 64) + 16 : 1 : 1, N: $T1*BLOCK_N/4 + $WG1*BLOCK_N : 4 : 1}
# CHECK-NEXT: add(lhs=read_M:0_N:0_K:0, rhs=read_M:0_N:0_K:0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 64) : 1 : 1, N: $T1*BLOCK_N/4 + $WG1*BLOCK_N : 4 : 1}
# CHECK-NEXT: add(lhs=read_M:1_N:0_K:0, rhs=read_M:1_N:0_K:0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 64) + 16 : 1 : 1, N: $T1*BLOCK_N/4 + $WG1*BLOCK_N : 4 : 1}
# CHECK-NEXT: sub(lhs=add_M:0_N:0_K:0, rhs=read_M:0_N:0_K:0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 64) : 1 : 1, N: $T1*BLOCK_N/4 + $WG1*BLOCK_N : 4 : 1}
# CHECK-NEXT: sub(lhs=add_M:1_N:0_K:0, rhs=read_M:1_N:0_K:0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 64) + 16 : 1 : 1, N: $T1*BLOCK_N/4 + $WG1*BLOCK_N : 4 : 1}
# CHECK-NEXT: neg(arg=sub_M:0_N:0_K:0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 64) : 1 : 1, N: $T1*BLOCK_N/4 + $WG1*BLOCK_N : 4 : 1}
# CHECK-NEXT: neg(arg=sub_M:1_N:0_K:0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 64) + 16 : 1 : 1, N: $T1*BLOCK_N/4 + $WG1*BLOCK_N : 4 : 1}
# CHECK-NEXT: write(register_=neg_M:0_N:0_K:0, memory=c, elements_per_thread=4, mapping_dynamic_vals=(), flags=MemoryAccessFlags.NONE, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 64) : 1 : 1, K: $WG2*BLOCK_K : 4 : 1}
# CHECK-NEXT: write(register_=neg_M:0_N:0_K:0, memory=c, elements_per_thread=4, mapping_dynamic_vals=(), flags=MemoryAccessFlags.NONE, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 64) : 1 : 1, K: $WG2*BLOCK_K + 16 : 4 : 1}
# CHECK-NEXT: write(register_=neg_M:1_N:0_K:0, memory=c, elements_per_thread=4, mapping_dynamic_vals=(), flags=MemoryAccessFlags.NONE, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 64) + 16 : 1 : 1, K: $WG2*BLOCK_K : 4 : 1}
# CHECK-NEXT: write(register_=neg_M:1_N:0_K:0, memory=c, elements_per_thread=4, mapping_dynamic_vals=(), flags=MemoryAccessFlags.NONE, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 64) + 16 : 1 : 1, K: $WG2*BLOCK_K + 16 : 4 : 1}

# CHECK: -----

Expand Down
2 changes: 1 addition & 1 deletion wave_lang/kernel/wave/analysis/index_sequence_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,7 @@ def compute_stride(
stride = int(stride)
except Exception as e:
logger.error(e)
return stride
return 1


def is_contiguous_dim(
Expand Down