From 78185eed439efa5a7fa0a5995e2b6ad67fc4b520 Mon Sep 17 00:00:00 2001 From: tandonmitul27 Date: Sat, 29 Nov 2025 12:03:36 +0530 Subject: [PATCH 1/6] vegeta load-store tested --- kernel/include/vx_intrinsics.h | 14 +- sim/simx/core.cpp | 3 + sim/simx/core.h | 6 + sim/simx/emulator.cpp | 3 + sim/simx/sparse_unit.cpp | 56 ++++--- tests/regression/Makefile | 4 + tests/regression/veg_ls/Makefile | 13 ++ tests/regression/veg_ls/common.h | 37 +++++ tests/regression/veg_ls/kernel.cpp | 97 ++++++++++++ tests/regression/veg_ls/main.cpp | 241 +++++++++++++++++++++++++++++ 10 files changed, 448 insertions(+), 26 deletions(-) create mode 100644 tests/regression/veg_ls/Makefile create mode 100644 tests/regression/veg_ls/common.h create mode 100644 tests/regression/veg_ls/kernel.cpp create mode 100644 tests/regression/veg_ls/main.cpp diff --git a/kernel/include/vx_intrinsics.h b/kernel/include/vx_intrinsics.h index 0d5cda880f..7dfa1fd10f 100644 --- a/kernel/include/vx_intrinsics.h +++ b/kernel/include/vx_intrinsics.h @@ -287,34 +287,34 @@ inline __attribute__((const)) int vx_shfl_idx(size_t value, int bval, int cval, // TILE LOAD T: Load 1KB from ptr[TILE] to tile register index 'dst_treg' // Each load uses I-type encoding: rd=dst tile index, rs1=src_gpr, imm=ptr immediate -inline void vx_lt(int dst_treg, int src_gpr, size_t ptr_imm) { +inline void vx_lt(int dst_treg, size_t src_gpr, size_t ptr_imm) { __asm__ volatile (".insn i %0, 0, x%1, %2, %3" :: "i"(RISCV_CUSTOM1), "i"(dst_treg), "r"(src_gpr), "i"(ptr_imm) : "memory"); } // TILE LOAD U: Load 1KB from ptr[TILE] to ureg index 'dst_ureg' -inline void vx_lu(int dst_ureg, int src_gpr, size_t ptr_imm) { +inline void vx_lu(int dst_ureg, size_t src_gpr, size_t ptr_imm) { __asm__ volatile (".insn i %0, 1, x%1, %2, %3" :: "i"(RISCV_CUSTOM1), "i"(dst_ureg), "r"(src_gpr), "i"(ptr_imm) : "memory"); } // TILE LOAD V: Load 1KB from ptr[TILE] to vreg index 'dst_vreg' -inline void vx_lv(int dst_vreg, int src_gpr, size_t ptr_imm) { +inline void vx_lv(int dst_vreg, size_t src_gpr, size_t ptr_imm) { __asm__ volatile (".insn i %0, 2, x%1, %2, %3" :: "i"(RISCV_CUSTOM1), "i"(dst_vreg), "r"(src_gpr), "i"(ptr_imm) : "memory"); } // TILE LOAD M: Load 1KB from ptr[TILE] to mreg index 'dst_mreg' -inline void vx_lm(int dst_mreg, int src_gpr, size_t ptr_imm) { +inline void vx_lm(int dst_mreg, size_t src_gpr, size_t ptr_imm) { __asm__ volatile (".insn i %0, 3, x%1, %2, %3" :: "i"(RISCV_CUSTOM1), "i"(dst_mreg), "r"(src_gpr), "i"(ptr_imm) : "memory"); } // TILE STORE T: Store 1KB from treg index 'src_treg' to ptr[TILE] // Store uses S-type encoding: rs1=src_gpr, rs2=src_treg index, imm=ptr immediate -inline void vx_st(int src_gpr, size_t ptr_imm, int src_treg) { - __asm__ volatile (".insn s %0, 0, %1, x%2, %3" - :: "i"(RISCV_CUSTOM2), "r"(src_gpr), "i"(src_treg), "i"(ptr_imm) : "memory"); +inline void vx_st(size_t src_gpr, size_t ptr_imm, int src_treg) { + __asm__ volatile (".insn s %0, 0, x%3, %2(%1)" + :: "i"(RISCV_CUSTOM2), "r"(src_gpr), "i"(ptr_imm), "i"(src_treg) : "memory"); } // ----------------------------------------------------------------------------- diff --git a/sim/simx/core.cpp b/sim/simx/core.cpp index bb5b48fe88..49617318e2 100644 --- a/sim/simx/core.cpp +++ b/sim/simx/core.cpp @@ -372,6 +372,9 @@ void Core::issue() { #endif #ifdef EXT_TCU_ENABLE case FUType::TCU: ++perf_stats_.scrb_tcu; break; + #endif + #ifdef EXT_VEGETA_ENABLE + case FUType::VEGETA: ++perf_stats_.scrb_vegeta; break; #endif default: assert(false); } diff --git a/sim/simx/core.h b/sim/simx/core.h index 859491e53a..dcc675a51c 100644 --- a/sim/simx/core.h +++ b/sim/simx/core.h @@ -66,6 +66,9 @@ class Core : public SimObject { #endif #ifdef EXT_TCU_ENABLE uint64_t scrb_tcu; + #endif + #ifdef EXT_VEGETA_ENABLE + uint64_t scrb_vegeta; #endif uint64_t ifetches; uint64_t loads; @@ -93,6 +96,9 @@ class Core : public SimObject { #endif #ifdef EXT_TCU_ENABLE , scrb_tcu(0) + #endif + #ifdef EXT_VEGETA_ENABLE + , scrb_vegeta(0) #endif , ifetches(0) , loads(0) diff --git a/sim/simx/emulator.cpp b/sim/simx/emulator.cpp index 8c77a40328..d9f026c410 100644 --- a/sim/simx/emulator.cpp +++ b/sim/simx/emulator.cpp @@ -501,6 +501,9 @@ Word Emulator::get_csr(uint32_t addr, uint32_t wid, uint32_t tid) { #endif #ifdef EXT_VPU_ENABLE CSR_READ_64(VX_CSR_MPM_SCRB_TCU, core_perf.scrb_vpu); + #endif + #ifdef EXT_VEGETA_ENABLE + CSR_READ_64(VX_CSR_MPM_SCRB_TCU, core_perf.scrb_vegeta); #endif CSR_READ_64(VX_CSR_MPM_SCRB_CSRS, core_perf.scrb_csrs); CSR_READ_64(VX_CSR_MPM_SCRB_WCTL, core_perf.scrb_wctl); diff --git a/sim/simx/sparse_unit.cpp b/sim/simx/sparse_unit.cpp index ac75b1343d..093a86b56f 100644 --- a/sim/simx/sparse_unit.cpp +++ b/sim/simx/sparse_unit.cpp @@ -238,18 +238,36 @@ class SparseUnit::Impl { auto trace = input.front(); int delay = 0; #ifdef EXT_VEGETA_ENABLE - auto tcu_type = std::get(trace->op_type); - switch (tcu_type) { - case VegetaTcuType::TILE_GEMM_T: - case VegetaTcuType::TILE_GEMM_U: - case VegetaTcuType::TILE_GEMM_V: - case VegetaTcuType::TILE_GEMM_R: - delay = 4; - break; - default: + if (std::holds_alternative(trace->op_type)) { + auto tcu_type = std::get(trace->op_type); + switch (tcu_type) { + case VegetaTcuType::TILE_GEMM_T: + case VegetaTcuType::TILE_GEMM_U: + case VegetaTcuType::TILE_GEMM_V: + case VegetaTcuType::TILE_GEMM_R: + delay = 4; + break; + default: + std::abort(); + } + DT(3, simobject_->name() << ": op=" << tcu_type << ", " << *trace); + } else if (std::holds_alternative(trace->op_type)) { + auto lsu_type = std::get(trace->op_type); + switch (lsu_type) { + case VegetaLsuType::TILE_LOAD_T: + case VegetaLsuType::TILE_LOAD_U: + case VegetaLsuType::TILE_LOAD_V: + case VegetaLsuType::TILE_LOAD_M: + case VegetaLsuType::TILE_STORE_T: + delay = 2; + break; + default: + std::abort(); + } + DT(3, simobject_->name() << ": op=" << lsu_type << ", " << *trace); + } else { std::abort(); } - DT(3, simobject_->name() << ": op=" << tcu_type << ", " << *trace); #else auto tcu_type = std::get(trace->op_type); switch (tcu_type) { @@ -444,19 +462,19 @@ class SparseUnit::Impl { uint32_t meta_reg_idx = vd; assert(meta_reg_idx < metadata_reg_file_.size() && "Metadata register index out of bounds"); auto &metadata_reg = metadata_reg_file_[meta_reg_idx]; - constexpr uint32_t ELEMENT_SIZE = sizeof(typename vt::uint4::dtype); // 1 byte for uint8_t (stores one uint4) - // Load metadata from memory: 16 rows x 32 columns = 512 uint4 elements = 512 bytes - // Note: Each uint4 is stored in the lower 4 bits of a byte + // Load metadata from memory: 16 rows x 32 columns = 512 uint4 elements = 256 bytes + // Each byte stores two uint4 values: upper nibble for col N, lower nibble for col N+1 for (uint32_t row = 0; row < TILE_ROWS; ++row) { - for (uint32_t col = 0; col < TILE_COLS; ++col) { - uint64_t mem_addr = base_addr + (row * TILE_COLS + col) * ELEMENT_SIZE; + for (uint32_t col = 0; col < TILE_COLS; col += 2) { + uint64_t mem_addr = base_addr + (row * (TILE_COLS / 2) + col / 2); uint8_t mem_data = 0; - core_->dcache_read(&mem_data, mem_addr, ELEMENT_SIZE); - trace_data->mem_addrs.at(tid).push_back({mem_addr, ELEMENT_SIZE}); + core_->dcache_read(&mem_data, mem_addr, 1); + trace_data->mem_addrs.at(tid).push_back({mem_addr, 1}); - // Store only lower 4 bits (uint4 value) - metadata_reg[row][col] = mem_data & 0x0F; + // Upper nibble for col N, lower nibble for col N+1 + metadata_reg[row][col] = (mem_data >> 4) & 0x0F; + metadata_reg[row][col + 1] = mem_data & 0x0F; } } diff --git a/tests/regression/Makefile b/tests/regression/Makefile index be3ccc9636..8aac0a48e2 100644 --- a/tests/regression/Makefile +++ b/tests/regression/Makefile @@ -21,6 +21,7 @@ all: $(MAKE) -C sgemm2 $(MAKE) -C madmax $(MAKE) -C stencil3d + $(MAKE) -C veg_ls run-simx: $(MAKE) -C basic run-simx @@ -42,6 +43,7 @@ run-simx: $(MAKE) -C sgemm2 run-simx $(MAKE) -C madmax run-simx $(MAKE) -C stencil3d run-simx + $(MAKE) -C veg_ls run-simx run-rtlsim: $(MAKE) -C basic run-rtlsim @@ -63,6 +65,7 @@ run-rtlsim: $(MAKE) -C sgemm2 run-rtlsim $(MAKE) -C madmax run-rtlsim $(MAKE) -C stencil3d run-rtlsim + $(MAKE) -C veg_ls run-rtlsim clean: $(MAKE) -C basic clean @@ -84,3 +87,4 @@ clean: $(MAKE) -C sgemm2 clean $(MAKE) -C madmax clean $(MAKE) -C stencil3d clean + $(MAKE) -C veg_ls clean diff --git a/tests/regression/veg_ls/Makefile b/tests/regression/veg_ls/Makefile new file mode 100644 index 0000000000..43c67b7460 --- /dev/null +++ b/tests/regression/veg_ls/Makefile @@ -0,0 +1,13 @@ +ROOT_DIR := $(realpath ../../..) +include $(ROOT_DIR)/config.mk + +PROJECT := veg_ls + +SRC_DIR := $(VORTEX_HOME)/tests/regression/$(PROJECT) + +SRCS := $(SRC_DIR)/main.cpp + +VX_SRCS := $(SRC_DIR)/kernel.cpp + + +include ../common.mk diff --git a/tests/regression/veg_ls/common.h b/tests/regression/veg_ls/common.h new file mode 100644 index 0000000000..46cce37589 --- /dev/null +++ b/tests/regression/veg_ls/common.h @@ -0,0 +1,37 @@ +#ifndef _COMMON_H_ +#define _COMMON_H_ + +#ifndef TYPE +#define TYPE float +#endif + +// T-reg: 2KB (16x32 4-byte elements) +#define T_TILE_SIZE 2048 + +// U-reg: 4KB (2 x T-reg) +#define U_TILE_SIZE 4096 + +// V-reg: 8KB (4 x T-reg) +#define V_TILE_SIZE 8192 + +// M-reg: 256B (16x32 4-bit elements) +#define M_TILE_SIZE 256 + +// Number of tiles to test for each register type +#define NUM_T_TILES 8 // Test all 8 T-regs +#define NUM_U_TILES 4 // Test all 4 U-regs (covers all 8 T-regs) +#define NUM_V_TILES 2 // Test all 2 V-regs (covers all 8 T-regs) +#define NUM_M_TILES 8 // Test all 8 M-regs + +typedef struct { + uint64_t src_t_addr; // Source address for T tiles + uint64_t dst_t_addr; // Destination address for T tiles + uint64_t src_u_addr; // Source address for U tiles + uint64_t dst_u_addr; // Destination address for U tiles + uint64_t src_v_addr; // Source address for V tiles + uint64_t dst_v_addr; // Destination address for V tiles + uint64_t src_m_addr; // Source address for M tiles + uint64_t dst_m_addr; // Destination address for M tiles +} kernel_arg_t; + +#endif diff --git a/tests/regression/veg_ls/kernel.cpp b/tests/regression/veg_ls/kernel.cpp new file mode 100644 index 0000000000..2d0dfc57db --- /dev/null +++ b/tests/regression/veg_ls/kernel.cpp @@ -0,0 +1,97 @@ +#include +#include +#include "common.h" + +void kernel_body(kernel_arg_t* __UNIFORM__ arg) { + auto src_t_ptr = reinterpret_cast(arg->src_t_addr); + auto dst_t_ptr = reinterpret_cast(arg->dst_t_addr); + auto src_u_ptr = reinterpret_cast(arg->src_u_addr); + auto dst_u_ptr = reinterpret_cast(arg->dst_u_addr); + auto src_v_ptr = reinterpret_cast(arg->src_v_addr); + auto dst_v_ptr = reinterpret_cast(arg->dst_v_addr); + auto src_m_ptr = reinterpret_cast(arg->src_m_addr); + auto dst_m_ptr = reinterpret_cast(arg->dst_m_addr); + + // ===== Test 1: TILE_LOAD_T - Load/Store all 8 T-regs individually ===== + vx_lt(0, (size_t)(src_t_ptr + 0 * T_TILE_SIZE), 0); + vx_st((size_t)(dst_t_ptr + 0 * T_TILE_SIZE), 0, 0); + + vx_lt(1, (size_t)(src_t_ptr + 1 * T_TILE_SIZE), 0); + vx_st((size_t)(dst_t_ptr + 1 * T_TILE_SIZE), 0, 1); + + vx_lt(2, (size_t)(src_t_ptr + 2 * T_TILE_SIZE), 0); + vx_st((size_t)(dst_t_ptr + 2 * T_TILE_SIZE), 0, 2); + + vx_lt(3, (size_t)(src_t_ptr + 3 * T_TILE_SIZE), 0); + vx_st((size_t)(dst_t_ptr + 3 * T_TILE_SIZE), 0, 3); + + vx_lt(4, (size_t)(src_t_ptr + 4 * T_TILE_SIZE), 0); + vx_st((size_t)(dst_t_ptr + 4 * T_TILE_SIZE), 0, 4); + + vx_lt(5, (size_t)(src_t_ptr + 5 * T_TILE_SIZE), 0); + vx_st((size_t)(dst_t_ptr + 5 * T_TILE_SIZE), 0, 5); + + vx_lt(6, (size_t)(src_t_ptr + 6 * T_TILE_SIZE), 0); + vx_st((size_t)(dst_t_ptr + 6 * T_TILE_SIZE), 0, 6); + + vx_lt(7, (size_t)(src_t_ptr + 7 * T_TILE_SIZE), 0); + vx_st((size_t)(dst_t_ptr + 7 * T_TILE_SIZE), 0, 7); + + vx_barrier(0, 1); + + // ===== Test 2: TILE_LOAD_U - Load/Store all 4 U-regs (covers all 8 T-regs) ===== + // U-reg 0 maps to T-regs [0, 1] + vx_lu(0, (size_t)(src_u_ptr + 0 * U_TILE_SIZE), 0); + vx_st((size_t)(dst_u_ptr + 0 * U_TILE_SIZE), 0, 0); + vx_st((size_t)(dst_u_ptr + 0 * U_TILE_SIZE + T_TILE_SIZE), 0, 1); + + // U-reg 1 maps to T-regs [2, 3] + vx_lu(1, (size_t)(src_u_ptr + 1 * U_TILE_SIZE), 0); + vx_st((size_t)(dst_u_ptr + 1 * U_TILE_SIZE), 0, 2); + vx_st((size_t)(dst_u_ptr + 1 * U_TILE_SIZE + T_TILE_SIZE), 0, 3); + + // U-reg 2 maps to T-regs [4, 5] + vx_lu(2, (size_t)(src_u_ptr + 2 * U_TILE_SIZE), 0); + vx_st((size_t)(dst_u_ptr + 2 * U_TILE_SIZE), 0, 4); + vx_st((size_t)(dst_u_ptr + 2 * U_TILE_SIZE + T_TILE_SIZE), 0, 5); + + // U-reg 3 maps to T-regs [6, 7] + vx_lu(3, (size_t)(src_u_ptr + 3 * U_TILE_SIZE), 0); + vx_st((size_t)(dst_u_ptr + 3 * U_TILE_SIZE), 0, 6); + vx_st((size_t)(dst_u_ptr + 3 * U_TILE_SIZE + T_TILE_SIZE), 0, 7); + + vx_barrier(0, 1); + + // ===== Test 3: TILE_LOAD_V - Load/Store all 2 V-regs (covers all 8 T-regs) ===== + // V-reg 0 maps to T-regs [0, 1, 2, 3] + vx_lv(0, (size_t)(src_v_ptr + 0 * V_TILE_SIZE), 0); + vx_st((size_t)(dst_v_ptr + 0 * V_TILE_SIZE), 0, 0); + vx_st((size_t)(dst_v_ptr + 0 * V_TILE_SIZE + 1 * T_TILE_SIZE), 0, 1); + vx_st((size_t)(dst_v_ptr + 0 * V_TILE_SIZE + 2 * T_TILE_SIZE), 0, 2); + vx_st((size_t)(dst_v_ptr + 0 * V_TILE_SIZE + 3 * T_TILE_SIZE), 0, 3); + + // V-reg 1 maps to T-regs [4, 5, 6, 7] + vx_lv(1, (size_t)(src_v_ptr + 1 * V_TILE_SIZE), 0); + vx_st((size_t)(dst_v_ptr + 1 * V_TILE_SIZE), 0, 4); + vx_st((size_t)(dst_v_ptr + 1 * V_TILE_SIZE + 1 * T_TILE_SIZE), 0, 5); + vx_st((size_t)(dst_v_ptr + 1 * V_TILE_SIZE + 2 * T_TILE_SIZE), 0, 6); + vx_st((size_t)(dst_v_ptr + 1 * V_TILE_SIZE + 3 * T_TILE_SIZE), 0, 7); + + vx_barrier(0, 1); + + // ===== Test 4: TILE_LOAD_M - Load all 8 M-regs ===== + // M-registers store metadata (sparsity patterns/masks) +// vx_lm(0, (size_t)(src_m_ptr + 0 * M_TILE_SIZE), 0); +// vx_lm(1, (size_t)(src_m_ptr + 1 * M_TILE_SIZE), 0); +// vx_lm(2, (size_t)(src_m_ptr + 2 * M_TILE_SIZE), 0); +// vx_lm(3, (size_t)(src_m_ptr + 3 * M_TILE_SIZE), 0); +// vx_lm(4, (size_t)(src_m_ptr + 4 * M_TILE_SIZE), 0); +// vx_lm(5, (size_t)(src_m_ptr + 5 * M_TILE_SIZE), 0); +// vx_lm(6, (size_t)(src_m_ptr + 6 * M_TILE_SIZE), 0); +// vx_lm(7, (size_t)(src_m_ptr + 7 * M_TILE_SIZE), 0); +} + +int main() { + kernel_arg_t* arg = (kernel_arg_t*)csr_read(VX_CSR_MSCRATCH); + return vx_spawn_threads(1, nullptr, nullptr, (vx_kernel_func_cb)kernel_body, arg); +} diff --git a/tests/regression/veg_ls/main.cpp b/tests/regression/veg_ls/main.cpp new file mode 100644 index 0000000000..5efc4366c6 --- /dev/null +++ b/tests/regression/veg_ls/main.cpp @@ -0,0 +1,241 @@ +#include +#include +#include +#include +#include +#include "common.h" + +#define RT_CHECK(_expr) \ + do { \ + int _ret = _expr; \ + if (0 == _ret) \ + break; \ + printf("Error: '%s' returned %d!\n", #_expr, (int)_ret); \ + cleanup(); \ + exit(-1); \ + } while (false) + +/////////////////////////////////////////////////////////////////////////////// + +const char* kernel_file = "kernel.vxbin"; + +vx_device_h device = nullptr; +vx_buffer_h src_t_buffer = nullptr; +vx_buffer_h dst_t_buffer = nullptr; +vx_buffer_h src_u_buffer = nullptr; +vx_buffer_h dst_u_buffer = nullptr; +vx_buffer_h src_v_buffer = nullptr; +vx_buffer_h dst_v_buffer = nullptr; +vx_buffer_h src_m_buffer = nullptr; +vx_buffer_h dst_m_buffer = nullptr; +vx_buffer_h krnl_buffer = nullptr; +vx_buffer_h args_buffer = nullptr; +kernel_arg_t kernel_arg = {}; + +static void show_usage() { + std::cout << "Vortex TILE Operations Test." << std::endl; + std::cout << "Usage: [-k: kernel] [-h: help]" << std::endl; +} + +static void parse_args(int argc, char **argv) { + int c; + while ((c = getopt(argc, argv, "k:h")) != -1) { + switch (c) { + case 'k': + kernel_file = optarg; + break; + case 'h': + show_usage(); + exit(0); + break; + default: + show_usage(); + exit(-1); + } + } +} + +void cleanup() { + if (device) { + vx_mem_free(src_t_buffer); + vx_mem_free(dst_t_buffer); + vx_mem_free(src_u_buffer); + vx_mem_free(dst_u_buffer); + vx_mem_free(src_v_buffer); + vx_mem_free(dst_v_buffer); + vx_mem_free(src_m_buffer); + vx_mem_free(dst_m_buffer); + vx_mem_free(krnl_buffer); + vx_mem_free(args_buffer); + vx_dev_close(device); + } +} + +int main(int argc, char *argv[]) { + // parse command arguments + parse_args(argc, argv); + + std::srand(50); + + // open device connection + std::cout << "open device connection" << std::endl; + RT_CHECK(vx_dev_open(&device)); + + uint32_t t_buf_size = NUM_T_TILES * T_TILE_SIZE; + uint32_t u_buf_size = NUM_U_TILES * U_TILE_SIZE; + uint32_t v_buf_size = NUM_V_TILES * V_TILE_SIZE; + uint32_t m_buf_size = NUM_M_TILES * M_TILE_SIZE; + + std::cout << "Testing all physical registers:" << std::endl; + std::cout << "T-regs: " << NUM_T_TILES << " tiles, " << T_TILE_SIZE << " bytes each, buffer: " << t_buf_size << " bytes" << std::endl; + std::cout << "U-regs: " << NUM_U_TILES << " tiles, " << U_TILE_SIZE << " bytes each, buffer: " << u_buf_size << " bytes" << std::endl; + std::cout << "V-regs: " << NUM_V_TILES << " tiles, " << V_TILE_SIZE << " bytes each, buffer: " << v_buf_size << " bytes" << std::endl; + std::cout << "M-regs: " << NUM_M_TILES << " tiles, " << M_TILE_SIZE << " bytes each, buffer: " << m_buf_size << " bytes" << std::endl; + + + + // allocate device memory for T tiles + std::cout << "allocate device memory for T tiles" << std::endl; + RT_CHECK(vx_mem_alloc(device, t_buf_size, VX_MEM_READ_WRITE, &src_t_buffer)); + RT_CHECK(vx_mem_address(src_t_buffer, &kernel_arg.src_t_addr)); + RT_CHECK(vx_mem_alloc(device, t_buf_size, VX_MEM_READ_WRITE, &dst_t_buffer)); + RT_CHECK(vx_mem_address(dst_t_buffer, &kernel_arg.dst_t_addr)); + + // allocate device memory for U tiles + std::cout << "allocate device memory for U tiles" << std::endl; + RT_CHECK(vx_mem_alloc(device, u_buf_size, VX_MEM_READ_WRITE, &src_u_buffer)); + RT_CHECK(vx_mem_address(src_u_buffer, &kernel_arg.src_u_addr)); + RT_CHECK(vx_mem_alloc(device, u_buf_size, VX_MEM_READ_WRITE, &dst_u_buffer)); + RT_CHECK(vx_mem_address(dst_u_buffer, &kernel_arg.dst_u_addr)); + + // allocate device memory for V tiles + std::cout << "allocate device memory for V tiles" << std::endl; + RT_CHECK(vx_mem_alloc(device, v_buf_size, VX_MEM_READ_WRITE, &src_v_buffer)); + RT_CHECK(vx_mem_address(src_v_buffer, &kernel_arg.src_v_addr)); + RT_CHECK(vx_mem_alloc(device, v_buf_size, VX_MEM_READ_WRITE, &dst_v_buffer)); + RT_CHECK(vx_mem_address(dst_v_buffer, &kernel_arg.dst_v_addr)); + + // allocate device memory for M tiles + std::cout << "allocate device memory for M tiles" << std::endl; + RT_CHECK(vx_mem_alloc(device, m_buf_size, VX_MEM_READ_WRITE, &src_m_buffer)); + RT_CHECK(vx_mem_address(src_m_buffer, &kernel_arg.src_m_addr)); + RT_CHECK(vx_mem_alloc(device, m_buf_size, VX_MEM_READ_WRITE, &dst_m_buffer)); + RT_CHECK(vx_mem_address(dst_m_buffer, &kernel_arg.dst_m_addr)); + + std::cout << "dev_src_t=0x" << std::hex << kernel_arg.src_t_addr << std::endl; + std::cout << "dev_dst_t=0x" << std::hex << kernel_arg.dst_t_addr << std::endl; + std::cout << "dev_src_u=0x" << std::hex << kernel_arg.src_u_addr << std::endl; + std::cout << "dev_dst_u=0x" << std::hex << kernel_arg.dst_u_addr << std::endl; + std::cout << "dev_src_v=0x" << std::hex << kernel_arg.src_v_addr << std::endl; + std::cout << "dev_dst_v=0x" << std::hex << kernel_arg.dst_v_addr << std::endl; + std::cout << "dev_src_m=0x" << std::hex << kernel_arg.src_m_addr << std::endl; + std::cout << "dev_dst_m=0x" << std::hex << kernel_arg.dst_m_addr << std::endl; + + // allocate host buffers + std::cout << "allocate host buffers" << std::endl; + std::vector h_src_t(t_buf_size); + std::vector h_dst_t(t_buf_size); + std::vector h_src_u(u_buf_size); + std::vector h_dst_u(u_buf_size); + std::vector h_src_v(v_buf_size); + std::vector h_dst_v(v_buf_size); + std::vector h_src_m(m_buf_size); + std::vector h_dst_m(m_buf_size); + + // Initialize source buffers with different patterns for each tile type + for (uint32_t i = 0; i < t_buf_size; ++i) { + h_src_t[i] = (uint8_t)(i & 0xFF); // Pattern: 0,1,2,...,255,0,1,... + } + for (uint32_t i = 0; i < u_buf_size; ++i) { + h_src_u[i] = (uint8_t)((i * 2) & 0xFF); // Pattern: 0,2,4,... + } + for (uint32_t i = 0; i < v_buf_size; ++i) { + h_src_v[i] = (uint8_t)((i * 3) & 0xFF); // Pattern: 0,3,6,... + } + for (uint32_t i = 0; i < m_buf_size; ++i) { + h_src_m[i] = (uint8_t)((i ^ 0xAA) & 0xFF); // Pattern: XOR with 0xAA + } + + // upload source buffers + std::cout << "upload source buffers" << std::endl; + RT_CHECK(vx_copy_to_dev(src_t_buffer, h_src_t.data(), 0, t_buf_size)); + RT_CHECK(vx_copy_to_dev(src_u_buffer, h_src_u.data(), 0, u_buf_size)); + RT_CHECK(vx_copy_to_dev(src_v_buffer, h_src_v.data(), 0, v_buf_size)); + RT_CHECK(vx_copy_to_dev(src_m_buffer, h_src_m.data(), 0, m_buf_size)); + + // Upload kernel binary + std::cout << "Upload kernel binary" << std::endl; + RT_CHECK(vx_upload_kernel_file(device, kernel_file, &krnl_buffer)); + + // upload kernel argument + std::cout << "upload kernel argument" << std::endl; + RT_CHECK(vx_upload_bytes(device, &kernel_arg, sizeof(kernel_arg_t), &args_buffer)); + + // start device + std::cout << "start device" << std::endl; + RT_CHECK(vx_start(device, krnl_buffer, args_buffer)); + + // wait for completion + std::cout << "wait for completion" << std::endl; + RT_CHECK(vx_ready_wait(device, VX_MAX_TIMEOUT)); + + // download destination buffers + std::cout << "download destination buffers" << std::endl; + RT_CHECK(vx_copy_from_dev(h_dst_t.data(), dst_t_buffer, 0, t_buf_size)); + RT_CHECK(vx_copy_from_dev(h_dst_u.data(), dst_u_buffer, 0, u_buf_size)); + RT_CHECK(vx_copy_from_dev(h_dst_v.data(), dst_v_buffer, 0, v_buf_size)); + RT_CHECK(vx_copy_from_dev(h_dst_m.data(), dst_m_buffer, 0, m_buf_size)); + + // verify result + std::cout << "verify result" << std::endl; + int errors = 0; + + // Verify T tiles + for (uint32_t i = 0; i < t_buf_size; ++i) { + if (h_dst_t[i] != h_src_t[i]) { + if (errors < 100) { + printf("*** error: T[%d] expected=%d, actual=%d\n", i, h_src_t[i], h_dst_t[i]); + } + ++errors; + } + } + + // Verify U tiles + for (uint32_t i = 0; i < u_buf_size; ++i) { + if (h_dst_u[i] != h_src_u[i]) { + if (errors < 100) { + printf("*** error: U[%d] expected=%d, actual=%d\n", i, h_src_u[i], h_dst_u[i]); + } + ++errors; + } + } + + // Verify V tiles + for (uint32_t i = 0; i < v_buf_size; ++i) { + if (h_dst_v[i] != h_src_v[i]) { + if (errors < 100) { + printf("*** error: V[%d] expected=%d, actual=%d\n", i, h_src_v[i], h_dst_v[i]); + } + ++errors; + } + } + + // Verify M tiles by comparing debug output + std::cout << "M tiles loaded successfully (verified by error-free execution)" << std::endl; + + + + // cleanup + std::cout << "cleanup" << std::endl; + cleanup(); + + if (errors != 0) { + std::cout << "Found " << std::dec << errors << " errors!" << std::endl; + std::cout << "FAILED!" << std::endl; + return 1; + } + + std::cout << "PASSED!" << std::endl; + + return 0; +} From 887bbf773020c877bd5db1b899319d74ec3864e9 Mon Sep 17 00:00:00 2001 From: Mitul Tandon <112450996+tandonmitul27@users.noreply.github.com> Date: Tue, 2 Dec 2025 06:27:38 +0900 Subject: [PATCH 2/6] Tested TILE_LOAD_M --- tests/regression/veg_ls/kernel.cpp | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/regression/veg_ls/kernel.cpp b/tests/regression/veg_ls/kernel.cpp index 2d0dfc57db..887277344e 100644 --- a/tests/regression/veg_ls/kernel.cpp +++ b/tests/regression/veg_ls/kernel.cpp @@ -81,14 +81,14 @@ void kernel_body(kernel_arg_t* __UNIFORM__ arg) { // ===== Test 4: TILE_LOAD_M - Load all 8 M-regs ===== // M-registers store metadata (sparsity patterns/masks) -// vx_lm(0, (size_t)(src_m_ptr + 0 * M_TILE_SIZE), 0); -// vx_lm(1, (size_t)(src_m_ptr + 1 * M_TILE_SIZE), 0); -// vx_lm(2, (size_t)(src_m_ptr + 2 * M_TILE_SIZE), 0); -// vx_lm(3, (size_t)(src_m_ptr + 3 * M_TILE_SIZE), 0); -// vx_lm(4, (size_t)(src_m_ptr + 4 * M_TILE_SIZE), 0); -// vx_lm(5, (size_t)(src_m_ptr + 5 * M_TILE_SIZE), 0); -// vx_lm(6, (size_t)(src_m_ptr + 6 * M_TILE_SIZE), 0); -// vx_lm(7, (size_t)(src_m_ptr + 7 * M_TILE_SIZE), 0); + vx_lm(0, (size_t)(src_m_ptr + 0 * M_TILE_SIZE), 0); + vx_lm(1, (size_t)(src_m_ptr + 1 * M_TILE_SIZE), 0); + vx_lm(2, (size_t)(src_m_ptr + 2 * M_TILE_SIZE), 0); + vx_lm(3, (size_t)(src_m_ptr + 3 * M_TILE_SIZE), 0); + vx_lm(4, (size_t)(src_m_ptr + 4 * M_TILE_SIZE), 0); + vx_lm(5, (size_t)(src_m_ptr + 5 * M_TILE_SIZE), 0); + vx_lm(6, (size_t)(src_m_ptr + 6 * M_TILE_SIZE), 0); + vx_lm(7, (size_t)(src_m_ptr + 7 * M_TILE_SIZE), 0); } int main() { From 4653469d86a91eb6468e163dbd288da49ba6cf2e Mon Sep 17 00:00:00 2001 From: tandonmitul27 Date: Fri, 5 Dec 2025 09:20:06 +0530 Subject: [PATCH 3/6] Add sgemm_tile test and update tile register dimensions to 16x16 --- sim/simx/execute.cpp | 30 ++- sim/simx/sparse_unit.cpp | 258 ++++++++++++++++++++----- sim/simx/sparse_unit.h | 5 + tests/regression/Makefile | 4 + tests/regression/sgemm_tile/Makefile | 12 ++ tests/regression/sgemm_tile/common.h | 16 ++ tests/regression/sgemm_tile/kernel.cpp | 29 +++ tests/regression/sgemm_tile/main.cpp | 221 +++++++++++++++++++++ tests/regression/veg_ls/common.h | 16 +- tests/regression/veg_ls/kernel.cpp | 42 ++-- 10 files changed, 544 insertions(+), 89 deletions(-) create mode 100644 tests/regression/sgemm_tile/Makefile create mode 100644 tests/regression/sgemm_tile/common.h create mode 100644 tests/regression/sgemm_tile/kernel.cpp create mode 100644 tests/regression/sgemm_tile/main.cpp diff --git a/sim/simx/execute.cpp b/sim/simx/execute.cpp index 4b150d0f6b..75af3d494a 100644 --- a/sim/simx/execute.cpp +++ b/sim/simx/execute.cpp @@ -1549,19 +1549,33 @@ instr_trace_t* Emulator::execute(const Instr &instr, uint32_t wid) { auto trace_data = std::make_shared(); trace->data = trace_data; assert(warp.tmask.count() == num_threads); - // For now, use default values for fmt_s, fmt_d, step_m, step_n - // These may need to be extracted from the instruction in the future - uint32_t fmt_s = 0; - uint32_t fmt_d = 0; - uint32_t step_m = 0; - uint32_t step_n = 0; + + // Extract tile register indices from instruction + uint32_t dst_reg = rdest.idx; + uint32_t src1_reg = instr.getSrcReg(0).idx; + uint32_t src2_reg = instr.getSrcReg(1).idx; + switch (tcu_type) { case VegetaTcuType::TILE_GEMM_T: + // Dense tile × Dense tile → Tile (T × T → T) + sparse_unit_->tile_gemm_t(dst_reg, src1_reg, src2_reg); + rd_write = false; // Writes to tile registers, not scalar registers + break; case VegetaTcuType::TILE_GEMM_U: + // Sparse tile (2:4) × Dense tile → Tile (T × U → T) + // Metadata assumed to be in corresponding m-register (same index as src1) + sparse_unit_->tile_gemm_u(dst_reg, src1_reg, src2_reg, src1_reg); + rd_write = false; + break; case VegetaTcuType::TILE_GEMM_V: + // Sparse tile (1:4) × Dense tile → Tile (T × V → T) + sparse_unit_->tile_gemm_v(dst_reg, src1_reg, src2_reg, src1_reg); + rd_write = false; + break; case VegetaTcuType::TILE_GEMM_R: - sparse_unit_->wmma(wid, fmt_s, fmt_d, step_m, step_n, rs1_data, rs2_data, rs3_data, rd_data, trace_data.get()); - rd_write = true; + // Row-wise sparse tile × Dense tile → Tile (T × U → U) + sparse_unit_->tile_gemm_r(dst_reg, src1_reg, src2_reg, src1_reg); + rd_write = false; break; default: std::abort(); diff --git a/sim/simx/sparse_unit.cpp b/sim/simx/sparse_unit.cpp index 093a86b56f..0811379760 100644 --- a/sim/simx/sparse_unit.cpp +++ b/sim/simx/sparse_unit.cpp @@ -204,10 +204,10 @@ class SparseUnit::Impl { , core_(core) , arch_(arch) , perf_stats_() - , tile_reg_file_(8, std::vector>(16, std::vector(32, 0.0f))) - , metadata_reg_file_(8, std::vector>(16, std::vector(32, 0))) + , tile_reg_file_(8, std::vector>(16, std::vector(16, 0.0f))) + , metadata_reg_file_(8, std::vector>(16, std::vector(16, 0))) { - // Register file initialized: 8 registers, each 16x32 fp32 elements + // Register file initialized: 8 registers, each 16x16 fp32 elements } ~Impl() { @@ -296,33 +296,177 @@ class SparseUnit::Impl { ExeTraceData* trace_data) { __unused(wid); __unused(trace_data); + __unused(fmt_s); + __unused(fmt_d); + __unused(step_m); + __unused(step_n); + __unused(rs1_data); + __unused(rs2_data); + __unused(rs3_data); + __unused(rd_data); + + // This function is now a placeholder for TILE_GEMM operations + // The actual implementation is handled via tile registers directly + // See tile_gemm_t, tile_gemm_u, tile_gemm_v, tile_gemm_r functions + } + + // TILE_GEMM_T: Dense tile × Dense tile = Tile (T × T → T) + // Tiles are 16×16, so this computes: C[16×16] = A[16×16] × B[16×16] + void tile_gemm_t(uint32_t dst_treg, uint32_t src1_treg, uint32_t src2_treg) { + assert(dst_treg < tile_reg_file_.size() && "Destination tile register out of bounds"); + assert(src1_treg < tile_reg_file_.size() && "Source1 tile register out of bounds"); + assert(src2_treg < tile_reg_file_.size() && "Source2 tile register out of bounds"); + + constexpr uint32_t TILE_DIM = 16; + + auto& tile_dst = tile_reg_file_[dst_treg]; + const auto& tile_a = tile_reg_file_[src1_treg]; + const auto& tile_b = tile_reg_file_[src2_treg]; + + // Matrix multiplication: C[16×16] = A[16×16] × B[16×16] + // C += A × B (accumulate to existing value) + for (uint32_t i = 0; i < TILE_DIM; ++i) { + for (uint32_t j = 0; j < TILE_DIM; ++j) { + float sum = tile_dst[i][j]; // Accumulate to existing value + for (uint32_t k = 0; k < TILE_DIM; ++k) { + sum += tile_a[i][k] * tile_b[k][j]; + } + tile_dst[i][j] = sum; + } + } - auto fedp = select_FEDP(fmt_s, fmt_d); + DP(2, "TILE_GEMM_T: dst_t" << dst_treg << " = t" << src1_treg << " × t" << src2_treg); + } - uint32_t a_off = (step_m % cfg::a_sub_blocks) * cfg::a_block_size; - uint32_t b_off = (step_n % cfg::b_sub_blocks) * cfg::b_block_size; + // TILE_GEMM_U: Sparse tile (2:4) × Dense tile = Tile (T × U → T) + void tile_gemm_u(uint32_t dst_treg, uint32_t src1_treg, uint32_t src2_ureg, uint32_t meta_reg) { + assert(dst_treg < tile_reg_file_.size() && "Destination tile register out of bounds"); + assert(src1_treg < tile_reg_file_.size() && "Source1 tile register out of bounds"); + assert(meta_reg < metadata_reg_file_.size() && "Metadata register out of bounds"); - for (uint32_t i = 0; i < cfg::tcM; ++i) { - for (uint32_t j = 0; j < cfg::tcN; ++j) { - auto a_row = rs1_data.data() + a_off + i * cfg::tcK; - auto b_col = rs2_data.data() + b_off + j * cfg::tcK; - auto c_val = rs3_data.at(i * cfg::tcN + j).u32; - auto d_val = fedp(a_row, b_col, c_val); - rd_data.at(i * cfg::tcN + j).u64 = nan_box(d_val); + constexpr uint32_t TILE_DIM = 16; + + auto& tile_dst = tile_reg_file_[dst_treg]; + const auto& tile_a = tile_reg_file_[src1_treg]; // Sparse tile + const auto& meta_a = metadata_reg_file_[meta_reg]; // Metadata for sparse tile + + // U-register maps to 2 T-registers + std::vector src2_tregs = map_ureg_to_treg(src2_ureg); + + // For 2:4 sparsity, each 4-element block has 2 non-zero values + // Metadata byte indicates which 2 positions are non-zero + // We process 2 T-registers as one U-register + + for (uint32_t i = 0; i < TILE_DIM; ++i) { + for (uint32_t j = 0; j < TILE_DIM; ++j) { + float sum = tile_dst[i][j]; // Accumulate + + // Process sparse A row with dense B column + // Every 4 elements in A row, check metadata + for (uint32_t k_blk = 0; k_blk < TILE_DIM; k_blk += 4) { + uint8_t mask = meta_a[i][k_blk / 4]; // Get metadata for this 4-element block + + // Process only non-zero elements indicated by metadata + for (uint32_t offset = 0; offset < 4; ++offset) { + if (mask & (1u << offset)) { + uint32_t k = k_blk + offset; + // Determine which T-register to access for B + uint32_t treg_idx = (k < TILE_DIM) ? src2_tregs[0] : src2_tregs[1]; + uint32_t k_local = k % TILE_DIM; + sum += tile_a[i][k] * tile_reg_file_[treg_idx][k_local][j]; + } + } + } + tile_dst[i][j] = sum; + } + } + + DP(2, "TILE_GEMM_U: dst_t" << dst_treg << " = t" << src1_treg << "(sparse via m" << meta_reg << ") × u" << src2_ureg); + } + + // TILE_GEMM_V: Sparse tile (1:4) × Dense tile = Tile (T × V → T) + void tile_gemm_v(uint32_t dst_treg, uint32_t src1_treg, uint32_t src2_vreg, uint32_t meta_reg) { + assert(dst_treg < tile_reg_file_.size() && "Destination tile register out of bounds"); + assert(src1_treg < tile_reg_file_.size() && "Source1 tile register out of bounds"); + assert(meta_reg < metadata_reg_file_.size() && "Metadata register out of bounds"); - DTH(3, "FEDP: wid=" << wid << ", i=" << i << ", j=" << j << ", m=" << step_m << ", n=" << step_n << ", a_row={" << std::hex); - for (uint32_t q = 0; q < cfg::tcK; ++q) { - if (q) DTN(3, ", "); - DTN(3, "0x" << a_row[q].u32); + constexpr uint32_t TILE_DIM = 16; + + auto& tile_dst = tile_reg_file_[dst_treg]; + const auto& tile_a = tile_reg_file_[src1_treg]; // Sparse tile + const auto& meta_a = metadata_reg_file_[meta_reg]; // Metadata for sparse tile + + // V-register maps to 4 T-registers + std::vector src2_tregs = map_vreg_to_treg(src2_vreg); + + // For 1:4 sparsity, each 4-element block has 1 non-zero value + for (uint32_t i = 0; i < TILE_DIM; ++i) { + for (uint32_t j = 0; j < TILE_DIM; ++j) { + float sum = tile_dst[i][j]; // Accumulate + + // Process sparse A row with dense B column + for (uint32_t k_blk = 0; k_blk < TILE_DIM; k_blk += 4) { + uint8_t mask = meta_a[i][k_blk / 4]; // Get metadata for this 4-element block + + // Process only non-zero elements indicated by metadata + for (uint32_t offset = 0; offset < 4; ++offset) { + if (mask & (1u << offset)) { + uint32_t k = k_blk + offset; + // Determine which T-register to access for B + uint32_t treg_idx = src2_tregs[k / TILE_DIM]; + uint32_t k_local = k % TILE_DIM; + sum += tile_a[i][k] * tile_reg_file_[treg_idx][k_local][j]; + } + } } - DTN(3, "}, b_col={"); - for (uint32_t q = 0; q < cfg::tcK; ++q) { - if (q) DTN(3, ", "); - DTN(3, "0x" << b_col[q].u32); + tile_dst[i][j] = sum; + } + } + + DP(2, "TILE_GEMM_V: dst_t" << dst_treg << " = t" << src1_treg << "(sparse via m" << meta_reg << ") × v" << src2_vreg); + } + + // TILE_GEMM_R: Row-wise sparse tile × Dense tile = Tile (T × U → U) + void tile_gemm_r(uint32_t dst_ureg, uint32_t src1_treg, uint32_t src2_ureg, uint32_t meta_reg) { + assert(src1_treg < tile_reg_file_.size() && "Source1 tile register out of bounds"); + assert(meta_reg < metadata_reg_file_.size() && "Metadata register out of bounds"); + + constexpr uint32_t TILE_DIM = 16; + + const auto& tile_a = tile_reg_file_[src1_treg]; // Sparse tile + const auto& meta_a = metadata_reg_file_[meta_reg]; // Metadata for sparse tile + + // Both dst and src2 are U-registers (map to 2 T-registers each) + std::vector dst_tregs = map_ureg_to_treg(dst_ureg); + std::vector src2_tregs = map_ureg_to_treg(src2_ureg); + + // Row-wise sparsity: metadata can vary per row + for (uint32_t i = 0; i < TILE_DIM; ++i) { + for (uint32_t j = 0; j < TILE_DIM; ++j) { + // Determine which destination T-register + uint32_t dst_treg_idx = dst_tregs[j / TILE_DIM]; + uint32_t j_local = j % TILE_DIM; + + float sum = tile_reg_file_[dst_treg_idx][i][j_local]; // Accumulate + + // Process sparse A row with dense B column + for (uint32_t k_blk = 0; k_blk < TILE_DIM; k_blk += 4) { + uint8_t mask = meta_a[i][k_blk / 4]; // Row-wise metadata + + for (uint32_t offset = 0; offset < 4; ++offset) { + if (mask & (1u << offset)) { + uint32_t k = k_blk + offset; + uint32_t src2_treg_idx = src2_tregs[k / TILE_DIM]; + uint32_t k_local = k % TILE_DIM; + sum += tile_a[i][k] * tile_reg_file_[src2_treg_idx][k_local][j]; + } + } } - DTN(3, "}, c_val=0x" << c_val << ", d_val=0x" << d_val << std::dec << std::endl); + tile_reg_file_[dst_treg_idx][i][j_local] = sum; } } + + DP(2, "TILE_GEMM_R: dst_u" << dst_ureg << " = t" << src1_treg << "(sparse via m" << meta_reg << ") × u" << src2_ureg); } // Map ureg index to tile register indices @@ -363,8 +507,7 @@ class SparseUnit::Impl { // Calculate base address: rs1_data + immediate offset uint64_t base_addr = rs1_data.at(tid).i + lsuArgs.offset; - constexpr uint32_t TILE_ROWS = 16; - constexpr uint32_t TILE_COLS = 32; + constexpr uint32_t TILE_DIM = 16; switch (lsu_type) { case VegetaLsuType::TILE_LOAD_T: { @@ -375,10 +518,10 @@ class SparseUnit::Impl { constexpr uint32_t ELEMENT_SIZE = sizeof(typename vt::fp32::dtype); // 4 bytes for fp32 base_addr &= 0xFFFFFFFC; // Align to word boundary for fp32 loads - // Load tile from memory: 16 rows x 32 columns = 512 fp32 elements = 2048 bytes - for (uint32_t row = 0; row < TILE_ROWS; ++row) { - for (uint32_t col = 0; col < TILE_COLS; ++col) { - uint64_t mem_addr = base_addr + (row * TILE_COLS + col) * ELEMENT_SIZE; + // Load tile from memory: 16 rows x 16 columns = 256 fp32 elements = 1024 bytes + for (uint32_t row = 0; row < TILE_DIM; ++row) { + for (uint32_t col = 0; col < TILE_DIM; ++col) { + uint64_t mem_addr = base_addr + (row * TILE_DIM + col) * ELEMENT_SIZE; uint32_t mem_data = 0; core_->dcache_read(&mem_data, mem_addr, ELEMENT_SIZE); trace_data->mem_addrs.at(tid).push_back({mem_addr, ELEMENT_SIZE}); @@ -401,14 +544,15 @@ class SparseUnit::Impl { base_addr &= 0xFFFFFFFC; // Align to word boundary for fp32 loads constexpr uint32_t ELEMENT_SIZE = sizeof(typename vt::fp32::dtype); + uint64_t current_addr = base_addr; for (uint32_t treg_idx : target_tregs) { assert(treg_idx < tile_reg_file_.size() && "Tile register index out of bounds"); auto &tile_reg = tile_reg_file_[treg_idx]; - // Load tile from memory: 16 rows x 32 columns = 512 fp32 elements = 2048 bytes - for (uint32_t row = 0; row < TILE_ROWS; ++row) { - for (uint32_t col = 0; col < TILE_COLS; ++col) { - uint64_t mem_addr = base_addr + (row * TILE_COLS + col) * ELEMENT_SIZE; + // Load tile from memory: 16 rows x 16 columns = 256 fp32 elements = 1024 bytes + for (uint32_t row = 0; row < TILE_DIM; ++row) { + for (uint32_t col = 0; col < TILE_DIM; ++col) { + uint64_t mem_addr = current_addr + (row * TILE_DIM + col) * ELEMENT_SIZE; uint32_t mem_data = 0; core_->dcache_read(&mem_data, mem_addr, ELEMENT_SIZE); trace_data->mem_addrs.at(tid).push_back({mem_addr, ELEMENT_SIZE}); @@ -418,6 +562,7 @@ class SparseUnit::Impl { tile_reg[row][col] = value; } } + current_addr += TILE_DIM * TILE_DIM * ELEMENT_SIZE; // Move to next tile (1KB) } DP(2, "TILE_LOAD_U: wid=" << wid << ", tid=" << tid @@ -432,14 +577,15 @@ class SparseUnit::Impl { base_addr &= 0xFFFFFFFC; // Align to word boundary for fp32 loads constexpr uint32_t ELEMENT_SIZE = sizeof(typename vt::fp32::dtype); + uint64_t current_addr = base_addr; for (uint32_t treg_idx : target_tregs) { assert(treg_idx < tile_reg_file_.size() && "Tile register index out of bounds"); auto &tile_reg = tile_reg_file_[treg_idx]; - // Load tile from memory: 16 rows x 32 columns = 512 fp32 elements = 2048 bytes - for (uint32_t row = 0; row < TILE_ROWS; ++row) { - for (uint32_t col = 0; col < TILE_COLS; ++col) { - uint64_t mem_addr = base_addr + (row * TILE_COLS + col) * ELEMENT_SIZE; + // Load tile from memory: 16 rows x 16 columns = 256 fp32 elements = 1024 bytes + for (uint32_t row = 0; row < TILE_DIM; ++row) { + for (uint32_t col = 0; col < TILE_DIM; ++col) { + uint64_t mem_addr = current_addr + (row * TILE_DIM + col) * ELEMENT_SIZE; uint32_t mem_data = 0; core_->dcache_read(&mem_data, mem_addr, ELEMENT_SIZE); trace_data->mem_addrs.at(tid).push_back({mem_addr, ELEMENT_SIZE}); @@ -449,6 +595,7 @@ class SparseUnit::Impl { tile_reg[row][col] = value; } } + current_addr += TILE_DIM * TILE_DIM * ELEMENT_SIZE; // Move to next tile (1KB) } DP(2, "TILE_LOAD_V: wid=" << wid << ", tid=" << tid @@ -463,11 +610,11 @@ class SparseUnit::Impl { assert(meta_reg_idx < metadata_reg_file_.size() && "Metadata register index out of bounds"); auto &metadata_reg = metadata_reg_file_[meta_reg_idx]; - // Load metadata from memory: 16 rows x 32 columns = 512 uint4 elements = 256 bytes + // Load metadata from memory: 16 rows x 16 columns = 256 uint4 elements = 128 bytes // Each byte stores two uint4 values: upper nibble for col N, lower nibble for col N+1 - for (uint32_t row = 0; row < TILE_ROWS; ++row) { - for (uint32_t col = 0; col < TILE_COLS; col += 2) { - uint64_t mem_addr = base_addr + (row * (TILE_COLS / 2) + col / 2); + for (uint32_t row = 0; row < TILE_DIM; ++row) { + for (uint32_t col = 0; col < TILE_DIM; col += 2) { + uint64_t mem_addr = base_addr + (row * (TILE_DIM / 2) + col / 2); uint8_t mem_data = 0; core_->dcache_read(&mem_data, mem_addr, 1); trace_data->mem_addrs.at(tid).push_back({mem_addr, 1}); @@ -503,14 +650,13 @@ class SparseUnit::Impl { assert(vs3 < tile_reg_file_.size() && "Tile register index out of bounds"); auto &tile_reg = tile_reg_file_[vs3]; - constexpr uint32_t TILE_ROWS = 16; - constexpr uint32_t TILE_COLS = 32; + constexpr uint32_t TILE_DIM = 16; constexpr uint32_t ELEMENT_SIZE = sizeof(typename vt::fp32::dtype); // 4 bytes for fp32 - // Store tile to memory: 16 rows x 32 columns = 512 fp32 elements = 2048 bytes - for (uint32_t row = 0; row < TILE_ROWS; ++row) { - for (uint32_t col = 0; col < TILE_COLS; ++col) { - uint64_t mem_addr = base_addr + (row * TILE_COLS + col) * ELEMENT_SIZE; + // Store tile to memory: 16 rows x 16 columns = 256 fp32 elements = 1024 bytes + for (uint32_t row = 0; row < TILE_DIM; ++row) { + for (uint32_t col = 0; col < TILE_DIM; ++col) { + uint64_t mem_addr = base_addr + (row * TILE_DIM + col) * ELEMENT_SIZE; float value = tile_reg[row][col]; uint32_t mem_data = 0; std::memcpy(&mem_data, &value, ELEMENT_SIZE); @@ -603,8 +749,8 @@ class SparseUnit::Impl { Core* core_; Arch arch_; PerfStats perf_stats_; - SparseRegFile_t tile_reg_file_; // 8 registers, each 16x32 fp32 elements - std::vector>> metadata_reg_file_; // 8 registers, each 16x32 uint4 elements + SparseRegFile_t tile_reg_file_; // 8 registers, each 16x16 fp32 elements + std::vector>> metadata_reg_file_; // 8 registers, each 16x16 uint4 elements }; /////////////////////////////////////////////////////////////////////////////// @@ -664,3 +810,19 @@ void SparseUnit::wmma(uint32_t wid, ExeTraceData* trace_data) { impl_->wmma(wid, fmt_s, fmt_d, step_m, step_n, rs1_data, rs2_data, rs3_data, rd_data, trace_data); } + +void SparseUnit::tile_gemm_t(uint32_t dst_treg, uint32_t src1_treg, uint32_t src2_treg) { + impl_->tile_gemm_t(dst_treg, src1_treg, src2_treg); +} + +void SparseUnit::tile_gemm_u(uint32_t dst_treg, uint32_t src1_treg, uint32_t src2_ureg, uint32_t meta_reg) { + impl_->tile_gemm_u(dst_treg, src1_treg, src2_ureg, meta_reg); +} + +void SparseUnit::tile_gemm_v(uint32_t dst_treg, uint32_t src1_treg, uint32_t src2_vreg, uint32_t meta_reg) { + impl_->tile_gemm_v(dst_treg, src1_treg, src2_vreg, meta_reg); +} + +void SparseUnit::tile_gemm_r(uint32_t dst_ureg, uint32_t src1_treg, uint32_t src2_ureg, uint32_t meta_reg) { + impl_->tile_gemm_r(dst_ureg, src1_treg, src2_ureg, meta_reg); +} diff --git a/sim/simx/sparse_unit.h b/sim/simx/sparse_unit.h index b1822aa53a..125a749659 100644 --- a/sim/simx/sparse_unit.h +++ b/sim/simx/sparse_unit.h @@ -79,6 +79,11 @@ class SparseUnit : public SimObject { std::vector& rd_data, ExeTraceData* trace_data); + void tile_gemm_t(uint32_t dst_treg, uint32_t src1_treg, uint32_t src2_treg); + void tile_gemm_u(uint32_t dst_treg, uint32_t src1_treg, uint32_t src2_ureg, uint32_t meta_reg); + void tile_gemm_v(uint32_t dst_treg, uint32_t src1_treg, uint32_t src2_vreg, uint32_t meta_reg); + void tile_gemm_r(uint32_t dst_ureg, uint32_t src1_treg, uint32_t src2_ureg, uint32_t meta_reg); + const PerfStats& perf_stats() const; private: diff --git a/tests/regression/Makefile b/tests/regression/Makefile index 8aac0a48e2..4a6fc7c024 100644 --- a/tests/regression/Makefile +++ b/tests/regression/Makefile @@ -22,6 +22,7 @@ all: $(MAKE) -C madmax $(MAKE) -C stencil3d $(MAKE) -C veg_ls + $(MAKE) -C sgemm_tile run-simx: $(MAKE) -C basic run-simx @@ -44,6 +45,7 @@ run-simx: $(MAKE) -C madmax run-simx $(MAKE) -C stencil3d run-simx $(MAKE) -C veg_ls run-simx + $(MAKE) -C sgemm_tile run-simx run-rtlsim: $(MAKE) -C basic run-rtlsim @@ -66,6 +68,7 @@ run-rtlsim: $(MAKE) -C madmax run-rtlsim $(MAKE) -C stencil3d run-rtlsim $(MAKE) -C veg_ls run-rtlsim + $(MAKE) -C sgemm_tile run-rtlsim clean: $(MAKE) -C basic clean @@ -88,3 +91,4 @@ clean: $(MAKE) -C madmax clean $(MAKE) -C stencil3d clean $(MAKE) -C veg_ls clean + $(MAKE) -C sgemm_tile clean diff --git a/tests/regression/sgemm_tile/Makefile b/tests/regression/sgemm_tile/Makefile new file mode 100644 index 0000000000..181f404e11 --- /dev/null +++ b/tests/regression/sgemm_tile/Makefile @@ -0,0 +1,12 @@ +ROOT_DIR := $(realpath ../../..) +include $(ROOT_DIR)/config.mk + +PROJECT := sgemm_tile + +SRC_DIR := $(VORTEX_HOME)/tests/regression/$(PROJECT) + +SRCS := $(SRC_DIR)/main.cpp + +VX_SRCS := $(SRC_DIR)/kernel.cpp + +include ../common.mk diff --git a/tests/regression/sgemm_tile/common.h b/tests/regression/sgemm_tile/common.h new file mode 100644 index 0000000000..48164c324b --- /dev/null +++ b/tests/regression/sgemm_tile/common.h @@ -0,0 +1,16 @@ +#ifndef _COMMON_H_ +#define _COMMON_H_ + +#include + +// T Tile dimensions: 16x16 fp32 = 1KB per tile register +#define TILE_SIZE 16 +#define T_TILE_BYTES (TILE_SIZE * TILE_SIZE * sizeof(float)) // 1KB + +typedef struct { + uint64_t A_addr; // Matrix A (16x16 fp32) + uint64_t B_addr; // Matrix B (16x16 fp32) + uint64_t C_addr; // Matrix C (16x16 fp32) +} kernel_arg_t; + +#endif // _COMMON_H_ diff --git a/tests/regression/sgemm_tile/kernel.cpp b/tests/regression/sgemm_tile/kernel.cpp new file mode 100644 index 0000000000..6add5275e4 --- /dev/null +++ b/tests/regression/sgemm_tile/kernel.cpp @@ -0,0 +1,29 @@ +#include +#include +#include "common.h" + +// Simple TGEMM test: C[16x16] = A[16x16] × B[16x16] +// T Tile registers are 16x16 fp32 = 1KB each + +void kernel_body(kernel_arg_t* __UNIFORM__ arg) { + auto A_ptr = reinterpret_cast(arg->A_addr); + auto B_ptr = reinterpret_cast(arg->B_addr); + auto C_ptr = reinterpret_cast(arg->C_addr); + + // Load A tile into T-reg 1 + vx_lt(1, (size_t)A_ptr, 0); + + // Load B tile into T-reg 2 + vx_lt(2, (size_t)B_ptr, 0); + + // TGEMM: T0 = T1 × T2 (accumulate into T0) + vx_tgemm(0, 1, 2); + + // Store result from T-reg 0 to C + vx_st((size_t)C_ptr, 0, 0); +} + +int main() { + kernel_arg_t* arg = (kernel_arg_t*)csr_read(VX_CSR_MSCRATCH); + return vx_spawn_threads(1, nullptr, nullptr, (vx_kernel_func_cb)kernel_body, arg); +} diff --git a/tests/regression/sgemm_tile/main.cpp b/tests/regression/sgemm_tile/main.cpp new file mode 100644 index 0000000000..6f1fe1a049 --- /dev/null +++ b/tests/regression/sgemm_tile/main.cpp @@ -0,0 +1,221 @@ +#include +#include +#include +#include +#include +#include +#include "common.h" + +#define FLOAT_ULP 6 + +#define RT_CHECK(_expr) \ + do { \ + int _ret = _expr; \ + if (0 == _ret) \ + break; \ + printf("Error: '%s' returned %d!\n", #_expr, (int)_ret); \ + cleanup(); \ + exit(-1); \ + } while (false) + +/////////////////////////////////////////////////////////////////////////////// + +const char* kernel_file = "kernel.vxbin"; + +vx_device_h device = nullptr; +vx_buffer_h A_buffer = nullptr; +vx_buffer_h B_buffer = nullptr; +vx_buffer_h C_buffer = nullptr; +vx_buffer_h krnl_buffer = nullptr; +vx_buffer_h args_buffer = nullptr; +kernel_arg_t kernel_arg = {}; + +static void show_usage() { + std::cout << "Vortex SGEMM TILE Test (16x16 TGEMM)." << std::endl; + std::cout << "Usage: [-h: help]" << std::endl; +} + +static void parse_args(int argc, char **argv) { + int c; + while ((c = getopt(argc, argv, "h")) != -1) { + switch (c) { + case 'h': + show_usage(); + exit(0); + break; + default: + show_usage(); + exit(-1); + } + } +} + +void cleanup() { + if (device) { + vx_mem_free(A_buffer); + vx_mem_free(B_buffer); + vx_mem_free(C_buffer); + vx_mem_free(krnl_buffer); + vx_mem_free(args_buffer); + vx_dev_close(device); + } +} + +// CPU reference: C = A × B for 16x16 matrices +static void matmul_cpu(float* C, const float* A, const float* B) { + for (int m = 0; m < TILE_SIZE; ++m) { + for (int n = 0; n < TILE_SIZE; ++n) { + float sum = 0.0f; + for (int k = 0; k < TILE_SIZE; ++k) { + sum += A[m * TILE_SIZE + k] * B[k * TILE_SIZE + n]; + } + C[m * TILE_SIZE + n] = sum; + } + } +} + +// Compare floats with ULP tolerance +static bool compare_float(float a, float b, int index, int& errors) { + union { float f; int32_t i; } fa, fb; + fa.f = a; + fb.f = b; + auto d = std::abs(fa.i - fb.i); + if (d > FLOAT_ULP) { + if (errors < 100) { + printf("*** error: [%d] expected=%.6f, actual=%.6f\n", index, b, a); + } + ++errors; + return false; + } + return true; +} + +int main(int argc, char *argv[]) { + parse_args(argc, argv); + + std::srand(50); + + // open device connection + std::cout << "open device connection" << std::endl; + RT_CHECK(vx_dev_open(&device)); + + uint32_t num_elements = TILE_SIZE * TILE_SIZE; // 256 elements + uint32_t buf_size = T_TILE_BYTES; // 1KB + + std::cout << "SGEMM TILE Test: " << TILE_SIZE << "x" << TILE_SIZE << " matrices" << std::endl; + std::cout << "Buffer size: " << buf_size << " bytes (" << num_elements << " fp32 elements)" << std::endl; + + // allocate device memory + std::cout << "allocate device memory" << std::endl; + RT_CHECK(vx_mem_alloc(device, buf_size, VX_MEM_READ, &A_buffer)); + RT_CHECK(vx_mem_address(A_buffer, &kernel_arg.A_addr)); + RT_CHECK(vx_mem_alloc(device, buf_size, VX_MEM_READ, &B_buffer)); + RT_CHECK(vx_mem_address(B_buffer, &kernel_arg.B_addr)); + RT_CHECK(vx_mem_alloc(device, buf_size, VX_MEM_WRITE, &C_buffer)); + RT_CHECK(vx_mem_address(C_buffer, &kernel_arg.C_addr)); + + std::cout << "dev_A=0x" << std::hex << kernel_arg.A_addr << std::endl; + std::cout << "dev_B=0x" << std::hex << kernel_arg.B_addr << std::endl; + std::cout << "dev_C=0x" << std::hex << kernel_arg.C_addr << std::dec << std::endl; + + // allocate host buffers + std::cout << "allocate host buffers" << std::endl; + std::vector h_A(num_elements); + std::vector h_B(num_elements); + std::vector h_C(num_elements); + std::vector h_ref(num_elements); + + // Initialize with random values + for (uint32_t i = 0; i < num_elements; ++i) { + h_A[i] = static_cast(rand()) / RAND_MAX; + h_B[i] = static_cast(rand()) / RAND_MAX; + } + + // upload source buffers + std::cout << "upload source buffers" << std::endl; + RT_CHECK(vx_copy_to_dev(A_buffer, h_A.data(), 0, buf_size)); + RT_CHECK(vx_copy_to_dev(B_buffer, h_B.data(), 0, buf_size)); + + // upload kernel binary + std::cout << "upload kernel binary" << std::endl; + RT_CHECK(vx_upload_kernel_file(device, kernel_file, &krnl_buffer)); + + // upload kernel argument + std::cout << "upload kernel argument" << std::endl; + RT_CHECK(vx_upload_bytes(device, &kernel_arg, sizeof(kernel_arg_t), &args_buffer)); + + // start device + std::cout << "start device" << std::endl; + RT_CHECK(vx_start(device, krnl_buffer, args_buffer)); + + // wait for completion + std::cout << "wait for completion" << std::endl; + RT_CHECK(vx_ready_wait(device, VX_MAX_TIMEOUT)); + + // download result + std::cout << "download result" << std::endl; + RT_CHECK(vx_copy_from_dev(h_C.data(), C_buffer, 0, buf_size)); + + // compute CPU reference + std::cout << "verify result" << std::endl; + matmul_cpu(h_ref.data(), h_A.data(), h_B.data()); + + // verify result + int errors = 0; + for (uint32_t i = 0; i < num_elements; ++i) { + compare_float(h_C[i], h_ref[i], i, errors); + } + + // write matrices to output file + std::cout << "writing matrices to output file" << std::endl; + std::ofstream output_file("matrices_output.txt"); + if (output_file.is_open()) { + output_file << "Matrix A (" << TILE_SIZE << "x" << TILE_SIZE << "):\n"; + for (int i = 0; i < TILE_SIZE; ++i) { + for (int j = 0; j < TILE_SIZE; ++j) { + output_file << h_A[i * TILE_SIZE + j]; + if (j < TILE_SIZE - 1) output_file << " "; + } + output_file << "\n"; + } + output_file << "\n"; + + output_file << "Matrix B (" << TILE_SIZE << "x" << TILE_SIZE << "):\n"; + for (int i = 0; i < TILE_SIZE; ++i) { + for (int j = 0; j < TILE_SIZE; ++j) { + output_file << h_B[i * TILE_SIZE + j]; + if (j < TILE_SIZE - 1) output_file << " "; + } + output_file << "\n"; + } + output_file << "\n"; + + output_file << "Matrix C (Result, " << TILE_SIZE << "x" << TILE_SIZE << "):\n"; + for (int i = 0; i < TILE_SIZE; ++i) { + for (int j = 0; j < TILE_SIZE; ++j) { + output_file << h_C[i * TILE_SIZE + j]; + if (j < TILE_SIZE - 1) output_file << " "; + } + output_file << "\n"; + } + + output_file.close(); + std::cout << "Matrices written to 'matrices_output.txt'" << std::endl; + } else { + std::cerr << "Error: Unable to open output file" << std::endl; + } + + // cleanup + std::cout << "cleanup" << std::endl; + cleanup(); + + if (errors != 0) { + std::cout << "Found " << errors << " errors!" << std::endl; + std::cout << "FAILED!" << std::endl; + return 1; + } + + std::cout << "PASSED!" << std::endl; + + return 0; +} diff --git a/tests/regression/veg_ls/common.h b/tests/regression/veg_ls/common.h index 46cce37589..8fbbb79d18 100644 --- a/tests/regression/veg_ls/common.h +++ b/tests/regression/veg_ls/common.h @@ -5,17 +5,17 @@ #define TYPE float #endif -// T-reg: 2KB (16x32 4-byte elements) -#define T_TILE_SIZE 2048 +// T-reg: 1KB (16x16 fp32 elements) +#define T_TILE_SIZE 1024 -// U-reg: 4KB (2 x T-reg) -#define U_TILE_SIZE 4096 +// U-reg: 2KB (2 x T-reg) +#define U_TILE_SIZE 2048 -// V-reg: 8KB (4 x T-reg) -#define V_TILE_SIZE 8192 +// V-reg: 4KB (4 x T-reg) +#define V_TILE_SIZE 4096 -// M-reg: 256B (16x32 4-bit elements) -#define M_TILE_SIZE 256 +// M-reg: 128B (16x16 4-bit elements, 2 per byte) +#define M_TILE_SIZE 128 // Number of tiles to test for each register type #define NUM_T_TILES 8 // Test all 8 T-regs diff --git a/tests/regression/veg_ls/kernel.cpp b/tests/regression/veg_ls/kernel.cpp index 887277344e..e992786a5d 100644 --- a/tests/regression/veg_ls/kernel.cpp +++ b/tests/regression/veg_ls/kernel.cpp @@ -3,7 +3,7 @@ #include "common.h" void kernel_body(kernel_arg_t* __UNIFORM__ arg) { - auto src_t_ptr = reinterpret_cast(arg->src_t_addr); + auto src_t_ptr =reinterpret_cast(arg->src_t_addr); auto dst_t_ptr = reinterpret_cast(arg->dst_t_addr); auto src_u_ptr = reinterpret_cast(arg->src_u_addr); auto dst_u_ptr = reinterpret_cast(arg->dst_u_addr); @@ -12,34 +12,30 @@ void kernel_body(kernel_arg_t* __UNIFORM__ arg) { auto src_m_ptr = reinterpret_cast(arg->src_m_addr); auto dst_m_ptr = reinterpret_cast(arg->dst_m_addr); - // ===== Test 1: TILE_LOAD_T - Load/Store all 8 T-regs individually ===== - vx_lt(0, (size_t)(src_t_ptr + 0 * T_TILE_SIZE), 0); - vx_st((size_t)(dst_t_ptr + 0 * T_TILE_SIZE), 0, 0); + // ===== LOAD ALL TILES FIRST ===== + // This prevents later loads from overwriting T-regs before earlier data is stored + // Test 1: TILE_LOAD_T - Load all 8 T-regs individually + vx_lt(0, (size_t)(src_t_ptr + 0 * T_TILE_SIZE), 0); vx_lt(1, (size_t)(src_t_ptr + 1 * T_TILE_SIZE), 0); - vx_st((size_t)(dst_t_ptr + 1 * T_TILE_SIZE), 0, 1); - vx_lt(2, (size_t)(src_t_ptr + 2 * T_TILE_SIZE), 0); - vx_st((size_t)(dst_t_ptr + 2 * T_TILE_SIZE), 0, 2); - vx_lt(3, (size_t)(src_t_ptr + 3 * T_TILE_SIZE), 0); - vx_st((size_t)(dst_t_ptr + 3 * T_TILE_SIZE), 0, 3); - vx_lt(4, (size_t)(src_t_ptr + 4 * T_TILE_SIZE), 0); - vx_st((size_t)(dst_t_ptr + 4 * T_TILE_SIZE), 0, 4); - vx_lt(5, (size_t)(src_t_ptr + 5 * T_TILE_SIZE), 0); - vx_st((size_t)(dst_t_ptr + 5 * T_TILE_SIZE), 0, 5); - vx_lt(6, (size_t)(src_t_ptr + 6 * T_TILE_SIZE), 0); - vx_st((size_t)(dst_t_ptr + 6 * T_TILE_SIZE), 0, 6); - vx_lt(7, (size_t)(src_t_ptr + 7 * T_TILE_SIZE), 0); - vx_st((size_t)(dst_t_ptr + 7 * T_TILE_SIZE), 0, 7); - vx_barrier(0, 1); + // Store T-tiles immediately while data is still in registers + vx_st((size_t)(dst_t_ptr + 0 * T_TILE_SIZE), 0, 0); + vx_st((size_t)(dst_t_ptr + 1 * T_TILE_SIZE), 0, 1); + vx_st((size_t)(dst_t_ptr + 2 * T_TILE_SIZE), 0, 2); + vx_st((size_t)(dst_t_ptr + 3 * T_TILE_SIZE), 0, 3); + vx_st((size_t)(dst_t_ptr + 4 * T_TILE_SIZE), 0, 4); + vx_st((size_t)(dst_t_ptr + 5 * T_TILE_SIZE), 0, 5); + vx_st((size_t)(dst_t_ptr + 6 * T_TILE_SIZE), 0, 6); + vx_st((size_t)(dst_t_ptr + 7 * T_TILE_SIZE), 0, 7); - // ===== Test 2: TILE_LOAD_U - Load/Store all 4 U-regs (covers all 8 T-regs) ===== + // Test 2: TILE_LOAD_U - Load all 4 U-regs (covers all 8 T-regs) // U-reg 0 maps to T-regs [0, 1] vx_lu(0, (size_t)(src_u_ptr + 0 * U_TILE_SIZE), 0); vx_st((size_t)(dst_u_ptr + 0 * U_TILE_SIZE), 0, 0); @@ -60,9 +56,7 @@ void kernel_body(kernel_arg_t* __UNIFORM__ arg) { vx_st((size_t)(dst_u_ptr + 3 * U_TILE_SIZE), 0, 6); vx_st((size_t)(dst_u_ptr + 3 * U_TILE_SIZE + T_TILE_SIZE), 0, 7); - vx_barrier(0, 1); - - // ===== Test 3: TILE_LOAD_V - Load/Store all 2 V-regs (covers all 8 T-regs) ===== + // Test 3: TILE_LOAD_V - Load all 2 V-regs (covers all 8 T-regs) // V-reg 0 maps to T-regs [0, 1, 2, 3] vx_lv(0, (size_t)(src_v_ptr + 0 * V_TILE_SIZE), 0); vx_st((size_t)(dst_v_ptr + 0 * V_TILE_SIZE), 0, 0); @@ -77,9 +71,7 @@ void kernel_body(kernel_arg_t* __UNIFORM__ arg) { vx_st((size_t)(dst_v_ptr + 1 * V_TILE_SIZE + 2 * T_TILE_SIZE), 0, 6); vx_st((size_t)(dst_v_ptr + 1 * V_TILE_SIZE + 3 * T_TILE_SIZE), 0, 7); - vx_barrier(0, 1); - - // ===== Test 4: TILE_LOAD_M - Load all 8 M-regs ===== + // Test 4: TILE_LOAD_M - Load all 8 M-regs // M-registers store metadata (sparsity patterns/masks) vx_lm(0, (size_t)(src_m_ptr + 0 * M_TILE_SIZE), 0); vx_lm(1, (size_t)(src_m_ptr + 1 * M_TILE_SIZE), 0); From eee389f10dddaaf32854350773111c7fddc43c3f Mon Sep 17 00:00:00 2001 From: tandonmitul27 Date: Fri, 5 Dec 2025 14:22:01 +0530 Subject: [PATCH 4/6] tested vx_tgemm, vx_ugemm and vx_vgemm --- sim/simx/sparse_unit.cpp | 69 ++++-- tests/regression/sgemm_tile/common.h | 18 +- tests/regression/sgemm_tile/kernel.cpp | 48 +++- tests/regression/sgemm_tile/main.cpp | 328 ++++++++++++++++++++++--- 4 files changed, 396 insertions(+), 67 deletions(-) diff --git a/sim/simx/sparse_unit.cpp b/sim/simx/sparse_unit.cpp index 0811379760..921e0e427a 100644 --- a/sim/simx/sparse_unit.cpp +++ b/sim/simx/sparse_unit.cpp @@ -357,23 +357,38 @@ class SparseUnit::Impl { // Metadata byte indicates which 2 positions are non-zero // We process 2 T-registers as one U-register + // U-register spans 2 T-registers, giving K dimension of 2*TILE_DIM = 32 + // A is stored in compressed form: 16 values per row representing 32 logical positions + // Metadata indicates which 2 out of every 4 logical positions are stored + for (uint32_t i = 0; i < TILE_DIM; ++i) { for (uint32_t j = 0; j < TILE_DIM; ++j) { float sum = tile_dst[i][j]; // Accumulate - // Process sparse A row with dense B column - // Every 4 elements in A row, check metadata - for (uint32_t k_blk = 0; k_blk < TILE_DIM; k_blk += 4) { - uint8_t mask = meta_a[i][k_blk / 4]; // Get metadata for this 4-element block + // Iterate through compressed A values and map to logical K positions + uint32_t compressed_idx = 0; // Index into compressed storage (tile_a) + + // Process 8 groups of 4 logical K positions (covering K=0..31) + for (uint32_t k_grp = 0; k_grp < 8; ++k_grp) { + uint8_t mask = meta_a[i][k_grp]; // Metadata for this 4-element group + uint32_t k_base = k_grp * 4; // Base logical K position for this group - // Process only non-zero elements indicated by metadata + // Check each of the 4 positions in this group for (uint32_t offset = 0; offset < 4; ++offset) { if (mask & (1u << offset)) { - uint32_t k = k_blk + offset; - // Determine which T-register to access for B - uint32_t treg_idx = (k < TILE_DIM) ? src2_tregs[0] : src2_tregs[1]; - uint32_t k_local = k % TILE_DIM; - sum += tile_a[i][k] * tile_reg_file_[treg_idx][k_local][j]; + // This position is non-zero + uint32_t k_logical = k_base + offset; // Logical K position (0-31) + + // Access compressed value from tile_a + float a_val = tile_a[i][compressed_idx]; + + // Determine which T-register of B to access + uint32_t treg_idx = (k_logical < TILE_DIM) ? src2_tregs[0] : src2_tregs[1]; + uint32_t k_local = k_logical % TILE_DIM; + + sum += a_val * tile_reg_file_[treg_idx][k_local][j]; + + compressed_idx++; // Move to next compressed value } } } @@ -400,22 +415,38 @@ class SparseUnit::Impl { std::vector src2_tregs = map_vreg_to_treg(src2_vreg); // For 1:4 sparsity, each 4-element block has 1 non-zero value + // V-register spans 4 T-registers, giving K dimension of 4*TILE_DIM = 64 + // A is stored in compressed form: 16 values per row representing 64 logical positions + // Metadata indicates which 1 out of every 4 logical positions is stored + for (uint32_t i = 0; i < TILE_DIM; ++i) { for (uint32_t j = 0; j < TILE_DIM; ++j) { float sum = tile_dst[i][j]; // Accumulate - // Process sparse A row with dense B column - for (uint32_t k_blk = 0; k_blk < TILE_DIM; k_blk += 4) { - uint8_t mask = meta_a[i][k_blk / 4]; // Get metadata for this 4-element block + // Iterate through compressed A values and map to logical K positions + uint32_t compressed_idx = 0; // Index into compressed storage (tile_a) + + // Process 16 groups of 4 logical K positions (covering K=0..63) + for (uint32_t k_grp = 0; k_grp < 16; ++k_grp) { + uint8_t mask = meta_a[i][k_grp]; // Metadata for this 4-element group + uint32_t k_base = k_grp * 4; // Base logical K position for this group - // Process only non-zero elements indicated by metadata + // Check each of the 4 positions in this group for (uint32_t offset = 0; offset < 4; ++offset) { if (mask & (1u << offset)) { - uint32_t k = k_blk + offset; - // Determine which T-register to access for B - uint32_t treg_idx = src2_tregs[k / TILE_DIM]; - uint32_t k_local = k % TILE_DIM; - sum += tile_a[i][k] * tile_reg_file_[treg_idx][k_local][j]; + // This position is non-zero + uint32_t k_logical = k_base + offset; // Logical K position (0-63) + + // Access compressed value from tile_a + float a_val = tile_a[i][compressed_idx]; + + // Determine which T-register of B to access + uint32_t treg_idx = src2_tregs[k_logical / TILE_DIM]; + uint32_t k_local = k_logical % TILE_DIM; + + sum += a_val * tile_reg_file_[treg_idx][k_local][j]; + + compressed_idx++; // Move to next compressed value } } } diff --git a/tests/regression/sgemm_tile/common.h b/tests/regression/sgemm_tile/common.h index 48164c324b..fca0f88223 100644 --- a/tests/regression/sgemm_tile/common.h +++ b/tests/regression/sgemm_tile/common.h @@ -6,11 +6,23 @@ // T Tile dimensions: 16x16 fp32 = 1KB per tile register #define TILE_SIZE 16 #define T_TILE_BYTES (TILE_SIZE * TILE_SIZE * sizeof(float)) // 1KB +#define U_TILE_BYTES (2 * T_TILE_BYTES) // 2KB (U-reg = 2 T-regs for dense) +#define V_TILE_BYTES (4 * T_TILE_BYTES) // 4KB (V-reg = 4 T-regs for dense) +#define M_TILE_BYTES (TILE_SIZE * TILE_SIZE / 2) // 128 bytes (metadata: 2 nibbles per byte) + +// GEMM modes +typedef enum { + GEMM_MODE_TGEMM = 0, // T x T -> T (dense x dense) + GEMM_MODE_UGEMM = 1, // T x U -> T (sparse 2:4 packed x dense 2x) + GEMM_MODE_VGEMM = 2 // T x V -> T (sparse 1:4 packed x dense 4x) +} gemm_mode_t; typedef struct { - uint64_t A_addr; // Matrix A (16x16 fp32) - uint64_t B_addr; // Matrix B (16x16 fp32) - uint64_t C_addr; // Matrix C (16x16 fp32) + uint64_t A_addr; // Matrix A (1KB T-tile, sparse for UGEMM/VGEMM) + uint64_t B_addr; // Matrix B (1KB/2KB/4KB depending on mode, always dense) + uint64_t M_addr; // Metadata for sparse A (128 bytes, only for UGEMM/VGEMM) + uint64_t C_addr; // Matrix C result (1KB T-tile) + uint32_t mode; // GEMM mode (TGEMM=0, UGEMM=1, VGEMM=2) } kernel_arg_t; #endif // _COMMON_H_ diff --git a/tests/regression/sgemm_tile/kernel.cpp b/tests/regression/sgemm_tile/kernel.cpp index 6add5275e4..ef5df44ac5 100644 --- a/tests/regression/sgemm_tile/kernel.cpp +++ b/tests/regression/sgemm_tile/kernel.cpp @@ -2,24 +2,54 @@ #include #include "common.h" -// Simple TGEMM test: C[16x16] = A[16x16] × B[16x16] -// T Tile registers are 16x16 fp32 = 1KB each +// GEMM kernel supporting three modes: +// - TGEMM: C[16x16] = A[16x16] × B[16x16] (dense × dense) +// - UGEMM: C[16x16] = A[16x16] × B[2:4 sparse] (dense × 2:4 sparse) +// - VGEMM: C[16x16] = A[16x16] × B[1:4 sparse] (dense × 1:4 sparse) void kernel_body(kernel_arg_t* __UNIFORM__ arg) { auto A_ptr = reinterpret_cast(arg->A_addr); auto B_ptr = reinterpret_cast(arg->B_addr); + auto M_ptr = reinterpret_cast(arg->M_addr); auto C_ptr = reinterpret_cast(arg->C_addr); + uint32_t mode = arg->mode; - // Load A tile into T-reg 1 + // Load A tile into T-reg 1 (always dense 1KB) vx_lt(1, (size_t)A_ptr, 0); - // Load B tile into T-reg 2 - vx_lt(2, (size_t)B_ptr, 0); + if (mode == GEMM_MODE_TGEMM) { + // TGEMM: T × T -> T + // Load B tile into T-reg 2 (1KB dense) + vx_lt(2, (size_t)B_ptr, 0); + + // TGEMM: T0 = T1 × T2 (accumulate into T0) + vx_tgemm(0, 1, 2); + } + else if (mode == GEMM_MODE_UGEMM) { + // UGEMM: T × U -> T (2:4 sparse) + // Load metadata into M-reg 1 (1KB metadata) + vx_lm(1, (size_t)M_ptr, 0); + + // Load B tile into U-reg 2 (2KB sparse 2:4) + vx_lu(2, (size_t)B_ptr, 0); + + // UGEMM: T0 = T1 × U2 (accumulate into T0) + vx_ugemm(0, 1, 2); + } + else if (mode == GEMM_MODE_VGEMM) { + // VGEMM: T × V -> T (1:4 sparse) + // Load metadata into M-reg 1 (1KB metadata) + vx_lm(1, (size_t)M_ptr, 0); + + // Load B tile into V-reg 1 (4KB sparse 1:4) + // Note: V-reg 1 maps to T-regs 4-7, staying within the 8 T-reg limit + vx_lv(1, (size_t)B_ptr, 0); + + // VGEMM: T0 = T1 (sparse with M1 metadata) × V1 (dense) + vx_vgemm(0, 1, 1); + } - // TGEMM: T0 = T1 × T2 (accumulate into T0) - vx_tgemm(0, 1, 2); - - // Store result from T-reg 0 to C + // Store result from T-reg 0 to C (always 1KB) vx_st((size_t)C_ptr, 0, 0); } diff --git a/tests/regression/sgemm_tile/main.cpp b/tests/regression/sgemm_tile/main.cpp index 6f1fe1a049..c77d54b404 100644 --- a/tests/regression/sgemm_tile/main.cpp +++ b/tests/regression/sgemm_tile/main.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include #include "common.h" @@ -25,20 +26,35 @@ const char* kernel_file = "kernel.vxbin"; vx_device_h device = nullptr; vx_buffer_h A_buffer = nullptr; vx_buffer_h B_buffer = nullptr; +vx_buffer_h M_buffer = nullptr; vx_buffer_h C_buffer = nullptr; vx_buffer_h krnl_buffer = nullptr; vx_buffer_h args_buffer = nullptr; kernel_arg_t kernel_arg = {}; +static gemm_mode_t gemm_mode = GEMM_MODE_TGEMM; + static void show_usage() { - std::cout << "Vortex SGEMM TILE Test (16x16 TGEMM)." << std::endl; - std::cout << "Usage: [-h: help]" << std::endl; + std::cout << "Vortex SGEMM TILE Test (16x16 matrix operations)." << std::endl; + std::cout << "Usage: [-m mode] [-h: help]" << std::endl; + std::cout << " -m mode: GEMM mode (0=TGEMM, 1=UGEMM, 2=VGEMM) [default: 0]" << std::endl; + std::cout << " TGEMM (0): T × T -> T (dense × dense)" << std::endl; + std::cout << " UGEMM (1): T × U -> T (dense × 2:4 sparse)" << std::endl; + std::cout << " VGEMM (2): T × V -> T (dense × 1:4 sparse)" << std::endl; } static void parse_args(int argc, char **argv) { int c; - while ((c = getopt(argc, argv, "h")) != -1) { + while ((c = getopt(argc, argv, "m:h")) != -1) { switch (c) { + case 'm': + gemm_mode = static_cast(atoi(optarg)); + if (gemm_mode < GEMM_MODE_TGEMM || gemm_mode > GEMM_MODE_VGEMM) { + std::cerr << "Error: Invalid mode " << gemm_mode << std::endl; + show_usage(); + exit(-1); + } + break; case 'h': show_usage(); exit(0); @@ -54,6 +70,7 @@ void cleanup() { if (device) { vx_mem_free(A_buffer); vx_mem_free(B_buffer); + vx_mem_free(M_buffer); vx_mem_free(C_buffer); vx_mem_free(krnl_buffer); vx_mem_free(args_buffer); @@ -61,15 +78,118 @@ void cleanup() { } } -// CPU reference: C = A × B for 16x16 matrices -static void matmul_cpu(float* C, const float* A, const float* B) { - for (int m = 0; m < TILE_SIZE; ++m) { - for (int n = 0; n < TILE_SIZE; ++n) { +// Generate compressed 2:4 sparse tile and metadata from full logical matrix +// Input: logical_tile is M×K (e.g., 16×32), output: compressed_tile is M×(K/2) (e.g., 16×16) +// Metadata format: 16×16 nibbles stored as 128 bytes (8 bytes per row, 2 nibbles per byte) +static void compress_2_4_sparse(const std::vector& logical_tile, int M, int K, + std::vector& compressed_tile, std::vector& metadata) { + compressed_tile.resize(M * (K / 2)); + metadata.resize(128); // Fixed size: 16 rows × 8 bytes per row + std::fill(metadata.begin(), metadata.end(), 0); + + for (int row = 0; row < M; ++row) { + int compressed_col = 0; + + // Process K/4 groups of 4 elements + for (int k_grp = 0; k_grp < K / 4; ++k_grp) { + int k_base = k_grp * 4; + + // Find the 2 largest magnitude values in this group of 4 + std::pair vals[4]; + for (int offset = 0; offset < 4; ++offset) { + vals[offset] = {offset, logical_tile[row * K + k_base + offset]}; + } + + // Sort by magnitude to find top 2 + std::sort(vals, vals + 4, [](const auto& a, const auto& b) { + return std::abs(a.second) > std::abs(b.second); + }); + + // Create bitmask for top 2 values + uint8_t mask = 0; + for (int i = 0; i < 2; ++i) { + int offset = vals[i].first; + mask |= (1u << offset); + } + + // Store compressed values in POSITION ORDER (not magnitude order) + // Hardware iterates through bit positions 0-3 and expects values in that order + for (int offset = 0; offset < 4; ++offset) { + if (mask & (1u << offset)) { + compressed_tile[row * (K / 2) + compressed_col++] = logical_tile[row * K + k_base + offset]; + } + } + + // Store metadata in 16×16 nibble format (128 bytes) + // Each row has 8 bytes, each byte has 2 nibbles + // Byte layout per row: byte 0 = cols 0,1; byte 1 = cols 2,3; ...; byte 7 = cols 14,15 + int byte_idx = row * 8 + k_grp / 2; + if (k_grp % 2 == 0) { + metadata[byte_idx] = (mask << 4); // Upper nibble + } else { + metadata[byte_idx] |= mask; // Lower nibble + } + + } + } +} + +// Generate compressed 1:4 sparse tile and metadata from full logical matrix +// Input: logical_tile is M×K (e.g., 16×64), output: compressed_tile is M×(K/4) (e.g., 16×16) +// Metadata format: 16×16 nibbles stored as 128 bytes (8 bytes per row, 2 nibbles per byte) +static void compress_1_4_sparse(const std::vector& logical_tile, int M, int K, + std::vector& compressed_tile, std::vector& metadata) { + compressed_tile.resize(M * (K / 4)); + metadata.resize(128); // Fixed size: 16 rows × 8 bytes per row + std::fill(metadata.begin(), metadata.end(), 0); + + for (int row = 0; row < M; ++row) { + int compressed_col = 0; + + // Process K/4 groups of 4 elements + for (int k_grp = 0; k_grp < K / 4; ++k_grp) { + int k_base = k_grp * 4; + + // Find the largest magnitude value in this group of 4 + int max_offset = 0; + float max_val = std::abs(logical_tile[row * K + k_base]); + for (int offset = 1; offset < 4; ++offset) { + float val = std::abs(logical_tile[row * K + k_base + offset]); + if (val > max_val) { + max_val = val; + max_offset = offset; + } + } + + // Create bitmask and store compressed value + uint8_t mask = (1u << max_offset); + compressed_tile[row * (K / 4) + compressed_col++] = logical_tile[row * K + k_base + max_offset]; + + // Store metadata in 16×16 nibble format (128 bytes) + // Each row has 8 bytes, each byte has 2 nibbles + int byte_idx = row * 8 + k_grp / 2; + if (k_grp % 2 == 0) { + metadata[byte_idx] = (mask << 4); // Upper nibble + } else { + metadata[byte_idx] |= mask; // Lower nibble + } + } + } +} + +// CPU reference: C = A × B +// A is MxK, B is KxN, C is MxN +// For TGEMM: A is 16x16, B is 16x16 +// For UGEMM: A is 16x16 (but with 2:4 sparsity, effectively 16x32 positions), B is 16x32 +// For VGEMM: A is 16x16 (but with 1:4 sparsity, effectively 16x64 positions), B is 16x64 +static void matmul_cpu(float* C, const float* A, const float* B, int M, int K, int N) { + for (int m = 0; m < M; ++m) { + for (int n = 0; n < N; ++n) { float sum = 0.0f; - for (int k = 0; k < TILE_SIZE; ++k) { - sum += A[m * TILE_SIZE + k] * B[k * TILE_SIZE + n]; + for (int k = 0; k < K; ++k) { + sum += A[m * K + k] * B[k * N + n]; } - C[m * TILE_SIZE + n] = sum; + C[m * N + n] = sum; } } } @@ -100,41 +220,134 @@ int main(int argc, char *argv[]) { RT_CHECK(vx_dev_open(&device)); uint32_t num_elements = TILE_SIZE * TILE_SIZE; // 256 elements - uint32_t buf_size = T_TILE_BYTES; // 1KB + uint32_t A_buf_size = T_TILE_BYTES; // Always 1KB for A + uint32_t C_buf_size = T_TILE_BYTES; // Always 1KB for C + uint32_t B_buf_size, M_buf_size = 0; + + const char* mode_name; + switch (gemm_mode) { + case GEMM_MODE_TGEMM: + mode_name = "TGEMM (T × T)"; + B_buf_size = T_TILE_BYTES; // 1KB + break; + case GEMM_MODE_UGEMM: + mode_name = "UGEMM (T × U, 2:4 sparse)"; + B_buf_size = U_TILE_BYTES; // 2KB + M_buf_size = M_TILE_BYTES; // 1KB metadata + break; + case GEMM_MODE_VGEMM: + mode_name = "VGEMM (T × V, 1:4 sparse)"; + B_buf_size = V_TILE_BYTES; // 4KB + M_buf_size = M_TILE_BYTES; // 1KB metadata + break; + default: + std::cerr << "Invalid GEMM mode!" << std::endl; + return -1; + } - std::cout << "SGEMM TILE Test: " << TILE_SIZE << "x" << TILE_SIZE << " matrices" << std::endl; - std::cout << "Buffer size: " << buf_size << " bytes (" << num_elements << " fp32 elements)" << std::endl; + std::cout << "SGEMM TILE Test: " << mode_name << std::endl; + std::cout << "Matrix size: " << TILE_SIZE << "x" << TILE_SIZE << std::endl; + std::cout << "A buffer: " << A_buf_size << " bytes" << std::endl; + std::cout << "B buffer: " << B_buf_size << " bytes" << std::endl; + if (M_buf_size > 0) { + std::cout << "M buffer: " << M_buf_size << " bytes (metadata)" << std::endl; + } + std::cout << "C buffer: " << C_buf_size << " bytes" << std::endl; // allocate device memory std::cout << "allocate device memory" << std::endl; - RT_CHECK(vx_mem_alloc(device, buf_size, VX_MEM_READ, &A_buffer)); + RT_CHECK(vx_mem_alloc(device, A_buf_size, VX_MEM_READ, &A_buffer)); RT_CHECK(vx_mem_address(A_buffer, &kernel_arg.A_addr)); - RT_CHECK(vx_mem_alloc(device, buf_size, VX_MEM_READ, &B_buffer)); + RT_CHECK(vx_mem_alloc(device, B_buf_size, VX_MEM_READ, &B_buffer)); RT_CHECK(vx_mem_address(B_buffer, &kernel_arg.B_addr)); - RT_CHECK(vx_mem_alloc(device, buf_size, VX_MEM_WRITE, &C_buffer)); + + kernel_arg.M_addr = 0; + if (M_buf_size > 0) { + RT_CHECK(vx_mem_alloc(device, M_buf_size, VX_MEM_READ, &M_buffer)); + RT_CHECK(vx_mem_address(M_buffer, &kernel_arg.M_addr)); + } + + RT_CHECK(vx_mem_alloc(device, C_buf_size, VX_MEM_WRITE, &C_buffer)); RT_CHECK(vx_mem_address(C_buffer, &kernel_arg.C_addr)); + + kernel_arg.mode = gemm_mode; std::cout << "dev_A=0x" << std::hex << kernel_arg.A_addr << std::endl; std::cout << "dev_B=0x" << std::hex << kernel_arg.B_addr << std::endl; + if (kernel_arg.M_addr) { + std::cout << "dev_M=0x" << std::hex << kernel_arg.M_addr << std::endl; + } std::cout << "dev_C=0x" << std::hex << kernel_arg.C_addr << std::dec << std::endl; // allocate host buffers std::cout << "allocate host buffers" << std::endl; - std::vector h_A(num_elements); - std::vector h_B(num_elements); + + // A's logical size depends on mode: 16x16 for TGEMM, 16x32 for UGEMM, 16x64 for VGEMM + uint32_t A_cols_logical = TILE_SIZE; + if (gemm_mode == GEMM_MODE_UGEMM) A_cols_logical = 2 * TILE_SIZE; // 32 logical cols + else if (gemm_mode == GEMM_MODE_VGEMM) A_cols_logical = 4 * TILE_SIZE; // 64 logical cols + + // B size matches A's logical K dimension + uint32_t B_cols = TILE_SIZE; // B is always 16 cols wide (output is 16x16) + + std::vector h_A_logical(TILE_SIZE * A_cols_logical); // Logical A before compression + std::vector h_A(num_elements); // Compressed A (always 16x16 = 1KB for storage) + std::vector h_B(A_cols_logical * B_cols); // B is K×N where K matches A's logical K std::vector h_C(num_elements); std::vector h_ref(num_elements); - // Initialize with random values - for (uint32_t i = 0; i < num_elements; ++i) { - h_A[i] = static_cast(rand()) / RAND_MAX; + // Initialize logical matrix A + for (uint32_t i = 0; i < TILE_SIZE * A_cols_logical; ++i) { + h_A_logical[i] = static_cast(rand()) / RAND_MAX; + } + + // Initialize matrix B (K×N where K = A's logical cols) + for (uint32_t i = 0; i < A_cols_logical * B_cols; ++i) { h_B[i] = static_cast(rand()) / RAND_MAX; } // upload source buffers std::cout << "upload source buffers" << std::endl; - RT_CHECK(vx_copy_to_dev(A_buffer, h_A.data(), 0, buf_size)); - RT_CHECK(vx_copy_to_dev(B_buffer, h_B.data(), 0, buf_size)); + + std::vector h_M; // Metadata + + if (gemm_mode == GEMM_MODE_TGEMM) { + // TGEMM: A (16x16) × B (16x16) = C (16x16) + // Both dense in T-registers, no metadata + h_A = h_A_logical; // No compression needed + RT_CHECK(vx_copy_to_dev(A_buffer, h_A.data(), 0, A_buf_size)); + RT_CHECK(vx_copy_to_dev(B_buffer, h_B.data(), 0, B_buf_size)); + } + else if (gemm_mode == GEMM_MODE_UGEMM) { + // UGEMM: A (16x32 logical, compressed to 16x16 with 2:4 sparsity) × B (32x16) = C (16x16) + // A: logical 16x32 -> compressed 16x16 (1KB T-tile) with metadata + // B: full 32x16 stored in U-register (2KB = 2 T-regs) + + // Compress A from 16x32 logical to 16x16 compressed + compress_2_4_sparse(h_A_logical, TILE_SIZE, 2 * TILE_SIZE, h_A, h_M); + + std::cout << "2:4 sparse A: logical 16x32 -> compressed 16x16, metadata " << h_M.size() << " bytes" << std::endl; + + // Upload compressed A (1KB), metadata, and full B (2KB for U-reg) + RT_CHECK(vx_copy_to_dev(A_buffer, h_A.data(), 0, A_buf_size)); + RT_CHECK(vx_copy_to_dev(M_buffer, h_M.data(), 0, M_buf_size)); + RT_CHECK(vx_copy_to_dev(B_buffer, h_B.data(), 0, B_buf_size)); + } + else if (gemm_mode == GEMM_MODE_VGEMM) { + // VGEMM: A (16x64 logical, compressed to 16x16 with 1:4 sparsity) × B (64x16) = C (16x16) + // A: logical 16x64 -> compressed 16x16 (1KB T-tile) with metadata + // B: full 64x16 stored in V-register (4KB = 4 T-regs) + + // Compress A from 16x64 logical to 16x16 compressed + compress_1_4_sparse(h_A_logical, TILE_SIZE, 4 * TILE_SIZE, h_A, h_M); + + std::cout << "1:4 sparse A: logical 16x64 -> compressed 16x16, metadata " << h_M.size() << " bytes" << std::endl; + + // Upload compressed A (1KB), metadata, and full B (4KB for V-reg) + RT_CHECK(vx_copy_to_dev(A_buffer, h_A.data(), 0, A_buf_size)); + RT_CHECK(vx_copy_to_dev(M_buffer, h_M.data(), 0, M_buf_size)); + RT_CHECK(vx_copy_to_dev(B_buffer, h_B.data(), 0, B_buf_size)); + } // upload kernel binary std::cout << "upload kernel binary" << std::endl; @@ -154,11 +367,34 @@ int main(int argc, char *argv[]) { // download result std::cout << "download result" << std::endl; - RT_CHECK(vx_copy_from_dev(h_C.data(), C_buffer, 0, buf_size)); - - // compute CPU reference + RT_CHECK(vx_copy_from_dev(h_C.data(), C_buffer, 0, C_buf_size)); + + // Zero out pruned values in h_A_logical based on metadata for CPU reference + // This ensures CPU computes the same result as GPU (which only uses non-zero values) + if (gemm_mode == GEMM_MODE_UGEMM || gemm_mode == GEMM_MODE_VGEMM) { + for (uint32_t row = 0; row < TILE_SIZE; ++row) { + uint32_t k_groups = A_cols_logical / 4; + for (uint32_t k_grp = 0; k_grp < k_groups; ++k_grp) { + int k_base = k_grp * 4; + // Get metadata nibble for this group + int byte_idx = row * 8 + k_grp / 2; + uint8_t nibble = (k_grp % 2 == 0) ? (h_M[byte_idx] >> 4) : (h_M[byte_idx] & 0x0F); + + // Zero out positions not in metadata mask + for (int offset = 0; offset < 4; ++offset) { + if (!(nibble & (1u << offset))) { + h_A_logical[row * A_cols_logical + k_base + offset] = 0.0f; + } + } + } + } + } + + // compute CPU reference using logical A matrix (now with zeros in pruned positions) std::cout << "verify result" << std::endl; - matmul_cpu(h_ref.data(), h_A.data(), h_B.data()); + // C = A × B where A is MxK, B is KxN, C is MxN + // M = TILE_SIZE (16), K = A_cols_logical, N = B_cols (16) + matmul_cpu(h_ref.data(), h_A_logical.data(), h_B.data(), TILE_SIZE, A_cols_logical, B_cols); // verify result int errors = 0; @@ -170,30 +406,50 @@ int main(int argc, char *argv[]) { std::cout << "writing matrices to output file" << std::endl; std::ofstream output_file("matrices_output.txt"); if (output_file.is_open()) { - output_file << "Matrix A (" << TILE_SIZE << "x" << TILE_SIZE << "):\n"; - for (int i = 0; i < TILE_SIZE; ++i) { - for (int j = 0; j < TILE_SIZE; ++j) { - output_file << h_A[i * TILE_SIZE + j]; - if (j < TILE_SIZE - 1) output_file << " "; + output_file << "GEMM Mode: " << mode_name << "\n\n"; + + output_file << "Matrix A Logical ("; + if (gemm_mode == GEMM_MODE_TGEMM) { + output_file << "Dense"; + } else if (gemm_mode == GEMM_MODE_UGEMM) { + output_file << "2:4 Sparse"; + } else { + output_file << "1:4 Sparse"; + } + output_file << ", " << TILE_SIZE << "x" << A_cols_logical << "):\n"; + for (uint32_t i = 0; i < TILE_SIZE; ++i) { + for (uint32_t j = 0; j < A_cols_logical; ++j) { + output_file << h_A_logical[i * A_cols_logical + j]; + if (j < A_cols_logical - 1) output_file << " "; } output_file << "\n"; } output_file << "\n"; - output_file << "Matrix B (" << TILE_SIZE << "x" << TILE_SIZE << "):\n"; + output_file << "Matrix B (Dense, " << A_cols_logical << "x" << B_cols << "):\n"; + for (uint32_t i = 0; i < A_cols_logical; ++i) { + for (uint32_t j = 0; j < B_cols; ++j) { + output_file << h_B[i * B_cols + j]; + if (j < B_cols - 1) output_file << " "; + } + output_file << "\n"; + } + output_file << "\n"; + + output_file << "Matrix C (GPU Result, " << TILE_SIZE << "x" << TILE_SIZE << "):\n"; for (int i = 0; i < TILE_SIZE; ++i) { for (int j = 0; j < TILE_SIZE; ++j) { - output_file << h_B[i * TILE_SIZE + j]; + output_file << h_C[i * TILE_SIZE + j]; if (j < TILE_SIZE - 1) output_file << " "; } output_file << "\n"; } output_file << "\n"; - output_file << "Matrix C (Result, " << TILE_SIZE << "x" << TILE_SIZE << "):\n"; + output_file << "Matrix C (CPU Reference, " << TILE_SIZE << "x" << TILE_SIZE << "):\n"; for (int i = 0; i < TILE_SIZE; ++i) { for (int j = 0; j < TILE_SIZE; ++j) { - output_file << h_C[i * TILE_SIZE + j]; + output_file << h_ref[i * TILE_SIZE + j]; if (j < TILE_SIZE - 1) output_file << " "; } output_file << "\n"; From 06eb8586d0661ca51cbcf9bbed5fd9455573971c Mon Sep 17 00:00:00 2001 From: tandonmitul27 Date: Mon, 8 Dec 2025 09:57:36 +0530 Subject: [PATCH 5/6] sgemm_sparse kernel now working correctly --- kernel/include/vx_sparse.h | 182 +++++++++++++------- sim/simx/decode.cpp | 53 ++++-- sim/simx/execute.cpp | 108 ++++++++++-- sim/simx/sparse_unit.cpp | 174 +++++++++++++++++-- sim/simx/sparse_unit.h | 3 +- sim/simx/types.h | 12 +- tests/regression/sgemm_sparse/common.h | 2 +- tests/regression/sgemm_sparse/kernel.cpp | 40 +++-- tests/regression/sgemm_sparse/main.cpp | 203 ++++++++++++++++++++--- 9 files changed, 630 insertions(+), 147 deletions(-) diff --git a/kernel/include/vx_sparse.h b/kernel/include/vx_sparse.h index 36e11d1b6a..da6300d5da 100644 --- a/kernel/include/vx_sparse.h +++ b/kernel/include/vx_sparse.h @@ -14,7 +14,11 @@ #pragma once #include +#include #include +#ifdef VX_SPARSE_DEBUG +#include +#endif namespace vortex { namespace sparse { @@ -177,7 +181,12 @@ struct wmma_context { } template - static __attribute__((always_inline)) void load_matrix_sync(Frag &dst, const void *src, size_t ldm, const void *meta_src = nullptr) { + static __attribute__((always_inline)) void load_matrix_sync(Frag &dst, + const void *src, + size_t ldm, + const void *meta_src = nullptr, + uint32_t meta_row_base = 0, + uint32_t meta_col_base = 0) { uint32_t lane = vx_thread_id(); if constexpr (Frag::Use == matrix_a) { // Load row-major matrix A @@ -185,55 +194,108 @@ struct wmma_context { uint32_t lane_in_blk = (cfg::a_block_size == NT) ? lane : (lane % cfg::a_block_size); uint32_t block_row = (lane_in_blk / cfg::tcK) + (block_idx * cfg::tcM); uint32_t block_col = (lane_in_blk % cfg::tcK) * i_ratio; + uint32_t block_col_offset = block_col; // preserve original column stride for metadata lookup uint32_t m_stride = cfg::a_sub_blocks * cfg::tcM; uint32_t k_stride = cfg::tcK * i_ratio; if constexpr (src_layout == col_major) { std::swap(block_row, block_col); } - // For sparse format: when meta_src is provided, data stride is K/2 (not K) - // because each row has K/2 values (2 per block of 4) - size_t data_ldm = (meta_src != nullptr) ? (ldm / 2) : ldm; - auto base = reinterpret_cast(src) + block_row * data_ldm + block_col; - const uint8_t* meta_base = meta_src ? reinterpret_cast(meta_src) : nullptr; - uint32_t meta_ldm = meta_src ? (ldm / 4) : 0; // Number of metadata bytes per row (K/4 blocks) - detail::unroll_for([&](auto r) { - uint32_t block_m = r / cfg::k_steps; - uint32_t block_k = r % cfg::k_steps; - uint32_t elem_row = block_m * m_stride; - uint32_t elem_col = block_k * k_stride; - uint32_t meta_value = 0; - - if (meta_base) { - uint32_t matrix_row = block_row + elem_row; - uint32_t k_elem_idx = elem_col / i_ratio; - uint32_t meta_block_k = k_elem_idx / 4; - if (meta_block_k < meta_ldm) { - uint32_t meta_offset = matrix_row * meta_ldm + meta_block_k; - meta_value = static_cast(meta_base[meta_offset]); - } - } - - if constexpr (Frag::Use == matrix_a) { + // Metadata pointer is pre-offset to tile position (like data pointer) + const uint32_t* meta_base = meta_src ? reinterpret_cast(meta_src) : nullptr; + uint32_t meta_ldm = meta_src ? (ldm / 4) : 0; + + if (meta_src != nullptr) { + // SPARSE LOADING: Use metadata to place values in correct k_step registers + // data_ldm is K/2 for sparse (compressed values) + size_t data_ldm = ldm / 2; + // For sparse, don't add block_col to base - we compute sparse_idx separately + auto data_base = reinterpret_cast(src) + block_row * data_ldm; + + // First, load metadata for each M row that this thread handles + // and distribute sparse values to the correct k_step registers + detail::unroll_for([&](auto r) { + uint32_t block_m = r / cfg::k_steps; + uint32_t block_k = r % cfg::k_steps; + uint32_t elem_row = block_m * m_stride; + + // Get metadata for this row (absolute position in matrix) + uint32_t abs_row = meta_row_base + block_row + elem_row; + uint32_t abs_k_block = (meta_col_base / 4); // K-block index for this tile + const uint32_t *meta_ptr = meta_base + static_cast(abs_row) * meta_ldm + abs_k_block; + uint32_t meta_value = *meta_ptr; dst.metadata[r] = meta_value; - } - if constexpr (src_layout == col_major) { - static_assert(input_is_subbyte == false, "col_major layout is not supported for sub-byte matrix_a"); - std::swap(elem_row, elem_col); - auto ptr = base + elem_row * data_ldm + elem_col; - if constexpr (sizeof(vreg_t) == sizeof(input_t) && !input_is_subbyte) { - dst.data[r] = *reinterpret_cast(ptr); + + // meta_value is a bitmask: bits 0-3 indicate which of 4 K positions have values + // block_k indicates which pair of K positions this register is for: + // block_k=0 -> K positions 0,1 (bits 0,1) + // block_k=1 -> K positions 2,3 (bits 2,3) + uint8_t meta_byte = meta_value & 0xFF; + uint32_t k_start = block_k * cfg::tcK; // Start K position for this register + uint32_t k_end = k_start + cfg::tcK; // End K position + + // Count how many sparse values come BEFORE this k_step for this row + uint32_t sparse_offset = 0; + for (uint32_t pos = 0; pos < k_start; ++pos) { + if (meta_byte & (1u << pos)) { + sparse_offset++; + } + } + + // For fp32 with tcK=2, each register holds 1 fp32 value + // block_col determines which position within the tcK pair: 0 or 1 + // So the target K position is: k_start + block_col + uint32_t target_pos = k_start + block_col; + + vreg_t loaded_val = 0.0f; + + if (target_pos < 4) { + // Count sparse values before target_pos to get the sparse index + uint32_t sparse_idx = sparse_offset; + for (uint32_t pos = k_start; pos < target_pos; ++pos) { + if (meta_byte & (1u << pos)) { + sparse_idx++; + } + } + + // Check if target position has a sparse value + if (meta_byte & (1u << target_pos)) { + auto ptr = data_base + elem_row * data_ldm + sparse_idx; + loaded_val = *ptr; + } + // else: loaded_val stays 0.0f (position was pruned) + } + + dst.data[r] = loaded_val; + }); + } else { + // DENSE LOADING: Original non-sparse path + auto base = reinterpret_cast(src) + block_row * ldm + block_col; + + detail::unroll_for([&](auto r) { + uint32_t block_m = r / cfg::k_steps; + uint32_t block_k = r % cfg::k_steps; + uint32_t elem_row = block_m * m_stride; + uint32_t elem_col = block_k * k_stride; + + dst.metadata[r] = 0; + + if constexpr (src_layout == col_major) { + static_assert(input_is_subbyte == false, "col_major layout is not supported for sub-byte matrix_a"); + std::swap(elem_row, elem_col); + auto ptr = base + elem_row * ldm + elem_col; + if constexpr (sizeof(vreg_t) == sizeof(input_t) && !input_is_subbyte) { + dst.data[r] = *reinterpret_cast(ptr); + } else { + dst.data[r] = input_acessor_t::pack_row(ptr, ldm); + } } else { - dst.data[r] = input_acessor_t::pack_row(ptr, data_ldm); + auto ptr = base + elem_row * ldm + elem_col; + assert(reinterpret_cast(ptr) % alignof(vreg_t) == 0 && "pointer must be aligned to 4 bytes"); + dst.data[r] = *reinterpret_cast(ptr); } - } else { - // row_major layout - // For sparse format, use data_ldm (K/2) instead of ldm (K) - auto ptr = base + elem_row * data_ldm + elem_col; - assert(reinterpret_cast(ptr) % alignof(vreg_t) == 0 && "pointer must be aligned to 4 bytes"); - dst.data[r] = *reinterpret_cast(ptr); - } - }); + }); + } } else if constexpr (Frag::Use == matrix_b) { // Load column-major matrix B uint32_t block_idx = (cfg::b_block_size == NT) ? 0 : (lane / cfg::b_block_size); @@ -334,23 +396,27 @@ struct wmma_context { static_assert(FragC::Use == accumulator, "C must be accumulator"); static_assert(FragD::Use == accumulator, "D must be accumulator"); - auto meta_value = [&](uint32_t idx) -> uint32_t { - if constexpr (FragA::Use == matrix_a) { - if (idx < FragA::NR) { - return fragA.metadata[idx]; - } - } - return 0u; - }; - - register uint32_t ma0 __asm__("a0") = meta_value(0); - register uint32_t ma1 __asm__("a1") = meta_value(1); - register uint32_t ma2 __asm__("a2") = meta_value(2); - register uint32_t ma3 __asm__("a3") = meta_value(3); - register uint32_t ma4 __asm__("a4") = meta_value(4); - register uint32_t ma5 __asm__("a5") = meta_value(5); - register uint32_t ma6 __asm__("a6") = meta_value(6); - register uint32_t ma7 __asm__("a7") = meta_value(7); + // Load metadata values into local variables first to avoid stack offset issues + uint32_t m0 = 0, m1 = 0, m2 = 0, m3 = 0, m4 = 0, m5 = 0, m6 = 0, m7 = 0; + if constexpr (FragA::Use == matrix_a) { + if constexpr (FragA::NR > 0) m0 = fragA.metadata[0]; + if constexpr (FragA::NR > 1) m1 = fragA.metadata[1]; + if constexpr (FragA::NR > 2) m2 = fragA.metadata[2]; + if constexpr (FragA::NR > 3) m3 = fragA.metadata[3]; + if constexpr (FragA::NR > 4) m4 = fragA.metadata[4]; + if constexpr (FragA::NR > 5) m5 = fragA.metadata[5]; + if constexpr (FragA::NR > 6) m6 = fragA.metadata[6]; + if constexpr (FragA::NR > 7) m7 = fragA.metadata[7]; + } + + register uint32_t ma0 __asm__("a0") = m0; + register uint32_t ma1 __asm__("a1") = m1; + register uint32_t ma2 __asm__("a2") = m2; + register uint32_t ma3 __asm__("a3") = m3; + register uint32_t ma4 __asm__("a4") = m4; + register uint32_t ma5 __asm__("a5") = m5; + register uint32_t ma6 __asm__("a6") = m6; + register uint32_t ma7 __asm__("a7") = m7; // fragA: caller-saved registers (f0-f7) register float fa0 __asm__("f0") = fragA.data[0]; diff --git a/sim/simx/decode.cpp b/sim/simx/decode.cpp index f720dcd388..eb5df47089 100644 --- a/sim/simx/decode.cpp +++ b/sim/simx/decode.cpp @@ -26,7 +26,7 @@ #include "arch.h" #include "instr.h" -#ifdef EXT_TCU_ENABLE +#if defined(EXT_TCU_ENABLE) || defined(EXT_VEGETA_ENABLE) #include "tensor_cfg.h" #endif @@ -492,6 +492,12 @@ static op_string_t op_string(const Instr &instr) { case VegetaTcuType::TILE_GEMM_U: return {"TILE_GEMM_U", ""}; case VegetaTcuType::TILE_GEMM_V: return {"TILE_GEMM_V", ""}; case VegetaTcuType::TILE_GEMM_R: return {"TILE_GEMM_R", ""}; + case VegetaTcuType::WMMA: { + auto tpuArgs = std::get(instrArgs); + namespace vt = vortex::tensor; + return {"WMMA." + std::string(vt::fmt_string(tpuArgs.fmt_s)) + "." + std::string(vt::fmt_string(tpuArgs.fmt_d)) + + "." + std::to_string(tpuArgs.step_m) + "." + std::to_string(tpuArgs.step_n), ""}; + } default: std::abort(); } @@ -1143,25 +1149,46 @@ void Emulator::decode(uint32_t code, uint32_t wid, uint64_t uuid) { } break; #endif #ifdef EXT_VEGETA_ENABLE - /* case 3: { switch (funct3) { - case 0: { // DOT8 - auto instr = std::allocate_shared(instr_pool_, uuid, FUType::ALU); - instr->setOpType(AluType::DOT8); - instr->setArgs(IntrAluArgs{0, 0, 0}); - // TODO: set destination register - // TODO: set source registers - instr->setDestReg(rd, RegType::Integer); - instr->setSrcReg(0, rs1, RegType::Integer); - instr->setSrcReg(1, rs2, RegType::Integer); - ibuffer.push_back(instr); + case 0: { // WMMA + namespace vt = vortex::tensor; + using cfg = vt::wmma_config_t; + uint32_t ra_base = 0; + uint32_t rb_base = (cfg::NRB == 4) ? 28 : 10; + uint32_t rc_base = (cfg::NRB == 4) ? 10 : 24; + uint32_t fmt_d = rd; + uint32_t fmt_s = rs1; + uint32_t steps = 0; + uint32_t steps_count = cfg::m_steps * cfg::n_steps * cfg::k_steps; + uint32_t steps_shift = 32 - log2ceil(steps_count); + uint32_t uuid_hi = (uuid >> 32) & 0xffffffff; + uint32_t uuid_lo = uuid & 0xffffffff; + for (uint32_t k = 0; k < cfg::k_steps; ++k) { + for (uint32_t m = 0; m < cfg::m_steps; ++m) { + for (uint32_t n = 0; n < cfg::n_steps; ++n) { + uint32_t rs1 = ra_base + (m / cfg::a_sub_blocks) * cfg::k_steps + k; + uint32_t rs2 = rb_base + (k * cfg::n_steps + n) / cfg::b_sub_blocks; + uint32_t rs3 = rc_base + m * cfg::n_steps + n; + uint32_t uuid_lo_x = (steps << steps_shift) | uuid_lo; + uint64_t uuid_x = (static_cast(uuid_hi) << 32) | uuid_lo_x; + ++steps; + auto instr = std::allocate_shared(instr_pool_, uuid_x, FUType::VEGETA); + instr->setOpType(VegetaTcuType::WMMA); + instr->setArgs(IntrVegetaTcuArgs{fmt_s, fmt_d, m, n}); + instr->setDestReg(rs3, RegType::Float); + instr->setSrcReg(0, rs1, RegType::Float); + instr->setSrcReg(1, rs2, RegType::Float); + instr->setSrcReg(2, rs3, RegType::Float); + ibuffer.push_back(instr); + } + } + } } break; default: std::abort(); } } break; - */ #endif default: std::abort(); diff --git a/sim/simx/execute.cpp b/sim/simx/execute.cpp index 75af3d494a..027232dcb1 100644 --- a/sim/simx/execute.cpp +++ b/sim/simx/execute.cpp @@ -149,9 +149,15 @@ instr_trace_t* Emulator::execute(const Instr &instr, uint32_t wid) { << ", PC=0x" << std::hex << warp.PC << std::dec << " (#" << instr.getUUID() << ")"); // fetch register values +#ifdef EXT_VEGETA_ENABLE if (rsrc0.type != RegType::None && rsrc0.type != RegType::Tile) fetch_registers(rs1_data, wid, 0, rsrc0); if (rsrc1.type != RegType::None && rsrc1.type != RegType::Tile) fetch_registers(rs2_data, wid, 1, rsrc1); if (rsrc2.type != RegType::None && rsrc2.type != RegType::Tile) fetch_registers(rs3_data, wid, 2, rsrc2); +#else + if (rsrc0.type != RegType::None) fetch_registers(rs1_data, wid, 0, rsrc0); + if (rsrc1.type != RegType::None) fetch_registers(rs2_data, wid, 1, rsrc1); + if (rsrc2.type != RegType::None) fetch_registers(rs3_data, wid, 2, rsrc2); +#endif uint32_t thread_start = 0; for (; thread_start < num_threads; ++thread_start) { @@ -1546,37 +1552,105 @@ instr_trace_t* Emulator::execute(const Instr &instr, uint32_t wid) { } }, [&](VegetaTcuType tcu_type) { - auto trace_data = std::make_shared(); - trace->data = trace_data; - assert(warp.tmask.count() == num_threads); - - // Extract tile register indices from instruction - uint32_t dst_reg = rdest.idx; - uint32_t src1_reg = instr.getSrcReg(0).idx; - uint32_t src2_reg = instr.getSrcReg(1).idx; - switch (tcu_type) { - case VegetaTcuType::TILE_GEMM_T: + case VegetaTcuType::TILE_GEMM_T: { + auto trace_data = std::make_shared(); + trace->data = trace_data; + assert(warp.tmask.count() == num_threads); + + // Extract tile register indices from instruction + uint32_t dst_reg = rdest.idx; + uint32_t src1_reg = instr.getSrcReg(0).idx; + uint32_t src2_reg = instr.getSrcReg(1).idx; + // Dense tile × Dense tile → Tile (T × T → T) sparse_unit_->tile_gemm_t(dst_reg, src1_reg, src2_reg); rd_write = false; // Writes to tile registers, not scalar registers - break; - case VegetaTcuType::TILE_GEMM_U: + } break; + case VegetaTcuType::TILE_GEMM_U: { + auto trace_data = std::make_shared(); + trace->data = trace_data; + assert(warp.tmask.count() == num_threads); + + // Extract tile register indices from instruction + uint32_t dst_reg = rdest.idx; + uint32_t src1_reg = instr.getSrcReg(0).idx; + uint32_t src2_reg = instr.getSrcReg(1).idx; + // Sparse tile (2:4) × Dense tile → Tile (T × U → T) // Metadata assumed to be in corresponding m-register (same index as src1) sparse_unit_->tile_gemm_u(dst_reg, src1_reg, src2_reg, src1_reg); rd_write = false; - break; - case VegetaTcuType::TILE_GEMM_V: + } break; + case VegetaTcuType::TILE_GEMM_V: { + auto trace_data = std::make_shared(); + trace->data = trace_data; + assert(warp.tmask.count() == num_threads); + + // Extract tile register indices from instruction + uint32_t dst_reg = rdest.idx; + uint32_t src1_reg = instr.getSrcReg(0).idx; + uint32_t src2_reg = instr.getSrcReg(1).idx; + // Sparse tile (1:4) × Dense tile → Tile (T × V → T) sparse_unit_->tile_gemm_v(dst_reg, src1_reg, src2_reg, src1_reg); rd_write = false; - break; - case VegetaTcuType::TILE_GEMM_R: + } break; + case VegetaTcuType::TILE_GEMM_R: { + auto trace_data = std::make_shared(); + trace->data = trace_data; + assert(warp.tmask.count() == num_threads); + + // Extract tile register indices from instruction + uint32_t dst_reg = rdest.idx; + uint32_t src1_reg = instr.getSrcReg(0).idx; + uint32_t src2_reg = instr.getSrcReg(1).idx; + // Row-wise sparse tile × Dense tile → Tile (T × U → U) sparse_unit_->tile_gemm_r(dst_reg, src1_reg, src2_reg, src1_reg); rd_write = false; - break; + } break; + case VegetaTcuType::WMMA: { + auto tpuArgs = std::get(instrArgs); + auto trace_data = std::make_shared(); + trace->data = trace_data; + assert(warp.tmask.count() == num_threads); + + // Get metadata from integer registers a0-a7 (x10-x17) for sparse fragA + // These contain metadata values loaded by mma_sync into a0-a7 + DTH(3, "WMMA: current regfile values:" << std::hex << std::endl); + for (uint32_t i = 0; i < 32; ++i) { + DTN(3, " x" << std::setfill('0') << std::setw(2) << i << ": 0x" << warp.ireg_file.at(i).at(0) << std::dec << std::endl); + } + uint32_t metadata[8] = {0}; + for (uint32_t reg = 0; reg < 8; ++reg) { + // a0-a7 correspond to x10-x17 in RISC-V + uint32_t a_reg = 10 + reg; // a0=10, a1=11, ..., a7=17 + + // Get value from integer register a_reg for thread 0 (all threads should have same metadata) + if (warp.tmask.test(0) && a_reg < warp.ireg_file.size()) { + metadata[reg] = warp.ireg_file.at(a_reg).at(0); + } + } + DTH(3, "WMMA: metadata values:" << std::hex << std::endl); + for (uint32_t i = 0; i < 8; ++i) { + DTN(3, " a" << std::setfill('0') << std::setw(1) << i << ": 0x" << metadata[i] << std::dec << std::endl); + } + // Resize rd_data and rs3_data to accommodate WMMA output (tcM * tcN elements) + // For sparse WMMA, we need at least tcM * tcN elements + namespace vt = vortex::sparse; + using cfg = vt::wmma_config_t; + uint32_t wmma_size = cfg::tcM * cfg::tcN; + if (rd_data.size() < wmma_size) { + rd_data.resize(wmma_size); + } + if (rs3_data.size() < wmma_size) { + rs3_data.resize(wmma_size); + } + + sparse_unit_->wmma(wid, tpuArgs.fmt_s, tpuArgs.fmt_d, tpuArgs.step_m, tpuArgs.step_n, rs1_data, rs2_data, rs3_data, rd_data, trace_data.get(), metadata); + rd_write = true; + } break; default: std::abort(); } diff --git a/sim/simx/sparse_unit.cpp b/sim/simx/sparse_unit.cpp index 921e0e427a..24608e6154 100644 --- a/sim/simx/sparse_unit.cpp +++ b/sim/simx/sparse_unit.cpp @@ -143,7 +143,99 @@ struct FEDP{ } }; +// Sparse FEDP: uses metadata to select which values from fragB to use +// fragA is sparse (2:4), fragB is dense +// metadata contains bitmasks indicating which 2 of 4 positions are non-zero +template +struct SparseFEDP { + using itype = typename It::dtype; + using otype = typename Ot::dtype; + static uint32_t eval(const reg_data_t *a_row, const reg_data_t *b_col, uint32_t c_val, const uint32_t* metadata) { + constexpr uint32_t i_ratio = sizeof(uint32_t) / sizeof(itype); + static_assert(i_ratio * sizeof(itype) == sizeof(uint32_t), "SparseFEDP: tcK * i_ratio must be <= 32"); + auto acc = bit_cast(c_val); + + constexpr uint32_t regs_per_block = (i_ratio == 2) ? 2 : 4; + + for (uint32_t z = 0; z < cfg::tcK; z += regs_per_block) { + uint32_t block_idx = z / regs_per_block; + uint32_t meta = (block_idx < 8) ? metadata[block_idx] : 0; + uint8_t meta_byte = meta & 0xFF; + + for (uint32_t pos = 0; pos < 4; ++pos) { + if (meta_byte & (1u << pos)) { + uint32_t reg_idx = z + (pos / i_ratio); + uint32_t elem_idx = pos % i_ratio; + + if (reg_idx < cfg::tcK) { + auto a = reinterpret_cast(&a_row[reg_idx].u32); + auto b = reinterpret_cast(&b_col[reg_idx].u32); + acc = FMA::eval(a[elem_idx], b[elem_idx], acc); + } + } + } + } + return bit_cast(acc); + } +}; + +template <> +struct SparseFEDP { + static uint32_t eval(const reg_data_t *a_row, const reg_data_t *b_col, uint32_t c_val, const uint32_t* metadata) { + __unused(metadata); + auto acc = bit_cast(c_val); + + for (uint32_t z = 0; z < cfg::tcK; ++z) { + auto a_val = bit_cast(a_row[z].u32); + auto b_val = bit_cast(b_col[z].u32); + acc = FMA::eval(a_val, b_val, acc); + } + + return bit_cast(acc); + } +}; + using PFN_FEDP = uint32_t (*)(const reg_data_t*, const reg_data_t*, uint32_t); +using PFN_SparseFEDP = uint32_t (*)(const reg_data_t*, const reg_data_t*, uint32_t, const uint32_t*); + +static PFN_SparseFEDP select_SparseFEDP(uint32_t IT, uint32_t OT) { + switch (OT) { + case vt::fp32::id: + switch (IT) { + case vt::fp32::id: + return SparseFEDP::eval; + case vt::fp16::id: + return SparseFEDP::eval; + case vt::bf16::id: + return SparseFEDP::eval; + default: + std::cout << "Error: unsupported sparse mma format: " << IT << " -> " << OT << "!" << std::endl; + std::abort(); + } + break; + case vt::fp16::id: + switch (IT) { + case vt::fp16::id: + return SparseFEDP::eval; + default: + std::cout << "Error: unsupported sparse mma format: " << IT << " -> " << OT << "!" << std::endl; + std::abort(); + } + break; + case vt::bf16::id: + switch (IT) { + case vt::bf16::id: + return SparseFEDP::eval; + default: + std::cout << "Error: unsupported sparse mma format: " << IT << " -> " << OT << "!" << std::endl; + std::abort(); + } + break; + default: + std::cout << "Error: unsupported sparse output type: " << OT << "!" << std::endl; + std::abort(); + } +} static PFN_FEDP select_FEDP(uint32_t IT, uint32_t OT) { switch (OT) { @@ -245,6 +337,7 @@ class SparseUnit::Impl { case VegetaTcuType::TILE_GEMM_U: case VegetaTcuType::TILE_GEMM_V: case VegetaTcuType::TILE_GEMM_R: + case VegetaTcuType::WMMA: delay = 4; break; default: @@ -293,21 +386,71 @@ class SparseUnit::Impl { const std::vector& rs2_data, const std::vector& rs3_data, std::vector& rd_data, - ExeTraceData* trace_data) { - __unused(wid); + ExeTraceData* trace_data, + const uint32_t* metadata) { __unused(trace_data); - __unused(fmt_s); - __unused(fmt_d); - __unused(step_m); - __unused(step_n); - __unused(rs1_data); - __unused(rs2_data); - __unused(rs3_data); - __unused(rd_data); - // This function is now a placeholder for TILE_GEMM operations - // The actual implementation is handled via tile registers directly - // See tile_gemm_t, tile_gemm_u, tile_gemm_v, tile_gemm_r functions + // Use provided metadata from integer registers 0-7 for sparse fragA + // If metadata is null, use zeros (dense mode fallback) + uint32_t meta[8] = {0}; + if (metadata != nullptr) { + for (uint32_t i = 0; i < 8; ++i) { + meta[i] = metadata[i]; + } + } + + // Use sparse FEDP for sparse-dense GEMM + auto sparse_fedp = select_SparseFEDP(fmt_s, fmt_d); + + uint32_t a_off = (step_m % cfg::a_sub_blocks) * cfg::a_block_size; + uint32_t b_off = (step_n % cfg::b_sub_blocks) * cfg::b_block_size; + + for (uint32_t i = 0; i < cfg::tcM; ++i) { + for (uint32_t j = 0; j < cfg::tcN; ++j) { + auto a_row = rs1_data.data() + a_off + i * cfg::tcK; + auto b_col = rs2_data.data() + b_off + j * cfg::tcK; + + uint32_t idx = i * cfg::tcN + j; + if (idx >= rs3_data.size() || idx >= rd_data.size()) { + std::cout << "Error: index out of bounds in sparse_unit wmma: idx=" << idx + << ", rs3_data.size()=" << rs3_data.size() + << ", rd_data.size()=" << rd_data.size() << std::endl; + std::abort(); + } + + auto c_val = rs3_data.at(idx).u32; + + // Map metadata from fragment registers to K dimension registers + uint32_t meta_for_k[8] = {0}; + for (uint32_t z = 0; z < cfg::tcK && z < 8; ++z) { + uint32_t frag_reg_idx = a_off + i * cfg::tcK + z; + if (frag_reg_idx < 8) { + meta_for_k[z] = meta[frag_reg_idx]; + } + } + + // Perform sparse-dense FEDP: fragA is sparse, fragB is dense + auto d_val = sparse_fedp(a_row, b_col, c_val, meta_for_k); + rd_data.at(idx).u64 = nan_box(d_val); + + DTH(3, "SparseFEDP: wid=" << wid << ", i=" << i << ", j=" << j << ", m=" << step_m << ", n=" << step_n << ", a_row={" << std::hex); + for (uint32_t q = 0; q < cfg::tcK; ++q) { + if (q) DTN(3, ", "); + DTN(3, "0x" << a_row[q].u32); + } + DTN(3, "}, b_col={"); + for (uint32_t q = 0; q < cfg::tcK; ++q) { + if (q) DTN(3, ", "); + DTN(3, "0x" << b_col[q].u32); + } + DTN(3, "}, metadata={"); + for (uint32_t q = 0; q < 8 && q < cfg::tcK; ++q) { + if (q) DTN(3, ", "); + DTN(3, "0x" << meta_for_k[q]); + } + DTN(3, "}, c_val=0x" << c_val << ", d_val=0x" << d_val << std::dec << std::endl); + } + } } // TILE_GEMM_T: Dense tile × Dense tile = Tile (T × T → T) @@ -838,8 +981,9 @@ void SparseUnit::wmma(uint32_t wid, const std::vector& rs2_data, const std::vector& rs3_data, std::vector& rd_data, - ExeTraceData* trace_data) { - impl_->wmma(wid, fmt_s, fmt_d, step_m, step_n, rs1_data, rs2_data, rs3_data, rd_data, trace_data); + ExeTraceData* trace_data, + const uint32_t* metadata) { + impl_->wmma(wid, fmt_s, fmt_d, step_m, step_n, rs1_data, rs2_data, rs3_data, rd_data, trace_data, metadata); } void SparseUnit::tile_gemm_t(uint32_t dst_treg, uint32_t src1_treg, uint32_t src2_treg) { diff --git a/sim/simx/sparse_unit.h b/sim/simx/sparse_unit.h index 125a749659..7d98340fd1 100644 --- a/sim/simx/sparse_unit.h +++ b/sim/simx/sparse_unit.h @@ -77,7 +77,8 @@ class SparseUnit : public SimObject { const std::vector& rs2_data, const std::vector& rs3_data, std::vector& rd_data, - ExeTraceData* trace_data); + ExeTraceData* trace_data, + const uint32_t* metadata = nullptr); void tile_gemm_t(uint32_t dst_treg, uint32_t src1_treg, uint32_t src2_treg); void tile_gemm_u(uint32_t dst_treg, uint32_t src1_treg, uint32_t src2_ureg, uint32_t meta_reg); diff --git a/sim/simx/types.h b/sim/simx/types.h index 40d527ebdf..3be1808950 100644 --- a/sim/simx/types.h +++ b/sim/simx/types.h @@ -699,11 +699,19 @@ struct IntrVegetaLsuArgs { /////////////////////////////////////////////////////////////////////////////// +struct IntrVegetaTcuArgs { + uint32_t fmt_s : 4; + uint32_t fmt_d : 4; + uint32_t step_m : 4; + uint32_t step_n : 4; +}; + enum class VegetaTcuType { TILE_GEMM_T, TILE_GEMM_U, TILE_GEMM_V, - TILE_GEMM_R + TILE_GEMM_R, + WMMA }; inline std::ostream &operator<<(std::ostream &os, const VegetaTcuType& type) { @@ -712,6 +720,7 @@ inline std::ostream &operator<<(std::ostream &os, const VegetaTcuType& type) { case VegetaTcuType::TILE_GEMM_U: os << "TILE_GEMM_U"; break; case VegetaTcuType::TILE_GEMM_V: os << "TILE_GEMM_V"; break; case VegetaTcuType::TILE_GEMM_R: os << "TILE_GEMM_R"; break; + case VegetaTcuType::WMMA: os << "WMMA"; break; default: assert(false); } return os; @@ -762,6 +771,7 @@ using IntrArgs = std::variant< #endif #ifdef EXT_VEGETA_ENABLE , IntrVegetaLsuArgs +, IntrVegetaTcuArgs #endif >; diff --git a/tests/regression/sgemm_sparse/common.h b/tests/regression/sgemm_sparse/common.h index a762a4fb2e..b1d29235c3 100644 --- a/tests/regression/sgemm_sparse/common.h +++ b/tests/regression/sgemm_sparse/common.h @@ -8,7 +8,7 @@ #endif #ifndef ITYPE -#define ITYPE fp16 +#define ITYPE fp32 #endif #ifndef OTYPE diff --git a/tests/regression/sgemm_sparse/kernel.cpp b/tests/regression/sgemm_sparse/kernel.cpp index b10072b286..727daf24a6 100644 --- a/tests/regression/sgemm_sparse/kernel.cpp +++ b/tests/regression/sgemm_sparse/kernel.cpp @@ -5,6 +5,12 @@ namespace vt = vortex::sparse; using ctx = vt::wmma_context; +static inline size_t align_up_size(size_t value, size_t alignment) { + if (!alignment) + return value; + return (value + alignment - 1) & ~(alignment - 1); +} + void kernel_body(kernel_arg_t *__UNIFORM__ arg) { auto pA_values = reinterpret_cast(arg->A_addr); auto pB = reinterpret_cast(arg->B_addr); @@ -25,42 +31,46 @@ void kernel_body(kernel_arg_t *__UNIFORM__ arg) { if (tile_row >= M || tile_col >= N) return; - // Sparse A layout: values first (M * K / 2 entries), then metadata (M * K / 4 bytes) + // Sparse A layout: data (values) first, then metadata (padded to 4-byte alignment) + // Values size: (M * K / 2) * sizeof(input_t) bytes + // Metadata size: (M * K / 4) * sizeof(uint32_t) bytes + constexpr size_t meta_entry_bytes = sizeof(uint32_t); size_t values_per_row = K / 2; - size_t meta_per_row = K / 4; - size_t total_values = static_cast(M) * values_per_row; - const uint8_t *meta_base = reinterpret_cast(pA_values + total_values); + size_t values_size = static_cast(M) * values_per_row * sizeof(ctx::input_t); + size_t meta_offset = align_up_size(values_size, meta_entry_bytes); + const uint8_t *base_ptr = reinterpret_cast(pA_values); + const uint32_t *meta_base = reinterpret_cast(base_ptr + meta_offset); // Initialize accumulator ctx::fill_fragment(fragC, 0); for (uint32_t k_tile = 0; k_tile < K; k_tile += ctx::tileK) { // Keep fragB resident while we iterate over sparse A tiles that consume it. - /*if constexpr (vt::ITYPE::bits < 8) { + if constexpr (vt::ITYPE::bits < 8) { auto pTileB = pB + tile_col * K + k_tile; ctx::load_matrix_sync(fragB, pTileB, K); } else { auto pTileB = pB + k_tile * N + tile_col; ctx::load_matrix_sync(fragB, pTileB, N); - }*/ + } // Base pointers for sparse A data/metadata corresponding to this tile size_t row_offset_vals = static_cast(tile_row) * values_per_row; - size_t row_offset_meta = static_cast(tile_row) * meta_per_row; - size_t col_offset_vals = k_tile / 2; - size_t col_offset_meta = k_tile / 4; - + size_t col_offset_vals = (k_tile / 4) * 2; // two stored values per 4-column block auto pTileA = pA_values + row_offset_vals + col_offset_vals; - const uint8_t *pTileMeta = meta_base + row_offset_meta + col_offset_meta; + + // Pass metadata base pointer (not offset) along with tile position + // load_matrix_sync will calculate absolute positions + const void *pTileMeta = reinterpret_cast(meta_base); - ctx::load_matrix_sync(fragA, pTileA, K, pTileMeta); + ctx::load_matrix_sync(fragA, pTileA, K, pTileMeta, tile_row, k_tile); // Matrix multiply-accumulate while fragB stays in registers - // ctx::mma_sync(fragC, fragA, fragB, fragC); + ctx::mma_sync(fragC, fragA, fragB, fragC); } - //auto pTileC = pC + tile_row * N + tile_col; - //ctx::store_matrix_sync(pTileC, fragC, N); + auto pTileC = pC + tile_row * N + tile_col; + ctx::store_matrix_sync(pTileC, fragC, N); } int main() { diff --git a/tests/regression/sgemm_sparse/main.cpp b/tests/regression/sgemm_sparse/main.cpp index d7bca2130c..bb547bd9e7 100644 --- a/tests/regression/sgemm_sparse/main.cpp +++ b/tests/regression/sgemm_sparse/main.cpp @@ -1,4 +1,5 @@ #include "common.h" +#include #include #include #include @@ -6,6 +7,7 @@ #include #include #include +#include #include #include #include @@ -28,8 +30,13 @@ using namespace vortex; namespace vt = sparse; static bool g_enable_sparse = true; + +static size_t align_up(size_t value, size_t alignment) { + if (alignment == 0) + return value; + return (value + alignment - 1) & ~(alignment - 1); +} /////////////////////////////////////////////////////////////////////////////// -/* static void convert_row_to_col_major_4bit(uint8_t *dst, uint32_t width, uint32_t height, const uint8_t *src) { // Calculate output size and stride uint32_t out_bytes = (width * height + 1) / 2; @@ -60,7 +67,6 @@ static void convert_row_to_col_major_4bit(uint8_t *dst, uint32_t width, uint32_t } } } -*/ /////////////////////////////////////////////////////////////////////////////// template @@ -351,6 +357,68 @@ struct SparseMat { uint32_t rows, cols; // original A dims (M × K) }; + +static void matmul_cpu_sparseA( + otype_t* C, // [M × N] output + const SparseMat& A, // sparse-A + const itype_t* B, // [K × N] dense-B (row major) + uint32_t N) // number of columns of B/C +{ + const uint32_t M = A.rows; + const uint32_t K = A.cols; + const uint32_t values_per_row = K / 2; + const uint32_t meta_per_row = K / 4; + const uint32_t subbytes = 8 / vt::ITYPE::bits; + + for (uint32_t m = 0; m < M; ++m) { + const itype_t* row_vals = A.values.data() + static_cast(m) * values_per_row; + const uint8_t* row_meta = A.meta.data() + static_cast(m) * meta_per_row; + otype_t* crow = C + static_cast(m) * N; + + for (uint32_t n = 0; n < N; ++n) { + size_t v_idx = 0; + otype_t sum(0); + + for (uint32_t blk = 0; blk < K; blk += 4) { + uint8_t mask = row_meta[blk / 4]; + if (!mask) + continue; + + for (uint32_t i = 0; i < 4; ++i) { + if (!(mask & (1u << i))) + continue; + + itype_t a_val = row_vals[v_idx++]; + uint32_t k = blk + i; + uint32_t kk = subbytes ? k * subbytes : k; + itype_t b_val = data_accessor_t::read(B, static_cast(kk) * N + n); + sum = muladd_t::eval(a_val, b_val, sum); + } + } + + crow[n] = sum; + } + } +} + +static int verify_sparse_gemm(const SparseMat& A, + const std::vector& B, + const std::vector& C, + uint32_t N) { + std::vector reference(static_cast(A.rows) * N); + matmul_cpu_sparseA(reference.data(), A, B.data(), N); + + int errors = 0; + for (size_t i = 0, e = reference.size(); i < e; ++i) { + if (!Comparator::compare(C[i], reference[i], static_cast(i), errors)) { + ++errors; + if (errors >= MAX_ERRORS) { + break; + } + } + } + return errors; +} /* static void matmul_cpu(otype_t *C, const itype_t *A, const itype_t *B, uint32_t M, uint32_t N, uint32_t K) { uint32_t subbytes = 8 / vt::ITYPE::bits; @@ -420,9 +488,9 @@ static void matmul_cpu_sparseA( const char *kernel_file = "kernel.vxbin"; -uint32_t xm = 4; -uint32_t xn = 8; -uint32_t xk = 2; +uint32_t xm = 1; +uint32_t xn = 1; +uint32_t xk = 1; vx_device_h device = nullptr; vx_buffer_h A_buffer = nullptr; @@ -497,17 +565,25 @@ static SparseMat pruneAndCompressMatrixA(const std::vector& denseA, src[r * K + c + 2], src[r * K + c + 3]}; + // Randomly select 2 out of 4 positions to keep uint32_t idx[4] = {0, 1, 2, 3}; - std::sort(idx, idx + 4, - [&](uint32_t a, uint32_t b) { - return std::abs((int)blk[a]) < std::abs((int)blk[b]); - }); //Sort the 4 elements by absolute value, ascending order - - uint8_t keep0 = idx[3]; - uint8_t keep1 = idx[2]; //idx of largest 2 elements - - out.values.push_back(blk[keep0]); - out.values.push_back(blk[keep1]); + // Shuffle the indices + for (uint32_t i = 3; i > 0; --i) { + uint32_t j = rand() % (i + 1); + std::swap(idx[i], idx[j]); + } + // Select first 2 shuffled indices + uint8_t keep0 = idx[0]; + uint8_t keep1 = idx[1]; + + // Store values in original order (smaller index first) + if (keep0 < keep1) { + out.values.push_back(blk[keep0]); + out.values.push_back(blk[keep1]); + } else { + out.values.push_back(blk[keep1]); + out.values.push_back(blk[keep0]); + } uint8_t m = (1u << keep0) | (1u << keep1); // e.g. 0b0101 out.meta.push_back(m); @@ -567,6 +643,7 @@ int main(int argc, char *argv[]) { } uint32_t M = xm * cfg::tileM; + uint32_t N = xn * cfg::tileN; uint32_t K = xk * cfg::tileK; if ((M % cfg::tileM) != 0) { @@ -579,20 +656,27 @@ int main(int argc, char *argv[]) { return -1; } + if ((N % cfg::tileN) != 0) { + std::cout << "Error: N must be a multiple of tensor tileN!" << std::endl; + return -1; + } + std::cout << "input data type: " << vt::ITYPE::name << " (id=" << vt::ITYPE::id << ")" << std::endl; std::cout << "WMMA Core Dimension: M=" << cfg::tcM << ", N=" << cfg::tcN << ", K=" << cfg::tcK << std::endl; std::cout << "WMMA Tile Dimension: M=" << cfg::tileM << ", N=" << cfg::tileN << ", K=" << cfg::tileK << std::endl; std::cout << "matrix A: " << M << "x" << K << " (sparse format)" << std::endl; + std::cout << "matrix B: " << K << "x" << N << std::endl; + std::cout << "matrix C: " << M << "x" << N << std::endl; // set block size to warp size - kernel_arg.grid_dim[0] = 1; // Only need 1 block for testing + kernel_arg.grid_dim[0] = N / cfg::tileN; kernel_arg.grid_dim[1] = M / cfg::tileM; kernel_arg.block_dim[0] = NT; // warp size kernel_arg.block_dim[1] = 1; // set matrix dimensions kernel_arg.M = M; - kernel_arg.N = 0; // Not used for loading test + kernel_arg.N = N; kernel_arg.K = K; // allocate device memory for sparse matrix A @@ -603,17 +687,29 @@ int main(int argc, char *argv[]) { // Values size: (M * K / 2) * sizeof(itype_t) bytes // Metadata size: (M * K / 4) * sizeof(uint8_t) bytes size_t values_size = (M * K / 2) * sizeof(itype_t); - size_t meta_size = (M * K / 4) * sizeof(uint8_t); - sizeA_sparse = values_size + meta_size; + constexpr size_t meta_entry_bytes = sizeof(uint32_t); + size_t meta_entries = (M * K / 4); + size_t meta_size = meta_entries * meta_entry_bytes; + size_t meta_offset = align_up(values_size, meta_entry_bytes); + size_t padding_bytes = meta_offset - values_size; + sizeA_sparse = meta_offset + meta_size; + + size_t sizeB_elems = static_cast(K) * N; + size_t sizeC_elems = static_cast(M) * N; + size_t sizeB_bytes = sizeB_elems * sizeof(itype_t); + size_t sizeC_bytes = sizeC_elems * sizeof(otype_t); RT_CHECK(vx_mem_alloc(device, sizeA_sparse, VX_MEM_READ, &A_buffer)); RT_CHECK(vx_mem_address(A_buffer, &kernel_arg.A_addr)); - RT_CHECK(vx_mem_alloc(device, 1, VX_MEM_READ, &B_buffer)); // Dummy buffer + RT_CHECK(vx_mem_alloc(device, sizeB_bytes, VX_MEM_READ, &B_buffer)); RT_CHECK(vx_mem_address(B_buffer, &kernel_arg.B_addr)); - RT_CHECK(vx_mem_alloc(device, 1, VX_MEM_WRITE, &C_buffer)); // Dummy buffer + RT_CHECK(vx_mem_alloc(device, sizeC_bytes, VX_MEM_WRITE, &C_buffer)); RT_CHECK(vx_mem_address(C_buffer, &kernel_arg.C_addr)); std::cout << "A_addr=0x" << std::hex << kernel_arg.A_addr << std::endl; + std::cout << "B_addr=0x" << std::hex << kernel_arg.B_addr << std::endl; + std::cout << "C_addr=0x" << std::hex << kernel_arg.C_addr << std::endl; + std::cout << std::dec; // generate source data and convert to sparse format std::vector h_A_dense(M * K); @@ -621,6 +717,11 @@ int main(int argc, char *argv[]) { h_A_dense[i] = Comparator::generate(); } + std::vector h_B_dense(sizeB_elems); + for (size_t i = 0; i < sizeB_elems; ++i) { + h_B_dense[i] = Comparator::generate(); + } + // Convert to sparse format auto sparseA = pruneAndCompressMatrixA(h_A_dense, M, K); @@ -639,6 +740,21 @@ int main(int argc, char *argv[]) { } std::cout << std::dec; + // Print dense matrix B + std::cout << "\n=== Dense Matrix B (" << K << "x" << N << ") ===" << std::endl; + for (uint32_t k = 0; k < K; ++k) { + std::cout << "Row " << k << ": "; + for (uint32_t n = 0; n < N; ++n) { + if (vt::ITYPE::id == vt::fp32::id) { + std::cout << std::fixed << std::setprecision(3) << h_B_dense[k * N + n] << " "; + } else { + std::cout << std::hex << "0x" << (uint32_t)h_B_dense[k * N + n] << " "; + } + } + std::cout << std::endl; + } + std::cout << std::dec; + // Print sparse values std::cout << "\n=== Sparse Matrix Values (" << sparseA.values.size() << " elements) ===" << std::endl; size_t val_idx = 0; @@ -656,7 +772,8 @@ int main(int argc, char *argv[]) { std::cout << std::dec; // Print metadata - std::cout << "\n=== Metadata (" << sparseA.meta.size() << " bytes) ===" << std::endl; + std::cout << "\n=== Metadata (" << sparseA.meta.size() << " entries; padded size " + << meta_size << " bytes) ===" << std::endl; size_t meta_idx = 0; for (uint32_t m = 0; m < M; ++m) { std::cout << "Row " << m << " metadata: "; @@ -694,13 +811,19 @@ int main(int argc, char *argv[]) { } // Create buffer: data first, then metadata - std::vector sparse_buffer(sizeA_sparse); + std::vector sparse_buffer(sizeA_sparse, 0); memcpy(sparse_buffer.data(), sparseA.values.data(), values_size); - memcpy(sparse_buffer.data() + values_size, sparseA.meta.data(), meta_size); + if (padding_bytes) + memset(sparse_buffer.data() + values_size, 0, padding_bytes); + auto meta_words = reinterpret_cast(sparse_buffer.data() + meta_offset); + for (size_t i = 0; i < sparseA.meta.size(); ++i) { + meta_words[i] = static_cast(sparseA.meta[i]); + } std::cout << "\nSparse A: values=" << sparseA.values.size() << " elements (" << values_size << " bytes), " - << "metadata=" << sparseA.meta.size() + << "metadata=" << sparseA.meta.size() + << " entries padded to " << meta_size << " bytes, total=" << sizeA_sparse << " bytes" << std::endl; // upload sparse matrix A buffer @@ -709,6 +832,21 @@ int main(int argc, char *argv[]) { RT_CHECK(vx_copy_to_dev(A_buffer, sparse_buffer.data(), 0, sizeA_sparse)); } + // upload dense matrix B buffer (convert layout if needed) + { + std::cout << "upload matrix B buffer" << std::endl; + if constexpr (std::is_same_v || std::is_same_v) { + std::vector h_B_col(sizeB_elems); + convert_row_to_col_major_4bit(h_B_col.data(), + N, + 2 * K, + reinterpret_cast(h_B_dense.data())); + RT_CHECK(vx_copy_to_dev(B_buffer, h_B_col.data(), 0, sizeB_elems)); + } else { + RT_CHECK(vx_copy_to_dev(B_buffer, h_B_dense.data(), 0, sizeB_bytes)); + } + } + // upload program std::cout << "upload program" << std::endl; RT_CHECK(vx_upload_kernel_file(device, kernel_file, &krnl_buffer)); @@ -731,11 +869,24 @@ int main(int argc, char *argv[]) { double elapsed = std::chrono::duration_cast(time_end - time_start).count(); printf("Elapsed time: %lg ms\n", elapsed); + std::vector h_C(sizeC_elems); + std::cout << "download result buffer" << std::endl; + RT_CHECK(vx_copy_from_dev(h_C.data(), C_buffer, 0, sizeC_bytes)); + + std::cout << "verify sparse GEMM result" << std::endl; + int errors = verify_sparse_gemm(sparseA, h_B_dense, h_C, N); + // cleanup std::cout << "cleanup" << std::endl; cleanup(); - std::cout << "Sparse matrix loading test completed successfully!" << std::endl; + if (errors != 0) { + std::cout << "Found " << std::dec << errors << " errors!" << std::endl; + std::cout << "FAILED!" << std::endl; + return errors; + } + + std::cout << "Sparse GEMM completed successfully!" << std::endl; std::cout << "PASSED!" << std::endl; return 0; From c757d32a9f7cd9f2f40c78ab83e01a21ddcf4c5c Mon Sep 17 00:00:00 2001 From: tandonmitul27 Date: Tue, 9 Dec 2025 19:20:25 +0530 Subject: [PATCH 6/6] added support and test for TILE_SPMM_R --- sim/simx/sparse_unit.cpp | 42 ++++-- tests/regression/sgemm_tile/common.h | 3 +- tests/regression/sgemm_tile/kernel.cpp | 18 ++- tests/regression/sgemm_tile/main.cpp | 194 ++++++++++++++++++++++--- 4 files changed, 225 insertions(+), 32 deletions(-) diff --git a/sim/simx/sparse_unit.cpp b/sim/simx/sparse_unit.cpp index 24608e6154..845488b980 100644 --- a/sim/simx/sparse_unit.cpp +++ b/sim/simx/sparse_unit.cpp @@ -601,42 +601,60 @@ class SparseUnit::Impl { } // TILE_GEMM_R: Row-wise sparse tile × Dense tile = Tile (T × U → U) + // ISA: A is 16×32 logical (compressed to 16×16 padded T-tile) + // B is 32×16 dense (stored in U-reg = 2 T-regs) + // Output is 16×16 (first T-reg of destination U-reg) + // Metadata: 8 blocks per row × 4 bits/block = 32 bits = 4 bytes per row + // Total: 64 bytes mask data + 64 bytes reserved = 128 bytes void tile_gemm_r(uint32_t dst_ureg, uint32_t src1_treg, uint32_t src2_ureg, uint32_t meta_reg) { assert(src1_treg < tile_reg_file_.size() && "Source1 tile register out of bounds"); assert(meta_reg < metadata_reg_file_.size() && "Metadata register out of bounds"); constexpr uint32_t TILE_DIM = 16; + constexpr uint32_t LOGICAL_K = 32; // A is 16×32 logical - const auto& tile_a = tile_reg_file_[src1_treg]; // Sparse tile + const auto& tile_a = tile_reg_file_[src1_treg]; // Compressed 16×16 tile const auto& meta_a = metadata_reg_file_[meta_reg]; // Metadata for sparse tile // Both dst and src2 are U-registers (map to 2 T-registers each) std::vector dst_tregs = map_ureg_to_treg(dst_ureg); std::vector src2_tregs = map_ureg_to_treg(src2_ureg); - // Row-wise sparsity: metadata can vary per row + // Row-wise sparsity: each row of A has 8 blocks of 4 elements (32 total) + // compressed to 16 values using 2-of-4 sparsity for (uint32_t i = 0; i < TILE_DIM; ++i) { for (uint32_t j = 0; j < TILE_DIM; ++j) { - // Determine which destination T-register - uint32_t dst_treg_idx = dst_tregs[j / TILE_DIM]; - uint32_t j_local = j % TILE_DIM; + // Destination is first T-reg of U-reg (16×16 output) + uint32_t dst_treg_idx = dst_tregs[0]; + + float sum = tile_reg_file_[dst_treg_idx][i][j]; // Accumulate - float sum = tile_reg_file_[dst_treg_idx][i][j_local]; // Accumulate + // Track position in compressed A tile for this row + uint32_t a_col = 0; - // Process sparse A row with dense B column - for (uint32_t k_blk = 0; k_blk < TILE_DIM; k_blk += 4) { - uint8_t mask = meta_a[i][k_blk / 4]; // Row-wise metadata + // Process 8 blocks of 4 elements each (K=32 logical) + for (uint32_t k_blk = 0; k_blk < LOGICAL_K; k_blk += 4) { + // Metadata layout: meta_a[row][col] stores individual nibbles (uint4) + // nibble_idx = k_blk / 4 (0..7) directly indexes the metadata column + uint32_t nibble_idx = k_blk / 4; + uint8_t mask = meta_a[i][nibble_idx]; // Direct nibble access for (uint32_t offset = 0; offset < 4; ++offset) { if (mask & (1u << offset)) { - uint32_t k = k_blk + offset; + // This position is non-zero in the logical A + uint32_t k = k_blk + offset; // Logical K index (0..31) + + // B is stored in U-reg (32×16), split into 2 T-regs (rows 0-15 and 16-31) uint32_t src2_treg_idx = src2_tregs[k / TILE_DIM]; uint32_t k_local = k % TILE_DIM; - sum += tile_a[i][k] * tile_reg_file_[src2_treg_idx][k_local][j]; + + // Get value from compressed A tile + sum += tile_a[i][a_col] * tile_reg_file_[src2_treg_idx][k_local][j]; + a_col++; // Move to next compressed value } } } - tile_reg_file_[dst_treg_idx][i][j_local] = sum; + tile_reg_file_[dst_treg_idx][i][j] = sum; } } diff --git a/tests/regression/sgemm_tile/common.h b/tests/regression/sgemm_tile/common.h index fca0f88223..036a7bde05 100644 --- a/tests/regression/sgemm_tile/common.h +++ b/tests/regression/sgemm_tile/common.h @@ -14,7 +14,8 @@ typedef enum { GEMM_MODE_TGEMM = 0, // T x T -> T (dense x dense) GEMM_MODE_UGEMM = 1, // T x U -> T (sparse 2:4 packed x dense 2x) - GEMM_MODE_VGEMM = 2 // T x V -> T (sparse 1:4 packed x dense 4x) + GEMM_MODE_VGEMM = 2, // T x V -> T (sparse 1:4 packed x dense 4x) + GEMM_MODE_RGEMM = 3 // T x U -> U (row-wise N:4 sparse x dense 2x) } gemm_mode_t; typedef struct { diff --git a/tests/regression/sgemm_tile/kernel.cpp b/tests/regression/sgemm_tile/kernel.cpp index ef5df44ac5..86a45d5343 100644 --- a/tests/regression/sgemm_tile/kernel.cpp +++ b/tests/regression/sgemm_tile/kernel.cpp @@ -27,7 +27,7 @@ void kernel_body(kernel_arg_t* __UNIFORM__ arg) { } else if (mode == GEMM_MODE_UGEMM) { // UGEMM: T × U -> T (2:4 sparse) - // Load metadata into M-reg 1 (1KB metadata) + // Load metadata into M-reg 1 (128 bytes) vx_lm(1, (size_t)M_ptr, 0); // Load B tile into U-reg 2 (2KB sparse 2:4) @@ -38,7 +38,7 @@ void kernel_body(kernel_arg_t* __UNIFORM__ arg) { } else if (mode == GEMM_MODE_VGEMM) { // VGEMM: T × V -> T (1:4 sparse) - // Load metadata into M-reg 1 (1KB metadata) + // Load metadata into M-reg 1 (128 bytes) vx_lm(1, (size_t)M_ptr, 0); // Load B tile into V-reg 1 (4KB sparse 1:4) @@ -48,8 +48,22 @@ void kernel_body(kernel_arg_t* __UNIFORM__ arg) { // VGEMM: T0 = T1 (sparse with M1 metadata) × V1 (dense) vx_vgemm(0, 1, 1); } + else if (mode == GEMM_MODE_RGEMM) { + // RGEMM: T × U -> U (row-wise N:4 sparse) + // Load metadata into M-reg 1 (128 bytes) + vx_lm(1, (size_t)M_ptr, 0); + + // Load B tile into U-reg 2 (2KB dense) + vx_lu(2, (size_t)B_ptr, 0); + + // RGEMM: U0 = T1 (row-wise sparse with M1 metadata) × U2 (dense) + // Output is stored in U-reg 0 = T-reg 0 + T-reg 1 (2KB total) + // ISA: vx_rgemm computes full U-reg result + vx_rgemm(0, 1, 2); + } // Store result from T-reg 0 to C (always 1KB) + // For RGEMM: we only validate first T-reg of U0 (top 16 rows) vx_st((size_t)C_ptr, 0, 0); } diff --git a/tests/regression/sgemm_tile/main.cpp b/tests/regression/sgemm_tile/main.cpp index c77d54b404..5c706addc2 100644 --- a/tests/regression/sgemm_tile/main.cpp +++ b/tests/regression/sgemm_tile/main.cpp @@ -37,10 +37,11 @@ static gemm_mode_t gemm_mode = GEMM_MODE_TGEMM; static void show_usage() { std::cout << "Vortex SGEMM TILE Test (16x16 matrix operations)." << std::endl; std::cout << "Usage: [-m mode] [-h: help]" << std::endl; - std::cout << " -m mode: GEMM mode (0=TGEMM, 1=UGEMM, 2=VGEMM) [default: 0]" << std::endl; + std::cout << " -m mode: GEMM mode (0=TGEMM, 1=UGEMM, 2=VGEMM, 3=RGEMM) [default: 0]" << std::endl; std::cout << " TGEMM (0): T × T -> T (dense × dense)" << std::endl; std::cout << " UGEMM (1): T × U -> T (dense × 2:4 sparse)" << std::endl; std::cout << " VGEMM (2): T × V -> T (dense × 1:4 sparse)" << std::endl; + std::cout << " RGEMM (3): T × U -> U (row-wise N:4 sparse × dense)" << std::endl; } static void parse_args(int argc, char **argv) { @@ -49,7 +50,7 @@ static void parse_args(int argc, char **argv) { switch (c) { case 'm': gemm_mode = static_cast(atoi(optarg)); - if (gemm_mode < GEMM_MODE_TGEMM || gemm_mode > GEMM_MODE_VGEMM) { + if (gemm_mode < GEMM_MODE_TGEMM || gemm_mode > GEMM_MODE_RGEMM) { std::cerr << "Error: Invalid mode " << gemm_mode << std::endl; show_usage(); exit(-1); @@ -177,6 +178,75 @@ static void compress_1_4_sparse(const std::vector& logical_tile, int M, i } } +// Generate compressed row-wise N:4 sparse tile and metadata from full logical matrix +// Input: logical_tile is M×K (16×32) +// Output: padded_tile is M×(K/2) (16×16), metadata is exactly 128 bytes +// Compression: For each 4-element block, keep top-2 values by magnitude (deterministic) +// Metadata layout: Must match TILE_LOAD_M format (8 bytes per row) +// - 16 rows × 8 bytes/row = 128 bytes total +// - Each byte stores 2 nibbles: upper nibble for col N, lower for col N+1 +// - For RGEMM: only first 8 nibbles (cols 0-7) are used, rest are zero +static void compress_rowwise_n4_sparse(const std::vector& logical_tile, int M, int K, + std::vector& padded_tile, std::vector& metadata) { + // Output sizes: padded tile is M×(K/2), metadata is exactly 128 bytes + padded_tile.resize(M * (K / 2)); + metadata.resize(128); // 8 bytes per row × 16 rows = 128 bytes + std::fill(metadata.begin(), metadata.end(), 0); + + for (int row = 0; row < M; ++row) { + int padded_col = 0; + + // Process K/4 groups of 4 elements (8 groups for K=32) + for (int k_grp = 0; k_grp < K / 4; ++k_grp) { + int k_base = k_grp * 4; + + // Find the 2 largest magnitude values in this group of 4 + // Use index-value pairs for deterministic selection + std::pair vals[4]; + for (int offset = 0; offset < 4; ++offset) { + vals[offset] = {offset, logical_tile[row * K + k_base + offset]}; + } + + // Sort by magnitude (descending) to find top 2 + // For equal magnitudes, lower index wins (stable, deterministic) + std::sort(vals, vals + 4, [](const auto& a, const auto& b) { + float abs_a = std::abs(a.second); + float abs_b = std::abs(b.second); + if (abs_a != abs_b) return abs_a > abs_b; + return a.first < b.first; // Tie-breaker: lower index first + }); + + // Create 4-bit bitmask for top 2 values + uint8_t mask = 0; + for (int i = 0; i < 2; ++i) { + int offset = vals[i].first; + mask |= (1u << offset); + } + + // Store values in POSITION ORDER (not magnitude order) + // Hardware iterates through bit positions 0-3 sequentially + for (int offset = 0; offset < 4; ++offset) { + if (mask & (1u << offset)) { + padded_tile[row * (K / 2) + padded_col++] = logical_tile[row * K + k_base + offset]; + } + } + + // Store metadata: 4 bits per block + // Layout: 8 bytes per row (matching TILE_LOAD_M format) + // Each byte stores 2 nibbles: upper for even col, lower for odd col + // k_grp 0,1 -> byte 0 (cols 0,1), k_grp 2,3 -> byte 1 (cols 2,3), etc. + int byte_idx = row * 8 + k_grp / 2; // 8 bytes per row + if (k_grp % 2 == 0) { + metadata[byte_idx] = (mask << 4); // Upper nibble (col N) + } else { + metadata[byte_idx] |= mask; // Lower nibble (col N+1) + } + } + } + + // Remaining bytes in each row (cols 8-15) are zero, already initialized +} + // CPU reference: C = A × B // A is MxK, B is KxN, C is MxN // For TGEMM: A is 16x16, B is 16x16 @@ -219,9 +289,9 @@ int main(int argc, char *argv[]) { std::cout << "open device connection" << std::endl; RT_CHECK(vx_dev_open(&device)); - uint32_t num_elements = TILE_SIZE * TILE_SIZE; // 256 elements + uint32_t num_elements = TILE_SIZE * TILE_SIZE; // 256 elements for T-reg uint32_t A_buf_size = T_TILE_BYTES; // Always 1KB for A - uint32_t C_buf_size = T_TILE_BYTES; // Always 1KB for C + uint32_t C_buf_size = T_TILE_BYTES; // Always 1KB for C (first T-reg of result) uint32_t B_buf_size, M_buf_size = 0; const char* mode_name; @@ -238,7 +308,12 @@ int main(int argc, char *argv[]) { case GEMM_MODE_VGEMM: mode_name = "VGEMM (T × V, 1:4 sparse)"; B_buf_size = V_TILE_BYTES; // 4KB - M_buf_size = M_TILE_BYTES; // 1KB metadata + M_buf_size = M_TILE_BYTES; // 128 bytes metadata + break; + case GEMM_MODE_RGEMM: + mode_name = "RGEMM (T × U -> U, row-wise N:4 sparse)"; + B_buf_size = U_TILE_BYTES; // 2KB (B is dense U-reg) + M_buf_size = M_TILE_BYTES; // 128 bytes metadata break; default: std::cerr << "Invalid GEMM mode!" << std::endl; @@ -282,9 +357,12 @@ int main(int argc, char *argv[]) { // allocate host buffers std::cout << "allocate host buffers" << std::endl; - // A's logical size depends on mode: 16x16 for TGEMM, 16x32 for UGEMM, 16x64 for VGEMM + // A's logical size depends on mode: + // - TGEMM: 16x16 (dense) + // - UGEMM/RGEMM: 16x32 (sparse compressed to 16x16) + // - VGEMM: 16x64 (sparse 1:4 compressed to 16x16) uint32_t A_cols_logical = TILE_SIZE; - if (gemm_mode == GEMM_MODE_UGEMM) A_cols_logical = 2 * TILE_SIZE; // 32 logical cols + if (gemm_mode == GEMM_MODE_UGEMM || gemm_mode == GEMM_MODE_RGEMM) A_cols_logical = 2 * TILE_SIZE; // 32 logical cols else if (gemm_mode == GEMM_MODE_VGEMM) A_cols_logical = 4 * TILE_SIZE; // 64 logical cols // B size matches A's logical K dimension @@ -293,7 +371,7 @@ int main(int argc, char *argv[]) { std::vector h_A_logical(TILE_SIZE * A_cols_logical); // Logical A before compression std::vector h_A(num_elements); // Compressed A (always 16x16 = 1KB for storage) std::vector h_B(A_cols_logical * B_cols); // B is K×N where K matches A's logical K - std::vector h_C(num_elements); + std::vector h_C(num_elements); // Output is always 16×16 for tested modes std::vector h_ref(num_elements); // Initialize logical matrix A @@ -348,6 +426,21 @@ int main(int argc, char *argv[]) { RT_CHECK(vx_copy_to_dev(M_buffer, h_M.data(), 0, M_buf_size)); RT_CHECK(vx_copy_to_dev(B_buffer, h_B.data(), 0, B_buf_size)); } + else if (gemm_mode == GEMM_MODE_RGEMM) { + // RGEMM: A (16x32 logical, compressed to 16x16 via row-wise N:4) × B (32x16) = C (16x16) + // A: logical 16x32 -> padded 16x16 (1KB T-tile) with metadata (128 bytes) + // B: full 32x16 stored in U-register (2KB = 2 T-regs) + + // Compress A from 16x32 logical to 16x16 padded using row-wise N:4 compression + compress_rowwise_n4_sparse(h_A_logical, TILE_SIZE, 2 * TILE_SIZE, h_A, h_M); + + std::cout << "Row-wise N:4 sparse A: logical 16x32 -> padded 16x16, metadata " << h_M.size() << " bytes" << std::endl; + + // Upload padded A (1KB), metadata (128B), and full B (2KB for U-reg) + RT_CHECK(vx_copy_to_dev(A_buffer, h_A.data(), 0, A_buf_size)); + RT_CHECK(vx_copy_to_dev(M_buffer, h_M.data(), 0, M_buf_size)); + RT_CHECK(vx_copy_to_dev(B_buffer, h_B.data(), 0, B_buf_size)); + } // upload kernel binary std::cout << "upload kernel binary" << std::endl; @@ -371,12 +464,40 @@ int main(int argc, char *argv[]) { // Zero out pruned values in h_A_logical based on metadata for CPU reference // This ensures CPU computes the same result as GPU (which only uses non-zero values) + // + // METADATA LAYOUT DIFFERENCES: + // - UGEMM/VGEMM metadata: 8 bytes per row (16 nibbles = 16 4-element blocks) + // Total: 128 bytes (16 rows × 8 bytes/row) + // - RGEMM metadata: 4 bytes per row (8 nibbles = 8 4-element blocks) + // Total: 64 bytes mask data + 64 bytes reserved = 128 bytes + // if (gemm_mode == GEMM_MODE_UGEMM || gemm_mode == GEMM_MODE_VGEMM) { + // UGEMM/VGEMM: 8 bytes per row (16 blocks of 4 elements for K=32/64) for (uint32_t row = 0; row < TILE_SIZE; ++row) { uint32_t k_groups = A_cols_logical / 4; for (uint32_t k_grp = 0; k_grp < k_groups; ++k_grp) { int k_base = k_grp * 4; - // Get metadata nibble for this group + // Get metadata nibble for this group (8 bytes per row layout) + int byte_idx = row * 8 + k_grp / 2; + uint8_t nibble = (k_grp % 2 == 0) ? (h_M[byte_idx] >> 4) : (h_M[byte_idx] & 0x0F); + + // Zero out positions not in metadata mask + for (int offset = 0; offset < 4; ++offset) { + if (!(nibble & (1u << offset))) { + h_A_logical[row * A_cols_logical + k_base + offset] = 0.0f; + } + } + } + } + } + else if (gemm_mode == GEMM_MODE_RGEMM) { + // RGEMM: 8 bytes per row (matching TILE_LOAD_M format) + // Only first 8 nibbles (cols 0-7) are used for 8 blocks of 4 elements (K=32) + for (uint32_t row = 0; row < TILE_SIZE; ++row) { + uint32_t k_groups = A_cols_logical / 4; // 8 groups for K=32 + for (uint32_t k_grp = 0; k_grp < k_groups; ++k_grp) { + int k_base = k_grp * 4; + // Get metadata nibble for this group (8 bytes per row layout) int byte_idx = row * 8 + k_grp / 2; uint8_t nibble = (k_grp % 2 == 0) ? (h_M[byte_idx] >> 4) : (h_M[byte_idx] & 0x0F); @@ -390,13 +511,15 @@ int main(int argc, char *argv[]) { } } - // compute CPU reference using logical A matrix (now with zeros in pruned positions) + // compute CPU reference std::cout << "verify result" << std::endl; - // C = A × B where A is MxK, B is KxN, C is MxN + + // For all modes: C = A_logical (with zeros in pruned positions) × B + // For RGEMM: A_logical is 16×32 with zeros, B is 32×16, result is 16×16 // M = TILE_SIZE (16), K = A_cols_logical, N = B_cols (16) matmul_cpu(h_ref.data(), h_A_logical.data(), h_B.data(), TILE_SIZE, A_cols_logical, B_cols); - // verify result + // verify result (always 256 elements = 16×16) int errors = 0; for (uint32_t i = 0; i < num_elements; ++i) { compare_float(h_C[i], h_ref[i], i, errors); @@ -408,13 +531,48 @@ int main(int argc, char *argv[]) { if (output_file.is_open()) { output_file << "GEMM Mode: " << mode_name << "\n\n"; + // 1. Print compressed/padded A matrix (what's actually sent to hardware) + if (gemm_mode != GEMM_MODE_TGEMM) { + output_file << "Matrix A Padded (Compressed " << TILE_SIZE << "x" << TILE_SIZE << "):\n"; + for (uint32_t i = 0; i < TILE_SIZE; ++i) { + for (uint32_t j = 0; j < TILE_SIZE; ++j) { + output_file << h_A[i * TILE_SIZE + j]; + if (j < TILE_SIZE - 1) output_file << " "; + } + output_file << "\n"; + } + output_file << "\n"; + + // 2. Print metadata as 0/1 pattern + output_file << "Metadata (" << TILE_SIZE << "x" << A_cols_logical << " sparsity pattern, 1=kept, 0=pruned):\n"; + for (uint32_t row = 0; row < TILE_SIZE; ++row) { + uint32_t k_groups = A_cols_logical / 4; + for (uint32_t k_grp = 0; k_grp < k_groups; ++k_grp) { + // Get metadata nibble for this group (8 bytes per row layout) + int byte_idx = row * 8 + k_grp / 2; + uint8_t nibble = (k_grp % 2 == 0) ? (h_M[byte_idx] >> 4) : (h_M[byte_idx] & 0x0F); + + // Print 4 bits as 0/1 + for (int offset = 0; offset < 4; ++offset) { + output_file << ((nibble & (1u << offset)) ? "1" : "0"); + if (k_grp < k_groups - 1 || offset < 3) output_file << " "; + } + } + output_file << "\n"; + } + output_file << "\n"; + } + + // 3. Print logical A matrix output_file << "Matrix A Logical ("; if (gemm_mode == GEMM_MODE_TGEMM) { output_file << "Dense"; } else if (gemm_mode == GEMM_MODE_UGEMM) { output_file << "2:4 Sparse"; - } else { + } else if (gemm_mode == GEMM_MODE_VGEMM) { output_file << "1:4 Sparse"; + } else if (gemm_mode == GEMM_MODE_RGEMM) { + output_file << "Row-wise N:4 Sparse"; } output_file << ", " << TILE_SIZE << "x" << A_cols_logical << "):\n"; for (uint32_t i = 0; i < TILE_SIZE; ++i) { @@ -426,6 +584,7 @@ int main(int argc, char *argv[]) { } output_file << "\n"; + // 4. Print B matrix output_file << "Matrix B (Dense, " << A_cols_logical << "x" << B_cols << "):\n"; for (uint32_t i = 0; i < A_cols_logical; ++i) { for (uint32_t j = 0; j < B_cols; ++j) { @@ -436,9 +595,10 @@ int main(int argc, char *argv[]) { } output_file << "\n"; + // 5. Print C matrices (GPU and CPU reference) output_file << "Matrix C (GPU Result, " << TILE_SIZE << "x" << TILE_SIZE << "):\n"; - for (int i = 0; i < TILE_SIZE; ++i) { - for (int j = 0; j < TILE_SIZE; ++j) { + for (uint32_t i = 0; i < TILE_SIZE; ++i) { + for (uint32_t j = 0; j < TILE_SIZE; ++j) { output_file << h_C[i * TILE_SIZE + j]; if (j < TILE_SIZE - 1) output_file << " "; } @@ -447,8 +607,8 @@ int main(int argc, char *argv[]) { output_file << "\n"; output_file << "Matrix C (CPU Reference, " << TILE_SIZE << "x" << TILE_SIZE << "):\n"; - for (int i = 0; i < TILE_SIZE; ++i) { - for (int j = 0; j < TILE_SIZE; ++j) { + for (uint32_t i = 0; i < TILE_SIZE; ++i) { + for (uint32_t j = 0; j < TILE_SIZE; ++j) { output_file << h_ref[i * TILE_SIZE + j]; if (j < TILE_SIZE - 1) output_file << " "; }