Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,11 @@ class IpDescType(TypedDict):
"common/rtl/unpacked_register_slice.sv",
"common/rtl/split2.sv",
"common/rtl/join2.sv",
"common/rtl/join_n.sv",
"common/rtl/register_slice.sv",
"memory/rtl/fifo.sv",
"memory/rtl/blk_mem_gen_0.sv",
"memory/rtl/simple_dual_port_ram.sv",
"memory/rtl/unpacked_skid_buffer.sv",
"memory/rtl/skid_buffer.sv",
"memory/rtl/ultraram_fifo.sv",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,17 @@ module mxint_accumulator #(
input logic rst,

// Input Data
input logic signed [DATA_IN_0_PRECISION_0-1:0] mdata_in_0 [BLOCK_SIZE - 1:0],
input logic [DATA_IN_0_PRECISION_1-1:0] edata_in_0,
input logic data_in_0_valid,
output logic data_in_0_ready,
input logic [DATA_IN_0_PRECISION_0-1:0] mdata_in_0 [BLOCK_SIZE - 1:0],
input logic [DATA_IN_0_PRECISION_1-1:0] edata_in_0,
input logic data_in_0_valid,
output logic data_in_0_ready,

// Output Data
output logic signed [DATA_OUT_0_PRECISION_0-1:0] mdata_out_0 [BLOCK_SIZE - 1:0],
output logic [DATA_OUT_0_PRECISION_1-1:0] edata_out_0,
output logic data_out_0_valid,
input logic data_out_0_ready,
output logic [ COUNTER_WIDTH:0] accum_count
output logic [DATA_OUT_0_PRECISION_0-1:0] mdata_out_0 [BLOCK_SIZE - 1:0],
output logic [DATA_OUT_0_PRECISION_1-1:0] edata_out_0,
output logic data_out_0_valid,
input logic data_out_0_ready,
output logic [ COUNTER_WIDTH:0] accum_count
);

localparam RIGHT_PADDING = 2 ** DATA_IN_0_PRECISION_1;
Expand Down Expand Up @@ -98,7 +98,7 @@ module mxint_accumulator #(
shifted_mdata_out_0[i] = mdata_out_0[i];
end else begin
shifted_mdata_in_0[i] = padded_mdata_in_0[i];
shifted_mdata_out_0[i] = mdata_out_0[i] >>> -shift;
shifted_mdata_out_0[i] = $signed(mdata_out_0[i]) >>> -shift;
end

end
Expand Down
29 changes: 19 additions & 10 deletions src/mase_components/linear_layers/mxint_operators/rtl/mxint_cast.sv
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,16 @@ module mxint_cast #(
input logic rst,

// Input Data
input logic signed [IN_MAN_WIDTH-1:0] mdata_in [BLOCK_SIZE-1:0],
input logic [IN_EXP_WIDTH-1:0] edata_in,
input logic data_in_valid,
output logic data_in_ready,
input logic [IN_MAN_WIDTH-1:0] mdata_in [BLOCK_SIZE-1:0],
input logic [IN_EXP_WIDTH-1:0] edata_in,
input logic data_in_valid,
output logic data_in_ready,

// Output Data
output logic signed [OUT_MAN_WIDTH-1:0] mdata_out [BLOCK_SIZE-1:0],
output logic [OUT_EXP_WIDTH-1:0] edata_out,
output logic data_out_valid,
input logic data_out_ready
output logic [OUT_MAN_WIDTH-1:0] mdata_out [BLOCK_SIZE-1:0],
output logic [OUT_EXP_WIDTH-1:0] edata_out,
output logic data_out_valid,
input logic data_out_ready
);

// =============================
Expand All @@ -50,6 +50,7 @@ module mxint_cast #(

logic data_for_max_valid, data_for_max_ready, data_for_out_valid, data_for_out_ready;
logic signed [IN_MAN_WIDTH-1:0] mbuffer_data_for_out[BLOCK_SIZE-1:0];
logic [IN_MAN_WIDTH-1:0] fifo_out[BLOCK_SIZE-1:0];
logic [IN_EXP_WIDTH-1:0] ebuffer_data_for_out;
logic buffer_data_for_out_valid, buffer_data_for_out_ready;

Expand Down Expand Up @@ -103,7 +104,9 @@ module mxint_cast #(
if (FIFO_DEPTH == 0) begin

