Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion operators/common/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ def verify_buffer(operator, buf_name, reference, rel_tol=0.04, abs_tol=1e-6):
expected_np = torch_to_numpy(reference).reshape((-1,))
buf_size = operator.buffers[buf_name] // 2
output = operator.read_buffer(buf_name, (buf_size,))
if len(output) != len(expected_np):
if len(output) < len(expected_np):
# Allow larger buffers - binning may have allocated more space than needed
print(
f"Buffer size mismatch for {buf_name}: expected {len(expected_np)}, got {len(output)}"
)
Expand Down
13 changes: 6 additions & 7 deletions operators/gemm/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ def __init__(
self.tile_k = tile_k
self.tile_n = tile_n
self.num_aie_columns = num_aie_columns
self.n_aie_rows = 4 # Number of AIE rows used in the design
self.gemm_args = gemm_kwargs

# Set frequently accessed gemm_args
Expand All @@ -66,7 +65,9 @@ def __init__(
assert (
N % partition_N == 0
), f"N ({N}) must be divisible by partition_N ({partition_N})"
M_padded, K_padded, N_padded = self._get_padded_dims(M, K, N // partition_N)
M_padded, K_padded, N_padded = self._get_padded_dims(
M, K, N // partition_N, tile_m, tile_k, tile_n
)
self.M = M_padded
self.K = K_padded
self.N = N_padded
Expand Down Expand Up @@ -109,8 +110,6 @@ def get_artifacts(self, prefix="gemm_"):
assert tile_m >= min_tile_m, f"tile_m ({tile_m}) must be >= {min_tile_m}"
assert tile_k >= min_tile_k, f"tile_k ({tile_k}) must be >= {min_tile_k}"
assert tile_n >= min_tile_n, f"tile_n ({tile_n}) must be >= {min_tile_n}"
assert tile_k & (tile_k - 1) == 0, f"tile_k ({tile_k}) must be power of 2"
assert tile_n & (tile_n - 1) == 0, f"tile_n ({tile_n}) must be power of 2"

file_name_tile_base = f"{prefix}{tile_m}x{tile_k}x{tile_n}"
file_name_total_base = f"{prefix}{M}x{K}x{N}_{tile_m}x{tile_k}x{tile_n}_{int(b_col_maj)}_{int(c_col_maj)}"
Expand Down Expand Up @@ -326,11 +325,11 @@ def forward(self, A, B=None):

return result

def _get_padded_dims(self, M, K, N):
tile_m, tile_k, tile_n = self.tile_m, self.tile_n, self.tile_k
def _get_padded_dims(self, M, K, N, tile_m, tile_k, tile_n):
num_aie_columns = self.num_aie_columns
num_aie_rows = 4

min_M = tile_m * self.n_aie_rows
min_M = tile_m * num_aie_rows
min_K = tile_k
min_N = tile_n * num_aie_columns

Expand Down
12 changes: 12 additions & 0 deletions operators/gemm/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,18 @@ def generate_golden_reference(
input_a = torch.randn(M, K, dtype=dtype_torch) * val_range
input_b_full = torch.rand(K, N, dtype=dtype_torch) * val_range
output_full = torch.matmul(input_a, input_b_full)
if False:
# The following inputs are useful for debugging;
# the A matrix becomes a matrix where each element encodes its row and column index,
# and the B matrix is an identity matrix.
col_digits = len(str(K - 1)) if K > 0 else 1
factor = 10 ** (col_digits + 1)
row_indices = torch.arange(M, dtype=torch.int64).unsqueeze(1)
col_indices = torch.arange(K, dtype=torch.int64).unsqueeze(0)
input_a = (row_indices * factor + col_indices).to(dtype=dtype_torch)
input_b_full = torch.zeros(K, N, dtype=dtype_torch)
diag_dim = min(K, N)
input_b_full[:diag_dim, :diag_dim] = torch.eye(diag_dim, dtype=dtype_torch)
if b_col_maj:
input_b_full = input_b_full.T
if c_col_maj:
Expand Down
97 changes: 61 additions & 36 deletions operators/gemm/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,42 +14,64 @@


def generate_test_params(extensive=False):
M_list = [2048] if not extensive else [2048]
K_list = [2048] if not extensive else [2048, 8192, 64]
N_list = [2048] if not extensive else [2048, 8192]
m, k, n = 64, 64, 64
num_aie_columns = 8
col_maj = [(False, False), (True, False), (False, True)]
trace_size = 0
partition_N = 1

params = []
names = []
# fmt: off
params = [
# M, K, N, num_aie_columns, b_col_maj, c_col_maj, m, k, n, trace_size, partition_N
(2048, 2048, 2048, 1, False, False, 64, 64, 64, 0, 1),
(2048, 2048, 2048, 2, True, False, 64, 64, 64, 0, 1),
(2048, 2048, 2048, 8, True, True, 64, 64, 64, 0, 1),
( 384, 1536, 1792, 4, True, False, 32, 48, 64, 0, 1),
(1792, 896, 1152, 8, False, True, 64, 32, 48, 0, 1),
( 896, 1792, 640, 8, False, True, 32, 64, 80, 0, 1),
( 192, 384, 64, 4, False, False, 48, 96, 16, 0, 1),
( 192, 384, 64, 4, True, True, 48, 96, 16, 0, 1),
]
extensive_params = [
(2048, 2048, 2048, 8, False, False, 32, 32, 128, 0, 1),
(2048, 2048, 8192, 2, False, False, 64, 64, 64, 0, 1),
(2048, 8192, 2048, 2, False, False, 64, 64, 64, 0, 1),
(2048, 64, 2048, 2, False, False, 64, 64, 64, 0, 1),
(2048, 64, 8192, 2, False, False, 64, 64, 64, 0, 1),
(2048, 2048, 2048, 8, True, False, 128, 32, 32, 0, 1),
(2048, 2048, 8192, 2, True, False, 64, 64, 64, 0, 1),
(2048, 8192, 2048, 2, True, False, 64, 64, 64, 0, 1),
(2048, 64, 2048, 2, True, False, 64, 64, 64, 0, 1),
(2048, 64, 8192, 2, True, False, 64, 64, 64, 0, 1),
(2048, 2048, 2048, 2, False, True, 8, 16, 32, 0, 1),
(2048, 2048, 8192, 2, False, True, 64, 64, 64, 0, 1),
(2048, 8192, 2048, 2, False, True, 64, 64, 64, 0, 1),
(2048, 64, 2048, 2, False, True, 64, 64, 64, 0, 1),
(2048, 64, 8192, 2, False, True, 64, 64, 64, 0, 1),
]
# fmt: on

if extensive:
params = extensive_params

for b_col_maj, c_col_maj in col_maj:
for M in M_list:
for K in K_list:
for N in N_list:
if N == 8192 and K == 8192:
continue # Untested combination because huge & slow, unused in our application
params.append(
(
M,
K,
N,
num_aie_columns,
b_col_maj,
c_col_maj,
m,
k,
n,
trace_size,
partition_N,
)
)
names.append(
f"gemm_{M}x{K}x{N}_{m}x{k}x{n}_{num_aie_columns}_cols_{int(b_col_maj)}_bcolmaj_{int(c_col_maj)}_ccolmaj_{trace_size}{f"_{partition_N}" if partition_N > 1 else ""}"
)
names = []
for (
M,
K,
N,
num_aie_columns,
b_col_maj,
c_col_maj,
m,
k,
n,
trace_size,
partition_N,
) in params:
name = f"gemm_{M}x{K}x{N}_{m}x{k}x{n}_{num_aie_columns}cols"
if b_col_maj:
name += "_bcolmaj"
if c_col_maj:
name += "_ccolmaj"
if partition_N > 1:
name += f"_{partition_N}npart"
if trace_size > 0:
name += f"_{trace_size}trace"
names.append(name)

return params, names

Expand Down Expand Up @@ -103,6 +125,9 @@ def test_gemm(
M=M,
K=K,
N=N,
tile_m=m,
tile_k=k,
tile_n=n,
num_aie_columns=num_aie_columns,
prio_accuracy=True,
emulate_bf16_mmul_with_bfp16=False,
Expand Down Expand Up @@ -131,4 +156,4 @@ def test_gemm(
print(f"Effective Bandwidth: {bandwidth_gbps:.6e} GB/s")
print(f"Throughput: {gflops:.6e} GFLOP/s\n")

assert not errors, f"Test failed with errors: {errors}"
assert not errors, f"Test failed"