From 4d544b07a8436fb4e6a4361e45d1ab8f2988d333 Mon Sep 17 00:00:00 2001
From: Johnny1882
Date: Thu, 23 Jan 2025 11:10:52 +0000
Subject: [PATCH 01/38] MORR ONN transform pass and test script implementation
---
src/chop/nn/__init__.py | 1 +
src/chop/nn/optical/__init__.py | 3 +
src/chop/nn/optical/functional/__init__.py | 57 +
src/chop/nn/optical/functional/compute.py | 1064 +++++++++++++++++
src/chop/nn/optical/functional/general.py | 411 +++++++
src/chop/nn/optical/functional/initializer.py | 152 +++
src/chop/nn/optical/functional/mrr.py | 112 ++
src/chop/nn/optical/functional/mrr_op.py | 404 +++++++
src/chop/nn/optical/functional/quantize.py | 575 +++++++++
src/chop/nn/optical/functional/torch_train.py | 857 +++++++++++++
src/chop/nn/optical/modules/__init__.py | 7 +
src/chop/nn/optical/modules/base_layer.py | 71 ++
src/chop/nn/optical/modules/morr_conv2d.py | 458 +++++++
src/chop/nn/optical/modules/morr_linear.py | 442 +++++++
.../passes/module/module_transform_helper.py | 64 +
src/chop/passes/module/transforms/__init__.py | 1 +
.../module/transforms/optical/__init__.py | 1 +
.../module/transforms/optical/optical.py | 103 ++
.../transforms/optical/test_optical_module.py | 192 +++
.../transforms/optical/train_mnist_cnn.py | 144 +++
20 files changed, 5119 insertions(+)
create mode 100644 src/chop/nn/optical/__init__.py
create mode 100644 src/chop/nn/optical/functional/__init__.py
create mode 100644 src/chop/nn/optical/functional/compute.py
create mode 100644 src/chop/nn/optical/functional/general.py
create mode 100644 src/chop/nn/optical/functional/initializer.py
create mode 100644 src/chop/nn/optical/functional/mrr.py
create mode 100644 src/chop/nn/optical/functional/mrr_op.py
create mode 100644 src/chop/nn/optical/functional/quantize.py
create mode 100644 src/chop/nn/optical/functional/torch_train.py
create mode 100644 src/chop/nn/optical/modules/__init__.py
create mode 100644 src/chop/nn/optical/modules/base_layer.py
create mode 100644 src/chop/nn/optical/modules/morr_conv2d.py
create mode 100644 src/chop/nn/optical/modules/morr_linear.py
create mode 100644 src/chop/passes/module/module_transform_helper.py
create mode 100644 src/chop/passes/module/transforms/optical/__init__.py
create mode 100644 src/chop/passes/module/transforms/optical/optical.py
create mode 100644 test/passes/module/transforms/optical/test_optical_module.py
create mode 100644 test/passes/module/transforms/optical/train_mnist_cnn.py
diff --git a/src/chop/nn/__init__.py b/src/chop/nn/__init__.py
index 7ab651324..5533a7608 100644
--- a/src/chop/nn/__init__.py
+++ b/src/chop/nn/__init__.py
@@ -1,3 +1,4 @@
from .quantized import quantized_module_map
+from .optical import optical_module_map
MASE_LEAF_LAYERS = tuple(quantized_module_map.values())
diff --git a/src/chop/nn/optical/__init__.py b/src/chop/nn/optical/__init__.py
new file mode 100644
index 000000000..e9c0423d5
--- /dev/null
+++ b/src/chop/nn/optical/__init__.py
@@ -0,0 +1,3 @@
+from .modules import (
+ optical_module_map,
+)
\ No newline at end of file
diff --git a/src/chop/nn/optical/functional/__init__.py b/src/chop/nn/optical/functional/__init__.py
new file mode 100644
index 000000000..11a435770
--- /dev/null
+++ b/src/chop/nn/optical/functional/__init__.py
@@ -0,0 +1,57 @@
+from .mrr import (
+ MORRConfig_20um_MQ,
+ MRRConfig_5um_HQ,
+ MRRConfig_5um_MQ,
+ MRRConfig_5um_LQ,
+ MORRConfig_10um_MQ,
+)
+
+from .compute import (
+ im2col_2d,
+ toeplitz,
+)
+
+from .general import (
+ logger,
+)
+
+from .initializer import (
+ morr_uniform_,
+)
+
+from .quantize import (
+ input_quantize_fn,
+ weight_quantize_fn,
+)
+
+from .mrr_op import (
+ mrr_roundtrip_phase_to_tr_func,
+ mrr_roundtrip_phase_to_tr_fused,
+)
+
+
+
+
+
+# """
+# Description:
+# Author: Jiaqi Gu (jqgu@utexas.edu)
+# Date: 2021-06-09 01:40:22
+# LastEditors: Jiaqi Gu (jqgu@utexas.edu)
+# LastEditTime: 2021-06-09 01:40:22
+# """
+
+# import importlib
+# import os
+
+# # automatically import any Python files in this directory
+# for file in sorted(os.listdir(os.path.dirname(__file__))):
+# if file.endswith(".py") and not file.startswith("_"):
+# source = file[: file.find(".py")]
+# module = importlib.import_module("torchonn.layers." + source)
+# if "__all__" in module.__dict__:
+# names = module.__dict__["__all__"]
+# else:
+# # import all names that do not begin with _
+# names = [x for x in module.__dict__ if not x.startswith("_")]
+# globals().update({k: getattr(module, k) for k in names})
diff --git a/src/chop/nn/optical/functional/compute.py b/src/chop/nn/optical/functional/compute.py
new file mode 100644
index 000000000..c43ad9f2d
--- /dev/null
+++ b/src/chop/nn/optical/functional/compute.py
@@ -0,0 +1,1064 @@
+"""
+Description:
+Author: Jiaqi Gu (jqgu@utexas.edu)
+Date: 2021-06-06 02:17:08
+LastEditors: Jiaqi Gu (jqgu@utexas.edu)
+LastEditTime: 2021-06-06 02:17:08
+"""
+
+import contextlib
+import logging
+from functools import lru_cache
+from typing import Callable, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+from scipy.stats import truncnorm
+from torch import Tensor, nn
+from torch.autograd import grad
+from torch.nn.modules.utils import _pair
+from torch.types import Device, _size
+
+from .torch_train import set_torch_deterministic
+
+__all__ = [
+ "shift",
+ "Krylov",
+ "circulant",
+ "toeplitz",
+ "complex_circulant",
+ "complex_mult",
+ "expi",
+ "complex_matvec_mult",
+ "complex_matmul",
+ "real_to_complex",
+ "get_complex_magnitude",
+ "get_complex_energy",
+ "complex_to_polar",
+ "polar_to_complex",
+ "absclamp",
+ "absclamp_",
+ "im2col_2d",
+ "check_identity_matrix",
+ "check_unitary_matrix",
+ "check_equal_tensor",
+ "batch_diag",
+ "batch_eye_cpu",
+ "batch_eye",
+ "merge_chunks",
+ "partition_chunks",
+ "clip_by_std",
+ "percentile",
+ "gen_boolean_mask_cpu",
+ "gen_boolean_mask",
+ "fftshift_cpu",
+ "ifftshift_cpu",
+ "gen_gaussian_noise",
+ "gen_gaussian_filter2d_cpu",
+ "gen_gaussian_filter2d",
+ "add_gaussian_noise_cpu",
+ "add_gaussian_noise",
+ "add_gaussian_noise_",
+ "circulant_multiply",
+ "calc_diagonal_hessian",
+ "calc_jacobian",
+ "polynomial",
+ "gaussian",
+ "lowrank_decompose",
+ "get_conv2d_flops",
+ "interp1d",
+]
+
+
+def shift(v: Tensor, f: float = 1) -> Tensor:
+ return torch.cat((f * v[..., -1:], v[..., :-1]), dim=-1)
+
+
+def Krylov(linear_map: Callable, v: Tensor, n: Optional[int] = None) -> Tensor:
+ if n is None:
+ n = v.size(-1)
+ cols = [v]
+ for _ in range(n - 1):
+ v = linear_map(v)
+ cols.append(v)
+ return torch.stack(cols, dim=-2)
+
+
+def circulant(eigens: Tensor) -> Tensor:
+ circ = Krylov(shift, eigens).transpose(-1, -2)
+ return circ
+
+
+@lru_cache(maxsize=4)
+def _get_toeplitz_indices(n: int, device: Device) -> Tensor:
+ # cached toeplitz indices. avoid repeatedly generate the indices.
+ indices = circulant(torch.arange(n, device=device))
+ return indices
+
+
+def toeplitz(col: Tensor) -> Tensor:
+ """
+ Efficient Toeplitz matrix generation from the first column. The column vector must in the last dimension. Batch generation is supported. Suitable for AutoGrad. Circulant matrix multiplication is ~4x faster than rfft-based implementation!\\
+ @col {torch.Tensor} (Batched) column vectors.\\
+ return out {torch.Tensor} (Batched) circulant matrices
+ """
+ n = col.size(-1)
+ indices = _get_toeplitz_indices(n, device=col.device)
+ return col[..., indices]
+
+
+def complex_circulant(eigens: Tensor) -> Tensor:
+ circ = Krylov(shift, eigens).transpose(-1, -2)
+ return circ
+
+
+def complex_mult(X: Tensor, Y: Tensor) -> Tensor:
+ """Complex-valued element-wise multiplication
+
+ Args:
+ X (Tensor): Real tensor with last dim of 2 or complex tensor
+ Y (Tensor): Real tensor with last dim of 2 or complex tensor
+
+ Returns:
+ Tensor: tensor with the same type as input
+ """
+ if not torch.is_complex(X) and not torch.is_complex(Y):
+ assert (
+ X.shape[-1] == 2 and Y.shape[-1] == 2
+ ), "Last dimension of real-valued tensor must be 2"
+ if hasattr(torch, "view_as_complex"):
+ return torch.view_as_real(
+ torch.view_as_complex(X) * torch.view_as_complex(Y)
+ )
+ else:
+ return torch.stack(
+ (
+ X[..., 0] * Y[..., 0] - X[..., 1] * Y[..., 1],
+ X[..., 0] * Y[..., 1] + X[..., 1] * Y[..., 0],
+ ),
+ dim=-1,
+ )
+ else:
+ return X.mul(Y)
+
+
+def complex_matvec_mult(W: Tensor, X: Tensor) -> Tensor:
+ return torch.sum(complex_mult(W, X.unsqueeze(0).repeat(W.size(0), 1, 1)), dim=1)
+
+
+def complex_matmul(X: Tensor, Y: Tensor) -> Tensor:
+ assert X.shape[-1] == 2 and Y.shape[-1] == 2, "Last dimension must be 2"
+ if torch.__version__ >= "1.8" or (
+ torch.__version__ >= "1.7" and X.shape[:-3] == Y.shape[:-3]
+ ):
+ return torch.view_as_real(
+ torch.matmul(torch.view_as_complex(X), torch.view_as_complex(Y))
+ )
+
+ return torch.stack(
+ [
+ X[..., 0].matmul(Y[..., 0]) - X[..., 1].matmul(Y[..., 1]),
+ X[..., 0].matmul(Y[..., 1]) + X[..., 1].matmul(Y[..., 0]),
+ ],
+ dim=-1,
+ )
+
+
+def expi(x: Tensor) -> Tensor:
+ if torch.__version__ >= "1.8" or (
+ torch.__version__ >= "1.7" and not x.requires_grad
+ ):
+ return torch.exp(1j * x)
+ else:
+ return x.cos().type(torch.cfloat) + 1j * x.sin().type(torch.cfloat)
+
+
+def real_to_complex(x: Tensor) -> Tensor:
+ if torch.__version__ < "1.7":
+ return torch.stack((x, torch.zeros_like(x).to(x.device)), dim=-1)
+ else:
+ return torch.view_as_real(x.to(torch.complex64))
+
+
+def get_complex_magnitude(x: Tensor) -> Tensor:
+ assert x.size(-1) == 2, "[E] Input must be complex Tensor"
+ return torch.sqrt(x[..., 0] * x[..., 0] + x[..., 1] * x[..., 1])
+
+
+def complex_to_polar(x: Tensor) -> Tensor:
+ # real and imag to magnitude and angle
+ if isinstance(x, torch.Tensor):
+ mag = x.norm(p=2, dim=-1)
+ angle = torch.view_as_complex(x).angle()
+ x = torch.stack([mag, angle], dim=-1)
+ elif isinstance(x, np.ndarray):
+ x = x.astype(np.complex64)
+ mag = np.abs(x)
+ angle = np.angle(x)
+ x = np.stack([mag, angle], axis=-1)
+ else:
+ raise NotImplementedError
+ return x
+
+
+def polar_to_complex(mag: Tensor, angle: Tensor) -> Tensor:
+ # magnitude and angle to real and imag
+ if angle is None:
+ return real_to_complex(angle)
+ if mag is None:
+ if isinstance(angle, torch.Tensor):
+ x = torch.stack([angle.cos(), angle.sin()], dim=-1)
+ elif isinstance(angle, np.ndarray):
+ x = np.stack([np.cos(angle), np.sin(angle)], axis=-1)
+ else:
+ raise NotImplementedError
+ else:
+ if isinstance(angle, torch.Tensor):
+ x = torch.stack([mag * angle.cos(), mag * angle.sin()], dim=-1)
+ elif isinstance(angle, np.ndarray):
+ x = np.stack([mag * np.cos(angle), mag * np.sin(angle)], axis=-1)
+ else:
+ raise NotImplementedError
+ return x
+
+
+def get_complex_energy(x: Tensor) -> Tensor:
+ assert x.size(-1) == 2, "[E] Input must be complex Tensor"
+ return x[..., 0] * x[..., 0] + x[..., 1] * x[..., 1]
+
+
+def absclamp(
+ x: Tensor, min: Optional[float] = None, max: Optional[float] = None
+) -> Tensor:
+ if isinstance(x, torch.Tensor):
+ mag = x.norm(p=2, dim=-1).clamp(min=min, max=max)
+ angle = torch.view_as_complex(x).angle()
+ x = polar_to_complex(mag, angle)
+ elif isinstance(x, np.ndarray):
+ x = x.astype(np.complex64)
+ mag = np.clip(np.abs(x), a_min=min, a_max=max)
+ angle = np.angle(x)
+ x = polar_to_complex(mag, angle)
+ else:
+ raise NotImplementedError
+ return x
+
+
+def absclamp_(
+ x: Tensor, min: Optional[float] = None, max: Optional[float] = None
+) -> Tensor:
+ if isinstance(x, torch.Tensor):
+ y = torch.view_as_complex(x)
+ mag = y.abs().clamp(min=min, max=max)
+ angle = y.angle()
+ x.data.copy_(polar_to_complex(mag, angle))
+ elif isinstance(x, np.ndarray):
+ y = x.astype(np.complex64)
+ mag = np.clip(np.abs(y), a_min=min, a_max=max)
+ angle = np.angle(y)
+ x[:] = polar_to_complex(mag, angle)
+ else:
+ raise NotImplementedError
+ return x
+
+
+def im2col_2d(
+ W: Optional[Tensor] = None,
+ X: Optional[Tensor] = None,
+ stride: int = 1,
+ padding: int = 0,
+ w_size: Optional[_size] = None,
+) -> Tuple[Tensor, Tensor, int, int]:
+ if W is not None:
+ W_col = W.view(W.size(0), -1)
+ else:
+ W_col = None
+
+ if X is not None:
+ n_filters, d_filter, h_filter, w_filter = W.size() if W is not None else w_size
+ n_x, d_x, h_x, w_x = X.size()
+
+ h_out = (h_x - h_filter + 2 * padding) / stride + 1
+ w_out = (w_x - w_filter + 2 * padding) / stride + 1
+
+ h_out, w_out = int(h_out), int(w_out)
+ X_col = torch.nn.functional.unfold(
+ X.view(1, -1, h_x, w_x),
+ h_filter,
+ dilation=1,
+ padding=padding,
+ stride=stride,
+ ).view(n_x, -1, h_out * w_out)
+ X_col = X_col.permute(1, 2, 0).contiguous().view(X_col.size(1), -1)
+ else:
+ X_col, h_out, w_out = None, None, None
+
+ return W_col, X_col, h_out, w_out
+
+
+def check_identity_matrix(W: Tensor) -> bool:
+ if isinstance(W, np.ndarray):
+ W_numpy = W.copy().astype(np.float64)
+ elif isinstance(W, torch.Tensor):
+ W_numpy = W.detach().cpu().numpy().copy().astype(np.float64)
+ else:
+ assert 0, "[E] Array type not supported, must be numpy.ndarray or torch.Tensor"
+
+ return (W_numpy.shape[0] == W_numpy.shape[1]) and np.allclose(
+ W_numpy, np.eye(W_numpy.shape[0])
+ )
+
+
+def check_unitary_matrix(W: Tensor) -> bool:
+ if isinstance(W, np.ndarray):
+ W_numpy = W.copy().astype(np.float64)
+ elif isinstance(W, torch.Tensor):
+ W_numpy = W.detach().cpu().numpy().copy().astype(np.float64)
+ else:
+ assert 0, "[E] Array type not supported, must be numpy.ndarray or torch.Tensor"
+ M = np.dot(W_numpy, W_numpy.T)
+ # print(M)
+ return check_identity_matrix(M)
+
+
+def check_equal_tensor(W1: Tensor, W2: Tensor) -> bool:
+ if isinstance(W1, np.ndarray):
+ W1_numpy = W1.copy().astype(np.float64)
+ elif isinstance(W1, torch.Tensor):
+ W1_numpy = W1.detach().cpu().numpy().copy().astype(np.float64)
+ else:
+ assert 0, "[E] Array type not supported, must be numpy.ndarray or torch.Tensor"
+
+ if isinstance(W2, np.ndarray):
+ W2_numpy = W2.copy().astype(np.float64)
+ elif isinstance(W2, torch.Tensor):
+ W2_numpy = W2.detach().cpu().numpy().copy().astype(np.float64)
+ else:
+ assert 0, "[E] Array type not supported, must be numpy.ndarray or torch.Tensor"
+ return (W1_numpy.shape == W2_numpy.shape) and np.allclose(W1_numpy, W2_numpy)
+
+
+def batch_diag(x: Tensor) -> Tensor:
+ # x[..., N, N] -> [..., N]
+ assert (
+ len(x.shape) >= 2
+ ), f"At least 2-D array/tensor is expected, but got shape {x.shape}"
+ if isinstance(x, np.ndarray):
+ size = list(x.shape)
+ x = x.reshape(size[:-2] + [size[-2] * size[-1]])
+ x = x[..., :: size[-1] + 1]
+ elif isinstance(x, torch.Tensor):
+ size = list(x.size())
+ x = x.flatten(-2, -1)
+ x = x[..., :: size[-1] + 1]
+ else:
+ raise NotImplementedError
+ return x
+
+
+def batch_eye_cpu(N: int, batch_shape: List[int], dtype: np.dtype) -> np.ndarray:
+ x = np.zeros(list(batch_shape) + [N, N], dtype=dtype)
+ x.reshape(-1, N * N)[..., :: N + 1] = 1
+ return x
+
+
+def batch_eye(
+ N: int,
+ batch_shape: List[int],
+ dtype: torch.dtype,
+ device: Device = torch.device("cuda"),
+) -> torch.Tensor:
+ x = torch.zeros(list(batch_shape) + [N, N], dtype=dtype, device=device)
+ x.view(-1, N * N)[..., :: N + 1] = 1
+ return x
+
+
+def merge_chunks(x: Tensor, complex: bool = False) -> Tensor:
+ """Merge a chunked/blocked tensors into a 2D matrix
+
+ Args:
+ x (Tensor): Tensor of shape [h1, w1, h2, w2, ...., hk, wk] if complex=False; [h1, w1, h2, w2, ...., hk, wk, 2] if complex=True
+ complex (bool, optional): True if the tensor x has a last dimension with size 2 for real/imag representation. Defaults to False.
+
+ Returns:
+ Tensor: [h1*h2*...*hk, w1*w2*...*wk] or [h1*h2*...*hk, w1*w2*...*wk, 2]
+ """
+ if isinstance(x, torch.Tensor):
+ permute = torch.permute
+ elif isinstance(x, np.ndarray):
+ permute = np.transpose
+ else:
+ raise NotImplementedError
+
+ if not complex:
+ dim = len(x.shape)
+ x = permute(x, list(range(0, dim, 2)) + list(range(1, dim + 1, 2)))
+ x = x.reshape(np.prod([x.shape[i] for i in range(dim // 2)]), -1)
+ else:
+ dim = len(x.shape) - 1
+ x = permute(x, list(range(0, dim, 2)) + list(range(1, dim + 1, 2) + [dim]))
+ x = x.reshape(np.prod([x.shape[i] for i in range(dim // 2)]), -1, 2)
+
+ return x
+
+
+def partition_chunks(
+ x: Tensor, out_shape: int | Tuple[int, ...], complex: bool = False
+) -> Tensor:
+ """Partition a tensor into square chunks, similar to Rearrange in einops
+
+ Args:
+ x (Tensor): 2D tensor of shape [h1*h2*...*hk, w1*w2*...*wk] or 3D tensor of shape [h1*h2*...*hk, w1*w2*...*wk, 2] if complex=True
+ out_shape (Tuple[int]): output blocked shape (h1, w1, h2, w2, ...); Do not include the last dimension even if complex=True
+ complex (bool, optional): whether x is complex tensor. Defaults to False.
+
+ Returns:
+ [Tensor]: Tensor of shape [h1, w1, h2, w2, ...., hk, wk] or [h1, w1, h2, w2, ...., hk, wk, 2] if complex=True
+ """
+ if complex:
+ assert len(x.shape) == 3
+ x_shape = (np.prod(out_shape[::2]), np.prod(out_shape[1::2]))
+ if isinstance(x, torch.Tensor):
+ permute = torch.permute
+ pad_fn = lambda x, padding: torch.nn.functional.pad(x[None, None], padding)[
+ 0, 0
+ ]
+ is_tensor = True
+ elif isinstance(x, np.ndarray):
+ permute = np.transpose
+ pad_fn = np.pad
+ is_tensor = False
+ else:
+ raise NotImplementedError
+
+ if x_shape != x.shape[:2]:
+ ## if x cannot be partitioned into out_shape, we need to pad it
+ if is_tensor:
+ ## torch from the last dim
+ padding = (0, x_shape[1] - x.shape[1], 0, x_shape[0] - x.shape[0])
+ if complex:
+ padding = (0, 0) + padding
+ else:
+ ## np from the first dim
+ padding = ((0, x_shape[0] - x.shape[0]), (0, x_shape[1] - x.shape[1]))
+ if complex:
+ padding = padding + (0, 0)
+
+ x = pad_fn(x, padding)
+
+ in_shape = list(out_shape[::2]) + list(out_shape[1::2])
+ permute_shape = np.arange(len(out_shape)).reshape(2, -1).T.reshape(-1).tolist()
+ if complex:
+ in_shape.append(2)
+ permute_shape.append(len(permute_shape))
+ x = x.reshape(in_shape) # [h1, h2, ..., hk, w1, w2, ..., wk]
+
+ x = permute(x, permute_shape) # [h1, w1, h2, w2, ...., hk, wk]
+
+ return x
+
+
+def clip_by_std(x: Tensor, n_std_neg: float = 3.0, n_std_pos: float = 3.0) -> Tensor:
+ if isinstance(x, np.ndarray):
+ std = np.std(x)
+ mean = np.mean(x)
+ out = np.clip(x, a_min=mean - n_std_neg * std, a_max=mean + n_std_pos * std)
+ elif isinstance(x, torch.Tensor):
+ std = x.data.std()
+ mean = x.data.mean()
+ out = x.clamp(min=mean - n_std_neg * std, max=mean + n_std_pos * std)
+ else:
+ raise NotImplementedError
+ return out
+
+
+def percentile(t: Tensor, q: float) -> Tensor:
+ """
+ Return the ``q``-th percentile of the flattened input tensor's data.
+
+ CAUTION:
+ * Needs PyTorch >= 1.1.0, as ``torch.kthvalue()`` is used.
+ * Values are not interpolated, which corresponds to
+ ``numpy.percentile(..., interpolation="nearest")``.
+
+ :param t: Input tensor.
+ :param q: Percentile to compute, which must be between 0 and 100 inclusive.
+ :return: Resulting value (scalar).
+ """
+ # Note that ``kthvalue()`` works one-based, i.e. the first sorted value
+ # indeed corresponds to k=1, not k=0! Use float(q) instead of q directly,
+ # so that ``round()`` returns an integer, even if q is a np.float32.
+ if isinstance(t, torch.Tensor):
+ k = 1 + round(0.01 * float(q) * (t.numel() - 1))
+ result = t.view(-1).kthvalue(k).values.item()
+ elif isinstance(t, np.ndarray):
+ result = np.percentile(t, q=q)
+ else:
+ raise NotImplementedError
+ return result
+
+
+def gen_boolean_mask_cpu(size: _size, true_prob: float) -> np.ndarray:
+ assert 0 <= true_prob <= 1, "[E] Wrong probability for True"
+ return np.random.choice(a=[False, True], size=size, p=[1 - true_prob, true_prob])
+
+
+def gen_boolean_mask(
+ size: _size,
+ true_prob: float,
+ random_state: Optional[int] = None,
+ device: Device = torch.device("cuda"),
+) -> Tensor:
+ assert 0 <= true_prob <= 1, "[E] Wrong probability for True"
+ if true_prob > 1 - 1e-9:
+ return torch.ones(size, device=device, dtype=torch.bool)
+ elif true_prob < 1e-9:
+ return torch.zeros(size, device=device, dtype=torch.bool)
+ if random_state is not None:
+ with torch.random.fork_rng():
+ torch.random.manual_seed(random_state)
+ return torch.empty(size, dtype=torch.bool, device=device).bernoulli_(
+ true_prob
+ )
+ else:
+ return torch.empty(size, dtype=torch.bool, device=device).bernoulli_(true_prob)
+
+
+def fftshift_cpu(
+ x: Union[Tensor, np.ndarray], batched: bool = True, dim: Optional[Tuple[int]] = None
+) -> Union[Tensor, np.ndarray]:
+ if isinstance(x, np.ndarray):
+ if dim is None:
+ if batched:
+ dim = tuple(range(1, len(x.shape)))
+ else:
+ dim = tuple(range(0, len(x.shape)))
+ out = np.fft.fftshift(x, axes=dim)
+ elif isinstance(x, torch.Tensor):
+ device = x.device
+ x = x.cpu().detach().numpy()
+ if dim is None:
+ if batched:
+ dim = tuple(range(1, len(x.shape)))
+ else:
+ dim = tuple(range(0, len(x.shape)))
+ out = np.fft.fftshift(x, axes=dim)
+ out = torch.from_numpy(out).to(device)
+ return out
+
+
+def ifftshift_cpu(
+ x: Union[Tensor, np.ndarray], batched: bool = True, dim: Optional[Tuple[int]] = None
+) -> Union[Tensor, np.ndarray]:
+ if isinstance(x, np.ndarray):
+ if dim is None:
+ if batched:
+ dim = tuple(range(1, len(x.shape)))
+ else:
+ dim = tuple(range(0, len(x.shape)))
+ out = np.fft.ifftshift(x, axes=dim)
+ elif isinstance(x, torch.Tensor):
+ device = x.device
+ x = x.cpu().detach().numpy()
+ if dim is None:
+ if batched:
+ dim = tuple(range(1, len(x.shape)))
+ else:
+ dim = tuple(range(0, len(x.shape)))
+ out = np.fft.ifftshift(x, axes=dim)
+ out = torch.from_numpy(out).to(device)
+ return out
+
+
+def gen_gaussian_noise(
+ W: Union[Tensor, np.ndarray],
+ noise_mean: float = 0.0,
+ noise_std: float = 0.002,
+ trunc_range: Tuple = (),
+ random_state: Optional[int] = None,
+) -> Union[Tensor, np.ndarray]:
+ if random_state is not None:
+ set_torch_deterministic(random_state)
+ if isinstance(W, np.ndarray):
+ if not trunc_range:
+ noises = np.random.normal(noise_mean, noise_std, W.shape)
+ else:
+ a = (trunc_range[0] - noise_mean) / noise_std
+ b = (trunc_range[1] - noise_mean) / noise_std
+ noises = truncnorm.rvs(
+ a, b, loc=noise_mean, scale=noise_std, size=W.shape, random_state=None
+ )
+ elif isinstance(W, torch.Tensor):
+ if not trunc_range:
+ noises = torch.zeros_like(W).normal_(mean=noise_mean, std=noise_std)
+ else:
+ size = W.shape
+ tmp = W.new_empty(size + (4,)).normal_()
+ a = (trunc_range[0] - noise_mean) / noise_std
+ b = (trunc_range[1] - noise_mean) / noise_std
+ valid = (tmp < b) & (tmp > a)
+ ind = valid.max(-1, keepdim=True)[1]
+ noises = tmp.gather(-1, ind).squeeze(-1).mul_(noise_std).add_(noise_mean)
+ # noises = truncated_normal(W, mean=noise_mean, std=noise_std, a=trunc_range[0], b=trunc_range[1])
+ else:
+ assert 0, logging.error(
+ f"Array type not supported, must be numpy.ndarray or torch.Tensor, but got {type(W)}"
+ )
+ return noises
+
+
+def gen_gaussian_filter2d_cpu(size: int = 3, std: float = 0.286) -> np.ndarray:
+ assert (
+ size % 2 == 1
+ ), f"Gaussian filter can only be odd size, but size={size} is given."
+ ax = np.linspace(-(size - 1) / 2.0, (size - 1) / 2.0, size)
+ xx, yy = np.meshgrid(ax, ax)
+ kernel = np.exp(-0.5 / np.square(std) * (np.square(xx) + np.square(yy)))
+ kernel = kernel / np.sum(kernel)
+ kernel[size // 2, size // 2] = 1
+ return kernel
+
+
+def gen_gaussian_filter2d(
+ size: int = 3,
+ std: float = 0.286,
+ center_one: bool = True,
+ device: Device = torch.device("cuda"),
+) -> Tensor:
+ assert (
+ size % 2 == 1
+ ), f"Gaussian filter can only be odd size, but size={size} is given."
+ if std > 1e-8:
+ ax = torch.linspace(
+ -(size - 1) / 2.0,
+ (size - 1) / 2.0,
+ size,
+ dtype=torch.float32,
+ device=device,
+ )
+ xx, yy = torch.meshgrid(ax, ax)
+ kernel = torch.exp(-0.5 / (std**2) * (xx.square() + yy.square()))
+ kernel = kernel.div_(kernel.sum())
+ if center_one:
+ kernel[size // 2, size // 2] = 1
+ else:
+ kernel = torch.zeros(size, size, dtype=torch.float32, device=device)
+ kernel[size // 2, size // 2] = 1
+
+ return kernel
+
+
+def add_gaussian_noise(
+ W: Union[Tensor, np.ndarray],
+ noise_mean: float = 0,
+ noise_std: float = 0.002,
+ trunc_range: Tuple = (),
+ random_state: Optional[int] = None,
+) -> Union[Tensor, np.ndarray]:
+ noises = gen_gaussian_noise(
+ W,
+ noise_mean=noise_mean,
+ noise_std=noise_std,
+ trunc_range=trunc_range,
+ random_state=random_state,
+ )
+ output = W + noises
+ return output
+
+
+def add_gaussian_noise_(
+ W: Union[Tensor, np.ndarray],
+ noise_mean: float = 0,
+ noise_std: float = 0.002,
+ trunc_range: Tuple = (),
+ random_state: Optional[int] = None,
+) -> Union[Tensor, np.ndarray]:
+ noises = gen_gaussian_noise(
+ W,
+ noise_mean=noise_mean,
+ noise_std=noise_std,
+ trunc_range=trunc_range,
+ random_state=random_state,
+ )
+ if isinstance(W, np.ndarray):
+ W += noises
+ elif isinstance(W, torch.Tensor):
+ W.data += noises
+ else:
+ assert 0, logging.error(
+ f"Array type not supported, must be numpy.ndarray or torch.Tensor, but got {type(W)}"
+ )
+ return W
+
+
+def add_gaussian_noise_cpu(
+ W: Union[Tensor, np.ndarray],
+ noise_mean: float = 0,
+ noise_std: float = 0.002,
+ trunc_range: Tuple = (),
+) -> Union[Tensor, np.ndarray]:
+ if isinstance(W, np.ndarray):
+ W_numpy = W.copy().astype(np.float64)
+ elif isinstance(W, torch.Tensor):
+ W_numpy = W.detach().cpu().numpy().copy().astype(np.float64)
+ else:
+ assert 0, "[E] Array type not supported, must be numpy.ndarray or torch.Tensor"
+ if not trunc_range:
+ noises = np.random.normal(noise_mean, noise_std, W_numpy.shape)
+ else:
+ a = (trunc_range[0] - noise_mean) / noise_std
+ b = (trunc_range[1] - noise_mean) / noise_std
+ noises = truncnorm.rvs(
+ a, b, loc=noise_mean, scale=noise_std, size=W_numpy.shape, random_state=None
+ )
+ return W_numpy + noises
+
+
+def circulant_multiply(c: Tensor, x: Tensor) -> Tensor:
+ """Multiply circulant matrix with first column c by x
+ Parameters:
+ c: (n, )
+ x: (batch_size, n) or (n, )
+ Return:
+ prod: (batch_size, n) or (n, )
+ """
+ return torch.irfft(
+ complex_mult(torch.rfft(c, 1), torch.rfft(x, 1)), 1, signal_sizes=(c.shape[-1],)
+ )
+
+
+def calc_diagonal_hessian(weight_dict, loss, model):
+ model.zero_grad()
+ hessian_dict = {}
+ for name, weight in weight_dict.items():
+ first_gradient = grad(loss, weight, create_graph=True)[0]
+ second_gradient = grad(first_gradient.sum(), weight, create_graph=True)[0]
+ hessian_dict[name] = second_gradient.clone()
+ model.zero_grad()
+ return hessian_dict
+
+
+def calc_jacobian(
+ weight_dict: Dict[str, Tensor], loss: Callable, model: nn.Module
+) -> Dict[str, Tensor]:
+ model.zero_grad()
+ jacobian_dict = {}
+ for name, weight in weight_dict.items():
+ first_gradient = grad(loss, weight, create_graph=True)[0]
+ jacobian_dict[name] = first_gradient.clone()
+ model.zero_grad()
+ return jacobian_dict
+
+
+@lru_cache(maxsize=4)
+def _polynomial_order_base(order: int, device: Device) -> Tensor:
+ return torch.arange(order - 1, -1, -1, device=device)
+
+
+def polynomial(x: Tensor | np.ndarray, coeff: Tensor | np.ndarray) -> Tensor:
+ """calculate polynomial function of x given coefficient coeff
+
+ Args:
+ x (Tensor): input tensor
+ coeff (Tensor): Tensor of shape [n], where n is the degree of polynomial. Orders: [n, n-1, ..., 2, 1, constant]
+
+ Returns:
+ Tensor: output tensor coeff[0]*x^n + coeff[1]*x^{n-1} + ... + coeff[n-1]*x + coeff[n]
+ """
+ # xs = [x]
+ # for i in range(2, coeff.size(0)):
+ # xs.append(xs[-1]*x)
+ # xs.reverse()
+ # x = torch.stack(xs, dim=-1)
+
+ # Deprecated implementation
+ # x = torch.stack([x**i for i in range(coeff.size(0) - 1, 0, -1)], dim=-1)
+ # out = (x * coeff[:-1]).sum(dim=-1) + coeff[-1].data.item()
+ # return out
+
+ ### x^n, x^{n-1}, ..., x^2, x, 1
+ order = coeff.shape[0] # n+1
+ if isinstance(x, Tensor):
+ ## torch from highest order to constant
+ x = x[..., None].expand([-1] * x.dim() + [order])
+ order_base = _polynomial_order_base(order, x.device)
+ return x.pow(order_base).matmul(coeff)
+ elif isinstance(x, np.ndarray):
+ ## numpy polyval from constant to higher order
+ return np.polynomial.polynomial.polyval(x, coeff[::-1])
+ else:
+ raise NotImplementedError
+
+
+def gaussian(x: Tensor, coeff: Tensor) -> Tensor:
+ # coeff : [n, 3], includes a, b, c
+ ## a * exp(-((x-b)/c)^2) + ...
+ size = x.size()
+ x = x.view(-1).unsqueeze(0)
+ x = (
+ (coeff[:, 0:1] * torch.exp(-((x - coeff[:, 1:2]) / coeff[:, 2:3]).square()))
+ .sum(dim=0)
+ .view(size)
+ )
+ return x
+
+
+def lowrank_decompose(
+ x: Tensor,
+ r: int,
+ u_ortho: bool = False,
+ out_u: Optional[Tensor] = None,
+ out_v: Optional[Tensor] = None,
+) -> Tuple[Tensor, Tensor]:
+ """low rank decomposition on x. x ~ uv.
+
+ Args:
+ x (Tensor): tensor to decomplse
+ r (int): rank
+ u_ortho (bool, optional): whether u is orthogonal matrix. Defaults to False.
+ out_u (Optional[Tensor], optional): output buffer for u. Defaults to None.
+ out_v (Optional[Tensor], optional): output buffer for v. Defaults to None.
+
+ Returns:
+ Tuple[Tensor, Tensor]: [description]
+ """
+ ### x [..., m, n]
+ # r rank
+ u, s, v = x.data.svd(some=True)
+ v = v.transpose(-2, -1).contiguous()
+ u = u[..., :, :r]
+ s = s[..., :r]
+ v = v[..., :r, :]
+ if u_ortho == False:
+ u.mul_(s.unsqueeze(-2))
+ else:
+ v.mul_(s.unsqueeze(-1))
+ if out_u is not None:
+ out_u.data.copy_(u)
+ if out_v is not None:
+ out_v.data.copy_(v)
+ return u, v
+
+
+def get_conv2d_flops(
+ input_shape: _size,
+ conv_filter: _size,
+ stride: _pair = (1, 1),
+ padding: _pair = (1, 1),
+) -> float:
+ # input_shape = (4, 3,300,300) # Format:(batch, channels, rows,cols)
+ # conv_filter = (64,3,3,3) # Format: (num_filters, channels, rows, cols)
+ # stride = (1, 1) in (height, width)
+ # padding = (1, 1) in (height, width)
+ if type(stride) not in {list, tuple}:
+ stride = [stride, stride]
+ if type(padding) not in {list, tuple}:
+ padding = [padding, padding]
+ n = conv_filter[1] * conv_filter[2] * conv_filter[3] # vector_length
+ # general defination for number of flops (n: multiplications and n-1: additions)
+ flops_per_instance = n + 1
+
+ num_instances_per_filter = (
+ (input_shape[2] - conv_filter[2] + 2 * padding[0]) / stride[0]
+ ) + 1 # for rows
+ # multiplying with cols
+ num_instances_per_filter *= (
+ (input_shape[3] - conv_filter[3] + 2 * padding[1]) / stride[1]
+ ) + 1
+
+ flops_per_filter = num_instances_per_filter * flops_per_instance
+ # multiply with number of filters adn batch
+ total_flops_per_layer = flops_per_filter * conv_filter[0] * input_shape[0]
+ return total_flops_per_layer
+
+
+class Interp1d(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, x, y, xnew, out=None):
+ """
+ Batched Linear 1D interpolation on the GPU for Pytorch.
+ This function returns interpolated values of a set of 1-D functions at
+ the desired query points `xnew`. Any point exceeds the border of [xmin, xmax]
+ will be filled with 0 and no grad.
+ This function is working similarly to Matlab™ or scipy functions with
+ the `linear` interpolation mode on, except that it parallelises over
+ any number of desired interpolation problems.
+ The code will run on GPU if all the tensors provided are on a cuda
+ device.
+ https://github.com/aliutkus/torchinterp1d
+
+ Parameters
+ ----------
+ x : (N, ) or (D, N) Pytorch Tensor
+ A 1-D or 2-D tensor of real values.
+ y : (N,) or (D, N) Pytorch Tensor
+ A 1-D or 2-D tensor of real values. The length of `y` along its
+ last dimension must be the same as that of `x`
+ xnew : (P,) or (D, P) Pytorch Tensor
+ A 1-D or 2-D tensor of real values. `xnew` can only be 1-D if
+ _both_ `x` and `y` are 1-D. Otherwise, its length along the first
+ dimension must be the same as that of whichever `x` and `y` is 2-D.
+ out : Pytorch Tensor, same shape as `xnew`
+ Tensor for the output. If None: allocated automatically.
+
+ """
+ # making the vectors at least 2D
+ is_flat = {}
+ require_grad = {}
+ v = {}
+ device = []
+ eps = torch.finfo(y.dtype).eps
+ for name, vec in {"x": x, "y": y, "xnew": xnew}.items():
+ assert len(vec.shape) <= 2, "interp1d: all inputs must be " "at most 2-D."
+ if len(vec.shape) == 1:
+ v[name] = vec[None, :]
+ else:
+ v[name] = vec
+ is_flat[name] = v[name].shape[0] == 1
+ require_grad[name] = vec.requires_grad
+ device = list(set(device + [str(vec.device)]))
+ assert len(device) == 1, "All parameters must be on the same device."
+ device = device[0]
+
+ # Checking for the dimensions
+ assert v["x"].shape[1] == v["y"].shape[1] and (
+ v["x"].shape[0] == v["y"].shape[0]
+ or v["x"].shape[0] == 1
+ or v["y"].shape[0] == 1
+ ), (
+ "x and y must have the same number of columns, and either "
+ "the same number of row or one of them having only one "
+ "row."
+ )
+
+ reshaped_xnew = False
+ if (
+ (v["x"].shape[0] == 1)
+ and (v["y"].shape[0] == 1)
+ and (v["xnew"].shape[0] > 1)
+ ):
+ # if there is only one row for both x and y, there is no need to
+ # loop over the rows of xnew because they will all have to face the
+ # same interpolation problem. We should just stack them together to
+ # call interp1d and put them back in place afterwards.
+ original_xnew_shape = v["xnew"].shape
+ v["xnew"] = v["xnew"].contiguous().view(1, -1)
+ reshaped_xnew = True
+
+ # identify the dimensions of output and check if the one provided is ok
+ D = max(v["x"].shape[0], v["xnew"].shape[0])
+ shape_ynew = (D, v["xnew"].shape[-1])
+ if out is not None:
+ if out.numel() != shape_ynew[0] * shape_ynew[1]:
+ # The output provided is of incorrect shape.
+ # Going for a new one
+ out = None
+ else:
+ ynew = out.reshape(shape_ynew)
+ if out is None:
+ ynew = torch.zeros(*shape_ynew, device=device)
+
+ # moving everything to the desired device in case it was not there
+ # already (not handling the case things do not fit entirely, user will
+ # do it if required.)
+ for name in v:
+ v[name] = v[name].to(device)
+
+ # calling searchsorted on the x values.
+ ind = ynew.long()
+
+ # expanding xnew to match the number of rows of x in case only one xnew is
+ # provided
+ if v["xnew"].shape[0] == 1:
+ v["xnew"] = v["xnew"].expand(v["x"].shape[0], -1)
+
+ # the squeeze is because torch.searchsorted does accept either a nd with
+ # matching shapes for x and xnew or a 1d vector for x. Here we would
+ # have (1,len) for x sometimes
+ torch.searchsorted(
+ v["x"].contiguous().squeeze(), v["xnew"].contiguous(), out=ind
+ )
+
+ # the `-1` is because searchsorted looks for the index where the values
+ # must be inserted to preserve order. And we want the index of the
+ # preceeding value.
+ ind -= 1
+ # we clamp the index, because the number of intervals is x.shape-1,
+ # and the left neighbour should hence be at most number of intervals
+ # -1, i.e. number of columns in x -2
+ ind = torch.clamp(ind, 0, v["x"].shape[1] - 1 - 1)
+
+ # helper function to select stuff according to the found indices.
+ def sel(name):
+ if is_flat[name]:
+ return v[name].contiguous().view(-1)[ind]
+ return torch.gather(v[name], 1, ind)
+
+ # activating gradient storing for everything now
+ enable_grad = False
+ saved_inputs = []
+ for name in ["x", "y", "xnew"]:
+ if require_grad[name]:
+ enable_grad = True
+ saved_inputs += [v[name]]
+ else:
+ saved_inputs += [
+ None,
+ ]
+ # assuming x are sorted in the dimension 1, computing the slopes for
+ # the segments
+ is_flat["slopes"] = is_flat["x"]
+ # now we have found the indices of the neighbors, we start building the
+ # output. Hence, we start also activating gradient tracking
+ with torch.enable_grad() if enable_grad else contextlib.suppress():
+ v["slopes"] = (v["y"][:, 1:] - v["y"][:, :-1]) / (
+ eps + (v["x"][:, 1:] - v["x"][:, :-1])
+ )
+
+ # now build the linear interpolation
+ ynew = sel("y") + sel("slopes") * (v["xnew"] - sel("x"))
+
+ mask = (v["xnew"] > v["x"][:, -1:]) | (
+ v["xnew"] < v["x"][:, :1]
+ ) # exceed left/right border
+ ynew = ynew.masked_fill(mask, 0)
+
+ if reshaped_xnew:
+ ynew = ynew.view(original_xnew_shape)
+
+ ctx.save_for_backward(ynew, *saved_inputs)
+ return ynew
+
+ @staticmethod
+ def backward(ctx, grad_out):
+ inputs = ctx.saved_tensors[1:]
+ gradients = torch.autograd.grad(
+ ctx.saved_tensors[0],
+ [i for i in inputs if i is not None],
+ grad_out,
+ retain_graph=True,
+ )
+ result = [
+ None,
+ ] * 5
+ pos = 0
+ for index in range(len(inputs)):
+ if inputs[index] is not None:
+ result[index] = gradients[pos]
+ pos += 1
+ return (*result,)
+
+
+def interp1d(x: Tensor, y: Tensor, xnew: Tensor, out: Tensor | None = None) -> Tensor:
+ """numpy.interp for pytorch. Only 1D
+
+ Args:
+ x (Tensor): input vector x coordinates
+ y (Tensor): input vector y coordinates
+ xnew (Tensor): new x coordinates to be interpolated
+ out (Tensor, optional): output tensor. Defaults to None.
+
+ Returns:
+ Tensor: interpolated y coordinates
+ """
+ return Interp1d.apply(x, y, xnew, out)
diff --git a/src/chop/nn/optical/functional/general.py b/src/chop/nn/optical/functional/general.py
new file mode 100644
index 000000000..1d4f2990d
--- /dev/null
+++ b/src/chop/nn/optical/functional/general.py
@@ -0,0 +1,411 @@
+"""
+Description:
+Author: Jiaqi Gu (jqgu@utexas.edu)
+Date: 2021-06-06 01:55:29
+LastEditors: Jiaqi Gu (jqgu@utexas.edu)
+LastEditTime: 2021-06-06 01:55:30
+"""
+
+import os
+import argparse
+import json
+import logging
+import logging.handlers
+import time
+from collections import OrderedDict
+from datetime import datetime
+from pathlib import Path
+from typing import Optional
+
+import numpy as np
+import torch
+
+
+__all__ = [
+ "ensure_dir",
+ "read_json",
+ "write_json",
+ "profile",
+ "print_stat",
+ "Timer",
+ "TimerCtx",
+ "TorchTracemalloc",
+ "fullprint",
+ "setup_default_logging",
+ "Logger",
+ "logger",
+ "get_logger",
+ "ArgParser",
+ "disable_tf_warning",
+ "AverageMeter",
+]
+
+
+def ensure_dir(dirname, exist_ok: bool = True):
+ dirname = Path(dirname)
+ if not dirname.is_dir():
+ dirname.mkdir(parents=True, exist_ok=exist_ok)
+
+
+def read_json(fname):
+ with open(fname, "rt") as handle:
+ return json.load(handle, object_hook=OrderedDict)
+
+
+def write_json(content, fname):
+ with open(fname, "wt") as handle:
+ json.dump(content, handle, indent=4, sort_keys=False)
+
+
+def profile(func=None, timer=True):
+ from functools import wraps, partial
+ import time
+
+ if func == None:
+ return partial(profile, timer=timer)
+
+ @wraps(func)
+ def wrapper(*args, **kw):
+ if timer:
+ local_time = time.time()
+ res = func(*args, **kw)
+ end_time = time.time()
+ print("[I] <%s> runtime: %.3f ms" % (func.__name__, (end_time - local_time) * 1000))
+ else:
+ res = func(*args, **kw)
+ return res
+
+ return wrapper
+
+
+def print_stat(x, message="", verbose=True):
+ if verbose:
+ if isinstance(x, torch.Tensor):
+ if torch.is_complex(x):
+ x = torch.view_as_real(x)
+ print(
+ message + f"min = {x.data.min().item():-15f} max = {x.data.max().item():-15f} mean = {x.data.mean().item():-15f} std = {x.data.std().item():-15f}"
+ )
+ elif isinstance(x, np.ndarray):
+ print(
+ message + f"min = {np.min(x):-15f} max = {np.max(x):-15f} mean = {np.mean(x):-15f} std = {np.std(x):-15f}"
+ )
+
+
+class Timer(object):
+ def __init__(self):
+ self.cache = datetime.now()
+
+ def check(self):
+ now = datetime.now()
+ duration = now - self.cache
+ self.cache = now
+ return duration.total_seconds()
+
+ def reset(self):
+ self.cache = datetime.now()
+
+
+class TimerCtx:
+ def __enter__(self):
+ self.start = time.time()
+ return self
+
+ def __exit__(self, *args):
+ self.end = time.time()
+ self.interval = self.end - self.start
+
+
+class TorchTracemalloc(object):
+ def __init__(self, verbose: bool = False) -> None:
+ super().__init__()
+ self.verbose = verbose
+
+ def __enter__(self):
+ self.begin = self._b2mb(torch.cuda.memory_allocated())
+ torch.cuda.reset_max_memory_allocated() # reset the peak gauge to zero
+ return self
+
+ def _b2mb(self, x):
+ return x / 2 ** 20
+
+ def __exit__(self, *exc):
+ self.end = self._b2mb(torch.cuda.memory_allocated())
+ self.peak = self._b2mb(torch.cuda.max_memory_allocated())
+ self.used = self.end - self.begin
+ self.peaked = self.peak - self.begin
+ if self.verbose:
+ print(f"Delta used/peaked {self.used:.2f} MB / {self.peaked:.2f} MB")
+ print(f"Current used/peaked {self.end:.2f} MB / {self.peak:.2f} MB")
+
+
+class fullprint:
+ "context manager for printing full numpy arrays"
+
+ def __init__(self, **kwargs):
+ """linewidth=75; precision=8"""
+ kwargs.setdefault("threshold", np.inf)
+ self.opt = kwargs
+
+ def __enter__(self):
+ self._opt = np.get_printoptions()
+ np.set_printoptions(**self.opt)
+
+ def __exit__(self, type, value, traceback):
+ np.set_printoptions(**self._opt)
+
+
+class CustomFormatter(logging.Formatter):
+ """Logging Formatter to add colors and count warning / errors"""
+
+ grey = "\x1b[38;21m"
+ yellow = "\x1b[33;21m"
+ red = "\x1b[31;21m"
+ bold_red = "\x1b[31;1m"
+ green = "\x1b[32;21m"
+ reset = "\x1b[0m"
+ # format = "%(asctime)s - %(name)s - %(levelname)s - %(message)s (%(filename)s:%(lineno)d)"
+ format = "%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s"
+
+ FORMATS = {
+ logging.DEBUG: grey + format + reset,
+ logging.INFO: grey + format + reset,
+ logging.WARNING: yellow + format + reset,
+ logging.ERROR: red + format + reset,
+ logging.CRITICAL: bold_red + format + reset,
+ }
+
+ def format(self, record):
+ log_fmt = self.FORMATS.get(record.levelno)
+ formatter = logging.Formatter(log_fmt)
+ return formatter.format(record)
+
+
+def setup_default_logging(default_level=logging.INFO, default_file_level=logging.INFO, log_path=""):
+ console_handler = logging.StreamHandler()
+ console_handler.setFormatter(CustomFormatter())
+ logging.root.addHandler(console_handler)
+ logging.root.setLevel(default_level)
+ if log_path:
+ file_handler = logging.handlers.RotatingFileHandler(log_path, maxBytes=(1024 ** 2 * 2), backupCount=3)
+ file_formatter = logging.Formatter(
+ "%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s"
+ )
+ file_handler.setFormatter(file_formatter)
+ file_handler.setLevel(default_file_level)
+ logging.root.addHandler(file_handler)
+
+
+class Logger(object):
+ def __init__(self, console=True, logfile=None, console_level=logging.INFO, logfile_level=logging.INFO):
+ super().__init__()
+ self.logfile = logfile
+ self.console_level = console_level
+ self.logifle_level = logfile_level
+ assert (
+ console == True or logfile is not None
+ ), "At least enable one from console or logfile for Logger"
+ # 第一步,创建一个logger
+ self.logger = logging.getLogger("my_logger")
+ self.logger.setLevel(logging.INFO) # Log等级总开关
+ self.logger.propagate = False
+
+ # formatter = logging.Formatter(
+ # "%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s")
+ formatter = CustomFormatter()
+
+ # 第三步,再创建一个handler,用于输出到控制台
+ if console:
+ ch = logging.StreamHandler()
+ ch.setLevel(self.console_level) # 输出到console的log等级的开关
+ ch.setFormatter(formatter)
+ self.logger.addHandler(ch)
+ if self.logfile is not None:
+ fh = logging.FileHandler(self.logfile, mode="w")
+ fh.setLevel(self.logifle_level) # 输出到file的log等级的开关
+ fh.setFormatter(formatter)
+ self.logger.addHandler(fh)
+
+ def debug(self, message):
+ self.logger.debug(message)
+
+ def info(self, message):
+ self.logger.info(message)
+
+ def warning(self, message):
+ self.logger.warning(message)
+
+ def error(self, message):
+ self.logger.error(message)
+
+ def critical(self, message):
+ self.logger.critical(message)
+
+
+def get_logger(name="default", default_level=logging.INFO, default_file_level=logging.INFO, log_path=""):
+ setup_default_logging(
+ default_level=default_level, default_file_level=default_file_level, log_path=log_path
+ )
+ return logging.getLogger(name)
+
+
+logger = get_logger()
+
+
+class ArgParser(object):
+ def __init__(self, load_json=None, save_json=None):
+ super().__init__()
+ self.load_json = load_json
+ self.save_json = save_json
+ self.args = None
+ self.parser = argparse.ArgumentParser("Argument Parser")
+
+ def add_arg(self, *args, **keywords):
+ self.parser.add_argument(*args, **keywords)
+
+ def parse_args(self):
+ if self.load_json is not None:
+ assert os.path.exists(self.load_json), logging.error(
+ f"Configuration JSON {self.load_json} not found"
+ )
+ json = read_json(self.load_json)
+ t_args = argparse.Namespace()
+ t_args.__dict__.update(json)
+ self.args = self.parser.parse_args(args=[], namespace=t_args)
+ else:
+ self.args = self.parser.parse_args()
+ return self.args
+
+ def print_args(self):
+ # Print arguments to std out
+ # and save argument values to yaml file
+ print("Arguments:")
+ for p in vars(self.args).items():
+ print(f"\t{p[0]:30}{str(p[1]):20}")
+ print("\n")
+
+ def dump_args(self, json_file=None):
+ if json_file is None:
+ if self.save_json is None:
+ logging.error("Skip dump configuration JSON. Please specify json_file")
+ return False
+ else:
+ ensure_dir(os.path.dirname(self.save_json))
+ logging.warning(f"Dump to the initialized JSON file {self.save_json}")
+ write_json(vars(self.args), self.save_json)
+ else:
+ ensure_dir(os.path.dirname(json_file))
+ logging.info(f"Dump to JSON file {json_file}")
+ write_json(vars(self.args), json_file)
+ # with open(self.file, 'w') as f:
+ # yaml.dump(vars(self.args), f, default_flow_style=False)
+ # print(f"[I] Arguments dumped to {file}")
+
+
+def disable_tf_warning():
+ import os
+
+ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
+ import warnings
+
+ warnings.filterwarnings("ignore", category=FutureWarning)
+ warnings.filterwarnings("ignore", category=DeprecationWarning)
+
+ import tensorflow as tf
+
+ if hasattr(tf, "contrib") and type(tf.contrib) != type(tf):
+ tf.contrib._warning = None
+ tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
+ # tf.logging.set_verbosity(tf.logging.ERROR)
+
+ import logging
+
+ logging.getLogger("tensorflow").setLevel(logging.ERROR)
+
+
+class Meter(object):
+ """Base class for Meters."""
+
+ def __init__(self):
+ pass
+
+ def state_dict(self):
+ return {}
+
+ def load_state_dict(self, state_dict):
+ pass
+
+ def reset(self):
+ raise NotImplementedError
+
+ @property
+ def smoothed_value(self) -> float:
+ """Smoothed value used for logging."""
+ raise NotImplementedError
+
+
+def safe_round(number, ndigits):
+ if hasattr(number, "__round__"):
+ return round(number, ndigits)
+ elif torch is not None and torch.is_tensor(number) and number.numel() == 1:
+ return safe_round(number.item(), ndigits)
+ elif np is not None and np.ndim(number) == 0 and hasattr(number, "item"):
+ return safe_round(number.item(), ndigits)
+ else:
+ return number
+
+
+def type_as(a, b):
+ if torch.is_tensor(a) and torch.is_tensor(b):
+ return a.to(b)
+ else:
+ return a
+
+
+class AverageMeter(Meter):
+ """Computes and stores the average and current value"""
+
+ def __init__(self, name: str, fmt: str = ":f", round: Optional[int] = None) -> None:
+ self.name = name
+ self.fmt = fmt
+ self.round = round
+ self.reset()
+
+ def reset(self):
+ self.val = None # most recent update
+ self.sum = 0 # sum from all updates
+ self.count = 0 # total n from all updates
+ self.avg = 0
+
+ def update(self, val, n=1):
+ if val is not None:
+ self.val = val
+ if n > 0:
+ self.sum = type_as(self.sum, val) + (val * n)
+ self.count = type_as(self.count, n) + n
+ self.avg = self.sum / self.count if self.count > 0 else self.val
+
+ def state_dict(self):
+ return {
+ "val": self.val,
+ "sum": self.sum,
+ "count": self.count,
+ "round": self.round,
+ }
+
+ def load_state_dict(self, state_dict):
+ self.val = state_dict["val"]
+ self.sum = state_dict["sum"]
+ self.count = state_dict["count"]
+ self.round = state_dict.get("round", None)
+
+ @property
+ def smoothed_value(self) -> float:
+ val = self.avg
+ if self.round is not None and val is not None:
+ val = safe_round(val, self.round)
+ return val
+
+ def __str__(self) -> str:
+ fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})"
+ return fmtstr.format(**self.__dict__)
diff --git a/src/chop/nn/optical/functional/initializer.py b/src/chop/nn/optical/functional/initializer.py
new file mode 100644
index 000000000..53002c398
--- /dev/null
+++ b/src/chop/nn/optical/functional/initializer.py
@@ -0,0 +1,152 @@
+"""
+Description:
+Author: Jiaqi Gu (jqgu@utexas.edu)
+Date: 2021-06-06 01:57:16
+LastEditors: Jiaqi Gu (jqgu@utexas.edu)
+LastEditTime: 2021-06-06 01:57:18
+"""
+import numpy as np
+import torch
+
+__all__ = [
+ "quant_kaiming_uniform",
+ "quant_kaiming_uniform_",
+ "truncated_normal",
+ "truncated_normal_",
+ "morr_uniform_",
+ "morr_uniform",
+]
+
+
+def quant_kaiming_uniform(w, nbit, beta=1.5):
+ """https://arxiv.org/pdf/1802.04680.pdf"""
+ if w.dim() > 2:
+ receptive_field = w[0, 0, ...].numel()
+ else:
+ receptive_field = 1
+ fan_in = w.size(1) * receptive_field
+ sigma = 2 ** (1 - nbit)
+ L_min = beta * sigma
+ L = max(np.sqrt(6 / fan_in), L_min)
+ return w.clone().uniform_(-L, L)
+
+
+def quant_kaiming_uniform_(w, nbit, beta=1.5):
+ """https://arxiv.org/pdf/1802.04680.pdf"""
+ if w.dim() > 2:
+ receptive_field = w[0, 0, ...].numel()
+ else:
+ receptive_field = 1
+ fan_in = w.size(1) * receptive_field
+ sigma = 2 ** (1 - nbit)
+ L = np.sqrt(6 / fan_in)
+ L_min = beta * sigma
+ scale = 2 ** round(np.log2(L_min / L))
+ scale = max(scale, 1.0)
+ L = max(L, L_min)
+
+ return torch.nn.init.uniform_(w, -L, L), scale
+
+
+def truncated_normal(tensor, mean=0, std=1, a=-2, b=2):
+ size = tensor.shape
+ tmp = tensor.new_empty(size + (4,)).normal_()
+ a = (a - mean) / std
+ b = (b - mean) / std
+ valid = (tmp < b) & (tmp > a)
+ ind = valid.max(-1, keepdim=True)[1]
+ output = tmp.gather(-1, ind).squeeze(-1).mul_(std).add_(mean)
+ return output
+
+
+def truncated_normal_(tensor, mean=0, std=1, a=-2, b=2):
+ size = tensor.shape
+ tmp = tensor.new_empty(size + (4,)).normal_()
+ a = (a - mean) / std
+ b = (b - mean) / std
+ valid = (tmp < b) & (tmp > a)
+ ind = valid.max(-1, keepdim=True)[1]
+ tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1))
+ tensor.data.mul_(std).add_(mean)
+ return tensor
+
+
+def morr_uniform_(tensor, MORRConfig, n_op=4, biased=False, gain=1):
+ """
+ description: Uniform initialization for MORR array based tensor core [SqueezeLight, Gu+, DATE'21]. We only consider how n_op influence one MORR's output. How to balance vector length should be considered in learnable balancing factor\\
+ @tensor {torch.Tensor} weight tensor/parameter\\
+ @MORRConfig {Config} MORR configuration defined in the onnlib/model/layer/device/mrr\\
+ @n_op {int scalar} Number of operands on an MORR\\
+ @biased {bool} biased=True, weight in [0, L]; otherwise in [-L/2, L/2].\\
+ @gain {float} Gain due to activation. ReLU=sqrt(2), Tanh=5/3, Clamp(0,1)=2\\
+ return {}
+ """
+ morr_fwhm = (
+ -4
+ * np.pi ** 2
+ * MORRConfig.radius
+ * MORRConfig.effective_index
+ * (
+ 1 / MORRConfig.resonance_wavelength
+ - 1 / (MORRConfig.resonance_wavelength - MORRConfig.bandwidth / 2)
+ )
+ )
+ ### first we need to calculate the information gain of an MORR, estimated by linear estimation at 0 and FWHM
+ # t1 = mrr_roundtrip_phase_to_tr_fused(torch.tensor([0]).float(), a=MORRConfig.attenuation_factor, r=MORRConfig.coupling_factor, intensity=True)
+ # t2 = mrr_roundtrip_phase_to_tr_fused(torch.tensor([morr_fwhm]).float(), a=MORRConfig.attenuation_factor, r=MORRConfig.coupling_factor, intensity=True)
+ # g = (t2 - t1) / morr_fwhm
+
+ ### calculate the variance of the weight
+ # var_phi = 1 ## assume the input is normalized to have variance 1
+ # var_w = 1/(3/2*g**4*n_op*var_phi)
+
+ ### calculate range of uniform distribution U(-L,L)
+ # L = ((3 * var_w)**0.5).item()
+ # return torch.nn.init.uniform_(tensor, -L, L)
+
+ ## approximation by assuming 4*std(phi)= 3*FWHM, E[x]=0, D[x]=1, W ~ U[0, L]
+ L = (3 / (4 * n_op)) ** 0.5 * morr_fwhm * gain
+ if biased:
+ return torch.nn.init.uniform_(tensor, 0, L)
+ else:
+ return torch.nn.init.uniform_(tensor, -L / 2, L / 2)
+
+
+def morr_uniform(tensor, MORRConfig, n_op=4, biased=False, gain=1):
+ """
+ description: Uniform initialization for MORR array based tensor core [SqueezeLight, Gu+, DATE'21]\\
+ @tensor {torch.Tensor} weight tensor/parameter\\
+ @MORRConfig {Config} MORR configuration defined in the onnlib/model/layer/device/mrr\\
+ @n_op {int scalar} Number of operands on an MORR\\
+ @biased {bool} biased=True, weight in [0, L]; otherwise in [-L/2, L/2].\\
+ @gain {float} Gain due to activation. ReLU=sqrt(2), Tanh=5/3, Clamp(0,1)=2\\
+ return {}
+ """
+ morr_fwhm = (
+ -4
+ * np.pi ** 2
+ * MORRConfig.radius
+ * MORRConfig.effective_index
+ * (
+ 1 / MORRConfig.resonance_wavelength
+ - 1 / (MORRConfig.resonance_wavelength - MORRConfig.bandwidth / 2)
+ )
+ )
+ ### first we need to calculate the information gain of an MORR, estimated by linear estimation at 0 and FWHM
+ # t1 = mrr_roundtrip_phase_to_tr_fused(torch.tensor([0]), a=MORRConfig.attenuation_factor, r=MORRConfig.coupling_factor, intensity=True)
+ # t2 = mrr_roundtrip_phase_to_tr_fused(torch.tensor([morr_fwhm]), a=MORRConfig.attenuation_factor, r=MORRConfig.coupling_factor, intensity=True)
+ # g = (t2 - t1) / morr_fwhm
+
+ # var_phi = 1 ## assume the input is normalized to have variance 1
+ # var_w = 1/(3/2*g**4*n_op*var_phi)
+
+ # ### calculate range of uniform distribution U(-L,L)
+ # L = (3 * var_w)**0.5
+ # return tensor.clone().uniform_(-L, L)
+
+ ## approximation by assuming 4*std(phi)= 3*FWHM, E[x]=0, D[x]=1, W ~ U[0, L]
+ L = (3 / (4 * n_op)) ** 0.5 * morr_fwhm * gain
+ if biased:
+ return tensor.clone().uniform_(0, L)
+ else:
+ return tensor.clone().uniform_(-L / 2, L / 2)
diff --git a/src/chop/nn/optical/functional/mrr.py b/src/chop/nn/optical/functional/mrr.py
new file mode 100644
index 000000000..53fd56fd9
--- /dev/null
+++ b/src/chop/nn/optical/functional/mrr.py
@@ -0,0 +1,112 @@
+"""
+Description:
+Author: Jiaqi Gu (jqgu@utexas.edu)
+Date: 2021-07-18 00:03:04
+LastEditors: Jiaqi Gu (jqgu@utexas.edu)
+LastEditTime: 2021-07-18 00:03:05
+"""
+
+import numpy as np
+
+
+__all__ = [
+ "MORRConfig_20um_MQ",
+ "MRRConfig_5um_HQ",
+ "MRRConfig_5um_MQ",
+ "MRRConfig_5um_LQ",
+ "MORRConfig_10um_MQ",
+]
+
+
+class MORRConfig_20um_MQ:
+ attenuation_factor = 0.8578
+ coupling_factor = 0.8985
+ radius = 20000 # nm
+ group_index = 2.35316094
+ effective_index = 2.35
+ resonance_wavelength = 1554.252 # nm
+ bandwidth = 0.67908 # nm
+ quality_factor = 2288.7644639
+
+
+class MRRConfig_5um_HQ:
+ attenuation_factor = 0.987
+ coupling_factor = 0.99
+ radius = 5000 # nm
+ group_index = 2.35316094
+ effective_index = 2.4
+ resonance_wavelength = 1538.739 # nm
+ bandwidth = 0.2278 # nm
+ quality_factor = 6754.780509
+
+
+class MRRConfig_5um_MQ:
+ attenuation_factor = 0.925
+ coupling_factor = 0.93
+ radius = 5000 # nm
+ group_index = 2.35316094
+ effective_index = 2.4
+ resonance_wavelength = 1538.739 # nm
+ bandwidth = 1.5068 # nm
+ quality_factor = 1021.1965755
+
+
+class MRRConfig_5um_LQ:
+ attenuation_factor = 0.845
+ coupling_factor = 0.85
+ radius = 5000 # nm
+ group_index = 2.35316094
+ effective_index = 2.4
+ resonance_wavelength = 1538.739 # nm
+ bandwidth = 2.522 # nm
+ quality_factor = 610.1265
+
+
+class MORRConfig_10um_MQ:
+ attenuation_factor = 0.8578
+ coupling_factor = 0.8985
+ radius = 10000 # nm
+ group_index = 2.35316094
+ effective_index = 2.4
+ resonance_wavelength = 1538.739 # nm
+ bandwidth = 1.6702 # nm
+ quality_factor = 1213.047
+
+
+def plot_curve(config):
+ import matplotlib.pyplot as plt
+
+ lambda0 = config.resonance_wavelength
+ lambda_vec = np.linspace(1546, lambda0, 9400)
+ aa = config.attenuation_factor # attenuation a
+
+ t = config.coupling_factor # self-coupling
+ # r = np.sqrt(1 - t**2) # cross coupling coef
+
+ R = config.radius # radius
+ neff = config.effective_index # refractive index
+ phi = -4 * np.pi * np.pi * R * neff / lambda_vec
+
+ phase_shift = np.linspace(phi[0], phi[-1], len(phi))
+ phase_shift = phase_shift - np.min(phase_shift)
+ print(phase_shift)
+ tr = (t - aa * np.exp(1j * phi)) / (1 - t * aa * np.exp(1j * phi))
+ energy = abs(tr) ** 2
+ print(energy)
+ plt.figure()
+ plt.plot(lambda_vec, energy)
+ plt.savefig("mrr_tr_wl.png")
+ plt.figure()
+ plt.plot(phase_shift, energy)
+ plt.savefig("mrr_tr_ps.png")
+
+ for i, e in enumerate(energy[:-1]):
+ if energy[i] >= 0.5 and energy[i + 1] <= 0.5:
+ print(i, i + 1)
+ print(energy[i], energy[i + 1])
+ print(lambda_vec[i], lambda_vec[i + 1])
+ exit(1)
+
+
+if __name__ == "__main__":
+ plot_curve(MRRConfig_5um_MQ)
diff --git a/src/chop/nn/optical/functional/mrr_op.py b/src/chop/nn/optical/functional/mrr_op.py
new file mode 100644
index 000000000..42952971e
--- /dev/null
+++ b/src/chop/nn/optical/functional/mrr_op.py
@@ -0,0 +1,404 @@
+"""
+Description:
+Author: Jiaqi Gu (jqgu@utexas.edu)
+Date: 2021-07-18 00:01:34
+LastEditors: Jiaqi Gu (jqgu@utexas.edu)
+LastEditTime: 2021-07-18 00:01:36
+"""
+
+from .compute import (
+ complex_mult,
+ polar_to_complex,
+ polynomial,
+)
+import logging
+
+import numpy as np
+import torch
+
+torch._C._jit_set_profiling_executor(False)
+
+
+__all__ = [
+ "mrr_voltage_to_delta_lambda",
+ "mrr_tr_to_roundtrip_phase",
+ "mrr_roundtrip_phase_to_tr",
+ "mrr_roundtrip_phase_to_tr_fused",
+ "mrr_roundtrip_phase_to_tr_grad_fused",
+ "mrr_roundtrip_phase_to_tr_func",
+ "mrr_roundtrip_phase_to_out_phase",
+ "mrr_tr_to_out_phase",
+ "mrr_roundtrip_phase_to_tr_phase",
+ "mrr_roundtrip_phase_to_tr_phase_fused",
+ "mrr_modulator",
+ "mrr_filter",
+ "morr_filter",
+ "mrr_fwhm_to_ng",
+ "mrr_ng_to_fsr",
+ "mrr_finesse",
+]
+
+
+def mrr_voltage_to_delta_lambda(v, alpha, k, gamma, n_g, lambda_0):
+ """
+ description: micro-ring resonator (MRR) wavelength modulation, \delta\lambda=\delta\n_eff\times\lambda/n_g, \deltan_eff=\gamma k \delta T=\gamma k \alpha v^2\\
+ v {torch.Tensor ro np.ndarray} voltage \\
+ alpha {scalar} voltage square to temperature change coefficient \\
+ k {scalar} parameter \\
+ gamma {scalar} power to phase shift coefficient \\
+ n_g {scalar} group index, typically from 4 to 4.5\\
+ lambda_0 {torch.Tensor or np.ndarray} central wavelength\\
+ return delta_lambda {torch.Tensor or np.ndarray} resonance wavelength drift
+ """
+ delta_neff = gamma * k * alpha * v * v
+ delta_lambda = delta_neff * lambda_0 / n_g
+ return delta_lambda
+
+
+def mrr_tr_to_roundtrip_phase(t, a, r):
+ """
+ description: field transmission to round trip phase shift
+ t {torch.Tensor or np.ndarray} field transmission from [0,1] \\
+ a {scalar} attenuation coefficient\\
+ r {scalar} coupling coefficient\\
+ return phi {torch.Tensor or np.ndarray} roune trip phase shift (abs(phase lag))[0, pi], center is 0. phase lag is negative, the sign is moved to the equation
+ """
+ # the curve has multiple valleies, thus given a t, there is infinite number of rt_phi, we only want [-pi, 0], thus the abs(phase lag) is in [0, pi], acos returns [0, pi], which matches our assumption
+ assert 0 <= a <= 1, logging.error(f"Expect a from [0,1] but got {a}")
+ assert 0 <= r <= 1, logging.error(f"Expect r from [0,1] but got {r}")
+ # given a and r, the curve is fixed, the max and min may not be 1 and 0
+ cos_phi = ((a * a + r * r - t * (1 + r * r * a * a)) / (2 * (1 - t) * a * r)).clamp(0, 1)
+
+ if isinstance(cos_phi, torch.Tensor):
+ return cos_phi.acos(), cos_phi
+ elif isinstance(cos_phi, np.ndarray):
+ return np.arccos(cos_phi), cos_phi
+ else:
+ raise NotImplementedError
+
+
+def mrr_roundtrip_phase_to_tr(
+ rt_phi, a: float = 0.8, r: float = 0.9, poly_coeff=None, intensity: bool = False
+):
+ """
+ description: round trip phase shift to field transmission
+ rt_phi {torch.Tensor or np.ndarray} abs of roundtrip phase shift (abs(phase lag)). range from abs([-pi, 0])=[0, pi]\\
+ a {scalar} attenuation coefficient\\
+ r {scalar} self-coupling coefficient\\
+ poly_coeff {Callable} polynomial coefficients of intensity tranmission-roundtrip phase curve. Default set to None. None for slow computation\\
+ intensity {bool scalar} whether output intensity tranmission or field transmission
+ return t {torch.Tensor or np.ndarray} mrr through port field/intensity transmission
+ """
+ if poly_coeff is not None:
+ # fast mode, use polynomial to predict the intensity transmission curve
+ # if using polynomial, we want fast intensity transmission estimation, instead of field
+ # if using coherent light, we will use complex output, we won't use polynomial fit
+ t = polynomial(rt_phi.clamp(0, np.pi), poly_coeff).clamp(1e-8, 1)
+ if not intensity:
+ # avoid NAN
+ t = (t + 1e-12).sqrt()
+ else:
+ # use slow but accurate mode from theoretical equation
+ # create e^(-j phi) first
+ # with torch.autograd.profiler.profile(use_cuda=True) as prof:
+ # ephi = torch.view_as_complex(polar_to_complex(mag=None, angle=-rt_phi)) ## this sign is from the negativity of phase lag
+ # ### Jiaqi: Since PyTorch 1.7 rsub is not supported for autograd of complex, so have to use negate and add
+ # a_ephi = -a * ephi
+ # t = torch.view_as_real((r + a_ephi)/(1 + r * a_ephi))
+
+ # if(intensity):
+ # t = get_complex_energy(t)
+ # else:
+ # t = get_complex_magnitude(t)
+ # print(prof.key_averages(group_by_stack_n=5).table(sort_by='cuda_time', row_limit=5))
+ ra_cosphi_by_n2 = -2 * r * a * rt_phi.cos()
+ t = (a * a + r * r + ra_cosphi_by_n2) / (1 + r * r * a * a + ra_cosphi_by_n2)
+ if not intensity:
+ # as long as a is not equal to r, t cannot be 0.
+ t = t.sqrt()
+ return t
+
+
+@torch.jit.script
+def mrr_roundtrip_phase_to_tr_fused(rt_phi, a: float = 0.8, r: float = 0.9, intensity: bool = False):
+ """
+ description: round trip phase shift to field transmission
+ rt_phi {torch.Tensor or np.ndarray} abs of roundtrip phase shift (abs(phase lag)). range from abs([-pi, 0])=[0, pi]\\
+ a {scalar} attenuation coefficient\\
+ r {scalar} self-coupling coefficient\\
+ intensity {bool scalar} whether output intensity tranmission or field transmission\\
+ return t {torch.Tensor or np.ndarray} mrr through port field/intensity transmission
+ """
+
+ # use slow but accurate mode from theoretical equation
+ # create e^(-j phi) first
+
+ # angle = -rt_phi
+ # ephi = torch.view_as_complex(torch.stack([angle.cos(), angle.sin()], dim=-1)) ## this sign is from the negativity of phase lag
+ # a_ephi = -a * ephi
+ # t = torch.view_as_real((r + a_ephi).div(1 + r * a_ephi))
+ # if(intensity):
+ # t = get_complex_energy(t)
+ # else:
+ # t = get_complex_magnitude(t)
+ ra_cosphi_by_n2 = -2 * r * a * rt_phi.cos()
+ t = (a * a + r * r + ra_cosphi_by_n2) / (1 + r * r * a * a + ra_cosphi_by_n2)
+ if not intensity:
+ # as long as a is not equal to r, t cannot be 0.
+ t = t.sqrt()
+
+ return t
+
+
+@torch.jit.script
+def mrr_roundtrip_phase_to_tr_grad_fused(rt_phi, a: float = 0.8, r: float = 0.9, intensity: bool = False):
+ """
+ description: round trip phase shift to the gradient of field transmission
+ rt_phi {torch.Tensor or np.ndarray} abs of roundtrip phase shift (abs(phase lag)). range from abs([-pi, 0])=[0, pi]\\
+ a {scalar} attenuation coefficient\\
+ r {scalar} self-coupling coefficient\\
+ intensity {bool scalar} whether output intensity tranmission or field transmission\\
+ return g {torch.Tensor or np.ndarray} the gradient of mrr through port field/intensity transmission
+ """
+ if not intensity:
+ g = (a * r * (a ** 2 - 1) * (r ** 2 - 1) * rt_phi.sin()) / (
+ (a ** 2 + r ** 2 - 2 * a * r * rt_phi.cos()) ** (1 / 2)
+ * (a ** 2 * r ** 2 + 1 - 2 * a * r * rt_phi.cos()) ** 1.5
+ )
+ else:
+ g = ((a ** 2 - 1) * (r ** 2 - 1) * 2 * a * r * rt_phi.sin()) / (
+ a ** 2 * r ** 2 + 1 - 2 * a * r * rt_phi.cos()
+ ) ** 2
+ return g
+
+
+def mrr_roundtrip_phase_to_tr_func(a: float = 0.8, r: float = 0.9, intensity: bool = False):
+ c1 = -2 * a * r
+ c2 = a * a + r * r
+ c3 = 1 + r * r * a * a - a * a - r * r
+ c4 = (a ** 2 - 1) * (r ** 2 - 1) * 2 * a * r
+
+ class MRRRoundTripPhaseToTrFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, input):
+ ctx.save_for_backward(input)
+ # ra_cosphi_by_n2 = input.cos().mul_(c1)
+ # numerator = ra_cosphi_by_n2.add_(c2)
+ # denominator = numerator.add(c3)
+ # t = numerator / denominator
+ t = input.cos().mul_(c1).add_(c2 + c3).reciprocal_().mul_(-c3).add_(1)
+ if not intensity:
+ # as long as a is not equal to r, t cannot be 0.
+ t.sqrt_()
+ return t
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ (input,) = ctx.saved_tensors
+ denominator = input.cos().mul_(c1).add_(c2 + c3)
+
+ if intensity:
+ denominator.square_()
+ numerator = input.sin().mul_(c4)
+ else:
+ numerator = input.sin().mul_(c4 / 2)
+ denominator = denominator.sub(1).pow_(1.5).mul_(denominator.sub(c3).sqrt_())
+ grad_input = numerator.div_(denominator).mul_(grad_output)
+ return grad_input
+
+ return MRRRoundTripPhaseToTrFunction.apply
+
+
+def mrr_roundtrip_phase_to_out_phase(rt_phi, a, r):
+ """
+ description: from round trip phase to output phase response \\
+ rt_phi {torch.Tensor or np.ndarray} round trip phase shift\\
+ a {scalar} attenuation coefficient\\
+ r {scalar} coupling coefficient\\
+ return phase {torch.Tensor or np.ndarray} output phase response
+ """
+ if isinstance(rt_phi, torch.Tensor):
+ arctan = torch.atan2
+ sin = torch.sin
+ cos = torch.cos
+ elif isinstance(rt_phi, np.ndarray):
+ arctan = np.arctan2
+ sin = np.sin
+ cos = np.cos
+ else:
+ raise NotImplementedError
+ sin_rt_phi = sin(rt_phi)
+ cos_rt_phi = cos(rt_phi)
+ # phi = np.pi + rt_phi + arctan(r*sin_rt_phi-2*r*r*a*sin_rt_phi*cos_rt_phi+r*a*a*sin_rt_phi, (a-r*cos_rt_phi)*(1-r*a*cos_rt_phi))
+ phi = (
+ np.pi
+ - rt_phi
+ - arctan(r * sin_rt_phi, a - r * cos_rt_phi)
+ - arctan(r * a * sin_rt_phi, 1 - r * a * cos_rt_phi)
+ )
+ return phi
+
+
+def mrr_tr_to_out_phase(t, a, r, onesided=True):
+ """
+ description: field transmission to round trip phase shift
+ t {torch.Tensor or np.ndarray} field transmission from [0,1] \\
+ a {scalar} attenuation coefficient\\
+ r {scalar} coupling coefficient\\
+ onesided {bool scalar} True if only use half of the curve, output phase range [0, pi]
+ return phi {torch.Tensor or np.ndarray} roune trip phase shift
+ """
+ rt_phi, cos_rt_phi = mrr_tr_to_roundtrip_phase(t, a, r)
+ if isinstance(t, torch.Tensor):
+ arctan = torch.atan2
+ sin = torch.sin
+ elif isinstance(t, np.ndarray):
+ arctan = np.arctan2
+ sin = np.sin
+ else:
+ raise NotImplementedError
+ sin_rt_phi = sin(rt_phi)
+ # phi = np.pi + rt_phi + arctan(r*sin_rt_phi-2*r*r*a*sin_rt_phi*cos_rt_phi+r*a*a*sin_rt_phi, (a-r*cos_rt_phi)*(1-r*a*cos_rt_phi))
+ phi = (
+ np.pi
+ - rt_phi
+ - arctan(r * sin_rt_phi, a - r * cos_rt_phi)
+ - arctan(r * a * sin_rt_phi, 1 - r * a * cos_rt_phi)
+ )
+ if onesided:
+ pass
+ return phi
+
+
+def mrr_roundtrip_phase_to_tr_phase(rt_phi, a, r):
+ """
+ description: from round trip phase to output transmission with phase response \\
+ rt_phi {torch.Tensor or np.ndarray} round trip phase shift\\
+ a {scalar} attenuation coefficient\\
+ r {scalar} coupling coefficient\\
+ return output {torch.Tensor or np.ndarray} transmission with phase response
+ """
+ # e^(-j phi)
+ ephi = torch.view_as_complex(polar_to_complex(mag=None, angle=-rt_phi))
+ a_ephi = -a * ephi
+ output = torch.view_as_real((r + a_ephi) / (1 + r * a_ephi))
+ return output
+
+
+@torch.jit.script
+def mrr_roundtrip_phase_to_tr_phase_fused(rt_phi, a: float, r: float):
+ """
+ description: from round trip phase to output transmission with phase response \\
+ rt_phi {torch.Tensor or np.ndarray} round trip phase shift\\
+ a {scalar} attenuation coefficient\\
+ r {scalar} coupling coefficient\\
+ return output {torch.Tensor or np.ndarray} transmission with phase response
+ """
+ # e^(-j phi)
+ rt_phi = -rt_phi
+ rt_phi = torch.complex(rt_phi.cos(), rt_phi.sin())
+ rt_phi = -a * rt_phi
+ output = torch.view_as_real((r + rt_phi) / (1 + r * rt_phi))
+ return output
+
+
+def mrr_modulator(t, a=0.9, r=0.8):
+ """
+ @description: all-pass MRR as a modulator. Map from the field intensity of through port transmission to coherent light with phase reponse\\
+ @t {torch.Tensor or np.ndarray} field intensity modulation factor\\
+ @a {float} attenuation factor from [0,1]. Default: 0.9\\
+ @r {float} transmission/self-coupling factor from [0,1]. Default: 0.8\\
+ @return: complexed light signal
+ """
+ phase = mrr_tr_to_out_phase(t, a, r)
+ cos_phase, sin_phase = torch.cos(phase), torch.sin(phase)
+ output_real = t * cos_phase
+ output_imag = t * sin_phase
+ output = torch.stack([output_real, output_imag], dim=-1)
+ return output
+
+
+def mrr_filter(x, t, a=0.9, r=0.8):
+ """
+ @description: all-pass MRR as a filter. Map from the input complex light signal to output signal with through port transmission\\
+ @x {torch.Tensor or np.ndarray} complexed input light signal\\
+ @t {torch.Tensor or np.ndarray} field intensity modulation factor\\
+ @a {float} attenuation factor from [0,1]. Default: 0.9\\
+ @r {float} transmission/self-coupling factor from [0,1]. Default: 0.8\\
+ @return: complexed light signal
+ """
+ phase = mrr_tr_to_out_phase(t, a, r)
+ cos_phase, sin_phase = torch.cos(phase), torch.sin(phase)
+ phase_shift = torch.complex(cos_phase, sin_phase)
+ out = t * complex_mult(x, phase_shift)
+ return out
+
+
+def morr_filter(rt_phi, tr_poly_coeff=None, a=0.9, r=0.8, x=None, coherent=False, intensity=False):
+ """
+ description: from round trip phase shift to output signal \\
+ rt_phi {torch.Tensor or np.ndarray, Optional} round trip phase shift. Default set to None \\
+ tr_poly_coeff {Callable} polynomial coefficients of tranmission-roundtrip phase curve. Default set to None. None for slow computation\\
+ a {float} attenuation factor from [0,1]. Default: 0.9\\
+ r {float} transmission/self-coupling factor from [0,1]. Default: 0.8\\
+ x {torch.Tensor or np.ndarray, Optional} input complex light signal {None, real tensor or complex tensor}. Default set to None\\
+ coherent {bool scalar, Optional} coherent output or not. Default set to False\\
+ intensity {bool scalar, Optional} whether use intensity or field transmission. Default set to False\\
+ return output {torch.Tensor or np.ndarray} real tensor if incoherent, complex tensor if coherent
+ """
+ if not coherent:
+ if x is None:
+ # unit laser input with incoherent light, 1e^j0
+ t = mrr_roundtrip_phase_to_tr(rt_phi, a=a, r=r, poly_coeff=tr_poly_coeff, intensity=intensity)
+ return t
+ else:
+ # incoherent light with non-unit input, input must be real number
+ t = mrr_roundtrip_phase_to_tr(rt_phi, a=a, r=r, poly_coeff=tr_poly_coeff, intensity=intensity)
+ return x * t
+ else:
+ if x is None:
+ # coherent light with unit laser, 1e^j0, treat morr as a mrr modulator
+ phase = polar_to_complex(mag=None, angle=mrr_roundtrip_phase_to_out_phase(rt_phi, a, r))
+ return phase
+ else:
+ # coherent light with complex input
+ return complex_mult(mrr_roundtrip_phase_to_tr_phase(rt_phi, a, r), x)
+
+
+def mrr_fwhm_to_ng(a, r, radius, lambda0, fwhm):
+ """
+ description: from full-width half maximum (FWHM) and resonance wavelength to group index n_g (Bogaerts et al., Silicon microring resonators, Laser and Photonics Review 2011, Eq.(7))\\
+ a {float} Attention coefficient\\
+ r {float} Self-coupling coefficient\\
+ radius {float} Radius of the MRR (unit: nm)\\
+ lambda0 {float} Resonance wavelength (unit: nm)\\
+ fwhm {float} bandwidth or full width half maximum (unit: nm)\\
+ return n_g {float} Group index of the MRR
+ """
+ n_g = (1 - r * a) * lambda0 ** 2 / (2 * np.pi * np.pi * radius * (r * a) ** 0.5 * fwhm)
+ return n_g
+
+
+def mrr_ng_to_fsr(lambda0, n_g, radius):
+ """
+ description: Calculate the free-spectral range (FSR) based on the central resonance wavelength, group index and MRR radius.
+ (Bogaerts et al., Silicon microring resonators, Laser and Photonics Review 2011, Eq.(9))\\
+ lambda0 {float} Resonance wavelength (unit: nm)\\
+ n_g {float} Group index\\
+ radius {float} Radius of the MRR (unit: nm)\\
+ return fsr {float} Free-spectral range
+ """
+ fsr = lambda0 ** 2 / (n_g * 2 * np.pi * radius)
+ return fsr
+
+
+def mrr_finesse(a, r):
+ """
+ description: Calculate the finesse of the MRR, i.e., finesse=FSR/FWHM=pi*sqrt(ra)/(1-ra) (Bogaerts et al., Silicon microring resonators, Laser and Photonics Review 2011, Eq.(21))\\
+ a {float} Attention coefficient\\
+ r {float} Self-coupling coefficient\\
+ return finesse {float} Finesse of the MRR
+ """
+ ra = r * a
+ finesse = np.pi * ra ** 0.5 / (1 - ra)
+ return finesse
diff --git a/src/chop/nn/optical/functional/quantize.py b/src/chop/nn/optical/functional/quantize.py
new file mode 100644
index 000000000..09b7ef28d
--- /dev/null
+++ b/src/chop/nn/optical/functional/quantize.py
@@ -0,0 +1,575 @@
+"""
+Description:
+Author: Jiaqi Gu (jqgu@utexas.edu)
+Date: 2021-06-06 03:15:00
+LastEditors: Jiaqi Gu (jqgu@utexas.edu)
+LastEditTime: 2021-06-06 03:15:00
+"""
+
+import numpy as np
+import torch
+import logging
+
+
+__all__ = [
+ "uniform_quantize_cpu",
+ "pact_quantize",
+ "PACT_Act",
+ "uniform_quantize",
+ "uniform_quantize_new",
+ "ewgs_quantize",
+ "input_quantize_fn",
+ "weight_quantize_fn",
+]
+
+
+class uniform_quantize_cpu(object):
+ def __init__(self, bits):
+ super(uniform_quantize_cpu).__init__()
+ self.bits = bits
+
+ def __call__(self, input):
+ if self.bits == 32:
+ out = input
+ elif self.bits == 1:
+ out = np.sign(input)
+ else:
+ n = float(2**self.bits - 1)
+ out = np.round(input * n) / n
+ return out
+
+
+def uniform_quantize(k, gradient_clip=False):
+ class qfn(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, input):
+ if k == 32:
+ out = input
+ elif k == 1:
+ out = torch.sign(input)
+ else:
+ n = float(2**k - 1)
+ out = torch.round(input * n) / n
+ return out
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ grad_input = grad_output.clone()
+ if gradient_clip:
+ grad_input.clamp_(-1, 1)
+ return grad_input
+
+ return qfn.apply
+
+
+############ add observer and new quant based on range and zeropoint for activation
+def uniform_quantize_new(k, gradient_clip=False):
+ """
+ Support uniform quantization with auto-adjusted input data range
+ args:
+ k: bitwidth
+ scale, zeropoint: obtained from observer
+ """
+
+ class qfn(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, input, scale, zero_point):
+ if k == 32:
+ out = input
+ elif k == 1:
+ out = torch.sign(input)
+ else:
+ n = float(2**k - 1)
+ # out = torch.round(input * n) / n
+ # out = (torch.clamp(torch.round(input / scale + zero_point), 0, n) - zero_point) * scale
+ out = input.div(scale).add_(zero_point).round_().clamp_(0, n).sub_(zero_point).mul_(scale)
+ return out
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ grad_input = grad_output.clone()
+ if gradient_clip:
+ grad_input.clamp_(-1, 1)
+ return grad_input, None, None
+
+ return qfn.apply
+
+
+def ewgs_quantize(num_levels, gradient_clip=False, scaling_factor: float = 1e-3):
+ class EWGS_quantizer(torch.autograd.Function):
+ """
+ Network Quantization with Element-wise Gradient Scaling, CVPR 2021
+ https://github.com/cvlab-yonsei/EWGS/blob/main/CIFAR10/custom_modules.py
+ x_in: continuous inputs within the range of [0,1]
+ num_levels: number of discrete levels
+ scaling_factor: backward scaling factor, typically fixed to 1e-3
+ x_out: discretized version of x_in within the range of [0,1]
+ """
+
+ @staticmethod
+ def forward(ctx, input):
+ out = input.mul(num_levels - 1).round_().mul_(1/(num_levels - 1))
+
+ ctx._scaling_factor = scaling_factor
+ ctx.save_for_backward(input - out)
+ return out
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ diff = ctx.saved_tensors[0]
+ delta = ctx._scaling_factor
+ scale = diff.mul_(grad_output.sign()).mul_(delta).add_(1)
+ grad_input = grad_output * scale
+ if gradient_clip:
+ grad_input.clamp_(-1, 1)
+ return grad_input
+
+ return EWGS_quantizer.apply
+
+
+class input_quantize_fn(torch.nn.Module):
+ def __init__(self, in_bit, alg="dorefa", device=torch.device("cuda:0"), quant_ratio=1.0):
+ """Input quantizer with Quant_Noise supported
+ Args:
+ in_bit (int): Input quantization bitwidth.
+ device (Device, optional): torch Device. Defaults to torch.device("cuda:0").
+ quant_ratio (float, optional): Quantization ratio. Defaults to 1.0.
+ """
+ super(input_quantize_fn, self).__init__()
+ assert 1 <= in_bit <= 32
+ self.in_bit = in_bit
+ self.alg = alg
+ assert alg in {"dorefa", "normal"}, f"Only support (dorefa, normal), but got {alg}"
+ self.quant_ratio = quant_ratio
+ assert 0 <= quant_ratio <= 1, logging.error(f"Wrong quant ratio. Must in [0,1], but got {quant_ratio}")
+ self.device = device
+
+ # define quant style
+ # dorefa: clamp to 0-1
+ # normal: obtain scale and zero_point via observer
+
+ if self.alg == "dorefa":
+ self.uniform_q = uniform_quantize(k=in_bit)
+ elif self.alg == "normal":
+ self.uniform_q = uniform_quantize_new(k=in_bit)
+ self.scale = None
+ self.zero_point = None
+ ### select scale and zero-point using EMA: exponential moving averages
+ # AT: MovingAverageMinMaxObserver only support self-defined quant bitwidths for pytorch1.7
+ # obs = torch.quantization.observer.MovingAverageMinMaxObserver(averaging_constant=0.01, dtype=torch.quint8,
+ # qscheme=torch.per_tensor_affine, reduce_range=False, quant_min=0, quant_max=2**self.in_bit-1)
+ # Thus use our version
+ ### torch version must be higher than 1.7
+ if 1 <= self.in_bit <= 8: # observer does not support higher than 8-bit
+ self.obs = torch.quantization.observer.MovingAverageMinMaxObserver(
+ averaging_constant=0.01,
+ dtype=torch.quint8,
+ qscheme=torch.per_tensor_affine,
+ reduce_range=False,
+ quant_min=0,
+ quant_max=2**self.in_bit - 1,
+ ).to(self.device)
+ else:
+ self.obs = None
+
+ def set_bitwidth(self, bit: int) -> None:
+ ### regenerate quantizer without changing observation statistics
+ if bit != self.in_bit:
+ if self.alg == "dorefa":
+ self.uniform_q = uniform_quantize(k=bit)
+ elif self.alg == "normal":
+ self.uniform_q = uniform_quantize_new(k=bit)
+ self.in_bit = bit
+
+ def set_alg(self, alg: str) -> None:
+ assert alg in {"dorefa", "normal"}, f"Only support (dorefa, normal), but got {alg}"
+ if alg != self.alg:
+ if alg == "dorefa":
+ self.uniform_q = uniform_quantize(k=self.in_bit)
+ elif alg == "normal":
+ self.uniform_q = uniform_quantize_new(k=self.in_bit)
+ self.alg = alg
+
+ def set_quant_ratio(self, quant_ratio=None):
+ if quant_ratio is None:
+ ### get recommended value
+ quant_ratio = [
+ None,
+ 0.2,
+ 0.3,
+ 0.4,
+ 0.5,
+ 0.55,
+ 0.6,
+ 0.7,
+ 0.8,
+ 0.83,
+ 0.86,
+ 0.89,
+ 0.92,
+ 0.95,
+ 0.98,
+ 0.99,
+ 1,
+ ][min(self.in_bit, 16)]
+ assert 0 <= quant_ratio <= 1, logging.error(f"Wrong quant ratio. Must in [0,1], but got {quant_ratio}")
+ self.quant_ratio = quant_ratio
+
+ def forward(self, x):
+ if self.quant_ratio < 1 and self.training:
+ ### implementation from fairseq
+ ### must fully quantize during inference
+ quant_noise_mask = torch.empty_like(x, dtype=torch.bool).bernoulli_(1 - self.quant_ratio)
+ else:
+ quant_noise_mask = None
+
+ if self.in_bit == 32:
+ input_q = x
+ elif self.in_bit == 1:
+ x = x.clamp(0, 1)
+ input_q = (self.uniform_q(x - 0.5) + 1) / 2
+ if quant_noise_mask is not None:
+ noise = input_q.data.sub_(x.data).masked_fill_(quant_noise_mask, 0)
+ ### unquantized inputs have to be clamped
+ input_q = x + noise
+ else:
+ ### dorefa-style clamp for input data
+ if self.alg == "dorefa":
+ x = x.clamp(0, 1)
+ input_q = self.uniform_q(x)
+ elif self.alg == "normal":
+ if self.obs is not None:
+ if self.training:
+ self.obs(x)
+ scale, zero_point = self.obs.calculate_qparams()
+ # convert scale and zero_point type from qint8
+ self.scale = scale.to(x)
+ self.zero_point = zero_point.to(x)
+ input_q = self.uniform_q(x, self.scale, self.zero_point)
+ else:
+ input_q = x # if no observer (in_bit > 8), do not quantize
+ else:
+ raise NotImplementedError
+
+ # add noise
+ if quant_noise_mask is not None:
+ noise = input_q.data.sub_(x.data).masked_fill_(quant_noise_mask, 0)
+ ### unquantized inputs have to be clamped
+ input_q = x + noise
+
+ return input_q
+
+
+class weight_quantize_fn(torch.nn.Module):
+ def __init__(self, w_bit, mode="oconv", alg="dorefa", quant_ratio=1.0):
+ """Differentiable weight quantizer. Support different algorithms. Support Quant-Noise with partial quantization.
+
+ Args:
+ w_bit (int): quantization bitwidth
+ mode (str, optional): Different mode indicates different NN architectures. Defaults to "oconv".
+ alg (str, optional): Quantization algorithms. [dorefa, dorefa_sym, qnn, dorefa_pos] Defaults to "dorefa".
+ quant_ratio (float, optional): Quantization ratio to support full-precision gradient flow. Defaults to 1.0.
+ """
+ super(weight_quantize_fn, self).__init__()
+ assert 1 <= w_bit <= 32, logging.error(f"Only support 1 - 32 bit quantization, but got {w_bit}")
+ self.w_bit = w_bit
+ self.alg = alg
+ self.mode = mode
+ assert alg in {"dorefa", "dorefa_sym", "qnn", "dorefa_pos"}, logging.error(
+ f"Only support (dorefa, dorefa_sym, qnn, dorefa_pos) algorithms, but got {alg}"
+ )
+ self.quant_ratio = quant_ratio
+ assert 0 <= quant_ratio <= 1, logging.error(f"Wrong quant ratio. Must in [0,1], but got {quant_ratio}")
+ self.uniform_q = uniform_quantize(k=w_bit, gradient_clip=True)
+
+ def set_quant_ratio(self, quant_ratio=None):
+ if quant_ratio is None:
+ ### get recommended value
+ quant_ratio = [
+ None,
+ 0.2,
+ 0.3,
+ 0.4,
+ 0.5,
+ 0.55,
+ 0.6,
+ 0.7,
+ 0.8,
+ 0.83,
+ 0.86,
+ 0.89,
+ 0.92,
+ 0.95,
+ 0.98,
+ 0.99,
+ 1,
+ ][min(self.w_bit, 16)]
+ assert 0 <= quant_ratio <= 1, logging.error(f"Wrong quant ratio. Must in [0,1], but got {quant_ratio}")
+ self.quant_ratio = quant_ratio
+
+ def set_bitwidth(self, bit: int) -> None:
+ ### regenerate quantizer without changing observation statistics
+ if bit != self.w_bit:
+ self.uniform_q = uniform_quantize(k=bit, gradient_clip=True)
+ self.w_bit = bit
+
+ def forward(self, x):
+ if self.quant_ratio < 1 and self.training:
+ ### implementation from fairseq
+ ### must fully quantize during inference
+ quant_noise_mask = torch.empty_like(x, dtype=torch.bool).bernoulli_(1 - self.quant_ratio)
+ else:
+ quant_noise_mask = None
+
+ if self.w_bit == 32:
+ weight_q = torch.tanh(x)
+ weight_q = weight_q / torch.max(torch.abs(weight_q))
+ elif self.w_bit == 1:
+ if self.mode == "ringonn":
+ weight_q = (self.uniform_q(x) / 4) + 0.5
+ else:
+ if self.alg == "dorefa":
+ E = x.data.abs().mean()
+ weight_q = (self.uniform_q(x / E) * E + E) / 2 # [0, E]
+ if quant_noise_mask is not None:
+ x = (x + E) / 2
+ noise = weight_q.data.sub_(x.data).masked_fill_(quant_noise_mask, 0)
+ ### unquantized weights have to follow reparameterization, i.e., tanh and scale
+ weight_q = x + noise
+ elif self.alg == "dorefa_sym":
+ E = x.data.abs().mean()
+ weight_q = self.uniform_q(x / E) * E # [-E, E]
+ if quant_noise_mask is not None:
+ noise = weight_q.data.sub_(x.data).masked_fill_(quant_noise_mask, 0)
+ ### unquantized weights have to follow reparameterization, i.e., tanh and scale
+ weight_q = x + noise
+ else:
+ assert NotImplementedError
+ else:
+ if self.alg == "dorefa":
+ weight = torch.tanh(x) # [-1, 1]
+ weight = weight / 2 / torch.max(torch.abs(weight.data)) + 0.5
+ # weight = weight / 2 + 0.5
+ weight_q = self.uniform_q(weight)
+ if quant_noise_mask is not None:
+ noise = weight_q.data.sub_(weight.data).masked_fill_(quant_noise_mask, 0)
+ ### unquantized weights have to follow reparameterization, i.e., tanh and scale
+ weight_q = weight + noise
+
+ elif self.alg == "dorefa_sym":
+ weight = torch.tanh(x) # [-1, 1]
+ r = torch.max(torch.abs(weight.data))
+ # weight = weight / 2 + 0.5
+ weight_q = self.uniform_q(weight / (2 * r) + 0.5) * (2 * r) - r
+ if quant_noise_mask is not None:
+ noise = weight_q.data.sub_(weight.data).masked_fill_(quant_noise_mask, 0)
+ ### unquantized weights have to follow reparameterization, i.e., tanh
+ weight_q = weight + noise
+ elif self.alg == "dorefa_pos":
+ weight = torch.tanh(x) # [-1, 1]
+ r = torch.max(torch.abs(weight.data))
+ weight = weight + r
+ # weight = weight / 2 + 0.5
+ weight_q = self.uniform_q(weight / (2 * r)) * 2 * r
+ if quant_noise_mask is not None:
+ noise = weight_q.data.sub_(weight.data).masked_fill_(quant_noise_mask, 0)
+ ### unquantized weights have to follow reparameterization, i.e., tanh
+ weight_q = weight + noise
+
+ elif self.alg == "qnn":
+ x_min = torch.min(x.data)
+ x_max = torch.max(x.data)
+ x_range = x_max - x_min
+ weight_q = self.uniform_q((x - x_min) / x_range) * x_range + x_min
+ if quant_noise_mask is not None:
+ noise = weight_q.data.sub_(x.data).masked_fill_(quant_noise_mask, 0)
+ ### unquantized weights have to follow reparameterization, i.e., tanh
+ weight_q = x + noise
+ else:
+ assert NotImplementedError
+
+ return weight_q
+
+
+# PACT activation: https://arxiv.org/pdf/1805.06085.pdf
+class PACT_QuantFunc(torch.autograd.Function):
+ r"""PACT (PArametrized Clipping acTivation) quantization function for activations.
+ Implements a :py:class:`torch.autograd.Function` for quantizing activations in :math:`Q` bits using the PACT strategy.
+ In forward propagation, the function is defined as
+
+ .. math::
+ \mathbf{y} = f(\mathbf{x}) = 1/\varepsilon \cdot \left\lfloor\mathrm{clip}_{ [0,\alpha) } (\mathbf{x})\right\rfloor \cdot \varepsilon
+
+ where :math:`\varepsilon` is the quantization precision:
+
+ .. math::
+ \varepsilon = \alpha / (2^Q - 1)
+
+ In backward propagation, using the Straight-Through Estimator, the gradient of the function is defined as
+
+ .. math::
+ \mathbf{\nabla}_\mathbf{x} \mathcal{L} &\doteq \mathbf{\nabla}_\mathbf{y} \mathcal{L}
+
+ It can be applied by using its static `.apply` method:
+
+ :param input: the tensor containing :math:`x`, the activations to be quantized.
+ :type input: `torch.Tensor`
+ :param eps: the precomputed value of :math:`\varepsilon`.
+ :type eps: `torch.Tensor` or float
+ :param alpha: the value of :math:`\alpha`.
+ :type alpha: `torch.Tensor` or float
+ :param delta: constant to sum to `eps` for numerical stability (default unused, 0 ).
+ :type delta: `torch.Tensor` or float
+
+ :return: The quantized input activations tensor.
+ :rtype: `torch.Tensor`
+ """
+
+ @staticmethod
+ def forward(ctx, input, eps, alpha):
+ where_input_clipped = (input < 0) | (input >= alpha)
+ where_input_ltalpha = input < alpha
+ ctx.save_for_backward(where_input_clipped, where_input_ltalpha)
+ return ((input / (eps)).floor() * eps).clamp(0.0, alpha.data[0] - eps.data[0])
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ # see Hubara et al., Section 2.3
+ where_input_clipped, where_input_ltalpha = ctx.saved_tensors
+ # zero = torch.zeros(1, device=where_input_nonclipped.device)
+ grad_input = grad_output.masked_fill(where_input_clipped, 0)
+ # grad_input = torch.where(where_input_nonclipped, grad_output, zero)
+ grad_alpha = grad_output.masked_fill(where_input_ltalpha, 0).sum().expand(1)
+ # grad_alpha = torch.where(where_input_gtalpha, grad_output, zero).sum().expand(1)
+ return grad_input, None, grad_alpha
+
+
+pact_quantize = PACT_QuantFunc.apply
+
+
+class PACT_Act(torch.nn.Module):
+ r"""PACT (PArametrized Clipping acTivation) activation.
+ Implements a :py:class:`torch.nn.Module` to implement PACT-style activations. It is meant to replace :py:class:`torch.nn.ReLU`, :py:class:`torch.nn.ReLU6` and
+ similar activations in a PACT-quantized network.
+ This layer can also operate in a special mode, defined by the `statistics_only` member, in which the layer runs in
+ forward-prop without quantization, collecting statistics on the activations that can then be
+ used to reset the value of :math:`\alpha`.
+ In this mode, the layer collects:
+ - tensor-wise maximum value ever seen
+ - running average with momentum 0.9
+ - running variance with momentum 0.9
+ """
+
+ def __init__(
+ self,
+ precision=None,
+ alpha=1.0,
+ backprop_alpha=True,
+ statistics_only=False,
+ leaky=None,
+ device=torch.device("cuda"),
+ ):
+ r"""Constructor. Initializes a :py:class:`torch.nn.Parameter` for :math:`\alpha` and sets
+ up the initial value of the `statistics_only` member.
+ :param precision: instance defining the current quantization level (default `None`).
+ :type precision: :py:class:`nemo.precision.Precision`
+ :param alpha: the value of :math:`\alpha`.
+ :type alpha: `torch.Tensor` or float
+ :param backprop_alpha: default `True`; if `False`, do not update the value of `\alpha` with backpropagation.
+ :type backprop_alpha: bool
+ :param statistics_only: initialization value of `statistics_only` member.
+ :type statistics_only: bool
+ """
+
+ super(PACT_Act, self).__init__()
+ self.precision = precision
+ self.device = device
+ self.alpha = torch.nn.Parameter(torch.Tensor((alpha,)).to(device), requires_grad=backprop_alpha)
+ self.alpha_p = alpha
+ self.statistics_only = statistics_only
+ self.deployment = False
+ self.eps_in = None
+ self.leaky = leaky
+ # self.requantization_factor = requantization_factor
+
+ # these are only used to gather statistics
+ self.max = torch.nn.Parameter(torch.zeros_like(self.alpha.data).to(device), requires_grad=False)
+ self.min = torch.nn.Parameter(torch.zeros_like(self.alpha.data).to(device), requires_grad=False)
+ self.running_mean = torch.nn.Parameter(torch.zeros_like(self.alpha.data).to(device), requires_grad=False)
+ self.running_var = torch.nn.Parameter(torch.ones_like(self.alpha.data).to(device), requires_grad=False)
+
+ self.precise = False
+
+ def set_static_precision(self, limit_at_32_bits=True, **kwargs):
+ r"""Sets static parameters used only for deployment."""
+ # item() --> conversion to float
+ # apparently causes a slight, but not invisibile, numerical divergence
+ # between FQ and QD stages
+ self.eps_static = self.alpha.clone().detach() / (2.0 ** (self.precision) - 1)
+ self.alpha_static = self.alpha.clone().detach()
+ # D is selected as a power-of-two
+ D = 2.0 ** torch.ceil(torch.log2(self.requantization_factor * self.eps_static / self.eps_in))
+ if not limit_at_32_bits:
+ self.D = D
+ else:
+ self.D = min(D, 2.0 ** (32 - 1 - (self.precision)))
+
+ def get_output_eps(self, eps_in):
+ r"""Get the output quantum (:math:`\varepsilon`) given the input one.
+ :param eps_in: input quantum :math:`\varepsilon_{in}`.
+ :type eps_in: :py:class:`torch.Tensor`
+ :return: output quantum :math:`\varepsilon_{out}`.
+ :rtype: :py:class:`torch.Tensor`
+ """
+
+ return self.alpha / (2.0 ** (self.precision) - 1)
+
+ def reset_alpha(self, use_max=True, nb_std=5.0):
+ r"""Reset the value of :math:`\alpha`. If `use_max` is `True`, then the highest tensor-wise value collected
+ in the statistics collection phase is used. If `False`, the collected standard deviation multiplied by
+ `nb_std` is used as a parameter
+ :param use_max: if True, use the tensor-wise maximum value collected in the statistics run as new :math:`\alpha` (default True).
+ :type use_max: bool
+ :param nb_std: number of standard deviations to be used to initialize :math:`\alpha` if `use_max` is False.
+ :type nb_std: float
+ """
+
+ if use_max:
+ self.alpha.data[0] = self.max.item()
+ else:
+ self.alpha.data[0] = nb_std * torch.sqrt(self.running_var).item()
+
+ def get_statistics(self):
+ r"""Returns the statistics collected up to now.
+
+ :return: The collected statistics (maximum, running average, running variance).
+ :rtype: tuple of floats
+ """
+ return self.max.item(), self.running_mean.item(), self.running_var.item()
+
+ def forward(self, x):
+ r"""Forward-prop function for PACT-quantized activations.
+
+ See :py:class:`nemo.quant.pact_quant.PACT_QuantFunc` for details on the normal operation performed by this layer.
+ In statistics mode, it uses a normal ReLU and collects statistics in the background.
+ :param x: input activations tensor.
+ :type x: :py:class:`torch.Tensor`
+
+ :return: output activations tensor.
+ :rtype: :py:class:`torch.Tensor`
+ """
+
+ if self.statistics_only:
+ if self.leaky is None:
+ x = torch.nn.functional.relu(x)
+ else:
+ x = torch.nn.functional.leaky_relu(x, self.leaky)
+ with torch.no_grad():
+ self.max[:] = max(self.max.item(), x.max())
+ self.min[:] = min(self.min.item(), x.min())
+ self.running_mean[:] = 0.9 * self.running_mean.item() + 0.1 * x.mean()
+ self.running_var[:] = 0.9 * self.running_var.item() + 0.1 * x.std() * x.std()
+ return x
+ else:
+ eps = self.alpha / (2.0 ** (self.precision) - 1)
+ return pact_quantize(x, eps, self.alpha + eps)
diff --git a/src/chop/nn/optical/functional/torch_train.py b/src/chop/nn/optical/functional/torch_train.py
new file mode 100644
index 000000000..c62ba3e91
--- /dev/null
+++ b/src/chop/nn/optical/functional/torch_train.py
@@ -0,0 +1,857 @@
+"""
+Description:
+Author: Jiaqi Gu (jqgu@utexas.edu)
+Date: 2021-06-06 03:15:06
+LastEditors: Jiaqi Gu (jqgu@utexas.edu)
+LastEditTime: 2021-06-06 03:15:06
+"""
+
+import csv
+import os
+import random
+import time
+import traceback
+from collections import OrderedDict
+
+import numpy as np
+import torch
+from scipy import interpolate
+from torch.nn.modules.batchnorm import _BatchNorm
+try:
+ from torchsummary import summary
+except:
+ print("[W] Cannot import torchsummary")
+from .general import ensure_dir
+
+__all__ = [
+ "DeterministicCtx",
+ "set_torch_deterministic",
+ "set_torch_stochastic",
+ "get_random_state",
+ "summary_model",
+ "save_model",
+ "BestKModelSaver",
+ "load_model",
+ "count_parameters",
+ "check_converge",
+ "ThresholdScheduler",
+ "ThresholdScheduler_tf",
+ "ValueRegister",
+ "ValueTracer",
+ "EMA",
+ "SWA",
+ "export_traces_to_csv",
+ "set_learning_rate",
+ "get_learning_rate",
+ "apply_weight_decay",
+ "disable_bn",
+ "enable_bn",
+]
+
+class DeterministicCtx:
+ def __init__(self, random_state: int | None = None) -> None:
+ self.random_state = random_state
+
+
+ def __enter__(self):
+ self.random_state = random.getstate()
+ self.numpy_random_state = np.random.get_state()
+ self.torch_random_state = torch.random.get_rng_state()
+ self.torch_cuda_random_state = torch.cuda.get_rng_state()
+ set_torch_deterministic(self.random_state)
+ return self
+
+
+ def __exit__(self, *args):
+ random.setstate(self.random_state)
+ np.random.seed(self.numpy_random_state)
+ np.random.set_state(self.numpy_random_state)
+ torch.random.set_rng_state(self.torch_random_state)
+ torch.cuda.set_rng_state(self.torch_cuda_random_state)
+
+def set_torch_deterministic(random_state: int = 0) -> None:
+ random_state = int(random_state) % (2**32)
+ torch.manual_seed(random_state)
+ np.random.seed(random_state)
+ if torch.cuda.is_available():
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
+ torch.cuda.manual_seed_all(random_state)
+ random.seed(random_state)
+
+
+def set_torch_stochastic():
+ seed = int(time.time() * 1000) % (2**32)
+ torch.manual_seed(seed)
+ np.random.seed(seed)
+ if torch.cuda.is_available():
+ torch.backends.cudnn.deterministic = False
+ torch.cuda.manual_seed_all(seed)
+
+
+def get_random_state():
+ return np.random.get_state()[1][0]
+
+
+def summary_model(model, input):
+ summary(model, input)
+
+
+def save_model(model, path="./checkpoint/model.pt", print_msg=True):
+ """Save PyTorch model in path
+
+ Args:
+ model (PyTorch model): PyTorch model
+ path (str, optional): Full path of PyTorch model. Defaults to "./checkpoint/model.pt".
+ print_msg (bool, optional): Control of message print. Defaults to True.
+ """
+ dir = os.path.dirname(path)
+ if not os.path.exists(dir):
+ os.mkdir(dir)
+ try:
+ torch.save(model.state_dict(), path)
+ if print_msg:
+ print(f"[I] Model saved to {path}")
+ except Exception as e:
+ if print_msg:
+ print(f"[E] Model failed to be saved to {path}")
+ traceback.print_exc(e)
+
+
+class BestKModelSaver(object):
+ def __init__(
+ self,
+ k: int = 1,
+ descend: bool = True,
+ truncate: int = 2,
+ metric_name: str = "acc",
+ format: str = "{:.2f}",
+ ):
+ super().__init__()
+ self.k = k
+ self.descend = descend
+ self.truncate = truncate
+ self.metric_name = metric_name
+ self.format = format
+ self.epsilon = 0.1**truncate
+ self.model_cache = OrderedDict()
+
+ def better_op(self, a, b):
+ if self.descend:
+ return a >= b + self.epsilon
+ else:
+ return a <= b - self.epsilon
+
+ def __insert_model_record(self, metric, dir, checkpoint_name, epoch=None):
+ metric = round(metric * 10**self.truncate) / 10**self.truncate
+ if len(self.model_cache) < self.k:
+ new_checkpoint_name = (
+ f"{checkpoint_name}_{self.metric_name}-"
+ + self.format.format(metric)
+ + f"{'' if epoch is None else '_epoch-'+str(epoch)}"
+ )
+ path = os.path.join(dir, new_checkpoint_name + ".pt")
+ self.model_cache[path] = (metric, epoch)
+ return path, None
+ else:
+ worst_metric, worst_epoch = sorted(
+ list(self.model_cache.values()),
+ key=lambda x: x[0],
+ reverse=False if self.descend else True,
+ )[0]
+ if self.better_op(metric, worst_metric):
+ del_checkpoint_name = (
+ f"{checkpoint_name}_{self.metric_name}-"
+ + self.format.format(worst_metric)
+ + f"{'' if epoch is None else '_epoch-'+str(worst_epoch)}"
+ )
+ del_path = os.path.join(dir, del_checkpoint_name + ".pt")
+ try:
+ del self.model_cache[del_path]
+ except:
+ print(
+ "[W] Cannot remove checkpoint: {} from cache".format(del_path),
+ flush=True,
+ )
+ new_checkpoint_name = (
+ f"{checkpoint_name}_{self.metric_name}-"
+ + self.format.format(metric)
+ + f"{'' if epoch is None else '_epoch-'+str(epoch)}"
+ )
+ path = os.path.join(dir, new_checkpoint_name + ".pt")
+ self.model_cache[path] = (metric, epoch)
+ return path, del_path
+ # elif(acc == min_acc):
+ # new_checkpoint_name = f"{checkpoint_name}_acc-{acc:.2f}{'' if epoch is None else '_epoch-'+str(epoch)}"
+ # path = os.path.join(dir, new_checkpoint_name+".pt")
+ # self.model_cache[path] = (acc, epoch)
+ # return path, None
+ else:
+ return None, None
+
+ def get_topk_model_path(self, topk: int = 1):
+ if topk <= 0:
+ return []
+ if topk > len(self.model_cache):
+ topk = len(self.model_cache)
+ return [
+ i[0]
+ for i in sorted(
+ self.model_cache.items(), key=lambda x: x[1][0], reverse=self.descend
+ )[:topk]
+ ]
+
+ def save_model(
+ self,
+ model,
+ metric,
+ epoch=None,
+ path="./checkpoint/model.pt",
+ other_params=None,
+ save_model=False,
+ print_msg=True,
+ ):
+ """Save PyTorch model in path
+
+ Args:
+ model (PyTorch model): PyTorch model
+ acc (scalar): accuracy
+ epoch (scalar, optional): epoch. Defaults to None
+ path (str, optional): Full path of PyTorch model. Defaults to "./checkpoint/model.pt".
+ other_params (dict, optional): Other saved params. Defaults to None
+ save_model (bool, optional): whether save source code of nn.Module. Defaults to False
+ print_msg (bool, optional): Control of message print. Defaults to True.
+ """
+ dir = os.path.dirname(path)
+ ensure_dir(dir)
+ checkpoint_name = os.path.splitext(os.path.basename(path))[0]
+ if isinstance(metric, torch.Tensor):
+ metric = metric.data.item()
+ new_path, del_path = self.__insert_model_record(
+ metric, dir, checkpoint_name, epoch
+ )
+
+ if del_path is not None:
+ try:
+ os.remove(del_path)
+ print(f"[I] Model {del_path} is removed", flush=True)
+ except Exception as e:
+ if print_msg:
+ print(f"[E] Model {del_path} failed to be removed", flush=True)
+ traceback.print_exc(e)
+
+ if new_path is None:
+ if print_msg:
+ if self.descend:
+ best_list = list(reversed(sorted(list(self.model_cache.values()))))
+ else:
+ best_list = list(sorted(list(self.model_cache.values())))
+ print(
+ f"[I] Not best {self.k}: {best_list}, skip this model ("
+ + self.format.format(metric)
+ + f"): {path}",
+ flush=True,
+ )
+ else:
+ try:
+ # torch.save(model.state_dict(), new_path)
+ if other_params is not None:
+ saved_dict = other_params
+ else:
+ saved_dict = {}
+ if save_model:
+ saved_dict.update(
+ {"model": model, "state_dict": model.state_dict()}
+ )
+ torch.save(saved_dict, new_path)
+ else:
+ saved_dict.update({"model": None, "state_dict": model.state_dict()})
+ torch.save(saved_dict, new_path)
+ if print_msg:
+ if self.descend:
+ best_list = list(
+ reversed(sorted(list(self.model_cache.values())))
+ )
+ else:
+ best_list = list(sorted(list(self.model_cache.values())))
+
+ print(
+ f"[I] Model saved to {new_path}. Current best {self.k}: {best_list}",
+ flush=True,
+ )
+ except Exception as e:
+ if print_msg:
+ print(f"[E] Model failed to be saved to {new_path}", flush=True)
+ traceback.print_exc(e)
+ return new_path
+
+
+def load_model(
+ model,
+ path="./checkpoint/model.pt",
+ ignore_size_mismatch: bool = False,
+ print_msg=True,
+):
+ """Load PyTorch model in path
+
+ Args:
+ model (PyTorch model): PyTorch model
+ path (str, optional): Full path of PyTorch model. Defaults to "./checkpoint/model.pt".
+ ignore_size_mismatch (bool, optional): Whether ignore tensor size mismatch. Defaults to False.
+ print_msg (bool, optional): Control of message print. Defaults to True.
+ """
+ try:
+ raw_data = torch.load(path, map_location=lambda storage, location: storage)
+ if isinstance(raw_data, OrderedDict) and "state_dict" not in raw_data:
+ ### state_dict: OrderedDict
+ state_dict = raw_data
+ else:
+ ### {"state_dict": ..., "model": ...}
+ state_dict = raw_data["state_dict"]
+ load_keys = set(state_dict.keys())
+ model_keys = set(model.state_dict().keys())
+ common_dict = load_keys & model_keys
+ diff_dict = load_keys ^ model_keys
+ extra_keys = load_keys - model_keys
+ lack_keys = model_keys - load_keys
+ cur_state_dict = model.state_dict()
+ if ignore_size_mismatch:
+ size_mismatch_dict = set(
+ key
+ for key in common_dict
+ if model.state_dict()[key].size() != state_dict[key].size()
+ )
+ print(
+ f"[W] {size_mismatch_dict} are ignored due to size mismatch", flush=True
+ )
+ common_dict = common_dict - size_mismatch_dict
+
+ cur_state_dict.update({key: state_dict[key] for key in common_dict})
+ if len(diff_dict) > 0:
+ print(
+ f"[W] Warning! Model is not the same as the checkpoint. not found keys {lack_keys}. extra unused keys {extra_keys}"
+ )
+
+ model.load_state_dict(cur_state_dict)
+ if print_msg:
+ print(f"[I] Model loaded from {path}")
+ except Exception as e:
+ traceback.print_exc(e)
+ if print_msg:
+ print(f"[E] Model failed to be loaded from {path}")
+
+
+def count_parameters(model):
+ return sum(p.numel() for p in model.parameters() if p.requires_grad)
+
+
+def check_converge(trace, epsilon=0.002):
+ if len(trace) <= 1:
+ return False
+ if np.abs(trace[-1] - trace[-2]) / (np.abs(trace[-1]) + 1e-8) < epsilon:
+ return True
+ return False
+
+
+class ThresholdScheduler(object):
+ """Intepolation between begin point and end point. step must be within two endpoints"""
+
+ def __init__(self, step_beg, step_end, thres_beg, thres_end, mode="tanh"):
+ assert mode in {
+ "linear",
+ "tanh",
+ }, "Threshold scheduler only supports linear and tanh modes"
+ self.mode = mode
+ self.step_beg = step_beg
+ self.step_end = step_end
+ self.thres_beg = thres_beg
+ self.thres_end = thres_end
+ self.func = self.createFunc()
+
+ def normalize(self, step, factor=2):
+ return (step - self.step_beg) / (self.step_end - self.step_beg) * factor
+
+ def createFunc(self):
+ if self.mode == "linear":
+ return lambda x: (self.thres_end - self.thres_beg) * x + self.thres_beg
+ elif self.mode == "tanh":
+ x = self.normalize(
+ np.arange(self.step_beg, self.step_end + 1).astype(np.float32)
+ )
+ y = np.tanh(x) * (self.thres_end - self.thres_beg) + self.thres_beg
+ return interpolate.interp1d(x, y)
+
+ def __call__(self, x):
+ return self.func(self.normalize(x)).tolist()
+
+
+class ThresholdScheduler_tf(object):
+ """smooth increasing threshold with tensorflow model pruning scheduler"""
+
+ def __init__(self, step_beg, step_end, thres_beg, thres_end):
+ import tensorflow as tf
+ import tensorflow_model_optimization as tfmot
+
+ gpus = tf.config.list_physical_devices("GPU")
+ if gpus:
+ # Restrict TensorFlow to only allocate 1GB of memory on the first GPU
+ try:
+ for gpu in gpus:
+ tf.config.experimental.set_memory_growth(gpu, True)
+ except RuntimeError as e:
+ # Virtual devices must be set before GPUs have been initialized
+ print(e)
+ self.step_beg = step_beg
+ self.step_end = step_end
+ self.thres_beg = thres_beg
+ self.thres_end = thres_end
+ if thres_beg < thres_end:
+ self.thres_min = thres_beg
+ self.thres_range = thres_end - thres_beg
+ self.descend = False
+
+ else:
+ self.thres_min = thres_end
+ self.thres_range = thres_beg - thres_end
+ self.descend = True
+
+ self.pruning_schedule = tfmot.sparsity.keras.PolynomialDecay(
+ initial_sparsity=0,
+ final_sparsity=0.9999999,
+ begin_step=self.step_beg,
+ end_step=self.step_end,
+ )
+
+ def __call__(self, x):
+ if x < self.step_beg:
+ return self.thres_beg
+ elif x > self.step_end:
+ return self.thres_end
+ res_norm = self.pruning_schedule(x)[1].numpy()
+ if self.descend == False:
+ res = res_norm * self.thres_range + self.thres_beg
+ else:
+ res = self.thres_beg - res_norm * self.thres_range
+
+ if np.abs(res - self.thres_end) <= 1e-6:
+ res = self.thres_end
+ return res
+
+
+class ValueRegister(object):
+ def __init__(self, operator, name="", show=True):
+ self.op = operator
+ self.cache = None
+ self.show = show
+ self.name = name if len(name) > 0 else "value"
+
+ def register_value(self, x):
+ self.cache = self.op(x, self.cache) if self.cache is not None else x
+ if self.show:
+ print(f"Recorded {self.name} is {self.cache}")
+
+
+class ValueTracer(object):
+ def __init__(self, show=True):
+ self.cache = {}
+ self.show = show
+
+ def add_value(self, name, value, step):
+ if name not in self.cache:
+ self.cache[name] = {}
+ self.cache[name][step] = value
+ if self.show:
+ print(f"Recorded {name}: step = {step}, value = {value}")
+
+ def get_trace_by_name(self, name):
+ return self.cache.get(name, {})
+
+ def get_all_traces(self):
+ return self.cache
+
+ def __len__(self):
+ return len(self.cache)
+
+ def get_num_trace(self):
+ return len(self.cache)
+
+ def get_len_trace_by_name(self, name):
+ return len(self.cache.get(name, {}))
+
+ def dump_trace_to_file(self, name, file):
+ if name not in self.cache:
+ print(f"[W] Trace name '{name}' not found in tracer")
+ return
+ torch.save(self.cache[name], file)
+ print(f"[I] Trace {name} saved to {file}")
+
+ def dump_all_traces_to_file(self, file):
+ torch.save(self.cache, file)
+ print(f"[I] All traces saved to {file}")
+
+ def load_all_traces_from_file(self, file):
+ self.cache = torch.load(file)
+ return self.cache
+
+
+class EMA(object):
+ def __init__(self, mu):
+ super().__init__()
+ self.mu = mu
+ self.shadow = {}
+
+ def register(self, name, val):
+ self.shadow[name] = val.clone().data
+
+ def __call__(self, name, x, mask=None):
+ if name not in self.shadow:
+ self.register(name, x)
+ return x.data
+
+ old_average = self.shadow[name]
+ new_average = (1 - self.mu) * x + self.mu * old_average
+ if mask is not None:
+ new_average[mask].copy_(old_average[mask])
+ self.shadow[name] = new_average.clone()
+ return new_average.data
+
+
+class SWA(torch.nn.Module):
+ """Stochastic Weight Averging.
+
+ # Paper
+ title: Averaging Weights Leads to Wider Optima and Better Generalization
+ link: https://arxiv.org/abs/1803.05407
+
+ # Arguments
+ start_epoch: integer, epoch when swa should start.
+ lr_schedule: string, type of learning rate schedule.
+ swa_lr: float, learning rate for swa.
+ swa_lr2: float, upper bound of cyclic learning rate.
+ swa_freq: integer, length of learning rate cycle.
+ batch_size integer, batch size (for batch norm with generator)
+ verbose: integer, verbosity mode, 0 or 1.
+ """
+
+ def __init__(
+ self,
+ model: torch.nn.Module,
+ optimizer: torch.optim.Optimizer,
+ start_epoch: int,
+ epochs: int, # total epochs
+ steps, # total steps per epoch
+ lr_schedule="manual",
+ swa_lr="auto",
+ swa_lr2="auto",
+ swa_freq=1,
+ batch_size=None,
+ verbose=0,
+ ):
+ super().__init__()
+ self.model = model
+ self.optimizer = optimizer
+ self.start_epoch = start_epoch - 1
+ self.epochs = epochs
+ self.steps = steps
+ self.lr_schedule = lr_schedule
+ self.swa_lr = swa_lr
+
+ # if no user determined upper bound, make one based off of the lower bound
+ self.swa_lr2 = swa_lr2 if swa_lr2 is not None else 10 * swa_lr
+ self.swa_freq = swa_freq
+ self.batch_size = batch_size
+ self.verbose = verbose
+
+ if start_epoch < 2:
+ raise ValueError('"swa_start" attribute cannot be lower than 2.')
+
+ schedules = ["manual", "constant", "cyclic"]
+
+ if self.lr_schedule not in schedules:
+ raise ValueError(
+ '"{}" is not a valid learning rate schedule'.format(self.lr_schedule)
+ )
+
+ if self.lr_schedule == "cyclic" and self.swa_freq < 2:
+ raise ValueError('"swa_freq" must be higher than 1 for cyclic schedule.')
+
+ if self.swa_lr == "auto" and self.swa_lr2 != "auto":
+ raise ValueError(
+ '"swa_lr2" cannot be manually set if "swa_lr" is automatic.'
+ )
+
+ if (
+ self.lr_schedule == "cyclic"
+ and self.swa_lr != "auto"
+ and self.swa_lr2 != "auto"
+ and self.swa_lr > self.swa_lr2
+ ):
+ raise ValueError('"swa_lr" must be lower than "swa_lr2".')
+
+ def on_train_begin(self):
+ self.lr_record = []
+
+ if self.start_epoch >= self.epochs - 1:
+ raise ValueError('"swa_start" attribute must be lower than "epochs".')
+
+ self.init_lr = self.optimizer.param_groups[0]["lr"]
+
+ # automatic swa_lr
+ if self.swa_lr == "auto":
+ self.swa_lr = 0.1 * self.init_lr
+
+ if self.init_lr < self.swa_lr:
+ raise ValueError('"swa_lr" must be lower than rate set in optimizer.')
+
+ # automatic swa_lr2 between initial lr and swa_lr
+ if self.lr_schedule == "cyclic" and self.swa_lr2 == "auto":
+ self.swa_lr2 = self.swa_lr + (self.init_lr - self.swa_lr) * 0.25
+
+ self._check_batch_norm()
+
+ if self.has_batch_norm and self.batch_size is None:
+ raise ValueError(
+ '"batch_size" needs to be set for models with batch normalization layers.'
+ )
+
+ def on_epoch_begin(self, epoch):
+ # input epoch is from 0 to epochs-1
+
+ self.current_epoch = epoch
+ self._scheduler(epoch)
+
+ # constant schedule is updated epoch-wise
+ if self.lr_schedule == "constant":
+ self._update_lr(epoch)
+
+ if self.is_swa_start_epoch:
+ # self.swa_weights = self.model.get_weights()
+ self.swa_weights = {
+ name: p.data.clone() for name, p in self.model.named_parameters()
+ }
+
+ if self.verbose > 0:
+ print(
+ "\nEpoch %05d: starting stochastic weight averaging" % (epoch + 1)
+ )
+
+ if self.is_batch_norm_epoch:
+ self._set_swa_weights(epoch)
+
+ if self.verbose > 0:
+ print(
+ "\nEpoch %05d: reinitializing batch normalization layers"
+ % (epoch + 1)
+ )
+
+ self._reset_batch_norm()
+
+ if self.verbose > 0:
+ print(
+ "\nEpoch %05d: running forward pass to adjust batch normalization"
+ % (epoch + 1)
+ )
+
+ def on_batch_begin(self, batch):
+ # update lr each batch for cyclic lr schedule
+ if self.lr_schedule == "cyclic":
+ self._update_lr(self.current_epoch, batch)
+
+ if self.is_batch_norm_epoch:
+ batch_size = self.batch_size
+ # this is for tensorflow momentum, applied to the running stat
+ # momentum = batch_size / (batch * batch_size + batch_size)
+
+ # we need to convert it to torch momentum, applied to the batch stat
+ momentum = 1 - batch_size / (batch * batch_size + batch_size)
+
+ for layer in self.batch_norm_layers:
+ layer.momentum = momentum
+
+ def on_epoch_end(self, epoch):
+ if self.is_swa_start_epoch:
+ self.swa_start_epoch = epoch
+
+ if self.is_swa_epoch and not self.is_batch_norm_epoch:
+ self.swa_weights = self._average_weights(epoch)
+
+ def on_train_end(self):
+ if not self.has_batch_norm:
+ self._set_swa_weights(self.epochs)
+ else:
+ self._restore_batch_norm()
+
+ ## TODO: what is meaning here?
+ # for batch_lr in self.lr_record:
+ # self.model.history.history.setdefault("lr", []).append(batch_lr)
+
+ def _scheduler(self, epoch):
+ swa_epoch = epoch - self.start_epoch
+
+ self.is_swa_epoch = epoch >= self.start_epoch and swa_epoch % self.swa_freq == 0
+ self.is_swa_start_epoch = epoch == self.start_epoch
+ self.is_batch_norm_epoch = epoch == self.epochs - 1 and self.has_batch_norm
+
+ def _average_weights(self, epoch):
+ # return [
+ # (swa_w * ((epoch - self.start_epoch) / self.swa_freq) + w)
+ # / ((epoch - self.start_epoch) / self.swa_freq + 1)
+ # for swa_w, w in zip(self.swa_weights, self.model.get_weights())
+ # ]
+ out = {}
+ with torch.no_grad():
+ for name, w in self.model.named_parameters():
+ swa_w = self.swa_weights[name]
+ out[name] = (
+ swa_w * ((epoch - self.start_epoch) / self.swa_freq) + w.data
+ ) / ((epoch - self.start_epoch) / self.swa_freq + 1)
+ return out
+
+ def _update_lr(self, epoch, batch=None):
+ if self.is_batch_norm_epoch:
+ lr = 0
+ # K.set_value(self.model.optimizer.lr, lr)
+ set_learning_rate(lr, self.optimizer)
+ elif self.lr_schedule == "constant":
+ lr = self._constant_schedule(epoch)
+ # K.set_value(self.model.optimizer.lr, lr)
+ set_learning_rate(lr, self.optimizer)
+ elif self.lr_schedule == "cyclic":
+ lr = self._cyclic_schedule(epoch, batch)
+ # K.set_value(self.model.optimizer.lr, lr)
+ set_learning_rate(lr, self.optimizer)
+ self.lr_record.append(lr)
+
+ def _constant_schedule(self, epoch):
+ t = epoch / self.start_epoch
+ lr_ratio = self.swa_lr / self.init_lr
+ if t <= 0.5:
+ factor = 1.0
+ elif t <= 0.9:
+ factor = 1.0 - (1.0 - lr_ratio) * (t - 0.5) / 0.4
+ else:
+ factor = lr_ratio
+ return self.init_lr * factor
+
+ def _cyclic_schedule(self, epoch, batch):
+ """Designed after Section 3.1 of Averaging Weights Leads to
+ Wider Optima and Better Generalization(https://arxiv.org/abs/1803.05407)
+ """
+ # steps are mini-batches per epoch, equal to training_samples / batch_size
+ steps = self.steps
+
+ swa_epoch = (epoch - self.start_epoch) % self.swa_freq
+ cycle_length = self.swa_freq * steps
+
+ # batch 0 indexed, so need to add 1
+ i = (swa_epoch * steps) + (batch + 1)
+ if epoch >= self.start_epoch:
+ t = (((i - 1) % cycle_length) + 1) / cycle_length
+ return (1 - t) * self.swa_lr2 + t * self.swa_lr
+ else:
+ return self._constant_schedule(epoch)
+
+ def _set_swa_weights(self, epoch):
+ # self.model.set_weights(self.swa_weights)
+ for name, p in self.model.named_parameters():
+ p.data.copy_(self.swa_weights[name])
+
+ if self.verbose > 0:
+ print(
+ "\nEpoch %05d: final model weights set to stochastic weight average"
+ % (epoch + 1)
+ )
+
+ def _check_batch_norm(self):
+ self.batch_norm_momentums = []
+ self.batch_norm_layers = []
+ self.has_batch_norm = False
+ self.running_bn_epoch = False
+
+ for layer in self.model.modules():
+ if isinstance(layer, _BatchNorm):
+ self.has_batch_norm = True
+ self.batch_norm_momentums.append(layer.momentum)
+ self.batch_norm_layers.append(layer)
+
+ if self.verbose > 0 and self.has_batch_norm:
+ print(
+ "Model uses batch normalization. SWA will require last epoch "
+ "to be a forward pass and will run with no learning rate"
+ )
+
+ def _reset_batch_norm(self):
+ for layer in self.batch_norm_layers:
+ # initialized moving mean and
+ # moving var weights
+ layer.reset_running_stats()
+
+ def _restore_batch_norm(self):
+ for layer, momentum in zip(self.batch_norm_layers, self.batch_norm_momentums):
+ layer.momentum = momentum
+
+
+def export_traces_to_csv(trace_file, csv_file, fieldnames=None):
+ traces = torch.load(trace_file)
+
+ with open(csv_file, "w", newline="") as csvfile:
+ if fieldnames is None:
+ fieldnames = list(traces.keys())
+ writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
+
+ writer.writeheader()
+ max_len = max([len(traces[field]) for field in fieldnames])
+
+ for idx in range(max_len):
+ row = {}
+ for field in fieldnames:
+ value = traces[field][idx] if idx < len(traces[field]) else ""
+ row[field] = (
+ value.data.item() if isinstance(value, torch.Tensor) else value
+ )
+ writer.writerow(row)
+
+
+def set_learning_rate(lr, optimizer):
+ for param_group in optimizer.param_groups:
+ param_group["lr"] = lr
+
+
+def get_learning_rate(optimizer):
+ return optimizer.param_groups[0]["lr"]
+
+
+def apply_weight_decay(W, decay_rate, learning_rate, mask=None):
+ # in mask, 1 represents fixed variables, 0 represents trainable variables
+ if mask is not None:
+ W[~mask] -= W[~mask] * decay_rate * learning_rate
+ else:
+ W -= W * decay_rate * learning_rate
+
+
+def disable_bn(model: torch.nn.Module) -> None:
+ for m in model.modules():
+ if isinstance(
+ m,
+ (
+ torch.nn.BatchNorm1d,
+ torch.nn.BatchNorm2d,
+ torch.nn.BatchNorm3d,
+ torch.nn.SyncBatchNorm,
+ ),
+ ):
+ m.eval()
+
+
+def enable_bn(model: torch.nn.Module) -> None:
+ for m in model.modules():
+ if isinstance(
+ m,
+ (
+ torch.nn.BatchNorm1d,
+ torch.nn.BatchNorm2d,
+ torch.nn.BatchNorm3d,
+ torch.nn.SyncBatchNorm,
+ ),
+ ):
+ m.train()
diff --git a/src/chop/nn/optical/modules/__init__.py b/src/chop/nn/optical/modules/__init__.py
new file mode 100644
index 000000000..46b6f21e9
--- /dev/null
+++ b/src/chop/nn/optical/modules/__init__.py
@@ -0,0 +1,7 @@
+from .morr_linear import AllPassMORRCirculantLinear
+from .morr_conv2d import AllPassMORRCirculantConv2d
+
+optical_module_map = {
+ "linear_morr": AllPassMORRCirculantLinear,
+ "conv2d_morr": AllPassMORRCirculantConv2d,
+}
\ No newline at end of file
diff --git a/src/chop/nn/optical/modules/base_layer.py b/src/chop/nn/optical/modules/base_layer.py
new file mode 100644
index 000000000..cb1fbe9dd
--- /dev/null
+++ b/src/chop/nn/optical/modules/base_layer.py
@@ -0,0 +1,71 @@
+"""
+Description:
+Author: Jiaqi Gu (jqgu@utexas.edu)
+Date: 2021-06-08 18:55:05
+LastEditors: Jiaqi Gu (jqgu@utexas.edu)
+LastEditTime: 2021-06-08 18:55:05
+"""
+from typing import Any, Dict, Optional
+import torch
+from torch import nn
+from torch.types import Device
+
+__all__ = ["ONNBaseLayer"]
+
+
+class ONNBaseLayer(nn.Module):
+ def __init__(self, *args, device: Device = torch.device("cpu"), **kwargs) -> None:
+ super().__init__(*args, **kwargs)
+ # cuda or cpu, defaults to cpu
+ self.device = device
+
+ def build_parameters(self) -> None:
+ raise NotImplementedError
+
+ def reset_parameters(self) -> None:
+ raise NotImplementedError
+
+ @classmethod
+ def from_layer(cls, layer: nn.Module, *args, **kwargs) -> nn.Module:
+ raise NotImplementedError
+
+ def get_num_parameters(self) -> int:
+ return sum(p.numel() for p in self.parameters() if p.requires_grad)
+
+ def enable_fast_forward(self) -> None:
+ self.fast_forward_flag = True
+
+ def disable_fast_forward(self) -> None:
+ self.fast_forward_flag = False
+
+ def set_phase_variation(self, noise_std: float, random_state: Optional[int] = None) -> None:
+ self.phase_noise_std = noise_std
+
+ def set_gamma_noise(self, noise_std: float, random_state: Optional[int] = None) -> None:
+ self.gamma_noise_std = noise_std
+
+ def set_crosstalk_factor(self, crosstalk_factor: float) -> None:
+ self.crosstalk_factor = crosstalk_factor
+
+ def set_weight_bitwidth(self, w_bit: int) -> None:
+ self.w_bit = w_bit
+
+ def set_input_bitwidth(self, in_bit: int) -> None:
+ self.in_bit = in_bit
+
+ def load_parameters(self, param_dict: Dict[str, Any]) -> None:
+ """
+ description: update parameters based on this parameter dictionary\\
+ param param_dict {dict of dict} {param_name: param_tensor, ...}
+ """
+ for name, param in param_dict.items():
+ getattr(self, name).data.copy_(param)
+
+ def switch_mode_to(self, mode: str) -> None:
+ self.mode = mode
+
+ def forward(self, x):
+ raise NotImplementedError
+
+ def extra_repr(self) -> str:
+ return ""
diff --git a/src/chop/nn/optical/modules/morr_conv2d.py b/src/chop/nn/optical/modules/morr_conv2d.py
new file mode 100644
index 000000000..83f0bf1fe
--- /dev/null
+++ b/src/chop/nn/optical/modules/morr_conv2d.py
@@ -0,0 +1,458 @@
+"""
+Author: Jiaqi Gu (jqgu@utexas.edu)
+Date: 2021-01-27 01:08:44
+LastEditors: Jiaqi Gu (jqgu@utexas.edu)
+LastEditTime: 2021-07-18 00:40:18
+"""
+from typing import Optional, Tuple
+
+import numpy as np
+import torch
+import torch.fft
+from ..functional import im2col_2d, toeplitz
+from ..functional import logger
+from ..functional import morr_uniform_
+from ..functional import input_quantize_fn, weight_quantize_fn
+from torch import Tensor, nn
+from torch.nn import Parameter, init
+from torch.nn.modules.utils import _pair
+from torch.types import Device, _size
+from ..functional import MORRConfig_20um_MQ
+from ..functional import mrr_roundtrip_phase_to_tr_func, mrr_roundtrip_phase_to_tr_fused
+
+
+from .base_layer import ONNBaseLayer
+
+__all__ = ["AllPassMORRCirculantConv2d"]
+
+
+class AllPassMORRCirculantConv2d(ONNBaseLayer):
+ """
+ All-pass MORR Conv2d layer, assumes (1) block-circulant matrix (2) differential rails (3) learnable balancing factors.
+ J. Gu, et al., "SqueezeLight: Towards Scalable Optical Neural Networks with Multi-Operand Ring Resonators"
+ https://doi.org/10.23919/DATE51398.2021.9474147
+ """
+
+ __constants__ = [
+ "stride",
+ "padding",
+ "dilation",
+ "groups",
+ "padding_mode",
+ "output_padding",
+ "in_channels",
+ "out_channels",
+ "kernel_size",
+ "miniblock",
+ ]
+ __annotations__ = {"bias": Optional[torch.Tensor]}
+
+ _in_channels: int
+ out_channels: int
+ kernel_size: Tuple[int, ...]
+ stride: Tuple[int, ...]
+ padding: Tuple[int, ...]
+ dilation: Tuple[int, ...]
+ transposed: bool
+ output_padding: Tuple[int, ...]
+ groups: int
+ padding_mode: str
+ weight: Tensor
+ bias: Optional[Tensor]
+ miniblock: int
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: _size,
+ stride: _size = 1,
+ padding: _size = 0,
+ dilation: _size = 1,
+ groups: int = 1,
+ bias: bool = True,
+ miniblock: int = 4,
+ ### morr parameter
+ MORRConfig=MORRConfig_20um_MQ,
+ morr_init: bool = True, # whether to use initialization method customized for MORR
+ ### trainable MORR nonlinearity
+ trainable_morr_bias: bool = False,
+ trainable_morr_scale: bool = False,
+ device: Device = torch.device("cuda"),
+ ) -> None:
+ super(AllPassMORRCirculantConv2d, self).__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.kernel_size = _pair(kernel_size)
+ self.stride = _pair(stride)
+ self.padding = _pair(padding)
+ self.dilation = _pair(dilation)
+ self.groups = groups
+ assert groups == 1, f"Currently group convolution is not supported, but got group: {groups}"
+ self.in_channels_flat = self.in_channels * self.kernel_size[0] * self.kernel_size[1]
+ self.grid_dim_x = int(np.ceil(self.in_channels_flat / miniblock))
+ self.grid_dim_y = int(np.ceil(self.out_channels / miniblock))
+ self.in_channels_pad = self.grid_dim_x * miniblock
+ self.out_channels_pad = self.grid_dim_y * miniblock
+ self.miniblock = miniblock
+
+ self.v_max = 10.8
+ self.v_pi = 4.36
+ self.gamma = np.pi / self.v_pi ** 2
+ self.w_bit = 32
+ self.in_bit = 32
+ self.MORRConfig = MORRConfig
+ self.morr_init = morr_init
+ self.mrr_a = MORRConfig.attenuation_factor
+ self.mrr_r = MORRConfig.coupling_factor
+ self.trainable_morr_bias = trainable_morr_bias
+ self.trainable_morr_scale = trainable_morr_scale
+ self.device = device
+
+ ### calculate FWHM (rad)
+ self.morr_fwhm = (
+ -4
+ * np.pi ** 2
+ * MORRConfig.radius
+ * MORRConfig.effective_index
+ * (
+ 1 / MORRConfig.resonance_wavelength
+ - 1 / (MORRConfig.resonance_wavelength - MORRConfig.bandwidth / 2)
+ )
+ )
+
+ ### allocate parameters
+ self.weight = None
+ self.x_zero_pad = None
+ self.morr_output_scale = None ## learnable balancing factors implelemt by MRRs
+ self.morr_input_bias = None ## round-trip phase shift bias within MORR
+ self.morr_input_scale = None ## scaling factor for the round-trip phase shift within MORR
+ self.morr_gain = (
+ 100 / (self.in_channels_flat // self.miniblock)
+ ) ** 0.5 ## set this TIA gain such that output variance is around 1
+ ### build trainable parameters
+ self.build_parameters()
+
+ ### quantization tool
+ self.input_quantizer = input_quantize_fn(self.in_bit, device=self.device)
+ self.weight_quantizer = weight_quantize_fn(
+ self.w_bit, alg="dorefa_pos"
+ ) ## [0-1] positive only, maintain the original scale
+ self.morr_output_scale_quantizer = weight_quantize_fn(
+ self.w_bit, alg="dorefa_sym"
+ ) ## [-1,1] full-range
+
+ self.mrr_roundtrip_phase_to_tr = mrr_roundtrip_phase_to_tr_func(
+ a=self.mrr_a, r=self.mrr_r, intensity=True
+ )
+
+ ### default set to slow forward
+ self.disable_fast_forward()
+ ### default set no gamma noise
+ self.set_gamma_noise(0)
+ ### default set no crosstalk
+ self.disable_crosstalk()
+ ### default set no phase variation
+ self.disable_phase_variation()
+
+ if bias:
+ self.bias = Parameter(torch.Tensor(out_channels).to(self.device))
+ else:
+ self.register_parameter("bias", None)
+
+ self.reset_parameters(morr_init=morr_init)
+
+ # support fine-grained structured pruning for MORRs
+ self.finegrain_drop_mask = None
+
+ def build_parameters(self) -> None:
+ ### MORR weights
+ self.weight = Parameter(
+ torch.ones(
+ self.grid_dim_y, self.grid_dim_x, self.miniblock, device=self.device, dtype=torch.float
+ )
+ )
+ ### learnable balancing factor achieved by MRRs (morr_output_scale)
+ ### We use a single scaling factor for each block
+ self.morr_output_scale = Parameter(torch.zeros(max(1, self.grid_dim_x // 2) + 1, device=self.device))
+ if self.trainable_morr_bias:
+ ### initialize with the finest-granularity, i.e., per mini-block
+ self.morr_input_bias = Parameter(
+ torch.zeros(self.grid_dim_y, self.grid_dim_x, device=self.device, dtype=torch.float)
+ )
+ if self.trainable_morr_scale:
+ ### initialize with the finest-granularity, i.e., per mini-block
+ self.morr_input_scale = Parameter(
+ torch.zeros(self.grid_dim_y, self.grid_dim_x, device=self.device, dtype=torch.float)
+ )
+
+ def reset_parameters(self, morr_init: bool = False) -> None:
+ if morr_init:
+ ### nonlinear curve aware initialization
+ morr_uniform_(
+ self.weight,
+ MORRConfig=self.MORRConfig,
+ n_op=self.miniblock,
+ biased=self.w_bit >= 16,
+ gain=2 if self.in_bit < 16 else 1,
+ )
+ self.sigma_weight = self.weight.data.std().item()
+ self.weight_quant_gain = None
+ ### output distribution aware initialization to output scaling factor
+ t1 = mrr_roundtrip_phase_to_tr_fused(
+ torch.tensor([0]).float(), a=self.mrr_a, r=self.mrr_r, intensity=True
+ )
+ t2 = mrr_roundtrip_phase_to_tr_fused(
+ torch.tensor([self.morr_fwhm * 2.4]).float(), a=self.mrr_a, r=self.mrr_r, intensity=True
+ )
+ g = ((t2 - t1) / (2.4 * self.morr_fwhm)).item() ## 0~2.4 FWHM slope as a linear approximation
+
+ self.sigma_out_scale = 4 / (3 * self.grid_dim_x ** 0.5 * g * self.morr_fwhm)
+ self.out_scale_quant_gain = None
+ init.normal_(self.morr_output_scale, 0, self.sigma_out_scale)
+
+ else:
+ nn.init.kaiming_normal_(self.weight)
+ nn.init.kaiming_normal_(self.morr_output_scale)
+ self.sigma_weight = self.weight.data.std().item()
+ self.weight_quant_gain = None
+ self.sigma_out_scale = self.morr_output_scale.data.std().item()
+ self.out_scale_quant_gain = None
+
+ if self.morr_input_bias is not None:
+ init.zeros_(self.morr_input_bias.data)
+ if self.morr_input_scale is not None:
+ init.zeros_(self.morr_input_scale.data)
+
+ if self.bias is not None:
+ init.uniform_(self.bias, 0, 0)
+
+ def sync_parameters(self, src: str = "weight") -> None:
+ """
+ description: synchronize all parameters from the source parameters
+ """
+
+ raise NotImplementedError
+
+ def build_weight(self) -> Tensor:
+ if self.w_bit < 16:
+ ### differentiable quantizer based on STE to enable QAT (Dorefa-Net, arXiv 2016)
+ weight = self.weight_quantizer(self.weight)
+ else:
+ weight = self.weight.abs() ## have to be all positive
+ if self.finegrain_drop_mask is not None:
+ weight = weight.mul(self.finegrain_drop_mask.float())
+
+ return weight
+
+ def enable_fast_forward(self) -> None:
+ self.fast_forward_flag = True
+
+ def disable_fast_forward(self) -> None:
+ self.fast_forward_flag = False
+
+ def set_gamma_noise(self, noise_std: float, random_state: Optional[int] = None) -> None:
+ self.gamma_noise_std = noise_std
+
+ def load_parameters(self, param_dict) -> None:
+ """
+ description: update parameters based on this parameter dictionary\\
+ param param_dict {dict of dict} {layer_name: {param_name: param_tensor, ...}, ...}
+ """
+ for name, param in param_dict.items():
+ getattr(self, name).data.copy_(param)
+
+ def set_weight_bitwidth(self, w_bit: int) -> None:
+ self.w_bit = w_bit
+ self.weight_quantizer.set_bitwidth(w_bit)
+ self.morr_output_scale_quantizer.set_bitwidth(w_bit)
+
+ def set_input_bitwidth(self, in_bit: int) -> None:
+ self.in_bit = in_bit
+ self.input_quantizer.set_bitwidth(in_bit)
+
+ def input_modulator(self, x: Tensor) -> Tensor:
+ ### voltage to power, which is proportional to the phase shift
+ return x * x
+
+ def set_crosstalk_coupling_matrix(self, coupling_factor: float, drop_perc: float = 0) -> None:
+ ### crosstalk coupling matrix is a symmetric matrix, but the intra-MORR crosstalk can be taken as a round-trip phase shift scaling factor, which is proportional to the number of segments after pruned.
+ ### See SqueezeLight paper
+ ### drop-perc is the pruning percentage.
+ assert 0 <= coupling_factor <= 1, logger.error(
+ f"Coupling factor must in [0,1], but got {coupling_factor}"
+ )
+
+ self.crosstalk_factor = 1 + max(3, (self.miniblock * (1 - drop_perc) - 1)) * coupling_factor
+
+ def enable_crosstalk(self) -> None:
+ self.enable_thermal_crosstalk = True
+
+ def disable_crosstalk(self) -> None:
+ self.enable_thermal_crosstalk = False
+
+ def set_phase_variation(self, phase_noise_std: float = 0) -> None:
+ self.phase_noise_std = phase_noise_std
+
+ def enable_phase_variation(self) -> None:
+ self.enable_phase_noise = True
+
+ def disable_phase_variation(self) -> None:
+ self.enable_phase_noise = False
+
+ def enable_trainable_morr_scale(self) -> None:
+ self.trainable_morr_scale = True
+
+ def disable_trainable_morr_scale(self) -> None:
+ self.trainable_morr_scale = False
+
+ def enable_trainable_morr_bias(self) -> None:
+ self.trainable_morr_bias = True
+
+ def disable_trainable_morr_bias(self) -> None:
+ self.trainable_morr_bias = False
+
+ @property
+ def morr_scale(self) -> Tensor:
+ return torch.sigmoid(self.morr_input_scale.unsqueeze(0).unsqueeze(-1)) + 0.2
+
+ @property
+ def morr_bias(self) -> Tensor:
+ return self.morr_fwhm * torch.tanh(self.morr_input_bias.unsqueeze(0).unsqueeze(-1))
+
+ def propagate_morr(self, weight: Tensor, x: Tensor) -> Tensor:
+ """
+ @description: propagate through the analytically calculated transfer matrix of morr.
+ @param weight {torch.Tensor} first column vectors in the block-circulant matrix
+ @param x {torch.Tensor} input
+ @return: y {torch.Tensor} output of MORR array
+ """
+ ### weights: [p, q, k]
+ ### x: [ks*ks*inc, h_out*w_out*bs]
+
+ x = x.t() # [h_out*w_out*bs, ks*ks*inc]
+ x = x.view(x.size(0), self.grid_dim_x, self.miniblock) # [h_out*w_out*bs, q, k]
+
+ ### injecting crosstalk into weights is more efficient
+ if self.enable_thermal_crosstalk and self.crosstalk_factor > 1:
+ weight = weight * self.crosstalk_factor
+
+ ### construct block-circulant matrix
+ weight = toeplitz(weight).unsqueeze(0) # [1, p, q, k, k]
+ x = x.unsqueeze(1).unsqueeze(-1) # [h*w*bs, 1, q, k, 1]
+ x = weight.matmul(x).squeeze(-1) # [h*w*bs, p, q, k]
+
+ if self.enable_phase_noise and self.phase_noise_std > 1e-5:
+ x = x + torch.zeros_like(x).normal_(0, self.phase_noise_std) # [h*w*bs, p, q, k]
+
+ ### input scaling, learnable MORR nonlinearity
+ if self.trainable_morr_scale:
+ x = x * self.morr_scale # [h*w*bs, p, q, k]
+ ### input biasing, learnable MORR nonlinearity
+ if self.trainable_morr_bias:
+ x = x - self.morr_bias
+
+ ### Use theoretical transmission function for trainable MORR nonlinearity
+ ### x is the phase detuning, x=0 means on-resonance
+ ### x: [h_out*w_out*bs, p, q, k]
+ x = self.mrr_roundtrip_phase_to_tr(x)
+
+ ### output scaling or learnable balancing factors
+ if self.w_bit < 16:
+ morr_output_scale = self.morr_output_scale_quantizer(self.morr_output_scale)
+ if self.out_scale_quant_gain is None:
+ self.out_scale_quant_gain = self.sigma_out_scale / morr_output_scale.data.std().item()
+ morr_output_scale = morr_output_scale.mul(
+ self.out_scale_quant_gain
+ ) ### gain factor from Tanh used in quantization
+ else:
+ morr_output_scale = self.morr_output_scale
+
+ scale = morr_output_scale[:-1]
+ scale_pad = morr_output_scale[-1:]
+
+ ### differential rails
+ if self.grid_dim_x % 2 == 0:
+ # even blocks
+ scale = torch.cat([scale, -scale], dim=0)
+ else:
+ # odd blocks
+ if self.grid_dim_x > 1:
+ scale = torch.cat([morr_output_scale, -scale], dim=0)
+ else:
+ scale = scale_pad
+ scale = scale.unsqueeze(0).unsqueeze(0).unsqueeze(0) # [1, 1, 1, q]
+
+ x = scale.matmul(x) # [1,1,1,q]x[h_out*w_out*bs, p, q, k]=[h_out*w_out*bs, p, 1, k]
+ x = x.view(x.size(0), -1).t() # [p*k, h_out*w_out*bs]
+ if self.out_channels_pad > self.out_channels:
+ x = x[: self.out_channels, :] # [outc, h_out*w_out*bs]
+ return x
+
+ def morr_conv2d(self, X: Tensor, W: Tensor) -> Tensor:
+ ### W : [p, q, k]
+ n_x = X.size(0)
+
+ _, X_col, h_out, w_out = im2col_2d(
+ None,
+ X,
+ stride=self.stride[0],
+ padding=self.padding[0],
+ w_size=(self.out_channels, self.in_channels, self.kernel_size[0], self.kernel_size[1]),
+ )
+ ## zero-padding X_col
+ if self.in_channels_pad > self.in_channels_flat:
+ if self.x_zero_pad is None or self.x_zero_pad.size(1) != X_col.size(1):
+ self.x_zero_pad = torch.zeros(
+ self.in_channels_pad - self.in_channels_flat,
+ X_col.size(1),
+ dtype=torch.float32,
+ device=self.device,
+ )
+
+ X_col = torch.cat([X_col, self.x_zero_pad], dim=0)
+ # matmul
+ out = self.propagate_morr(W, X_col) # [outc, w_out]
+ out = out.view(self.out_channels, h_out, w_out, n_x)
+ out = out.permute(3, 0, 1, 2).contiguous()
+
+ return out
+
+ def get_finegrain_drop_mask(self, topk: int) -> Tensor:
+ if self.w_bit < 16:
+ weight = self.weight_quantizer(self.weight.data) # [p, q, k]
+ else:
+ weight = self.weight.data.abs()
+ indices = weight.argsort(dim=-1)
+ mask = torch.ones_like(weight, dtype=torch.bool, device=weight.device)
+
+ drop_indices = indices[:, :, 0:-topk]
+ mask.scatter_(2, drop_indices, 0)
+ self.finegrain_drop_mask = mask
+ return mask
+
+ def apply_finegrain_drop_mask(self, mask: Tensor) -> None:
+ if self.w_bit < 16:
+ self.weight.data.masked_fill_(~mask.view_as(self.weight.data), -1000)
+ else:
+ self.weight.data.masked_fill_(~mask.view_as(self.weight.data), 0)
+
+ def get_output_dim(self, img_height: int, img_width: int) -> Tuple[int, int]:
+ """
+ get the output features size
+ """
+ h_out = (img_height - self.kernel_size[0] + 2 * self.padding[0]) / self.stride[0] + 1
+ w_out = (img_width - self.kernel_size[1] + 2 * self.padding[1]) / self.stride[1] + 1
+ return (int(h_out), int(w_out))
+
+ def forward(self, x: Tensor) -> Tensor:
+ if self.in_bit < 16:
+ x = self.input_quantizer(x)
+ weight = self.build_weight()
+ x = self.input_modulator(x)
+ x = self.morr_conv2d(x, weight)
+
+ if self.bias is not None:
+ x = x + self.bias.unsqueeze(0).unsqueeze(2).unsqueeze(3)
+
+ return x
diff --git a/src/chop/nn/optical/modules/morr_linear.py b/src/chop/nn/optical/modules/morr_linear.py
new file mode 100644
index 000000000..4846fb763
--- /dev/null
+++ b/src/chop/nn/optical/modules/morr_linear.py
@@ -0,0 +1,442 @@
+"""
+Description:
+Author: Jiaqi Gu (jqgu@utexas.edu)
+Date: 2022-04-18 14:19:57
+LastEditors: Jiaqi Gu (jqgu@utexas.edu)
+LastEditTime: 2022-04-18 16:21:37
+"""
+from typing import Optional
+
+import numpy as np
+import torch
+import torch.fft
+from ..functional import toeplitz
+from ..functional import logger
+from ..functional import morr_uniform_
+from ..functional import input_quantize_fn, weight_quantize_fn
+from torch import Tensor
+from torch.nn import Parameter, init
+from torch.types import Device
+from ..functional import MORRConfig_20um_MQ
+from ..functional import mrr_roundtrip_phase_to_tr_func, mrr_roundtrip_phase_to_tr_fused
+
+from .base_layer import ONNBaseLayer
+
+__all__ = ["AllPassMORRCirculantLinear"]
+
+
+class AllPassMORRCirculantLinear(ONNBaseLayer):
+ """
+ All-pass MORR Linear layer, assumes (1) block-circulant matrix (2) differential rails (3) learnable balancing factors.
+ J. Gu, et al., "SqueezeLight: Towards Scalable Optical Neural Networks with Multi-Operand Ring Resonators"
+ https://doi.org/10.23919/DATE51398.2021.9474147
+ """
+
+ __constants__ = ["in_features", "out_features"]
+ in_features: int
+ out_features: int
+ miniblock: int
+ weight: Tensor
+
+ def __init__(
+ self,
+ in_features: int,
+ out_features: int,
+ bias: bool = False,
+ config = None,
+ # miniblock: int = 4,
+ # ### mrr parameter
+ # MORRConfig=MORRConfig_20um_MQ,
+ # morr_init: bool = True,
+ # ### trainable MORR nonlinearity
+ # trainable_morr_bias: bool = False,
+ # trainable_morr_scale: bool = False,
+ device: Device = torch.device("cuda"),
+ ) -> None:
+ super(AllPassMORRCirculantLinear, self).__init__()
+ self.in_features = in_features
+ self.out_features = out_features
+
+ miniblock_size = config.get("miniblock", 4)
+ self.miniblock = miniblock_size
+ self.grid_dim_x = int(np.ceil(self.in_features / miniblock_size))
+ self.grid_dim_y = int(np.ceil(self.out_features / miniblock_size))
+ self.in_features_pad = self.grid_dim_x * miniblock_size
+ self.out_features_pad = self.grid_dim_y * miniblock_size
+
+ self.v_max = 10.8
+ self.v_pi = 4.36
+ self.gamma = np.pi / self.v_pi ** 2
+ self.w_bit = 32
+ self.in_bit = 32
+
+ morr_config = config.get("MORRConfig", MORRConfig_20um_MQ)
+ morr_init_val = config.get("morr_init", MORRConfig_20um_MQ)
+ self.MORRConfig = morr_config
+ self.morr_init = morr_init_val
+ self.mrr_a = morr_config.attenuation_factor
+ self.mrr_r = morr_config.coupling_factor
+ self.trainable_morr_bias = config.get("trainable_morr_bias", MORRConfig_20um_MQ)
+ self.trainable_morr_scale = config.get("trainable_morr_scale", MORRConfig_20um_MQ)
+ self.device = device
+ ### calculate FWHM (rad)
+ self.morr_fwhm = (
+ -4
+ * np.pi ** 2
+ * morr_config.radius
+ * morr_config.effective_index
+ * (
+ 1 / morr_config.resonance_wavelength
+ - 1 / (morr_config.resonance_wavelength - morr_config.bandwidth / 2)
+ )
+ )
+
+ ### allocate parameters
+ self.weight = None
+ self.x_zero_pad = None
+ self.morr_output_scale = None ## learnable balancing factors implelemt by MRRs
+ self.morr_input_bias = None ## round-trip phase shift bias within MORR
+ self.morr_input_scale = None ## scaling factor for the round-trip phase shift within MORR
+ self.morr_gain = (
+ 100 / (self.in_features // self.miniblock)
+ ) ** 0.5 ## TIA gain, calculated such that output variance is around 1
+ ### build trainable parameters
+ self.build_parameters()
+
+ ### quantization tool
+ self.input_quantizer = input_quantize_fn(self.in_bit, device=self.device)
+ self.weight_quantizer = weight_quantize_fn(
+ self.w_bit, alg="dorefa_pos"
+ ) ## [0-1] positive only, maintain the original scale
+ self.morr_output_scale_quantizer = weight_quantize_fn(
+ self.w_bit, alg="dorefa_sym"
+ ) ## [-1,1] full-range
+
+ self.mrr_roundtrip_phase_to_tr = mrr_roundtrip_phase_to_tr_func(
+ a=self.mrr_a, r=self.mrr_r, intensity=True
+ )
+
+ ### default set to slow forward
+ self.disable_fast_forward()
+ ### default set no gamma noise
+ self.set_gamma_noise(0)
+ ### default set no crosstalk
+ self.disable_crosstalk()
+ ### default set no phase variation
+ self.disable_phase_variation()
+
+ if bias:
+ self.bias = Parameter(torch.Tensor(out_features).to(self.device))
+ else:
+ self.register_parameter("bias", None)
+
+ self.reset_parameters(morr_init=morr_init_val)
+ self.finegrain_drop_mask = None
+
+ def build_parameters(self) -> None:
+
+ self.weight = Parameter(
+ torch.ones(
+ self.grid_dim_y, self.grid_dim_x, self.miniblock, device=self.device, dtype=torch.float
+ )
+ )
+ ### Learnable balancing factor (morr_output_scale)
+ ### We use a single scaling factor for each block
+ self.morr_output_scale = Parameter(
+ torch.randn(1, 1, max(1, self.grid_dim_x // 2) + 1, 1, device=self.device)
+ )
+ if self.trainable_morr_bias:
+ ### initialize with the finest-granularity, i.e., per mini-block
+ self.morr_input_bias = Parameter(
+ torch.zeros(self.grid_dim_y, self.grid_dim_x, device=self.device, dtype=torch.float)
+ )
+ if self.trainable_morr_scale:
+ ### initialize with the finest-granularity, i.e., per mini-block
+ self.morr_input_scale = Parameter(
+ torch.zeros(self.grid_dim_y, self.grid_dim_x, device=self.device, dtype=torch.float)
+ )
+
+ def reset_parameters(self, morr_init: bool = False) -> None:
+ ### nonlinear curve aware initialization
+ if morr_init:
+ ## initialize weight
+ morr_uniform_(
+ self.weight,
+ MORRConfig=self.MORRConfig,
+ n_op=self.miniblock,
+ biased=self.w_bit >= 16,
+ gain=2 if self.in_bit < 16 else 1,
+ ) # quantization needs zero-center
+ self.sigma_weight = self.weight.data.std().item()
+ self.weight_quant_gain = None
+
+ ## output distribution aware initialization to output scaling factor
+ t1 = mrr_roundtrip_phase_to_tr_fused(
+ torch.tensor([0]).float(), a=self.mrr_a, r=self.mrr_r, intensity=True
+ )
+ t2 = mrr_roundtrip_phase_to_tr_fused(
+ torch.tensor([self.morr_fwhm * 2.4]).float(), a=self.mrr_a, r=self.mrr_r, intensity=True
+ )
+ g = ((t2 - t1) / (2.4 * self.morr_fwhm)).item() ## 0~2.4 FWHM slope as a linear approximation
+
+ self.sigma_out_scale = 4 / (3 * self.grid_dim_x ** 0.5 * g * self.morr_fwhm)
+ self.out_scale_quant_gain = None
+ init.normal_(self.morr_output_scale, 0, self.sigma_out_scale)
+ else:
+ init.kaiming_normal_(self.weight.data)
+ init.kaiming_normal_(self.morr_output_scale.data)
+ self.sigma_weight = self.weight.data.std().item()
+ self.weight_quant_gain = None
+ self.sigma_out_scale = self.morr_output_scale.data.std().item()
+ self.out_scale_quant_gain = None
+
+ if self.morr_input_bias is not None:
+ self.morr_input_bias.data.zero_()
+ if self.morr_input_scale is not None:
+ ### after sigmoid, it cooresponds to 1 scale
+ init.normal_(self.morr_input_scale.data, 2, 0.1)
+
+ if self.bias is not None:
+ init.uniform_(self.bias, 0, 0)
+
+ def sync_parameters(self, src: str = "weight") -> None:
+ """
+ description: synchronize all parameters from the source parameters
+ """
+
+ raise NotImplementedError
+
+ def build_weight(self) -> Tensor:
+ if self.w_bit < 16:
+ ### differentiable quantizer based on STE to enable QAT (Dorefa-Net, arXiv 2016)
+ weight = self.weight_quantizer(self.weight)
+
+ ## rescale weights after quantization can maintain the initialization distribution
+ if self.weight_quant_gain is None:
+ self.weight_quant_gain = self.sigma_weight / weight.data.std()
+ if self.trainable_morr_scale:
+ morr_scale = self.morr_scale * self.weight_quant_gain
+ else:
+ morr_scale = self.weight_quant_gain
+ weight = weight.mul(morr_scale) ### gain factor from Tanh used in quantization
+
+ ### quantize learnable balancing factor
+ morr_output_scale = self.morr_output_scale_quantizer(self.morr_output_scale)
+ else:
+ weight = self.weight.abs() # positive only
+ morr_output_scale = self.morr_output_scale - self.morr_output_scale.data.mean()
+
+ if self.finegrain_drop_mask is not None:
+ weight = weight.mul(self.finegrain_drop_mask.float())
+
+ ## differential balancing factor concatenation
+ scale = morr_output_scale[..., :-1, :]
+ scale_pad = morr_output_scale[..., -1:, :]
+ if self.grid_dim_x % 2 == 0:
+ # even blocks
+ scale = torch.cat([scale, -scale], dim=2) # [1, 1, q, 1]
+ else:
+ # odd blocks
+ if self.grid_dim_x > 1:
+ scale = torch.cat([morr_output_scale, -scale], dim=2) # [1, 1, q, 1]
+ else:
+ scale = scale_pad # [1, 1, q, 1]
+ morr_output_scale = scale.squeeze(-1).unsqueeze(0) # [1 ,1, 1, q]
+
+ return weight, morr_output_scale
+
+ def enable_fast_forward(self) -> None:
+ self.fast_forward_flag = True
+
+ def disable_fast_forward(self) -> None:
+ self.fast_forward_flag = False
+
+ def set_gamma_noise(self, noise_std: float, random_state: Optional[int] = None) -> None:
+ self.gamma_noise_std = noise_std
+
+ def load_parameters(self, param_dict) -> None:
+ """
+ description: update parameters based on this parameter dictionary\\
+ param param_dict {dict of dict} {layer_name: {param_name: param_tensor, ...}, ...}
+ """
+ for name, param in param_dict.items():
+ getattr(self, name).data.copy_(param)
+
+ def set_weight_bitwidth(self, w_bit: int) -> None:
+ self.w_bit = w_bit
+ self.weight_quantizer.set_bitwidth(w_bit)
+ self.morr_output_scale_quantizer.set_bitwidth(w_bit)
+
+ def set_input_bitwidth(self, in_bit: int) -> None:
+ self.in_bit = in_bit
+ self.input_quantizer.set_bitwidth(in_bit)
+
+ def input_modulator(self, x: Tensor) -> Tensor:
+ ### voltage to power, which is proportional to the phase shift
+ return x * x
+
+ def set_crosstalk_coupling_matrix(self, coupling_factor: float, drop_perc: float = 0) -> None:
+ ### crosstalk coupling matrix is a symmetric matrix, but the intra-MORR crosstalk can be taken as a round-trip phase shift scaling factor, which is proportional to the number of segments after pruned.
+ ### drop-perc is the pruning percentage.
+ assert 0 <= coupling_factor <= 1, logger.error(
+ f"Coupling factor must in [0,1], but got {coupling_factor}"
+ )
+
+ self.crosstalk_factor = 1 + max(3, (self.miniblock * (1 - drop_perc) - 1)) * coupling_factor
+
+ def enable_crosstalk(self) -> None:
+ self.enable_thermal_crosstalk = True
+
+ def disable_crosstalk(self) -> None:
+ self.enable_thermal_crosstalk = False
+
+ def set_phase_variation(self, phase_noise_std: float = 0) -> None:
+ self.phase_noise_std = phase_noise_std
+
+ def enable_phase_variation(self) -> None:
+ self.enable_phase_noise = True
+
+ def disable_phase_variation(self) -> None:
+ self.enable_phase_noise = False
+
+ def enable_trainable_morr_scale(self) -> None:
+ self.trainable_morr_scale = True
+
+ def disable_trainable_morr_scale(self) -> None:
+ self.trainable_morr_scale = False
+
+ def enable_trainable_morr_bias(self) -> None:
+ self.trainable_morr_bias = True
+
+ def disable_trainable_morr_bias(self) -> None:
+ self.trainable_morr_bias = False
+
+ @property
+ def morr_bias(self) -> Tensor:
+ if self.morr_input_bias is None:
+ return None
+ # return 2 * self.morr_fwhm * torch.sigmoid(self.morr_input_bias.unsqueeze(0).unsqueeze(-1))
+ return self.morr_fwhm * torch.tanh(self.morr_input_bias.unsqueeze(0).unsqueeze(-1))
+
+ @property
+ def morr_scale(self) -> Tensor:
+ if self.morr_input_scale is None:
+ return None
+ return torch.sigmoid(self.morr_input_scale.unsqueeze(-1)) + 0.2 # [p, q, 1]
+
+ def propagate_morr(self, weight: Tensor, x: Tensor, morr_output_scale: Tensor) -> Tensor:
+ """
+ @description: propagate through the analytically calculated transfer matrix of molg. We implement circulant matrix multiplication using fast circ matmul
+ @param weight {torch.Tensor} two phase shifters in the MZI-based attenuators
+ @param x {torch.Tensor} complex-valued input
+ @param morr_output_scale {torch.Tensor} learnable balancing factors
+ @return: y {torch.Tensor} output of attenuators
+ """
+ ### x : [bs, q, k]
+ ### weights: [p, q, k]
+ ### morr_output_scale: [1, 1, 1, q]
+
+ ### input scaling [TCAD'21], must have valid ranges. too small will have dead neuron and not enough nonlinearity; too large will have larger power, cross-channel crosstalk. [0.2 - 1.2] will be suitable
+ ## build circulant weight matrix
+ # crosstalk on the weights are much cheaper to compute than on the phase shift
+ if self.enable_thermal_crosstalk and self.crosstalk_factor > 1:
+ weight = weight * self.crosstalk_factor
+ weight = toeplitz(weight).unsqueeze(0) # [1, p, q, k, k]
+ x = x.unsqueeze(1).unsqueeze(-1) # [bs, 1, q, k, 1]
+ x = weight.matmul(x).squeeze(-1) # [bs, p, q, k]
+
+ if self.enable_phase_noise and self.phase_noise_std > 1e-5:
+ x = x + torch.zeros_like(x).normal_(0, self.phase_noise_std)
+
+ ### input biasing [TCAD'21], must have valid ranges. too large will have power issue and cross-channel crosstalk. [-2FWHM ~ 0]
+ if self.trainable_morr_bias:
+ x = x - self.morr_bias
+
+ ### Use theoretical transmission function for trainable MORR nonlinearity [TCAD'21]
+ ### x is the phase detuning, x=0 means on-resonance
+ ### phase: [bs, p, q, k]
+ x = self.mrr_roundtrip_phase_to_tr(x) # 3x faster than autograd
+
+ ## implement balancing factor as dot-product
+ """
+ if(self.w_bit < 16):
+ morr_output_scale = self.morr_output_scale_quantizer(self.morr_output_scale)
+ if(self.sigma_out_scale_quant_gain is None):
+ self.sigma_out_scale_quant_gain = self.sigma_out_scale / morr_output_scale.data.std().item()
+ morr_output_scale = morr_output_scale.mul(self.sigma_out_scale_quant_gain)### gain factor from Tanh used in quantization
+ else:
+ morr_output_scale = self.morr_output_scale
+ # morr_output_scale = morr_output_scale * self.morr_gain
+ scale = morr_output_scale[..., :-1, :]
+ scale_pad = morr_output_scale[..., -1:, :]
+
+ # print("morr diff transmission:", end=", ")
+ # diff = x[..., :x.size(2)//2,:]-x[..., x.size(2)//2:,:]
+ # print_stat(diff)
+ if(self.grid_dim_x % 2 == 0):
+ #even blocks
+ scale = torch.cat([scale, -scale], dim=2) # [1, 1, q, 1]
+ else:
+ # odd blocks
+ if(self.grid_dim_x > 1):
+ scale = torch.cat([morr_output_scale, -scale], dim=2) # [1, 1, q, 1]
+ else:
+ scale = scale_pad # [1, 1, q, 1]
+ scale = scale.squeeze(-1).unsqueeze(0) # [1 ,1, 1, q]
+ # print("output scale Q:", end=", ")
+ # print_stat(scale[..., :scale.size(-1)//2])
+ """
+ x = morr_output_scale.matmul(x) # [1, 1, 1, q] x [bs, p, q, k] = [bs, p, 1, k]
+ x = x.flatten(1) # [bs, p*k]
+ return x
+
+ def get_finegrain_drop_mask(self, topk: int) -> Tensor:
+ if self.w_bit < 16:
+ weight = self.weight_quantizer(self.weight.data) # [p, q, k]
+ else:
+ weight = self.weight.data.abs()
+ indices = weight.argsort(dim=-1)
+ mask = torch.ones_like(weight, dtype=torch.bool, device=weight.device)
+
+ drop_indices = indices[:, :, 0:-topk]
+ mask.scatter_(2, drop_indices, 0)
+ self.finegrain_drop_mask = mask
+ return mask
+
+ def apply_finegrain_drop_mask(self, mask: Tensor) -> None:
+ if self.w_bit < 16:
+ self.weight.data.masked_fill_(~mask.view_as(self.weight.data), -1000)
+ else:
+ self.weight.data.masked_fill_(~mask.view_as(self.weight.data), 0)
+
+ def forward(self, x: Tensor) -> Tensor:
+ assert (
+ x.size(-1) == self.in_features
+ ), f"[E] Input dimension does not match the weight size {self.out_features, self.in_features}, but got input size ({tuple(x.size())}))"
+ if self.in_bit < 16:
+ x = self.input_quantizer(x)
+
+ weight, morr_output_scale = self.build_weight()
+ if self.in_features_pad > self.in_features:
+ if self.x_zero_pad is None or self.x_zero_pad.size(0) != x.size(0):
+ self.x_zero_pad = torch.zeros(
+ x.size(0), self.in_features_pad - self.in_features, device=x.device, dtype=x.dtype
+ )
+ x = torch.cat([x, self.x_zero_pad], dim=1)
+
+ x = x.view(-1, self.grid_dim_x, self.miniblock)
+
+ ### modulation
+ ### x: [bs, q, k] -> [bs, q, k]
+ x = self.input_modulator(x)
+
+ ### propagate through morr array
+ ### x: [bs, q, k] -> [bs, p*k]
+ x = self.propagate_morr(weight, x, morr_output_scale)
+
+ if self.out_features < self.out_features_pad:
+ x = x[..., : self.out_features]
+ if self.bias is not None:
+ x = x + self.bias.unsqueeze(0)
+
+ return x
diff --git a/src/chop/passes/module/module_transform_helper.py b/src/chop/passes/module/module_transform_helper.py
new file mode 100644
index 000000000..b644b087a
--- /dev/null
+++ b/src/chop/passes/module/module_transform_helper.py
@@ -0,0 +1,64 @@
+import torch
+import torch.nn as nn
+import numpy as np
+from chop.nn.optical.modules import optical_module_map
+from chop.passes.module.module_modify_helper import get_module_by_name, set_module_by_name
+
+def replace_by_name_optical(
+ network,
+ module_name: str,
+ new_module
+):
+
+ original = get_module_by_name(network, module_name)
+ updated_module = weight_replacement_optical(original, new_module)
+ network = set_module_by_name(network, module_name, updated_module)
+
+ return network
+
+
+def weight_replacement_optical(x, y):
+ """
+ Replace the weights of AllPassMORRCirculantLinear (y)
+ with those from a standard nn.Linear (x).
+ Focuses only on weight copying (no bias copying).
+ """
+
+ # Fetch original linear weight [out_features, in_features]
+ W = x.weight.data # shape: (out_features, in_features)
+
+ # Grab dimensions and zero-pad if needed
+ out_features_pad = y.out_features_pad # padded out_features in y
+ in_features_pad = y.in_features_pad # padded in_features in y
+ miniblock = y.miniblock
+ grid_dim_y = y.grid_dim_y
+ grid_dim_x = y.grid_dim_x
+
+ # Construct padded weight tensor
+ W_padded = W.new_zeros((out_features_pad, in_features_pad))
+ W_padded[: W.size(0), : W.size(1)] = W # copy original into top-left
+
+ # Now we create a new tensor of shape [grid_dim_y, grid_dim_x, miniblock]
+ # by compressing each row-block [1 x miniblock] from W_padded into a single scalar.
+ # This is a simple example that takes the mean across the miniblock slice.
+ new_weight = W.new_zeros((grid_dim_y, grid_dim_x, miniblock))
+
+ # Fill new_weight by averaging the corresponding sub-blocks in W_padded
+ with torch.no_grad():
+ for p in range(grid_dim_y):
+ for q in range(grid_dim_x):
+ for k in range(miniblock):
+ # The row in W_padded we look at:
+ row_idx = p * miniblock + k
+ # The columns we look at:
+ col_start = q * miniblock
+ col_end = (q + 1) * miniblock
+
+ block = W_padded[row_idx, col_start:col_end]
+ new_weight[p, q, k] = block.mean()
+
+ # Copy the result into y.weight
+ y.load_parameters({"weight": new_weight})
+ # y.weight.data.copy_(new_weight)
+
+ return y
diff --git a/src/chop/passes/module/transforms/__init__.py b/src/chop/passes/module/transforms/__init__.py
index efbb0ed14..8773589ec 100644
--- a/src/chop/passes/module/transforms/__init__.py
+++ b/src/chop/passes/module/transforms/__init__.py
@@ -1,3 +1,4 @@
from .autosharding import resharding_transform_pass
from .quantize import quantize_module_transform_pass
from .autosharding import resharding_transform_pass
+from .optical import optical_module_transform_pass
\ No newline at end of file
diff --git a/src/chop/passes/module/transforms/optical/__init__.py b/src/chop/passes/module/transforms/optical/__init__.py
new file mode 100644
index 000000000..9b1840c4e
--- /dev/null
+++ b/src/chop/passes/module/transforms/optical/__init__.py
@@ -0,0 +1 @@
+from .optical import optical_module_transform_pass
diff --git a/src/chop/passes/module/transforms/optical/optical.py b/src/chop/passes/module/transforms/optical/optical.py
new file mode 100644
index 000000000..da475d778
--- /dev/null
+++ b/src/chop/passes/module/transforms/optical/optical.py
@@ -0,0 +1,103 @@
+import torch
+
+from chop.nn.optical.modules import optical_module_map
+from ...module_modify_helper import replace_by_name, instantiate_module
+from ...module_transform_helper import replace_by_name_optical
+
+
+def get_config(config: dict, name: str):
+ if name in config:
+ return config[name]["config"]
+ else:
+ return config["default"]["config"]
+
+
+def optical_by_type(network, pass_args):
+ for type_name, config in pass_args.items():
+ n_m = {}
+ for n, m in network.named_modules():
+ n_m[n] = m
+
+ if type_name == "linear":
+ module = torch.nn.Linear
+ elif type_name == "conv2d":
+ module = torch.nn.Conv2d
+ else:
+ raise ValueError(f"{type_name} is not supported!")
+ config = config["config"]
+ postfix = config.pop("name")
+ for n, m in n_m.items():
+ if isinstance(m, module):
+ new_m = instantiate_module(
+ m, postfix, optical_module_map, {"config": config}
+ )
+ network = replace_by_name_optical(network, n, new_m)
+ return network
+
+
+def optical_by_name(network, pass_args):
+ quantize_names = pass_args.keys()
+ n_m = {}
+ for n, m in network.named_modules():
+ n_m[n] = m
+ for n, m in n_m.items():
+ if n in quantize_names:
+ quan_config = pass_args[n]
+
+ quan_config = quan_config["config"]
+ postfix = quan_config.pop("name")
+
+ new_m = instantiate_module(
+ m, postfix, optical_module_map, {"config": quan_config}
+ )
+ network = replace_by_name_optical(network, n, new_m)
+ return network
+
+
+def optical_module_transform_pass(network, pass_args):
+ """
+ Apply optical transformation to the given nn.Module.
+
+ :param network: The input network to be transformed.
+ :type network: torch.nn.Module
+
+ :param pass_args: Additional arguments for the transformation.
+ :type pass_args: dict, optional
+
+ Examples pass_args:
+
+ .. code-block:: python
+
+ pass_args = {
+ "by": "type", # quantize by type, name, or regex_name
+ "default": {"config": {"name": None}}, # default config, this would be used for any node that does not have a specific config
+ "linear": {
+ "config": {
+ "name": "integer", # quantization scheme name supported are ["integer", "fixed" (equivalent to integer), "lutnet" (dev mode), "logicnets" (dev mode), "binary", "binary_residual", "ternary", "minifloat_ieee", "minifloat_denorm", "log", "block_fp", "block_minifloat", "block_log"]
+ # data
+ "data_in_width": 8,
+ "data_in_frac_width": 4,
+ # weight
+ "weight_width": 8,
+ "weight_frac_width": 4,
+ # bias
+ "bias_width": 8,
+ "bias_frac_width": 4,
+ }
+ },
+ }
+
+ :return: The transformed torch.nn.Module.
+ :rtype: tuple
+ :raises ValueError: If the quantize "by" argument is unsupported.
+
+ """
+ by = pass_args.pop("by")
+ match by:
+ case "type":
+ network = optical_by_type(network, pass_args)
+ case "name":
+ network = optical_by_name(network, pass_args)
+ case _:
+ raise ValueError(f'Unsupported quantize "by": {by}')
+ return network, {}
diff --git a/test/passes/module/transforms/optical/test_optical_module.py b/test/passes/module/transforms/optical/test_optical_module.py
new file mode 100644
index 000000000..5409e64f5
--- /dev/null
+++ b/test/passes/module/transforms/optical/test_optical_module.py
@@ -0,0 +1,192 @@
+#!/usr/bin/env python3
+# This example converts a simple MLP model to Verilog
+import logging
+import os
+import sys
+
+import torch
+import torch.nn as nn
+
+from pathlib import Path
+
+sys.path.append(Path(__file__).resolve().parents[5].as_posix())
+
+
+# from chop.passes.module.transforms import quantize_module_transform_pass
+from chop.passes.module.transforms import optical_module_transform_pass
+from chop.passes.module import report_trainable_parameters_analysis_pass
+
+import argparse
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+from torchvision import datasets, transforms
+from torch.optim.lr_scheduler import StepLR
+
+from train_mnist_cnn import test, train, Net
+
+# --------------------------------------------------
+# Model specifications
+# --------------------------------------------------
+# class MLP(torch.nn.Module):
+# """
+# Toy quantized FC model for digit recognition on MNIST
+# """
+
+# def __init__(self) -> None:
+# super().__init__()
+
+# self.fc1 = nn.Linear(28 * 28, 28 * 28)
+# self.fc2 = nn.Linear(28 * 28, 28 * 28 * 4)
+# self.fc3 = nn.Linear(28 * 28 * 4, 10)
+
+# def forward(self, x):
+# x = torch.flatten(x, start_dim=1, end_dim=-1)
+# x = torch.nn.functional.relu(self.fc1(x))
+# # w = torch.randn((4, 28 * 28))
+# # x = torch.nn.functional.relu(nn.functional.linear(x, w))
+# x = torch.nn.functional.relu(self.fc2(x))
+# x = self.fc3(x)
+# return x
+
+def load_my_model(model_path, device):
+ # Load the model from the .pt file
+ loaded_model = torch.load(model_path, map_location=device)
+ # Set it to evaluation mode (important if it contains layers like BatchNorm or Dropout)
+ loaded_model.eval()
+ return loaded_model
+
+def perform_optical_module_transform_pass(model, save_path="mase_output/onn_cnn.pt"):
+ pass_args = {
+ "by": "type",
+ "linear": {
+ "config": {
+ "name": "morr",
+ "miniblock": 4,
+ "morr_init": True,
+ "trainable_morr_bias": False,
+ "trainable_morr_scale": False,
+ }
+ },
+ }
+ onn_model, _ = optical_module_transform_pass(model, pass_args)
+ torch.save(onn_model.state_dict(), save_path)
+ return onn_model
+
+def test_optical_module_transform_pass():
+ model_path = "mase_output/sample_mnist_cnn.pt"
+ mnist_cnn = load_my_model(model_path)
+ # Sanity check and report
+ pass_args = {
+ "by": "name",
+ "fc1": {
+ "config": {
+ "name": "morr",
+ "miniblock": 4,
+ "morr_init": True,
+ "trainable_morr_bias": False,
+ "trainable_morr_scale": False,
+ }
+ },
+ }
+ onn_cnn, _ = optical_module_transform_pass(mnist_cnn, pass_args)
+ torch.save(onn_cnn, "mase_output/onn_cnn.pt")
+
+
+
+if __name__ == '__main__':
+
+ if True:
+ parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
+ parser.add_argument('--batch-size', type=int, default=64, metavar='N',
+ help='input batch size for training (default: 64)')
+ parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
+ help='input batch size for testing (default: 1000)')
+ parser.add_argument('--epochs', type=int, default=14, metavar='N',
+ help='number of epochs to train (default: 14)')
+ parser.add_argument('--lr', type=float, default=1.0, metavar='LR',
+ help='learning rate (default: 1.0)')
+ parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
+ help='Learning rate step gamma (default: 0.7)')
+ parser.add_argument('--no-cuda', action='store_true', default=False,
+ help='disables CUDA training')
+ parser.add_argument('--no-mps', action='store_true', default=False,
+ help='disables macOS GPU training')
+ parser.add_argument('--dry-run', action='store_true', default=False,
+ help='quickly check a single pass')
+ parser.add_argument('--seed', type=int, default=1, metavar='S',
+ help='random seed (default: 1)')
+ parser.add_argument('--log-interval', type=int, default=10, metavar='N',
+ help='how many batches to wait before logging training status')
+ parser.add_argument('--save-model', action='store_true', default=True,
+ help='For Saving the current Model')
+ parser.add_argument('--gpu-id', type=int, default=0,
+ help='Which GPU device to use [default: 0]')
+
+ args = parser.parse_args()
+ use_cuda = not args.no_cuda and torch.cuda.is_available()
+ use_mps = not args.no_mps and torch.backends.mps.is_available()
+
+ torch.manual_seed(args.seed)
+
+ if not args.no_cuda and torch.cuda.is_available():
+ device = torch.device(f"cuda:{args.gpu_id}")
+ elif use_mps:
+ device = torch.device("mps")
+ else:
+ device = torch.device("cpu")
+
+ train_kwargs = {'batch_size': args.batch_size}
+ test_kwargs = {'batch_size': args.test_batch_size}
+ if use_cuda:
+ cuda_kwargs = {'num_workers': 1,
+ 'pin_memory': True,
+ 'shuffle': True}
+ train_kwargs.update(cuda_kwargs)
+ test_kwargs.update(cuda_kwargs)
+
+ transform=transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize((0.1307,), (0.3081,))
+ ])
+ dataset1 = datasets.MNIST('../data', train=True, download=True,
+ transform=transform)
+ dataset2 = datasets.MNIST('../data', train=False,
+ transform=transform)
+ train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)
+ test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)
+
+ cnn = load_my_model("mase_output/sample_mnist_cnn.pt", device)
+ print("-------------- Testing the original cnn model -------------------")
+ test(cnn, device, test_loader)
+ _, _ = report_trainable_parameters_analysis_pass(cnn)
+
+ # onn = load_my_model("mase_output/onn_cnn.pt", device)
+ onn_model = perform_optical_module_transform_pass(cnn)
+ onn_model.to(device)
+ print("-------------- Testing the transformed onn model -------------------")
+ test(onn_model, device, test_loader)
+ _, _ = report_trainable_parameters_analysis_pass(onn_model)
+
+
+ ##################################################################
+ ######### Training the onn model
+ ##################################################################
+ optimizer = optim.Adadelta(onn_model.parameters(), lr=args.lr)
+ scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
+ for epoch in range(1, args.epochs + 1):
+ train(args, onn_model, device, train_loader, optimizer, epoch)
+ test(onn_model, device, test_loader)
+ scheduler.step()
+
+
+ torch.save(onn_model.state_dict(), "mase_output/trained_onn.pt")
+
+ print("-------------- Testing the trained onn model -------------------")
+ test(onn_model, device, test_loader)
+ _, _ = report_trainable_parameters_analysis_pass(onn_model)
+
+
+
+ # test_optical_module_transform_pass()
\ No newline at end of file
diff --git a/test/passes/module/transforms/optical/train_mnist_cnn.py b/test/passes/module/transforms/optical/train_mnist_cnn.py
new file mode 100644
index 000000000..de3bbe385
--- /dev/null
+++ b/test/passes/module/transforms/optical/train_mnist_cnn.py
@@ -0,0 +1,144 @@
+import argparse
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+from torchvision import datasets, transforms
+from torch.optim.lr_scheduler import StepLR
+
+
+class Net(nn.Module):
+ def __init__(self):
+ super(Net, self).__init__()
+ self.conv1 = nn.Conv2d(1, 32, 3, 1)
+ self.conv2 = nn.Conv2d(32, 64, 3, 1)
+ self.dropout1 = nn.Dropout(0.25)
+ self.dropout2 = nn.Dropout(0.5)
+ self.fc1 = nn.Linear(9216, 128)
+ self.fc2 = nn.Linear(128, 10)
+
+ def forward(self, x):
+ x = self.conv1(x)
+ x = F.relu(x)
+ x = self.conv2(x)
+ x = F.relu(x)
+ x = F.max_pool2d(x, 2)
+ x = self.dropout1(x)
+ x = torch.flatten(x, 1)
+ x = self.fc1(x)
+ x = F.relu(x)
+ x = self.dropout2(x)
+ x = self.fc2(x)
+ output = F.log_softmax(x, dim=1)
+ return output
+
+
+def train(args, model, device, train_loader, optimizer, epoch):
+ model.train()
+ for batch_idx, (data, target) in enumerate(train_loader):
+ data, target = data.to(device), target.to(device)
+ optimizer.zero_grad()
+ output = model(data)
+ loss = F.nll_loss(output, target)
+ loss.backward()
+ optimizer.step()
+ if batch_idx % args.log_interval == 0:
+ print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
+ epoch, batch_idx * len(data), len(train_loader.dataset),
+ 100. * batch_idx / len(train_loader), loss.item()))
+ if args.dry_run:
+ break
+
+
+def test(model, device, test_loader):
+ model.eval()
+ test_loss = 0
+ correct = 0
+ with torch.no_grad():
+ for data, target in test_loader:
+ data, target = data.to(device), target.to(device)
+ output = model(data)
+ test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
+ pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
+ correct += pred.eq(target.view_as(pred)).sum().item()
+
+ test_loss /= len(test_loader.dataset)
+
+ print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
+ test_loss, correct, len(test_loader.dataset),
+ 100. * correct / len(test_loader.dataset)))
+
+
+def main():
+ # Training settings
+ parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
+ parser.add_argument('--batch-size', type=int, default=64, metavar='N',
+ help='input batch size for training (default: 64)')
+ parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
+ help='input batch size for testing (default: 1000)')
+ parser.add_argument('--epochs', type=int, default=29, metavar='N',
+ help='number of epochs to train (default: 14)')
+ parser.add_argument('--lr', type=float, default=1.0, metavar='LR',
+ help='learning rate (default: 1.0)')
+ parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
+ help='Learning rate step gamma (default: 0.7)')
+ parser.add_argument('--no-cuda', action='store_true', default=False,
+ help='disables CUDA training')
+ parser.add_argument('--no-mps', action='store_true', default=False,
+ help='disables macOS GPU training')
+ parser.add_argument('--dry-run', action='store_true', default=False,
+ help='quickly check a single pass')
+ parser.add_argument('--seed', type=int, default=1, metavar='S',
+ help='random seed (default: 1)')
+ parser.add_argument('--log-interval', type=int, default=10, metavar='N',
+ help='how many batches to wait before logging training status')
+ parser.add_argument('--save-model', action='store_true', default=True,
+ help='For Saving the current Model')
+ args = parser.parse_args()
+ use_cuda = not args.no_cuda and torch.cuda.is_available()
+ use_mps = not args.no_mps and torch.backends.mps.is_available()
+
+ torch.manual_seed(args.seed)
+
+ if use_cuda:
+ device = torch.device("cuda")
+ elif use_mps:
+ device = torch.device("mps")
+ else:
+ device = torch.device("cpu")
+
+ train_kwargs = {'batch_size': args.batch_size}
+ test_kwargs = {'batch_size': args.test_batch_size}
+ if use_cuda:
+ cuda_kwargs = {'num_workers': 1,
+ 'pin_memory': True,
+ 'shuffle': True}
+ train_kwargs.update(cuda_kwargs)
+ test_kwargs.update(cuda_kwargs)
+
+ transform=transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize((0.1307,), (0.3081,))
+ ])
+ dataset1 = datasets.MNIST('../data', train=True, download=True,
+ transform=transform)
+ dataset2 = datasets.MNIST('../data', train=False,
+ transform=transform)
+ train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)
+ test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)
+
+ model = Net().to(device)
+ optimizer = optim.Adadelta(model.parameters(), lr=args.lr)
+
+ scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
+ for epoch in range(1, args.epochs + 1):
+ train(args, model, device, train_loader, optimizer, epoch)
+ test(model, device, test_loader)
+ scheduler.step()
+
+ if args.save_model:
+ torch.save(model, "mase_output/sample_mnist_cnn.pt")
+
+
+if __name__ == '__main__':
+ main()
From 587154305012a1e220fe9ec6293093f903a2759c Mon Sep 17 00:00:00 2001
From: Johnny1882
Date: Sun, 26 Jan 2025 20:58:13 +0000
Subject: [PATCH 02/38] add support for onn conv2d layer
---
src/chop/nn/optical/modules/morr_conv2d.py | 23 +-
.../optical}/module_transform_helper.py | 54 +++-
.../module/transforms/optical/optical.py | 8 +-
.../transforms/optical/test_optical_module.py | 250 ++++++++++--------
4 files changed, 209 insertions(+), 126 deletions(-)
rename src/chop/passes/module/{ => transforms/optical}/module_transform_helper.py (50%)
diff --git a/src/chop/nn/optical/modules/morr_conv2d.py b/src/chop/nn/optical/modules/morr_conv2d.py
index 83f0bf1fe..87fd8ed42 100644
--- a/src/chop/nn/optical/modules/morr_conv2d.py
+++ b/src/chop/nn/optical/modules/morr_conv2d.py
@@ -71,16 +71,25 @@ def __init__(
dilation: _size = 1,
groups: int = 1,
bias: bool = True,
- miniblock: int = 4,
- ### morr parameter
- MORRConfig=MORRConfig_20um_MQ,
- morr_init: bool = True, # whether to use initialization method customized for MORR
- ### trainable MORR nonlinearity
- trainable_morr_bias: bool = False,
- trainable_morr_scale: bool = False,
+ padding_mode = None,
+ # miniblock: int = 4,
+ # ### morr parameter
+ # MORRConfig=MORRConfig_20um_MQ,
+ # morr_init: bool = True, # whether to use initialization method customized for MORR
+ # ### trainable MORR nonlinearity
+ # trainable_morr_bias: bool = False,
+ # trainable_morr_scale: bool = False,
+ config = None,
device: Device = torch.device("cuda"),
) -> None:
super(AllPassMORRCirculantConv2d, self).__init__()
+ miniblock = config.get("miniblock", 4)
+ MORRConfig = config.get("MORRConfig", MORRConfig_20um_MQ)
+ morr_init = config.get("morr_init", True)
+ trainable_morr_bias = config.get("trainable_morr_bias", False)
+ trainable_morr_scale = config.get("trainable_morr_scale", False)
+
+
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = _pair(kernel_size)
diff --git a/src/chop/passes/module/module_transform_helper.py b/src/chop/passes/module/transforms/optical/module_transform_helper.py
similarity index 50%
rename from src/chop/passes/module/module_transform_helper.py
rename to src/chop/passes/module/transforms/optical/module_transform_helper.py
index b644b087a..dc110d9b0 100644
--- a/src/chop/passes/module/module_transform_helper.py
+++ b/src/chop/passes/module/transforms/optical/module_transform_helper.py
@@ -16,8 +16,15 @@ def replace_by_name_optical(
return network
+def weight_replacement_optical(original, new_module):
+ if isinstance(original, nn.Linear):
+ return weight_replacement_linear_optical(original, new_module)
+ elif isinstance(original, nn.Conv2d):
+ return weight_replacement_conv2d_optical(original, new_module)
+ else:
+ raise NotImplementedError("weight replacement function for the optical module not implemented")
-def weight_replacement_optical(x, y):
+def weight_replacement_linear_optical(x, y):
"""
Replace the weights of AllPassMORRCirculantLinear (y)
with those from a standard nn.Linear (x).
@@ -62,3 +69,48 @@ def weight_replacement_optical(x, y):
# y.weight.data.copy_(new_weight)
return y
+
+
+
+def weight_replacement_conv2d_optical(x, y):
+ """
+ Replace the weights (and bias, if present) of a standard nn.Conv2d (x)
+ into an AllPassMORRCirculantConv2d (y).
+
+ Args:
+ x (nn.Conv2d): A standard PyTorch Conv2d module
+ y (AllPassMORRCirculantConv2d): An already-constructed optical Conv2d
+ module into which we copy weights/bias.
+ """
+ with torch.no_grad():
+ # 1) Copy bias (if both x and y actually have one).
+ if x.bias is not None and y.bias is not None:
+ y.bias.copy_(x.bias)
+
+ # 2) Flatten nn.Conv2d's weight => shape [out_channels, in_channels*kernel_h*kernel_w]
+ w_flat = x.weight.data.view(x.out_channels, -1)
+
+ # 3) Zero-pad to match (out_channels_pad, in_channels_pad)
+ outC_pad = y.out_channels_pad # == y.grid_dim_y * y.miniblock
+ inC_pad = y.in_channels_pad # == y.grid_dim_x * y.miniblock
+
+ W = torch.zeros(outC_pad, inC_pad, device=w_flat.device, dtype=w_flat.dtype)
+ # Copy as many channels/elements as we have
+ W[: x.out_channels, : w_flat.size(1)] = w_flat
+
+ # 4) Reshape into blocks => shape [p, miniblock, q, miniblock]
+ p = y.grid_dim_y
+ q = y.grid_dim_x
+ k = y.miniblock
+ W_blocks = W.view(p, k, q, k) # => [p, k, q, k]
+
+ # 5) For each p,q block, extract the "first column" of size 'k' and place it in y.weight
+ # That is, for a k x k sub-block, we interpret sub_block[:,0] as the "circulant first column".
+ for i in range(p):
+ for j in range(q):
+ sub_block = W_blocks[i, :, j, :] # shape [k, k]
+ y.weight.data[i, j, :] = sub_block[:, 0]
+
+ # Done. At this point, y.weight and y.bias (if present) have been overwritten
+ # with a simple block-circulant approximation of x's parameters.
+ return y
\ No newline at end of file
diff --git a/src/chop/passes/module/transforms/optical/optical.py b/src/chop/passes/module/transforms/optical/optical.py
index da475d778..a91e7403d 100644
--- a/src/chop/passes/module/transforms/optical/optical.py
+++ b/src/chop/passes/module/transforms/optical/optical.py
@@ -1,8 +1,8 @@
import torch
from chop.nn.optical.modules import optical_module_map
-from ...module_modify_helper import replace_by_name, instantiate_module
-from ...module_transform_helper import replace_by_name_optical
+from chop.passes.module.module_modify_helper import replace_by_name, instantiate_module
+from chop.passes.module.transforms.optical.module_transform_helper import replace_by_name_optical
def get_config(config: dict, name: str):
@@ -36,12 +36,12 @@ def optical_by_type(network, pass_args):
def optical_by_name(network, pass_args):
- quantize_names = pass_args.keys()
+ optical_names = pass_args.keys()
n_m = {}
for n, m in network.named_modules():
n_m[n] = m
for n, m in n_m.items():
- if n in quantize_names:
+ if n in optical_names:
quan_config = pass_args[n]
quan_config = quan_config["config"]
diff --git a/test/passes/module/transforms/optical/test_optical_module.py b/test/passes/module/transforms/optical/test_optical_module.py
index 5409e64f5..35c959a70 100644
--- a/test/passes/module/transforms/optical/test_optical_module.py
+++ b/test/passes/module/transforms/optical/test_optical_module.py
@@ -50,17 +50,21 @@
# x = self.fc3(x)
# return x
-def load_my_model(model_path, device):
+def load_my_model(model_path, device="cpu"):
# Load the model from the .pt file
loaded_model = torch.load(model_path, map_location=device)
# Set it to evaluation mode (important if it contains layers like BatchNorm or Dropout)
loaded_model.eval()
return loaded_model
-def perform_optical_module_transform_pass(model, save_path="mase_output/onn_cnn.pt"):
+
+def test_optical_module_transform_pass():
+ model_path = "mase_output/sample_mnist_cnn.pt"
+ mnist_cnn = load_my_model(model_path)
+ # Sanity check and report
pass_args = {
- "by": "type",
- "linear": {
+ "by": "name",
+ "fc1": {
"config": {
"name": "morr",
"miniblock": 4,
@@ -69,18 +73,7 @@ def perform_optical_module_transform_pass(model, save_path="mase_output/onn_cnn.
"trainable_morr_scale": False,
}
},
- }
- onn_model, _ = optical_module_transform_pass(model, pass_args)
- torch.save(onn_model.state_dict(), save_path)
- return onn_model
-
-def test_optical_module_transform_pass():
- model_path = "mase_output/sample_mnist_cnn.pt"
- mnist_cnn = load_my_model(model_path)
- # Sanity check and report
- pass_args = {
- "by": "name",
- "fc1": {
+ "conv1": {
"config": {
"name": "morr",
"miniblock": 4,
@@ -90,103 +83,132 @@ def test_optical_module_transform_pass():
}
},
}
- onn_cnn, _ = optical_module_transform_pass(mnist_cnn, pass_args)
- torch.save(onn_cnn, "mase_output/onn_cnn.pt")
-
-
-
-if __name__ == '__main__':
-
- if True:
- parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
- parser.add_argument('--batch-size', type=int, default=64, metavar='N',
- help='input batch size for training (default: 64)')
- parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
- help='input batch size for testing (default: 1000)')
- parser.add_argument('--epochs', type=int, default=14, metavar='N',
- help='number of epochs to train (default: 14)')
- parser.add_argument('--lr', type=float, default=1.0, metavar='LR',
- help='learning rate (default: 1.0)')
- parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
- help='Learning rate step gamma (default: 0.7)')
- parser.add_argument('--no-cuda', action='store_true', default=False,
- help='disables CUDA training')
- parser.add_argument('--no-mps', action='store_true', default=False,
- help='disables macOS GPU training')
- parser.add_argument('--dry-run', action='store_true', default=False,
- help='quickly check a single pass')
- parser.add_argument('--seed', type=int, default=1, metavar='S',
- help='random seed (default: 1)')
- parser.add_argument('--log-interval', type=int, default=10, metavar='N',
- help='how many batches to wait before logging training status')
- parser.add_argument('--save-model', action='store_true', default=True,
- help='For Saving the current Model')
- parser.add_argument('--gpu-id', type=int, default=0,
- help='Which GPU device to use [default: 0]')
-
- args = parser.parse_args()
- use_cuda = not args.no_cuda and torch.cuda.is_available()
- use_mps = not args.no_mps and torch.backends.mps.is_available()
-
- torch.manual_seed(args.seed)
-
- if not args.no_cuda and torch.cuda.is_available():
- device = torch.device(f"cuda:{args.gpu_id}")
- elif use_mps:
- device = torch.device("mps")
- else:
- device = torch.device("cpu")
-
- train_kwargs = {'batch_size': args.batch_size}
- test_kwargs = {'batch_size': args.test_batch_size}
- if use_cuda:
- cuda_kwargs = {'num_workers': 1,
- 'pin_memory': True,
- 'shuffle': True}
- train_kwargs.update(cuda_kwargs)
- test_kwargs.update(cuda_kwargs)
-
- transform=transforms.Compose([
- transforms.ToTensor(),
- transforms.Normalize((0.1307,), (0.3081,))
- ])
- dataset1 = datasets.MNIST('../data', train=True, download=True,
- transform=transform)
- dataset2 = datasets.MNIST('../data', train=False,
- transform=transform)
- train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)
- test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)
-
- cnn = load_my_model("mase_output/sample_mnist_cnn.pt", device)
- print("-------------- Testing the original cnn model -------------------")
- test(cnn, device, test_loader)
- _, _ = report_trainable_parameters_analysis_pass(cnn)
-
- # onn = load_my_model("mase_output/onn_cnn.pt", device)
- onn_model = perform_optical_module_transform_pass(cnn)
- onn_model.to(device)
- print("-------------- Testing the transformed onn model -------------------")
- test(onn_model, device, test_loader)
- _, _ = report_trainable_parameters_analysis_pass(onn_model)
-
-
- ##################################################################
- ######### Training the onn model
- ##################################################################
- optimizer = optim.Adadelta(onn_model.parameters(), lr=args.lr)
- scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
- for epoch in range(1, args.epochs + 1):
- train(args, onn_model, device, train_loader, optimizer, epoch)
- test(onn_model, device, test_loader)
- scheduler.step()
-
-
- torch.save(onn_model.state_dict(), "mase_output/trained_onn.pt")
-
- print("-------------- Testing the trained onn model -------------------")
- test(onn_model, device, test_loader)
- _, _ = report_trainable_parameters_analysis_pass(onn_model)
+ optical_module_transform_pass(mnist_cnn, pass_args)
+ # torch.save(onn_cnn, "mase_output/onn_cnn.pt")
+
+test_optical_module_transform_pass()
+
+# if __name__ == '__main__':
+# finetune = False
+
+# if True:
+# parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
+# parser.add_argument('--batch-size', type=int, default=64, metavar='N',
+# help='input batch size for training (default: 64)')
+# parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
+# help='input batch size for testing (default: 1000)')
+# parser.add_argument('--epochs', type=int, default=14, metavar='N',
+# help='number of epochs to train (default: 14)')
+# parser.add_argument('--lr', type=float, default=1.0, metavar='LR',
+# help='learning rate (default: 1.0)')
+# parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
+# help='Learning rate step gamma (default: 0.7)')
+# parser.add_argument('--no-cuda', action='store_true', default=False,
+# help='disables CUDA training')
+# parser.add_argument('--no-mps', action='store_true', default=False,
+# help='disables macOS GPU training')
+# parser.add_argument('--dry-run', action='store_true', default=False,
+# help='quickly check a single pass')
+# parser.add_argument('--seed', type=int, default=1, metavar='S',
+# help='random seed (default: 1)')
+# parser.add_argument('--log-interval', type=int, default=10, metavar='N',
+# help='how many batches to wait before logging training status')
+# parser.add_argument('--save-model', action='store_true', default=True,
+# help='For Saving the current Model')
+# parser.add_argument('--gpu-id', type=int, default=0,
+# help='Which GPU device to use [default: 0]')
+
+# args = parser.parse_args()
+# use_cuda = not args.no_cuda and torch.cuda.is_available()
+# use_mps = not args.no_mps and torch.backends.mps.is_available()
+
+# torch.manual_seed(args.seed)
+
+# if not args.no_cuda and torch.cuda.is_available():
+# device = torch.device(f"cuda:{args.gpu_id}")
+# elif use_mps:
+# device = torch.device("mps")
+# else:
+# device = torch.device("cpu")
+
+# train_kwargs = {'batch_size': args.batch_size}
+# test_kwargs = {'batch_size': args.test_batch_size}
+# if use_cuda:
+# cuda_kwargs = {'num_workers': 1,
+# 'pin_memory': True,
+# 'shuffle': True}
+# train_kwargs.update(cuda_kwargs)
+# test_kwargs.update(cuda_kwargs)
+
+# transform=transforms.Compose([
+# transforms.ToTensor(),
+# transforms.Normalize((0.1307,), (0.3081,))
+# ])
+# dataset1 = datasets.MNIST('../data', train=True, download=True,
+# transform=transform)
+# dataset2 = datasets.MNIST('../data', train=False,
+# transform=transform)
+# train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)
+# test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)
+
+# # load pre-trained cnn
+# cnn = load_my_model("mase_output/sample_mnist_cnn.pt", device)
+# print("-------------- Testing the original cnn model -------------------")
+# _, _ = report_trainable_parameters_analysis_pass(cnn)
+# test(cnn, device, test_loader)
+
+# ## transform cnn into onn
+
+# # onn = load_my_model("mase_output/onn_cnn.pt", device)
+# onn_model = perform_optical_module_transform_pass(cnn)
+# onn_model.to(device)
+# print("-------------- Testing the transformed onn model -------------------")
+# _, _ = report_trainable_parameters_analysis_pass(onn_model)
+# test(onn_model, device, test_loader)
+
+# # Training the onn model
+# if finetune:
+# optimizer = optim.Adadelta(onn_model.parameters(), lr=args.lr)
+# scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
+# for epoch in range(1, args.epochs + 1):
+# train(args, onn_model, device, train_loader, optimizer, epoch)
+# test(onn_model, device, test_loader)
+# scheduler.step()
+
+
+# torch.save(onn_model.state_dict(), "mase_output/trained_onn.pt")
+
+# print("-------------- Testing the trained onn model -------------------")
+# test(onn_model, device, test_loader)
+# _, _ = report_trainable_parameters_analysis_pass(onn_model)
-
- # test_optical_module_transform_pass()
\ No newline at end of file
+
+# def perform_optical_module_transform_pass(model, save_path="mase_output/onn_cnn.pt"):
+# # pass_args = {
+# # "by": "type",
+# # "linear": {
+# # "config": {
+# # "name": "morr",
+# # "miniblock": 4,
+# # "morr_init": True,
+# # "trainable_morr_bias": False,
+# # "trainable_morr_scale": False,
+# # }
+# # },
+# # }
+# pass_args = {
+# "by": "type",
+# "conv2d": {
+# "config": {
+# "name": "morr",
+# "miniblock": 4,
+# "morr_init": True,
+# "trainable_morr_bias": False,
+# "trainable_morr_scale": False,
+# }
+# },
+# }
+# onn_model, _ = optical_module_transform_pass(model, pass_args)
+# torch.save(onn_model.state_dict(), save_path)
+# return onn_model
From 5dc0974ccd54a673a125154f1e5809076bf44aa6 Mon Sep 17 00:00:00 2001
From: Johnny1882
Date: Sun, 26 Jan 2025 21:00:10 +0000
Subject: [PATCH 03/38] black format
---
src/chop/nn/optical/__init__.py | 2 +-
src/chop/nn/optical/functional/__init__.py | 3 -
src/chop/nn/optical/functional/general.py | 40 +++--
src/chop/nn/optical/functional/initializer.py | 5 +-
src/chop/nn/optical/functional/mrr_op.py | 56 ++++---
src/chop/nn/optical/functional/quantize.py | 101 +++++++++---
src/chop/nn/optical/functional/torch_train.py | 37 ++---
src/chop/nn/optical/modules/__init__.py | 2 +-
src/chop/nn/optical/modules/base_layer.py | 11 +-
src/chop/nn/optical/modules/morr_conv2d.py | 100 +++++++++---
src/chop/nn/optical/modules/morr_linear.py | 79 +++++++---
src/chop/passes/module/transforms/__init__.py | 2 +-
.../optical/module_transform_helper.py | 55 +++----
.../module/transforms/optical/optical.py | 4 +-
.../transforms/optical/test_optical_module.py | 7 +-
.../transforms/optical/train_mnist_cnn.py | 146 ++++++++++++------
16 files changed, 450 insertions(+), 200 deletions(-)
diff --git a/src/chop/nn/optical/__init__.py b/src/chop/nn/optical/__init__.py
index e9c0423d5..b74e7df6b 100644
--- a/src/chop/nn/optical/__init__.py
+++ b/src/chop/nn/optical/__init__.py
@@ -1,3 +1,3 @@
from .modules import (
optical_module_map,
-)
\ No newline at end of file
+)
diff --git a/src/chop/nn/optical/functional/__init__.py b/src/chop/nn/optical/functional/__init__.py
index 11a435770..84f94e7b6 100644
--- a/src/chop/nn/optical/functional/__init__.py
+++ b/src/chop/nn/optical/functional/__init__.py
@@ -30,9 +30,6 @@
)
-
-
-
# """
# Description:
# Author: Jiaqi Gu (jqgu@utexas.edu)
diff --git a/src/chop/nn/optical/functional/general.py b/src/chop/nn/optical/functional/general.py
index 1d4f2990d..f5aa87e14 100644
--- a/src/chop/nn/optical/functional/general.py
+++ b/src/chop/nn/optical/functional/general.py
@@ -70,7 +70,10 @@ def wrapper(*args, **kw):
local_time = time.time()
res = func(*args, **kw)
end_time = time.time()
- print("[I] <%s> runtime: %.3f ms" % (func.__name__, (end_time - local_time) * 1000))
+ print(
+ "[I] <%s> runtime: %.3f ms"
+ % (func.__name__, (end_time - local_time) * 1000)
+ )
else:
res = func(*args, **kw)
return res
@@ -84,11 +87,13 @@ def print_stat(x, message="", verbose=True):
if torch.is_complex(x):
x = torch.view_as_real(x)
print(
- message + f"min = {x.data.min().item():-15f} max = {x.data.max().item():-15f} mean = {x.data.mean().item():-15f} std = {x.data.std().item():-15f}"
+ message
+ + f"min = {x.data.min().item():-15f} max = {x.data.max().item():-15f} mean = {x.data.mean().item():-15f} std = {x.data.std().item():-15f}"
)
elif isinstance(x, np.ndarray):
print(
- message + f"min = {np.min(x):-15f} max = {np.max(x):-15f} mean = {np.mean(x):-15f} std = {np.std(x):-15f}"
+ message
+ + f"min = {np.min(x):-15f} max = {np.max(x):-15f} mean = {np.mean(x):-15f} std = {np.std(x):-15f}"
)
@@ -127,7 +132,7 @@ def __enter__(self):
return self
def _b2mb(self, x):
- return x / 2 ** 20
+ return x / 2**20
def __exit__(self, *exc):
self.end = self._b2mb(torch.cuda.memory_allocated())
@@ -181,13 +186,17 @@ def format(self, record):
return formatter.format(record)
-def setup_default_logging(default_level=logging.INFO, default_file_level=logging.INFO, log_path=""):
+def setup_default_logging(
+ default_level=logging.INFO, default_file_level=logging.INFO, log_path=""
+):
console_handler = logging.StreamHandler()
console_handler.setFormatter(CustomFormatter())
logging.root.addHandler(console_handler)
logging.root.setLevel(default_level)
if log_path:
- file_handler = logging.handlers.RotatingFileHandler(log_path, maxBytes=(1024 ** 2 * 2), backupCount=3)
+ file_handler = logging.handlers.RotatingFileHandler(
+ log_path, maxBytes=(1024**2 * 2), backupCount=3
+ )
file_formatter = logging.Formatter(
"%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s"
)
@@ -197,7 +206,13 @@ def setup_default_logging(default_level=logging.INFO, default_file_level=logging
class Logger(object):
- def __init__(self, console=True, logfile=None, console_level=logging.INFO, logfile_level=logging.INFO):
+ def __init__(
+ self,
+ console=True,
+ logfile=None,
+ console_level=logging.INFO,
+ logfile_level=logging.INFO,
+ ):
super().__init__()
self.logfile = logfile
self.console_level = console_level
@@ -242,9 +257,16 @@ def critical(self, message):
self.logger.critical(message)
-def get_logger(name="default", default_level=logging.INFO, default_file_level=logging.INFO, log_path=""):
+def get_logger(
+ name="default",
+ default_level=logging.INFO,
+ default_file_level=logging.INFO,
+ log_path="",
+):
setup_default_logging(
- default_level=default_level, default_file_level=default_file_level, log_path=log_path
+ default_level=default_level,
+ default_file_level=default_file_level,
+ log_path=log_path,
)
return logging.getLogger(name)
diff --git a/src/chop/nn/optical/functional/initializer.py b/src/chop/nn/optical/functional/initializer.py
index 53002c398..f0591c33a 100644
--- a/src/chop/nn/optical/functional/initializer.py
+++ b/src/chop/nn/optical/functional/initializer.py
@@ -5,6 +5,7 @@
LastEditors: Jiaqi Gu (jqgu@utexas.edu)
LastEditTime: 2021-06-06 01:57:18
"""
+
import numpy as np
import torch
@@ -83,7 +84,7 @@ def morr_uniform_(tensor, MORRConfig, n_op=4, biased=False, gain=1):
"""
morr_fwhm = (
-4
- * np.pi ** 2
+ * np.pi**2
* MORRConfig.radius
* MORRConfig.effective_index
* (
@@ -124,7 +125,7 @@ def morr_uniform(tensor, MORRConfig, n_op=4, biased=False, gain=1):
"""
morr_fwhm = (
-4
- * np.pi ** 2
+ * np.pi**2
* MORRConfig.radius
* MORRConfig.effective_index
* (
diff --git a/src/chop/nn/optical/functional/mrr_op.py b/src/chop/nn/optical/functional/mrr_op.py
index 42952971e..e98dc4fd9 100644
--- a/src/chop/nn/optical/functional/mrr_op.py
+++ b/src/chop/nn/optical/functional/mrr_op.py
@@ -67,7 +67,9 @@ def mrr_tr_to_roundtrip_phase(t, a, r):
assert 0 <= a <= 1, logging.error(f"Expect a from [0,1] but got {a}")
assert 0 <= r <= 1, logging.error(f"Expect r from [0,1] but got {r}")
# given a and r, the curve is fixed, the max and min may not be 1 and 0
- cos_phi = ((a * a + r * r - t * (1 + r * r * a * a)) / (2 * (1 - t) * a * r)).clamp(0, 1)
+ cos_phi = ((a * a + r * r - t * (1 + r * r * a * a)) / (2 * (1 - t) * a * r)).clamp(
+ 0, 1
+ )
if isinstance(cos_phi, torch.Tensor):
return cos_phi.acos(), cos_phi
@@ -120,7 +122,9 @@ def mrr_roundtrip_phase_to_tr(
@torch.jit.script
-def mrr_roundtrip_phase_to_tr_fused(rt_phi, a: float = 0.8, r: float = 0.9, intensity: bool = False):
+def mrr_roundtrip_phase_to_tr_fused(
+ rt_phi, a: float = 0.8, r: float = 0.9, intensity: bool = False
+):
"""
description: round trip phase shift to field transmission
rt_phi {torch.Tensor or np.ndarray} abs of roundtrip phase shift (abs(phase lag)). range from abs([-pi, 0])=[0, pi]\\
@@ -151,7 +155,9 @@ def mrr_roundtrip_phase_to_tr_fused(rt_phi, a: float = 0.8, r: float = 0.9, inte
@torch.jit.script
-def mrr_roundtrip_phase_to_tr_grad_fused(rt_phi, a: float = 0.8, r: float = 0.9, intensity: bool = False):
+def mrr_roundtrip_phase_to_tr_grad_fused(
+ rt_phi, a: float = 0.8, r: float = 0.9, intensity: bool = False
+):
"""
description: round trip phase shift to the gradient of field transmission
rt_phi {torch.Tensor or np.ndarray} abs of roundtrip phase shift (abs(phase lag)). range from abs([-pi, 0])=[0, pi]\\
@@ -161,22 +167,24 @@ def mrr_roundtrip_phase_to_tr_grad_fused(rt_phi, a: float = 0.8, r: float = 0.9,
return g {torch.Tensor or np.ndarray} the gradient of mrr through port field/intensity transmission
"""
if not intensity:
- g = (a * r * (a ** 2 - 1) * (r ** 2 - 1) * rt_phi.sin()) / (
- (a ** 2 + r ** 2 - 2 * a * r * rt_phi.cos()) ** (1 / 2)
- * (a ** 2 * r ** 2 + 1 - 2 * a * r * rt_phi.cos()) ** 1.5
+ g = (a * r * (a**2 - 1) * (r**2 - 1) * rt_phi.sin()) / (
+ (a**2 + r**2 - 2 * a * r * rt_phi.cos()) ** (1 / 2)
+ * (a**2 * r**2 + 1 - 2 * a * r * rt_phi.cos()) ** 1.5
)
else:
- g = ((a ** 2 - 1) * (r ** 2 - 1) * 2 * a * r * rt_phi.sin()) / (
- a ** 2 * r ** 2 + 1 - 2 * a * r * rt_phi.cos()
+ g = ((a**2 - 1) * (r**2 - 1) * 2 * a * r * rt_phi.sin()) / (
+ a**2 * r**2 + 1 - 2 * a * r * rt_phi.cos()
) ** 2
return g
-def mrr_roundtrip_phase_to_tr_func(a: float = 0.8, r: float = 0.9, intensity: bool = False):
+def mrr_roundtrip_phase_to_tr_func(
+ a: float = 0.8, r: float = 0.9, intensity: bool = False
+):
c1 = -2 * a * r
c2 = a * a + r * r
c3 = 1 + r * r * a * a - a * a - r * r
- c4 = (a ** 2 - 1) * (r ** 2 - 1) * 2 * a * r
+ c4 = (a**2 - 1) * (r**2 - 1) * 2 * a * r
class MRRRoundTripPhaseToTrFunction(torch.autograd.Function):
@staticmethod
@@ -202,7 +210,9 @@ def backward(ctx, grad_output):
numerator = input.sin().mul_(c4)
else:
numerator = input.sin().mul_(c4 / 2)
- denominator = denominator.sub(1).pow_(1.5).mul_(denominator.sub(c3).sqrt_())
+ denominator = (
+ denominator.sub(1).pow_(1.5).mul_(denominator.sub(c3).sqrt_())
+ )
grad_input = numerator.div_(denominator).mul_(grad_output)
return grad_input
@@ -334,7 +344,9 @@ def mrr_filter(x, t, a=0.9, r=0.8):
return out
-def morr_filter(rt_phi, tr_poly_coeff=None, a=0.9, r=0.8, x=None, coherent=False, intensity=False):
+def morr_filter(
+ rt_phi, tr_poly_coeff=None, a=0.9, r=0.8, x=None, coherent=False, intensity=False
+):
"""
description: from round trip phase shift to output signal \\
rt_phi {torch.Tensor or np.ndarray, Optional} round trip phase shift. Default set to None \\
@@ -349,16 +361,22 @@ def morr_filter(rt_phi, tr_poly_coeff=None, a=0.9, r=0.8, x=None, coherent=False
if not coherent:
if x is None:
# unit laser input with incoherent light, 1e^j0
- t = mrr_roundtrip_phase_to_tr(rt_phi, a=a, r=r, poly_coeff=tr_poly_coeff, intensity=intensity)
+ t = mrr_roundtrip_phase_to_tr(
+ rt_phi, a=a, r=r, poly_coeff=tr_poly_coeff, intensity=intensity
+ )
return t
else:
# incoherent light with non-unit input, input must be real number
- t = mrr_roundtrip_phase_to_tr(rt_phi, a=a, r=r, poly_coeff=tr_poly_coeff, intensity=intensity)
+ t = mrr_roundtrip_phase_to_tr(
+ rt_phi, a=a, r=r, poly_coeff=tr_poly_coeff, intensity=intensity
+ )
return x * t
else:
if x is None:
# coherent light with unit laser, 1e^j0, treat morr as a mrr modulator
- phase = polar_to_complex(mag=None, angle=mrr_roundtrip_phase_to_out_phase(rt_phi, a, r))
+ phase = polar_to_complex(
+ mag=None, angle=mrr_roundtrip_phase_to_out_phase(rt_phi, a, r)
+ )
return phase
else:
# coherent light with complex input
@@ -375,7 +393,9 @@ def mrr_fwhm_to_ng(a, r, radius, lambda0, fwhm):
fwhm {float} bandwidth or full width half maximum (unit: nm)\\
return n_g {float} Group index of the MRR
"""
- n_g = (1 - r * a) * lambda0 ** 2 / (2 * np.pi * np.pi * radius * (r * a) ** 0.5 * fwhm)
+ n_g = (
+ (1 - r * a) * lambda0**2 / (2 * np.pi * np.pi * radius * (r * a) ** 0.5 * fwhm)
+ )
return n_g
@@ -388,7 +408,7 @@ def mrr_ng_to_fsr(lambda0, n_g, radius):
radius {float} Radius of the MRR (unit: nm)\\
return fsr {float} Free-spectral range
"""
- fsr = lambda0 ** 2 / (n_g * 2 * np.pi * radius)
+ fsr = lambda0**2 / (n_g * 2 * np.pi * radius)
return fsr
@@ -400,5 +420,5 @@ def mrr_finesse(a, r):
return finesse {float} Finesse of the MRR
"""
ra = r * a
- finesse = np.pi * ra ** 0.5 / (1 - ra)
+ finesse = np.pi * ra**0.5 / (1 - ra)
return finesse
diff --git a/src/chop/nn/optical/functional/quantize.py b/src/chop/nn/optical/functional/quantize.py
index 09b7ef28d..f60e8b240 100644
--- a/src/chop/nn/optical/functional/quantize.py
+++ b/src/chop/nn/optical/functional/quantize.py
@@ -82,7 +82,14 @@ def forward(ctx, input, scale, zero_point):
n = float(2**k - 1)
# out = torch.round(input * n) / n
# out = (torch.clamp(torch.round(input / scale + zero_point), 0, n) - zero_point) * scale
- out = input.div(scale).add_(zero_point).round_().clamp_(0, n).sub_(zero_point).mul_(scale)
+ out = (
+ input.div(scale)
+ .add_(zero_point)
+ .round_()
+ .clamp_(0, n)
+ .sub_(zero_point)
+ .mul_(scale)
+ )
return out
@staticmethod
@@ -108,7 +115,7 @@ class EWGS_quantizer(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
- out = input.mul(num_levels - 1).round_().mul_(1/(num_levels - 1))
+ out = input.mul(num_levels - 1).round_().mul_(1 / (num_levels - 1))
ctx._scaling_factor = scaling_factor
ctx.save_for_backward(input - out)
@@ -128,7 +135,9 @@ def backward(ctx, grad_output):
class input_quantize_fn(torch.nn.Module):
- def __init__(self, in_bit, alg="dorefa", device=torch.device("cuda:0"), quant_ratio=1.0):
+ def __init__(
+ self, in_bit, alg="dorefa", device=torch.device("cuda:0"), quant_ratio=1.0
+ ):
"""Input quantizer with Quant_Noise supported
Args:
in_bit (int): Input quantization bitwidth.
@@ -139,9 +148,14 @@ def __init__(self, in_bit, alg="dorefa", device=torch.device("cuda:0"), quant_ra
assert 1 <= in_bit <= 32
self.in_bit = in_bit
self.alg = alg
- assert alg in {"dorefa", "normal"}, f"Only support (dorefa, normal), but got {alg}"
+ assert alg in {
+ "dorefa",
+ "normal",
+ }, f"Only support (dorefa, normal), but got {alg}"
self.quant_ratio = quant_ratio
- assert 0 <= quant_ratio <= 1, logging.error(f"Wrong quant ratio. Must in [0,1], but got {quant_ratio}")
+ assert 0 <= quant_ratio <= 1, logging.error(
+ f"Wrong quant ratio. Must in [0,1], but got {quant_ratio}"
+ )
self.device = device
# define quant style
@@ -182,7 +196,10 @@ def set_bitwidth(self, bit: int) -> None:
self.in_bit = bit
def set_alg(self, alg: str) -> None:
- assert alg in {"dorefa", "normal"}, f"Only support (dorefa, normal), but got {alg}"
+ assert alg in {
+ "dorefa",
+ "normal",
+ }, f"Only support (dorefa, normal), but got {alg}"
if alg != self.alg:
if alg == "dorefa":
self.uniform_q = uniform_quantize(k=self.in_bit)
@@ -212,14 +229,18 @@ def set_quant_ratio(self, quant_ratio=None):
0.99,
1,
][min(self.in_bit, 16)]
- assert 0 <= quant_ratio <= 1, logging.error(f"Wrong quant ratio. Must in [0,1], but got {quant_ratio}")
+ assert 0 <= quant_ratio <= 1, logging.error(
+ f"Wrong quant ratio. Must in [0,1], but got {quant_ratio}"
+ )
self.quant_ratio = quant_ratio
def forward(self, x):
if self.quant_ratio < 1 and self.training:
### implementation from fairseq
### must fully quantize during inference
- quant_noise_mask = torch.empty_like(x, dtype=torch.bool).bernoulli_(1 - self.quant_ratio)
+ quant_noise_mask = torch.empty_like(x, dtype=torch.bool).bernoulli_(
+ 1 - self.quant_ratio
+ )
else:
quant_noise_mask = None
@@ -271,7 +292,9 @@ def __init__(self, w_bit, mode="oconv", alg="dorefa", quant_ratio=1.0):
quant_ratio (float, optional): Quantization ratio to support full-precision gradient flow. Defaults to 1.0.
"""
super(weight_quantize_fn, self).__init__()
- assert 1 <= w_bit <= 32, logging.error(f"Only support 1 - 32 bit quantization, but got {w_bit}")
+ assert 1 <= w_bit <= 32, logging.error(
+ f"Only support 1 - 32 bit quantization, but got {w_bit}"
+ )
self.w_bit = w_bit
self.alg = alg
self.mode = mode
@@ -279,7 +302,9 @@ def __init__(self, w_bit, mode="oconv", alg="dorefa", quant_ratio=1.0):
f"Only support (dorefa, dorefa_sym, qnn, dorefa_pos) algorithms, but got {alg}"
)
self.quant_ratio = quant_ratio
- assert 0 <= quant_ratio <= 1, logging.error(f"Wrong quant ratio. Must in [0,1], but got {quant_ratio}")
+ assert 0 <= quant_ratio <= 1, logging.error(
+ f"Wrong quant ratio. Must in [0,1], but got {quant_ratio}"
+ )
self.uniform_q = uniform_quantize(k=w_bit, gradient_clip=True)
def set_quant_ratio(self, quant_ratio=None):
@@ -304,7 +329,9 @@ def set_quant_ratio(self, quant_ratio=None):
0.99,
1,
][min(self.w_bit, 16)]
- assert 0 <= quant_ratio <= 1, logging.error(f"Wrong quant ratio. Must in [0,1], but got {quant_ratio}")
+ assert 0 <= quant_ratio <= 1, logging.error(
+ f"Wrong quant ratio. Must in [0,1], but got {quant_ratio}"
+ )
self.quant_ratio = quant_ratio
def set_bitwidth(self, bit: int) -> None:
@@ -317,7 +344,9 @@ def forward(self, x):
if self.quant_ratio < 1 and self.training:
### implementation from fairseq
### must fully quantize during inference
- quant_noise_mask = torch.empty_like(x, dtype=torch.bool).bernoulli_(1 - self.quant_ratio)
+ quant_noise_mask = torch.empty_like(x, dtype=torch.bool).bernoulli_(
+ 1 - self.quant_ratio
+ )
else:
quant_noise_mask = None
@@ -333,14 +362,18 @@ def forward(self, x):
weight_q = (self.uniform_q(x / E) * E + E) / 2 # [0, E]
if quant_noise_mask is not None:
x = (x + E) / 2
- noise = weight_q.data.sub_(x.data).masked_fill_(quant_noise_mask, 0)
+ noise = weight_q.data.sub_(x.data).masked_fill_(
+ quant_noise_mask, 0
+ )
### unquantized weights have to follow reparameterization, i.e., tanh and scale
weight_q = x + noise
elif self.alg == "dorefa_sym":
E = x.data.abs().mean()
weight_q = self.uniform_q(x / E) * E # [-E, E]
if quant_noise_mask is not None:
- noise = weight_q.data.sub_(x.data).masked_fill_(quant_noise_mask, 0)
+ noise = weight_q.data.sub_(x.data).masked_fill_(
+ quant_noise_mask, 0
+ )
### unquantized weights have to follow reparameterization, i.e., tanh and scale
weight_q = x + noise
else:
@@ -352,7 +385,9 @@ def forward(self, x):
# weight = weight / 2 + 0.5
weight_q = self.uniform_q(weight)
if quant_noise_mask is not None:
- noise = weight_q.data.sub_(weight.data).masked_fill_(quant_noise_mask, 0)
+ noise = weight_q.data.sub_(weight.data).masked_fill_(
+ quant_noise_mask, 0
+ )
### unquantized weights have to follow reparameterization, i.e., tanh and scale
weight_q = weight + noise
@@ -362,7 +397,9 @@ def forward(self, x):
# weight = weight / 2 + 0.5
weight_q = self.uniform_q(weight / (2 * r) + 0.5) * (2 * r) - r
if quant_noise_mask is not None:
- noise = weight_q.data.sub_(weight.data).masked_fill_(quant_noise_mask, 0)
+ noise = weight_q.data.sub_(weight.data).masked_fill_(
+ quant_noise_mask, 0
+ )
### unquantized weights have to follow reparameterization, i.e., tanh
weight_q = weight + noise
elif self.alg == "dorefa_pos":
@@ -372,7 +409,9 @@ def forward(self, x):
# weight = weight / 2 + 0.5
weight_q = self.uniform_q(weight / (2 * r)) * 2 * r
if quant_noise_mask is not None:
- noise = weight_q.data.sub_(weight.data).masked_fill_(quant_noise_mask, 0)
+ noise = weight_q.data.sub_(weight.data).masked_fill_(
+ quant_noise_mask, 0
+ )
### unquantized weights have to follow reparameterization, i.e., tanh
weight_q = weight + noise
@@ -484,7 +523,9 @@ def __init__(
super(PACT_Act, self).__init__()
self.precision = precision
self.device = device
- self.alpha = torch.nn.Parameter(torch.Tensor((alpha,)).to(device), requires_grad=backprop_alpha)
+ self.alpha = torch.nn.Parameter(
+ torch.Tensor((alpha,)).to(device), requires_grad=backprop_alpha
+ )
self.alpha_p = alpha
self.statistics_only = statistics_only
self.deployment = False
@@ -493,10 +534,18 @@ def __init__(
# self.requantization_factor = requantization_factor
# these are only used to gather statistics
- self.max = torch.nn.Parameter(torch.zeros_like(self.alpha.data).to(device), requires_grad=False)
- self.min = torch.nn.Parameter(torch.zeros_like(self.alpha.data).to(device), requires_grad=False)
- self.running_mean = torch.nn.Parameter(torch.zeros_like(self.alpha.data).to(device), requires_grad=False)
- self.running_var = torch.nn.Parameter(torch.ones_like(self.alpha.data).to(device), requires_grad=False)
+ self.max = torch.nn.Parameter(
+ torch.zeros_like(self.alpha.data).to(device), requires_grad=False
+ )
+ self.min = torch.nn.Parameter(
+ torch.zeros_like(self.alpha.data).to(device), requires_grad=False
+ )
+ self.running_mean = torch.nn.Parameter(
+ torch.zeros_like(self.alpha.data).to(device), requires_grad=False
+ )
+ self.running_var = torch.nn.Parameter(
+ torch.ones_like(self.alpha.data).to(device), requires_grad=False
+ )
self.precise = False
@@ -508,7 +557,9 @@ def set_static_precision(self, limit_at_32_bits=True, **kwargs):
self.eps_static = self.alpha.clone().detach() / (2.0 ** (self.precision) - 1)
self.alpha_static = self.alpha.clone().detach()
# D is selected as a power-of-two
- D = 2.0 ** torch.ceil(torch.log2(self.requantization_factor * self.eps_static / self.eps_in))
+ D = 2.0 ** torch.ceil(
+ torch.log2(self.requantization_factor * self.eps_static / self.eps_in)
+ )
if not limit_at_32_bits:
self.D = D
else:
@@ -568,7 +619,9 @@ def forward(self, x):
self.max[:] = max(self.max.item(), x.max())
self.min[:] = min(self.min.item(), x.min())
self.running_mean[:] = 0.9 * self.running_mean.item() + 0.1 * x.mean()
- self.running_var[:] = 0.9 * self.running_var.item() + 0.1 * x.std() * x.std()
+ self.running_var[:] = (
+ 0.9 * self.running_var.item() + 0.1 * x.std() * x.std()
+ )
return x
else:
eps = self.alpha / (2.0 ** (self.precision) - 1)
diff --git a/src/chop/nn/optical/functional/torch_train.py b/src/chop/nn/optical/functional/torch_train.py
index c62ba3e91..41571effa 100644
--- a/src/chop/nn/optical/functional/torch_train.py
+++ b/src/chop/nn/optical/functional/torch_train.py
@@ -17,6 +17,7 @@
import torch
from scipy import interpolate
from torch.nn.modules.batchnorm import _BatchNorm
+
try:
from torchsummary import summary
except:
@@ -48,26 +49,26 @@
"enable_bn",
]
-class DeterministicCtx:
- def __init__(self, random_state: int | None = None) -> None:
- self.random_state = random_state
-
-
- def __enter__(self):
- self.random_state = random.getstate()
- self.numpy_random_state = np.random.get_state()
- self.torch_random_state = torch.random.get_rng_state()
- self.torch_cuda_random_state = torch.cuda.get_rng_state()
- set_torch_deterministic(self.random_state)
- return self
+class DeterministicCtx:
+ def __init__(self, random_state: int | None = None) -> None:
+ self.random_state = random_state
+
+ def __enter__(self):
+ self.random_state = random.getstate()
+ self.numpy_random_state = np.random.get_state()
+ self.torch_random_state = torch.random.get_rng_state()
+ self.torch_cuda_random_state = torch.cuda.get_rng_state()
+ set_torch_deterministic(self.random_state)
+ return self
+
+ def __exit__(self, *args):
+ random.setstate(self.random_state)
+ np.random.seed(self.numpy_random_state)
+ np.random.set_state(self.numpy_random_state)
+ torch.random.set_rng_state(self.torch_random_state)
+ torch.cuda.set_rng_state(self.torch_cuda_random_state)
- def __exit__(self, *args):
- random.setstate(self.random_state)
- np.random.seed(self.numpy_random_state)
- np.random.set_state(self.numpy_random_state)
- torch.random.set_rng_state(self.torch_random_state)
- torch.cuda.set_rng_state(self.torch_cuda_random_state)
def set_torch_deterministic(random_state: int = 0) -> None:
random_state = int(random_state) % (2**32)
diff --git a/src/chop/nn/optical/modules/__init__.py b/src/chop/nn/optical/modules/__init__.py
index 46b6f21e9..b1d7c5629 100644
--- a/src/chop/nn/optical/modules/__init__.py
+++ b/src/chop/nn/optical/modules/__init__.py
@@ -4,4 +4,4 @@
optical_module_map = {
"linear_morr": AllPassMORRCirculantLinear,
"conv2d_morr": AllPassMORRCirculantConv2d,
-}
\ No newline at end of file
+}
diff --git a/src/chop/nn/optical/modules/base_layer.py b/src/chop/nn/optical/modules/base_layer.py
index cb1fbe9dd..4ae2b35bf 100644
--- a/src/chop/nn/optical/modules/base_layer.py
+++ b/src/chop/nn/optical/modules/base_layer.py
@@ -5,6 +5,7 @@
LastEditors: Jiaqi Gu (jqgu@utexas.edu)
LastEditTime: 2021-06-08 18:55:05
"""
+
from typing import Any, Dict, Optional
import torch
from torch import nn
@@ -24,7 +25,7 @@ def build_parameters(self) -> None:
def reset_parameters(self) -> None:
raise NotImplementedError
-
+
@classmethod
def from_layer(cls, layer: nn.Module, *args, **kwargs) -> nn.Module:
raise NotImplementedError
@@ -38,10 +39,14 @@ def enable_fast_forward(self) -> None:
def disable_fast_forward(self) -> None:
self.fast_forward_flag = False
- def set_phase_variation(self, noise_std: float, random_state: Optional[int] = None) -> None:
+ def set_phase_variation(
+ self, noise_std: float, random_state: Optional[int] = None
+ ) -> None:
self.phase_noise_std = noise_std
- def set_gamma_noise(self, noise_std: float, random_state: Optional[int] = None) -> None:
+ def set_gamma_noise(
+ self, noise_std: float, random_state: Optional[int] = None
+ ) -> None:
self.gamma_noise_std = noise_std
def set_crosstalk_factor(self, crosstalk_factor: float) -> None:
diff --git a/src/chop/nn/optical/modules/morr_conv2d.py b/src/chop/nn/optical/modules/morr_conv2d.py
index 87fd8ed42..a554bc6c2 100644
--- a/src/chop/nn/optical/modules/morr_conv2d.py
+++ b/src/chop/nn/optical/modules/morr_conv2d.py
@@ -4,6 +4,7 @@
LastEditors: Jiaqi Gu (jqgu@utexas.edu)
LastEditTime: 2021-07-18 00:40:18
"""
+
from typing import Optional, Tuple
import numpy as np
@@ -71,7 +72,7 @@ def __init__(
dilation: _size = 1,
groups: int = 1,
bias: bool = True,
- padding_mode = None,
+ padding_mode=None,
# miniblock: int = 4,
# ### morr parameter
# MORRConfig=MORRConfig_20um_MQ,
@@ -79,7 +80,7 @@ def __init__(
# ### trainable MORR nonlinearity
# trainable_morr_bias: bool = False,
# trainable_morr_scale: bool = False,
- config = None,
+ config=None,
device: Device = torch.device("cuda"),
) -> None:
super(AllPassMORRCirculantConv2d, self).__init__()
@@ -89,7 +90,6 @@ def __init__(
trainable_morr_bias = config.get("trainable_morr_bias", False)
trainable_morr_scale = config.get("trainable_morr_scale", False)
-
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = _pair(kernel_size)
@@ -97,8 +97,12 @@ def __init__(
self.padding = _pair(padding)
self.dilation = _pair(dilation)
self.groups = groups
- assert groups == 1, f"Currently group convolution is not supported, but got group: {groups}"
- self.in_channels_flat = self.in_channels * self.kernel_size[0] * self.kernel_size[1]
+ assert (
+ groups == 1
+ ), f"Currently group convolution is not supported, but got group: {groups}"
+ self.in_channels_flat = (
+ self.in_channels * self.kernel_size[0] * self.kernel_size[1]
+ )
self.grid_dim_x = int(np.ceil(self.in_channels_flat / miniblock))
self.grid_dim_y = int(np.ceil(self.out_channels / miniblock))
self.in_channels_pad = self.grid_dim_x * miniblock
@@ -107,7 +111,7 @@ def __init__(
self.v_max = 10.8
self.v_pi = 4.36
- self.gamma = np.pi / self.v_pi ** 2
+ self.gamma = np.pi / self.v_pi**2
self.w_bit = 32
self.in_bit = 32
self.MORRConfig = MORRConfig
@@ -121,7 +125,7 @@ def __init__(
### calculate FWHM (rad)
self.morr_fwhm = (
-4
- * np.pi ** 2
+ * np.pi**2
* MORRConfig.radius
* MORRConfig.effective_index
* (
@@ -135,7 +139,9 @@ def __init__(
self.x_zero_pad = None
self.morr_output_scale = None ## learnable balancing factors implelemt by MRRs
self.morr_input_bias = None ## round-trip phase shift bias within MORR
- self.morr_input_scale = None ## scaling factor for the round-trip phase shift within MORR
+ self.morr_input_scale = (
+ None ## scaling factor for the round-trip phase shift within MORR
+ )
self.morr_gain = (
100 / (self.in_channels_flat // self.miniblock)
) ** 0.5 ## set this TIA gain such that output variance is around 1
@@ -178,21 +184,37 @@ def build_parameters(self) -> None:
### MORR weights
self.weight = Parameter(
torch.ones(
- self.grid_dim_y, self.grid_dim_x, self.miniblock, device=self.device, dtype=torch.float
+ self.grid_dim_y,
+ self.grid_dim_x,
+ self.miniblock,
+ device=self.device,
+ dtype=torch.float,
)
)
### learnable balancing factor achieved by MRRs (morr_output_scale)
### We use a single scaling factor for each block
- self.morr_output_scale = Parameter(torch.zeros(max(1, self.grid_dim_x // 2) + 1, device=self.device))
+ self.morr_output_scale = Parameter(
+ torch.zeros(max(1, self.grid_dim_x // 2) + 1, device=self.device)
+ )
if self.trainable_morr_bias:
### initialize with the finest-granularity, i.e., per mini-block
self.morr_input_bias = Parameter(
- torch.zeros(self.grid_dim_y, self.grid_dim_x, device=self.device, dtype=torch.float)
+ torch.zeros(
+ self.grid_dim_y,
+ self.grid_dim_x,
+ device=self.device,
+ dtype=torch.float,
+ )
)
if self.trainable_morr_scale:
### initialize with the finest-granularity, i.e., per mini-block
self.morr_input_scale = Parameter(
- torch.zeros(self.grid_dim_y, self.grid_dim_x, device=self.device, dtype=torch.float)
+ torch.zeros(
+ self.grid_dim_y,
+ self.grid_dim_x,
+ device=self.device,
+ dtype=torch.float,
+ )
)
def reset_parameters(self, morr_init: bool = False) -> None:
@@ -212,11 +234,16 @@ def reset_parameters(self, morr_init: bool = False) -> None:
torch.tensor([0]).float(), a=self.mrr_a, r=self.mrr_r, intensity=True
)
t2 = mrr_roundtrip_phase_to_tr_fused(
- torch.tensor([self.morr_fwhm * 2.4]).float(), a=self.mrr_a, r=self.mrr_r, intensity=True
+ torch.tensor([self.morr_fwhm * 2.4]).float(),
+ a=self.mrr_a,
+ r=self.mrr_r,
+ intensity=True,
)
- g = ((t2 - t1) / (2.4 * self.morr_fwhm)).item() ## 0~2.4 FWHM slope as a linear approximation
+ g = (
+ (t2 - t1) / (2.4 * self.morr_fwhm)
+ ).item() ## 0~2.4 FWHM slope as a linear approximation
- self.sigma_out_scale = 4 / (3 * self.grid_dim_x ** 0.5 * g * self.morr_fwhm)
+ self.sigma_out_scale = 4 / (3 * self.grid_dim_x**0.5 * g * self.morr_fwhm)
self.out_scale_quant_gain = None
init.normal_(self.morr_output_scale, 0, self.sigma_out_scale)
@@ -260,7 +287,9 @@ def enable_fast_forward(self) -> None:
def disable_fast_forward(self) -> None:
self.fast_forward_flag = False
- def set_gamma_noise(self, noise_std: float, random_state: Optional[int] = None) -> None:
+ def set_gamma_noise(
+ self, noise_std: float, random_state: Optional[int] = None
+ ) -> None:
self.gamma_noise_std = noise_std
def load_parameters(self, param_dict) -> None:
@@ -284,7 +313,9 @@ def input_modulator(self, x: Tensor) -> Tensor:
### voltage to power, which is proportional to the phase shift
return x * x
- def set_crosstalk_coupling_matrix(self, coupling_factor: float, drop_perc: float = 0) -> None:
+ def set_crosstalk_coupling_matrix(
+ self, coupling_factor: float, drop_perc: float = 0
+ ) -> None:
### crosstalk coupling matrix is a symmetric matrix, but the intra-MORR crosstalk can be taken as a round-trip phase shift scaling factor, which is proportional to the number of segments after pruned.
### See SqueezeLight paper
### drop-perc is the pruning percentage.
@@ -292,7 +323,9 @@ def set_crosstalk_coupling_matrix(self, coupling_factor: float, drop_perc: float
f"Coupling factor must in [0,1], but got {coupling_factor}"
)
- self.crosstalk_factor = 1 + max(3, (self.miniblock * (1 - drop_perc) - 1)) * coupling_factor
+ self.crosstalk_factor = (
+ 1 + max(3, (self.miniblock * (1 - drop_perc) - 1)) * coupling_factor
+ )
def enable_crosstalk(self) -> None:
self.enable_thermal_crosstalk = True
@@ -327,7 +360,9 @@ def morr_scale(self) -> Tensor:
@property
def morr_bias(self) -> Tensor:
- return self.morr_fwhm * torch.tanh(self.morr_input_bias.unsqueeze(0).unsqueeze(-1))
+ return self.morr_fwhm * torch.tanh(
+ self.morr_input_bias.unsqueeze(0).unsqueeze(-1)
+ )
def propagate_morr(self, weight: Tensor, x: Tensor) -> Tensor:
"""
@@ -352,7 +387,9 @@ def propagate_morr(self, weight: Tensor, x: Tensor) -> Tensor:
x = weight.matmul(x).squeeze(-1) # [h*w*bs, p, q, k]
if self.enable_phase_noise and self.phase_noise_std > 1e-5:
- x = x + torch.zeros_like(x).normal_(0, self.phase_noise_std) # [h*w*bs, p, q, k]
+ x = x + torch.zeros_like(x).normal_(
+ 0, self.phase_noise_std
+ ) # [h*w*bs, p, q, k]
### input scaling, learnable MORR nonlinearity
if self.trainable_morr_scale:
@@ -370,7 +407,9 @@ def propagate_morr(self, weight: Tensor, x: Tensor) -> Tensor:
if self.w_bit < 16:
morr_output_scale = self.morr_output_scale_quantizer(self.morr_output_scale)
if self.out_scale_quant_gain is None:
- self.out_scale_quant_gain = self.sigma_out_scale / morr_output_scale.data.std().item()
+ self.out_scale_quant_gain = (
+ self.sigma_out_scale / morr_output_scale.data.std().item()
+ )
morr_output_scale = morr_output_scale.mul(
self.out_scale_quant_gain
) ### gain factor from Tanh used in quantization
@@ -392,7 +431,9 @@ def propagate_morr(self, weight: Tensor, x: Tensor) -> Tensor:
scale = scale_pad
scale = scale.unsqueeze(0).unsqueeze(0).unsqueeze(0) # [1, 1, 1, q]
- x = scale.matmul(x) # [1,1,1,q]x[h_out*w_out*bs, p, q, k]=[h_out*w_out*bs, p, 1, k]
+ x = scale.matmul(
+ x
+ ) # [1,1,1,q]x[h_out*w_out*bs, p, q, k]=[h_out*w_out*bs, p, 1, k]
x = x.view(x.size(0), -1).t() # [p*k, h_out*w_out*bs]
if self.out_channels_pad > self.out_channels:
x = x[: self.out_channels, :] # [outc, h_out*w_out*bs]
@@ -407,7 +448,12 @@ def morr_conv2d(self, X: Tensor, W: Tensor) -> Tensor:
X,
stride=self.stride[0],
padding=self.padding[0],
- w_size=(self.out_channels, self.in_channels, self.kernel_size[0], self.kernel_size[1]),
+ w_size=(
+ self.out_channels,
+ self.in_channels,
+ self.kernel_size[0],
+ self.kernel_size[1],
+ ),
)
## zero-padding X_col
if self.in_channels_pad > self.in_channels_flat:
@@ -450,8 +496,12 @@ def get_output_dim(self, img_height: int, img_width: int) -> Tuple[int, int]:
"""
get the output features size
"""
- h_out = (img_height - self.kernel_size[0] + 2 * self.padding[0]) / self.stride[0] + 1
- w_out = (img_width - self.kernel_size[1] + 2 * self.padding[1]) / self.stride[1] + 1
+ h_out = (img_height - self.kernel_size[0] + 2 * self.padding[0]) / self.stride[
+ 0
+ ] + 1
+ w_out = (img_width - self.kernel_size[1] + 2 * self.padding[1]) / self.stride[
+ 1
+ ] + 1
return (int(h_out), int(w_out))
def forward(self, x: Tensor) -> Tensor:
diff --git a/src/chop/nn/optical/modules/morr_linear.py b/src/chop/nn/optical/modules/morr_linear.py
index 4846fb763..9a3f483e5 100644
--- a/src/chop/nn/optical/modules/morr_linear.py
+++ b/src/chop/nn/optical/modules/morr_linear.py
@@ -5,6 +5,7 @@
LastEditors: Jiaqi Gu (jqgu@utexas.edu)
LastEditTime: 2022-04-18 16:21:37
"""
+
from typing import Optional
import numpy as np
@@ -43,7 +44,7 @@ def __init__(
in_features: int,
out_features: int,
bias: bool = False,
- config = None,
+ config=None,
# miniblock: int = 4,
# ### mrr parameter
# MORRConfig=MORRConfig_20um_MQ,
@@ -66,7 +67,7 @@ def __init__(
self.v_max = 10.8
self.v_pi = 4.36
- self.gamma = np.pi / self.v_pi ** 2
+ self.gamma = np.pi / self.v_pi**2
self.w_bit = 32
self.in_bit = 32
@@ -77,12 +78,14 @@ def __init__(
self.mrr_a = morr_config.attenuation_factor
self.mrr_r = morr_config.coupling_factor
self.trainable_morr_bias = config.get("trainable_morr_bias", MORRConfig_20um_MQ)
- self.trainable_morr_scale = config.get("trainable_morr_scale", MORRConfig_20um_MQ)
+ self.trainable_morr_scale = config.get(
+ "trainable_morr_scale", MORRConfig_20um_MQ
+ )
self.device = device
### calculate FWHM (rad)
self.morr_fwhm = (
-4
- * np.pi ** 2
+ * np.pi**2
* morr_config.radius
* morr_config.effective_index
* (
@@ -96,7 +99,9 @@ def __init__(
self.x_zero_pad = None
self.morr_output_scale = None ## learnable balancing factors implelemt by MRRs
self.morr_input_bias = None ## round-trip phase shift bias within MORR
- self.morr_input_scale = None ## scaling factor for the round-trip phase shift within MORR
+ self.morr_input_scale = (
+ None ## scaling factor for the round-trip phase shift within MORR
+ )
self.morr_gain = (
100 / (self.in_features // self.miniblock)
) ** 0.5 ## TIA gain, calculated such that output variance is around 1
@@ -137,7 +142,11 @@ def build_parameters(self) -> None:
self.weight = Parameter(
torch.ones(
- self.grid_dim_y, self.grid_dim_x, self.miniblock, device=self.device, dtype=torch.float
+ self.grid_dim_y,
+ self.grid_dim_x,
+ self.miniblock,
+ device=self.device,
+ dtype=torch.float,
)
)
### Learnable balancing factor (morr_output_scale)
@@ -148,12 +157,22 @@ def build_parameters(self) -> None:
if self.trainable_morr_bias:
### initialize with the finest-granularity, i.e., per mini-block
self.morr_input_bias = Parameter(
- torch.zeros(self.grid_dim_y, self.grid_dim_x, device=self.device, dtype=torch.float)
+ torch.zeros(
+ self.grid_dim_y,
+ self.grid_dim_x,
+ device=self.device,
+ dtype=torch.float,
+ )
)
if self.trainable_morr_scale:
### initialize with the finest-granularity, i.e., per mini-block
self.morr_input_scale = Parameter(
- torch.zeros(self.grid_dim_y, self.grid_dim_x, device=self.device, dtype=torch.float)
+ torch.zeros(
+ self.grid_dim_y,
+ self.grid_dim_x,
+ device=self.device,
+ dtype=torch.float,
+ )
)
def reset_parameters(self, morr_init: bool = False) -> None:
@@ -175,11 +194,16 @@ def reset_parameters(self, morr_init: bool = False) -> None:
torch.tensor([0]).float(), a=self.mrr_a, r=self.mrr_r, intensity=True
)
t2 = mrr_roundtrip_phase_to_tr_fused(
- torch.tensor([self.morr_fwhm * 2.4]).float(), a=self.mrr_a, r=self.mrr_r, intensity=True
+ torch.tensor([self.morr_fwhm * 2.4]).float(),
+ a=self.mrr_a,
+ r=self.mrr_r,
+ intensity=True,
)
- g = ((t2 - t1) / (2.4 * self.morr_fwhm)).item() ## 0~2.4 FWHM slope as a linear approximation
+ g = (
+ (t2 - t1) / (2.4 * self.morr_fwhm)
+ ).item() ## 0~2.4 FWHM slope as a linear approximation
- self.sigma_out_scale = 4 / (3 * self.grid_dim_x ** 0.5 * g * self.morr_fwhm)
+ self.sigma_out_scale = 4 / (3 * self.grid_dim_x**0.5 * g * self.morr_fwhm)
self.out_scale_quant_gain = None
init.normal_(self.morr_output_scale, 0, self.sigma_out_scale)
else:
@@ -218,13 +242,17 @@ def build_weight(self) -> Tensor:
morr_scale = self.morr_scale * self.weight_quant_gain
else:
morr_scale = self.weight_quant_gain
- weight = weight.mul(morr_scale) ### gain factor from Tanh used in quantization
+ weight = weight.mul(
+ morr_scale
+ ) ### gain factor from Tanh used in quantization
### quantize learnable balancing factor
morr_output_scale = self.morr_output_scale_quantizer(self.morr_output_scale)
else:
weight = self.weight.abs() # positive only
- morr_output_scale = self.morr_output_scale - self.morr_output_scale.data.mean()
+ morr_output_scale = (
+ self.morr_output_scale - self.morr_output_scale.data.mean()
+ )
if self.finegrain_drop_mask is not None:
weight = weight.mul(self.finegrain_drop_mask.float())
@@ -251,7 +279,9 @@ def enable_fast_forward(self) -> None:
def disable_fast_forward(self) -> None:
self.fast_forward_flag = False
- def set_gamma_noise(self, noise_std: float, random_state: Optional[int] = None) -> None:
+ def set_gamma_noise(
+ self, noise_std: float, random_state: Optional[int] = None
+ ) -> None:
self.gamma_noise_std = noise_std
def load_parameters(self, param_dict) -> None:
@@ -275,14 +305,18 @@ def input_modulator(self, x: Tensor) -> Tensor:
### voltage to power, which is proportional to the phase shift
return x * x
- def set_crosstalk_coupling_matrix(self, coupling_factor: float, drop_perc: float = 0) -> None:
+ def set_crosstalk_coupling_matrix(
+ self, coupling_factor: float, drop_perc: float = 0
+ ) -> None:
### crosstalk coupling matrix is a symmetric matrix, but the intra-MORR crosstalk can be taken as a round-trip phase shift scaling factor, which is proportional to the number of segments after pruned.
### drop-perc is the pruning percentage.
assert 0 <= coupling_factor <= 1, logger.error(
f"Coupling factor must in [0,1], but got {coupling_factor}"
)
- self.crosstalk_factor = 1 + max(3, (self.miniblock * (1 - drop_perc) - 1)) * coupling_factor
+ self.crosstalk_factor = (
+ 1 + max(3, (self.miniblock * (1 - drop_perc) - 1)) * coupling_factor
+ )
def enable_crosstalk(self) -> None:
self.enable_thermal_crosstalk = True
@@ -316,7 +350,9 @@ def morr_bias(self) -> Tensor:
if self.morr_input_bias is None:
return None
# return 2 * self.morr_fwhm * torch.sigmoid(self.morr_input_bias.unsqueeze(0).unsqueeze(-1))
- return self.morr_fwhm * torch.tanh(self.morr_input_bias.unsqueeze(0).unsqueeze(-1))
+ return self.morr_fwhm * torch.tanh(
+ self.morr_input_bias.unsqueeze(0).unsqueeze(-1)
+ )
@property
def morr_scale(self) -> Tensor:
@@ -324,7 +360,9 @@ def morr_scale(self) -> Tensor:
return None
return torch.sigmoid(self.morr_input_scale.unsqueeze(-1)) + 0.2 # [p, q, 1]
- def propagate_morr(self, weight: Tensor, x: Tensor, morr_output_scale: Tensor) -> Tensor:
+ def propagate_morr(
+ self, weight: Tensor, x: Tensor, morr_output_scale: Tensor
+ ) -> Tensor:
"""
@description: propagate through the analytically calculated transfer matrix of molg. We implement circulant matrix multiplication using fast circ matmul
@param weight {torch.Tensor} two phase shifters in the MZI-based attenuators
@@ -420,7 +458,10 @@ def forward(self, x: Tensor) -> Tensor:
if self.in_features_pad > self.in_features:
if self.x_zero_pad is None or self.x_zero_pad.size(0) != x.size(0):
self.x_zero_pad = torch.zeros(
- x.size(0), self.in_features_pad - self.in_features, device=x.device, dtype=x.dtype
+ x.size(0),
+ self.in_features_pad - self.in_features,
+ device=x.device,
+ dtype=x.dtype,
)
x = torch.cat([x, self.x_zero_pad], dim=1)
diff --git a/src/chop/passes/module/transforms/__init__.py b/src/chop/passes/module/transforms/__init__.py
index 8773589ec..caaff1b65 100644
--- a/src/chop/passes/module/transforms/__init__.py
+++ b/src/chop/passes/module/transforms/__init__.py
@@ -1,4 +1,4 @@
from .autosharding import resharding_transform_pass
from .quantize import quantize_module_transform_pass
from .autosharding import resharding_transform_pass
-from .optical import optical_module_transform_pass
\ No newline at end of file
+from .optical import optical_module_transform_pass
diff --git a/src/chop/passes/module/transforms/optical/module_transform_helper.py b/src/chop/passes/module/transforms/optical/module_transform_helper.py
index dc110d9b0..1cef06928 100644
--- a/src/chop/passes/module/transforms/optical/module_transform_helper.py
+++ b/src/chop/passes/module/transforms/optical/module_transform_helper.py
@@ -2,13 +2,13 @@
import torch.nn as nn
import numpy as np
from chop.nn.optical.modules import optical_module_map
-from chop.passes.module.module_modify_helper import get_module_by_name, set_module_by_name
+from chop.passes.module.module_modify_helper import (
+ get_module_by_name,
+ set_module_by_name,
+)
-def replace_by_name_optical(
- network,
- module_name: str,
- new_module
-):
+
+def replace_by_name_optical(network, module_name: str, new_module):
original = get_module_by_name(network, module_name)
updated_module = weight_replacement_optical(original, new_module)
@@ -16,60 +16,63 @@ def replace_by_name_optical(
return network
+
def weight_replacement_optical(original, new_module):
if isinstance(original, nn.Linear):
return weight_replacement_linear_optical(original, new_module)
elif isinstance(original, nn.Conv2d):
return weight_replacement_conv2d_optical(original, new_module)
else:
- raise NotImplementedError("weight replacement function for the optical module not implemented")
+ raise NotImplementedError(
+ "weight replacement function for the optical module not implemented"
+ )
+
def weight_replacement_linear_optical(x, y):
"""
- Replace the weights of AllPassMORRCirculantLinear (y)
+ Replace the weights of AllPassMORRCirculantLinear (y)
with those from a standard nn.Linear (x).
Focuses only on weight copying (no bias copying).
"""
# Fetch original linear weight [out_features, in_features]
W = x.weight.data # shape: (out_features, in_features)
-
+
# Grab dimensions and zero-pad if needed
- out_features_pad = y.out_features_pad # padded out_features in y
- in_features_pad = y.in_features_pad # padded in_features in y
- miniblock = y.miniblock
- grid_dim_y = y.grid_dim_y
- grid_dim_x = y.grid_dim_x
-
+ out_features_pad = y.out_features_pad # padded out_features in y
+ in_features_pad = y.in_features_pad # padded in_features in y
+ miniblock = y.miniblock
+ grid_dim_y = y.grid_dim_y
+ grid_dim_x = y.grid_dim_x
+
# Construct padded weight tensor
W_padded = W.new_zeros((out_features_pad, in_features_pad))
W_padded[: W.size(0), : W.size(1)] = W # copy original into top-left
-
+
# Now we create a new tensor of shape [grid_dim_y, grid_dim_x, miniblock]
# by compressing each row-block [1 x miniblock] from W_padded into a single scalar.
# This is a simple example that takes the mean across the miniblock slice.
new_weight = W.new_zeros((grid_dim_y, grid_dim_x, miniblock))
-
+
# Fill new_weight by averaging the corresponding sub-blocks in W_padded
with torch.no_grad():
for p in range(grid_dim_y):
for q in range(grid_dim_x):
for k in range(miniblock):
# The row in W_padded we look at:
- row_idx = p * miniblock + k
+ row_idx = p * miniblock + k
# The columns we look at:
col_start = q * miniblock
- col_end = (q + 1) * miniblock
-
+ col_end = (q + 1) * miniblock
+
block = W_padded[row_idx, col_start:col_end]
new_weight[p, q, k] = block.mean()
-
+
# Copy the result into y.weight
y.load_parameters({"weight": new_weight})
# y.weight.data.copy_(new_weight)
-
- return y
+ return y
def weight_replacement_conv2d_optical(x, y):
@@ -91,8 +94,8 @@ def weight_replacement_conv2d_optical(x, y):
w_flat = x.weight.data.view(x.out_channels, -1)
# 3) Zero-pad to match (out_channels_pad, in_channels_pad)
- outC_pad = y.out_channels_pad # == y.grid_dim_y * y.miniblock
- inC_pad = y.in_channels_pad # == y.grid_dim_x * y.miniblock
+ outC_pad = y.out_channels_pad # == y.grid_dim_y * y.miniblock
+ inC_pad = y.in_channels_pad # == y.grid_dim_x * y.miniblock
W = torch.zeros(outC_pad, inC_pad, device=w_flat.device, dtype=w_flat.dtype)
# Copy as many channels/elements as we have
@@ -113,4 +116,4 @@ def weight_replacement_conv2d_optical(x, y):
# Done. At this point, y.weight and y.bias (if present) have been overwritten
# with a simple block-circulant approximation of x's parameters.
- return y
\ No newline at end of file
+ return y
diff --git a/src/chop/passes/module/transforms/optical/optical.py b/src/chop/passes/module/transforms/optical/optical.py
index a91e7403d..b911c070d 100644
--- a/src/chop/passes/module/transforms/optical/optical.py
+++ b/src/chop/passes/module/transforms/optical/optical.py
@@ -2,7 +2,9 @@
from chop.nn.optical.modules import optical_module_map
from chop.passes.module.module_modify_helper import replace_by_name, instantiate_module
-from chop.passes.module.transforms.optical.module_transform_helper import replace_by_name_optical
+from chop.passes.module.transforms.optical.module_transform_helper import (
+ replace_by_name_optical,
+)
def get_config(config: dict, name: str):
diff --git a/test/passes/module/transforms/optical/test_optical_module.py b/test/passes/module/transforms/optical/test_optical_module.py
index 35c959a70..45b540273 100644
--- a/test/passes/module/transforms/optical/test_optical_module.py
+++ b/test/passes/module/transforms/optical/test_optical_module.py
@@ -50,6 +50,7 @@
# x = self.fc3(x)
# return x
+
def load_my_model(model_path, device="cpu"):
# Load the model from the .pt file
loaded_model = torch.load(model_path, map_location=device)
@@ -86,6 +87,7 @@ def test_optical_module_transform_pass():
optical_module_transform_pass(mnist_cnn, pass_args)
# torch.save(onn_cnn, "mase_output/onn_cnn.pt")
+
test_optical_module_transform_pass()
# if __name__ == '__main__':
@@ -175,15 +177,14 @@ def test_optical_module_transform_pass():
# test(onn_model, device, test_loader)
# scheduler.step()
-
+
# torch.save(onn_model.state_dict(), "mase_output/trained_onn.pt")
-
+
# print("-------------- Testing the trained onn model -------------------")
# test(onn_model, device, test_loader)
# _, _ = report_trainable_parameters_analysis_pass(onn_model)
-
# def perform_optical_module_transform_pass(model, save_path="mase_output/onn_cnn.pt"):
# # pass_args = {
# # "by": "type",
diff --git a/test/passes/module/transforms/optical/train_mnist_cnn.py b/test/passes/module/transforms/optical/train_mnist_cnn.py
index de3bbe385..3d32593d5 100644
--- a/test/passes/module/transforms/optical/train_mnist_cnn.py
+++ b/test/passes/module/transforms/optical/train_mnist_cnn.py
@@ -43,9 +43,15 @@ def train(args, model, device, train_loader, optimizer, epoch):
loss.backward()
optimizer.step()
if batch_idx % args.log_interval == 0:
- print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
- epoch, batch_idx * len(data), len(train_loader.dataset),
- 100. * batch_idx / len(train_loader), loss.item()))
+ print(
+ "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
+ epoch,
+ batch_idx * len(data),
+ len(train_loader.dataset),
+ 100.0 * batch_idx / len(train_loader),
+ loss.item(),
+ )
+ )
if args.dry_run:
break
@@ -58,42 +64,95 @@ def test(model, device, test_loader):
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
- test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
- pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
+ test_loss += F.nll_loss(
+ output, target, reduction="sum"
+ ).item() # sum up batch loss
+ pred = output.argmax(
+ dim=1, keepdim=True
+ ) # get the index of the max log-probability
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
- print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
- test_loss, correct, len(test_loader.dataset),
- 100. * correct / len(test_loader.dataset)))
+ print(
+ "\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format(
+ test_loss,
+ correct,
+ len(test_loader.dataset),
+ 100.0 * correct / len(test_loader.dataset),
+ )
+ )
def main():
# Training settings
- parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
- parser.add_argument('--batch-size', type=int, default=64, metavar='N',
- help='input batch size for training (default: 64)')
- parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
- help='input batch size for testing (default: 1000)')
- parser.add_argument('--epochs', type=int, default=29, metavar='N',
- help='number of epochs to train (default: 14)')
- parser.add_argument('--lr', type=float, default=1.0, metavar='LR',
- help='learning rate (default: 1.0)')
- parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
- help='Learning rate step gamma (default: 0.7)')
- parser.add_argument('--no-cuda', action='store_true', default=False,
- help='disables CUDA training')
- parser.add_argument('--no-mps', action='store_true', default=False,
- help='disables macOS GPU training')
- parser.add_argument('--dry-run', action='store_true', default=False,
- help='quickly check a single pass')
- parser.add_argument('--seed', type=int, default=1, metavar='S',
- help='random seed (default: 1)')
- parser.add_argument('--log-interval', type=int, default=10, metavar='N',
- help='how many batches to wait before logging training status')
- parser.add_argument('--save-model', action='store_true', default=True,
- help='For Saving the current Model')
+ parser = argparse.ArgumentParser(description="PyTorch MNIST Example")
+ parser.add_argument(
+ "--batch-size",
+ type=int,
+ default=64,
+ metavar="N",
+ help="input batch size for training (default: 64)",
+ )
+ parser.add_argument(
+ "--test-batch-size",
+ type=int,
+ default=1000,
+ metavar="N",
+ help="input batch size for testing (default: 1000)",
+ )
+ parser.add_argument(
+ "--epochs",
+ type=int,
+ default=29,
+ metavar="N",
+ help="number of epochs to train (default: 14)",
+ )
+ parser.add_argument(
+ "--lr",
+ type=float,
+ default=1.0,
+ metavar="LR",
+ help="learning rate (default: 1.0)",
+ )
+ parser.add_argument(
+ "--gamma",
+ type=float,
+ default=0.7,
+ metavar="M",
+ help="Learning rate step gamma (default: 0.7)",
+ )
+ parser.add_argument(
+ "--no-cuda", action="store_true", default=False, help="disables CUDA training"
+ )
+ parser.add_argument(
+ "--no-mps",
+ action="store_true",
+ default=False,
+ help="disables macOS GPU training",
+ )
+ parser.add_argument(
+ "--dry-run",
+ action="store_true",
+ default=False,
+ help="quickly check a single pass",
+ )
+ parser.add_argument(
+ "--seed", type=int, default=1, metavar="S", help="random seed (default: 1)"
+ )
+ parser.add_argument(
+ "--log-interval",
+ type=int,
+ default=10,
+ metavar="N",
+ help="how many batches to wait before logging training status",
+ )
+ parser.add_argument(
+ "--save-model",
+ action="store_true",
+ default=True,
+ help="For Saving the current Model",
+ )
args = parser.parse_args()
use_cuda = not args.no_cuda and torch.cuda.is_available()
use_mps = not args.no_mps and torch.backends.mps.is_available()
@@ -107,24 +166,19 @@ def main():
else:
device = torch.device("cpu")
- train_kwargs = {'batch_size': args.batch_size}
- test_kwargs = {'batch_size': args.test_batch_size}
+ train_kwargs = {"batch_size": args.batch_size}
+ test_kwargs = {"batch_size": args.test_batch_size}
if use_cuda:
- cuda_kwargs = {'num_workers': 1,
- 'pin_memory': True,
- 'shuffle': True}
+ cuda_kwargs = {"num_workers": 1, "pin_memory": True, "shuffle": True}
train_kwargs.update(cuda_kwargs)
test_kwargs.update(cuda_kwargs)
- transform=transforms.Compose([
- transforms.ToTensor(),
- transforms.Normalize((0.1307,), (0.3081,))
- ])
- dataset1 = datasets.MNIST('../data', train=True, download=True,
- transform=transform)
- dataset2 = datasets.MNIST('../data', train=False,
- transform=transform)
- train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)
+ transform = transforms.Compose(
+ [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
+ )
+ dataset1 = datasets.MNIST("../data", train=True, download=True, transform=transform)
+ dataset2 = datasets.MNIST("../data", train=False, transform=transform)
+ train_loader = torch.utils.data.DataLoader(dataset1, **train_kwargs)
test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)
model = Net().to(device)
@@ -140,5 +194,5 @@ def main():
torch.save(model, "mase_output/sample_mnist_cnn.pt")
-if __name__ == '__main__':
+if __name__ == "__main__":
main()
From bd44a07035d04791254772ac7dae8c1cf54d762a Mon Sep 17 00:00:00 2001
From: Johnny1882
Date: Sun, 26 Jan 2025 22:13:58 +0000
Subject: [PATCH 04/38] fix optical test script
---
.../transforms/optical/test_optical_module.py | 64 +++---
.../transforms/optical/train_mnist_cnn.py | 198 ------------------
2 files changed, 37 insertions(+), 225 deletions(-)
delete mode 100644 test/passes/module/transforms/optical/train_mnist_cnn.py
diff --git a/test/passes/module/transforms/optical/test_optical_module.py b/test/passes/module/transforms/optical/test_optical_module.py
index 45b540273..f52acd559 100644
--- a/test/passes/module/transforms/optical/test_optical_module.py
+++ b/test/passes/module/transforms/optical/test_optical_module.py
@@ -26,29 +26,30 @@
from train_mnist_cnn import test, train, Net
-# --------------------------------------------------
-# Model specifications
-# --------------------------------------------------
-# class MLP(torch.nn.Module):
-# """
-# Toy quantized FC model for digit recognition on MNIST
-# """
-
-# def __init__(self) -> None:
-# super().__init__()
-
-# self.fc1 = nn.Linear(28 * 28, 28 * 28)
-# self.fc2 = nn.Linear(28 * 28, 28 * 28 * 4)
-# self.fc3 = nn.Linear(28 * 28 * 4, 10)
-
-# def forward(self, x):
-# x = torch.flatten(x, start_dim=1, end_dim=-1)
-# x = torch.nn.functional.relu(self.fc1(x))
-# # w = torch.randn((4, 28 * 28))
-# # x = torch.nn.functional.relu(nn.functional.linear(x, w))
-# x = torch.nn.functional.relu(self.fc2(x))
-# x = self.fc3(x)
-# return x
+class Net(nn.Module):
+ def __init__(self):
+ super(Net, self).__init__()
+ self.conv1 = nn.Conv2d(1, 32, 3, 1)
+ self.conv2 = nn.Conv2d(32, 64, 3, 1)
+ self.dropout1 = nn.Dropout(0.25)
+ self.dropout2 = nn.Dropout(0.5)
+ self.fc1 = nn.Linear(9216, 128)
+ self.fc2 = nn.Linear(128, 10)
+
+ def forward(self, x):
+ x = self.conv1(x)
+ x = F.relu(x)
+ x = self.conv2(x)
+ x = F.relu(x)
+ x = F.max_pool2d(x, 2)
+ x = self.dropout1(x)
+ x = torch.flatten(x, 1)
+ x = self.fc1(x)
+ x = F.relu(x)
+ x = self.dropout2(x)
+ x = self.fc2(x)
+ output = F.log_softmax(x, dim=1)
+ return output
def load_my_model(model_path, device="cpu"):
@@ -60,8 +61,9 @@ def load_my_model(model_path, device="cpu"):
def test_optical_module_transform_pass():
- model_path = "mase_output/sample_mnist_cnn.pt"
- mnist_cnn = load_my_model(model_path)
+ # model_path = "mase_output/sample_mnist_cnn.pt"
+ # mnist_cnn = load_my_model(model_path)
+ model = Net()
# Sanity check and report
pass_args = {
"by": "name",
@@ -84,12 +86,20 @@ def test_optical_module_transform_pass():
}
},
}
- optical_module_transform_pass(mnist_cnn, pass_args)
+ optical_module_transform_pass(model, pass_args)
# torch.save(onn_cnn, "mase_output/onn_cnn.pt")
-
test_optical_module_transform_pass()
+
+
+
+
+
+
+
+
+
# if __name__ == '__main__':
# finetune = False
diff --git a/test/passes/module/transforms/optical/train_mnist_cnn.py b/test/passes/module/transforms/optical/train_mnist_cnn.py
deleted file mode 100644
index 3d32593d5..000000000
--- a/test/passes/module/transforms/optical/train_mnist_cnn.py
+++ /dev/null
@@ -1,198 +0,0 @@
-import argparse
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-import torch.optim as optim
-from torchvision import datasets, transforms
-from torch.optim.lr_scheduler import StepLR
-
-
-class Net(nn.Module):
- def __init__(self):
- super(Net, self).__init__()
- self.conv1 = nn.Conv2d(1, 32, 3, 1)
- self.conv2 = nn.Conv2d(32, 64, 3, 1)
- self.dropout1 = nn.Dropout(0.25)
- self.dropout2 = nn.Dropout(0.5)
- self.fc1 = nn.Linear(9216, 128)
- self.fc2 = nn.Linear(128, 10)
-
- def forward(self, x):
- x = self.conv1(x)
- x = F.relu(x)
- x = self.conv2(x)
- x = F.relu(x)
- x = F.max_pool2d(x, 2)
- x = self.dropout1(x)
- x = torch.flatten(x, 1)
- x = self.fc1(x)
- x = F.relu(x)
- x = self.dropout2(x)
- x = self.fc2(x)
- output = F.log_softmax(x, dim=1)
- return output
-
-
-def train(args, model, device, train_loader, optimizer, epoch):
- model.train()
- for batch_idx, (data, target) in enumerate(train_loader):
- data, target = data.to(device), target.to(device)
- optimizer.zero_grad()
- output = model(data)
- loss = F.nll_loss(output, target)
- loss.backward()
- optimizer.step()
- if batch_idx % args.log_interval == 0:
- print(
- "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
- epoch,
- batch_idx * len(data),
- len(train_loader.dataset),
- 100.0 * batch_idx / len(train_loader),
- loss.item(),
- )
- )
- if args.dry_run:
- break
-
-
-def test(model, device, test_loader):
- model.eval()
- test_loss = 0
- correct = 0
- with torch.no_grad():
- for data, target in test_loader:
- data, target = data.to(device), target.to(device)
- output = model(data)
- test_loss += F.nll_loss(
- output, target, reduction="sum"
- ).item() # sum up batch loss
- pred = output.argmax(
- dim=1, keepdim=True
- ) # get the index of the max log-probability
- correct += pred.eq(target.view_as(pred)).sum().item()
-
- test_loss /= len(test_loader.dataset)
-
- print(
- "\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format(
- test_loss,
- correct,
- len(test_loader.dataset),
- 100.0 * correct / len(test_loader.dataset),
- )
- )
-
-
-def main():
- # Training settings
- parser = argparse.ArgumentParser(description="PyTorch MNIST Example")
- parser.add_argument(
- "--batch-size",
- type=int,
- default=64,
- metavar="N",
- help="input batch size for training (default: 64)",
- )
- parser.add_argument(
- "--test-batch-size",
- type=int,
- default=1000,
- metavar="N",
- help="input batch size for testing (default: 1000)",
- )
- parser.add_argument(
- "--epochs",
- type=int,
- default=29,
- metavar="N",
- help="number of epochs to train (default: 14)",
- )
- parser.add_argument(
- "--lr",
- type=float,
- default=1.0,
- metavar="LR",
- help="learning rate (default: 1.0)",
- )
- parser.add_argument(
- "--gamma",
- type=float,
- default=0.7,
- metavar="M",
- help="Learning rate step gamma (default: 0.7)",
- )
- parser.add_argument(
- "--no-cuda", action="store_true", default=False, help="disables CUDA training"
- )
- parser.add_argument(
- "--no-mps",
- action="store_true",
- default=False,
- help="disables macOS GPU training",
- )
- parser.add_argument(
- "--dry-run",
- action="store_true",
- default=False,
- help="quickly check a single pass",
- )
- parser.add_argument(
- "--seed", type=int, default=1, metavar="S", help="random seed (default: 1)"
- )
- parser.add_argument(
- "--log-interval",
- type=int,
- default=10,
- metavar="N",
- help="how many batches to wait before logging training status",
- )
- parser.add_argument(
- "--save-model",
- action="store_true",
- default=True,
- help="For Saving the current Model",
- )
- args = parser.parse_args()
- use_cuda = not args.no_cuda and torch.cuda.is_available()
- use_mps = not args.no_mps and torch.backends.mps.is_available()
-
- torch.manual_seed(args.seed)
-
- if use_cuda:
- device = torch.device("cuda")
- elif use_mps:
- device = torch.device("mps")
- else:
- device = torch.device("cpu")
-
- train_kwargs = {"batch_size": args.batch_size}
- test_kwargs = {"batch_size": args.test_batch_size}
- if use_cuda:
- cuda_kwargs = {"num_workers": 1, "pin_memory": True, "shuffle": True}
- train_kwargs.update(cuda_kwargs)
- test_kwargs.update(cuda_kwargs)
-
- transform = transforms.Compose(
- [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
- )
- dataset1 = datasets.MNIST("../data", train=True, download=True, transform=transform)
- dataset2 = datasets.MNIST("../data", train=False, transform=transform)
- train_loader = torch.utils.data.DataLoader(dataset1, **train_kwargs)
- test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)
-
- model = Net().to(device)
- optimizer = optim.Adadelta(model.parameters(), lr=args.lr)
-
- scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
- for epoch in range(1, args.epochs + 1):
- train(args, model, device, train_loader, optimizer, epoch)
- test(model, device, test_loader)
- scheduler.step()
-
- if args.save_model:
- torch.save(model, "mase_output/sample_mnist_cnn.pt")
-
-
-if __name__ == "__main__":
- main()
From 913d8faeb8f8eb85340bd79b460c2995a7327108 Mon Sep 17 00:00:00 2001
From: Johnny1882
Date: Sun, 26 Jan 2025 22:47:57 +0000
Subject: [PATCH 05/38] black format fixed
---
.../module/transforms/optical/test_optical_module.py | 10 ++--------
1 file changed, 2 insertions(+), 8 deletions(-)
diff --git a/test/passes/module/transforms/optical/test_optical_module.py b/test/passes/module/transforms/optical/test_optical_module.py
index f52acd559..2041b9ccf 100644
--- a/test/passes/module/transforms/optical/test_optical_module.py
+++ b/test/passes/module/transforms/optical/test_optical_module.py
@@ -26,6 +26,7 @@
from train_mnist_cnn import test, train, Net
+
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
@@ -89,15 +90,8 @@ def test_optical_module_transform_pass():
optical_module_transform_pass(model, pass_args)
# torch.save(onn_cnn, "mase_output/onn_cnn.pt")
-test_optical_module_transform_pass()
-
-
-
-
-
-
-
+test_optical_module_transform_pass()
# if __name__ == '__main__':
From b681e30fc5c6bec4641967fe5b82ec659d2152c5 Mon Sep 17 00:00:00 2001
From: Johnny1882
Date: Sun, 26 Jan 2025 23:06:04 +0000
Subject: [PATCH 06/38] remove unnecessary import
---
test/passes/module/transforms/optical/test_optical_module.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/test/passes/module/transforms/optical/test_optical_module.py b/test/passes/module/transforms/optical/test_optical_module.py
index 2041b9ccf..e61489f55 100644
--- a/test/passes/module/transforms/optical/test_optical_module.py
+++ b/test/passes/module/transforms/optical/test_optical_module.py
@@ -24,7 +24,7 @@
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
-from train_mnist_cnn import test, train, Net
+# from train_mnist_cnn import test, train, Net
class Net(nn.Module):
From 3d99177d960bba882923931905932c8bc7b429db Mon Sep 17 00:00:00 2001
From: Johnny1882
Date: Tue, 28 Jan 2025 15:44:56 +0000
Subject: [PATCH 07/38] fix cuda error for morr_conv2d and morr_linear
---
src/chop/nn/optical/modules/morr_conv2d.py | 2 +-
src/chop/nn/optical/modules/morr_linear.py | 2 +-
.../module/transforms/optical/test_optical_module.py | 11 -----------
3 files changed, 2 insertions(+), 13 deletions(-)
diff --git a/src/chop/nn/optical/modules/morr_conv2d.py b/src/chop/nn/optical/modules/morr_conv2d.py
index a554bc6c2..3cc56ad2f 100644
--- a/src/chop/nn/optical/modules/morr_conv2d.py
+++ b/src/chop/nn/optical/modules/morr_conv2d.py
@@ -81,7 +81,7 @@ def __init__(
# trainable_morr_bias: bool = False,
# trainable_morr_scale: bool = False,
config=None,
- device: Device = torch.device("cuda"),
+ device: Device = torch.device("cuda" if torch.cuda.is_available() else "cpu"),
) -> None:
super(AllPassMORRCirculantConv2d, self).__init__()
miniblock = config.get("miniblock", 4)
diff --git a/src/chop/nn/optical/modules/morr_linear.py b/src/chop/nn/optical/modules/morr_linear.py
index 9a3f483e5..af85357a2 100644
--- a/src/chop/nn/optical/modules/morr_linear.py
+++ b/src/chop/nn/optical/modules/morr_linear.py
@@ -52,7 +52,7 @@ def __init__(
# ### trainable MORR nonlinearity
# trainable_morr_bias: bool = False,
# trainable_morr_scale: bool = False,
- device: Device = torch.device("cuda"),
+ device: Device = torch.device("cuda" if torch.cuda.is_available() else "cpu"),
) -> None:
super(AllPassMORRCirculantLinear, self).__init__()
self.in_features = in_features
diff --git a/test/passes/module/transforms/optical/test_optical_module.py b/test/passes/module/transforms/optical/test_optical_module.py
index e61489f55..d05512cc1 100644
--- a/test/passes/module/transforms/optical/test_optical_module.py
+++ b/test/passes/module/transforms/optical/test_optical_module.py
@@ -14,17 +14,6 @@
# from chop.passes.module.transforms import quantize_module_transform_pass
from chop.passes.module.transforms import optical_module_transform_pass
-from chop.passes.module import report_trainable_parameters_analysis_pass
-
-import argparse
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-import torch.optim as optim
-from torchvision import datasets, transforms
-from torch.optim.lr_scheduler import StepLR
-
-# from train_mnist_cnn import test, train, Net
class Net(nn.Module):
From 09da13d3c1445a2a9284c9dd8d3e2b7ca3d58da7 Mon Sep 17 00:00:00 2001
From: Cheng Zhang
Date: Tue, 11 Feb 2025 16:19:00 +0000
Subject: [PATCH 08/38] reduce sphinx docstring error
---
.../documentation/tutorials/advanced/cli.md | 18 +--
.../advanced/onnxrt_quantization_tutorial.md | 69 ++++-----
.../tensorRT_quantization_tutorial.md | 137 +++++++++---------
.../mixed_precision_search_on_manual_model.md | 16 +-
.../mixed_precision_search_on_mase_graph.md | 6 +-
.../tutorials/cli/simple_train_flow.md | 8 +-
.../developer/Add-model-to-machop.md | 6 +-
.../modules/hardware/activations/gelu.md | 2 +-
.../modules/hardware/activations/selu.md | 4 +-
.../modules/hardware/activations/softplus.md | 2 +-
.../modules/hardware/linear/fixed_linear.md | 4 +-
.../systolic_modules/output_stationary.md | 2 +-
.../modules/labs_2024/lab_0_introduction.rst | 2 +-
13 files changed, 137 insertions(+), 139 deletions(-)
diff --git a/docs/source/modules/documentation/tutorials/advanced/cli.md b/docs/source/modules/documentation/tutorials/advanced/cli.md
index 7e418f7a6..0b126316c 100644
--- a/docs/source/modules/documentation/tutorials/advanced/cli.md
+++ b/docs/source/modules/documentation/tutorials/advanced/cli.md
@@ -16,7 +16,7 @@ In this case, we can try a toymodel, the command looks like the following
```bash
# assuming you you are at the our-stuff/mase directory
-cd src
+cd src
./ch train toy toy_tiny --config ../configs/archive/test/train.toml --max-epochs 3
```
@@ -24,7 +24,7 @@ cd src
You can fetch all command-line arguments:
-```bash
+```text
[nix-shell:~/Projects/mase/src]$ ./ch -help
INFO Set logging level to debug
WARNING TensorRT pass is unavailable because the following dependencies are not installed: pytorch_quantization, tensorrt, pycuda, cuda.
@@ -180,7 +180,7 @@ This directory includes
### Training Logs
-MASE creates [Tensorboard](https://pytorch.org/tutorials/recipes/recipes/tensorboard_with_pytorch.html) or [wandb](https://wandb.ai/site) logs for the training flow - allowing tracking and visualizing metrics such as loss and accuracy. The log files are in `/software/tensorboard/lightning_logs/version_`.
+MASE creates [Tensorboard](https://pytorch.org/tutorials/recipes/recipes/tensorboard_with_pytorch.html) or [wandb](https://wandb.ai/site) logs for the training flow - allowing tracking and visualizing metrics such as loss and accuracy. The log files are in `/software/tensorboard/lightning_logs/version_`.
Run Tensorboard to visualise the logs using:
@@ -213,7 +213,7 @@ To test the model trained above you can use:
```bash
# After training, you will have your checkpoint under mase-tools/mase_output
-# For example, the checkpoint is under ../mase_output/toy_classification_toy-tiny_2023-07-03/software/training_ckpts/best.ckpt
+# For example, the checkpoint is under ../mase_output/toy_classification_toy-tiny_2023-07-03/software/training_ckpts/best.ckpt
./ch test toy toy_tiny --config ../configs/archive/test/train.toml --load ../mase_output/toy_classification_toy_tiny_2024-06-13/software/training_ckpts/best.ckpt```
@@ -383,7 +383,7 @@ cd machop
```
When the search is done, the best quantization config will be printed out. Since we run multi-objective search. There may be multiple best trials found by Optuna.
-```txt
+```text
Best trial(s):
| | number | software_metrics | hardware_metrics | scaled_metrics |
|----+----------+--------------------------------------+------------------------------------------------------+-------------------------------------------------|
@@ -401,7 +401,7 @@ Here is part of the `log.json` recording all search details.
For example, `log["0"]["user_attrs_sampled_config"]` is the sampled quantization config of trial 0. Expand it and you will set the precision of each matmul/linear layer's operands.
-```json
+```text
{
"0":{
"number":0,
@@ -481,13 +481,13 @@ This tutorial shows how to search for mixed-precision quantization strategy for
### Commands
-First we train a model on the dataset. After training for some epochs, we get a model with some validation accuracy. The checkpoint is saved at an auto-created location. You can refer to [Run the train action with the CLI](../train/simple_train_flow.md) for more detailed explanation.
+First we train a model on the dataset. After training for some epochs, we get a model with some validation accuracy. The checkpoint is saved at an auto-created location. You can refer to *Run the train action with the CLI* for more detailed explanation.
The reason why we need a pre-trained model is because we would like to do a post-training-quantization (PTQ) search. This means the quantization happens on a pre-trained model. We then use the PTQ accuracy as a proxy signal for our search.
```bash
-cd src
+cd src
./ch train jsc-tiny jsc --max-epochs 3 --batch-size 256 --accelerator cpu --project tmp --debug --cpu 0
```
@@ -628,7 +628,7 @@ The entire searching log is saved in `../mase_output/jsc-tiny/software/search_ck
Here is part of the `log.json`
-```json
+```text
{
"0":{
"number":0,
diff --git a/docs/source/modules/documentation/tutorials/advanced/onnxrt_quantization_tutorial.md b/docs/source/modules/documentation/tutorials/advanced/onnxrt_quantization_tutorial.md
index e01c120b4..255697394 100644
--- a/docs/source/modules/documentation/tutorials/advanced/onnxrt_quantization_tutorial.md
+++ b/docs/source/modules/documentation/tutorials/advanced/onnxrt_quantization_tutorial.md
@@ -54,7 +54,7 @@ set_logging_verbosity("info")
I0329 13:46:21.531338 140128553666368 logger.py:44] Set logging level to info
-We then load in a demonstration toml file and set the relevant pass arguments (this is all done automatically if we were to use the command line, see [Section 2](#section-2-int8-quantization))
+We then load in a demonstration toml file and set the relevant pass arguments (this is all done automatically if we were to use the command line, see Section 2).
```python
@@ -120,8 +120,8 @@ mg = MaseGraph(model=model)
Next, we train the `jsc-toy` model using the machop `train` action with the config from the toml file. You may want to switch to GPU for this task - it will not affect the cpu optimizations later on.
-```python
-!ch train --config {JSC_TOML_PATH} --accelerator gpu
+```bash
+ch train --config {JSC_TOML_PATH} --accelerator gpu
```
Then we load in the checkpoint. You will have to adjust this according to where it has been stored in the mase_output directory.
@@ -154,7 +154,7 @@ We then run the `onnx_runtime_interface_pass` which completes the optimizations
- `onnx_path` (the optimized model)
- `onnx_dynamic_quantized_path` (the dynamically )
-In this case, since we are not quantizing the model, only the `onnx_path` is available.
+In this case, since we are not quantizing the model, only the `onnx_path` is available.
The models are also stored in the directory:
```
@@ -178,7 +178,7 @@ mg, onnx_meta = onnx_runtime_interface_pass(mg, pass_args=onnx_config)
I0327 14:20:12.539771 140012160939840 onnx_runtime.py:50] Project will be created at /root/mase/mase_output/onnxrt/jsc-toy_cls_jsc_2024-03-27
[32mINFO [0m [34mONNX Conversion Complete. Stored ONNX model to /root/mase/mase_output/onnxrt/jsc-toy_cls_jsc_2024-03-27/optimized/version_1/model.onnx[0m
I0327 14:20:12.751212 140012160939840 onnx_runtime.py:68] ONNX Conversion Complete. Stored ONNX model to /root/mase/mase_output/onnxrt/jsc-toy_cls_jsc_2024-03-27/optimized/version_1/model.onnx
- [32mINFO [0m [34mONNX Model Summary:
+ [32mINFO [0m [34mONNX Model Summary:
+-------+----------------------------------+--------------------+--------------------------------------------------------------------------------------------------------------------------+-------------------------------------------+---------------------+
| Index | Name | Type | Inputs | Outputs | Attributes |
+-------+----------------------------------+--------------------+--------------------------------------------------------------------------------------------------------------------------+-------------------------------------------+---------------------+
@@ -194,7 +194,7 @@ mg, onnx_meta = onnx_runtime_interface_pass(mg, pass_args=onnx_config)
| 9 | /seq_blocks.9/BatchNormalization | BatchNormalization | /seq_blocks.8/Gemm_output_0, seq_blocks.9.weight, seq_blocks.9.bias, seq_blocks.9.running_mean, seq_blocks.9.running_var | /seq_blocks.9/BatchNormalization_output_0 | epsilon, momentum |
| 10 | /seq_blocks.10/Relu | Relu | /seq_blocks.9/BatchNormalization_output_0 | 37 | |
+-------+----------------------------------+--------------------+--------------------------------------------------------------------------------------------------------------------------+-------------------------------------------+---------------------+[0m
- I0327 14:20:12.757548 140012160939840 onnx_runtime.py:90] ONNX Model Summary:
+ I0327 14:20:12.757548 140012160939840 onnx_runtime.py:90] ONNX Model Summary:
+-------+----------------------------------+--------------------+--------------------------------------------------------------------------------------------------------------------------+-------------------------------------------+---------------------+
| Index | Name | Type | Inputs | Outputs | Attributes |
+-------+----------------------------------+--------------------+--------------------------------------------------------------------------------------------------------------------------+-------------------------------------------+---------------------+
@@ -237,7 +237,7 @@ _, _ = runtime_analysis_pass(mg_original, pass_args=runtime_analysis_config)
| Average GPU Power Usage | 21.816 W |
| Inference Energy Consumption | 0.0048292 mWh |
+------------------------------+---------------+[0m
- I0327 14:20:19.793779 140012160939840 analysis.py:398]
+ I0327 14:20:19.793779 140012160939840 analysis.py:398]
Results jsc-toy:
+------------------------------+---------------+
| Metric (Per Batch) | Value |
@@ -280,7 +280,7 @@ _, _ = runtime_analysis_pass(onnx_meta['onnx_path'], pass_args=runtime_analysis_
| Average GPU Power Usage | 21.575 W |
| Inference Energy Consumption | 0.0013275 mWh |
+------------------------------+---------------+[0m
- I0327 14:20:35.876071 140012160939840 analysis.py:398]
+ I0327 14:20:35.876071 140012160939840 analysis.py:398]
Results jsc-toy-onnx:
+------------------------------+---------------+
| Metric (Per Batch) | Value |
@@ -298,14 +298,14 @@ _, _ = runtime_analysis_pass(onnx_meta['onnx_path'], pass_args=runtime_analysis_
I0327 14:20:35.878773 140012160939840 analysis.py:84] Runtime analysis results saved to /root/mase_output/tensorrt/quantization/jsc-toy_cls_jsc_2024-03-27/onnx/version_0/model.json
-As shown above, the latency of the cpu inference is around 3.5x less with the `jsc-toy` model without compromising accuracy simply by using the optimizations of ONNXRT.
+As shown above, the latency of the cpu inference is around 3.5x less with the `jsc-toy` model without compromising accuracy simply by using the optimizations of ONNXRT.
Lets now run the same optimzations, this time using a GPU and a larger model - the `vgg7`. We will also utilse the chop action from the terminal which runs the same `onnx_runtime_interface_pass` pass.
First lets train the `vgg7` model using the machop `train` action with the config from the new toml file and then load the trained checkpoint it into the `transform` pass.
-```python
+```bash
VGG_TOML_PATH = "../../../machop/configs/onnx/vgg7_gpu_quant.toml"
# !ch train --config {VGG_TOML_PATH}
@@ -313,7 +313,7 @@ VGG_TOML_PATH = "../../../machop/configs/onnx/vgg7_gpu_quant.toml"
# Load in the checkpoint from the previous train - modify accordingly
VGG_CHECKPOINT_PATH = "../../../mase_output/vgg7-pre-trained/test-accu-0.9332.ckpt"
-!ch transform --config {VGG_TOML_PATH} --load {VGG_CHECKPOINT_PATH} --load-type pl
+ch transform --config {VGG_TOML_PATH} --load {VGG_CHECKPOINT_PATH} --load-type pl
```
[2024-03-28 23:09:44,122] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)
@@ -387,7 +387,7 @@ VGG_CHECKPOINT_PATH = "../../../mase_output/vgg7-pre-trained/test-accu-0.9332.ck
I0328 23:10:35.119032 140014036379456 onnx_runtime.py:90] Project will be created at /root/mase/mase_output/onnxrt/vgg7_cls_cifar10_2024-03-28
[32mINFO [0m [34mONNX Conversion Complete. Stored ONNX model to /root/mase/mase_output/onnxrt/vgg7_cls_cifar10_2024-03-28/optimized/version_3/model.onnx[0m
I0328 23:10:43.779212 140014036379456 onnx_runtime.py:108] ONNX Conversion Complete. Stored ONNX model to /root/mase/mase_output/onnxrt/vgg7_cls_cifar10_2024-03-28/optimized/version_3/model.onnx
- [32mINFO [0m [34mONNX Model Summary:
+ [32mINFO [0m [34mONNX Model Summary:
+-------+----------------------------+----------+---------------------------------------------------------------------+-------------------------------------+---------------------------------------------------+
| Index | Name | Type | Inputs | Outputs | Attributes |
+-------+----------------------------+----------+---------------------------------------------------------------------+-------------------------------------+---------------------------------------------------+
@@ -414,7 +414,7 @@ VGG_CHECKPOINT_PATH = "../../../mase_output/vgg7-pre-trained/test-accu-0.9332.ck
| 20 | /classifier.3/Relu | Relu | /classifier.2/Gemm_output_0 | /classifier.3/Relu_output_0 | |
| 21 | /last_layer/Gemm | Gemm | /classifier.3/Relu_output_0, last_layer.weight, last_layer.bias | 76 | alpha, beta, transB |
+-------+----------------------------+----------+---------------------------------------------------------------------+-------------------------------------+---------------------------------------------------+[0m
- I0328 23:10:43.897069 140014036379456 onnx_runtime.py:146] ONNX Model Summary:
+ I0328 23:10:43.897069 140014036379456 onnx_runtime.py:146] ONNX Model Summary:
+-------+----------------------------+----------+---------------------------------------------------------------------+-------------------------------------+---------------------------------------------------+
| Index | Name | Type | Inputs | Outputs | Attributes |
+-------+----------------------------+----------+---------------------------------------------------------------------+-------------------------------------+---------------------------------------------------+
@@ -485,7 +485,7 @@ VGG_CHECKPOINT_PATH = "../../../mase_output/vgg7-pre-trained/test-accu-0.9332.ck
| Average GPU Power Usage | 50.352 W |
| Inference Energy Consumption | 0.12032 mWh |
+------------------------------+-------------+[0m
- I0328 23:15:43.756563 140014036379456 runtime_analysis.py:521]
+ I0328 23:15:43.756563 140014036379456 runtime_analysis.py:521]
Results vgg7:
+------------------------------+-------------+
| Metric (Per Batch) | Value |
@@ -521,7 +521,7 @@ VGG_CHECKPOINT_PATH = "../../../mase_output/vgg7-pre-trained/test-accu-0.9332.ck
| Average GPU Power Usage | 55.829 W |
| Inference Energy Consumption | 0.094099 mWh |
+------------------------------+--------------+[0m
- I0328 23:15:53.476423 140014036379456 runtime_analysis.py:521]
+ I0328 23:15:53.476423 140014036379456 runtime_analysis.py:521]
Results vgg7-onnx:
+------------------------------+--------------+
| Metric (Per Batch) | Value |
@@ -560,7 +560,7 @@ VGG_CHECKPOINT_PATH = "../../../mase_output/vgg7-pre-trained/test-accu-0.9332.ck
| Average GPU Power Usage | 54.555 W |
| Inference Energy Consumption | 0.14388 mWh |
+------------------------------+-------------+[0m
- I0328 23:16:03.469463 140014036379456 runtime_analysis.py:521]
+ I0328 23:16:03.469463 140014036379456 runtime_analysis.py:521]
Results vgg7-onnx:
+------------------------------+-------------+
| Metric (Per Batch) | Value |
@@ -599,7 +599,7 @@ VGG_CHECKPOINT_PATH = "../../../mase_output/vgg7-pre-trained/test-accu-0.9332.ck
| Average GPU Power Usage | 22.86 W |
| Inference Energy Consumption | 0.1748 mWh |
+------------------------------+------------+[0m
- I0328 23:18:23.964464 140014036379456 runtime_analysis.py:521]
+ I0328 23:18:23.964464 140014036379456 runtime_analysis.py:521]
Results vgg7-onnx:
+------------------------------+------------+
| Metric (Per Batch) | Value |
@@ -635,7 +635,7 @@ VGG_CHECKPOINT_PATH = "../../../mase_output/vgg7-pre-trained/test-accu-0.9332.ck
| Average GPU Power Usage | 50.354 W |
| Inference Energy Consumption | 0.076714 mWh |
+------------------------------+--------------+[0m
- I0328 23:18:33.854084 140014036379456 runtime_analysis.py:521]
+ I0328 23:18:33.854084 140014036379456 runtime_analysis.py:521]
Results vgg7-onnx:
+------------------------------+--------------+
| Metric (Per Batch) | Value |
@@ -653,13 +653,13 @@ VGG_CHECKPOINT_PATH = "../../../mase_output/vgg7-pre-trained/test-accu-0.9332.ck
I0328 23:18:33.855542 140014036379456 runtime_analysis.py:143] Runtime analysis results saved to /root/mase_output/tensorrt/quantization/vgg7_cls_cifar10_2024-03-28/onnx/version_12/model.json
-As shown above, the latency of the gpu inference is 30% less with the `vgg7` model without compromising accuracy simply by using the optimizations of ONNXRT.
+As shown above, the latency of the gpu inference is 30% less with the `vgg7` model without compromising accuracy simply by using the optimizations of ONNXRT.
-We will now look at quantization to further speed up the model.
+We will now look at quantization to further speed up the model.
## Section 2. Quantization
-We may quantize either using FP16 or INT8 by setting the `precision` parameter in `passes.onnxruntime.default.config` to `'fp16'` or `'int8'` respectively. INT8 quantization will show the most notable latency improvements but is more likely to lower performance.
+We may quantize either using FP16 or INT8 by setting the `precision` parameter in `passes.onnxruntime.default.config` to `'fp16'` or `'int8'` respectively. INT8 quantization will show the most notable latency improvements but is more likely to lower performance.
There are three types of quantization for ONNXRT and can be set in `onnxruntime.default.config` under `quantization_types`. The differences of the first two are for how they calibrate i.e. set the scale and zero points which are only relevant for integer based quantization:
- **Static Quantization**:
@@ -670,26 +670,27 @@ There are three types of quantization for ONNXRT and can be set in `onnxruntime.
- The scale and zero point of activations are calculated on-the-fly (online) and are specific for each forward pass.
- This approach is more accurate but introduces extra computational overhead
-The `onnx_runtime_interface_pass` pass also supports mixed precision. This is an automatic only procedure, where ONNXRT finds a minimal set of ops to skip while retaining a certain level of accuracy, converting most of the ops to float16 but leaving some in float32.
+The `onnx_runtime_interface_pass` pass also supports mixed precision. This is an automatic only procedure, where ONNXRT finds a minimal set of ops to skip while retaining a certain level of accuracy, converting most of the ops to float16 but leaving some in float32.
- **Auto Mixed Precision Quantization**:
- Automatically adjusts between FP16 and FP32 precisions to retain certain level of accuracy
- The `precision` parameter does not need to be set in the config since the whole process is automatic.
- Unfortunately, this process is currently only supported on GPU.
- This approach is most beneficial when INT8 or FP16 exclusive quantizations (static or dynamic) are giving poor results.
-All three methodolgies first pre-procsses the model before quantization adding further optimizations. This intermidate model is stored to the `pre-processed` directory.
+All three methodolgies first pre-procsses the model before quantization adding further optimizations. This intermidate model is stored to the `pre-processed` directory.
-For this example, we will set the `precision` to `'uint8'` (since `ConvInteger` node is not currently supported for `'int8'` on ONNXRT GPU execution provider).
+For this example, we will set the `precision` to `'uint8'` (since `ConvInteger` node is not currently supported for `'int8'` on ONNXRT GPU execution provider).
We will also set the `precision_types` to `['static', 'dynamic', 'auto']` to compare all three quantization methods, whilst keeping the other settings the exact same for a fair comparison against the optimized `vgg7` model used in the previous section.
-```python
+```bash
VGG_TOML_PATH = "../../../machop/configs/onnx/vgg7_gpu_quant.toml"
VGG_CHECKPOINT_PATH = "../../../mase_output/vgg7-pre-trained/test-accu-0.9332.ckpt"
!ch transform --config {VGG_TOML_PATH} --load {VGG_CHECKPOINT_PATH} --load-type pl
```
+```text
[2024-03-29 13:49:26,029] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)
INFO: Seed set to 0
WARNING: Logging before flag parsing goes to stderr.
@@ -761,7 +762,7 @@ VGG_CHECKPOINT_PATH = "../../../mase_output/vgg7-pre-trained/test-accu-0.9332.ck
I0329 13:50:32.508896 139783521261376 onnx_runtime.py:90] Project will be created at /root/mase/mase_output/onnxrt/vgg7_cls_cifar10_2024-03-29
[32mINFO [0m [34mONNX Conversion Complete. Stored ONNX model to /root/mase/mase_output/onnxrt/vgg7_cls_cifar10_2024-03-29/optimized/version_0/model.onnx[0m
I0329 13:50:53.587861 139783521261376 onnx_runtime.py:108] ONNX Conversion Complete. Stored ONNX model to /root/mase/mase_output/onnxrt/vgg7_cls_cifar10_2024-03-29/optimized/version_0/model.onnx
- [32mINFO [0m [34mONNX Model Summary:
+ [32mINFO [0m [34mONNX Model Summary:
+-------+----------------------------+----------+---------------------------------------------------------------------+-------------------------------------+---------------------------------------------------+
| Index | Name | Type | Inputs | Outputs | Attributes |
+-------+----------------------------+----------+---------------------------------------------------------------------+-------------------------------------+---------------------------------------------------+
@@ -788,7 +789,7 @@ VGG_CHECKPOINT_PATH = "../../../mase_output/vgg7-pre-trained/test-accu-0.9332.ck
| 20 | /classifier.3/Relu | Relu | /classifier.2/Gemm_output_0 | /classifier.3/Relu_output_0 | |
| 21 | /last_layer/Gemm | Gemm | /classifier.3/Relu_output_0, last_layer.weight, last_layer.bias | 76 | alpha, beta, transB |
+-------+----------------------------+----------+---------------------------------------------------------------------+-------------------------------------+---------------------------------------------------+[0m
- I0329 13:50:53.719763 139783521261376 onnx_runtime.py:146] ONNX Model Summary:
+ I0329 13:50:53.719763 139783521261376 onnx_runtime.py:146] ONNX Model Summary:
+-------+----------------------------+----------+---------------------------------------------------------------------+-------------------------------------+---------------------------------------------------+
| Index | Name | Type | Inputs | Outputs | Attributes |
+-------+----------------------------+----------+---------------------------------------------------------------------+-------------------------------------+---------------------------------------------------+
@@ -859,7 +860,7 @@ VGG_CHECKPOINT_PATH = "../../../mase_output/vgg7-pre-trained/test-accu-0.9332.ck
| Average GPU Power Usage | 52.579 W |
| Inference Energy Consumption | 0.125 mWh |
+------------------------------+-----------+[0m
- I0329 13:55:58.685605 139783521261376 runtime_analysis.py:521]
+ I0329 13:55:58.685605 139783521261376 runtime_analysis.py:521]
Results vgg7:
+------------------------------+-----------+
| Metric (Per Batch) | Value |
@@ -895,7 +896,7 @@ VGG_CHECKPOINT_PATH = "../../../mase_output/vgg7-pre-trained/test-accu-0.9332.ck
| Average GPU Power Usage | 53.26 W |
| Inference Energy Consumption | 0.089695 mWh |
+------------------------------+--------------+[0m
- I0329 13:56:20.459139 139783521261376 runtime_analysis.py:521]
+ I0329 13:56:20.459139 139783521261376 runtime_analysis.py:521]
Results vgg7-onnx:
+------------------------------+--------------+
| Metric (Per Batch) | Value |
@@ -934,7 +935,7 @@ VGG_CHECKPOINT_PATH = "../../../mase_output/vgg7-pre-trained/test-accu-0.9332.ck
| Average GPU Power Usage | 58.211 W |
| Inference Energy Consumption | 0.14364 mWh |
+------------------------------+-------------+[0m
- I0329 13:56:42.742136 139783521261376 runtime_analysis.py:521]
+ I0329 13:56:42.742136 139783521261376 runtime_analysis.py:521]
Results vgg7-onnx:
+------------------------------+-------------+
| Metric (Per Batch) | Value |
@@ -973,7 +974,7 @@ VGG_CHECKPOINT_PATH = "../../../mase_output/vgg7-pre-trained/test-accu-0.9332.ck
| Average GPU Power Usage | 22.924 W |
| Inference Energy Consumption | 0.0742 mWh |
+------------------------------+------------+[0m
- I0329 13:59:46.169317 139783521261376 runtime_analysis.py:521]
+ I0329 13:59:46.169317 139783521261376 runtime_analysis.py:521]
Results vgg7-onnx:
+------------------------------+------------+
| Metric (Per Batch) | Value |
@@ -1009,7 +1010,7 @@ VGG_CHECKPOINT_PATH = "../../../mase_output/vgg7-pre-trained/test-accu-0.9332.ck
| Average GPU Power Usage | 49.313 W |
| Inference Energy Consumption | 0.067374 mWh |
+------------------------------+--------------+[0m
- I0329 14:00:07.487649 139783521261376 runtime_analysis.py:521]
+ I0329 14:00:07.487649 139783521261376 runtime_analysis.py:521]
Results vgg7-onnx:
+------------------------------+--------------+
| Metric (Per Batch) | Value |
@@ -1029,6 +1030,6 @@ VGG_CHECKPOINT_PATH = "../../../mase_output/vgg7-pre-trained/test-accu-0.9332.ck
I0329 14:06:04.342311 139783521261376 save_and_load.py:147] Saved mase graph to /root/mase/mase_output/vgg7_cls_cifar10_2024-03-29/software/transform/transformed_ckpt
[32mINFO [0m [34mTransformation is completed[0m
I0329 14:06:04.342653 139783521261376 cli.py:388] Transformation is completed
+```
-
-As we can see, the optimized onnx model still outperforms Pytorch on the VGG model due to it's runtime optimizations. The static performs the best, then the automatic mixed precision which outperforms the dynamic quantization due to its requirement of calculating activations on-the-fly.
+As we can see, the optimized onnx model still outperforms Pytorch on the VGG model due to it's runtime optimizations. The static performs the best, then the automatic mixed precision which outperforms the dynamic quantization due to its requirement of calculating activations on-the-fly.
diff --git a/docs/source/modules/documentation/tutorials/advanced/tensorRT_quantization_tutorial.md b/docs/source/modules/documentation/tutorials/advanced/tensorRT_quantization_tutorial.md
index b208f6473..5de3d3c22 100644
--- a/docs/source/modules/documentation/tutorials/advanced/tensorRT_quantization_tutorial.md
+++ b/docs/source/modules/documentation/tutorials/advanced/tensorRT_quantization_tutorial.md
@@ -5,11 +5,11 @@ This notebook is designed to show the features of the TensorRT passes integrated
## Section 1. INT8 Quantization
Firstly, we will show you how to do a int8 quantization of a simple model, `jsc-toy`, and compare the quantized model to the original model using the `Machop API`. The quantization process is split into the following stages, each using their own individual pass, and are explained in depth at each subsection:
-1. [Fake quantization](#section-11-fake-quantization): `tensorrt_fake_quantize_transform_pass`
-2. [Calibration](#section-12-calibration): `tensorrt_calibrate_transform_pass`
-3. [Quantized Aware Training](#section-13-quantized-aware-training-qat): `tensorrt_fine_tune_transform_pass`
-4. [Quantization](#section-14-tensorrt-quantization): `tensorrt_engine_interface_pass`
-5. [Analysis](#section-15-performance-analysis): `tensorrt_analysis_pass`
+1. Fake quantization: `tensorrt_fake_quantize_transform_pass`
+2. Calibration: `tensorrt_calibrate_transform_pass`
+3. Quantized Aware Training: `tensorrt_fine_tune_transform_pass`
+4. Quantization: `tensorrt_engine_interface_pass`
+5. Analysis: `tensorrt_analysis_pass`
We start by loading in the required libraries and passes required for the notebook as well as ensuring the correct path is set for machop to be used.
@@ -63,7 +63,7 @@ set_logging_verbosity("info")
I0329 12:52:20.742465 139924298352448 logger.py:44] Set logging level to info
-Next, we load in the toml file used for quantization. To view the configuration, click [here](../../../machop/configs/tensorrt/jsc_toy_INT8_quantization_by_type.toml).
+Next, we load in the toml file used for quantization.
```python
@@ -136,8 +136,8 @@ mg = MaseGraph(model=model)
Next, we train the `jsc-toy` model using the machop `train` action with the config from the toml file.
-```python
-!ch train --config {JSC_TOML_PATH}
+```bash
+ch train --config {JSC_TOML_PATH}
```
Then we load in the checkpoint. You will have to adjust this according to where it has been stored in the mase_output directory.
@@ -169,12 +169,12 @@ mg_original = deepcopy_mase_graph(mg)
Firstly, we fake quantize the module in order to perform calibration and fine tuning before actually quantizing - this is only used if we have int8 calibration as other precisions are not currently supported within [pytorch-quantization](https://docs.nvidia.com/deeplearning/tensorrt/pytorch-quantization-toolkit/docs/index.html#) library.
-This is acheived through the `tensorrt_fake_quantize_transform_pass` which goes through the model, either by type or by name, replaces each layer appropriately to a fake quantized form if the `quantize` parameter is set in the default config (`passes.tensorrt.default.config`) or on a per name or type basis.
+This is acheived through the `tensorrt_fake_quantize_transform_pass` which goes through the model, either by type or by name, replaces each layer appropriately to a fake quantized form if the `quantize` parameter is set in the default config (`passes.tensorrt.default.config`) or on a per name or type basis.
Currently the quantizable layers are:
- Linear
-- Conv1d, Conv2d, Conv3d
-- ConvTranspose1d, ConvTranspose2d, ConvTranspose3d
+- Conv1d, Conv2d, Conv3d
+- ConvTranspose1d, ConvTranspose2d, ConvTranspose3d
- MaxPool1d, MaxPool2d, MaxPool3d
- AvgPool1d, AvgPool2d, AvgPool3d
- LSTM, LSTMCell
@@ -202,7 +202,7 @@ summarize_quantization_analysis_pass(mg_original, mg)
| ReLU | relu | 4 | 0 | 4 |
| output | output | 1 | 0 | 1 |
| x | placeholder | 1 | 0 | 1 |[0m
- I0329 12:52:41.851252 139924298352448 summary.py:85]
+ I0329 12:52:41.851252 139924298352448 summary.py:85]
| Original type | OP | Total | Changed | Unchanged |
|-----------------+--------------+---------+-----------+-------------|
| BatchNorm1d | batch_norm1d | 4 | 0 | 4 |
@@ -212,16 +212,16 @@ summarize_quantization_analysis_pass(mg_original, mg)
| x | placeholder | 1 | 0 | 1 |
-As you can see we have succesfully fake quantized all linear layers inside `jsc-toy`. This means that we will be able to simulate a quantized model in order to calibrate and fine tune it. This fake quantization was done on typewise i.e. for linear layers only. See [Section 4](#section-4-layer-wise-mixed-precision) for how to apply quantization layerwise - i.e. only first and second layers for example.
+As you can see we have succesfully fake quantized all linear layers inside `jsc-toy`. This means that we will be able to simulate a quantized model in order to calibrate and fine tune it. This fake quantization was done on typewise i.e. for linear layers only. See Section 4 for how to apply quantization layerwise - i.e. only first and second layers for example.
### Section 1.2 Calibration
-Next, we perform calibration using the `tensorrt_calibrate_transform_pass`. Calibration is achieved by passing data samples to the quantizer and deciding the best amax for activations.
+Next, we perform calibration using the `tensorrt_calibrate_transform_pass`. Calibration is achieved by passing data samples to the quantizer and deciding the best amax for activations.
Calibrators can be added as a search space parameter to examine the best performing calibrator. The calibrators have been included in the toml as follows.
For example: `calibrators = ["percentile", "mse", "entropy"]`
-Note:
+Note:
- To use `percentile` calibration, a list of percentiles must be given
- To use `max` calibration, the `histogram` weight and input calibrators must be removed and replaced with `max`. This will use global maximum absolute value to calibrate the model.
- If `post_calibration_analysis` is set true the `tensorrt_analysis_pass` will be run for each calibrator tested to evaluate the most suitable calibrator for the model.
@@ -300,7 +300,7 @@ mg, _ = tensorrt_calibrate_transform_pass(mg, pass_args=tensorrt_config)
| Average GPU Power Usage | 22.239 W |
| Inference Energy Consumption | 0.018661 mWh |
+------------------------------+--------------+[0m
- I0329 12:52:49.573626 139924298352448 runtime_analysis.py:521]
+ I0329 12:52:49.573626 139924298352448 runtime_analysis.py:521]
Results jsc-toy:
+------------------------------+--------------+
| Metric (Per Batch) | Value |
@@ -354,7 +354,7 @@ mg, _ = tensorrt_calibrate_transform_pass(mg, pass_args=tensorrt_config)
| Average GPU Power Usage | 22.21 W |
| Inference Energy Consumption | 0.018664 mWh |
+------------------------------+--------------+[0m
- I0329 12:52:53.233150 139924298352448 runtime_analysis.py:521]
+ I0329 12:52:53.233150 139924298352448 runtime_analysis.py:521]
Results jsc-toy:
+------------------------------+--------------+
| Metric (Per Batch) | Value |
@@ -408,7 +408,7 @@ mg, _ = tensorrt_calibrate_transform_pass(mg, pass_args=tensorrt_config)
| Average GPU Power Usage | 22.252 W |
| Inference Energy Consumption | 0.017687 mWh |
+------------------------------+--------------+[0m
- I0329 12:52:56.441818 139924298352448 runtime_analysis.py:521]
+ I0329 12:52:56.441818 139924298352448 runtime_analysis.py:521]
Results jsc-toy:
+------------------------------+--------------+
| Metric (Per Batch) | Value |
@@ -462,7 +462,7 @@ mg, _ = tensorrt_calibrate_transform_pass(mg, pass_args=tensorrt_config)
| Average GPU Power Usage | 22.426 W |
| Inference Energy Consumption | 0.018681 mWh |
+------------------------------+--------------+[0m
- I0329 12:53:05.428555 139924298352448 runtime_analysis.py:521]
+ I0329 12:53:05.428555 139924298352448 runtime_analysis.py:521]
Results jsc-toy:
+------------------------------+--------------+
| Metric (Per Batch) | Value |
@@ -516,7 +516,7 @@ mg, _ = tensorrt_calibrate_transform_pass(mg, pass_args=tensorrt_config)
| Average GPU Power Usage | 22.525 W |
| Inference Energy Consumption | 0.018149 mWh |
+------------------------------+--------------+[0m
- I0329 12:53:22.697756 139924298352448 runtime_analysis.py:521]
+ I0329 12:53:22.697756 139924298352448 runtime_analysis.py:521]
Results jsc-toy:
+------------------------------+--------------+
| Metric (Per Batch) | Value |
@@ -538,7 +538,7 @@ mg, _ = tensorrt_calibrate_transform_pass(mg, pass_args=tensorrt_config)
I0329 12:53:22.701544 139924298352448 calibrate.py:213] Succeeded in calibrating the model in PyTorch!
-From the results, the 99% `percentile` clips too many values during the amax calibration, compromising the loss. However 99.99% demonstrates higher validation accuracy alongside `mse` and `entropy` for `jsc-toy`. For such a small model, the methods are not highly distinguished, however for larger models this calibration process will be important for ensuring the quantized model still performs well.
+From the results, the 99% `percentile` clips too many values during the amax calibration, compromising the loss. However 99.99% demonstrates higher validation accuracy alongside `mse` and `entropy` for `jsc-toy`. For such a small model, the methods are not highly distinguished, however for larger models this calibration process will be important for ensuring the quantized model still performs well.
### Section 1.3 Quantized Aware Training (QAT)
@@ -576,14 +576,14 @@ mg, _ = tensorrt_fine_tune_transform_pass(mg, pass_args=tensorrt_config)
I0329 12:53:59.800536 139924298352448 cuda.py:61] LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
- I0329 12:53:59.814722 139924298352448 model_summary.py:94]
+ I0329 12:53:59.814722 139924298352448 model_summary.py:94]
| Name | Type | Params
-------------------------------------------------
- 0 | model | GraphModule | 327
- 1 | loss_fn | CrossEntropyLoss | 0
- 2 | acc_train | MulticlassAccuracy | 0
- 3 | loss_val | MeanMetric | 0
- 4 | loss_test | MeanMetric | 0
+ 0 | model | GraphModule | 327
+ 1 | loss_fn | CrossEntropyLoss | 0
+ 2 | acc_train | MulticlassAccuracy | 0
+ 3 | loss_val | MeanMetric | 0
+ 4 | loss_test | MeanMetric | 0
-------------------------------------------------
327 Trainable params
0 Non-trainable params
@@ -620,7 +620,7 @@ mg, _ = tensorrt_fine_tune_transform_pass(mg, pass_args=tensorrt_config)
After QAT, we are now ready to convert the model to a tensorRT engine so that it can be run with the superior inference speeds. To do so, we use the `tensorrt_engine_interface_pass` which converts the `MaseGraph`'s model from a Pytorch one to an ONNX format as an intermediate stage of the conversion.
-During the conversion process, the `.onnx` and `.trt` files are stored to their respective folders shown in [Section 1.3](#section-13-quantized-aware-training-qat).
+During the conversion process, the `.onnx` and `.trt` files are stored to their respective folders shown in Section 1.3.
This interface pass returns a dictionary containing the `onnx_path` and `trt_engine_path`.
@@ -677,7 +677,7 @@ _, _ = runtime_analysis_pass(meta['trt_engine_path'], pass_args=runtime_analysis
| Average GPU Power Usage | 23.792 W |
| Inference Energy Consumption | 0.0057535 mWh |
+------------------------------+---------------+[0m
- I0329 13:03:32.504800 139924298352448 runtime_analysis.py:521]
+ I0329 13:03:32.504800 139924298352448 runtime_analysis.py:521]
Results jsc-toy:
+------------------------------+---------------+
| Metric (Per Batch) | Value |
@@ -709,7 +709,7 @@ _, _ = runtime_analysis_pass(meta['trt_engine_path'], pass_args=runtime_analysis
| Average GPU Power Usage | 23.043 W |
| Inference Energy Consumption | 0.00085532 mWh |
+------------------------------+----------------+[0m
- I0329 13:03:34.503784 139924298352448 runtime_analysis.py:521]
+ I0329 13:03:34.503784 139924298352448 runtime_analysis.py:521]
Results jsc-toy-trt_quantized:
+------------------------------+----------------+
| Metric (Per Batch) | Value |
@@ -727,19 +727,18 @@ _, _ = runtime_analysis_pass(meta['trt_engine_path'], pass_args=runtime_analysis
I0329 13:03:34.506492 139924298352448 runtime_analysis.py:143] Runtime analysis results saved to /root/mase_output/tensorrt/quantization/jsc-toy_cls_jsc_2024-03-29/tensorrt/version_0/model.json
-As shown above, the latency has decreased around 6x with the `jsc-toy` model without compromising accuracy due to the well calibrated amax and quantization-aware fine tuning and additional runtime optimizations from TensorRT. The inference energy consumption has thus also dropped tremendously and this is an excellent demonstration for the need to quantize in industry especially for LLMs in order to reduce energy usage.
+As shown above, the latency has decreased around 6x with the `jsc-toy` model without compromising accuracy due to the well calibrated amax and quantization-aware fine tuning and additional runtime optimizations from TensorRT. The inference energy consumption has thus also dropped tremendously and this is an excellent demonstration for the need to quantize in industry especially for LLMs in order to reduce energy usage.
## Section 2. FP16 Quantization
-We will now load in a new toml configuration that uses fp16 instead of int8, whilst keeping the other settings the exact same for a fair comparison. This time however, we will use chop from the terminal which runs all the passes showcased in [Section 1](#section-1---int8-quantization).
+We will now load in a new toml configuration that uses fp16 instead of int8, whilst keeping the other settings the exact same for a fair comparison. This time however, we will use chop from the terminal which runs all the passes showcased in Section 1.
-Since float quantization does not require calibration, nor is it supported by `pytorch-quantization`, the model will not undergo fake quantization; for the time being this unfortunately means QAT is unavailable and only undergoes Post Training Quantization (PTQ).
+Since float quantization does not require calibration, nor is it supported by `pytorch-quantization`, the model will not undergo fake quantization; for the time being this unfortunately means QAT is unavailable and only undergoes Post Training Quantization (PTQ).
-```python
+```text
JSC_FP16_BY_TYPE_TOML = "../../../machop/configs/tensorrt/jsc_toy_FP16_quantization_by_type.toml"
!ch transform --config {JSC_FP16_BY_TYPE_TOML} --load {JSC_CHECKPOINT_PATH} --load-type pl
-```
8808.24s - pydevd: Sending message related to process being replaced timed-out after 5 seconds
[2024-03-28 09:37:03,989] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)
@@ -811,7 +810,7 @@ JSC_FP16_BY_TYPE_TOML = "../../../machop/configs/tensorrt/jsc_toy_FP16_quantizat
| ReLU | relu | 4 | 0 | 4 |
| output | output | 1 | 0 | 1 |
| x | placeholder | 1 | 0 | 1 |[0m
- I0328 09:37:11.551648 140201001654080 summary.py:85]
+ I0328 09:37:11.551648 140201001654080 summary.py:85]
| Original type | OP | Total | Changed | Unchanged |
|-----------------+--------------+---------+-----------+-------------|
| BatchNorm1d | batch_norm1d | 4 | 0 | 4 |
@@ -849,7 +848,7 @@ JSC_FP16_BY_TYPE_TOML = "../../../machop/configs/tensorrt/jsc_toy_FP16_quantizat
| Average GPU Power Usage | 22.024 W |
| Inference Energy Consumption | 0.0049148 mWh |
+------------------------------+---------------+[0m
- I0328 09:37:36.951404 140201001654080 runtime_analysis.py:437]
+ I0328 09:37:36.951404 140201001654080 runtime_analysis.py:437]
Results jsc-toy:
+------------------------------+---------------+
| Metric (Per Batch) | Value |
@@ -871,7 +870,7 @@ JSC_FP16_BY_TYPE_TOML = "../../../machop/configs/tensorrt/jsc_toy_FP16_quantizat
------|---------|----------|----------------------|----------------------|-----------------------
0 | Input | FLOAT | (64, 16) | (64, 16) | input
1 | Output | FLOAT | (64, 5) | (64, 5) | 37[0m
- I0328 09:37:36.960667 140201001654080 runtime_analysis.py:167]
+ I0328 09:37:36.960667 140201001654080 runtime_analysis.py:167]
TensorRT Engine Input/Output Information:
Index | Type | DataType | Static Shape | Dynamic Shape | Name
------|---------|----------|----------------------|----------------------|-----------------------
@@ -893,7 +892,7 @@ JSC_FP16_BY_TYPE_TOML = "../../../machop/configs/tensorrt/jsc_toy_FP16_quantizat
| Average GPU Power Usage | 21.706 W |
| Inference Energy Consumption | 0.00055067 mWh |
+------------------------------+----------------+[0m
- I0328 09:37:43.052305 140201001654080 runtime_analysis.py:437]
+ I0328 09:37:43.052305 140201001654080 runtime_analysis.py:437]
Results jsc-toy-trt_quantized:
+------------------------------+----------------+
| Metric (Per Batch) | Value |
@@ -913,7 +912,7 @@ JSC_FP16_BY_TYPE_TOML = "../../../machop/configs/tensorrt/jsc_toy_FP16_quantizat
I0328 09:37:43.132117 140201001654080 save_and_load.py:147] Saved mase graph to /root/mase/mase_output/jsc-toy_cls_jsc_2024-03-28/software/transform/transformed_ckpt
[32mINFO [0m [34mTransformation is completed[0m
I0328 09:37:43.132461 140201001654080 cli.py:383] Transformation is completed
-
+```
As you can see, `fp16` acheives a slighty higher test accuracy but a slightly lower latency (~30%) from that of int8 quantization; it is still ~2.5x faster than the unquantized model. Now lets apply quantization to a more complicated model.
@@ -925,31 +924,30 @@ In this case, we set:
- The `by` parameter to `type`
- The `quantize` parameter to true for `passes.tensorrt.conv2d.config` and `precision` parameter to 'int8'.
- The `input` and `weight` quantize axis for the conv2d layers.
-- The default `passes.tensorrt.default.config` precision to true.
+- The default `passes.tensorrt.default.config` precision to true.
During the TensorRT quantization, the model's conv2d layers will be converted to an int8 fake quantized form, whilst the linear layers are kept to their default 'fp16'. Calibration of the conv2d layers and then fine tuning will be undergone before quantization and inference.
-You may either download a pretrained model [here](https://imperiallondon-my.sharepoint.com/:f:/g/personal/zz7522_ic_ac_uk/Emh3VT7Q_qRFmnp8kDrcgDoBwGUuzLwwKNtX8ZAt368jJQ?e=gsKONa), otherwise train it yourself as shown below.
+You may either download a pretrained model [here](https://imperiallondon-my.sharepoint.com/:f:/g/personal/zz7522_ic_ac_uk/Emh3VT7Q_qRFmnp8kDrcgDoBwGUuzLwwKNtX8ZAt368jJQ?e=gsKONa), otherwise train it yourself as shown below.
-```python
+```bash
VGG_TYPEWISE_TOML = "../../../machop/configs/tensorrt/vgg7_typewise_mixed_precision.toml"
-!ch train --config {VGG_TYPEWISE_TOML}
+ch train --config {VGG_TYPEWISE_TOML}
```
-We will now load the checkpoint in, quantize the model and compare it to the unquantized version as we did in [Section 1.5](#section-15-performance-analysis)
+We will now load the checkpoint in, quantize the model and compare it to the unquantized version as we did in Section 1.5
-```python
+```bash
# Change this checkpoint path accordingly
VGG_CHECKPOINT_PATH = "../../../mase_output/vgg7-pre-trained/test-accu-0.9332.ckpt"
```
-```python
-!ch transform --config {VGG_TYPEWISE_TOML} --load {VGG_CHECKPOINT_PATH} --load-type pl
-```
+```text
+ch transform --config {VGG_TYPEWISE_TOML} --load {VGG_CHECKPOINT_PATH} --load-type pl
[2024-03-28 23:00:09,016] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)
INFO: Seed set to 0
@@ -1033,7 +1031,7 @@ VGG_CHECKPOINT_PATH = "../../../mase_output/vgg7-pre-trained/test-accu-0.9332.ck
| output | output | 1 | 0 | 1 |
| view | view | 1 | 0 | 1 |
| x | placeholder | 1 | 0 | 1 |[0m
- I0328 23:00:37.270473 139939454809920 summary.py:85]
+ I0328 23:00:37.270473 139939454809920 summary.py:85]
| Original type | OP | Total | Changed | Unchanged |
|-----------------+--------------+---------+-----------+-------------|
| BatchNorm2d | batch_norm2d | 6 | 0 | 6 |
@@ -1161,7 +1159,7 @@ VGG_CHECKPOINT_PATH = "../../../mase_output/vgg7-pre-trained/test-accu-0.9332.ck
| Average GPU Power Usage | 59.019 W |
| Inference Energy Consumption | 0.24777 mWh |
+------------------------------+-------------+[0m
- I0328 23:00:55.766893 139939454809920 runtime_analysis.py:521]
+ I0328 23:00:55.766893 139939454809920 runtime_analysis.py:521]
Results vgg7:
+------------------------------+-------------+
| Metric (Per Batch) | Value |
@@ -1233,7 +1231,7 @@ VGG_CHECKPOINT_PATH = "../../../mase_output/vgg7-pre-trained/test-accu-0.9332.ck
| Average GPU Power Usage | 59.653 W |
| Inference Energy Consumption | 0.25317 mWh |
+------------------------------+-------------+[0m
- I0328 23:01:07.450706 139939454809920 runtime_analysis.py:521]
+ I0328 23:01:07.450706 139939454809920 runtime_analysis.py:521]
Results vgg7:
+------------------------------+-------------+
| Metric (Per Batch) | Value |
@@ -1270,14 +1268,14 @@ VGG_CHECKPOINT_PATH = "../../../mase_output/vgg7-pre-trained/test-accu-0.9332.ck
Files already downloaded and verified
Files already downloaded and verified
I0328 23:01:12.623627 139939454809920 cuda.py:61] LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
- I0328 23:01:12.632704 139939454809920 model_summary.py:94]
+ I0328 23:01:12.632704 139939454809920 model_summary.py:94]
| Name | Type | Params
-------------------------------------------------
0 | model | GraphModule | 14.0 M
- 1 | loss_fn | CrossEntropyLoss | 0
- 2 | acc_train | MulticlassAccuracy | 0
- 3 | loss_val | MeanMetric | 0
- 4 | loss_test | MeanMetric | 0
+ 1 | loss_fn | CrossEntropyLoss | 0
+ 2 | acc_train | MulticlassAccuracy | 0
+ 3 | loss_val | MeanMetric | 0
+ 4 | loss_test | MeanMetric | 0
-------------------------------------------------
14.0 M Trainable params
0 Non-trainable params
@@ -1645,7 +1643,7 @@ VGG_CHECKPOINT_PATH = "../../../mase_output/vgg7-pre-trained/test-accu-0.9332.ck
| Average GPU Power Usage | 58.043 W |
| Inference Energy Consumption | 0.1441 mWh |
+------------------------------+------------+[0m
- I0328 23:06:47.111017 139939454809920 runtime_analysis.py:521]
+ I0328 23:06:47.111017 139939454809920 runtime_analysis.py:521]
Results vgg7:
+------------------------------+------------+
| Metric (Per Batch) | Value |
@@ -1677,7 +1675,7 @@ VGG_CHECKPOINT_PATH = "../../../mase_output/vgg7-pre-trained/test-accu-0.9332.ck
| Average GPU Power Usage | 52.687 W |
| Inference Energy Consumption | 0.12441 mWh |
+------------------------------+-------------+[0m
- I0328 23:07:00.676242 139939454809920 runtime_analysis.py:521]
+ I0328 23:07:00.676242 139939454809920 runtime_analysis.py:521]
Results vgg7-trt_quantized:
+------------------------------+-------------+
| Metric (Per Batch) | Value |
@@ -1693,13 +1691,13 @@ VGG_CHECKPOINT_PATH = "../../../mase_output/vgg7-pre-trained/test-accu-0.9332.ck
+------------------------------+-------------+
[32mINFO [0m [34mRuntime analysis results saved to /root/mase_output/tensorrt/quantization/vgg7_cls_cifar10_2024-03-28/tensorrt/version_7/model.json[0m
I0328 23:07:00.677799 139939454809920 runtime_analysis.py:143] Runtime analysis results saved to /root/mase_output/tensorrt/quantization/vgg7_cls_cifar10_2024-03-28/tensorrt/version_7/model.json
+```
-
-By quantizing all convolutional layers to INT8 and maintaining fp16 precision for the linear layers we see a marginal decrease in latency whilst maintaining a comparable accuracy. By experimenting with precisions on a per type basis, you may find insights that work best for your model.
+By quantizing all convolutional layers to INT8 and maintaining fp16 precision for the linear layers we see a marginal decrease in latency whilst maintaining a comparable accuracy. By experimenting with precisions on a per type basis, you may find insights that work best for your model.
## Section 4. Layer-wise Mixed Precision
-So far we have strictly quantized either in int8 or fp16. Now, we will show how to conduct layerwise mixed precision using the same `vgg7` model. In this case we will show how for instance, layer 0 and 1 can be set to fp16, while the remaining layers can be int8 quantized.
+So far we have strictly quantized either in int8 or fp16. Now, we will show how to conduct layerwise mixed precision using the same `vgg7` model. In this case we will show how for instance, layer 0 and 1 can be set to fp16, while the remaining layers can be int8 quantized.
For this, we set:
- The `by` parameter to `name`
@@ -1708,11 +1706,10 @@ For this, we set:
- The `precision` to 'int8' for `passes.tensorrt.feature_layers_2.config and passes.tensorrt.feature_layers_3.config` (although this is not necessary since the default is already set to 'int8')
-```python
+```text
VGG_LAYERWISE_TOML = "../../../machop/configs/tensorrt/vgg7_layerwise_mixed_precision.toml"
-!ch transform --config {VGG_LAYERWISE_TOML} --load {VGG_CHECKPOINT_PATH} --load-type pl
-```
+ch transform --config {VGG_LAYERWISE_TOML} --load {VGG_CHECKPOINT_PATH} --load-type pl
[2024-03-28 23:25:51,157] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)
INFO: Seed set to 0
@@ -1796,7 +1793,7 @@ VGG_LAYERWISE_TOML = "../../../machop/configs/tensorrt/vgg7_layerwise_mixed_prec
| output | output | 1 | 0 | 1 |
| view | view | 1 | 0 | 1 |
| x | placeholder | 1 | 0 | 1 |[0m
- I0328 23:26:12.941653 140449214740288 summary.py:85]
+ I0328 23:26:12.941653 140449214740288 summary.py:85]
| Original type | OP | Total | Changed | Unchanged |
|-----------------+--------------+---------+-----------+-------------|
| BatchNorm2d | batch_norm2d | 6 | 0 | 6 |
@@ -1956,7 +1953,7 @@ VGG_LAYERWISE_TOML = "../../../machop/configs/tensorrt/vgg7_layerwise_mixed_prec
| Average GPU Power Usage | 57.532 W |
| Inference Energy Consumption | 0.28607 mWh |
+------------------------------+-------------+[0m
- I0328 23:26:29.263397 140449214740288 runtime_analysis.py:521]
+ I0328 23:26:29.263397 140449214740288 runtime_analysis.py:521]
Results vgg7:
+------------------------------+-------------+
| Metric (Per Batch) | Value |
@@ -2040,7 +2037,7 @@ VGG_LAYERWISE_TOML = "../../../machop/configs/tensorrt/vgg7_layerwise_mixed_prec
| Average GPU Power Usage | 57.867 W |
| Inference Energy Consumption | 0.29145 mWh |
+------------------------------+-------------+[0m
- I0328 23:26:40.146152 140449214740288 runtime_analysis.py:521]
+ I0328 23:26:40.146152 140449214740288 runtime_analysis.py:521]
Results vgg7:
+------------------------------+-------------+
| Metric (Per Batch) | Value |
@@ -2099,7 +2096,7 @@ VGG_LAYERWISE_TOML = "../../../machop/configs/tensorrt/vgg7_layerwise_mixed_prec
| Average GPU Power Usage | 55.687 W |
| Inference Energy Consumption | 0.12102 mWh |
+------------------------------+-------------+[0m
- I0328 23:07:00.676242 139939454809920 runtime_analysis.py:521]
+ I0328 23:07:00.676242 139939454809920 runtime_analysis.py:521]
Results vgg7-trt_quantized:
+------------------------------+-------------+
| Metric (Per Batch) | Value |
@@ -2115,6 +2112,6 @@ VGG_LAYERWISE_TOML = "../../../machop/configs/tensorrt/vgg7_layerwise_mixed_prec
+------------------------------+-------------+
[32mINFO [0m [34mRuntime analysis results saved to /root/mase_output/tensorrt/quantization/vgg7_cls_cifar10_2024-03-28/tensorrt/version_8/model.json[0m
I0328 23:07:00.677799 139939454809920 runtime_analysis.py:143] Runtime analysis results saved to /root/mase_output/tensorrt/quantization/vgg7_cls_cifar10_2024-03-28/tensorrt/version_8/model.json
-
+```
In this case, we can see through the quantized summary that one convolutional layer (feature_layers_1) has not been quantized as its precision will be configured to 'fp16' in the tensorrt engine conversion stage whilst the remaining convolutional and linear layers have been quantized.
diff --git a/docs/source/modules/documentation/tutorials/cli/mixed_precision_search_on_manual_model.md b/docs/source/modules/documentation/tutorials/cli/mixed_precision_search_on_manual_model.md
index 8f4d69cf1..9fbb14c38 100644
--- a/docs/source/modules/documentation/tutorials/cli/mixed_precision_search_on_manual_model.md
+++ b/docs/source/modules/documentation/tutorials/cli/mixed_precision_search_on_manual_model.md
@@ -82,7 +82,7 @@ cd machop
```
When the search is done, the best quantization config will be printed out. Since we run multi-objective search. There may be multiple best trials found by Optuna.
-```txt
+```text
Best trial(s):
| | number | software_metrics | hardware_metrics | scaled_metrics |
|----+----------+--------------------------------------+------------------------------------------------------+-------------------------------------------------|
@@ -100,7 +100,7 @@ Here is part of the `log.json` recording all search details.
For example, `log["0"]["user_attrs_sampled_config"]` is the sampled quantization config of trial 0. Expand it and you will set the precision of each matmul/linear layer's operands.
-```json
+```text
{
"0":{
"number":0,
@@ -120,8 +120,8 @@ For example, `log["0"]["user_attrs_sampled_config"]` is the sampled quantization
"bias_width":2,
"bias_frac_width":8
}
- },
- ...
+ }
+ // ...
},
"user_attrs_scaled_metrics":{
"accuracy":0.5,
@@ -154,8 +154,8 @@ For example, `log["0"]["user_attrs_sampled_config"]` is the sampled quantization
"bias_width":4,
"bias_frac_width":3
}
- },
- ...
+ }
+ // ...
},
"user_attrs_scaled_metrics":{
"accuracy":0.5747232437,
@@ -169,7 +169,7 @@ For example, `log["0"]["user_attrs_sampled_config"]` is the sampled quantization
"datetime_start":1694095031290,
"datetime_complete":1694095032462,
"duration":1172
- },
- ...
+ }
+ // ...
}
```
\ No newline at end of file
diff --git a/docs/source/modules/documentation/tutorials/cli/mixed_precision_search_on_mase_graph.md b/docs/source/modules/documentation/tutorials/cli/mixed_precision_search_on_mase_graph.md
index e82d69d2d..10b309d90 100644
--- a/docs/source/modules/documentation/tutorials/cli/mixed_precision_search_on_mase_graph.md
+++ b/docs/source/modules/documentation/tutorials/cli/mixed_precision_search_on_mase_graph.md
@@ -4,13 +4,13 @@ This tutorial shows how to search for mixed-precision quantization strategy for
## Commands
-First we train a model on the dataset. After training for some epochs, we get a model with some validation accuracy. The checkpoint is saved at an auto-created location. You can refer to [Run the train action with the CLI](../train/simple_train_flow.md) for more detailed explanation.
+First we train a model on the dataset. After training for some epochs, we get a model with some validation accuracy. The checkpoint is saved at an auto-created location. You can refer to Run the train action with the CLI for more detailed explanation.
The reason why we need a pre-trained model is because we would like to do a post-training-quantization (PTQ) search. This means the quantization happens on a pre-trained model. We then use the PTQ accuracy as a proxy signal for our search.
```bash
-cd src
+cd src
./ch train jsc-tiny jsc --max-epochs 3 --batch-size 256 --accelerator cpu --project tmp --debug --cpu 0
```
@@ -151,7 +151,7 @@ The entire searching log is saved in `../mase_output/jsc-tiny/software/search_ck
Here is part of the `log.json`
-```json
+```text
{
"0":{
"number":0,
diff --git a/docs/source/modules/documentation/tutorials/cli/simple_train_flow.md b/docs/source/modules/documentation/tutorials/cli/simple_train_flow.md
index 8f9eb8cb1..52c21ca40 100644
--- a/docs/source/modules/documentation/tutorials/cli/simple_train_flow.md
+++ b/docs/source/modules/documentation/tutorials/cli/simple_train_flow.md
@@ -12,7 +12,7 @@ In this case, we can try a toymodel, the command looks like the following
```bash
# assuming you you are at the our-stuff/mase directory
-cd src
+cd src
./ch train toy toy_tiny --config ../configs/archive/test/train.toml --max-epochs 3
```
@@ -20,7 +20,7 @@ cd src
You can fetch all command-line arguments:
-```bash
+```text
[nix-shell:~/Projects/mase/src]$ ./ch -help
INFO Set logging level to debug
WARNING TensorRT pass is unavailable because the following dependencies are not installed: pytorch_quantization, tensorrt, pycuda, cuda.
@@ -176,7 +176,7 @@ This directory includes
## Training Logs
-MASE creates [Tensorboard](https://pytorch.org/tutorials/recipes/recipes/tensorboard_with_pytorch.html) or [wandb](https://wandb.ai/site) logs for the training flow - allowing tracking and visualizing metrics such as loss and accuracy. The log files are in `/software/tensorboard/lightning_logs/version_`.
+MASE creates [Tensorboard](https://pytorch.org/tutorials/recipes/recipes/tensorboard_with_pytorch.html) or [wandb](https://wandb.ai/site) logs for the training flow - allowing tracking and visualizing metrics such as loss and accuracy. The log files are in `/software/tensorboard/lightning_logs/version_`.
Run Tensorboard to visualise the logs using:
@@ -209,5 +209,5 @@ To test the model trained above you can use:
```bash
# After training, you will have your checkpoint under mase-tools/mase_output
-# For example, the checkpoint is under ../mase_output/toy_classification_toy-tiny_2023-07-03/software/training_ckpts/best.ckpt
+# For example, the checkpoint is under ../mase_output/toy_classification_toy-tiny_2023-07-03/software/training_ckpts/best.ckpt
./ch test toy toy_tiny --config ../configs/archive/test/train.toml --load ../mase_output/toy_classification_toy_tiny_2024-06-13/software/training_ckpts/best.ckpt```
diff --git a/docs/source/modules/documentation/tutorials/developer/Add-model-to-machop.md b/docs/source/modules/documentation/tutorials/developer/Add-model-to-machop.md
index a30a172fb..c7750e262 100644
--- a/docs/source/modules/documentation/tutorials/developer/Add-model-to-machop.md
+++ b/docs/source/modules/documentation/tutorials/developer/Add-model-to-machop.md
@@ -3,10 +3,10 @@
This document includes steps to add a new model into Machop
## Overall Structure
### Model
-All models that Machop support are defined inside **mase-tools/machop/chop/models**. Each model has a unique get model function, which can be called to create the model. Those get model function will be exported into a dictionary in [\_\_init\_\_](%2E%2E%5Cmachop%5Cchop%5Cmodels%5C%5F%5Finit%5F%5F.py) file.
+All models that Machop support are defined inside **mase-tools/machop/chop/models**. Each model has a unique get model function, which can be called to create the model. Those get model function will be exported into a dictionary in `__init__` file.
### Command Line Interface
-[Command Line Interface (cli)](..\machop\chop\cli.py) will take the input config, and perform the task defined inside the config. When training, cli will look into the dictionary contains the get funtions, use the get-function to create a model, and do training then.
+Command Line Interface (cli) will take the input config, and perform the task defined inside the config. When training, cli will look into the dictionary contains the get funtions, use the get-function to create a model, and do training then.
## What To Do
1. Find the GitHub repositories of the original paper, find the code that defines the models, and copy it into the right folder under **mase-tools\machop\chop\models**
@@ -20,7 +20,7 @@ All models that Machop support are defined inside **mase-tools/machop/chop/model
## Get model function
- **Info** should be used as one of the input variables. It is a dictionary that contains information about the dataset, e.g., number of classes; input image size.
-- Other then **Info**, Inputs for different types of models are different, you can check `_setup_model_and_dataset` function defined in [cli.py](..\machop\chop\cli.py) for more detail.
+- Other then **Info**, Inputs for different types of models are different, you can check `_setup_model_and_dataset` function defined in cli.py for more detail.
- function name of get-function should be in smaller case
- keys of the dictionary should also be in smaller case
diff --git a/docs/source/modules/hardware/activations/gelu.md b/docs/source/modules/hardware/activations/gelu.md
index cb2d094e8..25b2560b4 100644
--- a/docs/source/modules/hardware/activations/gelu.md
+++ b/docs/source/modules/hardware/activations/gelu.md
@@ -11,7 +11,7 @@ When the approximate argument is set to 'tanh', GELU is estimated with:
`GELU(x) = 0.5 * x * (1 + Tanh(2/π * (x + 0.044715 * x^3)))`
-### Parameters:
+## Parameters:
- `approximate` (str, optional): The GELU approximation algorithm to use: 'none' | 'tanh'. Default: 'none'.
diff --git a/docs/source/modules/hardware/activations/selu.md b/docs/source/modules/hardware/activations/selu.md
index 8413420a1..882e5c8d5 100644
--- a/docs/source/modules/hardware/activations/selu.md
+++ b/docs/source/modules/hardware/activations/selu.md
@@ -8,7 +8,7 @@ where:
- α = 1.6732632423543772848170429916717
- scale = 1.0507009873554804934193349852946
-### Parameters
+## Parameters
- `inplace` (bool, optional): Can optionally do the operation in-place. Default: False.
@@ -30,7 +30,7 @@ A hybrid approach is used for implementing exponential function $e^{-|x|}$ for a
2. **Representation of Binary Number**:
The N-bit binary number $a = b_{N-1}b_{N-2}...b_1b_0$ is represented, where $b_0$ is the least significant bit, and each bit $b_i$ has a place value $p_i$ given by $p_i = 2^{-P} \times 2^i$.
-
+
3. **Exponential Computation**:
$e^{-a} = \prod e^{-p_i \times b_i}$
diff --git a/docs/source/modules/hardware/activations/softplus.md b/docs/source/modules/hardware/activations/softplus.md
index ed0066cfb..0d7f3e77a 100644
--- a/docs/source/modules/hardware/activations/softplus.md
+++ b/docs/source/modules/hardware/activations/softplus.md
@@ -9,7 +9,7 @@ where:
- Softplus is a smooth approximation to the ReLU function and can be used to constrain the output of a machine to always be positive.
- For numerical stability, the implementation reverts to the linear function when `input * β > threshold`.
-### Parameters:
+## Parameters:
- `beta` (int): The β value for the Softplus formulation. Default: 1.
- `threshold` (int): Values above this revert to a linear function. Default: 20.
diff --git a/docs/source/modules/hardware/linear/fixed_linear.md b/docs/source/modules/hardware/linear/fixed_linear.md
index 672ea765a..33ddf350e 100644
--- a/docs/source/modules/hardware/linear/fixed_linear.md
+++ b/docs/source/modules/hardware/linear/fixed_linear.md
@@ -32,7 +32,7 @@ The module has the following parameters, following the hardware metadata standar
| Parameter | Default Value | Definition |
|------------------------------ |-------------------------- |------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
| DATA_IN_0_PARALLELISM_DIM_0 | 4 | Number of elements per transaction at the input interface. Dictates the number of transactions to compute the full layer. |
-| WEIGHT_PARALLELISM_DIM_0 | 4 | Number of columns of the weights matrix per transaction at the weights interface. This is equivalent to the number of dot product modules. Also dictates the number of backpressure cycles on the input interface (see [Latency Analysis](#latency-analysis) below) |
+| WEIGHT_PARALLELISM_DIM_0 | 4 | Number of columns of the weights matrix per transaction at the weights interface. This is equivalent to the number of dot product modules. Also dictates the number of backpressure cycles on the input interface (see Latency Analysis below) |
| DATA_OUT_0_PARALLELISM_DIM_0 | WEIGHT_PARALLELISM_DIM_0 | Number of elements per transaction at the output interface. |
| BIAS_PARALLELISM_DIM_0 | WEIGHT_PARALLELISM_DIM_0 | Number of elements per transaction at the bias interface. Dictates the number of fixed-point adders. |
@@ -56,7 +56,7 @@ The same process is repeated with the second input sub-vector $X_2$ and weight s
-## Latency Analysis
+## Latency Analysis
The time taken to compute a linear layer using the `fixed_linear` module, $L_{FL}$ can be broken down into 2 phases, the input driving phase $L_L$, and the pipeline unloading phase $L_U$ that begins after the last input beat is transferred.
diff --git a/docs/source/modules/hardware/systolic_modules/output_stationary.md b/docs/source/modules/hardware/systolic_modules/output_stationary.md
index d5c605a41..d0e5370b2 100644
--- a/docs/source/modules/hardware/systolic_modules/output_stationary.md
+++ b/docs/source/modules/hardware/systolic_modules/output_stationary.md
@@ -8,6 +8,6 @@ The MAC units in each PE perform the multiply-accumulate operation over 2 cycles

-### Systolic Module Driver
+## Systolic Module Driver
The Systolic Module Driver generates pulse signals in the format required to drive the read interface of an on-chip buffer such that data signals are made available with the required timing for the processing elements of a systolic module. This is achieved through a shift register of size BUFFER_SLOT_COUNT. After receiving a starting pulse, the least significant bit is set to 1. Subsequently, the register shifts after every shift pulse, up to a runtime-parametrizable pulse limit count parameter (this is set to the number of output features for the layer being executed). The driver should then pulse a subsequent BUFFER_SLOT_COUNT times until the register is flushed.
\ No newline at end of file
diff --git a/docs/source/modules/labs_2024/lab_0_introduction.rst b/docs/source/modules/labs_2024/lab_0_introduction.rst
index 0bde49047..896c0c769 100644
--- a/docs/source/modules/labs_2024/lab_0_introduction.rst
+++ b/docs/source/modules/labs_2024/lab_0_introduction.rst
@@ -46,7 +46,7 @@ TroubleShooting
You may find that you have to use `Python3.11` but Google Colab only provides `Python3.10`. In this case, you can use the following command to force the kernel ot use `Python3.11`:
-.. code-block:: python
+.. code-block:: text
#The code below installs 3.11 (assuming you now have 3.10 in colab) and restarts environment, so you can run your cells.
import sys #for version checker
From aac9097c1b2f00de5a092bb31cbd4bf230f6f351 Mon Sep 17 00:00:00 2001
From: Cheng Zhang
Date: Tue, 11 Feb 2025 16:45:31 +0000
Subject: [PATCH 09/38] undo docstrings changeRevert "reduce sphinx docstring
error"
This reverts commit 09da13d3c1445a2a9284c9dd8d3e2b7ca3d58da7.
---
.../documentation/tutorials/advanced/cli.md | 18 +--
.../advanced/onnxrt_quantization_tutorial.md | 69 +++++----
.../tensorRT_quantization_tutorial.md | 137 +++++++++---------
.../mixed_precision_search_on_manual_model.md | 16 +-
.../mixed_precision_search_on_mase_graph.md | 6 +-
.../tutorials/cli/simple_train_flow.md | 8 +-
.../developer/Add-model-to-machop.md | 6 +-
.../modules/hardware/activations/gelu.md | 2 +-
.../modules/hardware/activations/selu.md | 4 +-
.../modules/hardware/activations/softplus.md | 2 +-
.../modules/hardware/linear/fixed_linear.md | 4 +-
.../systolic_modules/output_stationary.md | 2 +-
.../modules/labs_2024/lab_0_introduction.rst | 2 +-
13 files changed, 139 insertions(+), 137 deletions(-)
diff --git a/docs/source/modules/documentation/tutorials/advanced/cli.md b/docs/source/modules/documentation/tutorials/advanced/cli.md
index 0b126316c..7e418f7a6 100644
--- a/docs/source/modules/documentation/tutorials/advanced/cli.md
+++ b/docs/source/modules/documentation/tutorials/advanced/cli.md
@@ -16,7 +16,7 @@ In this case, we can try a toymodel, the command looks like the following
```bash
# assuming you you are at the our-stuff/mase directory
-cd src
+cd src
./ch train toy toy_tiny --config ../configs/archive/test/train.toml --max-epochs 3
```
@@ -24,7 +24,7 @@ cd src
You can fetch all command-line arguments:
-```text
+```bash
[nix-shell:~/Projects/mase/src]$ ./ch -help
INFO Set logging level to debug
WARNING TensorRT pass is unavailable because the following dependencies are not installed: pytorch_quantization, tensorrt, pycuda, cuda.
@@ -180,7 +180,7 @@ This directory includes
### Training Logs
-MASE creates [Tensorboard](https://pytorch.org/tutorials/recipes/recipes/tensorboard_with_pytorch.html) or [wandb](https://wandb.ai/site) logs for the training flow - allowing tracking and visualizing metrics such as loss and accuracy. The log files are in `/software/tensorboard/lightning_logs/version_`.
+MASE creates [Tensorboard](https://pytorch.org/tutorials/recipes/recipes/tensorboard_with_pytorch.html) or [wandb](https://wandb.ai/site) logs for the training flow - allowing tracking and visualizing metrics such as loss and accuracy. The log files are in `/software/tensorboard/lightning_logs/version_`.
Run Tensorboard to visualise the logs using:
@@ -213,7 +213,7 @@ To test the model trained above you can use:
```bash
# After training, you will have your checkpoint under mase-tools/mase_output
-# For example, the checkpoint is under ../mase_output/toy_classification_toy-tiny_2023-07-03/software/training_ckpts/best.ckpt
+# For example, the checkpoint is under ../mase_output/toy_classification_toy-tiny_2023-07-03/software/training_ckpts/best.ckpt
./ch test toy toy_tiny --config ../configs/archive/test/train.toml --load ../mase_output/toy_classification_toy_tiny_2024-06-13/software/training_ckpts/best.ckpt```
@@ -383,7 +383,7 @@ cd machop
```
When the search is done, the best quantization config will be printed out. Since we run multi-objective search. There may be multiple best trials found by Optuna.
-```text
+```txt
Best trial(s):
| | number | software_metrics | hardware_metrics | scaled_metrics |
|----+----------+--------------------------------------+------------------------------------------------------+-------------------------------------------------|
@@ -401,7 +401,7 @@ Here is part of the `log.json` recording all search details.
For example, `log["0"]["user_attrs_sampled_config"]` is the sampled quantization config of trial 0. Expand it and you will set the precision of each matmul/linear layer's operands.
-```text
+```json
{
"0":{
"number":0,
@@ -481,13 +481,13 @@ This tutorial shows how to search for mixed-precision quantization strategy for
### Commands
-First we train a model on the dataset. After training for some epochs, we get a model with some validation accuracy. The checkpoint is saved at an auto-created location. You can refer to *Run the train action with the CLI* for more detailed explanation.
+First we train a model on the dataset. After training for some epochs, we get a model with some validation accuracy. The checkpoint is saved at an auto-created location. You can refer to [Run the train action with the CLI](../train/simple_train_flow.md) for more detailed explanation.
The reason why we need a pre-trained model is because we would like to do a post-training-quantization (PTQ) search. This means the quantization happens on a pre-trained model. We then use the PTQ accuracy as a proxy signal for our search.
```bash
-cd src
+cd src
./ch train jsc-tiny jsc --max-epochs 3 --batch-size 256 --accelerator cpu --project tmp --debug --cpu 0
```
@@ -628,7 +628,7 @@ The entire searching log is saved in `../mase_output/jsc-tiny/software/search_ck
Here is part of the `log.json`
-```text
+```json
{
"0":{
"number":0,
diff --git a/docs/source/modules/documentation/tutorials/advanced/onnxrt_quantization_tutorial.md b/docs/source/modules/documentation/tutorials/advanced/onnxrt_quantization_tutorial.md
index 255697394..e01c120b4 100644
--- a/docs/source/modules/documentation/tutorials/advanced/onnxrt_quantization_tutorial.md
+++ b/docs/source/modules/documentation/tutorials/advanced/onnxrt_quantization_tutorial.md
@@ -54,7 +54,7 @@ set_logging_verbosity("info")
I0329 13:46:21.531338 140128553666368 logger.py:44] Set logging level to info
-We then load in a demonstration toml file and set the relevant pass arguments (this is all done automatically if we were to use the command line, see Section 2).
+We then load in a demonstration toml file and set the relevant pass arguments (this is all done automatically if we were to use the command line, see [Section 2](#section-2-int8-quantization))
```python
@@ -120,8 +120,8 @@ mg = MaseGraph(model=model)
Next, we train the `jsc-toy` model using the machop `train` action with the config from the toml file. You may want to switch to GPU for this task - it will not affect the cpu optimizations later on.
-```bash
-ch train --config {JSC_TOML_PATH} --accelerator gpu
+```python
+!ch train --config {JSC_TOML_PATH} --accelerator gpu
```
Then we load in the checkpoint. You will have to adjust this according to where it has been stored in the mase_output directory.
@@ -154,7 +154,7 @@ We then run the `onnx_runtime_interface_pass` which completes the optimizations
- `onnx_path` (the optimized model)
- `onnx_dynamic_quantized_path` (the dynamically )
-In this case, since we are not quantizing the model, only the `onnx_path` is available.
+In this case, since we are not quantizing the model, only the `onnx_path` is available.
The models are also stored in the directory:
```
@@ -178,7 +178,7 @@ mg, onnx_meta = onnx_runtime_interface_pass(mg, pass_args=onnx_config)
I0327 14:20:12.539771 140012160939840 onnx_runtime.py:50] Project will be created at /root/mase/mase_output/onnxrt/jsc-toy_cls_jsc_2024-03-27
[32mINFO [0m [34mONNX Conversion Complete. Stored ONNX model to /root/mase/mase_output/onnxrt/jsc-toy_cls_jsc_2024-03-27/optimized/version_1/model.onnx[0m
I0327 14:20:12.751212 140012160939840 onnx_runtime.py:68] ONNX Conversion Complete. Stored ONNX model to /root/mase/mase_output/onnxrt/jsc-toy_cls_jsc_2024-03-27/optimized/version_1/model.onnx
- [32mINFO [0m [34mONNX Model Summary:
+ [32mINFO [0m [34mONNX Model Summary:
+-------+----------------------------------+--------------------+--------------------------------------------------------------------------------------------------------------------------+-------------------------------------------+---------------------+
| Index | Name | Type | Inputs | Outputs | Attributes |
+-------+----------------------------------+--------------------+--------------------------------------------------------------------------------------------------------------------------+-------------------------------------------+---------------------+
@@ -194,7 +194,7 @@ mg, onnx_meta = onnx_runtime_interface_pass(mg, pass_args=onnx_config)
| 9 | /seq_blocks.9/BatchNormalization | BatchNormalization | /seq_blocks.8/Gemm_output_0, seq_blocks.9.weight, seq_blocks.9.bias, seq_blocks.9.running_mean, seq_blocks.9.running_var | /seq_blocks.9/BatchNormalization_output_0 | epsilon, momentum |
| 10 | /seq_blocks.10/Relu | Relu | /seq_blocks.9/BatchNormalization_output_0 | 37 | |
+-------+----------------------------------+--------------------+--------------------------------------------------------------------------------------------------------------------------+-------------------------------------------+---------------------+[0m
- I0327 14:20:12.757548 140012160939840 onnx_runtime.py:90] ONNX Model Summary:
+ I0327 14:20:12.757548 140012160939840 onnx_runtime.py:90] ONNX Model Summary:
+-------+----------------------------------+--------------------+--------------------------------------------------------------------------------------------------------------------------+-------------------------------------------+---------------------+
| Index | Name | Type | Inputs | Outputs | Attributes |
+-------+----------------------------------+--------------------+--------------------------------------------------------------------------------------------------------------------------+-------------------------------------------+---------------------+
@@ -237,7 +237,7 @@ _, _ = runtime_analysis_pass(mg_original, pass_args=runtime_analysis_config)
| Average GPU Power Usage | 21.816 W |
| Inference Energy Consumption | 0.0048292 mWh |
+------------------------------+---------------+[0m
- I0327 14:20:19.793779 140012160939840 analysis.py:398]
+ I0327 14:20:19.793779 140012160939840 analysis.py:398]
Results jsc-toy:
+------------------------------+---------------+
| Metric (Per Batch) | Value |
@@ -280,7 +280,7 @@ _, _ = runtime_analysis_pass(onnx_meta['onnx_path'], pass_args=runtime_analysis_
| Average GPU Power Usage | 21.575 W |
| Inference Energy Consumption | 0.0013275 mWh |
+------------------------------+---------------+[0m
- I0327 14:20:35.876071 140012160939840 analysis.py:398]
+ I0327 14:20:35.876071 140012160939840 analysis.py:398]
Results jsc-toy-onnx:
+------------------------------+---------------+
| Metric (Per Batch) | Value |
@@ -298,14 +298,14 @@ _, _ = runtime_analysis_pass(onnx_meta['onnx_path'], pass_args=runtime_analysis_
I0327 14:20:35.878773 140012160939840 analysis.py:84] Runtime analysis results saved to /root/mase_output/tensorrt/quantization/jsc-toy_cls_jsc_2024-03-27/onnx/version_0/model.json
-As shown above, the latency of the cpu inference is around 3.5x less with the `jsc-toy` model without compromising accuracy simply by using the optimizations of ONNXRT.
+As shown above, the latency of the cpu inference is around 3.5x less with the `jsc-toy` model without compromising accuracy simply by using the optimizations of ONNXRT.
Lets now run the same optimzations, this time using a GPU and a larger model - the `vgg7`. We will also utilse the chop action from the terminal which runs the same `onnx_runtime_interface_pass` pass.
First lets train the `vgg7` model using the machop `train` action with the config from the new toml file and then load the trained checkpoint it into the `transform` pass.
-```bash
+```python
VGG_TOML_PATH = "../../../machop/configs/onnx/vgg7_gpu_quant.toml"
# !ch train --config {VGG_TOML_PATH}
@@ -313,7 +313,7 @@ VGG_TOML_PATH = "../../../machop/configs/onnx/vgg7_gpu_quant.toml"
# Load in the checkpoint from the previous train - modify accordingly
VGG_CHECKPOINT_PATH = "../../../mase_output/vgg7-pre-trained/test-accu-0.9332.ckpt"
-ch transform --config {VGG_TOML_PATH} --load {VGG_CHECKPOINT_PATH} --load-type pl
+!ch transform --config {VGG_TOML_PATH} --load {VGG_CHECKPOINT_PATH} --load-type pl
```
[2024-03-28 23:09:44,122] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)
@@ -387,7 +387,7 @@ ch transform --config {VGG_TOML_PATH} --load {VGG_CHECKPOINT_PATH} --load-type p
I0328 23:10:35.119032 140014036379456 onnx_runtime.py:90] Project will be created at /root/mase/mase_output/onnxrt/vgg7_cls_cifar10_2024-03-28
[32mINFO [0m [34mONNX Conversion Complete. Stored ONNX model to /root/mase/mase_output/onnxrt/vgg7_cls_cifar10_2024-03-28/optimized/version_3/model.onnx[0m
I0328 23:10:43.779212 140014036379456 onnx_runtime.py:108] ONNX Conversion Complete. Stored ONNX model to /root/mase/mase_output/onnxrt/vgg7_cls_cifar10_2024-03-28/optimized/version_3/model.onnx
- [32mINFO [0m [34mONNX Model Summary:
+ [32mINFO [0m [34mONNX Model Summary:
+-------+----------------------------+----------+---------------------------------------------------------------------+-------------------------------------+---------------------------------------------------+
| Index | Name | Type | Inputs | Outputs | Attributes |
+-------+----------------------------+----------+---------------------------------------------------------------------+-------------------------------------+---------------------------------------------------+
@@ -414,7 +414,7 @@ ch transform --config {VGG_TOML_PATH} --load {VGG_CHECKPOINT_PATH} --load-type p
| 20 | /classifier.3/Relu | Relu | /classifier.2/Gemm_output_0 | /classifier.3/Relu_output_0 | |
| 21 | /last_layer/Gemm | Gemm | /classifier.3/Relu_output_0, last_layer.weight, last_layer.bias | 76 | alpha, beta, transB |
+-------+----------------------------+----------+---------------------------------------------------------------------+-------------------------------------+---------------------------------------------------+[0m
- I0328 23:10:43.897069 140014036379456 onnx_runtime.py:146] ONNX Model Summary:
+ I0328 23:10:43.897069 140014036379456 onnx_runtime.py:146] ONNX Model Summary:
+-------+----------------------------+----------+---------------------------------------------------------------------+-------------------------------------+---------------------------------------------------+
| Index | Name | Type | Inputs | Outputs | Attributes |
+-------+----------------------------+----------+---------------------------------------------------------------------+-------------------------------------+---------------------------------------------------+
@@ -485,7 +485,7 @@ ch transform --config {VGG_TOML_PATH} --load {VGG_CHECKPOINT_PATH} --load-type p
| Average GPU Power Usage | 50.352 W |
| Inference Energy Consumption | 0.12032 mWh |
+------------------------------+-------------+[0m
- I0328 23:15:43.756563 140014036379456 runtime_analysis.py:521]
+ I0328 23:15:43.756563 140014036379456 runtime_analysis.py:521]
Results vgg7:
+------------------------------+-------------+
| Metric (Per Batch) | Value |
@@ -521,7 +521,7 @@ ch transform --config {VGG_TOML_PATH} --load {VGG_CHECKPOINT_PATH} --load-type p
| Average GPU Power Usage | 55.829 W |
| Inference Energy Consumption | 0.094099 mWh |
+------------------------------+--------------+[0m
- I0328 23:15:53.476423 140014036379456 runtime_analysis.py:521]
+ I0328 23:15:53.476423 140014036379456 runtime_analysis.py:521]
Results vgg7-onnx:
+------------------------------+--------------+
| Metric (Per Batch) | Value |
@@ -560,7 +560,7 @@ ch transform --config {VGG_TOML_PATH} --load {VGG_CHECKPOINT_PATH} --load-type p
| Average GPU Power Usage | 54.555 W |
| Inference Energy Consumption | 0.14388 mWh |
+------------------------------+-------------+[0m
- I0328 23:16:03.469463 140014036379456 runtime_analysis.py:521]
+ I0328 23:16:03.469463 140014036379456 runtime_analysis.py:521]
Results vgg7-onnx:
+------------------------------+-------------+
| Metric (Per Batch) | Value |
@@ -599,7 +599,7 @@ ch transform --config {VGG_TOML_PATH} --load {VGG_CHECKPOINT_PATH} --load-type p
| Average GPU Power Usage | 22.86 W |
| Inference Energy Consumption | 0.1748 mWh |
+------------------------------+------------+[0m
- I0328 23:18:23.964464 140014036379456 runtime_analysis.py:521]
+ I0328 23:18:23.964464 140014036379456 runtime_analysis.py:521]
Results vgg7-onnx:
+------------------------------+------------+
| Metric (Per Batch) | Value |
@@ -635,7 +635,7 @@ ch transform --config {VGG_TOML_PATH} --load {VGG_CHECKPOINT_PATH} --load-type p
| Average GPU Power Usage | 50.354 W |
| Inference Energy Consumption | 0.076714 mWh |
+------------------------------+--------------+[0m
- I0328 23:18:33.854084 140014036379456 runtime_analysis.py:521]
+ I0328 23:18:33.854084 140014036379456 runtime_analysis.py:521]
Results vgg7-onnx:
+------------------------------+--------------+
| Metric (Per Batch) | Value |
@@ -653,13 +653,13 @@ ch transform --config {VGG_TOML_PATH} --load {VGG_CHECKPOINT_PATH} --load-type p
I0328 23:18:33.855542 140014036379456 runtime_analysis.py:143] Runtime analysis results saved to /root/mase_output/tensorrt/quantization/vgg7_cls_cifar10_2024-03-28/onnx/version_12/model.json
-As shown above, the latency of the gpu inference is 30% less with the `vgg7` model without compromising accuracy simply by using the optimizations of ONNXRT.
+As shown above, the latency of the gpu inference is 30% less with the `vgg7` model without compromising accuracy simply by using the optimizations of ONNXRT.
-We will now look at quantization to further speed up the model.
+We will now look at quantization to further speed up the model.
## Section 2. Quantization
-We may quantize either using FP16 or INT8 by setting the `precision` parameter in `passes.onnxruntime.default.config` to `'fp16'` or `'int8'` respectively. INT8 quantization will show the most notable latency improvements but is more likely to lower performance.
+We may quantize either using FP16 or INT8 by setting the `precision` parameter in `passes.onnxruntime.default.config` to `'fp16'` or `'int8'` respectively. INT8 quantization will show the most notable latency improvements but is more likely to lower performance.
There are three types of quantization for ONNXRT and can be set in `onnxruntime.default.config` under `quantization_types`. The differences of the first two are for how they calibrate i.e. set the scale and zero points which are only relevant for integer based quantization:
- **Static Quantization**:
@@ -670,27 +670,26 @@ There are three types of quantization for ONNXRT and can be set in `onnxruntime.
- The scale and zero point of activations are calculated on-the-fly (online) and are specific for each forward pass.
- This approach is more accurate but introduces extra computational overhead
-The `onnx_runtime_interface_pass` pass also supports mixed precision. This is an automatic only procedure, where ONNXRT finds a minimal set of ops to skip while retaining a certain level of accuracy, converting most of the ops to float16 but leaving some in float32.
+The `onnx_runtime_interface_pass` pass also supports mixed precision. This is an automatic only procedure, where ONNXRT finds a minimal set of ops to skip while retaining a certain level of accuracy, converting most of the ops to float16 but leaving some in float32.
- **Auto Mixed Precision Quantization**:
- Automatically adjusts between FP16 and FP32 precisions to retain certain level of accuracy
- The `precision` parameter does not need to be set in the config since the whole process is automatic.
- Unfortunately, this process is currently only supported on GPU.
- This approach is most beneficial when INT8 or FP16 exclusive quantizations (static or dynamic) are giving poor results.
-All three methodolgies first pre-procsses the model before quantization adding further optimizations. This intermidate model is stored to the `pre-processed` directory.
+All three methodolgies first pre-procsses the model before quantization adding further optimizations. This intermidate model is stored to the `pre-processed` directory.
-For this example, we will set the `precision` to `'uint8'` (since `ConvInteger` node is not currently supported for `'int8'` on ONNXRT GPU execution provider).
+For this example, we will set the `precision` to `'uint8'` (since `ConvInteger` node is not currently supported for `'int8'` on ONNXRT GPU execution provider).
We will also set the `precision_types` to `['static', 'dynamic', 'auto']` to compare all three quantization methods, whilst keeping the other settings the exact same for a fair comparison against the optimized `vgg7` model used in the previous section.
-```bash
+```python
VGG_TOML_PATH = "../../../machop/configs/onnx/vgg7_gpu_quant.toml"
VGG_CHECKPOINT_PATH = "../../../mase_output/vgg7-pre-trained/test-accu-0.9332.ckpt"
!ch transform --config {VGG_TOML_PATH} --load {VGG_CHECKPOINT_PATH} --load-type pl
```
-```text
[2024-03-29 13:49:26,029] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)
INFO: Seed set to 0
WARNING: Logging before flag parsing goes to stderr.
@@ -762,7 +761,7 @@ VGG_CHECKPOINT_PATH = "../../../mase_output/vgg7-pre-trained/test-accu-0.9332.ck
I0329 13:50:32.508896 139783521261376 onnx_runtime.py:90] Project will be created at /root/mase/mase_output/onnxrt/vgg7_cls_cifar10_2024-03-29
[32mINFO [0m [34mONNX Conversion Complete. Stored ONNX model to /root/mase/mase_output/onnxrt/vgg7_cls_cifar10_2024-03-29/optimized/version_0/model.onnx[0m
I0329 13:50:53.587861 139783521261376 onnx_runtime.py:108] ONNX Conversion Complete. Stored ONNX model to /root/mase/mase_output/onnxrt/vgg7_cls_cifar10_2024-03-29/optimized/version_0/model.onnx
- [32mINFO [0m [34mONNX Model Summary:
+ [32mINFO [0m [34mONNX Model Summary:
+-------+----------------------------+----------+---------------------------------------------------------------------+-------------------------------------+---------------------------------------------------+
| Index | Name | Type | Inputs | Outputs | Attributes |
+-------+----------------------------+----------+---------------------------------------------------------------------+-------------------------------------+---------------------------------------------------+
@@ -789,7 +788,7 @@ VGG_CHECKPOINT_PATH = "../../../mase_output/vgg7-pre-trained/test-accu-0.9332.ck
| 20 | /classifier.3/Relu | Relu | /classifier.2/Gemm_output_0 | /classifier.3/Relu_output_0 | |
| 21 | /last_layer/Gemm | Gemm | /classifier.3/Relu_output_0, last_layer.weight, last_layer.bias | 76 | alpha, beta, transB |
+-------+----------------------------+----------+---------------------------------------------------------------------+-------------------------------------+---------------------------------------------------+[0m
- I0329 13:50:53.719763 139783521261376 onnx_runtime.py:146] ONNX Model Summary:
+ I0329 13:50:53.719763 139783521261376 onnx_runtime.py:146] ONNX Model Summary:
+-------+----------------------------+----------+---------------------------------------------------------------------+-------------------------------------+---------------------------------------------------+
| Index | Name | Type | Inputs | Outputs | Attributes |
+-------+----------------------------+----------+---------------------------------------------------------------------+-------------------------------------+---------------------------------------------------+
@@ -860,7 +859,7 @@ VGG_CHECKPOINT_PATH = "../../../mase_output/vgg7-pre-trained/test-accu-0.9332.ck
| Average GPU Power Usage | 52.579 W |
| Inference Energy Consumption | 0.125 mWh |
+------------------------------+-----------+[0m
- I0329 13:55:58.685605 139783521261376 runtime_analysis.py:521]
+ I0329 13:55:58.685605 139783521261376 runtime_analysis.py:521]
Results vgg7:
+------------------------------+-----------+
| Metric (Per Batch) | Value |
@@ -896,7 +895,7 @@ VGG_CHECKPOINT_PATH = "../../../mase_output/vgg7-pre-trained/test-accu-0.9332.ck
| Average GPU Power Usage | 53.26 W |
| Inference Energy Consumption | 0.089695 mWh |
+------------------------------+--------------+[0m
- I0329 13:56:20.459139 139783521261376 runtime_analysis.py:521]
+ I0329 13:56:20.459139 139783521261376 runtime_analysis.py:521]
Results vgg7-onnx:
+------------------------------+--------------+
| Metric (Per Batch) | Value |
@@ -935,7 +934,7 @@ VGG_CHECKPOINT_PATH = "../../../mase_output/vgg7-pre-trained/test-accu-0.9332.ck
| Average GPU Power Usage | 58.211 W |
| Inference Energy Consumption | 0.14364 mWh |
+------------------------------+-------------+[0m
- I0329 13:56:42.742136 139783521261376 runtime_analysis.py:521]
+ I0329 13:56:42.742136 139783521261376 runtime_analysis.py:521]
Results vgg7-onnx:
+------------------------------+-------------+
| Metric (Per Batch) | Value |
@@ -974,7 +973,7 @@ VGG_CHECKPOINT_PATH = "../../../mase_output/vgg7-pre-trained/test-accu-0.9332.ck
| Average GPU Power Usage | 22.924 W |
| Inference Energy Consumption | 0.0742 mWh |
+------------------------------+------------+[0m
- I0329 13:59:46.169317 139783521261376 runtime_analysis.py:521]
+ I0329 13:59:46.169317 139783521261376 runtime_analysis.py:521]
Results vgg7-onnx:
+------------------------------+------------+
| Metric (Per Batch) | Value |
@@ -1010,7 +1009,7 @@ VGG_CHECKPOINT_PATH = "../../../mase_output/vgg7-pre-trained/test-accu-0.9332.ck
| Average GPU Power Usage | 49.313 W |
| Inference Energy Consumption | 0.067374 mWh |
+------------------------------+--------------+[0m
- I0329 14:00:07.487649 139783521261376 runtime_analysis.py:521]
+ I0329 14:00:07.487649 139783521261376 runtime_analysis.py:521]
Results vgg7-onnx:
+------------------------------+--------------+
| Metric (Per Batch) | Value |
@@ -1030,6 +1029,6 @@ VGG_CHECKPOINT_PATH = "../../../mase_output/vgg7-pre-trained/test-accu-0.9332.ck
I0329 14:06:04.342311 139783521261376 save_and_load.py:147] Saved mase graph to /root/mase/mase_output/vgg7_cls_cifar10_2024-03-29/software/transform/transformed_ckpt
[32mINFO [0m [34mTransformation is completed[0m
I0329 14:06:04.342653 139783521261376 cli.py:388] Transformation is completed
-```
-As we can see, the optimized onnx model still outperforms Pytorch on the VGG model due to it's runtime optimizations. The static performs the best, then the automatic mixed precision which outperforms the dynamic quantization due to its requirement of calculating activations on-the-fly.
+
+As we can see, the optimized onnx model still outperforms Pytorch on the VGG model due to it's runtime optimizations. The static performs the best, then the automatic mixed precision which outperforms the dynamic quantization due to its requirement of calculating activations on-the-fly.
diff --git a/docs/source/modules/documentation/tutorials/advanced/tensorRT_quantization_tutorial.md b/docs/source/modules/documentation/tutorials/advanced/tensorRT_quantization_tutorial.md
index 5de3d3c22..b208f6473 100644
--- a/docs/source/modules/documentation/tutorials/advanced/tensorRT_quantization_tutorial.md
+++ b/docs/source/modules/documentation/tutorials/advanced/tensorRT_quantization_tutorial.md
@@ -5,11 +5,11 @@ This notebook is designed to show the features of the TensorRT passes integrated
## Section 1. INT8 Quantization
Firstly, we will show you how to do a int8 quantization of a simple model, `jsc-toy`, and compare the quantized model to the original model using the `Machop API`. The quantization process is split into the following stages, each using their own individual pass, and are explained in depth at each subsection:
-1. Fake quantization: `tensorrt_fake_quantize_transform_pass`
-2. Calibration: `tensorrt_calibrate_transform_pass`
-3. Quantized Aware Training: `tensorrt_fine_tune_transform_pass`
-4. Quantization: `tensorrt_engine_interface_pass`
-5. Analysis: `tensorrt_analysis_pass`
+1. [Fake quantization](#section-11-fake-quantization): `tensorrt_fake_quantize_transform_pass`
+2. [Calibration](#section-12-calibration): `tensorrt_calibrate_transform_pass`
+3. [Quantized Aware Training](#section-13-quantized-aware-training-qat): `tensorrt_fine_tune_transform_pass`
+4. [Quantization](#section-14-tensorrt-quantization): `tensorrt_engine_interface_pass`
+5. [Analysis](#section-15-performance-analysis): `tensorrt_analysis_pass`
We start by loading in the required libraries and passes required for the notebook as well as ensuring the correct path is set for machop to be used.
@@ -63,7 +63,7 @@ set_logging_verbosity("info")
I0329 12:52:20.742465 139924298352448 logger.py:44] Set logging level to info
-Next, we load in the toml file used for quantization.
+Next, we load in the toml file used for quantization. To view the configuration, click [here](../../../machop/configs/tensorrt/jsc_toy_INT8_quantization_by_type.toml).
```python
@@ -136,8 +136,8 @@ mg = MaseGraph(model=model)
Next, we train the `jsc-toy` model using the machop `train` action with the config from the toml file.
-```bash
-ch train --config {JSC_TOML_PATH}
+```python
+!ch train --config {JSC_TOML_PATH}
```
Then we load in the checkpoint. You will have to adjust this according to where it has been stored in the mase_output directory.
@@ -169,12 +169,12 @@ mg_original = deepcopy_mase_graph(mg)
Firstly, we fake quantize the module in order to perform calibration and fine tuning before actually quantizing - this is only used if we have int8 calibration as other precisions are not currently supported within [pytorch-quantization](https://docs.nvidia.com/deeplearning/tensorrt/pytorch-quantization-toolkit/docs/index.html#) library.
-This is acheived through the `tensorrt_fake_quantize_transform_pass` which goes through the model, either by type or by name, replaces each layer appropriately to a fake quantized form if the `quantize` parameter is set in the default config (`passes.tensorrt.default.config`) or on a per name or type basis.
+This is acheived through the `tensorrt_fake_quantize_transform_pass` which goes through the model, either by type or by name, replaces each layer appropriately to a fake quantized form if the `quantize` parameter is set in the default config (`passes.tensorrt.default.config`) or on a per name or type basis.
Currently the quantizable layers are:
- Linear
-- Conv1d, Conv2d, Conv3d
-- ConvTranspose1d, ConvTranspose2d, ConvTranspose3d
+- Conv1d, Conv2d, Conv3d
+- ConvTranspose1d, ConvTranspose2d, ConvTranspose3d
- MaxPool1d, MaxPool2d, MaxPool3d
- AvgPool1d, AvgPool2d, AvgPool3d
- LSTM, LSTMCell
@@ -202,7 +202,7 @@ summarize_quantization_analysis_pass(mg_original, mg)
| ReLU | relu | 4 | 0 | 4 |
| output | output | 1 | 0 | 1 |
| x | placeholder | 1 | 0 | 1 |[0m
- I0329 12:52:41.851252 139924298352448 summary.py:85]
+ I0329 12:52:41.851252 139924298352448 summary.py:85]
| Original type | OP | Total | Changed | Unchanged |
|-----------------+--------------+---------+-----------+-------------|
| BatchNorm1d | batch_norm1d | 4 | 0 | 4 |
@@ -212,16 +212,16 @@ summarize_quantization_analysis_pass(mg_original, mg)
| x | placeholder | 1 | 0 | 1 |
-As you can see we have succesfully fake quantized all linear layers inside `jsc-toy`. This means that we will be able to simulate a quantized model in order to calibrate and fine tune it. This fake quantization was done on typewise i.e. for linear layers only. See Section 4 for how to apply quantization layerwise - i.e. only first and second layers for example.
+As you can see we have succesfully fake quantized all linear layers inside `jsc-toy`. This means that we will be able to simulate a quantized model in order to calibrate and fine tune it. This fake quantization was done on typewise i.e. for linear layers only. See [Section 4](#section-4-layer-wise-mixed-precision) for how to apply quantization layerwise - i.e. only first and second layers for example.
### Section 1.2 Calibration
-Next, we perform calibration using the `tensorrt_calibrate_transform_pass`. Calibration is achieved by passing data samples to the quantizer and deciding the best amax for activations.
+Next, we perform calibration using the `tensorrt_calibrate_transform_pass`. Calibration is achieved by passing data samples to the quantizer and deciding the best amax for activations.
Calibrators can be added as a search space parameter to examine the best performing calibrator. The calibrators have been included in the toml as follows.
For example: `calibrators = ["percentile", "mse", "entropy"]`
-Note:
+Note:
- To use `percentile` calibration, a list of percentiles must be given
- To use `max` calibration, the `histogram` weight and input calibrators must be removed and replaced with `max`. This will use global maximum absolute value to calibrate the model.
- If `post_calibration_analysis` is set true the `tensorrt_analysis_pass` will be run for each calibrator tested to evaluate the most suitable calibrator for the model.
@@ -300,7 +300,7 @@ mg, _ = tensorrt_calibrate_transform_pass(mg, pass_args=tensorrt_config)
| Average GPU Power Usage | 22.239 W |
| Inference Energy Consumption | 0.018661 mWh |
+------------------------------+--------------+[0m
- I0329 12:52:49.573626 139924298352448 runtime_analysis.py:521]
+ I0329 12:52:49.573626 139924298352448 runtime_analysis.py:521]
Results jsc-toy:
+------------------------------+--------------+
| Metric (Per Batch) | Value |
@@ -354,7 +354,7 @@ mg, _ = tensorrt_calibrate_transform_pass(mg, pass_args=tensorrt_config)
| Average GPU Power Usage | 22.21 W |
| Inference Energy Consumption | 0.018664 mWh |
+------------------------------+--------------+[0m
- I0329 12:52:53.233150 139924298352448 runtime_analysis.py:521]
+ I0329 12:52:53.233150 139924298352448 runtime_analysis.py:521]
Results jsc-toy:
+------------------------------+--------------+
| Metric (Per Batch) | Value |
@@ -408,7 +408,7 @@ mg, _ = tensorrt_calibrate_transform_pass(mg, pass_args=tensorrt_config)
| Average GPU Power Usage | 22.252 W |
| Inference Energy Consumption | 0.017687 mWh |
+------------------------------+--------------+[0m
- I0329 12:52:56.441818 139924298352448 runtime_analysis.py:521]
+ I0329 12:52:56.441818 139924298352448 runtime_analysis.py:521]
Results jsc-toy:
+------------------------------+--------------+
| Metric (Per Batch) | Value |
@@ -462,7 +462,7 @@ mg, _ = tensorrt_calibrate_transform_pass(mg, pass_args=tensorrt_config)
| Average GPU Power Usage | 22.426 W |
| Inference Energy Consumption | 0.018681 mWh |
+------------------------------+--------------+[0m
- I0329 12:53:05.428555 139924298352448 runtime_analysis.py:521]
+ I0329 12:53:05.428555 139924298352448 runtime_analysis.py:521]
Results jsc-toy:
+------------------------------+--------------+
| Metric (Per Batch) | Value |
@@ -516,7 +516,7 @@ mg, _ = tensorrt_calibrate_transform_pass(mg, pass_args=tensorrt_config)
| Average GPU Power Usage | 22.525 W |
| Inference Energy Consumption | 0.018149 mWh |
+------------------------------+--------------+[0m
- I0329 12:53:22.697756 139924298352448 runtime_analysis.py:521]
+ I0329 12:53:22.697756 139924298352448 runtime_analysis.py:521]
Results jsc-toy:
+------------------------------+--------------+
| Metric (Per Batch) | Value |
@@ -538,7 +538,7 @@ mg, _ = tensorrt_calibrate_transform_pass(mg, pass_args=tensorrt_config)
I0329 12:53:22.701544 139924298352448 calibrate.py:213] Succeeded in calibrating the model in PyTorch!
-From the results, the 99% `percentile` clips too many values during the amax calibration, compromising the loss. However 99.99% demonstrates higher validation accuracy alongside `mse` and `entropy` for `jsc-toy`. For such a small model, the methods are not highly distinguished, however for larger models this calibration process will be important for ensuring the quantized model still performs well.
+From the results, the 99% `percentile` clips too many values during the amax calibration, compromising the loss. However 99.99% demonstrates higher validation accuracy alongside `mse` and `entropy` for `jsc-toy`. For such a small model, the methods are not highly distinguished, however for larger models this calibration process will be important for ensuring the quantized model still performs well.
### Section 1.3 Quantized Aware Training (QAT)
@@ -576,14 +576,14 @@ mg, _ = tensorrt_fine_tune_transform_pass(mg, pass_args=tensorrt_config)
I0329 12:53:59.800536 139924298352448 cuda.py:61] LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
- I0329 12:53:59.814722 139924298352448 model_summary.py:94]
+ I0329 12:53:59.814722 139924298352448 model_summary.py:94]
| Name | Type | Params
-------------------------------------------------
- 0 | model | GraphModule | 327
- 1 | loss_fn | CrossEntropyLoss | 0
- 2 | acc_train | MulticlassAccuracy | 0
- 3 | loss_val | MeanMetric | 0
- 4 | loss_test | MeanMetric | 0
+ 0 | model | GraphModule | 327
+ 1 | loss_fn | CrossEntropyLoss | 0
+ 2 | acc_train | MulticlassAccuracy | 0
+ 3 | loss_val | MeanMetric | 0
+ 4 | loss_test | MeanMetric | 0
-------------------------------------------------
327 Trainable params
0 Non-trainable params
@@ -620,7 +620,7 @@ mg, _ = tensorrt_fine_tune_transform_pass(mg, pass_args=tensorrt_config)
After QAT, we are now ready to convert the model to a tensorRT engine so that it can be run with the superior inference speeds. To do so, we use the `tensorrt_engine_interface_pass` which converts the `MaseGraph`'s model from a Pytorch one to an ONNX format as an intermediate stage of the conversion.
-During the conversion process, the `.onnx` and `.trt` files are stored to their respective folders shown in Section 1.3.
+During the conversion process, the `.onnx` and `.trt` files are stored to their respective folders shown in [Section 1.3](#section-13-quantized-aware-training-qat).
This interface pass returns a dictionary containing the `onnx_path` and `trt_engine_path`.
@@ -677,7 +677,7 @@ _, _ = runtime_analysis_pass(meta['trt_engine_path'], pass_args=runtime_analysis
| Average GPU Power Usage | 23.792 W |
| Inference Energy Consumption | 0.0057535 mWh |
+------------------------------+---------------+[0m
- I0329 13:03:32.504800 139924298352448 runtime_analysis.py:521]
+ I0329 13:03:32.504800 139924298352448 runtime_analysis.py:521]
Results jsc-toy:
+------------------------------+---------------+
| Metric (Per Batch) | Value |
@@ -709,7 +709,7 @@ _, _ = runtime_analysis_pass(meta['trt_engine_path'], pass_args=runtime_analysis
| Average GPU Power Usage | 23.043 W |
| Inference Energy Consumption | 0.00085532 mWh |
+------------------------------+----------------+[0m
- I0329 13:03:34.503784 139924298352448 runtime_analysis.py:521]
+ I0329 13:03:34.503784 139924298352448 runtime_analysis.py:521]
Results jsc-toy-trt_quantized:
+------------------------------+----------------+
| Metric (Per Batch) | Value |
@@ -727,18 +727,19 @@ _, _ = runtime_analysis_pass(meta['trt_engine_path'], pass_args=runtime_analysis
I0329 13:03:34.506492 139924298352448 runtime_analysis.py:143] Runtime analysis results saved to /root/mase_output/tensorrt/quantization/jsc-toy_cls_jsc_2024-03-29/tensorrt/version_0/model.json
-As shown above, the latency has decreased around 6x with the `jsc-toy` model without compromising accuracy due to the well calibrated amax and quantization-aware fine tuning and additional runtime optimizations from TensorRT. The inference energy consumption has thus also dropped tremendously and this is an excellent demonstration for the need to quantize in industry especially for LLMs in order to reduce energy usage.
+As shown above, the latency has decreased around 6x with the `jsc-toy` model without compromising accuracy due to the well calibrated amax and quantization-aware fine tuning and additional runtime optimizations from TensorRT. The inference energy consumption has thus also dropped tremendously and this is an excellent demonstration for the need to quantize in industry especially for LLMs in order to reduce energy usage.
## Section 2. FP16 Quantization
-We will now load in a new toml configuration that uses fp16 instead of int8, whilst keeping the other settings the exact same for a fair comparison. This time however, we will use chop from the terminal which runs all the passes showcased in Section 1.
+We will now load in a new toml configuration that uses fp16 instead of int8, whilst keeping the other settings the exact same for a fair comparison. This time however, we will use chop from the terminal which runs all the passes showcased in [Section 1](#section-1---int8-quantization).
-Since float quantization does not require calibration, nor is it supported by `pytorch-quantization`, the model will not undergo fake quantization; for the time being this unfortunately means QAT is unavailable and only undergoes Post Training Quantization (PTQ).
+Since float quantization does not require calibration, nor is it supported by `pytorch-quantization`, the model will not undergo fake quantization; for the time being this unfortunately means QAT is unavailable and only undergoes Post Training Quantization (PTQ).
-```text
+```python
JSC_FP16_BY_TYPE_TOML = "../../../machop/configs/tensorrt/jsc_toy_FP16_quantization_by_type.toml"
!ch transform --config {JSC_FP16_BY_TYPE_TOML} --load {JSC_CHECKPOINT_PATH} --load-type pl
+```
8808.24s - pydevd: Sending message related to process being replaced timed-out after 5 seconds
[2024-03-28 09:37:03,989] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)
@@ -810,7 +811,7 @@ JSC_FP16_BY_TYPE_TOML = "../../../machop/configs/tensorrt/jsc_toy_FP16_quantizat
| ReLU | relu | 4 | 0 | 4 |
| output | output | 1 | 0 | 1 |
| x | placeholder | 1 | 0 | 1 |[0m
- I0328 09:37:11.551648 140201001654080 summary.py:85]
+ I0328 09:37:11.551648 140201001654080 summary.py:85]
| Original type | OP | Total | Changed | Unchanged |
|-----------------+--------------+---------+-----------+-------------|
| BatchNorm1d | batch_norm1d | 4 | 0 | 4 |
@@ -848,7 +849,7 @@ JSC_FP16_BY_TYPE_TOML = "../../../machop/configs/tensorrt/jsc_toy_FP16_quantizat
| Average GPU Power Usage | 22.024 W |
| Inference Energy Consumption | 0.0049148 mWh |
+------------------------------+---------------+[0m
- I0328 09:37:36.951404 140201001654080 runtime_analysis.py:437]
+ I0328 09:37:36.951404 140201001654080 runtime_analysis.py:437]
Results jsc-toy:
+------------------------------+---------------+
| Metric (Per Batch) | Value |
@@ -870,7 +871,7 @@ JSC_FP16_BY_TYPE_TOML = "../../../machop/configs/tensorrt/jsc_toy_FP16_quantizat
------|---------|----------|----------------------|----------------------|-----------------------
0 | Input | FLOAT | (64, 16) | (64, 16) | input
1 | Output | FLOAT | (64, 5) | (64, 5) | 37[0m
- I0328 09:37:36.960667 140201001654080 runtime_analysis.py:167]
+ I0328 09:37:36.960667 140201001654080 runtime_analysis.py:167]
TensorRT Engine Input/Output Information:
Index | Type | DataType | Static Shape | Dynamic Shape | Name
------|---------|----------|----------------------|----------------------|-----------------------
@@ -892,7 +893,7 @@ JSC_FP16_BY_TYPE_TOML = "../../../machop/configs/tensorrt/jsc_toy_FP16_quantizat
| Average GPU Power Usage | 21.706 W |
| Inference Energy Consumption | 0.00055067 mWh |
+------------------------------+----------------+[0m
- I0328 09:37:43.052305 140201001654080 runtime_analysis.py:437]
+ I0328 09:37:43.052305 140201001654080 runtime_analysis.py:437]
Results jsc-toy-trt_quantized:
+------------------------------+----------------+
| Metric (Per Batch) | Value |
@@ -912,7 +913,7 @@ JSC_FP16_BY_TYPE_TOML = "../../../machop/configs/tensorrt/jsc_toy_FP16_quantizat
I0328 09:37:43.132117 140201001654080 save_and_load.py:147] Saved mase graph to /root/mase/mase_output/jsc-toy_cls_jsc_2024-03-28/software/transform/transformed_ckpt
[32mINFO [0m [34mTransformation is completed[0m
I0328 09:37:43.132461 140201001654080 cli.py:383] Transformation is completed
-```
+
As you can see, `fp16` acheives a slighty higher test accuracy but a slightly lower latency (~30%) from that of int8 quantization; it is still ~2.5x faster than the unquantized model. Now lets apply quantization to a more complicated model.
@@ -924,30 +925,31 @@ In this case, we set:
- The `by` parameter to `type`
- The `quantize` parameter to true for `passes.tensorrt.conv2d.config` and `precision` parameter to 'int8'.
- The `input` and `weight` quantize axis for the conv2d layers.
-- The default `passes.tensorrt.default.config` precision to true.
+- The default `passes.tensorrt.default.config` precision to true.
During the TensorRT quantization, the model's conv2d layers will be converted to an int8 fake quantized form, whilst the linear layers are kept to their default 'fp16'. Calibration of the conv2d layers and then fine tuning will be undergone before quantization and inference.
-You may either download a pretrained model [here](https://imperiallondon-my.sharepoint.com/:f:/g/personal/zz7522_ic_ac_uk/Emh3VT7Q_qRFmnp8kDrcgDoBwGUuzLwwKNtX8ZAt368jJQ?e=gsKONa), otherwise train it yourself as shown below.
+You may either download a pretrained model [here](https://imperiallondon-my.sharepoint.com/:f:/g/personal/zz7522_ic_ac_uk/Emh3VT7Q_qRFmnp8kDrcgDoBwGUuzLwwKNtX8ZAt368jJQ?e=gsKONa), otherwise train it yourself as shown below.
-```bash
+```python
VGG_TYPEWISE_TOML = "../../../machop/configs/tensorrt/vgg7_typewise_mixed_precision.toml"
-ch train --config {VGG_TYPEWISE_TOML}
+!ch train --config {VGG_TYPEWISE_TOML}
```
-We will now load the checkpoint in, quantize the model and compare it to the unquantized version as we did in Section 1.5
+We will now load the checkpoint in, quantize the model and compare it to the unquantized version as we did in [Section 1.5](#section-15-performance-analysis)
-```bash
+```python
# Change this checkpoint path accordingly
VGG_CHECKPOINT_PATH = "../../../mase_output/vgg7-pre-trained/test-accu-0.9332.ckpt"
```
-```text
-ch transform --config {VGG_TYPEWISE_TOML} --load {VGG_CHECKPOINT_PATH} --load-type pl
+```python
+!ch transform --config {VGG_TYPEWISE_TOML} --load {VGG_CHECKPOINT_PATH} --load-type pl
+```
[2024-03-28 23:00:09,016] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)
INFO: Seed set to 0
@@ -1031,7 +1033,7 @@ ch transform --config {VGG_TYPEWISE_TOML} --load {VGG_CHECKPOINT_PATH} --load-ty
| output | output | 1 | 0 | 1 |
| view | view | 1 | 0 | 1 |
| x | placeholder | 1 | 0 | 1 |[0m
- I0328 23:00:37.270473 139939454809920 summary.py:85]
+ I0328 23:00:37.270473 139939454809920 summary.py:85]
| Original type | OP | Total | Changed | Unchanged |
|-----------------+--------------+---------+-----------+-------------|
| BatchNorm2d | batch_norm2d | 6 | 0 | 6 |
@@ -1159,7 +1161,7 @@ ch transform --config {VGG_TYPEWISE_TOML} --load {VGG_CHECKPOINT_PATH} --load-ty
| Average GPU Power Usage | 59.019 W |
| Inference Energy Consumption | 0.24777 mWh |
+------------------------------+-------------+[0m
- I0328 23:00:55.766893 139939454809920 runtime_analysis.py:521]
+ I0328 23:00:55.766893 139939454809920 runtime_analysis.py:521]
Results vgg7:
+------------------------------+-------------+
| Metric (Per Batch) | Value |
@@ -1231,7 +1233,7 @@ ch transform --config {VGG_TYPEWISE_TOML} --load {VGG_CHECKPOINT_PATH} --load-ty
| Average GPU Power Usage | 59.653 W |
| Inference Energy Consumption | 0.25317 mWh |
+------------------------------+-------------+[0m
- I0328 23:01:07.450706 139939454809920 runtime_analysis.py:521]
+ I0328 23:01:07.450706 139939454809920 runtime_analysis.py:521]
Results vgg7:
+------------------------------+-------------+
| Metric (Per Batch) | Value |
@@ -1268,14 +1270,14 @@ ch transform --config {VGG_TYPEWISE_TOML} --load {VGG_CHECKPOINT_PATH} --load-ty
Files already downloaded and verified
Files already downloaded and verified
I0328 23:01:12.623627 139939454809920 cuda.py:61] LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
- I0328 23:01:12.632704 139939454809920 model_summary.py:94]
+ I0328 23:01:12.632704 139939454809920 model_summary.py:94]
| Name | Type | Params
-------------------------------------------------
0 | model | GraphModule | 14.0 M
- 1 | loss_fn | CrossEntropyLoss | 0
- 2 | acc_train | MulticlassAccuracy | 0
- 3 | loss_val | MeanMetric | 0
- 4 | loss_test | MeanMetric | 0
+ 1 | loss_fn | CrossEntropyLoss | 0
+ 2 | acc_train | MulticlassAccuracy | 0
+ 3 | loss_val | MeanMetric | 0
+ 4 | loss_test | MeanMetric | 0
-------------------------------------------------
14.0 M Trainable params
0 Non-trainable params
@@ -1643,7 +1645,7 @@ ch transform --config {VGG_TYPEWISE_TOML} --load {VGG_CHECKPOINT_PATH} --load-ty
| Average GPU Power Usage | 58.043 W |
| Inference Energy Consumption | 0.1441 mWh |
+------------------------------+------------+[0m
- I0328 23:06:47.111017 139939454809920 runtime_analysis.py:521]
+ I0328 23:06:47.111017 139939454809920 runtime_analysis.py:521]
Results vgg7:
+------------------------------+------------+
| Metric (Per Batch) | Value |
@@ -1675,7 +1677,7 @@ ch transform --config {VGG_TYPEWISE_TOML} --load {VGG_CHECKPOINT_PATH} --load-ty
| Average GPU Power Usage | 52.687 W |
| Inference Energy Consumption | 0.12441 mWh |
+------------------------------+-------------+[0m
- I0328 23:07:00.676242 139939454809920 runtime_analysis.py:521]
+ I0328 23:07:00.676242 139939454809920 runtime_analysis.py:521]
Results vgg7-trt_quantized:
+------------------------------+-------------+
| Metric (Per Batch) | Value |
@@ -1691,13 +1693,13 @@ ch transform --config {VGG_TYPEWISE_TOML} --load {VGG_CHECKPOINT_PATH} --load-ty
+------------------------------+-------------+
[32mINFO [0m [34mRuntime analysis results saved to /root/mase_output/tensorrt/quantization/vgg7_cls_cifar10_2024-03-28/tensorrt/version_7/model.json[0m
I0328 23:07:00.677799 139939454809920 runtime_analysis.py:143] Runtime analysis results saved to /root/mase_output/tensorrt/quantization/vgg7_cls_cifar10_2024-03-28/tensorrt/version_7/model.json
-```
-By quantizing all convolutional layers to INT8 and maintaining fp16 precision for the linear layers we see a marginal decrease in latency whilst maintaining a comparable accuracy. By experimenting with precisions on a per type basis, you may find insights that work best for your model.
+
+By quantizing all convolutional layers to INT8 and maintaining fp16 precision for the linear layers we see a marginal decrease in latency whilst maintaining a comparable accuracy. By experimenting with precisions on a per type basis, you may find insights that work best for your model.
## Section 4. Layer-wise Mixed Precision
-So far we have strictly quantized either in int8 or fp16. Now, we will show how to conduct layerwise mixed precision using the same `vgg7` model. In this case we will show how for instance, layer 0 and 1 can be set to fp16, while the remaining layers can be int8 quantized.
+So far we have strictly quantized either in int8 or fp16. Now, we will show how to conduct layerwise mixed precision using the same `vgg7` model. In this case we will show how for instance, layer 0 and 1 can be set to fp16, while the remaining layers can be int8 quantized.
For this, we set:
- The `by` parameter to `name`
@@ -1706,10 +1708,11 @@ For this, we set:
- The `precision` to 'int8' for `passes.tensorrt.feature_layers_2.config and passes.tensorrt.feature_layers_3.config` (although this is not necessary since the default is already set to 'int8')
-```text
+```python
VGG_LAYERWISE_TOML = "../../../machop/configs/tensorrt/vgg7_layerwise_mixed_precision.toml"
-ch transform --config {VGG_LAYERWISE_TOML} --load {VGG_CHECKPOINT_PATH} --load-type pl
+!ch transform --config {VGG_LAYERWISE_TOML} --load {VGG_CHECKPOINT_PATH} --load-type pl
+```
[2024-03-28 23:25:51,157] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)
INFO: Seed set to 0
@@ -1793,7 +1796,7 @@ ch transform --config {VGG_LAYERWISE_TOML} --load {VGG_CHECKPOINT_PATH} --load-t
| output | output | 1 | 0 | 1 |
| view | view | 1 | 0 | 1 |
| x | placeholder | 1 | 0 | 1 |[0m
- I0328 23:26:12.941653 140449214740288 summary.py:85]
+ I0328 23:26:12.941653 140449214740288 summary.py:85]
| Original type | OP | Total | Changed | Unchanged |
|-----------------+--------------+---------+-----------+-------------|
| BatchNorm2d | batch_norm2d | 6 | 0 | 6 |
@@ -1953,7 +1956,7 @@ ch transform --config {VGG_LAYERWISE_TOML} --load {VGG_CHECKPOINT_PATH} --load-t
| Average GPU Power Usage | 57.532 W |
| Inference Energy Consumption | 0.28607 mWh |
+------------------------------+-------------+[0m
- I0328 23:26:29.263397 140449214740288 runtime_analysis.py:521]
+ I0328 23:26:29.263397 140449214740288 runtime_analysis.py:521]
Results vgg7:
+------------------------------+-------------+
| Metric (Per Batch) | Value |
@@ -2037,7 +2040,7 @@ ch transform --config {VGG_LAYERWISE_TOML} --load {VGG_CHECKPOINT_PATH} --load-t
| Average GPU Power Usage | 57.867 W |
| Inference Energy Consumption | 0.29145 mWh |
+------------------------------+-------------+[0m
- I0328 23:26:40.146152 140449214740288 runtime_analysis.py:521]
+ I0328 23:26:40.146152 140449214740288 runtime_analysis.py:521]
Results vgg7:
+------------------------------+-------------+
| Metric (Per Batch) | Value |
@@ -2096,7 +2099,7 @@ ch transform --config {VGG_LAYERWISE_TOML} --load {VGG_CHECKPOINT_PATH} --load-t
| Average GPU Power Usage | 55.687 W |
| Inference Energy Consumption | 0.12102 mWh |
+------------------------------+-------------+[0m
- I0328 23:07:00.676242 139939454809920 runtime_analysis.py:521]
+ I0328 23:07:00.676242 139939454809920 runtime_analysis.py:521]
Results vgg7-trt_quantized:
+------------------------------+-------------+
| Metric (Per Batch) | Value |
@@ -2112,6 +2115,6 @@ ch transform --config {VGG_LAYERWISE_TOML} --load {VGG_CHECKPOINT_PATH} --load-t
+------------------------------+-------------+
[32mINFO [0m [34mRuntime analysis results saved to /root/mase_output/tensorrt/quantization/vgg7_cls_cifar10_2024-03-28/tensorrt/version_8/model.json[0m
I0328 23:07:00.677799 139939454809920 runtime_analysis.py:143] Runtime analysis results saved to /root/mase_output/tensorrt/quantization/vgg7_cls_cifar10_2024-03-28/tensorrt/version_8/model.json
-```
+
In this case, we can see through the quantized summary that one convolutional layer (feature_layers_1) has not been quantized as its precision will be configured to 'fp16' in the tensorrt engine conversion stage whilst the remaining convolutional and linear layers have been quantized.
diff --git a/docs/source/modules/documentation/tutorials/cli/mixed_precision_search_on_manual_model.md b/docs/source/modules/documentation/tutorials/cli/mixed_precision_search_on_manual_model.md
index 9fbb14c38..8f4d69cf1 100644
--- a/docs/source/modules/documentation/tutorials/cli/mixed_precision_search_on_manual_model.md
+++ b/docs/source/modules/documentation/tutorials/cli/mixed_precision_search_on_manual_model.md
@@ -82,7 +82,7 @@ cd machop
```
When the search is done, the best quantization config will be printed out. Since we run multi-objective search. There may be multiple best trials found by Optuna.
-```text
+```txt
Best trial(s):
| | number | software_metrics | hardware_metrics | scaled_metrics |
|----+----------+--------------------------------------+------------------------------------------------------+-------------------------------------------------|
@@ -100,7 +100,7 @@ Here is part of the `log.json` recording all search details.
For example, `log["0"]["user_attrs_sampled_config"]` is the sampled quantization config of trial 0. Expand it and you will set the precision of each matmul/linear layer's operands.
-```text
+```json
{
"0":{
"number":0,
@@ -120,8 +120,8 @@ For example, `log["0"]["user_attrs_sampled_config"]` is the sampled quantization
"bias_width":2,
"bias_frac_width":8
}
- }
- // ...
+ },
+ ...
},
"user_attrs_scaled_metrics":{
"accuracy":0.5,
@@ -154,8 +154,8 @@ For example, `log["0"]["user_attrs_sampled_config"]` is the sampled quantization
"bias_width":4,
"bias_frac_width":3
}
- }
- // ...
+ },
+ ...
},
"user_attrs_scaled_metrics":{
"accuracy":0.5747232437,
@@ -169,7 +169,7 @@ For example, `log["0"]["user_attrs_sampled_config"]` is the sampled quantization
"datetime_start":1694095031290,
"datetime_complete":1694095032462,
"duration":1172
- }
- // ...
+ },
+ ...
}
```
\ No newline at end of file
diff --git a/docs/source/modules/documentation/tutorials/cli/mixed_precision_search_on_mase_graph.md b/docs/source/modules/documentation/tutorials/cli/mixed_precision_search_on_mase_graph.md
index 10b309d90..e82d69d2d 100644
--- a/docs/source/modules/documentation/tutorials/cli/mixed_precision_search_on_mase_graph.md
+++ b/docs/source/modules/documentation/tutorials/cli/mixed_precision_search_on_mase_graph.md
@@ -4,13 +4,13 @@ This tutorial shows how to search for mixed-precision quantization strategy for
## Commands
-First we train a model on the dataset. After training for some epochs, we get a model with some validation accuracy. The checkpoint is saved at an auto-created location. You can refer to Run the train action with the CLI for more detailed explanation.
+First we train a model on the dataset. After training for some epochs, we get a model with some validation accuracy. The checkpoint is saved at an auto-created location. You can refer to [Run the train action with the CLI](../train/simple_train_flow.md) for more detailed explanation.
The reason why we need a pre-trained model is because we would like to do a post-training-quantization (PTQ) search. This means the quantization happens on a pre-trained model. We then use the PTQ accuracy as a proxy signal for our search.
```bash
-cd src
+cd src
./ch train jsc-tiny jsc --max-epochs 3 --batch-size 256 --accelerator cpu --project tmp --debug --cpu 0
```
@@ -151,7 +151,7 @@ The entire searching log is saved in `../mase_output/jsc-tiny/software/search_ck
Here is part of the `log.json`
-```text
+```json
{
"0":{
"number":0,
diff --git a/docs/source/modules/documentation/tutorials/cli/simple_train_flow.md b/docs/source/modules/documentation/tutorials/cli/simple_train_flow.md
index 52c21ca40..8f9eb8cb1 100644
--- a/docs/source/modules/documentation/tutorials/cli/simple_train_flow.md
+++ b/docs/source/modules/documentation/tutorials/cli/simple_train_flow.md
@@ -12,7 +12,7 @@ In this case, we can try a toymodel, the command looks like the following
```bash
# assuming you you are at the our-stuff/mase directory
-cd src
+cd src
./ch train toy toy_tiny --config ../configs/archive/test/train.toml --max-epochs 3
```
@@ -20,7 +20,7 @@ cd src
You can fetch all command-line arguments:
-```text
+```bash
[nix-shell:~/Projects/mase/src]$ ./ch -help
INFO Set logging level to debug
WARNING TensorRT pass is unavailable because the following dependencies are not installed: pytorch_quantization, tensorrt, pycuda, cuda.
@@ -176,7 +176,7 @@ This directory includes
## Training Logs
-MASE creates [Tensorboard](https://pytorch.org/tutorials/recipes/recipes/tensorboard_with_pytorch.html) or [wandb](https://wandb.ai/site) logs for the training flow - allowing tracking and visualizing metrics such as loss and accuracy. The log files are in `/software/tensorboard/lightning_logs/version_`.
+MASE creates [Tensorboard](https://pytorch.org/tutorials/recipes/recipes/tensorboard_with_pytorch.html) or [wandb](https://wandb.ai/site) logs for the training flow - allowing tracking and visualizing metrics such as loss and accuracy. The log files are in `/software/tensorboard/lightning_logs/version_`.
Run Tensorboard to visualise the logs using:
@@ -209,5 +209,5 @@ To test the model trained above you can use:
```bash
# After training, you will have your checkpoint under mase-tools/mase_output
-# For example, the checkpoint is under ../mase_output/toy_classification_toy-tiny_2023-07-03/software/training_ckpts/best.ckpt
+# For example, the checkpoint is under ../mase_output/toy_classification_toy-tiny_2023-07-03/software/training_ckpts/best.ckpt
./ch test toy toy_tiny --config ../configs/archive/test/train.toml --load ../mase_output/toy_classification_toy_tiny_2024-06-13/software/training_ckpts/best.ckpt```
diff --git a/docs/source/modules/documentation/tutorials/developer/Add-model-to-machop.md b/docs/source/modules/documentation/tutorials/developer/Add-model-to-machop.md
index c7750e262..a30a172fb 100644
--- a/docs/source/modules/documentation/tutorials/developer/Add-model-to-machop.md
+++ b/docs/source/modules/documentation/tutorials/developer/Add-model-to-machop.md
@@ -3,10 +3,10 @@
This document includes steps to add a new model into Machop
## Overall Structure
### Model
-All models that Machop support are defined inside **mase-tools/machop/chop/models**. Each model has a unique get model function, which can be called to create the model. Those get model function will be exported into a dictionary in `__init__` file.
+All models that Machop support are defined inside **mase-tools/machop/chop/models**. Each model has a unique get model function, which can be called to create the model. Those get model function will be exported into a dictionary in [\_\_init\_\_](%2E%2E%5Cmachop%5Cchop%5Cmodels%5C%5F%5Finit%5F%5F.py) file.
### Command Line Interface
-Command Line Interface (cli) will take the input config, and perform the task defined inside the config. When training, cli will look into the dictionary contains the get funtions, use the get-function to create a model, and do training then.
+[Command Line Interface (cli)](..\machop\chop\cli.py) will take the input config, and perform the task defined inside the config. When training, cli will look into the dictionary contains the get funtions, use the get-function to create a model, and do training then.
## What To Do
1. Find the GitHub repositories of the original paper, find the code that defines the models, and copy it into the right folder under **mase-tools\machop\chop\models**
@@ -20,7 +20,7 @@ Command Line Interface (cli) will take the input config, and perform the task de
## Get model function
- **Info** should be used as one of the input variables. It is a dictionary that contains information about the dataset, e.g., number of classes; input image size.
-- Other then **Info**, Inputs for different types of models are different, you can check `_setup_model_and_dataset` function defined in cli.py for more detail.
+- Other then **Info**, Inputs for different types of models are different, you can check `_setup_model_and_dataset` function defined in [cli.py](..\machop\chop\cli.py) for more detail.
- function name of get-function should be in smaller case
- keys of the dictionary should also be in smaller case
diff --git a/docs/source/modules/hardware/activations/gelu.md b/docs/source/modules/hardware/activations/gelu.md
index 25b2560b4..cb2d094e8 100644
--- a/docs/source/modules/hardware/activations/gelu.md
+++ b/docs/source/modules/hardware/activations/gelu.md
@@ -11,7 +11,7 @@ When the approximate argument is set to 'tanh', GELU is estimated with:
`GELU(x) = 0.5 * x * (1 + Tanh(2/π * (x + 0.044715 * x^3)))`
-## Parameters:
+### Parameters:
- `approximate` (str, optional): The GELU approximation algorithm to use: 'none' | 'tanh'. Default: 'none'.
diff --git a/docs/source/modules/hardware/activations/selu.md b/docs/source/modules/hardware/activations/selu.md
index 882e5c8d5..8413420a1 100644
--- a/docs/source/modules/hardware/activations/selu.md
+++ b/docs/source/modules/hardware/activations/selu.md
@@ -8,7 +8,7 @@ where:
- α = 1.6732632423543772848170429916717
- scale = 1.0507009873554804934193349852946
-## Parameters
+### Parameters
- `inplace` (bool, optional): Can optionally do the operation in-place. Default: False.
@@ -30,7 +30,7 @@ A hybrid approach is used for implementing exponential function $e^{-|x|}$ for a
2. **Representation of Binary Number**:
The N-bit binary number $a = b_{N-1}b_{N-2}...b_1b_0$ is represented, where $b_0$ is the least significant bit, and each bit $b_i$ has a place value $p_i$ given by $p_i = 2^{-P} \times 2^i$.
-
+
3. **Exponential Computation**:
$e^{-a} = \prod e^{-p_i \times b_i}$
diff --git a/docs/source/modules/hardware/activations/softplus.md b/docs/source/modules/hardware/activations/softplus.md
index 0d7f3e77a..ed0066cfb 100644
--- a/docs/source/modules/hardware/activations/softplus.md
+++ b/docs/source/modules/hardware/activations/softplus.md
@@ -9,7 +9,7 @@ where:
- Softplus is a smooth approximation to the ReLU function and can be used to constrain the output of a machine to always be positive.
- For numerical stability, the implementation reverts to the linear function when `input * β > threshold`.
-## Parameters:
+### Parameters:
- `beta` (int): The β value for the Softplus formulation. Default: 1.
- `threshold` (int): Values above this revert to a linear function. Default: 20.
diff --git a/docs/source/modules/hardware/linear/fixed_linear.md b/docs/source/modules/hardware/linear/fixed_linear.md
index 33ddf350e..672ea765a 100644
--- a/docs/source/modules/hardware/linear/fixed_linear.md
+++ b/docs/source/modules/hardware/linear/fixed_linear.md
@@ -32,7 +32,7 @@ The module has the following parameters, following the hardware metadata standar
| Parameter | Default Value | Definition |
|------------------------------ |-------------------------- |------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
| DATA_IN_0_PARALLELISM_DIM_0 | 4 | Number of elements per transaction at the input interface. Dictates the number of transactions to compute the full layer. |
-| WEIGHT_PARALLELISM_DIM_0 | 4 | Number of columns of the weights matrix per transaction at the weights interface. This is equivalent to the number of dot product modules. Also dictates the number of backpressure cycles on the input interface (see Latency Analysis below) |
+| WEIGHT_PARALLELISM_DIM_0 | 4 | Number of columns of the weights matrix per transaction at the weights interface. This is equivalent to the number of dot product modules. Also dictates the number of backpressure cycles on the input interface (see [Latency Analysis](#latency-analysis) below) |
| DATA_OUT_0_PARALLELISM_DIM_0 | WEIGHT_PARALLELISM_DIM_0 | Number of elements per transaction at the output interface. |
| BIAS_PARALLELISM_DIM_0 | WEIGHT_PARALLELISM_DIM_0 | Number of elements per transaction at the bias interface. Dictates the number of fixed-point adders. |
@@ -56,7 +56,7 @@ The same process is repeated with the second input sub-vector $X_2$ and weight s
-## Latency Analysis
+## Latency Analysis
The time taken to compute a linear layer using the `fixed_linear` module, $L_{FL}$ can be broken down into 2 phases, the input driving phase $L_L$, and the pipeline unloading phase $L_U$ that begins after the last input beat is transferred.
diff --git a/docs/source/modules/hardware/systolic_modules/output_stationary.md b/docs/source/modules/hardware/systolic_modules/output_stationary.md
index d0e5370b2..d5c605a41 100644
--- a/docs/source/modules/hardware/systolic_modules/output_stationary.md
+++ b/docs/source/modules/hardware/systolic_modules/output_stationary.md
@@ -8,6 +8,6 @@ The MAC units in each PE perform the multiply-accumulate operation over 2 cycles

-## Systolic Module Driver
+### Systolic Module Driver
The Systolic Module Driver generates pulse signals in the format required to drive the read interface of an on-chip buffer such that data signals are made available with the required timing for the processing elements of a systolic module. This is achieved through a shift register of size BUFFER_SLOT_COUNT. After receiving a starting pulse, the least significant bit is set to 1. Subsequently, the register shifts after every shift pulse, up to a runtime-parametrizable pulse limit count parameter (this is set to the number of output features for the layer being executed). The driver should then pulse a subsequent BUFFER_SLOT_COUNT times until the register is flushed.
\ No newline at end of file
diff --git a/docs/source/modules/labs_2024/lab_0_introduction.rst b/docs/source/modules/labs_2024/lab_0_introduction.rst
index 896c0c769..0bde49047 100644
--- a/docs/source/modules/labs_2024/lab_0_introduction.rst
+++ b/docs/source/modules/labs_2024/lab_0_introduction.rst
@@ -46,7 +46,7 @@ TroubleShooting
You may find that you have to use `Python3.11` but Google Colab only provides `Python3.10`. In this case, you can use the following command to force the kernel ot use `Python3.11`:
-.. code-block:: text
+.. code-block:: python
#The code below installs 3.11 (assuming you now have 3.10 in colab) and restarts environment, so you can run your cells.
import sys #for version checker
From 9dba4fde4bb557c582f3b65b6d13892817f9c7a5 Mon Sep 17 00:00:00 2001
From: Cheng Zhang
Date: Tue, 11 Feb 2025 20:19:06 +0000
Subject: [PATCH 10/38] rename files
---
src/chop/nn/__init__.py | 1 -
src/chop/nn/optical/__init__.py | 4 +-
src/chop/nn/optical/functional/__init__.py | 54 ---
src/chop/nn/optical/functional/general.py | 433 ------------------
src/chop/nn/optical/modules/morr_conv2d.py | 43 +-
src/chop/nn/optical/modules/morr_linear.py | 23 +-
src/chop/nn/optical/utils/__init__.py | 26 ++
.../optical/{functional => utils}/compute.py | 0
.../{functional => utils}/initializer.py | 0
.../nn/optical/{functional => utils}/mrr.py | 0
.../optical/{functional => utils}/mrr_op.py | 0
.../optical/{functional => utils}/quantize.py | 0
.../{functional => utils}/torch_train.py | 0
.../optical/module_transform_helper.py | 2 -
.../module/transforms/optical/optical.py | 44 +-
.../transforms/optical/test_optical_module.py | 144 +-----
16 files changed, 70 insertions(+), 704 deletions(-)
delete mode 100644 src/chop/nn/optical/functional/__init__.py
delete mode 100644 src/chop/nn/optical/functional/general.py
create mode 100644 src/chop/nn/optical/utils/__init__.py
rename src/chop/nn/optical/{functional => utils}/compute.py (100%)
rename src/chop/nn/optical/{functional => utils}/initializer.py (100%)
rename src/chop/nn/optical/{functional => utils}/mrr.py (100%)
rename src/chop/nn/optical/{functional => utils}/mrr_op.py (100%)
rename src/chop/nn/optical/{functional => utils}/quantize.py (100%)
rename src/chop/nn/optical/{functional => utils}/torch_train.py (100%)
diff --git a/src/chop/nn/__init__.py b/src/chop/nn/__init__.py
index 5533a7608..7ab651324 100644
--- a/src/chop/nn/__init__.py
+++ b/src/chop/nn/__init__.py
@@ -1,4 +1,3 @@
from .quantized import quantized_module_map
-from .optical import optical_module_map
MASE_LEAF_LAYERS = tuple(quantized_module_map.values())
diff --git a/src/chop/nn/optical/__init__.py b/src/chop/nn/optical/__init__.py
index b74e7df6b..0310afb71 100644
--- a/src/chop/nn/optical/__init__.py
+++ b/src/chop/nn/optical/__init__.py
@@ -1,3 +1 @@
-from .modules import (
- optical_module_map,
-)
+from .modules import optical_module_map
diff --git a/src/chop/nn/optical/functional/__init__.py b/src/chop/nn/optical/functional/__init__.py
deleted file mode 100644
index 84f94e7b6..000000000
--- a/src/chop/nn/optical/functional/__init__.py
+++ /dev/null
@@ -1,54 +0,0 @@
-from .mrr import (
- MORRConfig_20um_MQ,
- MRRConfig_5um_HQ,
- MRRConfig_5um_MQ,
- MRRConfig_5um_LQ,
- MORRConfig_10um_MQ,
-)
-
-from .compute import (
- im2col_2d,
- toeplitz,
-)
-
-from .general import (
- logger,
-)
-
-from .initializer import (
- morr_uniform_,
-)
-
-from .quantize import (
- input_quantize_fn,
- weight_quantize_fn,
-)
-
-from .mrr_op import (
- mrr_roundtrip_phase_to_tr_func,
- mrr_roundtrip_phase_to_tr_fused,
-)
-
-
-# """
-# Description:
-# Author: Jiaqi Gu (jqgu@utexas.edu)
-# Date: 2021-06-09 01:40:22
-# LastEditors: Jiaqi Gu (jqgu@utexas.edu)
-# LastEditTime: 2021-06-09 01:40:22
-# """
-
-# import importlib
-# import os
-
-# # automatically import any Python files in this directory
-# for file in sorted(os.listdir(os.path.dirname(__file__))):
-# if file.endswith(".py") and not file.startswith("_"):
-# source = file[: file.find(".py")]
-# module = importlib.import_module("torchonn.layers." + source)
-# if "__all__" in module.__dict__:
-# names = module.__dict__["__all__"]
-# else:
-# # import all names that do not begin with _
-# names = [x for x in module.__dict__ if not x.startswith("_")]
-# globals().update({k: getattr(module, k) for k in names})
diff --git a/src/chop/nn/optical/functional/general.py b/src/chop/nn/optical/functional/general.py
deleted file mode 100644
index f5aa87e14..000000000
--- a/src/chop/nn/optical/functional/general.py
+++ /dev/null
@@ -1,433 +0,0 @@
-"""
-Description:
-Author: Jiaqi Gu (jqgu@utexas.edu)
-Date: 2021-06-06 01:55:29
-LastEditors: Jiaqi Gu (jqgu@utexas.edu)
-LastEditTime: 2021-06-06 01:55:30
-"""
-
-import os
-import argparse
-import json
-import logging
-import logging.handlers
-import time
-from collections import OrderedDict
-from datetime import datetime
-from pathlib import Path
-from typing import Optional
-
-import numpy as np
-import torch
-
-
-__all__ = [
- "ensure_dir",
- "read_json",
- "write_json",
- "profile",
- "print_stat",
- "Timer",
- "TimerCtx",
- "TorchTracemalloc",
- "fullprint",
- "setup_default_logging",
- "Logger",
- "logger",
- "get_logger",
- "ArgParser",
- "disable_tf_warning",
- "AverageMeter",
-]
-
-
-def ensure_dir(dirname, exist_ok: bool = True):
- dirname = Path(dirname)
- if not dirname.is_dir():
- dirname.mkdir(parents=True, exist_ok=exist_ok)
-
-
-def read_json(fname):
- with open(fname, "rt") as handle:
- return json.load(handle, object_hook=OrderedDict)
-
-
-def write_json(content, fname):
- with open(fname, "wt") as handle:
- json.dump(content, handle, indent=4, sort_keys=False)
-
-
-def profile(func=None, timer=True):
- from functools import wraps, partial
- import time
-
- if func == None:
- return partial(profile, timer=timer)
-
- @wraps(func)
- def wrapper(*args, **kw):
- if timer:
- local_time = time.time()
- res = func(*args, **kw)
- end_time = time.time()
- print(
- "[I] <%s> runtime: %.3f ms"
- % (func.__name__, (end_time - local_time) * 1000)
- )
- else:
- res = func(*args, **kw)
- return res
-
- return wrapper
-
-
-def print_stat(x, message="", verbose=True):
- if verbose:
- if isinstance(x, torch.Tensor):
- if torch.is_complex(x):
- x = torch.view_as_real(x)
- print(
- message
- + f"min = {x.data.min().item():-15f} max = {x.data.max().item():-15f} mean = {x.data.mean().item():-15f} std = {x.data.std().item():-15f}"
- )
- elif isinstance(x, np.ndarray):
- print(
- message
- + f"min = {np.min(x):-15f} max = {np.max(x):-15f} mean = {np.mean(x):-15f} std = {np.std(x):-15f}"
- )
-
-
-class Timer(object):
- def __init__(self):
- self.cache = datetime.now()
-
- def check(self):
- now = datetime.now()
- duration = now - self.cache
- self.cache = now
- return duration.total_seconds()
-
- def reset(self):
- self.cache = datetime.now()
-
-
-class TimerCtx:
- def __enter__(self):
- self.start = time.time()
- return self
-
- def __exit__(self, *args):
- self.end = time.time()
- self.interval = self.end - self.start
-
-
-class TorchTracemalloc(object):
- def __init__(self, verbose: bool = False) -> None:
- super().__init__()
- self.verbose = verbose
-
- def __enter__(self):
- self.begin = self._b2mb(torch.cuda.memory_allocated())
- torch.cuda.reset_max_memory_allocated() # reset the peak gauge to zero
- return self
-
- def _b2mb(self, x):
- return x / 2**20
-
- def __exit__(self, *exc):
- self.end = self._b2mb(torch.cuda.memory_allocated())
- self.peak = self._b2mb(torch.cuda.max_memory_allocated())
- self.used = self.end - self.begin
- self.peaked = self.peak - self.begin
- if self.verbose:
- print(f"Delta used/peaked {self.used:.2f} MB / {self.peaked:.2f} MB")
- print(f"Current used/peaked {self.end:.2f} MB / {self.peak:.2f} MB")
-
-
-class fullprint:
- "context manager for printing full numpy arrays"
-
- def __init__(self, **kwargs):
- """linewidth=75; precision=8"""
- kwargs.setdefault("threshold", np.inf)
- self.opt = kwargs
-
- def __enter__(self):
- self._opt = np.get_printoptions()
- np.set_printoptions(**self.opt)
-
- def __exit__(self, type, value, traceback):
- np.set_printoptions(**self._opt)
-
-
-class CustomFormatter(logging.Formatter):
- """Logging Formatter to add colors and count warning / errors"""
-
- grey = "\x1b[38;21m"
- yellow = "\x1b[33;21m"
- red = "\x1b[31;21m"
- bold_red = "\x1b[31;1m"
- green = "\x1b[32;21m"
- reset = "\x1b[0m"
- # format = "%(asctime)s - %(name)s - %(levelname)s - %(message)s (%(filename)s:%(lineno)d)"
- format = "%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s"
-
- FORMATS = {
- logging.DEBUG: grey + format + reset,
- logging.INFO: grey + format + reset,
- logging.WARNING: yellow + format + reset,
- logging.ERROR: red + format + reset,
- logging.CRITICAL: bold_red + format + reset,
- }
-
- def format(self, record):
- log_fmt = self.FORMATS.get(record.levelno)
- formatter = logging.Formatter(log_fmt)
- return formatter.format(record)
-
-
-def setup_default_logging(
- default_level=logging.INFO, default_file_level=logging.INFO, log_path=""
-):
- console_handler = logging.StreamHandler()
- console_handler.setFormatter(CustomFormatter())
- logging.root.addHandler(console_handler)
- logging.root.setLevel(default_level)
- if log_path:
- file_handler = logging.handlers.RotatingFileHandler(
- log_path, maxBytes=(1024**2 * 2), backupCount=3
- )
- file_formatter = logging.Formatter(
- "%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s"
- )
- file_handler.setFormatter(file_formatter)
- file_handler.setLevel(default_file_level)
- logging.root.addHandler(file_handler)
-
-
-class Logger(object):
- def __init__(
- self,
- console=True,
- logfile=None,
- console_level=logging.INFO,
- logfile_level=logging.INFO,
- ):
- super().__init__()
- self.logfile = logfile
- self.console_level = console_level
- self.logifle_level = logfile_level
- assert (
- console == True or logfile is not None
- ), "At least enable one from console or logfile for Logger"
- # 第一步,创建一个logger
- self.logger = logging.getLogger("my_logger")
- self.logger.setLevel(logging.INFO) # Log等级总开关
- self.logger.propagate = False
-
- # formatter = logging.Formatter(
- # "%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s")
- formatter = CustomFormatter()
-
- # 第三步,再创建一个handler,用于输出到控制台
- if console:
- ch = logging.StreamHandler()
- ch.setLevel(self.console_level) # 输出到console的log等级的开关
- ch.setFormatter(formatter)
- self.logger.addHandler(ch)
- if self.logfile is not None:
- fh = logging.FileHandler(self.logfile, mode="w")
- fh.setLevel(self.logifle_level) # 输出到file的log等级的开关
- fh.setFormatter(formatter)
- self.logger.addHandler(fh)
-
- def debug(self, message):
- self.logger.debug(message)
-
- def info(self, message):
- self.logger.info(message)
-
- def warning(self, message):
- self.logger.warning(message)
-
- def error(self, message):
- self.logger.error(message)
-
- def critical(self, message):
- self.logger.critical(message)
-
-
-def get_logger(
- name="default",
- default_level=logging.INFO,
- default_file_level=logging.INFO,
- log_path="",
-):
- setup_default_logging(
- default_level=default_level,
- default_file_level=default_file_level,
- log_path=log_path,
- )
- return logging.getLogger(name)
-
-
-logger = get_logger()
-
-
-class ArgParser(object):
- def __init__(self, load_json=None, save_json=None):
- super().__init__()
- self.load_json = load_json
- self.save_json = save_json
- self.args = None
- self.parser = argparse.ArgumentParser("Argument Parser")
-
- def add_arg(self, *args, **keywords):
- self.parser.add_argument(*args, **keywords)
-
- def parse_args(self):
- if self.load_json is not None:
- assert os.path.exists(self.load_json), logging.error(
- f"Configuration JSON {self.load_json} not found"
- )
- json = read_json(self.load_json)
- t_args = argparse.Namespace()
- t_args.__dict__.update(json)
- self.args = self.parser.parse_args(args=[], namespace=t_args)
- else:
- self.args = self.parser.parse_args()
- return self.args
-
- def print_args(self):
- # Print arguments to std out
- # and save argument values to yaml file
- print("Arguments:")
- for p in vars(self.args).items():
- print(f"\t{p[0]:30}{str(p[1]):20}")
- print("\n")
-
- def dump_args(self, json_file=None):
- if json_file is None:
- if self.save_json is None:
- logging.error("Skip dump configuration JSON. Please specify json_file")
- return False
- else:
- ensure_dir(os.path.dirname(self.save_json))
- logging.warning(f"Dump to the initialized JSON file {self.save_json}")
- write_json(vars(self.args), self.save_json)
- else:
- ensure_dir(os.path.dirname(json_file))
- logging.info(f"Dump to JSON file {json_file}")
- write_json(vars(self.args), json_file)
- # with open(self.file, 'w') as f:
- # yaml.dump(vars(self.args), f, default_flow_style=False)
- # print(f"[I] Arguments dumped to {file}")
-
-
-def disable_tf_warning():
- import os
-
- os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
- import warnings
-
- warnings.filterwarnings("ignore", category=FutureWarning)
- warnings.filterwarnings("ignore", category=DeprecationWarning)
-
- import tensorflow as tf
-
- if hasattr(tf, "contrib") and type(tf.contrib) != type(tf):
- tf.contrib._warning = None
- tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
- # tf.logging.set_verbosity(tf.logging.ERROR)
-
- import logging
-
- logging.getLogger("tensorflow").setLevel(logging.ERROR)
-
-
-class Meter(object):
- """Base class for Meters."""
-
- def __init__(self):
- pass
-
- def state_dict(self):
- return {}
-
- def load_state_dict(self, state_dict):
- pass
-
- def reset(self):
- raise NotImplementedError
-
- @property
- def smoothed_value(self) -> float:
- """Smoothed value used for logging."""
- raise NotImplementedError
-
-
-def safe_round(number, ndigits):
- if hasattr(number, "__round__"):
- return round(number, ndigits)
- elif torch is not None and torch.is_tensor(number) and number.numel() == 1:
- return safe_round(number.item(), ndigits)
- elif np is not None and np.ndim(number) == 0 and hasattr(number, "item"):
- return safe_round(number.item(), ndigits)
- else:
- return number
-
-
-def type_as(a, b):
- if torch.is_tensor(a) and torch.is_tensor(b):
- return a.to(b)
- else:
- return a
-
-
-class AverageMeter(Meter):
- """Computes and stores the average and current value"""
-
- def __init__(self, name: str, fmt: str = ":f", round: Optional[int] = None) -> None:
- self.name = name
- self.fmt = fmt
- self.round = round
- self.reset()
-
- def reset(self):
- self.val = None # most recent update
- self.sum = 0 # sum from all updates
- self.count = 0 # total n from all updates
- self.avg = 0
-
- def update(self, val, n=1):
- if val is not None:
- self.val = val
- if n > 0:
- self.sum = type_as(self.sum, val) + (val * n)
- self.count = type_as(self.count, n) + n
- self.avg = self.sum / self.count if self.count > 0 else self.val
-
- def state_dict(self):
- return {
- "val": self.val,
- "sum": self.sum,
- "count": self.count,
- "round": self.round,
- }
-
- def load_state_dict(self, state_dict):
- self.val = state_dict["val"]
- self.sum = state_dict["sum"]
- self.count = state_dict["count"]
- self.round = state_dict.get("round", None)
-
- @property
- def smoothed_value(self) -> float:
- val = self.avg
- if self.round is not None and val is not None:
- val = safe_round(val, self.round)
- return val
-
- def __str__(self) -> str:
- fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})"
- return fmtstr.format(**self.__dict__)
diff --git a/src/chop/nn/optical/modules/morr_conv2d.py b/src/chop/nn/optical/modules/morr_conv2d.py
index 3cc56ad2f..f104e397f 100644
--- a/src/chop/nn/optical/modules/morr_conv2d.py
+++ b/src/chop/nn/optical/modules/morr_conv2d.py
@@ -6,24 +6,26 @@
"""
from typing import Optional, Tuple
+import logging
import numpy as np
import torch
import torch.fft
-from ..functional import im2col_2d, toeplitz
-from ..functional import logger
-from ..functional import morr_uniform_
-from ..functional import input_quantize_fn, weight_quantize_fn
from torch import Tensor, nn
from torch.nn import Parameter, init
from torch.nn.modules.utils import _pair
from torch.types import Device, _size
-from ..functional import MORRConfig_20um_MQ
-from ..functional import mrr_roundtrip_phase_to_tr_func, mrr_roundtrip_phase_to_tr_fused
+from ..utils import MORRConfig_20um_MQ
+from ..utils import mrr_roundtrip_phase_to_tr_func, mrr_roundtrip_phase_to_tr_fused
+from ..utils import im2col_2d, toeplitz
+from ..utils import morr_uniform_
+from ..utils import input_quantize_fn, weight_quantize_fn
from .base_layer import ONNBaseLayer
+logger = logging.getLogger(__name__)
+
__all__ = ["AllPassMORRCirculantConv2d"]
@@ -72,18 +74,13 @@ def __init__(
dilation: _size = 1,
groups: int = 1,
bias: bool = True,
- padding_mode=None,
- # miniblock: int = 4,
- # ### morr parameter
- # MORRConfig=MORRConfig_20um_MQ,
- # morr_init: bool = True, # whether to use initialization method customized for MORR
- # ### trainable MORR nonlinearity
- # trainable_morr_bias: bool = False,
- # trainable_morr_scale: bool = False,
+ padding_mode=None, # @johnny: unused argument
config=None,
- device: Device = torch.device("cuda" if torch.cuda.is_available() else "cpu"),
+ device: Device = torch.device("cpu"),
) -> None:
super(AllPassMORRCirculantConv2d, self).__init__()
+ assert config is not None
+
miniblock = config.get("miniblock", 4)
MORRConfig = config.get("MORRConfig", MORRConfig_20um_MQ)
morr_init = config.get("morr_init", True)
@@ -365,14 +362,16 @@ def morr_bias(self) -> Tensor:
)
def propagate_morr(self, weight: Tensor, x: Tensor) -> Tensor:
+ """Propagate through the analytically calculated transfer matrix of MORR.
+
+ :param weight: First column vectors in the block-circulant matrix.
+ :type weight: Tensor
+ :param x: Input tensor.
+ :type x: Tensor
+
+ :return: Output of MORR array.
+ :rtype: Tensor
"""
- @description: propagate through the analytically calculated transfer matrix of morr.
- @param weight {torch.Tensor} first column vectors in the block-circulant matrix
- @param x {torch.Tensor} input
- @return: y {torch.Tensor} output of MORR array
- """
- ### weights: [p, q, k]
- ### x: [ks*ks*inc, h_out*w_out*bs]
x = x.t() # [h_out*w_out*bs, ks*ks*inc]
x = x.view(x.size(0), self.grid_dim_x, self.miniblock) # [h_out*w_out*bs, q, k]
diff --git a/src/chop/nn/optical/modules/morr_linear.py b/src/chop/nn/optical/modules/morr_linear.py
index af85357a2..f93281af5 100644
--- a/src/chop/nn/optical/modules/morr_linear.py
+++ b/src/chop/nn/optical/modules/morr_linear.py
@@ -7,22 +7,24 @@
"""
from typing import Optional
+import logging
import numpy as np
import torch
import torch.fft
-from ..functional import toeplitz
-from ..functional import logger
-from ..functional import morr_uniform_
-from ..functional import input_quantize_fn, weight_quantize_fn
from torch import Tensor
from torch.nn import Parameter, init
from torch.types import Device
-from ..functional import MORRConfig_20um_MQ
-from ..functional import mrr_roundtrip_phase_to_tr_func, mrr_roundtrip_phase_to_tr_fused
+from ..utils import MORRConfig_20um_MQ
+from ..utils import mrr_roundtrip_phase_to_tr_func, mrr_roundtrip_phase_to_tr_fused
+from ..utils import toeplitz
+from ..utils import morr_uniform_
+from ..utils import input_quantize_fn, weight_quantize_fn
from .base_layer import ONNBaseLayer
+logger = logging.getLogger(__name__)
+
__all__ = ["AllPassMORRCirculantLinear"]
@@ -45,14 +47,7 @@ def __init__(
out_features: int,
bias: bool = False,
config=None,
- # miniblock: int = 4,
- # ### mrr parameter
- # MORRConfig=MORRConfig_20um_MQ,
- # morr_init: bool = True,
- # ### trainable MORR nonlinearity
- # trainable_morr_bias: bool = False,
- # trainable_morr_scale: bool = False,
- device: Device = torch.device("cuda" if torch.cuda.is_available() else "cpu"),
+ device: Device = torch.device("cpu"),
) -> None:
super(AllPassMORRCirculantLinear, self).__init__()
self.in_features = in_features
diff --git a/src/chop/nn/optical/utils/__init__.py b/src/chop/nn/optical/utils/__init__.py
new file mode 100644
index 000000000..88b71f264
--- /dev/null
+++ b/src/chop/nn/optical/utils/__init__.py
@@ -0,0 +1,26 @@
+from .mrr import (
+ MORRConfig_20um_MQ,
+ MRRConfig_5um_HQ,
+ MRRConfig_5um_MQ,
+ MRRConfig_5um_LQ,
+ MORRConfig_10um_MQ,
+)
+
+from .compute import (
+ im2col_2d,
+ toeplitz,
+)
+
+from .initializer import (
+ morr_uniform_,
+)
+
+from .quantize import (
+ input_quantize_fn,
+ weight_quantize_fn,
+)
+
+from .mrr_op import (
+ mrr_roundtrip_phase_to_tr_func,
+ mrr_roundtrip_phase_to_tr_fused,
+)
diff --git a/src/chop/nn/optical/functional/compute.py b/src/chop/nn/optical/utils/compute.py
similarity index 100%
rename from src/chop/nn/optical/functional/compute.py
rename to src/chop/nn/optical/utils/compute.py
diff --git a/src/chop/nn/optical/functional/initializer.py b/src/chop/nn/optical/utils/initializer.py
similarity index 100%
rename from src/chop/nn/optical/functional/initializer.py
rename to src/chop/nn/optical/utils/initializer.py
diff --git a/src/chop/nn/optical/functional/mrr.py b/src/chop/nn/optical/utils/mrr.py
similarity index 100%
rename from src/chop/nn/optical/functional/mrr.py
rename to src/chop/nn/optical/utils/mrr.py
diff --git a/src/chop/nn/optical/functional/mrr_op.py b/src/chop/nn/optical/utils/mrr_op.py
similarity index 100%
rename from src/chop/nn/optical/functional/mrr_op.py
rename to src/chop/nn/optical/utils/mrr_op.py
diff --git a/src/chop/nn/optical/functional/quantize.py b/src/chop/nn/optical/utils/quantize.py
similarity index 100%
rename from src/chop/nn/optical/functional/quantize.py
rename to src/chop/nn/optical/utils/quantize.py
diff --git a/src/chop/nn/optical/functional/torch_train.py b/src/chop/nn/optical/utils/torch_train.py
similarity index 100%
rename from src/chop/nn/optical/functional/torch_train.py
rename to src/chop/nn/optical/utils/torch_train.py
diff --git a/src/chop/passes/module/transforms/optical/module_transform_helper.py b/src/chop/passes/module/transforms/optical/module_transform_helper.py
index 1cef06928..3d5f9670b 100644
--- a/src/chop/passes/module/transforms/optical/module_transform_helper.py
+++ b/src/chop/passes/module/transforms/optical/module_transform_helper.py
@@ -1,7 +1,5 @@
import torch
import torch.nn as nn
-import numpy as np
-from chop.nn.optical.modules import optical_module_map
from chop.passes.module.module_modify_helper import (
get_module_by_name,
set_module_by_name,
diff --git a/src/chop/passes/module/transforms/optical/optical.py b/src/chop/passes/module/transforms/optical/optical.py
index b911c070d..550484880 100644
--- a/src/chop/passes/module/transforms/optical/optical.py
+++ b/src/chop/passes/module/transforms/optical/optical.py
@@ -1,7 +1,7 @@
import torch
from chop.nn.optical.modules import optical_module_map
-from chop.passes.module.module_modify_helper import replace_by_name, instantiate_module
+from chop.passes.module.module_modify_helper import instantiate_module
from chop.passes.module.transforms.optical.module_transform_helper import (
replace_by_name_optical,
)
@@ -14,7 +14,7 @@ def get_config(config: dict, name: str):
return config["default"]["config"]
-def optical_by_type(network, pass_args):
+def optical_transform_by_type(network, pass_args):
for type_name, config in pass_args.items():
n_m = {}
for n, m in network.named_modules():
@@ -37,20 +37,20 @@ def optical_by_type(network, pass_args):
return network
-def optical_by_name(network, pass_args):
+def optical_transform_by_name(network, pass_args):
optical_names = pass_args.keys()
n_m = {}
for n, m in network.named_modules():
n_m[n] = m
for n, m in n_m.items():
if n in optical_names:
- quan_config = pass_args[n]
+ optical_config = pass_args[n]
- quan_config = quan_config["config"]
- postfix = quan_config.pop("name")
+ optical_config = optical_config["config"]
+ postfix = optical_config.pop("name")
new_m = instantiate_module(
- m, postfix, optical_module_map, {"config": quan_config}
+ m, postfix, optical_module_map, {"config": optical_config}
)
network = replace_by_name_optical(network, n, new_m)
return network
@@ -66,40 +66,16 @@ def optical_module_transform_pass(network, pass_args):
:param pass_args: Additional arguments for the transformation.
:type pass_args: dict, optional
- Examples pass_args:
-
- .. code-block:: python
-
- pass_args = {
- "by": "type", # quantize by type, name, or regex_name
- "default": {"config": {"name": None}}, # default config, this would be used for any node that does not have a specific config
- "linear": {
- "config": {
- "name": "integer", # quantization scheme name supported are ["integer", "fixed" (equivalent to integer), "lutnet" (dev mode), "logicnets" (dev mode), "binary", "binary_residual", "ternary", "minifloat_ieee", "minifloat_denorm", "log", "block_fp", "block_minifloat", "block_log"]
- # data
- "data_in_width": 8,
- "data_in_frac_width": 4,
- # weight
- "weight_width": 8,
- "weight_frac_width": 4,
- # bias
- "bias_width": 8,
- "bias_frac_width": 4,
- }
- },
- }
-
:return: The transformed torch.nn.Module.
:rtype: tuple
- :raises ValueError: If the quantize "by" argument is unsupported.
-
+ :raises ValueError: If the "by" argument is unsupported.
"""
by = pass_args.pop("by")
match by:
case "type":
- network = optical_by_type(network, pass_args)
+ network = optical_transform_by_type(network, pass_args)
case "name":
- network = optical_by_name(network, pass_args)
+ network = optical_transform_by_name(network, pass_args)
case _:
raise ValueError(f'Unsupported quantize "by": {by}')
return network, {}
diff --git a/test/passes/module/transforms/optical/test_optical_module.py b/test/passes/module/transforms/optical/test_optical_module.py
index d05512cc1..e65546822 100644
--- a/test/passes/module/transforms/optical/test_optical_module.py
+++ b/test/passes/module/transforms/optical/test_optical_module.py
@@ -1,19 +1,17 @@
#!/usr/bin/env python3
-# This example converts a simple MLP model to Verilog
-import logging
-import os
+# This example converts a simple MLP model to an ONN model
import sys
import torch
import torch.nn as nn
+import torch.nn.functional as F
from pathlib import Path
sys.path.append(Path(__file__).resolve().parents[5].as_posix())
-# from chop.passes.module.transforms import quantize_module_transform_pass
-from chop.passes.module.transforms import optical_module_transform_pass
+from chop.passes.module.transforms.optical import optical_module_transform_pass
class Net(nn.Module):
@@ -42,17 +40,7 @@ def forward(self, x):
return output
-def load_my_model(model_path, device="cpu"):
- # Load the model from the .pt file
- loaded_model = torch.load(model_path, map_location=device)
- # Set it to evaluation mode (important if it contains layers like BatchNorm or Dropout)
- loaded_model.eval()
- return loaded_model
-
-
def test_optical_module_transform_pass():
- # model_path = "mase_output/sample_mnist_cnn.pt"
- # mnist_cnn = load_my_model(model_path)
model = Net()
# Sanity check and report
pass_args = {
@@ -77,132 +65,6 @@ def test_optical_module_transform_pass():
},
}
optical_module_transform_pass(model, pass_args)
- # torch.save(onn_cnn, "mase_output/onn_cnn.pt")
test_optical_module_transform_pass()
-
-
-# if __name__ == '__main__':
-# finetune = False
-
-# if True:
-# parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
-# parser.add_argument('--batch-size', type=int, default=64, metavar='N',
-# help='input batch size for training (default: 64)')
-# parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
-# help='input batch size for testing (default: 1000)')
-# parser.add_argument('--epochs', type=int, default=14, metavar='N',
-# help='number of epochs to train (default: 14)')
-# parser.add_argument('--lr', type=float, default=1.0, metavar='LR',
-# help='learning rate (default: 1.0)')
-# parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
-# help='Learning rate step gamma (default: 0.7)')
-# parser.add_argument('--no-cuda', action='store_true', default=False,
-# help='disables CUDA training')
-# parser.add_argument('--no-mps', action='store_true', default=False,
-# help='disables macOS GPU training')
-# parser.add_argument('--dry-run', action='store_true', default=False,
-# help='quickly check a single pass')
-# parser.add_argument('--seed', type=int, default=1, metavar='S',
-# help='random seed (default: 1)')
-# parser.add_argument('--log-interval', type=int, default=10, metavar='N',
-# help='how many batches to wait before logging training status')
-# parser.add_argument('--save-model', action='store_true', default=True,
-# help='For Saving the current Model')
-# parser.add_argument('--gpu-id', type=int, default=0,
-# help='Which GPU device to use [default: 0]')
-
-# args = parser.parse_args()
-# use_cuda = not args.no_cuda and torch.cuda.is_available()
-# use_mps = not args.no_mps and torch.backends.mps.is_available()
-
-# torch.manual_seed(args.seed)
-
-# if not args.no_cuda and torch.cuda.is_available():
-# device = torch.device(f"cuda:{args.gpu_id}")
-# elif use_mps:
-# device = torch.device("mps")
-# else:
-# device = torch.device("cpu")
-
-# train_kwargs = {'batch_size': args.batch_size}
-# test_kwargs = {'batch_size': args.test_batch_size}
-# if use_cuda:
-# cuda_kwargs = {'num_workers': 1,
-# 'pin_memory': True,
-# 'shuffle': True}
-# train_kwargs.update(cuda_kwargs)
-# test_kwargs.update(cuda_kwargs)
-
-# transform=transforms.Compose([
-# transforms.ToTensor(),
-# transforms.Normalize((0.1307,), (0.3081,))
-# ])
-# dataset1 = datasets.MNIST('../data', train=True, download=True,
-# transform=transform)
-# dataset2 = datasets.MNIST('../data', train=False,
-# transform=transform)
-# train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)
-# test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)
-
-# # load pre-trained cnn
-# cnn = load_my_model("mase_output/sample_mnist_cnn.pt", device)
-# print("-------------- Testing the original cnn model -------------------")
-# _, _ = report_trainable_parameters_analysis_pass(cnn)
-# test(cnn, device, test_loader)
-
-# ## transform cnn into onn
-
-# # onn = load_my_model("mase_output/onn_cnn.pt", device)
-# onn_model = perform_optical_module_transform_pass(cnn)
-# onn_model.to(device)
-# print("-------------- Testing the transformed onn model -------------------")
-# _, _ = report_trainable_parameters_analysis_pass(onn_model)
-# test(onn_model, device, test_loader)
-
-# # Training the onn model
-# if finetune:
-# optimizer = optim.Adadelta(onn_model.parameters(), lr=args.lr)
-# scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
-# for epoch in range(1, args.epochs + 1):
-# train(args, onn_model, device, train_loader, optimizer, epoch)
-# test(onn_model, device, test_loader)
-# scheduler.step()
-
-
-# torch.save(onn_model.state_dict(), "mase_output/trained_onn.pt")
-
-# print("-------------- Testing the trained onn model -------------------")
-# test(onn_model, device, test_loader)
-# _, _ = report_trainable_parameters_analysis_pass(onn_model)
-
-
-# def perform_optical_module_transform_pass(model, save_path="mase_output/onn_cnn.pt"):
-# # pass_args = {
-# # "by": "type",
-# # "linear": {
-# # "config": {
-# # "name": "morr",
-# # "miniblock": 4,
-# # "morr_init": True,
-# # "trainable_morr_bias": False,
-# # "trainable_morr_scale": False,
-# # }
-# # },
-# # }
-# pass_args = {
-# "by": "type",
-# "conv2d": {
-# "config": {
-# "name": "morr",
-# "miniblock": 4,
-# "morr_init": True,
-# "trainable_morr_bias": False,
-# "trainable_morr_scale": False,
-# }
-# },
-# }
-# onn_model, _ = optical_module_transform_pass(model, pass_args)
-# torch.save(onn_model.state_dict(), save_path)
-# return onn_model
From 188688492889b8c677a788269b14314f1737ed2d Mon Sep 17 00:00:00 2001
From: Cheng Zhang
Date: Tue, 11 Feb 2025 20:22:31 +0000
Subject: [PATCH 11/38] remove torch_train.py
---
src/chop/nn/optical/utils/compute.py | 13 +-
src/chop/nn/optical/utils/torch_train.py | 858 -----------------------
2 files changed, 11 insertions(+), 860 deletions(-)
delete mode 100644 src/chop/nn/optical/utils/torch_train.py
diff --git a/src/chop/nn/optical/utils/compute.py b/src/chop/nn/optical/utils/compute.py
index c43ad9f2d..d8d36a354 100644
--- a/src/chop/nn/optical/utils/compute.py
+++ b/src/chop/nn/optical/utils/compute.py
@@ -19,8 +19,6 @@
from torch.nn.modules.utils import _pair
from torch.types import Device, _size
-from .torch_train import set_torch_deterministic
-
__all__ = [
"shift",
"Krylov",
@@ -70,6 +68,17 @@
]
+def set_torch_deterministic(random_state: int = 0) -> None:
+ random_state = int(random_state) % (2**32)
+ torch.manual_seed(random_state)
+ np.random.seed(random_state)
+ if torch.cuda.is_available():
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
+ torch.cuda.manual_seed_all(random_state)
+ random.seed(random_state)
+
+
def shift(v: Tensor, f: float = 1) -> Tensor:
return torch.cat((f * v[..., -1:], v[..., :-1]), dim=-1)
diff --git a/src/chop/nn/optical/utils/torch_train.py b/src/chop/nn/optical/utils/torch_train.py
deleted file mode 100644
index 41571effa..000000000
--- a/src/chop/nn/optical/utils/torch_train.py
+++ /dev/null
@@ -1,858 +0,0 @@
-"""
-Description:
-Author: Jiaqi Gu (jqgu@utexas.edu)
-Date: 2021-06-06 03:15:06
-LastEditors: Jiaqi Gu (jqgu@utexas.edu)
-LastEditTime: 2021-06-06 03:15:06
-"""
-
-import csv
-import os
-import random
-import time
-import traceback
-from collections import OrderedDict
-
-import numpy as np
-import torch
-from scipy import interpolate
-from torch.nn.modules.batchnorm import _BatchNorm
-
-try:
- from torchsummary import summary
-except:
- print("[W] Cannot import torchsummary")
-from .general import ensure_dir
-
-__all__ = [
- "DeterministicCtx",
- "set_torch_deterministic",
- "set_torch_stochastic",
- "get_random_state",
- "summary_model",
- "save_model",
- "BestKModelSaver",
- "load_model",
- "count_parameters",
- "check_converge",
- "ThresholdScheduler",
- "ThresholdScheduler_tf",
- "ValueRegister",
- "ValueTracer",
- "EMA",
- "SWA",
- "export_traces_to_csv",
- "set_learning_rate",
- "get_learning_rate",
- "apply_weight_decay",
- "disable_bn",
- "enable_bn",
-]
-
-
-class DeterministicCtx:
- def __init__(self, random_state: int | None = None) -> None:
- self.random_state = random_state
-
- def __enter__(self):
- self.random_state = random.getstate()
- self.numpy_random_state = np.random.get_state()
- self.torch_random_state = torch.random.get_rng_state()
- self.torch_cuda_random_state = torch.cuda.get_rng_state()
- set_torch_deterministic(self.random_state)
- return self
-
- def __exit__(self, *args):
- random.setstate(self.random_state)
- np.random.seed(self.numpy_random_state)
- np.random.set_state(self.numpy_random_state)
- torch.random.set_rng_state(self.torch_random_state)
- torch.cuda.set_rng_state(self.torch_cuda_random_state)
-
-
-def set_torch_deterministic(random_state: int = 0) -> None:
- random_state = int(random_state) % (2**32)
- torch.manual_seed(random_state)
- np.random.seed(random_state)
- if torch.cuda.is_available():
- torch.backends.cudnn.deterministic = True
- torch.backends.cudnn.benchmark = False
- torch.cuda.manual_seed_all(random_state)
- random.seed(random_state)
-
-
-def set_torch_stochastic():
- seed = int(time.time() * 1000) % (2**32)
- torch.manual_seed(seed)
- np.random.seed(seed)
- if torch.cuda.is_available():
- torch.backends.cudnn.deterministic = False
- torch.cuda.manual_seed_all(seed)
-
-
-def get_random_state():
- return np.random.get_state()[1][0]
-
-
-def summary_model(model, input):
- summary(model, input)
-
-
-def save_model(model, path="./checkpoint/model.pt", print_msg=True):
- """Save PyTorch model in path
-
- Args:
- model (PyTorch model): PyTorch model
- path (str, optional): Full path of PyTorch model. Defaults to "./checkpoint/model.pt".
- print_msg (bool, optional): Control of message print. Defaults to True.
- """
- dir = os.path.dirname(path)
- if not os.path.exists(dir):
- os.mkdir(dir)
- try:
- torch.save(model.state_dict(), path)
- if print_msg:
- print(f"[I] Model saved to {path}")
- except Exception as e:
- if print_msg:
- print(f"[E] Model failed to be saved to {path}")
- traceback.print_exc(e)
-
-
-class BestKModelSaver(object):
- def __init__(
- self,
- k: int = 1,
- descend: bool = True,
- truncate: int = 2,
- metric_name: str = "acc",
- format: str = "{:.2f}",
- ):
- super().__init__()
- self.k = k
- self.descend = descend
- self.truncate = truncate
- self.metric_name = metric_name
- self.format = format
- self.epsilon = 0.1**truncate
- self.model_cache = OrderedDict()
-
- def better_op(self, a, b):
- if self.descend:
- return a >= b + self.epsilon
- else:
- return a <= b - self.epsilon
-
- def __insert_model_record(self, metric, dir, checkpoint_name, epoch=None):
- metric = round(metric * 10**self.truncate) / 10**self.truncate
- if len(self.model_cache) < self.k:
- new_checkpoint_name = (
- f"{checkpoint_name}_{self.metric_name}-"
- + self.format.format(metric)
- + f"{'' if epoch is None else '_epoch-'+str(epoch)}"
- )
- path = os.path.join(dir, new_checkpoint_name + ".pt")
- self.model_cache[path] = (metric, epoch)
- return path, None
- else:
- worst_metric, worst_epoch = sorted(
- list(self.model_cache.values()),
- key=lambda x: x[0],
- reverse=False if self.descend else True,
- )[0]
- if self.better_op(metric, worst_metric):
- del_checkpoint_name = (
- f"{checkpoint_name}_{self.metric_name}-"
- + self.format.format(worst_metric)
- + f"{'' if epoch is None else '_epoch-'+str(worst_epoch)}"
- )
- del_path = os.path.join(dir, del_checkpoint_name + ".pt")
- try:
- del self.model_cache[del_path]
- except:
- print(
- "[W] Cannot remove checkpoint: {} from cache".format(del_path),
- flush=True,
- )
- new_checkpoint_name = (
- f"{checkpoint_name}_{self.metric_name}-"
- + self.format.format(metric)
- + f"{'' if epoch is None else '_epoch-'+str(epoch)}"
- )
- path = os.path.join(dir, new_checkpoint_name + ".pt")
- self.model_cache[path] = (metric, epoch)
- return path, del_path
- # elif(acc == min_acc):
- # new_checkpoint_name = f"{checkpoint_name}_acc-{acc:.2f}{'' if epoch is None else '_epoch-'+str(epoch)}"
- # path = os.path.join(dir, new_checkpoint_name+".pt")
- # self.model_cache[path] = (acc, epoch)
- # return path, None
- else:
- return None, None
-
- def get_topk_model_path(self, topk: int = 1):
- if topk <= 0:
- return []
- if topk > len(self.model_cache):
- topk = len(self.model_cache)
- return [
- i[0]
- for i in sorted(
- self.model_cache.items(), key=lambda x: x[1][0], reverse=self.descend
- )[:topk]
- ]
-
- def save_model(
- self,
- model,
- metric,
- epoch=None,
- path="./checkpoint/model.pt",
- other_params=None,
- save_model=False,
- print_msg=True,
- ):
- """Save PyTorch model in path
-
- Args:
- model (PyTorch model): PyTorch model
- acc (scalar): accuracy
- epoch (scalar, optional): epoch. Defaults to None
- path (str, optional): Full path of PyTorch model. Defaults to "./checkpoint/model.pt".
- other_params (dict, optional): Other saved params. Defaults to None
- save_model (bool, optional): whether save source code of nn.Module. Defaults to False
- print_msg (bool, optional): Control of message print. Defaults to True.
- """
- dir = os.path.dirname(path)
- ensure_dir(dir)
- checkpoint_name = os.path.splitext(os.path.basename(path))[0]
- if isinstance(metric, torch.Tensor):
- metric = metric.data.item()
- new_path, del_path = self.__insert_model_record(
- metric, dir, checkpoint_name, epoch
- )
-
- if del_path is not None:
- try:
- os.remove(del_path)
- print(f"[I] Model {del_path} is removed", flush=True)
- except Exception as e:
- if print_msg:
- print(f"[E] Model {del_path} failed to be removed", flush=True)
- traceback.print_exc(e)
-
- if new_path is None:
- if print_msg:
- if self.descend:
- best_list = list(reversed(sorted(list(self.model_cache.values()))))
- else:
- best_list = list(sorted(list(self.model_cache.values())))
- print(
- f"[I] Not best {self.k}: {best_list}, skip this model ("
- + self.format.format(metric)
- + f"): {path}",
- flush=True,
- )
- else:
- try:
- # torch.save(model.state_dict(), new_path)
- if other_params is not None:
- saved_dict = other_params
- else:
- saved_dict = {}
- if save_model:
- saved_dict.update(
- {"model": model, "state_dict": model.state_dict()}
- )
- torch.save(saved_dict, new_path)
- else:
- saved_dict.update({"model": None, "state_dict": model.state_dict()})
- torch.save(saved_dict, new_path)
- if print_msg:
- if self.descend:
- best_list = list(
- reversed(sorted(list(self.model_cache.values())))
- )
- else:
- best_list = list(sorted(list(self.model_cache.values())))
-
- print(
- f"[I] Model saved to {new_path}. Current best {self.k}: {best_list}",
- flush=True,
- )
- except Exception as e:
- if print_msg:
- print(f"[E] Model failed to be saved to {new_path}", flush=True)
- traceback.print_exc(e)
- return new_path
-
-
-def load_model(
- model,
- path="./checkpoint/model.pt",
- ignore_size_mismatch: bool = False,
- print_msg=True,
-):
- """Load PyTorch model in path
-
- Args:
- model (PyTorch model): PyTorch model
- path (str, optional): Full path of PyTorch model. Defaults to "./checkpoint/model.pt".
- ignore_size_mismatch (bool, optional): Whether ignore tensor size mismatch. Defaults to False.
- print_msg (bool, optional): Control of message print. Defaults to True.
- """
- try:
- raw_data = torch.load(path, map_location=lambda storage, location: storage)
- if isinstance(raw_data, OrderedDict) and "state_dict" not in raw_data:
- ### state_dict: OrderedDict
- state_dict = raw_data
- else:
- ### {"state_dict": ..., "model": ...}
- state_dict = raw_data["state_dict"]
- load_keys = set(state_dict.keys())
- model_keys = set(model.state_dict().keys())
- common_dict = load_keys & model_keys
- diff_dict = load_keys ^ model_keys
- extra_keys = load_keys - model_keys
- lack_keys = model_keys - load_keys
- cur_state_dict = model.state_dict()
- if ignore_size_mismatch:
- size_mismatch_dict = set(
- key
- for key in common_dict
- if model.state_dict()[key].size() != state_dict[key].size()
- )
- print(
- f"[W] {size_mismatch_dict} are ignored due to size mismatch", flush=True
- )
- common_dict = common_dict - size_mismatch_dict
-
- cur_state_dict.update({key: state_dict[key] for key in common_dict})
- if len(diff_dict) > 0:
- print(
- f"[W] Warning! Model is not the same as the checkpoint. not found keys {lack_keys}. extra unused keys {extra_keys}"
- )
-
- model.load_state_dict(cur_state_dict)
- if print_msg:
- print(f"[I] Model loaded from {path}")
- except Exception as e:
- traceback.print_exc(e)
- if print_msg:
- print(f"[E] Model failed to be loaded from {path}")
-
-
-def count_parameters(model):
- return sum(p.numel() for p in model.parameters() if p.requires_grad)
-
-
-def check_converge(trace, epsilon=0.002):
- if len(trace) <= 1:
- return False
- if np.abs(trace[-1] - trace[-2]) / (np.abs(trace[-1]) + 1e-8) < epsilon:
- return True
- return False
-
-
-class ThresholdScheduler(object):
- """Intepolation between begin point and end point. step must be within two endpoints"""
-
- def __init__(self, step_beg, step_end, thres_beg, thres_end, mode="tanh"):
- assert mode in {
- "linear",
- "tanh",
- }, "Threshold scheduler only supports linear and tanh modes"
- self.mode = mode
- self.step_beg = step_beg
- self.step_end = step_end
- self.thres_beg = thres_beg
- self.thres_end = thres_end
- self.func = self.createFunc()
-
- def normalize(self, step, factor=2):
- return (step - self.step_beg) / (self.step_end - self.step_beg) * factor
-
- def createFunc(self):
- if self.mode == "linear":
- return lambda x: (self.thres_end - self.thres_beg) * x + self.thres_beg
- elif self.mode == "tanh":
- x = self.normalize(
- np.arange(self.step_beg, self.step_end + 1).astype(np.float32)
- )
- y = np.tanh(x) * (self.thres_end - self.thres_beg) + self.thres_beg
- return interpolate.interp1d(x, y)
-
- def __call__(self, x):
- return self.func(self.normalize(x)).tolist()
-
-
-class ThresholdScheduler_tf(object):
- """smooth increasing threshold with tensorflow model pruning scheduler"""
-
- def __init__(self, step_beg, step_end, thres_beg, thres_end):
- import tensorflow as tf
- import tensorflow_model_optimization as tfmot
-
- gpus = tf.config.list_physical_devices("GPU")
- if gpus:
- # Restrict TensorFlow to only allocate 1GB of memory on the first GPU
- try:
- for gpu in gpus:
- tf.config.experimental.set_memory_growth(gpu, True)
- except RuntimeError as e:
- # Virtual devices must be set before GPUs have been initialized
- print(e)
- self.step_beg = step_beg
- self.step_end = step_end
- self.thres_beg = thres_beg
- self.thres_end = thres_end
- if thres_beg < thres_end:
- self.thres_min = thres_beg
- self.thres_range = thres_end - thres_beg
- self.descend = False
-
- else:
- self.thres_min = thres_end
- self.thres_range = thres_beg - thres_end
- self.descend = True
-
- self.pruning_schedule = tfmot.sparsity.keras.PolynomialDecay(
- initial_sparsity=0,
- final_sparsity=0.9999999,
- begin_step=self.step_beg,
- end_step=self.step_end,
- )
-
- def __call__(self, x):
- if x < self.step_beg:
- return self.thres_beg
- elif x > self.step_end:
- return self.thres_end
- res_norm = self.pruning_schedule(x)[1].numpy()
- if self.descend == False:
- res = res_norm * self.thres_range + self.thres_beg
- else:
- res = self.thres_beg - res_norm * self.thres_range
-
- if np.abs(res - self.thres_end) <= 1e-6:
- res = self.thres_end
- return res
-
-
-class ValueRegister(object):
- def __init__(self, operator, name="", show=True):
- self.op = operator
- self.cache = None
- self.show = show
- self.name = name if len(name) > 0 else "value"
-
- def register_value(self, x):
- self.cache = self.op(x, self.cache) if self.cache is not None else x
- if self.show:
- print(f"Recorded {self.name} is {self.cache}")
-
-
-class ValueTracer(object):
- def __init__(self, show=True):
- self.cache = {}
- self.show = show
-
- def add_value(self, name, value, step):
- if name not in self.cache:
- self.cache[name] = {}
- self.cache[name][step] = value
- if self.show:
- print(f"Recorded {name}: step = {step}, value = {value}")
-
- def get_trace_by_name(self, name):
- return self.cache.get(name, {})
-
- def get_all_traces(self):
- return self.cache
-
- def __len__(self):
- return len(self.cache)
-
- def get_num_trace(self):
- return len(self.cache)
-
- def get_len_trace_by_name(self, name):
- return len(self.cache.get(name, {}))
-
- def dump_trace_to_file(self, name, file):
- if name not in self.cache:
- print(f"[W] Trace name '{name}' not found in tracer")
- return
- torch.save(self.cache[name], file)
- print(f"[I] Trace {name} saved to {file}")
-
- def dump_all_traces_to_file(self, file):
- torch.save(self.cache, file)
- print(f"[I] All traces saved to {file}")
-
- def load_all_traces_from_file(self, file):
- self.cache = torch.load(file)
- return self.cache
-
-
-class EMA(object):
- def __init__(self, mu):
- super().__init__()
- self.mu = mu
- self.shadow = {}
-
- def register(self, name, val):
- self.shadow[name] = val.clone().data
-
- def __call__(self, name, x, mask=None):
- if name not in self.shadow:
- self.register(name, x)
- return x.data
-
- old_average = self.shadow[name]
- new_average = (1 - self.mu) * x + self.mu * old_average
- if mask is not None:
- new_average[mask].copy_(old_average[mask])
- self.shadow[name] = new_average.clone()
- return new_average.data
-
-
-class SWA(torch.nn.Module):
- """Stochastic Weight Averging.
-
- # Paper
- title: Averaging Weights Leads to Wider Optima and Better Generalization
- link: https://arxiv.org/abs/1803.05407
-
- # Arguments
- start_epoch: integer, epoch when swa should start.
- lr_schedule: string, type of learning rate schedule.
- swa_lr: float, learning rate for swa.
- swa_lr2: float, upper bound of cyclic learning rate.
- swa_freq: integer, length of learning rate cycle.
- batch_size integer, batch size (for batch norm with generator)
- verbose: integer, verbosity mode, 0 or 1.
- """
-
- def __init__(
- self,
- model: torch.nn.Module,
- optimizer: torch.optim.Optimizer,
- start_epoch: int,
- epochs: int, # total epochs
- steps, # total steps per epoch
- lr_schedule="manual",
- swa_lr="auto",
- swa_lr2="auto",
- swa_freq=1,
- batch_size=None,
- verbose=0,
- ):
- super().__init__()
- self.model = model
- self.optimizer = optimizer
- self.start_epoch = start_epoch - 1
- self.epochs = epochs
- self.steps = steps
- self.lr_schedule = lr_schedule
- self.swa_lr = swa_lr
-
- # if no user determined upper bound, make one based off of the lower bound
- self.swa_lr2 = swa_lr2 if swa_lr2 is not None else 10 * swa_lr
- self.swa_freq = swa_freq
- self.batch_size = batch_size
- self.verbose = verbose
-
- if start_epoch < 2:
- raise ValueError('"swa_start" attribute cannot be lower than 2.')
-
- schedules = ["manual", "constant", "cyclic"]
-
- if self.lr_schedule not in schedules:
- raise ValueError(
- '"{}" is not a valid learning rate schedule'.format(self.lr_schedule)
- )
-
- if self.lr_schedule == "cyclic" and self.swa_freq < 2:
- raise ValueError('"swa_freq" must be higher than 1 for cyclic schedule.')
-
- if self.swa_lr == "auto" and self.swa_lr2 != "auto":
- raise ValueError(
- '"swa_lr2" cannot be manually set if "swa_lr" is automatic.'
- )
-
- if (
- self.lr_schedule == "cyclic"
- and self.swa_lr != "auto"
- and self.swa_lr2 != "auto"
- and self.swa_lr > self.swa_lr2
- ):
- raise ValueError('"swa_lr" must be lower than "swa_lr2".')
-
- def on_train_begin(self):
- self.lr_record = []
-
- if self.start_epoch >= self.epochs - 1:
- raise ValueError('"swa_start" attribute must be lower than "epochs".')
-
- self.init_lr = self.optimizer.param_groups[0]["lr"]
-
- # automatic swa_lr
- if self.swa_lr == "auto":
- self.swa_lr = 0.1 * self.init_lr
-
- if self.init_lr < self.swa_lr:
- raise ValueError('"swa_lr" must be lower than rate set in optimizer.')
-
- # automatic swa_lr2 between initial lr and swa_lr
- if self.lr_schedule == "cyclic" and self.swa_lr2 == "auto":
- self.swa_lr2 = self.swa_lr + (self.init_lr - self.swa_lr) * 0.25
-
- self._check_batch_norm()
-
- if self.has_batch_norm and self.batch_size is None:
- raise ValueError(
- '"batch_size" needs to be set for models with batch normalization layers.'
- )
-
- def on_epoch_begin(self, epoch):
- # input epoch is from 0 to epochs-1
-
- self.current_epoch = epoch
- self._scheduler(epoch)
-
- # constant schedule is updated epoch-wise
- if self.lr_schedule == "constant":
- self._update_lr(epoch)
-
- if self.is_swa_start_epoch:
- # self.swa_weights = self.model.get_weights()
- self.swa_weights = {
- name: p.data.clone() for name, p in self.model.named_parameters()
- }
-
- if self.verbose > 0:
- print(
- "\nEpoch %05d: starting stochastic weight averaging" % (epoch + 1)
- )
-
- if self.is_batch_norm_epoch:
- self._set_swa_weights(epoch)
-
- if self.verbose > 0:
- print(
- "\nEpoch %05d: reinitializing batch normalization layers"
- % (epoch + 1)
- )
-
- self._reset_batch_norm()
-
- if self.verbose > 0:
- print(
- "\nEpoch %05d: running forward pass to adjust batch normalization"
- % (epoch + 1)
- )
-
- def on_batch_begin(self, batch):
- # update lr each batch for cyclic lr schedule
- if self.lr_schedule == "cyclic":
- self._update_lr(self.current_epoch, batch)
-
- if self.is_batch_norm_epoch:
- batch_size = self.batch_size
- # this is for tensorflow momentum, applied to the running stat
- # momentum = batch_size / (batch * batch_size + batch_size)
-
- # we need to convert it to torch momentum, applied to the batch stat
- momentum = 1 - batch_size / (batch * batch_size + batch_size)
-
- for layer in self.batch_norm_layers:
- layer.momentum = momentum
-
- def on_epoch_end(self, epoch):
- if self.is_swa_start_epoch:
- self.swa_start_epoch = epoch
-
- if self.is_swa_epoch and not self.is_batch_norm_epoch:
- self.swa_weights = self._average_weights(epoch)
-
- def on_train_end(self):
- if not self.has_batch_norm:
- self._set_swa_weights(self.epochs)
- else:
- self._restore_batch_norm()
-
- ## TODO: what is meaning here?
- # for batch_lr in self.lr_record:
- # self.model.history.history.setdefault("lr", []).append(batch_lr)
-
- def _scheduler(self, epoch):
- swa_epoch = epoch - self.start_epoch
-
- self.is_swa_epoch = epoch >= self.start_epoch and swa_epoch % self.swa_freq == 0
- self.is_swa_start_epoch = epoch == self.start_epoch
- self.is_batch_norm_epoch = epoch == self.epochs - 1 and self.has_batch_norm
-
- def _average_weights(self, epoch):
- # return [
- # (swa_w * ((epoch - self.start_epoch) / self.swa_freq) + w)
- # / ((epoch - self.start_epoch) / self.swa_freq + 1)
- # for swa_w, w in zip(self.swa_weights, self.model.get_weights())
- # ]
- out = {}
- with torch.no_grad():
- for name, w in self.model.named_parameters():
- swa_w = self.swa_weights[name]
- out[name] = (
- swa_w * ((epoch - self.start_epoch) / self.swa_freq) + w.data
- ) / ((epoch - self.start_epoch) / self.swa_freq + 1)
- return out
-
- def _update_lr(self, epoch, batch=None):
- if self.is_batch_norm_epoch:
- lr = 0
- # K.set_value(self.model.optimizer.lr, lr)
- set_learning_rate(lr, self.optimizer)
- elif self.lr_schedule == "constant":
- lr = self._constant_schedule(epoch)
- # K.set_value(self.model.optimizer.lr, lr)
- set_learning_rate(lr, self.optimizer)
- elif self.lr_schedule == "cyclic":
- lr = self._cyclic_schedule(epoch, batch)
- # K.set_value(self.model.optimizer.lr, lr)
- set_learning_rate(lr, self.optimizer)
- self.lr_record.append(lr)
-
- def _constant_schedule(self, epoch):
- t = epoch / self.start_epoch
- lr_ratio = self.swa_lr / self.init_lr
- if t <= 0.5:
- factor = 1.0
- elif t <= 0.9:
- factor = 1.0 - (1.0 - lr_ratio) * (t - 0.5) / 0.4
- else:
- factor = lr_ratio
- return self.init_lr * factor
-
- def _cyclic_schedule(self, epoch, batch):
- """Designed after Section 3.1 of Averaging Weights Leads to
- Wider Optima and Better Generalization(https://arxiv.org/abs/1803.05407)
- """
- # steps are mini-batches per epoch, equal to training_samples / batch_size
- steps = self.steps
-
- swa_epoch = (epoch - self.start_epoch) % self.swa_freq
- cycle_length = self.swa_freq * steps
-
- # batch 0 indexed, so need to add 1
- i = (swa_epoch * steps) + (batch + 1)
- if epoch >= self.start_epoch:
- t = (((i - 1) % cycle_length) + 1) / cycle_length
- return (1 - t) * self.swa_lr2 + t * self.swa_lr
- else:
- return self._constant_schedule(epoch)
-
- def _set_swa_weights(self, epoch):
- # self.model.set_weights(self.swa_weights)
- for name, p in self.model.named_parameters():
- p.data.copy_(self.swa_weights[name])
-
- if self.verbose > 0:
- print(
- "\nEpoch %05d: final model weights set to stochastic weight average"
- % (epoch + 1)
- )
-
- def _check_batch_norm(self):
- self.batch_norm_momentums = []
- self.batch_norm_layers = []
- self.has_batch_norm = False
- self.running_bn_epoch = False
-
- for layer in self.model.modules():
- if isinstance(layer, _BatchNorm):
- self.has_batch_norm = True
- self.batch_norm_momentums.append(layer.momentum)
- self.batch_norm_layers.append(layer)
-
- if self.verbose > 0 and self.has_batch_norm:
- print(
- "Model uses batch normalization. SWA will require last epoch "
- "to be a forward pass and will run with no learning rate"
- )
-
- def _reset_batch_norm(self):
- for layer in self.batch_norm_layers:
- # initialized moving mean and
- # moving var weights
- layer.reset_running_stats()
-
- def _restore_batch_norm(self):
- for layer, momentum in zip(self.batch_norm_layers, self.batch_norm_momentums):
- layer.momentum = momentum
-
-
-def export_traces_to_csv(trace_file, csv_file, fieldnames=None):
- traces = torch.load(trace_file)
-
- with open(csv_file, "w", newline="") as csvfile:
- if fieldnames is None:
- fieldnames = list(traces.keys())
- writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
-
- writer.writeheader()
- max_len = max([len(traces[field]) for field in fieldnames])
-
- for idx in range(max_len):
- row = {}
- for field in fieldnames:
- value = traces[field][idx] if idx < len(traces[field]) else ""
- row[field] = (
- value.data.item() if isinstance(value, torch.Tensor) else value
- )
- writer.writerow(row)
-
-
-def set_learning_rate(lr, optimizer):
- for param_group in optimizer.param_groups:
- param_group["lr"] = lr
-
-
-def get_learning_rate(optimizer):
- return optimizer.param_groups[0]["lr"]
-
-
-def apply_weight_decay(W, decay_rate, learning_rate, mask=None):
- # in mask, 1 represents fixed variables, 0 represents trainable variables
- if mask is not None:
- W[~mask] -= W[~mask] * decay_rate * learning_rate
- else:
- W -= W * decay_rate * learning_rate
-
-
-def disable_bn(model: torch.nn.Module) -> None:
- for m in model.modules():
- if isinstance(
- m,
- (
- torch.nn.BatchNorm1d,
- torch.nn.BatchNorm2d,
- torch.nn.BatchNorm3d,
- torch.nn.SyncBatchNorm,
- ),
- ):
- m.eval()
-
-
-def enable_bn(model: torch.nn.Module) -> None:
- for m in model.modules():
- if isinstance(
- m,
- (
- torch.nn.BatchNorm1d,
- torch.nn.BatchNorm2d,
- torch.nn.BatchNorm3d,
- torch.nn.SyncBatchNorm,
- ),
- ):
- m.train()
From 6d9dddd8c1fc23d86633413ff67520c24f2f001a Mon Sep 17 00:00:00 2001
From: Johnny1882
Date: Sat, 15 Feb 2025 19:05:08 +0000
Subject: [PATCH 12/38] remove redundent onn functions
---
scripts/mase-hls.py | 30 +-
scripts/stat-to-conf.py | 7 +-
setup.py | 4 +-
src/chop/actions/emit.py | 5 +-
src/chop/actions/search/search.py | 4 +-
.../actions/search/search_space/nas_bert.py | 8 +-
.../quantization/manual_hf_module.py | 3 +-
src/chop/actions/simulate.py | 4 +-
src/chop/actions/train.py | 3 +-
src/chop/dataset/nerf/blender.py | 4 +-
src/chop/dataset/nlp/language_modeling.py | 10 +-
src/chop/dataset/vision/transforms/cifar.py | 5 +-
src/chop/distributed/launcher.py | 16 +-
src/chop/distributed/tensor/__init__.py | 18 +-
src/chop/distributed/tensor/_dispatch.py | 16 +-
src/chop/distributed/tensor/_redistribute.py | 17 +-
src/chop/distributed/tensor/_sharding_prop.py | 15 +-
src/chop/distributed/tensor/api.py | 56 +-
.../distributed/tensor/ops/basic_strategy.py | 5 +-
.../distributed/tensor/ops/common_rules.py | 15 +-
src/chop/distributed/tensor/ops/conv_ops.py | 23 +-
src/chop/distributed/tensor/ops/math_ops.py | 34 +-
src/chop/distributed/tensor/ops/matrix_ops.py | 10 +-
.../distributed/tensor/ops/pointwise_ops.py | 9 +-
src/chop/distributed/tensor/ops/tensor_ops.py | 16 +-
src/chop/distributed/tensor/ops/utils.py | 4 +-
src/chop/distributed/tensor/ops/view_ops.py | 9 +-
src/chop/ir/graph/mase_graph.py | 49 +-
src/chop/ir/graph/mase_metadata.py | 4 +-
src/chop/ir/onnx/mase_onnx_graph.py | 5 +-
src/chop/ir/onnx/utils.py | 9 +-
.../models/bert/modeling_bert_quantized.py | 8 +-
src/chop/models/bert/quant_config_bert.py | 5 +-
src/chop/models/cnv/cnv.py | 17 +-
src/chop/models/cswin/cswintransformer.py | 8 +-
src/chop/models/deit/deit_v2.py | 2 +-
src/chop/models/efficientnet/efficientnet.py | 23 +-
src/chop/models/lfc/lfc.py | 4 +-
src/chop/models/llama/modeling_llama_llora.py | 5 +-
.../models/llama/modeling_llama_sparse.py | 5 +-
src/chop/models/mobilenet_v2/mobilenet_v2.py | 3 +-
src/chop/models/nerf/nerf_vision.py | 4 +-
src/chop/models/opt/modeling_opt.py | 2 +-
src/chop/models/opt/modeling_opt_lora.py | 2 +-
src/chop/models/opt/modeling_opt_quantized.py | 2 +-
src/chop/models/opt/modeling_opt_sparse.py | 2 +-
.../models/opt/quant_config_opt_quantized.py | 4 +-
src/chop/models/pvt/pvt.py | 9 +-
src/chop/models/pvt/pvt_v2.py | 2 +-
src/chop/models/repvgg/repvgg.py | 4 +-
src/chop/models/resnet/resnet.py | 56 +-
src/chop/models/toy/toy.py | 15 +-
src/chop/models/vgg_cifar/vgg_orig.py | 61 +-
src/chop/models/vision/snn/snn_toy.py | 4 +-
.../models/vision/snn/spikingResformer.py | 42 +-
src/chop/nn/backward/modules/__init__.py | 4 +-
src/chop/nn/functional/softermax.py | 2 +-
src/chop/nn/modules/gqa.py | 11 +-
src/chop/nn/modules/lora.py | 13 +-
src/chop/nn/modules/sparse.py | 19 +-
src/chop/nn/mx/activations.py | 5 +-
src/chop/nn/mx/bmm.py | 12 +-
src/chop/nn/mx/convolution.py | 32 +-
src/chop/nn/mx/elemwise_ops.py | 8 +-
src/chop/nn/mx/formats.py | 8 +-
src/chop/nn/mx/linear.py | 32 +-
src/chop/nn/mx/matmul.py | 12 +-
src/chop/nn/mx/mx_ops.py | 9 +-
src/chop/nn/mx/quantize.py | 4 +-
src/chop/nn/mx/simd_ops.py | 2 +-
src/chop/nn/mx/transpose_convolution.py | 32 +-
src/chop/nn/optical/modules/morr_conv2d.py | 7 +-
src/chop/nn/optical/modules/morr_linear.py | 6 +-
src/chop/nn/optical/utils/__init__.py | 4 +-
src/chop/nn/optical/utils/compute.py | 950 +-----------------
src/chop/nn/optical/utils/initializer.py | 105 +-
src/chop/nn/optical/utils/mrr.py | 39 -
src/chop/nn/optical/utils/mrr_op.py | 341 +------
src/chop/nn/optical/utils/quantize.py | 264 +----
src/chop/nn/quantized/functional/gelu.py | 4 +-
src/chop/nn/quantized/functional/linear.py | 83 +-
src/chop/nn/quantized/functional/matmul.py | 8 +-
src/chop/nn/quantized/functional/relu.py | 4 +-
src/chop/nn/quantized/functional/selu.py | 4 +-
src/chop/nn/quantized/functional/softplus.py | 4 +-
src/chop/nn/quantized/functional/softsign.py | 4 +-
src/chop/nn/quantized/functional/tanh.py | 4 +-
src/chop/nn/quantized/modules/__init__.py | 12 +-
src/chop/nn/quantized/modules/attention.py | 4 +-
.../nn/quantized/modules/attention_head.py | 9 +-
src/chop/nn/quantized/modules/batch_norm1d.py | 4 +-
src/chop/nn/quantized/modules/conv1d.py | 12 +-
src/chop/nn/quantized/modules/conv2d.py | 39 +-
src/chop/nn/quantized/modules/gelu.py | 8 +-
src/chop/nn/quantized/modules/gqa.py | 5 +-
src/chop/nn/quantized/modules/group_norm.py | 4 +-
.../nn/quantized/modules/instance_norm2d.py | 4 +-
src/chop/nn/quantized/modules/layer_norm.py | 4 +-
src/chop/nn/quantized/modules/linear.py | 18 +-
src/chop/nn/quantized/modules/relu.py | 8 +-
src/chop/nn/quantized/modules/rms_norm.py | 4 +-
src/chop/nn/quantized/modules/selu.py | 8 +-
src/chop/nn/quantized/modules/silu.py | 8 +-
src/chop/nn/quantized/modules/softplus.py | 8 +-
src/chop/nn/quantized/modules/softsign.py | 8 +-
src/chop/nn/quantized/modules/tanh.py | 8 +-
.../nn/quantizers/LUTNet/BaseInitializer.py | 5 +-
src/chop/nn/quantizers/LUTNet/BaseTrainer.py | 2 +-
src/chop/nn/quantizers/block_fp.py | 17 +-
src/chop/nn/quantizers/block_log.py | 8 +-
src/chop/nn/quantizers/block_minifloat.py | 9 +-
src/chop/nn/quantizers/integer.py | 8 +-
src/chop/nn/quantizers/log.py | 14 +-
src/chop/nn/quantizers/minifloat.py | 34 +-
src/chop/nn/quantizers/mxint_hardware.py | 6 +-
src/chop/nn/quantizers/quantizers_for_hw.py | 14 +-
src/chop/nn/quantizers/ternary.py | 3 +-
src/chop/nn/quantizers/utils.py | 32 +-
src/chop/nn/snn/auto_cuda/generator.py | 5 +-
src/chop/nn/snn/modules/neuron/ifnode.py | 20 +-
src/chop/nn/snn/modules/neuron/lifnode.py | 112 ++-
.../nn/snn/modules/spiking_self_attention.py | 8 +-
src/chop/passes/__init__.py | 4 +-
src/chop/passes/graph/__init__.py | 10 +-
.../add_metadata/add_common_metadata.py | 12 +-
.../add_metadata/add_hardware_metadata.py | 3 +-
.../add_metadata/common_metadata_layers.py | 57 +-
.../add_metadata/hardware_metadata_layers.py | 37 +-
.../add_metadata/software_metadata_layers.py | 16 +-
.../autosharding/alpa_cost_modelling.py | 10 +-
.../analysis/autosharding/autosharding.py | 44 +-
.../graph/analysis/autosharding/megatron.py | 4 +-
.../autosharding/strategies/basic_strategy.py | 5 +-
.../autosharding/strategies/common.py | 6 +-
.../autosharding/strategies/matrix_ops.py | 38 +-
.../autosharding/strategies/pointwise_ops.py | 4 +-
.../autosharding/strategies/tensor_ops.py | 4 +-
.../autosharding/strategies/view_ops.py | 7 +-
.../flop_estimator/calculator/calc_modules.py | 2 +-
.../passes/graph/analysis/plot/plot_graph.py | 5 +-
.../graph/interface/tensorrt/quantize.py | 1 +
.../passes/graph/transforms/dse/run_dse.py | 2 +-
src/chop/passes/graph/transforms/lora.py | 6 +-
.../graph/transforms/onnxrt/quantize.py | 4 +-
.../transforms/pruning/pruning_methods.py | 28 +-
.../quant_parsers/parse_quant_config.py | 82 +-
.../graph/transforms/training/modify.py | 5 +-
.../transforms/utils/logicnets_fusion.py | 8 +-
.../graph/transforms/verilog/emit_bram.py | 8 +-
.../graph/transforms/verilog/emit_hls.py | 7 +-
.../verilog/logicnets/emit_linear.py | 4 +-
src/chop/passes/module/analysis/report.py | 3 +-
src/chop/passes/utils.py | 8 +-
src/chop/pipelines/auto_pipeline.py | 22 +-
src/chop/tools/check_dependency.py | 7 +-
src/chop/tools/huggingface.py | 14 +-
.../tools/plt_wrapper/nlp/classification.py | 4 +-
src/chop/tools/plt_wrapper/nlp/lm.py | 4 +-
src/chop/tools/utils.py | 6 +-
src/mase_cocotb/interfaces/streaming.py | 24 +-
src/mase_cocotb/runner.py | 13 +-
src/mase_cocotb/testbench.py | 7 +-
src/mase_cocotb/utils.py | 4 +-
src/mase_cocotb/z_qlayers/tensor_cast.py | 6 +-
.../activation_layers/test/fixed_elu_tb.py | 14 +-
.../activation_layers/test/fixed_gelu_tb.py | 14 +-
.../test/fixed_hardshrink_tb.py | 6 +-
.../test/fixed_hardswish_tb.py | 10 +-
.../test/fixed_leaky_relu_tb.py | 6 +-
.../test/fixed_logsigmoid_tb.py | 6 +-
.../activation_layers/test/fixed_relu_tb.py | 6 +-
.../activation_layers/test/fixed_selu_tb.py | 10 +-
.../test/fixed_sigmoid_tb.py | 6 +-
.../activation_layers/test/fixed_silu_tb.py | 6 +-
.../test/fixed_softermax_1d_tb.py | 23 +-
.../test/fixed_softermax_tb.py | 4 +-
.../test/fixed_softmax_tb.py | 6 +-
.../test/fixed_softplus_tb.py | 10 +-
.../test/fixed_softshrink_tb.py | 6 +-
.../test/fixed_softsign_tb.py | 10 +-
.../activation_layers/test/fixed_tanh_tb.py | 10 +-
.../activation_layers/test/softermax.py | 2 +-
.../test/softermax_global_norm_tb.py | 26 +-
.../test/softermax_local_window_tb.py | 8 +-
.../test/softermax_lpw_pow2_tb.py | 51 +-
.../test/softermax_lpw_reciprocal_tb.py | 40 +-
.../cast/test/fixed_rounding_tb.py | 2 +-
.../cast/test/fixed_signed_cast_tb.py | 13 +-
.../cast/test/fixed_unsigned_cast_tb.py | 8 +-
.../common/test/comparator_accumulator_tb.py | 8 +-
.../common/test/comparator_tree_tb.py | 5 +-
.../common/test/register_slice_tb.py | 4 +-
.../common/test/single_element_repeat_tb.py | 6 +-
...binary_activation_binary_convolution_tb.py | 4 +-
.../convolution_layers/test/convolution_tb.py | 34 +-
.../convolution_layers/test/padding_tb.py | 10 +-
.../convolution_layers/test/roller_tb.py | 4 +-
.../test/sliding_window_tb.py | 6 +-
src/mase_components/deps.py | 12 +-
src/mase_components/helper/generate_memory.py | 8 +-
.../hls/bfp_arith/bfp_adder.py | 6 +-
.../hls/bfp_arith/bfp_multiplier.py | 6 +-
src/mase_components/hls/elastic/buffer.py | 8 +-
src/mase_components/hls/hls_regression.py | 5 +-
.../hls/int_arith/int_layernorm.py | 8 +-
src/mase_components/hls/int_arith/int_relu.py | 8 +-
src/mase_components/hls/int_arith/int_silu.py | 8 +-
.../hls/int_arith/int_softmax.py | 8 +-
.../hls/int_arith/int_transpose.py | 8 +-
.../hls/regression_gen/bfp_add_dse.py | 11 +-
.../hls/regression_gen/bfp_linear2d_dse.py | 8 +-
.../hls/regression_gen/bfp_mult_dse.py | 11 +-
.../hls/regression_gen/buffer_dse.py | 11 +-
.../hls/regression_gen/fork_dse.py | 11 +-
.../hls/regression_gen/int_add_dse.py | 11 +-
.../hls/regression_gen/int_layernorm_dse.py | 11 +-
.../hls/regression_gen/int_linear2d_dse.py | 8 +-
.../hls/regression_gen/int_matmul_dse.py | 8 +-
.../hls/regression_gen/int_mult_dse.py | 11 +-
.../hls/regression_gen/int_relu_dse.py | 11 +-
.../hls/regression_gen/int_rmsnorm_dse.py | 11 +-
.../hls/regression_gen/int_rope_dse.py | 11 +-
.../hls/regression_gen/int_silu_dse.py | 11 +-
.../hls/regression_gen/int_softmax_dse.py | 11 +-
.../hls/regression_gen/int_transpose_dse.py | 11 +-
src/mase_components/hls/scripts/bl_bfp.py | 15 +-
...y_activation_binary_adder_tree_layer_tb.py | 3 +-
.../fixed_operators/test/fixed_isqrt_tb.py | 18 +-
.../test/fixed_lut_index_tb.py | 3 +-
.../fixed_operators/test/fixed_nr_stage_tb.py | 4 +-
.../test/fixed_range_augmentation_tb.py | 2 +-
.../test/fixed_range_reduction_tb.py | 3 +-
.../fixed_operators/test/isqrt_sw.py | 10 +-
.../matmul/test/fixed_matmul_tb.py | 8 +-
.../matmul/test/simple_matmul_tb.py | 8 +-
.../linear_layers/matmul/test/transpose_tb.py | 6 +-
.../mxint_operators/test/mxint_cast_tb.py | 13 +-
.../test/mxint_dot_product_tb.py | 9 +-
.../mxint_operators/test/mxint_matmul_tb.py | 10 +-
.../test/mxint_vector_mult_tb.py | 9 +-
.../mxint_operators/test/test.py | 16 +-
.../mxint_operators/test/utils.py | 6 +-
src/mase_components/memory/test/fifo_tb.py | 9 +-
.../memory/test/repeat_circular_buffer_tb.py | 2 +-
.../memory/test/unpacked_fifo_tb.py | 2 +-
.../process_synth_impl.py | 11 +-
.../test/batch_norm_2d_tb.py | 6 +-
.../test/channel_selection_tb.py | 1 -
.../test/group_norm_2d_tb.py | 13 +-
.../normalization_layers/test/models.py | 14 +-
.../test/rms_norm_2d_tb.py | 11 +-
.../fixed/test/fixed_isqrt_tb.py | 18 +-
.../fixed/test/fixed_nr_stage_tb.py | 4 +-
.../scalar_operators/fixed/test/isqrt_sw.py | 10 +-
...ixed_grouped_query_attention_wrapper_tb.py | 67 +-
.../test/fixed_self_attention_head_tb.py | 32 +-
.../vision_models/vit/test/fixed_mlp_tb.py | 6 +-
.../vit/test/fixed_patch_embed_tb.py | 9 +-
.../vision_models/vit/test/fixed_pvt_tb.py | 12 +-
.../vision_models/vit/test/hash_exp_tb.py | 4 +-
.../vision_models/vit/test/hash_softmax_tb.py | 6 +-
.../vit/test/helpers/ha_softmax.py | 11 +-
.../vit/test/helpers/pvt_quant.py | 2 +-
test/nn/quantized/modules/attention_head.py | 4 +-
test/nn/snn/test_ann2snn.py | 8 +-
.../add_metadata/test_add_common_metadata.py | 16 +-
.../analysis/pruning/test_hook_inspect.py | 12 +-
.../test_statistic_profiler.py | 16 +-
.../graph/transforms/prune/test_prune.py | 12 +-
.../prune/test_prune_detach_hook.py | 12 +-
.../quantize/test_quantize_lutnet_conv2d.py | 3 +-
.../quantize/test_quantize_lutnet_linear_2.py | 4 +-
.../training/test_training_base_pass.py | 10 +-
.../verilog/test_emit_activation_gelu.py | 6 +-
.../verilog/test_emit_activation_selu.py | 6 +-
.../verilog/test_emit_activation_softplus.py | 6 +-
.../verilog/test_emit_activation_softsign.py | 6 +-
.../verilog/test_emit_activation_tanh.py | 6 +-
.../verilog/test_emit_verilog_bert.py | 11 +-
.../verilog/test_emit_verilog_llama.py | 11 +-
.../verilog/test_emit_verilog_mistral.py | 11 +-
.../verilog/test_emit_verilog_norm.py | 27 +-
.../onnx/analysis/test_export_fx_graph.py | 7 +-
test/self/test_optical_module.py | 252 +++++
test/self/train_mnist_cnn.py | 247 +++++
test/tools/test_onnx_operators.py | 57 +-
286 files changed, 1477 insertions(+), 4061 deletions(-)
create mode 100644 test/self/test_optical_module.py
create mode 100644 test/self/train_mnist_cnn.py
diff --git a/scripts/mase-hls.py b/scripts/mase-hls.py
index f2097bcd1..925ec2300 100755
--- a/scripts/mase-hls.py
+++ b/scripts/mase-hls.py
@@ -44,33 +44,17 @@ def build(self):
def quick_test(self):
shutil.copy(
- os.path.join(
- self.root,
- "test",
- "test_in.mlir",
- ),
- os.path.join(
- self.root,
- "test",
- "test.mlir",
- ),
+ os.path.join(self.root, "test", "test_in.mlir",),
+ os.path.join(self.root, "test", "test.mlir",),
)
result = False
cmd = [
"mase-opt",
"--preprocess-func=func-name=relu",
"--canonicalize",
- os.path.join(
- self.root,
- "test",
- "test.mlir",
- ),
+ os.path.join(self.root, "test", "test.mlir",),
"-o",
- os.path.join(
- self.root,
- "test",
- "test1.mlir",
- ),
+ os.path.join(self.root, "test", "test1.mlir",),
]
result |= self.execute(cmd, log_output=True, cwd=self.root)
@@ -83,11 +67,7 @@ def quick_test(self):
# "test",
# "test.cpp",
# ),
- os.path.join(
- self.root,
- "test",
- "test1.mlir",
- ),
+ os.path.join(self.root, "test", "test1.mlir",),
"--debug",
]
result |= self.execute(cmd, log_output=True, cwd=self.root)
diff --git a/scripts/stat-to-conf.py b/scripts/stat-to-conf.py
index 01e0d0cbc..182e66d64 100755
--- a/scripts/stat-to-conf.py
+++ b/scripts/stat-to-conf.py
@@ -41,12 +41,7 @@
}
-def set_stat(
- entry_name: str,
- mean=None,
- median=None,
- max=None,
-) -> dict[str, Any]:
+def set_stat(entry_name: str, mean=None, median=None, max=None,) -> dict[str, Any]:
"""Return a dictionary containing the format of the stats required to use
ternary quantiser. If statistics are not specified, "NA" will be set as the value,
this interally is being interpreted as None when the .toml is loaded"""
diff --git a/setup.py b/setup.py
index 27d865de1..e0e91be1d 100644
--- a/setup.py
+++ b/setup.py
@@ -105,9 +105,7 @@ def get_system():
author_email="a.zhao@imperial.ac.uk, jianyi.cheng17@imperial.ac.uk, chengzhang98@outlook.com, pedro.gimenes19@imperial.ac.uk",
license_files=("LICENSE",),
python_requires=">=3.11.9",
- package_dir={
- "": "src",
- },
+ package_dir={"": "src",},
packages=find_packages("src"),
install_requires=requirements,
)
diff --git a/src/chop/actions/emit.py b/src/chop/actions/emit.py
index 7e326bc7b..9318b4476 100644
--- a/src/chop/actions/emit.py
+++ b/src/chop/actions/emit.py
@@ -41,10 +41,7 @@ def emit(
data_module.prepare_data()
data_module.setup()
dummy_in = get_dummy_input(
- model_info=model_info,
- data_module=data_module,
- task=task,
- device="cpu",
+ model_info=model_info, data_module=data_module, task=task, device="cpu",
)
mg, _ = add_common_metadata_analysis_pass(
mg, {"dummy_in": dummy_in, "add_value": False}
diff --git a/src/chop/actions/search/search.py b/src/chop/actions/search/search.py
index fe82d3574..48c2d21d7 100644
--- a/src/chop/actions/search/search.py
+++ b/src/chop/actions/search/search.py
@@ -15,9 +15,7 @@
logger = logging.getLogger(__name__)
-def parse_search_config(
- search_config: dict,
-):
+def parse_search_config(search_config: dict,):
"""
Parse search config from a dict or a toml file and do sanity check. The search config must consist of two parts: strategy and search_space.
diff --git a/src/chop/actions/search/search_space/nas_bert.py b/src/chop/actions/search/search_space/nas_bert.py
index 3c4838ea1..760d3b27f 100644
--- a/src/chop/actions/search/search_space/nas_bert.py
+++ b/src/chop/actions/search/search_space/nas_bert.py
@@ -1199,9 +1199,11 @@ def forward(
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
if self.config.is_decoder and encoder_hidden_states is not None:
- encoder_batch_size, encoder_sequence_length, _ = (
- encoder_hidden_states.size()
- )
+ (
+ encoder_batch_size,
+ encoder_sequence_length,
+ _,
+ ) = encoder_hidden_states.size()
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
diff --git a/src/chop/actions/search/search_space/quantization/manual_hf_module.py b/src/chop/actions/search/search_space/quantization/manual_hf_module.py
index 094330da1..a9c9f1bfb 100644
--- a/src/chop/actions/search/search_space/quantization/manual_hf_module.py
+++ b/src/chop/actions/search/search_space/quantization/manual_hf_module.py
@@ -61,8 +61,7 @@ def rebuild_model(self, sampled_config: dict, is_eval_mode: bool):
with init_empty_weights():
model = self.model_cls(config)
device_map = infer_auto_device_map(
- model,
- no_split_module_classes=model._no_split_modules,
+ model, no_split_module_classes=model._no_split_modules,
)
model = load_checkpoint_and_dispatch(
model, checkpoint=self.model_name, device_map=device_map
diff --git a/src/chop/actions/simulate.py b/src/chop/actions/simulate.py
index e56a512d8..8a0e93486 100644
--- a/src/chop/actions/simulate.py
+++ b/src/chop/actions/simulate.py
@@ -68,9 +68,7 @@ def simulate(
else:
raise ValueError(f"Unrecognized simulator: {simulator}")
- includes = [
- project_dir / "hardware" / "rtl",
- ] + [
+ includes = [project_dir / "hardware" / "rtl",] + [
Path(mase_components.__file__).parent / module / "rtl"
for module in get_modules()
]
diff --git a/src/chop/actions/train.py b/src/chop/actions/train.py
index 5daa9b49a..93524e756 100644
--- a/src/chop/actions/train.py
+++ b/src/chop/actions/train.py
@@ -109,8 +109,7 @@ def train(
trainer = pl.Trainer(**plt_trainer_args)
trainer.fit(
- pl_model,
- datamodule=data_module,
+ pl_model, datamodule=data_module,
)
# Save the trained model along with relevant metadata in the training_ckpts folder.
diff --git a/src/chop/dataset/nerf/blender.py b/src/chop/dataset/nerf/blender.py
index 53865d621..748dacca7 100644
--- a/src/chop/dataset/nerf/blender.py
+++ b/src/chop/dataset/nerf/blender.py
@@ -34,9 +34,7 @@ def _download_lego_dataset(path: Path) -> None:
# Unzip the file
subprocess.run(
- f"unzip {folder_path.as_posix()} -d {path.as_posix()}",
- shell=True,
- check=True,
+ f"unzip {folder_path.as_posix()} -d {path.as_posix()}", shell=True, check=True,
)
diff --git a/src/chop/dataset/nlp/language_modeling.py b/src/chop/dataset/nlp/language_modeling.py
index ba2bdee55..8d17b9466 100644
--- a/src/chop/dataset/nlp/language_modeling.py
+++ b/src/chop/dataset/nlp/language_modeling.py
@@ -281,10 +281,7 @@ def _tokenize(text, tokenizer, max_length):
prompt_len = prompt_tokenized.ne(tokenizer.pad_token_id).sum().item()
target_tokenized[:prompt_len] = ignore_id
- return dict(
- input_ids=input_ids,
- labels=target_tokenized,
- )
+ return dict(input_ids=input_ids, labels=target_tokenized,)
def prepare_data(self):
dataset_dict = self._download_dataset()
@@ -316,10 +313,7 @@ def setup(self):
dataset_dict = self._download_dataset()
dataset_dict = dataset_dict["train"].train_test_split(test_size=0.1, seed=42)
dataset_dict = hf_datasets.DatasetDict(
- {
- "train": dataset_dict["train"],
- "validation": dataset_dict["test"],
- }
+ {"train": dataset_dict["train"], "validation": dataset_dict["test"],}
)
dataset_dict = dataset_dict.map(
function=partial(
diff --git a/src/chop/dataset/vision/transforms/cifar.py b/src/chop/dataset/vision/transforms/cifar.py
index ef44d3721..84e7e1287 100644
--- a/src/chop/dataset/vision/transforms/cifar.py
+++ b/src/chop/dataset/vision/transforms/cifar.py
@@ -35,10 +35,7 @@
def _get_cifar_default_transform(train: bool, mean: tuple[float], std: tuple[float]):
if train:
- transform = create_transform(
- **DEFAULT_CIFAR_PREPROCESS_ARGS,
- is_training=True,
- )
+ transform = create_transform(**DEFAULT_CIFAR_PREPROCESS_ARGS, is_training=True,)
transform.transforms[0] = tv_transforms.RandomCrop(
DEFAULT_CIFAR_PREPROCESS_ARGS["input_size"], padding=4
)
diff --git a/src/chop/distributed/launcher.py b/src/chop/distributed/launcher.py
index 712bb026f..21337be8c 100644
--- a/src/chop/distributed/launcher.py
+++ b/src/chop/distributed/launcher.py
@@ -33,10 +33,7 @@ def distributed_average_timing(fn, repeat, args):
times = []
for itr in range(repeat):
rlog(
- logger,
- dist.get_rank(),
- f"Running teration {itr}",
- "debug",
+ logger, dist.get_rank(), f"Running teration {itr}", "debug",
)
dist.barrier(async_op=True)
start = time()
@@ -45,10 +42,7 @@ def distributed_average_timing(fn, repeat, args):
end = time()
times.append(end - start)
rlog(
- logger,
- dist.get_rank(),
- f"Time taken: {end - start}s",
- "debug",
+ logger, dist.get_rank(), f"Time taken: {end - start}s", "debug",
)
return result, sum(times[2:]) / len(times[2:])
@@ -138,11 +132,7 @@ def device_fn(
distribute_tensor(in_tensor, mesh, [Replicate(), Replicate()])
for in_tensor in inputs
]
- _, time_taken = distributed_average_timing(
- fn=model,
- repeat=10,
- args=inputs,
- )
+ _, time_taken = distributed_average_timing(fn=model, repeat=10, args=inputs,)
rlog(logger, rank, f"Forward pass finished. Time taken: {time_taken}", level="info")
dist.destroy_process_group()
diff --git a/src/chop/distributed/tensor/__init__.py b/src/chop/distributed/tensor/__init__.py
index 826ed05cc..b9aa9c272 100644
--- a/src/chop/distributed/tensor/__init__.py
+++ b/src/chop/distributed/tensor/__init__.py
@@ -34,11 +34,7 @@
def _dtensor_init_helper(
- init_op,
- size: torch.Size,
- device_mesh=None,
- placements=None,
- **kwargs,
+ init_op, size: torch.Size, device_mesh=None, placements=None, **kwargs,
) -> DTensor:
from torch.distributed.tensor.placement_types import _DTensorSpec, TensorMeta
@@ -82,18 +78,10 @@ def _dtensor_init_helper(
spec = _DTensorSpec(
device_mesh,
tuple(placements),
- tensor_meta=TensorMeta(
- size,
- torch_stride,
- local_tensor.dtype,
- ),
+ tensor_meta=TensorMeta(size, torch_stride, local_tensor.dtype,),
)
- return DTensor(
- local_tensor,
- spec,
- requires_grad=kwargs["requires_grad"],
- )
+ return DTensor(local_tensor, spec, requires_grad=kwargs["requires_grad"],)
def ones(
diff --git a/src/chop/distributed/tensor/_dispatch.py b/src/chop/distributed/tensor/_dispatch.py
index 9aa96ae68..dcc74852f 100644
--- a/src/chop/distributed/tensor/_dispatch.py
+++ b/src/chop/distributed/tensor/_dispatch.py
@@ -42,9 +42,7 @@
def decompose_handler(
- op_call: torch._ops.OpOverload,
- args: Tuple[object, ...],
- kwargs: Dict[str, object],
+ op_call: torch._ops.OpOverload, args: Tuple[object, ...], kwargs: Dict[str, object],
) -> object:
"""
Decomposes a op to core ATen op, this handler is mostly here
@@ -58,9 +56,7 @@ def decompose_handler(
def is_same_size_handler(
- op_call: torch._ops.OpOverload,
- args: Tuple[object, ...],
- kwargs: Dict[str, object],
+ op_call: torch._ops.OpOverload, args: Tuple[object, ...], kwargs: Dict[str, object],
) -> bool:
lhs = cast(torch.Tensor, args[0])
rhs = cast(torch.Tensor, args[1])
@@ -201,8 +197,9 @@ def default_tensor(spec: _DTensorSpec) -> torch.Tensor:
# did not already construct one
random._rng_tracker = random.OffsetBasedRNGTracker(mesh.device_type)
- first_arg, first_local_arg = cast(dtensor.DTensor, args[0]), cast(
- torch.Tensor, local_tensor_args[0]
+ first_arg, first_local_arg = (
+ cast(dtensor.DTensor, args[0]),
+ cast(torch.Tensor, local_tensor_args[0]),
)
rng_context = (
random._rng_tracker._distribute_region(first_arg._spec)
@@ -254,8 +251,7 @@ def default_tensor(spec: _DTensorSpec) -> torch.Tensor:
@staticmethod
def redistribute_local_args(
- op_info: OpInfo,
- suggested_input_schema: OpSchema,
+ op_info: OpInfo, suggested_input_schema: OpSchema,
) -> None:
# NOTE: it's very rare that we need to reshard kwargs so we intentionally skip it
diff --git a/src/chop/distributed/tensor/_redistribute.py b/src/chop/distributed/tensor/_redistribute.py
index e04495126..e65766eb7 100644
--- a/src/chop/distributed/tensor/_redistribute.py
+++ b/src/chop/distributed/tensor/_redistribute.py
@@ -48,8 +48,7 @@ def _replicate_then_shard(val: _TransformInfo) -> int:
@lru_cache(maxsize=None)
def _gen_transform_infos(
- src_spec: _DTensorSpec,
- dst_spec: _DTensorSpec,
+ src_spec: _DTensorSpec, dst_spec: _DTensorSpec,
) -> List[_TransformInfo]:
"""
Generate the transform infos from the source placements to the target placements.
@@ -88,9 +87,7 @@ def _gen_transform_infos(
# calculate and save the logical shape for this sharding
mesh_dim_size = device_mesh.size(mesh_dim=i)
local_shard_size, _ = src._local_shard_size_on_dim(
- current_logical_shape[src.dim],
- mesh_dim_size,
- my_coordinate[i],
+ current_logical_shape[src.dim], mesh_dim_size, my_coordinate[i],
)
new_logical_shape = list(current_logical_shape)
new_logical_shape[src.dim] = local_shard_size
@@ -288,11 +285,7 @@ def forward( # type: ignore[override]
output = input._local_tensor
target_spec = current_spec
- return dtensor.DTensor(
- output,
- target_spec,
- requires_grad=input.requires_grad,
- )
+ return dtensor.DTensor(output, target_spec, requires_grad=input.requires_grad,)
@staticmethod
def backward(ctx, grad_output: "dtensor.DTensor"): # type: ignore[override]
@@ -327,9 +320,7 @@ def backward(ctx, grad_output: "dtensor.DTensor"): # type: ignore[override]
),
)
output_dtensor = dtensor.DTensor(
- output,
- spec,
- requires_grad=grad_output.requires_grad,
+ output, spec, requires_grad=grad_output.requires_grad,
)
return (
diff --git a/src/chop/distributed/tensor/_sharding_prop.py b/src/chop/distributed/tensor/_sharding_prop.py
index 7351b7895..8e1e26db1 100644
--- a/src/chop/distributed/tensor/_sharding_prop.py
+++ b/src/chop/distributed/tensor/_sharding_prop.py
@@ -48,8 +48,7 @@ class ShardingPropagator:
def __init__(self) -> None:
self.op_to_rules: Dict[OpOverload, Callable[[OpSchema], OutputSharding]] = {}
self.op_strategy_funcs: Dict[
- OpOverload,
- Callable[[DeviceMesh, OpSchema], StrategyType],
+ OpOverload, Callable[[DeviceMesh, OpSchema], StrategyType],
] = {}
# op map to save static argnum to decide to reuse sharding prop cache or re-run sharding prop
self.op_to_schema_info: Dict[OpOverload, RuntimeSchemaInfo] = {}
@@ -344,10 +343,8 @@ def spec_to_strategy(spec: object) -> object:
expected_input_spec = selected_strategies[idx].input_spec(
tensor_or_list_tensor_arg_idx
)
- expected_input_spec = (
- expected_input_spec.shallow_copy_with_tensor_meta(
- arg_spec.tensor_meta
- )
+ expected_input_spec = expected_input_spec.shallow_copy_with_tensor_meta(
+ arg_spec.tensor_meta
)
if arg_spec.placements != expected_input_spec.placements:
needs_redistribute = True
@@ -363,10 +360,8 @@ def spec_to_strategy(spec: object) -> object:
expected_input_spec = selected_strategies[0].input_spec(
tensor_or_list_tensor_arg_idx
)
- expected_input_spec = (
- expected_input_spec.shallow_copy_with_tensor_meta(
- arg.tensor_meta
- )
+ expected_input_spec = expected_input_spec.shallow_copy_with_tensor_meta(
+ arg.tensor_meta
)
if arg.placements != expected_input_spec.placements:
needs_redistribute = True
diff --git a/src/chop/distributed/tensor/api.py b/src/chop/distributed/tensor/api.py
index dcef33b24..a27c848f4 100644
--- a/src/chop/distributed/tensor/api.py
+++ b/src/chop/distributed/tensor/api.py
@@ -64,9 +64,7 @@
class _ToTorchTensor(torch.autograd.Function):
@staticmethod
def forward( # type: ignore[override]
- ctx,
- input: "DTensor",
- grad_placements: Optional[Sequence[Placement]],
+ ctx, input: "DTensor", grad_placements: Optional[Sequence[Placement]],
):
ctx.dtensor_spec = input._spec
ctx.grad_placements = grad_placements
@@ -100,11 +98,7 @@ def backward(ctx, grad_output: torch.Tensor): # type: ignore[override]
)
return (
- DTensor(
- grad_output,
- grad_spec,
- requires_grad=grad_output.requires_grad,
- ),
+ DTensor(grad_output, grad_spec, requires_grad=grad_output.requires_grad,),
None,
)
@@ -160,11 +154,7 @@ def forward( # type: ignore[override]
dist_spec = _DTensorSpec(
device_mesh,
placements,
- tensor_meta=TensorMeta(
- tensor_shape,
- tensor_stride,
- input.dtype,
- ),
+ tensor_meta=TensorMeta(tensor_shape, tensor_stride, input.dtype,),
)
# We want a fresh Tensor object that shares memory with the input tensor
@@ -217,11 +207,7 @@ class DTensor(torch.Tensor): # pyre-ignore[13]: pyre is bad at __new__
@staticmethod
@torch._disable_dynamo
def __new__(
- cls,
- local_tensor: torch.Tensor,
- spec: _DTensorSpec,
- *,
- requires_grad: bool,
+ cls, local_tensor: torch.Tensor, spec: _DTensorSpec, *, requires_grad: bool,
) -> "DTensor":
"""
Construct a DTensor from a local tensor, device mesh, and placement and
@@ -277,20 +263,12 @@ def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride):
local_tensor = inner_tensors["_local_tensor"]
spec, requires_grad = flatten_spec
unflatten_tensor_meta = TensorMeta(
- shape=outer_size,
- stride=outer_stride,
- dtype=spec.tensor_meta.dtype,
+ shape=outer_size, stride=outer_stride, dtype=spec.tensor_meta.dtype,
)
unflatten_spec = _DTensorSpec(
- spec.mesh,
- spec.placements,
- tensor_meta=unflatten_tensor_meta,
- )
- return DTensor(
- local_tensor,
- unflatten_spec,
- requires_grad=requires_grad,
+ spec.mesh, spec.placements, tensor_meta=unflatten_tensor_meta,
)
+ return DTensor(local_tensor, unflatten_spec, requires_grad=requires_grad,)
def __coerce_tangent_metadata__(self):
if not any(isinstance(p, Partial) for p in self.placements):
@@ -303,8 +281,7 @@ def __coerce_tangent_metadata__(self):
def __coerce_same_metadata_as_tangent__(self, flatten_spec):
(spec, _) = flatten_spec # Result of tensor_flatten()
return self.redistribute(
- device_mesh=self.device_mesh,
- placements=spec.placements,
+ device_mesh=self.device_mesh, placements=spec.placements,
)
@classmethod
@@ -312,11 +289,7 @@ def __coerce_same_metadata_as_tangent__(self, flatten_spec):
# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
- return DTensor._op_dispatcher.dispatch(
- func,
- args,
- kwargs or {},
- )
+ return DTensor._op_dispatcher.dispatch(func, args, kwargs or {},)
@staticmethod
def from_local(
@@ -394,12 +367,7 @@ def from_local(
# created should flow back the gradients to the local_tensor, so we call an autograd
# function to construct the dist tensor instead.
return _FromTorchTensor.apply( # pyre-ignore[16]: autograd func
- local_tensor,
- device_mesh,
- tuple(placements),
- run_check,
- shape,
- stride,
+ local_tensor, device_mesh, tuple(placements), run_check, shape, stride,
)
def to_local(
@@ -699,9 +667,7 @@ def distribute_tensor(
mesh=device_mesh,
placements=placements,
tensor_meta=TensorMeta(
- shape=tensor.size(),
- stride=tensor.stride(),
- dtype=tensor.dtype,
+ shape=tensor.size(), stride=tensor.stride(), dtype=tensor.dtype,
),
)
return DTensor(
diff --git a/src/chop/distributed/tensor/ops/basic_strategy.py b/src/chop/distributed/tensor/ops/basic_strategy.py
index 70ba54416..3a9e81ae2 100644
--- a/src/chop/distributed/tensor/ops/basic_strategy.py
+++ b/src/chop/distributed/tensor/ops/basic_strategy.py
@@ -84,10 +84,7 @@ def parse_dims(cls, input_dims: List[str], output_dim: str) -> "EinsumDims":
def gen_einsum_strategies(
- equation: str,
- mesh: DeviceMesh,
- *,
- linearity: bool = False,
+ equation: str, mesh: DeviceMesh, *, linearity: bool = False,
) -> OpStrategy:
"""
Generate a strategy list for the ops that follow einsum style notation.
diff --git a/src/chop/distributed/tensor/ops/common_rules.py b/src/chop/distributed/tensor/ops/common_rules.py
index e0e9a6e34..1324f7806 100644
--- a/src/chop/distributed/tensor/ops/common_rules.py
+++ b/src/chop/distributed/tensor/ops/common_rules.py
@@ -37,10 +37,7 @@ def _gen_reshard_suggestions(
)
suggested_schema = OpSchema(op_schema.op, tuple(suggested_arg_specs), {})
suggested_schema._inplace_rewrap_schema_suggestion(op_schema)
- return OutputSharding(
- None,
- redistribute_schema=suggested_schema,
- )
+ return OutputSharding(None, redistribute_schema=suggested_schema,)
def einop_rule(
@@ -218,10 +215,7 @@ def merge_sharding(dim: str, a: int, b: int) -> int:
)
return OutputSharding(
_DTensorSpec.from_dim_map(
- input_specs[0].mesh,
- output_dim_map,
- pending_sums,
- tensor_meta=tensor_meta,
+ input_specs[0].mesh, output_dim_map, pending_sums, tensor_meta=tensor_meta,
)
)
@@ -281,8 +275,5 @@ def pointwise_rule(op_schema: OpSchema, linearity: bool = False) -> OutputShardi
enforce_sharding[out_dimchar] = mesh_dim
return einop_rule(
- fmt,
- op_schema,
- linearity=linearity,
- enforce_sharding=enforce_sharding,
+ fmt, op_schema, linearity=linearity, enforce_sharding=enforce_sharding,
)
diff --git a/src/chop/distributed/tensor/ops/conv_ops.py b/src/chop/distributed/tensor/ops/conv_ops.py
index 707f391d6..9a4d5f425 100644
--- a/src/chop/distributed/tensor/ops/conv_ops.py
+++ b/src/chop/distributed/tensor/ops/conv_ops.py
@@ -50,16 +50,11 @@ def convolution_rules(op_schema: OpSchema) -> OutputSharding:
pending_sums = input_spec.sums
tensor_meta = TensorMeta(
- torch.Size(output_shape),
- output_stride,
- input_spec.tensor_meta.dtype,
+ torch.Size(output_shape), output_stride, input_spec.tensor_meta.dtype,
)
return OutputSharding(
_DTensorSpec.from_dim_map(
- input_spec.mesh,
- output_dim_map,
- pending_sums,
- tensor_meta=tensor_meta,
+ input_spec.mesh, output_dim_map, pending_sums, tensor_meta=tensor_meta,
)
)
@@ -88,22 +83,14 @@ def convolution_backward_rules(op_schema: OpSchema) -> OutputSharding:
assert input_spec.tensor_meta is not None
weight_tensor_meta = weight_spec.tensor_meta
bias_tensor_meta = TensorMeta(
- torch.Size(bias_shape_opt),
- (1,),
- input_spec.tensor_meta.dtype,
+ torch.Size(bias_shape_opt), (1,), input_spec.tensor_meta.dtype,
)
grad_input_spec = input_spec
grad_weight_spec = _DTensorSpec.from_dim_map(
- input_spec.mesh,
- [-1, -1, -1, -1],
- [0],
- tensor_meta=weight_tensor_meta,
+ input_spec.mesh, [-1, -1, -1, -1], [0], tensor_meta=weight_tensor_meta,
)
grad_bias_spec = _DTensorSpec.from_dim_map(
- input_spec.mesh,
- [-1],
- [0],
- tensor_meta=bias_tensor_meta,
+ input_spec.mesh, [-1], [0], tensor_meta=bias_tensor_meta,
)
return OutputSharding([grad_input_spec, grad_weight_spec, grad_bias_spec])
diff --git a/src/chop/distributed/tensor/ops/math_ops.py b/src/chop/distributed/tensor/ops/math_ops.py
index 770514219..da8dad8d5 100644
--- a/src/chop/distributed/tensor/ops/math_ops.py
+++ b/src/chop/distributed/tensor/ops/math_ops.py
@@ -128,7 +128,7 @@ def _pre_reduce_transform(self, tensor: torch.Tensor) -> torch.Tensor:
if self.reduce_op == "sum":
assert isinstance(self.norm_type, (int, float)), f"{self.norm_type}"
if self.norm_type != 0 and self.norm_type != 1:
- return tensor**self.norm_type
+ return tensor ** self.norm_type
return tensor
def _post_reduce_transform(self, tensor: torch.Tensor) -> torch.Tensor:
@@ -289,10 +289,7 @@ def common_reduction_strategy(
redistribute_cost = [generate_redistribute_costs(input_strategy, input_spec)]
reduction_strategy.strategies.append(
PlacementStrategy(
- output_specs=_DTensorSpec(
- mesh=mesh,
- placements=out_placements,
- ),
+ output_specs=_DTensorSpec(mesh=mesh, placements=out_placements,),
input_specs=(input_spec,),
redistribute_cost=redistribute_cost,
)
@@ -478,10 +475,7 @@ def softmax_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy:
@register_op_strategy(
- [
- aten._log_softmax_backward_data.default,
- aten._softmax_backward_data.default,
- ],
+ [aten._log_softmax_backward_data.default, aten._softmax_backward_data.default,],
schema_info=RuntimeSchemaInfo(2),
)
def softmax_backward_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy:
@@ -614,21 +608,14 @@ def nll_loss_forward_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrate
reduce_dims_map,
reduction_op,
)
- output_expected_spec = _DTensorSpec(
- mesh=mesh,
- placements=out_placements,
- )
+ output_expected_spec = _DTensorSpec(mesh=mesh, placements=out_placements,)
# whether reduction is sum or mean, the total weight has to be summed up if not replicated
total_weight_placements = map_placements_after_reduction(
- target_expected_spec.placements,
- reduce_dims,
- reduce_dims_map,
- "sum",
+ target_expected_spec.placements, reduce_dims, reduce_dims_map, "sum",
)
total_weight_expected_spec = _DTensorSpec(
- mesh=mesh,
- placements=total_weight_placements,
+ mesh=mesh, placements=total_weight_placements,
)
output_strategy.strategies.append(
@@ -761,8 +748,7 @@ def rlog(msg):
@register_op_strategy(
- [aten.native_layer_norm.default],
- schema_info=RuntimeSchemaInfo(1),
+ [aten.native_layer_norm.default], schema_info=RuntimeSchemaInfo(1),
)
def layer_norm_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy:
# args must be: input, normalized_shape, weight, bias, eps
@@ -856,8 +842,7 @@ def layer_norm_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy:
@register_op_strategy(
- [aten.native_layer_norm_backward.default],
- schema_info=RuntimeSchemaInfo(2),
+ [aten.native_layer_norm_backward.default], schema_info=RuntimeSchemaInfo(2),
)
def layer_norm_bwd_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy:
# args must be: grad_out, input, normalized_shape, mean, rstd,
@@ -1014,8 +999,7 @@ def layer_norm_bwd_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy
@register_op_strategy(
- [aten.topk.default],
- schema_info=RuntimeSchemaInfo(2),
+ [aten.topk.default], schema_info=RuntimeSchemaInfo(2),
)
def topk_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy:
input_strategy = cast(OpStrategy, op_schema.args_schema[0])
diff --git a/src/chop/distributed/tensor/ops/matrix_ops.py b/src/chop/distributed/tensor/ops/matrix_ops.py
index f76dca190..029a795df 100644
--- a/src/chop/distributed/tensor/ops/matrix_ops.py
+++ b/src/chop/distributed/tensor/ops/matrix_ops.py
@@ -384,10 +384,7 @@ def scaled_dot_product_efficient_attention_strategy(
single_mesh_dim_strategies.append(num_heads_dim_sharding)
return expand_to_full_mesh_op_strategy(
- mesh,
- op_schema,
- single_mesh_dim_strategies,
- input_index=4,
+ mesh, op_schema, single_mesh_dim_strategies, input_index=4,
)
@@ -452,8 +449,5 @@ def scaled_dot_product_efficient_attention_backward_strategy(
single_mesh_dim_strategies.append(num_heads_dim_sharding)
return expand_to_full_mesh_op_strategy(
- mesh,
- op_schema,
- single_mesh_dim_strategies,
- input_index=4,
+ mesh, op_schema, single_mesh_dim_strategies, input_index=4,
)
diff --git a/src/chop/distributed/tensor/ops/pointwise_ops.py b/src/chop/distributed/tensor/ops/pointwise_ops.py
index 221001f01..656fd2996 100644
--- a/src/chop/distributed/tensor/ops/pointwise_ops.py
+++ b/src/chop/distributed/tensor/ops/pointwise_ops.py
@@ -483,9 +483,7 @@ def common_pointwise_strategy(
common_shape, input_arg_spec.shape
)
input_target_placements = map_placements_after_broadcast(
- tuple(out_placements),
- common_shape,
- input_arg_dims_map,
+ tuple(out_placements), common_shape, input_arg_dims_map,
)
input_arg_target_spec = _DTensorSpec(
mesh=mesh,
@@ -499,10 +497,7 @@ def common_pointwise_strategy(
pointwise_strategy.strategies.append(
PlacementStrategy(
- output_specs=_DTensorSpec(
- mesh=mesh,
- placements=tuple(out_placements),
- ),
+ output_specs=_DTensorSpec(mesh=mesh, placements=tuple(out_placements),),
input_specs=input_specs,
redistribute_cost=redistribute_costs,
)
diff --git a/src/chop/distributed/tensor/ops/tensor_ops.py b/src/chop/distributed/tensor/ops/tensor_ops.py
index dcddcb98c..f54640e8a 100644
--- a/src/chop/distributed/tensor/ops/tensor_ops.py
+++ b/src/chop/distributed/tensor/ops/tensor_ops.py
@@ -76,10 +76,7 @@ def default_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType:
@register_op_strategy(
- [
- aten.equal.default,
- aten.is_same_size.default,
- ]
+ [aten.equal.default, aten.is_same_size.default,]
)
def equal_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType:
# equal_strategy deals with ops that comparing two tensor, we need to make sure
@@ -128,8 +125,7 @@ def equal_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType:
schema_info=RuntimeSchemaInfo(1, ["dtype"]),
)
@register_op_strategy(
- [aten.full_like.default],
- schema_info=RuntimeSchemaInfo(2, ["dtype"]),
+ [aten.full_like.default], schema_info=RuntimeSchemaInfo(2, ["dtype"]),
)
@register_op_strategy(
[
@@ -696,8 +692,7 @@ def place(vp: Placement, ip: Placement) -> Placement:
)
result = OutputSharding(
output_spec=_DTensorSpec(
- mesh=values_spec.mesh,
- placements=value_placements,
+ mesh=values_spec.mesh, placements=value_placements,
)
)
return result
@@ -784,10 +779,7 @@ def size_split(N, i):
else split_size_or_sections
)
output_spec_list = [
- _DTensorSpec(
- mesh=input_spec.mesh,
- placements=input_spec.placements,
- )
+ _DTensorSpec(mesh=input_spec.mesh, placements=input_spec.placements,)
for _ in range(len(output_size_list))
]
return OutputSharding(output_spec_list)
diff --git a/src/chop/distributed/tensor/ops/utils.py b/src/chop/distributed/tensor/ops/utils.py
index 28e5245c3..e005501d6 100644
--- a/src/chop/distributed/tensor/ops/utils.py
+++ b/src/chop/distributed/tensor/ops/utils.py
@@ -196,9 +196,7 @@ def infer_broadcast_dims_map(
def map_placements_after_broadcast(
- placements: Tuple[Placement, ...],
- shape: torch.Size,
- broadcast_dims_map: List[int],
+ placements: Tuple[Placement, ...], shape: torch.Size, broadcast_dims_map: List[int],
) -> Tuple[Placement, ...]:
"""Map each placement based on the output shape after broadcast."""
new_placements: List[Placement] = []
diff --git a/src/chop/distributed/tensor/ops/view_ops.py b/src/chop/distributed/tensor/ops/view_ops.py
index 5c91f6d64..6b89771ec 100644
--- a/src/chop/distributed/tensor/ops/view_ops.py
+++ b/src/chop/distributed/tensor/ops/view_ops.py
@@ -238,9 +238,7 @@ def dim_flatten(ndim: int, start_dim=0, end_dim=-1) -> DimMap:
def dim_movedim(
- ndim: int,
- input: Union[int, Sequence[int]],
- destination: Union[int, Sequence[int]],
+ ndim: int, input: Union[int, Sequence[int]], destination: Union[int, Sequence[int]],
) -> DimMap:
input = normalize_dims(input, ndim)
destination = normalize_dims(destination, ndim)
@@ -606,10 +604,7 @@ def reshape_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType:
input_src_spec = input_placement_strategy.output_spec
input_tgt_placements, output_placements = propagate_shape_and_sharding(
- input_src_spec.placements,
- tuple(global_in_shape),
- rules,
- mesh.shape,
+ input_src_spec.placements, tuple(global_in_shape), rules, mesh.shape,
)
# TODO: optimize this. we shouldn't simply blindly replicate
diff --git a/src/chop/ir/graph/mase_graph.py b/src/chop/ir/graph/mase_graph.py
index 322195869..5ee6b8bb3 100644
--- a/src/chop/ir/graph/mase_graph.py
+++ b/src/chop/ir/graph/mase_graph.py
@@ -76,13 +76,7 @@ def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool
is_fx_built_in_leaf_module = super().is_leaf_module(m, module_qualified_name)
is_mase_leaf_layers = isinstance(m, MASE_LEAF_LAYERS)
is_custom_layer = isinstance(m, self.custom_leaf_layers)
- return any(
- (
- is_fx_built_in_leaf_module,
- is_mase_leaf_layers,
- is_custom_layer,
- )
- )
+ return any((is_fx_built_in_leaf_module, is_mase_leaf_layers, is_custom_layer,))
def trace_torch_module(
@@ -128,19 +122,13 @@ def is_leaf_module(
self, m: torch.nn.Module, module_qualified_name: str
) -> bool:
is_hf_built_in_leaf_module = hf_is_leaf_module(
- self,
- m,
- module_qualified_name,
+ self, m, module_qualified_name,
)
is_custom_module = isinstance(m, custom_modules)
is_mase_leaf_layer = isinstance(m, MASE_LEAF_LAYERS)
return any(
- (
- is_hf_built_in_leaf_module,
- is_custom_module,
- is_mase_leaf_layer,
- )
+ (is_hf_built_in_leaf_module, is_custom_module, is_mase_leaf_layer,)
)
return is_leaf_module
@@ -152,9 +140,7 @@ def is_leaf_module(
)
graph_module = hf_symbolic_trace(
- model,
- tracer_cls=tracer_cls,
- input_names=hf_input_names,
+ model, tracer_cls=tracer_cls, input_names=hf_input_names,
)
graph_module.custom_ops = custom_ops
@@ -307,10 +293,7 @@ def __init__(
self.model.additional_inputs = []
elif isinstance(model, torch.nn.Module):
self.model = trace_torch_module(
- model,
- cf_args,
- custom_ops,
- hf_input_names=hf_input_names,
+ model, cf_args, custom_ops, hf_input_names=hf_input_names,
)
else:
raise ValueError(
@@ -349,16 +332,11 @@ def from_module(
), f"model must be a torch.nn.Module. Received: {type(model)}"
graph_module = trace_torch_module(model, cf_args, custom_ops)
- return cls(
- model=graph_module,
- cf_args=cf_args,
- )
+ return cls(model=graph_module, cf_args=cf_args,)
@classmethod
def from_checkpoint(
- cls,
- checkpoint: str,
- propagate_missing_metadata: bool = True,
+ cls, checkpoint: str, propagate_missing_metadata: bool = True,
):
"""
Load a MaseGraph from a checkpoint. A MaseGraph checkpoint consists of two files:
@@ -393,18 +371,12 @@ def from_checkpoint(
for node in mg.nodes:
if node.name in loaded_meta.keys():
parameters = loaded_meta[node.name]
- node.meta["mase"] = MaseMetadata(
- node=node,
- model=loaded_model,
- )
+ node.meta["mase"] = MaseMetadata(node=node, model=loaded_model,)
node.meta["mase"].parameters = parameters
else:
# todo: propagate metadata for missing nodes
logger.warning(f"Node {node.name} not found in loaded metadata.")
- node.meta["mase"] = MaseMetadata(
- node=node,
- model=loaded_model,
- )
+ node.meta["mase"] = MaseMetadata(node=node, model=loaded_model,)
for attr in [
"class_for_deserialization",
@@ -417,8 +389,7 @@ def from_checkpoint(
return mg
def export(
- self,
- fname: str = "masegraph",
+ self, fname: str = "masegraph",
):
"""
Export the MaseGraph to a pair of files: {fname}.pt and {fname}.mz.
diff --git a/src/chop/ir/graph/mase_metadata.py b/src/chop/ir/graph/mase_metadata.py
index 6f01d83c8..52fcf0127 100644
--- a/src/chop/ir/graph/mase_metadata.py
+++ b/src/chop/ir/graph/mase_metadata.py
@@ -100,9 +100,7 @@ class MaseMetadata:
known_storage = ["BRAM"]
def __init__(
- self,
- node=None,
- model=None,
+ self, node=None, model=None,
):
# Top-level model
self.model = model
diff --git a/src/chop/ir/onnx/mase_onnx_graph.py b/src/chop/ir/onnx/mase_onnx_graph.py
index 40bb7915c..e4a8f52ed 100644
--- a/src/chop/ir/onnx/mase_onnx_graph.py
+++ b/src/chop/ir/onnx/mase_onnx_graph.py
@@ -9,11 +9,8 @@
class MaseOnnxGraph:
-
def __init__(
- self,
- model_proto: onnx.onnx_ml_pb2.ModelProto,
- model_name: str = None,
+ self, model_proto: onnx.onnx_ml_pb2.ModelProto, model_name: str = None,
):
self.model_proto = model_proto
self.graph = model_proto.graph
diff --git a/src/chop/ir/onnx/utils.py b/src/chop/ir/onnx/utils.py
index bb65aa8b2..c5250bba0 100644
--- a/src/chop/ir/onnx/utils.py
+++ b/src/chop/ir/onnx/utils.py
@@ -188,10 +188,7 @@ def onnx_to_torch_dtype(dtype):
"target": torch.mean,
"input_mapping": ["input"],
"attribute_mapping": {"keepdims": "", "axes": ""},
- "attribute_transform": {
- "keepdims": None,
- "axes": None,
- },
+ "attribute_transform": {"keepdims": None, "axes": None,},
"attribute_default": {"keepdims": 1, "axes": None},
},
"Expand": {
@@ -335,9 +332,7 @@ def onnx_to_torch_dtype(dtype):
"input_mapping": ["input"],
"attribute_mapping": {"perm": "dims"},
"attribute_transform": {"perm": lambda x: [i for i in x]},
- "attribute_default": {
- "perm": None,
- },
+ "attribute_default": {"perm": None,},
},
"Max": {
"fx_op": "call_function",
diff --git a/src/chop/models/bert/modeling_bert_quantized.py b/src/chop/models/bert/modeling_bert_quantized.py
index 8061d2593..c6747f0e9 100644
--- a/src/chop/models/bert/modeling_bert_quantized.py
+++ b/src/chop/models/bert/modeling_bert_quantized.py
@@ -561,9 +561,7 @@ def __init__(self, config, quant_config: dict):
super().__init__()
# self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
self.dense = get_quantized_cls("linear", quant_config["dense"])(
- config.hidden_size,
- config.intermediate_size,
- config=quant_config["dense"],
+ config.hidden_size, config.intermediate_size, config=quant_config["dense"],
)
self.quant_config = quant_config
if isinstance(config.hidden_act, str):
@@ -582,9 +580,7 @@ def __init__(self, config, quant_config):
super().__init__()
# self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.dense = get_quantized_cls("linear", quant_config["dense"])(
- config.intermediate_size,
- config.hidden_size,
- config=quant_config["dense"],
+ config.intermediate_size, config.hidden_size, config=quant_config["dense"],
)
self.quant_config = quant_config
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
diff --git a/src/chop/models/bert/quant_config_bert.py b/src/chop/models/bert/quant_config_bert.py
index f66c45a4c..4dadb97f0 100644
--- a/src/chop/models/bert/quant_config_bert.py
+++ b/src/chop/models/bert/quant_config_bert.py
@@ -88,10 +88,7 @@ def create_a_layer_config(
return qc
-def _parse_and_complete_config(
- config: dict,
- num_hidden_layers: int,
-) -> dict:
+def _parse_and_complete_config(config: dict, num_hidden_layers: int,) -> dict:
assert "default" in config, "Must provide a default config"
default_qc: dict = config["default"]
linear_qc: dict = parse_node_config(
diff --git a/src/chop/models/cnv/cnv.py b/src/chop/models/cnv/cnv.py
index 00df962a6..ee1912f18 100644
--- a/src/chop/models/cnv/cnv.py
+++ b/src/chop/models/cnv/cnv.py
@@ -4,12 +4,8 @@
from typing import Any
import numpy as np
-from chop.nn.quantized.modules.conv2d import (
- Conv2dBinaryResidualSign,
-)
-from chop.nn.quantized.modules.linear import (
- LinearBinaryResidualSign,
-)
+from chop.nn.quantized.modules.conv2d import Conv2dBinaryResidualSign
+from chop.nn.quantized.modules.linear import LinearBinaryResidualSign
from chop.models.utils import register_mase_model, register_mase_checkpoint
"""
@@ -206,8 +202,7 @@ def forward(self, x: Tensor) -> Tensor:
# Getters ------------------------------------------------------------------------------
@register_mase_checkpoint("cnv-toy")
def get_cnv_toy(
- pretrained=False,
- **kwargs: Any,
+ pretrained=False, **kwargs: Any,
):
# image_size = info["image_size"]
info = kwargs["dataset_info"]
@@ -217,8 +212,7 @@ def get_cnv_toy(
@register_mase_checkpoint("cnv")
def get_cnv(
- pretrained=False,
- **kwargs: Any,
+ pretrained=False, **kwargs: Any,
):
# image_size = info["image_size"]
info = kwargs["dataset_info"]
@@ -228,8 +222,7 @@ def get_cnv(
@register_mase_checkpoint("cnv_residual")
def get_cnv_residual(
- pretrained=False,
- **kwargs: Any,
+ pretrained=False, **kwargs: Any,
):
# image_size = info["image_size"]
info = kwargs["dataset_info"]
diff --git a/src/chop/models/cswin/cswintransformer.py b/src/chop/models/cswin/cswintransformer.py
index cbd85cda6..449a22d30 100644
--- a/src/chop/models/cswin/cswintransformer.py
+++ b/src/chop/models/cswin/cswintransformer.py
@@ -90,7 +90,7 @@ def __init__(
self.num_heads = num_heads
head_dim = dim // num_heads
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
- self.scale = qk_scale or head_dim**-0.5
+ self.scale = qk_scale or head_dim ** -0.5
if idx == -1:
H_sp, W_sp = self.resolution, self.resolution
elif idx == 0:
@@ -362,9 +362,9 @@ def __init__(
super().__init__()
self.use_chk = use_chk
self.num_classes = num_classes
- self.num_features = self.embed_dim = (
- embed_dim # num_features for consistency with other models
- )
+ self.num_features = (
+ self.embed_dim
+ ) = embed_dim # num_features for consistency with other models
heads = num_heads
self.stage1_conv_embed = nn.Sequential(
diff --git a/src/chop/models/deit/deit_v2.py b/src/chop/models/deit/deit_v2.py
index 4f0128216..21b2a5ad4 100644
--- a/src/chop/models/deit/deit_v2.py
+++ b/src/chop/models/deit/deit_v2.py
@@ -28,7 +28,7 @@ def __init__(
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
- self.scale = qk_scale or head_dim**-0.5
+ self.scale = qk_scale or head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
diff --git a/src/chop/models/efficientnet/efficientnet.py b/src/chop/models/efficientnet/efficientnet.py
index 5d2f7730a..b3af62d05 100644
--- a/src/chop/models/efficientnet/efficientnet.py
+++ b/src/chop/models/efficientnet/efficientnet.py
@@ -499,8 +499,7 @@ def forward(self, x: Tensor) -> Tensor:
def _efficientnet_conf(
- arch: str,
- **kwargs: Any,
+ arch: str, **kwargs: Any,
) -> Tuple[Sequence[Union[MBConvConfig, FusedMBConvConfig]], Optional[int]]:
inverted_residual_setting: Sequence[Union[MBConvConfig, FusedMBConvConfig]]
if arch.startswith("efficientnet_b"):
@@ -592,9 +591,7 @@ def _efficientnet(
def get_efficientnet_b0(
- info: Dict,
- pretrained: bool = False,
- **kwargs: Any,
+ info: Dict, pretrained: bool = False, **kwargs: Any,
):
num_classes = info.num_classes
if pretrained:
@@ -616,9 +613,7 @@ def get_efficientnet_b0(
def get_efficientnet_b3(
- info: Dict,
- pretrained: bool = False,
- **kwargs: Any,
+ info: Dict, pretrained: bool = False, **kwargs: Any,
):
num_classes = info.num_classes
if pretrained:
@@ -641,9 +636,7 @@ def get_efficientnet_b3(
def get_efficientnet_v2_s(
- info: Dict,
- pretrained: bool = False,
- **kwargs: Any,
+ info: Dict, pretrained: bool = False, **kwargs: Any,
):
num_classes = info.num_classes
if pretrained:
@@ -663,9 +656,7 @@ def get_efficientnet_v2_s(
def get_efficientnet_v2_m(
- info: Dict,
- pretrained: bool = False,
- **kwargs: Any,
+ info: Dict, pretrained: bool = False, **kwargs: Any,
):
num_classes = info.num_classes
if pretrained:
@@ -685,9 +676,7 @@ def get_efficientnet_v2_m(
def get_efficientnet_v2_l(
- info: Dict,
- pretrained: bool = False,
- **kwargs: Any,
+ info: Dict, pretrained: bool = False, **kwargs: Any,
):
num_classes = info.num_classes
if pretrained:
diff --git a/src/chop/models/lfc/lfc.py b/src/chop/models/lfc/lfc.py
index 70da9df45..410359a25 100644
--- a/src/chop/models/lfc/lfc.py
+++ b/src/chop/models/lfc/lfc.py
@@ -42,9 +42,7 @@ def forward(self, x):
# Getters ------------------------------------------------------------------------------
def get_lfc(
- info,
- pretrained=False,
- **kwargs: Any,
+ info, pretrained=False, **kwargs: Any,
):
image_size = info["image_size"]
num_classes = info.num_classes
diff --git a/src/chop/models/llama/modeling_llama_llora.py b/src/chop/models/llama/modeling_llama_llora.py
index b45dd8494..45eb2f210 100644
--- a/src/chop/models/llama/modeling_llama_llora.py
+++ b/src/chop/models/llama/modeling_llama_llora.py
@@ -195,10 +195,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
class LlamaMLP(nn.Module):
def __init__(
- self,
- hidden_size: int,
- intermediate_size: int,
- hidden_act: str,
+ self, hidden_size: int, intermediate_size: int, hidden_act: str,
):
super().__init__()
# fmt: off
diff --git a/src/chop/models/llama/modeling_llama_sparse.py b/src/chop/models/llama/modeling_llama_sparse.py
index f5f37eba0..21359ea34 100644
--- a/src/chop/models/llama/modeling_llama_sparse.py
+++ b/src/chop/models/llama/modeling_llama_sparse.py
@@ -195,10 +195,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
class LlamaMLP(nn.Module):
def __init__(
- self,
- hidden_size: int,
- intermediate_size: int,
- hidden_act: str,
+ self, hidden_size: int, intermediate_size: int, hidden_act: str,
):
super().__init__()
# fmt: off
diff --git a/src/chop/models/mobilenet_v2/mobilenet_v2.py b/src/chop/models/mobilenet_v2/mobilenet_v2.py
index d8f4020dd..0980d58dd 100644
--- a/src/chop/models/mobilenet_v2/mobilenet_v2.py
+++ b/src/chop/models/mobilenet_v2/mobilenet_v2.py
@@ -288,8 +288,7 @@ def __init__(
# building classifier
self.classifier = nn.Sequential(
- nn.Dropout(p=dropout),
- nn.Linear(self.last_channel, num_classes),
+ nn.Dropout(p=dropout), nn.Linear(self.last_channel, num_classes),
)
# weight initialization
diff --git a/src/chop/models/nerf/nerf_vision.py b/src/chop/models/nerf/nerf_vision.py
index f8591f1a4..18338940b 100644
--- a/src/chop/models/nerf/nerf_vision.py
+++ b/src/chop/models/nerf/nerf_vision.py
@@ -139,9 +139,7 @@ def load_weights_from_keras(self, weights):
# Getters ------------------------------------------------------------------------------
def get_nerf(
- info,
- pretrained=False,
- **kwargs: Any,
+ info, pretrained=False, **kwargs: Any,
):
# image_size = info["image_size"]
num_classes = info.num_classes
diff --git a/src/chop/models/opt/modeling_opt.py b/src/chop/models/opt/modeling_opt.py
index e88db8c1f..430f77df8 100644
--- a/src/chop/models/opt/modeling_opt.py
+++ b/src/chop/models/opt/modeling_opt.py
@@ -132,7 +132,7 @@ def __init__(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
f" and `num_heads`: {num_heads})."
)
- self.scaling = self.head_dim**-0.5
+ self.scaling = self.head_dim ** -0.5
self.is_decoder = is_decoder
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
diff --git a/src/chop/models/opt/modeling_opt_lora.py b/src/chop/models/opt/modeling_opt_lora.py
index 810cfe0d6..30bb21f20 100644
--- a/src/chop/models/opt/modeling_opt_lora.py
+++ b/src/chop/models/opt/modeling_opt_lora.py
@@ -130,7 +130,7 @@ def __init__(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
f" and `num_heads`: {num_heads})."
)
- self.scaling = self.head_dim**-0.5
+ self.scaling = self.head_dim ** -0.5
self.is_decoder = is_decoder
lora_config = config.lora_config[f"model_layer_{layer_id}"]["self_attn"]
diff --git a/src/chop/models/opt/modeling_opt_quantized.py b/src/chop/models/opt/modeling_opt_quantized.py
index d1086f413..d6eb94343 100644
--- a/src/chop/models/opt/modeling_opt_quantized.py
+++ b/src/chop/models/opt/modeling_opt_quantized.py
@@ -168,7 +168,7 @@ def __init__(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
f" and `num_heads`: {num_heads})."
)
- self.scaling = self.head_dim**-0.5
+ self.scaling = self.head_dim ** -0.5
self.is_decoder = is_decoder
# fmt:off
diff --git a/src/chop/models/opt/modeling_opt_sparse.py b/src/chop/models/opt/modeling_opt_sparse.py
index f39a29cc9..4addd6525 100644
--- a/src/chop/models/opt/modeling_opt_sparse.py
+++ b/src/chop/models/opt/modeling_opt_sparse.py
@@ -130,7 +130,7 @@ def __init__(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
f" and `num_heads`: {num_heads})."
)
- self.scaling = self.head_dim**-0.5
+ self.scaling = self.head_dim ** -0.5
self.is_decoder = is_decoder
sparse_config = config.sparse_config[f"model_layer_{layer_id}"]["self_attn"]
diff --git a/src/chop/models/opt/quant_config_opt_quantized.py b/src/chop/models/opt/quant_config_opt_quantized.py
index a76f8adb8..150a1b2d6 100644
--- a/src/chop/models/opt/quant_config_opt_quantized.py
+++ b/src/chop/models/opt/quant_config_opt_quantized.py
@@ -32,9 +32,7 @@
def create_a_layer_config(
- linear_qc: dict = None,
- bmm_qc: dict = None,
- layer_qc=None,
+ linear_qc: dict = None, bmm_qc: dict = None, layer_qc=None,
) -> dict:
if (layer_qc is None and bmm_qc is None) and layer_qc is None:
raise ValueError("Must provide either (linear_qc & bmm_qc ) or layer_qc")
diff --git a/src/chop/models/pvt/pvt.py b/src/chop/models/pvt/pvt.py
index ac3a18708..2474fb388 100644
--- a/src/chop/models/pvt/pvt.py
+++ b/src/chop/models/pvt/pvt.py
@@ -58,7 +58,7 @@ def __init__(
self.dim = dim
self.num_heads = num_heads
head_dim = dim // num_heads
- self.scale = qk_scale or head_dim**-0.5
+ self.scale = qk_scale or head_dim ** -0.5
self.q = nn.Linear(dim, dim, bias=qkv_bias)
self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
@@ -182,12 +182,7 @@ def forward(self, x):
@register_mase_model(
"pvt",
- checkpoints=[
- "pvt_tiny",
- "pvt_small",
- "pvt_medium",
- "pvt_large",
- ],
+ checkpoints=["pvt_tiny", "pvt_small", "pvt_medium", "pvt_large",],
model_source="vision_others",
task_type="vision",
image_classification=True,
diff --git a/src/chop/models/pvt/pvt_v2.py b/src/chop/models/pvt/pvt_v2.py
index 3ffb45f6b..b70acb6d1 100644
--- a/src/chop/models/pvt/pvt_v2.py
+++ b/src/chop/models/pvt/pvt_v2.py
@@ -82,7 +82,7 @@ def __init__(
self.dim = dim
self.num_heads = num_heads
head_dim = dim // num_heads
- self.scale = qk_scale or head_dim**-0.5
+ self.scale = qk_scale or head_dim ** -0.5
self.q = nn.Linear(dim, dim, bias=qkv_bias)
self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
diff --git a/src/chop/models/repvgg/repvgg.py b/src/chop/models/repvgg/repvgg.py
index 086501fc5..5ef7998e0 100644
--- a/src/chop/models/repvgg/repvgg.py
+++ b/src/chop/models/repvgg/repvgg.py
@@ -143,14 +143,14 @@ def get_custom_L2(self):
.detach()
)
- l2_loss_circle = (K3**2).sum() - (
+ l2_loss_circle = (K3 ** 2).sum() - (
K3[:, :, 1:2, 1:2] ** 2
).sum() # The L2 loss of the "circle" of weights in 3x3 kernel. Use regular L2 on them.
eq_kernel = (
K3[:, :, 1:2, 1:2] * t3 + K1 * t1
) # The equivalent resultant central point of 3x3 kernel.
l2_loss_eq_kernel = (
- eq_kernel**2 / (t3**2 + t1**2)
+ eq_kernel ** 2 / (t3 ** 2 + t1 ** 2)
).sum() # Normalize for an L2 coefficient comparable to regular L2.
return l2_loss_eq_kernel + l2_loss_circle
diff --git a/src/chop/models/resnet/resnet.py b/src/chop/models/resnet/resnet.py
index 313c8ce36..b339b72a7 100644
--- a/src/chop/models/resnet/resnet.py
+++ b/src/chop/models/resnet/resnet.py
@@ -157,13 +157,7 @@ def forward(self, x: Tensor) -> Tensor:
@register_mase_model(
name="resnet",
- checkpoints=[
- "resnet18",
- "resnet34",
- "resnet50",
- "resnet101",
- "wide_resnet50_2",
- ],
+ checkpoints=["resnet18", "resnet34", "resnet50", "resnet101", "wide_resnet50_2",],
model_source="torchvision",
task_type="vision",
image_classification=True,
@@ -337,10 +331,7 @@ def _resnet(
@register_mase_checkpoint("resnet18")
-def get_resnet18(
- pretrained: bool = False,
- **kwargs: Any,
-) -> ResNet:
+def get_resnet18(pretrained: bool = False, **kwargs: Any,) -> ResNet:
"""ResNet-18 from `Deep Residual Learning for Image Recognition `__."""
if pretrained:
pretrained_weight_cls = ResNet18_Weights.IMAGENET1K_V1
@@ -348,18 +339,12 @@ def get_resnet18(
pretrained_weight_cls = None
return _resnet(
- BasicBlock,
- [2, 2, 2, 2],
- pretrained_weight_cls=pretrained_weight_cls,
- **kwargs,
+ BasicBlock, [2, 2, 2, 2], pretrained_weight_cls=pretrained_weight_cls, **kwargs,
)
@register_mase_checkpoint("resnet34")
-def get_resnet34(
- pretrained: bool = False,
- **kwargs: Any,
-) -> ResNet:
+def get_resnet34(pretrained: bool = False, **kwargs: Any,) -> ResNet:
"""ResNet-34 from `Deep Residual Learning for Image Recognition `__."""
if pretrained:
pretrained_weight_cls = ResNet34_Weights.IMAGENET1K_V1
@@ -367,18 +352,12 @@ def get_resnet34(
pretrained_weight_cls = None
return _resnet(
- BasicBlock,
- [2, 2, 2, 2],
- pretrained_weight_cls=pretrained_weight_cls,
- **kwargs,
+ BasicBlock, [2, 2, 2, 2], pretrained_weight_cls=pretrained_weight_cls, **kwargs,
)
@register_mase_checkpoint("resnet50")
-def get_resnet50(
- pretrained: bool = False,
- **kwargs: Any,
-) -> ResNet:
+def get_resnet50(pretrained: bool = False, **kwargs: Any,) -> ResNet:
"""ResNet-50 from `Deep Residual Learning for Image Recognition `__."""
info = kwargs["dataset_info"]
if pretrained:
@@ -387,18 +366,12 @@ def get_resnet50(
pretrained_weight_cls = None
return _resnet(
- Bottleneck,
- [3, 4, 6, 3],
- pretrained_weight_cls=pretrained_weight_cls,
- **kwargs,
+ Bottleneck, [3, 4, 6, 3], pretrained_weight_cls=pretrained_weight_cls, **kwargs,
)
@register_mase_checkpoint("resnet101")
-def get_resnet101(
- pretrained: bool = False,
- **kwargs: Any,
-) -> ResNet:
+def get_resnet101(pretrained: bool = False, **kwargs: Any,) -> ResNet:
"""ResNet-101 from `Deep Residual Learning for Image Recognition `__."""
info = kwargs["dataset_info"]
if pretrained:
@@ -407,17 +380,13 @@ def get_resnet101(
pretrained_weight_cls = None
return _resnet(
- BasicBlock,
- [2, 2, 2, 2],
- pretrained_weight_cls=pretrained_weight_cls,
- **kwargs,
+ BasicBlock, [2, 2, 2, 2], pretrained_weight_cls=pretrained_weight_cls, **kwargs,
)
@register_mase_checkpoint("wide_resnet50_2")
def get_wide_resnet50_2(
- pretrained: bool = False,
- **kwargs,
+ pretrained: bool = False, **kwargs,
):
"""
`Wide Residual Networks `_.
@@ -435,8 +404,5 @@ def get_wide_resnet50_2(
_ovewrite_named_param(kwargs, "width_per_group", 64 * 2)
return _resnet(
- Bottleneck,
- [3, 4, 6, 3],
- pretrained_weight_cls=pretrained_weight_cls,
- **kwargs,
+ Bottleneck, [3, 4, 6, 3], pretrained_weight_cls=pretrained_weight_cls, **kwargs,
)
diff --git a/src/chop/models/toy/toy.py b/src/chop/models/toy/toy.py
index e88e030f6..3890c4a0f 100644
--- a/src/chop/models/toy/toy.py
+++ b/src/chop/models/toy/toy.py
@@ -160,8 +160,7 @@ def _conv_block(self, conv_class, *args):
# Getters ------------------------------------------------------------------------------
@register_mase_checkpoint("toy")
def get_toynet(
- pretrained=False,
- **kwargs: Any,
+ pretrained=False, **kwargs: Any,
):
info = kwargs["dataset_info"]
image_size = info.image_size
@@ -170,9 +169,7 @@ def get_toynet(
def get_toy_tiny(
- info,
- pretrained=False,
- **kwargs: Any,
+ info, pretrained=False, **kwargs: Any,
):
image_size = info.image_size
num_classes = info.num_classes
@@ -180,9 +177,7 @@ def get_toy_tiny(
def get_toy_testmodel(
- info,
- pretrained=False,
- **kwargs: Any,
+ info, pretrained=False, **kwargs: Any,
):
image_size = info["image_size"]
num_classes = info.num_classes
@@ -191,9 +186,7 @@ def get_toy_testmodel(
def get_toy_convnet(
- info,
- pretrained=False,
- **kwargs: Any,
+ info, pretrained=False, **kwargs: Any,
):
# NOTE: The model isn't configurable through the CLI or a configuration file yet.
num_classes = info.num_classes
diff --git a/src/chop/models/vgg_cifar/vgg_orig.py b/src/chop/models/vgg_cifar/vgg_orig.py
index cf906e888..f744c4404 100644
--- a/src/chop/models/vgg_cifar/vgg_orig.py
+++ b/src/chop/models/vgg_cifar/vgg_orig.py
@@ -189,12 +189,7 @@ class VGG11_Weights(WeightsEnum):
meta={
**_COMMON_META,
"num_params": 132863336,
- "_metrics": {
- "ImageNet-1K": {
- "acc@1": 69.020,
- "acc@5": 88.628,
- }
- },
+ "_metrics": {"ImageNet-1K": {"acc@1": 69.020, "acc@5": 88.628,}},
"_ops": 7.609,
"_file_size": 506.84,
},
@@ -209,12 +204,7 @@ class VGG11_BN_Weights(WeightsEnum):
meta={
**_COMMON_META,
"num_params": 132868840,
- "_metrics": {
- "ImageNet-1K": {
- "acc@1": 70.370,
- "acc@5": 89.810,
- }
- },
+ "_metrics": {"ImageNet-1K": {"acc@1": 70.370, "acc@5": 89.810,}},
"_ops": 7.609,
"_file_size": 506.881,
},
@@ -229,12 +219,7 @@ class VGG13_Weights(WeightsEnum):
meta={
**_COMMON_META,
"num_params": 133047848,
- "_metrics": {
- "ImageNet-1K": {
- "acc@1": 69.928,
- "acc@5": 89.246,
- }
- },
+ "_metrics": {"ImageNet-1K": {"acc@1": 69.928, "acc@5": 89.246,}},
"_ops": 11.308,
"_file_size": 507.545,
},
@@ -249,12 +234,7 @@ class VGG13_BN_Weights(WeightsEnum):
meta={
**_COMMON_META,
"num_params": 133053736,
- "_metrics": {
- "ImageNet-1K": {
- "acc@1": 71.586,
- "acc@5": 90.374,
- }
- },
+ "_metrics": {"ImageNet-1K": {"acc@1": 71.586, "acc@5": 90.374,}},
"_ops": 11.308,
"_file_size": 507.59,
},
@@ -269,12 +249,7 @@ class VGG16_Weights(WeightsEnum):
meta={
**_COMMON_META,
"num_params": 138357544,
- "_metrics": {
- "ImageNet-1K": {
- "acc@1": 71.592,
- "acc@5": 90.382,
- }
- },
+ "_metrics": {"ImageNet-1K": {"acc@1": 71.592, "acc@5": 90.382,}},
"_ops": 15.47,
"_file_size": 527.796,
},
@@ -294,10 +269,7 @@ class VGG16_Weights(WeightsEnum):
"categories": None,
"recipe": "https://github.com/amdegroot/ssd.pytorch#training-ssd",
"_metrics": {
- "ImageNet-1K": {
- "acc@1": float("nan"),
- "acc@5": float("nan"),
- }
+ "ImageNet-1K": {"acc@1": float("nan"), "acc@5": float("nan"),}
},
"_ops": 15.47,
"_file_size": 527.802,
@@ -318,12 +290,7 @@ class VGG16_BN_Weights(WeightsEnum):
meta={
**_COMMON_META,
"num_params": 138365992,
- "_metrics": {
- "ImageNet-1K": {
- "acc@1": 73.360,
- "acc@5": 91.516,
- }
- },
+ "_metrics": {"ImageNet-1K": {"acc@1": 73.360, "acc@5": 91.516,}},
"_ops": 15.47,
"_file_size": 527.866,
},
@@ -338,12 +305,7 @@ class VGG19_Weights(WeightsEnum):
meta={
**_COMMON_META,
"num_params": 143667240,
- "_metrics": {
- "ImageNet-1K": {
- "acc@1": 72.376,
- "acc@5": 90.876,
- }
- },
+ "_metrics": {"ImageNet-1K": {"acc@1": 72.376, "acc@5": 90.876,}},
"_ops": 19.632,
"_file_size": 548.051,
},
@@ -358,12 +320,7 @@ class VGG19_BN_Weights(WeightsEnum):
meta={
**_COMMON_META,
"num_params": 143678248,
- "_metrics": {
- "ImageNet-1K": {
- "acc@1": 74.218,
- "acc@5": 91.842,
- }
- },
+ "_metrics": {"ImageNet-1K": {"acc@1": 74.218, "acc@5": 91.842,}},
"_ops": 19.632,
"_file_size": 548.143,
},
diff --git a/src/chop/models/vision/snn/snn_toy.py b/src/chop/models/vision/snn/snn_toy.py
index 5915f1ec6..52c86d821 100644
--- a/src/chop/models/vision/snn/snn_toy.py
+++ b/src/chop/models/vision/snn/snn_toy.py
@@ -26,9 +26,7 @@ def forward(self, x: torch.Tensor):
# Getters ------------------------------------------------------------------------------
def get_snn_toy(
- info,
- pretrained=False,
- **kwargs: Any,
+ info, pretrained=False, **kwargs: Any,
):
tau = info["tau"]
num_classes = info.num_classes
diff --git a/src/chop/models/vision/snn/spikingResformer.py b/src/chop/models/vision/snn/spikingResformer.py
index ff74d76df..153cae155 100644
--- a/src/chop/models/vision/snn/spikingResformer.py
+++ b/src/chop/models/vision/snn/spikingResformer.py
@@ -124,11 +124,7 @@ def no_weight_decay(self):
@register_model
def spikingresformer_ti(**kwargs):
return SpikingResformer(
- [
- ["DSSA", "GWFFN"] * 1,
- ["DSSA", "GWFFN"] * 2,
- ["DSSA", "GWFFN"] * 3,
- ],
+ [["DSSA", "GWFFN"] * 1, ["DSSA", "GWFFN"] * 2, ["DSSA", "GWFFN"] * 3,],
[64, 192, 384],
[1, 3, 6],
[4, 2, 1],
@@ -140,11 +136,7 @@ def spikingresformer_ti(**kwargs):
@register_model
def spikingresformer_s(**kwargs):
return SpikingResformer(
- [
- ["DSSA", "GWFFN"] * 1,
- ["DSSA", "GWFFN"] * 2,
- ["DSSA", "GWFFN"] * 3,
- ],
+ [["DSSA", "GWFFN"] * 1, ["DSSA", "GWFFN"] * 2, ["DSSA", "GWFFN"] * 3,],
[64, 256, 512],
[1, 4, 8],
[4, 2, 1],
@@ -156,11 +148,7 @@ def spikingresformer_s(**kwargs):
@register_model
def spikingresformer_m(**kwargs):
return SpikingResformer(
- [
- ["DSSA", "GWFFN"] * 1,
- ["DSSA", "GWFFN"] * 2,
- ["DSSA", "GWFFN"] * 3,
- ],
+ [["DSSA", "GWFFN"] * 1, ["DSSA", "GWFFN"] * 2, ["DSSA", "GWFFN"] * 3,],
[64, 384, 768],
[1, 6, 12],
[4, 2, 1],
@@ -172,11 +160,7 @@ def spikingresformer_m(**kwargs):
@register_model
def spikingresformer_l(**kwargs):
return SpikingResformer(
- [
- ["DSSA", "GWFFN"] * 1,
- ["DSSA", "GWFFN"] * 2,
- ["DSSA", "GWFFN"] * 3,
- ],
+ [["DSSA", "GWFFN"] * 1, ["DSSA", "GWFFN"] * 2, ["DSSA", "GWFFN"] * 3,],
[128, 512, 1024],
[2, 8, 16],
[4, 2, 1],
@@ -188,18 +172,13 @@ def spikingresformer_l(**kwargs):
@register_model
def spikingresformer_dvsg(**kwargs):
return SpikingResformer(
- [
- ["DSSA", "GWFFN"] * 1,
- ["DSSA", "GWFFN"] * 2,
- ["DSSA", "GWFFN"] * 3,
- ],
+ [["DSSA", "GWFFN"] * 1, ["DSSA", "GWFFN"] * 2, ["DSSA", "GWFFN"] * 3,],
[32, 96, 192],
[1, 3, 6],
[4, 2, 1],
in_channels=3,
prologue=nn.Sequential(
- Conv2d(3, 32, 3, 1, 1, bias=False, step_mode="m"),
- BN(32),
+ Conv2d(3, 32, 3, 1, 1, bias=False, step_mode="m"), BN(32),
),
group_size=32,
activation=PLIF,
@@ -210,18 +189,13 @@ def spikingresformer_dvsg(**kwargs):
@register_model
def spikingresformer_cifar(**kwargs):
return SpikingResformer(
- [
- ["DSSA", "GWFFN"] * 1,
- ["DSSA", "GWFFN"] * 2,
- ["DSSA", "GWFFN"] * 3,
- ],
+ [["DSSA", "GWFFN"] * 1, ["DSSA", "GWFFN"] * 2, ["DSSA", "GWFFN"] * 3,],
[64, 192, 384],
[1, 3, 6],
[4, 2, 1],
in_channels=3,
prologue=nn.Sequential(
- Conv2d(3, 64, 3, 1, 1, bias=False, step_mode="m"),
- BN(64),
+ Conv2d(3, 64, 3, 1, 1, bias=False, step_mode="m"), BN(64),
),
**kwargs,
)
diff --git a/src/chop/nn/backward/modules/__init__.py b/src/chop/nn/backward/modules/__init__.py
index 1dcc39e65..279e0c867 100644
--- a/src/chop/nn/backward/modules/__init__.py
+++ b/src/chop/nn/backward/modules/__init__.py
@@ -1,6 +1,4 @@
-from .linear import (
- CustomLinear,
-)
+from .linear import CustomLinear
custom_module_map = {
diff --git a/src/chop/nn/functional/softermax.py b/src/chop/nn/functional/softermax.py
index 8653a6aa0..348d7ebea 100644
--- a/src/chop/nn/functional/softermax.py
+++ b/src/chop/nn/functional/softermax.py
@@ -11,7 +11,7 @@ def softermax(input: Tensor, dim: int) -> Tensor:
Tensor: Output tensor
"""
out = input - input.max(dim=dim, keepdim=True).values.floor()
- out = 2**out
+ out = 2 ** out
row_sum = out.sum(dim=dim, keepdim=True)
# Elementwise division
out = out / row_sum
diff --git a/src/chop/nn/modules/gqa.py b/src/chop/nn/modules/gqa.py
index f30b7a8d8..38188d51d 100644
--- a/src/chop/nn/modules/gqa.py
+++ b/src/chop/nn/modules/gqa.py
@@ -127,12 +127,7 @@ def _qkv_states(self, x: Tensor, batch_size: int, seq_len: int):
# return x
def _attention_mechanism(
- self,
- query: Tensor,
- key: Tensor,
- value: Tensor,
- batch_size: int,
- seq_len: int,
+ self, query: Tensor, key: Tensor, value: Tensor, batch_size: int, seq_len: int,
):
key = repeat_kv(key, n_rep=self.group_size)
value = repeat_kv(value, n_rep=self.group_size)
@@ -169,9 +164,7 @@ def forward(self, x: Tensor):
GROUPS = 4
gqa_module = GroupedQueryAttention(
- embed_dim=EMBED_DIM,
- num_heads=NUM_HEADS,
- num_kv_heads=GROUPS,
+ embed_dim=EMBED_DIM, num_heads=NUM_HEADS, num_kv_heads=GROUPS,
)
x_in = torch.rand(BATCH, SEQ_LEN, EMBED_DIM)
diff --git a/src/chop/nn/modules/lora.py b/src/chop/nn/modules/lora.py
index bbd2ea023..f101d0c03 100644
--- a/src/chop/nn/modules/lora.py
+++ b/src/chop/nn/modules/lora.py
@@ -93,11 +93,7 @@ def reset_lora_parameters(self, adapter_name):
class LinearLora(nn.Linear, LoraLayer):
# Lora implemented in a dense layer
def __init__(
- self,
- in_features: int,
- out_features: int,
- config: dict = None,
- **kwargs,
+ self, in_features: int, out_features: int, config: dict = None, **kwargs,
):
self.config = config
init_lora_weights = self.config.get("init_lora_weights", True)
@@ -222,12 +218,7 @@ def mark_only_lora_as_trainable(model: nn.Module, bias: str = "none") -> None:
# Simple Lora implementation from https://pytorch.org/torchtune/stable/tutorials/lora_finetune.html
class LoRALinear(nn.Module):
def __init__(
- self,
- in_dim: int,
- out_dim: int,
- rank: int,
- alpha: float,
- dropout: float,
+ self, in_dim: int, out_dim: int, rank: int, alpha: float, dropout: float,
):
super().__init__()
# These are the weights from the original pretrained model
diff --git a/src/chop/nn/modules/sparse.py b/src/chop/nn/modules/sparse.py
index 67508cf3d..5a3f749a1 100644
--- a/src/chop/nn/modules/sparse.py
+++ b/src/chop/nn/modules/sparse.py
@@ -92,11 +92,7 @@ def reset_sparse_parameters(self, adapter_name):
class LinearSparse(nn.Linear, SparseLayer):
def __init__(
- self,
- in_features: int,
- out_features: int,
- config: dict = None,
- **kwargs,
+ self, in_features: int, out_features: int, config: dict = None, **kwargs,
):
self.config = config
init_sparse_weights = self.config.get("init_sparse_weights", True)
@@ -159,9 +155,7 @@ def _linear(self, input: torch.Tensor) -> torch.Tensor:
def update_weight_selection(self, k):
w_flat = self.weight.flatten()
_, self.idx = torch.topk(
- self.index_method(w_flat, self.idx_method),
- k,
- sorted=True,
+ self.index_method(w_flat, self.idx_method), k, sorted=True,
)
self.selected_weights = torch.gather(w_flat, dim=0, index=self.idx)
@@ -192,10 +186,7 @@ def forward(self, x: torch.Tensor):
# Scatter adapted values into weight tensor
adapted_weights = torch.scatter(
- self.zero_tensor.to(x.device),
- dim=0,
- index=self.idx,
- src=scaled_output,
+ self.zero_tensor.to(x.device), dim=0, index=self.idx, src=scaled_output,
).view(self.unflattened_size)
self.step += 1
@@ -203,9 +194,7 @@ def forward(self, x: torch.Tensor):
x = x.to(sparse.weight.dtype)
result = F.linear(
- dropout(x),
- transpose(new_weight, self.fan_in_fan_out),
- bias=self.bias,
+ dropout(x), transpose(new_weight, self.fan_in_fan_out), bias=self.bias,
)
else:
diff --git a/src/chop/nn/mx/activations.py b/src/chop/nn/mx/activations.py
index 39702cc29..27fafc352 100644
--- a/src/chop/nn/mx/activations.py
+++ b/src/chop/nn/mx/activations.py
@@ -434,10 +434,7 @@ def forward(ctx, input, inplace=False, mx_specs=None, name=None):
@staticmethod
def backward(ctx, grad_output):
- (
- y,
- sig_x,
- ) = ctx.saved_tensors
+ (y, sig_x,) = ctx.saved_tensors
grad_output = vec_quantize(grad_output, mx_specs=ctx.mx_specs)
temp = vec_sub(1.0, sig_x, mx_specs=ctx.mx_specs)
diff --git a/src/chop/nn/mx/bmm.py b/src/chop/nn/mx/bmm.py
index 6fac4eef8..50071d06e 100644
--- a/src/chop/nn/mx/bmm.py
+++ b/src/chop/nn/mx/bmm.py
@@ -67,9 +67,7 @@ def backward(ctx, grad_out):
in1, in2 = ctx.saved_tensors
grad_out = quantize_elemwise_op(
- grad_out,
- mx_specs=ctx.mx_specs,
- round=ctx.mx_specs["round_grad_input"],
+ grad_out, mx_specs=ctx.mx_specs, round=ctx.mx_specs["round_grad_input"],
)
#####################################################
@@ -111,14 +109,10 @@ def backward(ctx, grad_out):
# element-wise quantize for grad_in1 and grad_in2
grad_in1 = quantize_elemwise_op(
- grad_in1,
- mx_specs=ctx.mx_specs,
- round=ctx.mx_specs["round_grad_input"],
+ grad_in1, mx_specs=ctx.mx_specs, round=ctx.mx_specs["round_grad_input"],
)
grad_in2 = quantize_elemwise_op(
- grad_in2,
- mx_specs=ctx.mx_specs,
- round=ctx.mx_specs["round_grad_input"],
+ grad_in2, mx_specs=ctx.mx_specs, round=ctx.mx_specs["round_grad_input"],
)
return (grad_in1, grad_in2, None, None)
diff --git a/src/chop/nn/mx/convolution.py b/src/chop/nn/mx/convolution.py
index 62b8dd2d2..9d2aaed17 100644
--- a/src/chop/nn/mx/convolution.py
+++ b/src/chop/nn/mx/convolution.py
@@ -180,16 +180,10 @@ def forward(
# weight is (out_channels, in_channels/groups, ..)
# quantize along in_channels
qid_input = quantize_mx_op(
- bf_in,
- mx_specs,
- elem_format=mx_specs["a_elem_format"],
- axes=[1],
+ bf_in, mx_specs, elem_format=mx_specs["a_elem_format"], axes=[1],
)
qid_weight = quantize_mx_op(
- bf_weight,
- mx_specs,
- elem_format=mx_specs["w_elem_format"],
- axes=[1],
+ bf_weight, mx_specs, elem_format=mx_specs["w_elem_format"], axes=[1],
)
# compute output
@@ -213,9 +207,7 @@ def backward(ctx, grad_output):
input, weight = ctx.saved_tensors
grad_output = quantize_elemwise_op(
- grad_output,
- mx_specs=ctx.mx_specs,
- round=ctx.mx_specs["round_grad_input"],
+ grad_output, mx_specs=ctx.mx_specs, round=ctx.mx_specs["round_grad_input"],
)
#####################################################
@@ -225,10 +217,7 @@ def backward(ctx, grad_output):
# output is (batch, out_channels, ...)
# quantize along the batch dim
qex_input = quantize_mx_op(
- input,
- ctx.mx_specs,
- elem_format=ctx.mx_specs["a_elem_format"],
- axes=[0],
+ input, ctx.mx_specs, elem_format=ctx.mx_specs["a_elem_format"], axes=[0],
)
qex_grad_output = quantize_mx_op(
grad_output,
@@ -251,9 +240,7 @@ def backward(ctx, grad_output):
# element-wise quantize for grad_weight
grad_weight = quantize_elemwise_op(
- grad_weight,
- mx_specs=ctx.mx_specs,
- round=ctx.mx_specs["round_grad_weight"],
+ grad_weight, mx_specs=ctx.mx_specs, round=ctx.mx_specs["round_grad_weight"],
)
#####################################################
@@ -264,10 +251,7 @@ def backward(ctx, grad_output):
# output is (batch, out_channels, ...)
# reduction dim is out_channels
qod_weight = quantize_mx_op(
- weight,
- ctx.mx_specs,
- elem_format=ctx.mx_specs["w_elem_format"],
- axes=[0],
+ weight, ctx.mx_specs, elem_format=ctx.mx_specs["w_elem_format"], axes=[0],
)
qod_grad_output = quantize_mx_op(
grad_output,
@@ -289,9 +273,7 @@ def backward(ctx, grad_output):
# element-wise quantize for grad_input
grad_input = quantize_elemwise_op(
- grad_input,
- mx_specs=ctx.mx_specs,
- round=ctx.mx_specs["round_grad_input"],
+ grad_input, mx_specs=ctx.mx_specs, round=ctx.mx_specs["round_grad_input"],
)
#####################################################
diff --git a/src/chop/nn/mx/elemwise_ops.py b/src/chop/nn/mx/elemwise_ops.py
index 5d2b4b8ec..21067bba1 100644
--- a/src/chop/nn/mx/elemwise_ops.py
+++ b/src/chop/nn/mx/elemwise_ops.py
@@ -33,16 +33,16 @@
# exponents smaller than -126
def _safe_lshift(x, bits, exp):
if exp is None:
- return x * (2**bits)
+ return x * (2 ** bits)
else:
- return x / (2**exp) * (2**bits)
+ return x / (2 ** exp) * (2 ** bits)
def _safe_rshift(x, bits, exp):
if exp is None:
- return x / (2**bits)
+ return x / (2 ** bits)
else:
- return x / (2**bits) * (2**exp)
+ return x / (2 ** bits) * (2 ** exp)
def _round_mantissa(A, bits, round, clamp=False):
diff --git a/src/chop/nn/mx/formats.py b/src/chop/nn/mx/formats.py
index 9a2c35189..ff2a9d726 100644
--- a/src/chop/nn/mx/formats.py
+++ b/src/chop/nn/mx/formats.py
@@ -52,14 +52,14 @@ def from_str(s):
def _get_min_norm(ebits):
"""Valid for all float formats"""
emin = 2 - (2 ** (ebits - 1))
- return 0 if ebits == 0 else 2**emin
+ return 0 if ebits == 0 else 2 ** emin
def _get_max_norm(ebits, mbits):
"""Valid only for floats that define NaN"""
assert ebits >= 5, "invalid for floats that don't define NaN"
emax = 0 if ebits == 0 else 2 ** (ebits - 1) - 1
- return 2**emax * float(2 ** (mbits - 1) - 1) / 2 ** (mbits - 2)
+ return 2 ** emax * float(2 ** (mbits - 1) - 1) / 2 ** (mbits - 2)
_FORMAT_CACHE = {}
@@ -121,9 +121,9 @@ def _get_format_params(fmt):
raise Exception("Unknown element format %s" % fmt)
if fmt != ElemFormat.fp8_e4m3:
- max_norm = 2**emax * float(2 ** (mbits - 1) - 1) / 2 ** (mbits - 2)
+ max_norm = 2 ** emax * float(2 ** (mbits - 1) - 1) / 2 ** (mbits - 2)
else:
- max_norm = 2**emax * 1.75 # FP8 has custom max_norm
+ max_norm = 2 ** emax * 1.75 # FP8 has custom max_norm
min_norm = _get_min_norm(ebits)
diff --git a/src/chop/nn/mx/linear.py b/src/chop/nn/mx/linear.py
index 64f761c86..ef78fcedc 100644
--- a/src/chop/nn/mx/linear.py
+++ b/src/chop/nn/mx/linear.py
@@ -18,12 +18,7 @@
class LinearFunction(torch.autograd.Function):
@staticmethod
def forward(
- ctx,
- input,
- weight,
- bias=None,
- mx_specs=None,
- name=None,
+ ctx, input, weight, bias=None, mx_specs=None, name=None,
):
# element-wise quantize for input
bf_in = quantize_elemwise_op(
@@ -90,9 +85,7 @@ def backward(ctx, grad_output):
in_dim = weight.shape[1]
grad_output = quantize_elemwise_op(
- grad_output,
- mx_specs=ctx.mx_specs,
- round=ctx.mx_specs["round_grad_input"],
+ grad_output, mx_specs=ctx.mx_specs, round=ctx.mx_specs["round_grad_input"],
)
#####################################################
@@ -122,9 +115,7 @@ def backward(ctx, grad_output):
# Compute grad_weight
grad_weight = torch_matmul(qex_grad_output.transpose(0, 1), qex_input)
grad_weight = quantize_elemwise_op(
- grad_weight,
- mx_specs=ctx.mx_specs,
- round=ctx.mx_specs["round_grad_weight"],
+ grad_weight, mx_specs=ctx.mx_specs, round=ctx.mx_specs["round_grad_weight"],
)
#####################################################
@@ -150,9 +141,7 @@ def backward(ctx, grad_output):
# Compute grad_input
grad_input = torch_matmul(qos_grad_output, qos_weight)
grad_input = quantize_elemwise_op(
- grad_input,
- mx_specs=ctx.mx_specs,
- round=ctx.mx_specs["round_grad_input"],
+ grad_input, mx_specs=ctx.mx_specs, round=ctx.mx_specs["round_grad_input"],
)
#####################################################
@@ -172,11 +161,7 @@ def backward(ctx, grad_output):
def linear(
- input,
- weight,
- bias=None,
- mx_specs=None,
- name=None,
+ input, weight, bias=None, mx_specs=None, name=None,
):
mx_assert_test(mx_specs)
if mx_specs is None:
@@ -189,12 +174,7 @@ def linear(
class Linear(torch.nn.Linear):
def __init__(
- self,
- in_features,
- out_features,
- bias=True,
- mx_specs=None,
- name=None,
+ self, in_features, out_features, bias=True, mx_specs=None, name=None,
):
mx_assert_test(mx_specs)
self.mx_none = mx_specs is None
diff --git a/src/chop/nn/mx/matmul.py b/src/chop/nn/mx/matmul.py
index 8c18914b7..3c2e0ffb4 100644
--- a/src/chop/nn/mx/matmul.py
+++ b/src/chop/nn/mx/matmul.py
@@ -114,9 +114,7 @@ def backward(ctx, grad_out):
in1, in2 = ctx.saved_tensors
grad_out = quantize_elemwise_op(
- grad_out,
- mx_specs=ctx.mx_specs,
- round=ctx.mx_specs["round_grad_input"],
+ grad_out, mx_specs=ctx.mx_specs, round=ctx.mx_specs["round_grad_input"],
)
#####################################################
@@ -160,14 +158,10 @@ def backward(ctx, grad_out):
# element-wise quantize for grad_in1 and grad_in2
grad_in1 = quantize_elemwise_op(
- grad_in1,
- mx_specs=ctx.mx_specs,
- round=ctx.mx_specs["round_grad_input"],
+ grad_in1, mx_specs=ctx.mx_specs, round=ctx.mx_specs["round_grad_input"],
)
grad_in2 = quantize_elemwise_op(
- grad_in2,
- mx_specs=ctx.mx_specs,
- round=ctx.mx_specs["round_grad_input"],
+ grad_in2, mx_specs=ctx.mx_specs, round=ctx.mx_specs["round_grad_input"],
)
#####################################################
diff --git a/src/chop/nn/mx/mx_ops.py b/src/chop/nn/mx/mx_ops.py
index 136441bea..aad12a52f 100644
--- a/src/chop/nn/mx/mx_ops.py
+++ b/src/chop/nn/mx/mx_ops.py
@@ -281,10 +281,7 @@ def _quantize_mx(
else:
# Get shared exponents
shared_exp = _shared_exponents(
- A,
- method=shared_exp_method,
- axes=shared_exp_axes,
- ebits=0,
+ A, method=shared_exp_method, axes=shared_exp_axes, ebits=0,
)
# Flush subnormal FP32 inputs to zero
@@ -299,7 +296,7 @@ def _quantize_mx(
shared_exp[shared_exp > scale_emax] = float("NaN")
shared_exp[shared_exp < -scale_emax] = -scale_emax
- A = A / (2**shared_exp)
+ A = A / (2 ** shared_exp)
A = _quantize_elemwise_core(
A,
@@ -312,7 +309,7 @@ def _quantize_mx(
custom_cuda=custom_cuda,
)
- A = A * (2**shared_exp)
+ A = A * (2 ** shared_exp)
# Undo tile reshaping
if block_size:
diff --git a/src/chop/nn/mx/quantize.py b/src/chop/nn/mx/quantize.py
index 7e0b907fa..b979b18d7 100644
--- a/src/chop/nn/mx/quantize.py
+++ b/src/chop/nn/mx/quantize.py
@@ -42,9 +42,7 @@ def forward(ctx, x, mx_specs, round=None):
@staticmethod
def backward(ctx, grad_output):
grad_input = quantize_elemwise_op(
- grad_output,
- mx_specs=ctx.mx_specs,
- round=ctx.round,
+ grad_output, mx_specs=ctx.mx_specs, round=ctx.round,
)
return (grad_input, None, None)
diff --git a/src/chop/nn/mx/simd_ops.py b/src/chop/nn/mx/simd_ops.py
index 475b47ef7..28180b758 100644
--- a/src/chop/nn/mx/simd_ops.py
+++ b/src/chop/nn/mx/simd_ops.py
@@ -307,7 +307,7 @@ def forward(ctx, in1, mx_specs=None):
else:
ctx.save_for_backward(in1)
- return vec_quantize(qin1**2, mx_specs=mx_specs)
+ return vec_quantize(qin1 ** 2, mx_specs=mx_specs)
@staticmethod
def backward(ctx, g):
diff --git a/src/chop/nn/mx/transpose_convolution.py b/src/chop/nn/mx/transpose_convolution.py
index af570025b..5d6c5da8a 100644
--- a/src/chop/nn/mx/transpose_convolution.py
+++ b/src/chop/nn/mx/transpose_convolution.py
@@ -75,16 +75,10 @@ def forward(
# weight is (in_channels, out_channels/groups, ...)
# quantize along in_channels
qid_input = quantize_mx_op(
- bf_in,
- mx_specs,
- elem_format=mx_specs["a_elem_format"],
- axes=[1],
+ bf_in, mx_specs, elem_format=mx_specs["a_elem_format"], axes=[1],
)
qid_weight = quantize_mx_op(
- bf_weight,
- mx_specs,
- elem_format=mx_specs["w_elem_format"],
- axes=[0],
+ bf_weight, mx_specs, elem_format=mx_specs["w_elem_format"], axes=[0],
)
# compute output
@@ -114,9 +108,7 @@ def backward(ctx, grad_output):
input, weight = ctx.saved_tensors
grad_output = quantize_elemwise_op(
- grad_output,
- mx_specs=ctx.mx_specs,
- round=ctx.mx_specs["round_grad_input"],
+ grad_output, mx_specs=ctx.mx_specs, round=ctx.mx_specs["round_grad_input"],
)
#####################################################
@@ -126,10 +118,7 @@ def backward(ctx, grad_output):
# output is (batch, out_channels, ...)
# quantize along the batch dim
qex_input = quantize_mx_op(
- input,
- ctx.mx_specs,
- elem_format=ctx.mx_specs["a_elem_format"],
- axes=[0],
+ input, ctx.mx_specs, elem_format=ctx.mx_specs["a_elem_format"], axes=[0],
)
qex_grad_output = quantize_mx_op(
grad_output,
@@ -150,9 +139,7 @@ def backward(ctx, grad_output):
# element-wise quantize for grad_weight
grad_weight = quantize_elemwise_op(
- grad_weight,
- mx_specs=ctx.mx_specs,
- round=ctx.mx_specs["round_grad_weight"],
+ grad_weight, mx_specs=ctx.mx_specs, round=ctx.mx_specs["round_grad_weight"],
)
#####################################################
@@ -162,10 +149,7 @@ def backward(ctx, grad_output):
# output is (batch, out_channels, ...)
# reduction dim is out_channels
qod_weight = quantize_mx_op(
- weight,
- ctx.mx_specs,
- elem_format=ctx.mx_specs["w_elem_format"],
- axes=[1],
+ weight, ctx.mx_specs, elem_format=ctx.mx_specs["w_elem_format"], axes=[1],
)
qod_grad_output = quantize_mx_op(
grad_output,
@@ -187,9 +171,7 @@ def backward(ctx, grad_output):
# element-wise quantize for grad_input
grad_input = quantize_elemwise_op(
- grad_input,
- mx_specs=ctx.mx_specs,
- round=ctx.mx_specs["round_grad_input"],
+ grad_input, mx_specs=ctx.mx_specs, round=ctx.mx_specs["round_grad_input"],
)
#####################################################
diff --git a/src/chop/nn/optical/modules/morr_conv2d.py b/src/chop/nn/optical/modules/morr_conv2d.py
index f104e397f..c95cd1bb3 100644
--- a/src/chop/nn/optical/modules/morr_conv2d.py
+++ b/src/chop/nn/optical/modules/morr_conv2d.py
@@ -86,6 +86,7 @@ def __init__(
morr_init = config.get("morr_init", True)
trainable_morr_bias = config.get("trainable_morr_bias", False)
trainable_morr_scale = config.get("trainable_morr_scale", False)
+ device = config.get("device", device)
self.in_channels = in_channels
self.out_channels = out_channels
@@ -108,7 +109,7 @@ def __init__(
self.v_max = 10.8
self.v_pi = 4.36
- self.gamma = np.pi / self.v_pi**2
+ self.gamma = np.pi / self.v_pi ** 2
self.w_bit = 32
self.in_bit = 32
self.MORRConfig = MORRConfig
@@ -122,7 +123,7 @@ def __init__(
### calculate FWHM (rad)
self.morr_fwhm = (
-4
- * np.pi**2
+ * np.pi ** 2
* MORRConfig.radius
* MORRConfig.effective_index
* (
@@ -240,7 +241,7 @@ def reset_parameters(self, morr_init: bool = False) -> None:
(t2 - t1) / (2.4 * self.morr_fwhm)
).item() ## 0~2.4 FWHM slope as a linear approximation
- self.sigma_out_scale = 4 / (3 * self.grid_dim_x**0.5 * g * self.morr_fwhm)
+ self.sigma_out_scale = 4 / (3 * self.grid_dim_x ** 0.5 * g * self.morr_fwhm)
self.out_scale_quant_gain = None
init.normal_(self.morr_output_scale, 0, self.sigma_out_scale)
diff --git a/src/chop/nn/optical/modules/morr_linear.py b/src/chop/nn/optical/modules/morr_linear.py
index f93281af5..c1181cfc1 100644
--- a/src/chop/nn/optical/modules/morr_linear.py
+++ b/src/chop/nn/optical/modules/morr_linear.py
@@ -62,7 +62,7 @@ def __init__(
self.v_max = 10.8
self.v_pi = 4.36
- self.gamma = np.pi / self.v_pi**2
+ self.gamma = np.pi / self.v_pi ** 2
self.w_bit = 32
self.in_bit = 32
@@ -80,7 +80,7 @@ def __init__(
### calculate FWHM (rad)
self.morr_fwhm = (
-4
- * np.pi**2
+ * np.pi ** 2
* morr_config.radius
* morr_config.effective_index
* (
@@ -198,7 +198,7 @@ def reset_parameters(self, morr_init: bool = False) -> None:
(t2 - t1) / (2.4 * self.morr_fwhm)
).item() ## 0~2.4 FWHM slope as a linear approximation
- self.sigma_out_scale = 4 / (3 * self.grid_dim_x**0.5 * g * self.morr_fwhm)
+ self.sigma_out_scale = 4 / (3 * self.grid_dim_x ** 0.5 * g * self.morr_fwhm)
self.out_scale_quant_gain = None
init.normal_(self.morr_output_scale, 0, self.sigma_out_scale)
else:
diff --git a/src/chop/nn/optical/utils/__init__.py b/src/chop/nn/optical/utils/__init__.py
index 88b71f264..248e214da 100644
--- a/src/chop/nn/optical/utils/__init__.py
+++ b/src/chop/nn/optical/utils/__init__.py
@@ -11,9 +11,7 @@
toeplitz,
)
-from .initializer import (
- morr_uniform_,
-)
+from .initializer import morr_uniform_
from .quantize import (
input_quantize_fn,
diff --git a/src/chop/nn/optical/utils/compute.py b/src/chop/nn/optical/utils/compute.py
index d8d36a354..8ae1b3279 100644
--- a/src/chop/nn/optical/utils/compute.py
+++ b/src/chop/nn/optical/utils/compute.py
@@ -20,65 +20,11 @@
from torch.types import Device, _size
__all__ = [
- "shift",
- "Krylov",
- "circulant",
"toeplitz",
- "complex_circulant",
- "complex_mult",
- "expi",
- "complex_matvec_mult",
- "complex_matmul",
- "real_to_complex",
- "get_complex_magnitude",
- "get_complex_energy",
- "complex_to_polar",
- "polar_to_complex",
- "absclamp",
- "absclamp_",
"im2col_2d",
- "check_identity_matrix",
- "check_unitary_matrix",
- "check_equal_tensor",
- "batch_diag",
- "batch_eye_cpu",
- "batch_eye",
- "merge_chunks",
- "partition_chunks",
- "clip_by_std",
- "percentile",
- "gen_boolean_mask_cpu",
- "gen_boolean_mask",
- "fftshift_cpu",
- "ifftshift_cpu",
- "gen_gaussian_noise",
- "gen_gaussian_filter2d_cpu",
- "gen_gaussian_filter2d",
- "add_gaussian_noise_cpu",
- "add_gaussian_noise",
- "add_gaussian_noise_",
- "circulant_multiply",
- "calc_diagonal_hessian",
- "calc_jacobian",
- "polynomial",
- "gaussian",
- "lowrank_decompose",
- "get_conv2d_flops",
- "interp1d",
]
-def set_torch_deterministic(random_state: int = 0) -> None:
- random_state = int(random_state) % (2**32)
- torch.manual_seed(random_state)
- np.random.seed(random_state)
- if torch.cuda.is_available():
- torch.backends.cudnn.deterministic = True
- torch.backends.cudnn.benchmark = False
- torch.cuda.manual_seed_all(random_state)
- random.seed(random_state)
-
-
def shift(v: Tensor, f: float = 1) -> Tensor:
return torch.cat((f * v[..., -1:], v[..., :-1]), dim=-1)
@@ -116,9 +62,38 @@ def toeplitz(col: Tensor) -> Tensor:
return col[..., indices]
-def complex_circulant(eigens: Tensor) -> Tensor:
- circ = Krylov(shift, eigens).transpose(-1, -2)
- return circ
+def im2col_2d(
+ W: Optional[Tensor] = None,
+ X: Optional[Tensor] = None,
+ stride: int = 1,
+ padding: int = 0,
+ w_size: Optional[_size] = None,
+) -> Tuple[Tensor, Tensor, int, int]:
+ if W is not None:
+ W_col = W.view(W.size(0), -1)
+ else:
+ W_col = None
+
+ if X is not None:
+ n_filters, d_filter, h_filter, w_filter = W.size() if W is not None else w_size
+ n_x, d_x, h_x, w_x = X.size()
+
+ h_out = (h_x - h_filter + 2 * padding) / stride + 1
+ w_out = (w_x - w_filter + 2 * padding) / stride + 1
+
+ h_out, w_out = int(h_out), int(w_out)
+ X_col = torch.nn.functional.unfold(
+ X.view(1, -1, h_x, w_x),
+ h_filter,
+ dilation=1,
+ padding=padding,
+ stride=stride,
+ ).view(n_x, -1, h_out * w_out)
+ X_col = X_col.permute(1, 2, 0).contiguous().view(X_col.size(1), -1)
+ else:
+ X_col, h_out, w_out = None, None, None
+
+ return W_col, X_col, h_out, w_out
def complex_mult(X: Tensor, Y: Tensor) -> Tensor:
@@ -151,65 +126,6 @@ def complex_mult(X: Tensor, Y: Tensor) -> Tensor:
return X.mul(Y)
-def complex_matvec_mult(W: Tensor, X: Tensor) -> Tensor:
- return torch.sum(complex_mult(W, X.unsqueeze(0).repeat(W.size(0), 1, 1)), dim=1)
-
-
-def complex_matmul(X: Tensor, Y: Tensor) -> Tensor:
- assert X.shape[-1] == 2 and Y.shape[-1] == 2, "Last dimension must be 2"
- if torch.__version__ >= "1.8" or (
- torch.__version__ >= "1.7" and X.shape[:-3] == Y.shape[:-3]
- ):
- return torch.view_as_real(
- torch.matmul(torch.view_as_complex(X), torch.view_as_complex(Y))
- )
-
- return torch.stack(
- [
- X[..., 0].matmul(Y[..., 0]) - X[..., 1].matmul(Y[..., 1]),
- X[..., 0].matmul(Y[..., 1]) + X[..., 1].matmul(Y[..., 0]),
- ],
- dim=-1,
- )
-
-
-def expi(x: Tensor) -> Tensor:
- if torch.__version__ >= "1.8" or (
- torch.__version__ >= "1.7" and not x.requires_grad
- ):
- return torch.exp(1j * x)
- else:
- return x.cos().type(torch.cfloat) + 1j * x.sin().type(torch.cfloat)
-
-
-def real_to_complex(x: Tensor) -> Tensor:
- if torch.__version__ < "1.7":
- return torch.stack((x, torch.zeros_like(x).to(x.device)), dim=-1)
- else:
- return torch.view_as_real(x.to(torch.complex64))
-
-
-def get_complex_magnitude(x: Tensor) -> Tensor:
- assert x.size(-1) == 2, "[E] Input must be complex Tensor"
- return torch.sqrt(x[..., 0] * x[..., 0] + x[..., 1] * x[..., 1])
-
-
-def complex_to_polar(x: Tensor) -> Tensor:
- # real and imag to magnitude and angle
- if isinstance(x, torch.Tensor):
- mag = x.norm(p=2, dim=-1)
- angle = torch.view_as_complex(x).angle()
- x = torch.stack([mag, angle], dim=-1)
- elif isinstance(x, np.ndarray):
- x = x.astype(np.complex64)
- mag = np.abs(x)
- angle = np.angle(x)
- x = np.stack([mag, angle], axis=-1)
- else:
- raise NotImplementedError
- return x
-
-
def polar_to_complex(mag: Tensor, angle: Tensor) -> Tensor:
# magnitude and angle to real and imag
if angle is None:
@@ -231,534 +147,6 @@ def polar_to_complex(mag: Tensor, angle: Tensor) -> Tensor:
return x
-def get_complex_energy(x: Tensor) -> Tensor:
- assert x.size(-1) == 2, "[E] Input must be complex Tensor"
- return x[..., 0] * x[..., 0] + x[..., 1] * x[..., 1]
-
-
-def absclamp(
- x: Tensor, min: Optional[float] = None, max: Optional[float] = None
-) -> Tensor:
- if isinstance(x, torch.Tensor):
- mag = x.norm(p=2, dim=-1).clamp(min=min, max=max)
- angle = torch.view_as_complex(x).angle()
- x = polar_to_complex(mag, angle)
- elif isinstance(x, np.ndarray):
- x = x.astype(np.complex64)
- mag = np.clip(np.abs(x), a_min=min, a_max=max)
- angle = np.angle(x)
- x = polar_to_complex(mag, angle)
- else:
- raise NotImplementedError
- return x
-
-
-def absclamp_(
- x: Tensor, min: Optional[float] = None, max: Optional[float] = None
-) -> Tensor:
- if isinstance(x, torch.Tensor):
- y = torch.view_as_complex(x)
- mag = y.abs().clamp(min=min, max=max)
- angle = y.angle()
- x.data.copy_(polar_to_complex(mag, angle))
- elif isinstance(x, np.ndarray):
- y = x.astype(np.complex64)
- mag = np.clip(np.abs(y), a_min=min, a_max=max)
- angle = np.angle(y)
- x[:] = polar_to_complex(mag, angle)
- else:
- raise NotImplementedError
- return x
-
-
-def im2col_2d(
- W: Optional[Tensor] = None,
- X: Optional[Tensor] = None,
- stride: int = 1,
- padding: int = 0,
- w_size: Optional[_size] = None,
-) -> Tuple[Tensor, Tensor, int, int]:
- if W is not None:
- W_col = W.view(W.size(0), -1)
- else:
- W_col = None
-
- if X is not None:
- n_filters, d_filter, h_filter, w_filter = W.size() if W is not None else w_size
- n_x, d_x, h_x, w_x = X.size()
-
- h_out = (h_x - h_filter + 2 * padding) / stride + 1
- w_out = (w_x - w_filter + 2 * padding) / stride + 1
-
- h_out, w_out = int(h_out), int(w_out)
- X_col = torch.nn.functional.unfold(
- X.view(1, -1, h_x, w_x),
- h_filter,
- dilation=1,
- padding=padding,
- stride=stride,
- ).view(n_x, -1, h_out * w_out)
- X_col = X_col.permute(1, 2, 0).contiguous().view(X_col.size(1), -1)
- else:
- X_col, h_out, w_out = None, None, None
-
- return W_col, X_col, h_out, w_out
-
-
-def check_identity_matrix(W: Tensor) -> bool:
- if isinstance(W, np.ndarray):
- W_numpy = W.copy().astype(np.float64)
- elif isinstance(W, torch.Tensor):
- W_numpy = W.detach().cpu().numpy().copy().astype(np.float64)
- else:
- assert 0, "[E] Array type not supported, must be numpy.ndarray or torch.Tensor"
-
- return (W_numpy.shape[0] == W_numpy.shape[1]) and np.allclose(
- W_numpy, np.eye(W_numpy.shape[0])
- )
-
-
-def check_unitary_matrix(W: Tensor) -> bool:
- if isinstance(W, np.ndarray):
- W_numpy = W.copy().astype(np.float64)
- elif isinstance(W, torch.Tensor):
- W_numpy = W.detach().cpu().numpy().copy().astype(np.float64)
- else:
- assert 0, "[E] Array type not supported, must be numpy.ndarray or torch.Tensor"
- M = np.dot(W_numpy, W_numpy.T)
- # print(M)
- return check_identity_matrix(M)
-
-
-def check_equal_tensor(W1: Tensor, W2: Tensor) -> bool:
- if isinstance(W1, np.ndarray):
- W1_numpy = W1.copy().astype(np.float64)
- elif isinstance(W1, torch.Tensor):
- W1_numpy = W1.detach().cpu().numpy().copy().astype(np.float64)
- else:
- assert 0, "[E] Array type not supported, must be numpy.ndarray or torch.Tensor"
-
- if isinstance(W2, np.ndarray):
- W2_numpy = W2.copy().astype(np.float64)
- elif isinstance(W2, torch.Tensor):
- W2_numpy = W2.detach().cpu().numpy().copy().astype(np.float64)
- else:
- assert 0, "[E] Array type not supported, must be numpy.ndarray or torch.Tensor"
- return (W1_numpy.shape == W2_numpy.shape) and np.allclose(W1_numpy, W2_numpy)
-
-
-def batch_diag(x: Tensor) -> Tensor:
- # x[..., N, N] -> [..., N]
- assert (
- len(x.shape) >= 2
- ), f"At least 2-D array/tensor is expected, but got shape {x.shape}"
- if isinstance(x, np.ndarray):
- size = list(x.shape)
- x = x.reshape(size[:-2] + [size[-2] * size[-1]])
- x = x[..., :: size[-1] + 1]
- elif isinstance(x, torch.Tensor):
- size = list(x.size())
- x = x.flatten(-2, -1)
- x = x[..., :: size[-1] + 1]
- else:
- raise NotImplementedError
- return x
-
-
-def batch_eye_cpu(N: int, batch_shape: List[int], dtype: np.dtype) -> np.ndarray:
- x = np.zeros(list(batch_shape) + [N, N], dtype=dtype)
- x.reshape(-1, N * N)[..., :: N + 1] = 1
- return x
-
-
-def batch_eye(
- N: int,
- batch_shape: List[int],
- dtype: torch.dtype,
- device: Device = torch.device("cuda"),
-) -> torch.Tensor:
- x = torch.zeros(list(batch_shape) + [N, N], dtype=dtype, device=device)
- x.view(-1, N * N)[..., :: N + 1] = 1
- return x
-
-
-def merge_chunks(x: Tensor, complex: bool = False) -> Tensor:
- """Merge a chunked/blocked tensors into a 2D matrix
-
- Args:
- x (Tensor): Tensor of shape [h1, w1, h2, w2, ...., hk, wk] if complex=False; [h1, w1, h2, w2, ...., hk, wk, 2] if complex=True
- complex (bool, optional): True if the tensor x has a last dimension with size 2 for real/imag representation. Defaults to False.
-
- Returns:
- Tensor: [h1*h2*...*hk, w1*w2*...*wk] or [h1*h2*...*hk, w1*w2*...*wk, 2]
- """
- if isinstance(x, torch.Tensor):
- permute = torch.permute
- elif isinstance(x, np.ndarray):
- permute = np.transpose
- else:
- raise NotImplementedError
-
- if not complex:
- dim = len(x.shape)
- x = permute(x, list(range(0, dim, 2)) + list(range(1, dim + 1, 2)))
- x = x.reshape(np.prod([x.shape[i] for i in range(dim // 2)]), -1)
- else:
- dim = len(x.shape) - 1
- x = permute(x, list(range(0, dim, 2)) + list(range(1, dim + 1, 2) + [dim]))
- x = x.reshape(np.prod([x.shape[i] for i in range(dim // 2)]), -1, 2)
-
- return x
-
-
-def partition_chunks(
- x: Tensor, out_shape: int | Tuple[int, ...], complex: bool = False
-) -> Tensor:
- """Partition a tensor into square chunks, similar to Rearrange in einops
-
- Args:
- x (Tensor): 2D tensor of shape [h1*h2*...*hk, w1*w2*...*wk] or 3D tensor of shape [h1*h2*...*hk, w1*w2*...*wk, 2] if complex=True
- out_shape (Tuple[int]): output blocked shape (h1, w1, h2, w2, ...); Do not include the last dimension even if complex=True
- complex (bool, optional): whether x is complex tensor. Defaults to False.
-
- Returns:
- [Tensor]: Tensor of shape [h1, w1, h2, w2, ...., hk, wk] or [h1, w1, h2, w2, ...., hk, wk, 2] if complex=True
- """
- if complex:
- assert len(x.shape) == 3
- x_shape = (np.prod(out_shape[::2]), np.prod(out_shape[1::2]))
- if isinstance(x, torch.Tensor):
- permute = torch.permute
- pad_fn = lambda x, padding: torch.nn.functional.pad(x[None, None], padding)[
- 0, 0
- ]
- is_tensor = True
- elif isinstance(x, np.ndarray):
- permute = np.transpose
- pad_fn = np.pad
- is_tensor = False
- else:
- raise NotImplementedError
-
- if x_shape != x.shape[:2]:
- ## if x cannot be partitioned into out_shape, we need to pad it
- if is_tensor:
- ## torch from the last dim
- padding = (0, x_shape[1] - x.shape[1], 0, x_shape[0] - x.shape[0])
- if complex:
- padding = (0, 0) + padding
- else:
- ## np from the first dim
- padding = ((0, x_shape[0] - x.shape[0]), (0, x_shape[1] - x.shape[1]))
- if complex:
- padding = padding + (0, 0)
-
- x = pad_fn(x, padding)
-
- in_shape = list(out_shape[::2]) + list(out_shape[1::2])
- permute_shape = np.arange(len(out_shape)).reshape(2, -1).T.reshape(-1).tolist()
- if complex:
- in_shape.append(2)
- permute_shape.append(len(permute_shape))
- x = x.reshape(in_shape) # [h1, h2, ..., hk, w1, w2, ..., wk]
-
- x = permute(x, permute_shape) # [h1, w1, h2, w2, ...., hk, wk]
-
- return x
-
-
-def clip_by_std(x: Tensor, n_std_neg: float = 3.0, n_std_pos: float = 3.0) -> Tensor:
- if isinstance(x, np.ndarray):
- std = np.std(x)
- mean = np.mean(x)
- out = np.clip(x, a_min=mean - n_std_neg * std, a_max=mean + n_std_pos * std)
- elif isinstance(x, torch.Tensor):
- std = x.data.std()
- mean = x.data.mean()
- out = x.clamp(min=mean - n_std_neg * std, max=mean + n_std_pos * std)
- else:
- raise NotImplementedError
- return out
-
-
-def percentile(t: Tensor, q: float) -> Tensor:
- """
- Return the ``q``-th percentile of the flattened input tensor's data.
-
- CAUTION:
- * Needs PyTorch >= 1.1.0, as ``torch.kthvalue()`` is used.
- * Values are not interpolated, which corresponds to
- ``numpy.percentile(..., interpolation="nearest")``.
-
- :param t: Input tensor.
- :param q: Percentile to compute, which must be between 0 and 100 inclusive.
- :return: Resulting value (scalar).
- """
- # Note that ``kthvalue()`` works one-based, i.e. the first sorted value
- # indeed corresponds to k=1, not k=0! Use float(q) instead of q directly,
- # so that ``round()`` returns an integer, even if q is a np.float32.
- if isinstance(t, torch.Tensor):
- k = 1 + round(0.01 * float(q) * (t.numel() - 1))
- result = t.view(-1).kthvalue(k).values.item()
- elif isinstance(t, np.ndarray):
- result = np.percentile(t, q=q)
- else:
- raise NotImplementedError
- return result
-
-
-def gen_boolean_mask_cpu(size: _size, true_prob: float) -> np.ndarray:
- assert 0 <= true_prob <= 1, "[E] Wrong probability for True"
- return np.random.choice(a=[False, True], size=size, p=[1 - true_prob, true_prob])
-
-
-def gen_boolean_mask(
- size: _size,
- true_prob: float,
- random_state: Optional[int] = None,
- device: Device = torch.device("cuda"),
-) -> Tensor:
- assert 0 <= true_prob <= 1, "[E] Wrong probability for True"
- if true_prob > 1 - 1e-9:
- return torch.ones(size, device=device, dtype=torch.bool)
- elif true_prob < 1e-9:
- return torch.zeros(size, device=device, dtype=torch.bool)
- if random_state is not None:
- with torch.random.fork_rng():
- torch.random.manual_seed(random_state)
- return torch.empty(size, dtype=torch.bool, device=device).bernoulli_(
- true_prob
- )
- else:
- return torch.empty(size, dtype=torch.bool, device=device).bernoulli_(true_prob)
-
-
-def fftshift_cpu(
- x: Union[Tensor, np.ndarray], batched: bool = True, dim: Optional[Tuple[int]] = None
-) -> Union[Tensor, np.ndarray]:
- if isinstance(x, np.ndarray):
- if dim is None:
- if batched:
- dim = tuple(range(1, len(x.shape)))
- else:
- dim = tuple(range(0, len(x.shape)))
- out = np.fft.fftshift(x, axes=dim)
- elif isinstance(x, torch.Tensor):
- device = x.device
- x = x.cpu().detach().numpy()
- if dim is None:
- if batched:
- dim = tuple(range(1, len(x.shape)))
- else:
- dim = tuple(range(0, len(x.shape)))
- out = np.fft.fftshift(x, axes=dim)
- out = torch.from_numpy(out).to(device)
- return out
-
-
-def ifftshift_cpu(
- x: Union[Tensor, np.ndarray], batched: bool = True, dim: Optional[Tuple[int]] = None
-) -> Union[Tensor, np.ndarray]:
- if isinstance(x, np.ndarray):
- if dim is None:
- if batched:
- dim = tuple(range(1, len(x.shape)))
- else:
- dim = tuple(range(0, len(x.shape)))
- out = np.fft.ifftshift(x, axes=dim)
- elif isinstance(x, torch.Tensor):
- device = x.device
- x = x.cpu().detach().numpy()
- if dim is None:
- if batched:
- dim = tuple(range(1, len(x.shape)))
- else:
- dim = tuple(range(0, len(x.shape)))
- out = np.fft.ifftshift(x, axes=dim)
- out = torch.from_numpy(out).to(device)
- return out
-
-
-def gen_gaussian_noise(
- W: Union[Tensor, np.ndarray],
- noise_mean: float = 0.0,
- noise_std: float = 0.002,
- trunc_range: Tuple = (),
- random_state: Optional[int] = None,
-) -> Union[Tensor, np.ndarray]:
- if random_state is not None:
- set_torch_deterministic(random_state)
- if isinstance(W, np.ndarray):
- if not trunc_range:
- noises = np.random.normal(noise_mean, noise_std, W.shape)
- else:
- a = (trunc_range[0] - noise_mean) / noise_std
- b = (trunc_range[1] - noise_mean) / noise_std
- noises = truncnorm.rvs(
- a, b, loc=noise_mean, scale=noise_std, size=W.shape, random_state=None
- )
- elif isinstance(W, torch.Tensor):
- if not trunc_range:
- noises = torch.zeros_like(W).normal_(mean=noise_mean, std=noise_std)
- else:
- size = W.shape
- tmp = W.new_empty(size + (4,)).normal_()
- a = (trunc_range[0] - noise_mean) / noise_std
- b = (trunc_range[1] - noise_mean) / noise_std
- valid = (tmp < b) & (tmp > a)
- ind = valid.max(-1, keepdim=True)[1]
- noises = tmp.gather(-1, ind).squeeze(-1).mul_(noise_std).add_(noise_mean)
- # noises = truncated_normal(W, mean=noise_mean, std=noise_std, a=trunc_range[0], b=trunc_range[1])
- else:
- assert 0, logging.error(
- f"Array type not supported, must be numpy.ndarray or torch.Tensor, but got {type(W)}"
- )
- return noises
-
-
-def gen_gaussian_filter2d_cpu(size: int = 3, std: float = 0.286) -> np.ndarray:
- assert (
- size % 2 == 1
- ), f"Gaussian filter can only be odd size, but size={size} is given."
- ax = np.linspace(-(size - 1) / 2.0, (size - 1) / 2.0, size)
- xx, yy = np.meshgrid(ax, ax)
- kernel = np.exp(-0.5 / np.square(std) * (np.square(xx) + np.square(yy)))
- kernel = kernel / np.sum(kernel)
- kernel[size // 2, size // 2] = 1
- return kernel
-
-
-def gen_gaussian_filter2d(
- size: int = 3,
- std: float = 0.286,
- center_one: bool = True,
- device: Device = torch.device("cuda"),
-) -> Tensor:
- assert (
- size % 2 == 1
- ), f"Gaussian filter can only be odd size, but size={size} is given."
- if std > 1e-8:
- ax = torch.linspace(
- -(size - 1) / 2.0,
- (size - 1) / 2.0,
- size,
- dtype=torch.float32,
- device=device,
- )
- xx, yy = torch.meshgrid(ax, ax)
- kernel = torch.exp(-0.5 / (std**2) * (xx.square() + yy.square()))
- kernel = kernel.div_(kernel.sum())
- if center_one:
- kernel[size // 2, size // 2] = 1
- else:
- kernel = torch.zeros(size, size, dtype=torch.float32, device=device)
- kernel[size // 2, size // 2] = 1
-
- return kernel
-
-
-def add_gaussian_noise(
- W: Union[Tensor, np.ndarray],
- noise_mean: float = 0,
- noise_std: float = 0.002,
- trunc_range: Tuple = (),
- random_state: Optional[int] = None,
-) -> Union[Tensor, np.ndarray]:
- noises = gen_gaussian_noise(
- W,
- noise_mean=noise_mean,
- noise_std=noise_std,
- trunc_range=trunc_range,
- random_state=random_state,
- )
- output = W + noises
- return output
-
-
-def add_gaussian_noise_(
- W: Union[Tensor, np.ndarray],
- noise_mean: float = 0,
- noise_std: float = 0.002,
- trunc_range: Tuple = (),
- random_state: Optional[int] = None,
-) -> Union[Tensor, np.ndarray]:
- noises = gen_gaussian_noise(
- W,
- noise_mean=noise_mean,
- noise_std=noise_std,
- trunc_range=trunc_range,
- random_state=random_state,
- )
- if isinstance(W, np.ndarray):
- W += noises
- elif isinstance(W, torch.Tensor):
- W.data += noises
- else:
- assert 0, logging.error(
- f"Array type not supported, must be numpy.ndarray or torch.Tensor, but got {type(W)}"
- )
- return W
-
-
-def add_gaussian_noise_cpu(
- W: Union[Tensor, np.ndarray],
- noise_mean: float = 0,
- noise_std: float = 0.002,
- trunc_range: Tuple = (),
-) -> Union[Tensor, np.ndarray]:
- if isinstance(W, np.ndarray):
- W_numpy = W.copy().astype(np.float64)
- elif isinstance(W, torch.Tensor):
- W_numpy = W.detach().cpu().numpy().copy().astype(np.float64)
- else:
- assert 0, "[E] Array type not supported, must be numpy.ndarray or torch.Tensor"
- if not trunc_range:
- noises = np.random.normal(noise_mean, noise_std, W_numpy.shape)
- else:
- a = (trunc_range[0] - noise_mean) / noise_std
- b = (trunc_range[1] - noise_mean) / noise_std
- noises = truncnorm.rvs(
- a, b, loc=noise_mean, scale=noise_std, size=W_numpy.shape, random_state=None
- )
- return W_numpy + noises
-
-
-def circulant_multiply(c: Tensor, x: Tensor) -> Tensor:
- """Multiply circulant matrix with first column c by x
- Parameters:
- c: (n, )
- x: (batch_size, n) or (n, )
- Return:
- prod: (batch_size, n) or (n, )
- """
- return torch.irfft(
- complex_mult(torch.rfft(c, 1), torch.rfft(x, 1)), 1, signal_sizes=(c.shape[-1],)
- )
-
-
-def calc_diagonal_hessian(weight_dict, loss, model):
- model.zero_grad()
- hessian_dict = {}
- for name, weight in weight_dict.items():
- first_gradient = grad(loss, weight, create_graph=True)[0]
- second_gradient = grad(first_gradient.sum(), weight, create_graph=True)[0]
- hessian_dict[name] = second_gradient.clone()
- model.zero_grad()
- return hessian_dict
-
-
-def calc_jacobian(
- weight_dict: Dict[str, Tensor], loss: Callable, model: nn.Module
-) -> Dict[str, Tensor]:
- model.zero_grad()
- jacobian_dict = {}
- for name, weight in weight_dict.items():
- first_gradient = grad(loss, weight, create_graph=True)[0]
- jacobian_dict[name] = first_gradient.clone()
- model.zero_grad()
- return jacobian_dict
-
-
@lru_cache(maxsize=4)
def _polynomial_order_base(order: int, device: Device) -> Tensor:
return torch.arange(order - 1, -1, -1, device=device)
@@ -797,277 +185,3 @@ def polynomial(x: Tensor | np.ndarray, coeff: Tensor | np.ndarray) -> Tensor:
return np.polynomial.polynomial.polyval(x, coeff[::-1])
else:
raise NotImplementedError
-
-
-def gaussian(x: Tensor, coeff: Tensor) -> Tensor:
- # coeff : [n, 3], includes a, b, c
- ## a * exp(-((x-b)/c)^2) + ...
- size = x.size()
- x = x.view(-1).unsqueeze(0)
- x = (
- (coeff[:, 0:1] * torch.exp(-((x - coeff[:, 1:2]) / coeff[:, 2:3]).square()))
- .sum(dim=0)
- .view(size)
- )
- return x
-
-
-def lowrank_decompose(
- x: Tensor,
- r: int,
- u_ortho: bool = False,
- out_u: Optional[Tensor] = None,
- out_v: Optional[Tensor] = None,
-) -> Tuple[Tensor, Tensor]:
- """low rank decomposition on x. x ~ uv.
-
- Args:
- x (Tensor): tensor to decomplse
- r (int): rank
- u_ortho (bool, optional): whether u is orthogonal matrix. Defaults to False.
- out_u (Optional[Tensor], optional): output buffer for u. Defaults to None.
- out_v (Optional[Tensor], optional): output buffer for v. Defaults to None.
-
- Returns:
- Tuple[Tensor, Tensor]: [description]
- """
- ### x [..., m, n]
- # r rank
- u, s, v = x.data.svd(some=True)
- v = v.transpose(-2, -1).contiguous()
- u = u[..., :, :r]
- s = s[..., :r]
- v = v[..., :r, :]
- if u_ortho == False:
- u.mul_(s.unsqueeze(-2))
- else:
- v.mul_(s.unsqueeze(-1))
- if out_u is not None:
- out_u.data.copy_(u)
- if out_v is not None:
- out_v.data.copy_(v)
- return u, v
-
-
-def get_conv2d_flops(
- input_shape: _size,
- conv_filter: _size,
- stride: _pair = (1, 1),
- padding: _pair = (1, 1),
-) -> float:
- # input_shape = (4, 3,300,300) # Format:(batch, channels, rows,cols)
- # conv_filter = (64,3,3,3) # Format: (num_filters, channels, rows, cols)
- # stride = (1, 1) in (height, width)
- # padding = (1, 1) in (height, width)
- if type(stride) not in {list, tuple}:
- stride = [stride, stride]
- if type(padding) not in {list, tuple}:
- padding = [padding, padding]
- n = conv_filter[1] * conv_filter[2] * conv_filter[3] # vector_length
- # general defination for number of flops (n: multiplications and n-1: additions)
- flops_per_instance = n + 1
-
- num_instances_per_filter = (
- (input_shape[2] - conv_filter[2] + 2 * padding[0]) / stride[0]
- ) + 1 # for rows
- # multiplying with cols
- num_instances_per_filter *= (
- (input_shape[3] - conv_filter[3] + 2 * padding[1]) / stride[1]
- ) + 1
-
- flops_per_filter = num_instances_per_filter * flops_per_instance
- # multiply with number of filters adn batch
- total_flops_per_layer = flops_per_filter * conv_filter[0] * input_shape[0]
- return total_flops_per_layer
-
-
-class Interp1d(torch.autograd.Function):
- @staticmethod
- def forward(ctx, x, y, xnew, out=None):
- """
- Batched Linear 1D interpolation on the GPU for Pytorch.
- This function returns interpolated values of a set of 1-D functions at
- the desired query points `xnew`. Any point exceeds the border of [xmin, xmax]
- will be filled with 0 and no grad.
- This function is working similarly to Matlab™ or scipy functions with
- the `linear` interpolation mode on, except that it parallelises over
- any number of desired interpolation problems.
- The code will run on GPU if all the tensors provided are on a cuda
- device.
- https://github.com/aliutkus/torchinterp1d
-
- Parameters
- ----------
- x : (N, ) or (D, N) Pytorch Tensor
- A 1-D or 2-D tensor of real values.
- y : (N,) or (D, N) Pytorch Tensor
- A 1-D or 2-D tensor of real values. The length of `y` along its
- last dimension must be the same as that of `x`
- xnew : (P,) or (D, P) Pytorch Tensor
- A 1-D or 2-D tensor of real values. `xnew` can only be 1-D if
- _both_ `x` and `y` are 1-D. Otherwise, its length along the first
- dimension must be the same as that of whichever `x` and `y` is 2-D.
- out : Pytorch Tensor, same shape as `xnew`
- Tensor for the output. If None: allocated automatically.
-
- """
- # making the vectors at least 2D
- is_flat = {}
- require_grad = {}
- v = {}
- device = []
- eps = torch.finfo(y.dtype).eps
- for name, vec in {"x": x, "y": y, "xnew": xnew}.items():
- assert len(vec.shape) <= 2, "interp1d: all inputs must be " "at most 2-D."
- if len(vec.shape) == 1:
- v[name] = vec[None, :]
- else:
- v[name] = vec
- is_flat[name] = v[name].shape[0] == 1
- require_grad[name] = vec.requires_grad
- device = list(set(device + [str(vec.device)]))
- assert len(device) == 1, "All parameters must be on the same device."
- device = device[0]
-
- # Checking for the dimensions
- assert v["x"].shape[1] == v["y"].shape[1] and (
- v["x"].shape[0] == v["y"].shape[0]
- or v["x"].shape[0] == 1
- or v["y"].shape[0] == 1
- ), (
- "x and y must have the same number of columns, and either "
- "the same number of row or one of them having only one "
- "row."
- )
-
- reshaped_xnew = False
- if (
- (v["x"].shape[0] == 1)
- and (v["y"].shape[0] == 1)
- and (v["xnew"].shape[0] > 1)
- ):
- # if there is only one row for both x and y, there is no need to
- # loop over the rows of xnew because they will all have to face the
- # same interpolation problem. We should just stack them together to
- # call interp1d and put them back in place afterwards.
- original_xnew_shape = v["xnew"].shape
- v["xnew"] = v["xnew"].contiguous().view(1, -1)
- reshaped_xnew = True
-
- # identify the dimensions of output and check if the one provided is ok
- D = max(v["x"].shape[0], v["xnew"].shape[0])
- shape_ynew = (D, v["xnew"].shape[-1])
- if out is not None:
- if out.numel() != shape_ynew[0] * shape_ynew[1]:
- # The output provided is of incorrect shape.
- # Going for a new one
- out = None
- else:
- ynew = out.reshape(shape_ynew)
- if out is None:
- ynew = torch.zeros(*shape_ynew, device=device)
-
- # moving everything to the desired device in case it was not there
- # already (not handling the case things do not fit entirely, user will
- # do it if required.)
- for name in v:
- v[name] = v[name].to(device)
-
- # calling searchsorted on the x values.
- ind = ynew.long()
-
- # expanding xnew to match the number of rows of x in case only one xnew is
- # provided
- if v["xnew"].shape[0] == 1:
- v["xnew"] = v["xnew"].expand(v["x"].shape[0], -1)
-
- # the squeeze is because torch.searchsorted does accept either a nd with
- # matching shapes for x and xnew or a 1d vector for x. Here we would
- # have (1,len) for x sometimes
- torch.searchsorted(
- v["x"].contiguous().squeeze(), v["xnew"].contiguous(), out=ind
- )
-
- # the `-1` is because searchsorted looks for the index where the values
- # must be inserted to preserve order. And we want the index of the
- # preceeding value.
- ind -= 1
- # we clamp the index, because the number of intervals is x.shape-1,
- # and the left neighbour should hence be at most number of intervals
- # -1, i.e. number of columns in x -2
- ind = torch.clamp(ind, 0, v["x"].shape[1] - 1 - 1)
-
- # helper function to select stuff according to the found indices.
- def sel(name):
- if is_flat[name]:
- return v[name].contiguous().view(-1)[ind]
- return torch.gather(v[name], 1, ind)
-
- # activating gradient storing for everything now
- enable_grad = False
- saved_inputs = []
- for name in ["x", "y", "xnew"]:
- if require_grad[name]:
- enable_grad = True
- saved_inputs += [v[name]]
- else:
- saved_inputs += [
- None,
- ]
- # assuming x are sorted in the dimension 1, computing the slopes for
- # the segments
- is_flat["slopes"] = is_flat["x"]
- # now we have found the indices of the neighbors, we start building the
- # output. Hence, we start also activating gradient tracking
- with torch.enable_grad() if enable_grad else contextlib.suppress():
- v["slopes"] = (v["y"][:, 1:] - v["y"][:, :-1]) / (
- eps + (v["x"][:, 1:] - v["x"][:, :-1])
- )
-
- # now build the linear interpolation
- ynew = sel("y") + sel("slopes") * (v["xnew"] - sel("x"))
-
- mask = (v["xnew"] > v["x"][:, -1:]) | (
- v["xnew"] < v["x"][:, :1]
- ) # exceed left/right border
- ynew = ynew.masked_fill(mask, 0)
-
- if reshaped_xnew:
- ynew = ynew.view(original_xnew_shape)
-
- ctx.save_for_backward(ynew, *saved_inputs)
- return ynew
-
- @staticmethod
- def backward(ctx, grad_out):
- inputs = ctx.saved_tensors[1:]
- gradients = torch.autograd.grad(
- ctx.saved_tensors[0],
- [i for i in inputs if i is not None],
- grad_out,
- retain_graph=True,
- )
- result = [
- None,
- ] * 5
- pos = 0
- for index in range(len(inputs)):
- if inputs[index] is not None:
- result[index] = gradients[pos]
- pos += 1
- return (*result,)
-
-
-def interp1d(x: Tensor, y: Tensor, xnew: Tensor, out: Tensor | None = None) -> Tensor:
- """numpy.interp for pytorch. Only 1D
-
- Args:
- x (Tensor): input vector x coordinates
- y (Tensor): input vector y coordinates
- xnew (Tensor): new x coordinates to be interpolated
- out (Tensor, optional): output tensor. Defaults to None.
-
- Returns:
- Tensor: interpolated y coordinates
- """
- return Interp1d.apply(x, y, xnew, out)
diff --git a/src/chop/nn/optical/utils/initializer.py b/src/chop/nn/optical/utils/initializer.py
index f0591c33a..19c97b6dd 100644
--- a/src/chop/nn/optical/utils/initializer.py
+++ b/src/chop/nn/optical/utils/initializer.py
@@ -10,68 +10,15 @@
import torch
__all__ = [
- "quant_kaiming_uniform",
- "quant_kaiming_uniform_",
- "truncated_normal",
- "truncated_normal_",
+ # "quant_kaiming_uniform",
+ # "quant_kaiming_uniform_",
+ # "truncated_normal",
+ # "truncated_normal_",
"morr_uniform_",
- "morr_uniform",
+ # "morr_uniform",
]
-def quant_kaiming_uniform(w, nbit, beta=1.5):
- """https://arxiv.org/pdf/1802.04680.pdf"""
- if w.dim() > 2:
- receptive_field = w[0, 0, ...].numel()
- else:
- receptive_field = 1
- fan_in = w.size(1) * receptive_field
- sigma = 2 ** (1 - nbit)
- L_min = beta * sigma
- L = max(np.sqrt(6 / fan_in), L_min)
- return w.clone().uniform_(-L, L)
-
-
-def quant_kaiming_uniform_(w, nbit, beta=1.5):
- """https://arxiv.org/pdf/1802.04680.pdf"""
- if w.dim() > 2:
- receptive_field = w[0, 0, ...].numel()
- else:
- receptive_field = 1
- fan_in = w.size(1) * receptive_field
- sigma = 2 ** (1 - nbit)
- L = np.sqrt(6 / fan_in)
- L_min = beta * sigma
- scale = 2 ** round(np.log2(L_min / L))
- scale = max(scale, 1.0)
- L = max(L, L_min)
-
- return torch.nn.init.uniform_(w, -L, L), scale
-
-
-def truncated_normal(tensor, mean=0, std=1, a=-2, b=2):
- size = tensor.shape
- tmp = tensor.new_empty(size + (4,)).normal_()
- a = (a - mean) / std
- b = (b - mean) / std
- valid = (tmp < b) & (tmp > a)
- ind = valid.max(-1, keepdim=True)[1]
- output = tmp.gather(-1, ind).squeeze(-1).mul_(std).add_(mean)
- return output
-
-
-def truncated_normal_(tensor, mean=0, std=1, a=-2, b=2):
- size = tensor.shape
- tmp = tensor.new_empty(size + (4,)).normal_()
- a = (a - mean) / std
- b = (b - mean) / std
- valid = (tmp < b) & (tmp > a)
- ind = valid.max(-1, keepdim=True)[1]
- tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1))
- tensor.data.mul_(std).add_(mean)
- return tensor
-
-
def morr_uniform_(tensor, MORRConfig, n_op=4, biased=False, gain=1):
"""
description: Uniform initialization for MORR array based tensor core [SqueezeLight, Gu+, DATE'21]. We only consider how n_op influence one MORR's output. How to balance vector length should be considered in learnable balancing factor\\
@@ -84,7 +31,7 @@ def morr_uniform_(tensor, MORRConfig, n_op=4, biased=False, gain=1):
"""
morr_fwhm = (
-4
- * np.pi**2
+ * np.pi ** 2
* MORRConfig.radius
* MORRConfig.effective_index
* (
@@ -111,43 +58,3 @@ def morr_uniform_(tensor, MORRConfig, n_op=4, biased=False, gain=1):
return torch.nn.init.uniform_(tensor, 0, L)
else:
return torch.nn.init.uniform_(tensor, -L / 2, L / 2)
-
-
-def morr_uniform(tensor, MORRConfig, n_op=4, biased=False, gain=1):
- """
- description: Uniform initialization for MORR array based tensor core [SqueezeLight, Gu+, DATE'21]\\
- @tensor {torch.Tensor} weight tensor/parameter\\
- @MORRConfig {Config} MORR configuration defined in the onnlib/model/layer/device/mrr\\
- @n_op {int scalar} Number of operands on an MORR\\
- @biased {bool} biased=True, weight in [0, L]; otherwise in [-L/2, L/2].\\
- @gain {float} Gain due to activation. ReLU=sqrt(2), Tanh=5/3, Clamp(0,1)=2\\
- return {}
- """
- morr_fwhm = (
- -4
- * np.pi**2
- * MORRConfig.radius
- * MORRConfig.effective_index
- * (
- 1 / MORRConfig.resonance_wavelength
- - 1 / (MORRConfig.resonance_wavelength - MORRConfig.bandwidth / 2)
- )
- )
- ### first we need to calculate the information gain of an MORR, estimated by linear estimation at 0 and FWHM
- # t1 = mrr_roundtrip_phase_to_tr_fused(torch.tensor([0]), a=MORRConfig.attenuation_factor, r=MORRConfig.coupling_factor, intensity=True)
- # t2 = mrr_roundtrip_phase_to_tr_fused(torch.tensor([morr_fwhm]), a=MORRConfig.attenuation_factor, r=MORRConfig.coupling_factor, intensity=True)
- # g = (t2 - t1) / morr_fwhm
-
- # var_phi = 1 ## assume the input is normalized to have variance 1
- # var_w = 1/(3/2*g**4*n_op*var_phi)
-
- # ### calculate range of uniform distribution U(-L,L)
- # L = (3 * var_w)**0.5
- # return tensor.clone().uniform_(-L, L)
-
- ## approximation by assuming 4*std(phi)= 3*FWHM, E[x]=0, D[x]=1, W ~ U[0, L]
- L = (3 / (4 * n_op)) ** 0.5 * morr_fwhm * gain
- if biased:
- return tensor.clone().uniform_(0, L)
- else:
- return tensor.clone().uniform_(-L / 2, L / 2)
diff --git a/src/chop/nn/optical/utils/mrr.py b/src/chop/nn/optical/utils/mrr.py
index 53fd56fd9..826e8ad64 100644
--- a/src/chop/nn/optical/utils/mrr.py
+++ b/src/chop/nn/optical/utils/mrr.py
@@ -71,42 +71,3 @@ class MORRConfig_10um_MQ:
resonance_wavelength = 1538.739 # nm
bandwidth = 1.6702 # nm
quality_factor = 1213.047
-
-
-def plot_curve(config):
- import matplotlib.pyplot as plt
-
- lambda0 = config.resonance_wavelength
- lambda_vec = np.linspace(1546, lambda0, 9400)
- aa = config.attenuation_factor # attenuation a
-
- t = config.coupling_factor # self-coupling
- # r = np.sqrt(1 - t**2) # cross coupling coef
-
- R = config.radius # radius
- neff = config.effective_index # refractive index
- phi = -4 * np.pi * np.pi * R * neff / lambda_vec
-
- phase_shift = np.linspace(phi[0], phi[-1], len(phi))
- phase_shift = phase_shift - np.min(phase_shift)
- print(phase_shift)
- tr = (t - aa * np.exp(1j * phi)) / (1 - t * aa * np.exp(1j * phi))
- energy = abs(tr) ** 2
- print(energy)
- plt.figure()
- plt.plot(lambda_vec, energy)
- plt.savefig("mrr_tr_wl.png")
- plt.figure()
- plt.plot(phase_shift, energy)
- plt.savefig("mrr_tr_ps.png")
-
- for i, e in enumerate(energy[:-1]):
- if energy[i] >= 0.5 and energy[i + 1] <= 0.5:
- print(i, i + 1)
- print(energy[i], energy[i + 1])
- print(lambda_vec[i], lambda_vec[i + 1])
- exit(1)
-
-
-if __name__ == "__main__":
- plot_curve(MRRConfig_5um_MQ)
diff --git a/src/chop/nn/optical/utils/mrr_op.py b/src/chop/nn/optical/utils/mrr_op.py
index e98dc4fd9..6c397b3f5 100644
--- a/src/chop/nn/optical/utils/mrr_op.py
+++ b/src/chop/nn/optical/utils/mrr_op.py
@@ -20,107 +20,25 @@
__all__ = [
- "mrr_voltage_to_delta_lambda",
- "mrr_tr_to_roundtrip_phase",
- "mrr_roundtrip_phase_to_tr",
+ # "mrr_voltage_to_delta_lambda",
+ # "mrr_tr_to_roundtrip_phase",
+ # "mrr_roundtrip_phase_to_tr",
"mrr_roundtrip_phase_to_tr_fused",
- "mrr_roundtrip_phase_to_tr_grad_fused",
+ # "mrr_roundtrip_phase_to_tr_grad_fused",
"mrr_roundtrip_phase_to_tr_func",
- "mrr_roundtrip_phase_to_out_phase",
- "mrr_tr_to_out_phase",
- "mrr_roundtrip_phase_to_tr_phase",
- "mrr_roundtrip_phase_to_tr_phase_fused",
- "mrr_modulator",
- "mrr_filter",
- "morr_filter",
- "mrr_fwhm_to_ng",
- "mrr_ng_to_fsr",
- "mrr_finesse",
+ # "mrr_roundtrip_phase_to_out_phase",
+ # "mrr_tr_to_out_phase",
+ # "mrr_roundtrip_phase_to_tr_phase",
+ # "mrr_roundtrip_phase_to_tr_phase_fused",
+ # "mrr_modulator",
+ # "mrr_filter",
+ # "morr_filter",
+ # "mrr_fwhm_to_ng",
+ # "mrr_ng_to_fsr",
+ # "mrr_finesse",
]
-def mrr_voltage_to_delta_lambda(v, alpha, k, gamma, n_g, lambda_0):
- """
- description: micro-ring resonator (MRR) wavelength modulation, \delta\lambda=\delta\n_eff\times\lambda/n_g, \deltan_eff=\gamma k \delta T=\gamma k \alpha v^2\\
- v {torch.Tensor ro np.ndarray} voltage \\
- alpha {scalar} voltage square to temperature change coefficient \\
- k {scalar} parameter \\
- gamma {scalar} power to phase shift coefficient \\
- n_g {scalar} group index, typically from 4 to 4.5\\
- lambda_0 {torch.Tensor or np.ndarray} central wavelength\\
- return delta_lambda {torch.Tensor or np.ndarray} resonance wavelength drift
- """
- delta_neff = gamma * k * alpha * v * v
- delta_lambda = delta_neff * lambda_0 / n_g
- return delta_lambda
-
-
-def mrr_tr_to_roundtrip_phase(t, a, r):
- """
- description: field transmission to round trip phase shift
- t {torch.Tensor or np.ndarray} field transmission from [0,1] \\
- a {scalar} attenuation coefficient\\
- r {scalar} coupling coefficient\\
- return phi {torch.Tensor or np.ndarray} roune trip phase shift (abs(phase lag))[0, pi], center is 0. phase lag is negative, the sign is moved to the equation
- """
- # the curve has multiple valleies, thus given a t, there is infinite number of rt_phi, we only want [-pi, 0], thus the abs(phase lag) is in [0, pi], acos returns [0, pi], which matches our assumption
- assert 0 <= a <= 1, logging.error(f"Expect a from [0,1] but got {a}")
- assert 0 <= r <= 1, logging.error(f"Expect r from [0,1] but got {r}")
- # given a and r, the curve is fixed, the max and min may not be 1 and 0
- cos_phi = ((a * a + r * r - t * (1 + r * r * a * a)) / (2 * (1 - t) * a * r)).clamp(
- 0, 1
- )
-
- if isinstance(cos_phi, torch.Tensor):
- return cos_phi.acos(), cos_phi
- elif isinstance(cos_phi, np.ndarray):
- return np.arccos(cos_phi), cos_phi
- else:
- raise NotImplementedError
-
-
-def mrr_roundtrip_phase_to_tr(
- rt_phi, a: float = 0.8, r: float = 0.9, poly_coeff=None, intensity: bool = False
-):
- """
- description: round trip phase shift to field transmission
- rt_phi {torch.Tensor or np.ndarray} abs of roundtrip phase shift (abs(phase lag)). range from abs([-pi, 0])=[0, pi]\\
- a {scalar} attenuation coefficient\\
- r {scalar} self-coupling coefficient\\
- poly_coeff {Callable} polynomial coefficients of intensity tranmission-roundtrip phase curve. Default set to None. None for slow computation\\
- intensity {bool scalar} whether output intensity tranmission or field transmission
- return t {torch.Tensor or np.ndarray} mrr through port field/intensity transmission
- """
- if poly_coeff is not None:
- # fast mode, use polynomial to predict the intensity transmission curve
- # if using polynomial, we want fast intensity transmission estimation, instead of field
- # if using coherent light, we will use complex output, we won't use polynomial fit
- t = polynomial(rt_phi.clamp(0, np.pi), poly_coeff).clamp(1e-8, 1)
- if not intensity:
- # avoid NAN
- t = (t + 1e-12).sqrt()
- else:
- # use slow but accurate mode from theoretical equation
- # create e^(-j phi) first
- # with torch.autograd.profiler.profile(use_cuda=True) as prof:
- # ephi = torch.view_as_complex(polar_to_complex(mag=None, angle=-rt_phi)) ## this sign is from the negativity of phase lag
- # ### Jiaqi: Since PyTorch 1.7 rsub is not supported for autograd of complex, so have to use negate and add
- # a_ephi = -a * ephi
- # t = torch.view_as_real((r + a_ephi)/(1 + r * a_ephi))
-
- # if(intensity):
- # t = get_complex_energy(t)
- # else:
- # t = get_complex_magnitude(t)
- # print(prof.key_averages(group_by_stack_n=5).table(sort_by='cuda_time', row_limit=5))
- ra_cosphi_by_n2 = -2 * r * a * rt_phi.cos()
- t = (a * a + r * r + ra_cosphi_by_n2) / (1 + r * r * a * a + ra_cosphi_by_n2)
- if not intensity:
- # as long as a is not equal to r, t cannot be 0.
- t = t.sqrt()
- return t
-
-
@torch.jit.script
def mrr_roundtrip_phase_to_tr_fused(
rt_phi, a: float = 0.8, r: float = 0.9, intensity: bool = False
@@ -154,37 +72,13 @@ def mrr_roundtrip_phase_to_tr_fused(
return t
-@torch.jit.script
-def mrr_roundtrip_phase_to_tr_grad_fused(
- rt_phi, a: float = 0.8, r: float = 0.9, intensity: bool = False
-):
- """
- description: round trip phase shift to the gradient of field transmission
- rt_phi {torch.Tensor or np.ndarray} abs of roundtrip phase shift (abs(phase lag)). range from abs([-pi, 0])=[0, pi]\\
- a {scalar} attenuation coefficient\\
- r {scalar} self-coupling coefficient\\
- intensity {bool scalar} whether output intensity tranmission or field transmission\\
- return g {torch.Tensor or np.ndarray} the gradient of mrr through port field/intensity transmission
- """
- if not intensity:
- g = (a * r * (a**2 - 1) * (r**2 - 1) * rt_phi.sin()) / (
- (a**2 + r**2 - 2 * a * r * rt_phi.cos()) ** (1 / 2)
- * (a**2 * r**2 + 1 - 2 * a * r * rt_phi.cos()) ** 1.5
- )
- else:
- g = ((a**2 - 1) * (r**2 - 1) * 2 * a * r * rt_phi.sin()) / (
- a**2 * r**2 + 1 - 2 * a * r * rt_phi.cos()
- ) ** 2
- return g
-
-
def mrr_roundtrip_phase_to_tr_func(
a: float = 0.8, r: float = 0.9, intensity: bool = False
):
c1 = -2 * a * r
c2 = a * a + r * r
c3 = 1 + r * r * a * a - a * a - r * r
- c4 = (a**2 - 1) * (r**2 - 1) * 2 * a * r
+ c4 = (a ** 2 - 1) * (r ** 2 - 1) * 2 * a * r
class MRRRoundTripPhaseToTrFunction(torch.autograd.Function):
@staticmethod
@@ -217,208 +111,3 @@ def backward(ctx, grad_output):
return grad_input
return MRRRoundTripPhaseToTrFunction.apply
-
-
-def mrr_roundtrip_phase_to_out_phase(rt_phi, a, r):
- """
- description: from round trip phase to output phase response \\
- rt_phi {torch.Tensor or np.ndarray} round trip phase shift\\
- a {scalar} attenuation coefficient\\
- r {scalar} coupling coefficient\\
- return phase {torch.Tensor or np.ndarray} output phase response
- """
- if isinstance(rt_phi, torch.Tensor):
- arctan = torch.atan2
- sin = torch.sin
- cos = torch.cos
- elif isinstance(rt_phi, np.ndarray):
- arctan = np.arctan2
- sin = np.sin
- cos = np.cos
- else:
- raise NotImplementedError
- sin_rt_phi = sin(rt_phi)
- cos_rt_phi = cos(rt_phi)
- # phi = np.pi + rt_phi + arctan(r*sin_rt_phi-2*r*r*a*sin_rt_phi*cos_rt_phi+r*a*a*sin_rt_phi, (a-r*cos_rt_phi)*(1-r*a*cos_rt_phi))
- phi = (
- np.pi
- - rt_phi
- - arctan(r * sin_rt_phi, a - r * cos_rt_phi)
- - arctan(r * a * sin_rt_phi, 1 - r * a * cos_rt_phi)
- )
- return phi
-
-
-def mrr_tr_to_out_phase(t, a, r, onesided=True):
- """
- description: field transmission to round trip phase shift
- t {torch.Tensor or np.ndarray} field transmission from [0,1] \\
- a {scalar} attenuation coefficient\\
- r {scalar} coupling coefficient\\
- onesided {bool scalar} True if only use half of the curve, output phase range [0, pi]
- return phi {torch.Tensor or np.ndarray} roune trip phase shift
- """
- rt_phi, cos_rt_phi = mrr_tr_to_roundtrip_phase(t, a, r)
- if isinstance(t, torch.Tensor):
- arctan = torch.atan2
- sin = torch.sin
- elif isinstance(t, np.ndarray):
- arctan = np.arctan2
- sin = np.sin
- else:
- raise NotImplementedError
- sin_rt_phi = sin(rt_phi)
- # phi = np.pi + rt_phi + arctan(r*sin_rt_phi-2*r*r*a*sin_rt_phi*cos_rt_phi+r*a*a*sin_rt_phi, (a-r*cos_rt_phi)*(1-r*a*cos_rt_phi))
- phi = (
- np.pi
- - rt_phi
- - arctan(r * sin_rt_phi, a - r * cos_rt_phi)
- - arctan(r * a * sin_rt_phi, 1 - r * a * cos_rt_phi)
- )
- if onesided:
- pass
- return phi
-
-
-def mrr_roundtrip_phase_to_tr_phase(rt_phi, a, r):
- """
- description: from round trip phase to output transmission with phase response \\
- rt_phi {torch.Tensor or np.ndarray} round trip phase shift\\
- a {scalar} attenuation coefficient\\
- r {scalar} coupling coefficient\\
- return output {torch.Tensor or np.ndarray} transmission with phase response
- """
- # e^(-j phi)
- ephi = torch.view_as_complex(polar_to_complex(mag=None, angle=-rt_phi))
- a_ephi = -a * ephi
- output = torch.view_as_real((r + a_ephi) / (1 + r * a_ephi))
- return output
-
-
-@torch.jit.script
-def mrr_roundtrip_phase_to_tr_phase_fused(rt_phi, a: float, r: float):
- """
- description: from round trip phase to output transmission with phase response \\
- rt_phi {torch.Tensor or np.ndarray} round trip phase shift\\
- a {scalar} attenuation coefficient\\
- r {scalar} coupling coefficient\\
- return output {torch.Tensor or np.ndarray} transmission with phase response
- """
- # e^(-j phi)
- rt_phi = -rt_phi
- rt_phi = torch.complex(rt_phi.cos(), rt_phi.sin())
- rt_phi = -a * rt_phi
- output = torch.view_as_real((r + rt_phi) / (1 + r * rt_phi))
- return output
-
-
-def mrr_modulator(t, a=0.9, r=0.8):
- """
- @description: all-pass MRR as a modulator. Map from the field intensity of through port transmission to coherent light with phase reponse\\
- @t {torch.Tensor or np.ndarray} field intensity modulation factor\\
- @a {float} attenuation factor from [0,1]. Default: 0.9\\
- @r {float} transmission/self-coupling factor from [0,1]. Default: 0.8\\
- @return: complexed light signal
- """
- phase = mrr_tr_to_out_phase(t, a, r)
- cos_phase, sin_phase = torch.cos(phase), torch.sin(phase)
- output_real = t * cos_phase
- output_imag = t * sin_phase
- output = torch.stack([output_real, output_imag], dim=-1)
- return output
-
-
-def mrr_filter(x, t, a=0.9, r=0.8):
- """
- @description: all-pass MRR as a filter. Map from the input complex light signal to output signal with through port transmission\\
- @x {torch.Tensor or np.ndarray} complexed input light signal\\
- @t {torch.Tensor or np.ndarray} field intensity modulation factor\\
- @a {float} attenuation factor from [0,1]. Default: 0.9\\
- @r {float} transmission/self-coupling factor from [0,1]. Default: 0.8\\
- @return: complexed light signal
- """
- phase = mrr_tr_to_out_phase(t, a, r)
- cos_phase, sin_phase = torch.cos(phase), torch.sin(phase)
- phase_shift = torch.complex(cos_phase, sin_phase)
- out = t * complex_mult(x, phase_shift)
- return out
-
-
-def morr_filter(
- rt_phi, tr_poly_coeff=None, a=0.9, r=0.8, x=None, coherent=False, intensity=False
-):
- """
- description: from round trip phase shift to output signal \\
- rt_phi {torch.Tensor or np.ndarray, Optional} round trip phase shift. Default set to None \\
- tr_poly_coeff {Callable} polynomial coefficients of tranmission-roundtrip phase curve. Default set to None. None for slow computation\\
- a {float} attenuation factor from [0,1]. Default: 0.9\\
- r {float} transmission/self-coupling factor from [0,1]. Default: 0.8\\
- x {torch.Tensor or np.ndarray, Optional} input complex light signal {None, real tensor or complex tensor}. Default set to None\\
- coherent {bool scalar, Optional} coherent output or not. Default set to False\\
- intensity {bool scalar, Optional} whether use intensity or field transmission. Default set to False\\
- return output {torch.Tensor or np.ndarray} real tensor if incoherent, complex tensor if coherent
- """
- if not coherent:
- if x is None:
- # unit laser input with incoherent light, 1e^j0
- t = mrr_roundtrip_phase_to_tr(
- rt_phi, a=a, r=r, poly_coeff=tr_poly_coeff, intensity=intensity
- )
- return t
- else:
- # incoherent light with non-unit input, input must be real number
- t = mrr_roundtrip_phase_to_tr(
- rt_phi, a=a, r=r, poly_coeff=tr_poly_coeff, intensity=intensity
- )
- return x * t
- else:
- if x is None:
- # coherent light with unit laser, 1e^j0, treat morr as a mrr modulator
- phase = polar_to_complex(
- mag=None, angle=mrr_roundtrip_phase_to_out_phase(rt_phi, a, r)
- )
- return phase
- else:
- # coherent light with complex input
- return complex_mult(mrr_roundtrip_phase_to_tr_phase(rt_phi, a, r), x)
-
-
-def mrr_fwhm_to_ng(a, r, radius, lambda0, fwhm):
- """
- description: from full-width half maximum (FWHM) and resonance wavelength to group index n_g (Bogaerts et al., Silicon microring resonators, Laser and Photonics Review 2011, Eq.(7))\\
- a {float} Attention coefficient\\
- r {float} Self-coupling coefficient\\
- radius {float} Radius of the MRR (unit: nm)\\
- lambda0 {float} Resonance wavelength (unit: nm)\\
- fwhm {float} bandwidth or full width half maximum (unit: nm)\\
- return n_g {float} Group index of the MRR
- """
- n_g = (
- (1 - r * a) * lambda0**2 / (2 * np.pi * np.pi * radius * (r * a) ** 0.5 * fwhm)
- )
- return n_g
-
-
-def mrr_ng_to_fsr(lambda0, n_g, radius):
- """
- description: Calculate the free-spectral range (FSR) based on the central resonance wavelength, group index and MRR radius.
- (Bogaerts et al., Silicon microring resonators, Laser and Photonics Review 2011, Eq.(9))\\
- lambda0 {float} Resonance wavelength (unit: nm)\\
- n_g {float} Group index\\
- radius {float} Radius of the MRR (unit: nm)\\
- return fsr {float} Free-spectral range
- """
- fsr = lambda0**2 / (n_g * 2 * np.pi * radius)
- return fsr
-
-
-def mrr_finesse(a, r):
- """
- description: Calculate the finesse of the MRR, i.e., finesse=FSR/FWHM=pi*sqrt(ra)/(1-ra) (Bogaerts et al., Silicon microring resonators, Laser and Photonics Review 2011, Eq.(21))\\
- a {float} Attention coefficient\\
- r {float} Self-coupling coefficient\\
- return finesse {float} Finesse of the MRR
- """
- ra = r * a
- finesse = np.pi * ra**0.5 / (1 - ra)
- return finesse
diff --git a/src/chop/nn/optical/utils/quantize.py b/src/chop/nn/optical/utils/quantize.py
index f60e8b240..828da142d 100644
--- a/src/chop/nn/optical/utils/quantize.py
+++ b/src/chop/nn/optical/utils/quantize.py
@@ -12,33 +12,17 @@
__all__ = [
- "uniform_quantize_cpu",
- "pact_quantize",
- "PACT_Act",
- "uniform_quantize",
- "uniform_quantize_new",
- "ewgs_quantize",
+ # "uniform_quantize_cpu",
+ # "pact_quantize",
+ # "PACT_Act",
+ # "uniform_quantize",
+ # "uniform_quantize_new",
+ # "ewgs_quantize",
"input_quantize_fn",
"weight_quantize_fn",
]
-class uniform_quantize_cpu(object):
- def __init__(self, bits):
- super(uniform_quantize_cpu).__init__()
- self.bits = bits
-
- def __call__(self, input):
- if self.bits == 32:
- out = input
- elif self.bits == 1:
- out = np.sign(input)
- else:
- n = float(2**self.bits - 1)
- out = np.round(input * n) / n
- return out
-
-
def uniform_quantize(k, gradient_clip=False):
class qfn(torch.autograd.Function):
@staticmethod
@@ -48,7 +32,7 @@ def forward(ctx, input):
elif k == 1:
out = torch.sign(input)
else:
- n = float(2**k - 1)
+ n = float(2 ** k - 1)
out = torch.round(input * n) / n
return out
@@ -79,7 +63,7 @@ def forward(ctx, input, scale, zero_point):
elif k == 1:
out = torch.sign(input)
else:
- n = float(2**k - 1)
+ n = float(2 ** k - 1)
# out = torch.round(input * n) / n
# out = (torch.clamp(torch.round(input / scale + zero_point), 0, n) - zero_point) * scale
out = (
@@ -102,38 +86,6 @@ def backward(ctx, grad_output):
return qfn.apply
-def ewgs_quantize(num_levels, gradient_clip=False, scaling_factor: float = 1e-3):
- class EWGS_quantizer(torch.autograd.Function):
- """
- Network Quantization with Element-wise Gradient Scaling, CVPR 2021
- https://github.com/cvlab-yonsei/EWGS/blob/main/CIFAR10/custom_modules.py
- x_in: continuous inputs within the range of [0,1]
- num_levels: number of discrete levels
- scaling_factor: backward scaling factor, typically fixed to 1e-3
- x_out: discretized version of x_in within the range of [0,1]
- """
-
- @staticmethod
- def forward(ctx, input):
- out = input.mul(num_levels - 1).round_().mul_(1 / (num_levels - 1))
-
- ctx._scaling_factor = scaling_factor
- ctx.save_for_backward(input - out)
- return out
-
- @staticmethod
- def backward(ctx, grad_output):
- diff = ctx.saved_tensors[0]
- delta = ctx._scaling_factor
- scale = diff.mul_(grad_output.sign()).mul_(delta).add_(1)
- grad_input = grad_output * scale
- if gradient_clip:
- grad_input.clamp_(-1, 1)
- return grad_input
-
- return EWGS_quantizer.apply
-
-
class input_quantize_fn(torch.nn.Module):
def __init__(
self, in_bit, alg="dorefa", device=torch.device("cuda:0"), quant_ratio=1.0
@@ -181,7 +133,7 @@ def __init__(
qscheme=torch.per_tensor_affine,
reduce_range=False,
quant_min=0,
- quant_max=2**self.in_bit - 1,
+ quant_max=2 ** self.in_bit - 1,
).to(self.device)
else:
self.obs = None
@@ -428,201 +380,3 @@ def forward(self, x):
assert NotImplementedError
return weight_q
-
-
-# PACT activation: https://arxiv.org/pdf/1805.06085.pdf
-class PACT_QuantFunc(torch.autograd.Function):
- r"""PACT (PArametrized Clipping acTivation) quantization function for activations.
- Implements a :py:class:`torch.autograd.Function` for quantizing activations in :math:`Q` bits using the PACT strategy.
- In forward propagation, the function is defined as
-
- .. math::
- \mathbf{y} = f(\mathbf{x}) = 1/\varepsilon \cdot \left\lfloor\mathrm{clip}_{ [0,\alpha) } (\mathbf{x})\right\rfloor \cdot \varepsilon
-
- where :math:`\varepsilon` is the quantization precision:
-
- .. math::
- \varepsilon = \alpha / (2^Q - 1)
-
- In backward propagation, using the Straight-Through Estimator, the gradient of the function is defined as
-
- .. math::
- \mathbf{\nabla}_\mathbf{x} \mathcal{L} &\doteq \mathbf{\nabla}_\mathbf{y} \mathcal{L}
-
- It can be applied by using its static `.apply` method:
-
- :param input: the tensor containing :math:`x`, the activations to be quantized.
- :type input: `torch.Tensor`
- :param eps: the precomputed value of :math:`\varepsilon`.
- :type eps: `torch.Tensor` or float
- :param alpha: the value of :math:`\alpha`.
- :type alpha: `torch.Tensor` or float
- :param delta: constant to sum to `eps` for numerical stability (default unused, 0 ).
- :type delta: `torch.Tensor` or float
-
- :return: The quantized input activations tensor.
- :rtype: `torch.Tensor`
- """
-
- @staticmethod
- def forward(ctx, input, eps, alpha):
- where_input_clipped = (input < 0) | (input >= alpha)
- where_input_ltalpha = input < alpha
- ctx.save_for_backward(where_input_clipped, where_input_ltalpha)
- return ((input / (eps)).floor() * eps).clamp(0.0, alpha.data[0] - eps.data[0])
-
- @staticmethod
- def backward(ctx, grad_output):
- # see Hubara et al., Section 2.3
- where_input_clipped, where_input_ltalpha = ctx.saved_tensors
- # zero = torch.zeros(1, device=where_input_nonclipped.device)
- grad_input = grad_output.masked_fill(where_input_clipped, 0)
- # grad_input = torch.where(where_input_nonclipped, grad_output, zero)
- grad_alpha = grad_output.masked_fill(where_input_ltalpha, 0).sum().expand(1)
- # grad_alpha = torch.where(where_input_gtalpha, grad_output, zero).sum().expand(1)
- return grad_input, None, grad_alpha
-
-
-pact_quantize = PACT_QuantFunc.apply
-
-
-class PACT_Act(torch.nn.Module):
- r"""PACT (PArametrized Clipping acTivation) activation.
- Implements a :py:class:`torch.nn.Module` to implement PACT-style activations. It is meant to replace :py:class:`torch.nn.ReLU`, :py:class:`torch.nn.ReLU6` and
- similar activations in a PACT-quantized network.
- This layer can also operate in a special mode, defined by the `statistics_only` member, in which the layer runs in
- forward-prop without quantization, collecting statistics on the activations that can then be
- used to reset the value of :math:`\alpha`.
- In this mode, the layer collects:
- - tensor-wise maximum value ever seen
- - running average with momentum 0.9
- - running variance with momentum 0.9
- """
-
- def __init__(
- self,
- precision=None,
- alpha=1.0,
- backprop_alpha=True,
- statistics_only=False,
- leaky=None,
- device=torch.device("cuda"),
- ):
- r"""Constructor. Initializes a :py:class:`torch.nn.Parameter` for :math:`\alpha` and sets
- up the initial value of the `statistics_only` member.
- :param precision: instance defining the current quantization level (default `None`).
- :type precision: :py:class:`nemo.precision.Precision`
- :param alpha: the value of :math:`\alpha`.
- :type alpha: `torch.Tensor` or float
- :param backprop_alpha: default `True`; if `False`, do not update the value of `\alpha` with backpropagation.
- :type backprop_alpha: bool
- :param statistics_only: initialization value of `statistics_only` member.
- :type statistics_only: bool
- """
-
- super(PACT_Act, self).__init__()
- self.precision = precision
- self.device = device
- self.alpha = torch.nn.Parameter(
- torch.Tensor((alpha,)).to(device), requires_grad=backprop_alpha
- )
- self.alpha_p = alpha
- self.statistics_only = statistics_only
- self.deployment = False
- self.eps_in = None
- self.leaky = leaky
- # self.requantization_factor = requantization_factor
-
- # these are only used to gather statistics
- self.max = torch.nn.Parameter(
- torch.zeros_like(self.alpha.data).to(device), requires_grad=False
- )
- self.min = torch.nn.Parameter(
- torch.zeros_like(self.alpha.data).to(device), requires_grad=False
- )
- self.running_mean = torch.nn.Parameter(
- torch.zeros_like(self.alpha.data).to(device), requires_grad=False
- )
- self.running_var = torch.nn.Parameter(
- torch.ones_like(self.alpha.data).to(device), requires_grad=False
- )
-
- self.precise = False
-
- def set_static_precision(self, limit_at_32_bits=True, **kwargs):
- r"""Sets static parameters used only for deployment."""
- # item() --> conversion to float
- # apparently causes a slight, but not invisibile, numerical divergence
- # between FQ and QD stages
- self.eps_static = self.alpha.clone().detach() / (2.0 ** (self.precision) - 1)
- self.alpha_static = self.alpha.clone().detach()
- # D is selected as a power-of-two
- D = 2.0 ** torch.ceil(
- torch.log2(self.requantization_factor * self.eps_static / self.eps_in)
- )
- if not limit_at_32_bits:
- self.D = D
- else:
- self.D = min(D, 2.0 ** (32 - 1 - (self.precision)))
-
- def get_output_eps(self, eps_in):
- r"""Get the output quantum (:math:`\varepsilon`) given the input one.
- :param eps_in: input quantum :math:`\varepsilon_{in}`.
- :type eps_in: :py:class:`torch.Tensor`
- :return: output quantum :math:`\varepsilon_{out}`.
- :rtype: :py:class:`torch.Tensor`
- """
-
- return self.alpha / (2.0 ** (self.precision) - 1)
-
- def reset_alpha(self, use_max=True, nb_std=5.0):
- r"""Reset the value of :math:`\alpha`. If `use_max` is `True`, then the highest tensor-wise value collected
- in the statistics collection phase is used. If `False`, the collected standard deviation multiplied by
- `nb_std` is used as a parameter
- :param use_max: if True, use the tensor-wise maximum value collected in the statistics run as new :math:`\alpha` (default True).
- :type use_max: bool
- :param nb_std: number of standard deviations to be used to initialize :math:`\alpha` if `use_max` is False.
- :type nb_std: float
- """
-
- if use_max:
- self.alpha.data[0] = self.max.item()
- else:
- self.alpha.data[0] = nb_std * torch.sqrt(self.running_var).item()
-
- def get_statistics(self):
- r"""Returns the statistics collected up to now.
-
- :return: The collected statistics (maximum, running average, running variance).
- :rtype: tuple of floats
- """
- return self.max.item(), self.running_mean.item(), self.running_var.item()
-
- def forward(self, x):
- r"""Forward-prop function for PACT-quantized activations.
-
- See :py:class:`nemo.quant.pact_quant.PACT_QuantFunc` for details on the normal operation performed by this layer.
- In statistics mode, it uses a normal ReLU and collects statistics in the background.
- :param x: input activations tensor.
- :type x: :py:class:`torch.Tensor`
-
- :return: output activations tensor.
- :rtype: :py:class:`torch.Tensor`
- """
-
- if self.statistics_only:
- if self.leaky is None:
- x = torch.nn.functional.relu(x)
- else:
- x = torch.nn.functional.leaky_relu(x, self.leaky)
- with torch.no_grad():
- self.max[:] = max(self.max.item(), x.max())
- self.min[:] = min(self.min.item(), x.min())
- self.running_mean[:] = 0.9 * self.running_mean.item() + 0.1 * x.mean()
- self.running_var[:] = (
- 0.9 * self.running_var.item() + 0.1 * x.std() * x.std()
- )
- return x
- else:
- eps = self.alpha / (2.0 ** (self.precision) - 1)
- return pact_quantize(x, eps, self.alpha + eps)
diff --git a/src/chop/nn/quantized/functional/gelu.py b/src/chop/nn/quantized/functional/gelu.py
index cee5e3317..225ff70ce 100644
--- a/src/chop/nn/quantized/functional/gelu.py
+++ b/src/chop/nn/quantized/functional/gelu.py
@@ -107,9 +107,7 @@ def gelu_log(x, inplace=False, config=None):
)
x_quantizer = partial(
- log_quantizer,
- width=x_width,
- exponent_bias=x_exponent_bias,
+ log_quantizer, width=x_width, exponent_bias=x_exponent_bias,
)
return F.gelu(x_quantizer(x), inplace=inplace)
diff --git a/src/chop/nn/quantized/functional/linear.py b/src/chop/nn/quantized/functional/linear.py
index 5fda700de..f518f4968 100644
--- a/src/chop/nn/quantized/functional/linear.py
+++ b/src/chop/nn/quantized/functional/linear.py
@@ -72,10 +72,7 @@ def linearInteger(
def linearMinifloatDenorm(
- x: Tensor,
- weight: Tensor,
- bias: Tensor = None,
- config: dict = None,
+ x: Tensor, weight: Tensor, bias: Tensor = None, config: dict = None,
):
w_width, w_exponent_width, w_exponent_bias = (
config["weight_width"],
@@ -122,10 +119,7 @@ def linearMinifloatDenorm(
def linearMinifloatIEEE(
- x: Tensor,
- weight: Tensor,
- bias: Tensor = None,
- config: dict = None,
+ x: Tensor, weight: Tensor, bias: Tensor = None, config: dict = None,
):
w_width, w_exponent_width, w_exponent_bias = (
config["weight_width"],
@@ -172,10 +166,7 @@ def linearMinifloatIEEE(
def linearMinifloatIEEE(
- x: Tensor,
- weight: Tensor,
- bias: Tensor = None,
- config: dict = None,
+ x: Tensor, weight: Tensor, bias: Tensor = None, config: dict = None,
):
w_width, w_exponent_width, w_exponent_bias = (
config["weight_width"],
@@ -222,10 +213,7 @@ def linearMinifloatIEEE(
def linearLog(
- x: Tensor,
- weight: Tensor,
- bias: Tensor = None,
- config: dict = None,
+ x: Tensor, weight: Tensor, bias: Tensor = None, config: dict = None,
):
w_width, w_exponent_bias = (
config["weight_width"],
@@ -240,23 +228,11 @@ def linearLog(
config["bias_exponent_bias"],
)
- w_quantizer = partial(
- log_quantizer,
- width=w_width,
- exponent_bias=w_exponent_bias,
- )
+ w_quantizer = partial(log_quantizer, width=w_width, exponent_bias=w_exponent_bias,)
- x_quantizer = partial(
- log_quantizer,
- width=x_width,
- exponent_bias=x_exponent_bias,
- )
+ x_quantizer = partial(log_quantizer, width=x_width, exponent_bias=x_exponent_bias,)
- b_quantizer = partial(
- log_quantizer,
- width=b_width,
- exponent_bias=b_exponent_bias,
- )
+ b_quantizer = partial(log_quantizer, width=b_width, exponent_bias=b_exponent_bias,)
x = x_quantizer(x)
weight = w_quantizer(weight)
@@ -266,10 +242,7 @@ def linearLog(
def linearBlockFP(
- x: Tensor,
- weight: Tensor,
- bias: Tensor = None,
- config: dict = None,
+ x: Tensor, weight: Tensor, bias: Tensor = None, config: dict = None,
):
# establish quantizers
w_width, w_exponent_width, w_exponent_bias, w_block_size = (
@@ -327,10 +300,7 @@ def linearBlockFP(
def linearBlockMinifloat(
- x: Tensor,
- weight: Tensor,
- bias: Tensor = None,
- config: dict = None,
+ x: Tensor, weight: Tensor, bias: Tensor = None, config: dict = None,
):
# establish quantizers
w_width, w_exponent_width, w_exponent_bias_width, w_block_size = (
@@ -388,10 +358,7 @@ def linearBlockMinifloat(
def linearBlockLog(
- x: Tensor,
- weight: Tensor,
- bias: Tensor = None,
- config: dict = None,
+ x: Tensor, weight: Tensor, bias: Tensor = None, config: dict = None,
):
# establish quantizers
w_width, w_exponent_bias_width, w_block_size = (
@@ -443,10 +410,7 @@ def linearBlockLog(
def linearBinary(
- x: Tensor,
- weight: Tensor,
- bias: Tensor = None,
- config: dict = None,
+ x: Tensor, weight: Tensor, bias: Tensor = None, config: dict = None,
):
w_stochastic = config["weight_stochastic"]
w_bipolar = config["weight_bipolar"]
@@ -462,10 +426,7 @@ def linearBinary(
def linearBinaryScaling(
- x: Tensor,
- weight: Tensor,
- bias: Tensor = None,
- config: dict = None,
+ x: Tensor, weight: Tensor, bias: Tensor = None, config: dict = None,
):
"""
Binary scaling variant of the linear transformation layer.
@@ -513,10 +474,7 @@ def linearBinaryScaling(
def linearTernary(
- x: Tensor,
- weight: Tensor,
- bias: Tensor = None,
- config: dict = None,
+ x: Tensor, weight: Tensor, bias: Tensor = None, config: dict = None,
):
w_scaling_factor = config["weight_scaling_factor"]
w_mean = get_stats(config, "weight_mean")
@@ -540,28 +498,19 @@ def linearTernary(
def linearBinaryResidualSign(
- x: Tensor,
- weight: Tensor,
- bias: Tensor = None,
- config: dict = None,
+ x: Tensor, weight: Tensor, bias: Tensor = None, config: dict = None,
):
raise NotImplementedError
def linearLUT(
- x: Tensor,
- weight: Tensor,
- bias: Tensor = None,
- config: dict = None,
+ x: Tensor, weight: Tensor, bias: Tensor = None, config: dict = None,
):
raise NotImplementedError
def linearLogicNets(
- x: Tensor,
- weight: Tensor,
- bias: Tensor = None,
- config: dict = None,
+ x: Tensor, weight: Tensor, bias: Tensor = None, config: dict = None,
):
raise NotImplementedError
diff --git a/src/chop/nn/quantized/functional/matmul.py b/src/chop/nn/quantized/functional/matmul.py
index d06eb1ece..f6487dc8f 100644
--- a/src/chop/nn/quantized/functional/matmul.py
+++ b/src/chop/nn/quantized/functional/matmul.py
@@ -176,14 +176,10 @@ def generic_matmul_log(x, y, config, style="matmul"):
)
x_quantizer = partial(
- log_quantizer,
- width=x_width,
- exponent_bias=x_exponent_bias,
+ log_quantizer, width=x_width, exponent_bias=x_exponent_bias,
)
y_quantizer = partial(
- log_quantizer,
- width=y_width,
- exponent_bias=y_exponent_bias,
+ log_quantizer, width=y_width, exponent_bias=y_exponent_bias,
)
x = x_quantizer(x)
y = y_quantizer(y)
diff --git a/src/chop/nn/quantized/functional/relu.py b/src/chop/nn/quantized/functional/relu.py
index 57daed04a..cb57d078e 100644
--- a/src/chop/nn/quantized/functional/relu.py
+++ b/src/chop/nn/quantized/functional/relu.py
@@ -107,9 +107,7 @@ def relu_log(x, inplace=False, config=None):
)
x_quantizer = partial(
- log_quantizer,
- width=x_width,
- exponent_bias=x_exponent_bias,
+ log_quantizer, width=x_width, exponent_bias=x_exponent_bias,
)
return F.relu(x_quantizer(x), inplace=inplace)
diff --git a/src/chop/nn/quantized/functional/selu.py b/src/chop/nn/quantized/functional/selu.py
index 12956c392..b1edebdbf 100644
--- a/src/chop/nn/quantized/functional/selu.py
+++ b/src/chop/nn/quantized/functional/selu.py
@@ -107,9 +107,7 @@ def selu_log(x, inplace=False, config=None):
)
x_quantizer = partial(
- log_quantizer,
- width=x_width,
- exponent_bias=x_exponent_bias,
+ log_quantizer, width=x_width, exponent_bias=x_exponent_bias,
)
return F.selu(x_quantizer(x), inplace=inplace)
diff --git a/src/chop/nn/quantized/functional/softplus.py b/src/chop/nn/quantized/functional/softplus.py
index a9bd7dafc..c873e7104 100644
--- a/src/chop/nn/quantized/functional/softplus.py
+++ b/src/chop/nn/quantized/functional/softplus.py
@@ -107,9 +107,7 @@ def softplus_log(x, inplace=False, config=None):
)
x_quantizer = partial(
- log_quantizer,
- width=x_width,
- exponent_bias=x_exponent_bias,
+ log_quantizer, width=x_width, exponent_bias=x_exponent_bias,
)
return F.softplus(x_quantizer(x), inplace=inplace)
diff --git a/src/chop/nn/quantized/functional/softsign.py b/src/chop/nn/quantized/functional/softsign.py
index 3eaab47fa..c60f5f757 100644
--- a/src/chop/nn/quantized/functional/softsign.py
+++ b/src/chop/nn/quantized/functional/softsign.py
@@ -107,9 +107,7 @@ def softsign_log(x, inplace=False, config=None):
)
x_quantizer = partial(
- log_quantizer,
- width=x_width,
- exponent_bias=x_exponent_bias,
+ log_quantizer, width=x_width, exponent_bias=x_exponent_bias,
)
return F.softsign(x_quantizer(x), inplace=inplace)
diff --git a/src/chop/nn/quantized/functional/tanh.py b/src/chop/nn/quantized/functional/tanh.py
index 7d3c67c31..8b1009ac0 100644
--- a/src/chop/nn/quantized/functional/tanh.py
+++ b/src/chop/nn/quantized/functional/tanh.py
@@ -107,9 +107,7 @@ def tanh_log(x, inplace=False, config=None):
)
x_quantizer = partial(
- log_quantizer,
- width=x_width,
- exponent_bias=x_exponent_bias,
+ log_quantizer, width=x_width, exponent_bias=x_exponent_bias,
)
return F.tanh(x_quantizer(x), inplace=inplace)
diff --git a/src/chop/nn/quantized/modules/__init__.py b/src/chop/nn/quantized/modules/__init__.py
index 4219d6da9..6759fc67f 100644
--- a/src/chop/nn/quantized/modules/__init__.py
+++ b/src/chop/nn/quantized/modules/__init__.py
@@ -65,9 +65,7 @@
BatchNorm2dInteger,
BatchNorm2dBinary,
)
-from .layer_norm import (
- LayerNormInteger,
-)
+from .layer_norm import LayerNormInteger
from .group_norm import GroupNormInteger
from .instance_norm2d import InstanceNorm2dInteger
@@ -144,12 +142,8 @@
SoftplusBinary,
SoftplusTernary,
)
-from .batch_norm1d import (
- BatchNorm1dInteger,
-)
-from .gqa import (
- GroupedQueryAttentionInteger,
-)
+from .batch_norm1d import BatchNorm1dInteger
+from .gqa import GroupedQueryAttentionInteger
quantized_module_map = {
"conv1d_block_minifloat": Conv1dBlockMinifloat,
diff --git a/src/chop/nn/quantized/modules/attention.py b/src/chop/nn/quantized/modules/attention.py
index 45819db75..315a74ec3 100644
--- a/src/chop/nn/quantized/modules/attention.py
+++ b/src/chop/nn/quantized/modules/attention.py
@@ -6,9 +6,7 @@
from transformers.models.bert.modeling_bert import BertSelfAttention
-from chop.nn.quantized.modules.linear import (
- LinearInteger,
-)
+from chop.nn.quantized.modules.linear import LinearInteger
from chop.nn.quantized.functional import fixed_softermax
from chop.nn.quantized.functional import matmul_integer
diff --git a/src/chop/nn/quantized/modules/attention_head.py b/src/chop/nn/quantized/modules/attention_head.py
index 8f9ea5969..8176f4d52 100644
--- a/src/chop/nn/quantized/modules/attention_head.py
+++ b/src/chop/nn/quantized/modules/attention_head.py
@@ -6,9 +6,7 @@
from typing import Optional, Tuple
from functools import partial
-from chop.nn.quantized.functional.matmul import (
- generic_matmul_integer,
-)
+from chop.nn.quantized.functional.matmul import generic_matmul_integer
from chop.nn.quantizers.integer import integer_quantizer
@@ -65,10 +63,7 @@ class BertSelfAttentionHeadInteger(_BertSelfAttentionHeadBase):
def __init__(self, config, q_config: dict = None) -> None:
super().__init__(config)
- self.query_quantizer = partial(
- integer_quantizer,
- **q_config,
- )
+ self.query_quantizer = partial(integer_quantizer, **q_config,)
self.key_quantizer = partial(integer_quantizer, **q_config)
self.value_quantizer = partial(integer_quantizer, **q_config)
diff --git a/src/chop/nn/quantized/modules/batch_norm1d.py b/src/chop/nn/quantized/modules/batch_norm1d.py
index b84c0d131..bafb96f9b 100644
--- a/src/chop/nn/quantized/modules/batch_norm1d.py
+++ b/src/chop/nn/quantized/modules/batch_norm1d.py
@@ -4,9 +4,7 @@
from torch import Tensor
from torch.nn import functional as F
-from chop.nn.quantizers import (
- integer_quantizer,
-)
+from chop.nn.quantizers import integer_quantizer
class _BatchNorm1dBase(torch.nn.BatchNorm1d):
diff --git a/src/chop/nn/quantized/modules/conv1d.py b/src/chop/nn/quantized/modules/conv1d.py
index 67654d917..03662097a 100644
--- a/src/chop/nn/quantized/modules/conv1d.py
+++ b/src/chop/nn/quantized/modules/conv1d.py
@@ -274,21 +274,15 @@ def __init__(
)
self.w_quantizer = partial(
- log_quantizer,
- width=w_width,
- exponent_bias=w_exponent_bias,
+ log_quantizer, width=w_width, exponent_bias=w_exponent_bias,
)
self.x_quantizer = partial(
- log_quantizer,
- width=x_width,
- exponent_bias=x_exponent_bias,
+ log_quantizer, width=x_width, exponent_bias=x_exponent_bias,
)
self.b_quantizer = partial(
- log_quantizer,
- width=b_width,
- exponent_bias=b_exponent_bias,
+ log_quantizer, width=b_width, exponent_bias=b_exponent_bias,
)
diff --git a/src/chop/nn/quantized/modules/conv2d.py b/src/chop/nn/quantized/modules/conv2d.py
index cc8cd982a..3d297842e 100644
--- a/src/chop/nn/quantized/modules/conv2d.py
+++ b/src/chop/nn/quantized/modules/conv2d.py
@@ -363,21 +363,15 @@ def __init__(
)
self.w_quantizer = partial(
- log_quantizer,
- width=w_width,
- exponent_bias=w_exponent_bias,
+ log_quantizer, width=w_width, exponent_bias=w_exponent_bias,
)
self.x_quantizer = partial(
- log_quantizer,
- width=x_width,
- exponent_bias=x_exponent_bias,
+ log_quantizer, width=x_width, exponent_bias=x_exponent_bias,
)
self.b_quantizer = partial(
- log_quantizer,
- width=b_width,
- exponent_bias=b_exponent_bias,
+ log_quantizer, width=b_width, exponent_bias=b_exponent_bias,
)
@@ -430,21 +424,15 @@ def __init__(
)
self.w_quantizer = partial(
- log_quantizer,
- width=w_width,
- exponent_bias=w_exponent_bias,
+ log_quantizer, width=w_width, exponent_bias=w_exponent_bias,
)
self.x_quantizer = partial(
- log_quantizer,
- width=x_width,
- exponent_bias=x_exponent_bias,
+ log_quantizer, width=x_width, exponent_bias=x_exponent_bias,
)
self.b_quantizer = partial(
- log_quantizer,
- width=b_width,
- exponent_bias=b_exponent_bias,
+ log_quantizer, width=b_width, exponent_bias=b_exponent_bias,
)
@@ -1140,10 +1128,7 @@ def __init__(
)
self.unfold = torch.nn.Unfold(
- kernel_size=kernel_size,
- dilation=dilation,
- padding=padding,
- stride=stride,
+ kernel_size=kernel_size, dilation=dilation, padding=padding, stride=stride,
)
self.fold = torch.nn.Fold(
@@ -1243,11 +1228,7 @@ def forward(
expanded_input, targets, initalize
).squeeze() # [10, 589824]
output = output.view(
- batch_size,
- self.out_channels,
- self._out_dim(0),
- self._out_dim(1),
- -1,
+ batch_size, self.out_channels, self._out_dim(0), self._out_dim(1), -1,
).sum(
-1
) # [10, 256, 1, 1, 2304] -> [10, 256, 1, 1]
@@ -1430,10 +1411,10 @@ def forward(self, x: Tensor) -> Tensor:
return self.decode(self.lut_forward(x))
def encode(self, input: Tensor) -> Tensor:
- return input * 2**self.x_frac_width
+ return input * 2 ** self.x_frac_width
def decode(self, input: Tensor) -> Tensor:
- return input / 2**self.x_frac_width
+ return input / 2 ** self.x_frac_width
def math_forward(self, input: Tensor) -> Tensor:
return self.y_quantizer(
diff --git a/src/chop/nn/quantized/modules/gelu.py b/src/chop/nn/quantized/modules/gelu.py
index 4e579efd1..ace6f5510 100644
--- a/src/chop/nn/quantized/modules/gelu.py
+++ b/src/chop/nn/quantized/modules/gelu.py
@@ -123,9 +123,7 @@ def __init__(self, inplace: bool = False, config: dict = None):
config["data_in_exponent_bias"],
)
self.x_quantizer = partial(
- log_quantizer,
- width=x_width,
- exponent_bias=x_exponent_bias,
+ log_quantizer, width=x_width, exponent_bias=x_exponent_bias,
)
@@ -143,9 +141,7 @@ def __init__(self, inplace: bool = False, config: dict = None):
config["data_in_exponent_bias"],
)
self.x_quantizer = partial(
- log_quantizer,
- width=x_width,
- exponent_bias=x_exponent_bias,
+ log_quantizer, width=x_width, exponent_bias=x_exponent_bias,
)
diff --git a/src/chop/nn/quantized/modules/gqa.py b/src/chop/nn/quantized/modules/gqa.py
index 1445b05ec..6cac74dda 100644
--- a/src/chop/nn/quantized/modules/gqa.py
+++ b/src/chop/nn/quantized/modules/gqa.py
@@ -89,10 +89,7 @@ def __init__(
)
self.v_matmul_func = partial(
- matmul_integer,
- config=config,
- out_config=out_config,
- floor=floor,
+ matmul_integer, config=config, out_config=out_config, floor=floor,
)
o_projection_q_config = {
diff --git a/src/chop/nn/quantized/modules/group_norm.py b/src/chop/nn/quantized/modules/group_norm.py
index a90e5b651..25721aae9 100644
--- a/src/chop/nn/quantized/modules/group_norm.py
+++ b/src/chop/nn/quantized/modules/group_norm.py
@@ -7,9 +7,7 @@
from torch import Tensor
import torch.nn.functional as F
-from chop.nn.quantizers import (
- integer_quantizer,
-)
+from chop.nn.quantizers import integer_quantizer
from mase_components.scalar_operators.fixed.test.isqrt_sw import isqrt_sw2
diff --git a/src/chop/nn/quantized/modules/instance_norm2d.py b/src/chop/nn/quantized/modules/instance_norm2d.py
index 0a7260443..d3946401f 100644
--- a/src/chop/nn/quantized/modules/instance_norm2d.py
+++ b/src/chop/nn/quantized/modules/instance_norm2d.py
@@ -4,9 +4,7 @@
from torch import Tensor
import torch.nn.functional as F
-from chop.nn.quantizers import (
- integer_quantizer,
-)
+from chop.nn.quantizers import integer_quantizer
class _InstanceNorm2dBase(nn.InstanceNorm2d):
diff --git a/src/chop/nn/quantized/modules/layer_norm.py b/src/chop/nn/quantized/modules/layer_norm.py
index 2ca5c6068..0d7e4d413 100644
--- a/src/chop/nn/quantized/modules/layer_norm.py
+++ b/src/chop/nn/quantized/modules/layer_norm.py
@@ -4,9 +4,7 @@
from torch import Tensor
import torch.nn.functional as F
-from chop.nn.quantizers import (
- integer_quantizer,
-)
+from chop.nn.quantizers import integer_quantizer
class _LayerNormBase(nn.LayerNorm):
diff --git a/src/chop/nn/quantized/modules/linear.py b/src/chop/nn/quantized/modules/linear.py
index 5d8d389a5..b49c86c54 100644
--- a/src/chop/nn/quantized/modules/linear.py
+++ b/src/chop/nn/quantized/modules/linear.py
@@ -66,11 +66,7 @@ def __init__(
dtype=None,
) -> None:
super().__init__(
- in_features,
- out_features,
- bias,
- device,
- dtype,
+ in_features, out_features, bias, device, dtype,
)
self.bypass = False
self.pruning_masks = None
@@ -448,9 +444,7 @@ def forward(self, x: Tensor) -> Tensor:
if self.binary_training:
w = self.w_quantizer(self.weight)
return F.linear(
- x_expanded,
- w * self.gamma.abs() * self.pruning_masks,
- self.bias,
+ x_expanded, w * self.gamma.abs() * self.pruning_masks, self.bias,
)
else:
self.weigh = self.weight.data.clamp_(-1, 1)
@@ -560,9 +554,7 @@ def forward(
output = output.view(batch_size, -1)
assert output.shape[-1] == self.tables_count
output = output.view(
- batch_size,
- self.out_features,
- int(self.tables_count / self.out_features),
+ batch_size, self.out_features, int(self.tables_count / self.out_features),
)
output = output.sum(-1)
if self.bias is not None:
@@ -773,10 +765,10 @@ def run_layers(self, input: Tensor, layers) -> Tensor:
return y
def encode(self, input: Tensor) -> Tensor:
- return input * 2**self.x_frac_width
+ return input * 2 ** self.x_frac_width
def decode(self, input: Tensor) -> Tensor:
- return input / 2**self.x_frac_width
+ return input / 2 ** self.x_frac_width
def forward(self, x: Tensor) -> Tensor:
if self.is_lut_inference:
diff --git a/src/chop/nn/quantized/modules/relu.py b/src/chop/nn/quantized/modules/relu.py
index 2bc527161..a4d1acfbf 100644
--- a/src/chop/nn/quantized/modules/relu.py
+++ b/src/chop/nn/quantized/modules/relu.py
@@ -121,9 +121,7 @@ def __init__(self, inplace: bool = False, config: dict = None):
config["data_in_exponent_bias"],
)
self.x_quantizer = partial(
- log_quantizer,
- width=x_width,
- exponent_bias=x_exponent_bias,
+ log_quantizer, width=x_width, exponent_bias=x_exponent_bias,
)
@@ -141,9 +139,7 @@ def __init__(self, inplace: bool = False, config: dict = None):
config["data_in_exponent_bias"],
)
self.x_quantizer = partial(
- log_quantizer,
- width=x_width,
- exponent_bias=x_exponent_bias,
+ log_quantizer, width=x_width, exponent_bias=x_exponent_bias,
)
diff --git a/src/chop/nn/quantized/modules/rms_norm.py b/src/chop/nn/quantized/modules/rms_norm.py
index 91dd9d9d6..a6b893b32 100644
--- a/src/chop/nn/quantized/modules/rms_norm.py
+++ b/src/chop/nn/quantized/modules/rms_norm.py
@@ -5,9 +5,7 @@
from torch import Tensor
import torch.nn.functional as F
-from chop.nn.quantizers import (
- integer_quantizer,
-)
+from chop.nn.quantizers import integer_quantizer
def _rms_norm(x: Tensor, eps, scale: Tensor | None):
diff --git a/src/chop/nn/quantized/modules/selu.py b/src/chop/nn/quantized/modules/selu.py
index 066ffc0b7..482c7c5d1 100644
--- a/src/chop/nn/quantized/modules/selu.py
+++ b/src/chop/nn/quantized/modules/selu.py
@@ -121,9 +121,7 @@ def __init__(self, inplace: bool = False, config: dict = None):
config["data_in_exponent_bias"],
)
self.x_quantizer = partial(
- log_quantizer,
- width=x_width,
- exponent_bias=x_exponent_bias,
+ log_quantizer, width=x_width, exponent_bias=x_exponent_bias,
)
@@ -141,9 +139,7 @@ def __init__(self, inplace: bool = False, config: dict = None):
config["data_in_exponent_bias"],
)
self.x_quantizer = partial(
- log_quantizer,
- width=x_width,
- exponent_bias=x_exponent_bias,
+ log_quantizer, width=x_width, exponent_bias=x_exponent_bias,
)
diff --git a/src/chop/nn/quantized/modules/silu.py b/src/chop/nn/quantized/modules/silu.py
index 07f18ef6e..30dac9e5d 100644
--- a/src/chop/nn/quantized/modules/silu.py
+++ b/src/chop/nn/quantized/modules/silu.py
@@ -113,9 +113,7 @@ def __init__(self, inplace: bool = False, config: dict = None):
config["data_in_exponent_bias"],
)
self.x_quantizer = partial(
- log_quantizer,
- width=x_width,
- exponent_bias=x_exponent_bias,
+ log_quantizer, width=x_width, exponent_bias=x_exponent_bias,
)
@@ -133,9 +131,7 @@ def __init__(self, inplace: bool = False, config: dict = None):
config["data_in_exponent_bias"],
)
self.x_quantizer = partial(
- log_quantizer,
- width=x_width,
- exponent_bias=x_exponent_bias,
+ log_quantizer, width=x_width, exponent_bias=x_exponent_bias,
)
diff --git a/src/chop/nn/quantized/modules/softplus.py b/src/chop/nn/quantized/modules/softplus.py
index 4e8465c56..458558b63 100644
--- a/src/chop/nn/quantized/modules/softplus.py
+++ b/src/chop/nn/quantized/modules/softplus.py
@@ -121,9 +121,7 @@ def __init__(self, inplace: bool = False, config: dict = None):
config["data_in_exponent_bias"],
)
self.x_quantizer = partial(
- log_quantizer,
- width=x_width,
- exponent_bias=x_exponent_bias,
+ log_quantizer, width=x_width, exponent_bias=x_exponent_bias,
)
@@ -141,9 +139,7 @@ def __init__(self, inplace: bool = False, config: dict = None):
config["data_in_exponent_bias"],
)
self.x_quantizer = partial(
- log_quantizer,
- width=x_width,
- exponent_bias=x_exponent_bias,
+ log_quantizer, width=x_width, exponent_bias=x_exponent_bias,
)
diff --git a/src/chop/nn/quantized/modules/softsign.py b/src/chop/nn/quantized/modules/softsign.py
index 5497426aa..fe3e53b62 100644
--- a/src/chop/nn/quantized/modules/softsign.py
+++ b/src/chop/nn/quantized/modules/softsign.py
@@ -121,9 +121,7 @@ def __init__(self, inplace: bool = False, config: dict = None):
config["data_in_exponent_bias"],
)
self.x_quantizer = partial(
- log_quantizer,
- width=x_width,
- exponent_bias=x_exponent_bias,
+ log_quantizer, width=x_width, exponent_bias=x_exponent_bias,
)
@@ -141,9 +139,7 @@ def __init__(self, inplace: bool = False, config: dict = None):
config["data_in_exponent_bias"],
)
self.x_quantizer = partial(
- log_quantizer,
- width=x_width,
- exponent_bias=x_exponent_bias,
+ log_quantizer, width=x_width, exponent_bias=x_exponent_bias,
)
diff --git a/src/chop/nn/quantized/modules/tanh.py b/src/chop/nn/quantized/modules/tanh.py
index fce343489..3378a612f 100644
--- a/src/chop/nn/quantized/modules/tanh.py
+++ b/src/chop/nn/quantized/modules/tanh.py
@@ -121,9 +121,7 @@ def __init__(self, inplace: bool = False, config: dict = None):
config["data_in_exponent_bias"],
)
self.x_quantizer = partial(
- log_quantizer,
- width=x_width,
- exponent_bias=x_exponent_bias,
+ log_quantizer, width=x_width, exponent_bias=x_exponent_bias,
)
@@ -141,9 +139,7 @@ def __init__(self, inplace: bool = False, config: dict = None):
config["data_in_exponent_bias"],
)
self.x_quantizer = partial(
- log_quantizer,
- width=x_width,
- exponent_bias=x_exponent_bias,
+ log_quantizer, width=x_width, exponent_bias=x_exponent_bias,
)
diff --git a/src/chop/nn/quantizers/LUTNet/BaseInitializer.py b/src/chop/nn/quantizers/LUTNet/BaseInitializer.py
index 14cf53839..72478f46d 100644
--- a/src/chop/nn/quantizers/LUTNet/BaseInitializer.py
+++ b/src/chop/nn/quantizers/LUTNet/BaseInitializer.py
@@ -62,10 +62,7 @@ def update_luts_weights(self) -> torch.Tensor:
key = row.detach().cpu().flatten().sign().numpy().tolist()
new_weights.append(key)
new_weights = torch.tensor(
- new_weights,
- dtype=torch.float32,
- requires_grad=True,
- device=self.device,
+ new_weights, dtype=torch.float32, requires_grad=True, device=self.device,
).view(-1, self.kk)
return new_weights
diff --git a/src/chop/nn/quantizers/LUTNet/BaseTrainer.py b/src/chop/nn/quantizers/LUTNet/BaseTrainer.py
index ff48436e0..45d95805c 100644
--- a/src/chop/nn/quantizers/LUTNet/BaseTrainer.py
+++ b/src/chop/nn/quantizers/LUTNet/BaseTrainer.py
@@ -38,7 +38,7 @@ def __init__(
levels (int): Number of residual level to use.
"""
self.k = k
- self.kk = 2**k
+ self.kk = 2 ** k
self.binarization_level = binarization_level
self.input_expanded = input_expanded
self.tables_count = tables_count
diff --git a/src/chop/nn/quantizers/block_fp.py b/src/chop/nn/quantizers/block_fp.py
index bfe8534e3..826133862 100644
--- a/src/chop/nn/quantizers/block_fp.py
+++ b/src/chop/nn/quantizers/block_fp.py
@@ -47,10 +47,10 @@ def _block_fp_quantize(
if exponent_bias in (None, "none", "None"):
exponent_bias = 2 ** (exponent_width - 1) - 1
- exponent_max = 2**exponent_width - 1 - exponent_bias
+ exponent_max = 2 ** exponent_width - 1 - exponent_bias
exponent_min = -exponent_bias
- mantissa_integer_max = 2**mantissa_bits - 1
+ mantissa_integer_max = 2 ** mantissa_bits - 1
# sign
per_block_sign = torch.sign(blocked_x + 1e-9)
# exponent
@@ -58,14 +58,14 @@ def _block_fp_quantize(
per_block_exponent = torch.ceil(torch.log2(per_block_max))
per_block_exponent = my_clamp(per_block_exponent, exponent_min, exponent_max)
# mantissa
- per_block_mantissa = per_block_value / 2**per_block_exponent
- shift = 2**mantissa_bits
+ per_block_mantissa = per_block_value / 2 ** per_block_exponent
+ shift = 2 ** mantissa_bits
per_block_mantissa_integer = my_clamp(
my_round(per_block_mantissa * shift), 0, mantissa_integer_max
)
per_block_mantissa = per_block_mantissa_integer / shift
- per_block_msfp = per_block_sign * (2**per_block_exponent) * per_block_mantissa
+ per_block_msfp = per_block_sign * (2 ** per_block_exponent) * per_block_mantissa
msfp_x = unblock(
per_block_msfp,
x_shape_before_blocking=x_shape_before_blocking,
@@ -133,10 +133,5 @@ def block_fp_quantizer(
"""
return BlockFPQuantize.apply(
- x,
- width,
- exponent_width,
- exponent_bias,
- block_size,
- skip_first_dim,
+ x, width, exponent_width, exponent_bias, block_size, skip_first_dim,
)
diff --git a/src/chop/nn/quantizers/block_log.py b/src/chop/nn/quantizers/block_log.py
index 8e65c91ea..1773b42b8 100644
--- a/src/chop/nn/quantizers/block_log.py
+++ b/src/chop/nn/quantizers/block_log.py
@@ -40,7 +40,7 @@ def _block_log_quantize(
per_block_max_exponent = torch.ceil(torch.log2(per_block_max))
per_block_bias = my_clamp(
- 2**exponent_bits - 1 - per_block_max_exponent, 0, 2**exponent_bias_width - 1
+ 2 ** exponent_bits - 1 - per_block_max_exponent, 0, 2 ** exponent_bias_width - 1
)
per_block_lq_x = _log_quantize(blocked_x, width=width, exponent_bias=per_block_bias)
@@ -98,9 +98,5 @@ def block_log_quantizer(
- `block_size`: a list of integers where each integer is the block size along the corresponding dim
"""
return BlockLogQuantize.apply(
- x,
- width,
- exponent_bias_width,
- block_size,
- skip_first_dim,
+ x, width, exponent_bias_width, block_size, skip_first_dim,
)
diff --git a/src/chop/nn/quantizers/block_minifloat.py b/src/chop/nn/quantizers/block_minifloat.py
index 34e00bbcb..ccef6649d 100644
--- a/src/chop/nn/quantizers/block_minifloat.py
+++ b/src/chop/nn/quantizers/block_minifloat.py
@@ -41,7 +41,7 @@ def _block_minifloat_quantize(
per_block_max[per_block_max == 0] = per_block_max[per_block_max != 0].min()
per_block_exponent_bias = my_clamp(
- torch.floor(torch.log2(per_block_max)), 0, 2**exponent_bias_width - 1
+ torch.floor(torch.log2(per_block_max)), 0, 2 ** exponent_bias_width - 1
)
per_block_bm_x = _minifloat_ieee_quantize(
blocked_x,
@@ -118,10 +118,5 @@ def block_minifloat_quantizer(
"""
return BlockMinifloatQuantize.apply(
- x,
- width,
- exponent_width,
- exponent_bias_width,
- block_size,
- skip_first_dim,
+ x, width, exponent_width, exponent_bias_width, block_size, skip_first_dim,
)
diff --git a/src/chop/nn/quantizers/integer.py b/src/chop/nn/quantizers/integer.py
index 9f3ffec1e..8c2573f22 100644
--- a/src/chop/nn/quantizers/integer.py
+++ b/src/chop/nn/quantizers/integer.py
@@ -34,9 +34,9 @@ def _integer_quantize(
int_max = 2 ** (width - 1) - 1
else:
int_min = 0
- int_max = 2**width - 1
+ int_max = 2 ** width - 1
# thresh = 2 ** (width - 1)
- scale = 2**frac_width
+ scale = 2 ** frac_width
if isinstance(x, (Tensor, ndarray)):
return my_clamp(my_round(x.mul(scale)), int_min, int_max).div(scale)
@@ -57,8 +57,8 @@ def _integer_floor_quantize(
int_max = 2 ** (width - 1) - 1
else:
int_min = 0
- int_max = 2**width - 1
- scale = 2**frac_width
+ int_max = 2 ** width - 1
+ scale = 2 ** frac_width
if isinstance(x, (Tensor, ndarray)):
return my_clamp(my_floor(x.mul(scale)), int_min, int_max).div(scale)
diff --git a/src/chop/nn/quantizers/log.py b/src/chop/nn/quantizers/log.py
index 98ab45e76..926e33cdb 100644
--- a/src/chop/nn/quantizers/log.py
+++ b/src/chop/nn/quantizers/log.py
@@ -8,9 +8,7 @@
def _log_quantize(
- x: Tensor | ndarray,
- width: int,
- exponent_bias: int | Tensor | ndarray | None,
+ x: Tensor | ndarray, width: int, exponent_bias: int | Tensor | ndarray | None,
):
"""
- Use non-uniform, base-2 logarithmic representation to encode IEEE FP32/64
@@ -32,16 +30,16 @@ def _log_quantize(
if exponent_bias in (None, "none", "None"):
exponent_bias = 2 ** (exponent_bits - 1) - 1
- exponent_max = 2**exponent_bits - 1 - exponent_bias
+ exponent_max = 2 ** exponent_bits - 1 - exponent_bias
exponent_min = -exponent_bias
- min_pos = 2**exponent_min
+ min_pos = 2 ** exponent_min
sign = torch.sign(x + min_pos * 0.1)
value = torch.abs(x) + min_pos * 0.1
exponent = my_clamp(my_round(torch.log2(value)), exponent_min, exponent_max)
- return sign * (2**exponent)
+ return sign * (2 ** exponent)
class LogQuantize(torch.autograd.Function):
@@ -58,9 +56,7 @@ def backward(ctx, grad_output):
def log_quantizer(
- x: Tensor | ndarray,
- width: int,
- exponent_bias: int | Tensor | ndarray | None,
+ x: Tensor | ndarray, width: int, exponent_bias: int | Tensor | ndarray | None,
):
"""
Convert IEEE FP32/64 to base-2 log quantized values
diff --git a/src/chop/nn/quantizers/minifloat.py b/src/chop/nn/quantizers/minifloat.py
index 2d6f23103..f19097fde 100644
--- a/src/chop/nn/quantizers/minifloat.py
+++ b/src/chop/nn/quantizers/minifloat.py
@@ -5,10 +5,7 @@
def _minifloat_denorm_quantize(
- x: Tensor,
- width: int,
- exponent_width: int,
- exponent_bias: int = None,
+ x: Tensor, width: int, exponent_width: int, exponent_bias: int = None,
):
"""
- Converts IEEE FP32/64 to minifloat without the implicit leading bit in mantissas.
@@ -37,10 +34,10 @@ def _minifloat_denorm_quantize(
if exponent_bias in (None, "none", "None"):
exponent_bias = 2 ** (exponent_width - 1) - 1
- exponent_max = 2**exponent_width - 1 - exponent_bias
+ exponent_max = 2 ** exponent_width - 1 - exponent_bias
exponent_min = -exponent_bias
# if the mantissa is an integer, the max mantissa value will be (2**mantissa_bits -1)
- shifted_mantissa_max = 2**mantissa_bits - 1
+ shifted_mantissa_max = 2 ** mantissa_bits - 1
shifted_mantissa_min = 0
sign = torch.sign(x + 1e-9)
@@ -52,8 +49,8 @@ def _minifloat_denorm_quantize(
# divide value by clipped exponent. this ensures the simulated minifloat value is correct
# when x is too large (minifloat will saturate) or too close to 0.
- mantissa = value / 2**exponent
- shift = 2**mantissa_bits
+ mantissa = value / 2 ** exponent
+ shift = 2 ** mantissa_bits
shifted_mantissa = my_round(mantissa * shift)
# clip the integer mantissa.
shifted_mantissa = my_clamp(
@@ -71,11 +68,7 @@ def _minifloat_denorm_quantize(
class MinifloatDenormQuantize(torch.autograd.Function):
@staticmethod
def forward(
- ctx,
- x: Tensor,
- width: int,
- exponent_width: int,
- exponent_bias: int = None,
+ ctx, x: Tensor, width: int, exponent_width: int, exponent_bias: int = None,
):
return _minifloat_denorm_quantize(
x, width=width, exponent_width=exponent_width, exponent_bias=exponent_bias
@@ -88,10 +81,7 @@ def backward(ctx, grad_output):
def minifloat_denorm_quantizer(
- x: Tensor,
- width: int,
- exponent_width: int,
- exponent_bias: int = None,
+ x: Tensor, width: int, exponent_width: int, exponent_bias: int = None,
):
"""
- Converts IEEE FP32/64 to minifloat without the implicit leading bit in mantissas.
@@ -148,11 +138,11 @@ def _minifloat_ieee_quantize(
if exponent_bias in (None, "none", "None"):
exponent_bias = 2 ** (exponent_width - 1) - 1
# upper and lower bound of shifted exponent
- exponent_max = 2**exponent_width - 1 - exponent_bias
+ exponent_max = 2 ** exponent_width - 1 - exponent_bias
exponent_min = -exponent_bias
# upper and lower bound of shifted minifloat mantissa
- shift = 2**mantissa_bits
- shifted_mantissa_max = 2**mantissa_bits - 1
+ shift = 2 ** mantissa_bits
+ shifted_mantissa_max = 2 ** mantissa_bits - 1
shifted_mantissa_min = 0
sign = torch.sign(x + 1e-9)
@@ -162,9 +152,9 @@ def _minifloat_ieee_quantize(
exponent = torch.floor(torch.log2(value + 1e-9))
exponent = my_clamp(exponent, exponent_min, exponent_max)
- mantissa = value / 2**exponent
+ mantissa = value / 2 ** exponent
- shift = 2**mantissa_bits
+ shift = 2 ** mantissa_bits
# fmt: off
# if the clipped exponent is zero, the minifloat is in a subnormal form
# this `is_normal` also help the grad keeps 1 if input x is 0, or the zero-initialized value will be trapped in 0
diff --git a/src/chop/nn/quantizers/mxint_hardware.py b/src/chop/nn/quantizers/mxint_hardware.py
index 0c3e06130..1a61fb70c 100644
--- a/src/chop/nn/quantizers/mxint_hardware.py
+++ b/src/chop/nn/quantizers/mxint_hardware.py
@@ -19,7 +19,7 @@ def mxint_quant_block(
"""
exponent_bias = 2 ** (exponent_width - 1)
- exponent_max = 2**exponent_width - 1 - exponent_bias
+ exponent_max = 2 ** exponent_width - 1 - exponent_bias
exponent_min = -exponent_bias
# exponent
@@ -29,9 +29,9 @@ def mxint_quant_block(
# mantissa
int_min = -(2 ** (width - 1))
int_max = 2 ** (width - 1) - 1
- mantissa = x / 2**exponent
+ mantissa = x / 2 ** exponent
mantissa = torch.clamp(mantissa.floor(), int_min, int_max)
- q_x = (2**exponent) * mantissa
+ q_x = (2 ** exponent) * mantissa
return q_x
diff --git a/src/chop/nn/quantizers/quantizers_for_hw.py b/src/chop/nn/quantizers/quantizers_for_hw.py
index d5ca3d8cf..ccf57319f 100644
--- a/src/chop/nn/quantizers/quantizers_for_hw.py
+++ b/src/chop/nn/quantizers/quantizers_for_hw.py
@@ -9,31 +9,31 @@
def integer_quantizer_for_hw(x: Tensor, width: int, frac_width: int):
thresh = 2 ** (width - 1)
- scale = 2**frac_width
+ scale = 2 ** frac_width
fixed_point_value = my_clamp(my_round(x.mul(scale)), -thresh, thresh - 1)
fixed_point_value = fixed_point_value.to(torch.int)
- fixed_point_value = fixed_point_value % (2**width)
+ fixed_point_value = fixed_point_value % (2 ** width)
return fixed_point_value
def unsigned_integer_quantizer_for_hw(x: Tensor, width: int, frac_width: int):
- thresh = 2**width - 1
- scale = 2**frac_width
+ thresh = 2 ** width - 1
+ scale = 2 ** frac_width
fixed_point_value = my_clamp(my_floor(x.mul(scale)), 0, thresh)
fixed_point_value = fixed_point_value.to(torch.int)
- fixed_point_value = fixed_point_value % (2**width)
+ fixed_point_value = fixed_point_value % (2 ** width)
return fixed_point_value
def integer_floor_quantizer_for_hw(x: Tensor, width: int, frac_width: int):
thresh = 2 ** (width - 1)
- scale = 2**frac_width
+ scale = 2 ** frac_width
fixed_point_value = my_clamp(my_floor(x.mul(scale)), -thresh, thresh - 1)
fixed_point_value = fixed_point_value.to(torch.int)
- fixed_point_value = fixed_point_value % (2**width)
+ fixed_point_value = fixed_point_value % (2 ** width)
return fixed_point_value
diff --git a/src/chop/nn/quantizers/ternary.py b/src/chop/nn/quantizers/ternary.py
index a27951c23..89fcfe75b 100644
--- a/src/chop/nn/quantizers/ternary.py
+++ b/src/chop/nn/quantizers/ternary.py
@@ -45,8 +45,7 @@ def ternary_quantizer(
# )
if scaling_factor:
x = ternarised_scaled_op(
- x,
- threshold, # abs_mean=mean
+ x, threshold, # abs_mean=mean
) # [mean, 0 ,-mean] # this function determines the mean on the fly, maybe we could make an alternative which uses the metadata?
else:
x = ternarised_op(x, threshold) # [1, 0 ,-1]
diff --git a/src/chop/nn/quantizers/utils.py b/src/chop/nn/quantizers/utils.py
index ae881328d..9efc8b512 100644
--- a/src/chop/nn/quantizers/utils.py
+++ b/src/chop/nn/quantizers/utils.py
@@ -224,16 +224,8 @@ def forward(ctx, input, _threshold):
alpha = TernaryScaled.alpha(input, delta)
output = torch.zeros_like(input)
- pos_one = torch.where(
- input > delta,
- 1.0,
- 0.0,
- )
- neg_one = torch.where(
- input < -delta,
- -1.0,
- 0.0,
- )
+ pos_one = torch.where(input > delta, 1.0, 0.0,)
+ neg_one = torch.where(input < -delta, -1.0, 0.0,)
output = (pos_one + neg_one) * alpha.view(-1, 1, 1, 1).expand(
-1, input.size()[1], input.size()[2], input.size()[3]
)
@@ -295,16 +287,8 @@ def forward(ctx, input, _threshold):
alpha = TernaryScaled.alpha(input, delta)
output = torch.zeros_like(input)
- pos_one = torch.where(
- input > delta,
- 1.0,
- 0.0,
- )
- neg_one = torch.where(
- input < -delta,
- -1.0,
- 0.0,
- )
+ pos_one = torch.where(input > delta, 1.0, 0.0,)
+ neg_one = torch.where(input < -delta, -1.0, 0.0,)
output = pos_one + neg_one
return output
@@ -413,8 +397,7 @@ def _block_1d_bias(x: Tensor, block_shape: List[int]):
def _unblock_to_1d_bias(
- blocked_x: Tensor,
- x_shape_before_blocking: List[int],
+ blocked_x: Tensor, x_shape_before_blocking: List[int],
):
"""
blocked bias shape: [num_blocks, block_size] -> [output_features]
@@ -609,10 +592,7 @@ def unblock(
return _unblock_to_2d_activation(blocked_x, x_shape_before_blocking)
else:
return _unblock_to_2d_weight(
- blocked_x,
- x_shape_before_blocking,
- padded_x_shape,
- block_shape,
+ blocked_x, x_shape_before_blocking, padded_x_shape, block_shape,
)
elif len(x_shape_before_blocking) == 3:
if skipped_first_dim_when_blocking:
diff --git a/src/chop/nn/snn/auto_cuda/generator.py b/src/chop/nn/snn/auto_cuda/generator.py
index 639cf0379..f1637ddf9 100644
--- a/src/chop/nn/snn/auto_cuda/generator.py
+++ b/src/chop/nn/snn/auto_cuda/generator.py
@@ -310,10 +310,7 @@ def gen_forward_codes(
params.append(("v_reset", "const float &"))
params.extend(
- [
- ("neuron_num", "const int &"),
- ("numel", "const int &"),
- ]
+ [("neuron_num", "const int &"), ("numel", "const int &"),]
)
params_name = []
for item in params:
diff --git a/src/chop/nn/snn/modules/neuron/ifnode.py b/src/chop/nn/snn/modules/neuron/ifnode.py
index e5e3c373f..c2376d310 100644
--- a/src/chop/nn/snn/modules/neuron/ifnode.py
+++ b/src/chop/nn/snn/modules/neuron/ifnode.py
@@ -244,10 +244,12 @@ def multi_step_forward(self, x_seq: torch.Tensor):
self.v_float_to_tensor(x_seq[0])
if self.v_reset is None:
if self.store_v_seq:
- spike_seq, self.v, self.v_seq = (
- self.jit_eval_multi_step_forward_soft_reset_with_v_seq(
- x_seq, self.v, self.v_threshold
- )
+ (
+ spike_seq,
+ self.v,
+ self.v_seq,
+ ) = self.jit_eval_multi_step_forward_soft_reset_with_v_seq(
+ x_seq, self.v, self.v_threshold
)
else:
spike_seq, self.v = self.jit_eval_multi_step_forward_soft_reset(
@@ -255,10 +257,12 @@ def multi_step_forward(self, x_seq: torch.Tensor):
)
else:
if self.store_v_seq:
- spike_seq, self.v, self.v_seq = (
- self.jit_eval_multi_step_forward_hard_reset_with_v_seq(
- x_seq, self.v, self.v_threshold, self.v_reset
- )
+ (
+ spike_seq,
+ self.v,
+ self.v_seq,
+ ) = self.jit_eval_multi_step_forward_hard_reset_with_v_seq(
+ x_seq, self.v, self.v_threshold, self.v_reset
)
else:
spike_seq, self.v = self.jit_eval_multi_step_forward_hard_reset(
diff --git a/src/chop/nn/snn/modules/neuron/lifnode.py b/src/chop/nn/snn/modules/neuron/lifnode.py
index 90b02289c..0f0522f4d 100644
--- a/src/chop/nn/snn/modules/neuron/lifnode.py
+++ b/src/chop/nn/snn/modules/neuron/lifnode.py
@@ -417,29 +417,33 @@ def single_step_forward(self, x: torch.Tensor):
self.v_float_to_tensor(x)
if self.v_reset is None:
if self.decay_input:
- spike, self.v = (
- self.jit_eval_single_step_forward_soft_reset_decay_input(
- x, self.v, self.v_threshold, self.tau
- )
+ (
+ spike,
+ self.v,
+ ) = self.jit_eval_single_step_forward_soft_reset_decay_input(
+ x, self.v, self.v_threshold, self.tau
)
else:
- spike, self.v = (
- self.jit_eval_single_step_forward_soft_reset_no_decay_input(
- x, self.v, self.v_threshold, self.tau
- )
+ (
+ spike,
+ self.v,
+ ) = self.jit_eval_single_step_forward_soft_reset_no_decay_input(
+ x, self.v, self.v_threshold, self.tau
)
else:
if self.decay_input:
- spike, self.v = (
- self.jit_eval_single_step_forward_hard_reset_decay_input(
- x, self.v, self.v_threshold, self.v_reset, self.tau
- )
+ (
+ spike,
+ self.v,
+ ) = self.jit_eval_single_step_forward_hard_reset_decay_input(
+ x, self.v, self.v_threshold, self.v_reset, self.tau
)
else:
- spike, self.v = (
- self.jit_eval_single_step_forward_hard_reset_no_decay_input(
- x, self.v, self.v_threshold, self.v_reset, self.tau
- )
+ (
+ spike,
+ self.v,
+ ) = self.jit_eval_single_step_forward_hard_reset_no_decay_input(
+ x, self.v, self.v_threshold, self.v_reset, self.tau
)
return spike
@@ -514,56 +518,68 @@ def multi_step_forward(self, x_seq: torch.Tensor):
if self.v_reset is None:
if self.decay_input:
if self.store_v_seq:
- spike_seq, self.v, self.v_seq = (
- self.jit_eval_multi_step_forward_soft_reset_decay_input_with_v_seq(
- x_seq, self.v, self.v_threshold, self.tau
- )
+ (
+ spike_seq,
+ self.v,
+ self.v_seq,
+ ) = self.jit_eval_multi_step_forward_soft_reset_decay_input_with_v_seq(
+ x_seq, self.v, self.v_threshold, self.tau
)
else:
- spike_seq, self.v = (
- self.jit_eval_multi_step_forward_soft_reset_decay_input(
- x_seq, self.v, self.v_threshold, self.tau
- )
+ (
+ spike_seq,
+ self.v,
+ ) = self.jit_eval_multi_step_forward_soft_reset_decay_input(
+ x_seq, self.v, self.v_threshold, self.tau
)
else:
if self.store_v_seq:
- spike_seq, self.v, self.v_seq = (
- self.jit_eval_multi_step_forward_soft_reset_no_decay_input_with_v_seq(
- x_seq, self.v, self.v_threshold, self.tau
- )
+ (
+ spike_seq,
+ self.v,
+ self.v_seq,
+ ) = self.jit_eval_multi_step_forward_soft_reset_no_decay_input_with_v_seq(
+ x_seq, self.v, self.v_threshold, self.tau
)
else:
- spike_seq, self.v = (
- self.jit_eval_multi_step_forward_soft_reset_no_decay_input(
- x_seq, self.v, self.v_threshold, self.tau
- )
+ (
+ spike_seq,
+ self.v,
+ ) = self.jit_eval_multi_step_forward_soft_reset_no_decay_input(
+ x_seq, self.v, self.v_threshold, self.tau
)
else:
if self.decay_input:
if self.store_v_seq:
- spike_seq, self.v, self.v_seq = (
- self.jit_eval_multi_step_forward_hard_reset_decay_input_with_v_seq(
- x_seq, self.v, self.v_threshold, self.v_reset, self.tau
- )
+ (
+ spike_seq,
+ self.v,
+ self.v_seq,
+ ) = self.jit_eval_multi_step_forward_hard_reset_decay_input_with_v_seq(
+ x_seq, self.v, self.v_threshold, self.v_reset, self.tau
)
else:
- spike_seq, self.v = (
- self.jit_eval_multi_step_forward_hard_reset_decay_input(
- x_seq, self.v, self.v_threshold, self.v_reset, self.tau
- )
+ (
+ spike_seq,
+ self.v,
+ ) = self.jit_eval_multi_step_forward_hard_reset_decay_input(
+ x_seq, self.v, self.v_threshold, self.v_reset, self.tau
)
else:
if self.store_v_seq:
- spike_seq, self.v, self.v_seq = (
- self.jit_eval_multi_step_forward_hard_reset_no_decay_input_with_v_seq(
- x_seq, self.v, self.v_threshold, self.v_reset, self.tau
- )
+ (
+ spike_seq,
+ self.v,
+ self.v_seq,
+ ) = self.jit_eval_multi_step_forward_hard_reset_no_decay_input_with_v_seq(
+ x_seq, self.v, self.v_threshold, self.v_reset, self.tau
)
else:
- spike_seq, self.v = (
- self.jit_eval_multi_step_forward_hard_reset_no_decay_input(
- x_seq, self.v, self.v_threshold, self.v_reset, self.tau
- )
+ (
+ spike_seq,
+ self.v,
+ ) = self.jit_eval_multi_step_forward_hard_reset_no_decay_input(
+ x_seq, self.v, self.v_threshold, self.v_reset, self.tau
)
return spike_seq
diff --git a/src/chop/nn/snn/modules/spiking_self_attention.py b/src/chop/nn/snn/modules/spiking_self_attention.py
index 87ee98a30..3fa47d1d8 100644
--- a/src/chop/nn/snn/modules/spiking_self_attention.py
+++ b/src/chop/nn/snn/modules/spiking_self_attention.py
@@ -145,9 +145,7 @@ def __init__(self, in_channels, num_conv=1, ratio=4, group_size=64, activation=L
super().__init__()
inner_channels = in_channels * ratio
self.up = nn.Sequential(
- activation(),
- Conv1x1(in_channels, inner_channels),
- BN(inner_channels),
+ activation(), Conv1x1(in_channels, inner_channels), BN(inner_channels),
)
self.conv = nn.ModuleList()
for _ in range(num_conv):
@@ -163,9 +161,7 @@ def __init__(self, in_channels, num_conv=1, ratio=4, group_size=64, activation=L
)
)
self.down = nn.Sequential(
- activation(),
- Conv1x1(inner_channels, in_channels),
- BN(in_channels),
+ activation(), Conv1x1(inner_channels, in_channels), BN(in_channels),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
diff --git a/src/chop/passes/__init__.py b/src/chop/passes/__init__.py
index bfc69b47e..d1d63d8f5 100644
--- a/src/chop/passes/__init__.py
+++ b/src/chop/passes/__init__.py
@@ -40,8 +40,6 @@
from .module.analysis import calculate_avg_bits_module_analysis_pass
from .module.transforms import quantize_module_transform_pass, resharding_transform_pass
-from .onnx.analysis import (
- export_fx_graph_analysis_pass,
-)
+from .onnx.analysis import export_fx_graph_analysis_pass
from .graph.analysis.autosharding import autosharding_analysis_pass
diff --git a/src/chop/passes/graph/__init__.py b/src/chop/passes/graph/__init__.py
index 188038a14..3274c9143 100644
--- a/src/chop/passes/graph/__init__.py
+++ b/src/chop/passes/graph/__init__.py
@@ -46,9 +46,7 @@
from .transforms.quantize.quant_parsers import parse_node_config
-from chop.passes.graph.analysis.runtime.runtime_analysis import (
- runtime_analysis_pass,
-)
+from chop.passes.graph.analysis.runtime.runtime_analysis import runtime_analysis_pass
from .interface import tensorrt_engine_interface_pass
@@ -151,6 +149,6 @@
if check_dependencies("tensorrt_fake_quantize_transform_pass"):
TRANSFORM_PASSES.append("tensorrt_fake_quantize_transform_pass")
- PASSES["tensorrt_fake_quantize_transform_pass"] = (
- tensorrt_fake_quantize_transform_pass
- )
+ PASSES[
+ "tensorrt_fake_quantize_transform_pass"
+ ] = tensorrt_fake_quantize_transform_pass
diff --git a/src/chop/passes/graph/analysis/add_metadata/add_common_metadata.py b/src/chop/passes/graph/analysis/add_metadata/add_common_metadata.py
index 66fe9cc28..9fde9d7cb 100644
--- a/src/chop/passes/graph/analysis/add_metadata/add_common_metadata.py
+++ b/src/chop/passes/graph/analysis/add_metadata/add_common_metadata.py
@@ -203,10 +203,7 @@ def graph_iterator_for_mase_ops(graph):
def graph_iterator_for_metadata(
- graph,
- dummy_in=None,
- add_value=True,
- force_device_meta=False,
+ graph, dummy_in=None, add_value=True, force_device_meta=False,
):
"""
largely adapted from https://pytorch.org/docs/stable/fx.html
@@ -291,12 +288,7 @@ def _add_graph_metadata(graph):
def add_common_metadata_analysis_pass(
- graph,
- pass_args={
- "dummy_in": None,
- "add_value": True,
- "force_device_meta": False,
- },
+ graph, pass_args={"dummy_in": None, "add_value": True, "force_device_meta": False,},
):
"""add common metadata
diff --git a/src/chop/passes/graph/analysis/add_metadata/add_hardware_metadata.py b/src/chop/passes/graph/analysis/add_metadata/add_hardware_metadata.py
index 39c8ac690..9fd2e9466 100644
--- a/src/chop/passes/graph/analysis/add_metadata/add_hardware_metadata.py
+++ b/src/chop/passes/graph/analysis/add_metadata/add_hardware_metadata.py
@@ -39,8 +39,7 @@ def add_component_source(node):
if mase_op == "user_defined_module":
for custom_op, op_info in node.meta["mase"].model.custom_ops["modules"].items():
if isinstance(
- deepgetattr(node.meta["mase"].model, node.target),
- custom_op,
+ deepgetattr(node.meta["mase"].model, node.target), custom_op,
):
node.meta["mase"]["hardware"]["toolchain"] = "INTERNAL_RTL"
node.meta["mase"]["hardware"]["module"] = op_info["module"]
diff --git a/src/chop/passes/graph/analysis/add_metadata/common_metadata_layers.py b/src/chop/passes/graph/analysis/add_metadata/common_metadata_layers.py
index 6f5ee0147..ed3249c9c 100644
--- a/src/chop/passes/graph/analysis/add_metadata/common_metadata_layers.py
+++ b/src/chop/passes/graph/analysis/add_metadata/common_metadata_layers.py
@@ -231,10 +231,7 @@
},
# Inserted ops from the replace_method_with_function pass
"torch_size": {"input": "data_in", "dim": "config"},
- "torch_contiguous": {
- "input": "data_in",
- "memory_format": "config",
- },
+ "torch_contiguous": {"input": "data_in", "memory_format": "config",},
# arbitrary length - support up to 4
"torch_expand": {
"input": "data_in",
@@ -257,11 +254,7 @@
"shape_2": "config",
"shape_3": "config",
},
- "torch_split": {
- "input": "data_in",
- "split_size": "config",
- "dim": "config",
- },
+ "torch_split": {"input": "data_in", "split_size": "config", "dim": "config",},
"torch_permute": {
"input": "data_in",
"dim_0": "config",
@@ -269,11 +262,7 @@
"dim_2": "config",
"dim_3": "config",
},
- "torch_transpose": {
- "input": "data_in",
- "dim0": "config",
- "dim1": "config",
- },
+ "torch_transpose": {"input": "data_in", "dim0": "config", "dim1": "config",},
# DTensor ops
"dtensor_arange": {
"device_mesh": "config",
@@ -287,9 +276,7 @@
"requires_grad": "config",
},
# tensor constructor
- "tensor": {
- "data": "data_in",
- },
+ "tensor": {"data": "data_in",},
# https://pytorch.org/docs/stable/generated/torch.nn.functional.dropout.html
"dropout": {
"input": "data_in",
@@ -353,10 +340,7 @@
"softmax": {"input": "data_in"},
"gelu": {"input": "data_in"},
# https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html
- "crossentropyloss": {
- "input": "data_in",
- "target": "data_in",
- },
+ "crossentropyloss": {"input": "data_in", "target": "data_in",},
# chop.nn.modules.lora.LoRALinear
"loralinear": {"input": "data_in"},
"grouped_query_attention": {"input": "data_in"},
@@ -405,24 +389,15 @@
"size_3": "config",
},
# Tensor.max(dim=None, keepdim=False)
- "max": {
- "dim": "config",
- "keepdim": "config",
- },
+ "max": {"dim": "config", "keepdim": "config",},
# https://pytorch.org/docs/stable/generated/torch.Tensor.sum.html
- "sum": {
- "dim": "config",
- "keepdim": "config",
- },
+ "sum": {"dim": "config", "keepdim": "config",},
# https://pytorch.org/docs/stable/generated/torch.Tensor.round.html
"round": {},
# https://pytorch.org/docs/stable/generated/torch.Tensor.floor.html
"floor": {},
# https://pytorch.org/docs/stable/generated/torch.Tensor.clamp.html
- "clamp": {
- "min": "config",
- "max": "config",
- },
+ "clamp": {"min": "config", "max": "config",},
# https://pytorch.org/docs/stable/generated/torch.Tensor.dim.html
"dim": {},
# https://pytorch.org/docs/stable/generated/torch.Tensor.permute.html#torch.Tensor.permute
@@ -451,11 +426,7 @@
# https://pytorch.org/docs/stable/generated/torch.Tensor.type_as.html
"type_as": {"tensor": "data_in"},
# https://pytorch.org/docs/stable/generated/torch.Tensor.index_select.html
- "index_select": {
- "input": "data_in",
- "dim": "config",
- "index": "data_in",
- },
+ "index_select": {"input": "data_in", "dim": "config", "index": "data_in",},
# https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html
"detach": {"input": "data_in"},
}
@@ -508,11 +479,7 @@ def deepgetattr(obj, attr):
def _annotate_arg_metadata(
- meta: MaseMetadata,
- args: list,
- kwargs: dict,
- func_data: dict,
- add_value: bool,
+ meta: MaseMetadata, args: list, kwargs: dict, func_data: dict, add_value: bool,
):
"""
Analyse target args and kwargs received from shape propagation to annotate combined meta["mase"]["args"]
@@ -621,9 +588,7 @@ def _annotate_arg_metadata(
def _annotate_result_metadata(
- meta: MaseMetadata,
- result,
- add_value: bool,
+ meta: MaseMetadata, result, add_value: bool,
) -> MaseMetadata:
"""
Analyse the result from running the target to annotate the meta["mase"]["results"] dictionary with metadata.
diff --git a/src/chop/passes/graph/analysis/add_metadata/hardware_metadata_layers.py b/src/chop/passes/graph/analysis/add_metadata/hardware_metadata_layers.py
index b084821ae..a92958cff 100644
--- a/src/chop/passes/graph/analysis/add_metadata/hardware_metadata_layers.py
+++ b/src/chop/passes/graph/analysis/add_metadata/hardware_metadata_layers.py
@@ -64,17 +64,13 @@
"relu": [
{
"name": "fixed_relu",
- "dependence_files": [
- "activation_layers/rtl/fixed_relu.sv",
- ],
+ "dependence_files": ["activation_layers/rtl/fixed_relu.sv",],
},
],
"hardshrink": [
{
"name": "fixed_hardshrink",
- "dependence_files": [
- "activation_layers/rtl/fixed_hardshrink.sv",
- ],
+ "dependence_files": ["activation_layers/rtl/fixed_hardshrink.sv",],
},
],
"silu": [
@@ -107,9 +103,7 @@
"softshrink": [
{
"name": "fixed_softshrink",
- "dependence_files": [
- "activation_layers/rtl/fixed_softshrink.sv",
- ],
+ "dependence_files": ["activation_layers/rtl/fixed_softshrink.sv",],
},
],
"logsigmoid": [
@@ -138,17 +132,13 @@
"selu": [
{
"name": "fixed_selu",
- "dependence_files": [
- "activation_layers/rtl/fixed_selu.sv",
- ],
+ "dependence_files": ["activation_layers/rtl/fixed_selu.sv",],
},
],
"tanh": [
{
"name": "fixed_tanh",
- "dependence_files": [
- "activation_layers/rtl/fixed_tanh.sv",
- ],
+ "dependence_files": ["activation_layers/rtl/fixed_tanh.sv",],
},
],
"gelu": [
@@ -172,17 +162,13 @@
"softplus": [
{
"name": "fixed_softplus",
- "dependence_files": [
- "activation_layers/rtl/fixed_softplus.sv",
- ],
+ "dependence_files": ["activation_layers/rtl/fixed_softplus.sv",],
},
],
"add": [
{
"name": "fixed_adder",
- "dependence_files": [
- "linear_layers/fixed_operators/rtl/fixed_adder.sv",
- ],
+ "dependence_files": ["linear_layers/fixed_operators/rtl/fixed_adder.sv",],
}
],
"mul": [
@@ -199,14 +185,7 @@
"dependence_files": ["common/rtl/df_split.sv", "common/rtl/split2.sv"],
}
],
- "getitem": [
- {
- "name": "buffer",
- "dependence_files": [
- "memory/rtl/buffer.sv",
- ],
- }
- ],
+ "getitem": [{"name": "buffer", "dependence_files": ["memory/rtl/buffer.sv",],}],
"grouped_query_attention": [
{
"name": "fixed_gqa_wrapper",
diff --git a/src/chop/passes/graph/analysis/add_metadata/software_metadata_layers.py b/src/chop/passes/graph/analysis/add_metadata/software_metadata_layers.py
index d7f4fd8ed..1e6ad2a57 100644
--- a/src/chop/passes/graph/analysis/add_metadata/software_metadata_layers.py
+++ b/src/chop/passes/graph/analysis/add_metadata/software_metadata_layers.py
@@ -243,16 +243,8 @@ def analyze_software_meta_param_patched_func_default(meta):
"getitem": analyze_software_meta_param_implicit_func_default,
"getattr": analyze_software_meta_param_implicit_func_default,
},
- "placeholder": {
- "placeholder": analyze_software_meta_param_placeholder,
- },
- "get_attr": {
- "get_attr": analyze_software_meta_param_get_attr,
- },
- "output": {
- "output": analyze_software_meta_param_output,
- },
- "patched_func": {
- "default": analyze_software_meta_param_patched_func_default,
- },
+ "placeholder": {"placeholder": analyze_software_meta_param_placeholder,},
+ "get_attr": {"get_attr": analyze_software_meta_param_get_attr,},
+ "output": {"output": analyze_software_meta_param_output,},
+ "patched_func": {"default": analyze_software_meta_param_patched_func_default,},
}
diff --git a/src/chop/passes/graph/analysis/autosharding/alpa_cost_modelling.py b/src/chop/passes/graph/analysis/autosharding/alpa_cost_modelling.py
index d9fea1aa5..30216eea3 100644
--- a/src/chop/passes/graph/analysis/autosharding/alpa_cost_modelling.py
+++ b/src/chop/passes/graph/analysis/autosharding/alpa_cost_modelling.py
@@ -74,10 +74,7 @@ def get_resharding_cost(
and (dest[0] == SpmdShard.R)
):
ag_dim = 1 if src[0] == dest[0] else 0
- return mesh.all_gather_cost(
- num_bytes=num_bytes,
- mesh_dim=ag_dim,
- )
+ return mesh.all_gather_cost(num_bytes=num_bytes, mesh_dim=ag_dim,)
# All-to-all
# E.g. (R, S_0) -> (S_0, R), (S_1, R) -> (R, S_1)
@@ -85,10 +82,7 @@ def get_resharding_cost(
# all to all
a2a_dim = src[0].value if src[0] != SpmdShard.R else src[1].value
try:
- return mesh.all_to_all_cost(
- num_bytes=num_bytes,
- mesh_dim=a2a_dim,
- )
+ return mesh.all_to_all_cost(num_bytes=num_bytes, mesh_dim=a2a_dim,)
except:
assert False
diff --git a/src/chop/passes/graph/analysis/autosharding/autosharding.py b/src/chop/passes/graph/analysis/autosharding/autosharding.py
index f1ad3a4fe..b49154c20 100644
--- a/src/chop/passes/graph/analysis/autosharding/autosharding.py
+++ b/src/chop/passes/graph/analysis/autosharding/autosharding.py
@@ -23,10 +23,7 @@ def deepgetattr(obj, attr, default=None):
def _import_solution(
- mg,
- solution: dict,
- mesh: MeshModel,
- extrapolate_sharding: bool = True,
+ mg, solution: dict, mesh: MeshModel, extrapolate_sharding: bool = True,
):
"""Import an autosharding solution into the metadata of the MaseGraph.
@@ -71,18 +68,14 @@ def _import_solution(
# Annotate the metadata for each argument
for arg, arg_spec in solution[node.name].get("args", {}).items():
node.meta["mase"]["common"]["args"][arg]["dtensor_spec"] = _DTensorSpec(
- mesh=mesh,
- placements=arg_spec,
+ mesh=mesh, placements=arg_spec,
)
# Annotate the metadata for each result
for result, result_spec in solution[node.name].get("results", {}).items():
- node.meta["mase"]["common"]["results"][result]["dtensor_spec"] = (
- _DTensorSpec(
- mesh=mesh,
- placements=result_spec,
- )
- )
+ node.meta["mase"]["common"]["results"][result][
+ "dtensor_spec"
+ ] = _DTensorSpec(mesh=mesh, placements=result_spec,)
return mg, {}
@@ -114,10 +107,7 @@ def _export_solution(mg, export_file: str = "ilp_solution.pkl"):
logger.warning(
f"DTensor spec not found for arg: {arg} in node: {node_name}. Assigning fully-replicated solution."
)
- spec = _DTensorSpec(
- None,
- (Replicate(), Replicate()),
- )
+ spec = _DTensorSpec(None, (Replicate(), Replicate()),)
else:
spec = arg_info["dtensor_spec"]
@@ -132,10 +122,7 @@ def _export_solution(mg, export_file: str = "ilp_solution.pkl"):
logger.warning(
f"DTensor spec not found for result: {result} in node: {node_name}. Assigning fully-replicated solution."
)
- spec = _DTensorSpec(
- None,
- (Replicate(), Replicate()),
- )
+ spec = _DTensorSpec(None, (Replicate(), Replicate()),)
else:
spec = result_info["dtensor_spec"]
out_dict[node_name]["results"][result] = spec.placements
@@ -197,9 +184,7 @@ def _get_sharding_map(mg):
if module not in tensor_sharding_map:
tensor_sharding_map[module] = {
"node": node.name,
- "sharding": {
- attr: out_specs,
- },
+ "sharding": {attr: out_specs,},
}
else:
tensor_sharding_map[module]["sharding"][attr] = out_specs
@@ -302,8 +287,11 @@ def autosharding_analysis_pass(mg, pass_args: dict = {}):
if not pass_args.get(f"skip_forward", False):
tensor_sharding_map = _get_sharding_map(mg)
- return mg, {
- "autosharding_time": autosharding_time,
- "tensor_sharding_map": tensor_sharding_map,
- **pass_outs,
- }
+ return (
+ mg,
+ {
+ "autosharding_time": autosharding_time,
+ "tensor_sharding_map": tensor_sharding_map,
+ **pass_outs,
+ },
+ )
diff --git a/src/chop/passes/graph/analysis/autosharding/megatron.py b/src/chop/passes/graph/analysis/autosharding/megatron.py
index 30cd36f7e..ab9134a36 100644
--- a/src/chop/passes/graph/analysis/autosharding/megatron.py
+++ b/src/chop/passes/graph/analysis/autosharding/megatron.py
@@ -3,9 +3,7 @@
def megatron_autosharding_pass(
- mg: MaseGraph,
- mesh: MeshModel,
- pass_args: dict,
+ mg: MaseGraph, mesh: MeshModel, pass_args: dict,
):
for node in mg.fx_graph.nodes:
meta = node.meta["mase"]["common"]
diff --git a/src/chop/passes/graph/analysis/autosharding/strategies/basic_strategy.py b/src/chop/passes/graph/analysis/autosharding/strategies/basic_strategy.py
index 40b704ca8..d68f61cb5 100644
--- a/src/chop/passes/graph/analysis/autosharding/strategies/basic_strategy.py
+++ b/src/chop/passes/graph/analysis/autosharding/strategies/basic_strategy.py
@@ -86,10 +86,7 @@ def parse_dims(cls, input_dims: List[str], output_dim: str) -> "EinsumDims":
def gen_einsum_strategies(
- equation: str,
- mesh: tuple,
- *,
- linearity: bool = False,
+ equation: str, mesh: tuple, *, linearity: bool = False,
) -> OpStrategy:
"""
Generate a strategy list for the ops that follow einsum style notation.
diff --git a/src/chop/passes/graph/analysis/autosharding/strategies/common.py b/src/chop/passes/graph/analysis/autosharding/strategies/common.py
index 2635a5587..e58980c60 100644
--- a/src/chop/passes/graph/analysis/autosharding/strategies/common.py
+++ b/src/chop/passes/graph/analysis/autosharding/strategies/common.py
@@ -81,11 +81,7 @@ def fully_replicated_strategy(meta, mesh):
in_spec = _DTensorSpec(
mesh,
sharding,
- tensor_meta=TensorMeta(
- shape=in_shape,
- stride=None,
- dtype=in_dtype,
- ),
+ tensor_meta=TensorMeta(shape=in_shape, stride=None, dtype=in_dtype,),
)
dtype_key = (
diff --git a/src/chop/passes/graph/analysis/autosharding/strategies/matrix_ops.py b/src/chop/passes/graph/analysis/autosharding/strategies/matrix_ops.py
index 78328a95d..60a3de959 100644
--- a/src/chop/passes/graph/analysis/autosharding/strategies/matrix_ops.py
+++ b/src/chop/passes/graph/analysis/autosharding/strategies/matrix_ops.py
@@ -24,10 +24,7 @@
from chop.ir.graph import MaseMetadata
-def transpose_strategy(
- meta: MaseMetadata,
- mesh: tuple,
-) -> OpStrategy:
+def transpose_strategy(meta: MaseMetadata, mesh: tuple,) -> OpStrategy:
parent_node = meta.node.args[0]
self_strategy = parent_node.meta["mase"]["software"]["autosharding"]["op_strategy"]
@@ -59,11 +56,7 @@ def transpose_strategy(
return OpStrategy(strategies=transpose_strategies)
-def _mm_like_strategy(
- mm_equation: str,
- meta: MaseMetadata,
- mesh: tuple,
-) -> OpStrategy:
+def _mm_like_strategy(mm_equation: str, meta: MaseMetadata, mesh: tuple,) -> OpStrategy:
self_shape, mat2_shape = [arg["shape"] for arg in meta["common"]["args"].values()]
# generate all possible strategies for mm
mm_strategy = gen_einsum_strategies(mm_equation, mesh)
@@ -102,9 +95,7 @@ def _mm_like_strategy(
def _addmm_like_strategy(
- mm_equation: str,
- meta: MaseMetadata,
- mesh: tuple,
+ mm_equation: str, meta: MaseMetadata, mesh: tuple,
) -> OpStrategy:
self_shape, mat1_shape, mat2_shape = [
@@ -170,37 +161,24 @@ def _addmm_like_strategy(
return mm_strategy
-def mm_strategy(
- meta: MaseMetadata,
- mesh: tuple,
-) -> OpStrategy:
+def mm_strategy(meta: MaseMetadata, mesh: tuple,) -> OpStrategy:
return _mm_like_strategy("mk,kn->mn", meta, mesh)
-def addmm_strategy(
- meta: MaseMetadata,
- mesh: tuple,
-) -> OpStrategy:
+def addmm_strategy(meta: MaseMetadata, mesh: tuple,) -> OpStrategy:
return _addmm_like_strategy("mk,kn->mn", meta, mesh)
-def bmm_strategy(
- meta: MaseMetadata,
- mesh: tuple,
-) -> OpStrategy:
+def bmm_strategy(meta: MaseMetadata, mesh: tuple,) -> OpStrategy:
return _mm_like_strategy("bmk,bkn->bmn", meta, mesh)
-def baddmm_strategy(
- meta: MaseMetadata,
- mesh: tuple,
-) -> OpStrategy:
+def baddmm_strategy(meta: MaseMetadata, mesh: tuple,) -> OpStrategy:
return _addmm_like_strategy("bmk,bkn->bmn", meta, mesh)
def scaled_dot_product_flash_attention_strategy(
- meta: MaseMetadata,
- mesh: tuple,
+ meta: MaseMetadata, mesh: tuple,
) -> OpStrategy:
# NOTE: currently we only support some simple strategies to support tensor parallelism
# TODO: sdpa might be a good candidate for us to explore decomposed sharding propagation
diff --git a/src/chop/passes/graph/analysis/autosharding/strategies/pointwise_ops.py b/src/chop/passes/graph/analysis/autosharding/strategies/pointwise_ops.py
index c3c06d35e..a7d8a739b 100644
--- a/src/chop/passes/graph/analysis/autosharding/strategies/pointwise_ops.py
+++ b/src/chop/passes/graph/analysis/autosharding/strategies/pointwise_ops.py
@@ -125,9 +125,7 @@ def common_pointwise_strategy(
arg_node.meta["mase"]["common"]["results"]["data_out_0"]["shape"],
)
input_target_placements = map_placements_after_broadcast(
- tuple(out_placements),
- common_shape,
- input_arg_dims_map,
+ tuple(out_placements), common_shape, input_arg_dims_map,
)
input_arg_target_spec = _DTensorSpec(
mesh=mesh,
diff --git a/src/chop/passes/graph/analysis/autosharding/strategies/tensor_ops.py b/src/chop/passes/graph/analysis/autosharding/strategies/tensor_ops.py
index 05c76160a..96cbfa445 100644
--- a/src/chop/passes/graph/analysis/autosharding/strategies/tensor_ops.py
+++ b/src/chop/passes/graph/analysis/autosharding/strategies/tensor_ops.py
@@ -6,9 +6,7 @@
PlacementStrategy,
StrategyType,
)
-from torch.distributed.tensor.ops.utils import (
- is_tensor_partial,
-)
+from torch.distributed.tensor.ops.utils import is_tensor_partial
from torch.distributed.tensor.placement_types import (
_DTensorSpec,
Partial,
diff --git a/src/chop/passes/graph/analysis/autosharding/strategies/view_ops.py b/src/chop/passes/graph/analysis/autosharding/strategies/view_ops.py
index 505a2440b..6cd423580 100644
--- a/src/chop/passes/graph/analysis/autosharding/strategies/view_ops.py
+++ b/src/chop/passes/graph/analysis/autosharding/strategies/view_ops.py
@@ -236,9 +236,7 @@ def dim_flatten(ndim: int, start_dim=0, end_dim=-1) -> DimMap:
def dim_movedim(
- ndim: int,
- input: Union[int, Sequence[int]],
- destination: Union[int, Sequence[int]],
+ ndim: int, input: Union[int, Sequence[int]], destination: Union[int, Sequence[int]],
) -> DimMap:
input = normalize_dims(input, ndim)
destination = normalize_dims(destination, ndim)
@@ -622,8 +620,7 @@ def reshape_strategy(meta, mesh):
)
output_strategy.strategies.append(
PlacementStrategy(
- output_specs=output_spec,
- input_specs=(input_tgt_spec,),
+ output_specs=output_spec, input_specs=(input_tgt_spec,),
)
)
diff --git a/src/chop/passes/graph/analysis/flop_estimator/calculator/calc_modules.py b/src/chop/passes/graph/analysis/flop_estimator/calculator/calc_modules.py
index b8ca08310..04dde1876 100644
--- a/src/chop/passes/graph/analysis/flop_estimator/calculator/calc_modules.py
+++ b/src/chop/passes/graph/analysis/flop_estimator/calculator/calc_modules.py
@@ -36,7 +36,7 @@ def calculate_modules(module, in_data, out_data):
# Kernel size here can be either a single int for square kernel
# or a tuple (see
# https://pytorch.org/docs/stable/nn.html#torch.nn.MaxPool2d )
- window_size = module.kernel_size**2
+ window_size = module.kernel_size ** 2
else:
window_size = module.kernel_size[0] * module.kernel_size[1]
diff --git a/src/chop/passes/graph/analysis/plot/plot_graph.py b/src/chop/passes/graph/analysis/plot/plot_graph.py
index 8af151b36..adcc93177 100644
--- a/src/chop/passes/graph/analysis/plot/plot_graph.py
+++ b/src/chop/passes/graph/analysis/plot/plot_graph.py
@@ -4,10 +4,7 @@
def plot_graph_analysis_pass(
- graph,
- pass_args={
- "file_name": None,
- },
+ graph, pass_args={"file_name": None,},
):
graph.draw(pass_args["file_name"])
# nx_graph = nx.DiGraph()
diff --git a/src/chop/passes/graph/interface/tensorrt/quantize.py b/src/chop/passes/graph/interface/tensorrt/quantize.py
index 0de0cda60..0a6773754 100644
--- a/src/chop/passes/graph/interface/tensorrt/quantize.py
+++ b/src/chop/passes/graph/interface/tensorrt/quantize.py
@@ -21,6 +21,7 @@ def Quantizer(config):
"pytorch_quantization is not installed. Cannot use tensorrt quantize pass."
)
+
else:
import tensorrt as trt
from pytorch_quantization import quant_modules, calib
diff --git a/src/chop/passes/graph/transforms/dse/run_dse.py b/src/chop/passes/graph/transforms/dse/run_dse.py
index 28446c4ca..991f9e1bd 100644
--- a/src/chop/passes/graph/transforms/dse/run_dse.py
+++ b/src/chop/passes/graph/transforms/dse/run_dse.py
@@ -16,7 +16,7 @@ def get_factors(n):
set(
functools.reduce(
list.__add__,
- ([i, n // i] for i in range(1, int(n**0.5) + 1) if n % i == 0),
+ ([i, n // i] for i in range(1, int(n ** 0.5) + 1) if n % i == 0),
)
)
)
diff --git a/src/chop/passes/graph/transforms/lora.py b/src/chop/passes/graph/transforms/lora.py
index 9c4559661..87fb2362f 100644
--- a/src/chop/passes/graph/transforms/lora.py
+++ b/src/chop/passes/graph/transforms/lora.py
@@ -10,8 +10,7 @@
def insert_lora_adapter_transform_pass(
- mg: MaseGraph,
- pass_args={},
+ mg: MaseGraph, pass_args={},
):
rank = pass_args.get("rank", 0)
@@ -42,8 +41,7 @@ def insert_lora_adapter_transform_pass(
def fuse_lora_weights_transform_pass(
- mg: MaseGraph,
- pass_args={},
+ mg: MaseGraph, pass_args={},
):
for node in mg.nodes:
target = (
diff --git a/src/chop/passes/graph/transforms/onnxrt/quantize.py b/src/chop/passes/graph/transforms/onnxrt/quantize.py
index cacbb6272..39d568236 100644
--- a/src/chop/passes/graph/transforms/onnxrt/quantize.py
+++ b/src/chop/passes/graph/transforms/onnxrt/quantize.py
@@ -54,9 +54,7 @@ def quantize_dynamic(self, model_path: PosixPath, quantized_model_path: PosixPat
)
quantized_model = quantize_dynamic(
- model_path,
- quantized_model_path,
- weight_type=precision,
+ model_path, quantized_model_path, weight_type=precision,
)
self.logger.info("Quantization complete. Model is now dynamically quantized.")
diff --git a/src/chop/passes/graph/transforms/pruning/pruning_methods.py b/src/chop/passes/graph/transforms/pruning/pruning_methods.py
index 665abc17e..aca81dc51 100644
--- a/src/chop/passes/graph/transforms/pruning/pruning_methods.py
+++ b/src/chop/passes/graph/transforms/pruning/pruning_methods.py
@@ -128,31 +128,11 @@ def neurons_random_fan_in(
weight_criteria_map = {
- "local": {
- "elementwise": {
- "random": random,
- "l1-norm": l1,
- }
- },
- "global": {
- "elementwise": {
- "random": random,
- "l1-norm": global_weight_l1,
- }
- },
+ "local": {"elementwise": {"random": random, "l1-norm": l1,}},
+ "global": {"elementwise": {"random": random, "l1-norm": global_weight_l1,}},
}
activation_criteria_map = {
- "local": {
- "elementwise": {
- "random": random,
- "l1-norm": l1,
- }
- },
- "global": {
- "elementwise": {
- "random": random,
- "l1-norm": global_activation_l1,
- }
- },
+ "local": {"elementwise": {"random": random, "l1-norm": l1,}},
+ "global": {"elementwise": {"random": random, "l1-norm": global_activation_l1,}},
}
diff --git a/src/chop/passes/graph/transforms/quantize/quant_parsers/parse_quant_config.py b/src/chop/passes/graph/transforms/quantize/quant_parsers/parse_quant_config.py
index dea246fc2..890497a95 100644
--- a/src/chop/passes/graph/transforms/quantize/quant_parsers/parse_quant_config.py
+++ b/src/chop/passes/graph/transforms/quantize/quant_parsers/parse_quant_config.py
@@ -27,10 +27,7 @@
"weight_entries": ("weight_width", "weight_frac_width"),
"data_in_entries": ("data_in_width", "data_in_frac_width"),
"bias_entries": ("bias_width", "bias_frac_width"),
- "data_out_entries": (
- "data_out_width",
- "data_out_frac_width",
- ),
+ "data_out_entries": ("data_out_width", "data_out_frac_width",),
"additional_layers_entries": ("floor"),
},
"lutnet": {
@@ -65,18 +62,9 @@
"weight_width",
"weight_frac_width",
),
- "bias_entries": (
- "bias_width",
- "bias_frac_width",
- ),
- "data_in_entries": (
- "data_in_width",
- "data_in_frac_width",
- ),
- "data_out_entries": (
- "data_out_width",
- "data_out_frac_width",
- ),
+ "bias_entries": ("bias_width", "bias_frac_width",),
+ "data_in_entries": ("data_in_width", "data_in_frac_width",),
+ "data_out_entries": ("data_out_width", "data_out_frac_width",),
"additional_layers_entries": {
"additional_layers_inputs",
"additional_layers_outputs",
@@ -84,21 +72,9 @@
},
},
"binary": {
- "weight_entries": (
- "weight_width",
- "weight_stochastic",
- "weight_bipolar",
- ),
- "data_in_entries": (
- "data_in_width",
- "data_in_stochastic",
- "data_in_bipolar",
- ),
- "bias_entries": (
- "bias_width",
- "bias_stochastic",
- "bias_bipolar",
- ),
+ "weight_entries": ("weight_width", "weight_stochastic", "weight_bipolar",),
+ "data_in_entries": ("data_in_width", "data_in_stochastic", "data_in_bipolar",),
+ "bias_entries": ("bias_width", "bias_stochastic", "bias_bipolar",),
},
"binary_residual": {
"weight_entries": (
@@ -114,11 +90,7 @@
"data_in_residual_sign",
"data_in_levels", # data_in_levels (int): number of residual levels to use in lutnet
),
- "bias_entries": (
- "bias_width",
- "bias_stochastic",
- "bias_bipolar",
- ),
+ "bias_entries": ("bias_width", "bias_stochastic", "bias_bipolar",),
},
"binary_residual": {
"weight_entries": (
@@ -134,11 +106,7 @@
"data_in_residual_sign",
"data_in_levels", # data_in_levels (int): number of residual levels to use in lutnet
),
- "bias_entries": (
- "bias_width",
- "bias_stochastic",
- "bias_bipolar",
- ),
+ "bias_entries": ("bias_width", "bias_stochastic", "bias_bipolar",),
},
"ternary": {
"weight_entries": (
@@ -245,11 +213,7 @@
"data_in_exponent_bias_width",
"data_in_block_size",
),
- "bias_entries": (
- "bias_width",
- "bias_exponent_bias_width",
- "bias_block_size",
- ),
+ "bias_entries": ("bias_width", "bias_exponent_bias_width", "bias_block_size",),
},
"mxint_hardware": {
"weight_entries": (
@@ -262,11 +226,7 @@
"data_in_exponent_width",
"data_in_parallelism",
),
- "bias_entries": (
- "bias_width",
- "bias_exponent_width",
- "bias_parallelism",
- ),
+ "bias_entries": ("bias_width", "bias_exponent_width", "bias_parallelism",),
},
}
@@ -389,22 +349,10 @@ def cp_data_out_entries(
("name", "data_in_entries"),
("weight_entries", "bias_entries", "bypass"),
),
- "layer_norm": (
- ("name", "data_in_entries"),
- ("bypass",),
- ),
- "group_norm": (
- ("name", "data_in_entries"),
- ("bypass",),
- ),
- "instance_norm2d": (
- ("name", "data_in_entries"),
- ("bypass",),
- ),
- "rms_norm": (
- ("name", "data_in_entries"),
- ("bypass",),
- ),
+ "layer_norm": (("name", "data_in_entries"), ("bypass",),),
+ "group_norm": (("name", "data_in_entries"), ("bypass",),),
+ "instance_norm2d": (("name", "data_in_entries"), ("bypass",),),
+ "rms_norm": (("name", "data_in_entries"), ("bypass",),),
"grouped_query_attention": (
("name", "data_in_entries", "weight_entries"),
("bypass", "bias_entries"),
diff --git a/src/chop/passes/graph/transforms/training/modify.py b/src/chop/passes/graph/transforms/training/modify.py
index eff5583d0..bf84faeee 100644
--- a/src/chop/passes/graph/transforms/training/modify.py
+++ b/src/chop/passes/graph/transforms/training/modify.py
@@ -65,10 +65,7 @@ def attach_backward_fn(q_fn: torch.autograd.Function, mase_op: str, q_fn_cfg: di
def create_new_module(
- mase_op: str,
- original_module: nn.Module,
- config: dict,
- node_meta: dict,
+ mase_op: str, original_module: nn.Module, config: dict, node_meta: dict,
):
original_module_cls = type(original_module)
diff --git a/src/chop/passes/graph/transforms/utils/logicnets_fusion.py b/src/chop/passes/graph/transforms/utils/logicnets_fusion.py
index 5bd5977b3..772206878 100644
--- a/src/chop/passes/graph/transforms/utils/logicnets_fusion.py
+++ b/src/chop/passes/graph/transforms/utils/logicnets_fusion.py
@@ -12,12 +12,8 @@
matches_module_pattern,
replace_node_module,
)
-from chop.nn.quantized.modules.linear import (
- LinearLogicNets,
-)
-from chop.nn.quantized.modules.conv2d import (
- Conv2DLogicNets,
-)
+from chop.nn.quantized.modules.linear import LinearLogicNets
+from chop.nn.quantized.modules.conv2d import Conv2DLogicNets
# Housekeeping -------------------------------------------------------------------------
logger = logging.getLogger(__file__)
diff --git a/src/chop/passes/graph/transforms/verilog/emit_bram.py b/src/chop/passes/graph/transforms/verilog/emit_bram.py
index a6a000c79..e562a5fb1 100644
--- a/src/chop/passes/graph/transforms/verilog/emit_bram.py
+++ b/src/chop/passes/graph/transforms/verilog/emit_bram.py
@@ -245,8 +245,8 @@ def emit_parameters_in_dat_internal(node, param_name, file_name):
else:
base_quantizer = integer_quantizer_for_hw
- scale = 2**frac_width
- thresh = 2**width
+ scale = 2 ** frac_width
+ thresh = 2 ** width
for i in range(0, out_depth):
line_buff = ""
for j in range(0, out_size):
@@ -301,8 +301,8 @@ def emit_parameters_in_dat_hls(node, param_name, file_name):
"precision"
][1]
- scale = 2**frac_width
- thresh = 2**width
+ scale = 2 ** frac_width
+ thresh = 2 ** width
for i in range(0, out_depth):
line_buff = ""
for j in range(0, out_size):
diff --git a/src/chop/passes/graph/transforms/verilog/emit_hls.py b/src/chop/passes/graph/transforms/verilog/emit_hls.py
index 3efe2c037..ce348c43d 100644
--- a/src/chop/passes/graph/transforms/verilog/emit_hls.py
+++ b/src/chop/passes/graph/transforms/verilog/emit_hls.py
@@ -121,12 +121,7 @@ def _call_hls_flow(node, node_dir):
# Call Vitis HLS for synthesis
vitis_hls = os.path.abspath(
os.path.join(
- os.path.dirname(__file__),
- "..",
- "..",
- "..",
- "scripts",
- "run-vitis-hls.sh",
+ os.path.dirname(__file__), "..", "..", "..", "scripts", "run-vitis-hls.sh",
)
)
assert os.path.isfile(
diff --git a/src/chop/passes/graph/transforms/verilog/logicnets/emit_linear.py b/src/chop/passes/graph/transforms/verilog/logicnets/emit_linear.py
index 3d8c0a16e..f0a52f75a 100644
--- a/src/chop/passes/graph/transforms/verilog/logicnets/emit_linear.py
+++ b/src/chop/passes/graph/transforms/verilog/logicnets/emit_linear.py
@@ -4,9 +4,7 @@
import torch.nn as nn
from chop.passes.graph.utils import init_project
-from chop.nn.quantized.modules.linear import (
- LinearLogicNets,
-)
+from chop.nn.quantized.modules.linear import LinearLogicNets
from .util import (
generate_lut_verilog,
diff --git a/src/chop/passes/module/analysis/report.py b/src/chop/passes/module/analysis/report.py
index 514f75ce8..c384ceafd 100644
--- a/src/chop/passes/module/analysis/report.py
+++ b/src/chop/passes/module/analysis/report.py
@@ -26,8 +26,7 @@ def get_submodule_summary(name: str, module: nn.Module, level: int = 0):
def report_trainable_parameters_analysis_pass(
- module: torch.nn.Module,
- pass_args: dict = {},
+ module: torch.nn.Module, pass_args: dict = {},
):
submodule_summary, total_params = get_submodule_summary("", module)
table = [(name, params) for _, name, params in submodule_summary]
diff --git a/src/chop/passes/utils.py b/src/chop/passes/utils.py
index 912250077..8d7ea71eb 100644
--- a/src/chop/passes/utils.py
+++ b/src/chop/passes/utils.py
@@ -23,9 +23,7 @@ def _nightly_torch_installed():
return False
-def find_missing_dependencies(
- pass_name: str,
-):
+def find_missing_dependencies(pass_name: str,):
dependencies = PassFactory._dependencies_dict.get(pass_name, None)
if dependencies is None:
@@ -40,9 +38,7 @@ def find_missing_dependencies(
def register_mase_pass(
- name: str,
- dependencies: list = [],
- requires_nightly_torch: bool = False,
+ name: str, dependencies: list = [], requires_nightly_torch: bool = False,
):
"""This decorator registers a mase pass as PassFactory class attributes which can be used globally."""
diff --git a/src/chop/pipelines/auto_pipeline.py b/src/chop/pipelines/auto_pipeline.py
index c9b784795..1bd382983 100644
--- a/src/chop/pipelines/auto_pipeline.py
+++ b/src/chop/pipelines/auto_pipeline.py
@@ -12,11 +12,7 @@ class AutoPipeline:
The output of each pass is stored in a dictionary and can be accessed by the next pass.
"""
- def __init__(
- self,
- pass_groups=None,
- run_training: bool = False,
- ) -> None:
+ def __init__(self, pass_groups=None, run_training: bool = False,) -> None:
"""Initializes the AutoPipeline.
Args:
@@ -26,11 +22,7 @@ def __init__(
self.pass_outputs = [{}] * len(pass_groups)
def _run_pass_group(
- self,
- mg: MaseGraph,
- pass_group: list,
- pass_args: dict,
- skip_passes: list = [],
+ self, mg: MaseGraph, pass_group: list, pass_args: dict, skip_passes: list = [],
):
pass_outputs = {}
@@ -56,10 +48,7 @@ def _run_pass_group(
return mg, pass_outputs
def __call__(
- self,
- mg: MaseGraph,
- pass_args: dict,
- skip_passes: list = [],
+ self, mg: MaseGraph, pass_args: dict, skip_passes: list = [],
):
for idx, pass_group in enumerate(self.pass_groups):
@@ -70,10 +59,7 @@ def __call__(
)
mg, pass_outputs = self._run_pass_group(
- mg,
- pass_group,
- pass_args,
- skip_passes,
+ mg, pass_group, pass_args, skip_passes,
)
self.pass_outputs[idx] = pass_outputs
diff --git a/src/chop/tools/check_dependency.py b/src/chop/tools/check_dependency.py
index dbf13cd12..c71ca088e 100644
--- a/src/chop/tools/check_dependency.py
+++ b/src/chop/tools/check_dependency.py
@@ -23,9 +23,7 @@ def check_deps_tensorRT_pass(silent: bool = True):
return all(availabilities)
-def find_missing_dependencies(
- pass_name: str,
-):
+def find_missing_dependencies(pass_name: str,):
dependencies = PassFactory._dependencies_dict.get(pass_name, None)
if dependencies is None:
@@ -40,8 +38,7 @@ def find_missing_dependencies(
def check_dependencies(
- pass_name: str,
- silent: bool = True,
+ pass_name: str, silent: bool = True,
):
unavailable_deps = find_missing_dependencies(pass_name)
diff --git a/src/chop/tools/huggingface.py b/src/chop/tools/huggingface.py
index 7cf675460..efa0fb1f6 100644
--- a/src/chop/tools/huggingface.py
+++ b/src/chop/tools/huggingface.py
@@ -39,10 +39,7 @@ def get_hf_dummy_in(model):
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
dummy_input = tokenizer(
- [
- "AI may take over the world one day",
- "This is why you should learn ADLS",
- ],
+ ["AI may take over the world one day", "This is why you should learn ADLS",],
return_tensors="pt",
)
@@ -53,9 +50,7 @@ def get_hf_dummy_in(model):
def get_tokenized_dataset(
- dataset: str,
- checkpoint: str,
- return_tokenizer: bool = False,
+ dataset: str, checkpoint: str, return_tokenizer: bool = False,
):
"""
Tokenizes a dataset using the AutoTokenizer from Huggingface.
@@ -81,10 +76,7 @@ def get_tokenized_dataset(
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
def tokenize_function(example):
- return tokenizer(
- example["text"],
- truncation=True,
- )
+ return tokenizer(example["text"], truncation=True,)
# Tokenize
tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)
diff --git a/src/chop/tools/plt_wrapper/nlp/classification.py b/src/chop/tools/plt_wrapper/nlp/classification.py
index f96227cba..a69b2ac0f 100644
--- a/src/chop/tools/plt_wrapper/nlp/classification.py
+++ b/src/chop/tools/plt_wrapper/nlp/classification.py
@@ -41,9 +41,7 @@ def forward(self, input_ids, attention_mask, token_type_ids=None, labels=None):
)
else:
outputs = self.model(
- input_ids,
- attention_mask=attention_mask,
- labels=labels,
+ input_ids, attention_mask=attention_mask, labels=labels,
)
return outputs
diff --git a/src/chop/tools/plt_wrapper/nlp/lm.py b/src/chop/tools/plt_wrapper/nlp/lm.py
index 60739da47..a2b669e87 100644
--- a/src/chop/tools/plt_wrapper/nlp/lm.py
+++ b/src/chop/tools/plt_wrapper/nlp/lm.py
@@ -46,9 +46,7 @@ def training_step(self, batch, batch_idx):
self.log("train_loss_step", loss, prog_bar=True)
self.log(
- "train_perplexity_step",
- perplexity,
- prog_bar=True,
+ "train_perplexity_step", perplexity, prog_bar=True,
)
return loss
diff --git a/src/chop/tools/utils.py b/src/chop/tools/utils.py
index 2ae59855e..d070eff92 100644
--- a/src/chop/tools/utils.py
+++ b/src/chop/tools/utils.py
@@ -92,7 +92,7 @@ def get_factors(n):
set(
functools.reduce(
list.__add__,
- ([i, n // i] for i in range(1, int(n**0.5) + 1) if n % i == 0),
+ ([i, n // i] for i in range(1, int(n ** 0.5) + 1) if n % i == 0),
)
)
)
@@ -184,9 +184,7 @@ def init_Conv2dLUT_weight(
# Initialize the weight based on the trained binaried network
# weight shape of the lagrange trainer [tables_count, self.kk]
input_mask = new_module.input_mask.reshape(
- -1,
- in_channels * kernel_size[0] * kernel_size[1] * k,
- 3,
+ -1, in_channels * kernel_size[0] * kernel_size[1] * k, 3,
) # [oc, k * kh * kw * ic ,3[ic,kh,kw]]
expanded_original_weight = original_weight[
np.arange(out_channels)[:, np.newaxis],
diff --git a/src/mase_cocotb/interfaces/streaming.py b/src/mase_cocotb/interfaces/streaming.py
index 10ca26820..711d2760f 100644
--- a/src/mase_cocotb/interfaces/streaming.py
+++ b/src/mase_cocotb/interfaces/streaming.py
@@ -17,14 +17,7 @@ def _sign_extend(value: int, bits: int):
class StreamDriver(Driver):
- def __init__(
- self,
- clk,
- data,
- valid,
- ready,
- record_num_beats=False,
- ) -> None:
+ def __init__(self, clk, data, valid, ready, record_num_beats=False,) -> None:
super().__init__()
self.clk = clk
self.data = data
@@ -74,14 +67,7 @@ async def _driver_send(self, transaction) -> None:
class StreamMonitor(Monitor):
def __init__(
- self,
- clk,
- data,
- valid,
- ready,
- check=True,
- name=None,
- unsigned=False,
+ self, clk, data, valid, ready, check=True, name=None, unsigned=False,
):
super().__init__(clk, check=check, name=name)
self.clk = clk
@@ -165,9 +151,9 @@ def __init__(self, clk, data, valid, ready, data_width, frac_width, check=True):
def _check(self, got, exp):
if self.check:
- float_got = [x * 2**-self.frac_width for x in got]
- float_exp = [x * 2**-self.frac_width for x in exp]
- if not np.isclose(float_got, float_exp, atol=2**-self.frac_width).all():
+ float_got = [x * 2 ** -self.frac_width for x in got]
+ float_exp = [x * 2 ** -self.frac_width for x in exp]
+ if not np.isclose(float_got, float_exp, atol=2 ** -self.frac_width).all():
# raise TestFailure("\nGot \n%s, \nExpected \n%s" % (got, exp))
raise TestFailure(
f"\nGot int \n{got}, \nExpected int \n{exp} \nGot float \n{float_got}, \nExpected float \n{float_exp}"
diff --git a/src/mase_cocotb/runner.py b/src/mase_cocotb/runner.py
index 21490555a..a74bff0d4 100644
--- a/src/mase_cocotb/runner.py
+++ b/src/mase_cocotb/runner.py
@@ -110,10 +110,7 @@ def _single_test(
verilog_sources=sources,
includes=includes,
hdl_toplevel=module,
- build_args=[
- *tool_args,
- *extra_build_args,
- ],
+ build_args=[*tool_args, *extra_build_args,],
# Do not use params in hierarchical verilation
parameters=module_params if not hierarchical else {},
build_dir=test_work_dir,
@@ -312,14 +309,10 @@ def simulate_pass(
verilog_sources=[rtl_dir / "top.sv"],
includes=[rtl_dir],
hdl_toplevel="top",
- build_args=[
- *_verilator_args(False, trace) * extra_build_args,
- ],
+ build_args=[*_verilator_args(False, trace) * extra_build_args,],
parameters=module_params,
build_dir=sim_dir,
)
runner.test(
- hdl_toplevel="top",
- test_module="test",
- results_xml="results.xml",
+ hdl_toplevel="top", test_module="test", results_xml="results.xml",
)
diff --git a/src/mase_cocotb/testbench.py b/src/mase_cocotb/testbench.py
index 0ad58ba1b..d746f74b5 100644
--- a/src/mase_cocotb/testbench.py
+++ b/src/mase_cocotb/testbench.py
@@ -8,12 +8,7 @@ class Testbench:
__test__ = False # so pytest doesn't confuse this with a test
def __init__(
- self,
- dut,
- clk=None,
- rst=None,
- fail_on_checks=True,
- clk_period_ns=20,
+ self, dut, clk=None, rst=None, fail_on_checks=True, clk_period_ns=20,
) -> None:
self.dut = dut
self.clk = clk
diff --git a/src/mase_cocotb/utils.py b/src/mase_cocotb/utils.py
index 7271389f2..25b6d83d8 100644
--- a/src/mase_cocotb/utils.py
+++ b/src/mase_cocotb/utils.py
@@ -78,8 +78,8 @@ def int_floor_quantizer(x: Tensor, width: int, frac_width: int, signed=True):
int_max = 2 ** (width - 1) - 1
else:
int_min = 0
- int_max = 2**width - 1
- scale = 2**frac_width
+ int_max = 2 ** width - 1
+ scale = 2 ** frac_width
return torch.clamp(torch.floor(x.mul(scale)), int_min, int_max).div(scale)
diff --git a/src/mase_cocotb/z_qlayers/tensor_cast.py b/src/mase_cocotb/z_qlayers/tensor_cast.py
index ce651bc5c..cf212cf79 100644
--- a/src/mase_cocotb/z_qlayers/tensor_cast.py
+++ b/src/mase_cocotb/z_qlayers/tensor_cast.py
@@ -46,9 +46,9 @@ def _integer_quantize(
int_max = 2 ** (width - 1) - 1
else:
int_min = 0
- int_max = 2**width - 1
+ int_max = 2 ** width - 1
# thresh = 2 ** (width - 1)
- scale = 2**frac_width
+ scale = 2 ** frac_width
if isinstance(x, (Tensor, ndarray)):
return my_clamp(my_round(x.mul(scale)), int_min, int_max).div(scale)
@@ -59,7 +59,7 @@ def _integer_quantize(
def quantize_to_int(x: Tensor, width: int, frac_width: int):
- x = (_integer_quantize(x, width, frac_width) * (2**frac_width)).int()
+ x = (_integer_quantize(x, width, frac_width) * (2 ** frac_width)).int()
return x
diff --git a/src/mase_components/activation_layers/test/fixed_elu_tb.py b/src/mase_components/activation_layers/test/fixed_elu_tb.py
index 87ff69fd6..1c7e850e4 100644
--- a/src/mase_components/activation_layers/test/fixed_elu_tb.py
+++ b/src/mase_components/activation_layers/test/fixed_elu_tb.py
@@ -79,10 +79,10 @@ def generate_lookup(data_width: int, f_width: int, function: str, type="hex"):
count += 1
iarr.append(i)
val = quanter(f(torch.tensor(i))) # entry in the lookup table
- lut[doubletofx(data_width=data_width, f_width=f_width, num=i, type=type)] = (
- doubletofx(
- data_width=data_width, f_width=f_width, num=val.item(), type=type
- )
+ lut[
+ doubletofx(data_width=data_width, f_width=f_width, num=i, type=type)
+ ] = doubletofx(
+ data_width=data_width, f_width=f_width, num=val.item(), type=type
)
i += 2 ** -(f_width)
return lut
@@ -444,8 +444,8 @@ def exp(self):
) # match output
logger.info(f"EXP - FLOAT OUTPUT: \n{m}")
m = self.out_dquantizer(m)
- m2 = (m * 2**self.outputfracw).to(torch.int64)
- m2 = m2.clone().detach() % (2**self.outputwidth)
+ m2 = (m * 2 ** self.outputfracw).to(torch.int64)
+ m2 = m2.clone().detach() % (2 ** self.outputwidth)
return m2
@@ -456,7 +456,7 @@ def generate_inputs(self):
)
logger.info(f"FLOAT INPUT: \n{inputs}")
inputs = self.in_dquantizer(inputs)
- intinp = (inputs * 2**self.frac_width).to(torch.int64)
+ intinp = (inputs * 2 ** self.frac_width).to(torch.int64)
return intinp, inputs
def doubletofx(self, num, data_width, f_width, type="bin"):
diff --git a/src/mase_components/activation_layers/test/fixed_gelu_tb.py b/src/mase_components/activation_layers/test/fixed_gelu_tb.py
index ab7d75f4c..2a8e79787 100644
--- a/src/mase_components/activation_layers/test/fixed_gelu_tb.py
+++ b/src/mase_components/activation_layers/test/fixed_gelu_tb.py
@@ -9,9 +9,7 @@
from mase_cocotb.runner import mase_runner
-from mase_components.helper.generate_memory import (
- generate_sv_lut,
-)
+from mase_components.helper.generate_memory import generate_sv_lut
DATA_IN_0_PRECISION_1 = 8
@@ -31,13 +29,13 @@ async def cocotb_test_fixed_gelu(dut):
resolution = (max_value - min_value) / (num_values - 1)
# Convert the resolution into fixed-point format
- resolution_fixed_point = int(resolution * (2**DATA_IN_0_PRECISION_1))
+ resolution_fixed_point = int(resolution * (2 ** DATA_IN_0_PRECISION_1))
# Generate the equidistant values
values = np.linspace(min_value, max_value, num_values)
# Convert values to fixed-point format
- values_fixed_point = np.round(values * (2**DATA_IN_0_PRECISION_1)).astype(int)
+ values_fixed_point = np.round(values * (2 ** DATA_IN_0_PRECISION_1)).astype(int)
tensor_tanh = torch.Tensor(values)
@@ -48,9 +46,9 @@ async def cocotb_test_fixed_gelu(dut):
for i in range(87):
# a = tanh_values[i]
- b = max_value * (2**DATA_IN_0_PRECISION_1) - resolution_fixed_point * (i - 1)
+ b = max_value * (2 ** DATA_IN_0_PRECISION_1) - resolution_fixed_point * (i - 1)
- a = [b / (2**DATA_IN_0_PRECISION_1)]
+ a = [b / (2 ** DATA_IN_0_PRECISION_1)]
tensor_tanh = torch.Tensor(a)
c = model(tensor_tanh)
@@ -58,7 +56,7 @@ async def cocotb_test_fixed_gelu(dut):
# Scale Tanh output to fixed-point range and convert to integers
scaled_value = np.round(
- tanh_value_numpy * (2**DATA_OUT_0_PRECISION_1)
+ tanh_value_numpy * (2 ** DATA_OUT_0_PRECISION_1)
).astype(int)
dut.data_in_0[0].value = b
diff --git a/src/mase_components/activation_layers/test/fixed_hardshrink_tb.py b/src/mase_components/activation_layers/test/fixed_hardshrink_tb.py
index efb6f0045..57f3940fb 100644
--- a/src/mase_components/activation_layers/test/fixed_hardshrink_tb.py
+++ b/src/mase_components/activation_layers/test/fixed_hardshrink_tb.py
@@ -62,11 +62,11 @@ def exp(self, inputs):
# cond = torch.logical_not(torch.logical_and(inputs <= self.thresh*2**self.fracw, inputs >= -1 * self.thresh *2**self.fracw))
# out = torch.where(cond, inputs, torch.tensor(0))
# unsignedout = torch.where(out < 0, torch.tensor(out % (2**self.width)), out)
- m = torch.nn.Hardshrink(self.thresh * 2**self.fracw)(inputs.to(torch.float))
+ m = torch.nn.Hardshrink(self.thresh * 2 ** self.fracw)(inputs.to(torch.float))
mout = m.clamp(
min=-1 * 2 ** (self.outputwidth - 1), max=2 ** (self.outputwidth - 1) - 1
)
- m2 = torch.where(mout < 0, torch.tensor(mout % (2**self.outputwidth)), mout)
+ m2 = torch.where(mout < 0, torch.tensor(mout % (2 ** self.outputwidth)), mout)
return m2.to(torch.int32).tolist()
def generate_inputs(self, w, fracw):
@@ -75,7 +75,7 @@ def generate_inputs(self, w, fracw):
)
realinp = torch.randn(self.samples)
inputs = self.dquantizer(realinp)
- intinp = (inputs * 2**self.fracw).to(torch.int64)
+ intinp = (inputs * 2 ** self.fracw).to(torch.int64)
intinp.clamp(
min=-(2 ** (self.width - self.fracw - 1)),
max=2 ** (self.width - self.fracw - 1) - 1,
diff --git a/src/mase_components/activation_layers/test/fixed_hardswish_tb.py b/src/mase_components/activation_layers/test/fixed_hardswish_tb.py
index ceb194e35..53f79755e 100644
--- a/src/mase_components/activation_layers/test/fixed_hardswish_tb.py
+++ b/src/mase_components/activation_layers/test/fixed_hardswish_tb.py
@@ -56,12 +56,14 @@ def __init__(self, dut) -> None:
def exp(self, inputs):
# Run the model with the provided inputs and return the outputs
- tmp0 = 3 * 2**self.fracw
+ tmp0 = 3 * 2 ** self.fracw
tmp1 = inputs + tmp0
- tmp2 = tmp1 * (2**-3) + tmp1 * (2**-4)
+ tmp2 = tmp1 * (2 ** -3) + tmp1 * (2 ** -4)
# qtmps = self.dquantizer(tmp2)
tmp3 = tmp2 * inputs
- unsignedout = torch.where(tmp3 < 0, torch.tensor(tmp3 % (2**self.width)), tmp3)
+ unsignedout = torch.where(
+ tmp3 < 0, torch.tensor(tmp3 % (2 ** self.width)), tmp3
+ )
# return unsignedout.tolist()
return unsignedout
@@ -71,7 +73,7 @@ def generate_inputs(self, w, fracw):
)
realinp = torch.randn(self.samples)
inputs = self.dquantizer(realinp)
- intinp = (inputs * 2**self.fracw).to(torch.int64)
+ intinp = (inputs * 2 ** self.fracw).to(torch.int64)
intinp.clamp(
min=-(2 ** (self.width - self.fracw - 1)),
max=2 ** (self.width - self.fracw - 1) - 1,
diff --git a/src/mase_components/activation_layers/test/fixed_leaky_relu_tb.py b/src/mase_components/activation_layers/test/fixed_leaky_relu_tb.py
index d0cd8796e..a44efe0b8 100644
--- a/src/mase_components/activation_layers/test/fixed_leaky_relu_tb.py
+++ b/src/mase_components/activation_layers/test/fixed_leaky_relu_tb.py
@@ -28,9 +28,9 @@ def get_in_and_out(x, fn, width, frac_width):
ins = integer_quantizer(x, width=width, frac_width=frac_width)
y = fn(x)
outs = integer_quantizer(y, width=width, frac_width=frac_width)
- outs = outs * 2**frac_width
+ outs = outs * 2 ** frac_width
outs = outs.int()
- ins = ins * 2**frac_width
+ ins = ins * 2 ** frac_width
ins = ins.int()
return (ins, outs)
@@ -78,7 +78,7 @@ async def cocotb_test(dut):
logger.info(f"Reset finished")
tb.data_out_0_monitor.ready.value = 1
- inputs, exp_outs = tb.generate_inputs_outputs(8, 4, 2**-4)
+ inputs, exp_outs = tb.generate_inputs_outputs(8, 4, 2 ** -4)
tb.data_in_0_driver.append(inputs.tolist())
diff --git a/src/mase_components/activation_layers/test/fixed_logsigmoid_tb.py b/src/mase_components/activation_layers/test/fixed_logsigmoid_tb.py
index 570ddc328..9b9a0a766 100644
--- a/src/mase_components/activation_layers/test/fixed_logsigmoid_tb.py
+++ b/src/mase_components/activation_layers/test/fixed_logsigmoid_tb.py
@@ -142,8 +142,8 @@ def exp(self):
) # match output
logger.info(f"EXP - FLOAT OUTPUT: \n{m}")
m = self.out_dquantizer(m)
- m2 = (m * 2**self.outputfracw).to(torch.int64)
- m2 = m2.clone().detach() % (2**self.outputwidth)
+ m2 = (m * 2 ** self.outputfracw).to(torch.int64)
+ m2 = m2.clone().detach() % (2 ** self.outputwidth)
return m2
@@ -154,7 +154,7 @@ def generate_inputs(self):
)
logger.info(f"FLOAT INPUT: \n{inputs}")
inputs = self.in_dquantizer(inputs)
- intinp = (inputs * 2**self.frac_width).to(torch.int64)
+ intinp = (inputs * 2 ** self.frac_width).to(torch.int64)
return intinp, inputs
def doubletofx(self, num, data_width, f_width, type="bin"):
diff --git a/src/mase_components/activation_layers/test/fixed_relu_tb.py b/src/mase_components/activation_layers/test/fixed_relu_tb.py
index 118296c48..dbe8ac146 100644
--- a/src/mase_components/activation_layers/test/fixed_relu_tb.py
+++ b/src/mase_components/activation_layers/test/fixed_relu_tb.py
@@ -47,7 +47,7 @@ def backward(ctx, grad_output):
def quantize(x, bits, bias): # bits = 32
"""Do linear quantization to input according to a scale and number of bits"""
thresh = 2 ** (bits - 1)
- scale = 2**bias
+ scale = 2 ** bias
return my_clamp(my_round(x.mul(scale)), -thresh, thresh - 1).div(scale)
@@ -83,7 +83,7 @@ def get_dut_parameters(self):
def get_dut_input(self, i):
inputs = self.inputs[i]
- shifted_integers = (inputs * (2**self.bias)).int()
+ shifted_integers = (inputs * (2 ** self.bias)).int()
return shifted_integers.numpy().tolist()
def get_dut_output(self, i):
@@ -92,7 +92,7 @@ def get_dut_output(self, i):
return shifted_integers
def convert_to_fixed(self, x):
- return (x * (2**self.bias)).int().numpy().tolist()
+ return (x * (2 ** self.bias)).int().numpy().tolist()
@cocotb.test()
diff --git a/src/mase_components/activation_layers/test/fixed_selu_tb.py b/src/mase_components/activation_layers/test/fixed_selu_tb.py
index 647c4fff1..459334c6c 100644
--- a/src/mase_components/activation_layers/test/fixed_selu_tb.py
+++ b/src/mase_components/activation_layers/test/fixed_selu_tb.py
@@ -25,13 +25,13 @@ async def cocotb_test_fixed_selu(dut):
resolution = (max_value - min_value) / (num_values - 1)
# Convert the resolution into fixed-point format
- resolution_fixed_point = int(resolution * (2**DATA_IN_0_PRECISION_1))
+ resolution_fixed_point = int(resolution * (2 ** DATA_IN_0_PRECISION_1))
# Generate the equidistant values
values = np.linspace(min_value, max_value, num_values)
# Convert values to fixed-point format
- values_fixed_point = np.round(values * (2**DATA_IN_0_PRECISION_1)).astype(int)
+ values_fixed_point = np.round(values * (2 ** DATA_IN_0_PRECISION_1)).astype(int)
tensor_tanh = torch.Tensor(values)
@@ -42,9 +42,9 @@ async def cocotb_test_fixed_selu(dut):
for i in range(87):
# a = tanh_values[i]
- b = max_value * (2**DATA_IN_0_PRECISION_1) - resolution_fixed_point * (i - 1)
+ b = max_value * (2 ** DATA_IN_0_PRECISION_1) - resolution_fixed_point * (i - 1)
- a = [b / (2**DATA_IN_0_PRECISION_1)]
+ a = [b / (2 ** DATA_IN_0_PRECISION_1)]
tensor_tanh = torch.Tensor(a)
c = model(tensor_tanh)
@@ -52,7 +52,7 @@ async def cocotb_test_fixed_selu(dut):
# Scale Tanh output to fixed-point range and convert to integers
scaled_value = np.round(
- tanh_value_numpy * (2**DATA_OUT_0_PRECISION_1)
+ tanh_value_numpy * (2 ** DATA_OUT_0_PRECISION_1)
).astype(int)
dut.data_in_0[0].value = b
diff --git a/src/mase_components/activation_layers/test/fixed_sigmoid_tb.py b/src/mase_components/activation_layers/test/fixed_sigmoid_tb.py
index e2a794d22..427220089 100644
--- a/src/mase_components/activation_layers/test/fixed_sigmoid_tb.py
+++ b/src/mase_components/activation_layers/test/fixed_sigmoid_tb.py
@@ -128,8 +128,8 @@ def exp(self):
) # match output
logger.info(f"EXP - FLOAT OUTPUT: \n{m}")
m = self.out_dquantizer(m)
- m2 = (m * 2**self.outputfracw).to(torch.int64)
- m2 = m2.clone().detach() % (2**self.outputwidth)
+ m2 = (m * 2 ** self.outputfracw).to(torch.int64)
+ m2 = m2.clone().detach() % (2 ** self.outputwidth)
return m2
@@ -140,7 +140,7 @@ def generate_inputs(self):
)
logger.info(f"FLOAT INPUT: \n{inputs}")
inputs = self.in_dquantizer(inputs)
- intinp = (inputs * 2**self.frac_width).to(torch.int64)
+ intinp = (inputs * 2 ** self.frac_width).to(torch.int64)
return intinp, inputs
def doubletofx(self, num, data_width, f_width, type="bin"):
diff --git a/src/mase_components/activation_layers/test/fixed_silu_tb.py b/src/mase_components/activation_layers/test/fixed_silu_tb.py
index 944a83e5e..19c4ac3a2 100644
--- a/src/mase_components/activation_layers/test/fixed_silu_tb.py
+++ b/src/mase_components/activation_layers/test/fixed_silu_tb.py
@@ -143,8 +143,8 @@ def exp(self):
) # match output
logger.info(f"EXP - FLOAT OUTPUT: \n{m}")
m = self.out_dquantizer(m)
- m2 = (m * 2**self.outputfracw).to(torch.int64)
- m2 = m2.clone().detach() % (2**self.outputwidth)
+ m2 = (m * 2 ** self.outputfracw).to(torch.int64)
+ m2 = m2.clone().detach() % (2 ** self.outputwidth)
return m2
@@ -155,7 +155,7 @@ def generate_inputs(self):
)
logger.info(f"FLOAT INPUT: \n{inputs}")
inputs = self.in_dquantizer(inputs)
- intinp = (inputs * 2**self.frac_width).to(torch.int64)
+ intinp = (inputs * 2 ** self.frac_width).to(torch.int64)
return intinp, inputs
def doubletofx(self, num, data_width, f_width, type="bin"):
diff --git a/src/mase_components/activation_layers/test/fixed_softermax_1d_tb.py b/src/mase_components/activation_layers/test/fixed_softermax_1d_tb.py
index d48214384..41e67e205 100644
--- a/src/mase_components/activation_layers/test/fixed_softermax_1d_tb.py
+++ b/src/mase_components/activation_layers/test/fixed_softermax_1d_tb.py
@@ -19,9 +19,7 @@
from chop.nn.quantized.functional import fixed_softermax
-from chop.nn.quantizers import (
- integer_quantizer,
-)
+from chop.nn.quantizers import integer_quantizer
class SoftermaxTB(Testbench):
@@ -66,10 +64,7 @@ def __init__(self, dut) -> None:
self.model = partial(
fixed_softermax,
dim=0,
- q_config={
- "width": self.IN_WIDTH,
- "frac_width": self.IN_FRAC_WIDTH,
- },
+ q_config={"width": self.IN_WIDTH, "frac_width": self.IN_FRAC_WIDTH,},
)
# Set verbosity of driver and monitor loggers to debug
@@ -77,9 +72,7 @@ def __init__(self, dut) -> None:
# self.out_data_monitor.log.setLevel(logging.DEBUG)
def generate_inputs(self, batches):
- return torch.randn(
- (batches, self.TOTAL_DIM),
- )
+ return torch.randn((batches, self.TOTAL_DIM),)
async def run_test(self, batches, us):
await self.reset()
@@ -95,10 +88,7 @@ async def run_test(self, batches, us):
self.log.debug(f"Processing inputs: {batch}")
driver_input = fixed_preprocess_tensor(
tensor=batch,
- q_config={
- "width": self.IN_WIDTH,
- "frac_width": self.IN_FRAC_WIDTH,
- },
+ q_config={"width": self.IN_WIDTH, "frac_width": self.IN_FRAC_WIDTH,},
parallelism=[self.PARALLELISM],
)
self.in_data_driver.load_driver(driver_input)
@@ -107,10 +97,7 @@ async def run_test(self, batches, us):
self.log.debug(f"Processing outputs: {exp_out}")
outs = fixed_preprocess_tensor(
tensor=exp_out,
- q_config={
- "width": self.OUT_WIDTH,
- "frac_width": self.OUT_FRAC_WIDTH,
- },
+ q_config={"width": self.OUT_WIDTH, "frac_width": self.OUT_FRAC_WIDTH,},
parallelism=[self.PARALLELISM],
)
self.out_data_monitor.load_monitor(outs)
diff --git a/src/mase_components/activation_layers/test/fixed_softermax_tb.py b/src/mase_components/activation_layers/test/fixed_softermax_tb.py
index 9b7d0aee1..84c6aa4e5 100644
--- a/src/mase_components/activation_layers/test/fixed_softermax_tb.py
+++ b/src/mase_components/activation_layers/test/fixed_softermax_tb.py
@@ -133,9 +133,7 @@ def test_fixed_softermax_smoke():
"""
mase_runner(
trace=True,
- module_param_list=[
- get_fixed_softermax_config(),
- ],
+ module_param_list=[get_fixed_softermax_config(),],
# skip_build=True,
)
diff --git a/src/mase_components/activation_layers/test/fixed_softmax_tb.py b/src/mase_components/activation_layers/test/fixed_softmax_tb.py
index 62f37012f..afc8876ac 100644
--- a/src/mase_components/activation_layers/test/fixed_softmax_tb.py
+++ b/src/mase_components/activation_layers/test/fixed_softmax_tb.py
@@ -128,8 +128,8 @@ def exp(self):
) # match output
logger.info(f"EXP - FLOAT OUTPUT: \n{m}")
m = self.out_dquantizer(m)
- m2 = (m * 2**self.outputfracw).to(torch.int64)
- m2 = m2.clone().detach() % (2**self.outputwidth)
+ m2 = (m * 2 ** self.outputfracw).to(torch.int64)
+ m2 = m2.clone().detach() % (2 ** self.outputwidth)
return m2
@@ -140,7 +140,7 @@ def generate_inputs(self):
)
logger.info(f"FLOAT INPUT: \n{inputs}")
inputs = self.in_dquantizer(inputs)
- intinp = (inputs * 2**self.frac_width).to(torch.int64)
+ intinp = (inputs * 2 ** self.frac_width).to(torch.int64)
return intinp, inputs
def doubletofx(self, num, data_width, f_width, type="bin"):
diff --git a/src/mase_components/activation_layers/test/fixed_softplus_tb.py b/src/mase_components/activation_layers/test/fixed_softplus_tb.py
index 121fa5c60..f39467ee7 100644
--- a/src/mase_components/activation_layers/test/fixed_softplus_tb.py
+++ b/src/mase_components/activation_layers/test/fixed_softplus_tb.py
@@ -25,13 +25,13 @@ async def cocotb_test_fixed_softplus(dut):
resolution = (max_value - min_value) / (num_values - 1)
# Convert the resolution into fixed-point format
- resolution_fixed_point = int(resolution * (2**DATA_IN_0_PRECISION_1))
+ resolution_fixed_point = int(resolution * (2 ** DATA_IN_0_PRECISION_1))
# Generate the equidistant values
values = np.linspace(min_value, max_value, num_values)
# Convert values to fixed-point format
- values_fixed_point = np.round(values * (2**DATA_IN_0_PRECISION_1)).astype(int)
+ values_fixed_point = np.round(values * (2 ** DATA_IN_0_PRECISION_1)).astype(int)
tensor_tanh = torch.Tensor(values)
@@ -42,9 +42,9 @@ async def cocotb_test_fixed_softplus(dut):
for i in range(87):
# a = tanh_values[i]
- b = max_value * (2**DATA_IN_0_PRECISION_1) - resolution_fixed_point * (i - 1)
+ b = max_value * (2 ** DATA_IN_0_PRECISION_1) - resolution_fixed_point * (i - 1)
- a = [b / (2**DATA_IN_0_PRECISION_1)]
+ a = [b / (2 ** DATA_IN_0_PRECISION_1)]
tensor_tanh = torch.Tensor(a)
c = model(tensor_tanh)
@@ -52,7 +52,7 @@ async def cocotb_test_fixed_softplus(dut):
# Scale Tanh output to fixed-point range and convert to integers
scaled_value = np.round(
- tanh_value_numpy * (2**DATA_OUT_0_PRECISION_1)
+ tanh_value_numpy * (2 ** DATA_OUT_0_PRECISION_1)
).astype(int)
dut.data_in_0[0].value = b
diff --git a/src/mase_components/activation_layers/test/fixed_softshrink_tb.py b/src/mase_components/activation_layers/test/fixed_softshrink_tb.py
index b79ccdaf5..9c738602e 100644
--- a/src/mase_components/activation_layers/test/fixed_softshrink_tb.py
+++ b/src/mase_components/activation_layers/test/fixed_softshrink_tb.py
@@ -142,8 +142,8 @@ def exp(self):
) # match output
logger.info(f"EXP - FLOAT OUTPUT: \n{m}")
m = self.out_dquantizer(m)
- m2 = (m * 2**self.outputfracw).to(torch.int64)
- m2 = m2.clone().detach() % (2**self.outputwidth)
+ m2 = (m * 2 ** self.outputfracw).to(torch.int64)
+ m2 = m2.clone().detach() % (2 ** self.outputwidth)
return m2
@@ -154,7 +154,7 @@ def generate_inputs(self):
)
logger.info(f"FLOAT INPUT: \n{inputs}")
inputs = self.in_dquantizer(inputs)
- intinp = (inputs * 2**self.frac_width).to(torch.int64)
+ intinp = (inputs * 2 ** self.frac_width).to(torch.int64)
return intinp, inputs
def doubletofx(self, num, data_width, f_width, type="bin"):
diff --git a/src/mase_components/activation_layers/test/fixed_softsign_tb.py b/src/mase_components/activation_layers/test/fixed_softsign_tb.py
index 5fec6b341..c8776ac32 100644
--- a/src/mase_components/activation_layers/test/fixed_softsign_tb.py
+++ b/src/mase_components/activation_layers/test/fixed_softsign_tb.py
@@ -24,13 +24,13 @@ async def cocotb_test_fixed_softsign(dut):
resolution = (max_value - min_value) / (num_values - 1)
# Convert the resolution into fixed-point format
- resolution_fixed_point = int(resolution * (2**DATA_IN_0_PRECISION_1))
+ resolution_fixed_point = int(resolution * (2 ** DATA_IN_0_PRECISION_1))
# Generate the equidistant values
values = np.linspace(min_value, max_value, num_values)
# Convert values to fixed-point format
- values_fixed_point = np.round(values * (2**DATA_IN_0_PRECISION_1)).astype(int)
+ values_fixed_point = np.round(values * (2 ** DATA_IN_0_PRECISION_1)).astype(int)
tensor_tanh = torch.Tensor(values)
@@ -41,9 +41,9 @@ async def cocotb_test_fixed_softsign(dut):
for i in range(87):
# a = tanh_values[i]
- b = max_value * (2**DATA_IN_0_PRECISION_1) - resolution_fixed_point * (i - 1)
+ b = max_value * (2 ** DATA_IN_0_PRECISION_1) - resolution_fixed_point * (i - 1)
- a = [b / (2**DATA_IN_0_PRECISION_1)]
+ a = [b / (2 ** DATA_IN_0_PRECISION_1)]
tensor_tanh = torch.Tensor(a)
c = model(tensor_tanh)
@@ -51,7 +51,7 @@ async def cocotb_test_fixed_softsign(dut):
# Scale Tanh output to fixed-point range and convert to integers
scaled_value = np.round(
- tanh_value_numpy * (2**DATA_OUT_0_PRECISION_1)
+ tanh_value_numpy * (2 ** DATA_OUT_0_PRECISION_1)
).astype(int)
dut.data_in_0[0].value = b
diff --git a/src/mase_components/activation_layers/test/fixed_tanh_tb.py b/src/mase_components/activation_layers/test/fixed_tanh_tb.py
index 4e8c6d268..937fb930c 100644
--- a/src/mase_components/activation_layers/test/fixed_tanh_tb.py
+++ b/src/mase_components/activation_layers/test/fixed_tanh_tb.py
@@ -24,13 +24,13 @@ async def cocotb_test_fixed_tanh(dut):
resolution = (max_value - min_value) / (num_values - 1)
# Convert the resolution into fixed-point format
- resolution_fixed_point = int(resolution * (2**DATA_IN_0_PRECISION_1))
+ resolution_fixed_point = int(resolution * (2 ** DATA_IN_0_PRECISION_1))
# Generate the equidistant values
values = np.linspace(min_value, max_value, num_values)
# Convert values to fixed-point format
- values_fixed_point = np.round(values * (2**DATA_IN_0_PRECISION_1)).astype(int)
+ values_fixed_point = np.round(values * (2 ** DATA_IN_0_PRECISION_1)).astype(int)
tensor_tanh = torch.Tensor(values)
@@ -41,9 +41,9 @@ async def cocotb_test_fixed_tanh(dut):
for i in range(87):
# a = tanh_values[i]
- b = max_value * (2**DATA_IN_0_PRECISION_1) - resolution_fixed_point * (i - 1)
+ b = max_value * (2 ** DATA_IN_0_PRECISION_1) - resolution_fixed_point * (i - 1)
- a = [b / (2**DATA_IN_0_PRECISION_1)]
+ a = [b / (2 ** DATA_IN_0_PRECISION_1)]
tensor_tanh = torch.Tensor(a)
c = model(tensor_tanh)
@@ -51,7 +51,7 @@ async def cocotb_test_fixed_tanh(dut):
# Scale Tanh output to fixed-point range and convert to integers
scaled_value = np.round(
- tanh_value_numpy * (2**DATA_OUT_0_PRECISION_1)
+ tanh_value_numpy * (2 ** DATA_OUT_0_PRECISION_1)
).astype(int)
dut.data_in_0[0].value = b
diff --git a/src/mase_components/activation_layers/test/softermax.py b/src/mase_components/activation_layers/test/softermax.py
index a1b3271f5..7c30c2986 100644
--- a/src/mase_components/activation_layers/test/softermax.py
+++ b/src/mase_components/activation_layers/test/softermax.py
@@ -53,7 +53,7 @@ def _softmax_model(l: list[int], parallelism: int, pow2=False):
for diff, vals in zip(local_max_diff, local_values_buffer):
if pow2:
- adj = [x * (2**-diff) for x in vals]
+ adj = [x * (2 ** -diff) for x in vals]
else:
adj = [x * exp(-diff) for x in vals]
norm += sum(adj)
diff --git a/src/mase_components/activation_layers/test/softermax_global_norm_tb.py b/src/mase_components/activation_layers/test/softermax_global_norm_tb.py
index bc74e42ee..6ebbfc131 100644
--- a/src/mase_components/activation_layers/test/softermax_global_norm_tb.py
+++ b/src/mase_components/activation_layers/test/softermax_global_norm_tb.py
@@ -46,7 +46,7 @@ def __init__(self, dut) -> None:
# Specify Error Threshold
self.percentage_error = 0.05 # 5%
- self.error_threshold_bits = ceil(self.percentage_error * (2**self.OUT_WIDTH))
+ self.error_threshold_bits = ceil(self.percentage_error * (2 ** self.OUT_WIDTH))
self.output_monitor = ErrorThresholdStreamMonitor(
dut.clk,
@@ -63,15 +63,15 @@ def __init__(self, dut) -> None:
def generate_inputs(self, batches=10):
# TODO: Take a look at all zero case again
local_vals = torch.randint(
- 1, 2**self.IN_VALUE_WIDTH, size=(batches * self.DEPTH, self.PARALLELISM)
+ 1, 2 ** self.IN_VALUE_WIDTH, size=(batches * self.DEPTH, self.PARALLELISM)
)
local_max = torch.randint(
- 0, 2**self.IN_MAX_WIDTH, size=(batches * self.DEPTH, 1)
+ 0, 2 ** self.IN_MAX_WIDTH, size=(batches * self.DEPTH, 1)
)
logger.debug("local_vals: %s" % (local_vals))
logger.debug(
- "local_vals (float): %s" % (local_vals / (2**self.IN_VALUE_FRAC_WIDTH))
+ "local_vals (float): %s" % (local_vals / (2 ** self.IN_VALUE_FRAC_WIDTH))
)
logger.debug("local_max: %s" % (local_max))
logger.debug(
@@ -87,7 +87,7 @@ def model(self, inputs):
for batch in batched_in:
local_vals, local_max = list(zip(*batch))
local_vals = torch.tensor(list(local_vals), dtype=torch.float) / (
- 2**self.IN_VALUE_FRAC_WIDTH
+ 2 ** self.IN_VALUE_FRAC_WIDTH
)
local_max = torch.tensor(list(local_max), dtype=torch.float)
local_max = sign_extend_t(
@@ -97,7 +97,7 @@ def model(self, inputs):
global_max = local_max.max()
adj_amt = global_max - local_max.reshape(self.DEPTH, 1)
adj_values = integer_floor_quantizer(
- x=local_vals / (2**adj_amt),
+ x=local_vals / (2 ** adj_amt),
width=self.IN_VALUE_WIDTH,
frac_width=self.IN_VALUE_FRAC_WIDTH,
is_signed=False,
@@ -226,10 +226,7 @@ def in_value_cfgs(cfgs: list):
for cfg in cfgs:
for in_width in [4, 7, 10]:
new_cfgs.append(
- {
- **cfg,
- "IN_VALUE_WIDTH": in_width,
- }
+ {**cfg, "IN_VALUE_WIDTH": in_width,}
)
return new_cfgs
@@ -238,10 +235,7 @@ def in_max_cfgs(cfgs: list):
for cfg in cfgs:
for in_max in [2, 3, 4]:
new_cfgs.append(
- {
- **cfg,
- "IN_MAX_WIDTH": in_max,
- }
+ {**cfg, "IN_MAX_WIDTH": in_max,}
)
return new_cfgs
@@ -256,7 +250,5 @@ def in_max_cfgs(cfgs: list):
# cfgs = [{'TOTAL_DIM': 32, 'PARALLELISM': 4, 'IN_VALUE_WIDTH': 16, 'IN_MAX_WIDTH': 2}]
mase_runner(
- module_param_list=cfgs,
- trace=True,
- jobs=12,
+ module_param_list=cfgs, trace=True, jobs=12,
)
diff --git a/src/mase_components/activation_layers/test/softermax_local_window_tb.py b/src/mase_components/activation_layers/test/softermax_local_window_tb.py
index 10b11a569..c52b7c647 100644
--- a/src/mase_components/activation_layers/test/softermax_local_window_tb.py
+++ b/src/mase_components/activation_layers/test/softermax_local_window_tb.py
@@ -53,7 +53,7 @@ def __init__(self, dut) -> None:
def generate_inputs(self, batches=10):
return [
- [randint(0, 2**self.IN_WIDTH - 1) for _ in range(self.PARALLELISM)]
+ [randint(0, 2 ** self.IN_WIDTH - 1) for _ in range(self.PARALLELISM)]
for _ in range(batches)
]
@@ -80,7 +80,7 @@ def model(self, inputs):
sign_ext = sign_extend_t(
torch.tensor(inputs, dtype=torch.float), bits=self.IN_WIDTH
)
- float_inputs = sign_ext / (2**self.IN_FRAC_WIDTH)
+ float_inputs = sign_ext / (2 ** self.IN_FRAC_WIDTH)
# float_inputs = torch.tensor([[-31.5, -32]])
rounded_inputs_float, rounded_inputs_uint = _fixed_signed_cast_model(
float_inputs, self.MAX_WIDTH, 0, False, "floor"
@@ -89,9 +89,9 @@ def model(self, inputs):
local_max_uint = signed_to_unsigned(local_max.int(), self.MAX_WIDTH)
difference = float_inputs - local_max
- pow2 = 2**difference
+ pow2 = 2 ** difference
res = torch.clamp(
- (pow2 * 2**self.OUT_FRAC_WIDTH).int(), 0, 2**self.OUT_WIDTH - 1
+ (pow2 * 2 ** self.OUT_FRAC_WIDTH).int(), 0, 2 ** self.OUT_WIDTH - 1
)
logger.debug("float_inputs: %s" % float_inputs)
diff --git a/src/mase_components/activation_layers/test/softermax_lpw_pow2_tb.py b/src/mase_components/activation_layers/test/softermax_lpw_pow2_tb.py
index 058fa10b0..0a4b52af5 100644
--- a/src/mase_components/activation_layers/test/softermax_lpw_pow2_tb.py
+++ b/src/mase_components/activation_layers/test/softermax_lpw_pow2_tb.py
@@ -37,7 +37,7 @@ def __init__(self, dut) -> None:
self.in_driver = StreamDriver(dut.clk, dut.in_data, dut.in_valid, dut.in_ready)
# 0.1% bit error
- self.error_threshold_bits = ceil((2**self.IN_WIDTH) * 0.001)
+ self.error_threshold_bits = ceil((2 ** self.IN_WIDTH) * 0.001)
self.output_monitor = ErrorThresholdStreamMonitor(
dut.clk,
@@ -53,10 +53,10 @@ def __init__(self, dut) -> None:
def sweep_inputs(self):
negative_nums = torch.arange(
- start=2 ** (self.IN_WIDTH - 1), end=2**self.IN_WIDTH, dtype=torch.int32
+ start=2 ** (self.IN_WIDTH - 1), end=2 ** self.IN_WIDTH, dtype=torch.int32
)
zero_to_one = torch.arange(
- start=0, end=2**self.IN_FRAC_WIDTH, dtype=torch.int32 # one
+ start=0, end=2 ** self.IN_FRAC_WIDTH, dtype=torch.int32 # one
)
return torch.cat((negative_nums, zero_to_one)).tolist()
@@ -68,14 +68,14 @@ def generate_inputs(self, batches=1):
# Negative Numbers
torch.randint(
low=2 ** (self.IN_WIDTH - 1),
- high=2**self.IN_WIDTH,
+ high=2 ** self.IN_WIDTH,
size=(negative_nums,),
dtype=torch.int32,
),
# Numbers between 0 and 1
torch.randint(
low=0,
- high=2**self.IN_FRAC_WIDTH,
+ high=2 ** self.IN_FRAC_WIDTH,
size=(zero_to_one_nums,),
dtype=torch.int32,
),
@@ -85,10 +85,10 @@ def generate_inputs(self, batches=1):
def model(self, inputs):
in_t = torch.tensor(inputs)
- num = sign_extend_t(in_t, self.IN_WIDTH) / (2**self.IN_FRAC_WIDTH)
- res = 2**num
- res = (res * 2**self.OUT_FRAC_WIDTH).int()
- res = torch.clamp(res, 0, 2**self.OUT_WIDTH - 1)
+ num = sign_extend_t(in_t, self.IN_WIDTH) / (2 ** self.IN_FRAC_WIDTH)
+ res = 2 ** num
+ res = (res * 2 ** self.OUT_FRAC_WIDTH).int()
+ res = torch.clamp(res, 0, 2 ** self.OUT_WIDTH - 1)
return res.tolist()
async def run_test(self, batches, us):
@@ -128,7 +128,7 @@ async def sweep(dut):
tb.in_driver.load_driver(inputs)
tb.output_monitor.load_monitor(exp_out)
- ns = ((2**tb.IN_WIDTH) * 1000) // 5
+ ns = ((2 ** tb.IN_WIDTH) * 1000) // 5
logger.info("Waiting %d ns..." % ns)
await Timer(ns, "ns")
assert tb.output_monitor.exp_queue.empty()
@@ -137,14 +137,14 @@ async def sweep(dut):
recv_log = tb.output_monitor.recv_log
assert len(exp_out) == len(recv_log)
- x = sign_extend_t(torch.tensor(inputs), tb.IN_WIDTH) / (2**tb.IN_FRAC_WIDTH)
- ref = 2**x
- ref *= 2**tb.OUT_FRAC_WIDTH # scale up
- ref = torch.clamp(ref, 0, 2**tb.OUT_WIDTH - 1)
+ x = sign_extend_t(torch.tensor(inputs), tb.IN_WIDTH) / (2 ** tb.IN_FRAC_WIDTH)
+ ref = 2 ** x
+ ref *= 2 ** tb.OUT_FRAC_WIDTH # scale up
+ ref = torch.clamp(ref, 0, 2 ** tb.OUT_WIDTH - 1)
- software_ref = ref / (2**tb.OUT_FRAC_WIDTH)
- software_res = [x / (2**tb.OUT_FRAC_WIDTH) for x in exp_out]
- hardware_res = [x / (2**tb.OUT_FRAC_WIDTH) for x in recv_log]
+ software_ref = ref / (2 ** tb.OUT_FRAC_WIDTH)
+ software_res = [x / (2 ** tb.OUT_FRAC_WIDTH) for x in exp_out]
+ hardware_res = [x / (2 ** tb.OUT_FRAC_WIDTH) for x in recv_log]
data = pd.DataFrame(
{
@@ -174,10 +174,7 @@ async def sweep(dut):
),
color=alt.Color("Type"),
)
- .properties(
- width=600,
- height=220,
- )
+ .properties(width=600, height=220,)
)
error_data = data[["x", "hardware error"]]
@@ -188,10 +185,7 @@ async def sweep(dut):
x=alt.X("x").title(f"x (Q{tb.IN_WIDTH}.{tb.IN_FRAC_WIDTH} Fixed-point)"),
y=alt.Y("hardware error").title(f"Error"),
)
- .properties(
- width=600,
- height=100,
- )
+ .properties(width=600, height=100,)
)
(curve_fig & error_fig).save(
@@ -302,12 +296,7 @@ def test_high_width():
def test_smoke():
mase_runner(
module_param_list=[
- {
- "IN_WIDTH": 8,
- "IN_FRAC_WIDTH": 4,
- "OUT_WIDTH": 8,
- "OUT_FRAC_WIDTH": 4,
- }
+ {"IN_WIDTH": 8, "IN_FRAC_WIDTH": 4, "OUT_WIDTH": 8, "OUT_FRAC_WIDTH": 4,}
]
)
diff --git a/src/mase_components/activation_layers/test/softermax_lpw_reciprocal_tb.py b/src/mase_components/activation_layers/test/softermax_lpw_reciprocal_tb.py
index f2c70b67c..842881dc9 100644
--- a/src/mase_components/activation_layers/test/softermax_lpw_reciprocal_tb.py
+++ b/src/mase_components/activation_layers/test/softermax_lpw_reciprocal_tb.py
@@ -38,7 +38,7 @@ def __init__(self, dut) -> None:
# Specify Error Threshold
self.percentage_error = 0.05 # 5%
- self.error_threshold_bits = ceil(self.percentage_error * (2**self.OUT_WIDTH))
+ self.error_threshold_bits = ceil(self.percentage_error * (2 ** self.OUT_WIDTH))
self.output_monitor = ErrorThresholdStreamMonitor(
dut.clk,
@@ -53,17 +53,17 @@ def __init__(self, dut) -> None:
)
def generate_inputs(self, batches=100):
- return [randint(0, 2**self.IN_WIDTH - 1) for _ in range(batches)]
+ return [randint(0, 2 ** self.IN_WIDTH - 1) for _ in range(batches)]
def sweep_input(self):
- return list(range(2**self.IN_WIDTH))
+ return list(range(2 ** self.IN_WIDTH))
def model(self, inputs):
- in_t = torch.tensor(inputs) / (2**self.IN_FRAC_WIDTH)
+ in_t = torch.tensor(inputs) / (2 ** self.IN_FRAC_WIDTH)
recip = 1.0 / in_t
- res = torch.floor(recip * 2**self.OUT_FRAC_WIDTH)
+ res = torch.floor(recip * 2 ** self.OUT_FRAC_WIDTH)
res = torch.nan_to_num(res)
- res = torch.clamp(res, 0, 2**self.OUT_WIDTH - 1)
+ res = torch.clamp(res, 0, 2 ** self.OUT_WIDTH - 1)
res = res.int()
return res.tolist()
@@ -107,14 +107,14 @@ async def sweep(dut):
recv_log = tb.output_monitor.recv_log
assert len(exp_out) == len(recv_log)
- x = torch.tensor(inputs) / (2**tb.IN_FRAC_WIDTH)
+ x = torch.tensor(inputs) / (2 ** tb.IN_FRAC_WIDTH)
ref = 1.0 / x
- ref *= 2**tb.OUT_FRAC_WIDTH # scale up
- ref = torch.clamp(ref, 0, 2**tb.OUT_WIDTH - 1)
+ ref *= 2 ** tb.OUT_FRAC_WIDTH # scale up
+ ref = torch.clamp(ref, 0, 2 ** tb.OUT_WIDTH - 1)
- software_ref = ref / (2**tb.OUT_FRAC_WIDTH)
- software_res = [x / (2**tb.OUT_FRAC_WIDTH) for x in exp_out]
- hardware_res = [x / (2**tb.OUT_FRAC_WIDTH) for x in recv_log]
+ software_ref = ref / (2 ** tb.OUT_FRAC_WIDTH)
+ software_res = [x / (2 ** tb.OUT_FRAC_WIDTH) for x in exp_out]
+ hardware_res = [x / (2 ** tb.OUT_FRAC_WIDTH) for x in recv_log]
data = pd.DataFrame(
{
@@ -150,10 +150,7 @@ async def sweep(dut):
),
color=alt.Color("Type"),
)
- .properties(
- width=600,
- height=300,
- )
+ .properties(width=600, height=300,)
)
error_data = data[["x", "hardware error"]]
@@ -164,10 +161,7 @@ async def sweep(dut):
x=alt.X("x").title(f"x (Q{tb.IN_WIDTH}.{tb.IN_FRAC_WIDTH} Fixed-point)"),
y=alt.Y("hardware error").title(f"Error"),
)
- .properties(
- width=600,
- height=100,
- )
+ .properties(width=600, height=100,)
)
(curve_fig & error_fig).save(
@@ -222,7 +216,7 @@ def width_cfgs():
for width in range(2, 16 + 1):
frac_width = width // 2
if frac_width < 3:
- entries = 2**frac_width
+ entries = 2 ** frac_width
else:
entries = 8
cfgs.append(
@@ -260,9 +254,7 @@ def test_width_cfgs():
"OUT_FRAC_WIDTH": 4,
}
]
- mase_runner(
- module_param_list=cfgs,
- )
+ mase_runner(module_param_list=cfgs,)
def test_smoke():
diff --git a/src/mase_components/cast/test/fixed_rounding_tb.py b/src/mase_components/cast/test/fixed_rounding_tb.py
index a7f2f935c..300283856 100644
--- a/src/mase_components/cast/test/fixed_rounding_tb.py
+++ b/src/mase_components/cast/test/fixed_rounding_tb.py
@@ -44,7 +44,7 @@ def single_run(self):
def sw_cast(self, inputs):
outputs = (
integer_floor_quantizer(inputs, self.out_width, self.out_frac_width)
- * 2**self.out_frac_width
+ * 2 ** self.out_frac_width
)
# breakpoint()
return outputs
diff --git a/src/mase_components/cast/test/fixed_signed_cast_tb.py b/src/mase_components/cast/test/fixed_signed_cast_tb.py
index 2ce3bbe5f..9f5a6956d 100644
--- a/src/mase_components/cast/test/fixed_signed_cast_tb.py
+++ b/src/mase_components/cast/test/fixed_signed_cast_tb.py
@@ -18,7 +18,7 @@
def _fixed_signed_cast_model(
float_input, out_width, out_frac_width, symmetric, rounding_mode
):
- scaled_float = float_input * (2**out_frac_width)
+ scaled_float = float_input * (2 ** out_frac_width)
if rounding_mode == "floor":
out_int = my_floor(scaled_float)
elif rounding_mode == "round_nearest_half_even":
@@ -30,7 +30,7 @@ def _fixed_signed_cast_model(
-(2 ** (out_width - 1)) + 1 if symmetric else -(2 ** (out_width - 1)),
(2 ** (out_width - 1)) - 1,
)
- out_float = out_int / (2**out_frac_width)
+ out_float = out_int / (2 ** out_frac_width)
# out_uint is a non-differentiable path
out_uint = signed_to_unsigned(out_int.int(), out_width)
return out_float, out_uint
@@ -58,9 +58,9 @@ def __init__(self, dut) -> None:
)
def generate_inputs(self):
- uints = torch.arange(2**self.IN_WIDTH)
+ uints = torch.arange(2 ** self.IN_WIDTH)
num_int = sign_extend_t(uints, self.IN_WIDTH)
- num_float = num_int / (2**self.IN_FRAC_WIDTH)
+ num_float = num_int / (2 ** self.IN_FRAC_WIDTH)
return num_int, num_float
def rounding_mode(self):
@@ -150,10 +150,7 @@ def gen_symmetric(cfg_list):
l = list()
for cfg in cfg_list:
l.extend(
- [
- {**cfg, "SYMMETRIC": 0},
- {**cfg, "SYMMETRIC": 1},
- ]
+ [{**cfg, "SYMMETRIC": 0}, {**cfg, "SYMMETRIC": 1},]
)
return l
diff --git a/src/mase_components/cast/test/fixed_unsigned_cast_tb.py b/src/mase_components/cast/test/fixed_unsigned_cast_tb.py
index cfeca8318..75910ed9e 100644
--- a/src/mase_components/cast/test/fixed_unsigned_cast_tb.py
+++ b/src/mase_components/cast/test/fixed_unsigned_cast_tb.py
@@ -51,7 +51,7 @@ def __init__(self, dut) -> None:
)
def generate_inputs(self):
- return torch.arange(2**self.IN_WIDTH)
+ return torch.arange(2 ** self.IN_WIDTH)
def rounding_mode(self):
if self.ROUND_FLOOR:
@@ -64,10 +64,10 @@ def rounding_mode(self):
raise Exception("Rounding mode not recognised.")
def model(self, inputs):
- float_input = inputs / (2**self.IN_FRAC_WIDTH)
- scaled_float = float_input * (2**self.OUT_FRAC_WIDTH)
+ float_input = inputs / (2 ** self.IN_FRAC_WIDTH)
+ scaled_float = float_input * (2 ** self.OUT_FRAC_WIDTH)
rounded = torch.floor(scaled_float)
- model_out = torch.clamp(rounded, 0, (2**self.OUT_WIDTH - 1))
+ model_out = torch.clamp(rounded, 0, (2 ** self.OUT_WIDTH - 1))
return model_out
async def run_test(self):
diff --git a/src/mase_components/common/test/comparator_accumulator_tb.py b/src/mase_components/common/test/comparator_accumulator_tb.py
index fbd8af937..d05f41a09 100644
--- a/src/mase_components/common/test/comparator_accumulator_tb.py
+++ b/src/mase_components/common/test/comparator_accumulator_tb.py
@@ -33,7 +33,9 @@ def __init__(self, dut) -> None:
)
def generate_inputs(self, batches=3):
- return [randint(0, 2**self.DATA_WIDTH - 1) for _ in range(self.DEPTH * batches)]
+ return [
+ randint(0, 2 ** self.DATA_WIDTH - 1) for _ in range(self.DEPTH * batches)
+ ]
def model(self, inputs):
@@ -152,7 +154,5 @@ def signed_max_min_cfgs(cfglist: list):
cfgs = signed_max_min_cfgs(cfgs)
mase_runner(
- module_param_list=cfgs,
- trace=True,
- jobs=12,
+ module_param_list=cfgs, trace=True, jobs=12,
)
diff --git a/src/mase_components/common/test/comparator_tree_tb.py b/src/mase_components/common/test/comparator_tree_tb.py
index 5637791b3..ae2a7fc61 100644
--- a/src/mase_components/common/test/comparator_tree_tb.py
+++ b/src/mase_components/common/test/comparator_tree_tb.py
@@ -34,7 +34,7 @@ def __init__(self, dut) -> None:
def generate_inputs(self, batches=3):
return [
- [randint(0, 2**self.DATA_WIDTH - 1) for _ in range(self.SIZE)]
+ [randint(0, 2 ** self.DATA_WIDTH - 1) for _ in range(self.SIZE)]
for _ in range(batches)
]
@@ -136,6 +136,5 @@ def signed_max_min_cfgs(cfglist: list):
cfgs = signed_max_min_cfgs(cfgs)
mase_runner(
- module_param_list=cfgs,
- trace=True,
+ module_param_list=cfgs, trace=True,
)
diff --git a/src/mase_components/common/test/register_slice_tb.py b/src/mase_components/common/test/register_slice_tb.py
index 638d53da7..ffcec4f3e 100644
--- a/src/mase_components/common/test/register_slice_tb.py
+++ b/src/mase_components/common/test/register_slice_tb.py
@@ -60,9 +60,7 @@ def in_out_wave(dut, name):
)
logger.debug(
"{} State: (shift_reg, buffer) = ({},{})".format(
- name,
- int(dut.shift_reg.value),
- int(dut.buffer.value),
+ name, int(dut.shift_reg.value), int(dut.buffer.value),
)
)
diff --git a/src/mase_components/common/test/single_element_repeat_tb.py b/src/mase_components/common/test/single_element_repeat_tb.py
index 247cc2b27..6d39b0c35 100644
--- a/src/mase_components/common/test/single_element_repeat_tb.py
+++ b/src/mase_components/common/test/single_element_repeat_tb.py
@@ -26,7 +26,7 @@ def __init__(self, dut) -> None:
)
def generate_inputs(self, num=10):
- return [random.randint(0, 2**self.DATA_WIDTH - 1) for _ in range(num)]
+ return [random.randint(0, 2 ** self.DATA_WIDTH - 1) for _ in range(num)]
def model(self, inputs):
exp_out = []
@@ -103,7 +103,5 @@ def generate_random_params():
]
mase_runner(
- module_param_list=cfgs,
- trace=True,
- jobs=8,
+ module_param_list=cfgs, trace=True, jobs=8,
)
diff --git a/src/mase_components/convolution_layers/test/binary_activation_binary_convolution_tb.py b/src/mase_components/convolution_layers/test/binary_activation_binary_convolution_tb.py
index 8091990e4..1481671a3 100644
--- a/src/mase_components/convolution_layers/test/binary_activation_binary_convolution_tb.py
+++ b/src/mase_components/convolution_layers/test/binary_activation_binary_convolution_tb.py
@@ -244,9 +244,7 @@ def sw_compute(self):
bias = \n\
{} \n\
".format(
- data,
- weight,
- bias,
+ data, weight, bias,
)
)
for i in range(self.samples):
diff --git a/src/mase_components/convolution_layers/test/convolution_tb.py b/src/mase_components/convolution_layers/test/convolution_tb.py
index 1a3789ed9..ec3c07a37 100644
--- a/src/mase_components/convolution_layers/test/convolution_tb.py
+++ b/src/mase_components/convolution_layers/test/convolution_tb.py
@@ -130,11 +130,7 @@ def get_manual_result(
# out2 = get_manual_result(x, w, b, 2,1,2,2,4,4,0,0,12,4)
# data_in_pack
- x = q2i(
- x,
- config["data_in_width"],
- config["data_in_frac_width"],
- )
+ x = q2i(x, config["data_in_width"], config["data_in_frac_width"],)
self.log.info(f"x = {x}")
# from (samples, c, h, w) to (samples*h*w*c/unroll_in_c, unroll_in_c)
@@ -144,16 +140,8 @@ def get_manual_result(
self.log.info(f"weight = {w}")
self.log.info(f"bias = {b}")
- w = q2i(
- w,
- config["weight_width"],
- config["weight_frac_width"],
- )
- b = q2i(
- b,
- config["bias_width"],
- config["bias_frac_width"],
- )
+ w = q2i(w, config["weight_width"], config["weight_frac_width"],)
+ b = q2i(b, config["bias_width"], config["bias_frac_width"],)
self.log.info(f"weight = {w}")
self.log.info(f"bias = {b}")
hw_w, hw_b = self.conv_pack(
@@ -169,11 +157,7 @@ def get_manual_result(
unroll_kernel_out=self.get_parameter("UNROLL_KERNEL_OUT"),
unroll_out_channels=self.get_parameter("UNROLL_OUT_C"),
)
- exp_out = q2i(
- out,
- config["out_width"],
- config["out_frac_width"],
- )
+ exp_out = q2i(out, config["out_width"], config["out_frac_width"],)
exp_out = (
exp_out.reshape(
-1, self.get_parameter("OUT_C"), self.get_parameter("SLIDING_NUM")
@@ -223,10 +207,7 @@ def conv_pack(
unroll_out_channels,
).permute(0, 3, 1, 4, 2)
- w_tensor = w_tensor.reshape(
- -1,
- unroll_out_channels * unroll_kernel_out,
- )
+ w_tensor = w_tensor.reshape(-1, unroll_out_channels * unroll_kernel_out,)
w_in = w_tensor.type(torch.int).tolist()
# bias_pack
bias_tensor = (
@@ -332,10 +313,7 @@ def test_fixed_linear_smoke():
Some quick tests to check if the module is working.
"""
mase_runner(
- trace=True,
- module_param_list=[
- get_fixed_conv_config(),
- ],
+ trace=True, module_param_list=[get_fixed_conv_config(),],
)
diff --git a/src/mase_components/convolution_layers/test/padding_tb.py b/src/mase_components/convolution_layers/test/padding_tb.py
index 6cb470069..d0504ef5d 100644
--- a/src/mase_components/convolution_layers/test/padding_tb.py
+++ b/src/mase_components/convolution_layers/test/padding_tb.py
@@ -96,9 +96,9 @@ def data_pack(self):
for j in range(in_channels):
for k in range(in_height):
for s in range(in_width):
- re_data_tensor[i][j][k + padding_height][s + padding_width] = (
- data_tensor[i][k][s][j]
- )
+ re_data_tensor[i][j][k + padding_height][
+ s + padding_width
+ ] = data_tensor[i][k][s][j]
return re_data_tensor
@@ -248,9 +248,7 @@ def runner():
print(extra_args)
runner = get_runner(sim)
runner.build(
- verilog_sources=verilog_sources,
- hdl_toplevel="padding",
- build_args=extra_args,
+ verilog_sources=verilog_sources, hdl_toplevel="padding", build_args=extra_args,
)
runner.test(hdl_toplevel="padding", test_module="padding_tb")
diff --git a/src/mase_components/convolution_layers/test/roller_tb.py b/src/mase_components/convolution_layers/test/roller_tb.py
index 4bf56b273..8c7a24d02 100644
--- a/src/mase_components/convolution_layers/test/roller_tb.py
+++ b/src/mase_components/convolution_layers/test/roller_tb.py
@@ -195,9 +195,7 @@ def runner():
print(extra_args)
runner = get_runner(sim)()
runner.build(
- verilog_sources=verilog_sources,
- toplevel="roller",
- extra_args=extra_args,
+ verilog_sources=verilog_sources, toplevel="roller", extra_args=extra_args,
)
runner.test(toplevel="roller", py_module="roller_tb")
diff --git a/src/mase_components/convolution_layers/test/sliding_window_tb.py b/src/mase_components/convolution_layers/test/sliding_window_tb.py
index f893d6f68..73facf247 100644
--- a/src/mase_components/convolution_layers/test/sliding_window_tb.py
+++ b/src/mase_components/convolution_layers/test/sliding_window_tb.py
@@ -119,9 +119,9 @@ def data_pack(self):
for j in range(in_channels):
for k in range(in_height):
for s in range(in_width):
- re_data_tensor[i][j][k + padding_height][s + padding_width] = (
- data_tensor[i][k][s][j]
- )
+ re_data_tensor[i][j][k + padding_height][
+ s + padding_width
+ ] = data_tensor[i][k][s][j]
return re_data_tensor
diff --git a/src/mase_components/deps.py b/src/mase_components/deps.py
index a06529344..9596adedc 100644
--- a/src/mase_components/deps.py
+++ b/src/mase_components/deps.py
@@ -17,11 +17,7 @@
"activation_layers",
"scalar_operators/fixed",
],
- "activation_layers/fixed_gelu": [
- "common",
- "memory",
- "activation_layers",
- ],
+ "activation_layers/fixed_gelu": ["common", "memory", "activation_layers",],
"activation_layers/fixed_softsign": [
"common",
"activation_layers",
@@ -127,11 +123,7 @@
"common",
"cast",
],
- "language_models/llmint8/scatter": [
- "language_models/llmint8",
- "memory",
- "common",
- ],
+ "language_models/llmint8/scatter": ["language_models/llmint8", "memory", "common",],
# Linear
"linear_layers/fixed_linear_layer/fixed_linear": [
"cast",
diff --git a/src/mase_components/helper/generate_memory.py b/src/mase_components/helper/generate_memory.py
index e46c99604..c780d6c7f 100644
--- a/src/mase_components/helper/generate_memory.py
+++ b/src/mase_components/helper/generate_memory.py
@@ -60,10 +60,10 @@ def generate_lookup(data_width: int, f_width: int, function: str, type="hex"):
count += 1
iarr.append(i)
val = quanter(f(torch.tensor(i))) # entry in the lookup table
- lut[doubletofx(data_width=data_width, f_width=f_width, num=i, type=type)] = (
- doubletofx(
- data_width=data_width, f_width=f_width, num=val.item(), type=type
- )
+ lut[
+ doubletofx(data_width=data_width, f_width=f_width, num=i, type=type)
+ ] = doubletofx(
+ data_width=data_width, f_width=f_width, num=val.item(), type=type
)
i += 2 ** -(f_width)
return lut
diff --git a/src/mase_components/hls/bfp_arith/bfp_adder.py b/src/mase_components/hls/bfp_arith/bfp_adder.py
index 5b2a401d9..0ba95741e 100644
--- a/src/mase_components/hls/bfp_arith/bfp_adder.py
+++ b/src/mase_components/hls/bfp_arith/bfp_adder.py
@@ -2,11 +2,7 @@
def bfp_adder_gen(
- writer,
- x_exp_width=16,
- x_man_width=8,
- w_exp_width=16,
- w_man_width=8,
+ writer, x_exp_width=16, x_man_width=8, w_exp_width=16, w_man_width=8,
):
"""
This script generates a element-level bfp add in HLS
diff --git a/src/mase_components/hls/bfp_arith/bfp_multiplier.py b/src/mase_components/hls/bfp_arith/bfp_multiplier.py
index e3695295e..fd588ea3c 100644
--- a/src/mase_components/hls/bfp_arith/bfp_multiplier.py
+++ b/src/mase_components/hls/bfp_arith/bfp_multiplier.py
@@ -1,9 +1,5 @@
def bfp_multiplier_gen(
- writer,
- x_exp_width=16,
- x_man_width=8,
- w_exp_width=16,
- w_man_width=8,
+ writer, x_exp_width=16, x_man_width=8, w_exp_width=16, w_man_width=8,
):
"""
This script generates a element-level bfp mult in HLS
diff --git a/src/mase_components/hls/elastic/buffer.py b/src/mase_components/hls/elastic/buffer.py
index d25c1fd56..a0fda6b90 100644
--- a/src/mase_components/hls/elastic/buffer.py
+++ b/src/mase_components/hls/elastic/buffer.py
@@ -2,13 +2,7 @@
def buffer_gen(
- writer,
- x_width=16,
- x_frac_width=8,
- x_row=3,
- x_col=2,
- x_row_depth=3,
- x_col_depth=2,
+ writer, x_width=16, x_frac_width=8, x_row=3, x_col=2, x_row_depth=3, x_col_depth=2,
):
"""
This script generates a buffer in HLS
diff --git a/src/mase_components/hls/hls_regression.py b/src/mase_components/hls/hls_regression.py
index a02dca60a..719b20380 100755
--- a/src/mase_components/hls/hls_regression.py
+++ b/src/mase_components/hls/hls_regression.py
@@ -106,10 +106,7 @@ def main():
parser = ArgumentParser(usage=USAGE)
parser.add_argument(
- "--op",
- dest="op",
- default=None,
- help="Op name to explore",
+ "--op", dest="op", default=None, help="Op name to explore",
)
parser.add_argument(
"--dir",
diff --git a/src/mase_components/hls/int_arith/int_layernorm.py b/src/mase_components/hls/int_arith/int_layernorm.py
index 059b15d4d..9521a68c0 100644
--- a/src/mase_components/hls/int_arith/int_layernorm.py
+++ b/src/mase_components/hls/int_arith/int_layernorm.py
@@ -2,13 +2,7 @@
def int_layernorm_gen(
- writer,
- x_width=16,
- x_frac_width=8,
- x_row=3,
- x_col=2,
- x_row_depth=3,
- x_col_depth=2,
+ writer, x_width=16, x_frac_width=8, x_row=3, x_col=2, x_row_depth=3, x_col_depth=2,
):
"""
This script generates a fixed-point layernorm in HLS
diff --git a/src/mase_components/hls/int_arith/int_relu.py b/src/mase_components/hls/int_arith/int_relu.py
index f934640c9..09f7462c2 100644
--- a/src/mase_components/hls/int_arith/int_relu.py
+++ b/src/mase_components/hls/int_arith/int_relu.py
@@ -2,13 +2,7 @@
def int_relu_gen(
- writer,
- x_width=16,
- x_frac_width=8,
- x_row=3,
- x_col=2,
- x_row_depth=3,
- x_col_depth=2,
+ writer, x_width=16, x_frac_width=8, x_row=3, x_col=2, x_row_depth=3, x_col_depth=2,
):
"""
This script generates a fixed-point relu in HLS
diff --git a/src/mase_components/hls/int_arith/int_silu.py b/src/mase_components/hls/int_arith/int_silu.py
index c4fc32abe..c90511c47 100644
--- a/src/mase_components/hls/int_arith/int_silu.py
+++ b/src/mase_components/hls/int_arith/int_silu.py
@@ -2,13 +2,7 @@
def int_silu_gen(
- writer,
- x_width=16,
- x_frac_width=8,
- x_row=3,
- x_col=2,
- x_row_depth=3,
- x_col_depth=2,
+ writer, x_width=16, x_frac_width=8, x_row=3, x_col=2, x_row_depth=3, x_col_depth=2,
):
"""
This script generates a fixed-point silu in HLS
diff --git a/src/mase_components/hls/int_arith/int_softmax.py b/src/mase_components/hls/int_arith/int_softmax.py
index c0b4ca5b1..8ad5901cf 100644
--- a/src/mase_components/hls/int_arith/int_softmax.py
+++ b/src/mase_components/hls/int_arith/int_softmax.py
@@ -2,13 +2,7 @@
def int_softmax_gen(
- writer,
- x_width=16,
- x_frac_width=8,
- x_row=3,
- x_col=2,
- x_row_depth=3,
- x_col_depth=2,
+ writer, x_width=16, x_frac_width=8, x_row=3, x_col=2, x_row_depth=3, x_col_depth=2,
):
"""
This script generates a fixed-point softmax in HLS
diff --git a/src/mase_components/hls/int_arith/int_transpose.py b/src/mase_components/hls/int_arith/int_transpose.py
index 6b7fd0a15..ee9c6d2d7 100644
--- a/src/mase_components/hls/int_arith/int_transpose.py
+++ b/src/mase_components/hls/int_arith/int_transpose.py
@@ -2,13 +2,7 @@
def int_transpose_gen(
- writer,
- x_width=16,
- x_frac_width=8,
- x_row=3,
- x_col=2,
- x_row_depth=3,
- x_col_depth=2,
+ writer, x_width=16, x_frac_width=8, x_row=3, x_col=2, x_row_depth=3, x_col_depth=2,
):
"""
This script generates a fixed-point transpose in HLS
diff --git a/src/mase_components/hls/regression_gen/bfp_add_dse.py b/src/mase_components/hls/regression_gen/bfp_add_dse.py
index 834644b24..c283301d2 100644
--- a/src/mase_components/hls/regression_gen/bfp_add_dse.py
+++ b/src/mase_components/hls/regression_gen/bfp_add_dse.py
@@ -1,13 +1,7 @@
# TODO: Temporary working solution
import sys, os
-sys.path.append(
- os.path.join(
- os.path.dirname(os.path.realpath(__file__)),
- "..",
- "..",
- )
-)
+sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..",))
from hls.regression_gen.utils import (
DSE_MODES,
@@ -157,8 +151,7 @@ def bfp_add_dse(mode=None, top=None, threads=16):
if mode in ["report", "all"]:
top_name = "bfp_add_2"
hr = get_hls_results(
- project=os.path.join(top, file_name),
- top=top_name,
+ project=os.path.join(top, file_name), top=top_name,
)
data_points.append(
[
diff --git a/src/mase_components/hls/regression_gen/bfp_linear2d_dse.py b/src/mase_components/hls/regression_gen/bfp_linear2d_dse.py
index d6f78f30c..9a894b143 100644
--- a/src/mase_components/hls/regression_gen/bfp_linear2d_dse.py
+++ b/src/mase_components/hls/regression_gen/bfp_linear2d_dse.py
@@ -1,13 +1,7 @@
# TODO: Temporary working solution
import sys, os
-sys.path.append(
- os.path.join(
- os.path.dirname(os.path.realpath(__file__)),
- "..",
- "..",
- )
-)
+sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..",))
from hls.regression_gen.utils import (
DSE_MODES,
diff --git a/src/mase_components/hls/regression_gen/bfp_mult_dse.py b/src/mase_components/hls/regression_gen/bfp_mult_dse.py
index b4b5fa58c..1510ea1f3 100644
--- a/src/mase_components/hls/regression_gen/bfp_mult_dse.py
+++ b/src/mase_components/hls/regression_gen/bfp_mult_dse.py
@@ -1,13 +1,7 @@
# TODO: Temporary working solution
import sys, os
-sys.path.append(
- os.path.join(
- os.path.dirname(os.path.realpath(__file__)),
- "..",
- "..",
- )
-)
+sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..",))
from hls.regression_gen.utils import (
DSE_MODES,
@@ -157,8 +151,7 @@ def bfp_mult_dse(mode=None, top=None, threads=16):
if mode in ["report", "all"]:
top_name = "bfp_mult_2"
hr = get_hls_results(
- project=os.path.join(top, file_name),
- top=top_name,
+ project=os.path.join(top, file_name), top=top_name,
)
if hr is None:
continue
diff --git a/src/mase_components/hls/regression_gen/buffer_dse.py b/src/mase_components/hls/regression_gen/buffer_dse.py
index 08cc6d56e..bcc003c6d 100644
--- a/src/mase_components/hls/regression_gen/buffer_dse.py
+++ b/src/mase_components/hls/regression_gen/buffer_dse.py
@@ -1,13 +1,7 @@
# TODO: Temporary working solution
import sys, os
-sys.path.append(
- os.path.join(
- os.path.dirname(os.path.realpath(__file__)),
- "..",
- "..",
- )
-)
+sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..",))
from hls.regression_gen.utils import (
DSE_MODES,
@@ -129,8 +123,7 @@ def buffer_dse(mode=None, top=None, threads=16):
if mode in ["report", "all"]:
top_name = "buffer_0"
hr = get_hls_results(
- project=os.path.join(top, file_name),
- top=top_name,
+ project=os.path.join(top, file_name), top=top_name,
)
data_points.append(
[
diff --git a/src/mase_components/hls/regression_gen/fork_dse.py b/src/mase_components/hls/regression_gen/fork_dse.py
index f045cd8a3..c8a5b3c0b 100644
--- a/src/mase_components/hls/regression_gen/fork_dse.py
+++ b/src/mase_components/hls/regression_gen/fork_dse.py
@@ -1,13 +1,7 @@
# TODO: Temporary working solution
import sys, os
-sys.path.append(
- os.path.join(
- os.path.dirname(os.path.realpath(__file__)),
- "..",
- "..",
- )
-)
+sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..",))
from hls.regression_gen.utils import (
DSE_MODES,
@@ -139,8 +133,7 @@ def fork_dse(mode=None, top=None, threads=16):
if mode in ["report", "all"]:
top_name = "fork_0"
hr = get_hls_results(
- project=os.path.join(top, file_name),
- top=top_name,
+ project=os.path.join(top, file_name), top=top_name,
)
data_points.append(
[
diff --git a/src/mase_components/hls/regression_gen/int_add_dse.py b/src/mase_components/hls/regression_gen/int_add_dse.py
index d4c4d9ccc..a88247d24 100644
--- a/src/mase_components/hls/regression_gen/int_add_dse.py
+++ b/src/mase_components/hls/regression_gen/int_add_dse.py
@@ -1,13 +1,7 @@
# TODO: Temporary working solution
import sys, os
-sys.path.append(
- os.path.join(
- os.path.dirname(os.path.realpath(__file__)),
- "..",
- "..",
- )
-)
+sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..",))
from hls.regression_gen.utils import (
DSE_MODES,
@@ -157,8 +151,7 @@ def int_add_dse(mode=None, top=None, threads=16):
if mode in ["report", "all"]:
top_name = "int_add_0"
hr = get_hls_results(
- project=os.path.join(top, file_name),
- top=top_name,
+ project=os.path.join(top, file_name), top=top_name,
)
data_points.append(
[
diff --git a/src/mase_components/hls/regression_gen/int_layernorm_dse.py b/src/mase_components/hls/regression_gen/int_layernorm_dse.py
index 3e6a61aff..d0ef2afdb 100644
--- a/src/mase_components/hls/regression_gen/int_layernorm_dse.py
+++ b/src/mase_components/hls/regression_gen/int_layernorm_dse.py
@@ -1,13 +1,7 @@
# TODO: Temporary working solution
import sys, os
-sys.path.append(
- os.path.join(
- os.path.dirname(os.path.realpath(__file__)),
- "..",
- "..",
- )
-)
+sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..",))
from hls.regression_gen.utils import (
DSE_MODES,
@@ -130,8 +124,7 @@ def int_layernorm_dse(mode=None, top=None, threads=16):
if mode in ["report", "all"]:
top_name = "int_layernorm_1"
hr = get_hls_results(
- project=os.path.join(top, file_name),
- top=top_name,
+ project=os.path.join(top, file_name), top=top_name,
)
data_points.append(
[
diff --git a/src/mase_components/hls/regression_gen/int_linear2d_dse.py b/src/mase_components/hls/regression_gen/int_linear2d_dse.py
index 1c89861eb..d75122306 100644
--- a/src/mase_components/hls/regression_gen/int_linear2d_dse.py
+++ b/src/mase_components/hls/regression_gen/int_linear2d_dse.py
@@ -1,13 +1,7 @@
# TODO: Temporary working solution
import sys, os
-sys.path.append(
- os.path.join(
- os.path.dirname(os.path.realpath(__file__)),
- "..",
- "..",
- )
-)
+sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..",))
from hls.regression_gen.utils import (
DSE_MODES,
diff --git a/src/mase_components/hls/regression_gen/int_matmul_dse.py b/src/mase_components/hls/regression_gen/int_matmul_dse.py
index 1d1f35fd2..ac09aaa5c 100644
--- a/src/mase_components/hls/regression_gen/int_matmul_dse.py
+++ b/src/mase_components/hls/regression_gen/int_matmul_dse.py
@@ -1,13 +1,7 @@
# TODO: Temporary working solution
import sys, os
-sys.path.append(
- os.path.join(
- os.path.dirname(os.path.realpath(__file__)),
- "..",
- "..",
- )
-)
+sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..",))
from hls.regression_gen.utils import DSE_MODES, get_tcl_buff
from hls.int_arith import int_matmul_gen
diff --git a/src/mase_components/hls/regression_gen/int_mult_dse.py b/src/mase_components/hls/regression_gen/int_mult_dse.py
index 275ffb163..f824ec99e 100644
--- a/src/mase_components/hls/regression_gen/int_mult_dse.py
+++ b/src/mase_components/hls/regression_gen/int_mult_dse.py
@@ -1,13 +1,7 @@
# TODO: Temporary working solution
import sys, os
-sys.path.append(
- os.path.join(
- os.path.dirname(os.path.realpath(__file__)),
- "..",
- "..",
- )
-)
+sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..",))
from hls.regression_gen.utils import (
DSE_MODES,
@@ -157,8 +151,7 @@ def int_mult_dse(mode=None, top=None, threads=16):
if mode in ["report", "all"]:
top_name = "int_mult_0"
hr = get_hls_results(
- project=os.path.join(top, file_name),
- top=top_name,
+ project=os.path.join(top, file_name), top=top_name,
)
data_points.append(
[
diff --git a/src/mase_components/hls/regression_gen/int_relu_dse.py b/src/mase_components/hls/regression_gen/int_relu_dse.py
index 7224b1444..788e3f593 100644
--- a/src/mase_components/hls/regression_gen/int_relu_dse.py
+++ b/src/mase_components/hls/regression_gen/int_relu_dse.py
@@ -1,13 +1,7 @@
# TODO: Temporary working solution
import sys, os
-sys.path.append(
- os.path.join(
- os.path.dirname(os.path.realpath(__file__)),
- "..",
- "..",
- )
-)
+sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..",))
from hls.regression_gen.utils import (
DSE_MODES,
@@ -131,8 +125,7 @@ def int_relu_dse(mode=None, top=None, threads=16):
if mode in ["report", "all"]:
top_name = "int_relu_0"
hr = get_hls_results(
- project=os.path.join(top, file_name),
- top=top_name,
+ project=os.path.join(top, file_name), top=top_name,
)
data_points.append(
[
diff --git a/src/mase_components/hls/regression_gen/int_rmsnorm_dse.py b/src/mase_components/hls/regression_gen/int_rmsnorm_dse.py
index 262794af2..c559ecee8 100644
--- a/src/mase_components/hls/regression_gen/int_rmsnorm_dse.py
+++ b/src/mase_components/hls/regression_gen/int_rmsnorm_dse.py
@@ -1,13 +1,7 @@
# TODO: Temporary working solution
import sys, os
-sys.path.append(
- os.path.join(
- os.path.dirname(os.path.realpath(__file__)),
- "..",
- "..",
- )
-)
+sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..",))
from hls.regression_gen.utils import (
DSE_MODES,
@@ -151,8 +145,7 @@ def int_rmsnorm_dse(mode=None, top=None, threads=16):
if mode in ["report", "all"]:
top_name = "int_rmsnorm_0"
hr = get_hls_results(
- project=os.path.join(top, file_name),
- top=top_name,
+ project=os.path.join(top, file_name), top=top_name,
)
data_points.append(
[
diff --git a/src/mase_components/hls/regression_gen/int_rope_dse.py b/src/mase_components/hls/regression_gen/int_rope_dse.py
index 421a37e21..bfc2bc385 100644
--- a/src/mase_components/hls/regression_gen/int_rope_dse.py
+++ b/src/mase_components/hls/regression_gen/int_rope_dse.py
@@ -1,13 +1,7 @@
# TODO: Temporary working solution
import sys, os
-sys.path.append(
- os.path.join(
- os.path.dirname(os.path.realpath(__file__)),
- "..",
- "..",
- )
-)
+sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..",))
from hls.regression_gen.utils import (
DSE_MODES,
@@ -152,8 +146,7 @@ def int_rope_dse(mode=None, top=None, threads=16):
if mode in ["report", "all"]:
top_name = "int_rope_0"
hr = get_hls_results(
- project=os.path.join(top, file_name),
- top=top_name,
+ project=os.path.join(top, file_name), top=top_name,
)
data_points.append(
[
diff --git a/src/mase_components/hls/regression_gen/int_silu_dse.py b/src/mase_components/hls/regression_gen/int_silu_dse.py
index 0d80397cd..e3ce1bd11 100644
--- a/src/mase_components/hls/regression_gen/int_silu_dse.py
+++ b/src/mase_components/hls/regression_gen/int_silu_dse.py
@@ -1,13 +1,7 @@
# TODO: Temporary working solution
import sys, os
-sys.path.append(
- os.path.join(
- os.path.dirname(os.path.realpath(__file__)),
- "..",
- "..",
- )
-)
+sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..",))
from hls.regression_gen.utils import (
DSE_MODES,
@@ -131,8 +125,7 @@ def int_silu_dse(mode=None, top=None, threads=16):
if mode in ["report", "all"]:
top_name = "int_silu_0"
hr = get_hls_results(
- project=os.path.join(top, file_name),
- top=top_name,
+ project=os.path.join(top, file_name), top=top_name,
)
data_points.append(
[
diff --git a/src/mase_components/hls/regression_gen/int_softmax_dse.py b/src/mase_components/hls/regression_gen/int_softmax_dse.py
index c10c091ef..796829278 100644
--- a/src/mase_components/hls/regression_gen/int_softmax_dse.py
+++ b/src/mase_components/hls/regression_gen/int_softmax_dse.py
@@ -1,13 +1,7 @@
# TODO: Temporary working solution
import sys, os
-sys.path.append(
- os.path.join(
- os.path.dirname(os.path.realpath(__file__)),
- "..",
- "..",
- )
-)
+sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..",))
from hls.regression_gen.utils import (
DSE_MODES,
@@ -131,8 +125,7 @@ def int_softmax_dse(mode=None, top=None, threads=16):
if mode in ["report", "all"]:
top_name = "int_softmax_0"
hr = get_hls_results(
- project=os.path.join(top, file_name),
- top=top_name,
+ project=os.path.join(top, file_name), top=top_name,
)
data_points.append(
[
diff --git a/src/mase_components/hls/regression_gen/int_transpose_dse.py b/src/mase_components/hls/regression_gen/int_transpose_dse.py
index 7b5c2b715..4563a9eca 100644
--- a/src/mase_components/hls/regression_gen/int_transpose_dse.py
+++ b/src/mase_components/hls/regression_gen/int_transpose_dse.py
@@ -1,13 +1,7 @@
# TODO: Temporary working solution
import sys, os
-sys.path.append(
- os.path.join(
- os.path.dirname(os.path.realpath(__file__)),
- "..",
- "..",
- )
-)
+sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..",))
from hls.regression_gen.utils import (
DSE_MODES,
@@ -131,8 +125,7 @@ def int_transpose_dse(mode=None, top=None, threads=16):
if mode in ["report", "all"]:
top_name = "int_transpose_0"
hr = get_hls_results(
- project=os.path.join(top, file_name),
- top=top_name,
+ project=os.path.join(top, file_name), top=top_name,
)
data_points.append(
[
diff --git a/src/mase_components/hls/scripts/bl_bfp.py b/src/mase_components/hls/scripts/bl_bfp.py
index beff47050..29f34c04d 100644
--- a/src/mase_components/hls/scripts/bl_bfp.py
+++ b/src/mase_components/hls/scripts/bl_bfp.py
@@ -1,12 +1,6 @@
import os, sys
-sys.path.append(
- os.path.join(
- os.path.dirname(os.path.realpath(__file__)),
- "..",
- "..",
- )
-)
+sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..",))
from hls.bfp_arith import bfp_mm_gen
from hls.bfp_arith import bfp_add_gen
@@ -14,12 +8,7 @@
def get_big_little_bfp(
- HIGH_MAN_WIDTH=7,
- LOW_MAN_WIDTH=3,
- X_ROW=1,
- X_COL=4096,
- W_COL=11008,
- A_COL=32,
+ HIGH_MAN_WIDTH=7, LOW_MAN_WIDTH=3, X_ROW=1, X_COL=4096, W_COL=11008, A_COL=32,
):
W_ROW = X_COL
A_ROW = X_COL
diff --git a/src/mase_components/linear_layers/binarized_operators/test/binary_activation_binary_adder_tree_layer_tb.py b/src/mase_components/linear_layers/binarized_operators/test/binary_activation_binary_adder_tree_layer_tb.py
index 91552b68a..2cdb58b59 100644
--- a/src/mase_components/linear_layers/binarized_operators/test/binary_activation_binary_adder_tree_layer_tb.py
+++ b/src/mase_components/linear_layers/binarized_operators/test/binary_activation_binary_adder_tree_layer_tb.py
@@ -95,8 +95,7 @@ def runner():
extra_args.append(f"-G{k}={v}")
mase_runner(
- trace=True,
- module_param_list=[test_case.get_dut_parameters()],
+ trace=True, module_param_list=[test_case.get_dut_parameters()],
)
diff --git a/src/mase_components/linear_layers/fixed_operators/test/fixed_isqrt_tb.py b/src/mase_components/linear_layers/fixed_operators/test/fixed_isqrt_tb.py
index 0aa8f5239..27e729200 100644
--- a/src/mase_components/linear_layers/fixed_operators/test/fixed_isqrt_tb.py
+++ b/src/mase_components/linear_layers/fixed_operators/test/fixed_isqrt_tb.py
@@ -24,31 +24,23 @@ class VerificationCase(Testbench):
def __init__(self, dut):
super().__init__(dut, dut.clk, dut.rst)
self.assign_self_params(
- [
- "IN_WIDTH",
- "IN_FRAC_WIDTH",
- "LUT_POW",
- ]
+ ["IN_WIDTH", "IN_FRAC_WIDTH", "LUT_POW",]
)
self.input_driver = StreamDriver(
dut.clk, dut.in_data, dut.in_valid, dut.in_ready
)
self.output_monitor = StreamMonitor(
- dut.clk,
- dut.out_data,
- dut.out_valid,
- dut.out_ready,
- name="Output ISQRT",
+ dut.clk, dut.out_data, dut.out_valid, dut.out_ready, name="Output ISQRT",
)
def generate_inputs(self, num=10000):
- maxnum = (2**self.IN_WIDTH) - 1
+ maxnum = (2 ** self.IN_WIDTH) - 1
return [random.randint(0, maxnum) for _ in range(num)], num
def model(self, data_in):
ref = []
- lut_size = 2**self.LUT_POW
+ lut_size = 2 ** self.LUT_POW
lut = make_lut(lut_size, self.IN_WIDTH)
for x in data_in:
expected = isqrt_sw2(
@@ -127,7 +119,7 @@ async def valid_backpressure(dut):
makedirs(mem_dir, exist_ok=True)
def single_cfg(width, frac_width, lut_pow, str_id):
- lut_size = 2**lut_pow
+ lut_size = 2 ** lut_pow
lut = make_lut(lut_size, width)
mem_path = mem_dir / f"lutmem-{str_id}.mem"
write_memb(mem_path, lut, width)
diff --git a/src/mase_components/linear_layers/fixed_operators/test/fixed_lut_index_tb.py b/src/mase_components/linear_layers/fixed_operators/test/fixed_lut_index_tb.py
index 42bcb4edc..84fde6666 100644
--- a/src/mase_components/linear_layers/fixed_operators/test/fixed_lut_index_tb.py
+++ b/src/mase_components/linear_layers/fixed_operators/test/fixed_lut_index_tb.py
@@ -22,7 +22,7 @@ def __init__(self, dut):
self.assign_self_params(["WIDTH", "LUT_POW"])
def generate_inputs(self):
- samples = 2**self.WIDTH
+ samples = 2 ** self.WIDTH
data_x = []
msb_indices = []
for x in range(samples):
@@ -94,7 +94,6 @@ async def cocotb_test_fixed_lut_index(dut):
@pytest.mark.skip(reason="Needs to be fixed.")
def test_fixed_lut_index():
-
def full_sweep():
parameter_list = []
lut_pow = 5
diff --git a/src/mase_components/linear_layers/fixed_operators/test/fixed_nr_stage_tb.py b/src/mase_components/linear_layers/fixed_operators/test/fixed_nr_stage_tb.py
index 8c0b70ed1..07a199810 100644
--- a/src/mase_components/linear_layers/fixed_operators/test/fixed_nr_stage_tb.py
+++ b/src/mase_components/linear_layers/fixed_operators/test/fixed_nr_stage_tb.py
@@ -24,12 +24,12 @@ def __init__(self, dut):
self.assign_self_params(["WIDTH"])
def generate_inputs(self, lut_pow):
- samples = 2**self.WIDTH
+ samples = 2 ** self.WIDTH
int_width = 1
frac_width = self.WIDTH - 1
data_x = []
initial_guesses = []
- lut_size = 2**lut_pow
+ lut_size = 2 ** lut_pow
lut = make_lut(lut_size, self.WIDTH)
# NOTE: since negative values are not supported by fixed formats since
# isqrt only outputs positive results we cannot test every single com-
diff --git a/src/mase_components/linear_layers/fixed_operators/test/fixed_range_augmentation_tb.py b/src/mase_components/linear_layers/fixed_operators/test/fixed_range_augmentation_tb.py
index cbf9cf776..a0801d23a 100644
--- a/src/mase_components/linear_layers/fixed_operators/test/fixed_range_augmentation_tb.py
+++ b/src/mase_components/linear_layers/fixed_operators/test/fixed_range_augmentation_tb.py
@@ -22,7 +22,7 @@ def __init__(self, dut) -> None:
self.assign_self_params(["WIDTH", "FRAC_WIDTH"])
def generate_inputs(self):
- samples = 2**self.WIDTH
+ samples = 2 ** self.WIDTH
data_x = []
msb_indices = []
for x in range(samples):
diff --git a/src/mase_components/linear_layers/fixed_operators/test/fixed_range_reduction_tb.py b/src/mase_components/linear_layers/fixed_operators/test/fixed_range_reduction_tb.py
index 1b54ea013..c4d00af2a 100644
--- a/src/mase_components/linear_layers/fixed_operators/test/fixed_range_reduction_tb.py
+++ b/src/mase_components/linear_layers/fixed_operators/test/fixed_range_reduction_tb.py
@@ -20,7 +20,7 @@ def __init__(self, dut) -> None:
self.assign_self_params(["WIDTH"])
def generate_inputs(self):
- samples = 2**self.WIDTH
+ samples = 2 ** self.WIDTH
return [val for val in range(0, samples)], samples
def model(self, inputs):
@@ -71,7 +71,6 @@ async def cocotb_test_fixed_range_reduction(dut):
@pytest.mark.skip(reason="Needs to be fixed.")
def test_fixed_range_reduction():
-
def full_sweep():
parameter_list = []
for width in range(1, 17):
diff --git a/src/mase_components/linear_layers/fixed_operators/test/isqrt_sw.py b/src/mase_components/linear_layers/fixed_operators/test/isqrt_sw.py
index c4723f971..a001b27bc 100644
--- a/src/mase_components/linear_layers/fixed_operators/test/isqrt_sw.py
+++ b/src/mase_components/linear_layers/fixed_operators/test/isqrt_sw.py
@@ -13,7 +13,7 @@ def find_msb(x: int, width: int) -> int:
def float_to_int(x: float, int_width: int, frac_width: int) -> int:
integer = int(x)
x -= integer
- res = integer * (2**frac_width)
+ res = integer * (2 ** frac_width)
for i in range(1, frac_width + 1):
power = 2 ** (-i)
if power <= x:
@@ -23,8 +23,8 @@ def float_to_int(x: float, int_width: int, frac_width: int) -> int:
def int_to_float(x: int, int_width: int, frac_width: int) -> float:
- integer = x / (2**frac_width)
- fraction = x - integer * 2**frac_width
+ integer = x / (2 ** frac_width)
+ fraction = x - integer * 2 ** frac_width
res = integer
for i in range(1, frac_width + 1):
@@ -85,7 +85,7 @@ def fixed_lut_index_sw(x_red: int, width: int, lut_pow: int) -> int:
res = 0
else:
res = x_red - 2 ** (width - 1)
- res = res * 2**lut_pow
+ res = res * 2 ** lut_pow
res = res / 2 ** (width - 1)
# FORMAT OUTPUT: Q(WIDTH).0
return int(res)
@@ -258,7 +258,7 @@ def test_sw_model():
def debug_single():
lut_pow = 5
- lut_size = 2**lut_pow
+ lut_size = 2 ** lut_pow
int_width = 2
frac_width = 1
width = int_width + frac_width
diff --git a/src/mase_components/linear_layers/matmul/test/fixed_matmul_tb.py b/src/mase_components/linear_layers/matmul/test/fixed_matmul_tb.py
index 815e9ed07..5734fb891 100644
--- a/src/mase_components/linear_layers/matmul/test/fixed_matmul_tb.py
+++ b/src/mase_components/linear_layers/matmul/test/fixed_matmul_tb.py
@@ -147,10 +147,7 @@ def data_generate(self):
weight_tensor = {} \n\
data_in = {} \n\
weight_in = {} ".format(
- data_tensor,
- weight_tensor,
- data_in,
- weight_in,
+ data_tensor, weight_tensor, data_in, weight_in,
)
)
data_in.reverse()
@@ -381,8 +378,7 @@ def runner():
build_args=extra_args,
)
runner.test(
- hdl_toplevel="fixed_matmul",
- test_module="fixed_matmul_tb",
+ hdl_toplevel="fixed_matmul", test_module="fixed_matmul_tb",
)
diff --git a/src/mase_components/linear_layers/matmul/test/simple_matmul_tb.py b/src/mase_components/linear_layers/matmul/test/simple_matmul_tb.py
index 9f058695c..0cd5b07b9 100644
--- a/src/mase_components/linear_layers/matmul/test/simple_matmul_tb.py
+++ b/src/mase_components/linear_layers/matmul/test/simple_matmul_tb.py
@@ -40,8 +40,8 @@ def __init__(self, dut) -> None:
]
)
- self.X_MAX = 2**self.X_WIDTH - 1
- self.Y_MAX = 2**self.Y_WIDTH - 1
+ self.X_MAX = 2 ** self.X_WIDTH - 1
+ self.Y_MAX = 2 ** self.Y_WIDTH - 1
self.x_driver = StreamDriver(dut.clk, dut.x_data, dut.x_valid, dut.x_ready)
self.y_driver = StreamDriver(dut.clk, dut.y_data, dut.y_valid, dut.y_ready)
@@ -87,10 +87,10 @@ def model(self, X, Y):
logger.debug("Sign Extended & Scaled")
X_input = sign_extend_t(X_input, self.X_WIDTH).type(torch.float32) / (
- 2**self.X_FRAC_WIDTH
+ 2 ** self.X_FRAC_WIDTH
)
Y_input = sign_extend_t(Y_input, self.Y_WIDTH).type(torch.float32) / (
- 2**self.Y_FRAC_WIDTH
+ 2 ** self.Y_FRAC_WIDTH
)
logger.debug(X_input)
logger.debug(Y_input)
diff --git a/src/mase_components/linear_layers/matmul/test/transpose_tb.py b/src/mase_components/linear_layers/matmul/test/transpose_tb.py
index c859749ed..003382ff6 100644
--- a/src/mase_components/linear_layers/matmul/test/transpose_tb.py
+++ b/src/mase_components/linear_layers/matmul/test/transpose_tb.py
@@ -56,11 +56,7 @@ def generate_random_params(num=3):
cfgs = list()
for _ in range(num):
cfgs.append(
- {
- "WIDTH": randint(1, 16),
- "DIM0": randint(2, 12),
- "DIM1": randint(2, 12),
- }
+ {"WIDTH": randint(1, 16), "DIM0": randint(2, 12), "DIM1": randint(2, 12),}
)
return cfgs
diff --git a/src/mase_components/linear_layers/mxint_operators/test/mxint_cast_tb.py b/src/mase_components/linear_layers/mxint_operators/test/mxint_cast_tb.py
index 963ae40df..ce742183b 100644
--- a/src/mase_components/linear_layers/mxint_operators/test/mxint_cast_tb.py
+++ b/src/mase_components/linear_layers/mxint_operators/test/mxint_cast_tb.py
@@ -29,10 +29,7 @@ def __init__(self, dut, num) -> None:
self.log = SimLog("%s" % (type(self).__qualname__))
self.data_in_0_driver = MultiSignalStreamDriver(
- dut.clk,
- (dut.mdata_in, dut.edata_in),
- dut.data_in_valid,
- dut.data_in_ready,
+ dut.clk, (dut.mdata_in, dut.edata_in), dut.data_in_valid, dut.data_in_ready,
)
self.data_out_0_monitor = MultiSignalStreamMonitor(
@@ -51,14 +48,10 @@ def generate_inputs(self):
for _ in range(self.num):
data = 20 * torch.rand(int(self.dut.BLOCK_SIZE))
(data_in, mdata_in, edata_in) = mxint_quantize(
- data,
- int(self.dut.IN_MAN_WIDTH),
- int(self.dut.IN_EXP_WIDTH),
+ data, int(self.dut.IN_MAN_WIDTH), int(self.dut.IN_EXP_WIDTH),
)
exp_out, mexp_out, eexp_out = mxint_quantize(
- data_in,
- int(self.dut.OUT_MAN_WIDTH),
- int(self.dut.OUT_EXP_WIDTH),
+ data_in, int(self.dut.OUT_MAN_WIDTH), int(self.dut.OUT_EXP_WIDTH),
)
breakpoint()
inputs.append((mdata_in.int().tolist(), edata_in.int().tolist()))
diff --git a/src/mase_components/linear_layers/mxint_operators/test/mxint_dot_product_tb.py b/src/mase_components/linear_layers/mxint_operators/test/mxint_dot_product_tb.py
index 43e70adcf..cb1b9638f 100644
--- a/src/mase_components/linear_layers/mxint_operators/test/mxint_dot_product_tb.py
+++ b/src/mase_components/linear_layers/mxint_operators/test/mxint_dot_product_tb.py
@@ -40,10 +40,7 @@ def __init__(self, dut, num) -> None:
dut.data_in_0_ready,
)
self.weight_driver = MultiSignalStreamDriver(
- dut.clk,
- (dut.mweight, dut.eweight),
- dut.weight_valid,
- dut.weight_ready,
+ dut.clk, (dut.mweight, dut.eweight), dut.weight_valid, dut.weight_ready,
)
self.data_out_0_monitor = MultiSignalStreamMonitor(
@@ -67,9 +64,7 @@ def generate_inputs(self):
)
w = torch.rand(int(self.dut.BLOCK_SIZE))
(weight, mweight, eweight) = mxint_quantize(
- w,
- int(self.dut.WEIGHT_PRECISION_0),
- int(self.dut.WEIGHT_PRECISION_1),
+ w, int(self.dut.WEIGHT_PRECISION_0), int(self.dut.WEIGHT_PRECISION_1),
)
mdp_out = mdata_in @ mweight
edp_out = edata_in + eweight
diff --git a/src/mase_components/linear_layers/mxint_operators/test/mxint_matmul_tb.py b/src/mase_components/linear_layers/mxint_operators/test/mxint_matmul_tb.py
index f31d1ab61..e01e056e1 100644
--- a/src/mase_components/linear_layers/mxint_operators/test/mxint_matmul_tb.py
+++ b/src/mase_components/linear_layers/mxint_operators/test/mxint_matmul_tb.py
@@ -48,16 +48,10 @@ def __init__(self, dut) -> None:
self.log.setLevel(logging.DEBUG)
self.a_driver = MultiSignalStreamDriver(
- dut.clk,
- (dut.ma_data, dut.ea_data),
- dut.a_valid,
- dut.a_ready,
+ dut.clk, (dut.ma_data, dut.ea_data), dut.a_valid, dut.a_ready,
)
self.b_driver = MultiSignalStreamDriver(
- dut.clk,
- (dut.mb_data, dut.eb_data),
- dut.b_valid,
- dut.b_ready,
+ dut.clk, (dut.mb_data, dut.eb_data), dut.b_valid, dut.b_ready,
)
self.output_monitor = MultiSignalStreamMonitor(
diff --git a/src/mase_components/linear_layers/mxint_operators/test/mxint_vector_mult_tb.py b/src/mase_components/linear_layers/mxint_operators/test/mxint_vector_mult_tb.py
index 58fc472ba..06f970039 100644
--- a/src/mase_components/linear_layers/mxint_operators/test/mxint_vector_mult_tb.py
+++ b/src/mase_components/linear_layers/mxint_operators/test/mxint_vector_mult_tb.py
@@ -35,10 +35,7 @@ def __init__(self, dut, num) -> None:
dut.data_in_0_ready,
)
self.weight_driver = MultiSignalStreamDriver(
- dut.clk,
- (dut.mweight, dut.eweight),
- dut.weight_valid,
- dut.weight_ready,
+ dut.clk, (dut.mweight, dut.eweight), dut.weight_valid, dut.weight_ready,
)
self.data_out_0_monitor = MultiSignalStreamMonitor(
@@ -66,9 +63,7 @@ def generate_inputs(self):
w = 20 * torch.rand(int(self.dut.BLOCK_SIZE))
(weight, mweight, eweight) = mxint_quantize(
- w,
- int(self.dut.WEIGHT_PRECISION_0),
- int(self.dut.WEIGHT_PRECISION_1),
+ w, int(self.dut.WEIGHT_PRECISION_0), int(self.dut.WEIGHT_PRECISION_1),
)
exp_out, mexp_out, eexp_out = mxint_quantize(
data_in * weight,
diff --git a/src/mase_components/linear_layers/mxint_operators/test/test.py b/src/mase_components/linear_layers/mxint_operators/test/test.py
index f58f382cf..85cdaa995 100644
--- a/src/mase_components/linear_layers/mxint_operators/test/test.py
+++ b/src/mase_components/linear_layers/mxint_operators/test/test.py
@@ -8,16 +8,8 @@
d_man_width = 12
w_man_width = 8
e_width = 4
-(data_in, mdata_in, edata_in) = mxint_quantize(
- data,
- d_man_width,
- e_width,
-)
-(weight, mweight, eweight) = mxint_quantize(
- w,
- w_man_width,
- e_width,
-)
+(data_in, mdata_in, edata_in) = mxint_quantize(data, d_man_width, e_width,)
+(weight, mweight, eweight) = mxint_quantize(w, w_man_width, e_width,)
linear = torch.nn.Linear(10, 10, bias=False)
linear.weight = torch.nn.Parameter(weight)
target_x = linear(data_in)
@@ -36,7 +28,7 @@ def hardware_quant(hardware_in, be_value, e_width, width):
exponent_bias = 2 ** (e_width - 1) - 1
# exponent
- exponent_max = 2**e_width - 1 - exponent_bias
+ exponent_max = 2 ** e_width - 1 - exponent_bias
exponent_min = -exponent_bias
exponent = (
torch.ceil(torch.log2(hardware_in.abs().max())) + be_value - exponent_bias
@@ -48,7 +40,7 @@ def hardware_quant(hardware_in, be_value, e_width, width):
breakpoint()
mantissa = torch.clamp(mantissa.floor(), int_min, int_max)
- msfp_x = (2**exponent) * mantissa
+ msfp_x = (2 ** exponent) * mantissa
return msfp_x, mantissa, exponent
diff --git a/src/mase_components/linear_layers/mxint_operators/test/utils.py b/src/mase_components/linear_layers/mxint_operators/test/utils.py
index 7edb9f6ed..43b3b0b87 100644
--- a/src/mase_components/linear_layers/mxint_operators/test/utils.py
+++ b/src/mase_components/linear_layers/mxint_operators/test/utils.py
@@ -24,7 +24,7 @@ def mxint_quantize(x, width: int = 12, exponent_width: int = 6, exponent: int =
"""
exponent_bias = 2 ** (exponent_width - 1)
- exponent_max = 2**exponent_width - 1 - exponent_bias
+ exponent_max = 2 ** exponent_width - 1 - exponent_bias
exponent_min = -exponent_bias
# exponent
@@ -34,9 +34,9 @@ def mxint_quantize(x, width: int = 12, exponent_width: int = 6, exponent: int =
# mantissa
int_min = -(2 ** (width - 1))
int_max = 2 ** (width - 1) - 1
- mantissa = x / 2**exponent
+ mantissa = x / 2 ** exponent
mantissa = torch.clamp(mantissa.floor(), int_min, int_max)
- msfp_x = (2**exponent) * mantissa
+ msfp_x = (2 ** exponent) * mantissa
return msfp_x, mantissa, exponent
diff --git a/src/mase_components/memory/test/fifo_tb.py b/src/mase_components/memory/test/fifo_tb.py
index 9dc4f54ff..5f247bede 100644
--- a/src/mase_components/memory/test/fifo_tb.py
+++ b/src/mase_components/memory/test/fifo_tb.py
@@ -27,7 +27,7 @@ def __init__(self, dut) -> None:
# self.output_monitor.log.setLevel("DEBUG")
def generate_inputs(self, num=20):
- return [randint(0, (2**self.DATA_WIDTH) - 1) for _ in range(num)]
+ return [randint(0, (2 ** self.DATA_WIDTH) - 1) for _ in range(num)]
@cocotb.test()
@@ -120,12 +120,7 @@ async def cocotb_test_soak(dut):
@pytest.mark.dev
def test_fifo():
mase_runner(
- module_param_list=[
- {"DEPTH": 1},
- {"DEPTH": 7},
- {"DEPTH": 8},
- {"DEPTH": 81},
- ],
+ module_param_list=[{"DEPTH": 1}, {"DEPTH": 7}, {"DEPTH": 8}, {"DEPTH": 81},],
trace=True,
)
diff --git a/src/mase_components/memory/test/repeat_circular_buffer_tb.py b/src/mase_components/memory/test/repeat_circular_buffer_tb.py
index 11d7d85b1..db6bfc7b8 100644
--- a/src/mase_components/memory/test/repeat_circular_buffer_tb.py
+++ b/src/mase_components/memory/test/repeat_circular_buffer_tb.py
@@ -30,7 +30,7 @@ def generate_inputs(self, num=10):
inputs = []
for _ in range(num):
inputs.extend(
- [random.randint(0, 2**self.DATA_WIDTH - 1) for _ in range(self.SIZE)]
+ [random.randint(0, 2 ** self.DATA_WIDTH - 1) for _ in range(self.SIZE)]
)
return inputs
diff --git a/src/mase_components/memory/test/unpacked_fifo_tb.py b/src/mase_components/memory/test/unpacked_fifo_tb.py
index cf0a99592..734a2edcb 100644
--- a/src/mase_components/memory/test/unpacked_fifo_tb.py
+++ b/src/mase_components/memory/test/unpacked_fifo_tb.py
@@ -26,7 +26,7 @@ def __init__(self, dut) -> None:
def generate_inputs(self, batches=20):
return [
- [randint(0, (2**self.DATA_WIDTH) - 1) for _ in range(self.IN_NUM)]
+ [randint(0, (2 ** self.DATA_WIDTH) - 1) for _ in range(self.IN_NUM)]
for _ in range(batches)
]
diff --git a/src/mase_components/normalization_layers/process_synth_impl.py b/src/mase_components/normalization_layers/process_synth_impl.py
index c65d1b8a9..5a12ee170 100644
--- a/src/mase_components/normalization_layers/process_synth_impl.py
+++ b/src/mase_components/normalization_layers/process_synth_impl.py
@@ -103,19 +103,14 @@ def gather_data(build_dir: Path):
if __name__ == "__main__":
data = gather_data(Path("build"))
data["ns"] = data["clk_period"] - data["wns"]
- data["fmax"] = 1 / (data["ns"] * (10**-9))
+ data["fmax"] = 1 / (data["ns"] * (10 ** -9))
data["fmax_mhz"] = data["fmax"] / 1_000_000
print(data)
def plot(col):
- alt.Chart(data).mark_line().encode(
- x="width",
- y=col,
- color="norm",
- ).properties(
- width=400,
- height=200,
+ alt.Chart(data).mark_line().encode(x="width", y=col, color="norm",).properties(
+ width=400, height=200,
).save(f"{col}_plot.png", scale_factor=3)
plot("wns")
diff --git a/src/mase_components/normalization_layers/test/batch_norm_2d_tb.py b/src/mase_components/normalization_layers/test/batch_norm_2d_tb.py
index 02297dd3b..d43976845 100644
--- a/src/mase_components/normalization_layers/test/batch_norm_2d_tb.py
+++ b/src/mase_components/normalization_layers/test/batch_norm_2d_tb.py
@@ -158,7 +158,7 @@ def model(self, inputs):
-1, self.NUM_CHANNELS, self.TOTAL_DIM1, self.TOTAL_DIM0
)
x = sign_extend_t(x, self.IN_WIDTH).to(dtype=torch.float32) / (
- 2**self.IN_FRAC_WIDTH
+ 2 ** self.IN_FRAC_WIDTH
)
# Float Model
@@ -344,7 +344,6 @@ async def valid_backpressure(dut):
@pytest.mark.skip(reason="Needs to be fixed.")
def test_batch_norm_2d():
-
def gen_cfg(
total_dim0: int = 4,
total_dim1: int = 4,
@@ -398,8 +397,7 @@ def gen_cfg(
]
mase_runner(
- module_param_list=test_cfgs,
- trace=True,
+ module_param_list=test_cfgs, trace=True,
)
diff --git a/src/mase_components/normalization_layers/test/channel_selection_tb.py b/src/mase_components/normalization_layers/test/channel_selection_tb.py
index ceb65df8a..0fb603971 100644
--- a/src/mase_components/normalization_layers/test/channel_selection_tb.py
+++ b/src/mase_components/normalization_layers/test/channel_selection_tb.py
@@ -61,7 +61,6 @@ async def basic(dut):
@pytest.mark.skip(reason="Needs to be fixed.")
def test_channel_selection():
-
def gen_cfg(num_channels, num_blocks):
return {"NUM_CHANNELS": num_channels, "NUM_SPATIAL_BLOCKS": num_blocks}
diff --git a/src/mase_components/normalization_layers/test/group_norm_2d_tb.py b/src/mase_components/normalization_layers/test/group_norm_2d_tb.py
index e31203c25..cffd16822 100644
--- a/src/mase_components/normalization_layers/test/group_norm_2d_tb.py
+++ b/src/mase_components/normalization_layers/test/group_norm_2d_tb.py
@@ -79,7 +79,7 @@ def __init__(self, dut) -> None:
self.out_width_tup = self.OUT_WIDTH, self.OUT_FRAC_WIDTH
# Inverse Square Root LUT
- self.isqrt_lut = make_lut(2**5, 16)
+ self.isqrt_lut = make_lut(2 ** 5, 16)
self.num_groups = randint(2, 3)
self.total_channels = self.GROUP_CHANNELS * self.num_groups
@@ -141,7 +141,7 @@ def model(self, inputs):
-1, self.total_channels, self.TOTAL_DIM1, self.TOTAL_DIM0
)
x = sign_extend_t(x, self.IN_WIDTH).to(dtype=torch.float32) / (
- 2**self.IN_FRAC_WIDTH
+ 2 ** self.IN_FRAC_WIDTH
)
# Float Model
@@ -239,12 +239,7 @@ def test_group_norm_2d():
makedirs(mem_dir, exist_ok=True)
def isqrt_width(
- total_dim0,
- total_dim1,
- compute_dim0,
- compute_dim1,
- group_channels,
- in_width,
+ total_dim0, total_dim1, compute_dim0, compute_dim1, group_channels, in_width,
):
depth_dim0 = total_dim0 // compute_dim0
depth_dim1 = total_dim1 // compute_dim1
@@ -274,7 +269,7 @@ def gen_cfg(
isqrt_w = isqrt_width(
total_dim0, total_dim1, compute_dim0, compute_dim1, channels, in_width
)
- lut = make_lut(2**LUT_POW, isqrt_w)
+ lut = make_lut(2 ** LUT_POW, isqrt_w)
mem_path = mem_dir / f"lutmem-{str_id}.mem"
write_memb(mem_path, lut, isqrt_w)
params = {
diff --git a/src/mase_components/normalization_layers/test/models.py b/src/mase_components/normalization_layers/test/models.py
index 0f10b228e..dfeadac98 100644
--- a/src/mase_components/normalization_layers/test/models.py
+++ b/src/mase_components/normalization_layers/test/models.py
@@ -56,11 +56,11 @@ def _fixed_group_norm_2d_model(
logger.debug("Diff:")
logger.debug(diff[0])
- squares = diff**2
+ squares = diff ** 2
logger.debug("Squares:")
logger.debug(squares[0])
- squares_int = (squares * (2**square_frac_width)).int()
- logger.debug(squares * (2**square_frac_width))
+ squares_int = (squares * (2 ** square_frac_width)).int()
+ logger.debug(squares * (2 ** square_frac_width))
sum_squares = torch.sum(squares, dim=(1, 2, 3), keepdim=True)
sum_squares = integer_floor_quantizer(
@@ -81,10 +81,12 @@ def _fixed_group_norm_2d_model(
logger.debug(f"{var[0]}")
# Clamp down variance to isqrt width
- var_clamp = torch.clamp(var, 0.0, ((2**isqrt_width) - 1) / (2**isqrt_frac_width))
+ var_clamp = torch.clamp(
+ var, 0.0, ((2 ** isqrt_width) - 1) / (2 ** isqrt_frac_width)
+ )
logger.debug("Variance Clamped:")
logger.debug(f"{var_clamp[0]}")
- var_clamp_int = (var_clamp * (2**isqrt_frac_width)).int()
+ var_clamp_int = (var_clamp * (2 ** isqrt_frac_width)).int()
# Inverse Square Root calculation
lut_pow = ceil(log2(len(isqrt_lut)))
@@ -104,7 +106,7 @@ def _fixed_group_norm_2d_model(
logger.debug("INV SQRT INT:")
logger.debug(f"{inv_sqrt_int[0]}")
- inv_sqrt = inv_sqrt_int / (2**isqrt_frac_width)
+ inv_sqrt = inv_sqrt_int / (2 ** isqrt_frac_width)
logger.debug("Inverse SQRT:")
logger.debug(f"{inv_sqrt[0]}")
diff --git a/src/mase_components/normalization_layers/test/rms_norm_2d_tb.py b/src/mase_components/normalization_layers/test/rms_norm_2d_tb.py
index 22ef8bdd1..0c998f83d 100644
--- a/src/mase_components/normalization_layers/test/rms_norm_2d_tb.py
+++ b/src/mase_components/normalization_layers/test/rms_norm_2d_tb.py
@@ -151,7 +151,7 @@ def reconstruct_tensor(self, x, width, frac_width):
x = torch.stack(matrix_list).reshape(
-1, self.CHANNELS, self.TOTAL_DIM1, self.TOTAL_DIM0
)
- x = sign_extend_t(x, width).to(dtype=torch.float32) / (2**frac_width)
+ x = sign_extend_t(x, width).to(dtype=torch.float32) / (2 ** frac_width)
return x
def output_monitor_split(self, x, width, frac_width):
@@ -248,12 +248,7 @@ def test_rms_norm_2d():
makedirs(mem_dir, exist_ok=True)
def isqrt_width(
- total_dim0,
- total_dim1,
- compute_dim0,
- compute_dim1,
- group_channels,
- in_width,
+ total_dim0, total_dim1, compute_dim0, compute_dim1, group_channels, in_width,
):
depth_dim0 = total_dim0 // compute_dim0
depth_dim1 = total_dim1 // compute_dim1
@@ -282,7 +277,7 @@ def gen_cfg(
isqrt_w = isqrt_width(
total_dim0, total_dim1, compute_dim0, compute_dim1, channels, in_width
)
- lut = make_lut(2**LUT_POW, isqrt_w)
+ lut = make_lut(2 ** LUT_POW, isqrt_w)
mem_path = mem_dir / f"lutmem-{str_id}.mem"
write_memb(mem_path, lut, isqrt_w)
params = {
diff --git a/src/mase_components/scalar_operators/fixed/test/fixed_isqrt_tb.py b/src/mase_components/scalar_operators/fixed/test/fixed_isqrt_tb.py
index 6ee7a93dd..f0f5cafb1 100644
--- a/src/mase_components/scalar_operators/fixed/test/fixed_isqrt_tb.py
+++ b/src/mase_components/scalar_operators/fixed/test/fixed_isqrt_tb.py
@@ -24,31 +24,23 @@ class VerificationCase(Testbench):
def __init__(self, dut):
super().__init__(dut, dut.clk, dut.rst)
self.assign_self_params(
- [
- "IN_WIDTH",
- "IN_FRAC_WIDTH",
- "LUT_POW",
- ]
+ ["IN_WIDTH", "IN_FRAC_WIDTH", "LUT_POW",]
)
self.input_driver = StreamDriver(
dut.clk, dut.in_data, dut.in_valid, dut.in_ready
)
self.output_monitor = StreamMonitor(
- dut.clk,
- dut.out_data,
- dut.out_valid,
- dut.out_ready,
- name="Output ISQRT",
+ dut.clk, dut.out_data, dut.out_valid, dut.out_ready, name="Output ISQRT",
)
def generate_inputs(self, num=10000):
- maxnum = (2**self.IN_WIDTH) - 1
+ maxnum = (2 ** self.IN_WIDTH) - 1
return [random.randint(0, maxnum) for _ in range(num)], num
def model(self, data_in):
ref = []
- lut_size = 2**self.LUT_POW
+ lut_size = 2 ** self.LUT_POW
lut = make_lut(lut_size, self.IN_WIDTH)
for x in data_in:
expected = isqrt_sw2(
@@ -131,7 +123,7 @@ def test_fixed_isqrt():
makedirs(mem_dir, exist_ok=True)
def single_cfg(width, frac_width, lut_pow, str_id):
- lut_size = 2**lut_pow
+ lut_size = 2 ** lut_pow
lut = make_lut(lut_size, width)
mem_path = mem_dir / f"lutmem-{str_id}.mem"
write_memb(mem_path, lut, width)
diff --git a/src/mase_components/scalar_operators/fixed/test/fixed_nr_stage_tb.py b/src/mase_components/scalar_operators/fixed/test/fixed_nr_stage_tb.py
index d09305ef0..5a3019f8e 100644
--- a/src/mase_components/scalar_operators/fixed/test/fixed_nr_stage_tb.py
+++ b/src/mase_components/scalar_operators/fixed/test/fixed_nr_stage_tb.py
@@ -24,12 +24,12 @@ def __init__(self, dut):
self.assign_self_params(["WIDTH"])
def generate_inputs(self, lut_pow):
- samples = 2**self.WIDTH
+ samples = 2 ** self.WIDTH
int_width = 1
frac_width = self.WIDTH - 1
data_x = []
initial_guesses = []
- lut_size = 2**lut_pow
+ lut_size = 2 ** lut_pow
lut = make_lut(lut_size, self.WIDTH)
# NOTE: since negative values are not supported by fixed formats since
# isqrt only outputs positive results we cannot test every single com-
diff --git a/src/mase_components/scalar_operators/fixed/test/isqrt_sw.py b/src/mase_components/scalar_operators/fixed/test/isqrt_sw.py
index f6fc798a4..901ba964d 100644
--- a/src/mase_components/scalar_operators/fixed/test/isqrt_sw.py
+++ b/src/mase_components/scalar_operators/fixed/test/isqrt_sw.py
@@ -13,7 +13,7 @@ def find_msb(x: int, width: int) -> int:
def float_to_int(x: float, int_width: int, frac_width: int) -> int:
integer = int(x)
x -= integer
- res = integer * (2**frac_width)
+ res = integer * (2 ** frac_width)
for i in range(1, frac_width + 1):
power = 2 ** (-i)
if power <= x:
@@ -23,8 +23,8 @@ def float_to_int(x: float, int_width: int, frac_width: int) -> int:
def int_to_float(x: int, int_width: int, frac_width: int) -> float:
- integer = x / (2**frac_width)
- fraction = x - integer * 2**frac_width
+ integer = x / (2 ** frac_width)
+ fraction = x - integer * 2 ** frac_width
res = integer
for i in range(1, frac_width + 1):
@@ -85,7 +85,7 @@ def fixed_lut_index_sw(x_red: int, width: int, lut_pow: int) -> int:
res = 0
else:
res = x_red - 2 ** (width - 1)
- res = res * 2**lut_pow
+ res = res * 2 ** lut_pow
res = res / 2 ** (width - 1)
# FORMAT OUTPUT: Q(WIDTH).0
return int(res)
@@ -258,7 +258,7 @@ def test_isqrt_sw_model():
def debug_single():
lut_pow = 5
- lut_size = 2**lut_pow
+ lut_size = 2 ** lut_pow
int_width = 2
frac_width = 1
width = int_width + frac_width
diff --git a/src/mase_components/transformer_layers/test/fixed_grouped_query_attention_wrapper_tb.py b/src/mase_components/transformer_layers/test/fixed_grouped_query_attention_wrapper_tb.py
index cc7d23428..a334a1469 100644
--- a/src/mase_components/transformer_layers/test/fixed_grouped_query_attention_wrapper_tb.py
+++ b/src/mase_components/transformer_layers/test/fixed_grouped_query_attention_wrapper_tb.py
@@ -89,13 +89,16 @@ def forward(self, x: Tensor):
out = self.o_projection(attn_output)
- return out, {
- "query": query,
- "key": key.transpose(1, 2), # Key is transposed in hardware
- "value": value,
- "heads_out": heads_out,
- "attn_output": attn_output,
- }
+ return (
+ out,
+ {
+ "query": query,
+ "key": key.transpose(1, 2), # Key is transposed in hardware
+ "value": value,
+ "heads_out": heads_out,
+ "attn_output": attn_output,
+ },
+ )
class FixedGroupedQueryAttentionTB(Testbench):
@@ -192,28 +195,16 @@ def __init__(self, dut) -> None:
if self.HAS_BIAS == 1:
self.bias_q_driver = StreamDriver(
- dut.clk,
- dut.bias_query,
- dut.bias_query_valid,
- dut.bias_query_ready,
+ dut.clk, dut.bias_query, dut.bias_query_valid, dut.bias_query_ready,
)
self.bias_k_driver = StreamDriver(
- dut.clk,
- dut.bias_key,
- dut.bias_key_valid,
- dut.bias_key_ready,
+ dut.clk, dut.bias_key, dut.bias_key_valid, dut.bias_key_ready,
)
self.bias_v_driver = StreamDriver(
- dut.clk,
- dut.bias_value,
- dut.bias_value_valid,
- dut.bias_value_ready,
+ dut.clk, dut.bias_value, dut.bias_value_valid, dut.bias_value_ready,
)
self.bias_o_driver = StreamDriver(
- dut.clk,
- dut.bias_output,
- dut.bias_output_valid,
- dut.bias_output_ready,
+ dut.clk, dut.bias_output, dut.bias_output_valid, dut.bias_output_ready,
)
self.error_threshold = 2
@@ -478,11 +469,11 @@ async def run_memory_bandwidth_test(self, us: int = 500):
num_v_weight_beats_sent = self.weight_v_driver.num_beats
num_o_weight_beats_sent = self.weight_o_driver.num_beats
- input_beats_per_sec = num_input_beats_sent / (nanosec * (10**-9))
- num_q_beats_per_sec = num_q_weight_beats_sent / (nanosec * (10**-9))
- num_k_beats_per_sec = num_k_weight_beats_sent / (nanosec * (10**-9))
- num_v_beats_per_sec = num_v_weight_beats_sent / (nanosec * (10**-9))
- num_o_beats_per_sec = num_o_weight_beats_sent / (nanosec * (10**-9))
+ input_beats_per_sec = num_input_beats_sent / (nanosec * (10 ** -9))
+ num_q_beats_per_sec = num_q_weight_beats_sent / (nanosec * (10 ** -9))
+ num_k_beats_per_sec = num_k_weight_beats_sent / (nanosec * (10 ** -9))
+ num_v_beats_per_sec = num_v_weight_beats_sent / (nanosec * (10 ** -9))
+ num_o_beats_per_sec = num_o_weight_beats_sent / (nanosec * (10 ** -9))
self.log.info("Test length (ns): %.4f" % nanosec)
@@ -600,9 +591,7 @@ def test_fixed_linear_smoke():
]
mase_runner(
- module_param_list=cfgs,
- hierarchical=True,
- template=True,
+ module_param_list=cfgs, hierarchical=True, template=True,
)
@@ -614,9 +603,7 @@ def test_parallelism_sweep():
cfgs.append(get_config(16, 128, 8, 4, embedding_par, seq_par))
mase_runner(
- module_param_list=cfgs,
- hierarchical=True,
- template=True,
+ module_param_list=cfgs, hierarchical=True, template=True,
)
@@ -628,9 +615,7 @@ def test_small_parallelism():
cfgs.append(get_config(16, 128, 8, 4, embedding_par, seq_par))
mase_runner(
- module_param_list=cfgs,
- hierarchical=True,
- template=True,
+ module_param_list=cfgs, hierarchical=True, template=True,
)
@@ -640,9 +625,7 @@ def test_heads_sweep():
cfgs.append(get_config(256, 256, 16, kv_heads, 16, 1))
mase_runner(
- module_param_list=cfgs,
- hierarchical=True,
- template=True,
+ module_param_list=cfgs, hierarchical=True, template=True,
)
@@ -654,9 +637,7 @@ def test_bitwidth_sweep():
)
mase_runner(
- module_param_list=cfgs,
- hierarchical=True,
- template=True,
+ module_param_list=cfgs, hierarchical=True, template=True,
)
diff --git a/src/mase_components/transformer_layers/test/fixed_self_attention_head_tb.py b/src/mase_components/transformer_layers/test/fixed_self_attention_head_tb.py
index a46032646..70df56d43 100644
--- a/src/mase_components/transformer_layers/test/fixed_self_attention_head_tb.py
+++ b/src/mase_components/transformer_layers/test/fixed_self_attention_head_tb.py
@@ -44,11 +44,7 @@ def __init__(self, dut) -> None:
)
self.out_monitor = StreamMonitor(
- dut.clk,
- dut.out,
- dut.out_valid,
- dut.out_ready,
- check=False,
+ dut.clk, dut.out, dut.out_valid, dut.out_ready, check=False,
)
# Model
@@ -60,8 +56,7 @@ def __init__(self, dut) -> None:
"frac_width": self.get_parameter("IN_DATA_PRECISION_1"),
}
self.model = BertSelfAttentionHeadInteger(
- config=self.config,
- q_config=self.q_config,
+ config=self.config, q_config=self.q_config,
)
# Set verbosity of driver and monitor loggers to debug
@@ -119,27 +114,21 @@ async def run_test(self):
# * Load the query driver
self.log.info(f"Processing query inputs: {inputs['query_layer']}")
query_inputs = self.preprocess_tensor(
- tensor=inputs["query_layer"],
- config=self.q_config,
- parallelism=parallelism,
+ tensor=inputs["query_layer"], config=self.q_config, parallelism=parallelism,
)
self.query_driver.load_driver(query_inputs)
# * Load the key driver
self.log.info(f"Processing key inputs: {inputs['key_layer']}")
key_inputs = self.preprocess_tensor(
- tensor=inputs["key_layer"],
- config=self.q_config,
- parallelism=parallelism,
+ tensor=inputs["key_layer"], config=self.q_config, parallelism=parallelism,
)
self.key_driver.load_driver(key_inputs)
# * Load the value driver
self.log.info(f"Processing value inputs: {inputs['value_layer']}")
value_inputs = self.preprocess_tensor(
- tensor=inputs["value_layer"],
- config=self.q_config,
- parallelism=parallelism,
+ tensor=inputs["value_layer"], config=self.q_config, parallelism=parallelism,
)
self.value_driver.load_driver(value_inputs)
@@ -198,18 +187,11 @@ def test_fixed_self_attention_head_smoke():
# * Generate exponential LUT for softmax
generate_memory.generate_sv_lut(
- "exp",
- 16,
- 3,
- 16,
- 3,
- path=Path(__file__).parents[1] / "rtl",
+ "exp", 16, 3, 16, 3, path=Path(__file__).parents[1] / "rtl",
)
mase_runner(
trace=True,
- module_param_list=[
- get_fixed_self_attention_head_config(),
- ],
+ module_param_list=[get_fixed_self_attention_head_config(),],
skip_build=False,
)
diff --git a/src/mase_components/vision_models/vit/test/fixed_mlp_tb.py b/src/mase_components/vision_models/vit/test/fixed_mlp_tb.py
index 2e19bed4b..823c65db4 100644
--- a/src/mase_components/vision_models/vit/test/fixed_mlp_tb.py
+++ b/src/mase_components/vision_models/vit/test/fixed_mlp_tb.py
@@ -108,8 +108,7 @@ def __init__(self, samples=1):
depth_in_num = int(self.in_num / self.tile_in_num)
depth_out_features = int(self.out_features / self.tile_out_features)
self.outputs = RandomSink(
- samples=samples * depth_out_features * depth_in_num,
- debug=debug,
+ samples=samples * depth_out_features * depth_in_num, debug=debug,
)
self.ref = self.sw_compute()
@@ -471,8 +470,7 @@ def runner():
)
for _ in range(1):
runner.test(
- hdl_toplevel="fixed_mlp",
- test_module="fixed_mlp_tb",
+ hdl_toplevel="fixed_mlp", test_module="fixed_mlp_tb",
)
diff --git a/src/mase_components/vision_models/vit/test/fixed_patch_embed_tb.py b/src/mase_components/vision_models/vit/test/fixed_patch_embed_tb.py
index 05cf0f2cf..b92e6dc4f 100644
--- a/src/mase_components/vision_models/vit/test/fixed_patch_embed_tb.py
+++ b/src/mase_components/vision_models/vit/test/fixed_patch_embed_tb.py
@@ -50,14 +50,14 @@ def __init__(self, samples=1):
self.pe_unroll_kernel_out = 3
self.pe_unroll_in_c = 3
self.pe_unroll_embed_dim = 8
- self.num_patch = int(self.in_y * self.in_x // (self.patch_size**2))
+ self.num_patch = int(self.in_y * self.in_x // (self.patch_size ** 2))
# self.num_classes = 10
# self.head_unroll_out_x = 5
self.samples = samples
self.pe_iter_weight = int(
- (self.patch_size**2)
+ (self.patch_size ** 2)
* self.in_c
* self.embed_dim
/ self.pe_unroll_kernel_out
@@ -247,10 +247,7 @@ def conv_pack(
unroll_out_channels,
).permute(0, 3, 1, 4, 2)
- w_tensor = w_tensor.reshape(
- -1,
- unroll_out_channels * unroll_kernel_out,
- )
+ w_tensor = w_tensor.reshape(-1, unroll_out_channels * unroll_kernel_out,)
w_in = w_tensor.type(torch.int).flip(0).tolist()
# bias_pack
bias_tensor = bias.repeat(samples, 1).reshape(-1, unroll_out_channels)
diff --git a/src/mase_components/vision_models/vit/test/fixed_pvt_tb.py b/src/mase_components/vision_models/vit/test/fixed_pvt_tb.py
index 06b921e81..532e3fb07 100644
--- a/src/mase_components/vision_models/vit/test/fixed_pvt_tb.py
+++ b/src/mase_components/vision_models/vit/test/fixed_pvt_tb.py
@@ -49,7 +49,7 @@ def __init__(self, samples=1):
self.in_x = 224
self.embed_dim = 384
self.patch_size = 16
- self.num_patch = self.in_y * self.in_x // (self.patch_size**2)
+ self.num_patch = self.in_y * self.in_x // (self.patch_size ** 2)
self.num_heads = 6
self.mlp_ratio = 2
@@ -64,7 +64,7 @@ def __init__(self, samples=1):
self.head_unroll_out_x = 1
self.pe_iter_weight = int(
- (self.patch_size**2)
+ (self.patch_size ** 2)
* self.in_c
* self.embed_dim
/ self.pe_unroll_kernel_out
@@ -1148,10 +1148,7 @@ def conv_pack(
unroll_out_channels,
).permute(0, 3, 1, 4, 2)
- w_tensor = w_tensor.reshape(
- -1,
- unroll_out_channels * unroll_kernel_out,
- )
+ w_tensor = w_tensor.reshape(-1, unroll_out_channels * unroll_kernel_out,)
w_in = w_tensor.type(torch.int).flip(0).tolist()
# bias_pack
bias_tensor = bias.repeat(samples, 1).reshape(-1, unroll_out_channels)
@@ -1514,8 +1511,7 @@ def runner():
build_args=extra_args,
)
runner.test(
- hdl_toplevel="fixed_pvt",
- test_module="fixed_pvt_tb",
+ hdl_toplevel="fixed_pvt", test_module="fixed_pvt_tb",
)
diff --git a/src/mase_components/vision_models/vit/test/hash_exp_tb.py b/src/mase_components/vision_models/vit/test/hash_exp_tb.py
index e8b107503..719ee89dd 100644
--- a/src/mase_components/vision_models/vit/test/hash_exp_tb.py
+++ b/src/mase_components/vision_models/vit/test/hash_exp_tb.py
@@ -147,9 +147,7 @@ def runner():
print(extra_args)
runner = get_runner(sim)
runner.build(
- verilog_sources=verilog_sources,
- hdl_toplevel="hash_exp",
- build_args=extra_args,
+ verilog_sources=verilog_sources, hdl_toplevel="hash_exp", build_args=extra_args,
)
runner.test(hdl_toplevel="hash_exp", test_module="hash_exp_tb")
diff --git a/src/mase_components/vision_models/vit/test/hash_softmax_tb.py b/src/mase_components/vision_models/vit/test/hash_softmax_tb.py
index 98ceea63f..5c0522fb9 100644
--- a/src/mase_components/vision_models/vit/test/hash_softmax_tb.py
+++ b/src/mase_components/vision_models/vit/test/hash_softmax_tb.py
@@ -45,11 +45,7 @@ def __init__(self, samples=1):
},
}
self.d_config = {
- "softmax": {
- "in_size": 1,
- "out_size": 1,
- "in_depth": 4,
- },
+ "softmax": {"in_size": 1, "out_size": 1, "in_depth": 4,},
}
in_size = self.d_config["softmax"]["in_size"]
out_size = self.d_config["softmax"]["out_size"]
diff --git a/src/mase_components/vision_models/vit/test/helpers/ha_softmax.py b/src/mase_components/vision_models/vit/test/helpers/ha_softmax.py
index 62e5b840d..8dac8afee 100644
--- a/src/mase_components/vision_models/vit/test/helpers/ha_softmax.py
+++ b/src/mase_components/vision_models/vit/test/helpers/ha_softmax.py
@@ -5,8 +5,8 @@
def quantize_to_int(x: Tensor, width: int, frac_width: int):
- x = _integer_quantize(x, width, frac_width) * (2**frac_width)
- x = x.int() & (2**width - 1)
+ x = _integer_quantize(x, width, frac_width) * (2 ** frac_width)
+ x = x.int() & (2 ** width - 1)
return x
@@ -25,7 +25,7 @@ def twos_complement_to_float(binary_string: str, width: int, frac_width: int):
integer_magnitude = -(2 ** (width - 1)) + integer_magnitude
# Calculate scaling factor
- scaling_factor = 2**frac_width
+ scaling_factor = 2 ** frac_width
# Calculate floating-point value
float_value = integer_magnitude / scaling_factor
@@ -79,8 +79,7 @@ def generate_table_div_software(width, out_width, out_frac_width):
class QHashSoftmax(torch.nn.Module):
def __init__(
- self,
- config,
+ self, config,
):
super(QHashSoftmax, self).__init__()
self.in_width = config["data_in_width"]
@@ -109,7 +108,7 @@ def forward(self, x, scale):
# quantize to div_width
one_over_div = _integer_quantize(exp_sum // exp, self.div_width + 1, 0)
one_over_div = torch.where(
- exp == 0, torch.tensor(2**self.div_width - 1), one_over_div
+ exp == 0, torch.tensor(2 ** self.div_width - 1), one_over_div
)
one_over_div = torch.tensor(one_over_div, dtype=int)
diff --git a/src/mase_components/vision_models/vit/test/helpers/pvt_quant.py b/src/mase_components/vision_models/vit/test/helpers/pvt_quant.py
index d7f4e8464..4863f2ca9 100644
--- a/src/mase_components/vision_models/vit/test/helpers/pvt_quant.py
+++ b/src/mase_components/vision_models/vit/test/helpers/pvt_quant.py
@@ -50,7 +50,7 @@ def __init__(
self.dim = dim
self.num_heads = num_heads
head_dim = dim // num_heads
- self.scale = qk_scale or head_dim**-0.5
+ self.scale = qk_scale or head_dim ** -0.5
self.q = get_quantized_cls("linear", config["q_proj"])(
dim, dim, bias=qkv_bias, config=config["q_proj"]
diff --git a/test/nn/quantized/modules/attention_head.py b/test/nn/quantized/modules/attention_head.py
index e0f52a61e..72b625f90 100644
--- a/test/nn/quantized/modules/attention_head.py
+++ b/test/nn/quantized/modules/attention_head.py
@@ -1,6 +1,4 @@
-from chop.nn.quantized.modules.attention_head import (
- BertSelfAttentionHeadInteger,
-)
+from chop.nn.quantized.modules.attention_head import BertSelfAttentionHeadInteger
from transformers import AutoConfig
import torch
diff --git a/test/nn/snn/test_ann2snn.py b/test/nn/snn/test_ann2snn.py
index 1b8bbbcfa..3a5d2b7e2 100644
--- a/test/nn/snn/test_ann2snn.py
+++ b/test/nn/snn/test_ann2snn.py
@@ -191,13 +191,7 @@ def val(net, device, data_loader, T=None):
"by": "type",
"default": {"config": {"name": None}},
"fuse": True,
- "relu": {
- "config": {
- "name": "IFNode",
- "mode": "99.9%",
- "momentum": 0.1,
- }
- },
+ "relu": {"config": {"name": "IFNode", "mode": "99.9%", "momentum": 0.1,}},
"train_data_loader": input_generator,
"device": "cpu", # "device": "cuda",
}
diff --git a/test/passes/graph/analysis/add_metadata/test_add_common_metadata.py b/test/passes/graph/analysis/add_metadata/test_add_common_metadata.py
index 9740bdd38..a7bce45cd 100644
--- a/test/passes/graph/analysis/add_metadata/test_add_common_metadata.py
+++ b/test/passes/graph/analysis/add_metadata/test_add_common_metadata.py
@@ -27,22 +27,10 @@ def add_common_metadata(model_cls_name: str) -> MaseGraph:
# mg.fx_graph.print_tabular()
input_ids = torch.randint(
- 0,
- config.vocab_size,
- (
- 1,
- 128,
- config.hidden_size,
- ),
- device="meta",
+ 0, config.vocab_size, (1, 128, config.hidden_size,), device="meta",
)
mg, _ = passes.add_common_metadata_analysis_pass(
- mg,
- pass_args={
- "dummy_in": {
- "input_ids": input_ids,
- },
- },
+ mg, pass_args={"dummy_in": {"input_ids": input_ids,},},
)
return mg
diff --git a/test/passes/graph/analysis/pruning/test_hook_inspect.py b/test/passes/graph/analysis/pruning/test_hook_inspect.py
index dfa7fa0fd..b34b579d3 100644
--- a/test/passes/graph/analysis/pruning/test_hook_inspect.py
+++ b/test/passes/graph/analysis/pruning/test_hook_inspect.py
@@ -111,15 +111,9 @@ def run_with_config(config_file):
profile_pass_arg = {
"by": "type",
- "target_weight_nodes": [
- "conv2d",
- ],
- "target_activation_nodes": [
- "conv2d",
- ],
- "weight_statistics": {
- "variance_precise": {"device": "cpu", "dims": "all"},
- },
+ "target_weight_nodes": ["conv2d",],
+ "target_activation_nodes": ["conv2d",],
+ "weight_statistics": {"variance_precise": {"device": "cpu", "dims": "all"},},
"activation_statistics": {
"variance_precise": {"device": "cpu", "dims": "all"},
},
diff --git a/test/passes/graph/analysis/statistic_profiler/test_statistic_profiler.py b/test/passes/graph/analysis/statistic_profiler/test_statistic_profiler.py
index 2b9828ace..084170c70 100644
--- a/test/passes/graph/analysis/statistic_profiler/test_statistic_profiler.py
+++ b/test/passes/graph/analysis/statistic_profiler/test_statistic_profiler.py
@@ -52,9 +52,7 @@ def test_statistic_profiler():
dataset_info = get_dataset_info("cifar10")
model = get_model(
- checkpoint="resnet18",
- pretrained=False,
- dataset_info=dataset_info,
+ checkpoint="resnet18", pretrained=False, dataset_info=dataset_info,
)
dummy_in = {"x": next(iter(datamodule.train_dataloader()))[0]}
@@ -68,15 +66,9 @@ def test_statistic_profiler():
pass_arg = {
"by": "type",
- "target_weight_nodes": [
- "conv2d",
- ],
- "target_activation_nodes": [
- "relu",
- ],
- "weight_statistics": {
- "variance_precise": {"device": "cpu", "dims": "all"},
- },
+ "target_weight_nodes": ["conv2d",],
+ "target_activation_nodes": ["relu",],
+ "weight_statistics": {"variance_precise": {"device": "cpu", "dims": "all"},},
"activation_statistics": {
"variance_precise": {"device": "cpu", "dims": "all"},
},
diff --git a/test/passes/graph/transforms/prune/test_prune.py b/test/passes/graph/transforms/prune/test_prune.py
index 8bddad906..8c6de1d2a 100644
--- a/test/passes/graph/transforms/prune/test_prune.py
+++ b/test/passes/graph/transforms/prune/test_prune.py
@@ -106,15 +106,9 @@ def run_with_config(config_file):
profile_pass_arg = {
"by": "type",
- "target_weight_nodes": [
- "conv2d",
- ],
- "target_activation_nodes": [
- "conv2d",
- ],
- "weight_statistics": {
- "variance_precise": {"device": "cpu", "dims": "all"},
- },
+ "target_weight_nodes": ["conv2d",],
+ "target_activation_nodes": ["conv2d",],
+ "weight_statistics": {"variance_precise": {"device": "cpu", "dims": "all"},},
"activation_statistics": {
"variance_precise": {"device": "cpu", "dims": "all"},
},
diff --git a/test/passes/graph/transforms/prune/test_prune_detach_hook.py b/test/passes/graph/transforms/prune/test_prune_detach_hook.py
index bf15caa86..f7afe9a8d 100644
--- a/test/passes/graph/transforms/prune/test_prune_detach_hook.py
+++ b/test/passes/graph/transforms/prune/test_prune_detach_hook.py
@@ -105,15 +105,9 @@ def run_with_config(config_file):
profile_pass_arg = {
"by": "type",
- "target_weight_nodes": [
- "conv2d",
- ],
- "target_activation_nodes": [
- "conv2d",
- ],
- "weight_statistics": {
- "variance_precise": {"device": "cpu", "dims": "all"},
- },
+ "target_weight_nodes": ["conv2d",],
+ "target_activation_nodes": ["conv2d",],
+ "weight_statistics": {"variance_precise": {"device": "cpu", "dims": "all"},},
"activation_statistics": {
"variance_precise": {"device": "cpu", "dims": "all"},
},
diff --git a/test/passes/graph/transforms/quantize/test_quantize_lutnet_conv2d.py b/test/passes/graph/transforms/quantize/test_quantize_lutnet_conv2d.py
index 89327f8f0..5ca6f6b27 100644
--- a/test/passes/graph/transforms/quantize/test_quantize_lutnet_conv2d.py
+++ b/test/passes/graph/transforms/quantize/test_quantize_lutnet_conv2d.py
@@ -181,8 +181,7 @@ def test_quantize_lutnet_conv2d():
first_table = connection[0]
assert any(initialized_weight[0, :] == first_table) and any(
initialized_weight[
- input_c * k_w * k_h * output_c * (lutnet_config["data_in_levels"] - 1),
- :,
+ input_c * k_w * k_h * output_c * (lutnet_config["data_in_levels"] - 1), :,
]
== first_table
)
diff --git a/test/passes/graph/transforms/quantize/test_quantize_lutnet_linear_2.py b/test/passes/graph/transforms/quantize/test_quantize_lutnet_linear_2.py
index f1115cf56..72a0ce71d 100644
--- a/test/passes/graph/transforms/quantize/test_quantize_lutnet_linear_2.py
+++ b/test/passes/graph/transforms/quantize/test_quantize_lutnet_linear_2.py
@@ -10,9 +10,7 @@
from pathlib import Path
sys.path.append(Path(__file__).resolve().parents[5].as_posix())
-from chop.nn.quantized.modules.linear import (
- LinearLogicNets,
-)
+from chop.nn.quantized.modules.linear import LinearLogicNets
def generate_input_tensor(batch_size, input_features, min_val, max_val):
diff --git a/test/passes/graph/transforms/training/test_training_base_pass.py b/test/passes/graph/transforms/training/test_training_base_pass.py
index 04ed2167c..e46777de0 100644
--- a/test/passes/graph/transforms/training/test_training_base_pass.py
+++ b/test/passes/graph/transforms/training/test_training_base_pass.py
@@ -17,9 +17,7 @@
verify_common_metadata_analysis_pass,
)
from chop.ir.graph.mase_graph import MaseGraph
-from chop.passes.graph.transforms import (
- training_base_pass,
-)
+from chop.passes.graph.transforms import training_base_pass
from chop.passes.graph.utils import deepcopy_mase_graph
from chop.tools.logger import set_logging_verbosity
@@ -172,11 +170,7 @@ def test_training_base_backward_only():
"default": {"config": {"name": None}},
"linear": {
"config": {
- "forward": {
- "bypass": True,
- "pass": "quantize",
- "name": "integer",
- },
+ "forward": {"bypass": True, "pass": "quantize", "name": "integer",},
"backward": {
"pass": "quantize",
"name": "integer",
diff --git a/test/passes/graph/transforms/verilog/test_emit_activation_gelu.py b/test/passes/graph/transforms/verilog/test_emit_activation_gelu.py
index d78ac8525..68d405e5a 100644
--- a/test/passes/graph/transforms/verilog/test_emit_activation_gelu.py
+++ b/test/passes/graph/transforms/verilog/test_emit_activation_gelu.py
@@ -76,11 +76,7 @@ def test_emit_activation_gelu():
)
config_file = os.path.join(
- os.path.abspath(""),
- "configs",
- "tests",
- "quantize",
- "fixed.toml",
+ os.path.abspath(""), "configs", "tests", "quantize", "fixed.toml",
)
with open(config_file, "r") as f:
quan_args = toml.load(f)["passes"]["quantize"]
diff --git a/test/passes/graph/transforms/verilog/test_emit_activation_selu.py b/test/passes/graph/transforms/verilog/test_emit_activation_selu.py
index ea1a11928..8a59b2e44 100644
--- a/test/passes/graph/transforms/verilog/test_emit_activation_selu.py
+++ b/test/passes/graph/transforms/verilog/test_emit_activation_selu.py
@@ -77,11 +77,7 @@ def test_emit_activation_selu():
)
config_file = os.path.join(
- os.path.abspath(""),
- "configs",
- "tests",
- "quantize",
- "fixed.toml",
+ os.path.abspath(""), "configs", "tests", "quantize", "fixed.toml",
)
with open(config_file, "r") as f:
quan_args = toml.load(f)["passes"]["quantize"]
diff --git a/test/passes/graph/transforms/verilog/test_emit_activation_softplus.py b/test/passes/graph/transforms/verilog/test_emit_activation_softplus.py
index 8c7a1e107..9e60ba7a3 100644
--- a/test/passes/graph/transforms/verilog/test_emit_activation_softplus.py
+++ b/test/passes/graph/transforms/verilog/test_emit_activation_softplus.py
@@ -76,11 +76,7 @@ def test_emit_activation_softplus():
)
config_file = os.path.join(
- os.path.abspath(""),
- "configs",
- "tests",
- "quantize",
- "fixed.toml",
+ os.path.abspath(""), "configs", "tests", "quantize", "fixed.toml",
)
with open(config_file, "r") as f:
quan_args = toml.load(f)["passes"]["quantize"]
diff --git a/test/passes/graph/transforms/verilog/test_emit_activation_softsign.py b/test/passes/graph/transforms/verilog/test_emit_activation_softsign.py
index f50f5fce5..3de23e480 100644
--- a/test/passes/graph/transforms/verilog/test_emit_activation_softsign.py
+++ b/test/passes/graph/transforms/verilog/test_emit_activation_softsign.py
@@ -76,11 +76,7 @@ def test_emit_activation_softsign():
)
config_file = os.path.join(
- os.path.abspath(""),
- "configs",
- "tests",
- "quantize",
- "fixed.toml",
+ os.path.abspath(""), "configs", "tests", "quantize", "fixed.toml",
)
with open(config_file, "r") as f:
quan_args = toml.load(f)["passes"]["quantize"]
diff --git a/test/passes/graph/transforms/verilog/test_emit_activation_tanh.py b/test/passes/graph/transforms/verilog/test_emit_activation_tanh.py
index 205211ea6..116d53b36 100644
--- a/test/passes/graph/transforms/verilog/test_emit_activation_tanh.py
+++ b/test/passes/graph/transforms/verilog/test_emit_activation_tanh.py
@@ -75,11 +75,7 @@ def test_emit_activation_tanh():
mg, _ = add_common_metadata_analysis_pass(mg, {"dummy_in": dummy_in})
config_file = os.path.join(
- os.path.abspath(""),
- "configs",
- "tests",
- "quantize",
- "fixed.toml",
+ os.path.abspath(""), "configs", "tests", "quantize", "fixed.toml",
)
with open(config_file, "r") as f:
quan_args = toml.load(f)["passes"]["quantize"]
diff --git a/test/passes/graph/transforms/verilog/test_emit_verilog_bert.py b/test/passes/graph/transforms/verilog/test_emit_verilog_bert.py
index b50fb3776..3e0f16599 100644
--- a/test/passes/graph/transforms/verilog/test_emit_verilog_bert.py
+++ b/test/passes/graph/transforms/verilog/test_emit_verilog_bert.py
@@ -174,10 +174,7 @@ def emit_verilog_bert(
mg, _ = bert_update_metadata(mg, q_config)
mg, _ = passes.add_hardware_metadata_analysis_pass(
- mg,
- pass_args={
- "max_parallelism": [max_parallelism] * 4,
- },
+ mg, pass_args={"max_parallelism": [max_parallelism] * 4,},
)
# * Save the metadata to a file for debugging
@@ -193,11 +190,7 @@ def emit_verilog_bert(
mg, _ = passes.emit_bram_transform_pass(mg)
mg, _ = passes.emit_internal_rtl_transform_pass(mg)
mg, _ = passes.emit_cocotb_transform_pass(
- mg,
- pass_args={
- "wait_time": wait_count,
- "wait_unit": wait_unit,
- },
+ mg, pass_args={"wait_time": wait_count, "wait_unit": wait_unit,},
)
mg, _ = passes.emit_vivado_project_transform_pass(mg)
diff --git a/test/passes/graph/transforms/verilog/test_emit_verilog_llama.py b/test/passes/graph/transforms/verilog/test_emit_verilog_llama.py
index f268184ce..5447a3782 100644
--- a/test/passes/graph/transforms/verilog/test_emit_verilog_llama.py
+++ b/test/passes/graph/transforms/verilog/test_emit_verilog_llama.py
@@ -86,10 +86,7 @@ def emit_verilog_llama(
mg, _ = llama_update_metadata(mg, q_config)
mg, _ = passes.add_hardware_metadata_analysis_pass(
- mg,
- pass_args={
- "max_parallelism": [max_parallelism] * 4,
- },
+ mg, pass_args={"max_parallelism": [max_parallelism] * 4,},
)
# * Save the metadata to a file for debugging
@@ -105,11 +102,7 @@ def emit_verilog_llama(
mg, _ = passes.emit_bram_transform_pass(mg)
mg, _ = passes.emit_internal_rtl_transform_pass(mg)
mg, _ = passes.emit_cocotb_transform_pass(
- mg,
- pass_args={
- "wait_time": wait_count,
- "wait_unit": wait_unit,
- },
+ mg, pass_args={"wait_time": wait_count, "wait_unit": wait_unit,},
)
mg, _ = passes.emit_vivado_project_transform_pass(mg)
diff --git a/test/passes/graph/transforms/verilog/test_emit_verilog_mistral.py b/test/passes/graph/transforms/verilog/test_emit_verilog_mistral.py
index 64625495e..51ae9dc7d 100644
--- a/test/passes/graph/transforms/verilog/test_emit_verilog_mistral.py
+++ b/test/passes/graph/transforms/verilog/test_emit_verilog_mistral.py
@@ -87,10 +87,7 @@ def emit_verilog_mistral(
mg, _ = mistral_update_metadata(mg, q_config)
mg, _ = passes.add_hardware_metadata_analysis_pass(
- mg,
- pass_args={
- "max_parallelism": [max_parallelism] * 4,
- },
+ mg, pass_args={"max_parallelism": [max_parallelism] * 4,},
)
# * Save the metadata to a file for debugging
@@ -106,11 +103,7 @@ def emit_verilog_mistral(
mg, _ = passes.emit_bram_transform_pass(mg)
mg, _ = passes.emit_internal_rtl_transform_pass(mg)
mg, _ = passes.emit_cocotb_transform_pass(
- mg,
- pass_args={
- "wait_time": wait_count,
- "wait_unit": wait_unit,
- },
+ mg, pass_args={"wait_time": wait_count, "wait_unit": wait_unit,},
)
mg, _ = passes.emit_vivado_project_transform_pass(mg)
diff --git a/test/passes/graph/transforms/verilog/test_emit_verilog_norm.py b/test/passes/graph/transforms/verilog/test_emit_verilog_norm.py
index daf78fbee..acc659331 100644
--- a/test/passes/graph/transforms/verilog/test_emit_verilog_norm.py
+++ b/test/passes/graph/transforms/verilog/test_emit_verilog_norm.py
@@ -11,9 +11,7 @@
from mase_components.scalar_operators.fixed.test.isqrt_sw import make_lut
from mase_components.common.test.lut_tb import write_memb
from chop.passes.graph.utils import get_module_by_name
-from chop.nn.quantizers.quantizers_for_hw import (
- integer_quantizer_for_hw,
-)
+from chop.nn.quantizers.quantizers_for_hw import integer_quantizer_for_hw
# import chop.models.manual.rms_norm as rms
@@ -118,7 +116,7 @@ def add_norm_metadata_gen_lut_analysis_pass(mg, config={}):
mem_dir = Path(__file__).parent / "build" / "norm" / "mem"
makedirs(mem_dir, exist_ok=True)
- lut = make_lut(2**LUT_POW, ISQRT_WIDTH)
+ lut = make_lut(2 ** LUT_POW, ISQRT_WIDTH)
mem_path = mem_dir / f"norm_isqrt_lut.mem"
write_memb(mem_path, lut, ISQRT_WIDTH)
mem_id = 0
@@ -226,23 +224,10 @@ def test_emit_verilog_norm():
shape = [10, 4, 8, 8]
normalizations = [
- nn.BatchNorm2d(
- num_features=shape[1],
- affine=False,
- ),
- nn.LayerNorm(
- normalized_shape=shape[1:],
- elementwise_affine=False,
- ),
- nn.GroupNorm(
- num_groups=2,
- num_channels=shape[1],
- affine=False,
- ),
- nn.InstanceNorm2d(
- num_features=shape[1],
- affine=False,
- ),
+ nn.BatchNorm2d(num_features=shape[1], affine=False,),
+ nn.LayerNorm(normalized_shape=shape[1:], elementwise_affine=False,),
+ nn.GroupNorm(num_groups=2, num_channels=shape[1], affine=False,),
+ nn.InstanceNorm2d(num_features=shape[1], affine=False,),
# rms.RMSNorm(
# normalized_shape=shape[1:],
# ),
diff --git a/test/passes/onnx/analysis/test_export_fx_graph.py b/test/passes/onnx/analysis/test_export_fx_graph.py
index 395991cda..fc081f1ee 100644
--- a/test/passes/onnx/analysis/test_export_fx_graph.py
+++ b/test/passes/onnx/analysis/test_export_fx_graph.py
@@ -78,16 +78,13 @@ def test_export_fx_graph_bert():
@pytest.mark.skip
def test_export_fx_graph_mistral():
- export_fx_graph_model(
- "mistral-community/Mistral-7B-v0.2",
- )
+ export_fx_graph_model("mistral-community/Mistral-7B-v0.2",)
@pytest.mark.skip
def test_export_fx_graph_whisper():
export_fx_graph_model(
- "openai/whisper-tiny",
- skip_export=True,
+ "openai/whisper-tiny", skip_export=True,
)
diff --git a/test/self/test_optical_module.py b/test/self/test_optical_module.py
new file mode 100644
index 000000000..e56921881
--- /dev/null
+++ b/test/self/test_optical_module.py
@@ -0,0 +1,252 @@
+#!/usr/bin/env python3
+# This example converts a simple MLP model to Verilog
+import logging
+import os
+import sys
+
+import torch
+import torch.nn as nn
+
+from torch.profiler import profile, record_function, ProfilerActivity
+import torchvision.models as models
+import torchvision.transforms as transforms
+import torch.utils.data as data
+from pathlib import Path
+
+sys.path.append(Path(__file__).resolve().parents[5].as_posix())
+
+
+# from chop.passes.module.transforms import quantize_module_transform_pass
+from chop.passes.module.transforms import optical_module_transform_pass
+from chop.passes.module import report_trainable_parameters_analysis_pass
+
+import argparse
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+from torchvision import datasets, transforms
+from torch.optim.lr_scheduler import StepLR
+
+from train_mnist_cnn import test, train, Net, test_memory_detailed
+
+# --------------------------------------------------
+# Model specifications
+# --------------------------------------------------
+# class MLP(torch.nn.Module):
+# """
+# Toy quantized FC model for digit recognition on MNIST
+# """
+
+# def __init__(self) -> None:
+# super().__init__()
+
+# self.fc1 = nn.Linear(28 * 28, 28 * 28)
+# self.fc2 = nn.Linear(28 * 28, 28 * 28 * 4)
+# self.fc3 = nn.Linear(28 * 28 * 4, 10)
+
+# def forward(self, x):
+# x = torch.flatten(x, start_dim=1, end_dim=-1)
+# x = torch.nn.functional.relu(self.fc1(x))
+# # w = torch.randn((4, 28 * 28))
+# # x = torch.nn.functional.relu(nn.functional.linear(x, w))
+# x = torch.nn.functional.relu(self.fc2(x))
+# x = self.fc3(x)
+# return x
+
+
+def load_my_model(model_path, device):
+ # Load the model from the .pt file
+ loaded_model = torch.load(model_path, map_location=device)
+ # Set it to evaluation mode (important if it contains layers like BatchNorm or Dropout)
+ loaded_model.eval()
+ return loaded_model
+
+
+def perform_optical_module_transform_pass(model, save_path="mase_output/onn_cnn.pt"):
+ pass_args = {
+ "by": "type",
+ "linear": {
+ "config": {
+ "name": "morr",
+ "miniblock": 4,
+ "morr_init": True,
+ "trainable_morr_bias": False,
+ "trainable_morr_scale": False,
+ "device": device,
+ }
+ },
+ "conv2d": {
+ "config": {
+ "name": "morr",
+ "miniblock": 4,
+ "morr_init": True,
+ "trainable_morr_bias": False,
+ "trainable_morr_scale": False,
+ "device": device,
+ }
+ },
+ }
+ onn_model, _ = optical_module_transform_pass(model, pass_args)
+ torch.save(onn_model.state_dict(), save_path)
+ return onn_model
+
+
+def test_optical_module_transform_pass():
+ model_path = "mase_output/sample_mnist_cnn.pt"
+ mnist_cnn = load_my_model(model_path)
+ # Sanity check and report
+ pass_args = {
+ "by": "name",
+ "fc1": {
+ "config": {
+ "name": "morr",
+ "miniblock": 4,
+ "morr_init": True,
+ "trainable_morr_bias": False,
+ "trainable_morr_scale": False,
+ "device": device,
+ }
+ },
+ }
+ onn_cnn, _ = optical_module_transform_pass(mnist_cnn, pass_args)
+ torch.save(onn_cnn, "mase_output/onn_cnn.pt")
+
+
+if __name__ == "__main__":
+ finetune = True
+ if True:
+ parser = argparse.ArgumentParser(description="PyTorch MNIST Example")
+ parser.add_argument(
+ "--batch-size",
+ type=int,
+ default=64,
+ metavar="N",
+ help="input batch size for training (default: 64)",
+ )
+ parser.add_argument(
+ "--test-batch-size",
+ type=int,
+ default=100,
+ metavar="N",
+ help="input batch size for testing (default: 1000)",
+ )
+ parser.add_argument(
+ "--epochs",
+ type=int,
+ default=1,
+ metavar="N",
+ help="number of epochs to train (default: 14)",
+ )
+ parser.add_argument(
+ "--lr",
+ type=float,
+ default=1.0,
+ metavar="LR",
+ help="learning rate (default: 1.0)",
+ )
+ parser.add_argument(
+ "--gamma",
+ type=float,
+ default=0.7,
+ metavar="M",
+ help="Learning rate step gamma (default: 0.7)",
+ )
+ parser.add_argument(
+ "--no-cuda",
+ action="store_true",
+ default=False,
+ help="disables CUDA training",
+ )
+ parser.add_argument(
+ "--no-mps",
+ action="store_true",
+ default=False,
+ help="disables macOS GPU training",
+ )
+ parser.add_argument(
+ "--dry-run",
+ action="store_true",
+ default=False,
+ help="quickly check a single pass",
+ )
+ parser.add_argument(
+ "--seed", type=int, default=1, metavar="S", help="random seed (default: 1)"
+ )
+ parser.add_argument(
+ "--log-interval",
+ type=int,
+ default=10,
+ metavar="N",
+ help="how many batches to wait before logging training status",
+ )
+ parser.add_argument(
+ "--save-model",
+ action="store_true",
+ default=True,
+ help="For Saving the current Model",
+ )
+ parser.add_argument(
+ "--gpu-id", type=int, default=1, help="Which GPU device to use [default: 0]"
+ )
+
+ args = parser.parse_args()
+ use_cuda = not args.no_cuda and torch.cuda.is_available()
+ use_mps = not args.no_mps and torch.backends.mps.is_available()
+
+ torch.manual_seed(args.seed)
+
+ if not args.no_cuda and torch.cuda.is_available():
+ device = torch.device(f"cuda:{args.gpu_id}")
+ elif use_mps:
+ device = torch.device("mps")
+ else:
+ device = torch.device("cpu")
+
+ train_kwargs = {"batch_size": args.batch_size}
+ test_kwargs = {"batch_size": args.test_batch_size}
+ if use_cuda:
+ cuda_kwargs = {"num_workers": 1, "pin_memory": True, "shuffle": True}
+ train_kwargs.update(cuda_kwargs)
+ test_kwargs.update(cuda_kwargs)
+
+ transform = transforms.Compose(
+ [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
+ )
+ dataset1 = datasets.MNIST(
+ "../data", train=True, download=True, transform=transform
+ )
+ dataset2 = datasets.MNIST("../data", train=False, transform=transform)
+ train_loader = torch.utils.data.DataLoader(dataset1, **train_kwargs)
+ test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)
+
+ cnn = load_my_model("mase_output/sample_mnist_cnn.pt", device)
+ print("-------------- Testing the original cnn model -------------------")
+ test(cnn, device, test_loader)
+ _, _ = report_trainable_parameters_analysis_pass(cnn)
+
+ # onn = load_my_model("mase_output/onn_cnn.pt", device)
+ onn_model = perform_optical_module_transform_pass(cnn)
+ onn_model.to(device)
+
+ print("-------------- Testing the transformed onn model -------------------")
+ test(onn_model, device, test_loader)
+ _, _ = report_trainable_parameters_analysis_pass(onn_model)
+
+ ######### Training the onn model
+ if finetune:
+ optimizer = optim.Adadelta(onn_model.parameters(), lr=args.lr)
+ scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
+
+ for epoch in range(1, args.epochs + 1):
+ train(args, onn_model, device, train_loader, optimizer, epoch)
+ test(onn_model, device, test_loader)
+ scheduler.step()
+
+ torch.save(onn_model.state_dict(), "mase_output/trained_onn.pt")
+
+ print("-------------- Testing the trained onn model -------------------")
+ test(onn_model, device, test_loader)
+ _, _ = report_trainable_parameters_analysis_pass(onn_model)
+
+ # test_optical_module_transform_pass()
diff --git a/test/self/train_mnist_cnn.py b/test/self/train_mnist_cnn.py
new file mode 100644
index 000000000..a955087da
--- /dev/null
+++ b/test/self/train_mnist_cnn.py
@@ -0,0 +1,247 @@
+import argparse
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+from torchvision import datasets, transforms
+from torch.optim.lr_scheduler import StepLR
+
+
+class Net(nn.Module):
+ def __init__(self):
+ super(Net, self).__init__()
+ self.conv1 = nn.Conv2d(1, 32, 3, 1)
+ self.conv2 = nn.Conv2d(32, 64, 3, 1)
+ self.dropout1 = nn.Dropout(0.25)
+ self.dropout2 = nn.Dropout(0.5)
+ self.fc1 = nn.Linear(9216, 128)
+ self.fc2 = nn.Linear(128, 10)
+
+ def forward(self, x):
+ x = self.conv1(x)
+ x = F.relu(x)
+ x = self.conv2(x)
+ x = F.relu(x)
+ x = F.max_pool2d(x, 2)
+ x = self.dropout1(x)
+ x = torch.flatten(x, 1)
+ x = self.fc1(x)
+ x = F.relu(x)
+ x = self.dropout2(x)
+ x = self.fc2(x)
+ output = F.log_softmax(x, dim=1)
+ return output
+
+
+def train(args, model, device, train_loader, optimizer, epoch):
+ model.train()
+ for batch_idx, (data, target) in enumerate(train_loader):
+ data, target = data.to(device), target.to(device)
+ optimizer.zero_grad()
+ output = model(data)
+ loss = F.nll_loss(output, target)
+ loss.backward()
+ optimizer.step()
+ if batch_idx % args.log_interval == 0:
+ print(
+ "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
+ epoch,
+ batch_idx * len(data),
+ len(train_loader.dataset),
+ 100.0 * batch_idx / len(train_loader),
+ loss.item(),
+ )
+ )
+ if args.dry_run:
+ break
+
+
+def test(model, device, test_loader):
+ model.eval()
+ test_loss = 0
+ correct = 0
+ with torch.no_grad():
+ for data, target in test_loader:
+ data, target = data.to(device), target.to(device)
+ output = model(data)
+ test_loss += F.nll_loss(
+ output, target, reduction="sum"
+ ).item() # sum up batch loss
+ pred = output.argmax(
+ dim=1, keepdim=True
+ ) # get the index of the max log-probability
+ correct += pred.eq(target.view_as(pred)).sum().item()
+
+ test_loss /= len(test_loader.dataset)
+
+ print(
+ "\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format(
+ test_loss,
+ correct,
+ len(test_loader.dataset),
+ 100.0 * correct / len(test_loader.dataset),
+ )
+ )
+
+
+# Custom Function
+from torch.profiler import profile, record_function, ProfilerActivity
+
+
+def test_memory_detailed(model, device, test_loader):
+ """
+ Use PyTorch Profiler to record detailed memory usage on the first batch,
+ so you can see exactly which ops consume how much memory.
+ """
+
+ # Put the model in eval mode
+ model.eval()
+
+ # Get just 1 batch (or a few) for profiling
+ data_iter = iter(test_loader)
+ try:
+ data, target = next(data_iter)
+ except StopIteration:
+ print("test_loader is empty. Cannot profile.")
+ return
+
+ # Move to device
+ data, target = data.to(device), target.to(device)
+
+ # Now start the profiler
+ with profile(
+ activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
+ record_shapes=True, # to see tensor shapes
+ profile_memory=True, # track memory usage
+ ) as prof:
+ with record_function("test_first_batch"):
+ # Forward pass on just this batch
+ output = model(data)
+ # Suppose you also compute a loss for some reason
+ loss = F.nll_loss(output, target, reduction="sum")
+ # If purely inference, you might not do backward
+ # but let's illustrate:
+ loss.backward() # If you want to see backward pass memory usage
+
+ # Print the summarized table
+ # Sort by self_cuda_memory_usage to see which ops used the most GPU memory
+ print(
+ prof.key_averages(group_by_input_shape=True).table(
+ sort_by="self_cuda_memory_usage",
+ row_limit=200, # show as many rows as you need
+ )
+ )
+
+
+def main():
+ # Training settings
+ parser = argparse.ArgumentParser(description="PyTorch MNIST Example")
+ parser.add_argument(
+ "--batch-size",
+ type=int,
+ default=64,
+ metavar="N",
+ help="input batch size for training (default: 64)",
+ )
+ parser.add_argument(
+ "--test-batch-size",
+ type=int,
+ default=1000,
+ metavar="N",
+ help="input batch size for testing (default: 1000)",
+ )
+ parser.add_argument(
+ "--epochs",
+ type=int,
+ default=29,
+ metavar="N",
+ help="number of epochs to train (default: 14)",
+ )
+ parser.add_argument(
+ "--lr",
+ type=float,
+ default=1.0,
+ metavar="LR",
+ help="learning rate (default: 1.0)",
+ )
+ parser.add_argument(
+ "--gamma",
+ type=float,
+ default=0.7,
+ metavar="M",
+ help="Learning rate step gamma (default: 0.7)",
+ )
+ parser.add_argument(
+ "--no-cuda", action="store_true", default=False, help="disables CUDA training"
+ )
+ parser.add_argument(
+ "--no-mps",
+ action="store_true",
+ default=False,
+ help="disables macOS GPU training",
+ )
+ parser.add_argument(
+ "--dry-run",
+ action="store_true",
+ default=False,
+ help="quickly check a single pass",
+ )
+ parser.add_argument(
+ "--seed", type=int, default=1, metavar="S", help="random seed (default: 1)"
+ )
+ parser.add_argument(
+ "--log-interval",
+ type=int,
+ default=10,
+ metavar="N",
+ help="how many batches to wait before logging training status",
+ )
+ parser.add_argument(
+ "--save-model",
+ action="store_true",
+ default=True,
+ help="For Saving the current Model",
+ )
+ args = parser.parse_args()
+ use_cuda = not args.no_cuda and torch.cuda.is_available()
+ use_mps = not args.no_mps and torch.backends.mps.is_available()
+
+ torch.manual_seed(args.seed)
+
+ if use_cuda:
+ device = torch.device("cuda")
+ elif use_mps:
+ device = torch.device("mps")
+ else:
+ device = torch.device("cpu")
+
+ train_kwargs = {"batch_size": args.batch_size}
+ test_kwargs = {"batch_size": args.test_batch_size}
+ if use_cuda:
+ cuda_kwargs = {"num_workers": 1, "pin_memory": True, "shuffle": True}
+ train_kwargs.update(cuda_kwargs)
+ test_kwargs.update(cuda_kwargs)
+
+ transform = transforms.Compose(
+ [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
+ )
+ dataset1 = datasets.MNIST("../data", train=True, download=True, transform=transform)
+ dataset2 = datasets.MNIST("../data", train=False, transform=transform)
+ train_loader = torch.utils.data.DataLoader(dataset1, **train_kwargs)
+ test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)
+
+ model = Net().to(device)
+ optimizer = optim.Adadelta(model.parameters(), lr=args.lr)
+
+ scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
+ for epoch in range(1, args.epochs + 1):
+ train(args, model, device, train_loader, optimizer, epoch)
+ test(model, device, test_loader)
+ scheduler.step()
+
+ if args.save_model:
+ torch.save(model, "mase_output/sample_mnist_cnn.pt")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/test/tools/test_onnx_operators.py b/test/tools/test_onnx_operators.py
index c15a44d89..7393cb41f 100644
--- a/test/tools/test_onnx_operators.py
+++ b/test/tools/test_onnx_operators.py
@@ -14,58 +14,20 @@ def excepthook(exc_type, exc_value, exc_traceback):
def test_gather():
- data1 = torch.Tensor(
- [
- [1.0, 1.2, 1.9],
- [2.3, 3.4, 3.9],
- [4.5, 5.7, 5.9],
- ]
- )
+ data1 = torch.Tensor([[1.0, 1.2, 1.9], [2.3, 3.4, 3.9], [4.5, 5.7, 5.9],])
- data2 = torch.Tensor(
- [
- [1.0, 1.2],
- [2.3, 3.4],
- [4.5, 5.7],
- ]
- )
+ data2 = torch.Tensor([[1.0, 1.2], [2.3, 3.4], [4.5, 5.7],])
- indices1 = torch.Tensor(
- [
- [0, 2],
- ]
- ).to(torch.int64)
+ indices1 = torch.Tensor([[0, 2],]).to(torch.int64)
- indices2 = torch.Tensor(
- [
- [0, 1],
- [1, 2],
- ]
- ).to(torch.int64)
+ indices2 = torch.Tensor([[0, 1], [1, 2],]).to(torch.int64)
obs_out1 = onnx_gather(data1, 1, indices1)
obs_out2 = onnx_gather(data2, 0, indices2)
- exp_out1 = torch.Tensor(
- [
- [[1.0, 1.9]],
- [[2.3, 3.9]],
- [[4.5, 5.9]],
- ]
- )
+ exp_out1 = torch.Tensor([[[1.0, 1.9]], [[2.3, 3.9]], [[4.5, 5.9]],])
- exp_out2 = torch.Tensor(
- [
- [
- [1.0, 1.2],
- [2.3, 3.4],
- ],
- [
- [2.3, 3.4],
- [4.5, 5.7],
- ],
- ]
- )
+ exp_out2 = torch.Tensor([[[1.0, 1.2], [2.3, 3.4],], [[2.3, 3.4], [4.5, 5.7],],])
print(obs_out2)
print(exp_out2)
@@ -75,12 +37,7 @@ def test_gather():
def test_slice():
- data = torch.Tensor(
- [
- [1, 2, 3, 4],
- [5, 6, 7, 8],
- ]
- )
+ data = torch.Tensor([[1, 2, 3, 4], [5, 6, 7, 8],])
test1 = onnx_slice(
data,
From 8aa1dbddbcacef0c8e369b29bc7f3bdc124d844b Mon Sep 17 00:00:00 2001
From: Johnny1882
Date: Sun, 16 Feb 2025 22:51:29 +0000
Subject: [PATCH 13/38] .
---
scripts/mase-hls.py | 30 +++++--
scripts/stat-to-conf.py | 7 +-
setup.py | 4 +-
src/chop/actions/emit.py | 5 +-
src/chop/actions/search/search.py | 4 +-
.../quantization/manual_hf_module.py | 3 +-
src/chop/actions/simulate.py | 4 +-
src/chop/actions/train.py | 3 +-
src/chop/dataset/nerf/blender.py | 4 +-
src/chop/dataset/nlp/language_modeling.py | 10 ++-
src/chop/dataset/vision/transforms/cifar.py | 5 +-
src/chop/distributed/launcher.py | 16 +++-
src/chop/distributed/tensor/__init__.py | 18 +++-
src/chop/distributed/tensor/_dispatch.py | 11 ++-
src/chop/distributed/tensor/_redistribute.py | 17 +++-
src/chop/distributed/tensor/_sharding_prop.py | 15 ++--
src/chop/distributed/tensor/api.py | 56 ++++++++++---
.../distributed/tensor/ops/basic_strategy.py | 5 +-
.../distributed/tensor/ops/common_rules.py | 15 +++-
src/chop/distributed/tensor/ops/conv_ops.py | 23 +++--
src/chop/distributed/tensor/ops/math_ops.py | 34 ++++++--
src/chop/distributed/tensor/ops/matrix_ops.py | 10 ++-
.../distributed/tensor/ops/pointwise_ops.py | 9 +-
src/chop/distributed/tensor/ops/tensor_ops.py | 16 +++-
src/chop/distributed/tensor/ops/utils.py | 4 +-
src/chop/distributed/tensor/ops/view_ops.py | 9 +-
src/chop/ir/graph/mase_graph.py | 49 ++++++++---
src/chop/ir/graph/mase_metadata.py | 4 +-
src/chop/ir/onnx/mase_onnx_graph.py | 4 +-
src/chop/ir/onnx/utils.py | 9 +-
.../models/bert/modeling_bert_quantized.py | 8 +-
src/chop/models/bert/quant_config_bert.py | 5 +-
src/chop/models/cnv/cnv.py | 9 +-
src/chop/models/cswin/cswintransformer.py | 8 +-
src/chop/models/deit/deit_v2.py | 2 +-
src/chop/models/efficientnet/efficientnet.py | 23 +++--
src/chop/models/lfc/lfc.py | 4 +-
src/chop/models/llama/modeling_llama_llora.py | 5 +-
.../models/llama/modeling_llama_sparse.py | 5 +-
src/chop/models/mobilenet_v2/mobilenet_v2.py | 3 +-
src/chop/models/nerf/nerf_vision.py | 4 +-
src/chop/models/opt/modeling_opt.py | 2 +-
src/chop/models/opt/modeling_opt_lora.py | 2 +-
src/chop/models/opt/modeling_opt_quantized.py | 2 +-
src/chop/models/opt/modeling_opt_sparse.py | 2 +-
.../models/opt/quant_config_opt_quantized.py | 4 +-
src/chop/models/pvt/pvt.py | 9 +-
src/chop/models/pvt/pvt_v2.py | 2 +-
src/chop/models/repvgg/repvgg.py | 4 +-
src/chop/models/resnet/resnet.py | 56 ++++++++++---
src/chop/models/toy/toy.py | 15 +++-
src/chop/models/vgg_cifar/vgg_orig.py | 61 ++++++++++++--
src/chop/models/vision/snn/snn_toy.py | 4 +-
.../models/vision/snn/spikingResformer.py | 42 ++++++++--
src/chop/nn/functional/softermax.py | 2 +-
src/chop/nn/modules/gqa.py | 11 ++-
src/chop/nn/modules/lora.py | 13 ++-
src/chop/nn/modules/sparse.py | 19 ++++-
src/chop/nn/mx/activations.py | 5 +-
src/chop/nn/mx/bmm.py | 12 ++-
src/chop/nn/mx/convolution.py | 32 +++++--
src/chop/nn/mx/elemwise_ops.py | 8 +-
src/chop/nn/mx/formats.py | 8 +-
src/chop/nn/mx/linear.py | 32 +++++--
src/chop/nn/mx/matmul.py | 12 ++-
src/chop/nn/mx/mx_ops.py | 9 +-
src/chop/nn/mx/quantize.py | 4 +-
src/chop/nn/mx/simd_ops.py | 2 +-
src/chop/nn/mx/transpose_convolution.py | 32 +++++--
src/chop/nn/optical/modules/morr_conv2d.py | 6 +-
src/chop/nn/optical/modules/morr_linear.py | 6 +-
src/chop/nn/optical/utils/initializer.py | 2 +-
src/chop/nn/optical/utils/mrr_op.py | 2 +-
src/chop/nn/optical/utils/quantize.py | 6 +-
src/chop/nn/quantized/functional/gelu.py | 4 +-
src/chop/nn/quantized/functional/linear.py | 83 +++++++++++++++----
src/chop/nn/quantized/functional/matmul.py | 8 +-
src/chop/nn/quantized/functional/relu.py | 4 +-
src/chop/nn/quantized/functional/selu.py | 4 +-
src/chop/nn/quantized/functional/softplus.py | 4 +-
src/chop/nn/quantized/functional/softsign.py | 4 +-
src/chop/nn/quantized/functional/tanh.py | 4 +-
.../nn/quantized/modules/attention_head.py | 5 +-
src/chop/nn/quantized/modules/conv1d.py | 12 ++-
src/chop/nn/quantized/modules/conv2d.py | 39 ++++++---
src/chop/nn/quantized/modules/gelu.py | 8 +-
src/chop/nn/quantized/modules/gqa.py | 5 +-
src/chop/nn/quantized/modules/linear.py | 18 ++--
src/chop/nn/quantized/modules/relu.py | 8 +-
src/chop/nn/quantized/modules/selu.py | 8 +-
src/chop/nn/quantized/modules/silu.py | 8 +-
src/chop/nn/quantized/modules/softplus.py | 8 +-
src/chop/nn/quantized/modules/softsign.py | 8 +-
src/chop/nn/quantized/modules/tanh.py | 8 +-
.../nn/quantizers/LUTNet/BaseInitializer.py | 5 +-
src/chop/nn/quantizers/LUTNet/BaseTrainer.py | 2 +-
src/chop/nn/quantizers/block_fp.py | 17 ++--
src/chop/nn/quantizers/block_log.py | 8 +-
src/chop/nn/quantizers/block_minifloat.py | 9 +-
src/chop/nn/quantizers/integer.py | 8 +-
src/chop/nn/quantizers/log.py | 14 ++--
src/chop/nn/quantizers/minifloat.py | 34 +++++---
src/chop/nn/quantizers/mxint_hardware.py | 6 +-
src/chop/nn/quantizers/quantizers_for_hw.py | 14 ++--
src/chop/nn/quantizers/ternary.py | 3 +-
src/chop/nn/quantizers/utils.py | 32 +++++--
src/chop/nn/snn/auto_cuda/generator.py | 5 +-
.../nn/snn/modules/spiking_self_attention.py | 8 +-
src/chop/passes/graph/__init__.py | 6 +-
.../add_metadata/add_common_metadata.py | 12 ++-
.../add_metadata/add_hardware_metadata.py | 3 +-
.../add_metadata/common_metadata_layers.py | 57 ++++++++++---
.../add_metadata/hardware_metadata_layers.py | 37 +++++++--
.../add_metadata/software_metadata_layers.py | 16 +++-
.../autosharding/alpa_cost_modelling.py | 10 ++-
.../analysis/autosharding/autosharding.py | 31 +++++--
.../graph/analysis/autosharding/megatron.py | 4 +-
.../autosharding/strategies/basic_strategy.py | 5 +-
.../autosharding/strategies/common.py | 6 +-
.../autosharding/strategies/matrix_ops.py | 38 +++++++--
.../autosharding/strategies/pointwise_ops.py | 4 +-
.../autosharding/strategies/view_ops.py | 7 +-
.../flop_estimator/calculator/calc_modules.py | 2 +-
.../passes/graph/analysis/plot/plot_graph.py | 5 +-
.../graph/interface/tensorrt/quantize.py | 1 -
.../passes/graph/transforms/dse/run_dse.py | 2 +-
src/chop/passes/graph/transforms/lora.py | 6 +-
.../graph/transforms/onnxrt/quantize.py | 4 +-
.../transforms/pruning/pruning_methods.py | 28 ++++++-
.../quant_parsers/parse_quant_config.py | 82 ++++++++++++++----
.../graph/transforms/training/modify.py | 5 +-
.../graph/transforms/verilog/emit_bram.py | 8 +-
.../graph/transforms/verilog/emit_hls.py | 7 +-
src/chop/passes/module/analysis/report.py | 3 +-
src/chop/passes/utils.py | 8 +-
src/chop/pipelines/auto_pipeline.py | 22 ++++-
src/chop/tools/check_dependency.py | 7 +-
src/chop/tools/huggingface.py | 14 +++-
.../tools/plt_wrapper/nlp/classification.py | 4 +-
src/chop/tools/plt_wrapper/nlp/lm.py | 4 +-
src/chop/tools/utils.py | 6 +-
src/mase_cocotb/interfaces/streaming.py | 24 ++++--
src/mase_cocotb/runner.py | 13 ++-
src/mase_cocotb/testbench.py | 7 +-
src/mase_cocotb/utils.py | 4 +-
src/mase_cocotb/z_qlayers/tensor_cast.py | 6 +-
.../activation_layers/test/fixed_elu_tb.py | 14 ++--
.../activation_layers/test/fixed_gelu_tb.py | 10 +--
.../test/fixed_hardshrink_tb.py | 6 +-
.../test/fixed_hardswish_tb.py | 10 +--
.../test/fixed_leaky_relu_tb.py | 6 +-
.../test/fixed_logsigmoid_tb.py | 6 +-
.../activation_layers/test/fixed_relu_tb.py | 6 +-
.../activation_layers/test/fixed_selu_tb.py | 10 +--
.../test/fixed_sigmoid_tb.py | 6 +-
.../activation_layers/test/fixed_silu_tb.py | 6 +-
.../test/fixed_softermax_1d_tb.py | 19 ++++-
.../test/fixed_softermax_tb.py | 4 +-
.../test/fixed_softmax_tb.py | 6 +-
.../test/fixed_softplus_tb.py | 10 +--
.../test/fixed_softshrink_tb.py | 6 +-
.../test/fixed_softsign_tb.py | 10 +--
.../activation_layers/test/fixed_tanh_tb.py | 10 +--
.../activation_layers/test/softermax.py | 2 +-
.../test/softermax_global_norm_tb.py | 26 ++++--
.../test/softermax_local_window_tb.py | 8 +-
.../test/softermax_lpw_pow2_tb.py | 51 +++++++-----
.../test/softermax_lpw_reciprocal_tb.py | 40 +++++----
.../cast/test/fixed_rounding_tb.py | 2 +-
.../cast/test/fixed_signed_cast_tb.py | 13 +--
.../cast/test/fixed_unsigned_cast_tb.py | 8 +-
.../common/test/comparator_accumulator_tb.py | 8 +-
.../common/test/comparator_tree_tb.py | 5 +-
.../common/test/register_slice_tb.py | 4 +-
.../common/test/single_element_repeat_tb.py | 6 +-
...binary_activation_binary_convolution_tb.py | 4 +-
.../convolution_layers/test/convolution_tb.py | 34 ++++++--
.../convolution_layers/test/padding_tb.py | 10 ++-
.../convolution_layers/test/roller_tb.py | 4 +-
.../test/sliding_window_tb.py | 6 +-
src/mase_components/deps.py | 12 ++-
src/mase_components/helper/generate_memory.py | 8 +-
.../hls/bfp_arith/bfp_adder.py | 6 +-
.../hls/bfp_arith/bfp_multiplier.py | 6 +-
src/mase_components/hls/elastic/buffer.py | 8 +-
src/mase_components/hls/hls_regression.py | 5 +-
.../hls/int_arith/int_layernorm.py | 8 +-
src/mase_components/hls/int_arith/int_relu.py | 8 +-
src/mase_components/hls/int_arith/int_silu.py | 8 +-
.../hls/int_arith/int_softmax.py | 8 +-
.../hls/int_arith/int_transpose.py | 8 +-
.../hls/regression_gen/bfp_add_dse.py | 11 ++-
.../hls/regression_gen/bfp_linear2d_dse.py | 8 +-
.../hls/regression_gen/bfp_mult_dse.py | 11 ++-
.../hls/regression_gen/buffer_dse.py | 11 ++-
.../hls/regression_gen/fork_dse.py | 11 ++-
.../hls/regression_gen/int_add_dse.py | 11 ++-
.../hls/regression_gen/int_layernorm_dse.py | 11 ++-
.../hls/regression_gen/int_linear2d_dse.py | 8 +-
.../hls/regression_gen/int_matmul_dse.py | 8 +-
.../hls/regression_gen/int_mult_dse.py | 11 ++-
.../hls/regression_gen/int_relu_dse.py | 11 ++-
.../hls/regression_gen/int_rmsnorm_dse.py | 11 ++-
.../hls/regression_gen/int_rope_dse.py | 11 ++-
.../hls/regression_gen/int_silu_dse.py | 11 ++-
.../hls/regression_gen/int_softmax_dse.py | 11 ++-
.../hls/regression_gen/int_transpose_dse.py | 11 ++-
src/mase_components/hls/scripts/bl_bfp.py | 15 +++-
...y_activation_binary_adder_tree_layer_tb.py | 3 +-
.../fixed_operators/test/fixed_isqrt_tb.py | 18 ++--
.../test/fixed_lut_index_tb.py | 2 +-
.../fixed_operators/test/fixed_nr_stage_tb.py | 4 +-
.../test/fixed_range_augmentation_tb.py | 2 +-
.../test/fixed_range_reduction_tb.py | 2 +-
.../fixed_operators/test/isqrt_sw.py | 10 +--
.../matmul/test/fixed_matmul_tb.py | 8 +-
.../matmul/test/simple_matmul_tb.py | 8 +-
.../linear_layers/matmul/test/transpose_tb.py | 6 +-
.../mxint_operators/test/mxint_cast_tb.py | 13 ++-
.../test/mxint_dot_product_tb.py | 9 +-
.../mxint_operators/test/mxint_matmul_tb.py | 10 ++-
.../test/mxint_vector_mult_tb.py | 9 +-
.../mxint_operators/test/test.py | 16 +++-
.../mxint_operators/test/utils.py | 6 +-
src/mase_components/memory/test/fifo_tb.py | 9 +-
.../memory/test/repeat_circular_buffer_tb.py | 2 +-
.../memory/test/unpacked_fifo_tb.py | 2 +-
.../process_synth_impl.py | 11 ++-
.../test/batch_norm_2d_tb.py | 5 +-
.../test/group_norm_2d_tb.py | 13 ++-
.../normalization_layers/test/models.py | 14 ++--
.../test/rms_norm_2d_tb.py | 11 ++-
.../fixed/test/fixed_isqrt_tb.py | 18 ++--
.../fixed/test/fixed_nr_stage_tb.py | 4 +-
.../scalar_operators/fixed/test/isqrt_sw.py | 10 +--
...ixed_grouped_query_attention_wrapper_tb.py | 50 +++++++----
.../test/fixed_self_attention_head_tb.py | 32 +++++--
.../vision_models/vit/test/fixed_mlp_tb.py | 6 +-
.../vit/test/fixed_patch_embed_tb.py | 9 +-
.../vision_models/vit/test/fixed_pvt_tb.py | 12 ++-
.../vision_models/vit/test/hash_exp_tb.py | 4 +-
.../vision_models/vit/test/hash_softmax_tb.py | 6 +-
.../vit/test/helpers/ha_softmax.py | 11 +--
.../vit/test/helpers/pvt_quant.py | 2 +-
test/nn/snn/test_ann2snn.py | 8 +-
.../add_metadata/test_add_common_metadata.py | 16 +++-
.../analysis/pruning/test_hook_inspect.py | 12 ++-
.../test_statistic_profiler.py | 16 +++-
.../graph/transforms/prune/test_prune.py | 12 ++-
.../prune/test_prune_detach_hook.py | 12 ++-
.../quantize/test_quantize_lutnet_conv2d.py | 3 +-
.../training/test_training_base_pass.py | 6 +-
.../verilog/test_emit_activation_gelu.py | 6 +-
.../verilog/test_emit_activation_selu.py | 6 +-
.../verilog/test_emit_activation_softplus.py | 6 +-
.../verilog/test_emit_activation_softsign.py | 6 +-
.../verilog/test_emit_activation_tanh.py | 6 +-
.../verilog/test_emit_verilog_bert.py | 11 ++-
.../verilog/test_emit_verilog_llama.py | 11 ++-
.../verilog/test_emit_verilog_mistral.py | 11 ++-
.../verilog/test_emit_verilog_norm.py | 23 +++--
.../onnx/analysis/test_export_fx_graph.py | 7 +-
test/tools/test_onnx_operators.py | 57 +++++++++++--
263 files changed, 2273 insertions(+), 792 deletions(-)
diff --git a/scripts/mase-hls.py b/scripts/mase-hls.py
index 925ec2300..f2097bcd1 100755
--- a/scripts/mase-hls.py
+++ b/scripts/mase-hls.py
@@ -44,17 +44,33 @@ def build(self):
def quick_test(self):
shutil.copy(
- os.path.join(self.root, "test", "test_in.mlir",),
- os.path.join(self.root, "test", "test.mlir",),
+ os.path.join(
+ self.root,
+ "test",
+ "test_in.mlir",
+ ),
+ os.path.join(
+ self.root,
+ "test",
+ "test.mlir",
+ ),
)
result = False
cmd = [
"mase-opt",
"--preprocess-func=func-name=relu",
"--canonicalize",
- os.path.join(self.root, "test", "test.mlir",),
+ os.path.join(
+ self.root,
+ "test",
+ "test.mlir",
+ ),
"-o",
- os.path.join(self.root, "test", "test1.mlir",),
+ os.path.join(
+ self.root,
+ "test",
+ "test1.mlir",
+ ),
]
result |= self.execute(cmd, log_output=True, cwd=self.root)
@@ -67,7 +83,11 @@ def quick_test(self):
# "test",
# "test.cpp",
# ),
- os.path.join(self.root, "test", "test1.mlir",),
+ os.path.join(
+ self.root,
+ "test",
+ "test1.mlir",
+ ),
"--debug",
]
result |= self.execute(cmd, log_output=True, cwd=self.root)
diff --git a/scripts/stat-to-conf.py b/scripts/stat-to-conf.py
index 182e66d64..01e0d0cbc 100755
--- a/scripts/stat-to-conf.py
+++ b/scripts/stat-to-conf.py
@@ -41,7 +41,12 @@
}
-def set_stat(entry_name: str, mean=None, median=None, max=None,) -> dict[str, Any]:
+def set_stat(
+ entry_name: str,
+ mean=None,
+ median=None,
+ max=None,
+) -> dict[str, Any]:
"""Return a dictionary containing the format of the stats required to use
ternary quantiser. If statistics are not specified, "NA" will be set as the value,
this interally is being interpreted as None when the .toml is loaded"""
diff --git a/setup.py b/setup.py
index e0e91be1d..27d865de1 100644
--- a/setup.py
+++ b/setup.py
@@ -105,7 +105,9 @@ def get_system():
author_email="a.zhao@imperial.ac.uk, jianyi.cheng17@imperial.ac.uk, chengzhang98@outlook.com, pedro.gimenes19@imperial.ac.uk",
license_files=("LICENSE",),
python_requires=">=3.11.9",
- package_dir={"": "src",},
+ package_dir={
+ "": "src",
+ },
packages=find_packages("src"),
install_requires=requirements,
)
diff --git a/src/chop/actions/emit.py b/src/chop/actions/emit.py
index 9318b4476..7e326bc7b 100644
--- a/src/chop/actions/emit.py
+++ b/src/chop/actions/emit.py
@@ -41,7 +41,10 @@ def emit(
data_module.prepare_data()
data_module.setup()
dummy_in = get_dummy_input(
- model_info=model_info, data_module=data_module, task=task, device="cpu",
+ model_info=model_info,
+ data_module=data_module,
+ task=task,
+ device="cpu",
)
mg, _ = add_common_metadata_analysis_pass(
mg, {"dummy_in": dummy_in, "add_value": False}
diff --git a/src/chop/actions/search/search.py b/src/chop/actions/search/search.py
index 48c2d21d7..fe82d3574 100644
--- a/src/chop/actions/search/search.py
+++ b/src/chop/actions/search/search.py
@@ -15,7 +15,9 @@
logger = logging.getLogger(__name__)
-def parse_search_config(search_config: dict,):
+def parse_search_config(
+ search_config: dict,
+):
"""
Parse search config from a dict or a toml file and do sanity check. The search config must consist of two parts: strategy and search_space.
diff --git a/src/chop/actions/search/search_space/quantization/manual_hf_module.py b/src/chop/actions/search/search_space/quantization/manual_hf_module.py
index a9c9f1bfb..094330da1 100644
--- a/src/chop/actions/search/search_space/quantization/manual_hf_module.py
+++ b/src/chop/actions/search/search_space/quantization/manual_hf_module.py
@@ -61,7 +61,8 @@ def rebuild_model(self, sampled_config: dict, is_eval_mode: bool):
with init_empty_weights():
model = self.model_cls(config)
device_map = infer_auto_device_map(
- model, no_split_module_classes=model._no_split_modules,
+ model,
+ no_split_module_classes=model._no_split_modules,
)
model = load_checkpoint_and_dispatch(
model, checkpoint=self.model_name, device_map=device_map
diff --git a/src/chop/actions/simulate.py b/src/chop/actions/simulate.py
index 8a0e93486..e56a512d8 100644
--- a/src/chop/actions/simulate.py
+++ b/src/chop/actions/simulate.py
@@ -68,7 +68,9 @@ def simulate(
else:
raise ValueError(f"Unrecognized simulator: {simulator}")
- includes = [project_dir / "hardware" / "rtl",] + [
+ includes = [
+ project_dir / "hardware" / "rtl",
+ ] + [
Path(mase_components.__file__).parent / module / "rtl"
for module in get_modules()
]
diff --git a/src/chop/actions/train.py b/src/chop/actions/train.py
index 93524e756..5daa9b49a 100644
--- a/src/chop/actions/train.py
+++ b/src/chop/actions/train.py
@@ -109,7 +109,8 @@ def train(
trainer = pl.Trainer(**plt_trainer_args)
trainer.fit(
- pl_model, datamodule=data_module,
+ pl_model,
+ datamodule=data_module,
)
# Save the trained model along with relevant metadata in the training_ckpts folder.
diff --git a/src/chop/dataset/nerf/blender.py b/src/chop/dataset/nerf/blender.py
index 748dacca7..53865d621 100644
--- a/src/chop/dataset/nerf/blender.py
+++ b/src/chop/dataset/nerf/blender.py
@@ -34,7 +34,9 @@ def _download_lego_dataset(path: Path) -> None:
# Unzip the file
subprocess.run(
- f"unzip {folder_path.as_posix()} -d {path.as_posix()}", shell=True, check=True,
+ f"unzip {folder_path.as_posix()} -d {path.as_posix()}",
+ shell=True,
+ check=True,
)
diff --git a/src/chop/dataset/nlp/language_modeling.py b/src/chop/dataset/nlp/language_modeling.py
index 8d17b9466..ba2bdee55 100644
--- a/src/chop/dataset/nlp/language_modeling.py
+++ b/src/chop/dataset/nlp/language_modeling.py
@@ -281,7 +281,10 @@ def _tokenize(text, tokenizer, max_length):
prompt_len = prompt_tokenized.ne(tokenizer.pad_token_id).sum().item()
target_tokenized[:prompt_len] = ignore_id
- return dict(input_ids=input_ids, labels=target_tokenized,)
+ return dict(
+ input_ids=input_ids,
+ labels=target_tokenized,
+ )
def prepare_data(self):
dataset_dict = self._download_dataset()
@@ -313,7 +316,10 @@ def setup(self):
dataset_dict = self._download_dataset()
dataset_dict = dataset_dict["train"].train_test_split(test_size=0.1, seed=42)
dataset_dict = hf_datasets.DatasetDict(
- {"train": dataset_dict["train"], "validation": dataset_dict["test"],}
+ {
+ "train": dataset_dict["train"],
+ "validation": dataset_dict["test"],
+ }
)
dataset_dict = dataset_dict.map(
function=partial(
diff --git a/src/chop/dataset/vision/transforms/cifar.py b/src/chop/dataset/vision/transforms/cifar.py
index 84e7e1287..ef44d3721 100644
--- a/src/chop/dataset/vision/transforms/cifar.py
+++ b/src/chop/dataset/vision/transforms/cifar.py
@@ -35,7 +35,10 @@
def _get_cifar_default_transform(train: bool, mean: tuple[float], std: tuple[float]):
if train:
- transform = create_transform(**DEFAULT_CIFAR_PREPROCESS_ARGS, is_training=True,)
+ transform = create_transform(
+ **DEFAULT_CIFAR_PREPROCESS_ARGS,
+ is_training=True,
+ )
transform.transforms[0] = tv_transforms.RandomCrop(
DEFAULT_CIFAR_PREPROCESS_ARGS["input_size"], padding=4
)
diff --git a/src/chop/distributed/launcher.py b/src/chop/distributed/launcher.py
index 21337be8c..712bb026f 100644
--- a/src/chop/distributed/launcher.py
+++ b/src/chop/distributed/launcher.py
@@ -33,7 +33,10 @@ def distributed_average_timing(fn, repeat, args):
times = []
for itr in range(repeat):
rlog(
- logger, dist.get_rank(), f"Running teration {itr}", "debug",
+ logger,
+ dist.get_rank(),
+ f"Running teration {itr}",
+ "debug",
)
dist.barrier(async_op=True)
start = time()
@@ -42,7 +45,10 @@ def distributed_average_timing(fn, repeat, args):
end = time()
times.append(end - start)
rlog(
- logger, dist.get_rank(), f"Time taken: {end - start}s", "debug",
+ logger,
+ dist.get_rank(),
+ f"Time taken: {end - start}s",
+ "debug",
)
return result, sum(times[2:]) / len(times[2:])
@@ -132,7 +138,11 @@ def device_fn(
distribute_tensor(in_tensor, mesh, [Replicate(), Replicate()])
for in_tensor in inputs
]
- _, time_taken = distributed_average_timing(fn=model, repeat=10, args=inputs,)
+ _, time_taken = distributed_average_timing(
+ fn=model,
+ repeat=10,
+ args=inputs,
+ )
rlog(logger, rank, f"Forward pass finished. Time taken: {time_taken}", level="info")
dist.destroy_process_group()
diff --git a/src/chop/distributed/tensor/__init__.py b/src/chop/distributed/tensor/__init__.py
index b9aa9c272..826ed05cc 100644
--- a/src/chop/distributed/tensor/__init__.py
+++ b/src/chop/distributed/tensor/__init__.py
@@ -34,7 +34,11 @@
def _dtensor_init_helper(
- init_op, size: torch.Size, device_mesh=None, placements=None, **kwargs,
+ init_op,
+ size: torch.Size,
+ device_mesh=None,
+ placements=None,
+ **kwargs,
) -> DTensor:
from torch.distributed.tensor.placement_types import _DTensorSpec, TensorMeta
@@ -78,10 +82,18 @@ def _dtensor_init_helper(
spec = _DTensorSpec(
device_mesh,
tuple(placements),
- tensor_meta=TensorMeta(size, torch_stride, local_tensor.dtype,),
+ tensor_meta=TensorMeta(
+ size,
+ torch_stride,
+ local_tensor.dtype,
+ ),
)
- return DTensor(local_tensor, spec, requires_grad=kwargs["requires_grad"],)
+ return DTensor(
+ local_tensor,
+ spec,
+ requires_grad=kwargs["requires_grad"],
+ )
def ones(
diff --git a/src/chop/distributed/tensor/_dispatch.py b/src/chop/distributed/tensor/_dispatch.py
index dcc74852f..f5e08f714 100644
--- a/src/chop/distributed/tensor/_dispatch.py
+++ b/src/chop/distributed/tensor/_dispatch.py
@@ -42,7 +42,9 @@
def decompose_handler(
- op_call: torch._ops.OpOverload, args: Tuple[object, ...], kwargs: Dict[str, object],
+ op_call: torch._ops.OpOverload,
+ args: Tuple[object, ...],
+ kwargs: Dict[str, object],
) -> object:
"""
Decomposes a op to core ATen op, this handler is mostly here
@@ -56,7 +58,9 @@ def decompose_handler(
def is_same_size_handler(
- op_call: torch._ops.OpOverload, args: Tuple[object, ...], kwargs: Dict[str, object],
+ op_call: torch._ops.OpOverload,
+ args: Tuple[object, ...],
+ kwargs: Dict[str, object],
) -> bool:
lhs = cast(torch.Tensor, args[0])
rhs = cast(torch.Tensor, args[1])
@@ -251,7 +255,8 @@ def default_tensor(spec: _DTensorSpec) -> torch.Tensor:
@staticmethod
def redistribute_local_args(
- op_info: OpInfo, suggested_input_schema: OpSchema,
+ op_info: OpInfo,
+ suggested_input_schema: OpSchema,
) -> None:
# NOTE: it's very rare that we need to reshard kwargs so we intentionally skip it
diff --git a/src/chop/distributed/tensor/_redistribute.py b/src/chop/distributed/tensor/_redistribute.py
index e65766eb7..e04495126 100644
--- a/src/chop/distributed/tensor/_redistribute.py
+++ b/src/chop/distributed/tensor/_redistribute.py
@@ -48,7 +48,8 @@ def _replicate_then_shard(val: _TransformInfo) -> int:
@lru_cache(maxsize=None)
def _gen_transform_infos(
- src_spec: _DTensorSpec, dst_spec: _DTensorSpec,
+ src_spec: _DTensorSpec,
+ dst_spec: _DTensorSpec,
) -> List[_TransformInfo]:
"""
Generate the transform infos from the source placements to the target placements.
@@ -87,7 +88,9 @@ def _gen_transform_infos(
# calculate and save the logical shape for this sharding
mesh_dim_size = device_mesh.size(mesh_dim=i)
local_shard_size, _ = src._local_shard_size_on_dim(
- current_logical_shape[src.dim], mesh_dim_size, my_coordinate[i],
+ current_logical_shape[src.dim],
+ mesh_dim_size,
+ my_coordinate[i],
)
new_logical_shape = list(current_logical_shape)
new_logical_shape[src.dim] = local_shard_size
@@ -285,7 +288,11 @@ def forward( # type: ignore[override]
output = input._local_tensor
target_spec = current_spec
- return dtensor.DTensor(output, target_spec, requires_grad=input.requires_grad,)
+ return dtensor.DTensor(
+ output,
+ target_spec,
+ requires_grad=input.requires_grad,
+ )
@staticmethod
def backward(ctx, grad_output: "dtensor.DTensor"): # type: ignore[override]
@@ -320,7 +327,9 @@ def backward(ctx, grad_output: "dtensor.DTensor"): # type: ignore[override]
),
)
output_dtensor = dtensor.DTensor(
- output, spec, requires_grad=grad_output.requires_grad,
+ output,
+ spec,
+ requires_grad=grad_output.requires_grad,
)
return (
diff --git a/src/chop/distributed/tensor/_sharding_prop.py b/src/chop/distributed/tensor/_sharding_prop.py
index 8e1e26db1..7351b7895 100644
--- a/src/chop/distributed/tensor/_sharding_prop.py
+++ b/src/chop/distributed/tensor/_sharding_prop.py
@@ -48,7 +48,8 @@ class ShardingPropagator:
def __init__(self) -> None:
self.op_to_rules: Dict[OpOverload, Callable[[OpSchema], OutputSharding]] = {}
self.op_strategy_funcs: Dict[
- OpOverload, Callable[[DeviceMesh, OpSchema], StrategyType],
+ OpOverload,
+ Callable[[DeviceMesh, OpSchema], StrategyType],
] = {}
# op map to save static argnum to decide to reuse sharding prop cache or re-run sharding prop
self.op_to_schema_info: Dict[OpOverload, RuntimeSchemaInfo] = {}
@@ -343,8 +344,10 @@ def spec_to_strategy(spec: object) -> object:
expected_input_spec = selected_strategies[idx].input_spec(
tensor_or_list_tensor_arg_idx
)
- expected_input_spec = expected_input_spec.shallow_copy_with_tensor_meta(
- arg_spec.tensor_meta
+ expected_input_spec = (
+ expected_input_spec.shallow_copy_with_tensor_meta(
+ arg_spec.tensor_meta
+ )
)
if arg_spec.placements != expected_input_spec.placements:
needs_redistribute = True
@@ -360,8 +363,10 @@ def spec_to_strategy(spec: object) -> object:
expected_input_spec = selected_strategies[0].input_spec(
tensor_or_list_tensor_arg_idx
)
- expected_input_spec = expected_input_spec.shallow_copy_with_tensor_meta(
- arg.tensor_meta
+ expected_input_spec = (
+ expected_input_spec.shallow_copy_with_tensor_meta(
+ arg.tensor_meta
+ )
)
if arg.placements != expected_input_spec.placements:
needs_redistribute = True
diff --git a/src/chop/distributed/tensor/api.py b/src/chop/distributed/tensor/api.py
index a27c848f4..dcef33b24 100644
--- a/src/chop/distributed/tensor/api.py
+++ b/src/chop/distributed/tensor/api.py
@@ -64,7 +64,9 @@
class _ToTorchTensor(torch.autograd.Function):
@staticmethod
def forward( # type: ignore[override]
- ctx, input: "DTensor", grad_placements: Optional[Sequence[Placement]],
+ ctx,
+ input: "DTensor",
+ grad_placements: Optional[Sequence[Placement]],
):
ctx.dtensor_spec = input._spec
ctx.grad_placements = grad_placements
@@ -98,7 +100,11 @@ def backward(ctx, grad_output: torch.Tensor): # type: ignore[override]
)
return (
- DTensor(grad_output, grad_spec, requires_grad=grad_output.requires_grad,),
+ DTensor(
+ grad_output,
+ grad_spec,
+ requires_grad=grad_output.requires_grad,
+ ),
None,
)
@@ -154,7 +160,11 @@ def forward( # type: ignore[override]
dist_spec = _DTensorSpec(
device_mesh,
placements,
- tensor_meta=TensorMeta(tensor_shape, tensor_stride, input.dtype,),
+ tensor_meta=TensorMeta(
+ tensor_shape,
+ tensor_stride,
+ input.dtype,
+ ),
)
# We want a fresh Tensor object that shares memory with the input tensor
@@ -207,7 +217,11 @@ class DTensor(torch.Tensor): # pyre-ignore[13]: pyre is bad at __new__
@staticmethod
@torch._disable_dynamo
def __new__(
- cls, local_tensor: torch.Tensor, spec: _DTensorSpec, *, requires_grad: bool,
+ cls,
+ local_tensor: torch.Tensor,
+ spec: _DTensorSpec,
+ *,
+ requires_grad: bool,
) -> "DTensor":
"""
Construct a DTensor from a local tensor, device mesh, and placement and
@@ -263,12 +277,20 @@ def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride):
local_tensor = inner_tensors["_local_tensor"]
spec, requires_grad = flatten_spec
unflatten_tensor_meta = TensorMeta(
- shape=outer_size, stride=outer_stride, dtype=spec.tensor_meta.dtype,
+ shape=outer_size,
+ stride=outer_stride,
+ dtype=spec.tensor_meta.dtype,
)
unflatten_spec = _DTensorSpec(
- spec.mesh, spec.placements, tensor_meta=unflatten_tensor_meta,
+ spec.mesh,
+ spec.placements,
+ tensor_meta=unflatten_tensor_meta,
+ )
+ return DTensor(
+ local_tensor,
+ unflatten_spec,
+ requires_grad=requires_grad,
)
- return DTensor(local_tensor, unflatten_spec, requires_grad=requires_grad,)
def __coerce_tangent_metadata__(self):
if not any(isinstance(p, Partial) for p in self.placements):
@@ -281,7 +303,8 @@ def __coerce_tangent_metadata__(self):
def __coerce_same_metadata_as_tangent__(self, flatten_spec):
(spec, _) = flatten_spec # Result of tensor_flatten()
return self.redistribute(
- device_mesh=self.device_mesh, placements=spec.placements,
+ device_mesh=self.device_mesh,
+ placements=spec.placements,
)
@classmethod
@@ -289,7 +312,11 @@ def __coerce_same_metadata_as_tangent__(self, flatten_spec):
# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
- return DTensor._op_dispatcher.dispatch(func, args, kwargs or {},)
+ return DTensor._op_dispatcher.dispatch(
+ func,
+ args,
+ kwargs or {},
+ )
@staticmethod
def from_local(
@@ -367,7 +394,12 @@ def from_local(
# created should flow back the gradients to the local_tensor, so we call an autograd
# function to construct the dist tensor instead.
return _FromTorchTensor.apply( # pyre-ignore[16]: autograd func
- local_tensor, device_mesh, tuple(placements), run_check, shape, stride,
+ local_tensor,
+ device_mesh,
+ tuple(placements),
+ run_check,
+ shape,
+ stride,
)
def to_local(
@@ -667,7 +699,9 @@ def distribute_tensor(
mesh=device_mesh,
placements=placements,
tensor_meta=TensorMeta(
- shape=tensor.size(), stride=tensor.stride(), dtype=tensor.dtype,
+ shape=tensor.size(),
+ stride=tensor.stride(),
+ dtype=tensor.dtype,
),
)
return DTensor(
diff --git a/src/chop/distributed/tensor/ops/basic_strategy.py b/src/chop/distributed/tensor/ops/basic_strategy.py
index 3a9e81ae2..70ba54416 100644
--- a/src/chop/distributed/tensor/ops/basic_strategy.py
+++ b/src/chop/distributed/tensor/ops/basic_strategy.py
@@ -84,7 +84,10 @@ def parse_dims(cls, input_dims: List[str], output_dim: str) -> "EinsumDims":
def gen_einsum_strategies(
- equation: str, mesh: DeviceMesh, *, linearity: bool = False,
+ equation: str,
+ mesh: DeviceMesh,
+ *,
+ linearity: bool = False,
) -> OpStrategy:
"""
Generate a strategy list for the ops that follow einsum style notation.
diff --git a/src/chop/distributed/tensor/ops/common_rules.py b/src/chop/distributed/tensor/ops/common_rules.py
index 1324f7806..e0e9a6e34 100644
--- a/src/chop/distributed/tensor/ops/common_rules.py
+++ b/src/chop/distributed/tensor/ops/common_rules.py
@@ -37,7 +37,10 @@ def _gen_reshard_suggestions(
)
suggested_schema = OpSchema(op_schema.op, tuple(suggested_arg_specs), {})
suggested_schema._inplace_rewrap_schema_suggestion(op_schema)
- return OutputSharding(None, redistribute_schema=suggested_schema,)
+ return OutputSharding(
+ None,
+ redistribute_schema=suggested_schema,
+ )
def einop_rule(
@@ -215,7 +218,10 @@ def merge_sharding(dim: str, a: int, b: int) -> int:
)
return OutputSharding(
_DTensorSpec.from_dim_map(
- input_specs[0].mesh, output_dim_map, pending_sums, tensor_meta=tensor_meta,
+ input_specs[0].mesh,
+ output_dim_map,
+ pending_sums,
+ tensor_meta=tensor_meta,
)
)
@@ -275,5 +281,8 @@ def pointwise_rule(op_schema: OpSchema, linearity: bool = False) -> OutputShardi
enforce_sharding[out_dimchar] = mesh_dim
return einop_rule(
- fmt, op_schema, linearity=linearity, enforce_sharding=enforce_sharding,
+ fmt,
+ op_schema,
+ linearity=linearity,
+ enforce_sharding=enforce_sharding,
)
diff --git a/src/chop/distributed/tensor/ops/conv_ops.py b/src/chop/distributed/tensor/ops/conv_ops.py
index 9a4d5f425..707f391d6 100644
--- a/src/chop/distributed/tensor/ops/conv_ops.py
+++ b/src/chop/distributed/tensor/ops/conv_ops.py
@@ -50,11 +50,16 @@ def convolution_rules(op_schema: OpSchema) -> OutputSharding:
pending_sums = input_spec.sums
tensor_meta = TensorMeta(
- torch.Size(output_shape), output_stride, input_spec.tensor_meta.dtype,
+ torch.Size(output_shape),
+ output_stride,
+ input_spec.tensor_meta.dtype,
)
return OutputSharding(
_DTensorSpec.from_dim_map(
- input_spec.mesh, output_dim_map, pending_sums, tensor_meta=tensor_meta,
+ input_spec.mesh,
+ output_dim_map,
+ pending_sums,
+ tensor_meta=tensor_meta,
)
)
@@ -83,14 +88,22 @@ def convolution_backward_rules(op_schema: OpSchema) -> OutputSharding:
assert input_spec.tensor_meta is not None
weight_tensor_meta = weight_spec.tensor_meta
bias_tensor_meta = TensorMeta(
- torch.Size(bias_shape_opt), (1,), input_spec.tensor_meta.dtype,
+ torch.Size(bias_shape_opt),
+ (1,),
+ input_spec.tensor_meta.dtype,
)
grad_input_spec = input_spec
grad_weight_spec = _DTensorSpec.from_dim_map(
- input_spec.mesh, [-1, -1, -1, -1], [0], tensor_meta=weight_tensor_meta,
+ input_spec.mesh,
+ [-1, -1, -1, -1],
+ [0],
+ tensor_meta=weight_tensor_meta,
)
grad_bias_spec = _DTensorSpec.from_dim_map(
- input_spec.mesh, [-1], [0], tensor_meta=bias_tensor_meta,
+ input_spec.mesh,
+ [-1],
+ [0],
+ tensor_meta=bias_tensor_meta,
)
return OutputSharding([grad_input_spec, grad_weight_spec, grad_bias_spec])
diff --git a/src/chop/distributed/tensor/ops/math_ops.py b/src/chop/distributed/tensor/ops/math_ops.py
index da8dad8d5..770514219 100644
--- a/src/chop/distributed/tensor/ops/math_ops.py
+++ b/src/chop/distributed/tensor/ops/math_ops.py
@@ -128,7 +128,7 @@ def _pre_reduce_transform(self, tensor: torch.Tensor) -> torch.Tensor:
if self.reduce_op == "sum":
assert isinstance(self.norm_type, (int, float)), f"{self.norm_type}"
if self.norm_type != 0 and self.norm_type != 1:
- return tensor ** self.norm_type
+ return tensor**self.norm_type
return tensor
def _post_reduce_transform(self, tensor: torch.Tensor) -> torch.Tensor:
@@ -289,7 +289,10 @@ def common_reduction_strategy(
redistribute_cost = [generate_redistribute_costs(input_strategy, input_spec)]
reduction_strategy.strategies.append(
PlacementStrategy(
- output_specs=_DTensorSpec(mesh=mesh, placements=out_placements,),
+ output_specs=_DTensorSpec(
+ mesh=mesh,
+ placements=out_placements,
+ ),
input_specs=(input_spec,),
redistribute_cost=redistribute_cost,
)
@@ -475,7 +478,10 @@ def softmax_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy:
@register_op_strategy(
- [aten._log_softmax_backward_data.default, aten._softmax_backward_data.default,],
+ [
+ aten._log_softmax_backward_data.default,
+ aten._softmax_backward_data.default,
+ ],
schema_info=RuntimeSchemaInfo(2),
)
def softmax_backward_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy:
@@ -608,14 +614,21 @@ def nll_loss_forward_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrate
reduce_dims_map,
reduction_op,
)
- output_expected_spec = _DTensorSpec(mesh=mesh, placements=out_placements,)
+ output_expected_spec = _DTensorSpec(
+ mesh=mesh,
+ placements=out_placements,
+ )
# whether reduction is sum or mean, the total weight has to be summed up if not replicated
total_weight_placements = map_placements_after_reduction(
- target_expected_spec.placements, reduce_dims, reduce_dims_map, "sum",
+ target_expected_spec.placements,
+ reduce_dims,
+ reduce_dims_map,
+ "sum",
)
total_weight_expected_spec = _DTensorSpec(
- mesh=mesh, placements=total_weight_placements,
+ mesh=mesh,
+ placements=total_weight_placements,
)
output_strategy.strategies.append(
@@ -748,7 +761,8 @@ def rlog(msg):
@register_op_strategy(
- [aten.native_layer_norm.default], schema_info=RuntimeSchemaInfo(1),
+ [aten.native_layer_norm.default],
+ schema_info=RuntimeSchemaInfo(1),
)
def layer_norm_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy:
# args must be: input, normalized_shape, weight, bias, eps
@@ -842,7 +856,8 @@ def layer_norm_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy:
@register_op_strategy(
- [aten.native_layer_norm_backward.default], schema_info=RuntimeSchemaInfo(2),
+ [aten.native_layer_norm_backward.default],
+ schema_info=RuntimeSchemaInfo(2),
)
def layer_norm_bwd_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy:
# args must be: grad_out, input, normalized_shape, mean, rstd,
@@ -999,7 +1014,8 @@ def layer_norm_bwd_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy
@register_op_strategy(
- [aten.topk.default], schema_info=RuntimeSchemaInfo(2),
+ [aten.topk.default],
+ schema_info=RuntimeSchemaInfo(2),
)
def topk_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy:
input_strategy = cast(OpStrategy, op_schema.args_schema[0])
diff --git a/src/chop/distributed/tensor/ops/matrix_ops.py b/src/chop/distributed/tensor/ops/matrix_ops.py
index 029a795df..f76dca190 100644
--- a/src/chop/distributed/tensor/ops/matrix_ops.py
+++ b/src/chop/distributed/tensor/ops/matrix_ops.py
@@ -384,7 +384,10 @@ def scaled_dot_product_efficient_attention_strategy(
single_mesh_dim_strategies.append(num_heads_dim_sharding)
return expand_to_full_mesh_op_strategy(
- mesh, op_schema, single_mesh_dim_strategies, input_index=4,
+ mesh,
+ op_schema,
+ single_mesh_dim_strategies,
+ input_index=4,
)
@@ -449,5 +452,8 @@ def scaled_dot_product_efficient_attention_backward_strategy(
single_mesh_dim_strategies.append(num_heads_dim_sharding)
return expand_to_full_mesh_op_strategy(
- mesh, op_schema, single_mesh_dim_strategies, input_index=4,
+ mesh,
+ op_schema,
+ single_mesh_dim_strategies,
+ input_index=4,
)
diff --git a/src/chop/distributed/tensor/ops/pointwise_ops.py b/src/chop/distributed/tensor/ops/pointwise_ops.py
index 656fd2996..221001f01 100644
--- a/src/chop/distributed/tensor/ops/pointwise_ops.py
+++ b/src/chop/distributed/tensor/ops/pointwise_ops.py
@@ -483,7 +483,9 @@ def common_pointwise_strategy(
common_shape, input_arg_spec.shape
)
input_target_placements = map_placements_after_broadcast(
- tuple(out_placements), common_shape, input_arg_dims_map,
+ tuple(out_placements),
+ common_shape,
+ input_arg_dims_map,
)
input_arg_target_spec = _DTensorSpec(
mesh=mesh,
@@ -497,7 +499,10 @@ def common_pointwise_strategy(
pointwise_strategy.strategies.append(
PlacementStrategy(
- output_specs=_DTensorSpec(mesh=mesh, placements=tuple(out_placements),),
+ output_specs=_DTensorSpec(
+ mesh=mesh,
+ placements=tuple(out_placements),
+ ),
input_specs=input_specs,
redistribute_cost=redistribute_costs,
)
diff --git a/src/chop/distributed/tensor/ops/tensor_ops.py b/src/chop/distributed/tensor/ops/tensor_ops.py
index f54640e8a..dcddcb98c 100644
--- a/src/chop/distributed/tensor/ops/tensor_ops.py
+++ b/src/chop/distributed/tensor/ops/tensor_ops.py
@@ -76,7 +76,10 @@ def default_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType:
@register_op_strategy(
- [aten.equal.default, aten.is_same_size.default,]
+ [
+ aten.equal.default,
+ aten.is_same_size.default,
+ ]
)
def equal_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType:
# equal_strategy deals with ops that comparing two tensor, we need to make sure
@@ -125,7 +128,8 @@ def equal_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType:
schema_info=RuntimeSchemaInfo(1, ["dtype"]),
)
@register_op_strategy(
- [aten.full_like.default], schema_info=RuntimeSchemaInfo(2, ["dtype"]),
+ [aten.full_like.default],
+ schema_info=RuntimeSchemaInfo(2, ["dtype"]),
)
@register_op_strategy(
[
@@ -692,7 +696,8 @@ def place(vp: Placement, ip: Placement) -> Placement:
)
result = OutputSharding(
output_spec=_DTensorSpec(
- mesh=values_spec.mesh, placements=value_placements,
+ mesh=values_spec.mesh,
+ placements=value_placements,
)
)
return result
@@ -779,7 +784,10 @@ def size_split(N, i):
else split_size_or_sections
)
output_spec_list = [
- _DTensorSpec(mesh=input_spec.mesh, placements=input_spec.placements,)
+ _DTensorSpec(
+ mesh=input_spec.mesh,
+ placements=input_spec.placements,
+ )
for _ in range(len(output_size_list))
]
return OutputSharding(output_spec_list)
diff --git a/src/chop/distributed/tensor/ops/utils.py b/src/chop/distributed/tensor/ops/utils.py
index e005501d6..28e5245c3 100644
--- a/src/chop/distributed/tensor/ops/utils.py
+++ b/src/chop/distributed/tensor/ops/utils.py
@@ -196,7 +196,9 @@ def infer_broadcast_dims_map(
def map_placements_after_broadcast(
- placements: Tuple[Placement, ...], shape: torch.Size, broadcast_dims_map: List[int],
+ placements: Tuple[Placement, ...],
+ shape: torch.Size,
+ broadcast_dims_map: List[int],
) -> Tuple[Placement, ...]:
"""Map each placement based on the output shape after broadcast."""
new_placements: List[Placement] = []
diff --git a/src/chop/distributed/tensor/ops/view_ops.py b/src/chop/distributed/tensor/ops/view_ops.py
index 6b89771ec..5c91f6d64 100644
--- a/src/chop/distributed/tensor/ops/view_ops.py
+++ b/src/chop/distributed/tensor/ops/view_ops.py
@@ -238,7 +238,9 @@ def dim_flatten(ndim: int, start_dim=0, end_dim=-1) -> DimMap:
def dim_movedim(
- ndim: int, input: Union[int, Sequence[int]], destination: Union[int, Sequence[int]],
+ ndim: int,
+ input: Union[int, Sequence[int]],
+ destination: Union[int, Sequence[int]],
) -> DimMap:
input = normalize_dims(input, ndim)
destination = normalize_dims(destination, ndim)
@@ -604,7 +606,10 @@ def reshape_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType:
input_src_spec = input_placement_strategy.output_spec
input_tgt_placements, output_placements = propagate_shape_and_sharding(
- input_src_spec.placements, tuple(global_in_shape), rules, mesh.shape,
+ input_src_spec.placements,
+ tuple(global_in_shape),
+ rules,
+ mesh.shape,
)
# TODO: optimize this. we shouldn't simply blindly replicate
diff --git a/src/chop/ir/graph/mase_graph.py b/src/chop/ir/graph/mase_graph.py
index 5ee6b8bb3..322195869 100644
--- a/src/chop/ir/graph/mase_graph.py
+++ b/src/chop/ir/graph/mase_graph.py
@@ -76,7 +76,13 @@ def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool
is_fx_built_in_leaf_module = super().is_leaf_module(m, module_qualified_name)
is_mase_leaf_layers = isinstance(m, MASE_LEAF_LAYERS)
is_custom_layer = isinstance(m, self.custom_leaf_layers)
- return any((is_fx_built_in_leaf_module, is_mase_leaf_layers, is_custom_layer,))
+ return any(
+ (
+ is_fx_built_in_leaf_module,
+ is_mase_leaf_layers,
+ is_custom_layer,
+ )
+ )
def trace_torch_module(
@@ -122,13 +128,19 @@ def is_leaf_module(
self, m: torch.nn.Module, module_qualified_name: str
) -> bool:
is_hf_built_in_leaf_module = hf_is_leaf_module(
- self, m, module_qualified_name,
+ self,
+ m,
+ module_qualified_name,
)
is_custom_module = isinstance(m, custom_modules)
is_mase_leaf_layer = isinstance(m, MASE_LEAF_LAYERS)
return any(
- (is_hf_built_in_leaf_module, is_custom_module, is_mase_leaf_layer,)
+ (
+ is_hf_built_in_leaf_module,
+ is_custom_module,
+ is_mase_leaf_layer,
+ )
)
return is_leaf_module
@@ -140,7 +152,9 @@ def is_leaf_module(
)
graph_module = hf_symbolic_trace(
- model, tracer_cls=tracer_cls, input_names=hf_input_names,
+ model,
+ tracer_cls=tracer_cls,
+ input_names=hf_input_names,
)
graph_module.custom_ops = custom_ops
@@ -293,7 +307,10 @@ def __init__(
self.model.additional_inputs = []
elif isinstance(model, torch.nn.Module):
self.model = trace_torch_module(
- model, cf_args, custom_ops, hf_input_names=hf_input_names,
+ model,
+ cf_args,
+ custom_ops,
+ hf_input_names=hf_input_names,
)
else:
raise ValueError(
@@ -332,11 +349,16 @@ def from_module(
), f"model must be a torch.nn.Module. Received: {type(model)}"
graph_module = trace_torch_module(model, cf_args, custom_ops)
- return cls(model=graph_module, cf_args=cf_args,)
+ return cls(
+ model=graph_module,
+ cf_args=cf_args,
+ )
@classmethod
def from_checkpoint(
- cls, checkpoint: str, propagate_missing_metadata: bool = True,
+ cls,
+ checkpoint: str,
+ propagate_missing_metadata: bool = True,
):
"""
Load a MaseGraph from a checkpoint. A MaseGraph checkpoint consists of two files:
@@ -371,12 +393,18 @@ def from_checkpoint(
for node in mg.nodes:
if node.name in loaded_meta.keys():
parameters = loaded_meta[node.name]
- node.meta["mase"] = MaseMetadata(node=node, model=loaded_model,)
+ node.meta["mase"] = MaseMetadata(
+ node=node,
+ model=loaded_model,
+ )
node.meta["mase"].parameters = parameters
else:
# todo: propagate metadata for missing nodes
logger.warning(f"Node {node.name} not found in loaded metadata.")
- node.meta["mase"] = MaseMetadata(node=node, model=loaded_model,)
+ node.meta["mase"] = MaseMetadata(
+ node=node,
+ model=loaded_model,
+ )
for attr in [
"class_for_deserialization",
@@ -389,7 +417,8 @@ def from_checkpoint(
return mg
def export(
- self, fname: str = "masegraph",
+ self,
+ fname: str = "masegraph",
):
"""
Export the MaseGraph to a pair of files: {fname}.pt and {fname}.mz.
diff --git a/src/chop/ir/graph/mase_metadata.py b/src/chop/ir/graph/mase_metadata.py
index 52fcf0127..6f01d83c8 100644
--- a/src/chop/ir/graph/mase_metadata.py
+++ b/src/chop/ir/graph/mase_metadata.py
@@ -100,7 +100,9 @@ class MaseMetadata:
known_storage = ["BRAM"]
def __init__(
- self, node=None, model=None,
+ self,
+ node=None,
+ model=None,
):
# Top-level model
self.model = model
diff --git a/src/chop/ir/onnx/mase_onnx_graph.py b/src/chop/ir/onnx/mase_onnx_graph.py
index e4a8f52ed..f6e1d31e6 100644
--- a/src/chop/ir/onnx/mase_onnx_graph.py
+++ b/src/chop/ir/onnx/mase_onnx_graph.py
@@ -10,7 +10,9 @@
class MaseOnnxGraph:
def __init__(
- self, model_proto: onnx.onnx_ml_pb2.ModelProto, model_name: str = None,
+ self,
+ model_proto: onnx.onnx_ml_pb2.ModelProto,
+ model_name: str = None,
):
self.model_proto = model_proto
self.graph = model_proto.graph
diff --git a/src/chop/ir/onnx/utils.py b/src/chop/ir/onnx/utils.py
index c5250bba0..bb65aa8b2 100644
--- a/src/chop/ir/onnx/utils.py
+++ b/src/chop/ir/onnx/utils.py
@@ -188,7 +188,10 @@ def onnx_to_torch_dtype(dtype):
"target": torch.mean,
"input_mapping": ["input"],
"attribute_mapping": {"keepdims": "", "axes": ""},
- "attribute_transform": {"keepdims": None, "axes": None,},
+ "attribute_transform": {
+ "keepdims": None,
+ "axes": None,
+ },
"attribute_default": {"keepdims": 1, "axes": None},
},
"Expand": {
@@ -332,7 +335,9 @@ def onnx_to_torch_dtype(dtype):
"input_mapping": ["input"],
"attribute_mapping": {"perm": "dims"},
"attribute_transform": {"perm": lambda x: [i for i in x]},
- "attribute_default": {"perm": None,},
+ "attribute_default": {
+ "perm": None,
+ },
},
"Max": {
"fx_op": "call_function",
diff --git a/src/chop/models/bert/modeling_bert_quantized.py b/src/chop/models/bert/modeling_bert_quantized.py
index c6747f0e9..8061d2593 100644
--- a/src/chop/models/bert/modeling_bert_quantized.py
+++ b/src/chop/models/bert/modeling_bert_quantized.py
@@ -561,7 +561,9 @@ def __init__(self, config, quant_config: dict):
super().__init__()
# self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
self.dense = get_quantized_cls("linear", quant_config["dense"])(
- config.hidden_size, config.intermediate_size, config=quant_config["dense"],
+ config.hidden_size,
+ config.intermediate_size,
+ config=quant_config["dense"],
)
self.quant_config = quant_config
if isinstance(config.hidden_act, str):
@@ -580,7 +582,9 @@ def __init__(self, config, quant_config):
super().__init__()
# self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.dense = get_quantized_cls("linear", quant_config["dense"])(
- config.intermediate_size, config.hidden_size, config=quant_config["dense"],
+ config.intermediate_size,
+ config.hidden_size,
+ config=quant_config["dense"],
)
self.quant_config = quant_config
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
diff --git a/src/chop/models/bert/quant_config_bert.py b/src/chop/models/bert/quant_config_bert.py
index 4dadb97f0..f66c45a4c 100644
--- a/src/chop/models/bert/quant_config_bert.py
+++ b/src/chop/models/bert/quant_config_bert.py
@@ -88,7 +88,10 @@ def create_a_layer_config(
return qc
-def _parse_and_complete_config(config: dict, num_hidden_layers: int,) -> dict:
+def _parse_and_complete_config(
+ config: dict,
+ num_hidden_layers: int,
+) -> dict:
assert "default" in config, "Must provide a default config"
default_qc: dict = config["default"]
linear_qc: dict = parse_node_config(
diff --git a/src/chop/models/cnv/cnv.py b/src/chop/models/cnv/cnv.py
index ee1912f18..8bb9ad468 100644
--- a/src/chop/models/cnv/cnv.py
+++ b/src/chop/models/cnv/cnv.py
@@ -202,7 +202,8 @@ def forward(self, x: Tensor) -> Tensor:
# Getters ------------------------------------------------------------------------------
@register_mase_checkpoint("cnv-toy")
def get_cnv_toy(
- pretrained=False, **kwargs: Any,
+ pretrained=False,
+ **kwargs: Any,
):
# image_size = info["image_size"]
info = kwargs["dataset_info"]
@@ -212,7 +213,8 @@ def get_cnv_toy(
@register_mase_checkpoint("cnv")
def get_cnv(
- pretrained=False, **kwargs: Any,
+ pretrained=False,
+ **kwargs: Any,
):
# image_size = info["image_size"]
info = kwargs["dataset_info"]
@@ -222,7 +224,8 @@ def get_cnv(
@register_mase_checkpoint("cnv_residual")
def get_cnv_residual(
- pretrained=False, **kwargs: Any,
+ pretrained=False,
+ **kwargs: Any,
):
# image_size = info["image_size"]
info = kwargs["dataset_info"]
diff --git a/src/chop/models/cswin/cswintransformer.py b/src/chop/models/cswin/cswintransformer.py
index 449a22d30..cbd85cda6 100644
--- a/src/chop/models/cswin/cswintransformer.py
+++ b/src/chop/models/cswin/cswintransformer.py
@@ -90,7 +90,7 @@ def __init__(
self.num_heads = num_heads
head_dim = dim // num_heads
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
- self.scale = qk_scale or head_dim ** -0.5
+ self.scale = qk_scale or head_dim**-0.5
if idx == -1:
H_sp, W_sp = self.resolution, self.resolution
elif idx == 0:
@@ -362,9 +362,9 @@ def __init__(
super().__init__()
self.use_chk = use_chk
self.num_classes = num_classes
- self.num_features = (
- self.embed_dim
- ) = embed_dim # num_features for consistency with other models
+ self.num_features = self.embed_dim = (
+ embed_dim # num_features for consistency with other models
+ )
heads = num_heads
self.stage1_conv_embed = nn.Sequential(
diff --git a/src/chop/models/deit/deit_v2.py b/src/chop/models/deit/deit_v2.py
index 21b2a5ad4..4f0128216 100644
--- a/src/chop/models/deit/deit_v2.py
+++ b/src/chop/models/deit/deit_v2.py
@@ -28,7 +28,7 @@ def __init__(
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
- self.scale = qk_scale or head_dim ** -0.5
+ self.scale = qk_scale or head_dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
diff --git a/src/chop/models/efficientnet/efficientnet.py b/src/chop/models/efficientnet/efficientnet.py
index b3af62d05..5d2f7730a 100644
--- a/src/chop/models/efficientnet/efficientnet.py
+++ b/src/chop/models/efficientnet/efficientnet.py
@@ -499,7 +499,8 @@ def forward(self, x: Tensor) -> Tensor:
def _efficientnet_conf(
- arch: str, **kwargs: Any,
+ arch: str,
+ **kwargs: Any,
) -> Tuple[Sequence[Union[MBConvConfig, FusedMBConvConfig]], Optional[int]]:
inverted_residual_setting: Sequence[Union[MBConvConfig, FusedMBConvConfig]]
if arch.startswith("efficientnet_b"):
@@ -591,7 +592,9 @@ def _efficientnet(
def get_efficientnet_b0(
- info: Dict, pretrained: bool = False, **kwargs: Any,
+ info: Dict,
+ pretrained: bool = False,
+ **kwargs: Any,
):
num_classes = info.num_classes
if pretrained:
@@ -613,7 +616,9 @@ def get_efficientnet_b0(
def get_efficientnet_b3(
- info: Dict, pretrained: bool = False, **kwargs: Any,
+ info: Dict,
+ pretrained: bool = False,
+ **kwargs: Any,
):
num_classes = info.num_classes
if pretrained:
@@ -636,7 +641,9 @@ def get_efficientnet_b3(
def get_efficientnet_v2_s(
- info: Dict, pretrained: bool = False, **kwargs: Any,
+ info: Dict,
+ pretrained: bool = False,
+ **kwargs: Any,
):
num_classes = info.num_classes
if pretrained:
@@ -656,7 +663,9 @@ def get_efficientnet_v2_s(
def get_efficientnet_v2_m(
- info: Dict, pretrained: bool = False, **kwargs: Any,
+ info: Dict,
+ pretrained: bool = False,
+ **kwargs: Any,
):
num_classes = info.num_classes
if pretrained:
@@ -676,7 +685,9 @@ def get_efficientnet_v2_m(
def get_efficientnet_v2_l(
- info: Dict, pretrained: bool = False, **kwargs: Any,
+ info: Dict,
+ pretrained: bool = False,
+ **kwargs: Any,
):
num_classes = info.num_classes
if pretrained:
diff --git a/src/chop/models/lfc/lfc.py b/src/chop/models/lfc/lfc.py
index 410359a25..70da9df45 100644
--- a/src/chop/models/lfc/lfc.py
+++ b/src/chop/models/lfc/lfc.py
@@ -42,7 +42,9 @@ def forward(self, x):
# Getters ------------------------------------------------------------------------------
def get_lfc(
- info, pretrained=False, **kwargs: Any,
+ info,
+ pretrained=False,
+ **kwargs: Any,
):
image_size = info["image_size"]
num_classes = info.num_classes
diff --git a/src/chop/models/llama/modeling_llama_llora.py b/src/chop/models/llama/modeling_llama_llora.py
index 45eb2f210..b45dd8494 100644
--- a/src/chop/models/llama/modeling_llama_llora.py
+++ b/src/chop/models/llama/modeling_llama_llora.py
@@ -195,7 +195,10 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
class LlamaMLP(nn.Module):
def __init__(
- self, hidden_size: int, intermediate_size: int, hidden_act: str,
+ self,
+ hidden_size: int,
+ intermediate_size: int,
+ hidden_act: str,
):
super().__init__()
# fmt: off
diff --git a/src/chop/models/llama/modeling_llama_sparse.py b/src/chop/models/llama/modeling_llama_sparse.py
index 21359ea34..f5f37eba0 100644
--- a/src/chop/models/llama/modeling_llama_sparse.py
+++ b/src/chop/models/llama/modeling_llama_sparse.py
@@ -195,7 +195,10 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
class LlamaMLP(nn.Module):
def __init__(
- self, hidden_size: int, intermediate_size: int, hidden_act: str,
+ self,
+ hidden_size: int,
+ intermediate_size: int,
+ hidden_act: str,
):
super().__init__()
# fmt: off
diff --git a/src/chop/models/mobilenet_v2/mobilenet_v2.py b/src/chop/models/mobilenet_v2/mobilenet_v2.py
index 0980d58dd..d8f4020dd 100644
--- a/src/chop/models/mobilenet_v2/mobilenet_v2.py
+++ b/src/chop/models/mobilenet_v2/mobilenet_v2.py
@@ -288,7 +288,8 @@ def __init__(
# building classifier
self.classifier = nn.Sequential(
- nn.Dropout(p=dropout), nn.Linear(self.last_channel, num_classes),
+ nn.Dropout(p=dropout),
+ nn.Linear(self.last_channel, num_classes),
)
# weight initialization
diff --git a/src/chop/models/nerf/nerf_vision.py b/src/chop/models/nerf/nerf_vision.py
index 18338940b..f8591f1a4 100644
--- a/src/chop/models/nerf/nerf_vision.py
+++ b/src/chop/models/nerf/nerf_vision.py
@@ -139,7 +139,9 @@ def load_weights_from_keras(self, weights):
# Getters ------------------------------------------------------------------------------
def get_nerf(
- info, pretrained=False, **kwargs: Any,
+ info,
+ pretrained=False,
+ **kwargs: Any,
):
# image_size = info["image_size"]
num_classes = info.num_classes
diff --git a/src/chop/models/opt/modeling_opt.py b/src/chop/models/opt/modeling_opt.py
index 430f77df8..e88db8c1f 100644
--- a/src/chop/models/opt/modeling_opt.py
+++ b/src/chop/models/opt/modeling_opt.py
@@ -132,7 +132,7 @@ def __init__(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
f" and `num_heads`: {num_heads})."
)
- self.scaling = self.head_dim ** -0.5
+ self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
diff --git a/src/chop/models/opt/modeling_opt_lora.py b/src/chop/models/opt/modeling_opt_lora.py
index 30bb21f20..810cfe0d6 100644
--- a/src/chop/models/opt/modeling_opt_lora.py
+++ b/src/chop/models/opt/modeling_opt_lora.py
@@ -130,7 +130,7 @@ def __init__(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
f" and `num_heads`: {num_heads})."
)
- self.scaling = self.head_dim ** -0.5
+ self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder
lora_config = config.lora_config[f"model_layer_{layer_id}"]["self_attn"]
diff --git a/src/chop/models/opt/modeling_opt_quantized.py b/src/chop/models/opt/modeling_opt_quantized.py
index d6eb94343..d1086f413 100644
--- a/src/chop/models/opt/modeling_opt_quantized.py
+++ b/src/chop/models/opt/modeling_opt_quantized.py
@@ -168,7 +168,7 @@ def __init__(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
f" and `num_heads`: {num_heads})."
)
- self.scaling = self.head_dim ** -0.5
+ self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder
# fmt:off
diff --git a/src/chop/models/opt/modeling_opt_sparse.py b/src/chop/models/opt/modeling_opt_sparse.py
index 4addd6525..f39a29cc9 100644
--- a/src/chop/models/opt/modeling_opt_sparse.py
+++ b/src/chop/models/opt/modeling_opt_sparse.py
@@ -130,7 +130,7 @@ def __init__(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
f" and `num_heads`: {num_heads})."
)
- self.scaling = self.head_dim ** -0.5
+ self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder
sparse_config = config.sparse_config[f"model_layer_{layer_id}"]["self_attn"]
diff --git a/src/chop/models/opt/quant_config_opt_quantized.py b/src/chop/models/opt/quant_config_opt_quantized.py
index 150a1b2d6..a76f8adb8 100644
--- a/src/chop/models/opt/quant_config_opt_quantized.py
+++ b/src/chop/models/opt/quant_config_opt_quantized.py
@@ -32,7 +32,9 @@
def create_a_layer_config(
- linear_qc: dict = None, bmm_qc: dict = None, layer_qc=None,
+ linear_qc: dict = None,
+ bmm_qc: dict = None,
+ layer_qc=None,
) -> dict:
if (layer_qc is None and bmm_qc is None) and layer_qc is None:
raise ValueError("Must provide either (linear_qc & bmm_qc ) or layer_qc")
diff --git a/src/chop/models/pvt/pvt.py b/src/chop/models/pvt/pvt.py
index 2474fb388..ac3a18708 100644
--- a/src/chop/models/pvt/pvt.py
+++ b/src/chop/models/pvt/pvt.py
@@ -58,7 +58,7 @@ def __init__(
self.dim = dim
self.num_heads = num_heads
head_dim = dim // num_heads
- self.scale = qk_scale or head_dim ** -0.5
+ self.scale = qk_scale or head_dim**-0.5
self.q = nn.Linear(dim, dim, bias=qkv_bias)
self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
@@ -182,7 +182,12 @@ def forward(self, x):
@register_mase_model(
"pvt",
- checkpoints=["pvt_tiny", "pvt_small", "pvt_medium", "pvt_large",],
+ checkpoints=[
+ "pvt_tiny",
+ "pvt_small",
+ "pvt_medium",
+ "pvt_large",
+ ],
model_source="vision_others",
task_type="vision",
image_classification=True,
diff --git a/src/chop/models/pvt/pvt_v2.py b/src/chop/models/pvt/pvt_v2.py
index b70acb6d1..3ffb45f6b 100644
--- a/src/chop/models/pvt/pvt_v2.py
+++ b/src/chop/models/pvt/pvt_v2.py
@@ -82,7 +82,7 @@ def __init__(
self.dim = dim
self.num_heads = num_heads
head_dim = dim // num_heads
- self.scale = qk_scale or head_dim ** -0.5
+ self.scale = qk_scale or head_dim**-0.5
self.q = nn.Linear(dim, dim, bias=qkv_bias)
self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
diff --git a/src/chop/models/repvgg/repvgg.py b/src/chop/models/repvgg/repvgg.py
index 5ef7998e0..086501fc5 100644
--- a/src/chop/models/repvgg/repvgg.py
+++ b/src/chop/models/repvgg/repvgg.py
@@ -143,14 +143,14 @@ def get_custom_L2(self):
.detach()
)
- l2_loss_circle = (K3 ** 2).sum() - (
+ l2_loss_circle = (K3**2).sum() - (
K3[:, :, 1:2, 1:2] ** 2
).sum() # The L2 loss of the "circle" of weights in 3x3 kernel. Use regular L2 on them.
eq_kernel = (
K3[:, :, 1:2, 1:2] * t3 + K1 * t1
) # The equivalent resultant central point of 3x3 kernel.
l2_loss_eq_kernel = (
- eq_kernel ** 2 / (t3 ** 2 + t1 ** 2)
+ eq_kernel**2 / (t3**2 + t1**2)
).sum() # Normalize for an L2 coefficient comparable to regular L2.
return l2_loss_eq_kernel + l2_loss_circle
diff --git a/src/chop/models/resnet/resnet.py b/src/chop/models/resnet/resnet.py
index b339b72a7..313c8ce36 100644
--- a/src/chop/models/resnet/resnet.py
+++ b/src/chop/models/resnet/resnet.py
@@ -157,7 +157,13 @@ def forward(self, x: Tensor) -> Tensor:
@register_mase_model(
name="resnet",
- checkpoints=["resnet18", "resnet34", "resnet50", "resnet101", "wide_resnet50_2",],
+ checkpoints=[
+ "resnet18",
+ "resnet34",
+ "resnet50",
+ "resnet101",
+ "wide_resnet50_2",
+ ],
model_source="torchvision",
task_type="vision",
image_classification=True,
@@ -331,7 +337,10 @@ def _resnet(
@register_mase_checkpoint("resnet18")
-def get_resnet18(pretrained: bool = False, **kwargs: Any,) -> ResNet:
+def get_resnet18(
+ pretrained: bool = False,
+ **kwargs: Any,
+) -> ResNet:
"""ResNet-18 from `Deep Residual Learning for Image Recognition `__."""
if pretrained:
pretrained_weight_cls = ResNet18_Weights.IMAGENET1K_V1
@@ -339,12 +348,18 @@ def get_resnet18(pretrained: bool = False, **kwargs: Any,) -> ResNet:
pretrained_weight_cls = None
return _resnet(
- BasicBlock, [2, 2, 2, 2], pretrained_weight_cls=pretrained_weight_cls, **kwargs,
+ BasicBlock,
+ [2, 2, 2, 2],
+ pretrained_weight_cls=pretrained_weight_cls,
+ **kwargs,
)
@register_mase_checkpoint("resnet34")
-def get_resnet34(pretrained: bool = False, **kwargs: Any,) -> ResNet:
+def get_resnet34(
+ pretrained: bool = False,
+ **kwargs: Any,
+) -> ResNet:
"""ResNet-34 from `Deep Residual Learning for Image Recognition `__."""
if pretrained:
pretrained_weight_cls = ResNet34_Weights.IMAGENET1K_V1
@@ -352,12 +367,18 @@ def get_resnet34(pretrained: bool = False, **kwargs: Any,) -> ResNet:
pretrained_weight_cls = None
return _resnet(
- BasicBlock, [2, 2, 2, 2], pretrained_weight_cls=pretrained_weight_cls, **kwargs,
+ BasicBlock,
+ [2, 2, 2, 2],
+ pretrained_weight_cls=pretrained_weight_cls,
+ **kwargs,
)
@register_mase_checkpoint("resnet50")
-def get_resnet50(pretrained: bool = False, **kwargs: Any,) -> ResNet:
+def get_resnet50(
+ pretrained: bool = False,
+ **kwargs: Any,
+) -> ResNet:
"""ResNet-50 from `Deep Residual Learning for Image Recognition `__."""
info = kwargs["dataset_info"]
if pretrained:
@@ -366,12 +387,18 @@ def get_resnet50(pretrained: bool = False, **kwargs: Any,) -> ResNet:
pretrained_weight_cls = None
return _resnet(
- Bottleneck, [3, 4, 6, 3], pretrained_weight_cls=pretrained_weight_cls, **kwargs,
+ Bottleneck,
+ [3, 4, 6, 3],
+ pretrained_weight_cls=pretrained_weight_cls,
+ **kwargs,
)
@register_mase_checkpoint("resnet101")
-def get_resnet101(pretrained: bool = False, **kwargs: Any,) -> ResNet:
+def get_resnet101(
+ pretrained: bool = False,
+ **kwargs: Any,
+) -> ResNet:
"""ResNet-101 from `Deep Residual Learning for Image Recognition `__."""
info = kwargs["dataset_info"]
if pretrained:
@@ -380,13 +407,17 @@ def get_resnet101(pretrained: bool = False, **kwargs: Any,) -> ResNet:
pretrained_weight_cls = None
return _resnet(
- BasicBlock, [2, 2, 2, 2], pretrained_weight_cls=pretrained_weight_cls, **kwargs,
+ BasicBlock,
+ [2, 2, 2, 2],
+ pretrained_weight_cls=pretrained_weight_cls,
+ **kwargs,
)
@register_mase_checkpoint("wide_resnet50_2")
def get_wide_resnet50_2(
- pretrained: bool = False, **kwargs,
+ pretrained: bool = False,
+ **kwargs,
):
"""
`Wide Residual Networks `_.
@@ -404,5 +435,8 @@ def get_wide_resnet50_2(
_ovewrite_named_param(kwargs, "width_per_group", 64 * 2)
return _resnet(
- Bottleneck, [3, 4, 6, 3], pretrained_weight_cls=pretrained_weight_cls, **kwargs,
+ Bottleneck,
+ [3, 4, 6, 3],
+ pretrained_weight_cls=pretrained_weight_cls,
+ **kwargs,
)
diff --git a/src/chop/models/toy/toy.py b/src/chop/models/toy/toy.py
index 3890c4a0f..e88e030f6 100644
--- a/src/chop/models/toy/toy.py
+++ b/src/chop/models/toy/toy.py
@@ -160,7 +160,8 @@ def _conv_block(self, conv_class, *args):
# Getters ------------------------------------------------------------------------------
@register_mase_checkpoint("toy")
def get_toynet(
- pretrained=False, **kwargs: Any,
+ pretrained=False,
+ **kwargs: Any,
):
info = kwargs["dataset_info"]
image_size = info.image_size
@@ -169,7 +170,9 @@ def get_toynet(
def get_toy_tiny(
- info, pretrained=False, **kwargs: Any,
+ info,
+ pretrained=False,
+ **kwargs: Any,
):
image_size = info.image_size
num_classes = info.num_classes
@@ -177,7 +180,9 @@ def get_toy_tiny(
def get_toy_testmodel(
- info, pretrained=False, **kwargs: Any,
+ info,
+ pretrained=False,
+ **kwargs: Any,
):
image_size = info["image_size"]
num_classes = info.num_classes
@@ -186,7 +191,9 @@ def get_toy_testmodel(
def get_toy_convnet(
- info, pretrained=False, **kwargs: Any,
+ info,
+ pretrained=False,
+ **kwargs: Any,
):
# NOTE: The model isn't configurable through the CLI or a configuration file yet.
num_classes = info.num_classes
diff --git a/src/chop/models/vgg_cifar/vgg_orig.py b/src/chop/models/vgg_cifar/vgg_orig.py
index f744c4404..cf906e888 100644
--- a/src/chop/models/vgg_cifar/vgg_orig.py
+++ b/src/chop/models/vgg_cifar/vgg_orig.py
@@ -189,7 +189,12 @@ class VGG11_Weights(WeightsEnum):
meta={
**_COMMON_META,
"num_params": 132863336,
- "_metrics": {"ImageNet-1K": {"acc@1": 69.020, "acc@5": 88.628,}},
+ "_metrics": {
+ "ImageNet-1K": {
+ "acc@1": 69.020,
+ "acc@5": 88.628,
+ }
+ },
"_ops": 7.609,
"_file_size": 506.84,
},
@@ -204,7 +209,12 @@ class VGG11_BN_Weights(WeightsEnum):
meta={
**_COMMON_META,
"num_params": 132868840,
- "_metrics": {"ImageNet-1K": {"acc@1": 70.370, "acc@5": 89.810,}},
+ "_metrics": {
+ "ImageNet-1K": {
+ "acc@1": 70.370,
+ "acc@5": 89.810,
+ }
+ },
"_ops": 7.609,
"_file_size": 506.881,
},
@@ -219,7 +229,12 @@ class VGG13_Weights(WeightsEnum):
meta={
**_COMMON_META,
"num_params": 133047848,
- "_metrics": {"ImageNet-1K": {"acc@1": 69.928, "acc@5": 89.246,}},
+ "_metrics": {
+ "ImageNet-1K": {
+ "acc@1": 69.928,
+ "acc@5": 89.246,
+ }
+ },
"_ops": 11.308,
"_file_size": 507.545,
},
@@ -234,7 +249,12 @@ class VGG13_BN_Weights(WeightsEnum):
meta={
**_COMMON_META,
"num_params": 133053736,
- "_metrics": {"ImageNet-1K": {"acc@1": 71.586, "acc@5": 90.374,}},
+ "_metrics": {
+ "ImageNet-1K": {
+ "acc@1": 71.586,
+ "acc@5": 90.374,
+ }
+ },
"_ops": 11.308,
"_file_size": 507.59,
},
@@ -249,7 +269,12 @@ class VGG16_Weights(WeightsEnum):
meta={
**_COMMON_META,
"num_params": 138357544,
- "_metrics": {"ImageNet-1K": {"acc@1": 71.592, "acc@5": 90.382,}},
+ "_metrics": {
+ "ImageNet-1K": {
+ "acc@1": 71.592,
+ "acc@5": 90.382,
+ }
+ },
"_ops": 15.47,
"_file_size": 527.796,
},
@@ -269,7 +294,10 @@ class VGG16_Weights(WeightsEnum):
"categories": None,
"recipe": "https://github.com/amdegroot/ssd.pytorch#training-ssd",
"_metrics": {
- "ImageNet-1K": {"acc@1": float("nan"), "acc@5": float("nan"),}
+ "ImageNet-1K": {
+ "acc@1": float("nan"),
+ "acc@5": float("nan"),
+ }
},
"_ops": 15.47,
"_file_size": 527.802,
@@ -290,7 +318,12 @@ class VGG16_BN_Weights(WeightsEnum):
meta={
**_COMMON_META,
"num_params": 138365992,
- "_metrics": {"ImageNet-1K": {"acc@1": 73.360, "acc@5": 91.516,}},
+ "_metrics": {
+ "ImageNet-1K": {
+ "acc@1": 73.360,
+ "acc@5": 91.516,
+ }
+ },
"_ops": 15.47,
"_file_size": 527.866,
},
@@ -305,7 +338,12 @@ class VGG19_Weights(WeightsEnum):
meta={
**_COMMON_META,
"num_params": 143667240,
- "_metrics": {"ImageNet-1K": {"acc@1": 72.376, "acc@5": 90.876,}},
+ "_metrics": {
+ "ImageNet-1K": {
+ "acc@1": 72.376,
+ "acc@5": 90.876,
+ }
+ },
"_ops": 19.632,
"_file_size": 548.051,
},
@@ -320,7 +358,12 @@ class VGG19_BN_Weights(WeightsEnum):
meta={
**_COMMON_META,
"num_params": 143678248,
- "_metrics": {"ImageNet-1K": {"acc@1": 74.218, "acc@5": 91.842,}},
+ "_metrics": {
+ "ImageNet-1K": {
+ "acc@1": 74.218,
+ "acc@5": 91.842,
+ }
+ },
"_ops": 19.632,
"_file_size": 548.143,
},
diff --git a/src/chop/models/vision/snn/snn_toy.py b/src/chop/models/vision/snn/snn_toy.py
index 52c86d821..5915f1ec6 100644
--- a/src/chop/models/vision/snn/snn_toy.py
+++ b/src/chop/models/vision/snn/snn_toy.py
@@ -26,7 +26,9 @@ def forward(self, x: torch.Tensor):
# Getters ------------------------------------------------------------------------------
def get_snn_toy(
- info, pretrained=False, **kwargs: Any,
+ info,
+ pretrained=False,
+ **kwargs: Any,
):
tau = info["tau"]
num_classes = info.num_classes
diff --git a/src/chop/models/vision/snn/spikingResformer.py b/src/chop/models/vision/snn/spikingResformer.py
index 153cae155..ff74d76df 100644
--- a/src/chop/models/vision/snn/spikingResformer.py
+++ b/src/chop/models/vision/snn/spikingResformer.py
@@ -124,7 +124,11 @@ def no_weight_decay(self):
@register_model
def spikingresformer_ti(**kwargs):
return SpikingResformer(
- [["DSSA", "GWFFN"] * 1, ["DSSA", "GWFFN"] * 2, ["DSSA", "GWFFN"] * 3,],
+ [
+ ["DSSA", "GWFFN"] * 1,
+ ["DSSA", "GWFFN"] * 2,
+ ["DSSA", "GWFFN"] * 3,
+ ],
[64, 192, 384],
[1, 3, 6],
[4, 2, 1],
@@ -136,7 +140,11 @@ def spikingresformer_ti(**kwargs):
@register_model
def spikingresformer_s(**kwargs):
return SpikingResformer(
- [["DSSA", "GWFFN"] * 1, ["DSSA", "GWFFN"] * 2, ["DSSA", "GWFFN"] * 3,],
+ [
+ ["DSSA", "GWFFN"] * 1,
+ ["DSSA", "GWFFN"] * 2,
+ ["DSSA", "GWFFN"] * 3,
+ ],
[64, 256, 512],
[1, 4, 8],
[4, 2, 1],
@@ -148,7 +156,11 @@ def spikingresformer_s(**kwargs):
@register_model
def spikingresformer_m(**kwargs):
return SpikingResformer(
- [["DSSA", "GWFFN"] * 1, ["DSSA", "GWFFN"] * 2, ["DSSA", "GWFFN"] * 3,],
+ [
+ ["DSSA", "GWFFN"] * 1,
+ ["DSSA", "GWFFN"] * 2,
+ ["DSSA", "GWFFN"] * 3,
+ ],
[64, 384, 768],
[1, 6, 12],
[4, 2, 1],
@@ -160,7 +172,11 @@ def spikingresformer_m(**kwargs):
@register_model
def spikingresformer_l(**kwargs):
return SpikingResformer(
- [["DSSA", "GWFFN"] * 1, ["DSSA", "GWFFN"] * 2, ["DSSA", "GWFFN"] * 3,],
+ [
+ ["DSSA", "GWFFN"] * 1,
+ ["DSSA", "GWFFN"] * 2,
+ ["DSSA", "GWFFN"] * 3,
+ ],
[128, 512, 1024],
[2, 8, 16],
[4, 2, 1],
@@ -172,13 +188,18 @@ def spikingresformer_l(**kwargs):
@register_model
def spikingresformer_dvsg(**kwargs):
return SpikingResformer(
- [["DSSA", "GWFFN"] * 1, ["DSSA", "GWFFN"] * 2, ["DSSA", "GWFFN"] * 3,],
+ [
+ ["DSSA", "GWFFN"] * 1,
+ ["DSSA", "GWFFN"] * 2,
+ ["DSSA", "GWFFN"] * 3,
+ ],
[32, 96, 192],
[1, 3, 6],
[4, 2, 1],
in_channels=3,
prologue=nn.Sequential(
- Conv2d(3, 32, 3, 1, 1, bias=False, step_mode="m"), BN(32),
+ Conv2d(3, 32, 3, 1, 1, bias=False, step_mode="m"),
+ BN(32),
),
group_size=32,
activation=PLIF,
@@ -189,13 +210,18 @@ def spikingresformer_dvsg(**kwargs):
@register_model
def spikingresformer_cifar(**kwargs):
return SpikingResformer(
- [["DSSA", "GWFFN"] * 1, ["DSSA", "GWFFN"] * 2, ["DSSA", "GWFFN"] * 3,],
+ [
+ ["DSSA", "GWFFN"] * 1,
+ ["DSSA", "GWFFN"] * 2,
+ ["DSSA", "GWFFN"] * 3,
+ ],
[64, 192, 384],
[1, 3, 6],
[4, 2, 1],
in_channels=3,
prologue=nn.Sequential(
- Conv2d(3, 64, 3, 1, 1, bias=False, step_mode="m"), BN(64),
+ Conv2d(3, 64, 3, 1, 1, bias=False, step_mode="m"),
+ BN(64),
),
**kwargs,
)
diff --git a/src/chop/nn/functional/softermax.py b/src/chop/nn/functional/softermax.py
index 348d7ebea..8653a6aa0 100644
--- a/src/chop/nn/functional/softermax.py
+++ b/src/chop/nn/functional/softermax.py
@@ -11,7 +11,7 @@ def softermax(input: Tensor, dim: int) -> Tensor:
Tensor: Output tensor
"""
out = input - input.max(dim=dim, keepdim=True).values.floor()
- out = 2 ** out
+ out = 2**out
row_sum = out.sum(dim=dim, keepdim=True)
# Elementwise division
out = out / row_sum
diff --git a/src/chop/nn/modules/gqa.py b/src/chop/nn/modules/gqa.py
index 38188d51d..f30b7a8d8 100644
--- a/src/chop/nn/modules/gqa.py
+++ b/src/chop/nn/modules/gqa.py
@@ -127,7 +127,12 @@ def _qkv_states(self, x: Tensor, batch_size: int, seq_len: int):
# return x
def _attention_mechanism(
- self, query: Tensor, key: Tensor, value: Tensor, batch_size: int, seq_len: int,
+ self,
+ query: Tensor,
+ key: Tensor,
+ value: Tensor,
+ batch_size: int,
+ seq_len: int,
):
key = repeat_kv(key, n_rep=self.group_size)
value = repeat_kv(value, n_rep=self.group_size)
@@ -164,7 +169,9 @@ def forward(self, x: Tensor):
GROUPS = 4
gqa_module = GroupedQueryAttention(
- embed_dim=EMBED_DIM, num_heads=NUM_HEADS, num_kv_heads=GROUPS,
+ embed_dim=EMBED_DIM,
+ num_heads=NUM_HEADS,
+ num_kv_heads=GROUPS,
)
x_in = torch.rand(BATCH, SEQ_LEN, EMBED_DIM)
diff --git a/src/chop/nn/modules/lora.py b/src/chop/nn/modules/lora.py
index f101d0c03..bbd2ea023 100644
--- a/src/chop/nn/modules/lora.py
+++ b/src/chop/nn/modules/lora.py
@@ -93,7 +93,11 @@ def reset_lora_parameters(self, adapter_name):
class LinearLora(nn.Linear, LoraLayer):
# Lora implemented in a dense layer
def __init__(
- self, in_features: int, out_features: int, config: dict = None, **kwargs,
+ self,
+ in_features: int,
+ out_features: int,
+ config: dict = None,
+ **kwargs,
):
self.config = config
init_lora_weights = self.config.get("init_lora_weights", True)
@@ -218,7 +222,12 @@ def mark_only_lora_as_trainable(model: nn.Module, bias: str = "none") -> None:
# Simple Lora implementation from https://pytorch.org/torchtune/stable/tutorials/lora_finetune.html
class LoRALinear(nn.Module):
def __init__(
- self, in_dim: int, out_dim: int, rank: int, alpha: float, dropout: float,
+ self,
+ in_dim: int,
+ out_dim: int,
+ rank: int,
+ alpha: float,
+ dropout: float,
):
super().__init__()
# These are the weights from the original pretrained model
diff --git a/src/chop/nn/modules/sparse.py b/src/chop/nn/modules/sparse.py
index 5a3f749a1..67508cf3d 100644
--- a/src/chop/nn/modules/sparse.py
+++ b/src/chop/nn/modules/sparse.py
@@ -92,7 +92,11 @@ def reset_sparse_parameters(self, adapter_name):
class LinearSparse(nn.Linear, SparseLayer):
def __init__(
- self, in_features: int, out_features: int, config: dict = None, **kwargs,
+ self,
+ in_features: int,
+ out_features: int,
+ config: dict = None,
+ **kwargs,
):
self.config = config
init_sparse_weights = self.config.get("init_sparse_weights", True)
@@ -155,7 +159,9 @@ def _linear(self, input: torch.Tensor) -> torch.Tensor:
def update_weight_selection(self, k):
w_flat = self.weight.flatten()
_, self.idx = torch.topk(
- self.index_method(w_flat, self.idx_method), k, sorted=True,
+ self.index_method(w_flat, self.idx_method),
+ k,
+ sorted=True,
)
self.selected_weights = torch.gather(w_flat, dim=0, index=self.idx)
@@ -186,7 +192,10 @@ def forward(self, x: torch.Tensor):
# Scatter adapted values into weight tensor
adapted_weights = torch.scatter(
- self.zero_tensor.to(x.device), dim=0, index=self.idx, src=scaled_output,
+ self.zero_tensor.to(x.device),
+ dim=0,
+ index=self.idx,
+ src=scaled_output,
).view(self.unflattened_size)
self.step += 1
@@ -194,7 +203,9 @@ def forward(self, x: torch.Tensor):
x = x.to(sparse.weight.dtype)
result = F.linear(
- dropout(x), transpose(new_weight, self.fan_in_fan_out), bias=self.bias,
+ dropout(x),
+ transpose(new_weight, self.fan_in_fan_out),
+ bias=self.bias,
)
else:
diff --git a/src/chop/nn/mx/activations.py b/src/chop/nn/mx/activations.py
index 27fafc352..39702cc29 100644
--- a/src/chop/nn/mx/activations.py
+++ b/src/chop/nn/mx/activations.py
@@ -434,7 +434,10 @@ def forward(ctx, input, inplace=False, mx_specs=None, name=None):
@staticmethod
def backward(ctx, grad_output):
- (y, sig_x,) = ctx.saved_tensors
+ (
+ y,
+ sig_x,
+ ) = ctx.saved_tensors
grad_output = vec_quantize(grad_output, mx_specs=ctx.mx_specs)
temp = vec_sub(1.0, sig_x, mx_specs=ctx.mx_specs)
diff --git a/src/chop/nn/mx/bmm.py b/src/chop/nn/mx/bmm.py
index 50071d06e..6fac4eef8 100644
--- a/src/chop/nn/mx/bmm.py
+++ b/src/chop/nn/mx/bmm.py
@@ -67,7 +67,9 @@ def backward(ctx, grad_out):
in1, in2 = ctx.saved_tensors
grad_out = quantize_elemwise_op(
- grad_out, mx_specs=ctx.mx_specs, round=ctx.mx_specs["round_grad_input"],
+ grad_out,
+ mx_specs=ctx.mx_specs,
+ round=ctx.mx_specs["round_grad_input"],
)
#####################################################
@@ -109,10 +111,14 @@ def backward(ctx, grad_out):
# element-wise quantize for grad_in1 and grad_in2
grad_in1 = quantize_elemwise_op(
- grad_in1, mx_specs=ctx.mx_specs, round=ctx.mx_specs["round_grad_input"],
+ grad_in1,
+ mx_specs=ctx.mx_specs,
+ round=ctx.mx_specs["round_grad_input"],
)
grad_in2 = quantize_elemwise_op(
- grad_in2, mx_specs=ctx.mx_specs, round=ctx.mx_specs["round_grad_input"],
+ grad_in2,
+ mx_specs=ctx.mx_specs,
+ round=ctx.mx_specs["round_grad_input"],
)
return (grad_in1, grad_in2, None, None)
diff --git a/src/chop/nn/mx/convolution.py b/src/chop/nn/mx/convolution.py
index 9d2aaed17..62b8dd2d2 100644
--- a/src/chop/nn/mx/convolution.py
+++ b/src/chop/nn/mx/convolution.py
@@ -180,10 +180,16 @@ def forward(
# weight is (out_channels, in_channels/groups, ..)
# quantize along in_channels
qid_input = quantize_mx_op(
- bf_in, mx_specs, elem_format=mx_specs["a_elem_format"], axes=[1],
+ bf_in,
+ mx_specs,
+ elem_format=mx_specs["a_elem_format"],
+ axes=[1],
)
qid_weight = quantize_mx_op(
- bf_weight, mx_specs, elem_format=mx_specs["w_elem_format"], axes=[1],
+ bf_weight,
+ mx_specs,
+ elem_format=mx_specs["w_elem_format"],
+ axes=[1],
)
# compute output
@@ -207,7 +213,9 @@ def backward(ctx, grad_output):
input, weight = ctx.saved_tensors
grad_output = quantize_elemwise_op(
- grad_output, mx_specs=ctx.mx_specs, round=ctx.mx_specs["round_grad_input"],
+ grad_output,
+ mx_specs=ctx.mx_specs,
+ round=ctx.mx_specs["round_grad_input"],
)
#####################################################
@@ -217,7 +225,10 @@ def backward(ctx, grad_output):
# output is (batch, out_channels, ...)
# quantize along the batch dim
qex_input = quantize_mx_op(
- input, ctx.mx_specs, elem_format=ctx.mx_specs["a_elem_format"], axes=[0],
+ input,
+ ctx.mx_specs,
+ elem_format=ctx.mx_specs["a_elem_format"],
+ axes=[0],
)
qex_grad_output = quantize_mx_op(
grad_output,
@@ -240,7 +251,9 @@ def backward(ctx, grad_output):
# element-wise quantize for grad_weight
grad_weight = quantize_elemwise_op(
- grad_weight, mx_specs=ctx.mx_specs, round=ctx.mx_specs["round_grad_weight"],
+ grad_weight,
+ mx_specs=ctx.mx_specs,
+ round=ctx.mx_specs["round_grad_weight"],
)
#####################################################
@@ -251,7 +264,10 @@ def backward(ctx, grad_output):
# output is (batch, out_channels, ...)
# reduction dim is out_channels
qod_weight = quantize_mx_op(
- weight, ctx.mx_specs, elem_format=ctx.mx_specs["w_elem_format"], axes=[0],
+ weight,
+ ctx.mx_specs,
+ elem_format=ctx.mx_specs["w_elem_format"],
+ axes=[0],
)
qod_grad_output = quantize_mx_op(
grad_output,
@@ -273,7 +289,9 @@ def backward(ctx, grad_output):
# element-wise quantize for grad_input
grad_input = quantize_elemwise_op(
- grad_input, mx_specs=ctx.mx_specs, round=ctx.mx_specs["round_grad_input"],
+ grad_input,
+ mx_specs=ctx.mx_specs,
+ round=ctx.mx_specs["round_grad_input"],
)
#####################################################
diff --git a/src/chop/nn/mx/elemwise_ops.py b/src/chop/nn/mx/elemwise_ops.py
index 21067bba1..5d2b4b8ec 100644
--- a/src/chop/nn/mx/elemwise_ops.py
+++ b/src/chop/nn/mx/elemwise_ops.py
@@ -33,16 +33,16 @@
# exponents smaller than -126
def _safe_lshift(x, bits, exp):
if exp is None:
- return x * (2 ** bits)
+ return x * (2**bits)
else:
- return x / (2 ** exp) * (2 ** bits)
+ return x / (2**exp) * (2**bits)
def _safe_rshift(x, bits, exp):
if exp is None:
- return x / (2 ** bits)
+ return x / (2**bits)
else:
- return x / (2 ** bits) * (2 ** exp)
+ return x / (2**bits) * (2**exp)
def _round_mantissa(A, bits, round, clamp=False):
diff --git a/src/chop/nn/mx/formats.py b/src/chop/nn/mx/formats.py
index ff2a9d726..9a2c35189 100644
--- a/src/chop/nn/mx/formats.py
+++ b/src/chop/nn/mx/formats.py
@@ -52,14 +52,14 @@ def from_str(s):
def _get_min_norm(ebits):
"""Valid for all float formats"""
emin = 2 - (2 ** (ebits - 1))
- return 0 if ebits == 0 else 2 ** emin
+ return 0 if ebits == 0 else 2**emin
def _get_max_norm(ebits, mbits):
"""Valid only for floats that define NaN"""
assert ebits >= 5, "invalid for floats that don't define NaN"
emax = 0 if ebits == 0 else 2 ** (ebits - 1) - 1
- return 2 ** emax * float(2 ** (mbits - 1) - 1) / 2 ** (mbits - 2)
+ return 2**emax * float(2 ** (mbits - 1) - 1) / 2 ** (mbits - 2)
_FORMAT_CACHE = {}
@@ -121,9 +121,9 @@ def _get_format_params(fmt):
raise Exception("Unknown element format %s" % fmt)
if fmt != ElemFormat.fp8_e4m3:
- max_norm = 2 ** emax * float(2 ** (mbits - 1) - 1) / 2 ** (mbits - 2)
+ max_norm = 2**emax * float(2 ** (mbits - 1) - 1) / 2 ** (mbits - 2)
else:
- max_norm = 2 ** emax * 1.75 # FP8 has custom max_norm
+ max_norm = 2**emax * 1.75 # FP8 has custom max_norm
min_norm = _get_min_norm(ebits)
diff --git a/src/chop/nn/mx/linear.py b/src/chop/nn/mx/linear.py
index ef78fcedc..64f761c86 100644
--- a/src/chop/nn/mx/linear.py
+++ b/src/chop/nn/mx/linear.py
@@ -18,7 +18,12 @@
class LinearFunction(torch.autograd.Function):
@staticmethod
def forward(
- ctx, input, weight, bias=None, mx_specs=None, name=None,
+ ctx,
+ input,
+ weight,
+ bias=None,
+ mx_specs=None,
+ name=None,
):
# element-wise quantize for input
bf_in = quantize_elemwise_op(
@@ -85,7 +90,9 @@ def backward(ctx, grad_output):
in_dim = weight.shape[1]
grad_output = quantize_elemwise_op(
- grad_output, mx_specs=ctx.mx_specs, round=ctx.mx_specs["round_grad_input"],
+ grad_output,
+ mx_specs=ctx.mx_specs,
+ round=ctx.mx_specs["round_grad_input"],
)
#####################################################
@@ -115,7 +122,9 @@ def backward(ctx, grad_output):
# Compute grad_weight
grad_weight = torch_matmul(qex_grad_output.transpose(0, 1), qex_input)
grad_weight = quantize_elemwise_op(
- grad_weight, mx_specs=ctx.mx_specs, round=ctx.mx_specs["round_grad_weight"],
+ grad_weight,
+ mx_specs=ctx.mx_specs,
+ round=ctx.mx_specs["round_grad_weight"],
)
#####################################################
@@ -141,7 +150,9 @@ def backward(ctx, grad_output):
# Compute grad_input
grad_input = torch_matmul(qos_grad_output, qos_weight)
grad_input = quantize_elemwise_op(
- grad_input, mx_specs=ctx.mx_specs, round=ctx.mx_specs["round_grad_input"],
+ grad_input,
+ mx_specs=ctx.mx_specs,
+ round=ctx.mx_specs["round_grad_input"],
)
#####################################################
@@ -161,7 +172,11 @@ def backward(ctx, grad_output):
def linear(
- input, weight, bias=None, mx_specs=None, name=None,
+ input,
+ weight,
+ bias=None,
+ mx_specs=None,
+ name=None,
):
mx_assert_test(mx_specs)
if mx_specs is None:
@@ -174,7 +189,12 @@ def linear(
class Linear(torch.nn.Linear):
def __init__(
- self, in_features, out_features, bias=True, mx_specs=None, name=None,
+ self,
+ in_features,
+ out_features,
+ bias=True,
+ mx_specs=None,
+ name=None,
):
mx_assert_test(mx_specs)
self.mx_none = mx_specs is None
diff --git a/src/chop/nn/mx/matmul.py b/src/chop/nn/mx/matmul.py
index 3c2e0ffb4..8c18914b7 100644
--- a/src/chop/nn/mx/matmul.py
+++ b/src/chop/nn/mx/matmul.py
@@ -114,7 +114,9 @@ def backward(ctx, grad_out):
in1, in2 = ctx.saved_tensors
grad_out = quantize_elemwise_op(
- grad_out, mx_specs=ctx.mx_specs, round=ctx.mx_specs["round_grad_input"],
+ grad_out,
+ mx_specs=ctx.mx_specs,
+ round=ctx.mx_specs["round_grad_input"],
)
#####################################################
@@ -158,10 +160,14 @@ def backward(ctx, grad_out):
# element-wise quantize for grad_in1 and grad_in2
grad_in1 = quantize_elemwise_op(
- grad_in1, mx_specs=ctx.mx_specs, round=ctx.mx_specs["round_grad_input"],
+ grad_in1,
+ mx_specs=ctx.mx_specs,
+ round=ctx.mx_specs["round_grad_input"],
)
grad_in2 = quantize_elemwise_op(
- grad_in2, mx_specs=ctx.mx_specs, round=ctx.mx_specs["round_grad_input"],
+ grad_in2,
+ mx_specs=ctx.mx_specs,
+ round=ctx.mx_specs["round_grad_input"],
)
#####################################################
diff --git a/src/chop/nn/mx/mx_ops.py b/src/chop/nn/mx/mx_ops.py
index aad12a52f..136441bea 100644
--- a/src/chop/nn/mx/mx_ops.py
+++ b/src/chop/nn/mx/mx_ops.py
@@ -281,7 +281,10 @@ def _quantize_mx(
else:
# Get shared exponents
shared_exp = _shared_exponents(
- A, method=shared_exp_method, axes=shared_exp_axes, ebits=0,
+ A,
+ method=shared_exp_method,
+ axes=shared_exp_axes,
+ ebits=0,
)
# Flush subnormal FP32 inputs to zero
@@ -296,7 +299,7 @@ def _quantize_mx(
shared_exp[shared_exp > scale_emax] = float("NaN")
shared_exp[shared_exp < -scale_emax] = -scale_emax
- A = A / (2 ** shared_exp)
+ A = A / (2**shared_exp)
A = _quantize_elemwise_core(
A,
@@ -309,7 +312,7 @@ def _quantize_mx(
custom_cuda=custom_cuda,
)
- A = A * (2 ** shared_exp)
+ A = A * (2**shared_exp)
# Undo tile reshaping
if block_size:
diff --git a/src/chop/nn/mx/quantize.py b/src/chop/nn/mx/quantize.py
index b979b18d7..7e0b907fa 100644
--- a/src/chop/nn/mx/quantize.py
+++ b/src/chop/nn/mx/quantize.py
@@ -42,7 +42,9 @@ def forward(ctx, x, mx_specs, round=None):
@staticmethod
def backward(ctx, grad_output):
grad_input = quantize_elemwise_op(
- grad_output, mx_specs=ctx.mx_specs, round=ctx.round,
+ grad_output,
+ mx_specs=ctx.mx_specs,
+ round=ctx.round,
)
return (grad_input, None, None)
diff --git a/src/chop/nn/mx/simd_ops.py b/src/chop/nn/mx/simd_ops.py
index 28180b758..475b47ef7 100644
--- a/src/chop/nn/mx/simd_ops.py
+++ b/src/chop/nn/mx/simd_ops.py
@@ -307,7 +307,7 @@ def forward(ctx, in1, mx_specs=None):
else:
ctx.save_for_backward(in1)
- return vec_quantize(qin1 ** 2, mx_specs=mx_specs)
+ return vec_quantize(qin1**2, mx_specs=mx_specs)
@staticmethod
def backward(ctx, g):
diff --git a/src/chop/nn/mx/transpose_convolution.py b/src/chop/nn/mx/transpose_convolution.py
index 5d6c5da8a..af570025b 100644
--- a/src/chop/nn/mx/transpose_convolution.py
+++ b/src/chop/nn/mx/transpose_convolution.py
@@ -75,10 +75,16 @@ def forward(
# weight is (in_channels, out_channels/groups, ...)
# quantize along in_channels
qid_input = quantize_mx_op(
- bf_in, mx_specs, elem_format=mx_specs["a_elem_format"], axes=[1],
+ bf_in,
+ mx_specs,
+ elem_format=mx_specs["a_elem_format"],
+ axes=[1],
)
qid_weight = quantize_mx_op(
- bf_weight, mx_specs, elem_format=mx_specs["w_elem_format"], axes=[0],
+ bf_weight,
+ mx_specs,
+ elem_format=mx_specs["w_elem_format"],
+ axes=[0],
)
# compute output
@@ -108,7 +114,9 @@ def backward(ctx, grad_output):
input, weight = ctx.saved_tensors
grad_output = quantize_elemwise_op(
- grad_output, mx_specs=ctx.mx_specs, round=ctx.mx_specs["round_grad_input"],
+ grad_output,
+ mx_specs=ctx.mx_specs,
+ round=ctx.mx_specs["round_grad_input"],
)
#####################################################
@@ -118,7 +126,10 @@ def backward(ctx, grad_output):
# output is (batch, out_channels, ...)
# quantize along the batch dim
qex_input = quantize_mx_op(
- input, ctx.mx_specs, elem_format=ctx.mx_specs["a_elem_format"], axes=[0],
+ input,
+ ctx.mx_specs,
+ elem_format=ctx.mx_specs["a_elem_format"],
+ axes=[0],
)
qex_grad_output = quantize_mx_op(
grad_output,
@@ -139,7 +150,9 @@ def backward(ctx, grad_output):
# element-wise quantize for grad_weight
grad_weight = quantize_elemwise_op(
- grad_weight, mx_specs=ctx.mx_specs, round=ctx.mx_specs["round_grad_weight"],
+ grad_weight,
+ mx_specs=ctx.mx_specs,
+ round=ctx.mx_specs["round_grad_weight"],
)
#####################################################
@@ -149,7 +162,10 @@ def backward(ctx, grad_output):
# output is (batch, out_channels, ...)
# reduction dim is out_channels
qod_weight = quantize_mx_op(
- weight, ctx.mx_specs, elem_format=ctx.mx_specs["w_elem_format"], axes=[1],
+ weight,
+ ctx.mx_specs,
+ elem_format=ctx.mx_specs["w_elem_format"],
+ axes=[1],
)
qod_grad_output = quantize_mx_op(
grad_output,
@@ -171,7 +187,9 @@ def backward(ctx, grad_output):
# element-wise quantize for grad_input
grad_input = quantize_elemwise_op(
- grad_input, mx_specs=ctx.mx_specs, round=ctx.mx_specs["round_grad_input"],
+ grad_input,
+ mx_specs=ctx.mx_specs,
+ round=ctx.mx_specs["round_grad_input"],
)
#####################################################
diff --git a/src/chop/nn/optical/modules/morr_conv2d.py b/src/chop/nn/optical/modules/morr_conv2d.py
index c95cd1bb3..13f9532c3 100644
--- a/src/chop/nn/optical/modules/morr_conv2d.py
+++ b/src/chop/nn/optical/modules/morr_conv2d.py
@@ -109,7 +109,7 @@ def __init__(
self.v_max = 10.8
self.v_pi = 4.36
- self.gamma = np.pi / self.v_pi ** 2
+ self.gamma = np.pi / self.v_pi**2
self.w_bit = 32
self.in_bit = 32
self.MORRConfig = MORRConfig
@@ -123,7 +123,7 @@ def __init__(
### calculate FWHM (rad)
self.morr_fwhm = (
-4
- * np.pi ** 2
+ * np.pi**2
* MORRConfig.radius
* MORRConfig.effective_index
* (
@@ -241,7 +241,7 @@ def reset_parameters(self, morr_init: bool = False) -> None:
(t2 - t1) / (2.4 * self.morr_fwhm)
).item() ## 0~2.4 FWHM slope as a linear approximation
- self.sigma_out_scale = 4 / (3 * self.grid_dim_x ** 0.5 * g * self.morr_fwhm)
+ self.sigma_out_scale = 4 / (3 * self.grid_dim_x**0.5 * g * self.morr_fwhm)
self.out_scale_quant_gain = None
init.normal_(self.morr_output_scale, 0, self.sigma_out_scale)
diff --git a/src/chop/nn/optical/modules/morr_linear.py b/src/chop/nn/optical/modules/morr_linear.py
index c1181cfc1..f93281af5 100644
--- a/src/chop/nn/optical/modules/morr_linear.py
+++ b/src/chop/nn/optical/modules/morr_linear.py
@@ -62,7 +62,7 @@ def __init__(
self.v_max = 10.8
self.v_pi = 4.36
- self.gamma = np.pi / self.v_pi ** 2
+ self.gamma = np.pi / self.v_pi**2
self.w_bit = 32
self.in_bit = 32
@@ -80,7 +80,7 @@ def __init__(
### calculate FWHM (rad)
self.morr_fwhm = (
-4
- * np.pi ** 2
+ * np.pi**2
* morr_config.radius
* morr_config.effective_index
* (
@@ -198,7 +198,7 @@ def reset_parameters(self, morr_init: bool = False) -> None:
(t2 - t1) / (2.4 * self.morr_fwhm)
).item() ## 0~2.4 FWHM slope as a linear approximation
- self.sigma_out_scale = 4 / (3 * self.grid_dim_x ** 0.5 * g * self.morr_fwhm)
+ self.sigma_out_scale = 4 / (3 * self.grid_dim_x**0.5 * g * self.morr_fwhm)
self.out_scale_quant_gain = None
init.normal_(self.morr_output_scale, 0, self.sigma_out_scale)
else:
diff --git a/src/chop/nn/optical/utils/initializer.py b/src/chop/nn/optical/utils/initializer.py
index 19c97b6dd..cbdd1f83c 100644
--- a/src/chop/nn/optical/utils/initializer.py
+++ b/src/chop/nn/optical/utils/initializer.py
@@ -31,7 +31,7 @@ def morr_uniform_(tensor, MORRConfig, n_op=4, biased=False, gain=1):
"""
morr_fwhm = (
-4
- * np.pi ** 2
+ * np.pi**2
* MORRConfig.radius
* MORRConfig.effective_index
* (
diff --git a/src/chop/nn/optical/utils/mrr_op.py b/src/chop/nn/optical/utils/mrr_op.py
index 6c397b3f5..189db8761 100644
--- a/src/chop/nn/optical/utils/mrr_op.py
+++ b/src/chop/nn/optical/utils/mrr_op.py
@@ -78,7 +78,7 @@ def mrr_roundtrip_phase_to_tr_func(
c1 = -2 * a * r
c2 = a * a + r * r
c3 = 1 + r * r * a * a - a * a - r * r
- c4 = (a ** 2 - 1) * (r ** 2 - 1) * 2 * a * r
+ c4 = (a**2 - 1) * (r**2 - 1) * 2 * a * r
class MRRRoundTripPhaseToTrFunction(torch.autograd.Function):
@staticmethod
diff --git a/src/chop/nn/optical/utils/quantize.py b/src/chop/nn/optical/utils/quantize.py
index 828da142d..84372c8c7 100644
--- a/src/chop/nn/optical/utils/quantize.py
+++ b/src/chop/nn/optical/utils/quantize.py
@@ -32,7 +32,7 @@ def forward(ctx, input):
elif k == 1:
out = torch.sign(input)
else:
- n = float(2 ** k - 1)
+ n = float(2**k - 1)
out = torch.round(input * n) / n
return out
@@ -63,7 +63,7 @@ def forward(ctx, input, scale, zero_point):
elif k == 1:
out = torch.sign(input)
else:
- n = float(2 ** k - 1)
+ n = float(2**k - 1)
# out = torch.round(input * n) / n
# out = (torch.clamp(torch.round(input / scale + zero_point), 0, n) - zero_point) * scale
out = (
@@ -133,7 +133,7 @@ def __init__(
qscheme=torch.per_tensor_affine,
reduce_range=False,
quant_min=0,
- quant_max=2 ** self.in_bit - 1,
+ quant_max=2**self.in_bit - 1,
).to(self.device)
else:
self.obs = None
diff --git a/src/chop/nn/quantized/functional/gelu.py b/src/chop/nn/quantized/functional/gelu.py
index 225ff70ce..cee5e3317 100644
--- a/src/chop/nn/quantized/functional/gelu.py
+++ b/src/chop/nn/quantized/functional/gelu.py
@@ -107,7 +107,9 @@ def gelu_log(x, inplace=False, config=None):
)
x_quantizer = partial(
- log_quantizer, width=x_width, exponent_bias=x_exponent_bias,
+ log_quantizer,
+ width=x_width,
+ exponent_bias=x_exponent_bias,
)
return F.gelu(x_quantizer(x), inplace=inplace)
diff --git a/src/chop/nn/quantized/functional/linear.py b/src/chop/nn/quantized/functional/linear.py
index f518f4968..5fda700de 100644
--- a/src/chop/nn/quantized/functional/linear.py
+++ b/src/chop/nn/quantized/functional/linear.py
@@ -72,7 +72,10 @@ def linearInteger(
def linearMinifloatDenorm(
- x: Tensor, weight: Tensor, bias: Tensor = None, config: dict = None,
+ x: Tensor,
+ weight: Tensor,
+ bias: Tensor = None,
+ config: dict = None,
):
w_width, w_exponent_width, w_exponent_bias = (
config["weight_width"],
@@ -119,7 +122,10 @@ def linearMinifloatDenorm(
def linearMinifloatIEEE(
- x: Tensor, weight: Tensor, bias: Tensor = None, config: dict = None,
+ x: Tensor,
+ weight: Tensor,
+ bias: Tensor = None,
+ config: dict = None,
):
w_width, w_exponent_width, w_exponent_bias = (
config["weight_width"],
@@ -166,7 +172,10 @@ def linearMinifloatIEEE(
def linearMinifloatIEEE(
- x: Tensor, weight: Tensor, bias: Tensor = None, config: dict = None,
+ x: Tensor,
+ weight: Tensor,
+ bias: Tensor = None,
+ config: dict = None,
):
w_width, w_exponent_width, w_exponent_bias = (
config["weight_width"],
@@ -213,7 +222,10 @@ def linearMinifloatIEEE(
def linearLog(
- x: Tensor, weight: Tensor, bias: Tensor = None, config: dict = None,
+ x: Tensor,
+ weight: Tensor,
+ bias: Tensor = None,
+ config: dict = None,
):
w_width, w_exponent_bias = (
config["weight_width"],
@@ -228,11 +240,23 @@ def linearLog(
config["bias_exponent_bias"],
)
- w_quantizer = partial(log_quantizer, width=w_width, exponent_bias=w_exponent_bias,)
+ w_quantizer = partial(
+ log_quantizer,
+ width=w_width,
+ exponent_bias=w_exponent_bias,
+ )
- x_quantizer = partial(log_quantizer, width=x_width, exponent_bias=x_exponent_bias,)
+ x_quantizer = partial(
+ log_quantizer,
+ width=x_width,
+ exponent_bias=x_exponent_bias,
+ )
- b_quantizer = partial(log_quantizer, width=b_width, exponent_bias=b_exponent_bias,)
+ b_quantizer = partial(
+ log_quantizer,
+ width=b_width,
+ exponent_bias=b_exponent_bias,
+ )
x = x_quantizer(x)
weight = w_quantizer(weight)
@@ -242,7 +266,10 @@ def linearLog(
def linearBlockFP(
- x: Tensor, weight: Tensor, bias: Tensor = None, config: dict = None,
+ x: Tensor,
+ weight: Tensor,
+ bias: Tensor = None,
+ config: dict = None,
):
# establish quantizers
w_width, w_exponent_width, w_exponent_bias, w_block_size = (
@@ -300,7 +327,10 @@ def linearBlockFP(
def linearBlockMinifloat(
- x: Tensor, weight: Tensor, bias: Tensor = None, config: dict = None,
+ x: Tensor,
+ weight: Tensor,
+ bias: Tensor = None,
+ config: dict = None,
):
# establish quantizers
w_width, w_exponent_width, w_exponent_bias_width, w_block_size = (
@@ -358,7 +388,10 @@ def linearBlockMinifloat(
def linearBlockLog(
- x: Tensor, weight: Tensor, bias: Tensor = None, config: dict = None,
+ x: Tensor,
+ weight: Tensor,
+ bias: Tensor = None,
+ config: dict = None,
):
# establish quantizers
w_width, w_exponent_bias_width, w_block_size = (
@@ -410,7 +443,10 @@ def linearBlockLog(
def linearBinary(
- x: Tensor, weight: Tensor, bias: Tensor = None, config: dict = None,
+ x: Tensor,
+ weight: Tensor,
+ bias: Tensor = None,
+ config: dict = None,
):
w_stochastic = config["weight_stochastic"]
w_bipolar = config["weight_bipolar"]
@@ -426,7 +462,10 @@ def linearBinary(
def linearBinaryScaling(
- x: Tensor, weight: Tensor, bias: Tensor = None, config: dict = None,
+ x: Tensor,
+ weight: Tensor,
+ bias: Tensor = None,
+ config: dict = None,
):
"""
Binary scaling variant of the linear transformation layer.
@@ -474,7 +513,10 @@ def linearBinaryScaling(
def linearTernary(
- x: Tensor, weight: Tensor, bias: Tensor = None, config: dict = None,
+ x: Tensor,
+ weight: Tensor,
+ bias: Tensor = None,
+ config: dict = None,
):
w_scaling_factor = config["weight_scaling_factor"]
w_mean = get_stats(config, "weight_mean")
@@ -498,19 +540,28 @@ def linearTernary(
def linearBinaryResidualSign(
- x: Tensor, weight: Tensor, bias: Tensor = None, config: dict = None,
+ x: Tensor,
+ weight: Tensor,
+ bias: Tensor = None,
+ config: dict = None,
):
raise NotImplementedError
def linearLUT(
- x: Tensor, weight: Tensor, bias: Tensor = None, config: dict = None,
+ x: Tensor,
+ weight: Tensor,
+ bias: Tensor = None,
+ config: dict = None,
):
raise NotImplementedError
def linearLogicNets(
- x: Tensor, weight: Tensor, bias: Tensor = None, config: dict = None,
+ x: Tensor,
+ weight: Tensor,
+ bias: Tensor = None,
+ config: dict = None,
):
raise NotImplementedError
diff --git a/src/chop/nn/quantized/functional/matmul.py b/src/chop/nn/quantized/functional/matmul.py
index f6487dc8f..d06eb1ece 100644
--- a/src/chop/nn/quantized/functional/matmul.py
+++ b/src/chop/nn/quantized/functional/matmul.py
@@ -176,10 +176,14 @@ def generic_matmul_log(x, y, config, style="matmul"):
)
x_quantizer = partial(
- log_quantizer, width=x_width, exponent_bias=x_exponent_bias,
+ log_quantizer,
+ width=x_width,
+ exponent_bias=x_exponent_bias,
)
y_quantizer = partial(
- log_quantizer, width=y_width, exponent_bias=y_exponent_bias,
+ log_quantizer,
+ width=y_width,
+ exponent_bias=y_exponent_bias,
)
x = x_quantizer(x)
y = y_quantizer(y)
diff --git a/src/chop/nn/quantized/functional/relu.py b/src/chop/nn/quantized/functional/relu.py
index cb57d078e..57daed04a 100644
--- a/src/chop/nn/quantized/functional/relu.py
+++ b/src/chop/nn/quantized/functional/relu.py
@@ -107,7 +107,9 @@ def relu_log(x, inplace=False, config=None):
)
x_quantizer = partial(
- log_quantizer, width=x_width, exponent_bias=x_exponent_bias,
+ log_quantizer,
+ width=x_width,
+ exponent_bias=x_exponent_bias,
)
return F.relu(x_quantizer(x), inplace=inplace)
diff --git a/src/chop/nn/quantized/functional/selu.py b/src/chop/nn/quantized/functional/selu.py
index b1edebdbf..12956c392 100644
--- a/src/chop/nn/quantized/functional/selu.py
+++ b/src/chop/nn/quantized/functional/selu.py
@@ -107,7 +107,9 @@ def selu_log(x, inplace=False, config=None):
)
x_quantizer = partial(
- log_quantizer, width=x_width, exponent_bias=x_exponent_bias,
+ log_quantizer,
+ width=x_width,
+ exponent_bias=x_exponent_bias,
)
return F.selu(x_quantizer(x), inplace=inplace)
diff --git a/src/chop/nn/quantized/functional/softplus.py b/src/chop/nn/quantized/functional/softplus.py
index c873e7104..a9bd7dafc 100644
--- a/src/chop/nn/quantized/functional/softplus.py
+++ b/src/chop/nn/quantized/functional/softplus.py
@@ -107,7 +107,9 @@ def softplus_log(x, inplace=False, config=None):
)
x_quantizer = partial(
- log_quantizer, width=x_width, exponent_bias=x_exponent_bias,
+ log_quantizer,
+ width=x_width,
+ exponent_bias=x_exponent_bias,
)
return F.softplus(x_quantizer(x), inplace=inplace)
diff --git a/src/chop/nn/quantized/functional/softsign.py b/src/chop/nn/quantized/functional/softsign.py
index c60f5f757..3eaab47fa 100644
--- a/src/chop/nn/quantized/functional/softsign.py
+++ b/src/chop/nn/quantized/functional/softsign.py
@@ -107,7 +107,9 @@ def softsign_log(x, inplace=False, config=None):
)
x_quantizer = partial(
- log_quantizer, width=x_width, exponent_bias=x_exponent_bias,
+ log_quantizer,
+ width=x_width,
+ exponent_bias=x_exponent_bias,
)
return F.softsign(x_quantizer(x), inplace=inplace)
diff --git a/src/chop/nn/quantized/functional/tanh.py b/src/chop/nn/quantized/functional/tanh.py
index 8b1009ac0..7d3c67c31 100644
--- a/src/chop/nn/quantized/functional/tanh.py
+++ b/src/chop/nn/quantized/functional/tanh.py
@@ -107,7 +107,9 @@ def tanh_log(x, inplace=False, config=None):
)
x_quantizer = partial(
- log_quantizer, width=x_width, exponent_bias=x_exponent_bias,
+ log_quantizer,
+ width=x_width,
+ exponent_bias=x_exponent_bias,
)
return F.tanh(x_quantizer(x), inplace=inplace)
diff --git a/src/chop/nn/quantized/modules/attention_head.py b/src/chop/nn/quantized/modules/attention_head.py
index 8176f4d52..93f4801e8 100644
--- a/src/chop/nn/quantized/modules/attention_head.py
+++ b/src/chop/nn/quantized/modules/attention_head.py
@@ -63,7 +63,10 @@ class BertSelfAttentionHeadInteger(_BertSelfAttentionHeadBase):
def __init__(self, config, q_config: dict = None) -> None:
super().__init__(config)
- self.query_quantizer = partial(integer_quantizer, **q_config,)
+ self.query_quantizer = partial(
+ integer_quantizer,
+ **q_config,
+ )
self.key_quantizer = partial(integer_quantizer, **q_config)
self.value_quantizer = partial(integer_quantizer, **q_config)
diff --git a/src/chop/nn/quantized/modules/conv1d.py b/src/chop/nn/quantized/modules/conv1d.py
index 03662097a..67654d917 100644
--- a/src/chop/nn/quantized/modules/conv1d.py
+++ b/src/chop/nn/quantized/modules/conv1d.py
@@ -274,15 +274,21 @@ def __init__(
)
self.w_quantizer = partial(
- log_quantizer, width=w_width, exponent_bias=w_exponent_bias,
+ log_quantizer,
+ width=w_width,
+ exponent_bias=w_exponent_bias,
)
self.x_quantizer = partial(
- log_quantizer, width=x_width, exponent_bias=x_exponent_bias,
+ log_quantizer,
+ width=x_width,
+ exponent_bias=x_exponent_bias,
)
self.b_quantizer = partial(
- log_quantizer, width=b_width, exponent_bias=b_exponent_bias,
+ log_quantizer,
+ width=b_width,
+ exponent_bias=b_exponent_bias,
)
diff --git a/src/chop/nn/quantized/modules/conv2d.py b/src/chop/nn/quantized/modules/conv2d.py
index 3d297842e..cc8cd982a 100644
--- a/src/chop/nn/quantized/modules/conv2d.py
+++ b/src/chop/nn/quantized/modules/conv2d.py
@@ -363,15 +363,21 @@ def __init__(
)
self.w_quantizer = partial(
- log_quantizer, width=w_width, exponent_bias=w_exponent_bias,
+ log_quantizer,
+ width=w_width,
+ exponent_bias=w_exponent_bias,
)
self.x_quantizer = partial(
- log_quantizer, width=x_width, exponent_bias=x_exponent_bias,
+ log_quantizer,
+ width=x_width,
+ exponent_bias=x_exponent_bias,
)
self.b_quantizer = partial(
- log_quantizer, width=b_width, exponent_bias=b_exponent_bias,
+ log_quantizer,
+ width=b_width,
+ exponent_bias=b_exponent_bias,
)
@@ -424,15 +430,21 @@ def __init__(
)
self.w_quantizer = partial(
- log_quantizer, width=w_width, exponent_bias=w_exponent_bias,
+ log_quantizer,
+ width=w_width,
+ exponent_bias=w_exponent_bias,
)
self.x_quantizer = partial(
- log_quantizer, width=x_width, exponent_bias=x_exponent_bias,
+ log_quantizer,
+ width=x_width,
+ exponent_bias=x_exponent_bias,
)
self.b_quantizer = partial(
- log_quantizer, width=b_width, exponent_bias=b_exponent_bias,
+ log_quantizer,
+ width=b_width,
+ exponent_bias=b_exponent_bias,
)
@@ -1128,7 +1140,10 @@ def __init__(
)
self.unfold = torch.nn.Unfold(
- kernel_size=kernel_size, dilation=dilation, padding=padding, stride=stride,
+ kernel_size=kernel_size,
+ dilation=dilation,
+ padding=padding,
+ stride=stride,
)
self.fold = torch.nn.Fold(
@@ -1228,7 +1243,11 @@ def forward(
expanded_input, targets, initalize
).squeeze() # [10, 589824]
output = output.view(
- batch_size, self.out_channels, self._out_dim(0), self._out_dim(1), -1,
+ batch_size,
+ self.out_channels,
+ self._out_dim(0),
+ self._out_dim(1),
+ -1,
).sum(
-1
) # [10, 256, 1, 1, 2304] -> [10, 256, 1, 1]
@@ -1411,10 +1430,10 @@ def forward(self, x: Tensor) -> Tensor:
return self.decode(self.lut_forward(x))
def encode(self, input: Tensor) -> Tensor:
- return input * 2 ** self.x_frac_width
+ return input * 2**self.x_frac_width
def decode(self, input: Tensor) -> Tensor:
- return input / 2 ** self.x_frac_width
+ return input / 2**self.x_frac_width
def math_forward(self, input: Tensor) -> Tensor:
return self.y_quantizer(
diff --git a/src/chop/nn/quantized/modules/gelu.py b/src/chop/nn/quantized/modules/gelu.py
index ace6f5510..4e579efd1 100644
--- a/src/chop/nn/quantized/modules/gelu.py
+++ b/src/chop/nn/quantized/modules/gelu.py
@@ -123,7 +123,9 @@ def __init__(self, inplace: bool = False, config: dict = None):
config["data_in_exponent_bias"],
)
self.x_quantizer = partial(
- log_quantizer, width=x_width, exponent_bias=x_exponent_bias,
+ log_quantizer,
+ width=x_width,
+ exponent_bias=x_exponent_bias,
)
@@ -141,7 +143,9 @@ def __init__(self, inplace: bool = False, config: dict = None):
config["data_in_exponent_bias"],
)
self.x_quantizer = partial(
- log_quantizer, width=x_width, exponent_bias=x_exponent_bias,
+ log_quantizer,
+ width=x_width,
+ exponent_bias=x_exponent_bias,
)
diff --git a/src/chop/nn/quantized/modules/gqa.py b/src/chop/nn/quantized/modules/gqa.py
index 6cac74dda..1445b05ec 100644
--- a/src/chop/nn/quantized/modules/gqa.py
+++ b/src/chop/nn/quantized/modules/gqa.py
@@ -89,7 +89,10 @@ def __init__(
)
self.v_matmul_func = partial(
- matmul_integer, config=config, out_config=out_config, floor=floor,
+ matmul_integer,
+ config=config,
+ out_config=out_config,
+ floor=floor,
)
o_projection_q_config = {
diff --git a/src/chop/nn/quantized/modules/linear.py b/src/chop/nn/quantized/modules/linear.py
index b49c86c54..5d8d389a5 100644
--- a/src/chop/nn/quantized/modules/linear.py
+++ b/src/chop/nn/quantized/modules/linear.py
@@ -66,7 +66,11 @@ def __init__(
dtype=None,
) -> None:
super().__init__(
- in_features, out_features, bias, device, dtype,
+ in_features,
+ out_features,
+ bias,
+ device,
+ dtype,
)
self.bypass = False
self.pruning_masks = None
@@ -444,7 +448,9 @@ def forward(self, x: Tensor) -> Tensor:
if self.binary_training:
w = self.w_quantizer(self.weight)
return F.linear(
- x_expanded, w * self.gamma.abs() * self.pruning_masks, self.bias,
+ x_expanded,
+ w * self.gamma.abs() * self.pruning_masks,
+ self.bias,
)
else:
self.weigh = self.weight.data.clamp_(-1, 1)
@@ -554,7 +560,9 @@ def forward(
output = output.view(batch_size, -1)
assert output.shape[-1] == self.tables_count
output = output.view(
- batch_size, self.out_features, int(self.tables_count / self.out_features),
+ batch_size,
+ self.out_features,
+ int(self.tables_count / self.out_features),
)
output = output.sum(-1)
if self.bias is not None:
@@ -765,10 +773,10 @@ def run_layers(self, input: Tensor, layers) -> Tensor:
return y
def encode(self, input: Tensor) -> Tensor:
- return input * 2 ** self.x_frac_width
+ return input * 2**self.x_frac_width
def decode(self, input: Tensor) -> Tensor:
- return input / 2 ** self.x_frac_width
+ return input / 2**self.x_frac_width
def forward(self, x: Tensor) -> Tensor:
if self.is_lut_inference:
diff --git a/src/chop/nn/quantized/modules/relu.py b/src/chop/nn/quantized/modules/relu.py
index a4d1acfbf..2bc527161 100644
--- a/src/chop/nn/quantized/modules/relu.py
+++ b/src/chop/nn/quantized/modules/relu.py
@@ -121,7 +121,9 @@ def __init__(self, inplace: bool = False, config: dict = None):
config["data_in_exponent_bias"],
)
self.x_quantizer = partial(
- log_quantizer, width=x_width, exponent_bias=x_exponent_bias,
+ log_quantizer,
+ width=x_width,
+ exponent_bias=x_exponent_bias,
)
@@ -139,7 +141,9 @@ def __init__(self, inplace: bool = False, config: dict = None):
config["data_in_exponent_bias"],
)
self.x_quantizer = partial(
- log_quantizer, width=x_width, exponent_bias=x_exponent_bias,
+ log_quantizer,
+ width=x_width,
+ exponent_bias=x_exponent_bias,
)
diff --git a/src/chop/nn/quantized/modules/selu.py b/src/chop/nn/quantized/modules/selu.py
index 482c7c5d1..066ffc0b7 100644
--- a/src/chop/nn/quantized/modules/selu.py
+++ b/src/chop/nn/quantized/modules/selu.py
@@ -121,7 +121,9 @@ def __init__(self, inplace: bool = False, config: dict = None):
config["data_in_exponent_bias"],
)
self.x_quantizer = partial(
- log_quantizer, width=x_width, exponent_bias=x_exponent_bias,
+ log_quantizer,
+ width=x_width,
+ exponent_bias=x_exponent_bias,
)
@@ -139,7 +141,9 @@ def __init__(self, inplace: bool = False, config: dict = None):
config["data_in_exponent_bias"],
)
self.x_quantizer = partial(
- log_quantizer, width=x_width, exponent_bias=x_exponent_bias,
+ log_quantizer,
+ width=x_width,
+ exponent_bias=x_exponent_bias,
)
diff --git a/src/chop/nn/quantized/modules/silu.py b/src/chop/nn/quantized/modules/silu.py
index 30dac9e5d..07f18ef6e 100644
--- a/src/chop/nn/quantized/modules/silu.py
+++ b/src/chop/nn/quantized/modules/silu.py
@@ -113,7 +113,9 @@ def __init__(self, inplace: bool = False, config: dict = None):
config["data_in_exponent_bias"],
)
self.x_quantizer = partial(
- log_quantizer, width=x_width, exponent_bias=x_exponent_bias,
+ log_quantizer,
+ width=x_width,
+ exponent_bias=x_exponent_bias,
)
@@ -131,7 +133,9 @@ def __init__(self, inplace: bool = False, config: dict = None):
config["data_in_exponent_bias"],
)
self.x_quantizer = partial(
- log_quantizer, width=x_width, exponent_bias=x_exponent_bias,
+ log_quantizer,
+ width=x_width,
+ exponent_bias=x_exponent_bias,
)
diff --git a/src/chop/nn/quantized/modules/softplus.py b/src/chop/nn/quantized/modules/softplus.py
index 458558b63..4e8465c56 100644
--- a/src/chop/nn/quantized/modules/softplus.py
+++ b/src/chop/nn/quantized/modules/softplus.py
@@ -121,7 +121,9 @@ def __init__(self, inplace: bool = False, config: dict = None):
config["data_in_exponent_bias"],
)
self.x_quantizer = partial(
- log_quantizer, width=x_width, exponent_bias=x_exponent_bias,
+ log_quantizer,
+ width=x_width,
+ exponent_bias=x_exponent_bias,
)
@@ -139,7 +141,9 @@ def __init__(self, inplace: bool = False, config: dict = None):
config["data_in_exponent_bias"],
)
self.x_quantizer = partial(
- log_quantizer, width=x_width, exponent_bias=x_exponent_bias,
+ log_quantizer,
+ width=x_width,
+ exponent_bias=x_exponent_bias,
)
diff --git a/src/chop/nn/quantized/modules/softsign.py b/src/chop/nn/quantized/modules/softsign.py
index fe3e53b62..5497426aa 100644
--- a/src/chop/nn/quantized/modules/softsign.py
+++ b/src/chop/nn/quantized/modules/softsign.py
@@ -121,7 +121,9 @@ def __init__(self, inplace: bool = False, config: dict = None):
config["data_in_exponent_bias"],
)
self.x_quantizer = partial(
- log_quantizer, width=x_width, exponent_bias=x_exponent_bias,
+ log_quantizer,
+ width=x_width,
+ exponent_bias=x_exponent_bias,
)
@@ -139,7 +141,9 @@ def __init__(self, inplace: bool = False, config: dict = None):
config["data_in_exponent_bias"],
)
self.x_quantizer = partial(
- log_quantizer, width=x_width, exponent_bias=x_exponent_bias,
+ log_quantizer,
+ width=x_width,
+ exponent_bias=x_exponent_bias,
)
diff --git a/src/chop/nn/quantized/modules/tanh.py b/src/chop/nn/quantized/modules/tanh.py
index 3378a612f..fce343489 100644
--- a/src/chop/nn/quantized/modules/tanh.py
+++ b/src/chop/nn/quantized/modules/tanh.py
@@ -121,7 +121,9 @@ def __init__(self, inplace: bool = False, config: dict = None):
config["data_in_exponent_bias"],
)
self.x_quantizer = partial(
- log_quantizer, width=x_width, exponent_bias=x_exponent_bias,
+ log_quantizer,
+ width=x_width,
+ exponent_bias=x_exponent_bias,
)
@@ -139,7 +141,9 @@ def __init__(self, inplace: bool = False, config: dict = None):
config["data_in_exponent_bias"],
)
self.x_quantizer = partial(
- log_quantizer, width=x_width, exponent_bias=x_exponent_bias,
+ log_quantizer,
+ width=x_width,
+ exponent_bias=x_exponent_bias,
)
diff --git a/src/chop/nn/quantizers/LUTNet/BaseInitializer.py b/src/chop/nn/quantizers/LUTNet/BaseInitializer.py
index 72478f46d..14cf53839 100644
--- a/src/chop/nn/quantizers/LUTNet/BaseInitializer.py
+++ b/src/chop/nn/quantizers/LUTNet/BaseInitializer.py
@@ -62,7 +62,10 @@ def update_luts_weights(self) -> torch.Tensor:
key = row.detach().cpu().flatten().sign().numpy().tolist()
new_weights.append(key)
new_weights = torch.tensor(
- new_weights, dtype=torch.float32, requires_grad=True, device=self.device,
+ new_weights,
+ dtype=torch.float32,
+ requires_grad=True,
+ device=self.device,
).view(-1, self.kk)
return new_weights
diff --git a/src/chop/nn/quantizers/LUTNet/BaseTrainer.py b/src/chop/nn/quantizers/LUTNet/BaseTrainer.py
index 45d95805c..ff48436e0 100644
--- a/src/chop/nn/quantizers/LUTNet/BaseTrainer.py
+++ b/src/chop/nn/quantizers/LUTNet/BaseTrainer.py
@@ -38,7 +38,7 @@ def __init__(
levels (int): Number of residual level to use.
"""
self.k = k
- self.kk = 2 ** k
+ self.kk = 2**k
self.binarization_level = binarization_level
self.input_expanded = input_expanded
self.tables_count = tables_count
diff --git a/src/chop/nn/quantizers/block_fp.py b/src/chop/nn/quantizers/block_fp.py
index 826133862..bfe8534e3 100644
--- a/src/chop/nn/quantizers/block_fp.py
+++ b/src/chop/nn/quantizers/block_fp.py
@@ -47,10 +47,10 @@ def _block_fp_quantize(
if exponent_bias in (None, "none", "None"):
exponent_bias = 2 ** (exponent_width - 1) - 1
- exponent_max = 2 ** exponent_width - 1 - exponent_bias
+ exponent_max = 2**exponent_width - 1 - exponent_bias
exponent_min = -exponent_bias
- mantissa_integer_max = 2 ** mantissa_bits - 1
+ mantissa_integer_max = 2**mantissa_bits - 1
# sign
per_block_sign = torch.sign(blocked_x + 1e-9)
# exponent
@@ -58,14 +58,14 @@ def _block_fp_quantize(
per_block_exponent = torch.ceil(torch.log2(per_block_max))
per_block_exponent = my_clamp(per_block_exponent, exponent_min, exponent_max)
# mantissa
- per_block_mantissa = per_block_value / 2 ** per_block_exponent
- shift = 2 ** mantissa_bits
+ per_block_mantissa = per_block_value / 2**per_block_exponent
+ shift = 2**mantissa_bits
per_block_mantissa_integer = my_clamp(
my_round(per_block_mantissa * shift), 0, mantissa_integer_max
)
per_block_mantissa = per_block_mantissa_integer / shift
- per_block_msfp = per_block_sign * (2 ** per_block_exponent) * per_block_mantissa
+ per_block_msfp = per_block_sign * (2**per_block_exponent) * per_block_mantissa
msfp_x = unblock(
per_block_msfp,
x_shape_before_blocking=x_shape_before_blocking,
@@ -133,5 +133,10 @@ def block_fp_quantizer(
"""
return BlockFPQuantize.apply(
- x, width, exponent_width, exponent_bias, block_size, skip_first_dim,
+ x,
+ width,
+ exponent_width,
+ exponent_bias,
+ block_size,
+ skip_first_dim,
)
diff --git a/src/chop/nn/quantizers/block_log.py b/src/chop/nn/quantizers/block_log.py
index 1773b42b8..8e65c91ea 100644
--- a/src/chop/nn/quantizers/block_log.py
+++ b/src/chop/nn/quantizers/block_log.py
@@ -40,7 +40,7 @@ def _block_log_quantize(
per_block_max_exponent = torch.ceil(torch.log2(per_block_max))
per_block_bias = my_clamp(
- 2 ** exponent_bits - 1 - per_block_max_exponent, 0, 2 ** exponent_bias_width - 1
+ 2**exponent_bits - 1 - per_block_max_exponent, 0, 2**exponent_bias_width - 1
)
per_block_lq_x = _log_quantize(blocked_x, width=width, exponent_bias=per_block_bias)
@@ -98,5 +98,9 @@ def block_log_quantizer(
- `block_size`: a list of integers where each integer is the block size along the corresponding dim
"""
return BlockLogQuantize.apply(
- x, width, exponent_bias_width, block_size, skip_first_dim,
+ x,
+ width,
+ exponent_bias_width,
+ block_size,
+ skip_first_dim,
)
diff --git a/src/chop/nn/quantizers/block_minifloat.py b/src/chop/nn/quantizers/block_minifloat.py
index ccef6649d..34e00bbcb 100644
--- a/src/chop/nn/quantizers/block_minifloat.py
+++ b/src/chop/nn/quantizers/block_minifloat.py
@@ -41,7 +41,7 @@ def _block_minifloat_quantize(
per_block_max[per_block_max == 0] = per_block_max[per_block_max != 0].min()
per_block_exponent_bias = my_clamp(
- torch.floor(torch.log2(per_block_max)), 0, 2 ** exponent_bias_width - 1
+ torch.floor(torch.log2(per_block_max)), 0, 2**exponent_bias_width - 1
)
per_block_bm_x = _minifloat_ieee_quantize(
blocked_x,
@@ -118,5 +118,10 @@ def block_minifloat_quantizer(
"""
return BlockMinifloatQuantize.apply(
- x, width, exponent_width, exponent_bias_width, block_size, skip_first_dim,
+ x,
+ width,
+ exponent_width,
+ exponent_bias_width,
+ block_size,
+ skip_first_dim,
)
diff --git a/src/chop/nn/quantizers/integer.py b/src/chop/nn/quantizers/integer.py
index 8c2573f22..9f3ffec1e 100644
--- a/src/chop/nn/quantizers/integer.py
+++ b/src/chop/nn/quantizers/integer.py
@@ -34,9 +34,9 @@ def _integer_quantize(
int_max = 2 ** (width - 1) - 1
else:
int_min = 0
- int_max = 2 ** width - 1
+ int_max = 2**width - 1
# thresh = 2 ** (width - 1)
- scale = 2 ** frac_width
+ scale = 2**frac_width
if isinstance(x, (Tensor, ndarray)):
return my_clamp(my_round(x.mul(scale)), int_min, int_max).div(scale)
@@ -57,8 +57,8 @@ def _integer_floor_quantize(
int_max = 2 ** (width - 1) - 1
else:
int_min = 0
- int_max = 2 ** width - 1
- scale = 2 ** frac_width
+ int_max = 2**width - 1
+ scale = 2**frac_width
if isinstance(x, (Tensor, ndarray)):
return my_clamp(my_floor(x.mul(scale)), int_min, int_max).div(scale)
diff --git a/src/chop/nn/quantizers/log.py b/src/chop/nn/quantizers/log.py
index 926e33cdb..98ab45e76 100644
--- a/src/chop/nn/quantizers/log.py
+++ b/src/chop/nn/quantizers/log.py
@@ -8,7 +8,9 @@
def _log_quantize(
- x: Tensor | ndarray, width: int, exponent_bias: int | Tensor | ndarray | None,
+ x: Tensor | ndarray,
+ width: int,
+ exponent_bias: int | Tensor | ndarray | None,
):
"""
- Use non-uniform, base-2 logarithmic representation to encode IEEE FP32/64
@@ -30,16 +32,16 @@ def _log_quantize(
if exponent_bias in (None, "none", "None"):
exponent_bias = 2 ** (exponent_bits - 1) - 1
- exponent_max = 2 ** exponent_bits - 1 - exponent_bias
+ exponent_max = 2**exponent_bits - 1 - exponent_bias
exponent_min = -exponent_bias
- min_pos = 2 ** exponent_min
+ min_pos = 2**exponent_min
sign = torch.sign(x + min_pos * 0.1)
value = torch.abs(x) + min_pos * 0.1
exponent = my_clamp(my_round(torch.log2(value)), exponent_min, exponent_max)
- return sign * (2 ** exponent)
+ return sign * (2**exponent)
class LogQuantize(torch.autograd.Function):
@@ -56,7 +58,9 @@ def backward(ctx, grad_output):
def log_quantizer(
- x: Tensor | ndarray, width: int, exponent_bias: int | Tensor | ndarray | None,
+ x: Tensor | ndarray,
+ width: int,
+ exponent_bias: int | Tensor | ndarray | None,
):
"""
Convert IEEE FP32/64 to base-2 log quantized values
diff --git a/src/chop/nn/quantizers/minifloat.py b/src/chop/nn/quantizers/minifloat.py
index f19097fde..2d6f23103 100644
--- a/src/chop/nn/quantizers/minifloat.py
+++ b/src/chop/nn/quantizers/minifloat.py
@@ -5,7 +5,10 @@
def _minifloat_denorm_quantize(
- x: Tensor, width: int, exponent_width: int, exponent_bias: int = None,
+ x: Tensor,
+ width: int,
+ exponent_width: int,
+ exponent_bias: int = None,
):
"""
- Converts IEEE FP32/64 to minifloat without the implicit leading bit in mantissas.
@@ -34,10 +37,10 @@ def _minifloat_denorm_quantize(
if exponent_bias in (None, "none", "None"):
exponent_bias = 2 ** (exponent_width - 1) - 1
- exponent_max = 2 ** exponent_width - 1 - exponent_bias
+ exponent_max = 2**exponent_width - 1 - exponent_bias
exponent_min = -exponent_bias
# if the mantissa is an integer, the max mantissa value will be (2**mantissa_bits -1)
- shifted_mantissa_max = 2 ** mantissa_bits - 1
+ shifted_mantissa_max = 2**mantissa_bits - 1
shifted_mantissa_min = 0
sign = torch.sign(x + 1e-9)
@@ -49,8 +52,8 @@ def _minifloat_denorm_quantize(
# divide value by clipped exponent. this ensures the simulated minifloat value is correct
# when x is too large (minifloat will saturate) or too close to 0.
- mantissa = value / 2 ** exponent
- shift = 2 ** mantissa_bits
+ mantissa = value / 2**exponent
+ shift = 2**mantissa_bits
shifted_mantissa = my_round(mantissa * shift)
# clip the integer mantissa.
shifted_mantissa = my_clamp(
@@ -68,7 +71,11 @@ def _minifloat_denorm_quantize(
class MinifloatDenormQuantize(torch.autograd.Function):
@staticmethod
def forward(
- ctx, x: Tensor, width: int, exponent_width: int, exponent_bias: int = None,
+ ctx,
+ x: Tensor,
+ width: int,
+ exponent_width: int,
+ exponent_bias: int = None,
):
return _minifloat_denorm_quantize(
x, width=width, exponent_width=exponent_width, exponent_bias=exponent_bias
@@ -81,7 +88,10 @@ def backward(ctx, grad_output):
def minifloat_denorm_quantizer(
- x: Tensor, width: int, exponent_width: int, exponent_bias: int = None,
+ x: Tensor,
+ width: int,
+ exponent_width: int,
+ exponent_bias: int = None,
):
"""
- Converts IEEE FP32/64 to minifloat without the implicit leading bit in mantissas.
@@ -138,11 +148,11 @@ def _minifloat_ieee_quantize(
if exponent_bias in (None, "none", "None"):
exponent_bias = 2 ** (exponent_width - 1) - 1
# upper and lower bound of shifted exponent
- exponent_max = 2 ** exponent_width - 1 - exponent_bias
+ exponent_max = 2**exponent_width - 1 - exponent_bias
exponent_min = -exponent_bias
# upper and lower bound of shifted minifloat mantissa
- shift = 2 ** mantissa_bits
- shifted_mantissa_max = 2 ** mantissa_bits - 1
+ shift = 2**mantissa_bits
+ shifted_mantissa_max = 2**mantissa_bits - 1
shifted_mantissa_min = 0
sign = torch.sign(x + 1e-9)
@@ -152,9 +162,9 @@ def _minifloat_ieee_quantize(
exponent = torch.floor(torch.log2(value + 1e-9))
exponent = my_clamp(exponent, exponent_min, exponent_max)
- mantissa = value / 2 ** exponent
+ mantissa = value / 2**exponent
- shift = 2 ** mantissa_bits
+ shift = 2**mantissa_bits
# fmt: off
# if the clipped exponent is zero, the minifloat is in a subnormal form
# this `is_normal` also help the grad keeps 1 if input x is 0, or the zero-initialized value will be trapped in 0
diff --git a/src/chop/nn/quantizers/mxint_hardware.py b/src/chop/nn/quantizers/mxint_hardware.py
index 1a61fb70c..0c3e06130 100644
--- a/src/chop/nn/quantizers/mxint_hardware.py
+++ b/src/chop/nn/quantizers/mxint_hardware.py
@@ -19,7 +19,7 @@ def mxint_quant_block(
"""
exponent_bias = 2 ** (exponent_width - 1)
- exponent_max = 2 ** exponent_width - 1 - exponent_bias
+ exponent_max = 2**exponent_width - 1 - exponent_bias
exponent_min = -exponent_bias
# exponent
@@ -29,9 +29,9 @@ def mxint_quant_block(
# mantissa
int_min = -(2 ** (width - 1))
int_max = 2 ** (width - 1) - 1
- mantissa = x / 2 ** exponent
+ mantissa = x / 2**exponent
mantissa = torch.clamp(mantissa.floor(), int_min, int_max)
- q_x = (2 ** exponent) * mantissa
+ q_x = (2**exponent) * mantissa
return q_x
diff --git a/src/chop/nn/quantizers/quantizers_for_hw.py b/src/chop/nn/quantizers/quantizers_for_hw.py
index ccf57319f..d5ca3d8cf 100644
--- a/src/chop/nn/quantizers/quantizers_for_hw.py
+++ b/src/chop/nn/quantizers/quantizers_for_hw.py
@@ -9,31 +9,31 @@
def integer_quantizer_for_hw(x: Tensor, width: int, frac_width: int):
thresh = 2 ** (width - 1)
- scale = 2 ** frac_width
+ scale = 2**frac_width
fixed_point_value = my_clamp(my_round(x.mul(scale)), -thresh, thresh - 1)
fixed_point_value = fixed_point_value.to(torch.int)
- fixed_point_value = fixed_point_value % (2 ** width)
+ fixed_point_value = fixed_point_value % (2**width)
return fixed_point_value
def unsigned_integer_quantizer_for_hw(x: Tensor, width: int, frac_width: int):
- thresh = 2 ** width - 1
- scale = 2 ** frac_width
+ thresh = 2**width - 1
+ scale = 2**frac_width
fixed_point_value = my_clamp(my_floor(x.mul(scale)), 0, thresh)
fixed_point_value = fixed_point_value.to(torch.int)
- fixed_point_value = fixed_point_value % (2 ** width)
+ fixed_point_value = fixed_point_value % (2**width)
return fixed_point_value
def integer_floor_quantizer_for_hw(x: Tensor, width: int, frac_width: int):
thresh = 2 ** (width - 1)
- scale = 2 ** frac_width
+ scale = 2**frac_width
fixed_point_value = my_clamp(my_floor(x.mul(scale)), -thresh, thresh - 1)
fixed_point_value = fixed_point_value.to(torch.int)
- fixed_point_value = fixed_point_value % (2 ** width)
+ fixed_point_value = fixed_point_value % (2**width)
return fixed_point_value
diff --git a/src/chop/nn/quantizers/ternary.py b/src/chop/nn/quantizers/ternary.py
index 89fcfe75b..a27951c23 100644
--- a/src/chop/nn/quantizers/ternary.py
+++ b/src/chop/nn/quantizers/ternary.py
@@ -45,7 +45,8 @@ def ternary_quantizer(
# )
if scaling_factor:
x = ternarised_scaled_op(
- x, threshold, # abs_mean=mean
+ x,
+ threshold, # abs_mean=mean
) # [mean, 0 ,-mean] # this function determines the mean on the fly, maybe we could make an alternative which uses the metadata?
else:
x = ternarised_op(x, threshold) # [1, 0 ,-1]
diff --git a/src/chop/nn/quantizers/utils.py b/src/chop/nn/quantizers/utils.py
index 9efc8b512..ae881328d 100644
--- a/src/chop/nn/quantizers/utils.py
+++ b/src/chop/nn/quantizers/utils.py
@@ -224,8 +224,16 @@ def forward(ctx, input, _threshold):
alpha = TernaryScaled.alpha(input, delta)
output = torch.zeros_like(input)
- pos_one = torch.where(input > delta, 1.0, 0.0,)
- neg_one = torch.where(input < -delta, -1.0, 0.0,)
+ pos_one = torch.where(
+ input > delta,
+ 1.0,
+ 0.0,
+ )
+ neg_one = torch.where(
+ input < -delta,
+ -1.0,
+ 0.0,
+ )
output = (pos_one + neg_one) * alpha.view(-1, 1, 1, 1).expand(
-1, input.size()[1], input.size()[2], input.size()[3]
)
@@ -287,8 +295,16 @@ def forward(ctx, input, _threshold):
alpha = TernaryScaled.alpha(input, delta)
output = torch.zeros_like(input)
- pos_one = torch.where(input > delta, 1.0, 0.0,)
- neg_one = torch.where(input < -delta, -1.0, 0.0,)
+ pos_one = torch.where(
+ input > delta,
+ 1.0,
+ 0.0,
+ )
+ neg_one = torch.where(
+ input < -delta,
+ -1.0,
+ 0.0,
+ )
output = pos_one + neg_one
return output
@@ -397,7 +413,8 @@ def _block_1d_bias(x: Tensor, block_shape: List[int]):
def _unblock_to_1d_bias(
- blocked_x: Tensor, x_shape_before_blocking: List[int],
+ blocked_x: Tensor,
+ x_shape_before_blocking: List[int],
):
"""
blocked bias shape: [num_blocks, block_size] -> [output_features]
@@ -592,7 +609,10 @@ def unblock(
return _unblock_to_2d_activation(blocked_x, x_shape_before_blocking)
else:
return _unblock_to_2d_weight(
- blocked_x, x_shape_before_blocking, padded_x_shape, block_shape,
+ blocked_x,
+ x_shape_before_blocking,
+ padded_x_shape,
+ block_shape,
)
elif len(x_shape_before_blocking) == 3:
if skipped_first_dim_when_blocking:
diff --git a/src/chop/nn/snn/auto_cuda/generator.py b/src/chop/nn/snn/auto_cuda/generator.py
index f1637ddf9..639cf0379 100644
--- a/src/chop/nn/snn/auto_cuda/generator.py
+++ b/src/chop/nn/snn/auto_cuda/generator.py
@@ -310,7 +310,10 @@ def gen_forward_codes(
params.append(("v_reset", "const float &"))
params.extend(
- [("neuron_num", "const int &"), ("numel", "const int &"),]
+ [
+ ("neuron_num", "const int &"),
+ ("numel", "const int &"),
+ ]
)
params_name = []
for item in params:
diff --git a/src/chop/nn/snn/modules/spiking_self_attention.py b/src/chop/nn/snn/modules/spiking_self_attention.py
index 3fa47d1d8..87ee98a30 100644
--- a/src/chop/nn/snn/modules/spiking_self_attention.py
+++ b/src/chop/nn/snn/modules/spiking_self_attention.py
@@ -145,7 +145,9 @@ def __init__(self, in_channels, num_conv=1, ratio=4, group_size=64, activation=L
super().__init__()
inner_channels = in_channels * ratio
self.up = nn.Sequential(
- activation(), Conv1x1(in_channels, inner_channels), BN(inner_channels),
+ activation(),
+ Conv1x1(in_channels, inner_channels),
+ BN(inner_channels),
)
self.conv = nn.ModuleList()
for _ in range(num_conv):
@@ -161,7 +163,9 @@ def __init__(self, in_channels, num_conv=1, ratio=4, group_size=64, activation=L
)
)
self.down = nn.Sequential(
- activation(), Conv1x1(inner_channels, in_channels), BN(in_channels),
+ activation(),
+ Conv1x1(inner_channels, in_channels),
+ BN(in_channels),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
diff --git a/src/chop/passes/graph/__init__.py b/src/chop/passes/graph/__init__.py
index 3274c9143..4b1eb96ef 100644
--- a/src/chop/passes/graph/__init__.py
+++ b/src/chop/passes/graph/__init__.py
@@ -149,6 +149,6 @@
if check_dependencies("tensorrt_fake_quantize_transform_pass"):
TRANSFORM_PASSES.append("tensorrt_fake_quantize_transform_pass")
- PASSES[
- "tensorrt_fake_quantize_transform_pass"
- ] = tensorrt_fake_quantize_transform_pass
+ PASSES["tensorrt_fake_quantize_transform_pass"] = (
+ tensorrt_fake_quantize_transform_pass
+ )
diff --git a/src/chop/passes/graph/analysis/add_metadata/add_common_metadata.py b/src/chop/passes/graph/analysis/add_metadata/add_common_metadata.py
index 9fde9d7cb..66fe9cc28 100644
--- a/src/chop/passes/graph/analysis/add_metadata/add_common_metadata.py
+++ b/src/chop/passes/graph/analysis/add_metadata/add_common_metadata.py
@@ -203,7 +203,10 @@ def graph_iterator_for_mase_ops(graph):
def graph_iterator_for_metadata(
- graph, dummy_in=None, add_value=True, force_device_meta=False,
+ graph,
+ dummy_in=None,
+ add_value=True,
+ force_device_meta=False,
):
"""
largely adapted from https://pytorch.org/docs/stable/fx.html
@@ -288,7 +291,12 @@ def _add_graph_metadata(graph):
def add_common_metadata_analysis_pass(
- graph, pass_args={"dummy_in": None, "add_value": True, "force_device_meta": False,},
+ graph,
+ pass_args={
+ "dummy_in": None,
+ "add_value": True,
+ "force_device_meta": False,
+ },
):
"""add common metadata
diff --git a/src/chop/passes/graph/analysis/add_metadata/add_hardware_metadata.py b/src/chop/passes/graph/analysis/add_metadata/add_hardware_metadata.py
index 9fd2e9466..39c8ac690 100644
--- a/src/chop/passes/graph/analysis/add_metadata/add_hardware_metadata.py
+++ b/src/chop/passes/graph/analysis/add_metadata/add_hardware_metadata.py
@@ -39,7 +39,8 @@ def add_component_source(node):
if mase_op == "user_defined_module":
for custom_op, op_info in node.meta["mase"].model.custom_ops["modules"].items():
if isinstance(
- deepgetattr(node.meta["mase"].model, node.target), custom_op,
+ deepgetattr(node.meta["mase"].model, node.target),
+ custom_op,
):
node.meta["mase"]["hardware"]["toolchain"] = "INTERNAL_RTL"
node.meta["mase"]["hardware"]["module"] = op_info["module"]
diff --git a/src/chop/passes/graph/analysis/add_metadata/common_metadata_layers.py b/src/chop/passes/graph/analysis/add_metadata/common_metadata_layers.py
index ed3249c9c..6f5ee0147 100644
--- a/src/chop/passes/graph/analysis/add_metadata/common_metadata_layers.py
+++ b/src/chop/passes/graph/analysis/add_metadata/common_metadata_layers.py
@@ -231,7 +231,10 @@
},
# Inserted ops from the replace_method_with_function pass
"torch_size": {"input": "data_in", "dim": "config"},
- "torch_contiguous": {"input": "data_in", "memory_format": "config",},
+ "torch_contiguous": {
+ "input": "data_in",
+ "memory_format": "config",
+ },
# arbitrary length - support up to 4
"torch_expand": {
"input": "data_in",
@@ -254,7 +257,11 @@
"shape_2": "config",
"shape_3": "config",
},
- "torch_split": {"input": "data_in", "split_size": "config", "dim": "config",},
+ "torch_split": {
+ "input": "data_in",
+ "split_size": "config",
+ "dim": "config",
+ },
"torch_permute": {
"input": "data_in",
"dim_0": "config",
@@ -262,7 +269,11 @@
"dim_2": "config",
"dim_3": "config",
},
- "torch_transpose": {"input": "data_in", "dim0": "config", "dim1": "config",},
+ "torch_transpose": {
+ "input": "data_in",
+ "dim0": "config",
+ "dim1": "config",
+ },
# DTensor ops
"dtensor_arange": {
"device_mesh": "config",
@@ -276,7 +287,9 @@
"requires_grad": "config",
},
# tensor constructor
- "tensor": {"data": "data_in",},
+ "tensor": {
+ "data": "data_in",
+ },
# https://pytorch.org/docs/stable/generated/torch.nn.functional.dropout.html
"dropout": {
"input": "data_in",
@@ -340,7 +353,10 @@
"softmax": {"input": "data_in"},
"gelu": {"input": "data_in"},
# https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html
- "crossentropyloss": {"input": "data_in", "target": "data_in",},
+ "crossentropyloss": {
+ "input": "data_in",
+ "target": "data_in",
+ },
# chop.nn.modules.lora.LoRALinear
"loralinear": {"input": "data_in"},
"grouped_query_attention": {"input": "data_in"},
@@ -389,15 +405,24 @@
"size_3": "config",
},
# Tensor.max(dim=None, keepdim=False)
- "max": {"dim": "config", "keepdim": "config",},
+ "max": {
+ "dim": "config",
+ "keepdim": "config",
+ },
# https://pytorch.org/docs/stable/generated/torch.Tensor.sum.html
- "sum": {"dim": "config", "keepdim": "config",},
+ "sum": {
+ "dim": "config",
+ "keepdim": "config",
+ },
# https://pytorch.org/docs/stable/generated/torch.Tensor.round.html
"round": {},
# https://pytorch.org/docs/stable/generated/torch.Tensor.floor.html
"floor": {},
# https://pytorch.org/docs/stable/generated/torch.Tensor.clamp.html
- "clamp": {"min": "config", "max": "config",},
+ "clamp": {
+ "min": "config",
+ "max": "config",
+ },
# https://pytorch.org/docs/stable/generated/torch.Tensor.dim.html
"dim": {},
# https://pytorch.org/docs/stable/generated/torch.Tensor.permute.html#torch.Tensor.permute
@@ -426,7 +451,11 @@
# https://pytorch.org/docs/stable/generated/torch.Tensor.type_as.html
"type_as": {"tensor": "data_in"},
# https://pytorch.org/docs/stable/generated/torch.Tensor.index_select.html
- "index_select": {"input": "data_in", "dim": "config", "index": "data_in",},
+ "index_select": {
+ "input": "data_in",
+ "dim": "config",
+ "index": "data_in",
+ },
# https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html
"detach": {"input": "data_in"},
}
@@ -479,7 +508,11 @@ def deepgetattr(obj, attr):
def _annotate_arg_metadata(
- meta: MaseMetadata, args: list, kwargs: dict, func_data: dict, add_value: bool,
+ meta: MaseMetadata,
+ args: list,
+ kwargs: dict,
+ func_data: dict,
+ add_value: bool,
):
"""
Analyse target args and kwargs received from shape propagation to annotate combined meta["mase"]["args"]
@@ -588,7 +621,9 @@ def _annotate_arg_metadata(
def _annotate_result_metadata(
- meta: MaseMetadata, result, add_value: bool,
+ meta: MaseMetadata,
+ result,
+ add_value: bool,
) -> MaseMetadata:
"""
Analyse the result from running the target to annotate the meta["mase"]["results"] dictionary with metadata.
diff --git a/src/chop/passes/graph/analysis/add_metadata/hardware_metadata_layers.py b/src/chop/passes/graph/analysis/add_metadata/hardware_metadata_layers.py
index a92958cff..b084821ae 100644
--- a/src/chop/passes/graph/analysis/add_metadata/hardware_metadata_layers.py
+++ b/src/chop/passes/graph/analysis/add_metadata/hardware_metadata_layers.py
@@ -64,13 +64,17 @@
"relu": [
{
"name": "fixed_relu",
- "dependence_files": ["activation_layers/rtl/fixed_relu.sv",],
+ "dependence_files": [
+ "activation_layers/rtl/fixed_relu.sv",
+ ],
},
],
"hardshrink": [
{
"name": "fixed_hardshrink",
- "dependence_files": ["activation_layers/rtl/fixed_hardshrink.sv",],
+ "dependence_files": [
+ "activation_layers/rtl/fixed_hardshrink.sv",
+ ],
},
],
"silu": [
@@ -103,7 +107,9 @@
"softshrink": [
{
"name": "fixed_softshrink",
- "dependence_files": ["activation_layers/rtl/fixed_softshrink.sv",],
+ "dependence_files": [
+ "activation_layers/rtl/fixed_softshrink.sv",
+ ],
},
],
"logsigmoid": [
@@ -132,13 +138,17 @@
"selu": [
{
"name": "fixed_selu",
- "dependence_files": ["activation_layers/rtl/fixed_selu.sv",],
+ "dependence_files": [
+ "activation_layers/rtl/fixed_selu.sv",
+ ],
},
],
"tanh": [
{
"name": "fixed_tanh",
- "dependence_files": ["activation_layers/rtl/fixed_tanh.sv",],
+ "dependence_files": [
+ "activation_layers/rtl/fixed_tanh.sv",
+ ],
},
],
"gelu": [
@@ -162,13 +172,17 @@
"softplus": [
{
"name": "fixed_softplus",
- "dependence_files": ["activation_layers/rtl/fixed_softplus.sv",],
+ "dependence_files": [
+ "activation_layers/rtl/fixed_softplus.sv",
+ ],
},
],
"add": [
{
"name": "fixed_adder",
- "dependence_files": ["linear_layers/fixed_operators/rtl/fixed_adder.sv",],
+ "dependence_files": [
+ "linear_layers/fixed_operators/rtl/fixed_adder.sv",
+ ],
}
],
"mul": [
@@ -185,7 +199,14 @@
"dependence_files": ["common/rtl/df_split.sv", "common/rtl/split2.sv"],
}
],
- "getitem": [{"name": "buffer", "dependence_files": ["memory/rtl/buffer.sv",],}],
+ "getitem": [
+ {
+ "name": "buffer",
+ "dependence_files": [
+ "memory/rtl/buffer.sv",
+ ],
+ }
+ ],
"grouped_query_attention": [
{
"name": "fixed_gqa_wrapper",
diff --git a/src/chop/passes/graph/analysis/add_metadata/software_metadata_layers.py b/src/chop/passes/graph/analysis/add_metadata/software_metadata_layers.py
index 1e6ad2a57..d7f4fd8ed 100644
--- a/src/chop/passes/graph/analysis/add_metadata/software_metadata_layers.py
+++ b/src/chop/passes/graph/analysis/add_metadata/software_metadata_layers.py
@@ -243,8 +243,16 @@ def analyze_software_meta_param_patched_func_default(meta):
"getitem": analyze_software_meta_param_implicit_func_default,
"getattr": analyze_software_meta_param_implicit_func_default,
},
- "placeholder": {"placeholder": analyze_software_meta_param_placeholder,},
- "get_attr": {"get_attr": analyze_software_meta_param_get_attr,},
- "output": {"output": analyze_software_meta_param_output,},
- "patched_func": {"default": analyze_software_meta_param_patched_func_default,},
+ "placeholder": {
+ "placeholder": analyze_software_meta_param_placeholder,
+ },
+ "get_attr": {
+ "get_attr": analyze_software_meta_param_get_attr,
+ },
+ "output": {
+ "output": analyze_software_meta_param_output,
+ },
+ "patched_func": {
+ "default": analyze_software_meta_param_patched_func_default,
+ },
}
diff --git a/src/chop/passes/graph/analysis/autosharding/alpa_cost_modelling.py b/src/chop/passes/graph/analysis/autosharding/alpa_cost_modelling.py
index 30216eea3..d9fea1aa5 100644
--- a/src/chop/passes/graph/analysis/autosharding/alpa_cost_modelling.py
+++ b/src/chop/passes/graph/analysis/autosharding/alpa_cost_modelling.py
@@ -74,7 +74,10 @@ def get_resharding_cost(
and (dest[0] == SpmdShard.R)
):
ag_dim = 1 if src[0] == dest[0] else 0
- return mesh.all_gather_cost(num_bytes=num_bytes, mesh_dim=ag_dim,)
+ return mesh.all_gather_cost(
+ num_bytes=num_bytes,
+ mesh_dim=ag_dim,
+ )
# All-to-all
# E.g. (R, S_0) -> (S_0, R), (S_1, R) -> (R, S_1)
@@ -82,7 +85,10 @@ def get_resharding_cost(
# all to all
a2a_dim = src[0].value if src[0] != SpmdShard.R else src[1].value
try:
- return mesh.all_to_all_cost(num_bytes=num_bytes, mesh_dim=a2a_dim,)
+ return mesh.all_to_all_cost(
+ num_bytes=num_bytes,
+ mesh_dim=a2a_dim,
+ )
except:
assert False
diff --git a/src/chop/passes/graph/analysis/autosharding/autosharding.py b/src/chop/passes/graph/analysis/autosharding/autosharding.py
index b49154c20..b149edf4a 100644
--- a/src/chop/passes/graph/analysis/autosharding/autosharding.py
+++ b/src/chop/passes/graph/analysis/autosharding/autosharding.py
@@ -23,7 +23,10 @@ def deepgetattr(obj, attr, default=None):
def _import_solution(
- mg, solution: dict, mesh: MeshModel, extrapolate_sharding: bool = True,
+ mg,
+ solution: dict,
+ mesh: MeshModel,
+ extrapolate_sharding: bool = True,
):
"""Import an autosharding solution into the metadata of the MaseGraph.
@@ -68,14 +71,18 @@ def _import_solution(
# Annotate the metadata for each argument
for arg, arg_spec in solution[node.name].get("args", {}).items():
node.meta["mase"]["common"]["args"][arg]["dtensor_spec"] = _DTensorSpec(
- mesh=mesh, placements=arg_spec,
+ mesh=mesh,
+ placements=arg_spec,
)
# Annotate the metadata for each result
for result, result_spec in solution[node.name].get("results", {}).items():
- node.meta["mase"]["common"]["results"][result][
- "dtensor_spec"
- ] = _DTensorSpec(mesh=mesh, placements=result_spec,)
+ node.meta["mase"]["common"]["results"][result]["dtensor_spec"] = (
+ _DTensorSpec(
+ mesh=mesh,
+ placements=result_spec,
+ )
+ )
return mg, {}
@@ -107,7 +114,10 @@ def _export_solution(mg, export_file: str = "ilp_solution.pkl"):
logger.warning(
f"DTensor spec not found for arg: {arg} in node: {node_name}. Assigning fully-replicated solution."
)
- spec = _DTensorSpec(None, (Replicate(), Replicate()),)
+ spec = _DTensorSpec(
+ None,
+ (Replicate(), Replicate()),
+ )
else:
spec = arg_info["dtensor_spec"]
@@ -122,7 +132,10 @@ def _export_solution(mg, export_file: str = "ilp_solution.pkl"):
logger.warning(
f"DTensor spec not found for result: {result} in node: {node_name}. Assigning fully-replicated solution."
)
- spec = _DTensorSpec(None, (Replicate(), Replicate()),)
+ spec = _DTensorSpec(
+ None,
+ (Replicate(), Replicate()),
+ )
else:
spec = result_info["dtensor_spec"]
out_dict[node_name]["results"][result] = spec.placements
@@ -184,7 +197,9 @@ def _get_sharding_map(mg):
if module not in tensor_sharding_map:
tensor_sharding_map[module] = {
"node": node.name,
- "sharding": {attr: out_specs,},
+ "sharding": {
+ attr: out_specs,
+ },
}
else:
tensor_sharding_map[module]["sharding"][attr] = out_specs
diff --git a/src/chop/passes/graph/analysis/autosharding/megatron.py b/src/chop/passes/graph/analysis/autosharding/megatron.py
index ab9134a36..30cd36f7e 100644
--- a/src/chop/passes/graph/analysis/autosharding/megatron.py
+++ b/src/chop/passes/graph/analysis/autosharding/megatron.py
@@ -3,7 +3,9 @@
def megatron_autosharding_pass(
- mg: MaseGraph, mesh: MeshModel, pass_args: dict,
+ mg: MaseGraph,
+ mesh: MeshModel,
+ pass_args: dict,
):
for node in mg.fx_graph.nodes:
meta = node.meta["mase"]["common"]
diff --git a/src/chop/passes/graph/analysis/autosharding/strategies/basic_strategy.py b/src/chop/passes/graph/analysis/autosharding/strategies/basic_strategy.py
index d68f61cb5..40b704ca8 100644
--- a/src/chop/passes/graph/analysis/autosharding/strategies/basic_strategy.py
+++ b/src/chop/passes/graph/analysis/autosharding/strategies/basic_strategy.py
@@ -86,7 +86,10 @@ def parse_dims(cls, input_dims: List[str], output_dim: str) -> "EinsumDims":
def gen_einsum_strategies(
- equation: str, mesh: tuple, *, linearity: bool = False,
+ equation: str,
+ mesh: tuple,
+ *,
+ linearity: bool = False,
) -> OpStrategy:
"""
Generate a strategy list for the ops that follow einsum style notation.
diff --git a/src/chop/passes/graph/analysis/autosharding/strategies/common.py b/src/chop/passes/graph/analysis/autosharding/strategies/common.py
index e58980c60..2635a5587 100644
--- a/src/chop/passes/graph/analysis/autosharding/strategies/common.py
+++ b/src/chop/passes/graph/analysis/autosharding/strategies/common.py
@@ -81,7 +81,11 @@ def fully_replicated_strategy(meta, mesh):
in_spec = _DTensorSpec(
mesh,
sharding,
- tensor_meta=TensorMeta(shape=in_shape, stride=None, dtype=in_dtype,),
+ tensor_meta=TensorMeta(
+ shape=in_shape,
+ stride=None,
+ dtype=in_dtype,
+ ),
)
dtype_key = (
diff --git a/src/chop/passes/graph/analysis/autosharding/strategies/matrix_ops.py b/src/chop/passes/graph/analysis/autosharding/strategies/matrix_ops.py
index 60a3de959..78328a95d 100644
--- a/src/chop/passes/graph/analysis/autosharding/strategies/matrix_ops.py
+++ b/src/chop/passes/graph/analysis/autosharding/strategies/matrix_ops.py
@@ -24,7 +24,10 @@
from chop.ir.graph import MaseMetadata
-def transpose_strategy(meta: MaseMetadata, mesh: tuple,) -> OpStrategy:
+def transpose_strategy(
+ meta: MaseMetadata,
+ mesh: tuple,
+) -> OpStrategy:
parent_node = meta.node.args[0]
self_strategy = parent_node.meta["mase"]["software"]["autosharding"]["op_strategy"]
@@ -56,7 +59,11 @@ def transpose_strategy(meta: MaseMetadata, mesh: tuple,) -> OpStrategy:
return OpStrategy(strategies=transpose_strategies)
-def _mm_like_strategy(mm_equation: str, meta: MaseMetadata, mesh: tuple,) -> OpStrategy:
+def _mm_like_strategy(
+ mm_equation: str,
+ meta: MaseMetadata,
+ mesh: tuple,
+) -> OpStrategy:
self_shape, mat2_shape = [arg["shape"] for arg in meta["common"]["args"].values()]
# generate all possible strategies for mm
mm_strategy = gen_einsum_strategies(mm_equation, mesh)
@@ -95,7 +102,9 @@ def _mm_like_strategy(mm_equation: str, meta: MaseMetadata, mesh: tuple,) -> OpS
def _addmm_like_strategy(
- mm_equation: str, meta: MaseMetadata, mesh: tuple,
+ mm_equation: str,
+ meta: MaseMetadata,
+ mesh: tuple,
) -> OpStrategy:
self_shape, mat1_shape, mat2_shape = [
@@ -161,24 +170,37 @@ def _addmm_like_strategy(
return mm_strategy
-def mm_strategy(meta: MaseMetadata, mesh: tuple,) -> OpStrategy:
+def mm_strategy(
+ meta: MaseMetadata,
+ mesh: tuple,
+) -> OpStrategy:
return _mm_like_strategy("mk,kn->mn", meta, mesh)
-def addmm_strategy(meta: MaseMetadata, mesh: tuple,) -> OpStrategy:
+def addmm_strategy(
+ meta: MaseMetadata,
+ mesh: tuple,
+) -> OpStrategy:
return _addmm_like_strategy("mk,kn->mn", meta, mesh)
-def bmm_strategy(meta: MaseMetadata, mesh: tuple,) -> OpStrategy:
+def bmm_strategy(
+ meta: MaseMetadata,
+ mesh: tuple,
+) -> OpStrategy:
return _mm_like_strategy("bmk,bkn->bmn", meta, mesh)
-def baddmm_strategy(meta: MaseMetadata, mesh: tuple,) -> OpStrategy:
+def baddmm_strategy(
+ meta: MaseMetadata,
+ mesh: tuple,
+) -> OpStrategy:
return _addmm_like_strategy("bmk,bkn->bmn", meta, mesh)
def scaled_dot_product_flash_attention_strategy(
- meta: MaseMetadata, mesh: tuple,
+ meta: MaseMetadata,
+ mesh: tuple,
) -> OpStrategy:
# NOTE: currently we only support some simple strategies to support tensor parallelism
# TODO: sdpa might be a good candidate for us to explore decomposed sharding propagation
diff --git a/src/chop/passes/graph/analysis/autosharding/strategies/pointwise_ops.py b/src/chop/passes/graph/analysis/autosharding/strategies/pointwise_ops.py
index a7d8a739b..c3c06d35e 100644
--- a/src/chop/passes/graph/analysis/autosharding/strategies/pointwise_ops.py
+++ b/src/chop/passes/graph/analysis/autosharding/strategies/pointwise_ops.py
@@ -125,7 +125,9 @@ def common_pointwise_strategy(
arg_node.meta["mase"]["common"]["results"]["data_out_0"]["shape"],
)
input_target_placements = map_placements_after_broadcast(
- tuple(out_placements), common_shape, input_arg_dims_map,
+ tuple(out_placements),
+ common_shape,
+ input_arg_dims_map,
)
input_arg_target_spec = _DTensorSpec(
mesh=mesh,
diff --git a/src/chop/passes/graph/analysis/autosharding/strategies/view_ops.py b/src/chop/passes/graph/analysis/autosharding/strategies/view_ops.py
index 6cd423580..505a2440b 100644
--- a/src/chop/passes/graph/analysis/autosharding/strategies/view_ops.py
+++ b/src/chop/passes/graph/analysis/autosharding/strategies/view_ops.py
@@ -236,7 +236,9 @@ def dim_flatten(ndim: int, start_dim=0, end_dim=-1) -> DimMap:
def dim_movedim(
- ndim: int, input: Union[int, Sequence[int]], destination: Union[int, Sequence[int]],
+ ndim: int,
+ input: Union[int, Sequence[int]],
+ destination: Union[int, Sequence[int]],
) -> DimMap:
input = normalize_dims(input, ndim)
destination = normalize_dims(destination, ndim)
@@ -620,7 +622,8 @@ def reshape_strategy(meta, mesh):
)
output_strategy.strategies.append(
PlacementStrategy(
- output_specs=output_spec, input_specs=(input_tgt_spec,),
+ output_specs=output_spec,
+ input_specs=(input_tgt_spec,),
)
)
diff --git a/src/chop/passes/graph/analysis/flop_estimator/calculator/calc_modules.py b/src/chop/passes/graph/analysis/flop_estimator/calculator/calc_modules.py
index 04dde1876..b8ca08310 100644
--- a/src/chop/passes/graph/analysis/flop_estimator/calculator/calc_modules.py
+++ b/src/chop/passes/graph/analysis/flop_estimator/calculator/calc_modules.py
@@ -36,7 +36,7 @@ def calculate_modules(module, in_data, out_data):
# Kernel size here can be either a single int for square kernel
# or a tuple (see
# https://pytorch.org/docs/stable/nn.html#torch.nn.MaxPool2d )
- window_size = module.kernel_size ** 2
+ window_size = module.kernel_size**2
else:
window_size = module.kernel_size[0] * module.kernel_size[1]
diff --git a/src/chop/passes/graph/analysis/plot/plot_graph.py b/src/chop/passes/graph/analysis/plot/plot_graph.py
index adcc93177..8af151b36 100644
--- a/src/chop/passes/graph/analysis/plot/plot_graph.py
+++ b/src/chop/passes/graph/analysis/plot/plot_graph.py
@@ -4,7 +4,10 @@
def plot_graph_analysis_pass(
- graph, pass_args={"file_name": None,},
+ graph,
+ pass_args={
+ "file_name": None,
+ },
):
graph.draw(pass_args["file_name"])
# nx_graph = nx.DiGraph()
diff --git a/src/chop/passes/graph/interface/tensorrt/quantize.py b/src/chop/passes/graph/interface/tensorrt/quantize.py
index 0a6773754..0de0cda60 100644
--- a/src/chop/passes/graph/interface/tensorrt/quantize.py
+++ b/src/chop/passes/graph/interface/tensorrt/quantize.py
@@ -21,7 +21,6 @@ def Quantizer(config):
"pytorch_quantization is not installed. Cannot use tensorrt quantize pass."
)
-
else:
import tensorrt as trt
from pytorch_quantization import quant_modules, calib
diff --git a/src/chop/passes/graph/transforms/dse/run_dse.py b/src/chop/passes/graph/transforms/dse/run_dse.py
index 991f9e1bd..28446c4ca 100644
--- a/src/chop/passes/graph/transforms/dse/run_dse.py
+++ b/src/chop/passes/graph/transforms/dse/run_dse.py
@@ -16,7 +16,7 @@ def get_factors(n):
set(
functools.reduce(
list.__add__,
- ([i, n // i] for i in range(1, int(n ** 0.5) + 1) if n % i == 0),
+ ([i, n // i] for i in range(1, int(n**0.5) + 1) if n % i == 0),
)
)
)
diff --git a/src/chop/passes/graph/transforms/lora.py b/src/chop/passes/graph/transforms/lora.py
index 87fb2362f..9c4559661 100644
--- a/src/chop/passes/graph/transforms/lora.py
+++ b/src/chop/passes/graph/transforms/lora.py
@@ -10,7 +10,8 @@
def insert_lora_adapter_transform_pass(
- mg: MaseGraph, pass_args={},
+ mg: MaseGraph,
+ pass_args={},
):
rank = pass_args.get("rank", 0)
@@ -41,7 +42,8 @@ def insert_lora_adapter_transform_pass(
def fuse_lora_weights_transform_pass(
- mg: MaseGraph, pass_args={},
+ mg: MaseGraph,
+ pass_args={},
):
for node in mg.nodes:
target = (
diff --git a/src/chop/passes/graph/transforms/onnxrt/quantize.py b/src/chop/passes/graph/transforms/onnxrt/quantize.py
index 39d568236..cacbb6272 100644
--- a/src/chop/passes/graph/transforms/onnxrt/quantize.py
+++ b/src/chop/passes/graph/transforms/onnxrt/quantize.py
@@ -54,7 +54,9 @@ def quantize_dynamic(self, model_path: PosixPath, quantized_model_path: PosixPat
)
quantized_model = quantize_dynamic(
- model_path, quantized_model_path, weight_type=precision,
+ model_path,
+ quantized_model_path,
+ weight_type=precision,
)
self.logger.info("Quantization complete. Model is now dynamically quantized.")
diff --git a/src/chop/passes/graph/transforms/pruning/pruning_methods.py b/src/chop/passes/graph/transforms/pruning/pruning_methods.py
index aca81dc51..665abc17e 100644
--- a/src/chop/passes/graph/transforms/pruning/pruning_methods.py
+++ b/src/chop/passes/graph/transforms/pruning/pruning_methods.py
@@ -128,11 +128,31 @@ def neurons_random_fan_in(
weight_criteria_map = {
- "local": {"elementwise": {"random": random, "l1-norm": l1,}},
- "global": {"elementwise": {"random": random, "l1-norm": global_weight_l1,}},
+ "local": {
+ "elementwise": {
+ "random": random,
+ "l1-norm": l1,
+ }
+ },
+ "global": {
+ "elementwise": {
+ "random": random,
+ "l1-norm": global_weight_l1,
+ }
+ },
}
activation_criteria_map = {
- "local": {"elementwise": {"random": random, "l1-norm": l1,}},
- "global": {"elementwise": {"random": random, "l1-norm": global_activation_l1,}},
+ "local": {
+ "elementwise": {
+ "random": random,
+ "l1-norm": l1,
+ }
+ },
+ "global": {
+ "elementwise": {
+ "random": random,
+ "l1-norm": global_activation_l1,
+ }
+ },
}
diff --git a/src/chop/passes/graph/transforms/quantize/quant_parsers/parse_quant_config.py b/src/chop/passes/graph/transforms/quantize/quant_parsers/parse_quant_config.py
index 890497a95..dea246fc2 100644
--- a/src/chop/passes/graph/transforms/quantize/quant_parsers/parse_quant_config.py
+++ b/src/chop/passes/graph/transforms/quantize/quant_parsers/parse_quant_config.py
@@ -27,7 +27,10 @@
"weight_entries": ("weight_width", "weight_frac_width"),
"data_in_entries": ("data_in_width", "data_in_frac_width"),
"bias_entries": ("bias_width", "bias_frac_width"),
- "data_out_entries": ("data_out_width", "data_out_frac_width",),
+ "data_out_entries": (
+ "data_out_width",
+ "data_out_frac_width",
+ ),
"additional_layers_entries": ("floor"),
},
"lutnet": {
@@ -62,9 +65,18 @@
"weight_width",
"weight_frac_width",
),
- "bias_entries": ("bias_width", "bias_frac_width",),
- "data_in_entries": ("data_in_width", "data_in_frac_width",),
- "data_out_entries": ("data_out_width", "data_out_frac_width",),
+ "bias_entries": (
+ "bias_width",
+ "bias_frac_width",
+ ),
+ "data_in_entries": (
+ "data_in_width",
+ "data_in_frac_width",
+ ),
+ "data_out_entries": (
+ "data_out_width",
+ "data_out_frac_width",
+ ),
"additional_layers_entries": {
"additional_layers_inputs",
"additional_layers_outputs",
@@ -72,9 +84,21 @@
},
},
"binary": {
- "weight_entries": ("weight_width", "weight_stochastic", "weight_bipolar",),
- "data_in_entries": ("data_in_width", "data_in_stochastic", "data_in_bipolar",),
- "bias_entries": ("bias_width", "bias_stochastic", "bias_bipolar",),
+ "weight_entries": (
+ "weight_width",
+ "weight_stochastic",
+ "weight_bipolar",
+ ),
+ "data_in_entries": (
+ "data_in_width",
+ "data_in_stochastic",
+ "data_in_bipolar",
+ ),
+ "bias_entries": (
+ "bias_width",
+ "bias_stochastic",
+ "bias_bipolar",
+ ),
},
"binary_residual": {
"weight_entries": (
@@ -90,7 +114,11 @@
"data_in_residual_sign",
"data_in_levels", # data_in_levels (int): number of residual levels to use in lutnet
),
- "bias_entries": ("bias_width", "bias_stochastic", "bias_bipolar",),
+ "bias_entries": (
+ "bias_width",
+ "bias_stochastic",
+ "bias_bipolar",
+ ),
},
"binary_residual": {
"weight_entries": (
@@ -106,7 +134,11 @@
"data_in_residual_sign",
"data_in_levels", # data_in_levels (int): number of residual levels to use in lutnet
),
- "bias_entries": ("bias_width", "bias_stochastic", "bias_bipolar",),
+ "bias_entries": (
+ "bias_width",
+ "bias_stochastic",
+ "bias_bipolar",
+ ),
},
"ternary": {
"weight_entries": (
@@ -213,7 +245,11 @@
"data_in_exponent_bias_width",
"data_in_block_size",
),
- "bias_entries": ("bias_width", "bias_exponent_bias_width", "bias_block_size",),
+ "bias_entries": (
+ "bias_width",
+ "bias_exponent_bias_width",
+ "bias_block_size",
+ ),
},
"mxint_hardware": {
"weight_entries": (
@@ -226,7 +262,11 @@
"data_in_exponent_width",
"data_in_parallelism",
),
- "bias_entries": ("bias_width", "bias_exponent_width", "bias_parallelism",),
+ "bias_entries": (
+ "bias_width",
+ "bias_exponent_width",
+ "bias_parallelism",
+ ),
},
}
@@ -349,10 +389,22 @@ def cp_data_out_entries(
("name", "data_in_entries"),
("weight_entries", "bias_entries", "bypass"),
),
- "layer_norm": (("name", "data_in_entries"), ("bypass",),),
- "group_norm": (("name", "data_in_entries"), ("bypass",),),
- "instance_norm2d": (("name", "data_in_entries"), ("bypass",),),
- "rms_norm": (("name", "data_in_entries"), ("bypass",),),
+ "layer_norm": (
+ ("name", "data_in_entries"),
+ ("bypass",),
+ ),
+ "group_norm": (
+ ("name", "data_in_entries"),
+ ("bypass",),
+ ),
+ "instance_norm2d": (
+ ("name", "data_in_entries"),
+ ("bypass",),
+ ),
+ "rms_norm": (
+ ("name", "data_in_entries"),
+ ("bypass",),
+ ),
"grouped_query_attention": (
("name", "data_in_entries", "weight_entries"),
("bypass", "bias_entries"),
diff --git a/src/chop/passes/graph/transforms/training/modify.py b/src/chop/passes/graph/transforms/training/modify.py
index bf84faeee..eff5583d0 100644
--- a/src/chop/passes/graph/transforms/training/modify.py
+++ b/src/chop/passes/graph/transforms/training/modify.py
@@ -65,7 +65,10 @@ def attach_backward_fn(q_fn: torch.autograd.Function, mase_op: str, q_fn_cfg: di
def create_new_module(
- mase_op: str, original_module: nn.Module, config: dict, node_meta: dict,
+ mase_op: str,
+ original_module: nn.Module,
+ config: dict,
+ node_meta: dict,
):
original_module_cls = type(original_module)
diff --git a/src/chop/passes/graph/transforms/verilog/emit_bram.py b/src/chop/passes/graph/transforms/verilog/emit_bram.py
index e562a5fb1..a6a000c79 100644
--- a/src/chop/passes/graph/transforms/verilog/emit_bram.py
+++ b/src/chop/passes/graph/transforms/verilog/emit_bram.py
@@ -245,8 +245,8 @@ def emit_parameters_in_dat_internal(node, param_name, file_name):
else:
base_quantizer = integer_quantizer_for_hw
- scale = 2 ** frac_width
- thresh = 2 ** width
+ scale = 2**frac_width
+ thresh = 2**width
for i in range(0, out_depth):
line_buff = ""
for j in range(0, out_size):
@@ -301,8 +301,8 @@ def emit_parameters_in_dat_hls(node, param_name, file_name):
"precision"
][1]
- scale = 2 ** frac_width
- thresh = 2 ** width
+ scale = 2**frac_width
+ thresh = 2**width
for i in range(0, out_depth):
line_buff = ""
for j in range(0, out_size):
diff --git a/src/chop/passes/graph/transforms/verilog/emit_hls.py b/src/chop/passes/graph/transforms/verilog/emit_hls.py
index ce348c43d..3efe2c037 100644
--- a/src/chop/passes/graph/transforms/verilog/emit_hls.py
+++ b/src/chop/passes/graph/transforms/verilog/emit_hls.py
@@ -121,7 +121,12 @@ def _call_hls_flow(node, node_dir):
# Call Vitis HLS for synthesis
vitis_hls = os.path.abspath(
os.path.join(
- os.path.dirname(__file__), "..", "..", "..", "scripts", "run-vitis-hls.sh",
+ os.path.dirname(__file__),
+ "..",
+ "..",
+ "..",
+ "scripts",
+ "run-vitis-hls.sh",
)
)
assert os.path.isfile(
diff --git a/src/chop/passes/module/analysis/report.py b/src/chop/passes/module/analysis/report.py
index c384ceafd..514f75ce8 100644
--- a/src/chop/passes/module/analysis/report.py
+++ b/src/chop/passes/module/analysis/report.py
@@ -26,7 +26,8 @@ def get_submodule_summary(name: str, module: nn.Module, level: int = 0):
def report_trainable_parameters_analysis_pass(
- module: torch.nn.Module, pass_args: dict = {},
+ module: torch.nn.Module,
+ pass_args: dict = {},
):
submodule_summary, total_params = get_submodule_summary("", module)
table = [(name, params) for _, name, params in submodule_summary]
diff --git a/src/chop/passes/utils.py b/src/chop/passes/utils.py
index 8d7ea71eb..912250077 100644
--- a/src/chop/passes/utils.py
+++ b/src/chop/passes/utils.py
@@ -23,7 +23,9 @@ def _nightly_torch_installed():
return False
-def find_missing_dependencies(pass_name: str,):
+def find_missing_dependencies(
+ pass_name: str,
+):
dependencies = PassFactory._dependencies_dict.get(pass_name, None)
if dependencies is None:
@@ -38,7 +40,9 @@ def find_missing_dependencies(pass_name: str,):
def register_mase_pass(
- name: str, dependencies: list = [], requires_nightly_torch: bool = False,
+ name: str,
+ dependencies: list = [],
+ requires_nightly_torch: bool = False,
):
"""This decorator registers a mase pass as PassFactory class attributes which can be used globally."""
diff --git a/src/chop/pipelines/auto_pipeline.py b/src/chop/pipelines/auto_pipeline.py
index 1bd382983..c9b784795 100644
--- a/src/chop/pipelines/auto_pipeline.py
+++ b/src/chop/pipelines/auto_pipeline.py
@@ -12,7 +12,11 @@ class AutoPipeline:
The output of each pass is stored in a dictionary and can be accessed by the next pass.
"""
- def __init__(self, pass_groups=None, run_training: bool = False,) -> None:
+ def __init__(
+ self,
+ pass_groups=None,
+ run_training: bool = False,
+ ) -> None:
"""Initializes the AutoPipeline.
Args:
@@ -22,7 +26,11 @@ def __init__(self, pass_groups=None, run_training: bool = False,) -> None:
self.pass_outputs = [{}] * len(pass_groups)
def _run_pass_group(
- self, mg: MaseGraph, pass_group: list, pass_args: dict, skip_passes: list = [],
+ self,
+ mg: MaseGraph,
+ pass_group: list,
+ pass_args: dict,
+ skip_passes: list = [],
):
pass_outputs = {}
@@ -48,7 +56,10 @@ def _run_pass_group(
return mg, pass_outputs
def __call__(
- self, mg: MaseGraph, pass_args: dict, skip_passes: list = [],
+ self,
+ mg: MaseGraph,
+ pass_args: dict,
+ skip_passes: list = [],
):
for idx, pass_group in enumerate(self.pass_groups):
@@ -59,7 +70,10 @@ def __call__(
)
mg, pass_outputs = self._run_pass_group(
- mg, pass_group, pass_args, skip_passes,
+ mg,
+ pass_group,
+ pass_args,
+ skip_passes,
)
self.pass_outputs[idx] = pass_outputs
diff --git a/src/chop/tools/check_dependency.py b/src/chop/tools/check_dependency.py
index c71ca088e..dbf13cd12 100644
--- a/src/chop/tools/check_dependency.py
+++ b/src/chop/tools/check_dependency.py
@@ -23,7 +23,9 @@ def check_deps_tensorRT_pass(silent: bool = True):
return all(availabilities)
-def find_missing_dependencies(pass_name: str,):
+def find_missing_dependencies(
+ pass_name: str,
+):
dependencies = PassFactory._dependencies_dict.get(pass_name, None)
if dependencies is None:
@@ -38,7 +40,8 @@ def find_missing_dependencies(pass_name: str,):
def check_dependencies(
- pass_name: str, silent: bool = True,
+ pass_name: str,
+ silent: bool = True,
):
unavailable_deps = find_missing_dependencies(pass_name)
diff --git a/src/chop/tools/huggingface.py b/src/chop/tools/huggingface.py
index efa0fb1f6..7cf675460 100644
--- a/src/chop/tools/huggingface.py
+++ b/src/chop/tools/huggingface.py
@@ -39,7 +39,10 @@ def get_hf_dummy_in(model):
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
dummy_input = tokenizer(
- ["AI may take over the world one day", "This is why you should learn ADLS",],
+ [
+ "AI may take over the world one day",
+ "This is why you should learn ADLS",
+ ],
return_tensors="pt",
)
@@ -50,7 +53,9 @@ def get_hf_dummy_in(model):
def get_tokenized_dataset(
- dataset: str, checkpoint: str, return_tokenizer: bool = False,
+ dataset: str,
+ checkpoint: str,
+ return_tokenizer: bool = False,
):
"""
Tokenizes a dataset using the AutoTokenizer from Huggingface.
@@ -76,7 +81,10 @@ def get_tokenized_dataset(
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
def tokenize_function(example):
- return tokenizer(example["text"], truncation=True,)
+ return tokenizer(
+ example["text"],
+ truncation=True,
+ )
# Tokenize
tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)
diff --git a/src/chop/tools/plt_wrapper/nlp/classification.py b/src/chop/tools/plt_wrapper/nlp/classification.py
index a69b2ac0f..f96227cba 100644
--- a/src/chop/tools/plt_wrapper/nlp/classification.py
+++ b/src/chop/tools/plt_wrapper/nlp/classification.py
@@ -41,7 +41,9 @@ def forward(self, input_ids, attention_mask, token_type_ids=None, labels=None):
)
else:
outputs = self.model(
- input_ids, attention_mask=attention_mask, labels=labels,
+ input_ids,
+ attention_mask=attention_mask,
+ labels=labels,
)
return outputs
diff --git a/src/chop/tools/plt_wrapper/nlp/lm.py b/src/chop/tools/plt_wrapper/nlp/lm.py
index a2b669e87..60739da47 100644
--- a/src/chop/tools/plt_wrapper/nlp/lm.py
+++ b/src/chop/tools/plt_wrapper/nlp/lm.py
@@ -46,7 +46,9 @@ def training_step(self, batch, batch_idx):
self.log("train_loss_step", loss, prog_bar=True)
self.log(
- "train_perplexity_step", perplexity, prog_bar=True,
+ "train_perplexity_step",
+ perplexity,
+ prog_bar=True,
)
return loss
diff --git a/src/chop/tools/utils.py b/src/chop/tools/utils.py
index d070eff92..2ae59855e 100644
--- a/src/chop/tools/utils.py
+++ b/src/chop/tools/utils.py
@@ -92,7 +92,7 @@ def get_factors(n):
set(
functools.reduce(
list.__add__,
- ([i, n // i] for i in range(1, int(n ** 0.5) + 1) if n % i == 0),
+ ([i, n // i] for i in range(1, int(n**0.5) + 1) if n % i == 0),
)
)
)
@@ -184,7 +184,9 @@ def init_Conv2dLUT_weight(
# Initialize the weight based on the trained binaried network
# weight shape of the lagrange trainer [tables_count, self.kk]
input_mask = new_module.input_mask.reshape(
- -1, in_channels * kernel_size[0] * kernel_size[1] * k, 3,
+ -1,
+ in_channels * kernel_size[0] * kernel_size[1] * k,
+ 3,
) # [oc, k * kh * kw * ic ,3[ic,kh,kw]]
expanded_original_weight = original_weight[
np.arange(out_channels)[:, np.newaxis],
diff --git a/src/mase_cocotb/interfaces/streaming.py b/src/mase_cocotb/interfaces/streaming.py
index 711d2760f..10ca26820 100644
--- a/src/mase_cocotb/interfaces/streaming.py
+++ b/src/mase_cocotb/interfaces/streaming.py
@@ -17,7 +17,14 @@ def _sign_extend(value: int, bits: int):
class StreamDriver(Driver):
- def __init__(self, clk, data, valid, ready, record_num_beats=False,) -> None:
+ def __init__(
+ self,
+ clk,
+ data,
+ valid,
+ ready,
+ record_num_beats=False,
+ ) -> None:
super().__init__()
self.clk = clk
self.data = data
@@ -67,7 +74,14 @@ async def _driver_send(self, transaction) -> None:
class StreamMonitor(Monitor):
def __init__(
- self, clk, data, valid, ready, check=True, name=None, unsigned=False,
+ self,
+ clk,
+ data,
+ valid,
+ ready,
+ check=True,
+ name=None,
+ unsigned=False,
):
super().__init__(clk, check=check, name=name)
self.clk = clk
@@ -151,9 +165,9 @@ def __init__(self, clk, data, valid, ready, data_width, frac_width, check=True):
def _check(self, got, exp):
if self.check:
- float_got = [x * 2 ** -self.frac_width for x in got]
- float_exp = [x * 2 ** -self.frac_width for x in exp]
- if not np.isclose(float_got, float_exp, atol=2 ** -self.frac_width).all():
+ float_got = [x * 2**-self.frac_width for x in got]
+ float_exp = [x * 2**-self.frac_width for x in exp]
+ if not np.isclose(float_got, float_exp, atol=2**-self.frac_width).all():
# raise TestFailure("\nGot \n%s, \nExpected \n%s" % (got, exp))
raise TestFailure(
f"\nGot int \n{got}, \nExpected int \n{exp} \nGot float \n{float_got}, \nExpected float \n{float_exp}"
diff --git a/src/mase_cocotb/runner.py b/src/mase_cocotb/runner.py
index a74bff0d4..21490555a 100644
--- a/src/mase_cocotb/runner.py
+++ b/src/mase_cocotb/runner.py
@@ -110,7 +110,10 @@ def _single_test(
verilog_sources=sources,
includes=includes,
hdl_toplevel=module,
- build_args=[*tool_args, *extra_build_args,],
+ build_args=[
+ *tool_args,
+ *extra_build_args,
+ ],
# Do not use params in hierarchical verilation
parameters=module_params if not hierarchical else {},
build_dir=test_work_dir,
@@ -309,10 +312,14 @@ def simulate_pass(
verilog_sources=[rtl_dir / "top.sv"],
includes=[rtl_dir],
hdl_toplevel="top",
- build_args=[*_verilator_args(False, trace) * extra_build_args,],
+ build_args=[
+ *_verilator_args(False, trace) * extra_build_args,
+ ],
parameters=module_params,
build_dir=sim_dir,
)
runner.test(
- hdl_toplevel="top", test_module="test", results_xml="results.xml",
+ hdl_toplevel="top",
+ test_module="test",
+ results_xml="results.xml",
)
diff --git a/src/mase_cocotb/testbench.py b/src/mase_cocotb/testbench.py
index d746f74b5..0ad58ba1b 100644
--- a/src/mase_cocotb/testbench.py
+++ b/src/mase_cocotb/testbench.py
@@ -8,7 +8,12 @@ class Testbench:
__test__ = False # so pytest doesn't confuse this with a test
def __init__(
- self, dut, clk=None, rst=None, fail_on_checks=True, clk_period_ns=20,
+ self,
+ dut,
+ clk=None,
+ rst=None,
+ fail_on_checks=True,
+ clk_period_ns=20,
) -> None:
self.dut = dut
self.clk = clk
diff --git a/src/mase_cocotb/utils.py b/src/mase_cocotb/utils.py
index 25b6d83d8..7271389f2 100644
--- a/src/mase_cocotb/utils.py
+++ b/src/mase_cocotb/utils.py
@@ -78,8 +78,8 @@ def int_floor_quantizer(x: Tensor, width: int, frac_width: int, signed=True):
int_max = 2 ** (width - 1) - 1
else:
int_min = 0
- int_max = 2 ** width - 1
- scale = 2 ** frac_width
+ int_max = 2**width - 1
+ scale = 2**frac_width
return torch.clamp(torch.floor(x.mul(scale)), int_min, int_max).div(scale)
diff --git a/src/mase_cocotb/z_qlayers/tensor_cast.py b/src/mase_cocotb/z_qlayers/tensor_cast.py
index cf212cf79..ce651bc5c 100644
--- a/src/mase_cocotb/z_qlayers/tensor_cast.py
+++ b/src/mase_cocotb/z_qlayers/tensor_cast.py
@@ -46,9 +46,9 @@ def _integer_quantize(
int_max = 2 ** (width - 1) - 1
else:
int_min = 0
- int_max = 2 ** width - 1
+ int_max = 2**width - 1
# thresh = 2 ** (width - 1)
- scale = 2 ** frac_width
+ scale = 2**frac_width
if isinstance(x, (Tensor, ndarray)):
return my_clamp(my_round(x.mul(scale)), int_min, int_max).div(scale)
@@ -59,7 +59,7 @@ def _integer_quantize(
def quantize_to_int(x: Tensor, width: int, frac_width: int):
- x = (_integer_quantize(x, width, frac_width) * (2 ** frac_width)).int()
+ x = (_integer_quantize(x, width, frac_width) * (2**frac_width)).int()
return x
diff --git a/src/mase_components/activation_layers/test/fixed_elu_tb.py b/src/mase_components/activation_layers/test/fixed_elu_tb.py
index 1c7e850e4..87ff69fd6 100644
--- a/src/mase_components/activation_layers/test/fixed_elu_tb.py
+++ b/src/mase_components/activation_layers/test/fixed_elu_tb.py
@@ -79,10 +79,10 @@ def generate_lookup(data_width: int, f_width: int, function: str, type="hex"):
count += 1
iarr.append(i)
val = quanter(f(torch.tensor(i))) # entry in the lookup table
- lut[
- doubletofx(data_width=data_width, f_width=f_width, num=i, type=type)
- ] = doubletofx(
- data_width=data_width, f_width=f_width, num=val.item(), type=type
+ lut[doubletofx(data_width=data_width, f_width=f_width, num=i, type=type)] = (
+ doubletofx(
+ data_width=data_width, f_width=f_width, num=val.item(), type=type
+ )
)
i += 2 ** -(f_width)
return lut
@@ -444,8 +444,8 @@ def exp(self):
) # match output
logger.info(f"EXP - FLOAT OUTPUT: \n{m}")
m = self.out_dquantizer(m)
- m2 = (m * 2 ** self.outputfracw).to(torch.int64)
- m2 = m2.clone().detach() % (2 ** self.outputwidth)
+ m2 = (m * 2**self.outputfracw).to(torch.int64)
+ m2 = m2.clone().detach() % (2**self.outputwidth)
return m2
@@ -456,7 +456,7 @@ def generate_inputs(self):
)
logger.info(f"FLOAT INPUT: \n{inputs}")
inputs = self.in_dquantizer(inputs)
- intinp = (inputs * 2 ** self.frac_width).to(torch.int64)
+ intinp = (inputs * 2**self.frac_width).to(torch.int64)
return intinp, inputs
def doubletofx(self, num, data_width, f_width, type="bin"):
diff --git a/src/mase_components/activation_layers/test/fixed_gelu_tb.py b/src/mase_components/activation_layers/test/fixed_gelu_tb.py
index 2a8e79787..1a7d760b6 100644
--- a/src/mase_components/activation_layers/test/fixed_gelu_tb.py
+++ b/src/mase_components/activation_layers/test/fixed_gelu_tb.py
@@ -29,13 +29,13 @@ async def cocotb_test_fixed_gelu(dut):
resolution = (max_value - min_value) / (num_values - 1)
# Convert the resolution into fixed-point format
- resolution_fixed_point = int(resolution * (2 ** DATA_IN_0_PRECISION_1))
+ resolution_fixed_point = int(resolution * (2**DATA_IN_0_PRECISION_1))
# Generate the equidistant values
values = np.linspace(min_value, max_value, num_values)
# Convert values to fixed-point format
- values_fixed_point = np.round(values * (2 ** DATA_IN_0_PRECISION_1)).astype(int)
+ values_fixed_point = np.round(values * (2**DATA_IN_0_PRECISION_1)).astype(int)
tensor_tanh = torch.Tensor(values)
@@ -46,9 +46,9 @@ async def cocotb_test_fixed_gelu(dut):
for i in range(87):
# a = tanh_values[i]
- b = max_value * (2 ** DATA_IN_0_PRECISION_1) - resolution_fixed_point * (i - 1)
+ b = max_value * (2**DATA_IN_0_PRECISION_1) - resolution_fixed_point * (i - 1)
- a = [b / (2 ** DATA_IN_0_PRECISION_1)]
+ a = [b / (2**DATA_IN_0_PRECISION_1)]
tensor_tanh = torch.Tensor(a)
c = model(tensor_tanh)
@@ -56,7 +56,7 @@ async def cocotb_test_fixed_gelu(dut):
# Scale Tanh output to fixed-point range and convert to integers
scaled_value = np.round(
- tanh_value_numpy * (2 ** DATA_OUT_0_PRECISION_1)
+ tanh_value_numpy * (2**DATA_OUT_0_PRECISION_1)
).astype(int)
dut.data_in_0[0].value = b
diff --git a/src/mase_components/activation_layers/test/fixed_hardshrink_tb.py b/src/mase_components/activation_layers/test/fixed_hardshrink_tb.py
index 57f3940fb..efb6f0045 100644
--- a/src/mase_components/activation_layers/test/fixed_hardshrink_tb.py
+++ b/src/mase_components/activation_layers/test/fixed_hardshrink_tb.py
@@ -62,11 +62,11 @@ def exp(self, inputs):
# cond = torch.logical_not(torch.logical_and(inputs <= self.thresh*2**self.fracw, inputs >= -1 * self.thresh *2**self.fracw))
# out = torch.where(cond, inputs, torch.tensor(0))
# unsignedout = torch.where(out < 0, torch.tensor(out % (2**self.width)), out)
- m = torch.nn.Hardshrink(self.thresh * 2 ** self.fracw)(inputs.to(torch.float))
+ m = torch.nn.Hardshrink(self.thresh * 2**self.fracw)(inputs.to(torch.float))
mout = m.clamp(
min=-1 * 2 ** (self.outputwidth - 1), max=2 ** (self.outputwidth - 1) - 1
)
- m2 = torch.where(mout < 0, torch.tensor(mout % (2 ** self.outputwidth)), mout)
+ m2 = torch.where(mout < 0, torch.tensor(mout % (2**self.outputwidth)), mout)
return m2.to(torch.int32).tolist()
def generate_inputs(self, w, fracw):
@@ -75,7 +75,7 @@ def generate_inputs(self, w, fracw):
)
realinp = torch.randn(self.samples)
inputs = self.dquantizer(realinp)
- intinp = (inputs * 2 ** self.fracw).to(torch.int64)
+ intinp = (inputs * 2**self.fracw).to(torch.int64)
intinp.clamp(
min=-(2 ** (self.width - self.fracw - 1)),
max=2 ** (self.width - self.fracw - 1) - 1,
diff --git a/src/mase_components/activation_layers/test/fixed_hardswish_tb.py b/src/mase_components/activation_layers/test/fixed_hardswish_tb.py
index 53f79755e..ceb194e35 100644
--- a/src/mase_components/activation_layers/test/fixed_hardswish_tb.py
+++ b/src/mase_components/activation_layers/test/fixed_hardswish_tb.py
@@ -56,14 +56,12 @@ def __init__(self, dut) -> None:
def exp(self, inputs):
# Run the model with the provided inputs and return the outputs
- tmp0 = 3 * 2 ** self.fracw
+ tmp0 = 3 * 2**self.fracw
tmp1 = inputs + tmp0
- tmp2 = tmp1 * (2 ** -3) + tmp1 * (2 ** -4)
+ tmp2 = tmp1 * (2**-3) + tmp1 * (2**-4)
# qtmps = self.dquantizer(tmp2)
tmp3 = tmp2 * inputs
- unsignedout = torch.where(
- tmp3 < 0, torch.tensor(tmp3 % (2 ** self.width)), tmp3
- )
+ unsignedout = torch.where(tmp3 < 0, torch.tensor(tmp3 % (2**self.width)), tmp3)
# return unsignedout.tolist()
return unsignedout
@@ -73,7 +71,7 @@ def generate_inputs(self, w, fracw):
)
realinp = torch.randn(self.samples)
inputs = self.dquantizer(realinp)
- intinp = (inputs * 2 ** self.fracw).to(torch.int64)
+ intinp = (inputs * 2**self.fracw).to(torch.int64)
intinp.clamp(
min=-(2 ** (self.width - self.fracw - 1)),
max=2 ** (self.width - self.fracw - 1) - 1,
diff --git a/src/mase_components/activation_layers/test/fixed_leaky_relu_tb.py b/src/mase_components/activation_layers/test/fixed_leaky_relu_tb.py
index a44efe0b8..d0cd8796e 100644
--- a/src/mase_components/activation_layers/test/fixed_leaky_relu_tb.py
+++ b/src/mase_components/activation_layers/test/fixed_leaky_relu_tb.py
@@ -28,9 +28,9 @@ def get_in_and_out(x, fn, width, frac_width):
ins = integer_quantizer(x, width=width, frac_width=frac_width)
y = fn(x)
outs = integer_quantizer(y, width=width, frac_width=frac_width)
- outs = outs * 2 ** frac_width
+ outs = outs * 2**frac_width
outs = outs.int()
- ins = ins * 2 ** frac_width
+ ins = ins * 2**frac_width
ins = ins.int()
return (ins, outs)
@@ -78,7 +78,7 @@ async def cocotb_test(dut):
logger.info(f"Reset finished")
tb.data_out_0_monitor.ready.value = 1
- inputs, exp_outs = tb.generate_inputs_outputs(8, 4, 2 ** -4)
+ inputs, exp_outs = tb.generate_inputs_outputs(8, 4, 2**-4)
tb.data_in_0_driver.append(inputs.tolist())
diff --git a/src/mase_components/activation_layers/test/fixed_logsigmoid_tb.py b/src/mase_components/activation_layers/test/fixed_logsigmoid_tb.py
index 9b9a0a766..570ddc328 100644
--- a/src/mase_components/activation_layers/test/fixed_logsigmoid_tb.py
+++ b/src/mase_components/activation_layers/test/fixed_logsigmoid_tb.py
@@ -142,8 +142,8 @@ def exp(self):
) # match output
logger.info(f"EXP - FLOAT OUTPUT: \n{m}")
m = self.out_dquantizer(m)
- m2 = (m * 2 ** self.outputfracw).to(torch.int64)
- m2 = m2.clone().detach() % (2 ** self.outputwidth)
+ m2 = (m * 2**self.outputfracw).to(torch.int64)
+ m2 = m2.clone().detach() % (2**self.outputwidth)
return m2
@@ -154,7 +154,7 @@ def generate_inputs(self):
)
logger.info(f"FLOAT INPUT: \n{inputs}")
inputs = self.in_dquantizer(inputs)
- intinp = (inputs * 2 ** self.frac_width).to(torch.int64)
+ intinp = (inputs * 2**self.frac_width).to(torch.int64)
return intinp, inputs
def doubletofx(self, num, data_width, f_width, type="bin"):
diff --git a/src/mase_components/activation_layers/test/fixed_relu_tb.py b/src/mase_components/activation_layers/test/fixed_relu_tb.py
index dbe8ac146..118296c48 100644
--- a/src/mase_components/activation_layers/test/fixed_relu_tb.py
+++ b/src/mase_components/activation_layers/test/fixed_relu_tb.py
@@ -47,7 +47,7 @@ def backward(ctx, grad_output):
def quantize(x, bits, bias): # bits = 32
"""Do linear quantization to input according to a scale and number of bits"""
thresh = 2 ** (bits - 1)
- scale = 2 ** bias
+ scale = 2**bias
return my_clamp(my_round(x.mul(scale)), -thresh, thresh - 1).div(scale)
@@ -83,7 +83,7 @@ def get_dut_parameters(self):
def get_dut_input(self, i):
inputs = self.inputs[i]
- shifted_integers = (inputs * (2 ** self.bias)).int()
+ shifted_integers = (inputs * (2**self.bias)).int()
return shifted_integers.numpy().tolist()
def get_dut_output(self, i):
@@ -92,7 +92,7 @@ def get_dut_output(self, i):
return shifted_integers
def convert_to_fixed(self, x):
- return (x * (2 ** self.bias)).int().numpy().tolist()
+ return (x * (2**self.bias)).int().numpy().tolist()
@cocotb.test()
diff --git a/src/mase_components/activation_layers/test/fixed_selu_tb.py b/src/mase_components/activation_layers/test/fixed_selu_tb.py
index 459334c6c..647c4fff1 100644
--- a/src/mase_components/activation_layers/test/fixed_selu_tb.py
+++ b/src/mase_components/activation_layers/test/fixed_selu_tb.py
@@ -25,13 +25,13 @@ async def cocotb_test_fixed_selu(dut):
resolution = (max_value - min_value) / (num_values - 1)
# Convert the resolution into fixed-point format
- resolution_fixed_point = int(resolution * (2 ** DATA_IN_0_PRECISION_1))
+ resolution_fixed_point = int(resolution * (2**DATA_IN_0_PRECISION_1))
# Generate the equidistant values
values = np.linspace(min_value, max_value, num_values)
# Convert values to fixed-point format
- values_fixed_point = np.round(values * (2 ** DATA_IN_0_PRECISION_1)).astype(int)
+ values_fixed_point = np.round(values * (2**DATA_IN_0_PRECISION_1)).astype(int)
tensor_tanh = torch.Tensor(values)
@@ -42,9 +42,9 @@ async def cocotb_test_fixed_selu(dut):
for i in range(87):
# a = tanh_values[i]
- b = max_value * (2 ** DATA_IN_0_PRECISION_1) - resolution_fixed_point * (i - 1)
+ b = max_value * (2**DATA_IN_0_PRECISION_1) - resolution_fixed_point * (i - 1)
- a = [b / (2 ** DATA_IN_0_PRECISION_1)]
+ a = [b / (2**DATA_IN_0_PRECISION_1)]
tensor_tanh = torch.Tensor(a)
c = model(tensor_tanh)
@@ -52,7 +52,7 @@ async def cocotb_test_fixed_selu(dut):
# Scale Tanh output to fixed-point range and convert to integers
scaled_value = np.round(
- tanh_value_numpy * (2 ** DATA_OUT_0_PRECISION_1)
+ tanh_value_numpy * (2**DATA_OUT_0_PRECISION_1)
).astype(int)
dut.data_in_0[0].value = b
diff --git a/src/mase_components/activation_layers/test/fixed_sigmoid_tb.py b/src/mase_components/activation_layers/test/fixed_sigmoid_tb.py
index 427220089..e2a794d22 100644
--- a/src/mase_components/activation_layers/test/fixed_sigmoid_tb.py
+++ b/src/mase_components/activation_layers/test/fixed_sigmoid_tb.py
@@ -128,8 +128,8 @@ def exp(self):
) # match output
logger.info(f"EXP - FLOAT OUTPUT: \n{m}")
m = self.out_dquantizer(m)
- m2 = (m * 2 ** self.outputfracw).to(torch.int64)
- m2 = m2.clone().detach() % (2 ** self.outputwidth)
+ m2 = (m * 2**self.outputfracw).to(torch.int64)
+ m2 = m2.clone().detach() % (2**self.outputwidth)
return m2
@@ -140,7 +140,7 @@ def generate_inputs(self):
)
logger.info(f"FLOAT INPUT: \n{inputs}")
inputs = self.in_dquantizer(inputs)
- intinp = (inputs * 2 ** self.frac_width).to(torch.int64)
+ intinp = (inputs * 2**self.frac_width).to(torch.int64)
return intinp, inputs
def doubletofx(self, num, data_width, f_width, type="bin"):
diff --git a/src/mase_components/activation_layers/test/fixed_silu_tb.py b/src/mase_components/activation_layers/test/fixed_silu_tb.py
index 19c4ac3a2..944a83e5e 100644
--- a/src/mase_components/activation_layers/test/fixed_silu_tb.py
+++ b/src/mase_components/activation_layers/test/fixed_silu_tb.py
@@ -143,8 +143,8 @@ def exp(self):
) # match output
logger.info(f"EXP - FLOAT OUTPUT: \n{m}")
m = self.out_dquantizer(m)
- m2 = (m * 2 ** self.outputfracw).to(torch.int64)
- m2 = m2.clone().detach() % (2 ** self.outputwidth)
+ m2 = (m * 2**self.outputfracw).to(torch.int64)
+ m2 = m2.clone().detach() % (2**self.outputwidth)
return m2
@@ -155,7 +155,7 @@ def generate_inputs(self):
)
logger.info(f"FLOAT INPUT: \n{inputs}")
inputs = self.in_dquantizer(inputs)
- intinp = (inputs * 2 ** self.frac_width).to(torch.int64)
+ intinp = (inputs * 2**self.frac_width).to(torch.int64)
return intinp, inputs
def doubletofx(self, num, data_width, f_width, type="bin"):
diff --git a/src/mase_components/activation_layers/test/fixed_softermax_1d_tb.py b/src/mase_components/activation_layers/test/fixed_softermax_1d_tb.py
index 41e67e205..fa2b2aa75 100644
--- a/src/mase_components/activation_layers/test/fixed_softermax_1d_tb.py
+++ b/src/mase_components/activation_layers/test/fixed_softermax_1d_tb.py
@@ -64,7 +64,10 @@ def __init__(self, dut) -> None:
self.model = partial(
fixed_softermax,
dim=0,
- q_config={"width": self.IN_WIDTH, "frac_width": self.IN_FRAC_WIDTH,},
+ q_config={
+ "width": self.IN_WIDTH,
+ "frac_width": self.IN_FRAC_WIDTH,
+ },
)
# Set verbosity of driver and monitor loggers to debug
@@ -72,7 +75,9 @@ def __init__(self, dut) -> None:
# self.out_data_monitor.log.setLevel(logging.DEBUG)
def generate_inputs(self, batches):
- return torch.randn((batches, self.TOTAL_DIM),)
+ return torch.randn(
+ (batches, self.TOTAL_DIM),
+ )
async def run_test(self, batches, us):
await self.reset()
@@ -88,7 +93,10 @@ async def run_test(self, batches, us):
self.log.debug(f"Processing inputs: {batch}")
driver_input = fixed_preprocess_tensor(
tensor=batch,
- q_config={"width": self.IN_WIDTH, "frac_width": self.IN_FRAC_WIDTH,},
+ q_config={
+ "width": self.IN_WIDTH,
+ "frac_width": self.IN_FRAC_WIDTH,
+ },
parallelism=[self.PARALLELISM],
)
self.in_data_driver.load_driver(driver_input)
@@ -97,7 +105,10 @@ async def run_test(self, batches, us):
self.log.debug(f"Processing outputs: {exp_out}")
outs = fixed_preprocess_tensor(
tensor=exp_out,
- q_config={"width": self.OUT_WIDTH, "frac_width": self.OUT_FRAC_WIDTH,},
+ q_config={
+ "width": self.OUT_WIDTH,
+ "frac_width": self.OUT_FRAC_WIDTH,
+ },
parallelism=[self.PARALLELISM],
)
self.out_data_monitor.load_monitor(outs)
diff --git a/src/mase_components/activation_layers/test/fixed_softermax_tb.py b/src/mase_components/activation_layers/test/fixed_softermax_tb.py
index 84c6aa4e5..9b7d0aee1 100644
--- a/src/mase_components/activation_layers/test/fixed_softermax_tb.py
+++ b/src/mase_components/activation_layers/test/fixed_softermax_tb.py
@@ -133,7 +133,9 @@ def test_fixed_softermax_smoke():
"""
mase_runner(
trace=True,
- module_param_list=[get_fixed_softermax_config(),],
+ module_param_list=[
+ get_fixed_softermax_config(),
+ ],
# skip_build=True,
)
diff --git a/src/mase_components/activation_layers/test/fixed_softmax_tb.py b/src/mase_components/activation_layers/test/fixed_softmax_tb.py
index afc8876ac..62f37012f 100644
--- a/src/mase_components/activation_layers/test/fixed_softmax_tb.py
+++ b/src/mase_components/activation_layers/test/fixed_softmax_tb.py
@@ -128,8 +128,8 @@ def exp(self):
) # match output
logger.info(f"EXP - FLOAT OUTPUT: \n{m}")
m = self.out_dquantizer(m)
- m2 = (m * 2 ** self.outputfracw).to(torch.int64)
- m2 = m2.clone().detach() % (2 ** self.outputwidth)
+ m2 = (m * 2**self.outputfracw).to(torch.int64)
+ m2 = m2.clone().detach() % (2**self.outputwidth)
return m2
@@ -140,7 +140,7 @@ def generate_inputs(self):
)
logger.info(f"FLOAT INPUT: \n{inputs}")
inputs = self.in_dquantizer(inputs)
- intinp = (inputs * 2 ** self.frac_width).to(torch.int64)
+ intinp = (inputs * 2**self.frac_width).to(torch.int64)
return intinp, inputs
def doubletofx(self, num, data_width, f_width, type="bin"):
diff --git a/src/mase_components/activation_layers/test/fixed_softplus_tb.py b/src/mase_components/activation_layers/test/fixed_softplus_tb.py
index f39467ee7..121fa5c60 100644
--- a/src/mase_components/activation_layers/test/fixed_softplus_tb.py
+++ b/src/mase_components/activation_layers/test/fixed_softplus_tb.py
@@ -25,13 +25,13 @@ async def cocotb_test_fixed_softplus(dut):
resolution = (max_value - min_value) / (num_values - 1)
# Convert the resolution into fixed-point format
- resolution_fixed_point = int(resolution * (2 ** DATA_IN_0_PRECISION_1))
+ resolution_fixed_point = int(resolution * (2**DATA_IN_0_PRECISION_1))
# Generate the equidistant values
values = np.linspace(min_value, max_value, num_values)
# Convert values to fixed-point format
- values_fixed_point = np.round(values * (2 ** DATA_IN_0_PRECISION_1)).astype(int)
+ values_fixed_point = np.round(values * (2**DATA_IN_0_PRECISION_1)).astype(int)
tensor_tanh = torch.Tensor(values)
@@ -42,9 +42,9 @@ async def cocotb_test_fixed_softplus(dut):
for i in range(87):
# a = tanh_values[i]
- b = max_value * (2 ** DATA_IN_0_PRECISION_1) - resolution_fixed_point * (i - 1)
+ b = max_value * (2**DATA_IN_0_PRECISION_1) - resolution_fixed_point * (i - 1)
- a = [b / (2 ** DATA_IN_0_PRECISION_1)]
+ a = [b / (2**DATA_IN_0_PRECISION_1)]
tensor_tanh = torch.Tensor(a)
c = model(tensor_tanh)
@@ -52,7 +52,7 @@ async def cocotb_test_fixed_softplus(dut):
# Scale Tanh output to fixed-point range and convert to integers
scaled_value = np.round(
- tanh_value_numpy * (2 ** DATA_OUT_0_PRECISION_1)
+ tanh_value_numpy * (2**DATA_OUT_0_PRECISION_1)
).astype(int)
dut.data_in_0[0].value = b
diff --git a/src/mase_components/activation_layers/test/fixed_softshrink_tb.py b/src/mase_components/activation_layers/test/fixed_softshrink_tb.py
index 9c738602e..b79ccdaf5 100644
--- a/src/mase_components/activation_layers/test/fixed_softshrink_tb.py
+++ b/src/mase_components/activation_layers/test/fixed_softshrink_tb.py
@@ -142,8 +142,8 @@ def exp(self):
) # match output
logger.info(f"EXP - FLOAT OUTPUT: \n{m}")
m = self.out_dquantizer(m)
- m2 = (m * 2 ** self.outputfracw).to(torch.int64)
- m2 = m2.clone().detach() % (2 ** self.outputwidth)
+ m2 = (m * 2**self.outputfracw).to(torch.int64)
+ m2 = m2.clone().detach() % (2**self.outputwidth)
return m2
@@ -154,7 +154,7 @@ def generate_inputs(self):
)
logger.info(f"FLOAT INPUT: \n{inputs}")
inputs = self.in_dquantizer(inputs)
- intinp = (inputs * 2 ** self.frac_width).to(torch.int64)
+ intinp = (inputs * 2**self.frac_width).to(torch.int64)
return intinp, inputs
def doubletofx(self, num, data_width, f_width, type="bin"):
diff --git a/src/mase_components/activation_layers/test/fixed_softsign_tb.py b/src/mase_components/activation_layers/test/fixed_softsign_tb.py
index c8776ac32..5fec6b341 100644
--- a/src/mase_components/activation_layers/test/fixed_softsign_tb.py
+++ b/src/mase_components/activation_layers/test/fixed_softsign_tb.py
@@ -24,13 +24,13 @@ async def cocotb_test_fixed_softsign(dut):
resolution = (max_value - min_value) / (num_values - 1)
# Convert the resolution into fixed-point format
- resolution_fixed_point = int(resolution * (2 ** DATA_IN_0_PRECISION_1))
+ resolution_fixed_point = int(resolution * (2**DATA_IN_0_PRECISION_1))
# Generate the equidistant values
values = np.linspace(min_value, max_value, num_values)
# Convert values to fixed-point format
- values_fixed_point = np.round(values * (2 ** DATA_IN_0_PRECISION_1)).astype(int)
+ values_fixed_point = np.round(values * (2**DATA_IN_0_PRECISION_1)).astype(int)
tensor_tanh = torch.Tensor(values)
@@ -41,9 +41,9 @@ async def cocotb_test_fixed_softsign(dut):
for i in range(87):
# a = tanh_values[i]
- b = max_value * (2 ** DATA_IN_0_PRECISION_1) - resolution_fixed_point * (i - 1)
+ b = max_value * (2**DATA_IN_0_PRECISION_1) - resolution_fixed_point * (i - 1)
- a = [b / (2 ** DATA_IN_0_PRECISION_1)]
+ a = [b / (2**DATA_IN_0_PRECISION_1)]
tensor_tanh = torch.Tensor(a)
c = model(tensor_tanh)
@@ -51,7 +51,7 @@ async def cocotb_test_fixed_softsign(dut):
# Scale Tanh output to fixed-point range and convert to integers
scaled_value = np.round(
- tanh_value_numpy * (2 ** DATA_OUT_0_PRECISION_1)
+ tanh_value_numpy * (2**DATA_OUT_0_PRECISION_1)
).astype(int)
dut.data_in_0[0].value = b
diff --git a/src/mase_components/activation_layers/test/fixed_tanh_tb.py b/src/mase_components/activation_layers/test/fixed_tanh_tb.py
index 937fb930c..4e8c6d268 100644
--- a/src/mase_components/activation_layers/test/fixed_tanh_tb.py
+++ b/src/mase_components/activation_layers/test/fixed_tanh_tb.py
@@ -24,13 +24,13 @@ async def cocotb_test_fixed_tanh(dut):
resolution = (max_value - min_value) / (num_values - 1)
# Convert the resolution into fixed-point format
- resolution_fixed_point = int(resolution * (2 ** DATA_IN_0_PRECISION_1))
+ resolution_fixed_point = int(resolution * (2**DATA_IN_0_PRECISION_1))
# Generate the equidistant values
values = np.linspace(min_value, max_value, num_values)
# Convert values to fixed-point format
- values_fixed_point = np.round(values * (2 ** DATA_IN_0_PRECISION_1)).astype(int)
+ values_fixed_point = np.round(values * (2**DATA_IN_0_PRECISION_1)).astype(int)
tensor_tanh = torch.Tensor(values)
@@ -41,9 +41,9 @@ async def cocotb_test_fixed_tanh(dut):
for i in range(87):
# a = tanh_values[i]
- b = max_value * (2 ** DATA_IN_0_PRECISION_1) - resolution_fixed_point * (i - 1)
+ b = max_value * (2**DATA_IN_0_PRECISION_1) - resolution_fixed_point * (i - 1)
- a = [b / (2 ** DATA_IN_0_PRECISION_1)]
+ a = [b / (2**DATA_IN_0_PRECISION_1)]
tensor_tanh = torch.Tensor(a)
c = model(tensor_tanh)
@@ -51,7 +51,7 @@ async def cocotb_test_fixed_tanh(dut):
# Scale Tanh output to fixed-point range and convert to integers
scaled_value = np.round(
- tanh_value_numpy * (2 ** DATA_OUT_0_PRECISION_1)
+ tanh_value_numpy * (2**DATA_OUT_0_PRECISION_1)
).astype(int)
dut.data_in_0[0].value = b
diff --git a/src/mase_components/activation_layers/test/softermax.py b/src/mase_components/activation_layers/test/softermax.py
index 7c30c2986..a1b3271f5 100644
--- a/src/mase_components/activation_layers/test/softermax.py
+++ b/src/mase_components/activation_layers/test/softermax.py
@@ -53,7 +53,7 @@ def _softmax_model(l: list[int], parallelism: int, pow2=False):
for diff, vals in zip(local_max_diff, local_values_buffer):
if pow2:
- adj = [x * (2 ** -diff) for x in vals]
+ adj = [x * (2**-diff) for x in vals]
else:
adj = [x * exp(-diff) for x in vals]
norm += sum(adj)
diff --git a/src/mase_components/activation_layers/test/softermax_global_norm_tb.py b/src/mase_components/activation_layers/test/softermax_global_norm_tb.py
index 6ebbfc131..bc74e42ee 100644
--- a/src/mase_components/activation_layers/test/softermax_global_norm_tb.py
+++ b/src/mase_components/activation_layers/test/softermax_global_norm_tb.py
@@ -46,7 +46,7 @@ def __init__(self, dut) -> None:
# Specify Error Threshold
self.percentage_error = 0.05 # 5%
- self.error_threshold_bits = ceil(self.percentage_error * (2 ** self.OUT_WIDTH))
+ self.error_threshold_bits = ceil(self.percentage_error * (2**self.OUT_WIDTH))
self.output_monitor = ErrorThresholdStreamMonitor(
dut.clk,
@@ -63,15 +63,15 @@ def __init__(self, dut) -> None:
def generate_inputs(self, batches=10):
# TODO: Take a look at all zero case again
local_vals = torch.randint(
- 1, 2 ** self.IN_VALUE_WIDTH, size=(batches * self.DEPTH, self.PARALLELISM)
+ 1, 2**self.IN_VALUE_WIDTH, size=(batches * self.DEPTH, self.PARALLELISM)
)
local_max = torch.randint(
- 0, 2 ** self.IN_MAX_WIDTH, size=(batches * self.DEPTH, 1)
+ 0, 2**self.IN_MAX_WIDTH, size=(batches * self.DEPTH, 1)
)
logger.debug("local_vals: %s" % (local_vals))
logger.debug(
- "local_vals (float): %s" % (local_vals / (2 ** self.IN_VALUE_FRAC_WIDTH))
+ "local_vals (float): %s" % (local_vals / (2**self.IN_VALUE_FRAC_WIDTH))
)
logger.debug("local_max: %s" % (local_max))
logger.debug(
@@ -87,7 +87,7 @@ def model(self, inputs):
for batch in batched_in:
local_vals, local_max = list(zip(*batch))
local_vals = torch.tensor(list(local_vals), dtype=torch.float) / (
- 2 ** self.IN_VALUE_FRAC_WIDTH
+ 2**self.IN_VALUE_FRAC_WIDTH
)
local_max = torch.tensor(list(local_max), dtype=torch.float)
local_max = sign_extend_t(
@@ -97,7 +97,7 @@ def model(self, inputs):
global_max = local_max.max()
adj_amt = global_max - local_max.reshape(self.DEPTH, 1)
adj_values = integer_floor_quantizer(
- x=local_vals / (2 ** adj_amt),
+ x=local_vals / (2**adj_amt),
width=self.IN_VALUE_WIDTH,
frac_width=self.IN_VALUE_FRAC_WIDTH,
is_signed=False,
@@ -226,7 +226,10 @@ def in_value_cfgs(cfgs: list):
for cfg in cfgs:
for in_width in [4, 7, 10]:
new_cfgs.append(
- {**cfg, "IN_VALUE_WIDTH": in_width,}
+ {
+ **cfg,
+ "IN_VALUE_WIDTH": in_width,
+ }
)
return new_cfgs
@@ -235,7 +238,10 @@ def in_max_cfgs(cfgs: list):
for cfg in cfgs:
for in_max in [2, 3, 4]:
new_cfgs.append(
- {**cfg, "IN_MAX_WIDTH": in_max,}
+ {
+ **cfg,
+ "IN_MAX_WIDTH": in_max,
+ }
)
return new_cfgs
@@ -250,5 +256,7 @@ def in_max_cfgs(cfgs: list):
# cfgs = [{'TOTAL_DIM': 32, 'PARALLELISM': 4, 'IN_VALUE_WIDTH': 16, 'IN_MAX_WIDTH': 2}]
mase_runner(
- module_param_list=cfgs, trace=True, jobs=12,
+ module_param_list=cfgs,
+ trace=True,
+ jobs=12,
)
diff --git a/src/mase_components/activation_layers/test/softermax_local_window_tb.py b/src/mase_components/activation_layers/test/softermax_local_window_tb.py
index c52b7c647..10b11a569 100644
--- a/src/mase_components/activation_layers/test/softermax_local_window_tb.py
+++ b/src/mase_components/activation_layers/test/softermax_local_window_tb.py
@@ -53,7 +53,7 @@ def __init__(self, dut) -> None:
def generate_inputs(self, batches=10):
return [
- [randint(0, 2 ** self.IN_WIDTH - 1) for _ in range(self.PARALLELISM)]
+ [randint(0, 2**self.IN_WIDTH - 1) for _ in range(self.PARALLELISM)]
for _ in range(batches)
]
@@ -80,7 +80,7 @@ def model(self, inputs):
sign_ext = sign_extend_t(
torch.tensor(inputs, dtype=torch.float), bits=self.IN_WIDTH
)
- float_inputs = sign_ext / (2 ** self.IN_FRAC_WIDTH)
+ float_inputs = sign_ext / (2**self.IN_FRAC_WIDTH)
# float_inputs = torch.tensor([[-31.5, -32]])
rounded_inputs_float, rounded_inputs_uint = _fixed_signed_cast_model(
float_inputs, self.MAX_WIDTH, 0, False, "floor"
@@ -89,9 +89,9 @@ def model(self, inputs):
local_max_uint = signed_to_unsigned(local_max.int(), self.MAX_WIDTH)
difference = float_inputs - local_max
- pow2 = 2 ** difference
+ pow2 = 2**difference
res = torch.clamp(
- (pow2 * 2 ** self.OUT_FRAC_WIDTH).int(), 0, 2 ** self.OUT_WIDTH - 1
+ (pow2 * 2**self.OUT_FRAC_WIDTH).int(), 0, 2**self.OUT_WIDTH - 1
)
logger.debug("float_inputs: %s" % float_inputs)
diff --git a/src/mase_components/activation_layers/test/softermax_lpw_pow2_tb.py b/src/mase_components/activation_layers/test/softermax_lpw_pow2_tb.py
index 0a4b52af5..058fa10b0 100644
--- a/src/mase_components/activation_layers/test/softermax_lpw_pow2_tb.py
+++ b/src/mase_components/activation_layers/test/softermax_lpw_pow2_tb.py
@@ -37,7 +37,7 @@ def __init__(self, dut) -> None:
self.in_driver = StreamDriver(dut.clk, dut.in_data, dut.in_valid, dut.in_ready)
# 0.1% bit error
- self.error_threshold_bits = ceil((2 ** self.IN_WIDTH) * 0.001)
+ self.error_threshold_bits = ceil((2**self.IN_WIDTH) * 0.001)
self.output_monitor = ErrorThresholdStreamMonitor(
dut.clk,
@@ -53,10 +53,10 @@ def __init__(self, dut) -> None:
def sweep_inputs(self):
negative_nums = torch.arange(
- start=2 ** (self.IN_WIDTH - 1), end=2 ** self.IN_WIDTH, dtype=torch.int32
+ start=2 ** (self.IN_WIDTH - 1), end=2**self.IN_WIDTH, dtype=torch.int32
)
zero_to_one = torch.arange(
- start=0, end=2 ** self.IN_FRAC_WIDTH, dtype=torch.int32 # one
+ start=0, end=2**self.IN_FRAC_WIDTH, dtype=torch.int32 # one
)
return torch.cat((negative_nums, zero_to_one)).tolist()
@@ -68,14 +68,14 @@ def generate_inputs(self, batches=1):
# Negative Numbers
torch.randint(
low=2 ** (self.IN_WIDTH - 1),
- high=2 ** self.IN_WIDTH,
+ high=2**self.IN_WIDTH,
size=(negative_nums,),
dtype=torch.int32,
),
# Numbers between 0 and 1
torch.randint(
low=0,
- high=2 ** self.IN_FRAC_WIDTH,
+ high=2**self.IN_FRAC_WIDTH,
size=(zero_to_one_nums,),
dtype=torch.int32,
),
@@ -85,10 +85,10 @@ def generate_inputs(self, batches=1):
def model(self, inputs):
in_t = torch.tensor(inputs)
- num = sign_extend_t(in_t, self.IN_WIDTH) / (2 ** self.IN_FRAC_WIDTH)
- res = 2 ** num
- res = (res * 2 ** self.OUT_FRAC_WIDTH).int()
- res = torch.clamp(res, 0, 2 ** self.OUT_WIDTH - 1)
+ num = sign_extend_t(in_t, self.IN_WIDTH) / (2**self.IN_FRAC_WIDTH)
+ res = 2**num
+ res = (res * 2**self.OUT_FRAC_WIDTH).int()
+ res = torch.clamp(res, 0, 2**self.OUT_WIDTH - 1)
return res.tolist()
async def run_test(self, batches, us):
@@ -128,7 +128,7 @@ async def sweep(dut):
tb.in_driver.load_driver(inputs)
tb.output_monitor.load_monitor(exp_out)
- ns = ((2 ** tb.IN_WIDTH) * 1000) // 5
+ ns = ((2**tb.IN_WIDTH) * 1000) // 5
logger.info("Waiting %d ns..." % ns)
await Timer(ns, "ns")
assert tb.output_monitor.exp_queue.empty()
@@ -137,14 +137,14 @@ async def sweep(dut):
recv_log = tb.output_monitor.recv_log
assert len(exp_out) == len(recv_log)
- x = sign_extend_t(torch.tensor(inputs), tb.IN_WIDTH) / (2 ** tb.IN_FRAC_WIDTH)
- ref = 2 ** x
- ref *= 2 ** tb.OUT_FRAC_WIDTH # scale up
- ref = torch.clamp(ref, 0, 2 ** tb.OUT_WIDTH - 1)
+ x = sign_extend_t(torch.tensor(inputs), tb.IN_WIDTH) / (2**tb.IN_FRAC_WIDTH)
+ ref = 2**x
+ ref *= 2**tb.OUT_FRAC_WIDTH # scale up
+ ref = torch.clamp(ref, 0, 2**tb.OUT_WIDTH - 1)
- software_ref = ref / (2 ** tb.OUT_FRAC_WIDTH)
- software_res = [x / (2 ** tb.OUT_FRAC_WIDTH) for x in exp_out]
- hardware_res = [x / (2 ** tb.OUT_FRAC_WIDTH) for x in recv_log]
+ software_ref = ref / (2**tb.OUT_FRAC_WIDTH)
+ software_res = [x / (2**tb.OUT_FRAC_WIDTH) for x in exp_out]
+ hardware_res = [x / (2**tb.OUT_FRAC_WIDTH) for x in recv_log]
data = pd.DataFrame(
{
@@ -174,7 +174,10 @@ async def sweep(dut):
),
color=alt.Color("Type"),
)
- .properties(width=600, height=220,)
+ .properties(
+ width=600,
+ height=220,
+ )
)
error_data = data[["x", "hardware error"]]
@@ -185,7 +188,10 @@ async def sweep(dut):
x=alt.X("x").title(f"x (Q{tb.IN_WIDTH}.{tb.IN_FRAC_WIDTH} Fixed-point)"),
y=alt.Y("hardware error").title(f"Error"),
)
- .properties(width=600, height=100,)
+ .properties(
+ width=600,
+ height=100,
+ )
)
(curve_fig & error_fig).save(
@@ -296,7 +302,12 @@ def test_high_width():
def test_smoke():
mase_runner(
module_param_list=[
- {"IN_WIDTH": 8, "IN_FRAC_WIDTH": 4, "OUT_WIDTH": 8, "OUT_FRAC_WIDTH": 4,}
+ {
+ "IN_WIDTH": 8,
+ "IN_FRAC_WIDTH": 4,
+ "OUT_WIDTH": 8,
+ "OUT_FRAC_WIDTH": 4,
+ }
]
)
diff --git a/src/mase_components/activation_layers/test/softermax_lpw_reciprocal_tb.py b/src/mase_components/activation_layers/test/softermax_lpw_reciprocal_tb.py
index 842881dc9..f2c70b67c 100644
--- a/src/mase_components/activation_layers/test/softermax_lpw_reciprocal_tb.py
+++ b/src/mase_components/activation_layers/test/softermax_lpw_reciprocal_tb.py
@@ -38,7 +38,7 @@ def __init__(self, dut) -> None:
# Specify Error Threshold
self.percentage_error = 0.05 # 5%
- self.error_threshold_bits = ceil(self.percentage_error * (2 ** self.OUT_WIDTH))
+ self.error_threshold_bits = ceil(self.percentage_error * (2**self.OUT_WIDTH))
self.output_monitor = ErrorThresholdStreamMonitor(
dut.clk,
@@ -53,17 +53,17 @@ def __init__(self, dut) -> None:
)
def generate_inputs(self, batches=100):
- return [randint(0, 2 ** self.IN_WIDTH - 1) for _ in range(batches)]
+ return [randint(0, 2**self.IN_WIDTH - 1) for _ in range(batches)]
def sweep_input(self):
- return list(range(2 ** self.IN_WIDTH))
+ return list(range(2**self.IN_WIDTH))
def model(self, inputs):
- in_t = torch.tensor(inputs) / (2 ** self.IN_FRAC_WIDTH)
+ in_t = torch.tensor(inputs) / (2**self.IN_FRAC_WIDTH)
recip = 1.0 / in_t
- res = torch.floor(recip * 2 ** self.OUT_FRAC_WIDTH)
+ res = torch.floor(recip * 2**self.OUT_FRAC_WIDTH)
res = torch.nan_to_num(res)
- res = torch.clamp(res, 0, 2 ** self.OUT_WIDTH - 1)
+ res = torch.clamp(res, 0, 2**self.OUT_WIDTH - 1)
res = res.int()
return res.tolist()
@@ -107,14 +107,14 @@ async def sweep(dut):
recv_log = tb.output_monitor.recv_log
assert len(exp_out) == len(recv_log)
- x = torch.tensor(inputs) / (2 ** tb.IN_FRAC_WIDTH)
+ x = torch.tensor(inputs) / (2**tb.IN_FRAC_WIDTH)
ref = 1.0 / x
- ref *= 2 ** tb.OUT_FRAC_WIDTH # scale up
- ref = torch.clamp(ref, 0, 2 ** tb.OUT_WIDTH - 1)
+ ref *= 2**tb.OUT_FRAC_WIDTH # scale up
+ ref = torch.clamp(ref, 0, 2**tb.OUT_WIDTH - 1)
- software_ref = ref / (2 ** tb.OUT_FRAC_WIDTH)
- software_res = [x / (2 ** tb.OUT_FRAC_WIDTH) for x in exp_out]
- hardware_res = [x / (2 ** tb.OUT_FRAC_WIDTH) for x in recv_log]
+ software_ref = ref / (2**tb.OUT_FRAC_WIDTH)
+ software_res = [x / (2**tb.OUT_FRAC_WIDTH) for x in exp_out]
+ hardware_res = [x / (2**tb.OUT_FRAC_WIDTH) for x in recv_log]
data = pd.DataFrame(
{
@@ -150,7 +150,10 @@ async def sweep(dut):
),
color=alt.Color("Type"),
)
- .properties(width=600, height=300,)
+ .properties(
+ width=600,
+ height=300,
+ )
)
error_data = data[["x", "hardware error"]]
@@ -161,7 +164,10 @@ async def sweep(dut):
x=alt.X("x").title(f"x (Q{tb.IN_WIDTH}.{tb.IN_FRAC_WIDTH} Fixed-point)"),
y=alt.Y("hardware error").title(f"Error"),
)
- .properties(width=600, height=100,)
+ .properties(
+ width=600,
+ height=100,
+ )
)
(curve_fig & error_fig).save(
@@ -216,7 +222,7 @@ def width_cfgs():
for width in range(2, 16 + 1):
frac_width = width // 2
if frac_width < 3:
- entries = 2 ** frac_width
+ entries = 2**frac_width
else:
entries = 8
cfgs.append(
@@ -254,7 +260,9 @@ def test_width_cfgs():
"OUT_FRAC_WIDTH": 4,
}
]
- mase_runner(module_param_list=cfgs,)
+ mase_runner(
+ module_param_list=cfgs,
+ )
def test_smoke():
diff --git a/src/mase_components/cast/test/fixed_rounding_tb.py b/src/mase_components/cast/test/fixed_rounding_tb.py
index 300283856..a7f2f935c 100644
--- a/src/mase_components/cast/test/fixed_rounding_tb.py
+++ b/src/mase_components/cast/test/fixed_rounding_tb.py
@@ -44,7 +44,7 @@ def single_run(self):
def sw_cast(self, inputs):
outputs = (
integer_floor_quantizer(inputs, self.out_width, self.out_frac_width)
- * 2 ** self.out_frac_width
+ * 2**self.out_frac_width
)
# breakpoint()
return outputs
diff --git a/src/mase_components/cast/test/fixed_signed_cast_tb.py b/src/mase_components/cast/test/fixed_signed_cast_tb.py
index 9f5a6956d..2ce3bbe5f 100644
--- a/src/mase_components/cast/test/fixed_signed_cast_tb.py
+++ b/src/mase_components/cast/test/fixed_signed_cast_tb.py
@@ -18,7 +18,7 @@
def _fixed_signed_cast_model(
float_input, out_width, out_frac_width, symmetric, rounding_mode
):
- scaled_float = float_input * (2 ** out_frac_width)
+ scaled_float = float_input * (2**out_frac_width)
if rounding_mode == "floor":
out_int = my_floor(scaled_float)
elif rounding_mode == "round_nearest_half_even":
@@ -30,7 +30,7 @@ def _fixed_signed_cast_model(
-(2 ** (out_width - 1)) + 1 if symmetric else -(2 ** (out_width - 1)),
(2 ** (out_width - 1)) - 1,
)
- out_float = out_int / (2 ** out_frac_width)
+ out_float = out_int / (2**out_frac_width)
# out_uint is a non-differentiable path
out_uint = signed_to_unsigned(out_int.int(), out_width)
return out_float, out_uint
@@ -58,9 +58,9 @@ def __init__(self, dut) -> None:
)
def generate_inputs(self):
- uints = torch.arange(2 ** self.IN_WIDTH)
+ uints = torch.arange(2**self.IN_WIDTH)
num_int = sign_extend_t(uints, self.IN_WIDTH)
- num_float = num_int / (2 ** self.IN_FRAC_WIDTH)
+ num_float = num_int / (2**self.IN_FRAC_WIDTH)
return num_int, num_float
def rounding_mode(self):
@@ -150,7 +150,10 @@ def gen_symmetric(cfg_list):
l = list()
for cfg in cfg_list:
l.extend(
- [{**cfg, "SYMMETRIC": 0}, {**cfg, "SYMMETRIC": 1},]
+ [
+ {**cfg, "SYMMETRIC": 0},
+ {**cfg, "SYMMETRIC": 1},
+ ]
)
return l
diff --git a/src/mase_components/cast/test/fixed_unsigned_cast_tb.py b/src/mase_components/cast/test/fixed_unsigned_cast_tb.py
index 75910ed9e..cfeca8318 100644
--- a/src/mase_components/cast/test/fixed_unsigned_cast_tb.py
+++ b/src/mase_components/cast/test/fixed_unsigned_cast_tb.py
@@ -51,7 +51,7 @@ def __init__(self, dut) -> None:
)
def generate_inputs(self):
- return torch.arange(2 ** self.IN_WIDTH)
+ return torch.arange(2**self.IN_WIDTH)
def rounding_mode(self):
if self.ROUND_FLOOR:
@@ -64,10 +64,10 @@ def rounding_mode(self):
raise Exception("Rounding mode not recognised.")
def model(self, inputs):
- float_input = inputs / (2 ** self.IN_FRAC_WIDTH)
- scaled_float = float_input * (2 ** self.OUT_FRAC_WIDTH)
+ float_input = inputs / (2**self.IN_FRAC_WIDTH)
+ scaled_float = float_input * (2**self.OUT_FRAC_WIDTH)
rounded = torch.floor(scaled_float)
- model_out = torch.clamp(rounded, 0, (2 ** self.OUT_WIDTH - 1))
+ model_out = torch.clamp(rounded, 0, (2**self.OUT_WIDTH - 1))
return model_out
async def run_test(self):
diff --git a/src/mase_components/common/test/comparator_accumulator_tb.py b/src/mase_components/common/test/comparator_accumulator_tb.py
index d05f41a09..fbd8af937 100644
--- a/src/mase_components/common/test/comparator_accumulator_tb.py
+++ b/src/mase_components/common/test/comparator_accumulator_tb.py
@@ -33,9 +33,7 @@ def __init__(self, dut) -> None:
)
def generate_inputs(self, batches=3):
- return [
- randint(0, 2 ** self.DATA_WIDTH - 1) for _ in range(self.DEPTH * batches)
- ]
+ return [randint(0, 2**self.DATA_WIDTH - 1) for _ in range(self.DEPTH * batches)]
def model(self, inputs):
@@ -154,5 +152,7 @@ def signed_max_min_cfgs(cfglist: list):
cfgs = signed_max_min_cfgs(cfgs)
mase_runner(
- module_param_list=cfgs, trace=True, jobs=12,
+ module_param_list=cfgs,
+ trace=True,
+ jobs=12,
)
diff --git a/src/mase_components/common/test/comparator_tree_tb.py b/src/mase_components/common/test/comparator_tree_tb.py
index ae2a7fc61..5637791b3 100644
--- a/src/mase_components/common/test/comparator_tree_tb.py
+++ b/src/mase_components/common/test/comparator_tree_tb.py
@@ -34,7 +34,7 @@ def __init__(self, dut) -> None:
def generate_inputs(self, batches=3):
return [
- [randint(0, 2 ** self.DATA_WIDTH - 1) for _ in range(self.SIZE)]
+ [randint(0, 2**self.DATA_WIDTH - 1) for _ in range(self.SIZE)]
for _ in range(batches)
]
@@ -136,5 +136,6 @@ def signed_max_min_cfgs(cfglist: list):
cfgs = signed_max_min_cfgs(cfgs)
mase_runner(
- module_param_list=cfgs, trace=True,
+ module_param_list=cfgs,
+ trace=True,
)
diff --git a/src/mase_components/common/test/register_slice_tb.py b/src/mase_components/common/test/register_slice_tb.py
index ffcec4f3e..638d53da7 100644
--- a/src/mase_components/common/test/register_slice_tb.py
+++ b/src/mase_components/common/test/register_slice_tb.py
@@ -60,7 +60,9 @@ def in_out_wave(dut, name):
)
logger.debug(
"{} State: (shift_reg, buffer) = ({},{})".format(
- name, int(dut.shift_reg.value), int(dut.buffer.value),
+ name,
+ int(dut.shift_reg.value),
+ int(dut.buffer.value),
)
)
diff --git a/src/mase_components/common/test/single_element_repeat_tb.py b/src/mase_components/common/test/single_element_repeat_tb.py
index 6d39b0c35..247cc2b27 100644
--- a/src/mase_components/common/test/single_element_repeat_tb.py
+++ b/src/mase_components/common/test/single_element_repeat_tb.py
@@ -26,7 +26,7 @@ def __init__(self, dut) -> None:
)
def generate_inputs(self, num=10):
- return [random.randint(0, 2 ** self.DATA_WIDTH - 1) for _ in range(num)]
+ return [random.randint(0, 2**self.DATA_WIDTH - 1) for _ in range(num)]
def model(self, inputs):
exp_out = []
@@ -103,5 +103,7 @@ def generate_random_params():
]
mase_runner(
- module_param_list=cfgs, trace=True, jobs=8,
+ module_param_list=cfgs,
+ trace=True,
+ jobs=8,
)
diff --git a/src/mase_components/convolution_layers/test/binary_activation_binary_convolution_tb.py b/src/mase_components/convolution_layers/test/binary_activation_binary_convolution_tb.py
index 1481671a3..8091990e4 100644
--- a/src/mase_components/convolution_layers/test/binary_activation_binary_convolution_tb.py
+++ b/src/mase_components/convolution_layers/test/binary_activation_binary_convolution_tb.py
@@ -244,7 +244,9 @@ def sw_compute(self):
bias = \n\
{} \n\
".format(
- data, weight, bias,
+ data,
+ weight,
+ bias,
)
)
for i in range(self.samples):
diff --git a/src/mase_components/convolution_layers/test/convolution_tb.py b/src/mase_components/convolution_layers/test/convolution_tb.py
index ec3c07a37..1a3789ed9 100644
--- a/src/mase_components/convolution_layers/test/convolution_tb.py
+++ b/src/mase_components/convolution_layers/test/convolution_tb.py
@@ -130,7 +130,11 @@ def get_manual_result(
# out2 = get_manual_result(x, w, b, 2,1,2,2,4,4,0,0,12,4)
# data_in_pack
- x = q2i(x, config["data_in_width"], config["data_in_frac_width"],)
+ x = q2i(
+ x,
+ config["data_in_width"],
+ config["data_in_frac_width"],
+ )
self.log.info(f"x = {x}")
# from (samples, c, h, w) to (samples*h*w*c/unroll_in_c, unroll_in_c)
@@ -140,8 +144,16 @@ def get_manual_result(
self.log.info(f"weight = {w}")
self.log.info(f"bias = {b}")
- w = q2i(w, config["weight_width"], config["weight_frac_width"],)
- b = q2i(b, config["bias_width"], config["bias_frac_width"],)
+ w = q2i(
+ w,
+ config["weight_width"],
+ config["weight_frac_width"],
+ )
+ b = q2i(
+ b,
+ config["bias_width"],
+ config["bias_frac_width"],
+ )
self.log.info(f"weight = {w}")
self.log.info(f"bias = {b}")
hw_w, hw_b = self.conv_pack(
@@ -157,7 +169,11 @@ def get_manual_result(
unroll_kernel_out=self.get_parameter("UNROLL_KERNEL_OUT"),
unroll_out_channels=self.get_parameter("UNROLL_OUT_C"),
)
- exp_out = q2i(out, config["out_width"], config["out_frac_width"],)
+ exp_out = q2i(
+ out,
+ config["out_width"],
+ config["out_frac_width"],
+ )
exp_out = (
exp_out.reshape(
-1, self.get_parameter("OUT_C"), self.get_parameter("SLIDING_NUM")
@@ -207,7 +223,10 @@ def conv_pack(
unroll_out_channels,
).permute(0, 3, 1, 4, 2)
- w_tensor = w_tensor.reshape(-1, unroll_out_channels * unroll_kernel_out,)
+ w_tensor = w_tensor.reshape(
+ -1,
+ unroll_out_channels * unroll_kernel_out,
+ )
w_in = w_tensor.type(torch.int).tolist()
# bias_pack
bias_tensor = (
@@ -313,7 +332,10 @@ def test_fixed_linear_smoke():
Some quick tests to check if the module is working.
"""
mase_runner(
- trace=True, module_param_list=[get_fixed_conv_config(),],
+ trace=True,
+ module_param_list=[
+ get_fixed_conv_config(),
+ ],
)
diff --git a/src/mase_components/convolution_layers/test/padding_tb.py b/src/mase_components/convolution_layers/test/padding_tb.py
index d0504ef5d..6cb470069 100644
--- a/src/mase_components/convolution_layers/test/padding_tb.py
+++ b/src/mase_components/convolution_layers/test/padding_tb.py
@@ -96,9 +96,9 @@ def data_pack(self):
for j in range(in_channels):
for k in range(in_height):
for s in range(in_width):
- re_data_tensor[i][j][k + padding_height][
- s + padding_width
- ] = data_tensor[i][k][s][j]
+ re_data_tensor[i][j][k + padding_height][s + padding_width] = (
+ data_tensor[i][k][s][j]
+ )
return re_data_tensor
@@ -248,7 +248,9 @@ def runner():
print(extra_args)
runner = get_runner(sim)
runner.build(
- verilog_sources=verilog_sources, hdl_toplevel="padding", build_args=extra_args,
+ verilog_sources=verilog_sources,
+ hdl_toplevel="padding",
+ build_args=extra_args,
)
runner.test(hdl_toplevel="padding", test_module="padding_tb")
diff --git a/src/mase_components/convolution_layers/test/roller_tb.py b/src/mase_components/convolution_layers/test/roller_tb.py
index 8c7a24d02..4bf56b273 100644
--- a/src/mase_components/convolution_layers/test/roller_tb.py
+++ b/src/mase_components/convolution_layers/test/roller_tb.py
@@ -195,7 +195,9 @@ def runner():
print(extra_args)
runner = get_runner(sim)()
runner.build(
- verilog_sources=verilog_sources, toplevel="roller", extra_args=extra_args,
+ verilog_sources=verilog_sources,
+ toplevel="roller",
+ extra_args=extra_args,
)
runner.test(toplevel="roller", py_module="roller_tb")
diff --git a/src/mase_components/convolution_layers/test/sliding_window_tb.py b/src/mase_components/convolution_layers/test/sliding_window_tb.py
index 73facf247..f893d6f68 100644
--- a/src/mase_components/convolution_layers/test/sliding_window_tb.py
+++ b/src/mase_components/convolution_layers/test/sliding_window_tb.py
@@ -119,9 +119,9 @@ def data_pack(self):
for j in range(in_channels):
for k in range(in_height):
for s in range(in_width):
- re_data_tensor[i][j][k + padding_height][
- s + padding_width
- ] = data_tensor[i][k][s][j]
+ re_data_tensor[i][j][k + padding_height][s + padding_width] = (
+ data_tensor[i][k][s][j]
+ )
return re_data_tensor
diff --git a/src/mase_components/deps.py b/src/mase_components/deps.py
index 9596adedc..a06529344 100644
--- a/src/mase_components/deps.py
+++ b/src/mase_components/deps.py
@@ -17,7 +17,11 @@
"activation_layers",
"scalar_operators/fixed",
],
- "activation_layers/fixed_gelu": ["common", "memory", "activation_layers",],
+ "activation_layers/fixed_gelu": [
+ "common",
+ "memory",
+ "activation_layers",
+ ],
"activation_layers/fixed_softsign": [
"common",
"activation_layers",
@@ -123,7 +127,11 @@
"common",
"cast",
],
- "language_models/llmint8/scatter": ["language_models/llmint8", "memory", "common",],
+ "language_models/llmint8/scatter": [
+ "language_models/llmint8",
+ "memory",
+ "common",
+ ],
# Linear
"linear_layers/fixed_linear_layer/fixed_linear": [
"cast",
diff --git a/src/mase_components/helper/generate_memory.py b/src/mase_components/helper/generate_memory.py
index c780d6c7f..e46c99604 100644
--- a/src/mase_components/helper/generate_memory.py
+++ b/src/mase_components/helper/generate_memory.py
@@ -60,10 +60,10 @@ def generate_lookup(data_width: int, f_width: int, function: str, type="hex"):
count += 1
iarr.append(i)
val = quanter(f(torch.tensor(i))) # entry in the lookup table
- lut[
- doubletofx(data_width=data_width, f_width=f_width, num=i, type=type)
- ] = doubletofx(
- data_width=data_width, f_width=f_width, num=val.item(), type=type
+ lut[doubletofx(data_width=data_width, f_width=f_width, num=i, type=type)] = (
+ doubletofx(
+ data_width=data_width, f_width=f_width, num=val.item(), type=type
+ )
)
i += 2 ** -(f_width)
return lut
diff --git a/src/mase_components/hls/bfp_arith/bfp_adder.py b/src/mase_components/hls/bfp_arith/bfp_adder.py
index 0ba95741e..5b2a401d9 100644
--- a/src/mase_components/hls/bfp_arith/bfp_adder.py
+++ b/src/mase_components/hls/bfp_arith/bfp_adder.py
@@ -2,7 +2,11 @@
def bfp_adder_gen(
- writer, x_exp_width=16, x_man_width=8, w_exp_width=16, w_man_width=8,
+ writer,
+ x_exp_width=16,
+ x_man_width=8,
+ w_exp_width=16,
+ w_man_width=8,
):
"""
This script generates a element-level bfp add in HLS
diff --git a/src/mase_components/hls/bfp_arith/bfp_multiplier.py b/src/mase_components/hls/bfp_arith/bfp_multiplier.py
index fd588ea3c..e3695295e 100644
--- a/src/mase_components/hls/bfp_arith/bfp_multiplier.py
+++ b/src/mase_components/hls/bfp_arith/bfp_multiplier.py
@@ -1,5 +1,9 @@
def bfp_multiplier_gen(
- writer, x_exp_width=16, x_man_width=8, w_exp_width=16, w_man_width=8,
+ writer,
+ x_exp_width=16,
+ x_man_width=8,
+ w_exp_width=16,
+ w_man_width=8,
):
"""
This script generates a element-level bfp mult in HLS
diff --git a/src/mase_components/hls/elastic/buffer.py b/src/mase_components/hls/elastic/buffer.py
index a0fda6b90..d25c1fd56 100644
--- a/src/mase_components/hls/elastic/buffer.py
+++ b/src/mase_components/hls/elastic/buffer.py
@@ -2,7 +2,13 @@
def buffer_gen(
- writer, x_width=16, x_frac_width=8, x_row=3, x_col=2, x_row_depth=3, x_col_depth=2,
+ writer,
+ x_width=16,
+ x_frac_width=8,
+ x_row=3,
+ x_col=2,
+ x_row_depth=3,
+ x_col_depth=2,
):
"""
This script generates a buffer in HLS
diff --git a/src/mase_components/hls/hls_regression.py b/src/mase_components/hls/hls_regression.py
index 719b20380..a02dca60a 100755
--- a/src/mase_components/hls/hls_regression.py
+++ b/src/mase_components/hls/hls_regression.py
@@ -106,7 +106,10 @@ def main():
parser = ArgumentParser(usage=USAGE)
parser.add_argument(
- "--op", dest="op", default=None, help="Op name to explore",
+ "--op",
+ dest="op",
+ default=None,
+ help="Op name to explore",
)
parser.add_argument(
"--dir",
diff --git a/src/mase_components/hls/int_arith/int_layernorm.py b/src/mase_components/hls/int_arith/int_layernorm.py
index 9521a68c0..059b15d4d 100644
--- a/src/mase_components/hls/int_arith/int_layernorm.py
+++ b/src/mase_components/hls/int_arith/int_layernorm.py
@@ -2,7 +2,13 @@
def int_layernorm_gen(
- writer, x_width=16, x_frac_width=8, x_row=3, x_col=2, x_row_depth=3, x_col_depth=2,
+ writer,
+ x_width=16,
+ x_frac_width=8,
+ x_row=3,
+ x_col=2,
+ x_row_depth=3,
+ x_col_depth=2,
):
"""
This script generates a fixed-point layernorm in HLS
diff --git a/src/mase_components/hls/int_arith/int_relu.py b/src/mase_components/hls/int_arith/int_relu.py
index 09f7462c2..f934640c9 100644
--- a/src/mase_components/hls/int_arith/int_relu.py
+++ b/src/mase_components/hls/int_arith/int_relu.py
@@ -2,7 +2,13 @@
def int_relu_gen(
- writer, x_width=16, x_frac_width=8, x_row=3, x_col=2, x_row_depth=3, x_col_depth=2,
+ writer,
+ x_width=16,
+ x_frac_width=8,
+ x_row=3,
+ x_col=2,
+ x_row_depth=3,
+ x_col_depth=2,
):
"""
This script generates a fixed-point relu in HLS
diff --git a/src/mase_components/hls/int_arith/int_silu.py b/src/mase_components/hls/int_arith/int_silu.py
index c90511c47..c4fc32abe 100644
--- a/src/mase_components/hls/int_arith/int_silu.py
+++ b/src/mase_components/hls/int_arith/int_silu.py
@@ -2,7 +2,13 @@
def int_silu_gen(
- writer, x_width=16, x_frac_width=8, x_row=3, x_col=2, x_row_depth=3, x_col_depth=2,
+ writer,
+ x_width=16,
+ x_frac_width=8,
+ x_row=3,
+ x_col=2,
+ x_row_depth=3,
+ x_col_depth=2,
):
"""
This script generates a fixed-point silu in HLS
diff --git a/src/mase_components/hls/int_arith/int_softmax.py b/src/mase_components/hls/int_arith/int_softmax.py
index 8ad5901cf..c0b4ca5b1 100644
--- a/src/mase_components/hls/int_arith/int_softmax.py
+++ b/src/mase_components/hls/int_arith/int_softmax.py
@@ -2,7 +2,13 @@
def int_softmax_gen(
- writer, x_width=16, x_frac_width=8, x_row=3, x_col=2, x_row_depth=3, x_col_depth=2,
+ writer,
+ x_width=16,
+ x_frac_width=8,
+ x_row=3,
+ x_col=2,
+ x_row_depth=3,
+ x_col_depth=2,
):
"""
This script generates a fixed-point softmax in HLS
diff --git a/src/mase_components/hls/int_arith/int_transpose.py b/src/mase_components/hls/int_arith/int_transpose.py
index ee9c6d2d7..6b7fd0a15 100644
--- a/src/mase_components/hls/int_arith/int_transpose.py
+++ b/src/mase_components/hls/int_arith/int_transpose.py
@@ -2,7 +2,13 @@
def int_transpose_gen(
- writer, x_width=16, x_frac_width=8, x_row=3, x_col=2, x_row_depth=3, x_col_depth=2,
+ writer,
+ x_width=16,
+ x_frac_width=8,
+ x_row=3,
+ x_col=2,
+ x_row_depth=3,
+ x_col_depth=2,
):
"""
This script generates a fixed-point transpose in HLS
diff --git a/src/mase_components/hls/regression_gen/bfp_add_dse.py b/src/mase_components/hls/regression_gen/bfp_add_dse.py
index c283301d2..834644b24 100644
--- a/src/mase_components/hls/regression_gen/bfp_add_dse.py
+++ b/src/mase_components/hls/regression_gen/bfp_add_dse.py
@@ -1,7 +1,13 @@
# TODO: Temporary working solution
import sys, os
-sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..",))
+sys.path.append(
+ os.path.join(
+ os.path.dirname(os.path.realpath(__file__)),
+ "..",
+ "..",
+ )
+)
from hls.regression_gen.utils import (
DSE_MODES,
@@ -151,7 +157,8 @@ def bfp_add_dse(mode=None, top=None, threads=16):
if mode in ["report", "all"]:
top_name = "bfp_add_2"
hr = get_hls_results(
- project=os.path.join(top, file_name), top=top_name,
+ project=os.path.join(top, file_name),
+ top=top_name,
)
data_points.append(
[
diff --git a/src/mase_components/hls/regression_gen/bfp_linear2d_dse.py b/src/mase_components/hls/regression_gen/bfp_linear2d_dse.py
index 9a894b143..d6f78f30c 100644
--- a/src/mase_components/hls/regression_gen/bfp_linear2d_dse.py
+++ b/src/mase_components/hls/regression_gen/bfp_linear2d_dse.py
@@ -1,7 +1,13 @@
# TODO: Temporary working solution
import sys, os
-sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..",))
+sys.path.append(
+ os.path.join(
+ os.path.dirname(os.path.realpath(__file__)),
+ "..",
+ "..",
+ )
+)
from hls.regression_gen.utils import (
DSE_MODES,
diff --git a/src/mase_components/hls/regression_gen/bfp_mult_dse.py b/src/mase_components/hls/regression_gen/bfp_mult_dse.py
index 1510ea1f3..b4b5fa58c 100644
--- a/src/mase_components/hls/regression_gen/bfp_mult_dse.py
+++ b/src/mase_components/hls/regression_gen/bfp_mult_dse.py
@@ -1,7 +1,13 @@
# TODO: Temporary working solution
import sys, os
-sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..",))
+sys.path.append(
+ os.path.join(
+ os.path.dirname(os.path.realpath(__file__)),
+ "..",
+ "..",
+ )
+)
from hls.regression_gen.utils import (
DSE_MODES,
@@ -151,7 +157,8 @@ def bfp_mult_dse(mode=None, top=None, threads=16):
if mode in ["report", "all"]:
top_name = "bfp_mult_2"
hr = get_hls_results(
- project=os.path.join(top, file_name), top=top_name,
+ project=os.path.join(top, file_name),
+ top=top_name,
)
if hr is None:
continue
diff --git a/src/mase_components/hls/regression_gen/buffer_dse.py b/src/mase_components/hls/regression_gen/buffer_dse.py
index bcc003c6d..08cc6d56e 100644
--- a/src/mase_components/hls/regression_gen/buffer_dse.py
+++ b/src/mase_components/hls/regression_gen/buffer_dse.py
@@ -1,7 +1,13 @@
# TODO: Temporary working solution
import sys, os
-sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..",))
+sys.path.append(
+ os.path.join(
+ os.path.dirname(os.path.realpath(__file__)),
+ "..",
+ "..",
+ )
+)
from hls.regression_gen.utils import (
DSE_MODES,
@@ -123,7 +129,8 @@ def buffer_dse(mode=None, top=None, threads=16):
if mode in ["report", "all"]:
top_name = "buffer_0"
hr = get_hls_results(
- project=os.path.join(top, file_name), top=top_name,
+ project=os.path.join(top, file_name),
+ top=top_name,
)
data_points.append(
[
diff --git a/src/mase_components/hls/regression_gen/fork_dse.py b/src/mase_components/hls/regression_gen/fork_dse.py
index c8a5b3c0b..f045cd8a3 100644
--- a/src/mase_components/hls/regression_gen/fork_dse.py
+++ b/src/mase_components/hls/regression_gen/fork_dse.py
@@ -1,7 +1,13 @@
# TODO: Temporary working solution
import sys, os
-sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..",))
+sys.path.append(
+ os.path.join(
+ os.path.dirname(os.path.realpath(__file__)),
+ "..",
+ "..",
+ )
+)
from hls.regression_gen.utils import (
DSE_MODES,
@@ -133,7 +139,8 @@ def fork_dse(mode=None, top=None, threads=16):
if mode in ["report", "all"]:
top_name = "fork_0"
hr = get_hls_results(
- project=os.path.join(top, file_name), top=top_name,
+ project=os.path.join(top, file_name),
+ top=top_name,
)
data_points.append(
[
diff --git a/src/mase_components/hls/regression_gen/int_add_dse.py b/src/mase_components/hls/regression_gen/int_add_dse.py
index a88247d24..d4c4d9ccc 100644
--- a/src/mase_components/hls/regression_gen/int_add_dse.py
+++ b/src/mase_components/hls/regression_gen/int_add_dse.py
@@ -1,7 +1,13 @@
# TODO: Temporary working solution
import sys, os
-sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..",))
+sys.path.append(
+ os.path.join(
+ os.path.dirname(os.path.realpath(__file__)),
+ "..",
+ "..",
+ )
+)
from hls.regression_gen.utils import (
DSE_MODES,
@@ -151,7 +157,8 @@ def int_add_dse(mode=None, top=None, threads=16):
if mode in ["report", "all"]:
top_name = "int_add_0"
hr = get_hls_results(
- project=os.path.join(top, file_name), top=top_name,
+ project=os.path.join(top, file_name),
+ top=top_name,
)
data_points.append(
[
diff --git a/src/mase_components/hls/regression_gen/int_layernorm_dse.py b/src/mase_components/hls/regression_gen/int_layernorm_dse.py
index d0ef2afdb..3e6a61aff 100644
--- a/src/mase_components/hls/regression_gen/int_layernorm_dse.py
+++ b/src/mase_components/hls/regression_gen/int_layernorm_dse.py
@@ -1,7 +1,13 @@
# TODO: Temporary working solution
import sys, os
-sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..",))
+sys.path.append(
+ os.path.join(
+ os.path.dirname(os.path.realpath(__file__)),
+ "..",
+ "..",
+ )
+)
from hls.regression_gen.utils import (
DSE_MODES,
@@ -124,7 +130,8 @@ def int_layernorm_dse(mode=None, top=None, threads=16):
if mode in ["report", "all"]:
top_name = "int_layernorm_1"
hr = get_hls_results(
- project=os.path.join(top, file_name), top=top_name,
+ project=os.path.join(top, file_name),
+ top=top_name,
)
data_points.append(
[
diff --git a/src/mase_components/hls/regression_gen/int_linear2d_dse.py b/src/mase_components/hls/regression_gen/int_linear2d_dse.py
index d75122306..1c89861eb 100644
--- a/src/mase_components/hls/regression_gen/int_linear2d_dse.py
+++ b/src/mase_components/hls/regression_gen/int_linear2d_dse.py
@@ -1,7 +1,13 @@
# TODO: Temporary working solution
import sys, os
-sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..",))
+sys.path.append(
+ os.path.join(
+ os.path.dirname(os.path.realpath(__file__)),
+ "..",
+ "..",
+ )
+)
from hls.regression_gen.utils import (
DSE_MODES,
diff --git a/src/mase_components/hls/regression_gen/int_matmul_dse.py b/src/mase_components/hls/regression_gen/int_matmul_dse.py
index ac09aaa5c..1d1f35fd2 100644
--- a/src/mase_components/hls/regression_gen/int_matmul_dse.py
+++ b/src/mase_components/hls/regression_gen/int_matmul_dse.py
@@ -1,7 +1,13 @@
# TODO: Temporary working solution
import sys, os
-sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..",))
+sys.path.append(
+ os.path.join(
+ os.path.dirname(os.path.realpath(__file__)),
+ "..",
+ "..",
+ )
+)
from hls.regression_gen.utils import DSE_MODES, get_tcl_buff
from hls.int_arith import int_matmul_gen
diff --git a/src/mase_components/hls/regression_gen/int_mult_dse.py b/src/mase_components/hls/regression_gen/int_mult_dse.py
index f824ec99e..275ffb163 100644
--- a/src/mase_components/hls/regression_gen/int_mult_dse.py
+++ b/src/mase_components/hls/regression_gen/int_mult_dse.py
@@ -1,7 +1,13 @@
# TODO: Temporary working solution
import sys, os
-sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..",))
+sys.path.append(
+ os.path.join(
+ os.path.dirname(os.path.realpath(__file__)),
+ "..",
+ "..",
+ )
+)
from hls.regression_gen.utils import (
DSE_MODES,
@@ -151,7 +157,8 @@ def int_mult_dse(mode=None, top=None, threads=16):
if mode in ["report", "all"]:
top_name = "int_mult_0"
hr = get_hls_results(
- project=os.path.join(top, file_name), top=top_name,
+ project=os.path.join(top, file_name),
+ top=top_name,
)
data_points.append(
[
diff --git a/src/mase_components/hls/regression_gen/int_relu_dse.py b/src/mase_components/hls/regression_gen/int_relu_dse.py
index 788e3f593..7224b1444 100644
--- a/src/mase_components/hls/regression_gen/int_relu_dse.py
+++ b/src/mase_components/hls/regression_gen/int_relu_dse.py
@@ -1,7 +1,13 @@
# TODO: Temporary working solution
import sys, os
-sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..",))
+sys.path.append(
+ os.path.join(
+ os.path.dirname(os.path.realpath(__file__)),
+ "..",
+ "..",
+ )
+)
from hls.regression_gen.utils import (
DSE_MODES,
@@ -125,7 +131,8 @@ def int_relu_dse(mode=None, top=None, threads=16):
if mode in ["report", "all"]:
top_name = "int_relu_0"
hr = get_hls_results(
- project=os.path.join(top, file_name), top=top_name,
+ project=os.path.join(top, file_name),
+ top=top_name,
)
data_points.append(
[
diff --git a/src/mase_components/hls/regression_gen/int_rmsnorm_dse.py b/src/mase_components/hls/regression_gen/int_rmsnorm_dse.py
index c559ecee8..262794af2 100644
--- a/src/mase_components/hls/regression_gen/int_rmsnorm_dse.py
+++ b/src/mase_components/hls/regression_gen/int_rmsnorm_dse.py
@@ -1,7 +1,13 @@
# TODO: Temporary working solution
import sys, os
-sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..",))
+sys.path.append(
+ os.path.join(
+ os.path.dirname(os.path.realpath(__file__)),
+ "..",
+ "..",
+ )
+)
from hls.regression_gen.utils import (
DSE_MODES,
@@ -145,7 +151,8 @@ def int_rmsnorm_dse(mode=None, top=None, threads=16):
if mode in ["report", "all"]:
top_name = "int_rmsnorm_0"
hr = get_hls_results(
- project=os.path.join(top, file_name), top=top_name,
+ project=os.path.join(top, file_name),
+ top=top_name,
)
data_points.append(
[
diff --git a/src/mase_components/hls/regression_gen/int_rope_dse.py b/src/mase_components/hls/regression_gen/int_rope_dse.py
index bfc2bc385..421a37e21 100644
--- a/src/mase_components/hls/regression_gen/int_rope_dse.py
+++ b/src/mase_components/hls/regression_gen/int_rope_dse.py
@@ -1,7 +1,13 @@
# TODO: Temporary working solution
import sys, os
-sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..",))
+sys.path.append(
+ os.path.join(
+ os.path.dirname(os.path.realpath(__file__)),
+ "..",
+ "..",
+ )
+)
from hls.regression_gen.utils import (
DSE_MODES,
@@ -146,7 +152,8 @@ def int_rope_dse(mode=None, top=None, threads=16):
if mode in ["report", "all"]:
top_name = "int_rope_0"
hr = get_hls_results(
- project=os.path.join(top, file_name), top=top_name,
+ project=os.path.join(top, file_name),
+ top=top_name,
)
data_points.append(
[
diff --git a/src/mase_components/hls/regression_gen/int_silu_dse.py b/src/mase_components/hls/regression_gen/int_silu_dse.py
index e3ce1bd11..0d80397cd 100644
--- a/src/mase_components/hls/regression_gen/int_silu_dse.py
+++ b/src/mase_components/hls/regression_gen/int_silu_dse.py
@@ -1,7 +1,13 @@
# TODO: Temporary working solution
import sys, os
-sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..",))
+sys.path.append(
+ os.path.join(
+ os.path.dirname(os.path.realpath(__file__)),
+ "..",
+ "..",
+ )
+)
from hls.regression_gen.utils import (
DSE_MODES,
@@ -125,7 +131,8 @@ def int_silu_dse(mode=None, top=None, threads=16):
if mode in ["report", "all"]:
top_name = "int_silu_0"
hr = get_hls_results(
- project=os.path.join(top, file_name), top=top_name,
+ project=os.path.join(top, file_name),
+ top=top_name,
)
data_points.append(
[
diff --git a/src/mase_components/hls/regression_gen/int_softmax_dse.py b/src/mase_components/hls/regression_gen/int_softmax_dse.py
index 796829278..c10c091ef 100644
--- a/src/mase_components/hls/regression_gen/int_softmax_dse.py
+++ b/src/mase_components/hls/regression_gen/int_softmax_dse.py
@@ -1,7 +1,13 @@
# TODO: Temporary working solution
import sys, os
-sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..",))
+sys.path.append(
+ os.path.join(
+ os.path.dirname(os.path.realpath(__file__)),
+ "..",
+ "..",
+ )
+)
from hls.regression_gen.utils import (
DSE_MODES,
@@ -125,7 +131,8 @@ def int_softmax_dse(mode=None, top=None, threads=16):
if mode in ["report", "all"]:
top_name = "int_softmax_0"
hr = get_hls_results(
- project=os.path.join(top, file_name), top=top_name,
+ project=os.path.join(top, file_name),
+ top=top_name,
)
data_points.append(
[
diff --git a/src/mase_components/hls/regression_gen/int_transpose_dse.py b/src/mase_components/hls/regression_gen/int_transpose_dse.py
index 4563a9eca..7b5c2b715 100644
--- a/src/mase_components/hls/regression_gen/int_transpose_dse.py
+++ b/src/mase_components/hls/regression_gen/int_transpose_dse.py
@@ -1,7 +1,13 @@
# TODO: Temporary working solution
import sys, os
-sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..",))
+sys.path.append(
+ os.path.join(
+ os.path.dirname(os.path.realpath(__file__)),
+ "..",
+ "..",
+ )
+)
from hls.regression_gen.utils import (
DSE_MODES,
@@ -125,7 +131,8 @@ def int_transpose_dse(mode=None, top=None, threads=16):
if mode in ["report", "all"]:
top_name = "int_transpose_0"
hr = get_hls_results(
- project=os.path.join(top, file_name), top=top_name,
+ project=os.path.join(top, file_name),
+ top=top_name,
)
data_points.append(
[
diff --git a/src/mase_components/hls/scripts/bl_bfp.py b/src/mase_components/hls/scripts/bl_bfp.py
index 29f34c04d..beff47050 100644
--- a/src/mase_components/hls/scripts/bl_bfp.py
+++ b/src/mase_components/hls/scripts/bl_bfp.py
@@ -1,6 +1,12 @@
import os, sys
-sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..",))
+sys.path.append(
+ os.path.join(
+ os.path.dirname(os.path.realpath(__file__)),
+ "..",
+ "..",
+ )
+)
from hls.bfp_arith import bfp_mm_gen
from hls.bfp_arith import bfp_add_gen
@@ -8,7 +14,12 @@
def get_big_little_bfp(
- HIGH_MAN_WIDTH=7, LOW_MAN_WIDTH=3, X_ROW=1, X_COL=4096, W_COL=11008, A_COL=32,
+ HIGH_MAN_WIDTH=7,
+ LOW_MAN_WIDTH=3,
+ X_ROW=1,
+ X_COL=4096,
+ W_COL=11008,
+ A_COL=32,
):
W_ROW = X_COL
A_ROW = X_COL
diff --git a/src/mase_components/linear_layers/binarized_operators/test/binary_activation_binary_adder_tree_layer_tb.py b/src/mase_components/linear_layers/binarized_operators/test/binary_activation_binary_adder_tree_layer_tb.py
index 2cdb58b59..91552b68a 100644
--- a/src/mase_components/linear_layers/binarized_operators/test/binary_activation_binary_adder_tree_layer_tb.py
+++ b/src/mase_components/linear_layers/binarized_operators/test/binary_activation_binary_adder_tree_layer_tb.py
@@ -95,7 +95,8 @@ def runner():
extra_args.append(f"-G{k}={v}")
mase_runner(
- trace=True, module_param_list=[test_case.get_dut_parameters()],
+ trace=True,
+ module_param_list=[test_case.get_dut_parameters()],
)
diff --git a/src/mase_components/linear_layers/fixed_operators/test/fixed_isqrt_tb.py b/src/mase_components/linear_layers/fixed_operators/test/fixed_isqrt_tb.py
index 27e729200..0aa8f5239 100644
--- a/src/mase_components/linear_layers/fixed_operators/test/fixed_isqrt_tb.py
+++ b/src/mase_components/linear_layers/fixed_operators/test/fixed_isqrt_tb.py
@@ -24,23 +24,31 @@ class VerificationCase(Testbench):
def __init__(self, dut):
super().__init__(dut, dut.clk, dut.rst)
self.assign_self_params(
- ["IN_WIDTH", "IN_FRAC_WIDTH", "LUT_POW",]
+ [
+ "IN_WIDTH",
+ "IN_FRAC_WIDTH",
+ "LUT_POW",
+ ]
)
self.input_driver = StreamDriver(
dut.clk, dut.in_data, dut.in_valid, dut.in_ready
)
self.output_monitor = StreamMonitor(
- dut.clk, dut.out_data, dut.out_valid, dut.out_ready, name="Output ISQRT",
+ dut.clk,
+ dut.out_data,
+ dut.out_valid,
+ dut.out_ready,
+ name="Output ISQRT",
)
def generate_inputs(self, num=10000):
- maxnum = (2 ** self.IN_WIDTH) - 1
+ maxnum = (2**self.IN_WIDTH) - 1
return [random.randint(0, maxnum) for _ in range(num)], num
def model(self, data_in):
ref = []
- lut_size = 2 ** self.LUT_POW
+ lut_size = 2**self.LUT_POW
lut = make_lut(lut_size, self.IN_WIDTH)
for x in data_in:
expected = isqrt_sw2(
@@ -119,7 +127,7 @@ async def valid_backpressure(dut):
makedirs(mem_dir, exist_ok=True)
def single_cfg(width, frac_width, lut_pow, str_id):
- lut_size = 2 ** lut_pow
+ lut_size = 2**lut_pow
lut = make_lut(lut_size, width)
mem_path = mem_dir / f"lutmem-{str_id}.mem"
write_memb(mem_path, lut, width)
diff --git a/src/mase_components/linear_layers/fixed_operators/test/fixed_lut_index_tb.py b/src/mase_components/linear_layers/fixed_operators/test/fixed_lut_index_tb.py
index 84fde6666..89c80aa83 100644
--- a/src/mase_components/linear_layers/fixed_operators/test/fixed_lut_index_tb.py
+++ b/src/mase_components/linear_layers/fixed_operators/test/fixed_lut_index_tb.py
@@ -22,7 +22,7 @@ def __init__(self, dut):
self.assign_self_params(["WIDTH", "LUT_POW"])
def generate_inputs(self):
- samples = 2 ** self.WIDTH
+ samples = 2**self.WIDTH
data_x = []
msb_indices = []
for x in range(samples):
diff --git a/src/mase_components/linear_layers/fixed_operators/test/fixed_nr_stage_tb.py b/src/mase_components/linear_layers/fixed_operators/test/fixed_nr_stage_tb.py
index 07a199810..8c0b70ed1 100644
--- a/src/mase_components/linear_layers/fixed_operators/test/fixed_nr_stage_tb.py
+++ b/src/mase_components/linear_layers/fixed_operators/test/fixed_nr_stage_tb.py
@@ -24,12 +24,12 @@ def __init__(self, dut):
self.assign_self_params(["WIDTH"])
def generate_inputs(self, lut_pow):
- samples = 2 ** self.WIDTH
+ samples = 2**self.WIDTH
int_width = 1
frac_width = self.WIDTH - 1
data_x = []
initial_guesses = []
- lut_size = 2 ** lut_pow
+ lut_size = 2**lut_pow
lut = make_lut(lut_size, self.WIDTH)
# NOTE: since negative values are not supported by fixed formats since
# isqrt only outputs positive results we cannot test every single com-
diff --git a/src/mase_components/linear_layers/fixed_operators/test/fixed_range_augmentation_tb.py b/src/mase_components/linear_layers/fixed_operators/test/fixed_range_augmentation_tb.py
index a0801d23a..cbf9cf776 100644
--- a/src/mase_components/linear_layers/fixed_operators/test/fixed_range_augmentation_tb.py
+++ b/src/mase_components/linear_layers/fixed_operators/test/fixed_range_augmentation_tb.py
@@ -22,7 +22,7 @@ def __init__(self, dut) -> None:
self.assign_self_params(["WIDTH", "FRAC_WIDTH"])
def generate_inputs(self):
- samples = 2 ** self.WIDTH
+ samples = 2**self.WIDTH
data_x = []
msb_indices = []
for x in range(samples):
diff --git a/src/mase_components/linear_layers/fixed_operators/test/fixed_range_reduction_tb.py b/src/mase_components/linear_layers/fixed_operators/test/fixed_range_reduction_tb.py
index c4d00af2a..ed155214a 100644
--- a/src/mase_components/linear_layers/fixed_operators/test/fixed_range_reduction_tb.py
+++ b/src/mase_components/linear_layers/fixed_operators/test/fixed_range_reduction_tb.py
@@ -20,7 +20,7 @@ def __init__(self, dut) -> None:
self.assign_self_params(["WIDTH"])
def generate_inputs(self):
- samples = 2 ** self.WIDTH
+ samples = 2**self.WIDTH
return [val for val in range(0, samples)], samples
def model(self, inputs):
diff --git a/src/mase_components/linear_layers/fixed_operators/test/isqrt_sw.py b/src/mase_components/linear_layers/fixed_operators/test/isqrt_sw.py
index a001b27bc..c4723f971 100644
--- a/src/mase_components/linear_layers/fixed_operators/test/isqrt_sw.py
+++ b/src/mase_components/linear_layers/fixed_operators/test/isqrt_sw.py
@@ -13,7 +13,7 @@ def find_msb(x: int, width: int) -> int:
def float_to_int(x: float, int_width: int, frac_width: int) -> int:
integer = int(x)
x -= integer
- res = integer * (2 ** frac_width)
+ res = integer * (2**frac_width)
for i in range(1, frac_width + 1):
power = 2 ** (-i)
if power <= x:
@@ -23,8 +23,8 @@ def float_to_int(x: float, int_width: int, frac_width: int) -> int:
def int_to_float(x: int, int_width: int, frac_width: int) -> float:
- integer = x / (2 ** frac_width)
- fraction = x - integer * 2 ** frac_width
+ integer = x / (2**frac_width)
+ fraction = x - integer * 2**frac_width
res = integer
for i in range(1, frac_width + 1):
@@ -85,7 +85,7 @@ def fixed_lut_index_sw(x_red: int, width: int, lut_pow: int) -> int:
res = 0
else:
res = x_red - 2 ** (width - 1)
- res = res * 2 ** lut_pow
+ res = res * 2**lut_pow
res = res / 2 ** (width - 1)
# FORMAT OUTPUT: Q(WIDTH).0
return int(res)
@@ -258,7 +258,7 @@ def test_sw_model():
def debug_single():
lut_pow = 5
- lut_size = 2 ** lut_pow
+ lut_size = 2**lut_pow
int_width = 2
frac_width = 1
width = int_width + frac_width
diff --git a/src/mase_components/linear_layers/matmul/test/fixed_matmul_tb.py b/src/mase_components/linear_layers/matmul/test/fixed_matmul_tb.py
index 5734fb891..815e9ed07 100644
--- a/src/mase_components/linear_layers/matmul/test/fixed_matmul_tb.py
+++ b/src/mase_components/linear_layers/matmul/test/fixed_matmul_tb.py
@@ -147,7 +147,10 @@ def data_generate(self):
weight_tensor = {} \n\
data_in = {} \n\
weight_in = {} ".format(
- data_tensor, weight_tensor, data_in, weight_in,
+ data_tensor,
+ weight_tensor,
+ data_in,
+ weight_in,
)
)
data_in.reverse()
@@ -378,7 +381,8 @@ def runner():
build_args=extra_args,
)
runner.test(
- hdl_toplevel="fixed_matmul", test_module="fixed_matmul_tb",
+ hdl_toplevel="fixed_matmul",
+ test_module="fixed_matmul_tb",
)
diff --git a/src/mase_components/linear_layers/matmul/test/simple_matmul_tb.py b/src/mase_components/linear_layers/matmul/test/simple_matmul_tb.py
index 0cd5b07b9..9f058695c 100644
--- a/src/mase_components/linear_layers/matmul/test/simple_matmul_tb.py
+++ b/src/mase_components/linear_layers/matmul/test/simple_matmul_tb.py
@@ -40,8 +40,8 @@ def __init__(self, dut) -> None:
]
)
- self.X_MAX = 2 ** self.X_WIDTH - 1
- self.Y_MAX = 2 ** self.Y_WIDTH - 1
+ self.X_MAX = 2**self.X_WIDTH - 1
+ self.Y_MAX = 2**self.Y_WIDTH - 1
self.x_driver = StreamDriver(dut.clk, dut.x_data, dut.x_valid, dut.x_ready)
self.y_driver = StreamDriver(dut.clk, dut.y_data, dut.y_valid, dut.y_ready)
@@ -87,10 +87,10 @@ def model(self, X, Y):
logger.debug("Sign Extended & Scaled")
X_input = sign_extend_t(X_input, self.X_WIDTH).type(torch.float32) / (
- 2 ** self.X_FRAC_WIDTH
+ 2**self.X_FRAC_WIDTH
)
Y_input = sign_extend_t(Y_input, self.Y_WIDTH).type(torch.float32) / (
- 2 ** self.Y_FRAC_WIDTH
+ 2**self.Y_FRAC_WIDTH
)
logger.debug(X_input)
logger.debug(Y_input)
diff --git a/src/mase_components/linear_layers/matmul/test/transpose_tb.py b/src/mase_components/linear_layers/matmul/test/transpose_tb.py
index 003382ff6..c859749ed 100644
--- a/src/mase_components/linear_layers/matmul/test/transpose_tb.py
+++ b/src/mase_components/linear_layers/matmul/test/transpose_tb.py
@@ -56,7 +56,11 @@ def generate_random_params(num=3):
cfgs = list()
for _ in range(num):
cfgs.append(
- {"WIDTH": randint(1, 16), "DIM0": randint(2, 12), "DIM1": randint(2, 12),}
+ {
+ "WIDTH": randint(1, 16),
+ "DIM0": randint(2, 12),
+ "DIM1": randint(2, 12),
+ }
)
return cfgs
diff --git a/src/mase_components/linear_layers/mxint_operators/test/mxint_cast_tb.py b/src/mase_components/linear_layers/mxint_operators/test/mxint_cast_tb.py
index ce742183b..963ae40df 100644
--- a/src/mase_components/linear_layers/mxint_operators/test/mxint_cast_tb.py
+++ b/src/mase_components/linear_layers/mxint_operators/test/mxint_cast_tb.py
@@ -29,7 +29,10 @@ def __init__(self, dut, num) -> None:
self.log = SimLog("%s" % (type(self).__qualname__))
self.data_in_0_driver = MultiSignalStreamDriver(
- dut.clk, (dut.mdata_in, dut.edata_in), dut.data_in_valid, dut.data_in_ready,
+ dut.clk,
+ (dut.mdata_in, dut.edata_in),
+ dut.data_in_valid,
+ dut.data_in_ready,
)
self.data_out_0_monitor = MultiSignalStreamMonitor(
@@ -48,10 +51,14 @@ def generate_inputs(self):
for _ in range(self.num):
data = 20 * torch.rand(int(self.dut.BLOCK_SIZE))
(data_in, mdata_in, edata_in) = mxint_quantize(
- data, int(self.dut.IN_MAN_WIDTH), int(self.dut.IN_EXP_WIDTH),
+ data,
+ int(self.dut.IN_MAN_WIDTH),
+ int(self.dut.IN_EXP_WIDTH),
)
exp_out, mexp_out, eexp_out = mxint_quantize(
- data_in, int(self.dut.OUT_MAN_WIDTH), int(self.dut.OUT_EXP_WIDTH),
+ data_in,
+ int(self.dut.OUT_MAN_WIDTH),
+ int(self.dut.OUT_EXP_WIDTH),
)
breakpoint()
inputs.append((mdata_in.int().tolist(), edata_in.int().tolist()))
diff --git a/src/mase_components/linear_layers/mxint_operators/test/mxint_dot_product_tb.py b/src/mase_components/linear_layers/mxint_operators/test/mxint_dot_product_tb.py
index cb1b9638f..43e70adcf 100644
--- a/src/mase_components/linear_layers/mxint_operators/test/mxint_dot_product_tb.py
+++ b/src/mase_components/linear_layers/mxint_operators/test/mxint_dot_product_tb.py
@@ -40,7 +40,10 @@ def __init__(self, dut, num) -> None:
dut.data_in_0_ready,
)
self.weight_driver = MultiSignalStreamDriver(
- dut.clk, (dut.mweight, dut.eweight), dut.weight_valid, dut.weight_ready,
+ dut.clk,
+ (dut.mweight, dut.eweight),
+ dut.weight_valid,
+ dut.weight_ready,
)
self.data_out_0_monitor = MultiSignalStreamMonitor(
@@ -64,7 +67,9 @@ def generate_inputs(self):
)
w = torch.rand(int(self.dut.BLOCK_SIZE))
(weight, mweight, eweight) = mxint_quantize(
- w, int(self.dut.WEIGHT_PRECISION_0), int(self.dut.WEIGHT_PRECISION_1),
+ w,
+ int(self.dut.WEIGHT_PRECISION_0),
+ int(self.dut.WEIGHT_PRECISION_1),
)
mdp_out = mdata_in @ mweight
edp_out = edata_in + eweight
diff --git a/src/mase_components/linear_layers/mxint_operators/test/mxint_matmul_tb.py b/src/mase_components/linear_layers/mxint_operators/test/mxint_matmul_tb.py
index e01e056e1..f31d1ab61 100644
--- a/src/mase_components/linear_layers/mxint_operators/test/mxint_matmul_tb.py
+++ b/src/mase_components/linear_layers/mxint_operators/test/mxint_matmul_tb.py
@@ -48,10 +48,16 @@ def __init__(self, dut) -> None:
self.log.setLevel(logging.DEBUG)
self.a_driver = MultiSignalStreamDriver(
- dut.clk, (dut.ma_data, dut.ea_data), dut.a_valid, dut.a_ready,
+ dut.clk,
+ (dut.ma_data, dut.ea_data),
+ dut.a_valid,
+ dut.a_ready,
)
self.b_driver = MultiSignalStreamDriver(
- dut.clk, (dut.mb_data, dut.eb_data), dut.b_valid, dut.b_ready,
+ dut.clk,
+ (dut.mb_data, dut.eb_data),
+ dut.b_valid,
+ dut.b_ready,
)
self.output_monitor = MultiSignalStreamMonitor(
diff --git a/src/mase_components/linear_layers/mxint_operators/test/mxint_vector_mult_tb.py b/src/mase_components/linear_layers/mxint_operators/test/mxint_vector_mult_tb.py
index 06f970039..58fc472ba 100644
--- a/src/mase_components/linear_layers/mxint_operators/test/mxint_vector_mult_tb.py
+++ b/src/mase_components/linear_layers/mxint_operators/test/mxint_vector_mult_tb.py
@@ -35,7 +35,10 @@ def __init__(self, dut, num) -> None:
dut.data_in_0_ready,
)
self.weight_driver = MultiSignalStreamDriver(
- dut.clk, (dut.mweight, dut.eweight), dut.weight_valid, dut.weight_ready,
+ dut.clk,
+ (dut.mweight, dut.eweight),
+ dut.weight_valid,
+ dut.weight_ready,
)
self.data_out_0_monitor = MultiSignalStreamMonitor(
@@ -63,7 +66,9 @@ def generate_inputs(self):
w = 20 * torch.rand(int(self.dut.BLOCK_SIZE))
(weight, mweight, eweight) = mxint_quantize(
- w, int(self.dut.WEIGHT_PRECISION_0), int(self.dut.WEIGHT_PRECISION_1),
+ w,
+ int(self.dut.WEIGHT_PRECISION_0),
+ int(self.dut.WEIGHT_PRECISION_1),
)
exp_out, mexp_out, eexp_out = mxint_quantize(
data_in * weight,
diff --git a/src/mase_components/linear_layers/mxint_operators/test/test.py b/src/mase_components/linear_layers/mxint_operators/test/test.py
index 85cdaa995..f58f382cf 100644
--- a/src/mase_components/linear_layers/mxint_operators/test/test.py
+++ b/src/mase_components/linear_layers/mxint_operators/test/test.py
@@ -8,8 +8,16 @@
d_man_width = 12
w_man_width = 8
e_width = 4
-(data_in, mdata_in, edata_in) = mxint_quantize(data, d_man_width, e_width,)
-(weight, mweight, eweight) = mxint_quantize(w, w_man_width, e_width,)
+(data_in, mdata_in, edata_in) = mxint_quantize(
+ data,
+ d_man_width,
+ e_width,
+)
+(weight, mweight, eweight) = mxint_quantize(
+ w,
+ w_man_width,
+ e_width,
+)
linear = torch.nn.Linear(10, 10, bias=False)
linear.weight = torch.nn.Parameter(weight)
target_x = linear(data_in)
@@ -28,7 +36,7 @@ def hardware_quant(hardware_in, be_value, e_width, width):
exponent_bias = 2 ** (e_width - 1) - 1
# exponent
- exponent_max = 2 ** e_width - 1 - exponent_bias
+ exponent_max = 2**e_width - 1 - exponent_bias
exponent_min = -exponent_bias
exponent = (
torch.ceil(torch.log2(hardware_in.abs().max())) + be_value - exponent_bias
@@ -40,7 +48,7 @@ def hardware_quant(hardware_in, be_value, e_width, width):
breakpoint()
mantissa = torch.clamp(mantissa.floor(), int_min, int_max)
- msfp_x = (2 ** exponent) * mantissa
+ msfp_x = (2**exponent) * mantissa
return msfp_x, mantissa, exponent
diff --git a/src/mase_components/linear_layers/mxint_operators/test/utils.py b/src/mase_components/linear_layers/mxint_operators/test/utils.py
index 43b3b0b87..7edb9f6ed 100644
--- a/src/mase_components/linear_layers/mxint_operators/test/utils.py
+++ b/src/mase_components/linear_layers/mxint_operators/test/utils.py
@@ -24,7 +24,7 @@ def mxint_quantize(x, width: int = 12, exponent_width: int = 6, exponent: int =
"""
exponent_bias = 2 ** (exponent_width - 1)
- exponent_max = 2 ** exponent_width - 1 - exponent_bias
+ exponent_max = 2**exponent_width - 1 - exponent_bias
exponent_min = -exponent_bias
# exponent
@@ -34,9 +34,9 @@ def mxint_quantize(x, width: int = 12, exponent_width: int = 6, exponent: int =
# mantissa
int_min = -(2 ** (width - 1))
int_max = 2 ** (width - 1) - 1
- mantissa = x / 2 ** exponent
+ mantissa = x / 2**exponent
mantissa = torch.clamp(mantissa.floor(), int_min, int_max)
- msfp_x = (2 ** exponent) * mantissa
+ msfp_x = (2**exponent) * mantissa
return msfp_x, mantissa, exponent
diff --git a/src/mase_components/memory/test/fifo_tb.py b/src/mase_components/memory/test/fifo_tb.py
index 5f247bede..9dc4f54ff 100644
--- a/src/mase_components/memory/test/fifo_tb.py
+++ b/src/mase_components/memory/test/fifo_tb.py
@@ -27,7 +27,7 @@ def __init__(self, dut) -> None:
# self.output_monitor.log.setLevel("DEBUG")
def generate_inputs(self, num=20):
- return [randint(0, (2 ** self.DATA_WIDTH) - 1) for _ in range(num)]
+ return [randint(0, (2**self.DATA_WIDTH) - 1) for _ in range(num)]
@cocotb.test()
@@ -120,7 +120,12 @@ async def cocotb_test_soak(dut):
@pytest.mark.dev
def test_fifo():
mase_runner(
- module_param_list=[{"DEPTH": 1}, {"DEPTH": 7}, {"DEPTH": 8}, {"DEPTH": 81},],
+ module_param_list=[
+ {"DEPTH": 1},
+ {"DEPTH": 7},
+ {"DEPTH": 8},
+ {"DEPTH": 81},
+ ],
trace=True,
)
diff --git a/src/mase_components/memory/test/repeat_circular_buffer_tb.py b/src/mase_components/memory/test/repeat_circular_buffer_tb.py
index db6bfc7b8..11d7d85b1 100644
--- a/src/mase_components/memory/test/repeat_circular_buffer_tb.py
+++ b/src/mase_components/memory/test/repeat_circular_buffer_tb.py
@@ -30,7 +30,7 @@ def generate_inputs(self, num=10):
inputs = []
for _ in range(num):
inputs.extend(
- [random.randint(0, 2 ** self.DATA_WIDTH - 1) for _ in range(self.SIZE)]
+ [random.randint(0, 2**self.DATA_WIDTH - 1) for _ in range(self.SIZE)]
)
return inputs
diff --git a/src/mase_components/memory/test/unpacked_fifo_tb.py b/src/mase_components/memory/test/unpacked_fifo_tb.py
index 734a2edcb..cf0a99592 100644
--- a/src/mase_components/memory/test/unpacked_fifo_tb.py
+++ b/src/mase_components/memory/test/unpacked_fifo_tb.py
@@ -26,7 +26,7 @@ def __init__(self, dut) -> None:
def generate_inputs(self, batches=20):
return [
- [randint(0, (2 ** self.DATA_WIDTH) - 1) for _ in range(self.IN_NUM)]
+ [randint(0, (2**self.DATA_WIDTH) - 1) for _ in range(self.IN_NUM)]
for _ in range(batches)
]
diff --git a/src/mase_components/normalization_layers/process_synth_impl.py b/src/mase_components/normalization_layers/process_synth_impl.py
index 5a12ee170..c65d1b8a9 100644
--- a/src/mase_components/normalization_layers/process_synth_impl.py
+++ b/src/mase_components/normalization_layers/process_synth_impl.py
@@ -103,14 +103,19 @@ def gather_data(build_dir: Path):
if __name__ == "__main__":
data = gather_data(Path("build"))
data["ns"] = data["clk_period"] - data["wns"]
- data["fmax"] = 1 / (data["ns"] * (10 ** -9))
+ data["fmax"] = 1 / (data["ns"] * (10**-9))
data["fmax_mhz"] = data["fmax"] / 1_000_000
print(data)
def plot(col):
- alt.Chart(data).mark_line().encode(x="width", y=col, color="norm",).properties(
- width=400, height=200,
+ alt.Chart(data).mark_line().encode(
+ x="width",
+ y=col,
+ color="norm",
+ ).properties(
+ width=400,
+ height=200,
).save(f"{col}_plot.png", scale_factor=3)
plot("wns")
diff --git a/src/mase_components/normalization_layers/test/batch_norm_2d_tb.py b/src/mase_components/normalization_layers/test/batch_norm_2d_tb.py
index d43976845..5d9da4bc6 100644
--- a/src/mase_components/normalization_layers/test/batch_norm_2d_tb.py
+++ b/src/mase_components/normalization_layers/test/batch_norm_2d_tb.py
@@ -158,7 +158,7 @@ def model(self, inputs):
-1, self.NUM_CHANNELS, self.TOTAL_DIM1, self.TOTAL_DIM0
)
x = sign_extend_t(x, self.IN_WIDTH).to(dtype=torch.float32) / (
- 2 ** self.IN_FRAC_WIDTH
+ 2**self.IN_FRAC_WIDTH
)
# Float Model
@@ -397,7 +397,8 @@ def gen_cfg(
]
mase_runner(
- module_param_list=test_cfgs, trace=True,
+ module_param_list=test_cfgs,
+ trace=True,
)
diff --git a/src/mase_components/normalization_layers/test/group_norm_2d_tb.py b/src/mase_components/normalization_layers/test/group_norm_2d_tb.py
index cffd16822..e31203c25 100644
--- a/src/mase_components/normalization_layers/test/group_norm_2d_tb.py
+++ b/src/mase_components/normalization_layers/test/group_norm_2d_tb.py
@@ -79,7 +79,7 @@ def __init__(self, dut) -> None:
self.out_width_tup = self.OUT_WIDTH, self.OUT_FRAC_WIDTH
# Inverse Square Root LUT
- self.isqrt_lut = make_lut(2 ** 5, 16)
+ self.isqrt_lut = make_lut(2**5, 16)
self.num_groups = randint(2, 3)
self.total_channels = self.GROUP_CHANNELS * self.num_groups
@@ -141,7 +141,7 @@ def model(self, inputs):
-1, self.total_channels, self.TOTAL_DIM1, self.TOTAL_DIM0
)
x = sign_extend_t(x, self.IN_WIDTH).to(dtype=torch.float32) / (
- 2 ** self.IN_FRAC_WIDTH
+ 2**self.IN_FRAC_WIDTH
)
# Float Model
@@ -239,7 +239,12 @@ def test_group_norm_2d():
makedirs(mem_dir, exist_ok=True)
def isqrt_width(
- total_dim0, total_dim1, compute_dim0, compute_dim1, group_channels, in_width,
+ total_dim0,
+ total_dim1,
+ compute_dim0,
+ compute_dim1,
+ group_channels,
+ in_width,
):
depth_dim0 = total_dim0 // compute_dim0
depth_dim1 = total_dim1 // compute_dim1
@@ -269,7 +274,7 @@ def gen_cfg(
isqrt_w = isqrt_width(
total_dim0, total_dim1, compute_dim0, compute_dim1, channels, in_width
)
- lut = make_lut(2 ** LUT_POW, isqrt_w)
+ lut = make_lut(2**LUT_POW, isqrt_w)
mem_path = mem_dir / f"lutmem-{str_id}.mem"
write_memb(mem_path, lut, isqrt_w)
params = {
diff --git a/src/mase_components/normalization_layers/test/models.py b/src/mase_components/normalization_layers/test/models.py
index dfeadac98..0f10b228e 100644
--- a/src/mase_components/normalization_layers/test/models.py
+++ b/src/mase_components/normalization_layers/test/models.py
@@ -56,11 +56,11 @@ def _fixed_group_norm_2d_model(
logger.debug("Diff:")
logger.debug(diff[0])
- squares = diff ** 2
+ squares = diff**2
logger.debug("Squares:")
logger.debug(squares[0])
- squares_int = (squares * (2 ** square_frac_width)).int()
- logger.debug(squares * (2 ** square_frac_width))
+ squares_int = (squares * (2**square_frac_width)).int()
+ logger.debug(squares * (2**square_frac_width))
sum_squares = torch.sum(squares, dim=(1, 2, 3), keepdim=True)
sum_squares = integer_floor_quantizer(
@@ -81,12 +81,10 @@ def _fixed_group_norm_2d_model(
logger.debug(f"{var[0]}")
# Clamp down variance to isqrt width
- var_clamp = torch.clamp(
- var, 0.0, ((2 ** isqrt_width) - 1) / (2 ** isqrt_frac_width)
- )
+ var_clamp = torch.clamp(var, 0.0, ((2**isqrt_width) - 1) / (2**isqrt_frac_width))
logger.debug("Variance Clamped:")
logger.debug(f"{var_clamp[0]}")
- var_clamp_int = (var_clamp * (2 ** isqrt_frac_width)).int()
+ var_clamp_int = (var_clamp * (2**isqrt_frac_width)).int()
# Inverse Square Root calculation
lut_pow = ceil(log2(len(isqrt_lut)))
@@ -106,7 +104,7 @@ def _fixed_group_norm_2d_model(
logger.debug("INV SQRT INT:")
logger.debug(f"{inv_sqrt_int[0]}")
- inv_sqrt = inv_sqrt_int / (2 ** isqrt_frac_width)
+ inv_sqrt = inv_sqrt_int / (2**isqrt_frac_width)
logger.debug("Inverse SQRT:")
logger.debug(f"{inv_sqrt[0]}")
diff --git a/src/mase_components/normalization_layers/test/rms_norm_2d_tb.py b/src/mase_components/normalization_layers/test/rms_norm_2d_tb.py
index 0c998f83d..22ef8bdd1 100644
--- a/src/mase_components/normalization_layers/test/rms_norm_2d_tb.py
+++ b/src/mase_components/normalization_layers/test/rms_norm_2d_tb.py
@@ -151,7 +151,7 @@ def reconstruct_tensor(self, x, width, frac_width):
x = torch.stack(matrix_list).reshape(
-1, self.CHANNELS, self.TOTAL_DIM1, self.TOTAL_DIM0
)
- x = sign_extend_t(x, width).to(dtype=torch.float32) / (2 ** frac_width)
+ x = sign_extend_t(x, width).to(dtype=torch.float32) / (2**frac_width)
return x
def output_monitor_split(self, x, width, frac_width):
@@ -248,7 +248,12 @@ def test_rms_norm_2d():
makedirs(mem_dir, exist_ok=True)
def isqrt_width(
- total_dim0, total_dim1, compute_dim0, compute_dim1, group_channels, in_width,
+ total_dim0,
+ total_dim1,
+ compute_dim0,
+ compute_dim1,
+ group_channels,
+ in_width,
):
depth_dim0 = total_dim0 // compute_dim0
depth_dim1 = total_dim1 // compute_dim1
@@ -277,7 +282,7 @@ def gen_cfg(
isqrt_w = isqrt_width(
total_dim0, total_dim1, compute_dim0, compute_dim1, channels, in_width
)
- lut = make_lut(2 ** LUT_POW, isqrt_w)
+ lut = make_lut(2**LUT_POW, isqrt_w)
mem_path = mem_dir / f"lutmem-{str_id}.mem"
write_memb(mem_path, lut, isqrt_w)
params = {
diff --git a/src/mase_components/scalar_operators/fixed/test/fixed_isqrt_tb.py b/src/mase_components/scalar_operators/fixed/test/fixed_isqrt_tb.py
index f0f5cafb1..6ee7a93dd 100644
--- a/src/mase_components/scalar_operators/fixed/test/fixed_isqrt_tb.py
+++ b/src/mase_components/scalar_operators/fixed/test/fixed_isqrt_tb.py
@@ -24,23 +24,31 @@ class VerificationCase(Testbench):
def __init__(self, dut):
super().__init__(dut, dut.clk, dut.rst)
self.assign_self_params(
- ["IN_WIDTH", "IN_FRAC_WIDTH", "LUT_POW",]
+ [
+ "IN_WIDTH",
+ "IN_FRAC_WIDTH",
+ "LUT_POW",
+ ]
)
self.input_driver = StreamDriver(
dut.clk, dut.in_data, dut.in_valid, dut.in_ready
)
self.output_monitor = StreamMonitor(
- dut.clk, dut.out_data, dut.out_valid, dut.out_ready, name="Output ISQRT",
+ dut.clk,
+ dut.out_data,
+ dut.out_valid,
+ dut.out_ready,
+ name="Output ISQRT",
)
def generate_inputs(self, num=10000):
- maxnum = (2 ** self.IN_WIDTH) - 1
+ maxnum = (2**self.IN_WIDTH) - 1
return [random.randint(0, maxnum) for _ in range(num)], num
def model(self, data_in):
ref = []
- lut_size = 2 ** self.LUT_POW
+ lut_size = 2**self.LUT_POW
lut = make_lut(lut_size, self.IN_WIDTH)
for x in data_in:
expected = isqrt_sw2(
@@ -123,7 +131,7 @@ def test_fixed_isqrt():
makedirs(mem_dir, exist_ok=True)
def single_cfg(width, frac_width, lut_pow, str_id):
- lut_size = 2 ** lut_pow
+ lut_size = 2**lut_pow
lut = make_lut(lut_size, width)
mem_path = mem_dir / f"lutmem-{str_id}.mem"
write_memb(mem_path, lut, width)
diff --git a/src/mase_components/scalar_operators/fixed/test/fixed_nr_stage_tb.py b/src/mase_components/scalar_operators/fixed/test/fixed_nr_stage_tb.py
index 5a3019f8e..d09305ef0 100644
--- a/src/mase_components/scalar_operators/fixed/test/fixed_nr_stage_tb.py
+++ b/src/mase_components/scalar_operators/fixed/test/fixed_nr_stage_tb.py
@@ -24,12 +24,12 @@ def __init__(self, dut):
self.assign_self_params(["WIDTH"])
def generate_inputs(self, lut_pow):
- samples = 2 ** self.WIDTH
+ samples = 2**self.WIDTH
int_width = 1
frac_width = self.WIDTH - 1
data_x = []
initial_guesses = []
- lut_size = 2 ** lut_pow
+ lut_size = 2**lut_pow
lut = make_lut(lut_size, self.WIDTH)
# NOTE: since negative values are not supported by fixed formats since
# isqrt only outputs positive results we cannot test every single com-
diff --git a/src/mase_components/scalar_operators/fixed/test/isqrt_sw.py b/src/mase_components/scalar_operators/fixed/test/isqrt_sw.py
index 901ba964d..f6fc798a4 100644
--- a/src/mase_components/scalar_operators/fixed/test/isqrt_sw.py
+++ b/src/mase_components/scalar_operators/fixed/test/isqrt_sw.py
@@ -13,7 +13,7 @@ def find_msb(x: int, width: int) -> int:
def float_to_int(x: float, int_width: int, frac_width: int) -> int:
integer = int(x)
x -= integer
- res = integer * (2 ** frac_width)
+ res = integer * (2**frac_width)
for i in range(1, frac_width + 1):
power = 2 ** (-i)
if power <= x:
@@ -23,8 +23,8 @@ def float_to_int(x: float, int_width: int, frac_width: int) -> int:
def int_to_float(x: int, int_width: int, frac_width: int) -> float:
- integer = x / (2 ** frac_width)
- fraction = x - integer * 2 ** frac_width
+ integer = x / (2**frac_width)
+ fraction = x - integer * 2**frac_width
res = integer
for i in range(1, frac_width + 1):
@@ -85,7 +85,7 @@ def fixed_lut_index_sw(x_red: int, width: int, lut_pow: int) -> int:
res = 0
else:
res = x_red - 2 ** (width - 1)
- res = res * 2 ** lut_pow
+ res = res * 2**lut_pow
res = res / 2 ** (width - 1)
# FORMAT OUTPUT: Q(WIDTH).0
return int(res)
@@ -258,7 +258,7 @@ def test_isqrt_sw_model():
def debug_single():
lut_pow = 5
- lut_size = 2 ** lut_pow
+ lut_size = 2**lut_pow
int_width = 2
frac_width = 1
width = int_width + frac_width
diff --git a/src/mase_components/transformer_layers/test/fixed_grouped_query_attention_wrapper_tb.py b/src/mase_components/transformer_layers/test/fixed_grouped_query_attention_wrapper_tb.py
index a334a1469..d999fdc59 100644
--- a/src/mase_components/transformer_layers/test/fixed_grouped_query_attention_wrapper_tb.py
+++ b/src/mase_components/transformer_layers/test/fixed_grouped_query_attention_wrapper_tb.py
@@ -195,16 +195,28 @@ def __init__(self, dut) -> None:
if self.HAS_BIAS == 1:
self.bias_q_driver = StreamDriver(
- dut.clk, dut.bias_query, dut.bias_query_valid, dut.bias_query_ready,
+ dut.clk,
+ dut.bias_query,
+ dut.bias_query_valid,
+ dut.bias_query_ready,
)
self.bias_k_driver = StreamDriver(
- dut.clk, dut.bias_key, dut.bias_key_valid, dut.bias_key_ready,
+ dut.clk,
+ dut.bias_key,
+ dut.bias_key_valid,
+ dut.bias_key_ready,
)
self.bias_v_driver = StreamDriver(
- dut.clk, dut.bias_value, dut.bias_value_valid, dut.bias_value_ready,
+ dut.clk,
+ dut.bias_value,
+ dut.bias_value_valid,
+ dut.bias_value_ready,
)
self.bias_o_driver = StreamDriver(
- dut.clk, dut.bias_output, dut.bias_output_valid, dut.bias_output_ready,
+ dut.clk,
+ dut.bias_output,
+ dut.bias_output_valid,
+ dut.bias_output_ready,
)
self.error_threshold = 2
@@ -469,11 +481,11 @@ async def run_memory_bandwidth_test(self, us: int = 500):
num_v_weight_beats_sent = self.weight_v_driver.num_beats
num_o_weight_beats_sent = self.weight_o_driver.num_beats
- input_beats_per_sec = num_input_beats_sent / (nanosec * (10 ** -9))
- num_q_beats_per_sec = num_q_weight_beats_sent / (nanosec * (10 ** -9))
- num_k_beats_per_sec = num_k_weight_beats_sent / (nanosec * (10 ** -9))
- num_v_beats_per_sec = num_v_weight_beats_sent / (nanosec * (10 ** -9))
- num_o_beats_per_sec = num_o_weight_beats_sent / (nanosec * (10 ** -9))
+ input_beats_per_sec = num_input_beats_sent / (nanosec * (10**-9))
+ num_q_beats_per_sec = num_q_weight_beats_sent / (nanosec * (10**-9))
+ num_k_beats_per_sec = num_k_weight_beats_sent / (nanosec * (10**-9))
+ num_v_beats_per_sec = num_v_weight_beats_sent / (nanosec * (10**-9))
+ num_o_beats_per_sec = num_o_weight_beats_sent / (nanosec * (10**-9))
self.log.info("Test length (ns): %.4f" % nanosec)
@@ -591,7 +603,9 @@ def test_fixed_linear_smoke():
]
mase_runner(
- module_param_list=cfgs, hierarchical=True, template=True,
+ module_param_list=cfgs,
+ hierarchical=True,
+ template=True,
)
@@ -603,7 +617,9 @@ def test_parallelism_sweep():
cfgs.append(get_config(16, 128, 8, 4, embedding_par, seq_par))
mase_runner(
- module_param_list=cfgs, hierarchical=True, template=True,
+ module_param_list=cfgs,
+ hierarchical=True,
+ template=True,
)
@@ -615,7 +631,9 @@ def test_small_parallelism():
cfgs.append(get_config(16, 128, 8, 4, embedding_par, seq_par))
mase_runner(
- module_param_list=cfgs, hierarchical=True, template=True,
+ module_param_list=cfgs,
+ hierarchical=True,
+ template=True,
)
@@ -625,7 +643,9 @@ def test_heads_sweep():
cfgs.append(get_config(256, 256, 16, kv_heads, 16, 1))
mase_runner(
- module_param_list=cfgs, hierarchical=True, template=True,
+ module_param_list=cfgs,
+ hierarchical=True,
+ template=True,
)
@@ -637,7 +657,9 @@ def test_bitwidth_sweep():
)
mase_runner(
- module_param_list=cfgs, hierarchical=True, template=True,
+ module_param_list=cfgs,
+ hierarchical=True,
+ template=True,
)
diff --git a/src/mase_components/transformer_layers/test/fixed_self_attention_head_tb.py b/src/mase_components/transformer_layers/test/fixed_self_attention_head_tb.py
index 70df56d43..a46032646 100644
--- a/src/mase_components/transformer_layers/test/fixed_self_attention_head_tb.py
+++ b/src/mase_components/transformer_layers/test/fixed_self_attention_head_tb.py
@@ -44,7 +44,11 @@ def __init__(self, dut) -> None:
)
self.out_monitor = StreamMonitor(
- dut.clk, dut.out, dut.out_valid, dut.out_ready, check=False,
+ dut.clk,
+ dut.out,
+ dut.out_valid,
+ dut.out_ready,
+ check=False,
)
# Model
@@ -56,7 +60,8 @@ def __init__(self, dut) -> None:
"frac_width": self.get_parameter("IN_DATA_PRECISION_1"),
}
self.model = BertSelfAttentionHeadInteger(
- config=self.config, q_config=self.q_config,
+ config=self.config,
+ q_config=self.q_config,
)
# Set verbosity of driver and monitor loggers to debug
@@ -114,21 +119,27 @@ async def run_test(self):
# * Load the query driver
self.log.info(f"Processing query inputs: {inputs['query_layer']}")
query_inputs = self.preprocess_tensor(
- tensor=inputs["query_layer"], config=self.q_config, parallelism=parallelism,
+ tensor=inputs["query_layer"],
+ config=self.q_config,
+ parallelism=parallelism,
)
self.query_driver.load_driver(query_inputs)
# * Load the key driver
self.log.info(f"Processing key inputs: {inputs['key_layer']}")
key_inputs = self.preprocess_tensor(
- tensor=inputs["key_layer"], config=self.q_config, parallelism=parallelism,
+ tensor=inputs["key_layer"],
+ config=self.q_config,
+ parallelism=parallelism,
)
self.key_driver.load_driver(key_inputs)
# * Load the value driver
self.log.info(f"Processing value inputs: {inputs['value_layer']}")
value_inputs = self.preprocess_tensor(
- tensor=inputs["value_layer"], config=self.q_config, parallelism=parallelism,
+ tensor=inputs["value_layer"],
+ config=self.q_config,
+ parallelism=parallelism,
)
self.value_driver.load_driver(value_inputs)
@@ -187,11 +198,18 @@ def test_fixed_self_attention_head_smoke():
# * Generate exponential LUT for softmax
generate_memory.generate_sv_lut(
- "exp", 16, 3, 16, 3, path=Path(__file__).parents[1] / "rtl",
+ "exp",
+ 16,
+ 3,
+ 16,
+ 3,
+ path=Path(__file__).parents[1] / "rtl",
)
mase_runner(
trace=True,
- module_param_list=[get_fixed_self_attention_head_config(),],
+ module_param_list=[
+ get_fixed_self_attention_head_config(),
+ ],
skip_build=False,
)
diff --git a/src/mase_components/vision_models/vit/test/fixed_mlp_tb.py b/src/mase_components/vision_models/vit/test/fixed_mlp_tb.py
index 823c65db4..2e19bed4b 100644
--- a/src/mase_components/vision_models/vit/test/fixed_mlp_tb.py
+++ b/src/mase_components/vision_models/vit/test/fixed_mlp_tb.py
@@ -108,7 +108,8 @@ def __init__(self, samples=1):
depth_in_num = int(self.in_num / self.tile_in_num)
depth_out_features = int(self.out_features / self.tile_out_features)
self.outputs = RandomSink(
- samples=samples * depth_out_features * depth_in_num, debug=debug,
+ samples=samples * depth_out_features * depth_in_num,
+ debug=debug,
)
self.ref = self.sw_compute()
@@ -470,7 +471,8 @@ def runner():
)
for _ in range(1):
runner.test(
- hdl_toplevel="fixed_mlp", test_module="fixed_mlp_tb",
+ hdl_toplevel="fixed_mlp",
+ test_module="fixed_mlp_tb",
)
diff --git a/src/mase_components/vision_models/vit/test/fixed_patch_embed_tb.py b/src/mase_components/vision_models/vit/test/fixed_patch_embed_tb.py
index b92e6dc4f..05cf0f2cf 100644
--- a/src/mase_components/vision_models/vit/test/fixed_patch_embed_tb.py
+++ b/src/mase_components/vision_models/vit/test/fixed_patch_embed_tb.py
@@ -50,14 +50,14 @@ def __init__(self, samples=1):
self.pe_unroll_kernel_out = 3
self.pe_unroll_in_c = 3
self.pe_unroll_embed_dim = 8
- self.num_patch = int(self.in_y * self.in_x // (self.patch_size ** 2))
+ self.num_patch = int(self.in_y * self.in_x // (self.patch_size**2))
# self.num_classes = 10
# self.head_unroll_out_x = 5
self.samples = samples
self.pe_iter_weight = int(
- (self.patch_size ** 2)
+ (self.patch_size**2)
* self.in_c
* self.embed_dim
/ self.pe_unroll_kernel_out
@@ -247,7 +247,10 @@ def conv_pack(
unroll_out_channels,
).permute(0, 3, 1, 4, 2)
- w_tensor = w_tensor.reshape(-1, unroll_out_channels * unroll_kernel_out,)
+ w_tensor = w_tensor.reshape(
+ -1,
+ unroll_out_channels * unroll_kernel_out,
+ )
w_in = w_tensor.type(torch.int).flip(0).tolist()
# bias_pack
bias_tensor = bias.repeat(samples, 1).reshape(-1, unroll_out_channels)
diff --git a/src/mase_components/vision_models/vit/test/fixed_pvt_tb.py b/src/mase_components/vision_models/vit/test/fixed_pvt_tb.py
index 532e3fb07..06b921e81 100644
--- a/src/mase_components/vision_models/vit/test/fixed_pvt_tb.py
+++ b/src/mase_components/vision_models/vit/test/fixed_pvt_tb.py
@@ -49,7 +49,7 @@ def __init__(self, samples=1):
self.in_x = 224
self.embed_dim = 384
self.patch_size = 16
- self.num_patch = self.in_y * self.in_x // (self.patch_size ** 2)
+ self.num_patch = self.in_y * self.in_x // (self.patch_size**2)
self.num_heads = 6
self.mlp_ratio = 2
@@ -64,7 +64,7 @@ def __init__(self, samples=1):
self.head_unroll_out_x = 1
self.pe_iter_weight = int(
- (self.patch_size ** 2)
+ (self.patch_size**2)
* self.in_c
* self.embed_dim
/ self.pe_unroll_kernel_out
@@ -1148,7 +1148,10 @@ def conv_pack(
unroll_out_channels,
).permute(0, 3, 1, 4, 2)
- w_tensor = w_tensor.reshape(-1, unroll_out_channels * unroll_kernel_out,)
+ w_tensor = w_tensor.reshape(
+ -1,
+ unroll_out_channels * unroll_kernel_out,
+ )
w_in = w_tensor.type(torch.int).flip(0).tolist()
# bias_pack
bias_tensor = bias.repeat(samples, 1).reshape(-1, unroll_out_channels)
@@ -1511,7 +1514,8 @@ def runner():
build_args=extra_args,
)
runner.test(
- hdl_toplevel="fixed_pvt", test_module="fixed_pvt_tb",
+ hdl_toplevel="fixed_pvt",
+ test_module="fixed_pvt_tb",
)
diff --git a/src/mase_components/vision_models/vit/test/hash_exp_tb.py b/src/mase_components/vision_models/vit/test/hash_exp_tb.py
index 719ee89dd..e8b107503 100644
--- a/src/mase_components/vision_models/vit/test/hash_exp_tb.py
+++ b/src/mase_components/vision_models/vit/test/hash_exp_tb.py
@@ -147,7 +147,9 @@ def runner():
print(extra_args)
runner = get_runner(sim)
runner.build(
- verilog_sources=verilog_sources, hdl_toplevel="hash_exp", build_args=extra_args,
+ verilog_sources=verilog_sources,
+ hdl_toplevel="hash_exp",
+ build_args=extra_args,
)
runner.test(hdl_toplevel="hash_exp", test_module="hash_exp_tb")
diff --git a/src/mase_components/vision_models/vit/test/hash_softmax_tb.py b/src/mase_components/vision_models/vit/test/hash_softmax_tb.py
index 5c0522fb9..98ceea63f 100644
--- a/src/mase_components/vision_models/vit/test/hash_softmax_tb.py
+++ b/src/mase_components/vision_models/vit/test/hash_softmax_tb.py
@@ -45,7 +45,11 @@ def __init__(self, samples=1):
},
}
self.d_config = {
- "softmax": {"in_size": 1, "out_size": 1, "in_depth": 4,},
+ "softmax": {
+ "in_size": 1,
+ "out_size": 1,
+ "in_depth": 4,
+ },
}
in_size = self.d_config["softmax"]["in_size"]
out_size = self.d_config["softmax"]["out_size"]
diff --git a/src/mase_components/vision_models/vit/test/helpers/ha_softmax.py b/src/mase_components/vision_models/vit/test/helpers/ha_softmax.py
index 8dac8afee..62e5b840d 100644
--- a/src/mase_components/vision_models/vit/test/helpers/ha_softmax.py
+++ b/src/mase_components/vision_models/vit/test/helpers/ha_softmax.py
@@ -5,8 +5,8 @@
def quantize_to_int(x: Tensor, width: int, frac_width: int):
- x = _integer_quantize(x, width, frac_width) * (2 ** frac_width)
- x = x.int() & (2 ** width - 1)
+ x = _integer_quantize(x, width, frac_width) * (2**frac_width)
+ x = x.int() & (2**width - 1)
return x
@@ -25,7 +25,7 @@ def twos_complement_to_float(binary_string: str, width: int, frac_width: int):
integer_magnitude = -(2 ** (width - 1)) + integer_magnitude
# Calculate scaling factor
- scaling_factor = 2 ** frac_width
+ scaling_factor = 2**frac_width
# Calculate floating-point value
float_value = integer_magnitude / scaling_factor
@@ -79,7 +79,8 @@ def generate_table_div_software(width, out_width, out_frac_width):
class QHashSoftmax(torch.nn.Module):
def __init__(
- self, config,
+ self,
+ config,
):
super(QHashSoftmax, self).__init__()
self.in_width = config["data_in_width"]
@@ -108,7 +109,7 @@ def forward(self, x, scale):
# quantize to div_width
one_over_div = _integer_quantize(exp_sum // exp, self.div_width + 1, 0)
one_over_div = torch.where(
- exp == 0, torch.tensor(2 ** self.div_width - 1), one_over_div
+ exp == 0, torch.tensor(2**self.div_width - 1), one_over_div
)
one_over_div = torch.tensor(one_over_div, dtype=int)
diff --git a/src/mase_components/vision_models/vit/test/helpers/pvt_quant.py b/src/mase_components/vision_models/vit/test/helpers/pvt_quant.py
index 4863f2ca9..d7f4e8464 100644
--- a/src/mase_components/vision_models/vit/test/helpers/pvt_quant.py
+++ b/src/mase_components/vision_models/vit/test/helpers/pvt_quant.py
@@ -50,7 +50,7 @@ def __init__(
self.dim = dim
self.num_heads = num_heads
head_dim = dim // num_heads
- self.scale = qk_scale or head_dim ** -0.5
+ self.scale = qk_scale or head_dim**-0.5
self.q = get_quantized_cls("linear", config["q_proj"])(
dim, dim, bias=qkv_bias, config=config["q_proj"]
diff --git a/test/nn/snn/test_ann2snn.py b/test/nn/snn/test_ann2snn.py
index 3a5d2b7e2..1b8bbbcfa 100644
--- a/test/nn/snn/test_ann2snn.py
+++ b/test/nn/snn/test_ann2snn.py
@@ -191,7 +191,13 @@ def val(net, device, data_loader, T=None):
"by": "type",
"default": {"config": {"name": None}},
"fuse": True,
- "relu": {"config": {"name": "IFNode", "mode": "99.9%", "momentum": 0.1,}},
+ "relu": {
+ "config": {
+ "name": "IFNode",
+ "mode": "99.9%",
+ "momentum": 0.1,
+ }
+ },
"train_data_loader": input_generator,
"device": "cpu", # "device": "cuda",
}
diff --git a/test/passes/graph/analysis/add_metadata/test_add_common_metadata.py b/test/passes/graph/analysis/add_metadata/test_add_common_metadata.py
index a7bce45cd..9740bdd38 100644
--- a/test/passes/graph/analysis/add_metadata/test_add_common_metadata.py
+++ b/test/passes/graph/analysis/add_metadata/test_add_common_metadata.py
@@ -27,10 +27,22 @@ def add_common_metadata(model_cls_name: str) -> MaseGraph:
# mg.fx_graph.print_tabular()
input_ids = torch.randint(
- 0, config.vocab_size, (1, 128, config.hidden_size,), device="meta",
+ 0,
+ config.vocab_size,
+ (
+ 1,
+ 128,
+ config.hidden_size,
+ ),
+ device="meta",
)
mg, _ = passes.add_common_metadata_analysis_pass(
- mg, pass_args={"dummy_in": {"input_ids": input_ids,},},
+ mg,
+ pass_args={
+ "dummy_in": {
+ "input_ids": input_ids,
+ },
+ },
)
return mg
diff --git a/test/passes/graph/analysis/pruning/test_hook_inspect.py b/test/passes/graph/analysis/pruning/test_hook_inspect.py
index b34b579d3..dfa7fa0fd 100644
--- a/test/passes/graph/analysis/pruning/test_hook_inspect.py
+++ b/test/passes/graph/analysis/pruning/test_hook_inspect.py
@@ -111,9 +111,15 @@ def run_with_config(config_file):
profile_pass_arg = {
"by": "type",
- "target_weight_nodes": ["conv2d",],
- "target_activation_nodes": ["conv2d",],
- "weight_statistics": {"variance_precise": {"device": "cpu", "dims": "all"},},
+ "target_weight_nodes": [
+ "conv2d",
+ ],
+ "target_activation_nodes": [
+ "conv2d",
+ ],
+ "weight_statistics": {
+ "variance_precise": {"device": "cpu", "dims": "all"},
+ },
"activation_statistics": {
"variance_precise": {"device": "cpu", "dims": "all"},
},
diff --git a/test/passes/graph/analysis/statistic_profiler/test_statistic_profiler.py b/test/passes/graph/analysis/statistic_profiler/test_statistic_profiler.py
index 084170c70..2b9828ace 100644
--- a/test/passes/graph/analysis/statistic_profiler/test_statistic_profiler.py
+++ b/test/passes/graph/analysis/statistic_profiler/test_statistic_profiler.py
@@ -52,7 +52,9 @@ def test_statistic_profiler():
dataset_info = get_dataset_info("cifar10")
model = get_model(
- checkpoint="resnet18", pretrained=False, dataset_info=dataset_info,
+ checkpoint="resnet18",
+ pretrained=False,
+ dataset_info=dataset_info,
)
dummy_in = {"x": next(iter(datamodule.train_dataloader()))[0]}
@@ -66,9 +68,15 @@ def test_statistic_profiler():
pass_arg = {
"by": "type",
- "target_weight_nodes": ["conv2d",],
- "target_activation_nodes": ["relu",],
- "weight_statistics": {"variance_precise": {"device": "cpu", "dims": "all"},},
+ "target_weight_nodes": [
+ "conv2d",
+ ],
+ "target_activation_nodes": [
+ "relu",
+ ],
+ "weight_statistics": {
+ "variance_precise": {"device": "cpu", "dims": "all"},
+ },
"activation_statistics": {
"variance_precise": {"device": "cpu", "dims": "all"},
},
diff --git a/test/passes/graph/transforms/prune/test_prune.py b/test/passes/graph/transforms/prune/test_prune.py
index 8c6de1d2a..8bddad906 100644
--- a/test/passes/graph/transforms/prune/test_prune.py
+++ b/test/passes/graph/transforms/prune/test_prune.py
@@ -106,9 +106,15 @@ def run_with_config(config_file):
profile_pass_arg = {
"by": "type",
- "target_weight_nodes": ["conv2d",],
- "target_activation_nodes": ["conv2d",],
- "weight_statistics": {"variance_precise": {"device": "cpu", "dims": "all"},},
+ "target_weight_nodes": [
+ "conv2d",
+ ],
+ "target_activation_nodes": [
+ "conv2d",
+ ],
+ "weight_statistics": {
+ "variance_precise": {"device": "cpu", "dims": "all"},
+ },
"activation_statistics": {
"variance_precise": {"device": "cpu", "dims": "all"},
},
diff --git a/test/passes/graph/transforms/prune/test_prune_detach_hook.py b/test/passes/graph/transforms/prune/test_prune_detach_hook.py
index f7afe9a8d..bf15caa86 100644
--- a/test/passes/graph/transforms/prune/test_prune_detach_hook.py
+++ b/test/passes/graph/transforms/prune/test_prune_detach_hook.py
@@ -105,9 +105,15 @@ def run_with_config(config_file):
profile_pass_arg = {
"by": "type",
- "target_weight_nodes": ["conv2d",],
- "target_activation_nodes": ["conv2d",],
- "weight_statistics": {"variance_precise": {"device": "cpu", "dims": "all"},},
+ "target_weight_nodes": [
+ "conv2d",
+ ],
+ "target_activation_nodes": [
+ "conv2d",
+ ],
+ "weight_statistics": {
+ "variance_precise": {"device": "cpu", "dims": "all"},
+ },
"activation_statistics": {
"variance_precise": {"device": "cpu", "dims": "all"},
},
diff --git a/test/passes/graph/transforms/quantize/test_quantize_lutnet_conv2d.py b/test/passes/graph/transforms/quantize/test_quantize_lutnet_conv2d.py
index 5ca6f6b27..89327f8f0 100644
--- a/test/passes/graph/transforms/quantize/test_quantize_lutnet_conv2d.py
+++ b/test/passes/graph/transforms/quantize/test_quantize_lutnet_conv2d.py
@@ -181,7 +181,8 @@ def test_quantize_lutnet_conv2d():
first_table = connection[0]
assert any(initialized_weight[0, :] == first_table) and any(
initialized_weight[
- input_c * k_w * k_h * output_c * (lutnet_config["data_in_levels"] - 1), :,
+ input_c * k_w * k_h * output_c * (lutnet_config["data_in_levels"] - 1),
+ :,
]
== first_table
)
diff --git a/test/passes/graph/transforms/training/test_training_base_pass.py b/test/passes/graph/transforms/training/test_training_base_pass.py
index e46777de0..06d8371dd 100644
--- a/test/passes/graph/transforms/training/test_training_base_pass.py
+++ b/test/passes/graph/transforms/training/test_training_base_pass.py
@@ -170,7 +170,11 @@ def test_training_base_backward_only():
"default": {"config": {"name": None}},
"linear": {
"config": {
- "forward": {"bypass": True, "pass": "quantize", "name": "integer",},
+ "forward": {
+ "bypass": True,
+ "pass": "quantize",
+ "name": "integer",
+ },
"backward": {
"pass": "quantize",
"name": "integer",
diff --git a/test/passes/graph/transforms/verilog/test_emit_activation_gelu.py b/test/passes/graph/transforms/verilog/test_emit_activation_gelu.py
index 68d405e5a..d78ac8525 100644
--- a/test/passes/graph/transforms/verilog/test_emit_activation_gelu.py
+++ b/test/passes/graph/transforms/verilog/test_emit_activation_gelu.py
@@ -76,7 +76,11 @@ def test_emit_activation_gelu():
)
config_file = os.path.join(
- os.path.abspath(""), "configs", "tests", "quantize", "fixed.toml",
+ os.path.abspath(""),
+ "configs",
+ "tests",
+ "quantize",
+ "fixed.toml",
)
with open(config_file, "r") as f:
quan_args = toml.load(f)["passes"]["quantize"]
diff --git a/test/passes/graph/transforms/verilog/test_emit_activation_selu.py b/test/passes/graph/transforms/verilog/test_emit_activation_selu.py
index 8a59b2e44..ea1a11928 100644
--- a/test/passes/graph/transforms/verilog/test_emit_activation_selu.py
+++ b/test/passes/graph/transforms/verilog/test_emit_activation_selu.py
@@ -77,7 +77,11 @@ def test_emit_activation_selu():
)
config_file = os.path.join(
- os.path.abspath(""), "configs", "tests", "quantize", "fixed.toml",
+ os.path.abspath(""),
+ "configs",
+ "tests",
+ "quantize",
+ "fixed.toml",
)
with open(config_file, "r") as f:
quan_args = toml.load(f)["passes"]["quantize"]
diff --git a/test/passes/graph/transforms/verilog/test_emit_activation_softplus.py b/test/passes/graph/transforms/verilog/test_emit_activation_softplus.py
index 9e60ba7a3..8c7a1e107 100644
--- a/test/passes/graph/transforms/verilog/test_emit_activation_softplus.py
+++ b/test/passes/graph/transforms/verilog/test_emit_activation_softplus.py
@@ -76,7 +76,11 @@ def test_emit_activation_softplus():
)
config_file = os.path.join(
- os.path.abspath(""), "configs", "tests", "quantize", "fixed.toml",
+ os.path.abspath(""),
+ "configs",
+ "tests",
+ "quantize",
+ "fixed.toml",
)
with open(config_file, "r") as f:
quan_args = toml.load(f)["passes"]["quantize"]
diff --git a/test/passes/graph/transforms/verilog/test_emit_activation_softsign.py b/test/passes/graph/transforms/verilog/test_emit_activation_softsign.py
index 3de23e480..f50f5fce5 100644
--- a/test/passes/graph/transforms/verilog/test_emit_activation_softsign.py
+++ b/test/passes/graph/transforms/verilog/test_emit_activation_softsign.py
@@ -76,7 +76,11 @@ def test_emit_activation_softsign():
)
config_file = os.path.join(
- os.path.abspath(""), "configs", "tests", "quantize", "fixed.toml",
+ os.path.abspath(""),
+ "configs",
+ "tests",
+ "quantize",
+ "fixed.toml",
)
with open(config_file, "r") as f:
quan_args = toml.load(f)["passes"]["quantize"]
diff --git a/test/passes/graph/transforms/verilog/test_emit_activation_tanh.py b/test/passes/graph/transforms/verilog/test_emit_activation_tanh.py
index 116d53b36..205211ea6 100644
--- a/test/passes/graph/transforms/verilog/test_emit_activation_tanh.py
+++ b/test/passes/graph/transforms/verilog/test_emit_activation_tanh.py
@@ -75,7 +75,11 @@ def test_emit_activation_tanh():
mg, _ = add_common_metadata_analysis_pass(mg, {"dummy_in": dummy_in})
config_file = os.path.join(
- os.path.abspath(""), "configs", "tests", "quantize", "fixed.toml",
+ os.path.abspath(""),
+ "configs",
+ "tests",
+ "quantize",
+ "fixed.toml",
)
with open(config_file, "r") as f:
quan_args = toml.load(f)["passes"]["quantize"]
diff --git a/test/passes/graph/transforms/verilog/test_emit_verilog_bert.py b/test/passes/graph/transforms/verilog/test_emit_verilog_bert.py
index 3e0f16599..b50fb3776 100644
--- a/test/passes/graph/transforms/verilog/test_emit_verilog_bert.py
+++ b/test/passes/graph/transforms/verilog/test_emit_verilog_bert.py
@@ -174,7 +174,10 @@ def emit_verilog_bert(
mg, _ = bert_update_metadata(mg, q_config)
mg, _ = passes.add_hardware_metadata_analysis_pass(
- mg, pass_args={"max_parallelism": [max_parallelism] * 4,},
+ mg,
+ pass_args={
+ "max_parallelism": [max_parallelism] * 4,
+ },
)
# * Save the metadata to a file for debugging
@@ -190,7 +193,11 @@ def emit_verilog_bert(
mg, _ = passes.emit_bram_transform_pass(mg)
mg, _ = passes.emit_internal_rtl_transform_pass(mg)
mg, _ = passes.emit_cocotb_transform_pass(
- mg, pass_args={"wait_time": wait_count, "wait_unit": wait_unit,},
+ mg,
+ pass_args={
+ "wait_time": wait_count,
+ "wait_unit": wait_unit,
+ },
)
mg, _ = passes.emit_vivado_project_transform_pass(mg)
diff --git a/test/passes/graph/transforms/verilog/test_emit_verilog_llama.py b/test/passes/graph/transforms/verilog/test_emit_verilog_llama.py
index 5447a3782..f268184ce 100644
--- a/test/passes/graph/transforms/verilog/test_emit_verilog_llama.py
+++ b/test/passes/graph/transforms/verilog/test_emit_verilog_llama.py
@@ -86,7 +86,10 @@ def emit_verilog_llama(
mg, _ = llama_update_metadata(mg, q_config)
mg, _ = passes.add_hardware_metadata_analysis_pass(
- mg, pass_args={"max_parallelism": [max_parallelism] * 4,},
+ mg,
+ pass_args={
+ "max_parallelism": [max_parallelism] * 4,
+ },
)
# * Save the metadata to a file for debugging
@@ -102,7 +105,11 @@ def emit_verilog_llama(
mg, _ = passes.emit_bram_transform_pass(mg)
mg, _ = passes.emit_internal_rtl_transform_pass(mg)
mg, _ = passes.emit_cocotb_transform_pass(
- mg, pass_args={"wait_time": wait_count, "wait_unit": wait_unit,},
+ mg,
+ pass_args={
+ "wait_time": wait_count,
+ "wait_unit": wait_unit,
+ },
)
mg, _ = passes.emit_vivado_project_transform_pass(mg)
diff --git a/test/passes/graph/transforms/verilog/test_emit_verilog_mistral.py b/test/passes/graph/transforms/verilog/test_emit_verilog_mistral.py
index 51ae9dc7d..64625495e 100644
--- a/test/passes/graph/transforms/verilog/test_emit_verilog_mistral.py
+++ b/test/passes/graph/transforms/verilog/test_emit_verilog_mistral.py
@@ -87,7 +87,10 @@ def emit_verilog_mistral(
mg, _ = mistral_update_metadata(mg, q_config)
mg, _ = passes.add_hardware_metadata_analysis_pass(
- mg, pass_args={"max_parallelism": [max_parallelism] * 4,},
+ mg,
+ pass_args={
+ "max_parallelism": [max_parallelism] * 4,
+ },
)
# * Save the metadata to a file for debugging
@@ -103,7 +106,11 @@ def emit_verilog_mistral(
mg, _ = passes.emit_bram_transform_pass(mg)
mg, _ = passes.emit_internal_rtl_transform_pass(mg)
mg, _ = passes.emit_cocotb_transform_pass(
- mg, pass_args={"wait_time": wait_count, "wait_unit": wait_unit,},
+ mg,
+ pass_args={
+ "wait_time": wait_count,
+ "wait_unit": wait_unit,
+ },
)
mg, _ = passes.emit_vivado_project_transform_pass(mg)
diff --git a/test/passes/graph/transforms/verilog/test_emit_verilog_norm.py b/test/passes/graph/transforms/verilog/test_emit_verilog_norm.py
index acc659331..324a2e29d 100644
--- a/test/passes/graph/transforms/verilog/test_emit_verilog_norm.py
+++ b/test/passes/graph/transforms/verilog/test_emit_verilog_norm.py
@@ -116,7 +116,7 @@ def add_norm_metadata_gen_lut_analysis_pass(mg, config={}):
mem_dir = Path(__file__).parent / "build" / "norm" / "mem"
makedirs(mem_dir, exist_ok=True)
- lut = make_lut(2 ** LUT_POW, ISQRT_WIDTH)
+ lut = make_lut(2**LUT_POW, ISQRT_WIDTH)
mem_path = mem_dir / f"norm_isqrt_lut.mem"
write_memb(mem_path, lut, ISQRT_WIDTH)
mem_id = 0
@@ -224,10 +224,23 @@ def test_emit_verilog_norm():
shape = [10, 4, 8, 8]
normalizations = [
- nn.BatchNorm2d(num_features=shape[1], affine=False,),
- nn.LayerNorm(normalized_shape=shape[1:], elementwise_affine=False,),
- nn.GroupNorm(num_groups=2, num_channels=shape[1], affine=False,),
- nn.InstanceNorm2d(num_features=shape[1], affine=False,),
+ nn.BatchNorm2d(
+ num_features=shape[1],
+ affine=False,
+ ),
+ nn.LayerNorm(
+ normalized_shape=shape[1:],
+ elementwise_affine=False,
+ ),
+ nn.GroupNorm(
+ num_groups=2,
+ num_channels=shape[1],
+ affine=False,
+ ),
+ nn.InstanceNorm2d(
+ num_features=shape[1],
+ affine=False,
+ ),
# rms.RMSNorm(
# normalized_shape=shape[1:],
# ),
diff --git a/test/passes/onnx/analysis/test_export_fx_graph.py b/test/passes/onnx/analysis/test_export_fx_graph.py
index fc081f1ee..395991cda 100644
--- a/test/passes/onnx/analysis/test_export_fx_graph.py
+++ b/test/passes/onnx/analysis/test_export_fx_graph.py
@@ -78,13 +78,16 @@ def test_export_fx_graph_bert():
@pytest.mark.skip
def test_export_fx_graph_mistral():
- export_fx_graph_model("mistral-community/Mistral-7B-v0.2",)
+ export_fx_graph_model(
+ "mistral-community/Mistral-7B-v0.2",
+ )
@pytest.mark.skip
def test_export_fx_graph_whisper():
export_fx_graph_model(
- "openai/whisper-tiny", skip_export=True,
+ "openai/whisper-tiny",
+ skip_export=True,
)
diff --git a/test/tools/test_onnx_operators.py b/test/tools/test_onnx_operators.py
index 7393cb41f..c15a44d89 100644
--- a/test/tools/test_onnx_operators.py
+++ b/test/tools/test_onnx_operators.py
@@ -14,20 +14,58 @@ def excepthook(exc_type, exc_value, exc_traceback):
def test_gather():
- data1 = torch.Tensor([[1.0, 1.2, 1.9], [2.3, 3.4, 3.9], [4.5, 5.7, 5.9],])
+ data1 = torch.Tensor(
+ [
+ [1.0, 1.2, 1.9],
+ [2.3, 3.4, 3.9],
+ [4.5, 5.7, 5.9],
+ ]
+ )
- data2 = torch.Tensor([[1.0, 1.2], [2.3, 3.4], [4.5, 5.7],])
+ data2 = torch.Tensor(
+ [
+ [1.0, 1.2],
+ [2.3, 3.4],
+ [4.5, 5.7],
+ ]
+ )
- indices1 = torch.Tensor([[0, 2],]).to(torch.int64)
+ indices1 = torch.Tensor(
+ [
+ [0, 2],
+ ]
+ ).to(torch.int64)
- indices2 = torch.Tensor([[0, 1], [1, 2],]).to(torch.int64)
+ indices2 = torch.Tensor(
+ [
+ [0, 1],
+ [1, 2],
+ ]
+ ).to(torch.int64)
obs_out1 = onnx_gather(data1, 1, indices1)
obs_out2 = onnx_gather(data2, 0, indices2)
- exp_out1 = torch.Tensor([[[1.0, 1.9]], [[2.3, 3.9]], [[4.5, 5.9]],])
+ exp_out1 = torch.Tensor(
+ [
+ [[1.0, 1.9]],
+ [[2.3, 3.9]],
+ [[4.5, 5.9]],
+ ]
+ )
- exp_out2 = torch.Tensor([[[1.0, 1.2], [2.3, 3.4],], [[2.3, 3.4], [4.5, 5.7],],])
+ exp_out2 = torch.Tensor(
+ [
+ [
+ [1.0, 1.2],
+ [2.3, 3.4],
+ ],
+ [
+ [2.3, 3.4],
+ [4.5, 5.7],
+ ],
+ ]
+ )
print(obs_out2)
print(exp_out2)
@@ -37,7 +75,12 @@ def test_gather():
def test_slice():
- data = torch.Tensor([[1, 2, 3, 4], [5, 6, 7, 8],])
+ data = torch.Tensor(
+ [
+ [1, 2, 3, 4],
+ [5, 6, 7, 8],
+ ]
+ )
test1 = onnx_slice(
data,
From d9e15b3aeb70b1ca1e2ffcd799fb09789b50fcf2 Mon Sep 17 00:00:00 2001
From: Johnny1882
Date: Thu, 20 Feb 2025 23:39:48 +0000
Subject: [PATCH 14/38] remove optical test to prevent doc generation error
---
.../transforms/optical/test_optical_module.py | 70 -------------------
1 file changed, 70 deletions(-)
delete mode 100644 test/passes/module/transforms/optical/test_optical_module.py
diff --git a/test/passes/module/transforms/optical/test_optical_module.py b/test/passes/module/transforms/optical/test_optical_module.py
deleted file mode 100644
index e65546822..000000000
--- a/test/passes/module/transforms/optical/test_optical_module.py
+++ /dev/null
@@ -1,70 +0,0 @@
-#!/usr/bin/env python3
-# This example converts a simple MLP model to an ONN model
-import sys
-
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-from pathlib import Path
-
-sys.path.append(Path(__file__).resolve().parents[5].as_posix())
-
-
-from chop.passes.module.transforms.optical import optical_module_transform_pass
-
-
-class Net(nn.Module):
- def __init__(self):
- super(Net, self).__init__()
- self.conv1 = nn.Conv2d(1, 32, 3, 1)
- self.conv2 = nn.Conv2d(32, 64, 3, 1)
- self.dropout1 = nn.Dropout(0.25)
- self.dropout2 = nn.Dropout(0.5)
- self.fc1 = nn.Linear(9216, 128)
- self.fc2 = nn.Linear(128, 10)
-
- def forward(self, x):
- x = self.conv1(x)
- x = F.relu(x)
- x = self.conv2(x)
- x = F.relu(x)
- x = F.max_pool2d(x, 2)
- x = self.dropout1(x)
- x = torch.flatten(x, 1)
- x = self.fc1(x)
- x = F.relu(x)
- x = self.dropout2(x)
- x = self.fc2(x)
- output = F.log_softmax(x, dim=1)
- return output
-
-
-def test_optical_module_transform_pass():
- model = Net()
- # Sanity check and report
- pass_args = {
- "by": "name",
- "fc1": {
- "config": {
- "name": "morr",
- "miniblock": 4,
- "morr_init": True,
- "trainable_morr_bias": False,
- "trainable_morr_scale": False,
- }
- },
- "conv1": {
- "config": {
- "name": "morr",
- "miniblock": 4,
- "morr_init": True,
- "trainable_morr_bias": False,
- "trainable_morr_scale": False,
- }
- },
- }
- optical_module_transform_pass(model, pass_args)
-
-
-test_optical_module_transform_pass()
From ea77e742714472ca427677adc6298c4c3f640cd6 Mon Sep 17 00:00:00 2001
From: Johnny1882
Date: Thu, 20 Feb 2025 23:54:01 +0000
Subject: [PATCH 15/38] reformat
---
src/chop/passes/module/transforms/__init__.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/src/chop/passes/module/transforms/__init__.py b/src/chop/passes/module/transforms/__init__.py
index 05df7efe2..54f545bfd 100644
--- a/src/chop/passes/module/transforms/__init__.py
+++ b/src/chop/passes/module/transforms/__init__.py
@@ -1,4 +1,4 @@
from .autosharding import resharding_transform_pass
from .quantize import quantize_module_transform_pass
from .autosharding import resharding_transform_pass
-from .snn import ann2snn_module_transform_pass
\ No newline at end of file
+from .snn import ann2snn_module_transform_pass
From 6be34c906ebb61b8e0af6fe299c1be9de71b07a9 Mon Sep 17 00:00:00 2001
From: Johnny1882
Date: Fri, 21 Feb 2025 00:25:38 +0000
Subject: [PATCH 16/38] remove unorgainzed test
---
test/self/test_optical_module.py | 252 -------------------------------
test/self/train_mnist_cnn.py | 247 ------------------------------
2 files changed, 499 deletions(-)
delete mode 100644 test/self/test_optical_module.py
delete mode 100644 test/self/train_mnist_cnn.py
diff --git a/test/self/test_optical_module.py b/test/self/test_optical_module.py
deleted file mode 100644
index e56921881..000000000
--- a/test/self/test_optical_module.py
+++ /dev/null
@@ -1,252 +0,0 @@
-#!/usr/bin/env python3
-# This example converts a simple MLP model to Verilog
-import logging
-import os
-import sys
-
-import torch
-import torch.nn as nn
-
-from torch.profiler import profile, record_function, ProfilerActivity
-import torchvision.models as models
-import torchvision.transforms as transforms
-import torch.utils.data as data
-from pathlib import Path
-
-sys.path.append(Path(__file__).resolve().parents[5].as_posix())
-
-
-# from chop.passes.module.transforms import quantize_module_transform_pass
-from chop.passes.module.transforms import optical_module_transform_pass
-from chop.passes.module import report_trainable_parameters_analysis_pass
-
-import argparse
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-import torch.optim as optim
-from torchvision import datasets, transforms
-from torch.optim.lr_scheduler import StepLR
-
-from train_mnist_cnn import test, train, Net, test_memory_detailed
-
-# --------------------------------------------------
-# Model specifications
-# --------------------------------------------------
-# class MLP(torch.nn.Module):
-# """
-# Toy quantized FC model for digit recognition on MNIST
-# """
-
-# def __init__(self) -> None:
-# super().__init__()
-
-# self.fc1 = nn.Linear(28 * 28, 28 * 28)
-# self.fc2 = nn.Linear(28 * 28, 28 * 28 * 4)
-# self.fc3 = nn.Linear(28 * 28 * 4, 10)
-
-# def forward(self, x):
-# x = torch.flatten(x, start_dim=1, end_dim=-1)
-# x = torch.nn.functional.relu(self.fc1(x))
-# # w = torch.randn((4, 28 * 28))
-# # x = torch.nn.functional.relu(nn.functional.linear(x, w))
-# x = torch.nn.functional.relu(self.fc2(x))
-# x = self.fc3(x)
-# return x
-
-
-def load_my_model(model_path, device):
- # Load the model from the .pt file
- loaded_model = torch.load(model_path, map_location=device)
- # Set it to evaluation mode (important if it contains layers like BatchNorm or Dropout)
- loaded_model.eval()
- return loaded_model
-
-
-def perform_optical_module_transform_pass(model, save_path="mase_output/onn_cnn.pt"):
- pass_args = {
- "by": "type",
- "linear": {
- "config": {
- "name": "morr",
- "miniblock": 4,
- "morr_init": True,
- "trainable_morr_bias": False,
- "trainable_morr_scale": False,
- "device": device,
- }
- },
- "conv2d": {
- "config": {
- "name": "morr",
- "miniblock": 4,
- "morr_init": True,
- "trainable_morr_bias": False,
- "trainable_morr_scale": False,
- "device": device,
- }
- },
- }
- onn_model, _ = optical_module_transform_pass(model, pass_args)
- torch.save(onn_model.state_dict(), save_path)
- return onn_model
-
-
-def test_optical_module_transform_pass():
- model_path = "mase_output/sample_mnist_cnn.pt"
- mnist_cnn = load_my_model(model_path)
- # Sanity check and report
- pass_args = {
- "by": "name",
- "fc1": {
- "config": {
- "name": "morr",
- "miniblock": 4,
- "morr_init": True,
- "trainable_morr_bias": False,
- "trainable_morr_scale": False,
- "device": device,
- }
- },
- }
- onn_cnn, _ = optical_module_transform_pass(mnist_cnn, pass_args)
- torch.save(onn_cnn, "mase_output/onn_cnn.pt")
-
-
-if __name__ == "__main__":
- finetune = True
- if True:
- parser = argparse.ArgumentParser(description="PyTorch MNIST Example")
- parser.add_argument(
- "--batch-size",
- type=int,
- default=64,
- metavar="N",
- help="input batch size for training (default: 64)",
- )
- parser.add_argument(
- "--test-batch-size",
- type=int,
- default=100,
- metavar="N",
- help="input batch size for testing (default: 1000)",
- )
- parser.add_argument(
- "--epochs",
- type=int,
- default=1,
- metavar="N",
- help="number of epochs to train (default: 14)",
- )
- parser.add_argument(
- "--lr",
- type=float,
- default=1.0,
- metavar="LR",
- help="learning rate (default: 1.0)",
- )
- parser.add_argument(
- "--gamma",
- type=float,
- default=0.7,
- metavar="M",
- help="Learning rate step gamma (default: 0.7)",
- )
- parser.add_argument(
- "--no-cuda",
- action="store_true",
- default=False,
- help="disables CUDA training",
- )
- parser.add_argument(
- "--no-mps",
- action="store_true",
- default=False,
- help="disables macOS GPU training",
- )
- parser.add_argument(
- "--dry-run",
- action="store_true",
- default=False,
- help="quickly check a single pass",
- )
- parser.add_argument(
- "--seed", type=int, default=1, metavar="S", help="random seed (default: 1)"
- )
- parser.add_argument(
- "--log-interval",
- type=int,
- default=10,
- metavar="N",
- help="how many batches to wait before logging training status",
- )
- parser.add_argument(
- "--save-model",
- action="store_true",
- default=True,
- help="For Saving the current Model",
- )
- parser.add_argument(
- "--gpu-id", type=int, default=1, help="Which GPU device to use [default: 0]"
- )
-
- args = parser.parse_args()
- use_cuda = not args.no_cuda and torch.cuda.is_available()
- use_mps = not args.no_mps and torch.backends.mps.is_available()
-
- torch.manual_seed(args.seed)
-
- if not args.no_cuda and torch.cuda.is_available():
- device = torch.device(f"cuda:{args.gpu_id}")
- elif use_mps:
- device = torch.device("mps")
- else:
- device = torch.device("cpu")
-
- train_kwargs = {"batch_size": args.batch_size}
- test_kwargs = {"batch_size": args.test_batch_size}
- if use_cuda:
- cuda_kwargs = {"num_workers": 1, "pin_memory": True, "shuffle": True}
- train_kwargs.update(cuda_kwargs)
- test_kwargs.update(cuda_kwargs)
-
- transform = transforms.Compose(
- [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
- )
- dataset1 = datasets.MNIST(
- "../data", train=True, download=True, transform=transform
- )
- dataset2 = datasets.MNIST("../data", train=False, transform=transform)
- train_loader = torch.utils.data.DataLoader(dataset1, **train_kwargs)
- test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)
-
- cnn = load_my_model("mase_output/sample_mnist_cnn.pt", device)
- print("-------------- Testing the original cnn model -------------------")
- test(cnn, device, test_loader)
- _, _ = report_trainable_parameters_analysis_pass(cnn)
-
- # onn = load_my_model("mase_output/onn_cnn.pt", device)
- onn_model = perform_optical_module_transform_pass(cnn)
- onn_model.to(device)
-
- print("-------------- Testing the transformed onn model -------------------")
- test(onn_model, device, test_loader)
- _, _ = report_trainable_parameters_analysis_pass(onn_model)
-
- ######### Training the onn model
- if finetune:
- optimizer = optim.Adadelta(onn_model.parameters(), lr=args.lr)
- scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
-
- for epoch in range(1, args.epochs + 1):
- train(args, onn_model, device, train_loader, optimizer, epoch)
- test(onn_model, device, test_loader)
- scheduler.step()
-
- torch.save(onn_model.state_dict(), "mase_output/trained_onn.pt")
-
- print("-------------- Testing the trained onn model -------------------")
- test(onn_model, device, test_loader)
- _, _ = report_trainable_parameters_analysis_pass(onn_model)
-
- # test_optical_module_transform_pass()
diff --git a/test/self/train_mnist_cnn.py b/test/self/train_mnist_cnn.py
deleted file mode 100644
index a955087da..000000000
--- a/test/self/train_mnist_cnn.py
+++ /dev/null
@@ -1,247 +0,0 @@
-import argparse
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-import torch.optim as optim
-from torchvision import datasets, transforms
-from torch.optim.lr_scheduler import StepLR
-
-
-class Net(nn.Module):
- def __init__(self):
- super(Net, self).__init__()
- self.conv1 = nn.Conv2d(1, 32, 3, 1)
- self.conv2 = nn.Conv2d(32, 64, 3, 1)
- self.dropout1 = nn.Dropout(0.25)
- self.dropout2 = nn.Dropout(0.5)
- self.fc1 = nn.Linear(9216, 128)
- self.fc2 = nn.Linear(128, 10)
-
- def forward(self, x):
- x = self.conv1(x)
- x = F.relu(x)
- x = self.conv2(x)
- x = F.relu(x)
- x = F.max_pool2d(x, 2)
- x = self.dropout1(x)
- x = torch.flatten(x, 1)
- x = self.fc1(x)
- x = F.relu(x)
- x = self.dropout2(x)
- x = self.fc2(x)
- output = F.log_softmax(x, dim=1)
- return output
-
-
-def train(args, model, device, train_loader, optimizer, epoch):
- model.train()
- for batch_idx, (data, target) in enumerate(train_loader):
- data, target = data.to(device), target.to(device)
- optimizer.zero_grad()
- output = model(data)
- loss = F.nll_loss(output, target)
- loss.backward()
- optimizer.step()
- if batch_idx % args.log_interval == 0:
- print(
- "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
- epoch,
- batch_idx * len(data),
- len(train_loader.dataset),
- 100.0 * batch_idx / len(train_loader),
- loss.item(),
- )
- )
- if args.dry_run:
- break
-
-
-def test(model, device, test_loader):
- model.eval()
- test_loss = 0
- correct = 0
- with torch.no_grad():
- for data, target in test_loader:
- data, target = data.to(device), target.to(device)
- output = model(data)
- test_loss += F.nll_loss(
- output, target, reduction="sum"
- ).item() # sum up batch loss
- pred = output.argmax(
- dim=1, keepdim=True
- ) # get the index of the max log-probability
- correct += pred.eq(target.view_as(pred)).sum().item()
-
- test_loss /= len(test_loader.dataset)
-
- print(
- "\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format(
- test_loss,
- correct,
- len(test_loader.dataset),
- 100.0 * correct / len(test_loader.dataset),
- )
- )
-
-
-# Custom Function
-from torch.profiler import profile, record_function, ProfilerActivity
-
-
-def test_memory_detailed(model, device, test_loader):
- """
- Use PyTorch Profiler to record detailed memory usage on the first batch,
- so you can see exactly which ops consume how much memory.
- """
-
- # Put the model in eval mode
- model.eval()
-
- # Get just 1 batch (or a few) for profiling
- data_iter = iter(test_loader)
- try:
- data, target = next(data_iter)
- except StopIteration:
- print("test_loader is empty. Cannot profile.")
- return
-
- # Move to device
- data, target = data.to(device), target.to(device)
-
- # Now start the profiler
- with profile(
- activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
- record_shapes=True, # to see tensor shapes
- profile_memory=True, # track memory usage
- ) as prof:
- with record_function("test_first_batch"):
- # Forward pass on just this batch
- output = model(data)
- # Suppose you also compute a loss for some reason
- loss = F.nll_loss(output, target, reduction="sum")
- # If purely inference, you might not do backward
- # but let's illustrate:
- loss.backward() # If you want to see backward pass memory usage
-
- # Print the summarized table
- # Sort by self_cuda_memory_usage to see which ops used the most GPU memory
- print(
- prof.key_averages(group_by_input_shape=True).table(
- sort_by="self_cuda_memory_usage",
- row_limit=200, # show as many rows as you need
- )
- )
-
-
-def main():
- # Training settings
- parser = argparse.ArgumentParser(description="PyTorch MNIST Example")
- parser.add_argument(
- "--batch-size",
- type=int,
- default=64,
- metavar="N",
- help="input batch size for training (default: 64)",
- )
- parser.add_argument(
- "--test-batch-size",
- type=int,
- default=1000,
- metavar="N",
- help="input batch size for testing (default: 1000)",
- )
- parser.add_argument(
- "--epochs",
- type=int,
- default=29,
- metavar="N",
- help="number of epochs to train (default: 14)",
- )
- parser.add_argument(
- "--lr",
- type=float,
- default=1.0,
- metavar="LR",
- help="learning rate (default: 1.0)",
- )
- parser.add_argument(
- "--gamma",
- type=float,
- default=0.7,
- metavar="M",
- help="Learning rate step gamma (default: 0.7)",
- )
- parser.add_argument(
- "--no-cuda", action="store_true", default=False, help="disables CUDA training"
- )
- parser.add_argument(
- "--no-mps",
- action="store_true",
- default=False,
- help="disables macOS GPU training",
- )
- parser.add_argument(
- "--dry-run",
- action="store_true",
- default=False,
- help="quickly check a single pass",
- )
- parser.add_argument(
- "--seed", type=int, default=1, metavar="S", help="random seed (default: 1)"
- )
- parser.add_argument(
- "--log-interval",
- type=int,
- default=10,
- metavar="N",
- help="how many batches to wait before logging training status",
- )
- parser.add_argument(
- "--save-model",
- action="store_true",
- default=True,
- help="For Saving the current Model",
- )
- args = parser.parse_args()
- use_cuda = not args.no_cuda and torch.cuda.is_available()
- use_mps = not args.no_mps and torch.backends.mps.is_available()
-
- torch.manual_seed(args.seed)
-
- if use_cuda:
- device = torch.device("cuda")
- elif use_mps:
- device = torch.device("mps")
- else:
- device = torch.device("cpu")
-
- train_kwargs = {"batch_size": args.batch_size}
- test_kwargs = {"batch_size": args.test_batch_size}
- if use_cuda:
- cuda_kwargs = {"num_workers": 1, "pin_memory": True, "shuffle": True}
- train_kwargs.update(cuda_kwargs)
- test_kwargs.update(cuda_kwargs)
-
- transform = transforms.Compose(
- [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
- )
- dataset1 = datasets.MNIST("../data", train=True, download=True, transform=transform)
- dataset2 = datasets.MNIST("../data", train=False, transform=transform)
- train_loader = torch.utils.data.DataLoader(dataset1, **train_kwargs)
- test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)
-
- model = Net().to(device)
- optimizer = optim.Adadelta(model.parameters(), lr=args.lr)
-
- scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
- for epoch in range(1, args.epochs + 1):
- train(args, model, device, train_loader, optimizer, epoch)
- test(model, device, test_loader)
- scheduler.step()
-
- if args.save_model:
- torch.save(model, "mase_output/sample_mnist_cnn.pt")
-
-
-if __name__ == "__main__":
- main()
From 87a222b7d0d90fb082dd96548dd4bbd75644a7e3 Mon Sep 17 00:00:00 2001
From: Johnny1882
Date: Fri, 21 Feb 2025 00:29:46 +0000
Subject: [PATCH 17/38] remove self test file
---
.gitignore | 4 +-
src/chop/nn/optical/utils/quantize.py | 54 +++++++++++++--------------
2 files changed, 30 insertions(+), 28 deletions(-)
diff --git a/.gitignore b/.gitignore
index acd171321..e7118d94c 100644
--- a/.gitignore
+++ b/.gitignore
@@ -166,4 +166,6 @@ prof/
# HuggingFace trainer cache
mase-trainer/
-test-trainer/
\ No newline at end of file
+test-trainer/
+
+test/self
\ No newline at end of file
diff --git a/src/chop/nn/optical/utils/quantize.py b/src/chop/nn/optical/utils/quantize.py
index 84372c8c7..0bb53f90c 100644
--- a/src/chop/nn/optical/utils/quantize.py
+++ b/src/chop/nn/optical/utils/quantize.py
@@ -1,10 +1,10 @@
-"""
-Description:
-Author: Jiaqi Gu (jqgu@utexas.edu)
-Date: 2021-06-06 03:15:00
-LastEditors: Jiaqi Gu (jqgu@utexas.edu)
-LastEditTime: 2021-06-06 03:15:00
-"""
+# """
+# Description:
+# Author: Jiaqi Gu (jqgu@utexas.edu)
+# Date: 2021-06-06 03:15:00
+# LastEditors: Jiaqi Gu (jqgu@utexas.edu)
+# LastEditTime: 2021-06-06 03:15:00
+# """
import numpy as np
import torch
@@ -48,12 +48,12 @@ def backward(ctx, grad_output):
############ add observer and new quant based on range and zeropoint for activation
def uniform_quantize_new(k, gradient_clip=False):
- """
- Support uniform quantization with auto-adjusted input data range
- args:
- k: bitwidth
- scale, zeropoint: obtained from observer
- """
+ # """
+ # Support uniform quantization with auto-adjusted input data range
+ # args:
+ # k: bitwidth
+ # scale, zeropoint: obtained from observer
+ # """
class qfn(torch.autograd.Function):
@staticmethod
@@ -90,12 +90,12 @@ class input_quantize_fn(torch.nn.Module):
def __init__(
self, in_bit, alg="dorefa", device=torch.device("cuda:0"), quant_ratio=1.0
):
- """Input quantizer with Quant_Noise supported
- Args:
- in_bit (int): Input quantization bitwidth.
- device (Device, optional): torch Device. Defaults to torch.device("cuda:0").
- quant_ratio (float, optional): Quantization ratio. Defaults to 1.0.
- """
+ # """Input quantizer with Quant_Noise supported
+ # Args:
+ # in_bit (int): Input quantization bitwidth.
+ # device (Device, optional): torch Device. Defaults to torch.device("cuda:0").
+ # quant_ratio (float, optional): Quantization ratio. Defaults to 1.0.
+ # """
super(input_quantize_fn, self).__init__()
assert 1 <= in_bit <= 32
self.in_bit = in_bit
@@ -235,14 +235,14 @@ def forward(self, x):
class weight_quantize_fn(torch.nn.Module):
def __init__(self, w_bit, mode="oconv", alg="dorefa", quant_ratio=1.0):
- """Differentiable weight quantizer. Support different algorithms. Support Quant-Noise with partial quantization.
-
- Args:
- w_bit (int): quantization bitwidth
- mode (str, optional): Different mode indicates different NN architectures. Defaults to "oconv".
- alg (str, optional): Quantization algorithms. [dorefa, dorefa_sym, qnn, dorefa_pos] Defaults to "dorefa".
- quant_ratio (float, optional): Quantization ratio to support full-precision gradient flow. Defaults to 1.0.
- """
+ # """Differentiable weight quantizer. Support different algorithms. Support Quant-Noise with partial quantization.
+
+ # Args:
+ # w_bit (int): quantization bitwidth
+ # mode (str, optional): Different mode indicates different NN architectures. Defaults to "oconv".
+ # alg (str, optional): Quantization algorithms. [dorefa, dorefa_sym, qnn, dorefa_pos] Defaults to "dorefa".
+ # quant_ratio (float, optional): Quantization ratio to support full-precision gradient flow. Defaults to 1.0.
+ # """
super(weight_quantize_fn, self).__init__()
assert 1 <= w_bit <= 32, logging.error(
f"Only support 1 - 32 bit quantization, but got {w_bit}"
From 67b91b505a75ee70a9eced9f7eda95df7c5fbff4 Mon Sep 17 00:00:00 2001
From: Johnny1882
Date: Mon, 10 Mar 2025 19:35:28 +0000
Subject: [PATCH 18/38] add back test_optical_module
---
.../passes/module/module_modify_helper.py | 2 +-
.../transforms/optical/test_optical_module.py | 70 +++++++++++++++++++
2 files changed, 71 insertions(+), 1 deletion(-)
create mode 100644 test/passes/module/transforms/optical/test_optical_module.py
diff --git a/src/chop/passes/module/module_modify_helper.py b/src/chop/passes/module/module_modify_helper.py
index 2e6a577e0..73806b2ac 100644
--- a/src/chop/passes/module/module_modify_helper.py
+++ b/src/chop/passes/module/module_modify_helper.py
@@ -131,7 +131,7 @@ def instantiate_conv2d(module, postfix, module_map, additional_module_args):
has_bias = not (module.bias is None)
# TODO: some transformed modules have "config" as an argument then extract the additional_module_args from it. Some directly take the additional_module_args.
# Need to handle this better
- if "config" in inspect.signature(conv2d.__init__).parameters:
+ if "config" in inspect.signature(conv2d_cls.__init__).parameters:
conv2d = conv2d_cls(
in_channels=module.in_channels,
out_channels=module.out_channels,
diff --git a/test/passes/module/transforms/optical/test_optical_module.py b/test/passes/module/transforms/optical/test_optical_module.py
new file mode 100644
index 000000000..e65546822
--- /dev/null
+++ b/test/passes/module/transforms/optical/test_optical_module.py
@@ -0,0 +1,70 @@
+#!/usr/bin/env python3
+# This example converts a simple MLP model to an ONN model
+import sys
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from pathlib import Path
+
+sys.path.append(Path(__file__).resolve().parents[5].as_posix())
+
+
+from chop.passes.module.transforms.optical import optical_module_transform_pass
+
+
+class Net(nn.Module):
+ def __init__(self):
+ super(Net, self).__init__()
+ self.conv1 = nn.Conv2d(1, 32, 3, 1)
+ self.conv2 = nn.Conv2d(32, 64, 3, 1)
+ self.dropout1 = nn.Dropout(0.25)
+ self.dropout2 = nn.Dropout(0.5)
+ self.fc1 = nn.Linear(9216, 128)
+ self.fc2 = nn.Linear(128, 10)
+
+ def forward(self, x):
+ x = self.conv1(x)
+ x = F.relu(x)
+ x = self.conv2(x)
+ x = F.relu(x)
+ x = F.max_pool2d(x, 2)
+ x = self.dropout1(x)
+ x = torch.flatten(x, 1)
+ x = self.fc1(x)
+ x = F.relu(x)
+ x = self.dropout2(x)
+ x = self.fc2(x)
+ output = F.log_softmax(x, dim=1)
+ return output
+
+
+def test_optical_module_transform_pass():
+ model = Net()
+ # Sanity check and report
+ pass_args = {
+ "by": "name",
+ "fc1": {
+ "config": {
+ "name": "morr",
+ "miniblock": 4,
+ "morr_init": True,
+ "trainable_morr_bias": False,
+ "trainable_morr_scale": False,
+ }
+ },
+ "conv1": {
+ "config": {
+ "name": "morr",
+ "miniblock": 4,
+ "morr_init": True,
+ "trainable_morr_bias": False,
+ "trainable_morr_scale": False,
+ }
+ },
+ }
+ optical_module_transform_pass(model, pass_args)
+
+
+test_optical_module_transform_pass()
From ce0f533953e61f60f1caafa58d270925b087a368 Mon Sep 17 00:00:00 2001
From: Johnny1882
Date: Tue, 11 Mar 2025 19:06:57 +0000
Subject: [PATCH 19/38] add bert transform file
---
...tfevents.1741637561.ee-tarrasque.3914313.0 | Bin 0 -> 5028 bytes
...tfevents.1741637649.ee-tarrasque.3918878.0 | Bin 0 -> 6507 bytes
...tfevents.1741641748.ee-tarrasque.4023347.0 | Bin 0 -> 5038 bytes
.../transforms/optical/bert-finetune.py | 71 ++
.../module/transforms/optical/run_glue.py | 637 ++++++++++++++++++
wandb/latest-run | 1 +
.../files/config.yaml | 497 ++++++++++++++
.../files/requirements.txt | 295 ++++++++
.../files/wandb-metadata.json | 60 ++
.../files/wandb-summary.json | 1 +
.../run-60c8phhh.wandb | Bin 0 -> 164998 bytes
.../files/requirements.txt | 295 ++++++++
.../files/wandb-metadata.json | 60 ++
.../run-x0wxhan7.wandb | Bin 0 -> 98304 bytes
14 files changed, 1917 insertions(+)
create mode 100644 model_sst2/runs/Mar10_20-12-38_ee-tarrasque/events.out.tfevents.1741637561.ee-tarrasque.3914313.0
create mode 100644 model_sst2/runs/Mar10_20-14-07_ee-tarrasque/events.out.tfevents.1741637649.ee-tarrasque.3918878.0
create mode 100644 model_sst2/runs/Mar10_21-22-27_ee-tarrasque/events.out.tfevents.1741641748.ee-tarrasque.4023347.0
create mode 100644 test/passes/module/transforms/optical/bert-finetune.py
create mode 100644 test/passes/module/transforms/optical/run_glue.py
create mode 120000 wandb/latest-run
create mode 100644 wandb/run-20250310_202153-60c8phhh/files/config.yaml
create mode 100644 wandb/run-20250310_202153-60c8phhh/files/requirements.txt
create mode 100644 wandb/run-20250310_202153-60c8phhh/files/wandb-metadata.json
create mode 100644 wandb/run-20250310_202153-60c8phhh/files/wandb-summary.json
create mode 100644 wandb/run-20250310_202153-60c8phhh/run-60c8phhh.wandb
create mode 100644 wandb/run-20250310_212229-x0wxhan7/files/requirements.txt
create mode 100644 wandb/run-20250310_212229-x0wxhan7/files/wandb-metadata.json
create mode 100644 wandb/run-20250310_212229-x0wxhan7/run-x0wxhan7.wandb
diff --git a/model_sst2/runs/Mar10_20-12-38_ee-tarrasque/events.out.tfevents.1741637561.ee-tarrasque.3914313.0 b/model_sst2/runs/Mar10_20-12-38_ee-tarrasque/events.out.tfevents.1741637561.ee-tarrasque.3914313.0
new file mode 100644
index 0000000000000000000000000000000000000000..21fed95b763c01ecfbd0a4f987eaf9cc96de2140
GIT binary patch
literal 5028
zcmaJ_zmFwH5eDlP?w*7M&LCR7(DJ@}yL;<9TNV;0**X~?M#hd#R;y>GXI^_}rpMhq
zyZ3CJ2%I`1frN~Rgb<0pfrJ1dNRbFhkPsrU_^P_6XL|P?8|-SktE=nBS6@}nlW&Bd
zuYd96*Y|$=$MTpdR_cE9&Xwx@TW23`R2J7WtYZMAV@*GaqITwCpx
z-6)b0wQiLND_=eM=-=;tH|Rb3qAAfcH#p?KYZAL$gUR2ty-#T7*rc$kN&Y!Jc=iMu1`7)8o)v|Sp1u7#v
z`GCXOqPmhT*1R2)++bCuOP@)Sm~hbGRkJ4p)L_b8xI4;g(~H(PCrsax$b8M9tf*BW6w;SF;q>PUa*x=)g$&>I5|L&Y59#Q~woPhC=C+O`(pXt(%8)F}u^U_I$#Un_
z;9$)utKs7Z_mdhHK@ueSy5&>eSh%IS!S6*W%biME)2UWT+v~m&6iFPsP$%09K76fo
zvqOGW88mK!PFcxL6w1p}x@Pkx%cjM%u=wU;8E)v-P@YerI|R3r&jq9tAfj|X+(=6B3!Nq>72e&@rr3jST(e|lq3Ni6wz4%*;I%=oUO>n1a0VT
zDD<~V=)N6$cmJuhr(1EhJvlu)Ie#RSI`Pt4={_6h+xh&m41(qC;-tRT)}fGrCtN-k
z!L&X3dXb!ErM>Kh?rWkAaLCeZC$9)dbLlPajC%pP7u%qc7rpVQ*j9I%3OV5~9FVmN
z{6!f`@!`E2!JF-myX{^u_hi0V<$!$!(6?ZvyhL^=*V#s=ywxQ|wip|1&Uu(to_&%c+ZtIBjaxH_R=PZw
z%!L!M(;bwMc#LegHuh3kvLc#1ZzF0zBuciHC4jUv6!q^>7!@ZfQW45k^s-|Hy4eUT
z?PP(@>p!cA>1Lm->I@2L?2-5i0cpps&*0z60xgyFW|1NfK)|pfTqV5-X`NMvK#JPf
zh2YYO#wyf<71|M%hx$%o(hXQyi2&bAowG*mDq4t%p#Mat;XUFxNGrQJoV-_>zSfjW
z8W}J&9G1n#^ERuzup(H;6=5wox7^_tf?7_=$0dQ@`5-8&cL+iUPtg|fyKJ5C22+qE
zCz8i#tO{L3BGpJlZm|Qe`!FDI9+rk6w*yZOJRD?cf|)dSEe=k!h8-$(A_fMAlih&J
zVb4OMr-DUUT~Z^-y+1rIr2a-ur7%Goir(gSkjTgvu>?;mRcvYEnv%C(uncBMr$IxQ?Pju>BjiwAQ33ui_bn9>N?O&(a!)Ky6L6SlOBm2Nf`o{hdFh70&{I9EL%L`d2t|{dv8Wqdq|`3Qt1#ar!9EhpYr~?gl(jGfZSV2j>ZX;l!gH$Yp^3
z4?nMO@_NGP1PYEQaEi9!6hJ^FGDl^SqQ=K43l-sr*Yh`RuWCQ`3e&$vLCbd7QKw7S
z1xm$a-++(M(ea$g+I`3hgYgiAc>~P?1FlVNkb}+ip&OWwDW~Vh%WOU$BZ)@(9Yjp=
zsk-sPM!8$rU~r*IcX*sydP|3AGg-*WRxjSWyZYAZ4t*
z?^OJ?;s?5r16}(tA8Ws5N%r&rnN*Rj0_!RD;|+
zP$9rJ_n38|tUoEFQzs+p7Nb>!IeHjAhLnB`gJO*3
zLOAVcaQM(W!WdHlr7O`C@Cpe9NE{
z+?K*{e%plf=9h)T@Qkk)Z0};|a{>^iw-1H(!REokO?(-Vt>OxgmST!T@RbUJzmwoV
XLRcnuzP&g*TfJD#sK5B|=gp(M_Mw$6779hPJ}rcX-M-y+yGw1?Wrb2LrRAgifsx6)d*{7(+55%J
zz5CHpgNg#7wXSF}L@J1ym>?=9g28|we;^tqM(`t|#)guZ7S>=fA!7ZVnfva&v+u3@
z$8L7!oSE}G=XXBlZeL=4{`=+a5A>|rw}07h-~8j{7rdt~xfc#xIVHkm^U!qzl}i&5
zsz}x&F5Ll@1_76|1Jlw=gd7kg#TuS3m|3~$wc35V&o<{CJ>w6L_gwoObMCt5`YO-#
zReHH>s`ZJONmzEx{&zH8;Om8YG#bmDN|zDi$THA<5>O_=M+>gK*`5V^u|f`&dh4n#;vfxSo?(k|f3yPDn&^mJASa
z(|JZN(v;lwC_>zE(jfJDLYAvU#0u67>SMHe*Zuq$ErK9$!dlZ>NhASF1aUIU9PYG4
zerm4e3ob)1Y%=mB6o*BBZCEQ$3$NKie1&699582cmj|pa5vcv)!~
zFg!n^I1l1%k|{uPQyU$vo85&pulhed$i0s>M&~X{0F`@P%tkyCEAVtgVs|;j6YBC>5M7*rVTq$o#WLLCz
zm=v+^Ih-VL26?oQKr{us53nv7nNu6*Hp%rJDYRaup<2I_%fT8Os`U?U?5~Y5A^H<8
zC0Dnm#qC~tZLGsNSbb>7L*|7IDjTRFWwRBZA|e&3
zSGqIXGt?cfnIrXL5+%r3-wQlS0{Xc>%>+(C3G
z*49QK66rbQ*?QW$|paSd)AM*OJ!U}G4r{yW2tLlo7njs6p&1fY?zMZq>yAq
zW^%C&uYr+B*#>t&q_UyN{{)HQ>Of_s2d(iXcKtr6DNb!Zm&xgjOQ>ypHBdH&{L+EGWgW$8=wwFcaY7Xya1Wh-68y
z45ujtGD9*?QBgOUo;%EIqCN{|h&<`p^`oI&TF62||FOV>AK=-TOCA)C)q749Hav8V6$VkjQO
zxE%JBBtlAEB-kXi6dQf_!6EN-1S)rfD^QbTQ=f9QhF2uIC0P`gwwK>Pq%ENS*?2d`SHH>d>4LqWonSCpN?Lkh{29e_W&3YHHdmiVH{omr~q
zGz**AHnKA^PF0IY-AV~aaXLzE8i+2Xz+$CfE(Tc3gSQSg{*RZwdv5<|7KAbu#y-`K)4Owq!MRHAU`OT&CHevv$(I>
zWuu|$f8yY++o2*i`6tK2%M(XeuyCW4(0>C+RuC
zB`p+hPuWdO*9LbMP0EV1T`8rUC%k%CWZDWc#5{61V}(BQh22;MU79h%w&Y3x^BO2L(xdm
zP~aWlA)&b+x2V(Q*O`)v8NZ2ogpQ7fjn(c=R5Tg8QCM%FSzy$ac^lNu*5$)%!hFm-
zp?<~`=4&;QX!4{5k*I!)Sn$?Hx=Y@~;6j5>s*Qc6nTr$Dx*1i7(RcEYO6SUR;;r`lyd5$_5LY3F>;KEFz9NX^os_i=+rIzQq
z7!xpXV!7HBewtNLv%bi+nO64cVq}wfUn4yekNfzZA!#gcQZAvVn|BEwKFDCS5vt3X
z;)dFY*)-D;JPza6NsJTIGJ^)Y#IezIbG9CjXD0ZhEul#w&99)s=HA?=|v79sQkT
z)oV|@Jbv%%H>|iMQyMwNmh^n)%3SHf$_15kvP;UJbyh7~v3vC0Yfscam?MhC}M3|ny4UYIL(fitmnbpOAWSqd-x;@MAY
zIER;Z;5q!*fw3Le&s}s;hF|SEx}@jWuL^~(fYvxaAHVUg+a6h$X+3&y?qtvMUl&@7
zB6{V?t4IC2zxGTu)12wu`c6;fH-%;gq`~>Qt$R-Fx!Z#D+<^yg*N~p?gmmPMJI4FGzEo%}3h9NFJ4TPav+iOG(%GS`Nbb#rLRUy1dSdnETeojIWkGt26KyF){I4a2!vN<-t<-T(*NQdqy6uLsXc=bKwC+EMp$$~Wh`9f<^NQchbHu{5B^BWeV
Ukvj{`4oF8Q){Pzb^R>PI1Jnp%`Tzg`
literal 0
HcmV?d00001
diff --git a/model_sst2/runs/Mar10_21-22-27_ee-tarrasque/events.out.tfevents.1741641748.ee-tarrasque.4023347.0 b/model_sst2/runs/Mar10_21-22-27_ee-tarrasque/events.out.tfevents.1741641748.ee-tarrasque.4023347.0
new file mode 100644
index 0000000000000000000000000000000000000000..0c1232b6c0a3604c85c2345800ba755b7517067a
GIT binary patch
literal 5038
zcmaJ_&yOTG6^1Aic3U99-jdU(jwsVR-Mf=bq9_q7C0YTS1V|1DAXzuGp@g
z=|p>h|9}HG4oKWNaz^4W;J~jd5)xc_w#!wX>CGHwH1_lJ`|-W+J=c@3g`fZZ
z>d9~J-h1zzlfQiN_wT*(tD6TQ`@)ZEl-J9Zs6hxL__PMEbI
zDb`3OSXlX@x%khA-wt~J`O@G1aQEdu2fg3@)9LDur>pO9)7tw^Tssz3J^tlcaUidy#_D-A
z08+F(T@_qbx@8JQM7S&lZmrylc*V5Cj2M)xI7vVUMRb-xHU**&XKOMtK^uBo3jM7T
zx^Kqb-hax?=A3OeXWQ-B_7M}}%yDD5{dAmf_ve>o5G?mDPU;(JED9NT!sT-jOxu&M
z7s*+#wC6pOeM7VX4q2LQ`JR9@m)_#exEG*%aUN9qMW-DqR>@9MAt(HW1H2J|zbHc~
zK7Q{;@Mintt~xN!J(+JtSYTfO^h%Yyx`eraw`Gmd9Rp~p%R$RNXOv#EYOE=P^Btpw
zj$ypD)(tiiNVa8#sIl;H>*HT@$5~uT{(d5zx7QOwc9`yxgSgHX%VzUhw$Kbm7MLE~
zjL;7@0eU;G8U`+`q$NV3IEbF?EBNSikRko70a@L)hi-b#mTMkRp(m%5Mxaqk{UugI
z&g*(u+?XWMG8h$kh3rtSvyDzTBP)t*IkwoG^DwRa?B^-6O7og%+~|QRZu4AvE}Vd!
z?x2LkV`RgXHkZPX7188*8&LxyQL+uM0Hmd%sDFpTsC}X$6`@>N&pXdRr7$Qq^g$-B
z(L-bhW)HEhrR61hu=}(o&YQh6qSGj;u?GZdghq{BpA~}QC7LVUkj0FA0TsiZ(0#dM
zA;Gii5K`7?vk+Xm(;9)wFict^^-$?4V6p{23l@NUDRcg4Y)uO>CG;QbB;tU`4&3l=
z4y@m+bl*tIDh&@9ACAjn<631k7*?3qu_x3e=axIZnV_l@3~)&xdVdgP)msFigQsYa
z_+2T>oW?xF$%*7KTC6}n;Yd0Xky~uR@GguBJ`YPnklVq}5PnR^(gZVU09zcK2o5_`
z^h6OJ7?yVfE{8n}iJnRq1$Id-CHMaL;E*~UIhDdt3vluf_$hfCaAiBK-HsYjLi^MX
z!7GxmT2O<(LclOj1Tgy@7JF(^2(hFig{NqZMo%pdps2oeyQLRJSc;%eMHB^nv7@tU
zloS{XU<=q;%(~th1oDdaby;9LG6c}rph~Br@`yx_@#UV(^&~Zd79BbT;9KH~eAA%B
zDysN$Y?_;=7kmNWcc@}Rv)7coV!=0H0xLmCnDS=YnImLTTu}l3cotkLB9ti6^6Ef*
zPP1^BZFOFe!)SK|>fK0=%-TV_`$pHzdi2`M(#l{G#^m{GtFFG@q=Y#c4lwF_1vj%fFF>GGjk@wGVFWw
z$i7lzjn3NRv4Q~rbv*}o4$&l*??-M6lQqTz-B;k)xx`bZeXl&)VCn2QP4x4~w0y11
zSv#78MOK+ME$`)LUf(jBOPM(zAAbCkCrAjpnU`(|3_aCD5z<9dK`5H+tuqJ7&j~f1
z0(ncDH!;r*^#M&P%d_FYOQqv8*iolT*9A(&jNgKf(9!X*$=ZF$iU;E{2>T5*3yiuZwLu?jPam>{`IvWl|BRW<
z_s2-0t+dS$C;-@mhZY4sv~HTd__Z(jfF
zh46i$`{nTE-Jkw9eqZ><$?C5stCz#P5*lp&o#FQ<={v*kPW*QU#+}1cgjyrbQlXXz
zt5%SksRn)XK!pI?zQ+_aiJg_)in9@Qi_t2=96b!5MJhLj5j6Ic$B=$co2T4fiu#97
z=T?I2aV-(wO6{y*G*z7EgP3PY%sjYJ5e#4@+X`D%)t&(coTHO2%CoLDYe
zp~lFfVf{JR##8ar#iQ54`x@oxBv<&JVQ4IGsg}^w!@C3zA5<{f2=Q_}_~86e*bLJV
zJz|hv!!y2P_<$~kKFa`MdIeEvAFdxhTE{mM*~<2KzLZlW
fg0EB%{)-6?B!p#h56&;P+tstxjQU^j_kZ+%3`T29
literal 0
HcmV?d00001
diff --git a/test/passes/module/transforms/optical/bert-finetune.py b/test/passes/module/transforms/optical/bert-finetune.py
new file mode 100644
index 000000000..e2ac5438c
--- /dev/null
+++ b/test/passes/module/transforms/optical/bert-finetune.py
@@ -0,0 +1,71 @@
+import os
+os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
+
+import numpy as np
+import evaluate
+from datasets import load_dataset
+from transformers import (
+ AutoTokenizer,
+ AutoModelForSequenceClassification,
+ Trainer,
+ TrainingArguments,
+ DataCollatorWithPadding,
+)
+from chop.passes.module.transforms.optical import optical_module_transform_pass
+
+def bert_onn_transform(model):
+ pass_args = {
+ "by": "type",
+ "linear": {
+ "config": {
+ "name": "morr",
+ "miniblock": 4,
+ "morr_init": True,
+ "trainable_morr_bias": False,
+ "trainable_morr_scale": False,
+ }
+ },
+ }
+ model, _ = optical_module_transform_pass(model, pass_args)
+ return model
+
+def main():
+ model_name = "bert-base-uncased"
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
+ model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)
+ print(model)
+
+ # Placeholder for modifications
+ model = bert_onn_transform(model)
+
+ dataset = load_dataset("glue", "sst2")
+ def preprocess(examples):
+ return tokenizer(examples["sentence"], truncation=True, padding=True)
+ dataset = dataset.map(preprocess, batched=True)
+
+ data_collator = DataCollatorWithPadding(tokenizer)
+ metric = evaluate.load("accuracy")
+ def compute_metrics(eval_pred):
+ logits, labels = eval_pred
+ return metric.compute(predictions=np.argmax(logits, axis=1), references=labels)
+
+ training_args = TrainingArguments(
+ output_dir="model_sst2",
+ run_name="bert_sst2_experiment",
+ evaluation_strategy="epoch",
+ num_train_epochs=3,
+ logging_steps=50
+ )
+
+ trainer = Trainer(
+ model=model,
+ args=training_args,
+ train_dataset=dataset["train"],
+ eval_dataset=dataset["validation"],
+ data_collator=data_collator,
+ compute_metrics=compute_metrics
+ )
+ trainer.train()
+
+if __name__ == "__main__":
+ main()
diff --git a/test/passes/module/transforms/optical/run_glue.py b/test/passes/module/transforms/optical/run_glue.py
new file mode 100644
index 000000000..b5d5aeb77
--- /dev/null
+++ b/test/passes/module/transforms/optical/run_glue.py
@@ -0,0 +1,637 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2020 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Finetuning the library models for sequence classification on GLUE."""
+# You can also adapt this script on your own text classification task. Pointers for this are left as comments.
+
+import logging
+import os
+import random
+import sys
+from dataclasses import dataclass, field
+from typing import Optional
+
+import datasets
+import evaluate
+import numpy as np
+from datasets import load_dataset
+
+import transformers
+from transformers import (
+ AutoConfig,
+ AutoModelForSequenceClassification,
+ AutoTokenizer,
+ DataCollatorWithPadding,
+ EvalPrediction,
+ HfArgumentParser,
+ PretrainedConfig,
+ Trainer,
+ TrainingArguments,
+ default_data_collator,
+ set_seed,
+)
+from transformers.trainer_utils import get_last_checkpoint
+from transformers.utils import check_min_version, send_example_telemetry
+from transformers.utils.versions import require_version
+
+
+# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
+check_min_version("4.50.0.dev0")
+
+require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
+
+task_to_keys = {
+ "cola": ("sentence", None),
+ "mnli": ("premise", "hypothesis"),
+ "mrpc": ("sentence1", "sentence2"),
+ "qnli": ("question", "sentence"),
+ "qqp": ("question1", "question2"),
+ "rte": ("sentence1", "sentence2"),
+ "sst2": ("sentence", None),
+ "stsb": ("sentence1", "sentence2"),
+ "wnli": ("sentence1", "sentence2"),
+}
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class DataTrainingArguments:
+ """
+ Arguments pertaining to what data we are going to input our model for training and eval.
+
+ Using `HfArgumentParser` we can turn this class
+ into argparse arguments to be able to specify them on
+ the command line.
+ """
+
+ task_name: Optional[str] = field(
+ default=None,
+ metadata={"help": "The name of the task to train on: " + ", ".join(task_to_keys.keys())},
+ )
+ dataset_name: Optional[str] = field(
+ default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
+ )
+ dataset_config_name: Optional[str] = field(
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
+ )
+ max_seq_length: int = field(
+ default=128,
+ metadata={
+ "help": (
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ )
+ },
+ )
+ overwrite_cache: bool = field(
+ default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."}
+ )
+ pad_to_max_length: bool = field(
+ default=True,
+ metadata={
+ "help": (
+ "Whether to pad all samples to `max_seq_length`. "
+ "If False, will pad the samples dynamically when batching to the maximum length in the batch."
+ )
+ },
+ )
+ max_train_samples: Optional[int] = field(
+ default=None,
+ metadata={
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ )
+ },
+ )
+ max_eval_samples: Optional[int] = field(
+ default=None,
+ metadata={
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
+ "value if set."
+ )
+ },
+ )
+ max_predict_samples: Optional[int] = field(
+ default=None,
+ metadata={
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of prediction examples to this "
+ "value if set."
+ )
+ },
+ )
+ train_file: Optional[str] = field(
+ default=None, metadata={"help": "A csv or a json file containing the training data."}
+ )
+ validation_file: Optional[str] = field(
+ default=None, metadata={"help": "A csv or a json file containing the validation data."}
+ )
+ test_file: Optional[str] = field(default=None, metadata={"help": "A csv or a json file containing the test data."})
+
+ def __post_init__(self):
+ if self.task_name is not None:
+ self.task_name = self.task_name.lower()
+ if self.task_name not in task_to_keys.keys():
+ raise ValueError("Unknown task, you should pick one in " + ",".join(task_to_keys.keys()))
+ elif self.dataset_name is not None:
+ pass
+ elif self.train_file is None or self.validation_file is None:
+ raise ValueError("Need either a GLUE task, a training/validation file or a dataset name.")
+ else:
+ train_extension = self.train_file.split(".")[-1]
+ assert train_extension in ["csv", "json"], "`train_file` should be a csv or a json file."
+ validation_extension = self.validation_file.split(".")[-1]
+ assert (
+ validation_extension == train_extension
+ ), "`validation_file` should have the same extension (csv or json) as `train_file`."
+
+
+@dataclass
+class ModelArguments:
+ """
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
+ """
+
+ model_name_or_path: str = field(
+ metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
+ )
+ config_name: Optional[str] = field(
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
+ )
+ tokenizer_name: Optional[str] = field(
+ default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
+ )
+ cache_dir: Optional[str] = field(
+ default=None,
+ metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
+ )
+ use_fast_tokenizer: bool = field(
+ default=True,
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
+ )
+ model_revision: str = field(
+ default="main",
+ metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
+ )
+ token: str = field(
+ default=None,
+ metadata={
+ "help": (
+ "The token to use as HTTP bearer authorization for remote files. If not specified, will use the token "
+ "generated when running `huggingface-cli login` (stored in `~/.huggingface`)."
+ )
+ },
+ )
+ trust_remote_code: bool = field(
+ default=False,
+ metadata={
+ "help": (
+ "Whether to trust the execution of code from datasets/models defined on the Hub."
+ " This option should only be set to `True` for repositories you trust and in which you have read the"
+ " code, as it will execute code present on the Hub on your local machine."
+ )
+ },
+ )
+ ignore_mismatched_sizes: bool = field(
+ default=False,
+ metadata={"help": "Will enable to load a pretrained model whose head dimensions are different."},
+ )
+
+
+def main():
+ # See all possible arguments in src/transformers/training_args.py
+ # or by passing the --help flag to this script.
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
+
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
+ # If we pass only one argument to the script and it's the path to a json file,
+ # let's parse it to get our arguments.
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
+ else:
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_glue", model_args, data_args)
+
+ # Setup logging
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ handlers=[logging.StreamHandler(sys.stdout)],
+ )
+
+ if training_args.should_log:
+ # The default of training_args.log_level is passive, so we set log level at info here to have that default.
+ transformers.utils.logging.set_verbosity_info()
+
+ log_level = training_args.get_process_log_level()
+ logger.setLevel(log_level)
+ datasets.utils.logging.set_verbosity(log_level)
+ transformers.utils.logging.set_verbosity(log_level)
+ transformers.utils.logging.enable_default_handler()
+ transformers.utils.logging.enable_explicit_format()
+
+ # Log on each process the small summary:
+ logger.warning(
+ f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, "
+ + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}"
+ )
+ logger.info(f"Training/evaluation parameters {training_args}")
+
+ # Detecting last checkpoint.
+ last_checkpoint = None
+ if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
+ last_checkpoint = get_last_checkpoint(training_args.output_dir)
+ if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
+ raise ValueError(
+ f"Output directory ({training_args.output_dir}) already exists and is not empty. "
+ "Use --overwrite_output_dir to overcome."
+ )
+ elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
+ logger.info(
+ f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
+ "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
+ )
+
+ # Set seed before initializing model.
+ set_seed(training_args.seed)
+
+ # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below)
+ # or specify a GLUE benchmark task (the dataset will be downloaded automatically from the datasets Hub).
+ #
+ # For CSV/JSON files, this script will use as labels the column called 'label' and as pair of sentences the
+ # sentences in columns called 'sentence1' and 'sentence2' if such column exists or the first two columns not named
+ # label if at least two columns are provided.
+ #
+ # If the CSVs/JSONs contain only one non-label column, the script does single sentence classification on this
+ # single column. You can easily tweak this behavior (see below)
+ #
+ # In distributed training, the load_dataset function guarantee that only one local process can concurrently
+ # download the dataset.
+ if data_args.task_name is not None:
+ # Downloading and loading a dataset from the hub.
+ raw_datasets = load_dataset(
+ "nyu-mll/glue",
+ data_args.task_name,
+ cache_dir=model_args.cache_dir,
+ token=model_args.token,
+ )
+ elif data_args.dataset_name is not None:
+ # Downloading and loading a dataset from the hub.
+ raw_datasets = load_dataset(
+ data_args.dataset_name,
+ data_args.dataset_config_name,
+ cache_dir=model_args.cache_dir,
+ token=model_args.token,
+ trust_remote_code=model_args.trust_remote_code,
+ )
+ else:
+ # Loading a dataset from your local files.
+ # CSV/JSON training and evaluation files are needed.
+ data_files = {"train": data_args.train_file, "validation": data_args.validation_file}
+
+ # Get the test dataset: you can provide your own CSV/JSON test file (see below)
+ # when you use `do_predict` without specifying a GLUE benchmark task.
+ if training_args.do_predict:
+ if data_args.test_file is not None:
+ train_extension = data_args.train_file.split(".")[-1]
+ test_extension = data_args.test_file.split(".")[-1]
+ assert (
+ test_extension == train_extension
+ ), "`test_file` should have the same extension (csv or json) as `train_file`."
+ data_files["test"] = data_args.test_file
+ else:
+ raise ValueError("Need either a GLUE task or a test file for `do_predict`.")
+
+ for key in data_files.keys():
+ logger.info(f"load a local file for {key}: {data_files[key]}")
+
+ if data_args.train_file.endswith(".csv"):
+ # Loading a dataset from local csv files
+ raw_datasets = load_dataset(
+ "csv",
+ data_files=data_files,
+ cache_dir=model_args.cache_dir,
+ token=model_args.token,
+ )
+ else:
+ # Loading a dataset from local json files
+ raw_datasets = load_dataset(
+ "json",
+ data_files=data_files,
+ cache_dir=model_args.cache_dir,
+ token=model_args.token,
+ )
+ # See more about loading any type of standard or custom dataset at
+ # https://huggingface.co/docs/datasets/loading_datasets.
+
+ # Labels
+ if data_args.task_name is not None:
+ is_regression = data_args.task_name == "stsb"
+ if not is_regression:
+ label_list = raw_datasets["train"].features["label"].names
+ num_labels = len(label_list)
+ else:
+ num_labels = 1
+ else:
+ # Trying to have good defaults here, don't hesitate to tweak to your needs.
+ is_regression = raw_datasets["train"].features["label"].dtype in ["float32", "float64"]
+ if is_regression:
+ num_labels = 1
+ else:
+ # A useful fast method:
+ # https://huggingface.co/docs/datasets/package_reference/main_classes#datasets.Dataset.unique
+ label_list = raw_datasets["train"].unique("label")
+ label_list.sort() # Let's sort it for determinism
+ num_labels = len(label_list)
+
+ # Load pretrained model and tokenizer
+ #
+ # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently
+ # download model & vocab.
+ config = AutoConfig.from_pretrained(
+ model_args.config_name if model_args.config_name else model_args.model_name_or_path,
+ num_labels=num_labels,
+ finetuning_task=data_args.task_name,
+ cache_dir=model_args.cache_dir,
+ revision=model_args.model_revision,
+ token=model_args.token,
+ trust_remote_code=model_args.trust_remote_code,
+ )
+ tokenizer = AutoTokenizer.from_pretrained(
+ model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
+ cache_dir=model_args.cache_dir,
+ use_fast=model_args.use_fast_tokenizer,
+ revision=model_args.model_revision,
+ token=model_args.token,
+ trust_remote_code=model_args.trust_remote_code,
+ )
+ model = AutoModelForSequenceClassification.from_pretrained(
+ model_args.model_name_or_path,
+ from_tf=bool(".ckpt" in model_args.model_name_or_path),
+ config=config,
+ cache_dir=model_args.cache_dir,
+ revision=model_args.model_revision,
+ token=model_args.token,
+ trust_remote_code=model_args.trust_remote_code,
+ ignore_mismatched_sizes=model_args.ignore_mismatched_sizes,
+ )
+
+ # Preprocessing the raw_datasets
+ if data_args.task_name is not None:
+ sentence1_key, sentence2_key = task_to_keys[data_args.task_name]
+ else:
+ # Again, we try to have some nice defaults but don't hesitate to tweak to your use case.
+ non_label_column_names = [name for name in raw_datasets["train"].column_names if name != "label"]
+ if "sentence1" in non_label_column_names and "sentence2" in non_label_column_names:
+ sentence1_key, sentence2_key = "sentence1", "sentence2"
+ else:
+ if len(non_label_column_names) >= 2:
+ sentence1_key, sentence2_key = non_label_column_names[:2]
+ else:
+ sentence1_key, sentence2_key = non_label_column_names[0], None
+
+ # Padding strategy
+ if data_args.pad_to_max_length:
+ padding = "max_length"
+ else:
+ # We will pad later, dynamically at batch creation, to the max sequence length in each batch
+ padding = False
+
+ # Some models have set the order of the labels to use, so let's make sure we do use it.
+ label_to_id = None
+ if (
+ model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id
+ and data_args.task_name is not None
+ and not is_regression
+ ):
+ # Some have all caps in their config, some don't.
+ label_name_to_id = {k.lower(): v for k, v in model.config.label2id.items()}
+ if sorted(label_name_to_id.keys()) == sorted(label_list):
+ label_to_id = {i: int(label_name_to_id[label_list[i]]) for i in range(num_labels)}
+ else:
+ logger.warning(
+ "Your model seems to have been trained with labels, but they don't match the dataset: "
+ f"model labels: {sorted(label_name_to_id.keys())}, dataset labels: {sorted(label_list)}."
+ "\nIgnoring the model labels as a result.",
+ )
+ elif data_args.task_name is None and not is_regression:
+ label_to_id = {v: i for i, v in enumerate(label_list)}
+
+ if label_to_id is not None:
+ model.config.label2id = label_to_id
+ model.config.id2label = {id: label for label, id in config.label2id.items()}
+ elif data_args.task_name is not None and not is_regression:
+ model.config.label2id = {l: i for i, l in enumerate(label_list)}
+ model.config.id2label = {id: label for label, id in config.label2id.items()}
+
+ if data_args.max_seq_length > tokenizer.model_max_length:
+ logger.warning(
+ f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the "
+ f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
+ )
+ max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
+
+ def preprocess_function(examples):
+ # Tokenize the texts
+ args = (
+ (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])
+ )
+ result = tokenizer(*args, padding=padding, max_length=max_seq_length, truncation=True)
+
+ # Map labels to IDs (not necessary for GLUE tasks)
+ if label_to_id is not None and "label" in examples:
+ result["label"] = [(label_to_id[l] if l != -1 else -1) for l in examples["label"]]
+ return result
+
+ with training_args.main_process_first(desc="dataset map pre-processing"):
+ raw_datasets = raw_datasets.map(
+ preprocess_function,
+ batched=True,
+ load_from_cache_file=not data_args.overwrite_cache,
+ desc="Running tokenizer on dataset",
+ )
+ if training_args.do_train:
+ if "train" not in raw_datasets:
+ raise ValueError("--do_train requires a train dataset")
+ train_dataset = raw_datasets["train"]
+ if data_args.max_train_samples is not None:
+ max_train_samples = min(len(train_dataset), data_args.max_train_samples)
+ train_dataset = train_dataset.select(range(max_train_samples))
+
+ if training_args.do_eval:
+ if "validation" not in raw_datasets and "validation_matched" not in raw_datasets:
+ raise ValueError("--do_eval requires a validation dataset")
+ eval_dataset = raw_datasets["validation_matched" if data_args.task_name == "mnli" else "validation"]
+ if data_args.max_eval_samples is not None:
+ max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
+ eval_dataset = eval_dataset.select(range(max_eval_samples))
+
+ if training_args.do_predict or data_args.task_name is not None or data_args.test_file is not None:
+ if "test" not in raw_datasets and "test_matched" not in raw_datasets:
+ raise ValueError("--do_predict requires a test dataset")
+ predict_dataset = raw_datasets["test_matched" if data_args.task_name == "mnli" else "test"]
+ if data_args.max_predict_samples is not None:
+ max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples)
+ predict_dataset = predict_dataset.select(range(max_predict_samples))
+
+ # Log a few random samples from the training set:
+ if training_args.do_train:
+ for index in random.sample(range(len(train_dataset)), 3):
+ logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")
+
+ # Get the metric function
+ if data_args.task_name is not None:
+ metric = evaluate.load("glue", data_args.task_name, cache_dir=model_args.cache_dir)
+ elif is_regression:
+ metric = evaluate.load("mse", cache_dir=model_args.cache_dir)
+ else:
+ metric = evaluate.load("accuracy", cache_dir=model_args.cache_dir)
+
+ # You can define your custom compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a
+ # predictions and label_ids field) and has to return a dictionary string to float.
+ def compute_metrics(p: EvalPrediction):
+ preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
+ preds = np.squeeze(preds) if is_regression else np.argmax(preds, axis=1)
+ result = metric.compute(predictions=preds, references=p.label_ids)
+ if len(result) > 1:
+ result["combined_score"] = np.mean(list(result.values())).item()
+ return result
+
+ # Data collator will default to DataCollatorWithPadding when the tokenizer is passed to Trainer, so we change it if
+ # we already did the padding.
+ if data_args.pad_to_max_length:
+ data_collator = default_data_collator
+ elif training_args.fp16:
+ data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)
+ else:
+ data_collator = None
+
+ # Initialize our Trainer
+ trainer = Trainer(
+ model=model,
+ args=training_args,
+ train_dataset=train_dataset if training_args.do_train else None,
+ eval_dataset=eval_dataset if training_args.do_eval else None,
+ compute_metrics=compute_metrics,
+ processing_class=tokenizer,
+ data_collator=data_collator,
+ )
+
+ # Training
+ if training_args.do_train:
+ checkpoint = None
+ if training_args.resume_from_checkpoint is not None:
+ checkpoint = training_args.resume_from_checkpoint
+ elif last_checkpoint is not None:
+ checkpoint = last_checkpoint
+ train_result = trainer.train(resume_from_checkpoint=checkpoint)
+ metrics = train_result.metrics
+ max_train_samples = (
+ data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
+ )
+ metrics["train_samples"] = min(max_train_samples, len(train_dataset))
+
+ trainer.save_model() # Saves the tokenizer too for easy upload
+
+ trainer.log_metrics("train", metrics)
+ trainer.save_metrics("train", metrics)
+ trainer.save_state()
+
+ # Evaluation
+ if training_args.do_eval:
+ logger.info("*** Evaluate ***")
+
+ # Loop to handle MNLI double evaluation (matched, mis-matched)
+ tasks = [data_args.task_name]
+ eval_datasets = [eval_dataset]
+ if data_args.task_name == "mnli":
+ tasks.append("mnli-mm")
+ valid_mm_dataset = raw_datasets["validation_mismatched"]
+ if data_args.max_eval_samples is not None:
+ max_eval_samples = min(len(valid_mm_dataset), data_args.max_eval_samples)
+ valid_mm_dataset = valid_mm_dataset.select(range(max_eval_samples))
+ eval_datasets.append(valid_mm_dataset)
+ combined = {}
+
+ for eval_dataset, task in zip(eval_datasets, tasks):
+ metrics = trainer.evaluate(eval_dataset=eval_dataset)
+
+ max_eval_samples = (
+ data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
+ )
+ metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
+
+ if task == "mnli-mm":
+ metrics = {k + "_mm": v for k, v in metrics.items()}
+ if task is not None and "mnli" in task:
+ combined.update(metrics)
+
+ trainer.log_metrics("eval", metrics)
+ trainer.save_metrics("eval", combined if task is not None and "mnli" in task else metrics)
+
+ if training_args.do_predict:
+ logger.info("*** Predict ***")
+
+ # Loop to handle MNLI double evaluation (matched, mis-matched)
+ tasks = [data_args.task_name]
+ predict_datasets = [predict_dataset]
+ if data_args.task_name == "mnli":
+ tasks.append("mnli-mm")
+ predict_datasets.append(raw_datasets["test_mismatched"])
+
+ for predict_dataset, task in zip(predict_datasets, tasks):
+ # Removing the `label` columns because it contains -1 and Trainer won't like that.
+ predict_dataset = predict_dataset.remove_columns("label")
+ predictions = trainer.predict(predict_dataset, metric_key_prefix="predict").predictions
+ predictions = np.squeeze(predictions) if is_regression else np.argmax(predictions, axis=1)
+
+ output_predict_file = os.path.join(training_args.output_dir, f"predict_results_{task}.txt")
+ if trainer.is_world_process_zero():
+ with open(output_predict_file, "w") as writer:
+ logger.info(f"***** Predict results {task} *****")
+ writer.write("index\tprediction\n")
+ for index, item in enumerate(predictions):
+ if is_regression:
+ writer.write(f"{index}\t{item:3.3f}\n")
+ else:
+ item = label_list[item]
+ writer.write(f"{index}\t{item}\n")
+
+ kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "text-classification"}
+ if data_args.task_name is not None:
+ kwargs["language"] = "en"
+ kwargs["dataset_tags"] = "glue"
+ kwargs["dataset_args"] = data_args.task_name
+ kwargs["dataset"] = f"GLUE {data_args.task_name.upper()}"
+
+ if training_args.push_to_hub:
+ trainer.push_to_hub(**kwargs)
+ else:
+ trainer.create_model_card(**kwargs)
+
+
+def _mp_fn(index):
+ # For xla_spawn (TPUs)
+ main()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/wandb/latest-run b/wandb/latest-run
new file mode 120000
index 000000000..d0ad236b2
--- /dev/null
+++ b/wandb/latest-run
@@ -0,0 +1 @@
+run-20250310_212229-x0wxhan7
\ No newline at end of file
diff --git a/wandb/run-20250310_202153-60c8phhh/files/config.yaml b/wandb/run-20250310_202153-60c8phhh/files/config.yaml
new file mode 100644
index 000000000..4b489b501
--- /dev/null
+++ b/wandb/run-20250310_202153-60c8phhh/files/config.yaml
@@ -0,0 +1,497 @@
+_attn_implementation_autoset:
+ value: true
+_name_or_path:
+ value: bert-base-uncased
+_wandb:
+ value:
+ cli_version: 0.19.1
+ m:
+ - "1": train/global_step
+ "6":
+ - 3
+ "7": []
+ - "1": train/learning_rate
+ "5": 1
+ "6":
+ - 1
+ - 3
+ "7": []
+ - "1": train/epoch
+ "5": 1
+ "6":
+ - 1
+ - 3
+ "7": []
+ - "1": train/loss
+ "5": 1
+ "6":
+ - 1
+ - 3
+ "7": []
+ - "1": train/grad_norm
+ "5": 1
+ "6":
+ - 1
+ - 3
+ "7": []
+ python_version: 3.11.11
+ t:
+ "1":
+ - 1
+ - 2
+ - 3
+ - 5
+ - 11
+ - 41
+ - 49
+ - 51
+ - 53
+ - 55
+ - 71
+ - 100
+ "2":
+ - 1
+ - 2
+ - 3
+ - 5
+ - 11
+ - 41
+ - 49
+ - 51
+ - 53
+ - 55
+ - 71
+ - 100
+ "3":
+ - 7
+ - 13
+ - 19
+ - 23
+ - 55
+ - 66
+ "4": 3.11.11
+ "5": 0.19.1
+ "6": 4.47.1
+ "8":
+ - 5
+ "9":
+ "1": transformers_trainer
+ "12": 0.19.1
+ "13": linux-x86_64
+accelerator_config:
+ value:
+ dispatch_batches: null
+ even_batches: true
+ gradient_accumulation_kwargs: null
+ non_blocking: false
+ split_batches: false
+ use_seedable_sampler: true
+adafactor:
+ value: false
+adam_beta1:
+ value: 0.9
+adam_beta2:
+ value: 0.999
+adam_epsilon:
+ value: 1e-08
+add_cross_attention:
+ value: false
+architectures:
+ value:
+ - BertForMaskedLM
+attention_probs_dropout_prob:
+ value: 0.1
+auto_find_batch_size:
+ value: false
+average_tokens_across_devices:
+ value: false
+bad_words_ids:
+ value: null
+batch_eval_metrics:
+ value: false
+begin_suppress_tokens:
+ value: null
+bf16:
+ value: false
+bf16_full_eval:
+ value: false
+bos_token_id:
+ value: null
+chunk_size_feed_forward:
+ value: 0
+classifier_dropout:
+ value: null
+cross_attention_hidden_size:
+ value: null
+data_seed:
+ value: null
+dataloader_drop_last:
+ value: false
+dataloader_num_workers:
+ value: 0
+dataloader_persistent_workers:
+ value: false
+dataloader_pin_memory:
+ value: true
+dataloader_prefetch_factor:
+ value: null
+ddp_backend:
+ value: null
+ddp_broadcast_buffers:
+ value: null
+ddp_bucket_cap_mb:
+ value: null
+ddp_find_unused_parameters:
+ value: null
+ddp_timeout:
+ value: 1800
+debug:
+ value: []
+decoder_start_token_id:
+ value: null
+deepspeed:
+ value: null
+disable_tqdm:
+ value: false
+dispatch_batches:
+ value: null
+diversity_penalty:
+ value: 0
+do_eval:
+ value: true
+do_predict:
+ value: false
+do_sample:
+ value: false
+do_train:
+ value: false
+early_stopping:
+ value: false
+encoder_no_repeat_ngram_size:
+ value: 0
+eos_token_id:
+ value: null
+eval_accumulation_steps:
+ value: null
+eval_delay:
+ value: 0
+eval_do_concat_batches:
+ value: true
+eval_on_start:
+ value: false
+eval_steps:
+ value: null
+eval_strategy:
+ value: epoch
+eval_use_gather_object:
+ value: false
+evaluation_strategy:
+ value: epoch
+exponential_decay_length_penalty:
+ value: null
+finetuning_task:
+ value: null
+forced_bos_token_id:
+ value: null
+forced_eos_token_id:
+ value: null
+fp16:
+ value: false
+fp16_backend:
+ value: auto
+fp16_full_eval:
+ value: false
+fp16_opt_level:
+ value: O1
+fsdp:
+ value: []
+fsdp_config:
+ value:
+ min_num_params: 0
+ xla: false
+ xla_fsdp_grad_ckpt: false
+ xla_fsdp_v2: false
+fsdp_min_num_params:
+ value: 0
+fsdp_transformer_layer_cls_to_wrap:
+ value: null
+full_determinism:
+ value: false
+gradient_accumulation_steps:
+ value: 1
+gradient_checkpointing:
+ value: false
+gradient_checkpointing_kwargs:
+ value: null
+greater_is_better:
+ value: null
+group_by_length:
+ value: false
+half_precision_backend:
+ value: auto
+hidden_act:
+ value: gelu
+hidden_dropout_prob:
+ value: 0.1
+hidden_size:
+ value: 768
+hub_always_push:
+ value: false
+hub_model_id:
+ value: null
+hub_private_repo:
+ value: null
+hub_strategy:
+ value: every_save
+hub_token:
+ value:
+id2label:
+ value:
+ "0": LABEL_0
+ "1": LABEL_1
+ignore_data_skip:
+ value: false
+include_for_metrics:
+ value: []
+include_inputs_for_metrics:
+ value: false
+include_num_input_tokens_seen:
+ value: false
+include_tokens_per_second:
+ value: false
+initializer_range:
+ value: 0.02
+intermediate_size:
+ value: 3072
+is_decoder:
+ value: false
+is_encoder_decoder:
+ value: false
+jit_mode_eval:
+ value: false
+label_names:
+ value: null
+label_smoothing_factor:
+ value: 0
+label2id:
+ value:
+ LABEL_0: 0
+ LABEL_1: 1
+layer_norm_eps:
+ value: 1e-12
+learning_rate:
+ value: 5e-05
+length_column_name:
+ value: length
+length_penalty:
+ value: 1
+load_best_model_at_end:
+ value: false
+local_rank:
+ value: 0
+log_level:
+ value: passive
+log_level_replica:
+ value: warning
+log_on_each_node:
+ value: true
+logging_dir:
+ value: model_sst2/runs/Mar10_20-14-07_ee-tarrasque
+logging_first_step:
+ value: false
+logging_nan_inf_filter:
+ value: true
+logging_steps:
+ value: 50
+logging_strategy:
+ value: steps
+lr_scheduler_type:
+ value: linear
+max_grad_norm:
+ value: 1
+max_length:
+ value: 20
+max_position_embeddings:
+ value: 512
+max_steps:
+ value: -1
+metric_for_best_model:
+ value: null
+min_length:
+ value: 0
+model/num_parameters:
+ value: 109483778
+model_type:
+ value: bert
+mp_parameters:
+ value: ""
+neftune_noise_alpha:
+ value: null
+no_cuda:
+ value: false
+no_repeat_ngram_size:
+ value: 0
+num_attention_heads:
+ value: 12
+num_beam_groups:
+ value: 1
+num_beams:
+ value: 1
+num_hidden_layers:
+ value: 12
+num_return_sequences:
+ value: 1
+num_train_epochs:
+ value: 3
+optim:
+ value: adamw_torch
+optim_args:
+ value: null
+optim_target_modules:
+ value: null
+output_attentions:
+ value: false
+output_dir:
+ value: model_sst2
+output_hidden_states:
+ value: false
+output_scores:
+ value: false
+overwrite_output_dir:
+ value: false
+pad_token_id:
+ value: 0
+past_index:
+ value: -1
+per_device_eval_batch_size:
+ value: 8
+per_device_train_batch_size:
+ value: 8
+per_gpu_eval_batch_size:
+ value: null
+per_gpu_train_batch_size:
+ value: null
+position_embedding_type:
+ value: absolute
+prediction_loss_only:
+ value: false
+prefix:
+ value: null
+problem_type:
+ value: null
+push_to_hub:
+ value: false
+push_to_hub_model_id:
+ value: null
+push_to_hub_organization:
+ value: null
+push_to_hub_token:
+ value:
+ray_scope:
+ value: last
+remove_invalid_values:
+ value: false
+remove_unused_columns:
+ value: true
+repetition_penalty:
+ value: 1
+report_to:
+ value:
+ - tensorboard
+ - wandb
+restore_callback_states_from_checkpoint:
+ value: false
+resume_from_checkpoint:
+ value: null
+return_dict:
+ value: true
+return_dict_in_generate:
+ value: false
+run_name:
+ value: bert_sst2_experiment
+save_on_each_node:
+ value: false
+save_only_model:
+ value: false
+save_safetensors:
+ value: true
+save_steps:
+ value: 500
+save_strategy:
+ value: steps
+save_total_limit:
+ value: null
+seed:
+ value: 42
+sep_token_id:
+ value: null
+skip_memory_metrics:
+ value: true
+split_batches:
+ value: null
+suppress_tokens:
+ value: null
+task_specific_params:
+ value: null
+temperature:
+ value: 1
+tf_legacy_loss:
+ value: false
+tf32:
+ value: null
+tie_encoder_decoder:
+ value: false
+tie_word_embeddings:
+ value: true
+tokenizer_class:
+ value: null
+top_k:
+ value: 50
+top_p:
+ value: 1
+torch_compile:
+ value: false
+torch_compile_backend:
+ value: null
+torch_compile_mode:
+ value: null
+torch_dtype:
+ value: null
+torch_empty_cache_steps:
+ value: null
+torchdynamo:
+ value: null
+torchscript:
+ value: false
+tpu_metrics_debug:
+ value: false
+tpu_num_cores:
+ value: null
+transformers_version:
+ value: 4.47.1
+type_vocab_size:
+ value: 2
+typical_p:
+ value: 1
+use_bfloat16:
+ value: false
+use_cache:
+ value: true
+use_cpu:
+ value: false
+use_ipex:
+ value: false
+use_legacy_prediction_loop:
+ value: false
+use_liger_kernel:
+ value: false
+use_mps_device:
+ value: false
+vocab_size:
+ value: 30522
+warmup_ratio:
+ value: 0
+warmup_steps:
+ value: 0
+weight_decay:
+ value: 0
diff --git a/wandb/run-20250310_202153-60c8phhh/files/requirements.txt b/wandb/run-20250310_202153-60c8phhh/files/requirements.txt
new file mode 100644
index 000000000..67cb06e8a
--- /dev/null
+++ b/wandb/run-20250310_202153-60c8phhh/files/requirements.txt
@@ -0,0 +1,295 @@
+pydantic==2.10.4
+urllib3==2.3.0
+scipy==1.15.0
+myst-nb==1.1.2
+pure_eval==0.2.3
+wcwidth==0.2.13
+attr-dot-dict==0.1.0
+emoji==2.14.0
+mkl_random==1.2.8
+keras==3.8.0
+nvidia-cuda-runtime-cu12==12.4.127
+torchvision==0.20.1
+cocotb==1.8.0
+wheel==0.44.0
+imageio==2.36.1
+dill==0.3.8
+pydot==3.0.4
+transformers==4.47.1
+sphinx-book-theme==1.1.3
+myst-parser==4.0.0
+traitlets==5.14.3
+nvidia-cudnn-cu12==9.1.0.70
+nvidia-curand-cu12==10.3.5.147
+kiwisolver==1.4.8
+pygame==2.6.1
+greenlet==3.1.1
+pytest-profiling==1.8.1
+requests==2.32.3
+aiosignal==1.2.0
+aiosignal==1.3.2
+Sphinx==8.1.3
+torch-summary==1.4.5
+Farama-Notifications==0.0.4
+sphinxcontrib-plantuml==0.30
+ptyprocess==0.7.0
+pexpect==4.9.0
+yarl==1.18.0
+yarl==1.18.3
+filelock==3.16.1
+filelock==3.13.1
+datasets==3.2.0
+datasets==3.3.2
+bitstring==4.3.0
+triton==3.1.0
+py4j==0.10.9.8
+pybind11==2.13.6
+pluggy==1.5.0
+regex==2024.11.6
+cvxpy==1.6.0
+sphinx-test-reports==1.1.0
+jsonschema-specifications==2024.10.1
+fastjsonschema==2.21.1
+pytest-xdist==3.6.1
+smmap==5.0.2
+onnx==1.17.0
+tornado==6.4.2
+GitPython==3.1.44
+sphinxcontrib-htmlhelp==2.1.0
+iniconfig==2.0.0
+threadpoolctl==3.5.0
+cycler==0.12.1
+tzdata==2024.2
+tzdata==2023.3
+certifi==2024.12.14
+certifi==2025.1.31
+numpy==1.26.4
+gast==0.6.0
+frozenlist==1.5.0
+opt_einsum==3.4.0
+astunparse==1.6.3
+colorlog==6.9.0
+grpcio==1.69.0
+jupyter_core==5.7.2
+torchmetrics==1.6.1
+gprof2dot==2024.6.6
+nvidia-ml-py==12.560.30
+multidict==6.1.0
+etils==1.11.0
+jupyter_client==8.6.3
+sphinxcontrib-jsmath==1.0.1
+tensorboard-plugin-profile==2.19.0
+clarabel==0.9.0
+idna==3.7
+idna==3.10
+pylance==0.21.0
+ipykernel==6.29.5
+matplotlib-inline==0.1.7
+jedi==0.19.2
+lightning-utilities==0.11.9
+namex==0.0.8
+kornia==0.7.4
+docker-pycreds==0.4.0
+mkl-service==2.4.0
+fonttools==4.55.3
+tensorboard-data-server==0.7.2
+beautifulsoup4==4.12.3
+Werkzeug==3.1.3
+Markdown==3.7
+asttokens==3.0.0
+huggingface-hub==0.27.1
+huggingface_hub==0.29.2
+pytest-sugar==1.0.0
+tensorflow==2.18.0
+pytest==8.3.4
+joblib==1.4.2
+ipython==8.31.0
+mdurl==0.1.2
+optimum==1.23.3
+pytest-metadata==3.1.1
+debugpy==1.8.11
+absl-py==2.1.0
+mkl_fft==1.3.11
+sphinxcontrib-serializinghtml==2.0.0
+MarkupSafe==3.0.2
+sympy==1.13.1
+six==1.16.0
+six==1.17.0
+multiprocess==0.70.15
+multiprocess==0.70.16
+snowballstemmer==2.2.0
+zipp==3.21.0
+ale-py==0.10.1
+scs==3.2.7.post2
+find_libpython==0.4.0
+sphinxcontrib-jquery==4.1
+decorator==5.1.1
+nvidia-nvtx-cu12==12.4.127
+prompt_toolkit==3.0.48
+charset-normalizer==3.4.1
+charset-normalizer==3.3.2
+nvidia-cuda-nvrtc-cu12==12.4.127
+evaluate==0.4.3
+tensorboard==2.18.0
+lightning==2.5.0.post0
+py-cpuinfo==9.0.0
+prettytable==3.12.0
+nbclient==0.10.2
+execnet==2.1.1
+torch-tb-profiler==0.4.3
+kornia_rs==0.1.8
+contourpy==1.3.1
+pydata-sphinx-theme==0.16.1
+pip==24.2
+requests-file==2.1.0
+jsonschema==4.23.0
+sphinx_glpi_theme==0.6
+imagesize==1.4.1
+osqp==0.6.7.post3
+importlib_resources==6.5.2
+termcolor==2.5.0
+importlib_metadata==8.5.0
+cocotb-bus==0.2.1
+future==1.0.0
+pyarrow==18.1.0
+pyarrow==19.0.0
+packaging==24.2
+sentry-sdk==2.19.2
+einops==0.8.0
+nvidia-cuda-cupti-cu12==12.4.127
+bitarray==3.0.0
+aiohttp==3.11.10
+aiohttp==3.11.11
+nvidia-cufft-cu12==11.2.1.3
+scikit-learn==1.6.0
+pyzmq==26.2.0
+Mako==1.3.8
+platformdirs==4.3.6
+nvidia-cusolver-cu12==11.6.1.9
+markdown-it-py==3.0.0
+wrapt==1.17.0
+tensorboardX==2.6.2.2
+protobuf==3.20.2
+propcache==0.2.1
+propcache==0.2.0
+pytz==2024.1
+pytz==2024.2
+wandb==0.19.1
+libclang==18.1.1
+nvidia-cublas-cu12==12.4.5.8
+alembic==1.14.0
+nvidia-nvjitlink-cu12==12.4.127
+click==8.1.8
+gymnasium==1.0.0
+Brotli==1.0.9
+lxml==5.3.0
+tensorflow-io-gcs-filesystem==0.37.1
+matplotlib==3.10.0
+tqdm==4.67.1
+annotated-types==0.7.0
+ghp-import==2.1.0
+pillow==10.4.0
+onnxconverter-common==1.14.0
+stable_baselines3==2.4.0
+imageio-ffmpeg==0.5.1
+onnxruntime==1.20.1
+typing_extensions==4.12.2
+Pygments==2.19.0
+coloredlogs==15.0.1
+sentencepiece==0.2.0
+torch==2.5.1
+timm==1.0.12
+mdit-py-plugins==0.4.2
+PyYAML==6.0.2
+gviz-api==1.10.0
+xxhash==3.5.0
+setuptools==75.1.0
+pytorch-nlp==0.5.0
+babel==2.16.0
+soupsieve==2.6
+ipdb==0.13.13
+python-dateutil==2.9.0.post0
+comm==0.2.2
+flatbuffers==24.12.23
+rpds-py==0.22.3
+psutil==6.1.1
+h5py==3.12.1
+numexpr==2.10.1
+optuna==4.1.0
+accessible-pygments==0.0.5
+tf_keras==2.18.0
+mypy-extensions==1.0.0
+pytest-html==4.1.1
+hyperopt==0.2.7
+tabulate==0.9.0
+fsspec==2024.12.0
+fsspec==2024.9.0
+parso==0.8.4
+sphinxcontrib-qthelp==2.0.0
+qdldl==0.1.7.post5
+nvidia-cusparse-cu12==12.3.1.170
+sphinx-data-viewer==0.1.5
+mase-cuda==0.0.1
+cloudpickle==3.1.0
+coverage==7.6.10
+pandas==2.2.3
+Jinja2==3.1.5
+black==24.10.0
+pathspec==0.12.1
+sphinxcontrib-devhelp==2.0.0
+mpmath==1.3.0
+pytorch-lightning==2.5.0.post0
+alabaster==1.0.0
+jupyter-cache==1.0.1
+stack-data==0.6.3
+sphinx-rtd-theme==3.0.2
+accelerate==1.2.1
+pyparsing==3.2.1
+docutils==0.21.2
+pytest-cov==6.0.0
+rich==13.9.4
+safetensors==0.5.3
+safetensors==0.5.0
+humanfriendly==10.0
+PySocks==1.7.1
+toml==0.10.2
+Bottleneck==1.4.2
+setproctitle==1.3.4
+opencv-python==4.10.0.84
+referencing==0.35.1
+nvidia-nccl-cu12==2.21.5
+tokenizers==0.21.0
+attrs==24.3.0
+aiohappyeyeballs==2.4.4
+optree==0.13.1
+networkx==3.4.2
+sphinx-needs==4.1.0
+nbformat==5.10.4
+gitdb==4.0.12
+SQLAlchemy==2.0.36
+executing==2.1.0
+google-pasta==0.2.0
+ml-dtypes==0.4.1
+pynvml==12.0.0
+nest-asyncio==1.6.0
+sphinxcontrib-applehelp==2.0.0
+pydantic_core==2.27.2
+transformers==4.47.1
+mase-tools==1.0.0
+more-itertools==10.3.0
+typing_extensions==4.12.2
+inflect==7.3.1
+typeguard==4.3.0
+tomli==2.0.1
+jaraco.context==5.3.0
+jaraco.functools==4.0.1
+platformdirs==4.2.2
+packaging==24.1
+autocommand==2.2.2
+jaraco.text==3.12.1
+zipp==3.19.2
+jaraco.collections==5.1.0
+importlib_metadata==8.0.0
+wheel==0.43.0
+backports.tarfile==1.2.0
+importlib_resources==6.4.0
diff --git a/wandb/run-20250310_202153-60c8phhh/files/wandb-metadata.json b/wandb/run-20250310_202153-60c8phhh/files/wandb-metadata.json
new file mode 100644
index 000000000..2bfa1fc7b
--- /dev/null
+++ b/wandb/run-20250310_202153-60c8phhh/files/wandb-metadata.json
@@ -0,0 +1,60 @@
+{
+ "os": "Linux-5.14.0-427.28.1.el9_4.x86_64-x86_64-with-glibc2.34",
+ "python": "CPython 3.11.11",
+ "startedAt": "2025-03-10T20:21:53.597208Z",
+ "program": "/home/jw3621/Projects/bert-onn/test/passes/module/transforms/optical/bert-finetune.py",
+ "codePath": "test/passes/module/transforms/optical/bert-finetune.py",
+ "git": {
+ "remote": "https://github.com/Johnny1882/mase.git",
+ "commit": "758710333d8ca4b7444930df91d86c7642652426"
+ },
+ "email": "jw3621@ic.ac.uk",
+ "root": "/home/jw3621/Projects/bert-onn",
+ "host": "ee-tarrasque",
+ "executable": "/home/jw3621/anaconda3/envs/mase/bin/python",
+ "codePathLocal": "test/passes/module/transforms/optical/bert-finetune.py",
+ "cpu_count": 16,
+ "cpu_count_logical": 32,
+ "gpu": "NVIDIA GeForce RTX 3090",
+ "gpu_count": 4,
+ "disk": {
+ "/": {
+ "total": "75125227520",
+ "used": "61026897920"
+ }
+ },
+ "memory": {
+ "total": "269555560448"
+ },
+ "cpu": {
+ "count": 16,
+ "countLogical": 32
+ },
+ "gpu_nvidia": [
+ {
+ "name": "NVIDIA GeForce RTX 3090",
+ "memoryTotal": "25769803776",
+ "cudaCores": 10496,
+ "architecture": "Ampere"
+ },
+ {
+ "name": "NVIDIA GeForce RTX 3090",
+ "memoryTotal": "25769803776",
+ "cudaCores": 10496,
+ "architecture": "Ampere"
+ },
+ {
+ "name": "NVIDIA GeForce RTX 3090",
+ "memoryTotal": "25769803776",
+ "cudaCores": 10496,
+ "architecture": "Ampere"
+ },
+ {
+ "name": "NVIDIA GeForce RTX 3090",
+ "memoryTotal": "25769803776",
+ "cudaCores": 10496,
+ "architecture": "Ampere"
+ }
+ ],
+ "cudaVersion": "12.7"
+}
\ No newline at end of file
diff --git a/wandb/run-20250310_202153-60c8phhh/files/wandb-summary.json b/wandb/run-20250310_202153-60c8phhh/files/wandb-summary.json
new file mode 100644
index 000000000..76b497321
--- /dev/null
+++ b/wandb/run-20250310_202153-60c8phhh/files/wandb-summary.json
@@ -0,0 +1 @@
+{"train/global_step":350,"_wandb":{"runtime":63},"train/learning_rate":4.722882026920032e-05,"_timestamp":1.7416381701494613e+09,"train/epoch":0.166270783847981,"train/grad_norm":4.34669828414917,"train/loss":0.2263,"_runtime":56.552776547,"_step":6}
\ No newline at end of file
diff --git a/wandb/run-20250310_202153-60c8phhh/run-60c8phhh.wandb b/wandb/run-20250310_202153-60c8phhh/run-60c8phhh.wandb
new file mode 100644
index 0000000000000000000000000000000000000000..28b610420d3100fe8e76d63abfdf1515169b8a57
GIT binary patch
literal 164998
zcmeFa31Adew*PMjWN8*j5Y)KExG!{S@9v1>=sd?|+{X3IsMFIV4Ft29MMTkw0VN1x
zKsHe^;)b#axJ49=iY%@HMNv@!6;V_IxI_^6f6u*Do!d+3s^|RXy{`OuJl~_cs=De^
z=iYP9UCuZE%Hz)5@X~&tHdIdv@8{{}sr6)@d4R0=eG>jh&R2EG&|dB7^>#C)fq
z;gbGCkNK}ZdOY1SYCT8g_Pnk#P*OH**!AJ-)q)c|`~KZKJYzH3?jpO#(CQ&Wib{tJ
z4i`jn5AgQ#b_op+mlh7H^<><1^O*+@id0mEDl4m$P~?X4NJUXeq_k>S&q!oIRk)%e
zTzOq}>J{dMD59PM${m
z?d)mj;>qqd)YIspf9G~peX@+dC*ipWgwjRxX?~x4X2uDgVLgkBN~>=ea6`Zd8QK$`
zOT5Pot*R=o%+JdkQdBjxdXTT6tR(NevZ1A=Bji9p$tww0Mtt}x>jb|Z@XL~_s)d1q
zur|oAXxR79?;E~?OioER2&%+
zE*KFiE~{+2d|z+R@`~!xNMUGbBwUEkVc_1wynU;RBB9}B6@{Tl$)HGKA=(1IlDvNH6ga$`&?%=YD;o*wHT8~G<$yr5}p-5>#Sz)9iR2X6ZH+m-rdcF8I
z{?lW;hZa=ewn9+=QB?R1>*)^n_6-*nhT@KlUK(H1UAOnknO<*6_y#~I9a1$Ee>
zKKQ7pH2M*4zOT1?VOgj$TvA?)pEX&7UL;&mJOb~nth^j3qTd|g?Os}45*ic4Yu|zu2QHQi5J247pjfE*zPg?y}bs73z?BrhKk5^xDC>E!NhC_
zm1y+16j?j)9OUg~T6$$cS%uI9j`AKPG@KA#%8*DYF==}XFiMzN0jzCMS!JlI>{>LW
zBA6Wi)wXXVNxnVO+YcR2+4T`z`TB5iQDF%GtrmX&L~jnd^0HEDec%fF45OhiQ)$D-
zf9YWF0hQI|<+ywEE}~C`cwt*B5b4iAbH*Lrep%9V2Sb8^o=-(*D5+A2Z?#bN3p_~Rbn
z?Sf|=Ty%r=K|gQLWG^5K9SO{#?Y*Ei;m~An10PpL%7sJuFCFRaZ}tn+LBhkfU5g)Z
zxc5M!5vnYY6ci0EDhQQ_VYSga#bwJY$_5ojNG#bg2R>jB
z9OHoM(gOUS!d&mM-ovQ(flX1{CnptVjVWARRaO~^_T@+69xB46l`zE;wC50O)Mcf$
zo^H9^N9THTy#3Ld3X5=wP{Ghh!L{XOMKCL&+vtaW2fbsld7oX_{p7-Xc&{%j2oEyb
z;r^Cp{t^l2;g6Hqe9y(Nt
z_YjIuH@m+a8Gwr&=RGL#mrT&!-0+~vvf}EhNUk@_+nroxK^U(HS7na+P;VBqeHa(H
zF8(N%-5=`pntfMcQAMq%XD)w*6b9xp7Y|byUV+9P5)U(7NS3!7{0c0m2&Q6w*2miw
zpGD9US-%ehz$a*KJbEyXJH*?I4nm`#{%{B!N7vj)d0D|w_Os^E1eP3osI0UYN0_G`
zM>8MFkQ5^PxfEs)C_bS08}2W($X~ABsETDDS~^#E|mpBv;@qGY8PkLlUmQ
zkHDQo8%4OFpt_{Gn76OWD)#o^9B>hKa@*S;~|UyIO-qBDxz=%K(kOO`u5Vnp}|GP=5@dx&OWOut3rz`E-EP^&%k1D>#NG};0U7l$};Mat$rnwofV5L+Pn`izeI$|
z`jqH1$=>4;Ccwa{pDHdM6fS`CW4S|Ua79^(;9a7pWO}
z5xP(MsHB`iaq1hQ-}LZyu0%u+Ytj_zBPrzz@lT@dX;@JeyHT_6GJghL99^@hJc3Zc
z{H(XP^Ps^(9)f=lE*JiNkheGeJ5*L)Wm#h9+>7K~?@`|VL&L>`sV^)js-(`19zIfv
z8|#uwE;*N609`vY7>Q5po!}dG)x>Xlc{^8OB*A)R=ue8_aEr;eAjm?0g0!lzu)OWZ
z_%kEaMfw(^_G^Z%{BtC3^f^-FilPESIjFiVH$V(;evon1HRK=&CGs+s6~VQH5M{OH
zbA8RO&9q&9{lgFuwVgyikCg0M=IU{;M|-n{Zz)KNP|uDr3s{Wxh{$@-va4YEh4fyi
ze>vLQ4NZ~7M75s&EcS>W%*dtQt+1#vjA$xUbzNae^o=6MH6!WjQoL?tWALXXl&^AA
z2S65cvdhz_tZg79q9;eMid*4+6G;bqou;MWx*Q|rlEQJ;6%Y&zsX(TJ{t)>idS~8-
z+10bXeTs%45Qv1hwO(7))-AxE2Uiy2-}Dj=K?;F?ged7FnpXS>n1j7N*jEK*rGty`
z60f)^H})vR$rLr;P#nhp$MSamuMj)!dL`MX%m*O;tFpel*?Y1#mz^L4kuU^=Z-}}n
zp9l@F2;*)!wii6!EO^5Ryg(T56s-BpJ>96xv&G~#o1Q`
z3E!Znr)Grj*rycn;8SX|)V|rP=810wE@ysQN-Ym<HA>IR6<{cVjWpGjBJrW*-=XBg^%{Bnn%0A6S@Yo$5uZ-m$wVs2o
z%HHc66CKH2zq!kJKlS^_&v@nLeGJOBiccjG>{ipqCJ?W
zTV0iPnPa@!^#09(p_ovztgQ`U2&%lgawu6R#uo*e{o$Reky#};BmR&`$Z)*b-rG|3
z!}-T;9=9i%W(9^<;D~dPadfA%Sn@|^Fz}qq&kS95(O=HG@DD8ZpfAfSimsq=Z8`;$?LMMAO#Y0TO{c;CheEFs40C7kpjd-f0ocEBXRcGX~goC^nk|^n10s!{AlkXmI*M5
zus8JJl^E-!2*qT9HfkM2&W2eOH|t93aoZ9>w9(SYV2t-f5KI@LJ%)?Rha!ICMv3>%
zdp?Yb3_(YNTSrxk;Y3b2m(4jqyxAW-+eQ
zGv#&uKH2l1ecMJnRv%|1N#xg~gjUa22o+P~ug`p_Gjj@dmn!I$c#
z*?Wd_7s?aP#WVCy`3xr=<3*Rs;_kd?+-vsg-DN4L1ysL3aO>iummP!I(*sgoZ`dAB
zR*!BQ+5F~$iBl`GfAI9m$&rp7nd7cMBXguYLzVTMt0XC3k_MiBO3vv;Re6=ygm%9%
z(Fat0QlqNN$m-ec3x;aNicLReWdGvH8(LNp$s0CYH5557TpA{?5mxgerPo)o+1k9~
zqCt7(BdUg$mC{6T9>z^01Im&AqQ9-oV~#AZv^0+*&f?-oaUN3O(gGS_SdeSq|^dbC8IZ4#_E`nQO|fDYY3?U4@T}bC{pXF-I_R
zhDRzQIhDwf!xfdM<_s@l2O!i87v^9%2Xi}-8zVWG8qeX=H#y;)>)|NMDtz9Y9CO5|
z(ntT#3Eri<)@(f%&@xj)+v|x!Tl@HLA9r$>fsM?O1+$U|s`-+ZuLVxY$&r12qcX3k
z>YD3zu6dT)VpeKatSx@Y46Nn0DSvVsDkTP$oUiEllEP3a8f%N)w_P%Y%&Ko{RLrb6
zshy2IlxRyjE%b5wzrGEi6Z70xQ?XRLfDe|`1wV0B<>R<}~W#aT5^esq;5yT$C|
zH%`6F9liX>94&^H8pv0u5h^~v%4i+#f2ovc^>apR@#g6nZltBh(vtFJKWM2~nZ#&)
zcIuSdiB^9{v>I+*{kbb?88NibYcX1yPc|5>)s0>6C0YkLqBVW;;zv8V(YE}tw3K{_
zY)kX`bw+FC4SnxF4zvzVO{*pYq3Ku59d4fdKpz3t1h;fsfmm1wO{tUDeF}#)__w8>
z5v)TT!D^Vh{0cY13dX>a1No}KU>O|Np(p)$GquC))UcRG25
zJL*8$pRdV`R*=V9lYN^?h}NNwXtnjC5G{gmZXA~VfFDB?A#I|T8BBJHT$tC
z^F@z%rz^TG*p?W{gOeOkF1{hpOSBGmL~H%rIlsA*mKe!1TDs0`OOcpItld>#
zNVJY{M62<|4__DK2iMq^7|Cv2`db&hD0)QK&=TW#(9-jDjnOiAul3>2uO3abavjlXdFzu}v5$8REisU1
zv@~v8W^8)Wfg6Vqt)m^$s@pwdr7LNPkvylx`*?%*@mGFu9viJW#u2Ti59Zw^dc<)q
z>9-Uyl4rD7-edUuyx;1%>cG>(CC309bcRX{u=n`E+OAO?7a)~N;
ziKc1I|G|)m)(MViHEdk|wn*y^w{%-#AkS&>w8tEh|Jxt){zbG-bVRGJ`OV3qOLPq_
zF_Mp=#XI?}T{347t&<$ls@b^yV^`7=LwQb%#!cZ8c_+W^*C9)Y*2#`&ZEjxtpy(1^
zV_RY
z_+ZpfH^LGFdB8IAb((g?uqmHN?Ecf7=cyg~Qo~~LgK1eiC%h$kM7NlhDhBe#Kuln;
zNow6^a9St5r&bWHJV&(Fel>5o*vGqumKewz0~I-+4asZ1V1SKnd0y^*0MU{h(Q4fE
z+H$jxmyaE}acZ4Qx-C_V?YxaV_M(h>uC|3Ep&SBvOAjju#{bHleyCxVuin${>$bRUxn$`UmbryVHV9G8-fQvPt|x$(KQ*sf
zT95QCbBwcJZaUhXxWtGaxQu+2XVxkohq!mf*?%Kk0Y|uM?s|GyEG}0pv%#`cF{Ed>
zWIiOQaqIeZ;nXXsH3n1TVy%%|*Yw3J7JA&tOAP7*Yz~%>&g;DQnx5DFN8*+5jMwI;
zTV2Xa4C}!Qv#=&Fo&@zjf6O-Gb*dv?&7UuRInFz}r9Fy~J>#X(>QlpK@b2s4v7N6c
zUIU%+nloceoOg7^3KH~RVrTCJFLK5LJsIpVc``t8jwa*uyV&5K17Tli4wil_5qyL-2oml)i0Tzn45;O2GJxCv)bb38pYF6JINuDaSM
zZjI~i-NH)@?t|3bvqih;zW7#yt{WQvM7;j!h}YUr$FGRvHQo)wUg*BW;GXeP&F-GN
z$K^*X$st~6IO5eZWygd#_vjX0Vsy`WnRBwd`}%rK)&a!pOh>%xKAE$|oxF50yyv{k
zc}5;iEcah=nSp)Ve;S11j=Y>eOoh5+yY4f(dGhx=5C7eJdhS{uPX4b!C!ugGt3|+T
z0Vdh1A3Z7egu2=RIO(iGCwjf+FQQi9K7Z(lfyet4ob%^FKGX)~MGvf!cKbZ>n(ICr
zdl*hUJNsa7|17g`c2-xEoT6$TD8wX7H*MLld_P#sIoZ2B-ea@kFAqDyq^vAS_}_pL
z2&lTIdwp5SSfPLw#`XPD)5T?raNS&?N;l2PI8
z;H*S$xJ6WsJ37l7q8mopj1rW>IS^ZOHQ7)DGA^R1h8#Ir3VIL8>K2MVSWg+%x)r~y
z;&u#gkF5R80(Gb^)h|oHjw&Z8R~`)LJ^#6BZU!)%o3-2Xzk4VDt1lsIzZ;%7FhhmO
zoR_)>>5b7%?$BG7KK@ZguPI+#O`CuJlKM=x`In#CJm#s7GWh0SH@GyL&v5rG7c8QF
z=ls;HGNY6iyzwGXd3s%a?f7@;6jCO&)?DJ
zg>T`9hu(BNRW&^x#n)JqIVI=#Xkkr!n9xn&`b=7UcU92~%1
zUX=PS`wDltVDI~c$NT-)*WN`w@Z!{`3o+^D^_y^Z%mBShVun;c@U+nlV~E#Z9r3F9
za`DW#6wxib#33#o!_!P`G=^_JZ|P7<5id#2i=~L)u@tfK$7S7O&C3<5Do5t%;s{rg
z5WYe7-*paJ^QG|RCwTHt#}GUxGVkNrJ+o1R4dvCTQl1axMX!fHR)T&61mYK!D5i%m
zRKKCB{@U#QS3i%-UB)l>BvrrR*W3Qmyq?*)sH;{{R^a0$`g|3Yl@}lu?x_T12%V@w
z$x!s{le~S!V^A)vfc-)iO7sWyfT}71&CoP|P*t;!@%HAIqN-y2uuMb65(X(K1)zDN
z=|(V!D;^|XF?@Zv2&EN^BlO$713^U%`bh30h?wWCqK#c|$NgW<@BilL?LYkDZN0Kj
z^!7#lvuk~MzUoR;n7KImn*rIUsH$$ry5d(A{GmG&4?UYge%swxKzzh9R${F;{^O1ccV1zp1}pfoZ6u)GWU=5>Y#zElv~F;j6TDf%W!g|1G_@oJbeN&S{+o@`bNtB4l0Q)aPTYtb
z)TL}Hgv$7E7S?w`zbenpNie1v?W!gyVexCw;Cd;CgfJIt~
zA>C%v)`KJ?t%#&G#UhO#O-@?1NK@lAZCY(3ty-k5&uo!SIJkY%YLqlJd$adY<4L3Q
zSE8iVl%!AOy@z?3DCsPa){rFgo`e3YS*X<*)NMv>JxMaus)!n`)a*X+tFfq`Iq{35
z8qkqn$vg6?Z5{c<89V!;tII4vN2U1i-!?~qHGi?;ZVO0l{7rDDoL^t
zlw|8M)Wkz|*-zn!BEg?XnqR^1bK-F?Cljv`ZxQ$?YzZ20KIZ#Qye3{AUW{fair=r{
z%GS@zM&g~yvZ6rSO^4ze26X)Bg!s?LJ%1lG90QIWuw_HyUWXm11QV`l1a!O)dM3Xi
z!=9`g4a6Nu@S%xchH*_67OY`V3(9bQNbw=u!EHw`8)p~PyaO;I
zP1jXDp!pGlSp*XA6>kOYigjE;#ZcP_#M=o0O=>L^b3^VHFtyymk%LoO(uKqC`>OO1<-Q#z>&i#sMhacA(F{X5!fe0E%iJH!XBT%~_4w^%UspTh^;F$=F6AZ6_M5!S
zt!8|7aq?gO@dNR?#u=|SXD^ByS#^tj3A6p27u!D--St`X%GI}!d;CXgUTkFbdp@$d
zZqi|~-Mw3I36uTFaJ_cn*$h|65w4m~ZZC+%<%$hwBXbO4u0I*BUvKNjdi#H-#>Juu
zZe7i9jDI4|JGzCJFwY^&6ip~RnlShG@44!tvDAA({leBp1kP*oJ=0rU%1fB&FnRH4
z!rWL~QK1YWUV|L*YWU)V=iA@9
zVrJbzUK{S5H^2Lph(Qw039(nkkI8T=F)_~P#A14|Q!hEfAYw(1h_$}F=xw(W^NXQA
zBW5L%Jn1&p~$4y^Ty)^Ulh(j_J)M*OJ$j0`-cQEG!8x|)b#g^&YcD>nYR
z+O@>QV4oA?^{UKLs&zLn4G}R&nmQo1dG@3kiLTNOThh>tiSa%s7Sk%%9WeioL=3Cz
z91vSk_t{3*5)(syM$D`dYF7Drwa0tU60z$X5!*cOsX2+R(j_(~hWwnEIWNgw<>&eD
ze@m{iA~iAQDqEPVthskl{{&a*5@2G)p9I*rZhf{8Ff4IN)hGpE3u@;7BLSEj7Uhr3
z35WrI3^2YAx%QdY|3+;R>%>z5W3dIdu+1|cdnD0Qx`db*@QX#q&7H`hOIPL-F)WL7
zK&<|^ZChMRObqxrv6!Ov4?n$*c}i?ja6;^()_$hx7Hs>i%&ox0
zc%K6^SMl(g9e2+8^*m~mH>C!~Vhe6zv$srrF40rEVPy(>F)`j}CAO6q6Ek{8@`%_-
zN5pFGpS;Jl#KeFM_u*hFGTDX
zN5nRN|G>CJPw5h3V#LphnbSyIUiF35hrLF`{^f{R-OgWDB@%PP5|xoTK{4V_Laa~T
zqhp9zjVCoS7F=v+!A1QuyG~8;l`bJBM*K-y<+?-9`|uPHyVV9U<}5j}h8eGoO6>Yw
zf=mqgAsT=|elvjO(;^#B-ERvat96EK`mdW;n#dGZ2Op*IH%@g4GBM`o$YT2b#pf?+
zA!MVRA*;Lh<$BXy3dr2Bt_NM281-{xX6Mg`VE+^flo7Jg&XCpb-Uj_4Ho@yYWMbUU
zkeRZ421mBE>*6a3*==@_?O?tC+Ov}DI(^j}u0-3vbsI@PIP6%B;k88^anDTmoV3S
z0lpA+;q=Q-JRBud>OAe1we$b$-gkC&0@ksfY!q}G&`!a+eke$#VC$`_VNb201dR6U
z)oIv-tq0Wrw%%e(GAhKN;9RnLb)9hgDB%`Uuf!U>)|B;$)njtYn=ZWo57#AA7WA
zv1FZ`q-32c2essd;&dN2@?sx4en~<>15N((CwKhN6(sKPWZjzaZ@g~$LCr!D#@9~V
zai9j<8E2!pQSmh1r#G(HJ~?CeN$1y*0?&9Gr?LWf{M?qMyC8V@1FiM!Kz~l#a>H^z
zINxal7FD?8teRe~okA7vl$KxT(Oa5e@cavVm+2E!F18S{oOmF5u79e8mXZ4Xd>Y
znmA?z`HNsalRei_ErVi33Mq=C`C1gPhs)wFer+Y>G2x3tn7)k$ahmBS%b|
zM&rw>WIi#qV%37d)HtWwV8(hhZfDJt=A%AFk{&Jn7u8aaCXO9JvB0o~+Z4Y2F=nc{CYcXap1Eb!
zF~saXXUt|l^w?ncG83}_&dl7p$+LlD{#sN>jq`pR%viv|?W|?f=y^$fr#lQy%m#Q7
zGqYSc-ywYBEzh1n&}KS=*82HZJCZ_6|D~ho(PS|lV9?B+PX>?DMn81MZv^cDXVBKZ
zvVL(=-{}r$Vm`p3#fbkL`TR>hg4W;+TJw}oZ%yhu-2qL^2RJlyts3_O=TF;jCqaAA
z8MGDeuV3L}XktDP1C58abLB_>^#X=3X8{@&^Y06PV%7XCRyE(y
z0;s@GRc%o?KV{i`1>>zuOvtadU)Wz!Fz!NCiBlBz4@jspY@kdoijhjngWq+!^;|r`
zL*^5ZECGLjme<~nktHzI2T;jWMt$N$y#f>g^=pAwcHBH+46gUEc|EG@$+4_w@gG
zUH`Wip7d%H%0WM3-pw;q51QZ3yw6(_3;PF97dx@Azasw!3;RvC&;t1h<`Uza<+tiPW~Czkb(
z7Mr&U+}kRFo^WvcCD7Z-`V(if1bRYQf88gcc(f^?@H@)-Lx=(Ox7%u#zB!8nQ|Hca-%H_|TuV
z*|chrjxFooR%_mRkYqL;Q`R4LVAh&XUDjV9ZPl^|X}?9<`pg#TgoE2B-B#A0v@+T&
z>yNjfIVkII_8#$N{WJ9-Y#K^+Ru5uWq+_f4qZqX*3SvD-GSV?s{o4c-5~})_H!rw&
z3cB$}&2IdTwr>1|mW6%yQ4{@mUhqGzA~>pC2T?*FrPWcon#to?rK{!mBT?2L<=5FC
zEJu*qN}?w#;vYmkY%GrO8z?M~V%t`wcqOjLb|1V8l!#XqsPI6dCZJj6$m8l}Lx%@<
z$o{A(zOU0IPvj32S2>*u`oo#iUx@u-T}B{C@0yqRjyo2=4yhjY2ZA(^9KeB06(Q~}
zGof-rmy{uRfZIjo_Gl^c_*&?a55?)(11Y)+c^;|l(1b^VGLMW90)GQkKwj3Yt0oZ%
zu&x?Jkfg!nEKW(vwu3wSONHwpr)WJ>+?}Elc@QPjsp`IFK*RyM1MNyeKxBbKl2t$I
z!AE~Qv1B{&pbmNfZ`07B%@jQ{u?)LEh-Eg=OVV^I!;VPJD#IT4M45WqtCwl@3JF5+
z+UoI7q~nQn_80FQ>vvQ%Jg)M9y7N%K0XL{nY9ySXlA*Dy$nozQ0j6nb(23LyBU*MI
zQDsamcU+V-An*`?L)uD8`(@wJDF>N?PymV&9cv=bNH9HjBX$Iv&Y+78UdTzU9Cdv24OZ2U0&_a4%yjyNO2(?+vc9v{`I
zo^|<+#B8=RW{vmFzsbGKgb5La)$}(@nHhY2=;ZF#mDBjd<2IPFF;{M9&0o|kPda?)
z4nq?rL^w2!59;WAX8gEY&(0!fPdJ0t^i$oPNrw+z0ZkERL^w37t_@#;QS;-Y+X>nn
zXVBLEFzWrJ(9(Zb9lA6{m=ZCe#fX8P{Lhf(DG{?zLc!K=u5K5YXU3piR?z)}C|V@D)&o~~e~is?X%AKm#a&A#im
z#tb)O6LSKFO^&H|ytrT06@=}1XV_|LC)9LuGd3|Rh_N=l1K`=Zqb?yEqTnUa*0!j<$mR$^(y=?vvV~miAjGsB04w15V8>iNmQ)
zymWmrQG3xDwWi&>uXi&wF*9H_kmVR^+tx*fQWITbgBtUr+}5UlxO7Q-e$*AlCguhl
zo4IP1$GOXHsz05uz2ppA>&h=ix93OQfK3zA1CGt?+4yXMoclmFVSCvbw&rCk-fRzB
z+HVIy*QSXXLK18%TQ7f^ur)cuRs*aZ%n>*?bE1XMgNFCpdJ#F&S8ZTp;m1xEel)-E+MnHQY+{PQshJXpe82_N
z*j}?ijX6?It$yk5@$EU%wBJfWU7MI9#MDXVJ1KsByU%jMw$vH6rp901b2T4&pc}6<5F8B{}q$_M-V~%tebEFNk
zmuI%)NZmk97jpzo&8*YPQ>#vWUaX-)eJE6Wu(y9!m(cKVY2l!(t`*g#RYfHbWX@3Z
zaW&nqqe6Ws0Et3FjX|`DUPDxTy
zPr)uO?{QfNR8@qFO7n(PgbPEZWfc%B%rdYXT@Pru3Km&Qe!r}HeeDz!?2jUaSRSfk
zDXFFgu_8OjD+(T+)q`I*QeIX7%ZE{+zBgH+zW>cT
z`QLsGDAbp+^1>I6Qeiu9*=TXSFmT2j+~zx$^gHB)voAgt9N)Ha<{sL($WLuuvaDZ+
zE7SM$Z;|Z=D8FNaRi8FgkG#_lRr-#2rhGWPqIYea-1ds<$FAw;{pzj58gOQ_jWcgV
zkv{7k)xG^mzw9ERMjtiWxEg(ltu<+pzGTfbwo0G>=B}UA_YL4SSJ}ADEa5h5CT;{D
z?^VamSTJV4jBcxKoGv)fp$ri}w7%(5q
ztX@|Ccfz*D8MXy=Pj6}uTiP#5r8X)KFD1d&^0y_A5VjATVQc($``q?&Fjrs`$Co%Z
zGc4xG#5vz;_Yt-aonfn+v3Q58v5DhLG1z!E@szZ21z}t33|sTO#}>Djy1K&J#POvV
zY~0$^?+%(pj`SlN*jNe$<`V#L*?DG$O}1
zgxxQXC?ss_oMCGjJ-OA**bH%aiDQegwi(yl{|jOJ)ETz*W0y9Wel)?_(tg1u8mS@X
z2pk)7}qUcj=-^*OQQJ{^rav7D;~-Uc=ne(Yx9N6qukp4yHhbwwK$Qv^-f1LoUpWKUGU=uJJGF0d>Av=o`ZzIN;NZ;3Xg-HO
z{`EWWlL2m%4RCkRqVhe=r7n2k-L9_Y<`?sYB;10NGd?G7UpwR0^!kPi-OWu*7?N
zbiJmT{UGNKGg*c)i0(DtX{et(@@_z`LtOC?i**|*536kw$is7
zjHLT|U-WQ(F=^o7c&_E=W0nt_#=9STWFJ1&4dBGQfq}D*w#xZNH04_XVuf#ffPH=Vm6Ve6{q;7x>R0Za+BW_H;v$
zySa&318=B&(AO*)vSa7&5^AVhZE$132)DNdtG|D8A3oI$1}A0>toWjA?gil^BdZU;
zc?yARaR{z;(a-J%7ZB414$h3?c%J&yJ?H;S;I=sgx20~$K74ArFRMWp7ZCFX4$hp`
z;(6*3(@v@)aNC`MYiN04^gev58^DQqLlTSg9eCDR1a5~jaOz|wT=JwW+m6haQPE1QA^?nWSJ1hoKGl~Xb6gTzaQuMZa$hX)l$=(l^K
zvbn6I05%mV-v5-}YDkpiXY=++l!V{?2V3T+79~NkW89fvLQVB!yuAyHDzEkB`Kl|!
zLn3G6mojyfG*@*$3W`$!YXc?XdwILEql!vV|6j`P?d`_?Syqkb^+@j@UMjtcYWBaF
zck>)ov*)+-#?1TsSW+31Z%BD{t*3L4WCapMNo3H
zqC$3tmCHUHx9j^|vLqUe&2LEyL?s-4w=7AUghOgs5*VW;mSL%i;FbLcWl03m39=-Z
zxkf|!cgT_iVXqcwl-2L=?L{x|cgvErkxng3g6ilNX-kX)q*Y5S!}`pYSVqFZ?TckZ
zWl5MWN_(;-3Pd0N9kL{T2qOrxB#K4)_sf!K_GL-5Hqw?D2T1=1Wl12X)2=K@(5J%h
zm^Q826*WpGO9E8|OOyo0ZIS-{vLsPaqXb!!Gk!R|W(GR)Rs2&N21oDK!_1`?zR+S5BQd>X_Kbid!PX%IGR6ZVkf
zMGVl%T7QD742=pS3LsA1B1-`((wo6ZDMAC>uo|VBpts;B)h4Xf)*!YJ#Q#{&6fZUb
zSq%Q{&|JVKHUp|E=F=xCNF+S{i6+y8XTqJT0R^%mnj%9jB@oqcfNEE)FayF8x<$>e
z25D;?Vg~E+_jIu9f+G4{(z3(ZxfqHG%4NG<I!A?Wyo2amYfRTM1eEY*$_Qw|8kavr-Cp>*Lt2p!Wonv#R^(JmRopEcPxoc$m+|qrI0S)vB
zlP4*`jaTfl!Htc{a(kQo^^B$a7+Q1#IAQW68MxroKh+Yr-<*M~
zX{sB)kD*03fD>jAvXzJzP+jJTbwUlPBidp>ah;
z1a6cwaO=m6xqlx%)eYc;v>_R|GhW;~n81y82Ci||y02UgPRtuvQ?>F2K3#Xk!Wk6=
z?lx!O){fsaZy!$84HhTn4IG>q{_+ZyN4&WEXBFUXw*hWEtpKlO!AR?~PLxNw5G@G5-KV^cMdtFuAw
zPTG>jZEnS~1+zVFMNj^
zz?6zS>!Y7Pp%Q|3WVBa8P*p~K`NO*zz&kDjc81z@LIJ>@n5GZX7{02iL9qg$nydn#
zq3hmblT-lA#K^OOVOR|l7dm93gSvM>I|Tsyp#UJXq_JKDqS}~eP?g94)E-P$|1XF^
zcg$>I&IB^=7@}{l>K+*3uDl^CnOIpBE-A0|^v4VZeg!oTp<0e{bIevy@x$tM2UT?e
z2z2>18EO3_zw<4kfMBN=k9t(rU^#c%Xmq`!-PP=6Q+9a19^{*Q95_y}aq3=A^xWoc
zGk178TnWJ^*Il~R56*YlfR!0#Rrli~VCCs`-rd_zrMFaXO>o--hk?SmPcw{^b~%2uvkS94L`1mU_y&eUPLglsmIh1lyCmqho`I$;wC5AxXHf4
zO*YJ1Onkhx%dUHjT*uuu&KJCH<4-Gbx~oIIl&+(7Sr?tqO?HNE^VT<>BS+HC9(qK{
zX})u9WR4<9Fh5B+%^p_}3i-TIX)cDo*(IQ+!X#i-dWyl?C&gl?)`bffq<
z@yk!%zwZX?0;>}Tpwxkik+0L17!{M^MOAr~*BthTDXR$GG-v2G-}l0ou16;hK>?l0
z%U)(PnBP!LK9=Ch%8T3*f??4{jCzeVU~I77F7RpaxW+-`N^I26$N^K}}*
z(|u-MJL50Ij-YXf@q)k{kbjqQ>+8_$H{Q6NjNd&d68zFqF*6jmY#br{;RU4RS1(;5OI1Xzzl3``7eG
zCyT>SsH<({OMJlH?CD;r-#(tu&2)yYWoF|qu16>4609V-#?fiKZ10uVe0C9`d%zjG
z&2KKn_Q`#PCN8i#F_&QIRKDiK+*|+WYkXY@U4t`pE1stfl>0)L=4(08*U4fo!O$sY
z0X?2eEWS(YOXwbShOTB(>l>~|C*~3iowZS%H`soczaAn7JIe+-7Mk3~LX)-Ie$?D<
zbYd#O$(gk-xq}^)*>EY5d&n8NhF`zB$?fFCOoA7VHjAtBWr-IJxMmZQPgl?`gboHAX?=$^tJ64zGOFXH;
zim3!g7b6MMJoxJ%p_}InUBj#IeC>L4VlKgJ=bM%FxR;#ZTf7X8Jf7)
z>cm`vqvIpQI-f?IUOad^IoKy{pktxQ7#5l|tY322J{_zJ+N+pKaB{r8>O6sc_VfE+
z)j@874RRA`MK&kbx_(<>_8wYm(w@zwyJ`dVbz&}&gx&6Q7Y{PPZlMi!chSmX&aQsT
zo%1}dho_3!1S`(1#2DU%|LiRjyhYC7ZGO7urw#y5Oei=!bJ&TyiU&URzDDq#at3eh
zCwpGVaKGh=IR!5ZZ*D^2I~33T_R{MJ-qX(Dwbt*b?UV{$dM|cDpQnmh1&3#rvE`oj
z$y?7FLGYe&2yfw?n4d`DFw!5Mm{)LkX0zqv>_6_9G>x3?vo_!{XIp13o*1jSpPiUk
zFm~qlSc3;D-`v#KAa>6=W7jhF1)$_p&j*jVr~%wkNes`%s#!I;63jU-mUZNQu^BThbN{N9A1o)$RDzPJ7(a1
z8Qm5;gSUS7jCU<~H_m>)dgp$flaGL695nqIifIOCXb!3H*1Jf1<{mP@7o0Ji{?)rv
zQ@Y%=ho^~YMv~S$ZD7}0avLw&fX7_!Smtu4Z%mz{rT01>^n99_WpH-pG$WtzSwDSM
zA+cNHklj;_=cIDE>Ca9~GB`V)WSEm>3$MFyGd0?mY_MZ4m)l+Ays4PDND(KeKRhwX
zVD;>z807+A(_WULz3dEL>-YCy-?8iAiAe_F8Tl%orj@u}#-6jj|ClVV$r-$w`&PVT
zx!iVrZF;Z2qb^U(G62t?uaU2%3Y~0cz;lb1%_4ZOID@x#)V<47`r7n|C#D%e+L6MV
zEy_jnft3>mO+1uT2VMm{QZZnv19vcgx@6vKJyS>u$fz8w<3A{zi^8!e++PwYDXSO}
ziYgCe8mbT9sWNINqd06pr$V&*R20C5ca%&a|e-=i&`Q6v8f3s)y;ohF5k*eWk71#PInG!-z*#{{RDIjADyOBhS0sHAuomM%#
zc>!q|EH$r6S_b^8k3KT_Kdxm^t@a+a6eX}@Ke?IBw7ag
ze^ARH*j9}^UdsSQm4kZFsz&}hvos*Fe!rRlb$6^G;8zQZBIg&CUVC-7QfygZLWtnX=+j4055Uysrh7V9=@@uEbX(hpL8I
z`8^m!0eY04j~<#>dEYO=jpDDYs(ypCiOg#D@fr!48fu&e17MFr?9d88f$-zKnJBI9
zi!XccN2lzkzbNspE2>8Bo*kI@;{l1v&MOLP^pp0$KI)iT4-_y+L5T?|B(1NW32r1|
zfNBl876__=sOUp{E&D9!B`8Kv@}pwAjt(s-6rzu>qYlvltMfy3dg+TH60Q%`@h5mRg11egtpaTvns9Uf4@fs}naqk{)
zNtQ7gg!idvno2?z?N-dkL$klAR&P*XX1)3N>h>tukIs$#(NVfyQ`=-D;%e(tX?RFz
zD4GK0C=~4{%gZ(u0;rxn9zlZ;fo%A775`)XAiDL~68%A+0^J4`HH~l-o=1ru8CRn}
zQ-$)2pGqX@(Cz?0Ep?6j_{#Ut7=S(p!a)VU2JHpZ?r%>JpiLzL<(&|Ek~9{V{)p2f
zAVme!U^G?oL#)a&0=&qWXC+xKGFG%Fyf-uhDQYir5=8*Wh~TXSQN#&GCYV;@LuqLH
zFsmR5U_k4D1T9!2jC*AeZosdzzmT7RH3p!*Wj5@%-_MjGz-2&x3o>H5gf6B{`6BM3
znQh7=Dt1mm5Fz^baiCxokWC0kDix$sNG~KR@epSeWIv*41nD8PC%j?!Nbdq&iv}e#
zsEuyAiQDjGgqe?!ASo+Qk-!?
z9x3O`w3BloWDl!ULa!H$icadr#d{PeYy!$0W{PndDhB!Z{XDnh{HgSO}1
znte6FTj>m5smi2cyBm^SM&9>58Mw=
z7p7VWULaqgA#bQ7@qx$4!&PA_dH80#B@Z>ZbkES``vLGVXC)FD5Fh#TrT2Gv!(?H_aB>kJHUXJqa1r15k7
z&Nqm1hn+9jaqvx4NbemRO^yob@xFfcyzw7r^xGQQOXUvVwQ(vdcWC;Fr>1?J(c#J+
zzCOF4FO@rNwgHRE9dcInzdis~o?idmdD={RORH?0-S(Cm7CrWH#t8#=)jxn!SKBys
z1gj*-e|p9GC-KumgmQ<}9^=X#CN#w4sLe6|AFrhws_AOJ47g-{tf*
zjReShRG&-MlIM8e#`!|b((=g)oStIJD(&5FD$mi-(E1I*TjLDg>|N6yOgZb4{_uoZ
zdUCs%vZ8>UwZNOuJ
z%6Jy2w7fZ5P33IUpPe``#n`F5%hSzKJF4Y0u6O$E8?hYqO-9
zS&eSt&p#g*0K3m^u$xHRG`Zb1H9vwHO*I+4_O-_~-L;3%=NV#t!SV5J={gTpKEB0w
zJKAd1LU1<=DTUa1YV@(iy$x(chon5$K7TMiP3*&Am{eru&r*dd%-~+grbL
z?PNQCH|>p2%r#K;TF*DLT!Ux1|2+T2S%hzsGkh&e)~xFg_{3}@3BCc(emaNneeDe2
z+U=Vkw&QoxdgB~*ePY7F@x^p~SO4oTR}#L>&hXVw`}q5gfX^@H92{Rv)VKHScmA93
zed7$@g4XrV*$G?H-TM4u&cX00X1dEuIe&fC(!j=gvY+3*5v|JmGH|^<(
zDTio$W=*H7TKtC*y)Dk@)x9-pm<@X4|Ho1f{|8P)&*v924o=SunR&u?#+>DEQq%p;
z20a$Gn6|g<`9*g8ZrbA$GY(!$oY$H#I6mpKb9WKG@15aW@XE%W9Ri=2awNewb6|-^
z_WaE&cX4Sg#h^g{V6wkiV5FVXZTjMj=R&2-%ab)qty3_IS0dM
zPD`3~6i@qd>1e{&;tXHIlD9wX5ctHLgX7~<;yNE~ksGFcL5_Eu4SXzYxr>js-21{A
zHXLu-8=shR@V0Bx;{(lWHg*1vjBmRQdd%^1dTW=>tFz-d(j8wwOgUJoY}qU>#s{l%
z|GBG#@a=GhZ^ecu-?f7;t=CbZ>kEh}hghTC;3-Gnyf1eXz8{_8TRVI7R6CA0?eU2@
z2ghfQ7x2bgzsh)l@crZrU&E~CbsYkqm~?P_e3H%FwNbWQewOh4>|ow|@1UUfl8O?Up>j3~ySm38Jn~%sDtc
zv(A&bFr)r~pNL+oGkR;EUUO+jpeLproSs>%i0@XD%ly7&<(U@crft--1QIwR8l0K{4mx_;}c27Uk{!*}TID
z-)?94T3`I8(F|K$ZGAy8=ivCvV%mJP>hbzcuao25V*?)xTPCuwrDlBn(KZ}!y0_h+
zm~zClU7m8Be#(f$gP^z92E9qNw#B5k``K%=YiU{EzHKkP1uW3_1;w0$^D`$pcq#X#
z_g%U(AN*=Mr7lr$H?520{AREGbfKqX02H$hlq5&!!k5yUz26;Y^;ks!Z*>N+;hwtI
z4gye2JUGCZy5yH1WL!-EYn=gXT>l*;wmU=<7IO~{kcTgRK4|s!_N*>I22>g9GFq*eo@21u9BpD4RyBRgJT--
z)mI4b`1{j<7+hKEZF2W54ws4lH4Dv6*3rxqCWkK;b|fN@=?>|LHDbv{tA
z%Sc}5L-Pd!il6F*7^rK2MVaStuv
zGyJH6hZ=V@nPLMVOXBWkN!*I3KC#I!otK?b7buQn
zB{9GYkNw9+GQg=e_%SEU`8BS7bCSobCX{aV%7P_2;WJj-~nxK8JqD*prVXfYY1-
zY+2R%Q3nAij$;9!ny*n8sQDD$g0Ct)U={(K?hN4O&!0gBHCtU^x?7+)j>T#YOXP&3
zHHSm%?mvgRzRS`VakQ2RBgF=?G4
zV}dU{k4xWcUV6SKdhIvHuISmWR1kCn0;mHUQ2mexK;iKArGk0|g4hcVsUVE-D0=q3
zw1NT{*irp@3N0g47}tZ)H&CEH1zA
z=G{C;vJw1te);(q+X_H||C|Dl0b{TfJmM9AB&bOQqGBnj6@W}#8uK!OSW3XB8&sdy
z)JuUjfhm?E`6^2+nGWblMI}5WpIr%{s8~u$2_S!R2_U~EmSSn)*b+-gIJkYWl&A!d
zh&oOJNDfML-%`
z%G5QV7R2bJrA!mFNdJBXAl(v6NvQzj7i4KH0Ust`bF#M|-L>_ZEz$`Gw@~f
z4)gZmKgLv&OnhDC0HjAC-(ZqXtpH?_j#mK6)FI%)2CpsF7U@_4pf;@~>p?7g7iF!K
z7+IQtMLMMbP=IV9%iEJ4)RMgc>7ZrP)@QazCmh^9={5l%lC?UNlZJ>6S*uC93*_5;
zOag_O^zYFBk$upvYcuJ91^V~v|0ovd1pS|zwst#d20HSYW=DQkTSs13yE#h~Ou>l0
z7S#Nhlu(%TjwMVIC;F15_?U_egi=T(M#9vGb!4K5hCdjD!OzP&T(aIKC6i1X
z2)Bv9&^ZZ0P$#O360b}|sxv_@7z4=@#0^H!dMG(g7bO!4CbEj~1tE+M6(61A3$7y+
z|NFS8B@Kf6h31dmp0te8_H*%MAOi!=kimfv67EflK5h~j7!~q~P+VjUQ#Yy%A(Q(8*WU>&f^=v#@uqykPbS}g$CBwbRVmea0c6wH_Xfe4Zx
zfm#%3qPi3MVo1;6FKkVdRftPjuR2bx1`<3lN%j|tIRTiPW*w>|VqzvKTBRic6#w{o
zRE0~25E7()EKc!ySt^8W;Ck^_fwPvhs0vo1R*t3Xr9((bM;vZ(ic{*zf+iuf=-6Kc
z9Z0aPoxylw6mZByR7^=~Fiww}O6)x)*FOs4qG;-B5cfcr
zruWp2{1AkSnEaL>uRZ|Dr>Lk>ytGy(J_u+Uu1IoK3iOw(*90Z6*oPu(CRsBBB1mcw
z7Txyp<55XIC@KY@Nre{x!7kO%Rv!H1d`c-FKnI;l0!TWts)=1D{C@k{4Dw$uX(^D~^E#HXmuVva}6=v_iR;A1xU
zu^@)q-}+^*&$sD=(;c8NKck}ndfQ$LzE!I9uMZzU0B1V`*!1(GKXnj*!UT;8(5y(v
zT}c0xTf7ADac2OVznEfouSJ?$pdzFrCO}?ez~C0x^zK^{0er$4z!hK2v%A+Koj2^E
z4^+fFgg0Tn&dK2W^^W=Fq%8z+jx&JG56yhbW)PF^0L46n1B|i2`fs0ppPcYq8vt1l
zGnEA~^-C7|?K$Ce=O?BiOhiL6cPsM&h0(P+>xkbxXZ-4#UjJ9e;3sAwtf04KZdc~3
zbl#tG%Y=d8H{S-oDYRjf+uw?J-(K#?9&KmGZF{V4y2h#l(am)J=bEj!wVxuU10%oHK-L-)P?4VGxSB2}fuav*A(U
zN4Kp0ju19FL%9Bf)`j*%nl!gUF*gB1J>N?2cwgA^kIROULw?={LKf0YXCck{O^+XH
z+aaepK`}MqMre-mb0d85{DYSf!Ntx9*55XDkZpo#ybptVLNPPp1kL=Ik5UgkcThiS
z#xK|)$U+)!gAG%EecZl3PIH7}W`Z*6$d6-s!f*Hd_6#9>(HX*)XSTL<7=&VK!V$(;
z;onAoSw#q!I78U{_S)y{`{VSsLQTv~I6^a|;fs#0?)H2(A$-Xh!loy`h0M6!fvz-O
zxP`t@6LXU!4f&^rt9uc`mz^P;{YQ`yox5BNTHJj?f&j=I-Q|hZ|W)(_{l73u*3Q
zA&_6ijNdTH^vA}&R3@j4{x3&q@oBQ&c8@oAGsu8#ad4tco^
zge;_)!9tqm7vB04cgQx3Fue)tVrml8j`=v(hQt1L2bGLm(Qe5|-H=WhcJ5vrv$9k6
zuO9D!c8WgsL-`2W*Qj8x8SREt6;|`{sJNPsnW*`QjYfXSk4<6NMHP_St@x-*C`9QE
zQR&Lcs&Gj;3QNj9?1jWWI+-@s1vNufS;@yk#FCFabpyL*An#}BnqRE>Xe6u5s2D2h
zN*=)LSPrQO7lum9DoScSeKen{83uM%1~mMX?7!^p4J$gq*52rpwcC>>_x$gFStu9T
zrRU|+9u@ZRrj1tD3$wkv(KUVY`q@R-4I6zd7`|oW%so`uf}gti_IJ+JvEYp`8(ACJ809SU7m+aSo&INpxeF4{BBOyk_|c*OxN2u@&R
znvV&ZdI^)y`(PU(e9sxe1z)}MUWY*_4rp)(zq3k%?H|c$Z1Yc9MOs~LcSm8
zx&QJmIt>KZ+8{WUR$y|1t&h#RD*JX?jBckVOlJ`&>I=mott5S6>GS(F{s9a>vcYf~
zt>`zc@U_=wc{&b6aZrm%c*HuCI}UvLYJ&K&Gl(rSzuMe^Ad16U9HKe6%7?Y)oqOkl
z1o0DR5F3A*^khbdu|z{0+2RoSx&?EDy0QDF&j{i=XAtY(8TGRxi0Qi~7=5B4W+_P^
zp1S4@)+c`I3}Q>;t`|EwvBdO-DCQ{~qFGU#kFg$;CGREY{Fx1i%sJo3!kg9~-#@P7
zFccG&B!>9%n|e7h{M;GC1)q;A={O9md-}a^-A6s*dK(N`fMeR>f(cJM
z^3CbHu9bR3F;`*b?X?(SIpV?
ziX-2g-j*olD;%O3;BdF{%#p7xA&6f%gSff5_6|qBIlUo@`3i?szZ*PM5
zr89_44{crT2x9s!aHT#`%vTshYn3&3EPwgykvqsae`NzA3vljd0Z!eR#|Aoa&S}li
zFQzL=3~~9~<1QdW++>3xbIzP${avF*Ir1y%4N**2P$d&P7|kU$d^r2)t!q{h#IKz}
zY?%1N$_@lkOjkHWbM%3)(cJ#m;Tl2Q>
z5WjH-vE}IvA9NsyV#2~9nh6Wv(|vf$N#78}Z=FG`e|ynPN6tCDEm2HZI7D;wfj8%&
z(x(@abKYVDB6H3&nR9O3@pBg^&N;mqiunp-sG7wEc+p$$m>aLZ91Bjq149z@`3q)H
zIHpzKeArWz?5T{DRx!yQ!$1k_V7rPwSiq*M8hsr(@aTz)aqaKTYd=YAYWT%>Hm>W5
za@_mY>VbL({T6BU9PQ04uP7@p#aMh5m6aEqS?lSg1QZq8&5%J3q!jDXf>3QnZE+k6
zF^FTly}7VUs5}CB6MD2vL-kSVb?7IcYB}naYb}4fy%S0L{9rx+N&4{nf3J0BYDphM
z##%j8OIH;7xx%8#YkhgX>dNqt$l190Of2RLVCqKk8z@YVg6_fWUf!`BD1&eg7j89j{%lNRc
z+UhcXhmKFsr|61h&@gMurhmVVk7kih(DC`~z3K_~p!aSud+!Ntz4xpic4b-89El1d
zP`hA)7qS6i4o&h~x*&1N6q!)u!8?#iss_p^kg==+3B+IAMRW=A5KfELD_b8X9qmFy^63oYV;fR3z&Uex2k*tg9w!^+3D^8XHm&
zx;W6I(V|C&_LPAk8y1^U_kLlxVbnsIMmW%!6isFq)jd-LC6PeAasQF
z`?5~fEC(!(@tQF(d>z-qVqgt=D75n5Ix=20C6iQBpz8yH6B&o9MgYQsTnQ!a-Sm}6
zheknrf(V_C`&1!&XGzAyX|ZG~5TemE3E>O{E$BI-N5<(mWkT@;F&p~>RYMhy!Fo4w
z5=89X$T$$HInY7DrK9gAUSB0khOR4K4*iu>$qyYz>#A@*@d7KR4i728-~l19Vm(xx
z9Z_#0t5El$Ee_~XpxR=J`@|iHcWOYf1&t#J#VHj+Fwr9up9wl^iULriSwyj&bC+@9ef)g4#`vaXI8E&~GiCBZwxu-Uu-~R&iMg0
z(%Up?;-MXfVRYUm%FrzF!lxV_*|>2TF>H0laPyY2;f}*FI&EVyRAU019aCDHs4@R)
zgCPrWxE;<=BOva!0r4K%Tx9y@)eBI+uFn1rt#lR*p
z^v*dQ828w~IGy&nHLwt;=BFmP15pf$DGR5_yG4VKK&*T>G@dB#bw;uA@v%2_B#L6r
z!W9R2&SLK1pF4cyU#UsgbWS}M#+o#@#pw^ehgx`a0^_Im4K?yNu8?2F1jMV>H)F@d2?{eS>c#jJ3`%wk(+UR)@kU<}Mtg8Fq7w
zea^q;J;FH38OElDr|x#@r_p$#t
zO!p5bJ-c5+N|xo7hbt>1m3bv)h1JE8Jd_wKtsIQvVwHJi<^SJGGJ)QKC{N4G
zOD!qQOVul=RMJt%0rC_~foU=`Pr(MF(~e642yzoM^E9m4Z^tk(uxN6b9CES%1wcCy
zzCZ*(L&v}C%&8Y1%-;*t+ChZYX@~$gal7RPvs5d60>A{=>H`J<$j@NUgO6uREz$$i
z8k$^d-2%^oyx&QL9}wQ(aIxzK^Y@>-r(OkW?IJ?!I)wLsES~`M{#*jyhsSYhQL%n;
zQ8FUZK>{#eAR-#L0tq-G5A-Qm3hdP)U}7lB1B&IP7R7^wGxO5pbMo^GG{Ewpv;YKg
zSO8bhUX}n*uyhmQH$<>}cs=a~^Z!{Ve}ax?>mfob@MtzzmfbLQ9xzzmQaf0H%NW4H
z5|1Sqz@dypq(SqSAsAt~R!t4?r&IC&D;{OBQr&y1^XW>{0_w${pc5B?3ti2O56^uSE&xs?5@=2U
literal 0
HcmV?d00001
diff --git a/wandb/run-20250310_212229-x0wxhan7/files/requirements.txt b/wandb/run-20250310_212229-x0wxhan7/files/requirements.txt
new file mode 100644
index 000000000..67cb06e8a
--- /dev/null
+++ b/wandb/run-20250310_212229-x0wxhan7/files/requirements.txt
@@ -0,0 +1,295 @@
+pydantic==2.10.4
+urllib3==2.3.0
+scipy==1.15.0
+myst-nb==1.1.2
+pure_eval==0.2.3
+wcwidth==0.2.13
+attr-dot-dict==0.1.0
+emoji==2.14.0
+mkl_random==1.2.8
+keras==3.8.0
+nvidia-cuda-runtime-cu12==12.4.127
+torchvision==0.20.1
+cocotb==1.8.0
+wheel==0.44.0
+imageio==2.36.1
+dill==0.3.8
+pydot==3.0.4
+transformers==4.47.1
+sphinx-book-theme==1.1.3
+myst-parser==4.0.0
+traitlets==5.14.3
+nvidia-cudnn-cu12==9.1.0.70
+nvidia-curand-cu12==10.3.5.147
+kiwisolver==1.4.8
+pygame==2.6.1
+greenlet==3.1.1
+pytest-profiling==1.8.1
+requests==2.32.3
+aiosignal==1.2.0
+aiosignal==1.3.2
+Sphinx==8.1.3
+torch-summary==1.4.5
+Farama-Notifications==0.0.4
+sphinxcontrib-plantuml==0.30
+ptyprocess==0.7.0
+pexpect==4.9.0
+yarl==1.18.0
+yarl==1.18.3
+filelock==3.16.1
+filelock==3.13.1
+datasets==3.2.0
+datasets==3.3.2
+bitstring==4.3.0
+triton==3.1.0
+py4j==0.10.9.8
+pybind11==2.13.6
+pluggy==1.5.0
+regex==2024.11.6
+cvxpy==1.6.0
+sphinx-test-reports==1.1.0
+jsonschema-specifications==2024.10.1
+fastjsonschema==2.21.1
+pytest-xdist==3.6.1
+smmap==5.0.2
+onnx==1.17.0
+tornado==6.4.2
+GitPython==3.1.44
+sphinxcontrib-htmlhelp==2.1.0
+iniconfig==2.0.0
+threadpoolctl==3.5.0
+cycler==0.12.1
+tzdata==2024.2
+tzdata==2023.3
+certifi==2024.12.14
+certifi==2025.1.31
+numpy==1.26.4
+gast==0.6.0
+frozenlist==1.5.0
+opt_einsum==3.4.0
+astunparse==1.6.3
+colorlog==6.9.0
+grpcio==1.69.0
+jupyter_core==5.7.2
+torchmetrics==1.6.1
+gprof2dot==2024.6.6
+nvidia-ml-py==12.560.30
+multidict==6.1.0
+etils==1.11.0
+jupyter_client==8.6.3
+sphinxcontrib-jsmath==1.0.1
+tensorboard-plugin-profile==2.19.0
+clarabel==0.9.0
+idna==3.7
+idna==3.10
+pylance==0.21.0
+ipykernel==6.29.5
+matplotlib-inline==0.1.7
+jedi==0.19.2
+lightning-utilities==0.11.9
+namex==0.0.8
+kornia==0.7.4
+docker-pycreds==0.4.0
+mkl-service==2.4.0
+fonttools==4.55.3
+tensorboard-data-server==0.7.2
+beautifulsoup4==4.12.3
+Werkzeug==3.1.3
+Markdown==3.7
+asttokens==3.0.0
+huggingface-hub==0.27.1
+huggingface_hub==0.29.2
+pytest-sugar==1.0.0
+tensorflow==2.18.0
+pytest==8.3.4
+joblib==1.4.2
+ipython==8.31.0
+mdurl==0.1.2
+optimum==1.23.3
+pytest-metadata==3.1.1
+debugpy==1.8.11
+absl-py==2.1.0
+mkl_fft==1.3.11
+sphinxcontrib-serializinghtml==2.0.0
+MarkupSafe==3.0.2
+sympy==1.13.1
+six==1.16.0
+six==1.17.0
+multiprocess==0.70.15
+multiprocess==0.70.16
+snowballstemmer==2.2.0
+zipp==3.21.0
+ale-py==0.10.1
+scs==3.2.7.post2
+find_libpython==0.4.0
+sphinxcontrib-jquery==4.1
+decorator==5.1.1
+nvidia-nvtx-cu12==12.4.127
+prompt_toolkit==3.0.48
+charset-normalizer==3.4.1
+charset-normalizer==3.3.2
+nvidia-cuda-nvrtc-cu12==12.4.127
+evaluate==0.4.3
+tensorboard==2.18.0
+lightning==2.5.0.post0
+py-cpuinfo==9.0.0
+prettytable==3.12.0
+nbclient==0.10.2
+execnet==2.1.1
+torch-tb-profiler==0.4.3
+kornia_rs==0.1.8
+contourpy==1.3.1
+pydata-sphinx-theme==0.16.1
+pip==24.2
+requests-file==2.1.0
+jsonschema==4.23.0
+sphinx_glpi_theme==0.6
+imagesize==1.4.1
+osqp==0.6.7.post3
+importlib_resources==6.5.2
+termcolor==2.5.0
+importlib_metadata==8.5.0
+cocotb-bus==0.2.1
+future==1.0.0
+pyarrow==18.1.0
+pyarrow==19.0.0
+packaging==24.2
+sentry-sdk==2.19.2
+einops==0.8.0
+nvidia-cuda-cupti-cu12==12.4.127
+bitarray==3.0.0
+aiohttp==3.11.10
+aiohttp==3.11.11
+nvidia-cufft-cu12==11.2.1.3
+scikit-learn==1.6.0
+pyzmq==26.2.0
+Mako==1.3.8
+platformdirs==4.3.6
+nvidia-cusolver-cu12==11.6.1.9
+markdown-it-py==3.0.0
+wrapt==1.17.0
+tensorboardX==2.6.2.2
+protobuf==3.20.2
+propcache==0.2.1
+propcache==0.2.0
+pytz==2024.1
+pytz==2024.2
+wandb==0.19.1
+libclang==18.1.1
+nvidia-cublas-cu12==12.4.5.8
+alembic==1.14.0
+nvidia-nvjitlink-cu12==12.4.127
+click==8.1.8
+gymnasium==1.0.0
+Brotli==1.0.9
+lxml==5.3.0
+tensorflow-io-gcs-filesystem==0.37.1
+matplotlib==3.10.0
+tqdm==4.67.1
+annotated-types==0.7.0
+ghp-import==2.1.0
+pillow==10.4.0
+onnxconverter-common==1.14.0
+stable_baselines3==2.4.0
+imageio-ffmpeg==0.5.1
+onnxruntime==1.20.1
+typing_extensions==4.12.2
+Pygments==2.19.0
+coloredlogs==15.0.1
+sentencepiece==0.2.0
+torch==2.5.1
+timm==1.0.12
+mdit-py-plugins==0.4.2
+PyYAML==6.0.2
+gviz-api==1.10.0
+xxhash==3.5.0
+setuptools==75.1.0
+pytorch-nlp==0.5.0
+babel==2.16.0
+soupsieve==2.6
+ipdb==0.13.13
+python-dateutil==2.9.0.post0
+comm==0.2.2
+flatbuffers==24.12.23
+rpds-py==0.22.3
+psutil==6.1.1
+h5py==3.12.1
+numexpr==2.10.1
+optuna==4.1.0
+accessible-pygments==0.0.5
+tf_keras==2.18.0
+mypy-extensions==1.0.0
+pytest-html==4.1.1
+hyperopt==0.2.7
+tabulate==0.9.0
+fsspec==2024.12.0
+fsspec==2024.9.0
+parso==0.8.4
+sphinxcontrib-qthelp==2.0.0
+qdldl==0.1.7.post5
+nvidia-cusparse-cu12==12.3.1.170
+sphinx-data-viewer==0.1.5
+mase-cuda==0.0.1
+cloudpickle==3.1.0
+coverage==7.6.10
+pandas==2.2.3
+Jinja2==3.1.5
+black==24.10.0
+pathspec==0.12.1
+sphinxcontrib-devhelp==2.0.0
+mpmath==1.3.0
+pytorch-lightning==2.5.0.post0
+alabaster==1.0.0
+jupyter-cache==1.0.1
+stack-data==0.6.3
+sphinx-rtd-theme==3.0.2
+accelerate==1.2.1
+pyparsing==3.2.1
+docutils==0.21.2
+pytest-cov==6.0.0
+rich==13.9.4
+safetensors==0.5.3
+safetensors==0.5.0
+humanfriendly==10.0
+PySocks==1.7.1
+toml==0.10.2
+Bottleneck==1.4.2
+setproctitle==1.3.4
+opencv-python==4.10.0.84
+referencing==0.35.1
+nvidia-nccl-cu12==2.21.5
+tokenizers==0.21.0
+attrs==24.3.0
+aiohappyeyeballs==2.4.4
+optree==0.13.1
+networkx==3.4.2
+sphinx-needs==4.1.0
+nbformat==5.10.4
+gitdb==4.0.12
+SQLAlchemy==2.0.36
+executing==2.1.0
+google-pasta==0.2.0
+ml-dtypes==0.4.1
+pynvml==12.0.0
+nest-asyncio==1.6.0
+sphinxcontrib-applehelp==2.0.0
+pydantic_core==2.27.2
+transformers==4.47.1
+mase-tools==1.0.0
+more-itertools==10.3.0
+typing_extensions==4.12.2
+inflect==7.3.1
+typeguard==4.3.0
+tomli==2.0.1
+jaraco.context==5.3.0
+jaraco.functools==4.0.1
+platformdirs==4.2.2
+packaging==24.1
+autocommand==2.2.2
+jaraco.text==3.12.1
+zipp==3.19.2
+jaraco.collections==5.1.0
+importlib_metadata==8.0.0
+wheel==0.43.0
+backports.tarfile==1.2.0
+importlib_resources==6.4.0
diff --git a/wandb/run-20250310_212229-x0wxhan7/files/wandb-metadata.json b/wandb/run-20250310_212229-x0wxhan7/files/wandb-metadata.json
new file mode 100644
index 000000000..752811856
--- /dev/null
+++ b/wandb/run-20250310_212229-x0wxhan7/files/wandb-metadata.json
@@ -0,0 +1,60 @@
+{
+ "os": "Linux-5.14.0-427.28.1.el9_4.x86_64-x86_64-with-glibc2.34",
+ "python": "CPython 3.11.11",
+ "startedAt": "2025-03-10T21:22:29.549593Z",
+ "program": "/home/jw3621/Projects/bert-onn/test/passes/module/transforms/optical/bert-finetune.py",
+ "codePath": "test/passes/module/transforms/optical/bert-finetune.py",
+ "git": {
+ "remote": "https://github.com/Johnny1882/mase.git",
+ "commit": "758710333d8ca4b7444930df91d86c7642652426"
+ },
+ "email": "jw3621@ic.ac.uk",
+ "root": "/home/jw3621/Projects/bert-onn",
+ "host": "ee-tarrasque",
+ "executable": "/home/jw3621/anaconda3/envs/mase/bin/python",
+ "codePathLocal": "test/passes/module/transforms/optical/bert-finetune.py",
+ "cpu_count": 16,
+ "cpu_count_logical": 32,
+ "gpu": "NVIDIA GeForce RTX 3090",
+ "gpu_count": 4,
+ "disk": {
+ "/": {
+ "total": "75125227520",
+ "used": "60982046720"
+ }
+ },
+ "memory": {
+ "total": "269555560448"
+ },
+ "cpu": {
+ "count": 16,
+ "countLogical": 32
+ },
+ "gpu_nvidia": [
+ {
+ "name": "NVIDIA GeForce RTX 3090",
+ "memoryTotal": "25769803776",
+ "cudaCores": 10496,
+ "architecture": "Ampere"
+ },
+ {
+ "name": "NVIDIA GeForce RTX 3090",
+ "memoryTotal": "25769803776",
+ "cudaCores": 10496,
+ "architecture": "Ampere"
+ },
+ {
+ "name": "NVIDIA GeForce RTX 3090",
+ "memoryTotal": "25769803776",
+ "cudaCores": 10496,
+ "architecture": "Ampere"
+ },
+ {
+ "name": "NVIDIA GeForce RTX 3090",
+ "memoryTotal": "25769803776",
+ "cudaCores": 10496,
+ "architecture": "Ampere"
+ }
+ ],
+ "cudaVersion": "12.7"
+}
\ No newline at end of file
diff --git a/wandb/run-20250310_212229-x0wxhan7/run-x0wxhan7.wandb b/wandb/run-20250310_212229-x0wxhan7/run-x0wxhan7.wandb
new file mode 100644
index 0000000000000000000000000000000000000000..c7816c0c9959c866c7effc0863411253fd092912
GIT binary patch
literal 98304
zcmeIb37lm`dG~KK3&UlAnI6^#0yL<&(A;y*U41nw#%LnO-58hL-08kE)6mo1_QE1!
z$h;y-R73$+WN}x-U5ODi#w9VnNqmV)jKUvx;t~@z29r1bzfV;?b*zZ5!j5}|+-v=)Ezi;W@Z)n5NuA$9m92UjRu`2(I9#*f_
z8jVkHtW-9fS3UQ<9o^~XyNCW)rCm9lFWY0$Lu%W$^^&L=SL@YnN!;z#8WW9Lv$3sO
zt=DT4(L^<=wVO%3H@v<&R-?+{h0~(P3=M5KVAs$|qZ{Yrv(K40JC4p-Jg0qX#BqQN;VfbWLOt
zudj{Onq$!!2b?-|*2byH>BaL-JFnH~G-~f2dTQkXI~EpZ=O1?Z>Dwn4b}Vik8=u*E
z`tQ!{n4UgAYPI6icXsD{V|+DyYO~&IM%5%qCR*d&+SX>RR%<8KiEZs@qSY90HfnLB
z9`olRU)ucT;~B9v@iCR7DqBqFPutmB=uUJOx?^X}&rBaZbl&LEUy?Fw2bB5laZmoY
z5z4HKkE|SC*)-RC{^I0ZPvkW}ws79U(L)p6E53I<1=cUA8qdCPIU*e1Q(_)I@yahY
zj~d!=;F1k}@a1Fl_29YQ!s6U?XJT@EVb{>wg}KFENJL%CVKr?)&b-0y3JMnNGhvqH^%U%-r~n`SH2QS*mou;YsG>&cwp`
zvpt?Py*M>hIkK{GalY5tx@~HvyAU;e>7y%$E^O;e^|p7%&+kml%%|0Fscf8`Tb%Aq
zbawQ*6MT-QSB|S}S(xl~&YhW?==64O?M+NDDfrHI=lIG|<2x3o&+g1mzM$9H*5kL^
zX6DZA&Q0tZ8mjW+;mP?=$ZfemKhzG}XfR+l|_bEUGgdmciV-oCJde#X_xCO(>+_8-ysEtP{NW;*lTowHLs
z*2x;_^}2IY=QD0Iv$Kf8e{)#npy|b(ovpp@PQ)`5RgS0}rXO^+&&@2(+K-Q^Y@V1r
zr#Cl0xp01Gwm03KS~!2#&>Ge*o_O#=Zzs!17ro(`Ya%|}SXsxTI%m_&dX;s^{A<>I
zxw&!>V?H_Fo$BaEBbDK~-fVASQgkQt0IBQ9%AvN;!RI5D!>4CD^0V$jXL|cwcc&Rp
zGlrv;<5Y~LWsa;Ivb8&5lw`g$DXL?2Q0he{ZXV1t^{Gm<9q}AhImBuD{P@gVW(u5C
zIVv-8It_e{TXP1d$R3#OgebUjD!{!%fXKB4_h41-qRF2r%WB=%s
ze1HGSAzfzjAPi~tW@62}5THjKtot>N6
zI@R0R$*ze=y))h2+3U>A@dKIqTSjFkKW!`9<7tc2gM6Pq*{+-Ss--VFl*tOW~0S;ZI(&qzL@Hs
zPg__lqPQ|r*&^R$7E-sa^BvE1zv#;Dw0?AZZ)$OrpKPXqJvuZo$&7Q4Uz=2$F?-V%
z_mhdanOP=pM`Uc)uBdWK<+$`Y@;NJezSIeoqh#TooniHpzFxG%9;Q<*59VoS-YC-L3O8Q;Q3|
z(aLb;AlWL%y9^PvnvMF{%COOWluNcQ+ZEOOV=EQ6c1=vq?HbxRYKKtp;HcT+QH68o
znB1LgF}H;bS2nP(pgEJMiu-I+2`u@dR?Q2;(+x>@w#${kZVL<%t}WIf!@#&O_i
zZ+2#Uhk2|kErOMey)!dCMG5!QlPbruv#^|j%13M2L&Zg
zsvIpPw$CmW)L>`ki3r^l$<<&>&=Nl>y5r-EI~S+y{F+}dqsQhzCFaLPU)xI9uT2T2
zPM1!y!tfglgV;jZTFq!8C{v~V4O`9>ySG_O-F{NUE+}{&s^Gpd>a?a_txDg`XA3h6%*d(9os+UN80@WI&3Cu;K=Jc4vLfrcvf2Eq
zUu1rIALhOUVVX6?_bIA85SW01%Q`hRwY59W=4ZG=XWQJ&&dl!Of3mr9Ac$>zae^^(
z58xxS(#4RvdmdkjmQ
z)HylZ11h-B4y~-+x-G*)`1iKi%)gJS94i0r%*-yRCaxWQQZ!mQsdD6w?$kC}3&$tt
zWpR_v_oivs`cc`EM`a6;x;xt-KK(7>8&hlkn?ovV7dTSTA!9w6V#A%1y#Np
z8q8^`mru>y01`>hSygF@-8Uf}W;mUuP_GrEaLG(LtpPBweGaC=dI*1Hb+&WZ)IL(#
zG`Srh(Cb)jJ$o`;7SQKy^Ar4A2H_Ye1pnv==@V%xengntDhHdd#%HFtO)`kjykNAy
z6gXK>^LbNU{-5FP_Fo0e4=yiT$k2b
zvHkf_WtiQt$1c!iOp+*^h&7@cQfUa*^aZ(i%CNsJmO7d8#eeIl$IxJ{<)-9>k6}W4Onzp`)KOY8UdYGaPOR!7>M8
zP&Ao(=n*t+qT}X)(Ipmsv+6sw+-%m8MHViYKo`9OIIA3vFyGBFD2H}1
z*}09r?u5V20eb8-dtl}G?Dza~l9dC(;TH%e4S4Z9^TGXGmnyC1l*)*VzdIP}gOW4p
z9O4LSc5!}(XeXzP8EGEL?z9NA%56rrlMph>yZN3{_2ccw>5jWmRCA8w6*l5g7{@{K
zD?|Q74IcjZXPnV_`jejU*x&n2gFWQS*}2JcgsI42sYn@j`Kmj0Zuk88jx^Y}^Z}LQ
zy~g%2vW{XW+hD{oIXynLI3ec*oY&2p#mFu1zBcfemXfu~e=REvc*7!8_TU+)rJtcH
z@uMmm&EAyT&Q7RokzXzXZcYAkw@&%?gmC!>fo3llngpU|~
z_>-UUw8tajPFB+3shzWSg>+hdz^EY{vin-i&+-JPZY;%S*PW2xuz*g^r(1^}Sx%e}
z9rwd_I8%C5I9V`9wUBH!Mx(4+&&!IN
z5<%wZbZ;ByJw3qmB=fO5HM;}+W|fk0w<{lKBHLL=XgajOZ>^wWI}S$gyfV5rATjef
z%eXmH@r`4k%*>zMPQ!r|)4~GWkJzovkEFnMPE05BfJ3`3aAvxd$)V
z+^xqA4?TFapwnwNVA?wJDR+G5Q`EC@$$FmshX0Xn|1L(T4Bi4}$>{aw`VDILpMLDG
z{*rDSvUE4DUcGK?lF@q|#y?~KTWSE8*OOOY{koSwo}L`KWKSM^GED1m1ISPJfY>ef
z8*25WUX5F|bDr?rkDrKn^{^$YJL{K2!y7kTWelMI@wN~DcxdFVp+inQv3ma(o_KG0
zz3{~9=~3KBswY0DT76ix`tV0S^u$L_E}TC9ys!Q3b!-0&sc(X7oI3Qd_lYR3F{1dy
zPp*FRnvqi~hgu9Ld&M|d4RJP-v0BnhTFqM2sMVWMH1c3h(&lEy$86ED?m68_0P|Gu
ziD&E@I<#GFC9Sa*mfEP*tdE?;ex4~iH$N|>4vwP+Csek`gDYFIMf{V8t7ENNwNZ`h
zQ6kp#R=audt4@CR@6i6uw*8Mvn}7Q|Xg$
zGif~T;rIKKmr&oZt?xo#-*-Ryed>EqWedCR*<+`VEzYx#JlV^nh5Z|mP<0eF;zqMI
za&qO+Y|+Qj$xZdfShJn9>y2i;QEx}BkwZ95mHOm>b*MUWCj{_z0ZGe
z-Oz?DwlVKK*tg~#w_S4Gnh|VjoNwZtz+Zc4HiALJex%5oS|3_DBKz5+r{wrxL5@p?
zqj;=QZ8xf|cB55ox7zhKT^uPaea895`ddljTD#VaYOU4?^u`tyTOd1stuT$iRuuYOyIGG$j;I`JzuVTGeww9v2P8GL
z!u@1+e6n}ut`8B
z9~O-v0Qe)*340
zc)3D$wWLCBvgS37B3pG9)imlHr$Sy6N7WIG@uIQX3n$|>wQ60JIqnm2Rh5}X_K7%B
zWlr)UP6|YR*z#3}@@B*mNti}w9FB%F>Ml4~dnsL3_j&b;F`Jzcs
zE2t?-K&BC1g3Sp0lJ+r%)0+e{(X%05wq*09*r&PE=6Av|z3!)ke%@V{SyuTmSSp
ze7X@-#DgQsv4hsrxfW;%U>v|(t@0;s!3ApVwD;jaMPs1eYP|(SW?nU`k!nf4&bsh7
z{-AkPaB@2H@+~cUd=w1n=4dQ#H)REBp-!ktboLLgJ^JII{3Go8{C;6Q?=r0Cwg383
zVE+B{tY_G;o>m(bZ&n$|II1m8dun3>lzHT8)1GfU|D2x*?K#qRKxmI`|G(Y(r-ig9
zZZ;VCBHF`rs5YPXg&(|8XwOl$KA}CfzRPa^^gT{{hN`>nxSrkfXxo_A$?j=ebK4a+
z?2WVsgrOLw+K*vc@T-PQ%Vn~jI}p)j=~(woCIIc6&k4AtpPPDWi;Hm2-yZtchU~
z4uc0gXH7=V{WPboat>*Z$Kd+{%}H7NifB%jw1eRG(;Ry;vK~}Hkb|o6ifB#@a_|Qu
zo=NUGn)Adnzw)_vu}~jl7wT)$h5FmK{l{C^C>rDFr;p}<#)LD-J8oB#R=WmGa1bwq
zH~`Ex8!iRez^W045Dqm45C=4ZrxR~rh@k&C_B;`6@zbHHOLs!^`t3>hK;J&vauYUHd-S`J3_Y!*JVgv)PT)_7m08uBS{dKHw+8)OZ#kA06r?n94}p?yUy&U|Z}D-K}6W_$~Xp3z!1vbZJ0
z)k~s)Y^czxDhVYUL9LJtJ876-ak~Wq_1z6hA{jdqDF)>NVtBY8f@r10U546xAtK=c&DlWLDnEs?k<01XY
z9F14|&TaQ4`co^VKMChAz@cT+A7@Yt=?}+(&1S~R)2b;n?_)oym_zoH5py5=ftNaS
z+P=(w>UD2{SepId)nsinW;%FHUe_FJVW-^Bcv8m0@uq6-N*E8ehYaEIn!D#1&-UNC
z@z%>(rGLY&(wC&G^xMAr9T<;di~tP`z*%t~9OuBQ!x#nFDNiIE#t1=90giSh^UsHF
z*%#yl6Ntw+1w2m4nVi>xGdV3Ls26yJGekeiS`&|XXjuT(Ks7BGSv9G*sx_>)fc!$s
z6W$CtL+OdFF6$-(%
zAcUg>Wo_e=3T?W3lkMREHZ0Rsgiz=+*P+iR!CWt)x4UV{Ir=qDcQ@<1UhDCTC-in)ylK7
zVm5C|V*zSl3dOLeG6+QLwIG6G8$%R5%NSi5%O(U6LopUlhSGI9P-b6Cc0#$L`2i9N
z1`{>sJuu8x?ap_8;i2Njc7k1(jT_s|hT6ROvlj#Acf+Hs-1s9tQqmhNi49rI%6;Ar
z5!w>t+Ob^TQn~fT7YH*u(Y9agIkx%NeDX?~pLb)cVT3AnW0N!~)u;T