From 2e17ca0210c662939e7956e2b827665d1a7e6cb6 Mon Sep 17 00:00:00 2001 From: gamzeisl Date: Sat, 2 Aug 2025 14:08:13 +0200 Subject: [PATCH 1/2] Support N = 8 --- Makefile | 3 +- PyITA/ITA.py | 8 +++-- PyITA/softmax.py | 3 +- src/hwpe/ita_hwpe_ctrl.sv | 5 +-- src/ita_max_finder.sv | 69 ++++++++++++++++----------------------- src/ita_package.sv | 2 +- testGenerator.py | 3 ++ 7 files changed, 44 insertions(+), 49 deletions(-) diff --git a/Makefile b/Makefile index 3359ca7..2ba4bf3 100644 --- a/Makefile +++ b/Makefile @@ -34,7 +34,8 @@ else ifeq ($(activation), relu) else activation_int = 0 endif -vlog_defs += -DNO_STALLS=$(no_stalls) -DSINGLE_ATTENTION=$(single_attention) -DSEQ_LENGTH=$(s) -DEMBED_SIZE=$(e) -DPROJ_SPACE=$(p) -DFF_SIZE=$(f) -DBIAS=$(bias) -DACTIVATION=$(activation_int) +ITA_N ?= 16 +vlog_defs += -DNO_STALLS=$(no_stalls) -DSINGLE_ATTENTION=$(single_attention) -DSEQ_LENGTH=$(s) -DEMBED_SIZE=$(e) -DPROJ_SPACE=$(p) -DFF_SIZE=$(f) -DBIAS=$(bias) -DACTIVATION=$(activation_int) -DITA_N=$(ITA_N) ifeq ($(target), sim_ita_hwpe_tb) BENDER_TARGETS += -t ita_hwpe -t ita_hwpe_test diff --git a/PyITA/ITA.py b/PyITA/ITA.py index 0068723..2b118e6 100644 --- a/PyITA/ITA.py +++ b/PyITA/ITA.py @@ -35,6 +35,7 @@ class Transformer: WI = 8 def __init__(self, + ITA_N: int, S: int, P: int, E: int, @@ -60,7 +61,7 @@ def __init__(self, Bff: ArrayLike = None, Bff2: ArrayLike = None): - self.ITA_N = 16 + self.ITA_N = ITA_N self.ITA_M = 64 # WIESEP: Set numpy print options @@ -546,7 +547,7 @@ def soft(self, no_partial_softmax = False): if no_partial_softmax: self.A_partial_softmax = fastSoftmax(self.A_requant) else: - self.A_partial_softmax = streamingPartialSoftmax(self.A_requant) + self.A_partial_softmax = streamingPartialSoftmax(self.ITA_N, self.A_requant) if self.H == 1: A_save = [np.tile(self.A_partial_softmax[i], [self.split, 1]) for i in range(self.H)] @@ -974,8 +975,9 @@ def generateTestVectors(path, **kwargs): bias = int(not kwargs['no_bias']) export_snitch_cluster = kwargs['export_snitch_cluster'] export_mempool = kwargs['export_mempool'] + ita_n = kwargs['ITA_N'] - acc1 = Transformer(s, p, e, f, h, bias = bias, path = path, activation = activation) + acc1 = Transformer(ita_n, s, p, e, f, h, bias = bias, path = path, activation = activation) if kwargs['verbose']: print("=> Generating test vectors...") diff --git a/PyITA/softmax.py b/PyITA/softmax.py index 7545086..b4b9e08 100644 --- a/PyITA/softmax.py +++ b/PyITA/softmax.py @@ -66,13 +66,12 @@ def fastSoftmax(x, integerize = True): return np.repeat(exp_sum_inverse, seq_length).reshape(n_heads, seq_length, seq_length) / 2**shift -def streamingPartialSoftmax(x, integerize = True): +def streamingPartialSoftmax(width, x, integerize = True): if not integerize: x = x.astype(np.float32) seq_length = x.shape[-1] n_heads = x.shape[-3] - width = 16 # 16 PE (processing units) groups = seq_length // width assert seq_length % width == 0, f"Sequence length must be a multiple of width ({width})" diff --git a/src/hwpe/ita_hwpe_ctrl.sv b/src/hwpe/ita_hwpe_ctrl.sv index 1edd454..6052b7a 100644 --- a/src/hwpe/ita_hwpe_ctrl.sv +++ b/src/hwpe/ita_hwpe_ctrl.sv @@ -34,6 +34,7 @@ module ita_hwpe_ctrl ); localparam int unsigned TOT_LEN = M*M/N; + localparam int unsigned WEIGHT_DIV = (ITA_TCDM_DW / 8) / N; logic slave_clear, stream_clear; ctrl_slave_t slave_ctrl; @@ -146,7 +147,7 @@ module ita_hwpe_ctrl restart_weight_d = 1'b1; ctrl_streamer_o.bias_source_ctrl.req_start = !ctrl_stream_o.bias_disable; ctrl_streamer_o.output_sink_ctrl.req_start = 1'b1 & !ctrl_stream_o.output_disable; - weight_len_d = (TOT_LEN / 8) - (!ctrl_stream_o.weight_preload * M / 8); + weight_len_d = (TOT_LEN / WEIGHT_DIV) - (!ctrl_stream_o.weight_preload * M / WEIGHT_DIV); weight_base_addr_d = weight_addr[0] + !ctrl_stream_o.weight_preload * N * M; end end @@ -154,7 +155,7 @@ module ita_hwpe_ctrl if (flags_streamer_i.weight_source_flags.done) begin state_d = Done; restart_weight_d = 1'b1; - weight_len_d = M / 8; + weight_len_d = M / WEIGHT_DIV; weight_base_addr_d = weight_addr[1]; end end diff --git a/src/ita_max_finder.sv b/src/ita_max_finder.sv index 5997979..7f0ae77 100644 --- a/src/ita_max_finder.sv +++ b/src/ita_max_finder.sv @@ -1,61 +1,50 @@ -// Copyright 2020 ETH Zurich and University of Bologna. +// Copyright 2025 ETH Zurich and University of Bologna. // Solderpad Hardware License, Version 0.51, see LICENSE for details. // SPDX-License-Identifier: SHL-0.51 module ita_max_finder import ita_package::*; ( - input logic clk_i , - input logic rst_ni , - // Input - input requant_oup_t x_i , + input logic clk_i, + input logic rst_ni, + input requant_oup_t x_i, input requant_t prev_max_i, - output requant_t max_o , + output requant_t max_o, output requant_t max_diff_o ); - // find maximum - requant_t [N/2-1:0] max_tmp; - requant_t [N/4-1:0] max_tmp2; - requant_t [N/8-1:0] max_tmp3; - requant_t max_tmp4; - always_comb begin - - for (int i = 0; i < N/2; i++) begin - if (x_i[2*i]>x_i[2*i+1]) - max_tmp[i] = x_i[2*i]; - else - max_tmp[i] = x_i[2*i+1]; + function automatic requant_t reduce_max(input requant_oup_t vec); + requant_oup_t stage; + int size = N; + int idx; + begin + stage = vec; + while (size > 1) begin + for (int i = 0; i < size/2; i++) begin + if (stage[2*i] > stage[2*i+1]) + stage[i] = stage[2*i]; + else + stage[i] = stage[2*i+1]; + end + size = size / 2; + end + return stage[0]; end + endfunction - for (int i = 0; i < N/4; i++) begin - if (max_tmp[2*i]>max_tmp[2*i+1]) - max_tmp2[i] = max_tmp[2*i]; - else - max_tmp2[i] = max_tmp[2*i+1]; - end + requant_t max_val; - for (int i = 0; i < N/8; i++) begin - if (max_tmp2[2*i]>max_tmp2[2*i+1]) - max_tmp3[i] = max_tmp2[2*i]; - else - max_tmp3[i] = max_tmp2[2*i+1]; - end - - if (max_tmp3[0]>max_tmp3[1]) - max_tmp4 = max_tmp3[0]; - else - max_tmp4 = max_tmp3[1]; + always_comb begin + max_val = reduce_max(x_i); - if (prev_max_i>max_tmp4) begin + if (prev_max_i > max_val) begin max_o = prev_max_i; max_diff_o = '0; end else begin - max_o = max_tmp4; - max_diff_o = max_o-prev_max_i; + max_o = max_val; + max_diff_o = max_o - prev_max_i; end - end -endmodule +endmodule \ No newline at end of file diff --git a/src/ita_package.sv b/src/ita_package.sv index c20ef71..0555049 100644 --- a/src/ita_package.sv +++ b/src/ita_package.sv @@ -32,7 +32,7 @@ package ita_package; parameter int unsigned MNumReadPorts = N ; parameter int unsigned FifoDepth = `ifdef ITA_OUTPUT_FIFO_DEPTH `ITA_OUTPUT_FIFO_DEPTH `else 12 `endif; localparam int unsigned SplitFactor = 4 ; - parameter int unsigned N_WRITE_EN = `ifdef TARGET_ITA_HWPE 8 `else M `endif; + parameter int unsigned N_WRITE_EN = `ifdef TARGET_ITA_HWPE ( M * N / 128) `else M `endif; // ITA_TCDM_DW=1024 / 8 = 128 // Feedforward typedef enum bit [1:0] {Attention=0, Feedforward=1, Linear=2, SingleAttention=3} layer_e; diff --git a/testGenerator.py b/testGenerator.py index 0c94a55..a69d8f2 100644 --- a/testGenerator.py +++ b/testGenerator.py @@ -113,6 +113,9 @@ class ArgumentDefaultMetavarTypeFormatter(argparse.ArgumentDefaultsHelpFormatter self.group1.add_argument('--export-mempool', action = 'store_true', help = 'Export for mempool') self.group1.add_argument('--export-rom', action = 'store_true', help = 'Export ROM configuration') + self.group2 = self.add_argument_group('Engine Settings') + self.group2.add_argument('-ITA_N', default = 16, type = int, help = 'Number of ITA processing elements') + if __name__ == "__main__": parser = TestParser() From 0a77df9514a160a54426a71641987079486c9aa5 Mon Sep 17 00:00:00 2001 From: gamzeisl Date: Sat, 2 Aug 2025 16:33:52 +0200 Subject: [PATCH 2/2] Parametrize bias buffer --- Bender.yml | 2 +- src/hwpe/ita_hwpe_input_bias_buffer.sv | 27 ++-- ..._register_file_1w_1r_double_width_write.sv | 139 ------------------ src/ita_register_file_1w_1r_multiwidth.sv | 77 ++++++++++ 4 files changed, 94 insertions(+), 151 deletions(-) delete mode 100644 src/ita_register_file_1w_1r_double_width_write.sv create mode 100644 src/ita_register_file_1w_1r_multiwidth.sv diff --git a/Bender.yml b/Bender.yml index 6d2c5a6..3655d0d 100644 --- a/Bender.yml +++ b/Bender.yml @@ -39,7 +39,7 @@ sources: - src/ita_inp2_mux.sv - src/ita_input_sampler.sv - src/ita_output_controller.sv - - src/ita_register_file_1w_1r_double_width_write.sv + - src/ita_register_file_1w_1r_multiwidth.sv - src/ita_register_file_1w_multi_port_read.sv - src/ita_register_file_1w_multi_port_read_we.sv - src/ita_requantizer.sv diff --git a/src/hwpe/ita_hwpe_input_bias_buffer.sv b/src/hwpe/ita_hwpe_input_bias_buffer.sv index eaf9b2f..b415c95 100644 --- a/src/hwpe/ita_hwpe_input_bias_buffer.sv +++ b/src/hwpe/ita_hwpe_input_bias_buffer.sv @@ -17,14 +17,19 @@ module ita_hwpe_input_bias_buffer #( hwpe_stream_intf_stream.source data_o ); + localparam int unsigned STRB_WIDTH = OUTPUT_DATA_WIDTH / 8; + localparam int unsigned WDATA_WIDTH = (N == 8) ? 4 * OUTPUT_DATA_WIDTH : + (N == 16) ? 2 * OUTPUT_DATA_WIDTH : + OUTPUT_DATA_WIDTH; + typedef enum logic { Read, Write } bias_state_t; bias_state_t state_d, state_q; - logic [8-1:0] read_cnt_d, read_cnt_q; - logic [4-1:0] write_cnt_d, write_cnt_q; + logic [$clog2(M*M/N)-1:0] read_cnt_d, read_cnt_q; + logic write_cnt_d, write_cnt_q; - logic [1:0] read_addr; + logic [1 + $clog2(WDATA_WIDTH/OUTPUT_DATA_WIDTH)-1:0] read_addr; logic write_addr; logic read_enable, read_enable_q; logic write_enable; @@ -43,7 +48,7 @@ module ita_hwpe_input_bias_buffer #( write_enable = 0; data_i.ready = 0; data_o.valid = 0; - data_o.strb = 48'hFFFFFFFFFFFF; + data_o.strb = {STRB_WIDTH{1'b1}}; data_o.data = '0; bias_reshape = '0; @@ -64,22 +69,22 @@ module ita_hwpe_input_bias_buffer #( if (read_enable_q) begin data_o.data = read_data; if (bias_dir_i) begin - bias_reshape = read_data >> read_cnt_q[3:0] * 24; + bias_reshape = read_data >> read_cnt_q[$clog2(N)-1:0] * 24; data_o.data = {N {bias_reshape[0]}}; end end read_enable = 1; if(data_o.valid && data_o.ready) begin read_cnt_d = read_cnt_q + 1; - if(read_cnt_q == 255) begin + if(read_cnt_q == M*M/N-1) begin state_d = Write; read_cnt_d = 0; end end if (bias_dir_i) begin - read_addr = read_cnt_d[5:4]; + read_addr = read_cnt_d[$clog2(M)-1:$clog2(N)]; end else begin - read_addr = read_cnt_d[7:6]; + read_addr = read_cnt_d[$clog2(M*M/N)-1:$clog2(M)]; end end endcase @@ -99,9 +104,9 @@ module ita_hwpe_input_bias_buffer #( end end - ita_register_file_1w_1r_double_width_write #( + ita_register_file_1w_1r_multiwidth #( .WADDR_WIDTH(1), - .WDATA_WIDTH(2*OUTPUT_DATA_WIDTH), + .WDATA_WIDTH(WDATA_WIDTH), .RDATA_WIDTH(OUTPUT_DATA_WIDTH ) ) i_register_file ( .clk (clk_i), @@ -111,7 +116,7 @@ module ita_hwpe_input_bias_buffer #( .ReadData (read_data), .WriteEnable (write_enable), .WriteAddr (write_addr), - .WriteData (data_i.data[2*OUTPUT_DATA_WIDTH-1:0]) + .WriteData (data_i.data[WDATA_WIDTH-1:0]) ); endmodule diff --git a/src/ita_register_file_1w_1r_double_width_write.sv b/src/ita_register_file_1w_1r_double_width_write.sv deleted file mode 100644 index 8e573f6..0000000 --- a/src/ita_register_file_1w_1r_double_width_write.sv +++ /dev/null @@ -1,139 +0,0 @@ -// Copyright 2014 ETH Zurich and University of Bologna. -// Solderpad Hardware License, Version 0.51, see LICENSE for details. -// SPDX-License-Identifier: SHL-0.51 - -module ita_register_file_1w_1r_double_width_write -#( - parameter WADDR_WIDTH = 2, - parameter WDATA_WIDTH = 768, - - parameter RDATA_WIDTH = 384, - parameter RADDR_WIDTH = WADDR_WIDTH+$clog2(WDATA_WIDTH/RDATA_WIDTH), - - parameter W_N_ROWS = 2**WADDR_WIDTH -) -( - input logic clk, - input logic rst_n, - - // Read port - input logic ReadEnable, - input logic [RADDR_WIDTH-1:0] ReadAddr, - output logic [RDATA_WIDTH-1:0] ReadData, - - // Write port - input logic WriteEnable, - input logic [WADDR_WIDTH-1:0] WriteAddr, - input logic [WDATA_WIDTH-1:0] WriteData -); - - -logic [RDATA_WIDTH-1:0] ReadData_lo; -logic [RDATA_WIDTH-1:0] ReadData_hi; - -logic DEST; - -int unsigned j; -genvar i; - - - - -always_ff @(posedge clk or negedge rst_n) -begin - if(~rst_n) - begin - DEST <= 0; - end - else - begin - DEST <= ReadAddr[0]; - end -end - - - -generate - - - assign ReadData = (DEST == 1'b0) ? ReadData_lo : ReadData_hi; - - - - if(W_N_ROWS == 1) - begin - register_file_1r_1w_1row - #( - .DATA_WIDTH ( RDATA_WIDTH ) - ) - bram_cut_lo - ( - .clk ( clk ), - - .ReadEnable ( ReadEnable ), - .ReadData ( ReadData_lo ), - - .WriteEnable ( WriteEnable ), - .WriteData ( WriteData[RDATA_WIDTH-1:0] ) - ); - - register_file_1r_1w_1row - #( - .DATA_WIDTH ( RDATA_WIDTH ) - ) - bram_cut_hi - ( - .clk ( clk ), - - .ReadEnable ( ReadEnable ), - .ReadData ( ReadData_hi ), - - .WriteEnable ( WriteEnable ), - .WriteData ( WriteData[2*RDATA_WIDTH-1:RDATA_WIDTH]) - ); - end - else - begin - register_file_1r_1w - #( - .ADDR_WIDTH ( WADDR_WIDTH ), - .DATA_WIDTH ( RDATA_WIDTH ) - ) - bram_cut_lo - ( - .clk ( clk ), - - .ReadEnable ( ReadEnable ), - .ReadAddr ( ReadAddr[RADDR_WIDTH-1:1] ), - .ReadData ( ReadData_lo ), - - .WriteAddr ( WriteAddr ), - .WriteEnable ( WriteEnable ), - .WriteData ( WriteData[RDATA_WIDTH-1:0]) - ); - - register_file_1r_1w - #( - .ADDR_WIDTH ( WADDR_WIDTH ), - .DATA_WIDTH ( RDATA_WIDTH ) - ) - bram_cut_hi - ( - .clk ( clk ), - - .ReadEnable ( ReadEnable ), - .ReadAddr ( ReadAddr[RADDR_WIDTH-1:1] ), - .ReadData ( ReadData_hi ), - - .WriteAddr ( WriteAddr ), - .WriteEnable ( WriteEnable ), - .WriteData ( WriteData[2*RDATA_WIDTH-1:RDATA_WIDTH]) - ); - end - -endgenerate - - - - -endmodule \ No newline at end of file diff --git a/src/ita_register_file_1w_1r_multiwidth.sv b/src/ita_register_file_1w_1r_multiwidth.sv new file mode 100644 index 0000000..84bdfc4 --- /dev/null +++ b/src/ita_register_file_1w_1r_multiwidth.sv @@ -0,0 +1,77 @@ +// Copyright 2025 ETH Zurich and University of Bologna. +// Solderpad Hardware License, Version 0.51, see LICENSE for details. +// SPDX-License-Identifier: SHL-0.51 + +module ita_register_file_1w_1r_multiwidth +#( + parameter WADDR_WIDTH = 2, + parameter WDATA_WIDTH = 768, + parameter RDATA_WIDTH = 384, + + localparam NUM_CUTS = WDATA_WIDTH / RDATA_WIDTH, + localparam RADDR_WIDTH = WADDR_WIDTH + $clog2(NUM_CUTS), + parameter W_N_ROWS = 2**WADDR_WIDTH +) +( + input logic clk, + input logic rst_n, + + input logic ReadEnable, + input logic [RADDR_WIDTH-1:0] ReadAddr, + output logic [RDATA_WIDTH-1:0] ReadData, + + input logic WriteEnable, + input logic [WADDR_WIDTH-1:0] WriteAddr, + input logic [WDATA_WIDTH-1:0] WriteData +); + + logic [$clog2(NUM_CUTS)-1:0] read_bank_sel_d, read_bank_sel_q; + logic [RDATA_WIDTH-1:0] ReadData_array [NUM_CUTS]; + + assign read_bank_sel_d = ReadAddr[$clog2(NUM_CUTS)-1:0]; + + always_ff @(posedge clk or negedge rst_n) + begin + if(~rst_n) + begin + read_bank_sel_q <= '0; + end + else + begin + read_bank_sel_q <= read_bank_sel_d; + end + end + + assign ReadData = ReadData_array[read_bank_sel_q]; + + genvar i; + generate + for (i = 0; i < NUM_CUTS; i++) begin : GEN_BANKS + if (W_N_ROWS == 1) begin + register_file_1r_1w_1row #( + .DATA_WIDTH(RDATA_WIDTH) + ) bank ( + .clk ( clk ), + .ReadEnable ( ReadEnable && (read_bank_sel_d == i) ), + .ReadData ( ReadData_array[i] ), + .WriteEnable ( WriteEnable ), + .WriteData ( WriteData[(i+1)*RDATA_WIDTH-1 : i*RDATA_WIDTH] ) + ); + end else begin + register_file_1r_1w #( + .ADDR_WIDTH(WADDR_WIDTH), + .DATA_WIDTH(RDATA_WIDTH) + ) bank ( + .clk ( clk ), + .ReadEnable ( ReadEnable && (read_bank_sel_d == i) ), + .ReadAddr ( ReadAddr[RADDR_WIDTH-1:$clog2(NUM_CUTS)] ), + .ReadData ( ReadData_array[i] ), + .WriteAddr ( WriteAddr ), + .WriteEnable ( WriteEnable ), + .WriteData ( WriteData[(i+1)*RDATA_WIDTH-1 : i*RDATA_WIDTH] ) + ); + end + end + endgenerate + +endmodule \ No newline at end of file