Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Bender.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ sources:
- src/ita_inp2_mux.sv
- src/ita_input_sampler.sv
- src/ita_output_controller.sv
- src/ita_register_file_1w_1r_double_width_write.sv
- src/ita_register_file_1w_1r_multiwidth.sv
- src/ita_register_file_1w_multi_port_read.sv
- src/ita_register_file_1w_multi_port_read_we.sv
- src/ita_requantizer.sv
Expand Down
3 changes: 2 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ else ifeq ($(activation), relu)
else
activation_int = 0
endif
vlog_defs += -DNO_STALLS=$(no_stalls) -DSINGLE_ATTENTION=$(single_attention) -DSEQ_LENGTH=$(s) -DEMBED_SIZE=$(e) -DPROJ_SPACE=$(p) -DFF_SIZE=$(f) -DBIAS=$(bias) -DACTIVATION=$(activation_int)
ITA_N ?= 16
vlog_defs += -DNO_STALLS=$(no_stalls) -DSINGLE_ATTENTION=$(single_attention) -DSEQ_LENGTH=$(s) -DEMBED_SIZE=$(e) -DPROJ_SPACE=$(p) -DFF_SIZE=$(f) -DBIAS=$(bias) -DACTIVATION=$(activation_int) -DITA_N=$(ITA_N)

ifeq ($(target), sim_ita_hwpe_tb)
BENDER_TARGETS += -t ita_hwpe -t ita_hwpe_test
Expand Down
8 changes: 5 additions & 3 deletions PyITA/ITA.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class Transformer:
WI = 8

def __init__(self,
ITA_N: int,
S: int,
P: int,
E: int,
Expand All @@ -60,7 +61,7 @@ def __init__(self,
Bff: ArrayLike = None,
Bff2: ArrayLike = None):

self.ITA_N = 16
self.ITA_N = ITA_N
self.ITA_M = 64

# WIESEP: Set numpy print options
Expand Down Expand Up @@ -546,7 +547,7 @@ def soft(self, no_partial_softmax = False):
if no_partial_softmax:
self.A_partial_softmax = fastSoftmax(self.A_requant)
else:
self.A_partial_softmax = streamingPartialSoftmax(self.A_requant)
self.A_partial_softmax = streamingPartialSoftmax(self.ITA_N, self.A_requant)

if self.H == 1:
A_save = [np.tile(self.A_partial_softmax[i], [self.split, 1]) for i in range(self.H)]
Expand Down Expand Up @@ -974,8 +975,9 @@ def generateTestVectors(path, **kwargs):
bias = int(not kwargs['no_bias'])
export_snitch_cluster = kwargs['export_snitch_cluster']
export_mempool = kwargs['export_mempool']
ita_n = kwargs['ITA_N']

acc1 = Transformer(s, p, e, f, h, bias = bias, path = path, activation = activation)
acc1 = Transformer(ita_n, s, p, e, f, h, bias = bias, path = path, activation = activation)

if kwargs['verbose']:
print("=> Generating test vectors...")
Expand Down
3 changes: 1 addition & 2 deletions PyITA/softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,12 @@ def fastSoftmax(x, integerize = True):
return np.repeat(exp_sum_inverse, seq_length).reshape(n_heads, seq_length, seq_length) / 2**shift


def streamingPartialSoftmax(x, integerize = True):
def streamingPartialSoftmax(width, x, integerize = True):
if not integerize:
x = x.astype(np.float32)

seq_length = x.shape[-1]
n_heads = x.shape[-3]
width = 16 # 16 PE (processing units)
groups = seq_length // width

assert seq_length % width == 0, f"Sequence length must be a multiple of width ({width})"
Expand Down
5 changes: 3 additions & 2 deletions src/hwpe/ita_hwpe_ctrl.sv
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ module ita_hwpe_ctrl
);

localparam int unsigned TOT_LEN = M*M/N;
localparam int unsigned WEIGHT_DIV = (ITA_TCDM_DW / 8) / N;

