diff --git a/operators/common/test_utils.py b/operators/common/test_utils.py index dc598b9..dc19df5 100644 --- a/operators/common/test_utils.py +++ b/operators/common/test_utils.py @@ -34,7 +34,8 @@ def verify_buffer(operator, buf_name, reference, rel_tol=0.04, abs_tol=1e-6): expected_np = torch_to_numpy(reference).reshape((-1,)) buf_size = operator.buffers[buf_name] // 2 output = operator.read_buffer(buf_name, (buf_size,)) - if len(output) != len(expected_np): + if len(output) < len(expected_np): + # Allow larger buffers - binning may have allocated more space than needed print( f"Buffer size mismatch for {buf_name}: expected {len(expected_np)}, got {len(output)}" ) diff --git a/operators/gemm/op.py b/operators/gemm/op.py index 1c20569..9201eb2 100644 --- a/operators/gemm/op.py +++ b/operators/gemm/op.py @@ -46,7 +46,6 @@ def __init__( self.tile_k = tile_k self.tile_n = tile_n self.num_aie_columns = num_aie_columns - self.n_aie_rows = 4 # Number of AIE rows used in the design self.gemm_args = gemm_kwargs # Set frequently accessed gemm_args @@ -66,7 +65,9 @@ def __init__( assert ( N % partition_N == 0 ), f"N ({N}) must be divisible by partition_N ({partition_N})" - M_padded, K_padded, N_padded = self._get_padded_dims(M, K, N // partition_N) + M_padded, K_padded, N_padded = self._get_padded_dims( + M, K, N // partition_N, tile_m, tile_k, tile_n + ) self.M = M_padded self.K = K_padded self.N = N_padded @@ -109,8 +110,6 @@ def get_artifacts(self, prefix="gemm_"): assert tile_m >= min_tile_m, f"tile_m ({tile_m}) must be >= {min_tile_m}" assert tile_k >= min_tile_k, f"tile_k ({tile_k}) must be >= {min_tile_k}" assert tile_n >= min_tile_n, f"tile_n ({tile_n}) must be >= {min_tile_n}" - assert tile_k & (tile_k - 1) == 0, f"tile_k ({tile_k}) must be power of 2" - assert tile_n & (tile_n - 1) == 0, f"tile_n ({tile_n}) must be power of 2" file_name_tile_base = f"{prefix}{tile_m}x{tile_k}x{tile_n}" file_name_total_base = f"{prefix}{M}x{K}x{N}_{tile_m}x{tile_k}x{tile_n}_{int(b_col_maj)}_{int(c_col_maj)}" @@ -326,11 +325,11 @@ def forward(self, A, B=None): return result - def _get_padded_dims(self, M, K, N): - tile_m, tile_k, tile_n = self.tile_m, self.tile_n, self.tile_k + def _get_padded_dims(self, M, K, N, tile_m, tile_k, tile_n): num_aie_columns = self.num_aie_columns + num_aie_rows = 4 - min_M = tile_m * self.n_aie_rows + min_M = tile_m * num_aie_rows min_K = tile_k min_N = tile_n * num_aie_columns diff --git a/operators/gemm/reference.py b/operators/gemm/reference.py index da35f36..093f3ab 100644 --- a/operators/gemm/reference.py +++ b/operators/gemm/reference.py @@ -21,6 +21,18 @@ def generate_golden_reference( input_a = torch.randn(M, K, dtype=dtype_torch) * val_range input_b_full = torch.rand(K, N, dtype=dtype_torch) * val_range output_full = torch.matmul(input_a, input_b_full) + if False: + # The following inputs are useful for debugging; + # the A matrix becomes a matrix where each element encodes its row and column index, + # and the B matrix is an identity matrix. + col_digits = len(str(K - 1)) if K > 0 else 1 + factor = 10 ** (col_digits + 1) + row_indices = torch.arange(M, dtype=torch.int64).unsqueeze(1) + col_indices = torch.arange(K, dtype=torch.int64).unsqueeze(0) + input_a = (row_indices * factor + col_indices).to(dtype=dtype_torch) + input_b_full = torch.zeros(K, N, dtype=dtype_torch) + diag_dim = min(K, N) + input_b_full[:diag_dim, :diag_dim] = torch.eye(diag_dim, dtype=dtype_torch) if b_col_maj: input_b_full = input_b_full.T if c_col_maj: diff --git a/operators/gemm/test.py b/operators/gemm/test.py index 57ead59..4c4062a 100755 --- a/operators/gemm/test.py +++ b/operators/gemm/test.py @@ -14,42 +14,64 @@ def generate_test_params(extensive=False): - M_list = [2048] if not extensive else [2048] - K_list = [2048] if not extensive else [2048, 8192, 64] - N_list = [2048] if not extensive else [2048, 8192] - m, k, n = 64, 64, 64 - num_aie_columns = 8 - col_maj = [(False, False), (True, False), (False, True)] - trace_size = 0 - partition_N = 1 - - params = [] - names = [] + # fmt: off + params = [ + # M, K, N, num_aie_columns, b_col_maj, c_col_maj, m, k, n, trace_size, partition_N + (2048, 2048, 2048, 1, False, False, 64, 64, 64, 0, 1), + (2048, 2048, 2048, 2, True, False, 64, 64, 64, 0, 1), + (2048, 2048, 2048, 8, True, True, 64, 64, 64, 0, 1), + ( 384, 1536, 1792, 4, True, False, 32, 48, 64, 0, 1), + (1792, 896, 1152, 8, False, True, 64, 32, 48, 0, 1), + ( 896, 1792, 640, 8, False, True, 32, 64, 80, 0, 1), + ( 192, 384, 64, 4, False, False, 48, 96, 16, 0, 1), + ( 192, 384, 64, 4, True, True, 48, 96, 16, 0, 1), + ] + extensive_params = [ + (2048, 2048, 2048, 8, False, False, 32, 32, 128, 0, 1), + (2048, 2048, 8192, 2, False, False, 64, 64, 64, 0, 1), + (2048, 8192, 2048, 2, False, False, 64, 64, 64, 0, 1), + (2048, 64, 2048, 2, False, False, 64, 64, 64, 0, 1), + (2048, 64, 8192, 2, False, False, 64, 64, 64, 0, 1), + (2048, 2048, 2048, 8, True, False, 128, 32, 32, 0, 1), + (2048, 2048, 8192, 2, True, False, 64, 64, 64, 0, 1), + (2048, 8192, 2048, 2, True, False, 64, 64, 64, 0, 1), + (2048, 64, 2048, 2, True, False, 64, 64, 64, 0, 1), + (2048, 64, 8192, 2, True, False, 64, 64, 64, 0, 1), + (2048, 2048, 2048, 2, False, True, 8, 16, 32, 0, 1), + (2048, 2048, 8192, 2, False, True, 64, 64, 64, 0, 1), + (2048, 8192, 2048, 2, False, True, 64, 64, 64, 0, 1), + (2048, 64, 2048, 2, False, True, 64, 64, 64, 0, 1), + (2048, 64, 8192, 2, False, True, 64, 64, 64, 0, 1), + ] + # fmt: on + + if extensive: + params = extensive_params - for b_col_maj, c_col_maj in col_maj: - for M in M_list: - for K in K_list: - for N in N_list: - if N == 8192 and K == 8192: - continue # Untested combination because huge & slow, unused in our application - params.append( - ( - M, - K, - N, - num_aie_columns, - b_col_maj, - c_col_maj, - m, - k, - n, - trace_size, - partition_N, - ) - ) - names.append( - f"gemm_{M}x{K}x{N}_{m}x{k}x{n}_{num_aie_columns}_cols_{int(b_col_maj)}_bcolmaj_{int(c_col_maj)}_ccolmaj_{trace_size}{f"_{partition_N}" if partition_N > 1 else ""}" - ) + names = [] + for ( + M, + K, + N, + num_aie_columns, + b_col_maj, + c_col_maj, + m, + k, + n, + trace_size, + partition_N, + ) in params: + name = f"gemm_{M}x{K}x{N}_{m}x{k}x{n}_{num_aie_columns}cols" + if b_col_maj: + name += "_bcolmaj" + if c_col_maj: + name += "_ccolmaj" + if partition_N > 1: + name += f"_{partition_N}npart" + if trace_size > 0: + name += f"_{trace_size}trace" + names.append(name) return params, names @@ -103,6 +125,9 @@ def test_gemm( M=M, K=K, N=N, + tile_m=m, + tile_k=k, + tile_n=n, num_aie_columns=num_aie_columns, prio_accuracy=True, emulate_bf16_mmul_with_bfp16=False, @@ -131,4 +156,4 @@ def test_gemm( print(f"Effective Bandwidth: {bandwidth_gbps:.6e} GB/s") print(f"Throughput: {gflops:.6e} GFLOP/s\n") - assert not errors, f"Test failed with errors: {errors}" + assert not errors, f"Test failed"