diff --git a/operators/gemm/op.py b/operators/gemm/op.py index 434857c..aab3e97 100644 --- a/operators/gemm/op.py +++ b/operators/gemm/op.py @@ -54,7 +54,7 @@ def __init__( # Calls to forward() may supply matrices of different sizes, and the # Python code will perform necessary padding/repeated application of # the NPU operator. - M_padded, K_padded, N_padded = self._get_padded_dims(M, K, N) + M_padded, K_padded, N_padded = self._get_padded_dims(M, K, N, tile_m, tile_k, tile_n, num_aie_columns, self.n_aie_rows) self.M = M_padded self.K = K_padded self.N = N_padded @@ -96,8 +96,6 @@ def get_artifacts(self, prefix="gemm_"): assert tile_m >= min_tile_m, f"tile_m ({tile_m}) must be >= {min_tile_m}" assert tile_k >= min_tile_k, f"tile_k ({tile_k}) must be >= {min_tile_k}" assert tile_n >= min_tile_n, f"tile_n ({tile_n}) must be >= {min_tile_n}" - assert tile_k & (tile_k - 1) == 0, f"tile_k ({tile_k}) must be power of 2" - assert tile_n & (tile_n - 1) == 0, f"tile_n ({tile_n}) must be power of 2" file_name_tile_base = f"{prefix}{tile_m}x{tile_k}x{tile_n}" file_name_total_base = f"{prefix}{M}x{K}x{N}_{tile_m}x{tile_k}x{tile_n}_{int(b_col_maj)}_{int(c_col_maj)}" @@ -277,15 +275,11 @@ def forward(self, A, B=None): return result - def _get_padded_dims(self, M, K, N): - tile_m, tile_k, tile_n = self.tile_m, self.tile_n, self.tile_k - num_aie_columns = self.num_aie_columns - - min_M = tile_m * self.n_aie_rows + def _get_padded_dims(self, M, K, N, tile_m, tile_k, tile_n, aie_cols, aie_rows): + min_M = tile_m * aie_rows min_K = tile_k - min_N = tile_n * num_aie_columns + min_N = tile_n * aie_cols - # Calculate padded dimensions M_padded = ((M + min_M - 1) // min_M) * min_M K_padded = ((K + min_K - 1) // min_K) * min_K N_padded = ((N + min_N - 1) // min_N) * min_N diff --git a/operators/gemm/test.py b/operators/gemm/test.py index 519dcb9..6529831 100755 --- a/operators/gemm/test.py +++ b/operators/gemm/test.py @@ -89,6 +89,9 @@ def test_gemm( M=M, K=K, N=N, + tile_m=m, + tile_k=k, + tile_n=n, num_aie_columns=num_aie_columns, prio_accuracy=True, emulate_bf16_mmul_with_bfp16=False, @@ -113,4 +116,4 @@ def test_gemm( print(f"Effective Bandwidth: {bandwidth_gbps:.6e} GB/s") print(f"Throughput: {gflops:.6e} GFLOP/s\n") - assert not errors, f"Test failed with errors: {errors}" + assert not errors, f"Test failed with errors: {errors[:10]}"