logic slave_clear, stream_clear;
ctrl_slave_t slave_ctrl;
Expand Down Expand Up @@ -146,15 +147,15 @@ module ita_hwpe_ctrl
restart_weight_d = 1'b1;
ctrl_streamer_o.bias_source_ctrl.req_start = !ctrl_stream_o.bias_disable;
ctrl_streamer_o.output_sink_ctrl.req_start = 1'b1 & !ctrl_stream_o.output_disable;
weight_len_d = (TOT_LEN / 8) - (!ctrl_stream_o.weight_preload * M / 8);
weight_len_d = (TOT_LEN / WEIGHT_DIV) - (!ctrl_stream_o.weight_preload * M / WEIGHT_DIV);
weight_base_addr_d = weight_addr[0] + !ctrl_stream_o.weight_preload * N * M;
end
end
NextLoad: begin
if (flags_streamer_i.weight_source_flags.done) begin
state_d = Done;
restart_weight_d = 1'b1;
weight_len_d = M / 8;
weight_len_d = M / WEIGHT_DIV;
weight_base_addr_d = weight_addr[1];
end
end
Expand Down
27 changes: 16 additions & 11 deletions src/hwpe/ita_hwpe_input_bias_buffer.sv
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,19 @@ module ita_hwpe_input_bias_buffer #(
hwpe_stream_intf_stream.source data_o
);

localparam int unsigned STRB_WIDTH = OUTPUT_DATA_WIDTH / 8;
localparam int unsigned WDATA_WIDTH = (N == 8) ? 4 * OUTPUT_DATA_WIDTH :
(N == 16) ? 2 * OUTPUT_DATA_WIDTH :
OUTPUT_DATA_WIDTH;

typedef enum logic { Read, Write } bias_state_t;

bias_state_t state_d, state_q;

logic [8-1:0] read_cnt_d, read_cnt_q;
logic [4-1:0] write_cnt_d, write_cnt_q;
logic [$clog2(M*M/N)-1:0] read_cnt_d, read_cnt_q;
logic write_cnt_d, write_cnt_q;