always_comb begin
mbuffer_data_for_out = mdata_in;
for (int i = 0; i < BLOCK_SIZE; i++) begin
mbuffer_data_for_out[i] = $signed(mdata_in[i]);
end
ebuffer_data_for_out = edata_in;
buffer_data_for_out_valid = data_for_out_valid;
data_for_out_ready = buffer_data_for_out_ready;
Expand All @@ -123,12 +126,18 @@ module mxint_cast #(
.edata_in(edata_in),
.data_in_valid(data_for_out_valid),
.data_in_ready(data_for_out_ready),
.mdata_out(mbuffer_data_for_out),
.mdata_out(fifo_out),
.edata_out(ebuffer_data_for_out),
.data_out_valid(buffer_data_for_out_valid),
.data_out_ready(buffer_data_for_out_ready)
);

always_comb begin
for (int i = 0; i < BLOCK_SIZE; i++) begin
mbuffer_data_for_out[i] = $signed(fifo_out[i]);
end
end

end

// =============================
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,6 @@ module mxint_linear #(
assign acc_data_out_ready = cast_data_out_0_ready;
assign cast_data_out_0_valid = acc_data_out_valid;
assign cast_edata_out_0 = acc_edata_out;
assign bias_ready = 1;

mxint_cast #(
.IN_MAN_WIDTH(LOSSLESS_OUT_WIDTH),
Expand Down
20 changes: 20 additions & 0 deletions test/passes/graph/transforms/verilog/generate.tcl
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
source config.tcl

create_project -in_memory -part xcku5p-ffvb676-2-e
set_property board_part xilinx.com:kcu116:part0:1.5 [current_project]

add_files -fileset sources_1 "$top_dir/hardware/rtl/"

set_property top top [current_fileset]

puts "Trial: ${trial_number}"

eval "synth_design -mode out_of_context -top top -part xcku5p-ffvb676-2-e"

save_project_as -force my_project

launch_runs synth_1 -jobs 12
wait_on_run synth_1

open_run synth_1
report_utilization -file "$mase_dir/resources/util_${trial_number}.txt"
38 changes: 21 additions & 17 deletions test/passes/graph/transforms/verilog/test_emit_verilog_mxint.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def forward(self, x):
shared_emit_verilog_mxint(linear, input_shape, params)


def shared_emit_verilog_mxint(model, input_shape, params: dict):
def shared_emit_verilog_mxint(model, input_shape, params: dict, sim: bool = True):
# Set seeds
torch.manual_seed(params["seed"])
random.seed(params["seed"])
Expand Down Expand Up @@ -176,26 +176,30 @@ def shared_emit_verilog_mxint(model, input_shape, params: dict):
mg, _ = passes.emit_verilog_top_transform_pass(mg)
mg, _ = passes.emit_bram_transform_pass(mg)
mg, _ = passes.emit_internal_rtl_transform_pass(mg)
mg, _ = passes.emit_cocotb_transform_pass(
mg,
pass_args={
"wait_time": 10 * block_size * batch_parallelism * num_batches,
"wait_unit": "us",
"num_batches": num_batches,
},
)

simulate(
skip_build=False,
skip_test=False,
simulator="verilator",
waves=True,
)
if sim:
mg, _ = passes.emit_cocotb_transform_pass(
mg,
pass_args={
"wait_time": 10 * block_size * batch_parallelism * num_batches,
"wait_unit": "us",
"num_batches": num_batches,
},
)

simulate(
skip_build=False,
skip_test=False,
simulator="verilator",
waves=True,
)

logger.info(
f"{block_size=}, {batch_parallelism=}, {m_width=}, {e_width=}, {batches=}"
)

return model, mg.model


if __name__ == "__main__":
seed = os.getenv("COCOTB_SEED")
Expand All @@ -205,6 +209,6 @@ def shared_emit_verilog_mxint(model, input_shape, params: dict):
else:
seed = int(seed)
logger.info(f"Using provided {seed=}")
test_emit_verilog_mxint_linear(seed)
# test_emit_verilog_mxint_mlp(seed)
# test_emit_verilog_mxint_linear(seed)
test_emit_verilog_mxint_mlp(seed)
logger.info(f"{seed=}")
189 changes: 189 additions & 0 deletions test/passes/graph/transforms/verilog/test_synthesize_mxint_vivado.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
import os, re, random
import optuna
from optuna import study
from optuna.samplers import TPESampler, GridSampler
import json
from chop.tools.logger import set_logging_verbosity
from chop.tools import get_logger
from test_emit_verilog_mxint import MLP, shared_emit_verilog_mxint
import os, sys, logging, traceback, pdb
import pytest
import toml
import torch
import torch.nn as nn
import chop as chop
import chop.passes as passes
from pathlib import Path
from chop.actions import simulate
from chop.passes.graph.analysis.report.report_node import report_node_type_analysis_pass
from chop.tools.logger import set_logging_verbosity
from chop.tools import get_logger

config_file = "config.tcl"

