Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 4 additions & 10 deletions operators/gemm/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __init__(
# Calls to forward() may supply matrices of different sizes, and the
# Python code will perform necessary padding/repeated application of
# the NPU operator.
M_padded, K_padded, N_padded = self._get_padded_dims(M, K, N)
M_padded, K_padded, N_padded = self._get_padded_dims(M, K, N, tile_m, tile_k, tile_n, num_aie_columns, self.n_aie_rows)
self.M = M_padded
self.K = K_padded
self.N = N_padded
Expand Down Expand Up @@ -96,8 +96,6 @@ def get_artifacts(self, prefix="gemm_"):
assert tile_m >= min_tile_m, f"tile_m ({tile_m}) must be >= {min_tile_m}"
assert tile_k >= min_tile_k, f"tile_k ({tile_k}) must be >= {min_tile_k}"
assert tile_n >= min_tile_n, f"tile_n ({tile_n}) must be >= {min_tile_n}"
assert tile_k & (tile_k - 1) == 0, f"tile_k ({tile_k}) must be power of 2"
assert tile_n & (tile_n - 1) == 0, f"tile_n ({tile_n}) must be power of 2"

file_name_tile_base = f"{prefix}{tile_m}x{tile_k}x{tile_n}"
file_name_total_base = f"{prefix}{M}x{K}x{N}_{tile_m}x{tile_k}x{tile_n}_{int(b_col_maj)}_{int(c_col_maj)}"
Expand Down Expand Up @@ -277,15 +275,11 @@ def forward(self, A, B=None):

return result

def _get_padded_dims(self, M, K, N):
tile_m, tile_k, tile_n = self.tile_m, self.tile_n, self.tile_k
num_aie_columns = self.num_aie_columns

min_M = tile_m * self.n_aie_rows
def _get_padded_dims(self, M, K, N, tile_m, tile_k, tile_n, aie_cols, aie_rows):
min_M = tile_m * aie_rows
min_K = tile_k
min_N = tile_n * num_aie_columns
min_N = tile_n * aie_cols

# Calculate padded dimensions
M_padded = ((M + min_M - 1) // min_M) * min_M
K_padded = ((K + min_K - 1) // min_K) * min_K
N_padded = ((N + min_N - 1) // min_N) * min_N
Expand Down
5 changes: 4 additions & 1 deletion operators/gemm/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ def test_gemm(
M=M,
K=K,
N=N,
tile_m=m,
tile_k=k,
tile_n=n,
num_aie_columns=num_aie_columns,
prio_accuracy=True,
emulate_bf16_mmul_with_bfp16=False,
Expand All @@ -113,4 +116,4 @@ def test_gemm(
print(f"Effective Bandwidth: {bandwidth_gbps:.6e} GB/s")
print(f"Throughput: {gflops:.6e} GFLOP/s\n")

assert not errors, f"Test failed with errors: {errors}"
assert not errors, f"Test failed with errors: {errors[:10]}"
Loading