diff --git a/configs/examples/nerf_tpe_search.toml b/configs/examples/nerf_tpe_search.toml new file mode 100644 index 000000000..35e0d021e --- /dev/null +++ b/configs/examples/nerf_tpe_search.toml @@ -0,0 +1,153 @@ +# basics +model = "nerfvision" +dataset = "nerf-lego" + +max_epochs = 5 +batch_size = 128 +learning_rate = 1e-2 +accelerator = "gpu" +project = "nerf_lego_quantizations" +seed = 42 + +[search.search_space] +name = "graph/quantize/mixed_precision_ptq" + +[search.search_space.setup] +by = "name" + +[search.search_space.seed.mxint.config] +name = ['mxint'] + +data_in_width = [4, 8] +data_in_exponent_width = [3, 4, 5, 6, 7, 8, 9] +data_in_block_size = [3, 4, 5, 6, 7, 8, 9] + +weight_width = [2, 4, 8] +weight_exponent_width = [3, 4, 5, 6, 7, 8, 9] +weight_block_size = [3, 4, 5, 6, 7, 8, 9] + +bias_width = [2, 4, 8] +bias_exponent_width = [3, 4, 5, 6, 7, 8, 9] +bias_block_size = [3, 4, 5, 6, 7, 8, 9] + + +[search.search_space.seed.integer.config] +name = ["integer"] + +data_in_width = [4, 8] +data_in_frac_width = [3, 4, 5, 6, 7, 8, 9] + +weight_width = [2, 4, 8] +weight_frac_width = [3, 4, 5, 6, 7, 8, 9] + +bias_width = [2, 4, 8] +bias_frac_width = [3, 4, 5, 6, 7, 8, 9] + + +[search.search_space.seed.log.config] +name = ["log"] + +data_in_width = [4, 8] +data_in_exponent_bias = [3, 4, 5, 6, 7, 8, 9] + +weight_width = [2, 4, 8] +weight_exponent_bias = [3, 4, 5, 6, 7, 8, 9] + +bias_width = [2, 4, 8] +bias_exponent_bias = [3, 4, 5, 6, 7, 8, 9] + + +[search.search_space.seed.minifloat_denorm.config] +name = ["minifloat_denorm"] + +data_in_width = [4, 8] +data_in_exponent_width = [3, 4, 5, 6, 7, 8, 9] +data_in_exponent_bias = [3, 4, 5, 6, 7, 8, 9] + +weight_width = [2, 4, 8] +weight_exponent_width = [3, 4, 5, 6, 7, 8, 9] +weight_exponent_bias = [3, 4, 5, 6, 7, 8, 9] + +bias_width = [2, 4, 8] +bias_exponent_width = [3, 4, 5, 6, 7, 8, 9] +bias_exponent_bias = [3, 4, 5, 6, 7, 8, 9] + + +[search.search_space.seed.minifloat_ieee.config] +name = ["minifloat_ieee"] + +data_in_width = [4, 8] +data_in_exponent_width = [3, 4, 5, 6, 7, 8, 9] +data_in_exponent_bias = [3, 4, 5, 6, 7, 8, 9] + +weight_width = [2, 4, 8] +weight_exponent_width = [3, 4, 5, 6, 7, 8, 9] +weight_exponent_bias = [3, 4, 5, 6, 7, 8, 9] + +bias_width = [2, 4, 8] +bias_exponent_width = [3, 4, 5, 6, 7, 8, 9] +bias_exponent_bias = [3, 4, 5, 6, 7, 8, 9] + + +[search.search_space.seed.block_fp.config] +name = ["block_fp"] + +data_in_width = [4, 8] +data_in_exponent_width = [3, 4, 5, 6, 7, 8, 9] +data_in_exponent_bias = [3, 4, 5, 6, 7, 8, 9] +data_in_block_size = [3, 4, 5, 6, 7, 8, 9] + +weight_width = [2, 4, 8] +weight_exponent_width = [3, 4, 5, 6, 7, 8, 9] +weight_exponent_bias = [3, 4, 5, 6, 7, 8, 9] +weight_block_size = [3, 4, 5, 6, 7, 8, 9] + +bias_width = [2, 4, 8] +bias_exponent_width = [3, 4, 5, 6, 7, 8, 9] +bias_exponent_bias = [3, 4, 5, 6, 7, 8, 9] +bias_block_size = [3, 4, 5, 6, 7, 8, 9] + + +[search.search_space.seed.binary.config] +name = ["binary"] + +data_in_width = [4, 8] +data_in_stochastic = [3, 4, 5, 6, 7, 8, 9] +data_in_bipolar = [3, 4, 5, 6, 7, 8, 9] + +weight_width = [2, 4, 8] +weight_stochastic = [3, 4, 5, 6, 7, 8, 9] +weight_bipolar = [3, 4, 5, 6, 7, 8, 9] + +bias_width = [2, 4, 8] +bias_stochastic = [3, 4, 5, 6, 7, 8, 9] +bias_bipolar = [3, 4, 5, 6, 7, 8, 9] + + +[search.strategy] +name = "optuna" +eval_mode = true + +[search.strategy.sw_runner.basic_evaluation] +data_loader = "val_dataloader" +num_samples = 256 + +[search.strategy.hw_runner.average_bitwidth] +compare_to = 32 # compare to FP32 + +[search.strategy.setup] +n_jobs = 1 +n_trials = 20 +timeout = 20000 +sampler = "tpe" +# sum_scaled_metrics = true # single objective +# direction = "maximize" +sum_scaled_metrics = false # multi objective + +[search.strategy.metrics] +loss.scale = 1.0 +loss.direction = "minimize" +psnr.scale = 1.0 +psnr.direction = "maximize" +average_bitwidth.scale = 0.2 +average_bitwidth.direction = "minimize" diff --git a/src/chop/actions/search/search.py b/src/chop/actions/search/search.py index fe82d3574..8b1ae2c2b 100644 --- a/src/chop/actions/search/search.py +++ b/src/chop/actions/search/search.py @@ -1,5 +1,6 @@ import logging from os import PathLike +import json import toml import torch @@ -68,44 +69,56 @@ def search( # search preparation accelerator = parse_accelerator(accelerator) - strategy_config, search_space_config = parse_search_config(search_config) + strategy_config, orig_search_space_config = parse_search_config(search_config) save_path.mkdir(parents=True, exist_ok=True) - # load model if the save_name is provided - if load_name is not None and load_type in ["pl", "mz", "pt"]: - model = load_model(load_name=load_name, load_type=load_type, model=model) - logger.info(f"Loaded model from {load_name}.") - model.to(accelerator) - # set up data module - data_module.prepare_data() - data_module.setup() - - # construct the search space - logger.info("Building search space...") - search_space_cls = get_search_space_cls(search_space_config["name"]) - search_space = search_space_cls( - model=model, - model_info=model_info, - config=search_space_config, - dummy_input=get_dummy_input(model_info, data_module, task, device=accelerator), - accelerator=accelerator, - data_module=data_module, - ) - search_space.build_search_space() - - # construct a search strategy - strategy_cls = get_search_strategy_cls(strategy_config["name"]) - strategy = strategy_cls( - model_info=model_info, - task=task, - dataset_info=dataset_info, - data_module=data_module, - config=strategy_config, - accelerator=accelerator, - save_dir=save_path, - visualizer=visualizer, - ) - - logger.info("Search started...") - # perform search and save the results - strategy.search(search_space) + for quantization_name in orig_search_space_config["seed"]: + search_space_config = json.loads(json.dumps(orig_search_space_config)) + search_space_config["seed"]["default"] = orig_search_space_config["seed"][ + quantization_name + ] + + current_save_path = save_path / quantization_name + current_save_path.mkdir(parents=True, exist_ok=True) + + # load model if the save_name is provided + if load_name is not None and load_type in ["pl", "mz", "pt"]: + model = load_model(load_name=load_name, load_type=load_type, model=model) + logger.info(f"Loaded model from {load_name}.") + model.to(accelerator) + # set up data module + data_module.prepare_data() + data_module.setup() + + # construct the search space + logger.info("Building search space...") + search_space_cls = get_search_space_cls(search_space_config["name"]) + search_space = search_space_cls( + model=model, + model_info=model_info, + config=search_space_config, + dummy_input=get_dummy_input( + model_info, data_module, task, device=accelerator + ), + accelerator=accelerator, + data_module=data_module, + ) + search_space.build_search_space() + + # construct a search strategy + strategy_cls = get_search_strategy_cls(strategy_config["name"]) + strategy = strategy_cls( + model_info=model_info, + task=task, + dataset_info=dataset_info, + data_module=data_module, + config=strategy_config, + accelerator=accelerator, + save_dir=current_save_path, + visualizer=visualizer, + quantization_name=quantization_name, + ) + + logger.info(f"Search started... ...{quantization_name}") + # perform search and save the results + strategy.search(search_space) diff --git a/src/chop/actions/search/strategies/base.py b/src/chop/actions/search/strategies/base.py index 46a313b28..41110b2c4 100644 --- a/src/chop/actions/search/strategies/base.py +++ b/src/chop/actions/search/strategies/base.py @@ -41,6 +41,7 @@ def __init__( accelerator, save_dir, visualizer, + quantization_name: str = "", ): self.dataset_info = dataset_info self.task = task @@ -50,6 +51,10 @@ def __init__( self.data_module = data_module self.visualizer = visualizer + if quantization_name != "": + quantization_name = f"{quantization_name}/" + self.quantization_name = quantization_name + self.sw_runner = [] self.hw_runner = [] # the software runner's __call__ will use the rebuilt model to calculate the software metrics like accuracy, loss, ... diff --git a/src/chop/actions/search/strategies/optuna.py b/src/chop/actions/search/strategies/optuna.py index f7ba65a2a..c08d7d7f3 100644 --- a/src/chop/actions/search/strategies/optuna.py +++ b/src/chop/actions/search/strategies/optuna.py @@ -110,7 +110,12 @@ def objective(self, trial: optuna.trial.Trial, search_space): trial.set_user_attr("scaled_metrics", scaled_metrics) trial.set_user_attr("sampled_config", sampled_config) - self.visualizer.log_metrics(metrics=scaled_metrics, step=trial.number) + log_scaled_metrics = { + self.quantization_name + k: v + for k, v in scaled_metrics.items() + if v is not None + } + self.visualizer.log_metrics(metrics=log_scaled_metrics, step=trial.number) if not self.sum_scaled_metrics: return list(scaled_metrics.values()) diff --git a/src/chop/actions/search/strategies/runners/software/eval.py b/src/chop/actions/search/strategies/runners/software/eval.py index 0292ae202..ceffd3f05 100644 --- a/src/chop/actions/search/strategies/runners/software/eval.py +++ b/src/chop/actions/search/strategies/runners/software/eval.py @@ -4,6 +4,10 @@ from torchmetrics.text import Perplexity from torchmetrics import MeanMetric +from chop.tools.plt_wrapper.nerf.losses import ( + ColorLoss, + NerfPsnr, +) from .base import SWRunnerBase @@ -42,6 +46,9 @@ def _setup_metric(self): ).to(self.accelerator) case _: raise ValueError(f"task {self.task} is not supported.") + elif self.model_info.is_nerf_vision_model: + self.metric = NerfPsnr().to(self.accelerator) + self.color_loss = ColorLoss().to(self.accelerator) elif self.model_info.is_nlp_model: match self.task: case "classification" | "cls": @@ -62,6 +69,8 @@ def forward(self, batch: dict[str, torch.Tensor], model): return self.vision_cls_forward(batch, model) case _: raise ValueError(f"task {self.task} is not supported.") + elif self.model_info.is_nerf_vision_model: + return self.nerf_vision_cls_forward(batch, model) elif self.model_info.is_nlp_model: match self.task: case "classification" | "cls": @@ -79,6 +88,19 @@ def vision_cls_forward(self, batch, model): self.loss(loss) return {"loss": loss, "accuracy": acc} + def nerf_vision_cls_forward(self, batch, model): + for k, v in batch.items(): + batch[k] = v.to(self.accelerator) + if len(batch["pts"].shape) == 4: + for k in batch: + batch[k] = batch[k].squeeze(0) + logits = model(batch["pts"], batch["viewdirs"]) + loss = self.color_loss(logits, batch) + psnr = self.metric(logits, batch) + self.loss(loss) + + return {"loss": loss, "psnr": psnr} + def nlp_cls_forward(self, batch, model): batch = { k: v.to(self.accelerator) if isinstance(v, torch.Tensor) else v @@ -109,6 +131,8 @@ def compute(self) -> dict[str, float]: reduced["perplexity"] = self.metric.compute().item() elif isinstance(self.metric, MulticlassAccuracy): reduced["accuracy"] = self.metric.compute().item() + elif isinstance(self.metric, NerfPsnr): + reduced["psnr"] = self.metric.compute().item() else: raise ValueError(f"metric {self.metric} is not supported.") return reduced diff --git a/src/chop/actions/simulate.py b/src/chop/actions/simulate.py index e56a512d8..13201ac3e 100644 --- a/src/chop/actions/simulate.py +++ b/src/chop/actions/simulate.py @@ -60,6 +60,7 @@ def simulate( "-Wno-fatal", "-Wno-lint", "-Wno-style", + "--assert", "--trace-fst", "--trace-structs", "--trace-depth", diff --git a/src/chop/cli.py b/src/chop/cli.py index 1f1b671f2..d08bf54b6 100644 --- a/src/chop/cli.py +++ b/src/chop/cli.py @@ -863,7 +863,7 @@ def _setup_model_and_dataset(self): name=self.args.model, task=self.args.task, dataset_info=dataset_info, - checkpoint=checkpoint, + checkpoint=checkpoint or self.args.model, pretrained=self.args.is_pretrained, quant_config=quant_config, ) diff --git a/src/chop/dataset/nerf/blender.py b/src/chop/dataset/nerf/blender.py index 53865d621..00dcc65d2 100644 --- a/src/chop/dataset/nerf/blender.py +++ b/src/chop/dataset/nerf/blender.py @@ -181,9 +181,76 @@ def __getitem__(self, idx): sample = {"rays": rays, "rgbs": img, "c2w": c2w, "valid_mask": valid_mask} + inputs = pre_render_vision( + sample["rays"], + is_train=self.split == "train", + ) + + for k in inputs: + sample[k] = inputs[k] + return sample +def pre_render_vision( + rays, + N_samples=8, + use_disp=False, + perturb=0, + is_train=False, +): + # Decompose the inputs + if is_train: + rays = rays.unsqueeze(0) + N_rays = rays.shape[0] + rays_o, rays_d = rays[:, 0:3], rays[:, 3:6] # both (N_rays, 3) + near, far = rays[:, 6:7], rays[:, 7:8] # both (N_rays, 1) + + # Sample depth points + z_steps = torch.linspace(0, 1, N_samples) # (N_samples) + if not use_disp: # use linear sampling in depth space + z_vals = near * (1 - z_steps) + far * z_steps + else: # use linear sampling in disparity space + z_vals = 1 / (1 / near * (1 - z_steps) + 1 / far * z_steps) + + z_vals = z_vals.expand(N_rays, N_samples) + + # Perturb sampling time along each ray. + if perturb > 0.0: + # get intervals between samples + mids = 0.5 * (z_vals[..., 1:] + z_vals[..., :-1]) + upper = torch.concat([mids, z_vals[..., -1:]], -1) + lower = torch.concat([z_vals[..., :1], mids], -1) + # stratified samples in those intervals + t_rand = torch.rand_like(z_vals.shape) + z_vals = lower + (upper - lower) * t_rand + + viewdirs = rays_d + viewdirs = viewdirs / torch.norm(viewdirs, dim=-1, keepdim=True) + viewdirs = torch.reshape(viewdirs, [-1, 3]).to(torch.float32) + + # Points in space to evaluate model at. + pts = ( + rays_o[..., None, :] + rays_d[..., None, :] * z_vals[..., :, None] + ) # [N_rays, N_samples, 3] + + inputs = {} + + inputs["pts"] = pts + inputs["viewdirs"] = viewdirs + inputs["rays_o"] = rays_o + inputs["rays_d"] = rays_d + inputs["z_vals"] = z_vals + inputs["near"] = near + inputs["far"] = far + + if is_train: + for k in inputs: + inputs[k] = inputs[k].squeeze(0) + + return inputs + + DEFAULT_NERF_CONFIG = { "img_wh": [800, 800], "N_emb_xyz": 10, diff --git a/src/chop/models/nerf/__init__.py b/src/chop/models/nerf/__init__.py index 707b87077..7023eaf30 100644 --- a/src/chop/models/nerf/__init__.py +++ b/src/chop/models/nerf/__init__.py @@ -1,3 +1,4 @@ from .layers import Embedding, NeRF from .rendering import render_rays from .nerf import NeRFModel +from .nerf_vision import NeRFVision diff --git a/src/chop/models/nerf/nerf_vision.py b/src/chop/models/nerf/nerf_vision.py index f8591f1a4..8fa6f11be 100644 --- a/src/chop/models/nerf/nerf_vision.py +++ b/src/chop/models/nerf/nerf_vision.py @@ -1,13 +1,28 @@ +from torch._tensor import Tensor + + from torch import Tensor import torch.nn as nn import torch from typing import Any +import pytorch_lightning as pl import numpy as np import torch.nn.functional as F +from chop.models.utils import register_mase_model, register_mase_checkpoint + +from .rendering import post_render_vision, pre_render_vision -# Model -class NeRFVision(nn.Module, output_ch=4): +# Model# Model +@register_mase_model( + "nerfvision", + checkpoints=["nerfvision"], + model_source="nerfvision", + task_type="nerfvision", + physical_data_point_classification=True, + is_fx_traceable=True, +) +class NeRFVision(nn.Module): def __init__( self, D=8, @@ -16,7 +31,7 @@ def __init__( input_ch_views=3, output_ch=4, skips=[4], - use_viewdirs=False, + use_viewdirs=True, ): """ This is the Nerf model from the Nerf Paper @@ -59,11 +74,11 @@ def __init__( else: self.output_linear = nn.Linear(W, output_ch) - def forward(self, x): - input_pts, input_views = torch.split( - x, [self.input_ch, self.input_ch_views], dim=-1 - ) + def forward(self, pts, viewdirs): + raw: Tensor | Any = self.apply_layers(pts, viewdirs) + return raw + def apply_layers(self, input_pts, input_views): h = input_pts for i, l in enumerate(self.pts_linears): h = self.pts_linears[i](h) @@ -72,9 +87,13 @@ def forward(self, x): h = torch.cat([input_pts, h], -1) if self.use_viewdirs: + # Add a new dimension and expand it + expanded_input_views = input_views.unsqueeze(1).expand( + -1, h.shape[1], -1 + ) # -1 means it will retain the size of that dimension alpha = self.alpha_linear(h) feature = self.feature_linear(h) - h = torch.cat([feature, input_views], -1) + h = torch.cat([feature, expanded_input_views], -1) for i, l in enumerate(self.views_linears): h = self.views_linears[i](h) @@ -139,11 +158,28 @@ def load_weights_from_keras(self, weights): # Getters ------------------------------------------------------------------------------ def get_nerf( - info, pretrained=False, **kwargs: Any, ): - # image_size = info["image_size"] - num_classes = info.num_classes - # TODO: number of channels - return NeRFVision(output_ch=num_classes) + return NeRFVision() + + +@register_mase_checkpoint("nerfvision") +def get_nerfvision( + pretrained: bool = False, + **kwargs: Any, +) -> NeRFVision: + model = get_nerf( + pretrained=pretrained, + **kwargs, + ) + if pretrained: + weights = np.load( + "/teamspace/studios/this_studio/mase-team-coursework/mase/nerf_vision/lego_example/model_200000.npy", + allow_pickle=True, + ) + model.set_weights(weights) + else: + pretrained_weight_cls = None + + return model diff --git a/src/chop/models/nerf/rendering.py b/src/chop/models/nerf/rendering.py index c18b3c98c..2f00dec9c 100644 --- a/src/chop/models/nerf/rendering.py +++ b/src/chop/models/nerf/rendering.py @@ -240,3 +240,297 @@ def inference(results, model, typ, xyz, z_vals, test_time=False, **kwargs): ) return results + + +def render_vision( + models, + rays, + N_samples=4, + use_disp=False, + perturb=0, + noise_std=1, +): + """ + Render rays by computing the output of @model applied on @rays + Inputs: + models: list of NeRF models (coarse and fine) defined in nerf.py + embeddings: list of embedding models of origin and direction defined in nerf.py + rays: (N_rays, 3+3+2), ray origins and directions, near and far depths + N_samples: number of coarse samples per ray + use_disp: whether to sample in disparity space (inverse depth) + perturb: factor to perturb the sampling position on the ray (for coarse model only) + noise_std: factor to perturb the model's prediction of sigma + N_importance: number of fine samples per ray + chunk: the chunk size in batched inference + white_back: whether the background is white (dataset dependent) + test_time: whether it is test (inference only) or not. If True, it will not do inference + on coarse rgb to save time + Outputs: + result: dictionary containing final rgb and depth maps for coarse and fine models + """ + + def raw2outputs(raw, z_vals, rays_d): + """Transforms model's predictions to semantically meaningful values. + + Args: + raw: [num_rays, num_samples along ray, 4]. Prediction from model. + z_vals: [num_rays, num_samples along ray]. Integration time. + rays_d: [num_rays, 3]. Direction of each ray. + + Returns: + rgb_map: [num_rays, 3]. Estimated RGB color of a ray. + disp_map: [num_rays]. Disparity map. Inverse of depth map. + acc_map: [num_rays]. Sum of weights along each ray. + weights: [num_rays, num_samples]. Weights assigned to each sampled color. + depth_map: [num_rays]. Estimated distance to object. + """ + + # Function for computing density from model prediction. This value is + # strictly between [0, 1]. + def raw2alpha(raw, dists, act_fn=torch.nn.functional.relu): + return 1.0 - torch.exp(-act_fn(raw) * dists) + + # Compute 'distance' (in time) between each integration time along a ray. + dists = z_vals[..., 1:] - z_vals[..., :-1] + + # The 'distance' from the last integration time is infinity. + dists = torch.concat( + [ + dists, + torch.broadcast_to( + torch.tensor([1e10]).to(dists.device), dists[..., :1].shape + ), + ], + dim=-1, + ) # [N_rays, N_samples] + + # Multiply each distance by the norm of its corresponding direction ray + # to convert to real world distance (accounts for non-unit directions). + dists = dists * torch.norm(rays_d[..., None, :], dim=-1) + + # Extract RGB of each sample position along each ray. + rgb = torch.sigmoid(raw[..., :3]) # [N_rays, N_samples, 3] + + # Add noise to model's predictions for density. Can be used to + # regularize network during training (prevents floater artifacts). + noise = 0.0 + if noise_std > 0.0: + noise = torch.rand_like(raw[..., 3]) * noise_std + + # Predict density of each sample along each ray. Higher values imply + # higher likelihood of being absorbed at this point. + alpha = raw2alpha(raw[..., 3] + noise, dists) # [N_rays, N_samples] + + # Compute weight for RGB of each sample along each ray. A cumprod() is + # used to express the idea of the ray not having reflected up to this + # sample yet. + # [N_rays, N_samples] + weights = alpha * torch.cumprod(1.0 - alpha + 1e-10, dim=-1) + + # Computed weighted color of each sample along each ray. + rgb_map = torch.sum(weights[..., None] * rgb, dim=-2) # [N_rays, 3] + + # Estimated depth map is expected distance. + depth_map = torch.sum(weights * z_vals, dim=-1) + + # Disparity map is inverse depth. + disp_map = 1.0 / torch.maximum( + torch.tensor([1e-10]).to(dists.device), + depth_map / torch.sum(weights, dim=-1), + ) + + # Sum of weights along each ray. This value is in [0, 1] up to numerical error. + acc_map = torch.sum(weights, dim=-1) + + return rgb_map, disp_map, acc_map, weights, depth_map + + # Decompose the inputs + N_rays = rays.shape[0] + rays_o, rays_d = rays[:, 0:3], rays[:, 3:6] # both (N_rays, 3) + near, far = rays[:, 6:7], rays[:, 7:8] # both (N_rays, 1) + + # Sample depth points + z_steps = torch.linspace(0, 1, N_samples, device=rays.device) # (N_samples) + if not use_disp: # use linear sampling in depth space + z_vals = near * (1 - z_steps) + far * z_steps + else: # use linear sampling in disparity space + z_vals = 1 / (1 / near * (1 - z_steps) + 1 / far * z_steps) + + z_vals = z_vals.expand(N_rays, N_samples) + + # Perturb sampling time along each ray. + if perturb > 0.0: + # get intervals between samples + mids = 0.5 * (z_vals[..., 1:] + z_vals[..., :-1]) + upper = torch.concat([mids, z_vals[..., -1:]], -1) + lower = torch.concat([z_vals[..., :1], mids], -1) + # stratified samples in those intervals + t_rand = torch.rand_like(z_vals.shape) + z_vals = lower + (upper - lower) * t_rand + + viewdirs = rays_d + viewdirs = viewdirs / torch.norm(viewdirs, dim=-1, keepdim=True) + viewdirs = torch.reshape(viewdirs, [-1, 3]).to(torch.float32) + + # Points in space to evaluate model at. + pts = ( + rays_o[..., None, :] + rays_d[..., None, :] * z_vals[..., :, None] + ) # [N_rays, N_samples, 3] + + # Evaluate model at each point. + raw = models.apply_layers(pts, viewdirs) # [N_rays, N_samples, 4] + rgb_map, disp_map, acc_map, weights, depth_map = raw2outputs(raw, z_vals, rays_d) + + results = {} + + results["rgb"] = rgb_map + results["depth"] = depth_map + results["weights"] = weights + results["opacity"] = acc_map + results["z_vals"] = z_vals + results["disp"] = disp_map + + return results + + +def pre_render_vision( + rays, + N_samples=64, + use_disp=False, + perturb=0, +): + # Decompose the inputs + N_rays = rays.shape[0] + rays_o, rays_d = rays[:, 0:3], rays[:, 3:6] # both (N_rays, 3) + near, far = rays[:, 6:7], rays[:, 7:8] # both (N_rays, 1) + + # Sample depth points + z_steps = torch.linspace(0, 1, N_samples, device="cuda") # (N_samples) + if not use_disp: # use linear sampling in depth space + z_vals = near * (1 - z_steps) + far * z_steps + else: # use linear sampling in disparity space + z_vals = 1 / (1 / near * (1 - z_steps) + 1 / far * z_steps) + + z_vals = z_vals.expand(N_rays, N_samples) + + # Perturb sampling time along each ray. + if perturb > 0.0: + # get intervals between samples + mids = 0.5 * (z_vals[..., 1:] + z_vals[..., :-1]) + upper = torch.concat([mids, z_vals[..., -1:]], -1) + lower = torch.concat([z_vals[..., :1], mids], -1) + # stratified samples in those intervals + t_rand = torch.rand_like(z_vals.shape) + z_vals = lower + (upper - lower) * t_rand + + viewdirs = rays_d + viewdirs = viewdirs / torch.norm(viewdirs, dim=-1, keepdim=True) + viewdirs = torch.reshape(viewdirs, [-1, 3]).to(torch.float32) + + # Points in space to evaluate model at. + pts = ( + rays_o[..., None, :] + rays_d[..., None, :] * z_vals[..., :, None] + ) # [N_rays, N_samples, 3] + + inputs = {} + + inputs["pts"] = pts + inputs["viewdirs"] = viewdirs + inputs["rays_d"] = rays_d + inputs["z_vals"] = z_vals + + return inputs + + +def post_render_vision( + x, + raw, + noise_std=1, +): + def raw2outputs(raw, z_vals, rays_d): + """Transforms model's predictions to semantically meaningful values. + + Args: + raw: [num_rays, num_samples along ray, 4]. Prediction from model. + z_vals: [num_rays, num_samples along ray]. Integration time. + rays_d: [num_rays, 3]. Direction of each ray. + + Returns: + rgb_map: [num_rays, 3]. Estimated RGB color of a ray. + disp_map: [num_rays]. Disparity map. Inverse of depth map. + acc_map: [num_rays]. Sum of weights along each ray. + weights: [num_rays, num_samples]. Weights assigned to each sampled color. + depth_map: [num_rays]. Estimated distance to object. + """ + + # Function for computing density from model prediction. This value is + # strictly between [0, 1]. + def raw2alpha(raw, dists, act_fn=torch.nn.functional.relu): + return 1.0 - torch.exp(-act_fn(raw) * dists) + + # Compute 'distance' (in time) between each integration time along a ray. + dists = z_vals[..., 1:] - z_vals[..., :-1] + + # The 'distance' from the last integration time is infinity. + dists = torch.concat( + [ + dists, + torch.ones_like(rays_d[:, :1]).to("cuda") * 1e10, + ], # ISSUE expand to [N_rays, 1] + dim=-1, + ) # [N_rays, N_samples] + + # Multiply each distance by the norm of its corresponding direction ray + # to convert to real world distance (accounts for non-unit directions). + dists = dists * torch.norm(rays_d[..., None, :], dim=-1) + + # Extract RGB of each sample position along each ray. + rgb = torch.sigmoid(raw[..., :3]) # [N_rays, N_samples, 3] + + # Add noise to model's predictions for density. Can be used to + # regularize network during training (prevents floater artifacts). + noise = 0.0 + if noise_std > 0.0: + noise = torch.rand_like(raw[..., 3]) * noise_std + + # Predict density of each sample along each ray. Higher values imply + # higher likelihood of being absorbed at this point. + alpha = raw2alpha(raw[..., 3] + noise, dists) # [N_rays, N_samples] + + # Compute weight for RGB of each sample along each ray. A cumprod() is + # used to express the idea of the ray not having reflected up to this + # sample yet. + # [N_rays, N_samples] + weights = alpha * torch.cumprod(1.0 - alpha + 1e-10, dim=-1) + + # Computed weighted color of each sample along each ray. + rgb_map = torch.sum(weights[..., None] * rgb, dim=-2) # [N_rays, 3] + + # Estimated depth map is expected distance. + depth_map = torch.sum(weights * z_vals, dim=-1) + + # Disparity map is inverse depth. + disp_map = 1.0 / torch.maximum( + torch.tensor([1e-10]).to(dists.device), + depth_map / torch.sum(weights, dim=-1), + ) + + # Sum of weights along each ray. This value is in [0, 1] up to numerical error. + acc_map = torch.sum(weights, dim=-1) + + return rgb_map, disp_map, acc_map, weights, depth_map + + rgb_map, disp_map, acc_map, weights, depth_map = raw2outputs( + raw, x["z_vals"], x["rays_d"] + ) + + results = {} + + results["rgb"] = rgb_map + results["depth"] = depth_map + results["weights"] = weights + results["opacity"] = acc_map + results["z_vals"] = x["z_vals"] + results["disp"] = disp_map + + return results diff --git a/src/chop/models/utils.py b/src/chop/models/utils.py index 67fee5a95..14f8aa33b 100644 --- a/src/chop/models/utils.py +++ b/src/chop/models/utils.py @@ -21,6 +21,7 @@ class ModelSource(Enum): VISION_OTHERS = "vision_others" PHYSICAL = "physical" NERF = "nerf" + NERF_VISION = "nerfvision" class ModelTaskType(Enum): @@ -36,6 +37,7 @@ class ModelTaskType(Enum): VISION = "vision" PHYSICAL = "physical" NERF = "nerf" + NERF_VISION = "nerfvision" @dataclass @@ -99,6 +101,10 @@ def __post_init__(self): # TODO: pass + if self.task_type == ModelTaskType.NERF_VISION: + # TODO: + pass + # manual models assert self.is_quantized + self.is_lora + self.is_sparse <= 1 if self.is_quantized or self.is_lora or self.is_sparse: @@ -120,6 +126,10 @@ def is_physical_model(self): def is_nerf_model(self): return self.task_type == ModelTaskType.NERF + @property + def is_nerf_vision_model(self): + return self.task_type == ModelTaskType.NERF_VISION + class ModelFactory: _model_info_dict: dict = {} diff --git a/src/chop/nerfvision b/src/chop/nerfvision new file mode 100644 index 000000000..e69de29bb diff --git a/src/chop/nn/backward/functional/linear.py b/src/chop/nn/backward/functional/linear.py index 56ea63a28..bb612caf4 100644 --- a/src/chop/nn/backward/functional/linear.py +++ b/src/chop/nn/backward/functional/linear.py @@ -15,7 +15,7 @@ minifloat_ieee_quantizer, binary_quantizer, ternary_quantizer, - mxint_hardware, + mxint_quantizer, ) diff --git a/src/chop/nn/quantized/functional/__init__.py b/src/chop/nn/quantized/functional/__init__.py index da30a9654..4b36f15b1 100644 --- a/src/chop/nn/quantized/functional/__init__.py +++ b/src/chop/nn/quantized/functional/__init__.py @@ -73,6 +73,7 @@ relu_minifloat_ieee, relu_binary, relu_ternary, + relu_mxint, ) from .selu import ( @@ -149,7 +150,7 @@ linearTernary, linearLUT, linearLogicNets, - linearMXIntHardware, + linearMXInt, ) quantized_func_map = { @@ -213,6 +214,7 @@ "relu_block_log": relu_block_log, "relu_binary": relu_binary, "relu_ternary": relu_ternary, + "relu_mxint": relu_mxint, "selu_block_minifloat": selu_block_minifloat, "selu_integer": selu_integer, "selu_fixed": selu_integer, @@ -267,7 +269,7 @@ "linear_integer": linearInteger, "linear_fixed": linearInteger, "linear_log": linearLog, - "linear_mxint_hardware": linearMXIntHardware, + "linear_mxint": linearMXInt, "linear_block_log": linearBlockLog, "linear_minifloat_ieee": linearMinifloatIEEE, "linear_minifloat_denorm": linearMinifloatDenorm, diff --git a/src/chop/nn/quantized/functional/linear.py b/src/chop/nn/quantized/functional/linear.py index 2c9f4caec..ed1ef275a 100644 --- a/src/chop/nn/quantized/functional/linear.py +++ b/src/chop/nn/quantized/functional/linear.py @@ -15,7 +15,7 @@ minifloat_ieee_quantizer, binary_quantizer, ternary_quantizer, - mxint_hardware, + mxint_quantizer, ) @@ -516,68 +516,69 @@ def linearLogicNets( raise NotImplementedError -def linearMXIntHardware( +def linearMXInt( x: Tensor, weight: Tensor, bias: Tensor = None, config: dict = None, - out_config: dict = None, ): - w_width, w_exponent_width = ( + # establish quantizers + w_width, w_exponent_width, w_block_size = ( config["weight_width"], config["weight_exponent_width"], + config["weight_block_size"], ) - w_p1, w_p0 = ( - config["weight_parallelism"][0], - config["weight_parallelism"][1], - ) - x_width, x_exponent_width = ( + x_width, x_exponent_width, x_block_size = ( config["data_in_width"], config["data_in_exponent_width"], + config["data_in_block_size"], ) - x_p1, x_p0 = ( - config["data_in_parallelism"][0], - config["data_in_parallelism"][1], + x_skip_first_dim = config.get("data_in_skip_first_dim", True) + + out_width, out_exponent_width, out_block_size = ( + config.get("data_in_width", x_width), + config.get("data_in_exponent_width", x_exponent_width), + config.get("data_in_block_size", x_block_size), ) - # check bias quantizer, if not, use weight quantizer - b_width, b_exponent_width = config["bias_width"], config["bias_exponent_width"] - b_p1, b_p0 = config["bias_parallelism"][0], config["bias_parallelism"][1] - base_quantizer = mxint_hardware - if out_config is not None: - out_width, out_exponent_width = ( - config["data_out_width"], - config["data_out_exponent_width"], - ) - out_p1, out_p0 = ( - config["data_out_parallelism_dim_1"], - config["data_out_parallelism_dim_0"], - ) - out_quantizer = partial( - base_quantizer, - q_config={"width": out_width, "exponent_width": out_exponent_width}, - parallelism=[out_p1, out_p0], - ) + + b_width, b_exponent_width, b_block_size = ( + config["bias_width"], + config["bias_exponent_width"], + config["bias_block_size"], + ) + + # blocking/unblocking 4D kernel/feature map is not supported w_quantizer = partial( - base_quantizer, - q_config={"width": w_width, "exponent_width": w_exponent_width}, - parallelism=[w_p1, w_p0], + mxint_quantizer, + width=w_width, + exponent_width=w_exponent_width, + block_size=w_block_size, + skip_first_dim=False, ) x_quantizer = partial( - base_quantizer, - q_config={"width": x_width, "exponent_width": x_exponent_width}, - parallelism=[x_p1, x_p0], + mxint_quantizer, + width=x_width, + exponent_width=x_exponent_width, + block_size=x_block_size, + skip_first_dim=x_skip_first_dim, + ) + out_quantizer = partial( + mxint_quantizer, + width=out_width, + exponent_width=out_exponent_width, + block_size=out_block_size, + skip_first_dim=x_skip_first_dim, ) b_quantizer = partial( - base_quantizer, - q_config={"width": b_width, "exponent_width": b_exponent_width}, - parallelism=[b_p1, b_p0], + mxint_quantizer, + width=b_width, + exponent_width=b_exponent_width, + block_size=b_block_size, + skip_first_dim=False, ) x = x_quantizer(x) weight = w_quantizer(weight) bias = b_quantizer(bias) if bias is not None else None - out = F.linear(x, weight, bias) - if out_config is not None: - out = out_quantizer(out) - return out + return out_quantizer(F.linear(x, weight, bias)) diff --git a/src/chop/nn/quantized/functional/relu.py b/src/chop/nn/quantized/functional/relu.py index 57daed04a..5e588cca6 100644 --- a/src/chop/nn/quantized/functional/relu.py +++ b/src/chop/nn/quantized/functional/relu.py @@ -14,6 +14,7 @@ minifloat_ieee_quantizer, binary_quantizer, ternary_quantizer, + mxint_quantizer, ) @@ -199,3 +200,53 @@ def relu_block_log(x, inplace=False, config=None): x = x_quantizer(x) x = torch.reshape(x, x_shape) return F.relu(x, inplace=inplace) + + +def relu_mxint(x, inplace=False, config=None): + bypass = config.get("bypass", False) + if bypass or isinstance(x, torch.fx.proxy.Proxy): + return F.relu(x, inplace=inplace) + else: + x_width, x_exponent_width, x_block_size = ( + config["data_in_width"], + config["data_in_exponent_width"], + config["data_in_block_size"], + ) + + out_width, out_exponent_width, out_block_size = ( + config.get("data_in_width", x_width), + config.get("data_in_exponent_width", x_exponent_width), + config.get("data_in_block_size", x_block_size), + ) + + x_more_than_2_dims = x.ndim > 2 + x_quantizer = partial( + mxint_quantizer, + width=x_width, + exponent_width=x_exponent_width, + block_size=x_block_size, + skip_first_dim=x_more_than_2_dims, + ) + + out_quantizer = partial( + mxint_quantizer, + width=out_width, + exponent_width=out_exponent_width, + block_size=out_block_size, + skip_first_dim=x_more_than_2_dims, + ) + + x_shape = [i for i in x.shape] + if x_more_than_2_dims: + x = torch.flatten(x, start_dim=0, end_dim=-3) + x = x_quantizer(x) + x = torch.reshape(x, x_shape) + relu_out = F.relu(x, inplace=inplace) + relu_out = ( + torch.flatten(relu_out, start_dim=0, end_dim=-3) + if x_more_than_2_dims + else relu_out + ) + relu_out_q = out_quantizer(relu_out) + relu_out_q = torch.reshape(relu_out_q, x_shape) + return relu_out_q diff --git a/src/chop/nn/quantized/modules/__init__.py b/src/chop/nn/quantized/modules/__init__.py index 9676b981e..6956e6ca3 100644 --- a/src/chop/nn/quantized/modules/__init__.py +++ b/src/chop/nn/quantized/modules/__init__.py @@ -52,7 +52,7 @@ LinearTernary, LinearLUT, LinearLogicNets, - LinearMXIntHardware, + LinearMXInt, ) from .pool2d import ( AdaptiveAvgPool2dInteger, @@ -70,6 +70,7 @@ ReLUMinifloatIEEE, ReLUBinary, ReLUTernary, + ReLUMXINT, ) from .batch_norm2d import ( BatchNorm2dInteger, @@ -163,7 +164,7 @@ quantized_basic_module_map = { "conv1d_block_minifloat": Conv1dBlockMinifloat, - "conv1d_integer": Conv1dInteger, + "conv1d_fixed": Conv1dInteger, "conv1d_binary": Conv1dBinary, "conv1d_ternary": Conv1dTernary, "conv1d_log": Conv1dLog, @@ -172,7 +173,7 @@ "conv1d_minifloat_denorm": Conv1dMinifloatDenorm, "conv1d_block_fp": Conv1dBlockFP, "conv2d_block_minifloat": Conv2dBlockMinifloat, - "conv2d_integer": Conv2dInteger, + "conv2d_fixed": Conv2dInteger, "conv2d_binary_residual": Conv2dBinaryResidualSign, "conv2d_binary": Conv2dBinaryScaling, "conv2d_ternary": Conv2dTernary, @@ -184,10 +185,9 @@ "conv2d_lutnet": Conv2dLUT, "conv2d_logicnets": Conv2DLogicNets, "linear_block_minifloat": LinearBlockMinifloat, - "linear_integer": LinearInteger, "linear_fixed": LinearInteger, "linear_log": LinearLog, - "linear_mxint_hardware": LinearMXIntHardware, + "linear_mxint": LinearMXInt, "linear_block_log": LinearBlockLog, "linear_minifloat_ieee": LinearMinifloatIEEE, "linear_minifloat_denorm": LinearMinifloatDenorm, @@ -197,12 +197,11 @@ "linear_ternary": LinearTernary, "linear_lutnet": LinearLUT, "linear_logicnets": LinearLogicNets, - "adaptive_avg_pool2d_integer": AdaptiveAvgPool2dInteger, - "avg_pool2d_integer": AvgPool2dInteger, + "adaptive_avg_pool2d_fixed": AdaptiveAvgPool2dInteger, + "avg_pool2d_fixed": AvgPool2dInteger, "avg_pool2d_binary": AvgPool2dBinary, "avg_pool2d_ternary": AvgPool2dTernary, "relu_block_minifloat": ReLUBlockMinifloat, - "relu_integer": ReLUInteger, "relu_fixed": ReLUInteger, "relu_log": ReLULog, "relu_block_log": ReLUBlockLog, @@ -211,14 +210,14 @@ "relu_block_fp": ReLUBlockFP, "relu_binary": ReLUBinary, "relu_ternary": ReLUTernary, - "batch_norm2d_integer": BatchNorm2dInteger, + "relu_mxint": ReLUMXINT, + "batch_norm2d_fixed": BatchNorm2dInteger, "batch_norm2d_binary": BatchNorm2dBinary, - "layer_norm_integer": LayerNormInteger, - "group_norm_integer": GroupNormInteger, - "instance_norm2d_integer": InstanceNorm2dInteger, - "rms_norm_integer": RMSNormInteger, + "layer_norm_fixed": LayerNormInteger, + "group_norm_fixed": GroupNormInteger, + "instance_norm2d_fixed": InstanceNorm2dInteger, + "rms_norm_fixed": RMSNormInteger, "selu_block_minifloat": SELUBlockMinifloat, - "selu_integer": SELUInteger, "selu_fixed": SELUInteger, "selu_log": SELULog, "selu_block_log": SELUBlockLog, @@ -228,7 +227,6 @@ "selu_binary": SELUBinary, "selu_ternary": SELUTernary, "silu_block_minifloat": SiLUBlockMinifloat, - "silu_integer": SiLUInteger, "silu_fixed": SiLUInteger, "silu_log": SiLULog, "silu_block_log": SiLUBlockLog, @@ -238,7 +236,6 @@ "silu_binary": SiLUBinary, "silu_ternary": SiLUTernary, "tanh_block_minifloat": TanhBlockMinifloat, - "tanh_integer": TanhInteger, "tanh_fixed": TanhInteger, "tanh_log": TanhLog, "tanh_block_log": TanhBlockLog, @@ -248,7 +245,6 @@ "tanh_binary": TanhBinary, "tanh_ternary": TanhTernary, "gelu_block_minifloat": GELUBlockMinifloat, - "gelu_integer": GELUInteger, "gelu_fixed": GELUInteger, "gelu_log": GELULog, "gelu_block_log": GELUBlockLog, @@ -258,7 +254,6 @@ "gelu_binary": GELUBinary, "gelu_ternary": GELUTernary, "softsign_block_minifloat": SoftsignBlockMinifloat, - "softsign_integer": SoftsignInteger, "softsign_fixed": SoftsignInteger, "softsign_log": SoftsignLog, "softsign_block_log": SoftsignBlockLog, @@ -268,7 +263,6 @@ "softsign_binary": SoftsignBinary, "softsign_ternary": SoftsignTernary, "softplus_block_minifloat": SoftplusBlockMinifloat, - "softplus_integer": SoftplusInteger, "softplus_fixed": SoftplusInteger, "softplus_log": SoftplusLog, "softplus_block_log": SoftplusBlockLog, @@ -282,9 +276,9 @@ } quantized_bert_module_map = { - "bert_self_attention_head_integer": BertSelfAttentionHeadInteger, - "bert_self_attention_integer": BertSelfAttentionInteger, - "grouped_query_attention_integer": GroupedQueryAttentionInteger, + "bert_self_attention_head_fixed": BertSelfAttentionHeadInteger, + "bert_self_attention_fixed": BertSelfAttentionInteger, + "grouped_query_attention_fixed": GroupedQueryAttentionInteger, } quantized_roberta_module_map = { diff --git a/src/chop/nn/quantized/modules/linear.py b/src/chop/nn/quantized/modules/linear.py index 5d8d389a5..dcb3e7892 100644 --- a/src/chop/nn/quantized/modules/linear.py +++ b/src/chop/nn/quantized/modules/linear.py @@ -8,7 +8,7 @@ linearBlockMinifloat, linearInteger, linearLog, - linearMXIntHardware, + linearMXInt, linearMinifloatDenorm, linearMinifloatIEEE, linearTernary, @@ -21,6 +21,7 @@ from ..utils import get_stats, quantiser_passthrough from chop.nn.quantizers import ( + mxint, residual_sign_quantizer, block_fp_quantizer, block_log_quantizer, @@ -32,7 +33,6 @@ minifloat_ieee_quantizer, binary_quantizer, ternary_quantizer, - mxint_hardware, ) # LUTNet @@ -785,7 +785,7 @@ def forward(self, x: Tensor) -> Tensor: return self.math_forward(x) -class LinearMXIntHardware(_LinearBase): +class LinearMXInt(_LinearBase): def __init__( self, in_features: int, @@ -794,12 +794,10 @@ def __init__( device=None, dtype=None, config=None, - out_config=None, ) -> None: super().__init__(in_features, out_features, bias, device, dtype) assert config is not None, "config is None!" self.config = config - self.out_config = out_config self.bypass = config.get("bypass", False) if self.bypass: return @@ -807,6 +805,4 @@ def __init__( def forward(self, x): if self.bypass: return F.linear(x, self.weight, self.bias) - return linearMXIntHardware( - x, self.weight, self.bias, self.config, self.out_config - ) + return linearMXInt(x, self.weight, self.bias, self.config) diff --git a/src/chop/nn/quantized/modules/relu.py b/src/chop/nn/quantized/modules/relu.py index 2bc527161..157a61058 100644 --- a/src/chop/nn/quantized/modules/relu.py +++ b/src/chop/nn/quantized/modules/relu.py @@ -14,6 +14,7 @@ log_quantizer, minifloat_denorm_quantizer, minifloat_ieee_quantizer, + mxint_quantizer, binary_quantizer, ternary_quantizer, ) @@ -300,3 +301,37 @@ def __init__(self, inplace: bool = False, config: dict = None): # mean=x_mean, # ) self.config = config + + +class ReLUMXINT(_ReLUBase): + def __init__(self, inplace: bool = False, config: dict = None): + super().__init__(inplace) + assert config is not None, "config is None!" + self.config = config + self.bypass = config.get("bypass", False) + if self.bypass: + return + + x_width, x_exponent_width, x_block_size = ( + config["data_in_width"], + config["data_in_exponent_width"], + config["data_in_block_size"], + ) + self.x_quantizer = partial( + mxint_quantizer, + width=x_width, + exponent_width=x_exponent_width, + block_size=x_block_size, + skip_first_dim=True, + ) + + def forward(self, x: Tensor) -> Tensor: + if self.bypass: + return F.relu(x) + else: + x_shape = [i for i in x.shape] + if x.ndim > 2: + x = torch.flatten(x, 0, -3) + x = self.x_quantizer(x) + x = torch.reshape(x, x_shape) + return F.relu(x, self.inplace) diff --git a/src/chop/nn/quantizers/__init__.py b/src/chop/nn/quantizers/__init__.py index c5b1f4c8d..8c5936736 100644 --- a/src/chop/nn/quantizers/__init__.py +++ b/src/chop/nn/quantizers/__init__.py @@ -6,8 +6,11 @@ from .ternary import ternary_quantizer from .log import log_quantizer from .minifloat import minifloat_denorm_quantizer, minifloat_ieee_quantizer -from .quantizers_for_hw import integer_quantizer_for_hw, integer_floor_quantizer_for_hw -from .mxint_hardware import mxint_hardware +from .quantizers_for_hw import ( + integer_quantizer_for_hw, + integer_floor_quantizer_for_hw, +) +from .mxint import mxint_quantizer quantizer_map = { "log": log_quantizer, @@ -19,5 +22,5 @@ "integer": integer_quantizer, "binary": binary_quantizer, "ternary": ternary_quantizer, - "mxint_hardware": mxint_hardware, + "mxint": mxint_quantizer, } diff --git a/src/chop/nn/quantizers/mxint.py b/src/chop/nn/quantizers/mxint.py new file mode 100644 index 000000000..cfedba5e6 --- /dev/null +++ b/src/chop/nn/quantizers/mxint.py @@ -0,0 +1,117 @@ +import torch +from torch import Tensor + +from .utils import block, my_clamp, my_round, unblock, my_floor + + +def _mxint_quantize( + x: Tensor, + width: int = 8, + exponent_width: int = 4, + block_size: list[int] = [16], + skip_first_dim: bool = True, + floor=True, +): + """ + - Convert IEEE FP32/64 to Microscaling Interger (MXINT), where an exponent is shared over all elements in a block. + - https://arxiv.org/pdf/2310.10537.pdf + - https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf + + --- + - forward: convert IEEE FP32/64 to MXINT + - backward: STE + + --- + - `width`: The number of mantissa bits + 1 (the sign bit) + - `exponent_width`: the number of exponent bits + - `block_size`: a list of integers where each integer is the block size on that dimension. See function `block`. + """ + + if isinstance(block_size, int): + block_size = [block_size] + + x_shape_before_blocking = [i for i in x.shape] + blocked_x, per_block_max, padded_x_shape, block_shape = block( + x, block_shape=block_size, skip_first_dim=skip_first_dim + ) + + if torch.all(per_block_max == 0): + per_block_max = torch.ones_like(per_block_max) + else: + per_block_max[per_block_max == 0] = per_block_max[per_block_max != 0].min() + + exponent_bias = 2 ** (exponent_width - 1) - 1 + + per_block_exponent = torch.floor(torch.log2(per_block_max)) + exponent_bias + per_block_exponent = my_clamp(per_block_exponent, 0, 2**exponent_width - 1) + + scaled_value = blocked_x / 2 ** (per_block_exponent - exponent_bias) + + element_max = 2 ** (width - 1) - 1 + shift = 2 ** (width - 2) + + # To advoid introducing a negative bias + mantissas = scaled_value * shift + quantized_value = my_clamp( + my_floor(mantissas) if floor else my_round(mantissas), -element_max, element_max + ) + + element_value = quantized_value / shift + + mxint_value = element_value * 2 ** (per_block_exponent - exponent_bias) + + mxint_x = unblock( + mxint_value, + x_shape_before_blocking=x_shape_before_blocking, + padded_x_shape=padded_x_shape, + block_shape=block_shape, + skipped_first_dim_when_blocking=skip_first_dim, + ) + + # fmt: off + # this `is_close_to_0` helps the grad keeps 1 if input x is 0, or the zero-initialized value will be trapped in 0 + is_close_to_0 = torch.isclose(x, torch.tensor([0.0], dtype=x.dtype, device=x.device)) + mxint_x = (~is_close_to_0) * mxint_x + (is_close_to_0) * x + # fmt: on + + return mxint_x + + +class MXINTQuantize(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x: Tensor, + width: int = 8, + exponent_width: int = 4, + block_size: list[int] = [16], + skip_first_dim: bool = True, + ): + return _mxint_quantize( + x, + width=width, + exponent_width=exponent_width, + block_size=block_size, + skip_first_dim=skip_first_dim, + ) + + @staticmethod + def backward(ctx, grad_output): + grad_input = grad_output.clone() + return grad_input, None, None, None, None + + +def mxint_quantizer( + x: Tensor, + width: int = 8, + exponent_width: int = 4, + block_size: list[int] = [16], + skip_first_dim: bool = True, +): + return MXINTQuantize.apply( + x, + width, + exponent_width, + block_size, + skip_first_dim, + ) diff --git a/src/chop/nn/quantizers/mxint_hardware.py b/src/chop/nn/quantizers/mxint_hardware.py deleted file mode 100644 index 0c3e06130..000000000 --- a/src/chop/nn/quantizers/mxint_hardware.py +++ /dev/null @@ -1,82 +0,0 @@ -import torch -from functools import partial -import torch.nn.functional as F -from torch import Tensor - - -def mxint_quant_block( - x, width: int = 12, exponent_width: int = 6, exponent: int = None -): - """ - - Idea from https://arxiv.org/pdf/2310.10537 - - Convert IEEE FP32/64 to Integer with sharing scale - - The main difference between is the sharing scale do not support NAN representation - --- - - `width`: The number of mantissa bits + 1 (the sign bit) - - `exponent_width`: the number of exponent bits, which is shared over a block - - `exponent_bias`: the exponent bias, if None, `2**(exponent_bits-1)-1` will be used - - """ - exponent_bias = 2 ** (exponent_width - 1) - - exponent_max = 2**exponent_width - 1 - exponent_bias - exponent_min = -exponent_bias - - # exponent - if exponent == None: - exponent = torch.ceil(torch.log2(x.abs().max())) - exponent_bias - exponent = torch.clamp(exponent, exponent_min, exponent_max) - # mantissa - int_min = -(2 ** (width - 1)) - int_max = 2 ** (width - 1) - 1 - mantissa = x / 2**exponent - mantissa = torch.clamp(mantissa.floor(), int_min, int_max) - q_x = (2**exponent) * mantissa - return q_x - - -def mxint_hardware(tensor, q_config, parallelism): - """ - - For hardware efficiency, the block will be set based on parallelism - - This will reshape all the input to a 3D matrix (other dimension will be packed into the first dimension) - - Then will quantize every block of the 2D matrix in the reshaped input tensor. - - The block size will be [parallelism[0], parallelism[1]] - --- - - q_config: assume to be a dict, for example - { - "width": 8, - "exponent_width": 4, - } - - parallelism: assume to be [tensor.shape[-2],tensor.shape[-1]] - - `exponent_width`: the number of exponent bits, which is shared over a block - - `exponent_bias`: the exponent bias, if None, `2**(exponent_bits-1)-1` will be used - """ - original_shape = tensor.shape - if len(tensor.shape) == 1: - tensor = tensor.unsqueeze(0) - if len(parallelism) == 1: - parallelism = [1, parallelism[0]] - - p1 = parallelism[0] - p0 = parallelism[1] - t1 = tensor.shape[-2] - t0 = tensor.shape[-1] - assert ( - t1 % p1 == 0 and t0 % p0 == 0 - ), f"""The Block should be able to completely segment the tensor size, - t1 = {t1}, p1 = {p1}, t0 = {t0}, p0 = {p0}""" - reshaped_tensor = tensor.reshape(-1, t1 // p1, p1, t0 // p0, p0).permute( - 0, 1, 3, 2, 4 - ) - - # Quantize - quantizer = partial(mxint_quant_block, **q_config) - reshaped_tensor = torch.tensor(reshaped_tensor.reshape(-1, p1 * p0)) - for i in range(reshaped_tensor.shape[0]): - reshaped_tensor[i] = quantizer(reshaped_tensor[i]) - qtensor = ( - reshaped_tensor.reshape(-1, t1 // p1, t0 // p0, p1, p0) - .permute(0, 1, 3, 2, 4) - .reshape(original_shape) - ) - return qtensor diff --git a/src/chop/nn/quantizers/quantizers_for_hw.py b/src/chop/nn/quantizers/quantizers_for_hw.py index d5ca3d8cf..f18bf927c 100644 --- a/src/chop/nn/quantizers/quantizers_for_hw.py +++ b/src/chop/nn/quantizers/quantizers_for_hw.py @@ -35,6 +35,3 @@ def integer_floor_quantizer_for_hw(x: Tensor, width: int, frac_width: int): fixed_point_value = fixed_point_value.to(torch.int) fixed_point_value = fixed_point_value % (2**width) return fixed_point_value - - -# sw_quantizer_to_hw_quantizer = {integer_quantizer: integer_quantizer_for_hw} diff --git a/src/chop/passes/graph/analysis/add_metadata/add_hardware_metadata.py b/src/chop/passes/graph/analysis/add_metadata/add_hardware_metadata.py index 3c4fccad0..9e71e2878 100644 --- a/src/chop/passes/graph/analysis/add_metadata/add_hardware_metadata.py +++ b/src/chop/passes/graph/analysis/add_metadata/add_hardware_metadata.py @@ -13,8 +13,8 @@ from chop.passes.graph.utils import get_mase_op, deepgetattr, get_module_by_name from torch import nn - -from .hardware_metadata_layers import INTERNAL_COMP +from typing import get_args +from .hardware_metadata_layers import INTERNAL_COMP, supported_hw_quantisations logger = logging.getLogger(__name__) @@ -29,7 +29,27 @@ def _cap(name): return str(name).upper() -def add_component_source(node): +def get_node_type(node: fx.Node) -> supported_hw_quantisations: + types: set[str] = set() + for _arg, type_info in node.meta["mase"]["common"]["args"].items(): + match type_info: + case dict(): + types.add(type_info["type"]) + case _: + pass + types = set(types) + # types = set([match type_info: case dict(): type_info['type'] case _: None for _arg, type_info in node.meta['mase']['common']['args'].items()]) + assert ( + len(types) == 1 + ), f"More than one type in node {node.name}, {types}, {node.meta['mase']['common']['args']}" + node_type = types.pop() + assert node_type in get_args( + supported_hw_quantisations + ), f"({node.name}) Unsupported hardware quantisation type: {node_type}" + return node_type + + +def add_component_source(node: fx.Node): if node.meta["mase"]["hardware"]["is_implicit"]: return @@ -50,10 +70,13 @@ def add_component_source(node): elif mase_op in INTERNAL_COMP.keys(): node.meta["mase"]["hardware"]["toolchain"] = "INTERNAL_RTL" # take the first ip in the component list by default - node.meta["mase"]["hardware"]["module"] = INTERNAL_COMP[mase_op][0]["name"] - node.meta["mase"]["hardware"]["dependence_files"] = INTERNAL_COMP[mase_op][0][ - "dependence_files" + node_type = get_node_type(node) + node.meta["mase"]["hardware"]["module"] = INTERNAL_COMP[mase_op][node_type][ + "name" ] + node.meta["mase"]["hardware"]["dependence_files"] = INTERNAL_COMP[mase_op][ + node_type + ]["dependence_files"] else: node.meta["mase"]["hardware"]["toolchain"] = "INTERNAL_HLS" node.meta["mase"]["hardware"]["module"] = None @@ -76,7 +99,7 @@ def add_component_source(node): node.meta["mase"]["hardware"]["interface"][arg] = {} -def add_verilog_param(node): +def add_verilog_param(node: fx.Node): if node.meta["mase"]["hardware"]["is_implicit"]: return @@ -89,26 +112,39 @@ def add_verilog_param(node): if isinstance(arg_info, dict): for i, precision in enumerate(arg_info["precision"]): vp[_cap(arg + f"_precision_{i}")] = arg_info["precision"][i] - for dim in range(0, len(arg_info["shape"])): - vp[_cap(arg + f"_tensor_size_dim_{dim}")] = ( - arg_info["shape"][len(arg_info["shape"]) - 1 - dim] - if dim < len(arg_info["shape"]) - else 1 - ) - # Check if max parallelism is defined - if node.meta["mase"]["hardware"]["max_parallelism"] is not None: - # Take the minimum between... - vp[_cap(arg + f"_parallelism_dim_{dim}")] = min( - # The defined max parallelism for this dimension - node.meta["mase"]["hardware"]["max_parallelism"][::-1][dim], - # The size of this dimension - arg_info["shape"][::-1][dim], + match arg_info.get("type", None): + case "fixed": + for dim in range(0, len(arg_info["shape"])): + vp[_cap(arg + f"_tensor_size_dim_{dim}")] = ( + arg_info["shape"][len(arg_info["shape"]) - 1 - dim] + if dim < len(arg_info["shape"]) + else 1 + ) + # Check if max parallelism is defined + if node.meta["mase"]["hardware"]["max_parallelism"] is not None: + # Take the minimum between... + vp[_cap(arg + f"_parallelism_dim_{dim}")] = min( + # The defined max parallelism for this dimension + node.meta["mase"]["hardware"]["max_parallelism"][::-1][ + dim + ], + # The size of this dimension + arg_info["shape"][::-1][dim], + ) + # Otherwise, assign to tensor size by default + else: + vp[_cap(arg + f"_parallelism_dim_{dim}")] = arg_info[ + "shape" + ][::-1][dim] + case "mxint": + for dim, dim_value in enumerate(reversed(arg_info["shape"])): + vp[_cap(arg + f"_tensor_size_dim_{dim}")] = dim_value + if parallel := arg_info.get(f"parallelism_{dim}", None): + vp[_cap(arg + f"_parallelism_dim_{dim}")] = parallel + case t: + raise NotImplementedError( + f"Unsupported quantization type {t} for {node.name} {arg}" ) - # Otherwise, assign to tensor size by default - else: - vp[_cap(arg + f"_parallelism_dim_{dim}")] = arg_info["shape"][::-1][ - dim - ] elif type(arg_info) == bool: vp[_cap(arg)] = 1 if arg_info else 0 else: @@ -118,26 +154,45 @@ def add_verilog_param(node): if isinstance(result_info, dict): for i, precision in enumerate(result_info["precision"]): vp[_cap(result + f"_precision_{i}")] = result_info["precision"][i] - for dim in range(0, len(result_info["shape"])): - vp[_cap(result + f"_tensor_size_dim_{dim}")] = ( - result_info["shape"][len(result_info["shape"]) - 1 - dim] - if dim < len(result_info["shape"]) - else 1 - ) - # Check if max parallelism is defined - if node.meta["mase"]["hardware"]["max_parallelism"] is not None: - # Take the minimum between... - vp[_cap(result + f"_parallelism_dim_{dim}")] = min( - # The defined max parallelism for this dimension - node.meta["mase"]["hardware"]["max_parallelism"][::-1][dim], - # The size of this dimension - result_info["shape"][::-1][dim], - ) - # Otherwise, assign to tensor size by default - else: - vp[_cap(result + f"_parallelism_dim_{dim}")] = result_info["shape"][ - ::-1 - ][dim] + match result_info.get("type", None): + case "fixed": + for dim in range(0, len(result_info["shape"])): + vp[_cap(result + f"_tensor_size_dim_{dim}")] = ( + result_info["shape"][ + len(result_info["shape"]) - 1 - dim + ] + if dim < len(result_info["shape"]) + else 1 + ) + # Check if max parallelism is defined + if ( + node.meta["mase"]["hardware"]["max_parallelism"] + is not None + ): + # Take the minimum between... + vp[_cap(result + f"_parallelism_dim_{dim}")] = min( + # The defined max parallelism for this dimension + node.meta["mase"]["hardware"]["max_parallelism"][ + ::-1 + ][dim], + # The size of this dimension + result_info["shape"][::-1][dim], + ) + # Otherwise, assign to tensor size by default + else: + vp[_cap(result + f"_parallelism_dim_{dim}")] = ( + result_info["shape"][::-1][dim] + ) + case "mxint": + for dim, dim_value in enumerate(reversed(result_info["shape"])): + vp[_cap(result + f"_tensor_size_dim_{dim}")] = dim_value + if parallel := result_info.get(f"parallelism_{dim}", None): + vp[_cap(result + f"_parallelism_dim_{dim}")] = parallel + + case t: + raise NotImplementedError( + f"Unsupported quantization type {t} for {node.name} {result}" + ) else: vp[_cap(result)] = result_info @@ -170,7 +225,7 @@ def add_extra_verilog_param(node, graph: MaseGraph): ] -def add_hardware_metadata_analysis_pass(graph, pass_args={}): +def add_hardware_metadata_analysis_pass(graph: MaseGraph, pass_args={}): """add hardware metadata :param graph: a MaseGraph diff --git a/src/chop/passes/graph/analysis/add_metadata/hardware_metadata_layers.py b/src/chop/passes/graph/analysis/add_metadata/hardware_metadata_layers.py index b084821ae..409f560d3 100644 --- a/src/chop/passes/graph/analysis/add_metadata/hardware_metadata_layers.py +++ b/src/chop/passes/graph/analysis/add_metadata/hardware_metadata_layers.py @@ -1,6 +1,16 @@ +from typing import Dict, TypedDict, Literal + + +class IpDescType(TypedDict): + name: str + dependence_files: list[str] + + +supported_hw_quantisations = Literal["fixed", "mxint"] + # mase_op : the set of functional equivalent IPs with different design configurations. # The first IP in each list is used by default -norm = { +norm: IpDescType = { "name": "norm", "dependence_files": [ "common/rtl/join2.sv", @@ -33,9 +43,10 @@ ], } -INTERNAL_COMP = { - "linear": [ - { + +INTERNAL_COMP: Dict[str, Dict[supported_hw_quantisations, IpDescType]] = { + "linear": { + "fixed": { "name": "fixed_linear", "dependence_files": [ "cast/rtl/fixed_cast.sv", @@ -60,209 +71,201 @@ "linear_layers/matmul/rtl/matrix_stream_transpose.sv", ], }, - ], - "relu": [ - { + "mxint": { + "name": "mxint_linear", + "dependence_files": [ + "linear_layers/mxint_operators/rtl/mxint_linear.sv", + "linear_layers/mxint_operators/rtl/mxint_circular.sv", + "memory/rtl/input_buffer.sv", + "linear_layers/mxint_operators/rtl/mxint_dot_product.sv", + "linear_layers/mxint_operators/rtl/mxint_accumulator.sv", + "linear_layers/mxint_operators/rtl/mxint_cast.sv", + "linear_layers/mxint_operators/rtl/log2_max_abs.sv", + "linear_layers/mxint_operators/rtl/or_tree.sv", + "linear_layers/mxint_operators/rtl/or_tree_layer.sv", + "linear_layers/mxint_operators/rtl/mxint_register_slice.sv", + "linear_layers/mxint_operators/rtl/unpacked_mx_fifo.sv", + "common/rtl/unpacked_register_slice.sv", + "common/rtl/split2.sv", + "common/rtl/join2.sv", + "common/rtl/join_n.sv", + "common/rtl/register_slice.sv", + "memory/rtl/fifo.sv", + "memory/rtl/blk_mem_gen_0.sv", + "memory/rtl/simple_dual_port_ram.sv", + "memory/rtl/unpacked_skid_buffer.sv", + "memory/rtl/skid_buffer.sv", + "memory/rtl/ultraram_fifo.sv", + "memory/rtl/ultraram.v", + "linear_layers/fixed_operators/rtl/fixed_dot_product.sv", + "linear_layers/fixed_operators/rtl/fixed_vector_mult.sv", + "linear_layers/fixed_operators/rtl/fixed_mult.sv", + "linear_layers/fixed_operators/rtl/fixed_adder_tree.sv", + "linear_layers/fixed_operators/rtl/fixed_adder_tree_layer.sv", + ], + }, + }, + "relu": { + "fixed": { "name": "fixed_relu", "dependence_files": [ "activation_layers/rtl/fixed_relu.sv", ], }, - ], - "hardshrink": [ - { + "mxint": { + "name": "mxint_relu", + "dependence_files": [ + "linear_layers/mxint_operators/rtl/mxint_relu.sv", + "linear_layers/mxint_operators/rtl/mxint_cast.sv", + "common/rtl/split2.sv", + "linear_layers/mxint_operators/rtl/log2_max_abs.sv", + "linear_layers/mxint_operators/rtl/or_tree.sv", + "linear_layers/mxint_operators/rtl/or_tree_layer.sv", + "common/rtl/register_slice.sv", + "linear_layers/mxint_operators/rtl/unpacked_mx_fifo.sv", + "memory/rtl/fifo.sv", + "memory/rtl/skid_buffer.sv", + "memory/rtl/simple_dual_port_ram.sv", + "common/rtl/join2.sv", + ], + }, + }, + "hardshrink": { + "fixed": { "name": "fixed_hardshrink", "dependence_files": [ "activation_layers/rtl/fixed_hardshrink.sv", ], }, - ], - "silu": [ - { + }, + "silu": { + "fixed": { "name": "fixed_silu", "dependence_files": [ "activation_layers/rtl/fixed_silu.sv", "activation_layers/rtl/silu_lut.sv", ], }, - ], - "elu": [ - { + }, + "elu": { + "fixed": { "name": "fixed_elu", "dependence_files": [ "activation_layers/rtl/fixed_elu.sv", "activation_layers/rtl/elu_lut.sv", ], }, - ], - "sigmoid": [ - { + }, + "sigmoid": { + "fixed": { "name": "fixed_sigmoid", "dependence_files": [ "activation_layers/rtl/fixed_sigmoid.sv", "activation_layers/rtl/sigmoid_lut.sv", ], }, - ], - "softshrink": [ - { + }, + "softshrink": { + "fixed": { "name": "fixed_softshrink", "dependence_files": [ "activation_layers/rtl/fixed_softshrink.sv", ], }, - ], - "logsigmoid": [ - { + }, + "logsigmoid": { + "fixed": { "name": "fixed_logsigmoid", "dependence_files": [ "activation_layers/rtl/fixed_logsigmoid.sv", "activation_layers/rtl/logsigmoid_lut.sv", ], }, - ], - "softmax": [ - { + }, + "softmax": { + "fixed": { "name": "fixed_softmax", "dependence_files": [ "activation_layers/rtl/fixed_softmax.sv", "activation_layers/rtl/exp_lut .sv", ], - } - ], - "batch_norm2d": [norm], - "layer_norm": [norm], - "group_norm": [norm], - "instance_norm2d": [norm], - "rms_norm": [norm], - "selu": [ - { + }, + }, + "batch_norm2d": {"fixed": norm}, + "layer_norm": {"fixed": norm}, + "group_norm": {"fixed": norm}, + "instance_norm2d": {"fixed": norm}, + "rms_norm": {"fixed": norm}, + "selu": { + "fixed": { "name": "fixed_selu", "dependence_files": [ "activation_layers/rtl/fixed_selu.sv", ], }, - ], - "tanh": [ - { + }, + "tanh": { + "fixed": { "name": "fixed_tanh", "dependence_files": [ "activation_layers/rtl/fixed_tanh.sv", ], }, - ], - "gelu": [ - { + }, + "gelu": { + "fixed": { "name": "fixed_gelu", "dependence_files": [ "activation_layers/rtl/fixed_gelu.sv", "activation_layers/rtl/gelu_lut.sv", ], }, - ], - "softsign": [ - { + }, + "softsign": { + "fixed": { "name": "fixed_softsign", "dependence_files": [ "activation_layers/rtl/fixed_softsign.sv", "linear_layers/fixed_operators/rtl/fixed_mult.sv", ], }, - ], - "softplus": [ - { + }, + "softplus": { + "fixed": { "name": "fixed_softplus", "dependence_files": [ "activation_layers/rtl/fixed_softplus.sv", ], }, - ], - "add": [ - { + }, + "add": { + "fixed": { "name": "fixed_adder", "dependence_files": [ "linear_layers/fixed_operators/rtl/fixed_adder.sv", ], - } - ], - "mul": [ - { + }, + }, + "mul": { + "fixed": { "name": "fixed_elementwise_multiplier", "dependence_files": [ "linear_layers/fixed_operators/rtl/fixed_vector_mult.sv", ], - } - ], - "df_split": [ - { + }, + }, + "df_split": { + "fixed": { "name": "df_split", "dependence_files": ["common/rtl/df_split.sv", "common/rtl/split2.sv"], - } - ], - "getitem": [ - { + }, + }, + "getitem": { + "fixed": { "name": "buffer", "dependence_files": [ "memory/rtl/buffer.sv", ], - } - ], - "grouped_query_attention": [ - { - "name": "fixed_gqa_wrapper", - "dependence_files": [ - "common/rtl/find_first_arbiter.sv", - "common/rtl/mux.sv", - "common/rtl/join2.sv", - "common/rtl/split2.sv", - "common/rtl/split_n.sv", - "common/rtl/join_n.sv", - "memory/rtl/repeat_circular_buffer.sv", - "common/rtl/single_element_repeat.sv", - "memory/rtl/skid_buffer.sv", - "memory/rtl/unpacked_skid_buffer.sv", - "common/rtl/register_slice.sv", - "memory/rtl/simple_dual_port_ram.sv", - "memory/rtl/fifo.sv", - "common/rtl/lut.sv", - "common/rtl/comparator_tree.sv", - "common/rtl/comparator_accumulator.sv", - "common/rtl/unpacked_register_slice.sv", - "cast/rtl/fixed_round.sv", - "cast/rtl/fixed_rounding.sv", - "cast/rtl/floor_round.sv", - "cast/rtl/signed_clamp.sv", - "cast/rtl/fixed_signed_cast.sv", - "linear_layers/fixed_operators/rtl/fixed_mult.sv", - "linear_layers/fixed_operators/rtl/fixed_vector_mult.sv", - "linear_layers/fixed_operators/rtl/fixed_dot_product.sv", - "linear_layers/fixed_operators/rtl/fixed_accumulator.sv", - "linear_layers/fixed_operators/rtl/fixed_adder_tree_layer.sv", - "linear_layers/fixed_operators/rtl/fixed_adder_tree.sv", - "linear_layers/fixed_operators/rtl/fixed_range_reduction.sv", - "linear_layers/fixed_linear_layer/rtl/fixed_linear.sv", - "linear_layers/matmul/rtl/matrix_flatten.sv", - "linear_layers/matmul/rtl/matrix_unflatten.sv", - "linear_layers/matmul/rtl/matrix_fifo.sv", - "linear_layers/matmul/rtl/matrix_accumulator.sv", - "linear_layers/matmul/rtl/simple_matmul.sv", - "linear_layers/matmul/rtl/matmul.sv", - "linear_layers/matmul/rtl/transpose.sv", - "linear_layers/matmul/rtl/matrix_stream_transpose.sv", - "activation_layers/rtl/softermax_lpw_pow2.sv", - "activation_layers/rtl/softermax_lpw_reciprocal.sv", - "activation_layers/rtl/softermax_local_window.sv", - "activation_layers/rtl/softermax_global_norm.sv", - "activation_layers/rtl/fixed_softermax_1d.sv", - "activation_layers/rtl/fixed_softermax.sv", - "transformer_layers/rtl/fixed_gqa_projections.sv", - "transformer_layers/rtl/self_attention_head_single_scatter.sv", - "transformer_layers/rtl/gqa_head_scatter_control.sv", - "transformer_layers/rtl/self_attention_head_gather.sv", - "transformer_layers/rtl/fixed_self_attention_head.sv", - "transformer_layers/rtl/fixed_grouped_query_attention.sv", - "transformer_layers/rtl/fixed_gqa_wrapper.sv", - ], - } - ], + }, + }, } diff --git a/src/chop/passes/graph/transforms/quantize/quant_parsers/parse_quant_config.py b/src/chop/passes/graph/transforms/quantize/quant_parsers/parse_quant_config.py index dea246fc2..1a9c62270 100644 --- a/src/chop/passes/graph/transforms/quantize/quant_parsers/parse_quant_config.py +++ b/src/chop/passes/graph/transforms/quantize/quant_parsers/parse_quant_config.py @@ -251,21 +251,21 @@ "bias_block_size", ), }, - "mxint_hardware": { + "mxint": { "weight_entries": ( "weight_width", "weight_exponent_width", - "weight_parallelism", + "weight_block_size", ), "data_in_entries": ( "data_in_width", "data_in_exponent_width", - "data_in_parallelism", + "data_in_block_size", ), "bias_entries": ( "bias_width", "bias_exponent_width", - "bias_parallelism", + "bias_block_size", ), }, } @@ -294,8 +294,10 @@ def cp_data_in_entries( cp_multi_values(config, p_config, entries["data_in_entries"], strict=strict) -def cp_data_out_entries(config: dict, p_config: dict, entries: dict): - cp_multi_values(config, p_config, entries["data_out_entries"]) +def cp_data_out_entries( + config: dict, p_config: dict, entries: dict, strict: bool = True +): + cp_multi_values(config, p_config, entries["data_out_entries"], strict=strict) def cp_bias_entries(config: dict, p_config: dict, entries: dict, strict: bool = True): @@ -321,12 +323,6 @@ def cp_layer_entries(config: dict, p_config: dict, entries: dict, strict: bool = cp_multi_values(config, p_config, entries["additional_layers_entries"]) -def cp_data_out_entries( - config: dict, p_config: dict, entries: dict, strict: bool = True -): - cp_multi_values(config, p_config, entries["data_out_entries"], strict=strict) - - """QUANT_ARITH_TO_CP_FN a map from quant_arith to a collection of functions where each function copies a specific quant_arith_spec from a src config to a parsed config. @@ -436,6 +432,10 @@ def parse_node_config(config: dict, mase_op: str, strict: bool = True) -> dict: assert isinstance( op_optional_entries, tuple ), f"op_optional_entries must be a tuple: {op_optional_entries}" + + # replace any instance of fixed with integer since they are the same, but all hardware is named after fixed + config["name"] = "fixed" if config["name"] == "integer" else config["name"] + p_config = {} for entry in op_entries: entry_cp_fn = QUANT_ARITH_TO_CP_FN[config["name"]][entry] diff --git a/src/chop/passes/graph/transforms/quantize/quant_parsers/update_node_meta.py b/src/chop/passes/graph/transforms/quantize/quant_parsers/update_node_meta.py index 0c580c4f2..6b1ca8513 100644 --- a/src/chop/passes/graph/transforms/quantize/quant_parsers/update_node_meta.py +++ b/src/chop/passes/graph/transforms/quantize/quant_parsers/update_node_meta.py @@ -28,7 +28,7 @@ def entry_to_list(config: dict, entry: str, suffixes: tuple[str]): "block_fp": ("width", "exponent_width", "exponent_bias", "block_size"), "block_minifloat": ("width", "exponent_width", "exponent_bias_width", "block_size"), "block_log": ("width", "exponent_bias_width", "block_size"), - "mxint_hardware": ("width", "exponent_width"), + "mxint": ("width", "exponent_width"), } @@ -91,31 +91,6 @@ def update_result(node, output_name, dtype=None, precision=None, size=None): node.meta["mase"].parameters["common"]["results"][output_name]["size"] = size -MASE_OP_TO_OUTPUT_ENTRIES = { - # entry and arg corresponding to name in software and hardware mapping - "add": (("data_out",), ("data_out_0",)), - "bmm": (("data_out",), ("data_out_0",)), - "conv1d": (("data_out",), ("data_out_0",)), - "conv2d": (("data_out",), ("data_out_0",)), - "matmul": (("data_out",), ("data_out_0",)), - "mul": (("data_out",), ("data_out_0",)), - "linear": (("data_out",), ("data_out_0",)), - "relu": (("data_out",), ("data_out_0",)), - "selu": (("data_out",), ("data_out_0",)), - "tanh": (("data_out",), ("data_out_0",)), - "gelu": (("data_out",), ("data_out_0",)), - "softsign": (("data_out",), ("data_out_0",)), - "softplus": (("data_out",), ("data_out_0",)), - "sub": (("data_out",), ("data_out_0",)), - "batch_norm2d": (("data_out",), ("data_out_0",)), - "layer_norm": (("data_out",), ("data_out_0",)), - "group_norm": (("data_out",), ("data_out_0",)), - "instance_norm2d": (("data_out",), ("data_out_0",)), - "rms_norm": (("data_out",), ("data_out_0")), - "grouped_query_attention": (("data_out",), ("data_out_0")), -} - - def arg_exists(node, arg_name) -> bool: return arg_name in node.meta["mase"].parameters["common"]["args"] @@ -144,15 +119,25 @@ def update_quant_meta_param(node, config: dict, mase_op: str) -> None: precision=quant_arith_to_list_fn[quant_arith](config, entry), ) - for entry, arg in zip(*MASE_OP_TO_OUTPUT_ENTRIES[mase_op]): - # Quantise all the output to fixed point. TODO: Make this automatic. Hardware will need change too - if quant_arith == "binary" or quant_arith == "binary_residual": - update_result( - node, - output_name=arg, - dtype="binary", - precision=[32, 0, 1], # [bitwidth, stochastic, bipolar] - ) + if quant_arith == "binary" or quant_arith == "binary_residual": + update_result( + node, + output_name="data_out_0", + dtype="binary", + precision=[32, 0, 1], # [bitwidth, stochastic, bipolar] + ) + else: + try: + precision = quant_arith_to_list_fn[quant_arith](config, "data_out") + except KeyError: + # fallback to use data_in-config if data_out is not defined + precision = quant_arith_to_list_fn[quant_arith](config, "data_in") + update_result( + node, + output_name="data_out_0", + dtype=quant_arith, + precision=precision, + ) def relink_node_meta(node, model): diff --git a/src/chop/passes/graph/transforms/verilog/emit_bram.py b/src/chop/passes/graph/transforms/verilog/emit_bram.py index a6a000c79..ae4b25963 100644 --- a/src/chop/passes/graph/transforms/verilog/emit_bram.py +++ b/src/chop/passes/graph/transforms/verilog/emit_bram.py @@ -4,10 +4,20 @@ import struct import time +from chop.nn.quantizers.utils import block +from chop.passes.graph.transforms.verilog.mxint_bram_template import mxint_template +from mase_components.linear_layers.mxint_operators.test.utils import ( + block_mxint_quant, + mxint_quantize, + pack_tensor_to_mx_listed_chunk, +) import torch from chop.passes.graph.utils import vf, v2p, get_module_by_name, init_project -from chop.nn.quantizers import integer_quantizer_for_hw, integer_floor_quantizer_for_hw +from chop.nn.quantizers import ( + integer_quantizer_for_hw, + integer_floor_quantizer_for_hw, +) logger = logging.getLogger(__name__) from pathlib import Path @@ -29,13 +39,76 @@ def _cap(name): def emit_parameters_in_mem_internal(node, param_name, file_name, data_name): + verilog_param_name = param_name.replace(".", "_") + match node.meta["mase"].parameters["common"]["args"][verilog_param_name]["type"]: + case "fixed": + emit_parameters_in_mem_internal_fixed_point( + node, verilog_param_name, file_name, data_name + ) + case "mxint": + emit_parameters_in_mem_internal_mxint( + node, verilog_param_name, file_name, data_name + ) + case unsupported_type: + raise NotImplementedError(f"Unsupported BRAM data-type {unsupported_type}") + + +def emit_parameters_in_mem_internal_mxint( + node, verilog_param_name, file_name, data_name +): + node_type_info = node.meta["mase"].parameters["common"]["args"][verilog_param_name] + node_verilog_info = node.meta["mase"].parameters["hardware"]["verilog_param"] + + block_size = int(node_verilog_info[f"{_cap(verilog_param_name)}_PARALLELISM_DIM_0"]) + # how many blocks are processed in parallel + parallelism = int( + node_verilog_info[f"{_cap(verilog_param_name)}_PARALLELISM_DIM_1"] + ) + + _, _, shape, _ = block( + torch.ones(node_type_info["shape"]), + block_shape=[parallelism, block_size], + ) + + total_size = math.prod(node_type_info["shape"]) + + out_size = block_size * parallelism + out_depth = int(total_size / out_size) + + mantissa_width = int(node_type_info["precision"][0]) + exponent_width = int(node_type_info["precision"][1]) + + node_param_name = f"{vf(node.name)}_{verilog_param_name}" + + rom_str = mxint_template.format( + node_param_name=node_param_name, + date_time=time.strftime("%d/%m/%Y %H:%M:%S"), + e_width=exponent_width, + e_mem_size=out_depth, + m_width=mantissa_width * out_size, + m_mem_size=out_depth, + filename=data_name, + verilog_param_name=_cap(verilog_param_name), + ) + + with open(file_name, "w", encoding="utf-8") as outf: + outf.write(rom_str) + logger.debug( + f"ROM module {verilog_param_name} successfully written into {file_name}" + ) + assert os.path.isfile(file_name), "ROM Verilog generation failed." + # os.system(f"verible-verilog-format --inplace {file_name}") + + +def emit_parameters_in_mem_internal_fixed_point( + node, verilog_param_name, file_name, data_name +): """ Emit single-port ROM hardware components for each parameter (Mostly because Vivado does not support string type parameters...) """ # ! TO DO: currently emitting too many parameters - verilog_param_name = param_name.replace(".", "_") total_size = math.prod( node.meta["mase"].parameters["common"]["args"][verilog_param_name]["shape"] ) @@ -64,12 +137,12 @@ def emit_parameters_in_mem_internal(node, param_name, file_name, data_name): // ===================================== // Mase Hardware // Parameter: {node_param_name} -// {time.strftime('%d/%m/%Y %H:%M:%S')} +// {time.strftime("%d/%m/%Y %H:%M:%S")} // ===================================== `timescale 1 ns / 1 ps module {node_param_name}_rom #( - parameter DWIDTH = {out_size*out_width}, + parameter DWIDTH = {out_size * out_width}, parameter MEM_SIZE = {out_depth}, parameter AWIDTH = $clog2(MEM_SIZE) + 1 ) ( @@ -96,7 +169,7 @@ def emit_parameters_in_mem_internal(node, param_name, file_name, data_name): `timescale 1 ns / 1 ps module {node_param_name} #( - parameter DATA_WIDTH = 32'd{out_width*out_size}, + parameter DATA_WIDTH = 32'd{out_width * out_size}, parameter ADDR_RANGE = 32'd{out_depth}, parameter ADDR_WIDTH = $clog2(ADDR_RANGE) + 1 ) ( @@ -188,92 +261,135 @@ def emit_parameters_in_mem_internal(node, param_name, file_name, data_name): def emit_parameters_in_dat_internal(node, param_name, file_name): """ - Emit initialised data for the ROM block. Each element must be in 8 HEX digits. + Emit initialized data for the ROM block. + Each element is represented in fixed-width hexadecimal format. """ verilog_param_name = param_name.replace(".", "_") - total_size = math.prod( - node.meta["mase"].parameters["common"]["args"][verilog_param_name]["shape"] - ) - # TO DO: change setting parallelism for weight in metadata - # node.meta["mase"].parameters["hardware"]["verilog_param"][f"{_cap(param_name)}_PARALLELISM_DIM_1"] + mase = node.meta["mase"] + common_args = mase.parameters["common"]["args"][verilog_param_name] + hw_verilog = mase.parameters["hardware"]["verilog_param"] + hw_interface = mase.parameters["hardware"]["interface"][verilog_param_name] + + total_size = math.prod(common_args["shape"]) out_size = int( - node.meta["mase"].parameters["hardware"]["verilog_param"][ - f"{_cap(verilog_param_name)}_PARALLELISM_DIM_0" - ] - * node.meta["mase"].parameters["hardware"]["verilog_param"][ - f"{_cap(verilog_param_name)}_PARALLELISM_DIM_1" - ] + hw_verilog[f"{_cap(verilog_param_name)}_PARALLELISM_DIM_0"] + * hw_verilog[f"{_cap(verilog_param_name)}_PARALLELISM_DIM_1"] ) - out_depth = int((total_size + out_size - 1) / out_size) + out_depth = int((total_size + out_size - 1) // out_size) - data_buff = "" - param_data = node.meta["mase"].module.get_parameter(param_name).data - if node.meta["mase"].parameters["hardware"]["interface"][verilog_param_name][ - "transpose" - ]: + param_data = mase.module.get_parameter(param_name).data + if hw_interface["transpose"]: param_data = torch.reshape( param_data, ( - node.meta["mase"].parameters["hardware"]["verilog_param"][ - "DATA_OUT_0_SIZE" - ], - node.meta["mase"].parameters["hardware"]["verilog_param"][ - "DATA_IN_0_DEPTH" - ], - node.meta["mase"].parameters["hardware"]["verilog_param"][ - "DATA_IN_0_SIZE" - ], + hw_verilog["DATA_OUT_0_SIZE"], + hw_verilog["DATA_IN_0_DEPTH"], + hw_verilog["DATA_IN_0_SIZE"], ), ) param_data = torch.transpose(param_data, 0, 1) - param_data = torch.flatten(param_data).tolist() - if ( - node.meta["mase"].parameters["common"]["args"][verilog_param_name]["type"] - == "fixed" - ): - width = node.meta["mase"].parameters["common"]["args"][verilog_param_name][ - "precision" - ][0] - frac_width = node.meta["mase"].parameters["common"]["args"][verilog_param_name][ - "precision" - ][1] + match common_args["type"]: + + case "fixed": + param_data = torch.flatten(param_data).tolist() + width = node.meta["mase"].parameters["common"]["args"][verilog_param_name][ + "precision" + ][0] + frac_width = node.meta["mase"].parameters["common"]["args"][ + verilog_param_name + ]["precision"][1] + + if node.meta["mase"].module.config.get("floor", False): + base_quantizer = integer_floor_quantizer_for_hw + else: + base_quantizer = integer_quantizer_for_hw + + data_buff = "" + for i in range(0, out_depth): + line_buff = "" + for j in range(0, out_size): + if i * out_size + out_size - 1 - j >= len(param_data): + value = 0 + else: + value = param_data[i * out_size + out_size - 1 - j] + + # TODO: please clear this up later + value = base_quantizer( + torch.tensor(value), width, frac_width + ).item() + value = str(bin(value)) + value_bits = value[value.find("0b") + 2 :] + value_bits = "0" * (width - len(value_bits)) + value_bits + assert len(value_bits) == width + value_bits = hex(int(value_bits, 2)) + value_bits = value_bits[value_bits.find("0x") + 2 :] + value_bits = "0" * (width // 4 - len(value_bits)) + value_bits + line_buff = value_bits + line_buff + + data_buff += line_buff + "\n" + + with open(file_name, "w", encoding="utf-8") as outf: + outf.write(data_buff) + logger.debug( + f"Init data {param_name} successfully written into {file_name}" + ) - if node.meta["mase"].module.config.get("floor", False): - base_quantizer = integer_floor_quantizer_for_hw - else: - base_quantizer = integer_quantizer_for_hw + case "mxint": + data_width, exponent_width = common_args["precision"] + block_size = [ + hw_verilog[f"{_cap(verilog_param_name)}_PARALLELISM_DIM_1"], + hw_verilog[f"{_cap(verilog_param_name)}_PARALLELISM_DIM_0"], + ] - scale = 2**frac_width - thresh = 2**width - for i in range(0, out_depth): - line_buff = "" - for j in range(0, out_size): - if i * out_size + out_size - 1 - j >= len(param_data): - value = 0 - else: - value = param_data[i * out_size + out_size - 1 - j] - - # TODO: please clear this up later - value = base_quantizer(torch.tensor(value), width, frac_width).item() - value = str(bin(value)) - value_bits = value[value.find("0b") + 2 :] - value_bits = "0" * (width - len(value_bits)) + value_bits - assert len(value_bits) == width - value_bits = hex(int(value_bits, 2)) - value_bits = value_bits[value_bits.find("0x") + 2 :] - value_bits = "0" * (width // 4 - len(value_bits)) + value_bits - line_buff = value_bits + line_buff + _quant, mxint_blocks, mxint_exp = block_mxint_quant( + param_data, + {"width": data_width, "exponent_width": exponent_width}, + block_size, + ) - data_buff += line_buff + "\n" - else: - assert False, "Emitting non-fixed parameters is not supported." + data_list = pack_tensor_to_mx_listed_chunk( + mxint_blocks, mxint_exp, block_size + ) - with open(file_name, "w", encoding="utf-8") as outf: - outf.write(data_buff) - logger.debug(f"Init data {param_name} successfully written into {file_name}") - assert os.path.isfile(file_name), "ROM data generation failed." + block_buff = "" + exp_buff = "" + for ms, e in data_list: + line_values = [] + # binary formatting for the mantissa to more easily handle non-power of 2 mantissa sizes + for m in ms: + value = int(m) + # two's complement + mask = 2 ** (data_width - 1) - 1 + value = (value & mask) - (value & ~mask) + # convert to binary string with variable 0 padding + bin_str = f"{value:0{data_width}b}" + line_values.append(bin_str) + block_buff += "_".join(line_values) + "\n" + + # convert to padded hex value + hex_str = f"{int(e):0{(exponent_width // 4) + 1}X}" + exp_buff += hex_str + "\n" + + block_file = file_name + "_block.dat" + exp_file = file_name + "_exp.dat" + + with open(block_file, "w", encoding="utf-8") as outf: + outf.write(block_buff) + with open(exp_file, "w", encoding="utf-8") as outf: + outf.write(exp_buff) + + assert os.path.isfile(block_file), "ROM data generation failed." + logger.debug( + f"Init data {param_name} successfully written into {block_file}" + ) + + assert os.path.isfile(exp_file), "ROM data generation failed." + logger.debug(f"Init data {param_name} successfully written into {exp_file}") + + case _: + raise ValueError("Emitting non-fixed parameters is not supported.") def emit_parameters_in_dat_hls(node, param_name, file_name): @@ -360,9 +476,7 @@ def emit_bram_handshake(node, rtl_dir): verilog_name = os.path.join( rtl_dir, f"{node_name}_{param_verilog_name}_source.sv" ) - data_name = os.path.join( - rtl_dir, f"{node_name}_{param_verilog_name}_rom.dat" - ) + data_name = os.path.join(rtl_dir, f"{node_name}_{param_verilog_name}_rom") emit_parameters_in_mem_internal(node, param_name, verilog_name, data_name) emit_parameters_in_dat_internal(node, param_name, data_name) else: @@ -406,7 +520,7 @@ def emit_parameters_in_mem_hls(node, param_name, file_name, data_name): `timescale 1 ns / 1 ps module {node_param_name}_rom #( - parameter DWIDTH = {out_size*out_width}, + parameter DWIDTH = {out_size * out_width}, parameter MEM_SIZE = {out_depth}, parameter AWIDTH = $clog2(MEM_SIZE) + 1 ) ( @@ -433,7 +547,7 @@ def emit_parameters_in_mem_hls(node, param_name, file_name, data_name): `timescale 1 ns / 1 ps module {node_param_name}_source #( - parameter DATA_WIDTH = 32'd{out_width*out_size}, + parameter DATA_WIDTH = 32'd{out_width * out_size}, parameter ADDR_RANGE = 32'd{out_depth}, parameter ADDR_WIDTH = $clog2(ADDR_RANGE) + 1 ) ( diff --git a/src/chop/passes/graph/transforms/verilog/emit_tb.py b/src/chop/passes/graph/transforms/verilog/emit_tb.py index ad72954f8..c361d1a74 100644 --- a/src/chop/passes/graph/transforms/verilog/emit_tb.py +++ b/src/chop/passes/graph/transforms/verilog/emit_tb.py @@ -1,3 +1,9 @@ +from typing import Dict, Literal, Tuple +from mase_components.linear_layers.mxint_operators.test.utils import ( + block_mxint_quant, + pack_tensor_to_mx_listed_chunk, +) +import numpy as np import logging, torch from pathlib import Path from textwrap import indent @@ -16,13 +22,156 @@ import cocotb from mase_cocotb.testbench import Testbench -from mase_cocotb.interfaces.streaming import StreamDriver, StreamMonitor +from mase_cocotb.interfaces.streaming import ( + MultiSignalStreamDriver, + MultiSignalStreamMonitor, + StreamDriver, + StreamMonitor, +) +from cocotb.result import TestFailure import dill import inspect +class FixedDriver(StreamDriver): + def __init__( + self, clk, data, valid, ready, precision, parallelism, record_num_beats=False + ) -> None: + super().__init__(clk, data, valid, ready, record_num_beats) + self.precision = precision + self.parallelism = parallelism + + def quantize_and_load(self, tensor_batches): + from mase_cocotb.utils import fixed_preprocess_tensor + + in_data_blocks = fixed_preprocess_tensor( + tensor=tensor_batches, + q_config={ + "width": self.precision[0], + "frac_width": self.precision[1], + }, + parallelism=self.parallelism, + ) + + block_size = self.parallelism[0] * self.parallelism[1] + for block in in_data_blocks: + if len(block) < block_size: + block = block + [0] * (block_size - len(block)) + self.append(block) + + +class MxIntDriver(MultiSignalStreamDriver): + def __init__(self, clk, data, valid, ready, config, parallelism) -> None: + super().__init__(clk, data, valid, ready) + self.config = config + self.parallelism = parallelism + + def quantize_and_load(self, tensor_batches): + (_qtensor, mtensor, etensor) = block_mxint_quant( + tensor_batches, self.config, self.parallelism + ) + driver_input = pack_tensor_to_mx_listed_chunk( + mtensor, etensor, self.parallelism + ) + self.load_driver(driver_input) + + +class FixedMonitor(StreamMonitor): + def __init__( + self, + clk, + data, + valid, + ready, + precision, + parallelism, + check=True, + name=None, + unsigned=False, + ): + super().__init__(clk, data, valid, ready, check, name, unsigned) + self.precision = precision + self.parallelism = parallelism + + def quantize_and_expect(self, tensor_expectation): + from mase_cocotb.utils import fixed_preprocess_tensor + + output_blocks = fixed_preprocess_tensor( + tensor=tensor_expectation, + q_config={ + "width": self.precision[0], + "frac_width": self.precision[1], + }, + parallelism=self.parallelism, + ) + + block_size = self.parallelism[0] * self.parallelism[1] + for block in output_blocks: + if len(block) < block_size: + block = block + [0] * (block_size - len(block)) + self.expect(block) + self.in_flight = True + + +class MxIntMonitor(MultiSignalStreamMonitor): + def __init__( + self, clk, e_data, m_data, valid, ready, config, parallelism, off_by_value=0 + ): + self.off_by = off_by_value + self.config = config + self.parallelism = parallelism + super().__init__( + clk, + (m_data, e_data), + valid, + ready, + check=True, + signed=True, + off_by_one=False, + ) + + def quantize_and_expect(self, tensor_expectation): + (qtensor, mtensor, etensor) = block_mxint_quant( + tensor_expectation, self.config, self.parallelism + ) + tensor_output = pack_tensor_to_mx_listed_chunk( + mtensor, etensor, self.parallelism + ) + + exp_max_val = 2 ** self.config["exponent_width"] + for i, (tensor, exp) in enumerate(tensor_output): + exp_signed = (2 * exp) % exp_max_val - (exp % exp_max_val) + tensor_output[i] = (tensor, exp_signed) + + self.load_monitor(tensor_output) + self.in_flight = True + + def _check(self, got, exp): + got_m, got_e = got + exp_m, exp_e = exp + + def check_equality(got, exp): + if not np.equal(got, exp).all(): + diff = np.subtract(got, exp) + if np.isclose(got, exp, atol=self.off_by).all(): + self.log.warning( + f"Off-by-{max(abs(diff))} error: {diff=}\nGot {got}\nExp {exp}" + ) + else: + raise TestFailure( + "\nGot \n%s, \nExp \n%s,\nDiff \n%s" % (got, exp, diff) + ) + + if exp_e == got_e: + check_equality(got_m, exp_m) + elif abs(diff := (exp_e - got_e)) == 1: + adj_m = np.array(got_m) * 2 ** (-diff) + self.log.warning(f"Normalisation Error {exp_e=} {got_e=}") + check_equality(adj_m, exp_m) + + def _cap(name): """ capitalize a string @@ -31,11 +180,9 @@ def _cap(name): def _emit_cocotb_test(graph, pass_args={}): - - wait_time = pass_args.get("wait_time", 2) - wait_unit = pass_args.get("wait_units", "ms") - batch_size = pass_args.get("batch_size", 1) - + wait_time = pass_args.get("wait_time", 100) + wait_unit = pass_args.get("wait_units", "us") + num_batches = pass_args.get("num_batches", 1) test_template = f""" import cocotb @@ -51,7 +198,7 @@ async def test(dut): await tb.initialize() - in_tensors = tb.generate_inputs(batches={batch_size}) + in_tensors = tb.generate_inputs(batches={num_batches}) exp_out = tb.model(*list(in_tensors.values())) tb.load_drivers(in_tensors) @@ -71,127 +218,135 @@ class MaseGraphTB(Testbench): def __init__(self, dut, fail_on_checks=True): super().__init__(dut, dut.clk, dut.rst, fail_on_checks=fail_on_checks) - # Instantiate as many drivers as required inputs to the model - self.input_drivers = {} - self.output_monitors = {} - + self.input_drivers: Dict[str, FixedDriver | MxIntDriver] = {} + self.output_monitors: Dict[str, FixedMonitor | MxIntMonitor] = {} for node in graph.nodes_in: - for arg in node.meta["mase"]["common"]["args"].keys(): + for arg, arg_info in node.meta["mase"]["common"]["args"].items(): if "data_in" not in arg: continue - self.input_drivers[arg] = StreamDriver( - dut.clk, - getattr(dut, arg), - getattr(dut, f"{arg}_valid"), - getattr(dut, f"{arg}_ready"), - ) - self.input_drivers[arg].log.setLevel(logging.DEBUG) + match arg_info.get("type", None): + case "mxint": + config = { + "width": self.get_parameter(f"{_cap(arg)}_PRECISION_0"), + "exponent_width": self.get_parameter( + f"{_cap(arg)}_PRECISION_1" + ), + } + parallelism = [ + self.get_parameter(f"{_cap(arg)}_PARALLELISM_DIM_1"), + self.get_parameter(f"{_cap(arg)}_PARALLELISM_DIM_0"), + ] + self.input_drivers[arg] = MxIntDriver( + dut.clk, + ( + getattr(dut, f"m_{arg}"), + getattr(dut, f"e_{arg}"), + ), + getattr(dut, f"{arg}_valid"), + getattr(dut, f"{arg}_ready"), + config, + parallelism, + ) + case "fixed": + precision = [ + self.get_parameter(f"{_cap(arg)}_PRECISION_0"), + self.get_parameter(f"{_cap(arg)}_PRECISION_1"), + ] + parallelism = [ + self.get_parameter(f"{_cap(arg)}_PARALLELISM_DIM_1"), + self.get_parameter(f"{_cap(arg)}_PARALLELISM_DIM_0"), + ] + self.input_drivers[arg] = FixedDriver( + dut.clk, + getattr(dut, arg), + getattr(dut, f"{arg}_valid"), + getattr(dut, f"{arg}_ready"), + precision, + parallelism, + ) + case t: + raise NotImplementedError( + f"Unsupported type format {t} for {node} {arg}" + ) + self.input_drivers[arg].log.setLevel(logging.INFO) - # Instantiate as many monitors as required outputs for node in graph.nodes_out: - for result in node.meta["mase"]["common"]["results"].keys(): + for result, result_info in node.meta["mase"]["common"][ + "results" + ].items(): if "data_out" not in result: continue - self.output_monitors[result] = StreamMonitor( - dut.clk, - getattr(dut, result), - getattr(dut, f"{result}_valid"), - getattr(dut, f"{result}_ready"), - check=False, - ) - self.output_monitors[result].log.setLevel(logging.DEBUG) + match result_info.get("type", None): + case "mxint": + config = { + "width": self.get_parameter("DATA_OUT_0_PRECISION_0"), + "exponent_width": self.get_parameter( + "DATA_OUT_0_PRECISION_1" + ), + } + parallelism = [ + self.get_parameter("DATA_IN_0_PARALLELISM_DIM_1"), + self.get_parameter("DATA_IN_0_PARALLELISM_DIM_0"), + ] + self.output_monitors[result] = MxIntMonitor( + dut.clk, + getattr(dut, f"e_{result}"), + getattr(dut, f"m_{result}"), + getattr(dut, f"{result}_valid"), + getattr(dut, f"{result}_ready"), + config, + parallelism, + off_by_value=1, + ) + case "fixed": + precision = [ + self.get_parameter("DATA_OUT_0_PRECISION_0"), + self.get_parameter("DATA_OUT_0_PRECISION_1"), + ] + parallelism = [ + self.get_parameter("DATA_OUT_0_PARALLELISM_DIM_1"), + self.get_parameter("DATA_OUT_0_PARALLELISM_DIM_0"), + ] + self.output_monitors[result] = FixedMonitor( + dut.clk, + getattr(dut, result), + getattr(dut, f"{result}_valid"), + getattr(dut, f"{result}_ready"), + precision, + parallelism, + check=False, + ) + case t: + raise NotImplementedError( + f"Unsupported type format {t} for {node} {result}" + ) + self.output_monitors[result].log.setLevel(logging.INFO) self.model = graph.model - - # To do: precision per input argument self.input_precision = graph.meta["mase"]["common"]["args"]["data_in_0"][ "precision" ] - def generate_inputs(self, batches): - """ - Generate inputs for the model by sampling a random tensor - for each input argument, according to its shape - - :param batches: number of batches to generate for each argument - :type batches: int - :return: a dictionary of input arguments and their corresponding tensors - :rtype: Dict - """ - # ! TO DO: iterate through graph.args instead to generalize + def generate_inputs(self, batches=1): inputs = {} for node in graph.nodes_in: for arg, arg_info in node.meta["mase"]["common"]["args"].items(): - # Batch dimension always set to 1 in metadata if "data_in" not in arg: continue - # print(f"Generating data for node {node}, arg {arg}: {arg_info}") - inputs[f"{arg}"] = torch.rand(([batches] + arg_info["shape"][1:])) + print( + f"Generating data for node {node}, arg {arg}: {arg_info} {arg_info['shape']}" + ) + inputs[f"{arg}"] = torch.randn(([batches] + arg_info["shape"])) return inputs def load_drivers(self, in_tensors): for arg, arg_batches in in_tensors.items(): - # Quantize input tensor according to precision - if len(self.input_precision) > 1: - from mase_cocotb.utils import fixed_preprocess_tensor - - in_data_blocks = fixed_preprocess_tensor( - tensor=arg_batches, - q_config={ - "width": self.get_parameter(f"{_cap(arg)}_PRECISION_0"), - "frac_width": self.get_parameter( - f"{_cap(arg)}_PRECISION_1" - ), - }, - parallelism=[ - self.get_parameter(f"{_cap(arg)}_PARALLELISM_DIM_1"), - self.get_parameter(f"{_cap(arg)}_PARALLELISM_DIM_0"), - ], - ) - - else: - # TO DO: convert to integer equivalent of floating point representation - pass - - # Append all input blocks to input driver - # ! TO DO: generalize - block_size = self.get_parameter( - "DATA_IN_0_PARALLELISM_DIM_0" - ) * self.get_parameter("DATA_IN_0_PARALLELISM_DIM_1") - for block in in_data_blocks: - if len(block) < block_size: - block = block + [0] * (block_size - len(block)) - self.input_drivers[arg].append(block) + self.input_drivers[arg].quantize_and_load(arg_batches) def load_monitors(self, expectation): - from mase_cocotb.utils import fixed_preprocess_tensor - - # Process the expectation tensor - output_blocks = fixed_preprocess_tensor( - tensor=expectation, - q_config={ - "width": self.get_parameter(f"DATA_OUT_0_PRECISION_0"), - "frac_width": self.get_parameter(f"DATA_OUT_0_PRECISION_1"), - }, - parallelism=[ - self.get_parameter(f"DATA_OUT_0_PARALLELISM_DIM_1"), - self.get_parameter(f"DATA_OUT_0_PARALLELISM_DIM_0"), - ], - ) - - # Set expectation for each monitor - for block in output_blocks: - # ! TO DO: generalize to multi-output models - if len(block) < self.get_parameter("DATA_OUT_0_PARALLELISM_DIM_0"): - block = block + [0] * ( - self.get_parameter("DATA_OUT_0_PARALLELISM_DIM_0") - len(block) - ) - self.output_monitors["data_out_0"].expect(block) - - # Drive the in-flight flag for each monitor - self.output_monitors["data_out_0"].in_flight = True + for result, monitor in self.output_monitors.items(): + monitor.quantize_and_expect(expectation) - # Serialize testbench object to be instantiated within test by cocotb runner cls_obj = MaseGraphTB tb_path = Path.home() / ".mase" / "top" / "hardware" / "test" / "mase_top_tb" tb_path.mkdir(parents=True, exist_ok=True) diff --git a/src/chop/passes/graph/transforms/verilog/emit_top.py b/src/chop/passes/graph/transforms/verilog/emit_top.py index bedff29f8..85260cb51 100644 --- a/src/chop/passes/graph/transforms/verilog/emit_top.py +++ b/src/chop/passes/graph/transforms/verilog/emit_top.py @@ -1,5 +1,5 @@ import logging -from typing import Tuple, Dict +from typing import Literal, Tuple, Dict import math import os import time @@ -83,6 +83,114 @@ def is_real_input_arg(node, arg_idx): ) +def interface_template( + type: str | None, + var_name: str, + param_name: str, + parallelism_params: list[str], + node_name: str, + direction: Literal["input", "output", "logic"], +): + suffix = ";" if direction == "logic" else "," + match type: + case "fixed": + out = f""" + {direction} [{param_name}_PRECISION_0-1:0] {var_name} [{'*'.join(parallelism_params)}-1:0]{suffix}""" + case "mxint": + out = f""" + {direction} [{param_name}_PRECISION_0-1:0] m_{var_name} [{'*'.join(parallelism_params)}-1:0]{suffix} + {direction} [{param_name}_PRECISION_1-1:0] e_{var_name}{suffix}""" + case None: + raise ValueError( + f"Missing type information for {node_name} {param_name}, {var_name}" + ) + case t: + raise NotImplementedError( + f"Unsupported type format {t} for {node_name} {param_name}, {var_name}" + ) + return ( + out + + f""" + {'logic' if direction == 'logic' else 'input ' if direction == 'input' else 'output'} {var_name}_valid{suffix} + {'logic' if direction == 'logic' else 'output' if direction == 'input' else 'input '} {var_name}_ready{suffix}""" + ) + + +def module_interface_template( + type: str | None, + port_name: str, + signal_name: str, + node_name: str, +): + match type: + case "fixed": + out = f""" + .{port_name}({signal_name}),""" + case "mxint": + out = f""" + .m{port_name}(m_{signal_name}), + .e{port_name}(e_{signal_name}),""" + case None: + raise ValueError( + f"Missing type information for {node_name} {port_name} {signal_name}" + ) + case t: + raise NotImplementedError( + f"Unsupported type format {t} for {node_name} {port_name} {signal_name}" + ) + return ( + out + + f""" + .{port_name}_valid({signal_name}_valid), + .{port_name}_ready({signal_name}_ready),""" + ) + + +def wiring_template( + type: str | None, + from_signal: str, + to_signal: str, + direction: Literal["input", "output"], + node_name: str, +): + match type: + case "fixed": + if direction == "input": + out = f""" +assign {to_signal} = {from_signal};""" + else: + out = f""" +assign {from_signal} = {to_signal};""" + case "mxint": + if direction == "input": + out = f""" +assign m_{to_signal} = m_{from_signal}; +assign e_{to_signal} = e_{from_signal};""" + else: + out = f""" +assign m_{from_signal} = m_{to_signal}; +assign e_{from_signal} = e_{to_signal};""" + case None: + raise ValueError( + f"Missing type information for {node_name} {from_signal} {to_signal}" + ) + case t: + raise NotImplementedError( + f"Unsupported type format {t} for {node_name} {from_signal} {to_signal}" + ) + if direction == "input": + out += f""" +assign {from_signal}_ready = {to_signal}_ready; +assign {to_signal}_valid = {from_signal}_valid; + """ + else: + out += f""" +assign {to_signal}_ready = {from_signal}_ready; +assign {from_signal}_valid = {to_signal}_valid; + """ + return out + + # ============================================================================= # Verilog parameters # ============================================================================= @@ -138,27 +246,33 @@ def emit(self, graph, parameter_map): i = 0 for node in nodes_in: node_name = vf(node.name) - for arg_idx, arg in enumerate( - node.meta["mase"].parameters["common"]["args"].keys() + for arg_idx, (arg, info) in enumerate( + node.meta["mase"].parameters["common"]["args"].items() ): if is_real_input_arg(node, arg_idx): - # if "data_in" in arg: arg_name = _cap(arg) parallelism_params = [ param for param in parameter_map if param.startswith(f"{arg_name}_PARALLELISM_DIM") ] - interface += f""" - input [{arg_name}_PRECISION_0-1:0] data_in_{i} [{'*'.join(parallelism_params)}-1:0], - input data_in_{i}_valid, - output data_in_{i}_ready,""" + + interface += interface_template( + info.get("type", None), + var_name=f"data_in_{i}", + param_name=arg_name, + parallelism_params=parallelism_params, + node_name=node_name, + direction="input", + ) i += 1 i = 0 for node in nodes_out: node_name = vf(node.name) - for result in node.meta["mase"].parameters["common"]["results"].keys(): + for result, info in ( + node.meta["mase"].parameters["common"]["results"].items() + ): if "data_out" in result: result_name = _cap(result) parallelism_params = [ @@ -166,10 +280,14 @@ def emit(self, graph, parameter_map): for param in parameter_map if param.startswith(f"{result_name}_PARALLELISM_DIM") ] - interface += f""" - output [{result_name}_PRECISION_0-1:0] data_out_{i} [{'*'.join(parallelism_params)}-1:0], - output data_out_{i}_valid, - input data_out_{i}_ready,""" + interface += interface_template( + info.get("type", None), + var_name=f"data_out_{i}", + param_name=result_name, + parallelism_params=parallelism_params, + node_name=node_name, + direction="output", + ) i += 1 # TODO: emit off-chip parameter interface @@ -212,10 +330,14 @@ def _emit_signals_top_internal(self, node, parameter_map): if node.meta["mase"]["common"]["mase_op"] == "getitem": arg = "data_in_0" - signals += f""" -logic [{node_name}_{arg_name}_PRECISION_0-1:0] {node_name}_{arg} [{'*'.join(parallelism_params)}-1:0]; -logic {node_name}_{arg}_valid; -logic {node_name}_{arg}_ready;""" + signals += interface_template( + arg_info.get("type", None), + var_name=f"{node_name}_{arg}", + param_name=f"{node_name}_{arg_name}", + parallelism_params=parallelism_params, + node_name=node_name, + direction="logic", + ) # Output signals for result, result_info in ( @@ -238,10 +360,15 @@ def _emit_signals_top_internal(self, node, parameter_map): for param in parameter_map if f"{node_name}_{result_name}_PARALLELISM_DIM" in param ] - signals += f""" -logic [{node_name}_{result_name}_PRECISION_0-1:0] {node_name}_{result} [{'*'.join(parallelism_params)}-1:0]; -logic {node_name}_{result}_valid; -logic {node_name}_{result}_ready;""" + + signals += interface_template( + result_info.get("type", None), + var_name=f"{node_name}_{result}", + param_name=f"{node_name}_{result_name}", + parallelism_params=parallelism_params, + node_name=node_name, + direction="logic", + ) return signals @@ -339,15 +466,20 @@ def _emit_module_parameters_top_internal(self, key, value, node, parameter_map): parameters += f" .{param}({node_name}_{param}),\n" parameters = _remove_last_comma(parameters) + signals = module_interface_template( + value.get("type", None), + port_name="data_out", + signal_name=f"{node_name}_{key}", + node_name=node_name, + ) + signals = _remove_last_comma(signals) return f""" {component_name} #( {parameters} ) {component_name_inst} ( .clk(clk), .rst(rst), - .data_out({node_name}_{key}), - .data_out_ready({node_name}_{key}_ready), - .data_out_valid({node_name}_{key}_valid) + {signals} ); """ @@ -359,7 +491,7 @@ def _emit_getitem_signals(self, node): """ node_name = vf(node.name) - + # TODO@luigi support mxint here too return f""" .data_in_0 ({node_name}_data_in_0), .data_in_0_valid ({node_name}_data_in_0_valid), @@ -371,6 +503,7 @@ def _emit_getitem_signals(self, node): """ def emit(self, node, parameter_map): + node_name = vf(node.name) component_name = node.meta["mase"].parameters["hardware"]["module"] signals = "" @@ -398,19 +531,21 @@ def emit(self, node, parameter_map): for key, value in node.meta["mase"].parameters["common"]["args"].items(): if "inplace" in key or not isinstance(value, dict): continue - signals += f""" - .{key}({node_name}_{key}), - .{key}_valid({node_name}_{key}_valid), - .{key}_ready({node_name}_{key}_ready), - """ + signals += module_interface_template( + value.get("type", None), + port_name=key, + signal_name=f"{node_name}_{key}", + node_name=node_name, + ) # Emit component instantiation output signals for key, value in node.meta["mase"].parameters["common"]["results"].items(): - signals += f""" - .{key}({node_name}_{key}), - .{key}_valid({node_name}_{key}_valid), - .{key}_ready({node_name}_{key}_ready), - """ + signals += module_interface_template( + value.get("type", None), + port_name=key, + signal_name=f"{node_name}_{key}", + node_name=node_name, + ) # Remove final comma in signal list signals = _remove_last_comma(signals) @@ -596,26 +731,32 @@ def _emit_top_wires(self): i = 0 for node in nodes_in: node_name = vf(node.name) - for arg_idx, arg in enumerate( - node.meta["mase"].parameters["common"]["args"].keys() + for arg_idx, (arg, arg_info) in enumerate( + node.meta["mase"].parameters["common"]["args"].items() ): if is_real_input_arg(node, arg_idx): - wires += f""" -assign data_in_{i}_ready = {node_name}_{arg}_ready; -assign {node_name}_{arg}_valid = data_in_{i}_valid; -assign {node_name}_{arg} = data_in_{i}; -""" + wires += wiring_template( + arg_info.get("type", None), + from_signal=f"data_in_{i}", + to_signal=f"{node_name}_{arg}", + node_name=node_name, + direction="input", + ) i += 1 i = 0 for node in nodes_out: node_name = vf(node.name) - for result in node.meta["mase"].parameters["common"]["results"].keys(): + for result, result_info in ( + node.meta["mase"].parameters["common"]["results"].items() + ): if "data_out" in result: - wires += f""" -assign data_out_{i}_valid = {node_name}_{result}_valid; -assign {node_name}_{result}_ready = data_out_{i}_ready; -assign data_out_{i} = {node_name}_{result}; -""" + wires += wiring_template( + result_info.get("type", None), + from_signal=f"data_out_{i}", + to_signal=f"{node_name}_{result}", + node_name=node_name, + direction="output", + ) i += 1 # TODO: emit off-chip parameter interface @@ -658,14 +799,26 @@ def _emit_node2node_wires(self): continue to_name = vf(node.name) - for i, node_in in enumerate(node.all_input_nodes): + to_type = node.meta["mase"]["common"]["args"][f"data_in_{i}"].get( + "type", None + ) + from_type = ( + node_in.meta["mase"] + .parameters["common"]["results"]["data_out_0"] + .get("type", None) + ) + assert ( + to_type == from_type + ), f"Incongruent types {to_type=} {from_type=}" from_name = vf(node_in.name) - wires += f""" -assign {from_name}_data_out_0_ready = {to_name}_data_in_{i}_ready; -assign {to_name}_data_in_{i}_valid = {from_name}_data_out_0_valid; -assign {to_name}_data_in_{i} = {from_name}_data_out_0; -""" + wires += wiring_template( + to_type, + from_signal=f"{from_name}_data_out_0", + to_signal=f"{to_name}_data_in_{i}", + node_name=node.name, + direction="input", + ) return wires def emit(self): @@ -769,6 +922,10 @@ def emit_verilog_top_transform_pass(graph, pass_args={}): top_name = pass_args["top_name"] if "top_name" in pass_args.keys() else "top" init_project(project_dir) rtl_dir = os.path.join(project_dir, "hardware", "rtl") + import shutil + + shutil.rmtree(rtl_dir) + os.mkdir(rtl_dir) top = VerilogEmitter(graph).emit(graph, top_name) diff --git a/src/chop/passes/graph/transforms/verilog/mxint_bram_template.py b/src/chop/passes/graph/transforms/verilog/mxint_bram_template.py new file mode 100644 index 000000000..bf217c311 --- /dev/null +++ b/src/chop/passes/graph/transforms/verilog/mxint_bram_template.py @@ -0,0 +1,123 @@ +mxint_template = """ +// ===================================== +// Mase Hardware +// Parameter: {node_param_name} +// {date_time} +// ===================================== + +`timescale 1 ns / 1 ps +module {node_param_name}_mantissa_rom #( + parameter DWIDTH = {m_width}, + parameter MEM_SIZE = {m_mem_size}, + parameter AWIDTH = $clog2(MEM_SIZE) + 1 +) ( + input clk, + input logic [AWIDTH-1:0] addr0, + input ce0, + output logic [DWIDTH-1:0] q0 +); + + logic [DWIDTH-1:0] ram[0:MEM_SIZE-1]; + logic [DWIDTH-1:0] q0_t0; + logic [DWIDTH-1:0] q0_t1; + + initial begin + $readmemb("{filename}_block.dat", ram); + end + + assign q0 = q0_t1; + + always_ff @(posedge clk) if (ce0) q0_t1 <= q0_t0; + always_ff @(posedge clk) if (ce0) q0_t0 <= ram[addr0]; +endmodule + +`timescale 1 ns / 1 ps +module {node_param_name}_exponent_rom #( + parameter DWIDTH = {e_width}, + parameter MEM_SIZE = {e_mem_size}, + parameter AWIDTH = $clog2(MEM_SIZE) + 1 +) ( + input clk, + input logic [AWIDTH-1:0] addr0, + input ce0, + output logic [DWIDTH-1:0] q0 +); + + logic [DWIDTH-1:0] ram[0:MEM_SIZE-1]; + logic [DWIDTH-1:0] q0_t0; + logic [DWIDTH-1:0] q0_t1; + + initial begin + $readmemh("{filename}_exp.dat", ram); + end + + assign q0 = q0_t1; + + always_ff @(posedge clk) if (ce0) q0_t1 <= q0_t0; + always_ff @(posedge clk) if (ce0) q0_t0 <= ram[addr0]; +endmodule + +`timescale 1ns / 1ps +module {node_param_name}_source #( + parameter {verilog_param_name}_TENSOR_SIZE_DIM_1 = 1, + parameter {verilog_param_name}_TENSOR_SIZE_DIM_0 = 32, + parameter {verilog_param_name}_PRECISION_0 = 16, + parameter {verilog_param_name}_PRECISION_1 = 3, + + parameter {verilog_param_name}_PARALLELISM_DIM_0 = 1, + parameter {verilog_param_name}_PARALLELISM_DIM_1 = 1, + parameter OUT_DEPTH = (({verilog_param_name}_TENSOR_SIZE_DIM_0 + {verilog_param_name}_PARALLELISM_DIM_0 - 1) / {verilog_param_name}_PARALLELISM_DIM_0) + * (({verilog_param_name}_TENSOR_SIZE_DIM_1 + {verilog_param_name}_PARALLELISM_DIM_1 - 1) / {verilog_param_name}_PARALLELISM_DIM_1) +) ( + input clk, + input rst, + + output logic [{verilog_param_name}_PRECISION_0-1:0] mdata_out [{verilog_param_name}_PARALLELISM_DIM_0 * {verilog_param_name}_PARALLELISM_DIM_1-1:0], + output logic [{verilog_param_name}_PRECISION_1-1:0] edata_out, + output data_out_valid, + input data_out_ready +); + // 1-bit wider so IN_DEPTH also fits. + localparam COUNTER_WIDTH = $clog2(OUT_DEPTH); + logic [COUNTER_WIDTH:0] counter; + + always_ff @(posedge clk) + if (rst) counter <= 0; + else begin + if (data_out_ready) begin + if (counter == OUT_DEPTH - 1) counter <= 0; + else counter <= counter + 1; + end + end + + logic [1:0] clear; + always_ff @(posedge clk) + if (rst) clear <= 0; + else if ((data_out_ready == 1) && (clear != 2)) clear <= clear + 1; + logic ce0; + assign ce0 = data_out_ready; + + logic [{verilog_param_name}_PRECISION_0*{verilog_param_name}_PARALLELISM_DIM_0*{verilog_param_name}_PARALLELISM_DIM_1-1:0] data_vector; + {node_param_name}_mantissa_rom #() {node_param_name}_mantissa ( + .clk(clk), + .addr0(counter), + .ce0(ce0), + .q0(data_vector) + ); + + {node_param_name}_exponent_rom #() {node_param_name}_exponent ( + .clk(clk), + .addr0(counter), + .ce0(ce0), + .q0(edata_out) + ); + + // Cocotb/verilator does not support array flattening, so + // we need to manually add some reshaping process. + for (genvar j = 0; j < {verilog_param_name}_PARALLELISM_DIM_0 * {verilog_param_name}_PARALLELISM_DIM_1; j++) + assign mdata_out[j] = data_vector[{verilog_param_name}_PRECISION_0*(j+1)-1:{verilog_param_name}_PRECISION_0*j]; + + assign data_out_valid = clear == 2; + +endmodule +""" diff --git a/src/chop/tools/get_input.py b/src/chop/tools/get_input.py index 710724980..75a925905 100644 --- a/src/chop/tools/get_input.py +++ b/src/chop/tools/get_input.py @@ -107,9 +107,15 @@ def get_dummy_input( dummy_inputs = {"x": x} case _: raise ValueError(f"Task {task} is not supported for {model_info.name}") - elif model_info.is_nerf_model: - # TODO: - pass + elif model_info.is_nerf_model or model_info.is_nerf_vision_model: + item = next(train_iter) + for key in item: + item[key] = item[key].to(device) + dummy_inputs = { + "pts": item["pts"], + "viewdirs": item["viewdirs"], + "targets": item, + } elif model_info.is_nlp_model: match task: diff --git a/src/chop/tools/plt_wrapper/__init__.py b/src/chop/tools/plt_wrapper/__init__.py index 86be5bffe..72fe951fc 100644 --- a/src/chop/tools/plt_wrapper/__init__.py +++ b/src/chop/tools/plt_wrapper/__init__.py @@ -11,7 +11,7 @@ def get_model_wrapper(model_info, task: str): if model_info.is_physical_model: return JetSubstructureModelWrapper - elif model_info.is_nerf_model: + elif model_info.is_nerf_model or model_info.is_nerf_vision_model: return NeRFModelWrapper elif model_info.is_vision_model: return VisionModelWrapper diff --git a/src/chop/tools/plt_wrapper/nerf/losses.py b/src/chop/tools/plt_wrapper/nerf/losses.py index 873f40cd3..b4335e41f 100644 --- a/src/chop/tools/plt_wrapper/nerf/losses.py +++ b/src/chop/tools/plt_wrapper/nerf/losses.py @@ -1,4 +1,8 @@ from torch import nn +import torch +from torchmetrics.aggregation import MeanMetric + +from chop.tools.plt_wrapper.nerf.metrics import psnr class ColorLoss(nn.Module): @@ -8,11 +12,118 @@ def __init__(self, coef=1): self.loss = nn.MSELoss(reduction="mean") def forward(self, inputs, targets): - loss = self.loss(inputs["rgb_coarse"], targets) - if "rgb_fine" in inputs: - loss += self.loss(inputs["rgb_fine"], targets) + output = post_render_vision(targets, inputs) + loss = self.loss(output["rgb"], targets["rgbs"]) return self.coef * loss +class NerfPsnr(MeanMetric): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def update(self, inputs, targets) -> None: + output = post_render_vision(targets, inputs) + preds, target = output["rgb"], targets["rgbs"] + psnr_ = psnr(preds, target) + if preds.shape != target.shape: + raise ValueError("preds and target must have the same shape") + + super().update(psnr_) + + +def post_render_vision( + x, + raw, + noise_std=1, +): + def raw2outputs(raw, z_vals, rays_d): + """Transforms model's predictions to semantically meaningful values. + + Args: + raw: [num_rays, num_samples along ray, 4]. Prediction from model. + z_vals: [num_rays, num_samples along ray]. Integration time. + rays_d: [num_rays, 3]. Direction of each ray. + + Returns: + rgb_map: [num_rays, 3]. Estimated RGB color of a ray. + disp_map: [num_rays]. Disparity map. Inverse of depth map. + acc_map: [num_rays]. Sum of weights along each ray. + weights: [num_rays, num_samples]. Weights assigned to each sampled color. + depth_map: [num_rays]. Estimated distance to object. + """ + + # Function for computing density from model prediction. This value is + # strictly between [0, 1]. + def raw2alpha(raw, dists, act_fn=torch.nn.functional.relu): + return 1.0 - torch.exp(-act_fn(raw) * dists) + + # Compute 'distance' (in time) between each integration time along a ray. + dists = z_vals[..., 1:] - z_vals[..., :-1] + + # The 'distance' from the last integration time is infinity. + dists = torch.concat( + [ + dists, + torch.ones_like(rays_d[:, :1]).to("cuda") * 1e10, + ], # ISSUE expand to [N_rays, 1] + dim=-1, + ) # [N_rays, N_samples] + + # Multiply each distance by the norm of its corresponding direction ray + # to convert to real world distance (accounts for non-unit directions). + dists = dists * torch.norm(rays_d[..., None, :], dim=-1) + + # Extract RGB of each sample position along each ray. + rgb = torch.sigmoid(raw[..., :3]) # [N_rays, N_samples, 3] + + # Add noise to model's predictions for density. Can be used to + # regularize network during training (prevents floater artifacts). + noise = 0.0 + if noise_std > 0.0: + noise = torch.rand_like(raw[..., 3]) * noise_std + + # Predict density of each sample along each ray. Higher values imply + # higher likelihood of being absorbed at this point. + alpha = raw2alpha(raw[..., 3] + noise, dists) # [N_rays, N_samples] + + # Compute weight for RGB of each sample along each ray. A cumprod() is + # used to express the idea of the ray not having reflected up to this + # sample yet. + # [N_rays, N_samples] + weights = alpha * torch.cumprod(1.0 - alpha + 1e-10, dim=-1) + + # Computed weighted color of each sample along each ray. + rgb_map = torch.sum(weights[..., None] * rgb, dim=-2) # [N_rays, 3] + + # Estimated depth map is expected distance. + depth_map = torch.sum(weights * z_vals, dim=-1) + + # Disparity map is inverse depth. + disp_map = 1.0 / torch.maximum( + torch.tensor([1e-10]).to(dists.device), + depth_map / torch.sum(weights, dim=-1), + ) + + # Sum of weights along each ray. This value is in [0, 1] up to numerical error. + acc_map = torch.sum(weights, dim=-1) + + return rgb_map, disp_map, acc_map, weights, depth_map + + rgb_map, disp_map, acc_map, weights, depth_map = raw2outputs( + raw, x["z_vals"], x["rays_d"] + ) + + results = {} + + results["rgb"] = rgb_map + results["depth"] = depth_map + results["weights"] = weights + results["opacity"] = acc_map + results["z_vals"] = x["z_vals"] + results["disp"] = disp_map + + return results + + loss_dict = {"color": ColorLoss} diff --git a/src/chop/tools/plt_wrapper/nerf/nerf.py b/src/chop/tools/plt_wrapper/nerf/nerf.py index bb21acd43..73e2f9c90 100644 --- a/src/chop/tools/plt_wrapper/nerf/nerf.py +++ b/src/chop/tools/plt_wrapper/nerf/nerf.py @@ -4,7 +4,7 @@ from ..base import WrapperBase -from .losses import loss_dict +from .losses import loss_dict, post_render_vision from .metrics import psnr from .visualization import visualize_depth @@ -36,18 +36,19 @@ def __init__( self.psnr_test = MeanMetric() - def forward(self, rays): + def forward(self, pts, viewdirs): """Do batched inference on rays using chunk.""" - return self.model(rays) + return self.model(pts, viewdirs) def training_step(self, batch, batch_idx): - rays, rgbs = batch["rays"], batch["rgbs"] - results = self(rays) - loss = self.loss(results, rgbs) + pts, viewdirs, rgbs = batch["pts"], batch["viewdirs"], batch["rgbs"] + results = self(pts, viewdirs) + loss = self.loss(results, batch) + results = post_render_vision(batch, results) with torch.no_grad(): - typ = "fine" if "rgb_fine" in results else "coarse" - psnr_ = psnr(results[f"rgb_{typ}"], rgbs) + # typ = "fine" if "rgb_fine" in results else "coarse" + psnr_ = psnr(results[f"rgb"], rgbs) # self.log('lr', get_learning_rate(self.optimizer)) self.log("train/loss", loss) @@ -56,26 +57,24 @@ def training_step(self, batch, batch_idx): return loss def validation_step(self, batch, batch_idx): - rays, rgbs = batch["rays"], batch["rgbs"] - rays = rays.squeeze() # (H*W, 3) - rgbs = rgbs.squeeze() # (H*W, 3) - results = self(rays) - log = {"val_loss": self.loss(results, rgbs)} - typ = "fine" if "rgb_fine" in results else "coarse" + batch = {k: v.squeeze() for k, v in batch.items()} + pts, viewdirs, rgbs = batch["pts"], batch["viewdirs"], batch["rgbs"] + results = self(pts, viewdirs) + log = {"val_loss": self.loss(results, batch)} + results = post_render_vision(batch, raw=results) + # typ = "fine" if "rgb_fine" in results else "coarse" if batch_idx == 0: W, H = self.hparams.img_wh - img = ( - results[f"rgb_{typ}"].view(H, W, 3).permute(2, 0, 1).cpu() - ) # (3, H, W) + img = results["rgb"].view(H, W, 3).permute(2, 0, 1).cpu() # (3, H, W) img_gt = rgbs.view(H, W, 3).permute(2, 0, 1).cpu() # (3, H, W) - depth = visualize_depth(results[f"depth_{typ}"].view(H, W)) # (3, H, W) + depth = visualize_depth(results["depth"].view(H, W)) # (3, H, W) stack = torch.stack([img_gt, img, depth]) # (3, 3, H, W) self.logger.experiment.add_images( "val/GT_pred_depth", stack, self.global_step ) - psnr_ = psnr(results[f"rgb_{typ}"], rgbs) + psnr_ = psnr(results["rgb"], rgbs) log["val_psnr"] = psnr_ self.loss_val.update(log["val_loss"]) @@ -99,15 +98,14 @@ def on_validation_epoch_end(self): self.log("val_loss_epoch", loss_epoch, prog_bar=True) def test_step(self, batch, batch_idx): - rays, rgbs = batch["rays"], batch["rgbs"] - rays = rays.squeeze() # (H*W, 3) - rgbs = rgbs.squeeze() # (H*W, 3) - results = self(rays) - loss = self.loss(results, rgbs) + batch = {k: v.squeeze() for k, v in batch.items()} + pts, viewdirs, rgbs = batch["pts"], batch["viewdirs"], batch["rgbs"] + results = self(pts, viewdirs) + loss = self.loss(results, batch) self.loss_test.update(loss) + results = post_render_vision(batch, raw=results) - typ = "fine" if "rgb_fine" in results else "coarse" - psnr_ = psnr(results[f"rgb_{typ}"], rgbs) + psnr_ = psnr(results["rgb"], rgbs) self.psnr_test.update(psnr_) return loss diff --git a/src/mase_cocotb/interfaces/streaming.py b/src/mase_cocotb/interfaces/streaming.py index 10ca26820..44767de3b 100644 --- a/src/mase_cocotb/interfaces/streaming.py +++ b/src/mase_cocotb/interfaces/streaming.py @@ -273,13 +273,17 @@ async def _driver_send(self, data) -> None: class MultiSignalStreamMonitor(Monitor): - def __init__(self, clk, data, valid, ready, check=True): + def __init__( + self, clk, data, valid, ready, check=True, signed=True, off_by_one=False + ): super().__init__(clk) self.clk = clk self.data = data self.valid = valid self.ready = ready self.check = check + self.signed = signed + self.off_by_one = off_by_one def _trigger(self): return self.valid.value == 1 and self.ready.value == 1 @@ -287,9 +291,15 @@ def _trigger(self): def _recv(self): def cast_data(value): if type(value) == list: - return [x.signed_integer for x in value] + if self.signed: + return [x.signed_integer for x in value] + else: + return [x.integer for x in value] elif type(value) == BinaryValue: - return value.signed_integer + if self.signed: + return value.signed_integer + else: + return value.integer return tuple([cast_data(target.value) for target in self.data]) @@ -297,4 +307,12 @@ def _check(self, got, exp): if self.check: for g, e in zip(got, exp): if not np.equal(g, e).all(): - raise TestFailure("\nGot \n%s, \nExpected \n%s" % (got, exp)) + diff = np.subtract(g, e) + if self.off_by_one and np.isclose(g, e, atol=1).all(): + self.log.warning( + f"Off-by-one error: {diff=}\nGot {got}\nExpected {exp}" + ) + else: + raise TestFailure( + "\nGot \n%s, \nExpected \n%s,\nDiff \n%s" % (got, exp, diff) + ) diff --git a/src/mase_components/deps.py b/src/mase_components/deps.py index a06529344..feabde7a6 100644 --- a/src/mase_components/deps.py +++ b/src/mase_components/deps.py @@ -259,6 +259,12 @@ "common", "memory", ], + "linear_layers/mxint_operators/mxint_matrix_cat": [ + "linear_layers/mxint_operators", + "linear_layers/fixed_operators", + "common", + "memory", + ], "linear_layers/mxint_operators/mxint_dot_product": [ "linear_layers/mxint_operators", "linear_layers/fixed_operators", @@ -272,6 +278,13 @@ "memory", "cast", ], + "linear_layers/mxint_operators/mxint_relu": [ + "linear_layers/mxint_operators", + "linear_layers/fixed_operators", + "common", + "memory", + "cast", + ], "linear_layers/mxint_operators/mxint_cast": [ "linear_layers/mxint_operators", "linear_layers/fixed_operators", diff --git a/src/mase_components/linear_layers/fixed_linear_layer/rtl/fixed_linear.sv b/src/mase_components/linear_layers/fixed_linear_layer/rtl/fixed_linear.sv index 1e0edce53..6a3cda7a9 100644 --- a/src/mase_components/linear_layers/fixed_linear_layer/rtl/fixed_linear.sv +++ b/src/mase_components/linear_layers/fixed_linear_layer/rtl/fixed_linear.sv @@ -104,10 +104,10 @@ module fixed_linear #( logic matmul_out_valid; logic matmul_out_ready; - logic [BIAS_PRECISION_0-1:0] bias_buffered [BIAS_PARALLELISM_DIM_0 * BIAS_PARALLELISM_DIM_1 -1:0]; + logic [BIAS_PRECISION_0-1:0] bias_buffered[BIAS_PARALLELISM_DIM_0 * BIAS_PARALLELISM_DIM_1 -1:0]; logic bias_buffered_valid, bias_buffered_ready; - logic [MMOUT_PRECISION_0-1:0] bias_casted [BIAS_PARALLELISM_DIM_0 * BIAS_PARALLELISM_DIM_1 -1:0]; + logic [MMOUT_PRECISION_0-1:0] bias_casted[BIAS_PARALLELISM_DIM_0 * BIAS_PARALLELISM_DIM_1 -1:0]; logic [MMOUT_PRECISION_0:0] add_bias_in [DATA_OUT_0_PARALLELISM_DIM_0*DATA_OUT_0_PARALLELISM_DIM_1-1:0]; logic [DATA_OUT_0_PRECISION_0 - 1:0] add_bias_in_casted [DATA_OUT_0_PARALLELISM_DIM_0*DATA_OUT_0_PARALLELISM_DIM_1-1:0]; logic add_bias_in_valid; diff --git a/src/mase_components/linear_layers/mxint_operators/rtl/log2_max_abs.sv b/src/mase_components/linear_layers/mxint_operators/rtl/log2_max_abs.sv index f3ed5beac..de1b6ebd8 100644 --- a/src/mase_components/linear_layers/mxint_operators/rtl/log2_max_abs.sv +++ b/src/mase_components/linear_layers/mxint_operators/rtl/log2_max_abs.sv @@ -1,18 +1,18 @@ `timescale 1ns / 1ps + /* Module : log2_max_abs -Description : For any given input, this module will calculate ceil(log2(abs(input + 1e-9))). - The 1e-9 is for hardware convenience, for example, if input = 4, this module will output ceil(log2(abs(4 + 1e-9)) = 3 +Description : Computes floor(log2(abs(input))). + For example, if input = 4, the output will be floor(log2(abs(4))) = 2. */ + module log2_max_abs #( parameter IN_SIZE = 2, parameter IN_WIDTH = 32, parameter OUT_WIDTH = $clog2(IN_WIDTH) + 1 ) ( - /* verilator lint_off UNUSEDSIGNAL */ input logic clk, input logic rst, - /* verilator lint_on UNUSEDSIGNAL */ input logic [ IN_WIDTH-1:0] data_in [IN_SIZE-1:0], input logic data_in_valid, output logic data_in_ready, @@ -20,31 +20,40 @@ module log2_max_abs #( output logic data_out_valid, input logic data_out_ready ); + logic [IN_WIDTH - 1:0] or_result; - logic [IN_WIDTH - 1:0] abs_data_in[IN_SIZE - 1:0]; - for (genvar i = 0; i < IN_SIZE; i++) begin - abs #( - .IN_WIDTH(IN_WIDTH) - ) abs_i ( - .data_in (data_in[i]), - .data_out(abs_data_in[i]) - ); - end + logic [IN_WIDTH - 1:0] abs_data_in[IN_SIZE-1:0]; + + // Compute absolute values + generate + for (genvar i = 0; i < IN_SIZE; i++) begin + abs #( + .IN_WIDTH(IN_WIDTH) + ) abs_i ( + .data_in (data_in[i]), + .data_out(abs_data_in[i]) + ); + end + endgenerate + + // OR-tree to find max absolute value as only the floor(log2) is needed an OR-tree is sufficient or_tree #( .IN_SIZE (IN_SIZE), - .IN_WIDTH(IN_WIDTH), + .IN_WIDTH(IN_WIDTH) ) max_bas_i ( - .clk, - .rst, - .data_in(abs_data_in), - .data_in_valid(data_in_valid), - .data_in_ready(data_in_ready), - .data_out(or_result), + .clk (clk), + .rst (rst), + .data_in (abs_data_in), + .data_in_valid (data_in_valid), + .data_in_ready (data_in_ready), + .data_out (or_result), .data_out_valid(data_out_valid), .data_out_ready(data_out_ready) ); + + // Compute log2 log2_value #( - .IN_WIDTH(IN_WIDTH), + .IN_WIDTH(IN_WIDTH) ) log2_i ( .data_in (or_result), .data_out(data_out) @@ -52,34 +61,48 @@ module log2_max_abs #( endmodule +/* +Module : log2_value +Description : Computes log2 of an input value by finding the index of the highest '1' bit. +*/ + module log2_value #( - /* verilator lint_off UNUSEDPARAM */ parameter IN_WIDTH = 32, - parameter OUT_WIDTH = $clog2(IN_WIDTH) + 1 + parameter OUT_WIDTH = $clog2(IN_WIDTH) ) ( - input logic [IN_WIDTH - 1:0] data_in, // 32-bit input number - output logic [OUT_WIDTH-1:0] data_out // 5-bit output log2 result, since log2(32-bit) is max 31 (5-bit) + input logic [IN_WIDTH - 1:0] data_in, + output logic [ OUT_WIDTH-1:0] data_out ); + integer i; - logic [$clog2(IN_WIDTH) - 1:0] unsigned_log_out; + logic [$clog2(IN_WIDTH)-1:0] unsigned_log_out; + always_comb begin - for (i = IN_WIDTH - 1; i >= 0; i = i - 1) begin - if (data_in[i] == 1) begin - unsigned_log_out = i + 1; + unsigned_log_out = 0; + for (i = IN_WIDTH - 1; i >= 0; i--) begin + if (data_in[i]) begin + unsigned_log_out = i; break; end end end - assign data_out = {1'b0, unsigned_log_out}; + + assign data_out = unsigned_log_out; + endmodule +/* +Module : abs +Description : Computes the absolute value of a signed number. +*/ + module abs #( - parameter IN_WIDTH = 8 // Parameter for bit-width, can be adjusted + parameter IN_WIDTH = 8 ) ( - input wire [IN_WIDTH-1:0] data_in, // N-bit input number - output wire [IN_WIDTH-1:0] data_out // N-bit output representing 2's complement + input wire [IN_WIDTH-1:0] data_in, + output wire [IN_WIDTH-1:0] data_out ); - // 2's complement calculation assign data_out = data_in[IN_WIDTH-1] ? ~data_in + 1 : data_in; + endmodule diff --git a/src/mase_components/linear_layers/mxint_operators/rtl/mxint_accumulator.sv b/src/mase_components/linear_layers/mxint_operators/rtl/mxint_accumulator.sv index ccbf0ddf3..c51f17a76 100644 --- a/src/mase_components/linear_layers/mxint_operators/rtl/mxint_accumulator.sv +++ b/src/mase_components/linear_layers/mxint_operators/rtl/mxint_accumulator.sv @@ -1,9 +1,10 @@ `timescale 1ns / 1ps /* Module : mxint_accumulator -Description : The accumulator for mxint. - When inputing different exponent, the mantissa will cast to the same bitwidth then accumulate. +Description : + - The accumulator is designed to accumulate a block of MxInt values. */ + module mxint_accumulator #( parameter DATA_IN_0_PRECISION_0 = 8, parameter DATA_IN_0_PRECISION_1 = 4, @@ -12,89 +13,147 @@ module mxint_accumulator #( parameter DATA_OUT_0_PRECISION_0 = DATA_IN_0_PRECISION_0 + 2 ** DATA_IN_0_PRECISION_1 + $clog2( IN_DEPTH ), - parameter DATA_OUT_0_PRECISION_1 = DATA_IN_0_PRECISION_1 + parameter DATA_OUT_0_PRECISION_1 = DATA_IN_0_PRECISION_1 + $clog2($clog2(IN_DEPTH) + 1), + localparam COUNTER_WIDTH = $clog2(IN_DEPTH) ) ( input logic clk, input logic rst, + // Input Data input logic [DATA_IN_0_PRECISION_0-1:0] mdata_in_0 [BLOCK_SIZE - 1:0], input logic [DATA_IN_0_PRECISION_1-1:0] edata_in_0, input logic data_in_0_valid, output logic data_in_0_ready, + // Output Data output logic [DATA_OUT_0_PRECISION_0-1:0] mdata_out_0 [BLOCK_SIZE - 1:0], output logic [DATA_OUT_0_PRECISION_1-1:0] edata_out_0, output logic data_out_0_valid, - input logic data_out_0_ready + input logic data_out_0_ready, + output logic [ COUNTER_WIDTH:0] accum_count ); - // 1-bit wider so IN_DEPTH also fits. - localparam COUNTER_WIDTH = $clog2(IN_DEPTH); - logic [COUNTER_WIDTH:0] counter; + localparam RIGHT_PADDING = 2 ** DATA_IN_0_PRECISION_1; + localparam LEFT_PADDING = $clog2(IN_DEPTH); + + localparam EXP_IN_BIAS = 2 ** (DATA_IN_0_PRECISION_1 - 1) - 1; + localparam EXP_OUT_BIAS = 2 ** (DATA_OUT_0_PRECISION_1 - 1) - 1; /* verilator lint_off WIDTH */ - assign data_in_0_ready = (counter != IN_DEPTH) || data_out_0_ready; - assign data_out_0_valid = (counter == IN_DEPTH); + assign data_in_0_ready = (accum_count != IN_DEPTH) || data_out_0_ready; + assign data_out_0_valid = (accum_count == IN_DEPTH); /* verilator lint_on WIDTH */ - // mantissa shift - logic [DATA_OUT_0_PRECISION_0 - 1:0] shifted_mdata_in_0[BLOCK_SIZE - 1:0]; - logic [DATA_OUT_0_PRECISION_0 - 1:0] shifted_mdata_out_0[BLOCK_SIZE - 1:0]; + logic signed [DATA_OUT_0_PRECISION_0 - 1:0] padded_mdata_in_0 [BLOCK_SIZE - 1:0]; + logic signed [DATA_OUT_0_PRECISION_0 - 1:0] shifted_mdata_in_0 [BLOCK_SIZE - 1:0]; + logic signed [DATA_OUT_0_PRECISION_0 - 1:0] shifted_mdata_out_0[BLOCK_SIZE - 1:0]; + logic signed [DATA_OUT_0_PRECISION_0 - 1:0] tmp_accumulator [BLOCK_SIZE - 1:0]; + + logic exponent_increment [BLOCK_SIZE - 1:0]; + + logic [ DATA_IN_0_PRECISION_1 - 1:0] max_exponent_d; + logic [ DATA_IN_0_PRECISION_1 - 1:0] max_exponent_q; - logic no_value_in_register; - logic [DATA_IN_0_PRECISION_1 - 1:0] exp_min; + logic signed [DATA_OUT_0_PRECISION_1 - 1:0] shift; + logic no_reg_value; + + + // ============================= + // Exponent Calculation + // ============================= + assign no_reg_value =(accum_count == 0 || (data_out_0_valid && data_out_0_ready && data_in_0_valid)); + assign max_exponent_d = (max_exponent_q < edata_in_0) ? edata_in_0 : max_exponent_q; + assign shift = max_exponent_q - edata_in_0; - assign no_value_in_register =(counter == 0 || (data_out_0_valid && data_out_0_ready && data_in_0_valid)); - assign exp_min = ($signed(edata_out_0) > $signed(edata_in_0)) ? edata_in_0 : edata_out_0; - // counter always_ff @(posedge clk) - if (rst) counter <= 0; + if (rst) accum_count <= 0; else begin + if (data_out_0_valid) begin + if (data_out_0_ready) begin - if (data_in_0_valid) counter <= 1; - else counter <= 0; + if (data_in_0_valid) accum_count <= 1; + else accum_count <= 0; end - end else if (data_in_0_valid && data_in_0_ready) counter <= counter + 1; + + end else if (data_in_0_valid && data_in_0_ready) accum_count <= accum_count + 1; end - // mantissa - - for (genvar i = 0; i < BLOCK_SIZE; i++) begin : mantissa_block - // mantissa shift - for (genvar j = 0; j < 2 ** DATA_IN_0_PRECISION_1; j++) begin : static_shift - always_comb begin - if (($signed(edata_in_0) - $signed(exp_min)) == j) - shifted_mdata_in_0[i] = no_value_in_register ? $signed( - mdata_in_0[i] - ) : $signed( - mdata_in_0[i] - ) <<< j; - if (($signed(edata_out_0) - $signed(exp_min)) == j) - shifted_mdata_out_0[i] = $signed(mdata_out_0[i]) <<< j; + + // ============================= + // Mantissa Shift and Accumulation + // ============================= + for (genvar i = 0; i < BLOCK_SIZE; i++) begin + + always_comb begin + + padded_mdata_in_0[i] = { + {LEFT_PADDING{mdata_in_0[i][DATA_IN_0_PRECISION_0-1]}}, mdata_in_0[i], {RIGHT_PADDING{1'b0}} + }; + + if (no_reg_value) begin + shifted_mdata_in_0[i] = padded_mdata_in_0[i]; + shifted_mdata_out_0[i] = '0; + end else if (shift > 0) begin + shifted_mdata_in_0[i] = padded_mdata_in_0[i] >>> shift; + shifted_mdata_out_0[i] = mdata_out_0[i]; + end else begin + shifted_mdata_in_0[i] = padded_mdata_in_0[i]; + shifted_mdata_out_0[i] = $signed(mdata_out_0[i]) >>> -shift; end + end - // mantissa out - always_ff @(posedge clk) + + end + + // ============================= + // Mantissa Output Update Logic + // ============================= + + genvar i; + for (i = 0; i < BLOCK_SIZE; i++) begin + always_ff @(posedge clk) begin if (rst) mdata_out_0[i] <= '0; else begin if (data_out_0_valid) begin if (data_out_0_ready) begin + if (data_in_0_valid) mdata_out_0[i] <= shifted_mdata_in_0[i]; else mdata_out_0[i] <= '0; + end end else if (data_in_0_valid && data_in_0_ready) - mdata_out_0[i] <= $signed(shifted_mdata_out_0[i]) + $signed(shifted_mdata_in_0[i]); + mdata_out_0[i] <= shifted_mdata_out_0[i] + shifted_mdata_in_0[i]; end + end end - localparam signed [DATA_IN_0_PRECISION_1 - 1:0] MAXIMUM_EXPONENTIAL = 2**(DATA_IN_0_PRECISION_1 - 1) - 1; - // exponent - always_ff @(posedge clk) - if (rst) edata_out_0 <= MAXIMUM_EXPONENTIAL; - else if (data_out_0_valid) begin + + // ============================= + // Exponent Output Update Logic + // ============================= + + always_ff @(posedge clk) begin + + if (rst) begin + edata_out_0 <= '0; + max_exponent_q <= '0; + + end else if (data_out_0_valid) begin if (data_out_0_ready) begin - if (data_in_0_valid) edata_out_0 <= edata_in_0; - else edata_out_0 <= MAXIMUM_EXPONENTIAL; + + if (data_in_0_valid) begin + edata_out_0 <= edata_in_0 - EXP_IN_BIAS + EXP_OUT_BIAS + LEFT_PADDING; + max_exponent_q <= edata_in_0; + end else begin + edata_out_0 <= '0; + max_exponent_q <= '0; + end + end - end else if (data_in_0_valid && data_in_0_ready) edata_out_0 <= exp_min; + end else if (data_in_0_valid && data_in_0_ready) begin + max_exponent_q <= max_exponent_d; + edata_out_0 <= max_exponent_d - EXP_IN_BIAS + EXP_OUT_BIAS + LEFT_PADDING; + end + + end endmodule diff --git a/src/mase_components/linear_layers/mxint_operators/rtl/mxint_cast.sv b/src/mase_components/linear_layers/mxint_operators/rtl/mxint_cast.sv index c5d02830f..647479928 100644 --- a/src/mase_components/linear_layers/mxint_operators/rtl/mxint_cast.sv +++ b/src/mase_components/linear_layers/mxint_operators/rtl/mxint_cast.sv @@ -1,8 +1,11 @@ `timescale 1ns / 1ps + /* -Module : Mxint cast -Description : MxInt Cast between Layers. +Module : MxInt Cast +Description : MxInt Cast between Layers. The input layer does not have to be normalized. + The output layer will be normalized. */ + module mxint_cast #( parameter IN_MAN_WIDTH = 1, parameter IN_EXP_WIDTH = 1, @@ -10,28 +13,44 @@ module mxint_cast #( parameter OUT_EXP_WIDTH = 1, parameter BLOCK_SIZE = 1 ) ( - /* verilator lint_off UNUSEDSIGNAL */ - input logic clk, - input logic rst, - /* verilator lint_on UNUSEDSIGNAL */ - input logic [ IN_MAN_WIDTH-1:0] mdata_in [BLOCK_SIZE-1:0], - input logic [ IN_EXP_WIDTH-1:0] edata_in, - input logic data_in_valid, - output logic data_in_ready, + input logic clk, + input logic rst, + + // Input Data + input logic [IN_MAN_WIDTH-1:0] mdata_in [BLOCK_SIZE-1:0], + input logic [IN_EXP_WIDTH-1:0] edata_in, + input logic data_in_valid, + output logic data_in_ready, + + // Output Data output logic [OUT_MAN_WIDTH-1:0] mdata_out [BLOCK_SIZE-1:0], output logic [OUT_EXP_WIDTH-1:0] edata_out, output logic data_out_valid, input logic data_out_ready ); - //get max_abs_value of input + + // ============================= + // Check Parameters + // ============================= + + initial begin + assert (IN_EXP_WIDTH > 2) + else $fatal("IN_EXP_WIDTH must be greater than 2"); + assert (IN_MAN_WIDTH > 3) + else $fatal("IN_MAN_WIDTH must be greater than 3"); + assert (OUT_EXP_WIDTH > 2) + else $fatal("OUT_EXP_WIDTH must be greater than 2"); + assert (OUT_MAN_WIDTH > 3) + else $fatal("OUT_MAN_WIDTH must be greater than 3"); + end + + // ============================= + // Internal Signals + // ============================= + logic data_for_max_valid, data_for_max_ready, data_for_out_valid, data_for_out_ready; - split2 #() split_i ( - .data_in_valid (data_in_valid), - .data_in_ready (data_in_ready), - .data_out_valid({data_for_max_valid, data_for_out_valid}), - .data_out_ready({data_for_max_ready, data_for_out_ready}) - ); - logic [IN_MAN_WIDTH-1:0] mbuffer_data_for_out [BLOCK_SIZE-1:0]; + logic signed [IN_MAN_WIDTH-1:0] mbuffer_data_for_out[BLOCK_SIZE-1:0]; + logic [IN_MAN_WIDTH-1:0] fifo_out[BLOCK_SIZE-1:0]; logic [IN_EXP_WIDTH-1:0] ebuffer_data_for_out; logic buffer_data_for_out_valid, buffer_data_for_out_ready; @@ -39,16 +58,36 @@ module mxint_cast #( logic [LOG2_WIDTH - 1:0] log2_max_value; logic log2_max_value_valid, log2_max_value_ready; - localparam EBIAS = 2 ** (OUT_EXP_WIDTH - 1); + localparam EBIAS_OUT = 2 ** (OUT_EXP_WIDTH - 1) - 1; + localparam EBIAS_IN = 2 ** (IN_EXP_WIDTH - 1) - 1; localparam LOSSLESSS_EDATA_WIDTH = max(LOG2_WIDTH, IN_EXP_WIDTH, OUT_EXP_WIDTH) + 2; localparam FIFO_DEPTH = $clog2(BLOCK_SIZE); - logic [LOSSLESSS_EDATA_WIDTH - 1:0] edata_out_full; + logic signed [LOSSLESSS_EDATA_WIDTH - 1:0] edata_out_full; + + localparam MAX_DATA_OUT = 2 ** (OUT_MAN_WIDTH - 1) - 1; + localparam MIN_DATA_OUT = -MAX_DATA_OUT; + + // ============================= + // Handshake Signals + // ============================= + + split2 split_i ( + .data_in_valid (data_in_valid), + .data_in_ready (data_in_ready), + .data_out_valid({data_for_max_valid, data_for_out_valid}), + .data_out_ready({data_for_max_ready, data_for_out_ready}) + ); + + // ============================= + // Compute Log2 Max Value + // ============================= + log2_max_abs #( .IN_SIZE (BLOCK_SIZE), - .IN_WIDTH(IN_MAN_WIDTH), + .IN_WIDTH(IN_MAN_WIDTH) ) max_bas_i ( - .clk, - .rst, + .clk(clk), + .rst(rst), .data_in(mdata_in), .data_in_valid(data_for_max_valid), .data_in_ready(data_for_max_ready), @@ -57,14 +96,24 @@ module mxint_cast #( .data_out_ready(log2_max_value_ready) ); + // ============================= + // FIFO + // ============================= + + if (FIFO_DEPTH == 0) begin + always_comb begin - mbuffer_data_for_out = mdata_in; + for (int i = 0; i < BLOCK_SIZE; i++) begin + mbuffer_data_for_out[i] = $signed(mdata_in[i]); + end ebuffer_data_for_out = edata_in; buffer_data_for_out_valid = data_for_out_valid; data_for_out_ready = buffer_data_for_out_ready; end + end else begin + unpacked_mx_fifo #( .DEPTH(FIFO_DEPTH), .MAN_WIDTH(IN_MAN_WIDTH), @@ -77,57 +126,126 @@ module mxint_cast #( .edata_in(edata_in), .data_in_valid(data_for_out_valid), .data_in_ready(data_for_out_ready), - .mdata_out(mbuffer_data_for_out), + .mdata_out(fifo_out), .edata_out(ebuffer_data_for_out), .data_out_valid(buffer_data_for_out_valid), .data_out_ready(buffer_data_for_out_ready) ); + + always_comb begin + for (int i = 0; i < BLOCK_SIZE; i++) begin + mbuffer_data_for_out[i] = $signed(fifo_out[i]); + end + end + end - join2 #() join_inst ( + + // ============================= + // Handshake Signals + // ============================= + + join2 join_inst ( .data_in_ready ({buffer_data_for_out_ready, log2_max_value_ready}), .data_in_valid ({buffer_data_for_out_valid, log2_max_value_valid}), .data_out_valid(data_out_valid), .data_out_ready(data_out_ready) ); - assign edata_out_full = $signed(log2_max_value) + $signed(ebuffer_data_for_out) - EBIAS; - // clamp - signed_clamp #( - .IN_WIDTH (LOSSLESSS_EDATA_WIDTH), - .OUT_WIDTH(OUT_EXP_WIDTH) - ) exp_clamp ( - .in_data (edata_out_full), - .out_data(edata_out) - ); - localparam SHIFT_WIDTH = max(OUT_EXP_WIDTH, IN_EXP_WIDTH, 0) + 1; - logic [SHIFT_WIDTH - 1:0] shift_value; - assign shift_value = $signed(edata_out) - $signed(ebuffer_data_for_out); - logic [SHIFT_WIDTH - 1:0] abs_shift_value; - assign abs_shift_value = (shift_value[SHIFT_WIDTH-1]) ? (~shift_value + 1) : shift_value; - logic [IN_MAN_WIDTH + EBIAS - 1:0] shift_buffer_data_for_out[BLOCK_SIZE - 1:0]; + // ============================= + // Compute Output Exponent + // ============================= + + assign edata_out_full = log2_max_value - IN_MAN_WIDTH + 2 + ebuffer_data_for_out - EBIAS_IN + EBIAS_OUT; + + always_comb begin + + if (log2_max_value == 0) edata_out = 0; + + else if (edata_out_full >= (1 << OUT_EXP_WIDTH)) edata_out = (1 << OUT_EXP_WIDTH) - 1; + + else if (edata_out_full < 0) edata_out = 0; + + else edata_out = edata_out_full; + + end + + // ============================= + // Compute Shift Value + // ============================= + + localparam SHIFT_WIDTH = max(LOSSLESSS_EDATA_WIDTH, OUT_MAN_WIDTH, 0); + logic signed [SHIFT_WIDTH - 1:0] shift_value; + logic [IN_MAN_WIDTH - 1:0] max_value; + + always_comb begin + + shift_value = $signed(edata_out_full - edata_out) + $signed(OUT_MAN_WIDTH - log2_max_value - 2); + + max_value = (1 << (OUT_MAN_WIDTH - shift_value - 1)); + + if (max_value == 0) max_value = 2 ** (IN_MAN_WIDTH - 1); + + end + + // ============================= + // Compute Output Mantissa + // ============================= + for (genvar i = 0; i < BLOCK_SIZE; i++) begin - for (genvar j = 0; j < 2 ** SHIFT_WIDTH; j++) - always_comb - if (abs_shift_value == j) - shift_buffer_data_for_out[i] = (shift_value[SHIFT_WIDTH-1]) ? $signed( - mbuffer_data_for_out[i] - ) <<< j : $signed( - mbuffer_data_for_out[i] - ) >>> j; - signed_clamp #( - .IN_WIDTH (IN_MAN_WIDTH + EBIAS), - .OUT_WIDTH(OUT_MAN_WIDTH) - ) exp_clamp ( - .in_data (shift_buffer_data_for_out[i]), - .out_data(mdata_out[i]) - ); + + always_comb begin + + if (mbuffer_data_for_out[i] == 0) mdata_out[i] = 0; + + else if ((shift_value > 0) && (shift_value >= OUT_MAN_WIDTH)) + if (mbuffer_data_for_out[i] < 0) mdata_out[i] = MIN_DATA_OUT; + else mdata_out[i] = MAX_DATA_OUT; + + // This is really stupid, but system verilog has poor support for signed arithmetic of large numbers + // So -shift_value != twos_complement(shift_value) hence: + else if ((shift_value < 0) && (twos_complement(shift_value) >= IN_MAN_WIDTH)) + if (mbuffer_data_for_out[i] < 0) mdata_out[i] = -1; + else mdata_out[i] = 0; + + else if ((mbuffer_data_for_out[i] > 0) && (mbuffer_data_for_out[i] >= max_value)) + mdata_out[i] = MAX_DATA_OUT; + + else if ((mbuffer_data_for_out[i] < 0) && (twos_complement( + mbuffer_data_for_out[i] + ) >= max_value)) + mdata_out[i] = MIN_DATA_OUT; + + else if (shift_value >= 0) mdata_out[i] = mbuffer_data_for_out[i] <<< shift_value; + + else mdata_out[i] = mbuffer_data_for_out[i] >>> twos_complement(shift_value); + + end + end + + // ============================= + // two's complement + // ============================= + + localparam TWOS_COMPLEMENT_WIDTH = max(IN_MAN_WIDTH, SHIFT_WIDTH, 0); + + function [TWOS_COMPLEMENT_WIDTH - 1:0] twos_complement; + input [TWOS_COMPLEMENT_WIDTH - 1:0] x; + begin + twos_complement = ~x + 1; + end + endfunction + endmodule + +// ============================= +// Max function +// ============================= + function [31:0] max; input [31:0] x, y, z; begin - if (x > y && x > z) max = x; - else if (y > z) max = y; - else max = z; + max = (x > y && x > z) ? x : (y > z) ? y : z; end endfunction + diff --git a/src/mase_components/linear_layers/mxint_operators/rtl/mxint_dot_product.sv b/src/mase_components/linear_layers/mxint_operators/rtl/mxint_dot_product.sv index 8b24e9009..27ba12401 100644 --- a/src/mase_components/linear_layers/mxint_operators/rtl/mxint_dot_product.sv +++ b/src/mase_components/linear_layers/mxint_operators/rtl/mxint_dot_product.sv @@ -31,6 +31,13 @@ module mxint_dot_product #( output data_out_0_valid, input data_out_0_ready ); + // Exponent Biases + localparam DATA_IN_BIAS = 2 ** (DATA_IN_0_PRECISION_1 - 1) - 1; + localparam WEIGHT_BIAS = 2 ** (WEIGHT_PRECISION_1 - 1) - 1; + localparam DATA_OUT_BIAS = 2 ** (DATA_OUT_0_PRECISION_1 - 1) - 1; + + localparam FRAC_POINT_ADJ = DATA_OUT_0_PRECISION_0 - 2 - ((DATA_IN_0_PRECISION_0 - 2) + (WEIGHT_PRECISION_0 - 2)); + logic [DATA_OUT_0_PRECISION_0-1:0] mdp [BLOCK_SIZE-1:0]; logic [DATA_OUT_0_PRECISION_1-1:0] edp; logic mdp_valid, mdp_ready; @@ -115,7 +122,11 @@ module mxint_dot_product #( .empty(), .full() ); - assign edata_out_0 = $signed(buffer_eweight) + $signed(buffer_edata_in_0); + assign edata_out_0 = ($signed( + buffer_eweight - WEIGHT_BIAS + ) + $signed( + buffer_edata_in_0 - DATA_IN_BIAS + )) + DATA_OUT_BIAS + FRAC_POINT_ADJ; fixed_dot_product #( .IN_WIDTH(DATA_IN_0_PRECISION_0), .WEIGHT_WIDTH(WEIGHT_PRECISION_0), diff --git a/src/mase_components/linear_layers/mxint_operators/rtl/mxint_linear.sv b/src/mase_components/linear_layers/mxint_operators/rtl/mxint_linear.sv index b345241db..5682a211e 100644 --- a/src/mase_components/linear_layers/mxint_operators/rtl/mxint_linear.sv +++ b/src/mase_components/linear_layers/mxint_operators/rtl/mxint_linear.sv @@ -76,6 +76,15 @@ module mxint_linear #( output logic data_out_0_valid, input logic data_out_0_ready ); + initial begin + assert (DATA_IN_0_PARALLELISM_DIM_0 == WEIGHT_PARALLELISM_DIM_1); + + assert (WEIGHT_PARALLELISM_DIM_1 == WEIGHT_PARALLELISM_DIM_0); + + assert ((DATA_IN_0_TENSOR_SIZE_DIM_0 % DATA_IN_0_PARALLELISM_DIM_0) == 0); + assert ((DATA_IN_0_TENSOR_SIZE_DIM_1 % DATA_IN_0_PARALLELISM_DIM_1) == 0); + end + logic [DATA_IN_0_PRECISION_0-1:0]circular_mdata_in_0[DATA_IN_0_PARALLELISM_DIM_0 * DATA_IN_0_PARALLELISM_DIM_1-1:0]; logic [DATA_IN_0_PRECISION_1-1:0] circular_edata_in_0; logic circular_data_in_0_valid, circular_data_in_0_ready; @@ -150,11 +159,11 @@ module mxint_linear #( localparam FDP_WIDTH = DATA_IN_0_PRECISION_0 + WEIGHT_PRECISION_0 + $clog2( DATA_IN_0_PARALLELISM_DIM_0 ); - localparam FDP_EXP_WIDTH = (WEIGHT_PRECISION_1 > DATA_IN_0_PRECISION_1)? WEIGHT_PRECISION_1 + 1: DATA_IN_0_PRECISION_1 + 1; - localparam ACC_WIDTH = FDP_WIDTH + $clog2(IN_0_DEPTH_DIM_0) + 2 ** FDP_EXP_WIDTH; - localparam ACC_EXP_WIDTH = FDP_EXP_WIDTH; - localparam LOSSLESS_OUT_WIDTH = ACC_WIDTH + HAS_BIAS; - localparam LOSSLESS_OUT_EXP_WIDTH = ACC_EXP_WIDTH; + localparam FDP_EXP_WIDTH_IN =(WEIGHT_PRECISION_1 > DATA_IN_0_PRECISION_1)? WEIGHT_PRECISION_1 + 1: DATA_IN_0_PRECISION_1 + 1; + localparam FDP_EXP_WIDTH_OUT = FDP_EXP_WIDTH_IN + $clog2($clog2(IN_0_DEPTH_DIM_0 + HAS_BIAS) + 1); + localparam ACC_WIDTH = FDP_WIDTH + $clog2(IN_0_DEPTH_DIM_0 + HAS_BIAS) + 2 ** FDP_EXP_WIDTH_IN; + localparam LOSSLESS_OUT_WIDTH = ACC_WIDTH; + localparam LOSSLESS_OUT_EXP_WIDTH = FDP_EXP_WIDTH_OUT; /* verilator lint_off UNUSEDSIGNAL */ // Assume the parallelised hardware above have the same arrival time // which means that they always have the same state. So we can just @@ -163,15 +172,15 @@ module mxint_linear #( fdp_data_ready, fdp_weight_ready; assign circular_weight_ready = fdp_weight_ready[0]; assign circular_data_in_0_ready = fdp_data_ready[0]; - logic [FDP_EXP_WIDTH-1:0] fdp_edata_out [DATA_IN_0_PARALLELISM_DIM_1 * WEIGHT_PARALLELISM_DIM_1 - 1:0]; + logic [FDP_EXP_WIDTH_IN-1:0] fdp_edata_out [DATA_IN_0_PARALLELISM_DIM_1 * WEIGHT_PARALLELISM_DIM_1 - 1:0]; logic [DATA_IN_0_PARALLELISM_DIM_1 * WEIGHT_PARALLELISM_DIM_1 - 1:0] fdp_data_out_valid; - logic [FDP_WIDTH-1:0] acc_mdata_in [DATA_IN_0_PARALLELISM_DIM_1 * WEIGHT_PARALLELISM_DIM_1 - 1:0]; - logic [FDP_EXP_WIDTH-1:0] acc_edata_in; + logic [FDP_WIDTH-1:0] acc_mdata_in[DATA_IN_0_PARALLELISM_DIM_1 * WEIGHT_PARALLELISM_DIM_1 - 1:0]; + logic [FDP_EXP_WIDTH_IN-1:0] acc_edata_in; logic acc_data_in_valid, acc_data_in_ready; logic [ ACC_WIDTH-1:0] acc_mdata_out [DATA_IN_0_PARALLELISM_DIM_1 * DATA_OUT_0_PARALLELISM_DIM_0-1:0]; - logic [FDP_EXP_WIDTH-1:0] acc_edata_out; + logic [FDP_EXP_WIDTH_OUT-1:0] acc_edata_out; logic acc_data_out_valid, acc_data_out_ready; logic [LOSSLESS_OUT_WIDTH-1:0] cast_mdata_out_0[DATA_OUT_0_PARALLELISM_DIM_1 * DATA_OUT_0_PARALLELISM_DIM_0-1:0]; logic [LOSSLESS_OUT_EXP_WIDTH-1:0] cast_edata_out_0; @@ -180,7 +189,7 @@ module mxint_linear #( // and each one computes for IN_0_DEPTH iterations for each inputs. for (genvar i = 0; i < DATA_IN_0_PARALLELISM_DIM_1; i = i + 1) begin : out_dim_1 for (genvar j = 0; j < WEIGHT_PARALLELISM_DIM_1; j = j + 1) begin : out_dim_0 - // Assume the weight are transposed and partitioned + // Assume the weight are transposed and partitioned logic [WEIGHT_PRECISION_0-1:0] current_mweight[WEIGHT_PARALLELISM_DIM_0-1:0]; logic [WEIGHT_PRECISION_1-1:0] current_eweight; logic [DATA_IN_0_PRECISION_0-1:0] current_mdata[WEIGHT_PARALLELISM_DIM_0-1:0]; @@ -218,56 +227,68 @@ module mxint_linear #( assign acc_data_in_valid = fdp_data_out_valid[0]; assign acc_edata_in = fdp_edata_out[0]; + localparam FDP_EXP_BIAS = 2 ** (FDP_EXP_WIDTH_IN - 1) - 1; + localparam BIAS_EXP_BIAS = 2 ** (BIAS_PRECISION_1 - 1) - 1; + + logic [$clog2(IN_0_DEPTH_DIM_0 + HAS_BIAS):0] accum_count; + logic acc_ready; + logic acc_valid; + logic add_bias; + + logic [FDP_WIDTH-1:0] acc_mdata[DATA_IN_0_PARALLELISM_DIM_1 * WEIGHT_PARALLELISM_DIM_1 - 1:0]; + logic [FDP_EXP_WIDTH_IN-1:0] acc_edata; + logic acc_data_valid, acc_data_ready; + mxint_accumulator #( .DATA_IN_0_PRECISION_0(FDP_WIDTH), - .DATA_IN_0_PRECISION_1(FDP_EXP_WIDTH), - .IN_DEPTH(IN_0_DEPTH_DIM_0), - .BLOCK_SIZE(DATA_OUT_0_PARALLELISM_DIM_1 * DATA_OUT_0_PARALLELISM_DIM_0) + .DATA_IN_0_PRECISION_1(FDP_EXP_WIDTH_IN), + .IN_DEPTH (IN_0_DEPTH_DIM_0 + HAS_BIAS), + .BLOCK_SIZE (DATA_OUT_0_PARALLELISM_DIM_1 * DATA_OUT_0_PARALLELISM_DIM_0) ) accumulator_inst ( - .clk(clk), - .rst(rst), - .mdata_in_0(acc_mdata_in), - .edata_in_0(acc_edata_in), - .data_in_0_valid(acc_data_in_valid), - .data_in_0_ready(acc_data_in_ready), - .mdata_out_0(acc_mdata_out), - .edata_out_0(acc_edata_out), + .clk (clk), + .rst (rst), + .mdata_in_0 (acc_mdata), + .edata_in_0 (acc_edata), + .data_in_0_valid (acc_valid), + .data_in_0_ready (acc_ready), + .mdata_out_0 (acc_mdata_out), + .edata_out_0 (acc_edata_out), .data_out_0_valid(acc_data_out_valid), - .data_out_0_ready(acc_data_out_ready) + .data_out_0_ready(acc_data_out_ready), + .accum_count (accum_count) ); - - logic [BIAS_PRECISION_0-1:0] mbias_sext[DATA_OUT_0_PARALLELISM_DIM_1 * DATA_OUT_0_PARALLELISM_DIM_0-1:0]; - logic [LOSSLESS_OUT_WIDTH-1:0] shifted_mbias[DATA_OUT_0_PARALLELISM_DIM_1 * DATA_OUT_0_PARALLELISM_DIM_0-1:0]; - logic [FDP_EXP_WIDTH - 1:0] exp_difference; - logic [FDP_EXP_WIDTH - 1:0] abs_shift_value; if (HAS_BIAS) begin : bias_cast - for (genvar k = 0; k < DATA_OUT_0_PARALLELISM_DIM_1; k++) - assign mbias_sext[(k+1)*DATA_OUT_0_PARALLELISM_DIM_0 - 1:k*DATA_OUT_0_PARALLELISM_DIM_0] = circular_mbias; - join2 #() acc_join_inst ( - .data_in_ready ({circular_bias_ready, acc_data_out_ready}), - .data_in_valid ({circular_bias_valid, acc_data_out_valid}), - .data_out_valid(cast_data_out_0_valid), - .data_out_ready(cast_data_out_0_ready) - ); - assign exp_difference = $signed(circular_ebias) - $signed(acc_edata_out); - assign abs_shift_value = exp_difference[FDP_EXP_WIDTH - 1]? (~exp_difference + 1): exp_difference; - for (genvar m = 0; m < DATA_OUT_0_PARALLELISM_DIM_0 * DATA_OUT_0_PARALLELISM_DIM_1; m++) begin - assign shifted_mbias[m] = exp_difference[FDP_EXP_WIDTH-1] ? $signed( - mbias_sext[m] - ) >>> abs_shift_value : $signed( - mbias_sext[m] - ) <<< abs_shift_value; - assign cast_mdata_out_0[m] = $signed(shifted_mbias[m]) + $signed(acc_mdata_out[m]); + + assign add_bias = accum_count == IN_0_DEPTH_DIM_0; + assign circular_bias_ready = add_bias && acc_ready; + assign acc_data_in_ready = !add_bias && acc_ready; + assign acc_valid = add_bias ? circular_bias_valid : acc_data_in_valid; + assign acc_edata = add_bias ? circular_ebias - BIAS_EXP_BIAS + FDP_EXP_BIAS : acc_edata_in; + + for (genvar j = 0; j < DATA_IN_0_PARALLELISM_DIM_1; j++) begin + for (genvar i = 0; i < BIAS_PARALLELISM_DIM_0; i++) begin + assign acc_mdata[j*DATA_IN_0_PARALLELISM_DIM_0+i] = add_bias ? {circular_mbias[i], {(FDP_WIDTH - BIAS_PRECISION_0){1'b0}}} : acc_mdata_in[j*DATA_IN_0_PARALLELISM_DIM_0+i]; + end end - assign cast_edata_out_0 = acc_edata_out; + end else begin - assign acc_data_out_ready = cast_data_out_0_ready; - assign cast_data_out_0_valid = acc_data_out_valid; - assign cast_mdata_out_0 = acc_mdata_out; - assign cast_edata_out_0 = acc_edata_out; - assign bias_ready = 1; + assign acc_mdata = acc_mdata_in; + assign acc_edata = acc_edata_in; + assign acc_data_in_ready = acc_ready; + assign acc_valid = acc_data_in_valid; + end + + always_comb begin + for (int i = 0; i < DATA_OUT_0_PARALLELISM_DIM_1 * DATA_OUT_0_PARALLELISM_DIM_0; i++) begin + cast_mdata_out_0[i] = acc_mdata_out[i]; + end end + + assign acc_data_out_ready = cast_data_out_0_ready; + assign cast_data_out_0_valid = acc_data_out_valid; + assign cast_edata_out_0 = acc_edata_out; + mxint_cast #( .IN_MAN_WIDTH(LOSSLESS_OUT_WIDTH), .IN_EXP_WIDTH(LOSSLESS_OUT_EXP_WIDTH), diff --git a/src/mase_components/linear_layers/mxint_operators/rtl/mxint_matrix_cat.sv b/src/mase_components/linear_layers/mxint_operators/rtl/mxint_matrix_cat.sv new file mode 100644 index 000000000..87b53f153 --- /dev/null +++ b/src/mase_components/linear_layers/mxint_operators/rtl/mxint_matrix_cat.sv @@ -0,0 +1,295 @@ +`timescale 1ns / 1ps + +/* +* This module implements torch.cat([t1, t2], dim=-1) +* +* This module assumes that the streaming is always happening in the last +* direciton, i.e the concatenating direction. +* +* E.g to concatenate the following 2 x 2 matricies: +* +* 1 2 5 6 --> 1 2 5 6 +* 3 4 7 8 --> 3 4 7 8 +* +* The data is expected to be streamed left to right in blocks. So for a block +* size of 1, this module expects 1 2 3 4, and 5 6 7 8 on each input interface. +* The output would be 1 2 5 6 3 4 7 8, hence concatenating the two matricies. +* +* The limitation of this module is that it requires the PARALLELISM parameters +* of the input and output interfaces to be the same. This mainly a limitation +* of mxint_cast. Can later remove this by adding a flow controller on the +* input of the module. The flow controller would ensure that the concatenation +* internally happens using uniform PARALLELISM. This is not a trivial +*/ + +module mxint_matrix_cat #( + //---------------------------------------------------// + //------------- Software -----------------// + //---------------------------------------------------// + + parameter DATA_IN_0_PRECISION_0 = 1, + parameter DATA_IN_0_PRECISION_1 = 1, + + /* verilator lint_off UNUSEDPARAM */ + parameter DATA_IN_0_TENSOR_SIZE_DIM_0 = 1, + parameter DATA_IN_0_TENSOR_SIZE_DIM_1 = 1, + /* verilator lint_on UNUSEDPARAM */ + + parameter DATA_IN_0_PARALLELISM_DIM_0 = 1, + parameter DATA_IN_0_PARALLELISM_DIM_1 = 1, + + parameter DATA_IN_1_PRECISION_0 = 1, + parameter DATA_IN_1_PRECISION_1 = 1, + parameter DATA_IN_1_TENSOR_SIZE_DIM_0 = 1, + parameter DATA_IN_1_TENSOR_SIZE_DIM_1 = 1, + parameter DATA_IN_1_PARALLELISM_DIM_0 = 1, + parameter DATA_IN_1_PARALLELISM_DIM_1 = 1, + + parameter DATA_OUT_0_PRECISION_0 = 1, + parameter DATA_OUT_0_PRECISION_1 = 1, + + //---------------------------------------------------// + //------------- Hardware Aliases --------------// + //---------------------------------------------------// + + localparam CONST_DIM = DATA_IN_0_PARALLELISM_DIM_0, + /* verilator lint_off UNUSEDPARAM */ + parameter DATA_OUT_0_TENSOR_SIZE_DIM_0 = CONST_DIM, + /* verilator lint_on UNUSEDPARAM */ + parameter DATA_OUT_0_TENSOR_SIZE_DIM_1 = DATA_IN_0_TENSOR_SIZE_DIM_1 + DATA_IN_1_TENSOR_SIZE_DIM_1, + + localparam DATA_OUT_0_PARALLELISM_DIM_0 = CONST_DIM, + localparam DATA_OUT_0_PARALLELISM_DIM_1 = DATA_IN_0_PARALLELISM_DIM_1, + + localparam BLOCK_SIZE = CONST_DIM * DATA_IN_0_PARALLELISM_DIM_1, + localparam CONCAT_DIM_0 = DATA_IN_0_TENSOR_SIZE_DIM_1 / DATA_IN_0_PARALLELISM_DIM_1, + localparam CONCAT_DIM_1 = DATA_IN_1_TENSOR_SIZE_DIM_1 / DATA_IN_1_PARALLELISM_DIM_1, + + localparam PARALLELISM_0 = DATA_IN_0_PARALLELISM_DIM_0, + localparam PARALLELISM_1 = DATA_IN_0_PARALLELISM_DIM_1, + + localparam MWIDTH_IN_0 = DATA_IN_0_PRECISION_0, + localparam EWIDTH_IN_0 = DATA_IN_0_PRECISION_1, + + localparam MWIDTH_IN_1 = DATA_IN_1_PRECISION_0, + localparam EWIDTH_IN_1 = DATA_IN_1_PRECISION_1, + + localparam MWIDTH_OUT = DATA_OUT_0_PRECISION_0, + localparam EWIDTH_OUT = DATA_OUT_0_PRECISION_1, + localparam FIFO_DEPTH = DATA_IN_0_TENSOR_SIZE_DIM_1 > DATA_IN_0_TENSOR_SIZE_DIM_1 ? DATA_IN_0_TENSOR_SIZE_DIM_1 : DATA_IN_0_TENSOR_SIZE_DIM_1 +) ( + input wire clk, + input wire rst, + + input logic [MWIDTH_IN_0-1:0] mdata_in_0 [BLOCK_SIZE-1:0], + input wire [EWIDTH_IN_0-1:0] edata_in_0, + input wire data_in_0_valid, + output logic data_in_0_ready, + + input logic [MWIDTH_IN_1-1:0] mdata_in_1 [BLOCK_SIZE-1:0], + input wire [EWIDTH_IN_1-1:0] edata_in_1, + input wire data_in_1_valid, + output logic data_in_1_ready, + + output logic [ MWIDTH_OUT-1:0] mdata_out_0 [BLOCK_SIZE-1:0], + output logic [EWIDTH_IN_0-1:0] edata_out_0, + output logic data_out_0_valid, + input logic data_out_0_ready +); + + function void driveDataOut(logic [MWIDTH_OUT-1:0] mdata[BLOCK_SIZE-1:0], + logic [EWIDTH_OUT-1:0] edata); + for (int i = 0; i < BLOCK_SIZE; i++) begin + assign mdata_out_0[i] = mdata[i]; + end + + assign edata_out_0 = edata; + + endfunction + + localparam COUNTER_WIDTH = $clog2(DATA_OUT_0_TENSOR_SIZE_DIM_1) + 1; + + initial begin + assert (DATA_IN_0_PARALLELISM_DIM_1 == DATA_IN_1_PARALLELISM_DIM_1) + else $error("PARALLELISM Parameters of matrix_cat should all be equal"); + assert (DATA_IN_1_PARALLELISM_DIM_1 == DATA_OUT_0_PARALLELISM_DIM_1) + else $error("PARALLELISM Parameters of matrix_cat should all be equal"); + + assert (DATA_IN_0_PARALLELISM_DIM_0 == DATA_IN_1_PARALLELISM_DIM_0) + else $error("PARALLELISM Parameters of matrix_cat should all be equal"); + assert (DATA_IN_1_PARALLELISM_DIM_0 == DATA_OUT_0_PARALLELISM_DIM_0) + else $error("PARALLELISM Parameters of matrix_cat should all be equal"); + end + + typedef enum integer { + OUT_0, + OUT_1 + } matrix_cat_state_enum; + + matrix_cat_state_enum state_b; + matrix_cat_state_enum state_r; + + logic [COUNTER_WIDTH-1:0] out_cntr_b; + logic [COUNTER_WIDTH-1:0] out_cntr_r; + + logic [ MWIDTH_OUT-1:0] mdata_in_0_c [BLOCK_SIZE-1:0]; + logic [ EWIDTH_OUT-1:0] edata_in_0_c; + logic shift_in_0; + logic fifo_0_full; + + logic [ MWIDTH_OUT-1:0] mdata_in_1_c [BLOCK_SIZE-1:0]; + logic [ EWIDTH_OUT-1:0] edata_in_1_c; + logic shift_in_1; + logic fifo_1_full; + + logic [ MWIDTH_OUT-1:0] mdata_in_1_fifo [BLOCK_SIZE-1:0]; + logic [ EWIDTH_OUT-1:0] edata_in_1_fifo; + logic fifo_1_valid; + logic fifo_1_ready; + + logic [ MWIDTH_OUT-1:0] mdata_in_0_fifo [BLOCK_SIZE-1:0]; + logic [ EWIDTH_OUT-1:0] edata_in_0_fifo; + logic fifo_0_valid; + logic fifo_0_ready; + + if ((DATA_OUT_0_PRECISION_0 == DATA_IN_0_PRECISION_0) && (DATA_OUT_0_PRECISION_1 == DATA_IN_0_PRECISION_1)) + begin : no_cast_in_0_gen + always_comb begin + for (int i = 0; i < BLOCK_SIZE; i++) begin + mdata_in_0_c[i] = mdata_in_0[i]; + end + shift_in_0 = data_in_0_valid; + data_in_0_ready = fifo_0_full; + edata_in_0_c = edata_in_0; + end + end else begin : cast_in_0_gen + mxint_cast #( + .IN_MAN_WIDTH (MWIDTH_IN_0), + .IN_EXP_WIDTH (EWIDTH_IN_0), + .OUT_MAN_WIDTH(MWIDTH_OUT), + .OUT_EXP_WIDTH(EWIDTH_OUT), + .BLOCK_SIZE (BLOCK_SIZE) + ) cast_I ( + .clk (clk), + .rst (rst), + .mdata_in (mdata_in_0), + .edata_in (edata_in_0), + .data_in_valid (data_in_0_valid), + .data_in_ready (data_in_0_ready), + .mdata_out (mdata_in_0_c), + .edata_out (edata_in_0_c), + .data_out_valid(shift_in_0), + .data_out_ready(fifo_0_full) + ); + end + + if ((DATA_OUT_0_PRECISION_0 == DATA_IN_1_PRECISION_0) && (DATA_OUT_0_PRECISION_1 == DATA_IN_1_PRECISION_1)) + begin : no_cast_in_1_gen + always_comb begin + for (int i = 0; i < BLOCK_SIZE; i++) begin + mdata_in_1_c[i] = mdata_in_1[i]; + end + shift_in_1 = data_in_1_valid; + data_in_1_ready = fifo_1_full; + edata_in_1_c = edata_in_1; + end + end else begin : cast_in_1_gen + mxint_cast #( + .IN_MAN_WIDTH (MWIDTH_IN_1), + .IN_EXP_WIDTH (EWIDTH_IN_1), + .OUT_MAN_WIDTH(MWIDTH_OUT), + .OUT_EXP_WIDTH(EWIDTH_OUT), + .BLOCK_SIZE (BLOCK_SIZE) + ) cast_I ( + .clk (clk), + .rst (rst), + .mdata_in (mdata_in_1), + .edata_in (edata_in_1), + .data_in_valid (data_in_1_valid), + .data_in_ready (data_in_1_ready), + .mdata_out (mdata_in_1_c), + .edata_out (edata_in_1_c), + .data_out_valid(shift_in_1), + .data_out_ready(fifo_1_full) + ); + end + + /* verilator lint_off PINMISSING */ + unpacked_mx_fifo #( + .DEPTH (FIFO_DEPTH), + .MAN_WIDTH(MWIDTH_OUT), + .EXP_WIDTH(EWIDTH_OUT), + .IN_SIZE (BLOCK_SIZE) + ) fifo_0_I ( + .clk (clk), + .rst (rst), + .mdata_in (mdata_in_0_c), + .edata_in (edata_in_0_c), + .data_in_valid (shift_in_0), + .data_in_ready (fifo_0_full), + .mdata_out (mdata_in_0_fifo), + .edata_out (edata_in_0_fifo), + .data_out_valid(fifo_0_valid), + .data_out_ready(fifo_0_ready) + ); + + unpacked_mx_fifo #( + .DEPTH (FIFO_DEPTH), + .MAN_WIDTH(MWIDTH_OUT), + .EXP_WIDTH(EWIDTH_OUT), + .IN_SIZE (BLOCK_SIZE) + ) fifo_1_I ( + .clk (clk), + .rst (rst), + .mdata_in (mdata_in_1_c), + .edata_in (edata_in_1_c), + .data_in_valid (shift_in_1), + .data_in_ready (fifo_1_full), + .mdata_out (mdata_in_1_fifo), + .edata_out (edata_in_1_fifo), + .data_out_valid(fifo_1_valid), + .data_out_ready(fifo_1_ready) + ); + /* verilator lint_on PINMISSING */ + + always_comb begin + state_b = state_r; + out_cntr_b = out_cntr_r; + fifo_0_ready = fifo_0_valid && data_out_0_ready && (out_cntr_b < CONCAT_DIM_0) && (state_r == OUT_0); + fifo_1_ready = fifo_1_valid && data_out_0_ready && (out_cntr_b < CONCAT_DIM_1) && (state_r == OUT_1); + + case (state_r) + OUT_0: begin + if (fifo_0_ready) begin + out_cntr_b = out_cntr_r + 1; + driveDataOut(mdata_in_0_fifo, edata_in_0_fifo); + end else if (out_cntr_b >= CONCAT_DIM_0) begin + out_cntr_b = 0; + state_b = OUT_1; + end + end + OUT_1: begin + if (fifo_1_ready) begin + out_cntr_b = out_cntr_r + 1; + driveDataOut(mdata_in_1_fifo, edata_in_1_fifo); + end else if (out_cntr_b >= CONCAT_DIM_1) begin + out_cntr_b = 0; + state_b = OUT_0; + end + end + endcase + end + + always_ff @(posedge clk) begin + if (rst) begin + state_r <= OUT_0; + out_cntr_r <= '0; + end else begin + state_r <= state_b; + out_cntr_r <= out_cntr_b; + end + end + + assign data_out_0_valid = ((state_r == OUT_0) && fifo_0_ready) || ((state_r == OUT_1) && fifo_1_ready); + +endmodule diff --git a/src/mase_components/linear_layers/mxint_operators/rtl/mxint_relu.sv b/src/mase_components/linear_layers/mxint_operators/rtl/mxint_relu.sv new file mode 100644 index 000000000..ddcc8989e --- /dev/null +++ b/src/mase_components/linear_layers/mxint_operators/rtl/mxint_relu.sv @@ -0,0 +1,97 @@ +/* +Module : mxint_relu +Description : This module performs relu(x) on the input function + + Python equivalent: + out = torch.nn.functional.relu(x) + + x should be the dimension of (DATA_IN_0_TENSOR_SIZE_DIM_1, DATA_IN_0_TENSOR_SIZE_DIM_0) +*/ +`timescale 1ns / 1ps + +module mxint_relu #( + + parameter DATA_IN_0_PRECISION_0 = 16, + parameter DATA_IN_0_PRECISION_1 = 3, + parameter DATA_IN_0_TENSOR_SIZE_DIM_0 = 20, + parameter DATA_IN_0_TENSOR_SIZE_DIM_1 = 20, + parameter DATA_IN_0_PARALLELISM_DIM_0 = 4, + parameter DATA_IN_0_PARALLELISM_DIM_1 = 4, + + /* verilator lint_off UNUSEDPARAM */ + parameter DATA_OUT_0_PRECISION_0 = DATA_IN_0_PRECISION_0, + parameter DATA_OUT_0_PRECISION_1 = DATA_IN_0_PRECISION_1, + parameter DATA_OUT_0_TENSOR_SIZE_DIM_0 = DATA_IN_0_TENSOR_SIZE_DIM_0, + parameter DATA_OUT_0_TENSOR_SIZE_DIM_1 = DATA_IN_0_TENSOR_SIZE_DIM_1, + parameter DATA_OUT_0_PARALLELISM_DIM_0 = DATA_IN_0_PARALLELISM_DIM_0, + parameter DATA_OUT_0_PARALLELISM_DIM_1 = DATA_IN_0_PARALLELISM_DIM_1, + /* verilator lint_on UNUSEDPARAM */ + + localparam IN_0_DEPTH_DIM_0 = DATA_IN_0_TENSOR_SIZE_DIM_0 / DATA_IN_0_PARALLELISM_DIM_0, + localparam IN_0_DEPTH_DIM_1 = DATA_IN_0_TENSOR_SIZE_DIM_1 / DATA_IN_0_PARALLELISM_DIM_1 +) ( + input clk, + input rst, + + input logic [DATA_IN_0_PRECISION_0-1:0] mdata_in_0 [DATA_IN_0_PARALLELISM_DIM_0*DATA_IN_0_PARALLELISM_DIM_1-1:0], + input logic [DATA_IN_0_PRECISION_1-1:0] edata_in_0, + input logic data_in_0_valid, + output logic data_in_0_ready, + + output logic [DATA_IN_0_PRECISION_0-1:0] mdata_out_0 [DATA_IN_0_PARALLELISM_DIM_0*DATA_IN_0_PARALLELISM_DIM_1-1:0], + output logic [DATA_IN_0_PRECISION_1-1:0] edata_out_0, + output logic data_out_0_valid, + input logic data_out_0_ready + +); + // mega jank to have this work with the mase emission + initial begin + assert (DATA_IN_0_PRECISION_0 == DATA_OUT_0_PRECISION_0) + else $error("ReLU: DATA_IN_0_PRECISION_0 must be equal to DATA_OUT_0_PRECISION_0"); + assert (DATA_IN_0_PRECISION_1 == DATA_OUT_0_PRECISION_1) + else $error("ReLU: DATA_IN_0_PRECISION_1 must be equal to DATA_OUT_0_PRECISION_1"); + assert (DATA_IN_0_TENSOR_SIZE_DIM_0 == DATA_OUT_0_TENSOR_SIZE_DIM_0) + else $error("ReLU: DATA_IN_0_TENSOR_SIZE_DIM_0 must be equal to DATA_OUT_0_TENSOR_SIZE_DIM_0"); + assert (DATA_IN_0_TENSOR_SIZE_DIM_1 == DATA_OUT_0_TENSOR_SIZE_DIM_1) + else $error("ReLU: DATA_IN_0_TENSOR_SIZE_DIM_1 must be equal to DATA_OUT_0_TENSOR_SIZE_DIM_1"); + assert (DATA_IN_0_PARALLELISM_DIM_0 == DATA_OUT_0_PARALLELISM_DIM_0) + else $error("ReLU: DATA_IN_0_PARALLELISM_DIM_0 must be equal to DATA_OUT_0_PARALLELISM_DIM_0"); + assert (DATA_IN_0_PARALLELISM_DIM_1 == DATA_OUT_0_PARALLELISM_DIM_1) + else $error("ReLU: DATA_IN_0_PARALLELISM_DIM_1 must be equal to DATA_OUT_0_PARALLELISM_DIM_0"); + end + + logic [DATA_IN_0_PRECISION_0-1:0] mdata_out_0_i [DATA_IN_0_PARALLELISM_DIM_0*DATA_IN_0_PARALLELISM_DIM_1-1:0]; + logic [DATA_IN_0_PRECISION_1-1:0] edata_out_0_i; + logic data_out_0_valid_i; + logic data_out_0_ready_i; + + always_comb begin + edata_out_0_i = edata_in_0; + + for (int i = 0; i < DATA_IN_0_PARALLELISM_DIM_0 * DATA_IN_0_PARALLELISM_DIM_1; i++) begin + mdata_out_0_i[i] = mdata_in_0[i][DATA_IN_0_PRECISION_0-1] ? '0 : mdata_in_0[i]; + end + data_out_0_valid_i = data_in_0_valid; + end + + mxint_cast #( + .IN_MAN_WIDTH (DATA_IN_0_PRECISION_0), + .IN_EXP_WIDTH (DATA_IN_0_PRECISION_1), + .OUT_MAN_WIDTH(DATA_IN_0_PRECISION_0), + .OUT_EXP_WIDTH(DATA_IN_0_PRECISION_1), + .BLOCK_SIZE (DATA_IN_0_PARALLELISM_DIM_1 * DATA_IN_0_PARALLELISM_DIM_0) + ) cast_i ( + .clk (clk), + .rst (rst), + .mdata_in (mdata_out_0_i), + .edata_in (edata_out_0_i), + .data_in_valid (data_out_0_valid_i), + .data_in_ready (data_in_0_ready), + .mdata_out (mdata_out_0), + .edata_out (edata_out_0), + .data_out_valid(data_out_0_valid), + .data_out_ready(data_out_0_ready) + ); + +endmodule + diff --git a/src/mase_components/linear_layers/mxint_operators/test/log2_max_abs_tb.py b/src/mase_components/linear_layers/mxint_operators/test/log2_max_abs_tb.py index e66d2e0f2..da42aee39 100644 --- a/src/mase_components/linear_layers/mxint_operators/test/log2_max_abs_tb.py +++ b/src/mase_components/linear_layers/mxint_operators/test/log2_max_abs_tb.py @@ -38,7 +38,6 @@ def __init__(self, dut, samples=10): def sw_compute(self): ref = [] for i in range(self.samples): - breakpoint() ref.append( math.ceil(math.log2(max([abs(data) for data in self.inputs.data[i]]))) ) diff --git a/src/mase_components/linear_layers/mxint_operators/test/mxint_accumulator_tb.py b/src/mase_components/linear_layers/mxint_operators/test/mxint_accumulator_tb.py index aaa16a70f..98e2e6e63 100644 --- a/src/mase_components/linear_layers/mxint_operators/test/mxint_accumulator_tb.py +++ b/src/mase_components/linear_layers/mxint_operators/test/mxint_accumulator_tb.py @@ -1,7 +1,10 @@ #!/usr/bin/env python3 # This script tests the fixed point linear -import os, logging +import os +import logging +import sys +import pytest import cocotb from cocotb.log import SimLog @@ -13,16 +16,17 @@ MultiSignalStreamMonitor, ) -from mase_cocotb.runner import mase_runner from utils import mxint_quantize +from mase_cocotb.runner import mase_runner import torch from math import ceil, log2 import random from mase_cocotb.utils import bit_driver + logger = logging.getLogger("testbench") -logger.setLevel(logging.DEBUG) +logger.setLevel(logging.WARNING) torch.manual_seed(10) @@ -46,44 +50,66 @@ def __init__(self, dut, num=1) -> None: dut.data_out_0_valid, dut.data_out_0_ready, check=True, + signed=False, ) def generate_inputs(self): - from utils import block_mxint_quant, pack_tensor_to_mx_listed_chunk - from utils import mxint_quantize - from math import ceil, log2 - - data_in = 20 * torch.rand( + data_in = 20 * torch.randn( self.get_parameter("IN_DEPTH"), self.get_parameter("BLOCK_SIZE") ) - config = { + config_in = { "width": self.get_parameter("DATA_IN_0_PRECISION_0"), "exponent_width": self.get_parameter("DATA_IN_0_PRECISION_1"), } - parallelism = [1, self.get_parameter("BLOCK_SIZE")] - (qtensor, mtensor, etensor) = block_mxint_quant(data_in, config, parallelism) - - qout, mout, eout = mxint_quantize( - qtensor.sum(dim=0), - width=config["width"] - + 2 ** config["exponent_width"] - + ceil(log2(self.get_parameter("IN_DEPTH"))), - exponent_width=config["exponent_width"], - exponent=int(etensor.min()), - ) - tensor_inputs = pack_tensor_to_mx_listed_chunk(mtensor, etensor, parallelism) - exp_outs = [(mout.int().tolist(), int(eout))] + config_out = { + "width": self.get_parameter("DATA_OUT_0_PRECISION_0"), + "exponent_width": self.get_parameter("DATA_OUT_0_PRECISION_1"), + } + + (data_in_q, tensor_out) = [], [] + for d in data_in: + d_q, m_d, e_d = mxint_quantize(d, **config_in) + data_in_q.append(d_q.tolist()) + tensor_out.append((m_d.int().tolist(), int(e_d))) + + # kinda jank but have to model the hardware exactly since accumulator doesn't normalize + left_padding = ceil(log2(self.get_parameter("IN_DEPTH"))) + right_padding = config_out["width"] - config_in["width"] - left_padding + # logical model of what the block is doing + # accumulate and adjust the exponent + mant_out, exp_max = tensor_out[0] + mant_out = torch.tensor(mant_out, dtype=torch.float64) + for mant, exp in tensor_out[1:]: + mant = torch.tensor(mant, dtype=torch.float64) + logger.debug( + torch.remainder(mant_out * 2**right_padding, (2 ** config_out["width"])) + ) + if exp > exp_max: + mant_out = (mant_out * (2 ** (exp_max - exp))) + mant + exp_max = exp + else: + mant_out += mant * (2 ** (exp - exp_max)) + logger.debug( + torch.remainder(mant_out * 2**right_padding, (2 ** config_out["width"])) + ) + m_data_out = torch.remainder( + mant_out * 2 ** (right_padding), (2 ** config_out["width"]) + ) + e_bias_in = 2 ** (config_in["exponent_width"] - 1) - 1 + e_bias_out = 2 ** (config_out["exponent_width"] - 1) - 1 + e_data_out = exp_max - e_bias_in + e_bias_out + left_padding - return tensor_inputs, exp_outs + return tensor_out, [(m_data_out.to(torch.int64).tolist(), int(e_data_out))] async def run_test(self, samples, us): await self.reset() - logger.info(f"Reset finished") + logger.info("Reset finished") self.data_out_0_monitor.ready.value = 1 for _ in range(samples): - logger.info(f"generating inputs") + logger.info("generating inputs") inputs, exp_outputs = self.generate_inputs() + logger.info(f"{inputs} {exp_outputs}") # Load the inputs driver self.data_in_0_driver.load_driver(inputs) @@ -115,21 +141,43 @@ async def run_test(self, samples, us): @cocotb.test() async def repeated_mult_valid_backpressure(dut): - tb = MXIntAccumulatorTB(dut, 1) + tb = MXIntAccumulatorTB(dut, 10) tb.data_in_0_driver.set_valid_prob(0.7) cocotb.start_soon(bit_driver(dut.data_out_0_ready, dut.clk, 0.6)) - await tb.run_test(samples=20, us=200) + num_samples = 100 + await tb.run_test( + samples=num_samples, us=0.05 * num_samples * tb.get_parameter("IN_DEPTH") + ) + + +def get_config(seed): + random.seed(seed) + return { + "DATA_IN_0_PRECISION_0": random.randint(2, 16), + "DATA_IN_0_PRECISION_1": random.randint(2, 4), + "BLOCK_SIZE": random.randint(2, 16), + "IN_DEPTH": random.randint(1, 100), + } + + +@pytest.mark.dev +def run_random_tests(): + torch.manual_seed(10) + seed = os.getenv("COCOTB_SEED") + + if seed is not None: + seed = int(seed) + mase_runner(trace=True, module_param_list=[get_config(seed)]) + else: + num_configs = int(os.getenv("NUM_CONFIGS", default=40)) + base_seed = random.randrange(sys.maxsize) + mase_runner( + trace=True, + module_param_list=[get_config(base_seed + i) for i in range(num_configs)], + jobs=min(num_configs, 10), + ) + print(f"Test seeds: \n{[(i, base_seed + i) for i in range(num_configs)]}") if __name__ == "__main__": - mase_runner( - trace=True, - module_param_list=[ - { - "DATA_IN_0_PRECISION_0": 8, - "DATA_IN_0_PRECISION_1": 4, - "BLOCK_SIZE": 1, - "IN_DEPTH": 1, - }, - ], - ) + run_random_tests() diff --git a/src/mase_components/linear_layers/mxint_operators/test/mxint_cast_tb.py b/src/mase_components/linear_layers/mxint_operators/test/mxint_cast_tb.py index 963ae40df..1b07c1b41 100644 --- a/src/mase_components/linear_layers/mxint_operators/test/mxint_cast_tb.py +++ b/src/mase_components/linear_layers/mxint_operators/test/mxint_cast_tb.py @@ -1,7 +1,11 @@ #!/usr/bin/env python3 # This script tests the fixed point linear -import os, logging +import os, logging, pytest + +import random +import sys +from mase_cocotb.utils import bit_driver import cocotb from cocotb.log import SimLog @@ -41,28 +45,58 @@ def __init__(self, dut, num) -> None: dut.data_out_valid, dut.data_out_ready, check=True, + off_by_one=True, ) + self.data_in_0_driver.log.setLevel(logging.DEBUG) self.data_out_0_monitor.log.setLevel(logging.DEBUG) + def generate_outputs(self, mdata_in, edata_in): + block_size = int(self.get_parameter("BLOCK_SIZE")) + mantissa_width_in = int(self.get_parameter("IN_MAN_WIDTH")) + exponent_width_in = int(self.get_parameter("IN_EXP_WIDTH")) + mantissa_width_out = int(self.get_parameter("OUT_MAN_WIDTH")) + exponent_width_out = int(self.get_parameter("OUT_EXP_WIDTH")) + + data_in = torch.tensor(mdata_in, dtype=torch.float) * ( + 2 ** (2 - mantissa_width_in + edata_in - 2 ** (exponent_width_in - 1) + 1) + ) + data_out, mdata_out, edata_out = mxint_quantize( + data_in, mantissa_width_out, exponent_width_out + ) + + mdata_out = mdata_out.int().tolist() + edata_out = edata_out.int().item() + + bit_mask = 2 ** (exponent_width_out - 1) - 1 + edata_out = (edata_out & bit_mask) - (edata_out & ~bit_mask) + + return mdata_out, edata_out + def generate_inputs(self): + + block_size = int(self.get_parameter("BLOCK_SIZE")) + mantissa_width = int(self.get_parameter("IN_MAN_WIDTH")) + exponent_width = int(self.get_parameter("IN_EXP_WIDTH")) + inputs = [] exp_outputs = [] + for _ in range(self.num): - data = 20 * torch.rand(int(self.dut.BLOCK_SIZE)) - (data_in, mdata_in, edata_in) = mxint_quantize( - data, - int(self.dut.IN_MAN_WIDTH), - int(self.dut.IN_EXP_WIDTH), - ) - exp_out, mexp_out, eexp_out = mxint_quantize( - data_in, - int(self.dut.OUT_MAN_WIDTH), - int(self.dut.OUT_EXP_WIDTH), - ) - breakpoint() - inputs.append((mdata_in.int().tolist(), edata_in.int().tolist())) - exp_outputs.append((mexp_out.int().tolist(), eexp_out.int().tolist())) + mdata_in = [ + random.randint( + -(2 ** (mantissa_width - 1)) + 1, 2 ** (mantissa_width - 1) - 1 + ) + for _ in range(block_size) + ] + edata_in = random.randint(0, 2**exponent_width - 1) + + inputs.append((mdata_in, edata_in)) + + mdata_out, edata_out = self.generate_outputs(mdata_in, edata_in) + + exp_outputs.append((mdata_out, edata_out)) + return inputs, exp_outputs async def run_test(self): @@ -73,53 +107,82 @@ async def run_test(self): logger.info(f"generating inputs") inputs, exp_outputs = self.generate_inputs() - # Load the inputs driver self.data_in_0_driver.load_driver(inputs) - # Load the output monitor self.data_out_0_monitor.load_monitor(exp_outputs) await Timer(5, units="us") assert self.data_out_0_monitor.exp_queue.empty() +def get_mxint_cast_config_random(seed, kwargs={}): + random.seed(seed) + + BLOCK_SIZE = random.randint(1, 16) + + MAX_MANTISSA = 32 + MAX_EXPONENT = 6 + + in_man = random.randint(3, MAX_MANTISSA) + in_exp = random.randint(2, MAX_EXPONENT) + + out_man = random.randint(3, MAX_MANTISSA) + out_exp = random.randint(2, MAX_EXPONENT) + + config = { + "IN_MAN_WIDTH": in_man, + "IN_EXP_WIDTH": in_exp, + "OUT_MAN_WIDTH": out_man, + "OUT_EXP_WIDTH": out_exp, + "BLOCK_SIZE": BLOCK_SIZE, + } + + config.update(kwargs) + return config + + +@pytest.mark.dev +def test_mxint_cast_random(): + """ + Fully randomized parameter testing. + """ + torch.manual_seed(10) + seed = os.getenv("COCOTB_SEED") + + # use this to fix a particular parameter value + param_override = { + # "IN_MAN_WIDTH": 32, + # "IN_EXP_WIDTH": 6, + # "OUT_MAN_WIDTH": 6, + # "OUT_EXP_WIDTH": 3, + # "BLOCK_SIZE": 2, + } + + if seed is not None: + seed = int(seed) + mase_runner( + trace=True, + module_param_list=[get_mxint_cast_config_random(seed, param_override)], + ) + else: + num_configs = int(os.getenv("NUM_CONFIGS", default=1)) + base_seed = random.randrange(sys.maxsize) + mase_runner( + trace=True, + module_param_list=[ + get_mxint_cast_config_random(base_seed + i, param_override) + for i in range(num_configs) + ], + jobs=min(num_configs, os.cpu_count() // 2), + ) + print(f"Test seeds: \n{[(i,base_seed+i) for i in range(num_configs)]}") + + @cocotb.test() async def test(dut): - tb = MXINTVectorMultTB(dut, num=1) + tb = MXINTVectorMultTB(dut, num=100) await tb.run_test() if __name__ == "__main__": - mase_runner( - trace=True, - module_param_list=[ - # { - # "IN_MAN_WIDTH": 6, - # "IN_EXP_WIDTH": 3, - # "OUT_MAN_WIDTH": 12, - # "OUT_EXP_WIDTH": 4, - # "BLOCK_SIZE": 4, - # }, - # { - # "IN_MAN_WIDTH": 8, - # "IN_EXP_WIDTH": 3, - # "OUT_MAN_WIDTH": 8, - # "OUT_EXP_WIDTH": 3, - # "BLOCK_SIZE": 4, - # }, - { - "IN_MAN_WIDTH": 8, - "IN_EXP_WIDTH": 4, - "OUT_MAN_WIDTH": 49, - "OUT_EXP_WIDTH": 5, - "BLOCK_SIZE": 4, - }, - # { - # "IN_MAN_WIDTH": 12, - # "IN_EXP_WIDTH": 3, - # "OUT_MAN_WIDTH": 8, - # "OUT_EXP_WIDTH": 4, - # "BLOCK_SIZE": 4, - # }, - ], - ) + test_mxint_cast_random() diff --git a/src/mase_components/linear_layers/mxint_operators/test/mxint_dot_product_tb.py b/src/mase_components/linear_layers/mxint_operators/test/mxint_dot_product_tb.py index 43e70adcf..4048fe838 100644 --- a/src/mase_components/linear_layers/mxint_operators/test/mxint_dot_product_tb.py +++ b/src/mase_components/linear_layers/mxint_operators/test/mxint_dot_product_tb.py @@ -2,6 +2,7 @@ # This script tests the fixed point linear import os, logging +import sys import cocotb from cocotb.log import SimLog @@ -19,12 +20,11 @@ import torch from math import ceil, log2 import random +import math logger = logging.getLogger("testbench") logger.setLevel(logging.DEBUG) -torch.manual_seed(10) - class MXIntDotProductTB(Testbench): def __init__(self, dut, num) -> None: @@ -52,30 +52,63 @@ def __init__(self, dut, num) -> None: dut.data_out_0_valid, dut.data_out_0_ready, check=True, + signed=False, ) def generate_inputs(self): inputs = [] weights = [] exp_outputs = [] + + ebias_data = (2 ** (self.dut.DATA_IN_0_PRECISION_1.value - 1)) - 1 + ebias_weight = (2 ** (self.dut.WEIGHT_PRECISION_1.value - 1)) - 1 + ebias_out = (2 ** (self.dut.DATA_OUT_0_PRECISION_1.value - 1)) - 1 + + w_man_w = self.dut.WEIGHT_PRECISION_0.value + in_man_w = self.dut.DATA_IN_0_PRECISION_0.value + out_man_w = self.dut.DATA_OUT_0_PRECISION_0.value + + fractional_point_adjustment = out_man_w - 2 - ((w_man_w - 2) + (in_man_w - 2)) + for _ in range(self.num): - data = torch.rand(int(self.dut.BLOCK_SIZE)) + data = torch.randn(int(self.dut.BLOCK_SIZE)) (data_in, mdata_in, edata_in) = mxint_quantize( data, int(self.dut.DATA_IN_0_PRECISION_0), int(self.dut.DATA_IN_0_PRECISION_1), ) - w = torch.rand(int(self.dut.BLOCK_SIZE)) + w = torch.randn(int(self.dut.BLOCK_SIZE)) (weight, mweight, eweight) = mxint_quantize( w, int(self.dut.WEIGHT_PRECISION_0), int(self.dut.WEIGHT_PRECISION_1), ) - mdp_out = mdata_in @ mweight - edp_out = edata_in + eweight + + # compute the mantissa + mdp_out = 0 + for d, w in zip(mdata_in, mweight): + mdp_out += math.floor(int(d) * int(w)) + # take the mod since the monitor comparison is unsigned + mdp_out_unsigned = int(mdp_out) % (2**out_man_w) + # adjust the exponent by the biases of the different widths + edp_out = (edata_in - ebias_data) + (eweight - ebias_weight) + ebias_out + # compute the result based on the exponent and mantissa found above + out_manual = (mdp_out * 2 ** (-(w_man_w + in_man_w - 4))) * ( + 2 ** (edp_out - ebias_out) + ) + + # compute the quantized output value in "full" precision + out_q = data_in @ weight + # check that the relative error doesn't exceed some amount + assert ( + abs(out_q - out_manual) < 1e-6 + ), "Something went wrong when calculating the expected mantissa and exponents" + inputs.append((mdata_in.int().tolist(), edata_in.int().tolist())) weights.append((mweight.int().tolist(), eweight.int().tolist())) - exp_outputs.append((mdp_out.int().tolist(), edp_out.int().tolist())) + exp_outputs.append( + (mdp_out_unsigned, int(edp_out + fractional_point_adjustment)) + ) print(inputs) print(weights) print(exp_outputs) @@ -89,6 +122,8 @@ async def run_test(self): logger.info(f"generating inputs") inputs, weights, exp_outputs = self.generate_inputs() + # self.log.info(f"inputs: {inputs}\n{}") + # Load the inputs driver self.data_in_0_driver.load_driver(inputs) self.weight_driver.load_driver(weights) @@ -102,20 +137,34 @@ async def run_test(self): @cocotb.test() async def test(dut): - tb = MXIntDotProductTB(dut, num=20) + tb = MXIntDotProductTB(dut, num=50) await tb.run_test() +def get_config(seed): + random.seed(seed) + return { + "DATA_IN_0_PRECISION_0": random.randint(2, 16), + "DATA_IN_0_PRECISION_1": random.randint(2, 16), + "WEIGHT_PRECISION_0": random.randint(2, 16), + "WEIGHT_PRECISION_1": random.randint(2, 16), + "BLOCK_SIZE": random.randint(2, 16), + } + + if __name__ == "__main__": - mase_runner( - trace=True, - module_param_list=[ - { - "DATA_IN_0_PRECISION_0": 8, - "DATA_IN_0_PRECISION_1": 4, - "WEIGHT_PRECISION_0": 7, - "WEIGHT_PRECISION_1": 4, - "BLOCK_SIZE": 4, - }, - ], - ) + torch.manual_seed(10) + seed = os.getenv("COCOTB_SEED") + + if seed is not None: + seed = int(seed) + mase_runner(trace=True, module_param_list=[get_config(seed)]) + else: + num_configs = int(os.getenv("NUM_CONFIGS", default=40)) + base_seed = random.randrange(sys.maxsize) + mase_runner( + trace=True, + module_param_list=[get_config(base_seed + i) for i in range(num_configs)], + jobs=10, + ) + print(f"Test seeds: \n{[(i,base_seed+i) for i in range(num_configs)]}") diff --git a/src/mase_components/linear_layers/mxint_operators/test/mxint_linear_tb.py b/src/mase_components/linear_layers/mxint_operators/test/mxint_linear_tb.py index 036c5a2e3..2859388ed 100644 --- a/src/mase_components/linear_layers/mxint_operators/test/mxint_linear_tb.py +++ b/src/mase_components/linear_layers/mxint_operators/test/mxint_linear_tb.py @@ -1,5 +1,8 @@ #!/usr/bin/env python3 +import random +import sys +from mase_cocotb.utils import bit_driver import os, pytest import torch @@ -23,7 +26,9 @@ class LinearTB(Testbench): - def __init__(self, dut) -> None: + def __init__( + self, dut, data_in_p=1.0, weight_p=1.0, bias_p=1.0, data_out_p=1.0 + ) -> None: super().__init__(dut, dut.clk, dut.rst) if not hasattr(self, "log"): @@ -36,15 +41,19 @@ def __init__(self, dut) -> None: dut.data_in_0_valid, dut.data_in_0_ready, ) + self.data_in_0_driver.set_valid_prob(data_in_p) + self.weight_driver = MultiSignalStreamDriver( dut.clk, (dut.mweight, dut.eweight), dut.weight_valid, dut.weight_ready ) + self.weight_driver.set_valid_prob(weight_p) if self.get_parameter("HAS_BIAS") == 1: self.bias_driver = MultiSignalStreamDriver( dut.clk, (dut.mbias, dut.ebias), dut.bias_valid, dut.bias_ready ) self.bias_driver.log.setLevel(logging.DEBUG) + self.bias_driver.set_valid_prob(bias_p) self.data_out_0_monitor = MultiSignalStreamMonitor( dut.clk, @@ -52,7 +61,10 @@ def __init__(self, dut) -> None: dut.data_out_0_valid, dut.data_out_0_ready, check=True, + signed=False, + off_by_one=True, ) + cocotb.start_soon(bit_driver(dut.data_out_0_ready, dut.clk, data_out_p)) # Model self.model = MXIntLinear( @@ -101,6 +113,11 @@ def preprocess_tensor_for_mxint(self, tensor, config, parallelism): from utils import pack_tensor_to_mx_listed_chunk (qtensor, mtensor, etensor) = block_mxint_quant(tensor, config, parallelism) + # take the mod to get the unsigned representation of the number + mtensor = mtensor.remainder(2 ** config["width"]) + self.log.info(f"Mantissa Tensor: {mtensor}") + self.log.info(f"Exponent Tensor: {etensor}") + tensor_inputs = pack_tensor_to_mx_listed_chunk(mtensor, etensor, parallelism) return tensor_inputs @@ -122,7 +139,7 @@ async def run_test(self, us): # * Load the inputs driver self.log.info(f"Processing inputs: {inputs}") - inputs = self.preprocess_tensor_for_mxint( + inputs_quant = self.preprocess_tensor_for_mxint( tensor=inputs, config={ "width": self.get_parameter("DATA_IN_0_PRECISION_0"), @@ -133,12 +150,12 @@ async def run_test(self, us): self.get_parameter("DATA_IN_0_PARALLELISM_DIM_0"), ], ) - self.data_in_0_driver.load_driver(inputs) - + self.data_in_0_driver.load_driver(inputs_quant) + self.log.info(f"processed inputs\n{inputs_quant}") # * Load the weights driver weights = self.model.weight - self.log.info(f"Processing weights: {weights}") + self.log.info(f"Processing weights:\n{weights}") weights = self.preprocess_tensor_for_mxint( tensor=weights, config={ @@ -150,6 +167,7 @@ async def run_test(self, us): self.get_parameter("WEIGHT_PARALLELISM_DIM_0"), ], ) + self.log.info(f"Processed weights:\n{weights}") self.weight_driver.load_driver(weights) # * Load the bias driver @@ -182,7 +200,6 @@ async def run_test(self, us): self.get_parameter("DATA_OUT_0_PARALLELISM_DIM_0"), ], ) - breakpoint() self.data_out_0_monitor.load_monitor(outs) await Timer(us, units="us") @@ -191,52 +208,77 @@ async def run_test(self, us): @cocotb.test() async def cocotb_test(dut): - tb = LinearTB(dut) - await tb.run_test(us=100) - - -def get_fixed_linear_config(kwargs={}): - # if pretranspose - # weight1 = in0 - # else - # weight0 = in0 - # currently, we only consider the transposed situation - # config = { - # "HAS_BIAS": 1, - # "DATA_IN_0_TENSOR_SIZE_DIM_0": 2, - # "DATA_IN_0_TENSOR_SIZE_DIM_1": 2, - # "DATA_IN_0_PARALLELISM_DIM_0": 2, - # "DATA_IN_0_PARALLELISM_DIM_1": 1, - # "WEIGHT_TENSOR_SIZE_DIM_0": 2, - # "WEIGHT_TENSOR_SIZE_DIM_1": 2, - # "WEIGHT_PARALLELISM_DIM_0": 2, - # "WEIGHT_PARALLELISM_DIM_1": 1, - # "DATA_IN_0_PRECISION_0": 8, - # "DATA_IN_0_PRECISION_1": 4, - # "WEIGHT_PRECISION_0": 8, - # "WEIGHT_PRECISION_1": 4, - # "BIAS_PRECISION_0": 8, - # "BIAS_PRECISION_1": 4, - # "DATA_OUT_0_PRECISION_0": 10, - # "DATA_OUT_0_PRECISION_1": 4, - # } + probs = [torch.rand(1).item() for _ in range(4)] + # probs = [1 for _ in range(4)] + tb = LinearTB( + dut, + data_in_p=probs[0], + weight_p=probs[1], + bias_p=probs[2], + data_out_p=probs[3], + ) + await tb.run_test(us=200) + + +def get_mxint_linear_config_random(seed, kwargs={}): + MAX_IN_FEATURES = 16 + MAX_OUT_FEATURES = 16 + MAX_BATCH_SIZE = 8 + random.seed(seed) + + BLOCK_SIZE = random.randint(2, 8) + PARALLELISM = random.randint(1, 8) + BATCH_SIZE = random.randint(1, MAX_BATCH_SIZE // PARALLELISM) * PARALLELISM + IN_FEATURES = random.randint(2, MAX_IN_FEATURES // BLOCK_SIZE) * BLOCK_SIZE + OUT_FEATURES = random.randint(2, MAX_OUT_FEATURES // BLOCK_SIZE) * BLOCK_SIZE + + MAX_MANTISSA = 16 + MAX_EXPONENT = 6 + + mantissas = [random.randint(4, MAX_MANTISSA)] * 4 + exps = [random.randint(3, min(mantissas[0], MAX_EXPONENT))] * 4 + + config = { + "HAS_BIAS": random.randint(0, 1), + "DATA_IN_0_TENSOR_SIZE_DIM_0": IN_FEATURES, + "DATA_IN_0_TENSOR_SIZE_DIM_1": BATCH_SIZE, + "DATA_IN_0_PARALLELISM_DIM_0": BLOCK_SIZE, + "DATA_IN_0_PARALLELISM_DIM_1": PARALLELISM, + "WEIGHT_TENSOR_SIZE_DIM_0": IN_FEATURES, + "WEIGHT_TENSOR_SIZE_DIM_1": OUT_FEATURES, + "WEIGHT_PARALLELISM_DIM_0": BLOCK_SIZE, + "WEIGHT_PARALLELISM_DIM_1": BLOCK_SIZE, + "DATA_IN_0_PRECISION_0": mantissas[0], + "DATA_IN_0_PRECISION_1": exps[0], + "WEIGHT_PRECISION_0": mantissas[1], + "WEIGHT_PRECISION_1": exps[1], + "BIAS_PRECISION_0": mantissas[2], + "BIAS_PRECISION_1": exps[2], + "DATA_OUT_0_PRECISION_0": mantissas[3], + "DATA_OUT_0_PRECISION_1": exps[3], + } + config.update(kwargs) + return config + + +def get_mxint_linear_config(kwargs={}): config = { "HAS_BIAS": 1, - "DATA_IN_0_TENSOR_SIZE_DIM_0": 32, - "DATA_IN_0_TENSOR_SIZE_DIM_1": 16, - "DATA_IN_0_PARALLELISM_DIM_0": 4, - "DATA_IN_0_PARALLELISM_DIM_1": 4, - "WEIGHT_TENSOR_SIZE_DIM_0": 32, - "WEIGHT_TENSOR_SIZE_DIM_1": 16, - "WEIGHT_PARALLELISM_DIM_0": 4, - "WEIGHT_PARALLELISM_DIM_1": 4, - "DATA_IN_0_PRECISION_0": 9, + "DATA_IN_0_TENSOR_SIZE_DIM_0": 2, + "DATA_IN_0_TENSOR_SIZE_DIM_1": 2, + "DATA_IN_0_PARALLELISM_DIM_0": 2, + "DATA_IN_0_PARALLELISM_DIM_1": 1, + "WEIGHT_TENSOR_SIZE_DIM_0": 2, + "WEIGHT_TENSOR_SIZE_DIM_1": 2, + "WEIGHT_PARALLELISM_DIM_0": 2, + "WEIGHT_PARALLELISM_DIM_1": 1, + "DATA_IN_0_PRECISION_0": 8, "DATA_IN_0_PRECISION_1": 4, "WEIGHT_PRECISION_0": 8, - "WEIGHT_PRECISION_1": 3, + "WEIGHT_PRECISION_1": 4, "BIAS_PRECISION_0": 8, "BIAS_PRECISION_1": 4, - "DATA_OUT_0_PRECISION_0": 12, + "DATA_OUT_0_PRECISION_0": 10, "DATA_OUT_0_PRECISION_1": 4, } config.update(kwargs) @@ -244,59 +286,55 @@ def get_fixed_linear_config(kwargs={}): @pytest.mark.dev -def test_fixed_linear_smoke(): +def test_mxint_linear_full_random(): """ - Some quick tests to check if the module is working. + Fully randomized parameter testing. """ - mase_runner( - trace=True, - extra_build_args=["--trace-depth", "8"], - module_param_list=[ - get_fixed_linear_config(), - # noticed here if change WEIGHT_PRE_TRANSPOSED also need to change the DIM_SIZE to match ACTIVATION - # get_fixed_linear_config( - # { - # "WEIGHT_TENSOR_SIZE_DIM_0": 32, - # "WEIGHT_TENSOR_SIZE_DIM_1": 16, - # "WEIGHT_PARALLELISM_DIM_0": 4, - # "WEIGHT_PARALLELISM_DIM_1": 2, - # }, - # ), - ], - ) + torch.manual_seed(10) + seed = os.getenv("COCOTB_SEED") + + # use this to fix a particular parameter value + param_override = { + "HAS_BIAS": 1, + } + + if seed is not None: + seed = int(seed) + mase_runner( + trace=True, + module_param_list=[get_mxint_linear_config_random(seed, param_override)], + ) + else: + num_configs = int(os.getenv("NUM_CONFIGS", default=5)) + base_seed = random.randrange(sys.maxsize) + mase_runner( + trace=True, + module_param_list=[ + get_mxint_linear_config_random(base_seed + i, param_override) + for i in range(num_configs) + ], + jobs=min(num_configs, 10), + ) + print(f"Test seeds: \n{[(i,base_seed+i) for i in range(num_configs)]}") @pytest.mark.dev -def test_fixed_linear_regression(): - """ - More extensive tests to check realistic parameter sizes. - """ +def test_mxint_linear(): mase_runner( trace=True, module_param_list=[ - get_fixed_linear_config( - { - "DATA_IN_0_TENSOR_SIZE_DIM_0": 768, - "DATA_IN_0_PARALLELISM_DIM_0": 32, - "WEIGHT_TENSOR_SIZE_DIM_0": 768, - "WEIGHT_TENSOR_SIZE_DIM_1": 768, - "WEIGHT_PARALLELISM_DIM_0": 32, - "WEIGHT_PARALLELISM_DIM_1": 32, - "BIAS_TENSOR_SIZE_DIM_0": 768, - "BIAS_PARALLELISM_DIM_0": 32, - } - ), - get_fixed_linear_config( + get_mxint_linear_config( { "HAS_BIAS": 1, - "DATA_IN_0_TENSOR_SIZE_DIM_0": 768, - "DATA_IN_0_PARALLELISM_DIM_0": 32, - "WEIGHT_TENSOR_SIZE_DIM_0": 768, - "WEIGHT_TENSOR_SIZE_DIM_1": 768, - "WEIGHT_PARALLELISM_DIM_0": 32, - "WEIGHT_PARALLELISM_DIM_1": 32, - "BIAS_TENSOR_SIZE_DIM_0": 768, - "BIAS_PARALLELISM_DIM_0": 32, + "DATA_IN_0_TENSOR_SIZE_DIM_0": 4, + "DATA_IN_0_TENSOR_SIZE_DIM_1": 10, + "DATA_IN_0_PARALLELISM_DIM_0": 2, + "WEIGHT_TENSOR_SIZE_DIM_0": 4, + "WEIGHT_TENSOR_SIZE_DIM_1": 4, + "WEIGHT_PARALLELISM_DIM_0": 2, + "WEIGHT_PARALLELISM_DIM_1": 2, + "BIAS_TENSOR_SIZE_DIM_0": 4, + "BIAS_PARALLELISM_DIM_0": 2, } ), ], @@ -304,5 +342,5 @@ def test_fixed_linear_regression(): if __name__ == "__main__": - test_fixed_linear_smoke() - # test_fixed_linear_regression() + # test_mxint_linear() + test_mxint_linear_full_random() diff --git a/src/mase_components/linear_layers/mxint_operators/test/mxint_matrix_cat_tb.py b/src/mase_components/linear_layers/mxint_operators/test/mxint_matrix_cat_tb.py new file mode 100644 index 000000000..df2b03fd1 --- /dev/null +++ b/src/mase_components/linear_layers/mxint_operators/test/mxint_matrix_cat_tb.py @@ -0,0 +1,253 @@ +#!/usr/bin/env python3 + +# This script tests the fixed point linear +import os, logging, pytest + +import cocotb +from cocotb.log import SimLog +from cocotb.triggers import * + +from mase_cocotb.testbench import Testbench +from mase_cocotb.interfaces.streaming import ( + MultiSignalStreamDriver, + MultiSignalStreamMonitor, +) + +from mase_cocotb.runner import mase_runner +from utils import mxint_quantize, block_mxint_quant, pack_tensor_to_mx_listed_chunk +import sys + +from chop.nn.quantizers.mxint import _mxint_quantize + +import torch +from math import ceil, log2 +import random + +logger = logging.getLogger("testbench") +logger.setLevel(logging.DEBUG) + +torch.manual_seed(10) + + +class MXIntMatrixCat(Testbench): + def __init__(self, dut, num) -> None: + super().__init__(dut, dut.clk, dut.rst) + self.num = num + if not hasattr(self, "log"): + self.log = SimLog("%s" % (type(self).__qualname__)) + + self.data_in_0_driver = MultiSignalStreamDriver( + dut.clk, + (dut.mdata_in_0, dut.edata_in_0), + dut.data_in_0_valid, + dut.data_in_0_ready, + ) + + self.data_in_1_driver = MultiSignalStreamDriver( + dut.clk, + (dut.mdata_in_1, dut.edata_in_1), + dut.data_in_1_valid, + dut.data_in_1_ready, + ) + + self.data_out_0_monitor = MultiSignalStreamMonitor( + dut.clk, + (dut.mdata_out_0, dut.edata_out_0), + dut.data_out_0_valid, + dut.data_out_0_ready, + check=True, + ) + self.data_in_0_driver.log.setLevel(logging.DEBUG) + self.data_in_1_driver.log.setLevel(logging.DEBUG) + self.data_out_0_monitor.log.setLevel(logging.DEBUG) + + def generate_inputs(self): + + din1 = [] + din2 = [] + exp_outputs = [] + + for _ in range(self.num): + + d0 = torch.rand( + int(self.dut.DATA_IN_0_TENSOR_SIZE_DIM_0), + int(self.dut.DATA_IN_0_TENSOR_SIZE_DIM_1), + ) + + d1 = torch.rand( + int(self.dut.DATA_IN_1_TENSOR_SIZE_DIM_0), + int(self.dut.DATA_IN_1_TENSOR_SIZE_DIM_1), + ) + + (data_in_0, mdata_in_0, edata_in_0) = block_mxint_quant( + d0, + { + "width": int(self.dut.DATA_IN_0_PRECISION_0), + "exponent_width": int(self.dut.DATA_IN_0_PRECISION_1), + }, + [ + int(self.dut.DATA_IN_0_PARALLELISM_DIM_0), + int(self.dut.DATA_IN_0_PARALLELISM_DIM_1), + ], + ) + + (data_in_1, mdata_in_1, edata_in_1) = block_mxint_quant( + d1, + { + "width": int(self.dut.DATA_IN_1_PRECISION_0), + "exponent_width": int(self.dut.DATA_IN_1_PRECISION_1), + }, + [ + int(self.dut.DATA_IN_1_PARALLELISM_DIM_0), + int(self.dut.DATA_IN_1_PARALLELISM_DIM_1), + ], + ) + + din1.append( + pack_tensor_to_mx_listed_chunk( + mdata_in_0, + edata_in_0, + [ + int(self.dut.DATA_IN_0_PARALLELISM_DIM_0), + int(self.dut.DATA_IN_0_PARALLELISM_DIM_1), + ], + ) + ) + + din2.append( + pack_tensor_to_mx_listed_chunk( + mdata_in_1, + edata_in_1, + [ + int(self.dut.DATA_IN_1_PARALLELISM_DIM_0), + int(self.dut.DATA_IN_1_PARALLELISM_DIM_1), + ], + ) + ) + + mdp = torch.cat([d0, d1], dim=-1) + + (data_out_0, mdata_out_0, edata_out_0) = block_mxint_quant( + mdp, + { + "width": int(self.dut.DATA_OUT_0_PRECISION_0), + "exponent_width": int(self.dut.DATA_OUT_0_PRECISION_1), + }, + [ + int(self.dut.DATA_OUT_0_PARALLELISM_DIM_0), + int(self.dut.DATA_OUT_0_PARALLELISM_DIM_1), + ], + ) + + exp_outputs.append( + pack_tensor_to_mx_listed_chunk( + mdata_out_0, + edata_out_0, + [ + int(self.dut.DATA_OUT_0_PARALLELISM_DIM_0), + int(self.dut.DATA_OUT_0_PARALLELISM_DIM_1), + ], + ) + ) + + print(f"Din 1: \n {din1}") + print(f"Din 2: \n {din2}") + print(f"Dout: \n {exp_outputs}") + + return din1, din2, exp_outputs + + async def run_test(self): + await self.reset() + + for i in range(0, self.num): + + logger.info(f"Reset finished") + + self.data_out_0_monitor.ready.value = 1 + + logger.info(f"generating inputs") + inputs, weights, exp_outputs = self.generate_inputs() + + # Load the inputs driver + self.data_in_0_driver.load_driver(inputs[i]) + self.data_in_1_driver.load_driver(weights[i]) + # # Load the output monitor + self.data_out_0_monitor.load_monitor(exp_outputs[i]) + + await Timer(100, units="us") + assert self.data_out_0_monitor.exp_queue.empty() + + +@cocotb.test() +async def test(dut): + tb = MXIntMatrixCat(dut, num=1) + await tb.run_test() + + +def get_config(seed: int): + random.seed(seed) + + MAX_MANTISSA = 16 + MAX_EXPONENT = 6 + + cat_dim = random.randint(4, 16) + + dim_din_0 = random.randint(4, 16) + dim_din_1 = random.randint(1, 10) * dim_din_0 + + factors = [i for i in range(1, cat_dim + 1) if cat_dim % i == 0] + parallelism_c = random.choice(factors) + + factors = [i for i in range(1, dim_din_0 + 1) if dim_din_0 % i == 0] + parallelism_0 = random.choice(factors) + + prec_0_0 = random.randint(4, 16) + prec_0_1 = random.randint(4, MAX_EXPONENT) + + prec_1_0 = random.randint(4, 16) + prec_1_1 = random.randint(4, MAX_EXPONENT) + + prec_out_0 = random.randint(3, min([prec_0_0, prec_0_1, prec_1_0, prec_1_1])) + prec_out_1 = random.randint(3, min([prec_0_0, prec_0_1, prec_1_0, prec_1_1])) + + config = { + "DATA_IN_0_PRECISION_0": prec_0_0, + "DATA_IN_0_PRECISION_1": prec_0_1, + "DATA_IN_0_TENSOR_SIZE_DIM_0": cat_dim, + "DATA_IN_0_TENSOR_SIZE_DIM_1": dim_din_0, + "DATA_IN_0_PARALLELISM_DIM_0": parallelism_c, + "DATA_IN_0_PARALLELISM_DIM_1": parallelism_0, + "DATA_IN_1_PRECISION_0": prec_1_0, + "DATA_IN_1_PRECISION_1": prec_1_1, + "DATA_IN_1_TENSOR_SIZE_DIM_0": cat_dim, + "DATA_IN_1_TENSOR_SIZE_DIM_1": dim_din_1, + "DATA_IN_1_PARALLELISM_DIM_0": parallelism_c, + "DATA_IN_1_PARALLELISM_DIM_1": parallelism_0, + "DATA_OUT_0_PRECISION_0": prec_out_0, + "DATA_OUT_0_PRECISION_1": prec_out_1, + } + + return config + + +@pytest.mark.dev +def test_random_cat(): + torch.manual_seed(1) + seed = os.getenv("COCOTB_SEED") + + if seed is not None: + seed = int(seed) + mase_runner(trace=True, module_param_list=[get_config(seed)]) + else: + num_configs = int(os.getenv("NUM_CONFIGS", default=40)) + base_seed = random.randrange(sys.maxsize) + mase_runner( + trace=True, + module_param_list=[get_config(base_seed + i) for i in range(num_configs)], + jobs=min(num_configs, 10), + ) + print(f"Test seeds: \n{[(i, base_seed + i) for i in range(num_configs)]}") + + +if __name__ == "__main__": + test_random_cat() diff --git a/src/mase_components/linear_layers/mxint_operators/test/mxint_relu_tb.py b/src/mase_components/linear_layers/mxint_operators/test/mxint_relu_tb.py new file mode 100644 index 000000000..53bf3ada2 --- /dev/null +++ b/src/mase_components/linear_layers/mxint_operators/test/mxint_relu_tb.py @@ -0,0 +1,211 @@ +#!/usr/bin/env python3 + +import os, pytest, random, sys + +import torch +import logging +from functools import partial + +import cocotb +from cocotb.log import SimLog +from cocotb.triggers import Timer, RisingEdge + +from mase_cocotb.testbench import Testbench +from mase_cocotb.interfaces.streaming import ( + MultiSignalStreamDriver, + MultiSignalStreamMonitor, +) +from mase_cocotb.runner import mase_runner + +torch.manual_seed(0) +# from mase_cocotb import Testbench, StreamDriver, StreamMonitor, mase_runner +from utils import MXIntRelu + + +class MXIntReluTB(Testbench): + def __init__(self, dut) -> None: + super().__init__(dut, dut.clk, dut.rst) + + if not hasattr(self, "log"): + self.log = SimLog("%s" % (type(self).__qualname__)) + self.log.setLevel(logging.DEBUG) + + self.data_in_0_driver = MultiSignalStreamDriver( + dut.clk, + (dut.mdata_in_0, dut.edata_in_0), + dut.data_in_0_valid, + dut.data_in_0_ready, + ) + + self.data_out_0_monitor = MultiSignalStreamMonitor( + dut.clk, + (dut.mdata_out_0, dut.edata_out_0), + dut.data_out_0_valid, + dut.data_out_0_ready, + check=True, + signed=False, + ) + + # Model + self.model = MXIntRelu( + config={ + "data_in_width": self.get_parameter("DATA_IN_0_PRECISION_0"), + "data_in_exponent_width": self.get_parameter("DATA_IN_0_PRECISION_1"), + "data_in_parallelism_dim_1": self.get_parameter( + "DATA_IN_0_PARALLELISM_DIM_1" + ), + "data_in_parallelism_dim_0": self.get_parameter( + "DATA_IN_0_PARALLELISM_DIM_0" + ), + }, + bypass=True, + ) + + # Set verbosity of driver and monitor loggers to debug + self.data_in_0_driver.log.setLevel(logging.DEBUG) + self.data_out_0_monitor.log.setLevel(logging.DEBUG) + + def preprocess_tensor_for_mxint(self, tensor, config, parallelism): + from utils import block_mxint_quant + from utils import pack_tensor_to_mx_listed_chunk + + (qtensor, mtensor, etensor) = block_mxint_quant(tensor, config, parallelism) + self.log.info(f"Mantissa Tensor: {mtensor}") + self.log.info(f"Exponenr Tensor: {etensor}") + tensor_inputs = pack_tensor_to_mx_listed_chunk(mtensor, etensor, parallelism) + return qtensor, tensor_inputs + + def generate_inputs(self): + return torch.randn( + ( + self.get_parameter("DATA_IN_0_TENSOR_SIZE_DIM_1"), + self.get_parameter("DATA_IN_0_TENSOR_SIZE_DIM_0"), + ) + ) + + async def run_test(self, us): + await self.reset() + self.log.info(f"Reset finished") + self.data_out_0_monitor.ready.value = 1 + + inputs = self.generate_inputs() + + # * Load the inputs driver + self.log.info(f"Processing inputs: {inputs}") + quantized, inputs = self.preprocess_tensor_for_mxint( + tensor=inputs, + config={ + "width": self.get_parameter("DATA_IN_0_PRECISION_0"), + "exponent_width": self.get_parameter("DATA_IN_0_PRECISION_1"), + }, + parallelism=[ + self.get_parameter("DATA_IN_0_PARALLELISM_DIM_1"), + self.get_parameter("DATA_IN_0_PARALLELISM_DIM_0"), + ], + ) + self.data_in_0_driver.load_driver(inputs) + + exp_out = self.model(quantized) + # * Load the output monitor + self.log.info(f"Processing outputs: {exp_out}") + _, outs = self.preprocess_tensor_for_mxint( + tensor=exp_out, + config={ + "width": self.get_parameter("DATA_IN_0_PRECISION_0"), + "exponent_width": self.get_parameter("DATA_IN_0_PRECISION_1"), + }, + parallelism=[ + self.get_parameter("DATA_IN_0_PARALLELISM_DIM_1"), + self.get_parameter("DATA_IN_0_PARALLELISM_DIM_0"), + ], + ) + self.data_out_0_monitor.load_monitor(outs) + + await Timer(us, units="us") + assert self.data_out_0_monitor.exp_queue.empty() + + +@cocotb.test() +async def cocotb_test(dut): + tb = MXIntReluTB(dut) + await tb.run_test(us=100) + + +def get_relu_config(seed, kwargs={}): + MAX_IN_FEATURES = 16 + MAX_BATCH_SIZE = 8 + random.seed(seed) + + BLOCK_SIZE = random.randint(2, 8) + PARALLELISM = random.randint(1, 8) + BATCH_SIZE = random.randint(1, MAX_BATCH_SIZE // PARALLELISM) * PARALLELISM + IN_FEATURES = random.randint(2, MAX_IN_FEATURES // BLOCK_SIZE) * BLOCK_SIZE + + MAX_MANTISSA = 16 + MAX_EXPONENT = 6 + + mantissa = random.randint(3, MAX_MANTISSA) + exp = random.randint(3, min(mantissa, MAX_EXPONENT)) + + +def get_relu_config(seed, kwargs={}): + MAX_IN_FEATURES = 16 + MAX_BATCH_SIZE = 8 + random.seed(seed) + + BLOCK_SIZE = random.randint(2, 8) + PARALLELISM = random.randint(1, 8) + BATCH_SIZE = random.randint(1, MAX_BATCH_SIZE // PARALLELISM) * PARALLELISM + IN_FEATURES = random.randint(2, MAX_IN_FEATURES // BLOCK_SIZE) * BLOCK_SIZE + + MAX_MANTISSA = 16 + MAX_EXPONENT = 6 + + mantissa = random.randint(3, MAX_MANTISSA) + exp = random.randint(3, min(mantissa, MAX_EXPONENT)) + + config = { + "DATA_IN_0_PRECISION_0": mantissa, + "DATA_IN_0_PRECISION_1": exp, + "DATA_IN_0_TENSOR_SIZE_DIM_0": IN_FEATURES, + "DATA_IN_0_TENSOR_SIZE_DIM_1": BATCH_SIZE, + "DATA_IN_0_PARALLELISM_DIM_0": BLOCK_SIZE, + "DATA_IN_0_PARALLELISM_DIM_1": PARALLELISM, + } + + config.update(kwargs) + return config + + +@pytest.mark.dev +def test_relu(): + """ + Fully randomized parameter testing. + """ + torch.manual_seed(10) + seed = os.getenv("COCOTB_SEED") + + param_override = {} + + if seed is not None: + seed = int(seed) + mase_runner( + trace=True, + module_param_list=[get_relu_config(seed, param_override)], + ) + else: + num_configs = int(os.getenv("NUM_CONFIGS", default=1)) + base_seed = random.randrange(sys.maxsize) + mase_runner( + trace=True, + module_param_list=[ + get_relu_config(base_seed + i, param_override) + for i in range(num_configs) + ], + jobs=min(num_configs, os.cpu_count() // 2), + ) + print(f"Test seeds: \n{[(i,base_seed+i) for i in range(num_configs)]}") + + +if __name__ == "__main__": + test_relu() diff --git a/src/mase_components/linear_layers/mxint_operators/test/utils.py b/src/mase_components/linear_layers/mxint_operators/test/utils.py index 2e7b561b1..54d2bf474 100644 --- a/src/mase_components/linear_layers/mxint_operators/test/utils.py +++ b/src/mase_components/linear_layers/mxint_operators/test/utils.py @@ -1,5 +1,6 @@ import torch from chop.nn.quantizers import integer_floor_quantizer +from chop.nn.quantized.modules import relu from functools import partial import torch.nn.functional as F from torch import Tensor @@ -7,37 +8,34 @@ def mxint_quantize(x, width: int = 12, exponent_width: int = 6, exponent: int = None): """ - - Convert IEEE FP32/64 to Microsoft floating point (MSFP), where an exponent is shared over all elements in a block. - - `e_shared x [(-1)^s1 x mantissa1, (-1)^s2 x mantissa2, ...]` - - See https://proceedings.neurips.cc/paper/2020/file/747e32ab0fea7fbd2ad9ec03daa3f840-Paper.pdf + - Convert IEEE FP32/64 to Microscaling Interger (MXINT), where an exponent is shared over all elements in a block. + - https://arxiv.org/pdf/2310.10537.pdf + - https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf --- - - forward: convert IEEE FP32/64 to MSFP + - forward: convert IEEE FP32/64 to MXINT - backward: STE --- - `width`: The number of mantissa bits + 1 (the sign bit) - - `exponent_width`: the number of exponent bits, which is shared over a block - - `exponent_bias`: the exponent bias, if None, `2**(exponent_bits-1)-1` will be used - - `block_size`: a list of integers where each integer is the block size on that dimension. See function `block`. - + - `exponent_width`: the number of exponent bits """ - exponent_bias = 2 ** (exponent_width - 1) - exponent_max = 2**exponent_width - 1 - exponent_bias - exponent_min = -exponent_bias + exponent_bias = 2 ** (exponent_width - 1) - 1 # exponent if exponent == None: - exponent = torch.ceil(torch.log2(x.abs().max())) - exponent_bias - exponent = torch.clamp(exponent, exponent_min, exponent_max) + exponent = torch.floor(torch.log2(x.abs().max())) + exponent_bias + exponent = torch.clamp(exponent, 0, 2**exponent_width - 1) # mantissa - int_min = -(2 ** (width - 1)) - int_max = 2 ** (width - 1) - 1 - mantissa = x / 2**exponent - mantissa = torch.clamp(mantissa.floor(), int_min, int_max) - msfp_x = (2**exponent) * mantissa - return msfp_x, mantissa, exponent + element_max = 2 ** (width - 1) - 1 + shift = 2 ** (width - 2) + + mantissa = shift * x / 2 ** (exponent - exponent_bias) + mantissa = torch.clamp(mantissa.floor(), -element_max, element_max) + mxint_x = mantissa * 2 ** (exponent - exponent_bias) / shift + + return mxint_x, mantissa, exponent def block_mxint_quant(tensor, q_config, parallelism): @@ -153,6 +151,8 @@ def __init__( q_config={"width": out_width, "exponent_width": out_exponent_width}, parallelism=[out_p1, out_p0], ) + else: + self.out_quantizer = None self.w_quantizer = partial( base_quantizer, q_config={"width": w_width, "exponent_width": w_exponent_width}, @@ -184,3 +184,37 @@ def forward(self, x: Tensor) -> Tensor: if self.out_quantizer is None: return out return self.out_quantizer(out) + + +class MXIntRelu(relu._ReLUBase): + def __init__(self, inplace: bool = False, config=None, bypass=False): + assert config is not None, "config is None!" + super().__init__(inplace) + + self.config = config + self.bypass = bypass + + base_quantizer = block_mxint_quant + + x_width, x_exponent_width = ( + config["data_in_width"], + config["data_in_exponent_width"], + ) + + x_p1, x_p0 = ( + config["data_in_parallelism_dim_1"], + config["data_in_parallelism_dim_0"], + ) + + self.x_quantizer = partial( + base_quantizer, + q_config={"width": x_width, "exponent_width": x_exponent_width}, + parallelism=[x_p1, x_p0], + ) + + def forward(self, x: Tensor) -> Tensor: + if self.bypass: + return F.relu(x) + else: + y = F.relu(x, self.inplace) + return self.x_quantizer(x) diff --git a/test/passes/graph/analysis/report/test_report_node_hardware_type_analysis_pass.py b/test/passes/graph/analysis/report/test_report_node_hardware_type_analysis_pass.py index a02003605..f0aa419e5 100644 --- a/test/passes/graph/analysis/report/test_report_node_hardware_type_analysis_pass.py +++ b/test/passes/graph/analysis/report/test_report_node_hardware_type_analysis_pass.py @@ -5,6 +5,9 @@ import os import sys +import toml + +from chop.passes.graph.transforms.quantize.quantize import quantize_transform_pass import torch @@ -40,6 +43,26 @@ def test_add_hardware_metadata_analysis_pass(): mg, {"dummy_in": dummy_in, "add_value": False} ) mg, _ = add_software_metadata_analysis_pass(mg, pass_args) + + # Quantize to fixed-point + config_file = os.path.join( + os.path.dirname(os.path.realpath(__file__)), + "..", + "..", + "..", + "..", + "..", + "configs", + "tests", + "quantize", + "fixed.toml", + ) + + # load toml config file + with open(config_file, "r") as f: + quan_args = toml.load(f)["passes"]["quantize"] + mg, _ = quantize_transform_pass(mg, quan_args) + mg, _ = add_hardware_metadata_analysis_pass(mg, pass_args) mg, _ = report_node_hardware_type_analysis_pass(mg, pass_args) diff --git a/test/passes/graph/analysis/report/test_report_node_meta_param_analysis_pass.py b/test/passes/graph/analysis/report/test_report_node_meta_param_analysis_pass.py index 2189e1221..d4588338b 100644 --- a/test/passes/graph/analysis/report/test_report_node_meta_param_analysis_pass.py +++ b/test/passes/graph/analysis/report/test_report_node_meta_param_analysis_pass.py @@ -5,6 +5,9 @@ import os import sys +import toml + +from chop.passes.graph.transforms.quantize.quantize import quantize_transform_pass import torch @@ -40,6 +43,26 @@ def test_report_node_meta_param_analysis_pass(): mg, {"dummy_in": dummy_in, "add_value": False} ) mg, _ = add_software_metadata_analysis_pass(mg, pass_args) + + # Quantize to fixed-point + config_file = os.path.join( + os.path.dirname(os.path.realpath(__file__)), + "..", + "..", + "..", + "..", + "..", + "configs", + "tests", + "quantize", + "fixed.toml", + ) + + # load toml config file + with open(config_file, "r") as f: + quan_args = toml.load(f)["passes"]["quantize"] + mg, _ = quantize_transform_pass(mg, quan_args) + mg, _ = add_hardware_metadata_analysis_pass(mg, pass_args) mg, _ = report_node_meta_param_analysis_pass(mg, pass_args) diff --git a/test/passes/graph/analysis/report/test_report_node_shape_analysis_pass.py b/test/passes/graph/analysis/report/test_report_node_shape_analysis_pass.py index c595ed155..12ff5982e 100644 --- a/test/passes/graph/analysis/report/test_report_node_shape_analysis_pass.py +++ b/test/passes/graph/analysis/report/test_report_node_shape_analysis_pass.py @@ -5,6 +5,9 @@ import os import sys +import toml + +from chop.passes.graph.transforms.quantize.quantize import quantize_transform_pass import torch from chop.tools.logger import set_logging_verbosity @@ -39,6 +42,25 @@ def test_report_node_shape_analysis_pass(): mg, {"dummy_in": dummy_in, "add_value": False} ) mg, _ = add_software_metadata_analysis_pass(mg, pass_args) + # Quantize to fixed-point + config_file = os.path.join( + os.path.dirname(os.path.realpath(__file__)), + "..", + "..", + "..", + "..", + "..", + "configs", + "tests", + "quantize", + "fixed.toml", + ) + + # load toml config file + with open(config_file, "r") as f: + quan_args = toml.load(f)["passes"]["quantize"] + mg, _ = quantize_transform_pass(mg, quan_args) + mg, _ = add_hardware_metadata_analysis_pass(mg, pass_args) mg, _ = report_node_shape_analysis_pass(mg, pass_args) diff --git a/test/passes/graph/analysis/report/test_report_node_type_analysis_pass.py b/test/passes/graph/analysis/report/test_report_node_type_analysis_pass.py index faf3631ba..238134bef 100644 --- a/test/passes/graph/analysis/report/test_report_node_type_analysis_pass.py +++ b/test/passes/graph/analysis/report/test_report_node_type_analysis_pass.py @@ -5,6 +5,9 @@ import os import sys +from chop.passes.graph.transforms.quantize.quantize import quantize_transform_pass +import toml + import torch from chop.tools.logger import set_logging_verbosity @@ -39,6 +42,26 @@ def test_report_node_type_analysis(): mg, {"dummy_in": dummy_in, "add_value": False} ) mg, _ = add_software_metadata_analysis_pass(mg, pass_args) + + # Quantize to fixed-point + config_file = os.path.join( + os.path.dirname(os.path.realpath(__file__)), + "..", + "..", + "..", + "..", + "..", + "configs", + "tests", + "quantize", + "fixed.toml", + ) + + # load toml config file + with open(config_file, "r") as f: + quan_args = toml.load(f)["passes"]["quantize"] + mg, _ = quantize_transform_pass(mg, quan_args) + mg, _ = add_hardware_metadata_analysis_pass(mg, pass_args) mg, _ = report_node_type_analysis_pass(mg, pass_args) diff --git a/test/passes/graph/analysis/verify/test_verify_metadata_analysis_passes.py b/test/passes/graph/analysis/verify/test_verify_metadata_analysis_passes.py index 4b3e489dc..e507ebb94 100644 --- a/test/passes/graph/analysis/verify/test_verify_metadata_analysis_passes.py +++ b/test/passes/graph/analysis/verify/test_verify_metadata_analysis_passes.py @@ -11,6 +11,9 @@ import os import sys +import toml + +from chop.passes.graph.transforms.quantize.quantize import quantize_transform_pass import torch @@ -43,6 +46,26 @@ def test_verify_metadata(): mg, _ = init_metadata_analysis_pass(mg, None) mg, _ = add_common_metadata_analysis_pass(mg, {"dummy_in": dummy_in}) mg, _ = add_software_metadata_analysis_pass(mg, {"dummy_in": dummy_in}) + + # Quantize to fixed-point + config_file = os.path.join( + os.path.dirname(os.path.realpath(__file__)), + "..", + "..", + "..", + "..", + "..", + "configs", + "tests", + "quantize", + "fixed.toml", + ) + + # load toml config file + with open(config_file, "r") as f: + quan_args = toml.load(f)["passes"]["quantize"] + mg, _ = quantize_transform_pass(mg, quan_args) + mg, _ = add_hardware_metadata_analysis_pass(mg, {"dummy_in": dummy_in}) # all three verify passes are bundled in one # mg, _ = verify_metadata_analysis_pass(mg, dummy_in) diff --git a/test/passes/graph/transforms/dse/test_apply_random_partition.py b/test/passes/graph/transforms/dse/test_apply_random_partition.py index 0a2a84c04..32bc40428 100644 --- a/test/passes/graph/transforms/dse/test_apply_random_partition.py +++ b/test/passes/graph/transforms/dse/test_apply_random_partition.py @@ -1,3 +1,7 @@ +import os + +from chop.passes.graph.transforms.quantize.quantize import quantize_transform_pass +import toml import torch import torch.nn as nn @@ -52,6 +56,25 @@ def test_apply_random_partition(): mg, _ = passes.add_common_metadata_analysis_pass(mg, pass_args=pass_args) + # Quantize to fixed-point for hardware pass + config_file = os.path.join( + os.path.dirname(os.path.realpath(__file__)), + "..", + "..", + "..", + "..", + "..", + "configs", + "tests", + "quantize", + "fixed.toml", + ) + + # load toml config file + with open(config_file, "r") as f: + quan_args = toml.load(f)["passes"]["quantize"] + mg, _ = quantize_transform_pass(mg, quan_args) + # add metadata for hardware in each mase node of graph mg, _ = passes.add_hardware_metadata_analysis_pass(mg, pass_args=pass_args) diff --git a/test/passes/graph/transforms/quantize/test_mxint_linear_quantize.py b/test/passes/graph/transforms/quantize/test_mxint_linear_quantize.py index 96bc36784..03cf46ea1 100644 --- a/test/passes/graph/transforms/quantize/test_mxint_linear_quantize.py +++ b/test/passes/graph/transforms/quantize/test_mxint_linear_quantize.py @@ -70,22 +70,21 @@ def test_quantize(): # torch.manual_seed(0) quan_args = { "by": "type", - "default": {"config": {"name": None}}, - "linear": { + "default": { "config": { - "name": "mxint_hardware", + "name": "mxint", # data "data_in_width": 12, "data_in_exponent_width": 4, - "data_in_parallelism": [1, 2], + "weight_block_size": [1, 2], # weight "weight_width": 12, "weight_exponent_width": 4, - "weight_parallelism": [2, 2], + "bias_block_size": [2, 2], # bias "bias_width": 12, "bias_exponent_width": 4, - "bias_parallelism": [1, 2], + "data_in_block_size": [1, 2], } }, } diff --git a/test/passes/graph/transforms/verilog/generate.tcl b/test/passes/graph/transforms/verilog/generate.tcl new file mode 100644 index 000000000..6315fe3e7 --- /dev/null +++ b/test/passes/graph/transforms/verilog/generate.tcl @@ -0,0 +1,20 @@ +source config.tcl + +create_project -in_memory -part xcku5p-ffvb676-2-e +set_property board_part xilinx.com:kcu116:part0:1.5 [current_project] + +add_files -fileset sources_1 "$top_dir/hardware/rtl/" + +set_property top top [current_fileset] + +puts "Trial: ${trial_number}" + +eval "synth_design -mode out_of_context -top top -part xcku5p-ffvb676-2-e" + +save_project_as -force my_project + +launch_runs synth_1 -jobs 12 +wait_on_run synth_1 + +open_run synth_1 +report_utilization -file "$mase_dir/resources/util_${trial_number}.txt" diff --git a/test/passes/graph/transforms/verilog/test_emit_verilog_linear.py b/test/passes/graph/transforms/verilog/test_emit_verilog_linear_fixed.py similarity index 69% rename from test/passes/graph/transforms/verilog/test_emit_verilog_linear.py rename to test/passes/graph/transforms/verilog/test_emit_verilog_linear_fixed.py index e10e61c3e..e77ce1fef 100644 --- a/test/passes/graph/transforms/verilog/test_emit_verilog_linear.py +++ b/test/passes/graph/transforms/verilog/test_emit_verilog_linear_fixed.py @@ -48,7 +48,6 @@ def forward(self, x): return x -@pytest.mark.dev def test_emit_verilog_linear(): mlp = MLP() mg = chop.MaseGraph(model=mlp) @@ -80,28 +79,6 @@ def test_emit_verilog_linear(): quan_args = toml.load(f)["passes"]["quantize"] mg, _ = passes.quantize_transform_pass(mg, quan_args) - # There is a bug in the current quantizzation pass, where the results metadata is not uppdated with the precision. - # Here we temporarily update the metadata here so we can test the hardware back end. - for node in mg.fx_graph.nodes: - for arg, _ in node.meta["mase"].parameters["common"]["args"].items(): - if ( - type(node.meta["mase"].parameters["common"]["args"][arg]) == dict - and "type" in node.meta["mase"].parameters["common"]["args"][arg].keys() - ): - node.meta["mase"].parameters["common"]["args"][arg]["type"] = "fixed" - for result, _ in node.meta["mase"].parameters["common"]["results"].items(): - if ( - type(node.meta["mase"].parameters["common"]["results"][result]) == dict - and "type" - in node.meta["mase"].parameters["common"]["results"][result].keys() - ): - node.meta["mase"].parameters["common"]["results"][result][ - "type" - ] = "fixed" - node.meta["mase"].parameters["common"]["results"][result][ - "precision" - ] = [8, 3] - # Increase weight range mg.model.fc1.weight = torch.nn.Parameter( 10 * torch.randn(mg.model.fc1.weight.shape) diff --git a/test/passes/graph/transforms/verilog/test_emit_verilog_mxint.py b/test/passes/graph/transforms/verilog/test_emit_verilog_mxint.py new file mode 100644 index 000000000..3e6815e40 --- /dev/null +++ b/test/passes/graph/transforms/verilog/test_emit_verilog_mxint.py @@ -0,0 +1,214 @@ +#!/usr/bin/env python3 +# This example converts a simple MLP model to Verilog +import random +import os, sys, logging, traceback, pdb +from chop.passes.graph.analysis.report.report_node import report_node_type_analysis_pass +import pytest +import toml + +import torch +import torch.nn as nn + +import chop as chop +import chop.passes as passes + +from pathlib import Path + +from chop.actions import simulate +from chop.tools.logger import set_logging_verbosity +from chop.tools import get_logger + +set_logging_verbosity("debug") + +logger = get_logger(__name__) + + +# -------------------------------------------------- +# Model specifications +# prefer small models for fast test +# -------------------------------------------------- +class MLP(torch.nn.Module): + def __init__(self, features: list[int]) -> None: + super().__init__() + + layers = [] + for in_f, out_f in zip(features[:-1], features[1:]): + layers.append(nn.Linear(in_f, out_f)) + layers.append(nn.ReLU()) + + self.model = nn.Sequential(*layers) + print(self.model) + + def forward(self, x): + return self.model(x) + + +def test_emit_verilog_mxint_mlp(seed: int = 10): + torch.manual_seed(seed) + random.seed(seed) + + block_size = random.randint(2, 10) + batch_parallelism = random.randint(2, 10) + mlp_depth = random.randint(1, 10) + mlp_features = [block_size * random.randint(1, 10) for _ in range(mlp_depth + 1)] + + params = { + "seed": seed, + "block_size": block_size, + "batch_parallelism": batch_parallelism, + "m_width": (m_width := random.randint(4, 10)), + "e_width": random.randint(3, min(m_width - 1, 10)), + "batches": batch_parallelism * random.randint(1, 20), + "num_batches": random.randint(1, 20), + } + + mlp = MLP(mlp_features) + input_shape = (mlp_features[0],) + logger.info( + f"{block_size=}, {batch_parallelism=}, {params['e_width']=}, {params['m_width']=}, {params['batches']=}" + ) + + shared_emit_verilog_mxint(mlp, input_shape, params) + + +def test_emit_verilog_mxint_linear(seed: int = 10): + torch.manual_seed(seed) + random.seed(seed) + + block_size = random.randint(2, 10) + batch_parallelism = random.randint(2, 10) + IN_FEATURES = block_size * random.randint(1, 10) + OUT_FEATURES = block_size * random.randint(1, 10) + + params = { + "seed": seed, + "block_size": block_size, + "batch_parallelism": batch_parallelism, + "m_width": (m_width := random.randint(5, 10)), + "e_width": random.randint(4, min(m_width - 1, 10)), + "batches": batch_parallelism * random.randint(1, 20), + "num_batches": random.randint(1, 20), + } + + class LinearModel(torch.nn.Module): + def __init__(self, IN_FEATURES, OUT_FEATURES) -> None: + super().__init__() + self.fc1 = torch.nn.Linear(IN_FEATURES, OUT_FEATURES) + + def forward(self, x): + return self.fc1(x) + + linear = LinearModel(IN_FEATURES, OUT_FEATURES) + input_shape = (IN_FEATURES,) + logger.info( + f"{block_size=}, {batch_parallelism=}, {params['e_width']=}, {params['m_width']=}, {params['batches']=}" + ) + + shared_emit_verilog_mxint(linear, input_shape, params) + + +def shared_emit_verilog_mxint(model, input_shape, params: dict, sim: bool = True): + # Set seeds + torch.manual_seed(params["seed"]) + random.seed(params["seed"]) + + block_size = params["block_size"] + batch_parallelism = params["batch_parallelism"] + m_width = params["m_width"] + e_width = params["e_width"] + batches = params["batches"] + num_batches = params["num_batches"] + + mg = chop.MaseGraph(model=model) + x = torch.randn((batches, *input_shape)) + dummy_in = {"x": x} + + mg, _ = passes.init_metadata_analysis_pass(mg, None) + mg, _ = passes.add_common_metadata_analysis_pass(mg, {"dummy_in": dummy_in}) + + quan_args = { + "by": "type", + "default": { + "config": { + "name": "mxint", + "data_in_width": m_width, + "data_in_exponent_width": e_width, + "data_in_block_size": [batch_parallelism, block_size], + "weight_width": m_width, + "weight_exponent_width": e_width, + "weight_block_size": [block_size, block_size], + "bias_width": m_width, + "bias_exponent_width": e_width, + "bias_block_size": [1, block_size], + } + }, + } + + mg, _ = passes.quantize_transform_pass(mg, quan_args) + _ = report_node_type_analysis_pass(mg) + mg, _ = passes.report_node_meta_param_analysis_pass(mg) + + # Parallelism adjustments + for node in mg.fx_graph.nodes: + node_meta = node.meta["mase"].parameters["common"] + args = node_meta["args"] + results = node_meta["results"] + match node_meta["mase_op"]: + case "linear": + args["data_in_0"]["parallelism_0"] = block_size + args["data_in_0"]["parallelism_1"] = batch_parallelism + args["weight"]["parallelism_0"] = block_size + args["weight"]["parallelism_1"] = block_size + args["bias"]["parallelism_0"] = block_size + args["bias"]["parallelism_1"] = 1 + + results["data_out_0"]["parallelism_0"] = block_size + results["data_out_0"]["parallelism_1"] = batch_parallelism + case "relu": + args["data_in_0"]["parallelism_0"] = block_size + args["data_in_0"]["parallelism_1"] = batch_parallelism + results["data_out_0"]["parallelism_0"] = block_size + results["data_out_0"]["parallelism_1"] = batch_parallelism + + mg, _ = passes.add_hardware_metadata_analysis_pass(mg) + mg, _ = passes.report_node_hardware_type_analysis_pass(mg) + + mg, _ = passes.emit_verilog_top_transform_pass(mg) + mg, _ = passes.emit_bram_transform_pass(mg) + mg, _ = passes.emit_internal_rtl_transform_pass(mg) + + if sim: + mg, _ = passes.emit_cocotb_transform_pass( + mg, + pass_args={ + "wait_time": 10 * block_size * batch_parallelism * num_batches, + "wait_unit": "us", + "num_batches": num_batches, + }, + ) + + simulate( + skip_build=False, + skip_test=False, + simulator="verilator", + waves=True, + ) + + logger.info( + f"{block_size=}, {batch_parallelism=}, {m_width=}, {e_width=}, {batches=}" + ) + + return model, mg.model + + +if __name__ == "__main__": + seed = os.getenv("COCOTB_SEED") + if seed is None: + seed = random.randrange(sys.maxsize) + logger.info(f"Generated {seed=}") + else: + seed = int(seed) + logger.info(f"Using provided {seed=}") + # test_emit_verilog_mxint_linear(seed) + test_emit_verilog_mxint_mlp(seed) + logger.info(f"{seed=}") diff --git a/test/passes/graph/transforms/verilog/test_synthesize_mxint_vivado.py b/test/passes/graph/transforms/verilog/test_synthesize_mxint_vivado.py new file mode 100644 index 000000000..33fd8700b --- /dev/null +++ b/test/passes/graph/transforms/verilog/test_synthesize_mxint_vivado.py @@ -0,0 +1,189 @@ +import os, re, random +import optuna +from optuna import study +from optuna.samplers import TPESampler, GridSampler +import json +from chop.tools.logger import set_logging_verbosity +from chop.tools import get_logger +from test_emit_verilog_mxint import MLP, shared_emit_verilog_mxint +import os, sys, logging, traceback, pdb +import pytest +import toml +import torch +import torch.nn as nn +import chop as chop +import chop.passes as passes +from pathlib import Path +from chop.actions import simulate +from chop.passes.graph.analysis.report.report_node import report_node_type_analysis_pass +from chop.tools.logger import set_logging_verbosity +from chop.tools import get_logger + +config_file = "config.tcl" + +set_logging_verbosity("debug") + +logger = get_logger(__name__) + + +def dump_param(trial_number, quan_args, filename="output.json"): + try: + with open(filename, "r") as file: + data = json.load(file) + except (FileNotFoundError, json.JSONDecodeError): + data = {} + + data[str(trial_number)] = quan_args + + with open(filename, "w") as file: + json.dump(data, file, indent=4) + + +def write_value(trial_number, name, value, filename="output.json"): + try: + with open(filename, "r") as file: + data = json.load(file) + except (FileNotFoundError, json.JSONDecodeError): + data = {} + + if str(trial_number) in data.keys(): + data[str(trial_number)][name] = value + else: + data[str(trial_number)] = {name: value} + + with open(filename, "w") as file: + json.dump(data, file, indent=4) + + +def get_params(trial): + + block_size = 2 ** trial.suggest_int("block_size", 1, 4) + batch_parallelism = 2 ** trial.suggest_int("batch_parallelism", 1, 4) + mlp_depth = 3 + mlp_features = [128 for i in range(mlp_depth + 1)] + + params = { + "seed": trial.number, + "block_size": block_size, + "batch_parallelism": batch_parallelism, + "m_width": (m_width := trial.suggest_int("m_width", 4, 10)), + "e_width": trial.suggest_int("e_width", 3, min(m_width - 1, 10)), + "batches": 128, + "num_batches": 10, + } + + mlp = MLP(mlp_features) + input_shape = (mlp_features[0],) + + logger.info( + f"{block_size=}, {batch_parallelism=}, {params['e_width']=}, {params['m_width']=}, {params['batches']=}" + ) + + mg, mlp = shared_emit_verilog_mxint(mlp, input_shape, params, sim=False) + + return params, mg, mlp + + +def writeTrialNumber(trial_number): + with open(config_file, "w") as f: + f.write(f"set trial_number {trial_number}\n") + f.write(f"set top_dir {Path.home()}/.mase/top/\n") + f.write(f"set mase_dir {Path.cwd()}/") + + +def extract_site_type_used_util(filename): + site_data = {} + with open(filename, "r") as file: + lines = file.readlines() + + pattern = re.compile(r"\|\s*([^|]+?)\s*\|\s*(\d+)\s*\|.*?\|\s*(\d+\.\d+)\s*\|") + + for line in lines: + match = pattern.match(line) + if match: + site_type = match.group(1).strip() + used = int(match.group(2).strip()) + util = float(match.group(3).strip()) + site_data[site_type] = {"Used": used, "Util%": util} + + return site_data + + +def get_bram_uram_util(filename): + site_data = extract_site_type_used_util(filename) + bram_util = site_data.get("Block RAM Tile", {}).get("Util%", 0.0) + uram_util = site_data.get("URAM", {}).get("Util%", 0.0) + return {"bram": bram_util, "uram": uram_util} + + +def getResources(trial): + params, mg, mlp = get_params(trial) + dump_param(trial.number, params) + writeTrialNumber(trial.number) + os.system( + f"vivado -mode batch -nolog -nojou -source {Path.cwd()}/test/passes/graph/transforms/verilog/generate.tcl" + ) + bram_utils = get_bram_uram_util(f"{Path.cwd()}/resources/util_{trial.number}.txt") + clb_luts = extract_site_type_used_util( + f"{Path.cwd()}/resources/util_{trial.number}.txt" + ) + out = ( + clb_luts["CLB LUTs*"]["Util%"] + + clb_luts["CLB Registers"]["Util%"] + + clb_luts["CARRY8"]["Util%"] + + bram_utils["bram"] + + bram_utils["uram"] + ) + write_value(trial.number, "resource_score", out) + return out + + +def getAccuracy(trial): + params, mg, mlp = get_params(trial) + quantized = mg.model + + criterion = nn.MSELoss() + total_mse = 0.0 + + for _ in range(100): + x = torch.randn(params["batches"], mg.model[0].in_features) + y1 = quantized(x) + y2 = mlp(x) + mse = criterion(y1, y2) + total_mse += mse.item() + + avg_mse = total_mse / 100 + + write_value(trial.number, "avg_mse", avg_mse) + return avg_mse + + +def main(): + sampler = TPESampler() + + study = optuna.create_study( + directions=["minimize", "minimize"], + study_name="resource_accuracy_optimiser", + sampler=sampler, + ) + + study.optimize( + lambda trial: (getResources(trial), getAccuracy(trial)), + n_trials=10, + timeout=60 * 60 * 24, + n_jobs=1, + ) + + print("Best trials:") + for trial in study.best_trials: + print(f"Trial {trial.number}: {trial.values}") + + +if __name__ == "__main__": + + try: + os.mkdir(f"{Path.cwd()}/resources/") + except: + pass + + main() diff --git a/test/passes/module/transforms/quantize/test_quantize_module.py b/test/passes/module/transforms/quantize/test_quantize_module.py index 853b4327d..7a383d2fc 100644 --- a/test/passes/module/transforms/quantize/test_quantize_module.py +++ b/test/passes/module/transforms/quantize/test_quantize_module.py @@ -48,7 +48,7 @@ def test_quantize_module_transform_pass(): "by": "name", "fc1": { "config": { - "name": "integer", + "name": "fixed", "data_in_width": 8, "data_in_frac_width": 4, "weight_width": 8,