diff --git a/lit_tests/kernel/wave/expansion.py b/lit_tests/kernel/wave/expansion.py index 59a0e512d..9c971f4e6 100644 --- a/lit_tests/kernel/wave/expansion.py +++ b/lit_tests/kernel/wave/expansion.py @@ -108,21 +108,21 @@ def test_read_write_equal_sizes(): # CHECK-NEXT: placeholder(_name=a # CHECK-NEXT: placeholder(_name=c # CHECK-NEXT: read(memory=a - # CHECK-SAME: index={M: $WG0*BLOCK_M + BLOCK_M*floor($T0/64) + Mod($T0, 64) : 1 : 16, N: $T1*BLOCK_N + $WG1*BLOCK_N : 4 : 1} + # CHECK-SAME: index={M: $WG0*BLOCK_M + BLOCK_M*floor($T0/64) + Mod($T0, 64) : 1 : 1, N: $T1*BLOCK_N + $WG1*BLOCK_N : 4 : 1} # CHECK-NEXT: read(memory=a - # CHECK-SAME: index={M: $WG0*BLOCK_M + BLOCK_M*floor($T0/64) + Mod($T0, 64) : 1 : 16, N: $T1*BLOCK_N + $WG1*BLOCK_N + 16 : 4 : 1} + # CHECK-SAME: index={M: $WG0*BLOCK_M + BLOCK_M*floor($T0/64) + Mod($T0, 64) : 1 : 1, N: $T1*BLOCK_N + $WG1*BLOCK_N + 16 : 4 : 1} # CHECK-NEXT: read(memory=a - # CHECK-SAME: index={M: $WG0*BLOCK_M + BLOCK_M*floor($T0/64) + Mod($T0, 64) + 16 : 1 : 16, N: $T1*BLOCK_N + $WG1*BLOCK_N : 4 : 1} + # CHECK-SAME: index={M: $WG0*BLOCK_M + BLOCK_M*floor($T0/64) + Mod($T0, 64) + 16 : 1 : 1, N: $T1*BLOCK_N + $WG1*BLOCK_N : 4 : 1} # CHECK-NEXT: read(memory=a - # CHECK-SAME: index={M: $WG0*BLOCK_M + BLOCK_M*floor($T0/64) + Mod($T0, 64) + 16 : 1 : 16, N: $T1*BLOCK_N + $WG1*BLOCK_N + 16 : 4 : 1} + # CHECK-SAME: index={M: $WG0*BLOCK_M + BLOCK_M*floor($T0/64) + Mod($T0, 64) + 16 : 1 : 1, N: $T1*BLOCK_N + $WG1*BLOCK_N + 16 : 4 : 1} # CHECK-NEXT: write(register_=read_M:0_N:0 - # CHECK-SAME: index={M: $WG0*BLOCK_M + BLOCK_M*floor($T0/64) + Mod($T0, 64) : 1 : 16, N: $T1*BLOCK_N + $WG1*BLOCK_N : 4 : 1} + # CHECK-SAME: index={M: $WG0*BLOCK_M + BLOCK_M*floor($T0/64) + Mod($T0, 64) : 1 : 1, N: $T1*BLOCK_N + $WG1*BLOCK_N : 4 : 1} # CHECK-NEXT: write(register_=read_M:0_N:1 - # CHECK-SAME: index={M: $WG0*BLOCK_M + BLOCK_M*floor($T0/64) + Mod($T0, 64) : 1 : 16, N: $T1*BLOCK_N + $WG1*BLOCK_N + 16 : 4 : 1} + # CHECK-SAME: index={M: $WG0*BLOCK_M + BLOCK_M*floor($T0/64) + Mod($T0, 64) : 1 : 1, N: $T1*BLOCK_N + $WG1*BLOCK_N + 16 : 4 : 1} # CHECK-NEXT: write(register_=read_M:1_N:0 - # CHECK-SAME: index={M: $WG0*BLOCK_M + BLOCK_M*floor($T0/64) + Mod($T0, 64) + 16 : 1 : 16, N: $T1*BLOCK_N + $WG1*BLOCK_N : 4 : 1} + # CHECK-SAME: index={M: $WG0*BLOCK_M + BLOCK_M*floor($T0/64) + Mod($T0, 64) + 16 : 1 : 1, N: $T1*BLOCK_N + $WG1*BLOCK_N : 4 : 1} # CHECK-NEXT: write(register_=read_M:1_N:1 - # CHECK-SAME: index={M: $WG0*BLOCK_M + BLOCK_M*floor($T0/64) + Mod($T0, 64) + 16 : 1 : 16, N: $T1*BLOCK_N + $WG1*BLOCK_N + 16 : 4 : 1} + # CHECK-SAME: index={M: $WG0*BLOCK_M + BLOCK_M*floor($T0/64) + Mod($T0, 64) + 16 : 1 : 1, N: $T1*BLOCK_N + $WG1*BLOCK_N + 16 : 4 : 1} # CHECK-NEXT: output # CHECK: ----- @@ -187,17 +187,17 @@ def test_read_write(): # CHECK-NEXT: placeholder(_name=a # CHECK-NEXT: placeholder(_name=c # CHECK-NEXT: read(memory=a - # CHECK-SAME: index={M: $T0*BLOCK_M/64 + $WG0*BLOCK_M + Mod($T0, 64) : 1 : 16, N: $T1*BLOCK_N + $WG1*BLOCK_N : 4 : 1} + # CHECK-SAME: index={M: $T0*BLOCK_M/64 + $WG0*BLOCK_M + Mod($T0, 64) : 1 : 1, N: $T1*BLOCK_N + $WG1*BLOCK_N : 4 : 1} # CHECK-NEXT: read(memory=a - # CHECK-SAME: index={M: $T0*BLOCK_M/64 + $WG0*BLOCK_M + Mod($T0, 64) + 16 : 1 : 16, N: $T1*BLOCK_N + $WG1*BLOCK_N : 4 : 1} + # CHECK-SAME: index={M: $T0*BLOCK_M/64 + $WG0*BLOCK_M + Mod($T0, 64) + 16 : 1 : 1, N: $T1*BLOCK_N + $WG1*BLOCK_N : 4 : 1} # CHECK-NEXT: write(register_=read_M:0_N:0_K:0 - # CHECK-SAME: index={M: $T0*BLOCK_M/64 + $WG0*BLOCK_M + Mod($T0, 64) : 1 : 16, K: $T2*BLOCK_K + $WG2*BLOCK_K : 4 : 1} + # CHECK-SAME: index={M: $T0*BLOCK_M/64 + $WG0*BLOCK_M + Mod($T0, 64) : 1 : 1, K: $T2*BLOCK_K + $WG2*BLOCK_K : 4 : 1} # CHECK-NEXT: write(register_=read_M:0_N:0_K:0 - # CHECK-SAME: index={M: $T0*BLOCK_M/64 + $WG0*BLOCK_M + Mod($T0, 64) : 1 : 16, K: $T2*BLOCK_K + $WG2*BLOCK_K + 16 : 4 : 1} + # CHECK-SAME: index={M: $T0*BLOCK_M/64 + $WG0*BLOCK_M + Mod($T0, 64) : 1 : 1, K: $T2*BLOCK_K + $WG2*BLOCK_K + 16 : 4 : 1} # CHECK-NEXT: write(register_=read_M:1_N:0_K:0 - # CHECK-SAME: index={M: $T0*BLOCK_M/64 + $WG0*BLOCK_M + Mod($T0, 64) + 16 : 1 : 16, K: $T2*BLOCK_K + $WG2*BLOCK_K : 4 : 1} + # CHECK-SAME: index={M: $T0*BLOCK_M/64 + $WG0*BLOCK_M + Mod($T0, 64) + 16 : 1 : 1, K: $T2*BLOCK_K + $WG2*BLOCK_K : 4 : 1} # CHECK-NEXT: write(register_=read_M:1_N:0_K:0 - # CHECK-SAME: index={M: $T0*BLOCK_M/64 + $WG0*BLOCK_M + Mod($T0, 64) + 16 : 1 : 16, K: $T2*BLOCK_K + $WG2*BLOCK_K + 16 : 4 : 1} + # CHECK-SAME: index={M: $T0*BLOCK_M/64 + $WG0*BLOCK_M + Mod($T0, 64) + 16 : 1 : 1, K: $T2*BLOCK_K + $WG2*BLOCK_K + 16 : 4 : 1} # CHECK-NEXT: output # CHECK: ----- @@ -1126,18 +1126,18 @@ def py_arithmetic_different_dims(): # CHECK: Custom format: # CHECK-NEXT: placeholder(_name=a # CHECK-NEXT: placeholder(_name=c - # CHECK-NEXT: read(memory=a, elements_per_thread=4, mapping_dynamic_vals=(), flags=MemoryAccessFlags.NONE, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 64) : 1 : 16, N: $T1*BLOCK_N/4 + $WG1*BLOCK_N : 4 : 1} - # CHECK-NEXT: read(memory=a, elements_per_thread=4, mapping_dynamic_vals=(), flags=MemoryAccessFlags.NONE, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 64) + 16 : 1 : 16, N: $T1*BLOCK_N/4 + $WG1*BLOCK_N : 4 : 1} - # CHECK-NEXT: add(lhs=read_M:0_N:0_K:0, rhs=read_M:0_N:0_K:0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 64) : 1 : 16, N: $T1*BLOCK_N/4 + $WG1*BLOCK_N : 4 : 1} - # CHECK-NEXT: add(lhs=read_M:1_N:0_K:0, rhs=read_M:1_N:0_K:0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 64) + 16 : 1 : 16, N: $T1*BLOCK_N/4 + $WG1*BLOCK_N : 4 : 1} - # CHECK-NEXT: sub(lhs=add_M:0_N:0_K:0, rhs=read_M:0_N:0_K:0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 64) : 1 : 16, N: $T1*BLOCK_N/4 + $WG1*BLOCK_N : 4 : 1} - # CHECK-NEXT: sub(lhs=add_M:1_N:0_K:0, rhs=read_M:1_N:0_K:0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 64) + 16 : 1 : 16, N: $T1*BLOCK_N/4 + $WG1*BLOCK_N : 4 : 1} - # CHECK-NEXT: neg(arg=sub_M:0_N:0_K:0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 64) : 1 : 16, N: $T1*BLOCK_N/4 + $WG1*BLOCK_N : 4 : 1} - # CHECK-NEXT: neg(arg=sub_M:1_N:0_K:0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 64) + 16 : 1 : 16, N: $T1*BLOCK_N/4 + $WG1*BLOCK_N : 4 : 1} - # CHECK-NEXT: write(register_=neg_M:0_N:0_K:0, memory=c, elements_per_thread=4, mapping_dynamic_vals=(), flags=MemoryAccessFlags.NONE, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 64) : 1 : 16, K: $WG2*BLOCK_K : 4 : 1} - # CHECK-NEXT: write(register_=neg_M:0_N:0_K:0, memory=c, elements_per_thread=4, mapping_dynamic_vals=(), flags=MemoryAccessFlags.NONE, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 64) : 1 : 16, K: $WG2*BLOCK_K + 16 : 4 : 1} - # CHECK-NEXT: write(register_=neg_M:1_N:0_K:0, memory=c, elements_per_thread=4, mapping_dynamic_vals=(), flags=MemoryAccessFlags.NONE, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 64) + 16 : 1 : 16, K: $WG2*BLOCK_K : 4 : 1} - # CHECK-NEXT: write(register_=neg_M:1_N:0_K:0, memory=c, elements_per_thread=4, mapping_dynamic_vals=(), flags=MemoryAccessFlags.NONE, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 64) + 16 : 1 : 16, K: $WG2*BLOCK_K + 16 : 4 : 1} + # CHECK-NEXT: read(memory=a, elements_per_thread=4, mapping_dynamic_vals=(), flags=MemoryAccessFlags.NONE, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 64) : 1 : 1, N: $T1*BLOCK_N/4 + $WG1*BLOCK_N : 4 : 1} + # CHECK-NEXT: read(memory=a, elements_per_thread=4, mapping_dynamic_vals=(), flags=MemoryAccessFlags.NONE, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 64) + 16 : 1 : 1, N: $T1*BLOCK_N/4 + $WG1*BLOCK_N : 4 : 1} + # CHECK-NEXT: add(lhs=read_M:0_N:0_K:0, rhs=read_M:0_N:0_K:0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 64) : 1 : 1, N: $T1*BLOCK_N/4 + $WG1*BLOCK_N : 4 : 1} + # CHECK-NEXT: add(lhs=read_M:1_N:0_K:0, rhs=read_M:1_N:0_K:0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 64) + 16 : 1 : 1, N: $T1*BLOCK_N/4 + $WG1*BLOCK_N : 4 : 1} + # CHECK-NEXT: sub(lhs=add_M:0_N:0_K:0, rhs=read_M:0_N:0_K:0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 64) : 1 : 1, N: $T1*BLOCK_N/4 + $WG1*BLOCK_N : 4 : 1} + # CHECK-NEXT: sub(lhs=add_M:1_N:0_K:0, rhs=read_M:1_N:0_K:0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 64) + 16 : 1 : 1, N: $T1*BLOCK_N/4 + $WG1*BLOCK_N : 4 : 1} + # CHECK-NEXT: neg(arg=sub_M:0_N:0_K:0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 64) : 1 : 1, N: $T1*BLOCK_N/4 + $WG1*BLOCK_N : 4 : 1} + # CHECK-NEXT: neg(arg=sub_M:1_N:0_K:0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 64) + 16 : 1 : 1, N: $T1*BLOCK_N/4 + $WG1*BLOCK_N : 4 : 1} + # CHECK-NEXT: write(register_=neg_M:0_N:0_K:0, memory=c, elements_per_thread=4, mapping_dynamic_vals=(), flags=MemoryAccessFlags.NONE, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 64) : 1 : 1, K: $WG2*BLOCK_K : 4 : 1} + # CHECK-NEXT: write(register_=neg_M:0_N:0_K:0, memory=c, elements_per_thread=4, mapping_dynamic_vals=(), flags=MemoryAccessFlags.NONE, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 64) : 1 : 1, K: $WG2*BLOCK_K + 16 : 4 : 1} + # CHECK-NEXT: write(register_=neg_M:1_N:0_K:0, memory=c, elements_per_thread=4, mapping_dynamic_vals=(), flags=MemoryAccessFlags.NONE, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 64) + 16 : 1 : 1, K: $WG2*BLOCK_K : 4 : 1} + # CHECK-NEXT: write(register_=neg_M:1_N:0_K:0, memory=c, elements_per_thread=4, mapping_dynamic_vals=(), flags=MemoryAccessFlags.NONE, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 64) + 16 : 1 : 1, K: $WG2*BLOCK_K + 16 : 4 : 1} # CHECK: ----- diff --git a/wave_lang/kernel/wave/analysis/index_sequence_analysis.py b/wave_lang/kernel/wave/analysis/index_sequence_analysis.py index cb515481f..40eab96f3 100644 --- a/wave_lang/kernel/wave/analysis/index_sequence_analysis.py +++ b/wave_lang/kernel/wave/analysis/index_sequence_analysis.py @@ -446,7 +446,7 @@ def compute_stride( stride = int(stride) except Exception as e: logger.error(e) - return stride + return 1 def is_contiguous_dim(