diff --git a/Makefile b/Makefile
index 6f591b541..03d093afb 100644
--- a/Makefile
+++ b/Makefile
@@ -65,6 +65,19 @@ build-docker:
docker pull $(img); \
fi
+build-docker-python13:
+ docker build --build-arg VHLS_PATH=$(vhls) --build-arg VHLS_VERSION=$(vhls_version) -f Docker/Dockerfile-$(PLATFORM)-python13 --tag mase-ubuntu2204-docker-python13 Docker; \
+
+shell-python13:
+ docker run -it --shm-size 256m \
+ --hostname mase-ubuntu2204-docker-python13 \
+ -w /workspace \
+ -v /$(USER_PREFIX)/$(shell whoami)/.gitconfig:/root/.gitconfig \
+ -v /$(USER_PREFIX)/$(shell whoami)/.ssh:/root/.ssh \
+ -v /$(USER_PREFIX)/$(shell whoami)/.mase:/root/.mase:z \
+ -v $(shell pwd):/workspace:z \
+ $(DOCKER_RUN_EXTRA_ARGS) \
+ $(img) /bin/bash
shell:
docker run -it --shm-size 256m \
--hostname mase-ubuntu2204 \
diff --git a/a_cx_mxint_quant/__init__.py b/a_cx_mxint_quant/__init__.py
new file mode 100644
index 000000000..ac5bdc3d1
--- /dev/null
+++ b/a_cx_mxint_quant/__init__.py
@@ -0,0 +1,86 @@
+from .module_level_tranform import vit_module_level_quantize
+from .quantizers import mxint_hardware, mxint_quant_block
+
+from .linear import MXIntLinear
+from .attention import MXIntAttention
+from .module_level_tranform import MXIntLayerNorm, MXIntGELU
+from .modules import MXIntPatchEmbed, MXIntAddition
+from mase_components import get_module_dependencies
+VIT_CUSTOM_OPS = {
+ "modules": {
+ MXIntPatchEmbed: {
+ "args": {
+ "data_in": "data_in",
+ "q_config": "config",
+ },
+ "toolchain": "INTERNAL_RTL",
+ "module": "mxint_patch_embed",
+ "dependence_files": get_module_dependencies(
+ "linear_layers/mxint_operators/mxint_patch_embed"
+ ),
+ },
+ MXIntAttention: {
+ "args": {
+ "data_in": "data_in",
+ "dim": "config",
+ "num_heads": "config",
+ "qkv_bias": "config",
+ "qk_norm": None,
+ "attn_drop": None,
+ "proj_drop": None,
+ "norm_layer": None,
+ "q_config": "config",
+ },
+ "toolchain": "INTERNAL_RTL",
+ "module": "mxint_vit_attention_wrap",
+ "dependence_files": get_module_dependencies(
+ "linear_layers/mxint_operators/mxint_vit_attention_wrap"
+ ),
+ },
+ MXIntLayerNorm: {
+ "args": {
+ "data_in": "data_in",
+ "q_config": "config",
+ },
+ "toolchain": "INTERNAL_RTL",
+ "module": "mxint_layernorm",
+ "dependence_files": get_module_dependencies(
+ "linear_layers/mxint_operators/mxint_layernorm"
+ ),
+ },
+ MXIntGELU: {
+ "args": {
+ "data_in": "data_in",
+ "q_config": "config",
+ },
+ "toolchain": "INTERNAL_RTL",
+ "module": "mxint_gelu",
+ "dependence_files": get_module_dependencies(
+ "linear_layers/mxint_operators/mxint_gelu"
+ ),
+ },
+ MXIntLinear: {
+ "args": {
+ "data_in": "data_in",
+ "q_config": "config",
+ },
+ "toolchain": "INTERNAL_RTL",
+ "module": "mxint_linear",
+ "dependence_files": get_module_dependencies(
+ "linear_layers/mxint_operators/mxint_linear"
+ ),
+ },
+ MXIntAddition: {
+ "args": {
+ "input_0": "data_in",
+ "input_1": "data_in",
+ "q_config": "config",
+ },
+ "toolchain": "INTERNAL_RTL",
+ "module": "mxint_addition",
+ "dependence_files": get_module_dependencies(
+ "linear_layers/mxint_operators/mxint_addition"
+ ),
+ },
+ },
+}
\ No newline at end of file
diff --git a/a_cx_mxint_quant/attention.py b/a_cx_mxint_quant/attention.py
new file mode 100644
index 000000000..34892c16e
--- /dev/null
+++ b/a_cx_mxint_quant/attention.py
@@ -0,0 +1,192 @@
+from functools import partial
+
+import torch
+import torch.nn as nn
+from torch import Tensor
+from torch.nn import functional as F
+
+from .attention_head import _ViTSelfAttentionHeadBase, ViTSelfAttentionHeadInteger
+
+from chop.nn.quantized.modules.linear import (
+ LinearInteger,
+)
+from chop.nn.quantized.functional import fixed_softermax
+from chop.nn.quantizers import integer_quantizer
+from chop.nn.quantized.functional import matmul_integer
+
+from typing import Optional, Tuple, Union
+
+from .linear import MXIntLinear
+from .attention_head import MXIntViTAttentionHead
+
+class _ViTAttentionBase(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 8,
+ qkv_bias: bool = False,
+ qk_norm: bool = False,
+ attn_drop: float = 0.0,
+ proj_drop: float = 0.0,
+ ) -> None:
+ super().__init__()
+ assert dim % num_heads == 0, "dim should be divisible by num_heads"
+ self.dim = dim
+ self.num_heads = num_heads
+ self.head_dim = dim // num_heads
+ self.query = nn.Linear(dim, dim, bias=qkv_bias)
+ self.key = nn.Linear(dim, dim, bias=qkv_bias)
+ self.value = nn.Linear(dim, dim, bias=qkv_bias)
+ self.self_attention = _ViTSelfAttentionHeadBase(
+ dim=self.head_dim, num_heads=num_heads, attn_drop=attn_drop
+ )
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ B, N, C = x.shape
+
+ def _tensor_reshape(x):
+ return x.reshape(B, -1, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
+
+ q, k, v = (
+ _tensor_reshape(self.query(x)),
+ _tensor_reshape(self.key(x)),
+ _tensor_reshape(self.value(x)),
+ )
+ x = self.self_attention(q, k, v)
+ x = x.transpose(1, 2).reshape(B, N, C)
+
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+class ViTAttentionInteger(_ViTAttentionBase):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 8,
+ qkv_bias: bool = False,
+ qk_norm: bool = False,
+ attn_drop: float = 0.0,
+ proj_drop: float = 0.0,
+ norm_layer: nn.Module = nn.LayerNorm,
+ q_config: dict = None,
+ floor=True,
+ ) -> None:
+ super().__init__(dim, num_heads, qkv_bias, qk_norm, attn_drop, proj_drop)
+ self.q_config = q_config
+ self.query = LinearInteger(
+ dim,
+ dim,
+ bias=qkv_bias,
+ config={
+ "data_in_width": q_config["data_in_width"],
+ "data_in_frac_width": q_config["data_in_frac_width"],
+ "weight_width": q_config["qkv_weight_width"],
+ "weight_frac_width": q_config["qkv_weight_frac_width"],
+ "bias_width": q_config["qkv_bias_width"],
+ "bias_frac_width": q_config["qkv_bias_frac_width"],
+ },
+ out_config={
+ "data_out_width": q_config["qkv_width"],
+ "data_out_frac_width": q_config["qkv_frac_width"],
+ },
+ floor=floor,
+ )
+ self.key = LinearInteger(
+ dim,
+ dim,
+ bias=qkv_bias,
+ config={
+ "data_in_width": q_config["data_in_width"],
+ "data_in_frac_width": q_config["data_in_frac_width"],
+ "weight_width": q_config["qkv_weight_width"],
+ "weight_frac_width": q_config["qkv_weight_frac_width"],
+ "bias_width": q_config["qkv_bias_width"],
+ "bias_frac_width": q_config["qkv_bias_frac_width"],
+ },
+ out_config={
+ "data_out_width": q_config["qkv_width"],
+ "data_out_frac_width": q_config["qkv_frac_width"],
+ },
+ floor=floor,
+ )
+ self.value = LinearInteger(
+ dim,
+ dim,
+ bias=qkv_bias,
+ config={
+ "data_in_width": q_config["data_in_width"],
+ "data_in_frac_width": q_config["data_in_frac_width"],
+ "weight_width": q_config["qkv_weight_width"],
+ "weight_frac_width": q_config["qkv_weight_frac_width"],
+ "bias_width": q_config["qkv_bias_width"],
+ "bias_frac_width": q_config["qkv_bias_frac_width"],
+ },
+ out_config={
+ "data_out_width": q_config["qkv_width"],
+ "data_out_frac_width": q_config["qkv_frac_width"],
+ },
+ floor=floor,
+ )
+ self.self_attention = ViTSelfAttentionHeadInteger(
+ dim=self.head_dim,
+ num_heads=num_heads,
+ attn_drop=attn_drop,
+ q_config={
+ "query_width": q_config["qkv_width"],
+ "query_frac_width": q_config["qkv_frac_width"],
+ "key_width": q_config["qkv_width"],
+ "key_frac_width": q_config["qkv_frac_width"],
+ "value_width": q_config["qkv_width"],
+ "value_frac_width": q_config["qkv_frac_width"],
+ "qkmm_out_width": q_config["qkmm_out_width"],
+ "qkmm_out_frac_width": q_config["qkmm_out_frac_width"],
+ "softmax_exp_width": q_config["softmax_exp_width"],
+ "softmax_exp_frac_width": q_config["softmax_exp_frac_width"],
+ "softmax_out_frac_width": q_config["softmax_out_frac_width"],
+ "svmm_out_width": q_config["svmm_out_width"],
+ "svmm_out_frac_width": q_config["svmm_out_frac_width"],
+ },
+ floor=floor,
+ )
+ self.proj = LinearInteger(
+ dim,
+ dim,
+ config={
+ "data_in_width": q_config["svmm_out_width"],
+ "data_in_frac_width": q_config["svmm_out_frac_width"],
+ "weight_width": q_config["proj_weight_width"],
+ "weight_frac_width": q_config["proj_weight_frac_width"],
+ "bias_width": q_config["proj_bias_width"],
+ "bias_frac_width": q_config["proj_bias_frac_width"],
+ },
+ out_config={
+ "data_out_width": q_config["data_out_width"],
+ "data_out_frac_width": q_config["data_out_frac_width"],
+ },
+ floor=floor,
+ )
+
+class MXIntAttention(_ViTAttentionBase):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 8,
+ qkv_bias: bool = False,
+ qk_norm: bool = False,
+ attn_drop: float = 0.0,
+ proj_drop: float = 0.0,
+ q_config: dict = None,
+ ) -> None:
+ super().__init__(dim, num_heads, qkv_bias, qk_norm, attn_drop, proj_drop)
+ self.q_config = q_config
+
+ # Replace attention with MXIntViTAttentionHead
+ # self.self_attention = MXIntViTAttentionHead(
+ # dim=self.head_dim,
+ # num_heads=num_heads,
+ # attn_drop=attn_drop,
+ # q_config=q_config
+ # )
diff --git a/a_cx_mxint_quant/attention_head.py b/a_cx_mxint_quant/attention_head.py
new file mode 100644
index 000000000..9935948e6
--- /dev/null
+++ b/a_cx_mxint_quant/attention_head.py
@@ -0,0 +1,167 @@
+import torch
+from torch import Tensor
+import torch.nn as nn
+import math
+
+from typing import Optional, Tuple
+from functools import partial
+
+from chop.nn.quantized.functional.matmul import (
+ generic_matmul_integer,
+)
+from chop.nn.quantized.functional.softmax import (
+ softmax_integer,
+)
+from chop.nn.quantizers.integer import integer_quantizer, integer_floor_quantizer
+from .quantizers import mxint_quant_block
+
+class _ViTSelfAttentionHeadBase(torch.nn.Module):
+ def __init__(self, dim, num_heads, attn_drop) -> None:
+ super().__init__()
+ self.dropout = nn.Dropout(attn_drop)
+
+ self.matmul1 = torch.matmul
+ self.matmul2 = torch.matmul
+ self.mult_data = torch.tensor(1 / math.sqrt(dim))
+ self.act = nn.functional.softmax
+
+ def self_attention_head(
+ self,
+ query_layer: torch.Tensor,
+ key_layer: torch.Tensor,
+ value_layer: torch.Tensor,
+ ) -> Tensor:
+ attention_scores = self.matmul1(query_layer, key_layer.transpose(-1, -2))
+ attention_scores = attention_scores * self.mult_data
+
+ # Normalize the attention scores to probabilities.
+ attention_probs = self.act(attention_scores, dim=-1)
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs = self.dropout(attention_probs)
+ context_layer = self.matmul2(attention_probs, value_layer)
+ return context_layer
+
+ def forward(
+ self,
+ query_layer: torch.Tensor,
+ key_layer: torch.Tensor,
+ value_layer: torch.Tensor,
+ ) -> Tensor:
+ return self.self_attention_head(
+ query_layer=query_layer, key_layer=key_layer, value_layer=value_layer
+ )
+
+from .linear import MXIntLinear, fast_linear
+from .quantizers import mxint_hardware
+
+class MXIntMatMul(nn.Module):
+ def __init__(self, q_config=None):
+ super().__init__()
+ assert q_config is not None, "q_config cannot be None"
+ self.q_config = q_config
+
+ def forward(self, x: torch.Tensor, y: torch.Tensor):
+ qx, _, _ = mxint_hardware(
+ x,
+ q_config = {
+ "width": self.q_config["data_in_width"],
+ "exponent_width": self.q_config["data_in_exponent_width"],
+ },
+ parallelism = self.q_config["data_in_parallelism"]
+ )
+ qy, _, _ = mxint_hardware(
+ y,
+ q_config = {
+ "width": self.q_config["weight_width"],
+ "exponent_width": self.q_config["weight_exponent_width"],
+ },
+ parallelism = self.q_config["weight_parallelism"]
+ )
+
+ out = qx @ qy
+ out, _, _ = mxint_hardware(
+ out,
+ q_config = {
+ "width": self.q_config["data_out_width"],
+ "exponent_width": self.q_config["data_out_exponent_width"],
+ },
+ parallelism = self.q_config["data_out_parallelism"]
+ )
+ return out
+
+from .softmax import MXIntSoftmax
+class MXIntViTAttentionHead(_ViTSelfAttentionHeadBase):
+ def __init__(
+ self, dim, num_heads, attn_drop=0.0, q_config: dict = None, floor=False
+ ) -> None:
+ super().__init__(dim, num_heads, attn_drop)
+ self.dropout = nn.Dropout(attn_drop)
+
+ self.matmul1 = torch.matmul
+ self.matmul2 = torch.matmul
+ self.act = MXIntSoftmax(q_config=q_config)
+ self.mult_data = torch.tensor(1 / math.sqrt(dim))
+
+class ViTSelfAttentionHeadInteger(_ViTSelfAttentionHeadBase):
+ def __init__(
+ self, dim, num_heads, attn_drop=0.0, q_config: dict = None, floor=False
+ ) -> None:
+ super().__init__(dim, num_heads, attn_drop)
+ base_quantizer = integer_floor_quantizer if floor else integer_quantizer
+ self.query_quantizer = partial(
+ base_quantizer,
+ width=q_config["query_width"],
+ frac_width=q_config["query_frac_width"],
+ )
+ self.key_quantizer = partial(
+ base_quantizer,
+ width=q_config["key_width"],
+ frac_width=q_config["key_frac_width"],
+ )
+ self.value_quantizer = partial(
+ base_quantizer,
+ width=q_config["value_width"],
+ frac_width=q_config["value_frac_width"],
+ )
+ self.matmul1 = partial(
+ generic_matmul_integer,
+ config={
+ "data_in_width": q_config["query_width"],
+ "data_in_frac_width": q_config["query_frac_width"],
+ "weight_width": q_config["key_width"],
+ "weight_frac_width": q_config["key_frac_width"],
+ },
+ out_config={
+ "data_out_width": q_config["qkmm_out_width"],
+ "data_out_frac_width": q_config["qkmm_out_frac_width"],
+ },
+ floor=floor,
+ )
+ self.act = partial(
+ softmax_integer,
+ config={
+ "data_in_width": q_config["qkmm_out_width"],
+ "data_in_frac_width": q_config["qkmm_out_frac_width"],
+ "data_in_exp_width": q_config["softmax_exp_width"],
+ "data_in_exp_frac_width": q_config["softmax_exp_frac_width"],
+ "data_out_frac_width": q_config["softmax_out_frac_width"],
+ "mult_data": self.mult_data,
+ },
+ floor=floor,
+ )
+ self.mult_data = torch.tensor(1)
+ self.matmul2 = partial(
+ generic_matmul_integer,
+ config={
+ "data_in_width": q_config["softmax_out_frac_width"] + 2,
+ "data_in_frac_width": q_config["softmax_out_frac_width"],
+ "weight_width": q_config["value_width"],
+ "weight_frac_width": q_config["value_frac_width"],
+ },
+ out_config={
+ "data_out_width": q_config["svmm_out_width"],
+ "data_out_frac_width": q_config["svmm_out_frac_width"],
+ },
+ floor=floor,
+ )
diff --git a/a_cx_mxint_quant/gelu.drawio b/a_cx_mxint_quant/gelu.drawio
new file mode 100644
index 000000000..bb0299be8
--- /dev/null
+++ b/a_cx_mxint_quant/gelu.drawio
@@ -0,0 +1,220 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/a_cx_mxint_quant/gelu.py b/a_cx_mxint_quant/gelu.py
new file mode 100644
index 000000000..cbe191f1d
--- /dev/null
+++ b/a_cx_mxint_quant/gelu.py
@@ -0,0 +1,77 @@
+# models.py
+import torch
+import torch.nn as nn
+import math
+from typing import List, Union, Optional
+from pathlib import Path
+import torch
+import torch.nn as nn
+from torch import Tensor
+import math
+from typing import Literal, Optional, Tuple, Union, Dict
+from enum import Enum
+from functools import partial
+from tqdm import tqdm
+from chop.nn.quantizers.integer import _integer_quantize
+from .quantizers import mxint_hardware
+from .utils import reshape_to_block, reshape_back
+
+def mxint_gelu(x, q_config):
+ """Vectorized range reduction"""
+ qx, mx, ex = mxint_hardware(
+ x,
+ {
+ "width": q_config["data_in_width"],
+ "exponent_width": q_config["data_in_exponent_width"],
+ "round_bits": 4,
+ },
+ parallelism=q_config["data_in_parallelism"]
+ )
+ # first
+
+ original_shape = qx.shape
+ t1, t0 = mx.shape[-2:]
+ p1, p0 = q_config["data_in_parallelism"]
+ qx = reshape_to_block(qx, t1,t0, p1, p0)
+ mx = reshape_to_block(mx, t1, t0, p1, p0)
+ ex = ex.unsqueeze(-1).unsqueeze(-1)
+
+ qout = torch.relu(qx)
+ eout = ex
+ remaining = (qx > -3) & (qx < 3)
+
+ # data_width_loss
+ # avoid quant_loss here
+ # we will need to shift it to
+ # in hardware qx is lossless
+ VALID_WIDTH = q_config["data_in_width"] + 2
+ HASH_OUT_WIDTH = q_config["hash_out_width"]
+ HASH_OUT_FRAC_WIDTH = HASH_OUT_WIDTH - 3
+ # hash loss
+ qgelu = _integer_quantize(torch.nn.GELU()(qx), HASH_OUT_WIDTH, HASH_OUT_FRAC_WIDTH)
+ mgelu = qgelu * 2**(HASH_OUT_WIDTH - 1) // 2**ex
+ qgelu = mgelu * 2**ex / 2**(HASH_OUT_WIDTH - 1)
+
+ qout[remaining] = qgelu[remaining]
+ qout = reshape_back(qout, t1, t0, p1, p0)
+ qout = qout.reshape(original_shape)
+ qx, mx, ex = mxint_hardware(
+ qout,
+ {
+ "width": q_config["data_out_width"],
+ "exponent_width": q_config["data_out_exponent_width"],
+ "round_bits": 4,
+ },
+ parallelism=q_config["data_out_parallelism"]
+ )
+ return qx, mx, ex
+
+class MXIntGELU(nn.Module):
+ def __init__(self, q_config: Dict = {}):
+ super().__init__()
+ self.q_config = q_config
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ out, _, _ = mxint_gelu(x, self.q_config)
+ return out
+
diff --git a/a_cx_mxint_quant/layer_norm.drawio b/a_cx_mxint_quant/layer_norm.drawio
new file mode 100644
index 000000000..8589bf0a0
--- /dev/null
+++ b/a_cx_mxint_quant/layer_norm.drawio
@@ -0,0 +1,469 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/a_cx_mxint_quant/layer_norm.py b/a_cx_mxint_quant/layer_norm.py
new file mode 100644
index 000000000..302307c05
--- /dev/null
+++ b/a_cx_mxint_quant/layer_norm.py
@@ -0,0 +1,166 @@
+from torch import nn
+import torch
+
+from .quantizers import mxint_quant_block, mxint_hardware
+from chop.nn.quantizers import integer_floor_quantizer, integer_quantizer
+from torch import Tensor
+from math import ceil, log2
+
+def mxint_layer_norm(
+ x: torch.Tensor,
+ normalized_shape: tuple or int,
+ weight=None,
+ bias=None,
+ eps=1e-5,
+ q_config={},
+):
+ def quantize(x, width, frac_width, by_pass=False, floor=False):
+ if not by_pass:
+ if floor:
+ x = integer_floor_quantizer(x, width, frac_width)
+ else:
+ x = integer_quantizer(x, width, frac_width)
+ return x
+
+ def get_dim_and_prodofdim(x, normalized_shape):
+ dim = tuple(range(0 - len(normalized_shape), 0))
+ num_vals = 1
+ for items in dim:
+ num_vals *= x.shape[items]
+ return dim, num_vals
+ '''
+ actually, we cannot assume that the input is quantized
+ '''
+ def isqrt(x: torch.Tensor):
+ x = (x + eps).sqrt()
+ x = x.reciprocal()
+ return x
+
+ if isinstance(normalized_shape, int):
+ normalized_shape = (normalized_shape,)
+ dim, num_vals = get_dim_and_prodofdim(x, normalized_shape)
+ inv_num_vals = torch.tensor(1 / num_vals)
+
+ acc_out_width = ceil(log2(num_vals)) + q_config.get("data_in_width")
+ inv_num_vals_quant_0 = 2**acc_out_width // num_vals / 2**acc_out_width
+ # Mean calculation
+ mu_acc = x.sum(dim, keepdim=True)
+ mu = mu_acc * inv_num_vals_quant_0
+ mu = quantize(
+ mu,
+ q_config.get("data_in_width"),
+ q_config.get("data_in_frac_width"),
+ q_config.get("by_pass"),
+ True,
+ )
+ # I hope the output precision here should be $clog2
+ # Variance calculation
+ diff = x - mu
+
+ squares = diff**2
+ sum_squares = torch.sum(squares, dim, keepdim=True)
+ squares_adder_tree_width = 2 * q_config.get("data_in_width") + ceil(log2(num_vals))
+ inv_num_vals_quant_1 = 2**squares_adder_tree_width // num_vals / 2**squares_adder_tree_width
+ var = sum_squares * inv_num_vals_quant_1
+ var = quantize(
+ var,
+ squares_adder_tree_width + 2,
+ 2*q_config.get("data_in_width") - 2,
+ floor=True,
+ )
+ var, mvar, evar = mxint_hardware(
+ var,
+ {
+ "width": q_config.get("isqrt_in_width"),
+ "exponent_width": 6,
+ },
+ parallelism=[1, 1],
+ )
+
+ var, mvar, evar = mxint_hardware(
+ var,
+ {
+ "width": q_config.get("isqrt_in_width"),
+ "exponent_width": 6,
+ },
+ parallelism=[1, 1],
+ )
+ mvar[evar %2 !=0] *= 2
+ evar[evar %2 !=0] -= 1
+ minv_sqrt = isqrt(mvar/2**(q_config.get("isqrt_in_width") - 1))
+ minv_sqrt = integer_quantizer(minv_sqrt, q_config.get("isqrt_out_width"), q_config.get("isqrt_out_frac_width"))
+ einv_sqrt = -evar/2
+
+ inv_sqrt = minv_sqrt * 2**einv_sqrt
+
+ # Norm calculation
+ mnorm_out = diff * minv_sqrt
+ enorm_out = einv_sqrt
+ mnorm_out = quantize(
+ mnorm_out,
+ q_config.get("data_out_width"),
+ q_config.get("data_out_frac_width"),
+ q_config.get("by_pass"),
+ floor=True,
+ )
+ qnorm_out = mnorm_out*2**einv_sqrt
+ if weight is not None:
+ qweight, mweight, eweight = mxint_hardware(weight,
+ {
+ "width": q_config.get("weight_width"),
+ "exponent_width": q_config.get("weight_exponent_width"),
+ "round_bits": 4
+ },
+ q_config.get("weight_parallelism"))
+ qnorm_out = qnorm_out * qweight
+ if bias is not None:
+ qbias, mbias, ebias = mxint_hardware(bias,
+ {
+ "width": q_config.get("bias_width"),
+ "exponent_width": q_config.get("bias_exponent_width"),
+ "round_bits": 4
+ },
+ q_config.get("bias_parallelism"))
+ qnorm_out = qnorm_out + qbias
+ qnorm_out, mnorm_out, enorm_out = mxint_hardware(qnorm_out,
+ {
+ "width": q_config.get("data_out_width"),
+ "exponent_width": q_config.get("data_out_exponent_width"),
+ "round_bits": 4
+ },
+ q_config.get("data_out_parallelism"))
+ return qnorm_out, mnorm_out, enorm_out
+
+def layer_norm_hardware(
+ x: torch.Tensor,
+ normalized_shape: tuple or int,
+ weight=None,
+ bias=None,
+ eps=1e-5,
+ q_config=None,
+):
+ qx, mx, ex = mxint_quant_block(x, q_config["data_in_width"], q_config["data_in_exponent_width"])
+ qnorm_out, _, _ = mxint_layer_norm(qx, normalized_shape, weight, bias, eps, q_config)
+ return qnorm_out
+
+class MXIntLayerNorm(nn.LayerNorm):
+ def __init__(
+ self,
+ normalized_shape,
+ eps: float = 0.00001,
+ elementwise_affine: bool = False,
+ bias: bool = False,
+ q_config=None,
+ ) -> None:
+ self.q_config = q_config
+ super().__init__(normalized_shape, eps, elementwise_affine, bias)
+
+ def forward(self, x: Tensor) -> Tensor:
+ return layer_norm_hardware(
+ x,
+ self.normalized_shape,
+ self.weight,
+ self.bias,
+ self.eps,
+ q_config=self.q_config,
+ )
\ No newline at end of file
diff --git a/a_cx_mxint_quant/linear.py b/a_cx_mxint_quant/linear.py
new file mode 100644
index 000000000..3e6883ffe
--- /dev/null
+++ b/a_cx_mxint_quant/linear.py
@@ -0,0 +1,173 @@
+from chop.nn.quantized.modules.linear import _LinearBase
+from torch import Tensor
+import torch
+from .quantizers import mxint_hardware, reshape_to_block, reshape_back
+
+def fast_linear(x, w, b, config):
+ batch_size, n = x.shape[:2]
+ out_features = w.shape[0]
+ qx, mx, ex = mxint_hardware(x, **{
+ "parallelism":[config["x_config"]["parallism_dim_1"], config["x_config"]["parallism_dim_0"]],
+ "q_config":{
+ "width": config["x_config"]["width"],
+ "exponent_width": config["x_config"]["exponent_width"],
+ "round_bits": config["round_bits"],
+
+ },
+ })
+ qw, mw, ew = mxint_hardware(w, **{
+ "parallelism":[config["w_config"]["parallism_dim_1"], config["w_config"]["parallism_dim_0"]],
+ "q_config":{
+ "width": config["w_config"]["width"],
+ "exponent_width": config["w_config"]["exponent_width"],
+ "round_bits": 8,
+ }
+ })
+ qb, mb, eb = mxint_hardware(b, **{
+ "parallelism":[config["bias_config"]["parallism_dim_1"], config["bias_config"]["parallism_dim_0"]],
+ "q_config":{
+ "width": config["bias_config"]["width"],
+ "exponent_width": config["bias_config"]["exponent_width"],
+ "round_bits": 8,
+ }
+ })
+ x_config = config["x_config"]
+ w_config = config["w_config"]
+ reshaped_mx = reshape_to_block(mx, x_config["dim_1"], x_config["dim_0"], x_config["parallism_dim_1"], x_config["parallism_dim_0"])
+ reshaped_mw = reshape_to_block(mw, w_config["dim_1"], w_config["dim_0"], w_config["parallism_dim_1"], w_config["parallism_dim_0"])
+
+ # move the infeatures depth to the front
+ mx_for_accumulation = reshaped_mx.permute(2, 0, 1, 3, 4)
+ # The dimension will be [depth_in_features, batch_size, depth_n, parallism_n, parallism_in_features]
+ # For every parallelised block, we will have a unique exponent
+ # Original shape of ex is [batch_size, depth_n, depth_in_features]
+ # We will permute it to [depth_in_features, batch_size, depth_n]
+ ex_for_accumulation = ex.permute(2, 0, 1)
+
+ # Same for mw, the shape of mw is [depth_out_features, depth_in_features, parallism_out_features, parallism_in_features]
+ mw_for_accumulation = reshaped_mw.squeeze(0)
+ mw_for_accumulation = mw_for_accumulation.permute(1, 0, 2, 3)
+ ew_for_accumulation = ew.transpose(0, 1)
+
+ # We are trying to do the matmul based on the block partition
+ # mx is [depth_in_features, batch_size, depth_n, parallism_n, parallism_in_features]
+ # mw is [depth_in_features, depth_out_features, parallism_out_features, parallism_in_features]
+ # merge depth_out_features and parallelism_out_features
+ # mw = [depth_in_features, out_features, parallism_in_features]
+ mw_for_accumulation = mw_for_accumulation.reshape(mw_for_accumulation.shape[0], -1, mw_for_accumulation.shape[-1])
+
+ mout = mx_for_accumulation[0] @ mw_for_accumulation[0].transpose(-2, -1)
+ mout = reshape_to_block(mout, x_config["dim_1"], w_config["dim_1"], x_config["parallism_dim_1"], w_config["parallism_dim_1"])
+ # shape of mout is [batch_size, depth_n, parallism_n, out_features]
+ ex_expanded = ex_for_accumulation.unsqueeze(-1) # [depth_in_features, batch_size, depth_n, 1]
+ ew_expanded = ew_for_accumulation.unsqueeze(1).unsqueeze(2) # [depth_in_features, 1, 1, depth_out_features]
+ eout = (ex_expanded[0] + ew_expanded[0]).unsqueeze(-1).unsqueeze(-1)
+ for i in range(1, mx_for_accumulation.shape[0]):
+ new_exponent = (ex_expanded[i] + ew_expanded[i]).unsqueeze(-1).unsqueeze(-1)
+ max_exponent = torch.max(eout, new_exponent)
+ mout = mout // 2 ** (max_exponent - eout)
+ current_result = mx_for_accumulation[i] @ mw_for_accumulation[i].transpose(-2, -1)
+ current_result = reshape_to_block(current_result, x_config["dim_1"], w_config["dim_1"], x_config["parallism_dim_1"], w_config["parallism_dim_1"])
+ current_result = current_result // 2 ** (max_exponent - new_exponent)
+ mout += current_result
+ eout = max_exponent
+
+ # the shape of qout will be [batch_size, depth_in_n, depth_out_features, paral_n, paral_out_features]
+ # the shape of mb will be [1, 1, out_features]
+ # reshape mb to [1, 1, depth_out_features, 1, paral_out_features]
+ # broad cast to [batch_size, depth_in_n, depth_out_features, paral_n, paral_out_features]
+
+ # the shape of eout willbe [batch_size, depth_n, depth_out_features]
+ # the shape of eb will be [1, 1, depth_out_featuers]
+
+ # so i wish eb can map back to
+ out_config = config["out_config"]
+ b_config = config["bias_config"]
+ width_difference = x_config["width"] + w_config["width"] - 2 - (b_config["width"] -1)
+ reshaped_mb = mb.reshape(1, 1, out_config["depth_dim_0"], 1, out_config["parallism_dim_0"])
+ reshaped_eb = eb.reshape(1, 1, out_config["depth_dim_0"], 1, 1)
+ mb_for_out = reshaped_mb // 2**(eout - reshaped_eb - width_difference)
+ mout = mout + mb_for_out
+
+ qout = reshape_back((mout / 2 **(x_config["width"]+w_config["width"] - 2 - eout)), x_config["dim_1"], w_config["dim_1"], x_config["parallism_dim_1"], w_config["parallism_dim_1"])
+ qout = qout.reshape(batch_size, n, out_features)
+
+ return qout
+
+class MXIntLinear(_LinearBase):
+ def __init__(
+ self,
+ in_features: int,
+ out_features: int,
+ bias: bool = True,
+ device=None,
+ dtype=None,
+ q_config=None,
+ ) -> None:
+ super().__init__(in_features, out_features, bias, device, dtype)
+ assert q_config is not None, "config is None!"
+ self.in_features = in_features
+ self.out_features = out_features
+ self.q_config = q_config
+ self.bypass = q_config.get("bypass", False)
+ if self.bypass:
+ return
+ # establish quantizer
+
+ def forward(self, x: Tensor) -> Tensor:
+ # an example of config
+ unroll_in_features = self.q_config["data_in_parallelism"][1]
+ unroll_out_features = self.q_config["data_out_parallelism"][1]
+ unroll_n = self.q_config["data_in_parallelism"][0]
+ in_features = self.in_features
+ out_features = self.out_features
+ n = x.shape[1]
+ batch_size = x.shape[0]
+ assert x.shape[2] == in_features, f"Input shape mismatch: {x.shape[2]} != {in_features}"
+
+ self.config = {
+ "x_config": {
+ "width": self.q_config["data_in_width"],
+ "exponent_width": self.q_config["data_in_exponent_width"],
+ "parallism_dim_0": unroll_in_features,
+ "parallism_dim_1": unroll_n,
+ "depth_dim_0": in_features // unroll_in_features,
+ "depth_dim_1": n // unroll_n,
+ "dim_0": in_features,
+ "dim_1": n,
+ },
+ "w_config": {
+ "width": self.q_config["weight_width"],
+ "exponent_width": self.q_config["weight_exponent_width"],
+ "parallism_dim_0": unroll_in_features,
+ "parallism_dim_1": unroll_out_features,
+ "depth_dim_0": in_features // unroll_in_features,
+ "depth_dim_1": out_features // unroll_out_features,
+ "dim_0": in_features,
+ "dim_1": out_features,
+ },
+ "bias_config": {
+ "width": self.q_config["bias_width"],
+ "exponent_width": self.q_config["bias_exponent_width"],
+ "parallism_dim_0": unroll_out_features,
+ "parallism_dim_1": 1,
+ "depth_dim_0": out_features // unroll_out_features,
+ "depth_dim_1": 1,
+ "dim_0": out_features,
+ "dim_1": 1,
+ },
+ "out_config": {
+ "width": self.q_config["data_out_width"],
+ "exponent_width": self.q_config["data_out_exponent_width"],
+ "parallism_dim_0": unroll_out_features,
+ "parallism_dim_1": unroll_n,
+ "depth_dim_0": out_features // unroll_out_features,
+ "depth_dim_1": n // unroll_n,
+ "dim_0": out_features,
+ "dim_1": n,
+ },
+ "round_bits": self.q_config.get("round_bits", 4),
+ }
+ # out = fast_linear(x, self.weight, self.bias, self.config)
+ out = torch.nn.Linear(in_features, out_features, bias=True)(x)
+ return out
diff --git a/a_cx_mxint_quant/mase_mxint_top_tb.py b/a_cx_mxint_quant/mase_mxint_top_tb.py
new file mode 100644
index 000000000..beb02fe6c
--- /dev/null
+++ b/a_cx_mxint_quant/mase_mxint_top_tb.py
@@ -0,0 +1,334 @@
+from pathlib import Path
+
+import cocotb
+import logging, torch
+from pathlib import Path
+
+logger = logging.getLogger(__name__)
+
+from pathlib import Path
+
+import cocotb
+from mase_cocotb.testbench import Testbench
+from mase_cocotb.interfaces.streaming import MultiSignalStreamDriver, MultiSignalErrorThresholdStreamMonitor, MultiSignalStreamMonitor
+import sys
+from os import getenv, PathLike
+
+import torch
+from pathlib import Path
+import time
+import warnings
+from cocotb.runner import get_runner, get_results
+
+from chop.tools import get_logger
+import mase_components
+from mase_components import get_modules
+
+import glob, os
+from cocotb.utils import get_sim_time
+def simulate(
+ model: torch.nn.Module = None,
+ model_info=None,
+ task: str = "",
+ dataset_info=None,
+ data_module=None,
+ load_name: PathLike = None,
+ load_type: str = None,
+ run_emit: bool = False,
+ skip_build: bool = False,
+ skip_test: bool = False,
+ trace_depth: int = 3,
+ gui: bool = False,
+ waves: bool = False,
+ simulator: str = "verilator",
+ pass_args = {},
+):
+ SIM = getenv("SIM", simulator)
+ runner = get_runner(SIM)
+
+ project_dir = (
+ pass_args["project_dir"]
+ if "project_dir" in pass_args.keys()
+ else Path.home() / ".mase" / "top"
+ )
+
+ if not skip_build:
+ # To do: extract from mz checkpoint
+ if simulator == "questa":
+ sources = glob.glob(os.path.join(project_dir / "hardware" / "rtl", "*.sv"))
+ build_args = []
+
+ elif simulator == "verilator":
+ # sources = ["../../../top.sv"]
+ sources = glob.glob(os.path.join(project_dir / "hardware" / "rtl", "*.sv"))
+ build_args = [
+ "-Wno-fatal",
+ "-Wno-lint",
+ "-Wno-style",
+ "--trace-fst",
+ "--trace-structs",
+ "--trace-depth",
+ str(trace_depth),
+ "--unroll-count",
+ "16384"
+ ]
+ else:
+ raise ValueError(f"Unrecognized simulator: {simulator}")
+
+ includes = [
+ project_dir / "hardware" / "rtl",
+ ] + [
+ Path(mase_components.__file__).parent / module / "rtl"
+ for module in get_modules()
+ ]
+
+ build_start = time.time()
+
+ runner.build(
+ verilog_sources=sources,
+ includes=includes,
+ hdl_toplevel="top",
+ build_args=build_args,
+ parameters=[], # use default parameters,
+ )
+
+ build_end = time.time()
+ logger.info(f"Build finished. Time taken: {build_end - build_start:.2f}s")
+
+ if not skip_test:
+ # Add tb file to python path
+
+ # sys.path.append(str(pass_args["test_dir"]))
+
+ test_start = time.time()
+ runner.test(
+ hdl_toplevel="top",
+ test_module="mase_mxint_top_tb",
+ hdl_toplevel_lang="verilog",
+ gui=gui,
+ waves=waves,
+ )
+ test_end = time.time()
+ logger.info(f"Test finished. Time taken: {test_end - test_start:.2f}s")
+
+class MaseGraphTB(Testbench):
+ def __init__(self, dut, fail_on_checks=True):
+ super().__init__(dut, dut.clk, dut.rst, fail_on_checks=fail_on_checks)
+
+ # Instantiate as many drivers as required inputs to the model
+ self.input_drivers = {}
+ self.output_monitors = {}
+
+ arg = "data_in_0"
+ result = "data_out_0"
+ self.input_drivers[arg] = MultiSignalStreamDriver(
+ dut.clk, (dut.mdata_in_0, dut.edata_in_0),
+ dut.data_in_0_valid, dut.data_in_0_ready
+ )
+ # self.input_drivers[arg].log.setLevel(logging.DEBUG)
+
+ # Instantiate as many monitors as required outputs
+ self.output_monitors[result] = MultiSignalStreamMonitor(
+ dut.clk,
+ (dut.mdata_out_0, dut.edata_out_0),
+ dut.data_out_0_valid,
+ dut.data_out_0_ready,
+ check=False,
+ )
+ # self.output_monitors[result].log.setLevel(logging.DEBUG)
+
+ def generate_inputs(self, batches, model=None):
+ """
+ Generate inputs for the model by sampling a random tensor
+ for each input argument, according to its shape
+
+ :param batches: number of batches to generate for each argument
+ :type batches: int
+ :return: a dictionary of input arguments and their corresponding tensors
+ :rtype: Dict
+ """
+ # ! TO DO: iterate through graph.args instead to generalize
+ inputs = torch.randn(batches, self.get_parameter(f"DATA_IN_0_TENSOR_SIZE_DIM_1"), self.get_parameter(f"DATA_IN_0_TENSOR_SIZE_DIM_0"))
+ if model is not None:
+ outputs = model(inputs)
+ else:
+ outputs = torch.randn(batches, self.get_parameter(f"DATA_OUT_0_TENSOR_SIZE_DIM_1"), self.get_parameter(f"DATA_OUT_0_TENSOR_SIZE_DIM_0"))
+ return inputs, outputs
+
+ def preprocess_tensor_for_mxint(self, tensor, config, parallelism):
+ from mase_components.linear_layers.mxint_operators.test.utils import mxint_hardware
+ from mase_components.linear_layers.mxint_operators.test.utils import pack_tensor_to_mx_listed_chunk
+
+ (qtensor, mtensor, etensor) = mxint_hardware(tensor, config, parallelism)
+ tensor_inputs = pack_tensor_to_mx_listed_chunk(mtensor, etensor, parallelism)
+ return tensor_inputs
+
+ def load_drivers(self, in_tensors):
+ for i in range(in_tensors.shape[0]):
+ data_0_inputs = self.preprocess_tensor_for_mxint(
+ tensor=in_tensors[i],
+ config={
+ "width": self.get_parameter("DATA_IN_0_PRECISION_0"),
+ "exponent_width": self.get_parameter("DATA_IN_0_PRECISION_1"),
+ "round_bits": 4
+ },
+ parallelism=[self.get_parameter("DATA_IN_0_PARALLELISM_DIM_1"), self.get_parameter("DATA_IN_0_PARALLELISM_DIM_0")]
+ )
+ self.input_drivers["data_in_0"].load_driver(data_0_inputs)
+
+ def load_monitors(self, expectation):
+ for i in range(expectation.shape[0]):
+ exp_out = self.preprocess_tensor_for_mxint(
+ tensor=expectation[i],
+ config={
+ "width": self.get_parameter("DATA_OUT_0_PRECISION_0"),
+ "exponent_width": self.get_parameter("DATA_OUT_0_PRECISION_1"),
+ "round_bits": 4
+ },
+ parallelism=[self.get_parameter("DATA_OUT_0_PARALLELISM_DIM_1"), self.get_parameter("DATA_OUT_0_PARALLELISM_DIM_0")]
+ )
+ self.output_monitors["data_out_0"].load_monitor(exp_out)
+
+import torch.nn as nn
+@cocotb.test()
+async def test(dut):
+ # cocotb.start_soon(check_signal(dut))
+ tb = MaseGraphTB(dut, fail_on_checks=True)
+ await tb.initialize()
+ in_tensors, out_tensors = tb.generate_inputs(batches=1)
+
+ tb.load_drivers(in_tensors)
+ tb.load_monitors(out_tensors)
+
+ await tb.wait_end(timeout=100, timeout_unit="ms")
+
+from cocotb.triggers import *
+async def check_signal(dut):
+ await Timer(40, units="ns")
+ # Initialize counters for each data handshake interface
+ data_in_0_count = 0
+ data_out_0_count = 0
+ linear1_data_count = 0
+ act_data_count = 0
+ linear2_data_count = 0
+ norm1_data_count = 0
+ attention_data_count = 0
+ norm2_data_count = 0
+ add_data_count = 0
+ add1_data_count = 0
+ out_depth = 192/4
+ # Initialize timestamps for measuring handshake intervals
+ data_in_time = get_sim_time(units='ns')
+ data_out_time = get_sim_time(units='ns')
+ linear1_time = get_sim_time(units='ns')
+ act_time = get_sim_time(units='ns')
+ linear2_time = get_sim_time(units='ns')
+ norm1_time = get_sim_time(units='ns')
+ attention_time = get_sim_time(units='ns')
+ norm2_time = get_sim_time(units='ns')
+ add_time = get_sim_time(units='ns')
+ add1_time = get_sim_time(units='ns')
+
+ while True:
+ await RisingEdge(dut.clk)
+ await ReadOnly()
+
+ # Count handshakes for main input/output
+ if dut.data_in_0_valid.value and dut.data_in_0_ready.value:
+ data_in_0_count += 1
+ if data_in_0_count == out_depth:
+ data_in_0_count = 0
+ new_data_in_time = get_sim_time(units='ns')
+ diff_data_in = new_data_in_time - data_in_time
+ data_in_time = get_sim_time(units='ns')
+ print(f"data_in_0 handshake time: {diff_data_in}")
+
+ if dut.data_out_0_valid.value and dut.data_out_0_ready.value:
+ data_out_0_count += 1
+ if data_out_0_count == out_depth:
+ data_out_0_count = 0
+ new_data_out_time = get_sim_time(units='ns')
+ diff_data_out = new_data_out_time - data_out_time
+ data_out_time = get_sim_time(units='ns')
+ print(f"data_out_0 handshake time: {diff_data_out}")
+
+ if dut.stream_blocks_0_linear1_data_out_0_valid.value and dut.stream_blocks_0_linear1_data_out_0_ready.value:
+ linear1_data_count += 1
+ if linear1_data_count == out_depth:
+ linear1_data_count = 0
+ new_linear1_time = get_sim_time(units='ns')
+ diff_linear1 = new_linear1_time - linear1_time
+ linear1_time = get_sim_time(units='ns')
+ print(f"linear1 handshake time: {diff_linear1}")
+
+ if dut.stream_blocks_0_act_data_out_0_valid.value and dut.stream_blocks_0_act_data_out_0_ready.value:
+ act_data_count += 1
+ if act_data_count == out_depth:
+ act_data_count = 0
+ new_act_time = get_sim_time(units='ns')
+ diff_act = new_act_time - act_time
+ act_time = get_sim_time(units='ns')
+ print(f"act handshake time: {diff_act}")
+
+ if dut.stream_blocks_0_linear2_data_out_0_valid.value and dut.stream_blocks_0_linear2_data_out_0_ready.value:
+ linear2_data_count += 1
+ if linear2_data_count == out_depth:
+ linear2_data_count = 0
+ new_linear2_time = get_sim_time(units='ns')
+ diff_linear2 = new_linear2_time - linear2_time
+ linear2_time = get_sim_time(units='ns')
+ print(f"linear2 handshake time: {diff_linear2}")
+
+ if dut.stream_blocks_0_norm1_data_out_0_valid.value and dut.stream_blocks_0_norm1_data_out_0_ready.value:
+ norm1_data_count += 1
+ if norm1_data_count == out_depth:
+ norm1_data_count = 0
+ new_norm1_time = get_sim_time(units='ns')
+ diff_norm1 = new_norm1_time - norm1_time
+ norm1_time = get_sim_time(units='ns')
+ print(f"norm1 handshake time: {diff_norm1}")
+
+ if dut.stream_blocks_0_attention_data_out_0_valid.value and dut.stream_blocks_0_attention_data_out_0_ready.value:
+ attention_data_count += 1
+ if attention_data_count == out_depth:
+ attention_data_count = 0
+ new_attention_time = get_sim_time(units='ns')
+ diff_attention = new_attention_time - attention_time
+ attention_time = get_sim_time(units='ns')
+ print(f"attention handshake time: {diff_attention}")
+
+ if dut.stream_blocks_0_norm2_data_out_0_valid.value and dut.stream_blocks_0_norm2_data_out_0_ready.value:
+ norm2_data_count += 1
+ if norm2_data_count == out_depth:
+ norm2_data_count = 0
+ new_norm2_time = get_sim_time(units='ns')
+ diff_norm2 = new_norm2_time - norm2_time
+ norm2_time = get_sim_time(units='ns')
+ print(f"norm2 handshake time: {diff_norm2}")
+
+ if dut.stream_blocks_0_add_data_out_0_valid.value and dut.stream_blocks_0_add_data_out_0_ready.value:
+ add_data_count += 1
+ if add_data_count == out_depth:
+ add_data_count = 0
+ new_add_time = get_sim_time(units='ns')
+ diff_add = new_add_time - add_time
+ add_time = get_sim_time(units='ns')
+ print(f"add handshake time: {diff_add}")
+
+ if dut.stream_blocks_0_add_1_data_out_0_valid.value and dut.stream_blocks_0_add_1_data_out_0_ready.value:
+ add1_data_count += 1
+ if add1_data_count == out_depth:
+ add1_data_count = 0
+ new_add1_time = get_sim_time(units='ns')
+ diff_add1 = new_add1_time - add1_time
+ add1_time = get_sim_time(units='ns')
+ print(f"add1 handshake time: {diff_add1}")
+
+
+
+
+if __name__ == "__main__":
+ pass_args = {
+ "project_dir": Path("./mxint_vit_block"),
+ }
+ simulate(skip_build=False, skip_test=False, simulator="verilator", waves=True, gui=False, trace_depth=5, pass_args=pass_args)
\ No newline at end of file
diff --git a/a_cx_mxint_quant/mase_top_tb.py b/a_cx_mxint_quant/mase_top_tb.py
new file mode 100644
index 000000000..f48a40723
--- /dev/null
+++ b/a_cx_mxint_quant/mase_top_tb.py
@@ -0,0 +1,232 @@
+from pathlib import Path
+
+import cocotb
+import logging, torch
+from pathlib import Path
+
+logger = logging.getLogger(__name__)
+
+from pathlib import Path
+
+import cocotb
+from mase_cocotb.testbench import Testbench
+from mase_cocotb.interfaces.streaming import StreamDriver, StreamMonitor
+import sys
+from os import getenv, PathLike
+
+import torch
+from pathlib import Path
+import time
+import warnings
+from cocotb.runner import get_runner, get_results
+
+from chop.tools import get_logger
+import mase_components
+from mase_components import get_modules
+
+import glob, os
+
+def simulate(
+ model: torch.nn.Module = None,
+ model_info=None,
+ task: str = "",
+ dataset_info=None,
+ data_module=None,
+ load_name: PathLike = None,
+ load_type: str = None,
+ run_emit: bool = False,
+ skip_build: bool = False,
+ skip_test: bool = False,
+ trace_depth: int = 3,
+ gui: bool = False,
+ waves: bool = False,
+ simulator: str = "verilator",
+ pass_args = {},
+):
+ SIM = getenv("SIM", simulator)
+ runner = get_runner(SIM)
+
+ project_dir = (
+ pass_args["project_dir"]
+ if "project_dir" in pass_args.keys()
+ else Path.home() / ".mase" / "top"
+ )
+
+ if not skip_build:
+ # To do: extract from mz checkpoint
+ if simulator == "questa":
+ sources = glob.glob(os.path.join(project_dir / "hardware" / "rtl", "*.sv"))
+ build_args = []
+
+ elif simulator == "verilator":
+ # sources = ["../../../top.sv"]
+ sources = glob.glob(os.path.join(project_dir / "hardware" / "rtl", "*.sv"))
+ build_args = [
+ "-Wno-fatal",
+ "-Wno-lint",
+ "-Wno-style",
+ "--trace-fst",
+ "--trace-structs",
+ "--trace-depth",
+ str(trace_depth),
+ "--unroll-count",
+ "16384"
+ ]
+ else:
+ raise ValueError(f"Unrecognized simulator: {simulator}")
+
+ includes = [
+ project_dir / "hardware" / "rtl",
+ ] + [
+ Path(mase_components.__file__).parent / module / "rtl"
+ for module in get_modules()
+ ]
+
+ build_start = time.time()
+
+ runner.build(
+ verilog_sources=sources,
+ includes=includes,
+ hdl_toplevel="top",
+ build_args=build_args,
+ parameters=[], # use default parameters,
+ )
+
+ build_end = time.time()
+ logger.info(f"Build finished. Time taken: {build_end - build_start:.2f}s")
+
+ if not skip_test:
+ # Add tb file to python path
+
+ # sys.path.append(str(pass_args["test_dir"]))
+
+ test_start = time.time()
+ runner.test(
+ hdl_toplevel="top",
+ test_module="mase_top_tb",
+ hdl_toplevel_lang="verilog",
+ gui=gui,
+ waves=waves,
+ )
+ test_end = time.time()
+ logger.info(f"Test finished. Time taken: {test_end - test_start:.2f}s")
+
+class MaseGraphTB(Testbench):
+ def __init__(self, dut, fail_on_checks=True):
+ super().__init__(dut, dut.clk, dut.rst, fail_on_checks=fail_on_checks)
+
+ # Instantiate as many drivers as required inputs to the model
+ self.input_drivers = {}
+ self.output_monitors = {}
+
+ arg = "data_in_0"
+ result = "data_out_0"
+ self.input_drivers[arg] = StreamDriver(
+ dut.clk,
+ getattr(dut, arg),
+ getattr(dut, f"{arg}_valid"),
+ getattr(dut, f"{arg}_ready"),
+ )
+ self.input_drivers[arg].log.setLevel(logging.DEBUG)
+
+ # Instantiate as many monitors as required outputs
+ self.output_monitors[result] = StreamMonitor(
+ dut.clk,
+ getattr(dut, result),
+ getattr(dut, f"{result}_valid"),
+ getattr(dut, f"{result}_ready"),
+ check=False,
+ )
+ self.output_monitors[result].log.setLevel(logging.DEBUG)
+
+ def generate_inputs(self, batches):
+ """
+ Generate inputs for the model by sampling a random tensor
+ for each input argument, according to its shape
+
+ :param batches: number of batches to generate for each argument
+ :type batches: int
+ :return: a dictionary of input arguments and their corresponding tensors
+ :rtype: Dict
+ """
+ # ! TO DO: iterate through graph.args instead to generalize
+ inputs = torch.randn(batches, self.get_parameter(f"DATA_IN_0_TENSOR_SIZE_DIM_1"), self.get_parameter(f"DATA_IN_0_TENSOR_SIZE_DIM_0"))
+ outputs = torch.randn(batches, self.get_parameter(f"DATA_OUT_0_TENSOR_SIZE_DIM_1"), self.get_parameter(f"DATA_OUT_0_TENSOR_SIZE_DIM_0"))
+ return inputs, outputs
+
+ def load_drivers(self, in_tensors):
+ from mase_cocotb.utils import fixed_preprocess_tensor
+
+ in_data_blocks = fixed_preprocess_tensor(
+ tensor=in_tensors,
+ q_config={
+ "width": self.get_parameter(f"DATA_IN_0_PRECISION_0"),
+ "frac_width": self.get_parameter(
+ f"DATA_IN_0_PRECISION_1"
+ ),
+ },
+ parallelism=[
+ self.get_parameter(f"DATA_IN_0_PARALLELISM_DIM_1"),
+ self.get_parameter(f"DATA_IN_0_PARALLELISM_DIM_0"),
+ ],
+ floor=True,
+ )
+
+ # Append all input blocks to input driver
+ # ! TO DO: generalize
+ block_size = self.get_parameter(
+ "DATA_IN_0_PARALLELISM_DIM_0"
+ ) * self.get_parameter("DATA_IN_0_PARALLELISM_DIM_1")
+ for block in in_data_blocks:
+ if len(block) < block_size:
+ block = block + [0] * (block_size - len(block))
+ self.input_drivers["data_in_0"].append(block)
+
+ def load_monitors(self, expectation):
+ from mase_cocotb.utils import fixed_preprocess_tensor
+
+ # Process the expectation tensor
+ output_blocks = fixed_preprocess_tensor(
+ tensor=expectation,
+ q_config={
+ "width": self.get_parameter(f"DATA_OUT_0_PRECISION_0"),
+ "frac_width": self.get_parameter(f"DATA_OUT_0_PRECISION_1"),
+ },
+ parallelism=[
+ self.get_parameter(f"DATA_OUT_0_PARALLELISM_DIM_1"),
+ self.get_parameter(f"DATA_OUT_0_PARALLELISM_DIM_0"),
+ ],
+ floor=True,
+ )
+
+ # Set expectation for each monitor
+ for block in output_blocks:
+ # ! TO DO: generalize to multi-output models
+ if len(block) < self.get_parameter("DATA_OUT_0_PARALLELISM_DIM_0"):
+ block = block + [0] * (
+ self.get_parameter("DATA_OUT_0_PARALLELISM_DIM_0") - len(block)
+ )
+ self.output_monitors["data_out_0"].expect(block)
+
+ # Drive the in-flight flag for each monitor
+ self.output_monitors["data_out_0"].in_flight = True
+
+@cocotb.test()
+async def test(dut):
+
+ tb = MaseGraphTB(dut, fail_on_checks=True)
+ await tb.initialize()
+
+ in_tensors, out_tensors = tb.generate_inputs(batches=10)
+
+ tb.load_drivers(in_tensors)
+ tb.load_monitors(out_tensors)
+
+ await tb.wait_end(timeout=0.1, timeout_unit="s")
+
+
+if __name__ == "__main__":
+ pass_args = {
+ "project_dir": Path("./int_linear"),
+ }
+ simulate(skip_build=False, skip_test=False, simulator="verilator", waves=True, gui=False, pass_args=pass_args)
\ No newline at end of file
diff --git a/a_cx_mxint_quant/mase_utils.py b/a_cx_mxint_quant/mase_utils.py
new file mode 100644
index 000000000..41460a9c0
--- /dev/null
+++ b/a_cx_mxint_quant/mase_utils.py
@@ -0,0 +1,475 @@
+import torch
+from chop.nn.quantizers import integer_floor_quantizer
+from functools import partial
+import torch.nn.functional as F
+from torch import Tensor
+
+import torch
+from functools import partial
+import torch.nn.functional as F
+from torch import Tensor
+
+def mxint_quant_block(
+ x, width: int = 12, exponent_width: int = 6, exponent: int = None, floor: bool = False
+):
+ """
+ - Idea from https://arxiv.org/pdf/2310.10537
+ - Convert IEEE FP32/64 to Integer with sharing scale
+ - The main difference between is the sharing scale do not support NAN representation
+ ---
+ - `width`: The number of mantissa bits + 1 (the sign bit)
+ - `exponent_width`: the number of exponent bits, which is shared over a block
+ - `exponent_bias`: the exponent bias, if None, `2**(exponent_bits-1)-1` will be used
+
+ """
+ exponent_bias = 2 ** (exponent_width - 1)
+ exponent_max = 2**exponent_width - 1 - exponent_bias
+ exponent_min = -exponent_bias
+
+ # Vectorized max and log2 operations
+ abs_max = x.abs().max(dim=-1, keepdim=True).values
+ log2 = torch.log2(abs_max + torch.finfo(torch.float32).tiny)
+
+ exponent = torch.ceil(log2)
+ exponent = torch.clamp(exponent, exponent_min, exponent_max)
+
+ # Vectorized mantissa calculation
+ int_min = -(2 ** (width - 1))
+ int_max = 2 ** (width - 1) - 1
+ mantissa = x * (2 ** (width - 1)) / 2**exponent
+ if floor:
+ mantissa = torch.floor(mantissa)
+ else:
+ mantissa = torch.round(mantissa)
+ mantissa = torch.clamp(mantissa, int_min, int_max)
+ q_x = (2**exponent) * mantissa /(2 ** (width - 1))
+ return q_x
+
+
+def mxint_hardware(tensor, q_config, parallelism):
+ """
+ Vectorized hardware-aware quantization implementation
+ """
+ original_shape = tensor.shape
+ if len(tensor.shape) == 1:
+ tensor = tensor.unsqueeze(0)
+ if len(parallelism) == 1:
+ parallelism = [1, parallelism[0]]
+
+ p1, p0 = parallelism
+ t1, t0 = tensor.shape[-2:]
+
+ assert (t1 % p1 == 0 and t0 % p0 == 0), \
+ f"Block size mismatch: t1={t1}, p1={p1}, t0={t0}, p0={p0}"
+
+ # Single reshape and permute operation
+ block_tensor = tensor.reshape(-1, t1 // p1, p1, t0 // p0, p0)\
+ .permute(0, 1, 3, 2, 4)\
+ .reshape(-1, p1 * p0)
+
+ # Direct vectorized quantization without loop
+ qtensor = mxint_quant_block(block_tensor, **q_config)
+
+ # Efficient shape restoration
+ return qtensor.reshape(-1, t1 // p1, t0 // p0, p1, p0)\
+ .permute(0, 1, 3, 2, 4)\
+ .reshape(original_shape)
+
+def wrapped_mxint_linear_hardware(x, w, bias, in_features, out_features, config):
+ mx = x[0]
+ n = mx.reshape(-1, in_features).shape[0]
+ in_config = {
+ "x_config": {
+ "width": config["data_in_width"],
+ "exponent_width": config["data_in_exponent_width"],
+ "parallism_dim_0": config["data_in_parallelism"][1],
+ "parallism_dim_1": config["data_in_parallelism"][0],
+ "depth_dim_0": in_features // config["data_in_parallelism"][1],
+ "depth_dim_1": n // config["data_in_parallelism"][0],
+ "dim_0": in_features,
+ "dim_1": n,
+ },
+ "w_config": {
+ "width": config["weight_width"],
+ "exponent_width": config["weight_exponent_width"],
+ "parallism_dim_0": config["weight_parallelism"][1],
+ "parallism_dim_1": config["weight_parallelism"][0],
+ "depth_dim_0": in_features // config["weight_parallelism"][1],
+ "depth_dim_1": out_features // config["weight_parallelism"][0],
+ "dim_0": in_features,
+ "dim_1": out_features,
+ },
+ "bias_config": {
+ "width": config["bias_width"],
+ "exponent_width": config["bias_exponent_width"],
+ "parallism_dim_0": config["bias_parallelism"][1],
+ "parallism_dim_1": 1,
+ "depth_dim_0": out_features // config["bias_parallelism"][1],
+ "depth_dim_1": 1,
+ "dim_0": out_features,
+ "dim_1": 1,
+ },
+ "out_config": {
+ "width": config["data_out_width"],
+ "exponent_width": config["data_out_exponent_width"],
+ "parallism_dim_0": config["data_out_parallelism"][1],
+ "parallism_dim_1": config["data_out_parallelism"][0],
+ "depth_dim_0": out_features // config["data_out_parallelism"][1],
+ "depth_dim_1": n // config["data_out_parallelism"][0],
+ "dim_0": out_features,
+ "dim_1": n,
+ },
+ }
+ mout, eout = mxint_linear_hardware(x, w, bias, in_config)
+ out_config = in_config["out_config"]
+ reshaped_mout = mout.reshape(
+ out_config["depth_dim_1"],
+ out_config["parallism_dim_1"],
+ out_config["depth_dim_0"],
+ out_config["parallism_dim_0"],
+ ).permute(0, 2, 1, 3)
+ reshaped_out = reshaped_mout * 2 ** (
+ eout[:, :, None, None] - config["data_out_width"] + 1
+ )
+ out = reshaped_out.reshape(
+ out_config["depth_dim_1"],
+ out_config["depth_dim_0"],
+ out_config["parallism_dim_1"],
+ out_config["parallism_dim_0"],
+ ).permute(0, 2, 1, 3)
+ out = out.reshape(out_config["dim_1"], out_config["dim_0"])
+
+ return out
+
+
+def mxint_linear_hardware(x, w, bias, config):
+ """
+ assume 2 dimensional input
+ config = {
+ "x_config":{
+ "width": ,
+ "exponent_width" ,
+ "parallism_dim_0",
+ "parallism_dim_1",
+ "depth_dim_0",
+ "depth_dim_1",
+ "dim_0",
+ "dim_1",
+ },
+ "w_config": {
+ ...
+ },
+ "bias_config": {
+ ...
+ },
+ "out_config": {
+ ...
+ },
+ }
+ """
+ mx, ex = x
+ mw, ew = w
+ x_config = config["x_config"]
+ w_config = config["w_config"]
+ out_config = config["out_config"]
+ from math import ceil, log2
+
+ def DotProductCore(man_x, exp_x, man_y, exp_y):
+ return man_x @ man_y.transpose(0, 1), exp_x + exp_y
+
+ def block_wise_reshape_tensor(x, x_config):
+ reshaped_x = x.reshape(
+ x_config["depth_dim_1"],
+ x_config["parallism_dim_1"],
+ x_config["depth_dim_0"],
+ x_config["parallism_dim_0"],
+ ).permute(0, 2, 1, 3)
+ reshaped_x = reshaped_x.reshape(
+ x_config["depth_dim_1"] * x_config["depth_dim_0"],
+ x_config["parallism_dim_1"],
+ x_config["parallism_dim_0"],
+ )
+ return reshaped_x
+
+ # assume 2 dimensional input
+ assert (
+ x_config["depth_dim_0"] == w_config["depth_dim_0"]
+ ), "need to check the setting of dim"
+ assert (
+ x_config["parallism_dim_0"] == w_config["parallism_dim_0"]
+ ), "need to check the setting of dim"
+ reshaped_ex = ex.reshape(-1)
+ reshaped_mx = block_wise_reshape_tensor(mx, x_config)
+ reshaped_ew = ew.reshape(-1)
+ reshaped_mw = block_wise_reshape_tensor(mw, w_config)
+ man_out = torch.zeros(
+ x_config["depth_dim_1"],
+ w_config["depth_dim_1"],
+ x_config["parallism_dim_1"] * w_config["parallism_dim_1"],
+ )
+ exp_out = torch.zeros(x_config["depth_dim_1"], w_config["depth_dim_1"])
+ for i in range(x_config["depth_dim_1"]):
+ for j in range(w_config["depth_dim_1"]):
+ partial_man_out = torch.zeros(
+ w_config["depth_dim_0"],
+ x_config["parallism_dim_1"],
+ w_config["parallism_dim_1"],
+ )
+ partial_exp_out = torch.zeros(w_config["depth_dim_0"])
+ for k in range(x_config["depth_dim_0"]):
+ mx_block = reshaped_mx[i * x_config["depth_dim_0"] + k]
+ ex_block = reshaped_ex[i * x_config["depth_dim_0"] + k]
+ mw_block = reshaped_mw[j * w_config["depth_dim_0"] + k]
+ ew_block = reshaped_ew[j * w_config["depth_dim_0"] + k]
+ partial_man_out[k], partial_exp_out[k] = DotProductCore(
+ mx_block, ex_block, mw_block, ew_block
+ )
+ acc_man_out, acc_exp_out = MxIntAccumulator(
+ partial_man_out.reshape(w_config["depth_dim_0"], -1), partial_exp_out
+ )
+ if bias != None:
+ bias_config = config["bias_config"]
+ mbias, ebias = bias
+ reshaped_mbias = mbias.reshape(
+ w_config["depth_dim_1"], w_config["parallism_dim_1"]
+ )
+ reshaped_ebias = ebias.reshape(w_config["depth_dim_1"])
+ shifted_value = (
+ reshaped_ebias[j]
+ - acc_exp_out
+ + x_config["width"]
+ + w_config["width"]
+ - 2
+ - (bias_config["width"] - 1)
+ )
+ shifted_bias = reshaped_mbias[j].repeat(
+ x_config["parallism_dim_1"]
+ ) * 2 ** (shifted_value)
+ acc_man_out = shifted_bias + acc_man_out
+ man_out[i][j], exp_out[i][j] = MxIntCast(
+ acc_man_out,
+ acc_exp_out,
+ {
+ "in_width": x_config["width"]
+ + w_config["width"]
+ + ceil(log2(x_config["dim_0"])),
+ "in_frac_width": x_config["width"] + w_config["width"] - 2,
+ "in_exponent_width": max(
+ x_config["exponent_width"], w_config["exponent_width"]
+ )
+ + 1,
+ "out_width": out_config["width"],
+ "out_exponent_width": out_config["exponent_width"],
+ },
+ )
+ man_out = (
+ man_out.reshape(
+ x_config["depth_dim_1"],
+ w_config["depth_dim_1"],
+ x_config["parallism_dim_1"],
+ w_config["parallism_dim_1"],
+ )
+ .permute(0, 2, 1, 3)
+ .reshape(x_config["dim_1"], w_config["dim_1"])
+ )
+ return man_out, exp_out
+
+
+def MXIntMatmulHardware(man_x, exp_x, man_y, exp_y, x_config, y_config, out_config):
+ """
+ assume 2 dimensional input
+ config = {
+ "width": ,
+ "exponent_width" ,
+ "parallism_dim_0",
+ "parallism_dim_1",
+ "depth_dim_0",
+ "depth_dim_1",
+ "dim_0",
+ "dim_1",
+ }
+ man.shape = [dim_1 * dim_0]
+ exp.shape = [depth_dim_1, depth_dim_0]
+ """
+ from math import ceil, log2
+
+ def MatmulCore(man_x, exp_x, man_y, exp_y):
+ return man_x @ man_y, exp_x + exp_y
+
+ # assume 2 dimensional input
+ assert (
+ x_config["depth_dim_0"] == y_config["depth_dim_1"]
+ ), "need to check the setting of dim"
+
+ def block_wise_reshape_tensor(x, x_config):
+ reshaped_x = x.reshape(
+ x_config["depth_dim_1"],
+ x_config["parallism_dim_1"],
+ x_config["depth_dim_0"],
+ x_config["parallism_dim_0"],
+ ).permute(0, 2, 1, 3)
+ reshaped_x = reshaped_x.reshape(
+ x_config["depth_dim_1"] * x_config["depth_dim_0"],
+ x_config["parallism_dim_1"],
+ x_config["parallism_dim_0"],
+ )
+ return reshaped_x
+
+ reshaped_exp_x = exp_x.reshape(-1)
+ reshaped_man_x = block_wise_reshape_tensor(man_x, x_config)
+ reshaped_exp_y = exp_y.reshape(-1)
+ reshaped_man_y = block_wise_reshape_tensor(man_y, y_config)
+ man_out = torch.zeros(
+ x_config["depth_dim_1"],
+ y_config["depth_dim_0"],
+ x_config["parallism_dim_1"] * y_config["parallism_dim_0"],
+ )
+ exp_out = torch.zeros(x_config["depth_dim_1"], y_config["depth_dim_0"])
+ for i in range(x_config["depth_dim_1"]):
+ for j in range(y_config["depth_dim_0"]):
+ partial_man_out = torch.zeros(
+ y_config["depth_dim_1"],
+ x_config["parallism_dim_1"],
+ y_config["parallism_dim_0"],
+ )
+ partial_exp_out = torch.zeros(y_config["depth_dim_1"])
+ for k in range(y_config["depth_dim_1"]):
+ man_x_block = reshaped_man_x[i * x_config["depth_dim_0"] + k]
+ exp_x_block = reshaped_exp_x[i * x_config["depth_dim_0"] + k]
+ man_y_block = reshaped_man_y[k * y_config["depth_dim_0"] + j]
+ exp_y_block = reshaped_exp_y[k * y_config["depth_dim_0"] + j]
+ partial_man_out[k], partial_exp_out[k] = MatmulCore(
+ man_x_block, exp_x_block, man_y_block, exp_y_block
+ )
+ acc_man_out, acc_exp_out = MxIntAccumulator(
+ partial_man_out.reshape(y_config["depth_dim_1"], -1), partial_exp_out
+ )
+ man_out[i][j], exp_out[i][j] = MxIntCast(
+ acc_man_out,
+ acc_exp_out,
+ {
+ "in_width": x_config["width"]
+ + y_config["width"]
+ + ceil(log2(x_config["dim_0"])),
+ "in_frac_width": x_config["width"] + y_config["width"] - 2,
+ "in_exponent_width": max(
+ x_config["exponent_width"], y_config["exponent_width"]
+ )
+ + 1,
+ "out_width": out_config["width"],
+ "out_exponent_width": out_config["exponent_width"],
+ },
+ )
+ man_out = (
+ man_out.reshape(
+ x_config["depth_dim_1"],
+ y_config["depth_dim_0"],
+ x_config["parallism_dim_1"],
+ x_config["parallism_dim_0"],
+ )
+ .permute(0, 2, 1, 3)
+ .reshape(x_config["dim_1"], y_config["dim_0"])
+ )
+ return man_out, exp_out
+
+
+def MxIntCast(man_in, exp_in, param):
+ # In Man Width
+ max_in = torch.ceil(torch.log2(man_in.abs().max()))
+ out_width = param["out_width"]
+ out_exponent_width = param["out_exponent_width"]
+ in_width = param["in_width"]
+ in_frac_width = param["in_frac_width"]
+ in_exponent_width = param["in_exponent_width"]
+
+ out_exponent_max = 2 ** (out_exponent_width - 1) - 1
+ out_exponent_min = -(2 ** (out_exponent_width - 1))
+
+ out_min = -(2 ** (out_width - 1))
+ out_max = 2 ** (out_width - 1) - 1
+ lma_in = torch.ceil(torch.log2(man_in.abs().max() + 1e-3))
+ out_exp_full = lma_in + exp_in - in_frac_width
+ out_exp = torch.clamp(out_exp_full, out_exponent_min, out_exponent_max)
+ out_man = man_in // 2 ** (in_frac_width - exp_in + out_exp - (out_width - 1))
+ out_man = torch.clamp(out_man, out_min, out_max)
+
+ return out_man, out_exp
+
+def MxIntAccumulator(man, exp):
+ IN_DEPTH, BLOCK_SIZE = man.shape[0], man.shape[1]
+ max_exp = torch.Tensor([float("-inf")])
+ mout = torch.zeros(BLOCK_SIZE)
+ out_exp = torch.Tensor([float("-inf")])
+ for i in range(IN_DEPTH):
+ max_exp = exp[i] if exp[i] > max_exp else max_exp
+ mout = mout // 2 ** (max_exp - out_exp)
+ out_exp = max_exp
+ shifted_man = man[i] // 2 ** (max_exp - exp[i])
+ mout = mout + shifted_man
+
+ return mout, out_exp
+
+def quantized_range_reduction(mx, ex, in_man_width, data_out_n_width):
+ """Vectorized range reduction"""
+ def hardware_round(mx, ex, in_man_frac_width, data_out_width):
+ round_max = 2**(data_out_width-1) - 1
+ round_min = -2**(data_out_width-1)
+ round_x = mx.reshape(-1) // 2**((in_man_frac_width-ex).reshape(-1))
+ return torch.clamp(round_x, round_min, round_max)
+ coefficient_quant_block = partial(
+ mxint_quantize,
+ width=8,
+ exponent_width=4)
+ _, mlog2_e, elog2_e = coefficient_quant_block(torch.log2(torch.tensor(math.e)))
+ _, mln_2, eln_2 = coefficient_quant_block(torch.log(torch.tensor(2.0)))
+ n = hardware_round(mx * mlog2_e, ex + elog2_e, (in_man_width - 1 + 7), data_out_n_width)
+ print(n)
+ _mx = n * mln_2
+ _ex = eln_2
+ shifted_mx = mx // 2**(_ex - ex + (in_man_width - 1) - 7)
+ print(shifted_mx)
+ print(_ex - ex + (in_man_width - 1) - 7)
+ mr = shifted_mx - _mx
+ # return mr as an fixedpoint ?.7 we can make it 2.7
+ # return n as an integer number with width = data_out_width
+ return mr, n
+
+def fixed_exp(fr):
+ frac_width = 7
+ exp = 1*2**(frac_width) + fr + fr**2//2**(frac_width + 1) + fr**3*5//2**(frac_width + 4)
+ return exp
+
+def mxint_softmax(x, q_config):
+ # fixed_r, integer_n
+ in_man_width = q_config["in_man_width"]
+ in_exp_width = q_config["in_exp_width"]
+ data_out_n_width = q_config["data_out_n_width"]
+ data_out_man_width = q_config["data_out_man_width"]
+ data_out_frac_width = data_out_man_width - 1
+ data_out_exp_width = q_config["data_out_exp_width"]
+
+ shape = x.shape[0]
+ mout = torch.zeros_like(x)
+ eout = torch.zeros_like(x)
+
+ list_of_mexps = []
+ list_of_eexps = []
+ for i in range(shape):
+ _, mx, ex = mxint_quantize(x[i], in_man_width, in_exp_width)
+ fixed_r, integer_n = quantized_range_reduction(mx, ex, in_man_width, data_out_n_width)
+ # fixed_r will be 2.7 bits, integer_n will be data_out_n_width bits
+ mexp = fixed_exp(fixed_r)
+ eexp = integer_n
+ # currently we got mexp ?.7 bits, integer_n data_out_n_width bits
+ list_of_mexps.append(mexp)
+ list_of_eexps.append(eexp)
+ eexps = torch.stack(list_of_eexps)
+ mexps = torch.stack(list_of_mexps)
+ m_sum, e_sum = MxIntAccumulator(torch.stack(list_of_mexps), torch.stack(list_of_eexps))
+ extended_mexps = mexps * 2**(data_out_frac_width)
+ pre_cast_mout = extended_mexps // mexps
+ pre_cast_eout = eexps - e_sum
+ pre_cast_out = pre_cast_mout * 2**(pre_cast_eout - 7)
+ for i in range(shape):
+ _, mout[i], eout[i] = mxint_quantize(pre_cast_out[i], data_out_man_width, data_out_exp_width)
+ return mout, eout
diff --git a/a_cx_mxint_quant/module_level_tranform.py b/a_cx_mxint_quant/module_level_tranform.py
new file mode 100644
index 000000000..9d82a5087
--- /dev/null
+++ b/a_cx_mxint_quant/module_level_tranform.py
@@ -0,0 +1,147 @@
+import torch.nn as nn
+import chop as chop
+from chop.tools import get_logger
+from chop.tools.logger import set_logging_verbosity
+
+from .attention_head import _ViTSelfAttentionHeadBase
+from .attention import MXIntAttention
+from chop.models.vision.vit.vit import Attention
+
+from .linear import MXIntLinear
+# from .layer_norm import MXIntLayerNorm
+# from .gelu import MXIntGELU
+import torch
+
+
+logger = get_logger(__name__)
+set_logging_verbosity("debug")
+class MXIntGELU(nn.Module):
+ def __init__(self, q_config = {}):
+ super().__init__()
+ self.q_config = q_config
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ out = x
+ return out
+
+class MXIntLayerNorm(nn.LayerNorm):
+ def __init__(
+ self,
+ normalized_shape,
+ eps: float = 0.00001,
+ elementwise_affine: bool = False,
+ bias: bool = False,
+ q_config=None,
+ ) -> None:
+ self.q_config = q_config
+ super().__init__(normalized_shape, eps, elementwise_affine, bias)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return torch.nn.functional.layer_norm(
+ x,
+ self.normalized_shape,
+ self.weight,
+ self.bias,
+ self.eps,
+ )
+
+def vit_module_level_quantize(model, model_config = {}, q_config = {}):
+ def parse_q_config(module, q_config):
+ if q_config.get("by") == "name":
+ if module[0] in q_config:
+ return False, q_config[module[0]]["config"]
+ else:
+ return True, None
+ elif q_config.get("by") == "type":
+ module_name = module[1].__class__.__name__
+ if "Linear" in module_name:
+ if any("linear" in key for key in q_config.keys()):
+ if "linear1" in module[0] and "linear1" in q_config:
+ return False, q_config["linear1"]["config"]
+ elif "linear2" in module[0] and "linear2" in q_config:
+ return False, q_config["linear2"]["config"]
+ else:
+ return False, q_config["linear"]["config"]
+ else:
+ return True, None
+ elif "layer_norm" in q_config and "LayerNorm" in module_name:
+ return False, q_config["layer_norm"]["config"]
+ elif "attention" in q_config and "Attention" in module_name:
+ return False, q_config["attention"]["config"]
+ elif "gelu" in q_config and "GELU" in module_name:
+ return False, q_config["gelu"]["config"]
+ else:
+ return True, None
+ else:
+ raise ValueError(f"Invalid q_config: {q_config}")
+
+ from chop.passes.graph.utils import deepsetattr
+ for module in model.named_modules():
+ skip, config = parse_q_config(module, q_config)
+ if skip:
+ continue
+ if isinstance(module[1], Attention):
+ ori_module = module[1]
+ new_module = MXIntAttention(
+ model_config["dim"],
+ model_config["num_heads"],
+ qkv_bias=True,
+ q_config=config,
+ )
+ logger.info(f"Replacing module: {module[0]}")
+ dim = ori_module.head_dim * ori_module.num_heads
+
+ qkv_weight = ori_module.qkv.weight.reshape(3, dim, dim)
+ new_module.query.weight = nn.Parameter(qkv_weight[0])
+ new_module.key.weight = nn.Parameter(qkv_weight[1])
+ new_module.value.weight = nn.Parameter(qkv_weight[2])
+
+ has_bias = False if ori_module.qkv.bias == None else True
+ if has_bias:
+ qkv_bias = ori_module.qkv.bias.reshape(3, 1, dim)
+ new_module.query.bias = nn.Parameter(qkv_bias[0])
+ new_module.key.bias = nn.Parameter(qkv_bias[1])
+ new_module.value.bias = nn.Parameter(qkv_bias[2])
+
+ new_module.proj.weight = ori_module.proj.weight
+ new_module.proj.bias = ori_module.proj.bias
+ deepsetattr(model, module[0], new_module)
+ elif isinstance(module[1], nn.LayerNorm):
+ ori_module = module[1]
+ if ori_module.bias is not None:
+ bias = True
+ new_module = MXIntLayerNorm(
+ ori_module.normalized_shape,
+ eps=ori_module.eps,
+ elementwise_affine=ori_module.elementwise_affine,
+ bias=bias,
+ q_config=config,
+ )
+ new_module.weight = ori_module.weight
+ new_module.bias = ori_module.bias
+ logger.info(f"Replacing module: {module[0]}")
+
+ deepsetattr(model, module[0], new_module)
+ elif isinstance(module[1], nn.Linear) or isinstance(module[1], MXIntLinear):
+ if "attention" in module[0]:
+ continue
+ if module[0] == "head":
+ continue
+ ori_module = module[1]
+ new_module = MXIntLinear(
+ ori_module.in_features,
+ ori_module.out_features,
+ q_config=config,
+ )
+ new_module.weight = ori_module.weight
+ new_module.bias = ori_module.bias
+ logger.info(f"Replacing linear module: {module[0]}")
+ deepsetattr(model, module[0], new_module)
+ elif isinstance(module[1], nn.GELU):
+ ori_module = module[1]
+ new_module = MXIntGELU(
+ q_config=config,
+ )
+ logger.info(f"Replacing module: {module[0]}")
+ deepsetattr(model, module[0], new_module)
+ return model
\ No newline at end of file
diff --git a/a_cx_mxint_quant/modules.py b/a_cx_mxint_quant/modules.py
new file mode 100644
index 000000000..be63a5299
--- /dev/null
+++ b/a_cx_mxint_quant/modules.py
@@ -0,0 +1,67 @@
+
+import torch.nn as nn
+
+from chop.nn.quantized.modules.attention import _ViTAttentionBase
+
+import chop as chop
+from chop.tools import get_logger
+from chop.tools.logger import set_logging_verbosity
+
+logger = get_logger(__name__)
+set_logging_verbosity("debug")
+from chop.models.vision.vit.vit import Attention
+import torch
+from mase_components.linear_layers.mxint_operators.test.utils import MXIntLinearHardware
+class MXIntPatchEmbed(nn.Module):
+ def __init__(
+ self,
+ img_size: int,
+ patch_size: int,
+ in_chans: int,
+ embed_dim: int,
+ q_config: dict = None,
+ norm_layer: nn.Module = nn.LayerNorm
+ ) -> None:
+ super().__init__()
+ self.q_config = q_config
+ self.conv = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
+ # self.norm = norm_layer(embed_dim)
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.num_patches = (img_size // patch_size) ** 2
+ self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
+ self.distill_token = nn.Parameter(torch.randn(1, 1, embed_dim))
+ def forward(self, x):
+ x = self.conv(x)
+ x = x.flatten(2).transpose(1, 2)
+ # x = self.norm(x)
+ x = torch.cat((self.cls_token.expand(x.size(0), -1, -1), self.distill_token.expand(x.size(0), -1, -1), x), dim=1)
+ return x
+
+class ViTAttentionMxInt(_ViTAttentionBase):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 8,
+ qkv_bias: bool = False,
+ qk_norm: bool = False,
+ attn_drop: float = 0.0,
+ proj_drop: float = 0.0,
+ norm_layer: nn.Module = nn.LayerNorm,
+ q_config: dict = None,
+ ) -> None:
+ super().__init__(dim, num_heads, qkv_bias, qk_norm, attn_drop, proj_drop)
+ self.q_config = q_config
+
+
+class MXIntAddition(nn.Module):
+ def __init__(
+ self,
+ q_config,
+ ) -> None:
+ super().__init__()
+ self.q_config = q_config
+
+ def forward(self, x, y):
+ return x + y
+
diff --git a/src/mase_components/vision_models/vit/rtl/__init__.py b/a_cx_mxint_quant/mxint_cast.drawio
similarity index 100%
rename from src/mase_components/vision_models/vit/rtl/__init__.py
rename to a_cx_mxint_quant/mxint_cast.drawio
diff --git a/a_cx_mxint_quant/quantizers.py b/a_cx_mxint_quant/quantizers.py
new file mode 100644
index 000000000..939b3c76f
--- /dev/null
+++ b/a_cx_mxint_quant/quantizers.py
@@ -0,0 +1,73 @@
+import torch
+from functools import partial
+import torch.nn.functional as F
+from torch import Tensor
+from chop.nn.quantized.modules.linear import _LinearBase
+from .utils import reshape_to_block, reshape_back
+
+def mxint_quant_block(
+ x, width: int = 12, exponent_width: int = 6, exponent: int = None, round_bits: int = 4,
+):
+ """
+ - Idea from https://arxiv.org/pdf/2310.10537
+ - Convert IEEE FP32/64 to Integer with sharing scale
+ - The main difference between is the sharing scale do not support NAN representation
+ ---
+ - `width`: The number of mantissa bits + 1 (the sign bit)
+ - `exponent_width`: the number of exponent bits, which is shared over a block
+ - `exponent_bias`: the exponent bias, if None, `2**(exponent_bits-1)-1` will be used
+
+ """
+ exponent_bias = 2 ** (exponent_width - 1)
+ exponent_max = 2**exponent_width - 1 - exponent_bias
+ exponent_min = -exponent_bias
+
+ # Vectorized max and log2 operations
+ abs_max = x.abs().max(dim=-1, keepdim=True).values
+ log2 = torch.log2(abs_max + torch.finfo(torch.float32).tiny)
+
+ exponent = torch.ceil(log2)
+ exponent[exponent == log2] += 1
+ exponent = torch.clamp(exponent, exponent_min, exponent_max)
+
+ # Vectorized mantissa calculation
+ int_min = -(2 ** (width - 1))
+ int_max = 2 ** (width - 1) - 1
+ mantissa = x * (2 ** (width - 1)) / 2**exponent
+ mantissa = mantissa * 2 ** round_bits
+ mantissa = torch.floor(mantissa)
+ mantissa = mantissa / 2 ** round_bits
+ mantissa = torch.round(mantissa)
+ mantissa = torch.clamp(mantissa, int_min, int_max)
+ q_x = (2**exponent) * mantissa /(2 ** (width - 1))
+ return q_x, mantissa, exponent
+
+def mxint_hardware(tensor, q_config, parallelism):
+ """
+ Vectorized hardware-aware quantization implementation
+ """
+
+ if len(tensor.shape) == 1:
+ tensor = tensor.unsqueeze(0)
+ if len(parallelism) == 1:
+ parallelism = [1, parallelism[0]]
+
+ p1, p0 = parallelism
+ t1, t0 = tensor.shape[-2:]
+
+ original_mshape = tensor.shape
+ original_eshape = torch.Size([t1//p1, t0//p0]) if len(tensor.shape) <=2 else torch.Size([*tensor.shape[:-2],t1//p1, t0//p0])
+ assert (t1 % p1 == 0 and t0 % p0 == 0), \
+ f"Block size mismatch: t1={t1}, p1={p1}, t0={t0}, p0={p0}"
+
+ # Single reshape and permute operation
+ block_tensor = reshape_to_block(tensor, t1, t0, p1, p0).reshape(-1, p1*p0)
+ qtensor, mantissa, exponent = mxint_quant_block(block_tensor, **q_config)
+
+ qtensor = reshape_back(qtensor, t1, t0, p1, p0)
+ mantissa = reshape_back(mantissa, t1, t0, p1, p0)
+ qtensor = qtensor.reshape(original_mshape)
+ mantissa = mantissa.reshape(original_mshape)
+ exponent = exponent.reshape(original_eshape)
+ # Efficient shape restoration
+ return qtensor, mantissa, exponent
\ No newline at end of file
diff --git a/a_cx_mxint_quant/softmax.drawio b/a_cx_mxint_quant/softmax.drawio
new file mode 100644
index 000000000..152dce204
--- /dev/null
+++ b/a_cx_mxint_quant/softmax.drawio
@@ -0,0 +1,143 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/a_cx_mxint_quant/softmax.py b/a_cx_mxint_quant/softmax.py
new file mode 100644
index 000000000..c502fd221
--- /dev/null
+++ b/a_cx_mxint_quant/softmax.py
@@ -0,0 +1,117 @@
+# models.py
+import torch
+import torch.nn as nn
+import math
+from typing import List, Union, Optional
+from pathlib import Path
+import torch
+import torch.nn as nn
+from torch import Tensor
+import math
+from typing import Literal, Optional, Tuple, Union, Dict
+from enum import Enum
+from .quantizers import mxint_quant_block, mxint_hardware
+from chop.nn.quantizers.integer import integer_quantizer, integer_floor_quantizer
+from functools import partial
+from tqdm import tqdm
+
+class MXIntHardwareExp(nn.Module):
+ def __init__(self, q_config: Dict = {}):
+ super().__init__()
+ self.q_config = q_config
+
+ def hardware_range_reduction(self, qx, data_r_width, data_n_width) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Perform range reduction: x = r + n*ln(2)
+ Returns (r, n) where r is remainder and n is integer power
+ """
+ coefficient_quant_block = partial(
+ mxint_quant_block,
+ width=8,
+ exponent_width=4
+ )
+ self.log2_e, _, _ = coefficient_quant_block(torch.log2(torch.tensor(math.e)))
+ new_mx = qx * self.log2_e
+ new_mx = integer_floor_quantizer(new_mx, data_n_width + data_r_width - 1, data_r_width - 1)
+ n = new_mx.floor()
+ r = new_mx - n
+ return r, n
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+
+ qx, mx, ex = mxint_hardware(x,
+ {
+ 'width': self.q_config.get('data_in_width'),
+ 'exponent_width': self.q_config.get('data_in_exponent_width')
+ },
+ parallelism=[1,1])
+
+ mr, n = self.hardware_range_reduction(qx, self.q_config.get('data_r_width'), self.q_config.get('data_out_exponent_width'))
+ mexp = 2 ** mr
+ mexp = integer_quantizer(mexp, self.q_config.get('data_out_width'), self.q_config.get('data_out_width') - 2)
+ mexp = mexp * 2 ** (self.q_config.get('data_out_width') - 2)
+ eexp = n
+ qexp = mexp * 2 ** eexp / 2 ** (self.q_config.get('data_out_width') - 2)
+
+ return qexp, mexp, eexp
+
+from tqdm import tqdm
+# CX: Set a new search
+# accumulator depth should be in the first dimension
+def mxint_accumulator(mx,ex):
+ out = mx[0]
+ emax = ex[0]
+ for i in range(1, mx.shape[0]):
+ old_max = emax
+ emax = torch.max(emax, ex[i])
+ in_out = out // 2**(emax - old_max)
+ in_mx = mx[i]// 2**(emax - ex[i])
+ out = in_out + in_mx
+ # breakpoint()
+ return out, emax
+
+
+class MXIntSoftmax(nn.Module):
+ def __init__(self,q_config: Dict = {}):
+ super().__init__()
+ self.q_config = q_config
+ self.exp_module = MXIntHardwareExp(q_config=q_config)
+
+ def forward(self, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
+ def exp(self, x):
+ qexp, mexp, eexp = self.exp_module(x)
+ return qexp, mexp, eexp
+
+ def exp_sum(self, qexp, mexp, eexp):
+ exp_sum_underflow_bits = self.q_config["exp_sum_underflow_bits"]
+ mexp = (mexp) * 2**exp_sum_underflow_bits
+
+ mexp= mexp.transpose(1,0)
+ eexp = eexp.transpose(1,0)
+ mexp_sum, eexp_sum = mxint_accumulator(mexp, eexp)
+ qexp_sum = mexp_sum * 2**eexp_sum / 2**exp_sum_underflow_bits
+ return qexp_sum, mexp_sum, eexp_sum
+
+ def division(self, qexp, mexp, eexp, qexp_sum, mexp_sum, eexp_sum):
+ division_underflow_bits = self.q_config["division_underflow_bits"]
+ exp_sum_underflow_bits = self.q_config["exp_sum_underflow_bits"]
+ mout = mexp * 2**(division_underflow_bits+exp_sum_underflow_bits) // mexp_sum
+ eout = eexp - eexp_sum
+ qout = mout * 2**eout / 2**division_underflow_bits
+
+ qout, _, _ = mxint_hardware(
+ qout,
+ q_config = {
+ "width": self.q_config["data_width"],
+ "exponent_width": self.q_config["data_exponent_width"],
+ },
+ parallelism = [1,1]
+ )
+
+ return qout, mout, eout
+
+ qexp, mexp, eexp = exp(self, x)
+ qexp_sum, mexp_sum, eexp_sum = exp_sum(self, qexp, mexp, eexp)
+ qout, mout, eout = division(self, qexp, mexp, eexp, qexp_sum, mexp_sum, eexp_sum)
+
+ return qout
\ No newline at end of file
diff --git a/a_cx_mxint_quant/utils.py b/a_cx_mxint_quant/utils.py
new file mode 100644
index 000000000..ea0e0b102
--- /dev/null
+++ b/a_cx_mxint_quant/utils.py
@@ -0,0 +1,33 @@
+import torch
+import torch.nn.functional as F
+
+def _get_similarity(tensor_raw, tensor_sim, metric=None):
+ if metric == "cosine":
+ similarity = F.cosine_similarity(tensor_raw, tensor_sim, dim=-1)
+ elif metric == "pearson":
+ similarity = F.cosine_similarity(
+ tensor_raw - torch.mean(tensor_raw, dim=-1, keepdim=True),
+ tensor_sim - torch.mean(tensor_sim, dim=-1, keepdim=True),
+ dim=-1,
+ )
+ else:
+ if metric == "L1_norm":
+ similarity = -torch.abs(tensor_raw - tensor_sim)
+ elif metric == "L2_norm":
+ similarity = -((tensor_raw - tensor_sim) ** 2)
+ elif metric == "linear_weighted_L2_norm":
+ similarity = -tensor_raw.abs() * (tensor_raw - tensor_sim) ** 2
+ elif metric == "square_weighted_L2_norm":
+ similarity = -((tensor_raw * (tensor_raw - tensor_sim)) ** 2)
+ else:
+ raise NotImplementedError(f"metric {metric} not implemented!")
+ similarity = torch.mean(similarity, dim=-1)
+ return similarity
+
+def reshape_to_block(tensor, t1, t0, p1, p0):
+ return tensor.reshape(-1, t1 // p1, p1, t0 // p0, p0)\
+ .permute(0, 1, 3, 2, 4)
+
+def reshape_back(tensor, t1, t0, p1, p0):
+ return tensor.reshape(-1, t1 // p1, t0 // p0, p1, p0)\
+ .permute(0, 1, 3, 2, 4)
\ No newline at end of file
diff --git a/a_cx_test_files/1attention_test.py b/a_cx_test_files/1attention_test.py
new file mode 100644
index 000000000..ec851237d
--- /dev/null
+++ b/a_cx_test_files/1attention_test.py
@@ -0,0 +1,67 @@
+
+from chop.nn.quantized.modules.attention_head import _ViTSelfAttentionHeadBase, ViTSelfAttentionHeadInteger
+from chop.nn.quantized.modules.attention import _ViTAttentionBase
+
+import torch.nn as nn
+import torch
+
+class ViTAttentionBase(nn.Module):
+
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 8,
+ qkv_bias: bool = False,
+ qk_norm: bool = False,
+ attn_drop: float = 0.0,
+ proj_drop: float = 0.0,
+ norm_layer: nn.Module = nn.LayerNorm,
+ ) -> None:
+ super().__init__()
+ assert dim % num_heads == 0, "dim should be divisible by num_heads"
+ self.num_heads = num_heads
+ self.head_dim = dim // num_heads
+ self.scale = torch.tensor(self.head_dim**-0.5)
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ B, N, C = x.shape
+ qkv = (
+ self.qkv(x)
+ .reshape(B, N, 3, self.num_heads, self.head_dim)
+ .permute(2, 0, 3, 1, 4)
+ )
+ q, k, v = qkv[0], qkv[1], qkv[2]
+ q, k = self.q_norm(q), self.k_norm(k)
+
+ attn = q @ k.transpose(-2, -1)
+ attn = (attn * self.scale).softmax(dim=-1)
+ attn = self.attn_drop(attn)
+ x = attn @ v
+
+ x = x.transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+if __name__ == "__main__":
+ dim = 4
+ head = 2
+
+ torch.manual_seed(0)
+ x = torch.rand(1, dim, dim)
+ module = ViTAttentionBase(dim, head)
+ result = module(x)
+ _module = _ViTAttentionBase(dim, head)
+ _module.qkv.weight = module.qkv.weight
+ _module.proj.weight = module.proj.weight
+ _module.qkv.bias = module.qkv.bias
+ _module.proj.bias = module.proj.bias
+ _result = _module(x)
+ print(result==_result)
\ No newline at end of file
diff --git a/a_cx_test_files/2linear_weigth_scatter.py b/a_cx_test_files/2linear_weigth_scatter.py
new file mode 100644
index 000000000..0d265d7c7
--- /dev/null
+++ b/a_cx_test_files/2linear_weigth_scatter.py
@@ -0,0 +1,72 @@
+
+from chop.nn.quantized.modules.attention_head import _ViTSelfAttentionHeadBase, ViTSelfAttentionHeadInteger
+from chop.nn.quantized.modules.attention import _ViTAttentionBase
+
+import torch.nn as nn
+import torch
+
+class ViTAttentionBase(nn.Module):
+
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 8,
+ qkv_bias: bool = False,
+ qk_norm: bool = False,
+ attn_drop: float = 0.0,
+ proj_drop: float = 0.0,
+ norm_layer: nn.Module = nn.LayerNorm,
+ ) -> None:
+ super().__init__()
+ assert dim % num_heads == 0, "dim should be divisible by num_heads"
+ self.num_heads = num_heads
+ self.head_dim = dim // num_heads
+ self.scale = torch.tensor(self.head_dim**-0.5)
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ B, N, C = x.shape
+ qkv = (
+ self.qkv(x)
+ .reshape(B, N, 3, self.num_heads, self.head_dim)
+ .permute(2, 0, 3, 1, 4)
+ )
+ q, k, v = qkv[0], qkv[1], qkv[2]
+ q, k = self.q_norm(q), self.k_norm(k)
+
+ attn = q @ k.transpose(-2, -1)
+ attn = (attn * self.scale).softmax(dim=-1)
+ attn = self.attn_drop(attn)
+ x = attn @ v
+
+ x = x.transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+if __name__ == "__main__":
+ dim = 4
+ head = 2
+ n = 3
+ x = torch.rand(1, n, dim)
+ qkv = nn.Linear(dim, 3 * dim)
+ q = nn.Linear(dim, dim)
+ k = nn.Linear(dim, dim)
+ v = nn.Linear(dim, dim)
+
+ new_weight = qkv.weight.reshape(3, -1, dim)
+ new_bias = qkv.bias.reshape(3, -1, dim)
+ q.weight,k.weight,v.weight = nn.Parameter(new_weight[0]),nn.Parameter(new_weight[1]),nn.Parameter(new_weight[2])
+ q.bias,k.bias,v.bias = nn.Parameter(new_bias[0]),nn.Parameter(new_bias[1]),nn.Parameter(new_bias[2])
+ qkv_x = qkv(x)
+ qkv_x = qkv_x.reshape(-1, 3, dim).permute(1,0,2)
+ print(qkv_x[0] == q(x))
+ print(qkv_x[1] == k(x))
+ print(qkv_x[2] == v(x))
+
\ No newline at end of file
diff --git a/a_cx_test_files/3intattention.py b/a_cx_test_files/3intattention.py
new file mode 100644
index 000000000..4c33ed518
--- /dev/null
+++ b/a_cx_test_files/3intattention.py
@@ -0,0 +1,34 @@
+from chop.nn.quantized import ViTAttentionInteger
+
+import torch.nn as nn
+import torch
+
+from chop.nn.quantized.modules.linear import (
+ LinearInteger,
+)
+
+if __name__ == "__main__":
+ dim = 4
+ head = 2
+
+ torch.manual_seed(0)
+ x = torch.rand(1, dim, dim)
+ q_config = {
+ "data_in_width":8,
+ "data_in_frac_width":4,
+ "qkv_weight_width":8,
+ "qkv_weight_frac_width":4,
+ "qkv_bias_width":8,
+ "qkv_bias_frac_width":4,
+ "qkv_width":8,
+ "qkv_frac_width":4,
+ "qkmm_out_width":4,
+ "qkmm_out_frac_width":8,
+ "softmax_exp_width":4,
+ "softmax_exp_frac_width":8,
+ "softmax_out_frac_width":4,
+ "svmm_out_width":8,
+ "svmm_out_frac_width":4,
+ }
+ module = ViTAttentionInteger(dim, head, q_config=q_config)
+ print(module(x))
\ No newline at end of file
diff --git a/a_cx_test_files/4norm.py b/a_cx_test_files/4norm.py
new file mode 100644
index 000000000..52b4222f9
--- /dev/null
+++ b/a_cx_test_files/4norm.py
@@ -0,0 +1,98 @@
+from chop.nn.quantized import ViTAttentionInteger
+import logging
+
+import torch.nn as nn
+import torch
+
+from chop.nn.quantizers.integer import (
+ integer_floor_quantizer,
+)
+
+logger = logging.getLogger("norm.models")
+logger.setLevel(logging.DEBUG)
+handler = logging.StreamHandler()
+handler.setLevel(logging.DEBUG)
+logger.addHandler(handler)
+
+def quantize(x, width, frac_width, by_pass=False):
+ if not by_pass:
+ x = integer_floor_quantizer(x, width, frac_width)
+ return x
+def get_dim_and_prodofdim(x, normalized_shape):
+ dim = tuple(range(0 - len(normalized_shape), 0))
+ num_vals = 1
+ for items in dim:
+ num_vals *= x.shape[items]
+ return dim, num_vals
+def isqrt(x:torch.Tensor):
+ x = x.sqrt()
+ x = x.reciprocal()
+ return x
+def _fixed_group_norm_2d_model(
+ x: torch.Tensor,
+ normalized_shape: tuple,
+ q_config,
+):
+ #TODO: add hardware debug info
+ logger.debug(f"Input: \n {x[0]}")
+ dim, num_vals = get_dim_and_prodofdim(x, normalized_shape)
+
+ # Mean calculation
+ mu = x.mean(dim, keepdim=True)
+ logger.debug(f"Mu: \n {mu[0]}")
+ mu = quantize(mu, q_config["in_width"], q_config["in_frac_width"], q_config["by_pass"])
+ logger.debug(f"Mu Quantized: \n {mu[0]}")
+
+ # Variance calculation
+ diff = x - mu
+ logger.debug(f"Diff: \n {diff[0]}")
+
+ squares = diff**2
+ logger.debug(f"Squares: {squares[0]}")
+
+ sum_squares = torch.sum(squares, dim, keepdim=True)
+
+ sum_squares = quantize(sum_squares, q_config["variance_width"], q_config["variance_frac_width"], q_config["by_pass"])
+
+ logger.debug("Num Values: %d" % (num_vals))
+ var = sum_squares / num_vals
+ var = quantize(var, q_config["variance_width"], q_config["variance_frac_width"], q_config["by_pass"])
+ logger.debug(f"Variance: \n {var[0]}")
+
+ inv_sqrt = isqrt(var + 1e-05)
+ inv_sqrt = quantize(inv_sqrt, q_config["isqrt_width"], q_config["isqrt_frac_width"], q_config["by_pass"])
+ logger.debug(f"INV SQRT INT: \n {inv_sqrt[0]}")
+
+ # Norm calculation
+ norm_out = diff * inv_sqrt
+ logger.debug("Norm:")
+ logger.debug(norm_out[0])
+
+ norm_out = quantize(norm_out, q_config["out_width"], q_config["out_frac_width"], q_config["by_pass"])
+ logger.debug(f"Norm (Casted): \n {norm_out[0]}")
+
+ return norm_out
+
+if __name__ == "__main__":
+ dim = 4
+ head = 2
+
+ torch.manual_seed(0)
+ q_config = {
+ "by_pass": False,
+ "in_width":8,
+ "in_frac_width":7,
+ "variance_width":16,
+ "variance_frac_width":8,
+ "isqrt_width":16,
+ "isqrt_frac_width":8,
+ "out_width":8,
+ "out_frac_width":4,
+ }
+ logger.setLevel(logging.DEBUG)
+ x = torch.rand(1, dim)
+ _x = _fixed_group_norm_2d_model(
+ x, (4,), q_config)
+ module = torch.nn.LayerNorm(dim,elementwise_affine=False, bias=False)
+ print(_x)
+ print(module(x))
\ No newline at end of file
diff --git a/a_cx_test_files/Dockerfile-cpu-python13 b/a_cx_test_files/Dockerfile-cpu-python13
new file mode 100644
index 000000000..1a5f846c7
--- /dev/null
+++ b/a_cx_test_files/Dockerfile-cpu-python13
@@ -0,0 +1,32 @@
+# This Dockerfile configures a Docker environment that
+# contains all the required packages for the tool
+FROM ubuntu:22.04
+
+USER root
+
+# Install apt packages
+ADD install-pkgs-python13.sh install-pkgs-python13.sh
+RUN bash install-pkgs-python13.sh
+
+CMD ["bash"]
+
+# Ensure pip is installed for python3.13 if it's missing dependencies
+RUN python3 -m ensurepip --upgrade && \
+ python3 -m pip install --upgrade pip
+
+# Install PyTorch and Torch-MLIR
+RUN pip3 install --upgrade pip
+RUN pip3 install --pre torch-mlir torchvision \
+ -f https://github.com/llvm/torch-mlir-release/releases/expanded_assets/dev-wheels \
+ --extra-index-url https://download.pytorch.org/whl/nightly/cpu
+
+# Install pip packages
+ADD install-pips-python13.sh install-pips-python13.sh
+RUN bash install-pips-python13.sh
+
+# Add environment variables
+ARG VHLS_PATH
+ARG VHLS_VERSION
+ADD install-env.sh install-env.sh
+RUN bash install-env.sh $VHLS_PATH $VHLS_VERSION
+
diff --git a/a_cx_test_files/bash_script/run_hw_test.sh b/a_cx_test_files/bash_script/run_hw_test.sh
new file mode 100644
index 000000000..68b0aa8a0
--- /dev/null
+++ b/a_cx_test_files/bash_script/run_hw_test.sh
@@ -0,0 +1,131 @@
+
+# # Activation_layers
+# # python3 scripts/build-components.py
+# python3 src/mase_components/activation_layers/test/fixed_gelu_tb.py
+# python3 src/mase_components/activation_layers/test/fixed_leaky_relu_tb.py
+# python3 src/mase_components/activation_layers/test/fixed_relu_tb.py
+# python3 src/mase_components/activation_layers/test/fixed_selu_tb.py
+# # python3 src/mase_components/activation_layers/test/fixed_sigmoid_tb.py
+# python3 src/mase_components/activation_layers/test/fixed_softermax_1d_tb.py
+# # python3 src/mase_components/activation_layers/test/fixed_softermax_tb.py
+# python3 src/mase_components/activation_layers/test/fixed_softmax_tb.py
+# python3 src/mase_components/activation_layers/test/fixed_softplus_tb.py
+# python3 src/mase_components/activation_layers/test/fixed_softsign_tb.py
+# python3 src/mase_components/activation_layers/test/fixed_tanh_tb.py
+# # python3 src/mase_components/activation_layers/test/softermax_global_norm_tb.py
+# # python3 src/mase_components/activation_layers/test/softermax_local_window_tb.py
+# # python3 src/mase_components/activation_layers/test/softermax_lpw_pow2_tb.py
+# # python3 src/mase_components/activation_layers/test/softermax_lpw_reciprocal_tb.py
+# # python3 src/mase_components/activation_layers/test/test_lint_activation_layers.py
+# # python3 src/mase_components/activation_layers/test/test_synth_activation_layers.py
+# # DEV mode (no intention to fix)
+# # python3 src/mase_components/activation_layers/test/fixed_elu_tb.py
+# # python3 src/mase_components/activation_layers/test/fixed_hardshrink_tb.py
+# # python3 src/mase_components/activation_layers/test/fixed_hardswish_tb.py
+# # python3 src/mase_components/activation_layers/test/fixed_logsigmoid_tb.py
+# # python3 src/mase_components/activation_layers/test/fixed_silu_tb.py
+# # python3 src/mase_components/activation_layers/test/fixed_softshrink_tb.py
+
+# # Cast
+# python3 src/mase_components/cast/test/fixed_cast_tb.py
+# python3 src/mase_components/cast/test/fixed_rounding_tb.py
+# python3 src/mase_components/cast/test/fixed_signed_cast_tb.py
+# # python3 src/mase_components/cast/test/fixed_unsigned_cast_tb.py
+
+# # Common
+# python3 src/mase_components/common/test/comparator_accumulator_tb.py
+# python3 src/mase_components/common/test/cut_data_tb.py
+# python3 src/mase_components/common/test/lut_tb.py
+# python3 src/mase_components/common/test/wrap_data_tb.py
+# # python3 src/mase_components/common/test/register_slice_tb.py
+# # python3 src/mase_components/common/test/test_lint_common.py
+# # DEV
+# # python3 src/mase_components/common/test/comparator_tree_tb.py
+# # python3 src/mase_components/common/test/single_element_repeat_tb.py
+
+# # Convolution_layers
+# python3 src/mase_components/convolution_layers/test/convolution_tb.py
+
+# # Inteface
+# python3 src/mase_components/interface/axi/test/test_lint_axi.py
+# # python3 src/mase_components/interface/axi/test/test_synth_axi.py
+
+# # Language models llmint8
+# python3 src/mase_components/language_models/llmint8/test/find_max_tb.py
+# python3 src/mase_components/language_models/llmint8/test/fixed_comparator_tree_layer_tb.py
+# python3 src/mase_components/language_models/llmint8/test/fixed_comparator_tree_tb.py
+# python3 src/mase_components/language_models/llmint8/test/quantized_matmul_tb.py
+# python3 src/mase_components/language_models/llmint8/test/quantizer_top_tb.py
+# python3 src/mase_components/language_models/llmint8/test/scatter_tb.py
+# # DEV
+# # python3 src/mase_components/language_models/llmint8/test/llm_int8_top_tb.py
+
+# # Linear layers
+# # Linear Layer - fixed_linear_layer DEBUG: use bias causes crash
+# python3 src/mase_components/linear_layers/fixed_linear_layer/test/fixed_linear_tb.py
+# # python3 src/mase_components/linear_layers/fixed_linear_layer/test/binary_activation_binary_linear_tb.py
+# # python3 src/mase_components/linear_layers/fixed_linear_layer/test/fixed_activation_binary_linear_tb.py
+# # Linear Layer - fixed_operators
+# python3 src/mase_components/linear_layers/fixed_operators/test/fixed_accumulator_tb.py
+# # python3 src/mase_components/linear_layers/fixed_operators/test/fixed_adder_tree_layer_tb.py
+# python3 src/mase_components/linear_layers/fixed_operators/test/fixed_adder_tree_tb.py
+# python3 src/mase_components/linear_layers/fixed_operators/test/fixed_dot_product_tb.py
+# python3 src/mase_components/linear_layers/fixed_operators/test/fixed_lut_index_tb.py
+# # python3 src/mase_components/linear_layers/fixed_operators/test/fixed_matmul_core_tb.py
+# python3 src/mase_components/linear_layers/fixed_operators/test/fixed_mult_tb.py
+# python3 src/mase_components/linear_layers/fixed_operators/test/fixed_range_augmentation_tb.py
+# # python3 src/mase_components/linear_layers/fixed_operators/test/fixed_range_reduction_tb.py
+# # Linear Layer - matmul
+# # python3 src/mase_components/linear_layers/matmul/test/chain_matmul_tb.py
+# # python3 src/mase_components/linear_layers/matmul/test/fixed_mamul_tb.py
+# # python3 src/mase_components/linear_layers/matmul/test/matmul_tb.py
+# # python3 src/mase_components/linear_layers/matmul/test/matrix_stream_transpose_tb.py
+# # python3 src/mase_components/linear_layers/matmul/test/transpose_tb.py
+# # DEV Linear Layer - binary_operators
+# python3 src/mase_components/linear_layers/binarized_operators/test/binary_activation_binary_adder_tree_layer_tb.py
+# # python3 src/mase_components/linear_layers/binarized_operators/test/binary_activation_binary_adder_tree_tb.py
+# # python3 src/mase_components/linear_layers/binarized_operators/test/binary_activation_binary_dot_product_tb.py
+# # python3 src/mase_components/linear_layers/binarized_operators/test/binary_activation_binary_matmul_core_tb.py
+# # python3 src/mase_components/linear_layers/binarized_operators/test/binary_activation_binary_mult_tb.py
+# # python3 src/mase_components/linear_layers/binarized_operators/test/binary_activation_binary_vector_mult_tb.py
+# # python3 src/mase_components/linear_layers/binarized_operators/test/fixed_activation_binary_dot_product_tb.py
+# # python3 src/mase_components/linear_layers/binarized_operators/test/fixed_activation_binary_mult_tb.py
+# # python3 src/mase_components/linear_layers/binarized_operators/test/fixed_activation_binary_vector_mult_tb.py
+# # python3 src/mase_components/linear_layers/binarized_operators/test/test_lint_binary_arith.py
+# # MxInt
+# python3 src/mase_components/linear_layers/mxint_operators/test/mxint_cast_tb.py
+# python3 src/mase_components/linear_layers/mxint_operators/test/mxint_matmul_tb.py
+python3 src/mase_components/linear_layers/mxint_operators/test/mxint_linear_tb.py
+python3 src/mase_components/linear_layers/mxint_operators/test/mxint_accumulator_tb.py
+python3 src/mase_components/linear_layers/mxint_operators/test/mxint_softmax.py
+# Memory
+python3 src/mase_components/memory/test/fifo_tb.py
+# python3 src/mase_components/memory/test/input_buffer_tb.py
+python3 src/mase_components/memory/test/skid_buffer_tb.py
+# python3 src/mase_components/memory/test/unpacked_fifo_tb.py
+# python3 src/mase_components/memory/test/repeat_circular_buffer_tb.py
+# python3 src/mase_components/memory/test/test_lint_memory.py
+
+# Normalization_layers
+python3 src/mase_components/normalization_layers/test/batch_norm_2d_tb.py
+python3 src/mase_components/normalization_layers/test/group_norm_2d_tb.py
+# DEV
+# python3 src/mase_components/normalization_layers/test/channel_selection_tb.py
+# python3 src/mase_components/normalization_layers/test/rms_norm_2d_tb.py
+# python3 src/mase_components/normalization_layers/test/test_lint_norm.py
+
+# Scalar operators
+python3 src/mase_components/scalar_operators/fixed/test/fixed_isqrt_tb.py
+python3 src/mase_components/scalar_operators/fixed/test/isqrt_sw.py
+# python3 src/mase_components/scalar_operators/float/test/test_lint_float_arithmetic.py
+# python3 src/mase_components/scalar_operators/fixed/test/fixed_nr_stage_tb.py
+# python3 src/mase_components/scalar_operators/fixed/test/test_lint_fixed_math.py
+
+# Systolic array
+# python3 src/mase_components/systolic_arrays/test/test_lint_systolic_arrays.py
+
+# Transformer_layers
+python3 src/mase_components/transformer_layers/test/fixed_self_attention_head_tb.py
+# python3 src/mase_components/transformer_layers/test/fixed_gqa_head_tb.py
+# python3 src/mase_components/transformer_layers/test/fixed_self_attention_tb.py
+# python3 src/mase_components/transformer_layers/test/test_lint_attention.py
diff --git a/a_cx_test_files/bash_script/run_latency_test.sh b/a_cx_test_files/bash_script/run_latency_test.sh
new file mode 100644
index 000000000..18f03abfd
--- /dev/null
+++ b/a_cx_test_files/bash_script/run_latency_test.sh
@@ -0,0 +1,3 @@
+# !/bin/bash
+python3 test/passes/graph/transforms/verilog/test_emit_verilog_mxint_vit_block.py
+python3 mase_mxint_top_tb.py
diff --git a/a_cx_test_files/bash_script/run_real_top.sh b/a_cx_test_files/bash_script/run_real_top.sh
new file mode 100644
index 000000000..54a96eeec
--- /dev/null
+++ b/a_cx_test_files/bash_script/run_real_top.sh
@@ -0,0 +1,5 @@
+#/bin/bash
+CONFIG_PATH=$1.yaml python3 test/passes/graph/transforms/verilog/test_emit_verilog_mxint_vit_folded_top.py
+CONFIG_PATH=$1.yaml python3 test/passes/graph/transforms/verilog/test_emit_verilog_mxint_real_top.py
+#cd /scratch/cx922/mase/mxint_$1/hardware/top_build_project
+#vivado -mode batch -log project_build.log -source build.tcl
diff --git a/a_cx_test_files/bash_script/run_vivado.sh b/a_cx_test_files/bash_script/run_vivado.sh
new file mode 100644
index 000000000..b1716cc07
--- /dev/null
+++ b/a_cx_test_files/bash_script/run_vivado.sh
@@ -0,0 +1,4 @@
+#/bin/bash
+#python3 ./test/passes/graph/transforms/verilog/test_emit_verilog_$1.py
+cd $1/hardware/top_build_project
+vivado -mode batch -log project_build.log -source build.tcl
diff --git a/a_cx_test_files/deit_base.yaml b/a_cx_test_files/deit_base.yaml
new file mode 100644
index 000000000..fd5f9df7a
--- /dev/null
+++ b/a_cx_test_files/deit_base.yaml
@@ -0,0 +1,27 @@
+# Parameters for real top test
+img_size: 224
+in_chans: 3
+patch_size: 16
+n: 196
+embed_dim: 768
+num_heads: 12
+
+# Parameters for vit folded top test
+config:
+ data_width: 6
+ data_exponent_width: 8
+ weight_width: 6
+ weight_exponent_width: 8
+ bias_width: 6
+ bias_exponent_width: 8
+
+parallelism: 16
+mlp_parallelism: 64
+
+folded_depth: 6 # number of times to fold/reuse the streaming blocks
+stream_depth: 2 # number of transformer blocks in streaming pipeline
+
+# Project directory
+# project_dir: "/home/cx922/v80_mxint_hardware/deit_base/"
+project_dir: "/home/cx922/optimized1_final_result/deit_base"
+# project_dir: "/home/cx922/fp8_result/deit_tiny"
\ No newline at end of file
diff --git a/a_cx_test_files/deit_small.yaml b/a_cx_test_files/deit_small.yaml
new file mode 100644
index 000000000..f45dec323
--- /dev/null
+++ b/a_cx_test_files/deit_small.yaml
@@ -0,0 +1,26 @@
+
+# Parameters for rea
+img_size: 224
+in_chans: 3
+patch_size: 16
+n: 196
+embed_dim: 384
+num_heads: 6
+# General parameters
+config:
+ data_width: 6
+ data_exponent_width: 8
+ weight_width: 6
+ weight_exponent_width: 8
+ bias_width: 6
+ bias_exponent_width: 8
+
+parallelism: 16
+mlp_parallelism: 48
+# Parameters for vit folded top test
+folded_depth: 4 # number of times to fold/reuse the streaming blocks
+stream_depth: 3 # number of transformer blocks in streaming pipeline
+# Project directory
+# project_dir: "/home/cx922/fp16_result/deit_small"
+# project_dir: "/home/cx922/optimized_final_result/deit_small"
+project_dir: "/home/cx922/optimized1_final_result/deit_small"
\ No newline at end of file
diff --git a/a_cx_test_files/deit_tiny.yaml b/a_cx_test_files/deit_tiny.yaml
new file mode 100644
index 000000000..85386ae24
--- /dev/null
+++ b/a_cx_test_files/deit_tiny.yaml
@@ -0,0 +1,29 @@
+
+# Parameters for rea
+img_size: 224
+in_chans: 3
+patch_size: 16
+n: 196
+embed_dim: 192
+num_heads: 3
+# General parameters
+config:
+ data_width: 4
+ data_exponent_width: 4
+ weight_width: 4
+ weight_exponent_width: 4
+ bias_width: 4
+ bias_exponent_width: 4
+
+parallelism: 1
+mlp_parallelism: 1
+# Parameters for vit folded top test
+folded_depth: 1 # number of times to fold/reuse the streaming blocks
+stream_depth: 12 # number of transformer blocks in streaming pipeline
+# Project directory
+# project_dir: "/scratch/cx922/mase/mxint_deit_tiny_f4_m3"
+
+# project_dir: "/home/cx922/optimized_final_result/deit_tiny"
+project_dir: "/home/cx922/fp8_result/deit_tiny"
+# project_dir: "/scratch/cx922/mase/deit_tiny_check"
+
diff --git a/a_cx_test_files/install-pips-python13.sh b/a_cx_test_files/install-pips-python13.sh
new file mode 100644
index 000000000..7abd38a3d
--- /dev/null
+++ b/a_cx_test_files/install-pips-python13.sh
@@ -0,0 +1,23 @@
+#!/usr/bin/env bash
+# --------------------------------------------------------------------
+# This script installs pip packages for both Docker containers
+# --------------------------------------------------------------------
+set -o errexit
+set -o pipefail
+set -o nounset
+
+pip3 install onnx black toml GitPython colorlog cocotb[bus] \
+ pytest pytorch-lightning transformers toml \
+ timm pytorch-nlp datasets ipython ipdb \
+ sentencepiece einops deepspeed pybind11 \
+ tabulate tensorboardx hyperopt accelerate \
+ optuna stable-baselines3[extra] h5py scikit-learn \
+ scipy onnxruntime matplotlib sphinx-rtd-theme \
+ imageio imageio-ffmpeg opencv-python kornia einops \
+ ghp-import optimum pytest-profiling myst_parser \
+ pytest-cov pytest-xdist pytest-sugar pytest-html \
+ lightning wandb bitarray bitstring emoji evaluate pynvml cvxpy \
+ "numpy<2.0" tensorboard \
+ onnxconverter-common absl-py sphinx-glpi-theme prettytable \
+ && pip install -U Pillow \
+ && pip install mpmath==1.3.0
diff --git a/a_cx_test_files/install-pkgs-python13.sh b/a_cx_test_files/install-pkgs-python13.sh
new file mode 100644
index 000000000..0734b0bee
--- /dev/null
+++ b/a_cx_test_files/install-pkgs-python13.sh
@@ -0,0 +1,76 @@
+#!/usr/bin/env bash
+# --------------------------------------------------------------------
+# This script installs initial packages for both Docker containers
+# --------------------------------------------------------------------
+set -o errexit
+set -o pipefail
+set -o nounset
+
+
+apt-get update -y && apt-get install apt-utils -y
+DEBIAN_FRONTEND="noninteractive" apt-get -y install tzdata
+
+# Install basic packages
+apt-get upgrade -y
+apt-get update -y \
+ && apt-get install -y clang graphviz-dev libclang-dev \
+ pkg-config g++ libxtst6 xdg-utils \
+ libboost-all-dev llvm gcc ninja-build \
+ python3 python3-pip build-essential \
+ libssl-dev git vim wget htop \
+ lld parallel clang-format clang-tidy \
+ libtinfo5 libidn11-dev unzip \
+ locales python3-sphinx graphviz
+
+locale-gen en_US.UTF-8
+
+# Install SystemVerilog formatter
+mkdir -p /srcPkgs \
+ && cd /srcPkgs \
+ && wget https://github.com/chipsalliance/verible/releases/download/v0.0-2776-gbaf0efe9/verible-v0.0-2776-gbaf0efe9-Ubuntu-22.04-jammy-x86_64.tar.gz \
+ && mkdir -p verible \
+ && tar xzvf verible-*-x86_64.tar.gz -C verible --strip-components 1
+# Install verilator from source - version v5.020
+apt-get update -y \
+ && apt-get install -y git perl make autoconf flex bison \
+ ccache libgoogle-perftools-dev numactl \
+ perl-doc libfl2 libfl-dev zlib1g zlib1g-dev \
+ help2man
+# Install Verilator from source
+mkdir -p /srcPkgs \
+ && cd /srcPkgs \
+ && git clone https://github.com/verilator/verilator \
+ && unset VERILATOR_ROOT \
+ && cd verilator \
+ && git checkout v5.020 \
+ && autoconf \
+ && ./configure \
+ && make -j 4 \
+ && make install
+
+# Install latest Cmake from source
+mkdir -p /srcPkgs \
+ && cd /srcPkgs \
+ && wget https://github.com/Kitware/CMake/releases/download/v3.28.0-rc5/cmake-3.28.0-rc5.tar.gz \
+ && mkdir -p cmake \
+ && tar xzvf cmake-*.tar.gz -C cmake --strip-components 1 \
+ && cd cmake \
+ && ./bootstrap --prefix=/usr/local \
+ && make -j 4 \
+ && make install
+
+# Append any packages you need here
+# apt-get ...
+apt-get update -y \
+ && apt-get install -y clang-12
+
+export DEBIAN_FRONTEND=noninteractive \
+ && apt-get install -y software-properties-common \
+ && add-apt-repository ppa:deadsnakes/ppa \
+ && apt update -y \
+ && apt install -y python3.13 python3.13-distutils \
+ && update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.13 300 \
+ && update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.10 100 \
+ && update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.11 200 \
+ && update-alternatives --config python3
+
diff --git a/a_cx_test_files/new_skid_buffer.py b/a_cx_test_files/new_skid_buffer.py
new file mode 100644
index 000000000..d263edd94
--- /dev/null
+++ b/a_cx_test_files/new_skid_buffer.py
@@ -0,0 +1,311 @@
+#!/usr/bin/env python3
+
+import os, pytest
+
+import torch
+import logging
+from functools import partial
+
+import cocotb
+from cocotb.log import SimLog
+from cocotb.triggers import Timer, RisingEdge
+
+from mase_cocotb.testbench import Testbench
+from mase_cocotb.interfaces.streaming import (
+ StreamDriver,
+ StreamMonitor,
+ ErrorThresholdStreamMonitor,
+)
+from mase_cocotb.runner import mase_runner
+
+# from mase_cocotb import Testbench, StreamDriver, StreamMonitor, mase_runner
+from chop.nn.quantized.modules.linear import LinearInteger
+from chop.nn.quantizers import integer_floor_quantizer
+
+
+class LinearTB(Testbench):
+ def __init__(self, dut) -> None:
+ super().__init__(dut, dut.clk, dut.rst)
+
+ if not hasattr(self, "log"):
+ self.log = SimLog("%s" % (type(self).__qualname__))
+ self.log.setLevel(logging.DEBUG)
+
+ self.data_in_0_driver = StreamDriver(
+ dut.clk, dut.data_in_0, dut.data_in_0_valid, dut.data_in_0_ready
+ )
+ self.weight_driver = StreamDriver(
+ dut.clk, dut.weight, dut.weight_valid, dut.weight_ready
+ )
+
+ if self.get_parameter("HAS_BIAS") == 1:
+ self.bias_driver = StreamDriver(
+ dut.clk, dut.bias, dut.bias_valid, dut.bias_ready
+ )
+ self.bias_driver.log.setLevel(logging.DEBUG)
+
+ self.data_out_0_monitor = StreamMonitor(
+ dut.clk,
+ dut.data_out_0,
+ dut.data_out_0_valid,
+ dut.data_out_0_ready,
+ check=True,
+ )
+
+ # self.data_out_0_monitor = ErrorThresholdStreamMonitor(
+ # dut.clk,
+ # dut.data_out_0,
+ # dut.data_out_0_valid,
+ # dut.data_out_0_ready,
+ # width=self.get_parameter("DATA_OUT_0_PRECISION_0"),
+ # signed=True,
+ # error_bits=1,
+ # check=True,
+ # )
+
+ # Model
+ self.model = LinearInteger(
+ in_features=self.get_parameter("DATA_IN_0_TENSOR_SIZE_DIM_0"),
+ out_features=self.get_parameter("DATA_OUT_0_TENSOR_SIZE_DIM_0"),
+ bias=True if self.get_parameter("HAS_BIAS") == 1 else False,
+ config={
+ "data_in_width": self.get_parameter("DATA_IN_0_PRECISION_0"),
+ "data_in_frac_width": self.get_parameter("DATA_IN_0_PRECISION_1"),
+ "weight_width": self.get_parameter("WEIGHT_PRECISION_0"),
+ "weight_frac_width": self.get_parameter("WEIGHT_PRECISION_1"),
+ "bias_width": self.get_parameter("BIAS_PRECISION_0"),
+ "bias_frac_width": self.get_parameter("BIAS_PRECISION_1"),
+ },
+ out_config={
+ "data_out_width": self.get_parameter("DATA_OUT_0_PRECISION_0"),
+ "data_out_frac_width": self.get_parameter("DATA_OUT_0_PRECISION_1"),
+ },
+ floor=True,
+ )
+
+ # Set verbosity of driver and monitor loggers to debug
+ self.data_in_0_driver.log.setLevel(logging.DEBUG)
+ self.weight_driver.log.setLevel(logging.DEBUG)
+ self.data_out_0_monitor.log.setLevel(logging.DEBUG)
+
+ def generate_inputs(self):
+ return torch.randn(
+ (
+ self.get_parameter("DATA_IN_0_TENSOR_SIZE_DIM_1"),
+ self.get_parameter("DATA_IN_0_TENSOR_SIZE_DIM_0"),
+ )
+ )
+
+ def preprocess_tensor(self, tensor, config, parallelism):
+ if len(tensor.shape) == 1:
+ tensor = tensor.unsqueeze(0)
+
+ # Quantize
+ quantizer = partial(integer_floor_quantizer, **config)
+ q_tensor = quantizer(tensor)
+ self.log.debug(f"Quantized tensor: {q_tensor}")
+
+ # Convert to integer format
+ q_tensor = (q_tensor * 2 ** config["frac_width"]).int()
+ self.log.debug(f"Tensor in integer format: {q_tensor}")
+
+ # Split into chunks according to parallelism in each dimension
+ # parallelism[0]: along rows, parallelism[1]: along columns
+ dim_0_split = q_tensor.split(parallelism[0], dim=0)
+ dim_1_split = [x.split(parallelism[1], dim=1) for x in dim_0_split]
+ blocks = []
+ # Flatten the list of blocks
+ for i in range(len(dim_1_split)):
+ for j in range(len(dim_1_split[i])):
+ blocks.append(dim_1_split[i][j].flatten().tolist())
+ return blocks
+
+ async def run_test(self, batches=1, us=100):
+ await self.reset()
+ self.log.info(f"Reset finished")
+ self.data_out_0_monitor.ready.value = 1
+ for _ in range(batches):
+ inputs = self.generate_inputs()
+ exp_out = self.model(inputs)
+
+ # * Load the inputs driver
+ self.log.info(f"Processing inputs: {inputs}")
+ inputs = self.preprocess_tensor(
+ tensor=inputs,
+ config={
+ "width": self.get_parameter("DATA_IN_0_PRECISION_0"),
+ "frac_width": self.get_parameter("DATA_IN_0_PRECISION_1"),
+ },
+ parallelism=[
+ self.get_parameter("DATA_IN_0_PARALLELISM_DIM_1"),
+ self.get_parameter("DATA_IN_0_PARALLELISM_DIM_0"),
+ ],
+ )
+ self.data_in_0_driver.load_driver(inputs)
+
+ # * Load the weights driver
+ if self.get_parameter("WEIGHTS_PRE_TRANSPOSED") == 1:
+ weights = self.model.weight.transpose(0, 1)
+ else:
+ weights = self.model.weight
+
+ self.log.info(f"Processing weights: {weights}")
+ weights = self.preprocess_tensor(
+ tensor=weights,
+ config={
+ "width": self.get_parameter("WEIGHT_PRECISION_0"),
+ "frac_width": self.get_parameter("WEIGHT_PRECISION_1"),
+ },
+ parallelism=[
+ self.get_parameter("WEIGHT_PARALLELISM_DIM_1"),
+ self.get_parameter("WEIGHT_PARALLELISM_DIM_0"),
+ ],
+ )
+ self.weight_driver.load_driver(weights)
+
+ # * Load the bias driver
+ if self.get_parameter("HAS_BIAS") == 1:
+ bias = self.model.bias
+ self.log.info(f"Processing bias: {bias}")
+ bias = self.preprocess_tensor(
+ tensor=bias,
+ config={
+ "width": self.get_parameter("BIAS_PRECISION_0"),
+ "frac_width": self.get_parameter("BIAS_PRECISION_1"),
+ },
+ parallelism=[
+ self.get_parameter("BIAS_PARALLELISM_DIM_1"),
+ self.get_parameter("BIAS_PARALLELISM_DIM_0"),
+ ],
+ )
+ self.bias_driver.load_driver(bias)
+
+ # * Load the output monitor
+ self.log.info(f"Processing outputs: {exp_out}")
+ outs = self.preprocess_tensor(
+ tensor=exp_out,
+ config={
+ "width": self.get_parameter("DATA_OUT_0_PRECISION_0"),
+ "frac_width": self.get_parameter("DATA_OUT_0_PRECISION_1"),
+ },
+ parallelism=[
+ self.get_parameter("DATA_OUT_0_PARALLELISM_DIM_1"),
+ self.get_parameter("DATA_OUT_0_PARALLELISM_DIM_0"),
+ ],
+ )
+ self.data_out_0_monitor.load_monitor(outs)
+
+ await Timer(us, units="us")
+ assert self.data_out_0_monitor.exp_queue.empty()
+
+
+@cocotb.test()
+async def cocotb_test(dut):
+ tb = LinearTB(dut)
+ await tb.run_test(batches=10, us=100)
+
+
+async def check_signal(dut, log):
+ num = {"data_out_0": 0, "data_in_0": 0}
+ while True:
+ await RisingEdge(dut.clk)
+
+
+# verified case
+# weight per transpoed = 0
+# weight pre transposed = 1
+# has bias = 0
+# has bias = 1
+def get_fixed_linear_config(kwargs={}):
+ # if pretranspose
+ # weight1 = in0
+ # else
+ # weight0 = in0
+ config = {
+ "HAS_BIAS": 1,
+ "WEIGHTS_PRE_TRANSPOSED": 0,
+ "DATA_IN_0_TENSOR_SIZE_DIM_0": 32,
+ "DATA_IN_0_TENSOR_SIZE_DIM_1": 16,
+ "DATA_IN_0_PARALLELISM_DIM_0": 8,
+ "DATA_IN_0_PARALLELISM_DIM_1": 4,
+ "WEIGHT_TENSOR_SIZE_DIM_0": 32,
+ "WEIGHT_TENSOR_SIZE_DIM_1": 16,
+ "WEIGHT_PARALLELISM_DIM_0": 8,
+ "WEIGHT_PARALLELISM_DIM_1": 4,
+ "DATA_IN_0_PRECISION_0": 8,
+ "DATA_IN_0_PRECISION_1": 4,
+ "WEIGHT_PRECISION_0": 10,
+ "WEIGHT_PRECISION_1": 3,
+ "BIAS_PRECISION_0": 5,
+ "BIAS_PRECISION_1": 2,
+ "DATA_OUT_0_PRECISION_0": 8,
+ "DATA_OUT_0_PRECISION_1": 4,
+ }
+ config.update(kwargs)
+ return config
+
+
+@pytest.mark.dev
+def test_fixed_linear_smoke():
+ """
+ Some quick tests to check if the module is working.
+ """
+ mase_runner(
+ trace=True,
+ module_param_list=[
+ get_fixed_linear_config(),
+ # noticed here if change WEIGHT_PRE_TRANSPOSED also need to change the DIM_SIZE to match ACTIVATION
+ get_fixed_linear_config(
+ {
+ "WEIGHTS_PRE_TRANSPOSED": 0,
+ "WEIGHT_TENSOR_SIZE_DIM_0": 32,
+ "WEIGHT_TENSOR_SIZE_DIM_1": 16,
+ "WEIGHT_PARALLELISM_DIM_0": 4,
+ "WEIGHT_PARALLELISM_DIM_1": 2,
+ },
+ ),
+ ],
+ )
+
+
+# @pytest.mark.dev
+# def test_fixed_linear_regression():
+# """
+# More extensive tests to check realistic parameter sizes.
+# """
+# mase_runner(
+# trace=True,
+# module_param_list=[
+# get_fixed_linear_config(
+# {
+# "DATA_IN_0_TENSOR_SIZE_DIM_0": 768,
+# "DATA_IN_0_PARALLELISM_DIM_0": 32,
+# "WEIGHT_TENSOR_SIZE_DIM_0": 768,
+# "WEIGHT_TENSOR_SIZE_DIM_1": 768,
+# "WEIGHT_PARALLELISM_DIM_0": 32,
+# "WEIGHT_PARALLELISM_DIM_1": 32,
+# "BIAS_TENSOR_SIZE_DIM_0": 768,
+# "BIAS_PARALLELISM_DIM_0": 32,
+# }
+# ),
+# get_fixed_linear_config(
+# {
+# "HAS_BIAS": 1,
+# "WEIGHTS_PRE_TRANSPOSED": 0,
+# "DATA_IN_0_TENSOR_SIZE_DIM_0": 768,
+# "DATA_IN_0_PARALLELISM_DIM_0": 32,
+# "WEIGHT_TENSOR_SIZE_DIM_0": 768,
+# "WEIGHT_TENSOR_SIZE_DIM_1": 768,
+# "WEIGHT_PARALLELISM_DIM_0": 32,
+# "WEIGHT_PARALLELISM_DIM_1": 32,
+# "BIAS_TENSOR_SIZE_DIM_0": 768,
+# "BIAS_PARALLELISM_DIM_0": 32,
+# }
+# ),
+# ],
+# )
+
+torch.manual_seed(3)
+if __name__ == "__main__":
+ test_fixed_linear_smoke()
+ # test_fixed_linear_regression()
diff --git a/a_cx_test_files/new_skid_buffer.sv b/a_cx_test_files/new_skid_buffer.sv
new file mode 100644
index 000000000..13c77c69c
--- /dev/null
+++ b/a_cx_test_files/new_skid_buffer.sv
@@ -0,0 +1,37 @@
+`timescale 1ns / 1ps
+module new_skid_buffer #(
+ parameter DATA_WIDTH = 32
+) (
+ input logic clk,
+ input logic rst,
+
+ input logic [DATA_WIDTH - 1:0] data_in,
+ input logic data_in_valid,
+ output logic data_in_ready,
+
+ output logic [DATA_WIDTH - 1:0] data_out,
+ output logic data_out_valid,
+ input logic data_out_ready
+);
+ // feed the data_out either from
+ // data_in or a buffered copy of data_in
+ logic [DATA_WIDTH - 1:0] buffer;
+ logic buffer_valid;
+ logic buffer_ready;
+ always_ff @(posedge clk) begin
+ if (rst) begin
+ buffer <= 0;
+ buffer_valid <= 0;
+ end else begin
+ buffer <= data_in;
+ buffer_valid <= data_in_valid;
+ data_in_ready <= buffer_ready;
+ end
+ end
+ always_comb begin
+ buffer_ready = data_out_ready;
+ data_out = buffer;
+ data_out_valid = buffer_valid;
+ end
+
+endmodule
diff --git a/a_cx_test_files/roadmap.drawio b/a_cx_test_files/roadmap.drawio
new file mode 100644
index 000000000..c8b3241c6
--- /dev/null
+++ b/a_cx_test_files/roadmap.drawio
@@ -0,0 +1,223 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/a_cx_test_files/source_code_list/_11.py b/a_cx_test_files/source_code_list/_11.py
new file mode 100644
index 000000000..4541111eb
--- /dev/null
+++ b/a_cx_test_files/source_code_list/_11.py
@@ -0,0 +1,194 @@
+#!/usr/bin/env python3
+
+import os
+import torch
+import logging
+from functools import partial
+import cocotb
+from cocotb.log import SimLog
+from cocotb.triggers import Timer, RisingEdge, ReadOnly
+from pathlib import Path
+
+from mase_cocotb.testbench import Testbench
+from mase_cocotb.interfaces.streaming import MultiSignalStreamDriver, MultiSignalStreamMonitor
+from mase_cocotb.runner import mase_runner
+from a_cx_mxint_quant import mxint_quant_block, mxint_hardware
+from utils import pack_tensor_to_mx_listed_chunk
+
+class MxIntLayerNorm1DTB(Testbench):
+ def __init__(self, dut) -> None:
+ super().__init__(dut, dut.clk, dut.rst)
+
+ if not hasattr(self, "log"):
+ self.log = SimLog("%s" % (type(self).__qualname__))
+ self.log.setLevel(logging.DEBUG)
+
+ # Input data driver
+ self.data_in_driver = MultiSignalStreamDriver(
+ dut.clk,
+ (dut.mdata_in_0, dut.edata_in_0),
+ dut.data_in_0_valid,
+ dut.data_in_0_ready
+ )
+
+ # Output monitor
+ self.out_monitor = MultiSignalStreamMonitor(
+ dut.clk,
+ (dut.mdata_out_0, dut.edata_out_0),
+ dut.data_out_0_valid,
+ dut.data_out_0_ready,
+ check=True,
+ )
+
+ self.input_drivers = {
+ "data_in": self.data_in_driver,
+ }
+ self.output_monitors = {"out": self.out_monitor}
+
+ # Model parameters
+ self.tensor_size_dim_0 = self.get_parameter("DATA_IN_0_TENSOR_SIZE_DIM_0")
+ self.parallelism_dim_0 = self.get_parameter("DATA_IN_0_PARALLELISM_DIM_0")
+
+ def preprocess_tensor_for_mxint(self, tensor, config, parallelism):
+ (qtensor, mtensor, etensor) = mxint_hardware(tensor, config, parallelism)
+ tensor_inputs = pack_tensor_to_mx_listed_chunk(mtensor, etensor, parallelism)
+ return tensor_inputs
+
+ async def run_test(self):
+ await self.reset()
+ self.log.info("Reset finished")
+ self.out_monitor.ready.value = 1
+
+ input_data = torch.randn((1, self.tensor_size_dim_0))
+ # Update config to match RTL parameter names
+ input_config = {
+ "width": self.get_parameter("DATA_IN_0_MAN_WIDTH"),
+ "exponent_width": self.get_parameter("DATA_IN_0_EXP_WIDTH"),
+ "round_bits": 4,
+ }
+
+ input_parallelism = [
+ 1,
+ self.get_parameter("DATA_IN_0_PARALLELISM_DIM_0"),
+ ]
+ (qtensor, mtensor, etensor) = mxint_hardware(input_data, input_config, input_parallelism)
+ shape = mtensor.shape
+ mtensor = mtensor.reshape(-1, self.get_parameter("DATA_IN_0_PARALLELISM_DIM_0")).unsqueeze(0)
+ mtensor = mtensor // 2**(etensor.max() - etensor).unsqueeze(-1)
+ etensor = etensor.max().repeat(etensor.shape)
+ input_data_processed = pack_tensor_to_mx_listed_chunk(mtensor, etensor, input_parallelism)
+ self.data_in_driver.load_driver(input_data_processed)
+
+ from a_cx_mxint_quant.layernorm import mxint_layer_norm
+ qinput = mtensor * 2**(etensor.unsqueeze(-1) - input_config["width"] - 1)
+ qinput = qinput.reshape(shape)
+ layer_norm_config = {
+ "name": "mxint_hardware",
+ # data
+ "data_in_width": self.get_parameter("DATA_IN_0_MAN_WIDTH"),
+ "data_in_exponent_width": self.get_parameter("DATA_IN_0_EXP_WIDTH"),
+ "data_in_parallelism": [1, self.get_parameter("DATA_IN_0_PARALLELISM_DIM_0")],
+ "data_out_width": self.get_parameter("DATA_OUT_0_MAN_WIDTH"),
+ "data_out_exponent_width": self.get_parameter("DATA_OUT_0_EXP_WIDTH"),
+ "data_out_parallelism": [1, self.get_parameter("DATA_OUT_0_PARALLELISM_DIM_0")],
+ }
+ int_config = {
+ "qx_lossy": True,
+ "num_val_0_lossy": True,
+ "num_val_1_lossy": True,
+ "mean_lossy": True,
+ "var_lossy": True,
+ "isqrt_lossy": True,
+ "data_in_width": layer_norm_config["data_in_width"],
+ "data_in_frac_width": layer_norm_config["data_in_width"] - 1,
+ "isqrt_in_width": self.get_parameter("ISQRT_IN_MAN_WIDTH"),
+ "isqrt_in_exponent_width": 6,
+ "isqrt_out_width": self.get_parameter("ISQRT_OUT_MAN_WIDTH"),
+ "isqrt_out_frac_width": self.get_parameter("ISQRT_OUT_MAN_FRAC_WIDTH"),
+ "isqrt_out_exponent_width": 6,
+ "weight_width": 8,
+ "weight_frac_width": 6,
+ "bias_width": 8,
+ "bias_frac_width": 6,
+ "data_out_width": self.get_parameter("DATA_OUT_0_MAN_WIDTH"),
+ "data_out_frac_width": self.get_parameter("DATA_OUT_0_MAN_FRAC_WIDTH"),
+ }
+ qout_data, mout_data, eout_data = mxint_layer_norm(qinput, (self.tensor_size_dim_0,), None, None, q_config=int_config)
+ eout_data = eout_data.repeat(etensor.shape)
+
+ # Simplified parallelism config since RTL only has one dimension
+ out_parallelism = [
+ 1,
+ self.get_parameter("DATA_OUT_0_PARALLELISM_DIM_0"),
+ ]
+ out_processed = pack_tensor_to_mx_listed_chunk(mout_data, eout_data, out_parallelism)
+
+ self.out_monitor.load_monitor(out_processed)
+
+ await Timer(100, units="us")
+ if not self.out_monitor.exp_queue.empty():
+ raise RuntimeError("Output monitor is not empty at end of test")
+
+@cocotb.test()
+async def test_mxint_layer_norm(dut):
+ cocotb.start_soon(check_signal(dut))
+ tb = MxIntLayerNorm1DTB(dut)
+ await tb.run_test()
+
+async def check_signal(dut):
+ await Timer(40, units="ns")
+ while True:
+ await RisingEdge(dut.clk)
+ await ReadOnly()
+ print(dut.data_in_0_valid.value, dut.data_in_0_ready.value)
+ print("end")
+
+default_config = {
+ # Input/output dimensions
+ "DATA_IN_0_TENSOR_SIZE_DIM_0": 10, # Changed from 8 to match RTL
+ "DATA_IN_0_PARALLELISM_DIM_0": 2, # Changed from 2 to match RTL
+
+ # Data width parameters
+ "DATA_IN_0_MAN_WIDTH": 8, # Added to match RTL
+ "DATA_IN_0_MAN_FRAC_WIDTH": 7, # Added to match RTL
+ "DATA_IN_0_EXP_WIDTH": 4, # Added to match RTL
+
+ "DATA_OUT_0_MAN_WIDTH": 8, # Added to match RTL
+ "DATA_OUT_0_MAN_FRAC_WIDTH": 7, # Added to match RTL
+ "DATA_OUT_0_EXP_WIDTH": 4, # Added to match RTL
+
+ # ISQRT parameters
+ "ISQRT_IN_MAN_WIDTH": 8, # Added to match RTL
+ "ISQRT_IN_MAN_FRAC_WIDTH": 7, # Added to match RTL
+ "ISQRT_OUT_MAN_WIDTH": 8, # Added to match RTL
+ "ISQRT_OUT_MAN_FRAC_WIDTH": 4, # Added to match RTL
+}
+
+def test_layer_norm_smoke():
+ valid_width = default_config["ISQRT_IN_MAN_WIDTH"] + 1
+ valid_frac_width = default_config["ISQRT_IN_MAN_WIDTH"] - 1
+
+ out_width = default_config["ISQRT_OUT_MAN_WIDTH"]
+ out_frac_width = default_config["ISQRT_OUT_MAN_FRAC_WIDTH"]
+
+ from mase_components.helper import generate_memory
+ generate_memory.generate_sv_lut(
+ "isqrt",
+ valid_width,
+ valid_frac_width,
+ out_width,
+ out_frac_width,
+ path=Path(__file__).parents[1] / "rtl",
+ constant_mult=1,
+ floor=False,
+ )
+ mase_runner(
+ trace=True,
+ module_param_list=[default_config],
+ skip_build=False,
+ sim="verilator",
+
+ )
+
+if __name__ == "__main__":
+ test_layer_norm_smoke()
diff --git a/a_cx_test_files/source_code_list/mxint_cast_log.sv b/a_cx_test_files/source_code_list/mxint_cast_log.sv
new file mode 100644
index 000000000..5894b5c5b
--- /dev/null
+++ b/a_cx_test_files/source_code_list/mxint_cast_log.sv
@@ -0,0 +1,235 @@
+`timescale 1ns / 1ps
+/*
+Module : Mxint cast
+Description : MxInt Cast between Layers.
+*/
+module mxint_cast_log #(
+ parameter IN_MAN_WIDTH = 1,
+ parameter IN_MAN_FRAC_WIDTH = IN_MAN_WIDTH - 1,
+ parameter IN_EXP_WIDTH = 1,
+ parameter OUT_MAN_WIDTH = 1,
+ parameter OUT_EXP_WIDTH = 1,
+ parameter ROUND_BITS = 4,
+ parameter BLOCK_SIZE = 1
+) (
+ /* verilator lint_off UNUSEDSIGNAL */
+ input logic clk,
+ input logic rst,
+ /* verilator lint_on UNUSEDSIGNAL */
+ input logic [ IN_MAN_WIDTH-1:0] mdata_in [BLOCK_SIZE-1:0],
+ input logic [ IN_EXP_WIDTH-1:0] edata_in,
+ input logic data_in_valid,
+ output logic data_in_ready,
+ output logic [OUT_MAN_WIDTH-1:0] mdata_out [BLOCK_SIZE-1:0],
+ output logic [OUT_EXP_WIDTH-1:0] edata_out,
+ output logic data_out_valid,
+ input logic data_out_ready
+);
+ //get max_abs_value of input
+ localparam LOG2_WIDTH = $clog2(IN_MAN_WIDTH) + 1;
+
+ localparam LOSSLESSS_EDATA_WIDTH =
+ (LOG2_WIDTH > IN_EXP_WIDTH && LOG2_WIDTH > OUT_EXP_WIDTH) ? LOG2_WIDTH + 2 :
+ (IN_EXP_WIDTH > OUT_EXP_WIDTH) ? IN_EXP_WIDTH + 2:
+ OUT_EXP_WIDTH + 2;
+
+ localparam SHIFT_WIDTH = (OUT_EXP_WIDTH > IN_EXP_WIDTH) ? OUT_EXP_WIDTH + 1 : IN_EXP_WIDTH + 1;
+ localparam SHIFT_DATA_WIDTH = OUT_MAN_WIDTH + 1;
+
+ localparam CAST_WIDTH = OUT_MAN_WIDTH + ROUND_BITS;
+
+ logic [IN_MAN_WIDTH - 1:0] mdata_for_max [BLOCK_SIZE - 1:0];
+ logic data_for_max_valid, data_for_max_ready;
+
+ logic [IN_MAN_WIDTH-1:0] mdata_for_out [BLOCK_SIZE-1:0];
+ logic [IN_EXP_WIDTH-1:0] edata_for_out;
+ logic data_for_out_valid, data_for_out_ready;
+
+ // Add register slice after log2_max_abs
+ logic [LOG2_WIDTH-1:0] log2_max_value_unreg;
+ logic log2_max_value_valid_unreg, log2_max_value_ready_unreg;
+
+ logic [LOG2_WIDTH - 1:0] log2_max_value;
+ logic log2_max_value_valid, log2_max_value_ready;
+
+ logic [LOSSLESSS_EDATA_WIDTH - 1:0] edata_out_full;
+ logic [OUT_EXP_WIDTH - 1:0] edata_out_unreg;
+ logic [SHIFT_WIDTH - 1:0] shift_value;
+ logic [IN_EXP_WIDTH + SHIFT_WIDTH - 1:0] merge_shift_edata_unreg;
+
+ logic data_out_join_valid, data_out_join_ready;
+ // we dont need to implement full shift here, because we'll clamp in the final.
+ // in order to avoid shift loss, we set the shift_data_width = OUT_MAN_WIDTH + 1.
+
+ logic [IN_EXP_WIDTH + SHIFT_WIDTH - 1:0] merge_shift_edata_reg;
+ logic [IN_MAN_WIDTH-1:0] mdata_for_out_reg [BLOCK_SIZE-1:0];
+ logic [SHIFT_WIDTH-1:0] shift_value_reg;
+
+ logic [IN_EXP_WIDTH + SHIFT_WIDTH - 1:0] merge_shift_edata_reg_1;
+ logic [IN_MAN_WIDTH-1:0] mdata_for_out_reg_1 [BLOCK_SIZE-1:0];
+ logic data_out_reg_valid_1;
+ logic data_out_reg_ready_1;
+
+ logic [CAST_WIDTH-1:0] mdata_for_cast_reg [BLOCK_SIZE-1:0];
+
+ logic [OUT_MAN_WIDTH-1:0] mdata_out_reg [BLOCK_SIZE-1:0];
+ logic [OUT_EXP_WIDTH-1:0] edata_out_reg;
+
+ logic data_out_reg_valid;
+ logic data_out_reg_ready;
+ unpacked_mx_split2_with_data #(
+ .DEPTH($clog2(BLOCK_SIZE) + 1),
+ .MAN_WIDTH(IN_MAN_WIDTH),
+ .EXP_WIDTH(IN_EXP_WIDTH),
+ .IN_SIZE(BLOCK_SIZE)
+ ) data_in_0_unpacked_mx_split2_with_data_i (
+ .clk(clk),
+ .rst(rst),
+ .mdata_in(mdata_in),
+ .edata_in(edata_in),
+ .data_in_valid(data_in_valid),
+ .data_in_ready(data_in_ready),
+ .fifo_mdata_out(mdata_for_out),
+ .fifo_edata_out(edata_for_out),
+ .fifo_data_out_valid(data_for_out_valid),
+ .fifo_data_out_ready(data_for_out_ready),
+ .straight_mdata_out(mdata_for_max),
+ .straight_edata_out(),
+ .straight_data_out_valid(data_for_max_valid),
+ .straight_data_out_ready(data_for_max_ready)
+ );
+
+ log2_max_abs #(
+ .IN_SIZE (BLOCK_SIZE),
+ .IN_WIDTH(IN_MAN_WIDTH)
+ ) max_bas_i (
+ .clk,
+ .rst,
+ .data_in_0(mdata_for_max),
+ .data_in_0_valid(data_for_max_valid),
+ .data_in_0_ready(data_for_max_ready),
+ .data_out_0(log2_max_value_unreg),
+ .data_out_0_valid(log2_max_value_valid_unreg),
+ .data_out_0_ready(log2_max_value_ready_unreg)
+ );
+
+ skid_buffer #(
+ .DATA_WIDTH(LOG2_WIDTH)
+ ) log2_reg_slice (
+ .clk(clk),
+ .rst(rst),
+ .data_in(log2_max_value_unreg),
+ .data_in_valid(log2_max_value_valid_unreg),
+ .data_in_ready(log2_max_value_ready_unreg),
+ .data_out(log2_max_value),
+ .data_out_valid(log2_max_value_valid),
+ .data_out_ready(log2_max_value_ready)
+ );
+
+ assign edata_out_full = $signed(
+ log2_max_value
+ ) + $signed(
+ edata_for_out
+ ) - IN_MAN_FRAC_WIDTH;
+
+ // clamp
+ signed_clamp #(
+ .IN_WIDTH (LOSSLESSS_EDATA_WIDTH),
+ .OUT_WIDTH(OUT_EXP_WIDTH)
+ ) exp_clamp (
+ .in_data (edata_out_full),
+ .out_data(edata_out_unreg)
+ );
+
+ assign shift_value = $signed(
+ edata_out_unreg
+ ) - $signed(
+ edata_for_out
+ ) + IN_MAN_FRAC_WIDTH - (CAST_WIDTH - 1);
+
+ join2 #() join_inst (
+ .data_in_ready ({data_for_out_ready, log2_max_value_ready}),
+ .data_in_valid ({data_for_out_valid, log2_max_value_valid}),
+ .data_out_valid(data_out_join_valid),
+ .data_out_ready(data_out_join_ready)
+ );
+
+ mxint_register_slice #(
+ .DATA_PRECISION_0(IN_MAN_WIDTH),
+ .DATA_PRECISION_1(IN_EXP_WIDTH + SHIFT_WIDTH),
+ .IN_NUM(BLOCK_SIZE)
+ ) shift_value_reg_slice (
+ .clk(clk),
+ .rst(rst),
+ .mdata_in(mdata_for_out),
+ .edata_in(merge_shift_edata_unreg),
+ .data_in_valid(data_out_join_valid),
+ .data_in_ready(data_out_join_ready),
+ .mdata_out(mdata_for_out_reg_1),
+ .edata_out(merge_shift_edata_reg_1),
+ .data_out_valid(data_out_reg_valid_1),
+ .data_out_ready(data_out_reg_ready_1)
+ );
+
+ mxint_register_slice #(
+ .DATA_PRECISION_0(IN_MAN_WIDTH),
+ .DATA_PRECISION_1(IN_EXP_WIDTH + SHIFT_WIDTH),
+ .IN_NUM(BLOCK_SIZE)
+ ) shift_value_reg_1_slice (
+ .clk(clk),
+ .rst(rst),
+ .mdata_in(mdata_for_out_reg_1),
+ .edata_in(merge_shift_edata_reg_1),
+ .data_in_valid(data_out_reg_valid_1),
+ .data_in_ready(data_out_reg_ready_1),
+ .mdata_out(mdata_for_out_reg),
+ .edata_out(merge_shift_edata_reg),
+ .data_out_valid(data_out_reg_valid),
+ .data_out_ready(data_out_reg_ready)
+ );
+ assign merge_shift_edata_unreg = {edata_out_unreg, shift_value};
+ assign edata_out_reg = merge_shift_edata_reg[IN_EXP_WIDTH + SHIFT_WIDTH - 1:SHIFT_WIDTH];
+ assign shift_value_reg = merge_shift_edata_reg[SHIFT_WIDTH - 1:0];
+
+ optimized_right_shift #(
+ .IN_WIDTH(IN_MAN_WIDTH),
+ .SHIFT_WIDTH(SHIFT_WIDTH),
+ .OUT_WIDTH(CAST_WIDTH),
+ .BLOCK_SIZE(BLOCK_SIZE)
+ ) ovshift_inst (
+ .data_in(mdata_for_out_reg),
+ .shift_value(shift_value_reg),
+ .data_out(mdata_for_cast_reg)
+ );
+
+ fixed_rounding #(
+ .IN_SIZE(BLOCK_SIZE),
+ .IN_WIDTH(CAST_WIDTH),
+ .IN_FRAC_WIDTH(CAST_WIDTH - 1),
+ .OUT_WIDTH(OUT_MAN_WIDTH),
+ .OUT_FRAC_WIDTH(OUT_MAN_WIDTH - 1)
+ ) fixed_cast_inst (
+ .data_in(mdata_for_cast_reg),
+ .data_out(mdata_out_reg) // Changed to feed into skid buffer
+ );
+
+
+ // Add skid buffer at the end
+ mxint_skid_buffer #(
+ .DATA_PRECISION_0(OUT_MAN_WIDTH),
+ .DATA_PRECISION_1(OUT_EXP_WIDTH),
+ .IN_NUM(BLOCK_SIZE)
+ ) output_skid_buffer (
+ .clk(clk),
+ .rst(rst),
+ .mdata_in(mdata_out_reg),
+ .edata_in(edata_out_reg),
+ .data_in_valid(data_out_reg_valid),
+ .data_in_ready(data_out_reg_ready),
+ .mdata_out(mdata_out),
+ .edata_out(edata_out),
+ .data_out_valid(data_out_valid),
+ .data_out_ready(data_out_ready)
+ );
+
+endmodule
\ No newline at end of file
diff --git a/a_cx_test_files/source_code_list/mxint_cast_try1.sv b/a_cx_test_files/source_code_list/mxint_cast_try1.sv
new file mode 100644
index 000000000..03e1001e0
--- /dev/null
+++ b/a_cx_test_files/source_code_list/mxint_cast_try1.sv
@@ -0,0 +1,230 @@
+`timescale 1ns / 1ps
+/*
+Module : Mxint cast
+Description : MxInt Cast between Layers.
+*/
+module mxint_cast_try1 #(
+ parameter IN_MAN_WIDTH = 1,
+ parameter IN_MAN_FRAC_WIDTH = IN_MAN_WIDTH - 1,
+ parameter IN_EXP_WIDTH = 1,
+ parameter OUT_MAN_WIDTH = 1,
+ parameter OUT_EXP_WIDTH = 1,
+ parameter ROUND_BITS = 4,
+ parameter BLOCK_SIZE = 1
+) (
+ /* verilator lint_off UNUSEDSIGNAL */
+ input logic clk,
+ input logic rst,
+ /* verilator lint_on UNUSEDSIGNAL */
+ input logic [ IN_MAN_WIDTH-1:0] mdata_in [BLOCK_SIZE-1:0],
+ input logic [ IN_EXP_WIDTH-1:0] edata_in,
+ input logic data_in_valid,
+ output logic data_in_ready,
+ output logic [OUT_MAN_WIDTH-1:0] mdata_out [BLOCK_SIZE-1:0],
+ output logic [OUT_EXP_WIDTH-1:0] edata_out,
+ output logic data_out_valid,
+ input logic data_out_ready
+);
+ //get max_abs_value of input
+ localparam LOG2_WIDTH = $clog2(IN_MAN_WIDTH) + 1;
+
+ localparam LOSSLESSS_EDATA_WIDTH =
+ (LOG2_WIDTH > IN_EXP_WIDTH && LOG2_WIDTH > OUT_EXP_WIDTH) ? LOG2_WIDTH + 2 :
+ (IN_EXP_WIDTH > OUT_EXP_WIDTH) ? IN_EXP_WIDTH + 2:
+ OUT_EXP_WIDTH + 2;
+
+ localparam SHIFT_WIDTH = (OUT_EXP_WIDTH > IN_EXP_WIDTH) ? OUT_EXP_WIDTH + 1 : IN_EXP_WIDTH + 1;
+ localparam SHIFT_DATA_WIDTH = OUT_MAN_WIDTH + 1;
+
+ localparam CAST_WIDTH = OUT_MAN_WIDTH + ROUND_BITS;
+
+ logic [IN_MAN_WIDTH - 1:0] mdata_for_max [BLOCK_SIZE - 1:0];
+ logic data_for_max_valid, data_for_max_ready;
+
+ logic [IN_MAN_WIDTH-1:0] mdata_for_out [BLOCK_SIZE-1:0];
+ logic [IN_EXP_WIDTH-1:0] edata_for_out;
+ logic data_for_out_valid, data_for_out_ready;
+
+ // Add register slice after log2_max_abs
+ logic [LOG2_WIDTH-1:0] log2_max_value_unreg;
+ logic log2_max_value_valid_unreg, log2_max_value_ready_unreg;
+
+ logic [LOG2_WIDTH - 1:0] log2_max_value;
+ logic log2_max_value_valid, log2_max_value_ready;
+
+ logic [LOSSLESSS_EDATA_WIDTH - 1:0] edata_out_full;
+ logic [SHIFT_WIDTH - 1:0] shift_value;
+ logic [IN_EXP_WIDTH + SHIFT_WIDTH - 1:0] merge_shift_edata_unreg;
+
+ logic data_out_join_valid, data_out_join_ready;
+ // we dont need to implement full shift here, because we'll clamp in the final.
+ // in order to avoid shift loss, we set the shift_data_width = OUT_MAN_WIDTH + 1.
+
+ logic [IN_EXP_WIDTH + SHIFT_WIDTH - 1:0] merge_shift_edata_reg;
+ logic [IN_MAN_WIDTH-1:0] mdata_for_out_reg [BLOCK_SIZE-1:0];
+ logic [SHIFT_WIDTH-1:0] shift_value_reg;
+
+ logic [IN_EXP_WIDTH + SHIFT_WIDTH - 1:0] merge_shift_edata_reg_1;
+ logic [IN_MAN_WIDTH-1:0] mdata_for_out_reg_1 [BLOCK_SIZE-1:0];
+ logic data_out_reg_valid_1;
+ logic data_out_reg_ready_1;
+
+ logic [CAST_WIDTH-1:0] mdata_for_cast [BLOCK_SIZE-1:0];
+
+ logic [OUT_MAN_WIDTH-1:0] mdata_out_unreg [BLOCK_SIZE-1:0];
+ logic [OUT_EXP_WIDTH-1:0] edata_out_unreg;
+
+ logic data_out_reg_valid;
+ logic data_out_reg_ready;
+
+ mxint_delay #(
+ .DATA_PRECISION_0(IN_MAN_WIDTH),
+ .DATA_PRECISION_1(IN_EXP_WIDTH),
+ .BLOCK_SIZE(BLOCK_SIZE),
+ .DELAY_REG_COUNT($clog2(BLOCK_SIZE) + 1)
+ ) mxint_delay_inst (
+ .clk(clk),
+ .rst(rst),
+ .mdata_in(mdata_in),
+ .edata_in(edata_in),
+ .mdata_out(mdata_for_out),
+ .edata_out(edata_for_out)
+ );
+ log2_max_abs #(
+ .IN_SIZE (BLOCK_SIZE),
+ .IN_WIDTH(IN_MAN_WIDTH)
+ ) max_bas_i (
+ .clk,
+ .rst,
+ .data_in_0(mdata_in),
+ .data_in_0_valid(data_in_valid),
+ .data_in_0_ready(data_in_ready),
+ .data_out_0(log2_max_value),
+ .data_out_0_valid(log2_max_value_valid),
+ .data_out_0_ready(log2_max_value_ready)
+ );
+
+ // get edata_out
+ assign edata_out_full = $signed(
+ log2_max_value
+ ) + $signed(
+ edata_for_out
+ ) - IN_MAN_FRAC_WIDTH;
+
+ signed_clamp #(
+ .IN_WIDTH (LOSSLESSS_EDATA_WIDTH),
+ .OUT_WIDTH(OUT_EXP_WIDTH)
+ ) exp_clamp (
+ .in_data (edata_out_full),
+ .out_data(edata_out_unreg)
+ );
+
+ //get shift_valud
+ assign shift_value = $signed(
+ edata_out_unreg
+ ) - $signed(
+ edata_for_out
+ ) + IN_MAN_FRAC_WIDTH - (CAST_WIDTH - 1);
+
+ optimized_right_shift #(
+ .IN_WIDTH(IN_MAN_WIDTH),
+ .SHIFT_WIDTH(SHIFT_WIDTH),
+ .OUT_WIDTH(CAST_WIDTH),
+ .BLOCK_SIZE(BLOCK_SIZE)
+ ) ovshift_inst (
+ .data_in(mdata_for_out),
+ .shift_value(shift_value),
+ .data_out(mdata_for_cast)
+ );
+ fixed_rounding #(
+ .IN_SIZE(BLOCK_SIZE),
+ .IN_WIDTH(CAST_WIDTH),
+ .IN_FRAC_WIDTH(CAST_WIDTH - 1),
+ .OUT_WIDTH(OUT_MAN_WIDTH),
+ .OUT_FRAC_WIDTH(OUT_MAN_WIDTH - 1)
+ ) fixed_cast_inst (
+ .data_in(mdata_for_cast),
+ .data_out(mdata_out_unreg) // Changed to feed into skid buffer
+ );
+
+ mxint_register_slice #(
+ .DATA_PRECISION_0(OUT_MAN_WIDTH),
+ .DATA_PRECISION_1(OUT_EXP_WIDTH),
+ .IN_NUM(BLOCK_SIZE)
+ ) register_slice_inst (
+ .clk(clk),
+ .rst(rst),
+ .mdata_in(mdata_out_unreg),
+ .edata_in(edata_out_unreg),
+ .data_in_valid(log2_max_value_valid),
+ .data_in_ready(log2_max_value_ready),
+ .mdata_out(mdata_out),
+ .edata_out(edata_out),
+ .data_out_valid(data_out_valid),
+ .data_out_ready(data_out_ready)
+ );
+
+
+
+endmodule
+
+module delay_reg #(
+ parameter DATA_PRECISION_0 = 1,
+ parameter DATA_PRECISION_1 = 1,
+ parameter DELAY_REG_COUNT = 1
+) (
+ input logic clk,
+ input logic rst,
+ input logic [DATA_PRECISION_0-1:0] data_in,
+ output logic [DATA_PRECISION_0-1:0] data_out
+);
+ logic [DATA_PRECISION_0-1:0] data_delay[DELAY_REG_COUNT-1:0];
+ always_ff @(posedge clk) begin
+ if (rst) begin
+ for (int i = 0; i < DELAY_REG_COUNT; i++) begin
+ data_delay[i] <= '0;
+ end
+ end else begin
+ data_delay[0] <= data_in;
+ for (int i = 0; i < DELAY_REG_COUNT-1; i++) begin
+ data_delay[i+1] <= data_delay[i];
+ end
+ end
+ end
+ assign data_out = data_delay[DELAY_REG_COUNT-1];
+endmodule
+
+module mxint_delay #(
+ parameter DATA_PRECISION_0 = 1,
+ parameter DATA_PRECISION_1 = 1,
+ parameter BLOCK_SIZE = 1,
+ parameter DELAY_REG_COUNT = 1
+) (
+ input logic clk,
+ input logic rst,
+ input logic [DATA_PRECISION_0-1:0] mdata_in [BLOCK_SIZE-1:0],
+ input logic [DATA_PRECISION_1-1:0] edata_in,
+ output logic [DATA_PRECISION_0-1:0] mdata_out [BLOCK_SIZE-1:0],
+ output logic [DATA_PRECISION_1-1:0] edata_out
+);
+ logic [DATA_PRECISION_0 * BLOCK_SIZE + DATA_PRECISION_1-1:0] data_in_pack;
+ logic [DATA_PRECISION_0 * BLOCK_SIZE + DATA_PRECISION_1-1:0] data_out_pack;
+ for (genvar i = 0; i < BLOCK_SIZE; i++) begin
+ assign data_in_pack[DATA_PRECISION_0 * (i+1) - 1:DATA_PRECISION_0 * i] = mdata_in[i];
+ end
+ assign data_in_pack[DATA_PRECISION_0 * BLOCK_SIZE + DATA_PRECISION_1-1:DATA_PRECISION_0 * BLOCK_SIZE] = edata_in;
+ delay_reg #(
+ .DATA_PRECISION_0(DATA_PRECISION_0 * BLOCK_SIZE + DATA_PRECISION_1),
+ .DATA_PRECISION_1(DATA_PRECISION_1),
+ .DELAY_REG_COUNT(DELAY_REG_COUNT)
+ ) delay_reg_inst (
+ .clk(clk),
+ .rst(rst),
+ .data_in(data_in_pack),
+ .data_out(data_out_pack)
+ );
+ for (genvar i = 0; i < BLOCK_SIZE; i++) begin
+ assign mdata_out[i] = data_out_pack[DATA_PRECISION_0 * (i+1) - 1:DATA_PRECISION_0 * i];
+ end
+ assign edata_out = data_out_pack[DATA_PRECISION_0 * BLOCK_SIZE + DATA_PRECISION_1-1:DATA_PRECISION_0 * BLOCK_SIZE];
+endmodule
diff --git a/src/mase_components/vision_models/vit/test/helpers/__init__.py b/a_cx_test_files/source_code_list/mxint_cast_try2.sv
similarity index 100%
rename from src/mase_components/vision_models/vit/test/helpers/__init__.py
rename to a_cx_test_files/source_code_list/mxint_cast_try2.sv
diff --git a/a_cx_test_files/source_code_list/mxint_dot_product_history.sv b/a_cx_test_files/source_code_list/mxint_dot_product_history.sv
new file mode 100644
index 000000000..d4836636a
--- /dev/null
+++ b/a_cx_test_files/source_code_list/mxint_dot_product_history.sv
@@ -0,0 +1,114 @@
+`timescale 1ns / 1ps
+module mxint_dot_product_history #(
+ // precision_0 represent mantissa width
+ // precision_1 represent exponent width
+ //
+ parameter DATA_IN_0_PRECISION_0 = 8,
+ parameter DATA_IN_0_PRECISION_1 = 8,
+ parameter WEIGHT_PRECISION_0 = 8,
+ parameter WEIGHT_PRECISION_1 = 8,
+ parameter BLOCK_SIZE = 6,
+ parameter DATA_OUT_0_PRECISION_0 = DATA_IN_0_PRECISION_0 + WEIGHT_PRECISION_0 + $clog2(
+ BLOCK_SIZE
+ ),
+ parameter DATA_OUT_0_PRECISION_1 = (DATA_IN_0_PRECISION_1 > WEIGHT_PRECISION_1)? DATA_IN_0_PRECISION_1 + 1 : WEIGHT_PRECISION_1 + 1
+) (
+ input clk,
+ input rst,
+ // m -> mantissa, e -> exponent
+ input logic [DATA_IN_0_PRECISION_0-1:0] mdata_in_0[BLOCK_SIZE - 1:0],
+ input logic [DATA_IN_0_PRECISION_1-1:0] edata_in_0,
+ input data_in_0_valid,
+ output data_in_0_ready,
+
+ input logic [WEIGHT_PRECISION_0-1:0] mweight[BLOCK_SIZE - 1:0],
+ input logic [WEIGHT_PRECISION_1-1:0] eweight,
+ input weight_valid,
+ output weight_ready,
+
+ output logic [DATA_OUT_0_PRECISION_0-1:0] mdata_out_0,
+ output logic [DATA_OUT_0_PRECISION_1-1:0] edata_out_0,
+ output data_out_0_valid,
+ input data_out_0_ready
+);
+
+ logic [DATA_IN_0_PRECISION_0 - 1:0] mdata_in_0_reg_out[BLOCK_SIZE - 1:0];
+ logic mdata_in_0_reg_out_valid, mdata_in_0_reg_out_ready;
+ logic [DATA_IN_0_PRECISION_1 - 1:0] buffer_edata_in_0;
+ logic buffer_edata_in_0_valid, buffer_edata_in_0_ready;
+
+ logic [WEIGHT_PRECISION_0 - 1:0] mweight_reg_out[BLOCK_SIZE - 1:0];
+ logic mweight_reg_out_valid, mweight_reg_out_ready;
+
+ logic [WEIGHT_PRECISION_1-1:0] buffer_eweight;
+ logic buffer_eweight_valid, buffer_eweight_ready;
+
+ logic mdata_out_0_valid, mdata_out_0_ready;
+ mxint_straightm_fifoe #(
+ .DEPTH($clog2(BLOCK_SIZE) + 1),
+ .MAN_WIDTH(DATA_IN_0_PRECISION_0),
+ .EXP_WIDTH(DATA_IN_0_PRECISION_1),
+ .IN_SIZE(BLOCK_SIZE)
+ ) data_in_0_split_m_e (
+ .clk(clk),
+ .rst(rst),
+ .mdata_in(mdata_in_0),
+ .edata_in(edata_in_0),
+ .data_in_valid(data_in_0_valid),
+ .data_in_ready(data_in_0_ready),
+ .fifo_edata_out(buffer_edata_in_0),
+ .fifo_edata_out_valid(buffer_edata_in_0_valid),
+ .fifo_edata_out_ready(buffer_edata_in_0_ready),
+ .straight_mdata_out(mdata_in_0_reg_out),
+ .straight_mdata_out_valid(mdata_in_0_reg_out_valid),
+ .straight_mdata_out_ready(mdata_in_0_reg_out_ready)
+ );
+
+ mxint_straightm_fifoe #(
+ .DEPTH($clog2(BLOCK_SIZE) + 1),
+ .MAN_WIDTH(WEIGHT_PRECISION_0),
+ .EXP_WIDTH(WEIGHT_PRECISION_1),
+ .IN_SIZE(BLOCK_SIZE)
+ ) weight_split_m_e (
+ .clk(clk),
+ .rst(rst),
+ .mdata_in(mweight),
+ .edata_in(eweight),
+ .data_in_valid(weight_valid),
+ .data_in_ready(weight_ready),
+ .fifo_edata_out(buffer_eweight),
+ .fifo_edata_out_valid(buffer_eweight_valid),
+ .fifo_edata_out_ready(buffer_eweight_ready),
+ .straight_mdata_out(mweight_reg_out),
+ .straight_mdata_out_valid(mweight_reg_out_valid),
+ .straight_mdata_out_ready(mweight_reg_out_ready)
+ );
+ assign edata_out_0 = $signed(buffer_eweight) + $signed(buffer_edata_in_0);
+ fixed_dot_product #(
+ .IN_WIDTH(DATA_IN_0_PRECISION_0),
+ .WEIGHT_WIDTH(WEIGHT_PRECISION_0),
+ .IN_SIZE(BLOCK_SIZE)
+ ) fdp_inst (
+ .clk(clk),
+ .rst(rst),
+ .data_in(mdata_in_0_reg_out),
+ .data_in_valid(mdata_in_0_reg_out_valid),
+ .data_in_ready(mdata_in_0_reg_out_ready),
+ .weight(mweight_reg_out),
+ .weight_valid(mweight_reg_out_valid),
+ .weight_ready(mweight_reg_out_ready),
+ .data_out(mdata_out_0),
+ .data_out_valid(mdata_out_0_valid),
+ .data_out_ready(mdata_out_0_ready)
+ );
+
+ join_n #(
+ .NUM_HANDSHAKES(3)
+ ) join_inst (
+ .data_in_ready ({mdata_out_0_ready, buffer_eweight_ready, buffer_edata_in_0_ready}),
+ .data_in_valid ({mdata_out_0_valid, buffer_eweight_valid, buffer_edata_in_0_valid}),
+ .data_out_valid(data_out_0_valid),
+ .data_out_ready(data_out_0_ready)
+ );
+
+endmodule
diff --git a/a_cx_test_files/source_code_list/some_useless_code.py b/a_cx_test_files/source_code_list/some_useless_code.py
new file mode 100644
index 000000000..2036fdb9f
--- /dev/null
+++ b/a_cx_test_files/source_code_list/some_useless_code.py
@@ -0,0 +1,532 @@
+class MXIntLinearHardware(_LinearBase):
+ def __init__(
+ self,
+ in_features: int,
+ out_features: int,
+ bias: bool = True,
+ device=None,
+ dtype=None,
+ config=None,
+ ) -> None:
+ super().__init__(in_features, out_features, bias, device, dtype)
+ assert config is not None, "config is None!"
+ self.in_features = in_features
+ self.out_features = out_features
+ self.config = config
+ self.bypass = config.get("bypass", False)
+ if self.bypass:
+ return
+ # establish quantizer
+ w_width, w_exponent_width = (
+ config["weight_width"],
+ config["weight_exponent_width"],
+ )
+ w_p1, w_p0 = (
+ config["weight_parallelism"][0],
+ config["weight_parallelism"][1],
+ )
+ x_width, x_exponent_width = (
+ config["data_in_width"],
+ config["data_in_exponent_width"],
+ )
+ x_p1, x_p0 = (
+ config["data_in_parallelism"][0],
+ config["data_in_parallelism"][1],
+ )
+ # check bias quantizer, if not, use weight quantizer
+ b_width, b_exponent_width = config["bias_width"], config["bias_exponent_width"]
+ b_p1, b_p0 = (
+ config["bias_parallelism"][0],
+ config["bias_parallelism"][1],
+ )
+ base_quantizer = block_mxint_quant
+ out_width, out_exponent_width = (
+ config["data_out_width"],
+ config["data_out_exponent_width"],
+ )
+ out_p1, out_p0 = (
+ config["data_out_parallelism"][0],
+ config["data_out_parallelism"][1],
+ )
+ self.out_quantizer = partial(
+ base_quantizer,
+ q_config={"width": out_width, "exponent_width": out_exponent_width},
+ parallelism=[out_p1, out_p0],
+ )
+ self.w_quantizer = partial(
+ base_quantizer,
+ q_config={"width": w_width, "exponent_width": w_exponent_width},
+ parallelism=[w_p1, w_p0],
+ )
+ self.x_quantizer = partial(
+ base_quantizer,
+ q_config={"width": x_width, "exponent_width": x_exponent_width},
+ parallelism=[x_p1, x_p0],
+ )
+ self.b_quantizer = partial(
+ base_quantizer,
+ q_config={"width": b_width, "exponent_width": b_exponent_width},
+ parallelism=[b_p1, b_p0],
+ )
+
+ def forward(self, x: Tensor) -> Tensor:
+ x, mx, ex = self.x_quantizer(x)
+ in_x = (mx, ex)
+ w, mw, ew = self.w_quantizer(self.weight)
+ in_w = (mw, ew)
+ if self.bias is not None:
+ bias, mbias, ebias = self.b_quantizer(self.bias)
+ in_bias = (mbias, ebias)
+ else:
+ bias = None
+ in_bias = None
+
+ out = wrapped_mxint_linear_hardware(
+ in_x, in_w, in_bias, self.in_features, self.out_features, self.config
+ )
+
+ return out
+
+
+def wrapped_mxint_linear_hardware(x, w, bias, in_features, out_features, config):
+ mx = x[0]
+ n = mx.reshape(-1, in_features).shape[0]
+ in_config = {
+ "x_config": {
+ "width": config["data_in_width"],
+ "exponent_width": config["data_in_exponent_width"],
+ "parallism_dim_0": config["data_in_parallelism"][1],
+ "parallism_dim_1": config["data_in_parallelism"][0],
+ "depth_dim_0": in_features // config["data_in_parallelism"][1],
+ "depth_dim_1": n // config["data_in_parallelism"][0],
+ "dim_0": in_features,
+ "dim_1": n,
+ },
+ "w_config": {
+ "width": config["weight_width"],
+ "exponent_width": config["weight_exponent_width"],
+ "parallism_dim_0": config["weight_parallelism"][1],
+ "parallism_dim_1": config["weight_parallelism"][0],
+ "depth_dim_0": in_features // config["weight_parallelism"][1],
+ "depth_dim_1": out_features // config["weight_parallelism"][0],
+ "dim_0": in_features,
+ "dim_1": out_features,
+ },
+ "bias_config": {
+ "width": config["bias_width"],
+ "exponent_width": config["bias_exponent_width"],
+ "parallism_dim_0": config["bias_parallelism"][1],
+ "parallism_dim_1": 1,
+ "depth_dim_0": out_features // config["bias_parallelism"][1],
+ "depth_dim_1": 1,
+ "dim_0": out_features,
+ "dim_1": 1,
+ },
+ "out_config": {
+ "width": config["data_out_width"],
+ "exponent_width": config["data_out_exponent_width"],
+ "parallism_dim_0": config["data_out_parallelism"][1],
+ "parallism_dim_1": config["data_out_parallelism"][0],
+ "depth_dim_0": out_features // config["data_out_parallelism"][1],
+ "depth_dim_1": n // config["data_out_parallelism"][0],
+ "dim_0": out_features,
+ "dim_1": n,
+ },
+ }
+ mout, eout = mxint_linear_hardware(x, w, bias, in_config)
+ out_config = in_config["out_config"]
+ reshaped_mout = mout.reshape(
+ out_config["depth_dim_1"],
+ out_config["parallism_dim_1"],
+ out_config["depth_dim_0"],
+ out_config["parallism_dim_0"],
+ ).permute(0, 2, 1, 3)
+ reshaped_out = reshaped_mout * 2 ** (
+ eout[:, :, None, None] - config["data_out_width"] + 1
+ )
+ out = reshaped_out.reshape(
+ out_config["depth_dim_1"],
+ out_config["depth_dim_0"],
+ out_config["parallism_dim_1"],
+ out_config["parallism_dim_0"],
+ ).permute(0, 2, 1, 3)
+ out = out.reshape(out_config["dim_1"], out_config["dim_0"])
+
+ return out
+
+
+def mxint_linear_hardware(x, w, bias, config):
+ """
+ assume 2 dimensional input
+ config = {
+ "x_config":{
+ "width": ,
+ "exponent_width" ,
+ "parallism_dim_0",
+ "parallism_dim_1",
+ "depth_dim_0",
+ "depth_dim_1",
+ "dim_0",
+ "dim_1",
+ },
+ "w_config": {
+ ...
+ },
+ "bias_config": {
+ ...
+ },
+ "out_config": {
+ ...
+ },
+ }
+ """
+ mx, ex = x
+ mw, ew = w
+ x_config = config["x_config"]
+ w_config = config["w_config"]
+ out_config = config["out_config"]
+ from math import ceil, log2
+
+ def DotProductCore(man_x, exp_x, man_y, exp_y):
+ return man_x @ man_y.transpose(0, 1), exp_x + exp_y
+
+ def block_wise_reshape_tensor(x, x_config):
+ reshaped_x = x.reshape(
+ x_config["depth_dim_1"],
+ x_config["parallism_dim_1"],
+ x_config["depth_dim_0"],
+ x_config["parallism_dim_0"],
+ ).permute(0, 2, 1, 3)
+ reshaped_x = reshaped_x.reshape(
+ x_config["depth_dim_1"] * x_config["depth_dim_0"],
+ x_config["parallism_dim_1"],
+ x_config["parallism_dim_0"],
+ )
+ return reshaped_x
+
+ # assume 2 dimensional input
+ assert (
+ x_config["depth_dim_0"] == w_config["depth_dim_0"]
+ ), "need to check the setting of dim"
+ assert (
+ x_config["parallism_dim_0"] == w_config["parallism_dim_0"]
+ ), "need to check the setting of dim"
+ reshaped_ex = ex.reshape(-1)
+ reshaped_mx = block_wise_reshape_tensor(mx, x_config)
+ reshaped_ew = ew.reshape(-1)
+ reshaped_mw = block_wise_reshape_tensor(mw, w_config)
+ man_out = torch.zeros(
+ x_config["depth_dim_1"],
+ w_config["depth_dim_1"],
+ x_config["parallism_dim_1"] * w_config["parallism_dim_1"],
+ )
+ exp_out = torch.zeros(x_config["depth_dim_1"], w_config["depth_dim_1"])
+ for i in range(x_config["depth_dim_1"]):
+ for j in range(w_config["depth_dim_1"]):
+ partial_man_out = torch.zeros(
+ w_config["depth_dim_0"],
+ x_config["parallism_dim_1"],
+ w_config["parallism_dim_1"],
+ )
+ partial_exp_out = torch.zeros(w_config["depth_dim_0"])
+ for k in range(x_config["depth_dim_0"]):
+ mx_block = reshaped_mx[i * x_config["depth_dim_0"] + k]
+ ex_block = reshaped_ex[i * x_config["depth_dim_0"] + k]
+ mw_block = reshaped_mw[j * w_config["depth_dim_0"] + k]
+ ew_block = reshaped_ew[j * w_config["depth_dim_0"] + k]
+ partial_man_out[k], partial_exp_out[k] = DotProductCore(
+ mx_block, ex_block, mw_block, ew_block
+ )
+ acc_man_out, acc_exp_out = MxIntAccumulator(
+ partial_man_out.reshape(w_config["depth_dim_0"], -1), partial_exp_out
+ )
+ if bias != None:
+ bias_config = config["bias_config"]
+ mbias, ebias = bias
+ reshaped_mbias = mbias.reshape(
+ w_config["depth_dim_1"], w_config["parallism_dim_1"]
+ )
+ reshaped_ebias = ebias.reshape(w_config["depth_dim_1"])
+ shifted_value = (
+ reshaped_ebias[j]
+ - acc_exp_out
+ + x_config["width"]
+ + w_config["width"]
+ - 2
+ - (bias_config["width"] - 1)
+ )
+ shifted_bias = reshaped_mbias[j].repeat(
+ x_config["parallism_dim_1"]
+ ) * 2 ** (shifted_value)
+ print(reshaped_mbias[j])
+ print(shifted_value)
+ acc_man_out = shifted_bias + acc_man_out
+ print("shfited_bias", shifted_bias)
+ man_out[i][j], exp_out[i][j] = MxIntCast(
+ acc_man_out,
+ acc_exp_out,
+ {
+ "in_width": x_config["width"]
+ + w_config["width"]
+ + ceil(log2(x_config["dim_0"])),
+ "in_frac_width": x_config["width"] + w_config["width"] - 2,
+ "in_exponent_width": max(
+ x_config["exponent_width"], w_config["exponent_width"]
+ )
+ + 1,
+ "out_width": out_config["width"],
+ "out_exponent_width": out_config["exponent_width"],
+ },
+ )
+ man_out = (
+ man_out.reshape(
+ x_config["depth_dim_1"],
+ w_config["depth_dim_1"],
+ x_config["parallism_dim_1"],
+ w_config["parallism_dim_1"],
+ )
+ .permute(0, 2, 1, 3)
+ .reshape(x_config["dim_1"], w_config["dim_1"])
+ )
+ return man_out, exp_out
+
+
+def MXIntMatmulHardware(man_x, exp_x, man_y, exp_y, x_config, y_config, out_config):
+ """
+ assume 2 dimensional input
+ config = {
+ "width": ,
+ "exponent_width" ,
+ "parallism_dim_0",
+ "parallism_dim_1",
+ "depth_dim_0",
+ "depth_dim_1",
+ "dim_0",
+ "dim_1",
+ }
+ man.shape = [dim_1 * dim_0]
+ exp.shape = [depth_dim_1, depth_dim_0]
+ """
+ from math import ceil, log2
+
+ def MatmulCore(man_x, exp_x, man_y, exp_y):
+ return man_x @ man_y, exp_x + exp_y
+
+ # assume 2 dimensional input
+ assert (
+ x_config["depth_dim_0"] == y_config["depth_dim_1"]
+ ), "need to check the setting of dim"
+
+ def block_wise_reshape_tensor(x, x_config):
+ reshaped_x = x.reshape(
+ x_config["depth_dim_1"],
+ x_config["parallism_dim_1"],
+ x_config["depth_dim_0"],
+ x_config["parallism_dim_0"],
+ ).permute(0, 2, 1, 3)
+ reshaped_x = reshaped_x.reshape(
+ x_config["depth_dim_1"] * x_config["depth_dim_0"],
+ x_config["parallism_dim_1"],
+ x_config["parallism_dim_0"],
+ )
+ return reshaped_x
+
+ reshaped_exp_x = exp_x.reshape(-1)
+ reshaped_man_x = block_wise_reshape_tensor(man_x, x_config)
+ reshaped_exp_y = exp_y.reshape(-1)
+ reshaped_man_y = block_wise_reshape_tensor(man_y, y_config)
+ man_out = torch.zeros(
+ x_config["depth_dim_1"],
+ y_config["depth_dim_0"],
+ x_config["parallism_dim_1"] * y_config["parallism_dim_0"],
+ )
+ exp_out = torch.zeros(x_config["depth_dim_1"], y_config["depth_dim_0"])
+ for i in range(x_config["depth_dim_1"]):
+ for j in range(y_config["depth_dim_0"]):
+ partial_man_out = torch.zeros(
+ y_config["depth_dim_1"],
+ x_config["parallism_dim_1"],
+ y_config["parallism_dim_0"],
+ )
+ partial_exp_out = torch.zeros(y_config["depth_dim_1"])
+ for k in range(y_config["depth_dim_1"]):
+ man_x_block = reshaped_man_x[i * x_config["depth_dim_0"] + k]
+ exp_x_block = reshaped_exp_x[i * x_config["depth_dim_0"] + k]
+ man_y_block = reshaped_man_y[k * y_config["depth_dim_0"] + j]
+ exp_y_block = reshaped_exp_y[k * y_config["depth_dim_0"] + j]
+ partial_man_out[k], partial_exp_out[k] = MatmulCore(
+ man_x_block, exp_x_block, man_y_block, exp_y_block
+ )
+ acc_man_out, acc_exp_out = MxIntAccumulator(
+ partial_man_out.reshape(y_config["depth_dim_1"], -1), partial_exp_out
+ )
+ man_out[i][j], exp_out[i][j] = MxIntCast(
+ acc_man_out,
+ acc_exp_out,
+ {
+ "in_width": x_config["width"]
+ + y_config["width"]
+ + ceil(log2(x_config["dim_0"])),
+ "in_frac_width": x_config["width"] + y_config["width"] - 2,
+ "in_exponent_width": max(
+ x_config["exponent_width"], y_config["exponent_width"]
+ )
+ + 1,
+ "out_width": out_config["width"],
+ "out_exponent_width": out_config["exponent_width"],
+ },
+ )
+ man_out = (
+ man_out.reshape(
+ x_config["depth_dim_1"],
+ y_config["depth_dim_0"],
+ x_config["parallism_dim_1"],
+ x_config["parallism_dim_0"],
+ )
+ .permute(0, 2, 1, 3)
+ .reshape(x_config["dim_1"], y_config["dim_0"])
+ )
+ return man_out, exp_out
+
+
+def MxIntCast(man_in, exp_in, param):
+ # In Man Width
+ max_in = torch.ceil(torch.log2(man_in.abs().max()))
+ out_width = param["out_width"]
+ out_exponent_width = param["out_exponent_width"]
+ in_width = param["in_width"]
+ in_frac_width = param["in_frac_width"]
+ in_exponent_width = param["in_exponent_width"]
+
+ out_exponent_max = 2 ** (out_exponent_width - 1) - 1
+ out_exponent_min = -(2 ** (out_exponent_width - 1))
+
+ out_min = -(2 ** (out_width - 1))
+ out_max = 2 ** (out_width - 1) - 1
+ lma_in = torch.ceil(torch.log2(man_in.abs().max() + 1e-3))
+ out_exp_full = lma_in + exp_in - in_frac_width
+ out_exp = torch.clamp(out_exp_full, out_exponent_min, out_exponent_max)
+ out_man = man_in // 2 ** (in_frac_width - exp_in + out_exp - (out_width - 1))
+ out_man = torch.clamp(out_man, out_min, out_max)
+
+ return out_man, out_exp
+
+
+# def MxIntAccumulator(man, exp, clamp_width = 15):
+# IN_DEPTH, BLOCK_SIZE = man.shape[0],man.shape[1]
+# min_exp = torch.Tensor([64])
+# mout = torch.zeros(BLOCK_SIZE)
+# out_exp = torch.Tensor([64])
+# for i in range(IN_DEPTH):
+# min_exp = exp[i] if exp[i] max_exp else max_exp
+ mout = mout // 2 ** (max_exp - out_exp)
+ out_exp = max_exp
+ shifted_man = man[i] // 2 ** (max_exp - exp[i])
+ mout = mout + shifted_man
+
+ return mout, out_exp
+
+def quantized_range_reduction(mx, ex, in_man_width, data_out_n_width):
+ """Vectorized range reduction"""
+ def hardware_round(mx, ex, in_man_frac_width, data_out_width):
+ round_max = 2**(data_out_width-1) - 1
+ round_min = -2**(data_out_width-1)
+ round_x = mx.reshape(-1) // 2**((in_man_frac_width-ex).reshape(-1))
+ return torch.clamp(round_x, round_min, round_max)
+ coefficient_quant_block = partial(
+ mxint_quantize,
+ width=8,
+ exponent_width=4)
+ _, mlog2_e, elog2_e = coefficient_quant_block(torch.log2(torch.tensor(math.e)))
+ _, mln_2, eln_2 = coefficient_quant_block(torch.log(torch.tensor(2.0)))
+ n = hardware_round(mx * mlog2_e, ex + elog2_e, (in_man_width - 1 + 7), data_out_n_width)
+ print(n)
+ _mx = n * mln_2
+ _ex = eln_2
+ shifted_mx = mx // 2**(_ex - ex + (in_man_width - 1) - 7)
+ print(shifted_mx)
+ print(_ex - ex + (in_man_width - 1) - 7)
+ mr = shifted_mx - _mx
+ # return mr as an fixedpoint ?.7 we can make it 2.7
+ # return n as an integer number with width = data_out_width
+ return mr, n
+
+def fixed_exp(fr):
+ frac_width = 7
+ exp = 1*2**(frac_width) + fr + fr**2//2**(frac_width + 1) + fr**3*5//2**(frac_width + 4)
+ return exp
+
+
+
+def mxint_softmax(x, q_config):
+ # fixed_r, integer_n
+ in_man_width = q_config["in_man_width"]
+ in_exp_width = q_config["in_exp_width"]
+ data_out_n_width = q_config["data_out_n_width"]
+ data_out_man_width = q_config["data_out_man_width"]
+ data_out_frac_width = data_out_man_width - 1
+ data_out_exp_width = q_config["data_out_exp_width"]
+
+ shape = x.shape[0]
+ mout = torch.zeros_like(x)
+ eout = torch.zeros_like(x)
+
+ list_of_mexps = []
+ list_of_eexps = []
+ for i in range(shape):
+ _, mx, ex = mxint_quantize(x[i], in_man_width, in_exp_width)
+ fixed_r, integer_n = quantized_range_reduction(mx, ex, in_man_width, data_out_n_width)
+ # fixed_r will be 2.7 bits, integer_n will be data_out_n_width bits
+ mexp = fixed_exp(fixed_r)
+ eexp = integer_n
+ # currently we got mexp ?.7 bits, integer_n data_out_n_width bits
+ list_of_mexps.append(mexp)
+ list_of_eexps.append(eexp)
+ eexps = torch.stack(list_of_eexps)
+ mexps = torch.stack(list_of_mexps)
+ m_sum, e_sum = MxIntAccumulator(torch.stack(list_of_mexps), torch.stack(list_of_eexps))
+ extended_mexps = mexps * 2**(data_out_frac_width)
+ pre_cast_mout = extended_mexps // mexps
+ pre_cast_eout = eexps - e_sum
+ pre_cast_out = pre_cast_mout * 2**(pre_cast_eout - 7)
+ for i in range(shape):
+ _, mout[i], eout[i] = mxint_quantize(pre_cast_out[i], data_out_man_width, data_out_exp_width)
+ return mout, eout
+
+
+def preprocess_weight_tensor_for_mxint(self, tensor, config, parallelism):
+ from utils import mxint_quantize
+
+ t1, t0 = tensor.shape[0], tensor.shape[1]
+ p1, p0 = parallelism[0], parallelism[1]
+ reshaped_tensor = tensor.reshape(t1//p1, p1, t0//p0, p0).permute(0, 2, 1, 3)
+ reshaped_tensor = reshaped_tensor.reshape(-1, p1,p0)
+
+ tensor_inputs = []
+ for i in range(t1 * t0 //(p1*p0)):
+ etensors = []
+ mtensors = []
+ for j in range(p1):
+ (qtensor, mtensor, etensor) = mxint_quantize(reshaped_tensor[i][j], width=config["width"], exponent_width=config["exponent_width"])
+ etensors.append(int(etensor))
+ mtensors += mtensor.int().tolist()
+ tensor_inputs.append((mtensors, etensors))
+
+ return tensor_inputs
\ No newline at end of file
diff --git a/a_cx_test_files/test.drawio b/a_cx_test_files/test.drawio
new file mode 100644
index 000000000..e69de29bb
diff --git a/a_cx_test_files/test.tex b/a_cx_test_files/test.tex
new file mode 100644
index 000000000..c77d9eba5
--- /dev/null
+++ b/a_cx_test_files/test.tex
@@ -0,0 +1,40 @@
+\begin{figure*}
+ \begin{subfigure}[b]{0.3\textwidth}
+ \begin{algorithmic}[1] \footnotesize
+ \Require $X$ \Comment{Input features}
+ \Require $H$ \Comment{Number of heads}
+ \Require $L$ \Comment{Number of hidden layers}
+ \State $\quant{X_n} \gets \apprx{LayerNorm(\quant{X})} $
+ \For{$i \in [0, H)$}
+ \State $\quant{Q_i} \gets \quant{W_{Q_i}} \apprx{\times} \quant{X_n}$
+ \State $\quant{K_i} \gets \quant{W_{K_i}} \apprx{\times} \quant{X_n}$
+ \State $\quant{V_i} \gets \quant{W_{V_i}} \apprx{\times} \quant{X_n}$
+ \State $\quant{A_i} \gets \frac{\quant{Q_i} \apprx{\times} \quant{K_i}^T}{\sqrt{d_k}} $
+ \State $\quant{\hat{A}_i} \gets \apprx{softmax(\quant{A_i})} $
+ \State $\quant{B_i} \gets \quant{\hat{A}_i} \apprx{\times} \quant{V_i}$
+ \EndFor
+ \State $\quant{B_c} \gets \apprx{concat(\quant{B_0}.. \quant{B_{H-1}})} $
+ \State $\quant{B_o} \gets \quant{W_0} \apprx{\times} \quant{B_c}$
+ \State $\quant{B_n} \gets \apprx{LayerNorm(\quant{B_o} + \quant{X_n})} $
+ \State $\quant{U} \gets \quant{W_U} \apprx{\times} \quant{B_n}$
+ \State $\quant{D} \gets \quant{W_D} (\apprx{GELU(\quant{U})})$
+ \State $\quant{O} \gets \quant{D} + \quant{B_n}$
+ \State \Return $\quant{O}$
+ \end{algorithmic}
+ \caption{An algorithm view of a block in the ViT model.
+ Values highlighted in \quant{\em blue} represent quantized values, and operations highlighted in \apprx{green} represent approximated operations.}
+ \label{fig:motivation}
+ \end{subfigure}
+ \hfill
+ % \begin{subfigure}[b]{0.01\textwidth}
+ % ~
+ % \end{subfigure}
+ \begin{subfigure}[b]{0.6\textwidth}
+ \caption{An architecture view of the proposed hardware accelerator.
+ The proposed architecture pipelines the model in a hierarchical dataflow, and tailors each operation for high area efficiency.}
+ \label{fig:motivation}
+ \end{subfigure}
+ \caption{An overview of the proposed accelerator architecture.}
+ \label{fig:motivation}
+ \end{figure*}
+
\ No newline at end of file
diff --git a/justfile b/justfile
index 74d5ceee6..fe290e063 100644
--- a/justfile
+++ b/justfile
@@ -22,7 +22,7 @@ test-hw:
# python3 src/mase_components/activation_layers/test/fixed_sigmoid_tb.py
python3 src/mase_components/activation_layers/test/fixed_softermax_1d_tb.py
# python3 src/mase_components/activation_layers/test/fixed_softermax_tb.py
- # python3 src/mase_components/activation_layers/test/fixed_softmax_tb.py
+ python3 src/mase_components/activation_layers/test/fixed_softmax_tb.py
python3 src/mase_components/activation_layers/test/fixed_softplus_tb.py
python3 src/mase_components/activation_layers/test/fixed_softsign_tb.py
python3 src/mase_components/activation_layers/test/fixed_tanh_tb.py
@@ -111,6 +111,7 @@ test-hw:
python3 src/mase_components/linear_layers/mxint_operators/test/mxint_matmul_tb.py
python3 src/mase_components/linear_layers/mxint_operators/test/mxint_linear_tb.py
python3 src/mase_components/linear_layers/mxint_operators/test/mxint_accumulator_tb.py
+ python3 src/mase_components/linear_layers/mxint_operators/test/mxint_softmax.py
# Memory
python3 src/mase_components/memory/test/fifo_tb.py
# python3 src/mase_components/memory/test/input_buffer_tb.py
@@ -143,6 +144,10 @@ test-hw:
# python3 src/mase_components/transformer_layers/test/fixed_self_attention_tb.py
# python3 src/mase_components/transformer_layers/test/test_lint_attention.py
+ # ViT layers
+ python3 src/mase_components/vision_models/test/fixed_self_attention_head_tb.py
+
+
reformat:
# format python files
black src/chop
diff --git a/src/chop/actions/simulate.py b/src/chop/actions/simulate.py
index e56a512d8..56d5bf56e 100644
--- a/src/chop/actions/simulate.py
+++ b/src/chop/actions/simulate.py
@@ -38,11 +38,16 @@ def simulate(
gui: bool = False,
waves: bool = False,
simulator: str = "verilator",
+ pass_args = {},
):
SIM = getenv("SIM", simulator)
runner = get_runner(SIM)
- project_dir = Path.home() / ".mase" / "top"
+ project_dir = (
+ pass_args["project_dir"]
+ if "project_dir" in pass_args.keys()
+ else Path.home() / ".mase" / "top"
+ )
if run_emit:
emit(model, model_info, task, dataset_info, data_module, load_name, load_type)
@@ -64,6 +69,8 @@ def simulate(
"--trace-structs",
"--trace-depth",
str(trace_depth),
+ "--unroll-count",
+ "16384"
]
else:
raise ValueError(f"Unrecognized simulator: {simulator}")
diff --git a/src/chop/ir/graph/mase_graph.py b/src/chop/ir/graph/mase_graph.py
index fc2fb9fd9..c5f5d2c1f 100644
--- a/src/chop/ir/graph/mase_graph.py
+++ b/src/chop/ir/graph/mase_graph.py
@@ -187,6 +187,8 @@ def is_leaf_module(
custom_leaf_layers = ()
# quantized functions/layers
custom_leaf_functions += tuple(quantized_func_map.values())
+ if custom_ops != None:
+ custom_leaf_layers += tuple(custom_ops.get("modules", {}).keys())
custom_leaf_layers += tuple(quantized_module_map.values())
# patched functions/layers
patched_nodes = getattr(model, "patched_nodes", None)
diff --git a/src/chop/models/vision/vit/__init__.py b/src/chop/models/vision/vit/__init__.py
new file mode 100644
index 000000000..03a3168cc
--- /dev/null
+++ b/src/chop/models/vision/vit/__init__.py
@@ -0,0 +1 @@
+from .vit import get_vit_tiny_patch16, get_vit_base_patch16
diff --git a/src/chop/models/vision/vit/utils.py b/src/chop/models/vision/vit/utils.py
new file mode 100644
index 000000000..ed7c23fa6
--- /dev/null
+++ b/src/chop/models/vision/vit/utils.py
@@ -0,0 +1,199 @@
+# Copyright (c) MEGVII Inc. and its affiliates. All Rights Reserved.
+import math
+import os
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+
+@torch.no_grad()
+def load_weights_from_npz(model, url, check_hash=False, progress=False, prefix=""):
+ """Load weights from .npz checkpoints for official Google Brain Flax implementation"""
+
+ def _n2p(w, t=True):
+ if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
+ w = w.flatten()
+ if t:
+ if w.ndim == 4:
+ w = w.transpose([3, 2, 0, 1])
+ elif w.ndim == 3:
+ w = w.transpose([2, 0, 1])
+ elif w.ndim == 2:
+ w = w.transpose([1, 0])
+ return torch.from_numpy(w)
+
+ def _get_cache_dir(child_dir=""):
+ """
+ Returns the location of the directory where models are cached (and creates it if necessary).
+ """
+ hub_dir = torch.hub.get_dir()
+ child_dir = () if not child_dir else (child_dir,)
+ model_dir = os.path.join(hub_dir, "checkpoints", *child_dir)
+ os.makedirs(model_dir, exist_ok=True)
+ return model_dir
+
+ def _download_cached_file(url, check_hash=True, progress=False):
+ parts = torch.hub.urlparse(url)
+ filename = os.path.basename(parts.path)
+ cached_file = os.path.join(_get_cache_dir(), filename)
+ if not os.path.exists(cached_file):
+ hash_prefix = None
+ if check_hash:
+ r = torch.hub.HASH_REGEX.search(filename) # r is Optional[Match[str]]
+ hash_prefix = r.group(1) if r else None
+ torch.hub.download_url_to_file(
+ url, cached_file, hash_prefix, progress=progress
+ )
+ return cached_file
+
+ def adapt_input_conv(in_chans, conv_weight):
+ conv_type = conv_weight.dtype
+ # Some weights are in torch.half, ensure it's float for sum on CPU
+ conv_weight = conv_weight.float()
+ O, I, J, K = conv_weight.shape
+ if in_chans == 1:
+ if I > 3:
+ assert conv_weight.shape[1] % 3 == 0
+ # For models with space2depth stems
+ conv_weight = conv_weight.reshape(O, I // 3, 3, J, K)
+ conv_weight = conv_weight.sum(dim=2, keepdim=False)
+ else:
+ conv_weight = conv_weight.sum(dim=1, keepdim=True)
+ elif in_chans != 3:
+ if I != 3:
+ raise NotImplementedError("Weight format not supported by conversion.")
+ else:
+ # NOTE this strategy should be better than random init, but there could be other combinations of
+ # the original RGB input layer weights that'd work better for specific cases.
+ repeat = int(math.ceil(in_chans / 3))
+ conv_weight = conv_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :]
+ conv_weight *= 3 / float(in_chans)
+ conv_weight = conv_weight.to(conv_type)
+ return conv_weight
+
+ def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()):
+ # Rescale the grid of position embeddings when loading from state_dict. Adapted from
+ # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
+ ntok_new = posemb_new.shape[1]
+ if num_tokens:
+ posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:]
+ ntok_new -= num_tokens
+ else:
+ posemb_tok, posemb_grid = posemb[:, :0], posemb[0]
+ gs_old = int(math.sqrt(len(posemb_grid)))
+ if not len(gs_new): # backwards compatibility
+ gs_new = [int(math.sqrt(ntok_new))] * 2
+ assert len(gs_new) >= 2
+ posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
+ posemb_grid = F.interpolate(
+ posemb_grid, size=gs_new, mode="bicubic", align_corners=False
+ )
+ posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(
+ 1, gs_new[0] * gs_new[1], -1
+ )
+ posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
+ return posemb
+
+ cached_file = _download_cached_file(url, check_hash=check_hash, progress=progress)
+
+ w = np.load(cached_file)
+ if not prefix and "opt/target/embedding/kernel" in w:
+ prefix = "opt/target/"
+
+ if hasattr(model.patch_embed, "backbone"):
+ # hybrid
+ backbone = model.patch_embed.backbone
+ stem_only = not hasattr(backbone, "stem")
+ stem = backbone if stem_only else backbone.stem
+ stem.conv.weight.copy_(
+ adapt_input_conv(
+ stem.conv.weight.shape[1], _n2p(w[f"{prefix}conv_root/kernel"])
+ )
+ )
+ stem.norm.weight.copy_(_n2p(w[f"{prefix}gn_root/scale"]))
+ stem.norm.bias.copy_(_n2p(w[f"{prefix}gn_root/bias"]))
+ if not stem_only:
+ for i, stage in enumerate(backbone.stages):
+ for j, block in enumerate(stage.blocks):
+ bp = f"{prefix}block{i + 1}/unit{j + 1}/"
+ for r in range(3):
+ getattr(block, f"conv{r + 1}").weight.copy_(
+ _n2p(w[f"{bp}conv{r + 1}/kernel"])
+ )
+ getattr(block, f"norm{r + 1}").weight.copy_(
+ _n2p(w[f"{bp}gn{r + 1}/scale"])
+ )
+ getattr(block, f"norm{r + 1}").bias.copy_(
+ _n2p(w[f"{bp}gn{r + 1}/bias"])
+ )
+ if block.downsample is not None:
+ block.downsample.conv.weight.copy_(
+ _n2p(w[f"{bp}conv_proj/kernel"])
+ )
+ block.downsample.norm.weight.copy_(
+ _n2p(w[f"{bp}gn_proj/scale"])
+ )
+ block.downsample.norm.bias.copy_(_n2p(w[f"{bp}gn_proj/bias"]))
+ embed_conv_w = _n2p(w[f"{prefix}embedding/kernel"])
+ else:
+ embed_conv_w = adapt_input_conv(
+ model.patch_embed.proj.weight.shape[1], _n2p(w[f"{prefix}embedding/kernel"])
+ )
+ model.patch_embed.proj.weight.copy_(embed_conv_w)
+ model.patch_embed.proj.bias.copy_(_n2p(w[f"{prefix}embedding/bias"]))
+ model.cls_token.copy_(_n2p(w[f"{prefix}cls"], t=False))
+ pos_embed_w = _n2p(w[f"{prefix}Transformer/posembed_input/pos_embedding"], t=False)
+ if pos_embed_w.shape != model.pos_embed.shape:
+ pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights
+ pos_embed_w,
+ model.pos_embed,
+ getattr(model, "num_tokens", 1),
+ model.patch_embed.grid_size,
+ )
+ model.pos_embed.copy_(pos_embed_w)
+ model.norm.weight.copy_(_n2p(w[f"{prefix}Transformer/encoder_norm/scale"]))
+ model.norm.bias.copy_(_n2p(w[f"{prefix}Transformer/encoder_norm/bias"]))
+ if (
+ isinstance(model.head, nn.Linear)
+ and model.head.bias.shape[0] == w[f"{prefix}head/bias"].shape[-1]
+ ):
+ model.head.weight.copy_(_n2p(w[f"{prefix}head/kernel"]))
+ model.head.bias.copy_(_n2p(w[f"{prefix}head/bias"]))
+ # if isinstance(getattr(model.pre_logits, 'fc', None),
+ # nn.Linear) and f'{prefix}pre_logits/bias' in w:
+ # model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel']))
+ # model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias']))
+ for i, block in enumerate(model.blocks.children()):
+ block_prefix = f"{prefix}Transformer/encoderblock_{i}/"
+ mha_prefix = block_prefix + "MultiHeadDotProductAttention_1/"
+ block.norm1.weight.copy_(_n2p(w[f"{block_prefix}LayerNorm_0/scale"]))
+ block.norm1.bias.copy_(_n2p(w[f"{block_prefix}LayerNorm_0/bias"]))
+ block.attn.qkv.weight.copy_(
+ torch.cat(
+ [
+ _n2p(w[f"{mha_prefix}{n}/kernel"], t=False).flatten(1).T
+ for n in ("query", "key", "value")
+ ]
+ )
+ )
+ block.attn.qkv.bias.copy_(
+ torch.cat(
+ [
+ _n2p(w[f"{mha_prefix}{n}/bias"], t=False).reshape(-1)
+ for n in ("query", "key", "value")
+ ]
+ )
+ )
+ block.attn.proj.weight.copy_(_n2p(w[f"{mha_prefix}out/kernel"]).flatten(1))
+ block.attn.proj.bias.copy_(_n2p(w[f"{mha_prefix}out/bias"]))
+ for r in range(2):
+ getattr(block.mlp, f"fc{r + 1}").weight.copy_(
+ _n2p(w[f"{block_prefix}MlpBlock_3/Dense_{r}/kernel"])
+ )
+ getattr(block.mlp, f"fc{r + 1}").bias.copy_(
+ _n2p(w[f"{block_prefix}MlpBlock_3/Dense_{r}/bias"])
+ )
+ block.norm2.weight.copy_(_n2p(w[f"{block_prefix}LayerNorm_2/scale"]))
+ block.norm2.bias.copy_(_n2p(w[f"{block_prefix}LayerNorm_2/bias"]))
diff --git a/src/chop/models/vision/vit/vit.py b/src/chop/models/vision/vit/vit.py
new file mode 100644
index 000000000..23869a8b6
--- /dev/null
+++ b/src/chop/models/vision/vit/vit.py
@@ -0,0 +1,394 @@
+import torch
+import torch.nn as nn
+from functools import partial
+from logging import getLogger
+from timm.layers import (
+ get_act_layer,
+ get_norm_layer,
+ LayerType,
+ DropPath,
+ to_2tuple,
+ trunc_normal_,
+)
+from timm.models._hub import load_state_dict_from_hf
+import numpy as np
+from .utils import load_weights_from_npz
+
+logger = getLogger(__name__)
+
+try:
+ from typing import Literal
+except ImportError:
+ from typing_extensions import Literal
+
+
+from typing import (
+ Any,
+ Callable,
+ Dict,
+ Optional,
+ Set,
+ Tuple,
+ Type,
+ Union,
+ List,
+)
+
+
+class PatchEmbed(nn.Module):
+ """Image to Patch Embedding"""
+
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
+ super().__init__()
+ img_size = to_2tuple(img_size)
+ patch_size = to_2tuple(patch_size)
+
+ self.img_size = img_size
+ self.patch_size = patch_size
+ # assert img_size[0] % patch_size[0] == 0 and img_size[1] % patch_size[1] == 0, \
+ # f"img_size {img_size} should be divided by patch_size {patch_size}."
+ self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1]
+ self.num_patches = self.H * self.W
+ self.proj = nn.Conv2d(
+ in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
+ )
+ self.norm = nn.LayerNorm(embed_dim)
+
+ def forward(self, x):
+ B, C, H, W = x.shape
+
+ x = self.proj(x).flatten(2).transpose(1, 2)
+ x = self.norm(x)
+ H, W = H // self.patch_size[0], W // self.patch_size[1]
+
+ return x, (H, W)
+
+
+class Mlp(nn.Module):
+ def __init__(
+ self,
+ in_features,
+ hidden_features=None,
+ out_features=None,
+ act_layer=nn.GELU,
+ drop=0.0,
+ ):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ # self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ # x = self.drop(x)
+ x = self.fc2(x)
+ # x = self.drop(x)
+ return x
+
+
+class Attention(nn.Module):
+
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 8,
+ qkv_bias: bool = False,
+ qk_norm: bool = False,
+ attn_drop: float = 0.0,
+ proj_drop: float = 0.0,
+ norm_layer: nn.Module = nn.LayerNorm,
+ ) -> None:
+ super().__init__()
+ assert dim % num_heads == 0, "dim should be divisible by num_heads"
+ self.num_heads = num_heads
+ self.head_dim = dim // num_heads
+ self.scale = torch.tensor(self.head_dim**-0.5)
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ B, N, C = x.shape
+ qkv = (
+ self.qkv(x)
+ .reshape(B, N, 3, self.num_heads, self.head_dim)
+ .permute(2, 0, 3, 1, 4)
+ )
+ q, k, v = qkv[0], qkv[1], qkv[2]
+ print("q", q)
+ q, k = self.q_norm(q), self.k_norm(k)
+
+ q = q * self.scale
+ attn = q @ k.transpose(-2, -1)
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+ x = attn @ v
+
+ x = x.transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class Block(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int,
+ mlp_ratio: float = 4.0,
+ qkv_bias: bool = False,
+ proj_drop: float = 0.0,
+ attn_drop: float = 0.0,
+ drop_path: float = 0.0,
+ act_layer: nn.Module = nn.GELU,
+ norm_layer: nn.Module = nn.LayerNorm,
+ mlp_layer: nn.Module = Mlp,
+ ) -> None:
+ super().__init__()
+ self.norm1 = norm_layer(dim)
+ self.attn = Attention(
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ attn_drop=attn_drop,
+ proj_drop=proj_drop,
+ )
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ self.norm2 = norm_layer(dim)
+ self.mlp = mlp_layer(
+ in_features=dim,
+ hidden_features=int(dim * mlp_ratio),
+ act_layer=act_layer,
+ drop=proj_drop,
+ )
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = x + self.drop_path1(self.attn(self.norm1(x)))
+ x = x + self.drop_path2(self.mlp(self.norm2(x)))
+ return x
+
+
+class VisionTransformer(nn.Module):
+ """Vision Transformer
+
+ A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
+ - https://arxiv.org/abs/2010.11929
+ """
+
+ def __init__(
+ self,
+ img_size: Union[int, Tuple[int, int]] = 224,
+ patch_size: Union[int, Tuple[int, int]] = 16,
+ in_chans: int = 3,
+ num_classes: int = 1000,
+ global_pool: Literal["", "avg", "token"] = "token",
+ embed_dim: int = 768,
+ depth: int = 12,
+ num_heads: int = 12,
+ mlp_ratio: float = 4.0,
+ qkv_bias: bool = True,
+ class_token: bool = True,
+ fc_norm: Optional[bool] = None,
+ drop_rate: float = 0.0,
+ pos_drop_rate: float = 0.0,
+ proj_drop_rate: float = 0.0,
+ attn_drop_rate: float = 0.0,
+ drop_path_rate: float = 0.0,
+ embed_layer: Callable = PatchEmbed,
+ norm_layer: Optional[LayerType] = None,
+ act_layer: Optional[LayerType] = None,
+ ) -> None:
+ """
+ Args:
+ img_size: Input image size.
+ patch_size: Patch size.
+ in_chans: Number of image input channels.
+ num_classes: Mumber of classes for classification head.
+ global_pool: Type of global pooling for final sequence (default: 'token').
+ embed_dim: Transformer embedding dimension.
+ depth: Depth of transformer.
+ num_heads: Number of attention heads.
+ mlp_ratio: Ratio of mlp hidden dim to embedding dim.
+ qkv_bias: Enable bias for qkv projections if True.
+ init_values: Layer-scale init values (layer-scale enabled if not None).
+ class_token: Use class token.
+ no_embed_class: Don't include position embeddings for class (or reg) tokens.
+ reg_tokens: Number of register tokens.
+ fc_norm: Pre head norm after pool (instead of before), if None, enabled when global_pool == 'avg'.
+ drop_rate: Head dropout rate.
+ pos_drop_rate: Position embedding dropout rate.
+ attn_drop_rate: Attention dropout rate.
+ drop_path_rate: Stochastic depth rate.
+ embed_layer: Patch embedding layer.
+ norm_layer: Normalization layer.
+ act_layer: MLP activation layer.
+ block_fn: Transformer block layer.
+ """
+ super().__init__()
+ assert global_pool in ("", "avg", "token")
+ assert class_token or global_pool != "token"
+ use_fc_norm = global_pool == "avg" if fc_norm is None else fc_norm
+ norm_layer = get_norm_layer(norm_layer) or partial(nn.LayerNorm, eps=1e-6)
+ act_layer = get_act_layer(act_layer) or nn.GELU
+
+ self.num_classes = num_classes
+ self.global_pool = global_pool
+ self.num_features = self.embed_dim = (
+ embed_dim # num_features for consistency with other models
+ )
+ self.num_prefix_tokens = 1 if class_token else 0
+ self.has_class_token = class_token
+ self.patch_embed = embed_layer(
+ img_size=img_size,
+ patch_size=patch_size,
+ in_chans=in_chans,
+ embed_dim=embed_dim,
+ patch_norm=False,
+ )
+ num_patches = self.patch_embed.num_patches
+
+ self.cls_token = (
+ nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
+ )
+ num_patches = num_patches + 1 if class_token else num_patches
+ self.pos_embed = nn.Parameter(torch.randn(1, num_patches, embed_dim) * 0.02)
+ self.pos_drop = nn.Dropout(p=pos_drop_rate)
+
+ dpr = [
+ x.item() for x in torch.linspace(0, drop_path_rate, depth)
+ ] # stochastic depth decay rule
+ self.blocks = nn.Sequential(
+ *[
+ Block(
+ dim=embed_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ proj_drop=proj_drop_rate,
+ attn_drop=attn_drop_rate,
+ drop_path=dpr[i],
+ norm_layer=norm_layer,
+ act_layer=act_layer,
+ )
+ for i in range(depth)
+ ]
+ )
+ self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity()
+
+ # Classifier Head
+ self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity()
+ self.head_drop = nn.Dropout(drop_rate)
+ self.head = (
+ nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
+ )
+
+ def _pos_embed(self, x: torch.Tensor) -> torch.Tensor:
+ pos_embed = self.pos_embed
+ x = torch.cat([self.cls_token.expand(x.shape[0], -1, -1), x], dim=1)
+ x = x + pos_embed
+
+ return self.pos_drop(x)
+
+ def forward_features(self, x: torch.Tensor) -> torch.Tensor:
+ x, _ = self.patch_embed(x)
+ x = self._pos_embed(x)
+ x = self.blocks(x)
+ x = self.norm(x)
+ return x
+
+ def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
+ if self.global_pool == "avg":
+ x = x[:, self.num_prefix_tokens :].mean(dim=1)
+ elif self.global_pool:
+ x = x[:, 0] # class token
+ x = self.fc_norm(x)
+ x = self.head_drop(x)
+ return x if pre_logits else self.head(x)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.forward_features(x)
+ x = self.forward_head(x)
+ return x
+
+
+def load_pretrained(pretrained, num_classes, model, url, model_name):
+ if pretrained:
+ checkpoint = np.load(
+ "/root/work/mase-tools/machop/.machop_cache/.machop_cache.ff0e4f759408437a93630130d36308e8.partial"
+ )
+ model.load_state_dict(checkpoint, strict=False)
+ logger.info("Pretrained weights loaded into {}".format(model_name))
+ else:
+ logger.info("{} randomly initialized".format(model_name))
+
+
+def get_vit_tiny_patch16(info, pretrained=False, **kwargs):
+ """ViT-Tiny (Vit-Ti/16)"""
+ num_classes = info.num_classes
+ img_size = info.image_size[-1]
+ model = VisionTransformer(
+ img_size=img_size,
+ num_classes=num_classes,
+ patch_size=16,
+ embed_dim=192,
+ num_heads=3,
+ depth=12,
+ **kwargs,
+ )
+ if img_size == 224:
+ url = "https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz"
+ elif img_size == 384:
+ url = "https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz"
+ else:
+ pretrained = False
+ logger.warning("this image_size is not supported rightnow")
+
+ load_weights_from_npz(model, url, check_hash=True)
+ return model
+
+
+def get_vit_base_patch16(info, pretrained=False, **kwargs):
+ """ViT-Base (Vit-B/16)"""
+ num_classes = info.num_classes
+ img_size = info.image_size[-1]
+ model = VisionTransformer(
+ img_size=img_size,
+ num_classes=num_classes,
+ patch_size=16,
+ embed_dim=768,
+ num_heads=12,
+ depth=12,
+ **kwargs,
+ )
+ if img_size == 224:
+ pre_trained_loc = "timm/vit_base_patch16_224.augreg2_in21k_ft_in1k"
+ elif img_size == 384:
+ url = "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_384-83fb41ba.pth"
+ else:
+ logger.warning("this image_size is not supported rightnow")
+ pretrained = False
+
+ if pretrained:
+ checkpoint = load_state_dict_from_hf(pre_trained_loc)
+ if num_classes != 1000:
+ _ = checkpoint.pop("head.weight")
+ _ = checkpoint.pop("head.bias")
+ logger.warning(
+ f"num_classes (={num_classes}) != 1000. The last classifier layer (head) is randomly initialized"
+ )
+ model.load_state_dict(checkpoint, strict=False)
+ logger.info("Pretrained weights loaded into vit_base_patch16")
+ else:
+ logger.info("vit_base_patch16 randomly initialized")
+
+ return model
diff --git a/src/chop/nn/quantized/__init__.py b/src/chop/nn/quantized/__init__.py
index f9f1389e7..b9dabd9ee 100644
--- a/src/chop/nn/quantized/__init__.py
+++ b/src/chop/nn/quantized/__init__.py
@@ -2,6 +2,8 @@
quantized_module_map,
BertSelfAttentionInteger,
BertSelfAttentionHeadInteger,
+ ViTSelfAttentionHeadInteger,
+ ViTAttentionInteger,
LinearInteger,
LayerNormInteger,
GELUInteger,
diff --git a/src/chop/nn/quantized/functional/__init__.py b/src/chop/nn/quantized/functional/__init__.py
index 65ccb1915..8f215ab67 100644
--- a/src/chop/nn/quantized/functional/__init__.py
+++ b/src/chop/nn/quantized/functional/__init__.py
@@ -1,10 +1,12 @@
from .softermax import fixed_softermax
-
+from .softmax import softmax_integer
+from .layer_norm import IntLayerNormFunc, _int_layer_norm
from .add import (
add_block_fp,
add_block_log,
add_block_minifloat,
add_integer,
+ add_integer_floor,
add_log,
add_minifloat_denorm,
add_minifloat_ieee,
@@ -138,6 +140,7 @@
quantized_func_map = {
"add_block_minifloat": add_block_minifloat,
"add_integer": add_integer,
+ "add_integer_floor": add_integer_floor,
"add_fixed": add_integer,
"add_log": add_log,
"add_minifloat_ieee": add_minifloat_ieee,
diff --git a/src/chop/nn/quantized/functional/add.py b/src/chop/nn/quantized/functional/add.py
index 051c2b260..98ae5f935 100644
--- a/src/chop/nn/quantized/functional/add.py
+++ b/src/chop/nn/quantized/functional/add.py
@@ -7,6 +7,7 @@
block_log_quantizer,
block_minifloat_quantizer,
integer_quantizer,
+ integer_floor_quantizer,
log_quantizer,
minifloat_denorm_quantizer,
minifloat_ieee_quantizer,
@@ -28,6 +29,21 @@ def add_integer(x, y, config):
return x + y
+def add_integer_floor(x, y, config):
+ bypass = config.get("bypass", False)
+ if bypass:
+ return x + y
+ else:
+ # establish quantizers
+ x_width, x_frac_width = config["data_in_width"], config["data_in_frac_width"]
+ x_quantizer = partial(
+ integer_floor_quantizer, width=x_width, frac_width=x_frac_width
+ )
+ x = x_quantizer(x)
+ y = x_quantizer(y)
+ return x + y
+
+
def add_binary(x, y, config):
bypass = config.get("bypass", False)
if bypass:
diff --git a/src/chop/nn/quantized/functional/layer_norm.py b/src/chop/nn/quantized/functional/layer_norm.py
new file mode 100644
index 000000000..1ca105a33
--- /dev/null
+++ b/src/chop/nn/quantized/functional/layer_norm.py
@@ -0,0 +1,141 @@
+from torch import nn
+import torch
+
+from chop.nn.quantizers import integer_floor_quantizer
+from math import ceil, log2
+
+
+def _int_layer_norm(
+ x: torch.Tensor,
+ normalized_shape: tuple or int,
+ weight=None,
+ bias=None,
+ eps=1e-5,
+ q_config={},
+):
+ def quantize(x, width, frac_width, by_pass=False):
+ if not by_pass:
+ x = integer_floor_quantizer(x, width, frac_width)
+ return x
+
+ def get_dim_and_prodofdim(x, normalized_shape):
+ dim = tuple(range(0 - len(normalized_shape), 0))
+ num_vals = 1
+ for items in dim:
+ num_vals *= x.shape[items]
+ return dim, num_vals
+
+ def isqrt(x: torch.Tensor):
+ x = (x + eps).sqrt()
+ x = x.reciprocal()
+ return x
+
+ if isinstance(normalized_shape, int):
+ normalized_shape = (normalized_shape,)
+ dim, num_vals = get_dim_and_prodofdim(x, normalized_shape)
+ x = quantize(
+ x,
+ q_config.get("data_in_width"),
+ q_config.get("data_in_frac_width"),
+ q_config.get("by_pass"),
+ )
+ acc_out_width = ceil(log2(num_vals)) + q_config.get("data_in_width")
+ inv_num_vals_quant_0 = quantize(
+ torch.tensor(1 / num_vals), acc_out_width + 2, acc_out_width
+ )
+ # Mean calculation
+ mu_acc = x.sum(dim, keepdim=True)
+ mu = mu_acc * inv_num_vals_quant_0
+ mu = quantize(
+ mu,
+ q_config.get("data_in_width"),
+ q_config.get("data_in_frac_width"),
+ q_config.get("by_pass"),
+ )
+ print("mu", mu * 2 ** q_config.get("data_in_frac_width"))
+ # I hope the output precision here should be $clog2
+ # Variance calculation
+ diff = x - mu
+
+ squares = diff**2
+ sum_squares = torch.sum(squares, dim, keepdim=True)
+ squares_adder_tree_width = 2 * q_config.get("data_in_width") + ceil(log2(num_vals))
+ inv_num_vals_quant_1 = quantize(
+ torch.tensor(1 / num_vals),
+ squares_adder_tree_width + 2,
+ squares_adder_tree_width,
+ )
+ var = sum_squares * inv_num_vals_quant_1
+ var = quantize(
+ var,
+ q_config.get("isqrt_in_width"),
+ q_config.get("isqrt_in_frac_width"),
+ q_config.get("by_pass"),
+ )
+
+ inv_sqrt = isqrt(var)
+ inv_sqrt = quantize(
+ inv_sqrt,
+ q_config.get("isqrt_out_width"),
+ q_config.get("isqrt_out_frac_width"),
+ q_config.get("by_pass"),
+ )
+
+ # Norm calculation
+ norm_out = diff * inv_sqrt
+
+ norm_out = quantize(
+ norm_out,
+ q_config.get("data_out_width"),
+ q_config.get("data_out_frac_width"),
+ q_config.get("by_pass"),
+ )
+ if weight is not None:
+ qweight = quantize(
+ weight,
+ q_config.get("weight_width"),
+ q_config.get("weight_frac_width"),
+ q_config.get("by_pass"),
+ )
+ norm_out = norm_out * qweight
+ if bias is not None:
+ qbias = quantize(
+ bias,
+ q_config.get("bias_width"),
+ q_config.get("bias_frac_width"),
+ q_config.get("by_pass"),
+ )
+ norm_out = norm_out + qbias
+ norm_out = quantize(
+ norm_out,
+ q_config.get("data_out_width"),
+ q_config.get("data_out_frac_width"),
+ q_config.get("by_pass"),
+ )
+ return norm_out
+
+
+class IntLayerNormFunc(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx, input: torch.Tensor, normalized_shape, weight, bias, eps, config, bypass
+ ):
+ with torch.enable_grad():
+ layernormed = nn.functional.layer_norm(
+ input, normalized_shape, weight, bias, eps
+ )
+ ctx.save_for_backward(input, layernormed)
+ output = (
+ _int_layer_norm(input, normalized_shape, weight, bias, eps, config)
+ if not bypass
+ else layernormed
+ )
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output: torch.Tensor):
+ input, layernormed = ctx.saved_tensors
+ (grad_input,) = torch.autograd.grad(
+ layernormed, input, grad_outputs=grad_output
+ )
+ return grad_input, None, None, None, None, None, None
diff --git a/src/chop/nn/quantized/functional/matmul.py b/src/chop/nn/quantized/functional/matmul.py
index d06eb1ece..8bf4c2300 100644
--- a/src/chop/nn/quantized/functional/matmul.py
+++ b/src/chop/nn/quantized/functional/matmul.py
@@ -28,7 +28,7 @@ def generic_matmul_integer(x, y, config, style="matmul", out_config=None, floor=
if bypass:
return matmul(x, y)
else:
- base_quantizer = integer_quantizer
+ base_quantizer = integer_floor_quantizer if floor else integer_quantizer
x_width, x_frac_width = config["data_in_width"], config["data_in_frac_width"]
y_width, y_frac_width = config["weight_width"], config["weight_frac_width"]
diff --git a/src/chop/nn/quantized/functional/softmax.py b/src/chop/nn/quantized/functional/softmax.py
new file mode 100644
index 000000000..f0a9d57c4
--- /dev/null
+++ b/src/chop/nn/quantized/functional/softmax.py
@@ -0,0 +1,34 @@
+from torch import nn
+import torch
+
+from chop.nn.quantizers import integer_quantizer, integer_floor_quantizer
+from math import ceil, log2
+
+
+def softmax_integer(x: torch.Tensor, dim: int, config: dict, floor=False):
+ """
+ This function defines the calculation process of hashsoftmax
+ Exp result is get from a hash table
+ All the data in this function will be quantized to fixed-point
+ """
+ base_quantizer = integer_floor_quantizer if floor else integer_quantizer
+ if config["mult_data"] != None:
+ mult = config["mult_data"]
+ else:
+ mult = 1
+ quant_x = base_quantizer(x, config["data_in_width"], config["data_in_frac_width"])
+ print("quant_x = ", quant_x * 2 ** config["data_in_frac_width"])
+ exp_x = (quant_x * mult).exp()
+ quant_exp = base_quantizer(
+ exp_x, config["data_in_exp_width"], config["data_in_exp_frac_width"]
+ )
+ print("quant_exp = ", quant_exp * 2 ** config["data_in_exp_frac_width"])
+ exp_sum = quant_exp.sum(dim=dim, keepdim=True)
+
+ shift_width = config["data_out_frac_width"]
+ if torch.all(quant_exp == exp_sum):
+ out = torch.tensor(1.0, device=x.device).expand(x.shape)
+ else:
+ out = quant_exp * (2 ** (shift_width)) // exp_sum
+ out = out / (2 ** (shift_width))
+ return out
diff --git a/src/chop/nn/quantized/modules/__init__.py b/src/chop/nn/quantized/modules/__init__.py
index 4219d6da9..1b07f192e 100644
--- a/src/chop/nn/quantized/modules/__init__.py
+++ b/src/chop/nn/quantized/modules/__init__.py
@@ -1,5 +1,5 @@
-from .attention_head import BertSelfAttentionHeadInteger
-from .attention import BertSelfAttentionInteger
+from .attention_head import BertSelfAttentionHeadInteger, ViTSelfAttentionHeadInteger
+from .attention import BertSelfAttentionInteger, ViTAttentionInteger
# from .add import AddInteger
from .conv1d import (
@@ -32,6 +32,7 @@
LinearBlockFP,
LinearBlockMinifloat,
LinearInteger,
+ LinearIntegerFloor,
LinearLog,
LinearBlockLog,
LinearMinifloatDenorm,
@@ -43,6 +44,7 @@
LinearLUT,
LinearLogicNets,
LinearMXIntHardware,
+ # LinearMxInt,
)
from .pool2d import (
AdaptiveAvgPool2dInteger,
@@ -67,6 +69,7 @@
)
from .layer_norm import (
LayerNormInteger,
+ LayerNormIntegerFloor,
)
from .group_norm import GroupNormInteger
from .instance_norm2d import InstanceNorm2dInteger
@@ -113,6 +116,7 @@
GELUBlockFP,
GELUBlockMinifloat,
GELUInteger,
+ GELUIntegerFloor,
GELULog,
GELUBlockLog,
GELUMinifloatDenorm,
@@ -151,6 +155,8 @@
GroupedQueryAttentionInteger,
)
+# from mase_components.linear_layers.mxint_operators.test.utils import MXIntLinearHardware
+
quantized_module_map = {
"conv1d_block_minifloat": Conv1dBlockMinifloat,
"conv1d_integer": Conv1dInteger,
@@ -176,6 +182,7 @@
"linear_block_minifloat": LinearBlockMinifloat,
"linear_integer": LinearInteger,
"linear_fixed": LinearInteger,
+ "linear_integer_floor": LinearIntegerFloor,
"linear_log": LinearLog,
"linear_mxint_hardware": LinearMXIntHardware,
"linear_block_log": LinearBlockLog,
@@ -204,6 +211,7 @@
"batch_norm2d_integer": BatchNorm2dInteger,
"batch_norm2d_binary": BatchNorm2dBinary,
"layer_norm_integer": LayerNormInteger,
+ "layer_norm_integer_floor": LayerNormIntegerFloor,
"group_norm_integer": GroupNormInteger,
"instance_norm2d_integer": InstanceNorm2dInteger,
"rms_norm_integer": RMSNormInteger,
@@ -240,6 +248,7 @@
"gelu_block_minifloat": GELUBlockMinifloat,
"gelu_integer": GELUInteger,
"gelu_fixed": GELUInteger,
+ "gelu_integer_floor": GELUIntegerFloor,
"gelu_log": GELULog,
"gelu_block_log": GELUBlockLog,
"gelu_minifloat_ieee": GELUMinifloatIEEE,
@@ -271,5 +280,7 @@
"batch_norm1d_linear": BatchNorm1dInteger,
"bert_self_attention_head_integer": BertSelfAttentionHeadInteger,
"bert_self_attention_integer": BertSelfAttentionInteger,
+ "bert_self_attention_head_integer": ViTSelfAttentionHeadInteger,
+ "vit_self_attention_integer": ViTAttentionInteger,
"grouped_query_attention_integer": GroupedQueryAttentionInteger,
}
diff --git a/src/chop/nn/quantized/modules/attention.py b/src/chop/nn/quantized/modules/attention.py
index 45819db75..2fda6273b 100644
--- a/src/chop/nn/quantized/modules/attention.py
+++ b/src/chop/nn/quantized/modules/attention.py
@@ -1,18 +1,21 @@
from functools import partial
import torch
+import torch.nn as nn
from torch import Tensor
from torch.nn import functional as F
from transformers.models.bert.modeling_bert import BertSelfAttention
+from .attention_head import _ViTSelfAttentionHeadBase, ViTSelfAttentionHeadInteger
from chop.nn.quantized.modules.linear import (
LinearInteger,
)
from chop.nn.quantized.functional import fixed_softermax
+from chop.nn.quantizers import integer_quantizer
from chop.nn.quantized.functional import matmul_integer
-from typing import Optional, Tuple
+from typing import Optional, Tuple, Union
class _BertSelfAttentionBase(BertSelfAttention):
@@ -56,6 +59,87 @@ def forward(
return out
+class _ViTAttentionBase(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 8,
+ qkv_bias: bool = False,
+ qk_norm: bool = False,
+ attn_drop: float = 0.0,
+ proj_drop: float = 0.0,
+ ) -> None:
+ super().__init__()
+ assert dim % num_heads == 0, "dim should be divisible by num_heads"
+ self.num_heads = num_heads
+ self.head_dim = dim // num_heads
+ self.query = nn.Linear(dim, dim, bias=qkv_bias)
+ self.key = nn.Linear(dim, dim, bias=qkv_bias)
+ self.value = nn.Linear(dim, dim, bias=qkv_bias)
+ self.self_attention = _ViTSelfAttentionHeadBase(
+ dim=self.head_dim, num_heads=num_heads, attn_drop=attn_drop
+ )
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ B, N, C = x.shape
+
+ def _tensor_reshape(x):
+ return x.reshape(B, -1, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
+
+ q, k, v = (
+ _tensor_reshape(self.query(x)),
+ _tensor_reshape(self.key(x)),
+ _tensor_reshape(self.value(x)),
+ )
+ x = self.self_attention(q, k, v)
+ x = x.transpose(1, 2).reshape(B, N, C)
+
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class _ViTAttentionBase_before(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 8,
+ qkv_bias: bool = False,
+ qk_norm: bool = False,
+ attn_drop: float = 0.0,
+ proj_drop: float = 0.0,
+ ) -> None:
+ super().__init__()
+ assert dim % num_heads == 0, "dim should be divisible by num_heads"
+ self.num_heads = num_heads
+ self.head_dim = dim // num_heads
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.self_attention = _ViTSelfAttentionHeadBase(
+ dim=self.head_dim, num_heads=num_heads, attn_drop=attn_drop
+ )
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ B, N, C = x.shape
+ qkv = (
+ self.qkv(x)
+ .reshape(B, N, 3, self.num_heads, self.head_dim)
+ .permute(2, 0, 3, 1, 4)
+ )
+ q, k, v = qkv[0], qkv[1], qkv[2]
+
+ x = self.self_attention(q, k, v)
+
+ x = x.transpose(1, 2).reshape(B, N, C)
+
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
class BertSelfAttentionInteger(_BertSelfAttentionBase):
def __init__(
self,
@@ -105,11 +189,120 @@ def __init__(
self.matmul = partial(
matmul_integer,
config={
- "data_in_width": self.q_config["data_out_width"],
- "data_in_frac_width": self.q_config["data_out_frac_width"],
- "weight_width": self.q_config["data_out_width"],
- "weight_frac_width": self.q_config["data_out_frac_width"],
+ "data_in_width": self.q_config["data_in_width"],
+ "data_in_frac_width": self.q_config["data_in_frac_width"],
+ "weight_width": self.q_config["weight_width"],
+ "weight_frac_width": self.q_config["weight_frac_width"],
},
out_config=out_q_config,
floor=floor,
)
+
+
+class ViTAttentionInteger(_ViTAttentionBase):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 8,
+ qkv_bias: bool = False,
+ qk_norm: bool = False,
+ attn_drop: float = 0.0,
+ proj_drop: float = 0.0,
+ norm_layer: nn.Module = nn.LayerNorm,
+ q_config: dict = None,
+ floor=True,
+ ) -> None:
+ super().__init__(dim, num_heads, qkv_bias, qk_norm, attn_drop, proj_drop)
+ self.q_config = q_config
+ self.query = LinearInteger(
+ dim,
+ dim,
+ bias=qkv_bias,
+ config={
+ "data_in_width": q_config["data_in_width"],
+ "data_in_frac_width": q_config["data_in_frac_width"],
+ "weight_width": q_config["qkv_weight_width"],
+ "weight_frac_width": q_config["qkv_weight_frac_width"],
+ "bias_width": q_config["qkv_bias_width"],
+ "bias_frac_width": q_config["qkv_bias_frac_width"],
+ },
+ out_config={
+ "data_out_width": q_config["qkv_width"],
+ "data_out_frac_width": q_config["qkv_frac_width"],
+ },
+ floor=floor,
+ )
+ self.key = LinearInteger(
+ dim,
+ dim,
+ bias=qkv_bias,
+ config={
+ "data_in_width": q_config["data_in_width"],
+ "data_in_frac_width": q_config["data_in_frac_width"],
+ "weight_width": q_config["qkv_weight_width"],
+ "weight_frac_width": q_config["qkv_weight_frac_width"],
+ "bias_width": q_config["qkv_bias_width"],
+ "bias_frac_width": q_config["qkv_bias_frac_width"],
+ },
+ out_config={
+ "data_out_width": q_config["qkv_width"],
+ "data_out_frac_width": q_config["qkv_frac_width"],
+ },
+ floor=floor,
+ )
+ self.value = LinearInteger(
+ dim,
+ dim,
+ bias=qkv_bias,
+ config={
+ "data_in_width": q_config["data_in_width"],
+ "data_in_frac_width": q_config["data_in_frac_width"],
+ "weight_width": q_config["qkv_weight_width"],
+ "weight_frac_width": q_config["qkv_weight_frac_width"],
+ "bias_width": q_config["qkv_bias_width"],
+ "bias_frac_width": q_config["qkv_bias_frac_width"],
+ },
+ out_config={
+ "data_out_width": q_config["qkv_width"],
+ "data_out_frac_width": q_config["qkv_frac_width"],
+ },
+ floor=floor,
+ )
+ self.self_attention = ViTSelfAttentionHeadInteger(
+ dim=self.head_dim,
+ num_heads=num_heads,
+ attn_drop=attn_drop,
+ q_config={
+ "query_width": q_config["qkv_width"],
+ "query_frac_width": q_config["qkv_frac_width"],
+ "key_width": q_config["qkv_width"],
+ "key_frac_width": q_config["qkv_frac_width"],
+ "value_width": q_config["qkv_width"],
+ "value_frac_width": q_config["qkv_frac_width"],
+ "qkmm_out_width": q_config["qkmm_out_width"],
+ "qkmm_out_frac_width": q_config["qkmm_out_frac_width"],
+ "softmax_exp_width": q_config["softmax_exp_width"],
+ "softmax_exp_frac_width": q_config["softmax_exp_frac_width"],
+ "softmax_out_frac_width": q_config["softmax_out_frac_width"],
+ "svmm_out_width": q_config["svmm_out_width"],
+ "svmm_out_frac_width": q_config["svmm_out_frac_width"],
+ },
+ floor=floor,
+ )
+ self.proj = LinearInteger(
+ dim,
+ dim,
+ config={
+ "data_in_width": q_config["svmm_out_width"],
+ "data_in_frac_width": q_config["svmm_out_frac_width"],
+ "weight_width": q_config["proj_weight_width"],
+ "weight_frac_width": q_config["proj_weight_frac_width"],
+ "bias_width": q_config["proj_bias_width"],
+ "bias_frac_width": q_config["proj_bias_frac_width"],
+ },
+ out_config={
+ "data_out_width": q_config["data_out_width"],
+ "data_out_frac_width": q_config["data_out_frac_width"],
+ },
+ floor=floor,
+ )
diff --git a/src/chop/nn/quantized/modules/attention_head.py b/src/chop/nn/quantized/modules/attention_head.py
index 8f9ea5969..7bf1a53af 100644
--- a/src/chop/nn/quantized/modules/attention_head.py
+++ b/src/chop/nn/quantized/modules/attention_head.py
@@ -9,7 +9,10 @@
from chop.nn.quantized.functional.matmul import (
generic_matmul_integer,
)
-from chop.nn.quantizers.integer import integer_quantizer
+from chop.nn.quantized.functional.softmax import (
+ softmax_integer,
+)
+from chop.nn.quantizers.integer import integer_quantizer, integer_floor_quantizer
class _BertSelfAttentionHeadBase(torch.nn.Module):
@@ -89,3 +92,110 @@ def forward(
value_layer=value_layer,
attention_mask=attention_mask,
)
+
+
+class _ViTSelfAttentionHeadBase(torch.nn.Module):
+ def __init__(self, dim, num_heads, attn_drop) -> None:
+ super().__init__()
+ self.dropout = nn.Dropout(attn_drop)
+
+ self.matmul1 = torch.matmul
+ self.matmul2 = torch.matmul
+ self.mult_data = torch.tensor(1 / math.sqrt(dim))
+ self.act = nn.functional.softmax
+
+ def self_attention_head(
+ self,
+ query_layer: torch.Tensor,
+ key_layer: torch.Tensor,
+ value_layer: torch.Tensor,
+ ) -> Tensor:
+ attention_scores = self.matmul1(query_layer, key_layer.transpose(-1, -2))
+ print("attention_scores = ", attention_scores * 2**4)
+ attention_scores = attention_scores * self.mult_data
+
+ # Normalize the attention scores to probabilities.
+ print("attention_scores = ", attention_scores * 2**4)
+ attention_probs = self.act(attention_scores, dim=-1)
+ print("attention_probs = ", attention_probs * 2**4)
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs = self.dropout(attention_probs)
+ context_layer = self.matmul2(attention_probs, value_layer)
+ print("value_layer = ", value_layer * 2**4)
+ print("context_layer = ", context_layer * 2**4)
+ return context_layer
+
+ def forward(
+ self,
+ query_layer: torch.Tensor,
+ key_layer: torch.Tensor,
+ value_layer: torch.Tensor,
+ ) -> Tensor:
+ return self.self_attention_head(
+ query_layer=query_layer, key_layer=key_layer, value_layer=value_layer
+ )
+
+
+class ViTSelfAttentionHeadInteger(_ViTSelfAttentionHeadBase):
+ def __init__(
+ self, dim, num_heads, attn_drop=0.0, q_config: dict = None, floor=False
+ ) -> None:
+ super().__init__(dim, num_heads, attn_drop)
+ base_quantizer = integer_floor_quantizer if floor else integer_quantizer
+ self.query_quantizer = partial(
+ base_quantizer,
+ width=q_config["query_width"],
+ frac_width=q_config["query_frac_width"],
+ )
+ self.key_quantizer = partial(
+ base_quantizer,
+ width=q_config["key_width"],
+ frac_width=q_config["key_frac_width"],
+ )
+ self.value_quantizer = partial(
+ base_quantizer,
+ width=q_config["value_width"],
+ frac_width=q_config["value_frac_width"],
+ )
+ self.matmul1 = partial(
+ generic_matmul_integer,
+ config={
+ "data_in_width": q_config["query_width"],
+ "data_in_frac_width": q_config["query_frac_width"],
+ "weight_width": q_config["key_width"],
+ "weight_frac_width": q_config["key_frac_width"],
+ },
+ out_config={
+ "data_out_width": q_config["qkmm_out_width"],
+ "data_out_frac_width": q_config["qkmm_out_frac_width"],
+ },
+ floor=floor,
+ )
+ self.act = partial(
+ softmax_integer,
+ config={
+ "data_in_width": q_config["qkmm_out_width"],
+ "data_in_frac_width": q_config["qkmm_out_frac_width"],
+ "data_in_exp_width": q_config["softmax_exp_width"],
+ "data_in_exp_frac_width": q_config["softmax_exp_frac_width"],
+ "data_out_frac_width": q_config["softmax_out_frac_width"],
+ "mult_data": self.mult_data,
+ },
+ floor=floor,
+ )
+ self.mult_data = torch.tensor(1)
+ self.matmul2 = partial(
+ generic_matmul_integer,
+ config={
+ "data_in_width": q_config["softmax_out_frac_width"] + 2,
+ "data_in_frac_width": q_config["softmax_out_frac_width"],
+ "weight_width": q_config["value_width"],
+ "weight_frac_width": q_config["value_frac_width"],
+ },
+ out_config={
+ "data_out_width": q_config["svmm_out_width"],
+ "data_out_frac_width": q_config["svmm_out_frac_width"],
+ },
+ floor=floor,
+ )
diff --git a/src/chop/nn/quantized/modules/gelu.py b/src/chop/nn/quantized/modules/gelu.py
index 074cf4df8..59096a099 100644
--- a/src/chop/nn/quantized/modules/gelu.py
+++ b/src/chop/nn/quantized/modules/gelu.py
@@ -11,6 +11,7 @@
block_log_quantizer,
block_minifloat_quantizer,
integer_quantizer,
+ integer_floor_quantizer,
log_quantizer,
minifloat_denorm_quantizer,
minifloat_ieee_quantizer,
@@ -25,13 +26,17 @@ def __init__(self, inplace: bool = False):
self.inplace = inplace
self.bypass = False
self.x_quantizer = None
+ self.out_quantizer = None
def forward(self, x: Tensor) -> Tensor:
if self.bypass:
return F.gelu(x)
else:
x = self.x_quantizer(x)
- return F.gelu(x)
+ out = F.gelu(x)
+ if self.out_quantizer is None:
+ return out
+ return self.out_quantizer(out)
def get_quantized_output(self, x: Tensor) -> Tensor:
x = self.x_quantizer(x)
@@ -58,11 +63,32 @@ def __init__(self, inplace: bool = False, config: dict = None):
self.x_width = x_width
self.x_frac_width = x_frac_width
- # def get_output_bitwidth(self) -> dict:
- # return {
- # "data_out_width": self.config["data_in_width"],
- # "data_out_frac_width": self.config["data_in_frac_width"],
- # }
+
+class GELUIntegerFloor(_GELUBase):
+ bypass = None
+
+ def __init__(self, inplace: bool = False, config: dict = None):
+ super().__init__(inplace)
+ assert config is not None, "config is None!"
+
+ self.config = config
+ self.bypass = config.get("bypass", False)
+ if self.bypass:
+ return
+ # establish quantizers
+ x_width, x_frac_width = config["data_in_width"], config["data_in_frac_width"]
+ out_width, out_frac_width = (
+ config["data_out_width"],
+ config["data_out_frac_width"],
+ )
+ self.x_quantizer = partial(
+ integer_floor_quantizer, width=x_width, frac_width=x_frac_width
+ )
+ self.out_quantizer = partial(
+ integer_floor_quantizer, width=out_width, frac_width=out_frac_width
+ )
+ self.x_width = x_width
+ self.x_frac_width = x_frac_width
class GELUMinifloatDenorm(_GELUBase):
diff --git a/src/chop/nn/quantized/modules/layer_norm.py b/src/chop/nn/quantized/modules/layer_norm.py
index 2ca5c6068..42829d77c 100644
--- a/src/chop/nn/quantized/modules/layer_norm.py
+++ b/src/chop/nn/quantized/modules/layer_norm.py
@@ -1,12 +1,11 @@
from functools import partial
-
+import torch
import torch.nn as nn
from torch import Tensor
import torch.nn.functional as F
-from chop.nn.quantizers import (
- integer_quantizer,
-)
+from ...quantizers import integer_quantizer
+from ..functional import IntLayerNormFunc
class _LayerNormBase(nn.LayerNorm):
@@ -47,7 +46,37 @@ def __init__(
self.bypass = config.get("bypass", False)
if self.bypass:
return
- x_width, x_frac_width = config["data_in_width"], config["data_in_frac_width"]
+ x_width, x_frac_width = config.get("data_in_width"), config.get(
+ "data_in_frac_width"
+ )
self.x_quantizer = partial(
integer_quantizer, width=x_width, frac_width=x_frac_width
)
+
+
+class LayerNormIntegerFloor(nn.LayerNorm):
+ def __init__(
+ self,
+ normalized_shape,
+ eps: float = 0.00001,
+ elementwise_affine: bool = False,
+ bias: bool = False,
+ device=None,
+ dtype=None,
+ config=None,
+ ) -> None:
+ assert config is not None, "config is None!"
+ super().__init__(normalized_shape, eps, elementwise_affine, bias, device, dtype)
+ self.config = config
+ self.bypass = config.get("bypass", False)
+
+ def forward(self, x: Tensor) -> Tensor:
+ return IntLayerNormFunc.apply(
+ x,
+ self.normalized_shape,
+ self.weight,
+ self.bias,
+ self.eps,
+ self.config,
+ self.bypass,
+ )
diff --git a/src/chop/nn/quantized/modules/linear.py b/src/chop/nn/quantized/modules/linear.py
index 0aaaea611..148478488 100644
--- a/src/chop/nn/quantized/modules/linear.py
+++ b/src/chop/nn/quantized/modules/linear.py
@@ -66,6 +66,7 @@ def forward(self, x: Tensor) -> Tensor:
x = self.x_quantizer(x)
w = self.w_quantizer(self.weight)
bias = self.b_quantizer(self.bias) if self.bias is not None else None
+ print(w)
out = F.linear(x, w, bias)
if self.out_quantizer is None:
return out
@@ -96,6 +97,11 @@ def __init__(
x_width, x_frac_width = config["data_in_width"], config["data_in_frac_width"]
# check bias quantizer, if not, use weight quantizer
b_width, b_frac_width = config["bias_width"], config["bias_frac_width"]
+ if config.get("data_out_width") is not None:
+ out_width, out_frac_width = (
+ config["data_out_width"],
+ config["data_out_frac_width"],
+ )
if out_config is not None:
out_width, out_frac_width = (
out_config["data_out_width"],
@@ -117,6 +123,46 @@ def __init__(
)
+class LinearIntegerFloor(_LinearBase):
+ def __init__(
+ self,
+ in_features: int,
+ out_features: int,
+ bias: bool = True,
+ device=None,
+ dtype=None,
+ config=None,
+ ) -> None:
+ super().__init__(in_features, out_features, bias, device, dtype)
+ assert config is not None, "config is None!"
+ self.config = config
+ self.bypass = config.get("bypass", False)
+ if self.bypass:
+ return
+ # establish quantizer
+ w_width, w_frac_width = config["weight_width"], config["weight_frac_width"]
+ x_width, x_frac_width = config["data_in_width"], config["data_in_frac_width"]
+ # check bias quantizer, if not, use weight quantizer
+ b_width, b_frac_width = config["bias_width"], config["bias_frac_width"]
+ out_width, out_frac_width = (
+ config["data_out_width"],
+ config["data_out_frac_width"],
+ )
+
+ self.w_quantizer = partial(
+ integer_floor_quantizer, width=w_width, frac_width=w_frac_width
+ )
+ self.x_quantizer = partial(
+ integer_floor_quantizer, width=x_width, frac_width=x_frac_width
+ )
+ self.b_quantizer = partial(
+ integer_floor_quantizer, width=b_width, frac_width=b_frac_width
+ )
+ self.out_quantizer = partial(
+ integer_floor_quantizer, width=out_width, frac_width=out_frac_width
+ )
+
+
class LinearMinifloatDenorm(_LinearBase):
def __init__(
self,
@@ -1028,6 +1074,91 @@ def forward(self, x: Tensor) -> Tensor:
return self.math_forward(x)
+class LinearMxInt(_LinearBase):
+ def __init__(
+ self,
+ in_features: int,
+ out_features: int,
+ bias: bool = True,
+ device=None,
+ dtype=None,
+ config=None,
+ out_config=None,
+ ) -> None:
+ super().__init__(in_features, out_features, bias, device, dtype)
+ assert config is not None, "config is None!"
+ self.config = config
+ self.out_config = out_config
+ self.bypass = config.get("bypass", False)
+ if self.bypass:
+ return
+ # establish quantizer
+ w_width, w_exponent_width = (
+ config["weight_width"],
+ config["weight_exponent_width"],
+ )
+ w_p1, w_p0 = (
+ config["weight_parallelism"][0],
+ config["weight_parallelism"][1],
+ )
+ x_width, x_exponent_width = (
+ config["data_in_width"],
+ config["data_in_exponent_width"],
+ )
+ x_p1, x_p0 = (
+ config["data_in_parallelism"][0],
+ config["data_in_parallelism"][1],
+ )
+ # check bias quantizer, if not, use weight quantizer
+ b_width, b_exponent_width = config["bias_width"], config["bias_exponent_width"]
+ b_p1, b_p0 = config["bias_parallelism"][0], config["bias_parallelism"][1]
+ base_quantizer = mxint_hardware
+ if out_config is not None:
+ out_width, out_exponent_width = (
+ config["data_out_width"],
+ config["data_out_exponent_width"],
+ )
+ out_p1, out_p0 = (
+ config["data_out_parallelism_dim_1"],
+ config["data_out_parallelism_dim_0"],
+ )
+ self.out_quantizer = partial(
+ base_quantizer,
+ q_config={"width": out_width, "exponent_width": out_exponent_width},
+ parallelism=[out_p1, out_p0],
+ )
+ self.w_quantizer = partial(
+ base_quantizer,
+ q_config={"width": w_width, "exponent_width": w_exponent_width},
+ parallelism=[w_p1, w_p0],
+ )
+ self.x_quantizer = partial(
+ base_quantizer,
+ q_config={"width": x_width, "exponent_width": x_exponent_width},
+ parallelism=[x_p1, x_p0],
+ )
+ self.b_quantizer = partial(
+ base_quantizer,
+ q_config={"width": b_width, "exponent_width": b_exponent_width},
+ parallelism=[b_p1, b_p0],
+ )
+
+ def forward(self, x: Tensor) -> Tensor:
+ if self.bypass:
+ return F.linear(x, self.weight, self.bias)
+ else:
+ x = self.x_quantizer(x)
+ w = self.w_quantizer(self.weight)
+ if self.bias is not None:
+ bias = self.b_quantizer(self.bias)
+ else:
+ bias = None
+ out = F.linear(x, w, bias)
+ if self.out_quantizer is None:
+ return out
+ return self.out_quantizer(out)
+
+
class LinearMXIntHardware(_LinearBase):
def __init__(
self,
diff --git a/src/chop/nn/quantized/modules/mxint_modules.py b/src/chop/nn/quantized/modules/mxint_modules.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/src/chop/nn/quantizers/mxint_hardware.py b/src/chop/nn/quantizers/mxint_hardware.py
index 0c3e06130..e772cd6e9 100644
--- a/src/chop/nn/quantizers/mxint_hardware.py
+++ b/src/chop/nn/quantizers/mxint_hardware.py
@@ -24,14 +24,14 @@ def mxint_quant_block(
# exponent
if exponent == None:
- exponent = torch.ceil(torch.log2(x.abs().max())) - exponent_bias
+ exponent = torch.ceil(torch.log2(x.abs().max()))
exponent = torch.clamp(exponent, exponent_min, exponent_max)
# mantissa
int_min = -(2 ** (width - 1))
int_max = 2 ** (width - 1) - 1
- mantissa = x / 2**exponent
+ mantissa = x * (2 ** (width - 1)) / 2**exponent
mantissa = torch.clamp(mantissa.floor(), int_min, int_max)
- q_x = (2**exponent) * mantissa
+ q_x = (2**exponent) * mantissa / ((2 ** (width - 1)))
return q_x
diff --git a/src/chop/nn/quantizers/quantizers_for_hw.py b/src/chop/nn/quantizers/quantizers_for_hw.py
index d5ca3d8cf..4a90e2a38 100644
--- a/src/chop/nn/quantizers/quantizers_for_hw.py
+++ b/src/chop/nn/quantizers/quantizers_for_hw.py
@@ -3,15 +3,15 @@
import torch.nn.functional as F
from torch import Tensor
-# from .quantizers import integer_quantizer
+from .integer import integer_quantizer, integer_floor_quantizer
from .utils import block, my_clamp, my_round, unblock, my_floor
-def integer_quantizer_for_hw(x: Tensor, width: int, frac_width: int):
+def integer_quantizer_for_hw(x: Tensor, width: int, frac_width: int, floor=False):
thresh = 2 ** (width - 1)
scale = 2**frac_width
-
- fixed_point_value = my_clamp(my_round(x.mul(scale)), -thresh, thresh - 1)
+ base_quantizer = integer_floor_quantizer if floor else integer_quantizer
+ fixed_point_value = base_quantizer(x, width, frac_width) * scale
fixed_point_value = fixed_point_value.to(torch.int)
fixed_point_value = fixed_point_value % (2**width)
return fixed_point_value
diff --git a/src/chop/passes/graph/analysis/add_metadata/add_common_metadata.py b/src/chop/passes/graph/analysis/add_metadata/add_common_metadata.py
index 346e486e9..3c1ff2efc 100644
--- a/src/chop/passes/graph/analysis/add_metadata/add_common_metadata.py
+++ b/src/chop/passes/graph/analysis/add_metadata/add_common_metadata.py
@@ -125,13 +125,18 @@ def graph_iterator_for_mase_ops(graph):
elif isinstance(module, GroupedQueryAttention):
mase_op = "grouped_query_attention"
else:
- mase_op = None
- for module_cls in graph.model.custom_ops["modules"].keys():
- if isinstance(module, module_cls):
- mase_op = "user_defined_module"
- break
- if mase_op is None:
- raise ValueError(f"Unknown module: {module_name}")
+ from chop.nn.quantized import ViTAttentionInteger
+
+ if isinstance(module, ViTAttentionInteger):
+ mase_op = "vit_self_attention_integer"
+ else:
+ mase_op = None
+ for module_cls in graph.model.custom_ops["modules"].keys():
+ if isinstance(module, module_cls):
+ mase_op = "user_defined_module"
+ break
+ if mase_op is None:
+ raise ValueError(f"Unknown module: {module_name}")
node.meta["mase"].parameters["common"]["mase_type"] = mase_type
node.meta["mase"].parameters["common"]["mase_op"] = mase_op
@@ -252,6 +257,8 @@ def graph_iterator_for_metadata(
# node.shape = result.shape
# node.dtype = result.dtype
+ # print(node.op, node.name, result)
+ # breakpoint()
node.meta["mase"] = analyse_fn(
node.meta["mase"], result, args, kwargs, add_value=add_value
)
diff --git a/src/chop/passes/graph/analysis/add_metadata/add_hardware_metadata.py b/src/chop/passes/graph/analysis/add_metadata/add_hardware_metadata.py
index 0e2d315ae..d4cb96ac1 100644
--- a/src/chop/passes/graph/analysis/add_metadata/add_hardware_metadata.py
+++ b/src/chop/passes/graph/analysis/add_metadata/add_hardware_metadata.py
@@ -47,13 +47,23 @@ def add_component_source(node):
node.meta["mase"]["hardware"]["dependence_files"] = op_info[
"dependence_files"
]
- elif mase_op in INTERNAL_COMP.keys():
- node.meta["mase"]["hardware"]["toolchain"] = "INTERNAL_RTL"
- # take the first ip in the component list by default
- node.meta["mase"]["hardware"]["module"] = INTERNAL_COMP[mase_op][0]["name"]
- node.meta["mase"]["hardware"]["dependence_files"] = INTERNAL_COMP[mase_op][0][
- "dependence_files"
- ]
+ elif any(mase_op in key for key in INTERNAL_COMP.keys()):
+ if node.meta["mase"].parameters["common"]["quant_type"] == "mxint_hardware":
+ node.meta["mase"]["hardware"]["toolchain"] = "INTERNAL_RTL"
+ # take the first ip in the component list by default
+ node.meta["mase"]["hardware"]["module"] = INTERNAL_COMP[
+ mase_op + "_mxint_hardware"
+ ][0]["name"]
+ node.meta["mase"]["hardware"]["dependence_files"] = INTERNAL_COMP[
+ mase_op + "_mxint_hardware"
+ ][0]["dependence_files"]
+ else:
+ node.meta["mase"]["hardware"]["toolchain"] = "INTERNAL_RTL"
+ # take the first ip in the component list by default
+ node.meta["mase"]["hardware"]["module"] = INTERNAL_COMP[mase_op][0]["name"]
+ node.meta["mase"]["hardware"]["dependence_files"] = INTERNAL_COMP[mase_op][
+ 0
+ ]["dependence_files"]
else:
node.meta["mase"]["hardware"]["toolchain"] = "INTERNAL_HLS"
node.meta["mase"]["hardware"]["module"] = None
@@ -96,7 +106,12 @@ def add_verilog_param(node):
else 1
)
# Check if max parallelism is defined
- if node.meta["mase"]["hardware"]["max_parallelism"] is not None:
+ if arg_info.get("parallelism") is not None:
+ # parallelism only support the last 2 dimension
+ vp[_cap(arg + f"_parallelism_dim_{dim}")] = (
+ arg_info["parallelism"][::-1][dim] if dim <= 1 else 1
+ )
+ elif node.meta["mase"]["hardware"]["max_parallelism"] is not None:
# Take the minimum between...
vp[_cap(arg + f"_parallelism_dim_{dim}")] = min(
# The defined max parallelism for this dimension
@@ -125,7 +140,12 @@ def add_verilog_param(node):
else 1
)
# Check if max parallelism is defined
- if node.meta["mase"]["hardware"]["max_parallelism"] is not None:
+ if result_info.get("parallelism") is not None:
+ # parallelism only support the last 2 dimension
+ vp[_cap(result + f"_parallelism_dim_{dim}")] = (
+ result_info["parallelism"][::-1][dim] if dim <= 1 else 1
+ )
+ elif node.meta["mase"]["hardware"]["max_parallelism"] is not None:
# Take the minimum between...
vp[_cap(result + f"_parallelism_dim_{dim}")] = min(
# The defined max parallelism for this dimension
diff --git a/src/chop/passes/graph/analysis/add_metadata/common_metadata_layers.py b/src/chop/passes/graph/analysis/add_metadata/common_metadata_layers.py
index 961e514f8..601194b86 100644
--- a/src/chop/passes/graph/analysis/add_metadata/common_metadata_layers.py
+++ b/src/chop/passes/graph/analysis/add_metadata/common_metadata_layers.py
@@ -272,6 +272,7 @@
"elu": {"input": "data_in"},
"softmax": {"input": "data_in"},
"gelu": {"input": "data_in"},
+ "vit_self_attention_integer": {"input": "data_in"},
"grouped_query_attention": {"input": "data_in"},
}
@@ -387,7 +388,6 @@ def match_args_and_kwargs(meta, args, kwargs, data, add_value):
ordered_func_data = [(k, v) for k, v in data.items()]
meta.parameters["common"]["args"] = {}
meta_kwargs = {}
-
arg_type, arg_precision = get_type_and_precision(meta)
# * Assign metadata for each argument
diff --git a/src/chop/passes/graph/analysis/add_metadata/hardware_metadata_layers.py b/src/chop/passes/graph/analysis/add_metadata/hardware_metadata_layers.py
index 778a08001..f6b228ab1 100644
--- a/src/chop/passes/graph/analysis/add_metadata/hardware_metadata_layers.py
+++ b/src/chop/passes/graph/analysis/add_metadata/hardware_metadata_layers.py
@@ -30,28 +30,84 @@
"normalization_layers/rtl/rms_norm_2d.sv",
"normalization_layers/rtl/batch_norm_2d.sv",
"normalization_layers/rtl/norm.sv",
+ "normalization_layers/rtl/layer_norm_1d.sv",
],
}
-
+linear = {
+ "name": "fixed_linear_with_input_circular",
+ "dependence_files": [
+ "cast/rtl/fixed_round.sv",
+ "cast/rtl/fixed_rounding.sv",
+ "cast/rtl/floor_round.sv",
+ "cast/rtl/signed_clamp.sv",
+ "cast/rtl/fixed_signed_cast.sv",
+ "linear_layers/fixed_operators/rtl/fixed_dot_product.sv",
+ "linear_layers/fixed_operators/rtl/fixed_vector_mult.sv",
+ "linear_layers/fixed_operators/rtl/fixed_accumulator.sv",
+ "linear_layers/fixed_operators/rtl/fixed_adder_tree.sv",
+ "linear_layers/fixed_operators/rtl/fixed_adder_tree_layer.sv",
+ "linear_layers/fixed_operators/rtl/fixed_mult.sv",
+ "common/rtl/register_slice.sv",
+ "common/rtl/join2.sv",
+ "common/rtl/mux.sv",
+ "common/rtl/unpacked_register_slice.sv",
+ "common/rtl/single_element_repeat.sv",
+ "memory/rtl/unpacked_repeat_circular_buffer.sv",
+ "memory/rtl/input_buffer.sv",
+ "memory/rtl/blk_mem_gen_0.sv",
+ "memory/rtl/simple_dual_port_ram.sv",
+ "linear_layers/fixed_linear_layer/rtl/fixed_linear_with_input_circular.sv",
+ "memory/rtl/fifo_for_autogen.sv",
+ "memory/rtl/unpacked_fifo.sv",
+ "memory/rtl/skid_buffer.sv",
+ "memory/rtl/unpacked_skid_buffer.sv",
+ "memory/rtl/simple_dual_port_ram.sv",
+ "memory/rtl/fifo.sv",
+ ],
+}
+unpacked_mx_split2_with_data = [
+ "linear_layers/mxint_operators/rtl/unpacked_mx_split2_with_data.sv",
+ "common/rtl/split2_with_data.sv",
+ "common/rtl/split2.sv",
+ "memory/rtl/fifo.sv",
+]
+mxint_cast = [
+ "linear_layers/mxint_operators/rtl/or_tree_layer.sv",
+ "linear_layers/mxint_operators/rtl/or_tree.sv",
+ "linear_layers/mxint_operators/rtl/log2_max_abs.sv",
+ "linear_layers/mxint_operators/rtl/mxint_cast.sv",
+ "linear_layers/mxint_operators/rtl/optimized_right_shift.sv"
+]
+mxint_linear = linear["dependence_files"] + unpacked_mx_split2_with_data + mxint_cast + [
+ "linear_layers/mxint_operators/rtl/mxint_linear.sv",
+ "linear_layers/mxint_operators/rtl/mxint_register_slice.sv",
+ "linear_layers/mxint_operators/rtl/mxint_skid_buffer.sv",
+ "linear_layers/mxint_operators/rtl/mxint_straightm_fifoe.sv",
+ "linear_layers/mxint_operators/rtl/mxint_accumulator.sv",
+ "linear_layers/mxint_operators/rtl/mxint_circular.sv",
+ "linear_layers/mxint_operators/rtl/mxint_dot_product.sv",
+ "linear_layers/mxint_operators/rtl/unpacked_mx_fifo.sv",
+ "common/rtl/join_n.sv",
+ ]
INTERNAL_COMP = {
- "linear": [
+ "linear": [linear],
+ "linear_mxint_hardware": [
+ {
+ "name": "mxint_linear",
+ "dependence_files": mxint_linear
+ }
+ ],
+ "fifo": [
{
- "name": "fixed_linear",
+ "name": "fifo_for_autogen",
"dependence_files": [
- "cast/rtl/fixed_cast.sv",
- "linear_layers/fixed_operators/rtl/fixed_dot_product.sv",
- "linear_layers/fixed_operators/rtl/fixed_vector_mult.sv",
- "linear_layers/fixed_operators/rtl/fixed_accumulator.sv",
- "linear_layers/fixed_operators/rtl/fixed_adder_tree.sv",
- "linear_layers/fixed_operators/rtl/fixed_adder_tree_layer.sv",
- "linear_layers/fixed_operators/rtl/fixed_mult.sv",
- "common/rtl/register_slice.sv",
- "common/rtl/join2.sv",
- "memory/rtl/unpacked_repeat_circular_buffer.sv",
+ "memory/rtl/fifo_for_autogen.sv",
+ "memory/rtl/unpacked_fifo.sv",
"memory/rtl/skid_buffer.sv",
- "linear_layers/fixed_linear_layer/rtl/fixed_linear.sv",
+ "memory/rtl/simple_dual_port_ram.sv",
+ "memory/rtl/fifo.sv",
],
- },
+ }
],
"relu": [
{
@@ -123,10 +179,19 @@
}
],
"batch_norm2d": [norm],
- "layer_norm": [norm],
"group_norm": [norm],
"instance_norm2d": [norm],
"rms_norm": [norm],
+ "layer_norm": [
+ {
+ "name": "layer_norm_2d",
+ "dependence_files": norm["dependence_files"]
+ + [
+ "normalization_layers/rtl/layer_norm_2d.sv",
+ "generated_lut/rtl/isqrt_lut.sv",
+ ],
+ },
+ ],
"selu": [
{
"name": "fixed_selu",
@@ -148,7 +213,62 @@
"name": "fixed_gelu",
"dependence_files": [
"activation_layers/rtl/fixed_gelu.sv",
- "activation_layers/rtl/gelu_lut.sv",
+ "generated_lut/rtl/gelu_lut.sv",
+ "common/rtl/unpacked_register_slice_quick.sv",
+ ],
+ },
+ ],
+ "gelu_mxint_hardware": [
+ {
+ "name": "mxint_gelu",
+ "dependence_files": [
+ "linear_layers/mxint_operators/rtl/mxint_gelu.sv",
+ "generated_lut/rtl/gelu_lut.sv",
+ "linear_layers/mxint_operators/rtl/mxint_register_slice.sv",
+ "linear_layers/mxint_operators/rtl/or_tree_layer.sv",
+ "linear_layers/mxint_operators/rtl/or_tree.sv",
+ "linear_layers/mxint_operators/rtl/log2_max_abs.sv",
+ "linear_layers/mxint_operators/rtl/mxint_accumulator.sv",
+ "linear_layers/mxint_operators/rtl/mxint_cast.sv",
+ "linear_layers/mxint_operators/rtl/mxint_circular.sv",
+ "linear_layers/mxint_operators/rtl/mxint_dot_product.sv",
+ "linear_layers/mxint_operators/rtl/unpacked_mx_fifo.sv",
+ "common/rtl/unpacked_register_slice_quick.sv",
+ ],
+ },
+ ],
+ "mx_int_patch_embed_mxint_hardware": [
+ {
+ "name": "mxint_patch_embed",
+ "dependence_files": mxint_linear +
+ [
+ "linear_layers/mxint_operators/rtl/mxint_patch_embed.sv",
+ "convolution_layers/rtl/sliding_window.sv",
+ "convolution_layers/rtl/padding.sv",
+ "convolution_layers/rtl/roller.sv",
+ ]
+ }
+ ],
+ "layer_norm_mxint_hardware": [
+ {
+ "name": "mxint_layernorm",
+ "dependence_files": norm["dependence_files"]
+ + [
+ "linear_layers/mxint_operators/rtl/mxint_layernorm.sv",
+ "linear_layers/mxint_operators/rtl/mxint_gelu.sv",
+ "generated_lut/rtl/isqrt_lut.sv",
+ "generated_lut/rtl/gelu_lut.sv",
+ "linear_layers/mxint_operators/rtl/mxint_register_slice.sv",
+ "linear_layers/mxint_operators/rtl/or_tree_layer.sv",
+ "linear_layers/mxint_operators/rtl/or_tree.sv",
+ "linear_layers/mxint_operators/rtl/log2_max_abs.sv",
+ "linear_layers/mxint_operators/rtl/mxint_accumulator.sv",
+ "linear_layers/mxint_operators/rtl/mxint_cast.sv",
+ "linear_layers/mxint_operators/rtl/mxint_circular.sv",
+ "linear_layers/mxint_operators/rtl/mxint_dot_product.sv",
+ "linear_layers/mxint_operators/rtl/unpacked_mx_fifo.sv",
+ "common/rtl/unpacked_register_slice_quick.sv",
+
],
},
],
@@ -177,6 +297,14 @@
],
}
],
+ "add_mxint_hardware": [
+ {
+ "name": "mxint_addition",
+ "dependence_files": [
+ "linear_layers/mxint_operators/rtl/mxint_addition.sv",
+ ],
+ },
+ ],
"mul": [
{
"name": "fixed_elementwise_multiplier",
@@ -191,6 +319,18 @@
"dependence_files": ["common/rtl/df_split.sv", "common/rtl/split2.sv"],
}
],
+ "fork2": [
+ {
+ "name": "fork2",
+ "dependence_files": ["common/rtl/fork2.sv"],
+ }
+ ],
+ "fork2_mxint_hardware": [
+ {
+ "name": "mxint_fork2",
+ "dependence_files": ["linear_layers/mxint_operators/rtl/mxint_fork2.sv"],
+ }
+ ],
"getitem": [
{
"name": "buffer",
@@ -199,6 +339,30 @@
],
}
],
+ "vit_self_attention_integer": [
+ {
+ "name": "fixed_vit_attention_single_precision_wrapper",
+ "dependence_files": linear["dependence_files"]
+ + [
+ "vision_models/vit/rtl/fixed_vit_attention_single_precision_wrapper.sv",
+ "vision_models/vit/rtl/fixed_vit_attention.sv",
+ "vision_models/vit/rtl/fixed_vit_attention_head.sv",
+ "transformer_layers/rtl/self_attention_head_single_scatter.sv",
+ "transformer_layers/rtl/gqa_head_scatter_control.sv",
+ "transformer_layers/rtl/self_attention_head_gather.sv",
+ "vision_models/vit/rtl/fixed_vit_attention_input_block_batched.sv",
+ "transformer_layers/rtl/self_attention_head_scatter.sv",
+ "activation_layers/rtl/fixed_softmax.sv",
+ "scalar_operators/fixed/rtl/fixed_div.sv",
+ "generated_lut/rtl/exp_lut.sv",
+ "common/rtl/find_first_arbiter.sv",
+ "common/rtl/split2.sv",
+ "common/rtl/split_n.sv",
+ "memory/rtl/unpacked_fifo.sv",
+ "memory/rtl/unpacked_skid_buffer.sv",
+ ],
+ }
+ ],
"grouped_query_attention": [
{
"name": "fixed_gqa_wrapper",
diff --git a/src/chop/passes/graph/transforms/__init__.py b/src/chop/passes/graph/transforms/__init__.py
index 612773262..e5a220a3f 100644
--- a/src/chop/passes/graph/transforms/__init__.py
+++ b/src/chop/passes/graph/transforms/__init__.py
@@ -8,6 +8,7 @@
emit_cocotb_transform_pass,
emit_verilog_top_transform_pass,
emit_vivado_project_transform_pass,
+ insert_fork_transform_pass,
)
from .utils import (
conv_bn_fusion_transform_pass,
diff --git a/src/chop/passes/graph/transforms/quantize/modify.py b/src/chop/passes/graph/transforms/quantize/modify.py
index 4ea1145f4..088fd3ade 100644
--- a/src/chop/passes/graph/transforms/quantize/modify.py
+++ b/src/chop/passes/graph/transforms/quantize/modify.py
@@ -164,7 +164,7 @@ def create_new_module(
new_module = new_module_cls(config=config)
elif mase_op == "gelu":
new_module_cls = quantized_module_map[f"gelu_{quant_name}"]
- new_module = new_module_cls(inplace=original_module.inplace, config=config)
+ new_module = new_module_cls(config=config)
elif mase_op == "softsign":
new_module_cls = quantized_module_map[f"softsign_{quant_name}"]
new_module = new_module_cls(inplace=original_module.inplace, config=config)
@@ -203,13 +203,17 @@ def create_new_module(
copy_weights(original_module.bias, new_module.bias)
elif mase_op == "layer_norm":
new_module_cls = quantized_module_map[f"layer_norm_{quant_name}"]
+
new_module = new_module_cls(
normalized_shape=original_module.normalized_shape,
eps=original_module.eps,
elementwise_affine=original_module.elementwise_affine,
- bias=original_module.bias,
config=config,
)
+ if original_module.elementwise_affine:
+ new_module.weight = original_module.weight
+ if original_module.bias is not None:
+ new_module.bias = original_module.bias
elif mase_op == "group_norm":
new_module_cls = quantized_module_map[f"group_norm_{quant_name}"]
new_module = new_module_cls(
diff --git a/src/chop/passes/graph/transforms/quantize/quant_parsers/parse_quant_config.py b/src/chop/passes/graph/transforms/quantize/quant_parsers/parse_quant_config.py
index e027b0819..70b997114 100644
--- a/src/chop/passes/graph/transforms/quantize/quant_parsers/parse_quant_config.py
+++ b/src/chop/passes/graph/transforms/quantize/quant_parsers/parse_quant_config.py
@@ -23,10 +23,17 @@
"data_in_entries": ("data_in_width", "data_in_frac_width"),
"bias_entries": ("bias_width", "bias_frac_width"),
},
+ "integer_floor": {
+ "weight_entries": ("weight_width", "weight_frac_width"),
+ "data_in_entries": ("data_in_width", "data_in_frac_width"),
+ "bias_entries": ("bias_width", "bias_frac_width"),
+ "data_out_entries": ("data_out_width", "data_out_frac_width"),
+ },
"fixed": {
"weight_entries": ("weight_width", "weight_frac_width"),
"data_in_entries": ("data_in_width", "data_in_frac_width"),
"bias_entries": ("bias_width", "bias_frac_width"),
+ "data_out_entries": ("data_out_width", "data_out_frac_width"),
},
"lutnet": {
"weight_entries": (
@@ -261,6 +268,11 @@
"bias_exponent_width",
"bias_parallelism",
),
+ "data_out_entries": (
+ "data_out_width",
+ "data_out_exponent_width",
+ "data_out_parallelism",
+ ),
},
}
@@ -278,6 +290,10 @@ def cp_bypass(config: dict, p_config: dict, entries=None, strict: bool = True):
cp_multi_values(config, p_config, ("bypass",), strict=strict)
+def cp_floor(config: dict, p_config: dict, entries=None, strict: bool = True):
+ cp_multi_values(config, p_config, ("floor",), strict=strict)
+
+
def cp_weight_entries(config: dict, p_config: dict, entries: dict, strict: bool = True):
cp_multi_values(config, p_config, entries["weight_entries"], strict=strict)
@@ -339,6 +355,7 @@ def cp_data_out_entries(
QUANT_ARITH_TO_CP_FN[quant_arith] = {
"name": partial(cp_name, entries=entries),
"bypass": partial(cp_bypass, entries=entries),
+ "floor": partial(cp_floor, entries=entries),
"weight_entries": partial(cp_weight_entries, entries=entries),
"data_in_entries": partial(cp_data_in_entries, entries=entries),
"bias_entries": partial(cp_bias_entries, entries=entries),
@@ -366,12 +383,18 @@ def cp_data_out_entries(
"mul": (("name", "data_in_entries"), ("bypass",)),
"linear": (
("name", "data_in_entries", "weight_entries"),
- ("bias_entries", "bypass", "data_out_entries", "additional_layers_entries"),
+ (
+ "bias_entries",
+ "bypass",
+ "data_out_entries",
+ "additional_layers_entries",
+ "floor",
+ ),
),
"relu": (("name", "data_in_entries"), ("bypass",)),
"selu": (("name", "data_in_entries"), ("bypass",)),
"tanh": (("name", "data_in_entries"), ("bypass",)),
- "gelu": (("name", "data_in_entries"), ("bypass",)),
+ "gelu": (("name", "data_in_entries"), ("data_out_entries", "bypass")),
"softplus": (("name", "data_in_entries"), ("bypass",)),
"softsign": (("name", "data_in_entries"), ("bypass",)),
"sub": (("name", "data_in_entries"), ("bypass",)),
@@ -385,7 +408,7 @@ def cp_data_out_entries(
),
"layer_norm": (
("name", "data_in_entries"),
- ("bypass",),
+ ("bypass", "isqrt_in_entries", "isqrt_out_entries", "data_out_entries"),
),
"group_norm": (
("name", "data_in_entries"),
@@ -423,6 +446,8 @@ def parse_node_config(config: dict, mase_op: str, strict: bool = True) -> dict:
a missing `bias_frac_width` in linear node config
"""
assert mase_op in MASE_OP_TO_ENTRIES, f"Unknown mase op: {mase_op}"
+ if config.get("noparse", False):
+ return config
if config.get("bypass", False):
return config
op_entries, op_optional_entries = MASE_OP_TO_ENTRIES[mase_op]
diff --git a/src/chop/passes/graph/transforms/quantize/quant_parsers/update_node_meta.py b/src/chop/passes/graph/transforms/quantize/quant_parsers/update_node_meta.py
index 0c580c4f2..7161006b6 100644
--- a/src/chop/passes/graph/transforms/quantize/quant_parsers/update_node_meta.py
+++ b/src/chop/passes/graph/transforms/quantize/quant_parsers/update_node_meta.py
@@ -9,6 +9,7 @@ def entry_to_list(config: dict, entry: str, suffixes: tuple[str]):
QUANT_ARITH_TO_SUFFIXES = {
"integer": ("width", "frac_width"),
"fixed": ("width", "frac_width"),
+ "integer_floor": ("width", "frac_width"),
"binary": (
"width",
"stochastic",
@@ -69,7 +70,10 @@ def update_arg(node, arg_name, dtype=None, precision=None, size=None):
"softplus": (("data_in",), ("data_in_0",)),
"sub": (("data_in", "data_in"), ("data_in_0", "data_in_1")),
"batch_norm2d": (("data_in", "weight", "bias"), ("data_in_0", "weight", "bias")),
- "layer_norm": (("data_in",), ("data_in_0",)),
+ "layer_norm": (("data_in", "weight", "bias"), ("data_in_0", "weight", "bias")),
+ "group_norm": (("data_in",), ("data_in_0")),
+ "instance_norm2d": (("data_in",), ("data_in_0")),
+ "rms_norm": (("data_in",), ("data_in_0")),
"group_norm": (("data_in",), ("data_in_0",)),
"instance_norm2d": (("data_in",), ("data_in_0",)),
"rms_norm": (("data_in",), ("data_in_0",)),
diff --git a/src/chop/passes/graph/transforms/quantize/quantize.py b/src/chop/passes/graph/transforms/quantize/quantize.py
index e3682fcdc..72f363b01 100644
--- a/src/chop/passes/graph/transforms/quantize/quantize.py
+++ b/src/chop/passes/graph/transforms/quantize/quantize.py
@@ -230,9 +230,15 @@ def quantize_transform_pass(graph, pass_args=None):
# weight
"weight_width": 8,
"weight_frac_width": 4,
+
+ # optional
# bias
"bias_width": 8,
"bias_frac_width": 4,
+ "data_out_width": 8,
+ "data_out_frac_width": 4,
+ # quantize method
+ "floor": True,
}
},
}
@@ -246,7 +252,7 @@ def quantize_transform_pass(graph, pass_args=None):
- by -> str : different quantization schemes choose from ["type", "name", "regx_name"]
"""
- by = pass_args.pop("by")
+ by = pass_args.get("by")
match by:
case "type":
graph = graph_iterator_quantize_by_type(graph, pass_args)
diff --git a/src/chop/passes/graph/transforms/verilog/__init__.py b/src/chop/passes/graph/transforms/verilog/__init__.py
index 262e7905f..573fdadc3 100644
--- a/src/chop/passes/graph/transforms/verilog/__init__.py
+++ b/src/chop/passes/graph/transforms/verilog/__init__.py
@@ -5,3 +5,4 @@
from .emit_internal import emit_internal_rtl_transform_pass
from .emit_logicnets import emit_logicnets_transform_pass
from .emit_vivado_project import emit_vivado_project_transform_pass
+from .insert_fork import insert_fork_transform_pass
diff --git a/src/chop/passes/graph/transforms/verilog/emit_bram.py b/src/chop/passes/graph/transforms/verilog/emit_bram.py
index 8aeeb663f..eebe182cd 100644
--- a/src/chop/passes/graph/transforms/verilog/emit_bram.py
+++ b/src/chop/passes/graph/transforms/verilog/emit_bram.py
@@ -28,6 +28,168 @@ def _cap(name):
return str(name).upper()
+def emit_mxint_parameters_in_mem_internal(node, param_name, file_name, data_name):
+ """
+ Emit single-port ROM hardware components for each parameter
+ (Mostly because Vivado does not support string type parameters...)
+ """
+ # ! TO DO: currently emitting too many parameters
+
+ verilog_param_name = param_name.replace(".", "_")
+ total_size = math.prod(
+ node.meta["mase"].parameters["common"]["args"][verilog_param_name]["shape"]
+ )
+ # Currently edata will be merged into mdata so out_size = paral1 * paral0 + 1
+ out_size = int(
+ node.meta["mase"].parameters["hardware"]["verilog_param"][
+ f"{_cap(verilog_param_name)}_PARALLELISM_DIM_0"
+ ]
+ * node.meta["mase"].parameters["hardware"]["verilog_param"][
+ f"{_cap(verilog_param_name)}_PARALLELISM_DIM_1"
+ ]
+ )
+ out_depth = int(total_size / out_size)
+ out_width = int(
+ node.meta["mase"].parameters["common"]["args"][verilog_param_name]["precision"][
+ 0
+ ]
+ )
+ out_exponent_width = int(
+ node.meta["mase"].parameters["common"]["args"][verilog_param_name]["precision"][
+ 1
+ ]
+ )
+
+ addr_width = clog2(out_depth) + 1
+
+ node_param_name = f"{vf(node.name)}_{verilog_param_name}"
+
+ rom_verilog = f"""
+// =====================================
+// Mase Hardware
+// Parameter: {node_param_name}
+// {time.strftime('%d/%m/%Y %H:%M:%S')}
+// =====================================
+
+`timescale 1 ns / 1 ps
+module {node_param_name}_rom #(
+ parameter DWIDTH = {out_size*out_width + out_exponent_width},
+ parameter MEM_SIZE = {out_depth},
+ parameter AWIDTH = $clog2(MEM_SIZE) + 1
+) (
+ input clk,
+ input logic [AWIDTH-1:0] addr0,
+ input ce0,
+ output logic [DWIDTH-1:0] q0
+);
+
+ logic [DWIDTH-1:0] ram[0:MEM_SIZE-1];
+ logic [DWIDTH-1:0] q0_t0;
+ logic [DWIDTH-1:0] q0_t1;
+
+ initial begin
+ $readmemb("{data_name}", ram);
+ end
+
+ assign q0 = q0_t1;
+
+ always_ff @(posedge clk) if (ce0) q0_t1 <= q0_t0;
+ always_ff @(posedge clk) if (ce0) q0_t0 <= ram[addr0];
+
+endmodule
+
+`timescale 1 ns / 1 ps
+module {node_param_name} #(
+ parameter DATA_WIDTH = 32'd{out_width*out_size + out_exponent_width},
+ parameter ADDR_RANGE = 32'd{out_depth},
+ parameter ADDR_WIDTH = $clog2(ADDR_RANGE) + 1
+) (
+ input reset,
+ input clk,
+ input logic [ADDR_WIDTH - 1:0] address0,
+ input ce0,
+ output logic [DATA_WIDTH - 1:0] q0
+);
+
+ {node_param_name}_rom {node_param_name}_rom_U (
+ .clk(clk),
+ .addr0(address0),
+ .ce0(ce0),
+ .q0(q0)
+ );
+
+endmodule
+
+
+`timescale 1ns / 1ps
+module {node_param_name}_source #(
+ parameter {_cap(verilog_param_name)}_TENSOR_SIZE_DIM_0 = -1,
+ parameter {_cap(verilog_param_name)}_TENSOR_SIZE_DIM_1 = -1,
+ parameter {_cap(verilog_param_name)}_PRECISION_0 = -1,
+ parameter {_cap(verilog_param_name)}_PRECISION_1 = -1,
+
+ parameter {_cap(verilog_param_name)}_PARALLELISM_DIM_0 = -1,
+ parameter {_cap(verilog_param_name)}_PARALLELISM_DIM_1 = -1,
+ parameter OUT_DEPTH = ({_cap(verilog_param_name)}_TENSOR_SIZE_DIM_0 / {_cap(verilog_param_name)}_PARALLELISM_DIM_0) * ({_cap(verilog_param_name)}_TENSOR_SIZE_DIM_1 / {_cap(verilog_param_name)}_PARALLELISM_DIM_1)
+) (
+ input clk,
+ input rst,
+
+ output logic [{_cap(verilog_param_name)}_PRECISION_0-1:0] mdata_out [{_cap(verilog_param_name)}_PARALLELISM_DIM_0 * {_cap(verilog_param_name)}_PARALLELISM_DIM_1-1:0],
+ output logic [{_cap(verilog_param_name)}_PRECISION_1-1:0] edata_out,
+ output data_out_valid,
+ input data_out_ready
+);
+ // 1-bit wider so IN_DEPTH also fits.
+ localparam COUNTER_WIDTH = $clog2(OUT_DEPTH);
+ logic [COUNTER_WIDTH:0] counter;
+ always_ff @(posedge clk)
+ if (rst) counter <= 0;
+ else begin
+ if (data_out_ready) begin
+ if (counter == OUT_DEPTH - 1) counter <= 0;
+ else counter <= counter + 1;
+ end
+ end
+ logic [1:0] clear;
+ always_ff @(posedge clk)
+ if (rst) clear <= 0;
+ else if ((data_out_ready == 1) && (clear != 2)) clear <= clear + 1;
+ logic ce0;
+ assign ce0 = data_out_ready;
+
+ localparam TOTAL_WIDTH = {_cap(verilog_param_name)}_PRECISION_0*({_cap(verilog_param_name)}_PARALLELISM_DIM_0*{_cap(verilog_param_name)}_PARALLELISM_DIM_1) + {_cap(verilog_param_name)}_PRECISION_1;
+ logic [TOTAL_WIDTH-1:0] data_vector;
+ {node_param_name} #(
+ .DATA_WIDTH(TOTAL_WIDTH),
+ .ADDR_RANGE(OUT_DEPTH)
+ ) {node_param_name}_mem (
+ .clk(clk),
+ .reset(rst),
+ .address0(counter),
+ .ce0(ce0),
+ .q0(data_vector)
+ );
+
+ // Cocotb/verilator does not support array flattening, so
+ // we need to manually add some reshaping process.
+ for (genvar j = 0; j < {_cap(verilog_param_name)}_PARALLELISM_DIM_0 * {_cap(verilog_param_name)}_PARALLELISM_DIM_1; j++)
+ assign mdata_out[j] = data_vector[{_cap(verilog_param_name)}_PRECISION_0*j+{_cap(verilog_param_name)}_PRECISION_0-1 + {_cap(verilog_param_name)}_PRECISION_1:{_cap(verilog_param_name)}_PRECISION_0*j + {_cap(verilog_param_name)}_PRECISION_1];
+ assign edata_out = data_vector[{_cap(verilog_param_name)}_PRECISION_1-1 : 0];
+ assign data_out_valid = clear == 2;
+
+endmodule
+"""
+
+ with open(file_name, "w", encoding="utf-8") as outf:
+ outf.write(rom_verilog)
+ logger.debug(
+ f"ROM module {verilog_param_name} successfully written into {file_name}"
+ )
+ assert os.path.isfile(file_name), "ROM Verilog generation failed."
+ # os.system(f"verible-verilog-format --inplace {file_name}")
+
+
def emit_parameters_in_mem_internal(node, param_name, file_name, data_name):
"""
Emit single-port ROM hardware components for each parameter
@@ -84,7 +246,7 @@ def emit_parameters_in_mem_internal(node, param_name, file_name, data_name):
logic [DWIDTH-1:0] q0_t1;
initial begin
- $readmemh("{data_name}", ram);
+ $readmemb("{data_name}", ram);
end
assign q0 = q0_t1;
@@ -119,14 +281,14 @@ def emit_parameters_in_mem_internal(node, param_name, file_name, data_name):
`timescale 1ns / 1ps
module {node_param_name}_source #(
- parameter {_cap(verilog_param_name)}_TENSOR_SIZE_DIM_0 = 32,
- parameter {_cap(verilog_param_name)}_TENSOR_SIZE_DIM_1 = 1,
- parameter {_cap(verilog_param_name)}_PRECISION_0 = 16,
- parameter {_cap(verilog_param_name)}_PRECISION_1 = 3,
-
- parameter {_cap(verilog_param_name)}_PARALLELISM_DIM_0 = 1,
- parameter {_cap(verilog_param_name)}_PARALLELISM_DIM_1 = 1,
- parameter OUT_DEPTH = {_cap(verilog_param_name)}_TENSOR_SIZE_DIM_0 / {_cap(verilog_param_name)}_PARALLELISM_DIM_0
+ parameter {_cap(verilog_param_name)}_TENSOR_SIZE_DIM_0 = -1,
+ parameter {_cap(verilog_param_name)}_TENSOR_SIZE_DIM_1 = -1,
+ parameter {_cap(verilog_param_name)}_PRECISION_0 = -1,
+ parameter {_cap(verilog_param_name)}_PRECISION_1 = -1,
+
+ parameter {_cap(verilog_param_name)}_PARALLELISM_DIM_0 = -1,
+ parameter {_cap(verilog_param_name)}_PARALLELISM_DIM_1 = -1,
+ parameter OUT_DEPTH = ({_cap(verilog_param_name)}_TENSOR_SIZE_DIM_0 / {_cap(verilog_param_name)}_PARALLELISM_DIM_0) * ({_cap(verilog_param_name)}_TENSOR_SIZE_DIM_1 / {_cap(verilog_param_name)}_PARALLELISM_DIM_1)
) (
input clk,
input rst,
@@ -138,7 +300,6 @@ def emit_parameters_in_mem_internal(node, param_name, file_name, data_name):
// 1-bit wider so IN_DEPTH also fits.
localparam COUNTER_WIDTH = $clog2(OUT_DEPTH);
logic [COUNTER_WIDTH:0] counter;
-
always_ff @(posedge clk)
if (rst) counter <= 0;
else begin
@@ -147,13 +308,16 @@ def emit_parameters_in_mem_internal(node, param_name, file_name, data_name):
else counter <= counter + 1;
end
end
-
+ logic [1:0] clear;
+ always_ff @(posedge clk)
+ if (rst) clear <= 0;
+ else if ((data_out_ready == 1) && (clear != 2)) clear <= clear + 1;
logic ce0;
- assign ce0 = 1;
+ assign ce0 = data_out_ready;
- logic [{_cap(verilog_param_name)}_PRECISION_0*{_cap(verilog_param_name)}_TENSOR_SIZE_DIM_0-1:0] data_vector;
+ logic [{_cap(verilog_param_name)}_PRECISION_0*{_cap(verilog_param_name)}_PARALLELISM_DIM_0*{_cap(verilog_param_name)}_PARALLELISM_DIM_1-1:0] data_vector;
{node_param_name} #(
- .DATA_WIDTH({_cap(verilog_param_name)}_PRECISION_0 * {_cap(verilog_param_name)}_TENSOR_SIZE_DIM_0),
+ .DATA_WIDTH({_cap(verilog_param_name)}_PRECISION_0 * {_cap(verilog_param_name)}_PARALLELISM_DIM_0 * {_cap(verilog_param_name)}_PARALLELISM_DIM_1),
.ADDR_RANGE(OUT_DEPTH)
) {node_param_name}_mem (
.clk(clk),
@@ -168,7 +332,7 @@ def emit_parameters_in_mem_internal(node, param_name, file_name, data_name):
for (genvar j = 0; j < {_cap(verilog_param_name)}_PARALLELISM_DIM_0 * {_cap(verilog_param_name)}_PARALLELISM_DIM_1; j++)
assign data_out[j] = data_vector[{_cap(verilog_param_name)}_PRECISION_0*j+{_cap(verilog_param_name)}_PRECISION_0-1:{_cap(verilog_param_name)}_PRECISION_0*j];
- assign data_out_valid = 1;
+ assign data_out_valid = clear == 2;
endmodule
"""
@@ -204,27 +368,40 @@ def emit_parameters_in_dat_internal(node, param_name, file_name):
out_depth = int(total_size / out_size)
data_buff = ""
- param_data = node.meta["mase"].module.get_parameter(param_name).data
+ param_data = (
+ node.meta["mase"].parameters["common"]["args"][verilog_param_name]["value"].data
+ )
+ param_meta = node.meta["mase"].parameters["hardware"]["verilog_param"]
+ # TODO: Currently only support tranpose linear
+
if node.meta["mase"].parameters["hardware"]["interface"][verilog_param_name][
"transpose"
]:
+ raise NotImplementedError("only support linear with not tranposed weight")
+ else:
+ assert (
+ param_meta[f"{_cap(verilog_param_name)}_TENSOR_SIZE_DIM_1"]
+ % param_meta[f"{_cap(verilog_param_name)}_PARALLELISM_DIM_1"]
+ == 0
+ ) and (
+ param_meta[f"{_cap(verilog_param_name)}_TENSOR_SIZE_DIM_0"]
+ % param_meta[f"{_cap(verilog_param_name)}_PARALLELISM_DIM_0"]
+ == 0
+ ), "The parallesim parameter must be divisible by the tensor size parameter."
param_data = torch.reshape(
param_data,
(
- node.meta["mase"].parameters["hardware"]["verilog_param"][
- "DATA_OUT_0_SIZE"
- ],
- node.meta["mase"].parameters["hardware"]["verilog_param"][
- "DATA_IN_0_DEPTH"
- ],
- node.meta["mase"].parameters["hardware"]["verilog_param"][
- "DATA_IN_0_SIZE"
- ],
+ -1,
+ param_meta[f"{_cap(verilog_param_name)}_TENSOR_SIZE_DIM_1"]
+ // param_meta[f"{_cap(verilog_param_name)}_PARALLELISM_DIM_1"],
+ param_meta[f"{_cap(verilog_param_name)}_PARALLELISM_DIM_1"],
+ param_meta[f"{_cap(verilog_param_name)}_TENSOR_SIZE_DIM_0"]
+ // param_meta[f"{_cap(verilog_param_name)}_PARALLELISM_DIM_0"],
+ param_meta[f"{_cap(verilog_param_name)}_PARALLELISM_DIM_0"],
),
)
- param_data = torch.transpose(param_data, 0, 1)
+ param_data = param_data.permute(0, 1, 3, 2, 4)
param_data = torch.flatten(param_data).tolist()
-
if (
node.meta["mase"].parameters["common"]["args"][verilog_param_name]["type"]
== "fixed"
@@ -241,18 +418,54 @@ def emit_parameters_in_dat_internal(node, param_name, file_name):
for i in range(0, out_depth):
line_buff = ""
for j in range(0, out_size):
- value = param_data[i * out_size + out_size - 1 - j]
+ value = param_data[i * out_size + j]
value = integer_quantizer_for_hw(
- torch.tensor(value), width, frac_width
+ torch.tensor(value), width, frac_width, floor=True
).item()
- value = str(bin(int(value * scale) % thresh))
+ value = str(bin(value))
value_bits = value[value.find("0b") + 2 :]
value_bits = "0" * (width - len(value_bits)) + value_bits
assert len(value_bits) == width
line_buff += value_bits
-
hex_buff = hex(int(line_buff, 2))
- data_buff += hex_buff[hex_buff.find("0x") + 2 :] + "\n"
+
+ data_buff += line_buff + "\n"
+ elif (
+ node.meta["mase"].parameters["common"]["args"][verilog_param_name]["type"]
+ == "mxint_hardware"
+ ):
+ width = node.meta["mase"].parameters["common"]["args"][verilog_param_name][
+ "precision"
+ ][0]
+ exponent_width = node.meta["mase"].parameters["common"]["args"][
+ verilog_param_name
+ ]["precision"][1]
+ from mase_components.linear_layers.mxint_operators.test.utils import mxint_quant_block
+
+ line_buff = ""
+
+ # assert (width >= exponent_width),"current only support width >= exponent_width"
+ def convert_to_bit(value, width):
+ value_bits = str(bin(int(value) & (2**width - 1)))
+ value_bits = value_bits[value_bits.find("0b") + 2 :]
+ value_bits = "0" * (width - len(value_bits)) + value_bits
+ assert len(value_bits) == width
+ return value_bits
+
+ for i in range(0, out_depth):
+ line_buff = ""
+ block_data = param_data[i * out_size : i * out_size + out_size]
+ value, mvalue, evalue = mxint_quant_block(
+ torch.tensor(block_data), width, exponent_width, round_bits=8,
+ )
+
+ for j in range(0, out_size):
+ value_bits = convert_to_bit(mvalue[j], width)
+ line_buff += value_bits
+ evalue_bits = convert_to_bit(evalue, exponent_width)
+ line_buff += evalue_bits
+
+ data_buff += line_buff + "\n"
else:
assert False, "Emitting non-fixed parameters is not supported."
@@ -349,7 +562,14 @@ def emit_bram_handshake(node, rtl_dir):
data_name = os.path.join(
rtl_dir, f"{node_name}_{param_verilog_name}_rom.dat"
)
- emit_parameters_in_mem_internal(node, param_name, verilog_name, data_name)
+ if node.meta["mase"].parameters["common"]["quant_type"] == "mxint_hardware":
+ emit_mxint_parameters_in_mem_internal(
+ node, param_name, verilog_name, data_name
+ )
+ else:
+ emit_parameters_in_mem_internal(
+ node, param_name, verilog_name, data_name
+ )
emit_parameters_in_dat_internal(node, param_name, data_name)
else:
assert False, "Emtting parameters in non-BRAM hardware is not supported."
diff --git a/src/chop/passes/graph/transforms/verilog/emit_tb.py b/src/chop/passes/graph/transforms/verilog/emit_tb.py
index 1afac2fea..4c06ef5df 100644
--- a/src/chop/passes/graph/transforms/verilog/emit_tb.py
+++ b/src/chop/passes/graph/transforms/verilog/emit_tb.py
@@ -20,6 +20,8 @@
import dill
import inspect
+torch.manual_seed(0)
+
def _cap(name):
"""
@@ -57,14 +59,17 @@ async def test(dut):
await tb.wait_end(timeout={wait_time}, timeout_unit="{wait_unit}")
"""
-
- tb_path = Path.home() / ".mase" / "top" / "hardware" / "test" / "mase_top_tb"
+ tb_path = (
+ pass_args["project_dir"] / "hardware" / "test" / "mase_top_tb"
+ if "project_dir" in pass_args.keys()
+ else Path.home() / ".mase" / "top" / "hardware" / "test" / "mase_top_tb"
+ )
tb_path.mkdir(parents=True, exist_ok=True)
with open(tb_path / "test.py", "w") as f:
f.write(test_template)
-def _emit_cocotb_tb(graph):
+def _emit_cocotb_tb(graph, pass_args={}):
class MaseGraphTB(Testbench):
def __init__(self, dut, fail_on_checks=True):
super().__init__(dut, dut.clk, dut.rst, fail_on_checks=fail_on_checks)
@@ -145,6 +150,7 @@ def load_drivers(self, in_tensors):
self.get_parameter(f"{_cap(arg)}_PARALLELISM_DIM_1"),
self.get_parameter(f"{_cap(arg)}_PARALLELISM_DIM_0"),
],
+ floor=True,
)
else:
@@ -175,6 +181,7 @@ def load_monitors(self, expectation):
self.get_parameter(f"DATA_OUT_0_PARALLELISM_DIM_1"),
self.get_parameter(f"DATA_OUT_0_PARALLELISM_DIM_0"),
],
+ floor=True,
)
# Set expectation for each monitor
@@ -189,11 +196,183 @@ def load_monitors(self, expectation):
# Drive the in-flight flag for each monitor
self.output_monitors["data_out_0"].in_flight = True
+ # Serialize testbench object to be instantiated within test by cocotb runner
+ cls_obj = MaseGraphTB
+ tb_path = (
+ pass_args["project_dir"] / "hardware" / "test" / "mase_top_tb"
+ if "project_dir" in pass_args.keys()
+ else Path.home() / ".mase" / "top" / "hardware" / "test" / "mase_top_tb"
+ )
+ tb_path.mkdir(parents=True, exist_ok=True)
+ with open(tb_path / "tb_obj.dill", "wb") as file:
+ import sys
+
+ sys.setrecursionlimit(10000) # Increase recursion limit
+ dill.dump(cls_obj, file)
+ with open(tb_path / "__init__.py", "w") as file:
+ file.write("from .test import test")
+
+
+from mase_components.linear_layers.mxint_operators.test.utils import (
+ mxint_hardware,
+ pack_tensor_to_mx_listed_chunk,
+)
+from mase_cocotb.interfaces.streaming import (
+ MultiSignalStreamDriver,
+ MultiSignalStreamMonitor,
+)
+from cocotb.triggers import Timer, RisingEdge, ReadOnly
+
+
+async def check_signal(dut):
+ await Timer(40, units="ns")
+ while True:
+ await RisingEdge(dut.clk)
+ await ReadOnly()
+ weight_0 = dut.fc1_weight_source_0
+ # if weight_0.data_out_ready.value == 1 and weight_0.data_out_valid.value == 1:
+ # print("mdata_out = ",[x for x in weight_0.mdata_out.value])
+ # print("edata_out = ",weight_0.edata_out.value.signed_integer)
+ print(weight_0.data_vector)
+ print(weight_0.fc1_weight_mem.fc1_weight_rom_U.q0_t0)
+ print(weight_0.fc1_weight_mem.fc1_weight_rom_U.addr0)
+ print(weight_0.fc1_weight_mem.fc1_weight_rom_U.DWIDTH)
+ print(weight_0.fc1_weight_mem.fc1_weight_rom_U.MEM_SIZE)
+ print([x.value for x in weight_0.fc1_weight_mem.fc1_weight_rom_U.ram])
+ print("end")
+
+
+def _emit_cocotb_tb_for_mxint(graph):
+ class MaseGraphTB(Testbench):
+ def __init__(self, dut, fail_on_checks=True):
+ super().__init__(dut, dut.clk, dut.rst, fail_on_checks=fail_on_checks)
+
+ # cocotb.start_soon(check_signal(dut))
+ # Instantiate as many drivers as required inputs to the model
+ self.input_drivers = {}
+ self.output_monitors = {}
+
+ for node in graph.nodes_in:
+ for arg in node.meta["mase"]["common"]["args"].keys():
+ if "data_in" not in arg:
+ continue
+ self.input_drivers[arg] = MultiSignalStreamDriver(
+ dut.clk,
+ (getattr(dut, "m" + arg), getattr(dut, "e" + arg)),
+ getattr(dut, f"{arg}_valid"),
+ getattr(dut, f"{arg}_ready"),
+ )
+ self.input_drivers[arg].log.setLevel(logging.DEBUG)
+
+ # Instantiate as many monitors as required outputs
+ for node in graph.nodes_out:
+ for result in node.meta["mase"]["common"]["results"].keys():
+ if "data_out" not in result:
+ continue
+ self.output_monitors[result] = MultiSignalStreamMonitor(
+ dut.clk,
+ (getattr(dut, "m" + result), getattr(dut, "e" + result)),
+ getattr(dut, f"{result}_valid"),
+ getattr(dut, f"{result}_ready"),
+ check=False,
+ )
+ self.output_monitors[result].log.setLevel(logging.DEBUG)
+
+ self.model = graph.model
+
+ # To do: precision per input argument
+ self.input_precision = graph.meta["mase"]["common"]["args"]["data_in_0"][
+ "precision"
+ ]
+
+ def generate_inputs(self, batches):
+ """
+ Generate inputs for the model by sampling a random tensor
+ for each input argument, according to its shape
+
+ :param batches: number of batches to generate for each argument
+ :type batches: int
+ :return: a dictionary of input arguments and their corresponding tensors
+ :rtype: Dict
+ """
+ # ! TO DO: iterate through graph.args instead to generalize
+ inputs = {}
+ for node in graph.nodes_in:
+ for arg, arg_info in node.meta["mase"]["common"]["args"].items():
+ # Batch dimension always set to 1 in metadata
+ if "data_in" not in arg:
+ continue
+ # print(f"Generating data for node {node}, arg {arg}: {arg_info}")
+ inputs[f"{arg}"] = torch.rand(([batches] + arg_info["shape"][1:]))
+ return inputs
+
+ def preprocess_tensor_for_mxint(self, tensor, q_config, parallelism):
+ (qtensor, mtensor, etensor) = block_mxint_quant(
+ tensor, q_config, parallelism
+ )
+ tensor_inputs = pack_tensor_to_mx_listed_chunk(
+ mtensor, etensor, parallelism
+ )
+ return tensor_inputs
+
+ def load_drivers(self, in_tensors):
+ for arg, arg_batches in in_tensors.items():
+ # Quantize input tensor according to precision
+ if len(self.input_precision) > 1:
+ in_data_blocks = self.preprocess_tensor_for_mxint(
+ tensor=arg_batches,
+ q_config={
+ "width": self.get_parameter(f"{_cap(arg)}_PRECISION_0"),
+ "exponent_width": self.get_parameter(
+ f"{_cap(arg)}_PRECISION_1"
+ ),
+ },
+ parallelism=[
+ self.get_parameter(f"{_cap(arg)}_PARALLELISM_DIM_1"),
+ self.get_parameter(f"{_cap(arg)}_PARALLELISM_DIM_0"),
+ ],
+ )
+
+ else:
+ # TO DO: convert to integer equivalent of floating point representation
+ pass
+
+ block_size = self.get_parameter(
+ "DATA_IN_0_PARALLELISM_DIM_0"
+ ) * self.get_parameter("DATA_IN_0_PARALLELISM_DIM_1")
+ for block in in_data_blocks:
+ self.input_drivers[arg].append(block)
+
+ def load_monitors(self, expectation):
+ # Process the expectation tensor
+ output_blocks = self.preprocess_tensor_for_mxint(
+ tensor=expectation,
+ q_config={
+ "width": self.get_parameter(f"DATA_OUT_0_PRECISION_0"),
+ "exponent_width": self.get_parameter(f"DATA_OUT_0_PRECISION_1"),
+ },
+ parallelism=[
+ self.get_parameter(f"DATA_OUT_0_PARALLELISM_DIM_1"),
+ self.get_parameter(f"DATA_OUT_0_PARALLELISM_DIM_0"),
+ ],
+ )
+
+ # Set expectation for each monitor
+ for block in output_blocks:
+ # ! TO DO: generalize to multi-output models
+ self.output_monitors["data_out_0"].expect(block)
+
+ # Drive the in-flight flag for each monitor
+ self.output_monitors["data_out_0"].in_flight = True
+
# Serialize testbench object to be instantiated within test by cocotb runner
cls_obj = MaseGraphTB
tb_path = Path.home() / ".mase" / "top" / "hardware" / "test" / "mase_top_tb"
tb_path.mkdir(parents=True, exist_ok=True)
with open(tb_path / "tb_obj.dill", "wb") as file:
+ import sys
+
+ sys.setrecursionlimit(10000) # Increase recursion limit
dill.dump(cls_obj, file)
with open(tb_path / "__init__.py", "w") as file:
file.write("from .test import test")
@@ -224,6 +403,7 @@ def emit_cocotb_transform_pass(graph, pass_args={}):
init_project(project_dir)
_emit_cocotb_test(graph, pass_args=pass_args)
- _emit_cocotb_tb(graph)
+ _emit_cocotb_tb(graph, pass_args=pass_args)
+ # _emit_cocotb_tb_for_mxint(graph)
return graph, None
diff --git a/src/chop/passes/graph/transforms/verilog/emit_top.py b/src/chop/passes/graph/transforms/verilog/emit_top.py
index 6da3c0043..983a732c9 100644
--- a/src/chop/passes/graph/transforms/verilog/emit_top.py
+++ b/src/chop/passes/graph/transforms/verilog/emit_top.py
@@ -9,11 +9,12 @@
from chop.passes.graph.utils import vf, v2p, init_project
import mase_components.helper.generate_memory as gen_lut
import torch.nn as nn
-
+import sys
+from pathlib import Path
logger = logging.getLogger(__name__)
-
+from chop.nn.quantized.modules.layer_norm import LayerNormIntegerFloor
+from chop.nn.quantized.modules.attention import ViTAttentionInteger
from .util import get_verilog_parameters
-from pathlib import Path
# =============================================================================
# Utilities
@@ -125,6 +126,7 @@ def emit(self, graph, parameter_map):
i = 0
for node in nodes_in:
node_name = vf(node.name)
+ quant_type = node.meta["mase"].parameters["common"]["quant_type"]
for arg_idx, arg in enumerate(
node.meta["mase"].parameters["common"]["args"].keys()
):
@@ -136,7 +138,14 @@ def emit(self, graph, parameter_map):
for param in parameter_map
if param.startswith(f"{arg_name}_PARALLELISM_DIM")
]
- interface += f"""
+ if quant_type == "mxint_hardware":
+ interface += f"""
+ input [{arg_name}_PRECISION_0-1:0] mdata_in_{i} [{'*'.join(parallelism_params)}-1:0],
+ input [{arg_name}_PRECISION_1-1:0] edata_in_{i},
+ input data_in_{i}_valid,
+ output data_in_{i}_ready,"""
+ else:
+ interface += f"""
input [{arg_name}_PRECISION_0-1:0] data_in_{i} [{'*'.join(parallelism_params)}-1:0],
input data_in_{i}_valid,
output data_in_{i}_ready,"""
@@ -145,6 +154,7 @@ def emit(self, graph, parameter_map):
i = 0
for node in nodes_out:
node_name = vf(node.name)
+ quant_type = node.meta["mase"].parameters["common"]["quant_type"]
for result in node.meta["mase"].parameters["common"]["results"].keys():
if "data_out" in result:
result_name = _cap(result)
@@ -153,7 +163,14 @@ def emit(self, graph, parameter_map):
for param in parameter_map
if param.startswith(f"{result_name}_PARALLELISM_DIM")
]
- interface += f"""
+ if quant_type == "mxint_hardware":
+ interface += f"""
+ output [{result_name}_PRECISION_0-1:0] mdata_out_{i} [{'*'.join(parallelism_params)}-1:0],
+ output [{result_name}_PRECISION_1-1:0] edata_out_{i},
+ output data_out_{i}_valid,
+ input data_out_{i}_ready,"""
+ else:
+ interface += f"""
output [{result_name}_PRECISION_0-1:0] data_out_{i} [{'*'.join(parallelism_params)}-1:0],
output data_out_{i}_valid,
input data_out_{i}_ready,"""
@@ -177,6 +194,7 @@ def _emit_signals_top_internal(self, node, parameter_map):
signals = ""
node_name = vf(node.name)
# Input signals
+ quant_type = node.meta["mase"].parameters["common"]["quant_type"]
for arg, arg_info in node.meta["mase"].parameters["common"]["args"].items():
if not isinstance(arg_info, dict):
continue
@@ -199,7 +217,14 @@ def _emit_signals_top_internal(self, node, parameter_map):
if node.meta["mase"]["common"]["mase_op"] == "getitem":
arg = "data_in_0"
- signals += f"""
+ if quant_type == "mxint_hardware":
+ signals += f"""
+logic [{node_name}_{arg_name}_PRECISION_0-1:0] {node_name}_m{arg} [{'*'.join(parallelism_params)}-1:0];
+logic [{node_name}_{arg_name}_PRECISION_1-1:0] {node_name}_e{arg};
+logic {node_name}_{arg}_valid;
+logic {node_name}_{arg}_ready;"""
+ else:
+ signals += f"""
logic [{node_name}_{arg_name}_PRECISION_0-1:0] {node_name}_{arg} [{'*'.join(parallelism_params)}-1:0];
logic {node_name}_{arg}_valid;
logic {node_name}_{arg}_ready;"""
@@ -225,7 +250,14 @@ def _emit_signals_top_internal(self, node, parameter_map):
for param in parameter_map
if f"{node_name}_{result_name}_PARALLELISM_DIM" in param
]
- signals += f"""
+ if quant_type == "mxint_hardware":
+ signals += f"""
+logic [{node_name}_{result_name}_PRECISION_0-1:0] {node_name}_m{result} [{'*'.join(parallelism_params)}-1:0];
+logic [{node_name}_{result_name}_PRECISION_1-1:0] {node_name}_e{result};
+logic {node_name}_{result}_valid;
+logic {node_name}_{result}_ready;"""
+ else:
+ signals += f"""
logic [{node_name}_{result_name}_PRECISION_0-1:0] {node_name}_{result} [{'*'.join(parallelism_params)}-1:0];
logic {node_name}_{result}_valid;
logic {node_name}_{result}_ready;"""
@@ -321,12 +353,27 @@ def _emit_module_parameters_top_internal(self, key, value, node, parameter_map):
component_name_inst = f"{component_name}_0"
parameters = ""
+ quant_type = node.meta["mase"].parameters["common"]["quant_type"]
for param in node.meta["mase"].parameters["hardware"]["verilog_param"].keys():
if f"{_cap(key)}_" in param:
parameters += f" .{param}({node_name}_{param}),\n"
parameters = _remove_last_comma(parameters)
- return f"""
+ if quant_type == "mxint_hardware":
+ top_component = f"""
+{component_name} #(
+{parameters}
+) {component_name_inst} (
+ .clk(clk),
+ .rst(rst),
+ .mdata_out({node_name}_m{key}),
+ .edata_out({node_name}_e{key}),
+ .data_out_ready({node_name}_{key}_ready),
+ .data_out_valid({node_name}_{key}_valid)
+);
+"""
+ else:
+ top_component = f"""
{component_name} #(
{parameters}
) {component_name_inst} (
@@ -338,6 +385,8 @@ def _emit_module_parameters_top_internal(self, key, value, node, parameter_map):
);
"""
+ return top_component
+
def _emit_getitem_signals(self, node):
"""
Getitem nodes have arg list like (None, None, None, Arg, None, None)
@@ -346,8 +395,22 @@ def _emit_getitem_signals(self, node):
"""
node_name = vf(node.name)
+ quant_type = node.meta["mase"].parameters["common"]["quant_type"]
- return f"""
+ if quant_type == "mxint_hardware":
+ component_interface = f"""
+ .mdata_in_0 ({node_name}_mdata_in_0),
+ .edata_in_0 ({node_name}_edata_in_0),
+ .data_in_0_valid ({node_name}_data_in_0_valid),
+ .data_in_0_ready ({node_name}_data_in_0_ready),
+
+ .mdata_out_0 ({node_name}_mdata_out_0),
+ .edata_out_0 ({node_name}_edata_out_0),
+ .data_out_0_valid ({node_name}_data_out_0_valid),
+ .data_out_0_ready ({node_name}_data_out_0_ready),
+ """
+ else:
+ component_interface = f"""
.data_in_0 ({node_name}_data_in_0),
.data_in_0_valid ({node_name}_data_in_0_valid),
.data_in_0_ready ({node_name}_data_in_0_ready),
@@ -357,10 +420,13 @@ def _emit_getitem_signals(self, node):
.data_out_0_ready ({node_name}_data_out_0_ready),
"""
+ return component_interface
+
def emit(self, node, parameter_map):
node_name = vf(node.name)
component_name = node.meta["mase"].parameters["hardware"]["module"]
signals = ""
+ quant_type = node.meta["mase"].parameters["common"]["quant_type"]
# Emit component instantiation parameters
parameters = ""
@@ -385,7 +451,15 @@ def emit(self, node, parameter_map):
for key, value in node.meta["mase"].parameters["common"]["args"].items():
if "inplace" in key or not isinstance(value, dict):
continue
- signals += f"""
+ if quant_type == "mxint_hardware":
+ signals += f"""
+ .m{key}({node_name}_m{key}),
+ .e{key}({node_name}_e{key}),
+ .{key}_valid({node_name}_{key}_valid),
+ .{key}_ready({node_name}_{key}_ready),
+ """
+ else:
+ signals += f"""
.{key}({node_name}_{key}),
.{key}_valid({node_name}_{key}_valid),
.{key}_ready({node_name}_{key}_ready),
@@ -393,7 +467,15 @@ def emit(self, node, parameter_map):
# Emit component instantiation output signals
for key, value in node.meta["mase"].parameters["common"]["results"].items():
- signals += f"""
+ if quant_type == "mxint_hardware":
+ signals += f"""
+ .m{key}({node_name}_m{key}),
+ .e{key}({node_name}_e{key}),
+ .{key}_valid({node_name}_{key}_valid),
+ .{key}_ready({node_name}_{key}_ready),
+ """
+ else:
+ signals += f"""
.{key}({node_name}_{key}),
.{key}_valid({node_name}_{key}_valid),
.{key}_ready({node_name}_{key}_ready),
@@ -583,10 +665,20 @@ def _emit_top_wires(self):
i = 0
for node in nodes_in:
node_name = vf(node.name)
+ quant_type = node.meta["mase"].parameters["common"]["quant_type"]
for arg_idx, arg in enumerate(
node.meta["mase"].parameters["common"]["args"].keys()
):
- if is_real_input_arg(node, arg_idx):
+ if not is_real_input_arg(node, arg_idx):
+ continue
+ if quant_type == "mxint_hardware":
+ wires += f"""
+assign data_in_{i}_ready = {node_name}_{arg}_ready;
+assign {node_name}_{arg}_valid = data_in_{i}_valid;
+assign {node_name}_m{arg} = mdata_in_{i};
+assign {node_name}_e{arg} = edata_in_{i};
+"""
+ else:
wires += f"""
assign data_in_{i}_ready = {node_name}_{arg}_ready;
assign {node_name}_{arg}_valid = data_in_{i}_valid;
@@ -598,15 +690,21 @@ def _emit_top_wires(self):
node_name = vf(node.name)
for result in node.meta["mase"].parameters["common"]["results"].keys():
if "data_out" in result:
- wires += f"""
+ if quant_type == "mxint_hardware":
+ wires += f"""
+assign data_out_{i}_valid = {node_name}_{result}_valid;
+assign {node_name}_{result}_ready = data_out_{i}_ready;
+assign mdata_out_{i} = {node_name}_m{result};
+assign edata_out_{i} = {node_name}_e{result};
+"""
+ else:
+ wires += f"""
assign data_out_{i}_valid = {node_name}_{result}_valid;
assign {node_name}_{result}_ready = data_out_{i}_ready;
assign data_out_{i} = {node_name}_{result};
"""
i += 1
- # TODO: emit off-chip parameter interface
-
return wires
def _emit_getitem_wires(self, node):
@@ -618,19 +716,30 @@ def _emit_getitem_wires(self, node):
from_name = vf(node.args[0].name)
to_name = vf(node.name)
select = node.args[1]
+ quant_type = node.meta["mase"].parameters["common"]["quant_type"]
+ if quant_type == "mxint_hardware":
+ getitem_wires = f"""
+assign {from_name}_data_out_{select}_ready = {to_name}_data_in_0_ready;
+assign {to_name}_data_in_0_valid = {from_name}_data_out_{select}_valid;
+assign {to_name}_mdata_in_0 = {from_name}_mdata_out_{select};
+assign {to_name}_edata_in_0 = {from_name}_edata_out_{select};
+"""
- return f"""
+ else:
+ getitem_wires = f"""
assign {from_name}_data_out_{select}_ready = {to_name}_data_in_0_ready;
assign {to_name}_data_in_0_valid = {from_name}_data_out_{select}_valid;
assign {to_name}_data_in_0 = {from_name}_data_out_{select};
"""
+ return getitem_wires
+
def _emit_node2node_wires(self):
nodes_in = self.graph.nodes_in
wires = ""
+ fork_in = {}
for node in self.graph.fx_graph.nodes:
-
if (
# Skip implicit nodes
node.meta["mase"].parameters["hardware"]["is_implicit"]
@@ -645,13 +754,28 @@ def _emit_node2node_wires(self):
continue
to_name = vf(node.name)
-
+ quant_type = node.meta["mase"].parameters["common"]["quant_type"]
for i, node_in in enumerate(node.all_input_nodes):
from_name = vf(node_in.name)
- wires += f"""
-assign {from_name}_data_out_0_ready = {to_name}_data_in_{i}_ready;
-assign {to_name}_data_in_{i}_valid = {from_name}_data_out_0_valid;
-assign {to_name}_data_in_{i} = {from_name}_data_out_0;
+ if "fork2" in from_name:
+ fork_in[from_name] = (
+ 0 if fork_in.get(from_name) == None else fork_in[from_name] + 1
+ )
+ j = fork_in[from_name]
+ else:
+ j = 0
+ if quant_type == "mxint_hardware":
+ wires += f"""
+assign {from_name}_data_out_{j}_ready = {to_name}_data_in_{i}_ready;
+assign {to_name}_data_in_{i}_valid = {from_name}_data_out_{j}_valid;
+assign {to_name}_mdata_in_{i} = {from_name}_mdata_out_{j};
+assign {to_name}_edata_in_{i} = {from_name}_edata_out_{j};
+"""
+ else:
+ wires += f"""
+assign {from_name}_data_out_{j}_ready = {to_name}_data_in_{i}_ready;
+assign {to_name}_data_in_{i}_valid = {from_name}_data_out_{j}_valid;
+assign {to_name}_data_in_{i} = {from_name}_data_out_{j};
"""
return wires
@@ -729,6 +853,255 @@ def emit(self, graph, top_name):
return module_inst
+def emit_folded_bram(folded_gragh, reuse_name, reuse_times):
+ def _emit_module_parameters_top_internal(key, node, reuse_name, reuse_times):
+ node_name = vf(node.name).replace(reuse_name + "_0", reuse_name)
+ component_name = f"{node_name}_{key}_source"
+ component_name_inst = f"{component_name}_0"
+
+ # verilog_param = node_name+"_"+_cap(key)
+ def get_image_depth(key, param_list, node_name):
+ if "weight" in key:
+ image_depth = (
+ param_list[f"{_cap(key)}_TENSOR_SIZE_DIM_0"]
+ * param_list[f"{_cap(key)}_TENSOR_SIZE_DIM_1"]
+ / (
+ param_list[f"{_cap(key)}_PARALLELISM_DIM_0"]
+ * param_list[f"{_cap(key)}_PARALLELISM_DIM_1"]
+ )
+ )
+ elif "bias" in key:
+ if "norm" in node_name:
+ image_depth = (
+ param_list[f"{_cap(key)}_TENSOR_SIZE_DIM_0"]
+ * param_list[f"{_cap(key)}_TENSOR_SIZE_DIM_1"]
+ / (
+ param_list[f"{_cap(key)}_PARALLELISM_DIM_0"]
+ * param_list[f"{_cap(key)}_PARALLELISM_DIM_1"]
+ )
+ )
+ else:
+ image_depth = (
+ param_list[f"{_cap(key)}_TENSOR_SIZE_DIM_0"]
+ / param_list[f"{_cap(key)}_PARALLELISM_DIM_0"]
+ )
+ else:
+ raise NotImplementedError
+ return image_depth
+
+ image_depth = get_image_depth(
+ key, node.meta["mase"].parameters["hardware"]["verilog_param"], node.name
+ )
+ parameters = ""
+ for param in node.meta["mase"].parameters["hardware"]["verilog_param"].keys():
+ if f"{_cap(key)}_" in param:
+ parameters += f" .{param}({param}),\n"
+ parameters = _remove_last_comma(parameters)
+ modules = ""
+ signal = ""
+ for i in range(reuse_times):
+ new_node_name = node_name.replace(reuse_name, reuse_name + f"_{i}")
+ new_componet_name = component_name.replace(reuse_name, reuse_name + f"_{i}")
+ new_component_name_inst = component_name_inst.replace(
+ reuse_name, reuse_name + f"_{i}"
+ )
+ signal += f"""
+logic [{_cap(key)}_PRECISION_0 - 1:0] {new_node_name}_{key} [{_cap(key)}_PARALLELISM_DIM_0*{_cap(key)}_PARALLELISM_DIM_1 - 1:0];
+logic {new_node_name}_{key}_valid, {new_node_name}_{key}_ready;
+"""
+ modules += f"""
+{new_componet_name} #(
+{parameters}
+) {new_component_name_inst} (
+ .clk(clk),
+ .rst(rst),
+ .data_out({new_node_name}_{key}),
+ .data_out_ready({new_node_name}_{key}_ready),
+ .data_out_valid({new_node_name}_{key}_valid)
+);
+
+ """
+
+ output_connections = f"""
+always_comb begin"""
+ for item in ["", f"_valid"]:
+ output_connections += f"""
+ data_out{item} = (counter= (REPEAT_TIMES - 1)*IMAGE_DEPTH)? data_out_0_ready: (counter_in < IMAGE_DEPTH) ? 0 : top_block_data_in_0_ready;
+end
+endmodule
+ """
+ return top
+
+
+def emit_verilog_folded_top_file(graph, top_name, pass_args):
+ folded_graph = pass_args["folded_graph"]
+ folded_node_name = pass_args["folded_node_name"]
+ reuse_times = pass_args["reuse_times"]
+ top_block = (
+ VerilogEmitter(folded_graph)
+ .emit(folded_graph, "top_block")
+ .replace(f"{folded_node_name}_0", folded_node_name)
+ )
+ top_bram = emit_folded_bram(folded_graph, folded_node_name, reuse_times)
+ top = emit_verilog_folded_top(graph, reuse_times, top_name)
+ top_file = f"""
+ {top}
+ {top_block}
+ {top_bram}
+ """
+ return top_file
+
+
def emit_verilog_top_transform_pass(graph, pass_args={}):
"""Emit the top-level model design in Verilog
@@ -756,8 +1129,10 @@ def emit_verilog_top_transform_pass(graph, pass_args={}):
top_name = pass_args["top_name"] if "top_name" in pass_args.keys() else "top"
init_project(project_dir)
rtl_dir = os.path.join(project_dir, "hardware", "rtl")
-
- top = VerilogEmitter(graph).emit(graph, top_name)
+ if pass_args.get("folded_graph", False):
+ top = emit_verilog_folded_top_file(graph, top_name, pass_args)
+ else:
+ top = VerilogEmitter(graph).emit(graph, top_name)
top_file = os.path.join(rtl_dir, f"{top_name}.sv")
with open(top_file, "w") as top_design:
@@ -768,8 +1143,6 @@ def emit_verilog_top_transform_pass(graph, pass_args={}):
# Alternatively, add a class to the emitter that can be called to generate LUTs, for LUT based implementations of activation functions,
# or other functions that require LUTs such as PolyLUT or LUTnet neurons.
for node in graph.fx_graph.nodes:
- # print(vars(node))
- # print(type(node))
if node.op == "call_module":
module = dict(graph.model.named_modules())[node.target]
if isinstance(module, nn.SiLU):
@@ -782,22 +1155,81 @@ def emit_verilog_top_transform_pass(graph, pass_args={}):
func = "logsigmoid"
elif isinstance(module, nn.Softmax):
func = "exp"
+ elif isinstance(module, nn.GELU) or node.meta["mase"]["common"]["mase_op"] == "gelu":
+ func = "gelu"
+ elif isinstance(module, LayerNormIntegerFloor):
+ func = "isqrt"
+ elif isinstance(module, ViTAttentionInteger):
+ func = "exp"
else:
func = "Unknown"
-
+ mult = 1
+ sys.path.append(Path(__file__).resolve().parents[6].as_posix())
+ from a_cx_mxint_quant import MXIntGELU
if func != "Unknown":
- d_in_width = node.meta["mase"].parameters["hardware"]["verilog_param"][
- "DATA_IN_0_PRECISION_0"
- ]
- d_in_f_width = node.meta["mase"].parameters["hardware"][
- "verilog_param"
- ]["DATA_IN_0_PRECISION_1"]
- d_out_width = node.meta["mase"].parameters["hardware"]["verilog_param"][
- "DATA_OUT_0_PRECISION_0"
- ]
- d_out_f_width = node.meta["mase"].parameters["hardware"][
- "verilog_param"
- ]["DATA_OUT_0_PRECISION_1"]
+ if isinstance(module, ViTAttentionInteger):
+ d_in_width = node.meta["mase"].parameters["hardware"][
+ "verilog_param"
+ ]["QKMM_OUT_PRECISION_0"]
+ d_in_f_width = node.meta["mase"].parameters["hardware"][
+ "verilog_param"
+ ]["QKMM_OUT_PRECISION_1"]
+ d_out_width = node.meta["mase"].parameters["hardware"][
+ "verilog_param"
+ ]["SOFTMAX_EXP_PRECISION_0"]
+ d_out_f_width = node.meta["mase"].parameters["hardware"][
+ "verilog_param"
+ ]["SOFTMAX_EXP_PRECISION_1"]
+ from math import sqrt
+
+ mult = 1 / sqrt(
+ node.meta["mase"].parameters["hardware"]["verilog_param"][
+ "DATA_IN_0_TENSOR_SIZE_DIM_0"
+ ]
+ // node.meta["mase"].parameters["hardware"]["verilog_param"][
+ "NUM_HEADS"
+ ]
+ )
+ elif isinstance(module, LayerNormIntegerFloor):
+ d_in_width = node.meta["mase"].parameters["hardware"][
+ "verilog_param"
+ ]["ISQRT_IN_PRECISION_0"]
+ d_in_f_width = node.meta["mase"].parameters["hardware"][
+ "verilog_param"
+ ]["ISQRT_IN_PRECISION_1"]
+ d_out_width = node.meta["mase"].parameters["hardware"][
+ "verilog_param"
+ ]["ISQRT_OUT_PRECISION_0"]
+ d_out_f_width = node.meta["mase"].parameters["hardware"][
+ "verilog_param"
+ ]["ISQRT_OUT_PRECISION_1"]
+ elif isinstance(module, MXIntGELU):
+ d_in_width = node.meta["mase"].parameters["hardware"][
+ "verilog_param"
+ ]["DATA_IN_0_PRECISION_0"] + 2
+ d_in_f_width = node.meta["mase"].parameters["hardware"][
+ "verilog_param"
+ ]["DATA_IN_0_PRECISION_1"] - 1
+ d_out_width = node.meta["mase"].parameters["hardware"][
+ "verilog_param"
+ ]["HASH_OUT_WIDTH"]
+ d_out_f_width = node.meta["mase"].parameters["hardware"][
+ "verilog_param"
+ ]["HASH_OUT_WIDTH"] - 3
+ else:
+ d_in_width = node.meta["mase"].parameters["hardware"][
+ "verilog_param"
+ ]["DATA_IN_0_PRECISION_0"]
+ d_in_f_width = node.meta["mase"].parameters["hardware"][
+ "verilog_param"
+ ]["DATA_IN_0_PRECISION_1"]
+ d_out_width = node.meta["mase"].parameters["hardware"][
+ "verilog_param"
+ ]["DATA_OUT_0_PRECISION_0"]
+ d_out_f_width = node.meta["mase"].parameters["hardware"][
+ "verilog_param"
+ ]["DATA_OUT_0_PRECISION_1"]
+ logger.info(f"Generating LUT for {func}")
gen_lut.generate_sv_lut(
func,
d_in_width,
@@ -806,5 +1238,7 @@ def emit_verilog_top_transform_pass(graph, pass_args={}):
d_out_f_width,
path=rtl_dir,
path_with_dtype=False,
+ constant_mult=mult,
+ floor=False,
)
return graph, {}
diff --git a/src/chop/passes/graph/transforms/verilog/emit_vivado_project.py b/src/chop/passes/graph/transforms/verilog/emit_vivado_project.py
index 3219a42ac..1633cfc60 100644
--- a/src/chop/passes/graph/transforms/verilog/emit_vivado_project.py
+++ b/src/chop/passes/graph/transforms/verilog/emit_vivado_project.py
@@ -18,8 +18,8 @@ def generate_tcl_script(top_name, vivado_project_path, include_groups, project_d
)
tcl_script_template = f"""
-set_param board.repoPaths {{{str(Path.home())}/shared/board-files}}
-create_project {top_name}_build_project {vivado_project_path} -part xcu280-fsvh2892-2L-e
+# set_param board.repoPaths {{{str(Path.home())}/shared/board-files}}
+create_project -force {top_name}_build_project {vivado_project_path} -part xcu280-fsvh2892-2L-e
set_property board_part xilinx.com:au280:part0:1.1 [current_project]
"""
for include_group in include_groups:
@@ -27,10 +27,21 @@ def generate_tcl_script(top_name, vivado_project_path, include_groups, project_d
tcl_script_template += f"\n\nset_property top top [current_fileset]"
+ tcl_script_template += f"""
+add_files /scratch/cx922/mase/src/mase_components/vivado/constraints.xdc
+read_xdc /scratch/cx922/mase/src/mase_components/vivado/constraints.xdc
+"""
tcl_script_template += f"""
update_compile_order -fileset sources_1
"""
+ # syth and impl
+ tcl_script_template += f"""
+launch_runs synth_1
+wait_on_run synth_1
+launch_runs impl_1
+wait_on_run impl_1
+"""
# * Package IP
tcl_script_template += f"""
ipx::package_project -root_dir {project_dir}/hardware/ip_repo -vendor user.org -library user -taxonomy /UserIP -import_files
@@ -87,11 +98,12 @@ def emit_vivado_project_transform_pass(graph, pass_args={}):
os.makedirs(vivado_project_path, exist_ok=True)
# * List include files
- include_groups = [
- f"{COMPONENTS_PATH / group / 'rtl'}"
- for group in mase_components.get_modules()
- if group != "vivado"
- ] + [project_dir / "hardware" / "rtl"]
+ include_groups = [project_dir / "hardware" / "rtl"]
+ # include_groups = [
+ # f"{COMPONENTS_PATH / group / 'rtl'}"
+ # for group in mase_components.get_modules()
+ # if group != "vivado"
+ # ] + [project_dir / "hardware" / "rtl"]
generate_tcl_script(top_name, vivado_project_path, include_groups, project_dir)
@@ -105,6 +117,6 @@ def emit_vivado_project_transform_pass(graph, pass_args={}):
"-source",
f"{vivado_project_path}/build.tcl",
]
- result = subprocess.run(cmd, capture_output=True, text=True)
+ # result = subprocess.run(cmd, capture_output=True, text=True)
return graph, {}
diff --git a/src/chop/passes/graph/transforms/verilog/insert_fork.py b/src/chop/passes/graph/transforms/verilog/insert_fork.py
new file mode 100644
index 000000000..82aac3d80
--- /dev/null
+++ b/src/chop/passes/graph/transforms/verilog/insert_fork.py
@@ -0,0 +1,168 @@
+import torch
+import torch.nn as nn
+from copy import copy, deepcopy
+from chop.ir.graph import MaseMetadata
+
+
+@torch.fx.wrap
+def fork2(x):
+ out = x
+ return out
+
+
+def insert_fork_transform_pass(graph, pass_args={}):
+ """Insert hardware-explicit forks into the mase graph
+ :param graph: a MaseGraph
+ :type graph: MaseGraph
+ :param pass_args: this pass requires additional arguments which is explained below, defaults to {}
+ :type pass_args: _type_, optional
+ :return: return a tuple of a MaseGraph and an empty dict (no additional info to return)
+ :rtype: tuple(MaseGrap`h, Dict)
+ """
+
+ def generating_mase_metadata(new_node, node, quan_args):
+ new_node.meta["mase"] = MaseMetadata(new_node, node.meta["mase"].model)
+ new_node.meta["mase"].parameters["common"]["mase_type"] = "call_function"
+ new_node.meta["mase"].parameters["common"]["mase_op"] = "fork2"
+ inherited_metadata = deepcopy(
+ node.meta["mase"]["common"]["results"]["data_out_0"]
+ )
+ if quan_args["config"]["name"] == "mxint_hardware":
+ inherited_metadata["precision"] = [quan_args["config"]["data_in_width"], quan_args["config"]["data_in_exponent_width"]],
+ inherited_metadata["type"] = "mxint_hardware"
+ else:
+ inherited_metadata["precision"] = quan_args
+ inherited_metadata["type"] = "fixed"
+ new_node.meta["mase"].parameters["common"]["args"] = {
+ "data_in_0": inherited_metadata
+ }
+ new_node.meta["mase"].parameters["common"]["results"] = {
+ "data_out_0": inherited_metadata,
+ "data_out_1": inherited_metadata,
+ }
+
+ new_node.meta["mase"].parameters["hardware"]["is_implicit"] = False
+
+ nodes_to_fork = []
+ from chop.tools.utils import to_numpy_if_tensor, to_tensor_if_numpy
+ from chop.passes.graph.transforms.utils import (
+ metadata_value_type_cast_transform_pass,
+ )
+
+ graph, _ = metadata_value_type_cast_transform_pass(
+ graph, pass_args={"fn": to_numpy_if_tensor}
+ )
+ for node in graph.fx_graph.nodes:
+ user_count = 0
+ for u in node.users.keys():
+ user_count += 1
+ if user_count > 1:
+ nodes_to_fork.append(node)
+ for node in nodes_to_fork:
+ with graph.fx_graph.inserting_after(node):
+ new_node = graph.fx_graph.call_function(fork2, args=(node,))
+ node.replace_all_uses_with(new_node)
+ new_node.args = (node,)
+ by = pass_args.get("by", "type")
+ if by == "type":
+ generating_mase_metadata(new_node, node, quan_args=pass_args["fork2"])
+ else:
+ generating_mase_metadata(
+ new_node, node, quan_args=pass_args[new_node.name]
+ )
+
+ # test whether the new graph works
+ insert_fifo_after_fork_pass(graph)
+ graph, _ = metadata_value_type_cast_transform_pass(
+ graph, pass_args={"fn": to_tensor_if_numpy}
+ )
+ graph.fx_graph.lint()
+ return graph, None
+
+
+@torch.fx.wrap
+def fifo(x):
+ out = x
+ return out
+
+
+def insert_fifo_after_fork_pass(graph, pass_args={}):
+ def generating_mase_metadata(new_node, node, i):
+ new_node.meta["mase"] = MaseMetadata(new_node, node.meta["mase"].model)
+ new_node.meta["mase"].parameters["common"]["mase_type"] = "call_function"
+ new_node.meta["mase"].parameters["common"]["mase_op"] = "fifo"
+ inherited_metadata = deepcopy(
+ node.meta["mase"]["common"]["args"][f"data_in_{i}"]
+ )
+ new_node.meta["mase"].parameters["common"]["args"] = {
+ "data_in_0": inherited_metadata
+ }
+ new_node.meta["mase"].parameters["common"]["results"] = {
+ "data_out_0": inherited_metadata
+ }
+
+ new_node.meta["mase"].parameters["hardware"]["is_implicit"] = False
+
+ record_list = []
+ for node in graph.fx_graph.nodes:
+ if node.meta["mase"].parameters["common"]["mase_op"] == "fork2":
+ for record_node in list(node.users):
+ if record_node.meta["mase"].parameters["common"]["mase_op"] == "add":
+ record_list.append(record_node)
+ for node in record_list:
+ with graph.fx_graph.inserting_before(node):
+ for i, arg in enumerate(list(node.args)):
+ if arg.meta["mase"].parameters["common"]["mase_op"] == "fork2":
+ new_node = graph.fx_graph.call_function(fifo, args=(arg,))
+ generating_mase_metadata(new_node, node, i)
+ node_args = list(node.args)
+ node_args[i] = new_node
+ node.args = tuple(node_args)
+ return graph, None
+
+
+def insert_fifo_after_specified_modules(graph, pass_args={}):
+ def generating_mase_metadata(new_node, node, parallelism):
+ new_node.meta["mase"] = MaseMetadata(new_node, node.meta["mase"].model)
+ new_node.meta["mase"].parameters["common"]["mase_type"] = "call_function"
+ new_node.meta["mase"].parameters["common"]["mase_op"] = "fifo"
+ inherited_metadata = deepcopy(
+ node.meta["mase"]["common"]["results"][f"data_out_0"]
+ )
+ new_node.meta["mase"].parameters["common"]["args"] = {
+ "data_in_0": inherited_metadata,
+ "depth": inherited_metadata["shape"][-1] // parallelism,
+ }
+ new_node.meta["mase"].parameters["common"]["results"] = {
+ "data_out_0": inherited_metadata
+ }
+
+ new_node.meta["mase"].parameters["hardware"]["is_implicit"] = False
+
+ from chop.tools.utils import to_numpy_if_tensor, to_tensor_if_numpy
+ from chop.passes.graph.transforms.utils import (
+ metadata_value_type_cast_transform_pass,
+ )
+
+ graph, _ = metadata_value_type_cast_transform_pass(
+ graph, pass_args={"fn": to_numpy_if_tensor}
+ )
+ record_list = []
+ for node in graph.fx_graph.nodes:
+ if (
+ node.meta["mase"].parameters["common"]["mase_op"]
+ in pass_args["insert_fifo"]
+ ):
+ record_list.append(node)
+ for node in record_list:
+ with graph.fx_graph.inserting_after(node):
+ new_node = graph.fx_graph.call_function(fifo, args=(node,))
+ node.replace_all_uses_with(new_node)
+ new_node.args = (node,)
+ generating_mase_metadata(new_node, node, pass_args["max_parallelism"])
+
+ graph, _ = metadata_value_type_cast_transform_pass(
+ graph, pass_args={"fn": to_tensor_if_numpy}
+ )
+ graph.fx_graph.lint()
+ return graph, None
diff --git a/src/chop/passes/graph/transforms/verilog/util.py b/src/chop/passes/graph/transforms/verilog/util.py
index 35abe9560..b0cbeea78 100644
--- a/src/chop/passes/graph/transforms/verilog/util.py
+++ b/src/chop/passes/graph/transforms/verilog/util.py
@@ -25,11 +25,18 @@ def get_verilog_parameters(graph):
parameter_map[f"{node_name}_{key}"] = value
# * Return graph level parameters
- for node in graph.nodes_in + graph.nodes_out:
+ for node in graph.nodes_in:
for key, value in (
node.meta["mase"].parameters["hardware"]["verilog_param"].items()
):
- if "DATA_IN" in key or "DATA_OUT" in key:
+ if "DATA_IN" in key:
+ parameter_map[key] = value
+
+ for node in graph.nodes_out:
+ for key, value in (
+ node.meta["mase"].parameters["hardware"]["verilog_param"].items()
+ ):
+ if "DATA_OUT" in key:
parameter_map[key] = value
return parameter_map
diff --git a/src/mase_cocotb/interfaces/random_draw.drawio b/src/mase_cocotb/interfaces/random_draw.drawio
new file mode 100644
index 000000000..0508a71e1
--- /dev/null
+++ b/src/mase_cocotb/interfaces/random_draw.drawio
@@ -0,0 +1,96 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/src/mase_cocotb/interfaces/streaming.py b/src/mase_cocotb/interfaces/streaming.py
index e67ebd1a3..01df3ac6c 100644
--- a/src/mase_cocotb/interfaces/streaming.py
+++ b/src/mase_cocotb/interfaces/streaming.py
@@ -238,50 +238,41 @@ def _check(self, got, exp):
self.log.debug("Passed | Got: %20s Exp: %20s Err: %10s" % (g, e, err))
-class MultiSignalStreamDriver(Driver):
- def __init__(self, clk, data, valid, ready) -> None:
- super().__init__()
- self.clk = clk
- self.data = data
- self.valid = valid
- self.ready = ready
- self.valid_prob = 1.0
-
- def set_valid_prob(self, prob):
- assert prob >= 0.0 and prob <= 1.0
- self.valid_prob = prob
-
- async def _driver_send(self, data) -> None:
+class MultiSignalStreamDriver(StreamDriver):
+ async def _driver_send(self, transaction) -> None:
while True:
await RisingEdge(self.clk)
- for hardware_target, item in zip(self.data, data):
- hardware_target.value = item
-
+ if type(self.data) == tuple:
+ # Drive multiple data bus
+ for wire, val in zip(self.data, transaction):
+ wire.value = val
+ else:
+ # Drive single data
+ self.data.value = transaction
if random.random() > self.valid_prob:
self.valid.value = 0
continue # Try roll random valid again at next clock
self.valid.value = 1
await ReadOnly()
if self.ready.value == 1:
- self.log.debug(f"Sent {data}")
+ if type(self.data) == tuple:
+ # Drive multiple data bus
+ for t in transaction:
+ self.log.debug("Sent %s" % t)
+ else:
+ self.log.debug("Sent %s" % transaction)
+ if self.record_num_beats:
+ self.num_beats += 1
break
+
+ # Load extra
+ # self.load_driver
+
if self.send_queue.empty():
await RisingEdge(self.clk)
self.valid.value = 0
-
-class MultiSignalStreamMonitor(Monitor):
- def __init__(self, clk, data, valid, ready, check=True):
- super().__init__(clk)
- self.clk = clk
- self.data = data
- self.valid = valid
- self.ready = ready
- self.check = check
-
- def _trigger(self):
- return self.valid.value == 1 and self.ready.value == 1
-
+class MultiSignalStreamMonitor(StreamMonitor):
def _recv(self):
def cast_data(value):
if type(value) == list:
@@ -296,3 +287,38 @@ def _check(self, got, exp):
for g, e in zip(got, exp):
if not np.equal(g, e).all():
raise TestFailure("\nGot \n%s, \nExpected \n%s" % (got, exp))
+
+class MultiSignalErrorThresholdStreamMonitor(ErrorThresholdStreamMonitor):
+ def _recv(self):
+ def cast_data(value):
+ if type(value) == list:
+ return [x.signed_integer for x in value]
+ elif type(value) == BinaryValue:
+ return value.signed_integer
+
+ return tuple([cast_data(target.value) for target in self.data])
+
+ def _check(self, got, exp):
+ if self.check:
+ mg, eg = got
+ me, ee = exp
+ if type(mg) == list:
+ mg = np.array(mg)
+ me = np.array(me)
+ mg = mg // 2**(ee - eg)
+ mg = mg.astype(np.int64)
+
+ if self.signed:
+ mg = _sign_extend(mg, self.width)
+ me = _sign_extend(me, self.width)
+ err = np.abs(mg - me)
+ if self.log_error:
+ self.error_log.append(err)
+ self.recv_log.append(got)
+ max_biterr = np.full_like(err, self.error_bits)
+ if not (err <= max_biterr).all():
+ self.log.error("Failed | Got: %20s Exp: %20s Err: %14s" % (mg, me, err))
+ assert False, "Test Failed!"
+ return
+ else:
+ assert False, "Not implemented"
\ No newline at end of file
diff --git a/src/mase_cocotb/monitor.py b/src/mase_cocotb/monitor.py
index 682dd2f6e..7ae5a26d2 100644
--- a/src/mase_cocotb/monitor.py
+++ b/src/mase_cocotb/monitor.py
@@ -15,7 +15,7 @@ def __init__(self, clk, check=True, name=None):
self.exp_queue = Queue()
self.check = check
self.name = name
- self.in_flight = False
+ self.in_flight = True
if not hasattr(self, "log"):
self.log = SimLog(
diff --git a/src/mase_cocotb/runner.py b/src/mase_cocotb/runner.py
index 21490555a..cc68fb6a8 100644
--- a/src/mase_cocotb/runner.py
+++ b/src/mase_cocotb/runner.py
@@ -65,6 +65,7 @@ def _single_test(
comp_path: Path,
test_work_dir: Path,
sim: str = "verilator",
+ gui: bool = False,
extra_build_args: list[str] = [],
seed: int = None,
trace: bool = False,
@@ -126,6 +127,7 @@ def _single_test(
seed=seed,
results_xml="results.xml",
build_dir=test_work_dir,
+ gui=gui,
)
num_tests, fail = get_results(test_work_dir.joinpath("results.xml"))
except Exception as e:
@@ -144,6 +146,7 @@ def mase_runner(
group=None,
module_param_list: list[dict[str, Any]] = [dict()],
sim: str = "verilator",
+ gui: str = False,
extra_build_args: list[str] = [],
seed: int = None,
jobs: int = 1,
@@ -206,6 +209,7 @@ def mase_runner(
comp_path=comp_path,
test_work_dir=test_work_dir,
sim=sim,
+ gui=gui,
extra_build_args=extra_build_args,
seed=seed,
trace=trace,
@@ -237,6 +241,7 @@ def mase_runner(
comp_path=comp_path,
test_work_dir=test_work_dir,
sim=sim,
+ gui=gui,
extra_build_args=extra_build_args,
seed=seed,
trace=trace,
diff --git a/src/mase_cocotb/testbench.py b/src/mase_cocotb/testbench.py
index be535dba5..e7f7293dd 100644
--- a/src/mase_cocotb/testbench.py
+++ b/src/mase_cocotb/testbench.py
@@ -38,10 +38,6 @@ def get_parameter(self, parameter_name):
parameter = getattr(self.dut, parameter_name)
return int(parameter)
- def get_parameter(self, parameter_name):
- parameter = getattr(self.dut, parameter_name)
- return int(parameter)
-
async def reset(self, active_high=True):
if self.rst is None:
raise Exception(
@@ -53,6 +49,10 @@ async def reset(self, active_high=True):
self.rst.value = 1 if active_high else 0
await RisingEdge(self.clk)
self.rst.value = 0 if active_high else 1
+ for monitor in self.output_monitors.values():
+ monitor.ready.value = 1
+ for driver in self.input_drivers.values():
+ driver.valid.value = 0
await RisingEdge(self.clk)
async def initialize(self):
diff --git a/src/mase_cocotb/utils.py b/src/mase_cocotb/utils.py
index 469681cd1..d62d37694 100644
--- a/src/mase_cocotb/utils.py
+++ b/src/mase_cocotb/utils.py
@@ -12,7 +12,7 @@
from mase_cocotb.z_qlayers import quantize_to_int
from functools import partial
-from chop.nn.quantizers import integer_quantizer
+from chop.nn.quantizers import integer_quantizer, integer_floor_quantizer
# Apparently this function only exists in Python 3.12 ...
@@ -101,7 +101,9 @@ def product_dict(**kwargs):
yield dict(zip(keys, instance))
-def fixed_preprocess_tensor(tensor: Tensor, q_config: dict, parallelism: list) -> list:
+def fixed_preprocess_tensor(
+ tensor: Tensor, q_config: dict, parallelism: list, floor=False
+) -> list:
"""Preprocess a tensor before driving it into the DUT.
1. Quantize to requested fixed-point precision.
2. Convert to integer format to be compatible with Cocotb drivers.
@@ -125,12 +127,13 @@ def fixed_preprocess_tensor(tensor: Tensor, q_config: dict, parallelism: list) -
tensor = tensor.view((-1, tensor.shape[-1]))
# Quantize
- quantizer = partial(integer_quantizer, **q_config)
+ base_quantizer = integer_floor_quantizer if floor else integer_quantizer
+ quantizer = partial(base_quantizer, **q_config)
q_tensor = quantizer(tensor)
-
+ # breakpoint()
# Convert to integer format
q_tensor = (q_tensor * 2 ** q_config["frac_width"]).int()
- q_tensor = signed_to_unsigned(q_tensor, bits=q_config["width"])
+ # q_tensor = signed_to_unsigned(q_tensor, bits=q_config["width"])
# Split into chunks according to parallelism in each dimension
# parallelism[0]: along rows, parallelism[1]: along columns
@@ -175,3 +178,27 @@ def fixed_cast(val, in_width, in_frac_width, out_width, out_frac_width):
val = val
# val = int(val % (1 << out_width))
return val # << out_frac_width # treat data as data
+
+
+async def check_signal(dut, log, signal_list):
+ # TODO: support count start
+ # TODO: support checking signal with different name in valid and ready signal
+ def handshake_signal_check(
+ dut, log, signal_base, valid=None, ready=None, count_start: dict = {}
+ ):
+ data_valid = getattr(dut, f"{signal_base}_valid") if valid is None else valid
+ data_ready = getattr(dut, f"{signal_base}_ready") if ready is None else ready
+ data = getattr(dut, signal_base)
+ svalue = [i.signed_integer for i in data.value]
+ if data_valid.value & data_ready.value:
+ count_start[signal_base] = (
+ count_start[signal_base] + 1
+ if count_start.get(signal_base) is not None
+ else " "
+ )
+ log.debug(f"handshake {count_start[signal_base]} {signal_base} = {svalue}")
+
+ while True:
+ await RisingEdge(dut.clk)
+ for signal in signal_list:
+ handshake_signal_check(dut, log, signal)
diff --git a/src/mase_components/__init__.py b/src/mase_components/__init__.py
index de4c54db6..5779476a7 100644
--- a/src/mase_components/__init__.py
+++ b/src/mase_components/__init__.py
@@ -10,9 +10,21 @@ def get_modules():
for d in os.listdir(current_dir)
if os.path.isdir(os.path.join(current_dir, d))
]
- if "__pycache__" in mods:
- mods.remove("__pycache__")
- return mods
+ detailed_mods = []
+ for mod in mods:
+ new_dir = os.path.join(current_dir, mod)
+ if "rtl" in os.listdir(new_dir):
+ detailed_mods.append(mod)
+ else:
+ update_mods = [
+ mod + "/" + d
+ for d in os.listdir(new_dir)
+ if os.path.isdir(os.path.join(new_dir, d))
+ ]
+ detailed_mods += update_mods
+ if "__pycache__" in detailed_mods:
+ detailed_mods.remove("__pycache__")
+ return detailed_mods
def get_group_files(group):
@@ -27,7 +39,7 @@ def get_group_files(group):
def get_module_dependencies(module):
- group, mod = module.split("/")
+ # group, mod = module.split("/")
group_deps = MASE_HW_DEPS.get(module, [])
file_deps = []
for group_dep in group_deps:
diff --git a/src/mase_components/activation_layers/rtl/fixed_gelu.sv b/src/mase_components/activation_layers/rtl/fixed_gelu.sv
index 9c0cc4235..a1adaf7d5 100644
--- a/src/mase_components/activation_layers/rtl/fixed_gelu.sv
+++ b/src/mase_components/activation_layers/rtl/fixed_gelu.sv
@@ -14,12 +14,12 @@ module fixed_gelu #(
parameter DATA_OUT_0_PRECISION_0 = 8,
parameter DATA_OUT_0_PRECISION_1 = 4,
- parameter DATA_OUT_0_TENSOR_SIZE_DIM_0 = 10,
- parameter DATA_OUT_0_TENSOR_SIZE_DIM_1 = 1,
- parameter DATA_OUT_0_TENSOR_SIZE_DIM_2 = 1,
- parameter DATA_OUT_0_PARALLELISM_DIM_0 = 1,
- parameter DATA_OUT_0_PARALLELISM_DIM_1 = 1,
- parameter DATA_OUT_0_PARALLELISM_DIM_2 = 1
+ parameter DATA_OUT_0_TENSOR_SIZE_DIM_0 = DATA_IN_0_TENSOR_SIZE_DIM_0,
+ parameter DATA_OUT_0_TENSOR_SIZE_DIM_1 = DATA_IN_0_TENSOR_SIZE_DIM_1,
+ parameter DATA_OUT_0_TENSOR_SIZE_DIM_2 = DATA_IN_0_TENSOR_SIZE_DIM_2,
+ parameter DATA_OUT_0_PARALLELISM_DIM_0 = DATA_IN_0_PARALLELISM_DIM_0,
+ parameter DATA_OUT_0_PARALLELISM_DIM_1 = DATA_IN_0_PARALLELISM_DIM_1,
+ parameter DATA_OUT_0_PARALLELISM_DIM_2 = DATA_IN_0_PARALLELISM_DIM_2
) (
/* verilator lint_off UNUSEDSIGNAL */
input clk,
@@ -34,80 +34,19 @@ module fixed_gelu #(
output logic [DATA_OUT_0_PRECISION_0-1:0] data_out_0[DATA_OUT_0_PARALLELISM_DIM_0*DATA_OUT_0_PARALLELISM_DIM_1-1:0]
);
- logic [DATA_IN_0_PRECISION_0-1:0] ff_data[DATA_IN_0_PARALLELISM_DIM_0*DATA_IN_0_PARALLELISM_DIM_1-1:0];
- logic [DATA_IN_0_PRECISION_0-1:0] roll_data[DATA_OUT_0_PARALLELISM_DIM_0*DATA_OUT_0_PARALLELISM_DIM_1-1:0];
-
- logic ff_data_valid;
- logic ff_data_ready;
-
- logic roll_data_valid;
- logic roll_data_ready;
-
- unpacked_fifo #(
- .DEPTH(IN_0_DEPTH),
- .DATA_WIDTH(DATA_IN_0_PRECISION_0),
- .IN_NUM(DATA_IN_0_PARALLELISM_DIM_0 * DATA_IN_0_PARALLELISM_DIM_1)
- ) roller_buffer (
- .clk(clk),
- .rst(rst),
- .data_in(data_in_0),
- .data_in_valid(data_in_0_valid),
- .data_in_ready(data_in_0_ready), // write enable
- .data_out(ff_data),
- .data_out_valid(ff_data_valid),
- .data_out_ready(ff_data_ready) // read enable
- );
-
- localparam STRAIGHT_THROUGH = (DATA_IN_0_PARALLELISM_DIM_0*DATA_IN_0_PARALLELISM_DIM_1 == DATA_OUT_0_PARALLELISM_DIM_0*DATA_OUT_0_PARALLELISM_DIM_1);
-
- generate
- if (STRAIGHT_THROUGH) begin
- unpacked_register_slice_quick #(
- .DATA_WIDTH(DATA_IN_0_PRECISION_0),
- .IN_SIZE(DATA_IN_0_PARALLELISM_DIM_0 * DATA_IN_0_PARALLELISM_DIM_1)
- ) single_roll (
- .clk(clk),
- .rst(rst),
- .in_data(ff_data),
- .in_valid(ff_data_valid),
- .in_ready(ff_data_ready),
- .out_data(roll_data),
- .out_valid(roll_data_valid),
- .out_ready(roll_data_ready)
- );
-
- end else begin
-
- roller #(
- .DATA_WIDTH(DATA_IN_0_PRECISION_0),
- .NUM(DATA_IN_0_PARALLELISM_DIM_0 * DATA_IN_0_PARALLELISM_DIM_1),
- .ROLL_NUM(DATA_OUT_0_PARALLELISM_DIM_0 * DATA_OUT_0_PARALLELISM_DIM_1)
- ) roller_inst (
- .clk(clk),
- .rst(rst),
- .data_in(ff_data),
- .data_in_valid(ff_data_valid),
- .data_in_ready(ff_data_ready),
- .data_out(roll_data),
- .data_out_valid(roll_data_valid),
- .data_out_ready(roll_data_ready)
- );
- end
- endgenerate
-
- for (genvar i = 0; i < DATA_IN_0_PARALLELISM_DIM_0 * DATA_IN_0_PARALLELISM_DIM_1; i++) begin : elu
+ for (genvar i = 0; i < DATA_IN_0_PARALLELISM_DIM_0 * DATA_IN_0_PARALLELISM_DIM_1; i++) begin : gelu
gelu_lut #(
.DATA_IN_0_PRECISION_0 (DATA_IN_0_PRECISION_0),
.DATA_IN_0_PRECISION_1 (DATA_IN_0_PRECISION_1),
.DATA_OUT_0_PRECISION_0(DATA_OUT_0_PRECISION_0),
.DATA_OUT_0_PRECISION_1(DATA_OUT_0_PRECISION_1)
) elu_map (
- .data_in_0 (roll_data[i]),
+ .data_in_0 (data_in_0[i]),
.data_out_0(data_out_0[i])
);
end
- assign data_out_0_valid = roll_data_valid;
- assign roll_data_ready = data_out_0_ready;
+ assign data_out_0_valid = data_in_0_valid;
+ assign data_in_0_ready = data_out_0_ready;
endmodule
diff --git a/src/mase_components/activation_layers/rtl/fixed_softmax.sv b/src/mase_components/activation_layers/rtl/fixed_softmax.sv
index 1158f9d12..be2ddef2c 100644
--- a/src/mase_components/activation_layers/rtl/fixed_softmax.sv
+++ b/src/mase_components/activation_layers/rtl/fixed_softmax.sv
@@ -4,27 +4,23 @@ module fixed_softmax #(
parameter DATA_IN_0_PRECISION_0 = 8,
parameter DATA_IN_0_PRECISION_1 = 4,
parameter DATA_IN_0_TENSOR_SIZE_DIM_0 = 10, // input vector size
- parameter DATA_IN_0_TENSOR_SIZE_DIM_1 = 1, //
- parameter DATA_IN_0_PARALLELISM_DIM_0 = 1, // incoming elements -
- parameter DATA_IN_0_PARALLELISM_DIM_1 = 1, // batch size
+ parameter DATA_IN_0_TENSOR_SIZE_DIM_1 = 6, //
+ parameter DATA_IN_0_PARALLELISM_DIM_0 = 3, // incoming elements -
+ parameter DATA_IN_0_PARALLELISM_DIM_1 = 2, // batch size
parameter IN_0_DEPTH = $rtoi($ceil(DATA_IN_0_TENSOR_SIZE_DIM_0 / DATA_IN_0_PARALLELISM_DIM_0)),
- parameter DATA_OUT_0_PRECISION_0 = 8,
parameter DATA_OUT_0_PRECISION_1 = 4,
- parameter DATA_OUT_0_TENSOR_SIZE_DIM_0 = 10,
- parameter DATA_OUT_0_TENSOR_SIZE_DIM_1 = 1,
- parameter DATA_OUT_0_PARALLELISM_DIM_0 = 1,
- parameter DATA_OUT_0_PARALLELISM_DIM_1 = 1,
+ parameter DATA_OUT_0_PRECISION_0 = DATA_OUT_0_PRECISION_1 + 2,
+ parameter DATA_OUT_0_TENSOR_SIZE_DIM_0 = DATA_IN_0_TENSOR_SIZE_DIM_0,
+ parameter DATA_OUT_0_TENSOR_SIZE_DIM_1 = DATA_IN_0_TENSOR_SIZE_DIM_1,
+ parameter DATA_OUT_0_PARALLELISM_DIM_0 = DATA_IN_0_PARALLELISM_DIM_0,
+ parameter DATA_OUT_0_PARALLELISM_DIM_1 = DATA_IN_0_PARALLELISM_DIM_1,
- parameter OUT_0_DEPTH = $rtoi(
- $ceil(DATA_OUT_0_TENSOR_SIZE_DIM_0 / DATA_OUT_0_PARALLELISM_DIM_0)
- ),
+ parameter OUT_0_DEPTH = IN_0_DEPTH,
- parameter DATA_INTERMEDIATE_0_PRECISION_0 = DATA_IN_0_PRECISION_0,
- parameter DATA_INTERMEDIATE_0_PRECISION_1 = DATA_IN_0_PRECISION_1,
-
- parameter IN_PLACE = 0
+ parameter DATA_EXP_0_PRECISION_0 = 12,
+ parameter DATA_EXP_0_PRECISION_1 = 8
) (
/* verilator lint_off UNUSEDSIGNAL */
input rst,
@@ -43,13 +39,12 @@ module fixed_softmax #(
// Can handle multiple batches at once
// each iteration recieves a batch of blocks
- logic [DATA_IN_0_PRECISION_0-1:0] ff_data[DATA_IN_0_PARALLELISM_DIM_0*DATA_IN_0_PARALLELISM_DIM_1-1:0];
logic [DATA_IN_0_PRECISION_0-1:0] roll_data[DATA_OUT_0_PARALLELISM_DIM_0*DATA_OUT_0_PARALLELISM_DIM_1-1:0];
- logic [DATA_INTERMEDIATE_0_PRECISION_0-1:0] exp_data[DATA_OUT_0_PARALLELISM_DIM_0*DATA_OUT_0_PARALLELISM_DIM_1-1:0];
- logic [DATA_INTERMEDIATE_0_PRECISION_0-1:0] ff_exp_data[DATA_OUT_0_PARALLELISM_DIM_0*DATA_OUT_0_PARALLELISM_DIM_1-1:0];
+ logic [DATA_EXP_0_PRECISION_0-1:0] exp_data[DATA_OUT_0_PARALLELISM_DIM_0*DATA_OUT_0_PARALLELISM_DIM_1-1:0];
+ logic [DATA_EXP_0_PRECISION_0-1:0] ff_exp_data[DATA_OUT_0_PARALLELISM_DIM_0*DATA_OUT_0_PARALLELISM_DIM_1-1:0];
- logic ff_data_valid;
- logic ff_data_ready;
+ // logic ff_data_valid;
+ // logic ff_data_ready;
logic roll_data_valid;
logic roll_data_ready;
@@ -60,7 +55,7 @@ module fixed_softmax #(
logic ff_exp_data_valid;
logic ff_exp_data_ready;
- localparam SUM_WIDTH = $clog2(DATA_OUT_0_PARALLELISM_DIM_0) + DATA_INTERMEDIATE_0_PRECISION_0;
+ localparam SUM_WIDTH = $clog2(DATA_OUT_0_PARALLELISM_DIM_0) + DATA_EXP_0_PRECISION_0;
localparam ACC_WIDTH = $clog2(OUT_0_DEPTH) + SUM_WIDTH;
logic [SUM_WIDTH-1:0] summed_exp_data[DATA_OUT_0_PARALLELISM_DIM_1-1:0]; // sum of current block
@@ -71,68 +66,15 @@ module fixed_softmax #(
logic [ACC_WIDTH-1:0] accumulated_exp_data [DATA_OUT_0_PARALLELISM_DIM_1-1:0]; // accumulation of total vector
logic [ACC_WIDTH-1:0] ff_accumulated_exp_data [DATA_OUT_0_PARALLELISM_DIM_1-1:0]; // accumulation of total vector
-
+ logic [ACC_WIDTH-1:0] ff_accumulated_exp_data_dup [DATA_OUT_0_PARALLELISM_DIM_0 * DATA_OUT_0_PARALLELISM_DIM_1-1:0]; // duplication accumulation of total vector
logic acc_out_valid[DATA_OUT_0_PARALLELISM_DIM_1-1:0];
logic acc_out_ready;
logic ff_acc_valid;
logic ff_acc_ready;
-
- unpacked_fifo #(
- .DEPTH(IN_0_DEPTH),
- .DATA_WIDTH(DATA_IN_0_PRECISION_0),
- .IN_NUM(DATA_IN_0_PARALLELISM_DIM_0 * DATA_IN_0_PARALLELISM_DIM_1)
- ) roller_buffer (
- .clk(clk),
- .rst(rst),
- .data_in(data_in_0),
- .data_in_valid(data_in_0_valid),
- .data_in_ready(data_in_0_ready), // write enable
- .data_out(ff_data),
- .data_out_valid(ff_data_valid),
- .data_out_ready(ff_data_ready) // read enable
- );
-
- localparam STRAIGHT_THROUGH = (DATA_IN_0_PARALLELISM_DIM_0*DATA_IN_0_PARALLELISM_DIM_1 == DATA_OUT_0_PARALLELISM_DIM_0*DATA_OUT_0_PARALLELISM_DIM_1);
-
- generate
- if (STRAIGHT_THROUGH) begin
- unpacked_register_slice_quick #(
- .DATA_WIDTH(DATA_IN_0_PRECISION_0),
- .IN_SIZE(DATA_IN_0_PARALLELISM_DIM_0 * DATA_IN_0_PARALLELISM_DIM_1)
- ) single_roll (
- .clk(clk),
- .rst(rst),
- .in_data(ff_data),
- .in_valid(ff_data_valid),
- .in_ready(ff_data_ready),
- .out_data(roll_data),
- .out_valid(roll_data_valid),
- .out_ready(roll_data_ready)
- );
-
- end else begin
-
- roller #(
- .DATA_WIDTH(DATA_IN_0_PRECISION_0),
- .NUM(DATA_IN_0_PARALLELISM_DIM_0 * DATA_IN_0_PARALLELISM_DIM_1),
- .ROLL_NUM(DATA_OUT_0_PARALLELISM_DIM_0 * DATA_OUT_0_PARALLELISM_DIM_1)
- ) roller_inst (
- .clk(clk),
- .rst(rst),
- .data_in(ff_data),
- .data_in_valid(ff_data_valid),
- .data_in_ready(ff_data_ready),
- .data_out(roll_data),
- .data_out_valid(roll_data_valid),
- .data_out_ready(roll_data_ready)
- );
- end
- endgenerate
-
split2 #() input_handshake_split (
- .data_in_valid (roll_data_valid),
- .data_in_ready (roll_data_ready),
+ .data_in_valid (data_in_0_valid),
+ .data_in_ready (data_in_0_ready),
.data_out_valid({buffer_valid, summed_in_valid}),
.data_out_ready({buffer_ready, summed_in_ready[0]})
);
@@ -144,17 +86,17 @@ module fixed_softmax #(
exp_lut #(
.DATA_IN_0_PRECISION_0 (DATA_IN_0_PRECISION_0),
.DATA_IN_0_PRECISION_1 (DATA_IN_0_PRECISION_1),
- .DATA_OUT_0_PRECISION_0(DATA_INTERMEDIATE_0_PRECISION_0),
- .DATA_OUT_0_PRECISION_1(DATA_INTERMEDIATE_0_PRECISION_1)
+ .DATA_OUT_0_PRECISION_0(DATA_EXP_0_PRECISION_0),
+ .DATA_OUT_0_PRECISION_1(DATA_EXP_0_PRECISION_1)
) exp_map (
- .data_in_0 (roll_data[i]),
+ .data_in_0 (data_in_0[i]),
.data_out_0(exp_data[i])
);
end
unpacked_fifo #(
- .DEPTH(OUT_0_DEPTH),
- .DATA_WIDTH(DATA_INTERMEDIATE_0_PRECISION_0),
+ .DEPTH(OUT_0_DEPTH * 8),
+ .DATA_WIDTH(DATA_EXP_0_PRECISION_0),
.IN_NUM(DATA_OUT_0_PARALLELISM_DIM_0 * DATA_OUT_0_PARALLELISM_DIM_1)
) out_roller_buffer (
.clk(clk),
@@ -173,7 +115,7 @@ module fixed_softmax #(
if (DATA_OUT_0_PARALLELISM_DIM_0 > 1) begin
fixed_adder_tree #(
.IN_SIZE (DATA_OUT_0_PARALLELISM_DIM_0),
- .IN_WIDTH(DATA_INTERMEDIATE_0_PRECISION_0)
+ .IN_WIDTH(DATA_EXP_0_PRECISION_0)
) block_sum (
.clk(clk),
.rst(rst),
@@ -209,10 +151,11 @@ module fixed_softmax #(
end
endgenerate
- hold_buffer #(
+ input_buffer #(
.DATA_WIDTH(ACC_WIDTH),
- .DATA_SIZE(DATA_OUT_0_PARALLELISM_DIM_1),
- .DEPTH(OUT_0_DEPTH)
+ .IN_NUM(DATA_OUT_0_PARALLELISM_DIM_1),
+ .BUFFER_SIZE(1),
+ .REPEAT(IN_0_DEPTH)
) acc_buffer (
.clk(clk),
.rst(rst),
@@ -224,127 +167,45 @@ module fixed_softmax #(
.data_out_ready(ff_acc_ready) // read enable
);
+ //TODO: change to register slice
- logic [DATA_INTERMEDIATE_0_PRECISION_0 - 1 :0] inter_quotient1 [DATA_OUT_0_PARALLELISM_DIM_0*DATA_OUT_0_PARALLELISM_DIM_1-1:0]; // extra bit for rounding division
- logic [DATA_INTERMEDIATE_0_PRECISION_0 - 1 :0] inter_quotient2 [DATA_OUT_0_PARALLELISM_DIM_0*DATA_OUT_0_PARALLELISM_DIM_1-1:0]; // extra bit for rounding division
- logic [DATA_INTERMEDIATE_0_PRECISION_0 - 1 :0] inter_quotient3[DATA_OUT_0_PARALLELISM_DIM_0*DATA_OUT_0_PARALLELISM_DIM_1-1:0]; // extra bit for rounding division
- logic [DATA_INTERMEDIATE_0_PRECISION_0 - 1 :0] inter_quotient4 [DATA_OUT_0_PARALLELISM_DIM_0*DATA_OUT_0_PARALLELISM_DIM_1-1:0]; // extra bit for rounding division
- logic [DATA_INTERMEDIATE_0_PRECISION_0 + DATA_INTERMEDIATE_0_PRECISION_1 - 1 :0] extended_divisor [DATA_OUT_0_PARALLELISM_DIM_0*DATA_OUT_0_PARALLELISM_DIM_1-1:0]; // extra bit for rounding division
-
- logic [DATA_INTERMEDIATE_0_PRECISION_0 + DATA_INTERMEDIATE_0_PRECISION_1 - 1 :0] extended_quotient [DATA_OUT_0_PARALLELISM_DIM_0*DATA_OUT_0_PARALLELISM_DIM_1-1:0]; // extra bit for quantization
- logic [DATA_INTERMEDIATE_0_PRECISION_0 - 1 :0] inter_quotient [DATA_OUT_0_PARALLELISM_DIM_0*DATA_OUT_0_PARALLELISM_DIM_1-1:0]; // extra bit for quantization
-
-
+ logic [DATA_EXP_0_PRECISION_0 + DATA_OUT_0_PRECISION_1 - 1 :0] extended_divisor [DATA_OUT_0_PARALLELISM_DIM_0*DATA_OUT_0_PARALLELISM_DIM_1-1:0]; // extra bit for rounding division
+ logic [DATA_OUT_0_PRECISION_0 + DATA_OUT_0_PRECISION_1 - 1 :0] extended_quotient [DATA_OUT_0_PARALLELISM_DIM_0*DATA_OUT_0_PARALLELISM_DIM_1-1:0]; // extra bit for quantization
for (genvar i = 0; i < DATA_OUT_0_PARALLELISM_DIM_1; i++) begin : scale_batches
for (genvar j = 0; j < DATA_OUT_0_PARALLELISM_DIM_0; j++) begin : div_elements
always_comb begin
- extended_divisor[DATA_OUT_0_PARALLELISM_DIM_0*(i) + j] = ff_exp_data[DATA_OUT_0_PARALLELISM_DIM_0*(i) + j] << DATA_INTERMEDIATE_0_PRECISION_1;
- extended_quotient[DATA_OUT_0_PARALLELISM_DIM_0*(i) + j] = extended_divisor[DATA_OUT_0_PARALLELISM_DIM_0*(i) + j] / ff_accumulated_exp_data[i];
- inter_quotient[DATA_OUT_0_PARALLELISM_DIM_0*(i) + j] = extended_quotient[DATA_OUT_0_PARALLELISM_DIM_0*(i) + j][DATA_INTERMEDIATE_0_PRECISION_0-1:0];
- // data_out_0[DATA_OUT_0_PARALLELISM_DIM_1*(i) + j] = extended_quotient[DATA_OUT_0_PARALLELISM_DIM_1*(i) + j][DATA_OUT_0_PRECISION_0-1:0];
+ extended_divisor[DATA_OUT_0_PARALLELISM_DIM_0*(i) + j] = ff_exp_data[DATA_OUT_0_PARALLELISM_DIM_0*(i) + j] << DATA_OUT_0_PRECISION_1;
+ ff_accumulated_exp_data_dup[DATA_OUT_0_PARALLELISM_DIM_0*(i) + j] = ff_accumulated_exp_data[i];
+ // extended_quotient[DATA_OUT_0_PARALLELISM_DIM_0*(i) + j] = extended_divisor[DATA_OUT_0_PARALLELISM_DIM_0*(i) + j] / ff_accumulated_exp_data[i];
+ data_out_0[DATA_OUT_0_PARALLELISM_DIM_0*(i) + j] = extended_quotient[DATA_OUT_0_PARALLELISM_DIM_0*(i) + j][DATA_OUT_0_PRECISION_0-1:0];
end
- // quick_round #(
- // .DATA_WIDTH(DATA_OUT_0_PRECISION_0)
- // ) round (
- // .data_in(extended_quotient[DATA_OUT_0_PARALLELISM_DIM_0*(i) + j][DATA_OUT_0_PRECISION_0-1:1]),
- // .round_bit(extended_quotient[DATA_OUT_0_PARALLELISM_DIM_0*(i) + j][0]),
- // .data_out(data_out_0[DATA_OUT_0_PARALLELISM_DIM_0*(i) + j])
- // );
end
end
- // assign data_out_0 = inter_quotient;
- // Divide pipeline (retiming)
- logic data_out_0_valid_0;
- logic data_out_0_valid_1;
- logic data_out_0_valid_2;
- logic data_out_0_valid_3;
- logic data_out_0_valid_4;
- always_ff @(posedge clk) begin
- inter_quotient1 <= inter_quotient;
- inter_quotient2 <= inter_quotient1;
- inter_quotient3 <= inter_quotient2;
- inter_quotient4 <= inter_quotient3;
-
- data_out_0_valid_1 <= data_out_0_valid_0;
- data_out_0_valid_2 <= data_out_0_valid_1;
- data_out_0_valid_3 <= data_out_0_valid_2;
- data_out_0_valid_4 <= data_out_0_valid_3;
-
- end
-
- fixed_rounding #(
- .IN_SIZE(DATA_OUT_0_PARALLELISM_DIM_0 * DATA_OUT_0_PARALLELISM_DIM_1),
- .IN_WIDTH(DATA_INTERMEDIATE_0_PRECISION_0),
- .IN_FRAC_WIDTH(DATA_INTERMEDIATE_0_PRECISION_1),
- .OUT_WIDTH(DATA_OUT_0_PRECISION_0),
- .OUT_FRAC_WIDTH(DATA_OUT_0_PRECISION_1)
- ) data_out_cast (
- .data_in (inter_quotient4),
- .data_out(data_out_0)
- );
-
- join2 #() output_handshake_split (
- .data_in_valid ({ff_exp_data_valid, ff_acc_valid}),
- .data_in_ready ({ff_exp_data_ready, ff_acc_ready}),
- .data_out_valid(data_out_0_valid_0),
- .data_out_ready(data_out_0_ready)
+ // join2 #() output_handshake_split (
+ // .data_in_valid ({ff_exp_data_valid, ff_acc_valid}),
+ // .data_in_ready ({ff_exp_data_ready, ff_acc_ready}),
+ // .data_out_valid(data_out_0_valid),
+ // .data_out_ready(data_out_0_ready)
+ // );
+ fixed_div #(
+ .IN_NUM(DATA_OUT_0_PARALLELISM_DIM_0 * DATA_OUT_0_PARALLELISM_DIM_1),
+ .DIVIDEND_WIDTH(DATA_EXP_0_PRECISION_0 + DATA_OUT_0_PRECISION_1),
+ .DIVISOR_WIDTH(ACC_WIDTH),
+ .QUOTIENT_WIDTH(DATA_OUT_0_PRECISION_0 + DATA_OUT_0_PRECISION_1),
+ .FIFO_DEPTH(DATA_OUT_0_TENSOR_SIZE_DIM_0 * DATA_OUT_0_TENSOR_SIZE_DIM_1 / (DATA_OUT_0_PARALLELISM_DIM_0 * DATA_OUT_0_PARALLELISM_DIM_1))
+ ) div_inst (
+ .clk(clk),
+ .rst(rst),
+ .dividend_data(extended_divisor),
+ .dividend_data_valid(ff_exp_data_valid),
+ .dividend_data_ready(ff_exp_data_ready),
+ .divisor_data(ff_accumulated_exp_data_dup),
+ .divisor_data_valid(ff_acc_valid),
+ .divisor_data_ready(ff_acc_ready),
+ .quotient_data(extended_quotient),
+ .quotient_data_valid(data_out_0_valid),
+ .quotient_data_ready(data_out_0_ready)
);
-
- assign data_out_0_valid = data_out_0_valid_4;
-endmodule
-
-/* verilator lint_off DECLFILENAME */
-
-module hold_buffer #(
- parameter DATA_WIDTH = 16,
- parameter DATA_SIZE = 4,
- parameter DEPTH = 1
-) (
- input rst,
- input clk,
-
- input logic [DATA_WIDTH - 1:0] data_in[DATA_SIZE - 1:0],
- input logic data_in_valid,
- output logic data_in_ready,
-
- output logic [DATA_WIDTH - 1:0] data_out[DATA_SIZE - 1:0],
- output logic data_out_valid,
- input logic data_out_ready
-);
-
- logic [$clog2(DEPTH) : 0] count;
- logic [ DATA_WIDTH - 1:0] data_out_register[DATA_SIZE - 1:0];
- assign data_out = data_out_register;
- always_ff @(posedge clk) begin
- if (rst) begin
- count <= 0;
- // data_out_register <= 0;
- data_out_valid <= 0;
- data_in_ready <= 1;
- end else begin
- if (count == 0) begin
- // The buffer is empty
- if (data_in_valid) begin
- data_out_register <= data_in;
- count <= DEPTH;
- data_out_valid <= 1;
- data_in_ready <= 0;
- end else begin
- data_in_ready <= data_out_ready;
- data_out_valid <= 0;
- end
- end else begin
- // The buffer has data
- if (data_out_ready) begin
- count <= count - 1;
- end else begin
- count <= count;
- end
- end
- end
- end
-
- // take an input and output it for depth length preventing further input from entering.
endmodule
diff --git a/src/mase_components/activation_layers/rtl/fixed_tanh.sv b/src/mase_components/activation_layers/rtl/fixed_tanh.sv
index 4cb3e044f..5492e24ec 100644
--- a/src/mase_components/activation_layers/rtl/fixed_tanh.sv
+++ b/src/mase_components/activation_layers/rtl/fixed_tanh.sv
@@ -2,21 +2,21 @@
module fixed_tanh #(
/* verilator lint_off UNUSEDPARAM */
- parameter DATA_IN_0_PRECISION_0 = 16, //total number of bits used to represent each input data
- parameter DATA_IN_0_PRECISION_1 = 8, //fractional bits
- parameter DATA_IN_0_PRECISION_INT = DATA_IN_0_PRECISION_0 - DATA_IN_0_PRECISION_1, //number of integer bits
-
- parameter DATA_IN_0_TENSOR_SIZE_DIM_0 = 8, //total input data per tensor along dim 0
- parameter DATA_IN_0_TENSOR_SIZE_DIM_1 = 1, //total input data per tensor along dim 1
- parameter DATA_IN_0_PARALLELISM_DIM_0 = 1, //input data along dim 0 coming in parallel in the same clock cycle
- parameter DATA_IN_0_PARALLELISM_DIM_1 = 1, //input data along dim 1 coming in parallel in the same clock cycle
-
- parameter DATA_OUT_0_PRECISION_0 = 16, //total number of bits used to represent each output data. Typically needs only (2 + fractional) bits since tanh varies between +/-1.
- parameter DATA_OUT_0_PRECISION_1 = 8, //fractional bits. Output of the module is rounded to satisfy this value
- parameter DATA_OUT_0_TENSOR_SIZE_DIM_0 = 8, //total output data per tensor along dim 0
- parameter DATA_OUT_0_TENSOR_SIZE_DIM_1 = 1, //total output data per tensor along dim 1
- parameter DATA_OUT_0_PARALLELISM_DIM_0 = 1, //output data along dim 0 going out in parallel in the same clock cycle
- parameter DATA_OUT_0_PARALLELISM_DIM_1 = 1 //output data along dim 1 going out in parallel in the same clock cycle
+ parameter DATA_IN_0_PRECISION_0 = 16, //total number of bits used to represent each input data
+ parameter DATA_IN_0_PRECISION_1 = 8, //fractional bits
+ parameter DATA_IN_0_PRECISION_INT = DATA_IN_0_PRECISION_0 - DATA_IN_0_PRECISION_1, //number of integer bits
+
+ parameter DATA_IN_0_TENSOR_SIZE_DIM_0 = 8, //total input data per tensor along dim 0
+ parameter DATA_IN_0_TENSOR_SIZE_DIM_1 = 1, //total input data per tensor along dim 1
+ parameter DATA_IN_0_PARALLELISM_DIM_0 = 1, //input data along dim 0 coming in parallel in the same clock cycle
+ parameter DATA_IN_0_PARALLELISM_DIM_1 = 1, //input data along dim 1 coming in parallel in the same clock cycle
+
+ parameter DATA_OUT_0_PRECISION_0 = 16, //total number of bits used to represent each output data. Typically needs only (2 + fractional) bits since tanh varies between +/-1.
+ parameter DATA_OUT_0_PRECISION_1 = 8, //fractional bits. Output of the module is rounded to satisfy this value
+ parameter DATA_OUT_0_TENSOR_SIZE_DIM_0 = 8, //total output data per tensor along dim 0
+ parameter DATA_OUT_0_TENSOR_SIZE_DIM_1 = 1, //total output data per tensor along dim 1
+ parameter DATA_OUT_0_PARALLELISM_DIM_0 = 1, //output data along dim 0 going out in parallel in the same clock cycle
+ parameter DATA_OUT_0_PARALLELISM_DIM_1 = 1 //output data along dim 1 going out in parallel in the same clock cycle
) (
/* verilator lint_off UNUSEDSIGNAL */
@@ -35,15 +35,15 @@ module fixed_tanh #(
data_out_valid1,
data_out_valid2,
data_out_valid3,
- data_out_valid4; //used to store delayed version of input data valid which is given as output datavalid
+ data_out_valid4; //used to store delayed version of input data valid which is given as output datavalid
- //constants a and b that divides the input range. Stored with 32 bit precision. However, they are rounded to the input precision once specified.
+ //constants a and b that divides the input range. Stored with 32 bit precision. However, they are rounded to the input precision once specified.
const logic signed [33 : 0] a = 34'b0110000101000111101011100001010001;
const logic signed [34 : 0] b = 35'b01010010001111010111000010100011110;
logic signed [DATA_IN_0_PRECISION_0-1:0] a_fixed, b_fixed;
- //rounding a to input precision
+ //rounding a to input precision
fixed_round #(
.IN_WIDTH(34),
.IN_FRAC_WIDTH(32),
@@ -54,7 +54,7 @@ module fixed_tanh #(
.data_out(a_fixed)
);
- //rounding b to input precision
+ //rounding b to input precision
fixed_round #(
.IN_WIDTH(35),
.IN_FRAC_WIDTH(32),
@@ -65,18 +65,18 @@ module fixed_tanh #(
.data_out(b_fixed)
);
- //constants for polynomial approximation. 16 bit fractional precision is used. c1 is 1. Hence not stored.
+ //constants for polynomial approximation. 16 bit fractional precision is used. c1 is 1. Hence not stored.
const logic signed [16 : 0] m1 = 17'b11011101001110111;
const logic signed [16 : 0] d1 = 17'b00000010000011000;
const logic signed [16 : 0] m2 = 17'b11110101001001100;
const logic signed [16 : 0] c2 = 17'b00110110100110001;
const logic signed [16 : 0] d2 = 17'b00111001110101111;
- //generating computation block for each parallel input data
+ //generating computation block for each parallel input data
for (
genvar i = 0; i < DATA_IN_0_PARALLELISM_DIM_0 * DATA_IN_0_PARALLELISM_DIM_1; i++
) begin : tanh
- // Local variables for computation
+ // Local variables for computation
logic signed [DATA_IN_0_PRECISION_0-1:0] data_in1;
logic signed [DATA_IN_0_PRECISION_0-1:0] data_in2;
logic signed [DATA_IN_0_PRECISION_0-1:0] data_in3;
@@ -95,20 +95,20 @@ module fixed_tanh #(
assign x_abs = ($signed(
data_in_0[i]
- ) >= 0) ? data_in_0[i] : -data_in_0[i]; //calculation of absolute value
+ ) >= 0) ? data_in_0[i] : -data_in_0[i]; //calculation of absolute value
assign x_abs_dum = x_abs;
- assign x_squared = x_abs * x_abs; //squaring of absolute value
+ assign x_squared = x_abs * x_abs; //squaring of absolute value
always_ff @(posedge clk) begin
- if (rst) begin //reset conditions
+ if (rst) begin //reset conditions
term0 <= 0;
term1 <= 0;
term2 <= 0;
temp_result <= 0;
- end
- else if (data_out_0_ready && (data_in_0_valid ||data_out_valid1||data_out_valid2)) begin //Calculation of polynomial approximation.Computation is performed in two pipelined stages
+ end
+ else if (data_out_0_ready && (data_in_0_valid ||data_out_valid1||data_out_valid2)) begin //Calculation of polynomial approximation.Computation is performed in two pipelined stages
if (x_abs_dum <= a_fixed) begin
term0 <= 0;
end else if (x_abs_dum <= b_fixed) begin
@@ -155,7 +155,7 @@ module fixed_tanh #(
data_in3 <= data_in2;
end
- //rounding of the output result
+ //rounding of the output result
fixed_round #(
.IN_WIDTH(2 * DATA_IN_0_PRECISION_0 + 17),
.IN_FRAC_WIDTH(2 * DATA_IN_0_PRECISION_1 + 16),
@@ -165,7 +165,7 @@ module fixed_tanh #(
.data_in (temp_result),
.data_out(temp_out)
);
- //assigning the output with sign based on sign of the input.
+ //assigning the output with sign based on sign of the input.
assign data_out_0[i] = (data_in3 >= 0) ? temp_out : -temp_out;
end
diff --git a/src/mase_components/activation_layers/rtl/softermax_lpw_reciprocal.sv b/src/mase_components/activation_layers/rtl/softermax_lpw_reciprocal.sv
index 8d0fd6d60..64df3290b 100644
--- a/src/mase_components/activation_layers/rtl/softermax_lpw_reciprocal.sv
+++ b/src/mase_components/activation_layers/rtl/softermax_lpw_reciprocal.sv
@@ -41,6 +41,9 @@ module softermax_lpw_reciprocal #(
// Parameters
// -----
+ // let max(a, b) = (a > b) ? a : b;
+ // This is not syntheable
+
localparam ENTRIES_WIDTH = $clog2(ENTRIES);
// Range reduced num: x
diff --git a/src/mase_components/activation_layers/test/fixed_softmax_tb.py b/src/mase_components/activation_layers/test/fixed_softmax_tb.py
index 62f37012f..b7c725ac0 100644
--- a/src/mase_components/activation_layers/test/fixed_softmax_tb.py
+++ b/src/mase_components/activation_layers/test/fixed_softmax_tb.py
@@ -1,220 +1,191 @@
#!/usr/bin/env python3
+import os
import pytest
-import os, logging
-from . import generate_memory
-import pdb
-from bitstring import BitArray
-import cocotb
-from functools import partial
-from cocotb.triggers import *
-from chop.nn.quantizers import integer_quantizer
-from mase_cocotb.testbench import Testbench
-from mase_cocotb.interfaces.streaming import (
- StreamDriver,
- StreamMonitor,
- StreamMonitorFloat,
-)
-from mase_cocotb.z_qlayers import quantize_to_int
-from mase_cocotb.runner import mase_runner
-from mase_cocotb.utils import bit_driver, sign_extend_t
-from math import ceil
-
-# from chop.passes.graph.transforms.quantize.quantized_modules import LinearInteger
import torch
+import logging
+from functools import partial
+from mase_components.helper import generate_memory
+from pathlib import Path
+import cocotb
+from cocotb.log import SimLog
+from cocotb.triggers import Timer
-logger = logging.getLogger("testbench")
-logger.setLevel(logging.INFO)
-
-
-def split_and_flatten_2d_tensor(input_tensor, row_block_size, col_block_size):
- rows, cols = input_tensor.size()
-
- num_row_blocks = rows // row_block_size
- num_col_blocks = cols // col_block_size
+from mase_cocotb.testbench import Testbench
+from mase_cocotb.interfaces.streaming import StreamDriver, StreamMonitor
+from mase_cocotb.runner import mase_runner
+from mase_cocotb.utils import fixed_preprocess_tensor
- reshaped_tensor = input_tensor.view(
- num_row_blocks, row_block_size, num_col_blocks, col_block_size
- )
- reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous()
- flattened_tensor = reshaped_tensor.view(-1, row_block_size * col_block_size)
- return flattened_tensor
+from mase_cocotb.utils import bit_driver
+from chop.nn.quantized.functional import softmax_integer
-class fixed_softmax_tb(Testbench):
- def __init__(self, module, dut, dut_params, float_test=False) -> None:
+class SoftmaxTB(Testbench):
+ def __init__(self, dut) -> None:
super().__init__(dut, dut.clk, dut.rst)
- self.data_width = dut_params["DATA_IN_0_PRECISION_0"]
- self.frac_width = dut_params["DATA_IN_0_PRECISION_1"]
-
- self.outputwidth = dut_params["DATA_OUT_0_PRECISION_0"]
- self.outputfracw = dut_params["DATA_OUT_0_PRECISION_1"]
-
- self.num_in_features = dut_params["DATA_IN_0_TENSOR_SIZE_DIM_0"]
- self.num_in_batches = dut_params["DATA_IN_0_TENSOR_SIZE_DIM_1"]
+ if not hasattr(self, "log"):
+ self.log = SimLog("%s" % (type(self).__qualname__))
+ self.log.setLevel(logging.DEBUG)
- self.size_in_feature_blocks = dut_params["DATA_IN_0_PARALLELISM_DIM_0"]
- self.size_in_batch_blocks = dut_params["DATA_IN_0_PARALLELISM_DIM_1"]
-
- self.num_in_feature_splits = int(
- ceil(self.num_in_features / self.size_in_feature_blocks)
- )
- self.num_in_batch_splits = int(
- ceil(self.num_in_batches / self.size_in_batch_blocks)
+ self.in_data_driver = StreamDriver(
+ dut.clk, dut.data_in_0, dut.data_in_0_valid, dut.data_in_0_ready
)
- self.num_out_features = dut_params["DATA_OUT_0_TENSOR_SIZE_DIM_0"]
- self.num_out_batches = dut_params["DATA_OUT_0_TENSOR_SIZE_DIM_1"]
-
- self.size_out_feature_blocks = dut_params["DATA_OUT_0_PARALLELISM_DIM_0"]
- self.size_out_batch_blocks = dut_params["DATA_OUT_0_PARALLELISM_DIM_1"]
-
- self.num_out_feature_splits = int(
- ceil(self.num_out_features / self.size_out_feature_blocks)
+ self.out_data_monitor = StreamMonitor(
+ dut.clk,
+ dut.data_out_0,
+ dut.data_out_0_valid,
+ dut.data_out_0_ready,
+ check=True,
)
- self.num_out_batch_splits = int(
- ceil(self.num_out_batches / self.size_out_batch_blocks)
+ # Model
+ self.model = partial(
+ softmax_integer,
+ config={
+ "data_in_width": self.get_parameter("DATA_IN_0_PRECISION_0"),
+ "data_in_frac_width": self.get_parameter("DATA_IN_0_PRECISION_1"),
+ "data_in_exp_width": self.get_parameter("DATA_EXP_0_PRECISION_0"),
+ "data_in_exp_frac_width": self.get_parameter("DATA_EXP_0_PRECISION_1"),
+ "data_out_frac_width": self.get_parameter("DATA_OUT_0_PRECISION_1"),
+ "mult_data": CONSTANT_MULT,
+ },
+ dim=-1,
+ floor=True,
)
- self.data_in_0_driver = StreamDriver(
- dut.clk, dut.data_in_0, dut.data_in_0_valid, dut.data_in_0_ready
- )
+ # Set verbosity of driver and monitor loggers to debug
+ self.in_data_driver.log.setLevel(logging.DEBUG)
+ self.out_data_monitor.log.setLevel(logging.DEBUG)
- if float_test:
- self.data_out_0_monitor = StreamMonitorFloat(
- dut.clk,
- dut.data_out_0,
- dut.data_out_0_valid,
- dut.data_out_0_ready,
- self.outputwidth,
- self.outputfracw,
- )
- else:
- self.data_out_0_monitor = StreamMonitor(
- dut.clk, dut.data_out_0, dut.data_out_0_valid, dut.data_out_0_ready
+ def generate_inputs(self):
+ return torch.randn(
+ (
+ self.get_parameter("DATA_IN_0_TENSOR_SIZE_DIM_1"),
+ self.get_parameter("DATA_IN_0_TENSOR_SIZE_DIM_0"),
)
-
- self.in_dquantizer = partial(
- integer_quantizer,
- width=self.data_width,
- frac_width=self.frac_width,
- is_signed=True,
)
- self.out_dquantizer = partial(
- integer_quantizer,
- width=self.outputwidth,
- frac_width=self.outputfracw,
- is_signed=True,
- )
+ async def run_test(self, batches, us):
+ await self.reset()
+ self.log.info(f"Reset finished")
+
+ for _ in range(batches):
+ inputs = self.generate_inputs()
+ exp_out = self.model(inputs)
+
+ # * Load the inputs driver
+ self.log.info(f"Processing inputs: {inputs}")
+ inputs = fixed_preprocess_tensor(
+ tensor=inputs,
+ q_config={
+ "width": self.get_parameter("DATA_IN_0_PRECISION_0"),
+ "frac_width": self.get_parameter("DATA_IN_0_PRECISION_1"),
+ },
+ parallelism=[
+ self.get_parameter("DATA_IN_0_PARALLELISM_DIM_1"),
+ self.get_parameter("DATA_IN_0_PARALLELISM_DIM_0"),
+ ],
+ floor=True,
+ )
+ self.in_data_driver.load_driver(inputs)
+
+ # * Load the output monitor
+ self.log.info(f"Processing outputs: {exp_out}")
+ outs = fixed_preprocess_tensor(
+ tensor=exp_out,
+ q_config={
+ "width": self.get_parameter("DATA_OUT_0_PRECISION_0"),
+ "frac_width": self.get_parameter("DATA_OUT_0_PRECISION_1"),
+ },
+ parallelism=[
+ self.get_parameter("DATA_OUT_0_PARALLELISM_DIM_1"),
+ self.get_parameter("DATA_OUT_0_PARALLELISM_DIM_0"),
+ ],
+ )
+ self.out_data_monitor.load_monitor(outs)
- self.model = module
+ await Timer(us, units="us")
+ assert self.out_data_monitor.exp_queue.empty()
- self.real_in_tensor = torch.randn(self.num_in_batches, self.num_in_features)
- self.quant_in_tensor = self.in_dquantizer(self.real_in_tensor)
- self.real_out_tensor = self.model(self.quant_in_tensor)
- logger.info(f"REAL IN TENSOR: \n{self.real_in_tensor}")
- logger.info(f"REAL OUT TENSOR: \n{self.real_out_tensor}")
+@cocotb.test()
+async def single_test(dut):
+ tb = SoftmaxTB(dut)
+ tb.out_data_monitor.ready.value = 1
+ await tb.run_test(batches=50, us=100)
- def exp(self):
- # Run the model with the provided inputs and return the expected integer outputs in the format expected by the monitor
- m = split_and_flatten_2d_tensor(
- self.real_out_tensor,
- self.size_out_batch_blocks,
- self.size_out_feature_blocks,
- ) # match output
- logger.info(f"EXP - FLOAT OUTPUT: \n{m}")
- m = self.out_dquantizer(m)
- m2 = (m * 2**self.outputfracw).to(torch.int64)
- m2 = m2.clone().detach() % (2**self.outputwidth)
- return m2
+# @cocotb.test()
+# async def repeated_mult(dut):
+# tb = SoftmaxTB(dut)
+# tb.out_data_monitor.ready.value = 1
+# await tb.run_test(batches=100, us=2000)
- def generate_inputs(self):
- # Generate the integer inputs for the DUT in the format expected by the driver
- inputs = split_and_flatten_2d_tensor(
- self.real_in_tensor, self.size_in_batch_blocks, self.size_in_feature_blocks
- )
- logger.info(f"FLOAT INPUT: \n{inputs}")
- inputs = self.in_dquantizer(inputs)
- intinp = (inputs * 2**self.frac_width).to(torch.int64)
- return intinp, inputs
-
- def doubletofx(self, num, data_width, f_width, type="bin"):
- assert type == "bin" or type == "hex", "type can only be: 'hex' or 'bin'"
- intnum = int(num * 2 ** (f_width))
- intbits = BitArray(int=intnum, length=data_width)
- return str(intbits.bin) if type == "bin" else str(intbits)
-
- async def run_test(self):
- await self.reset()
- logger.info(f"Reset finished")
- self.data_out_0_monitor.ready.value = 1
- for i in range(1):
- inputs, real_tensor = self.generate_inputs()
- exp_out = self.exp()
- inputs = inputs.tolist()
- exp_out = exp_out.tolist()
- logger.info("Inputs and expected generated")
- logger.info(f"DUT IN: {inputs}")
- logger.info(f"DUT EXP OUT: {exp_out}")
- self.data_in_0_driver.load_driver(inputs)
- self.data_out_0_monitor.load_monitor(exp_out)
-
- await Timer(1000, units="us")
- assert self.data_out_0_monitor.exp_queue.empty()
+# @cocotb.test()
+# async def repeated_mult_backpressure(dut):
+# tb = SoftmaxTB(dut)
+# cocotb.start_soon(bit_driver(dut.data_out_0_ready, dut.clk, 0.6))
+# await tb.run_test(batches=10, us=500)
-@cocotb.test()
-async def cocotb_test(dut):
- in_data_width = dut_params["DATA_IN_0_PRECISION_0"]
- in_frac_width = dut_params["DATA_IN_0_PRECISION_1"]
- out_data_width = dut_params["DATA_OUT_0_PRECISION_0"]
- out_frac_width = dut_params["DATA_OUT_0_PRECISION_1"]
- inter_data_width = dut_params["DATA_INTERMEDIATE_0_PRECISION_0"]
- inter_frac_width = dut_params["DATA_INTERMEDIATE_0_PRECISION_1"]
- # generate_memory.generate_sv_lut("exp", in_data_width, in_frac_width, inter_data_width, inter_frac_width)
- # print("Generated memory")
- tb = fixed_softmax_tb(torch.nn.Softmax(), dut, dut_params, float_test=True)
- await tb.run_test()
+# @cocotb.test()
+# async def repeated_mult_valid_backpressure(dut):
+# tb = SoftmaxTB(dut)
+# tb.in_data_driver.set_valid_prob(0.7)
+# cocotb.start_soon(bit_driver(dut.data_out_0_ready, dut.clk, 0.6))
+# await tb.run_test(batches=50, us=200)
dut_params = {
- "DATA_IN_0_TENSOR_SIZE_DIM_0": 12,
- "DATA_IN_0_TENSOR_SIZE_DIM_1": 4,
- "DATA_IN_0_PARALLELISM_DIM_0": 6,
- "DATA_IN_0_PARALLELISM_DIM_1": 2,
"DATA_IN_0_PRECISION_0": 8,
"DATA_IN_0_PRECISION_1": 4,
- "DATA_OUT_0_PRECISION_0": 8,
- "DATA_OUT_0_PRECISION_1": 4,
- "DATA_OUT_0_TENSOR_SIZE_DIM_0": 12,
- "DATA_OUT_0_TENSOR_SIZE_DIM_1": 4,
- "DATA_OUT_0_PARALLELISM_DIM_0": 6,
- "DATA_OUT_0_PARALLELISM_DIM_1": 2,
- "DATA_INTERMEDIATE_0_PRECISION_0": 12,
- "DATA_INTERMEDIATE_0_PRECISION_1": 8,
+ "DATA_IN_0_TENSOR_SIZE_DIM_0": 32,
+ "DATA_IN_0_TENSOR_SIZE_DIM_1": 1,
+ "DATA_IN_0_PARALLELISM_DIM_0": 1,
+ "DATA_IN_0_PARALLELISM_DIM_1": 1,
+ "DATA_EXP_0_PRECISION_0": 8,
+ "DATA_EXP_0_PRECISION_1": 4,
+ "DATA_OUT_0_PRECISION_1": 6,
}
+
+def get_fixed_softmax_config(kwargs={}):
+ config = dut_params
+ config.update(kwargs)
+ return config
+
+
torch.manual_seed(1)
+CONSTANT_MULT = 0.19
@pytest.mark.dev
-def test_fixed_softmax():
- # generate_memory.generate_sv_lut("exp", dut_params["DATA_IN_0_PRECISION_0"], dut_params["DATA_IN_0_PRECISION_1"])
+def test_fixed_softmax_smoke():
+ """
+ Some quick tests to check if the module is working.
+ """
+ path = Path(__file__).parents[1] / "rtl"
generate_memory.generate_sv_lut(
"exp",
dut_params["DATA_IN_0_PRECISION_0"],
dut_params["DATA_IN_0_PRECISION_1"],
- dut_params["DATA_INTERMEDIATE_0_PRECISION_0"],
- dut_params["DATA_INTERMEDIATE_0_PRECISION_1"],
+ dut_params["DATA_EXP_0_PRECISION_0"],
+ dut_params["DATA_EXP_0_PRECISION_1"],
+ path=path,
+ constant_mult=CONSTANT_MULT,
+ floor=True,
+ )
+ mase_runner(
+ trace=True,
+ module_param_list=[
+ get_fixed_softmax_config(),
+ ],
+ # sim="questa",
+ # skip_build=True,
)
- print("Generated memory")
- mase_runner(module_param_list=[dut_params])
if __name__ == "__main__":
- test_fixed_softmax()
+ test_fixed_softmax_smoke()
diff --git a/src/mase_components/cast/rtl/fixed_round.sv b/src/mase_components/cast/rtl/fixed_round.sv
index a49b43001..2a5291053 100644
--- a/src/mase_components/cast/rtl/fixed_round.sv
+++ b/src/mase_components/cast/rtl/fixed_round.sv
@@ -24,12 +24,13 @@ module fixed_round #(
logic carry_in, input_sign;
assign input_sign = data_in[IN_WIDTH-1];
assign input_data = (input_sign) ? ~(data_in[IN_WIDTH-2:0] - 1) : data_in[IN_WIDTH-2:0];
- /* verilator lint_off SELRANGE */
+ logic [IN_WIDTH + OUT_FRAC_WIDTH - 1:0] lsb_check;
+ assign lsb_check = {input_data, {(OUT_FRAC_WIDTH) {1'b0}}};
always_comb begin
- lsb_below[2] = (IN_FRAC_WIDTH >= OUT_FRAC_WIDTH) ? input_data[IN_FRAC_WIDTH-OUT_FRAC_WIDTH] : 0;
- lsb_below[1] = (IN_FRAC_WIDTH-1 >= OUT_FRAC_WIDTH) ? input_data[IN_FRAC_WIDTH-OUT_FRAC_WIDTH-1] : 0;
- // lsb_below[0] = (IN_FRAC_WIDTH-2 >= OUT_FRAC_WIDTH) ? |(input_data[IN_FRAC_WIDTH-OUT_FRAC_WIDTH-2:0]): 0;
- lsb_below[0] = '0; // to do: fix
+ lsb_below[2] = (IN_FRAC_WIDTH >= OUT_FRAC_WIDTH) ? lsb_check[IN_FRAC_WIDTH] : 0;
+ lsb_below[1] = (IN_FRAC_WIDTH-1 >= OUT_FRAC_WIDTH) ? lsb_check[IN_FRAC_WIDTH-1] : 0;
+ lsb_below[0] = (IN_FRAC_WIDTH-2 >= OUT_FRAC_WIDTH) ? |(lsb_check[IN_FRAC_WIDTH-2:0]): 0;
+ // lsb_below[0] = '0; // to do: fix
end
always_comb begin
if ((IN_FRAC_WIDTH - OUT_FRAC_WIDTH) >= 0)
diff --git a/src/mase_components/cast/rtl/fixed_rounding.sv b/src/mase_components/cast/rtl/fixed_rounding.sv
index 65a0709eb..dc2fd38cf 100644
--- a/src/mase_components/cast/rtl/fixed_rounding.sv
+++ b/src/mase_components/cast/rtl/fixed_rounding.sv
@@ -11,15 +11,14 @@ module fixed_rounding #(
output [OUT_WIDTH - 1:0] data_out[IN_SIZE - 1:0]
);
for (genvar i = 0; i < IN_SIZE; i++) begin : parallel_round
- fixed_signed_cast #(
+ fixed_round #(
.IN_WIDTH(IN_WIDTH),
.IN_FRAC_WIDTH(IN_FRAC_WIDTH),
.OUT_WIDTH(OUT_WIDTH),
- .OUT_FRAC_WIDTH(OUT_FRAC_WIDTH),
- .ROUND_FLOOR(1)
+ .OUT_FRAC_WIDTH(OUT_FRAC_WIDTH)
) fr_inst (
- .in_data (data_in[i]),
- .out_data(data_out[i])
+ .data_in (data_in[i]),
+ .data_out(data_out[i])
);
end
diff --git a/src/mase_components/common/rtl/comparator_tree.sv b/src/mase_components/common/rtl/comparator_tree.sv
index c6599dd7f..4afe5f7c3 100644
--- a/src/mase_components/common/rtl/comparator_tree.sv
+++ b/src/mase_components/common/rtl/comparator_tree.sv
@@ -35,6 +35,7 @@ module comparator_tree #(
logic [DATA_WIDTH-1:0] data[(2**(LEVELS-level))-1:0];
logic valid;
logic ready;
+ if (level == 0) assign data = in_data;
end
@@ -103,7 +104,7 @@ module comparator_tree #(
end
// Connect up first and last layer wires
- assign vars[0].data = in_data;
+ // assign vars[0].data = in_data;
assign vars[0].valid = in_valid;
assign in_ready = vars[0].ready;
diff --git a/src/mase_components/common/rtl/fork2.sv b/src/mase_components/common/rtl/fork2.sv
new file mode 100644
index 000000000..c7cc13673
--- /dev/null
+++ b/src/mase_components/common/rtl/fork2.sv
@@ -0,0 +1,83 @@
+`timescale 1ns / 1ps
+
+module fork2 #(
+ parameter DATA_IN_0_PRECISION_0 = 8,
+ parameter DATA_IN_0_PRECISION_1 = 3,
+ parameter DATA_OUT_0_PRECISION_0 = 8,
+ parameter DATA_OUT_0_PRECISION_1 = 3,
+ parameter DATA_OUT_1_PRECISION_0 = 8,
+ parameter DATA_OUT_1_PRECISION_1 = 3,
+
+ parameter DATA_IN_0_TENSOR_SIZE_DIM_0 = -1,
+ parameter DATA_IN_0_PARALLELISM_DIM_0 = -1,
+ parameter DATA_IN_0_TENSOR_SIZE_DIM_1 = -1,
+ parameter DATA_IN_0_PARALLELISM_DIM_1 = -1,
+
+ parameter DATA_OUT_0_TENSOR_SIZE_DIM_0 = -1,
+ parameter DATA_OUT_0_PARALLELISM_DIM_0 = -1,
+ parameter DATA_OUT_0_TENSOR_SIZE_DIM_1 = -1,
+ parameter DATA_OUT_0_PARALLELISM_DIM_1 = -1,
+
+ parameter DATA_OUT_1_TENSOR_SIZE_DIM_0 = -1,
+ parameter DATA_OUT_1_PARALLELISM_DIM_0 = -1,
+ parameter DATA_OUT_1_TENSOR_SIZE_DIM_1 = -1,
+ parameter DATA_OUT_1_PARALLELISM_DIM_1 = -1,
+
+ parameter DATA_OUT_1_FIFO_DEPTH = DATA_IN_0_TENSOR_SIZE_DIM_0 * DATA_IN_0_TENSOR_SIZE_DIM_1 / (DATA_IN_0_PARALLELISM_DIM_0 * DATA_IN_0_PARALLELISM_DIM_1),
+ parameter DATA_OUT_0_FIFO_DEPTH = DATA_IN_0_TENSOR_SIZE_DIM_0 * DATA_IN_0_TENSOR_SIZE_DIM_1 / (DATA_IN_0_PARALLELISM_DIM_0 * DATA_IN_0_PARALLELISM_DIM_1)
+) (
+ input logic clk,
+ input logic rst,
+
+ input logic [DATA_IN_0_PRECISION_0-1:0] data_in_0 [DATA_IN_0_PARALLELISM_DIM_0*DATA_IN_0_PARALLELISM_DIM_1-1:0],
+ input logic data_in_0_valid,
+ output logic data_in_0_ready,
+
+ output logic [DATA_IN_0_PRECISION_0-1:0] data_out_0 [DATA_IN_0_PARALLELISM_DIM_0*DATA_IN_0_PARALLELISM_DIM_1-1:0],
+ output logic data_out_0_valid,
+ input logic data_out_0_ready,
+
+ output logic [DATA_IN_0_PRECISION_0-1:0] data_out_1 [DATA_IN_0_PARALLELISM_DIM_0*DATA_IN_0_PARALLELISM_DIM_1-1:0],
+ output logic data_out_1_valid,
+ input logic data_out_1_ready
+);
+// logic buffered_data_out_1_valid, buffered_data_out_0_valid;
+// logic buffered_data_out_1_ready, buffered_data_out_0_ready;
+
+ split2 #() split2_inst (
+ .data_out_valid({data_out_1_valid, data_out_0_valid}),
+ .data_out_ready({data_out_1_ready, data_out_0_ready}),
+ .data_in_valid (data_in_0_valid),
+ .data_in_ready (data_in_0_ready)
+ );
+ assign data_out_0 = data_in_0;
+ assign data_out_1 = data_in_0;
+// unpacked_fifo #(
+// .DEPTH(DATA_OUT_0_FIFO_DEPTH),
+// .DATA_WIDTH(DATA_IN_0_PRECISION_0),
+// .IN_NUM(DATA_IN_0_PARALLELISM_DIM_0 * DATA_IN_0_PARALLELISM_DIM_1)
+// ) data_out_0_buffer (
+// .clk(clk),
+// .rst(rst),
+// .data_in(data_in_0),
+// .data_in_valid(buffered_data_out_0_valid),
+// .data_in_ready(buffered_data_out_0_ready), // write enable
+// .data_out(data_out_0),
+// .data_out_valid(data_out_0_valid),
+// .data_out_ready(data_out_0_ready) // read enable
+// );
+// unpacked_fifo #(
+// .DEPTH(DATA_OUT_1_FIFO_DEPTH),
+// .DATA_WIDTH(DATA_IN_0_PRECISION_0),
+// .IN_NUM(DATA_IN_0_PARALLELISM_DIM_0 * DATA_IN_0_PARALLELISM_DIM_1)
+// ) data_out_1_buffer (
+// .clk(clk),
+// .rst(rst),
+// .data_in(data_in_0),
+// .data_in_valid(buffered_data_out_1_valid),
+// .data_in_ready(buffered_data_out_1_ready), // write enable
+// .data_out(data_out_1),
+// .data_out_valid(data_out_1_valid),
+// .data_out_ready(data_out_1_ready) // read enable
+// );
+endmodule
diff --git a/src/mase_components/common/rtl/single_element_repeat.sv b/src/mase_components/common/rtl/single_element_repeat.sv
index 9d1b32822..9cf01ce4b 100644
--- a/src/mase_components/common/rtl/single_element_repeat.sv
+++ b/src/mase_components/common/rtl/single_element_repeat.sv
@@ -81,7 +81,7 @@ module single_element_repeat #(
end
- skid_buffer #(
+ register_slice #(
.DATA_WIDTH(DATA_WIDTH)
) output_buffer (
.clk(clk),
diff --git a/src/mase_components/common/rtl/split2_with_data.sv b/src/mase_components/common/rtl/split2_with_data.sv
new file mode 100644
index 000000000..a0683564b
--- /dev/null
+++ b/src/mase_components/common/rtl/split2_with_data.sv
@@ -0,0 +1,49 @@
+/*
+Module : split2_width_data
+Description : This module implements a 1-to-2 streaming interface handshake.
+*/
+
+`timescale 1ns / 1ps
+module split2_with_data #(
+ parameter DATA_WIDTH = -1,
+ parameter FIFO_DEPTH = -1
+) (
+ input logic clk,
+ input logic rst,
+ input logic [DATA_WIDTH - 1:0] data_in,
+ input logic data_in_valid,
+ output logic data_in_ready,
+
+ output logic [DATA_WIDTH - 1:0] fifo_data_out,
+ output logic fifo_data_out_valid,
+ input logic fifo_data_out_ready,
+
+ output logic [DATA_WIDTH - 1:0] straight_data_out,
+ output logic straight_data_out_valid,
+ input logic straight_data_out_ready
+);
+ logic fifo_in_valid, fifo_in_ready;
+ split2 #() data_out_n_split_i (
+ .data_in_valid (data_in_valid),
+ .data_in_ready (data_in_ready),
+ .data_out_valid({fifo_in_valid, straight_data_out_valid}),
+ .data_out_ready({fifo_in_ready, straight_data_out_ready})
+ );
+ fifo #(
+ .DEPTH(FIFO_DEPTH),
+ .DATA_WIDTH(DATA_WIDTH)
+ ) ff_inst (
+ .clk(clk),
+ .rst(rst),
+ .in_data(data_in),
+ .in_valid(fifo_in_valid),
+ .in_ready(fifo_in_ready),
+ .out_data(fifo_data_out),
+ .out_valid(fifo_data_out_valid),
+ .out_ready(fifo_data_out_ready),
+ .empty(),
+ .full()
+ );
+ assign straight_data_out = data_in;
+
+endmodule
diff --git a/src/mase_components/common/rtl/unpacked_split2_with_data.sv b/src/mase_components/common/rtl/unpacked_split2_with_data.sv
new file mode 100644
index 000000000..a4d3b0301
--- /dev/null
+++ b/src/mase_components/common/rtl/unpacked_split2_with_data.sv
@@ -0,0 +1,59 @@
+module unpacked_split2_with_data #(
+ parameter DEPTH = 8,
+ parameter DATA_WIDTH = 8,
+ parameter IN_SIZE = 8
+) (
+ input clk,
+ input rst,
+ // Input interface
+ input [DATA_WIDTH-1:0] data_in[IN_SIZE - 1:0],
+ input logic data_in_valid,
+ output logic data_in_ready,
+ // FIFO output interface
+ output [DATA_WIDTH-1:0] fifo_data_out[IN_SIZE - 1:0],
+ output logic fifo_data_out_valid,
+ input logic fifo_data_out_ready,
+ // Straight output interface
+ output [DATA_WIDTH-1:0] straight_data_out[IN_SIZE - 1:0],
+ output logic straight_data_out_valid,
+ input logic straight_data_out_ready
+);
+ // Flatten the input data
+ logic [DATA_WIDTH * IN_SIZE - 1:0] data_in_flatten;
+ logic [DATA_WIDTH * IN_SIZE - 1:0] fifo_data_out_flatten;
+ logic [DATA_WIDTH * IN_SIZE - 1:0] straight_data_out_flatten;
+
+ // Input flattening
+ for (genvar i = 0; i < IN_SIZE; i++) begin : reshape
+ assign data_in_flatten[i*DATA_WIDTH+DATA_WIDTH-1:i*DATA_WIDTH] = data_in[i];
+ end
+
+ // Split2 instance
+ split2_with_data #(
+ .DATA_WIDTH(DATA_WIDTH * IN_SIZE),
+ .FIFO_DEPTH(DEPTH)
+ ) split2_with_data_i (
+ .clk(clk),
+ .rst(rst),
+ .data_in(data_in_flatten),
+ .data_in_valid(data_in_valid),
+ .data_in_ready(data_in_ready),
+ .fifo_data_out(fifo_data_out_flatten),
+ .fifo_data_out_valid(fifo_data_out_valid),
+ .fifo_data_out_ready(fifo_data_out_ready),
+ .straight_data_out(straight_data_out_flatten),
+ .straight_data_out_valid(straight_data_out_valid),
+ .straight_data_out_ready(straight_data_out_ready)
+ );
+
+ // Unflatten FIFO output
+ for (genvar i = 0; i < IN_SIZE; i++) begin : unreshape_fifo
+ assign fifo_data_out[i] = fifo_data_out_flatten[i*DATA_WIDTH+DATA_WIDTH-1:i*DATA_WIDTH];
+ end
+
+ // Unflatten straight output
+ for (genvar i = 0; i < IN_SIZE; i++) begin : unreshape_straight
+ assign straight_data_out[i] = straight_data_out_flatten[i*DATA_WIDTH+DATA_WIDTH-1:i*DATA_WIDTH];
+ end
+
+endmodule
\ No newline at end of file
diff --git a/src/mase_components/common/test/test_synth_common.py b/src/mase_components/common/test/test_synth_common.py
index aa7d4dd79..50b54d052 100644
--- a/src/mase_components/common/test/test_synth_common.py
+++ b/src/mase_components/common/test/test_synth_common.py
@@ -4,7 +4,7 @@
@pytest.mark.vivado
def test_synth_common():
- run_synth("common")
+ run_synth("common", "comparator_tree.sv")
if __name__ == "__main__":
diff --git a/src/mase_components/convolution_layers/rtl/convolution.sv b/src/mase_components/convolution_layers/rtl/convolution.sv
index 65a9b06c2..2f2b06d1e 100644
--- a/src/mase_components/convolution_layers/rtl/convolution.sv
+++ b/src/mase_components/convolution_layers/rtl/convolution.sv
@@ -20,7 +20,6 @@ module convolution #(
parameter UNROLL_KERNEL_OUT = 4,
parameter UNROLL_OUT_C = 2,
- parameter SLIDING_NUM = 8,
parameter BIAS_SIZE = UNROLL_OUT_C,
parameter STRIDE = 1,
@@ -29,6 +28,10 @@ module convolution #(
parameter PADDING_X = 2,
parameter HAS_BIAS = 1,
+ parameter OUT_Y = (IN_Y - KERNEL_Y + 2 * PADDING_Y + 1) / (STRIDE),
+ parameter OUT_X = (IN_X - KERNEL_X + 2 * PADDING_X + 1) / (STRIDE),
+ parameter SLIDING_NUM = OUT_Y * OUT_X,
+
parameter DATA_OUT_0_PRECISION_0 = 8,
parameter DATA_OUT_0_PRECISION_1 = 4
) (
@@ -77,6 +80,11 @@ module convolution #(
logic [DATA_IN_0_PRECISION_0 - 1:0] kernel[KERNEL_Y * KERNEL_X * UNROLL_IN_C - 1:0];
logic kernel_valid;
logic kernel_ready;
+ localparam ROUND_PRECISION_0 = DATA_IN_0_PRECISION_0 + WEIGHT_PRECISION_0 + $clog2(
+ KERNEL_X * KERNEL_Y * IN_C
+ );
+ localparam ROUND_PRECISION_1 = DATA_IN_0_PRECISION_1 + WEIGHT_PRECISION_1;
+ logic [ROUND_PRECISION_0 -1:0] round_in[UNROLL_OUT_C-1:0];
sliding_window #(
.IMG_WIDTH (IN_X),
.IMG_HEIGHT (IN_Y),
@@ -89,14 +97,15 @@ module convolution #(
.STRIDE (STRIDE)
/* verilator lint_off PINMISSING */
) sw_inst (
+ .clk(clk),
+ .rst(rst),
.data_in(packed_data_in),
.data_in_valid(data_in_0_valid),
.data_in_ready(data_in_0_ready),
.data_out(packed_kernel),
.data_out_valid(kernel_valid),
- .data_out_ready(kernel_ready),
- .*
+ .data_out_ready(kernel_ready)
);
/* verilator lint_on PINMISSING */
for (genvar i = 0; i < KERNEL_Y * KERNEL_X; i++)
@@ -109,21 +118,17 @@ module convolution #(
.NUM(ROLL_IN_NUM),
.ROLL_NUM(UNROLL_KERNEL_OUT)
) roller_inst (
+ .clk(clk),
+ .rst(rst),
.data_in(kernel),
.data_in_valid(kernel_valid),
.data_in_ready(kernel_ready),
.data_out(rolled_k),
.data_out_valid(rolled_k_valid),
- .data_out_ready(rolled_k_ready),
- .*
+ .data_out_ready(rolled_k_ready)
);
- localparam ROUND_PRECISION_0 = DATA_IN_0_PRECISION_0 + WEIGHT_PRECISION_0 + $clog2(
- KERNEL_X * KERNEL_Y * IN_C
- );
- localparam ROUND_PRECISION_1 = DATA_IN_0_PRECISION_1 + WEIGHT_PRECISION_1;
- logic [ROUND_PRECISION_0 -1:0] round_in[UNROLL_OUT_C-1:0];
- convolution_arith #(
+ convolution_compute_core #(
// assume output will only unroll_out_channels
.DATA_IN_0_PRECISION_0(DATA_IN_0_PRECISION_0),
.DATA_IN_0_PRECISION_1(DATA_IN_0_PRECISION_1),
@@ -138,12 +143,21 @@ module convolution #(
.OUT_CHANNELS_DEPTH(OUT_C / UNROLL_OUT_C),
.WEIGHT_REPEATS(SLIDING_NUM),
.HAS_BIAS(HAS_BIAS)
- ) convolution_arith_inst (
+ ) ccc_inst (
+ .clk(clk),
+ .rst(rst),
.data_in_0(rolled_k),
.data_in_0_valid(rolled_k_valid),
.data_in_0_ready(rolled_k_ready),
+ .weight(weight),
+ .weight_valid(weight_valid),
+ .weight_ready(weight_ready),
+ .bias(bias),
+ .bias_valid(bias_valid),
+ .bias_ready(bias_ready),
.data_out_0(round_in),
- .*
+ .data_out_0_valid(data_out_0_valid),
+ .data_out_0_ready(data_out_0_ready)
);
fixed_rounding #(
diff --git a/src/mase_components/convolution_layers/rtl/convolution_arith.sv b/src/mase_components/convolution_layers/rtl/convolution_compute_core.sv
similarity index 99%
rename from src/mase_components/convolution_layers/rtl/convolution_arith.sv
rename to src/mase_components/convolution_layers/rtl/convolution_compute_core.sv
index 2c3568b1d..8e6d0db7f 100644
--- a/src/mase_components/convolution_layers/rtl/convolution_arith.sv
+++ b/src/mase_components/convolution_layers/rtl/convolution_compute_core.sv
@@ -1,5 +1,6 @@
+/* verilator lint_off DECLFILENAME */
`timescale 1ns / 1ps
-module convolution_arith #(
+module convolution_compute_core #(
// assume output will only unroll_out_channels
parameter DATA_IN_0_PRECISION_0 = 16,
parameter DATA_IN_0_PRECISION_1 = 3,
@@ -107,7 +108,6 @@ module convolution_arith #(
// .data_out(data_out_0[i])
// );end
endmodule
-
module simple_convolution_arith #(
parameter DATA_IN_0_PRECISION_0 = 16,
parameter DATA_IN_0_PRECISION_1 = 3,
@@ -119,7 +119,7 @@ module simple_convolution_arith #(
parameter ROLL_OUT_NUM = 2,
parameter IN_CHANNELS_DEPTH = 4,
parameter OUT_CHANNELS_PARALLELISM = 2,
- parameter HAS_BIAS,
+ parameter HAS_BIAS = 1,
parameter DATA_OUT_0_PRECISION_0 = DATA_IN_0_PRECISION_0 + WEIGHT_PRECISION_0 + $clog2(
ROLL_IN_NUM * IN_CHANNELS_DEPTH
),
@@ -174,7 +174,7 @@ module simple_convolution_arith #(
.DATA_IN_0_PRECISION_0(DATA_IN_0_PRECISION_0),
.WEIGHT_PRECISION_0(WEIGHT_PRECISION_0),
.DP_SIZE(ROLL_OUT_NUM),
- .ACC_DEPTH(ROLL_IN_NUM / ROLL_OUT_NUM * IN_CHANNELS_DEPTH),
+ .ACC_DEPTH(ROLL_IN_NUM / ROLL_OUT_NUM * IN_CHANNELS_DEPTH)
) dp_acc_inst (
.clk(clk),
.rst(rst),
diff --git a/src/mase_components/convolution_layers/rtl/padding.sv b/src/mase_components/convolution_layers/rtl/padding.sv
index fb336acb3..715396ff3 100644
--- a/src/mase_components/convolution_layers/rtl/padding.sv
+++ b/src/mase_components/convolution_layers/rtl/padding.sv
@@ -29,7 +29,11 @@ module padding #(
.data_out(reg_out),
.data_out_valid(reg_out_valid),
.data_out_ready(reg_out_ready),
- .*
+ .data_in(data_in),
+ .data_in_valid(data_in_valid),
+ .data_in_ready(data_in_ready),
+ .clk(clk),
+ .rst(rst)
);
logic [C_WIDTH -1:0] count_c;
logic [X_WIDTH -1:0] count_x;
diff --git a/src/mase_components/convolution_layers/rtl/sliding_window.sv b/src/mase_components/convolution_layers/rtl/sliding_window.sv
index 883e54ca6..79734f745 100644
--- a/src/mase_components/convolution_layers/rtl/sliding_window.sv
+++ b/src/mase_components/convolution_layers/rtl/sliding_window.sv
@@ -148,8 +148,6 @@ module sliding_window_buffer #(
end
end
end
-
-
endmodule
module sliding_window_stride #(
@@ -199,13 +197,17 @@ module sliding_window_stride #(
.DATA_WIDTH(DATA_WIDTH),
.CHANNELS(CHANNELS)
) buffer (
+ .clk (clk),
+ .rst (rst),
+ .data_in (data_in),
+ .data_in_valid (data_in_valid),
+ .data_in_ready (data_in_ready),
+ .data_out_valid(buffer_valid),
+ .data_out_ready(buffer_ready),
.data_out (buffer_data),
.out_x (buffer_x),
.out_y (buffer_y),
- .out_c (buffer_c),
- .data_out_valid(buffer_valid),
- .data_out_ready(buffer_ready),
- .*
+ .out_c (buffer_c)
);
// enable stride == 1
logic in_range;
@@ -312,10 +314,14 @@ module sliding_window #(
.DATA_WIDTH(DATA_WIDTH),
.CHANNELS(CHANNELS)
) padding_inst (
+ .clk(clk),
+ .rst(rst),
+ .data_in(data_in),
+ .data_in_valid(data_in_valid),
+ .data_in_ready(data_in_ready),
.data_out(padding_in),
.data_out_valid(padding_in_valid),
- .data_out_ready(padding_in_ready),
- .*
+ .data_out_ready(padding_in_ready)
);
sliding_window_stride #(
diff --git a/src/mase_components/convolution_layers/rtl/binary_activation_binary_convolution.sv b/src/mase_components/convolution_layers/test/binary_activation_binary_convolution.sv
similarity index 100%
rename from src/mase_components/convolution_layers/rtl/binary_activation_binary_convolution.sv
rename to src/mase_components/convolution_layers/test/binary_activation_binary_convolution.sv
diff --git a/src/mase_components/convolution_layers/test/convolution_tb.py b/src/mase_components/convolution_layers/test/convolution_tb.py
index 1a3789ed9..51bea7fdd 100644
--- a/src/mase_components/convolution_layers/test/convolution_tb.py
+++ b/src/mase_components/convolution_layers/test/convolution_tb.py
@@ -269,7 +269,7 @@ async def run_test(self):
self.data_out_0_monitor.load_monitor(o)
# cocotb.start_soon(check_signal(self.dut, self.log))
- await Timer(100, units="us")
+ await Timer(1000, units="us")
assert self.data_out_0_monitor.exp_queue.empty()
@@ -299,16 +299,16 @@ def get_fixed_conv_config(kwargs={}):
config = {
"IN_C": 3,
"UNROLL_IN_C": 3,
- "IN_X": 3,
- "IN_Y": 3,
- "KERNEL_X": 3,
- "KERNEL_Y": 2,
- "UNROLL_KERNEL_OUT": 3,
- "OUT_C": 4,
- "UNROLL_OUT_C": 2,
+ "IN_X": 16,
+ "IN_Y": 16,
+ "KERNEL_X": 4,
+ "KERNEL_Y": 4,
+ "UNROLL_KERNEL_OUT": 4,
+ "OUT_C": 16,
+ "UNROLL_OUT_C": 4,
"STRIDE": 2,
- "PADDING_Y": 1,
- "PADDING_X": 2,
+ "PADDING_Y": 0,
+ "PADDING_X": 0,
"HAS_BIAS": 1,
}
in_y = config["IN_Y"]
@@ -336,6 +336,7 @@ def test_fixed_linear_smoke():
module_param_list=[
get_fixed_conv_config(),
],
+ sim="questa",
)
@@ -357,20 +358,6 @@ def test_fixed_linear_regression():
"OUT_CHANNELS_DEPTH": 96,
}
),
- # get_fixed_linear_config(
- # {
- # "HAS_BIAS": 1,
- # "WEIGHTS_PRE_TRANSPOSED": 0,
- # "DATA_IN_0_TENSOR_SIZE_DIM_0": 768,
- # "DATA_IN_0_PARALLELISM_DIM_0": 32,
- # "WEIGHT_TENSOR_SIZE_DIM_0": 768,
- # "WEIGHT_TENSOR_SIZE_DIM_1": 768,
- # "WEIGHT_PARALLELISM_DIM_0": 32,
- # "WEIGHT_PARALLELISM_DIM_1": 32,
- # "BIAS_TENSOR_SIZE_DIM_0": 768,
- # "BIAS_PARALLELISM_DIM_0": 32,
- # }
- # ),
],
)
diff --git a/src/mase_components/convolution_layers/test/test_lint_conv.py b/src/mase_components/convolution_layers/test/test_lint_conv.py
index d635b192f..dcb176a21 100644
--- a/src/mase_components/convolution_layers/test/test_lint_conv.py
+++ b/src/mase_components/convolution_layers/test/test_lint_conv.py
@@ -5,7 +5,7 @@
@pytest.mark.skip(reason="Needs to be fixed.")
def test_lint_conv():
- run_lint("conv")
+ run_lint("convolution_layers")
if __name__ == "__main__":
diff --git a/src/mase_components/convolution_layers/test/test_synth_conv.py b/src/mase_components/convolution_layers/test/test_synth_conv.py
index 448ae4e3f..708e31c97 100644
--- a/src/mase_components/convolution_layers/test/test_synth_conv.py
+++ b/src/mase_components/convolution_layers/test/test_synth_conv.py
@@ -4,7 +4,7 @@
@pytest.mark.vivado
def test_synth_conv():
- run_synth("conv")
+ run_synth("convolution_layers")
if __name__ == "__main__":
diff --git a/src/mase_components/deps.py b/src/mase_components/deps.py
index a06529344..a68c1dbaf 100644
--- a/src/mase_components/deps.py
+++ b/src/mase_components/deps.py
@@ -21,6 +21,7 @@
"common",
"memory",
"activation_layers",
+ "generated_lut",
],
"activation_layers/fixed_softsign": [
"common",
@@ -37,10 +38,17 @@
"activation_layers/fixed_logsigmoid": ["common", "cast", "activation_layers"],
"activation_layers/fixed_softmax": [
"common",
+ "memory",
+ "scalar_operators/fixed",
"cast",
- "fixed_arithmetic",
- "conv",
+ "linear_layers/fixed_operators",
+ "generated_lut",
"activation_layers",
+ "convolution_layers",
+ "memory",
+ "linear_layers/fixed_operators",
+ "scalar_operators/fixed",
+ "generated_lut",
],
"activation_layers/fixed_softermax_1d": [
"common",
@@ -133,6 +141,13 @@
"common",
],
# Linear
+ "linear_layers/fixed_linear_layer/fixed_linear_with_input_circular": [
+ "cast",
+ "common",
+ "memory",
+ "linear_layers/fixed_operators",
+ "scalar_operators/fixed",
+ ],
"linear_layers/fixed_linear_layer/fixed_linear": [
"cast",
"common",
@@ -258,6 +273,7 @@
"linear_layers/fixed_operators",
"common",
"memory",
+ "cast",
],
"linear_layers/mxint_operators/mxint_dot_product": [
"linear_layers/mxint_operators",
@@ -265,6 +281,34 @@
"common",
"memory",
],
+ "linear_layers/mxint_operators/mxint_range_reduction": [
+ "linear_layers/mxint_operators",
+ "common",
+ "memory",
+ "cast",
+ ],
+ "linear_layers/mxint_operators/mxint_exp": [
+ "linear_layers/mxint_operators",
+ "common",
+ "memory",
+ "cast",
+ "generated_lut",
+ ],
+ "linear_layers/mxint_operators/mxint_softmax": [
+ "linear_layers/mxint_operators",
+ "common",
+ "memory",
+ "cast",
+ "scalar_operators/fixed",
+ "generated_lut",
+ ],
+ "linear_layers/mxint_operators/mxint_addition": [
+ "linear_layers/mxint_operators",
+ "linear_layers/fixed_operators",
+ "common",
+ "memory",
+ "cast",
+ ],
"linear_layers/mxint_operators/mxint_linear": [
"linear_layers/mxint_operators",
"linear_layers/fixed_operators",
@@ -279,6 +323,33 @@
"memory",
"cast",
],
+ "linear_layers/mxint_operators/mxint_gelu": [
+ "linear_layers/mxint_operators",
+ "linear_layers/fixed_operators",
+ "common",
+ "memory",
+ "cast",
+ "generated_lut"
+ ],
+ "linear_layers/mxint_operators/mxint_vit_attention_head": [
+ "linear_layers/mxint_operators",
+ "linear_layers/fixed_operators",
+ "linear_layers/matmul",
+ "common",
+ "memory",
+ "cast",
+ "scalar_operators/fixed",
+ ],
+ "linear_layers/mxint_operators/mxint_vit_attention_wrap": [
+ "linear_layers/mxint_operators",
+ "linear_layers/fixed_operators",
+ "transformer_layers",
+ "linear_layers/matmul",
+ "common",
+ "memory",
+ "cast",
+ "scalar_operators/fixed",
+ ],
"linear_layers/mxint_operators/old_linear": [
"linear_layers/mxint_operators",
"linear_layers/fixed_operators",
@@ -303,11 +374,48 @@
"memory",
"cast",
],
+ "linear_layers/mxint_operators/mxint_patch_embed": [
+ "convolution_layers",
+ "linear_layers/matmul",
+ "linear_layers/mxint_operators",
+ "linear_layers/fixed_operators",
+ "common",
+ "memory",
+ "cast",
+ ],
+ "linear_layers/mxint_operators/mxint_hardware_round": [
+ "linear_layers/mxint_operators",
+ "common",
+ "memory",
+ "cast",
+ ],
"linear_layers/mxint_operators/log2_max_abs": [
"linear_layers/mxint_operators",
"common",
"memory",
],
+ "linear_layers/mxint_operators/mxint_layernorm_1d": [
+ "common",
+ "linear_layers/matmul",
+ "linear_layers/fixed_operators",
+ "scalar_operators/fixed",
+ "normalization_layers",
+ "cast",
+ "memory",
+ "generated_lut",
+ "linear_layers/mxint_operators",
+ ],
+ "linear_layers/mxint_operators/mxint_layernorm": [
+ "common",
+ "linear_layers/matmul",
+ "linear_layers/fixed_operators",
+ "scalar_operators/fixed",
+ "normalization_layers",
+ "cast",
+ "memory",
+ "generated_lut",
+ "linear_layers/mxint_operators",
+ ],
# Memory
"memory/skid_buffer": [],
"memory/fifo": ["memory"],
@@ -316,6 +424,7 @@
"memory/ram_block": [],
"memory/unpacked_fifo": ["memory"],
"memory/unpacked_skid_buffer": ["memory"],
+ "memory/weight_source": ["memory"],
# Normalization Layers
"normalization_layers/batch_norm_2d": [
"normalization_layers",
@@ -343,6 +452,16 @@
"cast",
"memory",
],
+ "normalization_layers/layer_norm_2d": [
+ "common",
+ "linear_layers/matmul",
+ "linear_layers/fixed_operators",
+ "scalar_operators/fixed",
+ "normalization_layers",
+ "cast",
+ "memory",
+ "generated_lut",
+ ],
# Scalar Operators
"scalar_operators/fixed/fixed_isqrt": [
"memory",
@@ -356,6 +475,12 @@
"scalar_operators/fixed",
"linear_layers/fixed_operators",
],
+ "scalar_operators/fixed/fixed_div": [
+ "scalar_operators/fixed",
+ "memory",
+ "cast",
+ "common",
+ ],
# Transformer Layers
"transformer_layers/fixed_self_attention": [
"transformer_layers",
@@ -399,6 +524,31 @@
],
"arithmetic/mac": ["fixed_arithmetic", "float_arithmetic"],
# ViT
+ "vision_models/vit/fixed_vit_attention_head": [
+ "vision_models/attention",
+ "cast",
+ "memory",
+ "common",
+ "linear_layers/fixed_operators",
+ "linear_layers/fixed_linear_layer",
+ "linear_layers/matmul",
+ "activation_layers",
+ "scalar_operators/fixed",
+ "generated_lut",
+ ],
+ "vision_models/vit/fixed_vit_attention": [
+ "vision_models/vit",
+ "transformer_layers",
+ "cast",
+ "memory",
+ "common",
+ "linear_layers/fixed_operators",
+ "linear_layers/fixed_linear_layer",
+ "linear_layers/matmul",
+ "activation_layers",
+ "scalar_operators/fixed",
+ "generated_lut",
+ ],
"ViT/fixed_patch_embed": [
"conv",
"ViT",
diff --git a/src/mase_components/generated_lut/rtl/exp_lut.sv b/src/mase_components/generated_lut/rtl/exp_lut.sv
new file mode 100644
index 000000000..c8af7c8bf
--- /dev/null
+++ b/src/mase_components/generated_lut/rtl/exp_lut.sv
@@ -0,0 +1,277 @@
+
+`timescale 1ns / 1ps
+/* verilator lint_off UNUSEDPARAM */
+module exp_lut #(
+ parameter DATA_IN_0_PRECISION_0 = 16,
+ parameter DATA_IN_0_PRECISION_1 = 8,
+ parameter DATA_OUT_0_PRECISION_0 = 16,
+ parameter DATA_OUT_0_PRECISION_1 = 8
+) (
+ /* verilator lint_off UNUSEDSIGNAL */
+ input logic [7:0] data_in_0,
+ output logic [7:0] data_out_0
+);
+
+
+ always_comb begin
+ case (data_in_0)
+ 8'b00000000: data_out_0 = 8'b00010000;
+ 8'b00000001: data_out_0 = 8'b00010000;
+ 8'b00000010: data_out_0 = 8'b00010000;
+ 8'b00000011: data_out_0 = 8'b00010000;
+ 8'b00000100: data_out_0 = 8'b00010000;
+ 8'b00000101: data_out_0 = 8'b00010000;
+ 8'b00000110: data_out_0 = 8'b00010001;
+ 8'b00000111: data_out_0 = 8'b00010001;
+ 8'b00001000: data_out_0 = 8'b00010001;
+ 8'b00001001: data_out_0 = 8'b00010001;
+ 8'b00001010: data_out_0 = 8'b00010010;
+ 8'b00001011: data_out_0 = 8'b00010010;
+ 8'b00001100: data_out_0 = 8'b00010010;
+ 8'b00001101: data_out_0 = 8'b00010010;
+ 8'b00001110: data_out_0 = 8'b00010010;
+ 8'b00001111: data_out_0 = 8'b00010011;
+ 8'b00010000: data_out_0 = 8'b00010011;
+ 8'b00010001: data_out_0 = 8'b00010011;
+ 8'b00010010: data_out_0 = 8'b00010011;
+ 8'b00010011: data_out_0 = 8'b00010100;
+ 8'b00010100: data_out_0 = 8'b00010100;
+ 8'b00010101: data_out_0 = 8'b00010100;
+ 8'b00010110: data_out_0 = 8'b00010100;
+ 8'b00010111: data_out_0 = 8'b00010101;
+ 8'b00011000: data_out_0 = 8'b00010101;
+ 8'b00011001: data_out_0 = 8'b00010101;
+ 8'b00011010: data_out_0 = 8'b00010101;
+ 8'b00011011: data_out_0 = 8'b00010110;
+ 8'b00011100: data_out_0 = 8'b00010110;
+ 8'b00011101: data_out_0 = 8'b00010110;
+ 8'b00011110: data_out_0 = 8'b00010110;
+ 8'b00011111: data_out_0 = 8'b00010111;
+ 8'b00100000: data_out_0 = 8'b00010111;
+ 8'b00100001: data_out_0 = 8'b00010111;
+ 8'b00100010: data_out_0 = 8'b00010111;
+ 8'b00100011: data_out_0 = 8'b00011000;
+ 8'b00100100: data_out_0 = 8'b00011000;
+ 8'b00100101: data_out_0 = 8'b00011000;
+ 8'b00100110: data_out_0 = 8'b00011001;
+ 8'b00100111: data_out_0 = 8'b00011001;
+ 8'b00101000: data_out_0 = 8'b00011001;
+ 8'b00101001: data_out_0 = 8'b00011010;
+ 8'b00101010: data_out_0 = 8'b00011010;
+ 8'b00101011: data_out_0 = 8'b00011010;
+ 8'b00101100: data_out_0 = 8'b00011010;
+ 8'b00101101: data_out_0 = 8'b00011011;
+ 8'b00101110: data_out_0 = 8'b00011011;
+ 8'b00101111: data_out_0 = 8'b00011011;
+ 8'b00110000: data_out_0 = 8'b00011100;
+ 8'b00110001: data_out_0 = 8'b00011100;
+ 8'b00110010: data_out_0 = 8'b00011100;
+ 8'b00110011: data_out_0 = 8'b00011101;
+ 8'b00110100: data_out_0 = 8'b00011101;
+ 8'b00110101: data_out_0 = 8'b00011110;
+ 8'b00110110: data_out_0 = 8'b00011110;
+ 8'b00110111: data_out_0 = 8'b00011110;
+ 8'b00111000: data_out_0 = 8'b00011111;
+ 8'b00111001: data_out_0 = 8'b00011111;
+ 8'b00111010: data_out_0 = 8'b00011111;
+ 8'b00111011: data_out_0 = 8'b00100000;
+ 8'b00111100: data_out_0 = 8'b00100000;
+ 8'b00111101: data_out_0 = 8'b00100001;
+ 8'b00111110: data_out_0 = 8'b00100001;
+ 8'b00111111: data_out_0 = 8'b00100001;
+ 8'b01000000: data_out_0 = 8'b00100010;
+ 8'b01000001: data_out_0 = 8'b00100010;
+ 8'b01000010: data_out_0 = 8'b00100011;
+ 8'b01000011: data_out_0 = 8'b00100011;
+ 8'b01000100: data_out_0 = 8'b00100011;
+ 8'b01000101: data_out_0 = 8'b00100100;
+ 8'b01000110: data_out_0 = 8'b00100100;
+ 8'b01000111: data_out_0 = 8'b00100101;
+ 8'b01001000: data_out_0 = 8'b00100101;
+ 8'b01001001: data_out_0 = 8'b00100110;
+ 8'b01001010: data_out_0 = 8'b00100110;
+ 8'b01001011: data_out_0 = 8'b00100110;
+ 8'b01001100: data_out_0 = 8'b00100111;
+ 8'b01001101: data_out_0 = 8'b00100111;
+ 8'b01001110: data_out_0 = 8'b00101000;
+ 8'b01001111: data_out_0 = 8'b00101000;
+ 8'b01010000: data_out_0 = 8'b00101001;
+ 8'b01010001: data_out_0 = 8'b00101001;
+ 8'b01010010: data_out_0 = 8'b00101010;
+ 8'b01010011: data_out_0 = 8'b00101010;
+ 8'b01010100: data_out_0 = 8'b00101011;
+ 8'b01010101: data_out_0 = 8'b00101011;
+ 8'b01010110: data_out_0 = 8'b00101100;
+ 8'b01010111: data_out_0 = 8'b00101100;
+ 8'b01011000: data_out_0 = 8'b00101101;
+ 8'b01011001: data_out_0 = 8'b00101110;
+ 8'b01011010: data_out_0 = 8'b00101110;
+ 8'b01011011: data_out_0 = 8'b00101111;
+ 8'b01011100: data_out_0 = 8'b00101111;
+ 8'b01011101: data_out_0 = 8'b00110000;
+ 8'b01011110: data_out_0 = 8'b00110000;
+ 8'b01011111: data_out_0 = 8'b00110001;
+ 8'b01100000: data_out_0 = 8'b00110010;
+ 8'b01100001: data_out_0 = 8'b00110010;
+ 8'b01100010: data_out_0 = 8'b00110011;
+ 8'b01100011: data_out_0 = 8'b00110011;
+ 8'b01100100: data_out_0 = 8'b00110100;
+ 8'b01100101: data_out_0 = 8'b00110101;
+ 8'b01100110: data_out_0 = 8'b00110101;
+ 8'b01100111: data_out_0 = 8'b00110110;
+ 8'b01101000: data_out_0 = 8'b00110111;
+ 8'b01101001: data_out_0 = 8'b00110111;
+ 8'b01101010: data_out_0 = 8'b00111000;
+ 8'b01101011: data_out_0 = 8'b00111001;
+ 8'b01101100: data_out_0 = 8'b00111001;
+ 8'b01101101: data_out_0 = 8'b00111010;
+ 8'b01101110: data_out_0 = 8'b00111011;
+ 8'b01101111: data_out_0 = 8'b00111011;
+ 8'b01110000: data_out_0 = 8'b00111100;
+ 8'b01110001: data_out_0 = 8'b00111101;
+ 8'b01110010: data_out_0 = 8'b00111101;
+ 8'b01110011: data_out_0 = 8'b00111110;
+ 8'b01110100: data_out_0 = 8'b00111111;
+ 8'b01110101: data_out_0 = 8'b01000000;
+ 8'b01110110: data_out_0 = 8'b01000000;
+ 8'b01110111: data_out_0 = 8'b01000001;
+ 8'b01111000: data_out_0 = 8'b01000010;
+ 8'b01111001: data_out_0 = 8'b01000011;
+ 8'b01111010: data_out_0 = 8'b01000100;
+ 8'b01111011: data_out_0 = 8'b01000100;
+ 8'b01111100: data_out_0 = 8'b01000101;
+ 8'b01111101: data_out_0 = 8'b01000110;
+ 8'b01111110: data_out_0 = 8'b01000111;
+ 8'b01111111: data_out_0 = 8'b01001000;
+ 8'b10000000: data_out_0 = 8'b00000011;
+ 8'b10000001: data_out_0 = 8'b00000011;
+ 8'b10000010: data_out_0 = 8'b00000011;
+ 8'b10000011: data_out_0 = 8'b00000011;
+ 8'b10000100: data_out_0 = 8'b00000011;
+ 8'b10000101: data_out_0 = 8'b00000011;
+ 8'b10000110: data_out_0 = 8'b00000011;
+ 8'b10000111: data_out_0 = 8'b00000011;
+ 8'b10001000: data_out_0 = 8'b00000011;
+ 8'b10001001: data_out_0 = 8'b00000011;
+ 8'b10001010: data_out_0 = 8'b00000011;
+ 8'b10001011: data_out_0 = 8'b00000011;
+ 8'b10001100: data_out_0 = 8'b00000100;
+ 8'b10001101: data_out_0 = 8'b00000100;
+ 8'b10001110: data_out_0 = 8'b00000100;
+ 8'b10001111: data_out_0 = 8'b00000100;
+ 8'b10010000: data_out_0 = 8'b00000100;
+ 8'b10010001: data_out_0 = 8'b00000100;
+ 8'b10010010: data_out_0 = 8'b00000100;
+ 8'b10010011: data_out_0 = 8'b00000100;
+ 8'b10010100: data_out_0 = 8'b00000100;
+ 8'b10010101: data_out_0 = 8'b00000100;
+ 8'b10010110: data_out_0 = 8'b00000100;
+ 8'b10010111: data_out_0 = 8'b00000100;
+ 8'b10011000: data_out_0 = 8'b00000100;
+ 8'b10011001: data_out_0 = 8'b00000100;
+ 8'b10011010: data_out_0 = 8'b00000100;
+ 8'b10011011: data_out_0 = 8'b00000100;
+ 8'b10011100: data_out_0 = 8'b00000100;
+ 8'b10011101: data_out_0 = 8'b00000100;
+ 8'b10011110: data_out_0 = 8'b00000100;
+ 8'b10011111: data_out_0 = 8'b00000101;
+ 8'b10100000: data_out_0 = 8'b00000101;
+ 8'b10100001: data_out_0 = 8'b00000101;
+ 8'b10100010: data_out_0 = 8'b00000101;
+ 8'b10100011: data_out_0 = 8'b00000101;
+ 8'b10100100: data_out_0 = 8'b00000101;
+ 8'b10100101: data_out_0 = 8'b00000101;
+ 8'b10100110: data_out_0 = 8'b00000101;
+ 8'b10100111: data_out_0 = 8'b00000101;
+ 8'b10101000: data_out_0 = 8'b00000101;
+ 8'b10101001: data_out_0 = 8'b00000101;
+ 8'b10101010: data_out_0 = 8'b00000101;
+ 8'b10101011: data_out_0 = 8'b00000101;
+ 8'b10101100: data_out_0 = 8'b00000101;
+ 8'b10101101: data_out_0 = 8'b00000101;
+ 8'b10101110: data_out_0 = 8'b00000110;
+ 8'b10101111: data_out_0 = 8'b00000110;
+ 8'b10110000: data_out_0 = 8'b00000110;
+ 8'b10110001: data_out_0 = 8'b00000110;
+ 8'b10110010: data_out_0 = 8'b00000110;
+ 8'b10110011: data_out_0 = 8'b00000110;
+ 8'b10110100: data_out_0 = 8'b00000110;
+ 8'b10110101: data_out_0 = 8'b00000110;
+ 8'b10110110: data_out_0 = 8'b00000110;
+ 8'b10110111: data_out_0 = 8'b00000110;
+ 8'b10111000: data_out_0 = 8'b00000110;
+ 8'b10111001: data_out_0 = 8'b00000110;
+ 8'b10111010: data_out_0 = 8'b00000110;
+ 8'b10111011: data_out_0 = 8'b00000111;
+ 8'b10111100: data_out_0 = 8'b00000111;
+ 8'b10111101: data_out_0 = 8'b00000111;
+ 8'b10111110: data_out_0 = 8'b00000111;
+ 8'b10111111: data_out_0 = 8'b00000111;
+ 8'b11000000: data_out_0 = 8'b00000111;
+ 8'b11000001: data_out_0 = 8'b00000111;
+ 8'b11000010: data_out_0 = 8'b00000111;
+ 8'b11000011: data_out_0 = 8'b00000111;
+ 8'b11000100: data_out_0 = 8'b00000111;
+ 8'b11000101: data_out_0 = 8'b00000111;
+ 8'b11000110: data_out_0 = 8'b00001000;
+ 8'b11000111: data_out_0 = 8'b00001000;
+ 8'b11001000: data_out_0 = 8'b00001000;
+ 8'b11001001: data_out_0 = 8'b00001000;
+ 8'b11001010: data_out_0 = 8'b00001000;
+ 8'b11001011: data_out_0 = 8'b00001000;
+ 8'b11001100: data_out_0 = 8'b00001000;
+ 8'b11001101: data_out_0 = 8'b00001000;
+ 8'b11001110: data_out_0 = 8'b00001000;
+ 8'b11001111: data_out_0 = 8'b00001000;
+ 8'b11010000: data_out_0 = 8'b00001001;
+ 8'b11010001: data_out_0 = 8'b00001001;
+ 8'b11010010: data_out_0 = 8'b00001001;
+ 8'b11010011: data_out_0 = 8'b00001001;
+ 8'b11010100: data_out_0 = 8'b00001001;
+ 8'b11010101: data_out_0 = 8'b00001001;
+ 8'b11010110: data_out_0 = 8'b00001001;
+ 8'b11010111: data_out_0 = 8'b00001001;
+ 8'b11011000: data_out_0 = 8'b00001001;
+ 8'b11011001: data_out_0 = 8'b00001010;
+ 8'b11011010: data_out_0 = 8'b00001010;
+ 8'b11011011: data_out_0 = 8'b00001010;
+ 8'b11011100: data_out_0 = 8'b00001010;
+ 8'b11011101: data_out_0 = 8'b00001010;
+ 8'b11011110: data_out_0 = 8'b00001010;
+ 8'b11011111: data_out_0 = 8'b00001010;
+ 8'b11100000: data_out_0 = 8'b00001010;
+ 8'b11100001: data_out_0 = 8'b00001011;
+ 8'b11100010: data_out_0 = 8'b00001011;
+ 8'b11100011: data_out_0 = 8'b00001011;
+ 8'b11100100: data_out_0 = 8'b00001011;
+ 8'b11100101: data_out_0 = 8'b00001011;
+ 8'b11100110: data_out_0 = 8'b00001011;
+ 8'b11100111: data_out_0 = 8'b00001011;
+ 8'b11101000: data_out_0 = 8'b00001100;
+ 8'b11101001: data_out_0 = 8'b00001100;
+ 8'b11101010: data_out_0 = 8'b00001100;
+ 8'b11101011: data_out_0 = 8'b00001100;
+ 8'b11101100: data_out_0 = 8'b00001100;
+ 8'b11101101: data_out_0 = 8'b00001100;
+ 8'b11101110: data_out_0 = 8'b00001100;
+ 8'b11101111: data_out_0 = 8'b00001101;
+ 8'b11110000: data_out_0 = 8'b00001101;
+ 8'b11110001: data_out_0 = 8'b00001101;
+ 8'b11110010: data_out_0 = 8'b00001101;
+ 8'b11110011: data_out_0 = 8'b00001101;
+ 8'b11110100: data_out_0 = 8'b00001101;
+ 8'b11110101: data_out_0 = 8'b00001110;
+ 8'b11110110: data_out_0 = 8'b00001110;
+ 8'b11110111: data_out_0 = 8'b00001110;
+ 8'b11111000: data_out_0 = 8'b00001110;
+ 8'b11111001: data_out_0 = 8'b00001110;
+ 8'b11111010: data_out_0 = 8'b00001110;
+ 8'b11111011: data_out_0 = 8'b00001111;
+ 8'b11111100: data_out_0 = 8'b00001111;
+ 8'b11111101: data_out_0 = 8'b00001111;
+ 8'b11111110: data_out_0 = 8'b00001111;
+ 8'b11111111: data_out_0 = 8'b00001111;
+ default: data_out_0 = 8'b0;
+ endcase
+ end
+endmodule
diff --git a/src/mase_components/generated_lut/rtl/gelu_lut.sv b/src/mase_components/generated_lut/rtl/gelu_lut.sv
new file mode 100644
index 000000000..5cbedb1bc
--- /dev/null
+++ b/src/mase_components/generated_lut/rtl/gelu_lut.sv
@@ -0,0 +1,37 @@
+
+`timescale 1ns / 1ps
+/* verilator lint_off UNUSEDPARAM */
+module gelu_lut #(
+ parameter DATA_IN_0_PRECISION_0 = 16,
+ parameter DATA_IN_0_PRECISION_1 = 8,
+ parameter DATA_OUT_0_PRECISION_0 = 16,
+ parameter DATA_OUT_0_PRECISION_1 = 8
+) (
+ /* verilator lint_off UNUSEDSIGNAL */
+ input logic [4:0] data_in_0,
+ output logic [8:0] data_out_0
+);
+
+
+ always_comb begin
+ case (data_in_0)
+ 4'b0000: data_out_0 = 8'b00000000;
+ 4'b0001: data_out_0 = 8'b00001010;
+ 4'b0010: data_out_0 = 8'b00010110;
+ 4'b0011: data_out_0 = 8'b00100101;
+ 4'b0100: data_out_0 = 8'b00110110;
+ 4'b0101: data_out_0 = 8'b01001000;
+ 4'b0110: data_out_0 = 8'b01011010;
+ 4'b0111: data_out_0 = 8'b01101100;
+ 4'b1000: data_out_0 = 8'b11111101;
+ 4'b1001: data_out_0 = 8'b11111100;
+ 4'b1010: data_out_0 = 8'b11111010;
+ 4'b1011: data_out_0 = 8'b11111000;
+ 4'b1100: data_out_0 = 8'b11110110;
+ 4'b1101: data_out_0 = 8'b11110101;
+ 4'b1110: data_out_0 = 8'b11110110;
+ 4'b1111: data_out_0 = 8'b11111010;
+ default: data_out_0 = 8'b0;
+ endcase
+ end
+endmodule
diff --git a/src/mase_components/generated_lut/rtl/isqrt_lut.sv b/src/mase_components/generated_lut/rtl/isqrt_lut.sv
new file mode 100644
index 000000000..5b90bcca2
--- /dev/null
+++ b/src/mase_components/generated_lut/rtl/isqrt_lut.sv
@@ -0,0 +1,277 @@
+
+`timescale 1ns / 1ps
+/* verilator lint_off UNUSEDPARAM */
+module isqrt_lut #(
+ parameter DATA_IN_0_PRECISION_0 = 9,
+ parameter DATA_IN_0_PRECISION_1 = 7,
+ parameter DATA_OUT_0_PRECISION_0 = 8,
+ parameter DATA_OUT_0_PRECISION_1 = 4
+) (
+ /* verilator lint_off UNUSEDSIGNAL */
+ input logic [8:0] data_in_0,
+ output logic [7:0] data_out_0
+);
+
+
+ always_comb begin
+ case (data_in_0)
+ 9'b000000000: data_out_0 = 8'b01111111;
+ 9'b000000001: data_out_0 = 8'b01111111;
+ 9'b000000010: data_out_0 = 8'b01111111;
+ 9'b000000011: data_out_0 = 8'b01101000;
+ 9'b000000100: data_out_0 = 8'b01011010;
+ 9'b000000101: data_out_0 = 8'b01010001;
+ 9'b000000110: data_out_0 = 8'b01001010;
+ 9'b000000111: data_out_0 = 8'b01000100;
+ 9'b000001000: data_out_0 = 8'b01000000;
+ 9'b000001001: data_out_0 = 8'b00111100;
+ 9'b000001010: data_out_0 = 8'b00111001;
+ 9'b000001011: data_out_0 = 8'b00110111;
+ 9'b000001100: data_out_0 = 8'b00110100;
+ 9'b000001101: data_out_0 = 8'b00110010;
+ 9'b000001110: data_out_0 = 8'b00110000;
+ 9'b000001111: data_out_0 = 8'b00101111;
+ 9'b000010000: data_out_0 = 8'b00101101;
+ 9'b000010001: data_out_0 = 8'b00101100;
+ 9'b000010010: data_out_0 = 8'b00101011;
+ 9'b000010011: data_out_0 = 8'b00101010;
+ 9'b000010100: data_out_0 = 8'b00101000;
+ 9'b000010101: data_out_0 = 8'b00101000;
+ 9'b000010110: data_out_0 = 8'b00100111;
+ 9'b000010111: data_out_0 = 8'b00100110;
+ 9'b000011000: data_out_0 = 8'b00100101;
+ 9'b000011001: data_out_0 = 8'b00100100;
+ 9'b000011010: data_out_0 = 8'b00100011;
+ 9'b000011011: data_out_0 = 8'b00100011;
+ 9'b000011100: data_out_0 = 8'b00100010;
+ 9'b000011101: data_out_0 = 8'b00100010;
+ 9'b000011110: data_out_0 = 8'b00100001;
+ 9'b000011111: data_out_0 = 8'b00100001;
+ 9'b000100000: data_out_0 = 8'b00100000;
+ 9'b000100001: data_out_0 = 8'b00100000;
+ 9'b000100010: data_out_0 = 8'b00011111;
+ 9'b000100011: data_out_0 = 8'b00011111;
+ 9'b000100100: data_out_0 = 8'b00011110;
+ 9'b000100101: data_out_0 = 8'b00011110;
+ 9'b000100110: data_out_0 = 8'b00011101;
+ 9'b000100111: data_out_0 = 8'b00011101;
+ 9'b000101000: data_out_0 = 8'b00011101;
+ 9'b000101001: data_out_0 = 8'b00011100;
+ 9'b000101010: data_out_0 = 8'b00011100;
+ 9'b000101011: data_out_0 = 8'b00011100;
+ 9'b000101100: data_out_0 = 8'b00011011;
+ 9'b000101101: data_out_0 = 8'b00011011;
+ 9'b000101110: data_out_0 = 8'b00011011;
+ 9'b000101111: data_out_0 = 8'b00011010;
+ 9'b000110000: data_out_0 = 8'b00011010;
+ 9'b000110001: data_out_0 = 8'b00011010;
+ 9'b000110010: data_out_0 = 8'b00011010;
+ 9'b000110011: data_out_0 = 8'b00011001;
+ 9'b000110100: data_out_0 = 8'b00011001;
+ 9'b000110101: data_out_0 = 8'b00011001;
+ 9'b000110110: data_out_0 = 8'b00011001;
+ 9'b000110111: data_out_0 = 8'b00011000;
+ 9'b000111000: data_out_0 = 8'b00011000;
+ 9'b000111001: data_out_0 = 8'b00011000;
+ 9'b000111010: data_out_0 = 8'b00011000;
+ 9'b000111011: data_out_0 = 8'b00011000;
+ 9'b000111100: data_out_0 = 8'b00010111;
+ 9'b000111101: data_out_0 = 8'b00010111;
+ 9'b000111110: data_out_0 = 8'b00010111;
+ 9'b000111111: data_out_0 = 8'b00010111;
+ 9'b001000000: data_out_0 = 8'b00010111;
+ 9'b001000001: data_out_0 = 8'b00010110;
+ 9'b001000010: data_out_0 = 8'b00010110;
+ 9'b001000011: data_out_0 = 8'b00010110;
+ 9'b001000100: data_out_0 = 8'b00010110;
+ 9'b001000101: data_out_0 = 8'b00010110;
+ 9'b001000110: data_out_0 = 8'b00010110;
+ 9'b001000111: data_out_0 = 8'b00010101;
+ 9'b001001000: data_out_0 = 8'b00010101;
+ 9'b001001001: data_out_0 = 8'b00010101;
+ 9'b001001010: data_out_0 = 8'b00010101;
+ 9'b001001011: data_out_0 = 8'b00010101;
+ 9'b001001100: data_out_0 = 8'b00010101;
+ 9'b001001101: data_out_0 = 8'b00010101;
+ 9'b001001110: data_out_0 = 8'b00010100;
+ 9'b001001111: data_out_0 = 8'b00010100;
+ 9'b001010000: data_out_0 = 8'b00010100;
+ 9'b001010001: data_out_0 = 8'b00010100;
+ 9'b001010010: data_out_0 = 8'b00010100;
+ 9'b001010011: data_out_0 = 8'b00010100;
+ 9'b001010100: data_out_0 = 8'b00010100;
+ 9'b001010101: data_out_0 = 8'b00010100;
+ 9'b001010110: data_out_0 = 8'b00010100;
+ 9'b001010111: data_out_0 = 8'b00010011;
+ 9'b001011000: data_out_0 = 8'b00010011;
+ 9'b001011001: data_out_0 = 8'b00010011;
+ 9'b001011010: data_out_0 = 8'b00010011;
+ 9'b001011011: data_out_0 = 8'b00010011;
+ 9'b001011100: data_out_0 = 8'b00010011;
+ 9'b001011101: data_out_0 = 8'b00010011;
+ 9'b001011110: data_out_0 = 8'b00010011;
+ 9'b001011111: data_out_0 = 8'b00010011;
+ 9'b001100000: data_out_0 = 8'b00010010;
+ 9'b001100001: data_out_0 = 8'b00010010;
+ 9'b001100010: data_out_0 = 8'b00010010;
+ 9'b001100011: data_out_0 = 8'b00010010;
+ 9'b001100100: data_out_0 = 8'b00010010;
+ 9'b001100101: data_out_0 = 8'b00010010;
+ 9'b001100110: data_out_0 = 8'b00010010;
+ 9'b001100111: data_out_0 = 8'b00010010;
+ 9'b001101000: data_out_0 = 8'b00010010;
+ 9'b001101001: data_out_0 = 8'b00010010;
+ 9'b001101010: data_out_0 = 8'b00010010;
+ 9'b001101011: data_out_0 = 8'b00010001;
+ 9'b001101100: data_out_0 = 8'b00010001;
+ 9'b001101101: data_out_0 = 8'b00010001;
+ 9'b001101110: data_out_0 = 8'b00010001;
+ 9'b001101111: data_out_0 = 8'b00010001;
+ 9'b001110000: data_out_0 = 8'b00010001;
+ 9'b001110001: data_out_0 = 8'b00010001;
+ 9'b001110010: data_out_0 = 8'b00010001;
+ 9'b001110011: data_out_0 = 8'b00010001;
+ 9'b001110100: data_out_0 = 8'b00010001;
+ 9'b001110101: data_out_0 = 8'b00010001;
+ 9'b001110110: data_out_0 = 8'b00010001;
+ 9'b001110111: data_out_0 = 8'b00010001;
+ 9'b001111000: data_out_0 = 8'b00010001;
+ 9'b001111001: data_out_0 = 8'b00010000;
+ 9'b001111010: data_out_0 = 8'b00010000;
+ 9'b001111011: data_out_0 = 8'b00010000;
+ 9'b001111100: data_out_0 = 8'b00010000;
+ 9'b001111101: data_out_0 = 8'b00010000;
+ 9'b001111110: data_out_0 = 8'b00010000;
+ 9'b001111111: data_out_0 = 8'b00010000;
+ 9'b010000000: data_out_0 = 8'b00010000;
+ 9'b010000001: data_out_0 = 8'b00010000;
+ 9'b010000010: data_out_0 = 8'b00010000;
+ 9'b010000011: data_out_0 = 8'b00010000;
+ 9'b010000100: data_out_0 = 8'b00010000;
+ 9'b010000101: data_out_0 = 8'b00010000;
+ 9'b010000110: data_out_0 = 8'b00010000;
+ 9'b010000111: data_out_0 = 8'b00010000;
+ 9'b010001000: data_out_0 = 8'b00010000;
+ 9'b010001001: data_out_0 = 8'b00001111;
+ 9'b010001010: data_out_0 = 8'b00001111;
+ 9'b010001011: data_out_0 = 8'b00001111;
+ 9'b010001100: data_out_0 = 8'b00001111;
+ 9'b010001101: data_out_0 = 8'b00001111;
+ 9'b010001110: data_out_0 = 8'b00001111;
+ 9'b010001111: data_out_0 = 8'b00001111;
+ 9'b010010000: data_out_0 = 8'b00001111;
+ 9'b010010001: data_out_0 = 8'b00001111;
+ 9'b010010010: data_out_0 = 8'b00001111;
+ 9'b010010011: data_out_0 = 8'b00001111;
+ 9'b010010100: data_out_0 = 8'b00001111;
+ 9'b010010101: data_out_0 = 8'b00001111;
+ 9'b010010110: data_out_0 = 8'b00001111;
+ 9'b010010111: data_out_0 = 8'b00001111;
+ 9'b010011000: data_out_0 = 8'b00001111;
+ 9'b010011001: data_out_0 = 8'b00001111;
+ 9'b010011010: data_out_0 = 8'b00001111;
+ 9'b010011011: data_out_0 = 8'b00001111;
+ 9'b010011100: data_out_0 = 8'b00001110;
+ 9'b010011101: data_out_0 = 8'b00001110;
+ 9'b010011110: data_out_0 = 8'b00001110;
+ 9'b010011111: data_out_0 = 8'b00001110;
+ 9'b010100000: data_out_0 = 8'b00001110;
+ 9'b010100001: data_out_0 = 8'b00001110;
+ 9'b010100010: data_out_0 = 8'b00001110;
+ 9'b010100011: data_out_0 = 8'b00001110;
+ 9'b010100100: data_out_0 = 8'b00001110;
+ 9'b010100101: data_out_0 = 8'b00001110;
+ 9'b010100110: data_out_0 = 8'b00001110;
+ 9'b010100111: data_out_0 = 8'b00001110;
+ 9'b010101000: data_out_0 = 8'b00001110;
+ 9'b010101001: data_out_0 = 8'b00001110;
+ 9'b010101010: data_out_0 = 8'b00001110;
+ 9'b010101011: data_out_0 = 8'b00001110;
+ 9'b010101100: data_out_0 = 8'b00001110;
+ 9'b010101101: data_out_0 = 8'b00001110;
+ 9'b010101110: data_out_0 = 8'b00001110;
+ 9'b010101111: data_out_0 = 8'b00001110;
+ 9'b010110000: data_out_0 = 8'b00001110;
+ 9'b010110001: data_out_0 = 8'b00001110;
+ 9'b010110010: data_out_0 = 8'b00001110;
+ 9'b010110011: data_out_0 = 8'b00001110;
+ 9'b010110100: data_out_0 = 8'b00001101;
+ 9'b010110101: data_out_0 = 8'b00001101;
+ 9'b010110110: data_out_0 = 8'b00001101;
+ 9'b010110111: data_out_0 = 8'b00001101;
+ 9'b010111000: data_out_0 = 8'b00001101;
+ 9'b010111001: data_out_0 = 8'b00001101;
+ 9'b010111010: data_out_0 = 8'b00001101;
+ 9'b010111011: data_out_0 = 8'b00001101;
+ 9'b010111100: data_out_0 = 8'b00001101;
+ 9'b010111101: data_out_0 = 8'b00001101;
+ 9'b010111110: data_out_0 = 8'b00001101;
+ 9'b010111111: data_out_0 = 8'b00001101;
+ 9'b011000000: data_out_0 = 8'b00001101;
+ 9'b011000001: data_out_0 = 8'b00001101;
+ 9'b011000010: data_out_0 = 8'b00001101;
+ 9'b011000011: data_out_0 = 8'b00001101;
+ 9'b011000100: data_out_0 = 8'b00001101;
+ 9'b011000101: data_out_0 = 8'b00001101;
+ 9'b011000110: data_out_0 = 8'b00001101;
+ 9'b011000111: data_out_0 = 8'b00001101;
+ 9'b011001000: data_out_0 = 8'b00001101;
+ 9'b011001001: data_out_0 = 8'b00001101;
+ 9'b011001010: data_out_0 = 8'b00001101;
+ 9'b011001011: data_out_0 = 8'b00001101;
+ 9'b011001100: data_out_0 = 8'b00001101;
+ 9'b011001101: data_out_0 = 8'b00001101;
+ 9'b011001110: data_out_0 = 8'b00001101;
+ 9'b011001111: data_out_0 = 8'b00001101;
+ 9'b011010000: data_out_0 = 8'b00001101;
+ 9'b011010001: data_out_0 = 8'b00001101;
+ 9'b011010010: data_out_0 = 8'b00001100;
+ 9'b011010011: data_out_0 = 8'b00001100;
+ 9'b011010100: data_out_0 = 8'b00001100;
+ 9'b011010101: data_out_0 = 8'b00001100;
+ 9'b011010110: data_out_0 = 8'b00001100;
+ 9'b011010111: data_out_0 = 8'b00001100;
+ 9'b011011000: data_out_0 = 8'b00001100;
+ 9'b011011001: data_out_0 = 8'b00001100;
+ 9'b011011010: data_out_0 = 8'b00001100;
+ 9'b011011011: data_out_0 = 8'b00001100;
+ 9'b011011100: data_out_0 = 8'b00001100;
+ 9'b011011101: data_out_0 = 8'b00001100;
+ 9'b011011110: data_out_0 = 8'b00001100;
+ 9'b011011111: data_out_0 = 8'b00001100;
+ 9'b011100000: data_out_0 = 8'b00001100;
+ 9'b011100001: data_out_0 = 8'b00001100;
+ 9'b011100010: data_out_0 = 8'b00001100;
+ 9'b011100011: data_out_0 = 8'b00001100;
+ 9'b011100100: data_out_0 = 8'b00001100;
+ 9'b011100101: data_out_0 = 8'b00001100;
+ 9'b011100110: data_out_0 = 8'b00001100;
+ 9'b011100111: data_out_0 = 8'b00001100;
+ 9'b011101000: data_out_0 = 8'b00001100;
+ 9'b011101001: data_out_0 = 8'b00001100;
+ 9'b011101010: data_out_0 = 8'b00001100;
+ 9'b011101011: data_out_0 = 8'b00001100;
+ 9'b011101100: data_out_0 = 8'b00001100;
+ 9'b011101101: data_out_0 = 8'b00001100;
+ 9'b011101110: data_out_0 = 8'b00001100;
+ 9'b011101111: data_out_0 = 8'b00001100;
+ 9'b011110000: data_out_0 = 8'b00001100;
+ 9'b011110001: data_out_0 = 8'b00001100;
+ 9'b011110010: data_out_0 = 8'b00001100;
+ 9'b011110011: data_out_0 = 8'b00001100;
+ 9'b011110100: data_out_0 = 8'b00001100;
+ 9'b011110101: data_out_0 = 8'b00001100;
+ 9'b011110110: data_out_0 = 8'b00001100;
+ 9'b011110111: data_out_0 = 8'b00001100;
+ 9'b011111000: data_out_0 = 8'b00001011;
+ 9'b011111001: data_out_0 = 8'b00001011;
+ 9'b011111010: data_out_0 = 8'b00001011;
+ 9'b011111011: data_out_0 = 8'b00001011;
+ 9'b011111100: data_out_0 = 8'b00001011;
+ 9'b011111101: data_out_0 = 8'b00001011;
+ 9'b011111110: data_out_0 = 8'b00001011;
+ 9'b011111111: data_out_0 = 8'b00001011;
+ default: data_out_0 = 8'b0;
+ endcase
+ end
+endmodule
diff --git a/src/mase_components/helper/generate_memory.py b/src/mase_components/helper/generate_memory.py
index e46c99604..fdb2c1a6b 100644
--- a/src/mase_components/helper/generate_memory.py
+++ b/src/mase_components/helper/generate_memory.py
@@ -11,8 +11,14 @@
from pathlib import Path
-def make_quantizer(data_width: int, f_width: int):
- return partial(integer_quantizer, width=data_width, frac_width=f_width)
+def make_quantizer(data_width: int, f_width: int, floor):
+ base_quantizer = integer_floor_quantizer if floor else integer_quantizer
+ return partial(base_quantizer, width=data_width, frac_width=f_width)
+
+
+def isqrt(x):
+ x = (x + 1e-5).sqrt().reciprocal()
+ return x
FUNCTION_TABLE = {
@@ -23,7 +29,9 @@ def make_quantizer(data_width: int, f_width: int):
"softshrink": nn.Softshrink(),
"gelu": nn.GELU(),
"exp": torch.exp,
+ "power2": lambda x: torch.pow(2, x),
"softmax": torch.exp,
+ "isqrt": isqrt,
}
@@ -41,6 +49,9 @@ def doubletofx(data_width: int, f_width: int, num: float, type="hex"):
intbits = BitArray(int=intnum, length=data_width)
return str(intbits.bin) if type == "bin" else str(intbits)
+def inttobit(data_width:int, num: float, signed: bool = True):
+ intbits = BitArray(int=num, length=data_width) if signed else BitArray(uint=num, length=data_width)
+ return intbits
def generate_lookup(data_width: int, f_width: int, function: str, type="hex"):
f = FUNCTION_TABLE[function]
@@ -70,7 +81,14 @@ def generate_lookup(data_width: int, f_width: int, function: str, type="hex"):
def aligned_generate_lookup(
- in_data_width, in_f_width, data_width: int, f_width: int, function: str, type="hex"
+ in_data_width,
+ in_f_width,
+ data_width: int,
+ f_width: int,
+ function: str,
+ type="hex",
+ constant_mult=1,
+ floor=False,
):
f = FUNCTION_TABLE[function]
lut = {
@@ -83,15 +101,15 @@ def aligned_generate_lookup(
# entries = 2 ** data_width
minval = float(-(2 ** (in_data_width - in_f_width - 1)))
maxval = (2 ** (in_data_width - 1) - 1) * 2 ** (-in_f_width)
- inp_quanter = make_quantizer(in_data_width, in_f_width)
- quanter = make_quantizer(data_width, f_width)
+ inp_quanter = make_quantizer(in_data_width, in_f_width, floor)
+ quanter = make_quantizer(data_width, f_width, floor)
count = 0
iarr = []
pi = float(0)
while pi <= maxval:
count += 1
iarr.append(pi)
- val = quanter(f(torch.tensor(pi))) # entry in the lookup table
+ val = quanter(f(torch.tensor(pi * constant_mult))) # entry in the lookup table
lut[
doubletofx(data_width=in_data_width, f_width=in_f_width, num=pi, type=type)
] = doubletofx(
@@ -99,17 +117,22 @@ def aligned_generate_lookup(
)
pi += 2 ** -(in_f_width)
- i = minval
- while i <= -1 * 2 ** -(in_f_width):
- count += 1
- iarr.append(i)
- val = quanter(f(torch.tensor(i))) # entry in the lookup table
- lut[
- doubletofx(data_width=in_data_width, f_width=in_f_width, num=i, type=type)
- ] = doubletofx(
- data_width=data_width, f_width=f_width, num=val.item(), type=type
- )
- i += 2 ** -(in_f_width)
+ if function not in ["isqrt"]:
+ i = minval
+ while i <= -1 * 2 ** -(in_f_width):
+ count += 1
+ iarr.append(i)
+ val = quanter(
+ f(torch.tensor(i * constant_mult))
+ ) # entry in the lookup table
+ lut[
+ doubletofx(
+ data_width=in_data_width, f_width=in_f_width, num=i, type=type
+ )
+ ] = doubletofx(
+ data_width=data_width, f_width=f_width, num=val.item(), type=type
+ )
+ i += 2 ** -(in_f_width)
iarr = [(x * 2 ** (in_f_width)) for x in iarr]
# print(iarr)
@@ -211,6 +234,8 @@ def lookup_to_sv_file(
function: str,
file_path=None,
path_with_dtype=False,
+ constant_mult=1,
+ floor=False,
):
dicto = aligned_generate_lookup(
in_data_width=in_data_width,
@@ -219,6 +244,8 @@ def lookup_to_sv_file(
f_width=f_width,
function=function,
type="bin",
+ constant_mult=constant_mult,
+ floor=floor,
)
dicto = {
k: v
@@ -237,31 +264,32 @@ def lookup_to_sv_file(
`timescale 1ns / 1ps
/* verilator lint_off UNUSEDPARAM */
module {function}_lut{end} #(
- parameter DATA_IN_0_PRECISION_0 = 16,
- parameter DATA_IN_0_PRECISION_1 = 8,
- parameter DATA_OUT_0_PRECISION_0 = 16,
- parameter DATA_OUT_0_PRECISION_1 = 8
-)
-(
+ parameter DATA_IN_0_PRECISION_0 = {in_data_width},
+ parameter DATA_IN_0_PRECISION_1 = {in_f_width},
+ parameter DATA_OUT_0_PRECISION_0 = {data_width},
+ parameter DATA_OUT_0_PRECISION_1 = {f_width}
+) (
/* verilator lint_off UNUSEDSIGNAL */
- input logic [{in_data_width-1}:0] data_in_0,
- output logic [{data_width-1}:0] data_out_0
+ input logic [{in_data_width - 1}:0] data_in_0,
+ output logic [{data_width - 1}:0] data_out_0
);
-
+
+"""
+ sv_code += """
+ always_comb begin
+ case (data_in_0)
"""
- sv_code += " always_comb begin\n"
- sv_code += " case(data_in_0)\n"
# Adding each case
for key, value in dicto.items():
formatted_key = key_format.format(key)
formatted_value = value_format.format(value)
- sv_code += f" {formatted_key}: data_out_0 = {formatted_value};\n"
+ sv_code += f" {formatted_key}: data_out_0 = {formatted_value};\n"
# Ending the case statement and module
- sv_code += f" default: data_out_0 = {data_width}'b0;\n"
- sv_code += " endcase\n"
- sv_code += " end\n"
+ sv_code += f" default: data_out_0 = {data_width}'b0;\n"
+ sv_code += " endcase\n"
+ sv_code += " end\n"
sv_code += "endmodule\n"
# Write the code to a SystemVerilog file
@@ -271,14 +299,96 @@ def lookup_to_sv_file(
print(f"SystemVerilog module generated and saved as {file_path}.")
+def inttobit(data_width:int, num: float, signed: bool = True):
+ intbits = BitArray(int=num, length=data_width) if signed else BitArray(uint=num, length=data_width)
+ return intbits
+class GenerateSVLut:
+ def __init__(self, function_name, parameter, path):
+ assert (
+ function_name in FUNCTION_TABLE
+ ), f"Function {function_name} not found in FUNCTION_TABLE"
+ self.f = FUNCTION_TABLE[function_name]
+ self.parameter = parameter
+ self.path = path
+ def quant_profile(self, bin_in):
+ bin_out = bin_in
+ return bin_out
+
+ def generate_lut_address(self):
+ return NotImplementedError
+
+ def generate_lut(self, lut_address: list):
+ lut = {}
+ for i in lut_address:
+ bin_out = self.quant_profile(i)
+ lut[i] = bin_out
+ return lut
+
+ def generate_sv(self,lut):
+ self.generate_lut()
+ return NotImplementedError
+
+ def pipeline(self):
+ lut_address = self.generate_lut_address(self)
+ lut = self.generate_lut(lut_address)
+ sv = self.generate_sv(lut)
+
+from mase_components.linear_layers.mxint_operators.test.utils import mxint_quant_block
+class GenerateMxIntSVLut(GenerateSVLut):
+ def quant_profile(self, bin_in):
+ in_man_width, in_exp_width, out_man_width, out_exp_width = self.parameter["in_man_width"], self.parameter["in_exp_width"], self.parameter["out_man_width"], self.parameter["out_exp_width"]
+ _bin = BitArray(bin=bin_in)
+ exp_int = _bin[0:in_exp_width].int
+ man_int = _bin[in_exp_width:in_man_width + in_exp_width].int
+ value = man_int / 2**(in_man_width - 1) * 2**(exp_int)
+ exp_value = self.f(torch.tensor(value))
+ quant_value, mx, ex = mxint_quantize(exp_value,out_man_width,out_exp_width)
+ exp_bit = inttobit(out_exp_width, num=ex).bin
+ man_bit = inttobit(out_man_width, num=mx).bin
+ bin_out = exp_bit + man_bit
+ return bin_out
+ def generate_lut_address(self):
+ in_man_width, in_exp_width, out_man_width, out_exp_width = self.parameter["in_man_width"], self.parameter["in_exp_width"], self.parameter["out_man_width"], self.parameter["out_exp_width"]
+ # we can determine the upperbound of exp
+ from math import log
+ upperbound_of_mx_output = (2**(out_man_width - 1) - 1) / 2**(out_man_width - 1) * 2**(2**(out_exp_width - 1) - 1)
+ lowerbound_of_mx_output = (1) / 2**(out_man_width - 1) * 2**(-2**(out_exp_width - 1))
+ positive_max_bound = log(upperbound_of_mx_output)
+ negetive_max_bound = log(lowerbound_of_mx_output)
+ # when input> max_bound or input < lower_boud, we actually dont need to represent them
+ max_exp = torch.tensor(max(abs(positive_max_bound), abs(negetive_max_bound)))
+ _, _, max_exp = mxint_quantize(max_exp)
+
+ # actually, we also don't have that much precision to represent the data around 1(exp(0))
+ # so the limitation at data around 0 can determine the minimum value of exp.
+ # so we got two value in the left side or in the right side
+ _left = (2**(out_man_width - 1) - 1) / 2**(out_man_width - 1)
+ _right = (1*2**(out_man_width - 2) + 1) / 2**(out_man_width - 1) * 2**(1)
+ # we need to find a way to rounding them, divide the gap by two, so when it's smaller than this value, we can actually think, it's 0
+ _left = 1 - (1 - _left)/2
+ _right = 1 + (_right - 1)/2
+ positive_min_bound = log(_left)
+ negetive_min_bound = log(_right)
+ min_exp = torch.tensor(min(abs(positive_min_bound), abs(negetive_min_bound)))
+ _, _, min_exp = mxint_quantize(min_exp)
+ address = []
+ for i in range(int(min_exp), int(max_exp+in_man_width)):
+ for j in range(2**in_man_width):
+ exp_bin = inttobit(in_exp_width,i).bin
+ man_bin = inttobit(in_man_width,j, signed=False).bin
+ address += [str(exp_bin) + str(man_bin)]
+ return address
+
def generate_sv_lut(
function_name,
in_data_width,
in_f_width,
data_width,
f_width,
- path=None,
+ path=None, # maybe not accept path as a parameter due to redundantly-generated exp_lut
path_with_dtype=False,
+ constant_mult=1,
+ floor=False,
):
assert (
function_name in FUNCTION_TABLE
@@ -289,27 +399,18 @@ def generate_sv_lut(
else:
end = ""
- if path is None:
- p = Path(__file__).parents[1] / "rtl"
- lookup_to_sv_file(
- in_data_width,
- in_f_width,
- data_width,
- f_width,
- function_name,
- str(p / f"{function_name}_lut{end}.sv"),
- path_with_dtype=path_with_dtype,
- )
- else:
- lookup_to_sv_file(
- in_data_width,
- in_f_width,
- data_width,
- f_width,
- function_name,
- f"{path}/{function_name}_lut{end}.sv",
- path_with_dtype=path_with_dtype,
- )
+ p = Path(__file__).parents[1] / "generated_lut" / "rtl"
+ lookup_to_sv_file(
+ in_data_width,
+ in_f_width,
+ data_width,
+ f_width,
+ function_name,
+ str(p / f"{function_name}_lut{end}.sv"),
+ path_with_dtype=path_with_dtype,
+ constant_mult=constant_mult,
+ floor=floor,
+ )
if __name__ == "__main__":
diff --git a/src/mase_components/hls/scalar_ops/int_div/README.md b/src/mase_components/hls/scalar_ops/int_div/README.md
new file mode 100644
index 000000000..c829d9a3b
--- /dev/null
+++ b/src/mase_components/hls/scalar_ops/int_div/README.md
@@ -0,0 +1,7 @@
+# Scalar integer/fixed-point divider with handshake interface
+
+To generate the verilog, run:
+
+```sh
+vitis_hls vhls.tcl
+```
diff --git a/src/mase_components/hls/scalar_ops/int_div/div.cpp b/src/mase_components/hls/scalar_ops/int_div/div.cpp
new file mode 100644
index 000000000..6966d1ecf
--- /dev/null
+++ b/src/mase_components/hls/scalar_ops/int_div/div.cpp
@@ -0,0 +1,23 @@
+#include "ap_int.h"
+#include "hls_stream.h"
+
+#define total_width_0 32
+
+#define total_width_1 32
+
+#define total_width_2 16
+
+void div(hls::stream> &data_in_0,
+ hls::stream> &data_in_1,
+ hls::stream> &data_out_0) {
+#pragma HLS PIPELINE II = 1
+ if (data_in_0.empty() || data_in_1.empty())
+ return;
+ ap_int in0;
+ ap_int in1;
+ data_in_0.read_nb(in0);
+ data_in_1.read_nb(in1);
+ ap_int res = in0 / in1;
+ // TODO: #pragma HLS bind_op variable=res op= impl=fabric
+ data_out_0.write_nb(res);
+}
diff --git a/src/mase_components/hls/scalar_ops/int_div/prj/hls.app b/src/mase_components/hls/scalar_ops/int_div/prj/hls.app
new file mode 100644
index 000000000..f0f80b456
--- /dev/null
+++ b/src/mase_components/hls/scalar_ops/int_div/prj/hls.app
@@ -0,0 +1,9 @@
+
+
+
+
+
+
+
+
+
diff --git a/src/mase_components/hls/scalar_ops/int_div/prj/solution1/.autopilot/.autopilot_exit b/src/mase_components/hls/scalar_ops/int_div/prj/solution1/.autopilot/.autopilot_exit
new file mode 100644
index 000000000..cc26a0a77
--- /dev/null
+++ b/src/mase_components/hls/scalar_ops/int_div/prj/solution1/.autopilot/.autopilot_exit
@@ -0,0 +1,2 @@
+22:51:45
+08/04/2024
diff --git a/src/mase_components/hls/scalar_ops/int_div/prj/solution1/.autopilot/db/.message_syn.xml b/src/mase_components/hls/scalar_ops/int_div/prj/solution1/.autopilot/db/.message_syn.xml
new file mode 100644
index 000000000..3325f2f29
--- /dev/null
+++ b/src/mase_components/hls/scalar_ops/int_div/prj/solution1/.autopilot/db/.message_syn.xml
@@ -0,0 +1,61 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/src/mase_components/hls/scalar_ops/int_div/prj/solution1/.autopilot/db/a.g b/src/mase_components/hls/scalar_ops/int_div/prj/solution1/.autopilot/db/a.g
new file mode 100644
index 000000000..fed4a96e7
Binary files /dev/null and b/src/mase_components/hls/scalar_ops/int_div/prj/solution1/.autopilot/db/a.g differ
diff --git a/src/mase_components/hls/scalar_ops/int_div/prj/solution1/.autopilot/db/a.g.0.bc b/src/mase_components/hls/scalar_ops/int_div/prj/solution1/.autopilot/db/a.g.0.bc
new file mode 100644
index 000000000..fed4a96e7
Binary files /dev/null and b/src/mase_components/hls/scalar_ops/int_div/prj/solution1/.autopilot/db/a.g.0.bc differ
diff --git a/src/mase_components/hls/scalar_ops/int_div/prj/solution1/.autopilot/db/a.g.1.bc b/src/mase_components/hls/scalar_ops/int_div/prj/solution1/.autopilot/db/a.g.1.bc
new file mode 100644
index 000000000..a9959c961
Binary files /dev/null and b/src/mase_components/hls/scalar_ops/int_div/prj/solution1/.autopilot/db/a.g.1.bc differ
diff --git a/src/mase_components/hls/scalar_ops/int_div/prj/solution1/.autopilot/db/a.g.2.bc b/src/mase_components/hls/scalar_ops/int_div/prj/solution1/.autopilot/db/a.g.2.bc
new file mode 100644
index 000000000..4364432ab
Binary files /dev/null and b/src/mase_components/hls/scalar_ops/int_div/prj/solution1/.autopilot/db/a.g.2.bc differ
diff --git a/src/mase_components/hls/scalar_ops/int_div/prj/solution1/.autopilot/db/a.g.2.prechk.bc b/src/mase_components/hls/scalar_ops/int_div/prj/solution1/.autopilot/db/a.g.2.prechk.bc
new file mode 100644
index 000000000..6d478880a
Binary files /dev/null and b/src/mase_components/hls/scalar_ops/int_div/prj/solution1/.autopilot/db/a.g.2.prechk.bc differ
diff --git a/src/mase_components/hls/scalar_ops/int_div/prj/solution1/.autopilot/db/a.g.bc b/src/mase_components/hls/scalar_ops/int_div/prj/solution1/.autopilot/db/a.g.bc
new file mode 100644
index 000000000..fed4a96e7
Binary files /dev/null and b/src/mase_components/hls/scalar_ops/int_div/prj/solution1/.autopilot/db/a.g.bc differ
diff --git a/src/mase_components/hls/scalar_ops/int_div/prj/solution1/.autopilot/db/a.g.ld.0.bc b/src/mase_components/hls/scalar_ops/int_div/prj/solution1/.autopilot/db/a.g.ld.0.bc
new file mode 100644
index 000000000..b9c400d9f
Binary files /dev/null and b/src/mase_components/hls/scalar_ops/int_div/prj/solution1/.autopilot/db/a.g.ld.0.bc differ
diff --git a/src/mase_components/hls/scalar_ops/int_div/prj/solution1/.autopilot/db/a.g.ld.0.bc.clang.reflow.diag.xml b/src/mase_components/hls/scalar_ops/int_div/prj/solution1/.autopilot/db/a.g.ld.0.bc.clang.reflow.diag.xml
new file mode 100644
index 000000000..e69de29bb
diff --git a/src/mase_components/hls/scalar_ops/int_div/prj/solution1/.autopilot/db/a.g.ld.0.bc.clang.reflow.diag.yml b/src/mase_components/hls/scalar_ops/int_div/prj/solution1/.autopilot/db/a.g.ld.0.bc.clang.reflow.diag.yml
new file mode 100644
index 000000000..e69de29bb
diff --git a/src/mase_components/hls/scalar_ops/int_div/prj/solution1/.autopilot/db/a.g.ld.1.lower.bc b/src/mase_components/hls/scalar_ops/int_div/prj/solution1/.autopilot/db/a.g.ld.1.lower.bc
new file mode 100644
index 000000000..abecf12c8
Binary files /dev/null and b/src/mase_components/hls/scalar_ops/int_div/prj/solution1/.autopilot/db/a.g.ld.1.lower.bc differ
diff --git a/src/mase_components/hls/scalar_ops/int_div/prj/solution1/.autopilot/db/a.g.ld.1.lower.bc.opt.diag.yml b/src/mase_components/hls/scalar_ops/int_div/prj/solution1/.autopilot/db/a.g.ld.1.lower.bc.opt.diag.yml
new file mode 100644
index 000000000..e69de29bb
diff --git a/src/mase_components/hls/scalar_ops/int_div/prj/solution1/.autopilot/db/a.g.ld.2.m1.bc b/src/mase_components/hls/scalar_ops/int_div/prj/solution1/.autopilot/db/a.g.ld.2.m1.bc
new file mode 100644
index 000000000..bfad45749
Binary files /dev/null and b/src/mase_components/hls/scalar_ops/int_div/prj/solution1/.autopilot/db/a.g.ld.2.m1.bc differ
diff --git a/src/mase_components/hls/scalar_ops/int_div/prj/solution1/.autopilot/db/a.g.ld.3.fpc.bc b/src/mase_components/hls/scalar_ops/int_div/prj/solution1/.autopilot/db/a.g.ld.3.fpc.bc
new file mode 100644
index 000000000..db9054f0c
Binary files /dev/null and b/src/mase_components/hls/scalar_ops/int_div/prj/solution1/.autopilot/db/a.g.ld.3.fpc.bc differ
diff --git a/src/mase_components/hls/scalar_ops/int_div/prj/solution1/.autopilot/db/a.g.ld.3.fpc.bc.opt.diag.yml b/src/mase_components/hls/scalar_ops/int_div/prj/solution1/.autopilot/db/a.g.ld.3.fpc.bc.opt.diag.yml
new file mode 100644
index 000000000..e69de29bb
diff --git a/src/mase_components/hls/scalar_ops/int_div/prj/solution1/.autopilot/db/a.g.ld.4.m2.bc b/src/mase_components/hls/scalar_ops/int_div/prj/solution1/.autopilot/db/a.g.ld.4.m2.bc
new file mode 100644
index 000000000..28534701f
Binary files /dev/null and b/src/mase_components/hls/scalar_ops/int_div/prj/solution1/.autopilot/db/a.g.ld.4.m2.bc differ
diff --git a/src/mase_components/hls/scalar_ops/int_div/prj/solution1/.autopilot/db/a.g.ld.5.gdce.bc b/src/mase_components/hls/scalar_ops/int_div/prj/solution1/.autopilot/db/a.g.ld.5.gdce.bc
new file mode 100644
index 000000000..68c051076
Binary files /dev/null and b/src/mase_components/hls/scalar_ops/int_div/prj/solution1/.autopilot/db/a.g.ld.5.gdce.bc differ
diff --git a/src/mase_components/hls/scalar_ops/int_div/prj/solution1/.autopilot/db/a.g.ld.5.gdce.bc.opt.diag.yml b/src/mase_components/hls/scalar_ops/int_div/prj/solution1/.autopilot/db/a.g.ld.5.gdce.bc.opt.diag.yml
new file mode 100644
index 000000000..e69de29bb
diff --git a/src/mase_components/hls/scalar_ops/int_div/prj/solution1/.autopilot/db/a.g.lto.bc b/src/mase_components/hls/scalar_ops/int_div/prj/solution1/.autopilot/db/a.g.lto.bc
new file mode 100644
index 000000000..fed4a96e7
Binary files /dev/null and b/src/mase_components/hls/scalar_ops/int_div/prj/solution1/.autopilot/db/a.g.lto.bc differ
diff --git a/src/mase_components/hls/scalar_ops/int_div/prj/solution1/.autopilot/db/a.o.1.bc b/src/mase_components/hls/scalar_ops/int_div/prj/solution1/.autopilot/db/a.o.1.bc
new file mode 100644
index 000000000..a9959c961
Binary files /dev/null and b/src/mase_components/hls/scalar_ops/int_div/prj/solution1/.autopilot/db/a.o.1.bc differ
diff --git a/src/mase_components/hls/scalar_ops/int_div/prj/solution1/.autopilot/db/a.o.1.tmp.bc b/src/mase_components/hls/scalar_ops/int_div/prj/solution1/.autopilot/db/a.o.1.tmp.bc
new file mode 100644
index 000000000..6ef86274a
Binary files /dev/null and b/src/mase_components/hls/scalar_ops/int_div/prj/solution1/.autopilot/db/a.o.1.tmp.bc differ
diff --git a/src/mase_components/hls/scalar_ops/int_div/prj/solution1/.autopilot/db/a.o.2.bc b/src/mase_components/hls/scalar_ops/int_div/prj/solution1/.autopilot/db/a.o.2.bc
new file mode 100644
index 000000000..e0b9d8e14
Binary files /dev/null and b/src/mase_components/hls/scalar_ops/int_div/prj/solution1/.autopilot/db/a.o.2.bc differ
diff --git a/src/mase_components/hls/scalar_ops/int_div/prj/solution1/.autopilot/db/a.o.3.bc b/src/mase_components/hls/scalar_ops/int_div/prj/solution1/.autopilot/db/a.o.3.bc
new file mode 100644
index 000000000..1163aec65
Binary files /dev/null and b/src/mase_components/hls/scalar_ops/int_div/prj/solution1/.autopilot/db/a.o.3.bc differ
diff --git a/src/mase_components/hls/scalar_ops/int_div/prj/solution1/.autopilot/db/a.pp.bc b/src/mase_components/hls/scalar_ops/int_div/prj/solution1/.autopilot/db/a.pp.bc
new file mode 100644
index 000000000..fed4a96e7
Binary files /dev/null and b/src/mase_components/hls/scalar_ops/int_div/prj/solution1/.autopilot/db/a.pp.bc differ
diff --git a/src/mase_components/hls/scalar_ops/int_div/prj/solution1/.autopilot/db/all.directive.json b/src/mase_components/hls/scalar_ops/int_div/prj/solution1/.autopilot/db/all.directive.json
new file mode 100644
index 000000000..8d364558d
--- /dev/null
+++ b/src/mase_components/hls/scalar_ops/int_div/prj/solution1/.autopilot/db/all.directive.json
@@ -0,0 +1,23 @@
+[
+ {
+ "functionLabel": "",
+ "functionName": "div",
+ "id": 0,
+ "ifcond": "",
+ "insert_position": "",
+ "label": "",
+ "pragma": {
+ "name": "TOP",
+ "option": [
+ {
+ "name": "name",
+ "value": "div"
+ }
+ ]
+ },
+ "slx": false,
+ "sourceFile": "/workspace/src/mase_components/hls/scalar_ops/int_div/vhls.tcl",
+ "sourceLine": 8,
+ "success": true
+ }
+]
\ No newline at end of file
diff --git a/src/mase_components/hls/scalar_ops/int_div/prj/solution1/.autopilot/db/apatb_div.cpp b/src/mase_components/hls/scalar_ops/int_div/prj/solution1/.autopilot/db/apatb_div.cpp
new file mode 100644
index 000000000..8abf6c7ce
--- /dev/null
+++ b/src/mase_components/hls/scalar_ops/int_div/prj/solution1/.autopilot/db/apatb_div.cpp
@@ -0,0 +1,1218 @@
+#include "hls_signal_handler.h"
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include