logic [1:0] read_addr;
logic [1 + $clog2(WDATA_WIDTH/OUTPUT_DATA_WIDTH)-1:0] read_addr;
logic write_addr;
logic read_enable, read_enable_q;
logic write_enable;
Expand All @@ -43,7 +48,7 @@ module ita_hwpe_input_bias_buffer #(
write_enable = 0;
data_i.ready = 0;
data_o.valid = 0;
data_o.strb = 48'hFFFFFFFFFFFF;
data_o.strb = {STRB_WIDTH{1'b1}};
data_o.data = '0;
bias_reshape = '0;

Expand All @@ -64,22 +69,22 @@ module ita_hwpe_input_bias_buffer #(
if (read_enable_q) begin
data_o.data = read_data;
if (bias_dir_i) begin
bias_reshape = read_data >> read_cnt_q[3:0] * 24;
bias_reshape = read_data >> read_cnt_q[$clog2(N)-1:0] * 24;
data_o.data = {N {bias_reshape[0]}};
end
end
read_enable = 1;
if(data_o.valid && data_o.ready) begin
read_cnt_d = read_cnt_q + 1;
if(read_cnt_q == 255) begin
if(read_cnt_q == M*M/N-1) begin
state_d = Write;
read_cnt_d = 0;
end
end
if (bias_dir_i) begin
read_addr = read_cnt_d[5:4];
read_addr = read_cnt_d[$clog2(M)-1:$clog2(N)];
end else begin
read_addr = read_cnt_d[7:6];
read_addr = read_cnt_d[$clog2(M*M/N)-1:$clog2(M)];
end
end
endcase
Expand All @@ -99,9 +104,9 @@ module ita_hwpe_input_bias_buffer #(
end
end

ita_register_file_1w_1r_double_width_write #(
ita_register_file_1w_1r_multiwidth #(
.WADDR_WIDTH(1),
.WDATA_WIDTH(2*OUTPUT_DATA_WIDTH),
.WDATA_WIDTH(WDATA_WIDTH),
.RDATA_WIDTH(OUTPUT_DATA_WIDTH )
) i_register_file (
.clk (clk_i),
Expand All @@ -111,7 +116,7 @@ module ita_hwpe_input_bias_buffer #(
.ReadData (read_data),
.WriteEnable (write_enable),
.WriteAddr (write_addr),
.WriteData (data_i.data[2*OUTPUT_DATA_WIDTH-1:0])
.WriteData (data_i.data[WDATA_WIDTH-1:0])
);

endmodule
69 changes: 29 additions & 40 deletions src/ita_max_finder.sv
Original file line number Diff line number Diff line change
@@ -1,61 +1,50 @@
// Copyright 2020 ETH Zurich and University of Bologna.
// Copyright 2025 ETH Zurich and University of Bologna.
// Solderpad Hardware License, Version 0.51, see LICENSE for details.
// SPDX-License-Identifier: SHL-0.51

module ita_max_finder
import ita_package::*;
(
input logic clk_i ,
input logic rst_ni ,
// Input
input requant_oup_t x_i ,
input logic clk_i,
input logic rst_ni,
input requant_oup_t x_i,
input requant_t prev_max_i,
output requant_t max_o ,
output requant_t max_o,
output requant_t max_diff_o
);

// find maximum
requant_t [N/2-1:0] max_tmp;
requant_t [N/4-1:0] max_tmp2;
requant_t [N/8-1:0] max_tmp3;
requant_t max_tmp4;

always_comb begin

for (int i = 0; i < N/2; i++) begin
if (x_i[2*i]>x_i[2*i+1])
max_tmp[i] = x_i[2*i];
else
max_tmp[i] = x_i[2*i+1];
function automatic requant_t reduce_max(input requant_oup_t vec);
requant_oup_t stage;
int size = N;
int idx;
begin
stage = vec;
while (size > 1) begin
for (int i = 0; i < size/2; i++) begin
if (stage[2*i] > stage[2*i+1])
stage[i] = stage[2*i];
else
stage[i] = stage[2*i+1];
end
size = size / 2;
end
return stage[0];
end
endfunction

for (int i = 0; i < N/4; i++) begin
if (max_tmp[2*i]>max_tmp[2*i+1])
max_tmp2[i] = max_tmp[2*i];
else
max_tmp2[i] = max_tmp[2*i+1];
end
requant_t max_val;

for (int i = 0; i < N/8; i++) begin
if (max_tmp2[2*i]>max_tmp2[2*i+1])
max_tmp3[i] = max_tmp2[2*i];
else
max_tmp3[i] = max_tmp2[2*i+1];
end

if (max_tmp3[0]>max_tmp3[1])
max_tmp4 = max_tmp3[0];
else
max_tmp4 = max_tmp3[1];
always_comb begin
max_val = reduce_max(x_i);

if (prev_max_i>max_tmp4) begin
if (prev_max_i > max_val) begin
max_o = prev_max_i;
max_diff_o = '0;
end else begin
max_o = max_tmp4;
max_diff_o = max_o-prev_max_i;
max_o = max_val;
max_diff_o = max_o - prev_max_i;
end

end

endmodule
endmodule
2 changes: 1 addition & 1 deletion src/ita_package.sv
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ package ita_package;
parameter int unsigned MNumReadPorts = N ;
parameter int unsigned FifoDepth = `ifdef ITA_OUTPUT_FIFO_DEPTH `ITA_OUTPUT_FIFO_DEPTH `else 12 `endif;
localparam int unsigned SplitFactor = 4 ;
parameter int unsigned N_WRITE_EN = `ifdef TARGET_ITA_HWPE 8 `else M `endif;
parameter int unsigned N_WRITE_EN = `ifdef TARGET_ITA_HWPE ( M * N / 128) `else M `endif; // ITA_TCDM_DW=1024 / 8 = 128

// Feedforward
typedef enum bit [1:0] {Attention=0, Feedforward=1, Linear=2, SingleAttention=3} layer_e;
Expand Down
Loading