set_logging_verbosity("debug")

logger = get_logger(__name__)


def dump_param(trial_number, quan_args, filename="output.json"):
try:
with open(filename, "r") as file:
data = json.load(file)
except (FileNotFoundError, json.JSONDecodeError):
data = {}

data[str(trial_number)] = quan_args

with open(filename, "w") as file:
json.dump(data, file, indent=4)


def write_value(trial_number, name, value, filename="output.json"):
try:
with open(filename, "r") as file:
data = json.load(file)
except (FileNotFoundError, json.JSONDecodeError):
data = {}

if str(trial_number) in data.keys():
data[str(trial_number)][name] = value
else:
data[str(trial_number)] = {name: value}

with open(filename, "w") as file:
json.dump(data, file, indent=4)


def get_params(trial):

block_size = 2 ** trial.suggest_int("block_size", 1, 4)
batch_parallelism = 2 ** trial.suggest_int("batch_parallelism", 1, 4)
mlp_depth = 3
mlp_features = [128 for i in range(mlp_depth + 1)]

params = {
"seed": trial.number,
"block_size": block_size,
"batch_parallelism": batch_parallelism,
"m_width": (m_width := trial.suggest_int("m_width", 4, 10)),
"e_width": trial.suggest_int("e_width", 3, min(m_width - 1, 10)),
"batches": 128,
"num_batches": 10,
}

mlp = MLP(mlp_features)
input_shape = (mlp_features[0],)

logger.info(
f"{block_size=}, {batch_parallelism=}, {params['e_width']=}, {params['m_width']=}, {params['batches']=}"
)

mg, mlp = shared_emit_verilog_mxint(mlp, input_shape, params, simulate=False)

return params, mg, mlp


def writeTrialNumber(trial_number):
with open(config_file, "w") as f:
f.write(f"set trial_number {trial_number}\n")
f.write(f"set top_dir {Path.home()}/.mase/top/\n")
f.write(f"set mase_dir {Path.cwd()}/")


def extract_site_type_used_util(filename):
site_data = {}
with open(filename, "r") as file:
lines = file.readlines()

pattern = re.compile(r"\|\s*([^|]+?)\s*\|\s*(\d+)\s*\|.*?\|\s*(\d+\.\d+)\s*\|")

for line in lines:
match = pattern.match(line)
if match:
site_type = match.group(1).strip()
used = int(match.group(2).strip())
util = float(match.group(3).strip())
site_data[site_type] = {"Used": used, "Util%": util}

return site_data


def get_bram_uram_util(filename):
site_data = extract_site_type_used_util(filename)
bram_util = site_data.get("Block RAM Tile", {}).get("Util%", 0.0)
uram_util = site_data.get("URAM", {}).get("Util%", 0.0)
return {"bram": bram_util, "uram": uram_util}


def getResources(trial):
params, mg, mlp = get_params(trial)
dump_param(trial.number, params)
writeTrialNumber(trial.number)
os.system(
f"vivado -mode batch -nolog -nojou -source {Path.cwd()}/test/passes/graph/transforms/verilog/generate.tcl"
)
bram_utils = get_bram_uram_util(f"{Path.cwd()}/resources/util_{trial.number}.txt")
clb_luts = extract_site_type_used_util(
f"{Path.cwd()}/resources/util_{trial.number}.txt"
)
out = (
clb_luts["CLB LUTs*"]["Util%"]
+ clb_luts["CLB Registers"]["Util%"]
+ clb_luts["CARRY8"]["Util%"]
+ bram_utils["bram"]
+ bram_utils["uram"]
)
write_value(trial.number, "resource_score", out)
return out


def getAccuracy(trial):
params, mg, mlp = get_params(trial)
quantized = mg.model

criterion = nn.MSELoss()
total_mse = 0.0

for _ in range(100):
x = torch.randn(params["batches"], mg.model[0].in_features)
y1 = quantized(x)
y2 = mlp(x)
mse = criterion(y1, y2)
total_mse += mse.item()

avg_mse = total_mse / 100

write_value(trial.number, "avg_mse", avg_mse)
return avg_mse


def main():
sampler = TPESampler()

study = optuna.create_study(
directions=["minimize", "minimize"],
study_name="resource_accuracy_optimiser",
sampler=sampler,
)

study.optimize(
lambda trial: (getResources(trial), getAccuracy(trial)),
n_trials=10,
timeout=60 * 60 * 24,
n_jobs=1,
)

print("Best trials:")
for trial in study.best_trials:
print(f"Trial {trial.number}: {trial.values}")


if __name__ == "__main__":

try:
os.mkdir(f"{Path.cwd()}/resources/")
except:
pass

main()