From 4d544b07a8436fb4e6a4361e45d1ab8f2988d333 Mon Sep 17 00:00:00 2001 From: Johnny1882 Date: Thu, 23 Jan 2025 11:10:52 +0000 Subject: [PATCH 01/38] MORR ONN transform pass and test script implementation --- src/chop/nn/__init__.py | 1 + src/chop/nn/optical/__init__.py | 3 + src/chop/nn/optical/functional/__init__.py | 57 + src/chop/nn/optical/functional/compute.py | 1064 +++++++++++++++++ src/chop/nn/optical/functional/general.py | 411 +++++++ src/chop/nn/optical/functional/initializer.py | 152 +++ src/chop/nn/optical/functional/mrr.py | 112 ++ src/chop/nn/optical/functional/mrr_op.py | 404 +++++++ src/chop/nn/optical/functional/quantize.py | 575 +++++++++ src/chop/nn/optical/functional/torch_train.py | 857 +++++++++++++ src/chop/nn/optical/modules/__init__.py | 7 + src/chop/nn/optical/modules/base_layer.py | 71 ++ src/chop/nn/optical/modules/morr_conv2d.py | 458 +++++++ src/chop/nn/optical/modules/morr_linear.py | 442 +++++++ .../passes/module/module_transform_helper.py | 64 + src/chop/passes/module/transforms/__init__.py | 1 + .../module/transforms/optical/__init__.py | 1 + .../module/transforms/optical/optical.py | 103 ++ .../transforms/optical/test_optical_module.py | 192 +++ .../transforms/optical/train_mnist_cnn.py | 144 +++ 20 files changed, 5119 insertions(+) create mode 100644 src/chop/nn/optical/__init__.py create mode 100644 src/chop/nn/optical/functional/__init__.py create mode 100644 src/chop/nn/optical/functional/compute.py create mode 100644 src/chop/nn/optical/functional/general.py create mode 100644 src/chop/nn/optical/functional/initializer.py create mode 100644 src/chop/nn/optical/functional/mrr.py create mode 100644 src/chop/nn/optical/functional/mrr_op.py create mode 100644 src/chop/nn/optical/functional/quantize.py create mode 100644 src/chop/nn/optical/functional/torch_train.py create mode 100644 src/chop/nn/optical/modules/__init__.py create mode 100644 src/chop/nn/optical/modules/base_layer.py create mode 100644 src/chop/nn/optical/modules/morr_conv2d.py create mode 100644 src/chop/nn/optical/modules/morr_linear.py create mode 100644 src/chop/passes/module/module_transform_helper.py create mode 100644 src/chop/passes/module/transforms/optical/__init__.py create mode 100644 src/chop/passes/module/transforms/optical/optical.py create mode 100644 test/passes/module/transforms/optical/test_optical_module.py create mode 100644 test/passes/module/transforms/optical/train_mnist_cnn.py diff --git a/src/chop/nn/__init__.py b/src/chop/nn/__init__.py index 7ab651324..5533a7608 100644 --- a/src/chop/nn/__init__.py +++ b/src/chop/nn/__init__.py @@ -1,3 +1,4 @@ from .quantized import quantized_module_map +from .optical import optical_module_map MASE_LEAF_LAYERS = tuple(quantized_module_map.values()) diff --git a/src/chop/nn/optical/__init__.py b/src/chop/nn/optical/__init__.py new file mode 100644 index 000000000..e9c0423d5 --- /dev/null +++ b/src/chop/nn/optical/__init__.py @@ -0,0 +1,3 @@ +from .modules import ( + optical_module_map, +) \ No newline at end of file diff --git a/src/chop/nn/optical/functional/__init__.py b/src/chop/nn/optical/functional/__init__.py new file mode 100644 index 000000000..11a435770 --- /dev/null +++ b/src/chop/nn/optical/functional/__init__.py @@ -0,0 +1,57 @@ +from .mrr import ( + MORRConfig_20um_MQ, + MRRConfig_5um_HQ, + MRRConfig_5um_MQ, + MRRConfig_5um_LQ, + MORRConfig_10um_MQ, +) + +from .compute import ( + im2col_2d, + toeplitz, +) + +from .general import ( + logger, +) + +from .initializer import ( + morr_uniform_, +) + +from .quantize import ( + input_quantize_fn, + weight_quantize_fn, +) + +from .mrr_op import ( + mrr_roundtrip_phase_to_tr_func, + mrr_roundtrip_phase_to_tr_fused, +) + + + + + +# """ +# Description: +# Author: Jiaqi Gu (jqgu@utexas.edu) +# Date: 2021-06-09 01:40:22 +# LastEditors: Jiaqi Gu (jqgu@utexas.edu) +# LastEditTime: 2021-06-09 01:40:22 +# """ + +# import importlib +# import os + +# # automatically import any Python files in this directory +# for file in sorted(os.listdir(os.path.dirname(__file__))): +# if file.endswith(".py") and not file.startswith("_"): +# source = file[: file.find(".py")] +# module = importlib.import_module("torchonn.layers." + source) +# if "__all__" in module.__dict__: +# names = module.__dict__["__all__"] +# else: +# # import all names that do not begin with _ +# names = [x for x in module.__dict__ if not x.startswith("_")] +# globals().update({k: getattr(module, k) for k in names}) diff --git a/src/chop/nn/optical/functional/compute.py b/src/chop/nn/optical/functional/compute.py new file mode 100644 index 000000000..c43ad9f2d --- /dev/null +++ b/src/chop/nn/optical/functional/compute.py @@ -0,0 +1,1064 @@ +""" +Description: +Author: Jiaqi Gu (jqgu@utexas.edu) +Date: 2021-06-06 02:17:08 +LastEditors: Jiaqi Gu (jqgu@utexas.edu) +LastEditTime: 2021-06-06 02:17:08 +""" + +import contextlib +import logging +from functools import lru_cache +from typing import Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +from scipy.stats import truncnorm +from torch import Tensor, nn +from torch.autograd import grad +from torch.nn.modules.utils import _pair +from torch.types import Device, _size + +from .torch_train import set_torch_deterministic + +__all__ = [ + "shift", + "Krylov", + "circulant", + "toeplitz", + "complex_circulant", + "complex_mult", + "expi", + "complex_matvec_mult", + "complex_matmul", + "real_to_complex", + "get_complex_magnitude", + "get_complex_energy", + "complex_to_polar", + "polar_to_complex", + "absclamp", + "absclamp_", + "im2col_2d", + "check_identity_matrix", + "check_unitary_matrix", + "check_equal_tensor", + "batch_diag", + "batch_eye_cpu", + "batch_eye", + "merge_chunks", + "partition_chunks", + "clip_by_std", + "percentile", + "gen_boolean_mask_cpu", + "gen_boolean_mask", + "fftshift_cpu", + "ifftshift_cpu", + "gen_gaussian_noise", + "gen_gaussian_filter2d_cpu", + "gen_gaussian_filter2d", + "add_gaussian_noise_cpu", + "add_gaussian_noise", + "add_gaussian_noise_", + "circulant_multiply", + "calc_diagonal_hessian", + "calc_jacobian", + "polynomial", + "gaussian", + "lowrank_decompose", + "get_conv2d_flops", + "interp1d", +] + + +def shift(v: Tensor, f: float = 1) -> Tensor: + return torch.cat((f * v[..., -1:], v[..., :-1]), dim=-1) + + +def Krylov(linear_map: Callable, v: Tensor, n: Optional[int] = None) -> Tensor: + if n is None: + n = v.size(-1) + cols = [v] + for _ in range(n - 1): + v = linear_map(v) + cols.append(v) + return torch.stack(cols, dim=-2) + + +def circulant(eigens: Tensor) -> Tensor: + circ = Krylov(shift, eigens).transpose(-1, -2) + return circ + + +@lru_cache(maxsize=4) +def _get_toeplitz_indices(n: int, device: Device) -> Tensor: + # cached toeplitz indices. avoid repeatedly generate the indices. + indices = circulant(torch.arange(n, device=device)) + return indices + + +def toeplitz(col: Tensor) -> Tensor: + """ + Efficient Toeplitz matrix generation from the first column. The column vector must in the last dimension. Batch generation is supported. Suitable for AutoGrad. Circulant matrix multiplication is ~4x faster than rfft-based implementation!\\ + @col {torch.Tensor} (Batched) column vectors.\\ + return out {torch.Tensor} (Batched) circulant matrices + """ + n = col.size(-1) + indices = _get_toeplitz_indices(n, device=col.device) + return col[..., indices] + + +def complex_circulant(eigens: Tensor) -> Tensor: + circ = Krylov(shift, eigens).transpose(-1, -2) + return circ + + +def complex_mult(X: Tensor, Y: Tensor) -> Tensor: + """Complex-valued element-wise multiplication + + Args: + X (Tensor): Real tensor with last dim of 2 or complex tensor + Y (Tensor): Real tensor with last dim of 2 or complex tensor + + Returns: + Tensor: tensor with the same type as input + """ + if not torch.is_complex(X) and not torch.is_complex(Y): + assert ( + X.shape[-1] == 2 and Y.shape[-1] == 2 + ), "Last dimension of real-valued tensor must be 2" + if hasattr(torch, "view_as_complex"): + return torch.view_as_real( + torch.view_as_complex(X) * torch.view_as_complex(Y) + ) + else: + return torch.stack( + ( + X[..., 0] * Y[..., 0] - X[..., 1] * Y[..., 1], + X[..., 0] * Y[..., 1] + X[..., 1] * Y[..., 0], + ), + dim=-1, + ) + else: + return X.mul(Y) + + +def complex_matvec_mult(W: Tensor, X: Tensor) -> Tensor: + return torch.sum(complex_mult(W, X.unsqueeze(0).repeat(W.size(0), 1, 1)), dim=1) + + +def complex_matmul(X: Tensor, Y: Tensor) -> Tensor: + assert X.shape[-1] == 2 and Y.shape[-1] == 2, "Last dimension must be 2" + if torch.__version__ >= "1.8" or ( + torch.__version__ >= "1.7" and X.shape[:-3] == Y.shape[:-3] + ): + return torch.view_as_real( + torch.matmul(torch.view_as_complex(X), torch.view_as_complex(Y)) + ) + + return torch.stack( + [ + X[..., 0].matmul(Y[..., 0]) - X[..., 1].matmul(Y[..., 1]), + X[..., 0].matmul(Y[..., 1]) + X[..., 1].matmul(Y[..., 0]), + ], + dim=-1, + ) + + +def expi(x: Tensor) -> Tensor: + if torch.__version__ >= "1.8" or ( + torch.__version__ >= "1.7" and not x.requires_grad + ): + return torch.exp(1j * x) + else: + return x.cos().type(torch.cfloat) + 1j * x.sin().type(torch.cfloat) + + +def real_to_complex(x: Tensor) -> Tensor: + if torch.__version__ < "1.7": + return torch.stack((x, torch.zeros_like(x).to(x.device)), dim=-1) + else: + return torch.view_as_real(x.to(torch.complex64)) + + +def get_complex_magnitude(x: Tensor) -> Tensor: + assert x.size(-1) == 2, "[E] Input must be complex Tensor" + return torch.sqrt(x[..., 0] * x[..., 0] + x[..., 1] * x[..., 1]) + + +def complex_to_polar(x: Tensor) -> Tensor: + # real and imag to magnitude and angle + if isinstance(x, torch.Tensor): + mag = x.norm(p=2, dim=-1) + angle = torch.view_as_complex(x).angle() + x = torch.stack([mag, angle], dim=-1) + elif isinstance(x, np.ndarray): + x = x.astype(np.complex64) + mag = np.abs(x) + angle = np.angle(x) + x = np.stack([mag, angle], axis=-1) + else: + raise NotImplementedError + return x + + +def polar_to_complex(mag: Tensor, angle: Tensor) -> Tensor: + # magnitude and angle to real and imag + if angle is None: + return real_to_complex(angle) + if mag is None: + if isinstance(angle, torch.Tensor): + x = torch.stack([angle.cos(), angle.sin()], dim=-1) + elif isinstance(angle, np.ndarray): + x = np.stack([np.cos(angle), np.sin(angle)], axis=-1) + else: + raise NotImplementedError + else: + if isinstance(angle, torch.Tensor): + x = torch.stack([mag * angle.cos(), mag * angle.sin()], dim=-1) + elif isinstance(angle, np.ndarray): + x = np.stack([mag * np.cos(angle), mag * np.sin(angle)], axis=-1) + else: + raise NotImplementedError + return x + + +def get_complex_energy(x: Tensor) -> Tensor: + assert x.size(-1) == 2, "[E] Input must be complex Tensor" + return x[..., 0] * x[..., 0] + x[..., 1] * x[..., 1] + + +def absclamp( + x: Tensor, min: Optional[float] = None, max: Optional[float] = None +) -> Tensor: + if isinstance(x, torch.Tensor): + mag = x.norm(p=2, dim=-1).clamp(min=min, max=max) + angle = torch.view_as_complex(x).angle() + x = polar_to_complex(mag, angle) + elif isinstance(x, np.ndarray): + x = x.astype(np.complex64) + mag = np.clip(np.abs(x), a_min=min, a_max=max) + angle = np.angle(x) + x = polar_to_complex(mag, angle) + else: + raise NotImplementedError + return x + + +def absclamp_( + x: Tensor, min: Optional[float] = None, max: Optional[float] = None +) -> Tensor: + if isinstance(x, torch.Tensor): + y = torch.view_as_complex(x) + mag = y.abs().clamp(min=min, max=max) + angle = y.angle() + x.data.copy_(polar_to_complex(mag, angle)) + elif isinstance(x, np.ndarray): + y = x.astype(np.complex64) + mag = np.clip(np.abs(y), a_min=min, a_max=max) + angle = np.angle(y) + x[:] = polar_to_complex(mag, angle) + else: + raise NotImplementedError + return x + + +def im2col_2d( + W: Optional[Tensor] = None, + X: Optional[Tensor] = None, + stride: int = 1, + padding: int = 0, + w_size: Optional[_size] = None, +) -> Tuple[Tensor, Tensor, int, int]: + if W is not None: + W_col = W.view(W.size(0), -1) + else: + W_col = None + + if X is not None: + n_filters, d_filter, h_filter, w_filter = W.size() if W is not None else w_size + n_x, d_x, h_x, w_x = X.size() + + h_out = (h_x - h_filter + 2 * padding) / stride + 1 + w_out = (w_x - w_filter + 2 * padding) / stride + 1 + + h_out, w_out = int(h_out), int(w_out) + X_col = torch.nn.functional.unfold( + X.view(1, -1, h_x, w_x), + h_filter, + dilation=1, + padding=padding, + stride=stride, + ).view(n_x, -1, h_out * w_out) + X_col = X_col.permute(1, 2, 0).contiguous().view(X_col.size(1), -1) + else: + X_col, h_out, w_out = None, None, None + + return W_col, X_col, h_out, w_out + + +def check_identity_matrix(W: Tensor) -> bool: + if isinstance(W, np.ndarray): + W_numpy = W.copy().astype(np.float64) + elif isinstance(W, torch.Tensor): + W_numpy = W.detach().cpu().numpy().copy().astype(np.float64) + else: + assert 0, "[E] Array type not supported, must be numpy.ndarray or torch.Tensor" + + return (W_numpy.shape[0] == W_numpy.shape[1]) and np.allclose( + W_numpy, np.eye(W_numpy.shape[0]) + ) + + +def check_unitary_matrix(W: Tensor) -> bool: + if isinstance(W, np.ndarray): + W_numpy = W.copy().astype(np.float64) + elif isinstance(W, torch.Tensor): + W_numpy = W.detach().cpu().numpy().copy().astype(np.float64) + else: + assert 0, "[E] Array type not supported, must be numpy.ndarray or torch.Tensor" + M = np.dot(W_numpy, W_numpy.T) + # print(M) + return check_identity_matrix(M) + + +def check_equal_tensor(W1: Tensor, W2: Tensor) -> bool: + if isinstance(W1, np.ndarray): + W1_numpy = W1.copy().astype(np.float64) + elif isinstance(W1, torch.Tensor): + W1_numpy = W1.detach().cpu().numpy().copy().astype(np.float64) + else: + assert 0, "[E] Array type not supported, must be numpy.ndarray or torch.Tensor" + + if isinstance(W2, np.ndarray): + W2_numpy = W2.copy().astype(np.float64) + elif isinstance(W2, torch.Tensor): + W2_numpy = W2.detach().cpu().numpy().copy().astype(np.float64) + else: + assert 0, "[E] Array type not supported, must be numpy.ndarray or torch.Tensor" + return (W1_numpy.shape == W2_numpy.shape) and np.allclose(W1_numpy, W2_numpy) + + +def batch_diag(x: Tensor) -> Tensor: + # x[..., N, N] -> [..., N] + assert ( + len(x.shape) >= 2 + ), f"At least 2-D array/tensor is expected, but got shape {x.shape}" + if isinstance(x, np.ndarray): + size = list(x.shape) + x = x.reshape(size[:-2] + [size[-2] * size[-1]]) + x = x[..., :: size[-1] + 1] + elif isinstance(x, torch.Tensor): + size = list(x.size()) + x = x.flatten(-2, -1) + x = x[..., :: size[-1] + 1] + else: + raise NotImplementedError + return x + + +def batch_eye_cpu(N: int, batch_shape: List[int], dtype: np.dtype) -> np.ndarray: + x = np.zeros(list(batch_shape) + [N, N], dtype=dtype) + x.reshape(-1, N * N)[..., :: N + 1] = 1 + return x + + +def batch_eye( + N: int, + batch_shape: List[int], + dtype: torch.dtype, + device: Device = torch.device("cuda"), +) -> torch.Tensor: + x = torch.zeros(list(batch_shape) + [N, N], dtype=dtype, device=device) + x.view(-1, N * N)[..., :: N + 1] = 1 + return x + + +def merge_chunks(x: Tensor, complex: bool = False) -> Tensor: + """Merge a chunked/blocked tensors into a 2D matrix + + Args: + x (Tensor): Tensor of shape [h1, w1, h2, w2, ...., hk, wk] if complex=False; [h1, w1, h2, w2, ...., hk, wk, 2] if complex=True + complex (bool, optional): True if the tensor x has a last dimension with size 2 for real/imag representation. Defaults to False. + + Returns: + Tensor: [h1*h2*...*hk, w1*w2*...*wk] or [h1*h2*...*hk, w1*w2*...*wk, 2] + """ + if isinstance(x, torch.Tensor): + permute = torch.permute + elif isinstance(x, np.ndarray): + permute = np.transpose + else: + raise NotImplementedError + + if not complex: + dim = len(x.shape) + x = permute(x, list(range(0, dim, 2)) + list(range(1, dim + 1, 2))) + x = x.reshape(np.prod([x.shape[i] for i in range(dim // 2)]), -1) + else: + dim = len(x.shape) - 1 + x = permute(x, list(range(0, dim, 2)) + list(range(1, dim + 1, 2) + [dim])) + x = x.reshape(np.prod([x.shape[i] for i in range(dim // 2)]), -1, 2) + + return x + + +def partition_chunks( + x: Tensor, out_shape: int | Tuple[int, ...], complex: bool = False +) -> Tensor: + """Partition a tensor into square chunks, similar to Rearrange in einops + + Args: + x (Tensor): 2D tensor of shape [h1*h2*...*hk, w1*w2*...*wk] or 3D tensor of shape [h1*h2*...*hk, w1*w2*...*wk, 2] if complex=True + out_shape (Tuple[int]): output blocked shape (h1, w1, h2, w2, ...); Do not include the last dimension even if complex=True + complex (bool, optional): whether x is complex tensor. Defaults to False. + + Returns: + [Tensor]: Tensor of shape [h1, w1, h2, w2, ...., hk, wk] or [h1, w1, h2, w2, ...., hk, wk, 2] if complex=True + """ + if complex: + assert len(x.shape) == 3 + x_shape = (np.prod(out_shape[::2]), np.prod(out_shape[1::2])) + if isinstance(x, torch.Tensor): + permute = torch.permute + pad_fn = lambda x, padding: torch.nn.functional.pad(x[None, None], padding)[ + 0, 0 + ] + is_tensor = True + elif isinstance(x, np.ndarray): + permute = np.transpose + pad_fn = np.pad + is_tensor = False + else: + raise NotImplementedError + + if x_shape != x.shape[:2]: + ## if x cannot be partitioned into out_shape, we need to pad it + if is_tensor: + ## torch from the last dim + padding = (0, x_shape[1] - x.shape[1], 0, x_shape[0] - x.shape[0]) + if complex: + padding = (0, 0) + padding + else: + ## np from the first dim + padding = ((0, x_shape[0] - x.shape[0]), (0, x_shape[1] - x.shape[1])) + if complex: + padding = padding + (0, 0) + + x = pad_fn(x, padding) + + in_shape = list(out_shape[::2]) + list(out_shape[1::2]) + permute_shape = np.arange(len(out_shape)).reshape(2, -1).T.reshape(-1).tolist() + if complex: + in_shape.append(2) + permute_shape.append(len(permute_shape)) + x = x.reshape(in_shape) # [h1, h2, ..., hk, w1, w2, ..., wk] + + x = permute(x, permute_shape) # [h1, w1, h2, w2, ...., hk, wk] + + return x + + +def clip_by_std(x: Tensor, n_std_neg: float = 3.0, n_std_pos: float = 3.0) -> Tensor: + if isinstance(x, np.ndarray): + std = np.std(x) + mean = np.mean(x) + out = np.clip(x, a_min=mean - n_std_neg * std, a_max=mean + n_std_pos * std) + elif isinstance(x, torch.Tensor): + std = x.data.std() + mean = x.data.mean() + out = x.clamp(min=mean - n_std_neg * std, max=mean + n_std_pos * std) + else: + raise NotImplementedError + return out + + +def percentile(t: Tensor, q: float) -> Tensor: + """ + Return the ``q``-th percentile of the flattened input tensor's data. + + CAUTION: + * Needs PyTorch >= 1.1.0, as ``torch.kthvalue()`` is used. + * Values are not interpolated, which corresponds to + ``numpy.percentile(..., interpolation="nearest")``. + + :param t: Input tensor. + :param q: Percentile to compute, which must be between 0 and 100 inclusive. + :return: Resulting value (scalar). + """ + # Note that ``kthvalue()`` works one-based, i.e. the first sorted value + # indeed corresponds to k=1, not k=0! Use float(q) instead of q directly, + # so that ``round()`` returns an integer, even if q is a np.float32. + if isinstance(t, torch.Tensor): + k = 1 + round(0.01 * float(q) * (t.numel() - 1)) + result = t.view(-1).kthvalue(k).values.item() + elif isinstance(t, np.ndarray): + result = np.percentile(t, q=q) + else: + raise NotImplementedError + return result + + +def gen_boolean_mask_cpu(size: _size, true_prob: float) -> np.ndarray: + assert 0 <= true_prob <= 1, "[E] Wrong probability for True" + return np.random.choice(a=[False, True], size=size, p=[1 - true_prob, true_prob]) + + +def gen_boolean_mask( + size: _size, + true_prob: float, + random_state: Optional[int] = None, + device: Device = torch.device("cuda"), +) -> Tensor: + assert 0 <= true_prob <= 1, "[E] Wrong probability for True" + if true_prob > 1 - 1e-9: + return torch.ones(size, device=device, dtype=torch.bool) + elif true_prob < 1e-9: + return torch.zeros(size, device=device, dtype=torch.bool) + if random_state is not None: + with torch.random.fork_rng(): + torch.random.manual_seed(random_state) + return torch.empty(size, dtype=torch.bool, device=device).bernoulli_( + true_prob + ) + else: + return torch.empty(size, dtype=torch.bool, device=device).bernoulli_(true_prob) + + +def fftshift_cpu( + x: Union[Tensor, np.ndarray], batched: bool = True, dim: Optional[Tuple[int]] = None +) -> Union[Tensor, np.ndarray]: + if isinstance(x, np.ndarray): + if dim is None: + if batched: + dim = tuple(range(1, len(x.shape))) + else: + dim = tuple(range(0, len(x.shape))) + out = np.fft.fftshift(x, axes=dim) + elif isinstance(x, torch.Tensor): + device = x.device + x = x.cpu().detach().numpy() + if dim is None: + if batched: + dim = tuple(range(1, len(x.shape))) + else: + dim = tuple(range(0, len(x.shape))) + out = np.fft.fftshift(x, axes=dim) + out = torch.from_numpy(out).to(device) + return out + + +def ifftshift_cpu( + x: Union[Tensor, np.ndarray], batched: bool = True, dim: Optional[Tuple[int]] = None +) -> Union[Tensor, np.ndarray]: + if isinstance(x, np.ndarray): + if dim is None: + if batched: + dim = tuple(range(1, len(x.shape))) + else: + dim = tuple(range(0, len(x.shape))) + out = np.fft.ifftshift(x, axes=dim) + elif isinstance(x, torch.Tensor): + device = x.device + x = x.cpu().detach().numpy() + if dim is None: + if batched: + dim = tuple(range(1, len(x.shape))) + else: + dim = tuple(range(0, len(x.shape))) + out = np.fft.ifftshift(x, axes=dim) + out = torch.from_numpy(out).to(device) + return out + + +def gen_gaussian_noise( + W: Union[Tensor, np.ndarray], + noise_mean: float = 0.0, + noise_std: float = 0.002, + trunc_range: Tuple = (), + random_state: Optional[int] = None, +) -> Union[Tensor, np.ndarray]: + if random_state is not None: + set_torch_deterministic(random_state) + if isinstance(W, np.ndarray): + if not trunc_range: + noises = np.random.normal(noise_mean, noise_std, W.shape) + else: + a = (trunc_range[0] - noise_mean) / noise_std + b = (trunc_range[1] - noise_mean) / noise_std + noises = truncnorm.rvs( + a, b, loc=noise_mean, scale=noise_std, size=W.shape, random_state=None + ) + elif isinstance(W, torch.Tensor): + if not trunc_range: + noises = torch.zeros_like(W).normal_(mean=noise_mean, std=noise_std) + else: + size = W.shape + tmp = W.new_empty(size + (4,)).normal_() + a = (trunc_range[0] - noise_mean) / noise_std + b = (trunc_range[1] - noise_mean) / noise_std + valid = (tmp < b) & (tmp > a) + ind = valid.max(-1, keepdim=True)[1] + noises = tmp.gather(-1, ind).squeeze(-1).mul_(noise_std).add_(noise_mean) + # noises = truncated_normal(W, mean=noise_mean, std=noise_std, a=trunc_range[0], b=trunc_range[1]) + else: + assert 0, logging.error( + f"Array type not supported, must be numpy.ndarray or torch.Tensor, but got {type(W)}" + ) + return noises + + +def gen_gaussian_filter2d_cpu(size: int = 3, std: float = 0.286) -> np.ndarray: + assert ( + size % 2 == 1 + ), f"Gaussian filter can only be odd size, but size={size} is given." + ax = np.linspace(-(size - 1) / 2.0, (size - 1) / 2.0, size) + xx, yy = np.meshgrid(ax, ax) + kernel = np.exp(-0.5 / np.square(std) * (np.square(xx) + np.square(yy))) + kernel = kernel / np.sum(kernel) + kernel[size // 2, size // 2] = 1 + return kernel + + +def gen_gaussian_filter2d( + size: int = 3, + std: float = 0.286, + center_one: bool = True, + device: Device = torch.device("cuda"), +) -> Tensor: + assert ( + size % 2 == 1 + ), f"Gaussian filter can only be odd size, but size={size} is given." + if std > 1e-8: + ax = torch.linspace( + -(size - 1) / 2.0, + (size - 1) / 2.0, + size, + dtype=torch.float32, + device=device, + ) + xx, yy = torch.meshgrid(ax, ax) + kernel = torch.exp(-0.5 / (std**2) * (xx.square() + yy.square())) + kernel = kernel.div_(kernel.sum()) + if center_one: + kernel[size // 2, size // 2] = 1 + else: + kernel = torch.zeros(size, size, dtype=torch.float32, device=device) + kernel[size // 2, size // 2] = 1 + + return kernel + + +def add_gaussian_noise( + W: Union[Tensor, np.ndarray], + noise_mean: float = 0, + noise_std: float = 0.002, + trunc_range: Tuple = (), + random_state: Optional[int] = None, +) -> Union[Tensor, np.ndarray]: + noises = gen_gaussian_noise( + W, + noise_mean=noise_mean, + noise_std=noise_std, + trunc_range=trunc_range, + random_state=random_state, + ) + output = W + noises + return output + + +def add_gaussian_noise_( + W: Union[Tensor, np.ndarray], + noise_mean: float = 0, + noise_std: float = 0.002, + trunc_range: Tuple = (), + random_state: Optional[int] = None, +) -> Union[Tensor, np.ndarray]: + noises = gen_gaussian_noise( + W, + noise_mean=noise_mean, + noise_std=noise_std, + trunc_range=trunc_range, + random_state=random_state, + ) + if isinstance(W, np.ndarray): + W += noises + elif isinstance(W, torch.Tensor): + W.data += noises + else: + assert 0, logging.error( + f"Array type not supported, must be numpy.ndarray or torch.Tensor, but got {type(W)}" + ) + return W + + +def add_gaussian_noise_cpu( + W: Union[Tensor, np.ndarray], + noise_mean: float = 0, + noise_std: float = 0.002, + trunc_range: Tuple = (), +) -> Union[Tensor, np.ndarray]: + if isinstance(W, np.ndarray): + W_numpy = W.copy().astype(np.float64) + elif isinstance(W, torch.Tensor): + W_numpy = W.detach().cpu().numpy().copy().astype(np.float64) + else: + assert 0, "[E] Array type not supported, must be numpy.ndarray or torch.Tensor" + if not trunc_range: + noises = np.random.normal(noise_mean, noise_std, W_numpy.shape) + else: + a = (trunc_range[0] - noise_mean) / noise_std + b = (trunc_range[1] - noise_mean) / noise_std + noises = truncnorm.rvs( + a, b, loc=noise_mean, scale=noise_std, size=W_numpy.shape, random_state=None + ) + return W_numpy + noises + + +def circulant_multiply(c: Tensor, x: Tensor) -> Tensor: + """Multiply circulant matrix with first column c by x + Parameters: + c: (n, ) + x: (batch_size, n) or (n, ) + Return: + prod: (batch_size, n) or (n, ) + """ + return torch.irfft( + complex_mult(torch.rfft(c, 1), torch.rfft(x, 1)), 1, signal_sizes=(c.shape[-1],) + ) + + +def calc_diagonal_hessian(weight_dict, loss, model): + model.zero_grad() + hessian_dict = {} + for name, weight in weight_dict.items(): + first_gradient = grad(loss, weight, create_graph=True)[0] + second_gradient = grad(first_gradient.sum(), weight, create_graph=True)[0] + hessian_dict[name] = second_gradient.clone() + model.zero_grad() + return hessian_dict + + +def calc_jacobian( + weight_dict: Dict[str, Tensor], loss: Callable, model: nn.Module +) -> Dict[str, Tensor]: + model.zero_grad() + jacobian_dict = {} + for name, weight in weight_dict.items(): + first_gradient = grad(loss, weight, create_graph=True)[0] + jacobian_dict[name] = first_gradient.clone() + model.zero_grad() + return jacobian_dict + + +@lru_cache(maxsize=4) +def _polynomial_order_base(order: int, device: Device) -> Tensor: + return torch.arange(order - 1, -1, -1, device=device) + + +def polynomial(x: Tensor | np.ndarray, coeff: Tensor | np.ndarray) -> Tensor: + """calculate polynomial function of x given coefficient coeff + + Args: + x (Tensor): input tensor + coeff (Tensor): Tensor of shape [n], where n is the degree of polynomial. Orders: [n, n-1, ..., 2, 1, constant] + + Returns: + Tensor: output tensor coeff[0]*x^n + coeff[1]*x^{n-1} + ... + coeff[n-1]*x + coeff[n] + """ + # xs = [x] + # for i in range(2, coeff.size(0)): + # xs.append(xs[-1]*x) + # xs.reverse() + # x = torch.stack(xs, dim=-1) + + # Deprecated implementation + # x = torch.stack([x**i for i in range(coeff.size(0) - 1, 0, -1)], dim=-1) + # out = (x * coeff[:-1]).sum(dim=-1) + coeff[-1].data.item() + # return out + + ### x^n, x^{n-1}, ..., x^2, x, 1 + order = coeff.shape[0] # n+1 + if isinstance(x, Tensor): + ## torch from highest order to constant + x = x[..., None].expand([-1] * x.dim() + [order]) + order_base = _polynomial_order_base(order, x.device) + return x.pow(order_base).matmul(coeff) + elif isinstance(x, np.ndarray): + ## numpy polyval from constant to higher order + return np.polynomial.polynomial.polyval(x, coeff[::-1]) + else: + raise NotImplementedError + + +def gaussian(x: Tensor, coeff: Tensor) -> Tensor: + # coeff : [n, 3], includes a, b, c + ## a * exp(-((x-b)/c)^2) + ... + size = x.size() + x = x.view(-1).unsqueeze(0) + x = ( + (coeff[:, 0:1] * torch.exp(-((x - coeff[:, 1:2]) / coeff[:, 2:3]).square())) + .sum(dim=0) + .view(size) + ) + return x + + +def lowrank_decompose( + x: Tensor, + r: int, + u_ortho: bool = False, + out_u: Optional[Tensor] = None, + out_v: Optional[Tensor] = None, +) -> Tuple[Tensor, Tensor]: + """low rank decomposition on x. x ~ uv. + + Args: + x (Tensor): tensor to decomplse + r (int): rank + u_ortho (bool, optional): whether u is orthogonal matrix. Defaults to False. + out_u (Optional[Tensor], optional): output buffer for u. Defaults to None. + out_v (Optional[Tensor], optional): output buffer for v. Defaults to None. + + Returns: + Tuple[Tensor, Tensor]: [description] + """ + ### x [..., m, n] + # r rank + u, s, v = x.data.svd(some=True) + v = v.transpose(-2, -1).contiguous() + u = u[..., :, :r] + s = s[..., :r] + v = v[..., :r, :] + if u_ortho == False: + u.mul_(s.unsqueeze(-2)) + else: + v.mul_(s.unsqueeze(-1)) + if out_u is not None: + out_u.data.copy_(u) + if out_v is not None: + out_v.data.copy_(v) + return u, v + + +def get_conv2d_flops( + input_shape: _size, + conv_filter: _size, + stride: _pair = (1, 1), + padding: _pair = (1, 1), +) -> float: + # input_shape = (4, 3,300,300) # Format:(batch, channels, rows,cols) + # conv_filter = (64,3,3,3) # Format: (num_filters, channels, rows, cols) + # stride = (1, 1) in (height, width) + # padding = (1, 1) in (height, width) + if type(stride) not in {list, tuple}: + stride = [stride, stride] + if type(padding) not in {list, tuple}: + padding = [padding, padding] + n = conv_filter[1] * conv_filter[2] * conv_filter[3] # vector_length + # general defination for number of flops (n: multiplications and n-1: additions) + flops_per_instance = n + 1 + + num_instances_per_filter = ( + (input_shape[2] - conv_filter[2] + 2 * padding[0]) / stride[0] + ) + 1 # for rows + # multiplying with cols + num_instances_per_filter *= ( + (input_shape[3] - conv_filter[3] + 2 * padding[1]) / stride[1] + ) + 1 + + flops_per_filter = num_instances_per_filter * flops_per_instance + # multiply with number of filters adn batch + total_flops_per_layer = flops_per_filter * conv_filter[0] * input_shape[0] + return total_flops_per_layer + + +class Interp1d(torch.autograd.Function): + @staticmethod + def forward(ctx, x, y, xnew, out=None): + """ + Batched Linear 1D interpolation on the GPU for Pytorch. + This function returns interpolated values of a set of 1-D functions at + the desired query points `xnew`. Any point exceeds the border of [xmin, xmax] + will be filled with 0 and no grad. + This function is working similarly to Matlab™ or scipy functions with + the `linear` interpolation mode on, except that it parallelises over + any number of desired interpolation problems. + The code will run on GPU if all the tensors provided are on a cuda + device. + https://github.com/aliutkus/torchinterp1d + + Parameters + ---------- + x : (N, ) or (D, N) Pytorch Tensor + A 1-D or 2-D tensor of real values. + y : (N,) or (D, N) Pytorch Tensor + A 1-D or 2-D tensor of real values. The length of `y` along its + last dimension must be the same as that of `x` + xnew : (P,) or (D, P) Pytorch Tensor + A 1-D or 2-D tensor of real values. `xnew` can only be 1-D if + _both_ `x` and `y` are 1-D. Otherwise, its length along the first + dimension must be the same as that of whichever `x` and `y` is 2-D. + out : Pytorch Tensor, same shape as `xnew` + Tensor for the output. If None: allocated automatically. + + """ + # making the vectors at least 2D + is_flat = {} + require_grad = {} + v = {} + device = [] + eps = torch.finfo(y.dtype).eps + for name, vec in {"x": x, "y": y, "xnew": xnew}.items(): + assert len(vec.shape) <= 2, "interp1d: all inputs must be " "at most 2-D." + if len(vec.shape) == 1: + v[name] = vec[None, :] + else: + v[name] = vec + is_flat[name] = v[name].shape[0] == 1 + require_grad[name] = vec.requires_grad + device = list(set(device + [str(vec.device)])) + assert len(device) == 1, "All parameters must be on the same device." + device = device[0] + + # Checking for the dimensions + assert v["x"].shape[1] == v["y"].shape[1] and ( + v["x"].shape[0] == v["y"].shape[0] + or v["x"].shape[0] == 1 + or v["y"].shape[0] == 1 + ), ( + "x and y must have the same number of columns, and either " + "the same number of row or one of them having only one " + "row." + ) + + reshaped_xnew = False + if ( + (v["x"].shape[0] == 1) + and (v["y"].shape[0] == 1) + and (v["xnew"].shape[0] > 1) + ): + # if there is only one row for both x and y, there is no need to + # loop over the rows of xnew because they will all have to face the + # same interpolation problem. We should just stack them together to + # call interp1d and put them back in place afterwards. + original_xnew_shape = v["xnew"].shape + v["xnew"] = v["xnew"].contiguous().view(1, -1) + reshaped_xnew = True + + # identify the dimensions of output and check if the one provided is ok + D = max(v["x"].shape[0], v["xnew"].shape[0]) + shape_ynew = (D, v["xnew"].shape[-1]) + if out is not None: + if out.numel() != shape_ynew[0] * shape_ynew[1]: + # The output provided is of incorrect shape. + # Going for a new one + out = None + else: + ynew = out.reshape(shape_ynew) + if out is None: + ynew = torch.zeros(*shape_ynew, device=device) + + # moving everything to the desired device in case it was not there + # already (not handling the case things do not fit entirely, user will + # do it if required.) + for name in v: + v[name] = v[name].to(device) + + # calling searchsorted on the x values. + ind = ynew.long() + + # expanding xnew to match the number of rows of x in case only one xnew is + # provided + if v["xnew"].shape[0] == 1: + v["xnew"] = v["xnew"].expand(v["x"].shape[0], -1) + + # the squeeze is because torch.searchsorted does accept either a nd with + # matching shapes for x and xnew or a 1d vector for x. Here we would + # have (1,len) for x sometimes + torch.searchsorted( + v["x"].contiguous().squeeze(), v["xnew"].contiguous(), out=ind + ) + + # the `-1` is because searchsorted looks for the index where the values + # must be inserted to preserve order. And we want the index of the + # preceeding value. + ind -= 1 + # we clamp the index, because the number of intervals is x.shape-1, + # and the left neighbour should hence be at most number of intervals + # -1, i.e. number of columns in x -2 + ind = torch.clamp(ind, 0, v["x"].shape[1] - 1 - 1) + + # helper function to select stuff according to the found indices. + def sel(name): + if is_flat[name]: + return v[name].contiguous().view(-1)[ind] + return torch.gather(v[name], 1, ind) + + # activating gradient storing for everything now + enable_grad = False + saved_inputs = [] + for name in ["x", "y", "xnew"]: + if require_grad[name]: + enable_grad = True + saved_inputs += [v[name]] + else: + saved_inputs += [ + None, + ] + # assuming x are sorted in the dimension 1, computing the slopes for + # the segments + is_flat["slopes"] = is_flat["x"] + # now we have found the indices of the neighbors, we start building the + # output. Hence, we start also activating gradient tracking + with torch.enable_grad() if enable_grad else contextlib.suppress(): + v["slopes"] = (v["y"][:, 1:] - v["y"][:, :-1]) / ( + eps + (v["x"][:, 1:] - v["x"][:, :-1]) + ) + + # now build the linear interpolation + ynew = sel("y") + sel("slopes") * (v["xnew"] - sel("x")) + + mask = (v["xnew"] > v["x"][:, -1:]) | ( + v["xnew"] < v["x"][:, :1] + ) # exceed left/right border + ynew = ynew.masked_fill(mask, 0) + + if reshaped_xnew: + ynew = ynew.view(original_xnew_shape) + + ctx.save_for_backward(ynew, *saved_inputs) + return ynew + + @staticmethod + def backward(ctx, grad_out): + inputs = ctx.saved_tensors[1:] + gradients = torch.autograd.grad( + ctx.saved_tensors[0], + [i for i in inputs if i is not None], + grad_out, + retain_graph=True, + ) + result = [ + None, + ] * 5 + pos = 0 + for index in range(len(inputs)): + if inputs[index] is not None: + result[index] = gradients[pos] + pos += 1 + return (*result,) + + +def interp1d(x: Tensor, y: Tensor, xnew: Tensor, out: Tensor | None = None) -> Tensor: + """numpy.interp for pytorch. Only 1D + + Args: + x (Tensor): input vector x coordinates + y (Tensor): input vector y coordinates + xnew (Tensor): new x coordinates to be interpolated + out (Tensor, optional): output tensor. Defaults to None. + + Returns: + Tensor: interpolated y coordinates + """ + return Interp1d.apply(x, y, xnew, out) diff --git a/src/chop/nn/optical/functional/general.py b/src/chop/nn/optical/functional/general.py new file mode 100644 index 000000000..1d4f2990d --- /dev/null +++ b/src/chop/nn/optical/functional/general.py @@ -0,0 +1,411 @@ +""" +Description: +Author: Jiaqi Gu (jqgu@utexas.edu) +Date: 2021-06-06 01:55:29 +LastEditors: Jiaqi Gu (jqgu@utexas.edu) +LastEditTime: 2021-06-06 01:55:30 +""" + +import os +import argparse +import json +import logging +import logging.handlers +import time +from collections import OrderedDict +from datetime import datetime +from pathlib import Path +from typing import Optional + +import numpy as np +import torch + + +__all__ = [ + "ensure_dir", + "read_json", + "write_json", + "profile", + "print_stat", + "Timer", + "TimerCtx", + "TorchTracemalloc", + "fullprint", + "setup_default_logging", + "Logger", + "logger", + "get_logger", + "ArgParser", + "disable_tf_warning", + "AverageMeter", +] + + +def ensure_dir(dirname, exist_ok: bool = True): + dirname = Path(dirname) + if not dirname.is_dir(): + dirname.mkdir(parents=True, exist_ok=exist_ok) + + +def read_json(fname): + with open(fname, "rt") as handle: + return json.load(handle, object_hook=OrderedDict) + + +def write_json(content, fname): + with open(fname, "wt") as handle: + json.dump(content, handle, indent=4, sort_keys=False) + + +def profile(func=None, timer=True): + from functools import wraps, partial + import time + + if func == None: + return partial(profile, timer=timer) + + @wraps(func) + def wrapper(*args, **kw): + if timer: + local_time = time.time() + res = func(*args, **kw) + end_time = time.time() + print("[I] <%s> runtime: %.3f ms" % (func.__name__, (end_time - local_time) * 1000)) + else: + res = func(*args, **kw) + return res + + return wrapper + + +def print_stat(x, message="", verbose=True): + if verbose: + if isinstance(x, torch.Tensor): + if torch.is_complex(x): + x = torch.view_as_real(x) + print( + message + f"min = {x.data.min().item():-15f} max = {x.data.max().item():-15f} mean = {x.data.mean().item():-15f} std = {x.data.std().item():-15f}" + ) + elif isinstance(x, np.ndarray): + print( + message + f"min = {np.min(x):-15f} max = {np.max(x):-15f} mean = {np.mean(x):-15f} std = {np.std(x):-15f}" + ) + + +class Timer(object): + def __init__(self): + self.cache = datetime.now() + + def check(self): + now = datetime.now() + duration = now - self.cache + self.cache = now + return duration.total_seconds() + + def reset(self): + self.cache = datetime.now() + + +class TimerCtx: + def __enter__(self): + self.start = time.time() + return self + + def __exit__(self, *args): + self.end = time.time() + self.interval = self.end - self.start + + +class TorchTracemalloc(object): + def __init__(self, verbose: bool = False) -> None: + super().__init__() + self.verbose = verbose + + def __enter__(self): + self.begin = self._b2mb(torch.cuda.memory_allocated()) + torch.cuda.reset_max_memory_allocated() # reset the peak gauge to zero + return self + + def _b2mb(self, x): + return x / 2 ** 20 + + def __exit__(self, *exc): + self.end = self._b2mb(torch.cuda.memory_allocated()) + self.peak = self._b2mb(torch.cuda.max_memory_allocated()) + self.used = self.end - self.begin + self.peaked = self.peak - self.begin + if self.verbose: + print(f"Delta used/peaked {self.used:.2f} MB / {self.peaked:.2f} MB") + print(f"Current used/peaked {self.end:.2f} MB / {self.peak:.2f} MB") + + +class fullprint: + "context manager for printing full numpy arrays" + + def __init__(self, **kwargs): + """linewidth=75; precision=8""" + kwargs.setdefault("threshold", np.inf) + self.opt = kwargs + + def __enter__(self): + self._opt = np.get_printoptions() + np.set_printoptions(**self.opt) + + def __exit__(self, type, value, traceback): + np.set_printoptions(**self._opt) + + +class CustomFormatter(logging.Formatter): + """Logging Formatter to add colors and count warning / errors""" + + grey = "\x1b[38;21m" + yellow = "\x1b[33;21m" + red = "\x1b[31;21m" + bold_red = "\x1b[31;1m" + green = "\x1b[32;21m" + reset = "\x1b[0m" + # format = "%(asctime)s - %(name)s - %(levelname)s - %(message)s (%(filename)s:%(lineno)d)" + format = "%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s" + + FORMATS = { + logging.DEBUG: grey + format + reset, + logging.INFO: grey + format + reset, + logging.WARNING: yellow + format + reset, + logging.ERROR: red + format + reset, + logging.CRITICAL: bold_red + format + reset, + } + + def format(self, record): + log_fmt = self.FORMATS.get(record.levelno) + formatter = logging.Formatter(log_fmt) + return formatter.format(record) + + +def setup_default_logging(default_level=logging.INFO, default_file_level=logging.INFO, log_path=""): + console_handler = logging.StreamHandler() + console_handler.setFormatter(CustomFormatter()) + logging.root.addHandler(console_handler) + logging.root.setLevel(default_level) + if log_path: + file_handler = logging.handlers.RotatingFileHandler(log_path, maxBytes=(1024 ** 2 * 2), backupCount=3) + file_formatter = logging.Formatter( + "%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s" + ) + file_handler.setFormatter(file_formatter) + file_handler.setLevel(default_file_level) + logging.root.addHandler(file_handler) + + +class Logger(object): + def __init__(self, console=True, logfile=None, console_level=logging.INFO, logfile_level=logging.INFO): + super().__init__() + self.logfile = logfile + self.console_level = console_level + self.logifle_level = logfile_level + assert ( + console == True or logfile is not None + ), "At least enable one from console or logfile for Logger" + # 第一步,创建一个logger + self.logger = logging.getLogger("my_logger") + self.logger.setLevel(logging.INFO) # Log等级总开关 + self.logger.propagate = False + + # formatter = logging.Formatter( + # "%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s") + formatter = CustomFormatter() + + # 第三步,再创建一个handler,用于输出到控制台 + if console: + ch = logging.StreamHandler() + ch.setLevel(self.console_level) # 输出到console的log等级的开关 + ch.setFormatter(formatter) + self.logger.addHandler(ch) + if self.logfile is not None: + fh = logging.FileHandler(self.logfile, mode="w") + fh.setLevel(self.logifle_level) # 输出到file的log等级的开关 + fh.setFormatter(formatter) + self.logger.addHandler(fh) + + def debug(self, message): + self.logger.debug(message) + + def info(self, message): + self.logger.info(message) + + def warning(self, message): + self.logger.warning(message) + + def error(self, message): + self.logger.error(message) + + def critical(self, message): + self.logger.critical(message) + + +def get_logger(name="default", default_level=logging.INFO, default_file_level=logging.INFO, log_path=""): + setup_default_logging( + default_level=default_level, default_file_level=default_file_level, log_path=log_path + ) + return logging.getLogger(name) + + +logger = get_logger() + + +class ArgParser(object): + def __init__(self, load_json=None, save_json=None): + super().__init__() + self.load_json = load_json + self.save_json = save_json + self.args = None + self.parser = argparse.ArgumentParser("Argument Parser") + + def add_arg(self, *args, **keywords): + self.parser.add_argument(*args, **keywords) + + def parse_args(self): + if self.load_json is not None: + assert os.path.exists(self.load_json), logging.error( + f"Configuration JSON {self.load_json} not found" + ) + json = read_json(self.load_json) + t_args = argparse.Namespace() + t_args.__dict__.update(json) + self.args = self.parser.parse_args(args=[], namespace=t_args) + else: + self.args = self.parser.parse_args() + return self.args + + def print_args(self): + # Print arguments to std out + # and save argument values to yaml file + print("Arguments:") + for p in vars(self.args).items(): + print(f"\t{p[0]:30}{str(p[1]):20}") + print("\n") + + def dump_args(self, json_file=None): + if json_file is None: + if self.save_json is None: + logging.error("Skip dump configuration JSON. Please specify json_file") + return False + else: + ensure_dir(os.path.dirname(self.save_json)) + logging.warning(f"Dump to the initialized JSON file {self.save_json}") + write_json(vars(self.args), self.save_json) + else: + ensure_dir(os.path.dirname(json_file)) + logging.info(f"Dump to JSON file {json_file}") + write_json(vars(self.args), json_file) + # with open(self.file, 'w') as f: + # yaml.dump(vars(self.args), f, default_flow_style=False) + # print(f"[I] Arguments dumped to {file}") + + +def disable_tf_warning(): + import os + + os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" + import warnings + + warnings.filterwarnings("ignore", category=FutureWarning) + warnings.filterwarnings("ignore", category=DeprecationWarning) + + import tensorflow as tf + + if hasattr(tf, "contrib") and type(tf.contrib) != type(tf): + tf.contrib._warning = None + tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) + # tf.logging.set_verbosity(tf.logging.ERROR) + + import logging + + logging.getLogger("tensorflow").setLevel(logging.ERROR) + + +class Meter(object): + """Base class for Meters.""" + + def __init__(self): + pass + + def state_dict(self): + return {} + + def load_state_dict(self, state_dict): + pass + + def reset(self): + raise NotImplementedError + + @property + def smoothed_value(self) -> float: + """Smoothed value used for logging.""" + raise NotImplementedError + + +def safe_round(number, ndigits): + if hasattr(number, "__round__"): + return round(number, ndigits) + elif torch is not None and torch.is_tensor(number) and number.numel() == 1: + return safe_round(number.item(), ndigits) + elif np is not None and np.ndim(number) == 0 and hasattr(number, "item"): + return safe_round(number.item(), ndigits) + else: + return number + + +def type_as(a, b): + if torch.is_tensor(a) and torch.is_tensor(b): + return a.to(b) + else: + return a + + +class AverageMeter(Meter): + """Computes and stores the average and current value""" + + def __init__(self, name: str, fmt: str = ":f", round: Optional[int] = None) -> None: + self.name = name + self.fmt = fmt + self.round = round + self.reset() + + def reset(self): + self.val = None # most recent update + self.sum = 0 # sum from all updates + self.count = 0 # total n from all updates + self.avg = 0 + + def update(self, val, n=1): + if val is not None: + self.val = val + if n > 0: + self.sum = type_as(self.sum, val) + (val * n) + self.count = type_as(self.count, n) + n + self.avg = self.sum / self.count if self.count > 0 else self.val + + def state_dict(self): + return { + "val": self.val, + "sum": self.sum, + "count": self.count, + "round": self.round, + } + + def load_state_dict(self, state_dict): + self.val = state_dict["val"] + self.sum = state_dict["sum"] + self.count = state_dict["count"] + self.round = state_dict.get("round", None) + + @property + def smoothed_value(self) -> float: + val = self.avg + if self.round is not None and val is not None: + val = safe_round(val, self.round) + return val + + def __str__(self) -> str: + fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})" + return fmtstr.format(**self.__dict__) diff --git a/src/chop/nn/optical/functional/initializer.py b/src/chop/nn/optical/functional/initializer.py new file mode 100644 index 000000000..53002c398 --- /dev/null +++ b/src/chop/nn/optical/functional/initializer.py @@ -0,0 +1,152 @@ +""" +Description: +Author: Jiaqi Gu (jqgu@utexas.edu) +Date: 2021-06-06 01:57:16 +LastEditors: Jiaqi Gu (jqgu@utexas.edu) +LastEditTime: 2021-06-06 01:57:18 +""" +import numpy as np +import torch + +__all__ = [ + "quant_kaiming_uniform", + "quant_kaiming_uniform_", + "truncated_normal", + "truncated_normal_", + "morr_uniform_", + "morr_uniform", +] + + +def quant_kaiming_uniform(w, nbit, beta=1.5): + """https://arxiv.org/pdf/1802.04680.pdf""" + if w.dim() > 2: + receptive_field = w[0, 0, ...].numel() + else: + receptive_field = 1 + fan_in = w.size(1) * receptive_field + sigma = 2 ** (1 - nbit) + L_min = beta * sigma + L = max(np.sqrt(6 / fan_in), L_min) + return w.clone().uniform_(-L, L) + + +def quant_kaiming_uniform_(w, nbit, beta=1.5): + """https://arxiv.org/pdf/1802.04680.pdf""" + if w.dim() > 2: + receptive_field = w[0, 0, ...].numel() + else: + receptive_field = 1 + fan_in = w.size(1) * receptive_field + sigma = 2 ** (1 - nbit) + L = np.sqrt(6 / fan_in) + L_min = beta * sigma + scale = 2 ** round(np.log2(L_min / L)) + scale = max(scale, 1.0) + L = max(L, L_min) + + return torch.nn.init.uniform_(w, -L, L), scale + + +def truncated_normal(tensor, mean=0, std=1, a=-2, b=2): + size = tensor.shape + tmp = tensor.new_empty(size + (4,)).normal_() + a = (a - mean) / std + b = (b - mean) / std + valid = (tmp < b) & (tmp > a) + ind = valid.max(-1, keepdim=True)[1] + output = tmp.gather(-1, ind).squeeze(-1).mul_(std).add_(mean) + return output + + +def truncated_normal_(tensor, mean=0, std=1, a=-2, b=2): + size = tensor.shape + tmp = tensor.new_empty(size + (4,)).normal_() + a = (a - mean) / std + b = (b - mean) / std + valid = (tmp < b) & (tmp > a) + ind = valid.max(-1, keepdim=True)[1] + tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1)) + tensor.data.mul_(std).add_(mean) + return tensor + + +def morr_uniform_(tensor, MORRConfig, n_op=4, biased=False, gain=1): + """ + description: Uniform initialization for MORR array based tensor core [SqueezeLight, Gu+, DATE'21]. We only consider how n_op influence one MORR's output. How to balance vector length should be considered in learnable balancing factor\\ + @tensor {torch.Tensor} weight tensor/parameter\\ + @MORRConfig {Config} MORR configuration defined in the onnlib/model/layer/device/mrr\\ + @n_op {int scalar} Number of operands on an MORR\\ + @biased {bool} biased=True, weight in [0, L]; otherwise in [-L/2, L/2].\\ + @gain {float} Gain due to activation. ReLU=sqrt(2), Tanh=5/3, Clamp(0,1)=2\\ + return {} + """ + morr_fwhm = ( + -4 + * np.pi ** 2 + * MORRConfig.radius + * MORRConfig.effective_index + * ( + 1 / MORRConfig.resonance_wavelength + - 1 / (MORRConfig.resonance_wavelength - MORRConfig.bandwidth / 2) + ) + ) + ### first we need to calculate the information gain of an MORR, estimated by linear estimation at 0 and FWHM + # t1 = mrr_roundtrip_phase_to_tr_fused(torch.tensor([0]).float(), a=MORRConfig.attenuation_factor, r=MORRConfig.coupling_factor, intensity=True) + # t2 = mrr_roundtrip_phase_to_tr_fused(torch.tensor([morr_fwhm]).float(), a=MORRConfig.attenuation_factor, r=MORRConfig.coupling_factor, intensity=True) + # g = (t2 - t1) / morr_fwhm + + ### calculate the variance of the weight + # var_phi = 1 ## assume the input is normalized to have variance 1 + # var_w = 1/(3/2*g**4*n_op*var_phi) + + ### calculate range of uniform distribution U(-L,L) + # L = ((3 * var_w)**0.5).item() + # return torch.nn.init.uniform_(tensor, -L, L) + + ## approximation by assuming 4*std(phi)= 3*FWHM, E[x]=0, D[x]=1, W ~ U[0, L] + L = (3 / (4 * n_op)) ** 0.5 * morr_fwhm * gain + if biased: + return torch.nn.init.uniform_(tensor, 0, L) + else: + return torch.nn.init.uniform_(tensor, -L / 2, L / 2) + + +def morr_uniform(tensor, MORRConfig, n_op=4, biased=False, gain=1): + """ + description: Uniform initialization for MORR array based tensor core [SqueezeLight, Gu+, DATE'21]\\ + @tensor {torch.Tensor} weight tensor/parameter\\ + @MORRConfig {Config} MORR configuration defined in the onnlib/model/layer/device/mrr\\ + @n_op {int scalar} Number of operands on an MORR\\ + @biased {bool} biased=True, weight in [0, L]; otherwise in [-L/2, L/2].\\ + @gain {float} Gain due to activation. ReLU=sqrt(2), Tanh=5/3, Clamp(0,1)=2\\ + return {} + """ + morr_fwhm = ( + -4 + * np.pi ** 2 + * MORRConfig.radius + * MORRConfig.effective_index + * ( + 1 / MORRConfig.resonance_wavelength + - 1 / (MORRConfig.resonance_wavelength - MORRConfig.bandwidth / 2) + ) + ) + ### first we need to calculate the information gain of an MORR, estimated by linear estimation at 0 and FWHM + # t1 = mrr_roundtrip_phase_to_tr_fused(torch.tensor([0]), a=MORRConfig.attenuation_factor, r=MORRConfig.coupling_factor, intensity=True) + # t2 = mrr_roundtrip_phase_to_tr_fused(torch.tensor([morr_fwhm]), a=MORRConfig.attenuation_factor, r=MORRConfig.coupling_factor, intensity=True) + # g = (t2 - t1) / morr_fwhm + + # var_phi = 1 ## assume the input is normalized to have variance 1 + # var_w = 1/(3/2*g**4*n_op*var_phi) + + # ### calculate range of uniform distribution U(-L,L) + # L = (3 * var_w)**0.5 + # return tensor.clone().uniform_(-L, L) + + ## approximation by assuming 4*std(phi)= 3*FWHM, E[x]=0, D[x]=1, W ~ U[0, L] + L = (3 / (4 * n_op)) ** 0.5 * morr_fwhm * gain + if biased: + return tensor.clone().uniform_(0, L) + else: + return tensor.clone().uniform_(-L / 2, L / 2) diff --git a/src/chop/nn/optical/functional/mrr.py b/src/chop/nn/optical/functional/mrr.py new file mode 100644 index 000000000..53fd56fd9 --- /dev/null +++ b/src/chop/nn/optical/functional/mrr.py @@ -0,0 +1,112 @@ +""" +Description: +Author: Jiaqi Gu (jqgu@utexas.edu) +Date: 2021-07-18 00:03:04 +LastEditors: Jiaqi Gu (jqgu@utexas.edu) +LastEditTime: 2021-07-18 00:03:05 +""" + +import numpy as np + + +__all__ = [ + "MORRConfig_20um_MQ", + "MRRConfig_5um_HQ", + "MRRConfig_5um_MQ", + "MRRConfig_5um_LQ", + "MORRConfig_10um_MQ", +] + + +class MORRConfig_20um_MQ: + attenuation_factor = 0.8578 + coupling_factor = 0.8985 + radius = 20000 # nm + group_index = 2.35316094 + effective_index = 2.35 + resonance_wavelength = 1554.252 # nm + bandwidth = 0.67908 # nm + quality_factor = 2288.7644639 + + +class MRRConfig_5um_HQ: + attenuation_factor = 0.987 + coupling_factor = 0.99 + radius = 5000 # nm + group_index = 2.35316094 + effective_index = 2.4 + resonance_wavelength = 1538.739 # nm + bandwidth = 0.2278 # nm + quality_factor = 6754.780509 + + +class MRRConfig_5um_MQ: + attenuation_factor = 0.925 + coupling_factor = 0.93 + radius = 5000 # nm + group_index = 2.35316094 + effective_index = 2.4 + resonance_wavelength = 1538.739 # nm + bandwidth = 1.5068 # nm + quality_factor = 1021.1965755 + + +class MRRConfig_5um_LQ: + attenuation_factor = 0.845 + coupling_factor = 0.85 + radius = 5000 # nm + group_index = 2.35316094 + effective_index = 2.4 + resonance_wavelength = 1538.739 # nm + bandwidth = 2.522 # nm + quality_factor = 610.1265 + + +class MORRConfig_10um_MQ: + attenuation_factor = 0.8578 + coupling_factor = 0.8985 + radius = 10000 # nm + group_index = 2.35316094 + effective_index = 2.4 + resonance_wavelength = 1538.739 # nm + bandwidth = 1.6702 # nm + quality_factor = 1213.047 + + +def plot_curve(config): + import matplotlib.pyplot as plt + + lambda0 = config.resonance_wavelength + lambda_vec = np.linspace(1546, lambda0, 9400) + aa = config.attenuation_factor # attenuation a + + t = config.coupling_factor # self-coupling + # r = np.sqrt(1 - t**2) # cross coupling coef + + R = config.radius # radius + neff = config.effective_index # refractive index + phi = -4 * np.pi * np.pi * R * neff / lambda_vec + + phase_shift = np.linspace(phi[0], phi[-1], len(phi)) + phase_shift = phase_shift - np.min(phase_shift) + print(phase_shift) + tr = (t - aa * np.exp(1j * phi)) / (1 - t * aa * np.exp(1j * phi)) + energy = abs(tr) ** 2 + print(energy) + plt.figure() + plt.plot(lambda_vec, energy) + plt.savefig("mrr_tr_wl.png") + plt.figure() + plt.plot(phase_shift, energy) + plt.savefig("mrr_tr_ps.png") + + for i, e in enumerate(energy[:-1]): + if energy[i] >= 0.5 and energy[i + 1] <= 0.5: + print(i, i + 1) + print(energy[i], energy[i + 1]) + print(lambda_vec[i], lambda_vec[i + 1]) + exit(1) + + +if __name__ == "__main__": + plot_curve(MRRConfig_5um_MQ) diff --git a/src/chop/nn/optical/functional/mrr_op.py b/src/chop/nn/optical/functional/mrr_op.py new file mode 100644 index 000000000..42952971e --- /dev/null +++ b/src/chop/nn/optical/functional/mrr_op.py @@ -0,0 +1,404 @@ +""" +Description: +Author: Jiaqi Gu (jqgu@utexas.edu) +Date: 2021-07-18 00:01:34 +LastEditors: Jiaqi Gu (jqgu@utexas.edu) +LastEditTime: 2021-07-18 00:01:36 +""" + +from .compute import ( + complex_mult, + polar_to_complex, + polynomial, +) +import logging + +import numpy as np +import torch + +torch._C._jit_set_profiling_executor(False) + + +__all__ = [ + "mrr_voltage_to_delta_lambda", + "mrr_tr_to_roundtrip_phase", + "mrr_roundtrip_phase_to_tr", + "mrr_roundtrip_phase_to_tr_fused", + "mrr_roundtrip_phase_to_tr_grad_fused", + "mrr_roundtrip_phase_to_tr_func", + "mrr_roundtrip_phase_to_out_phase", + "mrr_tr_to_out_phase", + "mrr_roundtrip_phase_to_tr_phase", + "mrr_roundtrip_phase_to_tr_phase_fused", + "mrr_modulator", + "mrr_filter", + "morr_filter", + "mrr_fwhm_to_ng", + "mrr_ng_to_fsr", + "mrr_finesse", +] + + +def mrr_voltage_to_delta_lambda(v, alpha, k, gamma, n_g, lambda_0): + """ + description: micro-ring resonator (MRR) wavelength modulation, \delta\lambda=\delta\n_eff\times\lambda/n_g, \deltan_eff=\gamma k \delta T=\gamma k \alpha v^2\\ + v {torch.Tensor ro np.ndarray} voltage \\ + alpha {scalar} voltage square to temperature change coefficient \\ + k {scalar} parameter \\ + gamma {scalar} power to phase shift coefficient \\ + n_g {scalar} group index, typically from 4 to 4.5\\ + lambda_0 {torch.Tensor or np.ndarray} central wavelength\\ + return delta_lambda {torch.Tensor or np.ndarray} resonance wavelength drift + """ + delta_neff = gamma * k * alpha * v * v + delta_lambda = delta_neff * lambda_0 / n_g + return delta_lambda + + +def mrr_tr_to_roundtrip_phase(t, a, r): + """ + description: field transmission to round trip phase shift + t {torch.Tensor or np.ndarray} field transmission from [0,1] \\ + a {scalar} attenuation coefficient\\ + r {scalar} coupling coefficient\\ + return phi {torch.Tensor or np.ndarray} roune trip phase shift (abs(phase lag))[0, pi], center is 0. phase lag is negative, the sign is moved to the equation + """ + # the curve has multiple valleies, thus given a t, there is infinite number of rt_phi, we only want [-pi, 0], thus the abs(phase lag) is in [0, pi], acos returns [0, pi], which matches our assumption + assert 0 <= a <= 1, logging.error(f"Expect a from [0,1] but got {a}") + assert 0 <= r <= 1, logging.error(f"Expect r from [0,1] but got {r}") + # given a and r, the curve is fixed, the max and min may not be 1 and 0 + cos_phi = ((a * a + r * r - t * (1 + r * r * a * a)) / (2 * (1 - t) * a * r)).clamp(0, 1) + + if isinstance(cos_phi, torch.Tensor): + return cos_phi.acos(), cos_phi + elif isinstance(cos_phi, np.ndarray): + return np.arccos(cos_phi), cos_phi + else: + raise NotImplementedError + + +def mrr_roundtrip_phase_to_tr( + rt_phi, a: float = 0.8, r: float = 0.9, poly_coeff=None, intensity: bool = False +): + """ + description: round trip phase shift to field transmission + rt_phi {torch.Tensor or np.ndarray} abs of roundtrip phase shift (abs(phase lag)). range from abs([-pi, 0])=[0, pi]\\ + a {scalar} attenuation coefficient\\ + r {scalar} self-coupling coefficient\\ + poly_coeff {Callable} polynomial coefficients of intensity tranmission-roundtrip phase curve. Default set to None. None for slow computation\\ + intensity {bool scalar} whether output intensity tranmission or field transmission + return t {torch.Tensor or np.ndarray} mrr through port field/intensity transmission + """ + if poly_coeff is not None: + # fast mode, use polynomial to predict the intensity transmission curve + # if using polynomial, we want fast intensity transmission estimation, instead of field + # if using coherent light, we will use complex output, we won't use polynomial fit + t = polynomial(rt_phi.clamp(0, np.pi), poly_coeff).clamp(1e-8, 1) + if not intensity: + # avoid NAN + t = (t + 1e-12).sqrt() + else: + # use slow but accurate mode from theoretical equation + # create e^(-j phi) first + # with torch.autograd.profiler.profile(use_cuda=True) as prof: + # ephi = torch.view_as_complex(polar_to_complex(mag=None, angle=-rt_phi)) ## this sign is from the negativity of phase lag + # ### Jiaqi: Since PyTorch 1.7 rsub is not supported for autograd of complex, so have to use negate and add + # a_ephi = -a * ephi + # t = torch.view_as_real((r + a_ephi)/(1 + r * a_ephi)) + + # if(intensity): + # t = get_complex_energy(t) + # else: + # t = get_complex_magnitude(t) + # print(prof.key_averages(group_by_stack_n=5).table(sort_by='cuda_time', row_limit=5)) + ra_cosphi_by_n2 = -2 * r * a * rt_phi.cos() + t = (a * a + r * r + ra_cosphi_by_n2) / (1 + r * r * a * a + ra_cosphi_by_n2) + if not intensity: + # as long as a is not equal to r, t cannot be 0. + t = t.sqrt() + return t + + +@torch.jit.script +def mrr_roundtrip_phase_to_tr_fused(rt_phi, a: float = 0.8, r: float = 0.9, intensity: bool = False): + """ + description: round trip phase shift to field transmission + rt_phi {torch.Tensor or np.ndarray} abs of roundtrip phase shift (abs(phase lag)). range from abs([-pi, 0])=[0, pi]\\ + a {scalar} attenuation coefficient\\ + r {scalar} self-coupling coefficient\\ + intensity {bool scalar} whether output intensity tranmission or field transmission\\ + return t {torch.Tensor or np.ndarray} mrr through port field/intensity transmission + """ + + # use slow but accurate mode from theoretical equation + # create e^(-j phi) first + + # angle = -rt_phi + # ephi = torch.view_as_complex(torch.stack([angle.cos(), angle.sin()], dim=-1)) ## this sign is from the negativity of phase lag + # a_ephi = -a * ephi + # t = torch.view_as_real((r + a_ephi).div(1 + r * a_ephi)) + # if(intensity): + # t = get_complex_energy(t) + # else: + # t = get_complex_magnitude(t) + ra_cosphi_by_n2 = -2 * r * a * rt_phi.cos() + t = (a * a + r * r + ra_cosphi_by_n2) / (1 + r * r * a * a + ra_cosphi_by_n2) + if not intensity: + # as long as a is not equal to r, t cannot be 0. + t = t.sqrt() + + return t + + +@torch.jit.script +def mrr_roundtrip_phase_to_tr_grad_fused(rt_phi, a: float = 0.8, r: float = 0.9, intensity: bool = False): + """ + description: round trip phase shift to the gradient of field transmission + rt_phi {torch.Tensor or np.ndarray} abs of roundtrip phase shift (abs(phase lag)). range from abs([-pi, 0])=[0, pi]\\ + a {scalar} attenuation coefficient\\ + r {scalar} self-coupling coefficient\\ + intensity {bool scalar} whether output intensity tranmission or field transmission\\ + return g {torch.Tensor or np.ndarray} the gradient of mrr through port field/intensity transmission + """ + if not intensity: + g = (a * r * (a ** 2 - 1) * (r ** 2 - 1) * rt_phi.sin()) / ( + (a ** 2 + r ** 2 - 2 * a * r * rt_phi.cos()) ** (1 / 2) + * (a ** 2 * r ** 2 + 1 - 2 * a * r * rt_phi.cos()) ** 1.5 + ) + else: + g = ((a ** 2 - 1) * (r ** 2 - 1) * 2 * a * r * rt_phi.sin()) / ( + a ** 2 * r ** 2 + 1 - 2 * a * r * rt_phi.cos() + ) ** 2 + return g + + +def mrr_roundtrip_phase_to_tr_func(a: float = 0.8, r: float = 0.9, intensity: bool = False): + c1 = -2 * a * r + c2 = a * a + r * r + c3 = 1 + r * r * a * a - a * a - r * r + c4 = (a ** 2 - 1) * (r ** 2 - 1) * 2 * a * r + + class MRRRoundTripPhaseToTrFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, input): + ctx.save_for_backward(input) + # ra_cosphi_by_n2 = input.cos().mul_(c1) + # numerator = ra_cosphi_by_n2.add_(c2) + # denominator = numerator.add(c3) + # t = numerator / denominator + t = input.cos().mul_(c1).add_(c2 + c3).reciprocal_().mul_(-c3).add_(1) + if not intensity: + # as long as a is not equal to r, t cannot be 0. + t.sqrt_() + return t + + @staticmethod + def backward(ctx, grad_output): + (input,) = ctx.saved_tensors + denominator = input.cos().mul_(c1).add_(c2 + c3) + + if intensity: + denominator.square_() + numerator = input.sin().mul_(c4) + else: + numerator = input.sin().mul_(c4 / 2) + denominator = denominator.sub(1).pow_(1.5).mul_(denominator.sub(c3).sqrt_()) + grad_input = numerator.div_(denominator).mul_(grad_output) + return grad_input + + return MRRRoundTripPhaseToTrFunction.apply + + +def mrr_roundtrip_phase_to_out_phase(rt_phi, a, r): + """ + description: from round trip phase to output phase response \\ + rt_phi {torch.Tensor or np.ndarray} round trip phase shift\\ + a {scalar} attenuation coefficient\\ + r {scalar} coupling coefficient\\ + return phase {torch.Tensor or np.ndarray} output phase response + """ + if isinstance(rt_phi, torch.Tensor): + arctan = torch.atan2 + sin = torch.sin + cos = torch.cos + elif isinstance(rt_phi, np.ndarray): + arctan = np.arctan2 + sin = np.sin + cos = np.cos + else: + raise NotImplementedError + sin_rt_phi = sin(rt_phi) + cos_rt_phi = cos(rt_phi) + # phi = np.pi + rt_phi + arctan(r*sin_rt_phi-2*r*r*a*sin_rt_phi*cos_rt_phi+r*a*a*sin_rt_phi, (a-r*cos_rt_phi)*(1-r*a*cos_rt_phi)) + phi = ( + np.pi + - rt_phi + - arctan(r * sin_rt_phi, a - r * cos_rt_phi) + - arctan(r * a * sin_rt_phi, 1 - r * a * cos_rt_phi) + ) + return phi + + +def mrr_tr_to_out_phase(t, a, r, onesided=True): + """ + description: field transmission to round trip phase shift + t {torch.Tensor or np.ndarray} field transmission from [0,1] \\ + a {scalar} attenuation coefficient\\ + r {scalar} coupling coefficient\\ + onesided {bool scalar} True if only use half of the curve, output phase range [0, pi] + return phi {torch.Tensor or np.ndarray} roune trip phase shift + """ + rt_phi, cos_rt_phi = mrr_tr_to_roundtrip_phase(t, a, r) + if isinstance(t, torch.Tensor): + arctan = torch.atan2 + sin = torch.sin + elif isinstance(t, np.ndarray): + arctan = np.arctan2 + sin = np.sin + else: + raise NotImplementedError + sin_rt_phi = sin(rt_phi) + # phi = np.pi + rt_phi + arctan(r*sin_rt_phi-2*r*r*a*sin_rt_phi*cos_rt_phi+r*a*a*sin_rt_phi, (a-r*cos_rt_phi)*(1-r*a*cos_rt_phi)) + phi = ( + np.pi + - rt_phi + - arctan(r * sin_rt_phi, a - r * cos_rt_phi) + - arctan(r * a * sin_rt_phi, 1 - r * a * cos_rt_phi) + ) + if onesided: + pass + return phi + + +def mrr_roundtrip_phase_to_tr_phase(rt_phi, a, r): + """ + description: from round trip phase to output transmission with phase response \\ + rt_phi {torch.Tensor or np.ndarray} round trip phase shift\\ + a {scalar} attenuation coefficient\\ + r {scalar} coupling coefficient\\ + return output {torch.Tensor or np.ndarray} transmission with phase response + """ + # e^(-j phi) + ephi = torch.view_as_complex(polar_to_complex(mag=None, angle=-rt_phi)) + a_ephi = -a * ephi + output = torch.view_as_real((r + a_ephi) / (1 + r * a_ephi)) + return output + + +@torch.jit.script +def mrr_roundtrip_phase_to_tr_phase_fused(rt_phi, a: float, r: float): + """ + description: from round trip phase to output transmission with phase response \\ + rt_phi {torch.Tensor or np.ndarray} round trip phase shift\\ + a {scalar} attenuation coefficient\\ + r {scalar} coupling coefficient\\ + return output {torch.Tensor or np.ndarray} transmission with phase response + """ + # e^(-j phi) + rt_phi = -rt_phi + rt_phi = torch.complex(rt_phi.cos(), rt_phi.sin()) + rt_phi = -a * rt_phi + output = torch.view_as_real((r + rt_phi) / (1 + r * rt_phi)) + return output + + +def mrr_modulator(t, a=0.9, r=0.8): + """ + @description: all-pass MRR as a modulator. Map from the field intensity of through port transmission to coherent light with phase reponse\\ + @t {torch.Tensor or np.ndarray} field intensity modulation factor\\ + @a {float} attenuation factor from [0,1]. Default: 0.9\\ + @r {float} transmission/self-coupling factor from [0,1]. Default: 0.8\\ + @return: complexed light signal + """ + phase = mrr_tr_to_out_phase(t, a, r) + cos_phase, sin_phase = torch.cos(phase), torch.sin(phase) + output_real = t * cos_phase + output_imag = t * sin_phase + output = torch.stack([output_real, output_imag], dim=-1) + return output + + +def mrr_filter(x, t, a=0.9, r=0.8): + """ + @description: all-pass MRR as a filter. Map from the input complex light signal to output signal with through port transmission\\ + @x {torch.Tensor or np.ndarray} complexed input light signal\\ + @t {torch.Tensor or np.ndarray} field intensity modulation factor\\ + @a {float} attenuation factor from [0,1]. Default: 0.9\\ + @r {float} transmission/self-coupling factor from [0,1]. Default: 0.8\\ + @return: complexed light signal + """ + phase = mrr_tr_to_out_phase(t, a, r) + cos_phase, sin_phase = torch.cos(phase), torch.sin(phase) + phase_shift = torch.complex(cos_phase, sin_phase) + out = t * complex_mult(x, phase_shift) + return out + + +def morr_filter(rt_phi, tr_poly_coeff=None, a=0.9, r=0.8, x=None, coherent=False, intensity=False): + """ + description: from round trip phase shift to output signal \\ + rt_phi {torch.Tensor or np.ndarray, Optional} round trip phase shift. Default set to None \\ + tr_poly_coeff {Callable} polynomial coefficients of tranmission-roundtrip phase curve. Default set to None. None for slow computation\\ + a {float} attenuation factor from [0,1]. Default: 0.9\\ + r {float} transmission/self-coupling factor from [0,1]. Default: 0.8\\ + x {torch.Tensor or np.ndarray, Optional} input complex light signal {None, real tensor or complex tensor}. Default set to None\\ + coherent {bool scalar, Optional} coherent output or not. Default set to False\\ + intensity {bool scalar, Optional} whether use intensity or field transmission. Default set to False\\ + return output {torch.Tensor or np.ndarray} real tensor if incoherent, complex tensor if coherent + """ + if not coherent: + if x is None: + # unit laser input with incoherent light, 1e^j0 + t = mrr_roundtrip_phase_to_tr(rt_phi, a=a, r=r, poly_coeff=tr_poly_coeff, intensity=intensity) + return t + else: + # incoherent light with non-unit input, input must be real number + t = mrr_roundtrip_phase_to_tr(rt_phi, a=a, r=r, poly_coeff=tr_poly_coeff, intensity=intensity) + return x * t + else: + if x is None: + # coherent light with unit laser, 1e^j0, treat morr as a mrr modulator + phase = polar_to_complex(mag=None, angle=mrr_roundtrip_phase_to_out_phase(rt_phi, a, r)) + return phase + else: + # coherent light with complex input + return complex_mult(mrr_roundtrip_phase_to_tr_phase(rt_phi, a, r), x) + + +def mrr_fwhm_to_ng(a, r, radius, lambda0, fwhm): + """ + description: from full-width half maximum (FWHM) and resonance wavelength to group index n_g (Bogaerts et al., Silicon microring resonators, Laser and Photonics Review 2011, Eq.(7))\\ + a {float} Attention coefficient\\ + r {float} Self-coupling coefficient\\ + radius {float} Radius of the MRR (unit: nm)\\ + lambda0 {float} Resonance wavelength (unit: nm)\\ + fwhm {float} bandwidth or full width half maximum (unit: nm)\\ + return n_g {float} Group index of the MRR + """ + n_g = (1 - r * a) * lambda0 ** 2 / (2 * np.pi * np.pi * radius * (r * a) ** 0.5 * fwhm) + return n_g + + +def mrr_ng_to_fsr(lambda0, n_g, radius): + """ + description: Calculate the free-spectral range (FSR) based on the central resonance wavelength, group index and MRR radius. + (Bogaerts et al., Silicon microring resonators, Laser and Photonics Review 2011, Eq.(9))\\ + lambda0 {float} Resonance wavelength (unit: nm)\\ + n_g {float} Group index\\ + radius {float} Radius of the MRR (unit: nm)\\ + return fsr {float} Free-spectral range + """ + fsr = lambda0 ** 2 / (n_g * 2 * np.pi * radius) + return fsr + + +def mrr_finesse(a, r): + """ + description: Calculate the finesse of the MRR, i.e., finesse=FSR/FWHM=pi*sqrt(ra)/(1-ra) (Bogaerts et al., Silicon microring resonators, Laser and Photonics Review 2011, Eq.(21))\\ + a {float} Attention coefficient\\ + r {float} Self-coupling coefficient\\ + return finesse {float} Finesse of the MRR + """ + ra = r * a + finesse = np.pi * ra ** 0.5 / (1 - ra) + return finesse diff --git a/src/chop/nn/optical/functional/quantize.py b/src/chop/nn/optical/functional/quantize.py new file mode 100644 index 000000000..09b7ef28d --- /dev/null +++ b/src/chop/nn/optical/functional/quantize.py @@ -0,0 +1,575 @@ +""" +Description: +Author: Jiaqi Gu (jqgu@utexas.edu) +Date: 2021-06-06 03:15:00 +LastEditors: Jiaqi Gu (jqgu@utexas.edu) +LastEditTime: 2021-06-06 03:15:00 +""" + +import numpy as np +import torch +import logging + + +__all__ = [ + "uniform_quantize_cpu", + "pact_quantize", + "PACT_Act", + "uniform_quantize", + "uniform_quantize_new", + "ewgs_quantize", + "input_quantize_fn", + "weight_quantize_fn", +] + + +class uniform_quantize_cpu(object): + def __init__(self, bits): + super(uniform_quantize_cpu).__init__() + self.bits = bits + + def __call__(self, input): + if self.bits == 32: + out = input + elif self.bits == 1: + out = np.sign(input) + else: + n = float(2**self.bits - 1) + out = np.round(input * n) / n + return out + + +def uniform_quantize(k, gradient_clip=False): + class qfn(torch.autograd.Function): + @staticmethod + def forward(ctx, input): + if k == 32: + out = input + elif k == 1: + out = torch.sign(input) + else: + n = float(2**k - 1) + out = torch.round(input * n) / n + return out + + @staticmethod + def backward(ctx, grad_output): + grad_input = grad_output.clone() + if gradient_clip: + grad_input.clamp_(-1, 1) + return grad_input + + return qfn.apply + + +############ add observer and new quant based on range and zeropoint for activation +def uniform_quantize_new(k, gradient_clip=False): + """ + Support uniform quantization with auto-adjusted input data range + args: + k: bitwidth + scale, zeropoint: obtained from observer + """ + + class qfn(torch.autograd.Function): + @staticmethod + def forward(ctx, input, scale, zero_point): + if k == 32: + out = input + elif k == 1: + out = torch.sign(input) + else: + n = float(2**k - 1) + # out = torch.round(input * n) / n + # out = (torch.clamp(torch.round(input / scale + zero_point), 0, n) - zero_point) * scale + out = input.div(scale).add_(zero_point).round_().clamp_(0, n).sub_(zero_point).mul_(scale) + return out + + @staticmethod + def backward(ctx, grad_output): + grad_input = grad_output.clone() + if gradient_clip: + grad_input.clamp_(-1, 1) + return grad_input, None, None + + return qfn.apply + + +def ewgs_quantize(num_levels, gradient_clip=False, scaling_factor: float = 1e-3): + class EWGS_quantizer(torch.autograd.Function): + """ + Network Quantization with Element-wise Gradient Scaling, CVPR 2021 + https://github.com/cvlab-yonsei/EWGS/blob/main/CIFAR10/custom_modules.py + x_in: continuous inputs within the range of [0,1] + num_levels: number of discrete levels + scaling_factor: backward scaling factor, typically fixed to 1e-3 + x_out: discretized version of x_in within the range of [0,1] + """ + + @staticmethod + def forward(ctx, input): + out = input.mul(num_levels - 1).round_().mul_(1/(num_levels - 1)) + + ctx._scaling_factor = scaling_factor + ctx.save_for_backward(input - out) + return out + + @staticmethod + def backward(ctx, grad_output): + diff = ctx.saved_tensors[0] + delta = ctx._scaling_factor + scale = diff.mul_(grad_output.sign()).mul_(delta).add_(1) + grad_input = grad_output * scale + if gradient_clip: + grad_input.clamp_(-1, 1) + return grad_input + + return EWGS_quantizer.apply + + +class input_quantize_fn(torch.nn.Module): + def __init__(self, in_bit, alg="dorefa", device=torch.device("cuda:0"), quant_ratio=1.0): + """Input quantizer with Quant_Noise supported + Args: + in_bit (int): Input quantization bitwidth. + device (Device, optional): torch Device. Defaults to torch.device("cuda:0"). + quant_ratio (float, optional): Quantization ratio. Defaults to 1.0. + """ + super(input_quantize_fn, self).__init__() + assert 1 <= in_bit <= 32 + self.in_bit = in_bit + self.alg = alg + assert alg in {"dorefa", "normal"}, f"Only support (dorefa, normal), but got {alg}" + self.quant_ratio = quant_ratio + assert 0 <= quant_ratio <= 1, logging.error(f"Wrong quant ratio. Must in [0,1], but got {quant_ratio}") + self.device = device + + # define quant style + # dorefa: clamp to 0-1 + # normal: obtain scale and zero_point via observer + + if self.alg == "dorefa": + self.uniform_q = uniform_quantize(k=in_bit) + elif self.alg == "normal": + self.uniform_q = uniform_quantize_new(k=in_bit) + self.scale = None + self.zero_point = None + ### select scale and zero-point using EMA: exponential moving averages + # AT: MovingAverageMinMaxObserver only support self-defined quant bitwidths for pytorch1.7 + # obs = torch.quantization.observer.MovingAverageMinMaxObserver(averaging_constant=0.01, dtype=torch.quint8, + # qscheme=torch.per_tensor_affine, reduce_range=False, quant_min=0, quant_max=2**self.in_bit-1) + # Thus use our version + ### torch version must be higher than 1.7 + if 1 <= self.in_bit <= 8: # observer does not support higher than 8-bit + self.obs = torch.quantization.observer.MovingAverageMinMaxObserver( + averaging_constant=0.01, + dtype=torch.quint8, + qscheme=torch.per_tensor_affine, + reduce_range=False, + quant_min=0, + quant_max=2**self.in_bit - 1, + ).to(self.device) + else: + self.obs = None + + def set_bitwidth(self, bit: int) -> None: + ### regenerate quantizer without changing observation statistics + if bit != self.in_bit: + if self.alg == "dorefa": + self.uniform_q = uniform_quantize(k=bit) + elif self.alg == "normal": + self.uniform_q = uniform_quantize_new(k=bit) + self.in_bit = bit + + def set_alg(self, alg: str) -> None: + assert alg in {"dorefa", "normal"}, f"Only support (dorefa, normal), but got {alg}" + if alg != self.alg: + if alg == "dorefa": + self.uniform_q = uniform_quantize(k=self.in_bit) + elif alg == "normal": + self.uniform_q = uniform_quantize_new(k=self.in_bit) + self.alg = alg + + def set_quant_ratio(self, quant_ratio=None): + if quant_ratio is None: + ### get recommended value + quant_ratio = [ + None, + 0.2, + 0.3, + 0.4, + 0.5, + 0.55, + 0.6, + 0.7, + 0.8, + 0.83, + 0.86, + 0.89, + 0.92, + 0.95, + 0.98, + 0.99, + 1, + ][min(self.in_bit, 16)] + assert 0 <= quant_ratio <= 1, logging.error(f"Wrong quant ratio. Must in [0,1], but got {quant_ratio}") + self.quant_ratio = quant_ratio + + def forward(self, x): + if self.quant_ratio < 1 and self.training: + ### implementation from fairseq + ### must fully quantize during inference + quant_noise_mask = torch.empty_like(x, dtype=torch.bool).bernoulli_(1 - self.quant_ratio) + else: + quant_noise_mask = None + + if self.in_bit == 32: + input_q = x + elif self.in_bit == 1: + x = x.clamp(0, 1) + input_q = (self.uniform_q(x - 0.5) + 1) / 2 + if quant_noise_mask is not None: + noise = input_q.data.sub_(x.data).masked_fill_(quant_noise_mask, 0) + ### unquantized inputs have to be clamped + input_q = x + noise + else: + ### dorefa-style clamp for input data + if self.alg == "dorefa": + x = x.clamp(0, 1) + input_q = self.uniform_q(x) + elif self.alg == "normal": + if self.obs is not None: + if self.training: + self.obs(x) + scale, zero_point = self.obs.calculate_qparams() + # convert scale and zero_point type from qint8 + self.scale = scale.to(x) + self.zero_point = zero_point.to(x) + input_q = self.uniform_q(x, self.scale, self.zero_point) + else: + input_q = x # if no observer (in_bit > 8), do not quantize + else: + raise NotImplementedError + + # add noise + if quant_noise_mask is not None: + noise = input_q.data.sub_(x.data).masked_fill_(quant_noise_mask, 0) + ### unquantized inputs have to be clamped + input_q = x + noise + + return input_q + + +class weight_quantize_fn(torch.nn.Module): + def __init__(self, w_bit, mode="oconv", alg="dorefa", quant_ratio=1.0): + """Differentiable weight quantizer. Support different algorithms. Support Quant-Noise with partial quantization. + + Args: + w_bit (int): quantization bitwidth + mode (str, optional): Different mode indicates different NN architectures. Defaults to "oconv". + alg (str, optional): Quantization algorithms. [dorefa, dorefa_sym, qnn, dorefa_pos] Defaults to "dorefa". + quant_ratio (float, optional): Quantization ratio to support full-precision gradient flow. Defaults to 1.0. + """ + super(weight_quantize_fn, self).__init__() + assert 1 <= w_bit <= 32, logging.error(f"Only support 1 - 32 bit quantization, but got {w_bit}") + self.w_bit = w_bit + self.alg = alg + self.mode = mode + assert alg in {"dorefa", "dorefa_sym", "qnn", "dorefa_pos"}, logging.error( + f"Only support (dorefa, dorefa_sym, qnn, dorefa_pos) algorithms, but got {alg}" + ) + self.quant_ratio = quant_ratio + assert 0 <= quant_ratio <= 1, logging.error(f"Wrong quant ratio. Must in [0,1], but got {quant_ratio}") + self.uniform_q = uniform_quantize(k=w_bit, gradient_clip=True) + + def set_quant_ratio(self, quant_ratio=None): + if quant_ratio is None: + ### get recommended value + quant_ratio = [ + None, + 0.2, + 0.3, + 0.4, + 0.5, + 0.55, + 0.6, + 0.7, + 0.8, + 0.83, + 0.86, + 0.89, + 0.92, + 0.95, + 0.98, + 0.99, + 1, + ][min(self.w_bit, 16)] + assert 0 <= quant_ratio <= 1, logging.error(f"Wrong quant ratio. Must in [0,1], but got {quant_ratio}") + self.quant_ratio = quant_ratio + + def set_bitwidth(self, bit: int) -> None: + ### regenerate quantizer without changing observation statistics + if bit != self.w_bit: + self.uniform_q = uniform_quantize(k=bit, gradient_clip=True) + self.w_bit = bit + + def forward(self, x): + if self.quant_ratio < 1 and self.training: + ### implementation from fairseq + ### must fully quantize during inference + quant_noise_mask = torch.empty_like(x, dtype=torch.bool).bernoulli_(1 - self.quant_ratio) + else: + quant_noise_mask = None + + if self.w_bit == 32: + weight_q = torch.tanh(x) + weight_q = weight_q / torch.max(torch.abs(weight_q)) + elif self.w_bit == 1: + if self.mode == "ringonn": + weight_q = (self.uniform_q(x) / 4) + 0.5 + else: + if self.alg == "dorefa": + E = x.data.abs().mean() + weight_q = (self.uniform_q(x / E) * E + E) / 2 # [0, E] + if quant_noise_mask is not None: + x = (x + E) / 2 + noise = weight_q.data.sub_(x.data).masked_fill_(quant_noise_mask, 0) + ### unquantized weights have to follow reparameterization, i.e., tanh and scale + weight_q = x + noise + elif self.alg == "dorefa_sym": + E = x.data.abs().mean() + weight_q = self.uniform_q(x / E) * E # [-E, E] + if quant_noise_mask is not None: + noise = weight_q.data.sub_(x.data).masked_fill_(quant_noise_mask, 0) + ### unquantized weights have to follow reparameterization, i.e., tanh and scale + weight_q = x + noise + else: + assert NotImplementedError + else: + if self.alg == "dorefa": + weight = torch.tanh(x) # [-1, 1] + weight = weight / 2 / torch.max(torch.abs(weight.data)) + 0.5 + # weight = weight / 2 + 0.5 + weight_q = self.uniform_q(weight) + if quant_noise_mask is not None: + noise = weight_q.data.sub_(weight.data).masked_fill_(quant_noise_mask, 0) + ### unquantized weights have to follow reparameterization, i.e., tanh and scale + weight_q = weight + noise + + elif self.alg == "dorefa_sym": + weight = torch.tanh(x) # [-1, 1] + r = torch.max(torch.abs(weight.data)) + # weight = weight / 2 + 0.5 + weight_q = self.uniform_q(weight / (2 * r) + 0.5) * (2 * r) - r + if quant_noise_mask is not None: + noise = weight_q.data.sub_(weight.data).masked_fill_(quant_noise_mask, 0) + ### unquantized weights have to follow reparameterization, i.e., tanh + weight_q = weight + noise + elif self.alg == "dorefa_pos": + weight = torch.tanh(x) # [-1, 1] + r = torch.max(torch.abs(weight.data)) + weight = weight + r + # weight = weight / 2 + 0.5 + weight_q = self.uniform_q(weight / (2 * r)) * 2 * r + if quant_noise_mask is not None: + noise = weight_q.data.sub_(weight.data).masked_fill_(quant_noise_mask, 0) + ### unquantized weights have to follow reparameterization, i.e., tanh + weight_q = weight + noise + + elif self.alg == "qnn": + x_min = torch.min(x.data) + x_max = torch.max(x.data) + x_range = x_max - x_min + weight_q = self.uniform_q((x - x_min) / x_range) * x_range + x_min + if quant_noise_mask is not None: + noise = weight_q.data.sub_(x.data).masked_fill_(quant_noise_mask, 0) + ### unquantized weights have to follow reparameterization, i.e., tanh + weight_q = x + noise + else: + assert NotImplementedError + + return weight_q + + +# PACT activation: https://arxiv.org/pdf/1805.06085.pdf +class PACT_QuantFunc(torch.autograd.Function): + r"""PACT (PArametrized Clipping acTivation) quantization function for activations. + Implements a :py:class:`torch.autograd.Function` for quantizing activations in :math:`Q` bits using the PACT strategy. + In forward propagation, the function is defined as + + .. math:: + \mathbf{y} = f(\mathbf{x}) = 1/\varepsilon \cdot \left\lfloor\mathrm{clip}_{ [0,\alpha) } (\mathbf{x})\right\rfloor \cdot \varepsilon + + where :math:`\varepsilon` is the quantization precision: + + .. math:: + \varepsilon = \alpha / (2^Q - 1) + + In backward propagation, using the Straight-Through Estimator, the gradient of the function is defined as + + .. math:: + \mathbf{\nabla}_\mathbf{x} \mathcal{L} &\doteq \mathbf{\nabla}_\mathbf{y} \mathcal{L} + + It can be applied by using its static `.apply` method: + + :param input: the tensor containing :math:`x`, the activations to be quantized. + :type input: `torch.Tensor` + :param eps: the precomputed value of :math:`\varepsilon`. + :type eps: `torch.Tensor` or float + :param alpha: the value of :math:`\alpha`. + :type alpha: `torch.Tensor` or float + :param delta: constant to sum to `eps` for numerical stability (default unused, 0 ). + :type delta: `torch.Tensor` or float + + :return: The quantized input activations tensor. + :rtype: `torch.Tensor` + """ + + @staticmethod + def forward(ctx, input, eps, alpha): + where_input_clipped = (input < 0) | (input >= alpha) + where_input_ltalpha = input < alpha + ctx.save_for_backward(where_input_clipped, where_input_ltalpha) + return ((input / (eps)).floor() * eps).clamp(0.0, alpha.data[0] - eps.data[0]) + + @staticmethod + def backward(ctx, grad_output): + # see Hubara et al., Section 2.3 + where_input_clipped, where_input_ltalpha = ctx.saved_tensors + # zero = torch.zeros(1, device=where_input_nonclipped.device) + grad_input = grad_output.masked_fill(where_input_clipped, 0) + # grad_input = torch.where(where_input_nonclipped, grad_output, zero) + grad_alpha = grad_output.masked_fill(where_input_ltalpha, 0).sum().expand(1) + # grad_alpha = torch.where(where_input_gtalpha, grad_output, zero).sum().expand(1) + return grad_input, None, grad_alpha + + +pact_quantize = PACT_QuantFunc.apply + + +class PACT_Act(torch.nn.Module): + r"""PACT (PArametrized Clipping acTivation) activation. + Implements a :py:class:`torch.nn.Module` to implement PACT-style activations. It is meant to replace :py:class:`torch.nn.ReLU`, :py:class:`torch.nn.ReLU6` and + similar activations in a PACT-quantized network. + This layer can also operate in a special mode, defined by the `statistics_only` member, in which the layer runs in + forward-prop without quantization, collecting statistics on the activations that can then be + used to reset the value of :math:`\alpha`. + In this mode, the layer collects: + - tensor-wise maximum value ever seen + - running average with momentum 0.9 + - running variance with momentum 0.9 + """ + + def __init__( + self, + precision=None, + alpha=1.0, + backprop_alpha=True, + statistics_only=False, + leaky=None, + device=torch.device("cuda"), + ): + r"""Constructor. Initializes a :py:class:`torch.nn.Parameter` for :math:`\alpha` and sets + up the initial value of the `statistics_only` member. + :param precision: instance defining the current quantization level (default `None`). + :type precision: :py:class:`nemo.precision.Precision` + :param alpha: the value of :math:`\alpha`. + :type alpha: `torch.Tensor` or float + :param backprop_alpha: default `True`; if `False`, do not update the value of `\alpha` with backpropagation. + :type backprop_alpha: bool + :param statistics_only: initialization value of `statistics_only` member. + :type statistics_only: bool + """ + + super(PACT_Act, self).__init__() + self.precision = precision + self.device = device + self.alpha = torch.nn.Parameter(torch.Tensor((alpha,)).to(device), requires_grad=backprop_alpha) + self.alpha_p = alpha + self.statistics_only = statistics_only + self.deployment = False + self.eps_in = None + self.leaky = leaky + # self.requantization_factor = requantization_factor + + # these are only used to gather statistics + self.max = torch.nn.Parameter(torch.zeros_like(self.alpha.data).to(device), requires_grad=False) + self.min = torch.nn.Parameter(torch.zeros_like(self.alpha.data).to(device), requires_grad=False) + self.running_mean = torch.nn.Parameter(torch.zeros_like(self.alpha.data).to(device), requires_grad=False) + self.running_var = torch.nn.Parameter(torch.ones_like(self.alpha.data).to(device), requires_grad=False) + + self.precise = False + + def set_static_precision(self, limit_at_32_bits=True, **kwargs): + r"""Sets static parameters used only for deployment.""" + # item() --> conversion to float + # apparently causes a slight, but not invisibile, numerical divergence + # between FQ and QD stages + self.eps_static = self.alpha.clone().detach() / (2.0 ** (self.precision) - 1) + self.alpha_static = self.alpha.clone().detach() + # D is selected as a power-of-two + D = 2.0 ** torch.ceil(torch.log2(self.requantization_factor * self.eps_static / self.eps_in)) + if not limit_at_32_bits: + self.D = D + else: + self.D = min(D, 2.0 ** (32 - 1 - (self.precision))) + + def get_output_eps(self, eps_in): + r"""Get the output quantum (:math:`\varepsilon`) given the input one. + :param eps_in: input quantum :math:`\varepsilon_{in}`. + :type eps_in: :py:class:`torch.Tensor` + :return: output quantum :math:`\varepsilon_{out}`. + :rtype: :py:class:`torch.Tensor` + """ + + return self.alpha / (2.0 ** (self.precision) - 1) + + def reset_alpha(self, use_max=True, nb_std=5.0): + r"""Reset the value of :math:`\alpha`. If `use_max` is `True`, then the highest tensor-wise value collected + in the statistics collection phase is used. If `False`, the collected standard deviation multiplied by + `nb_std` is used as a parameter + :param use_max: if True, use the tensor-wise maximum value collected in the statistics run as new :math:`\alpha` (default True). + :type use_max: bool + :param nb_std: number of standard deviations to be used to initialize :math:`\alpha` if `use_max` is False. + :type nb_std: float + """ + + if use_max: + self.alpha.data[0] = self.max.item() + else: + self.alpha.data[0] = nb_std * torch.sqrt(self.running_var).item() + + def get_statistics(self): + r"""Returns the statistics collected up to now. + + :return: The collected statistics (maximum, running average, running variance). + :rtype: tuple of floats + """ + return self.max.item(), self.running_mean.item(), self.running_var.item() + + def forward(self, x): + r"""Forward-prop function for PACT-quantized activations. + + See :py:class:`nemo.quant.pact_quant.PACT_QuantFunc` for details on the normal operation performed by this layer. + In statistics mode, it uses a normal ReLU and collects statistics in the background. + :param x: input activations tensor. + :type x: :py:class:`torch.Tensor` + + :return: output activations tensor. + :rtype: :py:class:`torch.Tensor` + """ + + if self.statistics_only: + if self.leaky is None: + x = torch.nn.functional.relu(x) + else: + x = torch.nn.functional.leaky_relu(x, self.leaky) + with torch.no_grad(): + self.max[:] = max(self.max.item(), x.max()) + self.min[:] = min(self.min.item(), x.min()) + self.running_mean[:] = 0.9 * self.running_mean.item() + 0.1 * x.mean() + self.running_var[:] = 0.9 * self.running_var.item() + 0.1 * x.std() * x.std() + return x + else: + eps = self.alpha / (2.0 ** (self.precision) - 1) + return pact_quantize(x, eps, self.alpha + eps) diff --git a/src/chop/nn/optical/functional/torch_train.py b/src/chop/nn/optical/functional/torch_train.py new file mode 100644 index 000000000..c62ba3e91 --- /dev/null +++ b/src/chop/nn/optical/functional/torch_train.py @@ -0,0 +1,857 @@ +""" +Description: +Author: Jiaqi Gu (jqgu@utexas.edu) +Date: 2021-06-06 03:15:06 +LastEditors: Jiaqi Gu (jqgu@utexas.edu) +LastEditTime: 2021-06-06 03:15:06 +""" + +import csv +import os +import random +import time +import traceback +from collections import OrderedDict + +import numpy as np +import torch +from scipy import interpolate +from torch.nn.modules.batchnorm import _BatchNorm +try: + from torchsummary import summary +except: + print("[W] Cannot import torchsummary") +from .general import ensure_dir + +__all__ = [ + "DeterministicCtx", + "set_torch_deterministic", + "set_torch_stochastic", + "get_random_state", + "summary_model", + "save_model", + "BestKModelSaver", + "load_model", + "count_parameters", + "check_converge", + "ThresholdScheduler", + "ThresholdScheduler_tf", + "ValueRegister", + "ValueTracer", + "EMA", + "SWA", + "export_traces_to_csv", + "set_learning_rate", + "get_learning_rate", + "apply_weight_decay", + "disable_bn", + "enable_bn", +] + +class DeterministicCtx: + def __init__(self, random_state: int | None = None) -> None: + self.random_state = random_state + + + def __enter__(self): + self.random_state = random.getstate() + self.numpy_random_state = np.random.get_state() + self.torch_random_state = torch.random.get_rng_state() + self.torch_cuda_random_state = torch.cuda.get_rng_state() + set_torch_deterministic(self.random_state) + return self + + + def __exit__(self, *args): + random.setstate(self.random_state) + np.random.seed(self.numpy_random_state) + np.random.set_state(self.numpy_random_state) + torch.random.set_rng_state(self.torch_random_state) + torch.cuda.set_rng_state(self.torch_cuda_random_state) + +def set_torch_deterministic(random_state: int = 0) -> None: + random_state = int(random_state) % (2**32) + torch.manual_seed(random_state) + np.random.seed(random_state) + if torch.cuda.is_available(): + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + torch.cuda.manual_seed_all(random_state) + random.seed(random_state) + + +def set_torch_stochastic(): + seed = int(time.time() * 1000) % (2**32) + torch.manual_seed(seed) + np.random.seed(seed) + if torch.cuda.is_available(): + torch.backends.cudnn.deterministic = False + torch.cuda.manual_seed_all(seed) + + +def get_random_state(): + return np.random.get_state()[1][0] + + +def summary_model(model, input): + summary(model, input) + + +def save_model(model, path="./checkpoint/model.pt", print_msg=True): + """Save PyTorch model in path + + Args: + model (PyTorch model): PyTorch model + path (str, optional): Full path of PyTorch model. Defaults to "./checkpoint/model.pt". + print_msg (bool, optional): Control of message print. Defaults to True. + """ + dir = os.path.dirname(path) + if not os.path.exists(dir): + os.mkdir(dir) + try: + torch.save(model.state_dict(), path) + if print_msg: + print(f"[I] Model saved to {path}") + except Exception as e: + if print_msg: + print(f"[E] Model failed to be saved to {path}") + traceback.print_exc(e) + + +class BestKModelSaver(object): + def __init__( + self, + k: int = 1, + descend: bool = True, + truncate: int = 2, + metric_name: str = "acc", + format: str = "{:.2f}", + ): + super().__init__() + self.k = k + self.descend = descend + self.truncate = truncate + self.metric_name = metric_name + self.format = format + self.epsilon = 0.1**truncate + self.model_cache = OrderedDict() + + def better_op(self, a, b): + if self.descend: + return a >= b + self.epsilon + else: + return a <= b - self.epsilon + + def __insert_model_record(self, metric, dir, checkpoint_name, epoch=None): + metric = round(metric * 10**self.truncate) / 10**self.truncate + if len(self.model_cache) < self.k: + new_checkpoint_name = ( + f"{checkpoint_name}_{self.metric_name}-" + + self.format.format(metric) + + f"{'' if epoch is None else '_epoch-'+str(epoch)}" + ) + path = os.path.join(dir, new_checkpoint_name + ".pt") + self.model_cache[path] = (metric, epoch) + return path, None + else: + worst_metric, worst_epoch = sorted( + list(self.model_cache.values()), + key=lambda x: x[0], + reverse=False if self.descend else True, + )[0] + if self.better_op(metric, worst_metric): + del_checkpoint_name = ( + f"{checkpoint_name}_{self.metric_name}-" + + self.format.format(worst_metric) + + f"{'' if epoch is None else '_epoch-'+str(worst_epoch)}" + ) + del_path = os.path.join(dir, del_checkpoint_name + ".pt") + try: + del self.model_cache[del_path] + except: + print( + "[W] Cannot remove checkpoint: {} from cache".format(del_path), + flush=True, + ) + new_checkpoint_name = ( + f"{checkpoint_name}_{self.metric_name}-" + + self.format.format(metric) + + f"{'' if epoch is None else '_epoch-'+str(epoch)}" + ) + path = os.path.join(dir, new_checkpoint_name + ".pt") + self.model_cache[path] = (metric, epoch) + return path, del_path + # elif(acc == min_acc): + # new_checkpoint_name = f"{checkpoint_name}_acc-{acc:.2f}{'' if epoch is None else '_epoch-'+str(epoch)}" + # path = os.path.join(dir, new_checkpoint_name+".pt") + # self.model_cache[path] = (acc, epoch) + # return path, None + else: + return None, None + + def get_topk_model_path(self, topk: int = 1): + if topk <= 0: + return [] + if topk > len(self.model_cache): + topk = len(self.model_cache) + return [ + i[0] + for i in sorted( + self.model_cache.items(), key=lambda x: x[1][0], reverse=self.descend + )[:topk] + ] + + def save_model( + self, + model, + metric, + epoch=None, + path="./checkpoint/model.pt", + other_params=None, + save_model=False, + print_msg=True, + ): + """Save PyTorch model in path + + Args: + model (PyTorch model): PyTorch model + acc (scalar): accuracy + epoch (scalar, optional): epoch. Defaults to None + path (str, optional): Full path of PyTorch model. Defaults to "./checkpoint/model.pt". + other_params (dict, optional): Other saved params. Defaults to None + save_model (bool, optional): whether save source code of nn.Module. Defaults to False + print_msg (bool, optional): Control of message print. Defaults to True. + """ + dir = os.path.dirname(path) + ensure_dir(dir) + checkpoint_name = os.path.splitext(os.path.basename(path))[0] + if isinstance(metric, torch.Tensor): + metric = metric.data.item() + new_path, del_path = self.__insert_model_record( + metric, dir, checkpoint_name, epoch + ) + + if del_path is not None: + try: + os.remove(del_path) + print(f"[I] Model {del_path} is removed", flush=True) + except Exception as e: + if print_msg: + print(f"[E] Model {del_path} failed to be removed", flush=True) + traceback.print_exc(e) + + if new_path is None: + if print_msg: + if self.descend: + best_list = list(reversed(sorted(list(self.model_cache.values())))) + else: + best_list = list(sorted(list(self.model_cache.values()))) + print( + f"[I] Not best {self.k}: {best_list}, skip this model (" + + self.format.format(metric) + + f"): {path}", + flush=True, + ) + else: + try: + # torch.save(model.state_dict(), new_path) + if other_params is not None: + saved_dict = other_params + else: + saved_dict = {} + if save_model: + saved_dict.update( + {"model": model, "state_dict": model.state_dict()} + ) + torch.save(saved_dict, new_path) + else: + saved_dict.update({"model": None, "state_dict": model.state_dict()}) + torch.save(saved_dict, new_path) + if print_msg: + if self.descend: + best_list = list( + reversed(sorted(list(self.model_cache.values()))) + ) + else: + best_list = list(sorted(list(self.model_cache.values()))) + + print( + f"[I] Model saved to {new_path}. Current best {self.k}: {best_list}", + flush=True, + ) + except Exception as e: + if print_msg: + print(f"[E] Model failed to be saved to {new_path}", flush=True) + traceback.print_exc(e) + return new_path + + +def load_model( + model, + path="./checkpoint/model.pt", + ignore_size_mismatch: bool = False, + print_msg=True, +): + """Load PyTorch model in path + + Args: + model (PyTorch model): PyTorch model + path (str, optional): Full path of PyTorch model. Defaults to "./checkpoint/model.pt". + ignore_size_mismatch (bool, optional): Whether ignore tensor size mismatch. Defaults to False. + print_msg (bool, optional): Control of message print. Defaults to True. + """ + try: + raw_data = torch.load(path, map_location=lambda storage, location: storage) + if isinstance(raw_data, OrderedDict) and "state_dict" not in raw_data: + ### state_dict: OrderedDict + state_dict = raw_data + else: + ### {"state_dict": ..., "model": ...} + state_dict = raw_data["state_dict"] + load_keys = set(state_dict.keys()) + model_keys = set(model.state_dict().keys()) + common_dict = load_keys & model_keys + diff_dict = load_keys ^ model_keys + extra_keys = load_keys - model_keys + lack_keys = model_keys - load_keys + cur_state_dict = model.state_dict() + if ignore_size_mismatch: + size_mismatch_dict = set( + key + for key in common_dict + if model.state_dict()[key].size() != state_dict[key].size() + ) + print( + f"[W] {size_mismatch_dict} are ignored due to size mismatch", flush=True + ) + common_dict = common_dict - size_mismatch_dict + + cur_state_dict.update({key: state_dict[key] for key in common_dict}) + if len(diff_dict) > 0: + print( + f"[W] Warning! Model is not the same as the checkpoint. not found keys {lack_keys}. extra unused keys {extra_keys}" + ) + + model.load_state_dict(cur_state_dict) + if print_msg: + print(f"[I] Model loaded from {path}") + except Exception as e: + traceback.print_exc(e) + if print_msg: + print(f"[E] Model failed to be loaded from {path}") + + +def count_parameters(model): + return sum(p.numel() for p in model.parameters() if p.requires_grad) + + +def check_converge(trace, epsilon=0.002): + if len(trace) <= 1: + return False + if np.abs(trace[-1] - trace[-2]) / (np.abs(trace[-1]) + 1e-8) < epsilon: + return True + return False + + +class ThresholdScheduler(object): + """Intepolation between begin point and end point. step must be within two endpoints""" + + def __init__(self, step_beg, step_end, thres_beg, thres_end, mode="tanh"): + assert mode in { + "linear", + "tanh", + }, "Threshold scheduler only supports linear and tanh modes" + self.mode = mode + self.step_beg = step_beg + self.step_end = step_end + self.thres_beg = thres_beg + self.thres_end = thres_end + self.func = self.createFunc() + + def normalize(self, step, factor=2): + return (step - self.step_beg) / (self.step_end - self.step_beg) * factor + + def createFunc(self): + if self.mode == "linear": + return lambda x: (self.thres_end - self.thres_beg) * x + self.thres_beg + elif self.mode == "tanh": + x = self.normalize( + np.arange(self.step_beg, self.step_end + 1).astype(np.float32) + ) + y = np.tanh(x) * (self.thres_end - self.thres_beg) + self.thres_beg + return interpolate.interp1d(x, y) + + def __call__(self, x): + return self.func(self.normalize(x)).tolist() + + +class ThresholdScheduler_tf(object): + """smooth increasing threshold with tensorflow model pruning scheduler""" + + def __init__(self, step_beg, step_end, thres_beg, thres_end): + import tensorflow as tf + import tensorflow_model_optimization as tfmot + + gpus = tf.config.list_physical_devices("GPU") + if gpus: + # Restrict TensorFlow to only allocate 1GB of memory on the first GPU + try: + for gpu in gpus: + tf.config.experimental.set_memory_growth(gpu, True) + except RuntimeError as e: + # Virtual devices must be set before GPUs have been initialized + print(e) + self.step_beg = step_beg + self.step_end = step_end + self.thres_beg = thres_beg + self.thres_end = thres_end + if thres_beg < thres_end: + self.thres_min = thres_beg + self.thres_range = thres_end - thres_beg + self.descend = False + + else: + self.thres_min = thres_end + self.thres_range = thres_beg - thres_end + self.descend = True + + self.pruning_schedule = tfmot.sparsity.keras.PolynomialDecay( + initial_sparsity=0, + final_sparsity=0.9999999, + begin_step=self.step_beg, + end_step=self.step_end, + ) + + def __call__(self, x): + if x < self.step_beg: + return self.thres_beg + elif x > self.step_end: + return self.thres_end + res_norm = self.pruning_schedule(x)[1].numpy() + if self.descend == False: + res = res_norm * self.thres_range + self.thres_beg + else: + res = self.thres_beg - res_norm * self.thres_range + + if np.abs(res - self.thres_end) <= 1e-6: + res = self.thres_end + return res + + +class ValueRegister(object): + def __init__(self, operator, name="", show=True): + self.op = operator + self.cache = None + self.show = show + self.name = name if len(name) > 0 else "value" + + def register_value(self, x): + self.cache = self.op(x, self.cache) if self.cache is not None else x + if self.show: + print(f"Recorded {self.name} is {self.cache}") + + +class ValueTracer(object): + def __init__(self, show=True): + self.cache = {} + self.show = show + + def add_value(self, name, value, step): + if name not in self.cache: + self.cache[name] = {} + self.cache[name][step] = value + if self.show: + print(f"Recorded {name}: step = {step}, value = {value}") + + def get_trace_by_name(self, name): + return self.cache.get(name, {}) + + def get_all_traces(self): + return self.cache + + def __len__(self): + return len(self.cache) + + def get_num_trace(self): + return len(self.cache) + + def get_len_trace_by_name(self, name): + return len(self.cache.get(name, {})) + + def dump_trace_to_file(self, name, file): + if name not in self.cache: + print(f"[W] Trace name '{name}' not found in tracer") + return + torch.save(self.cache[name], file) + print(f"[I] Trace {name} saved to {file}") + + def dump_all_traces_to_file(self, file): + torch.save(self.cache, file) + print(f"[I] All traces saved to {file}") + + def load_all_traces_from_file(self, file): + self.cache = torch.load(file) + return self.cache + + +class EMA(object): + def __init__(self, mu): + super().__init__() + self.mu = mu + self.shadow = {} + + def register(self, name, val): + self.shadow[name] = val.clone().data + + def __call__(self, name, x, mask=None): + if name not in self.shadow: + self.register(name, x) + return x.data + + old_average = self.shadow[name] + new_average = (1 - self.mu) * x + self.mu * old_average + if mask is not None: + new_average[mask].copy_(old_average[mask]) + self.shadow[name] = new_average.clone() + return new_average.data + + +class SWA(torch.nn.Module): + """Stochastic Weight Averging. + + # Paper + title: Averaging Weights Leads to Wider Optima and Better Generalization + link: https://arxiv.org/abs/1803.05407 + + # Arguments + start_epoch: integer, epoch when swa should start. + lr_schedule: string, type of learning rate schedule. + swa_lr: float, learning rate for swa. + swa_lr2: float, upper bound of cyclic learning rate. + swa_freq: integer, length of learning rate cycle. + batch_size integer, batch size (for batch norm with generator) + verbose: integer, verbosity mode, 0 or 1. + """ + + def __init__( + self, + model: torch.nn.Module, + optimizer: torch.optim.Optimizer, + start_epoch: int, + epochs: int, # total epochs + steps, # total steps per epoch + lr_schedule="manual", + swa_lr="auto", + swa_lr2="auto", + swa_freq=1, + batch_size=None, + verbose=0, + ): + super().__init__() + self.model = model + self.optimizer = optimizer + self.start_epoch = start_epoch - 1 + self.epochs = epochs + self.steps = steps + self.lr_schedule = lr_schedule + self.swa_lr = swa_lr + + # if no user determined upper bound, make one based off of the lower bound + self.swa_lr2 = swa_lr2 if swa_lr2 is not None else 10 * swa_lr + self.swa_freq = swa_freq + self.batch_size = batch_size + self.verbose = verbose + + if start_epoch < 2: + raise ValueError('"swa_start" attribute cannot be lower than 2.') + + schedules = ["manual", "constant", "cyclic"] + + if self.lr_schedule not in schedules: + raise ValueError( + '"{}" is not a valid learning rate schedule'.format(self.lr_schedule) + ) + + if self.lr_schedule == "cyclic" and self.swa_freq < 2: + raise ValueError('"swa_freq" must be higher than 1 for cyclic schedule.') + + if self.swa_lr == "auto" and self.swa_lr2 != "auto": + raise ValueError( + '"swa_lr2" cannot be manually set if "swa_lr" is automatic.' + ) + + if ( + self.lr_schedule == "cyclic" + and self.swa_lr != "auto" + and self.swa_lr2 != "auto" + and self.swa_lr > self.swa_lr2 + ): + raise ValueError('"swa_lr" must be lower than "swa_lr2".') + + def on_train_begin(self): + self.lr_record = [] + + if self.start_epoch >= self.epochs - 1: + raise ValueError('"swa_start" attribute must be lower than "epochs".') + + self.init_lr = self.optimizer.param_groups[0]["lr"] + + # automatic swa_lr + if self.swa_lr == "auto": + self.swa_lr = 0.1 * self.init_lr + + if self.init_lr < self.swa_lr: + raise ValueError('"swa_lr" must be lower than rate set in optimizer.') + + # automatic swa_lr2 between initial lr and swa_lr + if self.lr_schedule == "cyclic" and self.swa_lr2 == "auto": + self.swa_lr2 = self.swa_lr + (self.init_lr - self.swa_lr) * 0.25 + + self._check_batch_norm() + + if self.has_batch_norm and self.batch_size is None: + raise ValueError( + '"batch_size" needs to be set for models with batch normalization layers.' + ) + + def on_epoch_begin(self, epoch): + # input epoch is from 0 to epochs-1 + + self.current_epoch = epoch + self._scheduler(epoch) + + # constant schedule is updated epoch-wise + if self.lr_schedule == "constant": + self._update_lr(epoch) + + if self.is_swa_start_epoch: + # self.swa_weights = self.model.get_weights() + self.swa_weights = { + name: p.data.clone() for name, p in self.model.named_parameters() + } + + if self.verbose > 0: + print( + "\nEpoch %05d: starting stochastic weight averaging" % (epoch + 1) + ) + + if self.is_batch_norm_epoch: + self._set_swa_weights(epoch) + + if self.verbose > 0: + print( + "\nEpoch %05d: reinitializing batch normalization layers" + % (epoch + 1) + ) + + self._reset_batch_norm() + + if self.verbose > 0: + print( + "\nEpoch %05d: running forward pass to adjust batch normalization" + % (epoch + 1) + ) + + def on_batch_begin(self, batch): + # update lr each batch for cyclic lr schedule + if self.lr_schedule == "cyclic": + self._update_lr(self.current_epoch, batch) + + if self.is_batch_norm_epoch: + batch_size = self.batch_size + # this is for tensorflow momentum, applied to the running stat + # momentum = batch_size / (batch * batch_size + batch_size) + + # we need to convert it to torch momentum, applied to the batch stat + momentum = 1 - batch_size / (batch * batch_size + batch_size) + + for layer in self.batch_norm_layers: + layer.momentum = momentum + + def on_epoch_end(self, epoch): + if self.is_swa_start_epoch: + self.swa_start_epoch = epoch + + if self.is_swa_epoch and not self.is_batch_norm_epoch: + self.swa_weights = self._average_weights(epoch) + + def on_train_end(self): + if not self.has_batch_norm: + self._set_swa_weights(self.epochs) + else: + self._restore_batch_norm() + + ## TODO: what is meaning here? + # for batch_lr in self.lr_record: + # self.model.history.history.setdefault("lr", []).append(batch_lr) + + def _scheduler(self, epoch): + swa_epoch = epoch - self.start_epoch + + self.is_swa_epoch = epoch >= self.start_epoch and swa_epoch % self.swa_freq == 0 + self.is_swa_start_epoch = epoch == self.start_epoch + self.is_batch_norm_epoch = epoch == self.epochs - 1 and self.has_batch_norm + + def _average_weights(self, epoch): + # return [ + # (swa_w * ((epoch - self.start_epoch) / self.swa_freq) + w) + # / ((epoch - self.start_epoch) / self.swa_freq + 1) + # for swa_w, w in zip(self.swa_weights, self.model.get_weights()) + # ] + out = {} + with torch.no_grad(): + for name, w in self.model.named_parameters(): + swa_w = self.swa_weights[name] + out[name] = ( + swa_w * ((epoch - self.start_epoch) / self.swa_freq) + w.data + ) / ((epoch - self.start_epoch) / self.swa_freq + 1) + return out + + def _update_lr(self, epoch, batch=None): + if self.is_batch_norm_epoch: + lr = 0 + # K.set_value(self.model.optimizer.lr, lr) + set_learning_rate(lr, self.optimizer) + elif self.lr_schedule == "constant": + lr = self._constant_schedule(epoch) + # K.set_value(self.model.optimizer.lr, lr) + set_learning_rate(lr, self.optimizer) + elif self.lr_schedule == "cyclic": + lr = self._cyclic_schedule(epoch, batch) + # K.set_value(self.model.optimizer.lr, lr) + set_learning_rate(lr, self.optimizer) + self.lr_record.append(lr) + + def _constant_schedule(self, epoch): + t = epoch / self.start_epoch + lr_ratio = self.swa_lr / self.init_lr + if t <= 0.5: + factor = 1.0 + elif t <= 0.9: + factor = 1.0 - (1.0 - lr_ratio) * (t - 0.5) / 0.4 + else: + factor = lr_ratio + return self.init_lr * factor + + def _cyclic_schedule(self, epoch, batch): + """Designed after Section 3.1 of Averaging Weights Leads to + Wider Optima and Better Generalization(https://arxiv.org/abs/1803.05407) + """ + # steps are mini-batches per epoch, equal to training_samples / batch_size + steps = self.steps + + swa_epoch = (epoch - self.start_epoch) % self.swa_freq + cycle_length = self.swa_freq * steps + + # batch 0 indexed, so need to add 1 + i = (swa_epoch * steps) + (batch + 1) + if epoch >= self.start_epoch: + t = (((i - 1) % cycle_length) + 1) / cycle_length + return (1 - t) * self.swa_lr2 + t * self.swa_lr + else: + return self._constant_schedule(epoch) + + def _set_swa_weights(self, epoch): + # self.model.set_weights(self.swa_weights) + for name, p in self.model.named_parameters(): + p.data.copy_(self.swa_weights[name]) + + if self.verbose > 0: + print( + "\nEpoch %05d: final model weights set to stochastic weight average" + % (epoch + 1) + ) + + def _check_batch_norm(self): + self.batch_norm_momentums = [] + self.batch_norm_layers = [] + self.has_batch_norm = False + self.running_bn_epoch = False + + for layer in self.model.modules(): + if isinstance(layer, _BatchNorm): + self.has_batch_norm = True + self.batch_norm_momentums.append(layer.momentum) + self.batch_norm_layers.append(layer) + + if self.verbose > 0 and self.has_batch_norm: + print( + "Model uses batch normalization. SWA will require last epoch " + "to be a forward pass and will run with no learning rate" + ) + + def _reset_batch_norm(self): + for layer in self.batch_norm_layers: + # initialized moving mean and + # moving var weights + layer.reset_running_stats() + + def _restore_batch_norm(self): + for layer, momentum in zip(self.batch_norm_layers, self.batch_norm_momentums): + layer.momentum = momentum + + +def export_traces_to_csv(trace_file, csv_file, fieldnames=None): + traces = torch.load(trace_file) + + with open(csv_file, "w", newline="") as csvfile: + if fieldnames is None: + fieldnames = list(traces.keys()) + writer = csv.DictWriter(csvfile, fieldnames=fieldnames) + + writer.writeheader() + max_len = max([len(traces[field]) for field in fieldnames]) + + for idx in range(max_len): + row = {} + for field in fieldnames: + value = traces[field][idx] if idx < len(traces[field]) else "" + row[field] = ( + value.data.item() if isinstance(value, torch.Tensor) else value + ) + writer.writerow(row) + + +def set_learning_rate(lr, optimizer): + for param_group in optimizer.param_groups: + param_group["lr"] = lr + + +def get_learning_rate(optimizer): + return optimizer.param_groups[0]["lr"] + + +def apply_weight_decay(W, decay_rate, learning_rate, mask=None): + # in mask, 1 represents fixed variables, 0 represents trainable variables + if mask is not None: + W[~mask] -= W[~mask] * decay_rate * learning_rate + else: + W -= W * decay_rate * learning_rate + + +def disable_bn(model: torch.nn.Module) -> None: + for m in model.modules(): + if isinstance( + m, + ( + torch.nn.BatchNorm1d, + torch.nn.BatchNorm2d, + torch.nn.BatchNorm3d, + torch.nn.SyncBatchNorm, + ), + ): + m.eval() + + +def enable_bn(model: torch.nn.Module) -> None: + for m in model.modules(): + if isinstance( + m, + ( + torch.nn.BatchNorm1d, + torch.nn.BatchNorm2d, + torch.nn.BatchNorm3d, + torch.nn.SyncBatchNorm, + ), + ): + m.train() diff --git a/src/chop/nn/optical/modules/__init__.py b/src/chop/nn/optical/modules/__init__.py new file mode 100644 index 000000000..46b6f21e9 --- /dev/null +++ b/src/chop/nn/optical/modules/__init__.py @@ -0,0 +1,7 @@ +from .morr_linear import AllPassMORRCirculantLinear +from .morr_conv2d import AllPassMORRCirculantConv2d + +optical_module_map = { + "linear_morr": AllPassMORRCirculantLinear, + "conv2d_morr": AllPassMORRCirculantConv2d, +} \ No newline at end of file diff --git a/src/chop/nn/optical/modules/base_layer.py b/src/chop/nn/optical/modules/base_layer.py new file mode 100644 index 000000000..cb1fbe9dd --- /dev/null +++ b/src/chop/nn/optical/modules/base_layer.py @@ -0,0 +1,71 @@ +""" +Description: +Author: Jiaqi Gu (jqgu@utexas.edu) +Date: 2021-06-08 18:55:05 +LastEditors: Jiaqi Gu (jqgu@utexas.edu) +LastEditTime: 2021-06-08 18:55:05 +""" +from typing import Any, Dict, Optional +import torch +from torch import nn +from torch.types import Device + +__all__ = ["ONNBaseLayer"] + + +class ONNBaseLayer(nn.Module): + def __init__(self, *args, device: Device = torch.device("cpu"), **kwargs) -> None: + super().__init__(*args, **kwargs) + # cuda or cpu, defaults to cpu + self.device = device + + def build_parameters(self) -> None: + raise NotImplementedError + + def reset_parameters(self) -> None: + raise NotImplementedError + + @classmethod + def from_layer(cls, layer: nn.Module, *args, **kwargs) -> nn.Module: + raise NotImplementedError + + def get_num_parameters(self) -> int: + return sum(p.numel() for p in self.parameters() if p.requires_grad) + + def enable_fast_forward(self) -> None: + self.fast_forward_flag = True + + def disable_fast_forward(self) -> None: + self.fast_forward_flag = False + + def set_phase_variation(self, noise_std: float, random_state: Optional[int] = None) -> None: + self.phase_noise_std = noise_std + + def set_gamma_noise(self, noise_std: float, random_state: Optional[int] = None) -> None: + self.gamma_noise_std = noise_std + + def set_crosstalk_factor(self, crosstalk_factor: float) -> None: + self.crosstalk_factor = crosstalk_factor + + def set_weight_bitwidth(self, w_bit: int) -> None: + self.w_bit = w_bit + + def set_input_bitwidth(self, in_bit: int) -> None: + self.in_bit = in_bit + + def load_parameters(self, param_dict: Dict[str, Any]) -> None: + """ + description: update parameters based on this parameter dictionary\\ + param param_dict {dict of dict} {param_name: param_tensor, ...} + """ + for name, param in param_dict.items(): + getattr(self, name).data.copy_(param) + + def switch_mode_to(self, mode: str) -> None: + self.mode = mode + + def forward(self, x): + raise NotImplementedError + + def extra_repr(self) -> str: + return "" diff --git a/src/chop/nn/optical/modules/morr_conv2d.py b/src/chop/nn/optical/modules/morr_conv2d.py new file mode 100644 index 000000000..83f0bf1fe --- /dev/null +++ b/src/chop/nn/optical/modules/morr_conv2d.py @@ -0,0 +1,458 @@ +""" +Author: Jiaqi Gu (jqgu@utexas.edu) +Date: 2021-01-27 01:08:44 +LastEditors: Jiaqi Gu (jqgu@utexas.edu) +LastEditTime: 2021-07-18 00:40:18 +""" +from typing import Optional, Tuple + +import numpy as np +import torch +import torch.fft +from ..functional import im2col_2d, toeplitz +from ..functional import logger +from ..functional import morr_uniform_ +from ..functional import input_quantize_fn, weight_quantize_fn +from torch import Tensor, nn +from torch.nn import Parameter, init +from torch.nn.modules.utils import _pair +from torch.types import Device, _size +from ..functional import MORRConfig_20um_MQ +from ..functional import mrr_roundtrip_phase_to_tr_func, mrr_roundtrip_phase_to_tr_fused + + +from .base_layer import ONNBaseLayer + +__all__ = ["AllPassMORRCirculantConv2d"] + + +class AllPassMORRCirculantConv2d(ONNBaseLayer): + """ + All-pass MORR Conv2d layer, assumes (1) block-circulant matrix (2) differential rails (3) learnable balancing factors. + J. Gu, et al., "SqueezeLight: Towards Scalable Optical Neural Networks with Multi-Operand Ring Resonators" + https://doi.org/10.23919/DATE51398.2021.9474147 + """ + + __constants__ = [ + "stride", + "padding", + "dilation", + "groups", + "padding_mode", + "output_padding", + "in_channels", + "out_channels", + "kernel_size", + "miniblock", + ] + __annotations__ = {"bias": Optional[torch.Tensor]} + + _in_channels: int + out_channels: int + kernel_size: Tuple[int, ...] + stride: Tuple[int, ...] + padding: Tuple[int, ...] + dilation: Tuple[int, ...] + transposed: bool + output_padding: Tuple[int, ...] + groups: int + padding_mode: str + weight: Tensor + bias: Optional[Tensor] + miniblock: int + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: _size, + stride: _size = 1, + padding: _size = 0, + dilation: _size = 1, + groups: int = 1, + bias: bool = True, + miniblock: int = 4, + ### morr parameter + MORRConfig=MORRConfig_20um_MQ, + morr_init: bool = True, # whether to use initialization method customized for MORR + ### trainable MORR nonlinearity + trainable_morr_bias: bool = False, + trainable_morr_scale: bool = False, + device: Device = torch.device("cuda"), + ) -> None: + super(AllPassMORRCirculantConv2d, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = _pair(kernel_size) + self.stride = _pair(stride) + self.padding = _pair(padding) + self.dilation = _pair(dilation) + self.groups = groups + assert groups == 1, f"Currently group convolution is not supported, but got group: {groups}" + self.in_channels_flat = self.in_channels * self.kernel_size[0] * self.kernel_size[1] + self.grid_dim_x = int(np.ceil(self.in_channels_flat / miniblock)) + self.grid_dim_y = int(np.ceil(self.out_channels / miniblock)) + self.in_channels_pad = self.grid_dim_x * miniblock + self.out_channels_pad = self.grid_dim_y * miniblock + self.miniblock = miniblock + + self.v_max = 10.8 + self.v_pi = 4.36 + self.gamma = np.pi / self.v_pi ** 2 + self.w_bit = 32 + self.in_bit = 32 + self.MORRConfig = MORRConfig + self.morr_init = morr_init + self.mrr_a = MORRConfig.attenuation_factor + self.mrr_r = MORRConfig.coupling_factor + self.trainable_morr_bias = trainable_morr_bias + self.trainable_morr_scale = trainable_morr_scale + self.device = device + + ### calculate FWHM (rad) + self.morr_fwhm = ( + -4 + * np.pi ** 2 + * MORRConfig.radius + * MORRConfig.effective_index + * ( + 1 / MORRConfig.resonance_wavelength + - 1 / (MORRConfig.resonance_wavelength - MORRConfig.bandwidth / 2) + ) + ) + + ### allocate parameters + self.weight = None + self.x_zero_pad = None + self.morr_output_scale = None ## learnable balancing factors implelemt by MRRs + self.morr_input_bias = None ## round-trip phase shift bias within MORR + self.morr_input_scale = None ## scaling factor for the round-trip phase shift within MORR + self.morr_gain = ( + 100 / (self.in_channels_flat // self.miniblock) + ) ** 0.5 ## set this TIA gain such that output variance is around 1 + ### build trainable parameters + self.build_parameters() + + ### quantization tool + self.input_quantizer = input_quantize_fn(self.in_bit, device=self.device) + self.weight_quantizer = weight_quantize_fn( + self.w_bit, alg="dorefa_pos" + ) ## [0-1] positive only, maintain the original scale + self.morr_output_scale_quantizer = weight_quantize_fn( + self.w_bit, alg="dorefa_sym" + ) ## [-1,1] full-range + + self.mrr_roundtrip_phase_to_tr = mrr_roundtrip_phase_to_tr_func( + a=self.mrr_a, r=self.mrr_r, intensity=True + ) + + ### default set to slow forward + self.disable_fast_forward() + ### default set no gamma noise + self.set_gamma_noise(0) + ### default set no crosstalk + self.disable_crosstalk() + ### default set no phase variation + self.disable_phase_variation() + + if bias: + self.bias = Parameter(torch.Tensor(out_channels).to(self.device)) + else: + self.register_parameter("bias", None) + + self.reset_parameters(morr_init=morr_init) + + # support fine-grained structured pruning for MORRs + self.finegrain_drop_mask = None + + def build_parameters(self) -> None: + ### MORR weights + self.weight = Parameter( + torch.ones( + self.grid_dim_y, self.grid_dim_x, self.miniblock, device=self.device, dtype=torch.float + ) + ) + ### learnable balancing factor achieved by MRRs (morr_output_scale) + ### We use a single scaling factor for each block + self.morr_output_scale = Parameter(torch.zeros(max(1, self.grid_dim_x // 2) + 1, device=self.device)) + if self.trainable_morr_bias: + ### initialize with the finest-granularity, i.e., per mini-block + self.morr_input_bias = Parameter( + torch.zeros(self.grid_dim_y, self.grid_dim_x, device=self.device, dtype=torch.float) + ) + if self.trainable_morr_scale: + ### initialize with the finest-granularity, i.e., per mini-block + self.morr_input_scale = Parameter( + torch.zeros(self.grid_dim_y, self.grid_dim_x, device=self.device, dtype=torch.float) + ) + + def reset_parameters(self, morr_init: bool = False) -> None: + if morr_init: + ### nonlinear curve aware initialization + morr_uniform_( + self.weight, + MORRConfig=self.MORRConfig, + n_op=self.miniblock, + biased=self.w_bit >= 16, + gain=2 if self.in_bit < 16 else 1, + ) + self.sigma_weight = self.weight.data.std().item() + self.weight_quant_gain = None + ### output distribution aware initialization to output scaling factor + t1 = mrr_roundtrip_phase_to_tr_fused( + torch.tensor([0]).float(), a=self.mrr_a, r=self.mrr_r, intensity=True + ) + t2 = mrr_roundtrip_phase_to_tr_fused( + torch.tensor([self.morr_fwhm * 2.4]).float(), a=self.mrr_a, r=self.mrr_r, intensity=True + ) + g = ((t2 - t1) / (2.4 * self.morr_fwhm)).item() ## 0~2.4 FWHM slope as a linear approximation + + self.sigma_out_scale = 4 / (3 * self.grid_dim_x ** 0.5 * g * self.morr_fwhm) + self.out_scale_quant_gain = None + init.normal_(self.morr_output_scale, 0, self.sigma_out_scale) + + else: + nn.init.kaiming_normal_(self.weight) + nn.init.kaiming_normal_(self.morr_output_scale) + self.sigma_weight = self.weight.data.std().item() + self.weight_quant_gain = None + self.sigma_out_scale = self.morr_output_scale.data.std().item() + self.out_scale_quant_gain = None + + if self.morr_input_bias is not None: + init.zeros_(self.morr_input_bias.data) + if self.morr_input_scale is not None: + init.zeros_(self.morr_input_scale.data) + + if self.bias is not None: + init.uniform_(self.bias, 0, 0) + + def sync_parameters(self, src: str = "weight") -> None: + """ + description: synchronize all parameters from the source parameters + """ + + raise NotImplementedError + + def build_weight(self) -> Tensor: + if self.w_bit < 16: + ### differentiable quantizer based on STE to enable QAT (Dorefa-Net, arXiv 2016) + weight = self.weight_quantizer(self.weight) + else: + weight = self.weight.abs() ## have to be all positive + if self.finegrain_drop_mask is not None: + weight = weight.mul(self.finegrain_drop_mask.float()) + + return weight + + def enable_fast_forward(self) -> None: + self.fast_forward_flag = True + + def disable_fast_forward(self) -> None: + self.fast_forward_flag = False + + def set_gamma_noise(self, noise_std: float, random_state: Optional[int] = None) -> None: + self.gamma_noise_std = noise_std + + def load_parameters(self, param_dict) -> None: + """ + description: update parameters based on this parameter dictionary\\ + param param_dict {dict of dict} {layer_name: {param_name: param_tensor, ...}, ...} + """ + for name, param in param_dict.items(): + getattr(self, name).data.copy_(param) + + def set_weight_bitwidth(self, w_bit: int) -> None: + self.w_bit = w_bit + self.weight_quantizer.set_bitwidth(w_bit) + self.morr_output_scale_quantizer.set_bitwidth(w_bit) + + def set_input_bitwidth(self, in_bit: int) -> None: + self.in_bit = in_bit + self.input_quantizer.set_bitwidth(in_bit) + + def input_modulator(self, x: Tensor) -> Tensor: + ### voltage to power, which is proportional to the phase shift + return x * x + + def set_crosstalk_coupling_matrix(self, coupling_factor: float, drop_perc: float = 0) -> None: + ### crosstalk coupling matrix is a symmetric matrix, but the intra-MORR crosstalk can be taken as a round-trip phase shift scaling factor, which is proportional to the number of segments after pruned. + ### See SqueezeLight paper + ### drop-perc is the pruning percentage. + assert 0 <= coupling_factor <= 1, logger.error( + f"Coupling factor must in [0,1], but got {coupling_factor}" + ) + + self.crosstalk_factor = 1 + max(3, (self.miniblock * (1 - drop_perc) - 1)) * coupling_factor + + def enable_crosstalk(self) -> None: + self.enable_thermal_crosstalk = True + + def disable_crosstalk(self) -> None: + self.enable_thermal_crosstalk = False + + def set_phase_variation(self, phase_noise_std: float = 0) -> None: + self.phase_noise_std = phase_noise_std + + def enable_phase_variation(self) -> None: + self.enable_phase_noise = True + + def disable_phase_variation(self) -> None: + self.enable_phase_noise = False + + def enable_trainable_morr_scale(self) -> None: + self.trainable_morr_scale = True + + def disable_trainable_morr_scale(self) -> None: + self.trainable_morr_scale = False + + def enable_trainable_morr_bias(self) -> None: + self.trainable_morr_bias = True + + def disable_trainable_morr_bias(self) -> None: + self.trainable_morr_bias = False + + @property + def morr_scale(self) -> Tensor: + return torch.sigmoid(self.morr_input_scale.unsqueeze(0).unsqueeze(-1)) + 0.2 + + @property + def morr_bias(self) -> Tensor: + return self.morr_fwhm * torch.tanh(self.morr_input_bias.unsqueeze(0).unsqueeze(-1)) + + def propagate_morr(self, weight: Tensor, x: Tensor) -> Tensor: + """ + @description: propagate through the analytically calculated transfer matrix of morr. + @param weight {torch.Tensor} first column vectors in the block-circulant matrix + @param x {torch.Tensor} input + @return: y {torch.Tensor} output of MORR array + """ + ### weights: [p, q, k] + ### x: [ks*ks*inc, h_out*w_out*bs] + + x = x.t() # [h_out*w_out*bs, ks*ks*inc] + x = x.view(x.size(0), self.grid_dim_x, self.miniblock) # [h_out*w_out*bs, q, k] + + ### injecting crosstalk into weights is more efficient + if self.enable_thermal_crosstalk and self.crosstalk_factor > 1: + weight = weight * self.crosstalk_factor + + ### construct block-circulant matrix + weight = toeplitz(weight).unsqueeze(0) # [1, p, q, k, k] + x = x.unsqueeze(1).unsqueeze(-1) # [h*w*bs, 1, q, k, 1] + x = weight.matmul(x).squeeze(-1) # [h*w*bs, p, q, k] + + if self.enable_phase_noise and self.phase_noise_std > 1e-5: + x = x + torch.zeros_like(x).normal_(0, self.phase_noise_std) # [h*w*bs, p, q, k] + + ### input scaling, learnable MORR nonlinearity + if self.trainable_morr_scale: + x = x * self.morr_scale # [h*w*bs, p, q, k] + ### input biasing, learnable MORR nonlinearity + if self.trainable_morr_bias: + x = x - self.morr_bias + + ### Use theoretical transmission function for trainable MORR nonlinearity + ### x is the phase detuning, x=0 means on-resonance + ### x: [h_out*w_out*bs, p, q, k] + x = self.mrr_roundtrip_phase_to_tr(x) + + ### output scaling or learnable balancing factors + if self.w_bit < 16: + morr_output_scale = self.morr_output_scale_quantizer(self.morr_output_scale) + if self.out_scale_quant_gain is None: + self.out_scale_quant_gain = self.sigma_out_scale / morr_output_scale.data.std().item() + morr_output_scale = morr_output_scale.mul( + self.out_scale_quant_gain + ) ### gain factor from Tanh used in quantization + else: + morr_output_scale = self.morr_output_scale + + scale = morr_output_scale[:-1] + scale_pad = morr_output_scale[-1:] + + ### differential rails + if self.grid_dim_x % 2 == 0: + # even blocks + scale = torch.cat([scale, -scale], dim=0) + else: + # odd blocks + if self.grid_dim_x > 1: + scale = torch.cat([morr_output_scale, -scale], dim=0) + else: + scale = scale_pad + scale = scale.unsqueeze(0).unsqueeze(0).unsqueeze(0) # [1, 1, 1, q] + + x = scale.matmul(x) # [1,1,1,q]x[h_out*w_out*bs, p, q, k]=[h_out*w_out*bs, p, 1, k] + x = x.view(x.size(0), -1).t() # [p*k, h_out*w_out*bs] + if self.out_channels_pad > self.out_channels: + x = x[: self.out_channels, :] # [outc, h_out*w_out*bs] + return x + + def morr_conv2d(self, X: Tensor, W: Tensor) -> Tensor: + ### W : [p, q, k] + n_x = X.size(0) + + _, X_col, h_out, w_out = im2col_2d( + None, + X, + stride=self.stride[0], + padding=self.padding[0], + w_size=(self.out_channels, self.in_channels, self.kernel_size[0], self.kernel_size[1]), + ) + ## zero-padding X_col + if self.in_channels_pad > self.in_channels_flat: + if self.x_zero_pad is None or self.x_zero_pad.size(1) != X_col.size(1): + self.x_zero_pad = torch.zeros( + self.in_channels_pad - self.in_channels_flat, + X_col.size(1), + dtype=torch.float32, + device=self.device, + ) + + X_col = torch.cat([X_col, self.x_zero_pad], dim=0) + # matmul + out = self.propagate_morr(W, X_col) # [outc, w_out] + out = out.view(self.out_channels, h_out, w_out, n_x) + out = out.permute(3, 0, 1, 2).contiguous() + + return out + + def get_finegrain_drop_mask(self, topk: int) -> Tensor: + if self.w_bit < 16: + weight = self.weight_quantizer(self.weight.data) # [p, q, k] + else: + weight = self.weight.data.abs() + indices = weight.argsort(dim=-1) + mask = torch.ones_like(weight, dtype=torch.bool, device=weight.device) + + drop_indices = indices[:, :, 0:-topk] + mask.scatter_(2, drop_indices, 0) + self.finegrain_drop_mask = mask + return mask + + def apply_finegrain_drop_mask(self, mask: Tensor) -> None: + if self.w_bit < 16: + self.weight.data.masked_fill_(~mask.view_as(self.weight.data), -1000) + else: + self.weight.data.masked_fill_(~mask.view_as(self.weight.data), 0) + + def get_output_dim(self, img_height: int, img_width: int) -> Tuple[int, int]: + """ + get the output features size + """ + h_out = (img_height - self.kernel_size[0] + 2 * self.padding[0]) / self.stride[0] + 1 + w_out = (img_width - self.kernel_size[1] + 2 * self.padding[1]) / self.stride[1] + 1 + return (int(h_out), int(w_out)) + + def forward(self, x: Tensor) -> Tensor: + if self.in_bit < 16: + x = self.input_quantizer(x) + weight = self.build_weight() + x = self.input_modulator(x) + x = self.morr_conv2d(x, weight) + + if self.bias is not None: + x = x + self.bias.unsqueeze(0).unsqueeze(2).unsqueeze(3) + + return x diff --git a/src/chop/nn/optical/modules/morr_linear.py b/src/chop/nn/optical/modules/morr_linear.py new file mode 100644 index 000000000..4846fb763 --- /dev/null +++ b/src/chop/nn/optical/modules/morr_linear.py @@ -0,0 +1,442 @@ +""" +Description: +Author: Jiaqi Gu (jqgu@utexas.edu) +Date: 2022-04-18 14:19:57 +LastEditors: Jiaqi Gu (jqgu@utexas.edu) +LastEditTime: 2022-04-18 16:21:37 +""" +from typing import Optional + +import numpy as np +import torch +import torch.fft +from ..functional import toeplitz +from ..functional import logger +from ..functional import morr_uniform_ +from ..functional import input_quantize_fn, weight_quantize_fn +from torch import Tensor +from torch.nn import Parameter, init +from torch.types import Device +from ..functional import MORRConfig_20um_MQ +from ..functional import mrr_roundtrip_phase_to_tr_func, mrr_roundtrip_phase_to_tr_fused + +from .base_layer import ONNBaseLayer + +__all__ = ["AllPassMORRCirculantLinear"] + + +class AllPassMORRCirculantLinear(ONNBaseLayer): + """ + All-pass MORR Linear layer, assumes (1) block-circulant matrix (2) differential rails (3) learnable balancing factors. + J. Gu, et al., "SqueezeLight: Towards Scalable Optical Neural Networks with Multi-Operand Ring Resonators" + https://doi.org/10.23919/DATE51398.2021.9474147 + """ + + __constants__ = ["in_features", "out_features"] + in_features: int + out_features: int + miniblock: int + weight: Tensor + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = False, + config = None, + # miniblock: int = 4, + # ### mrr parameter + # MORRConfig=MORRConfig_20um_MQ, + # morr_init: bool = True, + # ### trainable MORR nonlinearity + # trainable_morr_bias: bool = False, + # trainable_morr_scale: bool = False, + device: Device = torch.device("cuda"), + ) -> None: + super(AllPassMORRCirculantLinear, self).__init__() + self.in_features = in_features + self.out_features = out_features + + miniblock_size = config.get("miniblock", 4) + self.miniblock = miniblock_size + self.grid_dim_x = int(np.ceil(self.in_features / miniblock_size)) + self.grid_dim_y = int(np.ceil(self.out_features / miniblock_size)) + self.in_features_pad = self.grid_dim_x * miniblock_size + self.out_features_pad = self.grid_dim_y * miniblock_size + + self.v_max = 10.8 + self.v_pi = 4.36 + self.gamma = np.pi / self.v_pi ** 2 + self.w_bit = 32 + self.in_bit = 32 + + morr_config = config.get("MORRConfig", MORRConfig_20um_MQ) + morr_init_val = config.get("morr_init", MORRConfig_20um_MQ) + self.MORRConfig = morr_config + self.morr_init = morr_init_val + self.mrr_a = morr_config.attenuation_factor + self.mrr_r = morr_config.coupling_factor + self.trainable_morr_bias = config.get("trainable_morr_bias", MORRConfig_20um_MQ) + self.trainable_morr_scale = config.get("trainable_morr_scale", MORRConfig_20um_MQ) + self.device = device + ### calculate FWHM (rad) + self.morr_fwhm = ( + -4 + * np.pi ** 2 + * morr_config.radius + * morr_config.effective_index + * ( + 1 / morr_config.resonance_wavelength + - 1 / (morr_config.resonance_wavelength - morr_config.bandwidth / 2) + ) + ) + + ### allocate parameters + self.weight = None + self.x_zero_pad = None + self.morr_output_scale = None ## learnable balancing factors implelemt by MRRs + self.morr_input_bias = None ## round-trip phase shift bias within MORR + self.morr_input_scale = None ## scaling factor for the round-trip phase shift within MORR + self.morr_gain = ( + 100 / (self.in_features // self.miniblock) + ) ** 0.5 ## TIA gain, calculated such that output variance is around 1 + ### build trainable parameters + self.build_parameters() + + ### quantization tool + self.input_quantizer = input_quantize_fn(self.in_bit, device=self.device) + self.weight_quantizer = weight_quantize_fn( + self.w_bit, alg="dorefa_pos" + ) ## [0-1] positive only, maintain the original scale + self.morr_output_scale_quantizer = weight_quantize_fn( + self.w_bit, alg="dorefa_sym" + ) ## [-1,1] full-range + + self.mrr_roundtrip_phase_to_tr = mrr_roundtrip_phase_to_tr_func( + a=self.mrr_a, r=self.mrr_r, intensity=True + ) + + ### default set to slow forward + self.disable_fast_forward() + ### default set no gamma noise + self.set_gamma_noise(0) + ### default set no crosstalk + self.disable_crosstalk() + ### default set no phase variation + self.disable_phase_variation() + + if bias: + self.bias = Parameter(torch.Tensor(out_features).to(self.device)) + else: + self.register_parameter("bias", None) + + self.reset_parameters(morr_init=morr_init_val) + self.finegrain_drop_mask = None + + def build_parameters(self) -> None: + + self.weight = Parameter( + torch.ones( + self.grid_dim_y, self.grid_dim_x, self.miniblock, device=self.device, dtype=torch.float + ) + ) + ### Learnable balancing factor (morr_output_scale) + ### We use a single scaling factor for each block + self.morr_output_scale = Parameter( + torch.randn(1, 1, max(1, self.grid_dim_x // 2) + 1, 1, device=self.device) + ) + if self.trainable_morr_bias: + ### initialize with the finest-granularity, i.e., per mini-block + self.morr_input_bias = Parameter( + torch.zeros(self.grid_dim_y, self.grid_dim_x, device=self.device, dtype=torch.float) + ) + if self.trainable_morr_scale: + ### initialize with the finest-granularity, i.e., per mini-block + self.morr_input_scale = Parameter( + torch.zeros(self.grid_dim_y, self.grid_dim_x, device=self.device, dtype=torch.float) + ) + + def reset_parameters(self, morr_init: bool = False) -> None: + ### nonlinear curve aware initialization + if morr_init: + ## initialize weight + morr_uniform_( + self.weight, + MORRConfig=self.MORRConfig, + n_op=self.miniblock, + biased=self.w_bit >= 16, + gain=2 if self.in_bit < 16 else 1, + ) # quantization needs zero-center + self.sigma_weight = self.weight.data.std().item() + self.weight_quant_gain = None + + ## output distribution aware initialization to output scaling factor + t1 = mrr_roundtrip_phase_to_tr_fused( + torch.tensor([0]).float(), a=self.mrr_a, r=self.mrr_r, intensity=True + ) + t2 = mrr_roundtrip_phase_to_tr_fused( + torch.tensor([self.morr_fwhm * 2.4]).float(), a=self.mrr_a, r=self.mrr_r, intensity=True + ) + g = ((t2 - t1) / (2.4 * self.morr_fwhm)).item() ## 0~2.4 FWHM slope as a linear approximation + + self.sigma_out_scale = 4 / (3 * self.grid_dim_x ** 0.5 * g * self.morr_fwhm) + self.out_scale_quant_gain = None + init.normal_(self.morr_output_scale, 0, self.sigma_out_scale) + else: + init.kaiming_normal_(self.weight.data) + init.kaiming_normal_(self.morr_output_scale.data) + self.sigma_weight = self.weight.data.std().item() + self.weight_quant_gain = None + self.sigma_out_scale = self.morr_output_scale.data.std().item() + self.out_scale_quant_gain = None + + if self.morr_input_bias is not None: + self.morr_input_bias.data.zero_() + if self.morr_input_scale is not None: + ### after sigmoid, it cooresponds to 1 scale + init.normal_(self.morr_input_scale.data, 2, 0.1) + + if self.bias is not None: + init.uniform_(self.bias, 0, 0) + + def sync_parameters(self, src: str = "weight") -> None: + """ + description: synchronize all parameters from the source parameters + """ + + raise NotImplementedError + + def build_weight(self) -> Tensor: + if self.w_bit < 16: + ### differentiable quantizer based on STE to enable QAT (Dorefa-Net, arXiv 2016) + weight = self.weight_quantizer(self.weight) + + ## rescale weights after quantization can maintain the initialization distribution + if self.weight_quant_gain is None: + self.weight_quant_gain = self.sigma_weight / weight.data.std() + if self.trainable_morr_scale: + morr_scale = self.morr_scale * self.weight_quant_gain + else: + morr_scale = self.weight_quant_gain + weight = weight.mul(morr_scale) ### gain factor from Tanh used in quantization + + ### quantize learnable balancing factor + morr_output_scale = self.morr_output_scale_quantizer(self.morr_output_scale) + else: + weight = self.weight.abs() # positive only + morr_output_scale = self.morr_output_scale - self.morr_output_scale.data.mean() + + if self.finegrain_drop_mask is not None: + weight = weight.mul(self.finegrain_drop_mask.float()) + + ## differential balancing factor concatenation + scale = morr_output_scale[..., :-1, :] + scale_pad = morr_output_scale[..., -1:, :] + if self.grid_dim_x % 2 == 0: + # even blocks + scale = torch.cat([scale, -scale], dim=2) # [1, 1, q, 1] + else: + # odd blocks + if self.grid_dim_x > 1: + scale = torch.cat([morr_output_scale, -scale], dim=2) # [1, 1, q, 1] + else: + scale = scale_pad # [1, 1, q, 1] + morr_output_scale = scale.squeeze(-1).unsqueeze(0) # [1 ,1, 1, q] + + return weight, morr_output_scale + + def enable_fast_forward(self) -> None: + self.fast_forward_flag = True + + def disable_fast_forward(self) -> None: + self.fast_forward_flag = False + + def set_gamma_noise(self, noise_std: float, random_state: Optional[int] = None) -> None: + self.gamma_noise_std = noise_std + + def load_parameters(self, param_dict) -> None: + """ + description: update parameters based on this parameter dictionary\\ + param param_dict {dict of dict} {layer_name: {param_name: param_tensor, ...}, ...} + """ + for name, param in param_dict.items(): + getattr(self, name).data.copy_(param) + + def set_weight_bitwidth(self, w_bit: int) -> None: + self.w_bit = w_bit + self.weight_quantizer.set_bitwidth(w_bit) + self.morr_output_scale_quantizer.set_bitwidth(w_bit) + + def set_input_bitwidth(self, in_bit: int) -> None: + self.in_bit = in_bit + self.input_quantizer.set_bitwidth(in_bit) + + def input_modulator(self, x: Tensor) -> Tensor: + ### voltage to power, which is proportional to the phase shift + return x * x + + def set_crosstalk_coupling_matrix(self, coupling_factor: float, drop_perc: float = 0) -> None: + ### crosstalk coupling matrix is a symmetric matrix, but the intra-MORR crosstalk can be taken as a round-trip phase shift scaling factor, which is proportional to the number of segments after pruned. + ### drop-perc is the pruning percentage. + assert 0 <= coupling_factor <= 1, logger.error( + f"Coupling factor must in [0,1], but got {coupling_factor}" + ) + + self.crosstalk_factor = 1 + max(3, (self.miniblock * (1 - drop_perc) - 1)) * coupling_factor + + def enable_crosstalk(self) -> None: + self.enable_thermal_crosstalk = True + + def disable_crosstalk(self) -> None: + self.enable_thermal_crosstalk = False + + def set_phase_variation(self, phase_noise_std: float = 0) -> None: + self.phase_noise_std = phase_noise_std + + def enable_phase_variation(self) -> None: + self.enable_phase_noise = True + + def disable_phase_variation(self) -> None: + self.enable_phase_noise = False + + def enable_trainable_morr_scale(self) -> None: + self.trainable_morr_scale = True + + def disable_trainable_morr_scale(self) -> None: + self.trainable_morr_scale = False + + def enable_trainable_morr_bias(self) -> None: + self.trainable_morr_bias = True + + def disable_trainable_morr_bias(self) -> None: + self.trainable_morr_bias = False + + @property + def morr_bias(self) -> Tensor: + if self.morr_input_bias is None: + return None + # return 2 * self.morr_fwhm * torch.sigmoid(self.morr_input_bias.unsqueeze(0).unsqueeze(-1)) + return self.morr_fwhm * torch.tanh(self.morr_input_bias.unsqueeze(0).unsqueeze(-1)) + + @property + def morr_scale(self) -> Tensor: + if self.morr_input_scale is None: + return None + return torch.sigmoid(self.morr_input_scale.unsqueeze(-1)) + 0.2 # [p, q, 1] + + def propagate_morr(self, weight: Tensor, x: Tensor, morr_output_scale: Tensor) -> Tensor: + """ + @description: propagate through the analytically calculated transfer matrix of molg. We implement circulant matrix multiplication using fast circ matmul + @param weight {torch.Tensor} two phase shifters in the MZI-based attenuators + @param x {torch.Tensor} complex-valued input + @param morr_output_scale {torch.Tensor} learnable balancing factors + @return: y {torch.Tensor} output of attenuators + """ + ### x : [bs, q, k] + ### weights: [p, q, k] + ### morr_output_scale: [1, 1, 1, q] + + ### input scaling [TCAD'21], must have valid ranges. too small will have dead neuron and not enough nonlinearity; too large will have larger power, cross-channel crosstalk. [0.2 - 1.2] will be suitable + ## build circulant weight matrix + # crosstalk on the weights are much cheaper to compute than on the phase shift + if self.enable_thermal_crosstalk and self.crosstalk_factor > 1: + weight = weight * self.crosstalk_factor + weight = toeplitz(weight).unsqueeze(0) # [1, p, q, k, k] + x = x.unsqueeze(1).unsqueeze(-1) # [bs, 1, q, k, 1] + x = weight.matmul(x).squeeze(-1) # [bs, p, q, k] + + if self.enable_phase_noise and self.phase_noise_std > 1e-5: + x = x + torch.zeros_like(x).normal_(0, self.phase_noise_std) + + ### input biasing [TCAD'21], must have valid ranges. too large will have power issue and cross-channel crosstalk. [-2FWHM ~ 0] + if self.trainable_morr_bias: + x = x - self.morr_bias + + ### Use theoretical transmission function for trainable MORR nonlinearity [TCAD'21] + ### x is the phase detuning, x=0 means on-resonance + ### phase: [bs, p, q, k] + x = self.mrr_roundtrip_phase_to_tr(x) # 3x faster than autograd + + ## implement balancing factor as dot-product + """ + if(self.w_bit < 16): + morr_output_scale = self.morr_output_scale_quantizer(self.morr_output_scale) + if(self.sigma_out_scale_quant_gain is None): + self.sigma_out_scale_quant_gain = self.sigma_out_scale / morr_output_scale.data.std().item() + morr_output_scale = morr_output_scale.mul(self.sigma_out_scale_quant_gain)### gain factor from Tanh used in quantization + else: + morr_output_scale = self.morr_output_scale + # morr_output_scale = morr_output_scale * self.morr_gain + scale = morr_output_scale[..., :-1, :] + scale_pad = morr_output_scale[..., -1:, :] + + # print("morr diff transmission:", end=", ") + # diff = x[..., :x.size(2)//2,:]-x[..., x.size(2)//2:,:] + # print_stat(diff) + if(self.grid_dim_x % 2 == 0): + #even blocks + scale = torch.cat([scale, -scale], dim=2) # [1, 1, q, 1] + else: + # odd blocks + if(self.grid_dim_x > 1): + scale = torch.cat([morr_output_scale, -scale], dim=2) # [1, 1, q, 1] + else: + scale = scale_pad # [1, 1, q, 1] + scale = scale.squeeze(-1).unsqueeze(0) # [1 ,1, 1, q] + # print("output scale Q:", end=", ") + # print_stat(scale[..., :scale.size(-1)//2]) + """ + x = morr_output_scale.matmul(x) # [1, 1, 1, q] x [bs, p, q, k] = [bs, p, 1, k] + x = x.flatten(1) # [bs, p*k] + return x + + def get_finegrain_drop_mask(self, topk: int) -> Tensor: + if self.w_bit < 16: + weight = self.weight_quantizer(self.weight.data) # [p, q, k] + else: + weight = self.weight.data.abs() + indices = weight.argsort(dim=-1) + mask = torch.ones_like(weight, dtype=torch.bool, device=weight.device) + + drop_indices = indices[:, :, 0:-topk] + mask.scatter_(2, drop_indices, 0) + self.finegrain_drop_mask = mask + return mask + + def apply_finegrain_drop_mask(self, mask: Tensor) -> None: + if self.w_bit < 16: + self.weight.data.masked_fill_(~mask.view_as(self.weight.data), -1000) + else: + self.weight.data.masked_fill_(~mask.view_as(self.weight.data), 0) + + def forward(self, x: Tensor) -> Tensor: + assert ( + x.size(-1) == self.in_features + ), f"[E] Input dimension does not match the weight size {self.out_features, self.in_features}, but got input size ({tuple(x.size())}))" + if self.in_bit < 16: + x = self.input_quantizer(x) + + weight, morr_output_scale = self.build_weight() + if self.in_features_pad > self.in_features: + if self.x_zero_pad is None or self.x_zero_pad.size(0) != x.size(0): + self.x_zero_pad = torch.zeros( + x.size(0), self.in_features_pad - self.in_features, device=x.device, dtype=x.dtype + ) + x = torch.cat([x, self.x_zero_pad], dim=1) + + x = x.view(-1, self.grid_dim_x, self.miniblock) + + ### modulation + ### x: [bs, q, k] -> [bs, q, k] + x = self.input_modulator(x) + + ### propagate through morr array + ### x: [bs, q, k] -> [bs, p*k] + x = self.propagate_morr(weight, x, morr_output_scale) + + if self.out_features < self.out_features_pad: + x = x[..., : self.out_features] + if self.bias is not None: + x = x + self.bias.unsqueeze(0) + + return x diff --git a/src/chop/passes/module/module_transform_helper.py b/src/chop/passes/module/module_transform_helper.py new file mode 100644 index 000000000..b644b087a --- /dev/null +++ b/src/chop/passes/module/module_transform_helper.py @@ -0,0 +1,64 @@ +import torch +import torch.nn as nn +import numpy as np +from chop.nn.optical.modules import optical_module_map +from chop.passes.module.module_modify_helper import get_module_by_name, set_module_by_name + +def replace_by_name_optical( + network, + module_name: str, + new_module +): + + original = get_module_by_name(network, module_name) + updated_module = weight_replacement_optical(original, new_module) + network = set_module_by_name(network, module_name, updated_module) + + return network + + +def weight_replacement_optical(x, y): + """ + Replace the weights of AllPassMORRCirculantLinear (y) + with those from a standard nn.Linear (x). + Focuses only on weight copying (no bias copying). + """ + + # Fetch original linear weight [out_features, in_features] + W = x.weight.data # shape: (out_features, in_features) + + # Grab dimensions and zero-pad if needed + out_features_pad = y.out_features_pad # padded out_features in y + in_features_pad = y.in_features_pad # padded in_features in y + miniblock = y.miniblock + grid_dim_y = y.grid_dim_y + grid_dim_x = y.grid_dim_x + + # Construct padded weight tensor + W_padded = W.new_zeros((out_features_pad, in_features_pad)) + W_padded[: W.size(0), : W.size(1)] = W # copy original into top-left + + # Now we create a new tensor of shape [grid_dim_y, grid_dim_x, miniblock] + # by compressing each row-block [1 x miniblock] from W_padded into a single scalar. + # This is a simple example that takes the mean across the miniblock slice. + new_weight = W.new_zeros((grid_dim_y, grid_dim_x, miniblock)) + + # Fill new_weight by averaging the corresponding sub-blocks in W_padded + with torch.no_grad(): + for p in range(grid_dim_y): + for q in range(grid_dim_x): + for k in range(miniblock): + # The row in W_padded we look at: + row_idx = p * miniblock + k + # The columns we look at: + col_start = q * miniblock + col_end = (q + 1) * miniblock + + block = W_padded[row_idx, col_start:col_end] + new_weight[p, q, k] = block.mean() + + # Copy the result into y.weight + y.load_parameters({"weight": new_weight}) + # y.weight.data.copy_(new_weight) + + return y diff --git a/src/chop/passes/module/transforms/__init__.py b/src/chop/passes/module/transforms/__init__.py index efbb0ed14..8773589ec 100644 --- a/src/chop/passes/module/transforms/__init__.py +++ b/src/chop/passes/module/transforms/__init__.py @@ -1,3 +1,4 @@ from .autosharding import resharding_transform_pass from .quantize import quantize_module_transform_pass from .autosharding import resharding_transform_pass +from .optical import optical_module_transform_pass \ No newline at end of file diff --git a/src/chop/passes/module/transforms/optical/__init__.py b/src/chop/passes/module/transforms/optical/__init__.py new file mode 100644 index 000000000..9b1840c4e --- /dev/null +++ b/src/chop/passes/module/transforms/optical/__init__.py @@ -0,0 +1 @@ +from .optical import optical_module_transform_pass diff --git a/src/chop/passes/module/transforms/optical/optical.py b/src/chop/passes/module/transforms/optical/optical.py new file mode 100644 index 000000000..da475d778 --- /dev/null +++ b/src/chop/passes/module/transforms/optical/optical.py @@ -0,0 +1,103 @@ +import torch + +from chop.nn.optical.modules import optical_module_map +from ...module_modify_helper import replace_by_name, instantiate_module +from ...module_transform_helper import replace_by_name_optical + + +def get_config(config: dict, name: str): + if name in config: + return config[name]["config"] + else: + return config["default"]["config"] + + +def optical_by_type(network, pass_args): + for type_name, config in pass_args.items(): + n_m = {} + for n, m in network.named_modules(): + n_m[n] = m + + if type_name == "linear": + module = torch.nn.Linear + elif type_name == "conv2d": + module = torch.nn.Conv2d + else: + raise ValueError(f"{type_name} is not supported!") + config = config["config"] + postfix = config.pop("name") + for n, m in n_m.items(): + if isinstance(m, module): + new_m = instantiate_module( + m, postfix, optical_module_map, {"config": config} + ) + network = replace_by_name_optical(network, n, new_m) + return network + + +def optical_by_name(network, pass_args): + quantize_names = pass_args.keys() + n_m = {} + for n, m in network.named_modules(): + n_m[n] = m + for n, m in n_m.items(): + if n in quantize_names: + quan_config = pass_args[n] + + quan_config = quan_config["config"] + postfix = quan_config.pop("name") + + new_m = instantiate_module( + m, postfix, optical_module_map, {"config": quan_config} + ) + network = replace_by_name_optical(network, n, new_m) + return network + + +def optical_module_transform_pass(network, pass_args): + """ + Apply optical transformation to the given nn.Module. + + :param network: The input network to be transformed. + :type network: torch.nn.Module + + :param pass_args: Additional arguments for the transformation. + :type pass_args: dict, optional + + Examples pass_args: + + .. code-block:: python + + pass_args = { + "by": "type", # quantize by type, name, or regex_name + "default": {"config": {"name": None}}, # default config, this would be used for any node that does not have a specific config + "linear": { + "config": { + "name": "integer", # quantization scheme name supported are ["integer", "fixed" (equivalent to integer), "lutnet" (dev mode), "logicnets" (dev mode), "binary", "binary_residual", "ternary", "minifloat_ieee", "minifloat_denorm", "log", "block_fp", "block_minifloat", "block_log"] + # data + "data_in_width": 8, + "data_in_frac_width": 4, + # weight + "weight_width": 8, + "weight_frac_width": 4, + # bias + "bias_width": 8, + "bias_frac_width": 4, + } + }, + } + + :return: The transformed torch.nn.Module. + :rtype: tuple + :raises ValueError: If the quantize "by" argument is unsupported. + + """ + by = pass_args.pop("by") + match by: + case "type": + network = optical_by_type(network, pass_args) + case "name": + network = optical_by_name(network, pass_args) + case _: + raise ValueError(f'Unsupported quantize "by": {by}') + return network, {} diff --git a/test/passes/module/transforms/optical/test_optical_module.py b/test/passes/module/transforms/optical/test_optical_module.py new file mode 100644 index 000000000..5409e64f5 --- /dev/null +++ b/test/passes/module/transforms/optical/test_optical_module.py @@ -0,0 +1,192 @@ +#!/usr/bin/env python3 +# This example converts a simple MLP model to Verilog +import logging +import os +import sys + +import torch +import torch.nn as nn + +from pathlib import Path + +sys.path.append(Path(__file__).resolve().parents[5].as_posix()) + + +# from chop.passes.module.transforms import quantize_module_transform_pass +from chop.passes.module.transforms import optical_module_transform_pass +from chop.passes.module import report_trainable_parameters_analysis_pass + +import argparse +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torchvision import datasets, transforms +from torch.optim.lr_scheduler import StepLR + +from train_mnist_cnn import test, train, Net + +# -------------------------------------------------- +# Model specifications +# -------------------------------------------------- +# class MLP(torch.nn.Module): +# """ +# Toy quantized FC model for digit recognition on MNIST +# """ + +# def __init__(self) -> None: +# super().__init__() + +# self.fc1 = nn.Linear(28 * 28, 28 * 28) +# self.fc2 = nn.Linear(28 * 28, 28 * 28 * 4) +# self.fc3 = nn.Linear(28 * 28 * 4, 10) + +# def forward(self, x): +# x = torch.flatten(x, start_dim=1, end_dim=-1) +# x = torch.nn.functional.relu(self.fc1(x)) +# # w = torch.randn((4, 28 * 28)) +# # x = torch.nn.functional.relu(nn.functional.linear(x, w)) +# x = torch.nn.functional.relu(self.fc2(x)) +# x = self.fc3(x) +# return x + +def load_my_model(model_path, device): + # Load the model from the .pt file + loaded_model = torch.load(model_path, map_location=device) + # Set it to evaluation mode (important if it contains layers like BatchNorm or Dropout) + loaded_model.eval() + return loaded_model + +def perform_optical_module_transform_pass(model, save_path="mase_output/onn_cnn.pt"): + pass_args = { + "by": "type", + "linear": { + "config": { + "name": "morr", + "miniblock": 4, + "morr_init": True, + "trainable_morr_bias": False, + "trainable_morr_scale": False, + } + }, + } + onn_model, _ = optical_module_transform_pass(model, pass_args) + torch.save(onn_model.state_dict(), save_path) + return onn_model + +def test_optical_module_transform_pass(): + model_path = "mase_output/sample_mnist_cnn.pt" + mnist_cnn = load_my_model(model_path) + # Sanity check and report + pass_args = { + "by": "name", + "fc1": { + "config": { + "name": "morr", + "miniblock": 4, + "morr_init": True, + "trainable_morr_bias": False, + "trainable_morr_scale": False, + } + }, + } + onn_cnn, _ = optical_module_transform_pass(mnist_cnn, pass_args) + torch.save(onn_cnn, "mase_output/onn_cnn.pt") + + + +if __name__ == '__main__': + + if True: + parser = argparse.ArgumentParser(description='PyTorch MNIST Example') + parser.add_argument('--batch-size', type=int, default=64, metavar='N', + help='input batch size for training (default: 64)') + parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', + help='input batch size for testing (default: 1000)') + parser.add_argument('--epochs', type=int, default=14, metavar='N', + help='number of epochs to train (default: 14)') + parser.add_argument('--lr', type=float, default=1.0, metavar='LR', + help='learning rate (default: 1.0)') + parser.add_argument('--gamma', type=float, default=0.7, metavar='M', + help='Learning rate step gamma (default: 0.7)') + parser.add_argument('--no-cuda', action='store_true', default=False, + help='disables CUDA training') + parser.add_argument('--no-mps', action='store_true', default=False, + help='disables macOS GPU training') + parser.add_argument('--dry-run', action='store_true', default=False, + help='quickly check a single pass') + parser.add_argument('--seed', type=int, default=1, metavar='S', + help='random seed (default: 1)') + parser.add_argument('--log-interval', type=int, default=10, metavar='N', + help='how many batches to wait before logging training status') + parser.add_argument('--save-model', action='store_true', default=True, + help='For Saving the current Model') + parser.add_argument('--gpu-id', type=int, default=0, + help='Which GPU device to use [default: 0]') + + args = parser.parse_args() + use_cuda = not args.no_cuda and torch.cuda.is_available() + use_mps = not args.no_mps and torch.backends.mps.is_available() + + torch.manual_seed(args.seed) + + if not args.no_cuda and torch.cuda.is_available(): + device = torch.device(f"cuda:{args.gpu_id}") + elif use_mps: + device = torch.device("mps") + else: + device = torch.device("cpu") + + train_kwargs = {'batch_size': args.batch_size} + test_kwargs = {'batch_size': args.test_batch_size} + if use_cuda: + cuda_kwargs = {'num_workers': 1, + 'pin_memory': True, + 'shuffle': True} + train_kwargs.update(cuda_kwargs) + test_kwargs.update(cuda_kwargs) + + transform=transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.1307,), (0.3081,)) + ]) + dataset1 = datasets.MNIST('../data', train=True, download=True, + transform=transform) + dataset2 = datasets.MNIST('../data', train=False, + transform=transform) + train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs) + test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs) + + cnn = load_my_model("mase_output/sample_mnist_cnn.pt", device) + print("-------------- Testing the original cnn model -------------------") + test(cnn, device, test_loader) + _, _ = report_trainable_parameters_analysis_pass(cnn) + + # onn = load_my_model("mase_output/onn_cnn.pt", device) + onn_model = perform_optical_module_transform_pass(cnn) + onn_model.to(device) + print("-------------- Testing the transformed onn model -------------------") + test(onn_model, device, test_loader) + _, _ = report_trainable_parameters_analysis_pass(onn_model) + + + ################################################################## + ######### Training the onn model + ################################################################## + optimizer = optim.Adadelta(onn_model.parameters(), lr=args.lr) + scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) + for epoch in range(1, args.epochs + 1): + train(args, onn_model, device, train_loader, optimizer, epoch) + test(onn_model, device, test_loader) + scheduler.step() + + + torch.save(onn_model.state_dict(), "mase_output/trained_onn.pt") + + print("-------------- Testing the trained onn model -------------------") + test(onn_model, device, test_loader) + _, _ = report_trainable_parameters_analysis_pass(onn_model) + + + + # test_optical_module_transform_pass() \ No newline at end of file diff --git a/test/passes/module/transforms/optical/train_mnist_cnn.py b/test/passes/module/transforms/optical/train_mnist_cnn.py new file mode 100644 index 000000000..de3bbe385 --- /dev/null +++ b/test/passes/module/transforms/optical/train_mnist_cnn.py @@ -0,0 +1,144 @@ +import argparse +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torchvision import datasets, transforms +from torch.optim.lr_scheduler import StepLR + + +class Net(nn.Module): + def __init__(self): + super(Net, self).__init__() + self.conv1 = nn.Conv2d(1, 32, 3, 1) + self.conv2 = nn.Conv2d(32, 64, 3, 1) + self.dropout1 = nn.Dropout(0.25) + self.dropout2 = nn.Dropout(0.5) + self.fc1 = nn.Linear(9216, 128) + self.fc2 = nn.Linear(128, 10) + + def forward(self, x): + x = self.conv1(x) + x = F.relu(x) + x = self.conv2(x) + x = F.relu(x) + x = F.max_pool2d(x, 2) + x = self.dropout1(x) + x = torch.flatten(x, 1) + x = self.fc1(x) + x = F.relu(x) + x = self.dropout2(x) + x = self.fc2(x) + output = F.log_softmax(x, dim=1) + return output + + +def train(args, model, device, train_loader, optimizer, epoch): + model.train() + for batch_idx, (data, target) in enumerate(train_loader): + data, target = data.to(device), target.to(device) + optimizer.zero_grad() + output = model(data) + loss = F.nll_loss(output, target) + loss.backward() + optimizer.step() + if batch_idx % args.log_interval == 0: + print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( + epoch, batch_idx * len(data), len(train_loader.dataset), + 100. * batch_idx / len(train_loader), loss.item())) + if args.dry_run: + break + + +def test(model, device, test_loader): + model.eval() + test_loss = 0 + correct = 0 + with torch.no_grad(): + for data, target in test_loader: + data, target = data.to(device), target.to(device) + output = model(data) + test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss + pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability + correct += pred.eq(target.view_as(pred)).sum().item() + + test_loss /= len(test_loader.dataset) + + print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( + test_loss, correct, len(test_loader.dataset), + 100. * correct / len(test_loader.dataset))) + + +def main(): + # Training settings + parser = argparse.ArgumentParser(description='PyTorch MNIST Example') + parser.add_argument('--batch-size', type=int, default=64, metavar='N', + help='input batch size for training (default: 64)') + parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', + help='input batch size for testing (default: 1000)') + parser.add_argument('--epochs', type=int, default=29, metavar='N', + help='number of epochs to train (default: 14)') + parser.add_argument('--lr', type=float, default=1.0, metavar='LR', + help='learning rate (default: 1.0)') + parser.add_argument('--gamma', type=float, default=0.7, metavar='M', + help='Learning rate step gamma (default: 0.7)') + parser.add_argument('--no-cuda', action='store_true', default=False, + help='disables CUDA training') + parser.add_argument('--no-mps', action='store_true', default=False, + help='disables macOS GPU training') + parser.add_argument('--dry-run', action='store_true', default=False, + help='quickly check a single pass') + parser.add_argument('--seed', type=int, default=1, metavar='S', + help='random seed (default: 1)') + parser.add_argument('--log-interval', type=int, default=10, metavar='N', + help='how many batches to wait before logging training status') + parser.add_argument('--save-model', action='store_true', default=True, + help='For Saving the current Model') + args = parser.parse_args() + use_cuda = not args.no_cuda and torch.cuda.is_available() + use_mps = not args.no_mps and torch.backends.mps.is_available() + + torch.manual_seed(args.seed) + + if use_cuda: + device = torch.device("cuda") + elif use_mps: + device = torch.device("mps") + else: + device = torch.device("cpu") + + train_kwargs = {'batch_size': args.batch_size} + test_kwargs = {'batch_size': args.test_batch_size} + if use_cuda: + cuda_kwargs = {'num_workers': 1, + 'pin_memory': True, + 'shuffle': True} + train_kwargs.update(cuda_kwargs) + test_kwargs.update(cuda_kwargs) + + transform=transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.1307,), (0.3081,)) + ]) + dataset1 = datasets.MNIST('../data', train=True, download=True, + transform=transform) + dataset2 = datasets.MNIST('../data', train=False, + transform=transform) + train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs) + test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs) + + model = Net().to(device) + optimizer = optim.Adadelta(model.parameters(), lr=args.lr) + + scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) + for epoch in range(1, args.epochs + 1): + train(args, model, device, train_loader, optimizer, epoch) + test(model, device, test_loader) + scheduler.step() + + if args.save_model: + torch.save(model, "mase_output/sample_mnist_cnn.pt") + + +if __name__ == '__main__': + main() From 587154305012a1e220fe9ec6293093f903a2759c Mon Sep 17 00:00:00 2001 From: Johnny1882 Date: Sun, 26 Jan 2025 20:58:13 +0000 Subject: [PATCH 02/38] add support for onn conv2d layer --- src/chop/nn/optical/modules/morr_conv2d.py | 23 +- .../optical}/module_transform_helper.py | 54 +++- .../module/transforms/optical/optical.py | 8 +- .../transforms/optical/test_optical_module.py | 250 ++++++++++-------- 4 files changed, 209 insertions(+), 126 deletions(-) rename src/chop/passes/module/{ => transforms/optical}/module_transform_helper.py (50%) diff --git a/src/chop/nn/optical/modules/morr_conv2d.py b/src/chop/nn/optical/modules/morr_conv2d.py index 83f0bf1fe..87fd8ed42 100644 --- a/src/chop/nn/optical/modules/morr_conv2d.py +++ b/src/chop/nn/optical/modules/morr_conv2d.py @@ -71,16 +71,25 @@ def __init__( dilation: _size = 1, groups: int = 1, bias: bool = True, - miniblock: int = 4, - ### morr parameter - MORRConfig=MORRConfig_20um_MQ, - morr_init: bool = True, # whether to use initialization method customized for MORR - ### trainable MORR nonlinearity - trainable_morr_bias: bool = False, - trainable_morr_scale: bool = False, + padding_mode = None, + # miniblock: int = 4, + # ### morr parameter + # MORRConfig=MORRConfig_20um_MQ, + # morr_init: bool = True, # whether to use initialization method customized for MORR + # ### trainable MORR nonlinearity + # trainable_morr_bias: bool = False, + # trainable_morr_scale: bool = False, + config = None, device: Device = torch.device("cuda"), ) -> None: super(AllPassMORRCirculantConv2d, self).__init__() + miniblock = config.get("miniblock", 4) + MORRConfig = config.get("MORRConfig", MORRConfig_20um_MQ) + morr_init = config.get("morr_init", True) + trainable_morr_bias = config.get("trainable_morr_bias", False) + trainable_morr_scale = config.get("trainable_morr_scale", False) + + self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = _pair(kernel_size) diff --git a/src/chop/passes/module/module_transform_helper.py b/src/chop/passes/module/transforms/optical/module_transform_helper.py similarity index 50% rename from src/chop/passes/module/module_transform_helper.py rename to src/chop/passes/module/transforms/optical/module_transform_helper.py index b644b087a..dc110d9b0 100644 --- a/src/chop/passes/module/module_transform_helper.py +++ b/src/chop/passes/module/transforms/optical/module_transform_helper.py @@ -16,8 +16,15 @@ def replace_by_name_optical( return network +def weight_replacement_optical(original, new_module): + if isinstance(original, nn.Linear): + return weight_replacement_linear_optical(original, new_module) + elif isinstance(original, nn.Conv2d): + return weight_replacement_conv2d_optical(original, new_module) + else: + raise NotImplementedError("weight replacement function for the optical module not implemented") -def weight_replacement_optical(x, y): +def weight_replacement_linear_optical(x, y): """ Replace the weights of AllPassMORRCirculantLinear (y) with those from a standard nn.Linear (x). @@ -62,3 +69,48 @@ def weight_replacement_optical(x, y): # y.weight.data.copy_(new_weight) return y + + + +def weight_replacement_conv2d_optical(x, y): + """ + Replace the weights (and bias, if present) of a standard nn.Conv2d (x) + into an AllPassMORRCirculantConv2d (y). + + Args: + x (nn.Conv2d): A standard PyTorch Conv2d module + y (AllPassMORRCirculantConv2d): An already-constructed optical Conv2d + module into which we copy weights/bias. + """ + with torch.no_grad(): + # 1) Copy bias (if both x and y actually have one). + if x.bias is not None and y.bias is not None: + y.bias.copy_(x.bias) + + # 2) Flatten nn.Conv2d's weight => shape [out_channels, in_channels*kernel_h*kernel_w] + w_flat = x.weight.data.view(x.out_channels, -1) + + # 3) Zero-pad to match (out_channels_pad, in_channels_pad) + outC_pad = y.out_channels_pad # == y.grid_dim_y * y.miniblock + inC_pad = y.in_channels_pad # == y.grid_dim_x * y.miniblock + + W = torch.zeros(outC_pad, inC_pad, device=w_flat.device, dtype=w_flat.dtype) + # Copy as many channels/elements as we have + W[: x.out_channels, : w_flat.size(1)] = w_flat + + # 4) Reshape into blocks => shape [p, miniblock, q, miniblock] + p = y.grid_dim_y + q = y.grid_dim_x + k = y.miniblock + W_blocks = W.view(p, k, q, k) # => [p, k, q, k] + + # 5) For each p,q block, extract the "first column" of size 'k' and place it in y.weight + # That is, for a k x k sub-block, we interpret sub_block[:,0] as the "circulant first column". + for i in range(p): + for j in range(q): + sub_block = W_blocks[i, :, j, :] # shape [k, k] + y.weight.data[i, j, :] = sub_block[:, 0] + + # Done. At this point, y.weight and y.bias (if present) have been overwritten + # with a simple block-circulant approximation of x's parameters. + return y \ No newline at end of file diff --git a/src/chop/passes/module/transforms/optical/optical.py b/src/chop/passes/module/transforms/optical/optical.py index da475d778..a91e7403d 100644 --- a/src/chop/passes/module/transforms/optical/optical.py +++ b/src/chop/passes/module/transforms/optical/optical.py @@ -1,8 +1,8 @@ import torch from chop.nn.optical.modules import optical_module_map -from ...module_modify_helper import replace_by_name, instantiate_module -from ...module_transform_helper import replace_by_name_optical +from chop.passes.module.module_modify_helper import replace_by_name, instantiate_module +from chop.passes.module.transforms.optical.module_transform_helper import replace_by_name_optical def get_config(config: dict, name: str): @@ -36,12 +36,12 @@ def optical_by_type(network, pass_args): def optical_by_name(network, pass_args): - quantize_names = pass_args.keys() + optical_names = pass_args.keys() n_m = {} for n, m in network.named_modules(): n_m[n] = m for n, m in n_m.items(): - if n in quantize_names: + if n in optical_names: quan_config = pass_args[n] quan_config = quan_config["config"] diff --git a/test/passes/module/transforms/optical/test_optical_module.py b/test/passes/module/transforms/optical/test_optical_module.py index 5409e64f5..35c959a70 100644 --- a/test/passes/module/transforms/optical/test_optical_module.py +++ b/test/passes/module/transforms/optical/test_optical_module.py @@ -50,17 +50,21 @@ # x = self.fc3(x) # return x -def load_my_model(model_path, device): +def load_my_model(model_path, device="cpu"): # Load the model from the .pt file loaded_model = torch.load(model_path, map_location=device) # Set it to evaluation mode (important if it contains layers like BatchNorm or Dropout) loaded_model.eval() return loaded_model -def perform_optical_module_transform_pass(model, save_path="mase_output/onn_cnn.pt"): + +def test_optical_module_transform_pass(): + model_path = "mase_output/sample_mnist_cnn.pt" + mnist_cnn = load_my_model(model_path) + # Sanity check and report pass_args = { - "by": "type", - "linear": { + "by": "name", + "fc1": { "config": { "name": "morr", "miniblock": 4, @@ -69,18 +73,7 @@ def perform_optical_module_transform_pass(model, save_path="mase_output/onn_cnn. "trainable_morr_scale": False, } }, - } - onn_model, _ = optical_module_transform_pass(model, pass_args) - torch.save(onn_model.state_dict(), save_path) - return onn_model - -def test_optical_module_transform_pass(): - model_path = "mase_output/sample_mnist_cnn.pt" - mnist_cnn = load_my_model(model_path) - # Sanity check and report - pass_args = { - "by": "name", - "fc1": { + "conv1": { "config": { "name": "morr", "miniblock": 4, @@ -90,103 +83,132 @@ def test_optical_module_transform_pass(): } }, } - onn_cnn, _ = optical_module_transform_pass(mnist_cnn, pass_args) - torch.save(onn_cnn, "mase_output/onn_cnn.pt") - - - -if __name__ == '__main__': - - if True: - parser = argparse.ArgumentParser(description='PyTorch MNIST Example') - parser.add_argument('--batch-size', type=int, default=64, metavar='N', - help='input batch size for training (default: 64)') - parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', - help='input batch size for testing (default: 1000)') - parser.add_argument('--epochs', type=int, default=14, metavar='N', - help='number of epochs to train (default: 14)') - parser.add_argument('--lr', type=float, default=1.0, metavar='LR', - help='learning rate (default: 1.0)') - parser.add_argument('--gamma', type=float, default=0.7, metavar='M', - help='Learning rate step gamma (default: 0.7)') - parser.add_argument('--no-cuda', action='store_true', default=False, - help='disables CUDA training') - parser.add_argument('--no-mps', action='store_true', default=False, - help='disables macOS GPU training') - parser.add_argument('--dry-run', action='store_true', default=False, - help='quickly check a single pass') - parser.add_argument('--seed', type=int, default=1, metavar='S', - help='random seed (default: 1)') - parser.add_argument('--log-interval', type=int, default=10, metavar='N', - help='how many batches to wait before logging training status') - parser.add_argument('--save-model', action='store_true', default=True, - help='For Saving the current Model') - parser.add_argument('--gpu-id', type=int, default=0, - help='Which GPU device to use [default: 0]') - - args = parser.parse_args() - use_cuda = not args.no_cuda and torch.cuda.is_available() - use_mps = not args.no_mps and torch.backends.mps.is_available() - - torch.manual_seed(args.seed) - - if not args.no_cuda and torch.cuda.is_available(): - device = torch.device(f"cuda:{args.gpu_id}") - elif use_mps: - device = torch.device("mps") - else: - device = torch.device("cpu") - - train_kwargs = {'batch_size': args.batch_size} - test_kwargs = {'batch_size': args.test_batch_size} - if use_cuda: - cuda_kwargs = {'num_workers': 1, - 'pin_memory': True, - 'shuffle': True} - train_kwargs.update(cuda_kwargs) - test_kwargs.update(cuda_kwargs) - - transform=transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.1307,), (0.3081,)) - ]) - dataset1 = datasets.MNIST('../data', train=True, download=True, - transform=transform) - dataset2 = datasets.MNIST('../data', train=False, - transform=transform) - train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs) - test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs) - - cnn = load_my_model("mase_output/sample_mnist_cnn.pt", device) - print("-------------- Testing the original cnn model -------------------") - test(cnn, device, test_loader) - _, _ = report_trainable_parameters_analysis_pass(cnn) - - # onn = load_my_model("mase_output/onn_cnn.pt", device) - onn_model = perform_optical_module_transform_pass(cnn) - onn_model.to(device) - print("-------------- Testing the transformed onn model -------------------") - test(onn_model, device, test_loader) - _, _ = report_trainable_parameters_analysis_pass(onn_model) - - - ################################################################## - ######### Training the onn model - ################################################################## - optimizer = optim.Adadelta(onn_model.parameters(), lr=args.lr) - scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) - for epoch in range(1, args.epochs + 1): - train(args, onn_model, device, train_loader, optimizer, epoch) - test(onn_model, device, test_loader) - scheduler.step() - - - torch.save(onn_model.state_dict(), "mase_output/trained_onn.pt") - - print("-------------- Testing the trained onn model -------------------") - test(onn_model, device, test_loader) - _, _ = report_trainable_parameters_analysis_pass(onn_model) + optical_module_transform_pass(mnist_cnn, pass_args) + # torch.save(onn_cnn, "mase_output/onn_cnn.pt") + +test_optical_module_transform_pass() + +# if __name__ == '__main__': +# finetune = False + +# if True: +# parser = argparse.ArgumentParser(description='PyTorch MNIST Example') +# parser.add_argument('--batch-size', type=int, default=64, metavar='N', +# help='input batch size for training (default: 64)') +# parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', +# help='input batch size for testing (default: 1000)') +# parser.add_argument('--epochs', type=int, default=14, metavar='N', +# help='number of epochs to train (default: 14)') +# parser.add_argument('--lr', type=float, default=1.0, metavar='LR', +# help='learning rate (default: 1.0)') +# parser.add_argument('--gamma', type=float, default=0.7, metavar='M', +# help='Learning rate step gamma (default: 0.7)') +# parser.add_argument('--no-cuda', action='store_true', default=False, +# help='disables CUDA training') +# parser.add_argument('--no-mps', action='store_true', default=False, +# help='disables macOS GPU training') +# parser.add_argument('--dry-run', action='store_true', default=False, +# help='quickly check a single pass') +# parser.add_argument('--seed', type=int, default=1, metavar='S', +# help='random seed (default: 1)') +# parser.add_argument('--log-interval', type=int, default=10, metavar='N', +# help='how many batches to wait before logging training status') +# parser.add_argument('--save-model', action='store_true', default=True, +# help='For Saving the current Model') +# parser.add_argument('--gpu-id', type=int, default=0, +# help='Which GPU device to use [default: 0]') + +# args = parser.parse_args() +# use_cuda = not args.no_cuda and torch.cuda.is_available() +# use_mps = not args.no_mps and torch.backends.mps.is_available() + +# torch.manual_seed(args.seed) + +# if not args.no_cuda and torch.cuda.is_available(): +# device = torch.device(f"cuda:{args.gpu_id}") +# elif use_mps: +# device = torch.device("mps") +# else: +# device = torch.device("cpu") + +# train_kwargs = {'batch_size': args.batch_size} +# test_kwargs = {'batch_size': args.test_batch_size} +# if use_cuda: +# cuda_kwargs = {'num_workers': 1, +# 'pin_memory': True, +# 'shuffle': True} +# train_kwargs.update(cuda_kwargs) +# test_kwargs.update(cuda_kwargs) + +# transform=transforms.Compose([ +# transforms.ToTensor(), +# transforms.Normalize((0.1307,), (0.3081,)) +# ]) +# dataset1 = datasets.MNIST('../data', train=True, download=True, +# transform=transform) +# dataset2 = datasets.MNIST('../data', train=False, +# transform=transform) +# train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs) +# test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs) + +# # load pre-trained cnn +# cnn = load_my_model("mase_output/sample_mnist_cnn.pt", device) +# print("-------------- Testing the original cnn model -------------------") +# _, _ = report_trainable_parameters_analysis_pass(cnn) +# test(cnn, device, test_loader) + +# ## transform cnn into onn + +# # onn = load_my_model("mase_output/onn_cnn.pt", device) +# onn_model = perform_optical_module_transform_pass(cnn) +# onn_model.to(device) +# print("-------------- Testing the transformed onn model -------------------") +# _, _ = report_trainable_parameters_analysis_pass(onn_model) +# test(onn_model, device, test_loader) + +# # Training the onn model +# if finetune: +# optimizer = optim.Adadelta(onn_model.parameters(), lr=args.lr) +# scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) +# for epoch in range(1, args.epochs + 1): +# train(args, onn_model, device, train_loader, optimizer, epoch) +# test(onn_model, device, test_loader) +# scheduler.step() + + +# torch.save(onn_model.state_dict(), "mase_output/trained_onn.pt") + +# print("-------------- Testing the trained onn model -------------------") +# test(onn_model, device, test_loader) +# _, _ = report_trainable_parameters_analysis_pass(onn_model) - - # test_optical_module_transform_pass() \ No newline at end of file + +# def perform_optical_module_transform_pass(model, save_path="mase_output/onn_cnn.pt"): +# # pass_args = { +# # "by": "type", +# # "linear": { +# # "config": { +# # "name": "morr", +# # "miniblock": 4, +# # "morr_init": True, +# # "trainable_morr_bias": False, +# # "trainable_morr_scale": False, +# # } +# # }, +# # } +# pass_args = { +# "by": "type", +# "conv2d": { +# "config": { +# "name": "morr", +# "miniblock": 4, +# "morr_init": True, +# "trainable_morr_bias": False, +# "trainable_morr_scale": False, +# } +# }, +# } +# onn_model, _ = optical_module_transform_pass(model, pass_args) +# torch.save(onn_model.state_dict(), save_path) +# return onn_model From 5dc0974ccd54a673a125154f1e5809076bf44aa6 Mon Sep 17 00:00:00 2001 From: Johnny1882 Date: Sun, 26 Jan 2025 21:00:10 +0000 Subject: [PATCH 03/38] black format --- src/chop/nn/optical/__init__.py | 2 +- src/chop/nn/optical/functional/__init__.py | 3 - src/chop/nn/optical/functional/general.py | 40 +++-- src/chop/nn/optical/functional/initializer.py | 5 +- src/chop/nn/optical/functional/mrr_op.py | 56 ++++--- src/chop/nn/optical/functional/quantize.py | 101 +++++++++--- src/chop/nn/optical/functional/torch_train.py | 37 ++--- src/chop/nn/optical/modules/__init__.py | 2 +- src/chop/nn/optical/modules/base_layer.py | 11 +- src/chop/nn/optical/modules/morr_conv2d.py | 100 +++++++++--- src/chop/nn/optical/modules/morr_linear.py | 79 +++++++--- src/chop/passes/module/transforms/__init__.py | 2 +- .../optical/module_transform_helper.py | 55 +++---- .../module/transforms/optical/optical.py | 4 +- .../transforms/optical/test_optical_module.py | 7 +- .../transforms/optical/train_mnist_cnn.py | 146 ++++++++++++------ 16 files changed, 450 insertions(+), 200 deletions(-) diff --git a/src/chop/nn/optical/__init__.py b/src/chop/nn/optical/__init__.py index e9c0423d5..b74e7df6b 100644 --- a/src/chop/nn/optical/__init__.py +++ b/src/chop/nn/optical/__init__.py @@ -1,3 +1,3 @@ from .modules import ( optical_module_map, -) \ No newline at end of file +) diff --git a/src/chop/nn/optical/functional/__init__.py b/src/chop/nn/optical/functional/__init__.py index 11a435770..84f94e7b6 100644 --- a/src/chop/nn/optical/functional/__init__.py +++ b/src/chop/nn/optical/functional/__init__.py @@ -30,9 +30,6 @@ ) - - - # """ # Description: # Author: Jiaqi Gu (jqgu@utexas.edu) diff --git a/src/chop/nn/optical/functional/general.py b/src/chop/nn/optical/functional/general.py index 1d4f2990d..f5aa87e14 100644 --- a/src/chop/nn/optical/functional/general.py +++ b/src/chop/nn/optical/functional/general.py @@ -70,7 +70,10 @@ def wrapper(*args, **kw): local_time = time.time() res = func(*args, **kw) end_time = time.time() - print("[I] <%s> runtime: %.3f ms" % (func.__name__, (end_time - local_time) * 1000)) + print( + "[I] <%s> runtime: %.3f ms" + % (func.__name__, (end_time - local_time) * 1000) + ) else: res = func(*args, **kw) return res @@ -84,11 +87,13 @@ def print_stat(x, message="", verbose=True): if torch.is_complex(x): x = torch.view_as_real(x) print( - message + f"min = {x.data.min().item():-15f} max = {x.data.max().item():-15f} mean = {x.data.mean().item():-15f} std = {x.data.std().item():-15f}" + message + + f"min = {x.data.min().item():-15f} max = {x.data.max().item():-15f} mean = {x.data.mean().item():-15f} std = {x.data.std().item():-15f}" ) elif isinstance(x, np.ndarray): print( - message + f"min = {np.min(x):-15f} max = {np.max(x):-15f} mean = {np.mean(x):-15f} std = {np.std(x):-15f}" + message + + f"min = {np.min(x):-15f} max = {np.max(x):-15f} mean = {np.mean(x):-15f} std = {np.std(x):-15f}" ) @@ -127,7 +132,7 @@ def __enter__(self): return self def _b2mb(self, x): - return x / 2 ** 20 + return x / 2**20 def __exit__(self, *exc): self.end = self._b2mb(torch.cuda.memory_allocated()) @@ -181,13 +186,17 @@ def format(self, record): return formatter.format(record) -def setup_default_logging(default_level=logging.INFO, default_file_level=logging.INFO, log_path=""): +def setup_default_logging( + default_level=logging.INFO, default_file_level=logging.INFO, log_path="" +): console_handler = logging.StreamHandler() console_handler.setFormatter(CustomFormatter()) logging.root.addHandler(console_handler) logging.root.setLevel(default_level) if log_path: - file_handler = logging.handlers.RotatingFileHandler(log_path, maxBytes=(1024 ** 2 * 2), backupCount=3) + file_handler = logging.handlers.RotatingFileHandler( + log_path, maxBytes=(1024**2 * 2), backupCount=3 + ) file_formatter = logging.Formatter( "%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s" ) @@ -197,7 +206,13 @@ def setup_default_logging(default_level=logging.INFO, default_file_level=logging class Logger(object): - def __init__(self, console=True, logfile=None, console_level=logging.INFO, logfile_level=logging.INFO): + def __init__( + self, + console=True, + logfile=None, + console_level=logging.INFO, + logfile_level=logging.INFO, + ): super().__init__() self.logfile = logfile self.console_level = console_level @@ -242,9 +257,16 @@ def critical(self, message): self.logger.critical(message) -def get_logger(name="default", default_level=logging.INFO, default_file_level=logging.INFO, log_path=""): +def get_logger( + name="default", + default_level=logging.INFO, + default_file_level=logging.INFO, + log_path="", +): setup_default_logging( - default_level=default_level, default_file_level=default_file_level, log_path=log_path + default_level=default_level, + default_file_level=default_file_level, + log_path=log_path, ) return logging.getLogger(name) diff --git a/src/chop/nn/optical/functional/initializer.py b/src/chop/nn/optical/functional/initializer.py index 53002c398..f0591c33a 100644 --- a/src/chop/nn/optical/functional/initializer.py +++ b/src/chop/nn/optical/functional/initializer.py @@ -5,6 +5,7 @@ LastEditors: Jiaqi Gu (jqgu@utexas.edu) LastEditTime: 2021-06-06 01:57:18 """ + import numpy as np import torch @@ -83,7 +84,7 @@ def morr_uniform_(tensor, MORRConfig, n_op=4, biased=False, gain=1): """ morr_fwhm = ( -4 - * np.pi ** 2 + * np.pi**2 * MORRConfig.radius * MORRConfig.effective_index * ( @@ -124,7 +125,7 @@ def morr_uniform(tensor, MORRConfig, n_op=4, biased=False, gain=1): """ morr_fwhm = ( -4 - * np.pi ** 2 + * np.pi**2 * MORRConfig.radius * MORRConfig.effective_index * ( diff --git a/src/chop/nn/optical/functional/mrr_op.py b/src/chop/nn/optical/functional/mrr_op.py index 42952971e..e98dc4fd9 100644 --- a/src/chop/nn/optical/functional/mrr_op.py +++ b/src/chop/nn/optical/functional/mrr_op.py @@ -67,7 +67,9 @@ def mrr_tr_to_roundtrip_phase(t, a, r): assert 0 <= a <= 1, logging.error(f"Expect a from [0,1] but got {a}") assert 0 <= r <= 1, logging.error(f"Expect r from [0,1] but got {r}") # given a and r, the curve is fixed, the max and min may not be 1 and 0 - cos_phi = ((a * a + r * r - t * (1 + r * r * a * a)) / (2 * (1 - t) * a * r)).clamp(0, 1) + cos_phi = ((a * a + r * r - t * (1 + r * r * a * a)) / (2 * (1 - t) * a * r)).clamp( + 0, 1 + ) if isinstance(cos_phi, torch.Tensor): return cos_phi.acos(), cos_phi @@ -120,7 +122,9 @@ def mrr_roundtrip_phase_to_tr( @torch.jit.script -def mrr_roundtrip_phase_to_tr_fused(rt_phi, a: float = 0.8, r: float = 0.9, intensity: bool = False): +def mrr_roundtrip_phase_to_tr_fused( + rt_phi, a: float = 0.8, r: float = 0.9, intensity: bool = False +): """ description: round trip phase shift to field transmission rt_phi {torch.Tensor or np.ndarray} abs of roundtrip phase shift (abs(phase lag)). range from abs([-pi, 0])=[0, pi]\\ @@ -151,7 +155,9 @@ def mrr_roundtrip_phase_to_tr_fused(rt_phi, a: float = 0.8, r: float = 0.9, inte @torch.jit.script -def mrr_roundtrip_phase_to_tr_grad_fused(rt_phi, a: float = 0.8, r: float = 0.9, intensity: bool = False): +def mrr_roundtrip_phase_to_tr_grad_fused( + rt_phi, a: float = 0.8, r: float = 0.9, intensity: bool = False +): """ description: round trip phase shift to the gradient of field transmission rt_phi {torch.Tensor or np.ndarray} abs of roundtrip phase shift (abs(phase lag)). range from abs([-pi, 0])=[0, pi]\\ @@ -161,22 +167,24 @@ def mrr_roundtrip_phase_to_tr_grad_fused(rt_phi, a: float = 0.8, r: float = 0.9, return g {torch.Tensor or np.ndarray} the gradient of mrr through port field/intensity transmission """ if not intensity: - g = (a * r * (a ** 2 - 1) * (r ** 2 - 1) * rt_phi.sin()) / ( - (a ** 2 + r ** 2 - 2 * a * r * rt_phi.cos()) ** (1 / 2) - * (a ** 2 * r ** 2 + 1 - 2 * a * r * rt_phi.cos()) ** 1.5 + g = (a * r * (a**2 - 1) * (r**2 - 1) * rt_phi.sin()) / ( + (a**2 + r**2 - 2 * a * r * rt_phi.cos()) ** (1 / 2) + * (a**2 * r**2 + 1 - 2 * a * r * rt_phi.cos()) ** 1.5 ) else: - g = ((a ** 2 - 1) * (r ** 2 - 1) * 2 * a * r * rt_phi.sin()) / ( - a ** 2 * r ** 2 + 1 - 2 * a * r * rt_phi.cos() + g = ((a**2 - 1) * (r**2 - 1) * 2 * a * r * rt_phi.sin()) / ( + a**2 * r**2 + 1 - 2 * a * r * rt_phi.cos() ) ** 2 return g -def mrr_roundtrip_phase_to_tr_func(a: float = 0.8, r: float = 0.9, intensity: bool = False): +def mrr_roundtrip_phase_to_tr_func( + a: float = 0.8, r: float = 0.9, intensity: bool = False +): c1 = -2 * a * r c2 = a * a + r * r c3 = 1 + r * r * a * a - a * a - r * r - c4 = (a ** 2 - 1) * (r ** 2 - 1) * 2 * a * r + c4 = (a**2 - 1) * (r**2 - 1) * 2 * a * r class MRRRoundTripPhaseToTrFunction(torch.autograd.Function): @staticmethod @@ -202,7 +210,9 @@ def backward(ctx, grad_output): numerator = input.sin().mul_(c4) else: numerator = input.sin().mul_(c4 / 2) - denominator = denominator.sub(1).pow_(1.5).mul_(denominator.sub(c3).sqrt_()) + denominator = ( + denominator.sub(1).pow_(1.5).mul_(denominator.sub(c3).sqrt_()) + ) grad_input = numerator.div_(denominator).mul_(grad_output) return grad_input @@ -334,7 +344,9 @@ def mrr_filter(x, t, a=0.9, r=0.8): return out -def morr_filter(rt_phi, tr_poly_coeff=None, a=0.9, r=0.8, x=None, coherent=False, intensity=False): +def morr_filter( + rt_phi, tr_poly_coeff=None, a=0.9, r=0.8, x=None, coherent=False, intensity=False +): """ description: from round trip phase shift to output signal \\ rt_phi {torch.Tensor or np.ndarray, Optional} round trip phase shift. Default set to None \\ @@ -349,16 +361,22 @@ def morr_filter(rt_phi, tr_poly_coeff=None, a=0.9, r=0.8, x=None, coherent=False if not coherent: if x is None: # unit laser input with incoherent light, 1e^j0 - t = mrr_roundtrip_phase_to_tr(rt_phi, a=a, r=r, poly_coeff=tr_poly_coeff, intensity=intensity) + t = mrr_roundtrip_phase_to_tr( + rt_phi, a=a, r=r, poly_coeff=tr_poly_coeff, intensity=intensity + ) return t else: # incoherent light with non-unit input, input must be real number - t = mrr_roundtrip_phase_to_tr(rt_phi, a=a, r=r, poly_coeff=tr_poly_coeff, intensity=intensity) + t = mrr_roundtrip_phase_to_tr( + rt_phi, a=a, r=r, poly_coeff=tr_poly_coeff, intensity=intensity + ) return x * t else: if x is None: # coherent light with unit laser, 1e^j0, treat morr as a mrr modulator - phase = polar_to_complex(mag=None, angle=mrr_roundtrip_phase_to_out_phase(rt_phi, a, r)) + phase = polar_to_complex( + mag=None, angle=mrr_roundtrip_phase_to_out_phase(rt_phi, a, r) + ) return phase else: # coherent light with complex input @@ -375,7 +393,9 @@ def mrr_fwhm_to_ng(a, r, radius, lambda0, fwhm): fwhm {float} bandwidth or full width half maximum (unit: nm)\\ return n_g {float} Group index of the MRR """ - n_g = (1 - r * a) * lambda0 ** 2 / (2 * np.pi * np.pi * radius * (r * a) ** 0.5 * fwhm) + n_g = ( + (1 - r * a) * lambda0**2 / (2 * np.pi * np.pi * radius * (r * a) ** 0.5 * fwhm) + ) return n_g @@ -388,7 +408,7 @@ def mrr_ng_to_fsr(lambda0, n_g, radius): radius {float} Radius of the MRR (unit: nm)\\ return fsr {float} Free-spectral range """ - fsr = lambda0 ** 2 / (n_g * 2 * np.pi * radius) + fsr = lambda0**2 / (n_g * 2 * np.pi * radius) return fsr @@ -400,5 +420,5 @@ def mrr_finesse(a, r): return finesse {float} Finesse of the MRR """ ra = r * a - finesse = np.pi * ra ** 0.5 / (1 - ra) + finesse = np.pi * ra**0.5 / (1 - ra) return finesse diff --git a/src/chop/nn/optical/functional/quantize.py b/src/chop/nn/optical/functional/quantize.py index 09b7ef28d..f60e8b240 100644 --- a/src/chop/nn/optical/functional/quantize.py +++ b/src/chop/nn/optical/functional/quantize.py @@ -82,7 +82,14 @@ def forward(ctx, input, scale, zero_point): n = float(2**k - 1) # out = torch.round(input * n) / n # out = (torch.clamp(torch.round(input / scale + zero_point), 0, n) - zero_point) * scale - out = input.div(scale).add_(zero_point).round_().clamp_(0, n).sub_(zero_point).mul_(scale) + out = ( + input.div(scale) + .add_(zero_point) + .round_() + .clamp_(0, n) + .sub_(zero_point) + .mul_(scale) + ) return out @staticmethod @@ -108,7 +115,7 @@ class EWGS_quantizer(torch.autograd.Function): @staticmethod def forward(ctx, input): - out = input.mul(num_levels - 1).round_().mul_(1/(num_levels - 1)) + out = input.mul(num_levels - 1).round_().mul_(1 / (num_levels - 1)) ctx._scaling_factor = scaling_factor ctx.save_for_backward(input - out) @@ -128,7 +135,9 @@ def backward(ctx, grad_output): class input_quantize_fn(torch.nn.Module): - def __init__(self, in_bit, alg="dorefa", device=torch.device("cuda:0"), quant_ratio=1.0): + def __init__( + self, in_bit, alg="dorefa", device=torch.device("cuda:0"), quant_ratio=1.0 + ): """Input quantizer with Quant_Noise supported Args: in_bit (int): Input quantization bitwidth. @@ -139,9 +148,14 @@ def __init__(self, in_bit, alg="dorefa", device=torch.device("cuda:0"), quant_ra assert 1 <= in_bit <= 32 self.in_bit = in_bit self.alg = alg - assert alg in {"dorefa", "normal"}, f"Only support (dorefa, normal), but got {alg}" + assert alg in { + "dorefa", + "normal", + }, f"Only support (dorefa, normal), but got {alg}" self.quant_ratio = quant_ratio - assert 0 <= quant_ratio <= 1, logging.error(f"Wrong quant ratio. Must in [0,1], but got {quant_ratio}") + assert 0 <= quant_ratio <= 1, logging.error( + f"Wrong quant ratio. Must in [0,1], but got {quant_ratio}" + ) self.device = device # define quant style @@ -182,7 +196,10 @@ def set_bitwidth(self, bit: int) -> None: self.in_bit = bit def set_alg(self, alg: str) -> None: - assert alg in {"dorefa", "normal"}, f"Only support (dorefa, normal), but got {alg}" + assert alg in { + "dorefa", + "normal", + }, f"Only support (dorefa, normal), but got {alg}" if alg != self.alg: if alg == "dorefa": self.uniform_q = uniform_quantize(k=self.in_bit) @@ -212,14 +229,18 @@ def set_quant_ratio(self, quant_ratio=None): 0.99, 1, ][min(self.in_bit, 16)] - assert 0 <= quant_ratio <= 1, logging.error(f"Wrong quant ratio. Must in [0,1], but got {quant_ratio}") + assert 0 <= quant_ratio <= 1, logging.error( + f"Wrong quant ratio. Must in [0,1], but got {quant_ratio}" + ) self.quant_ratio = quant_ratio def forward(self, x): if self.quant_ratio < 1 and self.training: ### implementation from fairseq ### must fully quantize during inference - quant_noise_mask = torch.empty_like(x, dtype=torch.bool).bernoulli_(1 - self.quant_ratio) + quant_noise_mask = torch.empty_like(x, dtype=torch.bool).bernoulli_( + 1 - self.quant_ratio + ) else: quant_noise_mask = None @@ -271,7 +292,9 @@ def __init__(self, w_bit, mode="oconv", alg="dorefa", quant_ratio=1.0): quant_ratio (float, optional): Quantization ratio to support full-precision gradient flow. Defaults to 1.0. """ super(weight_quantize_fn, self).__init__() - assert 1 <= w_bit <= 32, logging.error(f"Only support 1 - 32 bit quantization, but got {w_bit}") + assert 1 <= w_bit <= 32, logging.error( + f"Only support 1 - 32 bit quantization, but got {w_bit}" + ) self.w_bit = w_bit self.alg = alg self.mode = mode @@ -279,7 +302,9 @@ def __init__(self, w_bit, mode="oconv", alg="dorefa", quant_ratio=1.0): f"Only support (dorefa, dorefa_sym, qnn, dorefa_pos) algorithms, but got {alg}" ) self.quant_ratio = quant_ratio - assert 0 <= quant_ratio <= 1, logging.error(f"Wrong quant ratio. Must in [0,1], but got {quant_ratio}") + assert 0 <= quant_ratio <= 1, logging.error( + f"Wrong quant ratio. Must in [0,1], but got {quant_ratio}" + ) self.uniform_q = uniform_quantize(k=w_bit, gradient_clip=True) def set_quant_ratio(self, quant_ratio=None): @@ -304,7 +329,9 @@ def set_quant_ratio(self, quant_ratio=None): 0.99, 1, ][min(self.w_bit, 16)] - assert 0 <= quant_ratio <= 1, logging.error(f"Wrong quant ratio. Must in [0,1], but got {quant_ratio}") + assert 0 <= quant_ratio <= 1, logging.error( + f"Wrong quant ratio. Must in [0,1], but got {quant_ratio}" + ) self.quant_ratio = quant_ratio def set_bitwidth(self, bit: int) -> None: @@ -317,7 +344,9 @@ def forward(self, x): if self.quant_ratio < 1 and self.training: ### implementation from fairseq ### must fully quantize during inference - quant_noise_mask = torch.empty_like(x, dtype=torch.bool).bernoulli_(1 - self.quant_ratio) + quant_noise_mask = torch.empty_like(x, dtype=torch.bool).bernoulli_( + 1 - self.quant_ratio + ) else: quant_noise_mask = None @@ -333,14 +362,18 @@ def forward(self, x): weight_q = (self.uniform_q(x / E) * E + E) / 2 # [0, E] if quant_noise_mask is not None: x = (x + E) / 2 - noise = weight_q.data.sub_(x.data).masked_fill_(quant_noise_mask, 0) + noise = weight_q.data.sub_(x.data).masked_fill_( + quant_noise_mask, 0 + ) ### unquantized weights have to follow reparameterization, i.e., tanh and scale weight_q = x + noise elif self.alg == "dorefa_sym": E = x.data.abs().mean() weight_q = self.uniform_q(x / E) * E # [-E, E] if quant_noise_mask is not None: - noise = weight_q.data.sub_(x.data).masked_fill_(quant_noise_mask, 0) + noise = weight_q.data.sub_(x.data).masked_fill_( + quant_noise_mask, 0 + ) ### unquantized weights have to follow reparameterization, i.e., tanh and scale weight_q = x + noise else: @@ -352,7 +385,9 @@ def forward(self, x): # weight = weight / 2 + 0.5 weight_q = self.uniform_q(weight) if quant_noise_mask is not None: - noise = weight_q.data.sub_(weight.data).masked_fill_(quant_noise_mask, 0) + noise = weight_q.data.sub_(weight.data).masked_fill_( + quant_noise_mask, 0 + ) ### unquantized weights have to follow reparameterization, i.e., tanh and scale weight_q = weight + noise @@ -362,7 +397,9 @@ def forward(self, x): # weight = weight / 2 + 0.5 weight_q = self.uniform_q(weight / (2 * r) + 0.5) * (2 * r) - r if quant_noise_mask is not None: - noise = weight_q.data.sub_(weight.data).masked_fill_(quant_noise_mask, 0) + noise = weight_q.data.sub_(weight.data).masked_fill_( + quant_noise_mask, 0 + ) ### unquantized weights have to follow reparameterization, i.e., tanh weight_q = weight + noise elif self.alg == "dorefa_pos": @@ -372,7 +409,9 @@ def forward(self, x): # weight = weight / 2 + 0.5 weight_q = self.uniform_q(weight / (2 * r)) * 2 * r if quant_noise_mask is not None: - noise = weight_q.data.sub_(weight.data).masked_fill_(quant_noise_mask, 0) + noise = weight_q.data.sub_(weight.data).masked_fill_( + quant_noise_mask, 0 + ) ### unquantized weights have to follow reparameterization, i.e., tanh weight_q = weight + noise @@ -484,7 +523,9 @@ def __init__( super(PACT_Act, self).__init__() self.precision = precision self.device = device - self.alpha = torch.nn.Parameter(torch.Tensor((alpha,)).to(device), requires_grad=backprop_alpha) + self.alpha = torch.nn.Parameter( + torch.Tensor((alpha,)).to(device), requires_grad=backprop_alpha + ) self.alpha_p = alpha self.statistics_only = statistics_only self.deployment = False @@ -493,10 +534,18 @@ def __init__( # self.requantization_factor = requantization_factor # these are only used to gather statistics - self.max = torch.nn.Parameter(torch.zeros_like(self.alpha.data).to(device), requires_grad=False) - self.min = torch.nn.Parameter(torch.zeros_like(self.alpha.data).to(device), requires_grad=False) - self.running_mean = torch.nn.Parameter(torch.zeros_like(self.alpha.data).to(device), requires_grad=False) - self.running_var = torch.nn.Parameter(torch.ones_like(self.alpha.data).to(device), requires_grad=False) + self.max = torch.nn.Parameter( + torch.zeros_like(self.alpha.data).to(device), requires_grad=False + ) + self.min = torch.nn.Parameter( + torch.zeros_like(self.alpha.data).to(device), requires_grad=False + ) + self.running_mean = torch.nn.Parameter( + torch.zeros_like(self.alpha.data).to(device), requires_grad=False + ) + self.running_var = torch.nn.Parameter( + torch.ones_like(self.alpha.data).to(device), requires_grad=False + ) self.precise = False @@ -508,7 +557,9 @@ def set_static_precision(self, limit_at_32_bits=True, **kwargs): self.eps_static = self.alpha.clone().detach() / (2.0 ** (self.precision) - 1) self.alpha_static = self.alpha.clone().detach() # D is selected as a power-of-two - D = 2.0 ** torch.ceil(torch.log2(self.requantization_factor * self.eps_static / self.eps_in)) + D = 2.0 ** torch.ceil( + torch.log2(self.requantization_factor * self.eps_static / self.eps_in) + ) if not limit_at_32_bits: self.D = D else: @@ -568,7 +619,9 @@ def forward(self, x): self.max[:] = max(self.max.item(), x.max()) self.min[:] = min(self.min.item(), x.min()) self.running_mean[:] = 0.9 * self.running_mean.item() + 0.1 * x.mean() - self.running_var[:] = 0.9 * self.running_var.item() + 0.1 * x.std() * x.std() + self.running_var[:] = ( + 0.9 * self.running_var.item() + 0.1 * x.std() * x.std() + ) return x else: eps = self.alpha / (2.0 ** (self.precision) - 1) diff --git a/src/chop/nn/optical/functional/torch_train.py b/src/chop/nn/optical/functional/torch_train.py index c62ba3e91..41571effa 100644 --- a/src/chop/nn/optical/functional/torch_train.py +++ b/src/chop/nn/optical/functional/torch_train.py @@ -17,6 +17,7 @@ import torch from scipy import interpolate from torch.nn.modules.batchnorm import _BatchNorm + try: from torchsummary import summary except: @@ -48,26 +49,26 @@ "enable_bn", ] -class DeterministicCtx: - def __init__(self, random_state: int | None = None) -> None: - self.random_state = random_state - - - def __enter__(self): - self.random_state = random.getstate() - self.numpy_random_state = np.random.get_state() - self.torch_random_state = torch.random.get_rng_state() - self.torch_cuda_random_state = torch.cuda.get_rng_state() - set_torch_deterministic(self.random_state) - return self +class DeterministicCtx: + def __init__(self, random_state: int | None = None) -> None: + self.random_state = random_state + + def __enter__(self): + self.random_state = random.getstate() + self.numpy_random_state = np.random.get_state() + self.torch_random_state = torch.random.get_rng_state() + self.torch_cuda_random_state = torch.cuda.get_rng_state() + set_torch_deterministic(self.random_state) + return self + + def __exit__(self, *args): + random.setstate(self.random_state) + np.random.seed(self.numpy_random_state) + np.random.set_state(self.numpy_random_state) + torch.random.set_rng_state(self.torch_random_state) + torch.cuda.set_rng_state(self.torch_cuda_random_state) - def __exit__(self, *args): - random.setstate(self.random_state) - np.random.seed(self.numpy_random_state) - np.random.set_state(self.numpy_random_state) - torch.random.set_rng_state(self.torch_random_state) - torch.cuda.set_rng_state(self.torch_cuda_random_state) def set_torch_deterministic(random_state: int = 0) -> None: random_state = int(random_state) % (2**32) diff --git a/src/chop/nn/optical/modules/__init__.py b/src/chop/nn/optical/modules/__init__.py index 46b6f21e9..b1d7c5629 100644 --- a/src/chop/nn/optical/modules/__init__.py +++ b/src/chop/nn/optical/modules/__init__.py @@ -4,4 +4,4 @@ optical_module_map = { "linear_morr": AllPassMORRCirculantLinear, "conv2d_morr": AllPassMORRCirculantConv2d, -} \ No newline at end of file +} diff --git a/src/chop/nn/optical/modules/base_layer.py b/src/chop/nn/optical/modules/base_layer.py index cb1fbe9dd..4ae2b35bf 100644 --- a/src/chop/nn/optical/modules/base_layer.py +++ b/src/chop/nn/optical/modules/base_layer.py @@ -5,6 +5,7 @@ LastEditors: Jiaqi Gu (jqgu@utexas.edu) LastEditTime: 2021-06-08 18:55:05 """ + from typing import Any, Dict, Optional import torch from torch import nn @@ -24,7 +25,7 @@ def build_parameters(self) -> None: def reset_parameters(self) -> None: raise NotImplementedError - + @classmethod def from_layer(cls, layer: nn.Module, *args, **kwargs) -> nn.Module: raise NotImplementedError @@ -38,10 +39,14 @@ def enable_fast_forward(self) -> None: def disable_fast_forward(self) -> None: self.fast_forward_flag = False - def set_phase_variation(self, noise_std: float, random_state: Optional[int] = None) -> None: + def set_phase_variation( + self, noise_std: float, random_state: Optional[int] = None + ) -> None: self.phase_noise_std = noise_std - def set_gamma_noise(self, noise_std: float, random_state: Optional[int] = None) -> None: + def set_gamma_noise( + self, noise_std: float, random_state: Optional[int] = None + ) -> None: self.gamma_noise_std = noise_std def set_crosstalk_factor(self, crosstalk_factor: float) -> None: diff --git a/src/chop/nn/optical/modules/morr_conv2d.py b/src/chop/nn/optical/modules/morr_conv2d.py index 87fd8ed42..a554bc6c2 100644 --- a/src/chop/nn/optical/modules/morr_conv2d.py +++ b/src/chop/nn/optical/modules/morr_conv2d.py @@ -4,6 +4,7 @@ LastEditors: Jiaqi Gu (jqgu@utexas.edu) LastEditTime: 2021-07-18 00:40:18 """ + from typing import Optional, Tuple import numpy as np @@ -71,7 +72,7 @@ def __init__( dilation: _size = 1, groups: int = 1, bias: bool = True, - padding_mode = None, + padding_mode=None, # miniblock: int = 4, # ### morr parameter # MORRConfig=MORRConfig_20um_MQ, @@ -79,7 +80,7 @@ def __init__( # ### trainable MORR nonlinearity # trainable_morr_bias: bool = False, # trainable_morr_scale: bool = False, - config = None, + config=None, device: Device = torch.device("cuda"), ) -> None: super(AllPassMORRCirculantConv2d, self).__init__() @@ -89,7 +90,6 @@ def __init__( trainable_morr_bias = config.get("trainable_morr_bias", False) trainable_morr_scale = config.get("trainable_morr_scale", False) - self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = _pair(kernel_size) @@ -97,8 +97,12 @@ def __init__( self.padding = _pair(padding) self.dilation = _pair(dilation) self.groups = groups - assert groups == 1, f"Currently group convolution is not supported, but got group: {groups}" - self.in_channels_flat = self.in_channels * self.kernel_size[0] * self.kernel_size[1] + assert ( + groups == 1 + ), f"Currently group convolution is not supported, but got group: {groups}" + self.in_channels_flat = ( + self.in_channels * self.kernel_size[0] * self.kernel_size[1] + ) self.grid_dim_x = int(np.ceil(self.in_channels_flat / miniblock)) self.grid_dim_y = int(np.ceil(self.out_channels / miniblock)) self.in_channels_pad = self.grid_dim_x * miniblock @@ -107,7 +111,7 @@ def __init__( self.v_max = 10.8 self.v_pi = 4.36 - self.gamma = np.pi / self.v_pi ** 2 + self.gamma = np.pi / self.v_pi**2 self.w_bit = 32 self.in_bit = 32 self.MORRConfig = MORRConfig @@ -121,7 +125,7 @@ def __init__( ### calculate FWHM (rad) self.morr_fwhm = ( -4 - * np.pi ** 2 + * np.pi**2 * MORRConfig.radius * MORRConfig.effective_index * ( @@ -135,7 +139,9 @@ def __init__( self.x_zero_pad = None self.morr_output_scale = None ## learnable balancing factors implelemt by MRRs self.morr_input_bias = None ## round-trip phase shift bias within MORR - self.morr_input_scale = None ## scaling factor for the round-trip phase shift within MORR + self.morr_input_scale = ( + None ## scaling factor for the round-trip phase shift within MORR + ) self.morr_gain = ( 100 / (self.in_channels_flat // self.miniblock) ) ** 0.5 ## set this TIA gain such that output variance is around 1 @@ -178,21 +184,37 @@ def build_parameters(self) -> None: ### MORR weights self.weight = Parameter( torch.ones( - self.grid_dim_y, self.grid_dim_x, self.miniblock, device=self.device, dtype=torch.float + self.grid_dim_y, + self.grid_dim_x, + self.miniblock, + device=self.device, + dtype=torch.float, ) ) ### learnable balancing factor achieved by MRRs (morr_output_scale) ### We use a single scaling factor for each block - self.morr_output_scale = Parameter(torch.zeros(max(1, self.grid_dim_x // 2) + 1, device=self.device)) + self.morr_output_scale = Parameter( + torch.zeros(max(1, self.grid_dim_x // 2) + 1, device=self.device) + ) if self.trainable_morr_bias: ### initialize with the finest-granularity, i.e., per mini-block self.morr_input_bias = Parameter( - torch.zeros(self.grid_dim_y, self.grid_dim_x, device=self.device, dtype=torch.float) + torch.zeros( + self.grid_dim_y, + self.grid_dim_x, + device=self.device, + dtype=torch.float, + ) ) if self.trainable_morr_scale: ### initialize with the finest-granularity, i.e., per mini-block self.morr_input_scale = Parameter( - torch.zeros(self.grid_dim_y, self.grid_dim_x, device=self.device, dtype=torch.float) + torch.zeros( + self.grid_dim_y, + self.grid_dim_x, + device=self.device, + dtype=torch.float, + ) ) def reset_parameters(self, morr_init: bool = False) -> None: @@ -212,11 +234,16 @@ def reset_parameters(self, morr_init: bool = False) -> None: torch.tensor([0]).float(), a=self.mrr_a, r=self.mrr_r, intensity=True ) t2 = mrr_roundtrip_phase_to_tr_fused( - torch.tensor([self.morr_fwhm * 2.4]).float(), a=self.mrr_a, r=self.mrr_r, intensity=True + torch.tensor([self.morr_fwhm * 2.4]).float(), + a=self.mrr_a, + r=self.mrr_r, + intensity=True, ) - g = ((t2 - t1) / (2.4 * self.morr_fwhm)).item() ## 0~2.4 FWHM slope as a linear approximation + g = ( + (t2 - t1) / (2.4 * self.morr_fwhm) + ).item() ## 0~2.4 FWHM slope as a linear approximation - self.sigma_out_scale = 4 / (3 * self.grid_dim_x ** 0.5 * g * self.morr_fwhm) + self.sigma_out_scale = 4 / (3 * self.grid_dim_x**0.5 * g * self.morr_fwhm) self.out_scale_quant_gain = None init.normal_(self.morr_output_scale, 0, self.sigma_out_scale) @@ -260,7 +287,9 @@ def enable_fast_forward(self) -> None: def disable_fast_forward(self) -> None: self.fast_forward_flag = False - def set_gamma_noise(self, noise_std: float, random_state: Optional[int] = None) -> None: + def set_gamma_noise( + self, noise_std: float, random_state: Optional[int] = None + ) -> None: self.gamma_noise_std = noise_std def load_parameters(self, param_dict) -> None: @@ -284,7 +313,9 @@ def input_modulator(self, x: Tensor) -> Tensor: ### voltage to power, which is proportional to the phase shift return x * x - def set_crosstalk_coupling_matrix(self, coupling_factor: float, drop_perc: float = 0) -> None: + def set_crosstalk_coupling_matrix( + self, coupling_factor: float, drop_perc: float = 0 + ) -> None: ### crosstalk coupling matrix is a symmetric matrix, but the intra-MORR crosstalk can be taken as a round-trip phase shift scaling factor, which is proportional to the number of segments after pruned. ### See SqueezeLight paper ### drop-perc is the pruning percentage. @@ -292,7 +323,9 @@ def set_crosstalk_coupling_matrix(self, coupling_factor: float, drop_perc: float f"Coupling factor must in [0,1], but got {coupling_factor}" ) - self.crosstalk_factor = 1 + max(3, (self.miniblock * (1 - drop_perc) - 1)) * coupling_factor + self.crosstalk_factor = ( + 1 + max(3, (self.miniblock * (1 - drop_perc) - 1)) * coupling_factor + ) def enable_crosstalk(self) -> None: self.enable_thermal_crosstalk = True @@ -327,7 +360,9 @@ def morr_scale(self) -> Tensor: @property def morr_bias(self) -> Tensor: - return self.morr_fwhm * torch.tanh(self.morr_input_bias.unsqueeze(0).unsqueeze(-1)) + return self.morr_fwhm * torch.tanh( + self.morr_input_bias.unsqueeze(0).unsqueeze(-1) + ) def propagate_morr(self, weight: Tensor, x: Tensor) -> Tensor: """ @@ -352,7 +387,9 @@ def propagate_morr(self, weight: Tensor, x: Tensor) -> Tensor: x = weight.matmul(x).squeeze(-1) # [h*w*bs, p, q, k] if self.enable_phase_noise and self.phase_noise_std > 1e-5: - x = x + torch.zeros_like(x).normal_(0, self.phase_noise_std) # [h*w*bs, p, q, k] + x = x + torch.zeros_like(x).normal_( + 0, self.phase_noise_std + ) # [h*w*bs, p, q, k] ### input scaling, learnable MORR nonlinearity if self.trainable_morr_scale: @@ -370,7 +407,9 @@ def propagate_morr(self, weight: Tensor, x: Tensor) -> Tensor: if self.w_bit < 16: morr_output_scale = self.morr_output_scale_quantizer(self.morr_output_scale) if self.out_scale_quant_gain is None: - self.out_scale_quant_gain = self.sigma_out_scale / morr_output_scale.data.std().item() + self.out_scale_quant_gain = ( + self.sigma_out_scale / morr_output_scale.data.std().item() + ) morr_output_scale = morr_output_scale.mul( self.out_scale_quant_gain ) ### gain factor from Tanh used in quantization @@ -392,7 +431,9 @@ def propagate_morr(self, weight: Tensor, x: Tensor) -> Tensor: scale = scale_pad scale = scale.unsqueeze(0).unsqueeze(0).unsqueeze(0) # [1, 1, 1, q] - x = scale.matmul(x) # [1,1,1,q]x[h_out*w_out*bs, p, q, k]=[h_out*w_out*bs, p, 1, k] + x = scale.matmul( + x + ) # [1,1,1,q]x[h_out*w_out*bs, p, q, k]=[h_out*w_out*bs, p, 1, k] x = x.view(x.size(0), -1).t() # [p*k, h_out*w_out*bs] if self.out_channels_pad > self.out_channels: x = x[: self.out_channels, :] # [outc, h_out*w_out*bs] @@ -407,7 +448,12 @@ def morr_conv2d(self, X: Tensor, W: Tensor) -> Tensor: X, stride=self.stride[0], padding=self.padding[0], - w_size=(self.out_channels, self.in_channels, self.kernel_size[0], self.kernel_size[1]), + w_size=( + self.out_channels, + self.in_channels, + self.kernel_size[0], + self.kernel_size[1], + ), ) ## zero-padding X_col if self.in_channels_pad > self.in_channels_flat: @@ -450,8 +496,12 @@ def get_output_dim(self, img_height: int, img_width: int) -> Tuple[int, int]: """ get the output features size """ - h_out = (img_height - self.kernel_size[0] + 2 * self.padding[0]) / self.stride[0] + 1 - w_out = (img_width - self.kernel_size[1] + 2 * self.padding[1]) / self.stride[1] + 1 + h_out = (img_height - self.kernel_size[0] + 2 * self.padding[0]) / self.stride[ + 0 + ] + 1 + w_out = (img_width - self.kernel_size[1] + 2 * self.padding[1]) / self.stride[ + 1 + ] + 1 return (int(h_out), int(w_out)) def forward(self, x: Tensor) -> Tensor: diff --git a/src/chop/nn/optical/modules/morr_linear.py b/src/chop/nn/optical/modules/morr_linear.py index 4846fb763..9a3f483e5 100644 --- a/src/chop/nn/optical/modules/morr_linear.py +++ b/src/chop/nn/optical/modules/morr_linear.py @@ -5,6 +5,7 @@ LastEditors: Jiaqi Gu (jqgu@utexas.edu) LastEditTime: 2022-04-18 16:21:37 """ + from typing import Optional import numpy as np @@ -43,7 +44,7 @@ def __init__( in_features: int, out_features: int, bias: bool = False, - config = None, + config=None, # miniblock: int = 4, # ### mrr parameter # MORRConfig=MORRConfig_20um_MQ, @@ -66,7 +67,7 @@ def __init__( self.v_max = 10.8 self.v_pi = 4.36 - self.gamma = np.pi / self.v_pi ** 2 + self.gamma = np.pi / self.v_pi**2 self.w_bit = 32 self.in_bit = 32 @@ -77,12 +78,14 @@ def __init__( self.mrr_a = morr_config.attenuation_factor self.mrr_r = morr_config.coupling_factor self.trainable_morr_bias = config.get("trainable_morr_bias", MORRConfig_20um_MQ) - self.trainable_morr_scale = config.get("trainable_morr_scale", MORRConfig_20um_MQ) + self.trainable_morr_scale = config.get( + "trainable_morr_scale", MORRConfig_20um_MQ + ) self.device = device ### calculate FWHM (rad) self.morr_fwhm = ( -4 - * np.pi ** 2 + * np.pi**2 * morr_config.radius * morr_config.effective_index * ( @@ -96,7 +99,9 @@ def __init__( self.x_zero_pad = None self.morr_output_scale = None ## learnable balancing factors implelemt by MRRs self.morr_input_bias = None ## round-trip phase shift bias within MORR - self.morr_input_scale = None ## scaling factor for the round-trip phase shift within MORR + self.morr_input_scale = ( + None ## scaling factor for the round-trip phase shift within MORR + ) self.morr_gain = ( 100 / (self.in_features // self.miniblock) ) ** 0.5 ## TIA gain, calculated such that output variance is around 1 @@ -137,7 +142,11 @@ def build_parameters(self) -> None: self.weight = Parameter( torch.ones( - self.grid_dim_y, self.grid_dim_x, self.miniblock, device=self.device, dtype=torch.float + self.grid_dim_y, + self.grid_dim_x, + self.miniblock, + device=self.device, + dtype=torch.float, ) ) ### Learnable balancing factor (morr_output_scale) @@ -148,12 +157,22 @@ def build_parameters(self) -> None: if self.trainable_morr_bias: ### initialize with the finest-granularity, i.e., per mini-block self.morr_input_bias = Parameter( - torch.zeros(self.grid_dim_y, self.grid_dim_x, device=self.device, dtype=torch.float) + torch.zeros( + self.grid_dim_y, + self.grid_dim_x, + device=self.device, + dtype=torch.float, + ) ) if self.trainable_morr_scale: ### initialize with the finest-granularity, i.e., per mini-block self.morr_input_scale = Parameter( - torch.zeros(self.grid_dim_y, self.grid_dim_x, device=self.device, dtype=torch.float) + torch.zeros( + self.grid_dim_y, + self.grid_dim_x, + device=self.device, + dtype=torch.float, + ) ) def reset_parameters(self, morr_init: bool = False) -> None: @@ -175,11 +194,16 @@ def reset_parameters(self, morr_init: bool = False) -> None: torch.tensor([0]).float(), a=self.mrr_a, r=self.mrr_r, intensity=True ) t2 = mrr_roundtrip_phase_to_tr_fused( - torch.tensor([self.morr_fwhm * 2.4]).float(), a=self.mrr_a, r=self.mrr_r, intensity=True + torch.tensor([self.morr_fwhm * 2.4]).float(), + a=self.mrr_a, + r=self.mrr_r, + intensity=True, ) - g = ((t2 - t1) / (2.4 * self.morr_fwhm)).item() ## 0~2.4 FWHM slope as a linear approximation + g = ( + (t2 - t1) / (2.4 * self.morr_fwhm) + ).item() ## 0~2.4 FWHM slope as a linear approximation - self.sigma_out_scale = 4 / (3 * self.grid_dim_x ** 0.5 * g * self.morr_fwhm) + self.sigma_out_scale = 4 / (3 * self.grid_dim_x**0.5 * g * self.morr_fwhm) self.out_scale_quant_gain = None init.normal_(self.morr_output_scale, 0, self.sigma_out_scale) else: @@ -218,13 +242,17 @@ def build_weight(self) -> Tensor: morr_scale = self.morr_scale * self.weight_quant_gain else: morr_scale = self.weight_quant_gain - weight = weight.mul(morr_scale) ### gain factor from Tanh used in quantization + weight = weight.mul( + morr_scale + ) ### gain factor from Tanh used in quantization ### quantize learnable balancing factor morr_output_scale = self.morr_output_scale_quantizer(self.morr_output_scale) else: weight = self.weight.abs() # positive only - morr_output_scale = self.morr_output_scale - self.morr_output_scale.data.mean() + morr_output_scale = ( + self.morr_output_scale - self.morr_output_scale.data.mean() + ) if self.finegrain_drop_mask is not None: weight = weight.mul(self.finegrain_drop_mask.float()) @@ -251,7 +279,9 @@ def enable_fast_forward(self) -> None: def disable_fast_forward(self) -> None: self.fast_forward_flag = False - def set_gamma_noise(self, noise_std: float, random_state: Optional[int] = None) -> None: + def set_gamma_noise( + self, noise_std: float, random_state: Optional[int] = None + ) -> None: self.gamma_noise_std = noise_std def load_parameters(self, param_dict) -> None: @@ -275,14 +305,18 @@ def input_modulator(self, x: Tensor) -> Tensor: ### voltage to power, which is proportional to the phase shift return x * x - def set_crosstalk_coupling_matrix(self, coupling_factor: float, drop_perc: float = 0) -> None: + def set_crosstalk_coupling_matrix( + self, coupling_factor: float, drop_perc: float = 0 + ) -> None: ### crosstalk coupling matrix is a symmetric matrix, but the intra-MORR crosstalk can be taken as a round-trip phase shift scaling factor, which is proportional to the number of segments after pruned. ### drop-perc is the pruning percentage. assert 0 <= coupling_factor <= 1, logger.error( f"Coupling factor must in [0,1], but got {coupling_factor}" ) - self.crosstalk_factor = 1 + max(3, (self.miniblock * (1 - drop_perc) - 1)) * coupling_factor + self.crosstalk_factor = ( + 1 + max(3, (self.miniblock * (1 - drop_perc) - 1)) * coupling_factor + ) def enable_crosstalk(self) -> None: self.enable_thermal_crosstalk = True @@ -316,7 +350,9 @@ def morr_bias(self) -> Tensor: if self.morr_input_bias is None: return None # return 2 * self.morr_fwhm * torch.sigmoid(self.morr_input_bias.unsqueeze(0).unsqueeze(-1)) - return self.morr_fwhm * torch.tanh(self.morr_input_bias.unsqueeze(0).unsqueeze(-1)) + return self.morr_fwhm * torch.tanh( + self.morr_input_bias.unsqueeze(0).unsqueeze(-1) + ) @property def morr_scale(self) -> Tensor: @@ -324,7 +360,9 @@ def morr_scale(self) -> Tensor: return None return torch.sigmoid(self.morr_input_scale.unsqueeze(-1)) + 0.2 # [p, q, 1] - def propagate_morr(self, weight: Tensor, x: Tensor, morr_output_scale: Tensor) -> Tensor: + def propagate_morr( + self, weight: Tensor, x: Tensor, morr_output_scale: Tensor + ) -> Tensor: """ @description: propagate through the analytically calculated transfer matrix of molg. We implement circulant matrix multiplication using fast circ matmul @param weight {torch.Tensor} two phase shifters in the MZI-based attenuators @@ -420,7 +458,10 @@ def forward(self, x: Tensor) -> Tensor: if self.in_features_pad > self.in_features: if self.x_zero_pad is None or self.x_zero_pad.size(0) != x.size(0): self.x_zero_pad = torch.zeros( - x.size(0), self.in_features_pad - self.in_features, device=x.device, dtype=x.dtype + x.size(0), + self.in_features_pad - self.in_features, + device=x.device, + dtype=x.dtype, ) x = torch.cat([x, self.x_zero_pad], dim=1) diff --git a/src/chop/passes/module/transforms/__init__.py b/src/chop/passes/module/transforms/__init__.py index 8773589ec..caaff1b65 100644 --- a/src/chop/passes/module/transforms/__init__.py +++ b/src/chop/passes/module/transforms/__init__.py @@ -1,4 +1,4 @@ from .autosharding import resharding_transform_pass from .quantize import quantize_module_transform_pass from .autosharding import resharding_transform_pass -from .optical import optical_module_transform_pass \ No newline at end of file +from .optical import optical_module_transform_pass diff --git a/src/chop/passes/module/transforms/optical/module_transform_helper.py b/src/chop/passes/module/transforms/optical/module_transform_helper.py index dc110d9b0..1cef06928 100644 --- a/src/chop/passes/module/transforms/optical/module_transform_helper.py +++ b/src/chop/passes/module/transforms/optical/module_transform_helper.py @@ -2,13 +2,13 @@ import torch.nn as nn import numpy as np from chop.nn.optical.modules import optical_module_map -from chop.passes.module.module_modify_helper import get_module_by_name, set_module_by_name +from chop.passes.module.module_modify_helper import ( + get_module_by_name, + set_module_by_name, +) -def replace_by_name_optical( - network, - module_name: str, - new_module -): + +def replace_by_name_optical(network, module_name: str, new_module): original = get_module_by_name(network, module_name) updated_module = weight_replacement_optical(original, new_module) @@ -16,60 +16,63 @@ def replace_by_name_optical( return network + def weight_replacement_optical(original, new_module): if isinstance(original, nn.Linear): return weight_replacement_linear_optical(original, new_module) elif isinstance(original, nn.Conv2d): return weight_replacement_conv2d_optical(original, new_module) else: - raise NotImplementedError("weight replacement function for the optical module not implemented") + raise NotImplementedError( + "weight replacement function for the optical module not implemented" + ) + def weight_replacement_linear_optical(x, y): """ - Replace the weights of AllPassMORRCirculantLinear (y) + Replace the weights of AllPassMORRCirculantLinear (y) with those from a standard nn.Linear (x). Focuses only on weight copying (no bias copying). """ # Fetch original linear weight [out_features, in_features] W = x.weight.data # shape: (out_features, in_features) - + # Grab dimensions and zero-pad if needed - out_features_pad = y.out_features_pad # padded out_features in y - in_features_pad = y.in_features_pad # padded in_features in y - miniblock = y.miniblock - grid_dim_y = y.grid_dim_y - grid_dim_x = y.grid_dim_x - + out_features_pad = y.out_features_pad # padded out_features in y + in_features_pad = y.in_features_pad # padded in_features in y + miniblock = y.miniblock + grid_dim_y = y.grid_dim_y + grid_dim_x = y.grid_dim_x + # Construct padded weight tensor W_padded = W.new_zeros((out_features_pad, in_features_pad)) W_padded[: W.size(0), : W.size(1)] = W # copy original into top-left - + # Now we create a new tensor of shape [grid_dim_y, grid_dim_x, miniblock] # by compressing each row-block [1 x miniblock] from W_padded into a single scalar. # This is a simple example that takes the mean across the miniblock slice. new_weight = W.new_zeros((grid_dim_y, grid_dim_x, miniblock)) - + # Fill new_weight by averaging the corresponding sub-blocks in W_padded with torch.no_grad(): for p in range(grid_dim_y): for q in range(grid_dim_x): for k in range(miniblock): # The row in W_padded we look at: - row_idx = p * miniblock + k + row_idx = p * miniblock + k # The columns we look at: col_start = q * miniblock - col_end = (q + 1) * miniblock - + col_end = (q + 1) * miniblock + block = W_padded[row_idx, col_start:col_end] new_weight[p, q, k] = block.mean() - + # Copy the result into y.weight y.load_parameters({"weight": new_weight}) # y.weight.data.copy_(new_weight) - - return y + return y def weight_replacement_conv2d_optical(x, y): @@ -91,8 +94,8 @@ def weight_replacement_conv2d_optical(x, y): w_flat = x.weight.data.view(x.out_channels, -1) # 3) Zero-pad to match (out_channels_pad, in_channels_pad) - outC_pad = y.out_channels_pad # == y.grid_dim_y * y.miniblock - inC_pad = y.in_channels_pad # == y.grid_dim_x * y.miniblock + outC_pad = y.out_channels_pad # == y.grid_dim_y * y.miniblock + inC_pad = y.in_channels_pad # == y.grid_dim_x * y.miniblock W = torch.zeros(outC_pad, inC_pad, device=w_flat.device, dtype=w_flat.dtype) # Copy as many channels/elements as we have @@ -113,4 +116,4 @@ def weight_replacement_conv2d_optical(x, y): # Done. At this point, y.weight and y.bias (if present) have been overwritten # with a simple block-circulant approximation of x's parameters. - return y \ No newline at end of file + return y diff --git a/src/chop/passes/module/transforms/optical/optical.py b/src/chop/passes/module/transforms/optical/optical.py index a91e7403d..b911c070d 100644 --- a/src/chop/passes/module/transforms/optical/optical.py +++ b/src/chop/passes/module/transforms/optical/optical.py @@ -2,7 +2,9 @@ from chop.nn.optical.modules import optical_module_map from chop.passes.module.module_modify_helper import replace_by_name, instantiate_module -from chop.passes.module.transforms.optical.module_transform_helper import replace_by_name_optical +from chop.passes.module.transforms.optical.module_transform_helper import ( + replace_by_name_optical, +) def get_config(config: dict, name: str): diff --git a/test/passes/module/transforms/optical/test_optical_module.py b/test/passes/module/transforms/optical/test_optical_module.py index 35c959a70..45b540273 100644 --- a/test/passes/module/transforms/optical/test_optical_module.py +++ b/test/passes/module/transforms/optical/test_optical_module.py @@ -50,6 +50,7 @@ # x = self.fc3(x) # return x + def load_my_model(model_path, device="cpu"): # Load the model from the .pt file loaded_model = torch.load(model_path, map_location=device) @@ -86,6 +87,7 @@ def test_optical_module_transform_pass(): optical_module_transform_pass(mnist_cnn, pass_args) # torch.save(onn_cnn, "mase_output/onn_cnn.pt") + test_optical_module_transform_pass() # if __name__ == '__main__': @@ -175,15 +177,14 @@ def test_optical_module_transform_pass(): # test(onn_model, device, test_loader) # scheduler.step() - + # torch.save(onn_model.state_dict(), "mase_output/trained_onn.pt") - + # print("-------------- Testing the trained onn model -------------------") # test(onn_model, device, test_loader) # _, _ = report_trainable_parameters_analysis_pass(onn_model) - # def perform_optical_module_transform_pass(model, save_path="mase_output/onn_cnn.pt"): # # pass_args = { # # "by": "type", diff --git a/test/passes/module/transforms/optical/train_mnist_cnn.py b/test/passes/module/transforms/optical/train_mnist_cnn.py index de3bbe385..3d32593d5 100644 --- a/test/passes/module/transforms/optical/train_mnist_cnn.py +++ b/test/passes/module/transforms/optical/train_mnist_cnn.py @@ -43,9 +43,15 @@ def train(args, model, device, train_loader, optimizer, epoch): loss.backward() optimizer.step() if batch_idx % args.log_interval == 0: - print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( - epoch, batch_idx * len(data), len(train_loader.dataset), - 100. * batch_idx / len(train_loader), loss.item())) + print( + "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format( + epoch, + batch_idx * len(data), + len(train_loader.dataset), + 100.0 * batch_idx / len(train_loader), + loss.item(), + ) + ) if args.dry_run: break @@ -58,42 +64,95 @@ def test(model, device, test_loader): for data, target in test_loader: data, target = data.to(device), target.to(device) output = model(data) - test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss - pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability + test_loss += F.nll_loss( + output, target, reduction="sum" + ).item() # sum up batch loss + pred = output.argmax( + dim=1, keepdim=True + ) # get the index of the max log-probability correct += pred.eq(target.view_as(pred)).sum().item() test_loss /= len(test_loader.dataset) - print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( - test_loss, correct, len(test_loader.dataset), - 100. * correct / len(test_loader.dataset))) + print( + "\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format( + test_loss, + correct, + len(test_loader.dataset), + 100.0 * correct / len(test_loader.dataset), + ) + ) def main(): # Training settings - parser = argparse.ArgumentParser(description='PyTorch MNIST Example') - parser.add_argument('--batch-size', type=int, default=64, metavar='N', - help='input batch size for training (default: 64)') - parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', - help='input batch size for testing (default: 1000)') - parser.add_argument('--epochs', type=int, default=29, metavar='N', - help='number of epochs to train (default: 14)') - parser.add_argument('--lr', type=float, default=1.0, metavar='LR', - help='learning rate (default: 1.0)') - parser.add_argument('--gamma', type=float, default=0.7, metavar='M', - help='Learning rate step gamma (default: 0.7)') - parser.add_argument('--no-cuda', action='store_true', default=False, - help='disables CUDA training') - parser.add_argument('--no-mps', action='store_true', default=False, - help='disables macOS GPU training') - parser.add_argument('--dry-run', action='store_true', default=False, - help='quickly check a single pass') - parser.add_argument('--seed', type=int, default=1, metavar='S', - help='random seed (default: 1)') - parser.add_argument('--log-interval', type=int, default=10, metavar='N', - help='how many batches to wait before logging training status') - parser.add_argument('--save-model', action='store_true', default=True, - help='For Saving the current Model') + parser = argparse.ArgumentParser(description="PyTorch MNIST Example") + parser.add_argument( + "--batch-size", + type=int, + default=64, + metavar="N", + help="input batch size for training (default: 64)", + ) + parser.add_argument( + "--test-batch-size", + type=int, + default=1000, + metavar="N", + help="input batch size for testing (default: 1000)", + ) + parser.add_argument( + "--epochs", + type=int, + default=29, + metavar="N", + help="number of epochs to train (default: 14)", + ) + parser.add_argument( + "--lr", + type=float, + default=1.0, + metavar="LR", + help="learning rate (default: 1.0)", + ) + parser.add_argument( + "--gamma", + type=float, + default=0.7, + metavar="M", + help="Learning rate step gamma (default: 0.7)", + ) + parser.add_argument( + "--no-cuda", action="store_true", default=False, help="disables CUDA training" + ) + parser.add_argument( + "--no-mps", + action="store_true", + default=False, + help="disables macOS GPU training", + ) + parser.add_argument( + "--dry-run", + action="store_true", + default=False, + help="quickly check a single pass", + ) + parser.add_argument( + "--seed", type=int, default=1, metavar="S", help="random seed (default: 1)" + ) + parser.add_argument( + "--log-interval", + type=int, + default=10, + metavar="N", + help="how many batches to wait before logging training status", + ) + parser.add_argument( + "--save-model", + action="store_true", + default=True, + help="For Saving the current Model", + ) args = parser.parse_args() use_cuda = not args.no_cuda and torch.cuda.is_available() use_mps = not args.no_mps and torch.backends.mps.is_available() @@ -107,24 +166,19 @@ def main(): else: device = torch.device("cpu") - train_kwargs = {'batch_size': args.batch_size} - test_kwargs = {'batch_size': args.test_batch_size} + train_kwargs = {"batch_size": args.batch_size} + test_kwargs = {"batch_size": args.test_batch_size} if use_cuda: - cuda_kwargs = {'num_workers': 1, - 'pin_memory': True, - 'shuffle': True} + cuda_kwargs = {"num_workers": 1, "pin_memory": True, "shuffle": True} train_kwargs.update(cuda_kwargs) test_kwargs.update(cuda_kwargs) - transform=transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.1307,), (0.3081,)) - ]) - dataset1 = datasets.MNIST('../data', train=True, download=True, - transform=transform) - dataset2 = datasets.MNIST('../data', train=False, - transform=transform) - train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs) + transform = transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] + ) + dataset1 = datasets.MNIST("../data", train=True, download=True, transform=transform) + dataset2 = datasets.MNIST("../data", train=False, transform=transform) + train_loader = torch.utils.data.DataLoader(dataset1, **train_kwargs) test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs) model = Net().to(device) @@ -140,5 +194,5 @@ def main(): torch.save(model, "mase_output/sample_mnist_cnn.pt") -if __name__ == '__main__': +if __name__ == "__main__": main() From bd44a07035d04791254772ac7dae8c1cf54d762a Mon Sep 17 00:00:00 2001 From: Johnny1882 Date: Sun, 26 Jan 2025 22:13:58 +0000 Subject: [PATCH 04/38] fix optical test script --- .../transforms/optical/test_optical_module.py | 64 +++--- .../transforms/optical/train_mnist_cnn.py | 198 ------------------ 2 files changed, 37 insertions(+), 225 deletions(-) delete mode 100644 test/passes/module/transforms/optical/train_mnist_cnn.py diff --git a/test/passes/module/transforms/optical/test_optical_module.py b/test/passes/module/transforms/optical/test_optical_module.py index 45b540273..f52acd559 100644 --- a/test/passes/module/transforms/optical/test_optical_module.py +++ b/test/passes/module/transforms/optical/test_optical_module.py @@ -26,29 +26,30 @@ from train_mnist_cnn import test, train, Net -# -------------------------------------------------- -# Model specifications -# -------------------------------------------------- -# class MLP(torch.nn.Module): -# """ -# Toy quantized FC model for digit recognition on MNIST -# """ - -# def __init__(self) -> None: -# super().__init__() - -# self.fc1 = nn.Linear(28 * 28, 28 * 28) -# self.fc2 = nn.Linear(28 * 28, 28 * 28 * 4) -# self.fc3 = nn.Linear(28 * 28 * 4, 10) - -# def forward(self, x): -# x = torch.flatten(x, start_dim=1, end_dim=-1) -# x = torch.nn.functional.relu(self.fc1(x)) -# # w = torch.randn((4, 28 * 28)) -# # x = torch.nn.functional.relu(nn.functional.linear(x, w)) -# x = torch.nn.functional.relu(self.fc2(x)) -# x = self.fc3(x) -# return x +class Net(nn.Module): + def __init__(self): + super(Net, self).__init__() + self.conv1 = nn.Conv2d(1, 32, 3, 1) + self.conv2 = nn.Conv2d(32, 64, 3, 1) + self.dropout1 = nn.Dropout(0.25) + self.dropout2 = nn.Dropout(0.5) + self.fc1 = nn.Linear(9216, 128) + self.fc2 = nn.Linear(128, 10) + + def forward(self, x): + x = self.conv1(x) + x = F.relu(x) + x = self.conv2(x) + x = F.relu(x) + x = F.max_pool2d(x, 2) + x = self.dropout1(x) + x = torch.flatten(x, 1) + x = self.fc1(x) + x = F.relu(x) + x = self.dropout2(x) + x = self.fc2(x) + output = F.log_softmax(x, dim=1) + return output def load_my_model(model_path, device="cpu"): @@ -60,8 +61,9 @@ def load_my_model(model_path, device="cpu"): def test_optical_module_transform_pass(): - model_path = "mase_output/sample_mnist_cnn.pt" - mnist_cnn = load_my_model(model_path) + # model_path = "mase_output/sample_mnist_cnn.pt" + # mnist_cnn = load_my_model(model_path) + model = Net() # Sanity check and report pass_args = { "by": "name", @@ -84,12 +86,20 @@ def test_optical_module_transform_pass(): } }, } - optical_module_transform_pass(mnist_cnn, pass_args) + optical_module_transform_pass(model, pass_args) # torch.save(onn_cnn, "mase_output/onn_cnn.pt") - test_optical_module_transform_pass() + + + + + + + + + # if __name__ == '__main__': # finetune = False diff --git a/test/passes/module/transforms/optical/train_mnist_cnn.py b/test/passes/module/transforms/optical/train_mnist_cnn.py deleted file mode 100644 index 3d32593d5..000000000 --- a/test/passes/module/transforms/optical/train_mnist_cnn.py +++ /dev/null @@ -1,198 +0,0 @@ -import argparse -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.optim as optim -from torchvision import datasets, transforms -from torch.optim.lr_scheduler import StepLR - - -class Net(nn.Module): - def __init__(self): - super(Net, self).__init__() - self.conv1 = nn.Conv2d(1, 32, 3, 1) - self.conv2 = nn.Conv2d(32, 64, 3, 1) - self.dropout1 = nn.Dropout(0.25) - self.dropout2 = nn.Dropout(0.5) - self.fc1 = nn.Linear(9216, 128) - self.fc2 = nn.Linear(128, 10) - - def forward(self, x): - x = self.conv1(x) - x = F.relu(x) - x = self.conv2(x) - x = F.relu(x) - x = F.max_pool2d(x, 2) - x = self.dropout1(x) - x = torch.flatten(x, 1) - x = self.fc1(x) - x = F.relu(x) - x = self.dropout2(x) - x = self.fc2(x) - output = F.log_softmax(x, dim=1) - return output - - -def train(args, model, device, train_loader, optimizer, epoch): - model.train() - for batch_idx, (data, target) in enumerate(train_loader): - data, target = data.to(device), target.to(device) - optimizer.zero_grad() - output = model(data) - loss = F.nll_loss(output, target) - loss.backward() - optimizer.step() - if batch_idx % args.log_interval == 0: - print( - "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format( - epoch, - batch_idx * len(data), - len(train_loader.dataset), - 100.0 * batch_idx / len(train_loader), - loss.item(), - ) - ) - if args.dry_run: - break - - -def test(model, device, test_loader): - model.eval() - test_loss = 0 - correct = 0 - with torch.no_grad(): - for data, target in test_loader: - data, target = data.to(device), target.to(device) - output = model(data) - test_loss += F.nll_loss( - output, target, reduction="sum" - ).item() # sum up batch loss - pred = output.argmax( - dim=1, keepdim=True - ) # get the index of the max log-probability - correct += pred.eq(target.view_as(pred)).sum().item() - - test_loss /= len(test_loader.dataset) - - print( - "\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format( - test_loss, - correct, - len(test_loader.dataset), - 100.0 * correct / len(test_loader.dataset), - ) - ) - - -def main(): - # Training settings - parser = argparse.ArgumentParser(description="PyTorch MNIST Example") - parser.add_argument( - "--batch-size", - type=int, - default=64, - metavar="N", - help="input batch size for training (default: 64)", - ) - parser.add_argument( - "--test-batch-size", - type=int, - default=1000, - metavar="N", - help="input batch size for testing (default: 1000)", - ) - parser.add_argument( - "--epochs", - type=int, - default=29, - metavar="N", - help="number of epochs to train (default: 14)", - ) - parser.add_argument( - "--lr", - type=float, - default=1.0, - metavar="LR", - help="learning rate (default: 1.0)", - ) - parser.add_argument( - "--gamma", - type=float, - default=0.7, - metavar="M", - help="Learning rate step gamma (default: 0.7)", - ) - parser.add_argument( - "--no-cuda", action="store_true", default=False, help="disables CUDA training" - ) - parser.add_argument( - "--no-mps", - action="store_true", - default=False, - help="disables macOS GPU training", - ) - parser.add_argument( - "--dry-run", - action="store_true", - default=False, - help="quickly check a single pass", - ) - parser.add_argument( - "--seed", type=int, default=1, metavar="S", help="random seed (default: 1)" - ) - parser.add_argument( - "--log-interval", - type=int, - default=10, - metavar="N", - help="how many batches to wait before logging training status", - ) - parser.add_argument( - "--save-model", - action="store_true", - default=True, - help="For Saving the current Model", - ) - args = parser.parse_args() - use_cuda = not args.no_cuda and torch.cuda.is_available() - use_mps = not args.no_mps and torch.backends.mps.is_available() - - torch.manual_seed(args.seed) - - if use_cuda: - device = torch.device("cuda") - elif use_mps: - device = torch.device("mps") - else: - device = torch.device("cpu") - - train_kwargs = {"batch_size": args.batch_size} - test_kwargs = {"batch_size": args.test_batch_size} - if use_cuda: - cuda_kwargs = {"num_workers": 1, "pin_memory": True, "shuffle": True} - train_kwargs.update(cuda_kwargs) - test_kwargs.update(cuda_kwargs) - - transform = transforms.Compose( - [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] - ) - dataset1 = datasets.MNIST("../data", train=True, download=True, transform=transform) - dataset2 = datasets.MNIST("../data", train=False, transform=transform) - train_loader = torch.utils.data.DataLoader(dataset1, **train_kwargs) - test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs) - - model = Net().to(device) - optimizer = optim.Adadelta(model.parameters(), lr=args.lr) - - scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) - for epoch in range(1, args.epochs + 1): - train(args, model, device, train_loader, optimizer, epoch) - test(model, device, test_loader) - scheduler.step() - - if args.save_model: - torch.save(model, "mase_output/sample_mnist_cnn.pt") - - -if __name__ == "__main__": - main() From 913d8faeb8f8eb85340bd79b460c2995a7327108 Mon Sep 17 00:00:00 2001 From: Johnny1882 Date: Sun, 26 Jan 2025 22:47:57 +0000 Subject: [PATCH 05/38] black format fixed --- .../module/transforms/optical/test_optical_module.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/test/passes/module/transforms/optical/test_optical_module.py b/test/passes/module/transforms/optical/test_optical_module.py index f52acd559..2041b9ccf 100644 --- a/test/passes/module/transforms/optical/test_optical_module.py +++ b/test/passes/module/transforms/optical/test_optical_module.py @@ -26,6 +26,7 @@ from train_mnist_cnn import test, train, Net + class Net(nn.Module): def __init__(self): super(Net, self).__init__() @@ -89,15 +90,8 @@ def test_optical_module_transform_pass(): optical_module_transform_pass(model, pass_args) # torch.save(onn_cnn, "mase_output/onn_cnn.pt") -test_optical_module_transform_pass() - - - - - - - +test_optical_module_transform_pass() # if __name__ == '__main__': From b681e30fc5c6bec4641967fe5b82ec659d2152c5 Mon Sep 17 00:00:00 2001 From: Johnny1882 Date: Sun, 26 Jan 2025 23:06:04 +0000 Subject: [PATCH 06/38] remove unnecessary import --- test/passes/module/transforms/optical/test_optical_module.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/passes/module/transforms/optical/test_optical_module.py b/test/passes/module/transforms/optical/test_optical_module.py index 2041b9ccf..e61489f55 100644 --- a/test/passes/module/transforms/optical/test_optical_module.py +++ b/test/passes/module/transforms/optical/test_optical_module.py @@ -24,7 +24,7 @@ from torchvision import datasets, transforms from torch.optim.lr_scheduler import StepLR -from train_mnist_cnn import test, train, Net +# from train_mnist_cnn import test, train, Net class Net(nn.Module): From 3d99177d960bba882923931905932c8bc7b429db Mon Sep 17 00:00:00 2001 From: Johnny1882 Date: Tue, 28 Jan 2025 15:44:56 +0000 Subject: [PATCH 07/38] fix cuda error for morr_conv2d and morr_linear --- src/chop/nn/optical/modules/morr_conv2d.py | 2 +- src/chop/nn/optical/modules/morr_linear.py | 2 +- .../module/transforms/optical/test_optical_module.py | 11 ----------- 3 files changed, 2 insertions(+), 13 deletions(-) diff --git a/src/chop/nn/optical/modules/morr_conv2d.py b/src/chop/nn/optical/modules/morr_conv2d.py index a554bc6c2..3cc56ad2f 100644 --- a/src/chop/nn/optical/modules/morr_conv2d.py +++ b/src/chop/nn/optical/modules/morr_conv2d.py @@ -81,7 +81,7 @@ def __init__( # trainable_morr_bias: bool = False, # trainable_morr_scale: bool = False, config=None, - device: Device = torch.device("cuda"), + device: Device = torch.device("cuda" if torch.cuda.is_available() else "cpu"), ) -> None: super(AllPassMORRCirculantConv2d, self).__init__() miniblock = config.get("miniblock", 4) diff --git a/src/chop/nn/optical/modules/morr_linear.py b/src/chop/nn/optical/modules/morr_linear.py index 9a3f483e5..af85357a2 100644 --- a/src/chop/nn/optical/modules/morr_linear.py +++ b/src/chop/nn/optical/modules/morr_linear.py @@ -52,7 +52,7 @@ def __init__( # ### trainable MORR nonlinearity # trainable_morr_bias: bool = False, # trainable_morr_scale: bool = False, - device: Device = torch.device("cuda"), + device: Device = torch.device("cuda" if torch.cuda.is_available() else "cpu"), ) -> None: super(AllPassMORRCirculantLinear, self).__init__() self.in_features = in_features diff --git a/test/passes/module/transforms/optical/test_optical_module.py b/test/passes/module/transforms/optical/test_optical_module.py index e61489f55..d05512cc1 100644 --- a/test/passes/module/transforms/optical/test_optical_module.py +++ b/test/passes/module/transforms/optical/test_optical_module.py @@ -14,17 +14,6 @@ # from chop.passes.module.transforms import quantize_module_transform_pass from chop.passes.module.transforms import optical_module_transform_pass -from chop.passes.module import report_trainable_parameters_analysis_pass - -import argparse -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.optim as optim -from torchvision import datasets, transforms -from torch.optim.lr_scheduler import StepLR - -# from train_mnist_cnn import test, train, Net class Net(nn.Module): From 09da13d3c1445a2a9284c9dd8d3e2b7ca3d58da7 Mon Sep 17 00:00:00 2001 From: Cheng Zhang Date: Tue, 11 Feb 2025 16:19:00 +0000 Subject: [PATCH 08/38] reduce sphinx docstring error --- .../documentation/tutorials/advanced/cli.md | 18 +-- .../advanced/onnxrt_quantization_tutorial.md | 69 ++++----- .../tensorRT_quantization_tutorial.md | 137 +++++++++--------- .../mixed_precision_search_on_manual_model.md | 16 +- .../mixed_precision_search_on_mase_graph.md | 6 +- .../tutorials/cli/simple_train_flow.md | 8 +- .../developer/Add-model-to-machop.md | 6 +- .../modules/hardware/activations/gelu.md | 2 +- .../modules/hardware/activations/selu.md | 4 +- .../modules/hardware/activations/softplus.md | 2 +- .../modules/hardware/linear/fixed_linear.md | 4 +- .../systolic_modules/output_stationary.md | 2 +- .../modules/labs_2024/lab_0_introduction.rst | 2 +- 13 files changed, 137 insertions(+), 139 deletions(-) diff --git a/docs/source/modules/documentation/tutorials/advanced/cli.md b/docs/source/modules/documentation/tutorials/advanced/cli.md index 7e418f7a6..0b126316c 100644 --- a/docs/source/modules/documentation/tutorials/advanced/cli.md +++ b/docs/source/modules/documentation/tutorials/advanced/cli.md @@ -16,7 +16,7 @@ In this case, we can try a toymodel, the command looks like the following ```bash # assuming you you are at the our-stuff/mase directory -cd src +cd src ./ch train toy toy_tiny --config ../configs/archive/test/train.toml --max-epochs 3 ``` @@ -24,7 +24,7 @@ cd src You can fetch all command-line arguments: -```bash +```text [nix-shell:~/Projects/mase/src]$ ./ch -help INFO Set logging level to debug WARNING TensorRT pass is unavailable because the following dependencies are not installed: pytorch_quantization, tensorrt, pycuda, cuda. @@ -180,7 +180,7 @@ This directory includes ### Training Logs -MASE creates [Tensorboard](https://pytorch.org/tutorials/recipes/recipes/tensorboard_with_pytorch.html) or [wandb](https://wandb.ai/site) logs for the training flow - allowing tracking and visualizing metrics such as loss and accuracy. The log files are in `/software/tensorboard/lightning_logs/version_`. +MASE creates [Tensorboard](https://pytorch.org/tutorials/recipes/recipes/tensorboard_with_pytorch.html) or [wandb](https://wandb.ai/site) logs for the training flow - allowing tracking and visualizing metrics such as loss and accuracy. The log files are in `/software/tensorboard/lightning_logs/version_`. Run Tensorboard to visualise the logs using: @@ -213,7 +213,7 @@ To test the model trained above you can use: ```bash # After training, you will have your checkpoint under mase-tools/mase_output -# For example, the checkpoint is under ../mase_output/toy_classification_toy-tiny_2023-07-03/software/training_ckpts/best.ckpt +# For example, the checkpoint is under ../mase_output/toy_classification_toy-tiny_2023-07-03/software/training_ckpts/best.ckpt ./ch test toy toy_tiny --config ../configs/archive/test/train.toml --load ../mase_output/toy_classification_toy_tiny_2024-06-13/software/training_ckpts/best.ckpt``` @@ -383,7 +383,7 @@ cd machop ``` When the search is done, the best quantization config will be printed out. Since we run multi-objective search. There may be multiple best trials found by Optuna. -```txt +```text Best trial(s): | | number | software_metrics | hardware_metrics | scaled_metrics | |----+----------+--------------------------------------+------------------------------------------------------+-------------------------------------------------| @@ -401,7 +401,7 @@ Here is part of the `log.json` recording all search details. For example, `log["0"]["user_attrs_sampled_config"]` is the sampled quantization config of trial 0. Expand it and you will set the precision of each matmul/linear layer's operands. -```json +```text { "0":{ "number":0, @@ -481,13 +481,13 @@ This tutorial shows how to search for mixed-precision quantization strategy for ### Commands -First we train a model on the dataset. After training for some epochs, we get a model with some validation accuracy. The checkpoint is saved at an auto-created location. You can refer to [Run the train action with the CLI](../train/simple_train_flow.md) for more detailed explanation. +First we train a model on the dataset. After training for some epochs, we get a model with some validation accuracy. The checkpoint is saved at an auto-created location. You can refer to *Run the train action with the CLI* for more detailed explanation. The reason why we need a pre-trained model is because we would like to do a post-training-quantization (PTQ) search. This means the quantization happens on a pre-trained model. We then use the PTQ accuracy as a proxy signal for our search. ```bash -cd src +cd src ./ch train jsc-tiny jsc --max-epochs 3 --batch-size 256 --accelerator cpu --project tmp --debug --cpu 0 ``` @@ -628,7 +628,7 @@ The entire searching log is saved in `../mase_output/jsc-tiny/software/search_ck Here is part of the `log.json` -```json +```text { "0":{ "number":0, diff --git a/docs/source/modules/documentation/tutorials/advanced/onnxrt_quantization_tutorial.md b/docs/source/modules/documentation/tutorials/advanced/onnxrt_quantization_tutorial.md index e01c120b4..255697394 100644 --- a/docs/source/modules/documentation/tutorials/advanced/onnxrt_quantization_tutorial.md +++ b/docs/source/modules/documentation/tutorials/advanced/onnxrt_quantization_tutorial.md @@ -54,7 +54,7 @@ set_logging_verbosity("info") I0329 13:46:21.531338 140128553666368 logger.py:44] Set logging level to info -We then load in a demonstration toml file and set the relevant pass arguments (this is all done automatically if we were to use the command line, see [Section 2](#section-2-int8-quantization)) +We then load in a demonstration toml file and set the relevant pass arguments (this is all done automatically if we were to use the command line, see Section 2). ```python @@ -120,8 +120,8 @@ mg = MaseGraph(model=model) Next, we train the `jsc-toy` model using the machop `train` action with the config from the toml file. You may want to switch to GPU for this task - it will not affect the cpu optimizations later on. -```python -!ch train --config {JSC_TOML_PATH} --accelerator gpu +```bash +ch train --config {JSC_TOML_PATH} --accelerator gpu ``` Then we load in the checkpoint. You will have to adjust this according to where it has been stored in the mase_output directory. @@ -154,7 +154,7 @@ We then run the `onnx_runtime_interface_pass` which completes the optimizations - `onnx_path` (the optimized model) - `onnx_dynamic_quantized_path` (the dynamically ) -In this case, since we are not quantizing the model, only the `onnx_path` is available. +In this case, since we are not quantizing the model, only the `onnx_path` is available. The models are also stored in the directory: ``` @@ -178,7 +178,7 @@ mg, onnx_meta = onnx_runtime_interface_pass(mg, pass_args=onnx_config) I0327 14:20:12.539771 140012160939840 onnx_runtime.py:50] Project will be created at /root/mase/mase_output/onnxrt/jsc-toy_cls_jsc_2024-03-27 INFO  ONNX Conversion Complete. Stored ONNX model to /root/mase/mase_output/onnxrt/jsc-toy_cls_jsc_2024-03-27/optimized/version_1/model.onnx I0327 14:20:12.751212 140012160939840 onnx_runtime.py:68] ONNX Conversion Complete. Stored ONNX model to /root/mase/mase_output/onnxrt/jsc-toy_cls_jsc_2024-03-27/optimized/version_1/model.onnx - INFO  ONNX Model Summary: + INFO  ONNX Model Summary: +-------+----------------------------------+--------------------+--------------------------------------------------------------------------------------------------------------------------+-------------------------------------------+---------------------+ | Index | Name | Type | Inputs | Outputs | Attributes | +-------+----------------------------------+--------------------+--------------------------------------------------------------------------------------------------------------------------+-------------------------------------------+---------------------+ @@ -194,7 +194,7 @@ mg, onnx_meta = onnx_runtime_interface_pass(mg, pass_args=onnx_config) | 9 | /seq_blocks.9/BatchNormalization | BatchNormalization | /seq_blocks.8/Gemm_output_0, seq_blocks.9.weight, seq_blocks.9.bias, seq_blocks.9.running_mean, seq_blocks.9.running_var | /seq_blocks.9/BatchNormalization_output_0 | epsilon, momentum | | 10 | /seq_blocks.10/Relu | Relu | /seq_blocks.9/BatchNormalization_output_0 | 37 | | +-------+----------------------------------+--------------------+--------------------------------------------------------------------------------------------------------------------------+-------------------------------------------+---------------------+ - I0327 14:20:12.757548 140012160939840 onnx_runtime.py:90] ONNX Model Summary: + I0327 14:20:12.757548 140012160939840 onnx_runtime.py:90] ONNX Model Summary: +-------+----------------------------------+--------------------+--------------------------------------------------------------------------------------------------------------------------+-------------------------------------------+---------------------+ | Index | Name | Type | Inputs | Outputs | Attributes | +-------+----------------------------------+--------------------+--------------------------------------------------------------------------------------------------------------------------+-------------------------------------------+---------------------+ @@ -237,7 +237,7 @@ _, _ = runtime_analysis_pass(mg_original, pass_args=runtime_analysis_config) | Average GPU Power Usage | 21.816 W | | Inference Energy Consumption | 0.0048292 mWh | +------------------------------+---------------+ - I0327 14:20:19.793779 140012160939840 analysis.py:398] + I0327 14:20:19.793779 140012160939840 analysis.py:398] Results jsc-toy: +------------------------------+---------------+ | Metric (Per Batch) | Value | @@ -280,7 +280,7 @@ _, _ = runtime_analysis_pass(onnx_meta['onnx_path'], pass_args=runtime_analysis_ | Average GPU Power Usage | 21.575 W | | Inference Energy Consumption | 0.0013275 mWh | +------------------------------+---------------+ - I0327 14:20:35.876071 140012160939840 analysis.py:398] + I0327 14:20:35.876071 140012160939840 analysis.py:398] Results jsc-toy-onnx: +------------------------------+---------------+ | Metric (Per Batch) | Value | @@ -298,14 +298,14 @@ _, _ = runtime_analysis_pass(onnx_meta['onnx_path'], pass_args=runtime_analysis_ I0327 14:20:35.878773 140012160939840 analysis.py:84] Runtime analysis results saved to /root/mase_output/tensorrt/quantization/jsc-toy_cls_jsc_2024-03-27/onnx/version_0/model.json -As shown above, the latency of the cpu inference is around 3.5x less with the `jsc-toy` model without compromising accuracy simply by using the optimizations of ONNXRT. +As shown above, the latency of the cpu inference is around 3.5x less with the `jsc-toy` model without compromising accuracy simply by using the optimizations of ONNXRT. Lets now run the same optimzations, this time using a GPU and a larger model - the `vgg7`. We will also utilse the chop action from the terminal which runs the same `onnx_runtime_interface_pass` pass. First lets train the `vgg7` model using the machop `train` action with the config from the new toml file and then load the trained checkpoint it into the `transform` pass. -```python +```bash VGG_TOML_PATH = "../../../machop/configs/onnx/vgg7_gpu_quant.toml" # !ch train --config {VGG_TOML_PATH} @@ -313,7 +313,7 @@ VGG_TOML_PATH = "../../../machop/configs/onnx/vgg7_gpu_quant.toml" # Load in the checkpoint from the previous train - modify accordingly VGG_CHECKPOINT_PATH = "../../../mase_output/vgg7-pre-trained/test-accu-0.9332.ckpt" -!ch transform --config {VGG_TOML_PATH} --load {VGG_CHECKPOINT_PATH} --load-type pl +ch transform --config {VGG_TOML_PATH} --load {VGG_CHECKPOINT_PATH} --load-type pl ``` [2024-03-28 23:09:44,122] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect) @@ -387,7 +387,7 @@ VGG_CHECKPOINT_PATH = "../../../mase_output/vgg7-pre-trained/test-accu-0.9332.ck I0328 23:10:35.119032 140014036379456 onnx_runtime.py:90] Project will be created at /root/mase/mase_output/onnxrt/vgg7_cls_cifar10_2024-03-28 INFO  ONNX Conversion Complete. Stored ONNX model to /root/mase/mase_output/onnxrt/vgg7_cls_cifar10_2024-03-28/optimized/version_3/model.onnx I0328 23:10:43.779212 140014036379456 onnx_runtime.py:108] ONNX Conversion Complete. Stored ONNX model to /root/mase/mase_output/onnxrt/vgg7_cls_cifar10_2024-03-28/optimized/version_3/model.onnx - INFO  ONNX Model Summary: + INFO  ONNX Model Summary: +-------+----------------------------+----------+---------------------------------------------------------------------+-------------------------------------+---------------------------------------------------+ | Index | Name | Type | Inputs | Outputs | Attributes | +-------+----------------------------+----------+---------------------------------------------------------------------+-------------------------------------+---------------------------------------------------+ @@ -414,7 +414,7 @@ VGG_CHECKPOINT_PATH = "../../../mase_output/vgg7-pre-trained/test-accu-0.9332.ck | 20 | /classifier.3/Relu | Relu | /classifier.2/Gemm_output_0 | /classifier.3/Relu_output_0 | | | 21 | /last_layer/Gemm | Gemm | /classifier.3/Relu_output_0, last_layer.weight, last_layer.bias | 76 | alpha, beta, transB | +-------+----------------------------+----------+---------------------------------------------------------------------+-------------------------------------+---------------------------------------------------+ - I0328 23:10:43.897069 140014036379456 onnx_runtime.py:146] ONNX Model Summary: + I0328 23:10:43.897069 140014036379456 onnx_runtime.py:146] ONNX Model Summary: +-------+----------------------------+----------+---------------------------------------------------------------------+-------------------------------------+---------------------------------------------------+ | Index | Name | Type | Inputs | Outputs | Attributes | +-------+----------------------------+----------+---------------------------------------------------------------------+-------------------------------------+---------------------------------------------------+ @@ -485,7 +485,7 @@ VGG_CHECKPOINT_PATH = "../../../mase_output/vgg7-pre-trained/test-accu-0.9332.ck | Average GPU Power Usage | 50.352 W | | Inference Energy Consumption | 0.12032 mWh | +------------------------------+-------------+ - I0328 23:15:43.756563 140014036379456 runtime_analysis.py:521] + I0328 23:15:43.756563 140014036379456 runtime_analysis.py:521] Results vgg7: +------------------------------+-------------+ | Metric (Per Batch) | Value | @@ -521,7 +521,7 @@ VGG_CHECKPOINT_PATH = "../../../mase_output/vgg7-pre-trained/test-accu-0.9332.ck | Average GPU Power Usage | 55.829 W | | Inference Energy Consumption | 0.094099 mWh | +------------------------------+--------------+ - I0328 23:15:53.476423 140014036379456 runtime_analysis.py:521] + I0328 23:15:53.476423 140014036379456 runtime_analysis.py:521] Results vgg7-onnx: +------------------------------+--------------+ | Metric (Per Batch) | Value | @@ -560,7 +560,7 @@ VGG_CHECKPOINT_PATH = "../../../mase_output/vgg7-pre-trained/test-accu-0.9332.ck | Average GPU Power Usage | 54.555 W | | Inference Energy Consumption | 0.14388 mWh | +------------------------------+-------------+ - I0328 23:16:03.469463 140014036379456 runtime_analysis.py:521] + I0328 23:16:03.469463 140014036379456 runtime_analysis.py:521] Results vgg7-onnx: +------------------------------+-------------+ | Metric (Per Batch) | Value | @@ -599,7 +599,7 @@ VGG_CHECKPOINT_PATH = "../../../mase_output/vgg7-pre-trained/test-accu-0.9332.ck | Average GPU Power Usage | 22.86 W | | Inference Energy Consumption | 0.1748 mWh | +------------------------------+------------+ - I0328 23:18:23.964464 140014036379456 runtime_analysis.py:521] + I0328 23:18:23.964464 140014036379456 runtime_analysis.py:521] Results vgg7-onnx: +------------------------------+------------+ | Metric (Per Batch) | Value | @@ -635,7 +635,7 @@ VGG_CHECKPOINT_PATH = "../../../mase_output/vgg7-pre-trained/test-accu-0.9332.ck | Average GPU Power Usage | 50.354 W | | Inference Energy Consumption | 0.076714 mWh | +------------------------------+--------------+ - I0328 23:18:33.854084 140014036379456 runtime_analysis.py:521] + I0328 23:18:33.854084 140014036379456 runtime_analysis.py:521] Results vgg7-onnx: +------------------------------+--------------+ | Metric (Per Batch) | Value | @@ -653,13 +653,13 @@ VGG_CHECKPOINT_PATH = "../../../mase_output/vgg7-pre-trained/test-accu-0.9332.ck I0328 23:18:33.855542 140014036379456 runtime_analysis.py:143] Runtime analysis results saved to /root/mase_output/tensorrt/quantization/vgg7_cls_cifar10_2024-03-28/onnx/version_12/model.json -As shown above, the latency of the gpu inference is 30% less with the `vgg7` model without compromising accuracy simply by using the optimizations of ONNXRT. +As shown above, the latency of the gpu inference is 30% less with the `vgg7` model without compromising accuracy simply by using the optimizations of ONNXRT. -We will now look at quantization to further speed up the model. +We will now look at quantization to further speed up the model. ## Section 2. Quantization -We may quantize either using FP16 or INT8 by setting the `precision` parameter in `passes.onnxruntime.default.config` to `'fp16'` or `'int8'` respectively. INT8 quantization will show the most notable latency improvements but is more likely to lower performance. +We may quantize either using FP16 or INT8 by setting the `precision` parameter in `passes.onnxruntime.default.config` to `'fp16'` or `'int8'` respectively. INT8 quantization will show the most notable latency improvements but is more likely to lower performance. There are three types of quantization for ONNXRT and can be set in `onnxruntime.default.config` under `quantization_types`. The differences of the first two are for how they calibrate i.e. set the scale and zero points which are only relevant for integer based quantization: - **Static Quantization**: @@ -670,26 +670,27 @@ There are three types of quantization for ONNXRT and can be set in `onnxruntime. - The scale and zero point of activations are calculated on-the-fly (online) and are specific for each forward pass. - This approach is more accurate but introduces extra computational overhead -The `onnx_runtime_interface_pass` pass also supports mixed precision. This is an automatic only procedure, where ONNXRT finds a minimal set of ops to skip while retaining a certain level of accuracy, converting most of the ops to float16 but leaving some in float32. +The `onnx_runtime_interface_pass` pass also supports mixed precision. This is an automatic only procedure, where ONNXRT finds a minimal set of ops to skip while retaining a certain level of accuracy, converting most of the ops to float16 but leaving some in float32. - **Auto Mixed Precision Quantization**: - Automatically adjusts between FP16 and FP32 precisions to retain certain level of accuracy - The `precision` parameter does not need to be set in the config since the whole process is automatic. - Unfortunately, this process is currently only supported on GPU. - This approach is most beneficial when INT8 or FP16 exclusive quantizations (static or dynamic) are giving poor results. -All three methodolgies first pre-procsses the model before quantization adding further optimizations. This intermidate model is stored to the `pre-processed` directory. +All three methodolgies first pre-procsses the model before quantization adding further optimizations. This intermidate model is stored to the `pre-processed` directory. -For this example, we will set the `precision` to `'uint8'` (since `ConvInteger` node is not currently supported for `'int8'` on ONNXRT GPU execution provider). +For this example, we will set the `precision` to `'uint8'` (since `ConvInteger` node is not currently supported for `'int8'` on ONNXRT GPU execution provider). We will also set the `precision_types` to `['static', 'dynamic', 'auto']` to compare all three quantization methods, whilst keeping the other settings the exact same for a fair comparison against the optimized `vgg7` model used in the previous section. -```python +```bash VGG_TOML_PATH = "../../../machop/configs/onnx/vgg7_gpu_quant.toml" VGG_CHECKPOINT_PATH = "../../../mase_output/vgg7-pre-trained/test-accu-0.9332.ckpt" !ch transform --config {VGG_TOML_PATH} --load {VGG_CHECKPOINT_PATH} --load-type pl ``` +```text [2024-03-29 13:49:26,029] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect) INFO: Seed set to 0 WARNING: Logging before flag parsing goes to stderr. @@ -761,7 +762,7 @@ VGG_CHECKPOINT_PATH = "../../../mase_output/vgg7-pre-trained/test-accu-0.9332.ck I0329 13:50:32.508896 139783521261376 onnx_runtime.py:90] Project will be created at /root/mase/mase_output/onnxrt/vgg7_cls_cifar10_2024-03-29 INFO  ONNX Conversion Complete. Stored ONNX model to /root/mase/mase_output/onnxrt/vgg7_cls_cifar10_2024-03-29/optimized/version_0/model.onnx I0329 13:50:53.587861 139783521261376 onnx_runtime.py:108] ONNX Conversion Complete. Stored ONNX model to /root/mase/mase_output/onnxrt/vgg7_cls_cifar10_2024-03-29/optimized/version_0/model.onnx - INFO  ONNX Model Summary: + INFO  ONNX Model Summary: +-------+----------------------------+----------+---------------------------------------------------------------------+-------------------------------------+---------------------------------------------------+ | Index | Name | Type | Inputs | Outputs | Attributes | +-------+----------------------------+----------+---------------------------------------------------------------------+-------------------------------------+---------------------------------------------------+ @@ -788,7 +789,7 @@ VGG_CHECKPOINT_PATH = "../../../mase_output/vgg7-pre-trained/test-accu-0.9332.ck | 20 | /classifier.3/Relu | Relu | /classifier.2/Gemm_output_0 | /classifier.3/Relu_output_0 | | | 21 | /last_layer/Gemm | Gemm | /classifier.3/Relu_output_0, last_layer.weight, last_layer.bias | 76 | alpha, beta, transB | +-------+----------------------------+----------+---------------------------------------------------------------------+-------------------------------------+---------------------------------------------------+ - I0329 13:50:53.719763 139783521261376 onnx_runtime.py:146] ONNX Model Summary: + I0329 13:50:53.719763 139783521261376 onnx_runtime.py:146] ONNX Model Summary: +-------+----------------------------+----------+---------------------------------------------------------------------+-------------------------------------+---------------------------------------------------+ | Index | Name | Type | Inputs | Outputs | Attributes | +-------+----------------------------+----------+---------------------------------------------------------------------+-------------------------------------+---------------------------------------------------+ @@ -859,7 +860,7 @@ VGG_CHECKPOINT_PATH = "../../../mase_output/vgg7-pre-trained/test-accu-0.9332.ck | Average GPU Power Usage | 52.579 W | | Inference Energy Consumption | 0.125 mWh | +------------------------------+-----------+ - I0329 13:55:58.685605 139783521261376 runtime_analysis.py:521] + I0329 13:55:58.685605 139783521261376 runtime_analysis.py:521] Results vgg7: +------------------------------+-----------+ | Metric (Per Batch) | Value | @@ -895,7 +896,7 @@ VGG_CHECKPOINT_PATH = "../../../mase_output/vgg7-pre-trained/test-accu-0.9332.ck | Average GPU Power Usage | 53.26 W | | Inference Energy Consumption | 0.089695 mWh | +------------------------------+--------------+ - I0329 13:56:20.459139 139783521261376 runtime_analysis.py:521] + I0329 13:56:20.459139 139783521261376 runtime_analysis.py:521] Results vgg7-onnx: +------------------------------+--------------+ | Metric (Per Batch) | Value | @@ -934,7 +935,7 @@ VGG_CHECKPOINT_PATH = "../../../mase_output/vgg7-pre-trained/test-accu-0.9332.ck | Average GPU Power Usage | 58.211 W | | Inference Energy Consumption | 0.14364 mWh | +------------------------------+-------------+ - I0329 13:56:42.742136 139783521261376 runtime_analysis.py:521] + I0329 13:56:42.742136 139783521261376 runtime_analysis.py:521] Results vgg7-onnx: +------------------------------+-------------+ | Metric (Per Batch) | Value | @@ -973,7 +974,7 @@ VGG_CHECKPOINT_PATH = "../../../mase_output/vgg7-pre-trained/test-accu-0.9332.ck | Average GPU Power Usage | 22.924 W | | Inference Energy Consumption | 0.0742 mWh | +------------------------------+------------+ - I0329 13:59:46.169317 139783521261376 runtime_analysis.py:521] + I0329 13:59:46.169317 139783521261376 runtime_analysis.py:521] Results vgg7-onnx: +------------------------------+------------+ | Metric (Per Batch) | Value | @@ -1009,7 +1010,7 @@ VGG_CHECKPOINT_PATH = "../../../mase_output/vgg7-pre-trained/test-accu-0.9332.ck | Average GPU Power Usage | 49.313 W | | Inference Energy Consumption | 0.067374 mWh | +------------------------------+--------------+ - I0329 14:00:07.487649 139783521261376 runtime_analysis.py:521] + I0329 14:00:07.487649 139783521261376 runtime_analysis.py:521] Results vgg7-onnx: +------------------------------+--------------+ | Metric (Per Batch) | Value | @@ -1029,6 +1030,6 @@ VGG_CHECKPOINT_PATH = "../../../mase_output/vgg7-pre-trained/test-accu-0.9332.ck I0329 14:06:04.342311 139783521261376 save_and_load.py:147] Saved mase graph to /root/mase/mase_output/vgg7_cls_cifar10_2024-03-29/software/transform/transformed_ckpt INFO  Transformation is completed I0329 14:06:04.342653 139783521261376 cli.py:388] Transformation is completed +``` - -As we can see, the optimized onnx model still outperforms Pytorch on the VGG model due to it's runtime optimizations. The static performs the best, then the automatic mixed precision which outperforms the dynamic quantization due to its requirement of calculating activations on-the-fly. +As we can see, the optimized onnx model still outperforms Pytorch on the VGG model due to it's runtime optimizations. The static performs the best, then the automatic mixed precision which outperforms the dynamic quantization due to its requirement of calculating activations on-the-fly. diff --git a/docs/source/modules/documentation/tutorials/advanced/tensorRT_quantization_tutorial.md b/docs/source/modules/documentation/tutorials/advanced/tensorRT_quantization_tutorial.md index b208f6473..5de3d3c22 100644 --- a/docs/source/modules/documentation/tutorials/advanced/tensorRT_quantization_tutorial.md +++ b/docs/source/modules/documentation/tutorials/advanced/tensorRT_quantization_tutorial.md @@ -5,11 +5,11 @@ This notebook is designed to show the features of the TensorRT passes integrated ## Section 1. INT8 Quantization Firstly, we will show you how to do a int8 quantization of a simple model, `jsc-toy`, and compare the quantized model to the original model using the `Machop API`. The quantization process is split into the following stages, each using their own individual pass, and are explained in depth at each subsection: -1. [Fake quantization](#section-11-fake-quantization): `tensorrt_fake_quantize_transform_pass` -2. [Calibration](#section-12-calibration): `tensorrt_calibrate_transform_pass` -3. [Quantized Aware Training](#section-13-quantized-aware-training-qat): `tensorrt_fine_tune_transform_pass` -4. [Quantization](#section-14-tensorrt-quantization): `tensorrt_engine_interface_pass` -5. [Analysis](#section-15-performance-analysis): `tensorrt_analysis_pass` +1. Fake quantization: `tensorrt_fake_quantize_transform_pass` +2. Calibration: `tensorrt_calibrate_transform_pass` +3. Quantized Aware Training: `tensorrt_fine_tune_transform_pass` +4. Quantization: `tensorrt_engine_interface_pass` +5. Analysis: `tensorrt_analysis_pass` We start by loading in the required libraries and passes required for the notebook as well as ensuring the correct path is set for machop to be used. @@ -63,7 +63,7 @@ set_logging_verbosity("info") I0329 12:52:20.742465 139924298352448 logger.py:44] Set logging level to info -Next, we load in the toml file used for quantization. To view the configuration, click [here](../../../machop/configs/tensorrt/jsc_toy_INT8_quantization_by_type.toml). +Next, we load in the toml file used for quantization. ```python @@ -136,8 +136,8 @@ mg = MaseGraph(model=model) Next, we train the `jsc-toy` model using the machop `train` action with the config from the toml file. -```python -!ch train --config {JSC_TOML_PATH} +```bash +ch train --config {JSC_TOML_PATH} ``` Then we load in the checkpoint. You will have to adjust this according to where it has been stored in the mase_output directory. @@ -169,12 +169,12 @@ mg_original = deepcopy_mase_graph(mg) Firstly, we fake quantize the module in order to perform calibration and fine tuning before actually quantizing - this is only used if we have int8 calibration as other precisions are not currently supported within [pytorch-quantization](https://docs.nvidia.com/deeplearning/tensorrt/pytorch-quantization-toolkit/docs/index.html#) library. -This is acheived through the `tensorrt_fake_quantize_transform_pass` which goes through the model, either by type or by name, replaces each layer appropriately to a fake quantized form if the `quantize` parameter is set in the default config (`passes.tensorrt.default.config`) or on a per name or type basis. +This is acheived through the `tensorrt_fake_quantize_transform_pass` which goes through the model, either by type or by name, replaces each layer appropriately to a fake quantized form if the `quantize` parameter is set in the default config (`passes.tensorrt.default.config`) or on a per name or type basis. Currently the quantizable layers are: - Linear -- Conv1d, Conv2d, Conv3d -- ConvTranspose1d, ConvTranspose2d, ConvTranspose3d +- Conv1d, Conv2d, Conv3d +- ConvTranspose1d, ConvTranspose2d, ConvTranspose3d - MaxPool1d, MaxPool2d, MaxPool3d - AvgPool1d, AvgPool2d, AvgPool3d - LSTM, LSTMCell @@ -202,7 +202,7 @@ summarize_quantization_analysis_pass(mg_original, mg) | ReLU | relu | 4 | 0 | 4 | | output | output | 1 | 0 | 1 | | x | placeholder | 1 | 0 | 1 | - I0329 12:52:41.851252 139924298352448 summary.py:85] + I0329 12:52:41.851252 139924298352448 summary.py:85] | Original type | OP | Total | Changed | Unchanged | |-----------------+--------------+---------+-----------+-------------| | BatchNorm1d | batch_norm1d | 4 | 0 | 4 | @@ -212,16 +212,16 @@ summarize_quantization_analysis_pass(mg_original, mg) | x | placeholder | 1 | 0 | 1 | -As you can see we have succesfully fake quantized all linear layers inside `jsc-toy`. This means that we will be able to simulate a quantized model in order to calibrate and fine tune it. This fake quantization was done on typewise i.e. for linear layers only. See [Section 4](#section-4-layer-wise-mixed-precision) for how to apply quantization layerwise - i.e. only first and second layers for example. +As you can see we have succesfully fake quantized all linear layers inside `jsc-toy`. This means that we will be able to simulate a quantized model in order to calibrate and fine tune it. This fake quantization was done on typewise i.e. for linear layers only. See Section 4 for how to apply quantization layerwise - i.e. only first and second layers for example. ### Section 1.2 Calibration -Next, we perform calibration using the `tensorrt_calibrate_transform_pass`. Calibration is achieved by passing data samples to the quantizer and deciding the best amax for activations. +Next, we perform calibration using the `tensorrt_calibrate_transform_pass`. Calibration is achieved by passing data samples to the quantizer and deciding the best amax for activations. Calibrators can be added as a search space parameter to examine the best performing calibrator. The calibrators have been included in the toml as follows. For example: `calibrators = ["percentile", "mse", "entropy"]` -Note: +Note: - To use `percentile` calibration, a list of percentiles must be given - To use `max` calibration, the `histogram` weight and input calibrators must be removed and replaced with `max`. This will use global maximum absolute value to calibrate the model. - If `post_calibration_analysis` is set true the `tensorrt_analysis_pass` will be run for each calibrator tested to evaluate the most suitable calibrator for the model. @@ -300,7 +300,7 @@ mg, _ = tensorrt_calibrate_transform_pass(mg, pass_args=tensorrt_config) | Average GPU Power Usage | 22.239 W | | Inference Energy Consumption | 0.018661 mWh | +------------------------------+--------------+ - I0329 12:52:49.573626 139924298352448 runtime_analysis.py:521] + I0329 12:52:49.573626 139924298352448 runtime_analysis.py:521] Results jsc-toy: +------------------------------+--------------+ | Metric (Per Batch) | Value | @@ -354,7 +354,7 @@ mg, _ = tensorrt_calibrate_transform_pass(mg, pass_args=tensorrt_config) | Average GPU Power Usage | 22.21 W | | Inference Energy Consumption | 0.018664 mWh | +------------------------------+--------------+ - I0329 12:52:53.233150 139924298352448 runtime_analysis.py:521] + I0329 12:52:53.233150 139924298352448 runtime_analysis.py:521] Results jsc-toy: +------------------------------+--------------+ | Metric (Per Batch) | Value | @@ -408,7 +408,7 @@ mg, _ = tensorrt_calibrate_transform_pass(mg, pass_args=tensorrt_config) | Average GPU Power Usage | 22.252 W | | Inference Energy Consumption | 0.017687 mWh | +------------------------------+--------------+ - I0329 12:52:56.441818 139924298352448 runtime_analysis.py:521] + I0329 12:52:56.441818 139924298352448 runtime_analysis.py:521] Results jsc-toy: +------------------------------+--------------+ | Metric (Per Batch) | Value | @@ -462,7 +462,7 @@ mg, _ = tensorrt_calibrate_transform_pass(mg, pass_args=tensorrt_config) | Average GPU Power Usage | 22.426 W | | Inference Energy Consumption | 0.018681 mWh | +------------------------------+--------------+ - I0329 12:53:05.428555 139924298352448 runtime_analysis.py:521] + I0329 12:53:05.428555 139924298352448 runtime_analysis.py:521] Results jsc-toy: +------------------------------+--------------+ | Metric (Per Batch) | Value | @@ -516,7 +516,7 @@ mg, _ = tensorrt_calibrate_transform_pass(mg, pass_args=tensorrt_config) | Average GPU Power Usage | 22.525 W | | Inference Energy Consumption | 0.018149 mWh | +------------------------------+--------------+ - I0329 12:53:22.697756 139924298352448 runtime_analysis.py:521] + I0329 12:53:22.697756 139924298352448 runtime_analysis.py:521] Results jsc-toy: +------------------------------+--------------+ | Metric (Per Batch) | Value | @@ -538,7 +538,7 @@ mg, _ = tensorrt_calibrate_transform_pass(mg, pass_args=tensorrt_config) I0329 12:53:22.701544 139924298352448 calibrate.py:213] Succeeded in calibrating the model in PyTorch! -From the results, the 99% `percentile` clips too many values during the amax calibration, compromising the loss. However 99.99% demonstrates higher validation accuracy alongside `mse` and `entropy` for `jsc-toy`. For such a small model, the methods are not highly distinguished, however for larger models this calibration process will be important for ensuring the quantized model still performs well. +From the results, the 99% `percentile` clips too many values during the amax calibration, compromising the loss. However 99.99% demonstrates higher validation accuracy alongside `mse` and `entropy` for `jsc-toy`. For such a small model, the methods are not highly distinguished, however for larger models this calibration process will be important for ensuring the quantized model still performs well. ### Section 1.3 Quantized Aware Training (QAT) @@ -576,14 +576,14 @@ mg, _ = tensorrt_fine_tune_transform_pass(mg, pass_args=tensorrt_config) I0329 12:53:59.800536 139924298352448 cuda.py:61] LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0] - I0329 12:53:59.814722 139924298352448 model_summary.py:94] + I0329 12:53:59.814722 139924298352448 model_summary.py:94] | Name | Type | Params ------------------------------------------------- - 0 | model | GraphModule | 327 - 1 | loss_fn | CrossEntropyLoss | 0 - 2 | acc_train | MulticlassAccuracy | 0 - 3 | loss_val | MeanMetric | 0 - 4 | loss_test | MeanMetric | 0 + 0 | model | GraphModule | 327 + 1 | loss_fn | CrossEntropyLoss | 0 + 2 | acc_train | MulticlassAccuracy | 0 + 3 | loss_val | MeanMetric | 0 + 4 | loss_test | MeanMetric | 0 ------------------------------------------------- 327 Trainable params 0 Non-trainable params @@ -620,7 +620,7 @@ mg, _ = tensorrt_fine_tune_transform_pass(mg, pass_args=tensorrt_config) After QAT, we are now ready to convert the model to a tensorRT engine so that it can be run with the superior inference speeds. To do so, we use the `tensorrt_engine_interface_pass` which converts the `MaseGraph`'s model from a Pytorch one to an ONNX format as an intermediate stage of the conversion. -During the conversion process, the `.onnx` and `.trt` files are stored to their respective folders shown in [Section 1.3](#section-13-quantized-aware-training-qat). +During the conversion process, the `.onnx` and `.trt` files are stored to their respective folders shown in Section 1.3. This interface pass returns a dictionary containing the `onnx_path` and `trt_engine_path`. @@ -677,7 +677,7 @@ _, _ = runtime_analysis_pass(meta['trt_engine_path'], pass_args=runtime_analysis | Average GPU Power Usage | 23.792 W | | Inference Energy Consumption | 0.0057535 mWh | +------------------------------+---------------+ - I0329 13:03:32.504800 139924298352448 runtime_analysis.py:521] + I0329 13:03:32.504800 139924298352448 runtime_analysis.py:521] Results jsc-toy: +------------------------------+---------------+ | Metric (Per Batch) | Value | @@ -709,7 +709,7 @@ _, _ = runtime_analysis_pass(meta['trt_engine_path'], pass_args=runtime_analysis | Average GPU Power Usage | 23.043 W | | Inference Energy Consumption | 0.00085532 mWh | +------------------------------+----------------+ - I0329 13:03:34.503784 139924298352448 runtime_analysis.py:521] + I0329 13:03:34.503784 139924298352448 runtime_analysis.py:521] Results jsc-toy-trt_quantized: +------------------------------+----------------+ | Metric (Per Batch) | Value | @@ -727,19 +727,18 @@ _, _ = runtime_analysis_pass(meta['trt_engine_path'], pass_args=runtime_analysis I0329 13:03:34.506492 139924298352448 runtime_analysis.py:143] Runtime analysis results saved to /root/mase_output/tensorrt/quantization/jsc-toy_cls_jsc_2024-03-29/tensorrt/version_0/model.json -As shown above, the latency has decreased around 6x with the `jsc-toy` model without compromising accuracy due to the well calibrated amax and quantization-aware fine tuning and additional runtime optimizations from TensorRT. The inference energy consumption has thus also dropped tremendously and this is an excellent demonstration for the need to quantize in industry especially for LLMs in order to reduce energy usage. +As shown above, the latency has decreased around 6x with the `jsc-toy` model without compromising accuracy due to the well calibrated amax and quantization-aware fine tuning and additional runtime optimizations from TensorRT. The inference energy consumption has thus also dropped tremendously and this is an excellent demonstration for the need to quantize in industry especially for LLMs in order to reduce energy usage. ## Section 2. FP16 Quantization -We will now load in a new toml configuration that uses fp16 instead of int8, whilst keeping the other settings the exact same for a fair comparison. This time however, we will use chop from the terminal which runs all the passes showcased in [Section 1](#section-1---int8-quantization). +We will now load in a new toml configuration that uses fp16 instead of int8, whilst keeping the other settings the exact same for a fair comparison. This time however, we will use chop from the terminal which runs all the passes showcased in Section 1. -Since float quantization does not require calibration, nor is it supported by `pytorch-quantization`, the model will not undergo fake quantization; for the time being this unfortunately means QAT is unavailable and only undergoes Post Training Quantization (PTQ). +Since float quantization does not require calibration, nor is it supported by `pytorch-quantization`, the model will not undergo fake quantization; for the time being this unfortunately means QAT is unavailable and only undergoes Post Training Quantization (PTQ). -```python +```text JSC_FP16_BY_TYPE_TOML = "../../../machop/configs/tensorrt/jsc_toy_FP16_quantization_by_type.toml" !ch transform --config {JSC_FP16_BY_TYPE_TOML} --load {JSC_CHECKPOINT_PATH} --load-type pl -``` 8808.24s - pydevd: Sending message related to process being replaced timed-out after 5 seconds [2024-03-28 09:37:03,989] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect) @@ -811,7 +810,7 @@ JSC_FP16_BY_TYPE_TOML = "../../../machop/configs/tensorrt/jsc_toy_FP16_quantizat | ReLU | relu | 4 | 0 | 4 | | output | output | 1 | 0 | 1 | | x | placeholder | 1 | 0 | 1 | - I0328 09:37:11.551648 140201001654080 summary.py:85] + I0328 09:37:11.551648 140201001654080 summary.py:85] | Original type | OP | Total | Changed | Unchanged | |-----------------+--------------+---------+-----------+-------------| | BatchNorm1d | batch_norm1d | 4 | 0 | 4 | @@ -849,7 +848,7 @@ JSC_FP16_BY_TYPE_TOML = "../../../machop/configs/tensorrt/jsc_toy_FP16_quantizat | Average GPU Power Usage | 22.024 W | | Inference Energy Consumption | 0.0049148 mWh | +------------------------------+---------------+ - I0328 09:37:36.951404 140201001654080 runtime_analysis.py:437] + I0328 09:37:36.951404 140201001654080 runtime_analysis.py:437] Results jsc-toy: +------------------------------+---------------+ | Metric (Per Batch) | Value | @@ -871,7 +870,7 @@ JSC_FP16_BY_TYPE_TOML = "../../../machop/configs/tensorrt/jsc_toy_FP16_quantizat ------|---------|----------|----------------------|----------------------|----------------------- 0 | Input | FLOAT | (64, 16) | (64, 16) | input 1 | Output | FLOAT | (64, 5) | (64, 5) | 37 - I0328 09:37:36.960667 140201001654080 runtime_analysis.py:167] + I0328 09:37:36.960667 140201001654080 runtime_analysis.py:167] TensorRT Engine Input/Output Information: Index | Type | DataType | Static Shape | Dynamic Shape | Name ------|---------|----------|----------------------|----------------------|----------------------- @@ -893,7 +892,7 @@ JSC_FP16_BY_TYPE_TOML = "../../../machop/configs/tensorrt/jsc_toy_FP16_quantizat | Average GPU Power Usage | 21.706 W | | Inference Energy Consumption | 0.00055067 mWh | +------------------------------+----------------+ - I0328 09:37:43.052305 140201001654080 runtime_analysis.py:437] + I0328 09:37:43.052305 140201001654080 runtime_analysis.py:437] Results jsc-toy-trt_quantized: +------------------------------+----------------+ | Metric (Per Batch) | Value | @@ -913,7 +912,7 @@ JSC_FP16_BY_TYPE_TOML = "../../../machop/configs/tensorrt/jsc_toy_FP16_quantizat I0328 09:37:43.132117 140201001654080 save_and_load.py:147] Saved mase graph to /root/mase/mase_output/jsc-toy_cls_jsc_2024-03-28/software/transform/transformed_ckpt INFO  Transformation is completed I0328 09:37:43.132461 140201001654080 cli.py:383] Transformation is completed - +``` As you can see, `fp16` acheives a slighty higher test accuracy but a slightly lower latency (~30%) from that of int8 quantization; it is still ~2.5x faster than the unquantized model. Now lets apply quantization to a more complicated model. @@ -925,31 +924,30 @@ In this case, we set: - The `by` parameter to `type` - The `quantize` parameter to true for `passes.tensorrt.conv2d.config` and `precision` parameter to 'int8'. - The `input` and `weight` quantize axis for the conv2d layers. -- The default `passes.tensorrt.default.config` precision to true. +- The default `passes.tensorrt.default.config` precision to true. During the TensorRT quantization, the model's conv2d layers will be converted to an int8 fake quantized form, whilst the linear layers are kept to their default 'fp16'. Calibration of the conv2d layers and then fine tuning will be undergone before quantization and inference. -You may either download a pretrained model [here](https://imperiallondon-my.sharepoint.com/:f:/g/personal/zz7522_ic_ac_uk/Emh3VT7Q_qRFmnp8kDrcgDoBwGUuzLwwKNtX8ZAt368jJQ?e=gsKONa), otherwise train it yourself as shown below. +You may either download a pretrained model [here](https://imperiallondon-my.sharepoint.com/:f:/g/personal/zz7522_ic_ac_uk/Emh3VT7Q_qRFmnp8kDrcgDoBwGUuzLwwKNtX8ZAt368jJQ?e=gsKONa), otherwise train it yourself as shown below. -```python +```bash VGG_TYPEWISE_TOML = "../../../machop/configs/tensorrt/vgg7_typewise_mixed_precision.toml" -!ch train --config {VGG_TYPEWISE_TOML} +ch train --config {VGG_TYPEWISE_TOML} ``` -We will now load the checkpoint in, quantize the model and compare it to the unquantized version as we did in [Section 1.5](#section-15-performance-analysis) +We will now load the checkpoint in, quantize the model and compare it to the unquantized version as we did in Section 1.5 -```python +```bash # Change this checkpoint path accordingly VGG_CHECKPOINT_PATH = "../../../mase_output/vgg7-pre-trained/test-accu-0.9332.ckpt" ``` -```python -!ch transform --config {VGG_TYPEWISE_TOML} --load {VGG_CHECKPOINT_PATH} --load-type pl -``` +```text +ch transform --config {VGG_TYPEWISE_TOML} --load {VGG_CHECKPOINT_PATH} --load-type pl [2024-03-28 23:00:09,016] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect) INFO: Seed set to 0 @@ -1033,7 +1031,7 @@ VGG_CHECKPOINT_PATH = "../../../mase_output/vgg7-pre-trained/test-accu-0.9332.ck | output | output | 1 | 0 | 1 | | view | view | 1 | 0 | 1 | | x | placeholder | 1 | 0 | 1 | - I0328 23:00:37.270473 139939454809920 summary.py:85] + I0328 23:00:37.270473 139939454809920 summary.py:85] | Original type | OP | Total | Changed | Unchanged | |-----------------+--------------+---------+-----------+-------------| | BatchNorm2d | batch_norm2d | 6 | 0 | 6 | @@ -1161,7 +1159,7 @@ VGG_CHECKPOINT_PATH = "../../../mase_output/vgg7-pre-trained/test-accu-0.9332.ck | Average GPU Power Usage | 59.019 W | | Inference Energy Consumption | 0.24777 mWh | +------------------------------+-------------+ - I0328 23:00:55.766893 139939454809920 runtime_analysis.py:521] + I0328 23:00:55.766893 139939454809920 runtime_analysis.py:521] Results vgg7: +------------------------------+-------------+ | Metric (Per Batch) | Value | @@ -1233,7 +1231,7 @@ VGG_CHECKPOINT_PATH = "../../../mase_output/vgg7-pre-trained/test-accu-0.9332.ck | Average GPU Power Usage | 59.653 W | | Inference Energy Consumption | 0.25317 mWh | +------------------------------+-------------+ - I0328 23:01:07.450706 139939454809920 runtime_analysis.py:521] + I0328 23:01:07.450706 139939454809920 runtime_analysis.py:521] Results vgg7: +------------------------------+-------------+ | Metric (Per Batch) | Value | @@ -1270,14 +1268,14 @@ VGG_CHECKPOINT_PATH = "../../../mase_output/vgg7-pre-trained/test-accu-0.9332.ck Files already downloaded and verified Files already downloaded and verified I0328 23:01:12.623627 139939454809920 cuda.py:61] LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0] - I0328 23:01:12.632704 139939454809920 model_summary.py:94] + I0328 23:01:12.632704 139939454809920 model_summary.py:94] | Name | Type | Params ------------------------------------------------- 0 | model | GraphModule | 14.0 M - 1 | loss_fn | CrossEntropyLoss | 0 - 2 | acc_train | MulticlassAccuracy | 0 - 3 | loss_val | MeanMetric | 0 - 4 | loss_test | MeanMetric | 0 + 1 | loss_fn | CrossEntropyLoss | 0 + 2 | acc_train | MulticlassAccuracy | 0 + 3 | loss_val | MeanMetric | 0 + 4 | loss_test | MeanMetric | 0 ------------------------------------------------- 14.0 M Trainable params 0 Non-trainable params @@ -1645,7 +1643,7 @@ VGG_CHECKPOINT_PATH = "../../../mase_output/vgg7-pre-trained/test-accu-0.9332.ck | Average GPU Power Usage | 58.043 W | | Inference Energy Consumption | 0.1441 mWh | +------------------------------+------------+ - I0328 23:06:47.111017 139939454809920 runtime_analysis.py:521] + I0328 23:06:47.111017 139939454809920 runtime_analysis.py:521] Results vgg7: +------------------------------+------------+ | Metric (Per Batch) | Value | @@ -1677,7 +1675,7 @@ VGG_CHECKPOINT_PATH = "../../../mase_output/vgg7-pre-trained/test-accu-0.9332.ck | Average GPU Power Usage | 52.687 W | | Inference Energy Consumption | 0.12441 mWh | +------------------------------+-------------+ - I0328 23:07:00.676242 139939454809920 runtime_analysis.py:521] + I0328 23:07:00.676242 139939454809920 runtime_analysis.py:521] Results vgg7-trt_quantized: +------------------------------+-------------+ | Metric (Per Batch) | Value | @@ -1693,13 +1691,13 @@ VGG_CHECKPOINT_PATH = "../../../mase_output/vgg7-pre-trained/test-accu-0.9332.ck +------------------------------+-------------+ INFO  Runtime analysis results saved to /root/mase_output/tensorrt/quantization/vgg7_cls_cifar10_2024-03-28/tensorrt/version_7/model.json I0328 23:07:00.677799 139939454809920 runtime_analysis.py:143] Runtime analysis results saved to /root/mase_output/tensorrt/quantization/vgg7_cls_cifar10_2024-03-28/tensorrt/version_7/model.json +``` - -By quantizing all convolutional layers to INT8 and maintaining fp16 precision for the linear layers we see a marginal decrease in latency whilst maintaining a comparable accuracy. By experimenting with precisions on a per type basis, you may find insights that work best for your model. +By quantizing all convolutional layers to INT8 and maintaining fp16 precision for the linear layers we see a marginal decrease in latency whilst maintaining a comparable accuracy. By experimenting with precisions on a per type basis, you may find insights that work best for your model. ## Section 4. Layer-wise Mixed Precision -So far we have strictly quantized either in int8 or fp16. Now, we will show how to conduct layerwise mixed precision using the same `vgg7` model. In this case we will show how for instance, layer 0 and 1 can be set to fp16, while the remaining layers can be int8 quantized. +So far we have strictly quantized either in int8 or fp16. Now, we will show how to conduct layerwise mixed precision using the same `vgg7` model. In this case we will show how for instance, layer 0 and 1 can be set to fp16, while the remaining layers can be int8 quantized. For this, we set: - The `by` parameter to `name` @@ -1708,11 +1706,10 @@ For this, we set: - The `precision` to 'int8' for `passes.tensorrt.feature_layers_2.config and passes.tensorrt.feature_layers_3.config` (although this is not necessary since the default is already set to 'int8') -```python +```text VGG_LAYERWISE_TOML = "../../../machop/configs/tensorrt/vgg7_layerwise_mixed_precision.toml" -!ch transform --config {VGG_LAYERWISE_TOML} --load {VGG_CHECKPOINT_PATH} --load-type pl -``` +ch transform --config {VGG_LAYERWISE_TOML} --load {VGG_CHECKPOINT_PATH} --load-type pl [2024-03-28 23:25:51,157] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect) INFO: Seed set to 0 @@ -1796,7 +1793,7 @@ VGG_LAYERWISE_TOML = "../../../machop/configs/tensorrt/vgg7_layerwise_mixed_prec | output | output | 1 | 0 | 1 | | view | view | 1 | 0 | 1 | | x | placeholder | 1 | 0 | 1 | - I0328 23:26:12.941653 140449214740288 summary.py:85] + I0328 23:26:12.941653 140449214740288 summary.py:85] | Original type | OP | Total | Changed | Unchanged | |-----------------+--------------+---------+-----------+-------------| | BatchNorm2d | batch_norm2d | 6 | 0 | 6 | @@ -1956,7 +1953,7 @@ VGG_LAYERWISE_TOML = "../../../machop/configs/tensorrt/vgg7_layerwise_mixed_prec | Average GPU Power Usage | 57.532 W | | Inference Energy Consumption | 0.28607 mWh | +------------------------------+-------------+ - I0328 23:26:29.263397 140449214740288 runtime_analysis.py:521] + I0328 23:26:29.263397 140449214740288 runtime_analysis.py:521] Results vgg7: +------------------------------+-------------+ | Metric (Per Batch) | Value | @@ -2040,7 +2037,7 @@ VGG_LAYERWISE_TOML = "../../../machop/configs/tensorrt/vgg7_layerwise_mixed_prec | Average GPU Power Usage | 57.867 W | | Inference Energy Consumption | 0.29145 mWh | +------------------------------+-------------+ - I0328 23:26:40.146152 140449214740288 runtime_analysis.py:521] + I0328 23:26:40.146152 140449214740288 runtime_analysis.py:521] Results vgg7: +------------------------------+-------------+ | Metric (Per Batch) | Value | @@ -2099,7 +2096,7 @@ VGG_LAYERWISE_TOML = "../../../machop/configs/tensorrt/vgg7_layerwise_mixed_prec | Average GPU Power Usage | 55.687 W | | Inference Energy Consumption | 0.12102 mWh | +------------------------------+-------------+ - I0328 23:07:00.676242 139939454809920 runtime_analysis.py:521] + I0328 23:07:00.676242 139939454809920 runtime_analysis.py:521] Results vgg7-trt_quantized: +------------------------------+-------------+ | Metric (Per Batch) | Value | @@ -2115,6 +2112,6 @@ VGG_LAYERWISE_TOML = "../../../machop/configs/tensorrt/vgg7_layerwise_mixed_prec +------------------------------+-------------+ INFO  Runtime analysis results saved to /root/mase_output/tensorrt/quantization/vgg7_cls_cifar10_2024-03-28/tensorrt/version_8/model.json I0328 23:07:00.677799 139939454809920 runtime_analysis.py:143] Runtime analysis results saved to /root/mase_output/tensorrt/quantization/vgg7_cls_cifar10_2024-03-28/tensorrt/version_8/model.json - +``` In this case, we can see through the quantized summary that one convolutional layer (feature_layers_1) has not been quantized as its precision will be configured to 'fp16' in the tensorrt engine conversion stage whilst the remaining convolutional and linear layers have been quantized. diff --git a/docs/source/modules/documentation/tutorials/cli/mixed_precision_search_on_manual_model.md b/docs/source/modules/documentation/tutorials/cli/mixed_precision_search_on_manual_model.md index 8f4d69cf1..9fbb14c38 100644 --- a/docs/source/modules/documentation/tutorials/cli/mixed_precision_search_on_manual_model.md +++ b/docs/source/modules/documentation/tutorials/cli/mixed_precision_search_on_manual_model.md @@ -82,7 +82,7 @@ cd machop ``` When the search is done, the best quantization config will be printed out. Since we run multi-objective search. There may be multiple best trials found by Optuna. -```txt +```text Best trial(s): | | number | software_metrics | hardware_metrics | scaled_metrics | |----+----------+--------------------------------------+------------------------------------------------------+-------------------------------------------------| @@ -100,7 +100,7 @@ Here is part of the `log.json` recording all search details. For example, `log["0"]["user_attrs_sampled_config"]` is the sampled quantization config of trial 0. Expand it and you will set the precision of each matmul/linear layer's operands. -```json +```text { "0":{ "number":0, @@ -120,8 +120,8 @@ For example, `log["0"]["user_attrs_sampled_config"]` is the sampled quantization "bias_width":2, "bias_frac_width":8 } - }, - ... + } + // ... }, "user_attrs_scaled_metrics":{ "accuracy":0.5, @@ -154,8 +154,8 @@ For example, `log["0"]["user_attrs_sampled_config"]` is the sampled quantization "bias_width":4, "bias_frac_width":3 } - }, - ... + } + // ... }, "user_attrs_scaled_metrics":{ "accuracy":0.5747232437, @@ -169,7 +169,7 @@ For example, `log["0"]["user_attrs_sampled_config"]` is the sampled quantization "datetime_start":1694095031290, "datetime_complete":1694095032462, "duration":1172 - }, - ... + } + // ... } ``` \ No newline at end of file diff --git a/docs/source/modules/documentation/tutorials/cli/mixed_precision_search_on_mase_graph.md b/docs/source/modules/documentation/tutorials/cli/mixed_precision_search_on_mase_graph.md index e82d69d2d..10b309d90 100644 --- a/docs/source/modules/documentation/tutorials/cli/mixed_precision_search_on_mase_graph.md +++ b/docs/source/modules/documentation/tutorials/cli/mixed_precision_search_on_mase_graph.md @@ -4,13 +4,13 @@ This tutorial shows how to search for mixed-precision quantization strategy for ## Commands -First we train a model on the dataset. After training for some epochs, we get a model with some validation accuracy. The checkpoint is saved at an auto-created location. You can refer to [Run the train action with the CLI](../train/simple_train_flow.md) for more detailed explanation. +First we train a model on the dataset. After training for some epochs, we get a model with some validation accuracy. The checkpoint is saved at an auto-created location. You can refer to Run the train action with the CLI for more detailed explanation. The reason why we need a pre-trained model is because we would like to do a post-training-quantization (PTQ) search. This means the quantization happens on a pre-trained model. We then use the PTQ accuracy as a proxy signal for our search. ```bash -cd src +cd src ./ch train jsc-tiny jsc --max-epochs 3 --batch-size 256 --accelerator cpu --project tmp --debug --cpu 0 ``` @@ -151,7 +151,7 @@ The entire searching log is saved in `../mase_output/jsc-tiny/software/search_ck Here is part of the `log.json` -```json +```text { "0":{ "number":0, diff --git a/docs/source/modules/documentation/tutorials/cli/simple_train_flow.md b/docs/source/modules/documentation/tutorials/cli/simple_train_flow.md index 8f9eb8cb1..52c21ca40 100644 --- a/docs/source/modules/documentation/tutorials/cli/simple_train_flow.md +++ b/docs/source/modules/documentation/tutorials/cli/simple_train_flow.md @@ -12,7 +12,7 @@ In this case, we can try a toymodel, the command looks like the following ```bash # assuming you you are at the our-stuff/mase directory -cd src +cd src ./ch train toy toy_tiny --config ../configs/archive/test/train.toml --max-epochs 3 ``` @@ -20,7 +20,7 @@ cd src You can fetch all command-line arguments: -```bash +```text [nix-shell:~/Projects/mase/src]$ ./ch -help INFO Set logging level to debug WARNING TensorRT pass is unavailable because the following dependencies are not installed: pytorch_quantization, tensorrt, pycuda, cuda. @@ -176,7 +176,7 @@ This directory includes ## Training Logs -MASE creates [Tensorboard](https://pytorch.org/tutorials/recipes/recipes/tensorboard_with_pytorch.html) or [wandb](https://wandb.ai/site) logs for the training flow - allowing tracking and visualizing metrics such as loss and accuracy. The log files are in `/software/tensorboard/lightning_logs/version_`. +MASE creates [Tensorboard](https://pytorch.org/tutorials/recipes/recipes/tensorboard_with_pytorch.html) or [wandb](https://wandb.ai/site) logs for the training flow - allowing tracking and visualizing metrics such as loss and accuracy. The log files are in `/software/tensorboard/lightning_logs/version_`. Run Tensorboard to visualise the logs using: @@ -209,5 +209,5 @@ To test the model trained above you can use: ```bash # After training, you will have your checkpoint under mase-tools/mase_output -# For example, the checkpoint is under ../mase_output/toy_classification_toy-tiny_2023-07-03/software/training_ckpts/best.ckpt +# For example, the checkpoint is under ../mase_output/toy_classification_toy-tiny_2023-07-03/software/training_ckpts/best.ckpt ./ch test toy toy_tiny --config ../configs/archive/test/train.toml --load ../mase_output/toy_classification_toy_tiny_2024-06-13/software/training_ckpts/best.ckpt``` diff --git a/docs/source/modules/documentation/tutorials/developer/Add-model-to-machop.md b/docs/source/modules/documentation/tutorials/developer/Add-model-to-machop.md index a30a172fb..c7750e262 100644 --- a/docs/source/modules/documentation/tutorials/developer/Add-model-to-machop.md +++ b/docs/source/modules/documentation/tutorials/developer/Add-model-to-machop.md @@ -3,10 +3,10 @@ This document includes steps to add a new model into Machop ## Overall Structure ### Model -All models that Machop support are defined inside **mase-tools/machop/chop/models**. Each model has a unique get model function, which can be called to create the model. Those get model function will be exported into a dictionary in [\_\_init\_\_](%2E%2E%5Cmachop%5Cchop%5Cmodels%5C%5F%5Finit%5F%5F.py) file. +All models that Machop support are defined inside **mase-tools/machop/chop/models**. Each model has a unique get model function, which can be called to create the model. Those get model function will be exported into a dictionary in `__init__` file. ### Command Line Interface -[Command Line Interface (cli)](..\machop\chop\cli.py) will take the input config, and perform the task defined inside the config. When training, cli will look into the dictionary contains the get funtions, use the get-function to create a model, and do training then. +Command Line Interface (cli) will take the input config, and perform the task defined inside the config. When training, cli will look into the dictionary contains the get funtions, use the get-function to create a model, and do training then. ## What To Do 1. Find the GitHub repositories of the original paper, find the code that defines the models, and copy it into the right folder under **mase-tools\machop\chop\models** @@ -20,7 +20,7 @@ All models that Machop support are defined inside **mase-tools/machop/chop/model ## Get model function - **Info** should be used as one of the input variables. It is a dictionary that contains information about the dataset, e.g., number of classes; input image size. -- Other then **Info**, Inputs for different types of models are different, you can check `_setup_model_and_dataset` function defined in [cli.py](..\machop\chop\cli.py) for more detail. +- Other then **Info**, Inputs for different types of models are different, you can check `_setup_model_and_dataset` function defined in cli.py for more detail. - function name of get-function should be in smaller case - keys of the dictionary should also be in smaller case diff --git a/docs/source/modules/hardware/activations/gelu.md b/docs/source/modules/hardware/activations/gelu.md index cb2d094e8..25b2560b4 100644 --- a/docs/source/modules/hardware/activations/gelu.md +++ b/docs/source/modules/hardware/activations/gelu.md @@ -11,7 +11,7 @@ When the approximate argument is set to 'tanh', GELU is estimated with: `GELU(x) = 0.5 * x * (1 + Tanh(2/π * (x + 0.044715 * x^3)))` -### Parameters: +## Parameters: - `approximate` (str, optional): The GELU approximation algorithm to use: 'none' | 'tanh'. Default: 'none'. diff --git a/docs/source/modules/hardware/activations/selu.md b/docs/source/modules/hardware/activations/selu.md index 8413420a1..882e5c8d5 100644 --- a/docs/source/modules/hardware/activations/selu.md +++ b/docs/source/modules/hardware/activations/selu.md @@ -8,7 +8,7 @@ where: - α = 1.6732632423543772848170429916717 - scale = 1.0507009873554804934193349852946 -### Parameters +## Parameters - `inplace` (bool, optional): Can optionally do the operation in-place. Default: False. @@ -30,7 +30,7 @@ A hybrid approach is used for implementing exponential function $e^{-|x|}$ for a 2. **Representation of Binary Number**: The N-bit binary number $a = b_{N-1}b_{N-2}...b_1b_0$ is represented, where $b_0$ is the least significant bit, and each bit $b_i$ has a place value $p_i$ given by $p_i = 2^{-P} \times 2^i$. - + 3. **Exponential Computation**: $e^{-a} = \prod e^{-p_i \times b_i}$ diff --git a/docs/source/modules/hardware/activations/softplus.md b/docs/source/modules/hardware/activations/softplus.md index ed0066cfb..0d7f3e77a 100644 --- a/docs/source/modules/hardware/activations/softplus.md +++ b/docs/source/modules/hardware/activations/softplus.md @@ -9,7 +9,7 @@ where: - Softplus is a smooth approximation to the ReLU function and can be used to constrain the output of a machine to always be positive. - For numerical stability, the implementation reverts to the linear function when `input * β > threshold`. -### Parameters: +## Parameters: - `beta` (int): The β value for the Softplus formulation. Default: 1. - `threshold` (int): Values above this revert to a linear function. Default: 20. diff --git a/docs/source/modules/hardware/linear/fixed_linear.md b/docs/source/modules/hardware/linear/fixed_linear.md index 672ea765a..33ddf350e 100644 --- a/docs/source/modules/hardware/linear/fixed_linear.md +++ b/docs/source/modules/hardware/linear/fixed_linear.md @@ -32,7 +32,7 @@ The module has the following parameters, following the hardware metadata standar | Parameter | Default Value | Definition | |------------------------------ |-------------------------- |------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | | DATA_IN_0_PARALLELISM_DIM_0 | 4 | Number of elements per transaction at the input interface. Dictates the number of transactions to compute the full layer. | -| WEIGHT_PARALLELISM_DIM_0 | 4 | Number of columns of the weights matrix per transaction at the weights interface. This is equivalent to the number of dot product modules. Also dictates the number of backpressure cycles on the input interface (see [Latency Analysis](#latency-analysis) below) | +| WEIGHT_PARALLELISM_DIM_0 | 4 | Number of columns of the weights matrix per transaction at the weights interface. This is equivalent to the number of dot product modules. Also dictates the number of backpressure cycles on the input interface (see Latency Analysis below) | | DATA_OUT_0_PARALLELISM_DIM_0 | WEIGHT_PARALLELISM_DIM_0 | Number of elements per transaction at the output interface. | | BIAS_PARALLELISM_DIM_0 | WEIGHT_PARALLELISM_DIM_0 | Number of elements per transaction at the bias interface. Dictates the number of fixed-point adders. | @@ -56,7 +56,7 @@ The same process is repeated with the second input sub-vector $X_2$ and weight s img

-## Latency Analysis +## Latency Analysis The time taken to compute a linear layer using the `fixed_linear` module, $L_{FL}$ can be broken down into 2 phases, the input driving phase $L_L$, and the pipeline unloading phase $L_U$ that begins after the last input beat is transferred. diff --git a/docs/source/modules/hardware/systolic_modules/output_stationary.md b/docs/source/modules/hardware/systolic_modules/output_stationary.md index d5c605a41..d0e5370b2 100644 --- a/docs/source/modules/hardware/systolic_modules/output_stationary.md +++ b/docs/source/modules/hardware/systolic_modules/output_stationary.md @@ -8,6 +8,6 @@ The MAC units in each PE perform the multiply-accumulate operation over 2 cycles ![Systolic Array](https://raw.githubusercontent.com/DeepWok/mase/main/docs/source/imgs/hardware/sys_array_pe.png) -### Systolic Module Driver +## Systolic Module Driver The Systolic Module Driver generates pulse signals in the format required to drive the read interface of an on-chip buffer such that data signals are made available with the required timing for the processing elements of a systolic module. This is achieved through a shift register of size BUFFER_SLOT_COUNT. After receiving a starting pulse, the least significant bit is set to 1. Subsequently, the register shifts after every shift pulse, up to a runtime-parametrizable pulse limit count parameter (this is set to the number of output features for the layer being executed). The driver should then pulse a subsequent BUFFER_SLOT_COUNT times until the register is flushed. \ No newline at end of file diff --git a/docs/source/modules/labs_2024/lab_0_introduction.rst b/docs/source/modules/labs_2024/lab_0_introduction.rst index 0bde49047..896c0c769 100644 --- a/docs/source/modules/labs_2024/lab_0_introduction.rst +++ b/docs/source/modules/labs_2024/lab_0_introduction.rst @@ -46,7 +46,7 @@ TroubleShooting You may find that you have to use `Python3.11` but Google Colab only provides `Python3.10`. In this case, you can use the following command to force the kernel ot use `Python3.11`: -.. code-block:: python +.. code-block:: text #The code below installs 3.11 (assuming you now have 3.10 in colab) and restarts environment, so you can run your cells. import sys #for version checker From aac9097c1b2f00de5a092bb31cbd4bf230f6f351 Mon Sep 17 00:00:00 2001 From: Cheng Zhang Date: Tue, 11 Feb 2025 16:45:31 +0000 Subject: [PATCH 09/38] undo docstrings changeRevert "reduce sphinx docstring error" This reverts commit 09da13d3c1445a2a9284c9dd8d3e2b7ca3d58da7. --- .../documentation/tutorials/advanced/cli.md | 18 +-- .../advanced/onnxrt_quantization_tutorial.md | 69 +++++---- .../tensorRT_quantization_tutorial.md | 137 +++++++++--------- .../mixed_precision_search_on_manual_model.md | 16 +- .../mixed_precision_search_on_mase_graph.md | 6 +- .../tutorials/cli/simple_train_flow.md | 8 +- .../developer/Add-model-to-machop.md | 6 +- .../modules/hardware/activations/gelu.md | 2 +- .../modules/hardware/activations/selu.md | 4 +- .../modules/hardware/activations/softplus.md | 2 +- .../modules/hardware/linear/fixed_linear.md | 4 +- .../systolic_modules/output_stationary.md | 2 +- .../modules/labs_2024/lab_0_introduction.rst | 2 +- 13 files changed, 139 insertions(+), 137 deletions(-) diff --git a/docs/source/modules/documentation/tutorials/advanced/cli.md b/docs/source/modules/documentation/tutorials/advanced/cli.md index 0b126316c..7e418f7a6 100644 --- a/docs/source/modules/documentation/tutorials/advanced/cli.md +++ b/docs/source/modules/documentation/tutorials/advanced/cli.md @@ -16,7 +16,7 @@ In this case, we can try a toymodel, the command looks like the following ```bash # assuming you you are at the our-stuff/mase directory -cd src +cd src ./ch train toy toy_tiny --config ../configs/archive/test/train.toml --max-epochs 3 ``` @@ -24,7 +24,7 @@ cd src You can fetch all command-line arguments: -```text +```bash [nix-shell:~/Projects/mase/src]$ ./ch -help INFO Set logging level to debug WARNING TensorRT pass is unavailable because the following dependencies are not installed: pytorch_quantization, tensorrt, pycuda, cuda. @@ -180,7 +180,7 @@ This directory includes ### Training Logs -MASE creates [Tensorboard](https://pytorch.org/tutorials/recipes/recipes/tensorboard_with_pytorch.html) or [wandb](https://wandb.ai/site) logs for the training flow - allowing tracking and visualizing metrics such as loss and accuracy. The log files are in `/software/tensorboard/lightning_logs/version_`. +MASE creates [Tensorboard](https://pytorch.org/tutorials/recipes/recipes/tensorboard_with_pytorch.html) or [wandb](https://wandb.ai/site) logs for the training flow - allowing tracking and visualizing metrics such as loss and accuracy. The log files are in `/software/tensorboard/lightning_logs/version_`. Run Tensorboard to visualise the logs using: @@ -213,7 +213,7 @@ To test the model trained above you can use: ```bash # After training, you will have your checkpoint under mase-tools/mase_output -# For example, the checkpoint is under ../mase_output/toy_classification_toy-tiny_2023-07-03/software/training_ckpts/best.ckpt +# For example, the checkpoint is under ../mase_output/toy_classification_toy-tiny_2023-07-03/software/training_ckpts/best.ckpt ./ch test toy toy_tiny --config ../configs/archive/test/train.toml --load ../mase_output/toy_classification_toy_tiny_2024-06-13/software/training_ckpts/best.ckpt``` @@ -383,7 +383,7 @@ cd machop ``` When the search is done, the best quantization config will be printed out. Since we run multi-objective search. There may be multiple best trials found by Optuna. -```text +```txt Best trial(s): | | number | software_metrics | hardware_metrics | scaled_metrics | |----+----------+--------------------------------------+------------------------------------------------------+-------------------------------------------------| @@ -401,7 +401,7 @@ Here is part of the `log.json` recording all search details. For example, `log["0"]["user_attrs_sampled_config"]` is the sampled quantization config of trial 0. Expand it and you will set the precision of each matmul/linear layer's operands. -```text +```json { "0":{ "number":0, @@ -481,13 +481,13 @@ This tutorial shows how to search for mixed-precision quantization strategy for ### Commands -First we train a model on the dataset. After training for some epochs, we get a model with some validation accuracy. The checkpoint is saved at an auto-created location. You can refer to *Run the train action with the CLI* for more detailed explanation. +First we train a model on the dataset. After training for some epochs, we get a model with some validation accuracy. The checkpoint is saved at an auto-created location. You can refer to [Run the train action with the CLI](../train/simple_train_flow.md) for more detailed explanation. The reason why we need a pre-trained model is because we would like to do a post-training-quantization (PTQ) search. This means the quantization happens on a pre-trained model. We then use the PTQ accuracy as a proxy signal for our search. ```bash -cd src +cd src ./ch train jsc-tiny jsc --max-epochs 3 --batch-size 256 --accelerator cpu --project tmp --debug --cpu 0 ``` @@ -628,7 +628,7 @@ The entire searching log is saved in `../mase_output/jsc-tiny/software/search_ck Here is part of the `log.json` -```text +```json { "0":{ "number":0, diff --git a/docs/source/modules/documentation/tutorials/advanced/onnxrt_quantization_tutorial.md b/docs/source/modules/documentation/tutorials/advanced/onnxrt_quantization_tutorial.md index 255697394..e01c120b4 100644 --- a/docs/source/modules/documentation/tutorials/advanced/onnxrt_quantization_tutorial.md +++ b/docs/source/modules/documentation/tutorials/advanced/onnxrt_quantization_tutorial.md @@ -54,7 +54,7 @@ set_logging_verbosity("info") I0329 13:46:21.531338 140128553666368 logger.py:44] Set logging level to info -We then load in a demonstration toml file and set the relevant pass arguments (this is all done automatically if we were to use the command line, see Section 2). +We then load in a demonstration toml file and set the relevant pass arguments (this is all done automatically if we were to use the command line, see [Section 2](#section-2-int8-quantization)) ```python @@ -120,8 +120,8 @@ mg = MaseGraph(model=model) Next, we train the `jsc-toy` model using the machop `train` action with the config from the toml file. You may want to switch to GPU for this task - it will not affect the cpu optimizations later on. -```bash -ch train --config {JSC_TOML_PATH} --accelerator gpu +```python +!ch train --config {JSC_TOML_PATH} --accelerator gpu ``` Then we load in the checkpoint. You will have to adjust this according to where it has been stored in the mase_output directory. @@ -154,7 +154,7 @@ We then run the `onnx_runtime_interface_pass` which completes the optimizations - `onnx_path` (the optimized model) - `onnx_dynamic_quantized_path` (the dynamically ) -In this case, since we are not quantizing the model, only the `onnx_path` is available. +In this case, since we are not quantizing the model, only the `onnx_path` is available. The models are also stored in the directory: ``` @@ -178,7 +178,7 @@ mg, onnx_meta = onnx_runtime_interface_pass(mg, pass_args=onnx_config) I0327 14:20:12.539771 140012160939840 onnx_runtime.py:50] Project will be created at /root/mase/mase_output/onnxrt/jsc-toy_cls_jsc_2024-03-27 INFO  ONNX Conversion Complete. Stored ONNX model to /root/mase/mase_output/onnxrt/jsc-toy_cls_jsc_2024-03-27/optimized/version_1/model.onnx I0327 14:20:12.751212 140012160939840 onnx_runtime.py:68] ONNX Conversion Complete. Stored ONNX model to /root/mase/mase_output/onnxrt/jsc-toy_cls_jsc_2024-03-27/optimized/version_1/model.onnx - INFO  ONNX Model Summary: + INFO  ONNX Model Summary: +-------+----------------------------------+--------------------+--------------------------------------------------------------------------------------------------------------------------+-------------------------------------------+---------------------+ | Index | Name | Type | Inputs | Outputs | Attributes | +-------+----------------------------------+--------------------+--------------------------------------------------------------------------------------------------------------------------+-------------------------------------------+---------------------+ @@ -194,7 +194,7 @@ mg, onnx_meta = onnx_runtime_interface_pass(mg, pass_args=onnx_config) | 9 | /seq_blocks.9/BatchNormalization | BatchNormalization | /seq_blocks.8/Gemm_output_0, seq_blocks.9.weight, seq_blocks.9.bias, seq_blocks.9.running_mean, seq_blocks.9.running_var | /seq_blocks.9/BatchNormalization_output_0 | epsilon, momentum | | 10 | /seq_blocks.10/Relu | Relu | /seq_blocks.9/BatchNormalization_output_0 | 37 | | +-------+----------------------------------+--------------------+--------------------------------------------------------------------------------------------------------------------------+-------------------------------------------+---------------------+ - I0327 14:20:12.757548 140012160939840 onnx_runtime.py:90] ONNX Model Summary: + I0327 14:20:12.757548 140012160939840 onnx_runtime.py:90] ONNX Model Summary: +-------+----------------------------------+--------------------+--------------------------------------------------------------------------------------------------------------------------+-------------------------------------------+---------------------+ | Index | Name | Type | Inputs | Outputs | Attributes | +-------+----------------------------------+--------------------+--------------------------------------------------------------------------------------------------------------------------+-------------------------------------------+---------------------+ @@ -237,7 +237,7 @@ _, _ = runtime_analysis_pass(mg_original, pass_args=runtime_analysis_config) | Average GPU Power Usage | 21.816 W | | Inference Energy Consumption | 0.0048292 mWh | +------------------------------+---------------+ - I0327 14:20:19.793779 140012160939840 analysis.py:398] + I0327 14:20:19.793779 140012160939840 analysis.py:398] Results jsc-toy: +------------------------------+---------------+ | Metric (Per Batch) | Value | @@ -280,7 +280,7 @@ _, _ = runtime_analysis_pass(onnx_meta['onnx_path'], pass_args=runtime_analysis_ | Average GPU Power Usage | 21.575 W | | Inference Energy Consumption | 0.0013275 mWh | +------------------------------+---------------+ - I0327 14:20:35.876071 140012160939840 analysis.py:398] + I0327 14:20:35.876071 140012160939840 analysis.py:398] Results jsc-toy-onnx: +------------------------------+---------------+ | Metric (Per Batch) | Value | @@ -298,14 +298,14 @@ _, _ = runtime_analysis_pass(onnx_meta['onnx_path'], pass_args=runtime_analysis_ I0327 14:20:35.878773 140012160939840 analysis.py:84] Runtime analysis results saved to /root/mase_output/tensorrt/quantization/jsc-toy_cls_jsc_2024-03-27/onnx/version_0/model.json -As shown above, the latency of the cpu inference is around 3.5x less with the `jsc-toy` model without compromising accuracy simply by using the optimizations of ONNXRT. +As shown above, the latency of the cpu inference is around 3.5x less with the `jsc-toy` model without compromising accuracy simply by using the optimizations of ONNXRT. Lets now run the same optimzations, this time using a GPU and a larger model - the `vgg7`. We will also utilse the chop action from the terminal which runs the same `onnx_runtime_interface_pass` pass. First lets train the `vgg7` model using the machop `train` action with the config from the new toml file and then load the trained checkpoint it into the `transform` pass. -```bash +```python VGG_TOML_PATH = "../../../machop/configs/onnx/vgg7_gpu_quant.toml" # !ch train --config {VGG_TOML_PATH} @@ -313,7 +313,7 @@ VGG_TOML_PATH = "../../../machop/configs/onnx/vgg7_gpu_quant.toml" # Load in the checkpoint from the previous train - modify accordingly VGG_CHECKPOINT_PATH = "../../../mase_output/vgg7-pre-trained/test-accu-0.9332.ckpt" -ch transform --config {VGG_TOML_PATH} --load {VGG_CHECKPOINT_PATH} --load-type pl +!ch transform --config {VGG_TOML_PATH} --load {VGG_CHECKPOINT_PATH} --load-type pl ``` [2024-03-28 23:09:44,122] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect) @@ -387,7 +387,7 @@ ch transform --config {VGG_TOML_PATH} --load {VGG_CHECKPOINT_PATH} --load-type p I0328 23:10:35.119032 140014036379456 onnx_runtime.py:90] Project will be created at /root/mase/mase_output/onnxrt/vgg7_cls_cifar10_2024-03-28 INFO  ONNX Conversion Complete. Stored ONNX model to /root/mase/mase_output/onnxrt/vgg7_cls_cifar10_2024-03-28/optimized/version_3/model.onnx I0328 23:10:43.779212 140014036379456 onnx_runtime.py:108] ONNX Conversion Complete. Stored ONNX model to /root/mase/mase_output/onnxrt/vgg7_cls_cifar10_2024-03-28/optimized/version_3/model.onnx - INFO  ONNX Model Summary: + INFO  ONNX Model Summary: +-------+----------------------------+----------+---------------------------------------------------------------------+-------------------------------------+---------------------------------------------------+ | Index | Name | Type | Inputs | Outputs | Attributes | +-------+----------------------------+----------+---------------------------------------------------------------------+-------------------------------------+---------------------------------------------------+ @@ -414,7 +414,7 @@ ch transform --config {VGG_TOML_PATH} --load {VGG_CHECKPOINT_PATH} --load-type p | 20 | /classifier.3/Relu | Relu | /classifier.2/Gemm_output_0 | /classifier.3/Relu_output_0 | | | 21 | /last_layer/Gemm | Gemm | /classifier.3/Relu_output_0, last_layer.weight, last_layer.bias | 76 | alpha, beta, transB | +-------+----------------------------+----------+---------------------------------------------------------------------+-------------------------------------+---------------------------------------------------+ - I0328 23:10:43.897069 140014036379456 onnx_runtime.py:146] ONNX Model Summary: + I0328 23:10:43.897069 140014036379456 onnx_runtime.py:146] ONNX Model Summary: +-------+----------------------------+----------+---------------------------------------------------------------------+-------------------------------------+---------------------------------------------------+ | Index | Name | Type | Inputs | Outputs | Attributes | +-------+----------------------------+----------+---------------------------------------------------------------------+-------------------------------------+---------------------------------------------------+ @@ -485,7 +485,7 @@ ch transform --config {VGG_TOML_PATH} --load {VGG_CHECKPOINT_PATH} --load-type p | Average GPU Power Usage | 50.352 W | | Inference Energy Consumption | 0.12032 mWh | +------------------------------+-------------+ - I0328 23:15:43.756563 140014036379456 runtime_analysis.py:521] + I0328 23:15:43.756563 140014036379456 runtime_analysis.py:521] Results vgg7: +------------------------------+-------------+ | Metric (Per Batch) | Value | @@ -521,7 +521,7 @@ ch transform --config {VGG_TOML_PATH} --load {VGG_CHECKPOINT_PATH} --load-type p | Average GPU Power Usage | 55.829 W | | Inference Energy Consumption | 0.094099 mWh | +------------------------------+--------------+ - I0328 23:15:53.476423 140014036379456 runtime_analysis.py:521] + I0328 23:15:53.476423 140014036379456 runtime_analysis.py:521] Results vgg7-onnx: +------------------------------+--------------+ | Metric (Per Batch) | Value | @@ -560,7 +560,7 @@ ch transform --config {VGG_TOML_PATH} --load {VGG_CHECKPOINT_PATH} --load-type p | Average GPU Power Usage | 54.555 W | | Inference Energy Consumption | 0.14388 mWh | +------------------------------+-------------+ - I0328 23:16:03.469463 140014036379456 runtime_analysis.py:521] + I0328 23:16:03.469463 140014036379456 runtime_analysis.py:521] Results vgg7-onnx: +------------------------------+-------------+ | Metric (Per Batch) | Value | @@ -599,7 +599,7 @@ ch transform --config {VGG_TOML_PATH} --load {VGG_CHECKPOINT_PATH} --load-type p | Average GPU Power Usage | 22.86 W | | Inference Energy Consumption | 0.1748 mWh | +------------------------------+------------+ - I0328 23:18:23.964464 140014036379456 runtime_analysis.py:521] + I0328 23:18:23.964464 140014036379456 runtime_analysis.py:521] Results vgg7-onnx: +------------------------------+------------+ | Metric (Per Batch) | Value | @@ -635,7 +635,7 @@ ch transform --config {VGG_TOML_PATH} --load {VGG_CHECKPOINT_PATH} --load-type p | Average GPU Power Usage | 50.354 W | | Inference Energy Consumption | 0.076714 mWh | +------------------------------+--------------+ - I0328 23:18:33.854084 140014036379456 runtime_analysis.py:521] + I0328 23:18:33.854084 140014036379456 runtime_analysis.py:521] Results vgg7-onnx: +------------------------------+--------------+ | Metric (Per Batch) | Value | @@ -653,13 +653,13 @@ ch transform --config {VGG_TOML_PATH} --load {VGG_CHECKPOINT_PATH} --load-type p I0328 23:18:33.855542 140014036379456 runtime_analysis.py:143] Runtime analysis results saved to /root/mase_output/tensorrt/quantization/vgg7_cls_cifar10_2024-03-28/onnx/version_12/model.json -As shown above, the latency of the gpu inference is 30% less with the `vgg7` model without compromising accuracy simply by using the optimizations of ONNXRT. +As shown above, the latency of the gpu inference is 30% less with the `vgg7` model without compromising accuracy simply by using the optimizations of ONNXRT. -We will now look at quantization to further speed up the model. +We will now look at quantization to further speed up the model. ## Section 2. Quantization -We may quantize either using FP16 or INT8 by setting the `precision` parameter in `passes.onnxruntime.default.config` to `'fp16'` or `'int8'` respectively. INT8 quantization will show the most notable latency improvements but is more likely to lower performance. +We may quantize either using FP16 or INT8 by setting the `precision` parameter in `passes.onnxruntime.default.config` to `'fp16'` or `'int8'` respectively. INT8 quantization will show the most notable latency improvements but is more likely to lower performance. There are three types of quantization for ONNXRT and can be set in `onnxruntime.default.config` under `quantization_types`. The differences of the first two are for how they calibrate i.e. set the scale and zero points which are only relevant for integer based quantization: - **Static Quantization**: @@ -670,27 +670,26 @@ There are three types of quantization for ONNXRT and can be set in `onnxruntime. - The scale and zero point of activations are calculated on-the-fly (online) and are specific for each forward pass. - This approach is more accurate but introduces extra computational overhead -The `onnx_runtime_interface_pass` pass also supports mixed precision. This is an automatic only procedure, where ONNXRT finds a minimal set of ops to skip while retaining a certain level of accuracy, converting most of the ops to float16 but leaving some in float32. +The `onnx_runtime_interface_pass` pass also supports mixed precision. This is an automatic only procedure, where ONNXRT finds a minimal set of ops to skip while retaining a certain level of accuracy, converting most of the ops to float16 but leaving some in float32. - **Auto Mixed Precision Quantization**: - Automatically adjusts between FP16 and FP32 precisions to retain certain level of accuracy - The `precision` parameter does not need to be set in the config since the whole process is automatic. - Unfortunately, this process is currently only supported on GPU. - This approach is most beneficial when INT8 or FP16 exclusive quantizations (static or dynamic) are giving poor results. -All three methodolgies first pre-procsses the model before quantization adding further optimizations. This intermidate model is stored to the `pre-processed` directory. +All three methodolgies first pre-procsses the model before quantization adding further optimizations. This intermidate model is stored to the `pre-processed` directory. -For this example, we will set the `precision` to `'uint8'` (since `ConvInteger` node is not currently supported for `'int8'` on ONNXRT GPU execution provider). +For this example, we will set the `precision` to `'uint8'` (since `ConvInteger` node is not currently supported for `'int8'` on ONNXRT GPU execution provider). We will also set the `precision_types` to `['static', 'dynamic', 'auto']` to compare all three quantization methods, whilst keeping the other settings the exact same for a fair comparison against the optimized `vgg7` model used in the previous section. -```bash +```python VGG_TOML_PATH = "../../../machop/configs/onnx/vgg7_gpu_quant.toml" VGG_CHECKPOINT_PATH = "../../../mase_output/vgg7-pre-trained/test-accu-0.9332.ckpt" !ch transform --config {VGG_TOML_PATH} --load {VGG_CHECKPOINT_PATH} --load-type pl ``` -```text [2024-03-29 13:49:26,029] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect) INFO: Seed set to 0 WARNING: Logging before flag parsing goes to stderr. @@ -762,7 +761,7 @@ VGG_CHECKPOINT_PATH = "../../../mase_output/vgg7-pre-trained/test-accu-0.9332.ck I0329 13:50:32.508896 139783521261376 onnx_runtime.py:90] Project will be created at /root/mase/mase_output/onnxrt/vgg7_cls_cifar10_2024-03-29 INFO  ONNX Conversion Complete. Stored ONNX model to /root/mase/mase_output/onnxrt/vgg7_cls_cifar10_2024-03-29/optimized/version_0/model.onnx I0329 13:50:53.587861 139783521261376 onnx_runtime.py:108] ONNX Conversion Complete. Stored ONNX model to /root/mase/mase_output/onnxrt/vgg7_cls_cifar10_2024-03-29/optimized/version_0/model.onnx - INFO  ONNX Model Summary: + INFO  ONNX Model Summary: +-------+----------------------------+----------+---------------------------------------------------------------------+-------------------------------------+---------------------------------------------------+ | Index | Name | Type | Inputs | Outputs | Attributes | +-------+----------------------------+----------+---------------------------------------------------------------------+-------------------------------------+---------------------------------------------------+ @@ -789,7 +788,7 @@ VGG_CHECKPOINT_PATH = "../../../mase_output/vgg7-pre-trained/test-accu-0.9332.ck | 20 | /classifier.3/Relu | Relu | /classifier.2/Gemm_output_0 | /classifier.3/Relu_output_0 | | | 21 | /last_layer/Gemm | Gemm | /classifier.3/Relu_output_0, last_layer.weight, last_layer.bias | 76 | alpha, beta, transB | +-------+----------------------------+----------+---------------------------------------------------------------------+-------------------------------------+---------------------------------------------------+ - I0329 13:50:53.719763 139783521261376 onnx_runtime.py:146] ONNX Model Summary: + I0329 13:50:53.719763 139783521261376 onnx_runtime.py:146] ONNX Model Summary: +-------+----------------------------+----------+---------------------------------------------------------------------+-------------------------------------+---------------------------------------------------+ | Index | Name | Type | Inputs | Outputs | Attributes | +-------+----------------------------+----------+---------------------------------------------------------------------+-------------------------------------+---------------------------------------------------+ @@ -860,7 +859,7 @@ VGG_CHECKPOINT_PATH = "../../../mase_output/vgg7-pre-trained/test-accu-0.9332.ck | Average GPU Power Usage | 52.579 W | | Inference Energy Consumption | 0.125 mWh | +------------------------------+-----------+ - I0329 13:55:58.685605 139783521261376 runtime_analysis.py:521] + I0329 13:55:58.685605 139783521261376 runtime_analysis.py:521] Results vgg7: +------------------------------+-----------+ | Metric (Per Batch) | Value | @@ -896,7 +895,7 @@ VGG_CHECKPOINT_PATH = "../../../mase_output/vgg7-pre-trained/test-accu-0.9332.ck | Average GPU Power Usage | 53.26 W | | Inference Energy Consumption | 0.089695 mWh | +------------------------------+--------------+ - I0329 13:56:20.459139 139783521261376 runtime_analysis.py:521] + I0329 13:56:20.459139 139783521261376 runtime_analysis.py:521] Results vgg7-onnx: +------------------------------+--------------+ | Metric (Per Batch) | Value | @@ -935,7 +934,7 @@ VGG_CHECKPOINT_PATH = "../../../mase_output/vgg7-pre-trained/test-accu-0.9332.ck | Average GPU Power Usage | 58.211 W | | Inference Energy Consumption | 0.14364 mWh | +------------------------------+-------------+ - I0329 13:56:42.742136 139783521261376 runtime_analysis.py:521] + I0329 13:56:42.742136 139783521261376 runtime_analysis.py:521] Results vgg7-onnx: +------------------------------+-------------+ | Metric (Per Batch) | Value | @@ -974,7 +973,7 @@ VGG_CHECKPOINT_PATH = "../../../mase_output/vgg7-pre-trained/test-accu-0.9332.ck | Average GPU Power Usage | 22.924 W | | Inference Energy Consumption | 0.0742 mWh | +------------------------------+------------+ - I0329 13:59:46.169317 139783521261376 runtime_analysis.py:521] + I0329 13:59:46.169317 139783521261376 runtime_analysis.py:521] Results vgg7-onnx: +------------------------------+------------+ | Metric (Per Batch) | Value | @@ -1010,7 +1009,7 @@ VGG_CHECKPOINT_PATH = "../../../mase_output/vgg7-pre-trained/test-accu-0.9332.ck | Average GPU Power Usage | 49.313 W | | Inference Energy Consumption | 0.067374 mWh | +------------------------------+--------------+ - I0329 14:00:07.487649 139783521261376 runtime_analysis.py:521] + I0329 14:00:07.487649 139783521261376 runtime_analysis.py:521] Results vgg7-onnx: +------------------------------+--------------+ | Metric (Per Batch) | Value | @@ -1030,6 +1029,6 @@ VGG_CHECKPOINT_PATH = "../../../mase_output/vgg7-pre-trained/test-accu-0.9332.ck I0329 14:06:04.342311 139783521261376 save_and_load.py:147] Saved mase graph to /root/mase/mase_output/vgg7_cls_cifar10_2024-03-29/software/transform/transformed_ckpt INFO  Transformation is completed I0329 14:06:04.342653 139783521261376 cli.py:388] Transformation is completed -``` -As we can see, the optimized onnx model still outperforms Pytorch on the VGG model due to it's runtime optimizations. The static performs the best, then the automatic mixed precision which outperforms the dynamic quantization due to its requirement of calculating activations on-the-fly. + +As we can see, the optimized onnx model still outperforms Pytorch on the VGG model due to it's runtime optimizations. The static performs the best, then the automatic mixed precision which outperforms the dynamic quantization due to its requirement of calculating activations on-the-fly. diff --git a/docs/source/modules/documentation/tutorials/advanced/tensorRT_quantization_tutorial.md b/docs/source/modules/documentation/tutorials/advanced/tensorRT_quantization_tutorial.md index 5de3d3c22..b208f6473 100644 --- a/docs/source/modules/documentation/tutorials/advanced/tensorRT_quantization_tutorial.md +++ b/docs/source/modules/documentation/tutorials/advanced/tensorRT_quantization_tutorial.md @@ -5,11 +5,11 @@ This notebook is designed to show the features of the TensorRT passes integrated ## Section 1. INT8 Quantization Firstly, we will show you how to do a int8 quantization of a simple model, `jsc-toy`, and compare the quantized model to the original model using the `Machop API`. The quantization process is split into the following stages, each using their own individual pass, and are explained in depth at each subsection: -1. Fake quantization: `tensorrt_fake_quantize_transform_pass` -2. Calibration: `tensorrt_calibrate_transform_pass` -3. Quantized Aware Training: `tensorrt_fine_tune_transform_pass` -4. Quantization: `tensorrt_engine_interface_pass` -5. Analysis: `tensorrt_analysis_pass` +1. [Fake quantization](#section-11-fake-quantization): `tensorrt_fake_quantize_transform_pass` +2. [Calibration](#section-12-calibration): `tensorrt_calibrate_transform_pass` +3. [Quantized Aware Training](#section-13-quantized-aware-training-qat): `tensorrt_fine_tune_transform_pass` +4. [Quantization](#section-14-tensorrt-quantization): `tensorrt_engine_interface_pass` +5. [Analysis](#section-15-performance-analysis): `tensorrt_analysis_pass` We start by loading in the required libraries and passes required for the notebook as well as ensuring the correct path is set for machop to be used. @@ -63,7 +63,7 @@ set_logging_verbosity("info") I0329 12:52:20.742465 139924298352448 logger.py:44] Set logging level to info -Next, we load in the toml file used for quantization. +Next, we load in the toml file used for quantization. To view the configuration, click [here](../../../machop/configs/tensorrt/jsc_toy_INT8_quantization_by_type.toml). ```python @@ -136,8 +136,8 @@ mg = MaseGraph(model=model) Next, we train the `jsc-toy` model using the machop `train` action with the config from the toml file. -```bash -ch train --config {JSC_TOML_PATH} +```python +!ch train --config {JSC_TOML_PATH} ``` Then we load in the checkpoint. You will have to adjust this according to where it has been stored in the mase_output directory. @@ -169,12 +169,12 @@ mg_original = deepcopy_mase_graph(mg) Firstly, we fake quantize the module in order to perform calibration and fine tuning before actually quantizing - this is only used if we have int8 calibration as other precisions are not currently supported within [pytorch-quantization](https://docs.nvidia.com/deeplearning/tensorrt/pytorch-quantization-toolkit/docs/index.html#) library. -This is acheived through the `tensorrt_fake_quantize_transform_pass` which goes through the model, either by type or by name, replaces each layer appropriately to a fake quantized form if the `quantize` parameter is set in the default config (`passes.tensorrt.default.config`) or on a per name or type basis. +This is acheived through the `tensorrt_fake_quantize_transform_pass` which goes through the model, either by type or by name, replaces each layer appropriately to a fake quantized form if the `quantize` parameter is set in the default config (`passes.tensorrt.default.config`) or on a per name or type basis. Currently the quantizable layers are: - Linear -- Conv1d, Conv2d, Conv3d -- ConvTranspose1d, ConvTranspose2d, ConvTranspose3d +- Conv1d, Conv2d, Conv3d +- ConvTranspose1d, ConvTranspose2d, ConvTranspose3d - MaxPool1d, MaxPool2d, MaxPool3d - AvgPool1d, AvgPool2d, AvgPool3d - LSTM, LSTMCell @@ -202,7 +202,7 @@ summarize_quantization_analysis_pass(mg_original, mg) | ReLU | relu | 4 | 0 | 4 | | output | output | 1 | 0 | 1 | | x | placeholder | 1 | 0 | 1 | - I0329 12:52:41.851252 139924298352448 summary.py:85] + I0329 12:52:41.851252 139924298352448 summary.py:85] | Original type | OP | Total | Changed | Unchanged | |-----------------+--------------+---------+-----------+-------------| | BatchNorm1d | batch_norm1d | 4 | 0 | 4 | @@ -212,16 +212,16 @@ summarize_quantization_analysis_pass(mg_original, mg) | x | placeholder | 1 | 0 | 1 | -As you can see we have succesfully fake quantized all linear layers inside `jsc-toy`. This means that we will be able to simulate a quantized model in order to calibrate and fine tune it. This fake quantization was done on typewise i.e. for linear layers only. See Section 4 for how to apply quantization layerwise - i.e. only first and second layers for example. +As you can see we have succesfully fake quantized all linear layers inside `jsc-toy`. This means that we will be able to simulate a quantized model in order to calibrate and fine tune it. This fake quantization was done on typewise i.e. for linear layers only. See [Section 4](#section-4-layer-wise-mixed-precision) for how to apply quantization layerwise - i.e. only first and second layers for example. ### Section 1.2 Calibration -Next, we perform calibration using the `tensorrt_calibrate_transform_pass`. Calibration is achieved by passing data samples to the quantizer and deciding the best amax for activations. +Next, we perform calibration using the `tensorrt_calibrate_transform_pass`. Calibration is achieved by passing data samples to the quantizer and deciding the best amax for activations. Calibrators can be added as a search space parameter to examine the best performing calibrator. The calibrators have been included in the toml as follows. For example: `calibrators = ["percentile", "mse", "entropy"]` -Note: +Note: - To use `percentile` calibration, a list of percentiles must be given - To use `max` calibration, the `histogram` weight and input calibrators must be removed and replaced with `max`. This will use global maximum absolute value to calibrate the model. - If `post_calibration_analysis` is set true the `tensorrt_analysis_pass` will be run for each calibrator tested to evaluate the most suitable calibrator for the model. @@ -300,7 +300,7 @@ mg, _ = tensorrt_calibrate_transform_pass(mg, pass_args=tensorrt_config) | Average GPU Power Usage | 22.239 W | | Inference Energy Consumption | 0.018661 mWh | +------------------------------+--------------+ - I0329 12:52:49.573626 139924298352448 runtime_analysis.py:521] + I0329 12:52:49.573626 139924298352448 runtime_analysis.py:521] Results jsc-toy: +------------------------------+--------------+ | Metric (Per Batch) | Value | @@ -354,7 +354,7 @@ mg, _ = tensorrt_calibrate_transform_pass(mg, pass_args=tensorrt_config) | Average GPU Power Usage | 22.21 W | | Inference Energy Consumption | 0.018664 mWh | +------------------------------+--------------+ - I0329 12:52:53.233150 139924298352448 runtime_analysis.py:521] + I0329 12:52:53.233150 139924298352448 runtime_analysis.py:521] Results jsc-toy: +------------------------------+--------------+ | Metric (Per Batch) | Value | @@ -408,7 +408,7 @@ mg, _ = tensorrt_calibrate_transform_pass(mg, pass_args=tensorrt_config) | Average GPU Power Usage | 22.252 W | | Inference Energy Consumption | 0.017687 mWh | +------------------------------+--------------+ - I0329 12:52:56.441818 139924298352448 runtime_analysis.py:521] + I0329 12:52:56.441818 139924298352448 runtime_analysis.py:521] Results jsc-toy: +------------------------------+--------------+ | Metric (Per Batch) | Value | @@ -462,7 +462,7 @@ mg, _ = tensorrt_calibrate_transform_pass(mg, pass_args=tensorrt_config) | Average GPU Power Usage | 22.426 W | | Inference Energy Consumption | 0.018681 mWh | +------------------------------+--------------+ - I0329 12:53:05.428555 139924298352448 runtime_analysis.py:521] + I0329 12:53:05.428555 139924298352448 runtime_analysis.py:521] Results jsc-toy: +------------------------------+--------------+ | Metric (Per Batch) | Value | @@ -516,7 +516,7 @@ mg, _ = tensorrt_calibrate_transform_pass(mg, pass_args=tensorrt_config) | Average GPU Power Usage | 22.525 W | | Inference Energy Consumption | 0.018149 mWh | +------------------------------+--------------+ - I0329 12:53:22.697756 139924298352448 runtime_analysis.py:521] + I0329 12:53:22.697756 139924298352448 runtime_analysis.py:521] Results jsc-toy: +------------------------------+--------------+ | Metric (Per Batch) | Value | @@ -538,7 +538,7 @@ mg, _ = tensorrt_calibrate_transform_pass(mg, pass_args=tensorrt_config) I0329 12:53:22.701544 139924298352448 calibrate.py:213] Succeeded in calibrating the model in PyTorch! -From the results, the 99% `percentile` clips too many values during the amax calibration, compromising the loss. However 99.99% demonstrates higher validation accuracy alongside `mse` and `entropy` for `jsc-toy`. For such a small model, the methods are not highly distinguished, however for larger models this calibration process will be important for ensuring the quantized model still performs well. +From the results, the 99% `percentile` clips too many values during the amax calibration, compromising the loss. However 99.99% demonstrates higher validation accuracy alongside `mse` and `entropy` for `jsc-toy`. For such a small model, the methods are not highly distinguished, however for larger models this calibration process will be important for ensuring the quantized model still performs well. ### Section 1.3 Quantized Aware Training (QAT) @@ -576,14 +576,14 @@ mg, _ = tensorrt_fine_tune_transform_pass(mg, pass_args=tensorrt_config) I0329 12:53:59.800536 139924298352448 cuda.py:61] LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0] - I0329 12:53:59.814722 139924298352448 model_summary.py:94] + I0329 12:53:59.814722 139924298352448 model_summary.py:94] | Name | Type | Params ------------------------------------------------- - 0 | model | GraphModule | 327 - 1 | loss_fn | CrossEntropyLoss | 0 - 2 | acc_train | MulticlassAccuracy | 0 - 3 | loss_val | MeanMetric | 0 - 4 | loss_test | MeanMetric | 0 + 0 | model | GraphModule | 327 + 1 | loss_fn | CrossEntropyLoss | 0 + 2 | acc_train | MulticlassAccuracy | 0 + 3 | loss_val | MeanMetric | 0 + 4 | loss_test | MeanMetric | 0 ------------------------------------------------- 327 Trainable params 0 Non-trainable params @@ -620,7 +620,7 @@ mg, _ = tensorrt_fine_tune_transform_pass(mg, pass_args=tensorrt_config) After QAT, we are now ready to convert the model to a tensorRT engine so that it can be run with the superior inference speeds. To do so, we use the `tensorrt_engine_interface_pass` which converts the `MaseGraph`'s model from a Pytorch one to an ONNX format as an intermediate stage of the conversion. -During the conversion process, the `.onnx` and `.trt` files are stored to their respective folders shown in Section 1.3. +During the conversion process, the `.onnx` and `.trt` files are stored to their respective folders shown in [Section 1.3](#section-13-quantized-aware-training-qat). This interface pass returns a dictionary containing the `onnx_path` and `trt_engine_path`. @@ -677,7 +677,7 @@ _, _ = runtime_analysis_pass(meta['trt_engine_path'], pass_args=runtime_analysis | Average GPU Power Usage | 23.792 W | | Inference Energy Consumption | 0.0057535 mWh | +------------------------------+---------------+ - I0329 13:03:32.504800 139924298352448 runtime_analysis.py:521] + I0329 13:03:32.504800 139924298352448 runtime_analysis.py:521] Results jsc-toy: +------------------------------+---------------+ | Metric (Per Batch) | Value | @@ -709,7 +709,7 @@ _, _ = runtime_analysis_pass(meta['trt_engine_path'], pass_args=runtime_analysis | Average GPU Power Usage | 23.043 W | | Inference Energy Consumption | 0.00085532 mWh | +------------------------------+----------------+ - I0329 13:03:34.503784 139924298352448 runtime_analysis.py:521] + I0329 13:03:34.503784 139924298352448 runtime_analysis.py:521] Results jsc-toy-trt_quantized: +------------------------------+----------------+ | Metric (Per Batch) | Value | @@ -727,18 +727,19 @@ _, _ = runtime_analysis_pass(meta['trt_engine_path'], pass_args=runtime_analysis I0329 13:03:34.506492 139924298352448 runtime_analysis.py:143] Runtime analysis results saved to /root/mase_output/tensorrt/quantization/jsc-toy_cls_jsc_2024-03-29/tensorrt/version_0/model.json -As shown above, the latency has decreased around 6x with the `jsc-toy` model without compromising accuracy due to the well calibrated amax and quantization-aware fine tuning and additional runtime optimizations from TensorRT. The inference energy consumption has thus also dropped tremendously and this is an excellent demonstration for the need to quantize in industry especially for LLMs in order to reduce energy usage. +As shown above, the latency has decreased around 6x with the `jsc-toy` model without compromising accuracy due to the well calibrated amax and quantization-aware fine tuning and additional runtime optimizations from TensorRT. The inference energy consumption has thus also dropped tremendously and this is an excellent demonstration for the need to quantize in industry especially for LLMs in order to reduce energy usage. ## Section 2. FP16 Quantization -We will now load in a new toml configuration that uses fp16 instead of int8, whilst keeping the other settings the exact same for a fair comparison. This time however, we will use chop from the terminal which runs all the passes showcased in Section 1. +We will now load in a new toml configuration that uses fp16 instead of int8, whilst keeping the other settings the exact same for a fair comparison. This time however, we will use chop from the terminal which runs all the passes showcased in [Section 1](#section-1---int8-quantization). -Since float quantization does not require calibration, nor is it supported by `pytorch-quantization`, the model will not undergo fake quantization; for the time being this unfortunately means QAT is unavailable and only undergoes Post Training Quantization (PTQ). +Since float quantization does not require calibration, nor is it supported by `pytorch-quantization`, the model will not undergo fake quantization; for the time being this unfortunately means QAT is unavailable and only undergoes Post Training Quantization (PTQ). -```text +```python JSC_FP16_BY_TYPE_TOML = "../../../machop/configs/tensorrt/jsc_toy_FP16_quantization_by_type.toml" !ch transform --config {JSC_FP16_BY_TYPE_TOML} --load {JSC_CHECKPOINT_PATH} --load-type pl +``` 8808.24s - pydevd: Sending message related to process being replaced timed-out after 5 seconds [2024-03-28 09:37:03,989] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect) @@ -810,7 +811,7 @@ JSC_FP16_BY_TYPE_TOML = "../../../machop/configs/tensorrt/jsc_toy_FP16_quantizat | ReLU | relu | 4 | 0 | 4 | | output | output | 1 | 0 | 1 | | x | placeholder | 1 | 0 | 1 | - I0328 09:37:11.551648 140201001654080 summary.py:85] + I0328 09:37:11.551648 140201001654080 summary.py:85] | Original type | OP | Total | Changed | Unchanged | |-----------------+--------------+---------+-----------+-------------| | BatchNorm1d | batch_norm1d | 4 | 0 | 4 | @@ -848,7 +849,7 @@ JSC_FP16_BY_TYPE_TOML = "../../../machop/configs/tensorrt/jsc_toy_FP16_quantizat | Average GPU Power Usage | 22.024 W | | Inference Energy Consumption | 0.0049148 mWh | +------------------------------+---------------+ - I0328 09:37:36.951404 140201001654080 runtime_analysis.py:437] + I0328 09:37:36.951404 140201001654080 runtime_analysis.py:437] Results jsc-toy: +------------------------------+---------------+ | Metric (Per Batch) | Value | @@ -870,7 +871,7 @@ JSC_FP16_BY_TYPE_TOML = "../../../machop/configs/tensorrt/jsc_toy_FP16_quantizat ------|---------|----------|----------------------|----------------------|----------------------- 0 | Input | FLOAT | (64, 16) | (64, 16) | input 1 | Output | FLOAT | (64, 5) | (64, 5) | 37 - I0328 09:37:36.960667 140201001654080 runtime_analysis.py:167] + I0328 09:37:36.960667 140201001654080 runtime_analysis.py:167] TensorRT Engine Input/Output Information: Index | Type | DataType | Static Shape | Dynamic Shape | Name ------|---------|----------|----------------------|----------------------|----------------------- @@ -892,7 +893,7 @@ JSC_FP16_BY_TYPE_TOML = "../../../machop/configs/tensorrt/jsc_toy_FP16_quantizat | Average GPU Power Usage | 21.706 W | | Inference Energy Consumption | 0.00055067 mWh | +------------------------------+----------------+ - I0328 09:37:43.052305 140201001654080 runtime_analysis.py:437] + I0328 09:37:43.052305 140201001654080 runtime_analysis.py:437] Results jsc-toy-trt_quantized: +------------------------------+----------------+ | Metric (Per Batch) | Value | @@ -912,7 +913,7 @@ JSC_FP16_BY_TYPE_TOML = "../../../machop/configs/tensorrt/jsc_toy_FP16_quantizat I0328 09:37:43.132117 140201001654080 save_and_load.py:147] Saved mase graph to /root/mase/mase_output/jsc-toy_cls_jsc_2024-03-28/software/transform/transformed_ckpt INFO  Transformation is completed I0328 09:37:43.132461 140201001654080 cli.py:383] Transformation is completed -``` + As you can see, `fp16` acheives a slighty higher test accuracy but a slightly lower latency (~30%) from that of int8 quantization; it is still ~2.5x faster than the unquantized model. Now lets apply quantization to a more complicated model. @@ -924,30 +925,31 @@ In this case, we set: - The `by` parameter to `type` - The `quantize` parameter to true for `passes.tensorrt.conv2d.config` and `precision` parameter to 'int8'. - The `input` and `weight` quantize axis for the conv2d layers. -- The default `passes.tensorrt.default.config` precision to true. +- The default `passes.tensorrt.default.config` precision to true. During the TensorRT quantization, the model's conv2d layers will be converted to an int8 fake quantized form, whilst the linear layers are kept to their default 'fp16'. Calibration of the conv2d layers and then fine tuning will be undergone before quantization and inference. -You may either download a pretrained model [here](https://imperiallondon-my.sharepoint.com/:f:/g/personal/zz7522_ic_ac_uk/Emh3VT7Q_qRFmnp8kDrcgDoBwGUuzLwwKNtX8ZAt368jJQ?e=gsKONa), otherwise train it yourself as shown below. +You may either download a pretrained model [here](https://imperiallondon-my.sharepoint.com/:f:/g/personal/zz7522_ic_ac_uk/Emh3VT7Q_qRFmnp8kDrcgDoBwGUuzLwwKNtX8ZAt368jJQ?e=gsKONa), otherwise train it yourself as shown below. -```bash +```python VGG_TYPEWISE_TOML = "../../../machop/configs/tensorrt/vgg7_typewise_mixed_precision.toml" -ch train --config {VGG_TYPEWISE_TOML} +!ch train --config {VGG_TYPEWISE_TOML} ``` -We will now load the checkpoint in, quantize the model and compare it to the unquantized version as we did in Section 1.5 +We will now load the checkpoint in, quantize the model and compare it to the unquantized version as we did in [Section 1.5](#section-15-performance-analysis) -```bash +```python # Change this checkpoint path accordingly VGG_CHECKPOINT_PATH = "../../../mase_output/vgg7-pre-trained/test-accu-0.9332.ckpt" ``` -```text -ch transform --config {VGG_TYPEWISE_TOML} --load {VGG_CHECKPOINT_PATH} --load-type pl +```python +!ch transform --config {VGG_TYPEWISE_TOML} --load {VGG_CHECKPOINT_PATH} --load-type pl +``` [2024-03-28 23:00:09,016] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect) INFO: Seed set to 0 @@ -1031,7 +1033,7 @@ ch transform --config {VGG_TYPEWISE_TOML} --load {VGG_CHECKPOINT_PATH} --load-ty | output | output | 1 | 0 | 1 | | view | view | 1 | 0 | 1 | | x | placeholder | 1 | 0 | 1 | - I0328 23:00:37.270473 139939454809920 summary.py:85] + I0328 23:00:37.270473 139939454809920 summary.py:85] | Original type | OP | Total | Changed | Unchanged | |-----------------+--------------+---------+-----------+-------------| | BatchNorm2d | batch_norm2d | 6 | 0 | 6 | @@ -1159,7 +1161,7 @@ ch transform --config {VGG_TYPEWISE_TOML} --load {VGG_CHECKPOINT_PATH} --load-ty | Average GPU Power Usage | 59.019 W | | Inference Energy Consumption | 0.24777 mWh | +------------------------------+-------------+ - I0328 23:00:55.766893 139939454809920 runtime_analysis.py:521] + I0328 23:00:55.766893 139939454809920 runtime_analysis.py:521] Results vgg7: +------------------------------+-------------+ | Metric (Per Batch) | Value | @@ -1231,7 +1233,7 @@ ch transform --config {VGG_TYPEWISE_TOML} --load {VGG_CHECKPOINT_PATH} --load-ty | Average GPU Power Usage | 59.653 W | | Inference Energy Consumption | 0.25317 mWh | +------------------------------+-------------+ - I0328 23:01:07.450706 139939454809920 runtime_analysis.py:521] + I0328 23:01:07.450706 139939454809920 runtime_analysis.py:521] Results vgg7: +------------------------------+-------------+ | Metric (Per Batch) | Value | @@ -1268,14 +1270,14 @@ ch transform --config {VGG_TYPEWISE_TOML} --load {VGG_CHECKPOINT_PATH} --load-ty Files already downloaded and verified Files already downloaded and verified I0328 23:01:12.623627 139939454809920 cuda.py:61] LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0] - I0328 23:01:12.632704 139939454809920 model_summary.py:94] + I0328 23:01:12.632704 139939454809920 model_summary.py:94] | Name | Type | Params ------------------------------------------------- 0 | model | GraphModule | 14.0 M - 1 | loss_fn | CrossEntropyLoss | 0 - 2 | acc_train | MulticlassAccuracy | 0 - 3 | loss_val | MeanMetric | 0 - 4 | loss_test | MeanMetric | 0 + 1 | loss_fn | CrossEntropyLoss | 0 + 2 | acc_train | MulticlassAccuracy | 0 + 3 | loss_val | MeanMetric | 0 + 4 | loss_test | MeanMetric | 0 ------------------------------------------------- 14.0 M Trainable params 0 Non-trainable params @@ -1643,7 +1645,7 @@ ch transform --config {VGG_TYPEWISE_TOML} --load {VGG_CHECKPOINT_PATH} --load-ty | Average GPU Power Usage | 58.043 W | | Inference Energy Consumption | 0.1441 mWh | +------------------------------+------------+ - I0328 23:06:47.111017 139939454809920 runtime_analysis.py:521] + I0328 23:06:47.111017 139939454809920 runtime_analysis.py:521] Results vgg7: +------------------------------+------------+ | Metric (Per Batch) | Value | @@ -1675,7 +1677,7 @@ ch transform --config {VGG_TYPEWISE_TOML} --load {VGG_CHECKPOINT_PATH} --load-ty | Average GPU Power Usage | 52.687 W | | Inference Energy Consumption | 0.12441 mWh | +------------------------------+-------------+ - I0328 23:07:00.676242 139939454809920 runtime_analysis.py:521] + I0328 23:07:00.676242 139939454809920 runtime_analysis.py:521] Results vgg7-trt_quantized: +------------------------------+-------------+ | Metric (Per Batch) | Value | @@ -1691,13 +1693,13 @@ ch transform --config {VGG_TYPEWISE_TOML} --load {VGG_CHECKPOINT_PATH} --load-ty +------------------------------+-------------+ INFO  Runtime analysis results saved to /root/mase_output/tensorrt/quantization/vgg7_cls_cifar10_2024-03-28/tensorrt/version_7/model.json I0328 23:07:00.677799 139939454809920 runtime_analysis.py:143] Runtime analysis results saved to /root/mase_output/tensorrt/quantization/vgg7_cls_cifar10_2024-03-28/tensorrt/version_7/model.json -``` -By quantizing all convolutional layers to INT8 and maintaining fp16 precision for the linear layers we see a marginal decrease in latency whilst maintaining a comparable accuracy. By experimenting with precisions on a per type basis, you may find insights that work best for your model. + +By quantizing all convolutional layers to INT8 and maintaining fp16 precision for the linear layers we see a marginal decrease in latency whilst maintaining a comparable accuracy. By experimenting with precisions on a per type basis, you may find insights that work best for your model. ## Section 4. Layer-wise Mixed Precision -So far we have strictly quantized either in int8 or fp16. Now, we will show how to conduct layerwise mixed precision using the same `vgg7` model. In this case we will show how for instance, layer 0 and 1 can be set to fp16, while the remaining layers can be int8 quantized. +So far we have strictly quantized either in int8 or fp16. Now, we will show how to conduct layerwise mixed precision using the same `vgg7` model. In this case we will show how for instance, layer 0 and 1 can be set to fp16, while the remaining layers can be int8 quantized. For this, we set: - The `by` parameter to `name` @@ -1706,10 +1708,11 @@ For this, we set: - The `precision` to 'int8' for `passes.tensorrt.feature_layers_2.config and passes.tensorrt.feature_layers_3.config` (although this is not necessary since the default is already set to 'int8') -```text +```python VGG_LAYERWISE_TOML = "../../../machop/configs/tensorrt/vgg7_layerwise_mixed_precision.toml" -ch transform --config {VGG_LAYERWISE_TOML} --load {VGG_CHECKPOINT_PATH} --load-type pl +!ch transform --config {VGG_LAYERWISE_TOML} --load {VGG_CHECKPOINT_PATH} --load-type pl +``` [2024-03-28 23:25:51,157] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect) INFO: Seed set to 0 @@ -1793,7 +1796,7 @@ ch transform --config {VGG_LAYERWISE_TOML} --load {VGG_CHECKPOINT_PATH} --load-t | output | output | 1 | 0 | 1 | | view | view | 1 | 0 | 1 | | x | placeholder | 1 | 0 | 1 | - I0328 23:26:12.941653 140449214740288 summary.py:85] + I0328 23:26:12.941653 140449214740288 summary.py:85] | Original type | OP | Total | Changed | Unchanged | |-----------------+--------------+---------+-----------+-------------| | BatchNorm2d | batch_norm2d | 6 | 0 | 6 | @@ -1953,7 +1956,7 @@ ch transform --config {VGG_LAYERWISE_TOML} --load {VGG_CHECKPOINT_PATH} --load-t | Average GPU Power Usage | 57.532 W | | Inference Energy Consumption | 0.28607 mWh | +------------------------------+-------------+ - I0328 23:26:29.263397 140449214740288 runtime_analysis.py:521] + I0328 23:26:29.263397 140449214740288 runtime_analysis.py:521] Results vgg7: +------------------------------+-------------+ | Metric (Per Batch) | Value | @@ -2037,7 +2040,7 @@ ch transform --config {VGG_LAYERWISE_TOML} --load {VGG_CHECKPOINT_PATH} --load-t | Average GPU Power Usage | 57.867 W | | Inference Energy Consumption | 0.29145 mWh | +------------------------------+-------------+ - I0328 23:26:40.146152 140449214740288 runtime_analysis.py:521] + I0328 23:26:40.146152 140449214740288 runtime_analysis.py:521] Results vgg7: +------------------------------+-------------+ | Metric (Per Batch) | Value | @@ -2096,7 +2099,7 @@ ch transform --config {VGG_LAYERWISE_TOML} --load {VGG_CHECKPOINT_PATH} --load-t | Average GPU Power Usage | 55.687 W | | Inference Energy Consumption | 0.12102 mWh | +------------------------------+-------------+ - I0328 23:07:00.676242 139939454809920 runtime_analysis.py:521] + I0328 23:07:00.676242 139939454809920 runtime_analysis.py:521] Results vgg7-trt_quantized: +------------------------------+-------------+ | Metric (Per Batch) | Value | @@ -2112,6 +2115,6 @@ ch transform --config {VGG_LAYERWISE_TOML} --load {VGG_CHECKPOINT_PATH} --load-t +------------------------------+-------------+ INFO  Runtime analysis results saved to /root/mase_output/tensorrt/quantization/vgg7_cls_cifar10_2024-03-28/tensorrt/version_8/model.json I0328 23:07:00.677799 139939454809920 runtime_analysis.py:143] Runtime analysis results saved to /root/mase_output/tensorrt/quantization/vgg7_cls_cifar10_2024-03-28/tensorrt/version_8/model.json -``` + In this case, we can see through the quantized summary that one convolutional layer (feature_layers_1) has not been quantized as its precision will be configured to 'fp16' in the tensorrt engine conversion stage whilst the remaining convolutional and linear layers have been quantized. diff --git a/docs/source/modules/documentation/tutorials/cli/mixed_precision_search_on_manual_model.md b/docs/source/modules/documentation/tutorials/cli/mixed_precision_search_on_manual_model.md index 9fbb14c38..8f4d69cf1 100644 --- a/docs/source/modules/documentation/tutorials/cli/mixed_precision_search_on_manual_model.md +++ b/docs/source/modules/documentation/tutorials/cli/mixed_precision_search_on_manual_model.md @@ -82,7 +82,7 @@ cd machop ``` When the search is done, the best quantization config will be printed out. Since we run multi-objective search. There may be multiple best trials found by Optuna. -```text +```txt Best trial(s): | | number | software_metrics | hardware_metrics | scaled_metrics | |----+----------+--------------------------------------+------------------------------------------------------+-------------------------------------------------| @@ -100,7 +100,7 @@ Here is part of the `log.json` recording all search details. For example, `log["0"]["user_attrs_sampled_config"]` is the sampled quantization config of trial 0. Expand it and you will set the precision of each matmul/linear layer's operands. -```text +```json { "0":{ "number":0, @@ -120,8 +120,8 @@ For example, `log["0"]["user_attrs_sampled_config"]` is the sampled quantization "bias_width":2, "bias_frac_width":8 } - } - // ... + }, + ... }, "user_attrs_scaled_metrics":{ "accuracy":0.5, @@ -154,8 +154,8 @@ For example, `log["0"]["user_attrs_sampled_config"]` is the sampled quantization "bias_width":4, "bias_frac_width":3 } - } - // ... + }, + ... }, "user_attrs_scaled_metrics":{ "accuracy":0.5747232437, @@ -169,7 +169,7 @@ For example, `log["0"]["user_attrs_sampled_config"]` is the sampled quantization "datetime_start":1694095031290, "datetime_complete":1694095032462, "duration":1172 - } - // ... + }, + ... } ``` \ No newline at end of file diff --git a/docs/source/modules/documentation/tutorials/cli/mixed_precision_search_on_mase_graph.md b/docs/source/modules/documentation/tutorials/cli/mixed_precision_search_on_mase_graph.md index 10b309d90..e82d69d2d 100644 --- a/docs/source/modules/documentation/tutorials/cli/mixed_precision_search_on_mase_graph.md +++ b/docs/source/modules/documentation/tutorials/cli/mixed_precision_search_on_mase_graph.md @@ -4,13 +4,13 @@ This tutorial shows how to search for mixed-precision quantization strategy for ## Commands -First we train a model on the dataset. After training for some epochs, we get a model with some validation accuracy. The checkpoint is saved at an auto-created location. You can refer to Run the train action with the CLI for more detailed explanation. +First we train a model on the dataset. After training for some epochs, we get a model with some validation accuracy. The checkpoint is saved at an auto-created location. You can refer to [Run the train action with the CLI](../train/simple_train_flow.md) for more detailed explanation. The reason why we need a pre-trained model is because we would like to do a post-training-quantization (PTQ) search. This means the quantization happens on a pre-trained model. We then use the PTQ accuracy as a proxy signal for our search. ```bash -cd src +cd src ./ch train jsc-tiny jsc --max-epochs 3 --batch-size 256 --accelerator cpu --project tmp --debug --cpu 0 ``` @@ -151,7 +151,7 @@ The entire searching log is saved in `../mase_output/jsc-tiny/software/search_ck Here is part of the `log.json` -```text +```json { "0":{ "number":0, diff --git a/docs/source/modules/documentation/tutorials/cli/simple_train_flow.md b/docs/source/modules/documentation/tutorials/cli/simple_train_flow.md index 52c21ca40..8f9eb8cb1 100644 --- a/docs/source/modules/documentation/tutorials/cli/simple_train_flow.md +++ b/docs/source/modules/documentation/tutorials/cli/simple_train_flow.md @@ -12,7 +12,7 @@ In this case, we can try a toymodel, the command looks like the following ```bash # assuming you you are at the our-stuff/mase directory -cd src +cd src ./ch train toy toy_tiny --config ../configs/archive/test/train.toml --max-epochs 3 ``` @@ -20,7 +20,7 @@ cd src You can fetch all command-line arguments: -```text +```bash [nix-shell:~/Projects/mase/src]$ ./ch -help INFO Set logging level to debug WARNING TensorRT pass is unavailable because the following dependencies are not installed: pytorch_quantization, tensorrt, pycuda, cuda. @@ -176,7 +176,7 @@ This directory includes ## Training Logs -MASE creates [Tensorboard](https://pytorch.org/tutorials/recipes/recipes/tensorboard_with_pytorch.html) or [wandb](https://wandb.ai/site) logs for the training flow - allowing tracking and visualizing metrics such as loss and accuracy. The log files are in `/software/tensorboard/lightning_logs/version_`. +MASE creates [Tensorboard](https://pytorch.org/tutorials/recipes/recipes/tensorboard_with_pytorch.html) or [wandb](https://wandb.ai/site) logs for the training flow - allowing tracking and visualizing metrics such as loss and accuracy. The log files are in `/software/tensorboard/lightning_logs/version_`. Run Tensorboard to visualise the logs using: @@ -209,5 +209,5 @@ To test the model trained above you can use: ```bash # After training, you will have your checkpoint under mase-tools/mase_output -# For example, the checkpoint is under ../mase_output/toy_classification_toy-tiny_2023-07-03/software/training_ckpts/best.ckpt +# For example, the checkpoint is under ../mase_output/toy_classification_toy-tiny_2023-07-03/software/training_ckpts/best.ckpt ./ch test toy toy_tiny --config ../configs/archive/test/train.toml --load ../mase_output/toy_classification_toy_tiny_2024-06-13/software/training_ckpts/best.ckpt``` diff --git a/docs/source/modules/documentation/tutorials/developer/Add-model-to-machop.md b/docs/source/modules/documentation/tutorials/developer/Add-model-to-machop.md index c7750e262..a30a172fb 100644 --- a/docs/source/modules/documentation/tutorials/developer/Add-model-to-machop.md +++ b/docs/source/modules/documentation/tutorials/developer/Add-model-to-machop.md @@ -3,10 +3,10 @@ This document includes steps to add a new model into Machop ## Overall Structure ### Model -All models that Machop support are defined inside **mase-tools/machop/chop/models**. Each model has a unique get model function, which can be called to create the model. Those get model function will be exported into a dictionary in `__init__` file. +All models that Machop support are defined inside **mase-tools/machop/chop/models**. Each model has a unique get model function, which can be called to create the model. Those get model function will be exported into a dictionary in [\_\_init\_\_](%2E%2E%5Cmachop%5Cchop%5Cmodels%5C%5F%5Finit%5F%5F.py) file. ### Command Line Interface -Command Line Interface (cli) will take the input config, and perform the task defined inside the config. When training, cli will look into the dictionary contains the get funtions, use the get-function to create a model, and do training then. +[Command Line Interface (cli)](..\machop\chop\cli.py) will take the input config, and perform the task defined inside the config. When training, cli will look into the dictionary contains the get funtions, use the get-function to create a model, and do training then. ## What To Do 1. Find the GitHub repositories of the original paper, find the code that defines the models, and copy it into the right folder under **mase-tools\machop\chop\models** @@ -20,7 +20,7 @@ Command Line Interface (cli) will take the input config, and perform the task de ## Get model function - **Info** should be used as one of the input variables. It is a dictionary that contains information about the dataset, e.g., number of classes; input image size. -- Other then **Info**, Inputs for different types of models are different, you can check `_setup_model_and_dataset` function defined in cli.py for more detail. +- Other then **Info**, Inputs for different types of models are different, you can check `_setup_model_and_dataset` function defined in [cli.py](..\machop\chop\cli.py) for more detail. - function name of get-function should be in smaller case - keys of the dictionary should also be in smaller case diff --git a/docs/source/modules/hardware/activations/gelu.md b/docs/source/modules/hardware/activations/gelu.md index 25b2560b4..cb2d094e8 100644 --- a/docs/source/modules/hardware/activations/gelu.md +++ b/docs/source/modules/hardware/activations/gelu.md @@ -11,7 +11,7 @@ When the approximate argument is set to 'tanh', GELU is estimated with: `GELU(x) = 0.5 * x * (1 + Tanh(2/π * (x + 0.044715 * x^3)))` -## Parameters: +### Parameters: - `approximate` (str, optional): The GELU approximation algorithm to use: 'none' | 'tanh'. Default: 'none'. diff --git a/docs/source/modules/hardware/activations/selu.md b/docs/source/modules/hardware/activations/selu.md index 882e5c8d5..8413420a1 100644 --- a/docs/source/modules/hardware/activations/selu.md +++ b/docs/source/modules/hardware/activations/selu.md @@ -8,7 +8,7 @@ where: - α = 1.6732632423543772848170429916717 - scale = 1.0507009873554804934193349852946 -## Parameters +### Parameters - `inplace` (bool, optional): Can optionally do the operation in-place. Default: False. @@ -30,7 +30,7 @@ A hybrid approach is used for implementing exponential function $e^{-|x|}$ for a 2. **Representation of Binary Number**: The N-bit binary number $a = b_{N-1}b_{N-2}...b_1b_0$ is represented, where $b_0$ is the least significant bit, and each bit $b_i$ has a place value $p_i$ given by $p_i = 2^{-P} \times 2^i$. - + 3. **Exponential Computation**: $e^{-a} = \prod e^{-p_i \times b_i}$ diff --git a/docs/source/modules/hardware/activations/softplus.md b/docs/source/modules/hardware/activations/softplus.md index 0d7f3e77a..ed0066cfb 100644 --- a/docs/source/modules/hardware/activations/softplus.md +++ b/docs/source/modules/hardware/activations/softplus.md @@ -9,7 +9,7 @@ where: - Softplus is a smooth approximation to the ReLU function and can be used to constrain the output of a machine to always be positive. - For numerical stability, the implementation reverts to the linear function when `input * β > threshold`. -## Parameters: +### Parameters: - `beta` (int): The β value for the Softplus formulation. Default: 1. - `threshold` (int): Values above this revert to a linear function. Default: 20. diff --git a/docs/source/modules/hardware/linear/fixed_linear.md b/docs/source/modules/hardware/linear/fixed_linear.md index 33ddf350e..672ea765a 100644 --- a/docs/source/modules/hardware/linear/fixed_linear.md +++ b/docs/source/modules/hardware/linear/fixed_linear.md @@ -32,7 +32,7 @@ The module has the following parameters, following the hardware metadata standar | Parameter | Default Value | Definition | |------------------------------ |-------------------------- |------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | | DATA_IN_0_PARALLELISM_DIM_0 | 4 | Number of elements per transaction at the input interface. Dictates the number of transactions to compute the full layer. | -| WEIGHT_PARALLELISM_DIM_0 | 4 | Number of columns of the weights matrix per transaction at the weights interface. This is equivalent to the number of dot product modules. Also dictates the number of backpressure cycles on the input interface (see Latency Analysis below) | +| WEIGHT_PARALLELISM_DIM_0 | 4 | Number of columns of the weights matrix per transaction at the weights interface. This is equivalent to the number of dot product modules. Also dictates the number of backpressure cycles on the input interface (see [Latency Analysis](#latency-analysis) below) | | DATA_OUT_0_PARALLELISM_DIM_0 | WEIGHT_PARALLELISM_DIM_0 | Number of elements per transaction at the output interface. | | BIAS_PARALLELISM_DIM_0 | WEIGHT_PARALLELISM_DIM_0 | Number of elements per transaction at the bias interface. Dictates the number of fixed-point adders. | @@ -56,7 +56,7 @@ The same process is repeated with the second input sub-vector $X_2$ and weight s img

-## Latency Analysis +## Latency Analysis The time taken to compute a linear layer using the `fixed_linear` module, $L_{FL}$ can be broken down into 2 phases, the input driving phase $L_L$, and the pipeline unloading phase $L_U$ that begins after the last input beat is transferred. diff --git a/docs/source/modules/hardware/systolic_modules/output_stationary.md b/docs/source/modules/hardware/systolic_modules/output_stationary.md index d0e5370b2..d5c605a41 100644 --- a/docs/source/modules/hardware/systolic_modules/output_stationary.md +++ b/docs/source/modules/hardware/systolic_modules/output_stationary.md @@ -8,6 +8,6 @@ The MAC units in each PE perform the multiply-accumulate operation over 2 cycles ![Systolic Array](https://raw.githubusercontent.com/DeepWok/mase/main/docs/source/imgs/hardware/sys_array_pe.png) -## Systolic Module Driver +### Systolic Module Driver The Systolic Module Driver generates pulse signals in the format required to drive the read interface of an on-chip buffer such that data signals are made available with the required timing for the processing elements of a systolic module. This is achieved through a shift register of size BUFFER_SLOT_COUNT. After receiving a starting pulse, the least significant bit is set to 1. Subsequently, the register shifts after every shift pulse, up to a runtime-parametrizable pulse limit count parameter (this is set to the number of output features for the layer being executed). The driver should then pulse a subsequent BUFFER_SLOT_COUNT times until the register is flushed. \ No newline at end of file diff --git a/docs/source/modules/labs_2024/lab_0_introduction.rst b/docs/source/modules/labs_2024/lab_0_introduction.rst index 896c0c769..0bde49047 100644 --- a/docs/source/modules/labs_2024/lab_0_introduction.rst +++ b/docs/source/modules/labs_2024/lab_0_introduction.rst @@ -46,7 +46,7 @@ TroubleShooting You may find that you have to use `Python3.11` but Google Colab only provides `Python3.10`. In this case, you can use the following command to force the kernel ot use `Python3.11`: -.. code-block:: text +.. code-block:: python #The code below installs 3.11 (assuming you now have 3.10 in colab) and restarts environment, so you can run your cells. import sys #for version checker From 9dba4fde4bb557c582f3b65b6d13892817f9c7a5 Mon Sep 17 00:00:00 2001 From: Cheng Zhang Date: Tue, 11 Feb 2025 20:19:06 +0000 Subject: [PATCH 10/38] rename files --- src/chop/nn/__init__.py | 1 - src/chop/nn/optical/__init__.py | 4 +- src/chop/nn/optical/functional/__init__.py | 54 --- src/chop/nn/optical/functional/general.py | 433 ------------------ src/chop/nn/optical/modules/morr_conv2d.py | 43 +- src/chop/nn/optical/modules/morr_linear.py | 23 +- src/chop/nn/optical/utils/__init__.py | 26 ++ .../optical/{functional => utils}/compute.py | 0 .../{functional => utils}/initializer.py | 0 .../nn/optical/{functional => utils}/mrr.py | 0 .../optical/{functional => utils}/mrr_op.py | 0 .../optical/{functional => utils}/quantize.py | 0 .../{functional => utils}/torch_train.py | 0 .../optical/module_transform_helper.py | 2 - .../module/transforms/optical/optical.py | 44 +- .../transforms/optical/test_optical_module.py | 144 +----- 16 files changed, 70 insertions(+), 704 deletions(-) delete mode 100644 src/chop/nn/optical/functional/__init__.py delete mode 100644 src/chop/nn/optical/functional/general.py create mode 100644 src/chop/nn/optical/utils/__init__.py rename src/chop/nn/optical/{functional => utils}/compute.py (100%) rename src/chop/nn/optical/{functional => utils}/initializer.py (100%) rename src/chop/nn/optical/{functional => utils}/mrr.py (100%) rename src/chop/nn/optical/{functional => utils}/mrr_op.py (100%) rename src/chop/nn/optical/{functional => utils}/quantize.py (100%) rename src/chop/nn/optical/{functional => utils}/torch_train.py (100%) diff --git a/src/chop/nn/__init__.py b/src/chop/nn/__init__.py index 5533a7608..7ab651324 100644 --- a/src/chop/nn/__init__.py +++ b/src/chop/nn/__init__.py @@ -1,4 +1,3 @@ from .quantized import quantized_module_map -from .optical import optical_module_map MASE_LEAF_LAYERS = tuple(quantized_module_map.values()) diff --git a/src/chop/nn/optical/__init__.py b/src/chop/nn/optical/__init__.py index b74e7df6b..0310afb71 100644 --- a/src/chop/nn/optical/__init__.py +++ b/src/chop/nn/optical/__init__.py @@ -1,3 +1 @@ -from .modules import ( - optical_module_map, -) +from .modules import optical_module_map diff --git a/src/chop/nn/optical/functional/__init__.py b/src/chop/nn/optical/functional/__init__.py deleted file mode 100644 index 84f94e7b6..000000000 --- a/src/chop/nn/optical/functional/__init__.py +++ /dev/null @@ -1,54 +0,0 @@ -from .mrr import ( - MORRConfig_20um_MQ, - MRRConfig_5um_HQ, - MRRConfig_5um_MQ, - MRRConfig_5um_LQ, - MORRConfig_10um_MQ, -) - -from .compute import ( - im2col_2d, - toeplitz, -) - -from .general import ( - logger, -) - -from .initializer import ( - morr_uniform_, -) - -from .quantize import ( - input_quantize_fn, - weight_quantize_fn, -) - -from .mrr_op import ( - mrr_roundtrip_phase_to_tr_func, - mrr_roundtrip_phase_to_tr_fused, -) - - -# """ -# Description: -# Author: Jiaqi Gu (jqgu@utexas.edu) -# Date: 2021-06-09 01:40:22 -# LastEditors: Jiaqi Gu (jqgu@utexas.edu) -# LastEditTime: 2021-06-09 01:40:22 -# """ - -# import importlib -# import os - -# # automatically import any Python files in this directory -# for file in sorted(os.listdir(os.path.dirname(__file__))): -# if file.endswith(".py") and not file.startswith("_"): -# source = file[: file.find(".py")] -# module = importlib.import_module("torchonn.layers." + source) -# if "__all__" in module.__dict__: -# names = module.__dict__["__all__"] -# else: -# # import all names that do not begin with _ -# names = [x for x in module.__dict__ if not x.startswith("_")] -# globals().update({k: getattr(module, k) for k in names}) diff --git a/src/chop/nn/optical/functional/general.py b/src/chop/nn/optical/functional/general.py deleted file mode 100644 index f5aa87e14..000000000 --- a/src/chop/nn/optical/functional/general.py +++ /dev/null @@ -1,433 +0,0 @@ -""" -Description: -Author: Jiaqi Gu (jqgu@utexas.edu) -Date: 2021-06-06 01:55:29 -LastEditors: Jiaqi Gu (jqgu@utexas.edu) -LastEditTime: 2021-06-06 01:55:30 -""" - -import os -import argparse -import json -import logging -import logging.handlers -import time -from collections import OrderedDict -from datetime import datetime -from pathlib import Path -from typing import Optional - -import numpy as np -import torch - - -__all__ = [ - "ensure_dir", - "read_json", - "write_json", - "profile", - "print_stat", - "Timer", - "TimerCtx", - "TorchTracemalloc", - "fullprint", - "setup_default_logging", - "Logger", - "logger", - "get_logger", - "ArgParser", - "disable_tf_warning", - "AverageMeter", -] - - -def ensure_dir(dirname, exist_ok: bool = True): - dirname = Path(dirname) - if not dirname.is_dir(): - dirname.mkdir(parents=True, exist_ok=exist_ok) - - -def read_json(fname): - with open(fname, "rt") as handle: - return json.load(handle, object_hook=OrderedDict) - - -def write_json(content, fname): - with open(fname, "wt") as handle: - json.dump(content, handle, indent=4, sort_keys=False) - - -def profile(func=None, timer=True): - from functools import wraps, partial - import time - - if func == None: - return partial(profile, timer=timer) - - @wraps(func) - def wrapper(*args, **kw): - if timer: - local_time = time.time() - res = func(*args, **kw) - end_time = time.time() - print( - "[I] <%s> runtime: %.3f ms" - % (func.__name__, (end_time - local_time) * 1000) - ) - else: - res = func(*args, **kw) - return res - - return wrapper - - -def print_stat(x, message="", verbose=True): - if verbose: - if isinstance(x, torch.Tensor): - if torch.is_complex(x): - x = torch.view_as_real(x) - print( - message - + f"min = {x.data.min().item():-15f} max = {x.data.max().item():-15f} mean = {x.data.mean().item():-15f} std = {x.data.std().item():-15f}" - ) - elif isinstance(x, np.ndarray): - print( - message - + f"min = {np.min(x):-15f} max = {np.max(x):-15f} mean = {np.mean(x):-15f} std = {np.std(x):-15f}" - ) - - -class Timer(object): - def __init__(self): - self.cache = datetime.now() - - def check(self): - now = datetime.now() - duration = now - self.cache - self.cache = now - return duration.total_seconds() - - def reset(self): - self.cache = datetime.now() - - -class TimerCtx: - def __enter__(self): - self.start = time.time() - return self - - def __exit__(self, *args): - self.end = time.time() - self.interval = self.end - self.start - - -class TorchTracemalloc(object): - def __init__(self, verbose: bool = False) -> None: - super().__init__() - self.verbose = verbose - - def __enter__(self): - self.begin = self._b2mb(torch.cuda.memory_allocated()) - torch.cuda.reset_max_memory_allocated() # reset the peak gauge to zero - return self - - def _b2mb(self, x): - return x / 2**20 - - def __exit__(self, *exc): - self.end = self._b2mb(torch.cuda.memory_allocated()) - self.peak = self._b2mb(torch.cuda.max_memory_allocated()) - self.used = self.end - self.begin - self.peaked = self.peak - self.begin - if self.verbose: - print(f"Delta used/peaked {self.used:.2f} MB / {self.peaked:.2f} MB") - print(f"Current used/peaked {self.end:.2f} MB / {self.peak:.2f} MB") - - -class fullprint: - "context manager for printing full numpy arrays" - - def __init__(self, **kwargs): - """linewidth=75; precision=8""" - kwargs.setdefault("threshold", np.inf) - self.opt = kwargs - - def __enter__(self): - self._opt = np.get_printoptions() - np.set_printoptions(**self.opt) - - def __exit__(self, type, value, traceback): - np.set_printoptions(**self._opt) - - -class CustomFormatter(logging.Formatter): - """Logging Formatter to add colors and count warning / errors""" - - grey = "\x1b[38;21m" - yellow = "\x1b[33;21m" - red = "\x1b[31;21m" - bold_red = "\x1b[31;1m" - green = "\x1b[32;21m" - reset = "\x1b[0m" - # format = "%(asctime)s - %(name)s - %(levelname)s - %(message)s (%(filename)s:%(lineno)d)" - format = "%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s" - - FORMATS = { - logging.DEBUG: grey + format + reset, - logging.INFO: grey + format + reset, - logging.WARNING: yellow + format + reset, - logging.ERROR: red + format + reset, - logging.CRITICAL: bold_red + format + reset, - } - - def format(self, record): - log_fmt = self.FORMATS.get(record.levelno) - formatter = logging.Formatter(log_fmt) - return formatter.format(record) - - -def setup_default_logging( - default_level=logging.INFO, default_file_level=logging.INFO, log_path="" -): - console_handler = logging.StreamHandler() - console_handler.setFormatter(CustomFormatter()) - logging.root.addHandler(console_handler) - logging.root.setLevel(default_level) - if log_path: - file_handler = logging.handlers.RotatingFileHandler( - log_path, maxBytes=(1024**2 * 2), backupCount=3 - ) - file_formatter = logging.Formatter( - "%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s" - ) - file_handler.setFormatter(file_formatter) - file_handler.setLevel(default_file_level) - logging.root.addHandler(file_handler) - - -class Logger(object): - def __init__( - self, - console=True, - logfile=None, - console_level=logging.INFO, - logfile_level=logging.INFO, - ): - super().__init__() - self.logfile = logfile - self.console_level = console_level - self.logifle_level = logfile_level - assert ( - console == True or logfile is not None - ), "At least enable one from console or logfile for Logger" - # 第一步,创建一个logger - self.logger = logging.getLogger("my_logger") - self.logger.setLevel(logging.INFO) # Log等级总开关 - self.logger.propagate = False - - # formatter = logging.Formatter( - # "%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s") - formatter = CustomFormatter() - - # 第三步,再创建一个handler,用于输出到控制台 - if console: - ch = logging.StreamHandler() - ch.setLevel(self.console_level) # 输出到console的log等级的开关 - ch.setFormatter(formatter) - self.logger.addHandler(ch) - if self.logfile is not None: - fh = logging.FileHandler(self.logfile, mode="w") - fh.setLevel(self.logifle_level) # 输出到file的log等级的开关 - fh.setFormatter(formatter) - self.logger.addHandler(fh) - - def debug(self, message): - self.logger.debug(message) - - def info(self, message): - self.logger.info(message) - - def warning(self, message): - self.logger.warning(message) - - def error(self, message): - self.logger.error(message) - - def critical(self, message): - self.logger.critical(message) - - -def get_logger( - name="default", - default_level=logging.INFO, - default_file_level=logging.INFO, - log_path="", -): - setup_default_logging( - default_level=default_level, - default_file_level=default_file_level, - log_path=log_path, - ) - return logging.getLogger(name) - - -logger = get_logger() - - -class ArgParser(object): - def __init__(self, load_json=None, save_json=None): - super().__init__() - self.load_json = load_json - self.save_json = save_json - self.args = None - self.parser = argparse.ArgumentParser("Argument Parser") - - def add_arg(self, *args, **keywords): - self.parser.add_argument(*args, **keywords) - - def parse_args(self): - if self.load_json is not None: - assert os.path.exists(self.load_json), logging.error( - f"Configuration JSON {self.load_json} not found" - ) - json = read_json(self.load_json) - t_args = argparse.Namespace() - t_args.__dict__.update(json) - self.args = self.parser.parse_args(args=[], namespace=t_args) - else: - self.args = self.parser.parse_args() - return self.args - - def print_args(self): - # Print arguments to std out - # and save argument values to yaml file - print("Arguments:") - for p in vars(self.args).items(): - print(f"\t{p[0]:30}{str(p[1]):20}") - print("\n") - - def dump_args(self, json_file=None): - if json_file is None: - if self.save_json is None: - logging.error("Skip dump configuration JSON. Please specify json_file") - return False - else: - ensure_dir(os.path.dirname(self.save_json)) - logging.warning(f"Dump to the initialized JSON file {self.save_json}") - write_json(vars(self.args), self.save_json) - else: - ensure_dir(os.path.dirname(json_file)) - logging.info(f"Dump to JSON file {json_file}") - write_json(vars(self.args), json_file) - # with open(self.file, 'w') as f: - # yaml.dump(vars(self.args), f, default_flow_style=False) - # print(f"[I] Arguments dumped to {file}") - - -def disable_tf_warning(): - import os - - os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" - import warnings - - warnings.filterwarnings("ignore", category=FutureWarning) - warnings.filterwarnings("ignore", category=DeprecationWarning) - - import tensorflow as tf - - if hasattr(tf, "contrib") and type(tf.contrib) != type(tf): - tf.contrib._warning = None - tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) - # tf.logging.set_verbosity(tf.logging.ERROR) - - import logging - - logging.getLogger("tensorflow").setLevel(logging.ERROR) - - -class Meter(object): - """Base class for Meters.""" - - def __init__(self): - pass - - def state_dict(self): - return {} - - def load_state_dict(self, state_dict): - pass - - def reset(self): - raise NotImplementedError - - @property - def smoothed_value(self) -> float: - """Smoothed value used for logging.""" - raise NotImplementedError - - -def safe_round(number, ndigits): - if hasattr(number, "__round__"): - return round(number, ndigits) - elif torch is not None and torch.is_tensor(number) and number.numel() == 1: - return safe_round(number.item(), ndigits) - elif np is not None and np.ndim(number) == 0 and hasattr(number, "item"): - return safe_round(number.item(), ndigits) - else: - return number - - -def type_as(a, b): - if torch.is_tensor(a) and torch.is_tensor(b): - return a.to(b) - else: - return a - - -class AverageMeter(Meter): - """Computes and stores the average and current value""" - - def __init__(self, name: str, fmt: str = ":f", round: Optional[int] = None) -> None: - self.name = name - self.fmt = fmt - self.round = round - self.reset() - - def reset(self): - self.val = None # most recent update - self.sum = 0 # sum from all updates - self.count = 0 # total n from all updates - self.avg = 0 - - def update(self, val, n=1): - if val is not None: - self.val = val - if n > 0: - self.sum = type_as(self.sum, val) + (val * n) - self.count = type_as(self.count, n) + n - self.avg = self.sum / self.count if self.count > 0 else self.val - - def state_dict(self): - return { - "val": self.val, - "sum": self.sum, - "count": self.count, - "round": self.round, - } - - def load_state_dict(self, state_dict): - self.val = state_dict["val"] - self.sum = state_dict["sum"] - self.count = state_dict["count"] - self.round = state_dict.get("round", None) - - @property - def smoothed_value(self) -> float: - val = self.avg - if self.round is not None and val is not None: - val = safe_round(val, self.round) - return val - - def __str__(self) -> str: - fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})" - return fmtstr.format(**self.__dict__) diff --git a/src/chop/nn/optical/modules/morr_conv2d.py b/src/chop/nn/optical/modules/morr_conv2d.py index 3cc56ad2f..f104e397f 100644 --- a/src/chop/nn/optical/modules/morr_conv2d.py +++ b/src/chop/nn/optical/modules/morr_conv2d.py @@ -6,24 +6,26 @@ """ from typing import Optional, Tuple +import logging import numpy as np import torch import torch.fft -from ..functional import im2col_2d, toeplitz -from ..functional import logger -from ..functional import morr_uniform_ -from ..functional import input_quantize_fn, weight_quantize_fn from torch import Tensor, nn from torch.nn import Parameter, init from torch.nn.modules.utils import _pair from torch.types import Device, _size -from ..functional import MORRConfig_20um_MQ -from ..functional import mrr_roundtrip_phase_to_tr_func, mrr_roundtrip_phase_to_tr_fused +from ..utils import MORRConfig_20um_MQ +from ..utils import mrr_roundtrip_phase_to_tr_func, mrr_roundtrip_phase_to_tr_fused +from ..utils import im2col_2d, toeplitz +from ..utils import morr_uniform_ +from ..utils import input_quantize_fn, weight_quantize_fn from .base_layer import ONNBaseLayer +logger = logging.getLogger(__name__) + __all__ = ["AllPassMORRCirculantConv2d"] @@ -72,18 +74,13 @@ def __init__( dilation: _size = 1, groups: int = 1, bias: bool = True, - padding_mode=None, - # miniblock: int = 4, - # ### morr parameter - # MORRConfig=MORRConfig_20um_MQ, - # morr_init: bool = True, # whether to use initialization method customized for MORR - # ### trainable MORR nonlinearity - # trainable_morr_bias: bool = False, - # trainable_morr_scale: bool = False, + padding_mode=None, # @johnny: unused argument config=None, - device: Device = torch.device("cuda" if torch.cuda.is_available() else "cpu"), + device: Device = torch.device("cpu"), ) -> None: super(AllPassMORRCirculantConv2d, self).__init__() + assert config is not None + miniblock = config.get("miniblock", 4) MORRConfig = config.get("MORRConfig", MORRConfig_20um_MQ) morr_init = config.get("morr_init", True) @@ -365,14 +362,16 @@ def morr_bias(self) -> Tensor: ) def propagate_morr(self, weight: Tensor, x: Tensor) -> Tensor: + """Propagate through the analytically calculated transfer matrix of MORR. + + :param weight: First column vectors in the block-circulant matrix. + :type weight: Tensor + :param x: Input tensor. + :type x: Tensor + + :return: Output of MORR array. + :rtype: Tensor """ - @description: propagate through the analytically calculated transfer matrix of morr. - @param weight {torch.Tensor} first column vectors in the block-circulant matrix - @param x {torch.Tensor} input - @return: y {torch.Tensor} output of MORR array - """ - ### weights: [p, q, k] - ### x: [ks*ks*inc, h_out*w_out*bs] x = x.t() # [h_out*w_out*bs, ks*ks*inc] x = x.view(x.size(0), self.grid_dim_x, self.miniblock) # [h_out*w_out*bs, q, k] diff --git a/src/chop/nn/optical/modules/morr_linear.py b/src/chop/nn/optical/modules/morr_linear.py index af85357a2..f93281af5 100644 --- a/src/chop/nn/optical/modules/morr_linear.py +++ b/src/chop/nn/optical/modules/morr_linear.py @@ -7,22 +7,24 @@ """ from typing import Optional +import logging import numpy as np import torch import torch.fft -from ..functional import toeplitz -from ..functional import logger -from ..functional import morr_uniform_ -from ..functional import input_quantize_fn, weight_quantize_fn from torch import Tensor from torch.nn import Parameter, init from torch.types import Device -from ..functional import MORRConfig_20um_MQ -from ..functional import mrr_roundtrip_phase_to_tr_func, mrr_roundtrip_phase_to_tr_fused +from ..utils import MORRConfig_20um_MQ +from ..utils import mrr_roundtrip_phase_to_tr_func, mrr_roundtrip_phase_to_tr_fused +from ..utils import toeplitz +from ..utils import morr_uniform_ +from ..utils import input_quantize_fn, weight_quantize_fn from .base_layer import ONNBaseLayer +logger = logging.getLogger(__name__) + __all__ = ["AllPassMORRCirculantLinear"] @@ -45,14 +47,7 @@ def __init__( out_features: int, bias: bool = False, config=None, - # miniblock: int = 4, - # ### mrr parameter - # MORRConfig=MORRConfig_20um_MQ, - # morr_init: bool = True, - # ### trainable MORR nonlinearity - # trainable_morr_bias: bool = False, - # trainable_morr_scale: bool = False, - device: Device = torch.device("cuda" if torch.cuda.is_available() else "cpu"), + device: Device = torch.device("cpu"), ) -> None: super(AllPassMORRCirculantLinear, self).__init__() self.in_features = in_features diff --git a/src/chop/nn/optical/utils/__init__.py b/src/chop/nn/optical/utils/__init__.py new file mode 100644 index 000000000..88b71f264 --- /dev/null +++ b/src/chop/nn/optical/utils/__init__.py @@ -0,0 +1,26 @@ +from .mrr import ( + MORRConfig_20um_MQ, + MRRConfig_5um_HQ, + MRRConfig_5um_MQ, + MRRConfig_5um_LQ, + MORRConfig_10um_MQ, +) + +from .compute import ( + im2col_2d, + toeplitz, +) + +from .initializer import ( + morr_uniform_, +) + +from .quantize import ( + input_quantize_fn, + weight_quantize_fn, +) + +from .mrr_op import ( + mrr_roundtrip_phase_to_tr_func, + mrr_roundtrip_phase_to_tr_fused, +) diff --git a/src/chop/nn/optical/functional/compute.py b/src/chop/nn/optical/utils/compute.py similarity index 100% rename from src/chop/nn/optical/functional/compute.py rename to src/chop/nn/optical/utils/compute.py diff --git a/src/chop/nn/optical/functional/initializer.py b/src/chop/nn/optical/utils/initializer.py similarity index 100% rename from src/chop/nn/optical/functional/initializer.py rename to src/chop/nn/optical/utils/initializer.py diff --git a/src/chop/nn/optical/functional/mrr.py b/src/chop/nn/optical/utils/mrr.py similarity index 100% rename from src/chop/nn/optical/functional/mrr.py rename to src/chop/nn/optical/utils/mrr.py diff --git a/src/chop/nn/optical/functional/mrr_op.py b/src/chop/nn/optical/utils/mrr_op.py similarity index 100% rename from src/chop/nn/optical/functional/mrr_op.py rename to src/chop/nn/optical/utils/mrr_op.py diff --git a/src/chop/nn/optical/functional/quantize.py b/src/chop/nn/optical/utils/quantize.py similarity index 100% rename from src/chop/nn/optical/functional/quantize.py rename to src/chop/nn/optical/utils/quantize.py diff --git a/src/chop/nn/optical/functional/torch_train.py b/src/chop/nn/optical/utils/torch_train.py similarity index 100% rename from src/chop/nn/optical/functional/torch_train.py rename to src/chop/nn/optical/utils/torch_train.py diff --git a/src/chop/passes/module/transforms/optical/module_transform_helper.py b/src/chop/passes/module/transforms/optical/module_transform_helper.py index 1cef06928..3d5f9670b 100644 --- a/src/chop/passes/module/transforms/optical/module_transform_helper.py +++ b/src/chop/passes/module/transforms/optical/module_transform_helper.py @@ -1,7 +1,5 @@ import torch import torch.nn as nn -import numpy as np -from chop.nn.optical.modules import optical_module_map from chop.passes.module.module_modify_helper import ( get_module_by_name, set_module_by_name, diff --git a/src/chop/passes/module/transforms/optical/optical.py b/src/chop/passes/module/transforms/optical/optical.py index b911c070d..550484880 100644 --- a/src/chop/passes/module/transforms/optical/optical.py +++ b/src/chop/passes/module/transforms/optical/optical.py @@ -1,7 +1,7 @@ import torch from chop.nn.optical.modules import optical_module_map -from chop.passes.module.module_modify_helper import replace_by_name, instantiate_module +from chop.passes.module.module_modify_helper import instantiate_module from chop.passes.module.transforms.optical.module_transform_helper import ( replace_by_name_optical, ) @@ -14,7 +14,7 @@ def get_config(config: dict, name: str): return config["default"]["config"] -def optical_by_type(network, pass_args): +def optical_transform_by_type(network, pass_args): for type_name, config in pass_args.items(): n_m = {} for n, m in network.named_modules(): @@ -37,20 +37,20 @@ def optical_by_type(network, pass_args): return network -def optical_by_name(network, pass_args): +def optical_transform_by_name(network, pass_args): optical_names = pass_args.keys() n_m = {} for n, m in network.named_modules(): n_m[n] = m for n, m in n_m.items(): if n in optical_names: - quan_config = pass_args[n] + optical_config = pass_args[n] - quan_config = quan_config["config"] - postfix = quan_config.pop("name") + optical_config = optical_config["config"] + postfix = optical_config.pop("name") new_m = instantiate_module( - m, postfix, optical_module_map, {"config": quan_config} + m, postfix, optical_module_map, {"config": optical_config} ) network = replace_by_name_optical(network, n, new_m) return network @@ -66,40 +66,16 @@ def optical_module_transform_pass(network, pass_args): :param pass_args: Additional arguments for the transformation. :type pass_args: dict, optional - Examples pass_args: - - .. code-block:: python - - pass_args = { - "by": "type", # quantize by type, name, or regex_name - "default": {"config": {"name": None}}, # default config, this would be used for any node that does not have a specific config - "linear": { - "config": { - "name": "integer", # quantization scheme name supported are ["integer", "fixed" (equivalent to integer), "lutnet" (dev mode), "logicnets" (dev mode), "binary", "binary_residual", "ternary", "minifloat_ieee", "minifloat_denorm", "log", "block_fp", "block_minifloat", "block_log"] - # data - "data_in_width": 8, - "data_in_frac_width": 4, - # weight - "weight_width": 8, - "weight_frac_width": 4, - # bias - "bias_width": 8, - "bias_frac_width": 4, - } - }, - } - :return: The transformed torch.nn.Module. :rtype: tuple - :raises ValueError: If the quantize "by" argument is unsupported. - + :raises ValueError: If the "by" argument is unsupported. """ by = pass_args.pop("by") match by: case "type": - network = optical_by_type(network, pass_args) + network = optical_transform_by_type(network, pass_args) case "name": - network = optical_by_name(network, pass_args) + network = optical_transform_by_name(network, pass_args) case _: raise ValueError(f'Unsupported quantize "by": {by}') return network, {} diff --git a/test/passes/module/transforms/optical/test_optical_module.py b/test/passes/module/transforms/optical/test_optical_module.py index d05512cc1..e65546822 100644 --- a/test/passes/module/transforms/optical/test_optical_module.py +++ b/test/passes/module/transforms/optical/test_optical_module.py @@ -1,19 +1,17 @@ #!/usr/bin/env python3 -# This example converts a simple MLP model to Verilog -import logging -import os +# This example converts a simple MLP model to an ONN model import sys import torch import torch.nn as nn +import torch.nn.functional as F from pathlib import Path sys.path.append(Path(__file__).resolve().parents[5].as_posix()) -# from chop.passes.module.transforms import quantize_module_transform_pass -from chop.passes.module.transforms import optical_module_transform_pass +from chop.passes.module.transforms.optical import optical_module_transform_pass class Net(nn.Module): @@ -42,17 +40,7 @@ def forward(self, x): return output -def load_my_model(model_path, device="cpu"): - # Load the model from the .pt file - loaded_model = torch.load(model_path, map_location=device) - # Set it to evaluation mode (important if it contains layers like BatchNorm or Dropout) - loaded_model.eval() - return loaded_model - - def test_optical_module_transform_pass(): - # model_path = "mase_output/sample_mnist_cnn.pt" - # mnist_cnn = load_my_model(model_path) model = Net() # Sanity check and report pass_args = { @@ -77,132 +65,6 @@ def test_optical_module_transform_pass(): }, } optical_module_transform_pass(model, pass_args) - # torch.save(onn_cnn, "mase_output/onn_cnn.pt") test_optical_module_transform_pass() - - -# if __name__ == '__main__': -# finetune = False - -# if True: -# parser = argparse.ArgumentParser(description='PyTorch MNIST Example') -# parser.add_argument('--batch-size', type=int, default=64, metavar='N', -# help='input batch size for training (default: 64)') -# parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', -# help='input batch size for testing (default: 1000)') -# parser.add_argument('--epochs', type=int, default=14, metavar='N', -# help='number of epochs to train (default: 14)') -# parser.add_argument('--lr', type=float, default=1.0, metavar='LR', -# help='learning rate (default: 1.0)') -# parser.add_argument('--gamma', type=float, default=0.7, metavar='M', -# help='Learning rate step gamma (default: 0.7)') -# parser.add_argument('--no-cuda', action='store_true', default=False, -# help='disables CUDA training') -# parser.add_argument('--no-mps', action='store_true', default=False, -# help='disables macOS GPU training') -# parser.add_argument('--dry-run', action='store_true', default=False, -# help='quickly check a single pass') -# parser.add_argument('--seed', type=int, default=1, metavar='S', -# help='random seed (default: 1)') -# parser.add_argument('--log-interval', type=int, default=10, metavar='N', -# help='how many batches to wait before logging training status') -# parser.add_argument('--save-model', action='store_true', default=True, -# help='For Saving the current Model') -# parser.add_argument('--gpu-id', type=int, default=0, -# help='Which GPU device to use [default: 0]') - -# args = parser.parse_args() -# use_cuda = not args.no_cuda and torch.cuda.is_available() -# use_mps = not args.no_mps and torch.backends.mps.is_available() - -# torch.manual_seed(args.seed) - -# if not args.no_cuda and torch.cuda.is_available(): -# device = torch.device(f"cuda:{args.gpu_id}") -# elif use_mps: -# device = torch.device("mps") -# else: -# device = torch.device("cpu") - -# train_kwargs = {'batch_size': args.batch_size} -# test_kwargs = {'batch_size': args.test_batch_size} -# if use_cuda: -# cuda_kwargs = {'num_workers': 1, -# 'pin_memory': True, -# 'shuffle': True} -# train_kwargs.update(cuda_kwargs) -# test_kwargs.update(cuda_kwargs) - -# transform=transforms.Compose([ -# transforms.ToTensor(), -# transforms.Normalize((0.1307,), (0.3081,)) -# ]) -# dataset1 = datasets.MNIST('../data', train=True, download=True, -# transform=transform) -# dataset2 = datasets.MNIST('../data', train=False, -# transform=transform) -# train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs) -# test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs) - -# # load pre-trained cnn -# cnn = load_my_model("mase_output/sample_mnist_cnn.pt", device) -# print("-------------- Testing the original cnn model -------------------") -# _, _ = report_trainable_parameters_analysis_pass(cnn) -# test(cnn, device, test_loader) - -# ## transform cnn into onn - -# # onn = load_my_model("mase_output/onn_cnn.pt", device) -# onn_model = perform_optical_module_transform_pass(cnn) -# onn_model.to(device) -# print("-------------- Testing the transformed onn model -------------------") -# _, _ = report_trainable_parameters_analysis_pass(onn_model) -# test(onn_model, device, test_loader) - -# # Training the onn model -# if finetune: -# optimizer = optim.Adadelta(onn_model.parameters(), lr=args.lr) -# scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) -# for epoch in range(1, args.epochs + 1): -# train(args, onn_model, device, train_loader, optimizer, epoch) -# test(onn_model, device, test_loader) -# scheduler.step() - - -# torch.save(onn_model.state_dict(), "mase_output/trained_onn.pt") - -# print("-------------- Testing the trained onn model -------------------") -# test(onn_model, device, test_loader) -# _, _ = report_trainable_parameters_analysis_pass(onn_model) - - -# def perform_optical_module_transform_pass(model, save_path="mase_output/onn_cnn.pt"): -# # pass_args = { -# # "by": "type", -# # "linear": { -# # "config": { -# # "name": "morr", -# # "miniblock": 4, -# # "morr_init": True, -# # "trainable_morr_bias": False, -# # "trainable_morr_scale": False, -# # } -# # }, -# # } -# pass_args = { -# "by": "type", -# "conv2d": { -# "config": { -# "name": "morr", -# "miniblock": 4, -# "morr_init": True, -# "trainable_morr_bias": False, -# "trainable_morr_scale": False, -# } -# }, -# } -# onn_model, _ = optical_module_transform_pass(model, pass_args) -# torch.save(onn_model.state_dict(), save_path) -# return onn_model From 188688492889b8c677a788269b14314f1737ed2d Mon Sep 17 00:00:00 2001 From: Cheng Zhang Date: Tue, 11 Feb 2025 20:22:31 +0000 Subject: [PATCH 11/38] remove torch_train.py --- src/chop/nn/optical/utils/compute.py | 13 +- src/chop/nn/optical/utils/torch_train.py | 858 ----------------------- 2 files changed, 11 insertions(+), 860 deletions(-) delete mode 100644 src/chop/nn/optical/utils/torch_train.py diff --git a/src/chop/nn/optical/utils/compute.py b/src/chop/nn/optical/utils/compute.py index c43ad9f2d..d8d36a354 100644 --- a/src/chop/nn/optical/utils/compute.py +++ b/src/chop/nn/optical/utils/compute.py @@ -19,8 +19,6 @@ from torch.nn.modules.utils import _pair from torch.types import Device, _size -from .torch_train import set_torch_deterministic - __all__ = [ "shift", "Krylov", @@ -70,6 +68,17 @@ ] +def set_torch_deterministic(random_state: int = 0) -> None: + random_state = int(random_state) % (2**32) + torch.manual_seed(random_state) + np.random.seed(random_state) + if torch.cuda.is_available(): + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + torch.cuda.manual_seed_all(random_state) + random.seed(random_state) + + def shift(v: Tensor, f: float = 1) -> Tensor: return torch.cat((f * v[..., -1:], v[..., :-1]), dim=-1) diff --git a/src/chop/nn/optical/utils/torch_train.py b/src/chop/nn/optical/utils/torch_train.py deleted file mode 100644 index 41571effa..000000000 --- a/src/chop/nn/optical/utils/torch_train.py +++ /dev/null @@ -1,858 +0,0 @@ -""" -Description: -Author: Jiaqi Gu (jqgu@utexas.edu) -Date: 2021-06-06 03:15:06 -LastEditors: Jiaqi Gu (jqgu@utexas.edu) -LastEditTime: 2021-06-06 03:15:06 -""" - -import csv -import os -import random -import time -import traceback -from collections import OrderedDict - -import numpy as np -import torch -from scipy import interpolate -from torch.nn.modules.batchnorm import _BatchNorm - -try: - from torchsummary import summary -except: - print("[W] Cannot import torchsummary") -from .general import ensure_dir - -__all__ = [ - "DeterministicCtx", - "set_torch_deterministic", - "set_torch_stochastic", - "get_random_state", - "summary_model", - "save_model", - "BestKModelSaver", - "load_model", - "count_parameters", - "check_converge", - "ThresholdScheduler", - "ThresholdScheduler_tf", - "ValueRegister", - "ValueTracer", - "EMA", - "SWA", - "export_traces_to_csv", - "set_learning_rate", - "get_learning_rate", - "apply_weight_decay", - "disable_bn", - "enable_bn", -] - - -class DeterministicCtx: - def __init__(self, random_state: int | None = None) -> None: - self.random_state = random_state - - def __enter__(self): - self.random_state = random.getstate() - self.numpy_random_state = np.random.get_state() - self.torch_random_state = torch.random.get_rng_state() - self.torch_cuda_random_state = torch.cuda.get_rng_state() - set_torch_deterministic(self.random_state) - return self - - def __exit__(self, *args): - random.setstate(self.random_state) - np.random.seed(self.numpy_random_state) - np.random.set_state(self.numpy_random_state) - torch.random.set_rng_state(self.torch_random_state) - torch.cuda.set_rng_state(self.torch_cuda_random_state) - - -def set_torch_deterministic(random_state: int = 0) -> None: - random_state = int(random_state) % (2**32) - torch.manual_seed(random_state) - np.random.seed(random_state) - if torch.cuda.is_available(): - torch.backends.cudnn.deterministic = True - torch.backends.cudnn.benchmark = False - torch.cuda.manual_seed_all(random_state) - random.seed(random_state) - - -def set_torch_stochastic(): - seed = int(time.time() * 1000) % (2**32) - torch.manual_seed(seed) - np.random.seed(seed) - if torch.cuda.is_available(): - torch.backends.cudnn.deterministic = False - torch.cuda.manual_seed_all(seed) - - -def get_random_state(): - return np.random.get_state()[1][0] - - -def summary_model(model, input): - summary(model, input) - - -def save_model(model, path="./checkpoint/model.pt", print_msg=True): - """Save PyTorch model in path - - Args: - model (PyTorch model): PyTorch model - path (str, optional): Full path of PyTorch model. Defaults to "./checkpoint/model.pt". - print_msg (bool, optional): Control of message print. Defaults to True. - """ - dir = os.path.dirname(path) - if not os.path.exists(dir): - os.mkdir(dir) - try: - torch.save(model.state_dict(), path) - if print_msg: - print(f"[I] Model saved to {path}") - except Exception as e: - if print_msg: - print(f"[E] Model failed to be saved to {path}") - traceback.print_exc(e) - - -class BestKModelSaver(object): - def __init__( - self, - k: int = 1, - descend: bool = True, - truncate: int = 2, - metric_name: str = "acc", - format: str = "{:.2f}", - ): - super().__init__() - self.k = k - self.descend = descend - self.truncate = truncate - self.metric_name = metric_name - self.format = format - self.epsilon = 0.1**truncate - self.model_cache = OrderedDict() - - def better_op(self, a, b): - if self.descend: - return a >= b + self.epsilon - else: - return a <= b - self.epsilon - - def __insert_model_record(self, metric, dir, checkpoint_name, epoch=None): - metric = round(metric * 10**self.truncate) / 10**self.truncate - if len(self.model_cache) < self.k: - new_checkpoint_name = ( - f"{checkpoint_name}_{self.metric_name}-" - + self.format.format(metric) - + f"{'' if epoch is None else '_epoch-'+str(epoch)}" - ) - path = os.path.join(dir, new_checkpoint_name + ".pt") - self.model_cache[path] = (metric, epoch) - return path, None - else: - worst_metric, worst_epoch = sorted( - list(self.model_cache.values()), - key=lambda x: x[0], - reverse=False if self.descend else True, - )[0] - if self.better_op(metric, worst_metric): - del_checkpoint_name = ( - f"{checkpoint_name}_{self.metric_name}-" - + self.format.format(worst_metric) - + f"{'' if epoch is None else '_epoch-'+str(worst_epoch)}" - ) - del_path = os.path.join(dir, del_checkpoint_name + ".pt") - try: - del self.model_cache[del_path] - except: - print( - "[W] Cannot remove checkpoint: {} from cache".format(del_path), - flush=True, - ) - new_checkpoint_name = ( - f"{checkpoint_name}_{self.metric_name}-" - + self.format.format(metric) - + f"{'' if epoch is None else '_epoch-'+str(epoch)}" - ) - path = os.path.join(dir, new_checkpoint_name + ".pt") - self.model_cache[path] = (metric, epoch) - return path, del_path - # elif(acc == min_acc): - # new_checkpoint_name = f"{checkpoint_name}_acc-{acc:.2f}{'' if epoch is None else '_epoch-'+str(epoch)}" - # path = os.path.join(dir, new_checkpoint_name+".pt") - # self.model_cache[path] = (acc, epoch) - # return path, None - else: - return None, None - - def get_topk_model_path(self, topk: int = 1): - if topk <= 0: - return [] - if topk > len(self.model_cache): - topk = len(self.model_cache) - return [ - i[0] - for i in sorted( - self.model_cache.items(), key=lambda x: x[1][0], reverse=self.descend - )[:topk] - ] - - def save_model( - self, - model, - metric, - epoch=None, - path="./checkpoint/model.pt", - other_params=None, - save_model=False, - print_msg=True, - ): - """Save PyTorch model in path - - Args: - model (PyTorch model): PyTorch model - acc (scalar): accuracy - epoch (scalar, optional): epoch. Defaults to None - path (str, optional): Full path of PyTorch model. Defaults to "./checkpoint/model.pt". - other_params (dict, optional): Other saved params. Defaults to None - save_model (bool, optional): whether save source code of nn.Module. Defaults to False - print_msg (bool, optional): Control of message print. Defaults to True. - """ - dir = os.path.dirname(path) - ensure_dir(dir) - checkpoint_name = os.path.splitext(os.path.basename(path))[0] - if isinstance(metric, torch.Tensor): - metric = metric.data.item() - new_path, del_path = self.__insert_model_record( - metric, dir, checkpoint_name, epoch - ) - - if del_path is not None: - try: - os.remove(del_path) - print(f"[I] Model {del_path} is removed", flush=True) - except Exception as e: - if print_msg: - print(f"[E] Model {del_path} failed to be removed", flush=True) - traceback.print_exc(e) - - if new_path is None: - if print_msg: - if self.descend: - best_list = list(reversed(sorted(list(self.model_cache.values())))) - else: - best_list = list(sorted(list(self.model_cache.values()))) - print( - f"[I] Not best {self.k}: {best_list}, skip this model (" - + self.format.format(metric) - + f"): {path}", - flush=True, - ) - else: - try: - # torch.save(model.state_dict(), new_path) - if other_params is not None: - saved_dict = other_params - else: - saved_dict = {} - if save_model: - saved_dict.update( - {"model": model, "state_dict": model.state_dict()} - ) - torch.save(saved_dict, new_path) - else: - saved_dict.update({"model": None, "state_dict": model.state_dict()}) - torch.save(saved_dict, new_path) - if print_msg: - if self.descend: - best_list = list( - reversed(sorted(list(self.model_cache.values()))) - ) - else: - best_list = list(sorted(list(self.model_cache.values()))) - - print( - f"[I] Model saved to {new_path}. Current best {self.k}: {best_list}", - flush=True, - ) - except Exception as e: - if print_msg: - print(f"[E] Model failed to be saved to {new_path}", flush=True) - traceback.print_exc(e) - return new_path - - -def load_model( - model, - path="./checkpoint/model.pt", - ignore_size_mismatch: bool = False, - print_msg=True, -): - """Load PyTorch model in path - - Args: - model (PyTorch model): PyTorch model - path (str, optional): Full path of PyTorch model. Defaults to "./checkpoint/model.pt". - ignore_size_mismatch (bool, optional): Whether ignore tensor size mismatch. Defaults to False. - print_msg (bool, optional): Control of message print. Defaults to True. - """ - try: - raw_data = torch.load(path, map_location=lambda storage, location: storage) - if isinstance(raw_data, OrderedDict) and "state_dict" not in raw_data: - ### state_dict: OrderedDict - state_dict = raw_data - else: - ### {"state_dict": ..., "model": ...} - state_dict = raw_data["state_dict"] - load_keys = set(state_dict.keys()) - model_keys = set(model.state_dict().keys()) - common_dict = load_keys & model_keys - diff_dict = load_keys ^ model_keys - extra_keys = load_keys - model_keys - lack_keys = model_keys - load_keys - cur_state_dict = model.state_dict() - if ignore_size_mismatch: - size_mismatch_dict = set( - key - for key in common_dict - if model.state_dict()[key].size() != state_dict[key].size() - ) - print( - f"[W] {size_mismatch_dict} are ignored due to size mismatch", flush=True - ) - common_dict = common_dict - size_mismatch_dict - - cur_state_dict.update({key: state_dict[key] for key in common_dict}) - if len(diff_dict) > 0: - print( - f"[W] Warning! Model is not the same as the checkpoint. not found keys {lack_keys}. extra unused keys {extra_keys}" - ) - - model.load_state_dict(cur_state_dict) - if print_msg: - print(f"[I] Model loaded from {path}") - except Exception as e: - traceback.print_exc(e) - if print_msg: - print(f"[E] Model failed to be loaded from {path}") - - -def count_parameters(model): - return sum(p.numel() for p in model.parameters() if p.requires_grad) - - -def check_converge(trace, epsilon=0.002): - if len(trace) <= 1: - return False - if np.abs(trace[-1] - trace[-2]) / (np.abs(trace[-1]) + 1e-8) < epsilon: - return True - return False - - -class ThresholdScheduler(object): - """Intepolation between begin point and end point. step must be within two endpoints""" - - def __init__(self, step_beg, step_end, thres_beg, thres_end, mode="tanh"): - assert mode in { - "linear", - "tanh", - }, "Threshold scheduler only supports linear and tanh modes" - self.mode = mode - self.step_beg = step_beg - self.step_end = step_end - self.thres_beg = thres_beg - self.thres_end = thres_end - self.func = self.createFunc() - - def normalize(self, step, factor=2): - return (step - self.step_beg) / (self.step_end - self.step_beg) * factor - - def createFunc(self): - if self.mode == "linear": - return lambda x: (self.thres_end - self.thres_beg) * x + self.thres_beg - elif self.mode == "tanh": - x = self.normalize( - np.arange(self.step_beg, self.step_end + 1).astype(np.float32) - ) - y = np.tanh(x) * (self.thres_end - self.thres_beg) + self.thres_beg - return interpolate.interp1d(x, y) - - def __call__(self, x): - return self.func(self.normalize(x)).tolist() - - -class ThresholdScheduler_tf(object): - """smooth increasing threshold with tensorflow model pruning scheduler""" - - def __init__(self, step_beg, step_end, thres_beg, thres_end): - import tensorflow as tf - import tensorflow_model_optimization as tfmot - - gpus = tf.config.list_physical_devices("GPU") - if gpus: - # Restrict TensorFlow to only allocate 1GB of memory on the first GPU - try: - for gpu in gpus: - tf.config.experimental.set_memory_growth(gpu, True) - except RuntimeError as e: - # Virtual devices must be set before GPUs have been initialized - print(e) - self.step_beg = step_beg - self.step_end = step_end - self.thres_beg = thres_beg - self.thres_end = thres_end - if thres_beg < thres_end: - self.thres_min = thres_beg - self.thres_range = thres_end - thres_beg - self.descend = False - - else: - self.thres_min = thres_end - self.thres_range = thres_beg - thres_end - self.descend = True - - self.pruning_schedule = tfmot.sparsity.keras.PolynomialDecay( - initial_sparsity=0, - final_sparsity=0.9999999, - begin_step=self.step_beg, - end_step=self.step_end, - ) - - def __call__(self, x): - if x < self.step_beg: - return self.thres_beg - elif x > self.step_end: - return self.thres_end - res_norm = self.pruning_schedule(x)[1].numpy() - if self.descend == False: - res = res_norm * self.thres_range + self.thres_beg - else: - res = self.thres_beg - res_norm * self.thres_range - - if np.abs(res - self.thres_end) <= 1e-6: - res = self.thres_end - return res - - -class ValueRegister(object): - def __init__(self, operator, name="", show=True): - self.op = operator - self.cache = None - self.show = show - self.name = name if len(name) > 0 else "value" - - def register_value(self, x): - self.cache = self.op(x, self.cache) if self.cache is not None else x - if self.show: - print(f"Recorded {self.name} is {self.cache}") - - -class ValueTracer(object): - def __init__(self, show=True): - self.cache = {} - self.show = show - - def add_value(self, name, value, step): - if name not in self.cache: - self.cache[name] = {} - self.cache[name][step] = value - if self.show: - print(f"Recorded {name}: step = {step}, value = {value}") - - def get_trace_by_name(self, name): - return self.cache.get(name, {}) - - def get_all_traces(self): - return self.cache - - def __len__(self): - return len(self.cache) - - def get_num_trace(self): - return len(self.cache) - - def get_len_trace_by_name(self, name): - return len(self.cache.get(name, {})) - - def dump_trace_to_file(self, name, file): - if name not in self.cache: - print(f"[W] Trace name '{name}' not found in tracer") - return - torch.save(self.cache[name], file) - print(f"[I] Trace {name} saved to {file}") - - def dump_all_traces_to_file(self, file): - torch.save(self.cache, file) - print(f"[I] All traces saved to {file}") - - def load_all_traces_from_file(self, file): - self.cache = torch.load(file) - return self.cache - - -class EMA(object): - def __init__(self, mu): - super().__init__() - self.mu = mu - self.shadow = {} - - def register(self, name, val): - self.shadow[name] = val.clone().data - - def __call__(self, name, x, mask=None): - if name not in self.shadow: - self.register(name, x) - return x.data - - old_average = self.shadow[name] - new_average = (1 - self.mu) * x + self.mu * old_average - if mask is not None: - new_average[mask].copy_(old_average[mask]) - self.shadow[name] = new_average.clone() - return new_average.data - - -class SWA(torch.nn.Module): - """Stochastic Weight Averging. - - # Paper - title: Averaging Weights Leads to Wider Optima and Better Generalization - link: https://arxiv.org/abs/1803.05407 - - # Arguments - start_epoch: integer, epoch when swa should start. - lr_schedule: string, type of learning rate schedule. - swa_lr: float, learning rate for swa. - swa_lr2: float, upper bound of cyclic learning rate. - swa_freq: integer, length of learning rate cycle. - batch_size integer, batch size (for batch norm with generator) - verbose: integer, verbosity mode, 0 or 1. - """ - - def __init__( - self, - model: torch.nn.Module, - optimizer: torch.optim.Optimizer, - start_epoch: int, - epochs: int, # total epochs - steps, # total steps per epoch - lr_schedule="manual", - swa_lr="auto", - swa_lr2="auto", - swa_freq=1, - batch_size=None, - verbose=0, - ): - super().__init__() - self.model = model - self.optimizer = optimizer - self.start_epoch = start_epoch - 1 - self.epochs = epochs - self.steps = steps - self.lr_schedule = lr_schedule - self.swa_lr = swa_lr - - # if no user determined upper bound, make one based off of the lower bound - self.swa_lr2 = swa_lr2 if swa_lr2 is not None else 10 * swa_lr - self.swa_freq = swa_freq - self.batch_size = batch_size - self.verbose = verbose - - if start_epoch < 2: - raise ValueError('"swa_start" attribute cannot be lower than 2.') - - schedules = ["manual", "constant", "cyclic"] - - if self.lr_schedule not in schedules: - raise ValueError( - '"{}" is not a valid learning rate schedule'.format(self.lr_schedule) - ) - - if self.lr_schedule == "cyclic" and self.swa_freq < 2: - raise ValueError('"swa_freq" must be higher than 1 for cyclic schedule.') - - if self.swa_lr == "auto" and self.swa_lr2 != "auto": - raise ValueError( - '"swa_lr2" cannot be manually set if "swa_lr" is automatic.' - ) - - if ( - self.lr_schedule == "cyclic" - and self.swa_lr != "auto" - and self.swa_lr2 != "auto" - and self.swa_lr > self.swa_lr2 - ): - raise ValueError('"swa_lr" must be lower than "swa_lr2".') - - def on_train_begin(self): - self.lr_record = [] - - if self.start_epoch >= self.epochs - 1: - raise ValueError('"swa_start" attribute must be lower than "epochs".') - - self.init_lr = self.optimizer.param_groups[0]["lr"] - - # automatic swa_lr - if self.swa_lr == "auto": - self.swa_lr = 0.1 * self.init_lr - - if self.init_lr < self.swa_lr: - raise ValueError('"swa_lr" must be lower than rate set in optimizer.') - - # automatic swa_lr2 between initial lr and swa_lr - if self.lr_schedule == "cyclic" and self.swa_lr2 == "auto": - self.swa_lr2 = self.swa_lr + (self.init_lr - self.swa_lr) * 0.25 - - self._check_batch_norm() - - if self.has_batch_norm and self.batch_size is None: - raise ValueError( - '"batch_size" needs to be set for models with batch normalization layers.' - ) - - def on_epoch_begin(self, epoch): - # input epoch is from 0 to epochs-1 - - self.current_epoch = epoch - self._scheduler(epoch) - - # constant schedule is updated epoch-wise - if self.lr_schedule == "constant": - self._update_lr(epoch) - - if self.is_swa_start_epoch: - # self.swa_weights = self.model.get_weights() - self.swa_weights = { - name: p.data.clone() for name, p in self.model.named_parameters() - } - - if self.verbose > 0: - print( - "\nEpoch %05d: starting stochastic weight averaging" % (epoch + 1) - ) - - if self.is_batch_norm_epoch: - self._set_swa_weights(epoch) - - if self.verbose > 0: - print( - "\nEpoch %05d: reinitializing batch normalization layers" - % (epoch + 1) - ) - - self._reset_batch_norm() - - if self.verbose > 0: - print( - "\nEpoch %05d: running forward pass to adjust batch normalization" - % (epoch + 1) - ) - - def on_batch_begin(self, batch): - # update lr each batch for cyclic lr schedule - if self.lr_schedule == "cyclic": - self._update_lr(self.current_epoch, batch) - - if self.is_batch_norm_epoch: - batch_size = self.batch_size - # this is for tensorflow momentum, applied to the running stat - # momentum = batch_size / (batch * batch_size + batch_size) - - # we need to convert it to torch momentum, applied to the batch stat - momentum = 1 - batch_size / (batch * batch_size + batch_size) - - for layer in self.batch_norm_layers: - layer.momentum = momentum - - def on_epoch_end(self, epoch): - if self.is_swa_start_epoch: - self.swa_start_epoch = epoch - - if self.is_swa_epoch and not self.is_batch_norm_epoch: - self.swa_weights = self._average_weights(epoch) - - def on_train_end(self): - if not self.has_batch_norm: - self._set_swa_weights(self.epochs) - else: - self._restore_batch_norm() - - ## TODO: what is meaning here? - # for batch_lr in self.lr_record: - # self.model.history.history.setdefault("lr", []).append(batch_lr) - - def _scheduler(self, epoch): - swa_epoch = epoch - self.start_epoch - - self.is_swa_epoch = epoch >= self.start_epoch and swa_epoch % self.swa_freq == 0 - self.is_swa_start_epoch = epoch == self.start_epoch - self.is_batch_norm_epoch = epoch == self.epochs - 1 and self.has_batch_norm - - def _average_weights(self, epoch): - # return [ - # (swa_w * ((epoch - self.start_epoch) / self.swa_freq) + w) - # / ((epoch - self.start_epoch) / self.swa_freq + 1) - # for swa_w, w in zip(self.swa_weights, self.model.get_weights()) - # ] - out = {} - with torch.no_grad(): - for name, w in self.model.named_parameters(): - swa_w = self.swa_weights[name] - out[name] = ( - swa_w * ((epoch - self.start_epoch) / self.swa_freq) + w.data - ) / ((epoch - self.start_epoch) / self.swa_freq + 1) - return out - - def _update_lr(self, epoch, batch=None): - if self.is_batch_norm_epoch: - lr = 0 - # K.set_value(self.model.optimizer.lr, lr) - set_learning_rate(lr, self.optimizer) - elif self.lr_schedule == "constant": - lr = self._constant_schedule(epoch) - # K.set_value(self.model.optimizer.lr, lr) - set_learning_rate(lr, self.optimizer) - elif self.lr_schedule == "cyclic": - lr = self._cyclic_schedule(epoch, batch) - # K.set_value(self.model.optimizer.lr, lr) - set_learning_rate(lr, self.optimizer) - self.lr_record.append(lr) - - def _constant_schedule(self, epoch): - t = epoch / self.start_epoch - lr_ratio = self.swa_lr / self.init_lr - if t <= 0.5: - factor = 1.0 - elif t <= 0.9: - factor = 1.0 - (1.0 - lr_ratio) * (t - 0.5) / 0.4 - else: - factor = lr_ratio - return self.init_lr * factor - - def _cyclic_schedule(self, epoch, batch): - """Designed after Section 3.1 of Averaging Weights Leads to - Wider Optima and Better Generalization(https://arxiv.org/abs/1803.05407) - """ - # steps are mini-batches per epoch, equal to training_samples / batch_size - steps = self.steps - - swa_epoch = (epoch - self.start_epoch) % self.swa_freq - cycle_length = self.swa_freq * steps - - # batch 0 indexed, so need to add 1 - i = (swa_epoch * steps) + (batch + 1) - if epoch >= self.start_epoch: - t = (((i - 1) % cycle_length) + 1) / cycle_length - return (1 - t) * self.swa_lr2 + t * self.swa_lr - else: - return self._constant_schedule(epoch) - - def _set_swa_weights(self, epoch): - # self.model.set_weights(self.swa_weights) - for name, p in self.model.named_parameters(): - p.data.copy_(self.swa_weights[name]) - - if self.verbose > 0: - print( - "\nEpoch %05d: final model weights set to stochastic weight average" - % (epoch + 1) - ) - - def _check_batch_norm(self): - self.batch_norm_momentums = [] - self.batch_norm_layers = [] - self.has_batch_norm = False - self.running_bn_epoch = False - - for layer in self.model.modules(): - if isinstance(layer, _BatchNorm): - self.has_batch_norm = True - self.batch_norm_momentums.append(layer.momentum) - self.batch_norm_layers.append(layer) - - if self.verbose > 0 and self.has_batch_norm: - print( - "Model uses batch normalization. SWA will require last epoch " - "to be a forward pass and will run with no learning rate" - ) - - def _reset_batch_norm(self): - for layer in self.batch_norm_layers: - # initialized moving mean and - # moving var weights - layer.reset_running_stats() - - def _restore_batch_norm(self): - for layer, momentum in zip(self.batch_norm_layers, self.batch_norm_momentums): - layer.momentum = momentum - - -def export_traces_to_csv(trace_file, csv_file, fieldnames=None): - traces = torch.load(trace_file) - - with open(csv_file, "w", newline="") as csvfile: - if fieldnames is None: - fieldnames = list(traces.keys()) - writer = csv.DictWriter(csvfile, fieldnames=fieldnames) - - writer.writeheader() - max_len = max([len(traces[field]) for field in fieldnames]) - - for idx in range(max_len): - row = {} - for field in fieldnames: - value = traces[field][idx] if idx < len(traces[field]) else "" - row[field] = ( - value.data.item() if isinstance(value, torch.Tensor) else value - ) - writer.writerow(row) - - -def set_learning_rate(lr, optimizer): - for param_group in optimizer.param_groups: - param_group["lr"] = lr - - -def get_learning_rate(optimizer): - return optimizer.param_groups[0]["lr"] - - -def apply_weight_decay(W, decay_rate, learning_rate, mask=None): - # in mask, 1 represents fixed variables, 0 represents trainable variables - if mask is not None: - W[~mask] -= W[~mask] * decay_rate * learning_rate - else: - W -= W * decay_rate * learning_rate - - -def disable_bn(model: torch.nn.Module) -> None: - for m in model.modules(): - if isinstance( - m, - ( - torch.nn.BatchNorm1d, - torch.nn.BatchNorm2d, - torch.nn.BatchNorm3d, - torch.nn.SyncBatchNorm, - ), - ): - m.eval() - - -def enable_bn(model: torch.nn.Module) -> None: - for m in model.modules(): - if isinstance( - m, - ( - torch.nn.BatchNorm1d, - torch.nn.BatchNorm2d, - torch.nn.BatchNorm3d, - torch.nn.SyncBatchNorm, - ), - ): - m.train() From 6d9dddd8c1fc23d86633413ff67520c24f2f001a Mon Sep 17 00:00:00 2001 From: Johnny1882 Date: Sat, 15 Feb 2025 19:05:08 +0000 Subject: [PATCH 12/38] remove redundent onn functions --- scripts/mase-hls.py | 30 +- scripts/stat-to-conf.py | 7 +- setup.py | 4 +- src/chop/actions/emit.py | 5 +- src/chop/actions/search/search.py | 4 +- .../actions/search/search_space/nas_bert.py | 8 +- .../quantization/manual_hf_module.py | 3 +- src/chop/actions/simulate.py | 4 +- src/chop/actions/train.py | 3 +- src/chop/dataset/nerf/blender.py | 4 +- src/chop/dataset/nlp/language_modeling.py | 10 +- src/chop/dataset/vision/transforms/cifar.py | 5 +- src/chop/distributed/launcher.py | 16 +- src/chop/distributed/tensor/__init__.py | 18 +- src/chop/distributed/tensor/_dispatch.py | 16 +- src/chop/distributed/tensor/_redistribute.py | 17 +- src/chop/distributed/tensor/_sharding_prop.py | 15 +- src/chop/distributed/tensor/api.py | 56 +- .../distributed/tensor/ops/basic_strategy.py | 5 +- .../distributed/tensor/ops/common_rules.py | 15 +- src/chop/distributed/tensor/ops/conv_ops.py | 23 +- src/chop/distributed/tensor/ops/math_ops.py | 34 +- src/chop/distributed/tensor/ops/matrix_ops.py | 10 +- .../distributed/tensor/ops/pointwise_ops.py | 9 +- src/chop/distributed/tensor/ops/tensor_ops.py | 16 +- src/chop/distributed/tensor/ops/utils.py | 4 +- src/chop/distributed/tensor/ops/view_ops.py | 9 +- src/chop/ir/graph/mase_graph.py | 49 +- src/chop/ir/graph/mase_metadata.py | 4 +- src/chop/ir/onnx/mase_onnx_graph.py | 5 +- src/chop/ir/onnx/utils.py | 9 +- .../models/bert/modeling_bert_quantized.py | 8 +- src/chop/models/bert/quant_config_bert.py | 5 +- src/chop/models/cnv/cnv.py | 17 +- src/chop/models/cswin/cswintransformer.py | 8 +- src/chop/models/deit/deit_v2.py | 2 +- src/chop/models/efficientnet/efficientnet.py | 23 +- src/chop/models/lfc/lfc.py | 4 +- src/chop/models/llama/modeling_llama_llora.py | 5 +- .../models/llama/modeling_llama_sparse.py | 5 +- src/chop/models/mobilenet_v2/mobilenet_v2.py | 3 +- src/chop/models/nerf/nerf_vision.py | 4 +- src/chop/models/opt/modeling_opt.py | 2 +- src/chop/models/opt/modeling_opt_lora.py | 2 +- src/chop/models/opt/modeling_opt_quantized.py | 2 +- src/chop/models/opt/modeling_opt_sparse.py | 2 +- .../models/opt/quant_config_opt_quantized.py | 4 +- src/chop/models/pvt/pvt.py | 9 +- src/chop/models/pvt/pvt_v2.py | 2 +- src/chop/models/repvgg/repvgg.py | 4 +- src/chop/models/resnet/resnet.py | 56 +- src/chop/models/toy/toy.py | 15 +- src/chop/models/vgg_cifar/vgg_orig.py | 61 +- src/chop/models/vision/snn/snn_toy.py | 4 +- .../models/vision/snn/spikingResformer.py | 42 +- src/chop/nn/backward/modules/__init__.py | 4 +- src/chop/nn/functional/softermax.py | 2 +- src/chop/nn/modules/gqa.py | 11 +- src/chop/nn/modules/lora.py | 13 +- src/chop/nn/modules/sparse.py | 19 +- src/chop/nn/mx/activations.py | 5 +- src/chop/nn/mx/bmm.py | 12 +- src/chop/nn/mx/convolution.py | 32 +- src/chop/nn/mx/elemwise_ops.py | 8 +- src/chop/nn/mx/formats.py | 8 +- src/chop/nn/mx/linear.py | 32 +- src/chop/nn/mx/matmul.py | 12 +- src/chop/nn/mx/mx_ops.py | 9 +- src/chop/nn/mx/quantize.py | 4 +- src/chop/nn/mx/simd_ops.py | 2 +- src/chop/nn/mx/transpose_convolution.py | 32 +- src/chop/nn/optical/modules/morr_conv2d.py | 7 +- src/chop/nn/optical/modules/morr_linear.py | 6 +- src/chop/nn/optical/utils/__init__.py | 4 +- src/chop/nn/optical/utils/compute.py | 950 +----------------- src/chop/nn/optical/utils/initializer.py | 105 +- src/chop/nn/optical/utils/mrr.py | 39 - src/chop/nn/optical/utils/mrr_op.py | 341 +------ src/chop/nn/optical/utils/quantize.py | 264 +---- src/chop/nn/quantized/functional/gelu.py | 4 +- src/chop/nn/quantized/functional/linear.py | 83 +- src/chop/nn/quantized/functional/matmul.py | 8 +- src/chop/nn/quantized/functional/relu.py | 4 +- src/chop/nn/quantized/functional/selu.py | 4 +- src/chop/nn/quantized/functional/softplus.py | 4 +- src/chop/nn/quantized/functional/softsign.py | 4 +- src/chop/nn/quantized/functional/tanh.py | 4 +- src/chop/nn/quantized/modules/__init__.py | 12 +- src/chop/nn/quantized/modules/attention.py | 4 +- .../nn/quantized/modules/attention_head.py | 9 +- src/chop/nn/quantized/modules/batch_norm1d.py | 4 +- src/chop/nn/quantized/modules/conv1d.py | 12 +- src/chop/nn/quantized/modules/conv2d.py | 39 +- src/chop/nn/quantized/modules/gelu.py | 8 +- src/chop/nn/quantized/modules/gqa.py | 5 +- src/chop/nn/quantized/modules/group_norm.py | 4 +- .../nn/quantized/modules/instance_norm2d.py | 4 +- src/chop/nn/quantized/modules/layer_norm.py | 4 +- src/chop/nn/quantized/modules/linear.py | 18 +- src/chop/nn/quantized/modules/relu.py | 8 +- src/chop/nn/quantized/modules/rms_norm.py | 4 +- src/chop/nn/quantized/modules/selu.py | 8 +- src/chop/nn/quantized/modules/silu.py | 8 +- src/chop/nn/quantized/modules/softplus.py | 8 +- src/chop/nn/quantized/modules/softsign.py | 8 +- src/chop/nn/quantized/modules/tanh.py | 8 +- .../nn/quantizers/LUTNet/BaseInitializer.py | 5 +- src/chop/nn/quantizers/LUTNet/BaseTrainer.py | 2 +- src/chop/nn/quantizers/block_fp.py | 17 +- src/chop/nn/quantizers/block_log.py | 8 +- src/chop/nn/quantizers/block_minifloat.py | 9 +- src/chop/nn/quantizers/integer.py | 8 +- src/chop/nn/quantizers/log.py | 14 +- src/chop/nn/quantizers/minifloat.py | 34 +- src/chop/nn/quantizers/mxint_hardware.py | 6 +- src/chop/nn/quantizers/quantizers_for_hw.py | 14 +- src/chop/nn/quantizers/ternary.py | 3 +- src/chop/nn/quantizers/utils.py | 32 +- src/chop/nn/snn/auto_cuda/generator.py | 5 +- src/chop/nn/snn/modules/neuron/ifnode.py | 20 +- src/chop/nn/snn/modules/neuron/lifnode.py | 112 ++- .../nn/snn/modules/spiking_self_attention.py | 8 +- src/chop/passes/__init__.py | 4 +- src/chop/passes/graph/__init__.py | 10 +- .../add_metadata/add_common_metadata.py | 12 +- .../add_metadata/add_hardware_metadata.py | 3 +- .../add_metadata/common_metadata_layers.py | 57 +- .../add_metadata/hardware_metadata_layers.py | 37 +- .../add_metadata/software_metadata_layers.py | 16 +- .../autosharding/alpa_cost_modelling.py | 10 +- .../analysis/autosharding/autosharding.py | 44 +- .../graph/analysis/autosharding/megatron.py | 4 +- .../autosharding/strategies/basic_strategy.py | 5 +- .../autosharding/strategies/common.py | 6 +- .../autosharding/strategies/matrix_ops.py | 38 +- .../autosharding/strategies/pointwise_ops.py | 4 +- .../autosharding/strategies/tensor_ops.py | 4 +- .../autosharding/strategies/view_ops.py | 7 +- .../flop_estimator/calculator/calc_modules.py | 2 +- .../passes/graph/analysis/plot/plot_graph.py | 5 +- .../graph/interface/tensorrt/quantize.py | 1 + .../passes/graph/transforms/dse/run_dse.py | 2 +- src/chop/passes/graph/transforms/lora.py | 6 +- .../graph/transforms/onnxrt/quantize.py | 4 +- .../transforms/pruning/pruning_methods.py | 28 +- .../quant_parsers/parse_quant_config.py | 82 +- .../graph/transforms/training/modify.py | 5 +- .../transforms/utils/logicnets_fusion.py | 8 +- .../graph/transforms/verilog/emit_bram.py | 8 +- .../graph/transforms/verilog/emit_hls.py | 7 +- .../verilog/logicnets/emit_linear.py | 4 +- src/chop/passes/module/analysis/report.py | 3 +- src/chop/passes/utils.py | 8 +- src/chop/pipelines/auto_pipeline.py | 22 +- src/chop/tools/check_dependency.py | 7 +- src/chop/tools/huggingface.py | 14 +- .../tools/plt_wrapper/nlp/classification.py | 4 +- src/chop/tools/plt_wrapper/nlp/lm.py | 4 +- src/chop/tools/utils.py | 6 +- src/mase_cocotb/interfaces/streaming.py | 24 +- src/mase_cocotb/runner.py | 13 +- src/mase_cocotb/testbench.py | 7 +- src/mase_cocotb/utils.py | 4 +- src/mase_cocotb/z_qlayers/tensor_cast.py | 6 +- .../activation_layers/test/fixed_elu_tb.py | 14 +- .../activation_layers/test/fixed_gelu_tb.py | 14 +- .../test/fixed_hardshrink_tb.py | 6 +- .../test/fixed_hardswish_tb.py | 10 +- .../test/fixed_leaky_relu_tb.py | 6 +- .../test/fixed_logsigmoid_tb.py | 6 +- .../activation_layers/test/fixed_relu_tb.py | 6 +- .../activation_layers/test/fixed_selu_tb.py | 10 +- .../test/fixed_sigmoid_tb.py | 6 +- .../activation_layers/test/fixed_silu_tb.py | 6 +- .../test/fixed_softermax_1d_tb.py | 23 +- .../test/fixed_softermax_tb.py | 4 +- .../test/fixed_softmax_tb.py | 6 +- .../test/fixed_softplus_tb.py | 10 +- .../test/fixed_softshrink_tb.py | 6 +- .../test/fixed_softsign_tb.py | 10 +- .../activation_layers/test/fixed_tanh_tb.py | 10 +- .../activation_layers/test/softermax.py | 2 +- .../test/softermax_global_norm_tb.py | 26 +- .../test/softermax_local_window_tb.py | 8 +- .../test/softermax_lpw_pow2_tb.py | 51 +- .../test/softermax_lpw_reciprocal_tb.py | 40 +- .../cast/test/fixed_rounding_tb.py | 2 +- .../cast/test/fixed_signed_cast_tb.py | 13 +- .../cast/test/fixed_unsigned_cast_tb.py | 8 +- .../common/test/comparator_accumulator_tb.py | 8 +- .../common/test/comparator_tree_tb.py | 5 +- .../common/test/register_slice_tb.py | 4 +- .../common/test/single_element_repeat_tb.py | 6 +- ...binary_activation_binary_convolution_tb.py | 4 +- .../convolution_layers/test/convolution_tb.py | 34 +- .../convolution_layers/test/padding_tb.py | 10 +- .../convolution_layers/test/roller_tb.py | 4 +- .../test/sliding_window_tb.py | 6 +- src/mase_components/deps.py | 12 +- src/mase_components/helper/generate_memory.py | 8 +- .../hls/bfp_arith/bfp_adder.py | 6 +- .../hls/bfp_arith/bfp_multiplier.py | 6 +- src/mase_components/hls/elastic/buffer.py | 8 +- src/mase_components/hls/hls_regression.py | 5 +- .../hls/int_arith/int_layernorm.py | 8 +- src/mase_components/hls/int_arith/int_relu.py | 8 +- src/mase_components/hls/int_arith/int_silu.py | 8 +- .../hls/int_arith/int_softmax.py | 8 +- .../hls/int_arith/int_transpose.py | 8 +- .../hls/regression_gen/bfp_add_dse.py | 11 +- .../hls/regression_gen/bfp_linear2d_dse.py | 8 +- .../hls/regression_gen/bfp_mult_dse.py | 11 +- .../hls/regression_gen/buffer_dse.py | 11 +- .../hls/regression_gen/fork_dse.py | 11 +- .../hls/regression_gen/int_add_dse.py | 11 +- .../hls/regression_gen/int_layernorm_dse.py | 11 +- .../hls/regression_gen/int_linear2d_dse.py | 8 +- .../hls/regression_gen/int_matmul_dse.py | 8 +- .../hls/regression_gen/int_mult_dse.py | 11 +- .../hls/regression_gen/int_relu_dse.py | 11 +- .../hls/regression_gen/int_rmsnorm_dse.py | 11 +- .../hls/regression_gen/int_rope_dse.py | 11 +- .../hls/regression_gen/int_silu_dse.py | 11 +- .../hls/regression_gen/int_softmax_dse.py | 11 +- .../hls/regression_gen/int_transpose_dse.py | 11 +- src/mase_components/hls/scripts/bl_bfp.py | 15 +- ...y_activation_binary_adder_tree_layer_tb.py | 3 +- .../fixed_operators/test/fixed_isqrt_tb.py | 18 +- .../test/fixed_lut_index_tb.py | 3 +- .../fixed_operators/test/fixed_nr_stage_tb.py | 4 +- .../test/fixed_range_augmentation_tb.py | 2 +- .../test/fixed_range_reduction_tb.py | 3 +- .../fixed_operators/test/isqrt_sw.py | 10 +- .../matmul/test/fixed_matmul_tb.py | 8 +- .../matmul/test/simple_matmul_tb.py | 8 +- .../linear_layers/matmul/test/transpose_tb.py | 6 +- .../mxint_operators/test/mxint_cast_tb.py | 13 +- .../test/mxint_dot_product_tb.py | 9 +- .../mxint_operators/test/mxint_matmul_tb.py | 10 +- .../test/mxint_vector_mult_tb.py | 9 +- .../mxint_operators/test/test.py | 16 +- .../mxint_operators/test/utils.py | 6 +- src/mase_components/memory/test/fifo_tb.py | 9 +- .../memory/test/repeat_circular_buffer_tb.py | 2 +- .../memory/test/unpacked_fifo_tb.py | 2 +- .../process_synth_impl.py | 11 +- .../test/batch_norm_2d_tb.py | 6 +- .../test/channel_selection_tb.py | 1 - .../test/group_norm_2d_tb.py | 13 +- .../normalization_layers/test/models.py | 14 +- .../test/rms_norm_2d_tb.py | 11 +- .../fixed/test/fixed_isqrt_tb.py | 18 +- .../fixed/test/fixed_nr_stage_tb.py | 4 +- .../scalar_operators/fixed/test/isqrt_sw.py | 10 +- ...ixed_grouped_query_attention_wrapper_tb.py | 67 +- .../test/fixed_self_attention_head_tb.py | 32 +- .../vision_models/vit/test/fixed_mlp_tb.py | 6 +- .../vit/test/fixed_patch_embed_tb.py | 9 +- .../vision_models/vit/test/fixed_pvt_tb.py | 12 +- .../vision_models/vit/test/hash_exp_tb.py | 4 +- .../vision_models/vit/test/hash_softmax_tb.py | 6 +- .../vit/test/helpers/ha_softmax.py | 11 +- .../vit/test/helpers/pvt_quant.py | 2 +- test/nn/quantized/modules/attention_head.py | 4 +- test/nn/snn/test_ann2snn.py | 8 +- .../add_metadata/test_add_common_metadata.py | 16 +- .../analysis/pruning/test_hook_inspect.py | 12 +- .../test_statistic_profiler.py | 16 +- .../graph/transforms/prune/test_prune.py | 12 +- .../prune/test_prune_detach_hook.py | 12 +- .../quantize/test_quantize_lutnet_conv2d.py | 3 +- .../quantize/test_quantize_lutnet_linear_2.py | 4 +- .../training/test_training_base_pass.py | 10 +- .../verilog/test_emit_activation_gelu.py | 6 +- .../verilog/test_emit_activation_selu.py | 6 +- .../verilog/test_emit_activation_softplus.py | 6 +- .../verilog/test_emit_activation_softsign.py | 6 +- .../verilog/test_emit_activation_tanh.py | 6 +- .../verilog/test_emit_verilog_bert.py | 11 +- .../verilog/test_emit_verilog_llama.py | 11 +- .../verilog/test_emit_verilog_mistral.py | 11 +- .../verilog/test_emit_verilog_norm.py | 27 +- .../onnx/analysis/test_export_fx_graph.py | 7 +- test/self/test_optical_module.py | 252 +++++ test/self/train_mnist_cnn.py | 247 +++++ test/tools/test_onnx_operators.py | 57 +- 286 files changed, 1477 insertions(+), 4061 deletions(-) create mode 100644 test/self/test_optical_module.py create mode 100644 test/self/train_mnist_cnn.py diff --git a/scripts/mase-hls.py b/scripts/mase-hls.py index f2097bcd1..925ec2300 100755 --- a/scripts/mase-hls.py +++ b/scripts/mase-hls.py @@ -44,33 +44,17 @@ def build(self): def quick_test(self): shutil.copy( - os.path.join( - self.root, - "test", - "test_in.mlir", - ), - os.path.join( - self.root, - "test", - "test.mlir", - ), + os.path.join(self.root, "test", "test_in.mlir",), + os.path.join(self.root, "test", "test.mlir",), ) result = False cmd = [ "mase-opt", "--preprocess-func=func-name=relu", "--canonicalize", - os.path.join( - self.root, - "test", - "test.mlir", - ), + os.path.join(self.root, "test", "test.mlir",), "-o", - os.path.join( - self.root, - "test", - "test1.mlir", - ), + os.path.join(self.root, "test", "test1.mlir",), ] result |= self.execute(cmd, log_output=True, cwd=self.root) @@ -83,11 +67,7 @@ def quick_test(self): # "test", # "test.cpp", # ), - os.path.join( - self.root, - "test", - "test1.mlir", - ), + os.path.join(self.root, "test", "test1.mlir",), "--debug", ] result |= self.execute(cmd, log_output=True, cwd=self.root) diff --git a/scripts/stat-to-conf.py b/scripts/stat-to-conf.py index 01e0d0cbc..182e66d64 100755 --- a/scripts/stat-to-conf.py +++ b/scripts/stat-to-conf.py @@ -41,12 +41,7 @@ } -def set_stat( - entry_name: str, - mean=None, - median=None, - max=None, -) -> dict[str, Any]: +def set_stat(entry_name: str, mean=None, median=None, max=None,) -> dict[str, Any]: """Return a dictionary containing the format of the stats required to use ternary quantiser. If statistics are not specified, "NA" will be set as the value, this interally is being interpreted as None when the .toml is loaded""" diff --git a/setup.py b/setup.py index 27d865de1..e0e91be1d 100644 --- a/setup.py +++ b/setup.py @@ -105,9 +105,7 @@ def get_system(): author_email="a.zhao@imperial.ac.uk, jianyi.cheng17@imperial.ac.uk, chengzhang98@outlook.com, pedro.gimenes19@imperial.ac.uk", license_files=("LICENSE",), python_requires=">=3.11.9", - package_dir={ - "": "src", - }, + package_dir={"": "src",}, packages=find_packages("src"), install_requires=requirements, ) diff --git a/src/chop/actions/emit.py b/src/chop/actions/emit.py index 7e326bc7b..9318b4476 100644 --- a/src/chop/actions/emit.py +++ b/src/chop/actions/emit.py @@ -41,10 +41,7 @@ def emit( data_module.prepare_data() data_module.setup() dummy_in = get_dummy_input( - model_info=model_info, - data_module=data_module, - task=task, - device="cpu", + model_info=model_info, data_module=data_module, task=task, device="cpu", ) mg, _ = add_common_metadata_analysis_pass( mg, {"dummy_in": dummy_in, "add_value": False} diff --git a/src/chop/actions/search/search.py b/src/chop/actions/search/search.py index fe82d3574..48c2d21d7 100644 --- a/src/chop/actions/search/search.py +++ b/src/chop/actions/search/search.py @@ -15,9 +15,7 @@ logger = logging.getLogger(__name__) -def parse_search_config( - search_config: dict, -): +def parse_search_config(search_config: dict,): """ Parse search config from a dict or a toml file and do sanity check. The search config must consist of two parts: strategy and search_space. diff --git a/src/chop/actions/search/search_space/nas_bert.py b/src/chop/actions/search/search_space/nas_bert.py index 3c4838ea1..760d3b27f 100644 --- a/src/chop/actions/search/search_space/nas_bert.py +++ b/src/chop/actions/search/search_space/nas_bert.py @@ -1199,9 +1199,11 @@ def forward( # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] if self.config.is_decoder and encoder_hidden_states is not None: - encoder_batch_size, encoder_sequence_length, _ = ( - encoder_hidden_states.size() - ) + ( + encoder_batch_size, + encoder_sequence_length, + _, + ) = encoder_hidden_states.size() encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) if encoder_attention_mask is None: encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) diff --git a/src/chop/actions/search/search_space/quantization/manual_hf_module.py b/src/chop/actions/search/search_space/quantization/manual_hf_module.py index 094330da1..a9c9f1bfb 100644 --- a/src/chop/actions/search/search_space/quantization/manual_hf_module.py +++ b/src/chop/actions/search/search_space/quantization/manual_hf_module.py @@ -61,8 +61,7 @@ def rebuild_model(self, sampled_config: dict, is_eval_mode: bool): with init_empty_weights(): model = self.model_cls(config) device_map = infer_auto_device_map( - model, - no_split_module_classes=model._no_split_modules, + model, no_split_module_classes=model._no_split_modules, ) model = load_checkpoint_and_dispatch( model, checkpoint=self.model_name, device_map=device_map diff --git a/src/chop/actions/simulate.py b/src/chop/actions/simulate.py index e56a512d8..8a0e93486 100644 --- a/src/chop/actions/simulate.py +++ b/src/chop/actions/simulate.py @@ -68,9 +68,7 @@ def simulate( else: raise ValueError(f"Unrecognized simulator: {simulator}") - includes = [ - project_dir / "hardware" / "rtl", - ] + [ + includes = [project_dir / "hardware" / "rtl",] + [ Path(mase_components.__file__).parent / module / "rtl" for module in get_modules() ] diff --git a/src/chop/actions/train.py b/src/chop/actions/train.py index 5daa9b49a..93524e756 100644 --- a/src/chop/actions/train.py +++ b/src/chop/actions/train.py @@ -109,8 +109,7 @@ def train( trainer = pl.Trainer(**plt_trainer_args) trainer.fit( - pl_model, - datamodule=data_module, + pl_model, datamodule=data_module, ) # Save the trained model along with relevant metadata in the training_ckpts folder. diff --git a/src/chop/dataset/nerf/blender.py b/src/chop/dataset/nerf/blender.py index 53865d621..748dacca7 100644 --- a/src/chop/dataset/nerf/blender.py +++ b/src/chop/dataset/nerf/blender.py @@ -34,9 +34,7 @@ def _download_lego_dataset(path: Path) -> None: # Unzip the file subprocess.run( - f"unzip {folder_path.as_posix()} -d {path.as_posix()}", - shell=True, - check=True, + f"unzip {folder_path.as_posix()} -d {path.as_posix()}", shell=True, check=True, ) diff --git a/src/chop/dataset/nlp/language_modeling.py b/src/chop/dataset/nlp/language_modeling.py index ba2bdee55..8d17b9466 100644 --- a/src/chop/dataset/nlp/language_modeling.py +++ b/src/chop/dataset/nlp/language_modeling.py @@ -281,10 +281,7 @@ def _tokenize(text, tokenizer, max_length): prompt_len = prompt_tokenized.ne(tokenizer.pad_token_id).sum().item() target_tokenized[:prompt_len] = ignore_id - return dict( - input_ids=input_ids, - labels=target_tokenized, - ) + return dict(input_ids=input_ids, labels=target_tokenized,) def prepare_data(self): dataset_dict = self._download_dataset() @@ -316,10 +313,7 @@ def setup(self): dataset_dict = self._download_dataset() dataset_dict = dataset_dict["train"].train_test_split(test_size=0.1, seed=42) dataset_dict = hf_datasets.DatasetDict( - { - "train": dataset_dict["train"], - "validation": dataset_dict["test"], - } + {"train": dataset_dict["train"], "validation": dataset_dict["test"],} ) dataset_dict = dataset_dict.map( function=partial( diff --git a/src/chop/dataset/vision/transforms/cifar.py b/src/chop/dataset/vision/transforms/cifar.py index ef44d3721..84e7e1287 100644 --- a/src/chop/dataset/vision/transforms/cifar.py +++ b/src/chop/dataset/vision/transforms/cifar.py @@ -35,10 +35,7 @@ def _get_cifar_default_transform(train: bool, mean: tuple[float], std: tuple[float]): if train: - transform = create_transform( - **DEFAULT_CIFAR_PREPROCESS_ARGS, - is_training=True, - ) + transform = create_transform(**DEFAULT_CIFAR_PREPROCESS_ARGS, is_training=True,) transform.transforms[0] = tv_transforms.RandomCrop( DEFAULT_CIFAR_PREPROCESS_ARGS["input_size"], padding=4 ) diff --git a/src/chop/distributed/launcher.py b/src/chop/distributed/launcher.py index 712bb026f..21337be8c 100644 --- a/src/chop/distributed/launcher.py +++ b/src/chop/distributed/launcher.py @@ -33,10 +33,7 @@ def distributed_average_timing(fn, repeat, args): times = [] for itr in range(repeat): rlog( - logger, - dist.get_rank(), - f"Running teration {itr}", - "debug", + logger, dist.get_rank(), f"Running teration {itr}", "debug", ) dist.barrier(async_op=True) start = time() @@ -45,10 +42,7 @@ def distributed_average_timing(fn, repeat, args): end = time() times.append(end - start) rlog( - logger, - dist.get_rank(), - f"Time taken: {end - start}s", - "debug", + logger, dist.get_rank(), f"Time taken: {end - start}s", "debug", ) return result, sum(times[2:]) / len(times[2:]) @@ -138,11 +132,7 @@ def device_fn( distribute_tensor(in_tensor, mesh, [Replicate(), Replicate()]) for in_tensor in inputs ] - _, time_taken = distributed_average_timing( - fn=model, - repeat=10, - args=inputs, - ) + _, time_taken = distributed_average_timing(fn=model, repeat=10, args=inputs,) rlog(logger, rank, f"Forward pass finished. Time taken: {time_taken}", level="info") dist.destroy_process_group() diff --git a/src/chop/distributed/tensor/__init__.py b/src/chop/distributed/tensor/__init__.py index 826ed05cc..b9aa9c272 100644 --- a/src/chop/distributed/tensor/__init__.py +++ b/src/chop/distributed/tensor/__init__.py @@ -34,11 +34,7 @@ def _dtensor_init_helper( - init_op, - size: torch.Size, - device_mesh=None, - placements=None, - **kwargs, + init_op, size: torch.Size, device_mesh=None, placements=None, **kwargs, ) -> DTensor: from torch.distributed.tensor.placement_types import _DTensorSpec, TensorMeta @@ -82,18 +78,10 @@ def _dtensor_init_helper( spec = _DTensorSpec( device_mesh, tuple(placements), - tensor_meta=TensorMeta( - size, - torch_stride, - local_tensor.dtype, - ), + tensor_meta=TensorMeta(size, torch_stride, local_tensor.dtype,), ) - return DTensor( - local_tensor, - spec, - requires_grad=kwargs["requires_grad"], - ) + return DTensor(local_tensor, spec, requires_grad=kwargs["requires_grad"],) def ones( diff --git a/src/chop/distributed/tensor/_dispatch.py b/src/chop/distributed/tensor/_dispatch.py index 9aa96ae68..dcc74852f 100644 --- a/src/chop/distributed/tensor/_dispatch.py +++ b/src/chop/distributed/tensor/_dispatch.py @@ -42,9 +42,7 @@ def decompose_handler( - op_call: torch._ops.OpOverload, - args: Tuple[object, ...], - kwargs: Dict[str, object], + op_call: torch._ops.OpOverload, args: Tuple[object, ...], kwargs: Dict[str, object], ) -> object: """ Decomposes a op to core ATen op, this handler is mostly here @@ -58,9 +56,7 @@ def decompose_handler( def is_same_size_handler( - op_call: torch._ops.OpOverload, - args: Tuple[object, ...], - kwargs: Dict[str, object], + op_call: torch._ops.OpOverload, args: Tuple[object, ...], kwargs: Dict[str, object], ) -> bool: lhs = cast(torch.Tensor, args[0]) rhs = cast(torch.Tensor, args[1]) @@ -201,8 +197,9 @@ def default_tensor(spec: _DTensorSpec) -> torch.Tensor: # did not already construct one random._rng_tracker = random.OffsetBasedRNGTracker(mesh.device_type) - first_arg, first_local_arg = cast(dtensor.DTensor, args[0]), cast( - torch.Tensor, local_tensor_args[0] + first_arg, first_local_arg = ( + cast(dtensor.DTensor, args[0]), + cast(torch.Tensor, local_tensor_args[0]), ) rng_context = ( random._rng_tracker._distribute_region(first_arg._spec) @@ -254,8 +251,7 @@ def default_tensor(spec: _DTensorSpec) -> torch.Tensor: @staticmethod def redistribute_local_args( - op_info: OpInfo, - suggested_input_schema: OpSchema, + op_info: OpInfo, suggested_input_schema: OpSchema, ) -> None: # NOTE: it's very rare that we need to reshard kwargs so we intentionally skip it diff --git a/src/chop/distributed/tensor/_redistribute.py b/src/chop/distributed/tensor/_redistribute.py index e04495126..e65766eb7 100644 --- a/src/chop/distributed/tensor/_redistribute.py +++ b/src/chop/distributed/tensor/_redistribute.py @@ -48,8 +48,7 @@ def _replicate_then_shard(val: _TransformInfo) -> int: @lru_cache(maxsize=None) def _gen_transform_infos( - src_spec: _DTensorSpec, - dst_spec: _DTensorSpec, + src_spec: _DTensorSpec, dst_spec: _DTensorSpec, ) -> List[_TransformInfo]: """ Generate the transform infos from the source placements to the target placements. @@ -88,9 +87,7 @@ def _gen_transform_infos( # calculate and save the logical shape for this sharding mesh_dim_size = device_mesh.size(mesh_dim=i) local_shard_size, _ = src._local_shard_size_on_dim( - current_logical_shape[src.dim], - mesh_dim_size, - my_coordinate[i], + current_logical_shape[src.dim], mesh_dim_size, my_coordinate[i], ) new_logical_shape = list(current_logical_shape) new_logical_shape[src.dim] = local_shard_size @@ -288,11 +285,7 @@ def forward( # type: ignore[override] output = input._local_tensor target_spec = current_spec - return dtensor.DTensor( - output, - target_spec, - requires_grad=input.requires_grad, - ) + return dtensor.DTensor(output, target_spec, requires_grad=input.requires_grad,) @staticmethod def backward(ctx, grad_output: "dtensor.DTensor"): # type: ignore[override] @@ -327,9 +320,7 @@ def backward(ctx, grad_output: "dtensor.DTensor"): # type: ignore[override] ), ) output_dtensor = dtensor.DTensor( - output, - spec, - requires_grad=grad_output.requires_grad, + output, spec, requires_grad=grad_output.requires_grad, ) return ( diff --git a/src/chop/distributed/tensor/_sharding_prop.py b/src/chop/distributed/tensor/_sharding_prop.py index 7351b7895..8e1e26db1 100644 --- a/src/chop/distributed/tensor/_sharding_prop.py +++ b/src/chop/distributed/tensor/_sharding_prop.py @@ -48,8 +48,7 @@ class ShardingPropagator: def __init__(self) -> None: self.op_to_rules: Dict[OpOverload, Callable[[OpSchema], OutputSharding]] = {} self.op_strategy_funcs: Dict[ - OpOverload, - Callable[[DeviceMesh, OpSchema], StrategyType], + OpOverload, Callable[[DeviceMesh, OpSchema], StrategyType], ] = {} # op map to save static argnum to decide to reuse sharding prop cache or re-run sharding prop self.op_to_schema_info: Dict[OpOverload, RuntimeSchemaInfo] = {} @@ -344,10 +343,8 @@ def spec_to_strategy(spec: object) -> object: expected_input_spec = selected_strategies[idx].input_spec( tensor_or_list_tensor_arg_idx ) - expected_input_spec = ( - expected_input_spec.shallow_copy_with_tensor_meta( - arg_spec.tensor_meta - ) + expected_input_spec = expected_input_spec.shallow_copy_with_tensor_meta( + arg_spec.tensor_meta ) if arg_spec.placements != expected_input_spec.placements: needs_redistribute = True @@ -363,10 +360,8 @@ def spec_to_strategy(spec: object) -> object: expected_input_spec = selected_strategies[0].input_spec( tensor_or_list_tensor_arg_idx ) - expected_input_spec = ( - expected_input_spec.shallow_copy_with_tensor_meta( - arg.tensor_meta - ) + expected_input_spec = expected_input_spec.shallow_copy_with_tensor_meta( + arg.tensor_meta ) if arg.placements != expected_input_spec.placements: needs_redistribute = True diff --git a/src/chop/distributed/tensor/api.py b/src/chop/distributed/tensor/api.py index dcef33b24..a27c848f4 100644 --- a/src/chop/distributed/tensor/api.py +++ b/src/chop/distributed/tensor/api.py @@ -64,9 +64,7 @@ class _ToTorchTensor(torch.autograd.Function): @staticmethod def forward( # type: ignore[override] - ctx, - input: "DTensor", - grad_placements: Optional[Sequence[Placement]], + ctx, input: "DTensor", grad_placements: Optional[Sequence[Placement]], ): ctx.dtensor_spec = input._spec ctx.grad_placements = grad_placements @@ -100,11 +98,7 @@ def backward(ctx, grad_output: torch.Tensor): # type: ignore[override] ) return ( - DTensor( - grad_output, - grad_spec, - requires_grad=grad_output.requires_grad, - ), + DTensor(grad_output, grad_spec, requires_grad=grad_output.requires_grad,), None, ) @@ -160,11 +154,7 @@ def forward( # type: ignore[override] dist_spec = _DTensorSpec( device_mesh, placements, - tensor_meta=TensorMeta( - tensor_shape, - tensor_stride, - input.dtype, - ), + tensor_meta=TensorMeta(tensor_shape, tensor_stride, input.dtype,), ) # We want a fresh Tensor object that shares memory with the input tensor @@ -217,11 +207,7 @@ class DTensor(torch.Tensor): # pyre-ignore[13]: pyre is bad at __new__ @staticmethod @torch._disable_dynamo def __new__( - cls, - local_tensor: torch.Tensor, - spec: _DTensorSpec, - *, - requires_grad: bool, + cls, local_tensor: torch.Tensor, spec: _DTensorSpec, *, requires_grad: bool, ) -> "DTensor": """ Construct a DTensor from a local tensor, device mesh, and placement and @@ -277,20 +263,12 @@ def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride): local_tensor = inner_tensors["_local_tensor"] spec, requires_grad = flatten_spec unflatten_tensor_meta = TensorMeta( - shape=outer_size, - stride=outer_stride, - dtype=spec.tensor_meta.dtype, + shape=outer_size, stride=outer_stride, dtype=spec.tensor_meta.dtype, ) unflatten_spec = _DTensorSpec( - spec.mesh, - spec.placements, - tensor_meta=unflatten_tensor_meta, - ) - return DTensor( - local_tensor, - unflatten_spec, - requires_grad=requires_grad, + spec.mesh, spec.placements, tensor_meta=unflatten_tensor_meta, ) + return DTensor(local_tensor, unflatten_spec, requires_grad=requires_grad,) def __coerce_tangent_metadata__(self): if not any(isinstance(p, Partial) for p in self.placements): @@ -303,8 +281,7 @@ def __coerce_tangent_metadata__(self): def __coerce_same_metadata_as_tangent__(self, flatten_spec): (spec, _) = flatten_spec # Result of tensor_flatten() return self.redistribute( - device_mesh=self.device_mesh, - placements=spec.placements, + device_mesh=self.device_mesh, placements=spec.placements, ) @classmethod @@ -312,11 +289,7 @@ def __coerce_same_metadata_as_tangent__(self, flatten_spec): # pyre-fixme[3]: Return type must be annotated. # pyre-fixme[2]: Parameter must be annotated. def __torch_dispatch__(cls, func, types, args=(), kwargs=None): - return DTensor._op_dispatcher.dispatch( - func, - args, - kwargs or {}, - ) + return DTensor._op_dispatcher.dispatch(func, args, kwargs or {},) @staticmethod def from_local( @@ -394,12 +367,7 @@ def from_local( # created should flow back the gradients to the local_tensor, so we call an autograd # function to construct the dist tensor instead. return _FromTorchTensor.apply( # pyre-ignore[16]: autograd func - local_tensor, - device_mesh, - tuple(placements), - run_check, - shape, - stride, + local_tensor, device_mesh, tuple(placements), run_check, shape, stride, ) def to_local( @@ -699,9 +667,7 @@ def distribute_tensor( mesh=device_mesh, placements=placements, tensor_meta=TensorMeta( - shape=tensor.size(), - stride=tensor.stride(), - dtype=tensor.dtype, + shape=tensor.size(), stride=tensor.stride(), dtype=tensor.dtype, ), ) return DTensor( diff --git a/src/chop/distributed/tensor/ops/basic_strategy.py b/src/chop/distributed/tensor/ops/basic_strategy.py index 70ba54416..3a9e81ae2 100644 --- a/src/chop/distributed/tensor/ops/basic_strategy.py +++ b/src/chop/distributed/tensor/ops/basic_strategy.py @@ -84,10 +84,7 @@ def parse_dims(cls, input_dims: List[str], output_dim: str) -> "EinsumDims": def gen_einsum_strategies( - equation: str, - mesh: DeviceMesh, - *, - linearity: bool = False, + equation: str, mesh: DeviceMesh, *, linearity: bool = False, ) -> OpStrategy: """ Generate a strategy list for the ops that follow einsum style notation. diff --git a/src/chop/distributed/tensor/ops/common_rules.py b/src/chop/distributed/tensor/ops/common_rules.py index e0e9a6e34..1324f7806 100644 --- a/src/chop/distributed/tensor/ops/common_rules.py +++ b/src/chop/distributed/tensor/ops/common_rules.py @@ -37,10 +37,7 @@ def _gen_reshard_suggestions( ) suggested_schema = OpSchema(op_schema.op, tuple(suggested_arg_specs), {}) suggested_schema._inplace_rewrap_schema_suggestion(op_schema) - return OutputSharding( - None, - redistribute_schema=suggested_schema, - ) + return OutputSharding(None, redistribute_schema=suggested_schema,) def einop_rule( @@ -218,10 +215,7 @@ def merge_sharding(dim: str, a: int, b: int) -> int: ) return OutputSharding( _DTensorSpec.from_dim_map( - input_specs[0].mesh, - output_dim_map, - pending_sums, - tensor_meta=tensor_meta, + input_specs[0].mesh, output_dim_map, pending_sums, tensor_meta=tensor_meta, ) ) @@ -281,8 +275,5 @@ def pointwise_rule(op_schema: OpSchema, linearity: bool = False) -> OutputShardi enforce_sharding[out_dimchar] = mesh_dim return einop_rule( - fmt, - op_schema, - linearity=linearity, - enforce_sharding=enforce_sharding, + fmt, op_schema, linearity=linearity, enforce_sharding=enforce_sharding, ) diff --git a/src/chop/distributed/tensor/ops/conv_ops.py b/src/chop/distributed/tensor/ops/conv_ops.py index 707f391d6..9a4d5f425 100644 --- a/src/chop/distributed/tensor/ops/conv_ops.py +++ b/src/chop/distributed/tensor/ops/conv_ops.py @@ -50,16 +50,11 @@ def convolution_rules(op_schema: OpSchema) -> OutputSharding: pending_sums = input_spec.sums tensor_meta = TensorMeta( - torch.Size(output_shape), - output_stride, - input_spec.tensor_meta.dtype, + torch.Size(output_shape), output_stride, input_spec.tensor_meta.dtype, ) return OutputSharding( _DTensorSpec.from_dim_map( - input_spec.mesh, - output_dim_map, - pending_sums, - tensor_meta=tensor_meta, + input_spec.mesh, output_dim_map, pending_sums, tensor_meta=tensor_meta, ) ) @@ -88,22 +83,14 @@ def convolution_backward_rules(op_schema: OpSchema) -> OutputSharding: assert input_spec.tensor_meta is not None weight_tensor_meta = weight_spec.tensor_meta bias_tensor_meta = TensorMeta( - torch.Size(bias_shape_opt), - (1,), - input_spec.tensor_meta.dtype, + torch.Size(bias_shape_opt), (1,), input_spec.tensor_meta.dtype, ) grad_input_spec = input_spec grad_weight_spec = _DTensorSpec.from_dim_map( - input_spec.mesh, - [-1, -1, -1, -1], - [0], - tensor_meta=weight_tensor_meta, + input_spec.mesh, [-1, -1, -1, -1], [0], tensor_meta=weight_tensor_meta, ) grad_bias_spec = _DTensorSpec.from_dim_map( - input_spec.mesh, - [-1], - [0], - tensor_meta=bias_tensor_meta, + input_spec.mesh, [-1], [0], tensor_meta=bias_tensor_meta, ) return OutputSharding([grad_input_spec, grad_weight_spec, grad_bias_spec]) diff --git a/src/chop/distributed/tensor/ops/math_ops.py b/src/chop/distributed/tensor/ops/math_ops.py index 770514219..da8dad8d5 100644 --- a/src/chop/distributed/tensor/ops/math_ops.py +++ b/src/chop/distributed/tensor/ops/math_ops.py @@ -128,7 +128,7 @@ def _pre_reduce_transform(self, tensor: torch.Tensor) -> torch.Tensor: if self.reduce_op == "sum": assert isinstance(self.norm_type, (int, float)), f"{self.norm_type}" if self.norm_type != 0 and self.norm_type != 1: - return tensor**self.norm_type + return tensor ** self.norm_type return tensor def _post_reduce_transform(self, tensor: torch.Tensor) -> torch.Tensor: @@ -289,10 +289,7 @@ def common_reduction_strategy( redistribute_cost = [generate_redistribute_costs(input_strategy, input_spec)] reduction_strategy.strategies.append( PlacementStrategy( - output_specs=_DTensorSpec( - mesh=mesh, - placements=out_placements, - ), + output_specs=_DTensorSpec(mesh=mesh, placements=out_placements,), input_specs=(input_spec,), redistribute_cost=redistribute_cost, ) @@ -478,10 +475,7 @@ def softmax_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy: @register_op_strategy( - [ - aten._log_softmax_backward_data.default, - aten._softmax_backward_data.default, - ], + [aten._log_softmax_backward_data.default, aten._softmax_backward_data.default,], schema_info=RuntimeSchemaInfo(2), ) def softmax_backward_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy: @@ -614,21 +608,14 @@ def nll_loss_forward_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrate reduce_dims_map, reduction_op, ) - output_expected_spec = _DTensorSpec( - mesh=mesh, - placements=out_placements, - ) + output_expected_spec = _DTensorSpec(mesh=mesh, placements=out_placements,) # whether reduction is sum or mean, the total weight has to be summed up if not replicated total_weight_placements = map_placements_after_reduction( - target_expected_spec.placements, - reduce_dims, - reduce_dims_map, - "sum", + target_expected_spec.placements, reduce_dims, reduce_dims_map, "sum", ) total_weight_expected_spec = _DTensorSpec( - mesh=mesh, - placements=total_weight_placements, + mesh=mesh, placements=total_weight_placements, ) output_strategy.strategies.append( @@ -761,8 +748,7 @@ def rlog(msg): @register_op_strategy( - [aten.native_layer_norm.default], - schema_info=RuntimeSchemaInfo(1), + [aten.native_layer_norm.default], schema_info=RuntimeSchemaInfo(1), ) def layer_norm_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy: # args must be: input, normalized_shape, weight, bias, eps @@ -856,8 +842,7 @@ def layer_norm_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy: @register_op_strategy( - [aten.native_layer_norm_backward.default], - schema_info=RuntimeSchemaInfo(2), + [aten.native_layer_norm_backward.default], schema_info=RuntimeSchemaInfo(2), ) def layer_norm_bwd_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy: # args must be: grad_out, input, normalized_shape, mean, rstd, @@ -1014,8 +999,7 @@ def layer_norm_bwd_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy @register_op_strategy( - [aten.topk.default], - schema_info=RuntimeSchemaInfo(2), + [aten.topk.default], schema_info=RuntimeSchemaInfo(2), ) def topk_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy: input_strategy = cast(OpStrategy, op_schema.args_schema[0]) diff --git a/src/chop/distributed/tensor/ops/matrix_ops.py b/src/chop/distributed/tensor/ops/matrix_ops.py index f76dca190..029a795df 100644 --- a/src/chop/distributed/tensor/ops/matrix_ops.py +++ b/src/chop/distributed/tensor/ops/matrix_ops.py @@ -384,10 +384,7 @@ def scaled_dot_product_efficient_attention_strategy( single_mesh_dim_strategies.append(num_heads_dim_sharding) return expand_to_full_mesh_op_strategy( - mesh, - op_schema, - single_mesh_dim_strategies, - input_index=4, + mesh, op_schema, single_mesh_dim_strategies, input_index=4, ) @@ -452,8 +449,5 @@ def scaled_dot_product_efficient_attention_backward_strategy( single_mesh_dim_strategies.append(num_heads_dim_sharding) return expand_to_full_mesh_op_strategy( - mesh, - op_schema, - single_mesh_dim_strategies, - input_index=4, + mesh, op_schema, single_mesh_dim_strategies, input_index=4, ) diff --git a/src/chop/distributed/tensor/ops/pointwise_ops.py b/src/chop/distributed/tensor/ops/pointwise_ops.py index 221001f01..656fd2996 100644 --- a/src/chop/distributed/tensor/ops/pointwise_ops.py +++ b/src/chop/distributed/tensor/ops/pointwise_ops.py @@ -483,9 +483,7 @@ def common_pointwise_strategy( common_shape, input_arg_spec.shape ) input_target_placements = map_placements_after_broadcast( - tuple(out_placements), - common_shape, - input_arg_dims_map, + tuple(out_placements), common_shape, input_arg_dims_map, ) input_arg_target_spec = _DTensorSpec( mesh=mesh, @@ -499,10 +497,7 @@ def common_pointwise_strategy( pointwise_strategy.strategies.append( PlacementStrategy( - output_specs=_DTensorSpec( - mesh=mesh, - placements=tuple(out_placements), - ), + output_specs=_DTensorSpec(mesh=mesh, placements=tuple(out_placements),), input_specs=input_specs, redistribute_cost=redistribute_costs, ) diff --git a/src/chop/distributed/tensor/ops/tensor_ops.py b/src/chop/distributed/tensor/ops/tensor_ops.py index dcddcb98c..f54640e8a 100644 --- a/src/chop/distributed/tensor/ops/tensor_ops.py +++ b/src/chop/distributed/tensor/ops/tensor_ops.py @@ -76,10 +76,7 @@ def default_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: @register_op_strategy( - [ - aten.equal.default, - aten.is_same_size.default, - ] + [aten.equal.default, aten.is_same_size.default,] ) def equal_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: # equal_strategy deals with ops that comparing two tensor, we need to make sure @@ -128,8 +125,7 @@ def equal_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: schema_info=RuntimeSchemaInfo(1, ["dtype"]), ) @register_op_strategy( - [aten.full_like.default], - schema_info=RuntimeSchemaInfo(2, ["dtype"]), + [aten.full_like.default], schema_info=RuntimeSchemaInfo(2, ["dtype"]), ) @register_op_strategy( [ @@ -696,8 +692,7 @@ def place(vp: Placement, ip: Placement) -> Placement: ) result = OutputSharding( output_spec=_DTensorSpec( - mesh=values_spec.mesh, - placements=value_placements, + mesh=values_spec.mesh, placements=value_placements, ) ) return result @@ -784,10 +779,7 @@ def size_split(N, i): else split_size_or_sections ) output_spec_list = [ - _DTensorSpec( - mesh=input_spec.mesh, - placements=input_spec.placements, - ) + _DTensorSpec(mesh=input_spec.mesh, placements=input_spec.placements,) for _ in range(len(output_size_list)) ] return OutputSharding(output_spec_list) diff --git a/src/chop/distributed/tensor/ops/utils.py b/src/chop/distributed/tensor/ops/utils.py index 28e5245c3..e005501d6 100644 --- a/src/chop/distributed/tensor/ops/utils.py +++ b/src/chop/distributed/tensor/ops/utils.py @@ -196,9 +196,7 @@ def infer_broadcast_dims_map( def map_placements_after_broadcast( - placements: Tuple[Placement, ...], - shape: torch.Size, - broadcast_dims_map: List[int], + placements: Tuple[Placement, ...], shape: torch.Size, broadcast_dims_map: List[int], ) -> Tuple[Placement, ...]: """Map each placement based on the output shape after broadcast.""" new_placements: List[Placement] = [] diff --git a/src/chop/distributed/tensor/ops/view_ops.py b/src/chop/distributed/tensor/ops/view_ops.py index 5c91f6d64..6b89771ec 100644 --- a/src/chop/distributed/tensor/ops/view_ops.py +++ b/src/chop/distributed/tensor/ops/view_ops.py @@ -238,9 +238,7 @@ def dim_flatten(ndim: int, start_dim=0, end_dim=-1) -> DimMap: def dim_movedim( - ndim: int, - input: Union[int, Sequence[int]], - destination: Union[int, Sequence[int]], + ndim: int, input: Union[int, Sequence[int]], destination: Union[int, Sequence[int]], ) -> DimMap: input = normalize_dims(input, ndim) destination = normalize_dims(destination, ndim) @@ -606,10 +604,7 @@ def reshape_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: input_src_spec = input_placement_strategy.output_spec input_tgt_placements, output_placements = propagate_shape_and_sharding( - input_src_spec.placements, - tuple(global_in_shape), - rules, - mesh.shape, + input_src_spec.placements, tuple(global_in_shape), rules, mesh.shape, ) # TODO: optimize this. we shouldn't simply blindly replicate diff --git a/src/chop/ir/graph/mase_graph.py b/src/chop/ir/graph/mase_graph.py index 322195869..5ee6b8bb3 100644 --- a/src/chop/ir/graph/mase_graph.py +++ b/src/chop/ir/graph/mase_graph.py @@ -76,13 +76,7 @@ def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool is_fx_built_in_leaf_module = super().is_leaf_module(m, module_qualified_name) is_mase_leaf_layers = isinstance(m, MASE_LEAF_LAYERS) is_custom_layer = isinstance(m, self.custom_leaf_layers) - return any( - ( - is_fx_built_in_leaf_module, - is_mase_leaf_layers, - is_custom_layer, - ) - ) + return any((is_fx_built_in_leaf_module, is_mase_leaf_layers, is_custom_layer,)) def trace_torch_module( @@ -128,19 +122,13 @@ def is_leaf_module( self, m: torch.nn.Module, module_qualified_name: str ) -> bool: is_hf_built_in_leaf_module = hf_is_leaf_module( - self, - m, - module_qualified_name, + self, m, module_qualified_name, ) is_custom_module = isinstance(m, custom_modules) is_mase_leaf_layer = isinstance(m, MASE_LEAF_LAYERS) return any( - ( - is_hf_built_in_leaf_module, - is_custom_module, - is_mase_leaf_layer, - ) + (is_hf_built_in_leaf_module, is_custom_module, is_mase_leaf_layer,) ) return is_leaf_module @@ -152,9 +140,7 @@ def is_leaf_module( ) graph_module = hf_symbolic_trace( - model, - tracer_cls=tracer_cls, - input_names=hf_input_names, + model, tracer_cls=tracer_cls, input_names=hf_input_names, ) graph_module.custom_ops = custom_ops @@ -307,10 +293,7 @@ def __init__( self.model.additional_inputs = [] elif isinstance(model, torch.nn.Module): self.model = trace_torch_module( - model, - cf_args, - custom_ops, - hf_input_names=hf_input_names, + model, cf_args, custom_ops, hf_input_names=hf_input_names, ) else: raise ValueError( @@ -349,16 +332,11 @@ def from_module( ), f"model must be a torch.nn.Module. Received: {type(model)}" graph_module = trace_torch_module(model, cf_args, custom_ops) - return cls( - model=graph_module, - cf_args=cf_args, - ) + return cls(model=graph_module, cf_args=cf_args,) @classmethod def from_checkpoint( - cls, - checkpoint: str, - propagate_missing_metadata: bool = True, + cls, checkpoint: str, propagate_missing_metadata: bool = True, ): """ Load a MaseGraph from a checkpoint. A MaseGraph checkpoint consists of two files: @@ -393,18 +371,12 @@ def from_checkpoint( for node in mg.nodes: if node.name in loaded_meta.keys(): parameters = loaded_meta[node.name] - node.meta["mase"] = MaseMetadata( - node=node, - model=loaded_model, - ) + node.meta["mase"] = MaseMetadata(node=node, model=loaded_model,) node.meta["mase"].parameters = parameters else: # todo: propagate metadata for missing nodes logger.warning(f"Node {node.name} not found in loaded metadata.") - node.meta["mase"] = MaseMetadata( - node=node, - model=loaded_model, - ) + node.meta["mase"] = MaseMetadata(node=node, model=loaded_model,) for attr in [ "class_for_deserialization", @@ -417,8 +389,7 @@ def from_checkpoint( return mg def export( - self, - fname: str = "masegraph", + self, fname: str = "masegraph", ): """ Export the MaseGraph to a pair of files: {fname}.pt and {fname}.mz. diff --git a/src/chop/ir/graph/mase_metadata.py b/src/chop/ir/graph/mase_metadata.py index 6f01d83c8..52fcf0127 100644 --- a/src/chop/ir/graph/mase_metadata.py +++ b/src/chop/ir/graph/mase_metadata.py @@ -100,9 +100,7 @@ class MaseMetadata: known_storage = ["BRAM"] def __init__( - self, - node=None, - model=None, + self, node=None, model=None, ): # Top-level model self.model = model diff --git a/src/chop/ir/onnx/mase_onnx_graph.py b/src/chop/ir/onnx/mase_onnx_graph.py index 40bb7915c..e4a8f52ed 100644 --- a/src/chop/ir/onnx/mase_onnx_graph.py +++ b/src/chop/ir/onnx/mase_onnx_graph.py @@ -9,11 +9,8 @@ class MaseOnnxGraph: - def __init__( - self, - model_proto: onnx.onnx_ml_pb2.ModelProto, - model_name: str = None, + self, model_proto: onnx.onnx_ml_pb2.ModelProto, model_name: str = None, ): self.model_proto = model_proto self.graph = model_proto.graph diff --git a/src/chop/ir/onnx/utils.py b/src/chop/ir/onnx/utils.py index bb65aa8b2..c5250bba0 100644 --- a/src/chop/ir/onnx/utils.py +++ b/src/chop/ir/onnx/utils.py @@ -188,10 +188,7 @@ def onnx_to_torch_dtype(dtype): "target": torch.mean, "input_mapping": ["input"], "attribute_mapping": {"keepdims": "", "axes": ""}, - "attribute_transform": { - "keepdims": None, - "axes": None, - }, + "attribute_transform": {"keepdims": None, "axes": None,}, "attribute_default": {"keepdims": 1, "axes": None}, }, "Expand": { @@ -335,9 +332,7 @@ def onnx_to_torch_dtype(dtype): "input_mapping": ["input"], "attribute_mapping": {"perm": "dims"}, "attribute_transform": {"perm": lambda x: [i for i in x]}, - "attribute_default": { - "perm": None, - }, + "attribute_default": {"perm": None,}, }, "Max": { "fx_op": "call_function", diff --git a/src/chop/models/bert/modeling_bert_quantized.py b/src/chop/models/bert/modeling_bert_quantized.py index 8061d2593..c6747f0e9 100644 --- a/src/chop/models/bert/modeling_bert_quantized.py +++ b/src/chop/models/bert/modeling_bert_quantized.py @@ -561,9 +561,7 @@ def __init__(self, config, quant_config: dict): super().__init__() # self.dense = nn.Linear(config.hidden_size, config.intermediate_size) self.dense = get_quantized_cls("linear", quant_config["dense"])( - config.hidden_size, - config.intermediate_size, - config=quant_config["dense"], + config.hidden_size, config.intermediate_size, config=quant_config["dense"], ) self.quant_config = quant_config if isinstance(config.hidden_act, str): @@ -582,9 +580,7 @@ def __init__(self, config, quant_config): super().__init__() # self.dense = nn.Linear(config.intermediate_size, config.hidden_size) self.dense = get_quantized_cls("linear", quant_config["dense"])( - config.intermediate_size, - config.hidden_size, - config=quant_config["dense"], + config.intermediate_size, config.hidden_size, config=quant_config["dense"], ) self.quant_config = quant_config self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) diff --git a/src/chop/models/bert/quant_config_bert.py b/src/chop/models/bert/quant_config_bert.py index f66c45a4c..4dadb97f0 100644 --- a/src/chop/models/bert/quant_config_bert.py +++ b/src/chop/models/bert/quant_config_bert.py @@ -88,10 +88,7 @@ def create_a_layer_config( return qc -def _parse_and_complete_config( - config: dict, - num_hidden_layers: int, -) -> dict: +def _parse_and_complete_config(config: dict, num_hidden_layers: int,) -> dict: assert "default" in config, "Must provide a default config" default_qc: dict = config["default"] linear_qc: dict = parse_node_config( diff --git a/src/chop/models/cnv/cnv.py b/src/chop/models/cnv/cnv.py index 00df962a6..ee1912f18 100644 --- a/src/chop/models/cnv/cnv.py +++ b/src/chop/models/cnv/cnv.py @@ -4,12 +4,8 @@ from typing import Any import numpy as np -from chop.nn.quantized.modules.conv2d import ( - Conv2dBinaryResidualSign, -) -from chop.nn.quantized.modules.linear import ( - LinearBinaryResidualSign, -) +from chop.nn.quantized.modules.conv2d import Conv2dBinaryResidualSign +from chop.nn.quantized.modules.linear import LinearBinaryResidualSign from chop.models.utils import register_mase_model, register_mase_checkpoint """ @@ -206,8 +202,7 @@ def forward(self, x: Tensor) -> Tensor: # Getters ------------------------------------------------------------------------------ @register_mase_checkpoint("cnv-toy") def get_cnv_toy( - pretrained=False, - **kwargs: Any, + pretrained=False, **kwargs: Any, ): # image_size = info["image_size"] info = kwargs["dataset_info"] @@ -217,8 +212,7 @@ def get_cnv_toy( @register_mase_checkpoint("cnv") def get_cnv( - pretrained=False, - **kwargs: Any, + pretrained=False, **kwargs: Any, ): # image_size = info["image_size"] info = kwargs["dataset_info"] @@ -228,8 +222,7 @@ def get_cnv( @register_mase_checkpoint("cnv_residual") def get_cnv_residual( - pretrained=False, - **kwargs: Any, + pretrained=False, **kwargs: Any, ): # image_size = info["image_size"] info = kwargs["dataset_info"] diff --git a/src/chop/models/cswin/cswintransformer.py b/src/chop/models/cswin/cswintransformer.py index cbd85cda6..449a22d30 100644 --- a/src/chop/models/cswin/cswintransformer.py +++ b/src/chop/models/cswin/cswintransformer.py @@ -90,7 +90,7 @@ def __init__( self.num_heads = num_heads head_dim = dim // num_heads # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights - self.scale = qk_scale or head_dim**-0.5 + self.scale = qk_scale or head_dim ** -0.5 if idx == -1: H_sp, W_sp = self.resolution, self.resolution elif idx == 0: @@ -362,9 +362,9 @@ def __init__( super().__init__() self.use_chk = use_chk self.num_classes = num_classes - self.num_features = self.embed_dim = ( - embed_dim # num_features for consistency with other models - ) + self.num_features = ( + self.embed_dim + ) = embed_dim # num_features for consistency with other models heads = num_heads self.stage1_conv_embed = nn.Sequential( diff --git a/src/chop/models/deit/deit_v2.py b/src/chop/models/deit/deit_v2.py index 4f0128216..21b2a5ad4 100644 --- a/src/chop/models/deit/deit_v2.py +++ b/src/chop/models/deit/deit_v2.py @@ -28,7 +28,7 @@ def __init__( super().__init__() self.num_heads = num_heads head_dim = dim // num_heads - self.scale = qk_scale or head_dim**-0.5 + self.scale = qk_scale or head_dim ** -0.5 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) diff --git a/src/chop/models/efficientnet/efficientnet.py b/src/chop/models/efficientnet/efficientnet.py index 5d2f7730a..b3af62d05 100644 --- a/src/chop/models/efficientnet/efficientnet.py +++ b/src/chop/models/efficientnet/efficientnet.py @@ -499,8 +499,7 @@ def forward(self, x: Tensor) -> Tensor: def _efficientnet_conf( - arch: str, - **kwargs: Any, + arch: str, **kwargs: Any, ) -> Tuple[Sequence[Union[MBConvConfig, FusedMBConvConfig]], Optional[int]]: inverted_residual_setting: Sequence[Union[MBConvConfig, FusedMBConvConfig]] if arch.startswith("efficientnet_b"): @@ -592,9 +591,7 @@ def _efficientnet( def get_efficientnet_b0( - info: Dict, - pretrained: bool = False, - **kwargs: Any, + info: Dict, pretrained: bool = False, **kwargs: Any, ): num_classes = info.num_classes if pretrained: @@ -616,9 +613,7 @@ def get_efficientnet_b0( def get_efficientnet_b3( - info: Dict, - pretrained: bool = False, - **kwargs: Any, + info: Dict, pretrained: bool = False, **kwargs: Any, ): num_classes = info.num_classes if pretrained: @@ -641,9 +636,7 @@ def get_efficientnet_b3( def get_efficientnet_v2_s( - info: Dict, - pretrained: bool = False, - **kwargs: Any, + info: Dict, pretrained: bool = False, **kwargs: Any, ): num_classes = info.num_classes if pretrained: @@ -663,9 +656,7 @@ def get_efficientnet_v2_s( def get_efficientnet_v2_m( - info: Dict, - pretrained: bool = False, - **kwargs: Any, + info: Dict, pretrained: bool = False, **kwargs: Any, ): num_classes = info.num_classes if pretrained: @@ -685,9 +676,7 @@ def get_efficientnet_v2_m( def get_efficientnet_v2_l( - info: Dict, - pretrained: bool = False, - **kwargs: Any, + info: Dict, pretrained: bool = False, **kwargs: Any, ): num_classes = info.num_classes if pretrained: diff --git a/src/chop/models/lfc/lfc.py b/src/chop/models/lfc/lfc.py index 70da9df45..410359a25 100644 --- a/src/chop/models/lfc/lfc.py +++ b/src/chop/models/lfc/lfc.py @@ -42,9 +42,7 @@ def forward(self, x): # Getters ------------------------------------------------------------------------------ def get_lfc( - info, - pretrained=False, - **kwargs: Any, + info, pretrained=False, **kwargs: Any, ): image_size = info["image_size"] num_classes = info.num_classes diff --git a/src/chop/models/llama/modeling_llama_llora.py b/src/chop/models/llama/modeling_llama_llora.py index b45dd8494..45eb2f210 100644 --- a/src/chop/models/llama/modeling_llama_llora.py +++ b/src/chop/models/llama/modeling_llama_llora.py @@ -195,10 +195,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids): class LlamaMLP(nn.Module): def __init__( - self, - hidden_size: int, - intermediate_size: int, - hidden_act: str, + self, hidden_size: int, intermediate_size: int, hidden_act: str, ): super().__init__() # fmt: off diff --git a/src/chop/models/llama/modeling_llama_sparse.py b/src/chop/models/llama/modeling_llama_sparse.py index f5f37eba0..21359ea34 100644 --- a/src/chop/models/llama/modeling_llama_sparse.py +++ b/src/chop/models/llama/modeling_llama_sparse.py @@ -195,10 +195,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids): class LlamaMLP(nn.Module): def __init__( - self, - hidden_size: int, - intermediate_size: int, - hidden_act: str, + self, hidden_size: int, intermediate_size: int, hidden_act: str, ): super().__init__() # fmt: off diff --git a/src/chop/models/mobilenet_v2/mobilenet_v2.py b/src/chop/models/mobilenet_v2/mobilenet_v2.py index d8f4020dd..0980d58dd 100644 --- a/src/chop/models/mobilenet_v2/mobilenet_v2.py +++ b/src/chop/models/mobilenet_v2/mobilenet_v2.py @@ -288,8 +288,7 @@ def __init__( # building classifier self.classifier = nn.Sequential( - nn.Dropout(p=dropout), - nn.Linear(self.last_channel, num_classes), + nn.Dropout(p=dropout), nn.Linear(self.last_channel, num_classes), ) # weight initialization diff --git a/src/chop/models/nerf/nerf_vision.py b/src/chop/models/nerf/nerf_vision.py index f8591f1a4..18338940b 100644 --- a/src/chop/models/nerf/nerf_vision.py +++ b/src/chop/models/nerf/nerf_vision.py @@ -139,9 +139,7 @@ def load_weights_from_keras(self, weights): # Getters ------------------------------------------------------------------------------ def get_nerf( - info, - pretrained=False, - **kwargs: Any, + info, pretrained=False, **kwargs: Any, ): # image_size = info["image_size"] num_classes = info.num_classes diff --git a/src/chop/models/opt/modeling_opt.py b/src/chop/models/opt/modeling_opt.py index e88db8c1f..430f77df8 100644 --- a/src/chop/models/opt/modeling_opt.py +++ b/src/chop/models/opt/modeling_opt.py @@ -132,7 +132,7 @@ def __init__( f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" f" and `num_heads`: {num_heads})." ) - self.scaling = self.head_dim**-0.5 + self.scaling = self.head_dim ** -0.5 self.is_decoder = is_decoder self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) diff --git a/src/chop/models/opt/modeling_opt_lora.py b/src/chop/models/opt/modeling_opt_lora.py index 810cfe0d6..30bb21f20 100644 --- a/src/chop/models/opt/modeling_opt_lora.py +++ b/src/chop/models/opt/modeling_opt_lora.py @@ -130,7 +130,7 @@ def __init__( f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" f" and `num_heads`: {num_heads})." ) - self.scaling = self.head_dim**-0.5 + self.scaling = self.head_dim ** -0.5 self.is_decoder = is_decoder lora_config = config.lora_config[f"model_layer_{layer_id}"]["self_attn"] diff --git a/src/chop/models/opt/modeling_opt_quantized.py b/src/chop/models/opt/modeling_opt_quantized.py index d1086f413..d6eb94343 100644 --- a/src/chop/models/opt/modeling_opt_quantized.py +++ b/src/chop/models/opt/modeling_opt_quantized.py @@ -168,7 +168,7 @@ def __init__( f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" f" and `num_heads`: {num_heads})." ) - self.scaling = self.head_dim**-0.5 + self.scaling = self.head_dim ** -0.5 self.is_decoder = is_decoder # fmt:off diff --git a/src/chop/models/opt/modeling_opt_sparse.py b/src/chop/models/opt/modeling_opt_sparse.py index f39a29cc9..4addd6525 100644 --- a/src/chop/models/opt/modeling_opt_sparse.py +++ b/src/chop/models/opt/modeling_opt_sparse.py @@ -130,7 +130,7 @@ def __init__( f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" f" and `num_heads`: {num_heads})." ) - self.scaling = self.head_dim**-0.5 + self.scaling = self.head_dim ** -0.5 self.is_decoder = is_decoder sparse_config = config.sparse_config[f"model_layer_{layer_id}"]["self_attn"] diff --git a/src/chop/models/opt/quant_config_opt_quantized.py b/src/chop/models/opt/quant_config_opt_quantized.py index a76f8adb8..150a1b2d6 100644 --- a/src/chop/models/opt/quant_config_opt_quantized.py +++ b/src/chop/models/opt/quant_config_opt_quantized.py @@ -32,9 +32,7 @@ def create_a_layer_config( - linear_qc: dict = None, - bmm_qc: dict = None, - layer_qc=None, + linear_qc: dict = None, bmm_qc: dict = None, layer_qc=None, ) -> dict: if (layer_qc is None and bmm_qc is None) and layer_qc is None: raise ValueError("Must provide either (linear_qc & bmm_qc ) or layer_qc") diff --git a/src/chop/models/pvt/pvt.py b/src/chop/models/pvt/pvt.py index ac3a18708..2474fb388 100644 --- a/src/chop/models/pvt/pvt.py +++ b/src/chop/models/pvt/pvt.py @@ -58,7 +58,7 @@ def __init__( self.dim = dim self.num_heads = num_heads head_dim = dim // num_heads - self.scale = qk_scale or head_dim**-0.5 + self.scale = qk_scale or head_dim ** -0.5 self.q = nn.Linear(dim, dim, bias=qkv_bias) self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) @@ -182,12 +182,7 @@ def forward(self, x): @register_mase_model( "pvt", - checkpoints=[ - "pvt_tiny", - "pvt_small", - "pvt_medium", - "pvt_large", - ], + checkpoints=["pvt_tiny", "pvt_small", "pvt_medium", "pvt_large",], model_source="vision_others", task_type="vision", image_classification=True, diff --git a/src/chop/models/pvt/pvt_v2.py b/src/chop/models/pvt/pvt_v2.py index 3ffb45f6b..b70acb6d1 100644 --- a/src/chop/models/pvt/pvt_v2.py +++ b/src/chop/models/pvt/pvt_v2.py @@ -82,7 +82,7 @@ def __init__( self.dim = dim self.num_heads = num_heads head_dim = dim // num_heads - self.scale = qk_scale or head_dim**-0.5 + self.scale = qk_scale or head_dim ** -0.5 self.q = nn.Linear(dim, dim, bias=qkv_bias) self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) diff --git a/src/chop/models/repvgg/repvgg.py b/src/chop/models/repvgg/repvgg.py index 086501fc5..5ef7998e0 100644 --- a/src/chop/models/repvgg/repvgg.py +++ b/src/chop/models/repvgg/repvgg.py @@ -143,14 +143,14 @@ def get_custom_L2(self): .detach() ) - l2_loss_circle = (K3**2).sum() - ( + l2_loss_circle = (K3 ** 2).sum() - ( K3[:, :, 1:2, 1:2] ** 2 ).sum() # The L2 loss of the "circle" of weights in 3x3 kernel. Use regular L2 on them. eq_kernel = ( K3[:, :, 1:2, 1:2] * t3 + K1 * t1 ) # The equivalent resultant central point of 3x3 kernel. l2_loss_eq_kernel = ( - eq_kernel**2 / (t3**2 + t1**2) + eq_kernel ** 2 / (t3 ** 2 + t1 ** 2) ).sum() # Normalize for an L2 coefficient comparable to regular L2. return l2_loss_eq_kernel + l2_loss_circle diff --git a/src/chop/models/resnet/resnet.py b/src/chop/models/resnet/resnet.py index 313c8ce36..b339b72a7 100644 --- a/src/chop/models/resnet/resnet.py +++ b/src/chop/models/resnet/resnet.py @@ -157,13 +157,7 @@ def forward(self, x: Tensor) -> Tensor: @register_mase_model( name="resnet", - checkpoints=[ - "resnet18", - "resnet34", - "resnet50", - "resnet101", - "wide_resnet50_2", - ], + checkpoints=["resnet18", "resnet34", "resnet50", "resnet101", "wide_resnet50_2",], model_source="torchvision", task_type="vision", image_classification=True, @@ -337,10 +331,7 @@ def _resnet( @register_mase_checkpoint("resnet18") -def get_resnet18( - pretrained: bool = False, - **kwargs: Any, -) -> ResNet: +def get_resnet18(pretrained: bool = False, **kwargs: Any,) -> ResNet: """ResNet-18 from `Deep Residual Learning for Image Recognition `__.""" if pretrained: pretrained_weight_cls = ResNet18_Weights.IMAGENET1K_V1 @@ -348,18 +339,12 @@ def get_resnet18( pretrained_weight_cls = None return _resnet( - BasicBlock, - [2, 2, 2, 2], - pretrained_weight_cls=pretrained_weight_cls, - **kwargs, + BasicBlock, [2, 2, 2, 2], pretrained_weight_cls=pretrained_weight_cls, **kwargs, ) @register_mase_checkpoint("resnet34") -def get_resnet34( - pretrained: bool = False, - **kwargs: Any, -) -> ResNet: +def get_resnet34(pretrained: bool = False, **kwargs: Any,) -> ResNet: """ResNet-34 from `Deep Residual Learning for Image Recognition `__.""" if pretrained: pretrained_weight_cls = ResNet34_Weights.IMAGENET1K_V1 @@ -367,18 +352,12 @@ def get_resnet34( pretrained_weight_cls = None return _resnet( - BasicBlock, - [2, 2, 2, 2], - pretrained_weight_cls=pretrained_weight_cls, - **kwargs, + BasicBlock, [2, 2, 2, 2], pretrained_weight_cls=pretrained_weight_cls, **kwargs, ) @register_mase_checkpoint("resnet50") -def get_resnet50( - pretrained: bool = False, - **kwargs: Any, -) -> ResNet: +def get_resnet50(pretrained: bool = False, **kwargs: Any,) -> ResNet: """ResNet-50 from `Deep Residual Learning for Image Recognition `__.""" info = kwargs["dataset_info"] if pretrained: @@ -387,18 +366,12 @@ def get_resnet50( pretrained_weight_cls = None return _resnet( - Bottleneck, - [3, 4, 6, 3], - pretrained_weight_cls=pretrained_weight_cls, - **kwargs, + Bottleneck, [3, 4, 6, 3], pretrained_weight_cls=pretrained_weight_cls, **kwargs, ) @register_mase_checkpoint("resnet101") -def get_resnet101( - pretrained: bool = False, - **kwargs: Any, -) -> ResNet: +def get_resnet101(pretrained: bool = False, **kwargs: Any,) -> ResNet: """ResNet-101 from `Deep Residual Learning for Image Recognition `__.""" info = kwargs["dataset_info"] if pretrained: @@ -407,17 +380,13 @@ def get_resnet101( pretrained_weight_cls = None return _resnet( - BasicBlock, - [2, 2, 2, 2], - pretrained_weight_cls=pretrained_weight_cls, - **kwargs, + BasicBlock, [2, 2, 2, 2], pretrained_weight_cls=pretrained_weight_cls, **kwargs, ) @register_mase_checkpoint("wide_resnet50_2") def get_wide_resnet50_2( - pretrained: bool = False, - **kwargs, + pretrained: bool = False, **kwargs, ): """ `Wide Residual Networks `_. @@ -435,8 +404,5 @@ def get_wide_resnet50_2( _ovewrite_named_param(kwargs, "width_per_group", 64 * 2) return _resnet( - Bottleneck, - [3, 4, 6, 3], - pretrained_weight_cls=pretrained_weight_cls, - **kwargs, + Bottleneck, [3, 4, 6, 3], pretrained_weight_cls=pretrained_weight_cls, **kwargs, ) diff --git a/src/chop/models/toy/toy.py b/src/chop/models/toy/toy.py index e88e030f6..3890c4a0f 100644 --- a/src/chop/models/toy/toy.py +++ b/src/chop/models/toy/toy.py @@ -160,8 +160,7 @@ def _conv_block(self, conv_class, *args): # Getters ------------------------------------------------------------------------------ @register_mase_checkpoint("toy") def get_toynet( - pretrained=False, - **kwargs: Any, + pretrained=False, **kwargs: Any, ): info = kwargs["dataset_info"] image_size = info.image_size @@ -170,9 +169,7 @@ def get_toynet( def get_toy_tiny( - info, - pretrained=False, - **kwargs: Any, + info, pretrained=False, **kwargs: Any, ): image_size = info.image_size num_classes = info.num_classes @@ -180,9 +177,7 @@ def get_toy_tiny( def get_toy_testmodel( - info, - pretrained=False, - **kwargs: Any, + info, pretrained=False, **kwargs: Any, ): image_size = info["image_size"] num_classes = info.num_classes @@ -191,9 +186,7 @@ def get_toy_testmodel( def get_toy_convnet( - info, - pretrained=False, - **kwargs: Any, + info, pretrained=False, **kwargs: Any, ): # NOTE: The model isn't configurable through the CLI or a configuration file yet. num_classes = info.num_classes diff --git a/src/chop/models/vgg_cifar/vgg_orig.py b/src/chop/models/vgg_cifar/vgg_orig.py index cf906e888..f744c4404 100644 --- a/src/chop/models/vgg_cifar/vgg_orig.py +++ b/src/chop/models/vgg_cifar/vgg_orig.py @@ -189,12 +189,7 @@ class VGG11_Weights(WeightsEnum): meta={ **_COMMON_META, "num_params": 132863336, - "_metrics": { - "ImageNet-1K": { - "acc@1": 69.020, - "acc@5": 88.628, - } - }, + "_metrics": {"ImageNet-1K": {"acc@1": 69.020, "acc@5": 88.628,}}, "_ops": 7.609, "_file_size": 506.84, }, @@ -209,12 +204,7 @@ class VGG11_BN_Weights(WeightsEnum): meta={ **_COMMON_META, "num_params": 132868840, - "_metrics": { - "ImageNet-1K": { - "acc@1": 70.370, - "acc@5": 89.810, - } - }, + "_metrics": {"ImageNet-1K": {"acc@1": 70.370, "acc@5": 89.810,}}, "_ops": 7.609, "_file_size": 506.881, }, @@ -229,12 +219,7 @@ class VGG13_Weights(WeightsEnum): meta={ **_COMMON_META, "num_params": 133047848, - "_metrics": { - "ImageNet-1K": { - "acc@1": 69.928, - "acc@5": 89.246, - } - }, + "_metrics": {"ImageNet-1K": {"acc@1": 69.928, "acc@5": 89.246,}}, "_ops": 11.308, "_file_size": 507.545, }, @@ -249,12 +234,7 @@ class VGG13_BN_Weights(WeightsEnum): meta={ **_COMMON_META, "num_params": 133053736, - "_metrics": { - "ImageNet-1K": { - "acc@1": 71.586, - "acc@5": 90.374, - } - }, + "_metrics": {"ImageNet-1K": {"acc@1": 71.586, "acc@5": 90.374,}}, "_ops": 11.308, "_file_size": 507.59, }, @@ -269,12 +249,7 @@ class VGG16_Weights(WeightsEnum): meta={ **_COMMON_META, "num_params": 138357544, - "_metrics": { - "ImageNet-1K": { - "acc@1": 71.592, - "acc@5": 90.382, - } - }, + "_metrics": {"ImageNet-1K": {"acc@1": 71.592, "acc@5": 90.382,}}, "_ops": 15.47, "_file_size": 527.796, }, @@ -294,10 +269,7 @@ class VGG16_Weights(WeightsEnum): "categories": None, "recipe": "https://github.com/amdegroot/ssd.pytorch#training-ssd", "_metrics": { - "ImageNet-1K": { - "acc@1": float("nan"), - "acc@5": float("nan"), - } + "ImageNet-1K": {"acc@1": float("nan"), "acc@5": float("nan"),} }, "_ops": 15.47, "_file_size": 527.802, @@ -318,12 +290,7 @@ class VGG16_BN_Weights(WeightsEnum): meta={ **_COMMON_META, "num_params": 138365992, - "_metrics": { - "ImageNet-1K": { - "acc@1": 73.360, - "acc@5": 91.516, - } - }, + "_metrics": {"ImageNet-1K": {"acc@1": 73.360, "acc@5": 91.516,}}, "_ops": 15.47, "_file_size": 527.866, }, @@ -338,12 +305,7 @@ class VGG19_Weights(WeightsEnum): meta={ **_COMMON_META, "num_params": 143667240, - "_metrics": { - "ImageNet-1K": { - "acc@1": 72.376, - "acc@5": 90.876, - } - }, + "_metrics": {"ImageNet-1K": {"acc@1": 72.376, "acc@5": 90.876,}}, "_ops": 19.632, "_file_size": 548.051, }, @@ -358,12 +320,7 @@ class VGG19_BN_Weights(WeightsEnum): meta={ **_COMMON_META, "num_params": 143678248, - "_metrics": { - "ImageNet-1K": { - "acc@1": 74.218, - "acc@5": 91.842, - } - }, + "_metrics": {"ImageNet-1K": {"acc@1": 74.218, "acc@5": 91.842,}}, "_ops": 19.632, "_file_size": 548.143, }, diff --git a/src/chop/models/vision/snn/snn_toy.py b/src/chop/models/vision/snn/snn_toy.py index 5915f1ec6..52c86d821 100644 --- a/src/chop/models/vision/snn/snn_toy.py +++ b/src/chop/models/vision/snn/snn_toy.py @@ -26,9 +26,7 @@ def forward(self, x: torch.Tensor): # Getters ------------------------------------------------------------------------------ def get_snn_toy( - info, - pretrained=False, - **kwargs: Any, + info, pretrained=False, **kwargs: Any, ): tau = info["tau"] num_classes = info.num_classes diff --git a/src/chop/models/vision/snn/spikingResformer.py b/src/chop/models/vision/snn/spikingResformer.py index ff74d76df..153cae155 100644 --- a/src/chop/models/vision/snn/spikingResformer.py +++ b/src/chop/models/vision/snn/spikingResformer.py @@ -124,11 +124,7 @@ def no_weight_decay(self): @register_model def spikingresformer_ti(**kwargs): return SpikingResformer( - [ - ["DSSA", "GWFFN"] * 1, - ["DSSA", "GWFFN"] * 2, - ["DSSA", "GWFFN"] * 3, - ], + [["DSSA", "GWFFN"] * 1, ["DSSA", "GWFFN"] * 2, ["DSSA", "GWFFN"] * 3,], [64, 192, 384], [1, 3, 6], [4, 2, 1], @@ -140,11 +136,7 @@ def spikingresformer_ti(**kwargs): @register_model def spikingresformer_s(**kwargs): return SpikingResformer( - [ - ["DSSA", "GWFFN"] * 1, - ["DSSA", "GWFFN"] * 2, - ["DSSA", "GWFFN"] * 3, - ], + [["DSSA", "GWFFN"] * 1, ["DSSA", "GWFFN"] * 2, ["DSSA", "GWFFN"] * 3,], [64, 256, 512], [1, 4, 8], [4, 2, 1], @@ -156,11 +148,7 @@ def spikingresformer_s(**kwargs): @register_model def spikingresformer_m(**kwargs): return SpikingResformer( - [ - ["DSSA", "GWFFN"] * 1, - ["DSSA", "GWFFN"] * 2, - ["DSSA", "GWFFN"] * 3, - ], + [["DSSA", "GWFFN"] * 1, ["DSSA", "GWFFN"] * 2, ["DSSA", "GWFFN"] * 3,], [64, 384, 768], [1, 6, 12], [4, 2, 1], @@ -172,11 +160,7 @@ def spikingresformer_m(**kwargs): @register_model def spikingresformer_l(**kwargs): return SpikingResformer( - [ - ["DSSA", "GWFFN"] * 1, - ["DSSA", "GWFFN"] * 2, - ["DSSA", "GWFFN"] * 3, - ], + [["DSSA", "GWFFN"] * 1, ["DSSA", "GWFFN"] * 2, ["DSSA", "GWFFN"] * 3,], [128, 512, 1024], [2, 8, 16], [4, 2, 1], @@ -188,18 +172,13 @@ def spikingresformer_l(**kwargs): @register_model def spikingresformer_dvsg(**kwargs): return SpikingResformer( - [ - ["DSSA", "GWFFN"] * 1, - ["DSSA", "GWFFN"] * 2, - ["DSSA", "GWFFN"] * 3, - ], + [["DSSA", "GWFFN"] * 1, ["DSSA", "GWFFN"] * 2, ["DSSA", "GWFFN"] * 3,], [32, 96, 192], [1, 3, 6], [4, 2, 1], in_channels=3, prologue=nn.Sequential( - Conv2d(3, 32, 3, 1, 1, bias=False, step_mode="m"), - BN(32), + Conv2d(3, 32, 3, 1, 1, bias=False, step_mode="m"), BN(32), ), group_size=32, activation=PLIF, @@ -210,18 +189,13 @@ def spikingresformer_dvsg(**kwargs): @register_model def spikingresformer_cifar(**kwargs): return SpikingResformer( - [ - ["DSSA", "GWFFN"] * 1, - ["DSSA", "GWFFN"] * 2, - ["DSSA", "GWFFN"] * 3, - ], + [["DSSA", "GWFFN"] * 1, ["DSSA", "GWFFN"] * 2, ["DSSA", "GWFFN"] * 3,], [64, 192, 384], [1, 3, 6], [4, 2, 1], in_channels=3, prologue=nn.Sequential( - Conv2d(3, 64, 3, 1, 1, bias=False, step_mode="m"), - BN(64), + Conv2d(3, 64, 3, 1, 1, bias=False, step_mode="m"), BN(64), ), **kwargs, ) diff --git a/src/chop/nn/backward/modules/__init__.py b/src/chop/nn/backward/modules/__init__.py index 1dcc39e65..279e0c867 100644 --- a/src/chop/nn/backward/modules/__init__.py +++ b/src/chop/nn/backward/modules/__init__.py @@ -1,6 +1,4 @@ -from .linear import ( - CustomLinear, -) +from .linear import CustomLinear custom_module_map = { diff --git a/src/chop/nn/functional/softermax.py b/src/chop/nn/functional/softermax.py index 8653a6aa0..348d7ebea 100644 --- a/src/chop/nn/functional/softermax.py +++ b/src/chop/nn/functional/softermax.py @@ -11,7 +11,7 @@ def softermax(input: Tensor, dim: int) -> Tensor: Tensor: Output tensor """ out = input - input.max(dim=dim, keepdim=True).values.floor() - out = 2**out + out = 2 ** out row_sum = out.sum(dim=dim, keepdim=True) # Elementwise division out = out / row_sum diff --git a/src/chop/nn/modules/gqa.py b/src/chop/nn/modules/gqa.py index f30b7a8d8..38188d51d 100644 --- a/src/chop/nn/modules/gqa.py +++ b/src/chop/nn/modules/gqa.py @@ -127,12 +127,7 @@ def _qkv_states(self, x: Tensor, batch_size: int, seq_len: int): # return x def _attention_mechanism( - self, - query: Tensor, - key: Tensor, - value: Tensor, - batch_size: int, - seq_len: int, + self, query: Tensor, key: Tensor, value: Tensor, batch_size: int, seq_len: int, ): key = repeat_kv(key, n_rep=self.group_size) value = repeat_kv(value, n_rep=self.group_size) @@ -169,9 +164,7 @@ def forward(self, x: Tensor): GROUPS = 4 gqa_module = GroupedQueryAttention( - embed_dim=EMBED_DIM, - num_heads=NUM_HEADS, - num_kv_heads=GROUPS, + embed_dim=EMBED_DIM, num_heads=NUM_HEADS, num_kv_heads=GROUPS, ) x_in = torch.rand(BATCH, SEQ_LEN, EMBED_DIM) diff --git a/src/chop/nn/modules/lora.py b/src/chop/nn/modules/lora.py index bbd2ea023..f101d0c03 100644 --- a/src/chop/nn/modules/lora.py +++ b/src/chop/nn/modules/lora.py @@ -93,11 +93,7 @@ def reset_lora_parameters(self, adapter_name): class LinearLora(nn.Linear, LoraLayer): # Lora implemented in a dense layer def __init__( - self, - in_features: int, - out_features: int, - config: dict = None, - **kwargs, + self, in_features: int, out_features: int, config: dict = None, **kwargs, ): self.config = config init_lora_weights = self.config.get("init_lora_weights", True) @@ -222,12 +218,7 @@ def mark_only_lora_as_trainable(model: nn.Module, bias: str = "none") -> None: # Simple Lora implementation from https://pytorch.org/torchtune/stable/tutorials/lora_finetune.html class LoRALinear(nn.Module): def __init__( - self, - in_dim: int, - out_dim: int, - rank: int, - alpha: float, - dropout: float, + self, in_dim: int, out_dim: int, rank: int, alpha: float, dropout: float, ): super().__init__() # These are the weights from the original pretrained model diff --git a/src/chop/nn/modules/sparse.py b/src/chop/nn/modules/sparse.py index 67508cf3d..5a3f749a1 100644 --- a/src/chop/nn/modules/sparse.py +++ b/src/chop/nn/modules/sparse.py @@ -92,11 +92,7 @@ def reset_sparse_parameters(self, adapter_name): class LinearSparse(nn.Linear, SparseLayer): def __init__( - self, - in_features: int, - out_features: int, - config: dict = None, - **kwargs, + self, in_features: int, out_features: int, config: dict = None, **kwargs, ): self.config = config init_sparse_weights = self.config.get("init_sparse_weights", True) @@ -159,9 +155,7 @@ def _linear(self, input: torch.Tensor) -> torch.Tensor: def update_weight_selection(self, k): w_flat = self.weight.flatten() _, self.idx = torch.topk( - self.index_method(w_flat, self.idx_method), - k, - sorted=True, + self.index_method(w_flat, self.idx_method), k, sorted=True, ) self.selected_weights = torch.gather(w_flat, dim=0, index=self.idx) @@ -192,10 +186,7 @@ def forward(self, x: torch.Tensor): # Scatter adapted values into weight tensor adapted_weights = torch.scatter( - self.zero_tensor.to(x.device), - dim=0, - index=self.idx, - src=scaled_output, + self.zero_tensor.to(x.device), dim=0, index=self.idx, src=scaled_output, ).view(self.unflattened_size) self.step += 1 @@ -203,9 +194,7 @@ def forward(self, x: torch.Tensor): x = x.to(sparse.weight.dtype) result = F.linear( - dropout(x), - transpose(new_weight, self.fan_in_fan_out), - bias=self.bias, + dropout(x), transpose(new_weight, self.fan_in_fan_out), bias=self.bias, ) else: diff --git a/src/chop/nn/mx/activations.py b/src/chop/nn/mx/activations.py index 39702cc29..27fafc352 100644 --- a/src/chop/nn/mx/activations.py +++ b/src/chop/nn/mx/activations.py @@ -434,10 +434,7 @@ def forward(ctx, input, inplace=False, mx_specs=None, name=None): @staticmethod def backward(ctx, grad_output): - ( - y, - sig_x, - ) = ctx.saved_tensors + (y, sig_x,) = ctx.saved_tensors grad_output = vec_quantize(grad_output, mx_specs=ctx.mx_specs) temp = vec_sub(1.0, sig_x, mx_specs=ctx.mx_specs) diff --git a/src/chop/nn/mx/bmm.py b/src/chop/nn/mx/bmm.py index 6fac4eef8..50071d06e 100644 --- a/src/chop/nn/mx/bmm.py +++ b/src/chop/nn/mx/bmm.py @@ -67,9 +67,7 @@ def backward(ctx, grad_out): in1, in2 = ctx.saved_tensors grad_out = quantize_elemwise_op( - grad_out, - mx_specs=ctx.mx_specs, - round=ctx.mx_specs["round_grad_input"], + grad_out, mx_specs=ctx.mx_specs, round=ctx.mx_specs["round_grad_input"], ) ##################################################### @@ -111,14 +109,10 @@ def backward(ctx, grad_out): # element-wise quantize for grad_in1 and grad_in2 grad_in1 = quantize_elemwise_op( - grad_in1, - mx_specs=ctx.mx_specs, - round=ctx.mx_specs["round_grad_input"], + grad_in1, mx_specs=ctx.mx_specs, round=ctx.mx_specs["round_grad_input"], ) grad_in2 = quantize_elemwise_op( - grad_in2, - mx_specs=ctx.mx_specs, - round=ctx.mx_specs["round_grad_input"], + grad_in2, mx_specs=ctx.mx_specs, round=ctx.mx_specs["round_grad_input"], ) return (grad_in1, grad_in2, None, None) diff --git a/src/chop/nn/mx/convolution.py b/src/chop/nn/mx/convolution.py index 62b8dd2d2..9d2aaed17 100644 --- a/src/chop/nn/mx/convolution.py +++ b/src/chop/nn/mx/convolution.py @@ -180,16 +180,10 @@ def forward( # weight is (out_channels, in_channels/groups, ..) # quantize along in_channels qid_input = quantize_mx_op( - bf_in, - mx_specs, - elem_format=mx_specs["a_elem_format"], - axes=[1], + bf_in, mx_specs, elem_format=mx_specs["a_elem_format"], axes=[1], ) qid_weight = quantize_mx_op( - bf_weight, - mx_specs, - elem_format=mx_specs["w_elem_format"], - axes=[1], + bf_weight, mx_specs, elem_format=mx_specs["w_elem_format"], axes=[1], ) # compute output @@ -213,9 +207,7 @@ def backward(ctx, grad_output): input, weight = ctx.saved_tensors grad_output = quantize_elemwise_op( - grad_output, - mx_specs=ctx.mx_specs, - round=ctx.mx_specs["round_grad_input"], + grad_output, mx_specs=ctx.mx_specs, round=ctx.mx_specs["round_grad_input"], ) ##################################################### @@ -225,10 +217,7 @@ def backward(ctx, grad_output): # output is (batch, out_channels, ...) # quantize along the batch dim qex_input = quantize_mx_op( - input, - ctx.mx_specs, - elem_format=ctx.mx_specs["a_elem_format"], - axes=[0], + input, ctx.mx_specs, elem_format=ctx.mx_specs["a_elem_format"], axes=[0], ) qex_grad_output = quantize_mx_op( grad_output, @@ -251,9 +240,7 @@ def backward(ctx, grad_output): # element-wise quantize for grad_weight grad_weight = quantize_elemwise_op( - grad_weight, - mx_specs=ctx.mx_specs, - round=ctx.mx_specs["round_grad_weight"], + grad_weight, mx_specs=ctx.mx_specs, round=ctx.mx_specs["round_grad_weight"], ) ##################################################### @@ -264,10 +251,7 @@ def backward(ctx, grad_output): # output is (batch, out_channels, ...) # reduction dim is out_channels qod_weight = quantize_mx_op( - weight, - ctx.mx_specs, - elem_format=ctx.mx_specs["w_elem_format"], - axes=[0], + weight, ctx.mx_specs, elem_format=ctx.mx_specs["w_elem_format"], axes=[0], ) qod_grad_output = quantize_mx_op( grad_output, @@ -289,9 +273,7 @@ def backward(ctx, grad_output): # element-wise quantize for grad_input grad_input = quantize_elemwise_op( - grad_input, - mx_specs=ctx.mx_specs, - round=ctx.mx_specs["round_grad_input"], + grad_input, mx_specs=ctx.mx_specs, round=ctx.mx_specs["round_grad_input"], ) ##################################################### diff --git a/src/chop/nn/mx/elemwise_ops.py b/src/chop/nn/mx/elemwise_ops.py index 5d2b4b8ec..21067bba1 100644 --- a/src/chop/nn/mx/elemwise_ops.py +++ b/src/chop/nn/mx/elemwise_ops.py @@ -33,16 +33,16 @@ # exponents smaller than -126 def _safe_lshift(x, bits, exp): if exp is None: - return x * (2**bits) + return x * (2 ** bits) else: - return x / (2**exp) * (2**bits) + return x / (2 ** exp) * (2 ** bits) def _safe_rshift(x, bits, exp): if exp is None: - return x / (2**bits) + return x / (2 ** bits) else: - return x / (2**bits) * (2**exp) + return x / (2 ** bits) * (2 ** exp) def _round_mantissa(A, bits, round, clamp=False): diff --git a/src/chop/nn/mx/formats.py b/src/chop/nn/mx/formats.py index 9a2c35189..ff2a9d726 100644 --- a/src/chop/nn/mx/formats.py +++ b/src/chop/nn/mx/formats.py @@ -52,14 +52,14 @@ def from_str(s): def _get_min_norm(ebits): """Valid for all float formats""" emin = 2 - (2 ** (ebits - 1)) - return 0 if ebits == 0 else 2**emin + return 0 if ebits == 0 else 2 ** emin def _get_max_norm(ebits, mbits): """Valid only for floats that define NaN""" assert ebits >= 5, "invalid for floats that don't define NaN" emax = 0 if ebits == 0 else 2 ** (ebits - 1) - 1 - return 2**emax * float(2 ** (mbits - 1) - 1) / 2 ** (mbits - 2) + return 2 ** emax * float(2 ** (mbits - 1) - 1) / 2 ** (mbits - 2) _FORMAT_CACHE = {} @@ -121,9 +121,9 @@ def _get_format_params(fmt): raise Exception("Unknown element format %s" % fmt) if fmt != ElemFormat.fp8_e4m3: - max_norm = 2**emax * float(2 ** (mbits - 1) - 1) / 2 ** (mbits - 2) + max_norm = 2 ** emax * float(2 ** (mbits - 1) - 1) / 2 ** (mbits - 2) else: - max_norm = 2**emax * 1.75 # FP8 has custom max_norm + max_norm = 2 ** emax * 1.75 # FP8 has custom max_norm min_norm = _get_min_norm(ebits) diff --git a/src/chop/nn/mx/linear.py b/src/chop/nn/mx/linear.py index 64f761c86..ef78fcedc 100644 --- a/src/chop/nn/mx/linear.py +++ b/src/chop/nn/mx/linear.py @@ -18,12 +18,7 @@ class LinearFunction(torch.autograd.Function): @staticmethod def forward( - ctx, - input, - weight, - bias=None, - mx_specs=None, - name=None, + ctx, input, weight, bias=None, mx_specs=None, name=None, ): # element-wise quantize for input bf_in = quantize_elemwise_op( @@ -90,9 +85,7 @@ def backward(ctx, grad_output): in_dim = weight.shape[1] grad_output = quantize_elemwise_op( - grad_output, - mx_specs=ctx.mx_specs, - round=ctx.mx_specs["round_grad_input"], + grad_output, mx_specs=ctx.mx_specs, round=ctx.mx_specs["round_grad_input"], ) ##################################################### @@ -122,9 +115,7 @@ def backward(ctx, grad_output): # Compute grad_weight grad_weight = torch_matmul(qex_grad_output.transpose(0, 1), qex_input) grad_weight = quantize_elemwise_op( - grad_weight, - mx_specs=ctx.mx_specs, - round=ctx.mx_specs["round_grad_weight"], + grad_weight, mx_specs=ctx.mx_specs, round=ctx.mx_specs["round_grad_weight"], ) ##################################################### @@ -150,9 +141,7 @@ def backward(ctx, grad_output): # Compute grad_input grad_input = torch_matmul(qos_grad_output, qos_weight) grad_input = quantize_elemwise_op( - grad_input, - mx_specs=ctx.mx_specs, - round=ctx.mx_specs["round_grad_input"], + grad_input, mx_specs=ctx.mx_specs, round=ctx.mx_specs["round_grad_input"], ) ##################################################### @@ -172,11 +161,7 @@ def backward(ctx, grad_output): def linear( - input, - weight, - bias=None, - mx_specs=None, - name=None, + input, weight, bias=None, mx_specs=None, name=None, ): mx_assert_test(mx_specs) if mx_specs is None: @@ -189,12 +174,7 @@ def linear( class Linear(torch.nn.Linear): def __init__( - self, - in_features, - out_features, - bias=True, - mx_specs=None, - name=None, + self, in_features, out_features, bias=True, mx_specs=None, name=None, ): mx_assert_test(mx_specs) self.mx_none = mx_specs is None diff --git a/src/chop/nn/mx/matmul.py b/src/chop/nn/mx/matmul.py index 8c18914b7..3c2e0ffb4 100644 --- a/src/chop/nn/mx/matmul.py +++ b/src/chop/nn/mx/matmul.py @@ -114,9 +114,7 @@ def backward(ctx, grad_out): in1, in2 = ctx.saved_tensors grad_out = quantize_elemwise_op( - grad_out, - mx_specs=ctx.mx_specs, - round=ctx.mx_specs["round_grad_input"], + grad_out, mx_specs=ctx.mx_specs, round=ctx.mx_specs["round_grad_input"], ) ##################################################### @@ -160,14 +158,10 @@ def backward(ctx, grad_out): # element-wise quantize for grad_in1 and grad_in2 grad_in1 = quantize_elemwise_op( - grad_in1, - mx_specs=ctx.mx_specs, - round=ctx.mx_specs["round_grad_input"], + grad_in1, mx_specs=ctx.mx_specs, round=ctx.mx_specs["round_grad_input"], ) grad_in2 = quantize_elemwise_op( - grad_in2, - mx_specs=ctx.mx_specs, - round=ctx.mx_specs["round_grad_input"], + grad_in2, mx_specs=ctx.mx_specs, round=ctx.mx_specs["round_grad_input"], ) ##################################################### diff --git a/src/chop/nn/mx/mx_ops.py b/src/chop/nn/mx/mx_ops.py index 136441bea..aad12a52f 100644 --- a/src/chop/nn/mx/mx_ops.py +++ b/src/chop/nn/mx/mx_ops.py @@ -281,10 +281,7 @@ def _quantize_mx( else: # Get shared exponents shared_exp = _shared_exponents( - A, - method=shared_exp_method, - axes=shared_exp_axes, - ebits=0, + A, method=shared_exp_method, axes=shared_exp_axes, ebits=0, ) # Flush subnormal FP32 inputs to zero @@ -299,7 +296,7 @@ def _quantize_mx( shared_exp[shared_exp > scale_emax] = float("NaN") shared_exp[shared_exp < -scale_emax] = -scale_emax - A = A / (2**shared_exp) + A = A / (2 ** shared_exp) A = _quantize_elemwise_core( A, @@ -312,7 +309,7 @@ def _quantize_mx( custom_cuda=custom_cuda, ) - A = A * (2**shared_exp) + A = A * (2 ** shared_exp) # Undo tile reshaping if block_size: diff --git a/src/chop/nn/mx/quantize.py b/src/chop/nn/mx/quantize.py index 7e0b907fa..b979b18d7 100644 --- a/src/chop/nn/mx/quantize.py +++ b/src/chop/nn/mx/quantize.py @@ -42,9 +42,7 @@ def forward(ctx, x, mx_specs, round=None): @staticmethod def backward(ctx, grad_output): grad_input = quantize_elemwise_op( - grad_output, - mx_specs=ctx.mx_specs, - round=ctx.round, + grad_output, mx_specs=ctx.mx_specs, round=ctx.round, ) return (grad_input, None, None) diff --git a/src/chop/nn/mx/simd_ops.py b/src/chop/nn/mx/simd_ops.py index 475b47ef7..28180b758 100644 --- a/src/chop/nn/mx/simd_ops.py +++ b/src/chop/nn/mx/simd_ops.py @@ -307,7 +307,7 @@ def forward(ctx, in1, mx_specs=None): else: ctx.save_for_backward(in1) - return vec_quantize(qin1**2, mx_specs=mx_specs) + return vec_quantize(qin1 ** 2, mx_specs=mx_specs) @staticmethod def backward(ctx, g): diff --git a/src/chop/nn/mx/transpose_convolution.py b/src/chop/nn/mx/transpose_convolution.py index af570025b..5d6c5da8a 100644 --- a/src/chop/nn/mx/transpose_convolution.py +++ b/src/chop/nn/mx/transpose_convolution.py @@ -75,16 +75,10 @@ def forward( # weight is (in_channels, out_channels/groups, ...) # quantize along in_channels qid_input = quantize_mx_op( - bf_in, - mx_specs, - elem_format=mx_specs["a_elem_format"], - axes=[1], + bf_in, mx_specs, elem_format=mx_specs["a_elem_format"], axes=[1], ) qid_weight = quantize_mx_op( - bf_weight, - mx_specs, - elem_format=mx_specs["w_elem_format"], - axes=[0], + bf_weight, mx_specs, elem_format=mx_specs["w_elem_format"], axes=[0], ) # compute output @@ -114,9 +108,7 @@ def backward(ctx, grad_output): input, weight = ctx.saved_tensors grad_output = quantize_elemwise_op( - grad_output, - mx_specs=ctx.mx_specs, - round=ctx.mx_specs["round_grad_input"], + grad_output, mx_specs=ctx.mx_specs, round=ctx.mx_specs["round_grad_input"], ) ##################################################### @@ -126,10 +118,7 @@ def backward(ctx, grad_output): # output is (batch, out_channels, ...) # quantize along the batch dim qex_input = quantize_mx_op( - input, - ctx.mx_specs, - elem_format=ctx.mx_specs["a_elem_format"], - axes=[0], + input, ctx.mx_specs, elem_format=ctx.mx_specs["a_elem_format"], axes=[0], ) qex_grad_output = quantize_mx_op( grad_output, @@ -150,9 +139,7 @@ def backward(ctx, grad_output): # element-wise quantize for grad_weight grad_weight = quantize_elemwise_op( - grad_weight, - mx_specs=ctx.mx_specs, - round=ctx.mx_specs["round_grad_weight"], + grad_weight, mx_specs=ctx.mx_specs, round=ctx.mx_specs["round_grad_weight"], ) ##################################################### @@ -162,10 +149,7 @@ def backward(ctx, grad_output): # output is (batch, out_channels, ...) # reduction dim is out_channels qod_weight = quantize_mx_op( - weight, - ctx.mx_specs, - elem_format=ctx.mx_specs["w_elem_format"], - axes=[1], + weight, ctx.mx_specs, elem_format=ctx.mx_specs["w_elem_format"], axes=[1], ) qod_grad_output = quantize_mx_op( grad_output, @@ -187,9 +171,7 @@ def backward(ctx, grad_output): # element-wise quantize for grad_input grad_input = quantize_elemwise_op( - grad_input, - mx_specs=ctx.mx_specs, - round=ctx.mx_specs["round_grad_input"], + grad_input, mx_specs=ctx.mx_specs, round=ctx.mx_specs["round_grad_input"], ) ##################################################### diff --git a/src/chop/nn/optical/modules/morr_conv2d.py b/src/chop/nn/optical/modules/morr_conv2d.py index f104e397f..c95cd1bb3 100644 --- a/src/chop/nn/optical/modules/morr_conv2d.py +++ b/src/chop/nn/optical/modules/morr_conv2d.py @@ -86,6 +86,7 @@ def __init__( morr_init = config.get("morr_init", True) trainable_morr_bias = config.get("trainable_morr_bias", False) trainable_morr_scale = config.get("trainable_morr_scale", False) + device = config.get("device", device) self.in_channels = in_channels self.out_channels = out_channels @@ -108,7 +109,7 @@ def __init__( self.v_max = 10.8 self.v_pi = 4.36 - self.gamma = np.pi / self.v_pi**2 + self.gamma = np.pi / self.v_pi ** 2 self.w_bit = 32 self.in_bit = 32 self.MORRConfig = MORRConfig @@ -122,7 +123,7 @@ def __init__( ### calculate FWHM (rad) self.morr_fwhm = ( -4 - * np.pi**2 + * np.pi ** 2 * MORRConfig.radius * MORRConfig.effective_index * ( @@ -240,7 +241,7 @@ def reset_parameters(self, morr_init: bool = False) -> None: (t2 - t1) / (2.4 * self.morr_fwhm) ).item() ## 0~2.4 FWHM slope as a linear approximation - self.sigma_out_scale = 4 / (3 * self.grid_dim_x**0.5 * g * self.morr_fwhm) + self.sigma_out_scale = 4 / (3 * self.grid_dim_x ** 0.5 * g * self.morr_fwhm) self.out_scale_quant_gain = None init.normal_(self.morr_output_scale, 0, self.sigma_out_scale) diff --git a/src/chop/nn/optical/modules/morr_linear.py b/src/chop/nn/optical/modules/morr_linear.py index f93281af5..c1181cfc1 100644 --- a/src/chop/nn/optical/modules/morr_linear.py +++ b/src/chop/nn/optical/modules/morr_linear.py @@ -62,7 +62,7 @@ def __init__( self.v_max = 10.8 self.v_pi = 4.36 - self.gamma = np.pi / self.v_pi**2 + self.gamma = np.pi / self.v_pi ** 2 self.w_bit = 32 self.in_bit = 32 @@ -80,7 +80,7 @@ def __init__( ### calculate FWHM (rad) self.morr_fwhm = ( -4 - * np.pi**2 + * np.pi ** 2 * morr_config.radius * morr_config.effective_index * ( @@ -198,7 +198,7 @@ def reset_parameters(self, morr_init: bool = False) -> None: (t2 - t1) / (2.4 * self.morr_fwhm) ).item() ## 0~2.4 FWHM slope as a linear approximation - self.sigma_out_scale = 4 / (3 * self.grid_dim_x**0.5 * g * self.morr_fwhm) + self.sigma_out_scale = 4 / (3 * self.grid_dim_x ** 0.5 * g * self.morr_fwhm) self.out_scale_quant_gain = None init.normal_(self.morr_output_scale, 0, self.sigma_out_scale) else: diff --git a/src/chop/nn/optical/utils/__init__.py b/src/chop/nn/optical/utils/__init__.py index 88b71f264..248e214da 100644 --- a/src/chop/nn/optical/utils/__init__.py +++ b/src/chop/nn/optical/utils/__init__.py @@ -11,9 +11,7 @@ toeplitz, ) -from .initializer import ( - morr_uniform_, -) +from .initializer import morr_uniform_ from .quantize import ( input_quantize_fn, diff --git a/src/chop/nn/optical/utils/compute.py b/src/chop/nn/optical/utils/compute.py index d8d36a354..8ae1b3279 100644 --- a/src/chop/nn/optical/utils/compute.py +++ b/src/chop/nn/optical/utils/compute.py @@ -20,65 +20,11 @@ from torch.types import Device, _size __all__ = [ - "shift", - "Krylov", - "circulant", "toeplitz", - "complex_circulant", - "complex_mult", - "expi", - "complex_matvec_mult", - "complex_matmul", - "real_to_complex", - "get_complex_magnitude", - "get_complex_energy", - "complex_to_polar", - "polar_to_complex", - "absclamp", - "absclamp_", "im2col_2d", - "check_identity_matrix", - "check_unitary_matrix", - "check_equal_tensor", - "batch_diag", - "batch_eye_cpu", - "batch_eye", - "merge_chunks", - "partition_chunks", - "clip_by_std", - "percentile", - "gen_boolean_mask_cpu", - "gen_boolean_mask", - "fftshift_cpu", - "ifftshift_cpu", - "gen_gaussian_noise", - "gen_gaussian_filter2d_cpu", - "gen_gaussian_filter2d", - "add_gaussian_noise_cpu", - "add_gaussian_noise", - "add_gaussian_noise_", - "circulant_multiply", - "calc_diagonal_hessian", - "calc_jacobian", - "polynomial", - "gaussian", - "lowrank_decompose", - "get_conv2d_flops", - "interp1d", ] -def set_torch_deterministic(random_state: int = 0) -> None: - random_state = int(random_state) % (2**32) - torch.manual_seed(random_state) - np.random.seed(random_state) - if torch.cuda.is_available(): - torch.backends.cudnn.deterministic = True - torch.backends.cudnn.benchmark = False - torch.cuda.manual_seed_all(random_state) - random.seed(random_state) - - def shift(v: Tensor, f: float = 1) -> Tensor: return torch.cat((f * v[..., -1:], v[..., :-1]), dim=-1) @@ -116,9 +62,38 @@ def toeplitz(col: Tensor) -> Tensor: return col[..., indices] -def complex_circulant(eigens: Tensor) -> Tensor: - circ = Krylov(shift, eigens).transpose(-1, -2) - return circ +def im2col_2d( + W: Optional[Tensor] = None, + X: Optional[Tensor] = None, + stride: int = 1, + padding: int = 0, + w_size: Optional[_size] = None, +) -> Tuple[Tensor, Tensor, int, int]: + if W is not None: + W_col = W.view(W.size(0), -1) + else: + W_col = None + + if X is not None: + n_filters, d_filter, h_filter, w_filter = W.size() if W is not None else w_size + n_x, d_x, h_x, w_x = X.size() + + h_out = (h_x - h_filter + 2 * padding) / stride + 1 + w_out = (w_x - w_filter + 2 * padding) / stride + 1 + + h_out, w_out = int(h_out), int(w_out) + X_col = torch.nn.functional.unfold( + X.view(1, -1, h_x, w_x), + h_filter, + dilation=1, + padding=padding, + stride=stride, + ).view(n_x, -1, h_out * w_out) + X_col = X_col.permute(1, 2, 0).contiguous().view(X_col.size(1), -1) + else: + X_col, h_out, w_out = None, None, None + + return W_col, X_col, h_out, w_out def complex_mult(X: Tensor, Y: Tensor) -> Tensor: @@ -151,65 +126,6 @@ def complex_mult(X: Tensor, Y: Tensor) -> Tensor: return X.mul(Y) -def complex_matvec_mult(W: Tensor, X: Tensor) -> Tensor: - return torch.sum(complex_mult(W, X.unsqueeze(0).repeat(W.size(0), 1, 1)), dim=1) - - -def complex_matmul(X: Tensor, Y: Tensor) -> Tensor: - assert X.shape[-1] == 2 and Y.shape[-1] == 2, "Last dimension must be 2" - if torch.__version__ >= "1.8" or ( - torch.__version__ >= "1.7" and X.shape[:-3] == Y.shape[:-3] - ): - return torch.view_as_real( - torch.matmul(torch.view_as_complex(X), torch.view_as_complex(Y)) - ) - - return torch.stack( - [ - X[..., 0].matmul(Y[..., 0]) - X[..., 1].matmul(Y[..., 1]), - X[..., 0].matmul(Y[..., 1]) + X[..., 1].matmul(Y[..., 0]), - ], - dim=-1, - ) - - -def expi(x: Tensor) -> Tensor: - if torch.__version__ >= "1.8" or ( - torch.__version__ >= "1.7" and not x.requires_grad - ): - return torch.exp(1j * x) - else: - return x.cos().type(torch.cfloat) + 1j * x.sin().type(torch.cfloat) - - -def real_to_complex(x: Tensor) -> Tensor: - if torch.__version__ < "1.7": - return torch.stack((x, torch.zeros_like(x).to(x.device)), dim=-1) - else: - return torch.view_as_real(x.to(torch.complex64)) - - -def get_complex_magnitude(x: Tensor) -> Tensor: - assert x.size(-1) == 2, "[E] Input must be complex Tensor" - return torch.sqrt(x[..., 0] * x[..., 0] + x[..., 1] * x[..., 1]) - - -def complex_to_polar(x: Tensor) -> Tensor: - # real and imag to magnitude and angle - if isinstance(x, torch.Tensor): - mag = x.norm(p=2, dim=-1) - angle = torch.view_as_complex(x).angle() - x = torch.stack([mag, angle], dim=-1) - elif isinstance(x, np.ndarray): - x = x.astype(np.complex64) - mag = np.abs(x) - angle = np.angle(x) - x = np.stack([mag, angle], axis=-1) - else: - raise NotImplementedError - return x - - def polar_to_complex(mag: Tensor, angle: Tensor) -> Tensor: # magnitude and angle to real and imag if angle is None: @@ -231,534 +147,6 @@ def polar_to_complex(mag: Tensor, angle: Tensor) -> Tensor: return x -def get_complex_energy(x: Tensor) -> Tensor: - assert x.size(-1) == 2, "[E] Input must be complex Tensor" - return x[..., 0] * x[..., 0] + x[..., 1] * x[..., 1] - - -def absclamp( - x: Tensor, min: Optional[float] = None, max: Optional[float] = None -) -> Tensor: - if isinstance(x, torch.Tensor): - mag = x.norm(p=2, dim=-1).clamp(min=min, max=max) - angle = torch.view_as_complex(x).angle() - x = polar_to_complex(mag, angle) - elif isinstance(x, np.ndarray): - x = x.astype(np.complex64) - mag = np.clip(np.abs(x), a_min=min, a_max=max) - angle = np.angle(x) - x = polar_to_complex(mag, angle) - else: - raise NotImplementedError - return x - - -def absclamp_( - x: Tensor, min: Optional[float] = None, max: Optional[float] = None -) -> Tensor: - if isinstance(x, torch.Tensor): - y = torch.view_as_complex(x) - mag = y.abs().clamp(min=min, max=max) - angle = y.angle() - x.data.copy_(polar_to_complex(mag, angle)) - elif isinstance(x, np.ndarray): - y = x.astype(np.complex64) - mag = np.clip(np.abs(y), a_min=min, a_max=max) - angle = np.angle(y) - x[:] = polar_to_complex(mag, angle) - else: - raise NotImplementedError - return x - - -def im2col_2d( - W: Optional[Tensor] = None, - X: Optional[Tensor] = None, - stride: int = 1, - padding: int = 0, - w_size: Optional[_size] = None, -) -> Tuple[Tensor, Tensor, int, int]: - if W is not None: - W_col = W.view(W.size(0), -1) - else: - W_col = None - - if X is not None: - n_filters, d_filter, h_filter, w_filter = W.size() if W is not None else w_size - n_x, d_x, h_x, w_x = X.size() - - h_out = (h_x - h_filter + 2 * padding) / stride + 1 - w_out = (w_x - w_filter + 2 * padding) / stride + 1 - - h_out, w_out = int(h_out), int(w_out) - X_col = torch.nn.functional.unfold( - X.view(1, -1, h_x, w_x), - h_filter, - dilation=1, - padding=padding, - stride=stride, - ).view(n_x, -1, h_out * w_out) - X_col = X_col.permute(1, 2, 0).contiguous().view(X_col.size(1), -1) - else: - X_col, h_out, w_out = None, None, None - - return W_col, X_col, h_out, w_out - - -def check_identity_matrix(W: Tensor) -> bool: - if isinstance(W, np.ndarray): - W_numpy = W.copy().astype(np.float64) - elif isinstance(W, torch.Tensor): - W_numpy = W.detach().cpu().numpy().copy().astype(np.float64) - else: - assert 0, "[E] Array type not supported, must be numpy.ndarray or torch.Tensor" - - return (W_numpy.shape[0] == W_numpy.shape[1]) and np.allclose( - W_numpy, np.eye(W_numpy.shape[0]) - ) - - -def check_unitary_matrix(W: Tensor) -> bool: - if isinstance(W, np.ndarray): - W_numpy = W.copy().astype(np.float64) - elif isinstance(W, torch.Tensor): - W_numpy = W.detach().cpu().numpy().copy().astype(np.float64) - else: - assert 0, "[E] Array type not supported, must be numpy.ndarray or torch.Tensor" - M = np.dot(W_numpy, W_numpy.T) - # print(M) - return check_identity_matrix(M) - - -def check_equal_tensor(W1: Tensor, W2: Tensor) -> bool: - if isinstance(W1, np.ndarray): - W1_numpy = W1.copy().astype(np.float64) - elif isinstance(W1, torch.Tensor): - W1_numpy = W1.detach().cpu().numpy().copy().astype(np.float64) - else: - assert 0, "[E] Array type not supported, must be numpy.ndarray or torch.Tensor" - - if isinstance(W2, np.ndarray): - W2_numpy = W2.copy().astype(np.float64) - elif isinstance(W2, torch.Tensor): - W2_numpy = W2.detach().cpu().numpy().copy().astype(np.float64) - else: - assert 0, "[E] Array type not supported, must be numpy.ndarray or torch.Tensor" - return (W1_numpy.shape == W2_numpy.shape) and np.allclose(W1_numpy, W2_numpy) - - -def batch_diag(x: Tensor) -> Tensor: - # x[..., N, N] -> [..., N] - assert ( - len(x.shape) >= 2 - ), f"At least 2-D array/tensor is expected, but got shape {x.shape}" - if isinstance(x, np.ndarray): - size = list(x.shape) - x = x.reshape(size[:-2] + [size[-2] * size[-1]]) - x = x[..., :: size[-1] + 1] - elif isinstance(x, torch.Tensor): - size = list(x.size()) - x = x.flatten(-2, -1) - x = x[..., :: size[-1] + 1] - else: - raise NotImplementedError - return x - - -def batch_eye_cpu(N: int, batch_shape: List[int], dtype: np.dtype) -> np.ndarray: - x = np.zeros(list(batch_shape) + [N, N], dtype=dtype) - x.reshape(-1, N * N)[..., :: N + 1] = 1 - return x - - -def batch_eye( - N: int, - batch_shape: List[int], - dtype: torch.dtype, - device: Device = torch.device("cuda"), -) -> torch.Tensor: - x = torch.zeros(list(batch_shape) + [N, N], dtype=dtype, device=device) - x.view(-1, N * N)[..., :: N + 1] = 1 - return x - - -def merge_chunks(x: Tensor, complex: bool = False) -> Tensor: - """Merge a chunked/blocked tensors into a 2D matrix - - Args: - x (Tensor): Tensor of shape [h1, w1, h2, w2, ...., hk, wk] if complex=False; [h1, w1, h2, w2, ...., hk, wk, 2] if complex=True - complex (bool, optional): True if the tensor x has a last dimension with size 2 for real/imag representation. Defaults to False. - - Returns: - Tensor: [h1*h2*...*hk, w1*w2*...*wk] or [h1*h2*...*hk, w1*w2*...*wk, 2] - """ - if isinstance(x, torch.Tensor): - permute = torch.permute - elif isinstance(x, np.ndarray): - permute = np.transpose - else: - raise NotImplementedError - - if not complex: - dim = len(x.shape) - x = permute(x, list(range(0, dim, 2)) + list(range(1, dim + 1, 2))) - x = x.reshape(np.prod([x.shape[i] for i in range(dim // 2)]), -1) - else: - dim = len(x.shape) - 1 - x = permute(x, list(range(0, dim, 2)) + list(range(1, dim + 1, 2) + [dim])) - x = x.reshape(np.prod([x.shape[i] for i in range(dim // 2)]), -1, 2) - - return x - - -def partition_chunks( - x: Tensor, out_shape: int | Tuple[int, ...], complex: bool = False -) -> Tensor: - """Partition a tensor into square chunks, similar to Rearrange in einops - - Args: - x (Tensor): 2D tensor of shape [h1*h2*...*hk, w1*w2*...*wk] or 3D tensor of shape [h1*h2*...*hk, w1*w2*...*wk, 2] if complex=True - out_shape (Tuple[int]): output blocked shape (h1, w1, h2, w2, ...); Do not include the last dimension even if complex=True - complex (bool, optional): whether x is complex tensor. Defaults to False. - - Returns: - [Tensor]: Tensor of shape [h1, w1, h2, w2, ...., hk, wk] or [h1, w1, h2, w2, ...., hk, wk, 2] if complex=True - """ - if complex: - assert len(x.shape) == 3 - x_shape = (np.prod(out_shape[::2]), np.prod(out_shape[1::2])) - if isinstance(x, torch.Tensor): - permute = torch.permute - pad_fn = lambda x, padding: torch.nn.functional.pad(x[None, None], padding)[ - 0, 0 - ] - is_tensor = True - elif isinstance(x, np.ndarray): - permute = np.transpose - pad_fn = np.pad - is_tensor = False - else: - raise NotImplementedError - - if x_shape != x.shape[:2]: - ## if x cannot be partitioned into out_shape, we need to pad it - if is_tensor: - ## torch from the last dim - padding = (0, x_shape[1] - x.shape[1], 0, x_shape[0] - x.shape[0]) - if complex: - padding = (0, 0) + padding - else: - ## np from the first dim - padding = ((0, x_shape[0] - x.shape[0]), (0, x_shape[1] - x.shape[1])) - if complex: - padding = padding + (0, 0) - - x = pad_fn(x, padding) - - in_shape = list(out_shape[::2]) + list(out_shape[1::2]) - permute_shape = np.arange(len(out_shape)).reshape(2, -1).T.reshape(-1).tolist() - if complex: - in_shape.append(2) - permute_shape.append(len(permute_shape)) - x = x.reshape(in_shape) # [h1, h2, ..., hk, w1, w2, ..., wk] - - x = permute(x, permute_shape) # [h1, w1, h2, w2, ...., hk, wk] - - return x - - -def clip_by_std(x: Tensor, n_std_neg: float = 3.0, n_std_pos: float = 3.0) -> Tensor: - if isinstance(x, np.ndarray): - std = np.std(x) - mean = np.mean(x) - out = np.clip(x, a_min=mean - n_std_neg * std, a_max=mean + n_std_pos * std) - elif isinstance(x, torch.Tensor): - std = x.data.std() - mean = x.data.mean() - out = x.clamp(min=mean - n_std_neg * std, max=mean + n_std_pos * std) - else: - raise NotImplementedError - return out - - -def percentile(t: Tensor, q: float) -> Tensor: - """ - Return the ``q``-th percentile of the flattened input tensor's data. - - CAUTION: - * Needs PyTorch >= 1.1.0, as ``torch.kthvalue()`` is used. - * Values are not interpolated, which corresponds to - ``numpy.percentile(..., interpolation="nearest")``. - - :param t: Input tensor. - :param q: Percentile to compute, which must be between 0 and 100 inclusive. - :return: Resulting value (scalar). - """ - # Note that ``kthvalue()`` works one-based, i.e. the first sorted value - # indeed corresponds to k=1, not k=0! Use float(q) instead of q directly, - # so that ``round()`` returns an integer, even if q is a np.float32. - if isinstance(t, torch.Tensor): - k = 1 + round(0.01 * float(q) * (t.numel() - 1)) - result = t.view(-1).kthvalue(k).values.item() - elif isinstance(t, np.ndarray): - result = np.percentile(t, q=q) - else: - raise NotImplementedError - return result - - -def gen_boolean_mask_cpu(size: _size, true_prob: float) -> np.ndarray: - assert 0 <= true_prob <= 1, "[E] Wrong probability for True" - return np.random.choice(a=[False, True], size=size, p=[1 - true_prob, true_prob]) - - -def gen_boolean_mask( - size: _size, - true_prob: float, - random_state: Optional[int] = None, - device: Device = torch.device("cuda"), -) -> Tensor: - assert 0 <= true_prob <= 1, "[E] Wrong probability for True" - if true_prob > 1 - 1e-9: - return torch.ones(size, device=device, dtype=torch.bool) - elif true_prob < 1e-9: - return torch.zeros(size, device=device, dtype=torch.bool) - if random_state is not None: - with torch.random.fork_rng(): - torch.random.manual_seed(random_state) - return torch.empty(size, dtype=torch.bool, device=device).bernoulli_( - true_prob - ) - else: - return torch.empty(size, dtype=torch.bool, device=device).bernoulli_(true_prob) - - -def fftshift_cpu( - x: Union[Tensor, np.ndarray], batched: bool = True, dim: Optional[Tuple[int]] = None -) -> Union[Tensor, np.ndarray]: - if isinstance(x, np.ndarray): - if dim is None: - if batched: - dim = tuple(range(1, len(x.shape))) - else: - dim = tuple(range(0, len(x.shape))) - out = np.fft.fftshift(x, axes=dim) - elif isinstance(x, torch.Tensor): - device = x.device - x = x.cpu().detach().numpy() - if dim is None: - if batched: - dim = tuple(range(1, len(x.shape))) - else: - dim = tuple(range(0, len(x.shape))) - out = np.fft.fftshift(x, axes=dim) - out = torch.from_numpy(out).to(device) - return out - - -def ifftshift_cpu( - x: Union[Tensor, np.ndarray], batched: bool = True, dim: Optional[Tuple[int]] = None -) -> Union[Tensor, np.ndarray]: - if isinstance(x, np.ndarray): - if dim is None: - if batched: - dim = tuple(range(1, len(x.shape))) - else: - dim = tuple(range(0, len(x.shape))) - out = np.fft.ifftshift(x, axes=dim) - elif isinstance(x, torch.Tensor): - device = x.device - x = x.cpu().detach().numpy() - if dim is None: - if batched: - dim = tuple(range(1, len(x.shape))) - else: - dim = tuple(range(0, len(x.shape))) - out = np.fft.ifftshift(x, axes=dim) - out = torch.from_numpy(out).to(device) - return out - - -def gen_gaussian_noise( - W: Union[Tensor, np.ndarray], - noise_mean: float = 0.0, - noise_std: float = 0.002, - trunc_range: Tuple = (), - random_state: Optional[int] = None, -) -> Union[Tensor, np.ndarray]: - if random_state is not None: - set_torch_deterministic(random_state) - if isinstance(W, np.ndarray): - if not trunc_range: - noises = np.random.normal(noise_mean, noise_std, W.shape) - else: - a = (trunc_range[0] - noise_mean) / noise_std - b = (trunc_range[1] - noise_mean) / noise_std - noises = truncnorm.rvs( - a, b, loc=noise_mean, scale=noise_std, size=W.shape, random_state=None - ) - elif isinstance(W, torch.Tensor): - if not trunc_range: - noises = torch.zeros_like(W).normal_(mean=noise_mean, std=noise_std) - else: - size = W.shape - tmp = W.new_empty(size + (4,)).normal_() - a = (trunc_range[0] - noise_mean) / noise_std - b = (trunc_range[1] - noise_mean) / noise_std - valid = (tmp < b) & (tmp > a) - ind = valid.max(-1, keepdim=True)[1] - noises = tmp.gather(-1, ind).squeeze(-1).mul_(noise_std).add_(noise_mean) - # noises = truncated_normal(W, mean=noise_mean, std=noise_std, a=trunc_range[0], b=trunc_range[1]) - else: - assert 0, logging.error( - f"Array type not supported, must be numpy.ndarray or torch.Tensor, but got {type(W)}" - ) - return noises - - -def gen_gaussian_filter2d_cpu(size: int = 3, std: float = 0.286) -> np.ndarray: - assert ( - size % 2 == 1 - ), f"Gaussian filter can only be odd size, but size={size} is given." - ax = np.linspace(-(size - 1) / 2.0, (size - 1) / 2.0, size) - xx, yy = np.meshgrid(ax, ax) - kernel = np.exp(-0.5 / np.square(std) * (np.square(xx) + np.square(yy))) - kernel = kernel / np.sum(kernel) - kernel[size // 2, size // 2] = 1 - return kernel - - -def gen_gaussian_filter2d( - size: int = 3, - std: float = 0.286, - center_one: bool = True, - device: Device = torch.device("cuda"), -) -> Tensor: - assert ( - size % 2 == 1 - ), f"Gaussian filter can only be odd size, but size={size} is given." - if std > 1e-8: - ax = torch.linspace( - -(size - 1) / 2.0, - (size - 1) / 2.0, - size, - dtype=torch.float32, - device=device, - ) - xx, yy = torch.meshgrid(ax, ax) - kernel = torch.exp(-0.5 / (std**2) * (xx.square() + yy.square())) - kernel = kernel.div_(kernel.sum()) - if center_one: - kernel[size // 2, size // 2] = 1 - else: - kernel = torch.zeros(size, size, dtype=torch.float32, device=device) - kernel[size // 2, size // 2] = 1 - - return kernel - - -def add_gaussian_noise( - W: Union[Tensor, np.ndarray], - noise_mean: float = 0, - noise_std: float = 0.002, - trunc_range: Tuple = (), - random_state: Optional[int] = None, -) -> Union[Tensor, np.ndarray]: - noises = gen_gaussian_noise( - W, - noise_mean=noise_mean, - noise_std=noise_std, - trunc_range=trunc_range, - random_state=random_state, - ) - output = W + noises - return output - - -def add_gaussian_noise_( - W: Union[Tensor, np.ndarray], - noise_mean: float = 0, - noise_std: float = 0.002, - trunc_range: Tuple = (), - random_state: Optional[int] = None, -) -> Union[Tensor, np.ndarray]: - noises = gen_gaussian_noise( - W, - noise_mean=noise_mean, - noise_std=noise_std, - trunc_range=trunc_range, - random_state=random_state, - ) - if isinstance(W, np.ndarray): - W += noises - elif isinstance(W, torch.Tensor): - W.data += noises - else: - assert 0, logging.error( - f"Array type not supported, must be numpy.ndarray or torch.Tensor, but got {type(W)}" - ) - return W - - -def add_gaussian_noise_cpu( - W: Union[Tensor, np.ndarray], - noise_mean: float = 0, - noise_std: float = 0.002, - trunc_range: Tuple = (), -) -> Union[Tensor, np.ndarray]: - if isinstance(W, np.ndarray): - W_numpy = W.copy().astype(np.float64) - elif isinstance(W, torch.Tensor): - W_numpy = W.detach().cpu().numpy().copy().astype(np.float64) - else: - assert 0, "[E] Array type not supported, must be numpy.ndarray or torch.Tensor" - if not trunc_range: - noises = np.random.normal(noise_mean, noise_std, W_numpy.shape) - else: - a = (trunc_range[0] - noise_mean) / noise_std - b = (trunc_range[1] - noise_mean) / noise_std - noises = truncnorm.rvs( - a, b, loc=noise_mean, scale=noise_std, size=W_numpy.shape, random_state=None - ) - return W_numpy + noises - - -def circulant_multiply(c: Tensor, x: Tensor) -> Tensor: - """Multiply circulant matrix with first column c by x - Parameters: - c: (n, ) - x: (batch_size, n) or (n, ) - Return: - prod: (batch_size, n) or (n, ) - """ - return torch.irfft( - complex_mult(torch.rfft(c, 1), torch.rfft(x, 1)), 1, signal_sizes=(c.shape[-1],) - ) - - -def calc_diagonal_hessian(weight_dict, loss, model): - model.zero_grad() - hessian_dict = {} - for name, weight in weight_dict.items(): - first_gradient = grad(loss, weight, create_graph=True)[0] - second_gradient = grad(first_gradient.sum(), weight, create_graph=True)[0] - hessian_dict[name] = second_gradient.clone() - model.zero_grad() - return hessian_dict - - -def calc_jacobian( - weight_dict: Dict[str, Tensor], loss: Callable, model: nn.Module -) -> Dict[str, Tensor]: - model.zero_grad() - jacobian_dict = {} - for name, weight in weight_dict.items(): - first_gradient = grad(loss, weight, create_graph=True)[0] - jacobian_dict[name] = first_gradient.clone() - model.zero_grad() - return jacobian_dict - - @lru_cache(maxsize=4) def _polynomial_order_base(order: int, device: Device) -> Tensor: return torch.arange(order - 1, -1, -1, device=device) @@ -797,277 +185,3 @@ def polynomial(x: Tensor | np.ndarray, coeff: Tensor | np.ndarray) -> Tensor: return np.polynomial.polynomial.polyval(x, coeff[::-1]) else: raise NotImplementedError - - -def gaussian(x: Tensor, coeff: Tensor) -> Tensor: - # coeff : [n, 3], includes a, b, c - ## a * exp(-((x-b)/c)^2) + ... - size = x.size() - x = x.view(-1).unsqueeze(0) - x = ( - (coeff[:, 0:1] * torch.exp(-((x - coeff[:, 1:2]) / coeff[:, 2:3]).square())) - .sum(dim=0) - .view(size) - ) - return x - - -def lowrank_decompose( - x: Tensor, - r: int, - u_ortho: bool = False, - out_u: Optional[Tensor] = None, - out_v: Optional[Tensor] = None, -) -> Tuple[Tensor, Tensor]: - """low rank decomposition on x. x ~ uv. - - Args: - x (Tensor): tensor to decomplse - r (int): rank - u_ortho (bool, optional): whether u is orthogonal matrix. Defaults to False. - out_u (Optional[Tensor], optional): output buffer for u. Defaults to None. - out_v (Optional[Tensor], optional): output buffer for v. Defaults to None. - - Returns: - Tuple[Tensor, Tensor]: [description] - """ - ### x [..., m, n] - # r rank - u, s, v = x.data.svd(some=True) - v = v.transpose(-2, -1).contiguous() - u = u[..., :, :r] - s = s[..., :r] - v = v[..., :r, :] - if u_ortho == False: - u.mul_(s.unsqueeze(-2)) - else: - v.mul_(s.unsqueeze(-1)) - if out_u is not None: - out_u.data.copy_(u) - if out_v is not None: - out_v.data.copy_(v) - return u, v - - -def get_conv2d_flops( - input_shape: _size, - conv_filter: _size, - stride: _pair = (1, 1), - padding: _pair = (1, 1), -) -> float: - # input_shape = (4, 3,300,300) # Format:(batch, channels, rows,cols) - # conv_filter = (64,3,3,3) # Format: (num_filters, channels, rows, cols) - # stride = (1, 1) in (height, width) - # padding = (1, 1) in (height, width) - if type(stride) not in {list, tuple}: - stride = [stride, stride] - if type(padding) not in {list, tuple}: - padding = [padding, padding] - n = conv_filter[1] * conv_filter[2] * conv_filter[3] # vector_length - # general defination for number of flops (n: multiplications and n-1: additions) - flops_per_instance = n + 1 - - num_instances_per_filter = ( - (input_shape[2] - conv_filter[2] + 2 * padding[0]) / stride[0] - ) + 1 # for rows - # multiplying with cols - num_instances_per_filter *= ( - (input_shape[3] - conv_filter[3] + 2 * padding[1]) / stride[1] - ) + 1 - - flops_per_filter = num_instances_per_filter * flops_per_instance - # multiply with number of filters adn batch - total_flops_per_layer = flops_per_filter * conv_filter[0] * input_shape[0] - return total_flops_per_layer - - -class Interp1d(torch.autograd.Function): - @staticmethod - def forward(ctx, x, y, xnew, out=None): - """ - Batched Linear 1D interpolation on the GPU for Pytorch. - This function returns interpolated values of a set of 1-D functions at - the desired query points `xnew`. Any point exceeds the border of [xmin, xmax] - will be filled with 0 and no grad. - This function is working similarly to Matlab™ or scipy functions with - the `linear` interpolation mode on, except that it parallelises over - any number of desired interpolation problems. - The code will run on GPU if all the tensors provided are on a cuda - device. - https://github.com/aliutkus/torchinterp1d - - Parameters - ---------- - x : (N, ) or (D, N) Pytorch Tensor - A 1-D or 2-D tensor of real values. - y : (N,) or (D, N) Pytorch Tensor - A 1-D or 2-D tensor of real values. The length of `y` along its - last dimension must be the same as that of `x` - xnew : (P,) or (D, P) Pytorch Tensor - A 1-D or 2-D tensor of real values. `xnew` can only be 1-D if - _both_ `x` and `y` are 1-D. Otherwise, its length along the first - dimension must be the same as that of whichever `x` and `y` is 2-D. - out : Pytorch Tensor, same shape as `xnew` - Tensor for the output. If None: allocated automatically. - - """ - # making the vectors at least 2D - is_flat = {} - require_grad = {} - v = {} - device = [] - eps = torch.finfo(y.dtype).eps - for name, vec in {"x": x, "y": y, "xnew": xnew}.items(): - assert len(vec.shape) <= 2, "interp1d: all inputs must be " "at most 2-D." - if len(vec.shape) == 1: - v[name] = vec[None, :] - else: - v[name] = vec - is_flat[name] = v[name].shape[0] == 1 - require_grad[name] = vec.requires_grad - device = list(set(device + [str(vec.device)])) - assert len(device) == 1, "All parameters must be on the same device." - device = device[0] - - # Checking for the dimensions - assert v["x"].shape[1] == v["y"].shape[1] and ( - v["x"].shape[0] == v["y"].shape[0] - or v["x"].shape[0] == 1 - or v["y"].shape[0] == 1 - ), ( - "x and y must have the same number of columns, and either " - "the same number of row or one of them having only one " - "row." - ) - - reshaped_xnew = False - if ( - (v["x"].shape[0] == 1) - and (v["y"].shape[0] == 1) - and (v["xnew"].shape[0] > 1) - ): - # if there is only one row for both x and y, there is no need to - # loop over the rows of xnew because they will all have to face the - # same interpolation problem. We should just stack them together to - # call interp1d and put them back in place afterwards. - original_xnew_shape = v["xnew"].shape - v["xnew"] = v["xnew"].contiguous().view(1, -1) - reshaped_xnew = True - - # identify the dimensions of output and check if the one provided is ok - D = max(v["x"].shape[0], v["xnew"].shape[0]) - shape_ynew = (D, v["xnew"].shape[-1]) - if out is not None: - if out.numel() != shape_ynew[0] * shape_ynew[1]: - # The output provided is of incorrect shape. - # Going for a new one - out = None - else: - ynew = out.reshape(shape_ynew) - if out is None: - ynew = torch.zeros(*shape_ynew, device=device) - - # moving everything to the desired device in case it was not there - # already (not handling the case things do not fit entirely, user will - # do it if required.) - for name in v: - v[name] = v[name].to(device) - - # calling searchsorted on the x values. - ind = ynew.long() - - # expanding xnew to match the number of rows of x in case only one xnew is - # provided - if v["xnew"].shape[0] == 1: - v["xnew"] = v["xnew"].expand(v["x"].shape[0], -1) - - # the squeeze is because torch.searchsorted does accept either a nd with - # matching shapes for x and xnew or a 1d vector for x. Here we would - # have (1,len) for x sometimes - torch.searchsorted( - v["x"].contiguous().squeeze(), v["xnew"].contiguous(), out=ind - ) - - # the `-1` is because searchsorted looks for the index where the values - # must be inserted to preserve order. And we want the index of the - # preceeding value. - ind -= 1 - # we clamp the index, because the number of intervals is x.shape-1, - # and the left neighbour should hence be at most number of intervals - # -1, i.e. number of columns in x -2 - ind = torch.clamp(ind, 0, v["x"].shape[1] - 1 - 1) - - # helper function to select stuff according to the found indices. - def sel(name): - if is_flat[name]: - return v[name].contiguous().view(-1)[ind] - return torch.gather(v[name], 1, ind) - - # activating gradient storing for everything now - enable_grad = False - saved_inputs = [] - for name in ["x", "y", "xnew"]: - if require_grad[name]: - enable_grad = True - saved_inputs += [v[name]] - else: - saved_inputs += [ - None, - ] - # assuming x are sorted in the dimension 1, computing the slopes for - # the segments - is_flat["slopes"] = is_flat["x"] - # now we have found the indices of the neighbors, we start building the - # output. Hence, we start also activating gradient tracking - with torch.enable_grad() if enable_grad else contextlib.suppress(): - v["slopes"] = (v["y"][:, 1:] - v["y"][:, :-1]) / ( - eps + (v["x"][:, 1:] - v["x"][:, :-1]) - ) - - # now build the linear interpolation - ynew = sel("y") + sel("slopes") * (v["xnew"] - sel("x")) - - mask = (v["xnew"] > v["x"][:, -1:]) | ( - v["xnew"] < v["x"][:, :1] - ) # exceed left/right border - ynew = ynew.masked_fill(mask, 0) - - if reshaped_xnew: - ynew = ynew.view(original_xnew_shape) - - ctx.save_for_backward(ynew, *saved_inputs) - return ynew - - @staticmethod - def backward(ctx, grad_out): - inputs = ctx.saved_tensors[1:] - gradients = torch.autograd.grad( - ctx.saved_tensors[0], - [i for i in inputs if i is not None], - grad_out, - retain_graph=True, - ) - result = [ - None, - ] * 5 - pos = 0 - for index in range(len(inputs)): - if inputs[index] is not None: - result[index] = gradients[pos] - pos += 1 - return (*result,) - - -def interp1d(x: Tensor, y: Tensor, xnew: Tensor, out: Tensor | None = None) -> Tensor: - """numpy.interp for pytorch. Only 1D - - Args: - x (Tensor): input vector x coordinates - y (Tensor): input vector y coordinates - xnew (Tensor): new x coordinates to be interpolated - out (Tensor, optional): output tensor. Defaults to None. - - Returns: - Tensor: interpolated y coordinates - """ - return Interp1d.apply(x, y, xnew, out) diff --git a/src/chop/nn/optical/utils/initializer.py b/src/chop/nn/optical/utils/initializer.py index f0591c33a..19c97b6dd 100644 --- a/src/chop/nn/optical/utils/initializer.py +++ b/src/chop/nn/optical/utils/initializer.py @@ -10,68 +10,15 @@ import torch __all__ = [ - "quant_kaiming_uniform", - "quant_kaiming_uniform_", - "truncated_normal", - "truncated_normal_", + # "quant_kaiming_uniform", + # "quant_kaiming_uniform_", + # "truncated_normal", + # "truncated_normal_", "morr_uniform_", - "morr_uniform", + # "morr_uniform", ] -def quant_kaiming_uniform(w, nbit, beta=1.5): - """https://arxiv.org/pdf/1802.04680.pdf""" - if w.dim() > 2: - receptive_field = w[0, 0, ...].numel() - else: - receptive_field = 1 - fan_in = w.size(1) * receptive_field - sigma = 2 ** (1 - nbit) - L_min = beta * sigma - L = max(np.sqrt(6 / fan_in), L_min) - return w.clone().uniform_(-L, L) - - -def quant_kaiming_uniform_(w, nbit, beta=1.5): - """https://arxiv.org/pdf/1802.04680.pdf""" - if w.dim() > 2: - receptive_field = w[0, 0, ...].numel() - else: - receptive_field = 1 - fan_in = w.size(1) * receptive_field - sigma = 2 ** (1 - nbit) - L = np.sqrt(6 / fan_in) - L_min = beta * sigma - scale = 2 ** round(np.log2(L_min / L)) - scale = max(scale, 1.0) - L = max(L, L_min) - - return torch.nn.init.uniform_(w, -L, L), scale - - -def truncated_normal(tensor, mean=0, std=1, a=-2, b=2): - size = tensor.shape - tmp = tensor.new_empty(size + (4,)).normal_() - a = (a - mean) / std - b = (b - mean) / std - valid = (tmp < b) & (tmp > a) - ind = valid.max(-1, keepdim=True)[1] - output = tmp.gather(-1, ind).squeeze(-1).mul_(std).add_(mean) - return output - - -def truncated_normal_(tensor, mean=0, std=1, a=-2, b=2): - size = tensor.shape - tmp = tensor.new_empty(size + (4,)).normal_() - a = (a - mean) / std - b = (b - mean) / std - valid = (tmp < b) & (tmp > a) - ind = valid.max(-1, keepdim=True)[1] - tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1)) - tensor.data.mul_(std).add_(mean) - return tensor - - def morr_uniform_(tensor, MORRConfig, n_op=4, biased=False, gain=1): """ description: Uniform initialization for MORR array based tensor core [SqueezeLight, Gu+, DATE'21]. We only consider how n_op influence one MORR's output. How to balance vector length should be considered in learnable balancing factor\\ @@ -84,7 +31,7 @@ def morr_uniform_(tensor, MORRConfig, n_op=4, biased=False, gain=1): """ morr_fwhm = ( -4 - * np.pi**2 + * np.pi ** 2 * MORRConfig.radius * MORRConfig.effective_index * ( @@ -111,43 +58,3 @@ def morr_uniform_(tensor, MORRConfig, n_op=4, biased=False, gain=1): return torch.nn.init.uniform_(tensor, 0, L) else: return torch.nn.init.uniform_(tensor, -L / 2, L / 2) - - -def morr_uniform(tensor, MORRConfig, n_op=4, biased=False, gain=1): - """ - description: Uniform initialization for MORR array based tensor core [SqueezeLight, Gu+, DATE'21]\\ - @tensor {torch.Tensor} weight tensor/parameter\\ - @MORRConfig {Config} MORR configuration defined in the onnlib/model/layer/device/mrr\\ - @n_op {int scalar} Number of operands on an MORR\\ - @biased {bool} biased=True, weight in [0, L]; otherwise in [-L/2, L/2].\\ - @gain {float} Gain due to activation. ReLU=sqrt(2), Tanh=5/3, Clamp(0,1)=2\\ - return {} - """ - morr_fwhm = ( - -4 - * np.pi**2 - * MORRConfig.radius - * MORRConfig.effective_index - * ( - 1 / MORRConfig.resonance_wavelength - - 1 / (MORRConfig.resonance_wavelength - MORRConfig.bandwidth / 2) - ) - ) - ### first we need to calculate the information gain of an MORR, estimated by linear estimation at 0 and FWHM - # t1 = mrr_roundtrip_phase_to_tr_fused(torch.tensor([0]), a=MORRConfig.attenuation_factor, r=MORRConfig.coupling_factor, intensity=True) - # t2 = mrr_roundtrip_phase_to_tr_fused(torch.tensor([morr_fwhm]), a=MORRConfig.attenuation_factor, r=MORRConfig.coupling_factor, intensity=True) - # g = (t2 - t1) / morr_fwhm - - # var_phi = 1 ## assume the input is normalized to have variance 1 - # var_w = 1/(3/2*g**4*n_op*var_phi) - - # ### calculate range of uniform distribution U(-L,L) - # L = (3 * var_w)**0.5 - # return tensor.clone().uniform_(-L, L) - - ## approximation by assuming 4*std(phi)= 3*FWHM, E[x]=0, D[x]=1, W ~ U[0, L] - L = (3 / (4 * n_op)) ** 0.5 * morr_fwhm * gain - if biased: - return tensor.clone().uniform_(0, L) - else: - return tensor.clone().uniform_(-L / 2, L / 2) diff --git a/src/chop/nn/optical/utils/mrr.py b/src/chop/nn/optical/utils/mrr.py index 53fd56fd9..826e8ad64 100644 --- a/src/chop/nn/optical/utils/mrr.py +++ b/src/chop/nn/optical/utils/mrr.py @@ -71,42 +71,3 @@ class MORRConfig_10um_MQ: resonance_wavelength = 1538.739 # nm bandwidth = 1.6702 # nm quality_factor = 1213.047 - - -def plot_curve(config): - import matplotlib.pyplot as plt - - lambda0 = config.resonance_wavelength - lambda_vec = np.linspace(1546, lambda0, 9400) - aa = config.attenuation_factor # attenuation a - - t = config.coupling_factor # self-coupling - # r = np.sqrt(1 - t**2) # cross coupling coef - - R = config.radius # radius - neff = config.effective_index # refractive index - phi = -4 * np.pi * np.pi * R * neff / lambda_vec - - phase_shift = np.linspace(phi[0], phi[-1], len(phi)) - phase_shift = phase_shift - np.min(phase_shift) - print(phase_shift) - tr = (t - aa * np.exp(1j * phi)) / (1 - t * aa * np.exp(1j * phi)) - energy = abs(tr) ** 2 - print(energy) - plt.figure() - plt.plot(lambda_vec, energy) - plt.savefig("mrr_tr_wl.png") - plt.figure() - plt.plot(phase_shift, energy) - plt.savefig("mrr_tr_ps.png") - - for i, e in enumerate(energy[:-1]): - if energy[i] >= 0.5 and energy[i + 1] <= 0.5: - print(i, i + 1) - print(energy[i], energy[i + 1]) - print(lambda_vec[i], lambda_vec[i + 1]) - exit(1) - - -if __name__ == "__main__": - plot_curve(MRRConfig_5um_MQ) diff --git a/src/chop/nn/optical/utils/mrr_op.py b/src/chop/nn/optical/utils/mrr_op.py index e98dc4fd9..6c397b3f5 100644 --- a/src/chop/nn/optical/utils/mrr_op.py +++ b/src/chop/nn/optical/utils/mrr_op.py @@ -20,107 +20,25 @@ __all__ = [ - "mrr_voltage_to_delta_lambda", - "mrr_tr_to_roundtrip_phase", - "mrr_roundtrip_phase_to_tr", + # "mrr_voltage_to_delta_lambda", + # "mrr_tr_to_roundtrip_phase", + # "mrr_roundtrip_phase_to_tr", "mrr_roundtrip_phase_to_tr_fused", - "mrr_roundtrip_phase_to_tr_grad_fused", + # "mrr_roundtrip_phase_to_tr_grad_fused", "mrr_roundtrip_phase_to_tr_func", - "mrr_roundtrip_phase_to_out_phase", - "mrr_tr_to_out_phase", - "mrr_roundtrip_phase_to_tr_phase", - "mrr_roundtrip_phase_to_tr_phase_fused", - "mrr_modulator", - "mrr_filter", - "morr_filter", - "mrr_fwhm_to_ng", - "mrr_ng_to_fsr", - "mrr_finesse", + # "mrr_roundtrip_phase_to_out_phase", + # "mrr_tr_to_out_phase", + # "mrr_roundtrip_phase_to_tr_phase", + # "mrr_roundtrip_phase_to_tr_phase_fused", + # "mrr_modulator", + # "mrr_filter", + # "morr_filter", + # "mrr_fwhm_to_ng", + # "mrr_ng_to_fsr", + # "mrr_finesse", ] -def mrr_voltage_to_delta_lambda(v, alpha, k, gamma, n_g, lambda_0): - """ - description: micro-ring resonator (MRR) wavelength modulation, \delta\lambda=\delta\n_eff\times\lambda/n_g, \deltan_eff=\gamma k \delta T=\gamma k \alpha v^2\\ - v {torch.Tensor ro np.ndarray} voltage \\ - alpha {scalar} voltage square to temperature change coefficient \\ - k {scalar} parameter \\ - gamma {scalar} power to phase shift coefficient \\ - n_g {scalar} group index, typically from 4 to 4.5\\ - lambda_0 {torch.Tensor or np.ndarray} central wavelength\\ - return delta_lambda {torch.Tensor or np.ndarray} resonance wavelength drift - """ - delta_neff = gamma * k * alpha * v * v - delta_lambda = delta_neff * lambda_0 / n_g - return delta_lambda - - -def mrr_tr_to_roundtrip_phase(t, a, r): - """ - description: field transmission to round trip phase shift - t {torch.Tensor or np.ndarray} field transmission from [0,1] \\ - a {scalar} attenuation coefficient\\ - r {scalar} coupling coefficient\\ - return phi {torch.Tensor or np.ndarray} roune trip phase shift (abs(phase lag))[0, pi], center is 0. phase lag is negative, the sign is moved to the equation - """ - # the curve has multiple valleies, thus given a t, there is infinite number of rt_phi, we only want [-pi, 0], thus the abs(phase lag) is in [0, pi], acos returns [0, pi], which matches our assumption - assert 0 <= a <= 1, logging.error(f"Expect a from [0,1] but got {a}") - assert 0 <= r <= 1, logging.error(f"Expect r from [0,1] but got {r}") - # given a and r, the curve is fixed, the max and min may not be 1 and 0 - cos_phi = ((a * a + r * r - t * (1 + r * r * a * a)) / (2 * (1 - t) * a * r)).clamp( - 0, 1 - ) - - if isinstance(cos_phi, torch.Tensor): - return cos_phi.acos(), cos_phi - elif isinstance(cos_phi, np.ndarray): - return np.arccos(cos_phi), cos_phi - else: - raise NotImplementedError - - -def mrr_roundtrip_phase_to_tr( - rt_phi, a: float = 0.8, r: float = 0.9, poly_coeff=None, intensity: bool = False -): - """ - description: round trip phase shift to field transmission - rt_phi {torch.Tensor or np.ndarray} abs of roundtrip phase shift (abs(phase lag)). range from abs([-pi, 0])=[0, pi]\\ - a {scalar} attenuation coefficient\\ - r {scalar} self-coupling coefficient\\ - poly_coeff {Callable} polynomial coefficients of intensity tranmission-roundtrip phase curve. Default set to None. None for slow computation\\ - intensity {bool scalar} whether output intensity tranmission or field transmission - return t {torch.Tensor or np.ndarray} mrr through port field/intensity transmission - """ - if poly_coeff is not None: - # fast mode, use polynomial to predict the intensity transmission curve - # if using polynomial, we want fast intensity transmission estimation, instead of field - # if using coherent light, we will use complex output, we won't use polynomial fit - t = polynomial(rt_phi.clamp(0, np.pi), poly_coeff).clamp(1e-8, 1) - if not intensity: - # avoid NAN - t = (t + 1e-12).sqrt() - else: - # use slow but accurate mode from theoretical equation - # create e^(-j phi) first - # with torch.autograd.profiler.profile(use_cuda=True) as prof: - # ephi = torch.view_as_complex(polar_to_complex(mag=None, angle=-rt_phi)) ## this sign is from the negativity of phase lag - # ### Jiaqi: Since PyTorch 1.7 rsub is not supported for autograd of complex, so have to use negate and add - # a_ephi = -a * ephi - # t = torch.view_as_real((r + a_ephi)/(1 + r * a_ephi)) - - # if(intensity): - # t = get_complex_energy(t) - # else: - # t = get_complex_magnitude(t) - # print(prof.key_averages(group_by_stack_n=5).table(sort_by='cuda_time', row_limit=5)) - ra_cosphi_by_n2 = -2 * r * a * rt_phi.cos() - t = (a * a + r * r + ra_cosphi_by_n2) / (1 + r * r * a * a + ra_cosphi_by_n2) - if not intensity: - # as long as a is not equal to r, t cannot be 0. - t = t.sqrt() - return t - - @torch.jit.script def mrr_roundtrip_phase_to_tr_fused( rt_phi, a: float = 0.8, r: float = 0.9, intensity: bool = False @@ -154,37 +72,13 @@ def mrr_roundtrip_phase_to_tr_fused( return t -@torch.jit.script -def mrr_roundtrip_phase_to_tr_grad_fused( - rt_phi, a: float = 0.8, r: float = 0.9, intensity: bool = False -): - """ - description: round trip phase shift to the gradient of field transmission - rt_phi {torch.Tensor or np.ndarray} abs of roundtrip phase shift (abs(phase lag)). range from abs([-pi, 0])=[0, pi]\\ - a {scalar} attenuation coefficient\\ - r {scalar} self-coupling coefficient\\ - intensity {bool scalar} whether output intensity tranmission or field transmission\\ - return g {torch.Tensor or np.ndarray} the gradient of mrr through port field/intensity transmission - """ - if not intensity: - g = (a * r * (a**2 - 1) * (r**2 - 1) * rt_phi.sin()) / ( - (a**2 + r**2 - 2 * a * r * rt_phi.cos()) ** (1 / 2) - * (a**2 * r**2 + 1 - 2 * a * r * rt_phi.cos()) ** 1.5 - ) - else: - g = ((a**2 - 1) * (r**2 - 1) * 2 * a * r * rt_phi.sin()) / ( - a**2 * r**2 + 1 - 2 * a * r * rt_phi.cos() - ) ** 2 - return g - - def mrr_roundtrip_phase_to_tr_func( a: float = 0.8, r: float = 0.9, intensity: bool = False ): c1 = -2 * a * r c2 = a * a + r * r c3 = 1 + r * r * a * a - a * a - r * r - c4 = (a**2 - 1) * (r**2 - 1) * 2 * a * r + c4 = (a ** 2 - 1) * (r ** 2 - 1) * 2 * a * r class MRRRoundTripPhaseToTrFunction(torch.autograd.Function): @staticmethod @@ -217,208 +111,3 @@ def backward(ctx, grad_output): return grad_input return MRRRoundTripPhaseToTrFunction.apply - - -def mrr_roundtrip_phase_to_out_phase(rt_phi, a, r): - """ - description: from round trip phase to output phase response \\ - rt_phi {torch.Tensor or np.ndarray} round trip phase shift\\ - a {scalar} attenuation coefficient\\ - r {scalar} coupling coefficient\\ - return phase {torch.Tensor or np.ndarray} output phase response - """ - if isinstance(rt_phi, torch.Tensor): - arctan = torch.atan2 - sin = torch.sin - cos = torch.cos - elif isinstance(rt_phi, np.ndarray): - arctan = np.arctan2 - sin = np.sin - cos = np.cos - else: - raise NotImplementedError - sin_rt_phi = sin(rt_phi) - cos_rt_phi = cos(rt_phi) - # phi = np.pi + rt_phi + arctan(r*sin_rt_phi-2*r*r*a*sin_rt_phi*cos_rt_phi+r*a*a*sin_rt_phi, (a-r*cos_rt_phi)*(1-r*a*cos_rt_phi)) - phi = ( - np.pi - - rt_phi - - arctan(r * sin_rt_phi, a - r * cos_rt_phi) - - arctan(r * a * sin_rt_phi, 1 - r * a * cos_rt_phi) - ) - return phi - - -def mrr_tr_to_out_phase(t, a, r, onesided=True): - """ - description: field transmission to round trip phase shift - t {torch.Tensor or np.ndarray} field transmission from [0,1] \\ - a {scalar} attenuation coefficient\\ - r {scalar} coupling coefficient\\ - onesided {bool scalar} True if only use half of the curve, output phase range [0, pi] - return phi {torch.Tensor or np.ndarray} roune trip phase shift - """ - rt_phi, cos_rt_phi = mrr_tr_to_roundtrip_phase(t, a, r) - if isinstance(t, torch.Tensor): - arctan = torch.atan2 - sin = torch.sin - elif isinstance(t, np.ndarray): - arctan = np.arctan2 - sin = np.sin - else: - raise NotImplementedError - sin_rt_phi = sin(rt_phi) - # phi = np.pi + rt_phi + arctan(r*sin_rt_phi-2*r*r*a*sin_rt_phi*cos_rt_phi+r*a*a*sin_rt_phi, (a-r*cos_rt_phi)*(1-r*a*cos_rt_phi)) - phi = ( - np.pi - - rt_phi - - arctan(r * sin_rt_phi, a - r * cos_rt_phi) - - arctan(r * a * sin_rt_phi, 1 - r * a * cos_rt_phi) - ) - if onesided: - pass - return phi - - -def mrr_roundtrip_phase_to_tr_phase(rt_phi, a, r): - """ - description: from round trip phase to output transmission with phase response \\ - rt_phi {torch.Tensor or np.ndarray} round trip phase shift\\ - a {scalar} attenuation coefficient\\ - r {scalar} coupling coefficient\\ - return output {torch.Tensor or np.ndarray} transmission with phase response - """ - # e^(-j phi) - ephi = torch.view_as_complex(polar_to_complex(mag=None, angle=-rt_phi)) - a_ephi = -a * ephi - output = torch.view_as_real((r + a_ephi) / (1 + r * a_ephi)) - return output - - -@torch.jit.script -def mrr_roundtrip_phase_to_tr_phase_fused(rt_phi, a: float, r: float): - """ - description: from round trip phase to output transmission with phase response \\ - rt_phi {torch.Tensor or np.ndarray} round trip phase shift\\ - a {scalar} attenuation coefficient\\ - r {scalar} coupling coefficient\\ - return output {torch.Tensor or np.ndarray} transmission with phase response - """ - # e^(-j phi) - rt_phi = -rt_phi - rt_phi = torch.complex(rt_phi.cos(), rt_phi.sin()) - rt_phi = -a * rt_phi - output = torch.view_as_real((r + rt_phi) / (1 + r * rt_phi)) - return output - - -def mrr_modulator(t, a=0.9, r=0.8): - """ - @description: all-pass MRR as a modulator. Map from the field intensity of through port transmission to coherent light with phase reponse\\ - @t {torch.Tensor or np.ndarray} field intensity modulation factor\\ - @a {float} attenuation factor from [0,1]. Default: 0.9\\ - @r {float} transmission/self-coupling factor from [0,1]. Default: 0.8\\ - @return: complexed light signal - """ - phase = mrr_tr_to_out_phase(t, a, r) - cos_phase, sin_phase = torch.cos(phase), torch.sin(phase) - output_real = t * cos_phase - output_imag = t * sin_phase - output = torch.stack([output_real, output_imag], dim=-1) - return output - - -def mrr_filter(x, t, a=0.9, r=0.8): - """ - @description: all-pass MRR as a filter. Map from the input complex light signal to output signal with through port transmission\\ - @x {torch.Tensor or np.ndarray} complexed input light signal\\ - @t {torch.Tensor or np.ndarray} field intensity modulation factor\\ - @a {float} attenuation factor from [0,1]. Default: 0.9\\ - @r {float} transmission/self-coupling factor from [0,1]. Default: 0.8\\ - @return: complexed light signal - """ - phase = mrr_tr_to_out_phase(t, a, r) - cos_phase, sin_phase = torch.cos(phase), torch.sin(phase) - phase_shift = torch.complex(cos_phase, sin_phase) - out = t * complex_mult(x, phase_shift) - return out - - -def morr_filter( - rt_phi, tr_poly_coeff=None, a=0.9, r=0.8, x=None, coherent=False, intensity=False -): - """ - description: from round trip phase shift to output signal \\ - rt_phi {torch.Tensor or np.ndarray, Optional} round trip phase shift. Default set to None \\ - tr_poly_coeff {Callable} polynomial coefficients of tranmission-roundtrip phase curve. Default set to None. None for slow computation\\ - a {float} attenuation factor from [0,1]. Default: 0.9\\ - r {float} transmission/self-coupling factor from [0,1]. Default: 0.8\\ - x {torch.Tensor or np.ndarray, Optional} input complex light signal {None, real tensor or complex tensor}. Default set to None\\ - coherent {bool scalar, Optional} coherent output or not. Default set to False\\ - intensity {bool scalar, Optional} whether use intensity or field transmission. Default set to False\\ - return output {torch.Tensor or np.ndarray} real tensor if incoherent, complex tensor if coherent - """ - if not coherent: - if x is None: - # unit laser input with incoherent light, 1e^j0 - t = mrr_roundtrip_phase_to_tr( - rt_phi, a=a, r=r, poly_coeff=tr_poly_coeff, intensity=intensity - ) - return t - else: - # incoherent light with non-unit input, input must be real number - t = mrr_roundtrip_phase_to_tr( - rt_phi, a=a, r=r, poly_coeff=tr_poly_coeff, intensity=intensity - ) - return x * t - else: - if x is None: - # coherent light with unit laser, 1e^j0, treat morr as a mrr modulator - phase = polar_to_complex( - mag=None, angle=mrr_roundtrip_phase_to_out_phase(rt_phi, a, r) - ) - return phase - else: - # coherent light with complex input - return complex_mult(mrr_roundtrip_phase_to_tr_phase(rt_phi, a, r), x) - - -def mrr_fwhm_to_ng(a, r, radius, lambda0, fwhm): - """ - description: from full-width half maximum (FWHM) and resonance wavelength to group index n_g (Bogaerts et al., Silicon microring resonators, Laser and Photonics Review 2011, Eq.(7))\\ - a {float} Attention coefficient\\ - r {float} Self-coupling coefficient\\ - radius {float} Radius of the MRR (unit: nm)\\ - lambda0 {float} Resonance wavelength (unit: nm)\\ - fwhm {float} bandwidth or full width half maximum (unit: nm)\\ - return n_g {float} Group index of the MRR - """ - n_g = ( - (1 - r * a) * lambda0**2 / (2 * np.pi * np.pi * radius * (r * a) ** 0.5 * fwhm) - ) - return n_g - - -def mrr_ng_to_fsr(lambda0, n_g, radius): - """ - description: Calculate the free-spectral range (FSR) based on the central resonance wavelength, group index and MRR radius. - (Bogaerts et al., Silicon microring resonators, Laser and Photonics Review 2011, Eq.(9))\\ - lambda0 {float} Resonance wavelength (unit: nm)\\ - n_g {float} Group index\\ - radius {float} Radius of the MRR (unit: nm)\\ - return fsr {float} Free-spectral range - """ - fsr = lambda0**2 / (n_g * 2 * np.pi * radius) - return fsr - - -def mrr_finesse(a, r): - """ - description: Calculate the finesse of the MRR, i.e., finesse=FSR/FWHM=pi*sqrt(ra)/(1-ra) (Bogaerts et al., Silicon microring resonators, Laser and Photonics Review 2011, Eq.(21))\\ - a {float} Attention coefficient\\ - r {float} Self-coupling coefficient\\ - return finesse {float} Finesse of the MRR - """ - ra = r * a - finesse = np.pi * ra**0.5 / (1 - ra) - return finesse diff --git a/src/chop/nn/optical/utils/quantize.py b/src/chop/nn/optical/utils/quantize.py index f60e8b240..828da142d 100644 --- a/src/chop/nn/optical/utils/quantize.py +++ b/src/chop/nn/optical/utils/quantize.py @@ -12,33 +12,17 @@ __all__ = [ - "uniform_quantize_cpu", - "pact_quantize", - "PACT_Act", - "uniform_quantize", - "uniform_quantize_new", - "ewgs_quantize", + # "uniform_quantize_cpu", + # "pact_quantize", + # "PACT_Act", + # "uniform_quantize", + # "uniform_quantize_new", + # "ewgs_quantize", "input_quantize_fn", "weight_quantize_fn", ] -class uniform_quantize_cpu(object): - def __init__(self, bits): - super(uniform_quantize_cpu).__init__() - self.bits = bits - - def __call__(self, input): - if self.bits == 32: - out = input - elif self.bits == 1: - out = np.sign(input) - else: - n = float(2**self.bits - 1) - out = np.round(input * n) / n - return out - - def uniform_quantize(k, gradient_clip=False): class qfn(torch.autograd.Function): @staticmethod @@ -48,7 +32,7 @@ def forward(ctx, input): elif k == 1: out = torch.sign(input) else: - n = float(2**k - 1) + n = float(2 ** k - 1) out = torch.round(input * n) / n return out @@ -79,7 +63,7 @@ def forward(ctx, input, scale, zero_point): elif k == 1: out = torch.sign(input) else: - n = float(2**k - 1) + n = float(2 ** k - 1) # out = torch.round(input * n) / n # out = (torch.clamp(torch.round(input / scale + zero_point), 0, n) - zero_point) * scale out = ( @@ -102,38 +86,6 @@ def backward(ctx, grad_output): return qfn.apply -def ewgs_quantize(num_levels, gradient_clip=False, scaling_factor: float = 1e-3): - class EWGS_quantizer(torch.autograd.Function): - """ - Network Quantization with Element-wise Gradient Scaling, CVPR 2021 - https://github.com/cvlab-yonsei/EWGS/blob/main/CIFAR10/custom_modules.py - x_in: continuous inputs within the range of [0,1] - num_levels: number of discrete levels - scaling_factor: backward scaling factor, typically fixed to 1e-3 - x_out: discretized version of x_in within the range of [0,1] - """ - - @staticmethod - def forward(ctx, input): - out = input.mul(num_levels - 1).round_().mul_(1 / (num_levels - 1)) - - ctx._scaling_factor = scaling_factor - ctx.save_for_backward(input - out) - return out - - @staticmethod - def backward(ctx, grad_output): - diff = ctx.saved_tensors[0] - delta = ctx._scaling_factor - scale = diff.mul_(grad_output.sign()).mul_(delta).add_(1) - grad_input = grad_output * scale - if gradient_clip: - grad_input.clamp_(-1, 1) - return grad_input - - return EWGS_quantizer.apply - - class input_quantize_fn(torch.nn.Module): def __init__( self, in_bit, alg="dorefa", device=torch.device("cuda:0"), quant_ratio=1.0 @@ -181,7 +133,7 @@ def __init__( qscheme=torch.per_tensor_affine, reduce_range=False, quant_min=0, - quant_max=2**self.in_bit - 1, + quant_max=2 ** self.in_bit - 1, ).to(self.device) else: self.obs = None @@ -428,201 +380,3 @@ def forward(self, x): assert NotImplementedError return weight_q - - -# PACT activation: https://arxiv.org/pdf/1805.06085.pdf -class PACT_QuantFunc(torch.autograd.Function): - r"""PACT (PArametrized Clipping acTivation) quantization function for activations. - Implements a :py:class:`torch.autograd.Function` for quantizing activations in :math:`Q` bits using the PACT strategy. - In forward propagation, the function is defined as - - .. math:: - \mathbf{y} = f(\mathbf{x}) = 1/\varepsilon \cdot \left\lfloor\mathrm{clip}_{ [0,\alpha) } (\mathbf{x})\right\rfloor \cdot \varepsilon - - where :math:`\varepsilon` is the quantization precision: - - .. math:: - \varepsilon = \alpha / (2^Q - 1) - - In backward propagation, using the Straight-Through Estimator, the gradient of the function is defined as - - .. math:: - \mathbf{\nabla}_\mathbf{x} \mathcal{L} &\doteq \mathbf{\nabla}_\mathbf{y} \mathcal{L} - - It can be applied by using its static `.apply` method: - - :param input: the tensor containing :math:`x`, the activations to be quantized. - :type input: `torch.Tensor` - :param eps: the precomputed value of :math:`\varepsilon`. - :type eps: `torch.Tensor` or float - :param alpha: the value of :math:`\alpha`. - :type alpha: `torch.Tensor` or float - :param delta: constant to sum to `eps` for numerical stability (default unused, 0 ). - :type delta: `torch.Tensor` or float - - :return: The quantized input activations tensor. - :rtype: `torch.Tensor` - """ - - @staticmethod - def forward(ctx, input, eps, alpha): - where_input_clipped = (input < 0) | (input >= alpha) - where_input_ltalpha = input < alpha - ctx.save_for_backward(where_input_clipped, where_input_ltalpha) - return ((input / (eps)).floor() * eps).clamp(0.0, alpha.data[0] - eps.data[0]) - - @staticmethod - def backward(ctx, grad_output): - # see Hubara et al., Section 2.3 - where_input_clipped, where_input_ltalpha = ctx.saved_tensors - # zero = torch.zeros(1, device=where_input_nonclipped.device) - grad_input = grad_output.masked_fill(where_input_clipped, 0) - # grad_input = torch.where(where_input_nonclipped, grad_output, zero) - grad_alpha = grad_output.masked_fill(where_input_ltalpha, 0).sum().expand(1) - # grad_alpha = torch.where(where_input_gtalpha, grad_output, zero).sum().expand(1) - return grad_input, None, grad_alpha - - -pact_quantize = PACT_QuantFunc.apply - - -class PACT_Act(torch.nn.Module): - r"""PACT (PArametrized Clipping acTivation) activation. - Implements a :py:class:`torch.nn.Module` to implement PACT-style activations. It is meant to replace :py:class:`torch.nn.ReLU`, :py:class:`torch.nn.ReLU6` and - similar activations in a PACT-quantized network. - This layer can also operate in a special mode, defined by the `statistics_only` member, in which the layer runs in - forward-prop without quantization, collecting statistics on the activations that can then be - used to reset the value of :math:`\alpha`. - In this mode, the layer collects: - - tensor-wise maximum value ever seen - - running average with momentum 0.9 - - running variance with momentum 0.9 - """ - - def __init__( - self, - precision=None, - alpha=1.0, - backprop_alpha=True, - statistics_only=False, - leaky=None, - device=torch.device("cuda"), - ): - r"""Constructor. Initializes a :py:class:`torch.nn.Parameter` for :math:`\alpha` and sets - up the initial value of the `statistics_only` member. - :param precision: instance defining the current quantization level (default `None`). - :type precision: :py:class:`nemo.precision.Precision` - :param alpha: the value of :math:`\alpha`. - :type alpha: `torch.Tensor` or float - :param backprop_alpha: default `True`; if `False`, do not update the value of `\alpha` with backpropagation. - :type backprop_alpha: bool - :param statistics_only: initialization value of `statistics_only` member. - :type statistics_only: bool - """ - - super(PACT_Act, self).__init__() - self.precision = precision - self.device = device - self.alpha = torch.nn.Parameter( - torch.Tensor((alpha,)).to(device), requires_grad=backprop_alpha - ) - self.alpha_p = alpha - self.statistics_only = statistics_only - self.deployment = False - self.eps_in = None - self.leaky = leaky - # self.requantization_factor = requantization_factor - - # these are only used to gather statistics - self.max = torch.nn.Parameter( - torch.zeros_like(self.alpha.data).to(device), requires_grad=False - ) - self.min = torch.nn.Parameter( - torch.zeros_like(self.alpha.data).to(device), requires_grad=False - ) - self.running_mean = torch.nn.Parameter( - torch.zeros_like(self.alpha.data).to(device), requires_grad=False - ) - self.running_var = torch.nn.Parameter( - torch.ones_like(self.alpha.data).to(device), requires_grad=False - ) - - self.precise = False - - def set_static_precision(self, limit_at_32_bits=True, **kwargs): - r"""Sets static parameters used only for deployment.""" - # item() --> conversion to float - # apparently causes a slight, but not invisibile, numerical divergence - # between FQ and QD stages - self.eps_static = self.alpha.clone().detach() / (2.0 ** (self.precision) - 1) - self.alpha_static = self.alpha.clone().detach() - # D is selected as a power-of-two - D = 2.0 ** torch.ceil( - torch.log2(self.requantization_factor * self.eps_static / self.eps_in) - ) - if not limit_at_32_bits: - self.D = D - else: - self.D = min(D, 2.0 ** (32 - 1 - (self.precision))) - - def get_output_eps(self, eps_in): - r"""Get the output quantum (:math:`\varepsilon`) given the input one. - :param eps_in: input quantum :math:`\varepsilon_{in}`. - :type eps_in: :py:class:`torch.Tensor` - :return: output quantum :math:`\varepsilon_{out}`. - :rtype: :py:class:`torch.Tensor` - """ - - return self.alpha / (2.0 ** (self.precision) - 1) - - def reset_alpha(self, use_max=True, nb_std=5.0): - r"""Reset the value of :math:`\alpha`. If `use_max` is `True`, then the highest tensor-wise value collected - in the statistics collection phase is used. If `False`, the collected standard deviation multiplied by - `nb_std` is used as a parameter - :param use_max: if True, use the tensor-wise maximum value collected in the statistics run as new :math:`\alpha` (default True). - :type use_max: bool - :param nb_std: number of standard deviations to be used to initialize :math:`\alpha` if `use_max` is False. - :type nb_std: float - """ - - if use_max: - self.alpha.data[0] = self.max.item() - else: - self.alpha.data[0] = nb_std * torch.sqrt(self.running_var).item() - - def get_statistics(self): - r"""Returns the statistics collected up to now. - - :return: The collected statistics (maximum, running average, running variance). - :rtype: tuple of floats - """ - return self.max.item(), self.running_mean.item(), self.running_var.item() - - def forward(self, x): - r"""Forward-prop function for PACT-quantized activations. - - See :py:class:`nemo.quant.pact_quant.PACT_QuantFunc` for details on the normal operation performed by this layer. - In statistics mode, it uses a normal ReLU and collects statistics in the background. - :param x: input activations tensor. - :type x: :py:class:`torch.Tensor` - - :return: output activations tensor. - :rtype: :py:class:`torch.Tensor` - """ - - if self.statistics_only: - if self.leaky is None: - x = torch.nn.functional.relu(x) - else: - x = torch.nn.functional.leaky_relu(x, self.leaky) - with torch.no_grad(): - self.max[:] = max(self.max.item(), x.max()) - self.min[:] = min(self.min.item(), x.min()) - self.running_mean[:] = 0.9 * self.running_mean.item() + 0.1 * x.mean() - self.running_var[:] = ( - 0.9 * self.running_var.item() + 0.1 * x.std() * x.std() - ) - return x - else: - eps = self.alpha / (2.0 ** (self.precision) - 1) - return pact_quantize(x, eps, self.alpha + eps) diff --git a/src/chop/nn/quantized/functional/gelu.py b/src/chop/nn/quantized/functional/gelu.py index cee5e3317..225ff70ce 100644 --- a/src/chop/nn/quantized/functional/gelu.py +++ b/src/chop/nn/quantized/functional/gelu.py @@ -107,9 +107,7 @@ def gelu_log(x, inplace=False, config=None): ) x_quantizer = partial( - log_quantizer, - width=x_width, - exponent_bias=x_exponent_bias, + log_quantizer, width=x_width, exponent_bias=x_exponent_bias, ) return F.gelu(x_quantizer(x), inplace=inplace) diff --git a/src/chop/nn/quantized/functional/linear.py b/src/chop/nn/quantized/functional/linear.py index 5fda700de..f518f4968 100644 --- a/src/chop/nn/quantized/functional/linear.py +++ b/src/chop/nn/quantized/functional/linear.py @@ -72,10 +72,7 @@ def linearInteger( def linearMinifloatDenorm( - x: Tensor, - weight: Tensor, - bias: Tensor = None, - config: dict = None, + x: Tensor, weight: Tensor, bias: Tensor = None, config: dict = None, ): w_width, w_exponent_width, w_exponent_bias = ( config["weight_width"], @@ -122,10 +119,7 @@ def linearMinifloatDenorm( def linearMinifloatIEEE( - x: Tensor, - weight: Tensor, - bias: Tensor = None, - config: dict = None, + x: Tensor, weight: Tensor, bias: Tensor = None, config: dict = None, ): w_width, w_exponent_width, w_exponent_bias = ( config["weight_width"], @@ -172,10 +166,7 @@ def linearMinifloatIEEE( def linearMinifloatIEEE( - x: Tensor, - weight: Tensor, - bias: Tensor = None, - config: dict = None, + x: Tensor, weight: Tensor, bias: Tensor = None, config: dict = None, ): w_width, w_exponent_width, w_exponent_bias = ( config["weight_width"], @@ -222,10 +213,7 @@ def linearMinifloatIEEE( def linearLog( - x: Tensor, - weight: Tensor, - bias: Tensor = None, - config: dict = None, + x: Tensor, weight: Tensor, bias: Tensor = None, config: dict = None, ): w_width, w_exponent_bias = ( config["weight_width"], @@ -240,23 +228,11 @@ def linearLog( config["bias_exponent_bias"], ) - w_quantizer = partial( - log_quantizer, - width=w_width, - exponent_bias=w_exponent_bias, - ) + w_quantizer = partial(log_quantizer, width=w_width, exponent_bias=w_exponent_bias,) - x_quantizer = partial( - log_quantizer, - width=x_width, - exponent_bias=x_exponent_bias, - ) + x_quantizer = partial(log_quantizer, width=x_width, exponent_bias=x_exponent_bias,) - b_quantizer = partial( - log_quantizer, - width=b_width, - exponent_bias=b_exponent_bias, - ) + b_quantizer = partial(log_quantizer, width=b_width, exponent_bias=b_exponent_bias,) x = x_quantizer(x) weight = w_quantizer(weight) @@ -266,10 +242,7 @@ def linearLog( def linearBlockFP( - x: Tensor, - weight: Tensor, - bias: Tensor = None, - config: dict = None, + x: Tensor, weight: Tensor, bias: Tensor = None, config: dict = None, ): # establish quantizers w_width, w_exponent_width, w_exponent_bias, w_block_size = ( @@ -327,10 +300,7 @@ def linearBlockFP( def linearBlockMinifloat( - x: Tensor, - weight: Tensor, - bias: Tensor = None, - config: dict = None, + x: Tensor, weight: Tensor, bias: Tensor = None, config: dict = None, ): # establish quantizers w_width, w_exponent_width, w_exponent_bias_width, w_block_size = ( @@ -388,10 +358,7 @@ def linearBlockMinifloat( def linearBlockLog( - x: Tensor, - weight: Tensor, - bias: Tensor = None, - config: dict = None, + x: Tensor, weight: Tensor, bias: Tensor = None, config: dict = None, ): # establish quantizers w_width, w_exponent_bias_width, w_block_size = ( @@ -443,10 +410,7 @@ def linearBlockLog( def linearBinary( - x: Tensor, - weight: Tensor, - bias: Tensor = None, - config: dict = None, + x: Tensor, weight: Tensor, bias: Tensor = None, config: dict = None, ): w_stochastic = config["weight_stochastic"] w_bipolar = config["weight_bipolar"] @@ -462,10 +426,7 @@ def linearBinary( def linearBinaryScaling( - x: Tensor, - weight: Tensor, - bias: Tensor = None, - config: dict = None, + x: Tensor, weight: Tensor, bias: Tensor = None, config: dict = None, ): """ Binary scaling variant of the linear transformation layer. @@ -513,10 +474,7 @@ def linearBinaryScaling( def linearTernary( - x: Tensor, - weight: Tensor, - bias: Tensor = None, - config: dict = None, + x: Tensor, weight: Tensor, bias: Tensor = None, config: dict = None, ): w_scaling_factor = config["weight_scaling_factor"] w_mean = get_stats(config, "weight_mean") @@ -540,28 +498,19 @@ def linearTernary( def linearBinaryResidualSign( - x: Tensor, - weight: Tensor, - bias: Tensor = None, - config: dict = None, + x: Tensor, weight: Tensor, bias: Tensor = None, config: dict = None, ): raise NotImplementedError def linearLUT( - x: Tensor, - weight: Tensor, - bias: Tensor = None, - config: dict = None, + x: Tensor, weight: Tensor, bias: Tensor = None, config: dict = None, ): raise NotImplementedError def linearLogicNets( - x: Tensor, - weight: Tensor, - bias: Tensor = None, - config: dict = None, + x: Tensor, weight: Tensor, bias: Tensor = None, config: dict = None, ): raise NotImplementedError diff --git a/src/chop/nn/quantized/functional/matmul.py b/src/chop/nn/quantized/functional/matmul.py index d06eb1ece..f6487dc8f 100644 --- a/src/chop/nn/quantized/functional/matmul.py +++ b/src/chop/nn/quantized/functional/matmul.py @@ -176,14 +176,10 @@ def generic_matmul_log(x, y, config, style="matmul"): ) x_quantizer = partial( - log_quantizer, - width=x_width, - exponent_bias=x_exponent_bias, + log_quantizer, width=x_width, exponent_bias=x_exponent_bias, ) y_quantizer = partial( - log_quantizer, - width=y_width, - exponent_bias=y_exponent_bias, + log_quantizer, width=y_width, exponent_bias=y_exponent_bias, ) x = x_quantizer(x) y = y_quantizer(y) diff --git a/src/chop/nn/quantized/functional/relu.py b/src/chop/nn/quantized/functional/relu.py index 57daed04a..cb57d078e 100644 --- a/src/chop/nn/quantized/functional/relu.py +++ b/src/chop/nn/quantized/functional/relu.py @@ -107,9 +107,7 @@ def relu_log(x, inplace=False, config=None): ) x_quantizer = partial( - log_quantizer, - width=x_width, - exponent_bias=x_exponent_bias, + log_quantizer, width=x_width, exponent_bias=x_exponent_bias, ) return F.relu(x_quantizer(x), inplace=inplace) diff --git a/src/chop/nn/quantized/functional/selu.py b/src/chop/nn/quantized/functional/selu.py index 12956c392..b1edebdbf 100644 --- a/src/chop/nn/quantized/functional/selu.py +++ b/src/chop/nn/quantized/functional/selu.py @@ -107,9 +107,7 @@ def selu_log(x, inplace=False, config=None): ) x_quantizer = partial( - log_quantizer, - width=x_width, - exponent_bias=x_exponent_bias, + log_quantizer, width=x_width, exponent_bias=x_exponent_bias, ) return F.selu(x_quantizer(x), inplace=inplace) diff --git a/src/chop/nn/quantized/functional/softplus.py b/src/chop/nn/quantized/functional/softplus.py index a9bd7dafc..c873e7104 100644 --- a/src/chop/nn/quantized/functional/softplus.py +++ b/src/chop/nn/quantized/functional/softplus.py @@ -107,9 +107,7 @@ def softplus_log(x, inplace=False, config=None): ) x_quantizer = partial( - log_quantizer, - width=x_width, - exponent_bias=x_exponent_bias, + log_quantizer, width=x_width, exponent_bias=x_exponent_bias, ) return F.softplus(x_quantizer(x), inplace=inplace) diff --git a/src/chop/nn/quantized/functional/softsign.py b/src/chop/nn/quantized/functional/softsign.py index 3eaab47fa..c60f5f757 100644 --- a/src/chop/nn/quantized/functional/softsign.py +++ b/src/chop/nn/quantized/functional/softsign.py @@ -107,9 +107,7 @@ def softsign_log(x, inplace=False, config=None): ) x_quantizer = partial( - log_quantizer, - width=x_width, - exponent_bias=x_exponent_bias, + log_quantizer, width=x_width, exponent_bias=x_exponent_bias, ) return F.softsign(x_quantizer(x), inplace=inplace) diff --git a/src/chop/nn/quantized/functional/tanh.py b/src/chop/nn/quantized/functional/tanh.py index 7d3c67c31..8b1009ac0 100644 --- a/src/chop/nn/quantized/functional/tanh.py +++ b/src/chop/nn/quantized/functional/tanh.py @@ -107,9 +107,7 @@ def tanh_log(x, inplace=False, config=None): ) x_quantizer = partial( - log_quantizer, - width=x_width, - exponent_bias=x_exponent_bias, + log_quantizer, width=x_width, exponent_bias=x_exponent_bias, ) return F.tanh(x_quantizer(x), inplace=inplace) diff --git a/src/chop/nn/quantized/modules/__init__.py b/src/chop/nn/quantized/modules/__init__.py index 4219d6da9..6759fc67f 100644 --- a/src/chop/nn/quantized/modules/__init__.py +++ b/src/chop/nn/quantized/modules/__init__.py @@ -65,9 +65,7 @@ BatchNorm2dInteger, BatchNorm2dBinary, ) -from .layer_norm import ( - LayerNormInteger, -) +from .layer_norm import LayerNormInteger from .group_norm import GroupNormInteger from .instance_norm2d import InstanceNorm2dInteger @@ -144,12 +142,8 @@ SoftplusBinary, SoftplusTernary, ) -from .batch_norm1d import ( - BatchNorm1dInteger, -) -from .gqa import ( - GroupedQueryAttentionInteger, -) +from .batch_norm1d import BatchNorm1dInteger +from .gqa import GroupedQueryAttentionInteger quantized_module_map = { "conv1d_block_minifloat": Conv1dBlockMinifloat, diff --git a/src/chop/nn/quantized/modules/attention.py b/src/chop/nn/quantized/modules/attention.py index 45819db75..315a74ec3 100644 --- a/src/chop/nn/quantized/modules/attention.py +++ b/src/chop/nn/quantized/modules/attention.py @@ -6,9 +6,7 @@ from transformers.models.bert.modeling_bert import BertSelfAttention -from chop.nn.quantized.modules.linear import ( - LinearInteger, -) +from chop.nn.quantized.modules.linear import LinearInteger from chop.nn.quantized.functional import fixed_softermax from chop.nn.quantized.functional import matmul_integer diff --git a/src/chop/nn/quantized/modules/attention_head.py b/src/chop/nn/quantized/modules/attention_head.py index 8f9ea5969..8176f4d52 100644 --- a/src/chop/nn/quantized/modules/attention_head.py +++ b/src/chop/nn/quantized/modules/attention_head.py @@ -6,9 +6,7 @@ from typing import Optional, Tuple from functools import partial -from chop.nn.quantized.functional.matmul import ( - generic_matmul_integer, -) +from chop.nn.quantized.functional.matmul import generic_matmul_integer from chop.nn.quantizers.integer import integer_quantizer @@ -65,10 +63,7 @@ class BertSelfAttentionHeadInteger(_BertSelfAttentionHeadBase): def __init__(self, config, q_config: dict = None) -> None: super().__init__(config) - self.query_quantizer = partial( - integer_quantizer, - **q_config, - ) + self.query_quantizer = partial(integer_quantizer, **q_config,) self.key_quantizer = partial(integer_quantizer, **q_config) self.value_quantizer = partial(integer_quantizer, **q_config) diff --git a/src/chop/nn/quantized/modules/batch_norm1d.py b/src/chop/nn/quantized/modules/batch_norm1d.py index b84c0d131..bafb96f9b 100644 --- a/src/chop/nn/quantized/modules/batch_norm1d.py +++ b/src/chop/nn/quantized/modules/batch_norm1d.py @@ -4,9 +4,7 @@ from torch import Tensor from torch.nn import functional as F -from chop.nn.quantizers import ( - integer_quantizer, -) +from chop.nn.quantizers import integer_quantizer class _BatchNorm1dBase(torch.nn.BatchNorm1d): diff --git a/src/chop/nn/quantized/modules/conv1d.py b/src/chop/nn/quantized/modules/conv1d.py index 67654d917..03662097a 100644 --- a/src/chop/nn/quantized/modules/conv1d.py +++ b/src/chop/nn/quantized/modules/conv1d.py @@ -274,21 +274,15 @@ def __init__( ) self.w_quantizer = partial( - log_quantizer, - width=w_width, - exponent_bias=w_exponent_bias, + log_quantizer, width=w_width, exponent_bias=w_exponent_bias, ) self.x_quantizer = partial( - log_quantizer, - width=x_width, - exponent_bias=x_exponent_bias, + log_quantizer, width=x_width, exponent_bias=x_exponent_bias, ) self.b_quantizer = partial( - log_quantizer, - width=b_width, - exponent_bias=b_exponent_bias, + log_quantizer, width=b_width, exponent_bias=b_exponent_bias, ) diff --git a/src/chop/nn/quantized/modules/conv2d.py b/src/chop/nn/quantized/modules/conv2d.py index cc8cd982a..3d297842e 100644 --- a/src/chop/nn/quantized/modules/conv2d.py +++ b/src/chop/nn/quantized/modules/conv2d.py @@ -363,21 +363,15 @@ def __init__( ) self.w_quantizer = partial( - log_quantizer, - width=w_width, - exponent_bias=w_exponent_bias, + log_quantizer, width=w_width, exponent_bias=w_exponent_bias, ) self.x_quantizer = partial( - log_quantizer, - width=x_width, - exponent_bias=x_exponent_bias, + log_quantizer, width=x_width, exponent_bias=x_exponent_bias, ) self.b_quantizer = partial( - log_quantizer, - width=b_width, - exponent_bias=b_exponent_bias, + log_quantizer, width=b_width, exponent_bias=b_exponent_bias, ) @@ -430,21 +424,15 @@ def __init__( ) self.w_quantizer = partial( - log_quantizer, - width=w_width, - exponent_bias=w_exponent_bias, + log_quantizer, width=w_width, exponent_bias=w_exponent_bias, ) self.x_quantizer = partial( - log_quantizer, - width=x_width, - exponent_bias=x_exponent_bias, + log_quantizer, width=x_width, exponent_bias=x_exponent_bias, ) self.b_quantizer = partial( - log_quantizer, - width=b_width, - exponent_bias=b_exponent_bias, + log_quantizer, width=b_width, exponent_bias=b_exponent_bias, ) @@ -1140,10 +1128,7 @@ def __init__( ) self.unfold = torch.nn.Unfold( - kernel_size=kernel_size, - dilation=dilation, - padding=padding, - stride=stride, + kernel_size=kernel_size, dilation=dilation, padding=padding, stride=stride, ) self.fold = torch.nn.Fold( @@ -1243,11 +1228,7 @@ def forward( expanded_input, targets, initalize ).squeeze() # [10, 589824] output = output.view( - batch_size, - self.out_channels, - self._out_dim(0), - self._out_dim(1), - -1, + batch_size, self.out_channels, self._out_dim(0), self._out_dim(1), -1, ).sum( -1 ) # [10, 256, 1, 1, 2304] -> [10, 256, 1, 1] @@ -1430,10 +1411,10 @@ def forward(self, x: Tensor) -> Tensor: return self.decode(self.lut_forward(x)) def encode(self, input: Tensor) -> Tensor: - return input * 2**self.x_frac_width + return input * 2 ** self.x_frac_width def decode(self, input: Tensor) -> Tensor: - return input / 2**self.x_frac_width + return input / 2 ** self.x_frac_width def math_forward(self, input: Tensor) -> Tensor: return self.y_quantizer( diff --git a/src/chop/nn/quantized/modules/gelu.py b/src/chop/nn/quantized/modules/gelu.py index 4e579efd1..ace6f5510 100644 --- a/src/chop/nn/quantized/modules/gelu.py +++ b/src/chop/nn/quantized/modules/gelu.py @@ -123,9 +123,7 @@ def __init__(self, inplace: bool = False, config: dict = None): config["data_in_exponent_bias"], ) self.x_quantizer = partial( - log_quantizer, - width=x_width, - exponent_bias=x_exponent_bias, + log_quantizer, width=x_width, exponent_bias=x_exponent_bias, ) @@ -143,9 +141,7 @@ def __init__(self, inplace: bool = False, config: dict = None): config["data_in_exponent_bias"], ) self.x_quantizer = partial( - log_quantizer, - width=x_width, - exponent_bias=x_exponent_bias, + log_quantizer, width=x_width, exponent_bias=x_exponent_bias, ) diff --git a/src/chop/nn/quantized/modules/gqa.py b/src/chop/nn/quantized/modules/gqa.py index 1445b05ec..6cac74dda 100644 --- a/src/chop/nn/quantized/modules/gqa.py +++ b/src/chop/nn/quantized/modules/gqa.py @@ -89,10 +89,7 @@ def __init__( ) self.v_matmul_func = partial( - matmul_integer, - config=config, - out_config=out_config, - floor=floor, + matmul_integer, config=config, out_config=out_config, floor=floor, ) o_projection_q_config = { diff --git a/src/chop/nn/quantized/modules/group_norm.py b/src/chop/nn/quantized/modules/group_norm.py index a90e5b651..25721aae9 100644 --- a/src/chop/nn/quantized/modules/group_norm.py +++ b/src/chop/nn/quantized/modules/group_norm.py @@ -7,9 +7,7 @@ from torch import Tensor import torch.nn.functional as F -from chop.nn.quantizers import ( - integer_quantizer, -) +from chop.nn.quantizers import integer_quantizer from mase_components.scalar_operators.fixed.test.isqrt_sw import isqrt_sw2 diff --git a/src/chop/nn/quantized/modules/instance_norm2d.py b/src/chop/nn/quantized/modules/instance_norm2d.py index 0a7260443..d3946401f 100644 --- a/src/chop/nn/quantized/modules/instance_norm2d.py +++ b/src/chop/nn/quantized/modules/instance_norm2d.py @@ -4,9 +4,7 @@ from torch import Tensor import torch.nn.functional as F -from chop.nn.quantizers import ( - integer_quantizer, -) +from chop.nn.quantizers import integer_quantizer class _InstanceNorm2dBase(nn.InstanceNorm2d): diff --git a/src/chop/nn/quantized/modules/layer_norm.py b/src/chop/nn/quantized/modules/layer_norm.py index 2ca5c6068..0d7e4d413 100644 --- a/src/chop/nn/quantized/modules/layer_norm.py +++ b/src/chop/nn/quantized/modules/layer_norm.py @@ -4,9 +4,7 @@ from torch import Tensor import torch.nn.functional as F -from chop.nn.quantizers import ( - integer_quantizer, -) +from chop.nn.quantizers import integer_quantizer class _LayerNormBase(nn.LayerNorm): diff --git a/src/chop/nn/quantized/modules/linear.py b/src/chop/nn/quantized/modules/linear.py index 5d8d389a5..b49c86c54 100644 --- a/src/chop/nn/quantized/modules/linear.py +++ b/src/chop/nn/quantized/modules/linear.py @@ -66,11 +66,7 @@ def __init__( dtype=None, ) -> None: super().__init__( - in_features, - out_features, - bias, - device, - dtype, + in_features, out_features, bias, device, dtype, ) self.bypass = False self.pruning_masks = None @@ -448,9 +444,7 @@ def forward(self, x: Tensor) -> Tensor: if self.binary_training: w = self.w_quantizer(self.weight) return F.linear( - x_expanded, - w * self.gamma.abs() * self.pruning_masks, - self.bias, + x_expanded, w * self.gamma.abs() * self.pruning_masks, self.bias, ) else: self.weigh = self.weight.data.clamp_(-1, 1) @@ -560,9 +554,7 @@ def forward( output = output.view(batch_size, -1) assert output.shape[-1] == self.tables_count output = output.view( - batch_size, - self.out_features, - int(self.tables_count / self.out_features), + batch_size, self.out_features, int(self.tables_count / self.out_features), ) output = output.sum(-1) if self.bias is not None: @@ -773,10 +765,10 @@ def run_layers(self, input: Tensor, layers) -> Tensor: return y def encode(self, input: Tensor) -> Tensor: - return input * 2**self.x_frac_width + return input * 2 ** self.x_frac_width def decode(self, input: Tensor) -> Tensor: - return input / 2**self.x_frac_width + return input / 2 ** self.x_frac_width def forward(self, x: Tensor) -> Tensor: if self.is_lut_inference: diff --git a/src/chop/nn/quantized/modules/relu.py b/src/chop/nn/quantized/modules/relu.py index 2bc527161..a4d1acfbf 100644 --- a/src/chop/nn/quantized/modules/relu.py +++ b/src/chop/nn/quantized/modules/relu.py @@ -121,9 +121,7 @@ def __init__(self, inplace: bool = False, config: dict = None): config["data_in_exponent_bias"], ) self.x_quantizer = partial( - log_quantizer, - width=x_width, - exponent_bias=x_exponent_bias, + log_quantizer, width=x_width, exponent_bias=x_exponent_bias, ) @@ -141,9 +139,7 @@ def __init__(self, inplace: bool = False, config: dict = None): config["data_in_exponent_bias"], ) self.x_quantizer = partial( - log_quantizer, - width=x_width, - exponent_bias=x_exponent_bias, + log_quantizer, width=x_width, exponent_bias=x_exponent_bias, ) diff --git a/src/chop/nn/quantized/modules/rms_norm.py b/src/chop/nn/quantized/modules/rms_norm.py index 91dd9d9d6..a6b893b32 100644 --- a/src/chop/nn/quantized/modules/rms_norm.py +++ b/src/chop/nn/quantized/modules/rms_norm.py @@ -5,9 +5,7 @@ from torch import Tensor import torch.nn.functional as F -from chop.nn.quantizers import ( - integer_quantizer, -) +from chop.nn.quantizers import integer_quantizer def _rms_norm(x: Tensor, eps, scale: Tensor | None): diff --git a/src/chop/nn/quantized/modules/selu.py b/src/chop/nn/quantized/modules/selu.py index 066ffc0b7..482c7c5d1 100644 --- a/src/chop/nn/quantized/modules/selu.py +++ b/src/chop/nn/quantized/modules/selu.py @@ -121,9 +121,7 @@ def __init__(self, inplace: bool = False, config: dict = None): config["data_in_exponent_bias"], ) self.x_quantizer = partial( - log_quantizer, - width=x_width, - exponent_bias=x_exponent_bias, + log_quantizer, width=x_width, exponent_bias=x_exponent_bias, ) @@ -141,9 +139,7 @@ def __init__(self, inplace: bool = False, config: dict = None): config["data_in_exponent_bias"], ) self.x_quantizer = partial( - log_quantizer, - width=x_width, - exponent_bias=x_exponent_bias, + log_quantizer, width=x_width, exponent_bias=x_exponent_bias, ) diff --git a/src/chop/nn/quantized/modules/silu.py b/src/chop/nn/quantized/modules/silu.py index 07f18ef6e..30dac9e5d 100644 --- a/src/chop/nn/quantized/modules/silu.py +++ b/src/chop/nn/quantized/modules/silu.py @@ -113,9 +113,7 @@ def __init__(self, inplace: bool = False, config: dict = None): config["data_in_exponent_bias"], ) self.x_quantizer = partial( - log_quantizer, - width=x_width, - exponent_bias=x_exponent_bias, + log_quantizer, width=x_width, exponent_bias=x_exponent_bias, ) @@ -133,9 +131,7 @@ def __init__(self, inplace: bool = False, config: dict = None): config["data_in_exponent_bias"], ) self.x_quantizer = partial( - log_quantizer, - width=x_width, - exponent_bias=x_exponent_bias, + log_quantizer, width=x_width, exponent_bias=x_exponent_bias, ) diff --git a/src/chop/nn/quantized/modules/softplus.py b/src/chop/nn/quantized/modules/softplus.py index 4e8465c56..458558b63 100644 --- a/src/chop/nn/quantized/modules/softplus.py +++ b/src/chop/nn/quantized/modules/softplus.py @@ -121,9 +121,7 @@ def __init__(self, inplace: bool = False, config: dict = None): config["data_in_exponent_bias"], ) self.x_quantizer = partial( - log_quantizer, - width=x_width, - exponent_bias=x_exponent_bias, + log_quantizer, width=x_width, exponent_bias=x_exponent_bias, ) @@ -141,9 +139,7 @@ def __init__(self, inplace: bool = False, config: dict = None): config["data_in_exponent_bias"], ) self.x_quantizer = partial( - log_quantizer, - width=x_width, - exponent_bias=x_exponent_bias, + log_quantizer, width=x_width, exponent_bias=x_exponent_bias, ) diff --git a/src/chop/nn/quantized/modules/softsign.py b/src/chop/nn/quantized/modules/softsign.py index 5497426aa..fe3e53b62 100644 --- a/src/chop/nn/quantized/modules/softsign.py +++ b/src/chop/nn/quantized/modules/softsign.py @@ -121,9 +121,7 @@ def __init__(self, inplace: bool = False, config: dict = None): config["data_in_exponent_bias"], ) self.x_quantizer = partial( - log_quantizer, - width=x_width, - exponent_bias=x_exponent_bias, + log_quantizer, width=x_width, exponent_bias=x_exponent_bias, ) @@ -141,9 +139,7 @@ def __init__(self, inplace: bool = False, config: dict = None): config["data_in_exponent_bias"], ) self.x_quantizer = partial( - log_quantizer, - width=x_width, - exponent_bias=x_exponent_bias, + log_quantizer, width=x_width, exponent_bias=x_exponent_bias, ) diff --git a/src/chop/nn/quantized/modules/tanh.py b/src/chop/nn/quantized/modules/tanh.py index fce343489..3378a612f 100644 --- a/src/chop/nn/quantized/modules/tanh.py +++ b/src/chop/nn/quantized/modules/tanh.py @@ -121,9 +121,7 @@ def __init__(self, inplace: bool = False, config: dict = None): config["data_in_exponent_bias"], ) self.x_quantizer = partial( - log_quantizer, - width=x_width, - exponent_bias=x_exponent_bias, + log_quantizer, width=x_width, exponent_bias=x_exponent_bias, ) @@ -141,9 +139,7 @@ def __init__(self, inplace: bool = False, config: dict = None): config["data_in_exponent_bias"], ) self.x_quantizer = partial( - log_quantizer, - width=x_width, - exponent_bias=x_exponent_bias, + log_quantizer, width=x_width, exponent_bias=x_exponent_bias, ) diff --git a/src/chop/nn/quantizers/LUTNet/BaseInitializer.py b/src/chop/nn/quantizers/LUTNet/BaseInitializer.py index 14cf53839..72478f46d 100644 --- a/src/chop/nn/quantizers/LUTNet/BaseInitializer.py +++ b/src/chop/nn/quantizers/LUTNet/BaseInitializer.py @@ -62,10 +62,7 @@ def update_luts_weights(self) -> torch.Tensor: key = row.detach().cpu().flatten().sign().numpy().tolist() new_weights.append(key) new_weights = torch.tensor( - new_weights, - dtype=torch.float32, - requires_grad=True, - device=self.device, + new_weights, dtype=torch.float32, requires_grad=True, device=self.device, ).view(-1, self.kk) return new_weights diff --git a/src/chop/nn/quantizers/LUTNet/BaseTrainer.py b/src/chop/nn/quantizers/LUTNet/BaseTrainer.py index ff48436e0..45d95805c 100644 --- a/src/chop/nn/quantizers/LUTNet/BaseTrainer.py +++ b/src/chop/nn/quantizers/LUTNet/BaseTrainer.py @@ -38,7 +38,7 @@ def __init__( levels (int): Number of residual level to use. """ self.k = k - self.kk = 2**k + self.kk = 2 ** k self.binarization_level = binarization_level self.input_expanded = input_expanded self.tables_count = tables_count diff --git a/src/chop/nn/quantizers/block_fp.py b/src/chop/nn/quantizers/block_fp.py index bfe8534e3..826133862 100644 --- a/src/chop/nn/quantizers/block_fp.py +++ b/src/chop/nn/quantizers/block_fp.py @@ -47,10 +47,10 @@ def _block_fp_quantize( if exponent_bias in (None, "none", "None"): exponent_bias = 2 ** (exponent_width - 1) - 1 - exponent_max = 2**exponent_width - 1 - exponent_bias + exponent_max = 2 ** exponent_width - 1 - exponent_bias exponent_min = -exponent_bias - mantissa_integer_max = 2**mantissa_bits - 1 + mantissa_integer_max = 2 ** mantissa_bits - 1 # sign per_block_sign = torch.sign(blocked_x + 1e-9) # exponent @@ -58,14 +58,14 @@ def _block_fp_quantize( per_block_exponent = torch.ceil(torch.log2(per_block_max)) per_block_exponent = my_clamp(per_block_exponent, exponent_min, exponent_max) # mantissa - per_block_mantissa = per_block_value / 2**per_block_exponent - shift = 2**mantissa_bits + per_block_mantissa = per_block_value / 2 ** per_block_exponent + shift = 2 ** mantissa_bits per_block_mantissa_integer = my_clamp( my_round(per_block_mantissa * shift), 0, mantissa_integer_max ) per_block_mantissa = per_block_mantissa_integer / shift - per_block_msfp = per_block_sign * (2**per_block_exponent) * per_block_mantissa + per_block_msfp = per_block_sign * (2 ** per_block_exponent) * per_block_mantissa msfp_x = unblock( per_block_msfp, x_shape_before_blocking=x_shape_before_blocking, @@ -133,10 +133,5 @@ def block_fp_quantizer( """ return BlockFPQuantize.apply( - x, - width, - exponent_width, - exponent_bias, - block_size, - skip_first_dim, + x, width, exponent_width, exponent_bias, block_size, skip_first_dim, ) diff --git a/src/chop/nn/quantizers/block_log.py b/src/chop/nn/quantizers/block_log.py index 8e65c91ea..1773b42b8 100644 --- a/src/chop/nn/quantizers/block_log.py +++ b/src/chop/nn/quantizers/block_log.py @@ -40,7 +40,7 @@ def _block_log_quantize( per_block_max_exponent = torch.ceil(torch.log2(per_block_max)) per_block_bias = my_clamp( - 2**exponent_bits - 1 - per_block_max_exponent, 0, 2**exponent_bias_width - 1 + 2 ** exponent_bits - 1 - per_block_max_exponent, 0, 2 ** exponent_bias_width - 1 ) per_block_lq_x = _log_quantize(blocked_x, width=width, exponent_bias=per_block_bias) @@ -98,9 +98,5 @@ def block_log_quantizer( - `block_size`: a list of integers where each integer is the block size along the corresponding dim """ return BlockLogQuantize.apply( - x, - width, - exponent_bias_width, - block_size, - skip_first_dim, + x, width, exponent_bias_width, block_size, skip_first_dim, ) diff --git a/src/chop/nn/quantizers/block_minifloat.py b/src/chop/nn/quantizers/block_minifloat.py index 34e00bbcb..ccef6649d 100644 --- a/src/chop/nn/quantizers/block_minifloat.py +++ b/src/chop/nn/quantizers/block_minifloat.py @@ -41,7 +41,7 @@ def _block_minifloat_quantize( per_block_max[per_block_max == 0] = per_block_max[per_block_max != 0].min() per_block_exponent_bias = my_clamp( - torch.floor(torch.log2(per_block_max)), 0, 2**exponent_bias_width - 1 + torch.floor(torch.log2(per_block_max)), 0, 2 ** exponent_bias_width - 1 ) per_block_bm_x = _minifloat_ieee_quantize( blocked_x, @@ -118,10 +118,5 @@ def block_minifloat_quantizer( """ return BlockMinifloatQuantize.apply( - x, - width, - exponent_width, - exponent_bias_width, - block_size, - skip_first_dim, + x, width, exponent_width, exponent_bias_width, block_size, skip_first_dim, ) diff --git a/src/chop/nn/quantizers/integer.py b/src/chop/nn/quantizers/integer.py index 9f3ffec1e..8c2573f22 100644 --- a/src/chop/nn/quantizers/integer.py +++ b/src/chop/nn/quantizers/integer.py @@ -34,9 +34,9 @@ def _integer_quantize( int_max = 2 ** (width - 1) - 1 else: int_min = 0 - int_max = 2**width - 1 + int_max = 2 ** width - 1 # thresh = 2 ** (width - 1) - scale = 2**frac_width + scale = 2 ** frac_width if isinstance(x, (Tensor, ndarray)): return my_clamp(my_round(x.mul(scale)), int_min, int_max).div(scale) @@ -57,8 +57,8 @@ def _integer_floor_quantize( int_max = 2 ** (width - 1) - 1 else: int_min = 0 - int_max = 2**width - 1 - scale = 2**frac_width + int_max = 2 ** width - 1 + scale = 2 ** frac_width if isinstance(x, (Tensor, ndarray)): return my_clamp(my_floor(x.mul(scale)), int_min, int_max).div(scale) diff --git a/src/chop/nn/quantizers/log.py b/src/chop/nn/quantizers/log.py index 98ab45e76..926e33cdb 100644 --- a/src/chop/nn/quantizers/log.py +++ b/src/chop/nn/quantizers/log.py @@ -8,9 +8,7 @@ def _log_quantize( - x: Tensor | ndarray, - width: int, - exponent_bias: int | Tensor | ndarray | None, + x: Tensor | ndarray, width: int, exponent_bias: int | Tensor | ndarray | None, ): """ - Use non-uniform, base-2 logarithmic representation to encode IEEE FP32/64 @@ -32,16 +30,16 @@ def _log_quantize( if exponent_bias in (None, "none", "None"): exponent_bias = 2 ** (exponent_bits - 1) - 1 - exponent_max = 2**exponent_bits - 1 - exponent_bias + exponent_max = 2 ** exponent_bits - 1 - exponent_bias exponent_min = -exponent_bias - min_pos = 2**exponent_min + min_pos = 2 ** exponent_min sign = torch.sign(x + min_pos * 0.1) value = torch.abs(x) + min_pos * 0.1 exponent = my_clamp(my_round(torch.log2(value)), exponent_min, exponent_max) - return sign * (2**exponent) + return sign * (2 ** exponent) class LogQuantize(torch.autograd.Function): @@ -58,9 +56,7 @@ def backward(ctx, grad_output): def log_quantizer( - x: Tensor | ndarray, - width: int, - exponent_bias: int | Tensor | ndarray | None, + x: Tensor | ndarray, width: int, exponent_bias: int | Tensor | ndarray | None, ): """ Convert IEEE FP32/64 to base-2 log quantized values diff --git a/src/chop/nn/quantizers/minifloat.py b/src/chop/nn/quantizers/minifloat.py index 2d6f23103..f19097fde 100644 --- a/src/chop/nn/quantizers/minifloat.py +++ b/src/chop/nn/quantizers/minifloat.py @@ -5,10 +5,7 @@ def _minifloat_denorm_quantize( - x: Tensor, - width: int, - exponent_width: int, - exponent_bias: int = None, + x: Tensor, width: int, exponent_width: int, exponent_bias: int = None, ): """ - Converts IEEE FP32/64 to minifloat without the implicit leading bit in mantissas. @@ -37,10 +34,10 @@ def _minifloat_denorm_quantize( if exponent_bias in (None, "none", "None"): exponent_bias = 2 ** (exponent_width - 1) - 1 - exponent_max = 2**exponent_width - 1 - exponent_bias + exponent_max = 2 ** exponent_width - 1 - exponent_bias exponent_min = -exponent_bias # if the mantissa is an integer, the max mantissa value will be (2**mantissa_bits -1) - shifted_mantissa_max = 2**mantissa_bits - 1 + shifted_mantissa_max = 2 ** mantissa_bits - 1 shifted_mantissa_min = 0 sign = torch.sign(x + 1e-9) @@ -52,8 +49,8 @@ def _minifloat_denorm_quantize( # divide value by clipped exponent. this ensures the simulated minifloat value is correct # when x is too large (minifloat will saturate) or too close to 0. - mantissa = value / 2**exponent - shift = 2**mantissa_bits + mantissa = value / 2 ** exponent + shift = 2 ** mantissa_bits shifted_mantissa = my_round(mantissa * shift) # clip the integer mantissa. shifted_mantissa = my_clamp( @@ -71,11 +68,7 @@ def _minifloat_denorm_quantize( class MinifloatDenormQuantize(torch.autograd.Function): @staticmethod def forward( - ctx, - x: Tensor, - width: int, - exponent_width: int, - exponent_bias: int = None, + ctx, x: Tensor, width: int, exponent_width: int, exponent_bias: int = None, ): return _minifloat_denorm_quantize( x, width=width, exponent_width=exponent_width, exponent_bias=exponent_bias @@ -88,10 +81,7 @@ def backward(ctx, grad_output): def minifloat_denorm_quantizer( - x: Tensor, - width: int, - exponent_width: int, - exponent_bias: int = None, + x: Tensor, width: int, exponent_width: int, exponent_bias: int = None, ): """ - Converts IEEE FP32/64 to minifloat without the implicit leading bit in mantissas. @@ -148,11 +138,11 @@ def _minifloat_ieee_quantize( if exponent_bias in (None, "none", "None"): exponent_bias = 2 ** (exponent_width - 1) - 1 # upper and lower bound of shifted exponent - exponent_max = 2**exponent_width - 1 - exponent_bias + exponent_max = 2 ** exponent_width - 1 - exponent_bias exponent_min = -exponent_bias # upper and lower bound of shifted minifloat mantissa - shift = 2**mantissa_bits - shifted_mantissa_max = 2**mantissa_bits - 1 + shift = 2 ** mantissa_bits + shifted_mantissa_max = 2 ** mantissa_bits - 1 shifted_mantissa_min = 0 sign = torch.sign(x + 1e-9) @@ -162,9 +152,9 @@ def _minifloat_ieee_quantize( exponent = torch.floor(torch.log2(value + 1e-9)) exponent = my_clamp(exponent, exponent_min, exponent_max) - mantissa = value / 2**exponent + mantissa = value / 2 ** exponent - shift = 2**mantissa_bits + shift = 2 ** mantissa_bits # fmt: off # if the clipped exponent is zero, the minifloat is in a subnormal form # this `is_normal` also help the grad keeps 1 if input x is 0, or the zero-initialized value will be trapped in 0 diff --git a/src/chop/nn/quantizers/mxint_hardware.py b/src/chop/nn/quantizers/mxint_hardware.py index 0c3e06130..1a61fb70c 100644 --- a/src/chop/nn/quantizers/mxint_hardware.py +++ b/src/chop/nn/quantizers/mxint_hardware.py @@ -19,7 +19,7 @@ def mxint_quant_block( """ exponent_bias = 2 ** (exponent_width - 1) - exponent_max = 2**exponent_width - 1 - exponent_bias + exponent_max = 2 ** exponent_width - 1 - exponent_bias exponent_min = -exponent_bias # exponent @@ -29,9 +29,9 @@ def mxint_quant_block( # mantissa int_min = -(2 ** (width - 1)) int_max = 2 ** (width - 1) - 1 - mantissa = x / 2**exponent + mantissa = x / 2 ** exponent mantissa = torch.clamp(mantissa.floor(), int_min, int_max) - q_x = (2**exponent) * mantissa + q_x = (2 ** exponent) * mantissa return q_x diff --git a/src/chop/nn/quantizers/quantizers_for_hw.py b/src/chop/nn/quantizers/quantizers_for_hw.py index d5ca3d8cf..ccf57319f 100644 --- a/src/chop/nn/quantizers/quantizers_for_hw.py +++ b/src/chop/nn/quantizers/quantizers_for_hw.py @@ -9,31 +9,31 @@ def integer_quantizer_for_hw(x: Tensor, width: int, frac_width: int): thresh = 2 ** (width - 1) - scale = 2**frac_width + scale = 2 ** frac_width fixed_point_value = my_clamp(my_round(x.mul(scale)), -thresh, thresh - 1) fixed_point_value = fixed_point_value.to(torch.int) - fixed_point_value = fixed_point_value % (2**width) + fixed_point_value = fixed_point_value % (2 ** width) return fixed_point_value def unsigned_integer_quantizer_for_hw(x: Tensor, width: int, frac_width: int): - thresh = 2**width - 1 - scale = 2**frac_width + thresh = 2 ** width - 1 + scale = 2 ** frac_width fixed_point_value = my_clamp(my_floor(x.mul(scale)), 0, thresh) fixed_point_value = fixed_point_value.to(torch.int) - fixed_point_value = fixed_point_value % (2**width) + fixed_point_value = fixed_point_value % (2 ** width) return fixed_point_value def integer_floor_quantizer_for_hw(x: Tensor, width: int, frac_width: int): thresh = 2 ** (width - 1) - scale = 2**frac_width + scale = 2 ** frac_width fixed_point_value = my_clamp(my_floor(x.mul(scale)), -thresh, thresh - 1) fixed_point_value = fixed_point_value.to(torch.int) - fixed_point_value = fixed_point_value % (2**width) + fixed_point_value = fixed_point_value % (2 ** width) return fixed_point_value diff --git a/src/chop/nn/quantizers/ternary.py b/src/chop/nn/quantizers/ternary.py index a27951c23..89fcfe75b 100644 --- a/src/chop/nn/quantizers/ternary.py +++ b/src/chop/nn/quantizers/ternary.py @@ -45,8 +45,7 @@ def ternary_quantizer( # ) if scaling_factor: x = ternarised_scaled_op( - x, - threshold, # abs_mean=mean + x, threshold, # abs_mean=mean ) # [mean, 0 ,-mean] # this function determines the mean on the fly, maybe we could make an alternative which uses the metadata? else: x = ternarised_op(x, threshold) # [1, 0 ,-1] diff --git a/src/chop/nn/quantizers/utils.py b/src/chop/nn/quantizers/utils.py index ae881328d..9efc8b512 100644 --- a/src/chop/nn/quantizers/utils.py +++ b/src/chop/nn/quantizers/utils.py @@ -224,16 +224,8 @@ def forward(ctx, input, _threshold): alpha = TernaryScaled.alpha(input, delta) output = torch.zeros_like(input) - pos_one = torch.where( - input > delta, - 1.0, - 0.0, - ) - neg_one = torch.where( - input < -delta, - -1.0, - 0.0, - ) + pos_one = torch.where(input > delta, 1.0, 0.0,) + neg_one = torch.where(input < -delta, -1.0, 0.0,) output = (pos_one + neg_one) * alpha.view(-1, 1, 1, 1).expand( -1, input.size()[1], input.size()[2], input.size()[3] ) @@ -295,16 +287,8 @@ def forward(ctx, input, _threshold): alpha = TernaryScaled.alpha(input, delta) output = torch.zeros_like(input) - pos_one = torch.where( - input > delta, - 1.0, - 0.0, - ) - neg_one = torch.where( - input < -delta, - -1.0, - 0.0, - ) + pos_one = torch.where(input > delta, 1.0, 0.0,) + neg_one = torch.where(input < -delta, -1.0, 0.0,) output = pos_one + neg_one return output @@ -413,8 +397,7 @@ def _block_1d_bias(x: Tensor, block_shape: List[int]): def _unblock_to_1d_bias( - blocked_x: Tensor, - x_shape_before_blocking: List[int], + blocked_x: Tensor, x_shape_before_blocking: List[int], ): """ blocked bias shape: [num_blocks, block_size] -> [output_features] @@ -609,10 +592,7 @@ def unblock( return _unblock_to_2d_activation(blocked_x, x_shape_before_blocking) else: return _unblock_to_2d_weight( - blocked_x, - x_shape_before_blocking, - padded_x_shape, - block_shape, + blocked_x, x_shape_before_blocking, padded_x_shape, block_shape, ) elif len(x_shape_before_blocking) == 3: if skipped_first_dim_when_blocking: diff --git a/src/chop/nn/snn/auto_cuda/generator.py b/src/chop/nn/snn/auto_cuda/generator.py index 639cf0379..f1637ddf9 100644 --- a/src/chop/nn/snn/auto_cuda/generator.py +++ b/src/chop/nn/snn/auto_cuda/generator.py @@ -310,10 +310,7 @@ def gen_forward_codes( params.append(("v_reset", "const float &")) params.extend( - [ - ("neuron_num", "const int &"), - ("numel", "const int &"), - ] + [("neuron_num", "const int &"), ("numel", "const int &"),] ) params_name = [] for item in params: diff --git a/src/chop/nn/snn/modules/neuron/ifnode.py b/src/chop/nn/snn/modules/neuron/ifnode.py index e5e3c373f..c2376d310 100644 --- a/src/chop/nn/snn/modules/neuron/ifnode.py +++ b/src/chop/nn/snn/modules/neuron/ifnode.py @@ -244,10 +244,12 @@ def multi_step_forward(self, x_seq: torch.Tensor): self.v_float_to_tensor(x_seq[0]) if self.v_reset is None: if self.store_v_seq: - spike_seq, self.v, self.v_seq = ( - self.jit_eval_multi_step_forward_soft_reset_with_v_seq( - x_seq, self.v, self.v_threshold - ) + ( + spike_seq, + self.v, + self.v_seq, + ) = self.jit_eval_multi_step_forward_soft_reset_with_v_seq( + x_seq, self.v, self.v_threshold ) else: spike_seq, self.v = self.jit_eval_multi_step_forward_soft_reset( @@ -255,10 +257,12 @@ def multi_step_forward(self, x_seq: torch.Tensor): ) else: if self.store_v_seq: - spike_seq, self.v, self.v_seq = ( - self.jit_eval_multi_step_forward_hard_reset_with_v_seq( - x_seq, self.v, self.v_threshold, self.v_reset - ) + ( + spike_seq, + self.v, + self.v_seq, + ) = self.jit_eval_multi_step_forward_hard_reset_with_v_seq( + x_seq, self.v, self.v_threshold, self.v_reset ) else: spike_seq, self.v = self.jit_eval_multi_step_forward_hard_reset( diff --git a/src/chop/nn/snn/modules/neuron/lifnode.py b/src/chop/nn/snn/modules/neuron/lifnode.py index 90b02289c..0f0522f4d 100644 --- a/src/chop/nn/snn/modules/neuron/lifnode.py +++ b/src/chop/nn/snn/modules/neuron/lifnode.py @@ -417,29 +417,33 @@ def single_step_forward(self, x: torch.Tensor): self.v_float_to_tensor(x) if self.v_reset is None: if self.decay_input: - spike, self.v = ( - self.jit_eval_single_step_forward_soft_reset_decay_input( - x, self.v, self.v_threshold, self.tau - ) + ( + spike, + self.v, + ) = self.jit_eval_single_step_forward_soft_reset_decay_input( + x, self.v, self.v_threshold, self.tau ) else: - spike, self.v = ( - self.jit_eval_single_step_forward_soft_reset_no_decay_input( - x, self.v, self.v_threshold, self.tau - ) + ( + spike, + self.v, + ) = self.jit_eval_single_step_forward_soft_reset_no_decay_input( + x, self.v, self.v_threshold, self.tau ) else: if self.decay_input: - spike, self.v = ( - self.jit_eval_single_step_forward_hard_reset_decay_input( - x, self.v, self.v_threshold, self.v_reset, self.tau - ) + ( + spike, + self.v, + ) = self.jit_eval_single_step_forward_hard_reset_decay_input( + x, self.v, self.v_threshold, self.v_reset, self.tau ) else: - spike, self.v = ( - self.jit_eval_single_step_forward_hard_reset_no_decay_input( - x, self.v, self.v_threshold, self.v_reset, self.tau - ) + ( + spike, + self.v, + ) = self.jit_eval_single_step_forward_hard_reset_no_decay_input( + x, self.v, self.v_threshold, self.v_reset, self.tau ) return spike @@ -514,56 +518,68 @@ def multi_step_forward(self, x_seq: torch.Tensor): if self.v_reset is None: if self.decay_input: if self.store_v_seq: - spike_seq, self.v, self.v_seq = ( - self.jit_eval_multi_step_forward_soft_reset_decay_input_with_v_seq( - x_seq, self.v, self.v_threshold, self.tau - ) + ( + spike_seq, + self.v, + self.v_seq, + ) = self.jit_eval_multi_step_forward_soft_reset_decay_input_with_v_seq( + x_seq, self.v, self.v_threshold, self.tau ) else: - spike_seq, self.v = ( - self.jit_eval_multi_step_forward_soft_reset_decay_input( - x_seq, self.v, self.v_threshold, self.tau - ) + ( + spike_seq, + self.v, + ) = self.jit_eval_multi_step_forward_soft_reset_decay_input( + x_seq, self.v, self.v_threshold, self.tau ) else: if self.store_v_seq: - spike_seq, self.v, self.v_seq = ( - self.jit_eval_multi_step_forward_soft_reset_no_decay_input_with_v_seq( - x_seq, self.v, self.v_threshold, self.tau - ) + ( + spike_seq, + self.v, + self.v_seq, + ) = self.jit_eval_multi_step_forward_soft_reset_no_decay_input_with_v_seq( + x_seq, self.v, self.v_threshold, self.tau ) else: - spike_seq, self.v = ( - self.jit_eval_multi_step_forward_soft_reset_no_decay_input( - x_seq, self.v, self.v_threshold, self.tau - ) + ( + spike_seq, + self.v, + ) = self.jit_eval_multi_step_forward_soft_reset_no_decay_input( + x_seq, self.v, self.v_threshold, self.tau ) else: if self.decay_input: if self.store_v_seq: - spike_seq, self.v, self.v_seq = ( - self.jit_eval_multi_step_forward_hard_reset_decay_input_with_v_seq( - x_seq, self.v, self.v_threshold, self.v_reset, self.tau - ) + ( + spike_seq, + self.v, + self.v_seq, + ) = self.jit_eval_multi_step_forward_hard_reset_decay_input_with_v_seq( + x_seq, self.v, self.v_threshold, self.v_reset, self.tau ) else: - spike_seq, self.v = ( - self.jit_eval_multi_step_forward_hard_reset_decay_input( - x_seq, self.v, self.v_threshold, self.v_reset, self.tau - ) + ( + spike_seq, + self.v, + ) = self.jit_eval_multi_step_forward_hard_reset_decay_input( + x_seq, self.v, self.v_threshold, self.v_reset, self.tau ) else: if self.store_v_seq: - spike_seq, self.v, self.v_seq = ( - self.jit_eval_multi_step_forward_hard_reset_no_decay_input_with_v_seq( - x_seq, self.v, self.v_threshold, self.v_reset, self.tau - ) + ( + spike_seq, + self.v, + self.v_seq, + ) = self.jit_eval_multi_step_forward_hard_reset_no_decay_input_with_v_seq( + x_seq, self.v, self.v_threshold, self.v_reset, self.tau ) else: - spike_seq, self.v = ( - self.jit_eval_multi_step_forward_hard_reset_no_decay_input( - x_seq, self.v, self.v_threshold, self.v_reset, self.tau - ) + ( + spike_seq, + self.v, + ) = self.jit_eval_multi_step_forward_hard_reset_no_decay_input( + x_seq, self.v, self.v_threshold, self.v_reset, self.tau ) return spike_seq diff --git a/src/chop/nn/snn/modules/spiking_self_attention.py b/src/chop/nn/snn/modules/spiking_self_attention.py index 87ee98a30..3fa47d1d8 100644 --- a/src/chop/nn/snn/modules/spiking_self_attention.py +++ b/src/chop/nn/snn/modules/spiking_self_attention.py @@ -145,9 +145,7 @@ def __init__(self, in_channels, num_conv=1, ratio=4, group_size=64, activation=L super().__init__() inner_channels = in_channels * ratio self.up = nn.Sequential( - activation(), - Conv1x1(in_channels, inner_channels), - BN(inner_channels), + activation(), Conv1x1(in_channels, inner_channels), BN(inner_channels), ) self.conv = nn.ModuleList() for _ in range(num_conv): @@ -163,9 +161,7 @@ def __init__(self, in_channels, num_conv=1, ratio=4, group_size=64, activation=L ) ) self.down = nn.Sequential( - activation(), - Conv1x1(inner_channels, in_channels), - BN(in_channels), + activation(), Conv1x1(inner_channels, in_channels), BN(in_channels), ) def forward(self, x: torch.Tensor) -> torch.Tensor: diff --git a/src/chop/passes/__init__.py b/src/chop/passes/__init__.py index bfc69b47e..d1d63d8f5 100644 --- a/src/chop/passes/__init__.py +++ b/src/chop/passes/__init__.py @@ -40,8 +40,6 @@ from .module.analysis import calculate_avg_bits_module_analysis_pass from .module.transforms import quantize_module_transform_pass, resharding_transform_pass -from .onnx.analysis import ( - export_fx_graph_analysis_pass, -) +from .onnx.analysis import export_fx_graph_analysis_pass from .graph.analysis.autosharding import autosharding_analysis_pass diff --git a/src/chop/passes/graph/__init__.py b/src/chop/passes/graph/__init__.py index 188038a14..3274c9143 100644 --- a/src/chop/passes/graph/__init__.py +++ b/src/chop/passes/graph/__init__.py @@ -46,9 +46,7 @@ from .transforms.quantize.quant_parsers import parse_node_config -from chop.passes.graph.analysis.runtime.runtime_analysis import ( - runtime_analysis_pass, -) +from chop.passes.graph.analysis.runtime.runtime_analysis import runtime_analysis_pass from .interface import tensorrt_engine_interface_pass @@ -151,6 +149,6 @@ if check_dependencies("tensorrt_fake_quantize_transform_pass"): TRANSFORM_PASSES.append("tensorrt_fake_quantize_transform_pass") - PASSES["tensorrt_fake_quantize_transform_pass"] = ( - tensorrt_fake_quantize_transform_pass - ) + PASSES[ + "tensorrt_fake_quantize_transform_pass" + ] = tensorrt_fake_quantize_transform_pass diff --git a/src/chop/passes/graph/analysis/add_metadata/add_common_metadata.py b/src/chop/passes/graph/analysis/add_metadata/add_common_metadata.py index 66fe9cc28..9fde9d7cb 100644 --- a/src/chop/passes/graph/analysis/add_metadata/add_common_metadata.py +++ b/src/chop/passes/graph/analysis/add_metadata/add_common_metadata.py @@ -203,10 +203,7 @@ def graph_iterator_for_mase_ops(graph): def graph_iterator_for_metadata( - graph, - dummy_in=None, - add_value=True, - force_device_meta=False, + graph, dummy_in=None, add_value=True, force_device_meta=False, ): """ largely adapted from https://pytorch.org/docs/stable/fx.html @@ -291,12 +288,7 @@ def _add_graph_metadata(graph): def add_common_metadata_analysis_pass( - graph, - pass_args={ - "dummy_in": None, - "add_value": True, - "force_device_meta": False, - }, + graph, pass_args={"dummy_in": None, "add_value": True, "force_device_meta": False,}, ): """add common metadata diff --git a/src/chop/passes/graph/analysis/add_metadata/add_hardware_metadata.py b/src/chop/passes/graph/analysis/add_metadata/add_hardware_metadata.py index 39c8ac690..9fd2e9466 100644 --- a/src/chop/passes/graph/analysis/add_metadata/add_hardware_metadata.py +++ b/src/chop/passes/graph/analysis/add_metadata/add_hardware_metadata.py @@ -39,8 +39,7 @@ def add_component_source(node): if mase_op == "user_defined_module": for custom_op, op_info in node.meta["mase"].model.custom_ops["modules"].items(): if isinstance( - deepgetattr(node.meta["mase"].model, node.target), - custom_op, + deepgetattr(node.meta["mase"].model, node.target), custom_op, ): node.meta["mase"]["hardware"]["toolchain"] = "INTERNAL_RTL" node.meta["mase"]["hardware"]["module"] = op_info["module"] diff --git a/src/chop/passes/graph/analysis/add_metadata/common_metadata_layers.py b/src/chop/passes/graph/analysis/add_metadata/common_metadata_layers.py index 6f5ee0147..ed3249c9c 100644 --- a/src/chop/passes/graph/analysis/add_metadata/common_metadata_layers.py +++ b/src/chop/passes/graph/analysis/add_metadata/common_metadata_layers.py @@ -231,10 +231,7 @@ }, # Inserted ops from the replace_method_with_function pass "torch_size": {"input": "data_in", "dim": "config"}, - "torch_contiguous": { - "input": "data_in", - "memory_format": "config", - }, + "torch_contiguous": {"input": "data_in", "memory_format": "config",}, # arbitrary length - support up to 4 "torch_expand": { "input": "data_in", @@ -257,11 +254,7 @@ "shape_2": "config", "shape_3": "config", }, - "torch_split": { - "input": "data_in", - "split_size": "config", - "dim": "config", - }, + "torch_split": {"input": "data_in", "split_size": "config", "dim": "config",}, "torch_permute": { "input": "data_in", "dim_0": "config", @@ -269,11 +262,7 @@ "dim_2": "config", "dim_3": "config", }, - "torch_transpose": { - "input": "data_in", - "dim0": "config", - "dim1": "config", - }, + "torch_transpose": {"input": "data_in", "dim0": "config", "dim1": "config",}, # DTensor ops "dtensor_arange": { "device_mesh": "config", @@ -287,9 +276,7 @@ "requires_grad": "config", }, # tensor constructor - "tensor": { - "data": "data_in", - }, + "tensor": {"data": "data_in",}, # https://pytorch.org/docs/stable/generated/torch.nn.functional.dropout.html "dropout": { "input": "data_in", @@ -353,10 +340,7 @@ "softmax": {"input": "data_in"}, "gelu": {"input": "data_in"}, # https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html - "crossentropyloss": { - "input": "data_in", - "target": "data_in", - }, + "crossentropyloss": {"input": "data_in", "target": "data_in",}, # chop.nn.modules.lora.LoRALinear "loralinear": {"input": "data_in"}, "grouped_query_attention": {"input": "data_in"}, @@ -405,24 +389,15 @@ "size_3": "config", }, # Tensor.max(dim=None, keepdim=False) - "max": { - "dim": "config", - "keepdim": "config", - }, + "max": {"dim": "config", "keepdim": "config",}, # https://pytorch.org/docs/stable/generated/torch.Tensor.sum.html - "sum": { - "dim": "config", - "keepdim": "config", - }, + "sum": {"dim": "config", "keepdim": "config",}, # https://pytorch.org/docs/stable/generated/torch.Tensor.round.html "round": {}, # https://pytorch.org/docs/stable/generated/torch.Tensor.floor.html "floor": {}, # https://pytorch.org/docs/stable/generated/torch.Tensor.clamp.html - "clamp": { - "min": "config", - "max": "config", - }, + "clamp": {"min": "config", "max": "config",}, # https://pytorch.org/docs/stable/generated/torch.Tensor.dim.html "dim": {}, # https://pytorch.org/docs/stable/generated/torch.Tensor.permute.html#torch.Tensor.permute @@ -451,11 +426,7 @@ # https://pytorch.org/docs/stable/generated/torch.Tensor.type_as.html "type_as": {"tensor": "data_in"}, # https://pytorch.org/docs/stable/generated/torch.Tensor.index_select.html - "index_select": { - "input": "data_in", - "dim": "config", - "index": "data_in", - }, + "index_select": {"input": "data_in", "dim": "config", "index": "data_in",}, # https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html "detach": {"input": "data_in"}, } @@ -508,11 +479,7 @@ def deepgetattr(obj, attr): def _annotate_arg_metadata( - meta: MaseMetadata, - args: list, - kwargs: dict, - func_data: dict, - add_value: bool, + meta: MaseMetadata, args: list, kwargs: dict, func_data: dict, add_value: bool, ): """ Analyse target args and kwargs received from shape propagation to annotate combined meta["mase"]["args"] @@ -621,9 +588,7 @@ def _annotate_arg_metadata( def _annotate_result_metadata( - meta: MaseMetadata, - result, - add_value: bool, + meta: MaseMetadata, result, add_value: bool, ) -> MaseMetadata: """ Analyse the result from running the target to annotate the meta["mase"]["results"] dictionary with metadata. diff --git a/src/chop/passes/graph/analysis/add_metadata/hardware_metadata_layers.py b/src/chop/passes/graph/analysis/add_metadata/hardware_metadata_layers.py index b084821ae..a92958cff 100644 --- a/src/chop/passes/graph/analysis/add_metadata/hardware_metadata_layers.py +++ b/src/chop/passes/graph/analysis/add_metadata/hardware_metadata_layers.py @@ -64,17 +64,13 @@ "relu": [ { "name": "fixed_relu", - "dependence_files": [ - "activation_layers/rtl/fixed_relu.sv", - ], + "dependence_files": ["activation_layers/rtl/fixed_relu.sv",], }, ], "hardshrink": [ { "name": "fixed_hardshrink", - "dependence_files": [ - "activation_layers/rtl/fixed_hardshrink.sv", - ], + "dependence_files": ["activation_layers/rtl/fixed_hardshrink.sv",], }, ], "silu": [ @@ -107,9 +103,7 @@ "softshrink": [ { "name": "fixed_softshrink", - "dependence_files": [ - "activation_layers/rtl/fixed_softshrink.sv", - ], + "dependence_files": ["activation_layers/rtl/fixed_softshrink.sv",], }, ], "logsigmoid": [ @@ -138,17 +132,13 @@ "selu": [ { "name": "fixed_selu", - "dependence_files": [ - "activation_layers/rtl/fixed_selu.sv", - ], + "dependence_files": ["activation_layers/rtl/fixed_selu.sv",], }, ], "tanh": [ { "name": "fixed_tanh", - "dependence_files": [ - "activation_layers/rtl/fixed_tanh.sv", - ], + "dependence_files": ["activation_layers/rtl/fixed_tanh.sv",], }, ], "gelu": [ @@ -172,17 +162,13 @@ "softplus": [ { "name": "fixed_softplus", - "dependence_files": [ - "activation_layers/rtl/fixed_softplus.sv", - ], + "dependence_files": ["activation_layers/rtl/fixed_softplus.sv",], }, ], "add": [ { "name": "fixed_adder", - "dependence_files": [ - "linear_layers/fixed_operators/rtl/fixed_adder.sv", - ], + "dependence_files": ["linear_layers/fixed_operators/rtl/fixed_adder.sv",], } ], "mul": [ @@ -199,14 +185,7 @@ "dependence_files": ["common/rtl/df_split.sv", "common/rtl/split2.sv"], } ], - "getitem": [ - { - "name": "buffer", - "dependence_files": [ - "memory/rtl/buffer.sv", - ], - } - ], + "getitem": [{"name": "buffer", "dependence_files": ["memory/rtl/buffer.sv",],}], "grouped_query_attention": [ { "name": "fixed_gqa_wrapper", diff --git a/src/chop/passes/graph/analysis/add_metadata/software_metadata_layers.py b/src/chop/passes/graph/analysis/add_metadata/software_metadata_layers.py index d7f4fd8ed..1e6ad2a57 100644 --- a/src/chop/passes/graph/analysis/add_metadata/software_metadata_layers.py +++ b/src/chop/passes/graph/analysis/add_metadata/software_metadata_layers.py @@ -243,16 +243,8 @@ def analyze_software_meta_param_patched_func_default(meta): "getitem": analyze_software_meta_param_implicit_func_default, "getattr": analyze_software_meta_param_implicit_func_default, }, - "placeholder": { - "placeholder": analyze_software_meta_param_placeholder, - }, - "get_attr": { - "get_attr": analyze_software_meta_param_get_attr, - }, - "output": { - "output": analyze_software_meta_param_output, - }, - "patched_func": { - "default": analyze_software_meta_param_patched_func_default, - }, + "placeholder": {"placeholder": analyze_software_meta_param_placeholder,}, + "get_attr": {"get_attr": analyze_software_meta_param_get_attr,}, + "output": {"output": analyze_software_meta_param_output,}, + "patched_func": {"default": analyze_software_meta_param_patched_func_default,}, } diff --git a/src/chop/passes/graph/analysis/autosharding/alpa_cost_modelling.py b/src/chop/passes/graph/analysis/autosharding/alpa_cost_modelling.py index d9fea1aa5..30216eea3 100644 --- a/src/chop/passes/graph/analysis/autosharding/alpa_cost_modelling.py +++ b/src/chop/passes/graph/analysis/autosharding/alpa_cost_modelling.py @@ -74,10 +74,7 @@ def get_resharding_cost( and (dest[0] == SpmdShard.R) ): ag_dim = 1 if src[0] == dest[0] else 0 - return mesh.all_gather_cost( - num_bytes=num_bytes, - mesh_dim=ag_dim, - ) + return mesh.all_gather_cost(num_bytes=num_bytes, mesh_dim=ag_dim,) # All-to-all # E.g. (R, S_0) -> (S_0, R), (S_1, R) -> (R, S_1) @@ -85,10 +82,7 @@ def get_resharding_cost( # all to all a2a_dim = src[0].value if src[0] != SpmdShard.R else src[1].value try: - return mesh.all_to_all_cost( - num_bytes=num_bytes, - mesh_dim=a2a_dim, - ) + return mesh.all_to_all_cost(num_bytes=num_bytes, mesh_dim=a2a_dim,) except: assert False diff --git a/src/chop/passes/graph/analysis/autosharding/autosharding.py b/src/chop/passes/graph/analysis/autosharding/autosharding.py index f1ad3a4fe..b49154c20 100644 --- a/src/chop/passes/graph/analysis/autosharding/autosharding.py +++ b/src/chop/passes/graph/analysis/autosharding/autosharding.py @@ -23,10 +23,7 @@ def deepgetattr(obj, attr, default=None): def _import_solution( - mg, - solution: dict, - mesh: MeshModel, - extrapolate_sharding: bool = True, + mg, solution: dict, mesh: MeshModel, extrapolate_sharding: bool = True, ): """Import an autosharding solution into the metadata of the MaseGraph. @@ -71,18 +68,14 @@ def _import_solution( # Annotate the metadata for each argument for arg, arg_spec in solution[node.name].get("args", {}).items(): node.meta["mase"]["common"]["args"][arg]["dtensor_spec"] = _DTensorSpec( - mesh=mesh, - placements=arg_spec, + mesh=mesh, placements=arg_spec, ) # Annotate the metadata for each result for result, result_spec in solution[node.name].get("results", {}).items(): - node.meta["mase"]["common"]["results"][result]["dtensor_spec"] = ( - _DTensorSpec( - mesh=mesh, - placements=result_spec, - ) - ) + node.meta["mase"]["common"]["results"][result][ + "dtensor_spec" + ] = _DTensorSpec(mesh=mesh, placements=result_spec,) return mg, {} @@ -114,10 +107,7 @@ def _export_solution(mg, export_file: str = "ilp_solution.pkl"): logger.warning( f"DTensor spec not found for arg: {arg} in node: {node_name}. Assigning fully-replicated solution." ) - spec = _DTensorSpec( - None, - (Replicate(), Replicate()), - ) + spec = _DTensorSpec(None, (Replicate(), Replicate()),) else: spec = arg_info["dtensor_spec"] @@ -132,10 +122,7 @@ def _export_solution(mg, export_file: str = "ilp_solution.pkl"): logger.warning( f"DTensor spec not found for result: {result} in node: {node_name}. Assigning fully-replicated solution." ) - spec = _DTensorSpec( - None, - (Replicate(), Replicate()), - ) + spec = _DTensorSpec(None, (Replicate(), Replicate()),) else: spec = result_info["dtensor_spec"] out_dict[node_name]["results"][result] = spec.placements @@ -197,9 +184,7 @@ def _get_sharding_map(mg): if module not in tensor_sharding_map: tensor_sharding_map[module] = { "node": node.name, - "sharding": { - attr: out_specs, - }, + "sharding": {attr: out_specs,}, } else: tensor_sharding_map[module]["sharding"][attr] = out_specs @@ -302,8 +287,11 @@ def autosharding_analysis_pass(mg, pass_args: dict = {}): if not pass_args.get(f"skip_forward", False): tensor_sharding_map = _get_sharding_map(mg) - return mg, { - "autosharding_time": autosharding_time, - "tensor_sharding_map": tensor_sharding_map, - **pass_outs, - } + return ( + mg, + { + "autosharding_time": autosharding_time, + "tensor_sharding_map": tensor_sharding_map, + **pass_outs, + }, + ) diff --git a/src/chop/passes/graph/analysis/autosharding/megatron.py b/src/chop/passes/graph/analysis/autosharding/megatron.py index 30cd36f7e..ab9134a36 100644 --- a/src/chop/passes/graph/analysis/autosharding/megatron.py +++ b/src/chop/passes/graph/analysis/autosharding/megatron.py @@ -3,9 +3,7 @@ def megatron_autosharding_pass( - mg: MaseGraph, - mesh: MeshModel, - pass_args: dict, + mg: MaseGraph, mesh: MeshModel, pass_args: dict, ): for node in mg.fx_graph.nodes: meta = node.meta["mase"]["common"] diff --git a/src/chop/passes/graph/analysis/autosharding/strategies/basic_strategy.py b/src/chop/passes/graph/analysis/autosharding/strategies/basic_strategy.py index 40b704ca8..d68f61cb5 100644 --- a/src/chop/passes/graph/analysis/autosharding/strategies/basic_strategy.py +++ b/src/chop/passes/graph/analysis/autosharding/strategies/basic_strategy.py @@ -86,10 +86,7 @@ def parse_dims(cls, input_dims: List[str], output_dim: str) -> "EinsumDims": def gen_einsum_strategies( - equation: str, - mesh: tuple, - *, - linearity: bool = False, + equation: str, mesh: tuple, *, linearity: bool = False, ) -> OpStrategy: """ Generate a strategy list for the ops that follow einsum style notation. diff --git a/src/chop/passes/graph/analysis/autosharding/strategies/common.py b/src/chop/passes/graph/analysis/autosharding/strategies/common.py index 2635a5587..e58980c60 100644 --- a/src/chop/passes/graph/analysis/autosharding/strategies/common.py +++ b/src/chop/passes/graph/analysis/autosharding/strategies/common.py @@ -81,11 +81,7 @@ def fully_replicated_strategy(meta, mesh): in_spec = _DTensorSpec( mesh, sharding, - tensor_meta=TensorMeta( - shape=in_shape, - stride=None, - dtype=in_dtype, - ), + tensor_meta=TensorMeta(shape=in_shape, stride=None, dtype=in_dtype,), ) dtype_key = ( diff --git a/src/chop/passes/graph/analysis/autosharding/strategies/matrix_ops.py b/src/chop/passes/graph/analysis/autosharding/strategies/matrix_ops.py index 78328a95d..60a3de959 100644 --- a/src/chop/passes/graph/analysis/autosharding/strategies/matrix_ops.py +++ b/src/chop/passes/graph/analysis/autosharding/strategies/matrix_ops.py @@ -24,10 +24,7 @@ from chop.ir.graph import MaseMetadata -def transpose_strategy( - meta: MaseMetadata, - mesh: tuple, -) -> OpStrategy: +def transpose_strategy(meta: MaseMetadata, mesh: tuple,) -> OpStrategy: parent_node = meta.node.args[0] self_strategy = parent_node.meta["mase"]["software"]["autosharding"]["op_strategy"] @@ -59,11 +56,7 @@ def transpose_strategy( return OpStrategy(strategies=transpose_strategies) -def _mm_like_strategy( - mm_equation: str, - meta: MaseMetadata, - mesh: tuple, -) -> OpStrategy: +def _mm_like_strategy(mm_equation: str, meta: MaseMetadata, mesh: tuple,) -> OpStrategy: self_shape, mat2_shape = [arg["shape"] for arg in meta["common"]["args"].values()] # generate all possible strategies for mm mm_strategy = gen_einsum_strategies(mm_equation, mesh) @@ -102,9 +95,7 @@ def _mm_like_strategy( def _addmm_like_strategy( - mm_equation: str, - meta: MaseMetadata, - mesh: tuple, + mm_equation: str, meta: MaseMetadata, mesh: tuple, ) -> OpStrategy: self_shape, mat1_shape, mat2_shape = [ @@ -170,37 +161,24 @@ def _addmm_like_strategy( return mm_strategy -def mm_strategy( - meta: MaseMetadata, - mesh: tuple, -) -> OpStrategy: +def mm_strategy(meta: MaseMetadata, mesh: tuple,) -> OpStrategy: return _mm_like_strategy("mk,kn->mn", meta, mesh) -def addmm_strategy( - meta: MaseMetadata, - mesh: tuple, -) -> OpStrategy: +def addmm_strategy(meta: MaseMetadata, mesh: tuple,) -> OpStrategy: return _addmm_like_strategy("mk,kn->mn", meta, mesh) -def bmm_strategy( - meta: MaseMetadata, - mesh: tuple, -) -> OpStrategy: +def bmm_strategy(meta: MaseMetadata, mesh: tuple,) -> OpStrategy: return _mm_like_strategy("bmk,bkn->bmn", meta, mesh) -def baddmm_strategy( - meta: MaseMetadata, - mesh: tuple, -) -> OpStrategy: +def baddmm_strategy(meta: MaseMetadata, mesh: tuple,) -> OpStrategy: return _addmm_like_strategy("bmk,bkn->bmn", meta, mesh) def scaled_dot_product_flash_attention_strategy( - meta: MaseMetadata, - mesh: tuple, + meta: MaseMetadata, mesh: tuple, ) -> OpStrategy: # NOTE: currently we only support some simple strategies to support tensor parallelism # TODO: sdpa might be a good candidate for us to explore decomposed sharding propagation diff --git a/src/chop/passes/graph/analysis/autosharding/strategies/pointwise_ops.py b/src/chop/passes/graph/analysis/autosharding/strategies/pointwise_ops.py index c3c06d35e..a7d8a739b 100644 --- a/src/chop/passes/graph/analysis/autosharding/strategies/pointwise_ops.py +++ b/src/chop/passes/graph/analysis/autosharding/strategies/pointwise_ops.py @@ -125,9 +125,7 @@ def common_pointwise_strategy( arg_node.meta["mase"]["common"]["results"]["data_out_0"]["shape"], ) input_target_placements = map_placements_after_broadcast( - tuple(out_placements), - common_shape, - input_arg_dims_map, + tuple(out_placements), common_shape, input_arg_dims_map, ) input_arg_target_spec = _DTensorSpec( mesh=mesh, diff --git a/src/chop/passes/graph/analysis/autosharding/strategies/tensor_ops.py b/src/chop/passes/graph/analysis/autosharding/strategies/tensor_ops.py index 05c76160a..96cbfa445 100644 --- a/src/chop/passes/graph/analysis/autosharding/strategies/tensor_ops.py +++ b/src/chop/passes/graph/analysis/autosharding/strategies/tensor_ops.py @@ -6,9 +6,7 @@ PlacementStrategy, StrategyType, ) -from torch.distributed.tensor.ops.utils import ( - is_tensor_partial, -) +from torch.distributed.tensor.ops.utils import is_tensor_partial from torch.distributed.tensor.placement_types import ( _DTensorSpec, Partial, diff --git a/src/chop/passes/graph/analysis/autosharding/strategies/view_ops.py b/src/chop/passes/graph/analysis/autosharding/strategies/view_ops.py index 505a2440b..6cd423580 100644 --- a/src/chop/passes/graph/analysis/autosharding/strategies/view_ops.py +++ b/src/chop/passes/graph/analysis/autosharding/strategies/view_ops.py @@ -236,9 +236,7 @@ def dim_flatten(ndim: int, start_dim=0, end_dim=-1) -> DimMap: def dim_movedim( - ndim: int, - input: Union[int, Sequence[int]], - destination: Union[int, Sequence[int]], + ndim: int, input: Union[int, Sequence[int]], destination: Union[int, Sequence[int]], ) -> DimMap: input = normalize_dims(input, ndim) destination = normalize_dims(destination, ndim) @@ -622,8 +620,7 @@ def reshape_strategy(meta, mesh): ) output_strategy.strategies.append( PlacementStrategy( - output_specs=output_spec, - input_specs=(input_tgt_spec,), + output_specs=output_spec, input_specs=(input_tgt_spec,), ) ) diff --git a/src/chop/passes/graph/analysis/flop_estimator/calculator/calc_modules.py b/src/chop/passes/graph/analysis/flop_estimator/calculator/calc_modules.py index b8ca08310..04dde1876 100644 --- a/src/chop/passes/graph/analysis/flop_estimator/calculator/calc_modules.py +++ b/src/chop/passes/graph/analysis/flop_estimator/calculator/calc_modules.py @@ -36,7 +36,7 @@ def calculate_modules(module, in_data, out_data): # Kernel size here can be either a single int for square kernel # or a tuple (see # https://pytorch.org/docs/stable/nn.html#torch.nn.MaxPool2d ) - window_size = module.kernel_size**2 + window_size = module.kernel_size ** 2 else: window_size = module.kernel_size[0] * module.kernel_size[1] diff --git a/src/chop/passes/graph/analysis/plot/plot_graph.py b/src/chop/passes/graph/analysis/plot/plot_graph.py index 8af151b36..adcc93177 100644 --- a/src/chop/passes/graph/analysis/plot/plot_graph.py +++ b/src/chop/passes/graph/analysis/plot/plot_graph.py @@ -4,10 +4,7 @@ def plot_graph_analysis_pass( - graph, - pass_args={ - "file_name": None, - }, + graph, pass_args={"file_name": None,}, ): graph.draw(pass_args["file_name"]) # nx_graph = nx.DiGraph() diff --git a/src/chop/passes/graph/interface/tensorrt/quantize.py b/src/chop/passes/graph/interface/tensorrt/quantize.py index 0de0cda60..0a6773754 100644 --- a/src/chop/passes/graph/interface/tensorrt/quantize.py +++ b/src/chop/passes/graph/interface/tensorrt/quantize.py @@ -21,6 +21,7 @@ def Quantizer(config): "pytorch_quantization is not installed. Cannot use tensorrt quantize pass." ) + else: import tensorrt as trt from pytorch_quantization import quant_modules, calib diff --git a/src/chop/passes/graph/transforms/dse/run_dse.py b/src/chop/passes/graph/transforms/dse/run_dse.py index 28446c4ca..991f9e1bd 100644 --- a/src/chop/passes/graph/transforms/dse/run_dse.py +++ b/src/chop/passes/graph/transforms/dse/run_dse.py @@ -16,7 +16,7 @@ def get_factors(n): set( functools.reduce( list.__add__, - ([i, n // i] for i in range(1, int(n**0.5) + 1) if n % i == 0), + ([i, n // i] for i in range(1, int(n ** 0.5) + 1) if n % i == 0), ) ) ) diff --git a/src/chop/passes/graph/transforms/lora.py b/src/chop/passes/graph/transforms/lora.py index 9c4559661..87fb2362f 100644 --- a/src/chop/passes/graph/transforms/lora.py +++ b/src/chop/passes/graph/transforms/lora.py @@ -10,8 +10,7 @@ def insert_lora_adapter_transform_pass( - mg: MaseGraph, - pass_args={}, + mg: MaseGraph, pass_args={}, ): rank = pass_args.get("rank", 0) @@ -42,8 +41,7 @@ def insert_lora_adapter_transform_pass( def fuse_lora_weights_transform_pass( - mg: MaseGraph, - pass_args={}, + mg: MaseGraph, pass_args={}, ): for node in mg.nodes: target = ( diff --git a/src/chop/passes/graph/transforms/onnxrt/quantize.py b/src/chop/passes/graph/transforms/onnxrt/quantize.py index cacbb6272..39d568236 100644 --- a/src/chop/passes/graph/transforms/onnxrt/quantize.py +++ b/src/chop/passes/graph/transforms/onnxrt/quantize.py @@ -54,9 +54,7 @@ def quantize_dynamic(self, model_path: PosixPath, quantized_model_path: PosixPat ) quantized_model = quantize_dynamic( - model_path, - quantized_model_path, - weight_type=precision, + model_path, quantized_model_path, weight_type=precision, ) self.logger.info("Quantization complete. Model is now dynamically quantized.") diff --git a/src/chop/passes/graph/transforms/pruning/pruning_methods.py b/src/chop/passes/graph/transforms/pruning/pruning_methods.py index 665abc17e..aca81dc51 100644 --- a/src/chop/passes/graph/transforms/pruning/pruning_methods.py +++ b/src/chop/passes/graph/transforms/pruning/pruning_methods.py @@ -128,31 +128,11 @@ def neurons_random_fan_in( weight_criteria_map = { - "local": { - "elementwise": { - "random": random, - "l1-norm": l1, - } - }, - "global": { - "elementwise": { - "random": random, - "l1-norm": global_weight_l1, - } - }, + "local": {"elementwise": {"random": random, "l1-norm": l1,}}, + "global": {"elementwise": {"random": random, "l1-norm": global_weight_l1,}}, } activation_criteria_map = { - "local": { - "elementwise": { - "random": random, - "l1-norm": l1, - } - }, - "global": { - "elementwise": { - "random": random, - "l1-norm": global_activation_l1, - } - }, + "local": {"elementwise": {"random": random, "l1-norm": l1,}}, + "global": {"elementwise": {"random": random, "l1-norm": global_activation_l1,}}, } diff --git a/src/chop/passes/graph/transforms/quantize/quant_parsers/parse_quant_config.py b/src/chop/passes/graph/transforms/quantize/quant_parsers/parse_quant_config.py index dea246fc2..890497a95 100644 --- a/src/chop/passes/graph/transforms/quantize/quant_parsers/parse_quant_config.py +++ b/src/chop/passes/graph/transforms/quantize/quant_parsers/parse_quant_config.py @@ -27,10 +27,7 @@ "weight_entries": ("weight_width", "weight_frac_width"), "data_in_entries": ("data_in_width", "data_in_frac_width"), "bias_entries": ("bias_width", "bias_frac_width"), - "data_out_entries": ( - "data_out_width", - "data_out_frac_width", - ), + "data_out_entries": ("data_out_width", "data_out_frac_width",), "additional_layers_entries": ("floor"), }, "lutnet": { @@ -65,18 +62,9 @@ "weight_width", "weight_frac_width", ), - "bias_entries": ( - "bias_width", - "bias_frac_width", - ), - "data_in_entries": ( - "data_in_width", - "data_in_frac_width", - ), - "data_out_entries": ( - "data_out_width", - "data_out_frac_width", - ), + "bias_entries": ("bias_width", "bias_frac_width",), + "data_in_entries": ("data_in_width", "data_in_frac_width",), + "data_out_entries": ("data_out_width", "data_out_frac_width",), "additional_layers_entries": { "additional_layers_inputs", "additional_layers_outputs", @@ -84,21 +72,9 @@ }, }, "binary": { - "weight_entries": ( - "weight_width", - "weight_stochastic", - "weight_bipolar", - ), - "data_in_entries": ( - "data_in_width", - "data_in_stochastic", - "data_in_bipolar", - ), - "bias_entries": ( - "bias_width", - "bias_stochastic", - "bias_bipolar", - ), + "weight_entries": ("weight_width", "weight_stochastic", "weight_bipolar",), + "data_in_entries": ("data_in_width", "data_in_stochastic", "data_in_bipolar",), + "bias_entries": ("bias_width", "bias_stochastic", "bias_bipolar",), }, "binary_residual": { "weight_entries": ( @@ -114,11 +90,7 @@ "data_in_residual_sign", "data_in_levels", # data_in_levels (int): number of residual levels to use in lutnet ), - "bias_entries": ( - "bias_width", - "bias_stochastic", - "bias_bipolar", - ), + "bias_entries": ("bias_width", "bias_stochastic", "bias_bipolar",), }, "binary_residual": { "weight_entries": ( @@ -134,11 +106,7 @@ "data_in_residual_sign", "data_in_levels", # data_in_levels (int): number of residual levels to use in lutnet ), - "bias_entries": ( - "bias_width", - "bias_stochastic", - "bias_bipolar", - ), + "bias_entries": ("bias_width", "bias_stochastic", "bias_bipolar",), }, "ternary": { "weight_entries": ( @@ -245,11 +213,7 @@ "data_in_exponent_bias_width", "data_in_block_size", ), - "bias_entries": ( - "bias_width", - "bias_exponent_bias_width", - "bias_block_size", - ), + "bias_entries": ("bias_width", "bias_exponent_bias_width", "bias_block_size",), }, "mxint_hardware": { "weight_entries": ( @@ -262,11 +226,7 @@ "data_in_exponent_width", "data_in_parallelism", ), - "bias_entries": ( - "bias_width", - "bias_exponent_width", - "bias_parallelism", - ), + "bias_entries": ("bias_width", "bias_exponent_width", "bias_parallelism",), }, } @@ -389,22 +349,10 @@ def cp_data_out_entries( ("name", "data_in_entries"), ("weight_entries", "bias_entries", "bypass"), ), - "layer_norm": ( - ("name", "data_in_entries"), - ("bypass",), - ), - "group_norm": ( - ("name", "data_in_entries"), - ("bypass",), - ), - "instance_norm2d": ( - ("name", "data_in_entries"), - ("bypass",), - ), - "rms_norm": ( - ("name", "data_in_entries"), - ("bypass",), - ), + "layer_norm": (("name", "data_in_entries"), ("bypass",),), + "group_norm": (("name", "data_in_entries"), ("bypass",),), + "instance_norm2d": (("name", "data_in_entries"), ("bypass",),), + "rms_norm": (("name", "data_in_entries"), ("bypass",),), "grouped_query_attention": ( ("name", "data_in_entries", "weight_entries"), ("bypass", "bias_entries"), diff --git a/src/chop/passes/graph/transforms/training/modify.py b/src/chop/passes/graph/transforms/training/modify.py index eff5583d0..bf84faeee 100644 --- a/src/chop/passes/graph/transforms/training/modify.py +++ b/src/chop/passes/graph/transforms/training/modify.py @@ -65,10 +65,7 @@ def attach_backward_fn(q_fn: torch.autograd.Function, mase_op: str, q_fn_cfg: di def create_new_module( - mase_op: str, - original_module: nn.Module, - config: dict, - node_meta: dict, + mase_op: str, original_module: nn.Module, config: dict, node_meta: dict, ): original_module_cls = type(original_module) diff --git a/src/chop/passes/graph/transforms/utils/logicnets_fusion.py b/src/chop/passes/graph/transforms/utils/logicnets_fusion.py index 5bd5977b3..772206878 100644 --- a/src/chop/passes/graph/transforms/utils/logicnets_fusion.py +++ b/src/chop/passes/graph/transforms/utils/logicnets_fusion.py @@ -12,12 +12,8 @@ matches_module_pattern, replace_node_module, ) -from chop.nn.quantized.modules.linear import ( - LinearLogicNets, -) -from chop.nn.quantized.modules.conv2d import ( - Conv2DLogicNets, -) +from chop.nn.quantized.modules.linear import LinearLogicNets +from chop.nn.quantized.modules.conv2d import Conv2DLogicNets # Housekeeping ------------------------------------------------------------------------- logger = logging.getLogger(__file__) diff --git a/src/chop/passes/graph/transforms/verilog/emit_bram.py b/src/chop/passes/graph/transforms/verilog/emit_bram.py index a6a000c79..e562a5fb1 100644 --- a/src/chop/passes/graph/transforms/verilog/emit_bram.py +++ b/src/chop/passes/graph/transforms/verilog/emit_bram.py @@ -245,8 +245,8 @@ def emit_parameters_in_dat_internal(node, param_name, file_name): else: base_quantizer = integer_quantizer_for_hw - scale = 2**frac_width - thresh = 2**width + scale = 2 ** frac_width + thresh = 2 ** width for i in range(0, out_depth): line_buff = "" for j in range(0, out_size): @@ -301,8 +301,8 @@ def emit_parameters_in_dat_hls(node, param_name, file_name): "precision" ][1] - scale = 2**frac_width - thresh = 2**width + scale = 2 ** frac_width + thresh = 2 ** width for i in range(0, out_depth): line_buff = "" for j in range(0, out_size): diff --git a/src/chop/passes/graph/transforms/verilog/emit_hls.py b/src/chop/passes/graph/transforms/verilog/emit_hls.py index 3efe2c037..ce348c43d 100644 --- a/src/chop/passes/graph/transforms/verilog/emit_hls.py +++ b/src/chop/passes/graph/transforms/verilog/emit_hls.py @@ -121,12 +121,7 @@ def _call_hls_flow(node, node_dir): # Call Vitis HLS for synthesis vitis_hls = os.path.abspath( os.path.join( - os.path.dirname(__file__), - "..", - "..", - "..", - "scripts", - "run-vitis-hls.sh", + os.path.dirname(__file__), "..", "..", "..", "scripts", "run-vitis-hls.sh", ) ) assert os.path.isfile( diff --git a/src/chop/passes/graph/transforms/verilog/logicnets/emit_linear.py b/src/chop/passes/graph/transforms/verilog/logicnets/emit_linear.py index 3d8c0a16e..f0a52f75a 100644 --- a/src/chop/passes/graph/transforms/verilog/logicnets/emit_linear.py +++ b/src/chop/passes/graph/transforms/verilog/logicnets/emit_linear.py @@ -4,9 +4,7 @@ import torch.nn as nn from chop.passes.graph.utils import init_project -from chop.nn.quantized.modules.linear import ( - LinearLogicNets, -) +from chop.nn.quantized.modules.linear import LinearLogicNets from .util import ( generate_lut_verilog, diff --git a/src/chop/passes/module/analysis/report.py b/src/chop/passes/module/analysis/report.py index 514f75ce8..c384ceafd 100644 --- a/src/chop/passes/module/analysis/report.py +++ b/src/chop/passes/module/analysis/report.py @@ -26,8 +26,7 @@ def get_submodule_summary(name: str, module: nn.Module, level: int = 0): def report_trainable_parameters_analysis_pass( - module: torch.nn.Module, - pass_args: dict = {}, + module: torch.nn.Module, pass_args: dict = {}, ): submodule_summary, total_params = get_submodule_summary("", module) table = [(name, params) for _, name, params in submodule_summary] diff --git a/src/chop/passes/utils.py b/src/chop/passes/utils.py index 912250077..8d7ea71eb 100644 --- a/src/chop/passes/utils.py +++ b/src/chop/passes/utils.py @@ -23,9 +23,7 @@ def _nightly_torch_installed(): return False -def find_missing_dependencies( - pass_name: str, -): +def find_missing_dependencies(pass_name: str,): dependencies = PassFactory._dependencies_dict.get(pass_name, None) if dependencies is None: @@ -40,9 +38,7 @@ def find_missing_dependencies( def register_mase_pass( - name: str, - dependencies: list = [], - requires_nightly_torch: bool = False, + name: str, dependencies: list = [], requires_nightly_torch: bool = False, ): """This decorator registers a mase pass as PassFactory class attributes which can be used globally.""" diff --git a/src/chop/pipelines/auto_pipeline.py b/src/chop/pipelines/auto_pipeline.py index c9b784795..1bd382983 100644 --- a/src/chop/pipelines/auto_pipeline.py +++ b/src/chop/pipelines/auto_pipeline.py @@ -12,11 +12,7 @@ class AutoPipeline: The output of each pass is stored in a dictionary and can be accessed by the next pass. """ - def __init__( - self, - pass_groups=None, - run_training: bool = False, - ) -> None: + def __init__(self, pass_groups=None, run_training: bool = False,) -> None: """Initializes the AutoPipeline. Args: @@ -26,11 +22,7 @@ def __init__( self.pass_outputs = [{}] * len(pass_groups) def _run_pass_group( - self, - mg: MaseGraph, - pass_group: list, - pass_args: dict, - skip_passes: list = [], + self, mg: MaseGraph, pass_group: list, pass_args: dict, skip_passes: list = [], ): pass_outputs = {} @@ -56,10 +48,7 @@ def _run_pass_group( return mg, pass_outputs def __call__( - self, - mg: MaseGraph, - pass_args: dict, - skip_passes: list = [], + self, mg: MaseGraph, pass_args: dict, skip_passes: list = [], ): for idx, pass_group in enumerate(self.pass_groups): @@ -70,10 +59,7 @@ def __call__( ) mg, pass_outputs = self._run_pass_group( - mg, - pass_group, - pass_args, - skip_passes, + mg, pass_group, pass_args, skip_passes, ) self.pass_outputs[idx] = pass_outputs diff --git a/src/chop/tools/check_dependency.py b/src/chop/tools/check_dependency.py index dbf13cd12..c71ca088e 100644 --- a/src/chop/tools/check_dependency.py +++ b/src/chop/tools/check_dependency.py @@ -23,9 +23,7 @@ def check_deps_tensorRT_pass(silent: bool = True): return all(availabilities) -def find_missing_dependencies( - pass_name: str, -): +def find_missing_dependencies(pass_name: str,): dependencies = PassFactory._dependencies_dict.get(pass_name, None) if dependencies is None: @@ -40,8 +38,7 @@ def find_missing_dependencies( def check_dependencies( - pass_name: str, - silent: bool = True, + pass_name: str, silent: bool = True, ): unavailable_deps = find_missing_dependencies(pass_name) diff --git a/src/chop/tools/huggingface.py b/src/chop/tools/huggingface.py index 7cf675460..efa0fb1f6 100644 --- a/src/chop/tools/huggingface.py +++ b/src/chop/tools/huggingface.py @@ -39,10 +39,7 @@ def get_hf_dummy_in(model): tokenizer = AutoTokenizer.from_pretrained(checkpoint) dummy_input = tokenizer( - [ - "AI may take over the world one day", - "This is why you should learn ADLS", - ], + ["AI may take over the world one day", "This is why you should learn ADLS",], return_tensors="pt", ) @@ -53,9 +50,7 @@ def get_hf_dummy_in(model): def get_tokenized_dataset( - dataset: str, - checkpoint: str, - return_tokenizer: bool = False, + dataset: str, checkpoint: str, return_tokenizer: bool = False, ): """ Tokenizes a dataset using the AutoTokenizer from Huggingface. @@ -81,10 +76,7 @@ def get_tokenized_dataset( tokenizer = AutoTokenizer.from_pretrained(checkpoint) def tokenize_function(example): - return tokenizer( - example["text"], - truncation=True, - ) + return tokenizer(example["text"], truncation=True,) # Tokenize tokenized_datasets = raw_datasets.map(tokenize_function, batched=True) diff --git a/src/chop/tools/plt_wrapper/nlp/classification.py b/src/chop/tools/plt_wrapper/nlp/classification.py index f96227cba..a69b2ac0f 100644 --- a/src/chop/tools/plt_wrapper/nlp/classification.py +++ b/src/chop/tools/plt_wrapper/nlp/classification.py @@ -41,9 +41,7 @@ def forward(self, input_ids, attention_mask, token_type_ids=None, labels=None): ) else: outputs = self.model( - input_ids, - attention_mask=attention_mask, - labels=labels, + input_ids, attention_mask=attention_mask, labels=labels, ) return outputs diff --git a/src/chop/tools/plt_wrapper/nlp/lm.py b/src/chop/tools/plt_wrapper/nlp/lm.py index 60739da47..a2b669e87 100644 --- a/src/chop/tools/plt_wrapper/nlp/lm.py +++ b/src/chop/tools/plt_wrapper/nlp/lm.py @@ -46,9 +46,7 @@ def training_step(self, batch, batch_idx): self.log("train_loss_step", loss, prog_bar=True) self.log( - "train_perplexity_step", - perplexity, - prog_bar=True, + "train_perplexity_step", perplexity, prog_bar=True, ) return loss diff --git a/src/chop/tools/utils.py b/src/chop/tools/utils.py index 2ae59855e..d070eff92 100644 --- a/src/chop/tools/utils.py +++ b/src/chop/tools/utils.py @@ -92,7 +92,7 @@ def get_factors(n): set( functools.reduce( list.__add__, - ([i, n // i] for i in range(1, int(n**0.5) + 1) if n % i == 0), + ([i, n // i] for i in range(1, int(n ** 0.5) + 1) if n % i == 0), ) ) ) @@ -184,9 +184,7 @@ def init_Conv2dLUT_weight( # Initialize the weight based on the trained binaried network # weight shape of the lagrange trainer [tables_count, self.kk] input_mask = new_module.input_mask.reshape( - -1, - in_channels * kernel_size[0] * kernel_size[1] * k, - 3, + -1, in_channels * kernel_size[0] * kernel_size[1] * k, 3, ) # [oc, k * kh * kw * ic ,3[ic,kh,kw]] expanded_original_weight = original_weight[ np.arange(out_channels)[:, np.newaxis], diff --git a/src/mase_cocotb/interfaces/streaming.py b/src/mase_cocotb/interfaces/streaming.py index 10ca26820..711d2760f 100644 --- a/src/mase_cocotb/interfaces/streaming.py +++ b/src/mase_cocotb/interfaces/streaming.py @@ -17,14 +17,7 @@ def _sign_extend(value: int, bits: int): class StreamDriver(Driver): - def __init__( - self, - clk, - data, - valid, - ready, - record_num_beats=False, - ) -> None: + def __init__(self, clk, data, valid, ready, record_num_beats=False,) -> None: super().__init__() self.clk = clk self.data = data @@ -74,14 +67,7 @@ async def _driver_send(self, transaction) -> None: class StreamMonitor(Monitor): def __init__( - self, - clk, - data, - valid, - ready, - check=True, - name=None, - unsigned=False, + self, clk, data, valid, ready, check=True, name=None, unsigned=False, ): super().__init__(clk, check=check, name=name) self.clk = clk @@ -165,9 +151,9 @@ def __init__(self, clk, data, valid, ready, data_width, frac_width, check=True): def _check(self, got, exp): if self.check: - float_got = [x * 2**-self.frac_width for x in got] - float_exp = [x * 2**-self.frac_width for x in exp] - if not np.isclose(float_got, float_exp, atol=2**-self.frac_width).all(): + float_got = [x * 2 ** -self.frac_width for x in got] + float_exp = [x * 2 ** -self.frac_width for x in exp] + if not np.isclose(float_got, float_exp, atol=2 ** -self.frac_width).all(): # raise TestFailure("\nGot \n%s, \nExpected \n%s" % (got, exp)) raise TestFailure( f"\nGot int \n{got}, \nExpected int \n{exp} \nGot float \n{float_got}, \nExpected float \n{float_exp}" diff --git a/src/mase_cocotb/runner.py b/src/mase_cocotb/runner.py index 21490555a..a74bff0d4 100644 --- a/src/mase_cocotb/runner.py +++ b/src/mase_cocotb/runner.py @@ -110,10 +110,7 @@ def _single_test( verilog_sources=sources, includes=includes, hdl_toplevel=module, - build_args=[ - *tool_args, - *extra_build_args, - ], + build_args=[*tool_args, *extra_build_args,], # Do not use params in hierarchical verilation parameters=module_params if not hierarchical else {}, build_dir=test_work_dir, @@ -312,14 +309,10 @@ def simulate_pass( verilog_sources=[rtl_dir / "top.sv"], includes=[rtl_dir], hdl_toplevel="top", - build_args=[ - *_verilator_args(False, trace) * extra_build_args, - ], + build_args=[*_verilator_args(False, trace) * extra_build_args,], parameters=module_params, build_dir=sim_dir, ) runner.test( - hdl_toplevel="top", - test_module="test", - results_xml="results.xml", + hdl_toplevel="top", test_module="test", results_xml="results.xml", ) diff --git a/src/mase_cocotb/testbench.py b/src/mase_cocotb/testbench.py index 0ad58ba1b..d746f74b5 100644 --- a/src/mase_cocotb/testbench.py +++ b/src/mase_cocotb/testbench.py @@ -8,12 +8,7 @@ class Testbench: __test__ = False # so pytest doesn't confuse this with a test def __init__( - self, - dut, - clk=None, - rst=None, - fail_on_checks=True, - clk_period_ns=20, + self, dut, clk=None, rst=None, fail_on_checks=True, clk_period_ns=20, ) -> None: self.dut = dut self.clk = clk diff --git a/src/mase_cocotb/utils.py b/src/mase_cocotb/utils.py index 7271389f2..25b6d83d8 100644 --- a/src/mase_cocotb/utils.py +++ b/src/mase_cocotb/utils.py @@ -78,8 +78,8 @@ def int_floor_quantizer(x: Tensor, width: int, frac_width: int, signed=True): int_max = 2 ** (width - 1) - 1 else: int_min = 0 - int_max = 2**width - 1 - scale = 2**frac_width + int_max = 2 ** width - 1 + scale = 2 ** frac_width return torch.clamp(torch.floor(x.mul(scale)), int_min, int_max).div(scale) diff --git a/src/mase_cocotb/z_qlayers/tensor_cast.py b/src/mase_cocotb/z_qlayers/tensor_cast.py index ce651bc5c..cf212cf79 100644 --- a/src/mase_cocotb/z_qlayers/tensor_cast.py +++ b/src/mase_cocotb/z_qlayers/tensor_cast.py @@ -46,9 +46,9 @@ def _integer_quantize( int_max = 2 ** (width - 1) - 1 else: int_min = 0 - int_max = 2**width - 1 + int_max = 2 ** width - 1 # thresh = 2 ** (width - 1) - scale = 2**frac_width + scale = 2 ** frac_width if isinstance(x, (Tensor, ndarray)): return my_clamp(my_round(x.mul(scale)), int_min, int_max).div(scale) @@ -59,7 +59,7 @@ def _integer_quantize( def quantize_to_int(x: Tensor, width: int, frac_width: int): - x = (_integer_quantize(x, width, frac_width) * (2**frac_width)).int() + x = (_integer_quantize(x, width, frac_width) * (2 ** frac_width)).int() return x diff --git a/src/mase_components/activation_layers/test/fixed_elu_tb.py b/src/mase_components/activation_layers/test/fixed_elu_tb.py index 87ff69fd6..1c7e850e4 100644 --- a/src/mase_components/activation_layers/test/fixed_elu_tb.py +++ b/src/mase_components/activation_layers/test/fixed_elu_tb.py @@ -79,10 +79,10 @@ def generate_lookup(data_width: int, f_width: int, function: str, type="hex"): count += 1 iarr.append(i) val = quanter(f(torch.tensor(i))) # entry in the lookup table - lut[doubletofx(data_width=data_width, f_width=f_width, num=i, type=type)] = ( - doubletofx( - data_width=data_width, f_width=f_width, num=val.item(), type=type - ) + lut[ + doubletofx(data_width=data_width, f_width=f_width, num=i, type=type) + ] = doubletofx( + data_width=data_width, f_width=f_width, num=val.item(), type=type ) i += 2 ** -(f_width) return lut @@ -444,8 +444,8 @@ def exp(self): ) # match output logger.info(f"EXP - FLOAT OUTPUT: \n{m}") m = self.out_dquantizer(m) - m2 = (m * 2**self.outputfracw).to(torch.int64) - m2 = m2.clone().detach() % (2**self.outputwidth) + m2 = (m * 2 ** self.outputfracw).to(torch.int64) + m2 = m2.clone().detach() % (2 ** self.outputwidth) return m2 @@ -456,7 +456,7 @@ def generate_inputs(self): ) logger.info(f"FLOAT INPUT: \n{inputs}") inputs = self.in_dquantizer(inputs) - intinp = (inputs * 2**self.frac_width).to(torch.int64) + intinp = (inputs * 2 ** self.frac_width).to(torch.int64) return intinp, inputs def doubletofx(self, num, data_width, f_width, type="bin"): diff --git a/src/mase_components/activation_layers/test/fixed_gelu_tb.py b/src/mase_components/activation_layers/test/fixed_gelu_tb.py index ab7d75f4c..2a8e79787 100644 --- a/src/mase_components/activation_layers/test/fixed_gelu_tb.py +++ b/src/mase_components/activation_layers/test/fixed_gelu_tb.py @@ -9,9 +9,7 @@ from mase_cocotb.runner import mase_runner -from mase_components.helper.generate_memory import ( - generate_sv_lut, -) +from mase_components.helper.generate_memory import generate_sv_lut DATA_IN_0_PRECISION_1 = 8 @@ -31,13 +29,13 @@ async def cocotb_test_fixed_gelu(dut): resolution = (max_value - min_value) / (num_values - 1) # Convert the resolution into fixed-point format - resolution_fixed_point = int(resolution * (2**DATA_IN_0_PRECISION_1)) + resolution_fixed_point = int(resolution * (2 ** DATA_IN_0_PRECISION_1)) # Generate the equidistant values values = np.linspace(min_value, max_value, num_values) # Convert values to fixed-point format - values_fixed_point = np.round(values * (2**DATA_IN_0_PRECISION_1)).astype(int) + values_fixed_point = np.round(values * (2 ** DATA_IN_0_PRECISION_1)).astype(int) tensor_tanh = torch.Tensor(values) @@ -48,9 +46,9 @@ async def cocotb_test_fixed_gelu(dut): for i in range(87): # a = tanh_values[i] - b = max_value * (2**DATA_IN_0_PRECISION_1) - resolution_fixed_point * (i - 1) + b = max_value * (2 ** DATA_IN_0_PRECISION_1) - resolution_fixed_point * (i - 1) - a = [b / (2**DATA_IN_0_PRECISION_1)] + a = [b / (2 ** DATA_IN_0_PRECISION_1)] tensor_tanh = torch.Tensor(a) c = model(tensor_tanh) @@ -58,7 +56,7 @@ async def cocotb_test_fixed_gelu(dut): # Scale Tanh output to fixed-point range and convert to integers scaled_value = np.round( - tanh_value_numpy * (2**DATA_OUT_0_PRECISION_1) + tanh_value_numpy * (2 ** DATA_OUT_0_PRECISION_1) ).astype(int) dut.data_in_0[0].value = b diff --git a/src/mase_components/activation_layers/test/fixed_hardshrink_tb.py b/src/mase_components/activation_layers/test/fixed_hardshrink_tb.py index efb6f0045..57f3940fb 100644 --- a/src/mase_components/activation_layers/test/fixed_hardshrink_tb.py +++ b/src/mase_components/activation_layers/test/fixed_hardshrink_tb.py @@ -62,11 +62,11 @@ def exp(self, inputs): # cond = torch.logical_not(torch.logical_and(inputs <= self.thresh*2**self.fracw, inputs >= -1 * self.thresh *2**self.fracw)) # out = torch.where(cond, inputs, torch.tensor(0)) # unsignedout = torch.where(out < 0, torch.tensor(out % (2**self.width)), out) - m = torch.nn.Hardshrink(self.thresh * 2**self.fracw)(inputs.to(torch.float)) + m = torch.nn.Hardshrink(self.thresh * 2 ** self.fracw)(inputs.to(torch.float)) mout = m.clamp( min=-1 * 2 ** (self.outputwidth - 1), max=2 ** (self.outputwidth - 1) - 1 ) - m2 = torch.where(mout < 0, torch.tensor(mout % (2**self.outputwidth)), mout) + m2 = torch.where(mout < 0, torch.tensor(mout % (2 ** self.outputwidth)), mout) return m2.to(torch.int32).tolist() def generate_inputs(self, w, fracw): @@ -75,7 +75,7 @@ def generate_inputs(self, w, fracw): ) realinp = torch.randn(self.samples) inputs = self.dquantizer(realinp) - intinp = (inputs * 2**self.fracw).to(torch.int64) + intinp = (inputs * 2 ** self.fracw).to(torch.int64) intinp.clamp( min=-(2 ** (self.width - self.fracw - 1)), max=2 ** (self.width - self.fracw - 1) - 1, diff --git a/src/mase_components/activation_layers/test/fixed_hardswish_tb.py b/src/mase_components/activation_layers/test/fixed_hardswish_tb.py index ceb194e35..53f79755e 100644 --- a/src/mase_components/activation_layers/test/fixed_hardswish_tb.py +++ b/src/mase_components/activation_layers/test/fixed_hardswish_tb.py @@ -56,12 +56,14 @@ def __init__(self, dut) -> None: def exp(self, inputs): # Run the model with the provided inputs and return the outputs - tmp0 = 3 * 2**self.fracw + tmp0 = 3 * 2 ** self.fracw tmp1 = inputs + tmp0 - tmp2 = tmp1 * (2**-3) + tmp1 * (2**-4) + tmp2 = tmp1 * (2 ** -3) + tmp1 * (2 ** -4) # qtmps = self.dquantizer(tmp2) tmp3 = tmp2 * inputs - unsignedout = torch.where(tmp3 < 0, torch.tensor(tmp3 % (2**self.width)), tmp3) + unsignedout = torch.where( + tmp3 < 0, torch.tensor(tmp3 % (2 ** self.width)), tmp3 + ) # return unsignedout.tolist() return unsignedout @@ -71,7 +73,7 @@ def generate_inputs(self, w, fracw): ) realinp = torch.randn(self.samples) inputs = self.dquantizer(realinp) - intinp = (inputs * 2**self.fracw).to(torch.int64) + intinp = (inputs * 2 ** self.fracw).to(torch.int64) intinp.clamp( min=-(2 ** (self.width - self.fracw - 1)), max=2 ** (self.width - self.fracw - 1) - 1, diff --git a/src/mase_components/activation_layers/test/fixed_leaky_relu_tb.py b/src/mase_components/activation_layers/test/fixed_leaky_relu_tb.py index d0cd8796e..a44efe0b8 100644 --- a/src/mase_components/activation_layers/test/fixed_leaky_relu_tb.py +++ b/src/mase_components/activation_layers/test/fixed_leaky_relu_tb.py @@ -28,9 +28,9 @@ def get_in_and_out(x, fn, width, frac_width): ins = integer_quantizer(x, width=width, frac_width=frac_width) y = fn(x) outs = integer_quantizer(y, width=width, frac_width=frac_width) - outs = outs * 2**frac_width + outs = outs * 2 ** frac_width outs = outs.int() - ins = ins * 2**frac_width + ins = ins * 2 ** frac_width ins = ins.int() return (ins, outs) @@ -78,7 +78,7 @@ async def cocotb_test(dut): logger.info(f"Reset finished") tb.data_out_0_monitor.ready.value = 1 - inputs, exp_outs = tb.generate_inputs_outputs(8, 4, 2**-4) + inputs, exp_outs = tb.generate_inputs_outputs(8, 4, 2 ** -4) tb.data_in_0_driver.append(inputs.tolist()) diff --git a/src/mase_components/activation_layers/test/fixed_logsigmoid_tb.py b/src/mase_components/activation_layers/test/fixed_logsigmoid_tb.py index 570ddc328..9b9a0a766 100644 --- a/src/mase_components/activation_layers/test/fixed_logsigmoid_tb.py +++ b/src/mase_components/activation_layers/test/fixed_logsigmoid_tb.py @@ -142,8 +142,8 @@ def exp(self): ) # match output logger.info(f"EXP - FLOAT OUTPUT: \n{m}") m = self.out_dquantizer(m) - m2 = (m * 2**self.outputfracw).to(torch.int64) - m2 = m2.clone().detach() % (2**self.outputwidth) + m2 = (m * 2 ** self.outputfracw).to(torch.int64) + m2 = m2.clone().detach() % (2 ** self.outputwidth) return m2 @@ -154,7 +154,7 @@ def generate_inputs(self): ) logger.info(f"FLOAT INPUT: \n{inputs}") inputs = self.in_dquantizer(inputs) - intinp = (inputs * 2**self.frac_width).to(torch.int64) + intinp = (inputs * 2 ** self.frac_width).to(torch.int64) return intinp, inputs def doubletofx(self, num, data_width, f_width, type="bin"): diff --git a/src/mase_components/activation_layers/test/fixed_relu_tb.py b/src/mase_components/activation_layers/test/fixed_relu_tb.py index 118296c48..dbe8ac146 100644 --- a/src/mase_components/activation_layers/test/fixed_relu_tb.py +++ b/src/mase_components/activation_layers/test/fixed_relu_tb.py @@ -47,7 +47,7 @@ def backward(ctx, grad_output): def quantize(x, bits, bias): # bits = 32 """Do linear quantization to input according to a scale and number of bits""" thresh = 2 ** (bits - 1) - scale = 2**bias + scale = 2 ** bias return my_clamp(my_round(x.mul(scale)), -thresh, thresh - 1).div(scale) @@ -83,7 +83,7 @@ def get_dut_parameters(self): def get_dut_input(self, i): inputs = self.inputs[i] - shifted_integers = (inputs * (2**self.bias)).int() + shifted_integers = (inputs * (2 ** self.bias)).int() return shifted_integers.numpy().tolist() def get_dut_output(self, i): @@ -92,7 +92,7 @@ def get_dut_output(self, i): return shifted_integers def convert_to_fixed(self, x): - return (x * (2**self.bias)).int().numpy().tolist() + return (x * (2 ** self.bias)).int().numpy().tolist() @cocotb.test() diff --git a/src/mase_components/activation_layers/test/fixed_selu_tb.py b/src/mase_components/activation_layers/test/fixed_selu_tb.py index 647c4fff1..459334c6c 100644 --- a/src/mase_components/activation_layers/test/fixed_selu_tb.py +++ b/src/mase_components/activation_layers/test/fixed_selu_tb.py @@ -25,13 +25,13 @@ async def cocotb_test_fixed_selu(dut): resolution = (max_value - min_value) / (num_values - 1) # Convert the resolution into fixed-point format - resolution_fixed_point = int(resolution * (2**DATA_IN_0_PRECISION_1)) + resolution_fixed_point = int(resolution * (2 ** DATA_IN_0_PRECISION_1)) # Generate the equidistant values values = np.linspace(min_value, max_value, num_values) # Convert values to fixed-point format - values_fixed_point = np.round(values * (2**DATA_IN_0_PRECISION_1)).astype(int) + values_fixed_point = np.round(values * (2 ** DATA_IN_0_PRECISION_1)).astype(int) tensor_tanh = torch.Tensor(values) @@ -42,9 +42,9 @@ async def cocotb_test_fixed_selu(dut): for i in range(87): # a = tanh_values[i] - b = max_value * (2**DATA_IN_0_PRECISION_1) - resolution_fixed_point * (i - 1) + b = max_value * (2 ** DATA_IN_0_PRECISION_1) - resolution_fixed_point * (i - 1) - a = [b / (2**DATA_IN_0_PRECISION_1)] + a = [b / (2 ** DATA_IN_0_PRECISION_1)] tensor_tanh = torch.Tensor(a) c = model(tensor_tanh) @@ -52,7 +52,7 @@ async def cocotb_test_fixed_selu(dut): # Scale Tanh output to fixed-point range and convert to integers scaled_value = np.round( - tanh_value_numpy * (2**DATA_OUT_0_PRECISION_1) + tanh_value_numpy * (2 ** DATA_OUT_0_PRECISION_1) ).astype(int) dut.data_in_0[0].value = b diff --git a/src/mase_components/activation_layers/test/fixed_sigmoid_tb.py b/src/mase_components/activation_layers/test/fixed_sigmoid_tb.py index e2a794d22..427220089 100644 --- a/src/mase_components/activation_layers/test/fixed_sigmoid_tb.py +++ b/src/mase_components/activation_layers/test/fixed_sigmoid_tb.py @@ -128,8 +128,8 @@ def exp(self): ) # match output logger.info(f"EXP - FLOAT OUTPUT: \n{m}") m = self.out_dquantizer(m) - m2 = (m * 2**self.outputfracw).to(torch.int64) - m2 = m2.clone().detach() % (2**self.outputwidth) + m2 = (m * 2 ** self.outputfracw).to(torch.int64) + m2 = m2.clone().detach() % (2 ** self.outputwidth) return m2 @@ -140,7 +140,7 @@ def generate_inputs(self): ) logger.info(f"FLOAT INPUT: \n{inputs}") inputs = self.in_dquantizer(inputs) - intinp = (inputs * 2**self.frac_width).to(torch.int64) + intinp = (inputs * 2 ** self.frac_width).to(torch.int64) return intinp, inputs def doubletofx(self, num, data_width, f_width, type="bin"): diff --git a/src/mase_components/activation_layers/test/fixed_silu_tb.py b/src/mase_components/activation_layers/test/fixed_silu_tb.py index 944a83e5e..19c4ac3a2 100644 --- a/src/mase_components/activation_layers/test/fixed_silu_tb.py +++ b/src/mase_components/activation_layers/test/fixed_silu_tb.py @@ -143,8 +143,8 @@ def exp(self): ) # match output logger.info(f"EXP - FLOAT OUTPUT: \n{m}") m = self.out_dquantizer(m) - m2 = (m * 2**self.outputfracw).to(torch.int64) - m2 = m2.clone().detach() % (2**self.outputwidth) + m2 = (m * 2 ** self.outputfracw).to(torch.int64) + m2 = m2.clone().detach() % (2 ** self.outputwidth) return m2 @@ -155,7 +155,7 @@ def generate_inputs(self): ) logger.info(f"FLOAT INPUT: \n{inputs}") inputs = self.in_dquantizer(inputs) - intinp = (inputs * 2**self.frac_width).to(torch.int64) + intinp = (inputs * 2 ** self.frac_width).to(torch.int64) return intinp, inputs def doubletofx(self, num, data_width, f_width, type="bin"): diff --git a/src/mase_components/activation_layers/test/fixed_softermax_1d_tb.py b/src/mase_components/activation_layers/test/fixed_softermax_1d_tb.py index d48214384..41e67e205 100644 --- a/src/mase_components/activation_layers/test/fixed_softermax_1d_tb.py +++ b/src/mase_components/activation_layers/test/fixed_softermax_1d_tb.py @@ -19,9 +19,7 @@ from chop.nn.quantized.functional import fixed_softermax -from chop.nn.quantizers import ( - integer_quantizer, -) +from chop.nn.quantizers import integer_quantizer class SoftermaxTB(Testbench): @@ -66,10 +64,7 @@ def __init__(self, dut) -> None: self.model = partial( fixed_softermax, dim=0, - q_config={ - "width": self.IN_WIDTH, - "frac_width": self.IN_FRAC_WIDTH, - }, + q_config={"width": self.IN_WIDTH, "frac_width": self.IN_FRAC_WIDTH,}, ) # Set verbosity of driver and monitor loggers to debug @@ -77,9 +72,7 @@ def __init__(self, dut) -> None: # self.out_data_monitor.log.setLevel(logging.DEBUG) def generate_inputs(self, batches): - return torch.randn( - (batches, self.TOTAL_DIM), - ) + return torch.randn((batches, self.TOTAL_DIM),) async def run_test(self, batches, us): await self.reset() @@ -95,10 +88,7 @@ async def run_test(self, batches, us): self.log.debug(f"Processing inputs: {batch}") driver_input = fixed_preprocess_tensor( tensor=batch, - q_config={ - "width": self.IN_WIDTH, - "frac_width": self.IN_FRAC_WIDTH, - }, + q_config={"width": self.IN_WIDTH, "frac_width": self.IN_FRAC_WIDTH,}, parallelism=[self.PARALLELISM], ) self.in_data_driver.load_driver(driver_input) @@ -107,10 +97,7 @@ async def run_test(self, batches, us): self.log.debug(f"Processing outputs: {exp_out}") outs = fixed_preprocess_tensor( tensor=exp_out, - q_config={ - "width": self.OUT_WIDTH, - "frac_width": self.OUT_FRAC_WIDTH, - }, + q_config={"width": self.OUT_WIDTH, "frac_width": self.OUT_FRAC_WIDTH,}, parallelism=[self.PARALLELISM], ) self.out_data_monitor.load_monitor(outs) diff --git a/src/mase_components/activation_layers/test/fixed_softermax_tb.py b/src/mase_components/activation_layers/test/fixed_softermax_tb.py index 9b7d0aee1..84c6aa4e5 100644 --- a/src/mase_components/activation_layers/test/fixed_softermax_tb.py +++ b/src/mase_components/activation_layers/test/fixed_softermax_tb.py @@ -133,9 +133,7 @@ def test_fixed_softermax_smoke(): """ mase_runner( trace=True, - module_param_list=[ - get_fixed_softermax_config(), - ], + module_param_list=[get_fixed_softermax_config(),], # skip_build=True, ) diff --git a/src/mase_components/activation_layers/test/fixed_softmax_tb.py b/src/mase_components/activation_layers/test/fixed_softmax_tb.py index 62f37012f..afc8876ac 100644 --- a/src/mase_components/activation_layers/test/fixed_softmax_tb.py +++ b/src/mase_components/activation_layers/test/fixed_softmax_tb.py @@ -128,8 +128,8 @@ def exp(self): ) # match output logger.info(f"EXP - FLOAT OUTPUT: \n{m}") m = self.out_dquantizer(m) - m2 = (m * 2**self.outputfracw).to(torch.int64) - m2 = m2.clone().detach() % (2**self.outputwidth) + m2 = (m * 2 ** self.outputfracw).to(torch.int64) + m2 = m2.clone().detach() % (2 ** self.outputwidth) return m2 @@ -140,7 +140,7 @@ def generate_inputs(self): ) logger.info(f"FLOAT INPUT: \n{inputs}") inputs = self.in_dquantizer(inputs) - intinp = (inputs * 2**self.frac_width).to(torch.int64) + intinp = (inputs * 2 ** self.frac_width).to(torch.int64) return intinp, inputs def doubletofx(self, num, data_width, f_width, type="bin"): diff --git a/src/mase_components/activation_layers/test/fixed_softplus_tb.py b/src/mase_components/activation_layers/test/fixed_softplus_tb.py index 121fa5c60..f39467ee7 100644 --- a/src/mase_components/activation_layers/test/fixed_softplus_tb.py +++ b/src/mase_components/activation_layers/test/fixed_softplus_tb.py @@ -25,13 +25,13 @@ async def cocotb_test_fixed_softplus(dut): resolution = (max_value - min_value) / (num_values - 1) # Convert the resolution into fixed-point format - resolution_fixed_point = int(resolution * (2**DATA_IN_0_PRECISION_1)) + resolution_fixed_point = int(resolution * (2 ** DATA_IN_0_PRECISION_1)) # Generate the equidistant values values = np.linspace(min_value, max_value, num_values) # Convert values to fixed-point format - values_fixed_point = np.round(values * (2**DATA_IN_0_PRECISION_1)).astype(int) + values_fixed_point = np.round(values * (2 ** DATA_IN_0_PRECISION_1)).astype(int) tensor_tanh = torch.Tensor(values) @@ -42,9 +42,9 @@ async def cocotb_test_fixed_softplus(dut): for i in range(87): # a = tanh_values[i] - b = max_value * (2**DATA_IN_0_PRECISION_1) - resolution_fixed_point * (i - 1) + b = max_value * (2 ** DATA_IN_0_PRECISION_1) - resolution_fixed_point * (i - 1) - a = [b / (2**DATA_IN_0_PRECISION_1)] + a = [b / (2 ** DATA_IN_0_PRECISION_1)] tensor_tanh = torch.Tensor(a) c = model(tensor_tanh) @@ -52,7 +52,7 @@ async def cocotb_test_fixed_softplus(dut): # Scale Tanh output to fixed-point range and convert to integers scaled_value = np.round( - tanh_value_numpy * (2**DATA_OUT_0_PRECISION_1) + tanh_value_numpy * (2 ** DATA_OUT_0_PRECISION_1) ).astype(int) dut.data_in_0[0].value = b diff --git a/src/mase_components/activation_layers/test/fixed_softshrink_tb.py b/src/mase_components/activation_layers/test/fixed_softshrink_tb.py index b79ccdaf5..9c738602e 100644 --- a/src/mase_components/activation_layers/test/fixed_softshrink_tb.py +++ b/src/mase_components/activation_layers/test/fixed_softshrink_tb.py @@ -142,8 +142,8 @@ def exp(self): ) # match output logger.info(f"EXP - FLOAT OUTPUT: \n{m}") m = self.out_dquantizer(m) - m2 = (m * 2**self.outputfracw).to(torch.int64) - m2 = m2.clone().detach() % (2**self.outputwidth) + m2 = (m * 2 ** self.outputfracw).to(torch.int64) + m2 = m2.clone().detach() % (2 ** self.outputwidth) return m2 @@ -154,7 +154,7 @@ def generate_inputs(self): ) logger.info(f"FLOAT INPUT: \n{inputs}") inputs = self.in_dquantizer(inputs) - intinp = (inputs * 2**self.frac_width).to(torch.int64) + intinp = (inputs * 2 ** self.frac_width).to(torch.int64) return intinp, inputs def doubletofx(self, num, data_width, f_width, type="bin"): diff --git a/src/mase_components/activation_layers/test/fixed_softsign_tb.py b/src/mase_components/activation_layers/test/fixed_softsign_tb.py index 5fec6b341..c8776ac32 100644 --- a/src/mase_components/activation_layers/test/fixed_softsign_tb.py +++ b/src/mase_components/activation_layers/test/fixed_softsign_tb.py @@ -24,13 +24,13 @@ async def cocotb_test_fixed_softsign(dut): resolution = (max_value - min_value) / (num_values - 1) # Convert the resolution into fixed-point format - resolution_fixed_point = int(resolution * (2**DATA_IN_0_PRECISION_1)) + resolution_fixed_point = int(resolution * (2 ** DATA_IN_0_PRECISION_1)) # Generate the equidistant values values = np.linspace(min_value, max_value, num_values) # Convert values to fixed-point format - values_fixed_point = np.round(values * (2**DATA_IN_0_PRECISION_1)).astype(int) + values_fixed_point = np.round(values * (2 ** DATA_IN_0_PRECISION_1)).astype(int) tensor_tanh = torch.Tensor(values) @@ -41,9 +41,9 @@ async def cocotb_test_fixed_softsign(dut): for i in range(87): # a = tanh_values[i] - b = max_value * (2**DATA_IN_0_PRECISION_1) - resolution_fixed_point * (i - 1) + b = max_value * (2 ** DATA_IN_0_PRECISION_1) - resolution_fixed_point * (i - 1) - a = [b / (2**DATA_IN_0_PRECISION_1)] + a = [b / (2 ** DATA_IN_0_PRECISION_1)] tensor_tanh = torch.Tensor(a) c = model(tensor_tanh) @@ -51,7 +51,7 @@ async def cocotb_test_fixed_softsign(dut): # Scale Tanh output to fixed-point range and convert to integers scaled_value = np.round( - tanh_value_numpy * (2**DATA_OUT_0_PRECISION_1) + tanh_value_numpy * (2 ** DATA_OUT_0_PRECISION_1) ).astype(int) dut.data_in_0[0].value = b diff --git a/src/mase_components/activation_layers/test/fixed_tanh_tb.py b/src/mase_components/activation_layers/test/fixed_tanh_tb.py index 4e8c6d268..937fb930c 100644 --- a/src/mase_components/activation_layers/test/fixed_tanh_tb.py +++ b/src/mase_components/activation_layers/test/fixed_tanh_tb.py @@ -24,13 +24,13 @@ async def cocotb_test_fixed_tanh(dut): resolution = (max_value - min_value) / (num_values - 1) # Convert the resolution into fixed-point format - resolution_fixed_point = int(resolution * (2**DATA_IN_0_PRECISION_1)) + resolution_fixed_point = int(resolution * (2 ** DATA_IN_0_PRECISION_1)) # Generate the equidistant values values = np.linspace(min_value, max_value, num_values) # Convert values to fixed-point format - values_fixed_point = np.round(values * (2**DATA_IN_0_PRECISION_1)).astype(int) + values_fixed_point = np.round(values * (2 ** DATA_IN_0_PRECISION_1)).astype(int) tensor_tanh = torch.Tensor(values) @@ -41,9 +41,9 @@ async def cocotb_test_fixed_tanh(dut): for i in range(87): # a = tanh_values[i] - b = max_value * (2**DATA_IN_0_PRECISION_1) - resolution_fixed_point * (i - 1) + b = max_value * (2 ** DATA_IN_0_PRECISION_1) - resolution_fixed_point * (i - 1) - a = [b / (2**DATA_IN_0_PRECISION_1)] + a = [b / (2 ** DATA_IN_0_PRECISION_1)] tensor_tanh = torch.Tensor(a) c = model(tensor_tanh) @@ -51,7 +51,7 @@ async def cocotb_test_fixed_tanh(dut): # Scale Tanh output to fixed-point range and convert to integers scaled_value = np.round( - tanh_value_numpy * (2**DATA_OUT_0_PRECISION_1) + tanh_value_numpy * (2 ** DATA_OUT_0_PRECISION_1) ).astype(int) dut.data_in_0[0].value = b diff --git a/src/mase_components/activation_layers/test/softermax.py b/src/mase_components/activation_layers/test/softermax.py index a1b3271f5..7c30c2986 100644 --- a/src/mase_components/activation_layers/test/softermax.py +++ b/src/mase_components/activation_layers/test/softermax.py @@ -53,7 +53,7 @@ def _softmax_model(l: list[int], parallelism: int, pow2=False): for diff, vals in zip(local_max_diff, local_values_buffer): if pow2: - adj = [x * (2**-diff) for x in vals] + adj = [x * (2 ** -diff) for x in vals] else: adj = [x * exp(-diff) for x in vals] norm += sum(adj) diff --git a/src/mase_components/activation_layers/test/softermax_global_norm_tb.py b/src/mase_components/activation_layers/test/softermax_global_norm_tb.py index bc74e42ee..6ebbfc131 100644 --- a/src/mase_components/activation_layers/test/softermax_global_norm_tb.py +++ b/src/mase_components/activation_layers/test/softermax_global_norm_tb.py @@ -46,7 +46,7 @@ def __init__(self, dut) -> None: # Specify Error Threshold self.percentage_error = 0.05 # 5% - self.error_threshold_bits = ceil(self.percentage_error * (2**self.OUT_WIDTH)) + self.error_threshold_bits = ceil(self.percentage_error * (2 ** self.OUT_WIDTH)) self.output_monitor = ErrorThresholdStreamMonitor( dut.clk, @@ -63,15 +63,15 @@ def __init__(self, dut) -> None: def generate_inputs(self, batches=10): # TODO: Take a look at all zero case again local_vals = torch.randint( - 1, 2**self.IN_VALUE_WIDTH, size=(batches * self.DEPTH, self.PARALLELISM) + 1, 2 ** self.IN_VALUE_WIDTH, size=(batches * self.DEPTH, self.PARALLELISM) ) local_max = torch.randint( - 0, 2**self.IN_MAX_WIDTH, size=(batches * self.DEPTH, 1) + 0, 2 ** self.IN_MAX_WIDTH, size=(batches * self.DEPTH, 1) ) logger.debug("local_vals: %s" % (local_vals)) logger.debug( - "local_vals (float): %s" % (local_vals / (2**self.IN_VALUE_FRAC_WIDTH)) + "local_vals (float): %s" % (local_vals / (2 ** self.IN_VALUE_FRAC_WIDTH)) ) logger.debug("local_max: %s" % (local_max)) logger.debug( @@ -87,7 +87,7 @@ def model(self, inputs): for batch in batched_in: local_vals, local_max = list(zip(*batch)) local_vals = torch.tensor(list(local_vals), dtype=torch.float) / ( - 2**self.IN_VALUE_FRAC_WIDTH + 2 ** self.IN_VALUE_FRAC_WIDTH ) local_max = torch.tensor(list(local_max), dtype=torch.float) local_max = sign_extend_t( @@ -97,7 +97,7 @@ def model(self, inputs): global_max = local_max.max() adj_amt = global_max - local_max.reshape(self.DEPTH, 1) adj_values = integer_floor_quantizer( - x=local_vals / (2**adj_amt), + x=local_vals / (2 ** adj_amt), width=self.IN_VALUE_WIDTH, frac_width=self.IN_VALUE_FRAC_WIDTH, is_signed=False, @@ -226,10 +226,7 @@ def in_value_cfgs(cfgs: list): for cfg in cfgs: for in_width in [4, 7, 10]: new_cfgs.append( - { - **cfg, - "IN_VALUE_WIDTH": in_width, - } + {**cfg, "IN_VALUE_WIDTH": in_width,} ) return new_cfgs @@ -238,10 +235,7 @@ def in_max_cfgs(cfgs: list): for cfg in cfgs: for in_max in [2, 3, 4]: new_cfgs.append( - { - **cfg, - "IN_MAX_WIDTH": in_max, - } + {**cfg, "IN_MAX_WIDTH": in_max,} ) return new_cfgs @@ -256,7 +250,5 @@ def in_max_cfgs(cfgs: list): # cfgs = [{'TOTAL_DIM': 32, 'PARALLELISM': 4, 'IN_VALUE_WIDTH': 16, 'IN_MAX_WIDTH': 2}] mase_runner( - module_param_list=cfgs, - trace=True, - jobs=12, + module_param_list=cfgs, trace=True, jobs=12, ) diff --git a/src/mase_components/activation_layers/test/softermax_local_window_tb.py b/src/mase_components/activation_layers/test/softermax_local_window_tb.py index 10b11a569..c52b7c647 100644 --- a/src/mase_components/activation_layers/test/softermax_local_window_tb.py +++ b/src/mase_components/activation_layers/test/softermax_local_window_tb.py @@ -53,7 +53,7 @@ def __init__(self, dut) -> None: def generate_inputs(self, batches=10): return [ - [randint(0, 2**self.IN_WIDTH - 1) for _ in range(self.PARALLELISM)] + [randint(0, 2 ** self.IN_WIDTH - 1) for _ in range(self.PARALLELISM)] for _ in range(batches) ] @@ -80,7 +80,7 @@ def model(self, inputs): sign_ext = sign_extend_t( torch.tensor(inputs, dtype=torch.float), bits=self.IN_WIDTH ) - float_inputs = sign_ext / (2**self.IN_FRAC_WIDTH) + float_inputs = sign_ext / (2 ** self.IN_FRAC_WIDTH) # float_inputs = torch.tensor([[-31.5, -32]]) rounded_inputs_float, rounded_inputs_uint = _fixed_signed_cast_model( float_inputs, self.MAX_WIDTH, 0, False, "floor" @@ -89,9 +89,9 @@ def model(self, inputs): local_max_uint = signed_to_unsigned(local_max.int(), self.MAX_WIDTH) difference = float_inputs - local_max - pow2 = 2**difference + pow2 = 2 ** difference res = torch.clamp( - (pow2 * 2**self.OUT_FRAC_WIDTH).int(), 0, 2**self.OUT_WIDTH - 1 + (pow2 * 2 ** self.OUT_FRAC_WIDTH).int(), 0, 2 ** self.OUT_WIDTH - 1 ) logger.debug("float_inputs: %s" % float_inputs) diff --git a/src/mase_components/activation_layers/test/softermax_lpw_pow2_tb.py b/src/mase_components/activation_layers/test/softermax_lpw_pow2_tb.py index 058fa10b0..0a4b52af5 100644 --- a/src/mase_components/activation_layers/test/softermax_lpw_pow2_tb.py +++ b/src/mase_components/activation_layers/test/softermax_lpw_pow2_tb.py @@ -37,7 +37,7 @@ def __init__(self, dut) -> None: self.in_driver = StreamDriver(dut.clk, dut.in_data, dut.in_valid, dut.in_ready) # 0.1% bit error - self.error_threshold_bits = ceil((2**self.IN_WIDTH) * 0.001) + self.error_threshold_bits = ceil((2 ** self.IN_WIDTH) * 0.001) self.output_monitor = ErrorThresholdStreamMonitor( dut.clk, @@ -53,10 +53,10 @@ def __init__(self, dut) -> None: def sweep_inputs(self): negative_nums = torch.arange( - start=2 ** (self.IN_WIDTH - 1), end=2**self.IN_WIDTH, dtype=torch.int32 + start=2 ** (self.IN_WIDTH - 1), end=2 ** self.IN_WIDTH, dtype=torch.int32 ) zero_to_one = torch.arange( - start=0, end=2**self.IN_FRAC_WIDTH, dtype=torch.int32 # one + start=0, end=2 ** self.IN_FRAC_WIDTH, dtype=torch.int32 # one ) return torch.cat((negative_nums, zero_to_one)).tolist() @@ -68,14 +68,14 @@ def generate_inputs(self, batches=1): # Negative Numbers torch.randint( low=2 ** (self.IN_WIDTH - 1), - high=2**self.IN_WIDTH, + high=2 ** self.IN_WIDTH, size=(negative_nums,), dtype=torch.int32, ), # Numbers between 0 and 1 torch.randint( low=0, - high=2**self.IN_FRAC_WIDTH, + high=2 ** self.IN_FRAC_WIDTH, size=(zero_to_one_nums,), dtype=torch.int32, ), @@ -85,10 +85,10 @@ def generate_inputs(self, batches=1): def model(self, inputs): in_t = torch.tensor(inputs) - num = sign_extend_t(in_t, self.IN_WIDTH) / (2**self.IN_FRAC_WIDTH) - res = 2**num - res = (res * 2**self.OUT_FRAC_WIDTH).int() - res = torch.clamp(res, 0, 2**self.OUT_WIDTH - 1) + num = sign_extend_t(in_t, self.IN_WIDTH) / (2 ** self.IN_FRAC_WIDTH) + res = 2 ** num + res = (res * 2 ** self.OUT_FRAC_WIDTH).int() + res = torch.clamp(res, 0, 2 ** self.OUT_WIDTH - 1) return res.tolist() async def run_test(self, batches, us): @@ -128,7 +128,7 @@ async def sweep(dut): tb.in_driver.load_driver(inputs) tb.output_monitor.load_monitor(exp_out) - ns = ((2**tb.IN_WIDTH) * 1000) // 5 + ns = ((2 ** tb.IN_WIDTH) * 1000) // 5 logger.info("Waiting %d ns..." % ns) await Timer(ns, "ns") assert tb.output_monitor.exp_queue.empty() @@ -137,14 +137,14 @@ async def sweep(dut): recv_log = tb.output_monitor.recv_log assert len(exp_out) == len(recv_log) - x = sign_extend_t(torch.tensor(inputs), tb.IN_WIDTH) / (2**tb.IN_FRAC_WIDTH) - ref = 2**x - ref *= 2**tb.OUT_FRAC_WIDTH # scale up - ref = torch.clamp(ref, 0, 2**tb.OUT_WIDTH - 1) + x = sign_extend_t(torch.tensor(inputs), tb.IN_WIDTH) / (2 ** tb.IN_FRAC_WIDTH) + ref = 2 ** x + ref *= 2 ** tb.OUT_FRAC_WIDTH # scale up + ref = torch.clamp(ref, 0, 2 ** tb.OUT_WIDTH - 1) - software_ref = ref / (2**tb.OUT_FRAC_WIDTH) - software_res = [x / (2**tb.OUT_FRAC_WIDTH) for x in exp_out] - hardware_res = [x / (2**tb.OUT_FRAC_WIDTH) for x in recv_log] + software_ref = ref / (2 ** tb.OUT_FRAC_WIDTH) + software_res = [x / (2 ** tb.OUT_FRAC_WIDTH) for x in exp_out] + hardware_res = [x / (2 ** tb.OUT_FRAC_WIDTH) for x in recv_log] data = pd.DataFrame( { @@ -174,10 +174,7 @@ async def sweep(dut): ), color=alt.Color("Type"), ) - .properties( - width=600, - height=220, - ) + .properties(width=600, height=220,) ) error_data = data[["x", "hardware error"]] @@ -188,10 +185,7 @@ async def sweep(dut): x=alt.X("x").title(f"x (Q{tb.IN_WIDTH}.{tb.IN_FRAC_WIDTH} Fixed-point)"), y=alt.Y("hardware error").title(f"Error"), ) - .properties( - width=600, - height=100, - ) + .properties(width=600, height=100,) ) (curve_fig & error_fig).save( @@ -302,12 +296,7 @@ def test_high_width(): def test_smoke(): mase_runner( module_param_list=[ - { - "IN_WIDTH": 8, - "IN_FRAC_WIDTH": 4, - "OUT_WIDTH": 8, - "OUT_FRAC_WIDTH": 4, - } + {"IN_WIDTH": 8, "IN_FRAC_WIDTH": 4, "OUT_WIDTH": 8, "OUT_FRAC_WIDTH": 4,} ] ) diff --git a/src/mase_components/activation_layers/test/softermax_lpw_reciprocal_tb.py b/src/mase_components/activation_layers/test/softermax_lpw_reciprocal_tb.py index f2c70b67c..842881dc9 100644 --- a/src/mase_components/activation_layers/test/softermax_lpw_reciprocal_tb.py +++ b/src/mase_components/activation_layers/test/softermax_lpw_reciprocal_tb.py @@ -38,7 +38,7 @@ def __init__(self, dut) -> None: # Specify Error Threshold self.percentage_error = 0.05 # 5% - self.error_threshold_bits = ceil(self.percentage_error * (2**self.OUT_WIDTH)) + self.error_threshold_bits = ceil(self.percentage_error * (2 ** self.OUT_WIDTH)) self.output_monitor = ErrorThresholdStreamMonitor( dut.clk, @@ -53,17 +53,17 @@ def __init__(self, dut) -> None: ) def generate_inputs(self, batches=100): - return [randint(0, 2**self.IN_WIDTH - 1) for _ in range(batches)] + return [randint(0, 2 ** self.IN_WIDTH - 1) for _ in range(batches)] def sweep_input(self): - return list(range(2**self.IN_WIDTH)) + return list(range(2 ** self.IN_WIDTH)) def model(self, inputs): - in_t = torch.tensor(inputs) / (2**self.IN_FRAC_WIDTH) + in_t = torch.tensor(inputs) / (2 ** self.IN_FRAC_WIDTH) recip = 1.0 / in_t - res = torch.floor(recip * 2**self.OUT_FRAC_WIDTH) + res = torch.floor(recip * 2 ** self.OUT_FRAC_WIDTH) res = torch.nan_to_num(res) - res = torch.clamp(res, 0, 2**self.OUT_WIDTH - 1) + res = torch.clamp(res, 0, 2 ** self.OUT_WIDTH - 1) res = res.int() return res.tolist() @@ -107,14 +107,14 @@ async def sweep(dut): recv_log = tb.output_monitor.recv_log assert len(exp_out) == len(recv_log) - x = torch.tensor(inputs) / (2**tb.IN_FRAC_WIDTH) + x = torch.tensor(inputs) / (2 ** tb.IN_FRAC_WIDTH) ref = 1.0 / x - ref *= 2**tb.OUT_FRAC_WIDTH # scale up - ref = torch.clamp(ref, 0, 2**tb.OUT_WIDTH - 1) + ref *= 2 ** tb.OUT_FRAC_WIDTH # scale up + ref = torch.clamp(ref, 0, 2 ** tb.OUT_WIDTH - 1) - software_ref = ref / (2**tb.OUT_FRAC_WIDTH) - software_res = [x / (2**tb.OUT_FRAC_WIDTH) for x in exp_out] - hardware_res = [x / (2**tb.OUT_FRAC_WIDTH) for x in recv_log] + software_ref = ref / (2 ** tb.OUT_FRAC_WIDTH) + software_res = [x / (2 ** tb.OUT_FRAC_WIDTH) for x in exp_out] + hardware_res = [x / (2 ** tb.OUT_FRAC_WIDTH) for x in recv_log] data = pd.DataFrame( { @@ -150,10 +150,7 @@ async def sweep(dut): ), color=alt.Color("Type"), ) - .properties( - width=600, - height=300, - ) + .properties(width=600, height=300,) ) error_data = data[["x", "hardware error"]] @@ -164,10 +161,7 @@ async def sweep(dut): x=alt.X("x").title(f"x (Q{tb.IN_WIDTH}.{tb.IN_FRAC_WIDTH} Fixed-point)"), y=alt.Y("hardware error").title(f"Error"), ) - .properties( - width=600, - height=100, - ) + .properties(width=600, height=100,) ) (curve_fig & error_fig).save( @@ -222,7 +216,7 @@ def width_cfgs(): for width in range(2, 16 + 1): frac_width = width // 2 if frac_width < 3: - entries = 2**frac_width + entries = 2 ** frac_width else: entries = 8 cfgs.append( @@ -260,9 +254,7 @@ def test_width_cfgs(): "OUT_FRAC_WIDTH": 4, } ] - mase_runner( - module_param_list=cfgs, - ) + mase_runner(module_param_list=cfgs,) def test_smoke(): diff --git a/src/mase_components/cast/test/fixed_rounding_tb.py b/src/mase_components/cast/test/fixed_rounding_tb.py index a7f2f935c..300283856 100644 --- a/src/mase_components/cast/test/fixed_rounding_tb.py +++ b/src/mase_components/cast/test/fixed_rounding_tb.py @@ -44,7 +44,7 @@ def single_run(self): def sw_cast(self, inputs): outputs = ( integer_floor_quantizer(inputs, self.out_width, self.out_frac_width) - * 2**self.out_frac_width + * 2 ** self.out_frac_width ) # breakpoint() return outputs diff --git a/src/mase_components/cast/test/fixed_signed_cast_tb.py b/src/mase_components/cast/test/fixed_signed_cast_tb.py index 2ce3bbe5f..9f5a6956d 100644 --- a/src/mase_components/cast/test/fixed_signed_cast_tb.py +++ b/src/mase_components/cast/test/fixed_signed_cast_tb.py @@ -18,7 +18,7 @@ def _fixed_signed_cast_model( float_input, out_width, out_frac_width, symmetric, rounding_mode ): - scaled_float = float_input * (2**out_frac_width) + scaled_float = float_input * (2 ** out_frac_width) if rounding_mode == "floor": out_int = my_floor(scaled_float) elif rounding_mode == "round_nearest_half_even": @@ -30,7 +30,7 @@ def _fixed_signed_cast_model( -(2 ** (out_width - 1)) + 1 if symmetric else -(2 ** (out_width - 1)), (2 ** (out_width - 1)) - 1, ) - out_float = out_int / (2**out_frac_width) + out_float = out_int / (2 ** out_frac_width) # out_uint is a non-differentiable path out_uint = signed_to_unsigned(out_int.int(), out_width) return out_float, out_uint @@ -58,9 +58,9 @@ def __init__(self, dut) -> None: ) def generate_inputs(self): - uints = torch.arange(2**self.IN_WIDTH) + uints = torch.arange(2 ** self.IN_WIDTH) num_int = sign_extend_t(uints, self.IN_WIDTH) - num_float = num_int / (2**self.IN_FRAC_WIDTH) + num_float = num_int / (2 ** self.IN_FRAC_WIDTH) return num_int, num_float def rounding_mode(self): @@ -150,10 +150,7 @@ def gen_symmetric(cfg_list): l = list() for cfg in cfg_list: l.extend( - [ - {**cfg, "SYMMETRIC": 0}, - {**cfg, "SYMMETRIC": 1}, - ] + [{**cfg, "SYMMETRIC": 0}, {**cfg, "SYMMETRIC": 1},] ) return l diff --git a/src/mase_components/cast/test/fixed_unsigned_cast_tb.py b/src/mase_components/cast/test/fixed_unsigned_cast_tb.py index cfeca8318..75910ed9e 100644 --- a/src/mase_components/cast/test/fixed_unsigned_cast_tb.py +++ b/src/mase_components/cast/test/fixed_unsigned_cast_tb.py @@ -51,7 +51,7 @@ def __init__(self, dut) -> None: ) def generate_inputs(self): - return torch.arange(2**self.IN_WIDTH) + return torch.arange(2 ** self.IN_WIDTH) def rounding_mode(self): if self.ROUND_FLOOR: @@ -64,10 +64,10 @@ def rounding_mode(self): raise Exception("Rounding mode not recognised.") def model(self, inputs): - float_input = inputs / (2**self.IN_FRAC_WIDTH) - scaled_float = float_input * (2**self.OUT_FRAC_WIDTH) + float_input = inputs / (2 ** self.IN_FRAC_WIDTH) + scaled_float = float_input * (2 ** self.OUT_FRAC_WIDTH) rounded = torch.floor(scaled_float) - model_out = torch.clamp(rounded, 0, (2**self.OUT_WIDTH - 1)) + model_out = torch.clamp(rounded, 0, (2 ** self.OUT_WIDTH - 1)) return model_out async def run_test(self): diff --git a/src/mase_components/common/test/comparator_accumulator_tb.py b/src/mase_components/common/test/comparator_accumulator_tb.py index fbd8af937..d05f41a09 100644 --- a/src/mase_components/common/test/comparator_accumulator_tb.py +++ b/src/mase_components/common/test/comparator_accumulator_tb.py @@ -33,7 +33,9 @@ def __init__(self, dut) -> None: ) def generate_inputs(self, batches=3): - return [randint(0, 2**self.DATA_WIDTH - 1) for _ in range(self.DEPTH * batches)] + return [ + randint(0, 2 ** self.DATA_WIDTH - 1) for _ in range(self.DEPTH * batches) + ] def model(self, inputs): @@ -152,7 +154,5 @@ def signed_max_min_cfgs(cfglist: list): cfgs = signed_max_min_cfgs(cfgs) mase_runner( - module_param_list=cfgs, - trace=True, - jobs=12, + module_param_list=cfgs, trace=True, jobs=12, ) diff --git a/src/mase_components/common/test/comparator_tree_tb.py b/src/mase_components/common/test/comparator_tree_tb.py index 5637791b3..ae2a7fc61 100644 --- a/src/mase_components/common/test/comparator_tree_tb.py +++ b/src/mase_components/common/test/comparator_tree_tb.py @@ -34,7 +34,7 @@ def __init__(self, dut) -> None: def generate_inputs(self, batches=3): return [ - [randint(0, 2**self.DATA_WIDTH - 1) for _ in range(self.SIZE)] + [randint(0, 2 ** self.DATA_WIDTH - 1) for _ in range(self.SIZE)] for _ in range(batches) ] @@ -136,6 +136,5 @@ def signed_max_min_cfgs(cfglist: list): cfgs = signed_max_min_cfgs(cfgs) mase_runner( - module_param_list=cfgs, - trace=True, + module_param_list=cfgs, trace=True, ) diff --git a/src/mase_components/common/test/register_slice_tb.py b/src/mase_components/common/test/register_slice_tb.py index 638d53da7..ffcec4f3e 100644 --- a/src/mase_components/common/test/register_slice_tb.py +++ b/src/mase_components/common/test/register_slice_tb.py @@ -60,9 +60,7 @@ def in_out_wave(dut, name): ) logger.debug( "{} State: (shift_reg, buffer) = ({},{})".format( - name, - int(dut.shift_reg.value), - int(dut.buffer.value), + name, int(dut.shift_reg.value), int(dut.buffer.value), ) ) diff --git a/src/mase_components/common/test/single_element_repeat_tb.py b/src/mase_components/common/test/single_element_repeat_tb.py index 247cc2b27..6d39b0c35 100644 --- a/src/mase_components/common/test/single_element_repeat_tb.py +++ b/src/mase_components/common/test/single_element_repeat_tb.py @@ -26,7 +26,7 @@ def __init__(self, dut) -> None: ) def generate_inputs(self, num=10): - return [random.randint(0, 2**self.DATA_WIDTH - 1) for _ in range(num)] + return [random.randint(0, 2 ** self.DATA_WIDTH - 1) for _ in range(num)] def model(self, inputs): exp_out = [] @@ -103,7 +103,5 @@ def generate_random_params(): ] mase_runner( - module_param_list=cfgs, - trace=True, - jobs=8, + module_param_list=cfgs, trace=True, jobs=8, ) diff --git a/src/mase_components/convolution_layers/test/binary_activation_binary_convolution_tb.py b/src/mase_components/convolution_layers/test/binary_activation_binary_convolution_tb.py index 8091990e4..1481671a3 100644 --- a/src/mase_components/convolution_layers/test/binary_activation_binary_convolution_tb.py +++ b/src/mase_components/convolution_layers/test/binary_activation_binary_convolution_tb.py @@ -244,9 +244,7 @@ def sw_compute(self): bias = \n\ {} \n\ ".format( - data, - weight, - bias, + data, weight, bias, ) ) for i in range(self.samples): diff --git a/src/mase_components/convolution_layers/test/convolution_tb.py b/src/mase_components/convolution_layers/test/convolution_tb.py index 1a3789ed9..ec3c07a37 100644 --- a/src/mase_components/convolution_layers/test/convolution_tb.py +++ b/src/mase_components/convolution_layers/test/convolution_tb.py @@ -130,11 +130,7 @@ def get_manual_result( # out2 = get_manual_result(x, w, b, 2,1,2,2,4,4,0,0,12,4) # data_in_pack - x = q2i( - x, - config["data_in_width"], - config["data_in_frac_width"], - ) + x = q2i(x, config["data_in_width"], config["data_in_frac_width"],) self.log.info(f"x = {x}") # from (samples, c, h, w) to (samples*h*w*c/unroll_in_c, unroll_in_c) @@ -144,16 +140,8 @@ def get_manual_result( self.log.info(f"weight = {w}") self.log.info(f"bias = {b}") - w = q2i( - w, - config["weight_width"], - config["weight_frac_width"], - ) - b = q2i( - b, - config["bias_width"], - config["bias_frac_width"], - ) + w = q2i(w, config["weight_width"], config["weight_frac_width"],) + b = q2i(b, config["bias_width"], config["bias_frac_width"],) self.log.info(f"weight = {w}") self.log.info(f"bias = {b}") hw_w, hw_b = self.conv_pack( @@ -169,11 +157,7 @@ def get_manual_result( unroll_kernel_out=self.get_parameter("UNROLL_KERNEL_OUT"), unroll_out_channels=self.get_parameter("UNROLL_OUT_C"), ) - exp_out = q2i( - out, - config["out_width"], - config["out_frac_width"], - ) + exp_out = q2i(out, config["out_width"], config["out_frac_width"],) exp_out = ( exp_out.reshape( -1, self.get_parameter("OUT_C"), self.get_parameter("SLIDING_NUM") @@ -223,10 +207,7 @@ def conv_pack( unroll_out_channels, ).permute(0, 3, 1, 4, 2) - w_tensor = w_tensor.reshape( - -1, - unroll_out_channels * unroll_kernel_out, - ) + w_tensor = w_tensor.reshape(-1, unroll_out_channels * unroll_kernel_out,) w_in = w_tensor.type(torch.int).tolist() # bias_pack bias_tensor = ( @@ -332,10 +313,7 @@ def test_fixed_linear_smoke(): Some quick tests to check if the module is working. """ mase_runner( - trace=True, - module_param_list=[ - get_fixed_conv_config(), - ], + trace=True, module_param_list=[get_fixed_conv_config(),], ) diff --git a/src/mase_components/convolution_layers/test/padding_tb.py b/src/mase_components/convolution_layers/test/padding_tb.py index 6cb470069..d0504ef5d 100644 --- a/src/mase_components/convolution_layers/test/padding_tb.py +++ b/src/mase_components/convolution_layers/test/padding_tb.py @@ -96,9 +96,9 @@ def data_pack(self): for j in range(in_channels): for k in range(in_height): for s in range(in_width): - re_data_tensor[i][j][k + padding_height][s + padding_width] = ( - data_tensor[i][k][s][j] - ) + re_data_tensor[i][j][k + padding_height][ + s + padding_width + ] = data_tensor[i][k][s][j] return re_data_tensor @@ -248,9 +248,7 @@ def runner(): print(extra_args) runner = get_runner(sim) runner.build( - verilog_sources=verilog_sources, - hdl_toplevel="padding", - build_args=extra_args, + verilog_sources=verilog_sources, hdl_toplevel="padding", build_args=extra_args, ) runner.test(hdl_toplevel="padding", test_module="padding_tb") diff --git a/src/mase_components/convolution_layers/test/roller_tb.py b/src/mase_components/convolution_layers/test/roller_tb.py index 4bf56b273..8c7a24d02 100644 --- a/src/mase_components/convolution_layers/test/roller_tb.py +++ b/src/mase_components/convolution_layers/test/roller_tb.py @@ -195,9 +195,7 @@ def runner(): print(extra_args) runner = get_runner(sim)() runner.build( - verilog_sources=verilog_sources, - toplevel="roller", - extra_args=extra_args, + verilog_sources=verilog_sources, toplevel="roller", extra_args=extra_args, ) runner.test(toplevel="roller", py_module="roller_tb") diff --git a/src/mase_components/convolution_layers/test/sliding_window_tb.py b/src/mase_components/convolution_layers/test/sliding_window_tb.py index f893d6f68..73facf247 100644 --- a/src/mase_components/convolution_layers/test/sliding_window_tb.py +++ b/src/mase_components/convolution_layers/test/sliding_window_tb.py @@ -119,9 +119,9 @@ def data_pack(self): for j in range(in_channels): for k in range(in_height): for s in range(in_width): - re_data_tensor[i][j][k + padding_height][s + padding_width] = ( - data_tensor[i][k][s][j] - ) + re_data_tensor[i][j][k + padding_height][ + s + padding_width + ] = data_tensor[i][k][s][j] return re_data_tensor diff --git a/src/mase_components/deps.py b/src/mase_components/deps.py index a06529344..9596adedc 100644 --- a/src/mase_components/deps.py +++ b/src/mase_components/deps.py @@ -17,11 +17,7 @@ "activation_layers", "scalar_operators/fixed", ], - "activation_layers/fixed_gelu": [ - "common", - "memory", - "activation_layers", - ], + "activation_layers/fixed_gelu": ["common", "memory", "activation_layers",], "activation_layers/fixed_softsign": [ "common", "activation_layers", @@ -127,11 +123,7 @@ "common", "cast", ], - "language_models/llmint8/scatter": [ - "language_models/llmint8", - "memory", - "common", - ], + "language_models/llmint8/scatter": ["language_models/llmint8", "memory", "common",], # Linear "linear_layers/fixed_linear_layer/fixed_linear": [ "cast", diff --git a/src/mase_components/helper/generate_memory.py b/src/mase_components/helper/generate_memory.py index e46c99604..c780d6c7f 100644 --- a/src/mase_components/helper/generate_memory.py +++ b/src/mase_components/helper/generate_memory.py @@ -60,10 +60,10 @@ def generate_lookup(data_width: int, f_width: int, function: str, type="hex"): count += 1 iarr.append(i) val = quanter(f(torch.tensor(i))) # entry in the lookup table - lut[doubletofx(data_width=data_width, f_width=f_width, num=i, type=type)] = ( - doubletofx( - data_width=data_width, f_width=f_width, num=val.item(), type=type - ) + lut[ + doubletofx(data_width=data_width, f_width=f_width, num=i, type=type) + ] = doubletofx( + data_width=data_width, f_width=f_width, num=val.item(), type=type ) i += 2 ** -(f_width) return lut diff --git a/src/mase_components/hls/bfp_arith/bfp_adder.py b/src/mase_components/hls/bfp_arith/bfp_adder.py index 5b2a401d9..0ba95741e 100644 --- a/src/mase_components/hls/bfp_arith/bfp_adder.py +++ b/src/mase_components/hls/bfp_arith/bfp_adder.py @@ -2,11 +2,7 @@ def bfp_adder_gen( - writer, - x_exp_width=16, - x_man_width=8, - w_exp_width=16, - w_man_width=8, + writer, x_exp_width=16, x_man_width=8, w_exp_width=16, w_man_width=8, ): """ This script generates a element-level bfp add in HLS diff --git a/src/mase_components/hls/bfp_arith/bfp_multiplier.py b/src/mase_components/hls/bfp_arith/bfp_multiplier.py index e3695295e..fd588ea3c 100644 --- a/src/mase_components/hls/bfp_arith/bfp_multiplier.py +++ b/src/mase_components/hls/bfp_arith/bfp_multiplier.py @@ -1,9 +1,5 @@ def bfp_multiplier_gen( - writer, - x_exp_width=16, - x_man_width=8, - w_exp_width=16, - w_man_width=8, + writer, x_exp_width=16, x_man_width=8, w_exp_width=16, w_man_width=8, ): """ This script generates a element-level bfp mult in HLS diff --git a/src/mase_components/hls/elastic/buffer.py b/src/mase_components/hls/elastic/buffer.py index d25c1fd56..a0fda6b90 100644 --- a/src/mase_components/hls/elastic/buffer.py +++ b/src/mase_components/hls/elastic/buffer.py @@ -2,13 +2,7 @@ def buffer_gen( - writer, - x_width=16, - x_frac_width=8, - x_row=3, - x_col=2, - x_row_depth=3, - x_col_depth=2, + writer, x_width=16, x_frac_width=8, x_row=3, x_col=2, x_row_depth=3, x_col_depth=2, ): """ This script generates a buffer in HLS diff --git a/src/mase_components/hls/hls_regression.py b/src/mase_components/hls/hls_regression.py index a02dca60a..719b20380 100755 --- a/src/mase_components/hls/hls_regression.py +++ b/src/mase_components/hls/hls_regression.py @@ -106,10 +106,7 @@ def main(): parser = ArgumentParser(usage=USAGE) parser.add_argument( - "--op", - dest="op", - default=None, - help="Op name to explore", + "--op", dest="op", default=None, help="Op name to explore", ) parser.add_argument( "--dir", diff --git a/src/mase_components/hls/int_arith/int_layernorm.py b/src/mase_components/hls/int_arith/int_layernorm.py index 059b15d4d..9521a68c0 100644 --- a/src/mase_components/hls/int_arith/int_layernorm.py +++ b/src/mase_components/hls/int_arith/int_layernorm.py @@ -2,13 +2,7 @@ def int_layernorm_gen( - writer, - x_width=16, - x_frac_width=8, - x_row=3, - x_col=2, - x_row_depth=3, - x_col_depth=2, + writer, x_width=16, x_frac_width=8, x_row=3, x_col=2, x_row_depth=3, x_col_depth=2, ): """ This script generates a fixed-point layernorm in HLS diff --git a/src/mase_components/hls/int_arith/int_relu.py b/src/mase_components/hls/int_arith/int_relu.py index f934640c9..09f7462c2 100644 --- a/src/mase_components/hls/int_arith/int_relu.py +++ b/src/mase_components/hls/int_arith/int_relu.py @@ -2,13 +2,7 @@ def int_relu_gen( - writer, - x_width=16, - x_frac_width=8, - x_row=3, - x_col=2, - x_row_depth=3, - x_col_depth=2, + writer, x_width=16, x_frac_width=8, x_row=3, x_col=2, x_row_depth=3, x_col_depth=2, ): """ This script generates a fixed-point relu in HLS diff --git a/src/mase_components/hls/int_arith/int_silu.py b/src/mase_components/hls/int_arith/int_silu.py index c4fc32abe..c90511c47 100644 --- a/src/mase_components/hls/int_arith/int_silu.py +++ b/src/mase_components/hls/int_arith/int_silu.py @@ -2,13 +2,7 @@ def int_silu_gen( - writer, - x_width=16, - x_frac_width=8, - x_row=3, - x_col=2, - x_row_depth=3, - x_col_depth=2, + writer, x_width=16, x_frac_width=8, x_row=3, x_col=2, x_row_depth=3, x_col_depth=2, ): """ This script generates a fixed-point silu in HLS diff --git a/src/mase_components/hls/int_arith/int_softmax.py b/src/mase_components/hls/int_arith/int_softmax.py index c0b4ca5b1..8ad5901cf 100644 --- a/src/mase_components/hls/int_arith/int_softmax.py +++ b/src/mase_components/hls/int_arith/int_softmax.py @@ -2,13 +2,7 @@ def int_softmax_gen( - writer, - x_width=16, - x_frac_width=8, - x_row=3, - x_col=2, - x_row_depth=3, - x_col_depth=2, + writer, x_width=16, x_frac_width=8, x_row=3, x_col=2, x_row_depth=3, x_col_depth=2, ): """ This script generates a fixed-point softmax in HLS diff --git a/src/mase_components/hls/int_arith/int_transpose.py b/src/mase_components/hls/int_arith/int_transpose.py index 6b7fd0a15..ee9c6d2d7 100644 --- a/src/mase_components/hls/int_arith/int_transpose.py +++ b/src/mase_components/hls/int_arith/int_transpose.py @@ -2,13 +2,7 @@ def int_transpose_gen( - writer, - x_width=16, - x_frac_width=8, - x_row=3, - x_col=2, - x_row_depth=3, - x_col_depth=2, + writer, x_width=16, x_frac_width=8, x_row=3, x_col=2, x_row_depth=3, x_col_depth=2, ): """ This script generates a fixed-point transpose in HLS diff --git a/src/mase_components/hls/regression_gen/bfp_add_dse.py b/src/mase_components/hls/regression_gen/bfp_add_dse.py index 834644b24..c283301d2 100644 --- a/src/mase_components/hls/regression_gen/bfp_add_dse.py +++ b/src/mase_components/hls/regression_gen/bfp_add_dse.py @@ -1,13 +1,7 @@ # TODO: Temporary working solution import sys, os -sys.path.append( - os.path.join( - os.path.dirname(os.path.realpath(__file__)), - "..", - "..", - ) -) +sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..",)) from hls.regression_gen.utils import ( DSE_MODES, @@ -157,8 +151,7 @@ def bfp_add_dse(mode=None, top=None, threads=16): if mode in ["report", "all"]: top_name = "bfp_add_2" hr = get_hls_results( - project=os.path.join(top, file_name), - top=top_name, + project=os.path.join(top, file_name), top=top_name, ) data_points.append( [ diff --git a/src/mase_components/hls/regression_gen/bfp_linear2d_dse.py b/src/mase_components/hls/regression_gen/bfp_linear2d_dse.py index d6f78f30c..9a894b143 100644 --- a/src/mase_components/hls/regression_gen/bfp_linear2d_dse.py +++ b/src/mase_components/hls/regression_gen/bfp_linear2d_dse.py @@ -1,13 +1,7 @@ # TODO: Temporary working solution import sys, os -sys.path.append( - os.path.join( - os.path.dirname(os.path.realpath(__file__)), - "..", - "..", - ) -) +sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..",)) from hls.regression_gen.utils import ( DSE_MODES, diff --git a/src/mase_components/hls/regression_gen/bfp_mult_dse.py b/src/mase_components/hls/regression_gen/bfp_mult_dse.py index b4b5fa58c..1510ea1f3 100644 --- a/src/mase_components/hls/regression_gen/bfp_mult_dse.py +++ b/src/mase_components/hls/regression_gen/bfp_mult_dse.py @@ -1,13 +1,7 @@ # TODO: Temporary working solution import sys, os -sys.path.append( - os.path.join( - os.path.dirname(os.path.realpath(__file__)), - "..", - "..", - ) -) +sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..",)) from hls.regression_gen.utils import ( DSE_MODES, @@ -157,8 +151,7 @@ def bfp_mult_dse(mode=None, top=None, threads=16): if mode in ["report", "all"]: top_name = "bfp_mult_2" hr = get_hls_results( - project=os.path.join(top, file_name), - top=top_name, + project=os.path.join(top, file_name), top=top_name, ) if hr is None: continue diff --git a/src/mase_components/hls/regression_gen/buffer_dse.py b/src/mase_components/hls/regression_gen/buffer_dse.py index 08cc6d56e..bcc003c6d 100644 --- a/src/mase_components/hls/regression_gen/buffer_dse.py +++ b/src/mase_components/hls/regression_gen/buffer_dse.py @@ -1,13 +1,7 @@ # TODO: Temporary working solution import sys, os -sys.path.append( - os.path.join( - os.path.dirname(os.path.realpath(__file__)), - "..", - "..", - ) -) +sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..",)) from hls.regression_gen.utils import ( DSE_MODES, @@ -129,8 +123,7 @@ def buffer_dse(mode=None, top=None, threads=16): if mode in ["report", "all"]: top_name = "buffer_0" hr = get_hls_results( - project=os.path.join(top, file_name), - top=top_name, + project=os.path.join(top, file_name), top=top_name, ) data_points.append( [ diff --git a/src/mase_components/hls/regression_gen/fork_dse.py b/src/mase_components/hls/regression_gen/fork_dse.py index f045cd8a3..c8a5b3c0b 100644 --- a/src/mase_components/hls/regression_gen/fork_dse.py +++ b/src/mase_components/hls/regression_gen/fork_dse.py @@ -1,13 +1,7 @@ # TODO: Temporary working solution import sys, os -sys.path.append( - os.path.join( - os.path.dirname(os.path.realpath(__file__)), - "..", - "..", - ) -) +sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..",)) from hls.regression_gen.utils import ( DSE_MODES, @@ -139,8 +133,7 @@ def fork_dse(mode=None, top=None, threads=16): if mode in ["report", "all"]: top_name = "fork_0" hr = get_hls_results( - project=os.path.join(top, file_name), - top=top_name, + project=os.path.join(top, file_name), top=top_name, ) data_points.append( [ diff --git a/src/mase_components/hls/regression_gen/int_add_dse.py b/src/mase_components/hls/regression_gen/int_add_dse.py index d4c4d9ccc..a88247d24 100644 --- a/src/mase_components/hls/regression_gen/int_add_dse.py +++ b/src/mase_components/hls/regression_gen/int_add_dse.py @@ -1,13 +1,7 @@ # TODO: Temporary working solution import sys, os -sys.path.append( - os.path.join( - os.path.dirname(os.path.realpath(__file__)), - "..", - "..", - ) -) +sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..",)) from hls.regression_gen.utils import ( DSE_MODES, @@ -157,8 +151,7 @@ def int_add_dse(mode=None, top=None, threads=16): if mode in ["report", "all"]: top_name = "int_add_0" hr = get_hls_results( - project=os.path.join(top, file_name), - top=top_name, + project=os.path.join(top, file_name), top=top_name, ) data_points.append( [ diff --git a/src/mase_components/hls/regression_gen/int_layernorm_dse.py b/src/mase_components/hls/regression_gen/int_layernorm_dse.py index 3e6a61aff..d0ef2afdb 100644 --- a/src/mase_components/hls/regression_gen/int_layernorm_dse.py +++ b/src/mase_components/hls/regression_gen/int_layernorm_dse.py @@ -1,13 +1,7 @@ # TODO: Temporary working solution import sys, os -sys.path.append( - os.path.join( - os.path.dirname(os.path.realpath(__file__)), - "..", - "..", - ) -) +sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..",)) from hls.regression_gen.utils import ( DSE_MODES, @@ -130,8 +124,7 @@ def int_layernorm_dse(mode=None, top=None, threads=16): if mode in ["report", "all"]: top_name = "int_layernorm_1" hr = get_hls_results( - project=os.path.join(top, file_name), - top=top_name, + project=os.path.join(top, file_name), top=top_name, ) data_points.append( [ diff --git a/src/mase_components/hls/regression_gen/int_linear2d_dse.py b/src/mase_components/hls/regression_gen/int_linear2d_dse.py index 1c89861eb..d75122306 100644 --- a/src/mase_components/hls/regression_gen/int_linear2d_dse.py +++ b/src/mase_components/hls/regression_gen/int_linear2d_dse.py @@ -1,13 +1,7 @@ # TODO: Temporary working solution import sys, os -sys.path.append( - os.path.join( - os.path.dirname(os.path.realpath(__file__)), - "..", - "..", - ) -) +sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..",)) from hls.regression_gen.utils import ( DSE_MODES, diff --git a/src/mase_components/hls/regression_gen/int_matmul_dse.py b/src/mase_components/hls/regression_gen/int_matmul_dse.py index 1d1f35fd2..ac09aaa5c 100644 --- a/src/mase_components/hls/regression_gen/int_matmul_dse.py +++ b/src/mase_components/hls/regression_gen/int_matmul_dse.py @@ -1,13 +1,7 @@ # TODO: Temporary working solution import sys, os -sys.path.append( - os.path.join( - os.path.dirname(os.path.realpath(__file__)), - "..", - "..", - ) -) +sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..",)) from hls.regression_gen.utils import DSE_MODES, get_tcl_buff from hls.int_arith import int_matmul_gen diff --git a/src/mase_components/hls/regression_gen/int_mult_dse.py b/src/mase_components/hls/regression_gen/int_mult_dse.py index 275ffb163..f824ec99e 100644 --- a/src/mase_components/hls/regression_gen/int_mult_dse.py +++ b/src/mase_components/hls/regression_gen/int_mult_dse.py @@ -1,13 +1,7 @@ # TODO: Temporary working solution import sys, os -sys.path.append( - os.path.join( - os.path.dirname(os.path.realpath(__file__)), - "..", - "..", - ) -) +sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..",)) from hls.regression_gen.utils import ( DSE_MODES, @@ -157,8 +151,7 @@ def int_mult_dse(mode=None, top=None, threads=16): if mode in ["report", "all"]: top_name = "int_mult_0" hr = get_hls_results( - project=os.path.join(top, file_name), - top=top_name, + project=os.path.join(top, file_name), top=top_name, ) data_points.append( [ diff --git a/src/mase_components/hls/regression_gen/int_relu_dse.py b/src/mase_components/hls/regression_gen/int_relu_dse.py index 7224b1444..788e3f593 100644 --- a/src/mase_components/hls/regression_gen/int_relu_dse.py +++ b/src/mase_components/hls/regression_gen/int_relu_dse.py @@ -1,13 +1,7 @@ # TODO: Temporary working solution import sys, os -sys.path.append( - os.path.join( - os.path.dirname(os.path.realpath(__file__)), - "..", - "..", - ) -) +sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..",)) from hls.regression_gen.utils import ( DSE_MODES, @@ -131,8 +125,7 @@ def int_relu_dse(mode=None, top=None, threads=16): if mode in ["report", "all"]: top_name = "int_relu_0" hr = get_hls_results( - project=os.path.join(top, file_name), - top=top_name, + project=os.path.join(top, file_name), top=top_name, ) data_points.append( [ diff --git a/src/mase_components/hls/regression_gen/int_rmsnorm_dse.py b/src/mase_components/hls/regression_gen/int_rmsnorm_dse.py index 262794af2..c559ecee8 100644 --- a/src/mase_components/hls/regression_gen/int_rmsnorm_dse.py +++ b/src/mase_components/hls/regression_gen/int_rmsnorm_dse.py @@ -1,13 +1,7 @@ # TODO: Temporary working solution import sys, os -sys.path.append( - os.path.join( - os.path.dirname(os.path.realpath(__file__)), - "..", - "..", - ) -) +sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..",)) from hls.regression_gen.utils import ( DSE_MODES, @@ -151,8 +145,7 @@ def int_rmsnorm_dse(mode=None, top=None, threads=16): if mode in ["report", "all"]: top_name = "int_rmsnorm_0" hr = get_hls_results( - project=os.path.join(top, file_name), - top=top_name, + project=os.path.join(top, file_name), top=top_name, ) data_points.append( [ diff --git a/src/mase_components/hls/regression_gen/int_rope_dse.py b/src/mase_components/hls/regression_gen/int_rope_dse.py index 421a37e21..bfc2bc385 100644 --- a/src/mase_components/hls/regression_gen/int_rope_dse.py +++ b/src/mase_components/hls/regression_gen/int_rope_dse.py @@ -1,13 +1,7 @@ # TODO: Temporary working solution import sys, os -sys.path.append( - os.path.join( - os.path.dirname(os.path.realpath(__file__)), - "..", - "..", - ) -) +sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..",)) from hls.regression_gen.utils import ( DSE_MODES, @@ -152,8 +146,7 @@ def int_rope_dse(mode=None, top=None, threads=16): if mode in ["report", "all"]: top_name = "int_rope_0" hr = get_hls_results( - project=os.path.join(top, file_name), - top=top_name, + project=os.path.join(top, file_name), top=top_name, ) data_points.append( [ diff --git a/src/mase_components/hls/regression_gen/int_silu_dse.py b/src/mase_components/hls/regression_gen/int_silu_dse.py index 0d80397cd..e3ce1bd11 100644 --- a/src/mase_components/hls/regression_gen/int_silu_dse.py +++ b/src/mase_components/hls/regression_gen/int_silu_dse.py @@ -1,13 +1,7 @@ # TODO: Temporary working solution import sys, os -sys.path.append( - os.path.join( - os.path.dirname(os.path.realpath(__file__)), - "..", - "..", - ) -) +sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..",)) from hls.regression_gen.utils import ( DSE_MODES, @@ -131,8 +125,7 @@ def int_silu_dse(mode=None, top=None, threads=16): if mode in ["report", "all"]: top_name = "int_silu_0" hr = get_hls_results( - project=os.path.join(top, file_name), - top=top_name, + project=os.path.join(top, file_name), top=top_name, ) data_points.append( [ diff --git a/src/mase_components/hls/regression_gen/int_softmax_dse.py b/src/mase_components/hls/regression_gen/int_softmax_dse.py index c10c091ef..796829278 100644 --- a/src/mase_components/hls/regression_gen/int_softmax_dse.py +++ b/src/mase_components/hls/regression_gen/int_softmax_dse.py @@ -1,13 +1,7 @@ # TODO: Temporary working solution import sys, os -sys.path.append( - os.path.join( - os.path.dirname(os.path.realpath(__file__)), - "..", - "..", - ) -) +sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..",)) from hls.regression_gen.utils import ( DSE_MODES, @@ -131,8 +125,7 @@ def int_softmax_dse(mode=None, top=None, threads=16): if mode in ["report", "all"]: top_name = "int_softmax_0" hr = get_hls_results( - project=os.path.join(top, file_name), - top=top_name, + project=os.path.join(top, file_name), top=top_name, ) data_points.append( [ diff --git a/src/mase_components/hls/regression_gen/int_transpose_dse.py b/src/mase_components/hls/regression_gen/int_transpose_dse.py index 7b5c2b715..4563a9eca 100644 --- a/src/mase_components/hls/regression_gen/int_transpose_dse.py +++ b/src/mase_components/hls/regression_gen/int_transpose_dse.py @@ -1,13 +1,7 @@ # TODO: Temporary working solution import sys, os -sys.path.append( - os.path.join( - os.path.dirname(os.path.realpath(__file__)), - "..", - "..", - ) -) +sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..",)) from hls.regression_gen.utils import ( DSE_MODES, @@ -131,8 +125,7 @@ def int_transpose_dse(mode=None, top=None, threads=16): if mode in ["report", "all"]: top_name = "int_transpose_0" hr = get_hls_results( - project=os.path.join(top, file_name), - top=top_name, + project=os.path.join(top, file_name), top=top_name, ) data_points.append( [ diff --git a/src/mase_components/hls/scripts/bl_bfp.py b/src/mase_components/hls/scripts/bl_bfp.py index beff47050..29f34c04d 100644 --- a/src/mase_components/hls/scripts/bl_bfp.py +++ b/src/mase_components/hls/scripts/bl_bfp.py @@ -1,12 +1,6 @@ import os, sys -sys.path.append( - os.path.join( - os.path.dirname(os.path.realpath(__file__)), - "..", - "..", - ) -) +sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..",)) from hls.bfp_arith import bfp_mm_gen from hls.bfp_arith import bfp_add_gen @@ -14,12 +8,7 @@ def get_big_little_bfp( - HIGH_MAN_WIDTH=7, - LOW_MAN_WIDTH=3, - X_ROW=1, - X_COL=4096, - W_COL=11008, - A_COL=32, + HIGH_MAN_WIDTH=7, LOW_MAN_WIDTH=3, X_ROW=1, X_COL=4096, W_COL=11008, A_COL=32, ): W_ROW = X_COL A_ROW = X_COL diff --git a/src/mase_components/linear_layers/binarized_operators/test/binary_activation_binary_adder_tree_layer_tb.py b/src/mase_components/linear_layers/binarized_operators/test/binary_activation_binary_adder_tree_layer_tb.py index 91552b68a..2cdb58b59 100644 --- a/src/mase_components/linear_layers/binarized_operators/test/binary_activation_binary_adder_tree_layer_tb.py +++ b/src/mase_components/linear_layers/binarized_operators/test/binary_activation_binary_adder_tree_layer_tb.py @@ -95,8 +95,7 @@ def runner(): extra_args.append(f"-G{k}={v}") mase_runner( - trace=True, - module_param_list=[test_case.get_dut_parameters()], + trace=True, module_param_list=[test_case.get_dut_parameters()], ) diff --git a/src/mase_components/linear_layers/fixed_operators/test/fixed_isqrt_tb.py b/src/mase_components/linear_layers/fixed_operators/test/fixed_isqrt_tb.py index 0aa8f5239..27e729200 100644 --- a/src/mase_components/linear_layers/fixed_operators/test/fixed_isqrt_tb.py +++ b/src/mase_components/linear_layers/fixed_operators/test/fixed_isqrt_tb.py @@ -24,31 +24,23 @@ class VerificationCase(Testbench): def __init__(self, dut): super().__init__(dut, dut.clk, dut.rst) self.assign_self_params( - [ - "IN_WIDTH", - "IN_FRAC_WIDTH", - "LUT_POW", - ] + ["IN_WIDTH", "IN_FRAC_WIDTH", "LUT_POW",] ) self.input_driver = StreamDriver( dut.clk, dut.in_data, dut.in_valid, dut.in_ready ) self.output_monitor = StreamMonitor( - dut.clk, - dut.out_data, - dut.out_valid, - dut.out_ready, - name="Output ISQRT", + dut.clk, dut.out_data, dut.out_valid, dut.out_ready, name="Output ISQRT", ) def generate_inputs(self, num=10000): - maxnum = (2**self.IN_WIDTH) - 1 + maxnum = (2 ** self.IN_WIDTH) - 1 return [random.randint(0, maxnum) for _ in range(num)], num def model(self, data_in): ref = [] - lut_size = 2**self.LUT_POW + lut_size = 2 ** self.LUT_POW lut = make_lut(lut_size, self.IN_WIDTH) for x in data_in: expected = isqrt_sw2( @@ -127,7 +119,7 @@ async def valid_backpressure(dut): makedirs(mem_dir, exist_ok=True) def single_cfg(width, frac_width, lut_pow, str_id): - lut_size = 2**lut_pow + lut_size = 2 ** lut_pow lut = make_lut(lut_size, width) mem_path = mem_dir / f"lutmem-{str_id}.mem" write_memb(mem_path, lut, width) diff --git a/src/mase_components/linear_layers/fixed_operators/test/fixed_lut_index_tb.py b/src/mase_components/linear_layers/fixed_operators/test/fixed_lut_index_tb.py index 42bcb4edc..84fde6666 100644 --- a/src/mase_components/linear_layers/fixed_operators/test/fixed_lut_index_tb.py +++ b/src/mase_components/linear_layers/fixed_operators/test/fixed_lut_index_tb.py @@ -22,7 +22,7 @@ def __init__(self, dut): self.assign_self_params(["WIDTH", "LUT_POW"]) def generate_inputs(self): - samples = 2**self.WIDTH + samples = 2 ** self.WIDTH data_x = [] msb_indices = [] for x in range(samples): @@ -94,7 +94,6 @@ async def cocotb_test_fixed_lut_index(dut): @pytest.mark.skip(reason="Needs to be fixed.") def test_fixed_lut_index(): - def full_sweep(): parameter_list = [] lut_pow = 5 diff --git a/src/mase_components/linear_layers/fixed_operators/test/fixed_nr_stage_tb.py b/src/mase_components/linear_layers/fixed_operators/test/fixed_nr_stage_tb.py index 8c0b70ed1..07a199810 100644 --- a/src/mase_components/linear_layers/fixed_operators/test/fixed_nr_stage_tb.py +++ b/src/mase_components/linear_layers/fixed_operators/test/fixed_nr_stage_tb.py @@ -24,12 +24,12 @@ def __init__(self, dut): self.assign_self_params(["WIDTH"]) def generate_inputs(self, lut_pow): - samples = 2**self.WIDTH + samples = 2 ** self.WIDTH int_width = 1 frac_width = self.WIDTH - 1 data_x = [] initial_guesses = [] - lut_size = 2**lut_pow + lut_size = 2 ** lut_pow lut = make_lut(lut_size, self.WIDTH) # NOTE: since negative values are not supported by fixed formats since # isqrt only outputs positive results we cannot test every single com- diff --git a/src/mase_components/linear_layers/fixed_operators/test/fixed_range_augmentation_tb.py b/src/mase_components/linear_layers/fixed_operators/test/fixed_range_augmentation_tb.py index cbf9cf776..a0801d23a 100644 --- a/src/mase_components/linear_layers/fixed_operators/test/fixed_range_augmentation_tb.py +++ b/src/mase_components/linear_layers/fixed_operators/test/fixed_range_augmentation_tb.py @@ -22,7 +22,7 @@ def __init__(self, dut) -> None: self.assign_self_params(["WIDTH", "FRAC_WIDTH"]) def generate_inputs(self): - samples = 2**self.WIDTH + samples = 2 ** self.WIDTH data_x = [] msb_indices = [] for x in range(samples): diff --git a/src/mase_components/linear_layers/fixed_operators/test/fixed_range_reduction_tb.py b/src/mase_components/linear_layers/fixed_operators/test/fixed_range_reduction_tb.py index 1b54ea013..c4d00af2a 100644 --- a/src/mase_components/linear_layers/fixed_operators/test/fixed_range_reduction_tb.py +++ b/src/mase_components/linear_layers/fixed_operators/test/fixed_range_reduction_tb.py @@ -20,7 +20,7 @@ def __init__(self, dut) -> None: self.assign_self_params(["WIDTH"]) def generate_inputs(self): - samples = 2**self.WIDTH + samples = 2 ** self.WIDTH return [val for val in range(0, samples)], samples def model(self, inputs): @@ -71,7 +71,6 @@ async def cocotb_test_fixed_range_reduction(dut): @pytest.mark.skip(reason="Needs to be fixed.") def test_fixed_range_reduction(): - def full_sweep(): parameter_list = [] for width in range(1, 17): diff --git a/src/mase_components/linear_layers/fixed_operators/test/isqrt_sw.py b/src/mase_components/linear_layers/fixed_operators/test/isqrt_sw.py index c4723f971..a001b27bc 100644 --- a/src/mase_components/linear_layers/fixed_operators/test/isqrt_sw.py +++ b/src/mase_components/linear_layers/fixed_operators/test/isqrt_sw.py @@ -13,7 +13,7 @@ def find_msb(x: int, width: int) -> int: def float_to_int(x: float, int_width: int, frac_width: int) -> int: integer = int(x) x -= integer - res = integer * (2**frac_width) + res = integer * (2 ** frac_width) for i in range(1, frac_width + 1): power = 2 ** (-i) if power <= x: @@ -23,8 +23,8 @@ def float_to_int(x: float, int_width: int, frac_width: int) -> int: def int_to_float(x: int, int_width: int, frac_width: int) -> float: - integer = x / (2**frac_width) - fraction = x - integer * 2**frac_width + integer = x / (2 ** frac_width) + fraction = x - integer * 2 ** frac_width res = integer for i in range(1, frac_width + 1): @@ -85,7 +85,7 @@ def fixed_lut_index_sw(x_red: int, width: int, lut_pow: int) -> int: res = 0 else: res = x_red - 2 ** (width - 1) - res = res * 2**lut_pow + res = res * 2 ** lut_pow res = res / 2 ** (width - 1) # FORMAT OUTPUT: Q(WIDTH).0 return int(res) @@ -258,7 +258,7 @@ def test_sw_model(): def debug_single(): lut_pow = 5 - lut_size = 2**lut_pow + lut_size = 2 ** lut_pow int_width = 2 frac_width = 1 width = int_width + frac_width diff --git a/src/mase_components/linear_layers/matmul/test/fixed_matmul_tb.py b/src/mase_components/linear_layers/matmul/test/fixed_matmul_tb.py index 815e9ed07..5734fb891 100644 --- a/src/mase_components/linear_layers/matmul/test/fixed_matmul_tb.py +++ b/src/mase_components/linear_layers/matmul/test/fixed_matmul_tb.py @@ -147,10 +147,7 @@ def data_generate(self): weight_tensor = {} \n\ data_in = {} \n\ weight_in = {} ".format( - data_tensor, - weight_tensor, - data_in, - weight_in, + data_tensor, weight_tensor, data_in, weight_in, ) ) data_in.reverse() @@ -381,8 +378,7 @@ def runner(): build_args=extra_args, ) runner.test( - hdl_toplevel="fixed_matmul", - test_module="fixed_matmul_tb", + hdl_toplevel="fixed_matmul", test_module="fixed_matmul_tb", ) diff --git a/src/mase_components/linear_layers/matmul/test/simple_matmul_tb.py b/src/mase_components/linear_layers/matmul/test/simple_matmul_tb.py index 9f058695c..0cd5b07b9 100644 --- a/src/mase_components/linear_layers/matmul/test/simple_matmul_tb.py +++ b/src/mase_components/linear_layers/matmul/test/simple_matmul_tb.py @@ -40,8 +40,8 @@ def __init__(self, dut) -> None: ] ) - self.X_MAX = 2**self.X_WIDTH - 1 - self.Y_MAX = 2**self.Y_WIDTH - 1 + self.X_MAX = 2 ** self.X_WIDTH - 1 + self.Y_MAX = 2 ** self.Y_WIDTH - 1 self.x_driver = StreamDriver(dut.clk, dut.x_data, dut.x_valid, dut.x_ready) self.y_driver = StreamDriver(dut.clk, dut.y_data, dut.y_valid, dut.y_ready) @@ -87,10 +87,10 @@ def model(self, X, Y): logger.debug("Sign Extended & Scaled") X_input = sign_extend_t(X_input, self.X_WIDTH).type(torch.float32) / ( - 2**self.X_FRAC_WIDTH + 2 ** self.X_FRAC_WIDTH ) Y_input = sign_extend_t(Y_input, self.Y_WIDTH).type(torch.float32) / ( - 2**self.Y_FRAC_WIDTH + 2 ** self.Y_FRAC_WIDTH ) logger.debug(X_input) logger.debug(Y_input) diff --git a/src/mase_components/linear_layers/matmul/test/transpose_tb.py b/src/mase_components/linear_layers/matmul/test/transpose_tb.py index c859749ed..003382ff6 100644 --- a/src/mase_components/linear_layers/matmul/test/transpose_tb.py +++ b/src/mase_components/linear_layers/matmul/test/transpose_tb.py @@ -56,11 +56,7 @@ def generate_random_params(num=3): cfgs = list() for _ in range(num): cfgs.append( - { - "WIDTH": randint(1, 16), - "DIM0": randint(2, 12), - "DIM1": randint(2, 12), - } + {"WIDTH": randint(1, 16), "DIM0": randint(2, 12), "DIM1": randint(2, 12),} ) return cfgs diff --git a/src/mase_components/linear_layers/mxint_operators/test/mxint_cast_tb.py b/src/mase_components/linear_layers/mxint_operators/test/mxint_cast_tb.py index 963ae40df..ce742183b 100644 --- a/src/mase_components/linear_layers/mxint_operators/test/mxint_cast_tb.py +++ b/src/mase_components/linear_layers/mxint_operators/test/mxint_cast_tb.py @@ -29,10 +29,7 @@ def __init__(self, dut, num) -> None: self.log = SimLog("%s" % (type(self).__qualname__)) self.data_in_0_driver = MultiSignalStreamDriver( - dut.clk, - (dut.mdata_in, dut.edata_in), - dut.data_in_valid, - dut.data_in_ready, + dut.clk, (dut.mdata_in, dut.edata_in), dut.data_in_valid, dut.data_in_ready, ) self.data_out_0_monitor = MultiSignalStreamMonitor( @@ -51,14 +48,10 @@ def generate_inputs(self): for _ in range(self.num): data = 20 * torch.rand(int(self.dut.BLOCK_SIZE)) (data_in, mdata_in, edata_in) = mxint_quantize( - data, - int(self.dut.IN_MAN_WIDTH), - int(self.dut.IN_EXP_WIDTH), + data, int(self.dut.IN_MAN_WIDTH), int(self.dut.IN_EXP_WIDTH), ) exp_out, mexp_out, eexp_out = mxint_quantize( - data_in, - int(self.dut.OUT_MAN_WIDTH), - int(self.dut.OUT_EXP_WIDTH), + data_in, int(self.dut.OUT_MAN_WIDTH), int(self.dut.OUT_EXP_WIDTH), ) breakpoint() inputs.append((mdata_in.int().tolist(), edata_in.int().tolist())) diff --git a/src/mase_components/linear_layers/mxint_operators/test/mxint_dot_product_tb.py b/src/mase_components/linear_layers/mxint_operators/test/mxint_dot_product_tb.py index 43e70adcf..cb1b9638f 100644 --- a/src/mase_components/linear_layers/mxint_operators/test/mxint_dot_product_tb.py +++ b/src/mase_components/linear_layers/mxint_operators/test/mxint_dot_product_tb.py @@ -40,10 +40,7 @@ def __init__(self, dut, num) -> None: dut.data_in_0_ready, ) self.weight_driver = MultiSignalStreamDriver( - dut.clk, - (dut.mweight, dut.eweight), - dut.weight_valid, - dut.weight_ready, + dut.clk, (dut.mweight, dut.eweight), dut.weight_valid, dut.weight_ready, ) self.data_out_0_monitor = MultiSignalStreamMonitor( @@ -67,9 +64,7 @@ def generate_inputs(self): ) w = torch.rand(int(self.dut.BLOCK_SIZE)) (weight, mweight, eweight) = mxint_quantize( - w, - int(self.dut.WEIGHT_PRECISION_0), - int(self.dut.WEIGHT_PRECISION_1), + w, int(self.dut.WEIGHT_PRECISION_0), int(self.dut.WEIGHT_PRECISION_1), ) mdp_out = mdata_in @ mweight edp_out = edata_in + eweight diff --git a/src/mase_components/linear_layers/mxint_operators/test/mxint_matmul_tb.py b/src/mase_components/linear_layers/mxint_operators/test/mxint_matmul_tb.py index f31d1ab61..e01e056e1 100644 --- a/src/mase_components/linear_layers/mxint_operators/test/mxint_matmul_tb.py +++ b/src/mase_components/linear_layers/mxint_operators/test/mxint_matmul_tb.py @@ -48,16 +48,10 @@ def __init__(self, dut) -> None: self.log.setLevel(logging.DEBUG) self.a_driver = MultiSignalStreamDriver( - dut.clk, - (dut.ma_data, dut.ea_data), - dut.a_valid, - dut.a_ready, + dut.clk, (dut.ma_data, dut.ea_data), dut.a_valid, dut.a_ready, ) self.b_driver = MultiSignalStreamDriver( - dut.clk, - (dut.mb_data, dut.eb_data), - dut.b_valid, - dut.b_ready, + dut.clk, (dut.mb_data, dut.eb_data), dut.b_valid, dut.b_ready, ) self.output_monitor = MultiSignalStreamMonitor( diff --git a/src/mase_components/linear_layers/mxint_operators/test/mxint_vector_mult_tb.py b/src/mase_components/linear_layers/mxint_operators/test/mxint_vector_mult_tb.py index 58fc472ba..06f970039 100644 --- a/src/mase_components/linear_layers/mxint_operators/test/mxint_vector_mult_tb.py +++ b/src/mase_components/linear_layers/mxint_operators/test/mxint_vector_mult_tb.py @@ -35,10 +35,7 @@ def __init__(self, dut, num) -> None: dut.data_in_0_ready, ) self.weight_driver = MultiSignalStreamDriver( - dut.clk, - (dut.mweight, dut.eweight), - dut.weight_valid, - dut.weight_ready, + dut.clk, (dut.mweight, dut.eweight), dut.weight_valid, dut.weight_ready, ) self.data_out_0_monitor = MultiSignalStreamMonitor( @@ -66,9 +63,7 @@ def generate_inputs(self): w = 20 * torch.rand(int(self.dut.BLOCK_SIZE)) (weight, mweight, eweight) = mxint_quantize( - w, - int(self.dut.WEIGHT_PRECISION_0), - int(self.dut.WEIGHT_PRECISION_1), + w, int(self.dut.WEIGHT_PRECISION_0), int(self.dut.WEIGHT_PRECISION_1), ) exp_out, mexp_out, eexp_out = mxint_quantize( data_in * weight, diff --git a/src/mase_components/linear_layers/mxint_operators/test/test.py b/src/mase_components/linear_layers/mxint_operators/test/test.py index f58f382cf..85cdaa995 100644 --- a/src/mase_components/linear_layers/mxint_operators/test/test.py +++ b/src/mase_components/linear_layers/mxint_operators/test/test.py @@ -8,16 +8,8 @@ d_man_width = 12 w_man_width = 8 e_width = 4 -(data_in, mdata_in, edata_in) = mxint_quantize( - data, - d_man_width, - e_width, -) -(weight, mweight, eweight) = mxint_quantize( - w, - w_man_width, - e_width, -) +(data_in, mdata_in, edata_in) = mxint_quantize(data, d_man_width, e_width,) +(weight, mweight, eweight) = mxint_quantize(w, w_man_width, e_width,) linear = torch.nn.Linear(10, 10, bias=False) linear.weight = torch.nn.Parameter(weight) target_x = linear(data_in) @@ -36,7 +28,7 @@ def hardware_quant(hardware_in, be_value, e_width, width): exponent_bias = 2 ** (e_width - 1) - 1 # exponent - exponent_max = 2**e_width - 1 - exponent_bias + exponent_max = 2 ** e_width - 1 - exponent_bias exponent_min = -exponent_bias exponent = ( torch.ceil(torch.log2(hardware_in.abs().max())) + be_value - exponent_bias @@ -48,7 +40,7 @@ def hardware_quant(hardware_in, be_value, e_width, width): breakpoint() mantissa = torch.clamp(mantissa.floor(), int_min, int_max) - msfp_x = (2**exponent) * mantissa + msfp_x = (2 ** exponent) * mantissa return msfp_x, mantissa, exponent diff --git a/src/mase_components/linear_layers/mxint_operators/test/utils.py b/src/mase_components/linear_layers/mxint_operators/test/utils.py index 7edb9f6ed..43b3b0b87 100644 --- a/src/mase_components/linear_layers/mxint_operators/test/utils.py +++ b/src/mase_components/linear_layers/mxint_operators/test/utils.py @@ -24,7 +24,7 @@ def mxint_quantize(x, width: int = 12, exponent_width: int = 6, exponent: int = """ exponent_bias = 2 ** (exponent_width - 1) - exponent_max = 2**exponent_width - 1 - exponent_bias + exponent_max = 2 ** exponent_width - 1 - exponent_bias exponent_min = -exponent_bias # exponent @@ -34,9 +34,9 @@ def mxint_quantize(x, width: int = 12, exponent_width: int = 6, exponent: int = # mantissa int_min = -(2 ** (width - 1)) int_max = 2 ** (width - 1) - 1 - mantissa = x / 2**exponent + mantissa = x / 2 ** exponent mantissa = torch.clamp(mantissa.floor(), int_min, int_max) - msfp_x = (2**exponent) * mantissa + msfp_x = (2 ** exponent) * mantissa return msfp_x, mantissa, exponent diff --git a/src/mase_components/memory/test/fifo_tb.py b/src/mase_components/memory/test/fifo_tb.py index 9dc4f54ff..5f247bede 100644 --- a/src/mase_components/memory/test/fifo_tb.py +++ b/src/mase_components/memory/test/fifo_tb.py @@ -27,7 +27,7 @@ def __init__(self, dut) -> None: # self.output_monitor.log.setLevel("DEBUG") def generate_inputs(self, num=20): - return [randint(0, (2**self.DATA_WIDTH) - 1) for _ in range(num)] + return [randint(0, (2 ** self.DATA_WIDTH) - 1) for _ in range(num)] @cocotb.test() @@ -120,12 +120,7 @@ async def cocotb_test_soak(dut): @pytest.mark.dev def test_fifo(): mase_runner( - module_param_list=[ - {"DEPTH": 1}, - {"DEPTH": 7}, - {"DEPTH": 8}, - {"DEPTH": 81}, - ], + module_param_list=[{"DEPTH": 1}, {"DEPTH": 7}, {"DEPTH": 8}, {"DEPTH": 81},], trace=True, ) diff --git a/src/mase_components/memory/test/repeat_circular_buffer_tb.py b/src/mase_components/memory/test/repeat_circular_buffer_tb.py index 11d7d85b1..db6bfc7b8 100644 --- a/src/mase_components/memory/test/repeat_circular_buffer_tb.py +++ b/src/mase_components/memory/test/repeat_circular_buffer_tb.py @@ -30,7 +30,7 @@ def generate_inputs(self, num=10): inputs = [] for _ in range(num): inputs.extend( - [random.randint(0, 2**self.DATA_WIDTH - 1) for _ in range(self.SIZE)] + [random.randint(0, 2 ** self.DATA_WIDTH - 1) for _ in range(self.SIZE)] ) return inputs diff --git a/src/mase_components/memory/test/unpacked_fifo_tb.py b/src/mase_components/memory/test/unpacked_fifo_tb.py index cf0a99592..734a2edcb 100644 --- a/src/mase_components/memory/test/unpacked_fifo_tb.py +++ b/src/mase_components/memory/test/unpacked_fifo_tb.py @@ -26,7 +26,7 @@ def __init__(self, dut) -> None: def generate_inputs(self, batches=20): return [ - [randint(0, (2**self.DATA_WIDTH) - 1) for _ in range(self.IN_NUM)] + [randint(0, (2 ** self.DATA_WIDTH) - 1) for _ in range(self.IN_NUM)] for _ in range(batches) ] diff --git a/src/mase_components/normalization_layers/process_synth_impl.py b/src/mase_components/normalization_layers/process_synth_impl.py index c65d1b8a9..5a12ee170 100644 --- a/src/mase_components/normalization_layers/process_synth_impl.py +++ b/src/mase_components/normalization_layers/process_synth_impl.py @@ -103,19 +103,14 @@ def gather_data(build_dir: Path): if __name__ == "__main__": data = gather_data(Path("build")) data["ns"] = data["clk_period"] - data["wns"] - data["fmax"] = 1 / (data["ns"] * (10**-9)) + data["fmax"] = 1 / (data["ns"] * (10 ** -9)) data["fmax_mhz"] = data["fmax"] / 1_000_000 print(data) def plot(col): - alt.Chart(data).mark_line().encode( - x="width", - y=col, - color="norm", - ).properties( - width=400, - height=200, + alt.Chart(data).mark_line().encode(x="width", y=col, color="norm",).properties( + width=400, height=200, ).save(f"{col}_plot.png", scale_factor=3) plot("wns") diff --git a/src/mase_components/normalization_layers/test/batch_norm_2d_tb.py b/src/mase_components/normalization_layers/test/batch_norm_2d_tb.py index 02297dd3b..d43976845 100644 --- a/src/mase_components/normalization_layers/test/batch_norm_2d_tb.py +++ b/src/mase_components/normalization_layers/test/batch_norm_2d_tb.py @@ -158,7 +158,7 @@ def model(self, inputs): -1, self.NUM_CHANNELS, self.TOTAL_DIM1, self.TOTAL_DIM0 ) x = sign_extend_t(x, self.IN_WIDTH).to(dtype=torch.float32) / ( - 2**self.IN_FRAC_WIDTH + 2 ** self.IN_FRAC_WIDTH ) # Float Model @@ -344,7 +344,6 @@ async def valid_backpressure(dut): @pytest.mark.skip(reason="Needs to be fixed.") def test_batch_norm_2d(): - def gen_cfg( total_dim0: int = 4, total_dim1: int = 4, @@ -398,8 +397,7 @@ def gen_cfg( ] mase_runner( - module_param_list=test_cfgs, - trace=True, + module_param_list=test_cfgs, trace=True, ) diff --git a/src/mase_components/normalization_layers/test/channel_selection_tb.py b/src/mase_components/normalization_layers/test/channel_selection_tb.py index ceb65df8a..0fb603971 100644 --- a/src/mase_components/normalization_layers/test/channel_selection_tb.py +++ b/src/mase_components/normalization_layers/test/channel_selection_tb.py @@ -61,7 +61,6 @@ async def basic(dut): @pytest.mark.skip(reason="Needs to be fixed.") def test_channel_selection(): - def gen_cfg(num_channels, num_blocks): return {"NUM_CHANNELS": num_channels, "NUM_SPATIAL_BLOCKS": num_blocks} diff --git a/src/mase_components/normalization_layers/test/group_norm_2d_tb.py b/src/mase_components/normalization_layers/test/group_norm_2d_tb.py index e31203c25..cffd16822 100644 --- a/src/mase_components/normalization_layers/test/group_norm_2d_tb.py +++ b/src/mase_components/normalization_layers/test/group_norm_2d_tb.py @@ -79,7 +79,7 @@ def __init__(self, dut) -> None: self.out_width_tup = self.OUT_WIDTH, self.OUT_FRAC_WIDTH # Inverse Square Root LUT - self.isqrt_lut = make_lut(2**5, 16) + self.isqrt_lut = make_lut(2 ** 5, 16) self.num_groups = randint(2, 3) self.total_channels = self.GROUP_CHANNELS * self.num_groups @@ -141,7 +141,7 @@ def model(self, inputs): -1, self.total_channels, self.TOTAL_DIM1, self.TOTAL_DIM0 ) x = sign_extend_t(x, self.IN_WIDTH).to(dtype=torch.float32) / ( - 2**self.IN_FRAC_WIDTH + 2 ** self.IN_FRAC_WIDTH ) # Float Model @@ -239,12 +239,7 @@ def test_group_norm_2d(): makedirs(mem_dir, exist_ok=True) def isqrt_width( - total_dim0, - total_dim1, - compute_dim0, - compute_dim1, - group_channels, - in_width, + total_dim0, total_dim1, compute_dim0, compute_dim1, group_channels, in_width, ): depth_dim0 = total_dim0 // compute_dim0 depth_dim1 = total_dim1 // compute_dim1 @@ -274,7 +269,7 @@ def gen_cfg( isqrt_w = isqrt_width( total_dim0, total_dim1, compute_dim0, compute_dim1, channels, in_width ) - lut = make_lut(2**LUT_POW, isqrt_w) + lut = make_lut(2 ** LUT_POW, isqrt_w) mem_path = mem_dir / f"lutmem-{str_id}.mem" write_memb(mem_path, lut, isqrt_w) params = { diff --git a/src/mase_components/normalization_layers/test/models.py b/src/mase_components/normalization_layers/test/models.py index 0f10b228e..dfeadac98 100644 --- a/src/mase_components/normalization_layers/test/models.py +++ b/src/mase_components/normalization_layers/test/models.py @@ -56,11 +56,11 @@ def _fixed_group_norm_2d_model( logger.debug("Diff:") logger.debug(diff[0]) - squares = diff**2 + squares = diff ** 2 logger.debug("Squares:") logger.debug(squares[0]) - squares_int = (squares * (2**square_frac_width)).int() - logger.debug(squares * (2**square_frac_width)) + squares_int = (squares * (2 ** square_frac_width)).int() + logger.debug(squares * (2 ** square_frac_width)) sum_squares = torch.sum(squares, dim=(1, 2, 3), keepdim=True) sum_squares = integer_floor_quantizer( @@ -81,10 +81,12 @@ def _fixed_group_norm_2d_model( logger.debug(f"{var[0]}") # Clamp down variance to isqrt width - var_clamp = torch.clamp(var, 0.0, ((2**isqrt_width) - 1) / (2**isqrt_frac_width)) + var_clamp = torch.clamp( + var, 0.0, ((2 ** isqrt_width) - 1) / (2 ** isqrt_frac_width) + ) logger.debug("Variance Clamped:") logger.debug(f"{var_clamp[0]}") - var_clamp_int = (var_clamp * (2**isqrt_frac_width)).int() + var_clamp_int = (var_clamp * (2 ** isqrt_frac_width)).int() # Inverse Square Root calculation lut_pow = ceil(log2(len(isqrt_lut))) @@ -104,7 +106,7 @@ def _fixed_group_norm_2d_model( logger.debug("INV SQRT INT:") logger.debug(f"{inv_sqrt_int[0]}") - inv_sqrt = inv_sqrt_int / (2**isqrt_frac_width) + inv_sqrt = inv_sqrt_int / (2 ** isqrt_frac_width) logger.debug("Inverse SQRT:") logger.debug(f"{inv_sqrt[0]}") diff --git a/src/mase_components/normalization_layers/test/rms_norm_2d_tb.py b/src/mase_components/normalization_layers/test/rms_norm_2d_tb.py index 22ef8bdd1..0c998f83d 100644 --- a/src/mase_components/normalization_layers/test/rms_norm_2d_tb.py +++ b/src/mase_components/normalization_layers/test/rms_norm_2d_tb.py @@ -151,7 +151,7 @@ def reconstruct_tensor(self, x, width, frac_width): x = torch.stack(matrix_list).reshape( -1, self.CHANNELS, self.TOTAL_DIM1, self.TOTAL_DIM0 ) - x = sign_extend_t(x, width).to(dtype=torch.float32) / (2**frac_width) + x = sign_extend_t(x, width).to(dtype=torch.float32) / (2 ** frac_width) return x def output_monitor_split(self, x, width, frac_width): @@ -248,12 +248,7 @@ def test_rms_norm_2d(): makedirs(mem_dir, exist_ok=True) def isqrt_width( - total_dim0, - total_dim1, - compute_dim0, - compute_dim1, - group_channels, - in_width, + total_dim0, total_dim1, compute_dim0, compute_dim1, group_channels, in_width, ): depth_dim0 = total_dim0 // compute_dim0 depth_dim1 = total_dim1 // compute_dim1 @@ -282,7 +277,7 @@ def gen_cfg( isqrt_w = isqrt_width( total_dim0, total_dim1, compute_dim0, compute_dim1, channels, in_width ) - lut = make_lut(2**LUT_POW, isqrt_w) + lut = make_lut(2 ** LUT_POW, isqrt_w) mem_path = mem_dir / f"lutmem-{str_id}.mem" write_memb(mem_path, lut, isqrt_w) params = { diff --git a/src/mase_components/scalar_operators/fixed/test/fixed_isqrt_tb.py b/src/mase_components/scalar_operators/fixed/test/fixed_isqrt_tb.py index 6ee7a93dd..f0f5cafb1 100644 --- a/src/mase_components/scalar_operators/fixed/test/fixed_isqrt_tb.py +++ b/src/mase_components/scalar_operators/fixed/test/fixed_isqrt_tb.py @@ -24,31 +24,23 @@ class VerificationCase(Testbench): def __init__(self, dut): super().__init__(dut, dut.clk, dut.rst) self.assign_self_params( - [ - "IN_WIDTH", - "IN_FRAC_WIDTH", - "LUT_POW", - ] + ["IN_WIDTH", "IN_FRAC_WIDTH", "LUT_POW",] ) self.input_driver = StreamDriver( dut.clk, dut.in_data, dut.in_valid, dut.in_ready ) self.output_monitor = StreamMonitor( - dut.clk, - dut.out_data, - dut.out_valid, - dut.out_ready, - name="Output ISQRT", + dut.clk, dut.out_data, dut.out_valid, dut.out_ready, name="Output ISQRT", ) def generate_inputs(self, num=10000): - maxnum = (2**self.IN_WIDTH) - 1 + maxnum = (2 ** self.IN_WIDTH) - 1 return [random.randint(0, maxnum) for _ in range(num)], num def model(self, data_in): ref = [] - lut_size = 2**self.LUT_POW + lut_size = 2 ** self.LUT_POW lut = make_lut(lut_size, self.IN_WIDTH) for x in data_in: expected = isqrt_sw2( @@ -131,7 +123,7 @@ def test_fixed_isqrt(): makedirs(mem_dir, exist_ok=True) def single_cfg(width, frac_width, lut_pow, str_id): - lut_size = 2**lut_pow + lut_size = 2 ** lut_pow lut = make_lut(lut_size, width) mem_path = mem_dir / f"lutmem-{str_id}.mem" write_memb(mem_path, lut, width) diff --git a/src/mase_components/scalar_operators/fixed/test/fixed_nr_stage_tb.py b/src/mase_components/scalar_operators/fixed/test/fixed_nr_stage_tb.py index d09305ef0..5a3019f8e 100644 --- a/src/mase_components/scalar_operators/fixed/test/fixed_nr_stage_tb.py +++ b/src/mase_components/scalar_operators/fixed/test/fixed_nr_stage_tb.py @@ -24,12 +24,12 @@ def __init__(self, dut): self.assign_self_params(["WIDTH"]) def generate_inputs(self, lut_pow): - samples = 2**self.WIDTH + samples = 2 ** self.WIDTH int_width = 1 frac_width = self.WIDTH - 1 data_x = [] initial_guesses = [] - lut_size = 2**lut_pow + lut_size = 2 ** lut_pow lut = make_lut(lut_size, self.WIDTH) # NOTE: since negative values are not supported by fixed formats since # isqrt only outputs positive results we cannot test every single com- diff --git a/src/mase_components/scalar_operators/fixed/test/isqrt_sw.py b/src/mase_components/scalar_operators/fixed/test/isqrt_sw.py index f6fc798a4..901ba964d 100644 --- a/src/mase_components/scalar_operators/fixed/test/isqrt_sw.py +++ b/src/mase_components/scalar_operators/fixed/test/isqrt_sw.py @@ -13,7 +13,7 @@ def find_msb(x: int, width: int) -> int: def float_to_int(x: float, int_width: int, frac_width: int) -> int: integer = int(x) x -= integer - res = integer * (2**frac_width) + res = integer * (2 ** frac_width) for i in range(1, frac_width + 1): power = 2 ** (-i) if power <= x: @@ -23,8 +23,8 @@ def float_to_int(x: float, int_width: int, frac_width: int) -> int: def int_to_float(x: int, int_width: int, frac_width: int) -> float: - integer = x / (2**frac_width) - fraction = x - integer * 2**frac_width + integer = x / (2 ** frac_width) + fraction = x - integer * 2 ** frac_width res = integer for i in range(1, frac_width + 1): @@ -85,7 +85,7 @@ def fixed_lut_index_sw(x_red: int, width: int, lut_pow: int) -> int: res = 0 else: res = x_red - 2 ** (width - 1) - res = res * 2**lut_pow + res = res * 2 ** lut_pow res = res / 2 ** (width - 1) # FORMAT OUTPUT: Q(WIDTH).0 return int(res) @@ -258,7 +258,7 @@ def test_isqrt_sw_model(): def debug_single(): lut_pow = 5 - lut_size = 2**lut_pow + lut_size = 2 ** lut_pow int_width = 2 frac_width = 1 width = int_width + frac_width diff --git a/src/mase_components/transformer_layers/test/fixed_grouped_query_attention_wrapper_tb.py b/src/mase_components/transformer_layers/test/fixed_grouped_query_attention_wrapper_tb.py index cc7d23428..a334a1469 100644 --- a/src/mase_components/transformer_layers/test/fixed_grouped_query_attention_wrapper_tb.py +++ b/src/mase_components/transformer_layers/test/fixed_grouped_query_attention_wrapper_tb.py @@ -89,13 +89,16 @@ def forward(self, x: Tensor): out = self.o_projection(attn_output) - return out, { - "query": query, - "key": key.transpose(1, 2), # Key is transposed in hardware - "value": value, - "heads_out": heads_out, - "attn_output": attn_output, - } + return ( + out, + { + "query": query, + "key": key.transpose(1, 2), # Key is transposed in hardware + "value": value, + "heads_out": heads_out, + "attn_output": attn_output, + }, + ) class FixedGroupedQueryAttentionTB(Testbench): @@ -192,28 +195,16 @@ def __init__(self, dut) -> None: if self.HAS_BIAS == 1: self.bias_q_driver = StreamDriver( - dut.clk, - dut.bias_query, - dut.bias_query_valid, - dut.bias_query_ready, + dut.clk, dut.bias_query, dut.bias_query_valid, dut.bias_query_ready, ) self.bias_k_driver = StreamDriver( - dut.clk, - dut.bias_key, - dut.bias_key_valid, - dut.bias_key_ready, + dut.clk, dut.bias_key, dut.bias_key_valid, dut.bias_key_ready, ) self.bias_v_driver = StreamDriver( - dut.clk, - dut.bias_value, - dut.bias_value_valid, - dut.bias_value_ready, + dut.clk, dut.bias_value, dut.bias_value_valid, dut.bias_value_ready, ) self.bias_o_driver = StreamDriver( - dut.clk, - dut.bias_output, - dut.bias_output_valid, - dut.bias_output_ready, + dut.clk, dut.bias_output, dut.bias_output_valid, dut.bias_output_ready, ) self.error_threshold = 2 @@ -478,11 +469,11 @@ async def run_memory_bandwidth_test(self, us: int = 500): num_v_weight_beats_sent = self.weight_v_driver.num_beats num_o_weight_beats_sent = self.weight_o_driver.num_beats - input_beats_per_sec = num_input_beats_sent / (nanosec * (10**-9)) - num_q_beats_per_sec = num_q_weight_beats_sent / (nanosec * (10**-9)) - num_k_beats_per_sec = num_k_weight_beats_sent / (nanosec * (10**-9)) - num_v_beats_per_sec = num_v_weight_beats_sent / (nanosec * (10**-9)) - num_o_beats_per_sec = num_o_weight_beats_sent / (nanosec * (10**-9)) + input_beats_per_sec = num_input_beats_sent / (nanosec * (10 ** -9)) + num_q_beats_per_sec = num_q_weight_beats_sent / (nanosec * (10 ** -9)) + num_k_beats_per_sec = num_k_weight_beats_sent / (nanosec * (10 ** -9)) + num_v_beats_per_sec = num_v_weight_beats_sent / (nanosec * (10 ** -9)) + num_o_beats_per_sec = num_o_weight_beats_sent / (nanosec * (10 ** -9)) self.log.info("Test length (ns): %.4f" % nanosec) @@ -600,9 +591,7 @@ def test_fixed_linear_smoke(): ] mase_runner( - module_param_list=cfgs, - hierarchical=True, - template=True, + module_param_list=cfgs, hierarchical=True, template=True, ) @@ -614,9 +603,7 @@ def test_parallelism_sweep(): cfgs.append(get_config(16, 128, 8, 4, embedding_par, seq_par)) mase_runner( - module_param_list=cfgs, - hierarchical=True, - template=True, + module_param_list=cfgs, hierarchical=True, template=True, ) @@ -628,9 +615,7 @@ def test_small_parallelism(): cfgs.append(get_config(16, 128, 8, 4, embedding_par, seq_par)) mase_runner( - module_param_list=cfgs, - hierarchical=True, - template=True, + module_param_list=cfgs, hierarchical=True, template=True, ) @@ -640,9 +625,7 @@ def test_heads_sweep(): cfgs.append(get_config(256, 256, 16, kv_heads, 16, 1)) mase_runner( - module_param_list=cfgs, - hierarchical=True, - template=True, + module_param_list=cfgs, hierarchical=True, template=True, ) @@ -654,9 +637,7 @@ def test_bitwidth_sweep(): ) mase_runner( - module_param_list=cfgs, - hierarchical=True, - template=True, + module_param_list=cfgs, hierarchical=True, template=True, ) diff --git a/src/mase_components/transformer_layers/test/fixed_self_attention_head_tb.py b/src/mase_components/transformer_layers/test/fixed_self_attention_head_tb.py index a46032646..70df56d43 100644 --- a/src/mase_components/transformer_layers/test/fixed_self_attention_head_tb.py +++ b/src/mase_components/transformer_layers/test/fixed_self_attention_head_tb.py @@ -44,11 +44,7 @@ def __init__(self, dut) -> None: ) self.out_monitor = StreamMonitor( - dut.clk, - dut.out, - dut.out_valid, - dut.out_ready, - check=False, + dut.clk, dut.out, dut.out_valid, dut.out_ready, check=False, ) # Model @@ -60,8 +56,7 @@ def __init__(self, dut) -> None: "frac_width": self.get_parameter("IN_DATA_PRECISION_1"), } self.model = BertSelfAttentionHeadInteger( - config=self.config, - q_config=self.q_config, + config=self.config, q_config=self.q_config, ) # Set verbosity of driver and monitor loggers to debug @@ -119,27 +114,21 @@ async def run_test(self): # * Load the query driver self.log.info(f"Processing query inputs: {inputs['query_layer']}") query_inputs = self.preprocess_tensor( - tensor=inputs["query_layer"], - config=self.q_config, - parallelism=parallelism, + tensor=inputs["query_layer"], config=self.q_config, parallelism=parallelism, ) self.query_driver.load_driver(query_inputs) # * Load the key driver self.log.info(f"Processing key inputs: {inputs['key_layer']}") key_inputs = self.preprocess_tensor( - tensor=inputs["key_layer"], - config=self.q_config, - parallelism=parallelism, + tensor=inputs["key_layer"], config=self.q_config, parallelism=parallelism, ) self.key_driver.load_driver(key_inputs) # * Load the value driver self.log.info(f"Processing value inputs: {inputs['value_layer']}") value_inputs = self.preprocess_tensor( - tensor=inputs["value_layer"], - config=self.q_config, - parallelism=parallelism, + tensor=inputs["value_layer"], config=self.q_config, parallelism=parallelism, ) self.value_driver.load_driver(value_inputs) @@ -198,18 +187,11 @@ def test_fixed_self_attention_head_smoke(): # * Generate exponential LUT for softmax generate_memory.generate_sv_lut( - "exp", - 16, - 3, - 16, - 3, - path=Path(__file__).parents[1] / "rtl", + "exp", 16, 3, 16, 3, path=Path(__file__).parents[1] / "rtl", ) mase_runner( trace=True, - module_param_list=[ - get_fixed_self_attention_head_config(), - ], + module_param_list=[get_fixed_self_attention_head_config(),], skip_build=False, ) diff --git a/src/mase_components/vision_models/vit/test/fixed_mlp_tb.py b/src/mase_components/vision_models/vit/test/fixed_mlp_tb.py index 2e19bed4b..823c65db4 100644 --- a/src/mase_components/vision_models/vit/test/fixed_mlp_tb.py +++ b/src/mase_components/vision_models/vit/test/fixed_mlp_tb.py @@ -108,8 +108,7 @@ def __init__(self, samples=1): depth_in_num = int(self.in_num / self.tile_in_num) depth_out_features = int(self.out_features / self.tile_out_features) self.outputs = RandomSink( - samples=samples * depth_out_features * depth_in_num, - debug=debug, + samples=samples * depth_out_features * depth_in_num, debug=debug, ) self.ref = self.sw_compute() @@ -471,8 +470,7 @@ def runner(): ) for _ in range(1): runner.test( - hdl_toplevel="fixed_mlp", - test_module="fixed_mlp_tb", + hdl_toplevel="fixed_mlp", test_module="fixed_mlp_tb", ) diff --git a/src/mase_components/vision_models/vit/test/fixed_patch_embed_tb.py b/src/mase_components/vision_models/vit/test/fixed_patch_embed_tb.py index 05cf0f2cf..b92e6dc4f 100644 --- a/src/mase_components/vision_models/vit/test/fixed_patch_embed_tb.py +++ b/src/mase_components/vision_models/vit/test/fixed_patch_embed_tb.py @@ -50,14 +50,14 @@ def __init__(self, samples=1): self.pe_unroll_kernel_out = 3 self.pe_unroll_in_c = 3 self.pe_unroll_embed_dim = 8 - self.num_patch = int(self.in_y * self.in_x // (self.patch_size**2)) + self.num_patch = int(self.in_y * self.in_x // (self.patch_size ** 2)) # self.num_classes = 10 # self.head_unroll_out_x = 5 self.samples = samples self.pe_iter_weight = int( - (self.patch_size**2) + (self.patch_size ** 2) * self.in_c * self.embed_dim / self.pe_unroll_kernel_out @@ -247,10 +247,7 @@ def conv_pack( unroll_out_channels, ).permute(0, 3, 1, 4, 2) - w_tensor = w_tensor.reshape( - -1, - unroll_out_channels * unroll_kernel_out, - ) + w_tensor = w_tensor.reshape(-1, unroll_out_channels * unroll_kernel_out,) w_in = w_tensor.type(torch.int).flip(0).tolist() # bias_pack bias_tensor = bias.repeat(samples, 1).reshape(-1, unroll_out_channels) diff --git a/src/mase_components/vision_models/vit/test/fixed_pvt_tb.py b/src/mase_components/vision_models/vit/test/fixed_pvt_tb.py index 06b921e81..532e3fb07 100644 --- a/src/mase_components/vision_models/vit/test/fixed_pvt_tb.py +++ b/src/mase_components/vision_models/vit/test/fixed_pvt_tb.py @@ -49,7 +49,7 @@ def __init__(self, samples=1): self.in_x = 224 self.embed_dim = 384 self.patch_size = 16 - self.num_patch = self.in_y * self.in_x // (self.patch_size**2) + self.num_patch = self.in_y * self.in_x // (self.patch_size ** 2) self.num_heads = 6 self.mlp_ratio = 2 @@ -64,7 +64,7 @@ def __init__(self, samples=1): self.head_unroll_out_x = 1 self.pe_iter_weight = int( - (self.patch_size**2) + (self.patch_size ** 2) * self.in_c * self.embed_dim / self.pe_unroll_kernel_out @@ -1148,10 +1148,7 @@ def conv_pack( unroll_out_channels, ).permute(0, 3, 1, 4, 2) - w_tensor = w_tensor.reshape( - -1, - unroll_out_channels * unroll_kernel_out, - ) + w_tensor = w_tensor.reshape(-1, unroll_out_channels * unroll_kernel_out,) w_in = w_tensor.type(torch.int).flip(0).tolist() # bias_pack bias_tensor = bias.repeat(samples, 1).reshape(-1, unroll_out_channels) @@ -1514,8 +1511,7 @@ def runner(): build_args=extra_args, ) runner.test( - hdl_toplevel="fixed_pvt", - test_module="fixed_pvt_tb", + hdl_toplevel="fixed_pvt", test_module="fixed_pvt_tb", ) diff --git a/src/mase_components/vision_models/vit/test/hash_exp_tb.py b/src/mase_components/vision_models/vit/test/hash_exp_tb.py index e8b107503..719ee89dd 100644 --- a/src/mase_components/vision_models/vit/test/hash_exp_tb.py +++ b/src/mase_components/vision_models/vit/test/hash_exp_tb.py @@ -147,9 +147,7 @@ def runner(): print(extra_args) runner = get_runner(sim) runner.build( - verilog_sources=verilog_sources, - hdl_toplevel="hash_exp", - build_args=extra_args, + verilog_sources=verilog_sources, hdl_toplevel="hash_exp", build_args=extra_args, ) runner.test(hdl_toplevel="hash_exp", test_module="hash_exp_tb") diff --git a/src/mase_components/vision_models/vit/test/hash_softmax_tb.py b/src/mase_components/vision_models/vit/test/hash_softmax_tb.py index 98ceea63f..5c0522fb9 100644 --- a/src/mase_components/vision_models/vit/test/hash_softmax_tb.py +++ b/src/mase_components/vision_models/vit/test/hash_softmax_tb.py @@ -45,11 +45,7 @@ def __init__(self, samples=1): }, } self.d_config = { - "softmax": { - "in_size": 1, - "out_size": 1, - "in_depth": 4, - }, + "softmax": {"in_size": 1, "out_size": 1, "in_depth": 4,}, } in_size = self.d_config["softmax"]["in_size"] out_size = self.d_config["softmax"]["out_size"] diff --git a/src/mase_components/vision_models/vit/test/helpers/ha_softmax.py b/src/mase_components/vision_models/vit/test/helpers/ha_softmax.py index 62e5b840d..8dac8afee 100644 --- a/src/mase_components/vision_models/vit/test/helpers/ha_softmax.py +++ b/src/mase_components/vision_models/vit/test/helpers/ha_softmax.py @@ -5,8 +5,8 @@ def quantize_to_int(x: Tensor, width: int, frac_width: int): - x = _integer_quantize(x, width, frac_width) * (2**frac_width) - x = x.int() & (2**width - 1) + x = _integer_quantize(x, width, frac_width) * (2 ** frac_width) + x = x.int() & (2 ** width - 1) return x @@ -25,7 +25,7 @@ def twos_complement_to_float(binary_string: str, width: int, frac_width: int): integer_magnitude = -(2 ** (width - 1)) + integer_magnitude # Calculate scaling factor - scaling_factor = 2**frac_width + scaling_factor = 2 ** frac_width # Calculate floating-point value float_value = integer_magnitude / scaling_factor @@ -79,8 +79,7 @@ def generate_table_div_software(width, out_width, out_frac_width): class QHashSoftmax(torch.nn.Module): def __init__( - self, - config, + self, config, ): super(QHashSoftmax, self).__init__() self.in_width = config["data_in_width"] @@ -109,7 +108,7 @@ def forward(self, x, scale): # quantize to div_width one_over_div = _integer_quantize(exp_sum // exp, self.div_width + 1, 0) one_over_div = torch.where( - exp == 0, torch.tensor(2**self.div_width - 1), one_over_div + exp == 0, torch.tensor(2 ** self.div_width - 1), one_over_div ) one_over_div = torch.tensor(one_over_div, dtype=int) diff --git a/src/mase_components/vision_models/vit/test/helpers/pvt_quant.py b/src/mase_components/vision_models/vit/test/helpers/pvt_quant.py index d7f4e8464..4863f2ca9 100644 --- a/src/mase_components/vision_models/vit/test/helpers/pvt_quant.py +++ b/src/mase_components/vision_models/vit/test/helpers/pvt_quant.py @@ -50,7 +50,7 @@ def __init__( self.dim = dim self.num_heads = num_heads head_dim = dim // num_heads - self.scale = qk_scale or head_dim**-0.5 + self.scale = qk_scale or head_dim ** -0.5 self.q = get_quantized_cls("linear", config["q_proj"])( dim, dim, bias=qkv_bias, config=config["q_proj"] diff --git a/test/nn/quantized/modules/attention_head.py b/test/nn/quantized/modules/attention_head.py index e0f52a61e..72b625f90 100644 --- a/test/nn/quantized/modules/attention_head.py +++ b/test/nn/quantized/modules/attention_head.py @@ -1,6 +1,4 @@ -from chop.nn.quantized.modules.attention_head import ( - BertSelfAttentionHeadInteger, -) +from chop.nn.quantized.modules.attention_head import BertSelfAttentionHeadInteger from transformers import AutoConfig import torch diff --git a/test/nn/snn/test_ann2snn.py b/test/nn/snn/test_ann2snn.py index 1b8bbbcfa..3a5d2b7e2 100644 --- a/test/nn/snn/test_ann2snn.py +++ b/test/nn/snn/test_ann2snn.py @@ -191,13 +191,7 @@ def val(net, device, data_loader, T=None): "by": "type", "default": {"config": {"name": None}}, "fuse": True, - "relu": { - "config": { - "name": "IFNode", - "mode": "99.9%", - "momentum": 0.1, - } - }, + "relu": {"config": {"name": "IFNode", "mode": "99.9%", "momentum": 0.1,}}, "train_data_loader": input_generator, "device": "cpu", # "device": "cuda", } diff --git a/test/passes/graph/analysis/add_metadata/test_add_common_metadata.py b/test/passes/graph/analysis/add_metadata/test_add_common_metadata.py index 9740bdd38..a7bce45cd 100644 --- a/test/passes/graph/analysis/add_metadata/test_add_common_metadata.py +++ b/test/passes/graph/analysis/add_metadata/test_add_common_metadata.py @@ -27,22 +27,10 @@ def add_common_metadata(model_cls_name: str) -> MaseGraph: # mg.fx_graph.print_tabular() input_ids = torch.randint( - 0, - config.vocab_size, - ( - 1, - 128, - config.hidden_size, - ), - device="meta", + 0, config.vocab_size, (1, 128, config.hidden_size,), device="meta", ) mg, _ = passes.add_common_metadata_analysis_pass( - mg, - pass_args={ - "dummy_in": { - "input_ids": input_ids, - }, - }, + mg, pass_args={"dummy_in": {"input_ids": input_ids,},}, ) return mg diff --git a/test/passes/graph/analysis/pruning/test_hook_inspect.py b/test/passes/graph/analysis/pruning/test_hook_inspect.py index dfa7fa0fd..b34b579d3 100644 --- a/test/passes/graph/analysis/pruning/test_hook_inspect.py +++ b/test/passes/graph/analysis/pruning/test_hook_inspect.py @@ -111,15 +111,9 @@ def run_with_config(config_file): profile_pass_arg = { "by": "type", - "target_weight_nodes": [ - "conv2d", - ], - "target_activation_nodes": [ - "conv2d", - ], - "weight_statistics": { - "variance_precise": {"device": "cpu", "dims": "all"}, - }, + "target_weight_nodes": ["conv2d",], + "target_activation_nodes": ["conv2d",], + "weight_statistics": {"variance_precise": {"device": "cpu", "dims": "all"},}, "activation_statistics": { "variance_precise": {"device": "cpu", "dims": "all"}, }, diff --git a/test/passes/graph/analysis/statistic_profiler/test_statistic_profiler.py b/test/passes/graph/analysis/statistic_profiler/test_statistic_profiler.py index 2b9828ace..084170c70 100644 --- a/test/passes/graph/analysis/statistic_profiler/test_statistic_profiler.py +++ b/test/passes/graph/analysis/statistic_profiler/test_statistic_profiler.py @@ -52,9 +52,7 @@ def test_statistic_profiler(): dataset_info = get_dataset_info("cifar10") model = get_model( - checkpoint="resnet18", - pretrained=False, - dataset_info=dataset_info, + checkpoint="resnet18", pretrained=False, dataset_info=dataset_info, ) dummy_in = {"x": next(iter(datamodule.train_dataloader()))[0]} @@ -68,15 +66,9 @@ def test_statistic_profiler(): pass_arg = { "by": "type", - "target_weight_nodes": [ - "conv2d", - ], - "target_activation_nodes": [ - "relu", - ], - "weight_statistics": { - "variance_precise": {"device": "cpu", "dims": "all"}, - }, + "target_weight_nodes": ["conv2d",], + "target_activation_nodes": ["relu",], + "weight_statistics": {"variance_precise": {"device": "cpu", "dims": "all"},}, "activation_statistics": { "variance_precise": {"device": "cpu", "dims": "all"}, }, diff --git a/test/passes/graph/transforms/prune/test_prune.py b/test/passes/graph/transforms/prune/test_prune.py index 8bddad906..8c6de1d2a 100644 --- a/test/passes/graph/transforms/prune/test_prune.py +++ b/test/passes/graph/transforms/prune/test_prune.py @@ -106,15 +106,9 @@ def run_with_config(config_file): profile_pass_arg = { "by": "type", - "target_weight_nodes": [ - "conv2d", - ], - "target_activation_nodes": [ - "conv2d", - ], - "weight_statistics": { - "variance_precise": {"device": "cpu", "dims": "all"}, - }, + "target_weight_nodes": ["conv2d",], + "target_activation_nodes": ["conv2d",], + "weight_statistics": {"variance_precise": {"device": "cpu", "dims": "all"},}, "activation_statistics": { "variance_precise": {"device": "cpu", "dims": "all"}, }, diff --git a/test/passes/graph/transforms/prune/test_prune_detach_hook.py b/test/passes/graph/transforms/prune/test_prune_detach_hook.py index bf15caa86..f7afe9a8d 100644 --- a/test/passes/graph/transforms/prune/test_prune_detach_hook.py +++ b/test/passes/graph/transforms/prune/test_prune_detach_hook.py @@ -105,15 +105,9 @@ def run_with_config(config_file): profile_pass_arg = { "by": "type", - "target_weight_nodes": [ - "conv2d", - ], - "target_activation_nodes": [ - "conv2d", - ], - "weight_statistics": { - "variance_precise": {"device": "cpu", "dims": "all"}, - }, + "target_weight_nodes": ["conv2d",], + "target_activation_nodes": ["conv2d",], + "weight_statistics": {"variance_precise": {"device": "cpu", "dims": "all"},}, "activation_statistics": { "variance_precise": {"device": "cpu", "dims": "all"}, }, diff --git a/test/passes/graph/transforms/quantize/test_quantize_lutnet_conv2d.py b/test/passes/graph/transforms/quantize/test_quantize_lutnet_conv2d.py index 89327f8f0..5ca6f6b27 100644 --- a/test/passes/graph/transforms/quantize/test_quantize_lutnet_conv2d.py +++ b/test/passes/graph/transforms/quantize/test_quantize_lutnet_conv2d.py @@ -181,8 +181,7 @@ def test_quantize_lutnet_conv2d(): first_table = connection[0] assert any(initialized_weight[0, :] == first_table) and any( initialized_weight[ - input_c * k_w * k_h * output_c * (lutnet_config["data_in_levels"] - 1), - :, + input_c * k_w * k_h * output_c * (lutnet_config["data_in_levels"] - 1), :, ] == first_table ) diff --git a/test/passes/graph/transforms/quantize/test_quantize_lutnet_linear_2.py b/test/passes/graph/transforms/quantize/test_quantize_lutnet_linear_2.py index f1115cf56..72a0ce71d 100644 --- a/test/passes/graph/transforms/quantize/test_quantize_lutnet_linear_2.py +++ b/test/passes/graph/transforms/quantize/test_quantize_lutnet_linear_2.py @@ -10,9 +10,7 @@ from pathlib import Path sys.path.append(Path(__file__).resolve().parents[5].as_posix()) -from chop.nn.quantized.modules.linear import ( - LinearLogicNets, -) +from chop.nn.quantized.modules.linear import LinearLogicNets def generate_input_tensor(batch_size, input_features, min_val, max_val): diff --git a/test/passes/graph/transforms/training/test_training_base_pass.py b/test/passes/graph/transforms/training/test_training_base_pass.py index 04ed2167c..e46777de0 100644 --- a/test/passes/graph/transforms/training/test_training_base_pass.py +++ b/test/passes/graph/transforms/training/test_training_base_pass.py @@ -17,9 +17,7 @@ verify_common_metadata_analysis_pass, ) from chop.ir.graph.mase_graph import MaseGraph -from chop.passes.graph.transforms import ( - training_base_pass, -) +from chop.passes.graph.transforms import training_base_pass from chop.passes.graph.utils import deepcopy_mase_graph from chop.tools.logger import set_logging_verbosity @@ -172,11 +170,7 @@ def test_training_base_backward_only(): "default": {"config": {"name": None}}, "linear": { "config": { - "forward": { - "bypass": True, - "pass": "quantize", - "name": "integer", - }, + "forward": {"bypass": True, "pass": "quantize", "name": "integer",}, "backward": { "pass": "quantize", "name": "integer", diff --git a/test/passes/graph/transforms/verilog/test_emit_activation_gelu.py b/test/passes/graph/transforms/verilog/test_emit_activation_gelu.py index d78ac8525..68d405e5a 100644 --- a/test/passes/graph/transforms/verilog/test_emit_activation_gelu.py +++ b/test/passes/graph/transforms/verilog/test_emit_activation_gelu.py @@ -76,11 +76,7 @@ def test_emit_activation_gelu(): ) config_file = os.path.join( - os.path.abspath(""), - "configs", - "tests", - "quantize", - "fixed.toml", + os.path.abspath(""), "configs", "tests", "quantize", "fixed.toml", ) with open(config_file, "r") as f: quan_args = toml.load(f)["passes"]["quantize"] diff --git a/test/passes/graph/transforms/verilog/test_emit_activation_selu.py b/test/passes/graph/transforms/verilog/test_emit_activation_selu.py index ea1a11928..8a59b2e44 100644 --- a/test/passes/graph/transforms/verilog/test_emit_activation_selu.py +++ b/test/passes/graph/transforms/verilog/test_emit_activation_selu.py @@ -77,11 +77,7 @@ def test_emit_activation_selu(): ) config_file = os.path.join( - os.path.abspath(""), - "configs", - "tests", - "quantize", - "fixed.toml", + os.path.abspath(""), "configs", "tests", "quantize", "fixed.toml", ) with open(config_file, "r") as f: quan_args = toml.load(f)["passes"]["quantize"] diff --git a/test/passes/graph/transforms/verilog/test_emit_activation_softplus.py b/test/passes/graph/transforms/verilog/test_emit_activation_softplus.py index 8c7a1e107..9e60ba7a3 100644 --- a/test/passes/graph/transforms/verilog/test_emit_activation_softplus.py +++ b/test/passes/graph/transforms/verilog/test_emit_activation_softplus.py @@ -76,11 +76,7 @@ def test_emit_activation_softplus(): ) config_file = os.path.join( - os.path.abspath(""), - "configs", - "tests", - "quantize", - "fixed.toml", + os.path.abspath(""), "configs", "tests", "quantize", "fixed.toml", ) with open(config_file, "r") as f: quan_args = toml.load(f)["passes"]["quantize"] diff --git a/test/passes/graph/transforms/verilog/test_emit_activation_softsign.py b/test/passes/graph/transforms/verilog/test_emit_activation_softsign.py index f50f5fce5..3de23e480 100644 --- a/test/passes/graph/transforms/verilog/test_emit_activation_softsign.py +++ b/test/passes/graph/transforms/verilog/test_emit_activation_softsign.py @@ -76,11 +76,7 @@ def test_emit_activation_softsign(): ) config_file = os.path.join( - os.path.abspath(""), - "configs", - "tests", - "quantize", - "fixed.toml", + os.path.abspath(""), "configs", "tests", "quantize", "fixed.toml", ) with open(config_file, "r") as f: quan_args = toml.load(f)["passes"]["quantize"] diff --git a/test/passes/graph/transforms/verilog/test_emit_activation_tanh.py b/test/passes/graph/transforms/verilog/test_emit_activation_tanh.py index 205211ea6..116d53b36 100644 --- a/test/passes/graph/transforms/verilog/test_emit_activation_tanh.py +++ b/test/passes/graph/transforms/verilog/test_emit_activation_tanh.py @@ -75,11 +75,7 @@ def test_emit_activation_tanh(): mg, _ = add_common_metadata_analysis_pass(mg, {"dummy_in": dummy_in}) config_file = os.path.join( - os.path.abspath(""), - "configs", - "tests", - "quantize", - "fixed.toml", + os.path.abspath(""), "configs", "tests", "quantize", "fixed.toml", ) with open(config_file, "r") as f: quan_args = toml.load(f)["passes"]["quantize"] diff --git a/test/passes/graph/transforms/verilog/test_emit_verilog_bert.py b/test/passes/graph/transforms/verilog/test_emit_verilog_bert.py index b50fb3776..3e0f16599 100644 --- a/test/passes/graph/transforms/verilog/test_emit_verilog_bert.py +++ b/test/passes/graph/transforms/verilog/test_emit_verilog_bert.py @@ -174,10 +174,7 @@ def emit_verilog_bert( mg, _ = bert_update_metadata(mg, q_config) mg, _ = passes.add_hardware_metadata_analysis_pass( - mg, - pass_args={ - "max_parallelism": [max_parallelism] * 4, - }, + mg, pass_args={"max_parallelism": [max_parallelism] * 4,}, ) # * Save the metadata to a file for debugging @@ -193,11 +190,7 @@ def emit_verilog_bert( mg, _ = passes.emit_bram_transform_pass(mg) mg, _ = passes.emit_internal_rtl_transform_pass(mg) mg, _ = passes.emit_cocotb_transform_pass( - mg, - pass_args={ - "wait_time": wait_count, - "wait_unit": wait_unit, - }, + mg, pass_args={"wait_time": wait_count, "wait_unit": wait_unit,}, ) mg, _ = passes.emit_vivado_project_transform_pass(mg) diff --git a/test/passes/graph/transforms/verilog/test_emit_verilog_llama.py b/test/passes/graph/transforms/verilog/test_emit_verilog_llama.py index f268184ce..5447a3782 100644 --- a/test/passes/graph/transforms/verilog/test_emit_verilog_llama.py +++ b/test/passes/graph/transforms/verilog/test_emit_verilog_llama.py @@ -86,10 +86,7 @@ def emit_verilog_llama( mg, _ = llama_update_metadata(mg, q_config) mg, _ = passes.add_hardware_metadata_analysis_pass( - mg, - pass_args={ - "max_parallelism": [max_parallelism] * 4, - }, + mg, pass_args={"max_parallelism": [max_parallelism] * 4,}, ) # * Save the metadata to a file for debugging @@ -105,11 +102,7 @@ def emit_verilog_llama( mg, _ = passes.emit_bram_transform_pass(mg) mg, _ = passes.emit_internal_rtl_transform_pass(mg) mg, _ = passes.emit_cocotb_transform_pass( - mg, - pass_args={ - "wait_time": wait_count, - "wait_unit": wait_unit, - }, + mg, pass_args={"wait_time": wait_count, "wait_unit": wait_unit,}, ) mg, _ = passes.emit_vivado_project_transform_pass(mg) diff --git a/test/passes/graph/transforms/verilog/test_emit_verilog_mistral.py b/test/passes/graph/transforms/verilog/test_emit_verilog_mistral.py index 64625495e..51ae9dc7d 100644 --- a/test/passes/graph/transforms/verilog/test_emit_verilog_mistral.py +++ b/test/passes/graph/transforms/verilog/test_emit_verilog_mistral.py @@ -87,10 +87,7 @@ def emit_verilog_mistral( mg, _ = mistral_update_metadata(mg, q_config) mg, _ = passes.add_hardware_metadata_analysis_pass( - mg, - pass_args={ - "max_parallelism": [max_parallelism] * 4, - }, + mg, pass_args={"max_parallelism": [max_parallelism] * 4,}, ) # * Save the metadata to a file for debugging @@ -106,11 +103,7 @@ def emit_verilog_mistral( mg, _ = passes.emit_bram_transform_pass(mg) mg, _ = passes.emit_internal_rtl_transform_pass(mg) mg, _ = passes.emit_cocotb_transform_pass( - mg, - pass_args={ - "wait_time": wait_count, - "wait_unit": wait_unit, - }, + mg, pass_args={"wait_time": wait_count, "wait_unit": wait_unit,}, ) mg, _ = passes.emit_vivado_project_transform_pass(mg) diff --git a/test/passes/graph/transforms/verilog/test_emit_verilog_norm.py b/test/passes/graph/transforms/verilog/test_emit_verilog_norm.py index daf78fbee..acc659331 100644 --- a/test/passes/graph/transforms/verilog/test_emit_verilog_norm.py +++ b/test/passes/graph/transforms/verilog/test_emit_verilog_norm.py @@ -11,9 +11,7 @@ from mase_components.scalar_operators.fixed.test.isqrt_sw import make_lut from mase_components.common.test.lut_tb import write_memb from chop.passes.graph.utils import get_module_by_name -from chop.nn.quantizers.quantizers_for_hw import ( - integer_quantizer_for_hw, -) +from chop.nn.quantizers.quantizers_for_hw import integer_quantizer_for_hw # import chop.models.manual.rms_norm as rms @@ -118,7 +116,7 @@ def add_norm_metadata_gen_lut_analysis_pass(mg, config={}): mem_dir = Path(__file__).parent / "build" / "norm" / "mem" makedirs(mem_dir, exist_ok=True) - lut = make_lut(2**LUT_POW, ISQRT_WIDTH) + lut = make_lut(2 ** LUT_POW, ISQRT_WIDTH) mem_path = mem_dir / f"norm_isqrt_lut.mem" write_memb(mem_path, lut, ISQRT_WIDTH) mem_id = 0 @@ -226,23 +224,10 @@ def test_emit_verilog_norm(): shape = [10, 4, 8, 8] normalizations = [ - nn.BatchNorm2d( - num_features=shape[1], - affine=False, - ), - nn.LayerNorm( - normalized_shape=shape[1:], - elementwise_affine=False, - ), - nn.GroupNorm( - num_groups=2, - num_channels=shape[1], - affine=False, - ), - nn.InstanceNorm2d( - num_features=shape[1], - affine=False, - ), + nn.BatchNorm2d(num_features=shape[1], affine=False,), + nn.LayerNorm(normalized_shape=shape[1:], elementwise_affine=False,), + nn.GroupNorm(num_groups=2, num_channels=shape[1], affine=False,), + nn.InstanceNorm2d(num_features=shape[1], affine=False,), # rms.RMSNorm( # normalized_shape=shape[1:], # ), diff --git a/test/passes/onnx/analysis/test_export_fx_graph.py b/test/passes/onnx/analysis/test_export_fx_graph.py index 395991cda..fc081f1ee 100644 --- a/test/passes/onnx/analysis/test_export_fx_graph.py +++ b/test/passes/onnx/analysis/test_export_fx_graph.py @@ -78,16 +78,13 @@ def test_export_fx_graph_bert(): @pytest.mark.skip def test_export_fx_graph_mistral(): - export_fx_graph_model( - "mistral-community/Mistral-7B-v0.2", - ) + export_fx_graph_model("mistral-community/Mistral-7B-v0.2",) @pytest.mark.skip def test_export_fx_graph_whisper(): export_fx_graph_model( - "openai/whisper-tiny", - skip_export=True, + "openai/whisper-tiny", skip_export=True, ) diff --git a/test/self/test_optical_module.py b/test/self/test_optical_module.py new file mode 100644 index 000000000..e56921881 --- /dev/null +++ b/test/self/test_optical_module.py @@ -0,0 +1,252 @@ +#!/usr/bin/env python3 +# This example converts a simple MLP model to Verilog +import logging +import os +import sys + +import torch +import torch.nn as nn + +from torch.profiler import profile, record_function, ProfilerActivity +import torchvision.models as models +import torchvision.transforms as transforms +import torch.utils.data as data +from pathlib import Path + +sys.path.append(Path(__file__).resolve().parents[5].as_posix()) + + +# from chop.passes.module.transforms import quantize_module_transform_pass +from chop.passes.module.transforms import optical_module_transform_pass +from chop.passes.module import report_trainable_parameters_analysis_pass + +import argparse +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torchvision import datasets, transforms +from torch.optim.lr_scheduler import StepLR + +from train_mnist_cnn import test, train, Net, test_memory_detailed + +# -------------------------------------------------- +# Model specifications +# -------------------------------------------------- +# class MLP(torch.nn.Module): +# """ +# Toy quantized FC model for digit recognition on MNIST +# """ + +# def __init__(self) -> None: +# super().__init__() + +# self.fc1 = nn.Linear(28 * 28, 28 * 28) +# self.fc2 = nn.Linear(28 * 28, 28 * 28 * 4) +# self.fc3 = nn.Linear(28 * 28 * 4, 10) + +# def forward(self, x): +# x = torch.flatten(x, start_dim=1, end_dim=-1) +# x = torch.nn.functional.relu(self.fc1(x)) +# # w = torch.randn((4, 28 * 28)) +# # x = torch.nn.functional.relu(nn.functional.linear(x, w)) +# x = torch.nn.functional.relu(self.fc2(x)) +# x = self.fc3(x) +# return x + + +def load_my_model(model_path, device): + # Load the model from the .pt file + loaded_model = torch.load(model_path, map_location=device) + # Set it to evaluation mode (important if it contains layers like BatchNorm or Dropout) + loaded_model.eval() + return loaded_model + + +def perform_optical_module_transform_pass(model, save_path="mase_output/onn_cnn.pt"): + pass_args = { + "by": "type", + "linear": { + "config": { + "name": "morr", + "miniblock": 4, + "morr_init": True, + "trainable_morr_bias": False, + "trainable_morr_scale": False, + "device": device, + } + }, + "conv2d": { + "config": { + "name": "morr", + "miniblock": 4, + "morr_init": True, + "trainable_morr_bias": False, + "trainable_morr_scale": False, + "device": device, + } + }, + } + onn_model, _ = optical_module_transform_pass(model, pass_args) + torch.save(onn_model.state_dict(), save_path) + return onn_model + + +def test_optical_module_transform_pass(): + model_path = "mase_output/sample_mnist_cnn.pt" + mnist_cnn = load_my_model(model_path) + # Sanity check and report + pass_args = { + "by": "name", + "fc1": { + "config": { + "name": "morr", + "miniblock": 4, + "morr_init": True, + "trainable_morr_bias": False, + "trainable_morr_scale": False, + "device": device, + } + }, + } + onn_cnn, _ = optical_module_transform_pass(mnist_cnn, pass_args) + torch.save(onn_cnn, "mase_output/onn_cnn.pt") + + +if __name__ == "__main__": + finetune = True + if True: + parser = argparse.ArgumentParser(description="PyTorch MNIST Example") + parser.add_argument( + "--batch-size", + type=int, + default=64, + metavar="N", + help="input batch size for training (default: 64)", + ) + parser.add_argument( + "--test-batch-size", + type=int, + default=100, + metavar="N", + help="input batch size for testing (default: 1000)", + ) + parser.add_argument( + "--epochs", + type=int, + default=1, + metavar="N", + help="number of epochs to train (default: 14)", + ) + parser.add_argument( + "--lr", + type=float, + default=1.0, + metavar="LR", + help="learning rate (default: 1.0)", + ) + parser.add_argument( + "--gamma", + type=float, + default=0.7, + metavar="M", + help="Learning rate step gamma (default: 0.7)", + ) + parser.add_argument( + "--no-cuda", + action="store_true", + default=False, + help="disables CUDA training", + ) + parser.add_argument( + "--no-mps", + action="store_true", + default=False, + help="disables macOS GPU training", + ) + parser.add_argument( + "--dry-run", + action="store_true", + default=False, + help="quickly check a single pass", + ) + parser.add_argument( + "--seed", type=int, default=1, metavar="S", help="random seed (default: 1)" + ) + parser.add_argument( + "--log-interval", + type=int, + default=10, + metavar="N", + help="how many batches to wait before logging training status", + ) + parser.add_argument( + "--save-model", + action="store_true", + default=True, + help="For Saving the current Model", + ) + parser.add_argument( + "--gpu-id", type=int, default=1, help="Which GPU device to use [default: 0]" + ) + + args = parser.parse_args() + use_cuda = not args.no_cuda and torch.cuda.is_available() + use_mps = not args.no_mps and torch.backends.mps.is_available() + + torch.manual_seed(args.seed) + + if not args.no_cuda and torch.cuda.is_available(): + device = torch.device(f"cuda:{args.gpu_id}") + elif use_mps: + device = torch.device("mps") + else: + device = torch.device("cpu") + + train_kwargs = {"batch_size": args.batch_size} + test_kwargs = {"batch_size": args.test_batch_size} + if use_cuda: + cuda_kwargs = {"num_workers": 1, "pin_memory": True, "shuffle": True} + train_kwargs.update(cuda_kwargs) + test_kwargs.update(cuda_kwargs) + + transform = transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] + ) + dataset1 = datasets.MNIST( + "../data", train=True, download=True, transform=transform + ) + dataset2 = datasets.MNIST("../data", train=False, transform=transform) + train_loader = torch.utils.data.DataLoader(dataset1, **train_kwargs) + test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs) + + cnn = load_my_model("mase_output/sample_mnist_cnn.pt", device) + print("-------------- Testing the original cnn model -------------------") + test(cnn, device, test_loader) + _, _ = report_trainable_parameters_analysis_pass(cnn) + + # onn = load_my_model("mase_output/onn_cnn.pt", device) + onn_model = perform_optical_module_transform_pass(cnn) + onn_model.to(device) + + print("-------------- Testing the transformed onn model -------------------") + test(onn_model, device, test_loader) + _, _ = report_trainable_parameters_analysis_pass(onn_model) + + ######### Training the onn model + if finetune: + optimizer = optim.Adadelta(onn_model.parameters(), lr=args.lr) + scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) + + for epoch in range(1, args.epochs + 1): + train(args, onn_model, device, train_loader, optimizer, epoch) + test(onn_model, device, test_loader) + scheduler.step() + + torch.save(onn_model.state_dict(), "mase_output/trained_onn.pt") + + print("-------------- Testing the trained onn model -------------------") + test(onn_model, device, test_loader) + _, _ = report_trainable_parameters_analysis_pass(onn_model) + + # test_optical_module_transform_pass() diff --git a/test/self/train_mnist_cnn.py b/test/self/train_mnist_cnn.py new file mode 100644 index 000000000..a955087da --- /dev/null +++ b/test/self/train_mnist_cnn.py @@ -0,0 +1,247 @@ +import argparse +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torchvision import datasets, transforms +from torch.optim.lr_scheduler import StepLR + + +class Net(nn.Module): + def __init__(self): + super(Net, self).__init__() + self.conv1 = nn.Conv2d(1, 32, 3, 1) + self.conv2 = nn.Conv2d(32, 64, 3, 1) + self.dropout1 = nn.Dropout(0.25) + self.dropout2 = nn.Dropout(0.5) + self.fc1 = nn.Linear(9216, 128) + self.fc2 = nn.Linear(128, 10) + + def forward(self, x): + x = self.conv1(x) + x = F.relu(x) + x = self.conv2(x) + x = F.relu(x) + x = F.max_pool2d(x, 2) + x = self.dropout1(x) + x = torch.flatten(x, 1) + x = self.fc1(x) + x = F.relu(x) + x = self.dropout2(x) + x = self.fc2(x) + output = F.log_softmax(x, dim=1) + return output + + +def train(args, model, device, train_loader, optimizer, epoch): + model.train() + for batch_idx, (data, target) in enumerate(train_loader): + data, target = data.to(device), target.to(device) + optimizer.zero_grad() + output = model(data) + loss = F.nll_loss(output, target) + loss.backward() + optimizer.step() + if batch_idx % args.log_interval == 0: + print( + "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format( + epoch, + batch_idx * len(data), + len(train_loader.dataset), + 100.0 * batch_idx / len(train_loader), + loss.item(), + ) + ) + if args.dry_run: + break + + +def test(model, device, test_loader): + model.eval() + test_loss = 0 + correct = 0 + with torch.no_grad(): + for data, target in test_loader: + data, target = data.to(device), target.to(device) + output = model(data) + test_loss += F.nll_loss( + output, target, reduction="sum" + ).item() # sum up batch loss + pred = output.argmax( + dim=1, keepdim=True + ) # get the index of the max log-probability + correct += pred.eq(target.view_as(pred)).sum().item() + + test_loss /= len(test_loader.dataset) + + print( + "\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format( + test_loss, + correct, + len(test_loader.dataset), + 100.0 * correct / len(test_loader.dataset), + ) + ) + + +# Custom Function +from torch.profiler import profile, record_function, ProfilerActivity + + +def test_memory_detailed(model, device, test_loader): + """ + Use PyTorch Profiler to record detailed memory usage on the first batch, + so you can see exactly which ops consume how much memory. + """ + + # Put the model in eval mode + model.eval() + + # Get just 1 batch (or a few) for profiling + data_iter = iter(test_loader) + try: + data, target = next(data_iter) + except StopIteration: + print("test_loader is empty. Cannot profile.") + return + + # Move to device + data, target = data.to(device), target.to(device) + + # Now start the profiler + with profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + record_shapes=True, # to see tensor shapes + profile_memory=True, # track memory usage + ) as prof: + with record_function("test_first_batch"): + # Forward pass on just this batch + output = model(data) + # Suppose you also compute a loss for some reason + loss = F.nll_loss(output, target, reduction="sum") + # If purely inference, you might not do backward + # but let's illustrate: + loss.backward() # If you want to see backward pass memory usage + + # Print the summarized table + # Sort by self_cuda_memory_usage to see which ops used the most GPU memory + print( + prof.key_averages(group_by_input_shape=True).table( + sort_by="self_cuda_memory_usage", + row_limit=200, # show as many rows as you need + ) + ) + + +def main(): + # Training settings + parser = argparse.ArgumentParser(description="PyTorch MNIST Example") + parser.add_argument( + "--batch-size", + type=int, + default=64, + metavar="N", + help="input batch size for training (default: 64)", + ) + parser.add_argument( + "--test-batch-size", + type=int, + default=1000, + metavar="N", + help="input batch size for testing (default: 1000)", + ) + parser.add_argument( + "--epochs", + type=int, + default=29, + metavar="N", + help="number of epochs to train (default: 14)", + ) + parser.add_argument( + "--lr", + type=float, + default=1.0, + metavar="LR", + help="learning rate (default: 1.0)", + ) + parser.add_argument( + "--gamma", + type=float, + default=0.7, + metavar="M", + help="Learning rate step gamma (default: 0.7)", + ) + parser.add_argument( + "--no-cuda", action="store_true", default=False, help="disables CUDA training" + ) + parser.add_argument( + "--no-mps", + action="store_true", + default=False, + help="disables macOS GPU training", + ) + parser.add_argument( + "--dry-run", + action="store_true", + default=False, + help="quickly check a single pass", + ) + parser.add_argument( + "--seed", type=int, default=1, metavar="S", help="random seed (default: 1)" + ) + parser.add_argument( + "--log-interval", + type=int, + default=10, + metavar="N", + help="how many batches to wait before logging training status", + ) + parser.add_argument( + "--save-model", + action="store_true", + default=True, + help="For Saving the current Model", + ) + args = parser.parse_args() + use_cuda = not args.no_cuda and torch.cuda.is_available() + use_mps = not args.no_mps and torch.backends.mps.is_available() + + torch.manual_seed(args.seed) + + if use_cuda: + device = torch.device("cuda") + elif use_mps: + device = torch.device("mps") + else: + device = torch.device("cpu") + + train_kwargs = {"batch_size": args.batch_size} + test_kwargs = {"batch_size": args.test_batch_size} + if use_cuda: + cuda_kwargs = {"num_workers": 1, "pin_memory": True, "shuffle": True} + train_kwargs.update(cuda_kwargs) + test_kwargs.update(cuda_kwargs) + + transform = transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] + ) + dataset1 = datasets.MNIST("../data", train=True, download=True, transform=transform) + dataset2 = datasets.MNIST("../data", train=False, transform=transform) + train_loader = torch.utils.data.DataLoader(dataset1, **train_kwargs) + test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs) + + model = Net().to(device) + optimizer = optim.Adadelta(model.parameters(), lr=args.lr) + + scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) + for epoch in range(1, args.epochs + 1): + train(args, model, device, train_loader, optimizer, epoch) + test(model, device, test_loader) + scheduler.step() + + if args.save_model: + torch.save(model, "mase_output/sample_mnist_cnn.pt") + + +if __name__ == "__main__": + main() diff --git a/test/tools/test_onnx_operators.py b/test/tools/test_onnx_operators.py index c15a44d89..7393cb41f 100644 --- a/test/tools/test_onnx_operators.py +++ b/test/tools/test_onnx_operators.py @@ -14,58 +14,20 @@ def excepthook(exc_type, exc_value, exc_traceback): def test_gather(): - data1 = torch.Tensor( - [ - [1.0, 1.2, 1.9], - [2.3, 3.4, 3.9], - [4.5, 5.7, 5.9], - ] - ) + data1 = torch.Tensor([[1.0, 1.2, 1.9], [2.3, 3.4, 3.9], [4.5, 5.7, 5.9],]) - data2 = torch.Tensor( - [ - [1.0, 1.2], - [2.3, 3.4], - [4.5, 5.7], - ] - ) + data2 = torch.Tensor([[1.0, 1.2], [2.3, 3.4], [4.5, 5.7],]) - indices1 = torch.Tensor( - [ - [0, 2], - ] - ).to(torch.int64) + indices1 = torch.Tensor([[0, 2],]).to(torch.int64) - indices2 = torch.Tensor( - [ - [0, 1], - [1, 2], - ] - ).to(torch.int64) + indices2 = torch.Tensor([[0, 1], [1, 2],]).to(torch.int64) obs_out1 = onnx_gather(data1, 1, indices1) obs_out2 = onnx_gather(data2, 0, indices2) - exp_out1 = torch.Tensor( - [ - [[1.0, 1.9]], - [[2.3, 3.9]], - [[4.5, 5.9]], - ] - ) + exp_out1 = torch.Tensor([[[1.0, 1.9]], [[2.3, 3.9]], [[4.5, 5.9]],]) - exp_out2 = torch.Tensor( - [ - [ - [1.0, 1.2], - [2.3, 3.4], - ], - [ - [2.3, 3.4], - [4.5, 5.7], - ], - ] - ) + exp_out2 = torch.Tensor([[[1.0, 1.2], [2.3, 3.4],], [[2.3, 3.4], [4.5, 5.7],],]) print(obs_out2) print(exp_out2) @@ -75,12 +37,7 @@ def test_gather(): def test_slice(): - data = torch.Tensor( - [ - [1, 2, 3, 4], - [5, 6, 7, 8], - ] - ) + data = torch.Tensor([[1, 2, 3, 4], [5, 6, 7, 8],]) test1 = onnx_slice( data, From 8aa1dbddbcacef0c8e369b29bc7f3bdc124d844b Mon Sep 17 00:00:00 2001 From: Johnny1882 Date: Sun, 16 Feb 2025 22:51:29 +0000 Subject: [PATCH 13/38] . --- scripts/mase-hls.py | 30 +++++-- scripts/stat-to-conf.py | 7 +- setup.py | 4 +- src/chop/actions/emit.py | 5 +- src/chop/actions/search/search.py | 4 +- .../quantization/manual_hf_module.py | 3 +- src/chop/actions/simulate.py | 4 +- src/chop/actions/train.py | 3 +- src/chop/dataset/nerf/blender.py | 4 +- src/chop/dataset/nlp/language_modeling.py | 10 ++- src/chop/dataset/vision/transforms/cifar.py | 5 +- src/chop/distributed/launcher.py | 16 +++- src/chop/distributed/tensor/__init__.py | 18 +++- src/chop/distributed/tensor/_dispatch.py | 11 ++- src/chop/distributed/tensor/_redistribute.py | 17 +++- src/chop/distributed/tensor/_sharding_prop.py | 15 ++-- src/chop/distributed/tensor/api.py | 56 ++++++++++--- .../distributed/tensor/ops/basic_strategy.py | 5 +- .../distributed/tensor/ops/common_rules.py | 15 +++- src/chop/distributed/tensor/ops/conv_ops.py | 23 +++-- src/chop/distributed/tensor/ops/math_ops.py | 34 ++++++-- src/chop/distributed/tensor/ops/matrix_ops.py | 10 ++- .../distributed/tensor/ops/pointwise_ops.py | 9 +- src/chop/distributed/tensor/ops/tensor_ops.py | 16 +++- src/chop/distributed/tensor/ops/utils.py | 4 +- src/chop/distributed/tensor/ops/view_ops.py | 9 +- src/chop/ir/graph/mase_graph.py | 49 ++++++++--- src/chop/ir/graph/mase_metadata.py | 4 +- src/chop/ir/onnx/mase_onnx_graph.py | 4 +- src/chop/ir/onnx/utils.py | 9 +- .../models/bert/modeling_bert_quantized.py | 8 +- src/chop/models/bert/quant_config_bert.py | 5 +- src/chop/models/cnv/cnv.py | 9 +- src/chop/models/cswin/cswintransformer.py | 8 +- src/chop/models/deit/deit_v2.py | 2 +- src/chop/models/efficientnet/efficientnet.py | 23 +++-- src/chop/models/lfc/lfc.py | 4 +- src/chop/models/llama/modeling_llama_llora.py | 5 +- .../models/llama/modeling_llama_sparse.py | 5 +- src/chop/models/mobilenet_v2/mobilenet_v2.py | 3 +- src/chop/models/nerf/nerf_vision.py | 4 +- src/chop/models/opt/modeling_opt.py | 2 +- src/chop/models/opt/modeling_opt_lora.py | 2 +- src/chop/models/opt/modeling_opt_quantized.py | 2 +- src/chop/models/opt/modeling_opt_sparse.py | 2 +- .../models/opt/quant_config_opt_quantized.py | 4 +- src/chop/models/pvt/pvt.py | 9 +- src/chop/models/pvt/pvt_v2.py | 2 +- src/chop/models/repvgg/repvgg.py | 4 +- src/chop/models/resnet/resnet.py | 56 ++++++++++--- src/chop/models/toy/toy.py | 15 +++- src/chop/models/vgg_cifar/vgg_orig.py | 61 ++++++++++++-- src/chop/models/vision/snn/snn_toy.py | 4 +- .../models/vision/snn/spikingResformer.py | 42 ++++++++-- src/chop/nn/functional/softermax.py | 2 +- src/chop/nn/modules/gqa.py | 11 ++- src/chop/nn/modules/lora.py | 13 ++- src/chop/nn/modules/sparse.py | 19 ++++- src/chop/nn/mx/activations.py | 5 +- src/chop/nn/mx/bmm.py | 12 ++- src/chop/nn/mx/convolution.py | 32 +++++-- src/chop/nn/mx/elemwise_ops.py | 8 +- src/chop/nn/mx/formats.py | 8 +- src/chop/nn/mx/linear.py | 32 +++++-- src/chop/nn/mx/matmul.py | 12 ++- src/chop/nn/mx/mx_ops.py | 9 +- src/chop/nn/mx/quantize.py | 4 +- src/chop/nn/mx/simd_ops.py | 2 +- src/chop/nn/mx/transpose_convolution.py | 32 +++++-- src/chop/nn/optical/modules/morr_conv2d.py | 6 +- src/chop/nn/optical/modules/morr_linear.py | 6 +- src/chop/nn/optical/utils/initializer.py | 2 +- src/chop/nn/optical/utils/mrr_op.py | 2 +- src/chop/nn/optical/utils/quantize.py | 6 +- src/chop/nn/quantized/functional/gelu.py | 4 +- src/chop/nn/quantized/functional/linear.py | 83 +++++++++++++++---- src/chop/nn/quantized/functional/matmul.py | 8 +- src/chop/nn/quantized/functional/relu.py | 4 +- src/chop/nn/quantized/functional/selu.py | 4 +- src/chop/nn/quantized/functional/softplus.py | 4 +- src/chop/nn/quantized/functional/softsign.py | 4 +- src/chop/nn/quantized/functional/tanh.py | 4 +- .../nn/quantized/modules/attention_head.py | 5 +- src/chop/nn/quantized/modules/conv1d.py | 12 ++- src/chop/nn/quantized/modules/conv2d.py | 39 ++++++--- src/chop/nn/quantized/modules/gelu.py | 8 +- src/chop/nn/quantized/modules/gqa.py | 5 +- src/chop/nn/quantized/modules/linear.py | 18 ++-- src/chop/nn/quantized/modules/relu.py | 8 +- src/chop/nn/quantized/modules/selu.py | 8 +- src/chop/nn/quantized/modules/silu.py | 8 +- src/chop/nn/quantized/modules/softplus.py | 8 +- src/chop/nn/quantized/modules/softsign.py | 8 +- src/chop/nn/quantized/modules/tanh.py | 8 +- .../nn/quantizers/LUTNet/BaseInitializer.py | 5 +- src/chop/nn/quantizers/LUTNet/BaseTrainer.py | 2 +- src/chop/nn/quantizers/block_fp.py | 17 ++-- src/chop/nn/quantizers/block_log.py | 8 +- src/chop/nn/quantizers/block_minifloat.py | 9 +- src/chop/nn/quantizers/integer.py | 8 +- src/chop/nn/quantizers/log.py | 14 ++-- src/chop/nn/quantizers/minifloat.py | 34 +++++--- src/chop/nn/quantizers/mxint_hardware.py | 6 +- src/chop/nn/quantizers/quantizers_for_hw.py | 14 ++-- src/chop/nn/quantizers/ternary.py | 3 +- src/chop/nn/quantizers/utils.py | 32 +++++-- src/chop/nn/snn/auto_cuda/generator.py | 5 +- .../nn/snn/modules/spiking_self_attention.py | 8 +- src/chop/passes/graph/__init__.py | 6 +- .../add_metadata/add_common_metadata.py | 12 ++- .../add_metadata/add_hardware_metadata.py | 3 +- .../add_metadata/common_metadata_layers.py | 57 ++++++++++--- .../add_metadata/hardware_metadata_layers.py | 37 +++++++-- .../add_metadata/software_metadata_layers.py | 16 +++- .../autosharding/alpa_cost_modelling.py | 10 ++- .../analysis/autosharding/autosharding.py | 31 +++++-- .../graph/analysis/autosharding/megatron.py | 4 +- .../autosharding/strategies/basic_strategy.py | 5 +- .../autosharding/strategies/common.py | 6 +- .../autosharding/strategies/matrix_ops.py | 38 +++++++-- .../autosharding/strategies/pointwise_ops.py | 4 +- .../autosharding/strategies/view_ops.py | 7 +- .../flop_estimator/calculator/calc_modules.py | 2 +- .../passes/graph/analysis/plot/plot_graph.py | 5 +- .../graph/interface/tensorrt/quantize.py | 1 - .../passes/graph/transforms/dse/run_dse.py | 2 +- src/chop/passes/graph/transforms/lora.py | 6 +- .../graph/transforms/onnxrt/quantize.py | 4 +- .../transforms/pruning/pruning_methods.py | 28 ++++++- .../quant_parsers/parse_quant_config.py | 82 ++++++++++++++---- .../graph/transforms/training/modify.py | 5 +- .../graph/transforms/verilog/emit_bram.py | 8 +- .../graph/transforms/verilog/emit_hls.py | 7 +- src/chop/passes/module/analysis/report.py | 3 +- src/chop/passes/utils.py | 8 +- src/chop/pipelines/auto_pipeline.py | 22 ++++- src/chop/tools/check_dependency.py | 7 +- src/chop/tools/huggingface.py | 14 +++- .../tools/plt_wrapper/nlp/classification.py | 4 +- src/chop/tools/plt_wrapper/nlp/lm.py | 4 +- src/chop/tools/utils.py | 6 +- src/mase_cocotb/interfaces/streaming.py | 24 ++++-- src/mase_cocotb/runner.py | 13 ++- src/mase_cocotb/testbench.py | 7 +- src/mase_cocotb/utils.py | 4 +- src/mase_cocotb/z_qlayers/tensor_cast.py | 6 +- .../activation_layers/test/fixed_elu_tb.py | 14 ++-- .../activation_layers/test/fixed_gelu_tb.py | 10 +-- .../test/fixed_hardshrink_tb.py | 6 +- .../test/fixed_hardswish_tb.py | 10 +-- .../test/fixed_leaky_relu_tb.py | 6 +- .../test/fixed_logsigmoid_tb.py | 6 +- .../activation_layers/test/fixed_relu_tb.py | 6 +- .../activation_layers/test/fixed_selu_tb.py | 10 +-- .../test/fixed_sigmoid_tb.py | 6 +- .../activation_layers/test/fixed_silu_tb.py | 6 +- .../test/fixed_softermax_1d_tb.py | 19 ++++- .../test/fixed_softermax_tb.py | 4 +- .../test/fixed_softmax_tb.py | 6 +- .../test/fixed_softplus_tb.py | 10 +-- .../test/fixed_softshrink_tb.py | 6 +- .../test/fixed_softsign_tb.py | 10 +-- .../activation_layers/test/fixed_tanh_tb.py | 10 +-- .../activation_layers/test/softermax.py | 2 +- .../test/softermax_global_norm_tb.py | 26 ++++-- .../test/softermax_local_window_tb.py | 8 +- .../test/softermax_lpw_pow2_tb.py | 51 +++++++----- .../test/softermax_lpw_reciprocal_tb.py | 40 +++++---- .../cast/test/fixed_rounding_tb.py | 2 +- .../cast/test/fixed_signed_cast_tb.py | 13 +-- .../cast/test/fixed_unsigned_cast_tb.py | 8 +- .../common/test/comparator_accumulator_tb.py | 8 +- .../common/test/comparator_tree_tb.py | 5 +- .../common/test/register_slice_tb.py | 4 +- .../common/test/single_element_repeat_tb.py | 6 +- ...binary_activation_binary_convolution_tb.py | 4 +- .../convolution_layers/test/convolution_tb.py | 34 ++++++-- .../convolution_layers/test/padding_tb.py | 10 ++- .../convolution_layers/test/roller_tb.py | 4 +- .../test/sliding_window_tb.py | 6 +- src/mase_components/deps.py | 12 ++- src/mase_components/helper/generate_memory.py | 8 +- .../hls/bfp_arith/bfp_adder.py | 6 +- .../hls/bfp_arith/bfp_multiplier.py | 6 +- src/mase_components/hls/elastic/buffer.py | 8 +- src/mase_components/hls/hls_regression.py | 5 +- .../hls/int_arith/int_layernorm.py | 8 +- src/mase_components/hls/int_arith/int_relu.py | 8 +- src/mase_components/hls/int_arith/int_silu.py | 8 +- .../hls/int_arith/int_softmax.py | 8 +- .../hls/int_arith/int_transpose.py | 8 +- .../hls/regression_gen/bfp_add_dse.py | 11 ++- .../hls/regression_gen/bfp_linear2d_dse.py | 8 +- .../hls/regression_gen/bfp_mult_dse.py | 11 ++- .../hls/regression_gen/buffer_dse.py | 11 ++- .../hls/regression_gen/fork_dse.py | 11 ++- .../hls/regression_gen/int_add_dse.py | 11 ++- .../hls/regression_gen/int_layernorm_dse.py | 11 ++- .../hls/regression_gen/int_linear2d_dse.py | 8 +- .../hls/regression_gen/int_matmul_dse.py | 8 +- .../hls/regression_gen/int_mult_dse.py | 11 ++- .../hls/regression_gen/int_relu_dse.py | 11 ++- .../hls/regression_gen/int_rmsnorm_dse.py | 11 ++- .../hls/regression_gen/int_rope_dse.py | 11 ++- .../hls/regression_gen/int_silu_dse.py | 11 ++- .../hls/regression_gen/int_softmax_dse.py | 11 ++- .../hls/regression_gen/int_transpose_dse.py | 11 ++- src/mase_components/hls/scripts/bl_bfp.py | 15 +++- ...y_activation_binary_adder_tree_layer_tb.py | 3 +- .../fixed_operators/test/fixed_isqrt_tb.py | 18 ++-- .../test/fixed_lut_index_tb.py | 2 +- .../fixed_operators/test/fixed_nr_stage_tb.py | 4 +- .../test/fixed_range_augmentation_tb.py | 2 +- .../test/fixed_range_reduction_tb.py | 2 +- .../fixed_operators/test/isqrt_sw.py | 10 +-- .../matmul/test/fixed_matmul_tb.py | 8 +- .../matmul/test/simple_matmul_tb.py | 8 +- .../linear_layers/matmul/test/transpose_tb.py | 6 +- .../mxint_operators/test/mxint_cast_tb.py | 13 ++- .../test/mxint_dot_product_tb.py | 9 +- .../mxint_operators/test/mxint_matmul_tb.py | 10 ++- .../test/mxint_vector_mult_tb.py | 9 +- .../mxint_operators/test/test.py | 16 +++- .../mxint_operators/test/utils.py | 6 +- src/mase_components/memory/test/fifo_tb.py | 9 +- .../memory/test/repeat_circular_buffer_tb.py | 2 +- .../memory/test/unpacked_fifo_tb.py | 2 +- .../process_synth_impl.py | 11 ++- .../test/batch_norm_2d_tb.py | 5 +- .../test/group_norm_2d_tb.py | 13 ++- .../normalization_layers/test/models.py | 14 ++-- .../test/rms_norm_2d_tb.py | 11 ++- .../fixed/test/fixed_isqrt_tb.py | 18 ++-- .../fixed/test/fixed_nr_stage_tb.py | 4 +- .../scalar_operators/fixed/test/isqrt_sw.py | 10 +-- ...ixed_grouped_query_attention_wrapper_tb.py | 50 +++++++---- .../test/fixed_self_attention_head_tb.py | 32 +++++-- .../vision_models/vit/test/fixed_mlp_tb.py | 6 +- .../vit/test/fixed_patch_embed_tb.py | 9 +- .../vision_models/vit/test/fixed_pvt_tb.py | 12 ++- .../vision_models/vit/test/hash_exp_tb.py | 4 +- .../vision_models/vit/test/hash_softmax_tb.py | 6 +- .../vit/test/helpers/ha_softmax.py | 11 +-- .../vit/test/helpers/pvt_quant.py | 2 +- test/nn/snn/test_ann2snn.py | 8 +- .../add_metadata/test_add_common_metadata.py | 16 +++- .../analysis/pruning/test_hook_inspect.py | 12 ++- .../test_statistic_profiler.py | 16 +++- .../graph/transforms/prune/test_prune.py | 12 ++- .../prune/test_prune_detach_hook.py | 12 ++- .../quantize/test_quantize_lutnet_conv2d.py | 3 +- .../training/test_training_base_pass.py | 6 +- .../verilog/test_emit_activation_gelu.py | 6 +- .../verilog/test_emit_activation_selu.py | 6 +- .../verilog/test_emit_activation_softplus.py | 6 +- .../verilog/test_emit_activation_softsign.py | 6 +- .../verilog/test_emit_activation_tanh.py | 6 +- .../verilog/test_emit_verilog_bert.py | 11 ++- .../verilog/test_emit_verilog_llama.py | 11 ++- .../verilog/test_emit_verilog_mistral.py | 11 ++- .../verilog/test_emit_verilog_norm.py | 23 +++-- .../onnx/analysis/test_export_fx_graph.py | 7 +- test/tools/test_onnx_operators.py | 57 +++++++++++-- 263 files changed, 2273 insertions(+), 792 deletions(-) diff --git a/scripts/mase-hls.py b/scripts/mase-hls.py index 925ec2300..f2097bcd1 100755 --- a/scripts/mase-hls.py +++ b/scripts/mase-hls.py @@ -44,17 +44,33 @@ def build(self): def quick_test(self): shutil.copy( - os.path.join(self.root, "test", "test_in.mlir",), - os.path.join(self.root, "test", "test.mlir",), + os.path.join( + self.root, + "test", + "test_in.mlir", + ), + os.path.join( + self.root, + "test", + "test.mlir", + ), ) result = False cmd = [ "mase-opt", "--preprocess-func=func-name=relu", "--canonicalize", - os.path.join(self.root, "test", "test.mlir",), + os.path.join( + self.root, + "test", + "test.mlir", + ), "-o", - os.path.join(self.root, "test", "test1.mlir",), + os.path.join( + self.root, + "test", + "test1.mlir", + ), ] result |= self.execute(cmd, log_output=True, cwd=self.root) @@ -67,7 +83,11 @@ def quick_test(self): # "test", # "test.cpp", # ), - os.path.join(self.root, "test", "test1.mlir",), + os.path.join( + self.root, + "test", + "test1.mlir", + ), "--debug", ] result |= self.execute(cmd, log_output=True, cwd=self.root) diff --git a/scripts/stat-to-conf.py b/scripts/stat-to-conf.py index 182e66d64..01e0d0cbc 100755 --- a/scripts/stat-to-conf.py +++ b/scripts/stat-to-conf.py @@ -41,7 +41,12 @@ } -def set_stat(entry_name: str, mean=None, median=None, max=None,) -> dict[str, Any]: +def set_stat( + entry_name: str, + mean=None, + median=None, + max=None, +) -> dict[str, Any]: """Return a dictionary containing the format of the stats required to use ternary quantiser. If statistics are not specified, "NA" will be set as the value, this interally is being interpreted as None when the .toml is loaded""" diff --git a/setup.py b/setup.py index e0e91be1d..27d865de1 100644 --- a/setup.py +++ b/setup.py @@ -105,7 +105,9 @@ def get_system(): author_email="a.zhao@imperial.ac.uk, jianyi.cheng17@imperial.ac.uk, chengzhang98@outlook.com, pedro.gimenes19@imperial.ac.uk", license_files=("LICENSE",), python_requires=">=3.11.9", - package_dir={"": "src",}, + package_dir={ + "": "src", + }, packages=find_packages("src"), install_requires=requirements, ) diff --git a/src/chop/actions/emit.py b/src/chop/actions/emit.py index 9318b4476..7e326bc7b 100644 --- a/src/chop/actions/emit.py +++ b/src/chop/actions/emit.py @@ -41,7 +41,10 @@ def emit( data_module.prepare_data() data_module.setup() dummy_in = get_dummy_input( - model_info=model_info, data_module=data_module, task=task, device="cpu", + model_info=model_info, + data_module=data_module, + task=task, + device="cpu", ) mg, _ = add_common_metadata_analysis_pass( mg, {"dummy_in": dummy_in, "add_value": False} diff --git a/src/chop/actions/search/search.py b/src/chop/actions/search/search.py index 48c2d21d7..fe82d3574 100644 --- a/src/chop/actions/search/search.py +++ b/src/chop/actions/search/search.py @@ -15,7 +15,9 @@ logger = logging.getLogger(__name__) -def parse_search_config(search_config: dict,): +def parse_search_config( + search_config: dict, +): """ Parse search config from a dict or a toml file and do sanity check. The search config must consist of two parts: strategy and search_space. diff --git a/src/chop/actions/search/search_space/quantization/manual_hf_module.py b/src/chop/actions/search/search_space/quantization/manual_hf_module.py index a9c9f1bfb..094330da1 100644 --- a/src/chop/actions/search/search_space/quantization/manual_hf_module.py +++ b/src/chop/actions/search/search_space/quantization/manual_hf_module.py @@ -61,7 +61,8 @@ def rebuild_model(self, sampled_config: dict, is_eval_mode: bool): with init_empty_weights(): model = self.model_cls(config) device_map = infer_auto_device_map( - model, no_split_module_classes=model._no_split_modules, + model, + no_split_module_classes=model._no_split_modules, ) model = load_checkpoint_and_dispatch( model, checkpoint=self.model_name, device_map=device_map diff --git a/src/chop/actions/simulate.py b/src/chop/actions/simulate.py index 8a0e93486..e56a512d8 100644 --- a/src/chop/actions/simulate.py +++ b/src/chop/actions/simulate.py @@ -68,7 +68,9 @@ def simulate( else: raise ValueError(f"Unrecognized simulator: {simulator}") - includes = [project_dir / "hardware" / "rtl",] + [ + includes = [ + project_dir / "hardware" / "rtl", + ] + [ Path(mase_components.__file__).parent / module / "rtl" for module in get_modules() ] diff --git a/src/chop/actions/train.py b/src/chop/actions/train.py index 93524e756..5daa9b49a 100644 --- a/src/chop/actions/train.py +++ b/src/chop/actions/train.py @@ -109,7 +109,8 @@ def train( trainer = pl.Trainer(**plt_trainer_args) trainer.fit( - pl_model, datamodule=data_module, + pl_model, + datamodule=data_module, ) # Save the trained model along with relevant metadata in the training_ckpts folder. diff --git a/src/chop/dataset/nerf/blender.py b/src/chop/dataset/nerf/blender.py index 748dacca7..53865d621 100644 --- a/src/chop/dataset/nerf/blender.py +++ b/src/chop/dataset/nerf/blender.py @@ -34,7 +34,9 @@ def _download_lego_dataset(path: Path) -> None: # Unzip the file subprocess.run( - f"unzip {folder_path.as_posix()} -d {path.as_posix()}", shell=True, check=True, + f"unzip {folder_path.as_posix()} -d {path.as_posix()}", + shell=True, + check=True, ) diff --git a/src/chop/dataset/nlp/language_modeling.py b/src/chop/dataset/nlp/language_modeling.py index 8d17b9466..ba2bdee55 100644 --- a/src/chop/dataset/nlp/language_modeling.py +++ b/src/chop/dataset/nlp/language_modeling.py @@ -281,7 +281,10 @@ def _tokenize(text, tokenizer, max_length): prompt_len = prompt_tokenized.ne(tokenizer.pad_token_id).sum().item() target_tokenized[:prompt_len] = ignore_id - return dict(input_ids=input_ids, labels=target_tokenized,) + return dict( + input_ids=input_ids, + labels=target_tokenized, + ) def prepare_data(self): dataset_dict = self._download_dataset() @@ -313,7 +316,10 @@ def setup(self): dataset_dict = self._download_dataset() dataset_dict = dataset_dict["train"].train_test_split(test_size=0.1, seed=42) dataset_dict = hf_datasets.DatasetDict( - {"train": dataset_dict["train"], "validation": dataset_dict["test"],} + { + "train": dataset_dict["train"], + "validation": dataset_dict["test"], + } ) dataset_dict = dataset_dict.map( function=partial( diff --git a/src/chop/dataset/vision/transforms/cifar.py b/src/chop/dataset/vision/transforms/cifar.py index 84e7e1287..ef44d3721 100644 --- a/src/chop/dataset/vision/transforms/cifar.py +++ b/src/chop/dataset/vision/transforms/cifar.py @@ -35,7 +35,10 @@ def _get_cifar_default_transform(train: bool, mean: tuple[float], std: tuple[float]): if train: - transform = create_transform(**DEFAULT_CIFAR_PREPROCESS_ARGS, is_training=True,) + transform = create_transform( + **DEFAULT_CIFAR_PREPROCESS_ARGS, + is_training=True, + ) transform.transforms[0] = tv_transforms.RandomCrop( DEFAULT_CIFAR_PREPROCESS_ARGS["input_size"], padding=4 ) diff --git a/src/chop/distributed/launcher.py b/src/chop/distributed/launcher.py index 21337be8c..712bb026f 100644 --- a/src/chop/distributed/launcher.py +++ b/src/chop/distributed/launcher.py @@ -33,7 +33,10 @@ def distributed_average_timing(fn, repeat, args): times = [] for itr in range(repeat): rlog( - logger, dist.get_rank(), f"Running teration {itr}", "debug", + logger, + dist.get_rank(), + f"Running teration {itr}", + "debug", ) dist.barrier(async_op=True) start = time() @@ -42,7 +45,10 @@ def distributed_average_timing(fn, repeat, args): end = time() times.append(end - start) rlog( - logger, dist.get_rank(), f"Time taken: {end - start}s", "debug", + logger, + dist.get_rank(), + f"Time taken: {end - start}s", + "debug", ) return result, sum(times[2:]) / len(times[2:]) @@ -132,7 +138,11 @@ def device_fn( distribute_tensor(in_tensor, mesh, [Replicate(), Replicate()]) for in_tensor in inputs ] - _, time_taken = distributed_average_timing(fn=model, repeat=10, args=inputs,) + _, time_taken = distributed_average_timing( + fn=model, + repeat=10, + args=inputs, + ) rlog(logger, rank, f"Forward pass finished. Time taken: {time_taken}", level="info") dist.destroy_process_group() diff --git a/src/chop/distributed/tensor/__init__.py b/src/chop/distributed/tensor/__init__.py index b9aa9c272..826ed05cc 100644 --- a/src/chop/distributed/tensor/__init__.py +++ b/src/chop/distributed/tensor/__init__.py @@ -34,7 +34,11 @@ def _dtensor_init_helper( - init_op, size: torch.Size, device_mesh=None, placements=None, **kwargs, + init_op, + size: torch.Size, + device_mesh=None, + placements=None, + **kwargs, ) -> DTensor: from torch.distributed.tensor.placement_types import _DTensorSpec, TensorMeta @@ -78,10 +82,18 @@ def _dtensor_init_helper( spec = _DTensorSpec( device_mesh, tuple(placements), - tensor_meta=TensorMeta(size, torch_stride, local_tensor.dtype,), + tensor_meta=TensorMeta( + size, + torch_stride, + local_tensor.dtype, + ), ) - return DTensor(local_tensor, spec, requires_grad=kwargs["requires_grad"],) + return DTensor( + local_tensor, + spec, + requires_grad=kwargs["requires_grad"], + ) def ones( diff --git a/src/chop/distributed/tensor/_dispatch.py b/src/chop/distributed/tensor/_dispatch.py index dcc74852f..f5e08f714 100644 --- a/src/chop/distributed/tensor/_dispatch.py +++ b/src/chop/distributed/tensor/_dispatch.py @@ -42,7 +42,9 @@ def decompose_handler( - op_call: torch._ops.OpOverload, args: Tuple[object, ...], kwargs: Dict[str, object], + op_call: torch._ops.OpOverload, + args: Tuple[object, ...], + kwargs: Dict[str, object], ) -> object: """ Decomposes a op to core ATen op, this handler is mostly here @@ -56,7 +58,9 @@ def decompose_handler( def is_same_size_handler( - op_call: torch._ops.OpOverload, args: Tuple[object, ...], kwargs: Dict[str, object], + op_call: torch._ops.OpOverload, + args: Tuple[object, ...], + kwargs: Dict[str, object], ) -> bool: lhs = cast(torch.Tensor, args[0]) rhs = cast(torch.Tensor, args[1]) @@ -251,7 +255,8 @@ def default_tensor(spec: _DTensorSpec) -> torch.Tensor: @staticmethod def redistribute_local_args( - op_info: OpInfo, suggested_input_schema: OpSchema, + op_info: OpInfo, + suggested_input_schema: OpSchema, ) -> None: # NOTE: it's very rare that we need to reshard kwargs so we intentionally skip it diff --git a/src/chop/distributed/tensor/_redistribute.py b/src/chop/distributed/tensor/_redistribute.py index e65766eb7..e04495126 100644 --- a/src/chop/distributed/tensor/_redistribute.py +++ b/src/chop/distributed/tensor/_redistribute.py @@ -48,7 +48,8 @@ def _replicate_then_shard(val: _TransformInfo) -> int: @lru_cache(maxsize=None) def _gen_transform_infos( - src_spec: _DTensorSpec, dst_spec: _DTensorSpec, + src_spec: _DTensorSpec, + dst_spec: _DTensorSpec, ) -> List[_TransformInfo]: """ Generate the transform infos from the source placements to the target placements. @@ -87,7 +88,9 @@ def _gen_transform_infos( # calculate and save the logical shape for this sharding mesh_dim_size = device_mesh.size(mesh_dim=i) local_shard_size, _ = src._local_shard_size_on_dim( - current_logical_shape[src.dim], mesh_dim_size, my_coordinate[i], + current_logical_shape[src.dim], + mesh_dim_size, + my_coordinate[i], ) new_logical_shape = list(current_logical_shape) new_logical_shape[src.dim] = local_shard_size @@ -285,7 +288,11 @@ def forward( # type: ignore[override] output = input._local_tensor target_spec = current_spec - return dtensor.DTensor(output, target_spec, requires_grad=input.requires_grad,) + return dtensor.DTensor( + output, + target_spec, + requires_grad=input.requires_grad, + ) @staticmethod def backward(ctx, grad_output: "dtensor.DTensor"): # type: ignore[override] @@ -320,7 +327,9 @@ def backward(ctx, grad_output: "dtensor.DTensor"): # type: ignore[override] ), ) output_dtensor = dtensor.DTensor( - output, spec, requires_grad=grad_output.requires_grad, + output, + spec, + requires_grad=grad_output.requires_grad, ) return ( diff --git a/src/chop/distributed/tensor/_sharding_prop.py b/src/chop/distributed/tensor/_sharding_prop.py index 8e1e26db1..7351b7895 100644 --- a/src/chop/distributed/tensor/_sharding_prop.py +++ b/src/chop/distributed/tensor/_sharding_prop.py @@ -48,7 +48,8 @@ class ShardingPropagator: def __init__(self) -> None: self.op_to_rules: Dict[OpOverload, Callable[[OpSchema], OutputSharding]] = {} self.op_strategy_funcs: Dict[ - OpOverload, Callable[[DeviceMesh, OpSchema], StrategyType], + OpOverload, + Callable[[DeviceMesh, OpSchema], StrategyType], ] = {} # op map to save static argnum to decide to reuse sharding prop cache or re-run sharding prop self.op_to_schema_info: Dict[OpOverload, RuntimeSchemaInfo] = {} @@ -343,8 +344,10 @@ def spec_to_strategy(spec: object) -> object: expected_input_spec = selected_strategies[idx].input_spec( tensor_or_list_tensor_arg_idx ) - expected_input_spec = expected_input_spec.shallow_copy_with_tensor_meta( - arg_spec.tensor_meta + expected_input_spec = ( + expected_input_spec.shallow_copy_with_tensor_meta( + arg_spec.tensor_meta + ) ) if arg_spec.placements != expected_input_spec.placements: needs_redistribute = True @@ -360,8 +363,10 @@ def spec_to_strategy(spec: object) -> object: expected_input_spec = selected_strategies[0].input_spec( tensor_or_list_tensor_arg_idx ) - expected_input_spec = expected_input_spec.shallow_copy_with_tensor_meta( - arg.tensor_meta + expected_input_spec = ( + expected_input_spec.shallow_copy_with_tensor_meta( + arg.tensor_meta + ) ) if arg.placements != expected_input_spec.placements: needs_redistribute = True diff --git a/src/chop/distributed/tensor/api.py b/src/chop/distributed/tensor/api.py index a27c848f4..dcef33b24 100644 --- a/src/chop/distributed/tensor/api.py +++ b/src/chop/distributed/tensor/api.py @@ -64,7 +64,9 @@ class _ToTorchTensor(torch.autograd.Function): @staticmethod def forward( # type: ignore[override] - ctx, input: "DTensor", grad_placements: Optional[Sequence[Placement]], + ctx, + input: "DTensor", + grad_placements: Optional[Sequence[Placement]], ): ctx.dtensor_spec = input._spec ctx.grad_placements = grad_placements @@ -98,7 +100,11 @@ def backward(ctx, grad_output: torch.Tensor): # type: ignore[override] ) return ( - DTensor(grad_output, grad_spec, requires_grad=grad_output.requires_grad,), + DTensor( + grad_output, + grad_spec, + requires_grad=grad_output.requires_grad, + ), None, ) @@ -154,7 +160,11 @@ def forward( # type: ignore[override] dist_spec = _DTensorSpec( device_mesh, placements, - tensor_meta=TensorMeta(tensor_shape, tensor_stride, input.dtype,), + tensor_meta=TensorMeta( + tensor_shape, + tensor_stride, + input.dtype, + ), ) # We want a fresh Tensor object that shares memory with the input tensor @@ -207,7 +217,11 @@ class DTensor(torch.Tensor): # pyre-ignore[13]: pyre is bad at __new__ @staticmethod @torch._disable_dynamo def __new__( - cls, local_tensor: torch.Tensor, spec: _DTensorSpec, *, requires_grad: bool, + cls, + local_tensor: torch.Tensor, + spec: _DTensorSpec, + *, + requires_grad: bool, ) -> "DTensor": """ Construct a DTensor from a local tensor, device mesh, and placement and @@ -263,12 +277,20 @@ def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride): local_tensor = inner_tensors["_local_tensor"] spec, requires_grad = flatten_spec unflatten_tensor_meta = TensorMeta( - shape=outer_size, stride=outer_stride, dtype=spec.tensor_meta.dtype, + shape=outer_size, + stride=outer_stride, + dtype=spec.tensor_meta.dtype, ) unflatten_spec = _DTensorSpec( - spec.mesh, spec.placements, tensor_meta=unflatten_tensor_meta, + spec.mesh, + spec.placements, + tensor_meta=unflatten_tensor_meta, + ) + return DTensor( + local_tensor, + unflatten_spec, + requires_grad=requires_grad, ) - return DTensor(local_tensor, unflatten_spec, requires_grad=requires_grad,) def __coerce_tangent_metadata__(self): if not any(isinstance(p, Partial) for p in self.placements): @@ -281,7 +303,8 @@ def __coerce_tangent_metadata__(self): def __coerce_same_metadata_as_tangent__(self, flatten_spec): (spec, _) = flatten_spec # Result of tensor_flatten() return self.redistribute( - device_mesh=self.device_mesh, placements=spec.placements, + device_mesh=self.device_mesh, + placements=spec.placements, ) @classmethod @@ -289,7 +312,11 @@ def __coerce_same_metadata_as_tangent__(self, flatten_spec): # pyre-fixme[3]: Return type must be annotated. # pyre-fixme[2]: Parameter must be annotated. def __torch_dispatch__(cls, func, types, args=(), kwargs=None): - return DTensor._op_dispatcher.dispatch(func, args, kwargs or {},) + return DTensor._op_dispatcher.dispatch( + func, + args, + kwargs or {}, + ) @staticmethod def from_local( @@ -367,7 +394,12 @@ def from_local( # created should flow back the gradients to the local_tensor, so we call an autograd # function to construct the dist tensor instead. return _FromTorchTensor.apply( # pyre-ignore[16]: autograd func - local_tensor, device_mesh, tuple(placements), run_check, shape, stride, + local_tensor, + device_mesh, + tuple(placements), + run_check, + shape, + stride, ) def to_local( @@ -667,7 +699,9 @@ def distribute_tensor( mesh=device_mesh, placements=placements, tensor_meta=TensorMeta( - shape=tensor.size(), stride=tensor.stride(), dtype=tensor.dtype, + shape=tensor.size(), + stride=tensor.stride(), + dtype=tensor.dtype, ), ) return DTensor( diff --git a/src/chop/distributed/tensor/ops/basic_strategy.py b/src/chop/distributed/tensor/ops/basic_strategy.py index 3a9e81ae2..70ba54416 100644 --- a/src/chop/distributed/tensor/ops/basic_strategy.py +++ b/src/chop/distributed/tensor/ops/basic_strategy.py @@ -84,7 +84,10 @@ def parse_dims(cls, input_dims: List[str], output_dim: str) -> "EinsumDims": def gen_einsum_strategies( - equation: str, mesh: DeviceMesh, *, linearity: bool = False, + equation: str, + mesh: DeviceMesh, + *, + linearity: bool = False, ) -> OpStrategy: """ Generate a strategy list for the ops that follow einsum style notation. diff --git a/src/chop/distributed/tensor/ops/common_rules.py b/src/chop/distributed/tensor/ops/common_rules.py index 1324f7806..e0e9a6e34 100644 --- a/src/chop/distributed/tensor/ops/common_rules.py +++ b/src/chop/distributed/tensor/ops/common_rules.py @@ -37,7 +37,10 @@ def _gen_reshard_suggestions( ) suggested_schema = OpSchema(op_schema.op, tuple(suggested_arg_specs), {}) suggested_schema._inplace_rewrap_schema_suggestion(op_schema) - return OutputSharding(None, redistribute_schema=suggested_schema,) + return OutputSharding( + None, + redistribute_schema=suggested_schema, + ) def einop_rule( @@ -215,7 +218,10 @@ def merge_sharding(dim: str, a: int, b: int) -> int: ) return OutputSharding( _DTensorSpec.from_dim_map( - input_specs[0].mesh, output_dim_map, pending_sums, tensor_meta=tensor_meta, + input_specs[0].mesh, + output_dim_map, + pending_sums, + tensor_meta=tensor_meta, ) ) @@ -275,5 +281,8 @@ def pointwise_rule(op_schema: OpSchema, linearity: bool = False) -> OutputShardi enforce_sharding[out_dimchar] = mesh_dim return einop_rule( - fmt, op_schema, linearity=linearity, enforce_sharding=enforce_sharding, + fmt, + op_schema, + linearity=linearity, + enforce_sharding=enforce_sharding, ) diff --git a/src/chop/distributed/tensor/ops/conv_ops.py b/src/chop/distributed/tensor/ops/conv_ops.py index 9a4d5f425..707f391d6 100644 --- a/src/chop/distributed/tensor/ops/conv_ops.py +++ b/src/chop/distributed/tensor/ops/conv_ops.py @@ -50,11 +50,16 @@ def convolution_rules(op_schema: OpSchema) -> OutputSharding: pending_sums = input_spec.sums tensor_meta = TensorMeta( - torch.Size(output_shape), output_stride, input_spec.tensor_meta.dtype, + torch.Size(output_shape), + output_stride, + input_spec.tensor_meta.dtype, ) return OutputSharding( _DTensorSpec.from_dim_map( - input_spec.mesh, output_dim_map, pending_sums, tensor_meta=tensor_meta, + input_spec.mesh, + output_dim_map, + pending_sums, + tensor_meta=tensor_meta, ) ) @@ -83,14 +88,22 @@ def convolution_backward_rules(op_schema: OpSchema) -> OutputSharding: assert input_spec.tensor_meta is not None weight_tensor_meta = weight_spec.tensor_meta bias_tensor_meta = TensorMeta( - torch.Size(bias_shape_opt), (1,), input_spec.tensor_meta.dtype, + torch.Size(bias_shape_opt), + (1,), + input_spec.tensor_meta.dtype, ) grad_input_spec = input_spec grad_weight_spec = _DTensorSpec.from_dim_map( - input_spec.mesh, [-1, -1, -1, -1], [0], tensor_meta=weight_tensor_meta, + input_spec.mesh, + [-1, -1, -1, -1], + [0], + tensor_meta=weight_tensor_meta, ) grad_bias_spec = _DTensorSpec.from_dim_map( - input_spec.mesh, [-1], [0], tensor_meta=bias_tensor_meta, + input_spec.mesh, + [-1], + [0], + tensor_meta=bias_tensor_meta, ) return OutputSharding([grad_input_spec, grad_weight_spec, grad_bias_spec]) diff --git a/src/chop/distributed/tensor/ops/math_ops.py b/src/chop/distributed/tensor/ops/math_ops.py index da8dad8d5..770514219 100644 --- a/src/chop/distributed/tensor/ops/math_ops.py +++ b/src/chop/distributed/tensor/ops/math_ops.py @@ -128,7 +128,7 @@ def _pre_reduce_transform(self, tensor: torch.Tensor) -> torch.Tensor: if self.reduce_op == "sum": assert isinstance(self.norm_type, (int, float)), f"{self.norm_type}" if self.norm_type != 0 and self.norm_type != 1: - return tensor ** self.norm_type + return tensor**self.norm_type return tensor def _post_reduce_transform(self, tensor: torch.Tensor) -> torch.Tensor: @@ -289,7 +289,10 @@ def common_reduction_strategy( redistribute_cost = [generate_redistribute_costs(input_strategy, input_spec)] reduction_strategy.strategies.append( PlacementStrategy( - output_specs=_DTensorSpec(mesh=mesh, placements=out_placements,), + output_specs=_DTensorSpec( + mesh=mesh, + placements=out_placements, + ), input_specs=(input_spec,), redistribute_cost=redistribute_cost, ) @@ -475,7 +478,10 @@ def softmax_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy: @register_op_strategy( - [aten._log_softmax_backward_data.default, aten._softmax_backward_data.default,], + [ + aten._log_softmax_backward_data.default, + aten._softmax_backward_data.default, + ], schema_info=RuntimeSchemaInfo(2), ) def softmax_backward_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy: @@ -608,14 +614,21 @@ def nll_loss_forward_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrate reduce_dims_map, reduction_op, ) - output_expected_spec = _DTensorSpec(mesh=mesh, placements=out_placements,) + output_expected_spec = _DTensorSpec( + mesh=mesh, + placements=out_placements, + ) # whether reduction is sum or mean, the total weight has to be summed up if not replicated total_weight_placements = map_placements_after_reduction( - target_expected_spec.placements, reduce_dims, reduce_dims_map, "sum", + target_expected_spec.placements, + reduce_dims, + reduce_dims_map, + "sum", ) total_weight_expected_spec = _DTensorSpec( - mesh=mesh, placements=total_weight_placements, + mesh=mesh, + placements=total_weight_placements, ) output_strategy.strategies.append( @@ -748,7 +761,8 @@ def rlog(msg): @register_op_strategy( - [aten.native_layer_norm.default], schema_info=RuntimeSchemaInfo(1), + [aten.native_layer_norm.default], + schema_info=RuntimeSchemaInfo(1), ) def layer_norm_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy: # args must be: input, normalized_shape, weight, bias, eps @@ -842,7 +856,8 @@ def layer_norm_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy: @register_op_strategy( - [aten.native_layer_norm_backward.default], schema_info=RuntimeSchemaInfo(2), + [aten.native_layer_norm_backward.default], + schema_info=RuntimeSchemaInfo(2), ) def layer_norm_bwd_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy: # args must be: grad_out, input, normalized_shape, mean, rstd, @@ -999,7 +1014,8 @@ def layer_norm_bwd_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy @register_op_strategy( - [aten.topk.default], schema_info=RuntimeSchemaInfo(2), + [aten.topk.default], + schema_info=RuntimeSchemaInfo(2), ) def topk_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy: input_strategy = cast(OpStrategy, op_schema.args_schema[0]) diff --git a/src/chop/distributed/tensor/ops/matrix_ops.py b/src/chop/distributed/tensor/ops/matrix_ops.py index 029a795df..f76dca190 100644 --- a/src/chop/distributed/tensor/ops/matrix_ops.py +++ b/src/chop/distributed/tensor/ops/matrix_ops.py @@ -384,7 +384,10 @@ def scaled_dot_product_efficient_attention_strategy( single_mesh_dim_strategies.append(num_heads_dim_sharding) return expand_to_full_mesh_op_strategy( - mesh, op_schema, single_mesh_dim_strategies, input_index=4, + mesh, + op_schema, + single_mesh_dim_strategies, + input_index=4, ) @@ -449,5 +452,8 @@ def scaled_dot_product_efficient_attention_backward_strategy( single_mesh_dim_strategies.append(num_heads_dim_sharding) return expand_to_full_mesh_op_strategy( - mesh, op_schema, single_mesh_dim_strategies, input_index=4, + mesh, + op_schema, + single_mesh_dim_strategies, + input_index=4, ) diff --git a/src/chop/distributed/tensor/ops/pointwise_ops.py b/src/chop/distributed/tensor/ops/pointwise_ops.py index 656fd2996..221001f01 100644 --- a/src/chop/distributed/tensor/ops/pointwise_ops.py +++ b/src/chop/distributed/tensor/ops/pointwise_ops.py @@ -483,7 +483,9 @@ def common_pointwise_strategy( common_shape, input_arg_spec.shape ) input_target_placements = map_placements_after_broadcast( - tuple(out_placements), common_shape, input_arg_dims_map, + tuple(out_placements), + common_shape, + input_arg_dims_map, ) input_arg_target_spec = _DTensorSpec( mesh=mesh, @@ -497,7 +499,10 @@ def common_pointwise_strategy( pointwise_strategy.strategies.append( PlacementStrategy( - output_specs=_DTensorSpec(mesh=mesh, placements=tuple(out_placements),), + output_specs=_DTensorSpec( + mesh=mesh, + placements=tuple(out_placements), + ), input_specs=input_specs, redistribute_cost=redistribute_costs, ) diff --git a/src/chop/distributed/tensor/ops/tensor_ops.py b/src/chop/distributed/tensor/ops/tensor_ops.py index f54640e8a..dcddcb98c 100644 --- a/src/chop/distributed/tensor/ops/tensor_ops.py +++ b/src/chop/distributed/tensor/ops/tensor_ops.py @@ -76,7 +76,10 @@ def default_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: @register_op_strategy( - [aten.equal.default, aten.is_same_size.default,] + [ + aten.equal.default, + aten.is_same_size.default, + ] ) def equal_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: # equal_strategy deals with ops that comparing two tensor, we need to make sure @@ -125,7 +128,8 @@ def equal_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: schema_info=RuntimeSchemaInfo(1, ["dtype"]), ) @register_op_strategy( - [aten.full_like.default], schema_info=RuntimeSchemaInfo(2, ["dtype"]), + [aten.full_like.default], + schema_info=RuntimeSchemaInfo(2, ["dtype"]), ) @register_op_strategy( [ @@ -692,7 +696,8 @@ def place(vp: Placement, ip: Placement) -> Placement: ) result = OutputSharding( output_spec=_DTensorSpec( - mesh=values_spec.mesh, placements=value_placements, + mesh=values_spec.mesh, + placements=value_placements, ) ) return result @@ -779,7 +784,10 @@ def size_split(N, i): else split_size_or_sections ) output_spec_list = [ - _DTensorSpec(mesh=input_spec.mesh, placements=input_spec.placements,) + _DTensorSpec( + mesh=input_spec.mesh, + placements=input_spec.placements, + ) for _ in range(len(output_size_list)) ] return OutputSharding(output_spec_list) diff --git a/src/chop/distributed/tensor/ops/utils.py b/src/chop/distributed/tensor/ops/utils.py index e005501d6..28e5245c3 100644 --- a/src/chop/distributed/tensor/ops/utils.py +++ b/src/chop/distributed/tensor/ops/utils.py @@ -196,7 +196,9 @@ def infer_broadcast_dims_map( def map_placements_after_broadcast( - placements: Tuple[Placement, ...], shape: torch.Size, broadcast_dims_map: List[int], + placements: Tuple[Placement, ...], + shape: torch.Size, + broadcast_dims_map: List[int], ) -> Tuple[Placement, ...]: """Map each placement based on the output shape after broadcast.""" new_placements: List[Placement] = [] diff --git a/src/chop/distributed/tensor/ops/view_ops.py b/src/chop/distributed/tensor/ops/view_ops.py index 6b89771ec..5c91f6d64 100644 --- a/src/chop/distributed/tensor/ops/view_ops.py +++ b/src/chop/distributed/tensor/ops/view_ops.py @@ -238,7 +238,9 @@ def dim_flatten(ndim: int, start_dim=0, end_dim=-1) -> DimMap: def dim_movedim( - ndim: int, input: Union[int, Sequence[int]], destination: Union[int, Sequence[int]], + ndim: int, + input: Union[int, Sequence[int]], + destination: Union[int, Sequence[int]], ) -> DimMap: input = normalize_dims(input, ndim) destination = normalize_dims(destination, ndim) @@ -604,7 +606,10 @@ def reshape_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: input_src_spec = input_placement_strategy.output_spec input_tgt_placements, output_placements = propagate_shape_and_sharding( - input_src_spec.placements, tuple(global_in_shape), rules, mesh.shape, + input_src_spec.placements, + tuple(global_in_shape), + rules, + mesh.shape, ) # TODO: optimize this. we shouldn't simply blindly replicate diff --git a/src/chop/ir/graph/mase_graph.py b/src/chop/ir/graph/mase_graph.py index 5ee6b8bb3..322195869 100644 --- a/src/chop/ir/graph/mase_graph.py +++ b/src/chop/ir/graph/mase_graph.py @@ -76,7 +76,13 @@ def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool is_fx_built_in_leaf_module = super().is_leaf_module(m, module_qualified_name) is_mase_leaf_layers = isinstance(m, MASE_LEAF_LAYERS) is_custom_layer = isinstance(m, self.custom_leaf_layers) - return any((is_fx_built_in_leaf_module, is_mase_leaf_layers, is_custom_layer,)) + return any( + ( + is_fx_built_in_leaf_module, + is_mase_leaf_layers, + is_custom_layer, + ) + ) def trace_torch_module( @@ -122,13 +128,19 @@ def is_leaf_module( self, m: torch.nn.Module, module_qualified_name: str ) -> bool: is_hf_built_in_leaf_module = hf_is_leaf_module( - self, m, module_qualified_name, + self, + m, + module_qualified_name, ) is_custom_module = isinstance(m, custom_modules) is_mase_leaf_layer = isinstance(m, MASE_LEAF_LAYERS) return any( - (is_hf_built_in_leaf_module, is_custom_module, is_mase_leaf_layer,) + ( + is_hf_built_in_leaf_module, + is_custom_module, + is_mase_leaf_layer, + ) ) return is_leaf_module @@ -140,7 +152,9 @@ def is_leaf_module( ) graph_module = hf_symbolic_trace( - model, tracer_cls=tracer_cls, input_names=hf_input_names, + model, + tracer_cls=tracer_cls, + input_names=hf_input_names, ) graph_module.custom_ops = custom_ops @@ -293,7 +307,10 @@ def __init__( self.model.additional_inputs = [] elif isinstance(model, torch.nn.Module): self.model = trace_torch_module( - model, cf_args, custom_ops, hf_input_names=hf_input_names, + model, + cf_args, + custom_ops, + hf_input_names=hf_input_names, ) else: raise ValueError( @@ -332,11 +349,16 @@ def from_module( ), f"model must be a torch.nn.Module. Received: {type(model)}" graph_module = trace_torch_module(model, cf_args, custom_ops) - return cls(model=graph_module, cf_args=cf_args,) + return cls( + model=graph_module, + cf_args=cf_args, + ) @classmethod def from_checkpoint( - cls, checkpoint: str, propagate_missing_metadata: bool = True, + cls, + checkpoint: str, + propagate_missing_metadata: bool = True, ): """ Load a MaseGraph from a checkpoint. A MaseGraph checkpoint consists of two files: @@ -371,12 +393,18 @@ def from_checkpoint( for node in mg.nodes: if node.name in loaded_meta.keys(): parameters = loaded_meta[node.name] - node.meta["mase"] = MaseMetadata(node=node, model=loaded_model,) + node.meta["mase"] = MaseMetadata( + node=node, + model=loaded_model, + ) node.meta["mase"].parameters = parameters else: # todo: propagate metadata for missing nodes logger.warning(f"Node {node.name} not found in loaded metadata.") - node.meta["mase"] = MaseMetadata(node=node, model=loaded_model,) + node.meta["mase"] = MaseMetadata( + node=node, + model=loaded_model, + ) for attr in [ "class_for_deserialization", @@ -389,7 +417,8 @@ def from_checkpoint( return mg def export( - self, fname: str = "masegraph", + self, + fname: str = "masegraph", ): """ Export the MaseGraph to a pair of files: {fname}.pt and {fname}.mz. diff --git a/src/chop/ir/graph/mase_metadata.py b/src/chop/ir/graph/mase_metadata.py index 52fcf0127..6f01d83c8 100644 --- a/src/chop/ir/graph/mase_metadata.py +++ b/src/chop/ir/graph/mase_metadata.py @@ -100,7 +100,9 @@ class MaseMetadata: known_storage = ["BRAM"] def __init__( - self, node=None, model=None, + self, + node=None, + model=None, ): # Top-level model self.model = model diff --git a/src/chop/ir/onnx/mase_onnx_graph.py b/src/chop/ir/onnx/mase_onnx_graph.py index e4a8f52ed..f6e1d31e6 100644 --- a/src/chop/ir/onnx/mase_onnx_graph.py +++ b/src/chop/ir/onnx/mase_onnx_graph.py @@ -10,7 +10,9 @@ class MaseOnnxGraph: def __init__( - self, model_proto: onnx.onnx_ml_pb2.ModelProto, model_name: str = None, + self, + model_proto: onnx.onnx_ml_pb2.ModelProto, + model_name: str = None, ): self.model_proto = model_proto self.graph = model_proto.graph diff --git a/src/chop/ir/onnx/utils.py b/src/chop/ir/onnx/utils.py index c5250bba0..bb65aa8b2 100644 --- a/src/chop/ir/onnx/utils.py +++ b/src/chop/ir/onnx/utils.py @@ -188,7 +188,10 @@ def onnx_to_torch_dtype(dtype): "target": torch.mean, "input_mapping": ["input"], "attribute_mapping": {"keepdims": "", "axes": ""}, - "attribute_transform": {"keepdims": None, "axes": None,}, + "attribute_transform": { + "keepdims": None, + "axes": None, + }, "attribute_default": {"keepdims": 1, "axes": None}, }, "Expand": { @@ -332,7 +335,9 @@ def onnx_to_torch_dtype(dtype): "input_mapping": ["input"], "attribute_mapping": {"perm": "dims"}, "attribute_transform": {"perm": lambda x: [i for i in x]}, - "attribute_default": {"perm": None,}, + "attribute_default": { + "perm": None, + }, }, "Max": { "fx_op": "call_function", diff --git a/src/chop/models/bert/modeling_bert_quantized.py b/src/chop/models/bert/modeling_bert_quantized.py index c6747f0e9..8061d2593 100644 --- a/src/chop/models/bert/modeling_bert_quantized.py +++ b/src/chop/models/bert/modeling_bert_quantized.py @@ -561,7 +561,9 @@ def __init__(self, config, quant_config: dict): super().__init__() # self.dense = nn.Linear(config.hidden_size, config.intermediate_size) self.dense = get_quantized_cls("linear", quant_config["dense"])( - config.hidden_size, config.intermediate_size, config=quant_config["dense"], + config.hidden_size, + config.intermediate_size, + config=quant_config["dense"], ) self.quant_config = quant_config if isinstance(config.hidden_act, str): @@ -580,7 +582,9 @@ def __init__(self, config, quant_config): super().__init__() # self.dense = nn.Linear(config.intermediate_size, config.hidden_size) self.dense = get_quantized_cls("linear", quant_config["dense"])( - config.intermediate_size, config.hidden_size, config=quant_config["dense"], + config.intermediate_size, + config.hidden_size, + config=quant_config["dense"], ) self.quant_config = quant_config self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) diff --git a/src/chop/models/bert/quant_config_bert.py b/src/chop/models/bert/quant_config_bert.py index 4dadb97f0..f66c45a4c 100644 --- a/src/chop/models/bert/quant_config_bert.py +++ b/src/chop/models/bert/quant_config_bert.py @@ -88,7 +88,10 @@ def create_a_layer_config( return qc -def _parse_and_complete_config(config: dict, num_hidden_layers: int,) -> dict: +def _parse_and_complete_config( + config: dict, + num_hidden_layers: int, +) -> dict: assert "default" in config, "Must provide a default config" default_qc: dict = config["default"] linear_qc: dict = parse_node_config( diff --git a/src/chop/models/cnv/cnv.py b/src/chop/models/cnv/cnv.py index ee1912f18..8bb9ad468 100644 --- a/src/chop/models/cnv/cnv.py +++ b/src/chop/models/cnv/cnv.py @@ -202,7 +202,8 @@ def forward(self, x: Tensor) -> Tensor: # Getters ------------------------------------------------------------------------------ @register_mase_checkpoint("cnv-toy") def get_cnv_toy( - pretrained=False, **kwargs: Any, + pretrained=False, + **kwargs: Any, ): # image_size = info["image_size"] info = kwargs["dataset_info"] @@ -212,7 +213,8 @@ def get_cnv_toy( @register_mase_checkpoint("cnv") def get_cnv( - pretrained=False, **kwargs: Any, + pretrained=False, + **kwargs: Any, ): # image_size = info["image_size"] info = kwargs["dataset_info"] @@ -222,7 +224,8 @@ def get_cnv( @register_mase_checkpoint("cnv_residual") def get_cnv_residual( - pretrained=False, **kwargs: Any, + pretrained=False, + **kwargs: Any, ): # image_size = info["image_size"] info = kwargs["dataset_info"] diff --git a/src/chop/models/cswin/cswintransformer.py b/src/chop/models/cswin/cswintransformer.py index 449a22d30..cbd85cda6 100644 --- a/src/chop/models/cswin/cswintransformer.py +++ b/src/chop/models/cswin/cswintransformer.py @@ -90,7 +90,7 @@ def __init__( self.num_heads = num_heads head_dim = dim // num_heads # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights - self.scale = qk_scale or head_dim ** -0.5 + self.scale = qk_scale or head_dim**-0.5 if idx == -1: H_sp, W_sp = self.resolution, self.resolution elif idx == 0: @@ -362,9 +362,9 @@ def __init__( super().__init__() self.use_chk = use_chk self.num_classes = num_classes - self.num_features = ( - self.embed_dim - ) = embed_dim # num_features for consistency with other models + self.num_features = self.embed_dim = ( + embed_dim # num_features for consistency with other models + ) heads = num_heads self.stage1_conv_embed = nn.Sequential( diff --git a/src/chop/models/deit/deit_v2.py b/src/chop/models/deit/deit_v2.py index 21b2a5ad4..4f0128216 100644 --- a/src/chop/models/deit/deit_v2.py +++ b/src/chop/models/deit/deit_v2.py @@ -28,7 +28,7 @@ def __init__( super().__init__() self.num_heads = num_heads head_dim = dim // num_heads - self.scale = qk_scale or head_dim ** -0.5 + self.scale = qk_scale or head_dim**-0.5 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) diff --git a/src/chop/models/efficientnet/efficientnet.py b/src/chop/models/efficientnet/efficientnet.py index b3af62d05..5d2f7730a 100644 --- a/src/chop/models/efficientnet/efficientnet.py +++ b/src/chop/models/efficientnet/efficientnet.py @@ -499,7 +499,8 @@ def forward(self, x: Tensor) -> Tensor: def _efficientnet_conf( - arch: str, **kwargs: Any, + arch: str, + **kwargs: Any, ) -> Tuple[Sequence[Union[MBConvConfig, FusedMBConvConfig]], Optional[int]]: inverted_residual_setting: Sequence[Union[MBConvConfig, FusedMBConvConfig]] if arch.startswith("efficientnet_b"): @@ -591,7 +592,9 @@ def _efficientnet( def get_efficientnet_b0( - info: Dict, pretrained: bool = False, **kwargs: Any, + info: Dict, + pretrained: bool = False, + **kwargs: Any, ): num_classes = info.num_classes if pretrained: @@ -613,7 +616,9 @@ def get_efficientnet_b0( def get_efficientnet_b3( - info: Dict, pretrained: bool = False, **kwargs: Any, + info: Dict, + pretrained: bool = False, + **kwargs: Any, ): num_classes = info.num_classes if pretrained: @@ -636,7 +641,9 @@ def get_efficientnet_b3( def get_efficientnet_v2_s( - info: Dict, pretrained: bool = False, **kwargs: Any, + info: Dict, + pretrained: bool = False, + **kwargs: Any, ): num_classes = info.num_classes if pretrained: @@ -656,7 +663,9 @@ def get_efficientnet_v2_s( def get_efficientnet_v2_m( - info: Dict, pretrained: bool = False, **kwargs: Any, + info: Dict, + pretrained: bool = False, + **kwargs: Any, ): num_classes = info.num_classes if pretrained: @@ -676,7 +685,9 @@ def get_efficientnet_v2_m( def get_efficientnet_v2_l( - info: Dict, pretrained: bool = False, **kwargs: Any, + info: Dict, + pretrained: bool = False, + **kwargs: Any, ): num_classes = info.num_classes if pretrained: diff --git a/src/chop/models/lfc/lfc.py b/src/chop/models/lfc/lfc.py index 410359a25..70da9df45 100644 --- a/src/chop/models/lfc/lfc.py +++ b/src/chop/models/lfc/lfc.py @@ -42,7 +42,9 @@ def forward(self, x): # Getters ------------------------------------------------------------------------------ def get_lfc( - info, pretrained=False, **kwargs: Any, + info, + pretrained=False, + **kwargs: Any, ): image_size = info["image_size"] num_classes = info.num_classes diff --git a/src/chop/models/llama/modeling_llama_llora.py b/src/chop/models/llama/modeling_llama_llora.py index 45eb2f210..b45dd8494 100644 --- a/src/chop/models/llama/modeling_llama_llora.py +++ b/src/chop/models/llama/modeling_llama_llora.py @@ -195,7 +195,10 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids): class LlamaMLP(nn.Module): def __init__( - self, hidden_size: int, intermediate_size: int, hidden_act: str, + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, ): super().__init__() # fmt: off diff --git a/src/chop/models/llama/modeling_llama_sparse.py b/src/chop/models/llama/modeling_llama_sparse.py index 21359ea34..f5f37eba0 100644 --- a/src/chop/models/llama/modeling_llama_sparse.py +++ b/src/chop/models/llama/modeling_llama_sparse.py @@ -195,7 +195,10 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids): class LlamaMLP(nn.Module): def __init__( - self, hidden_size: int, intermediate_size: int, hidden_act: str, + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, ): super().__init__() # fmt: off diff --git a/src/chop/models/mobilenet_v2/mobilenet_v2.py b/src/chop/models/mobilenet_v2/mobilenet_v2.py index 0980d58dd..d8f4020dd 100644 --- a/src/chop/models/mobilenet_v2/mobilenet_v2.py +++ b/src/chop/models/mobilenet_v2/mobilenet_v2.py @@ -288,7 +288,8 @@ def __init__( # building classifier self.classifier = nn.Sequential( - nn.Dropout(p=dropout), nn.Linear(self.last_channel, num_classes), + nn.Dropout(p=dropout), + nn.Linear(self.last_channel, num_classes), ) # weight initialization diff --git a/src/chop/models/nerf/nerf_vision.py b/src/chop/models/nerf/nerf_vision.py index 18338940b..f8591f1a4 100644 --- a/src/chop/models/nerf/nerf_vision.py +++ b/src/chop/models/nerf/nerf_vision.py @@ -139,7 +139,9 @@ def load_weights_from_keras(self, weights): # Getters ------------------------------------------------------------------------------ def get_nerf( - info, pretrained=False, **kwargs: Any, + info, + pretrained=False, + **kwargs: Any, ): # image_size = info["image_size"] num_classes = info.num_classes diff --git a/src/chop/models/opt/modeling_opt.py b/src/chop/models/opt/modeling_opt.py index 430f77df8..e88db8c1f 100644 --- a/src/chop/models/opt/modeling_opt.py +++ b/src/chop/models/opt/modeling_opt.py @@ -132,7 +132,7 @@ def __init__( f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" f" and `num_heads`: {num_heads})." ) - self.scaling = self.head_dim ** -0.5 + self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) diff --git a/src/chop/models/opt/modeling_opt_lora.py b/src/chop/models/opt/modeling_opt_lora.py index 30bb21f20..810cfe0d6 100644 --- a/src/chop/models/opt/modeling_opt_lora.py +++ b/src/chop/models/opt/modeling_opt_lora.py @@ -130,7 +130,7 @@ def __init__( f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" f" and `num_heads`: {num_heads})." ) - self.scaling = self.head_dim ** -0.5 + self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder lora_config = config.lora_config[f"model_layer_{layer_id}"]["self_attn"] diff --git a/src/chop/models/opt/modeling_opt_quantized.py b/src/chop/models/opt/modeling_opt_quantized.py index d6eb94343..d1086f413 100644 --- a/src/chop/models/opt/modeling_opt_quantized.py +++ b/src/chop/models/opt/modeling_opt_quantized.py @@ -168,7 +168,7 @@ def __init__( f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" f" and `num_heads`: {num_heads})." ) - self.scaling = self.head_dim ** -0.5 + self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder # fmt:off diff --git a/src/chop/models/opt/modeling_opt_sparse.py b/src/chop/models/opt/modeling_opt_sparse.py index 4addd6525..f39a29cc9 100644 --- a/src/chop/models/opt/modeling_opt_sparse.py +++ b/src/chop/models/opt/modeling_opt_sparse.py @@ -130,7 +130,7 @@ def __init__( f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" f" and `num_heads`: {num_heads})." ) - self.scaling = self.head_dim ** -0.5 + self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder sparse_config = config.sparse_config[f"model_layer_{layer_id}"]["self_attn"] diff --git a/src/chop/models/opt/quant_config_opt_quantized.py b/src/chop/models/opt/quant_config_opt_quantized.py index 150a1b2d6..a76f8adb8 100644 --- a/src/chop/models/opt/quant_config_opt_quantized.py +++ b/src/chop/models/opt/quant_config_opt_quantized.py @@ -32,7 +32,9 @@ def create_a_layer_config( - linear_qc: dict = None, bmm_qc: dict = None, layer_qc=None, + linear_qc: dict = None, + bmm_qc: dict = None, + layer_qc=None, ) -> dict: if (layer_qc is None and bmm_qc is None) and layer_qc is None: raise ValueError("Must provide either (linear_qc & bmm_qc ) or layer_qc") diff --git a/src/chop/models/pvt/pvt.py b/src/chop/models/pvt/pvt.py index 2474fb388..ac3a18708 100644 --- a/src/chop/models/pvt/pvt.py +++ b/src/chop/models/pvt/pvt.py @@ -58,7 +58,7 @@ def __init__( self.dim = dim self.num_heads = num_heads head_dim = dim // num_heads - self.scale = qk_scale or head_dim ** -0.5 + self.scale = qk_scale or head_dim**-0.5 self.q = nn.Linear(dim, dim, bias=qkv_bias) self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) @@ -182,7 +182,12 @@ def forward(self, x): @register_mase_model( "pvt", - checkpoints=["pvt_tiny", "pvt_small", "pvt_medium", "pvt_large",], + checkpoints=[ + "pvt_tiny", + "pvt_small", + "pvt_medium", + "pvt_large", + ], model_source="vision_others", task_type="vision", image_classification=True, diff --git a/src/chop/models/pvt/pvt_v2.py b/src/chop/models/pvt/pvt_v2.py index b70acb6d1..3ffb45f6b 100644 --- a/src/chop/models/pvt/pvt_v2.py +++ b/src/chop/models/pvt/pvt_v2.py @@ -82,7 +82,7 @@ def __init__( self.dim = dim self.num_heads = num_heads head_dim = dim // num_heads - self.scale = qk_scale or head_dim ** -0.5 + self.scale = qk_scale or head_dim**-0.5 self.q = nn.Linear(dim, dim, bias=qkv_bias) self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) diff --git a/src/chop/models/repvgg/repvgg.py b/src/chop/models/repvgg/repvgg.py index 5ef7998e0..086501fc5 100644 --- a/src/chop/models/repvgg/repvgg.py +++ b/src/chop/models/repvgg/repvgg.py @@ -143,14 +143,14 @@ def get_custom_L2(self): .detach() ) - l2_loss_circle = (K3 ** 2).sum() - ( + l2_loss_circle = (K3**2).sum() - ( K3[:, :, 1:2, 1:2] ** 2 ).sum() # The L2 loss of the "circle" of weights in 3x3 kernel. Use regular L2 on them. eq_kernel = ( K3[:, :, 1:2, 1:2] * t3 + K1 * t1 ) # The equivalent resultant central point of 3x3 kernel. l2_loss_eq_kernel = ( - eq_kernel ** 2 / (t3 ** 2 + t1 ** 2) + eq_kernel**2 / (t3**2 + t1**2) ).sum() # Normalize for an L2 coefficient comparable to regular L2. return l2_loss_eq_kernel + l2_loss_circle diff --git a/src/chop/models/resnet/resnet.py b/src/chop/models/resnet/resnet.py index b339b72a7..313c8ce36 100644 --- a/src/chop/models/resnet/resnet.py +++ b/src/chop/models/resnet/resnet.py @@ -157,7 +157,13 @@ def forward(self, x: Tensor) -> Tensor: @register_mase_model( name="resnet", - checkpoints=["resnet18", "resnet34", "resnet50", "resnet101", "wide_resnet50_2",], + checkpoints=[ + "resnet18", + "resnet34", + "resnet50", + "resnet101", + "wide_resnet50_2", + ], model_source="torchvision", task_type="vision", image_classification=True, @@ -331,7 +337,10 @@ def _resnet( @register_mase_checkpoint("resnet18") -def get_resnet18(pretrained: bool = False, **kwargs: Any,) -> ResNet: +def get_resnet18( + pretrained: bool = False, + **kwargs: Any, +) -> ResNet: """ResNet-18 from `Deep Residual Learning for Image Recognition `__.""" if pretrained: pretrained_weight_cls = ResNet18_Weights.IMAGENET1K_V1 @@ -339,12 +348,18 @@ def get_resnet18(pretrained: bool = False, **kwargs: Any,) -> ResNet: pretrained_weight_cls = None return _resnet( - BasicBlock, [2, 2, 2, 2], pretrained_weight_cls=pretrained_weight_cls, **kwargs, + BasicBlock, + [2, 2, 2, 2], + pretrained_weight_cls=pretrained_weight_cls, + **kwargs, ) @register_mase_checkpoint("resnet34") -def get_resnet34(pretrained: bool = False, **kwargs: Any,) -> ResNet: +def get_resnet34( + pretrained: bool = False, + **kwargs: Any, +) -> ResNet: """ResNet-34 from `Deep Residual Learning for Image Recognition `__.""" if pretrained: pretrained_weight_cls = ResNet34_Weights.IMAGENET1K_V1 @@ -352,12 +367,18 @@ def get_resnet34(pretrained: bool = False, **kwargs: Any,) -> ResNet: pretrained_weight_cls = None return _resnet( - BasicBlock, [2, 2, 2, 2], pretrained_weight_cls=pretrained_weight_cls, **kwargs, + BasicBlock, + [2, 2, 2, 2], + pretrained_weight_cls=pretrained_weight_cls, + **kwargs, ) @register_mase_checkpoint("resnet50") -def get_resnet50(pretrained: bool = False, **kwargs: Any,) -> ResNet: +def get_resnet50( + pretrained: bool = False, + **kwargs: Any, +) -> ResNet: """ResNet-50 from `Deep Residual Learning for Image Recognition `__.""" info = kwargs["dataset_info"] if pretrained: @@ -366,12 +387,18 @@ def get_resnet50(pretrained: bool = False, **kwargs: Any,) -> ResNet: pretrained_weight_cls = None return _resnet( - Bottleneck, [3, 4, 6, 3], pretrained_weight_cls=pretrained_weight_cls, **kwargs, + Bottleneck, + [3, 4, 6, 3], + pretrained_weight_cls=pretrained_weight_cls, + **kwargs, ) @register_mase_checkpoint("resnet101") -def get_resnet101(pretrained: bool = False, **kwargs: Any,) -> ResNet: +def get_resnet101( + pretrained: bool = False, + **kwargs: Any, +) -> ResNet: """ResNet-101 from `Deep Residual Learning for Image Recognition `__.""" info = kwargs["dataset_info"] if pretrained: @@ -380,13 +407,17 @@ def get_resnet101(pretrained: bool = False, **kwargs: Any,) -> ResNet: pretrained_weight_cls = None return _resnet( - BasicBlock, [2, 2, 2, 2], pretrained_weight_cls=pretrained_weight_cls, **kwargs, + BasicBlock, + [2, 2, 2, 2], + pretrained_weight_cls=pretrained_weight_cls, + **kwargs, ) @register_mase_checkpoint("wide_resnet50_2") def get_wide_resnet50_2( - pretrained: bool = False, **kwargs, + pretrained: bool = False, + **kwargs, ): """ `Wide Residual Networks `_. @@ -404,5 +435,8 @@ def get_wide_resnet50_2( _ovewrite_named_param(kwargs, "width_per_group", 64 * 2) return _resnet( - Bottleneck, [3, 4, 6, 3], pretrained_weight_cls=pretrained_weight_cls, **kwargs, + Bottleneck, + [3, 4, 6, 3], + pretrained_weight_cls=pretrained_weight_cls, + **kwargs, ) diff --git a/src/chop/models/toy/toy.py b/src/chop/models/toy/toy.py index 3890c4a0f..e88e030f6 100644 --- a/src/chop/models/toy/toy.py +++ b/src/chop/models/toy/toy.py @@ -160,7 +160,8 @@ def _conv_block(self, conv_class, *args): # Getters ------------------------------------------------------------------------------ @register_mase_checkpoint("toy") def get_toynet( - pretrained=False, **kwargs: Any, + pretrained=False, + **kwargs: Any, ): info = kwargs["dataset_info"] image_size = info.image_size @@ -169,7 +170,9 @@ def get_toynet( def get_toy_tiny( - info, pretrained=False, **kwargs: Any, + info, + pretrained=False, + **kwargs: Any, ): image_size = info.image_size num_classes = info.num_classes @@ -177,7 +180,9 @@ def get_toy_tiny( def get_toy_testmodel( - info, pretrained=False, **kwargs: Any, + info, + pretrained=False, + **kwargs: Any, ): image_size = info["image_size"] num_classes = info.num_classes @@ -186,7 +191,9 @@ def get_toy_testmodel( def get_toy_convnet( - info, pretrained=False, **kwargs: Any, + info, + pretrained=False, + **kwargs: Any, ): # NOTE: The model isn't configurable through the CLI or a configuration file yet. num_classes = info.num_classes diff --git a/src/chop/models/vgg_cifar/vgg_orig.py b/src/chop/models/vgg_cifar/vgg_orig.py index f744c4404..cf906e888 100644 --- a/src/chop/models/vgg_cifar/vgg_orig.py +++ b/src/chop/models/vgg_cifar/vgg_orig.py @@ -189,7 +189,12 @@ class VGG11_Weights(WeightsEnum): meta={ **_COMMON_META, "num_params": 132863336, - "_metrics": {"ImageNet-1K": {"acc@1": 69.020, "acc@5": 88.628,}}, + "_metrics": { + "ImageNet-1K": { + "acc@1": 69.020, + "acc@5": 88.628, + } + }, "_ops": 7.609, "_file_size": 506.84, }, @@ -204,7 +209,12 @@ class VGG11_BN_Weights(WeightsEnum): meta={ **_COMMON_META, "num_params": 132868840, - "_metrics": {"ImageNet-1K": {"acc@1": 70.370, "acc@5": 89.810,}}, + "_metrics": { + "ImageNet-1K": { + "acc@1": 70.370, + "acc@5": 89.810, + } + }, "_ops": 7.609, "_file_size": 506.881, }, @@ -219,7 +229,12 @@ class VGG13_Weights(WeightsEnum): meta={ **_COMMON_META, "num_params": 133047848, - "_metrics": {"ImageNet-1K": {"acc@1": 69.928, "acc@5": 89.246,}}, + "_metrics": { + "ImageNet-1K": { + "acc@1": 69.928, + "acc@5": 89.246, + } + }, "_ops": 11.308, "_file_size": 507.545, }, @@ -234,7 +249,12 @@ class VGG13_BN_Weights(WeightsEnum): meta={ **_COMMON_META, "num_params": 133053736, - "_metrics": {"ImageNet-1K": {"acc@1": 71.586, "acc@5": 90.374,}}, + "_metrics": { + "ImageNet-1K": { + "acc@1": 71.586, + "acc@5": 90.374, + } + }, "_ops": 11.308, "_file_size": 507.59, }, @@ -249,7 +269,12 @@ class VGG16_Weights(WeightsEnum): meta={ **_COMMON_META, "num_params": 138357544, - "_metrics": {"ImageNet-1K": {"acc@1": 71.592, "acc@5": 90.382,}}, + "_metrics": { + "ImageNet-1K": { + "acc@1": 71.592, + "acc@5": 90.382, + } + }, "_ops": 15.47, "_file_size": 527.796, }, @@ -269,7 +294,10 @@ class VGG16_Weights(WeightsEnum): "categories": None, "recipe": "https://github.com/amdegroot/ssd.pytorch#training-ssd", "_metrics": { - "ImageNet-1K": {"acc@1": float("nan"), "acc@5": float("nan"),} + "ImageNet-1K": { + "acc@1": float("nan"), + "acc@5": float("nan"), + } }, "_ops": 15.47, "_file_size": 527.802, @@ -290,7 +318,12 @@ class VGG16_BN_Weights(WeightsEnum): meta={ **_COMMON_META, "num_params": 138365992, - "_metrics": {"ImageNet-1K": {"acc@1": 73.360, "acc@5": 91.516,}}, + "_metrics": { + "ImageNet-1K": { + "acc@1": 73.360, + "acc@5": 91.516, + } + }, "_ops": 15.47, "_file_size": 527.866, }, @@ -305,7 +338,12 @@ class VGG19_Weights(WeightsEnum): meta={ **_COMMON_META, "num_params": 143667240, - "_metrics": {"ImageNet-1K": {"acc@1": 72.376, "acc@5": 90.876,}}, + "_metrics": { + "ImageNet-1K": { + "acc@1": 72.376, + "acc@5": 90.876, + } + }, "_ops": 19.632, "_file_size": 548.051, }, @@ -320,7 +358,12 @@ class VGG19_BN_Weights(WeightsEnum): meta={ **_COMMON_META, "num_params": 143678248, - "_metrics": {"ImageNet-1K": {"acc@1": 74.218, "acc@5": 91.842,}}, + "_metrics": { + "ImageNet-1K": { + "acc@1": 74.218, + "acc@5": 91.842, + } + }, "_ops": 19.632, "_file_size": 548.143, }, diff --git a/src/chop/models/vision/snn/snn_toy.py b/src/chop/models/vision/snn/snn_toy.py index 52c86d821..5915f1ec6 100644 --- a/src/chop/models/vision/snn/snn_toy.py +++ b/src/chop/models/vision/snn/snn_toy.py @@ -26,7 +26,9 @@ def forward(self, x: torch.Tensor): # Getters ------------------------------------------------------------------------------ def get_snn_toy( - info, pretrained=False, **kwargs: Any, + info, + pretrained=False, + **kwargs: Any, ): tau = info["tau"] num_classes = info.num_classes diff --git a/src/chop/models/vision/snn/spikingResformer.py b/src/chop/models/vision/snn/spikingResformer.py index 153cae155..ff74d76df 100644 --- a/src/chop/models/vision/snn/spikingResformer.py +++ b/src/chop/models/vision/snn/spikingResformer.py @@ -124,7 +124,11 @@ def no_weight_decay(self): @register_model def spikingresformer_ti(**kwargs): return SpikingResformer( - [["DSSA", "GWFFN"] * 1, ["DSSA", "GWFFN"] * 2, ["DSSA", "GWFFN"] * 3,], + [ + ["DSSA", "GWFFN"] * 1, + ["DSSA", "GWFFN"] * 2, + ["DSSA", "GWFFN"] * 3, + ], [64, 192, 384], [1, 3, 6], [4, 2, 1], @@ -136,7 +140,11 @@ def spikingresformer_ti(**kwargs): @register_model def spikingresformer_s(**kwargs): return SpikingResformer( - [["DSSA", "GWFFN"] * 1, ["DSSA", "GWFFN"] * 2, ["DSSA", "GWFFN"] * 3,], + [ + ["DSSA", "GWFFN"] * 1, + ["DSSA", "GWFFN"] * 2, + ["DSSA", "GWFFN"] * 3, + ], [64, 256, 512], [1, 4, 8], [4, 2, 1], @@ -148,7 +156,11 @@ def spikingresformer_s(**kwargs): @register_model def spikingresformer_m(**kwargs): return SpikingResformer( - [["DSSA", "GWFFN"] * 1, ["DSSA", "GWFFN"] * 2, ["DSSA", "GWFFN"] * 3,], + [ + ["DSSA", "GWFFN"] * 1, + ["DSSA", "GWFFN"] * 2, + ["DSSA", "GWFFN"] * 3, + ], [64, 384, 768], [1, 6, 12], [4, 2, 1], @@ -160,7 +172,11 @@ def spikingresformer_m(**kwargs): @register_model def spikingresformer_l(**kwargs): return SpikingResformer( - [["DSSA", "GWFFN"] * 1, ["DSSA", "GWFFN"] * 2, ["DSSA", "GWFFN"] * 3,], + [ + ["DSSA", "GWFFN"] * 1, + ["DSSA", "GWFFN"] * 2, + ["DSSA", "GWFFN"] * 3, + ], [128, 512, 1024], [2, 8, 16], [4, 2, 1], @@ -172,13 +188,18 @@ def spikingresformer_l(**kwargs): @register_model def spikingresformer_dvsg(**kwargs): return SpikingResformer( - [["DSSA", "GWFFN"] * 1, ["DSSA", "GWFFN"] * 2, ["DSSA", "GWFFN"] * 3,], + [ + ["DSSA", "GWFFN"] * 1, + ["DSSA", "GWFFN"] * 2, + ["DSSA", "GWFFN"] * 3, + ], [32, 96, 192], [1, 3, 6], [4, 2, 1], in_channels=3, prologue=nn.Sequential( - Conv2d(3, 32, 3, 1, 1, bias=False, step_mode="m"), BN(32), + Conv2d(3, 32, 3, 1, 1, bias=False, step_mode="m"), + BN(32), ), group_size=32, activation=PLIF, @@ -189,13 +210,18 @@ def spikingresformer_dvsg(**kwargs): @register_model def spikingresformer_cifar(**kwargs): return SpikingResformer( - [["DSSA", "GWFFN"] * 1, ["DSSA", "GWFFN"] * 2, ["DSSA", "GWFFN"] * 3,], + [ + ["DSSA", "GWFFN"] * 1, + ["DSSA", "GWFFN"] * 2, + ["DSSA", "GWFFN"] * 3, + ], [64, 192, 384], [1, 3, 6], [4, 2, 1], in_channels=3, prologue=nn.Sequential( - Conv2d(3, 64, 3, 1, 1, bias=False, step_mode="m"), BN(64), + Conv2d(3, 64, 3, 1, 1, bias=False, step_mode="m"), + BN(64), ), **kwargs, ) diff --git a/src/chop/nn/functional/softermax.py b/src/chop/nn/functional/softermax.py index 348d7ebea..8653a6aa0 100644 --- a/src/chop/nn/functional/softermax.py +++ b/src/chop/nn/functional/softermax.py @@ -11,7 +11,7 @@ def softermax(input: Tensor, dim: int) -> Tensor: Tensor: Output tensor """ out = input - input.max(dim=dim, keepdim=True).values.floor() - out = 2 ** out + out = 2**out row_sum = out.sum(dim=dim, keepdim=True) # Elementwise division out = out / row_sum diff --git a/src/chop/nn/modules/gqa.py b/src/chop/nn/modules/gqa.py index 38188d51d..f30b7a8d8 100644 --- a/src/chop/nn/modules/gqa.py +++ b/src/chop/nn/modules/gqa.py @@ -127,7 +127,12 @@ def _qkv_states(self, x: Tensor, batch_size: int, seq_len: int): # return x def _attention_mechanism( - self, query: Tensor, key: Tensor, value: Tensor, batch_size: int, seq_len: int, + self, + query: Tensor, + key: Tensor, + value: Tensor, + batch_size: int, + seq_len: int, ): key = repeat_kv(key, n_rep=self.group_size) value = repeat_kv(value, n_rep=self.group_size) @@ -164,7 +169,9 @@ def forward(self, x: Tensor): GROUPS = 4 gqa_module = GroupedQueryAttention( - embed_dim=EMBED_DIM, num_heads=NUM_HEADS, num_kv_heads=GROUPS, + embed_dim=EMBED_DIM, + num_heads=NUM_HEADS, + num_kv_heads=GROUPS, ) x_in = torch.rand(BATCH, SEQ_LEN, EMBED_DIM) diff --git a/src/chop/nn/modules/lora.py b/src/chop/nn/modules/lora.py index f101d0c03..bbd2ea023 100644 --- a/src/chop/nn/modules/lora.py +++ b/src/chop/nn/modules/lora.py @@ -93,7 +93,11 @@ def reset_lora_parameters(self, adapter_name): class LinearLora(nn.Linear, LoraLayer): # Lora implemented in a dense layer def __init__( - self, in_features: int, out_features: int, config: dict = None, **kwargs, + self, + in_features: int, + out_features: int, + config: dict = None, + **kwargs, ): self.config = config init_lora_weights = self.config.get("init_lora_weights", True) @@ -218,7 +222,12 @@ def mark_only_lora_as_trainable(model: nn.Module, bias: str = "none") -> None: # Simple Lora implementation from https://pytorch.org/torchtune/stable/tutorials/lora_finetune.html class LoRALinear(nn.Module): def __init__( - self, in_dim: int, out_dim: int, rank: int, alpha: float, dropout: float, + self, + in_dim: int, + out_dim: int, + rank: int, + alpha: float, + dropout: float, ): super().__init__() # These are the weights from the original pretrained model diff --git a/src/chop/nn/modules/sparse.py b/src/chop/nn/modules/sparse.py index 5a3f749a1..67508cf3d 100644 --- a/src/chop/nn/modules/sparse.py +++ b/src/chop/nn/modules/sparse.py @@ -92,7 +92,11 @@ def reset_sparse_parameters(self, adapter_name): class LinearSparse(nn.Linear, SparseLayer): def __init__( - self, in_features: int, out_features: int, config: dict = None, **kwargs, + self, + in_features: int, + out_features: int, + config: dict = None, + **kwargs, ): self.config = config init_sparse_weights = self.config.get("init_sparse_weights", True) @@ -155,7 +159,9 @@ def _linear(self, input: torch.Tensor) -> torch.Tensor: def update_weight_selection(self, k): w_flat = self.weight.flatten() _, self.idx = torch.topk( - self.index_method(w_flat, self.idx_method), k, sorted=True, + self.index_method(w_flat, self.idx_method), + k, + sorted=True, ) self.selected_weights = torch.gather(w_flat, dim=0, index=self.idx) @@ -186,7 +192,10 @@ def forward(self, x: torch.Tensor): # Scatter adapted values into weight tensor adapted_weights = torch.scatter( - self.zero_tensor.to(x.device), dim=0, index=self.idx, src=scaled_output, + self.zero_tensor.to(x.device), + dim=0, + index=self.idx, + src=scaled_output, ).view(self.unflattened_size) self.step += 1 @@ -194,7 +203,9 @@ def forward(self, x: torch.Tensor): x = x.to(sparse.weight.dtype) result = F.linear( - dropout(x), transpose(new_weight, self.fan_in_fan_out), bias=self.bias, + dropout(x), + transpose(new_weight, self.fan_in_fan_out), + bias=self.bias, ) else: diff --git a/src/chop/nn/mx/activations.py b/src/chop/nn/mx/activations.py index 27fafc352..39702cc29 100644 --- a/src/chop/nn/mx/activations.py +++ b/src/chop/nn/mx/activations.py @@ -434,7 +434,10 @@ def forward(ctx, input, inplace=False, mx_specs=None, name=None): @staticmethod def backward(ctx, grad_output): - (y, sig_x,) = ctx.saved_tensors + ( + y, + sig_x, + ) = ctx.saved_tensors grad_output = vec_quantize(grad_output, mx_specs=ctx.mx_specs) temp = vec_sub(1.0, sig_x, mx_specs=ctx.mx_specs) diff --git a/src/chop/nn/mx/bmm.py b/src/chop/nn/mx/bmm.py index 50071d06e..6fac4eef8 100644 --- a/src/chop/nn/mx/bmm.py +++ b/src/chop/nn/mx/bmm.py @@ -67,7 +67,9 @@ def backward(ctx, grad_out): in1, in2 = ctx.saved_tensors grad_out = quantize_elemwise_op( - grad_out, mx_specs=ctx.mx_specs, round=ctx.mx_specs["round_grad_input"], + grad_out, + mx_specs=ctx.mx_specs, + round=ctx.mx_specs["round_grad_input"], ) ##################################################### @@ -109,10 +111,14 @@ def backward(ctx, grad_out): # element-wise quantize for grad_in1 and grad_in2 grad_in1 = quantize_elemwise_op( - grad_in1, mx_specs=ctx.mx_specs, round=ctx.mx_specs["round_grad_input"], + grad_in1, + mx_specs=ctx.mx_specs, + round=ctx.mx_specs["round_grad_input"], ) grad_in2 = quantize_elemwise_op( - grad_in2, mx_specs=ctx.mx_specs, round=ctx.mx_specs["round_grad_input"], + grad_in2, + mx_specs=ctx.mx_specs, + round=ctx.mx_specs["round_grad_input"], ) return (grad_in1, grad_in2, None, None) diff --git a/src/chop/nn/mx/convolution.py b/src/chop/nn/mx/convolution.py index 9d2aaed17..62b8dd2d2 100644 --- a/src/chop/nn/mx/convolution.py +++ b/src/chop/nn/mx/convolution.py @@ -180,10 +180,16 @@ def forward( # weight is (out_channels, in_channels/groups, ..) # quantize along in_channels qid_input = quantize_mx_op( - bf_in, mx_specs, elem_format=mx_specs["a_elem_format"], axes=[1], + bf_in, + mx_specs, + elem_format=mx_specs["a_elem_format"], + axes=[1], ) qid_weight = quantize_mx_op( - bf_weight, mx_specs, elem_format=mx_specs["w_elem_format"], axes=[1], + bf_weight, + mx_specs, + elem_format=mx_specs["w_elem_format"], + axes=[1], ) # compute output @@ -207,7 +213,9 @@ def backward(ctx, grad_output): input, weight = ctx.saved_tensors grad_output = quantize_elemwise_op( - grad_output, mx_specs=ctx.mx_specs, round=ctx.mx_specs["round_grad_input"], + grad_output, + mx_specs=ctx.mx_specs, + round=ctx.mx_specs["round_grad_input"], ) ##################################################### @@ -217,7 +225,10 @@ def backward(ctx, grad_output): # output is (batch, out_channels, ...) # quantize along the batch dim qex_input = quantize_mx_op( - input, ctx.mx_specs, elem_format=ctx.mx_specs["a_elem_format"], axes=[0], + input, + ctx.mx_specs, + elem_format=ctx.mx_specs["a_elem_format"], + axes=[0], ) qex_grad_output = quantize_mx_op( grad_output, @@ -240,7 +251,9 @@ def backward(ctx, grad_output): # element-wise quantize for grad_weight grad_weight = quantize_elemwise_op( - grad_weight, mx_specs=ctx.mx_specs, round=ctx.mx_specs["round_grad_weight"], + grad_weight, + mx_specs=ctx.mx_specs, + round=ctx.mx_specs["round_grad_weight"], ) ##################################################### @@ -251,7 +264,10 @@ def backward(ctx, grad_output): # output is (batch, out_channels, ...) # reduction dim is out_channels qod_weight = quantize_mx_op( - weight, ctx.mx_specs, elem_format=ctx.mx_specs["w_elem_format"], axes=[0], + weight, + ctx.mx_specs, + elem_format=ctx.mx_specs["w_elem_format"], + axes=[0], ) qod_grad_output = quantize_mx_op( grad_output, @@ -273,7 +289,9 @@ def backward(ctx, grad_output): # element-wise quantize for grad_input grad_input = quantize_elemwise_op( - grad_input, mx_specs=ctx.mx_specs, round=ctx.mx_specs["round_grad_input"], + grad_input, + mx_specs=ctx.mx_specs, + round=ctx.mx_specs["round_grad_input"], ) ##################################################### diff --git a/src/chop/nn/mx/elemwise_ops.py b/src/chop/nn/mx/elemwise_ops.py index 21067bba1..5d2b4b8ec 100644 --- a/src/chop/nn/mx/elemwise_ops.py +++ b/src/chop/nn/mx/elemwise_ops.py @@ -33,16 +33,16 @@ # exponents smaller than -126 def _safe_lshift(x, bits, exp): if exp is None: - return x * (2 ** bits) + return x * (2**bits) else: - return x / (2 ** exp) * (2 ** bits) + return x / (2**exp) * (2**bits) def _safe_rshift(x, bits, exp): if exp is None: - return x / (2 ** bits) + return x / (2**bits) else: - return x / (2 ** bits) * (2 ** exp) + return x / (2**bits) * (2**exp) def _round_mantissa(A, bits, round, clamp=False): diff --git a/src/chop/nn/mx/formats.py b/src/chop/nn/mx/formats.py index ff2a9d726..9a2c35189 100644 --- a/src/chop/nn/mx/formats.py +++ b/src/chop/nn/mx/formats.py @@ -52,14 +52,14 @@ def from_str(s): def _get_min_norm(ebits): """Valid for all float formats""" emin = 2 - (2 ** (ebits - 1)) - return 0 if ebits == 0 else 2 ** emin + return 0 if ebits == 0 else 2**emin def _get_max_norm(ebits, mbits): """Valid only for floats that define NaN""" assert ebits >= 5, "invalid for floats that don't define NaN" emax = 0 if ebits == 0 else 2 ** (ebits - 1) - 1 - return 2 ** emax * float(2 ** (mbits - 1) - 1) / 2 ** (mbits - 2) + return 2**emax * float(2 ** (mbits - 1) - 1) / 2 ** (mbits - 2) _FORMAT_CACHE = {} @@ -121,9 +121,9 @@ def _get_format_params(fmt): raise Exception("Unknown element format %s" % fmt) if fmt != ElemFormat.fp8_e4m3: - max_norm = 2 ** emax * float(2 ** (mbits - 1) - 1) / 2 ** (mbits - 2) + max_norm = 2**emax * float(2 ** (mbits - 1) - 1) / 2 ** (mbits - 2) else: - max_norm = 2 ** emax * 1.75 # FP8 has custom max_norm + max_norm = 2**emax * 1.75 # FP8 has custom max_norm min_norm = _get_min_norm(ebits) diff --git a/src/chop/nn/mx/linear.py b/src/chop/nn/mx/linear.py index ef78fcedc..64f761c86 100644 --- a/src/chop/nn/mx/linear.py +++ b/src/chop/nn/mx/linear.py @@ -18,7 +18,12 @@ class LinearFunction(torch.autograd.Function): @staticmethod def forward( - ctx, input, weight, bias=None, mx_specs=None, name=None, + ctx, + input, + weight, + bias=None, + mx_specs=None, + name=None, ): # element-wise quantize for input bf_in = quantize_elemwise_op( @@ -85,7 +90,9 @@ def backward(ctx, grad_output): in_dim = weight.shape[1] grad_output = quantize_elemwise_op( - grad_output, mx_specs=ctx.mx_specs, round=ctx.mx_specs["round_grad_input"], + grad_output, + mx_specs=ctx.mx_specs, + round=ctx.mx_specs["round_grad_input"], ) ##################################################### @@ -115,7 +122,9 @@ def backward(ctx, grad_output): # Compute grad_weight grad_weight = torch_matmul(qex_grad_output.transpose(0, 1), qex_input) grad_weight = quantize_elemwise_op( - grad_weight, mx_specs=ctx.mx_specs, round=ctx.mx_specs["round_grad_weight"], + grad_weight, + mx_specs=ctx.mx_specs, + round=ctx.mx_specs["round_grad_weight"], ) ##################################################### @@ -141,7 +150,9 @@ def backward(ctx, grad_output): # Compute grad_input grad_input = torch_matmul(qos_grad_output, qos_weight) grad_input = quantize_elemwise_op( - grad_input, mx_specs=ctx.mx_specs, round=ctx.mx_specs["round_grad_input"], + grad_input, + mx_specs=ctx.mx_specs, + round=ctx.mx_specs["round_grad_input"], ) ##################################################### @@ -161,7 +172,11 @@ def backward(ctx, grad_output): def linear( - input, weight, bias=None, mx_specs=None, name=None, + input, + weight, + bias=None, + mx_specs=None, + name=None, ): mx_assert_test(mx_specs) if mx_specs is None: @@ -174,7 +189,12 @@ def linear( class Linear(torch.nn.Linear): def __init__( - self, in_features, out_features, bias=True, mx_specs=None, name=None, + self, + in_features, + out_features, + bias=True, + mx_specs=None, + name=None, ): mx_assert_test(mx_specs) self.mx_none = mx_specs is None diff --git a/src/chop/nn/mx/matmul.py b/src/chop/nn/mx/matmul.py index 3c2e0ffb4..8c18914b7 100644 --- a/src/chop/nn/mx/matmul.py +++ b/src/chop/nn/mx/matmul.py @@ -114,7 +114,9 @@ def backward(ctx, grad_out): in1, in2 = ctx.saved_tensors grad_out = quantize_elemwise_op( - grad_out, mx_specs=ctx.mx_specs, round=ctx.mx_specs["round_grad_input"], + grad_out, + mx_specs=ctx.mx_specs, + round=ctx.mx_specs["round_grad_input"], ) ##################################################### @@ -158,10 +160,14 @@ def backward(ctx, grad_out): # element-wise quantize for grad_in1 and grad_in2 grad_in1 = quantize_elemwise_op( - grad_in1, mx_specs=ctx.mx_specs, round=ctx.mx_specs["round_grad_input"], + grad_in1, + mx_specs=ctx.mx_specs, + round=ctx.mx_specs["round_grad_input"], ) grad_in2 = quantize_elemwise_op( - grad_in2, mx_specs=ctx.mx_specs, round=ctx.mx_specs["round_grad_input"], + grad_in2, + mx_specs=ctx.mx_specs, + round=ctx.mx_specs["round_grad_input"], ) ##################################################### diff --git a/src/chop/nn/mx/mx_ops.py b/src/chop/nn/mx/mx_ops.py index aad12a52f..136441bea 100644 --- a/src/chop/nn/mx/mx_ops.py +++ b/src/chop/nn/mx/mx_ops.py @@ -281,7 +281,10 @@ def _quantize_mx( else: # Get shared exponents shared_exp = _shared_exponents( - A, method=shared_exp_method, axes=shared_exp_axes, ebits=0, + A, + method=shared_exp_method, + axes=shared_exp_axes, + ebits=0, ) # Flush subnormal FP32 inputs to zero @@ -296,7 +299,7 @@ def _quantize_mx( shared_exp[shared_exp > scale_emax] = float("NaN") shared_exp[shared_exp < -scale_emax] = -scale_emax - A = A / (2 ** shared_exp) + A = A / (2**shared_exp) A = _quantize_elemwise_core( A, @@ -309,7 +312,7 @@ def _quantize_mx( custom_cuda=custom_cuda, ) - A = A * (2 ** shared_exp) + A = A * (2**shared_exp) # Undo tile reshaping if block_size: diff --git a/src/chop/nn/mx/quantize.py b/src/chop/nn/mx/quantize.py index b979b18d7..7e0b907fa 100644 --- a/src/chop/nn/mx/quantize.py +++ b/src/chop/nn/mx/quantize.py @@ -42,7 +42,9 @@ def forward(ctx, x, mx_specs, round=None): @staticmethod def backward(ctx, grad_output): grad_input = quantize_elemwise_op( - grad_output, mx_specs=ctx.mx_specs, round=ctx.round, + grad_output, + mx_specs=ctx.mx_specs, + round=ctx.round, ) return (grad_input, None, None) diff --git a/src/chop/nn/mx/simd_ops.py b/src/chop/nn/mx/simd_ops.py index 28180b758..475b47ef7 100644 --- a/src/chop/nn/mx/simd_ops.py +++ b/src/chop/nn/mx/simd_ops.py @@ -307,7 +307,7 @@ def forward(ctx, in1, mx_specs=None): else: ctx.save_for_backward(in1) - return vec_quantize(qin1 ** 2, mx_specs=mx_specs) + return vec_quantize(qin1**2, mx_specs=mx_specs) @staticmethod def backward(ctx, g): diff --git a/src/chop/nn/mx/transpose_convolution.py b/src/chop/nn/mx/transpose_convolution.py index 5d6c5da8a..af570025b 100644 --- a/src/chop/nn/mx/transpose_convolution.py +++ b/src/chop/nn/mx/transpose_convolution.py @@ -75,10 +75,16 @@ def forward( # weight is (in_channels, out_channels/groups, ...) # quantize along in_channels qid_input = quantize_mx_op( - bf_in, mx_specs, elem_format=mx_specs["a_elem_format"], axes=[1], + bf_in, + mx_specs, + elem_format=mx_specs["a_elem_format"], + axes=[1], ) qid_weight = quantize_mx_op( - bf_weight, mx_specs, elem_format=mx_specs["w_elem_format"], axes=[0], + bf_weight, + mx_specs, + elem_format=mx_specs["w_elem_format"], + axes=[0], ) # compute output @@ -108,7 +114,9 @@ def backward(ctx, grad_output): input, weight = ctx.saved_tensors grad_output = quantize_elemwise_op( - grad_output, mx_specs=ctx.mx_specs, round=ctx.mx_specs["round_grad_input"], + grad_output, + mx_specs=ctx.mx_specs, + round=ctx.mx_specs["round_grad_input"], ) ##################################################### @@ -118,7 +126,10 @@ def backward(ctx, grad_output): # output is (batch, out_channels, ...) # quantize along the batch dim qex_input = quantize_mx_op( - input, ctx.mx_specs, elem_format=ctx.mx_specs["a_elem_format"], axes=[0], + input, + ctx.mx_specs, + elem_format=ctx.mx_specs["a_elem_format"], + axes=[0], ) qex_grad_output = quantize_mx_op( grad_output, @@ -139,7 +150,9 @@ def backward(ctx, grad_output): # element-wise quantize for grad_weight grad_weight = quantize_elemwise_op( - grad_weight, mx_specs=ctx.mx_specs, round=ctx.mx_specs["round_grad_weight"], + grad_weight, + mx_specs=ctx.mx_specs, + round=ctx.mx_specs["round_grad_weight"], ) ##################################################### @@ -149,7 +162,10 @@ def backward(ctx, grad_output): # output is (batch, out_channels, ...) # reduction dim is out_channels qod_weight = quantize_mx_op( - weight, ctx.mx_specs, elem_format=ctx.mx_specs["w_elem_format"], axes=[1], + weight, + ctx.mx_specs, + elem_format=ctx.mx_specs["w_elem_format"], + axes=[1], ) qod_grad_output = quantize_mx_op( grad_output, @@ -171,7 +187,9 @@ def backward(ctx, grad_output): # element-wise quantize for grad_input grad_input = quantize_elemwise_op( - grad_input, mx_specs=ctx.mx_specs, round=ctx.mx_specs["round_grad_input"], + grad_input, + mx_specs=ctx.mx_specs, + round=ctx.mx_specs["round_grad_input"], ) ##################################################### diff --git a/src/chop/nn/optical/modules/morr_conv2d.py b/src/chop/nn/optical/modules/morr_conv2d.py index c95cd1bb3..13f9532c3 100644 --- a/src/chop/nn/optical/modules/morr_conv2d.py +++ b/src/chop/nn/optical/modules/morr_conv2d.py @@ -109,7 +109,7 @@ def __init__( self.v_max = 10.8 self.v_pi = 4.36 - self.gamma = np.pi / self.v_pi ** 2 + self.gamma = np.pi / self.v_pi**2 self.w_bit = 32 self.in_bit = 32 self.MORRConfig = MORRConfig @@ -123,7 +123,7 @@ def __init__( ### calculate FWHM (rad) self.morr_fwhm = ( -4 - * np.pi ** 2 + * np.pi**2 * MORRConfig.radius * MORRConfig.effective_index * ( @@ -241,7 +241,7 @@ def reset_parameters(self, morr_init: bool = False) -> None: (t2 - t1) / (2.4 * self.morr_fwhm) ).item() ## 0~2.4 FWHM slope as a linear approximation - self.sigma_out_scale = 4 / (3 * self.grid_dim_x ** 0.5 * g * self.morr_fwhm) + self.sigma_out_scale = 4 / (3 * self.grid_dim_x**0.5 * g * self.morr_fwhm) self.out_scale_quant_gain = None init.normal_(self.morr_output_scale, 0, self.sigma_out_scale) diff --git a/src/chop/nn/optical/modules/morr_linear.py b/src/chop/nn/optical/modules/morr_linear.py index c1181cfc1..f93281af5 100644 --- a/src/chop/nn/optical/modules/morr_linear.py +++ b/src/chop/nn/optical/modules/morr_linear.py @@ -62,7 +62,7 @@ def __init__( self.v_max = 10.8 self.v_pi = 4.36 - self.gamma = np.pi / self.v_pi ** 2 + self.gamma = np.pi / self.v_pi**2 self.w_bit = 32 self.in_bit = 32 @@ -80,7 +80,7 @@ def __init__( ### calculate FWHM (rad) self.morr_fwhm = ( -4 - * np.pi ** 2 + * np.pi**2 * morr_config.radius * morr_config.effective_index * ( @@ -198,7 +198,7 @@ def reset_parameters(self, morr_init: bool = False) -> None: (t2 - t1) / (2.4 * self.morr_fwhm) ).item() ## 0~2.4 FWHM slope as a linear approximation - self.sigma_out_scale = 4 / (3 * self.grid_dim_x ** 0.5 * g * self.morr_fwhm) + self.sigma_out_scale = 4 / (3 * self.grid_dim_x**0.5 * g * self.morr_fwhm) self.out_scale_quant_gain = None init.normal_(self.morr_output_scale, 0, self.sigma_out_scale) else: diff --git a/src/chop/nn/optical/utils/initializer.py b/src/chop/nn/optical/utils/initializer.py index 19c97b6dd..cbdd1f83c 100644 --- a/src/chop/nn/optical/utils/initializer.py +++ b/src/chop/nn/optical/utils/initializer.py @@ -31,7 +31,7 @@ def morr_uniform_(tensor, MORRConfig, n_op=4, biased=False, gain=1): """ morr_fwhm = ( -4 - * np.pi ** 2 + * np.pi**2 * MORRConfig.radius * MORRConfig.effective_index * ( diff --git a/src/chop/nn/optical/utils/mrr_op.py b/src/chop/nn/optical/utils/mrr_op.py index 6c397b3f5..189db8761 100644 --- a/src/chop/nn/optical/utils/mrr_op.py +++ b/src/chop/nn/optical/utils/mrr_op.py @@ -78,7 +78,7 @@ def mrr_roundtrip_phase_to_tr_func( c1 = -2 * a * r c2 = a * a + r * r c3 = 1 + r * r * a * a - a * a - r * r - c4 = (a ** 2 - 1) * (r ** 2 - 1) * 2 * a * r + c4 = (a**2 - 1) * (r**2 - 1) * 2 * a * r class MRRRoundTripPhaseToTrFunction(torch.autograd.Function): @staticmethod diff --git a/src/chop/nn/optical/utils/quantize.py b/src/chop/nn/optical/utils/quantize.py index 828da142d..84372c8c7 100644 --- a/src/chop/nn/optical/utils/quantize.py +++ b/src/chop/nn/optical/utils/quantize.py @@ -32,7 +32,7 @@ def forward(ctx, input): elif k == 1: out = torch.sign(input) else: - n = float(2 ** k - 1) + n = float(2**k - 1) out = torch.round(input * n) / n return out @@ -63,7 +63,7 @@ def forward(ctx, input, scale, zero_point): elif k == 1: out = torch.sign(input) else: - n = float(2 ** k - 1) + n = float(2**k - 1) # out = torch.round(input * n) / n # out = (torch.clamp(torch.round(input / scale + zero_point), 0, n) - zero_point) * scale out = ( @@ -133,7 +133,7 @@ def __init__( qscheme=torch.per_tensor_affine, reduce_range=False, quant_min=0, - quant_max=2 ** self.in_bit - 1, + quant_max=2**self.in_bit - 1, ).to(self.device) else: self.obs = None diff --git a/src/chop/nn/quantized/functional/gelu.py b/src/chop/nn/quantized/functional/gelu.py index 225ff70ce..cee5e3317 100644 --- a/src/chop/nn/quantized/functional/gelu.py +++ b/src/chop/nn/quantized/functional/gelu.py @@ -107,7 +107,9 @@ def gelu_log(x, inplace=False, config=None): ) x_quantizer = partial( - log_quantizer, width=x_width, exponent_bias=x_exponent_bias, + log_quantizer, + width=x_width, + exponent_bias=x_exponent_bias, ) return F.gelu(x_quantizer(x), inplace=inplace) diff --git a/src/chop/nn/quantized/functional/linear.py b/src/chop/nn/quantized/functional/linear.py index f518f4968..5fda700de 100644 --- a/src/chop/nn/quantized/functional/linear.py +++ b/src/chop/nn/quantized/functional/linear.py @@ -72,7 +72,10 @@ def linearInteger( def linearMinifloatDenorm( - x: Tensor, weight: Tensor, bias: Tensor = None, config: dict = None, + x: Tensor, + weight: Tensor, + bias: Tensor = None, + config: dict = None, ): w_width, w_exponent_width, w_exponent_bias = ( config["weight_width"], @@ -119,7 +122,10 @@ def linearMinifloatDenorm( def linearMinifloatIEEE( - x: Tensor, weight: Tensor, bias: Tensor = None, config: dict = None, + x: Tensor, + weight: Tensor, + bias: Tensor = None, + config: dict = None, ): w_width, w_exponent_width, w_exponent_bias = ( config["weight_width"], @@ -166,7 +172,10 @@ def linearMinifloatIEEE( def linearMinifloatIEEE( - x: Tensor, weight: Tensor, bias: Tensor = None, config: dict = None, + x: Tensor, + weight: Tensor, + bias: Tensor = None, + config: dict = None, ): w_width, w_exponent_width, w_exponent_bias = ( config["weight_width"], @@ -213,7 +222,10 @@ def linearMinifloatIEEE( def linearLog( - x: Tensor, weight: Tensor, bias: Tensor = None, config: dict = None, + x: Tensor, + weight: Tensor, + bias: Tensor = None, + config: dict = None, ): w_width, w_exponent_bias = ( config["weight_width"], @@ -228,11 +240,23 @@ def linearLog( config["bias_exponent_bias"], ) - w_quantizer = partial(log_quantizer, width=w_width, exponent_bias=w_exponent_bias,) + w_quantizer = partial( + log_quantizer, + width=w_width, + exponent_bias=w_exponent_bias, + ) - x_quantizer = partial(log_quantizer, width=x_width, exponent_bias=x_exponent_bias,) + x_quantizer = partial( + log_quantizer, + width=x_width, + exponent_bias=x_exponent_bias, + ) - b_quantizer = partial(log_quantizer, width=b_width, exponent_bias=b_exponent_bias,) + b_quantizer = partial( + log_quantizer, + width=b_width, + exponent_bias=b_exponent_bias, + ) x = x_quantizer(x) weight = w_quantizer(weight) @@ -242,7 +266,10 @@ def linearLog( def linearBlockFP( - x: Tensor, weight: Tensor, bias: Tensor = None, config: dict = None, + x: Tensor, + weight: Tensor, + bias: Tensor = None, + config: dict = None, ): # establish quantizers w_width, w_exponent_width, w_exponent_bias, w_block_size = ( @@ -300,7 +327,10 @@ def linearBlockFP( def linearBlockMinifloat( - x: Tensor, weight: Tensor, bias: Tensor = None, config: dict = None, + x: Tensor, + weight: Tensor, + bias: Tensor = None, + config: dict = None, ): # establish quantizers w_width, w_exponent_width, w_exponent_bias_width, w_block_size = ( @@ -358,7 +388,10 @@ def linearBlockMinifloat( def linearBlockLog( - x: Tensor, weight: Tensor, bias: Tensor = None, config: dict = None, + x: Tensor, + weight: Tensor, + bias: Tensor = None, + config: dict = None, ): # establish quantizers w_width, w_exponent_bias_width, w_block_size = ( @@ -410,7 +443,10 @@ def linearBlockLog( def linearBinary( - x: Tensor, weight: Tensor, bias: Tensor = None, config: dict = None, + x: Tensor, + weight: Tensor, + bias: Tensor = None, + config: dict = None, ): w_stochastic = config["weight_stochastic"] w_bipolar = config["weight_bipolar"] @@ -426,7 +462,10 @@ def linearBinary( def linearBinaryScaling( - x: Tensor, weight: Tensor, bias: Tensor = None, config: dict = None, + x: Tensor, + weight: Tensor, + bias: Tensor = None, + config: dict = None, ): """ Binary scaling variant of the linear transformation layer. @@ -474,7 +513,10 @@ def linearBinaryScaling( def linearTernary( - x: Tensor, weight: Tensor, bias: Tensor = None, config: dict = None, + x: Tensor, + weight: Tensor, + bias: Tensor = None, + config: dict = None, ): w_scaling_factor = config["weight_scaling_factor"] w_mean = get_stats(config, "weight_mean") @@ -498,19 +540,28 @@ def linearTernary( def linearBinaryResidualSign( - x: Tensor, weight: Tensor, bias: Tensor = None, config: dict = None, + x: Tensor, + weight: Tensor, + bias: Tensor = None, + config: dict = None, ): raise NotImplementedError def linearLUT( - x: Tensor, weight: Tensor, bias: Tensor = None, config: dict = None, + x: Tensor, + weight: Tensor, + bias: Tensor = None, + config: dict = None, ): raise NotImplementedError def linearLogicNets( - x: Tensor, weight: Tensor, bias: Tensor = None, config: dict = None, + x: Tensor, + weight: Tensor, + bias: Tensor = None, + config: dict = None, ): raise NotImplementedError diff --git a/src/chop/nn/quantized/functional/matmul.py b/src/chop/nn/quantized/functional/matmul.py index f6487dc8f..d06eb1ece 100644 --- a/src/chop/nn/quantized/functional/matmul.py +++ b/src/chop/nn/quantized/functional/matmul.py @@ -176,10 +176,14 @@ def generic_matmul_log(x, y, config, style="matmul"): ) x_quantizer = partial( - log_quantizer, width=x_width, exponent_bias=x_exponent_bias, + log_quantizer, + width=x_width, + exponent_bias=x_exponent_bias, ) y_quantizer = partial( - log_quantizer, width=y_width, exponent_bias=y_exponent_bias, + log_quantizer, + width=y_width, + exponent_bias=y_exponent_bias, ) x = x_quantizer(x) y = y_quantizer(y) diff --git a/src/chop/nn/quantized/functional/relu.py b/src/chop/nn/quantized/functional/relu.py index cb57d078e..57daed04a 100644 --- a/src/chop/nn/quantized/functional/relu.py +++ b/src/chop/nn/quantized/functional/relu.py @@ -107,7 +107,9 @@ def relu_log(x, inplace=False, config=None): ) x_quantizer = partial( - log_quantizer, width=x_width, exponent_bias=x_exponent_bias, + log_quantizer, + width=x_width, + exponent_bias=x_exponent_bias, ) return F.relu(x_quantizer(x), inplace=inplace) diff --git a/src/chop/nn/quantized/functional/selu.py b/src/chop/nn/quantized/functional/selu.py index b1edebdbf..12956c392 100644 --- a/src/chop/nn/quantized/functional/selu.py +++ b/src/chop/nn/quantized/functional/selu.py @@ -107,7 +107,9 @@ def selu_log(x, inplace=False, config=None): ) x_quantizer = partial( - log_quantizer, width=x_width, exponent_bias=x_exponent_bias, + log_quantizer, + width=x_width, + exponent_bias=x_exponent_bias, ) return F.selu(x_quantizer(x), inplace=inplace) diff --git a/src/chop/nn/quantized/functional/softplus.py b/src/chop/nn/quantized/functional/softplus.py index c873e7104..a9bd7dafc 100644 --- a/src/chop/nn/quantized/functional/softplus.py +++ b/src/chop/nn/quantized/functional/softplus.py @@ -107,7 +107,9 @@ def softplus_log(x, inplace=False, config=None): ) x_quantizer = partial( - log_quantizer, width=x_width, exponent_bias=x_exponent_bias, + log_quantizer, + width=x_width, + exponent_bias=x_exponent_bias, ) return F.softplus(x_quantizer(x), inplace=inplace) diff --git a/src/chop/nn/quantized/functional/softsign.py b/src/chop/nn/quantized/functional/softsign.py index c60f5f757..3eaab47fa 100644 --- a/src/chop/nn/quantized/functional/softsign.py +++ b/src/chop/nn/quantized/functional/softsign.py @@ -107,7 +107,9 @@ def softsign_log(x, inplace=False, config=None): ) x_quantizer = partial( - log_quantizer, width=x_width, exponent_bias=x_exponent_bias, + log_quantizer, + width=x_width, + exponent_bias=x_exponent_bias, ) return F.softsign(x_quantizer(x), inplace=inplace) diff --git a/src/chop/nn/quantized/functional/tanh.py b/src/chop/nn/quantized/functional/tanh.py index 8b1009ac0..7d3c67c31 100644 --- a/src/chop/nn/quantized/functional/tanh.py +++ b/src/chop/nn/quantized/functional/tanh.py @@ -107,7 +107,9 @@ def tanh_log(x, inplace=False, config=None): ) x_quantizer = partial( - log_quantizer, width=x_width, exponent_bias=x_exponent_bias, + log_quantizer, + width=x_width, + exponent_bias=x_exponent_bias, ) return F.tanh(x_quantizer(x), inplace=inplace) diff --git a/src/chop/nn/quantized/modules/attention_head.py b/src/chop/nn/quantized/modules/attention_head.py index 8176f4d52..93f4801e8 100644 --- a/src/chop/nn/quantized/modules/attention_head.py +++ b/src/chop/nn/quantized/modules/attention_head.py @@ -63,7 +63,10 @@ class BertSelfAttentionHeadInteger(_BertSelfAttentionHeadBase): def __init__(self, config, q_config: dict = None) -> None: super().__init__(config) - self.query_quantizer = partial(integer_quantizer, **q_config,) + self.query_quantizer = partial( + integer_quantizer, + **q_config, + ) self.key_quantizer = partial(integer_quantizer, **q_config) self.value_quantizer = partial(integer_quantizer, **q_config) diff --git a/src/chop/nn/quantized/modules/conv1d.py b/src/chop/nn/quantized/modules/conv1d.py index 03662097a..67654d917 100644 --- a/src/chop/nn/quantized/modules/conv1d.py +++ b/src/chop/nn/quantized/modules/conv1d.py @@ -274,15 +274,21 @@ def __init__( ) self.w_quantizer = partial( - log_quantizer, width=w_width, exponent_bias=w_exponent_bias, + log_quantizer, + width=w_width, + exponent_bias=w_exponent_bias, ) self.x_quantizer = partial( - log_quantizer, width=x_width, exponent_bias=x_exponent_bias, + log_quantizer, + width=x_width, + exponent_bias=x_exponent_bias, ) self.b_quantizer = partial( - log_quantizer, width=b_width, exponent_bias=b_exponent_bias, + log_quantizer, + width=b_width, + exponent_bias=b_exponent_bias, ) diff --git a/src/chop/nn/quantized/modules/conv2d.py b/src/chop/nn/quantized/modules/conv2d.py index 3d297842e..cc8cd982a 100644 --- a/src/chop/nn/quantized/modules/conv2d.py +++ b/src/chop/nn/quantized/modules/conv2d.py @@ -363,15 +363,21 @@ def __init__( ) self.w_quantizer = partial( - log_quantizer, width=w_width, exponent_bias=w_exponent_bias, + log_quantizer, + width=w_width, + exponent_bias=w_exponent_bias, ) self.x_quantizer = partial( - log_quantizer, width=x_width, exponent_bias=x_exponent_bias, + log_quantizer, + width=x_width, + exponent_bias=x_exponent_bias, ) self.b_quantizer = partial( - log_quantizer, width=b_width, exponent_bias=b_exponent_bias, + log_quantizer, + width=b_width, + exponent_bias=b_exponent_bias, ) @@ -424,15 +430,21 @@ def __init__( ) self.w_quantizer = partial( - log_quantizer, width=w_width, exponent_bias=w_exponent_bias, + log_quantizer, + width=w_width, + exponent_bias=w_exponent_bias, ) self.x_quantizer = partial( - log_quantizer, width=x_width, exponent_bias=x_exponent_bias, + log_quantizer, + width=x_width, + exponent_bias=x_exponent_bias, ) self.b_quantizer = partial( - log_quantizer, width=b_width, exponent_bias=b_exponent_bias, + log_quantizer, + width=b_width, + exponent_bias=b_exponent_bias, ) @@ -1128,7 +1140,10 @@ def __init__( ) self.unfold = torch.nn.Unfold( - kernel_size=kernel_size, dilation=dilation, padding=padding, stride=stride, + kernel_size=kernel_size, + dilation=dilation, + padding=padding, + stride=stride, ) self.fold = torch.nn.Fold( @@ -1228,7 +1243,11 @@ def forward( expanded_input, targets, initalize ).squeeze() # [10, 589824] output = output.view( - batch_size, self.out_channels, self._out_dim(0), self._out_dim(1), -1, + batch_size, + self.out_channels, + self._out_dim(0), + self._out_dim(1), + -1, ).sum( -1 ) # [10, 256, 1, 1, 2304] -> [10, 256, 1, 1] @@ -1411,10 +1430,10 @@ def forward(self, x: Tensor) -> Tensor: return self.decode(self.lut_forward(x)) def encode(self, input: Tensor) -> Tensor: - return input * 2 ** self.x_frac_width + return input * 2**self.x_frac_width def decode(self, input: Tensor) -> Tensor: - return input / 2 ** self.x_frac_width + return input / 2**self.x_frac_width def math_forward(self, input: Tensor) -> Tensor: return self.y_quantizer( diff --git a/src/chop/nn/quantized/modules/gelu.py b/src/chop/nn/quantized/modules/gelu.py index ace6f5510..4e579efd1 100644 --- a/src/chop/nn/quantized/modules/gelu.py +++ b/src/chop/nn/quantized/modules/gelu.py @@ -123,7 +123,9 @@ def __init__(self, inplace: bool = False, config: dict = None): config["data_in_exponent_bias"], ) self.x_quantizer = partial( - log_quantizer, width=x_width, exponent_bias=x_exponent_bias, + log_quantizer, + width=x_width, + exponent_bias=x_exponent_bias, ) @@ -141,7 +143,9 @@ def __init__(self, inplace: bool = False, config: dict = None): config["data_in_exponent_bias"], ) self.x_quantizer = partial( - log_quantizer, width=x_width, exponent_bias=x_exponent_bias, + log_quantizer, + width=x_width, + exponent_bias=x_exponent_bias, ) diff --git a/src/chop/nn/quantized/modules/gqa.py b/src/chop/nn/quantized/modules/gqa.py index 6cac74dda..1445b05ec 100644 --- a/src/chop/nn/quantized/modules/gqa.py +++ b/src/chop/nn/quantized/modules/gqa.py @@ -89,7 +89,10 @@ def __init__( ) self.v_matmul_func = partial( - matmul_integer, config=config, out_config=out_config, floor=floor, + matmul_integer, + config=config, + out_config=out_config, + floor=floor, ) o_projection_q_config = { diff --git a/src/chop/nn/quantized/modules/linear.py b/src/chop/nn/quantized/modules/linear.py index b49c86c54..5d8d389a5 100644 --- a/src/chop/nn/quantized/modules/linear.py +++ b/src/chop/nn/quantized/modules/linear.py @@ -66,7 +66,11 @@ def __init__( dtype=None, ) -> None: super().__init__( - in_features, out_features, bias, device, dtype, + in_features, + out_features, + bias, + device, + dtype, ) self.bypass = False self.pruning_masks = None @@ -444,7 +448,9 @@ def forward(self, x: Tensor) -> Tensor: if self.binary_training: w = self.w_quantizer(self.weight) return F.linear( - x_expanded, w * self.gamma.abs() * self.pruning_masks, self.bias, + x_expanded, + w * self.gamma.abs() * self.pruning_masks, + self.bias, ) else: self.weigh = self.weight.data.clamp_(-1, 1) @@ -554,7 +560,9 @@ def forward( output = output.view(batch_size, -1) assert output.shape[-1] == self.tables_count output = output.view( - batch_size, self.out_features, int(self.tables_count / self.out_features), + batch_size, + self.out_features, + int(self.tables_count / self.out_features), ) output = output.sum(-1) if self.bias is not None: @@ -765,10 +773,10 @@ def run_layers(self, input: Tensor, layers) -> Tensor: return y def encode(self, input: Tensor) -> Tensor: - return input * 2 ** self.x_frac_width + return input * 2**self.x_frac_width def decode(self, input: Tensor) -> Tensor: - return input / 2 ** self.x_frac_width + return input / 2**self.x_frac_width def forward(self, x: Tensor) -> Tensor: if self.is_lut_inference: diff --git a/src/chop/nn/quantized/modules/relu.py b/src/chop/nn/quantized/modules/relu.py index a4d1acfbf..2bc527161 100644 --- a/src/chop/nn/quantized/modules/relu.py +++ b/src/chop/nn/quantized/modules/relu.py @@ -121,7 +121,9 @@ def __init__(self, inplace: bool = False, config: dict = None): config["data_in_exponent_bias"], ) self.x_quantizer = partial( - log_quantizer, width=x_width, exponent_bias=x_exponent_bias, + log_quantizer, + width=x_width, + exponent_bias=x_exponent_bias, ) @@ -139,7 +141,9 @@ def __init__(self, inplace: bool = False, config: dict = None): config["data_in_exponent_bias"], ) self.x_quantizer = partial( - log_quantizer, width=x_width, exponent_bias=x_exponent_bias, + log_quantizer, + width=x_width, + exponent_bias=x_exponent_bias, ) diff --git a/src/chop/nn/quantized/modules/selu.py b/src/chop/nn/quantized/modules/selu.py index 482c7c5d1..066ffc0b7 100644 --- a/src/chop/nn/quantized/modules/selu.py +++ b/src/chop/nn/quantized/modules/selu.py @@ -121,7 +121,9 @@ def __init__(self, inplace: bool = False, config: dict = None): config["data_in_exponent_bias"], ) self.x_quantizer = partial( - log_quantizer, width=x_width, exponent_bias=x_exponent_bias, + log_quantizer, + width=x_width, + exponent_bias=x_exponent_bias, ) @@ -139,7 +141,9 @@ def __init__(self, inplace: bool = False, config: dict = None): config["data_in_exponent_bias"], ) self.x_quantizer = partial( - log_quantizer, width=x_width, exponent_bias=x_exponent_bias, + log_quantizer, + width=x_width, + exponent_bias=x_exponent_bias, ) diff --git a/src/chop/nn/quantized/modules/silu.py b/src/chop/nn/quantized/modules/silu.py index 30dac9e5d..07f18ef6e 100644 --- a/src/chop/nn/quantized/modules/silu.py +++ b/src/chop/nn/quantized/modules/silu.py @@ -113,7 +113,9 @@ def __init__(self, inplace: bool = False, config: dict = None): config["data_in_exponent_bias"], ) self.x_quantizer = partial( - log_quantizer, width=x_width, exponent_bias=x_exponent_bias, + log_quantizer, + width=x_width, + exponent_bias=x_exponent_bias, ) @@ -131,7 +133,9 @@ def __init__(self, inplace: bool = False, config: dict = None): config["data_in_exponent_bias"], ) self.x_quantizer = partial( - log_quantizer, width=x_width, exponent_bias=x_exponent_bias, + log_quantizer, + width=x_width, + exponent_bias=x_exponent_bias, ) diff --git a/src/chop/nn/quantized/modules/softplus.py b/src/chop/nn/quantized/modules/softplus.py index 458558b63..4e8465c56 100644 --- a/src/chop/nn/quantized/modules/softplus.py +++ b/src/chop/nn/quantized/modules/softplus.py @@ -121,7 +121,9 @@ def __init__(self, inplace: bool = False, config: dict = None): config["data_in_exponent_bias"], ) self.x_quantizer = partial( - log_quantizer, width=x_width, exponent_bias=x_exponent_bias, + log_quantizer, + width=x_width, + exponent_bias=x_exponent_bias, ) @@ -139,7 +141,9 @@ def __init__(self, inplace: bool = False, config: dict = None): config["data_in_exponent_bias"], ) self.x_quantizer = partial( - log_quantizer, width=x_width, exponent_bias=x_exponent_bias, + log_quantizer, + width=x_width, + exponent_bias=x_exponent_bias, ) diff --git a/src/chop/nn/quantized/modules/softsign.py b/src/chop/nn/quantized/modules/softsign.py index fe3e53b62..5497426aa 100644 --- a/src/chop/nn/quantized/modules/softsign.py +++ b/src/chop/nn/quantized/modules/softsign.py @@ -121,7 +121,9 @@ def __init__(self, inplace: bool = False, config: dict = None): config["data_in_exponent_bias"], ) self.x_quantizer = partial( - log_quantizer, width=x_width, exponent_bias=x_exponent_bias, + log_quantizer, + width=x_width, + exponent_bias=x_exponent_bias, ) @@ -139,7 +141,9 @@ def __init__(self, inplace: bool = False, config: dict = None): config["data_in_exponent_bias"], ) self.x_quantizer = partial( - log_quantizer, width=x_width, exponent_bias=x_exponent_bias, + log_quantizer, + width=x_width, + exponent_bias=x_exponent_bias, ) diff --git a/src/chop/nn/quantized/modules/tanh.py b/src/chop/nn/quantized/modules/tanh.py index 3378a612f..fce343489 100644 --- a/src/chop/nn/quantized/modules/tanh.py +++ b/src/chop/nn/quantized/modules/tanh.py @@ -121,7 +121,9 @@ def __init__(self, inplace: bool = False, config: dict = None): config["data_in_exponent_bias"], ) self.x_quantizer = partial( - log_quantizer, width=x_width, exponent_bias=x_exponent_bias, + log_quantizer, + width=x_width, + exponent_bias=x_exponent_bias, ) @@ -139,7 +141,9 @@ def __init__(self, inplace: bool = False, config: dict = None): config["data_in_exponent_bias"], ) self.x_quantizer = partial( - log_quantizer, width=x_width, exponent_bias=x_exponent_bias, + log_quantizer, + width=x_width, + exponent_bias=x_exponent_bias, ) diff --git a/src/chop/nn/quantizers/LUTNet/BaseInitializer.py b/src/chop/nn/quantizers/LUTNet/BaseInitializer.py index 72478f46d..14cf53839 100644 --- a/src/chop/nn/quantizers/LUTNet/BaseInitializer.py +++ b/src/chop/nn/quantizers/LUTNet/BaseInitializer.py @@ -62,7 +62,10 @@ def update_luts_weights(self) -> torch.Tensor: key = row.detach().cpu().flatten().sign().numpy().tolist() new_weights.append(key) new_weights = torch.tensor( - new_weights, dtype=torch.float32, requires_grad=True, device=self.device, + new_weights, + dtype=torch.float32, + requires_grad=True, + device=self.device, ).view(-1, self.kk) return new_weights diff --git a/src/chop/nn/quantizers/LUTNet/BaseTrainer.py b/src/chop/nn/quantizers/LUTNet/BaseTrainer.py index 45d95805c..ff48436e0 100644 --- a/src/chop/nn/quantizers/LUTNet/BaseTrainer.py +++ b/src/chop/nn/quantizers/LUTNet/BaseTrainer.py @@ -38,7 +38,7 @@ def __init__( levels (int): Number of residual level to use. """ self.k = k - self.kk = 2 ** k + self.kk = 2**k self.binarization_level = binarization_level self.input_expanded = input_expanded self.tables_count = tables_count diff --git a/src/chop/nn/quantizers/block_fp.py b/src/chop/nn/quantizers/block_fp.py index 826133862..bfe8534e3 100644 --- a/src/chop/nn/quantizers/block_fp.py +++ b/src/chop/nn/quantizers/block_fp.py @@ -47,10 +47,10 @@ def _block_fp_quantize( if exponent_bias in (None, "none", "None"): exponent_bias = 2 ** (exponent_width - 1) - 1 - exponent_max = 2 ** exponent_width - 1 - exponent_bias + exponent_max = 2**exponent_width - 1 - exponent_bias exponent_min = -exponent_bias - mantissa_integer_max = 2 ** mantissa_bits - 1 + mantissa_integer_max = 2**mantissa_bits - 1 # sign per_block_sign = torch.sign(blocked_x + 1e-9) # exponent @@ -58,14 +58,14 @@ def _block_fp_quantize( per_block_exponent = torch.ceil(torch.log2(per_block_max)) per_block_exponent = my_clamp(per_block_exponent, exponent_min, exponent_max) # mantissa - per_block_mantissa = per_block_value / 2 ** per_block_exponent - shift = 2 ** mantissa_bits + per_block_mantissa = per_block_value / 2**per_block_exponent + shift = 2**mantissa_bits per_block_mantissa_integer = my_clamp( my_round(per_block_mantissa * shift), 0, mantissa_integer_max ) per_block_mantissa = per_block_mantissa_integer / shift - per_block_msfp = per_block_sign * (2 ** per_block_exponent) * per_block_mantissa + per_block_msfp = per_block_sign * (2**per_block_exponent) * per_block_mantissa msfp_x = unblock( per_block_msfp, x_shape_before_blocking=x_shape_before_blocking, @@ -133,5 +133,10 @@ def block_fp_quantizer( """ return BlockFPQuantize.apply( - x, width, exponent_width, exponent_bias, block_size, skip_first_dim, + x, + width, + exponent_width, + exponent_bias, + block_size, + skip_first_dim, ) diff --git a/src/chop/nn/quantizers/block_log.py b/src/chop/nn/quantizers/block_log.py index 1773b42b8..8e65c91ea 100644 --- a/src/chop/nn/quantizers/block_log.py +++ b/src/chop/nn/quantizers/block_log.py @@ -40,7 +40,7 @@ def _block_log_quantize( per_block_max_exponent = torch.ceil(torch.log2(per_block_max)) per_block_bias = my_clamp( - 2 ** exponent_bits - 1 - per_block_max_exponent, 0, 2 ** exponent_bias_width - 1 + 2**exponent_bits - 1 - per_block_max_exponent, 0, 2**exponent_bias_width - 1 ) per_block_lq_x = _log_quantize(blocked_x, width=width, exponent_bias=per_block_bias) @@ -98,5 +98,9 @@ def block_log_quantizer( - `block_size`: a list of integers where each integer is the block size along the corresponding dim """ return BlockLogQuantize.apply( - x, width, exponent_bias_width, block_size, skip_first_dim, + x, + width, + exponent_bias_width, + block_size, + skip_first_dim, ) diff --git a/src/chop/nn/quantizers/block_minifloat.py b/src/chop/nn/quantizers/block_minifloat.py index ccef6649d..34e00bbcb 100644 --- a/src/chop/nn/quantizers/block_minifloat.py +++ b/src/chop/nn/quantizers/block_minifloat.py @@ -41,7 +41,7 @@ def _block_minifloat_quantize( per_block_max[per_block_max == 0] = per_block_max[per_block_max != 0].min() per_block_exponent_bias = my_clamp( - torch.floor(torch.log2(per_block_max)), 0, 2 ** exponent_bias_width - 1 + torch.floor(torch.log2(per_block_max)), 0, 2**exponent_bias_width - 1 ) per_block_bm_x = _minifloat_ieee_quantize( blocked_x, @@ -118,5 +118,10 @@ def block_minifloat_quantizer( """ return BlockMinifloatQuantize.apply( - x, width, exponent_width, exponent_bias_width, block_size, skip_first_dim, + x, + width, + exponent_width, + exponent_bias_width, + block_size, + skip_first_dim, ) diff --git a/src/chop/nn/quantizers/integer.py b/src/chop/nn/quantizers/integer.py index 8c2573f22..9f3ffec1e 100644 --- a/src/chop/nn/quantizers/integer.py +++ b/src/chop/nn/quantizers/integer.py @@ -34,9 +34,9 @@ def _integer_quantize( int_max = 2 ** (width - 1) - 1 else: int_min = 0 - int_max = 2 ** width - 1 + int_max = 2**width - 1 # thresh = 2 ** (width - 1) - scale = 2 ** frac_width + scale = 2**frac_width if isinstance(x, (Tensor, ndarray)): return my_clamp(my_round(x.mul(scale)), int_min, int_max).div(scale) @@ -57,8 +57,8 @@ def _integer_floor_quantize( int_max = 2 ** (width - 1) - 1 else: int_min = 0 - int_max = 2 ** width - 1 - scale = 2 ** frac_width + int_max = 2**width - 1 + scale = 2**frac_width if isinstance(x, (Tensor, ndarray)): return my_clamp(my_floor(x.mul(scale)), int_min, int_max).div(scale) diff --git a/src/chop/nn/quantizers/log.py b/src/chop/nn/quantizers/log.py index 926e33cdb..98ab45e76 100644 --- a/src/chop/nn/quantizers/log.py +++ b/src/chop/nn/quantizers/log.py @@ -8,7 +8,9 @@ def _log_quantize( - x: Tensor | ndarray, width: int, exponent_bias: int | Tensor | ndarray | None, + x: Tensor | ndarray, + width: int, + exponent_bias: int | Tensor | ndarray | None, ): """ - Use non-uniform, base-2 logarithmic representation to encode IEEE FP32/64 @@ -30,16 +32,16 @@ def _log_quantize( if exponent_bias in (None, "none", "None"): exponent_bias = 2 ** (exponent_bits - 1) - 1 - exponent_max = 2 ** exponent_bits - 1 - exponent_bias + exponent_max = 2**exponent_bits - 1 - exponent_bias exponent_min = -exponent_bias - min_pos = 2 ** exponent_min + min_pos = 2**exponent_min sign = torch.sign(x + min_pos * 0.1) value = torch.abs(x) + min_pos * 0.1 exponent = my_clamp(my_round(torch.log2(value)), exponent_min, exponent_max) - return sign * (2 ** exponent) + return sign * (2**exponent) class LogQuantize(torch.autograd.Function): @@ -56,7 +58,9 @@ def backward(ctx, grad_output): def log_quantizer( - x: Tensor | ndarray, width: int, exponent_bias: int | Tensor | ndarray | None, + x: Tensor | ndarray, + width: int, + exponent_bias: int | Tensor | ndarray | None, ): """ Convert IEEE FP32/64 to base-2 log quantized values diff --git a/src/chop/nn/quantizers/minifloat.py b/src/chop/nn/quantizers/minifloat.py index f19097fde..2d6f23103 100644 --- a/src/chop/nn/quantizers/minifloat.py +++ b/src/chop/nn/quantizers/minifloat.py @@ -5,7 +5,10 @@ def _minifloat_denorm_quantize( - x: Tensor, width: int, exponent_width: int, exponent_bias: int = None, + x: Tensor, + width: int, + exponent_width: int, + exponent_bias: int = None, ): """ - Converts IEEE FP32/64 to minifloat without the implicit leading bit in mantissas. @@ -34,10 +37,10 @@ def _minifloat_denorm_quantize( if exponent_bias in (None, "none", "None"): exponent_bias = 2 ** (exponent_width - 1) - 1 - exponent_max = 2 ** exponent_width - 1 - exponent_bias + exponent_max = 2**exponent_width - 1 - exponent_bias exponent_min = -exponent_bias # if the mantissa is an integer, the max mantissa value will be (2**mantissa_bits -1) - shifted_mantissa_max = 2 ** mantissa_bits - 1 + shifted_mantissa_max = 2**mantissa_bits - 1 shifted_mantissa_min = 0 sign = torch.sign(x + 1e-9) @@ -49,8 +52,8 @@ def _minifloat_denorm_quantize( # divide value by clipped exponent. this ensures the simulated minifloat value is correct # when x is too large (minifloat will saturate) or too close to 0. - mantissa = value / 2 ** exponent - shift = 2 ** mantissa_bits + mantissa = value / 2**exponent + shift = 2**mantissa_bits shifted_mantissa = my_round(mantissa * shift) # clip the integer mantissa. shifted_mantissa = my_clamp( @@ -68,7 +71,11 @@ def _minifloat_denorm_quantize( class MinifloatDenormQuantize(torch.autograd.Function): @staticmethod def forward( - ctx, x: Tensor, width: int, exponent_width: int, exponent_bias: int = None, + ctx, + x: Tensor, + width: int, + exponent_width: int, + exponent_bias: int = None, ): return _minifloat_denorm_quantize( x, width=width, exponent_width=exponent_width, exponent_bias=exponent_bias @@ -81,7 +88,10 @@ def backward(ctx, grad_output): def minifloat_denorm_quantizer( - x: Tensor, width: int, exponent_width: int, exponent_bias: int = None, + x: Tensor, + width: int, + exponent_width: int, + exponent_bias: int = None, ): """ - Converts IEEE FP32/64 to minifloat without the implicit leading bit in mantissas. @@ -138,11 +148,11 @@ def _minifloat_ieee_quantize( if exponent_bias in (None, "none", "None"): exponent_bias = 2 ** (exponent_width - 1) - 1 # upper and lower bound of shifted exponent - exponent_max = 2 ** exponent_width - 1 - exponent_bias + exponent_max = 2**exponent_width - 1 - exponent_bias exponent_min = -exponent_bias # upper and lower bound of shifted minifloat mantissa - shift = 2 ** mantissa_bits - shifted_mantissa_max = 2 ** mantissa_bits - 1 + shift = 2**mantissa_bits + shifted_mantissa_max = 2**mantissa_bits - 1 shifted_mantissa_min = 0 sign = torch.sign(x + 1e-9) @@ -152,9 +162,9 @@ def _minifloat_ieee_quantize( exponent = torch.floor(torch.log2(value + 1e-9)) exponent = my_clamp(exponent, exponent_min, exponent_max) - mantissa = value / 2 ** exponent + mantissa = value / 2**exponent - shift = 2 ** mantissa_bits + shift = 2**mantissa_bits # fmt: off # if the clipped exponent is zero, the minifloat is in a subnormal form # this `is_normal` also help the grad keeps 1 if input x is 0, or the zero-initialized value will be trapped in 0 diff --git a/src/chop/nn/quantizers/mxint_hardware.py b/src/chop/nn/quantizers/mxint_hardware.py index 1a61fb70c..0c3e06130 100644 --- a/src/chop/nn/quantizers/mxint_hardware.py +++ b/src/chop/nn/quantizers/mxint_hardware.py @@ -19,7 +19,7 @@ def mxint_quant_block( """ exponent_bias = 2 ** (exponent_width - 1) - exponent_max = 2 ** exponent_width - 1 - exponent_bias + exponent_max = 2**exponent_width - 1 - exponent_bias exponent_min = -exponent_bias # exponent @@ -29,9 +29,9 @@ def mxint_quant_block( # mantissa int_min = -(2 ** (width - 1)) int_max = 2 ** (width - 1) - 1 - mantissa = x / 2 ** exponent + mantissa = x / 2**exponent mantissa = torch.clamp(mantissa.floor(), int_min, int_max) - q_x = (2 ** exponent) * mantissa + q_x = (2**exponent) * mantissa return q_x diff --git a/src/chop/nn/quantizers/quantizers_for_hw.py b/src/chop/nn/quantizers/quantizers_for_hw.py index ccf57319f..d5ca3d8cf 100644 --- a/src/chop/nn/quantizers/quantizers_for_hw.py +++ b/src/chop/nn/quantizers/quantizers_for_hw.py @@ -9,31 +9,31 @@ def integer_quantizer_for_hw(x: Tensor, width: int, frac_width: int): thresh = 2 ** (width - 1) - scale = 2 ** frac_width + scale = 2**frac_width fixed_point_value = my_clamp(my_round(x.mul(scale)), -thresh, thresh - 1) fixed_point_value = fixed_point_value.to(torch.int) - fixed_point_value = fixed_point_value % (2 ** width) + fixed_point_value = fixed_point_value % (2**width) return fixed_point_value def unsigned_integer_quantizer_for_hw(x: Tensor, width: int, frac_width: int): - thresh = 2 ** width - 1 - scale = 2 ** frac_width + thresh = 2**width - 1 + scale = 2**frac_width fixed_point_value = my_clamp(my_floor(x.mul(scale)), 0, thresh) fixed_point_value = fixed_point_value.to(torch.int) - fixed_point_value = fixed_point_value % (2 ** width) + fixed_point_value = fixed_point_value % (2**width) return fixed_point_value def integer_floor_quantizer_for_hw(x: Tensor, width: int, frac_width: int): thresh = 2 ** (width - 1) - scale = 2 ** frac_width + scale = 2**frac_width fixed_point_value = my_clamp(my_floor(x.mul(scale)), -thresh, thresh - 1) fixed_point_value = fixed_point_value.to(torch.int) - fixed_point_value = fixed_point_value % (2 ** width) + fixed_point_value = fixed_point_value % (2**width) return fixed_point_value diff --git a/src/chop/nn/quantizers/ternary.py b/src/chop/nn/quantizers/ternary.py index 89fcfe75b..a27951c23 100644 --- a/src/chop/nn/quantizers/ternary.py +++ b/src/chop/nn/quantizers/ternary.py @@ -45,7 +45,8 @@ def ternary_quantizer( # ) if scaling_factor: x = ternarised_scaled_op( - x, threshold, # abs_mean=mean + x, + threshold, # abs_mean=mean ) # [mean, 0 ,-mean] # this function determines the mean on the fly, maybe we could make an alternative which uses the metadata? else: x = ternarised_op(x, threshold) # [1, 0 ,-1] diff --git a/src/chop/nn/quantizers/utils.py b/src/chop/nn/quantizers/utils.py index 9efc8b512..ae881328d 100644 --- a/src/chop/nn/quantizers/utils.py +++ b/src/chop/nn/quantizers/utils.py @@ -224,8 +224,16 @@ def forward(ctx, input, _threshold): alpha = TernaryScaled.alpha(input, delta) output = torch.zeros_like(input) - pos_one = torch.where(input > delta, 1.0, 0.0,) - neg_one = torch.where(input < -delta, -1.0, 0.0,) + pos_one = torch.where( + input > delta, + 1.0, + 0.0, + ) + neg_one = torch.where( + input < -delta, + -1.0, + 0.0, + ) output = (pos_one + neg_one) * alpha.view(-1, 1, 1, 1).expand( -1, input.size()[1], input.size()[2], input.size()[3] ) @@ -287,8 +295,16 @@ def forward(ctx, input, _threshold): alpha = TernaryScaled.alpha(input, delta) output = torch.zeros_like(input) - pos_one = torch.where(input > delta, 1.0, 0.0,) - neg_one = torch.where(input < -delta, -1.0, 0.0,) + pos_one = torch.where( + input > delta, + 1.0, + 0.0, + ) + neg_one = torch.where( + input < -delta, + -1.0, + 0.0, + ) output = pos_one + neg_one return output @@ -397,7 +413,8 @@ def _block_1d_bias(x: Tensor, block_shape: List[int]): def _unblock_to_1d_bias( - blocked_x: Tensor, x_shape_before_blocking: List[int], + blocked_x: Tensor, + x_shape_before_blocking: List[int], ): """ blocked bias shape: [num_blocks, block_size] -> [output_features] @@ -592,7 +609,10 @@ def unblock( return _unblock_to_2d_activation(blocked_x, x_shape_before_blocking) else: return _unblock_to_2d_weight( - blocked_x, x_shape_before_blocking, padded_x_shape, block_shape, + blocked_x, + x_shape_before_blocking, + padded_x_shape, + block_shape, ) elif len(x_shape_before_blocking) == 3: if skipped_first_dim_when_blocking: diff --git a/src/chop/nn/snn/auto_cuda/generator.py b/src/chop/nn/snn/auto_cuda/generator.py index f1637ddf9..639cf0379 100644 --- a/src/chop/nn/snn/auto_cuda/generator.py +++ b/src/chop/nn/snn/auto_cuda/generator.py @@ -310,7 +310,10 @@ def gen_forward_codes( params.append(("v_reset", "const float &")) params.extend( - [("neuron_num", "const int &"), ("numel", "const int &"),] + [ + ("neuron_num", "const int &"), + ("numel", "const int &"), + ] ) params_name = [] for item in params: diff --git a/src/chop/nn/snn/modules/spiking_self_attention.py b/src/chop/nn/snn/modules/spiking_self_attention.py index 3fa47d1d8..87ee98a30 100644 --- a/src/chop/nn/snn/modules/spiking_self_attention.py +++ b/src/chop/nn/snn/modules/spiking_self_attention.py @@ -145,7 +145,9 @@ def __init__(self, in_channels, num_conv=1, ratio=4, group_size=64, activation=L super().__init__() inner_channels = in_channels * ratio self.up = nn.Sequential( - activation(), Conv1x1(in_channels, inner_channels), BN(inner_channels), + activation(), + Conv1x1(in_channels, inner_channels), + BN(inner_channels), ) self.conv = nn.ModuleList() for _ in range(num_conv): @@ -161,7 +163,9 @@ def __init__(self, in_channels, num_conv=1, ratio=4, group_size=64, activation=L ) ) self.down = nn.Sequential( - activation(), Conv1x1(inner_channels, in_channels), BN(in_channels), + activation(), + Conv1x1(inner_channels, in_channels), + BN(in_channels), ) def forward(self, x: torch.Tensor) -> torch.Tensor: diff --git a/src/chop/passes/graph/__init__.py b/src/chop/passes/graph/__init__.py index 3274c9143..4b1eb96ef 100644 --- a/src/chop/passes/graph/__init__.py +++ b/src/chop/passes/graph/__init__.py @@ -149,6 +149,6 @@ if check_dependencies("tensorrt_fake_quantize_transform_pass"): TRANSFORM_PASSES.append("tensorrt_fake_quantize_transform_pass") - PASSES[ - "tensorrt_fake_quantize_transform_pass" - ] = tensorrt_fake_quantize_transform_pass + PASSES["tensorrt_fake_quantize_transform_pass"] = ( + tensorrt_fake_quantize_transform_pass + ) diff --git a/src/chop/passes/graph/analysis/add_metadata/add_common_metadata.py b/src/chop/passes/graph/analysis/add_metadata/add_common_metadata.py index 9fde9d7cb..66fe9cc28 100644 --- a/src/chop/passes/graph/analysis/add_metadata/add_common_metadata.py +++ b/src/chop/passes/graph/analysis/add_metadata/add_common_metadata.py @@ -203,7 +203,10 @@ def graph_iterator_for_mase_ops(graph): def graph_iterator_for_metadata( - graph, dummy_in=None, add_value=True, force_device_meta=False, + graph, + dummy_in=None, + add_value=True, + force_device_meta=False, ): """ largely adapted from https://pytorch.org/docs/stable/fx.html @@ -288,7 +291,12 @@ def _add_graph_metadata(graph): def add_common_metadata_analysis_pass( - graph, pass_args={"dummy_in": None, "add_value": True, "force_device_meta": False,}, + graph, + pass_args={ + "dummy_in": None, + "add_value": True, + "force_device_meta": False, + }, ): """add common metadata diff --git a/src/chop/passes/graph/analysis/add_metadata/add_hardware_metadata.py b/src/chop/passes/graph/analysis/add_metadata/add_hardware_metadata.py index 9fd2e9466..39c8ac690 100644 --- a/src/chop/passes/graph/analysis/add_metadata/add_hardware_metadata.py +++ b/src/chop/passes/graph/analysis/add_metadata/add_hardware_metadata.py @@ -39,7 +39,8 @@ def add_component_source(node): if mase_op == "user_defined_module": for custom_op, op_info in node.meta["mase"].model.custom_ops["modules"].items(): if isinstance( - deepgetattr(node.meta["mase"].model, node.target), custom_op, + deepgetattr(node.meta["mase"].model, node.target), + custom_op, ): node.meta["mase"]["hardware"]["toolchain"] = "INTERNAL_RTL" node.meta["mase"]["hardware"]["module"] = op_info["module"] diff --git a/src/chop/passes/graph/analysis/add_metadata/common_metadata_layers.py b/src/chop/passes/graph/analysis/add_metadata/common_metadata_layers.py index ed3249c9c..6f5ee0147 100644 --- a/src/chop/passes/graph/analysis/add_metadata/common_metadata_layers.py +++ b/src/chop/passes/graph/analysis/add_metadata/common_metadata_layers.py @@ -231,7 +231,10 @@ }, # Inserted ops from the replace_method_with_function pass "torch_size": {"input": "data_in", "dim": "config"}, - "torch_contiguous": {"input": "data_in", "memory_format": "config",}, + "torch_contiguous": { + "input": "data_in", + "memory_format": "config", + }, # arbitrary length - support up to 4 "torch_expand": { "input": "data_in", @@ -254,7 +257,11 @@ "shape_2": "config", "shape_3": "config", }, - "torch_split": {"input": "data_in", "split_size": "config", "dim": "config",}, + "torch_split": { + "input": "data_in", + "split_size": "config", + "dim": "config", + }, "torch_permute": { "input": "data_in", "dim_0": "config", @@ -262,7 +269,11 @@ "dim_2": "config", "dim_3": "config", }, - "torch_transpose": {"input": "data_in", "dim0": "config", "dim1": "config",}, + "torch_transpose": { + "input": "data_in", + "dim0": "config", + "dim1": "config", + }, # DTensor ops "dtensor_arange": { "device_mesh": "config", @@ -276,7 +287,9 @@ "requires_grad": "config", }, # tensor constructor - "tensor": {"data": "data_in",}, + "tensor": { + "data": "data_in", + }, # https://pytorch.org/docs/stable/generated/torch.nn.functional.dropout.html "dropout": { "input": "data_in", @@ -340,7 +353,10 @@ "softmax": {"input": "data_in"}, "gelu": {"input": "data_in"}, # https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html - "crossentropyloss": {"input": "data_in", "target": "data_in",}, + "crossentropyloss": { + "input": "data_in", + "target": "data_in", + }, # chop.nn.modules.lora.LoRALinear "loralinear": {"input": "data_in"}, "grouped_query_attention": {"input": "data_in"}, @@ -389,15 +405,24 @@ "size_3": "config", }, # Tensor.max(dim=None, keepdim=False) - "max": {"dim": "config", "keepdim": "config",}, + "max": { + "dim": "config", + "keepdim": "config", + }, # https://pytorch.org/docs/stable/generated/torch.Tensor.sum.html - "sum": {"dim": "config", "keepdim": "config",}, + "sum": { + "dim": "config", + "keepdim": "config", + }, # https://pytorch.org/docs/stable/generated/torch.Tensor.round.html "round": {}, # https://pytorch.org/docs/stable/generated/torch.Tensor.floor.html "floor": {}, # https://pytorch.org/docs/stable/generated/torch.Tensor.clamp.html - "clamp": {"min": "config", "max": "config",}, + "clamp": { + "min": "config", + "max": "config", + }, # https://pytorch.org/docs/stable/generated/torch.Tensor.dim.html "dim": {}, # https://pytorch.org/docs/stable/generated/torch.Tensor.permute.html#torch.Tensor.permute @@ -426,7 +451,11 @@ # https://pytorch.org/docs/stable/generated/torch.Tensor.type_as.html "type_as": {"tensor": "data_in"}, # https://pytorch.org/docs/stable/generated/torch.Tensor.index_select.html - "index_select": {"input": "data_in", "dim": "config", "index": "data_in",}, + "index_select": { + "input": "data_in", + "dim": "config", + "index": "data_in", + }, # https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html "detach": {"input": "data_in"}, } @@ -479,7 +508,11 @@ def deepgetattr(obj, attr): def _annotate_arg_metadata( - meta: MaseMetadata, args: list, kwargs: dict, func_data: dict, add_value: bool, + meta: MaseMetadata, + args: list, + kwargs: dict, + func_data: dict, + add_value: bool, ): """ Analyse target args and kwargs received from shape propagation to annotate combined meta["mase"]["args"] @@ -588,7 +621,9 @@ def _annotate_arg_metadata( def _annotate_result_metadata( - meta: MaseMetadata, result, add_value: bool, + meta: MaseMetadata, + result, + add_value: bool, ) -> MaseMetadata: """ Analyse the result from running the target to annotate the meta["mase"]["results"] dictionary with metadata. diff --git a/src/chop/passes/graph/analysis/add_metadata/hardware_metadata_layers.py b/src/chop/passes/graph/analysis/add_metadata/hardware_metadata_layers.py index a92958cff..b084821ae 100644 --- a/src/chop/passes/graph/analysis/add_metadata/hardware_metadata_layers.py +++ b/src/chop/passes/graph/analysis/add_metadata/hardware_metadata_layers.py @@ -64,13 +64,17 @@ "relu": [ { "name": "fixed_relu", - "dependence_files": ["activation_layers/rtl/fixed_relu.sv",], + "dependence_files": [ + "activation_layers/rtl/fixed_relu.sv", + ], }, ], "hardshrink": [ { "name": "fixed_hardshrink", - "dependence_files": ["activation_layers/rtl/fixed_hardshrink.sv",], + "dependence_files": [ + "activation_layers/rtl/fixed_hardshrink.sv", + ], }, ], "silu": [ @@ -103,7 +107,9 @@ "softshrink": [ { "name": "fixed_softshrink", - "dependence_files": ["activation_layers/rtl/fixed_softshrink.sv",], + "dependence_files": [ + "activation_layers/rtl/fixed_softshrink.sv", + ], }, ], "logsigmoid": [ @@ -132,13 +138,17 @@ "selu": [ { "name": "fixed_selu", - "dependence_files": ["activation_layers/rtl/fixed_selu.sv",], + "dependence_files": [ + "activation_layers/rtl/fixed_selu.sv", + ], }, ], "tanh": [ { "name": "fixed_tanh", - "dependence_files": ["activation_layers/rtl/fixed_tanh.sv",], + "dependence_files": [ + "activation_layers/rtl/fixed_tanh.sv", + ], }, ], "gelu": [ @@ -162,13 +172,17 @@ "softplus": [ { "name": "fixed_softplus", - "dependence_files": ["activation_layers/rtl/fixed_softplus.sv",], + "dependence_files": [ + "activation_layers/rtl/fixed_softplus.sv", + ], }, ], "add": [ { "name": "fixed_adder", - "dependence_files": ["linear_layers/fixed_operators/rtl/fixed_adder.sv",], + "dependence_files": [ + "linear_layers/fixed_operators/rtl/fixed_adder.sv", + ], } ], "mul": [ @@ -185,7 +199,14 @@ "dependence_files": ["common/rtl/df_split.sv", "common/rtl/split2.sv"], } ], - "getitem": [{"name": "buffer", "dependence_files": ["memory/rtl/buffer.sv",],}], + "getitem": [ + { + "name": "buffer", + "dependence_files": [ + "memory/rtl/buffer.sv", + ], + } + ], "grouped_query_attention": [ { "name": "fixed_gqa_wrapper", diff --git a/src/chop/passes/graph/analysis/add_metadata/software_metadata_layers.py b/src/chop/passes/graph/analysis/add_metadata/software_metadata_layers.py index 1e6ad2a57..d7f4fd8ed 100644 --- a/src/chop/passes/graph/analysis/add_metadata/software_metadata_layers.py +++ b/src/chop/passes/graph/analysis/add_metadata/software_metadata_layers.py @@ -243,8 +243,16 @@ def analyze_software_meta_param_patched_func_default(meta): "getitem": analyze_software_meta_param_implicit_func_default, "getattr": analyze_software_meta_param_implicit_func_default, }, - "placeholder": {"placeholder": analyze_software_meta_param_placeholder,}, - "get_attr": {"get_attr": analyze_software_meta_param_get_attr,}, - "output": {"output": analyze_software_meta_param_output,}, - "patched_func": {"default": analyze_software_meta_param_patched_func_default,}, + "placeholder": { + "placeholder": analyze_software_meta_param_placeholder, + }, + "get_attr": { + "get_attr": analyze_software_meta_param_get_attr, + }, + "output": { + "output": analyze_software_meta_param_output, + }, + "patched_func": { + "default": analyze_software_meta_param_patched_func_default, + }, } diff --git a/src/chop/passes/graph/analysis/autosharding/alpa_cost_modelling.py b/src/chop/passes/graph/analysis/autosharding/alpa_cost_modelling.py index 30216eea3..d9fea1aa5 100644 --- a/src/chop/passes/graph/analysis/autosharding/alpa_cost_modelling.py +++ b/src/chop/passes/graph/analysis/autosharding/alpa_cost_modelling.py @@ -74,7 +74,10 @@ def get_resharding_cost( and (dest[0] == SpmdShard.R) ): ag_dim = 1 if src[0] == dest[0] else 0 - return mesh.all_gather_cost(num_bytes=num_bytes, mesh_dim=ag_dim,) + return mesh.all_gather_cost( + num_bytes=num_bytes, + mesh_dim=ag_dim, + ) # All-to-all # E.g. (R, S_0) -> (S_0, R), (S_1, R) -> (R, S_1) @@ -82,7 +85,10 @@ def get_resharding_cost( # all to all a2a_dim = src[0].value if src[0] != SpmdShard.R else src[1].value try: - return mesh.all_to_all_cost(num_bytes=num_bytes, mesh_dim=a2a_dim,) + return mesh.all_to_all_cost( + num_bytes=num_bytes, + mesh_dim=a2a_dim, + ) except: assert False diff --git a/src/chop/passes/graph/analysis/autosharding/autosharding.py b/src/chop/passes/graph/analysis/autosharding/autosharding.py index b49154c20..b149edf4a 100644 --- a/src/chop/passes/graph/analysis/autosharding/autosharding.py +++ b/src/chop/passes/graph/analysis/autosharding/autosharding.py @@ -23,7 +23,10 @@ def deepgetattr(obj, attr, default=None): def _import_solution( - mg, solution: dict, mesh: MeshModel, extrapolate_sharding: bool = True, + mg, + solution: dict, + mesh: MeshModel, + extrapolate_sharding: bool = True, ): """Import an autosharding solution into the metadata of the MaseGraph. @@ -68,14 +71,18 @@ def _import_solution( # Annotate the metadata for each argument for arg, arg_spec in solution[node.name].get("args", {}).items(): node.meta["mase"]["common"]["args"][arg]["dtensor_spec"] = _DTensorSpec( - mesh=mesh, placements=arg_spec, + mesh=mesh, + placements=arg_spec, ) # Annotate the metadata for each result for result, result_spec in solution[node.name].get("results", {}).items(): - node.meta["mase"]["common"]["results"][result][ - "dtensor_spec" - ] = _DTensorSpec(mesh=mesh, placements=result_spec,) + node.meta["mase"]["common"]["results"][result]["dtensor_spec"] = ( + _DTensorSpec( + mesh=mesh, + placements=result_spec, + ) + ) return mg, {} @@ -107,7 +114,10 @@ def _export_solution(mg, export_file: str = "ilp_solution.pkl"): logger.warning( f"DTensor spec not found for arg: {arg} in node: {node_name}. Assigning fully-replicated solution." ) - spec = _DTensorSpec(None, (Replicate(), Replicate()),) + spec = _DTensorSpec( + None, + (Replicate(), Replicate()), + ) else: spec = arg_info["dtensor_spec"] @@ -122,7 +132,10 @@ def _export_solution(mg, export_file: str = "ilp_solution.pkl"): logger.warning( f"DTensor spec not found for result: {result} in node: {node_name}. Assigning fully-replicated solution." ) - spec = _DTensorSpec(None, (Replicate(), Replicate()),) + spec = _DTensorSpec( + None, + (Replicate(), Replicate()), + ) else: spec = result_info["dtensor_spec"] out_dict[node_name]["results"][result] = spec.placements @@ -184,7 +197,9 @@ def _get_sharding_map(mg): if module not in tensor_sharding_map: tensor_sharding_map[module] = { "node": node.name, - "sharding": {attr: out_specs,}, + "sharding": { + attr: out_specs, + }, } else: tensor_sharding_map[module]["sharding"][attr] = out_specs diff --git a/src/chop/passes/graph/analysis/autosharding/megatron.py b/src/chop/passes/graph/analysis/autosharding/megatron.py index ab9134a36..30cd36f7e 100644 --- a/src/chop/passes/graph/analysis/autosharding/megatron.py +++ b/src/chop/passes/graph/analysis/autosharding/megatron.py @@ -3,7 +3,9 @@ def megatron_autosharding_pass( - mg: MaseGraph, mesh: MeshModel, pass_args: dict, + mg: MaseGraph, + mesh: MeshModel, + pass_args: dict, ): for node in mg.fx_graph.nodes: meta = node.meta["mase"]["common"] diff --git a/src/chop/passes/graph/analysis/autosharding/strategies/basic_strategy.py b/src/chop/passes/graph/analysis/autosharding/strategies/basic_strategy.py index d68f61cb5..40b704ca8 100644 --- a/src/chop/passes/graph/analysis/autosharding/strategies/basic_strategy.py +++ b/src/chop/passes/graph/analysis/autosharding/strategies/basic_strategy.py @@ -86,7 +86,10 @@ def parse_dims(cls, input_dims: List[str], output_dim: str) -> "EinsumDims": def gen_einsum_strategies( - equation: str, mesh: tuple, *, linearity: bool = False, + equation: str, + mesh: tuple, + *, + linearity: bool = False, ) -> OpStrategy: """ Generate a strategy list for the ops that follow einsum style notation. diff --git a/src/chop/passes/graph/analysis/autosharding/strategies/common.py b/src/chop/passes/graph/analysis/autosharding/strategies/common.py index e58980c60..2635a5587 100644 --- a/src/chop/passes/graph/analysis/autosharding/strategies/common.py +++ b/src/chop/passes/graph/analysis/autosharding/strategies/common.py @@ -81,7 +81,11 @@ def fully_replicated_strategy(meta, mesh): in_spec = _DTensorSpec( mesh, sharding, - tensor_meta=TensorMeta(shape=in_shape, stride=None, dtype=in_dtype,), + tensor_meta=TensorMeta( + shape=in_shape, + stride=None, + dtype=in_dtype, + ), ) dtype_key = ( diff --git a/src/chop/passes/graph/analysis/autosharding/strategies/matrix_ops.py b/src/chop/passes/graph/analysis/autosharding/strategies/matrix_ops.py index 60a3de959..78328a95d 100644 --- a/src/chop/passes/graph/analysis/autosharding/strategies/matrix_ops.py +++ b/src/chop/passes/graph/analysis/autosharding/strategies/matrix_ops.py @@ -24,7 +24,10 @@ from chop.ir.graph import MaseMetadata -def transpose_strategy(meta: MaseMetadata, mesh: tuple,) -> OpStrategy: +def transpose_strategy( + meta: MaseMetadata, + mesh: tuple, +) -> OpStrategy: parent_node = meta.node.args[0] self_strategy = parent_node.meta["mase"]["software"]["autosharding"]["op_strategy"] @@ -56,7 +59,11 @@ def transpose_strategy(meta: MaseMetadata, mesh: tuple,) -> OpStrategy: return OpStrategy(strategies=transpose_strategies) -def _mm_like_strategy(mm_equation: str, meta: MaseMetadata, mesh: tuple,) -> OpStrategy: +def _mm_like_strategy( + mm_equation: str, + meta: MaseMetadata, + mesh: tuple, +) -> OpStrategy: self_shape, mat2_shape = [arg["shape"] for arg in meta["common"]["args"].values()] # generate all possible strategies for mm mm_strategy = gen_einsum_strategies(mm_equation, mesh) @@ -95,7 +102,9 @@ def _mm_like_strategy(mm_equation: str, meta: MaseMetadata, mesh: tuple,) -> OpS def _addmm_like_strategy( - mm_equation: str, meta: MaseMetadata, mesh: tuple, + mm_equation: str, + meta: MaseMetadata, + mesh: tuple, ) -> OpStrategy: self_shape, mat1_shape, mat2_shape = [ @@ -161,24 +170,37 @@ def _addmm_like_strategy( return mm_strategy -def mm_strategy(meta: MaseMetadata, mesh: tuple,) -> OpStrategy: +def mm_strategy( + meta: MaseMetadata, + mesh: tuple, +) -> OpStrategy: return _mm_like_strategy("mk,kn->mn", meta, mesh) -def addmm_strategy(meta: MaseMetadata, mesh: tuple,) -> OpStrategy: +def addmm_strategy( + meta: MaseMetadata, + mesh: tuple, +) -> OpStrategy: return _addmm_like_strategy("mk,kn->mn", meta, mesh) -def bmm_strategy(meta: MaseMetadata, mesh: tuple,) -> OpStrategy: +def bmm_strategy( + meta: MaseMetadata, + mesh: tuple, +) -> OpStrategy: return _mm_like_strategy("bmk,bkn->bmn", meta, mesh) -def baddmm_strategy(meta: MaseMetadata, mesh: tuple,) -> OpStrategy: +def baddmm_strategy( + meta: MaseMetadata, + mesh: tuple, +) -> OpStrategy: return _addmm_like_strategy("bmk,bkn->bmn", meta, mesh) def scaled_dot_product_flash_attention_strategy( - meta: MaseMetadata, mesh: tuple, + meta: MaseMetadata, + mesh: tuple, ) -> OpStrategy: # NOTE: currently we only support some simple strategies to support tensor parallelism # TODO: sdpa might be a good candidate for us to explore decomposed sharding propagation diff --git a/src/chop/passes/graph/analysis/autosharding/strategies/pointwise_ops.py b/src/chop/passes/graph/analysis/autosharding/strategies/pointwise_ops.py index a7d8a739b..c3c06d35e 100644 --- a/src/chop/passes/graph/analysis/autosharding/strategies/pointwise_ops.py +++ b/src/chop/passes/graph/analysis/autosharding/strategies/pointwise_ops.py @@ -125,7 +125,9 @@ def common_pointwise_strategy( arg_node.meta["mase"]["common"]["results"]["data_out_0"]["shape"], ) input_target_placements = map_placements_after_broadcast( - tuple(out_placements), common_shape, input_arg_dims_map, + tuple(out_placements), + common_shape, + input_arg_dims_map, ) input_arg_target_spec = _DTensorSpec( mesh=mesh, diff --git a/src/chop/passes/graph/analysis/autosharding/strategies/view_ops.py b/src/chop/passes/graph/analysis/autosharding/strategies/view_ops.py index 6cd423580..505a2440b 100644 --- a/src/chop/passes/graph/analysis/autosharding/strategies/view_ops.py +++ b/src/chop/passes/graph/analysis/autosharding/strategies/view_ops.py @@ -236,7 +236,9 @@ def dim_flatten(ndim: int, start_dim=0, end_dim=-1) -> DimMap: def dim_movedim( - ndim: int, input: Union[int, Sequence[int]], destination: Union[int, Sequence[int]], + ndim: int, + input: Union[int, Sequence[int]], + destination: Union[int, Sequence[int]], ) -> DimMap: input = normalize_dims(input, ndim) destination = normalize_dims(destination, ndim) @@ -620,7 +622,8 @@ def reshape_strategy(meta, mesh): ) output_strategy.strategies.append( PlacementStrategy( - output_specs=output_spec, input_specs=(input_tgt_spec,), + output_specs=output_spec, + input_specs=(input_tgt_spec,), ) ) diff --git a/src/chop/passes/graph/analysis/flop_estimator/calculator/calc_modules.py b/src/chop/passes/graph/analysis/flop_estimator/calculator/calc_modules.py index 04dde1876..b8ca08310 100644 --- a/src/chop/passes/graph/analysis/flop_estimator/calculator/calc_modules.py +++ b/src/chop/passes/graph/analysis/flop_estimator/calculator/calc_modules.py @@ -36,7 +36,7 @@ def calculate_modules(module, in_data, out_data): # Kernel size here can be either a single int for square kernel # or a tuple (see # https://pytorch.org/docs/stable/nn.html#torch.nn.MaxPool2d ) - window_size = module.kernel_size ** 2 + window_size = module.kernel_size**2 else: window_size = module.kernel_size[0] * module.kernel_size[1] diff --git a/src/chop/passes/graph/analysis/plot/plot_graph.py b/src/chop/passes/graph/analysis/plot/plot_graph.py index adcc93177..8af151b36 100644 --- a/src/chop/passes/graph/analysis/plot/plot_graph.py +++ b/src/chop/passes/graph/analysis/plot/plot_graph.py @@ -4,7 +4,10 @@ def plot_graph_analysis_pass( - graph, pass_args={"file_name": None,}, + graph, + pass_args={ + "file_name": None, + }, ): graph.draw(pass_args["file_name"]) # nx_graph = nx.DiGraph() diff --git a/src/chop/passes/graph/interface/tensorrt/quantize.py b/src/chop/passes/graph/interface/tensorrt/quantize.py index 0a6773754..0de0cda60 100644 --- a/src/chop/passes/graph/interface/tensorrt/quantize.py +++ b/src/chop/passes/graph/interface/tensorrt/quantize.py @@ -21,7 +21,6 @@ def Quantizer(config): "pytorch_quantization is not installed. Cannot use tensorrt quantize pass." ) - else: import tensorrt as trt from pytorch_quantization import quant_modules, calib diff --git a/src/chop/passes/graph/transforms/dse/run_dse.py b/src/chop/passes/graph/transforms/dse/run_dse.py index 991f9e1bd..28446c4ca 100644 --- a/src/chop/passes/graph/transforms/dse/run_dse.py +++ b/src/chop/passes/graph/transforms/dse/run_dse.py @@ -16,7 +16,7 @@ def get_factors(n): set( functools.reduce( list.__add__, - ([i, n // i] for i in range(1, int(n ** 0.5) + 1) if n % i == 0), + ([i, n // i] for i in range(1, int(n**0.5) + 1) if n % i == 0), ) ) ) diff --git a/src/chop/passes/graph/transforms/lora.py b/src/chop/passes/graph/transforms/lora.py index 87fb2362f..9c4559661 100644 --- a/src/chop/passes/graph/transforms/lora.py +++ b/src/chop/passes/graph/transforms/lora.py @@ -10,7 +10,8 @@ def insert_lora_adapter_transform_pass( - mg: MaseGraph, pass_args={}, + mg: MaseGraph, + pass_args={}, ): rank = pass_args.get("rank", 0) @@ -41,7 +42,8 @@ def insert_lora_adapter_transform_pass( def fuse_lora_weights_transform_pass( - mg: MaseGraph, pass_args={}, + mg: MaseGraph, + pass_args={}, ): for node in mg.nodes: target = ( diff --git a/src/chop/passes/graph/transforms/onnxrt/quantize.py b/src/chop/passes/graph/transforms/onnxrt/quantize.py index 39d568236..cacbb6272 100644 --- a/src/chop/passes/graph/transforms/onnxrt/quantize.py +++ b/src/chop/passes/graph/transforms/onnxrt/quantize.py @@ -54,7 +54,9 @@ def quantize_dynamic(self, model_path: PosixPath, quantized_model_path: PosixPat ) quantized_model = quantize_dynamic( - model_path, quantized_model_path, weight_type=precision, + model_path, + quantized_model_path, + weight_type=precision, ) self.logger.info("Quantization complete. Model is now dynamically quantized.") diff --git a/src/chop/passes/graph/transforms/pruning/pruning_methods.py b/src/chop/passes/graph/transforms/pruning/pruning_methods.py index aca81dc51..665abc17e 100644 --- a/src/chop/passes/graph/transforms/pruning/pruning_methods.py +++ b/src/chop/passes/graph/transforms/pruning/pruning_methods.py @@ -128,11 +128,31 @@ def neurons_random_fan_in( weight_criteria_map = { - "local": {"elementwise": {"random": random, "l1-norm": l1,}}, - "global": {"elementwise": {"random": random, "l1-norm": global_weight_l1,}}, + "local": { + "elementwise": { + "random": random, + "l1-norm": l1, + } + }, + "global": { + "elementwise": { + "random": random, + "l1-norm": global_weight_l1, + } + }, } activation_criteria_map = { - "local": {"elementwise": {"random": random, "l1-norm": l1,}}, - "global": {"elementwise": {"random": random, "l1-norm": global_activation_l1,}}, + "local": { + "elementwise": { + "random": random, + "l1-norm": l1, + } + }, + "global": { + "elementwise": { + "random": random, + "l1-norm": global_activation_l1, + } + }, } diff --git a/src/chop/passes/graph/transforms/quantize/quant_parsers/parse_quant_config.py b/src/chop/passes/graph/transforms/quantize/quant_parsers/parse_quant_config.py index 890497a95..dea246fc2 100644 --- a/src/chop/passes/graph/transforms/quantize/quant_parsers/parse_quant_config.py +++ b/src/chop/passes/graph/transforms/quantize/quant_parsers/parse_quant_config.py @@ -27,7 +27,10 @@ "weight_entries": ("weight_width", "weight_frac_width"), "data_in_entries": ("data_in_width", "data_in_frac_width"), "bias_entries": ("bias_width", "bias_frac_width"), - "data_out_entries": ("data_out_width", "data_out_frac_width",), + "data_out_entries": ( + "data_out_width", + "data_out_frac_width", + ), "additional_layers_entries": ("floor"), }, "lutnet": { @@ -62,9 +65,18 @@ "weight_width", "weight_frac_width", ), - "bias_entries": ("bias_width", "bias_frac_width",), - "data_in_entries": ("data_in_width", "data_in_frac_width",), - "data_out_entries": ("data_out_width", "data_out_frac_width",), + "bias_entries": ( + "bias_width", + "bias_frac_width", + ), + "data_in_entries": ( + "data_in_width", + "data_in_frac_width", + ), + "data_out_entries": ( + "data_out_width", + "data_out_frac_width", + ), "additional_layers_entries": { "additional_layers_inputs", "additional_layers_outputs", @@ -72,9 +84,21 @@ }, }, "binary": { - "weight_entries": ("weight_width", "weight_stochastic", "weight_bipolar",), - "data_in_entries": ("data_in_width", "data_in_stochastic", "data_in_bipolar",), - "bias_entries": ("bias_width", "bias_stochastic", "bias_bipolar",), + "weight_entries": ( + "weight_width", + "weight_stochastic", + "weight_bipolar", + ), + "data_in_entries": ( + "data_in_width", + "data_in_stochastic", + "data_in_bipolar", + ), + "bias_entries": ( + "bias_width", + "bias_stochastic", + "bias_bipolar", + ), }, "binary_residual": { "weight_entries": ( @@ -90,7 +114,11 @@ "data_in_residual_sign", "data_in_levels", # data_in_levels (int): number of residual levels to use in lutnet ), - "bias_entries": ("bias_width", "bias_stochastic", "bias_bipolar",), + "bias_entries": ( + "bias_width", + "bias_stochastic", + "bias_bipolar", + ), }, "binary_residual": { "weight_entries": ( @@ -106,7 +134,11 @@ "data_in_residual_sign", "data_in_levels", # data_in_levels (int): number of residual levels to use in lutnet ), - "bias_entries": ("bias_width", "bias_stochastic", "bias_bipolar",), + "bias_entries": ( + "bias_width", + "bias_stochastic", + "bias_bipolar", + ), }, "ternary": { "weight_entries": ( @@ -213,7 +245,11 @@ "data_in_exponent_bias_width", "data_in_block_size", ), - "bias_entries": ("bias_width", "bias_exponent_bias_width", "bias_block_size",), + "bias_entries": ( + "bias_width", + "bias_exponent_bias_width", + "bias_block_size", + ), }, "mxint_hardware": { "weight_entries": ( @@ -226,7 +262,11 @@ "data_in_exponent_width", "data_in_parallelism", ), - "bias_entries": ("bias_width", "bias_exponent_width", "bias_parallelism",), + "bias_entries": ( + "bias_width", + "bias_exponent_width", + "bias_parallelism", + ), }, } @@ -349,10 +389,22 @@ def cp_data_out_entries( ("name", "data_in_entries"), ("weight_entries", "bias_entries", "bypass"), ), - "layer_norm": (("name", "data_in_entries"), ("bypass",),), - "group_norm": (("name", "data_in_entries"), ("bypass",),), - "instance_norm2d": (("name", "data_in_entries"), ("bypass",),), - "rms_norm": (("name", "data_in_entries"), ("bypass",),), + "layer_norm": ( + ("name", "data_in_entries"), + ("bypass",), + ), + "group_norm": ( + ("name", "data_in_entries"), + ("bypass",), + ), + "instance_norm2d": ( + ("name", "data_in_entries"), + ("bypass",), + ), + "rms_norm": ( + ("name", "data_in_entries"), + ("bypass",), + ), "grouped_query_attention": ( ("name", "data_in_entries", "weight_entries"), ("bypass", "bias_entries"), diff --git a/src/chop/passes/graph/transforms/training/modify.py b/src/chop/passes/graph/transforms/training/modify.py index bf84faeee..eff5583d0 100644 --- a/src/chop/passes/graph/transforms/training/modify.py +++ b/src/chop/passes/graph/transforms/training/modify.py @@ -65,7 +65,10 @@ def attach_backward_fn(q_fn: torch.autograd.Function, mase_op: str, q_fn_cfg: di def create_new_module( - mase_op: str, original_module: nn.Module, config: dict, node_meta: dict, + mase_op: str, + original_module: nn.Module, + config: dict, + node_meta: dict, ): original_module_cls = type(original_module) diff --git a/src/chop/passes/graph/transforms/verilog/emit_bram.py b/src/chop/passes/graph/transforms/verilog/emit_bram.py index e562a5fb1..a6a000c79 100644 --- a/src/chop/passes/graph/transforms/verilog/emit_bram.py +++ b/src/chop/passes/graph/transforms/verilog/emit_bram.py @@ -245,8 +245,8 @@ def emit_parameters_in_dat_internal(node, param_name, file_name): else: base_quantizer = integer_quantizer_for_hw - scale = 2 ** frac_width - thresh = 2 ** width + scale = 2**frac_width + thresh = 2**width for i in range(0, out_depth): line_buff = "" for j in range(0, out_size): @@ -301,8 +301,8 @@ def emit_parameters_in_dat_hls(node, param_name, file_name): "precision" ][1] - scale = 2 ** frac_width - thresh = 2 ** width + scale = 2**frac_width + thresh = 2**width for i in range(0, out_depth): line_buff = "" for j in range(0, out_size): diff --git a/src/chop/passes/graph/transforms/verilog/emit_hls.py b/src/chop/passes/graph/transforms/verilog/emit_hls.py index ce348c43d..3efe2c037 100644 --- a/src/chop/passes/graph/transforms/verilog/emit_hls.py +++ b/src/chop/passes/graph/transforms/verilog/emit_hls.py @@ -121,7 +121,12 @@ def _call_hls_flow(node, node_dir): # Call Vitis HLS for synthesis vitis_hls = os.path.abspath( os.path.join( - os.path.dirname(__file__), "..", "..", "..", "scripts", "run-vitis-hls.sh", + os.path.dirname(__file__), + "..", + "..", + "..", + "scripts", + "run-vitis-hls.sh", ) ) assert os.path.isfile( diff --git a/src/chop/passes/module/analysis/report.py b/src/chop/passes/module/analysis/report.py index c384ceafd..514f75ce8 100644 --- a/src/chop/passes/module/analysis/report.py +++ b/src/chop/passes/module/analysis/report.py @@ -26,7 +26,8 @@ def get_submodule_summary(name: str, module: nn.Module, level: int = 0): def report_trainable_parameters_analysis_pass( - module: torch.nn.Module, pass_args: dict = {}, + module: torch.nn.Module, + pass_args: dict = {}, ): submodule_summary, total_params = get_submodule_summary("", module) table = [(name, params) for _, name, params in submodule_summary] diff --git a/src/chop/passes/utils.py b/src/chop/passes/utils.py index 8d7ea71eb..912250077 100644 --- a/src/chop/passes/utils.py +++ b/src/chop/passes/utils.py @@ -23,7 +23,9 @@ def _nightly_torch_installed(): return False -def find_missing_dependencies(pass_name: str,): +def find_missing_dependencies( + pass_name: str, +): dependencies = PassFactory._dependencies_dict.get(pass_name, None) if dependencies is None: @@ -38,7 +40,9 @@ def find_missing_dependencies(pass_name: str,): def register_mase_pass( - name: str, dependencies: list = [], requires_nightly_torch: bool = False, + name: str, + dependencies: list = [], + requires_nightly_torch: bool = False, ): """This decorator registers a mase pass as PassFactory class attributes which can be used globally.""" diff --git a/src/chop/pipelines/auto_pipeline.py b/src/chop/pipelines/auto_pipeline.py index 1bd382983..c9b784795 100644 --- a/src/chop/pipelines/auto_pipeline.py +++ b/src/chop/pipelines/auto_pipeline.py @@ -12,7 +12,11 @@ class AutoPipeline: The output of each pass is stored in a dictionary and can be accessed by the next pass. """ - def __init__(self, pass_groups=None, run_training: bool = False,) -> None: + def __init__( + self, + pass_groups=None, + run_training: bool = False, + ) -> None: """Initializes the AutoPipeline. Args: @@ -22,7 +26,11 @@ def __init__(self, pass_groups=None, run_training: bool = False,) -> None: self.pass_outputs = [{}] * len(pass_groups) def _run_pass_group( - self, mg: MaseGraph, pass_group: list, pass_args: dict, skip_passes: list = [], + self, + mg: MaseGraph, + pass_group: list, + pass_args: dict, + skip_passes: list = [], ): pass_outputs = {} @@ -48,7 +56,10 @@ def _run_pass_group( return mg, pass_outputs def __call__( - self, mg: MaseGraph, pass_args: dict, skip_passes: list = [], + self, + mg: MaseGraph, + pass_args: dict, + skip_passes: list = [], ): for idx, pass_group in enumerate(self.pass_groups): @@ -59,7 +70,10 @@ def __call__( ) mg, pass_outputs = self._run_pass_group( - mg, pass_group, pass_args, skip_passes, + mg, + pass_group, + pass_args, + skip_passes, ) self.pass_outputs[idx] = pass_outputs diff --git a/src/chop/tools/check_dependency.py b/src/chop/tools/check_dependency.py index c71ca088e..dbf13cd12 100644 --- a/src/chop/tools/check_dependency.py +++ b/src/chop/tools/check_dependency.py @@ -23,7 +23,9 @@ def check_deps_tensorRT_pass(silent: bool = True): return all(availabilities) -def find_missing_dependencies(pass_name: str,): +def find_missing_dependencies( + pass_name: str, +): dependencies = PassFactory._dependencies_dict.get(pass_name, None) if dependencies is None: @@ -38,7 +40,8 @@ def find_missing_dependencies(pass_name: str,): def check_dependencies( - pass_name: str, silent: bool = True, + pass_name: str, + silent: bool = True, ): unavailable_deps = find_missing_dependencies(pass_name) diff --git a/src/chop/tools/huggingface.py b/src/chop/tools/huggingface.py index efa0fb1f6..7cf675460 100644 --- a/src/chop/tools/huggingface.py +++ b/src/chop/tools/huggingface.py @@ -39,7 +39,10 @@ def get_hf_dummy_in(model): tokenizer = AutoTokenizer.from_pretrained(checkpoint) dummy_input = tokenizer( - ["AI may take over the world one day", "This is why you should learn ADLS",], + [ + "AI may take over the world one day", + "This is why you should learn ADLS", + ], return_tensors="pt", ) @@ -50,7 +53,9 @@ def get_hf_dummy_in(model): def get_tokenized_dataset( - dataset: str, checkpoint: str, return_tokenizer: bool = False, + dataset: str, + checkpoint: str, + return_tokenizer: bool = False, ): """ Tokenizes a dataset using the AutoTokenizer from Huggingface. @@ -76,7 +81,10 @@ def get_tokenized_dataset( tokenizer = AutoTokenizer.from_pretrained(checkpoint) def tokenize_function(example): - return tokenizer(example["text"], truncation=True,) + return tokenizer( + example["text"], + truncation=True, + ) # Tokenize tokenized_datasets = raw_datasets.map(tokenize_function, batched=True) diff --git a/src/chop/tools/plt_wrapper/nlp/classification.py b/src/chop/tools/plt_wrapper/nlp/classification.py index a69b2ac0f..f96227cba 100644 --- a/src/chop/tools/plt_wrapper/nlp/classification.py +++ b/src/chop/tools/plt_wrapper/nlp/classification.py @@ -41,7 +41,9 @@ def forward(self, input_ids, attention_mask, token_type_ids=None, labels=None): ) else: outputs = self.model( - input_ids, attention_mask=attention_mask, labels=labels, + input_ids, + attention_mask=attention_mask, + labels=labels, ) return outputs diff --git a/src/chop/tools/plt_wrapper/nlp/lm.py b/src/chop/tools/plt_wrapper/nlp/lm.py index a2b669e87..60739da47 100644 --- a/src/chop/tools/plt_wrapper/nlp/lm.py +++ b/src/chop/tools/plt_wrapper/nlp/lm.py @@ -46,7 +46,9 @@ def training_step(self, batch, batch_idx): self.log("train_loss_step", loss, prog_bar=True) self.log( - "train_perplexity_step", perplexity, prog_bar=True, + "train_perplexity_step", + perplexity, + prog_bar=True, ) return loss diff --git a/src/chop/tools/utils.py b/src/chop/tools/utils.py index d070eff92..2ae59855e 100644 --- a/src/chop/tools/utils.py +++ b/src/chop/tools/utils.py @@ -92,7 +92,7 @@ def get_factors(n): set( functools.reduce( list.__add__, - ([i, n // i] for i in range(1, int(n ** 0.5) + 1) if n % i == 0), + ([i, n // i] for i in range(1, int(n**0.5) + 1) if n % i == 0), ) ) ) @@ -184,7 +184,9 @@ def init_Conv2dLUT_weight( # Initialize the weight based on the trained binaried network # weight shape of the lagrange trainer [tables_count, self.kk] input_mask = new_module.input_mask.reshape( - -1, in_channels * kernel_size[0] * kernel_size[1] * k, 3, + -1, + in_channels * kernel_size[0] * kernel_size[1] * k, + 3, ) # [oc, k * kh * kw * ic ,3[ic,kh,kw]] expanded_original_weight = original_weight[ np.arange(out_channels)[:, np.newaxis], diff --git a/src/mase_cocotb/interfaces/streaming.py b/src/mase_cocotb/interfaces/streaming.py index 711d2760f..10ca26820 100644 --- a/src/mase_cocotb/interfaces/streaming.py +++ b/src/mase_cocotb/interfaces/streaming.py @@ -17,7 +17,14 @@ def _sign_extend(value: int, bits: int): class StreamDriver(Driver): - def __init__(self, clk, data, valid, ready, record_num_beats=False,) -> None: + def __init__( + self, + clk, + data, + valid, + ready, + record_num_beats=False, + ) -> None: super().__init__() self.clk = clk self.data = data @@ -67,7 +74,14 @@ async def _driver_send(self, transaction) -> None: class StreamMonitor(Monitor): def __init__( - self, clk, data, valid, ready, check=True, name=None, unsigned=False, + self, + clk, + data, + valid, + ready, + check=True, + name=None, + unsigned=False, ): super().__init__(clk, check=check, name=name) self.clk = clk @@ -151,9 +165,9 @@ def __init__(self, clk, data, valid, ready, data_width, frac_width, check=True): def _check(self, got, exp): if self.check: - float_got = [x * 2 ** -self.frac_width for x in got] - float_exp = [x * 2 ** -self.frac_width for x in exp] - if not np.isclose(float_got, float_exp, atol=2 ** -self.frac_width).all(): + float_got = [x * 2**-self.frac_width for x in got] + float_exp = [x * 2**-self.frac_width for x in exp] + if not np.isclose(float_got, float_exp, atol=2**-self.frac_width).all(): # raise TestFailure("\nGot \n%s, \nExpected \n%s" % (got, exp)) raise TestFailure( f"\nGot int \n{got}, \nExpected int \n{exp} \nGot float \n{float_got}, \nExpected float \n{float_exp}" diff --git a/src/mase_cocotb/runner.py b/src/mase_cocotb/runner.py index a74bff0d4..21490555a 100644 --- a/src/mase_cocotb/runner.py +++ b/src/mase_cocotb/runner.py @@ -110,7 +110,10 @@ def _single_test( verilog_sources=sources, includes=includes, hdl_toplevel=module, - build_args=[*tool_args, *extra_build_args,], + build_args=[ + *tool_args, + *extra_build_args, + ], # Do not use params in hierarchical verilation parameters=module_params if not hierarchical else {}, build_dir=test_work_dir, @@ -309,10 +312,14 @@ def simulate_pass( verilog_sources=[rtl_dir / "top.sv"], includes=[rtl_dir], hdl_toplevel="top", - build_args=[*_verilator_args(False, trace) * extra_build_args,], + build_args=[ + *_verilator_args(False, trace) * extra_build_args, + ], parameters=module_params, build_dir=sim_dir, ) runner.test( - hdl_toplevel="top", test_module="test", results_xml="results.xml", + hdl_toplevel="top", + test_module="test", + results_xml="results.xml", ) diff --git a/src/mase_cocotb/testbench.py b/src/mase_cocotb/testbench.py index d746f74b5..0ad58ba1b 100644 --- a/src/mase_cocotb/testbench.py +++ b/src/mase_cocotb/testbench.py @@ -8,7 +8,12 @@ class Testbench: __test__ = False # so pytest doesn't confuse this with a test def __init__( - self, dut, clk=None, rst=None, fail_on_checks=True, clk_period_ns=20, + self, + dut, + clk=None, + rst=None, + fail_on_checks=True, + clk_period_ns=20, ) -> None: self.dut = dut self.clk = clk diff --git a/src/mase_cocotb/utils.py b/src/mase_cocotb/utils.py index 25b6d83d8..7271389f2 100644 --- a/src/mase_cocotb/utils.py +++ b/src/mase_cocotb/utils.py @@ -78,8 +78,8 @@ def int_floor_quantizer(x: Tensor, width: int, frac_width: int, signed=True): int_max = 2 ** (width - 1) - 1 else: int_min = 0 - int_max = 2 ** width - 1 - scale = 2 ** frac_width + int_max = 2**width - 1 + scale = 2**frac_width return torch.clamp(torch.floor(x.mul(scale)), int_min, int_max).div(scale) diff --git a/src/mase_cocotb/z_qlayers/tensor_cast.py b/src/mase_cocotb/z_qlayers/tensor_cast.py index cf212cf79..ce651bc5c 100644 --- a/src/mase_cocotb/z_qlayers/tensor_cast.py +++ b/src/mase_cocotb/z_qlayers/tensor_cast.py @@ -46,9 +46,9 @@ def _integer_quantize( int_max = 2 ** (width - 1) - 1 else: int_min = 0 - int_max = 2 ** width - 1 + int_max = 2**width - 1 # thresh = 2 ** (width - 1) - scale = 2 ** frac_width + scale = 2**frac_width if isinstance(x, (Tensor, ndarray)): return my_clamp(my_round(x.mul(scale)), int_min, int_max).div(scale) @@ -59,7 +59,7 @@ def _integer_quantize( def quantize_to_int(x: Tensor, width: int, frac_width: int): - x = (_integer_quantize(x, width, frac_width) * (2 ** frac_width)).int() + x = (_integer_quantize(x, width, frac_width) * (2**frac_width)).int() return x diff --git a/src/mase_components/activation_layers/test/fixed_elu_tb.py b/src/mase_components/activation_layers/test/fixed_elu_tb.py index 1c7e850e4..87ff69fd6 100644 --- a/src/mase_components/activation_layers/test/fixed_elu_tb.py +++ b/src/mase_components/activation_layers/test/fixed_elu_tb.py @@ -79,10 +79,10 @@ def generate_lookup(data_width: int, f_width: int, function: str, type="hex"): count += 1 iarr.append(i) val = quanter(f(torch.tensor(i))) # entry in the lookup table - lut[ - doubletofx(data_width=data_width, f_width=f_width, num=i, type=type) - ] = doubletofx( - data_width=data_width, f_width=f_width, num=val.item(), type=type + lut[doubletofx(data_width=data_width, f_width=f_width, num=i, type=type)] = ( + doubletofx( + data_width=data_width, f_width=f_width, num=val.item(), type=type + ) ) i += 2 ** -(f_width) return lut @@ -444,8 +444,8 @@ def exp(self): ) # match output logger.info(f"EXP - FLOAT OUTPUT: \n{m}") m = self.out_dquantizer(m) - m2 = (m * 2 ** self.outputfracw).to(torch.int64) - m2 = m2.clone().detach() % (2 ** self.outputwidth) + m2 = (m * 2**self.outputfracw).to(torch.int64) + m2 = m2.clone().detach() % (2**self.outputwidth) return m2 @@ -456,7 +456,7 @@ def generate_inputs(self): ) logger.info(f"FLOAT INPUT: \n{inputs}") inputs = self.in_dquantizer(inputs) - intinp = (inputs * 2 ** self.frac_width).to(torch.int64) + intinp = (inputs * 2**self.frac_width).to(torch.int64) return intinp, inputs def doubletofx(self, num, data_width, f_width, type="bin"): diff --git a/src/mase_components/activation_layers/test/fixed_gelu_tb.py b/src/mase_components/activation_layers/test/fixed_gelu_tb.py index 2a8e79787..1a7d760b6 100644 --- a/src/mase_components/activation_layers/test/fixed_gelu_tb.py +++ b/src/mase_components/activation_layers/test/fixed_gelu_tb.py @@ -29,13 +29,13 @@ async def cocotb_test_fixed_gelu(dut): resolution = (max_value - min_value) / (num_values - 1) # Convert the resolution into fixed-point format - resolution_fixed_point = int(resolution * (2 ** DATA_IN_0_PRECISION_1)) + resolution_fixed_point = int(resolution * (2**DATA_IN_0_PRECISION_1)) # Generate the equidistant values values = np.linspace(min_value, max_value, num_values) # Convert values to fixed-point format - values_fixed_point = np.round(values * (2 ** DATA_IN_0_PRECISION_1)).astype(int) + values_fixed_point = np.round(values * (2**DATA_IN_0_PRECISION_1)).astype(int) tensor_tanh = torch.Tensor(values) @@ -46,9 +46,9 @@ async def cocotb_test_fixed_gelu(dut): for i in range(87): # a = tanh_values[i] - b = max_value * (2 ** DATA_IN_0_PRECISION_1) - resolution_fixed_point * (i - 1) + b = max_value * (2**DATA_IN_0_PRECISION_1) - resolution_fixed_point * (i - 1) - a = [b / (2 ** DATA_IN_0_PRECISION_1)] + a = [b / (2**DATA_IN_0_PRECISION_1)] tensor_tanh = torch.Tensor(a) c = model(tensor_tanh) @@ -56,7 +56,7 @@ async def cocotb_test_fixed_gelu(dut): # Scale Tanh output to fixed-point range and convert to integers scaled_value = np.round( - tanh_value_numpy * (2 ** DATA_OUT_0_PRECISION_1) + tanh_value_numpy * (2**DATA_OUT_0_PRECISION_1) ).astype(int) dut.data_in_0[0].value = b diff --git a/src/mase_components/activation_layers/test/fixed_hardshrink_tb.py b/src/mase_components/activation_layers/test/fixed_hardshrink_tb.py index 57f3940fb..efb6f0045 100644 --- a/src/mase_components/activation_layers/test/fixed_hardshrink_tb.py +++ b/src/mase_components/activation_layers/test/fixed_hardshrink_tb.py @@ -62,11 +62,11 @@ def exp(self, inputs): # cond = torch.logical_not(torch.logical_and(inputs <= self.thresh*2**self.fracw, inputs >= -1 * self.thresh *2**self.fracw)) # out = torch.where(cond, inputs, torch.tensor(0)) # unsignedout = torch.where(out < 0, torch.tensor(out % (2**self.width)), out) - m = torch.nn.Hardshrink(self.thresh * 2 ** self.fracw)(inputs.to(torch.float)) + m = torch.nn.Hardshrink(self.thresh * 2**self.fracw)(inputs.to(torch.float)) mout = m.clamp( min=-1 * 2 ** (self.outputwidth - 1), max=2 ** (self.outputwidth - 1) - 1 ) - m2 = torch.where(mout < 0, torch.tensor(mout % (2 ** self.outputwidth)), mout) + m2 = torch.where(mout < 0, torch.tensor(mout % (2**self.outputwidth)), mout) return m2.to(torch.int32).tolist() def generate_inputs(self, w, fracw): @@ -75,7 +75,7 @@ def generate_inputs(self, w, fracw): ) realinp = torch.randn(self.samples) inputs = self.dquantizer(realinp) - intinp = (inputs * 2 ** self.fracw).to(torch.int64) + intinp = (inputs * 2**self.fracw).to(torch.int64) intinp.clamp( min=-(2 ** (self.width - self.fracw - 1)), max=2 ** (self.width - self.fracw - 1) - 1, diff --git a/src/mase_components/activation_layers/test/fixed_hardswish_tb.py b/src/mase_components/activation_layers/test/fixed_hardswish_tb.py index 53f79755e..ceb194e35 100644 --- a/src/mase_components/activation_layers/test/fixed_hardswish_tb.py +++ b/src/mase_components/activation_layers/test/fixed_hardswish_tb.py @@ -56,14 +56,12 @@ def __init__(self, dut) -> None: def exp(self, inputs): # Run the model with the provided inputs and return the outputs - tmp0 = 3 * 2 ** self.fracw + tmp0 = 3 * 2**self.fracw tmp1 = inputs + tmp0 - tmp2 = tmp1 * (2 ** -3) + tmp1 * (2 ** -4) + tmp2 = tmp1 * (2**-3) + tmp1 * (2**-4) # qtmps = self.dquantizer(tmp2) tmp3 = tmp2 * inputs - unsignedout = torch.where( - tmp3 < 0, torch.tensor(tmp3 % (2 ** self.width)), tmp3 - ) + unsignedout = torch.where(tmp3 < 0, torch.tensor(tmp3 % (2**self.width)), tmp3) # return unsignedout.tolist() return unsignedout @@ -73,7 +71,7 @@ def generate_inputs(self, w, fracw): ) realinp = torch.randn(self.samples) inputs = self.dquantizer(realinp) - intinp = (inputs * 2 ** self.fracw).to(torch.int64) + intinp = (inputs * 2**self.fracw).to(torch.int64) intinp.clamp( min=-(2 ** (self.width - self.fracw - 1)), max=2 ** (self.width - self.fracw - 1) - 1, diff --git a/src/mase_components/activation_layers/test/fixed_leaky_relu_tb.py b/src/mase_components/activation_layers/test/fixed_leaky_relu_tb.py index a44efe0b8..d0cd8796e 100644 --- a/src/mase_components/activation_layers/test/fixed_leaky_relu_tb.py +++ b/src/mase_components/activation_layers/test/fixed_leaky_relu_tb.py @@ -28,9 +28,9 @@ def get_in_and_out(x, fn, width, frac_width): ins = integer_quantizer(x, width=width, frac_width=frac_width) y = fn(x) outs = integer_quantizer(y, width=width, frac_width=frac_width) - outs = outs * 2 ** frac_width + outs = outs * 2**frac_width outs = outs.int() - ins = ins * 2 ** frac_width + ins = ins * 2**frac_width ins = ins.int() return (ins, outs) @@ -78,7 +78,7 @@ async def cocotb_test(dut): logger.info(f"Reset finished") tb.data_out_0_monitor.ready.value = 1 - inputs, exp_outs = tb.generate_inputs_outputs(8, 4, 2 ** -4) + inputs, exp_outs = tb.generate_inputs_outputs(8, 4, 2**-4) tb.data_in_0_driver.append(inputs.tolist()) diff --git a/src/mase_components/activation_layers/test/fixed_logsigmoid_tb.py b/src/mase_components/activation_layers/test/fixed_logsigmoid_tb.py index 9b9a0a766..570ddc328 100644 --- a/src/mase_components/activation_layers/test/fixed_logsigmoid_tb.py +++ b/src/mase_components/activation_layers/test/fixed_logsigmoid_tb.py @@ -142,8 +142,8 @@ def exp(self): ) # match output logger.info(f"EXP - FLOAT OUTPUT: \n{m}") m = self.out_dquantizer(m) - m2 = (m * 2 ** self.outputfracw).to(torch.int64) - m2 = m2.clone().detach() % (2 ** self.outputwidth) + m2 = (m * 2**self.outputfracw).to(torch.int64) + m2 = m2.clone().detach() % (2**self.outputwidth) return m2 @@ -154,7 +154,7 @@ def generate_inputs(self): ) logger.info(f"FLOAT INPUT: \n{inputs}") inputs = self.in_dquantizer(inputs) - intinp = (inputs * 2 ** self.frac_width).to(torch.int64) + intinp = (inputs * 2**self.frac_width).to(torch.int64) return intinp, inputs def doubletofx(self, num, data_width, f_width, type="bin"): diff --git a/src/mase_components/activation_layers/test/fixed_relu_tb.py b/src/mase_components/activation_layers/test/fixed_relu_tb.py index dbe8ac146..118296c48 100644 --- a/src/mase_components/activation_layers/test/fixed_relu_tb.py +++ b/src/mase_components/activation_layers/test/fixed_relu_tb.py @@ -47,7 +47,7 @@ def backward(ctx, grad_output): def quantize(x, bits, bias): # bits = 32 """Do linear quantization to input according to a scale and number of bits""" thresh = 2 ** (bits - 1) - scale = 2 ** bias + scale = 2**bias return my_clamp(my_round(x.mul(scale)), -thresh, thresh - 1).div(scale) @@ -83,7 +83,7 @@ def get_dut_parameters(self): def get_dut_input(self, i): inputs = self.inputs[i] - shifted_integers = (inputs * (2 ** self.bias)).int() + shifted_integers = (inputs * (2**self.bias)).int() return shifted_integers.numpy().tolist() def get_dut_output(self, i): @@ -92,7 +92,7 @@ def get_dut_output(self, i): return shifted_integers def convert_to_fixed(self, x): - return (x * (2 ** self.bias)).int().numpy().tolist() + return (x * (2**self.bias)).int().numpy().tolist() @cocotb.test() diff --git a/src/mase_components/activation_layers/test/fixed_selu_tb.py b/src/mase_components/activation_layers/test/fixed_selu_tb.py index 459334c6c..647c4fff1 100644 --- a/src/mase_components/activation_layers/test/fixed_selu_tb.py +++ b/src/mase_components/activation_layers/test/fixed_selu_tb.py @@ -25,13 +25,13 @@ async def cocotb_test_fixed_selu(dut): resolution = (max_value - min_value) / (num_values - 1) # Convert the resolution into fixed-point format - resolution_fixed_point = int(resolution * (2 ** DATA_IN_0_PRECISION_1)) + resolution_fixed_point = int(resolution * (2**DATA_IN_0_PRECISION_1)) # Generate the equidistant values values = np.linspace(min_value, max_value, num_values) # Convert values to fixed-point format - values_fixed_point = np.round(values * (2 ** DATA_IN_0_PRECISION_1)).astype(int) + values_fixed_point = np.round(values * (2**DATA_IN_0_PRECISION_1)).astype(int) tensor_tanh = torch.Tensor(values) @@ -42,9 +42,9 @@ async def cocotb_test_fixed_selu(dut): for i in range(87): # a = tanh_values[i] - b = max_value * (2 ** DATA_IN_0_PRECISION_1) - resolution_fixed_point * (i - 1) + b = max_value * (2**DATA_IN_0_PRECISION_1) - resolution_fixed_point * (i - 1) - a = [b / (2 ** DATA_IN_0_PRECISION_1)] + a = [b / (2**DATA_IN_0_PRECISION_1)] tensor_tanh = torch.Tensor(a) c = model(tensor_tanh) @@ -52,7 +52,7 @@ async def cocotb_test_fixed_selu(dut): # Scale Tanh output to fixed-point range and convert to integers scaled_value = np.round( - tanh_value_numpy * (2 ** DATA_OUT_0_PRECISION_1) + tanh_value_numpy * (2**DATA_OUT_0_PRECISION_1) ).astype(int) dut.data_in_0[0].value = b diff --git a/src/mase_components/activation_layers/test/fixed_sigmoid_tb.py b/src/mase_components/activation_layers/test/fixed_sigmoid_tb.py index 427220089..e2a794d22 100644 --- a/src/mase_components/activation_layers/test/fixed_sigmoid_tb.py +++ b/src/mase_components/activation_layers/test/fixed_sigmoid_tb.py @@ -128,8 +128,8 @@ def exp(self): ) # match output logger.info(f"EXP - FLOAT OUTPUT: \n{m}") m = self.out_dquantizer(m) - m2 = (m * 2 ** self.outputfracw).to(torch.int64) - m2 = m2.clone().detach() % (2 ** self.outputwidth) + m2 = (m * 2**self.outputfracw).to(torch.int64) + m2 = m2.clone().detach() % (2**self.outputwidth) return m2 @@ -140,7 +140,7 @@ def generate_inputs(self): ) logger.info(f"FLOAT INPUT: \n{inputs}") inputs = self.in_dquantizer(inputs) - intinp = (inputs * 2 ** self.frac_width).to(torch.int64) + intinp = (inputs * 2**self.frac_width).to(torch.int64) return intinp, inputs def doubletofx(self, num, data_width, f_width, type="bin"): diff --git a/src/mase_components/activation_layers/test/fixed_silu_tb.py b/src/mase_components/activation_layers/test/fixed_silu_tb.py index 19c4ac3a2..944a83e5e 100644 --- a/src/mase_components/activation_layers/test/fixed_silu_tb.py +++ b/src/mase_components/activation_layers/test/fixed_silu_tb.py @@ -143,8 +143,8 @@ def exp(self): ) # match output logger.info(f"EXP - FLOAT OUTPUT: \n{m}") m = self.out_dquantizer(m) - m2 = (m * 2 ** self.outputfracw).to(torch.int64) - m2 = m2.clone().detach() % (2 ** self.outputwidth) + m2 = (m * 2**self.outputfracw).to(torch.int64) + m2 = m2.clone().detach() % (2**self.outputwidth) return m2 @@ -155,7 +155,7 @@ def generate_inputs(self): ) logger.info(f"FLOAT INPUT: \n{inputs}") inputs = self.in_dquantizer(inputs) - intinp = (inputs * 2 ** self.frac_width).to(torch.int64) + intinp = (inputs * 2**self.frac_width).to(torch.int64) return intinp, inputs def doubletofx(self, num, data_width, f_width, type="bin"): diff --git a/src/mase_components/activation_layers/test/fixed_softermax_1d_tb.py b/src/mase_components/activation_layers/test/fixed_softermax_1d_tb.py index 41e67e205..fa2b2aa75 100644 --- a/src/mase_components/activation_layers/test/fixed_softermax_1d_tb.py +++ b/src/mase_components/activation_layers/test/fixed_softermax_1d_tb.py @@ -64,7 +64,10 @@ def __init__(self, dut) -> None: self.model = partial( fixed_softermax, dim=0, - q_config={"width": self.IN_WIDTH, "frac_width": self.IN_FRAC_WIDTH,}, + q_config={ + "width": self.IN_WIDTH, + "frac_width": self.IN_FRAC_WIDTH, + }, ) # Set verbosity of driver and monitor loggers to debug @@ -72,7 +75,9 @@ def __init__(self, dut) -> None: # self.out_data_monitor.log.setLevel(logging.DEBUG) def generate_inputs(self, batches): - return torch.randn((batches, self.TOTAL_DIM),) + return torch.randn( + (batches, self.TOTAL_DIM), + ) async def run_test(self, batches, us): await self.reset() @@ -88,7 +93,10 @@ async def run_test(self, batches, us): self.log.debug(f"Processing inputs: {batch}") driver_input = fixed_preprocess_tensor( tensor=batch, - q_config={"width": self.IN_WIDTH, "frac_width": self.IN_FRAC_WIDTH,}, + q_config={ + "width": self.IN_WIDTH, + "frac_width": self.IN_FRAC_WIDTH, + }, parallelism=[self.PARALLELISM], ) self.in_data_driver.load_driver(driver_input) @@ -97,7 +105,10 @@ async def run_test(self, batches, us): self.log.debug(f"Processing outputs: {exp_out}") outs = fixed_preprocess_tensor( tensor=exp_out, - q_config={"width": self.OUT_WIDTH, "frac_width": self.OUT_FRAC_WIDTH,}, + q_config={ + "width": self.OUT_WIDTH, + "frac_width": self.OUT_FRAC_WIDTH, + }, parallelism=[self.PARALLELISM], ) self.out_data_monitor.load_monitor(outs) diff --git a/src/mase_components/activation_layers/test/fixed_softermax_tb.py b/src/mase_components/activation_layers/test/fixed_softermax_tb.py index 84c6aa4e5..9b7d0aee1 100644 --- a/src/mase_components/activation_layers/test/fixed_softermax_tb.py +++ b/src/mase_components/activation_layers/test/fixed_softermax_tb.py @@ -133,7 +133,9 @@ def test_fixed_softermax_smoke(): """ mase_runner( trace=True, - module_param_list=[get_fixed_softermax_config(),], + module_param_list=[ + get_fixed_softermax_config(), + ], # skip_build=True, ) diff --git a/src/mase_components/activation_layers/test/fixed_softmax_tb.py b/src/mase_components/activation_layers/test/fixed_softmax_tb.py index afc8876ac..62f37012f 100644 --- a/src/mase_components/activation_layers/test/fixed_softmax_tb.py +++ b/src/mase_components/activation_layers/test/fixed_softmax_tb.py @@ -128,8 +128,8 @@ def exp(self): ) # match output logger.info(f"EXP - FLOAT OUTPUT: \n{m}") m = self.out_dquantizer(m) - m2 = (m * 2 ** self.outputfracw).to(torch.int64) - m2 = m2.clone().detach() % (2 ** self.outputwidth) + m2 = (m * 2**self.outputfracw).to(torch.int64) + m2 = m2.clone().detach() % (2**self.outputwidth) return m2 @@ -140,7 +140,7 @@ def generate_inputs(self): ) logger.info(f"FLOAT INPUT: \n{inputs}") inputs = self.in_dquantizer(inputs) - intinp = (inputs * 2 ** self.frac_width).to(torch.int64) + intinp = (inputs * 2**self.frac_width).to(torch.int64) return intinp, inputs def doubletofx(self, num, data_width, f_width, type="bin"): diff --git a/src/mase_components/activation_layers/test/fixed_softplus_tb.py b/src/mase_components/activation_layers/test/fixed_softplus_tb.py index f39467ee7..121fa5c60 100644 --- a/src/mase_components/activation_layers/test/fixed_softplus_tb.py +++ b/src/mase_components/activation_layers/test/fixed_softplus_tb.py @@ -25,13 +25,13 @@ async def cocotb_test_fixed_softplus(dut): resolution = (max_value - min_value) / (num_values - 1) # Convert the resolution into fixed-point format - resolution_fixed_point = int(resolution * (2 ** DATA_IN_0_PRECISION_1)) + resolution_fixed_point = int(resolution * (2**DATA_IN_0_PRECISION_1)) # Generate the equidistant values values = np.linspace(min_value, max_value, num_values) # Convert values to fixed-point format - values_fixed_point = np.round(values * (2 ** DATA_IN_0_PRECISION_1)).astype(int) + values_fixed_point = np.round(values * (2**DATA_IN_0_PRECISION_1)).astype(int) tensor_tanh = torch.Tensor(values) @@ -42,9 +42,9 @@ async def cocotb_test_fixed_softplus(dut): for i in range(87): # a = tanh_values[i] - b = max_value * (2 ** DATA_IN_0_PRECISION_1) - resolution_fixed_point * (i - 1) + b = max_value * (2**DATA_IN_0_PRECISION_1) - resolution_fixed_point * (i - 1) - a = [b / (2 ** DATA_IN_0_PRECISION_1)] + a = [b / (2**DATA_IN_0_PRECISION_1)] tensor_tanh = torch.Tensor(a) c = model(tensor_tanh) @@ -52,7 +52,7 @@ async def cocotb_test_fixed_softplus(dut): # Scale Tanh output to fixed-point range and convert to integers scaled_value = np.round( - tanh_value_numpy * (2 ** DATA_OUT_0_PRECISION_1) + tanh_value_numpy * (2**DATA_OUT_0_PRECISION_1) ).astype(int) dut.data_in_0[0].value = b diff --git a/src/mase_components/activation_layers/test/fixed_softshrink_tb.py b/src/mase_components/activation_layers/test/fixed_softshrink_tb.py index 9c738602e..b79ccdaf5 100644 --- a/src/mase_components/activation_layers/test/fixed_softshrink_tb.py +++ b/src/mase_components/activation_layers/test/fixed_softshrink_tb.py @@ -142,8 +142,8 @@ def exp(self): ) # match output logger.info(f"EXP - FLOAT OUTPUT: \n{m}") m = self.out_dquantizer(m) - m2 = (m * 2 ** self.outputfracw).to(torch.int64) - m2 = m2.clone().detach() % (2 ** self.outputwidth) + m2 = (m * 2**self.outputfracw).to(torch.int64) + m2 = m2.clone().detach() % (2**self.outputwidth) return m2 @@ -154,7 +154,7 @@ def generate_inputs(self): ) logger.info(f"FLOAT INPUT: \n{inputs}") inputs = self.in_dquantizer(inputs) - intinp = (inputs * 2 ** self.frac_width).to(torch.int64) + intinp = (inputs * 2**self.frac_width).to(torch.int64) return intinp, inputs def doubletofx(self, num, data_width, f_width, type="bin"): diff --git a/src/mase_components/activation_layers/test/fixed_softsign_tb.py b/src/mase_components/activation_layers/test/fixed_softsign_tb.py index c8776ac32..5fec6b341 100644 --- a/src/mase_components/activation_layers/test/fixed_softsign_tb.py +++ b/src/mase_components/activation_layers/test/fixed_softsign_tb.py @@ -24,13 +24,13 @@ async def cocotb_test_fixed_softsign(dut): resolution = (max_value - min_value) / (num_values - 1) # Convert the resolution into fixed-point format - resolution_fixed_point = int(resolution * (2 ** DATA_IN_0_PRECISION_1)) + resolution_fixed_point = int(resolution * (2**DATA_IN_0_PRECISION_1)) # Generate the equidistant values values = np.linspace(min_value, max_value, num_values) # Convert values to fixed-point format - values_fixed_point = np.round(values * (2 ** DATA_IN_0_PRECISION_1)).astype(int) + values_fixed_point = np.round(values * (2**DATA_IN_0_PRECISION_1)).astype(int) tensor_tanh = torch.Tensor(values) @@ -41,9 +41,9 @@ async def cocotb_test_fixed_softsign(dut): for i in range(87): # a = tanh_values[i] - b = max_value * (2 ** DATA_IN_0_PRECISION_1) - resolution_fixed_point * (i - 1) + b = max_value * (2**DATA_IN_0_PRECISION_1) - resolution_fixed_point * (i - 1) - a = [b / (2 ** DATA_IN_0_PRECISION_1)] + a = [b / (2**DATA_IN_0_PRECISION_1)] tensor_tanh = torch.Tensor(a) c = model(tensor_tanh) @@ -51,7 +51,7 @@ async def cocotb_test_fixed_softsign(dut): # Scale Tanh output to fixed-point range and convert to integers scaled_value = np.round( - tanh_value_numpy * (2 ** DATA_OUT_0_PRECISION_1) + tanh_value_numpy * (2**DATA_OUT_0_PRECISION_1) ).astype(int) dut.data_in_0[0].value = b diff --git a/src/mase_components/activation_layers/test/fixed_tanh_tb.py b/src/mase_components/activation_layers/test/fixed_tanh_tb.py index 937fb930c..4e8c6d268 100644 --- a/src/mase_components/activation_layers/test/fixed_tanh_tb.py +++ b/src/mase_components/activation_layers/test/fixed_tanh_tb.py @@ -24,13 +24,13 @@ async def cocotb_test_fixed_tanh(dut): resolution = (max_value - min_value) / (num_values - 1) # Convert the resolution into fixed-point format - resolution_fixed_point = int(resolution * (2 ** DATA_IN_0_PRECISION_1)) + resolution_fixed_point = int(resolution * (2**DATA_IN_0_PRECISION_1)) # Generate the equidistant values values = np.linspace(min_value, max_value, num_values) # Convert values to fixed-point format - values_fixed_point = np.round(values * (2 ** DATA_IN_0_PRECISION_1)).astype(int) + values_fixed_point = np.round(values * (2**DATA_IN_0_PRECISION_1)).astype(int) tensor_tanh = torch.Tensor(values) @@ -41,9 +41,9 @@ async def cocotb_test_fixed_tanh(dut): for i in range(87): # a = tanh_values[i] - b = max_value * (2 ** DATA_IN_0_PRECISION_1) - resolution_fixed_point * (i - 1) + b = max_value * (2**DATA_IN_0_PRECISION_1) - resolution_fixed_point * (i - 1) - a = [b / (2 ** DATA_IN_0_PRECISION_1)] + a = [b / (2**DATA_IN_0_PRECISION_1)] tensor_tanh = torch.Tensor(a) c = model(tensor_tanh) @@ -51,7 +51,7 @@ async def cocotb_test_fixed_tanh(dut): # Scale Tanh output to fixed-point range and convert to integers scaled_value = np.round( - tanh_value_numpy * (2 ** DATA_OUT_0_PRECISION_1) + tanh_value_numpy * (2**DATA_OUT_0_PRECISION_1) ).astype(int) dut.data_in_0[0].value = b diff --git a/src/mase_components/activation_layers/test/softermax.py b/src/mase_components/activation_layers/test/softermax.py index 7c30c2986..a1b3271f5 100644 --- a/src/mase_components/activation_layers/test/softermax.py +++ b/src/mase_components/activation_layers/test/softermax.py @@ -53,7 +53,7 @@ def _softmax_model(l: list[int], parallelism: int, pow2=False): for diff, vals in zip(local_max_diff, local_values_buffer): if pow2: - adj = [x * (2 ** -diff) for x in vals] + adj = [x * (2**-diff) for x in vals] else: adj = [x * exp(-diff) for x in vals] norm += sum(adj) diff --git a/src/mase_components/activation_layers/test/softermax_global_norm_tb.py b/src/mase_components/activation_layers/test/softermax_global_norm_tb.py index 6ebbfc131..bc74e42ee 100644 --- a/src/mase_components/activation_layers/test/softermax_global_norm_tb.py +++ b/src/mase_components/activation_layers/test/softermax_global_norm_tb.py @@ -46,7 +46,7 @@ def __init__(self, dut) -> None: # Specify Error Threshold self.percentage_error = 0.05 # 5% - self.error_threshold_bits = ceil(self.percentage_error * (2 ** self.OUT_WIDTH)) + self.error_threshold_bits = ceil(self.percentage_error * (2**self.OUT_WIDTH)) self.output_monitor = ErrorThresholdStreamMonitor( dut.clk, @@ -63,15 +63,15 @@ def __init__(self, dut) -> None: def generate_inputs(self, batches=10): # TODO: Take a look at all zero case again local_vals = torch.randint( - 1, 2 ** self.IN_VALUE_WIDTH, size=(batches * self.DEPTH, self.PARALLELISM) + 1, 2**self.IN_VALUE_WIDTH, size=(batches * self.DEPTH, self.PARALLELISM) ) local_max = torch.randint( - 0, 2 ** self.IN_MAX_WIDTH, size=(batches * self.DEPTH, 1) + 0, 2**self.IN_MAX_WIDTH, size=(batches * self.DEPTH, 1) ) logger.debug("local_vals: %s" % (local_vals)) logger.debug( - "local_vals (float): %s" % (local_vals / (2 ** self.IN_VALUE_FRAC_WIDTH)) + "local_vals (float): %s" % (local_vals / (2**self.IN_VALUE_FRAC_WIDTH)) ) logger.debug("local_max: %s" % (local_max)) logger.debug( @@ -87,7 +87,7 @@ def model(self, inputs): for batch in batched_in: local_vals, local_max = list(zip(*batch)) local_vals = torch.tensor(list(local_vals), dtype=torch.float) / ( - 2 ** self.IN_VALUE_FRAC_WIDTH + 2**self.IN_VALUE_FRAC_WIDTH ) local_max = torch.tensor(list(local_max), dtype=torch.float) local_max = sign_extend_t( @@ -97,7 +97,7 @@ def model(self, inputs): global_max = local_max.max() adj_amt = global_max - local_max.reshape(self.DEPTH, 1) adj_values = integer_floor_quantizer( - x=local_vals / (2 ** adj_amt), + x=local_vals / (2**adj_amt), width=self.IN_VALUE_WIDTH, frac_width=self.IN_VALUE_FRAC_WIDTH, is_signed=False, @@ -226,7 +226,10 @@ def in_value_cfgs(cfgs: list): for cfg in cfgs: for in_width in [4, 7, 10]: new_cfgs.append( - {**cfg, "IN_VALUE_WIDTH": in_width,} + { + **cfg, + "IN_VALUE_WIDTH": in_width, + } ) return new_cfgs @@ -235,7 +238,10 @@ def in_max_cfgs(cfgs: list): for cfg in cfgs: for in_max in [2, 3, 4]: new_cfgs.append( - {**cfg, "IN_MAX_WIDTH": in_max,} + { + **cfg, + "IN_MAX_WIDTH": in_max, + } ) return new_cfgs @@ -250,5 +256,7 @@ def in_max_cfgs(cfgs: list): # cfgs = [{'TOTAL_DIM': 32, 'PARALLELISM': 4, 'IN_VALUE_WIDTH': 16, 'IN_MAX_WIDTH': 2}] mase_runner( - module_param_list=cfgs, trace=True, jobs=12, + module_param_list=cfgs, + trace=True, + jobs=12, ) diff --git a/src/mase_components/activation_layers/test/softermax_local_window_tb.py b/src/mase_components/activation_layers/test/softermax_local_window_tb.py index c52b7c647..10b11a569 100644 --- a/src/mase_components/activation_layers/test/softermax_local_window_tb.py +++ b/src/mase_components/activation_layers/test/softermax_local_window_tb.py @@ -53,7 +53,7 @@ def __init__(self, dut) -> None: def generate_inputs(self, batches=10): return [ - [randint(0, 2 ** self.IN_WIDTH - 1) for _ in range(self.PARALLELISM)] + [randint(0, 2**self.IN_WIDTH - 1) for _ in range(self.PARALLELISM)] for _ in range(batches) ] @@ -80,7 +80,7 @@ def model(self, inputs): sign_ext = sign_extend_t( torch.tensor(inputs, dtype=torch.float), bits=self.IN_WIDTH ) - float_inputs = sign_ext / (2 ** self.IN_FRAC_WIDTH) + float_inputs = sign_ext / (2**self.IN_FRAC_WIDTH) # float_inputs = torch.tensor([[-31.5, -32]]) rounded_inputs_float, rounded_inputs_uint = _fixed_signed_cast_model( float_inputs, self.MAX_WIDTH, 0, False, "floor" @@ -89,9 +89,9 @@ def model(self, inputs): local_max_uint = signed_to_unsigned(local_max.int(), self.MAX_WIDTH) difference = float_inputs - local_max - pow2 = 2 ** difference + pow2 = 2**difference res = torch.clamp( - (pow2 * 2 ** self.OUT_FRAC_WIDTH).int(), 0, 2 ** self.OUT_WIDTH - 1 + (pow2 * 2**self.OUT_FRAC_WIDTH).int(), 0, 2**self.OUT_WIDTH - 1 ) logger.debug("float_inputs: %s" % float_inputs) diff --git a/src/mase_components/activation_layers/test/softermax_lpw_pow2_tb.py b/src/mase_components/activation_layers/test/softermax_lpw_pow2_tb.py index 0a4b52af5..058fa10b0 100644 --- a/src/mase_components/activation_layers/test/softermax_lpw_pow2_tb.py +++ b/src/mase_components/activation_layers/test/softermax_lpw_pow2_tb.py @@ -37,7 +37,7 @@ def __init__(self, dut) -> None: self.in_driver = StreamDriver(dut.clk, dut.in_data, dut.in_valid, dut.in_ready) # 0.1% bit error - self.error_threshold_bits = ceil((2 ** self.IN_WIDTH) * 0.001) + self.error_threshold_bits = ceil((2**self.IN_WIDTH) * 0.001) self.output_monitor = ErrorThresholdStreamMonitor( dut.clk, @@ -53,10 +53,10 @@ def __init__(self, dut) -> None: def sweep_inputs(self): negative_nums = torch.arange( - start=2 ** (self.IN_WIDTH - 1), end=2 ** self.IN_WIDTH, dtype=torch.int32 + start=2 ** (self.IN_WIDTH - 1), end=2**self.IN_WIDTH, dtype=torch.int32 ) zero_to_one = torch.arange( - start=0, end=2 ** self.IN_FRAC_WIDTH, dtype=torch.int32 # one + start=0, end=2**self.IN_FRAC_WIDTH, dtype=torch.int32 # one ) return torch.cat((negative_nums, zero_to_one)).tolist() @@ -68,14 +68,14 @@ def generate_inputs(self, batches=1): # Negative Numbers torch.randint( low=2 ** (self.IN_WIDTH - 1), - high=2 ** self.IN_WIDTH, + high=2**self.IN_WIDTH, size=(negative_nums,), dtype=torch.int32, ), # Numbers between 0 and 1 torch.randint( low=0, - high=2 ** self.IN_FRAC_WIDTH, + high=2**self.IN_FRAC_WIDTH, size=(zero_to_one_nums,), dtype=torch.int32, ), @@ -85,10 +85,10 @@ def generate_inputs(self, batches=1): def model(self, inputs): in_t = torch.tensor(inputs) - num = sign_extend_t(in_t, self.IN_WIDTH) / (2 ** self.IN_FRAC_WIDTH) - res = 2 ** num - res = (res * 2 ** self.OUT_FRAC_WIDTH).int() - res = torch.clamp(res, 0, 2 ** self.OUT_WIDTH - 1) + num = sign_extend_t(in_t, self.IN_WIDTH) / (2**self.IN_FRAC_WIDTH) + res = 2**num + res = (res * 2**self.OUT_FRAC_WIDTH).int() + res = torch.clamp(res, 0, 2**self.OUT_WIDTH - 1) return res.tolist() async def run_test(self, batches, us): @@ -128,7 +128,7 @@ async def sweep(dut): tb.in_driver.load_driver(inputs) tb.output_monitor.load_monitor(exp_out) - ns = ((2 ** tb.IN_WIDTH) * 1000) // 5 + ns = ((2**tb.IN_WIDTH) * 1000) // 5 logger.info("Waiting %d ns..." % ns) await Timer(ns, "ns") assert tb.output_monitor.exp_queue.empty() @@ -137,14 +137,14 @@ async def sweep(dut): recv_log = tb.output_monitor.recv_log assert len(exp_out) == len(recv_log) - x = sign_extend_t(torch.tensor(inputs), tb.IN_WIDTH) / (2 ** tb.IN_FRAC_WIDTH) - ref = 2 ** x - ref *= 2 ** tb.OUT_FRAC_WIDTH # scale up - ref = torch.clamp(ref, 0, 2 ** tb.OUT_WIDTH - 1) + x = sign_extend_t(torch.tensor(inputs), tb.IN_WIDTH) / (2**tb.IN_FRAC_WIDTH) + ref = 2**x + ref *= 2**tb.OUT_FRAC_WIDTH # scale up + ref = torch.clamp(ref, 0, 2**tb.OUT_WIDTH - 1) - software_ref = ref / (2 ** tb.OUT_FRAC_WIDTH) - software_res = [x / (2 ** tb.OUT_FRAC_WIDTH) for x in exp_out] - hardware_res = [x / (2 ** tb.OUT_FRAC_WIDTH) for x in recv_log] + software_ref = ref / (2**tb.OUT_FRAC_WIDTH) + software_res = [x / (2**tb.OUT_FRAC_WIDTH) for x in exp_out] + hardware_res = [x / (2**tb.OUT_FRAC_WIDTH) for x in recv_log] data = pd.DataFrame( { @@ -174,7 +174,10 @@ async def sweep(dut): ), color=alt.Color("Type"), ) - .properties(width=600, height=220,) + .properties( + width=600, + height=220, + ) ) error_data = data[["x", "hardware error"]] @@ -185,7 +188,10 @@ async def sweep(dut): x=alt.X("x").title(f"x (Q{tb.IN_WIDTH}.{tb.IN_FRAC_WIDTH} Fixed-point)"), y=alt.Y("hardware error").title(f"Error"), ) - .properties(width=600, height=100,) + .properties( + width=600, + height=100, + ) ) (curve_fig & error_fig).save( @@ -296,7 +302,12 @@ def test_high_width(): def test_smoke(): mase_runner( module_param_list=[ - {"IN_WIDTH": 8, "IN_FRAC_WIDTH": 4, "OUT_WIDTH": 8, "OUT_FRAC_WIDTH": 4,} + { + "IN_WIDTH": 8, + "IN_FRAC_WIDTH": 4, + "OUT_WIDTH": 8, + "OUT_FRAC_WIDTH": 4, + } ] ) diff --git a/src/mase_components/activation_layers/test/softermax_lpw_reciprocal_tb.py b/src/mase_components/activation_layers/test/softermax_lpw_reciprocal_tb.py index 842881dc9..f2c70b67c 100644 --- a/src/mase_components/activation_layers/test/softermax_lpw_reciprocal_tb.py +++ b/src/mase_components/activation_layers/test/softermax_lpw_reciprocal_tb.py @@ -38,7 +38,7 @@ def __init__(self, dut) -> None: # Specify Error Threshold self.percentage_error = 0.05 # 5% - self.error_threshold_bits = ceil(self.percentage_error * (2 ** self.OUT_WIDTH)) + self.error_threshold_bits = ceil(self.percentage_error * (2**self.OUT_WIDTH)) self.output_monitor = ErrorThresholdStreamMonitor( dut.clk, @@ -53,17 +53,17 @@ def __init__(self, dut) -> None: ) def generate_inputs(self, batches=100): - return [randint(0, 2 ** self.IN_WIDTH - 1) for _ in range(batches)] + return [randint(0, 2**self.IN_WIDTH - 1) for _ in range(batches)] def sweep_input(self): - return list(range(2 ** self.IN_WIDTH)) + return list(range(2**self.IN_WIDTH)) def model(self, inputs): - in_t = torch.tensor(inputs) / (2 ** self.IN_FRAC_WIDTH) + in_t = torch.tensor(inputs) / (2**self.IN_FRAC_WIDTH) recip = 1.0 / in_t - res = torch.floor(recip * 2 ** self.OUT_FRAC_WIDTH) + res = torch.floor(recip * 2**self.OUT_FRAC_WIDTH) res = torch.nan_to_num(res) - res = torch.clamp(res, 0, 2 ** self.OUT_WIDTH - 1) + res = torch.clamp(res, 0, 2**self.OUT_WIDTH - 1) res = res.int() return res.tolist() @@ -107,14 +107,14 @@ async def sweep(dut): recv_log = tb.output_monitor.recv_log assert len(exp_out) == len(recv_log) - x = torch.tensor(inputs) / (2 ** tb.IN_FRAC_WIDTH) + x = torch.tensor(inputs) / (2**tb.IN_FRAC_WIDTH) ref = 1.0 / x - ref *= 2 ** tb.OUT_FRAC_WIDTH # scale up - ref = torch.clamp(ref, 0, 2 ** tb.OUT_WIDTH - 1) + ref *= 2**tb.OUT_FRAC_WIDTH # scale up + ref = torch.clamp(ref, 0, 2**tb.OUT_WIDTH - 1) - software_ref = ref / (2 ** tb.OUT_FRAC_WIDTH) - software_res = [x / (2 ** tb.OUT_FRAC_WIDTH) for x in exp_out] - hardware_res = [x / (2 ** tb.OUT_FRAC_WIDTH) for x in recv_log] + software_ref = ref / (2**tb.OUT_FRAC_WIDTH) + software_res = [x / (2**tb.OUT_FRAC_WIDTH) for x in exp_out] + hardware_res = [x / (2**tb.OUT_FRAC_WIDTH) for x in recv_log] data = pd.DataFrame( { @@ -150,7 +150,10 @@ async def sweep(dut): ), color=alt.Color("Type"), ) - .properties(width=600, height=300,) + .properties( + width=600, + height=300, + ) ) error_data = data[["x", "hardware error"]] @@ -161,7 +164,10 @@ async def sweep(dut): x=alt.X("x").title(f"x (Q{tb.IN_WIDTH}.{tb.IN_FRAC_WIDTH} Fixed-point)"), y=alt.Y("hardware error").title(f"Error"), ) - .properties(width=600, height=100,) + .properties( + width=600, + height=100, + ) ) (curve_fig & error_fig).save( @@ -216,7 +222,7 @@ def width_cfgs(): for width in range(2, 16 + 1): frac_width = width // 2 if frac_width < 3: - entries = 2 ** frac_width + entries = 2**frac_width else: entries = 8 cfgs.append( @@ -254,7 +260,9 @@ def test_width_cfgs(): "OUT_FRAC_WIDTH": 4, } ] - mase_runner(module_param_list=cfgs,) + mase_runner( + module_param_list=cfgs, + ) def test_smoke(): diff --git a/src/mase_components/cast/test/fixed_rounding_tb.py b/src/mase_components/cast/test/fixed_rounding_tb.py index 300283856..a7f2f935c 100644 --- a/src/mase_components/cast/test/fixed_rounding_tb.py +++ b/src/mase_components/cast/test/fixed_rounding_tb.py @@ -44,7 +44,7 @@ def single_run(self): def sw_cast(self, inputs): outputs = ( integer_floor_quantizer(inputs, self.out_width, self.out_frac_width) - * 2 ** self.out_frac_width + * 2**self.out_frac_width ) # breakpoint() return outputs diff --git a/src/mase_components/cast/test/fixed_signed_cast_tb.py b/src/mase_components/cast/test/fixed_signed_cast_tb.py index 9f5a6956d..2ce3bbe5f 100644 --- a/src/mase_components/cast/test/fixed_signed_cast_tb.py +++ b/src/mase_components/cast/test/fixed_signed_cast_tb.py @@ -18,7 +18,7 @@ def _fixed_signed_cast_model( float_input, out_width, out_frac_width, symmetric, rounding_mode ): - scaled_float = float_input * (2 ** out_frac_width) + scaled_float = float_input * (2**out_frac_width) if rounding_mode == "floor": out_int = my_floor(scaled_float) elif rounding_mode == "round_nearest_half_even": @@ -30,7 +30,7 @@ def _fixed_signed_cast_model( -(2 ** (out_width - 1)) + 1 if symmetric else -(2 ** (out_width - 1)), (2 ** (out_width - 1)) - 1, ) - out_float = out_int / (2 ** out_frac_width) + out_float = out_int / (2**out_frac_width) # out_uint is a non-differentiable path out_uint = signed_to_unsigned(out_int.int(), out_width) return out_float, out_uint @@ -58,9 +58,9 @@ def __init__(self, dut) -> None: ) def generate_inputs(self): - uints = torch.arange(2 ** self.IN_WIDTH) + uints = torch.arange(2**self.IN_WIDTH) num_int = sign_extend_t(uints, self.IN_WIDTH) - num_float = num_int / (2 ** self.IN_FRAC_WIDTH) + num_float = num_int / (2**self.IN_FRAC_WIDTH) return num_int, num_float def rounding_mode(self): @@ -150,7 +150,10 @@ def gen_symmetric(cfg_list): l = list() for cfg in cfg_list: l.extend( - [{**cfg, "SYMMETRIC": 0}, {**cfg, "SYMMETRIC": 1},] + [ + {**cfg, "SYMMETRIC": 0}, + {**cfg, "SYMMETRIC": 1}, + ] ) return l diff --git a/src/mase_components/cast/test/fixed_unsigned_cast_tb.py b/src/mase_components/cast/test/fixed_unsigned_cast_tb.py index 75910ed9e..cfeca8318 100644 --- a/src/mase_components/cast/test/fixed_unsigned_cast_tb.py +++ b/src/mase_components/cast/test/fixed_unsigned_cast_tb.py @@ -51,7 +51,7 @@ def __init__(self, dut) -> None: ) def generate_inputs(self): - return torch.arange(2 ** self.IN_WIDTH) + return torch.arange(2**self.IN_WIDTH) def rounding_mode(self): if self.ROUND_FLOOR: @@ -64,10 +64,10 @@ def rounding_mode(self): raise Exception("Rounding mode not recognised.") def model(self, inputs): - float_input = inputs / (2 ** self.IN_FRAC_WIDTH) - scaled_float = float_input * (2 ** self.OUT_FRAC_WIDTH) + float_input = inputs / (2**self.IN_FRAC_WIDTH) + scaled_float = float_input * (2**self.OUT_FRAC_WIDTH) rounded = torch.floor(scaled_float) - model_out = torch.clamp(rounded, 0, (2 ** self.OUT_WIDTH - 1)) + model_out = torch.clamp(rounded, 0, (2**self.OUT_WIDTH - 1)) return model_out async def run_test(self): diff --git a/src/mase_components/common/test/comparator_accumulator_tb.py b/src/mase_components/common/test/comparator_accumulator_tb.py index d05f41a09..fbd8af937 100644 --- a/src/mase_components/common/test/comparator_accumulator_tb.py +++ b/src/mase_components/common/test/comparator_accumulator_tb.py @@ -33,9 +33,7 @@ def __init__(self, dut) -> None: ) def generate_inputs(self, batches=3): - return [ - randint(0, 2 ** self.DATA_WIDTH - 1) for _ in range(self.DEPTH * batches) - ] + return [randint(0, 2**self.DATA_WIDTH - 1) for _ in range(self.DEPTH * batches)] def model(self, inputs): @@ -154,5 +152,7 @@ def signed_max_min_cfgs(cfglist: list): cfgs = signed_max_min_cfgs(cfgs) mase_runner( - module_param_list=cfgs, trace=True, jobs=12, + module_param_list=cfgs, + trace=True, + jobs=12, ) diff --git a/src/mase_components/common/test/comparator_tree_tb.py b/src/mase_components/common/test/comparator_tree_tb.py index ae2a7fc61..5637791b3 100644 --- a/src/mase_components/common/test/comparator_tree_tb.py +++ b/src/mase_components/common/test/comparator_tree_tb.py @@ -34,7 +34,7 @@ def __init__(self, dut) -> None: def generate_inputs(self, batches=3): return [ - [randint(0, 2 ** self.DATA_WIDTH - 1) for _ in range(self.SIZE)] + [randint(0, 2**self.DATA_WIDTH - 1) for _ in range(self.SIZE)] for _ in range(batches) ] @@ -136,5 +136,6 @@ def signed_max_min_cfgs(cfglist: list): cfgs = signed_max_min_cfgs(cfgs) mase_runner( - module_param_list=cfgs, trace=True, + module_param_list=cfgs, + trace=True, ) diff --git a/src/mase_components/common/test/register_slice_tb.py b/src/mase_components/common/test/register_slice_tb.py index ffcec4f3e..638d53da7 100644 --- a/src/mase_components/common/test/register_slice_tb.py +++ b/src/mase_components/common/test/register_slice_tb.py @@ -60,7 +60,9 @@ def in_out_wave(dut, name): ) logger.debug( "{} State: (shift_reg, buffer) = ({},{})".format( - name, int(dut.shift_reg.value), int(dut.buffer.value), + name, + int(dut.shift_reg.value), + int(dut.buffer.value), ) ) diff --git a/src/mase_components/common/test/single_element_repeat_tb.py b/src/mase_components/common/test/single_element_repeat_tb.py index 6d39b0c35..247cc2b27 100644 --- a/src/mase_components/common/test/single_element_repeat_tb.py +++ b/src/mase_components/common/test/single_element_repeat_tb.py @@ -26,7 +26,7 @@ def __init__(self, dut) -> None: ) def generate_inputs(self, num=10): - return [random.randint(0, 2 ** self.DATA_WIDTH - 1) for _ in range(num)] + return [random.randint(0, 2**self.DATA_WIDTH - 1) for _ in range(num)] def model(self, inputs): exp_out = [] @@ -103,5 +103,7 @@ def generate_random_params(): ] mase_runner( - module_param_list=cfgs, trace=True, jobs=8, + module_param_list=cfgs, + trace=True, + jobs=8, ) diff --git a/src/mase_components/convolution_layers/test/binary_activation_binary_convolution_tb.py b/src/mase_components/convolution_layers/test/binary_activation_binary_convolution_tb.py index 1481671a3..8091990e4 100644 --- a/src/mase_components/convolution_layers/test/binary_activation_binary_convolution_tb.py +++ b/src/mase_components/convolution_layers/test/binary_activation_binary_convolution_tb.py @@ -244,7 +244,9 @@ def sw_compute(self): bias = \n\ {} \n\ ".format( - data, weight, bias, + data, + weight, + bias, ) ) for i in range(self.samples): diff --git a/src/mase_components/convolution_layers/test/convolution_tb.py b/src/mase_components/convolution_layers/test/convolution_tb.py index ec3c07a37..1a3789ed9 100644 --- a/src/mase_components/convolution_layers/test/convolution_tb.py +++ b/src/mase_components/convolution_layers/test/convolution_tb.py @@ -130,7 +130,11 @@ def get_manual_result( # out2 = get_manual_result(x, w, b, 2,1,2,2,4,4,0,0,12,4) # data_in_pack - x = q2i(x, config["data_in_width"], config["data_in_frac_width"],) + x = q2i( + x, + config["data_in_width"], + config["data_in_frac_width"], + ) self.log.info(f"x = {x}") # from (samples, c, h, w) to (samples*h*w*c/unroll_in_c, unroll_in_c) @@ -140,8 +144,16 @@ def get_manual_result( self.log.info(f"weight = {w}") self.log.info(f"bias = {b}") - w = q2i(w, config["weight_width"], config["weight_frac_width"],) - b = q2i(b, config["bias_width"], config["bias_frac_width"],) + w = q2i( + w, + config["weight_width"], + config["weight_frac_width"], + ) + b = q2i( + b, + config["bias_width"], + config["bias_frac_width"], + ) self.log.info(f"weight = {w}") self.log.info(f"bias = {b}") hw_w, hw_b = self.conv_pack( @@ -157,7 +169,11 @@ def get_manual_result( unroll_kernel_out=self.get_parameter("UNROLL_KERNEL_OUT"), unroll_out_channels=self.get_parameter("UNROLL_OUT_C"), ) - exp_out = q2i(out, config["out_width"], config["out_frac_width"],) + exp_out = q2i( + out, + config["out_width"], + config["out_frac_width"], + ) exp_out = ( exp_out.reshape( -1, self.get_parameter("OUT_C"), self.get_parameter("SLIDING_NUM") @@ -207,7 +223,10 @@ def conv_pack( unroll_out_channels, ).permute(0, 3, 1, 4, 2) - w_tensor = w_tensor.reshape(-1, unroll_out_channels * unroll_kernel_out,) + w_tensor = w_tensor.reshape( + -1, + unroll_out_channels * unroll_kernel_out, + ) w_in = w_tensor.type(torch.int).tolist() # bias_pack bias_tensor = ( @@ -313,7 +332,10 @@ def test_fixed_linear_smoke(): Some quick tests to check if the module is working. """ mase_runner( - trace=True, module_param_list=[get_fixed_conv_config(),], + trace=True, + module_param_list=[ + get_fixed_conv_config(), + ], ) diff --git a/src/mase_components/convolution_layers/test/padding_tb.py b/src/mase_components/convolution_layers/test/padding_tb.py index d0504ef5d..6cb470069 100644 --- a/src/mase_components/convolution_layers/test/padding_tb.py +++ b/src/mase_components/convolution_layers/test/padding_tb.py @@ -96,9 +96,9 @@ def data_pack(self): for j in range(in_channels): for k in range(in_height): for s in range(in_width): - re_data_tensor[i][j][k + padding_height][ - s + padding_width - ] = data_tensor[i][k][s][j] + re_data_tensor[i][j][k + padding_height][s + padding_width] = ( + data_tensor[i][k][s][j] + ) return re_data_tensor @@ -248,7 +248,9 @@ def runner(): print(extra_args) runner = get_runner(sim) runner.build( - verilog_sources=verilog_sources, hdl_toplevel="padding", build_args=extra_args, + verilog_sources=verilog_sources, + hdl_toplevel="padding", + build_args=extra_args, ) runner.test(hdl_toplevel="padding", test_module="padding_tb") diff --git a/src/mase_components/convolution_layers/test/roller_tb.py b/src/mase_components/convolution_layers/test/roller_tb.py index 8c7a24d02..4bf56b273 100644 --- a/src/mase_components/convolution_layers/test/roller_tb.py +++ b/src/mase_components/convolution_layers/test/roller_tb.py @@ -195,7 +195,9 @@ def runner(): print(extra_args) runner = get_runner(sim)() runner.build( - verilog_sources=verilog_sources, toplevel="roller", extra_args=extra_args, + verilog_sources=verilog_sources, + toplevel="roller", + extra_args=extra_args, ) runner.test(toplevel="roller", py_module="roller_tb") diff --git a/src/mase_components/convolution_layers/test/sliding_window_tb.py b/src/mase_components/convolution_layers/test/sliding_window_tb.py index 73facf247..f893d6f68 100644 --- a/src/mase_components/convolution_layers/test/sliding_window_tb.py +++ b/src/mase_components/convolution_layers/test/sliding_window_tb.py @@ -119,9 +119,9 @@ def data_pack(self): for j in range(in_channels): for k in range(in_height): for s in range(in_width): - re_data_tensor[i][j][k + padding_height][ - s + padding_width - ] = data_tensor[i][k][s][j] + re_data_tensor[i][j][k + padding_height][s + padding_width] = ( + data_tensor[i][k][s][j] + ) return re_data_tensor diff --git a/src/mase_components/deps.py b/src/mase_components/deps.py index 9596adedc..a06529344 100644 --- a/src/mase_components/deps.py +++ b/src/mase_components/deps.py @@ -17,7 +17,11 @@ "activation_layers", "scalar_operators/fixed", ], - "activation_layers/fixed_gelu": ["common", "memory", "activation_layers",], + "activation_layers/fixed_gelu": [ + "common", + "memory", + "activation_layers", + ], "activation_layers/fixed_softsign": [ "common", "activation_layers", @@ -123,7 +127,11 @@ "common", "cast", ], - "language_models/llmint8/scatter": ["language_models/llmint8", "memory", "common",], + "language_models/llmint8/scatter": [ + "language_models/llmint8", + "memory", + "common", + ], # Linear "linear_layers/fixed_linear_layer/fixed_linear": [ "cast", diff --git a/src/mase_components/helper/generate_memory.py b/src/mase_components/helper/generate_memory.py index c780d6c7f..e46c99604 100644 --- a/src/mase_components/helper/generate_memory.py +++ b/src/mase_components/helper/generate_memory.py @@ -60,10 +60,10 @@ def generate_lookup(data_width: int, f_width: int, function: str, type="hex"): count += 1 iarr.append(i) val = quanter(f(torch.tensor(i))) # entry in the lookup table - lut[ - doubletofx(data_width=data_width, f_width=f_width, num=i, type=type) - ] = doubletofx( - data_width=data_width, f_width=f_width, num=val.item(), type=type + lut[doubletofx(data_width=data_width, f_width=f_width, num=i, type=type)] = ( + doubletofx( + data_width=data_width, f_width=f_width, num=val.item(), type=type + ) ) i += 2 ** -(f_width) return lut diff --git a/src/mase_components/hls/bfp_arith/bfp_adder.py b/src/mase_components/hls/bfp_arith/bfp_adder.py index 0ba95741e..5b2a401d9 100644 --- a/src/mase_components/hls/bfp_arith/bfp_adder.py +++ b/src/mase_components/hls/bfp_arith/bfp_adder.py @@ -2,7 +2,11 @@ def bfp_adder_gen( - writer, x_exp_width=16, x_man_width=8, w_exp_width=16, w_man_width=8, + writer, + x_exp_width=16, + x_man_width=8, + w_exp_width=16, + w_man_width=8, ): """ This script generates a element-level bfp add in HLS diff --git a/src/mase_components/hls/bfp_arith/bfp_multiplier.py b/src/mase_components/hls/bfp_arith/bfp_multiplier.py index fd588ea3c..e3695295e 100644 --- a/src/mase_components/hls/bfp_arith/bfp_multiplier.py +++ b/src/mase_components/hls/bfp_arith/bfp_multiplier.py @@ -1,5 +1,9 @@ def bfp_multiplier_gen( - writer, x_exp_width=16, x_man_width=8, w_exp_width=16, w_man_width=8, + writer, + x_exp_width=16, + x_man_width=8, + w_exp_width=16, + w_man_width=8, ): """ This script generates a element-level bfp mult in HLS diff --git a/src/mase_components/hls/elastic/buffer.py b/src/mase_components/hls/elastic/buffer.py index a0fda6b90..d25c1fd56 100644 --- a/src/mase_components/hls/elastic/buffer.py +++ b/src/mase_components/hls/elastic/buffer.py @@ -2,7 +2,13 @@ def buffer_gen( - writer, x_width=16, x_frac_width=8, x_row=3, x_col=2, x_row_depth=3, x_col_depth=2, + writer, + x_width=16, + x_frac_width=8, + x_row=3, + x_col=2, + x_row_depth=3, + x_col_depth=2, ): """ This script generates a buffer in HLS diff --git a/src/mase_components/hls/hls_regression.py b/src/mase_components/hls/hls_regression.py index 719b20380..a02dca60a 100755 --- a/src/mase_components/hls/hls_regression.py +++ b/src/mase_components/hls/hls_regression.py @@ -106,7 +106,10 @@ def main(): parser = ArgumentParser(usage=USAGE) parser.add_argument( - "--op", dest="op", default=None, help="Op name to explore", + "--op", + dest="op", + default=None, + help="Op name to explore", ) parser.add_argument( "--dir", diff --git a/src/mase_components/hls/int_arith/int_layernorm.py b/src/mase_components/hls/int_arith/int_layernorm.py index 9521a68c0..059b15d4d 100644 --- a/src/mase_components/hls/int_arith/int_layernorm.py +++ b/src/mase_components/hls/int_arith/int_layernorm.py @@ -2,7 +2,13 @@ def int_layernorm_gen( - writer, x_width=16, x_frac_width=8, x_row=3, x_col=2, x_row_depth=3, x_col_depth=2, + writer, + x_width=16, + x_frac_width=8, + x_row=3, + x_col=2, + x_row_depth=3, + x_col_depth=2, ): """ This script generates a fixed-point layernorm in HLS diff --git a/src/mase_components/hls/int_arith/int_relu.py b/src/mase_components/hls/int_arith/int_relu.py index 09f7462c2..f934640c9 100644 --- a/src/mase_components/hls/int_arith/int_relu.py +++ b/src/mase_components/hls/int_arith/int_relu.py @@ -2,7 +2,13 @@ def int_relu_gen( - writer, x_width=16, x_frac_width=8, x_row=3, x_col=2, x_row_depth=3, x_col_depth=2, + writer, + x_width=16, + x_frac_width=8, + x_row=3, + x_col=2, + x_row_depth=3, + x_col_depth=2, ): """ This script generates a fixed-point relu in HLS diff --git a/src/mase_components/hls/int_arith/int_silu.py b/src/mase_components/hls/int_arith/int_silu.py index c90511c47..c4fc32abe 100644 --- a/src/mase_components/hls/int_arith/int_silu.py +++ b/src/mase_components/hls/int_arith/int_silu.py @@ -2,7 +2,13 @@ def int_silu_gen( - writer, x_width=16, x_frac_width=8, x_row=3, x_col=2, x_row_depth=3, x_col_depth=2, + writer, + x_width=16, + x_frac_width=8, + x_row=3, + x_col=2, + x_row_depth=3, + x_col_depth=2, ): """ This script generates a fixed-point silu in HLS diff --git a/src/mase_components/hls/int_arith/int_softmax.py b/src/mase_components/hls/int_arith/int_softmax.py index 8ad5901cf..c0b4ca5b1 100644 --- a/src/mase_components/hls/int_arith/int_softmax.py +++ b/src/mase_components/hls/int_arith/int_softmax.py @@ -2,7 +2,13 @@ def int_softmax_gen( - writer, x_width=16, x_frac_width=8, x_row=3, x_col=2, x_row_depth=3, x_col_depth=2, + writer, + x_width=16, + x_frac_width=8, + x_row=3, + x_col=2, + x_row_depth=3, + x_col_depth=2, ): """ This script generates a fixed-point softmax in HLS diff --git a/src/mase_components/hls/int_arith/int_transpose.py b/src/mase_components/hls/int_arith/int_transpose.py index ee9c6d2d7..6b7fd0a15 100644 --- a/src/mase_components/hls/int_arith/int_transpose.py +++ b/src/mase_components/hls/int_arith/int_transpose.py @@ -2,7 +2,13 @@ def int_transpose_gen( - writer, x_width=16, x_frac_width=8, x_row=3, x_col=2, x_row_depth=3, x_col_depth=2, + writer, + x_width=16, + x_frac_width=8, + x_row=3, + x_col=2, + x_row_depth=3, + x_col_depth=2, ): """ This script generates a fixed-point transpose in HLS diff --git a/src/mase_components/hls/regression_gen/bfp_add_dse.py b/src/mase_components/hls/regression_gen/bfp_add_dse.py index c283301d2..834644b24 100644 --- a/src/mase_components/hls/regression_gen/bfp_add_dse.py +++ b/src/mase_components/hls/regression_gen/bfp_add_dse.py @@ -1,7 +1,13 @@ # TODO: Temporary working solution import sys, os -sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..",)) +sys.path.append( + os.path.join( + os.path.dirname(os.path.realpath(__file__)), + "..", + "..", + ) +) from hls.regression_gen.utils import ( DSE_MODES, @@ -151,7 +157,8 @@ def bfp_add_dse(mode=None, top=None, threads=16): if mode in ["report", "all"]: top_name = "bfp_add_2" hr = get_hls_results( - project=os.path.join(top, file_name), top=top_name, + project=os.path.join(top, file_name), + top=top_name, ) data_points.append( [ diff --git a/src/mase_components/hls/regression_gen/bfp_linear2d_dse.py b/src/mase_components/hls/regression_gen/bfp_linear2d_dse.py index 9a894b143..d6f78f30c 100644 --- a/src/mase_components/hls/regression_gen/bfp_linear2d_dse.py +++ b/src/mase_components/hls/regression_gen/bfp_linear2d_dse.py @@ -1,7 +1,13 @@ # TODO: Temporary working solution import sys, os -sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..",)) +sys.path.append( + os.path.join( + os.path.dirname(os.path.realpath(__file__)), + "..", + "..", + ) +) from hls.regression_gen.utils import ( DSE_MODES, diff --git a/src/mase_components/hls/regression_gen/bfp_mult_dse.py b/src/mase_components/hls/regression_gen/bfp_mult_dse.py index 1510ea1f3..b4b5fa58c 100644 --- a/src/mase_components/hls/regression_gen/bfp_mult_dse.py +++ b/src/mase_components/hls/regression_gen/bfp_mult_dse.py @@ -1,7 +1,13 @@ # TODO: Temporary working solution import sys, os -sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..",)) +sys.path.append( + os.path.join( + os.path.dirname(os.path.realpath(__file__)), + "..", + "..", + ) +) from hls.regression_gen.utils import ( DSE_MODES, @@ -151,7 +157,8 @@ def bfp_mult_dse(mode=None, top=None, threads=16): if mode in ["report", "all"]: top_name = "bfp_mult_2" hr = get_hls_results( - project=os.path.join(top, file_name), top=top_name, + project=os.path.join(top, file_name), + top=top_name, ) if hr is None: continue diff --git a/src/mase_components/hls/regression_gen/buffer_dse.py b/src/mase_components/hls/regression_gen/buffer_dse.py index bcc003c6d..08cc6d56e 100644 --- a/src/mase_components/hls/regression_gen/buffer_dse.py +++ b/src/mase_components/hls/regression_gen/buffer_dse.py @@ -1,7 +1,13 @@ # TODO: Temporary working solution import sys, os -sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..",)) +sys.path.append( + os.path.join( + os.path.dirname(os.path.realpath(__file__)), + "..", + "..", + ) +) from hls.regression_gen.utils import ( DSE_MODES, @@ -123,7 +129,8 @@ def buffer_dse(mode=None, top=None, threads=16): if mode in ["report", "all"]: top_name = "buffer_0" hr = get_hls_results( - project=os.path.join(top, file_name), top=top_name, + project=os.path.join(top, file_name), + top=top_name, ) data_points.append( [ diff --git a/src/mase_components/hls/regression_gen/fork_dse.py b/src/mase_components/hls/regression_gen/fork_dse.py index c8a5b3c0b..f045cd8a3 100644 --- a/src/mase_components/hls/regression_gen/fork_dse.py +++ b/src/mase_components/hls/regression_gen/fork_dse.py @@ -1,7 +1,13 @@ # TODO: Temporary working solution import sys, os -sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..",)) +sys.path.append( + os.path.join( + os.path.dirname(os.path.realpath(__file__)), + "..", + "..", + ) +) from hls.regression_gen.utils import ( DSE_MODES, @@ -133,7 +139,8 @@ def fork_dse(mode=None, top=None, threads=16): if mode in ["report", "all"]: top_name = "fork_0" hr = get_hls_results( - project=os.path.join(top, file_name), top=top_name, + project=os.path.join(top, file_name), + top=top_name, ) data_points.append( [ diff --git a/src/mase_components/hls/regression_gen/int_add_dse.py b/src/mase_components/hls/regression_gen/int_add_dse.py index a88247d24..d4c4d9ccc 100644 --- a/src/mase_components/hls/regression_gen/int_add_dse.py +++ b/src/mase_components/hls/regression_gen/int_add_dse.py @@ -1,7 +1,13 @@ # TODO: Temporary working solution import sys, os -sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..",)) +sys.path.append( + os.path.join( + os.path.dirname(os.path.realpath(__file__)), + "..", + "..", + ) +) from hls.regression_gen.utils import ( DSE_MODES, @@ -151,7 +157,8 @@ def int_add_dse(mode=None, top=None, threads=16): if mode in ["report", "all"]: top_name = "int_add_0" hr = get_hls_results( - project=os.path.join(top, file_name), top=top_name, + project=os.path.join(top, file_name), + top=top_name, ) data_points.append( [ diff --git a/src/mase_components/hls/regression_gen/int_layernorm_dse.py b/src/mase_components/hls/regression_gen/int_layernorm_dse.py index d0ef2afdb..3e6a61aff 100644 --- a/src/mase_components/hls/regression_gen/int_layernorm_dse.py +++ b/src/mase_components/hls/regression_gen/int_layernorm_dse.py @@ -1,7 +1,13 @@ # TODO: Temporary working solution import sys, os -sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..",)) +sys.path.append( + os.path.join( + os.path.dirname(os.path.realpath(__file__)), + "..", + "..", + ) +) from hls.regression_gen.utils import ( DSE_MODES, @@ -124,7 +130,8 @@ def int_layernorm_dse(mode=None, top=None, threads=16): if mode in ["report", "all"]: top_name = "int_layernorm_1" hr = get_hls_results( - project=os.path.join(top, file_name), top=top_name, + project=os.path.join(top, file_name), + top=top_name, ) data_points.append( [ diff --git a/src/mase_components/hls/regression_gen/int_linear2d_dse.py b/src/mase_components/hls/regression_gen/int_linear2d_dse.py index d75122306..1c89861eb 100644 --- a/src/mase_components/hls/regression_gen/int_linear2d_dse.py +++ b/src/mase_components/hls/regression_gen/int_linear2d_dse.py @@ -1,7 +1,13 @@ # TODO: Temporary working solution import sys, os -sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..",)) +sys.path.append( + os.path.join( + os.path.dirname(os.path.realpath(__file__)), + "..", + "..", + ) +) from hls.regression_gen.utils import ( DSE_MODES, diff --git a/src/mase_components/hls/regression_gen/int_matmul_dse.py b/src/mase_components/hls/regression_gen/int_matmul_dse.py index ac09aaa5c..1d1f35fd2 100644 --- a/src/mase_components/hls/regression_gen/int_matmul_dse.py +++ b/src/mase_components/hls/regression_gen/int_matmul_dse.py @@ -1,7 +1,13 @@ # TODO: Temporary working solution import sys, os -sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..",)) +sys.path.append( + os.path.join( + os.path.dirname(os.path.realpath(__file__)), + "..", + "..", + ) +) from hls.regression_gen.utils import DSE_MODES, get_tcl_buff from hls.int_arith import int_matmul_gen diff --git a/src/mase_components/hls/regression_gen/int_mult_dse.py b/src/mase_components/hls/regression_gen/int_mult_dse.py index f824ec99e..275ffb163 100644 --- a/src/mase_components/hls/regression_gen/int_mult_dse.py +++ b/src/mase_components/hls/regression_gen/int_mult_dse.py @@ -1,7 +1,13 @@ # TODO: Temporary working solution import sys, os -sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..",)) +sys.path.append( + os.path.join( + os.path.dirname(os.path.realpath(__file__)), + "..", + "..", + ) +) from hls.regression_gen.utils import ( DSE_MODES, @@ -151,7 +157,8 @@ def int_mult_dse(mode=None, top=None, threads=16): if mode in ["report", "all"]: top_name = "int_mult_0" hr = get_hls_results( - project=os.path.join(top, file_name), top=top_name, + project=os.path.join(top, file_name), + top=top_name, ) data_points.append( [ diff --git a/src/mase_components/hls/regression_gen/int_relu_dse.py b/src/mase_components/hls/regression_gen/int_relu_dse.py index 788e3f593..7224b1444 100644 --- a/src/mase_components/hls/regression_gen/int_relu_dse.py +++ b/src/mase_components/hls/regression_gen/int_relu_dse.py @@ -1,7 +1,13 @@ # TODO: Temporary working solution import sys, os -sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..",)) +sys.path.append( + os.path.join( + os.path.dirname(os.path.realpath(__file__)), + "..", + "..", + ) +) from hls.regression_gen.utils import ( DSE_MODES, @@ -125,7 +131,8 @@ def int_relu_dse(mode=None, top=None, threads=16): if mode in ["report", "all"]: top_name = "int_relu_0" hr = get_hls_results( - project=os.path.join(top, file_name), top=top_name, + project=os.path.join(top, file_name), + top=top_name, ) data_points.append( [ diff --git a/src/mase_components/hls/regression_gen/int_rmsnorm_dse.py b/src/mase_components/hls/regression_gen/int_rmsnorm_dse.py index c559ecee8..262794af2 100644 --- a/src/mase_components/hls/regression_gen/int_rmsnorm_dse.py +++ b/src/mase_components/hls/regression_gen/int_rmsnorm_dse.py @@ -1,7 +1,13 @@ # TODO: Temporary working solution import sys, os -sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..",)) +sys.path.append( + os.path.join( + os.path.dirname(os.path.realpath(__file__)), + "..", + "..", + ) +) from hls.regression_gen.utils import ( DSE_MODES, @@ -145,7 +151,8 @@ def int_rmsnorm_dse(mode=None, top=None, threads=16): if mode in ["report", "all"]: top_name = "int_rmsnorm_0" hr = get_hls_results( - project=os.path.join(top, file_name), top=top_name, + project=os.path.join(top, file_name), + top=top_name, ) data_points.append( [ diff --git a/src/mase_components/hls/regression_gen/int_rope_dse.py b/src/mase_components/hls/regression_gen/int_rope_dse.py index bfc2bc385..421a37e21 100644 --- a/src/mase_components/hls/regression_gen/int_rope_dse.py +++ b/src/mase_components/hls/regression_gen/int_rope_dse.py @@ -1,7 +1,13 @@ # TODO: Temporary working solution import sys, os -sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..",)) +sys.path.append( + os.path.join( + os.path.dirname(os.path.realpath(__file__)), + "..", + "..", + ) +) from hls.regression_gen.utils import ( DSE_MODES, @@ -146,7 +152,8 @@ def int_rope_dse(mode=None, top=None, threads=16): if mode in ["report", "all"]: top_name = "int_rope_0" hr = get_hls_results( - project=os.path.join(top, file_name), top=top_name, + project=os.path.join(top, file_name), + top=top_name, ) data_points.append( [ diff --git a/src/mase_components/hls/regression_gen/int_silu_dse.py b/src/mase_components/hls/regression_gen/int_silu_dse.py index e3ce1bd11..0d80397cd 100644 --- a/src/mase_components/hls/regression_gen/int_silu_dse.py +++ b/src/mase_components/hls/regression_gen/int_silu_dse.py @@ -1,7 +1,13 @@ # TODO: Temporary working solution import sys, os -sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..",)) +sys.path.append( + os.path.join( + os.path.dirname(os.path.realpath(__file__)), + "..", + "..", + ) +) from hls.regression_gen.utils import ( DSE_MODES, @@ -125,7 +131,8 @@ def int_silu_dse(mode=None, top=None, threads=16): if mode in ["report", "all"]: top_name = "int_silu_0" hr = get_hls_results( - project=os.path.join(top, file_name), top=top_name, + project=os.path.join(top, file_name), + top=top_name, ) data_points.append( [ diff --git a/src/mase_components/hls/regression_gen/int_softmax_dse.py b/src/mase_components/hls/regression_gen/int_softmax_dse.py index 796829278..c10c091ef 100644 --- a/src/mase_components/hls/regression_gen/int_softmax_dse.py +++ b/src/mase_components/hls/regression_gen/int_softmax_dse.py @@ -1,7 +1,13 @@ # TODO: Temporary working solution import sys, os -sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..",)) +sys.path.append( + os.path.join( + os.path.dirname(os.path.realpath(__file__)), + "..", + "..", + ) +) from hls.regression_gen.utils import ( DSE_MODES, @@ -125,7 +131,8 @@ def int_softmax_dse(mode=None, top=None, threads=16): if mode in ["report", "all"]: top_name = "int_softmax_0" hr = get_hls_results( - project=os.path.join(top, file_name), top=top_name, + project=os.path.join(top, file_name), + top=top_name, ) data_points.append( [ diff --git a/src/mase_components/hls/regression_gen/int_transpose_dse.py b/src/mase_components/hls/regression_gen/int_transpose_dse.py index 4563a9eca..7b5c2b715 100644 --- a/src/mase_components/hls/regression_gen/int_transpose_dse.py +++ b/src/mase_components/hls/regression_gen/int_transpose_dse.py @@ -1,7 +1,13 @@ # TODO: Temporary working solution import sys, os -sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..",)) +sys.path.append( + os.path.join( + os.path.dirname(os.path.realpath(__file__)), + "..", + "..", + ) +) from hls.regression_gen.utils import ( DSE_MODES, @@ -125,7 +131,8 @@ def int_transpose_dse(mode=None, top=None, threads=16): if mode in ["report", "all"]: top_name = "int_transpose_0" hr = get_hls_results( - project=os.path.join(top, file_name), top=top_name, + project=os.path.join(top, file_name), + top=top_name, ) data_points.append( [ diff --git a/src/mase_components/hls/scripts/bl_bfp.py b/src/mase_components/hls/scripts/bl_bfp.py index 29f34c04d..beff47050 100644 --- a/src/mase_components/hls/scripts/bl_bfp.py +++ b/src/mase_components/hls/scripts/bl_bfp.py @@ -1,6 +1,12 @@ import os, sys -sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..",)) +sys.path.append( + os.path.join( + os.path.dirname(os.path.realpath(__file__)), + "..", + "..", + ) +) from hls.bfp_arith import bfp_mm_gen from hls.bfp_arith import bfp_add_gen @@ -8,7 +14,12 @@ def get_big_little_bfp( - HIGH_MAN_WIDTH=7, LOW_MAN_WIDTH=3, X_ROW=1, X_COL=4096, W_COL=11008, A_COL=32, + HIGH_MAN_WIDTH=7, + LOW_MAN_WIDTH=3, + X_ROW=1, + X_COL=4096, + W_COL=11008, + A_COL=32, ): W_ROW = X_COL A_ROW = X_COL diff --git a/src/mase_components/linear_layers/binarized_operators/test/binary_activation_binary_adder_tree_layer_tb.py b/src/mase_components/linear_layers/binarized_operators/test/binary_activation_binary_adder_tree_layer_tb.py index 2cdb58b59..91552b68a 100644 --- a/src/mase_components/linear_layers/binarized_operators/test/binary_activation_binary_adder_tree_layer_tb.py +++ b/src/mase_components/linear_layers/binarized_operators/test/binary_activation_binary_adder_tree_layer_tb.py @@ -95,7 +95,8 @@ def runner(): extra_args.append(f"-G{k}={v}") mase_runner( - trace=True, module_param_list=[test_case.get_dut_parameters()], + trace=True, + module_param_list=[test_case.get_dut_parameters()], ) diff --git a/src/mase_components/linear_layers/fixed_operators/test/fixed_isqrt_tb.py b/src/mase_components/linear_layers/fixed_operators/test/fixed_isqrt_tb.py index 27e729200..0aa8f5239 100644 --- a/src/mase_components/linear_layers/fixed_operators/test/fixed_isqrt_tb.py +++ b/src/mase_components/linear_layers/fixed_operators/test/fixed_isqrt_tb.py @@ -24,23 +24,31 @@ class VerificationCase(Testbench): def __init__(self, dut): super().__init__(dut, dut.clk, dut.rst) self.assign_self_params( - ["IN_WIDTH", "IN_FRAC_WIDTH", "LUT_POW",] + [ + "IN_WIDTH", + "IN_FRAC_WIDTH", + "LUT_POW", + ] ) self.input_driver = StreamDriver( dut.clk, dut.in_data, dut.in_valid, dut.in_ready ) self.output_monitor = StreamMonitor( - dut.clk, dut.out_data, dut.out_valid, dut.out_ready, name="Output ISQRT", + dut.clk, + dut.out_data, + dut.out_valid, + dut.out_ready, + name="Output ISQRT", ) def generate_inputs(self, num=10000): - maxnum = (2 ** self.IN_WIDTH) - 1 + maxnum = (2**self.IN_WIDTH) - 1 return [random.randint(0, maxnum) for _ in range(num)], num def model(self, data_in): ref = [] - lut_size = 2 ** self.LUT_POW + lut_size = 2**self.LUT_POW lut = make_lut(lut_size, self.IN_WIDTH) for x in data_in: expected = isqrt_sw2( @@ -119,7 +127,7 @@ async def valid_backpressure(dut): makedirs(mem_dir, exist_ok=True) def single_cfg(width, frac_width, lut_pow, str_id): - lut_size = 2 ** lut_pow + lut_size = 2**lut_pow lut = make_lut(lut_size, width) mem_path = mem_dir / f"lutmem-{str_id}.mem" write_memb(mem_path, lut, width) diff --git a/src/mase_components/linear_layers/fixed_operators/test/fixed_lut_index_tb.py b/src/mase_components/linear_layers/fixed_operators/test/fixed_lut_index_tb.py index 84fde6666..89c80aa83 100644 --- a/src/mase_components/linear_layers/fixed_operators/test/fixed_lut_index_tb.py +++ b/src/mase_components/linear_layers/fixed_operators/test/fixed_lut_index_tb.py @@ -22,7 +22,7 @@ def __init__(self, dut): self.assign_self_params(["WIDTH", "LUT_POW"]) def generate_inputs(self): - samples = 2 ** self.WIDTH + samples = 2**self.WIDTH data_x = [] msb_indices = [] for x in range(samples): diff --git a/src/mase_components/linear_layers/fixed_operators/test/fixed_nr_stage_tb.py b/src/mase_components/linear_layers/fixed_operators/test/fixed_nr_stage_tb.py index 07a199810..8c0b70ed1 100644 --- a/src/mase_components/linear_layers/fixed_operators/test/fixed_nr_stage_tb.py +++ b/src/mase_components/linear_layers/fixed_operators/test/fixed_nr_stage_tb.py @@ -24,12 +24,12 @@ def __init__(self, dut): self.assign_self_params(["WIDTH"]) def generate_inputs(self, lut_pow): - samples = 2 ** self.WIDTH + samples = 2**self.WIDTH int_width = 1 frac_width = self.WIDTH - 1 data_x = [] initial_guesses = [] - lut_size = 2 ** lut_pow + lut_size = 2**lut_pow lut = make_lut(lut_size, self.WIDTH) # NOTE: since negative values are not supported by fixed formats since # isqrt only outputs positive results we cannot test every single com- diff --git a/src/mase_components/linear_layers/fixed_operators/test/fixed_range_augmentation_tb.py b/src/mase_components/linear_layers/fixed_operators/test/fixed_range_augmentation_tb.py index a0801d23a..cbf9cf776 100644 --- a/src/mase_components/linear_layers/fixed_operators/test/fixed_range_augmentation_tb.py +++ b/src/mase_components/linear_layers/fixed_operators/test/fixed_range_augmentation_tb.py @@ -22,7 +22,7 @@ def __init__(self, dut) -> None: self.assign_self_params(["WIDTH", "FRAC_WIDTH"]) def generate_inputs(self): - samples = 2 ** self.WIDTH + samples = 2**self.WIDTH data_x = [] msb_indices = [] for x in range(samples): diff --git a/src/mase_components/linear_layers/fixed_operators/test/fixed_range_reduction_tb.py b/src/mase_components/linear_layers/fixed_operators/test/fixed_range_reduction_tb.py index c4d00af2a..ed155214a 100644 --- a/src/mase_components/linear_layers/fixed_operators/test/fixed_range_reduction_tb.py +++ b/src/mase_components/linear_layers/fixed_operators/test/fixed_range_reduction_tb.py @@ -20,7 +20,7 @@ def __init__(self, dut) -> None: self.assign_self_params(["WIDTH"]) def generate_inputs(self): - samples = 2 ** self.WIDTH + samples = 2**self.WIDTH return [val for val in range(0, samples)], samples def model(self, inputs): diff --git a/src/mase_components/linear_layers/fixed_operators/test/isqrt_sw.py b/src/mase_components/linear_layers/fixed_operators/test/isqrt_sw.py index a001b27bc..c4723f971 100644 --- a/src/mase_components/linear_layers/fixed_operators/test/isqrt_sw.py +++ b/src/mase_components/linear_layers/fixed_operators/test/isqrt_sw.py @@ -13,7 +13,7 @@ def find_msb(x: int, width: int) -> int: def float_to_int(x: float, int_width: int, frac_width: int) -> int: integer = int(x) x -= integer - res = integer * (2 ** frac_width) + res = integer * (2**frac_width) for i in range(1, frac_width + 1): power = 2 ** (-i) if power <= x: @@ -23,8 +23,8 @@ def float_to_int(x: float, int_width: int, frac_width: int) -> int: def int_to_float(x: int, int_width: int, frac_width: int) -> float: - integer = x / (2 ** frac_width) - fraction = x - integer * 2 ** frac_width + integer = x / (2**frac_width) + fraction = x - integer * 2**frac_width res = integer for i in range(1, frac_width + 1): @@ -85,7 +85,7 @@ def fixed_lut_index_sw(x_red: int, width: int, lut_pow: int) -> int: res = 0 else: res = x_red - 2 ** (width - 1) - res = res * 2 ** lut_pow + res = res * 2**lut_pow res = res / 2 ** (width - 1) # FORMAT OUTPUT: Q(WIDTH).0 return int(res) @@ -258,7 +258,7 @@ def test_sw_model(): def debug_single(): lut_pow = 5 - lut_size = 2 ** lut_pow + lut_size = 2**lut_pow int_width = 2 frac_width = 1 width = int_width + frac_width diff --git a/src/mase_components/linear_layers/matmul/test/fixed_matmul_tb.py b/src/mase_components/linear_layers/matmul/test/fixed_matmul_tb.py index 5734fb891..815e9ed07 100644 --- a/src/mase_components/linear_layers/matmul/test/fixed_matmul_tb.py +++ b/src/mase_components/linear_layers/matmul/test/fixed_matmul_tb.py @@ -147,7 +147,10 @@ def data_generate(self): weight_tensor = {} \n\ data_in = {} \n\ weight_in = {} ".format( - data_tensor, weight_tensor, data_in, weight_in, + data_tensor, + weight_tensor, + data_in, + weight_in, ) ) data_in.reverse() @@ -378,7 +381,8 @@ def runner(): build_args=extra_args, ) runner.test( - hdl_toplevel="fixed_matmul", test_module="fixed_matmul_tb", + hdl_toplevel="fixed_matmul", + test_module="fixed_matmul_tb", ) diff --git a/src/mase_components/linear_layers/matmul/test/simple_matmul_tb.py b/src/mase_components/linear_layers/matmul/test/simple_matmul_tb.py index 0cd5b07b9..9f058695c 100644 --- a/src/mase_components/linear_layers/matmul/test/simple_matmul_tb.py +++ b/src/mase_components/linear_layers/matmul/test/simple_matmul_tb.py @@ -40,8 +40,8 @@ def __init__(self, dut) -> None: ] ) - self.X_MAX = 2 ** self.X_WIDTH - 1 - self.Y_MAX = 2 ** self.Y_WIDTH - 1 + self.X_MAX = 2**self.X_WIDTH - 1 + self.Y_MAX = 2**self.Y_WIDTH - 1 self.x_driver = StreamDriver(dut.clk, dut.x_data, dut.x_valid, dut.x_ready) self.y_driver = StreamDriver(dut.clk, dut.y_data, dut.y_valid, dut.y_ready) @@ -87,10 +87,10 @@ def model(self, X, Y): logger.debug("Sign Extended & Scaled") X_input = sign_extend_t(X_input, self.X_WIDTH).type(torch.float32) / ( - 2 ** self.X_FRAC_WIDTH + 2**self.X_FRAC_WIDTH ) Y_input = sign_extend_t(Y_input, self.Y_WIDTH).type(torch.float32) / ( - 2 ** self.Y_FRAC_WIDTH + 2**self.Y_FRAC_WIDTH ) logger.debug(X_input) logger.debug(Y_input) diff --git a/src/mase_components/linear_layers/matmul/test/transpose_tb.py b/src/mase_components/linear_layers/matmul/test/transpose_tb.py index 003382ff6..c859749ed 100644 --- a/src/mase_components/linear_layers/matmul/test/transpose_tb.py +++ b/src/mase_components/linear_layers/matmul/test/transpose_tb.py @@ -56,7 +56,11 @@ def generate_random_params(num=3): cfgs = list() for _ in range(num): cfgs.append( - {"WIDTH": randint(1, 16), "DIM0": randint(2, 12), "DIM1": randint(2, 12),} + { + "WIDTH": randint(1, 16), + "DIM0": randint(2, 12), + "DIM1": randint(2, 12), + } ) return cfgs diff --git a/src/mase_components/linear_layers/mxint_operators/test/mxint_cast_tb.py b/src/mase_components/linear_layers/mxint_operators/test/mxint_cast_tb.py index ce742183b..963ae40df 100644 --- a/src/mase_components/linear_layers/mxint_operators/test/mxint_cast_tb.py +++ b/src/mase_components/linear_layers/mxint_operators/test/mxint_cast_tb.py @@ -29,7 +29,10 @@ def __init__(self, dut, num) -> None: self.log = SimLog("%s" % (type(self).__qualname__)) self.data_in_0_driver = MultiSignalStreamDriver( - dut.clk, (dut.mdata_in, dut.edata_in), dut.data_in_valid, dut.data_in_ready, + dut.clk, + (dut.mdata_in, dut.edata_in), + dut.data_in_valid, + dut.data_in_ready, ) self.data_out_0_monitor = MultiSignalStreamMonitor( @@ -48,10 +51,14 @@ def generate_inputs(self): for _ in range(self.num): data = 20 * torch.rand(int(self.dut.BLOCK_SIZE)) (data_in, mdata_in, edata_in) = mxint_quantize( - data, int(self.dut.IN_MAN_WIDTH), int(self.dut.IN_EXP_WIDTH), + data, + int(self.dut.IN_MAN_WIDTH), + int(self.dut.IN_EXP_WIDTH), ) exp_out, mexp_out, eexp_out = mxint_quantize( - data_in, int(self.dut.OUT_MAN_WIDTH), int(self.dut.OUT_EXP_WIDTH), + data_in, + int(self.dut.OUT_MAN_WIDTH), + int(self.dut.OUT_EXP_WIDTH), ) breakpoint() inputs.append((mdata_in.int().tolist(), edata_in.int().tolist())) diff --git a/src/mase_components/linear_layers/mxint_operators/test/mxint_dot_product_tb.py b/src/mase_components/linear_layers/mxint_operators/test/mxint_dot_product_tb.py index cb1b9638f..43e70adcf 100644 --- a/src/mase_components/linear_layers/mxint_operators/test/mxint_dot_product_tb.py +++ b/src/mase_components/linear_layers/mxint_operators/test/mxint_dot_product_tb.py @@ -40,7 +40,10 @@ def __init__(self, dut, num) -> None: dut.data_in_0_ready, ) self.weight_driver = MultiSignalStreamDriver( - dut.clk, (dut.mweight, dut.eweight), dut.weight_valid, dut.weight_ready, + dut.clk, + (dut.mweight, dut.eweight), + dut.weight_valid, + dut.weight_ready, ) self.data_out_0_monitor = MultiSignalStreamMonitor( @@ -64,7 +67,9 @@ def generate_inputs(self): ) w = torch.rand(int(self.dut.BLOCK_SIZE)) (weight, mweight, eweight) = mxint_quantize( - w, int(self.dut.WEIGHT_PRECISION_0), int(self.dut.WEIGHT_PRECISION_1), + w, + int(self.dut.WEIGHT_PRECISION_0), + int(self.dut.WEIGHT_PRECISION_1), ) mdp_out = mdata_in @ mweight edp_out = edata_in + eweight diff --git a/src/mase_components/linear_layers/mxint_operators/test/mxint_matmul_tb.py b/src/mase_components/linear_layers/mxint_operators/test/mxint_matmul_tb.py index e01e056e1..f31d1ab61 100644 --- a/src/mase_components/linear_layers/mxint_operators/test/mxint_matmul_tb.py +++ b/src/mase_components/linear_layers/mxint_operators/test/mxint_matmul_tb.py @@ -48,10 +48,16 @@ def __init__(self, dut) -> None: self.log.setLevel(logging.DEBUG) self.a_driver = MultiSignalStreamDriver( - dut.clk, (dut.ma_data, dut.ea_data), dut.a_valid, dut.a_ready, + dut.clk, + (dut.ma_data, dut.ea_data), + dut.a_valid, + dut.a_ready, ) self.b_driver = MultiSignalStreamDriver( - dut.clk, (dut.mb_data, dut.eb_data), dut.b_valid, dut.b_ready, + dut.clk, + (dut.mb_data, dut.eb_data), + dut.b_valid, + dut.b_ready, ) self.output_monitor = MultiSignalStreamMonitor( diff --git a/src/mase_components/linear_layers/mxint_operators/test/mxint_vector_mult_tb.py b/src/mase_components/linear_layers/mxint_operators/test/mxint_vector_mult_tb.py index 06f970039..58fc472ba 100644 --- a/src/mase_components/linear_layers/mxint_operators/test/mxint_vector_mult_tb.py +++ b/src/mase_components/linear_layers/mxint_operators/test/mxint_vector_mult_tb.py @@ -35,7 +35,10 @@ def __init__(self, dut, num) -> None: dut.data_in_0_ready, ) self.weight_driver = MultiSignalStreamDriver( - dut.clk, (dut.mweight, dut.eweight), dut.weight_valid, dut.weight_ready, + dut.clk, + (dut.mweight, dut.eweight), + dut.weight_valid, + dut.weight_ready, ) self.data_out_0_monitor = MultiSignalStreamMonitor( @@ -63,7 +66,9 @@ def generate_inputs(self): w = 20 * torch.rand(int(self.dut.BLOCK_SIZE)) (weight, mweight, eweight) = mxint_quantize( - w, int(self.dut.WEIGHT_PRECISION_0), int(self.dut.WEIGHT_PRECISION_1), + w, + int(self.dut.WEIGHT_PRECISION_0), + int(self.dut.WEIGHT_PRECISION_1), ) exp_out, mexp_out, eexp_out = mxint_quantize( data_in * weight, diff --git a/src/mase_components/linear_layers/mxint_operators/test/test.py b/src/mase_components/linear_layers/mxint_operators/test/test.py index 85cdaa995..f58f382cf 100644 --- a/src/mase_components/linear_layers/mxint_operators/test/test.py +++ b/src/mase_components/linear_layers/mxint_operators/test/test.py @@ -8,8 +8,16 @@ d_man_width = 12 w_man_width = 8 e_width = 4 -(data_in, mdata_in, edata_in) = mxint_quantize(data, d_man_width, e_width,) -(weight, mweight, eweight) = mxint_quantize(w, w_man_width, e_width,) +(data_in, mdata_in, edata_in) = mxint_quantize( + data, + d_man_width, + e_width, +) +(weight, mweight, eweight) = mxint_quantize( + w, + w_man_width, + e_width, +) linear = torch.nn.Linear(10, 10, bias=False) linear.weight = torch.nn.Parameter(weight) target_x = linear(data_in) @@ -28,7 +36,7 @@ def hardware_quant(hardware_in, be_value, e_width, width): exponent_bias = 2 ** (e_width - 1) - 1 # exponent - exponent_max = 2 ** e_width - 1 - exponent_bias + exponent_max = 2**e_width - 1 - exponent_bias exponent_min = -exponent_bias exponent = ( torch.ceil(torch.log2(hardware_in.abs().max())) + be_value - exponent_bias @@ -40,7 +48,7 @@ def hardware_quant(hardware_in, be_value, e_width, width): breakpoint() mantissa = torch.clamp(mantissa.floor(), int_min, int_max) - msfp_x = (2 ** exponent) * mantissa + msfp_x = (2**exponent) * mantissa return msfp_x, mantissa, exponent diff --git a/src/mase_components/linear_layers/mxint_operators/test/utils.py b/src/mase_components/linear_layers/mxint_operators/test/utils.py index 43b3b0b87..7edb9f6ed 100644 --- a/src/mase_components/linear_layers/mxint_operators/test/utils.py +++ b/src/mase_components/linear_layers/mxint_operators/test/utils.py @@ -24,7 +24,7 @@ def mxint_quantize(x, width: int = 12, exponent_width: int = 6, exponent: int = """ exponent_bias = 2 ** (exponent_width - 1) - exponent_max = 2 ** exponent_width - 1 - exponent_bias + exponent_max = 2**exponent_width - 1 - exponent_bias exponent_min = -exponent_bias # exponent @@ -34,9 +34,9 @@ def mxint_quantize(x, width: int = 12, exponent_width: int = 6, exponent: int = # mantissa int_min = -(2 ** (width - 1)) int_max = 2 ** (width - 1) - 1 - mantissa = x / 2 ** exponent + mantissa = x / 2**exponent mantissa = torch.clamp(mantissa.floor(), int_min, int_max) - msfp_x = (2 ** exponent) * mantissa + msfp_x = (2**exponent) * mantissa return msfp_x, mantissa, exponent diff --git a/src/mase_components/memory/test/fifo_tb.py b/src/mase_components/memory/test/fifo_tb.py index 5f247bede..9dc4f54ff 100644 --- a/src/mase_components/memory/test/fifo_tb.py +++ b/src/mase_components/memory/test/fifo_tb.py @@ -27,7 +27,7 @@ def __init__(self, dut) -> None: # self.output_monitor.log.setLevel("DEBUG") def generate_inputs(self, num=20): - return [randint(0, (2 ** self.DATA_WIDTH) - 1) for _ in range(num)] + return [randint(0, (2**self.DATA_WIDTH) - 1) for _ in range(num)] @cocotb.test() @@ -120,7 +120,12 @@ async def cocotb_test_soak(dut): @pytest.mark.dev def test_fifo(): mase_runner( - module_param_list=[{"DEPTH": 1}, {"DEPTH": 7}, {"DEPTH": 8}, {"DEPTH": 81},], + module_param_list=[ + {"DEPTH": 1}, + {"DEPTH": 7}, + {"DEPTH": 8}, + {"DEPTH": 81}, + ], trace=True, ) diff --git a/src/mase_components/memory/test/repeat_circular_buffer_tb.py b/src/mase_components/memory/test/repeat_circular_buffer_tb.py index db6bfc7b8..11d7d85b1 100644 --- a/src/mase_components/memory/test/repeat_circular_buffer_tb.py +++ b/src/mase_components/memory/test/repeat_circular_buffer_tb.py @@ -30,7 +30,7 @@ def generate_inputs(self, num=10): inputs = [] for _ in range(num): inputs.extend( - [random.randint(0, 2 ** self.DATA_WIDTH - 1) for _ in range(self.SIZE)] + [random.randint(0, 2**self.DATA_WIDTH - 1) for _ in range(self.SIZE)] ) return inputs diff --git a/src/mase_components/memory/test/unpacked_fifo_tb.py b/src/mase_components/memory/test/unpacked_fifo_tb.py index 734a2edcb..cf0a99592 100644 --- a/src/mase_components/memory/test/unpacked_fifo_tb.py +++ b/src/mase_components/memory/test/unpacked_fifo_tb.py @@ -26,7 +26,7 @@ def __init__(self, dut) -> None: def generate_inputs(self, batches=20): return [ - [randint(0, (2 ** self.DATA_WIDTH) - 1) for _ in range(self.IN_NUM)] + [randint(0, (2**self.DATA_WIDTH) - 1) for _ in range(self.IN_NUM)] for _ in range(batches) ] diff --git a/src/mase_components/normalization_layers/process_synth_impl.py b/src/mase_components/normalization_layers/process_synth_impl.py index 5a12ee170..c65d1b8a9 100644 --- a/src/mase_components/normalization_layers/process_synth_impl.py +++ b/src/mase_components/normalization_layers/process_synth_impl.py @@ -103,14 +103,19 @@ def gather_data(build_dir: Path): if __name__ == "__main__": data = gather_data(Path("build")) data["ns"] = data["clk_period"] - data["wns"] - data["fmax"] = 1 / (data["ns"] * (10 ** -9)) + data["fmax"] = 1 / (data["ns"] * (10**-9)) data["fmax_mhz"] = data["fmax"] / 1_000_000 print(data) def plot(col): - alt.Chart(data).mark_line().encode(x="width", y=col, color="norm",).properties( - width=400, height=200, + alt.Chart(data).mark_line().encode( + x="width", + y=col, + color="norm", + ).properties( + width=400, + height=200, ).save(f"{col}_plot.png", scale_factor=3) plot("wns") diff --git a/src/mase_components/normalization_layers/test/batch_norm_2d_tb.py b/src/mase_components/normalization_layers/test/batch_norm_2d_tb.py index d43976845..5d9da4bc6 100644 --- a/src/mase_components/normalization_layers/test/batch_norm_2d_tb.py +++ b/src/mase_components/normalization_layers/test/batch_norm_2d_tb.py @@ -158,7 +158,7 @@ def model(self, inputs): -1, self.NUM_CHANNELS, self.TOTAL_DIM1, self.TOTAL_DIM0 ) x = sign_extend_t(x, self.IN_WIDTH).to(dtype=torch.float32) / ( - 2 ** self.IN_FRAC_WIDTH + 2**self.IN_FRAC_WIDTH ) # Float Model @@ -397,7 +397,8 @@ def gen_cfg( ] mase_runner( - module_param_list=test_cfgs, trace=True, + module_param_list=test_cfgs, + trace=True, ) diff --git a/src/mase_components/normalization_layers/test/group_norm_2d_tb.py b/src/mase_components/normalization_layers/test/group_norm_2d_tb.py index cffd16822..e31203c25 100644 --- a/src/mase_components/normalization_layers/test/group_norm_2d_tb.py +++ b/src/mase_components/normalization_layers/test/group_norm_2d_tb.py @@ -79,7 +79,7 @@ def __init__(self, dut) -> None: self.out_width_tup = self.OUT_WIDTH, self.OUT_FRAC_WIDTH # Inverse Square Root LUT - self.isqrt_lut = make_lut(2 ** 5, 16) + self.isqrt_lut = make_lut(2**5, 16) self.num_groups = randint(2, 3) self.total_channels = self.GROUP_CHANNELS * self.num_groups @@ -141,7 +141,7 @@ def model(self, inputs): -1, self.total_channels, self.TOTAL_DIM1, self.TOTAL_DIM0 ) x = sign_extend_t(x, self.IN_WIDTH).to(dtype=torch.float32) / ( - 2 ** self.IN_FRAC_WIDTH + 2**self.IN_FRAC_WIDTH ) # Float Model @@ -239,7 +239,12 @@ def test_group_norm_2d(): makedirs(mem_dir, exist_ok=True) def isqrt_width( - total_dim0, total_dim1, compute_dim0, compute_dim1, group_channels, in_width, + total_dim0, + total_dim1, + compute_dim0, + compute_dim1, + group_channels, + in_width, ): depth_dim0 = total_dim0 // compute_dim0 depth_dim1 = total_dim1 // compute_dim1 @@ -269,7 +274,7 @@ def gen_cfg( isqrt_w = isqrt_width( total_dim0, total_dim1, compute_dim0, compute_dim1, channels, in_width ) - lut = make_lut(2 ** LUT_POW, isqrt_w) + lut = make_lut(2**LUT_POW, isqrt_w) mem_path = mem_dir / f"lutmem-{str_id}.mem" write_memb(mem_path, lut, isqrt_w) params = { diff --git a/src/mase_components/normalization_layers/test/models.py b/src/mase_components/normalization_layers/test/models.py index dfeadac98..0f10b228e 100644 --- a/src/mase_components/normalization_layers/test/models.py +++ b/src/mase_components/normalization_layers/test/models.py @@ -56,11 +56,11 @@ def _fixed_group_norm_2d_model( logger.debug("Diff:") logger.debug(diff[0]) - squares = diff ** 2 + squares = diff**2 logger.debug("Squares:") logger.debug(squares[0]) - squares_int = (squares * (2 ** square_frac_width)).int() - logger.debug(squares * (2 ** square_frac_width)) + squares_int = (squares * (2**square_frac_width)).int() + logger.debug(squares * (2**square_frac_width)) sum_squares = torch.sum(squares, dim=(1, 2, 3), keepdim=True) sum_squares = integer_floor_quantizer( @@ -81,12 +81,10 @@ def _fixed_group_norm_2d_model( logger.debug(f"{var[0]}") # Clamp down variance to isqrt width - var_clamp = torch.clamp( - var, 0.0, ((2 ** isqrt_width) - 1) / (2 ** isqrt_frac_width) - ) + var_clamp = torch.clamp(var, 0.0, ((2**isqrt_width) - 1) / (2**isqrt_frac_width)) logger.debug("Variance Clamped:") logger.debug(f"{var_clamp[0]}") - var_clamp_int = (var_clamp * (2 ** isqrt_frac_width)).int() + var_clamp_int = (var_clamp * (2**isqrt_frac_width)).int() # Inverse Square Root calculation lut_pow = ceil(log2(len(isqrt_lut))) @@ -106,7 +104,7 @@ def _fixed_group_norm_2d_model( logger.debug("INV SQRT INT:") logger.debug(f"{inv_sqrt_int[0]}") - inv_sqrt = inv_sqrt_int / (2 ** isqrt_frac_width) + inv_sqrt = inv_sqrt_int / (2**isqrt_frac_width) logger.debug("Inverse SQRT:") logger.debug(f"{inv_sqrt[0]}") diff --git a/src/mase_components/normalization_layers/test/rms_norm_2d_tb.py b/src/mase_components/normalization_layers/test/rms_norm_2d_tb.py index 0c998f83d..22ef8bdd1 100644 --- a/src/mase_components/normalization_layers/test/rms_norm_2d_tb.py +++ b/src/mase_components/normalization_layers/test/rms_norm_2d_tb.py @@ -151,7 +151,7 @@ def reconstruct_tensor(self, x, width, frac_width): x = torch.stack(matrix_list).reshape( -1, self.CHANNELS, self.TOTAL_DIM1, self.TOTAL_DIM0 ) - x = sign_extend_t(x, width).to(dtype=torch.float32) / (2 ** frac_width) + x = sign_extend_t(x, width).to(dtype=torch.float32) / (2**frac_width) return x def output_monitor_split(self, x, width, frac_width): @@ -248,7 +248,12 @@ def test_rms_norm_2d(): makedirs(mem_dir, exist_ok=True) def isqrt_width( - total_dim0, total_dim1, compute_dim0, compute_dim1, group_channels, in_width, + total_dim0, + total_dim1, + compute_dim0, + compute_dim1, + group_channels, + in_width, ): depth_dim0 = total_dim0 // compute_dim0 depth_dim1 = total_dim1 // compute_dim1 @@ -277,7 +282,7 @@ def gen_cfg( isqrt_w = isqrt_width( total_dim0, total_dim1, compute_dim0, compute_dim1, channels, in_width ) - lut = make_lut(2 ** LUT_POW, isqrt_w) + lut = make_lut(2**LUT_POW, isqrt_w) mem_path = mem_dir / f"lutmem-{str_id}.mem" write_memb(mem_path, lut, isqrt_w) params = { diff --git a/src/mase_components/scalar_operators/fixed/test/fixed_isqrt_tb.py b/src/mase_components/scalar_operators/fixed/test/fixed_isqrt_tb.py index f0f5cafb1..6ee7a93dd 100644 --- a/src/mase_components/scalar_operators/fixed/test/fixed_isqrt_tb.py +++ b/src/mase_components/scalar_operators/fixed/test/fixed_isqrt_tb.py @@ -24,23 +24,31 @@ class VerificationCase(Testbench): def __init__(self, dut): super().__init__(dut, dut.clk, dut.rst) self.assign_self_params( - ["IN_WIDTH", "IN_FRAC_WIDTH", "LUT_POW",] + [ + "IN_WIDTH", + "IN_FRAC_WIDTH", + "LUT_POW", + ] ) self.input_driver = StreamDriver( dut.clk, dut.in_data, dut.in_valid, dut.in_ready ) self.output_monitor = StreamMonitor( - dut.clk, dut.out_data, dut.out_valid, dut.out_ready, name="Output ISQRT", + dut.clk, + dut.out_data, + dut.out_valid, + dut.out_ready, + name="Output ISQRT", ) def generate_inputs(self, num=10000): - maxnum = (2 ** self.IN_WIDTH) - 1 + maxnum = (2**self.IN_WIDTH) - 1 return [random.randint(0, maxnum) for _ in range(num)], num def model(self, data_in): ref = [] - lut_size = 2 ** self.LUT_POW + lut_size = 2**self.LUT_POW lut = make_lut(lut_size, self.IN_WIDTH) for x in data_in: expected = isqrt_sw2( @@ -123,7 +131,7 @@ def test_fixed_isqrt(): makedirs(mem_dir, exist_ok=True) def single_cfg(width, frac_width, lut_pow, str_id): - lut_size = 2 ** lut_pow + lut_size = 2**lut_pow lut = make_lut(lut_size, width) mem_path = mem_dir / f"lutmem-{str_id}.mem" write_memb(mem_path, lut, width) diff --git a/src/mase_components/scalar_operators/fixed/test/fixed_nr_stage_tb.py b/src/mase_components/scalar_operators/fixed/test/fixed_nr_stage_tb.py index 5a3019f8e..d09305ef0 100644 --- a/src/mase_components/scalar_operators/fixed/test/fixed_nr_stage_tb.py +++ b/src/mase_components/scalar_operators/fixed/test/fixed_nr_stage_tb.py @@ -24,12 +24,12 @@ def __init__(self, dut): self.assign_self_params(["WIDTH"]) def generate_inputs(self, lut_pow): - samples = 2 ** self.WIDTH + samples = 2**self.WIDTH int_width = 1 frac_width = self.WIDTH - 1 data_x = [] initial_guesses = [] - lut_size = 2 ** lut_pow + lut_size = 2**lut_pow lut = make_lut(lut_size, self.WIDTH) # NOTE: since negative values are not supported by fixed formats since # isqrt only outputs positive results we cannot test every single com- diff --git a/src/mase_components/scalar_operators/fixed/test/isqrt_sw.py b/src/mase_components/scalar_operators/fixed/test/isqrt_sw.py index 901ba964d..f6fc798a4 100644 --- a/src/mase_components/scalar_operators/fixed/test/isqrt_sw.py +++ b/src/mase_components/scalar_operators/fixed/test/isqrt_sw.py @@ -13,7 +13,7 @@ def find_msb(x: int, width: int) -> int: def float_to_int(x: float, int_width: int, frac_width: int) -> int: integer = int(x) x -= integer - res = integer * (2 ** frac_width) + res = integer * (2**frac_width) for i in range(1, frac_width + 1): power = 2 ** (-i) if power <= x: @@ -23,8 +23,8 @@ def float_to_int(x: float, int_width: int, frac_width: int) -> int: def int_to_float(x: int, int_width: int, frac_width: int) -> float: - integer = x / (2 ** frac_width) - fraction = x - integer * 2 ** frac_width + integer = x / (2**frac_width) + fraction = x - integer * 2**frac_width res = integer for i in range(1, frac_width + 1): @@ -85,7 +85,7 @@ def fixed_lut_index_sw(x_red: int, width: int, lut_pow: int) -> int: res = 0 else: res = x_red - 2 ** (width - 1) - res = res * 2 ** lut_pow + res = res * 2**lut_pow res = res / 2 ** (width - 1) # FORMAT OUTPUT: Q(WIDTH).0 return int(res) @@ -258,7 +258,7 @@ def test_isqrt_sw_model(): def debug_single(): lut_pow = 5 - lut_size = 2 ** lut_pow + lut_size = 2**lut_pow int_width = 2 frac_width = 1 width = int_width + frac_width diff --git a/src/mase_components/transformer_layers/test/fixed_grouped_query_attention_wrapper_tb.py b/src/mase_components/transformer_layers/test/fixed_grouped_query_attention_wrapper_tb.py index a334a1469..d999fdc59 100644 --- a/src/mase_components/transformer_layers/test/fixed_grouped_query_attention_wrapper_tb.py +++ b/src/mase_components/transformer_layers/test/fixed_grouped_query_attention_wrapper_tb.py @@ -195,16 +195,28 @@ def __init__(self, dut) -> None: if self.HAS_BIAS == 1: self.bias_q_driver = StreamDriver( - dut.clk, dut.bias_query, dut.bias_query_valid, dut.bias_query_ready, + dut.clk, + dut.bias_query, + dut.bias_query_valid, + dut.bias_query_ready, ) self.bias_k_driver = StreamDriver( - dut.clk, dut.bias_key, dut.bias_key_valid, dut.bias_key_ready, + dut.clk, + dut.bias_key, + dut.bias_key_valid, + dut.bias_key_ready, ) self.bias_v_driver = StreamDriver( - dut.clk, dut.bias_value, dut.bias_value_valid, dut.bias_value_ready, + dut.clk, + dut.bias_value, + dut.bias_value_valid, + dut.bias_value_ready, ) self.bias_o_driver = StreamDriver( - dut.clk, dut.bias_output, dut.bias_output_valid, dut.bias_output_ready, + dut.clk, + dut.bias_output, + dut.bias_output_valid, + dut.bias_output_ready, ) self.error_threshold = 2 @@ -469,11 +481,11 @@ async def run_memory_bandwidth_test(self, us: int = 500): num_v_weight_beats_sent = self.weight_v_driver.num_beats num_o_weight_beats_sent = self.weight_o_driver.num_beats - input_beats_per_sec = num_input_beats_sent / (nanosec * (10 ** -9)) - num_q_beats_per_sec = num_q_weight_beats_sent / (nanosec * (10 ** -9)) - num_k_beats_per_sec = num_k_weight_beats_sent / (nanosec * (10 ** -9)) - num_v_beats_per_sec = num_v_weight_beats_sent / (nanosec * (10 ** -9)) - num_o_beats_per_sec = num_o_weight_beats_sent / (nanosec * (10 ** -9)) + input_beats_per_sec = num_input_beats_sent / (nanosec * (10**-9)) + num_q_beats_per_sec = num_q_weight_beats_sent / (nanosec * (10**-9)) + num_k_beats_per_sec = num_k_weight_beats_sent / (nanosec * (10**-9)) + num_v_beats_per_sec = num_v_weight_beats_sent / (nanosec * (10**-9)) + num_o_beats_per_sec = num_o_weight_beats_sent / (nanosec * (10**-9)) self.log.info("Test length (ns): %.4f" % nanosec) @@ -591,7 +603,9 @@ def test_fixed_linear_smoke(): ] mase_runner( - module_param_list=cfgs, hierarchical=True, template=True, + module_param_list=cfgs, + hierarchical=True, + template=True, ) @@ -603,7 +617,9 @@ def test_parallelism_sweep(): cfgs.append(get_config(16, 128, 8, 4, embedding_par, seq_par)) mase_runner( - module_param_list=cfgs, hierarchical=True, template=True, + module_param_list=cfgs, + hierarchical=True, + template=True, ) @@ -615,7 +631,9 @@ def test_small_parallelism(): cfgs.append(get_config(16, 128, 8, 4, embedding_par, seq_par)) mase_runner( - module_param_list=cfgs, hierarchical=True, template=True, + module_param_list=cfgs, + hierarchical=True, + template=True, ) @@ -625,7 +643,9 @@ def test_heads_sweep(): cfgs.append(get_config(256, 256, 16, kv_heads, 16, 1)) mase_runner( - module_param_list=cfgs, hierarchical=True, template=True, + module_param_list=cfgs, + hierarchical=True, + template=True, ) @@ -637,7 +657,9 @@ def test_bitwidth_sweep(): ) mase_runner( - module_param_list=cfgs, hierarchical=True, template=True, + module_param_list=cfgs, + hierarchical=True, + template=True, ) diff --git a/src/mase_components/transformer_layers/test/fixed_self_attention_head_tb.py b/src/mase_components/transformer_layers/test/fixed_self_attention_head_tb.py index 70df56d43..a46032646 100644 --- a/src/mase_components/transformer_layers/test/fixed_self_attention_head_tb.py +++ b/src/mase_components/transformer_layers/test/fixed_self_attention_head_tb.py @@ -44,7 +44,11 @@ def __init__(self, dut) -> None: ) self.out_monitor = StreamMonitor( - dut.clk, dut.out, dut.out_valid, dut.out_ready, check=False, + dut.clk, + dut.out, + dut.out_valid, + dut.out_ready, + check=False, ) # Model @@ -56,7 +60,8 @@ def __init__(self, dut) -> None: "frac_width": self.get_parameter("IN_DATA_PRECISION_1"), } self.model = BertSelfAttentionHeadInteger( - config=self.config, q_config=self.q_config, + config=self.config, + q_config=self.q_config, ) # Set verbosity of driver and monitor loggers to debug @@ -114,21 +119,27 @@ async def run_test(self): # * Load the query driver self.log.info(f"Processing query inputs: {inputs['query_layer']}") query_inputs = self.preprocess_tensor( - tensor=inputs["query_layer"], config=self.q_config, parallelism=parallelism, + tensor=inputs["query_layer"], + config=self.q_config, + parallelism=parallelism, ) self.query_driver.load_driver(query_inputs) # * Load the key driver self.log.info(f"Processing key inputs: {inputs['key_layer']}") key_inputs = self.preprocess_tensor( - tensor=inputs["key_layer"], config=self.q_config, parallelism=parallelism, + tensor=inputs["key_layer"], + config=self.q_config, + parallelism=parallelism, ) self.key_driver.load_driver(key_inputs) # * Load the value driver self.log.info(f"Processing value inputs: {inputs['value_layer']}") value_inputs = self.preprocess_tensor( - tensor=inputs["value_layer"], config=self.q_config, parallelism=parallelism, + tensor=inputs["value_layer"], + config=self.q_config, + parallelism=parallelism, ) self.value_driver.load_driver(value_inputs) @@ -187,11 +198,18 @@ def test_fixed_self_attention_head_smoke(): # * Generate exponential LUT for softmax generate_memory.generate_sv_lut( - "exp", 16, 3, 16, 3, path=Path(__file__).parents[1] / "rtl", + "exp", + 16, + 3, + 16, + 3, + path=Path(__file__).parents[1] / "rtl", ) mase_runner( trace=True, - module_param_list=[get_fixed_self_attention_head_config(),], + module_param_list=[ + get_fixed_self_attention_head_config(), + ], skip_build=False, ) diff --git a/src/mase_components/vision_models/vit/test/fixed_mlp_tb.py b/src/mase_components/vision_models/vit/test/fixed_mlp_tb.py index 823c65db4..2e19bed4b 100644 --- a/src/mase_components/vision_models/vit/test/fixed_mlp_tb.py +++ b/src/mase_components/vision_models/vit/test/fixed_mlp_tb.py @@ -108,7 +108,8 @@ def __init__(self, samples=1): depth_in_num = int(self.in_num / self.tile_in_num) depth_out_features = int(self.out_features / self.tile_out_features) self.outputs = RandomSink( - samples=samples * depth_out_features * depth_in_num, debug=debug, + samples=samples * depth_out_features * depth_in_num, + debug=debug, ) self.ref = self.sw_compute() @@ -470,7 +471,8 @@ def runner(): ) for _ in range(1): runner.test( - hdl_toplevel="fixed_mlp", test_module="fixed_mlp_tb", + hdl_toplevel="fixed_mlp", + test_module="fixed_mlp_tb", ) diff --git a/src/mase_components/vision_models/vit/test/fixed_patch_embed_tb.py b/src/mase_components/vision_models/vit/test/fixed_patch_embed_tb.py index b92e6dc4f..05cf0f2cf 100644 --- a/src/mase_components/vision_models/vit/test/fixed_patch_embed_tb.py +++ b/src/mase_components/vision_models/vit/test/fixed_patch_embed_tb.py @@ -50,14 +50,14 @@ def __init__(self, samples=1): self.pe_unroll_kernel_out = 3 self.pe_unroll_in_c = 3 self.pe_unroll_embed_dim = 8 - self.num_patch = int(self.in_y * self.in_x // (self.patch_size ** 2)) + self.num_patch = int(self.in_y * self.in_x // (self.patch_size**2)) # self.num_classes = 10 # self.head_unroll_out_x = 5 self.samples = samples self.pe_iter_weight = int( - (self.patch_size ** 2) + (self.patch_size**2) * self.in_c * self.embed_dim / self.pe_unroll_kernel_out @@ -247,7 +247,10 @@ def conv_pack( unroll_out_channels, ).permute(0, 3, 1, 4, 2) - w_tensor = w_tensor.reshape(-1, unroll_out_channels * unroll_kernel_out,) + w_tensor = w_tensor.reshape( + -1, + unroll_out_channels * unroll_kernel_out, + ) w_in = w_tensor.type(torch.int).flip(0).tolist() # bias_pack bias_tensor = bias.repeat(samples, 1).reshape(-1, unroll_out_channels) diff --git a/src/mase_components/vision_models/vit/test/fixed_pvt_tb.py b/src/mase_components/vision_models/vit/test/fixed_pvt_tb.py index 532e3fb07..06b921e81 100644 --- a/src/mase_components/vision_models/vit/test/fixed_pvt_tb.py +++ b/src/mase_components/vision_models/vit/test/fixed_pvt_tb.py @@ -49,7 +49,7 @@ def __init__(self, samples=1): self.in_x = 224 self.embed_dim = 384 self.patch_size = 16 - self.num_patch = self.in_y * self.in_x // (self.patch_size ** 2) + self.num_patch = self.in_y * self.in_x // (self.patch_size**2) self.num_heads = 6 self.mlp_ratio = 2 @@ -64,7 +64,7 @@ def __init__(self, samples=1): self.head_unroll_out_x = 1 self.pe_iter_weight = int( - (self.patch_size ** 2) + (self.patch_size**2) * self.in_c * self.embed_dim / self.pe_unroll_kernel_out @@ -1148,7 +1148,10 @@ def conv_pack( unroll_out_channels, ).permute(0, 3, 1, 4, 2) - w_tensor = w_tensor.reshape(-1, unroll_out_channels * unroll_kernel_out,) + w_tensor = w_tensor.reshape( + -1, + unroll_out_channels * unroll_kernel_out, + ) w_in = w_tensor.type(torch.int).flip(0).tolist() # bias_pack bias_tensor = bias.repeat(samples, 1).reshape(-1, unroll_out_channels) @@ -1511,7 +1514,8 @@ def runner(): build_args=extra_args, ) runner.test( - hdl_toplevel="fixed_pvt", test_module="fixed_pvt_tb", + hdl_toplevel="fixed_pvt", + test_module="fixed_pvt_tb", ) diff --git a/src/mase_components/vision_models/vit/test/hash_exp_tb.py b/src/mase_components/vision_models/vit/test/hash_exp_tb.py index 719ee89dd..e8b107503 100644 --- a/src/mase_components/vision_models/vit/test/hash_exp_tb.py +++ b/src/mase_components/vision_models/vit/test/hash_exp_tb.py @@ -147,7 +147,9 @@ def runner(): print(extra_args) runner = get_runner(sim) runner.build( - verilog_sources=verilog_sources, hdl_toplevel="hash_exp", build_args=extra_args, + verilog_sources=verilog_sources, + hdl_toplevel="hash_exp", + build_args=extra_args, ) runner.test(hdl_toplevel="hash_exp", test_module="hash_exp_tb") diff --git a/src/mase_components/vision_models/vit/test/hash_softmax_tb.py b/src/mase_components/vision_models/vit/test/hash_softmax_tb.py index 5c0522fb9..98ceea63f 100644 --- a/src/mase_components/vision_models/vit/test/hash_softmax_tb.py +++ b/src/mase_components/vision_models/vit/test/hash_softmax_tb.py @@ -45,7 +45,11 @@ def __init__(self, samples=1): }, } self.d_config = { - "softmax": {"in_size": 1, "out_size": 1, "in_depth": 4,}, + "softmax": { + "in_size": 1, + "out_size": 1, + "in_depth": 4, + }, } in_size = self.d_config["softmax"]["in_size"] out_size = self.d_config["softmax"]["out_size"] diff --git a/src/mase_components/vision_models/vit/test/helpers/ha_softmax.py b/src/mase_components/vision_models/vit/test/helpers/ha_softmax.py index 8dac8afee..62e5b840d 100644 --- a/src/mase_components/vision_models/vit/test/helpers/ha_softmax.py +++ b/src/mase_components/vision_models/vit/test/helpers/ha_softmax.py @@ -5,8 +5,8 @@ def quantize_to_int(x: Tensor, width: int, frac_width: int): - x = _integer_quantize(x, width, frac_width) * (2 ** frac_width) - x = x.int() & (2 ** width - 1) + x = _integer_quantize(x, width, frac_width) * (2**frac_width) + x = x.int() & (2**width - 1) return x @@ -25,7 +25,7 @@ def twos_complement_to_float(binary_string: str, width: int, frac_width: int): integer_magnitude = -(2 ** (width - 1)) + integer_magnitude # Calculate scaling factor - scaling_factor = 2 ** frac_width + scaling_factor = 2**frac_width # Calculate floating-point value float_value = integer_magnitude / scaling_factor @@ -79,7 +79,8 @@ def generate_table_div_software(width, out_width, out_frac_width): class QHashSoftmax(torch.nn.Module): def __init__( - self, config, + self, + config, ): super(QHashSoftmax, self).__init__() self.in_width = config["data_in_width"] @@ -108,7 +109,7 @@ def forward(self, x, scale): # quantize to div_width one_over_div = _integer_quantize(exp_sum // exp, self.div_width + 1, 0) one_over_div = torch.where( - exp == 0, torch.tensor(2 ** self.div_width - 1), one_over_div + exp == 0, torch.tensor(2**self.div_width - 1), one_over_div ) one_over_div = torch.tensor(one_over_div, dtype=int) diff --git a/src/mase_components/vision_models/vit/test/helpers/pvt_quant.py b/src/mase_components/vision_models/vit/test/helpers/pvt_quant.py index 4863f2ca9..d7f4e8464 100644 --- a/src/mase_components/vision_models/vit/test/helpers/pvt_quant.py +++ b/src/mase_components/vision_models/vit/test/helpers/pvt_quant.py @@ -50,7 +50,7 @@ def __init__( self.dim = dim self.num_heads = num_heads head_dim = dim // num_heads - self.scale = qk_scale or head_dim ** -0.5 + self.scale = qk_scale or head_dim**-0.5 self.q = get_quantized_cls("linear", config["q_proj"])( dim, dim, bias=qkv_bias, config=config["q_proj"] diff --git a/test/nn/snn/test_ann2snn.py b/test/nn/snn/test_ann2snn.py index 3a5d2b7e2..1b8bbbcfa 100644 --- a/test/nn/snn/test_ann2snn.py +++ b/test/nn/snn/test_ann2snn.py @@ -191,7 +191,13 @@ def val(net, device, data_loader, T=None): "by": "type", "default": {"config": {"name": None}}, "fuse": True, - "relu": {"config": {"name": "IFNode", "mode": "99.9%", "momentum": 0.1,}}, + "relu": { + "config": { + "name": "IFNode", + "mode": "99.9%", + "momentum": 0.1, + } + }, "train_data_loader": input_generator, "device": "cpu", # "device": "cuda", } diff --git a/test/passes/graph/analysis/add_metadata/test_add_common_metadata.py b/test/passes/graph/analysis/add_metadata/test_add_common_metadata.py index a7bce45cd..9740bdd38 100644 --- a/test/passes/graph/analysis/add_metadata/test_add_common_metadata.py +++ b/test/passes/graph/analysis/add_metadata/test_add_common_metadata.py @@ -27,10 +27,22 @@ def add_common_metadata(model_cls_name: str) -> MaseGraph: # mg.fx_graph.print_tabular() input_ids = torch.randint( - 0, config.vocab_size, (1, 128, config.hidden_size,), device="meta", + 0, + config.vocab_size, + ( + 1, + 128, + config.hidden_size, + ), + device="meta", ) mg, _ = passes.add_common_metadata_analysis_pass( - mg, pass_args={"dummy_in": {"input_ids": input_ids,},}, + mg, + pass_args={ + "dummy_in": { + "input_ids": input_ids, + }, + }, ) return mg diff --git a/test/passes/graph/analysis/pruning/test_hook_inspect.py b/test/passes/graph/analysis/pruning/test_hook_inspect.py index b34b579d3..dfa7fa0fd 100644 --- a/test/passes/graph/analysis/pruning/test_hook_inspect.py +++ b/test/passes/graph/analysis/pruning/test_hook_inspect.py @@ -111,9 +111,15 @@ def run_with_config(config_file): profile_pass_arg = { "by": "type", - "target_weight_nodes": ["conv2d",], - "target_activation_nodes": ["conv2d",], - "weight_statistics": {"variance_precise": {"device": "cpu", "dims": "all"},}, + "target_weight_nodes": [ + "conv2d", + ], + "target_activation_nodes": [ + "conv2d", + ], + "weight_statistics": { + "variance_precise": {"device": "cpu", "dims": "all"}, + }, "activation_statistics": { "variance_precise": {"device": "cpu", "dims": "all"}, }, diff --git a/test/passes/graph/analysis/statistic_profiler/test_statistic_profiler.py b/test/passes/graph/analysis/statistic_profiler/test_statistic_profiler.py index 084170c70..2b9828ace 100644 --- a/test/passes/graph/analysis/statistic_profiler/test_statistic_profiler.py +++ b/test/passes/graph/analysis/statistic_profiler/test_statistic_profiler.py @@ -52,7 +52,9 @@ def test_statistic_profiler(): dataset_info = get_dataset_info("cifar10") model = get_model( - checkpoint="resnet18", pretrained=False, dataset_info=dataset_info, + checkpoint="resnet18", + pretrained=False, + dataset_info=dataset_info, ) dummy_in = {"x": next(iter(datamodule.train_dataloader()))[0]} @@ -66,9 +68,15 @@ def test_statistic_profiler(): pass_arg = { "by": "type", - "target_weight_nodes": ["conv2d",], - "target_activation_nodes": ["relu",], - "weight_statistics": {"variance_precise": {"device": "cpu", "dims": "all"},}, + "target_weight_nodes": [ + "conv2d", + ], + "target_activation_nodes": [ + "relu", + ], + "weight_statistics": { + "variance_precise": {"device": "cpu", "dims": "all"}, + }, "activation_statistics": { "variance_precise": {"device": "cpu", "dims": "all"}, }, diff --git a/test/passes/graph/transforms/prune/test_prune.py b/test/passes/graph/transforms/prune/test_prune.py index 8c6de1d2a..8bddad906 100644 --- a/test/passes/graph/transforms/prune/test_prune.py +++ b/test/passes/graph/transforms/prune/test_prune.py @@ -106,9 +106,15 @@ def run_with_config(config_file): profile_pass_arg = { "by": "type", - "target_weight_nodes": ["conv2d",], - "target_activation_nodes": ["conv2d",], - "weight_statistics": {"variance_precise": {"device": "cpu", "dims": "all"},}, + "target_weight_nodes": [ + "conv2d", + ], + "target_activation_nodes": [ + "conv2d", + ], + "weight_statistics": { + "variance_precise": {"device": "cpu", "dims": "all"}, + }, "activation_statistics": { "variance_precise": {"device": "cpu", "dims": "all"}, }, diff --git a/test/passes/graph/transforms/prune/test_prune_detach_hook.py b/test/passes/graph/transforms/prune/test_prune_detach_hook.py index f7afe9a8d..bf15caa86 100644 --- a/test/passes/graph/transforms/prune/test_prune_detach_hook.py +++ b/test/passes/graph/transforms/prune/test_prune_detach_hook.py @@ -105,9 +105,15 @@ def run_with_config(config_file): profile_pass_arg = { "by": "type", - "target_weight_nodes": ["conv2d",], - "target_activation_nodes": ["conv2d",], - "weight_statistics": {"variance_precise": {"device": "cpu", "dims": "all"},}, + "target_weight_nodes": [ + "conv2d", + ], + "target_activation_nodes": [ + "conv2d", + ], + "weight_statistics": { + "variance_precise": {"device": "cpu", "dims": "all"}, + }, "activation_statistics": { "variance_precise": {"device": "cpu", "dims": "all"}, }, diff --git a/test/passes/graph/transforms/quantize/test_quantize_lutnet_conv2d.py b/test/passes/graph/transforms/quantize/test_quantize_lutnet_conv2d.py index 5ca6f6b27..89327f8f0 100644 --- a/test/passes/graph/transforms/quantize/test_quantize_lutnet_conv2d.py +++ b/test/passes/graph/transforms/quantize/test_quantize_lutnet_conv2d.py @@ -181,7 +181,8 @@ def test_quantize_lutnet_conv2d(): first_table = connection[0] assert any(initialized_weight[0, :] == first_table) and any( initialized_weight[ - input_c * k_w * k_h * output_c * (lutnet_config["data_in_levels"] - 1), :, + input_c * k_w * k_h * output_c * (lutnet_config["data_in_levels"] - 1), + :, ] == first_table ) diff --git a/test/passes/graph/transforms/training/test_training_base_pass.py b/test/passes/graph/transforms/training/test_training_base_pass.py index e46777de0..06d8371dd 100644 --- a/test/passes/graph/transforms/training/test_training_base_pass.py +++ b/test/passes/graph/transforms/training/test_training_base_pass.py @@ -170,7 +170,11 @@ def test_training_base_backward_only(): "default": {"config": {"name": None}}, "linear": { "config": { - "forward": {"bypass": True, "pass": "quantize", "name": "integer",}, + "forward": { + "bypass": True, + "pass": "quantize", + "name": "integer", + }, "backward": { "pass": "quantize", "name": "integer", diff --git a/test/passes/graph/transforms/verilog/test_emit_activation_gelu.py b/test/passes/graph/transforms/verilog/test_emit_activation_gelu.py index 68d405e5a..d78ac8525 100644 --- a/test/passes/graph/transforms/verilog/test_emit_activation_gelu.py +++ b/test/passes/graph/transforms/verilog/test_emit_activation_gelu.py @@ -76,7 +76,11 @@ def test_emit_activation_gelu(): ) config_file = os.path.join( - os.path.abspath(""), "configs", "tests", "quantize", "fixed.toml", + os.path.abspath(""), + "configs", + "tests", + "quantize", + "fixed.toml", ) with open(config_file, "r") as f: quan_args = toml.load(f)["passes"]["quantize"] diff --git a/test/passes/graph/transforms/verilog/test_emit_activation_selu.py b/test/passes/graph/transforms/verilog/test_emit_activation_selu.py index 8a59b2e44..ea1a11928 100644 --- a/test/passes/graph/transforms/verilog/test_emit_activation_selu.py +++ b/test/passes/graph/transforms/verilog/test_emit_activation_selu.py @@ -77,7 +77,11 @@ def test_emit_activation_selu(): ) config_file = os.path.join( - os.path.abspath(""), "configs", "tests", "quantize", "fixed.toml", + os.path.abspath(""), + "configs", + "tests", + "quantize", + "fixed.toml", ) with open(config_file, "r") as f: quan_args = toml.load(f)["passes"]["quantize"] diff --git a/test/passes/graph/transforms/verilog/test_emit_activation_softplus.py b/test/passes/graph/transforms/verilog/test_emit_activation_softplus.py index 9e60ba7a3..8c7a1e107 100644 --- a/test/passes/graph/transforms/verilog/test_emit_activation_softplus.py +++ b/test/passes/graph/transforms/verilog/test_emit_activation_softplus.py @@ -76,7 +76,11 @@ def test_emit_activation_softplus(): ) config_file = os.path.join( - os.path.abspath(""), "configs", "tests", "quantize", "fixed.toml", + os.path.abspath(""), + "configs", + "tests", + "quantize", + "fixed.toml", ) with open(config_file, "r") as f: quan_args = toml.load(f)["passes"]["quantize"] diff --git a/test/passes/graph/transforms/verilog/test_emit_activation_softsign.py b/test/passes/graph/transforms/verilog/test_emit_activation_softsign.py index 3de23e480..f50f5fce5 100644 --- a/test/passes/graph/transforms/verilog/test_emit_activation_softsign.py +++ b/test/passes/graph/transforms/verilog/test_emit_activation_softsign.py @@ -76,7 +76,11 @@ def test_emit_activation_softsign(): ) config_file = os.path.join( - os.path.abspath(""), "configs", "tests", "quantize", "fixed.toml", + os.path.abspath(""), + "configs", + "tests", + "quantize", + "fixed.toml", ) with open(config_file, "r") as f: quan_args = toml.load(f)["passes"]["quantize"] diff --git a/test/passes/graph/transforms/verilog/test_emit_activation_tanh.py b/test/passes/graph/transforms/verilog/test_emit_activation_tanh.py index 116d53b36..205211ea6 100644 --- a/test/passes/graph/transforms/verilog/test_emit_activation_tanh.py +++ b/test/passes/graph/transforms/verilog/test_emit_activation_tanh.py @@ -75,7 +75,11 @@ def test_emit_activation_tanh(): mg, _ = add_common_metadata_analysis_pass(mg, {"dummy_in": dummy_in}) config_file = os.path.join( - os.path.abspath(""), "configs", "tests", "quantize", "fixed.toml", + os.path.abspath(""), + "configs", + "tests", + "quantize", + "fixed.toml", ) with open(config_file, "r") as f: quan_args = toml.load(f)["passes"]["quantize"] diff --git a/test/passes/graph/transforms/verilog/test_emit_verilog_bert.py b/test/passes/graph/transforms/verilog/test_emit_verilog_bert.py index 3e0f16599..b50fb3776 100644 --- a/test/passes/graph/transforms/verilog/test_emit_verilog_bert.py +++ b/test/passes/graph/transforms/verilog/test_emit_verilog_bert.py @@ -174,7 +174,10 @@ def emit_verilog_bert( mg, _ = bert_update_metadata(mg, q_config) mg, _ = passes.add_hardware_metadata_analysis_pass( - mg, pass_args={"max_parallelism": [max_parallelism] * 4,}, + mg, + pass_args={ + "max_parallelism": [max_parallelism] * 4, + }, ) # * Save the metadata to a file for debugging @@ -190,7 +193,11 @@ def emit_verilog_bert( mg, _ = passes.emit_bram_transform_pass(mg) mg, _ = passes.emit_internal_rtl_transform_pass(mg) mg, _ = passes.emit_cocotb_transform_pass( - mg, pass_args={"wait_time": wait_count, "wait_unit": wait_unit,}, + mg, + pass_args={ + "wait_time": wait_count, + "wait_unit": wait_unit, + }, ) mg, _ = passes.emit_vivado_project_transform_pass(mg) diff --git a/test/passes/graph/transforms/verilog/test_emit_verilog_llama.py b/test/passes/graph/transforms/verilog/test_emit_verilog_llama.py index 5447a3782..f268184ce 100644 --- a/test/passes/graph/transforms/verilog/test_emit_verilog_llama.py +++ b/test/passes/graph/transforms/verilog/test_emit_verilog_llama.py @@ -86,7 +86,10 @@ def emit_verilog_llama( mg, _ = llama_update_metadata(mg, q_config) mg, _ = passes.add_hardware_metadata_analysis_pass( - mg, pass_args={"max_parallelism": [max_parallelism] * 4,}, + mg, + pass_args={ + "max_parallelism": [max_parallelism] * 4, + }, ) # * Save the metadata to a file for debugging @@ -102,7 +105,11 @@ def emit_verilog_llama( mg, _ = passes.emit_bram_transform_pass(mg) mg, _ = passes.emit_internal_rtl_transform_pass(mg) mg, _ = passes.emit_cocotb_transform_pass( - mg, pass_args={"wait_time": wait_count, "wait_unit": wait_unit,}, + mg, + pass_args={ + "wait_time": wait_count, + "wait_unit": wait_unit, + }, ) mg, _ = passes.emit_vivado_project_transform_pass(mg) diff --git a/test/passes/graph/transforms/verilog/test_emit_verilog_mistral.py b/test/passes/graph/transforms/verilog/test_emit_verilog_mistral.py index 51ae9dc7d..64625495e 100644 --- a/test/passes/graph/transforms/verilog/test_emit_verilog_mistral.py +++ b/test/passes/graph/transforms/verilog/test_emit_verilog_mistral.py @@ -87,7 +87,10 @@ def emit_verilog_mistral( mg, _ = mistral_update_metadata(mg, q_config) mg, _ = passes.add_hardware_metadata_analysis_pass( - mg, pass_args={"max_parallelism": [max_parallelism] * 4,}, + mg, + pass_args={ + "max_parallelism": [max_parallelism] * 4, + }, ) # * Save the metadata to a file for debugging @@ -103,7 +106,11 @@ def emit_verilog_mistral( mg, _ = passes.emit_bram_transform_pass(mg) mg, _ = passes.emit_internal_rtl_transform_pass(mg) mg, _ = passes.emit_cocotb_transform_pass( - mg, pass_args={"wait_time": wait_count, "wait_unit": wait_unit,}, + mg, + pass_args={ + "wait_time": wait_count, + "wait_unit": wait_unit, + }, ) mg, _ = passes.emit_vivado_project_transform_pass(mg) diff --git a/test/passes/graph/transforms/verilog/test_emit_verilog_norm.py b/test/passes/graph/transforms/verilog/test_emit_verilog_norm.py index acc659331..324a2e29d 100644 --- a/test/passes/graph/transforms/verilog/test_emit_verilog_norm.py +++ b/test/passes/graph/transforms/verilog/test_emit_verilog_norm.py @@ -116,7 +116,7 @@ def add_norm_metadata_gen_lut_analysis_pass(mg, config={}): mem_dir = Path(__file__).parent / "build" / "norm" / "mem" makedirs(mem_dir, exist_ok=True) - lut = make_lut(2 ** LUT_POW, ISQRT_WIDTH) + lut = make_lut(2**LUT_POW, ISQRT_WIDTH) mem_path = mem_dir / f"norm_isqrt_lut.mem" write_memb(mem_path, lut, ISQRT_WIDTH) mem_id = 0 @@ -224,10 +224,23 @@ def test_emit_verilog_norm(): shape = [10, 4, 8, 8] normalizations = [ - nn.BatchNorm2d(num_features=shape[1], affine=False,), - nn.LayerNorm(normalized_shape=shape[1:], elementwise_affine=False,), - nn.GroupNorm(num_groups=2, num_channels=shape[1], affine=False,), - nn.InstanceNorm2d(num_features=shape[1], affine=False,), + nn.BatchNorm2d( + num_features=shape[1], + affine=False, + ), + nn.LayerNorm( + normalized_shape=shape[1:], + elementwise_affine=False, + ), + nn.GroupNorm( + num_groups=2, + num_channels=shape[1], + affine=False, + ), + nn.InstanceNorm2d( + num_features=shape[1], + affine=False, + ), # rms.RMSNorm( # normalized_shape=shape[1:], # ), diff --git a/test/passes/onnx/analysis/test_export_fx_graph.py b/test/passes/onnx/analysis/test_export_fx_graph.py index fc081f1ee..395991cda 100644 --- a/test/passes/onnx/analysis/test_export_fx_graph.py +++ b/test/passes/onnx/analysis/test_export_fx_graph.py @@ -78,13 +78,16 @@ def test_export_fx_graph_bert(): @pytest.mark.skip def test_export_fx_graph_mistral(): - export_fx_graph_model("mistral-community/Mistral-7B-v0.2",) + export_fx_graph_model( + "mistral-community/Mistral-7B-v0.2", + ) @pytest.mark.skip def test_export_fx_graph_whisper(): export_fx_graph_model( - "openai/whisper-tiny", skip_export=True, + "openai/whisper-tiny", + skip_export=True, ) diff --git a/test/tools/test_onnx_operators.py b/test/tools/test_onnx_operators.py index 7393cb41f..c15a44d89 100644 --- a/test/tools/test_onnx_operators.py +++ b/test/tools/test_onnx_operators.py @@ -14,20 +14,58 @@ def excepthook(exc_type, exc_value, exc_traceback): def test_gather(): - data1 = torch.Tensor([[1.0, 1.2, 1.9], [2.3, 3.4, 3.9], [4.5, 5.7, 5.9],]) + data1 = torch.Tensor( + [ + [1.0, 1.2, 1.9], + [2.3, 3.4, 3.9], + [4.5, 5.7, 5.9], + ] + ) - data2 = torch.Tensor([[1.0, 1.2], [2.3, 3.4], [4.5, 5.7],]) + data2 = torch.Tensor( + [ + [1.0, 1.2], + [2.3, 3.4], + [4.5, 5.7], + ] + ) - indices1 = torch.Tensor([[0, 2],]).to(torch.int64) + indices1 = torch.Tensor( + [ + [0, 2], + ] + ).to(torch.int64) - indices2 = torch.Tensor([[0, 1], [1, 2],]).to(torch.int64) + indices2 = torch.Tensor( + [ + [0, 1], + [1, 2], + ] + ).to(torch.int64) obs_out1 = onnx_gather(data1, 1, indices1) obs_out2 = onnx_gather(data2, 0, indices2) - exp_out1 = torch.Tensor([[[1.0, 1.9]], [[2.3, 3.9]], [[4.5, 5.9]],]) + exp_out1 = torch.Tensor( + [ + [[1.0, 1.9]], + [[2.3, 3.9]], + [[4.5, 5.9]], + ] + ) - exp_out2 = torch.Tensor([[[1.0, 1.2], [2.3, 3.4],], [[2.3, 3.4], [4.5, 5.7],],]) + exp_out2 = torch.Tensor( + [ + [ + [1.0, 1.2], + [2.3, 3.4], + ], + [ + [2.3, 3.4], + [4.5, 5.7], + ], + ] + ) print(obs_out2) print(exp_out2) @@ -37,7 +75,12 @@ def test_gather(): def test_slice(): - data = torch.Tensor([[1, 2, 3, 4], [5, 6, 7, 8],]) + data = torch.Tensor( + [ + [1, 2, 3, 4], + [5, 6, 7, 8], + ] + ) test1 = onnx_slice( data, From d9e15b3aeb70b1ca1e2ffcd799fb09789b50fcf2 Mon Sep 17 00:00:00 2001 From: Johnny1882 Date: Thu, 20 Feb 2025 23:39:48 +0000 Subject: [PATCH 14/38] remove optical test to prevent doc generation error --- .../transforms/optical/test_optical_module.py | 70 ------------------- 1 file changed, 70 deletions(-) delete mode 100644 test/passes/module/transforms/optical/test_optical_module.py diff --git a/test/passes/module/transforms/optical/test_optical_module.py b/test/passes/module/transforms/optical/test_optical_module.py deleted file mode 100644 index e65546822..000000000 --- a/test/passes/module/transforms/optical/test_optical_module.py +++ /dev/null @@ -1,70 +0,0 @@ -#!/usr/bin/env python3 -# This example converts a simple MLP model to an ONN model -import sys - -import torch -import torch.nn as nn -import torch.nn.functional as F - -from pathlib import Path - -sys.path.append(Path(__file__).resolve().parents[5].as_posix()) - - -from chop.passes.module.transforms.optical import optical_module_transform_pass - - -class Net(nn.Module): - def __init__(self): - super(Net, self).__init__() - self.conv1 = nn.Conv2d(1, 32, 3, 1) - self.conv2 = nn.Conv2d(32, 64, 3, 1) - self.dropout1 = nn.Dropout(0.25) - self.dropout2 = nn.Dropout(0.5) - self.fc1 = nn.Linear(9216, 128) - self.fc2 = nn.Linear(128, 10) - - def forward(self, x): - x = self.conv1(x) - x = F.relu(x) - x = self.conv2(x) - x = F.relu(x) - x = F.max_pool2d(x, 2) - x = self.dropout1(x) - x = torch.flatten(x, 1) - x = self.fc1(x) - x = F.relu(x) - x = self.dropout2(x) - x = self.fc2(x) - output = F.log_softmax(x, dim=1) - return output - - -def test_optical_module_transform_pass(): - model = Net() - # Sanity check and report - pass_args = { - "by": "name", - "fc1": { - "config": { - "name": "morr", - "miniblock": 4, - "morr_init": True, - "trainable_morr_bias": False, - "trainable_morr_scale": False, - } - }, - "conv1": { - "config": { - "name": "morr", - "miniblock": 4, - "morr_init": True, - "trainable_morr_bias": False, - "trainable_morr_scale": False, - } - }, - } - optical_module_transform_pass(model, pass_args) - - -test_optical_module_transform_pass() From ea77e742714472ca427677adc6298c4c3f640cd6 Mon Sep 17 00:00:00 2001 From: Johnny1882 Date: Thu, 20 Feb 2025 23:54:01 +0000 Subject: [PATCH 15/38] reformat --- src/chop/passes/module/transforms/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/chop/passes/module/transforms/__init__.py b/src/chop/passes/module/transforms/__init__.py index 05df7efe2..54f545bfd 100644 --- a/src/chop/passes/module/transforms/__init__.py +++ b/src/chop/passes/module/transforms/__init__.py @@ -1,4 +1,4 @@ from .autosharding import resharding_transform_pass from .quantize import quantize_module_transform_pass from .autosharding import resharding_transform_pass -from .snn import ann2snn_module_transform_pass \ No newline at end of file +from .snn import ann2snn_module_transform_pass From 6be34c906ebb61b8e0af6fe299c1be9de71b07a9 Mon Sep 17 00:00:00 2001 From: Johnny1882 Date: Fri, 21 Feb 2025 00:25:38 +0000 Subject: [PATCH 16/38] remove unorgainzed test --- test/self/test_optical_module.py | 252 ------------------------------- test/self/train_mnist_cnn.py | 247 ------------------------------ 2 files changed, 499 deletions(-) delete mode 100644 test/self/test_optical_module.py delete mode 100644 test/self/train_mnist_cnn.py diff --git a/test/self/test_optical_module.py b/test/self/test_optical_module.py deleted file mode 100644 index e56921881..000000000 --- a/test/self/test_optical_module.py +++ /dev/null @@ -1,252 +0,0 @@ -#!/usr/bin/env python3 -# This example converts a simple MLP model to Verilog -import logging -import os -import sys - -import torch -import torch.nn as nn - -from torch.profiler import profile, record_function, ProfilerActivity -import torchvision.models as models -import torchvision.transforms as transforms -import torch.utils.data as data -from pathlib import Path - -sys.path.append(Path(__file__).resolve().parents[5].as_posix()) - - -# from chop.passes.module.transforms import quantize_module_transform_pass -from chop.passes.module.transforms import optical_module_transform_pass -from chop.passes.module import report_trainable_parameters_analysis_pass - -import argparse -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.optim as optim -from torchvision import datasets, transforms -from torch.optim.lr_scheduler import StepLR - -from train_mnist_cnn import test, train, Net, test_memory_detailed - -# -------------------------------------------------- -# Model specifications -# -------------------------------------------------- -# class MLP(torch.nn.Module): -# """ -# Toy quantized FC model for digit recognition on MNIST -# """ - -# def __init__(self) -> None: -# super().__init__() - -# self.fc1 = nn.Linear(28 * 28, 28 * 28) -# self.fc2 = nn.Linear(28 * 28, 28 * 28 * 4) -# self.fc3 = nn.Linear(28 * 28 * 4, 10) - -# def forward(self, x): -# x = torch.flatten(x, start_dim=1, end_dim=-1) -# x = torch.nn.functional.relu(self.fc1(x)) -# # w = torch.randn((4, 28 * 28)) -# # x = torch.nn.functional.relu(nn.functional.linear(x, w)) -# x = torch.nn.functional.relu(self.fc2(x)) -# x = self.fc3(x) -# return x - - -def load_my_model(model_path, device): - # Load the model from the .pt file - loaded_model = torch.load(model_path, map_location=device) - # Set it to evaluation mode (important if it contains layers like BatchNorm or Dropout) - loaded_model.eval() - return loaded_model - - -def perform_optical_module_transform_pass(model, save_path="mase_output/onn_cnn.pt"): - pass_args = { - "by": "type", - "linear": { - "config": { - "name": "morr", - "miniblock": 4, - "morr_init": True, - "trainable_morr_bias": False, - "trainable_morr_scale": False, - "device": device, - } - }, - "conv2d": { - "config": { - "name": "morr", - "miniblock": 4, - "morr_init": True, - "trainable_morr_bias": False, - "trainable_morr_scale": False, - "device": device, - } - }, - } - onn_model, _ = optical_module_transform_pass(model, pass_args) - torch.save(onn_model.state_dict(), save_path) - return onn_model - - -def test_optical_module_transform_pass(): - model_path = "mase_output/sample_mnist_cnn.pt" - mnist_cnn = load_my_model(model_path) - # Sanity check and report - pass_args = { - "by": "name", - "fc1": { - "config": { - "name": "morr", - "miniblock": 4, - "morr_init": True, - "trainable_morr_bias": False, - "trainable_morr_scale": False, - "device": device, - } - }, - } - onn_cnn, _ = optical_module_transform_pass(mnist_cnn, pass_args) - torch.save(onn_cnn, "mase_output/onn_cnn.pt") - - -if __name__ == "__main__": - finetune = True - if True: - parser = argparse.ArgumentParser(description="PyTorch MNIST Example") - parser.add_argument( - "--batch-size", - type=int, - default=64, - metavar="N", - help="input batch size for training (default: 64)", - ) - parser.add_argument( - "--test-batch-size", - type=int, - default=100, - metavar="N", - help="input batch size for testing (default: 1000)", - ) - parser.add_argument( - "--epochs", - type=int, - default=1, - metavar="N", - help="number of epochs to train (default: 14)", - ) - parser.add_argument( - "--lr", - type=float, - default=1.0, - metavar="LR", - help="learning rate (default: 1.0)", - ) - parser.add_argument( - "--gamma", - type=float, - default=0.7, - metavar="M", - help="Learning rate step gamma (default: 0.7)", - ) - parser.add_argument( - "--no-cuda", - action="store_true", - default=False, - help="disables CUDA training", - ) - parser.add_argument( - "--no-mps", - action="store_true", - default=False, - help="disables macOS GPU training", - ) - parser.add_argument( - "--dry-run", - action="store_true", - default=False, - help="quickly check a single pass", - ) - parser.add_argument( - "--seed", type=int, default=1, metavar="S", help="random seed (default: 1)" - ) - parser.add_argument( - "--log-interval", - type=int, - default=10, - metavar="N", - help="how many batches to wait before logging training status", - ) - parser.add_argument( - "--save-model", - action="store_true", - default=True, - help="For Saving the current Model", - ) - parser.add_argument( - "--gpu-id", type=int, default=1, help="Which GPU device to use [default: 0]" - ) - - args = parser.parse_args() - use_cuda = not args.no_cuda and torch.cuda.is_available() - use_mps = not args.no_mps and torch.backends.mps.is_available() - - torch.manual_seed(args.seed) - - if not args.no_cuda and torch.cuda.is_available(): - device = torch.device(f"cuda:{args.gpu_id}") - elif use_mps: - device = torch.device("mps") - else: - device = torch.device("cpu") - - train_kwargs = {"batch_size": args.batch_size} - test_kwargs = {"batch_size": args.test_batch_size} - if use_cuda: - cuda_kwargs = {"num_workers": 1, "pin_memory": True, "shuffle": True} - train_kwargs.update(cuda_kwargs) - test_kwargs.update(cuda_kwargs) - - transform = transforms.Compose( - [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] - ) - dataset1 = datasets.MNIST( - "../data", train=True, download=True, transform=transform - ) - dataset2 = datasets.MNIST("../data", train=False, transform=transform) - train_loader = torch.utils.data.DataLoader(dataset1, **train_kwargs) - test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs) - - cnn = load_my_model("mase_output/sample_mnist_cnn.pt", device) - print("-------------- Testing the original cnn model -------------------") - test(cnn, device, test_loader) - _, _ = report_trainable_parameters_analysis_pass(cnn) - - # onn = load_my_model("mase_output/onn_cnn.pt", device) - onn_model = perform_optical_module_transform_pass(cnn) - onn_model.to(device) - - print("-------------- Testing the transformed onn model -------------------") - test(onn_model, device, test_loader) - _, _ = report_trainable_parameters_analysis_pass(onn_model) - - ######### Training the onn model - if finetune: - optimizer = optim.Adadelta(onn_model.parameters(), lr=args.lr) - scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) - - for epoch in range(1, args.epochs + 1): - train(args, onn_model, device, train_loader, optimizer, epoch) - test(onn_model, device, test_loader) - scheduler.step() - - torch.save(onn_model.state_dict(), "mase_output/trained_onn.pt") - - print("-------------- Testing the trained onn model -------------------") - test(onn_model, device, test_loader) - _, _ = report_trainable_parameters_analysis_pass(onn_model) - - # test_optical_module_transform_pass() diff --git a/test/self/train_mnist_cnn.py b/test/self/train_mnist_cnn.py deleted file mode 100644 index a955087da..000000000 --- a/test/self/train_mnist_cnn.py +++ /dev/null @@ -1,247 +0,0 @@ -import argparse -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.optim as optim -from torchvision import datasets, transforms -from torch.optim.lr_scheduler import StepLR - - -class Net(nn.Module): - def __init__(self): - super(Net, self).__init__() - self.conv1 = nn.Conv2d(1, 32, 3, 1) - self.conv2 = nn.Conv2d(32, 64, 3, 1) - self.dropout1 = nn.Dropout(0.25) - self.dropout2 = nn.Dropout(0.5) - self.fc1 = nn.Linear(9216, 128) - self.fc2 = nn.Linear(128, 10) - - def forward(self, x): - x = self.conv1(x) - x = F.relu(x) - x = self.conv2(x) - x = F.relu(x) - x = F.max_pool2d(x, 2) - x = self.dropout1(x) - x = torch.flatten(x, 1) - x = self.fc1(x) - x = F.relu(x) - x = self.dropout2(x) - x = self.fc2(x) - output = F.log_softmax(x, dim=1) - return output - - -def train(args, model, device, train_loader, optimizer, epoch): - model.train() - for batch_idx, (data, target) in enumerate(train_loader): - data, target = data.to(device), target.to(device) - optimizer.zero_grad() - output = model(data) - loss = F.nll_loss(output, target) - loss.backward() - optimizer.step() - if batch_idx % args.log_interval == 0: - print( - "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format( - epoch, - batch_idx * len(data), - len(train_loader.dataset), - 100.0 * batch_idx / len(train_loader), - loss.item(), - ) - ) - if args.dry_run: - break - - -def test(model, device, test_loader): - model.eval() - test_loss = 0 - correct = 0 - with torch.no_grad(): - for data, target in test_loader: - data, target = data.to(device), target.to(device) - output = model(data) - test_loss += F.nll_loss( - output, target, reduction="sum" - ).item() # sum up batch loss - pred = output.argmax( - dim=1, keepdim=True - ) # get the index of the max log-probability - correct += pred.eq(target.view_as(pred)).sum().item() - - test_loss /= len(test_loader.dataset) - - print( - "\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format( - test_loss, - correct, - len(test_loader.dataset), - 100.0 * correct / len(test_loader.dataset), - ) - ) - - -# Custom Function -from torch.profiler import profile, record_function, ProfilerActivity - - -def test_memory_detailed(model, device, test_loader): - """ - Use PyTorch Profiler to record detailed memory usage on the first batch, - so you can see exactly which ops consume how much memory. - """ - - # Put the model in eval mode - model.eval() - - # Get just 1 batch (or a few) for profiling - data_iter = iter(test_loader) - try: - data, target = next(data_iter) - except StopIteration: - print("test_loader is empty. Cannot profile.") - return - - # Move to device - data, target = data.to(device), target.to(device) - - # Now start the profiler - with profile( - activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], - record_shapes=True, # to see tensor shapes - profile_memory=True, # track memory usage - ) as prof: - with record_function("test_first_batch"): - # Forward pass on just this batch - output = model(data) - # Suppose you also compute a loss for some reason - loss = F.nll_loss(output, target, reduction="sum") - # If purely inference, you might not do backward - # but let's illustrate: - loss.backward() # If you want to see backward pass memory usage - - # Print the summarized table - # Sort by self_cuda_memory_usage to see which ops used the most GPU memory - print( - prof.key_averages(group_by_input_shape=True).table( - sort_by="self_cuda_memory_usage", - row_limit=200, # show as many rows as you need - ) - ) - - -def main(): - # Training settings - parser = argparse.ArgumentParser(description="PyTorch MNIST Example") - parser.add_argument( - "--batch-size", - type=int, - default=64, - metavar="N", - help="input batch size for training (default: 64)", - ) - parser.add_argument( - "--test-batch-size", - type=int, - default=1000, - metavar="N", - help="input batch size for testing (default: 1000)", - ) - parser.add_argument( - "--epochs", - type=int, - default=29, - metavar="N", - help="number of epochs to train (default: 14)", - ) - parser.add_argument( - "--lr", - type=float, - default=1.0, - metavar="LR", - help="learning rate (default: 1.0)", - ) - parser.add_argument( - "--gamma", - type=float, - default=0.7, - metavar="M", - help="Learning rate step gamma (default: 0.7)", - ) - parser.add_argument( - "--no-cuda", action="store_true", default=False, help="disables CUDA training" - ) - parser.add_argument( - "--no-mps", - action="store_true", - default=False, - help="disables macOS GPU training", - ) - parser.add_argument( - "--dry-run", - action="store_true", - default=False, - help="quickly check a single pass", - ) - parser.add_argument( - "--seed", type=int, default=1, metavar="S", help="random seed (default: 1)" - ) - parser.add_argument( - "--log-interval", - type=int, - default=10, - metavar="N", - help="how many batches to wait before logging training status", - ) - parser.add_argument( - "--save-model", - action="store_true", - default=True, - help="For Saving the current Model", - ) - args = parser.parse_args() - use_cuda = not args.no_cuda and torch.cuda.is_available() - use_mps = not args.no_mps and torch.backends.mps.is_available() - - torch.manual_seed(args.seed) - - if use_cuda: - device = torch.device("cuda") - elif use_mps: - device = torch.device("mps") - else: - device = torch.device("cpu") - - train_kwargs = {"batch_size": args.batch_size} - test_kwargs = {"batch_size": args.test_batch_size} - if use_cuda: - cuda_kwargs = {"num_workers": 1, "pin_memory": True, "shuffle": True} - train_kwargs.update(cuda_kwargs) - test_kwargs.update(cuda_kwargs) - - transform = transforms.Compose( - [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] - ) - dataset1 = datasets.MNIST("../data", train=True, download=True, transform=transform) - dataset2 = datasets.MNIST("../data", train=False, transform=transform) - train_loader = torch.utils.data.DataLoader(dataset1, **train_kwargs) - test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs) - - model = Net().to(device) - optimizer = optim.Adadelta(model.parameters(), lr=args.lr) - - scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) - for epoch in range(1, args.epochs + 1): - train(args, model, device, train_loader, optimizer, epoch) - test(model, device, test_loader) - scheduler.step() - - if args.save_model: - torch.save(model, "mase_output/sample_mnist_cnn.pt") - - -if __name__ == "__main__": - main() From 87a222b7d0d90fb082dd96548dd4bbd75644a7e3 Mon Sep 17 00:00:00 2001 From: Johnny1882 Date: Fri, 21 Feb 2025 00:29:46 +0000 Subject: [PATCH 17/38] remove self test file --- .gitignore | 4 +- src/chop/nn/optical/utils/quantize.py | 54 +++++++++++++-------------- 2 files changed, 30 insertions(+), 28 deletions(-) diff --git a/.gitignore b/.gitignore index acd171321..e7118d94c 100644 --- a/.gitignore +++ b/.gitignore @@ -166,4 +166,6 @@ prof/ # HuggingFace trainer cache mase-trainer/ -test-trainer/ \ No newline at end of file +test-trainer/ + +test/self \ No newline at end of file diff --git a/src/chop/nn/optical/utils/quantize.py b/src/chop/nn/optical/utils/quantize.py index 84372c8c7..0bb53f90c 100644 --- a/src/chop/nn/optical/utils/quantize.py +++ b/src/chop/nn/optical/utils/quantize.py @@ -1,10 +1,10 @@ -""" -Description: -Author: Jiaqi Gu (jqgu@utexas.edu) -Date: 2021-06-06 03:15:00 -LastEditors: Jiaqi Gu (jqgu@utexas.edu) -LastEditTime: 2021-06-06 03:15:00 -""" +# """ +# Description: +# Author: Jiaqi Gu (jqgu@utexas.edu) +# Date: 2021-06-06 03:15:00 +# LastEditors: Jiaqi Gu (jqgu@utexas.edu) +# LastEditTime: 2021-06-06 03:15:00 +# """ import numpy as np import torch @@ -48,12 +48,12 @@ def backward(ctx, grad_output): ############ add observer and new quant based on range and zeropoint for activation def uniform_quantize_new(k, gradient_clip=False): - """ - Support uniform quantization with auto-adjusted input data range - args: - k: bitwidth - scale, zeropoint: obtained from observer - """ + # """ + # Support uniform quantization with auto-adjusted input data range + # args: + # k: bitwidth + # scale, zeropoint: obtained from observer + # """ class qfn(torch.autograd.Function): @staticmethod @@ -90,12 +90,12 @@ class input_quantize_fn(torch.nn.Module): def __init__( self, in_bit, alg="dorefa", device=torch.device("cuda:0"), quant_ratio=1.0 ): - """Input quantizer with Quant_Noise supported - Args: - in_bit (int): Input quantization bitwidth. - device (Device, optional): torch Device. Defaults to torch.device("cuda:0"). - quant_ratio (float, optional): Quantization ratio. Defaults to 1.0. - """ + # """Input quantizer with Quant_Noise supported + # Args: + # in_bit (int): Input quantization bitwidth. + # device (Device, optional): torch Device. Defaults to torch.device("cuda:0"). + # quant_ratio (float, optional): Quantization ratio. Defaults to 1.0. + # """ super(input_quantize_fn, self).__init__() assert 1 <= in_bit <= 32 self.in_bit = in_bit @@ -235,14 +235,14 @@ def forward(self, x): class weight_quantize_fn(torch.nn.Module): def __init__(self, w_bit, mode="oconv", alg="dorefa", quant_ratio=1.0): - """Differentiable weight quantizer. Support different algorithms. Support Quant-Noise with partial quantization. - - Args: - w_bit (int): quantization bitwidth - mode (str, optional): Different mode indicates different NN architectures. Defaults to "oconv". - alg (str, optional): Quantization algorithms. [dorefa, dorefa_sym, qnn, dorefa_pos] Defaults to "dorefa". - quant_ratio (float, optional): Quantization ratio to support full-precision gradient flow. Defaults to 1.0. - """ + # """Differentiable weight quantizer. Support different algorithms. Support Quant-Noise with partial quantization. + + # Args: + # w_bit (int): quantization bitwidth + # mode (str, optional): Different mode indicates different NN architectures. Defaults to "oconv". + # alg (str, optional): Quantization algorithms. [dorefa, dorefa_sym, qnn, dorefa_pos] Defaults to "dorefa". + # quant_ratio (float, optional): Quantization ratio to support full-precision gradient flow. Defaults to 1.0. + # """ super(weight_quantize_fn, self).__init__() assert 1 <= w_bit <= 32, logging.error( f"Only support 1 - 32 bit quantization, but got {w_bit}" From 67b91b505a75ee70a9eced9f7eda95df7c5fbff4 Mon Sep 17 00:00:00 2001 From: Johnny1882 Date: Mon, 10 Mar 2025 19:35:28 +0000 Subject: [PATCH 18/38] add back test_optical_module --- .../passes/module/module_modify_helper.py | 2 +- .../transforms/optical/test_optical_module.py | 70 +++++++++++++++++++ 2 files changed, 71 insertions(+), 1 deletion(-) create mode 100644 test/passes/module/transforms/optical/test_optical_module.py diff --git a/src/chop/passes/module/module_modify_helper.py b/src/chop/passes/module/module_modify_helper.py index 2e6a577e0..73806b2ac 100644 --- a/src/chop/passes/module/module_modify_helper.py +++ b/src/chop/passes/module/module_modify_helper.py @@ -131,7 +131,7 @@ def instantiate_conv2d(module, postfix, module_map, additional_module_args): has_bias = not (module.bias is None) # TODO: some transformed modules have "config" as an argument then extract the additional_module_args from it. Some directly take the additional_module_args. # Need to handle this better - if "config" in inspect.signature(conv2d.__init__).parameters: + if "config" in inspect.signature(conv2d_cls.__init__).parameters: conv2d = conv2d_cls( in_channels=module.in_channels, out_channels=module.out_channels, diff --git a/test/passes/module/transforms/optical/test_optical_module.py b/test/passes/module/transforms/optical/test_optical_module.py new file mode 100644 index 000000000..e65546822 --- /dev/null +++ b/test/passes/module/transforms/optical/test_optical_module.py @@ -0,0 +1,70 @@ +#!/usr/bin/env python3 +# This example converts a simple MLP model to an ONN model +import sys + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from pathlib import Path + +sys.path.append(Path(__file__).resolve().parents[5].as_posix()) + + +from chop.passes.module.transforms.optical import optical_module_transform_pass + + +class Net(nn.Module): + def __init__(self): + super(Net, self).__init__() + self.conv1 = nn.Conv2d(1, 32, 3, 1) + self.conv2 = nn.Conv2d(32, 64, 3, 1) + self.dropout1 = nn.Dropout(0.25) + self.dropout2 = nn.Dropout(0.5) + self.fc1 = nn.Linear(9216, 128) + self.fc2 = nn.Linear(128, 10) + + def forward(self, x): + x = self.conv1(x) + x = F.relu(x) + x = self.conv2(x) + x = F.relu(x) + x = F.max_pool2d(x, 2) + x = self.dropout1(x) + x = torch.flatten(x, 1) + x = self.fc1(x) + x = F.relu(x) + x = self.dropout2(x) + x = self.fc2(x) + output = F.log_softmax(x, dim=1) + return output + + +def test_optical_module_transform_pass(): + model = Net() + # Sanity check and report + pass_args = { + "by": "name", + "fc1": { + "config": { + "name": "morr", + "miniblock": 4, + "morr_init": True, + "trainable_morr_bias": False, + "trainable_morr_scale": False, + } + }, + "conv1": { + "config": { + "name": "morr", + "miniblock": 4, + "morr_init": True, + "trainable_morr_bias": False, + "trainable_morr_scale": False, + } + }, + } + optical_module_transform_pass(model, pass_args) + + +test_optical_module_transform_pass() From ce0f533953e61f60f1caafa58d270925b087a368 Mon Sep 17 00:00:00 2001 From: Johnny1882 Date: Tue, 11 Mar 2025 19:06:57 +0000 Subject: [PATCH 19/38] add bert transform file --- ...tfevents.1741637561.ee-tarrasque.3914313.0 | Bin 0 -> 5028 bytes ...tfevents.1741637649.ee-tarrasque.3918878.0 | Bin 0 -> 6507 bytes ...tfevents.1741641748.ee-tarrasque.4023347.0 | Bin 0 -> 5038 bytes .../transforms/optical/bert-finetune.py | 71 ++ .../module/transforms/optical/run_glue.py | 637 ++++++++++++++++++ wandb/latest-run | 1 + .../files/config.yaml | 497 ++++++++++++++ .../files/requirements.txt | 295 ++++++++ .../files/wandb-metadata.json | 60 ++ .../files/wandb-summary.json | 1 + .../run-60c8phhh.wandb | Bin 0 -> 164998 bytes .../files/requirements.txt | 295 ++++++++ .../files/wandb-metadata.json | 60 ++ .../run-x0wxhan7.wandb | Bin 0 -> 98304 bytes 14 files changed, 1917 insertions(+) create mode 100644 model_sst2/runs/Mar10_20-12-38_ee-tarrasque/events.out.tfevents.1741637561.ee-tarrasque.3914313.0 create mode 100644 model_sst2/runs/Mar10_20-14-07_ee-tarrasque/events.out.tfevents.1741637649.ee-tarrasque.3918878.0 create mode 100644 model_sst2/runs/Mar10_21-22-27_ee-tarrasque/events.out.tfevents.1741641748.ee-tarrasque.4023347.0 create mode 100644 test/passes/module/transforms/optical/bert-finetune.py create mode 100644 test/passes/module/transforms/optical/run_glue.py create mode 120000 wandb/latest-run create mode 100644 wandb/run-20250310_202153-60c8phhh/files/config.yaml create mode 100644 wandb/run-20250310_202153-60c8phhh/files/requirements.txt create mode 100644 wandb/run-20250310_202153-60c8phhh/files/wandb-metadata.json create mode 100644 wandb/run-20250310_202153-60c8phhh/files/wandb-summary.json create mode 100644 wandb/run-20250310_202153-60c8phhh/run-60c8phhh.wandb create mode 100644 wandb/run-20250310_212229-x0wxhan7/files/requirements.txt create mode 100644 wandb/run-20250310_212229-x0wxhan7/files/wandb-metadata.json create mode 100644 wandb/run-20250310_212229-x0wxhan7/run-x0wxhan7.wandb diff --git a/model_sst2/runs/Mar10_20-12-38_ee-tarrasque/events.out.tfevents.1741637561.ee-tarrasque.3914313.0 b/model_sst2/runs/Mar10_20-12-38_ee-tarrasque/events.out.tfevents.1741637561.ee-tarrasque.3914313.0 new file mode 100644 index 0000000000000000000000000000000000000000..21fed95b763c01ecfbd0a4f987eaf9cc96de2140 GIT binary patch literal 5028 zcmaJ_zmFwH5eDlP?w*7M&LCR7(DJ@}yL;<9TNV;0**X~?M#hd#R;y>GXI^_}rpMhq zyZ3CJ2%I`1frN~Rgb<0pfrJ1dNRbFhkPsrU_^P_6XL|P?8|-SktE=nBS6@}nlW&Bd zuYd96*Y|$=$MTpdR_cE9&Xwx@TW23`R2J7WtYZMAV@*GaqITwCpx z-6)b0wQiLND_=eM=-=;tH|Rb3qAAfcH#p?KYZAL$gUR2ty-#T7*rc$kN&Y!Jc=iMu1`7)8o)v|Sp1u7#v z`GCXOqPmhT*1R2)++bCuOP@)Sm~hbGRkJ4p)L_b8xI4;g(~H(PCrsax$b8M9tf*BW6w;SF;q>PUa*x=)g$&>I5|L&Y59#Q~woPhC=C+O`(pXt(%8)F}u^U_I$#Un_ z;9$)utKs7Z_mdhHK@ueSy5&>eSh%IS!S6*W%biME)2UWT+v~m&6iFPsP$%09K76fo zvqOGW88mK!PFcxL6w1p}x@Pkx%cjM%u=wU;8E)v-P@YerI|R3r&jq9tAfj|X+(=6B3!Nq>72e&@rr3jST(e|lq3Ni6wz4%*;I%=oUO>n1a0VT zDD<~V=)N6$cmJuhr(1EhJvlu)Ie#RSI`Pt4={_6h+xh&m41(qC;-tRT)}fGrCtN-k z!L&X3dXb!ErM>Kh?rWkAaLCeZC$9)dbLlPajC%pP7u%qc7rpVQ*j9I%3OV5~9FVmN z{6!f`@!`E2!JF-myX{^u_hi0V<$!$!(6?ZvyhL^=*V#s=ywxQ|wip|1&Uu(to_&%c+ZtIBjaxH_R=PZw z%!L!M(;bwMc#LegHuh3kvLc#1ZzF0zBuciHC4jUv6!q^>7!@ZfQW45k^s-|Hy4eUT z?PP(@>p!cA>1Lm->I@2L?2-5i0cpps&*0z60xgyFW|1NfK)|pfTqV5-X`NMvK#JPf zh2YYO#wyf<71|M%hx$%o(hXQyi2&bAowG*mDq4t%p#Mat;XUFxNGrQJoV-_>zSfjW z8W}J&9G1n#^ERuzup(H;6=5wox7^_tf?7_=$0dQ@`5-8&cL+iUPtg|fyKJ5C22+qE zCz8i#tO{L3BGpJlZm|Qe`!FDI9+rk6w*yZOJRD?cf|)dSEe=k!h8-$(A_fMAlih&J zVb4OMr-DUUT~Z^-y+1rIr2a-ur7%Goir(gSkjTgvu>?;mRcvYEnv%C(uncBMr$IxQ?Pju>BjiwAQ33ui_bn9>N?O&(a!)Ky6L6SlOBm2Nf`o{hdFh70&{I9EL%L`d2t|{dv8Wqdq|`3Qt1#ar!9EhpYr~?gl(jGfZSV2j>ZX;l!gH$Yp^3 z4?nMO@_NGP1PYEQaEi9!6hJ^FGDl^SqQ=K43l-sr*Yh`RuWCQ`3e&$vLCbd7QKw7S z1xm$a-++(M(ea$g+I`3hgYgiAc>~P?1FlVNkb}+ip&OWwDW~Vh%WOU$BZ)@(9Yjp= zsk-sPM!8$rU~r*IcX*sydP|3AGg-*WRxjSWyZYAZ4t* z?^OJ?;s?5r16}(tA8Ws5N%r&rnN*Rj0_!RD;|+ zP$9rJ_n38|tUoEFQzs+p7Nb>!IeHjAhLnB`gJO*3 zLOAVcaQM(W!WdHlr7O`C@Cpe9NE{ z+?K*{e%plf=9h)T@Qkk)Z0};|a{>^iw-1H(!REokO?(-Vt>OxgmST!T@RbUJzmwoV XLRcnuzP&g*TfJD#sK5B|=gp(M_Mw$6779hPJ}rcX-M-y+yGw1?Wrb2LrRAgifsx6)d*{7(+55%J zz5CHpgNg#7wXSF}L@J1ym>?=9g28|we;^tqM(`t|#)guZ7S>=fA!7ZVnfva&v+u3@ z$8L7!oSE}G=XXBlZeL=4{`=+a5A>|rw}07h-~8j{7rdt~xfc#xIVHkm^U!qzl}i&5 zsz}x&F5Ll@1_76|1Jlw=gd7kg#TuS3m|3~$wc35V&o<{CJ>w6L_gwoObMCt5`YO-# zReHH>s`ZJONmzEx{&zH8;Om8YG#bmDN|zDi$THA<5>O_=M+>gK*`5V^u|f`&dh4n#;vfxSo?(k|f3yPDn&^mJASa z(|JZN(v;lwC_>zE(jfJDLYAvU#0u67>SMHe*Zuq$ErK9$!dlZ>NhASF1aUIU9PYG4 zerm4e3ob)1Y%=mB6o*BBZCEQ$3$NKie1&699582cmj|pa5vcv)!~ zFg!n^I1l1%k|{uPQyU$vo85&pulhed$i0s>M&~X{0F`@P%tkyCEAVtgVs|;j6YBC>5M7*rVTq$o#WLLCz zm=v+^Ih-VL26?oQKr{us53nv7nNu6*Hp%rJDYRaup<2I_%fT8Os`U?U?5~Y5A^H<8 zC0Dnm#qC~tZLGsNSbb>7L*|7IDjTRFWwRBZA|e&3 zSGqIXGt?cfnIrXL5+%r3-wQlS0{Xc>%>+(C3G z*49QK66rbQ*?QW$|paSd)AM*OJ!U}G4r{yW2tLlo7njs6p&1fY?zMZq>yAq zW^%C&uYr+B*#>t&q_UyN{{)HQ>Of_s2d(iXcKtr6DNb!Zm&xgjOQ>ypHBdH&{L+EGWgW$8=wwFcaY7Xya1Wh-68y z45ujtGD9*?QBgOUo;%EIqCN{|h&<`p^`oI&TF62||FOV>AK=-TOCA)C)q749Hav8V6$VkjQO zxE%JBBtlAEB-kXi6dQf_!6EN-1S)rfD^QbTQ=f9QhF2uIC0P`gwwK>Pq%ENS*?2d`SHH>d>4LqWonSCpN?Lkh{29e_W&3YHHdmiVH{omr~q zGz**AHnKA^PF0IY-AV~aaXLzE8i+2Xz+$CfE(Tc3gSQSg{*RZwdv5<|7KAbu#y-`K)4Owq!MRHAU`OT&CHevv$(I> zWuu|$f8yY++o2*i`6tK2%M(XeuyCW4(0>C+RuC zB`p+hPuWdO*9LbMP0EV1T`8rUC%k%CWZDWc#5{61V}(BQh22;MU79h%w&Y3x^BO2L(xdm zP~aWlA)&b+x2V(Q*O`)v8NZ2ogpQ7fjn(c=R5Tg8QCM%FSzy$ac^lNu*5$)%!hFm- zp?<~`=4&;QX!4{5k*I!)Sn$?Hx=Y@~;6j5>s*Qc6nTr$Dx*1i7(RcEYO6SUR;;r`lyd5$_5LY3F>;KEFz9NX^os_i=+rIzQq z7!xpXV!7HBewtNLv%bi+nO64cVq}wfUn4yekNfzZA!#gcQZAvVn|BEwKFDCS5vt3X z;)dFY*)-D;JPza6NsJTIGJ^)Y#IezIbG9CjXD0ZhEul#w&99)s=HA?=|v79sQkT z)oV|@Jbv%%H>|iMQyMwNmh^n)%3SHf$_15kvP;UJbyh7~v3vC0Yfscam?MhC}M3|ny4UYIL(fitmnbpOAWSqd-x;@MAY zIER;Z;5q!*fw3Le&s}s;hF|SEx}@jWuL^~(fYvxaAHVUg+a6h$X+3&y?qtvMUl&@7 zB6{V?t4IC2zxGTu)12wu`c6;fH-%;gq`~>Qt$R-Fx!Z#D+<^yg*N~p?gmmPMJI4FGzEo%}3h9NFJ4TPav+iOG(%GS`Nbb#rLRUy1dSdnETeojIWkGt26KyF){I4a2!vN<-t<-T(*NQdqy6uLsXc=bKwC+EMp$$~Wh`9f<^NQchbHu{5B^BWeV Ukvj{`4oF8Q){Pzb^R>PI1Jnp%`Tzg` literal 0 HcmV?d00001 diff --git a/model_sst2/runs/Mar10_21-22-27_ee-tarrasque/events.out.tfevents.1741641748.ee-tarrasque.4023347.0 b/model_sst2/runs/Mar10_21-22-27_ee-tarrasque/events.out.tfevents.1741641748.ee-tarrasque.4023347.0 new file mode 100644 index 0000000000000000000000000000000000000000..0c1232b6c0a3604c85c2345800ba755b7517067a GIT binary patch literal 5038 zcmaJ_&yOTG6^1Aic3U99-jdU(jwsVR-Mf=bq9_q7C0YTS1V|1DAXzuGp@g z=|p>h|9}HG4oKWNaz^4W;J~jd5)xc_w#!wX>CGHwH1_lJ`|-W+J=c@3g`fZZ z>d9~J-h1zzlfQiN_wT*(tD6TQ`@)ZEl-J9Zs6hxL__PMEbI zDb`3OSXlX@x%khA-wt~J`O@G1aQEdu2fg3@)9LDur>pO9)7tw^Tssz3J^tlcaUidy#_D-A z08+F(T@_qbx@8JQM7S&lZmrylc*V5Cj2M)xI7vVUMRb-xHU**&XKOMtK^uBo3jM7T zx^Kqb-hax?=A3OeXWQ-B_7M}}%yDD5{dAmf_ve>o5G?mDPU;(JED9NT!sT-jOxu&M z7s*+#wC6pOeM7VX4q2LQ`JR9@m)_#exEG*%aUN9qMW-DqR>@9MAt(HW1H2J|zbHc~ zK7Q{;@Mintt~xN!J(+JtSYTfO^h%Yyx`eraw`Gmd9Rp~p%R$RNXOv#EYOE=P^Btpw zj$ypD)(tiiNVa8#sIl;H>*HT@$5~uT{(d5zx7QOwc9`yxgSgHX%VzUhw$Kbm7MLE~ zjL;7@0eU;G8U`+`q$NV3IEbF?EBNSikRko70a@L)hi-b#mTMkRp(m%5Mxaqk{UugI z&g*(u+?XWMG8h$kh3rtSvyDzTBP)t*IkwoG^DwRa?B^-6O7og%+~|QRZu4AvE}Vd! z?x2LkV`RgXHkZPX7188*8&LxyQL+uM0Hmd%sDFpTsC}X$6`@>N&pXdRr7$Qq^g$-B z(L-bhW)HEhrR61hu=}(o&YQh6qSGj;u?GZdghq{BpA~}QC7LVUkj0FA0TsiZ(0#dM zA;Gii5K`7?vk+Xm(;9)wFict^^-$?4V6p{23l@NUDRcg4Y)uO>CG;QbB;tU`4&3l= z4y@m+bl*tIDh&@9ACAjn<631k7*?3qu_x3e=axIZnV_l@3~)&xdVdgP)msFigQsYa z_+2T>oW?xF$%*7KTC6}n;Yd0Xky~uR@GguBJ`YPnklVq}5PnR^(gZVU09zcK2o5_` z^h6OJ7?yVfE{8n}iJnRq1$Id-CHMaL;E*~UIhDdt3vluf_$hfCaAiBK-HsYjLi^MX z!7GxmT2O<(LclOj1Tgy@7JF(^2(hFig{NqZMo%pdps2oeyQLRJSc;%eMHB^nv7@tU zloS{XU<=q;%(~th1oDdaby;9LG6c}rph~Br@`yx_@#UV(^&~Zd79BbT;9KH~eAA%B zDysN$Y?_;=7kmNWcc@}Rv)7coV!=0H0xLmCnDS=YnImLTTu}l3cotkLB9ti6^6Ef* zPP1^BZFOFe!)SK|>fK0=%-TV_`$pHzdi2`M(#l{G#^m{GtFFG@q=Y#c4lwF_1vj%fFF>GGjk@wGVFWw z$i7lzjn3NRv4Q~rbv*}o4$&l*??-M6lQqTz-B;k)xx`bZeXl&)VCn2QP4x4~w0y11 zSv#78MOK+ME$`)LUf(jBOPM(zAAbCkCrAjpnU`(|3_aCD5z<9dK`5H+tuqJ7&j~f1 z0(ncDH!;r*^#M&P%d_FYOQqv8*iolT*9A(&jNgKf(9!X*$=ZF$iU;E{2>T5*3yiuZwLu?jPam>{`IvWl|BRW< z_s2-0t+dS$C;-@mhZY4sv~HTd__Z(jfF zh46i$`{nTE-Jkw9eqZ><$?C5stCz#P5*lp&o#FQ<={v*kPW*QU#+}1cgjyrbQlXXz zt5%SksRn)XK!pI?zQ+_aiJg_)in9@Qi_t2=96b!5MJhLj5j6Ic$B=$co2T4fiu#97 z=T?I2aV-(wO6{y*G*z7EgP3PY%sjYJ5e#4@+X`D%)t&(coTHO2%CoLDYe zp~lFfVf{JR##8ar#iQ54`x@oxBv<&JVQ4IGsg}^w!@C3zA5<{f2=Q_}_~86e*bLJV zJz|hv!!y2P_<$~kKFa`MdIeEvAFdxhTE{mM*~<2KzLZlW fg0EB%{)-6?B!p#h56&;P+tstxjQU^j_kZ+%3`T29 literal 0 HcmV?d00001 diff --git a/test/passes/module/transforms/optical/bert-finetune.py b/test/passes/module/transforms/optical/bert-finetune.py new file mode 100644 index 000000000..e2ac5438c --- /dev/null +++ b/test/passes/module/transforms/optical/bert-finetune.py @@ -0,0 +1,71 @@ +import os +os.environ["CUDA_VISIBLE_DEVICES"] = "0,1" + +import numpy as np +import evaluate +from datasets import load_dataset +from transformers import ( + AutoTokenizer, + AutoModelForSequenceClassification, + Trainer, + TrainingArguments, + DataCollatorWithPadding, +) +from chop.passes.module.transforms.optical import optical_module_transform_pass + +def bert_onn_transform(model): + pass_args = { + "by": "type", + "linear": { + "config": { + "name": "morr", + "miniblock": 4, + "morr_init": True, + "trainable_morr_bias": False, + "trainable_morr_scale": False, + } + }, + } + model, _ = optical_module_transform_pass(model, pass_args) + return model + +def main(): + model_name = "bert-base-uncased" + tokenizer = AutoTokenizer.from_pretrained(model_name) + model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2) + print(model) + + # Placeholder for modifications + model = bert_onn_transform(model) + + dataset = load_dataset("glue", "sst2") + def preprocess(examples): + return tokenizer(examples["sentence"], truncation=True, padding=True) + dataset = dataset.map(preprocess, batched=True) + + data_collator = DataCollatorWithPadding(tokenizer) + metric = evaluate.load("accuracy") + def compute_metrics(eval_pred): + logits, labels = eval_pred + return metric.compute(predictions=np.argmax(logits, axis=1), references=labels) + + training_args = TrainingArguments( + output_dir="model_sst2", + run_name="bert_sst2_experiment", + evaluation_strategy="epoch", + num_train_epochs=3, + logging_steps=50 + ) + + trainer = Trainer( + model=model, + args=training_args, + train_dataset=dataset["train"], + eval_dataset=dataset["validation"], + data_collator=data_collator, + compute_metrics=compute_metrics + ) + trainer.train() + +if __name__ == "__main__": + main() diff --git a/test/passes/module/transforms/optical/run_glue.py b/test/passes/module/transforms/optical/run_glue.py new file mode 100644 index 000000000..b5d5aeb77 --- /dev/null +++ b/test/passes/module/transforms/optical/run_glue.py @@ -0,0 +1,637 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2020 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Finetuning the library models for sequence classification on GLUE.""" +# You can also adapt this script on your own text classification task. Pointers for this are left as comments. + +import logging +import os +import random +import sys +from dataclasses import dataclass, field +from typing import Optional + +import datasets +import evaluate +import numpy as np +from datasets import load_dataset + +import transformers +from transformers import ( + AutoConfig, + AutoModelForSequenceClassification, + AutoTokenizer, + DataCollatorWithPadding, + EvalPrediction, + HfArgumentParser, + PretrainedConfig, + Trainer, + TrainingArguments, + default_data_collator, + set_seed, +) +from transformers.trainer_utils import get_last_checkpoint +from transformers.utils import check_min_version, send_example_telemetry +from transformers.utils.versions import require_version + + +# Will error if the minimal version of Transformers is not installed. Remove at your own risks. +check_min_version("4.50.0.dev0") + +require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt") + +task_to_keys = { + "cola": ("sentence", None), + "mnli": ("premise", "hypothesis"), + "mrpc": ("sentence1", "sentence2"), + "qnli": ("question", "sentence"), + "qqp": ("question1", "question2"), + "rte": ("sentence1", "sentence2"), + "sst2": ("sentence", None), + "stsb": ("sentence1", "sentence2"), + "wnli": ("sentence1", "sentence2"), +} + +logger = logging.getLogger(__name__) + + +@dataclass +class DataTrainingArguments: + """ + Arguments pertaining to what data we are going to input our model for training and eval. + + Using `HfArgumentParser` we can turn this class + into argparse arguments to be able to specify them on + the command line. + """ + + task_name: Optional[str] = field( + default=None, + metadata={"help": "The name of the task to train on: " + ", ".join(task_to_keys.keys())}, + ) + dataset_name: Optional[str] = field( + default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} + ) + dataset_config_name: Optional[str] = field( + default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} + ) + max_seq_length: int = field( + default=128, + metadata={ + "help": ( + "The maximum total input sequence length after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded." + ) + }, + ) + overwrite_cache: bool = field( + default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."} + ) + pad_to_max_length: bool = field( + default=True, + metadata={ + "help": ( + "Whether to pad all samples to `max_seq_length`. " + "If False, will pad the samples dynamically when batching to the maximum length in the batch." + ) + }, + ) + max_train_samples: Optional[int] = field( + default=None, + metadata={ + "help": ( + "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + ) + }, + ) + max_eval_samples: Optional[int] = field( + default=None, + metadata={ + "help": ( + "For debugging purposes or quicker training, truncate the number of evaluation examples to this " + "value if set." + ) + }, + ) + max_predict_samples: Optional[int] = field( + default=None, + metadata={ + "help": ( + "For debugging purposes or quicker training, truncate the number of prediction examples to this " + "value if set." + ) + }, + ) + train_file: Optional[str] = field( + default=None, metadata={"help": "A csv or a json file containing the training data."} + ) + validation_file: Optional[str] = field( + default=None, metadata={"help": "A csv or a json file containing the validation data."} + ) + test_file: Optional[str] = field(default=None, metadata={"help": "A csv or a json file containing the test data."}) + + def __post_init__(self): + if self.task_name is not None: + self.task_name = self.task_name.lower() + if self.task_name not in task_to_keys.keys(): + raise ValueError("Unknown task, you should pick one in " + ",".join(task_to_keys.keys())) + elif self.dataset_name is not None: + pass + elif self.train_file is None or self.validation_file is None: + raise ValueError("Need either a GLUE task, a training/validation file or a dataset name.") + else: + train_extension = self.train_file.split(".")[-1] + assert train_extension in ["csv", "json"], "`train_file` should be a csv or a json file." + validation_extension = self.validation_file.split(".")[-1] + assert ( + validation_extension == train_extension + ), "`validation_file` should have the same extension (csv or json) as `train_file`." + + +@dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. + """ + + model_name_or_path: str = field( + metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} + ) + config_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} + ) + tokenizer_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} + ) + cache_dir: Optional[str] = field( + default=None, + metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"}, + ) + use_fast_tokenizer: bool = field( + default=True, + metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, + ) + model_revision: str = field( + default="main", + metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, + ) + token: str = field( + default=None, + metadata={ + "help": ( + "The token to use as HTTP bearer authorization for remote files. If not specified, will use the token " + "generated when running `huggingface-cli login` (stored in `~/.huggingface`)." + ) + }, + ) + trust_remote_code: bool = field( + default=False, + metadata={ + "help": ( + "Whether to trust the execution of code from datasets/models defined on the Hub." + " This option should only be set to `True` for repositories you trust and in which you have read the" + " code, as it will execute code present on the Hub on your local machine." + ) + }, + ) + ignore_mismatched_sizes: bool = field( + default=False, + metadata={"help": "Will enable to load a pretrained model whose head dimensions are different."}, + ) + + +def main(): + # See all possible arguments in src/transformers/training_args.py + # or by passing the --help flag to this script. + # We now keep distinct sets of args, for a cleaner separation of concerns. + + parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + # If we pass only one argument to the script and it's the path to a json file, + # let's parse it to get our arguments. + model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) + else: + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + + # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The + # information sent is the one passed as arguments along with your Python/PyTorch versions. + send_example_telemetry("run_glue", model_args, data_args) + + # Setup logging + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + handlers=[logging.StreamHandler(sys.stdout)], + ) + + if training_args.should_log: + # The default of training_args.log_level is passive, so we set log level at info here to have that default. + transformers.utils.logging.set_verbosity_info() + + log_level = training_args.get_process_log_level() + logger.setLevel(log_level) + datasets.utils.logging.set_verbosity(log_level) + transformers.utils.logging.set_verbosity(log_level) + transformers.utils.logging.enable_default_handler() + transformers.utils.logging.enable_explicit_format() + + # Log on each process the small summary: + logger.warning( + f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, " + + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}" + ) + logger.info(f"Training/evaluation parameters {training_args}") + + # Detecting last checkpoint. + last_checkpoint = None + if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: + last_checkpoint = get_last_checkpoint(training_args.output_dir) + if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: + raise ValueError( + f"Output directory ({training_args.output_dir}) already exists and is not empty. " + "Use --overwrite_output_dir to overcome." + ) + elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: + logger.info( + f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " + "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." + ) + + # Set seed before initializing model. + set_seed(training_args.seed) + + # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below) + # or specify a GLUE benchmark task (the dataset will be downloaded automatically from the datasets Hub). + # + # For CSV/JSON files, this script will use as labels the column called 'label' and as pair of sentences the + # sentences in columns called 'sentence1' and 'sentence2' if such column exists or the first two columns not named + # label if at least two columns are provided. + # + # If the CSVs/JSONs contain only one non-label column, the script does single sentence classification on this + # single column. You can easily tweak this behavior (see below) + # + # In distributed training, the load_dataset function guarantee that only one local process can concurrently + # download the dataset. + if data_args.task_name is not None: + # Downloading and loading a dataset from the hub. + raw_datasets = load_dataset( + "nyu-mll/glue", + data_args.task_name, + cache_dir=model_args.cache_dir, + token=model_args.token, + ) + elif data_args.dataset_name is not None: + # Downloading and loading a dataset from the hub. + raw_datasets = load_dataset( + data_args.dataset_name, + data_args.dataset_config_name, + cache_dir=model_args.cache_dir, + token=model_args.token, + trust_remote_code=model_args.trust_remote_code, + ) + else: + # Loading a dataset from your local files. + # CSV/JSON training and evaluation files are needed. + data_files = {"train": data_args.train_file, "validation": data_args.validation_file} + + # Get the test dataset: you can provide your own CSV/JSON test file (see below) + # when you use `do_predict` without specifying a GLUE benchmark task. + if training_args.do_predict: + if data_args.test_file is not None: + train_extension = data_args.train_file.split(".")[-1] + test_extension = data_args.test_file.split(".")[-1] + assert ( + test_extension == train_extension + ), "`test_file` should have the same extension (csv or json) as `train_file`." + data_files["test"] = data_args.test_file + else: + raise ValueError("Need either a GLUE task or a test file for `do_predict`.") + + for key in data_files.keys(): + logger.info(f"load a local file for {key}: {data_files[key]}") + + if data_args.train_file.endswith(".csv"): + # Loading a dataset from local csv files + raw_datasets = load_dataset( + "csv", + data_files=data_files, + cache_dir=model_args.cache_dir, + token=model_args.token, + ) + else: + # Loading a dataset from local json files + raw_datasets = load_dataset( + "json", + data_files=data_files, + cache_dir=model_args.cache_dir, + token=model_args.token, + ) + # See more about loading any type of standard or custom dataset at + # https://huggingface.co/docs/datasets/loading_datasets. + + # Labels + if data_args.task_name is not None: + is_regression = data_args.task_name == "stsb" + if not is_regression: + label_list = raw_datasets["train"].features["label"].names + num_labels = len(label_list) + else: + num_labels = 1 + else: + # Trying to have good defaults here, don't hesitate to tweak to your needs. + is_regression = raw_datasets["train"].features["label"].dtype in ["float32", "float64"] + if is_regression: + num_labels = 1 + else: + # A useful fast method: + # https://huggingface.co/docs/datasets/package_reference/main_classes#datasets.Dataset.unique + label_list = raw_datasets["train"].unique("label") + label_list.sort() # Let's sort it for determinism + num_labels = len(label_list) + + # Load pretrained model and tokenizer + # + # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently + # download model & vocab. + config = AutoConfig.from_pretrained( + model_args.config_name if model_args.config_name else model_args.model_name_or_path, + num_labels=num_labels, + finetuning_task=data_args.task_name, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + token=model_args.token, + trust_remote_code=model_args.trust_remote_code, + ) + tokenizer = AutoTokenizer.from_pretrained( + model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, + cache_dir=model_args.cache_dir, + use_fast=model_args.use_fast_tokenizer, + revision=model_args.model_revision, + token=model_args.token, + trust_remote_code=model_args.trust_remote_code, + ) + model = AutoModelForSequenceClassification.from_pretrained( + model_args.model_name_or_path, + from_tf=bool(".ckpt" in model_args.model_name_or_path), + config=config, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + token=model_args.token, + trust_remote_code=model_args.trust_remote_code, + ignore_mismatched_sizes=model_args.ignore_mismatched_sizes, + ) + + # Preprocessing the raw_datasets + if data_args.task_name is not None: + sentence1_key, sentence2_key = task_to_keys[data_args.task_name] + else: + # Again, we try to have some nice defaults but don't hesitate to tweak to your use case. + non_label_column_names = [name for name in raw_datasets["train"].column_names if name != "label"] + if "sentence1" in non_label_column_names and "sentence2" in non_label_column_names: + sentence1_key, sentence2_key = "sentence1", "sentence2" + else: + if len(non_label_column_names) >= 2: + sentence1_key, sentence2_key = non_label_column_names[:2] + else: + sentence1_key, sentence2_key = non_label_column_names[0], None + + # Padding strategy + if data_args.pad_to_max_length: + padding = "max_length" + else: + # We will pad later, dynamically at batch creation, to the max sequence length in each batch + padding = False + + # Some models have set the order of the labels to use, so let's make sure we do use it. + label_to_id = None + if ( + model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id + and data_args.task_name is not None + and not is_regression + ): + # Some have all caps in their config, some don't. + label_name_to_id = {k.lower(): v for k, v in model.config.label2id.items()} + if sorted(label_name_to_id.keys()) == sorted(label_list): + label_to_id = {i: int(label_name_to_id[label_list[i]]) for i in range(num_labels)} + else: + logger.warning( + "Your model seems to have been trained with labels, but they don't match the dataset: " + f"model labels: {sorted(label_name_to_id.keys())}, dataset labels: {sorted(label_list)}." + "\nIgnoring the model labels as a result.", + ) + elif data_args.task_name is None and not is_regression: + label_to_id = {v: i for i, v in enumerate(label_list)} + + if label_to_id is not None: + model.config.label2id = label_to_id + model.config.id2label = {id: label for label, id in config.label2id.items()} + elif data_args.task_name is not None and not is_regression: + model.config.label2id = {l: i for i, l in enumerate(label_list)} + model.config.id2label = {id: label for label, id in config.label2id.items()} + + if data_args.max_seq_length > tokenizer.model_max_length: + logger.warning( + f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the " + f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}." + ) + max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length) + + def preprocess_function(examples): + # Tokenize the texts + args = ( + (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key]) + ) + result = tokenizer(*args, padding=padding, max_length=max_seq_length, truncation=True) + + # Map labels to IDs (not necessary for GLUE tasks) + if label_to_id is not None and "label" in examples: + result["label"] = [(label_to_id[l] if l != -1 else -1) for l in examples["label"]] + return result + + with training_args.main_process_first(desc="dataset map pre-processing"): + raw_datasets = raw_datasets.map( + preprocess_function, + batched=True, + load_from_cache_file=not data_args.overwrite_cache, + desc="Running tokenizer on dataset", + ) + if training_args.do_train: + if "train" not in raw_datasets: + raise ValueError("--do_train requires a train dataset") + train_dataset = raw_datasets["train"] + if data_args.max_train_samples is not None: + max_train_samples = min(len(train_dataset), data_args.max_train_samples) + train_dataset = train_dataset.select(range(max_train_samples)) + + if training_args.do_eval: + if "validation" not in raw_datasets and "validation_matched" not in raw_datasets: + raise ValueError("--do_eval requires a validation dataset") + eval_dataset = raw_datasets["validation_matched" if data_args.task_name == "mnli" else "validation"] + if data_args.max_eval_samples is not None: + max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples) + eval_dataset = eval_dataset.select(range(max_eval_samples)) + + if training_args.do_predict or data_args.task_name is not None or data_args.test_file is not None: + if "test" not in raw_datasets and "test_matched" not in raw_datasets: + raise ValueError("--do_predict requires a test dataset") + predict_dataset = raw_datasets["test_matched" if data_args.task_name == "mnli" else "test"] + if data_args.max_predict_samples is not None: + max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples) + predict_dataset = predict_dataset.select(range(max_predict_samples)) + + # Log a few random samples from the training set: + if training_args.do_train: + for index in random.sample(range(len(train_dataset)), 3): + logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") + + # Get the metric function + if data_args.task_name is not None: + metric = evaluate.load("glue", data_args.task_name, cache_dir=model_args.cache_dir) + elif is_regression: + metric = evaluate.load("mse", cache_dir=model_args.cache_dir) + else: + metric = evaluate.load("accuracy", cache_dir=model_args.cache_dir) + + # You can define your custom compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a + # predictions and label_ids field) and has to return a dictionary string to float. + def compute_metrics(p: EvalPrediction): + preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions + preds = np.squeeze(preds) if is_regression else np.argmax(preds, axis=1) + result = metric.compute(predictions=preds, references=p.label_ids) + if len(result) > 1: + result["combined_score"] = np.mean(list(result.values())).item() + return result + + # Data collator will default to DataCollatorWithPadding when the tokenizer is passed to Trainer, so we change it if + # we already did the padding. + if data_args.pad_to_max_length: + data_collator = default_data_collator + elif training_args.fp16: + data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8) + else: + data_collator = None + + # Initialize our Trainer + trainer = Trainer( + model=model, + args=training_args, + train_dataset=train_dataset if training_args.do_train else None, + eval_dataset=eval_dataset if training_args.do_eval else None, + compute_metrics=compute_metrics, + processing_class=tokenizer, + data_collator=data_collator, + ) + + # Training + if training_args.do_train: + checkpoint = None + if training_args.resume_from_checkpoint is not None: + checkpoint = training_args.resume_from_checkpoint + elif last_checkpoint is not None: + checkpoint = last_checkpoint + train_result = trainer.train(resume_from_checkpoint=checkpoint) + metrics = train_result.metrics + max_train_samples = ( + data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset) + ) + metrics["train_samples"] = min(max_train_samples, len(train_dataset)) + + trainer.save_model() # Saves the tokenizer too for easy upload + + trainer.log_metrics("train", metrics) + trainer.save_metrics("train", metrics) + trainer.save_state() + + # Evaluation + if training_args.do_eval: + logger.info("*** Evaluate ***") + + # Loop to handle MNLI double evaluation (matched, mis-matched) + tasks = [data_args.task_name] + eval_datasets = [eval_dataset] + if data_args.task_name == "mnli": + tasks.append("mnli-mm") + valid_mm_dataset = raw_datasets["validation_mismatched"] + if data_args.max_eval_samples is not None: + max_eval_samples = min(len(valid_mm_dataset), data_args.max_eval_samples) + valid_mm_dataset = valid_mm_dataset.select(range(max_eval_samples)) + eval_datasets.append(valid_mm_dataset) + combined = {} + + for eval_dataset, task in zip(eval_datasets, tasks): + metrics = trainer.evaluate(eval_dataset=eval_dataset) + + max_eval_samples = ( + data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) + ) + metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) + + if task == "mnli-mm": + metrics = {k + "_mm": v for k, v in metrics.items()} + if task is not None and "mnli" in task: + combined.update(metrics) + + trainer.log_metrics("eval", metrics) + trainer.save_metrics("eval", combined if task is not None and "mnli" in task else metrics) + + if training_args.do_predict: + logger.info("*** Predict ***") + + # Loop to handle MNLI double evaluation (matched, mis-matched) + tasks = [data_args.task_name] + predict_datasets = [predict_dataset] + if data_args.task_name == "mnli": + tasks.append("mnli-mm") + predict_datasets.append(raw_datasets["test_mismatched"]) + + for predict_dataset, task in zip(predict_datasets, tasks): + # Removing the `label` columns because it contains -1 and Trainer won't like that. + predict_dataset = predict_dataset.remove_columns("label") + predictions = trainer.predict(predict_dataset, metric_key_prefix="predict").predictions + predictions = np.squeeze(predictions) if is_regression else np.argmax(predictions, axis=1) + + output_predict_file = os.path.join(training_args.output_dir, f"predict_results_{task}.txt") + if trainer.is_world_process_zero(): + with open(output_predict_file, "w") as writer: + logger.info(f"***** Predict results {task} *****") + writer.write("index\tprediction\n") + for index, item in enumerate(predictions): + if is_regression: + writer.write(f"{index}\t{item:3.3f}\n") + else: + item = label_list[item] + writer.write(f"{index}\t{item}\n") + + kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "text-classification"} + if data_args.task_name is not None: + kwargs["language"] = "en" + kwargs["dataset_tags"] = "glue" + kwargs["dataset_args"] = data_args.task_name + kwargs["dataset"] = f"GLUE {data_args.task_name.upper()}" + + if training_args.push_to_hub: + trainer.push_to_hub(**kwargs) + else: + trainer.create_model_card(**kwargs) + + +def _mp_fn(index): + # For xla_spawn (TPUs) + main() + + +if __name__ == "__main__": + main() diff --git a/wandb/latest-run b/wandb/latest-run new file mode 120000 index 000000000..d0ad236b2 --- /dev/null +++ b/wandb/latest-run @@ -0,0 +1 @@ +run-20250310_212229-x0wxhan7 \ No newline at end of file diff --git a/wandb/run-20250310_202153-60c8phhh/files/config.yaml b/wandb/run-20250310_202153-60c8phhh/files/config.yaml new file mode 100644 index 000000000..4b489b501 --- /dev/null +++ b/wandb/run-20250310_202153-60c8phhh/files/config.yaml @@ -0,0 +1,497 @@ +_attn_implementation_autoset: + value: true +_name_or_path: + value: bert-base-uncased +_wandb: + value: + cli_version: 0.19.1 + m: + - "1": train/global_step + "6": + - 3 + "7": [] + - "1": train/learning_rate + "5": 1 + "6": + - 1 + - 3 + "7": [] + - "1": train/epoch + "5": 1 + "6": + - 1 + - 3 + "7": [] + - "1": train/loss + "5": 1 + "6": + - 1 + - 3 + "7": [] + - "1": train/grad_norm + "5": 1 + "6": + - 1 + - 3 + "7": [] + python_version: 3.11.11 + t: + "1": + - 1 + - 2 + - 3 + - 5 + - 11 + - 41 + - 49 + - 51 + - 53 + - 55 + - 71 + - 100 + "2": + - 1 + - 2 + - 3 + - 5 + - 11 + - 41 + - 49 + - 51 + - 53 + - 55 + - 71 + - 100 + "3": + - 7 + - 13 + - 19 + - 23 + - 55 + - 66 + "4": 3.11.11 + "5": 0.19.1 + "6": 4.47.1 + "8": + - 5 + "9": + "1": transformers_trainer + "12": 0.19.1 + "13": linux-x86_64 +accelerator_config: + value: + dispatch_batches: null + even_batches: true + gradient_accumulation_kwargs: null + non_blocking: false + split_batches: false + use_seedable_sampler: true +adafactor: + value: false +adam_beta1: + value: 0.9 +adam_beta2: + value: 0.999 +adam_epsilon: + value: 1e-08 +add_cross_attention: + value: false +architectures: + value: + - BertForMaskedLM +attention_probs_dropout_prob: + value: 0.1 +auto_find_batch_size: + value: false +average_tokens_across_devices: + value: false +bad_words_ids: + value: null +batch_eval_metrics: + value: false +begin_suppress_tokens: + value: null +bf16: + value: false +bf16_full_eval: + value: false +bos_token_id: + value: null +chunk_size_feed_forward: + value: 0 +classifier_dropout: + value: null +cross_attention_hidden_size: + value: null +data_seed: + value: null +dataloader_drop_last: + value: false +dataloader_num_workers: + value: 0 +dataloader_persistent_workers: + value: false +dataloader_pin_memory: + value: true +dataloader_prefetch_factor: + value: null +ddp_backend: + value: null +ddp_broadcast_buffers: + value: null +ddp_bucket_cap_mb: + value: null +ddp_find_unused_parameters: + value: null +ddp_timeout: + value: 1800 +debug: + value: [] +decoder_start_token_id: + value: null +deepspeed: + value: null +disable_tqdm: + value: false +dispatch_batches: + value: null +diversity_penalty: + value: 0 +do_eval: + value: true +do_predict: + value: false +do_sample: + value: false +do_train: + value: false +early_stopping: + value: false +encoder_no_repeat_ngram_size: + value: 0 +eos_token_id: + value: null +eval_accumulation_steps: + value: null +eval_delay: + value: 0 +eval_do_concat_batches: + value: true +eval_on_start: + value: false +eval_steps: + value: null +eval_strategy: + value: epoch +eval_use_gather_object: + value: false +evaluation_strategy: + value: epoch +exponential_decay_length_penalty: + value: null +finetuning_task: + value: null +forced_bos_token_id: + value: null +forced_eos_token_id: + value: null +fp16: + value: false +fp16_backend: + value: auto +fp16_full_eval: + value: false +fp16_opt_level: + value: O1 +fsdp: + value: [] +fsdp_config: + value: + min_num_params: 0 + xla: false + xla_fsdp_grad_ckpt: false + xla_fsdp_v2: false +fsdp_min_num_params: + value: 0 +fsdp_transformer_layer_cls_to_wrap: + value: null +full_determinism: + value: false +gradient_accumulation_steps: + value: 1 +gradient_checkpointing: + value: false +gradient_checkpointing_kwargs: + value: null +greater_is_better: + value: null +group_by_length: + value: false +half_precision_backend: + value: auto +hidden_act: + value: gelu +hidden_dropout_prob: + value: 0.1 +hidden_size: + value: 768 +hub_always_push: + value: false +hub_model_id: + value: null +hub_private_repo: + value: null +hub_strategy: + value: every_save +hub_token: + value: +id2label: + value: + "0": LABEL_0 + "1": LABEL_1 +ignore_data_skip: + value: false +include_for_metrics: + value: [] +include_inputs_for_metrics: + value: false +include_num_input_tokens_seen: + value: false +include_tokens_per_second: + value: false +initializer_range: + value: 0.02 +intermediate_size: + value: 3072 +is_decoder: + value: false +is_encoder_decoder: + value: false +jit_mode_eval: + value: false +label_names: + value: null +label_smoothing_factor: + value: 0 +label2id: + value: + LABEL_0: 0 + LABEL_1: 1 +layer_norm_eps: + value: 1e-12 +learning_rate: + value: 5e-05 +length_column_name: + value: length +length_penalty: + value: 1 +load_best_model_at_end: + value: false +local_rank: + value: 0 +log_level: + value: passive +log_level_replica: + value: warning +log_on_each_node: + value: true +logging_dir: + value: model_sst2/runs/Mar10_20-14-07_ee-tarrasque +logging_first_step: + value: false +logging_nan_inf_filter: + value: true +logging_steps: + value: 50 +logging_strategy: + value: steps +lr_scheduler_type: + value: linear +max_grad_norm: + value: 1 +max_length: + value: 20 +max_position_embeddings: + value: 512 +max_steps: + value: -1 +metric_for_best_model: + value: null +min_length: + value: 0 +model/num_parameters: + value: 109483778 +model_type: + value: bert +mp_parameters: + value: "" +neftune_noise_alpha: + value: null +no_cuda: + value: false +no_repeat_ngram_size: + value: 0 +num_attention_heads: + value: 12 +num_beam_groups: + value: 1 +num_beams: + value: 1 +num_hidden_layers: + value: 12 +num_return_sequences: + value: 1 +num_train_epochs: + value: 3 +optim: + value: adamw_torch +optim_args: + value: null +optim_target_modules: + value: null +output_attentions: + value: false +output_dir: + value: model_sst2 +output_hidden_states: + value: false +output_scores: + value: false +overwrite_output_dir: + value: false +pad_token_id: + value: 0 +past_index: + value: -1 +per_device_eval_batch_size: + value: 8 +per_device_train_batch_size: + value: 8 +per_gpu_eval_batch_size: + value: null +per_gpu_train_batch_size: + value: null +position_embedding_type: + value: absolute +prediction_loss_only: + value: false +prefix: + value: null +problem_type: + value: null +push_to_hub: + value: false +push_to_hub_model_id: + value: null +push_to_hub_organization: + value: null +push_to_hub_token: + value: +ray_scope: + value: last +remove_invalid_values: + value: false +remove_unused_columns: + value: true +repetition_penalty: + value: 1 +report_to: + value: + - tensorboard + - wandb +restore_callback_states_from_checkpoint: + value: false +resume_from_checkpoint: + value: null +return_dict: + value: true +return_dict_in_generate: + value: false +run_name: + value: bert_sst2_experiment +save_on_each_node: + value: false +save_only_model: + value: false +save_safetensors: + value: true +save_steps: + value: 500 +save_strategy: + value: steps +save_total_limit: + value: null +seed: + value: 42 +sep_token_id: + value: null +skip_memory_metrics: + value: true +split_batches: + value: null +suppress_tokens: + value: null +task_specific_params: + value: null +temperature: + value: 1 +tf_legacy_loss: + value: false +tf32: + value: null +tie_encoder_decoder: + value: false +tie_word_embeddings: + value: true +tokenizer_class: + value: null +top_k: + value: 50 +top_p: + value: 1 +torch_compile: + value: false +torch_compile_backend: + value: null +torch_compile_mode: + value: null +torch_dtype: + value: null +torch_empty_cache_steps: + value: null +torchdynamo: + value: null +torchscript: + value: false +tpu_metrics_debug: + value: false +tpu_num_cores: + value: null +transformers_version: + value: 4.47.1 +type_vocab_size: + value: 2 +typical_p: + value: 1 +use_bfloat16: + value: false +use_cache: + value: true +use_cpu: + value: false +use_ipex: + value: false +use_legacy_prediction_loop: + value: false +use_liger_kernel: + value: false +use_mps_device: + value: false +vocab_size: + value: 30522 +warmup_ratio: + value: 0 +warmup_steps: + value: 0 +weight_decay: + value: 0 diff --git a/wandb/run-20250310_202153-60c8phhh/files/requirements.txt b/wandb/run-20250310_202153-60c8phhh/files/requirements.txt new file mode 100644 index 000000000..67cb06e8a --- /dev/null +++ b/wandb/run-20250310_202153-60c8phhh/files/requirements.txt @@ -0,0 +1,295 @@ +pydantic==2.10.4 +urllib3==2.3.0 +scipy==1.15.0 +myst-nb==1.1.2 +pure_eval==0.2.3 +wcwidth==0.2.13 +attr-dot-dict==0.1.0 +emoji==2.14.0 +mkl_random==1.2.8 +keras==3.8.0 +nvidia-cuda-runtime-cu12==12.4.127 +torchvision==0.20.1 +cocotb==1.8.0 +wheel==0.44.0 +imageio==2.36.1 +dill==0.3.8 +pydot==3.0.4 +transformers==4.47.1 +sphinx-book-theme==1.1.3 +myst-parser==4.0.0 +traitlets==5.14.3 +nvidia-cudnn-cu12==9.1.0.70 +nvidia-curand-cu12==10.3.5.147 +kiwisolver==1.4.8 +pygame==2.6.1 +greenlet==3.1.1 +pytest-profiling==1.8.1 +requests==2.32.3 +aiosignal==1.2.0 +aiosignal==1.3.2 +Sphinx==8.1.3 +torch-summary==1.4.5 +Farama-Notifications==0.0.4 +sphinxcontrib-plantuml==0.30 +ptyprocess==0.7.0 +pexpect==4.9.0 +yarl==1.18.0 +yarl==1.18.3 +filelock==3.16.1 +filelock==3.13.1 +datasets==3.2.0 +datasets==3.3.2 +bitstring==4.3.0 +triton==3.1.0 +py4j==0.10.9.8 +pybind11==2.13.6 +pluggy==1.5.0 +regex==2024.11.6 +cvxpy==1.6.0 +sphinx-test-reports==1.1.0 +jsonschema-specifications==2024.10.1 +fastjsonschema==2.21.1 +pytest-xdist==3.6.1 +smmap==5.0.2 +onnx==1.17.0 +tornado==6.4.2 +GitPython==3.1.44 +sphinxcontrib-htmlhelp==2.1.0 +iniconfig==2.0.0 +threadpoolctl==3.5.0 +cycler==0.12.1 +tzdata==2024.2 +tzdata==2023.3 +certifi==2024.12.14 +certifi==2025.1.31 +numpy==1.26.4 +gast==0.6.0 +frozenlist==1.5.0 +opt_einsum==3.4.0 +astunparse==1.6.3 +colorlog==6.9.0 +grpcio==1.69.0 +jupyter_core==5.7.2 +torchmetrics==1.6.1 +gprof2dot==2024.6.6 +nvidia-ml-py==12.560.30 +multidict==6.1.0 +etils==1.11.0 +jupyter_client==8.6.3 +sphinxcontrib-jsmath==1.0.1 +tensorboard-plugin-profile==2.19.0 +clarabel==0.9.0 +idna==3.7 +idna==3.10 +pylance==0.21.0 +ipykernel==6.29.5 +matplotlib-inline==0.1.7 +jedi==0.19.2 +lightning-utilities==0.11.9 +namex==0.0.8 +kornia==0.7.4 +docker-pycreds==0.4.0 +mkl-service==2.4.0 +fonttools==4.55.3 +tensorboard-data-server==0.7.2 +beautifulsoup4==4.12.3 +Werkzeug==3.1.3 +Markdown==3.7 +asttokens==3.0.0 +huggingface-hub==0.27.1 +huggingface_hub==0.29.2 +pytest-sugar==1.0.0 +tensorflow==2.18.0 +pytest==8.3.4 +joblib==1.4.2 +ipython==8.31.0 +mdurl==0.1.2 +optimum==1.23.3 +pytest-metadata==3.1.1 +debugpy==1.8.11 +absl-py==2.1.0 +mkl_fft==1.3.11 +sphinxcontrib-serializinghtml==2.0.0 +MarkupSafe==3.0.2 +sympy==1.13.1 +six==1.16.0 +six==1.17.0 +multiprocess==0.70.15 +multiprocess==0.70.16 +snowballstemmer==2.2.0 +zipp==3.21.0 +ale-py==0.10.1 +scs==3.2.7.post2 +find_libpython==0.4.0 +sphinxcontrib-jquery==4.1 +decorator==5.1.1 +nvidia-nvtx-cu12==12.4.127 +prompt_toolkit==3.0.48 +charset-normalizer==3.4.1 +charset-normalizer==3.3.2 +nvidia-cuda-nvrtc-cu12==12.4.127 +evaluate==0.4.3 +tensorboard==2.18.0 +lightning==2.5.0.post0 +py-cpuinfo==9.0.0 +prettytable==3.12.0 +nbclient==0.10.2 +execnet==2.1.1 +torch-tb-profiler==0.4.3 +kornia_rs==0.1.8 +contourpy==1.3.1 +pydata-sphinx-theme==0.16.1 +pip==24.2 +requests-file==2.1.0 +jsonschema==4.23.0 +sphinx_glpi_theme==0.6 +imagesize==1.4.1 +osqp==0.6.7.post3 +importlib_resources==6.5.2 +termcolor==2.5.0 +importlib_metadata==8.5.0 +cocotb-bus==0.2.1 +future==1.0.0 +pyarrow==18.1.0 +pyarrow==19.0.0 +packaging==24.2 +sentry-sdk==2.19.2 +einops==0.8.0 +nvidia-cuda-cupti-cu12==12.4.127 +bitarray==3.0.0 +aiohttp==3.11.10 +aiohttp==3.11.11 +nvidia-cufft-cu12==11.2.1.3 +scikit-learn==1.6.0 +pyzmq==26.2.0 +Mako==1.3.8 +platformdirs==4.3.6 +nvidia-cusolver-cu12==11.6.1.9 +markdown-it-py==3.0.0 +wrapt==1.17.0 +tensorboardX==2.6.2.2 +protobuf==3.20.2 +propcache==0.2.1 +propcache==0.2.0 +pytz==2024.1 +pytz==2024.2 +wandb==0.19.1 +libclang==18.1.1 +nvidia-cublas-cu12==12.4.5.8 +alembic==1.14.0 +nvidia-nvjitlink-cu12==12.4.127 +click==8.1.8 +gymnasium==1.0.0 +Brotli==1.0.9 +lxml==5.3.0 +tensorflow-io-gcs-filesystem==0.37.1 +matplotlib==3.10.0 +tqdm==4.67.1 +annotated-types==0.7.0 +ghp-import==2.1.0 +pillow==10.4.0 +onnxconverter-common==1.14.0 +stable_baselines3==2.4.0 +imageio-ffmpeg==0.5.1 +onnxruntime==1.20.1 +typing_extensions==4.12.2 +Pygments==2.19.0 +coloredlogs==15.0.1 +sentencepiece==0.2.0 +torch==2.5.1 +timm==1.0.12 +mdit-py-plugins==0.4.2 +PyYAML==6.0.2 +gviz-api==1.10.0 +xxhash==3.5.0 +setuptools==75.1.0 +pytorch-nlp==0.5.0 +babel==2.16.0 +soupsieve==2.6 +ipdb==0.13.13 +python-dateutil==2.9.0.post0 +comm==0.2.2 +flatbuffers==24.12.23 +rpds-py==0.22.3 +psutil==6.1.1 +h5py==3.12.1 +numexpr==2.10.1 +optuna==4.1.0 +accessible-pygments==0.0.5 +tf_keras==2.18.0 +mypy-extensions==1.0.0 +pytest-html==4.1.1 +hyperopt==0.2.7 +tabulate==0.9.0 +fsspec==2024.12.0 +fsspec==2024.9.0 +parso==0.8.4 +sphinxcontrib-qthelp==2.0.0 +qdldl==0.1.7.post5 +nvidia-cusparse-cu12==12.3.1.170 +sphinx-data-viewer==0.1.5 +mase-cuda==0.0.1 +cloudpickle==3.1.0 +coverage==7.6.10 +pandas==2.2.3 +Jinja2==3.1.5 +black==24.10.0 +pathspec==0.12.1 +sphinxcontrib-devhelp==2.0.0 +mpmath==1.3.0 +pytorch-lightning==2.5.0.post0 +alabaster==1.0.0 +jupyter-cache==1.0.1 +stack-data==0.6.3 +sphinx-rtd-theme==3.0.2 +accelerate==1.2.1 +pyparsing==3.2.1 +docutils==0.21.2 +pytest-cov==6.0.0 +rich==13.9.4 +safetensors==0.5.3 +safetensors==0.5.0 +humanfriendly==10.0 +PySocks==1.7.1 +toml==0.10.2 +Bottleneck==1.4.2 +setproctitle==1.3.4 +opencv-python==4.10.0.84 +referencing==0.35.1 +nvidia-nccl-cu12==2.21.5 +tokenizers==0.21.0 +attrs==24.3.0 +aiohappyeyeballs==2.4.4 +optree==0.13.1 +networkx==3.4.2 +sphinx-needs==4.1.0 +nbformat==5.10.4 +gitdb==4.0.12 +SQLAlchemy==2.0.36 +executing==2.1.0 +google-pasta==0.2.0 +ml-dtypes==0.4.1 +pynvml==12.0.0 +nest-asyncio==1.6.0 +sphinxcontrib-applehelp==2.0.0 +pydantic_core==2.27.2 +transformers==4.47.1 +mase-tools==1.0.0 +more-itertools==10.3.0 +typing_extensions==4.12.2 +inflect==7.3.1 +typeguard==4.3.0 +tomli==2.0.1 +jaraco.context==5.3.0 +jaraco.functools==4.0.1 +platformdirs==4.2.2 +packaging==24.1 +autocommand==2.2.2 +jaraco.text==3.12.1 +zipp==3.19.2 +jaraco.collections==5.1.0 +importlib_metadata==8.0.0 +wheel==0.43.0 +backports.tarfile==1.2.0 +importlib_resources==6.4.0 diff --git a/wandb/run-20250310_202153-60c8phhh/files/wandb-metadata.json b/wandb/run-20250310_202153-60c8phhh/files/wandb-metadata.json new file mode 100644 index 000000000..2bfa1fc7b --- /dev/null +++ b/wandb/run-20250310_202153-60c8phhh/files/wandb-metadata.json @@ -0,0 +1,60 @@ +{ + "os": "Linux-5.14.0-427.28.1.el9_4.x86_64-x86_64-with-glibc2.34", + "python": "CPython 3.11.11", + "startedAt": "2025-03-10T20:21:53.597208Z", + "program": "/home/jw3621/Projects/bert-onn/test/passes/module/transforms/optical/bert-finetune.py", + "codePath": "test/passes/module/transforms/optical/bert-finetune.py", + "git": { + "remote": "https://github.com/Johnny1882/mase.git", + "commit": "758710333d8ca4b7444930df91d86c7642652426" + }, + "email": "jw3621@ic.ac.uk", + "root": "/home/jw3621/Projects/bert-onn", + "host": "ee-tarrasque", + "executable": "/home/jw3621/anaconda3/envs/mase/bin/python", + "codePathLocal": "test/passes/module/transforms/optical/bert-finetune.py", + "cpu_count": 16, + "cpu_count_logical": 32, + "gpu": "NVIDIA GeForce RTX 3090", + "gpu_count": 4, + "disk": { + "/": { + "total": "75125227520", + "used": "61026897920" + } + }, + "memory": { + "total": "269555560448" + }, + "cpu": { + "count": 16, + "countLogical": 32 + }, + "gpu_nvidia": [ + { + "name": "NVIDIA GeForce RTX 3090", + "memoryTotal": "25769803776", + "cudaCores": 10496, + "architecture": "Ampere" + }, + { + "name": "NVIDIA GeForce RTX 3090", + "memoryTotal": "25769803776", + "cudaCores": 10496, + "architecture": "Ampere" + }, + { + "name": "NVIDIA GeForce RTX 3090", + "memoryTotal": "25769803776", + "cudaCores": 10496, + "architecture": "Ampere" + }, + { + "name": "NVIDIA GeForce RTX 3090", + "memoryTotal": "25769803776", + "cudaCores": 10496, + "architecture": "Ampere" + } + ], + "cudaVersion": "12.7" +} \ No newline at end of file diff --git a/wandb/run-20250310_202153-60c8phhh/files/wandb-summary.json b/wandb/run-20250310_202153-60c8phhh/files/wandb-summary.json new file mode 100644 index 000000000..76b497321 --- /dev/null +++ b/wandb/run-20250310_202153-60c8phhh/files/wandb-summary.json @@ -0,0 +1 @@ +{"train/global_step":350,"_wandb":{"runtime":63},"train/learning_rate":4.722882026920032e-05,"_timestamp":1.7416381701494613e+09,"train/epoch":0.166270783847981,"train/grad_norm":4.34669828414917,"train/loss":0.2263,"_runtime":56.552776547,"_step":6} \ No newline at end of file diff --git a/wandb/run-20250310_202153-60c8phhh/run-60c8phhh.wandb b/wandb/run-20250310_202153-60c8phhh/run-60c8phhh.wandb new file mode 100644 index 0000000000000000000000000000000000000000..28b610420d3100fe8e76d63abfdf1515169b8a57 GIT binary patch literal 164998 zcmeFa31Adew*PMjWN8*j5Y)KExG!{S@9v1>=sd?|+{X3IsMFIV4Ft29MMTkw0VN1x zKsHe^;)b#axJ49=iY%@HMNv@!6;V_IxI_^6f6u*Do!d+3s^|RXy{`OuJl~_cs=De^ z=iYP9UCuZE%Hz)5@X~&tHdIdv@8{{}sr6)@d4R0=eG>jh&R2EG&|dB7^>#C)fq z;gbGCkNK}ZdOY1SYCT8g_Pnk#P*OH**!AJ-)q)c|`~KZKJYzH3?jpO#(CQ&Wib{tJ z4i`jn5AgQ#b_op+mlh7H^<><1^O*+@id0mEDl4m$P~?X4NJUXeq_k>S&q!oIRk)%e zTzOq}>J{dMD59PM${m z?d)mj;>qqd)YIspf9G~peX@+dC*ipWgwjRxX?~x4X2uDgVLgkBN~>=ea6`Zd8QK$` zOT5Pot*R=o%+JdkQdBjxdXTT6tR(NevZ1A=Bji9p$tww0Mtt}x>jb|Z@XL~_s)d1q zur|oAXxR79?;E~?OioER2&%+ zE*KFiE~{+2d|z+R@`~!xNMUGbBwUEkVc_1wynU;RBB9}B6@{Tl$)HGKA=(1IlDvNH6ga$`&?%=YD;o*wHT8~G<$yr5}p-5>#Sz)9iR2X6ZH+m-rdcF8I z{?lW;hZa=ewn9+=QB?R1>*)^n_6-*nhT@KlUK(H1UAOnknO<*6_y#~I9a1$Ee> zKKQ7pH2M*4zOT1?VOgj$TvA?)pEX&7UL;&mJOb~nth^j3qTd|g?Os}45*ic4Yu|zu2QHQi5J247pjfE*zPg?y}bs73z?BrhKk5^xDC>E!NhC_ zm1y+16j?j)9OUg~T6$$cS%uI9j`AKPG@KA#%8*DYF==}XFiMzN0jzCMS!JlI>{>LW zBA6Wi)wXXVNxnVO+YcR2+4T`z`TB5iQDF%GtrmX&L~jnd^0HEDec%fF45OhiQ)$D- zf9YWF0hQI|<+ywEE}~C`cwt*B5b4iAbH*Lrep%9V2Sb8^o=-(*D5+A2Z?#bN3p_~Rbn z?Sf|=Ty%r=K|gQLWG^5K9SO{#?Y*Ei;m~An10PpL%7sJuFCFRaZ}tn+LBhkfU5g)Z zxc5M!5vnYY6ci0EDhQQ_VYSga#bwJY$_5ojNG#bg2R>jB z9OHoM(gOUS!d&mM-ovQ(flX1{CnptVjVWARRaO~^_T@+69xB46l`zE;wC50O)Mcf$ zo^H9^N9THTy#3Ld3X5=wP{Ghh!L{XOMKCL&+vtaW2fbsld7oX_{p7-Xc&{%j2oEyb z;r^Cp{t^l2;g6Hqe9y(Nt z_YjIuH@m+a8Gwr&=RGL#mrT&!-0+~vvf}EhNUk@_+nroxK^U(HS7na+P;VBqeHa(H zF8(N%-5=`pntfMcQAMq%XD)w*6b9xp7Y|byUV+9P5)U(7NS3!7{0c0m2&Q6w*2miw zpGD9US-%ehz$a*KJbEyXJH*?I4nm`#{%{B!N7vj)d0D|w_Os^E1eP3osI0UYN0_G` zM>8MFkQ5^PxfEs)C_bS08}2W($X~ABsETDDS~^#E|mpBv;@qGY8PkLlUmQ zkHDQo8%4OFpt_{Gn76OWD)#o^9B>hKa@*S;~|UyIO-qBDxz=%K(kOO`u5Vnp}|GP=5@dx&OWOut3rz`E-EP^&%k1D>#NG};0U7l$};Mat$rnwofV5L+Pn`izeI$| z`jqH1$=>4;Ccwa{pDHdM6fS`CW4S|Ua79^(;9a7pWO} z5xP(MsHB`iaq1hQ-}LZyu0%u+Ytj_zBPrzz@lT@dX;@JeyHT_6GJghL99^@hJc3Zc z{H(XP^Ps^(9)f=lE*JiNkheGeJ5*L)Wm#h9+>7K~?@`|VL&L>`sV^)js-(`19zIfv z8|#uwE;*N609`vY7>Q5po!}dG)x>Xlc{^8OB*A)R=ue8_aEr;eAjm?0g0!lzu)OWZ z_%kEaMfw(^_G^Z%{BtC3^f^-FilPESIjFiVH$V(;evon1HRK=&CGs+s6~VQH5M{OH zbA8RO&9q&9{lgFuwVgyikCg0M=IU{;M|-n{Zz)KNP|uDr3s{Wxh{$@-va4YEh4fyi ze>vLQ4NZ~7M75s&EcS>W%*dtQt+1#vjA$xUbzNae^o=6MH6!WjQoL?tWALXXl&^AA z2S65cvdhz_tZg79q9;eMid*4+6G;bqou;MWx*Q|rlEQJ;6%Y&zsX(TJ{t)>idS~8- z+10bXeTs%45Qv1hwO(7))-AxE2Uiy2-}Dj=K?;F?ged7FnpXS>n1j7N*jEK*rGty` z60f)^H})vR$rLr;P#nhp$MSamuMj)!dL`MX%m*O;tFpel*?Y1#mz^L4kuU^=Z-}}n zp9l@F2;*)!wii6!EO^5Ryg(T56s-BpJ>96xv&G~#o1Q` z3E!Znr)Grj*rycn;8SX|)V|rP=810wE@ysQN-Ym<HA>IR6<{cVjWpGjBJrW*-=XBg^%{Bnn%0A6S@Yo$5uZ-m$wVs2o z%HHc66CKH2zq!kJKlS^_&v@nLeGJOBiccjG>{ipqCJ?W zTV0iPnPa@!^#09(p_ovztgQ`U2&%lgawu6R#uo*e{o$Reky#};BmR&`$Z)*b-rG|3 z!}-T;9=9i%W(9^<;D~dPadfA%Sn@|^Fz}qq&kS95(O=HG@DD8ZpfAfSimsq=Z8`;$?LMMAO#Y0TO{c;CheEFs40C7kpjd-f0ocEBXRcGX~goC^nk|^n10s!{AlkXmI*M5 zus8JJl^E-!2*qT9HfkM2&W2eOH|t93aoZ9>w9(SYV2t-f5KI@LJ%)?Rha!ICMv3>% zdp?Yb3_(YNTSrxk;Y3b2m(4jqyxAW-+eQ zGv#&uKH2l1ecMJnRv%|1N#xg~gjUa22o+P~ug`p_Gjj@dmn!I$c# z*?Wd_7s?aP#WVCy`3xr=<3*Rs;_kd?+-vsg-DN4L1ysL3aO>iummP!I(*sgoZ`dAB zR*!BQ+5F~$iBl`GfAI9m$&rp7nd7cMBXguYLzVTMt0XC3k_MiBO3vv;Re6=ygm%9% z(Fat0QlqNN$m-ec3x;aNicLReWdGvH8(LNp$s0CYH5557TpA{?5mxgerPo)o+1k9~ zqCt7(BdUg$mC{6T9>z^01Im&AqQ9-oV~#AZv^0+*&f?-oaUN3O(gGS_SdeSq|^dbC8IZ4#_E`nQO|fDYY3?U4@T}bC{pXF-I_R zhDRzQIhDwf!xfdM<_s@l2O!i87v^9%2Xi}-8zVWG8qeX=H#y;)>)|NMDtz9Y9CO5| z(ntT#3Eri<)@(f%&@xj)+v|x!Tl@HLA9r$>fsM?O1+$U|s`-+ZuLVxY$&r12qcX3k z>YD3zu6dT)VpeKatSx@Y46Nn0DSvVsDkTP$oUiEllEP3a8f%N)w_P%Y%&Ko{RLrb6 zshy2IlxRyjE%b5wzrGEi6Z70xQ?XRLfDe|`1wV0B<>R<}~W#aT5^esq;5yT$C| zH%`6F9liX>94&^H8pv0u5h^~v%4i+#f2ovc^>apR@#g6nZltBh(vtFJKWM2~nZ#&) zcIuSdiB^9{v>I+*{kbb?88NibYcX1yPc|5>)s0>6C0YkLqBVW;;zv8V(YE}tw3K{_ zY)kX`bw+FC4SnxF4zvzVO{*pYq3Ku59d4fdKpz3t1h;fsfmm1wO{tUDeF}#)__w8> z5v)TT!D^Vh{0cY13dX>a1No}KU>O|Np(p)$GquC))UcRG25 zJL*8$pRdV`R*=V9lYN^?h}NNwXtnjC5G{gmZXA~VfFDB?A#I|T8BBJHT$tC z^F@z%rz^TG*p?W{gOeOkF1{hpOSBGmL~H%rIlsA*mKe!1TDs0`OOcpItld># zNVJY{M62<|4__DK2iMq^7|Cv2`db&hD0)QK&=TW#(9-jDjnOiAul3>2uO3abavjlXdFzu}v5$8REisU1 zv@~v8W^8)Wfg6Vqt)m^$s@pwdr7LNPkvylx`*?%*@mGFu9viJW#u2Ti59Zw^dc<)q z>9-Uyl4rD7-edUuyx;1%>cG>(CC309bcRX{u=n`E+OAO?7a)~N; ziKc1I|G|)m)(MViHEdk|wn*y^w{%-#AkS&>w8tEh|Jxt){zbG-bVRGJ`OV3qOLPq_ zF_Mp=#XI?}T{347t&<$ls@b^yV^`7=LwQb%#!cZ8c_+W^*C9)Y*2#`&ZEjxtpy(1^ zV_RY z_+ZpfH^LGFdB8IAb((g?uqmHN?Ecf7=cyg~Qo~~LgK1eiC%h$kM7NlhDhBe#Kuln; zNow6^a9St5r&bWHJV&(Fel>5o*vGqumKewz0~I-+4asZ1V1SKnd0y^*0MU{h(Q4fE z+H$jxmyaE}acZ4Qx-C_V?YxaV_M(h>uC|3Ep&SBvOAjju#{bHleyCxVuin${>$bRUxn$`UmbryVHV9G8-fQvPt|x$(KQ*sf zT95QCbBwcJZaUhXxWtGaxQu+2XVxkohq!mf*?%Kk0Y|uM?s|GyEG}0pv%#`cF{Ed> zWIiOQaqIeZ;nXXsH3n1TVy%%|*Yw3J7JA&tOAP7*Yz~%>&g;DQnx5DFN8*+5jMwI; zTV2Xa4C}!Qv#=&Fo&@zjf6O-Gb*dv?&7UuRInFz}r9Fy~J>#X(>QlpK@b2s4v7N6c zUIU%+nloceoOg7^3KH~RVrTCJFLK5LJsIpVc``t8jwa*uyV&5K17Tli4wil_5qyL-2oml)i0Tzn45;O2GJxCv)bb38pYF6JINuDaSM zZjI~i-NH)@?t|3bvqih;zW7#yt{WQvM7;j!h}YUr$FGRvHQo)wUg*BW;GXeP&F-GN z$K^*X$st~6IO5eZWygd#_vjX0Vsy`WnRBwd`}%rK)&a!pOh>%xKAE$|oxF50yyv{k zc}5;iEcah=nSp)Ve;S11j=Y>eOoh5+yY4f(dGhx=5C7eJdhS{uPX4b!C!ugGt3|+T z0Vdh1A3Z7egu2=RIO(iGCwjf+FQQi9K7Z(lfyet4ob%^FKGX)~MGvf!cKbZ>n(ICr zdl*hUJNsa7|17g`c2-xEoT6$TD8wX7H*MLld_P#sIoZ2B-ea@kFAqDyq^vAS_}_pL z2&lTIdwp5SSfPLw#`XPD)5T?raNS&?N;l2PI8 z;H*S$xJ6WsJ37l7q8mopj1rW>IS^ZOHQ7)DGA^R1h8#Ir3VIL8>K2MVSWg+%x)r~y z;&u#gkF5R80(Gb^)h|oHjw&Z8R~`)LJ^#6BZU!)%o3-2Xzk4VDt1lsIzZ;%7FhhmO zoR_)>>5b7%?$BG7KK@ZguPI+#O`CuJlKM=x`In#CJm#s7GWh0SH@GyL&v5rG7c8QF z=ls;HGNY6iyzwGXd3s%a?f7@;6jCO&)?DJ zg>T`9hu(BNRW&^x#n)JqIVI=#Xkkr!n9xn&`b=7UcU92~%1 zUX=PS`wDltVDI~c$NT-)*WN`w@Z!{`3o+^D^_y^Z%mBShVun;c@U+nlV~E#Z9r3F9 za`DW#6wxib#33#o!_!P`G=^_JZ|P7<5id#2i=~L)u@tfK$7S7O&C3<5Do5t%;s{rg z5WYe7-*paJ^QG|RCwTHt#}GUxGVkNrJ+o1R4dvCTQl1axMX!fHR)T&61mYK!D5i%m zRKKCB{@U#QS3i%-UB)l>BvrrR*W3Qmyq?*)sH;{{R^a0$`g|3Yl@}lu?x_T12%V@w z$x!s{le~S!V^A)vfc-)iO7sWyfT}71&CoP|P*t;!@%HAIqN-y2uuMb65(X(K1)zDN z=|(V!D;^|XF?@Zv2&EN^BlO$713^U%`bh30h?wWCqK#c|$NgW<@BilL?LYkDZN0Kj z^!7#lvuk~MzUoR;n7KImn*rIUsH$$ry5d(A{GmG&4?UYge%swxKzzh9R${F;{^O1ccV1zp1}pfoZ6u)GWU=5>Y#zElv~F;j6TDf%W!g|1G_@oJbeN&S{+o@`bNtB4l0Q)aPTYtb z)TL}Hgv$7E7S?w`zbenpNie1v?W!gyVexCw;Cd;CgfJIt~ zA>C%v)`KJ?t%#&G#UhO#O-@?1NK@lAZCY(3ty-k5&uo!SIJkY%YLqlJd$adY<4L3Q zSE8iVl%!AOy@z?3DCsPa){rFgo`e3YS*X<*)NMv>JxMaus)!n`)a*X+tFfq`Iq{35 z8qkqn$vg6?Z5{c<89V!;tII4vN2U1i-!?~qHGi?;ZVO0l{7rDDoL^t zlw|8M)Wkz|*-zn!BEg?XnqR^1bK-F?Cljv`ZxQ$?YzZ20KIZ#Qye3{AUW{fair=r{ z%GS@zM&g~yvZ6rSO^4ze26X)Bg!s?LJ%1lG90QIWuw_HyUWXm11QV`l1a!O)dM3Xi z!=9`g4a6Nu@S%xchH*_67OY`V3(9bQNbw=u!EHw`8)p~PyaO;I zP1jXDp!pGlSp*XA6>kOYigjE;#ZcP_#M=o0O=>L^b3^VHFtyymk%LoO(uKqC`>OO1<-Q#z>&i#sMhacA(F{X5!fe0E%iJH!XBT%~_4w^%UspTh^;F$=F6AZ6_M5!S zt!8|7aq?gO@dNR?#u=|SXD^ByS#^tj3A6p27u!D--St`X%GI}!d;CXgUTkFbdp@$d zZqi|~-Mw3I36uTFaJ_cn*$h|65w4m~ZZC+%<%$hwBXbO4u0I*BUvKNjdi#H-#>Juu zZe7i9jDI4|JGzCJFwY^&6ip~RnlShG@44!tvDAA({leBp1kP*oJ=0rU%1fB&FnRH4 z!rWL~QK1YWUV|L*YWU)V=iA@9 zVrJbzUK{S5H^2Lph(Qw039(nkkI8T=F)_~P#A14|Q!hEfAYw(1h_$}F=xw(W^NXQA zBW5L%Jn1&p~$4y^Ty)^Ulh(j_J)M*OJ$j0`-cQEG!8x|)b#g^&YcD>nYR z+O@>QV4oA?^{UKLs&zLn4G}R&nmQo1dG@3kiLTNOThh>tiSa%s7Sk%%9WeioL=3Cz z91vSk_t{3*5)(syM$D`dYF7Drwa0tU60z$X5!*cOsX2+R(j_(~hWwnEIWNgw<>&eD ze@m{iA~iAQDqEPVthskl{{&a*5@2G)p9I*rZhf{8Ff4IN)hGpE3u@;7BLSEj7Uhr3 z35WrI3^2YAx%QdY|3+;R>%>z5W3dIdu+1|cdnD0Qx`db*@QX#q&7H`hOIPL-F)WL7 zK&<|^ZChMRObqxrv6!Ov4?n$*c}i?ja6;^()_$hx7Hs>i%&ox0 zc%K6^SMl(g9e2+8^*m~mH>C!~Vhe6zv$srrF40rEVPy(>F)`j}CAO6q6Ek{8@`%_- zN5pFGpS;Jl#KeFM_u*hFGTDX zN5nRN|G>CJPw5h3V#LphnbSyIUiF35hrLF`{^f{R-OgWDB@%PP5|xoTK{4V_Laa~T zqhp9zjVCoS7F=v+!A1QuyG~8;l`bJBM*K-y<+?-9`|uPHyVV9U<}5j}h8eGoO6>Yw zf=mqgAsT=|elvjO(;^#B-ERvat96EK`mdW;n#dGZ2Op*IH%@g4GBM`o$YT2b#pf?+ zA!MVRA*;Lh<$BXy3dr2Bt_NM281-{xX6Mg`VE+^flo7Jg&XCpb-Uj_4Ho@yYWMbUU zkeRZ421mBE>*6a3*==@_?O?tC+Ov}DI(^j}u0-3vbsI@PIP6%B;k88^anDTmoV3S z0lpA+;q=Q-JRBud>OAe1we$b$-gkC&0@ksfY!q}G&`!a+eke$#VC$`_VNb201dR6U z)oIv-tq0Wrw%%e(GAhKN;9RnLb)9hgDB%`Uu&#f!U>)|B;$)njtYn=ZWo57#AA7WA zv1FZ`q-32c2essd;&dN2@?sx4en~<>15N((CwKhN6(sKPWZjzaZ@g~$LCr!D#@9~V zai9j<8E2!pQSmh1r#G(HJ~?CeN$1y*0?&9Gr?LWf{M?qMyC8V@1FiM!Kz~l#a>H^z zINxal7FD?8teRe~okA7vl$KxT(Oa5e@cavVm+2E!F18S{oOmF5u79e8mXZ4Xd>Y znmA?z`HNsalRei_ErVi33Mq=C`C1gPhs)wFer+Y>G2x3tn7)k$ahmBS%b| zM&rw>WIi#qV%37d)HtWwV8(hhZfDJt=A%AFk{&Jn7u8aaCXO9JvB0o~+Z4Y2F=nc{CYcXap1Eb! zF~saXXUt|l^w?ncG83}_&dl7p$+LlD{#sN>jq`pR%viv|?W|?f=y^$fr#lQy%m#Q7 zGqYSc-ywYBEzh1n&}KS=*82HZJCZ_6|D~ho(PS|lV9?B+PX>?DMn81MZv^cDXVBKZ zvVL(=-{}r$Vm`p3#fbkL`TR>hg4W;+TJw}oZ%yhu-2qL^2RJlyts3_O=TF;jCqaAA z8MGDeuV3L}XktDP1C58abLB_>^#X=3X8{@&^Y06PV%7XCRyE(y z0;s@GRc%o?KV{i`1>>zuOvtadU)Wz!Fz!NCiBlBz4@jspY@kdoijhjngWq+!^;|r` zL*^5ZECGLjme<~nktHzI2T;jWMt$N$y#f>g^=pAwcHBH+46gUEc|EG@&#$+4_w@gG zUH`Wip7d%H%0WM3-pw;q51QZ3yw6(_3;PF97dx@Azasw!3;RvC&;t1h<`Uza<+tiPW~Czkb( z7Mr&U+}kRFo^WvcCD7Z-`V(if1bRYQf88gcc(f^?@H@)-Lx=(Ox7%u#zB!8nQ|Hca-%H_|TuV z*|chrjxFooR%_mRkYqL;Q`R4LVAh&XUDjV9ZPl^|X}?9<`pg#TgoE2B-B#A0v@+T& z>yNjfIVkII_8#$N{WJ9-Y#K^+Ru5uWq+_f4qZqX*3SvD-GSV?s{o4c-5~})_H!rw& z3cB$}&2IdTwr>1|mW6%yQ4{@mUhqGzA~>pC2T?*FrPWcon#to?rK{!mBT?2L<=5FC zEJu*qN}?w#;vYmkY%GrO8z?M~V%t`wcqOjLb|1V8l!#XqsPI6dCZJj6$m8l}Lx%@< z$o{A(zOU0IPvj32S2>*u`oo#iUx@u-T}B{C@0yqRjyo2=4yhjY2ZA(^9KeB06(Q~} zGof-rmy{uRfZIjo_Gl^c_*&?a55?)(11Y)+c^;|l(1b^VGLMW90)GQkKwj3Yt0oZ% zu&x?Jkfg!nEKW(vwu3wSONHwpr)WJ>+?}Elc@QPjsp`IFK*RyM1MNyeKxBbKl2t$I z!AE~Qv1B{&pbmNfZ`07B%@jQ{u?)LEh-Eg=OVV^I!;VPJD#IT4M45WqtCwl@3JF5+ z+UoI7q~nQn_80FQ>vvQ%Jg)M9y7N%K0XL{nY9ySXlA*Dy$nozQ0j6nb(23LyBU*MI zQDsamcU+V-An*`?L)uD8`(@wJDF>N?PymV&9cv=bNH9HjBX$Iv&Y+78UdTzU9Cdv24OZ2U0&_a4%yjyNO2(?+vc9v{`I zo^|<+#B8=RW{vmFzsbGKgb5La)$}(@nHhY2=;ZF#mDBjd<2IPFF;{M9&0o|kPda?) z4nq?rL^w2!59;WAX8gEY&(0!fPdJ0t^i$oPNrw+z0ZkERL^w37t_@#;QS;-Y+X>nn zXVBLEFzWrJ(9(Zb9lA6{m=ZCe#fX8P{Lhf(DG{?zLc!K=u5K5YXU3piR?z)}C|V@D)&o~~e~is?X%AKm#a&A#im z#tb)O6LSKFO^&H|ytrT06@=}1XV_|LC)9LuGd3|Rh_N=l1K`=Zqb?yEqTnUa*0!j<$mR$^(y=?vvV~miAjGsB04w15V8>iNmQ) zymWmrQG3xDwWi&>uXi&wF*9H_kmVR^+tx*fQWITbgBtUr+}5UlxO7Q-e$*AlCguhl zo4IP1$GOXHsz05uz2ppA>&h=ix93OQfK3zA1CGt?+4yXMoclmFVSCvbw&rCk-fRzB z+HVIy*QSXXLK18%TQ7f^ur)cuRs*aZ%n>*?bE1XMgNFCpdJ#F&S8ZTp;m1xEel)-E+MnHQY+{PQshJXpe82_N z*j}?ijX6?It$yk5@$EU%wBJfWU7MI9#MDXVJ1KsByU%jMw$vH6rp901b2T4&pc}6<5F8B{}q$_M-V~%tebEFNk zmuI%)NZmk97jpzo&8*YPQ>#vWUaX-)eJE6Wu(y9!m(cKVY2l!(t`*g#RYfHbWX@3Z zaW&nqqe6Ws0Et3FjX|`DUPDxTy zPr)uO?{QfNR8@qFO7n(PgbPEZWfc%B%rdYXT@Pru3Km&Qe!r}HeeDz!?2jUaSRSfk zDXFFgu_8OjD+(T+)q`I*QeIX7%ZE{+zBgH+zW>cT z`QLsGDAbp+^1>I6Qeiu9*=TXSFmT2j+~zx$^gHB)voAgt9N)Ha<{sL($WLuuvaDZ+ zE7SM$Z;|Z=D8FNaRi8FgkG#_lRr-#2rhGWPqIYea-1ds<$FAw;{pzj58gOQ_jWcgV zkv{7k)xG^mzw9ERMjtiWxEg(ltu<+pzGTfbwo0G>=B}UA_YL4SSJ}ADEa5h5CT;{D z?^VamSTJV4jBcxKoGv)fp$ri}w7%(5q ztX@|Ccfz*D8MXy=Pj6}uTiP#5r8X)KFD1d&^0y_A5VjATVQc($``q?&Fjrs`$Co%Z zGc4xG#5vz;_Yt-aonfn+v3Q58v5DhLG1z!E@szZ21z}t33|sTO#}>Djy1K&J#POvV zY~0$^?+%(pj`SlN*jNe$<`V#L*?DG$O}1 zgxxQXC?ss_oMCGjJ-OA**bH%aiDQegwi(yl{|jOJ)ETz*W0y9Wel)?_(tg1u8mS@X z2pk)7}qUcj=-^*OQQJ{^rav7D-f1LoUpWKUGU=uJJGF0d>Av=o`ZzIN;NZ;3Xg-HO z{`EWWlL2m%4RCkRqVhe=r7n2k-L9_Y<`?sYB;10NGd?G7UpwR0^!kPi-OWu*7?NJmT{UGNKGg*c)i0(DtX{et(@@_z`LtOC?i**|*536kw$is7 zjHLT|U-WQ(F=^o7c&_E=W0nt_#=9STWFJ1&4dBGQfq}D*w#xZNH04_XVuf#ffPH=Vm6Ve6{q;7x>R0Za+BW_H;v$ zySa&318=B&(AO*)vSa7&5^AVhZE$132)DNdtG|D8A3oI$1}A0>toWjA?gil^BdZU; zc?yARaR{z;(a-J%7ZB414$h3?c%J&yJ?H;S;I=sgx20~$K74ArFRMWp7ZCFX4$hp` z;(6*3(@v@)aNC`MYiN04^gev58^DQqLlTSg9eCDR1a5~jaOz|wT=JwW+m6haQPE1QA^?nWSJ1hoKGl~Xb6gTzaQuMZa$hX)l$=(l^K zvbn6I05%mV-v5-}YDkpiXY=++l!V{?2V3T+79~NkW89fvLQVB!yuAyHDzEkB`Kl|! zLn3G6mojyfG*@*$3W`$!YXc?XdwILEql!vV|6j`P?d`_?Syqkb^+@j@UMjtcYWBaF zck>)ov*)+-#?1TsSW+31Z%BD{t*3L4WCapMNo3H zqC$3tmCHUHx9j^|vLqUe&2LEyL?s-4w=7AUghOgs5*VW;mSL%i;FbLcWl03m39=-Z zxkf|!cgT_iVXqcwl-2L=?L{x|cgvErkxng3g6ilNX-kX)q*Y5S!}`pYSVqFZ?TckZ zWl5MWN_(;-3Pd0N9kL{T2qOrxB#K4)_sf!K_GL-5Hqw?D2T1=1Wl12X)2=K@(5J%h zm^Q826*WpGO9E8|OOyo0ZIS-{vLsPaqXb!!Gk!R|W(GR)Rs2&N21oDK!_1`?zR+S5BQd>X_Kbid!PX%IGR6ZVkf zMGVl%T7QD742=pS3LsA1B1-`((wo6ZDMAC>uo|VBpts;B)h4Xf)*!YJ#Q#{&6fZUb zSq%Q{&|JVKHUp|E=F=xCNF+S{i6+y8XTqJT0R^%mnj%9jB@oqcfNEE)FayF8x<$>e z25D;?Vg~E+_jIu9f+G4{(z3(ZxfqHG%4NG<I!A?Wyo2amYfRTM1eEY*$_Qw|8kavr-Cp>*Lt2p!Wonv#R^(JmRopEcPxoc$m+|qrI0S)vB zlP4*`jaTfl!Htc{a(kQo^^B$a7+Q1#IAQW68MxroKh+Yr-<*M~ zX{sB)kD*03fD>jAvXzJzP+jJTbwUlPBidp>ah; z1a6cwaO=m6xqlx%)eYc;v>_R|GhW;~n81y82Ci||y02UgPRtuvQ?>F2K3#Xk!Wk6= z?lx!O){fsaZy!$84HhTn4IG>q{_+ZyN4&WEXBFUXw*hWEtpKlO!AR?~PLxNw5G@G5-KV^cMdtFuAw zPTG>jZEnS~1+zVFMNj^ zz?6zS>!Y7Pp%Q|3WVBa8P*p~K`NO*zz&kDjc81z@LIJ>@n5GZX7{02iL9qg$nydn# zq3hmblT-lA#K^OOVOR|l7dm93gSvM>I|Tsyp#UJXq_JKDqS}~eP?g94)E-P$|1XF^ zcg$>I&IB^=7@}{l>K+*3uDl^CnOIpBE-A0|^v4VZeg!oTp<0e{bIevy@x$tM2UT?e z2z2>18EO3_zw<4kfMBN=k9t(rU^#c%Xmq`!-PP=6Q+9a19^{*Q95_y}aq3=A^xWoc zGk178TnWJ^*Il~R56*YlfR!0#Rrli~VCCs`-rd_zrMFaXO>o--hk?SmPcw{^b~%2uvkS94L`1mU_y&eUPLglsmIh1lyCmqho`I$;wC5AxXHf4 zO*YJ1Onkhx%dUHjT*uuu&KJCH<4-Gbx~oIIl&+(7Sr?tqO?HNE^VT<>BS+HC9(qK{ zX})u9WR4<9Fh5B+%^p_}3i-TIX)cDo*(IQ+!X#i-dWyl?C&gl?)`bffq< z@yk!%zwZX?0;>}Tpwxkik+0L17!{M^MOAr~*BthTDXR$GG-v2G-}l0ou16;hK>?l0 z%U)(PnBP!LK9=Ch%8T3*f??4{jCzeVU~I77F7RpaxW+-`N^I26$N^K}}* z(|u-MJL50Ij-YXf@q)k{kbjqQ>+8_$H{Q6NjNd&d68zFqF*6jmY#br{;RU4RS1(;5OI1Xzzl3``7eG zCyT>SsH<({OMJlH?CD;r-#(tu&2)yYWoF|qu16>4609V-#?fiKZ10uVe0C9`d%zjG z&2KKn_Q`#PCN8i#F_&QIRKDiK+*|+WYkXY@U4t`pE1stfl>0)L=4(08*U4fo!O$sY z0X?2eEWS(YOXwbShOTB(>l>~|C*~3iowZS%H`soczaAn7JIe+-7Mk3~LX)-Ie$?D< zbYd#O$(gk-xq}^)*>EY5d&n8NhF`zB$?fFCOoA7VHjAtBWr-IJxMmZQPgl?`gboHAX?=$^tJ64zGOFXH; zim3!g7b6MMJoxJ%p_}InUBj#IeC>L4VlKgJ=bM%FxR;#ZTf7X8Jf7) z>cm`vqvIpQI-f?IUOad^IoKy{pktxQ7#5l|tY322J{_zJ+N+pKaB{r8>O6sc_VfE+ z)j@874RRA`MK&kbx_(<>_8wYm(w@zwyJ`dVbz&}&gx&6Q7Y{PPZlMi!chSmX&aQsT zo%1}dho_3!1S`(1#2DU%|LiRjyhYC7ZGO7urw#y5Oei=!bJ&TyiU&URzDDq#at3eh zCwpGVaKGh=IR!5ZZ*D^2I~33T_R{MJ-qX(Dwbt*b?UV{$dM|cDpQnmh1&3#rvE`oj z$y?7FLGYe&2yfw?n4d`DFw!5Mm{)LkX0zqv>_6_9G>x3?vo_!{XIp13o*1jSpPiUk zFm~qlSc3;D-`v#KAa>6=W7jhF1)$_p&j*jVr~%wkNes`%s#!I;63jU-mUZNQu^BThbN{N9A1o)$RDzPJ7(a1 z8Qm5;gSUS7jCU<~H_m>)dgp$flaGL695nqIifIOCXb!3H*1Jf1<{mP@7o0Ji{?)rv zQ@Y%=ho^~YMv~S$ZD7}0avLw&fX7_!Smtu4Z%mz{rT01>^n99_WpH-pG$WtzSwDSM zA+cNHklj;_=cIDE>Ca9~GB`V)WSEm>3$MFyGd0?mY_MZ4m)l+Ays4PDND(KeKRhwX zVD;>z807+A(_WULz3dEL>-YCy-?8iAiAe_F8Tl%orj@u}#-6jj|ClVV$r-$w`&PVT zx!iVrZF;Z2qb^U(G62t?uaU2%3Y~0cz;lb1%_4ZOID@x#)V<47`r7n|C#D%e+L6MV zEy_jnft3>mO+1uT2VMm{QZZnv19vcgx@6vKJyS>u$fz8w<3A{zi^8!e++PwYDXSO} ziYgCe8mbT9sWNINqd06pr$V&*R20C5ca%&a|e-=i&`Q6v8f3s)y;ohF5k*eWk71#PInG!-z*#{{RDIjADyOBhS0sHAuomM%# zc>!q|EH$r6S_b^8k3KT_Kdxm^t@a+a6eX}@Ke?IBw7ag ze^ARH*j9}^UdsSQm4kZFsz&}hvos*Fe!rRlb$6^G;8zQZBIg&CUVC-7QfygZLWtnX=+j4055Uysrh7V9=@@uEbX(hpL8I z`8^m!0eY04j~<#>dEYO=jpDDYs(ypCiOg#D@fr!48fu&e17MFr?9d88f$-zKnJBI9 zi!XccN2lzkzbNspE2>8Bo*kI@;{l1v&MOLP^pp0$KI)iT4-_y+L5T?|B(1NW32r1| zfNBl876__=sOUp{E&D9!B`8Kv@}pwAjt(s-6rzu>qYlvltMfy3dg+TH60Q%`@h5mRg11egtpaTvns9Uf4@fs}naqk{) zNtQ7gg!idvno2?z?N-dkL$klAR&P*XX1)3N>h>tukIs$#(NVfyQ`=-D;%e(tX?RFz zD4GK0C=~4{%gZ(u0;rxn9zlZ;fo%A775`)XAiDL~68%A+0^J4`HH~l-o=1ru8CRn} zQ-$)2pGqX@(Cz?0Ep?6j_{#Ut7=S(p!a)VU2JHpZ?r%>JpiLzL<(&|Ek~9{V{)p2f zAVme!U^G?oL#)a&0=&qWXC+xKGFG%Fyf-uhDQYir5=8*Wh~TXSQN#&GCYV;@LuqLH zFsmR5U_k4D1T9!2jC*AeZosdzzmT7RH3p!*Wj5@%-_MjGz-2&x3o>H5gf6B{`6BM3 znQh7=Dt1mm5Fz^baiCxokWC0kDix$sNG~KR@epSeWIv*41nD8PC%j?!Nbdq&iv}e# zsEuyAiQDjGgqe?!ASo+Qk-!? z9x3O`w3BloWDl!ULa!H$icadr#d{PeYy!$0W{PndDhB!Z{XDnh{HgSO}1 znte6FTj>m5smi2cyBm^SM&9>58Mw= z7p7VWULaqgA#bQ7@qx$4!&PA_dH80#B@Z>ZbkES``vLGVXC)FD5Fh#TrT2Gv!(?H_aB>kJHUXJqa1r15k7 z&Nqm1hn+9jaqvx4NbemRO^yob@xFfcyzw7r^xGQQOXUvVwQ(vdcWC;Fr>1?J(c#J+ zzCOF4FO@rNwgHRE9dcInzdis~o?idmdD={RORH?0-S(Cm7CrWH#t8#=)jxn!SKBys z1gj*-e|p9GC-KumgmQ<}9^=X#CN#w4sLe6|AFrhws_AOJ47g-{tf* zjReShRG&-MlIM8e#`!|b((=g)oStIJD(&5FD$mi-(E1I*TjLDg>|N6yOgZb4{_uoZ zdUCs%vZ8>UwZNOuJ z%6Jy2w7fZ5P33IUpPe``#n`F5%hSzKJF4Y0u6O$E8?hYqO-9 zS&eSt&p#g*0K3m^u$xHRG`Zb1H9vwHO*I+4_O-_~-L;3%=NV#t!SV5J={gTpKEB0w zJKAd1LU1<=DTUa1YV@(iy$x(chon5$K7TMiP3*&Am{eru&r*dd%-~+grbL z?PNQCH|>p2%r#K;TF*DLT!Ux1|2+T2S%hzsGkh&e)~xFg_{3}@3BCc(emaNneeDe2 z+U=Vkw&QoxdgB~*ePY7F@x^p~SO4oTR}#L>&hXVw`}q5gfX^@H92{Rv)VKHScmA93 zed7$@g4XrV*$G?H-TM4u&cX00X1dEuIe&fC(!j=gvY+3*5v|JmGH|^<( zDTio$W=*H7TKtC*y)Dk@)x9-pm<@X4|Ho1f{|8P)&*v924o=SunR&u?#+>DEQq%p; z20a$Gn6|g<`9*g8ZrbA$GY(!$oY$H#I6mpKb9WKG@15aW@XE%W9Ri=2awNewb6|-^ z_WaE&cX4Sg#h^g{V6wkiV5FVXZTjMj=R&2-%ab)qty3_IS0dM zPD`3~6i@qd>1e{&;tXHIlD9wX5ctHLgX7~<;yNE~ksGFcL5_Eu4SXzYxr>js-21{A zHXLu-8=shR@V0Bx;{(lWHg*1vjBmRQdd%^1dTW=>tFz-d(j8wwOgUJoY}qU>#s{l% z|GBG#@a=GhZ^ecu-?f7;t=CbZ>kEh}hghTC;3-Gnyf1eXz8{_8TRVI7R6CA0?eU2@ z2ghfQ7x2bgzsh)l@crZrU&E~CbsYkqm~?P_e3H%FwNbWQewOh4>|ow|@1UUfl8O?Up>j3~ySm38Jn~%sDtc zv(A&bFr)r~pNL+oGkR;EUUO+jpeLproSs>%i0@XD%ly7&<(U@crft--1QIwR8l0K{4mx_;}c27Uk{!*}TID z-)?94T3`I8(F|K$ZGAy8=ivCvV%mJP>hbzcuao25V*?)xTPCuwrDlBn(KZ}!y0_h+ zm~zClU7m8Be#(f$gP^z92E9qNw#B5k``K%=YiU{EzHKkP1uW3_1;w0$^D`$pcq#X# z_g%U(AN*=Mr7lr$H?520{AREGbfKqX02H$hlq5&!!k5yUz26;Y^;ks!Z*>N+;hwtI z4gye2JUGCZy5yH1WL!-EYn=gXT>l*;wmU=<7IO~{kcTgRK4|s!_N*>I22>g9GFq*eo@21u9BpD4RyBRgJT-- z)mI4b`1{j<7+hKEZF2W54ws4lH4Dv6*3rxqCWkK;b|fN@=?>|LHDbv{tA z%Sc}5L-Pd!il6F*7^rK2MVaStuv zGyJH6hZ=V@nPLMVOXBWkN!*I3KC#I!otK?b7buQn zB{9GYkNw9+GQg=e_%SEU`8BS7bCS&#obCX{aV%7P_2;WJj-~nxK8JqD*prVXfYY1- zY+2R%Q3nAij$;9!ny*n8sQDD$g0Ct)U={(K?hN4O&!0gBHCtU^x?7+)j>T#YOXP&3 zHHSm%?mvgRzRS`VakQ2RBgF=?G4 zV}dU{k4xWcUV6SKdhIvHuISmWR1kCn0;mHUQ2mexK;iKArGk0|g4hcVsUVE-D0=q3 zw1NT{*irp@3N0g47}tZ)H&CEH1zA z=G{C;vJw1te);(q+X_H||C|Dl0b{TfJmM9AB&bOQqGBnj6@W}#8uK!OSW3XB8&sdy z)JuUjfhm?E`6^2+nGWblMI}5WpIr%{s8~u$2_S!R2_U~EmSSn)*b+-gIJkYWl&A!d zh&oOJNDfML-%` z%G5QV7R2bJrA!mFNdJBXAl(v6NvQzj7i4KH0Ust`bF#M|-L>_ZEz$`Gw@~f z4)gZmKgLv&OnhDC0HjAC-(ZqXtpH?_j#mK6)FI%)2CpsF7U@_4pf;@~>p?7g7iF!K z7+IQtMLMMbP=IV9%iEJ4)RMgc>7ZrP)@QazCmh^9={5l%lC?UNlZJ>6S*uC93*_5; zOag_O^zYFBk$upvYcuJ91^V~v|0ovd1pS|zwst#d20HSYW=DQkTSs13yE#h~Ou>l0 z7S#Nhlu(%TjwMVIC;F15_?U_egi=T(M#9vGb!4K5hCdjD!OzP&T(aIKC6i1X z2)Bv9&^ZZ0P$#O360b}|sxv_@7z4=@#0^H!dMG(g7bO!4CbEj~1tE+M6(61A3$7y+ z|NFS8B@Kf6h31dmp0te8_H*%MAOi!=kimfv67EflK5h~j7!~q~P+VjUQ#Yy%A(Q(8*WU>&f^=v#@uqykPbS}g$CBwbRVmea0c6wH_Xfe4Zx zfm#%3qPi3MVo1;6FKkVdRftPjuR2bx1`<3lN%j|tIRTiPW*w>|VqzvKTBRic6#w{o zRE0~25E7()EKc!ySt^8W;Ck^_fwPvhs0vo1R*t3Xr9((bM;vZ(ic{*zf+iuf=-6Kc z9Z0aPoxylw6mZByR7^=~Fiww}O6)x)*FOs4qG;-B5cfcr zruWp2{1AkSnEaL>uRZ|Dr>Lk>ytGy(J_u+Uu1IoK3iOw(*90Z6*oPu(CRsBBB1mcw z7Txyp<55XIC@KY@Nre{x!7kO%Rv!H1d`c-FKnI;l0!TWts)=1D{C@k{4Dw$uX(^D~^E#HXmuVva}6=v_iR;A1xU zu^@)q-}+^*&$sD=(;c8NKck}ndfQ$LzE!I9uMZzU0B1V`*!1(GKXnj*!UT;8(5y(v zT}c0xTf7ADac2OVznEfouSJ?$pdzFrCO}?ez~C0x^zK^{0er$4z!hK2v%A+Koj2^E z4^+fFgg0Tn&dK2W^^W=Fq%8z+jx&JG56yhbW)PF^0L46n1B|i2`fs0ppPcYq8vt1l zGnEA~^-C7|?K$Ce=O?BiOhiL6cPsM&h0(P+>xkbxXZ-4#UjJ9e;3sAwtf04KZdc~3 zbl#tG%Y=d8H{S-oDYRjf+uw?J-(K#?9&KmGZF{V4y2h#l(am)J=bEj!wVxuU10%oHK-L-)P?4VGxSB2}fuav*A(U zN4Kp0ju19FL%9Bf)`j*%nl!gUF*gB1J>N?2cwgA^kIROULw?={LKf0YXCck{O^+XH z+aaepK`}MqMre-mb0d85{DYSf!Ntx9*55XDkZpo#ybptVLNPPp1kL=Ik5UgkcThiS z#xK|)$U+)!gAG%EecZl3PIH7}W`Z*6$d6-s!f*Hd_6#9>(HX*)XSTL<7=&VK!V$(; z;onAoSw#q!I78U{_S)y{`{VSsLQTv~I6^a|;fs#0?)H2(A$-Xh!loy`h0M6!fvz-O zxP`t@6LXU!4f&^rt9uc`mz^P;{Y&#Q`yox5BNTHJj?f&j=I-Q|hZ|W)(_{l73u*3Q zA&_6ijNdTH^vA}&R3@j4{x3&q@oBQ&c8@oAGsu8#ad4tco^ zge;_)!9tqm7vB04cgQx3Fue)tVrml8j`=v(hQt1L2bGLm(Qe5|-H=WhcJ5vrv$9k6 zuO9D!c8WgsL-`2W*Qj8x8SREt6;|`{sJNPsnW*`QjYfXSk4<6NMHP_St@x-*C`9QE zQR&Lcs&Gj;3QNj9?1jWWI+-@s1vNufS;@yk#FCFabpyL*An#}BnqRE>Xe6u5s2D2h zN*=)LSPrQO7lum9DoScSeKen{83uM%1~mMX?7!^p4J$gq*52rpwcC>>_x$gFStu9T zrRU|+9u@ZRrj1tD3$wkv(KUVY`q@R-4I6zd7`|oW%so`uf}gti_IJ+JvEYp`8(ACJ809SU7m+aSo&INpxeF4{BBOyk_|c*OxN2u@&R znvV&ZdI^)y`(PU(e9sxe1z)}MUWY*_4rp)(zq3k%?H|c$Z1Yc9MOs~LcSm8 zx&QJmIt>KZ+8{WUR$y|1t&h#RD*JX?jBckVOlJ`&>I=mott5S6>GS(F{s9a>vcYf~ zt>`zc@U_=wc{&b6aZrm%c*HuCI}UvLYJ&K&Gl(rSzuMe^Ad16U9HKe6%7?Y)oqOkl z1o0DR5F3A*^khbdu|z{0+2RoSx&?EDy0QDF&j{i=XAtY(8TGRxi0Qi~7=5B4W+_P^ zp1S4@)+c`I3}Q>;t`|EwvBdO-DCQ{~qFGU#kFg$;CGREY{Fx1i%sJo3!kg9~-#@P7 zFccG&B!>9%n|e7h{M;GC1)q;A={O9md-}a^-A6s*dK(N`fMeR>f(cJM z^3CbHu9bR3F;`*b?X?(SIpV? ziX-2g-j*olD;%O3;BdF{%#p7xA&6f%gSff5_6|qBIlUo@`3i?szZ*PM5 zr89_44{crT2x9s!aHT#`%vTshYn3&3EPwgykvqsae`NzA3vljd0Z!eR#|Aoa&S}li zFQzL=3~~9~<1QdW++>3xbIzP${avF*Ir1y%4N**2P$d&P7|kU$d^r2)t!q{h#IKz} zY?%1N$_@lkOjkHWbM%3)(cJ#m;Tl2Q> z5WjH-vE}IvA9NsyV#2~9nh6Wv(|vf$N#78}Z=FG`e|ynPN6tCDEm2HZI7D;wfj8%& z(x(@abKYVDB6H3&nR9O3@pBg^&N;mqiunp-sG7wEc+p$$m>aLZ91Bjq149z@`3q)H zIHpzKeArWz?5T{DRx!yQ!$1k_V7rPwSiq*M8hsr(@aTz)aqaKTYd=YAYWT%>Hm>W5 za@_mY>VbL({T6BU9PQ04uP7@p#aMh5m6aEqS?lSg1QZq8&5%J3q!jDXf>3QnZE+k6 zF^FTly}7VUs5}CB6MD2vL-kSVb?7IcYB}naYb}4fy%S0L{9rx+N&4{nf3J0BYDphM z##%j8OIH;7xx%8#YkhgX>dNqt$l190Of2RLVCqKk8z@YVg6_fWUf!`BD1&eg7j89j{%lNRc z+UhcXhmKFsr|61h&@gMurhmVVk7kih(DC`~z3K_~p!aSud+!Ntz4xpic4b-89El1d zP`hA)7qS6i4o&h~x*&1N6q!)u!8?#iss_p^kg==+3B+IAMRW=A5KfELD_b8X9qmFy^63oYV;fR3z&Uex2k*tg9w!^+3D^8XHm& zx;W6I(V|C&_LPAk8y1^U_kLlxVbnsIMmW%!6isFq)jd-LC6PeAasQF z`?5~fEC(!(@tQF(d>z-qVqgt=D75n5Ix=20C6iQBpz8yH6B&o9MgYQsTnQ!a-Sm}6 zheknrf(V_C`&1!&XGzAyX|ZG~5TemE3E>O{E$BI-N5<(mWkT@;F&p~>RYMhy!Fo4w z5=89X$T$$HInY7DrK9gAUSB0khOR4K4*iu>$qyYz>#A@*@d7KR4i728-~l19Vm(xx z9Z_#0t5El$Ee_~XpxR=J`@|iHcWOYf1&t#J#VHj+Fwr9up9wl^iULriSwyj&bC+@9ef)g4#`vaXI8E&~GiCBZwxu-Uu-~R&iMg0 z(%Up?;-MXfVRYUm%FrzF!lxV_*|>2TF>H0laPyY2;f}*FI&EVyRAU019aCDHs4@R) zgCPrWxE;<=BOva!0r4K%Tx9y@)eBI+uFn1rt#lR*p z^v*dQ828w~IGy&nHLwt;=BFmP15pf$DGR5_yG4VKK&*T>G@dB#bw;uA@v%2_B#L6r z!W9R2&SLK1pF4cyU#UsgbWS}M#+o#@#pw^ehgx`a0^_Im4K?yNu8?2F1jMV>H)F@d2?{eS>c#jJ3`%wk(+UR)@kU<}Mtg8Fq7w zea^q;J;FH38OElDr|x#@r_p$#t zO!p5bJ-c5+N|xo7hbt>1m3bv)h1JE8Jd_wKtsIQvVwHJi<^SJGGJ)QKC{N4G zOD!qQOVul=RMJt%0rC_~foU=`Pr(MF(~e642yzoM^E9m4Z^tk(uxN6b9CES%1wcCy zzCZ*(L&v}C%&8Y1%-;*t+ChZYX@~$gal7RPvs5d60>A{=>H`J<$j@NUgO6uREz$$i z8k$^d-2%^oyx&QL9}wQ(aIxzK^Y@>-r(OkW?IJ?!I)wLsES~`M{#*jyhsSYhQL%n; zQ8FUZK>{#eAR-#L0tq-G5A-Qm3hdP)U}7lB1B&IP7R7^wGxO5pbMo^GG{Ewpv;YKg zSO8bhUX}n*uyhmQH$<>}cs=a~^Z!{Ve}ax?>mfob@MtzzmfbLQ9xzzmQaf0H%NW4H z5|1Sqz@dypq(SqSAsAt~R!t4?r&IC&D;{OBQr&y1^XW>{0_w${pc5B?3ti2O56^uSE&xs?5@=2U literal 0 HcmV?d00001 diff --git a/wandb/run-20250310_212229-x0wxhan7/files/requirements.txt b/wandb/run-20250310_212229-x0wxhan7/files/requirements.txt new file mode 100644 index 000000000..67cb06e8a --- /dev/null +++ b/wandb/run-20250310_212229-x0wxhan7/files/requirements.txt @@ -0,0 +1,295 @@ +pydantic==2.10.4 +urllib3==2.3.0 +scipy==1.15.0 +myst-nb==1.1.2 +pure_eval==0.2.3 +wcwidth==0.2.13 +attr-dot-dict==0.1.0 +emoji==2.14.0 +mkl_random==1.2.8 +keras==3.8.0 +nvidia-cuda-runtime-cu12==12.4.127 +torchvision==0.20.1 +cocotb==1.8.0 +wheel==0.44.0 +imageio==2.36.1 +dill==0.3.8 +pydot==3.0.4 +transformers==4.47.1 +sphinx-book-theme==1.1.3 +myst-parser==4.0.0 +traitlets==5.14.3 +nvidia-cudnn-cu12==9.1.0.70 +nvidia-curand-cu12==10.3.5.147 +kiwisolver==1.4.8 +pygame==2.6.1 +greenlet==3.1.1 +pytest-profiling==1.8.1 +requests==2.32.3 +aiosignal==1.2.0 +aiosignal==1.3.2 +Sphinx==8.1.3 +torch-summary==1.4.5 +Farama-Notifications==0.0.4 +sphinxcontrib-plantuml==0.30 +ptyprocess==0.7.0 +pexpect==4.9.0 +yarl==1.18.0 +yarl==1.18.3 +filelock==3.16.1 +filelock==3.13.1 +datasets==3.2.0 +datasets==3.3.2 +bitstring==4.3.0 +triton==3.1.0 +py4j==0.10.9.8 +pybind11==2.13.6 +pluggy==1.5.0 +regex==2024.11.6 +cvxpy==1.6.0 +sphinx-test-reports==1.1.0 +jsonschema-specifications==2024.10.1 +fastjsonschema==2.21.1 +pytest-xdist==3.6.1 +smmap==5.0.2 +onnx==1.17.0 +tornado==6.4.2 +GitPython==3.1.44 +sphinxcontrib-htmlhelp==2.1.0 +iniconfig==2.0.0 +threadpoolctl==3.5.0 +cycler==0.12.1 +tzdata==2024.2 +tzdata==2023.3 +certifi==2024.12.14 +certifi==2025.1.31 +numpy==1.26.4 +gast==0.6.0 +frozenlist==1.5.0 +opt_einsum==3.4.0 +astunparse==1.6.3 +colorlog==6.9.0 +grpcio==1.69.0 +jupyter_core==5.7.2 +torchmetrics==1.6.1 +gprof2dot==2024.6.6 +nvidia-ml-py==12.560.30 +multidict==6.1.0 +etils==1.11.0 +jupyter_client==8.6.3 +sphinxcontrib-jsmath==1.0.1 +tensorboard-plugin-profile==2.19.0 +clarabel==0.9.0 +idna==3.7 +idna==3.10 +pylance==0.21.0 +ipykernel==6.29.5 +matplotlib-inline==0.1.7 +jedi==0.19.2 +lightning-utilities==0.11.9 +namex==0.0.8 +kornia==0.7.4 +docker-pycreds==0.4.0 +mkl-service==2.4.0 +fonttools==4.55.3 +tensorboard-data-server==0.7.2 +beautifulsoup4==4.12.3 +Werkzeug==3.1.3 +Markdown==3.7 +asttokens==3.0.0 +huggingface-hub==0.27.1 +huggingface_hub==0.29.2 +pytest-sugar==1.0.0 +tensorflow==2.18.0 +pytest==8.3.4 +joblib==1.4.2 +ipython==8.31.0 +mdurl==0.1.2 +optimum==1.23.3 +pytest-metadata==3.1.1 +debugpy==1.8.11 +absl-py==2.1.0 +mkl_fft==1.3.11 +sphinxcontrib-serializinghtml==2.0.0 +MarkupSafe==3.0.2 +sympy==1.13.1 +six==1.16.0 +six==1.17.0 +multiprocess==0.70.15 +multiprocess==0.70.16 +snowballstemmer==2.2.0 +zipp==3.21.0 +ale-py==0.10.1 +scs==3.2.7.post2 +find_libpython==0.4.0 +sphinxcontrib-jquery==4.1 +decorator==5.1.1 +nvidia-nvtx-cu12==12.4.127 +prompt_toolkit==3.0.48 +charset-normalizer==3.4.1 +charset-normalizer==3.3.2 +nvidia-cuda-nvrtc-cu12==12.4.127 +evaluate==0.4.3 +tensorboard==2.18.0 +lightning==2.5.0.post0 +py-cpuinfo==9.0.0 +prettytable==3.12.0 +nbclient==0.10.2 +execnet==2.1.1 +torch-tb-profiler==0.4.3 +kornia_rs==0.1.8 +contourpy==1.3.1 +pydata-sphinx-theme==0.16.1 +pip==24.2 +requests-file==2.1.0 +jsonschema==4.23.0 +sphinx_glpi_theme==0.6 +imagesize==1.4.1 +osqp==0.6.7.post3 +importlib_resources==6.5.2 +termcolor==2.5.0 +importlib_metadata==8.5.0 +cocotb-bus==0.2.1 +future==1.0.0 +pyarrow==18.1.0 +pyarrow==19.0.0 +packaging==24.2 +sentry-sdk==2.19.2 +einops==0.8.0 +nvidia-cuda-cupti-cu12==12.4.127 +bitarray==3.0.0 +aiohttp==3.11.10 +aiohttp==3.11.11 +nvidia-cufft-cu12==11.2.1.3 +scikit-learn==1.6.0 +pyzmq==26.2.0 +Mako==1.3.8 +platformdirs==4.3.6 +nvidia-cusolver-cu12==11.6.1.9 +markdown-it-py==3.0.0 +wrapt==1.17.0 +tensorboardX==2.6.2.2 +protobuf==3.20.2 +propcache==0.2.1 +propcache==0.2.0 +pytz==2024.1 +pytz==2024.2 +wandb==0.19.1 +libclang==18.1.1 +nvidia-cublas-cu12==12.4.5.8 +alembic==1.14.0 +nvidia-nvjitlink-cu12==12.4.127 +click==8.1.8 +gymnasium==1.0.0 +Brotli==1.0.9 +lxml==5.3.0 +tensorflow-io-gcs-filesystem==0.37.1 +matplotlib==3.10.0 +tqdm==4.67.1 +annotated-types==0.7.0 +ghp-import==2.1.0 +pillow==10.4.0 +onnxconverter-common==1.14.0 +stable_baselines3==2.4.0 +imageio-ffmpeg==0.5.1 +onnxruntime==1.20.1 +typing_extensions==4.12.2 +Pygments==2.19.0 +coloredlogs==15.0.1 +sentencepiece==0.2.0 +torch==2.5.1 +timm==1.0.12 +mdit-py-plugins==0.4.2 +PyYAML==6.0.2 +gviz-api==1.10.0 +xxhash==3.5.0 +setuptools==75.1.0 +pytorch-nlp==0.5.0 +babel==2.16.0 +soupsieve==2.6 +ipdb==0.13.13 +python-dateutil==2.9.0.post0 +comm==0.2.2 +flatbuffers==24.12.23 +rpds-py==0.22.3 +psutil==6.1.1 +h5py==3.12.1 +numexpr==2.10.1 +optuna==4.1.0 +accessible-pygments==0.0.5 +tf_keras==2.18.0 +mypy-extensions==1.0.0 +pytest-html==4.1.1 +hyperopt==0.2.7 +tabulate==0.9.0 +fsspec==2024.12.0 +fsspec==2024.9.0 +parso==0.8.4 +sphinxcontrib-qthelp==2.0.0 +qdldl==0.1.7.post5 +nvidia-cusparse-cu12==12.3.1.170 +sphinx-data-viewer==0.1.5 +mase-cuda==0.0.1 +cloudpickle==3.1.0 +coverage==7.6.10 +pandas==2.2.3 +Jinja2==3.1.5 +black==24.10.0 +pathspec==0.12.1 +sphinxcontrib-devhelp==2.0.0 +mpmath==1.3.0 +pytorch-lightning==2.5.0.post0 +alabaster==1.0.0 +jupyter-cache==1.0.1 +stack-data==0.6.3 +sphinx-rtd-theme==3.0.2 +accelerate==1.2.1 +pyparsing==3.2.1 +docutils==0.21.2 +pytest-cov==6.0.0 +rich==13.9.4 +safetensors==0.5.3 +safetensors==0.5.0 +humanfriendly==10.0 +PySocks==1.7.1 +toml==0.10.2 +Bottleneck==1.4.2 +setproctitle==1.3.4 +opencv-python==4.10.0.84 +referencing==0.35.1 +nvidia-nccl-cu12==2.21.5 +tokenizers==0.21.0 +attrs==24.3.0 +aiohappyeyeballs==2.4.4 +optree==0.13.1 +networkx==3.4.2 +sphinx-needs==4.1.0 +nbformat==5.10.4 +gitdb==4.0.12 +SQLAlchemy==2.0.36 +executing==2.1.0 +google-pasta==0.2.0 +ml-dtypes==0.4.1 +pynvml==12.0.0 +nest-asyncio==1.6.0 +sphinxcontrib-applehelp==2.0.0 +pydantic_core==2.27.2 +transformers==4.47.1 +mase-tools==1.0.0 +more-itertools==10.3.0 +typing_extensions==4.12.2 +inflect==7.3.1 +typeguard==4.3.0 +tomli==2.0.1 +jaraco.context==5.3.0 +jaraco.functools==4.0.1 +platformdirs==4.2.2 +packaging==24.1 +autocommand==2.2.2 +jaraco.text==3.12.1 +zipp==3.19.2 +jaraco.collections==5.1.0 +importlib_metadata==8.0.0 +wheel==0.43.0 +backports.tarfile==1.2.0 +importlib_resources==6.4.0 diff --git a/wandb/run-20250310_212229-x0wxhan7/files/wandb-metadata.json b/wandb/run-20250310_212229-x0wxhan7/files/wandb-metadata.json new file mode 100644 index 000000000..752811856 --- /dev/null +++ b/wandb/run-20250310_212229-x0wxhan7/files/wandb-metadata.json @@ -0,0 +1,60 @@ +{ + "os": "Linux-5.14.0-427.28.1.el9_4.x86_64-x86_64-with-glibc2.34", + "python": "CPython 3.11.11", + "startedAt": "2025-03-10T21:22:29.549593Z", + "program": "/home/jw3621/Projects/bert-onn/test/passes/module/transforms/optical/bert-finetune.py", + "codePath": "test/passes/module/transforms/optical/bert-finetune.py", + "git": { + "remote": "https://github.com/Johnny1882/mase.git", + "commit": "758710333d8ca4b7444930df91d86c7642652426" + }, + "email": "jw3621@ic.ac.uk", + "root": "/home/jw3621/Projects/bert-onn", + "host": "ee-tarrasque", + "executable": "/home/jw3621/anaconda3/envs/mase/bin/python", + "codePathLocal": "test/passes/module/transforms/optical/bert-finetune.py", + "cpu_count": 16, + "cpu_count_logical": 32, + "gpu": "NVIDIA GeForce RTX 3090", + "gpu_count": 4, + "disk": { + "/": { + "total": "75125227520", + "used": "60982046720" + } + }, + "memory": { + "total": "269555560448" + }, + "cpu": { + "count": 16, + "countLogical": 32 + }, + "gpu_nvidia": [ + { + "name": "NVIDIA GeForce RTX 3090", + "memoryTotal": "25769803776", + "cudaCores": 10496, + "architecture": "Ampere" + }, + { + "name": "NVIDIA GeForce RTX 3090", + "memoryTotal": "25769803776", + "cudaCores": 10496, + "architecture": "Ampere" + }, + { + "name": "NVIDIA GeForce RTX 3090", + "memoryTotal": "25769803776", + "cudaCores": 10496, + "architecture": "Ampere" + }, + { + "name": "NVIDIA GeForce RTX 3090", + "memoryTotal": "25769803776", + "cudaCores": 10496, + "architecture": "Ampere" + } + ], + "cudaVersion": "12.7" +} \ No newline at end of file diff --git a/wandb/run-20250310_212229-x0wxhan7/run-x0wxhan7.wandb b/wandb/run-20250310_212229-x0wxhan7/run-x0wxhan7.wandb new file mode 100644 index 0000000000000000000000000000000000000000..c7816c0c9959c866c7effc0863411253fd092912 GIT binary patch literal 98304 zcmeIb37lm`dG~KK3&UlAnI6^#0yL<&(A;y*U41nw#%LnO-58hL-08kE)6mo1_QE1! z$h;y-R73$+WN}x-U5ODi#w9VnNqmV)jKUvx;t~@z29r1bzfV;?b*zZ5!j5}|+-v=)Ezi;W@Z)n5NuA$9m92UjRu`2(I9#*f_ z8jVkHtW-9fS3UQ<9o^~XyNCW)rCm9lFWY0$Lu%W$^^&L=SL@YnN!;z#8WW9Lv$3sO zt=DT4(L^<=wVO%3H@v<&R-?+{h0~(P3=M5KVAs$|qZ{Yrv(K40JC4p-Jg0qX#BqQN;VfbWLOt zudj{Onq$!!2b?-|*2byH>BaL-JFnH~G-~f2dTQkXI~EpZ=O1?Z>Dwn4b}Vik8=u*E z`tQ!{n4UgAYPI6icXsD{V|+DyYO~&IM%5%qCR*d&+SX>RR%<8KiEZs@qSY90HfnLB z9`olRU)ucT;~B9v@iCR7DqBqFPutmB=uUJOx?^X}&rBaZbl&LEUy?Fw2bB5laZmoY z5z4HKkE|SC*)-RC{^I0ZPvkW}ws79U(L)p6E53I<1=cUA8qdCPIU*e1Q(_)I@yahY zj~d!=;F1k}@a1Fl_29YQ!s6U?XJT@EVb{>wg}KFENJL%CVKr?)&b-0y3JMnNGhvqH^%U%-r~n`SH2QS*mou;YsG>&cwp` zvpt?Py*M>hIkK{GalY5tx@~HvyAU;e>7y%$E^O;e^|p7%&+kml%%|0Fscf8`Tb%Aq zbawQ*6MT-QSB|S}S(xl~&YhW?==64O?M+NDDfrHI=lIG|<2x3o&+g1mzM$9H*5kL^ zX6DZA&Q0tZ8mjW+;mP?=$ZfemKhzG}XfR+l|_bEUGgdmciV-oCJde#X_xCO(>+_8-ysEtP{NW;*lTowHLs z*2x;_^}2IY=QD0Iv$Kf8e{)#npy|b(ovpp@PQ)`5RgS0}rXO^+&&@2(+K-Q^Y@V1r zr#Cl0xp01Gwm03KS~!2#&>Ge*o_O#=Zzs!17ro(`Ya%|}SXsxTI%m_&dX;s^{A<>I zxw&!>V?H_Fo$BaEBbDK~-fVASQgkQt0IBQ9%AvN;!RI5D!>4CD^0V$jXL|cwcc&Rp zGlrv;<5Y~LWsa;Ivb8&5lw`g$DXL?2Q0he{ZXV1t^{Gm<9q}AhImBuD{P@gVW(u5C zIVv-8It_e{TXP1d$R3#OgebUjD!{!%fXKB4_h41-qRF2r%WB=%s ze1HGSAzfzjAPi~tW@62}5THjKtot>N6 zI@R0R$*ze=y))h2+3U>A@dKIqTSjFkKW!`9<7tc2gM6Pq*{+-Ss--VFl*tOW~0S;ZI(&qzL@Hs zPg__lqPQ|r*&^R$7E-sa^BvE1zv#;Dw0?AZZ)$OrpKPXqJvuZo$&7Q4Uz=2$F?-V% z_mhdanOP=pM`Uc)uBdWK<+$`Y@;NJezSIeoqh#TooniHpzFxG%9;Q<*59VoS-YC-L3O8Q;Q3| z(aLb;AlWL%y9^PvnvMF{%COOWluNcQ+ZEOOV=EQ6c1=vq?HbxRYKKtp;HcT+QH68o znB1LgF}H;bS2nP(pgEJMiu-I+2`u@dR?Q2;(+x>@w#${kZVL<%t}WIf!@#&O_i zZ+2#Uhk2|kErOMey)!dCMG5!QlPbruv#^|j%13M2L&Zg zsvIpPw$CmW)L>`ki3r^l$<<&>&=Nl>y5r-EI~S+y{F+}dqsQhzCFaLPU)xI9uT2T2 zPM1!y!tfglgV;jZTFq!8C{v~V4O`9>ySG_O-F{NUE+}{&s^Gpd>a?a_txDg`XA3h6%*d(9os+UN80@WI&3Cu;K=Jc4vLfrcvf2Eq zUu1rIALhOUVVX6?_bIA85SW01%Q`hRwY59W=4ZG=XWQJ&&dl!Of3mr9Ac$>zae^^( z58xxS(#4RvdmdkjmQ z)HylZ11h-B4y~-+x-G*)`1iKi%)gJS94i0r%*-yRCaxWQQZ!mQsdD6w?$kC}3&$tt zWpR_v_oivs`cc`EM`a6;x;xt-KK(7>8&hlkn?ovV7dTSTA!9w6V#A%1y#Np z8q8^`mru>y01`>hSygF@-8Uf}W;mUuP_GrEaLG(LtpPBweGaC=dI*1Hb+&WZ)IL(# zG`Srh(Cb)jJ$o`;7SQKy^Ar4A2H_Ye1pnv==@V%xengntDhHdd#%HFtO)`kjykNAy z6gXK>^LbNU{-5FP_Fo0e4=yiT$k2b zvHkf_WtiQt$1c!iOp+*^h&7@cQfUa*^aZ(i%CNsJmO7d8#eeIl$IxJ{<)-9>k6}W4Onzp`)KOY8UdYGaPOR!7>M8 zP&Ao(=n*t+qT}X)(Ipmsv+6sw+-%m8MHViYKo`9OIIA3vFyGBFD2H}1 z*}09r?u5V20eb8-dtl}G?Dza~l9dC(;TH%e4S4Z9^TGXGmnyC1l*)*VzdIP}gOW4p z9O4LSc5!}(XeXzP8EGEL?z9NA%56rrlMph>yZN3{_2ccw>5jWmRCA8w6*l5g7{@{K zD?|Q74IcjZXPnV_`jejU*x&n2gFWQS*}2JcgsI42sYn@j`Kmj0Zuk88jx^Y}^Z}LQ zy~g%2vW{XW+hD{oIXynLI3ec*oY&2p#mFu1zBcfemXfu~e=REvc*7!8_TU+)rJtcH z@uMmm&EAyT&Q7RokzXzXZcYAkw@&%?gmC!>fo3llngpU|~ z_>-UUw8tajPFB+3shzWSg>+hdz^EY{vin-i&+-JPZY;%S*PW2xuz*g^r(1^}Sx%e} z9rwd_I8%C5I9V`9wUBH!Mx(4+&&!IN z5<%wZbZ;ByJw3qmB=fO5HM;}+W|fk0w<{lKBHLL=XgajOZ>^wWI}S$gyfV5rATjef z%eXmH@r`4k%*>zMPQ!r|)4~GWkJzovkEFnMPE05BfJ3`3aAvxd$)V z+^xqA4?TFapwnwNVA?wJDR+G5Q`EC@$$FmshX0Xn|1L(T4Bi4}$>{aw`VDILpMLDG z{*rDSvUE4DUcGK?lF@q|#y?~KTWSE8*OOOY{koSwo}L`KWKSM^GED1m1ISPJfY>ef z8*25WUX5F|bDr?rkDrKn^{^$YJL{K2!y7kTWelMI@wN~DcxdFVp+inQv3ma(o_KG0 zz3{~9=~3KBswY0DT76ix`tV0S^u$L_E}TC9ys!Q3b!-0&sc(X7oI3Qd_lYR3F{1dy zPp*FRnvqi~hgu9Ld&M|d4RJP-v0BnhTFqM2sMVWMH1c3h(&lEy$86ED?m68_0P|Gu ziD&E@I<#GFC9Sa*mfEP*tdE?;ex4~iH$N|>4vwP+Csek`gDYFIMf{V8t7ENNwNZ`h zQ6kp#R=audt4@CR@6i6uw*8Mvn}7Q|Xg$ zGif~T;rIKKmr&oZt?xo#-*-Ryed>EqWedCR*<+`VEzYx#JlV^nh5Z|mP<0eF;zqMI za&qO+Y|+Qj$xZdfShJn9>y2i;QEx}BkwZ95mHOm>b*MUWCj{_z0ZGe z-Oz?DwlVKK*tg~#w_S4Gnh|VjoNwZtz+Zc4HiALJex%5oS|3_DBKz5+r{wrxL5@p? zqj;=QZ8xf|cB55ox7zhKT^uPaea895`ddljTD#VaYOU4?^u`tyTOd1stuT$iRuuYOyIGG$j;I`JzuVTGeww9v2P8GL z!u@1+e6n}ut`8B z9~O-v0Qe)*340 zc)3D$wWLCBvgS37B3pG9)imlHr$Sy6N7WIG@uIQX3n$|>wQ60JIqnm2Rh5}X_K7%B zWlr)UP6|YR*z#3}@@B*mNti}w9FB%F>Ml4~dnsL3_j&b;F`Jzcs zE2t?-K&BC1g3Sp0lJ+r%)0+e{(X%05wq*09*r&PE=6Av|z3!)ke%@V{SyuTmSSp ze7X@-#DgQsv4hsrxfW;%U>v|(t@0;s!3ApVwD;jaMPs1eYP|(SW?nU`k!nf4&bsh7 z{-AkPaB@2H@+~cUd=w1n=4dQ#H)REBp-!ktboLLgJ^JII{3Go8{C;6Q?=r0Cwg383 zVE+B{tY_G;o>m(bZ&n$|II1m8dun3>lzHT8)1GfU|D2x*?K#qRKxmI`|G(Y(r-ig9 zZZ;VCBHF`rs5YPXg&(|8XwOl$KA}CfzRPa^^gT{{hN`>nxSrkfXxo_A$?j=ebK4a+ z?2WVsgrOLw+K*vc@T-PQ%Vn~jI}p)j=~(woCIIc6&k4AtpPPDWi;Hm2-yZtchU~ z4uc0gXH7=V{WPboat>*Z$Kd+{%}H7NifB%jw1eRG(;Ry;vK~}Hkb|o6ifB#@a_|Qu zo=NUGn)Adnzw)_vu}~jl7wT)$h5FmK{l{C^C>rDFr;p}<#)LD-J8oB#R=WmGa1bwq zH~`Ex8!iRez^W045Dqm45C=4ZrxR~rh@k&C_B;`6@zbHHOLs!^`t3>hK;J&vauYUHd-S`J3_Y!*JVgv)PT)_7m08uBS{dKHw+8)OZ#kA06r?n94}p?yUy&U|Z}D-K}6W_$~Xp3z!1vbZJ0 z)k~s)Y^czxDhVYUL9LJtJ876-ak~Wq_1z6hA{jdqDF)>NVtBY8f@r10U546xAtK=c&DlWLDnEs?k<01XY z9F14|&TaQ4`co^VKMChAz@cT+A7@Yt=?}+(&1S~R)2b;n?_)oym_zoH5py5=ftNaS z+P=(w>UD2{SepId)nsinW;%FHUe_FJVW-^Bcv8m0@uq6-N*E8ehYaEIn!D#1&-UNC z@z%>(rGLY&(wC&G^xMAr9T<;di~tP`z*%t~9OuBQ!x#nFDNiIE#t1=90giSh^UsHF z*%#yl6Ntw+1w2m4nVi>xGdV3Ls26yJGekeiS`&|XXjuT(Ks7BGSv9G*sx_>)fc!$s z6W$CtL+OdFF6$-(% zAcUg>Wo_e=3T?W3lkMREHZ0Rsgiz=+*P+iR!CWt)x4UV{Ir=qDcQ@<1UhDCTC-in)ylK7 zVm5C|V*zSl3dOLeG6+QLwIG6G8$%R5%NSi5%O(U6LopUlhSGI9P-b6Cc0#$L`2i9N z1`{>sJuu8x?ap_8;i2Njc7k1(jT_s|hT6ROvlj#Acf+Hs-1s9tQqmhNi49rI%6;Ar z5!w>t+Ob^TQn~fT7YH*u(Y9agIkx%NeDX?~pLb)cVT3AnW0N!~)u;TgN{Sp1>{D~vQn>+rI9lR-0M{q%ZAnYTytHj+0J*_g zG(8pf9P++IZX8SR_h4Jc2{v5MDe}_X#!0xyfsKRVj3%oh-$xwS67j?=bztKhUPZhj zo)cju@Ir2=i1!@Nd1db{8{WW*eX?D#UzM)dcijAml;`9Sv_jI%dA7oaJkqt~CBmH> zUmKv2_>q^n?Qsd`^J;$5t z0xdDmvn``+4Y_#_&U`0wCMnkmcFDfcEMhr2X5sT?rrH4}Sb;4Plf2HpJjuz=R5woo zDHbZJN<81#+(c!QF^3013ovLY;^%~bBlEPafj-lKJl9rf1!b+Y43CI`Vq&T)=b0)j z(KifLp2h@q7Czt2(qIAfrwwt{8*Fv3@}jc6O3ift@9ii5oe+joQo`_g!*_1K{q==} zfoW2Ynpq~{;$C9>R-?uNzep|+0(E~|pCM4T!msUmyCG2JY)2w|ScdTzWMVUKRyV?@ z)QVKYZGz>ENeS8d;dhgBgz!ATwqFR3ZT_8?eU0W1w8J!T2W>Q~;%o~gE@L#U2E-5l zVP zT+Vpnx+j$@VLUj36fG8ej`38Uy0HCyEZL{pCHrmZlKsh--JJ1d3()AG4L=X$TJEg2 zMX(6Fh<9cSp@E{9biY@KTVOk>i~md&xAX=v4uoI~!g22T=~;2fVYNAe2Xx3*2H9)_LKt5vCq8ypyYhO!*@tLA|%&Eua-~sO=XNhD0 zqyUmy7RZO|S$e9SA6G4z9>Ql%wFwHsqei?~uspB?xV8n%$ae!L4+%aZwki$12+ATz z1a4n-KxRAf2D2pShQz3FelKt67S^Yl_~?|zUm2+dkea6{XCr(>T)#Y{4$XWr!l4(W zj_`8ELKP>Zf@*Ec;9q4R!OXFf45Th51}u%$Rwja_VD1>}5feff8!<0_LIf?4(NT$e z&2a-6NZb~sG;i;Fra~wvTY6X!w6?GbDurOCQC4j|#K#b^T}zcIG!18lmryR~wuvG# ziVz5)NSAcPia4X#TzIQHs4|v75qLfpAM?U|M7Vq6ZFpg&O1!K16~^Q!$k=D|IvdiU zUmdV^?j~UWgY2q(qj;cw$?%|?Zg?fQzsv&-Llmy3#MTjEB#K{7Zp4&>g_4jtjx1K2 z+kAW5#~vXp?!mVCV)n7Ef8QN%?q_k1tR;A8a!!`F2Qk2S!k+xSlXkq6`W|BI`(t0< zpZ(ciQ(ur?fPN3~PQx7)H)#UUIK*g+Q7p9c4C4*g;_RXxS6RrhZ4z|e?+$m=S01rb zXwXA#a|{i#{rkda{%)V6K{$&mcw7+;N|F>i)|;zEgM``_CNJ=@CVNjz&p;YPjHHTr zMRp$!@lufP+2(#4gx!ZEB`V?-8Gbmi&d69bc|Q%pOQS~oiDQ~8p+WdK6p47x(V#ay ze&e~4a^^IyHl${+J891E%y*i;a1V-K1GIA)=)g)M41aSW*lwpj2;=m`ITe>0ZUW6C z@?OV_gc@-8G=46ePq1HC6vG*2He$3~rAg2)vKd(R3u7SbfK3NdL~RNOOBFV#lp(NI zwmJthHOXz}pC1~zeRx$Glc<`N7ppR72y{$LUe$Yv7T9;K-!BeughACfNAupNVX&ce z%mlP$%3k#LgU%N!F_uz^e=QaN6p7D09f0w+LbGQ*^T;D%2Nv;?L1nY7x5E zP$6;zshC$ng&L%2^wx?rjA}m>5*Ii^IG1e^Dv$5MmvcY;QRhwQ52QfFx-|W9Vht0w zts|{U;=`u)9Ob#>&aLa-%+kyS8?rQCkuJ^O`o)L(DG&QHSc)JLxQXN*DF+iuPe?o` z2N53n5gXa>vZG*zqkM2$jZ9uzYq;wJQK;3v)SJlxj{=ygoFm0(ij?!nOQE`tLDy;n zkmLas_{;n3m0D$JfqjK>B{HNPZMXPpaFQyxl|Pa&n$<9q0^o1FI~jO)5|i)z6?0 zJ{E;Ddh*fXujZ95jN9V*Z4<~(WCv-vIHFIzH1pIXKa5V^dCpf<)KbYnuC%~)9}d+Y zwi4$@JZ~W>Ww|B9-rSRAs0a>0oNc6$FQH+qa2%Ck<$J76OO3BiIQi+IDo%biOr`Ka zrFenJjq%1!GAuGp>|MHggtQGYiVF1w*(qxaYAx5~Fn7YhnH!8FWs7*ewbKT=oNvMS z)0*FW5_BZ7i}Y^;pTA@1%_YCM7C6uC9Q}EoH{#KRNu(L~9)k2T>_R-aD^zRY~pSZh?$OVb`F>5%qp zhW2o%>!W>|_dp`IjaSW_EQ#EjcI7sS-0aGI%bPz21#(a+*rkCK7GUSy)QV!dP0F`| zq5)EpTFYW$0(*c!3U_Fndc+**(t{mJ>J*Gq9Oc3yO?GJm=ks7o&(m|1CTJK>qq}T$ znYd#;P*3tm6l^WVvA(oc7T|^uA4e!XB_}5mHX46U4{`{Z60LPFQlHT=@7+`!i~@8n zU>P|U$~isd%$uMbJd4sM1cN>Z#)ULA591V_7P&E&r0)x*l?JYmxE})~(sJqU zzU+NQ2@0TTPLCYR3slZ`N!$``Zncu+)Jp8LR*Eu$qYBRv0Ey%X(@-`Ts1O(krI8*@0iDq};Wx=*5+znwhfqtx2Awca4kbhg@` z-+2qwmYRHUhbOOMf*WZI?{B%wJafTM-(^4*>rJ-@vbnu_BK}7S0D72hyaWK*wtw$^ zH_OVx)FOuT1_llIZwBcTz&7Z)Bi>?9{exc8k6zBHs`JX zwpV665OYsSHa8N4c~rAB6Ji==Qo{KhSi^O0m&hV?+!czlS^pXt4}{+DE1quysSr1G zDSBL*3O%qmoAoBKr)DpdQ7P|dLSpTydb7le9N$PjT7b{*HB4wF6o|7RRpOOUAQB#= z-kYFNw^r;qE6`Er{^hL_|MPIWVw?CMyJBCp>!T?J3Xp_F2L?O((_)84LmVM#W`oH0 zQd(7JkSREUE;4Z#n3)7vinHPz&2S9}OM)`Pm?8zF0gw^URYunGI3^DWwc3E^6jfUe z@)aSrxZ#4mS;2FGDTBlT=1pE_i9Yfqr<6Novr(x#Qh+;gN5aR-ytxaB+e>O}r1r|8 z;Y3nEaor!;03>S#);$_$7)T&Tk4Q5h&!h0*ffhU>%WUHmJmxto1p&o4%i%Idcts|i zU%qUf3TwTAiqHne-MmuwLmkZg9LdSOlnA&{4w@$Vprfp&PbbJga|(r{c8;K71$f<+ z%w5a{nK5bTQ`YAmzZ*)PBi*>cb9)iP2r$Rzs*2{Q2*bsry#IyDVPzHCU8DRU=Gm1S z@B*7<>3|m>_SiP)Xlb>D@X_9u_sBt+;DcGVUlbhjZwe{=iK z`!iZQyN}#y1bQ#B7sWNUoTf0O_dIINd&DgCo3>VC7P8e|{wzSV z2T8%@LQP2zQ%-|S#ss)6WO5Hsq1&&#XsuA8N808K6|$|r{rexaMy)WEM-p4_7QfjI zKK0C(FZ`ePvavqOw&pG3USXT_#!v2zo^RQNPI=EaxshtwRv~6{85c^jS;_Wuq5fNw z64)p%&`!=vb0K$dW3o1jRWl2D55s-gEX1{KD&iH9A-1lR*!w(&`^gZ; zgH?REoffZz3}IovXUWjZ9`*8ru4T=Bv|Y2WO4sa5-u?r~P$_)Vtvio2^3YPTP9E!t z4?d@D_-0Dp*i0})0E-I2o(qCu4e2wTlm)@K)B6CA7eGVT+%TyD&{U7)!$A6CBZtLB z8gs9Cu7lI_G%s60emg`75E5^m{@Z&3csq>#VIv^+1TQauW-YxRvw-hV9q$6Kx;((N zj8EX8#xyuETsS->STtX44%{gk4p}cEm(KUOGU@09a?FfQ!s#tnKl;3VMFAMLt^8Hn zT87ENE=m7H4lVBN;MyS{brT?3d7?5S0y*TG6eV&+_}e53=+B2?J)#1Iq*}}LGX_+h zYxb~1&Z`~@_n4V#mX>*O~|Ng+ux2)gvk=x|{ zol=7hrxKV>Nen>>#-*<@t+p`bad@@*`43+4)Vp3F^zAXW147?y`@izFt4i!x2`<02 zGCP(8LRK4Zc-i}o7q__IvbBm^oUQfkZ_d>V0!30R+4VO)T9564+evLAvn8t_auQc>W;@R#M zwjl0wS1~Wme+Jlsh$Hl3_Sk%-wxD|j{b9FL39pF$fK)U5M-3JG=@0aW0A&^Nis(;Y z;z&>8_na;0y7hOwT2}1e&aK!lxkYS2x}Ra-GH!ca;jD@c9<^1t-tpkvFLx7> zfe8_-tO|&fqh?jSWdT!M`$xTb% z$^^LsZ?@S}!G#$+LIHD9^5!qJbr3p?7qJ6mI9WD?%owi%4W+0h9N7zP9Wo9H2i7v< zXNJ*3gZCA~2Slo{R;NtC#4Bv88_S?plu39yP_Z)Nx9pQA=(>}yHAmvKXkGxi@Of%jcJTC;*_;1REH;m`E4H!NyuvV{Yv1+; zF#l48s0(MRmD>3D;C^K(3$q>|Z`hc+lf*^Mwc5*1zFOX?>Bk zs>aIQY?b))VQ3~x^~Q(4_D11Bzhi3^9%O61>_^|A*4@hnS0{Jw%6QRh-ur^{WLN#& z%&uzN^YJ&{yzlX#G>6vGPH_Wx5W957mq_8`eTfHUGEL;egUBqe5OQfAY<2XXA0Su9k}9Cs#DUZ5E9MC%Ot< z^VO;?fkRlIam}8q=Nl^+UEY~1ByG`v24E91j#TxU9O%LX{5>!Q+#(N%ouMg!P6dW0 z?Nxx>21+2Pm9linywdE|inH@Lmw_Tdb&agbl2%RbCpBg2on|!Mj?MU}UA%FiDo4*B7x8j&ng|czJP&CV8>kaD_S`RM5Ui z4j@VO2zM~-A)q9iQc)#QQ|>5oA`Q3`Bv67&CS%QnrM4bH(V3y-)-9t8Dvj`12A!dc_T^6{l=K9uP&65|XrK&( ziiRDGQpZZM(MU9wW0K(!j$b-_9=Z=7e96;=Mx2?_h}RktboKxDK>&X~3j(PLYuG1? z_yspQlXhSD#g(tzb(my=d!nt?WP-ER-h7$(ldZE zHdJ4xLF(MgYM&Jv^n14bLW6AcUwP5@{s(B#kIp;d;v3mipJbc!R`IQ{?fL3I{>wf` zgV-_@+N^{IVVTe1b8o0vDGid#<@ObG&_Ei56-ULq5*pNClTT@o7xR6I2Dz&``!7RX zFR?*Q&or0jK`it|SrUdNw{~A_kdIusj~Ilq#$%V+Ad?mj_@ZjOG!b$#4)Y@nlOF*O zfQ%m3_B<2%;mk+=Nuq?n z5=Lx;JS+~&)0+VpxDH|jw3b0uwVY&Jkrem@0?7ER;9EwRaMg3B0((QA1gHc8By3?C zqRqAYhz?gI?g$r2`od6+e&#p-ctX zJWr~%x%qlk?NbjlEgMLTHc+p7sSBKz%}>3&$lO2#5*sjNL8X*jt_NCyW~pTJgPS*R z>MUJGQ%jkX>uvD3qTXuD0$oPl`LWTKO zs_lT-g>3upc)Jj~BHo56z1}jst$^?0y~@o;Pe17R|C0Eir`cL1KFHSkkx%`IT2~|g z0UP3=tA(=p7LJ_$>VHan(9>;eOni`S&fVYt{yxWt>bP92hI<@aM|#w|YFXp%HGGJh zcvZ|R;X|bJO{JXP#k8Lf!H4khRuQj=50O|VgVyhBd}#1xsrd7#jF;v^&bk@$q0R82 z;^c>W&M@@#^*>!Fw*)=IuG!|6AiHLN_R1Sm!%(pEdu+;vu4%CyoCqP&fV&}sq5>ye z5C#892-08$E=4bjNrJOn3myVWrDcO&3@RHs=6Fz2KvcqLkRs`569T|SwWW|Xq;-YO zEv2h@%(#5K=Sl`p911A;*9hn0I9uX{$k(Q5FBnQBkRwjY>hc_mrrKBF=og0BivXPe zB&HpLp~%9a=l}Sq_63P$#!GG)ZN~%@Wc};_pk%>p;i(sq%>k>f4p5#`&=d1co#$NH z0%IWs?JKl7BzbBDcJB8xQwimZ<{Z7W1$phaQ;4&7ak4a7u$+pN70re$xq%Oq)a!GG ziMmP)rInLFxl>-QsI`|BkcTp_w4epFRfXlROLdezs1)}-arl-usXnmTc_UnyF`>Bs zwx!`5(ZhwlOof)_6ch~;Y3i`%7b-nb@u@9e_O$LP!ib(}7jDCd-eVZiUtInUu>UIi zzLl~Z@~B~NCyO>Kwm2Ry!NY1}Isvte$uwDohppzS+-VTM7Gvn{<|d7UOt&w74kIwhmZQ;yV+g;z_#XEu@Kqj z{PeEd_f|$!nmmpbbQK6u%+^t0x!7+YB3IXXwh0)uig{^5R30zX#3|Yn&sDPz`KYXY zz)h~)K2+NcKIE{vvofU_DXwUW`2o&_wkd0>$Hbi3I2o(P0nRJ}x~fM(t;k^UFg)AXeg`tgN`g8Ibw&FXXh5c4O)j00tV6!7IbI-V8}>yE2Vq|( zkAemJ`5B63h&M(M!n-6jI1(`3$~QpyFxS1}7srdG)D#W44%*5Xm=}4nnJjWAt0@>v zPv5B`AHJYc%H|g+9jrhJ#D^2dNu2aBabSIIHN8=|AoNh%nwls}4|i1Q^YeqYA#ls@XEJ|OWSkof^A&dK4NIl&weI`l_dxXR^2@oE>w2q5RrX&HnJ8pa8f{ z>vHmKqCXlq+~X6bMdLh{;N+4wc7PUf2H1~%(Ga}h$xq@wSb@0$h-Rln!v+`v0f6f} zDlY@uIcFj5R)EPeV4qV?OdlENnLO6b^Sjc+dKzlp*;`d@V?z;Wuswn7(4vGC2P!R= z>ORX0aWDwTzHJClQtp08q`7?oJVqnWT5XjF}BgH-~dor>YIV4uVtXdp@Enp$aoz7XqgUffs0 z{eo#`fGB*vT;FQ-wOvafk!{aQ;(SpVCX-{L!R?mK)Vx)29FaE|xoMc^Te)<2j+X>e zYqQ+f<$6_5#2M2|Dqp91g&Fyihf5Dm!h{#_lKimh+L9M*PFAH{!}T`5JZW|bpJLvX z3}zhqu22rQwg#~*K2Jz^_F?!^Y_eu>p%{iIQdVS+K8WFI3Ehjyl}VwNDouWk2^ZPx zK6c#uC27JR+2z_KP57i?K`;J)zaY(Z85LpCYv9$v)Ps&NNae;G6RSfOUhcJCeIL39 zKKLiE7ZTL9%@-16TmRxBf3+twHoWSjHN&%AdZBtiZNeQ6Tp zpkq<)xN57&N^xJ~K#ptnb0Dz=C6PZyUl|F?^`m+Uv*`9VmK9w7%<8KhEos4@p?%8{B@ zs@bU>1qdEbQ}t9f5!*nk&Q;k;Lc+sd)2;rdDy+^|3s&Quwege!FnwN#tBfZPWD8gX zpNKM)BgBTnkF>*~6!X})jCPhkU#?M=l+k{~0~l+uY0zIcHX$^{iH?+lRf$E2p)$nu ztH8q}b9yD`+>#udJO(u)t~`y%gH;}G3u~N`LgRd|l$(&y7--Ht$2y7u>!epeEb~h5 z3RWOmw zjN`tg1R+)HWJHU7?uW1>nH|MwD3=6-k8vpUj)|dB)$0nv$DgSs=MSm@1)MK;eOHz* z)F2Rzh;SOxKetF?;#4j_aW(xL3>PFowrlNipvd&NRFMLbM^Tf&qQ>mT|On!VjN`$FI9&-};tY4xck)|*W&=qyO4(Tusl zy49==AU`p8J3t5}uxbhN^MWluf69bTf8@Bzri|^z z0l@<~qZ#8wKmpCt?1!nAsm41H4YMv{2lp;oG3n8ctQ=yC&hLE8)XezVyN1?AGTI}W z`oP{*rQhL(EjP*^YJan>JN>lTUQfUZhgc=q5|WY4m4l>wG{)auLu&-p4p%nXk8Jh3 zhPbohp_L=DpFMht_H-Ad_F;Bk_AhauZjinMJFJ+UM+!@yasEONx3&tCj14RzH^|7A zO2rmk=rErZNTnU;r5%rDr2QZ(ZT&1G?ZeO#)=QRJLLww>|EX5gxuQyBm^)L2y0J^@;^hZV={;iG+nvHU!&%K;O79yUn>d`w|TW7&o%{#DFqb4#Dz> zhPG)ZLYFQ~7yQ<|S#-%E(C{ULG22|n9)uL}#eo^d63?!J33wDGjmiM!@e-@Wsrj;i z;(KzYGscm#kVPRwx@y3Bok-<`KqfgVH0&q<^_$adbCOaVx3F+ifUZn>eFzS3-jWj# z1BxH0prxBhd*>L9T$+tVRIAL(NtvmD&mmM%Z<+KL#5@2%7g$Av1@KWPl=RZ=Y(3EYzP$Ap=fBMP~?r|#gkhKfXk|YT` zZDUN51lyYLU3|yhNQGJhs8Gy^?FJPpGXy30k!Iuf8>`2JyrUM@r#@`HmolMpcevdk zLgi6G4JAM3AYGy3N)18x4gys74u>lvKv|E&{ZT>Ls33x{mCdgxNdhh{sU^re9PT+o z(35X?&tF}|nhmFy4e;ecfb5!m{bxU5qJj$HlpOVQ(SqE3+Nxo};Pk_pVS@eK0%E`% z?=a0ZsJ`F5Dm9CEvxEblAp-%C>rr&CQ#7b z;&$*7^Y)?Y%1Z(a4=*w9ha6`>b|fB6r5fO3`fv3JM#7bmWQtB%S-wss8=RNn&F2w( z9F+9k&*ncU<=iFEzTmbTm3BbFtbPS1DpzYb!}VgN3_EwV7Sj-8B8*}V{v_?PPKMAg zQs#{_eEC%3t^q+?VJyhoyb}0QyOD;$g_)~urM9e6#!MD(F1TvOkW)f_;4_1ce^kM_ z6U&|y>6#X_4NcUxxtOQ;i+8?I*vU+WoqWSkpVwS^3Gn_NN`!#%h9g>OQwzmo5SAc2 zH97sqZdSjnd%}Uj+&B!E%|)1-ZT=^I@tS_-Hs70G*fq2<9&5(XS$ktphPkE0j`J~_ z@cZ|V`|_70VfOQF-I6fd*8TZE{aZ`iCMM_49y@((alX5~_c(5!9j-SdEe8F-nV6zU#Mo;sZBln`3<7Y^V zP)bF-Lh6&TOZc1<_YohsLAD=p-SUgYuTOnANX9Imx*2*A?>Xx8fA4Pu z)4PTG*d_ai-}-#&0~dmH;EsiX*)#(g*Oj=%aelkTVJOGvI2)Abd2?sl0VL+6mTM>N zt0A%1sjVY`xgZp_LqR@i;#mdHA@EY1%?~t4+aaWYuTZH)q?u&guF}w=P=t^MVnjm& zQ`aCZ8?>NMz={$BVOi{1Qg48@rL2={nq$ogAwbadOt}VO4|tQ+J%A5c2}qyAVBDiX z7Nn2EwFJ~)h+r^UgPIyLVO!Qkq0e8rFA>Ue~+k!H^N1c`pWNQbv zL}&8#=BaEcWz*$|i;S=`83tgqE9h6I0ja!nbc3=cn$dA9%(`vXQcva!r?@Sv!*KFr zqCXX41ZVJOo9`K7^s~Ra9Q3~mj%#&NT%*xQoqGnPapk0kEZjuR;adi!Sl->Nr@Z2< z^Mn}TrliEkHveO{y}q9qdFDfYxQ27*%Mc^aoA4^+OhimtC(Qr#8znaA99y@<2HCp5 z^xt2m?!lA>%W#b>W-K|;px7Ys1g%2V2!UIILVf=qE`EUANpY@ijk%M;Hs`yazH%?* zMJ5)a%sq}VPFe~le|=(vO$bA53U=zNXA=^%QFxcoUTPC64G$vjOa;9g-f=l<1lLcQ z(dy;S_D+cVu3H?oPmrZ!pT9meq8j$5RG0lVH~Xm(0m%`86{`{+geSTe@;#6v;hg6l zdFeY?wa>Gw_SHg-?5cgyTfYr8%AA9OdGV4bWijV!WX18jcpgj@;V$bdlP4?c)$cEL zSI`rnREs$R{O^EFfR6@X&(OLwsA#^eN-F&Q@8uXVA9SK;!2v$Kcf@k+?gW6ZC`a-p9;XDmt8U_M2R#-q zYvnHi-2%Ip#(4#_2)X9cHMpi0lqA)%q^yp_JquSgHJJrcrOq6EDrMZ(EF^R{kzDC0 zpmr+ID{*CW%F1=4eK9||Fsle^jQ>}JK^x-_23u*WUCHtqsI_u=u}m)QR8Y_es>P8D z3FmO6B-5F9(+yU<0#*IgO&|ONVH4-4Y~o)H8T!m8F6(C#jxHo@+AukrV<5vOJlSBf zA}e7NKb&oSU(zqUz}9WjFWAaI_J6KlC8mYrW(!ky9Sf+*{87%d#Fws-j7bhR2^d(~ z{`IpD9k)}+&;_>nLWXSXKl0hz?|Cxxl23m58{#1MLfabSAZMHNmCwBR`ZXF-QO$=` zRI?!!E`UWp%Y;-62w=&TE)QU-`mhR#6)Mgam(zMnPwLc3js zn$~e|5so_%_JtDDmBz`WtVQ8gh#^hJ2P+#_OOmP-6SF~i^CWl#{e4BTAFI#-i6YYO@h|?vRMnyQfx3!Y5PGdhrHU0 zJQi;{)Sk2xyreKwIeZ@uc$%sk?;CPIr{kkNq?Vix3GlRdbi7L2RP3K!lg64s$F+vgiz(_xV2d91l9>+fTSg4Eej< zLjBfsq5kgYK0c5LS%0$Z1sw`|0(??%w&_7}069FA5fo4MswSCUec1p_6+)Se#~y@6 zD#2-`EvMCSna7?tpX32^Su9QgMabJkoHRld3iHVURux>(!y>_HtpP)TcsQ(tcNh7* zNOGf)JwTRJTIK{73v#4ih-3d_+x(UQ`k`PZX~Q8R>7Shb=>eg z@ynOgN<9m4l@{=YFudGzn}GKY*P$UK$>L2W+y<#OQjQovRYnT4jd|C|hJibiH;TnU zTfN?930gS0X!UamyB-YTPGA=M=LBTsnN;NM-GT4 z{~?2yXM!wc{WRH_5G}9|5}xOXnBjN*9KiJ|P-}XXcT-b!$v42V#sV+IP+dum=8hC_ zE5tTLB$?J4$kk{zNgA|hPPk|7p7ZbE(leDKnV-SRm%per&XTYnCdc(JS7+|z0+zxMKvQ+Y}91TLg);Z#bJYw=bUzt(bcBd&Fg`k18lTc*Wl zci7jZ2R7~&E_AW$fp8()fgio%#n;^9zH&RZy!dj7A$o~x&ULcI+V+Fk!?H9N8jwDLJ=U9tyb*3c6_OYt+?rPGC`UEm ztpXA6r$t~S+%#3hOVc8^R)n-je2DsJk=Noq$BN$Z!MhKaCHs>6l6~#_gcUik5Fn7E zty&>S0kSGPHjeXntqk23gRsI}!8r>^K_rO6r2w!2m()Dvz2Ovv1|kq-1+0>{KMhx4 zlLtZhli*ULMro!nA$al@NJ}b5#k~10Yc&7!2)C}t(GakX^gvghikQZhu8yF)5o$(x&~oFqCfx84TETlq#|z+?9Y;A zte1mnOT9ftRFKVVyR@U)?$DO6c6K~NY+EmLJrGi7J8;A2K2}2N$b>?Y9%I}pB6Z@( zYe?Nc9QU;!2qn7I)h(3B*8T1;N^VqaX@k<)}ek!E`RpBLPms@za(~OUvrp~Ys#FM zmu5r`OuGEs{VBCM%=FZA)vQBuylg_W_TAVa$vV0l)*&b2&^okP{O)S|DUnIVpHHbR zm)S2ncIen*OVL^^MfbpMLal$i=*(+brC;t==_}J!`Y+y;VL`!`UkD{iuy&0Tb~4QY z++$+w*@cxZg^qU_IGM{$9UK`KK40NU0A;8aECtqvJp@79MFIIH`&jnaUJ+6uO0&qz zdm5FRPGC#sg-0+3K-};YqXpR4iMwKHxr+2xVl0f`o1tXx%dy6@cYpdykfIh|{PLhwX@xSJQ;-Ag`tZCt_8g!pF3HLvFKVy@X8 zIobxzgpi#pI&=~S+ygtCDxk}c?>R!~XeLAFW=ookK=rzC5~8aRMh-z)c+>RMg)%l3 z`_-j{aZ3t7`YXv38pK9&=ce+{;PFsp@`spr3+rLb)S{$`CkhV|aJ19I0Zn_Usv8`y z6!rhcA78runPB@@xMfB}zB7eQ)F0YzQ=jE)*ME}65m@~_Z{@EQ; z_3U>&>UfD1y3Dm-Qfb@f|L~@-)BM4=3UT`i%LI3AlMH+SMdFfgk^o{KZY+^D;Z4my ze#{l@sIPLZd712}wmCn2&0d*H8+9)U6@vHRBdl0vCB&yDt5J%{SIHD~?;$?;DyW!O zLVU2e6^VJJwxD|l@xcpZ*|wlEw>XFo>%@MyxbDJ&7%{PU;0d5&UE2OLfb84@CZBc3 zeszZg1YPcyW|Pv!F3m5$;ucIk4i5xC8>|EXGcU3P$&HHsR*e82Qqn@_Oa?aHTQ@&a~JN8+b6=noDi$hb*`2=<`Dd-eBxkb1};bIOur)UWPRwq71vI-Oet0I}Zq6PLiO~0%lNu*I10iYtju|d6}N&+nfcn7J`LN+YbI;~6N`kjIXLrJZw zA0FB|{*tFfMT2fc5OfQ~PE5afgQ~H?q)EHOMD1Pz*>E*g(4g7^RnpN4W{B+-9m|lS zrMP8)&uN&ti;>ztN2(eJi;T$1(0*Z45g|XXcFVSqpSKwDbL-Xb zE+IeIc599DWYpxzp!Gt2u*CE}PUWw0m79BnZ0-N>+V>73Kir0db+lGPt0d^NocxF< zT(c%lBrr14Gq0{m>5n#i`z^wLu5cX?_G8=szW@A|ZGV_fd$UcaO{~>2InDAZ)MMH= zz3V!**ehLYt`hcRn{)BE{%UWu`V3}21h=k&*{3Gi2)#kVAM38v?9)U@&&{oZfjOEOvM0`+PMeq@yG~#A(;EyGvT2b5KmmNyga_vEJN04-&w}kv zZZ~~zAMuCUF-gOlk~WAq7H&6`^i~e+1h7O9v%9EL$Orc{4)WD9e5%iq;2=qXKT;kD zD~JnFWCj{$;B{fvC@E^(jKIxYM9#5)YB~v!&Y+@sXsGx;gG~*5$|Eh{o4CnAQ}yZ< zC)fbP*j_tPZ_6wo#yn4mgZ`LgJhefiPSvVpeRdV`Y< z#|A@Nk7t8vK0FmdpaX@iRecGx03j%q$=n8JhG98v9Y8Ye2}EAz1=G{64|xU2xglzY z1N|%# Date: Tue, 11 Mar 2025 20:15:26 +0000 Subject: [PATCH 20/38] add debug file --- .../transforms/optical/bert-finetune.py | 54 ++++++++++++++++++- .../module/transforms/optical/playground.py | 22 ++++++++ 2 files changed, 74 insertions(+), 2 deletions(-) create mode 100644 test/passes/module/transforms/optical/playground.py diff --git a/test/passes/module/transforms/optical/bert-finetune.py b/test/passes/module/transforms/optical/bert-finetune.py index e2ac5438c..15001b2cd 100644 --- a/test/passes/module/transforms/optical/bert-finetune.py +++ b/test/passes/module/transforms/optical/bert-finetune.py @@ -26,9 +26,50 @@ def bert_onn_transform(model): } }, } - model, _ = optical_module_transform_pass(model, pass_args) + + name_args = { + "by": "name", + "bert.encoder.layer.0.attention.self.query": { + "config": { + "name": "morr", + "miniblock": 4, + "morr_init": True, + "trainable_morr_bias": False, + "trainable_morr_scale": False, + } + }, + "bert.encoder.layer.0.attention.self.key": { + "config": { + "name": "morr", + "miniblock": 4, + "morr_init": True, + "trainable_morr_bias": False, + "trainable_morr_scale": False, + } + }, + "bert.encoder.layer.0.attention.self.value": { + "config": { + "name": "morr", + "miniblock": 4, + "morr_init": True, + "trainable_morr_bias": False, + "trainable_morr_scale": False, + } + }, + } + model, _ = optical_module_transform_pass(model, name_args) return model +def test_bert_inference(model, text="This is a test."): + """ + Passes a sample string through the model for quick debugging. + """ + tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") + inputs = tokenizer(text, return_tensors="pt") + outputs = model(**inputs) + + return outputs + def main(): model_name = "bert-base-uncased" tokenizer = AutoTokenizer.from_pretrained(model_name) @@ -68,4 +109,13 @@ def compute_metrics(eval_pred): trainer.train() if __name__ == "__main__": - main() + model_name = "bert-base-uncased" + tokenizer = AutoTokenizer.from_pretrained(model_name) + model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2) + + test_bert_inference(model) + + model = bert_onn_transform(model) + + test_bert_inference(model) + # main() diff --git a/test/passes/module/transforms/optical/playground.py b/test/passes/module/transforms/optical/playground.py new file mode 100644 index 000000000..d0d83d6bf --- /dev/null +++ b/test/passes/module/transforms/optical/playground.py @@ -0,0 +1,22 @@ +import torch +import torch.nn as nn + +def test_linear_out_shape(hidden_size=768, out_size=1024): + """ + Passes a [1, 7, hidden_size] tensor through nn.Linear + and prints input/output shapes. + """ + # Sample input tensor + x = torch.randn(1, 7, hidden_size) + + # Linear layer: change dims if needed + linear_layer = nn.Linear(hidden_size, out_size) + + # Forward pass + y = linear_layer(x) + + # Print shapes for quick verification + print("Input shape:", x.shape) + print("Output shape:", y.shape) + +test_linear_out_shape() \ No newline at end of file From fe9fff6c96448b9db0bf3b0e3a9f27af872a7176 Mon Sep 17 00:00:00 2001 From: Johnny1882 Date: Wed, 12 Mar 2025 21:54:28 +0000 Subject: [PATCH 21/38] switch branch --- .gitignore | 3 +- src/chop/nn/optical/modules/morr_linear.py | 4 +- .../module/transforms/optical/optical.py | 33 +- .../transforms/optical/bert-finetune.py | 47 +- wandb/latest-run | 1 - .../files/config.yaml | 497 ------------------ .../files/requirements.txt | 295 ----------- .../files/wandb-metadata.json | 60 --- .../files/wandb-summary.json | 1 - .../run-60c8phhh.wandb | Bin 164998 -> 0 bytes .../files/requirements.txt | 295 ----------- .../files/wandb-metadata.json | 60 --- .../run-x0wxhan7.wandb | Bin 98304 -> 0 bytes 13 files changed, 70 insertions(+), 1226 deletions(-) delete mode 120000 wandb/latest-run delete mode 100644 wandb/run-20250310_202153-60c8phhh/files/config.yaml delete mode 100644 wandb/run-20250310_202153-60c8phhh/files/requirements.txt delete mode 100644 wandb/run-20250310_202153-60c8phhh/files/wandb-metadata.json delete mode 100644 wandb/run-20250310_202153-60c8phhh/files/wandb-summary.json delete mode 100644 wandb/run-20250310_202153-60c8phhh/run-60c8phhh.wandb delete mode 100644 wandb/run-20250310_212229-x0wxhan7/files/requirements.txt delete mode 100644 wandb/run-20250310_212229-x0wxhan7/files/wandb-metadata.json delete mode 100644 wandb/run-20250310_212229-x0wxhan7/run-x0wxhan7.wandb diff --git a/.gitignore b/.gitignore index e7118d94c..01935c9af 100644 --- a/.gitignore +++ b/.gitignore @@ -168,4 +168,5 @@ prof/ mase-trainer/ test-trainer/ -test/self \ No newline at end of file +test/self +model_sst2/ \ No newline at end of file diff --git a/src/chop/nn/optical/modules/morr_linear.py b/src/chop/nn/optical/modules/morr_linear.py index f93281af5..dcb35752c 100644 --- a/src/chop/nn/optical/modules/morr_linear.py +++ b/src/chop/nn/optical/modules/morr_linear.py @@ -443,6 +443,7 @@ def apply_finegrain_drop_mask(self, mask: Tensor) -> None: self.weight.data.masked_fill_(~mask.view_as(self.weight.data), 0) def forward(self, x: Tensor) -> Tensor: + B, N, D = x.shape assert ( x.size(-1) == self.in_features ), f"[E] Input dimension does not match the weight size {self.out_features, self.in_features}, but got input size ({tuple(x.size())}))" @@ -474,5 +475,6 @@ def forward(self, x: Tensor) -> Tensor: x = x[..., : self.out_features] if self.bias is not None: x = x + self.bias.unsqueeze(0) - + + x = x.view(B, N, self.out_features) return x diff --git a/src/chop/passes/module/transforms/optical/optical.py b/src/chop/passes/module/transforms/optical/optical.py index 550484880..a82ea72fc 100644 --- a/src/chop/passes/module/transforms/optical/optical.py +++ b/src/chop/passes/module/transforms/optical/optical.py @@ -5,7 +5,7 @@ from chop.passes.module.transforms.optical.module_transform_helper import ( replace_by_name_optical, ) - +from ...state_dict_map import match_a_pattern, check_is_huggingface_model def get_config(config: dict, name: str): if name in config: @@ -29,6 +29,7 @@ def optical_transform_by_type(network, pass_args): config = config["config"] postfix = config.pop("name") for n, m in n_m.items(): + print(f"processing {n}...") if isinstance(m, module): new_m = instantiate_module( m, postfix, optical_module_map, {"config": config} @@ -55,6 +56,34 @@ def optical_transform_by_name(network, pass_args): network = replace_by_name_optical(network, n, new_m) return network +def optical_transform_by_regex_name(network, pass_args): + is_huggingface_model = check_is_huggingface_model(network) + + patterns = list(pass_args.keys()) + n_m = {} + for n, m in network.named_modules(): + n_m[n] = m + + for n, m in n_m.items(): + matched_pattern = match_a_pattern(n, patterns) + if not matched_pattern: + continue + + optical_config = pass_args[matched_pattern]["config"] + postfix = optical_config["name"] + + additional_module_args = ( + {"config": optical_config, "network_config": network.config} + if is_huggingface_model + else {"config": optical_config} + ) + + new_m = instantiate_module( + m, postfix, optical_module_map, additional_module_args + ) + network = replace_by_name_optical(network, n, new_m) + + return network def optical_module_transform_pass(network, pass_args): """ @@ -76,6 +105,8 @@ def optical_module_transform_pass(network, pass_args): network = optical_transform_by_type(network, pass_args) case "name": network = optical_transform_by_name(network, pass_args) + case "regex_name": + network = optical_transform_by_regex_name(network, pass_args) case _: raise ValueError(f'Unsupported quantize "by": {by}') return network, {} diff --git a/test/passes/module/transforms/optical/bert-finetune.py b/test/passes/module/transforms/optical/bert-finetune.py index 15001b2cd..bab5eecbb 100644 --- a/test/passes/module/transforms/optical/bert-finetune.py +++ b/test/passes/module/transforms/optical/bert-finetune.py @@ -1,9 +1,11 @@ import os os.environ["CUDA_VISIBLE_DEVICES"] = "0,1" - +import torch import numpy as np import evaluate from datasets import load_dataset +import dill +from pathlib import Path from transformers import ( AutoTokenizer, AutoModelForSequenceClassification, @@ -14,7 +16,7 @@ from chop.passes.module.transforms.optical import optical_module_transform_pass def bert_onn_transform(model): - pass_args = { + type_args = { "by": "type", "linear": { "config": { @@ -57,7 +59,22 @@ def bert_onn_transform(model): } }, } - model, _ = optical_module_transform_pass(model, name_args) + + pattern = r"^bert\.encoder\.layer\.\d+\.attention\.self\.(key|query|value)$" + regex_args = { + "by": "regex_name", + pattern: { + "config": { + "name": "morr", + "miniblock": 4, + "morr_init": True, + "trainable_morr_bias": False, + "trainable_morr_scale": False, + } + }, + } + + model, _ = optical_module_transform_pass(model, regex_args) return model def test_bert_inference(model, text="This is a test."): @@ -70,14 +87,9 @@ def test_bert_inference(model, text="This is a test."): return outputs -def main(): +def finetune_bert(model): model_name = "bert-base-uncased" tokenizer = AutoTokenizer.from_pretrained(model_name) - model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2) - print(model) - - # Placeholder for modifications - model = bert_onn_transform(model) dataset = load_dataset("glue", "sst2") def preprocess(examples): @@ -94,8 +106,11 @@ def compute_metrics(eval_pred): output_dir="model_sst2", run_name="bert_sst2_experiment", evaluation_strategy="epoch", - num_train_epochs=3, - logging_steps=50 + report_to=["none"], + num_train_epochs=2, + logging_steps=1000, + per_device_train_batch_size=2, # set training batch size + per_device_eval_batch_size=2, # set evaluation batch size ) trainer = Trainer( @@ -107,15 +122,19 @@ def compute_metrics(eval_pred): compute_metrics=compute_metrics ) trainer.train() + return model if __name__ == "__main__": model_name = "bert-base-uncased" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2) - test_bert_inference(model) - model = bert_onn_transform(model) + print(model) - test_bert_inference(model) + model = finetune_bert(model) + with open(f"{Path.home()}/bert-onn-2epoch", "wb") as f: + dill.dump(model, f) + # print(1) + # test_bert_inference(model) # main() diff --git a/wandb/latest-run b/wandb/latest-run deleted file mode 120000 index d0ad236b2..000000000 --- a/wandb/latest-run +++ /dev/null @@ -1 +0,0 @@ -run-20250310_212229-x0wxhan7 \ No newline at end of file diff --git a/wandb/run-20250310_202153-60c8phhh/files/config.yaml b/wandb/run-20250310_202153-60c8phhh/files/config.yaml deleted file mode 100644 index 4b489b501..000000000 --- a/wandb/run-20250310_202153-60c8phhh/files/config.yaml +++ /dev/null @@ -1,497 +0,0 @@ -_attn_implementation_autoset: - value: true -_name_or_path: - value: bert-base-uncased -_wandb: - value: - cli_version: 0.19.1 - m: - - "1": train/global_step - "6": - - 3 - "7": [] - - "1": train/learning_rate - "5": 1 - "6": - - 1 - - 3 - "7": [] - - "1": train/epoch - "5": 1 - "6": - - 1 - - 3 - "7": [] - - "1": train/loss - "5": 1 - "6": - - 1 - - 3 - "7": [] - - "1": train/grad_norm - "5": 1 - "6": - - 1 - - 3 - "7": [] - python_version: 3.11.11 - t: - "1": - - 1 - - 2 - - 3 - - 5 - - 11 - - 41 - - 49 - - 51 - - 53 - - 55 - - 71 - - 100 - "2": - - 1 - - 2 - - 3 - - 5 - - 11 - - 41 - - 49 - - 51 - - 53 - - 55 - - 71 - - 100 - "3": - - 7 - - 13 - - 19 - - 23 - - 55 - - 66 - "4": 3.11.11 - "5": 0.19.1 - "6": 4.47.1 - "8": - - 5 - "9": - "1": transformers_trainer - "12": 0.19.1 - "13": linux-x86_64 -accelerator_config: - value: - dispatch_batches: null - even_batches: true - gradient_accumulation_kwargs: null - non_blocking: false - split_batches: false - use_seedable_sampler: true -adafactor: - value: false -adam_beta1: - value: 0.9 -adam_beta2: - value: 0.999 -adam_epsilon: - value: 1e-08 -add_cross_attention: - value: false -architectures: - value: - - BertForMaskedLM -attention_probs_dropout_prob: - value: 0.1 -auto_find_batch_size: - value: false -average_tokens_across_devices: - value: false -bad_words_ids: - value: null -batch_eval_metrics: - value: false -begin_suppress_tokens: - value: null -bf16: - value: false -bf16_full_eval: - value: false -bos_token_id: - value: null -chunk_size_feed_forward: - value: 0 -classifier_dropout: - value: null -cross_attention_hidden_size: - value: null -data_seed: - value: null -dataloader_drop_last: - value: false -dataloader_num_workers: - value: 0 -dataloader_persistent_workers: - value: false -dataloader_pin_memory: - value: true -dataloader_prefetch_factor: - value: null -ddp_backend: - value: null -ddp_broadcast_buffers: - value: null -ddp_bucket_cap_mb: - value: null -ddp_find_unused_parameters: - value: null -ddp_timeout: - value: 1800 -debug: - value: [] -decoder_start_token_id: - value: null -deepspeed: - value: null -disable_tqdm: - value: false -dispatch_batches: - value: null -diversity_penalty: - value: 0 -do_eval: - value: true -do_predict: - value: false -do_sample: - value: false -do_train: - value: false -early_stopping: - value: false -encoder_no_repeat_ngram_size: - value: 0 -eos_token_id: - value: null -eval_accumulation_steps: - value: null -eval_delay: - value: 0 -eval_do_concat_batches: - value: true -eval_on_start: - value: false -eval_steps: - value: null -eval_strategy: - value: epoch -eval_use_gather_object: - value: false -evaluation_strategy: - value: epoch -exponential_decay_length_penalty: - value: null -finetuning_task: - value: null -forced_bos_token_id: - value: null -forced_eos_token_id: - value: null -fp16: - value: false -fp16_backend: - value: auto -fp16_full_eval: - value: false -fp16_opt_level: - value: O1 -fsdp: - value: [] -fsdp_config: - value: - min_num_params: 0 - xla: false - xla_fsdp_grad_ckpt: false - xla_fsdp_v2: false -fsdp_min_num_params: - value: 0 -fsdp_transformer_layer_cls_to_wrap: - value: null -full_determinism: - value: false -gradient_accumulation_steps: - value: 1 -gradient_checkpointing: - value: false -gradient_checkpointing_kwargs: - value: null -greater_is_better: - value: null -group_by_length: - value: false -half_precision_backend: - value: auto -hidden_act: - value: gelu -hidden_dropout_prob: - value: 0.1 -hidden_size: - value: 768 -hub_always_push: - value: false -hub_model_id: - value: null -hub_private_repo: - value: null -hub_strategy: - value: every_save -hub_token: - value: -id2label: - value: - "0": LABEL_0 - "1": LABEL_1 -ignore_data_skip: - value: false -include_for_metrics: - value: [] -include_inputs_for_metrics: - value: false -include_num_input_tokens_seen: - value: false -include_tokens_per_second: - value: false -initializer_range: - value: 0.02 -intermediate_size: - value: 3072 -is_decoder: - value: false -is_encoder_decoder: - value: false -jit_mode_eval: - value: false -label_names: - value: null -label_smoothing_factor: - value: 0 -label2id: - value: - LABEL_0: 0 - LABEL_1: 1 -layer_norm_eps: - value: 1e-12 -learning_rate: - value: 5e-05 -length_column_name: - value: length -length_penalty: - value: 1 -load_best_model_at_end: - value: false -local_rank: - value: 0 -log_level: - value: passive -log_level_replica: - value: warning -log_on_each_node: - value: true -logging_dir: - value: model_sst2/runs/Mar10_20-14-07_ee-tarrasque -logging_first_step: - value: false -logging_nan_inf_filter: - value: true -logging_steps: - value: 50 -logging_strategy: - value: steps -lr_scheduler_type: - value: linear -max_grad_norm: - value: 1 -max_length: - value: 20 -max_position_embeddings: - value: 512 -max_steps: - value: -1 -metric_for_best_model: - value: null -min_length: - value: 0 -model/num_parameters: - value: 109483778 -model_type: - value: bert -mp_parameters: - value: "" -neftune_noise_alpha: - value: null -no_cuda: - value: false -no_repeat_ngram_size: - value: 0 -num_attention_heads: - value: 12 -num_beam_groups: - value: 1 -num_beams: - value: 1 -num_hidden_layers: - value: 12 -num_return_sequences: - value: 1 -num_train_epochs: - value: 3 -optim: - value: adamw_torch -optim_args: - value: null -optim_target_modules: - value: null -output_attentions: - value: false -output_dir: - value: model_sst2 -output_hidden_states: - value: false -output_scores: - value: false -overwrite_output_dir: - value: false -pad_token_id: - value: 0 -past_index: - value: -1 -per_device_eval_batch_size: - value: 8 -per_device_train_batch_size: - value: 8 -per_gpu_eval_batch_size: - value: null -per_gpu_train_batch_size: - value: null -position_embedding_type: - value: absolute -prediction_loss_only: - value: false -prefix: - value: null -problem_type: - value: null -push_to_hub: - value: false -push_to_hub_model_id: - value: null -push_to_hub_organization: - value: null -push_to_hub_token: - value: -ray_scope: - value: last -remove_invalid_values: - value: false -remove_unused_columns: - value: true -repetition_penalty: - value: 1 -report_to: - value: - - tensorboard - - wandb -restore_callback_states_from_checkpoint: - value: false -resume_from_checkpoint: - value: null -return_dict: - value: true -return_dict_in_generate: - value: false -run_name: - value: bert_sst2_experiment -save_on_each_node: - value: false -save_only_model: - value: false -save_safetensors: - value: true -save_steps: - value: 500 -save_strategy: - value: steps -save_total_limit: - value: null -seed: - value: 42 -sep_token_id: - value: null -skip_memory_metrics: - value: true -split_batches: - value: null -suppress_tokens: - value: null -task_specific_params: - value: null -temperature: - value: 1 -tf_legacy_loss: - value: false -tf32: - value: null -tie_encoder_decoder: - value: false -tie_word_embeddings: - value: true -tokenizer_class: - value: null -top_k: - value: 50 -top_p: - value: 1 -torch_compile: - value: false -torch_compile_backend: - value: null -torch_compile_mode: - value: null -torch_dtype: - value: null -torch_empty_cache_steps: - value: null -torchdynamo: - value: null -torchscript: - value: false -tpu_metrics_debug: - value: false -tpu_num_cores: - value: null -transformers_version: - value: 4.47.1 -type_vocab_size: - value: 2 -typical_p: - value: 1 -use_bfloat16: - value: false -use_cache: - value: true -use_cpu: - value: false -use_ipex: - value: false -use_legacy_prediction_loop: - value: false -use_liger_kernel: - value: false -use_mps_device: - value: false -vocab_size: - value: 30522 -warmup_ratio: - value: 0 -warmup_steps: - value: 0 -weight_decay: - value: 0 diff --git a/wandb/run-20250310_202153-60c8phhh/files/requirements.txt b/wandb/run-20250310_202153-60c8phhh/files/requirements.txt deleted file mode 100644 index 67cb06e8a..000000000 --- a/wandb/run-20250310_202153-60c8phhh/files/requirements.txt +++ /dev/null @@ -1,295 +0,0 @@ -pydantic==2.10.4 -urllib3==2.3.0 -scipy==1.15.0 -myst-nb==1.1.2 -pure_eval==0.2.3 -wcwidth==0.2.13 -attr-dot-dict==0.1.0 -emoji==2.14.0 -mkl_random==1.2.8 -keras==3.8.0 -nvidia-cuda-runtime-cu12==12.4.127 -torchvision==0.20.1 -cocotb==1.8.0 -wheel==0.44.0 -imageio==2.36.1 -dill==0.3.8 -pydot==3.0.4 -transformers==4.47.1 -sphinx-book-theme==1.1.3 -myst-parser==4.0.0 -traitlets==5.14.3 -nvidia-cudnn-cu12==9.1.0.70 -nvidia-curand-cu12==10.3.5.147 -kiwisolver==1.4.8 -pygame==2.6.1 -greenlet==3.1.1 -pytest-profiling==1.8.1 -requests==2.32.3 -aiosignal==1.2.0 -aiosignal==1.3.2 -Sphinx==8.1.3 -torch-summary==1.4.5 -Farama-Notifications==0.0.4 -sphinxcontrib-plantuml==0.30 -ptyprocess==0.7.0 -pexpect==4.9.0 -yarl==1.18.0 -yarl==1.18.3 -filelock==3.16.1 -filelock==3.13.1 -datasets==3.2.0 -datasets==3.3.2 -bitstring==4.3.0 -triton==3.1.0 -py4j==0.10.9.8 -pybind11==2.13.6 -pluggy==1.5.0 -regex==2024.11.6 -cvxpy==1.6.0 -sphinx-test-reports==1.1.0 -jsonschema-specifications==2024.10.1 -fastjsonschema==2.21.1 -pytest-xdist==3.6.1 -smmap==5.0.2 -onnx==1.17.0 -tornado==6.4.2 -GitPython==3.1.44 -sphinxcontrib-htmlhelp==2.1.0 -iniconfig==2.0.0 -threadpoolctl==3.5.0 -cycler==0.12.1 -tzdata==2024.2 -tzdata==2023.3 -certifi==2024.12.14 -certifi==2025.1.31 -numpy==1.26.4 -gast==0.6.0 -frozenlist==1.5.0 -opt_einsum==3.4.0 -astunparse==1.6.3 -colorlog==6.9.0 -grpcio==1.69.0 -jupyter_core==5.7.2 -torchmetrics==1.6.1 -gprof2dot==2024.6.6 -nvidia-ml-py==12.560.30 -multidict==6.1.0 -etils==1.11.0 -jupyter_client==8.6.3 -sphinxcontrib-jsmath==1.0.1 -tensorboard-plugin-profile==2.19.0 -clarabel==0.9.0 -idna==3.7 -idna==3.10 -pylance==0.21.0 -ipykernel==6.29.5 -matplotlib-inline==0.1.7 -jedi==0.19.2 -lightning-utilities==0.11.9 -namex==0.0.8 -kornia==0.7.4 -docker-pycreds==0.4.0 -mkl-service==2.4.0 -fonttools==4.55.3 -tensorboard-data-server==0.7.2 -beautifulsoup4==4.12.3 -Werkzeug==3.1.3 -Markdown==3.7 -asttokens==3.0.0 -huggingface-hub==0.27.1 -huggingface_hub==0.29.2 -pytest-sugar==1.0.0 -tensorflow==2.18.0 -pytest==8.3.4 -joblib==1.4.2 -ipython==8.31.0 -mdurl==0.1.2 -optimum==1.23.3 -pytest-metadata==3.1.1 -debugpy==1.8.11 -absl-py==2.1.0 -mkl_fft==1.3.11 -sphinxcontrib-serializinghtml==2.0.0 -MarkupSafe==3.0.2 -sympy==1.13.1 -six==1.16.0 -six==1.17.0 -multiprocess==0.70.15 -multiprocess==0.70.16 -snowballstemmer==2.2.0 -zipp==3.21.0 -ale-py==0.10.1 -scs==3.2.7.post2 -find_libpython==0.4.0 -sphinxcontrib-jquery==4.1 -decorator==5.1.1 -nvidia-nvtx-cu12==12.4.127 -prompt_toolkit==3.0.48 -charset-normalizer==3.4.1 -charset-normalizer==3.3.2 -nvidia-cuda-nvrtc-cu12==12.4.127 -evaluate==0.4.3 -tensorboard==2.18.0 -lightning==2.5.0.post0 -py-cpuinfo==9.0.0 -prettytable==3.12.0 -nbclient==0.10.2 -execnet==2.1.1 -torch-tb-profiler==0.4.3 -kornia_rs==0.1.8 -contourpy==1.3.1 -pydata-sphinx-theme==0.16.1 -pip==24.2 -requests-file==2.1.0 -jsonschema==4.23.0 -sphinx_glpi_theme==0.6 -imagesize==1.4.1 -osqp==0.6.7.post3 -importlib_resources==6.5.2 -termcolor==2.5.0 -importlib_metadata==8.5.0 -cocotb-bus==0.2.1 -future==1.0.0 -pyarrow==18.1.0 -pyarrow==19.0.0 -packaging==24.2 -sentry-sdk==2.19.2 -einops==0.8.0 -nvidia-cuda-cupti-cu12==12.4.127 -bitarray==3.0.0 -aiohttp==3.11.10 -aiohttp==3.11.11 -nvidia-cufft-cu12==11.2.1.3 -scikit-learn==1.6.0 -pyzmq==26.2.0 -Mako==1.3.8 -platformdirs==4.3.6 -nvidia-cusolver-cu12==11.6.1.9 -markdown-it-py==3.0.0 -wrapt==1.17.0 -tensorboardX==2.6.2.2 -protobuf==3.20.2 -propcache==0.2.1 -propcache==0.2.0 -pytz==2024.1 -pytz==2024.2 -wandb==0.19.1 -libclang==18.1.1 -nvidia-cublas-cu12==12.4.5.8 -alembic==1.14.0 -nvidia-nvjitlink-cu12==12.4.127 -click==8.1.8 -gymnasium==1.0.0 -Brotli==1.0.9 -lxml==5.3.0 -tensorflow-io-gcs-filesystem==0.37.1 -matplotlib==3.10.0 -tqdm==4.67.1 -annotated-types==0.7.0 -ghp-import==2.1.0 -pillow==10.4.0 -onnxconverter-common==1.14.0 -stable_baselines3==2.4.0 -imageio-ffmpeg==0.5.1 -onnxruntime==1.20.1 -typing_extensions==4.12.2 -Pygments==2.19.0 -coloredlogs==15.0.1 -sentencepiece==0.2.0 -torch==2.5.1 -timm==1.0.12 -mdit-py-plugins==0.4.2 -PyYAML==6.0.2 -gviz-api==1.10.0 -xxhash==3.5.0 -setuptools==75.1.0 -pytorch-nlp==0.5.0 -babel==2.16.0 -soupsieve==2.6 -ipdb==0.13.13 -python-dateutil==2.9.0.post0 -comm==0.2.2 -flatbuffers==24.12.23 -rpds-py==0.22.3 -psutil==6.1.1 -h5py==3.12.1 -numexpr==2.10.1 -optuna==4.1.0 -accessible-pygments==0.0.5 -tf_keras==2.18.0 -mypy-extensions==1.0.0 -pytest-html==4.1.1 -hyperopt==0.2.7 -tabulate==0.9.0 -fsspec==2024.12.0 -fsspec==2024.9.0 -parso==0.8.4 -sphinxcontrib-qthelp==2.0.0 -qdldl==0.1.7.post5 -nvidia-cusparse-cu12==12.3.1.170 -sphinx-data-viewer==0.1.5 -mase-cuda==0.0.1 -cloudpickle==3.1.0 -coverage==7.6.10 -pandas==2.2.3 -Jinja2==3.1.5 -black==24.10.0 -pathspec==0.12.1 -sphinxcontrib-devhelp==2.0.0 -mpmath==1.3.0 -pytorch-lightning==2.5.0.post0 -alabaster==1.0.0 -jupyter-cache==1.0.1 -stack-data==0.6.3 -sphinx-rtd-theme==3.0.2 -accelerate==1.2.1 -pyparsing==3.2.1 -docutils==0.21.2 -pytest-cov==6.0.0 -rich==13.9.4 -safetensors==0.5.3 -safetensors==0.5.0 -humanfriendly==10.0 -PySocks==1.7.1 -toml==0.10.2 -Bottleneck==1.4.2 -setproctitle==1.3.4 -opencv-python==4.10.0.84 -referencing==0.35.1 -nvidia-nccl-cu12==2.21.5 -tokenizers==0.21.0 -attrs==24.3.0 -aiohappyeyeballs==2.4.4 -optree==0.13.1 -networkx==3.4.2 -sphinx-needs==4.1.0 -nbformat==5.10.4 -gitdb==4.0.12 -SQLAlchemy==2.0.36 -executing==2.1.0 -google-pasta==0.2.0 -ml-dtypes==0.4.1 -pynvml==12.0.0 -nest-asyncio==1.6.0 -sphinxcontrib-applehelp==2.0.0 -pydantic_core==2.27.2 -transformers==4.47.1 -mase-tools==1.0.0 -more-itertools==10.3.0 -typing_extensions==4.12.2 -inflect==7.3.1 -typeguard==4.3.0 -tomli==2.0.1 -jaraco.context==5.3.0 -jaraco.functools==4.0.1 -platformdirs==4.2.2 -packaging==24.1 -autocommand==2.2.2 -jaraco.text==3.12.1 -zipp==3.19.2 -jaraco.collections==5.1.0 -importlib_metadata==8.0.0 -wheel==0.43.0 -backports.tarfile==1.2.0 -importlib_resources==6.4.0 diff --git a/wandb/run-20250310_202153-60c8phhh/files/wandb-metadata.json b/wandb/run-20250310_202153-60c8phhh/files/wandb-metadata.json deleted file mode 100644 index 2bfa1fc7b..000000000 --- a/wandb/run-20250310_202153-60c8phhh/files/wandb-metadata.json +++ /dev/null @@ -1,60 +0,0 @@ -{ - "os": "Linux-5.14.0-427.28.1.el9_4.x86_64-x86_64-with-glibc2.34", - "python": "CPython 3.11.11", - "startedAt": "2025-03-10T20:21:53.597208Z", - "program": "/home/jw3621/Projects/bert-onn/test/passes/module/transforms/optical/bert-finetune.py", - "codePath": "test/passes/module/transforms/optical/bert-finetune.py", - "git": { - "remote": "https://github.com/Johnny1882/mase.git", - "commit": "758710333d8ca4b7444930df91d86c7642652426" - }, - "email": "jw3621@ic.ac.uk", - "root": "/home/jw3621/Projects/bert-onn", - "host": "ee-tarrasque", - "executable": "/home/jw3621/anaconda3/envs/mase/bin/python", - "codePathLocal": "test/passes/module/transforms/optical/bert-finetune.py", - "cpu_count": 16, - "cpu_count_logical": 32, - "gpu": "NVIDIA GeForce RTX 3090", - "gpu_count": 4, - "disk": { - "/": { - "total": "75125227520", - "used": "61026897920" - } - }, - "memory": { - "total": "269555560448" - }, - "cpu": { - "count": 16, - "countLogical": 32 - }, - "gpu_nvidia": [ - { - "name": "NVIDIA GeForce RTX 3090", - "memoryTotal": "25769803776", - "cudaCores": 10496, - "architecture": "Ampere" - }, - { - "name": "NVIDIA GeForce RTX 3090", - "memoryTotal": "25769803776", - "cudaCores": 10496, - "architecture": "Ampere" - }, - { - "name": "NVIDIA GeForce RTX 3090", - "memoryTotal": "25769803776", - "cudaCores": 10496, - "architecture": "Ampere" - }, - { - "name": "NVIDIA GeForce RTX 3090", - "memoryTotal": "25769803776", - "cudaCores": 10496, - "architecture": "Ampere" - } - ], - "cudaVersion": "12.7" -} \ No newline at end of file diff --git a/wandb/run-20250310_202153-60c8phhh/files/wandb-summary.json b/wandb/run-20250310_202153-60c8phhh/files/wandb-summary.json deleted file mode 100644 index 76b497321..000000000 --- a/wandb/run-20250310_202153-60c8phhh/files/wandb-summary.json +++ /dev/null @@ -1 +0,0 @@ -{"train/global_step":350,"_wandb":{"runtime":63},"train/learning_rate":4.722882026920032e-05,"_timestamp":1.7416381701494613e+09,"train/epoch":0.166270783847981,"train/grad_norm":4.34669828414917,"train/loss":0.2263,"_runtime":56.552776547,"_step":6} \ No newline at end of file diff --git a/wandb/run-20250310_202153-60c8phhh/run-60c8phhh.wandb b/wandb/run-20250310_202153-60c8phhh/run-60c8phhh.wandb deleted file mode 100644 index 28b610420d3100fe8e76d63abfdf1515169b8a57..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 164998 zcmeFa31Adew*PMjWN8*j5Y)KExG!{S@9v1>=sd?|+{X3IsMFIV4Ft29MMTkw0VN1x zKsHe^;)b#axJ49=iY%@HMNv@!6;V_IxI_^6f6u*Do!d+3s^|RXy{`OuJl~_cs=De^ z=iYP9UCuZE%Hz)5@X~&tHdIdv@8{{}sr6)@d4R0=eG>jh&R2EG&|dB7^>#C)fq z;gbGCkNK}ZdOY1SYCT8g_Pnk#P*OH**!AJ-)q)c|`~KZKJYzH3?jpO#(CQ&Wib{tJ z4i`jn5AgQ#b_op+mlh7H^<><1^O*+@id0mEDl4m$P~?X4NJUXeq_k>S&q!oIRk)%e zTzOq}>J{dMD59PM${m z?d)mj;>qqd)YIspf9G~peX@+dC*ipWgwjRxX?~x4X2uDgVLgkBN~>=ea6`Zd8QK$` zOT5Pot*R=o%+JdkQdBjxdXTT6tR(NevZ1A=Bji9p$tww0Mtt}x>jb|Z@XL~_s)d1q zur|oAXxR79?;E~?OioER2&%+ zE*KFiE~{+2d|z+R@`~!xNMUGbBwUEkVc_1wynU;RBB9}B6@{Tl$)HGKA=(1IlDvNH6ga$`&?%=YD;o*wHT8~G<$yr5}p-5>#Sz)9iR2X6ZH+m-rdcF8I z{?lW;hZa=ewn9+=QB?R1>*)^n_6-*nhT@KlUK(H1UAOnknO<*6_y#~I9a1$Ee> zKKQ7pH2M*4zOT1?VOgj$TvA?)pEX&7UL;&mJOb~nth^j3qTd|g?Os}45*ic4Yu|zu2QHQi5J247pjfE*zPg?y}bs73z?BrhKk5^xDC>E!NhC_ zm1y+16j?j)9OUg~T6$$cS%uI9j`AKPG@KA#%8*DYF==}XFiMzN0jzCMS!JlI>{>LW zBA6Wi)wXXVNxnVO+YcR2+4T`z`TB5iQDF%GtrmX&L~jnd^0HEDec%fF45OhiQ)$D- zf9YWF0hQI|<+ywEE}~C`cwt*B5b4iAbH*Lrep%9V2Sb8^o=-(*D5+A2Z?#bN3p_~Rbn z?Sf|=Ty%r=K|gQLWG^5K9SO{#?Y*Ei;m~An10PpL%7sJuFCFRaZ}tn+LBhkfU5g)Z zxc5M!5vnYY6ci0EDhQQ_VYSga#bwJY$_5ojNG#bg2R>jB z9OHoM(gOUS!d&mM-ovQ(flX1{CnptVjVWARRaO~^_T@+69xB46l`zE;wC50O)Mcf$ zo^H9^N9THTy#3Ld3X5=wP{Ghh!L{XOMKCL&+vtaW2fbsld7oX_{p7-Xc&{%j2oEyb z;r^Cp{t^l2;g6Hqe9y(Nt z_YjIuH@m+a8Gwr&=RGL#mrT&!-0+~vvf}EhNUk@_+nroxK^U(HS7na+P;VBqeHa(H zF8(N%-5=`pntfMcQAMq%XD)w*6b9xp7Y|byUV+9P5)U(7NS3!7{0c0m2&Q6w*2miw zpGD9US-%ehz$a*KJbEyXJH*?I4nm`#{%{B!N7vj)d0D|w_Os^E1eP3osI0UYN0_G` zM>8MFkQ5^PxfEs)C_bS08}2W($X~ABsETDDS~^#E|mpBv;@qGY8PkLlUmQ zkHDQo8%4OFpt_{Gn76OWD)#o^9B>hKa@*S;~|UyIO-qBDxz=%K(kOO`u5Vnp}|GP=5@dx&OWOut3rz`E-EP^&%k1D>#NG};0U7l$};Mat$rnwofV5L+Pn`izeI$| z`jqH1$=>4;Ccwa{pDHdM6fS`CW4S|Ua79^(;9a7pWO} z5xP(MsHB`iaq1hQ-}LZyu0%u+Ytj_zBPrzz@lT@dX;@JeyHT_6GJghL99^@hJc3Zc z{H(XP^Ps^(9)f=lE*JiNkheGeJ5*L)Wm#h9+>7K~?@`|VL&L>`sV^)js-(`19zIfv z8|#uwE;*N609`vY7>Q5po!}dG)x>Xlc{^8OB*A)R=ue8_aEr;eAjm?0g0!lzu)OWZ z_%kEaMfw(^_G^Z%{BtC3^f^-FilPESIjFiVH$V(;evon1HRK=&CGs+s6~VQH5M{OH zbA8RO&9q&9{lgFuwVgyikCg0M=IU{;M|-n{Zz)KNP|uDr3s{Wxh{$@-va4YEh4fyi ze>vLQ4NZ~7M75s&EcS>W%*dtQt+1#vjA$xUbzNae^o=6MH6!WjQoL?tWALXXl&^AA z2S65cvdhz_tZg79q9;eMid*4+6G;bqou;MWx*Q|rlEQJ;6%Y&zsX(TJ{t)>idS~8- z+10bXeTs%45Qv1hwO(7))-AxE2Uiy2-}Dj=K?;F?ged7FnpXS>n1j7N*jEK*rGty` z60f)^H})vR$rLr;P#nhp$MSamuMj)!dL`MX%m*O;tFpel*?Y1#mz^L4kuU^=Z-}}n zp9l@F2;*)!wii6!EO^5Ryg(T56s-BpJ>96xv&G~#o1Q` z3E!Znr)Grj*rycn;8SX|)V|rP=810wE@ysQN-Ym<HA>IR6<{cVjWpGjBJrW*-=XBg^%{Bnn%0A6S@Yo$5uZ-m$wVs2o z%HHc66CKH2zq!kJKlS^_&v@nLeGJOBiccjG>{ipqCJ?W zTV0iPnPa@!^#09(p_ovztgQ`U2&%lgawu6R#uo*e{o$Reky#};BmR&`$Z)*b-rG|3 z!}-T;9=9i%W(9^<;D~dPadfA%Sn@|^Fz}qq&kS95(O=HG@DD8ZpfAfSimsq=Z8`;$?LMMAO#Y0TO{c;CheEFs40C7kpjd-f0ocEBXRcGX~goC^nk|^n10s!{AlkXmI*M5 zus8JJl^E-!2*qT9HfkM2&W2eOH|t93aoZ9>w9(SYV2t-f5KI@LJ%)?Rha!ICMv3>% zdp?Yb3_(YNTSrxk;Y3b2m(4jqyxAW-+eQ zGv#&uKH2l1ecMJnRv%|1N#xg~gjUa22o+P~ug`p_Gjj@dmn!I$c# z*?Wd_7s?aP#WVCy`3xr=<3*Rs;_kd?+-vsg-DN4L1ysL3aO>iummP!I(*sgoZ`dAB zR*!BQ+5F~$iBl`GfAI9m$&rp7nd7cMBXguYLzVTMt0XC3k_MiBO3vv;Re6=ygm%9% z(Fat0QlqNN$m-ec3x;aNicLReWdGvH8(LNp$s0CYH5557TpA{?5mxgerPo)o+1k9~ zqCt7(BdUg$mC{6T9>z^01Im&AqQ9-oV~#AZv^0+*&f?-oaUN3O(gGS_SdeSq|^dbC8IZ4#_E`nQO|fDYY3?U4@T}bC{pXF-I_R zhDRzQIhDwf!xfdM<_s@l2O!i87v^9%2Xi}-8zVWG8qeX=H#y;)>)|NMDtz9Y9CO5| z(ntT#3Eri<)@(f%&@xj)+v|x!Tl@HLA9r$>fsM?O1+$U|s`-+ZuLVxY$&r12qcX3k z>YD3zu6dT)VpeKatSx@Y46Nn0DSvVsDkTP$oUiEllEP3a8f%N)w_P%Y%&Ko{RLrb6 zshy2IlxRyjE%b5wzrGEi6Z70xQ?XRLfDe|`1wV0B<>R<}~W#aT5^esq;5yT$C| zH%`6F9liX>94&^H8pv0u5h^~v%4i+#f2ovc^>apR@#g6nZltBh(vtFJKWM2~nZ#&) zcIuSdiB^9{v>I+*{kbb?88NibYcX1yPc|5>)s0>6C0YkLqBVW;;zv8V(YE}tw3K{_ zY)kX`bw+FC4SnxF4zvzVO{*pYq3Ku59d4fdKpz3t1h;fsfmm1wO{tUDeF}#)__w8> z5v)TT!D^Vh{0cY13dX>a1No}KU>O|Np(p)$GquC))UcRG25 zJL*8$pRdV`R*=V9lYN^?h}NNwXtnjC5G{gmZXA~VfFDB?A#I|T8BBJHT$tC z^F@z%rz^TG*p?W{gOeOkF1{hpOSBGmL~H%rIlsA*mKe!1TDs0`OOcpItld># zNVJY{M62<|4__DK2iMq^7|Cv2`db&hD0)QK&=TW#(9-jDjnOiAul3>2uO3abavjlXdFzu}v5$8REisU1 zv@~v8W^8)Wfg6Vqt)m^$s@pwdr7LNPkvylx`*?%*@mGFu9viJW#u2Ti59Zw^dc<)q z>9-Uyl4rD7-edUuyx;1%>cG>(CC309bcRX{u=n`E+OAO?7a)~N; ziKc1I|G|)m)(MViHEdk|wn*y^w{%-#AkS&>w8tEh|Jxt){zbG-bVRGJ`OV3qOLPq_ zF_Mp=#XI?}T{347t&<$ls@b^yV^`7=LwQb%#!cZ8c_+W^*C9)Y*2#`&ZEjxtpy(1^ zV_RY z_+ZpfH^LGFdB8IAb((g?uqmHN?Ecf7=cyg~Qo~~LgK1eiC%h$kM7NlhDhBe#Kuln; zNow6^a9St5r&bWHJV&(Fel>5o*vGqumKewz0~I-+4asZ1V1SKnd0y^*0MU{h(Q4fE z+H$jxmyaE}acZ4Qx-C_V?YxaV_M(h>uC|3Ep&SBvOAjju#{bHleyCxVuin${>$bRUxn$`UmbryVHV9G8-fQvPt|x$(KQ*sf zT95QCbBwcJZaUhXxWtGaxQu+2XVxkohq!mf*?%Kk0Y|uM?s|GyEG}0pv%#`cF{Ed> zWIiOQaqIeZ;nXXsH3n1TVy%%|*Yw3J7JA&tOAP7*Yz~%>&g;DQnx5DFN8*+5jMwI; zTV2Xa4C}!Qv#=&Fo&@zjf6O-Gb*dv?&7UuRInFz}r9Fy~J>#X(>QlpK@b2s4v7N6c zUIU%+nloceoOg7^3KH~RVrTCJFLK5LJsIpVc``t8jwa*uyV&5K17Tli4wil_5qyL-2oml)i0Tzn45;O2GJxCv)bb38pYF6JINuDaSM zZjI~i-NH)@?t|3bvqih;zW7#yt{WQvM7;j!h}YUr$FGRvHQo)wUg*BW;GXeP&F-GN z$K^*X$st~6IO5eZWygd#_vjX0Vsy`WnRBwd`}%rK)&a!pOh>%xKAE$|oxF50yyv{k zc}5;iEcah=nSp)Ve;S11j=Y>eOoh5+yY4f(dGhx=5C7eJdhS{uPX4b!C!ugGt3|+T z0Vdh1A3Z7egu2=RIO(iGCwjf+FQQi9K7Z(lfyet4ob%^FKGX)~MGvf!cKbZ>n(ICr zdl*hUJNsa7|17g`c2-xEoT6$TD8wX7H*MLld_P#sIoZ2B-ea@kFAqDyq^vAS_}_pL z2&lTIdwp5SSfPLw#`XPD)5T?raNS&?N;l2PI8 z;H*S$xJ6WsJ37l7q8mopj1rW>IS^ZOHQ7)DGA^R1h8#Ir3VIL8>K2MVSWg+%x)r~y z;&u#gkF5R80(Gb^)h|oHjw&Z8R~`)LJ^#6BZU!)%o3-2Xzk4VDt1lsIzZ;%7FhhmO zoR_)>>5b7%?$BG7KK@ZguPI+#O`CuJlKM=x`In#CJm#s7GWh0SH@GyL&v5rG7c8QF z=ls;HGNY6iyzwGXd3s%a?f7@;6jCO&)?DJ zg>T`9hu(BNRW&^x#n)JqIVI=#Xkkr!n9xn&`b=7UcU92~%1 zUX=PS`wDltVDI~c$NT-)*WN`w@Z!{`3o+^D^_y^Z%mBShVun;c@U+nlV~E#Z9r3F9 za`DW#6wxib#33#o!_!P`G=^_JZ|P7<5id#2i=~L)u@tfK$7S7O&C3<5Do5t%;s{rg z5WYe7-*paJ^QG|RCwTHt#}GUxGVkNrJ+o1R4dvCTQl1axMX!fHR)T&61mYK!D5i%m zRKKCB{@U#QS3i%-UB)l>BvrrR*W3Qmyq?*)sH;{{R^a0$`g|3Yl@}lu?x_T12%V@w z$x!s{le~S!V^A)vfc-)iO7sWyfT}71&CoP|P*t;!@%HAIqN-y2uuMb65(X(K1)zDN z=|(V!D;^|XF?@Zv2&EN^BlO$713^U%`bh30h?wWCqK#c|$NgW<@BilL?LYkDZN0Kj z^!7#lvuk~MzUoR;n7KImn*rIUsH$$ry5d(A{GmG&4?UYge%swxKzzh9R${F;{^O1ccV1zp1}pfoZ6u)GWU=5>Y#zElv~F;j6TDf%W!g|1G_@oJbeN&S{+o@`bNtB4l0Q)aPTYtb z)TL}Hgv$7E7S?w`zbenpNie1v?W!gyVexCw;Cd;CgfJIt~ zA>C%v)`KJ?t%#&G#UhO#O-@?1NK@lAZCY(3ty-k5&uo!SIJkY%YLqlJd$adY<4L3Q zSE8iVl%!AOy@z?3DCsPa){rFgo`e3YS*X<*)NMv>JxMaus)!n`)a*X+tFfq`Iq{35 z8qkqn$vg6?Z5{c<89V!;tII4vN2U1i-!?~qHGi?;ZVO0l{7rDDoL^t zlw|8M)Wkz|*-zn!BEg?XnqR^1bK-F?Cljv`ZxQ$?YzZ20KIZ#Qye3{AUW{fair=r{ z%GS@zM&g~yvZ6rSO^4ze26X)Bg!s?LJ%1lG90QIWuw_HyUWXm11QV`l1a!O)dM3Xi z!=9`g4a6Nu@S%xchH*_67OY`V3(9bQNbw=u!EHw`8)p~PyaO;I zP1jXDp!pGlSp*XA6>kOYigjE;#ZcP_#M=o0O=>L^b3^VHFtyymk%LoO(uKqC`>OO1<-Q#z>&i#sMhacA(F{X5!fe0E%iJH!XBT%~_4w^%UspTh^;F$=F6AZ6_M5!S zt!8|7aq?gO@dNR?#u=|SXD^ByS#^tj3A6p27u!D--St`X%GI}!d;CXgUTkFbdp@$d zZqi|~-Mw3I36uTFaJ_cn*$h|65w4m~ZZC+%<%$hwBXbO4u0I*BUvKNjdi#H-#>Juu zZe7i9jDI4|JGzCJFwY^&6ip~RnlShG@44!tvDAA({leBp1kP*oJ=0rU%1fB&FnRH4 z!rWL~QK1YWUV|L*YWU)V=iA@9 zVrJbzUK{S5H^2Lph(Qw039(nkkI8T=F)_~P#A14|Q!hEfAYw(1h_$}F=xw(W^NXQA zBW5L%Jn1&p~$4y^Ty)^Ulh(j_J)M*OJ$j0`-cQEG!8x|)b#g^&YcD>nYR z+O@>QV4oA?^{UKLs&zLn4G}R&nmQo1dG@3kiLTNOThh>tiSa%s7Sk%%9WeioL=3Cz z91vSk_t{3*5)(syM$D`dYF7Drwa0tU60z$X5!*cOsX2+R(j_(~hWwnEIWNgw<>&eD ze@m{iA~iAQDqEPVthskl{{&a*5@2G)p9I*rZhf{8Ff4IN)hGpE3u@;7BLSEj7Uhr3 z35WrI3^2YAx%QdY|3+;R>%>z5W3dIdu+1|cdnD0Qx`db*@QX#q&7H`hOIPL-F)WL7 zK&<|^ZChMRObqxrv6!Ov4?n$*c}i?ja6;^()_$hx7Hs>i%&ox0 zc%K6^SMl(g9e2+8^*m~mH>C!~Vhe6zv$srrF40rEVPy(>F)`j}CAO6q6Ek{8@`%_- zN5pFGpS;Jl#KeFM_u*hFGTDX zN5nRN|G>CJPw5h3V#LphnbSyIUiF35hrLF`{^f{R-OgWDB@%PP5|xoTK{4V_Laa~T zqhp9zjVCoS7F=v+!A1QuyG~8;l`bJBM*K-y<+?-9`|uPHyVV9U<}5j}h8eGoO6>Yw zf=mqgAsT=|elvjO(;^#B-ERvat96EK`mdW;n#dGZ2Op*IH%@g4GBM`o$YT2b#pf?+ zA!MVRA*;Lh<$BXy3dr2Bt_NM281-{xX6Mg`VE+^flo7Jg&XCpb-Uj_4Ho@yYWMbUU zkeRZ421mBE>*6a3*==@_?O?tC+Ov}DI(^j}u0-3vbsI@PIP6%B;k88^anDTmoV3S z0lpA+;q=Q-JRBud>OAe1we$b$-gkC&0@ksfY!q}G&`!a+eke$#VC$`_VNb201dR6U z)oIv-tq0Wrw%%e(GAhKN;9RnLb)9hgDB%`Uu&#f!U>)|B;$)njtYn=ZWo57#AA7WA zv1FZ`q-32c2essd;&dN2@?sx4en~<>15N((CwKhN6(sKPWZjzaZ@g~$LCr!D#@9~V zai9j<8E2!pQSmh1r#G(HJ~?CeN$1y*0?&9Gr?LWf{M?qMyC8V@1FiM!Kz~l#a>H^z zINxal7FD?8teRe~okA7vl$KxT(Oa5e@cavVm+2E!F18S{oOmF5u79e8mXZ4Xd>Y znmA?z`HNsalRei_ErVi33Mq=C`C1gPhs)wFer+Y>G2x3tn7)k$ahmBS%b| zM&rw>WIi#qV%37d)HtWwV8(hhZfDJt=A%AFk{&Jn7u8aaCXO9JvB0o~+Z4Y2F=nc{CYcXap1Eb! zF~saXXUt|l^w?ncG83}_&dl7p$+LlD{#sN>jq`pR%viv|?W|?f=y^$fr#lQy%m#Q7 zGqYSc-ywYBEzh1n&}KS=*82HZJCZ_6|D~ho(PS|lV9?B+PX>?DMn81MZv^cDXVBKZ zvVL(=-{}r$Vm`p3#fbkL`TR>hg4W;+TJw}oZ%yhu-2qL^2RJlyts3_O=TF;jCqaAA z8MGDeuV3L}XktDP1C58abLB_>^#X=3X8{@&^Y06PV%7XCRyE(y z0;s@GRc%o?KV{i`1>>zuOvtadU)Wz!Fz!NCiBlBz4@jspY@kdoijhjngWq+!^;|r` zL*^5ZECGLjme<~nktHzI2T;jWMt$N$y#f>g^=pAwcHBH+46gUEc|EG@&#$+4_w@gG zUH`Wip7d%H%0WM3-pw;q51QZ3yw6(_3;PF97dx@Azasw!3;RvC&;t1h<`Uza<+tiPW~Czkb( z7Mr&U+}kRFo^WvcCD7Z-`V(if1bRYQf88gcc(f^?@H@)-Lx=(Ox7%u#zB!8nQ|Hca-%H_|TuV z*|chrjxFooR%_mRkYqL;Q`R4LVAh&XUDjV9ZPl^|X}?9<`pg#TgoE2B-B#A0v@+T& z>yNjfIVkII_8#$N{WJ9-Y#K^+Ru5uWq+_f4qZqX*3SvD-GSV?s{o4c-5~})_H!rw& z3cB$}&2IdTwr>1|mW6%yQ4{@mUhqGzA~>pC2T?*FrPWcon#to?rK{!mBT?2L<=5FC zEJu*qN}?w#;vYmkY%GrO8z?M~V%t`wcqOjLb|1V8l!#XqsPI6dCZJj6$m8l}Lx%@< z$o{A(zOU0IPvj32S2>*u`oo#iUx@u-T}B{C@0yqRjyo2=4yhjY2ZA(^9KeB06(Q~} zGof-rmy{uRfZIjo_Gl^c_*&?a55?)(11Y)+c^;|l(1b^VGLMW90)GQkKwj3Yt0oZ% zu&x?Jkfg!nEKW(vwu3wSONHwpr)WJ>+?}Elc@QPjsp`IFK*RyM1MNyeKxBbKl2t$I z!AE~Qv1B{&pbmNfZ`07B%@jQ{u?)LEh-Eg=OVV^I!;VPJD#IT4M45WqtCwl@3JF5+ z+UoI7q~nQn_80FQ>vvQ%Jg)M9y7N%K0XL{nY9ySXlA*Dy$nozQ0j6nb(23LyBU*MI zQDsamcU+V-An*`?L)uD8`(@wJDF>N?PymV&9cv=bNH9HjBX$Iv&Y+78UdTzU9Cdv24OZ2U0&_a4%yjyNO2(?+vc9v{`I zo^|<+#B8=RW{vmFzsbGKgb5La)$}(@nHhY2=;ZF#mDBjd<2IPFF;{M9&0o|kPda?) z4nq?rL^w2!59;WAX8gEY&(0!fPdJ0t^i$oPNrw+z0ZkERL^w37t_@#;QS;-Y+X>nn zXVBLEFzWrJ(9(Zb9lA6{m=ZCe#fX8P{Lhf(DG{?zLc!K=u5K5YXU3piR?z)}C|V@D)&o~~e~is?X%AKm#a&A#im z#tb)O6LSKFO^&H|ytrT06@=}1XV_|LC)9LuGd3|Rh_N=l1K`=Zqb?yEqTnUa*0!j<$mR$^(y=?vvV~miAjGsB04w15V8>iNmQ) zymWmrQG3xDwWi&>uXi&wF*9H_kmVR^+tx*fQWITbgBtUr+}5UlxO7Q-e$*AlCguhl zo4IP1$GOXHsz05uz2ppA>&h=ix93OQfK3zA1CGt?+4yXMoclmFVSCvbw&rCk-fRzB z+HVIy*QSXXLK18%TQ7f^ur)cuRs*aZ%n>*?bE1XMgNFCpdJ#F&S8ZTp;m1xEel)-E+MnHQY+{PQshJXpe82_N z*j}?ijX6?It$yk5@$EU%wBJfWU7MI9#MDXVJ1KsByU%jMw$vH6rp901b2T4&pc}6<5F8B{}q$_M-V~%tebEFNk zmuI%)NZmk97jpzo&8*YPQ>#vWUaX-)eJE6Wu(y9!m(cKVY2l!(t`*g#RYfHbWX@3Z zaW&nqqe6Ws0Et3FjX|`DUPDxTy zPr)uO?{QfNR8@qFO7n(PgbPEZWfc%B%rdYXT@Pru3Km&Qe!r}HeeDz!?2jUaSRSfk zDXFFgu_8OjD+(T+)q`I*QeIX7%ZE{+zBgH+zW>cT z`QLsGDAbp+^1>I6Qeiu9*=TXSFmT2j+~zx$^gHB)voAgt9N)Ha<{sL($WLuuvaDZ+ zE7SM$Z;|Z=D8FNaRi8FgkG#_lRr-#2rhGWPqIYea-1ds<$FAw;{pzj58gOQ_jWcgV zkv{7k)xG^mzw9ERMjtiWxEg(ltu<+pzGTfbwo0G>=B}UA_YL4SSJ}ADEa5h5CT;{D z?^VamSTJV4jBcxKoGv)fp$ri}w7%(5q ztX@|Ccfz*D8MXy=Pj6}uTiP#5r8X)KFD1d&^0y_A5VjATVQc($``q?&Fjrs`$Co%Z zGc4xG#5vz;_Yt-aonfn+v3Q58v5DhLG1z!E@szZ21z}t33|sTO#}>Djy1K&J#POvV zY~0$^?+%(pj`SlN*jNe$<`V#L*?DG$O}1 zgxxQXC?ss_oMCGjJ-OA**bH%aiDQegwi(yl{|jOJ)ETz*W0y9Wel)?_(tg1u8mS@X z2pk)7}qUcj=-^*OQQJ{^rav7D-f1LoUpWKUGU=uJJGF0d>Av=o`ZzIN;NZ;3Xg-HO z{`EWWlL2m%4RCkRqVhe=r7n2k-L9_Y<`?sYB;10NGd?G7UpwR0^!kPi-OWu*7?NJmT{UGNKGg*c)i0(DtX{et(@@_z`LtOC?i**|*536kw$is7 zjHLT|U-WQ(F=^o7c&_E=W0nt_#=9STWFJ1&4dBGQfq}D*w#xZNH04_XVuf#ffPH=Vm6Ve6{q;7x>R0Za+BW_H;v$ zySa&318=B&(AO*)vSa7&5^AVhZE$132)DNdtG|D8A3oI$1}A0>toWjA?gil^BdZU; zc?yARaR{z;(a-J%7ZB414$h3?c%J&yJ?H;S;I=sgx20~$K74ArFRMWp7ZCFX4$hp` z;(6*3(@v@)aNC`MYiN04^gev58^DQqLlTSg9eCDR1a5~jaOz|wT=JwW+m6haQPE1QA^?nWSJ1hoKGl~Xb6gTzaQuMZa$hX)l$=(l^K zvbn6I05%mV-v5-}YDkpiXY=++l!V{?2V3T+79~NkW89fvLQVB!yuAyHDzEkB`Kl|! zLn3G6mojyfG*@*$3W`$!YXc?XdwILEql!vV|6j`P?d`_?Syqkb^+@j@UMjtcYWBaF zck>)ov*)+-#?1TsSW+31Z%BD{t*3L4WCapMNo3H zqC$3tmCHUHx9j^|vLqUe&2LEyL?s-4w=7AUghOgs5*VW;mSL%i;FbLcWl03m39=-Z zxkf|!cgT_iVXqcwl-2L=?L{x|cgvErkxng3g6ilNX-kX)q*Y5S!}`pYSVqFZ?TckZ zWl5MWN_(;-3Pd0N9kL{T2qOrxB#K4)_sf!K_GL-5Hqw?D2T1=1Wl12X)2=K@(5J%h zm^Q826*WpGO9E8|OOyo0ZIS-{vLsPaqXb!!Gk!R|W(GR)Rs2&N21oDK!_1`?zR+S5BQd>X_Kbid!PX%IGR6ZVkf zMGVl%T7QD742=pS3LsA1B1-`((wo6ZDMAC>uo|VBpts;B)h4Xf)*!YJ#Q#{&6fZUb zSq%Q{&|JVKHUp|E=F=xCNF+S{i6+y8XTqJT0R^%mnj%9jB@oqcfNEE)FayF8x<$>e z25D;?Vg~E+_jIu9f+G4{(z3(ZxfqHG%4NG<I!A?Wyo2amYfRTM1eEY*$_Qw|8kavr-Cp>*Lt2p!Wonv#R^(JmRopEcPxoc$m+|qrI0S)vB zlP4*`jaTfl!Htc{a(kQo^^B$a7+Q1#IAQW68MxroKh+Yr-<*M~ zX{sB)kD*03fD>jAvXzJzP+jJTbwUlPBidp>ah; z1a6cwaO=m6xqlx%)eYc;v>_R|GhW;~n81y82Ci||y02UgPRtuvQ?>F2K3#Xk!Wk6= z?lx!O){fsaZy!$84HhTn4IG>q{_+ZyN4&WEXBFUXw*hWEtpKlO!AR?~PLxNw5G@G5-KV^cMdtFuAw zPTG>jZEnS~1+zVFMNj^ zz?6zS>!Y7Pp%Q|3WVBa8P*p~K`NO*zz&kDjc81z@LIJ>@n5GZX7{02iL9qg$nydn# zq3hmblT-lA#K^OOVOR|l7dm93gSvM>I|Tsyp#UJXq_JKDqS}~eP?g94)E-P$|1XF^ zcg$>I&IB^=7@}{l>K+*3uDl^CnOIpBE-A0|^v4VZeg!oTp<0e{bIevy@x$tM2UT?e z2z2>18EO3_zw<4kfMBN=k9t(rU^#c%Xmq`!-PP=6Q+9a19^{*Q95_y}aq3=A^xWoc zGk178TnWJ^*Il~R56*YlfR!0#Rrli~VCCs`-rd_zrMFaXO>o--hk?SmPcw{^b~%2uvkS94L`1mU_y&eUPLglsmIh1lyCmqho`I$;wC5AxXHf4 zO*YJ1Onkhx%dUHjT*uuu&KJCH<4-Gbx~oIIl&+(7Sr?tqO?HNE^VT<>BS+HC9(qK{ zX})u9WR4<9Fh5B+%^p_}3i-TIX)cDo*(IQ+!X#i-dWyl?C&gl?)`bffq< z@yk!%zwZX?0;>}Tpwxkik+0L17!{M^MOAr~*BthTDXR$GG-v2G-}l0ou16;hK>?l0 z%U)(PnBP!LK9=Ch%8T3*f??4{jCzeVU~I77F7RpaxW+-`N^I26$N^K}}* z(|u-MJL50Ij-YXf@q)k{kbjqQ>+8_$H{Q6NjNd&d68zFqF*6jmY#br{;RU4RS1(;5OI1Xzzl3``7eG zCyT>SsH<({OMJlH?CD;r-#(tu&2)yYWoF|qu16>4609V-#?fiKZ10uVe0C9`d%zjG z&2KKn_Q`#PCN8i#F_&QIRKDiK+*|+WYkXY@U4t`pE1stfl>0)L=4(08*U4fo!O$sY z0X?2eEWS(YOXwbShOTB(>l>~|C*~3iowZS%H`soczaAn7JIe+-7Mk3~LX)-Ie$?D< zbYd#O$(gk-xq}^)*>EY5d&n8NhF`zB$?fFCOoA7VHjAtBWr-IJxMmZQPgl?`gboHAX?=$^tJ64zGOFXH; zim3!g7b6MMJoxJ%p_}InUBj#IeC>L4VlKgJ=bM%FxR;#ZTf7X8Jf7) z>cm`vqvIpQI-f?IUOad^IoKy{pktxQ7#5l|tY322J{_zJ+N+pKaB{r8>O6sc_VfE+ z)j@874RRA`MK&kbx_(<>_8wYm(w@zwyJ`dVbz&}&gx&6Q7Y{PPZlMi!chSmX&aQsT zo%1}dho_3!1S`(1#2DU%|LiRjyhYC7ZGO7urw#y5Oei=!bJ&TyiU&URzDDq#at3eh zCwpGVaKGh=IR!5ZZ*D^2I~33T_R{MJ-qX(Dwbt*b?UV{$dM|cDpQnmh1&3#rvE`oj z$y?7FLGYe&2yfw?n4d`DFw!5Mm{)LkX0zqv>_6_9G>x3?vo_!{XIp13o*1jSpPiUk zFm~qlSc3;D-`v#KAa>6=W7jhF1)$_p&j*jVr~%wkNes`%s#!I;63jU-mUZNQu^BThbN{N9A1o)$RDzPJ7(a1 z8Qm5;gSUS7jCU<~H_m>)dgp$flaGL695nqIifIOCXb!3H*1Jf1<{mP@7o0Ji{?)rv zQ@Y%=ho^~YMv~S$ZD7}0avLw&fX7_!Smtu4Z%mz{rT01>^n99_WpH-pG$WtzSwDSM zA+cNHklj;_=cIDE>Ca9~GB`V)WSEm>3$MFyGd0?mY_MZ4m)l+Ays4PDND(KeKRhwX zVD;>z807+A(_WULz3dEL>-YCy-?8iAiAe_F8Tl%orj@u}#-6jj|ClVV$r-$w`&PVT zx!iVrZF;Z2qb^U(G62t?uaU2%3Y~0cz;lb1%_4ZOID@x#)V<47`r7n|C#D%e+L6MV zEy_jnft3>mO+1uT2VMm{QZZnv19vcgx@6vKJyS>u$fz8w<3A{zi^8!e++PwYDXSO} ziYgCe8mbT9sWNINqd06pr$V&*R20C5ca%&a|e-=i&`Q6v8f3s)y;ohF5k*eWk71#PInG!-z*#{{RDIjADyOBhS0sHAuomM%# zc>!q|EH$r6S_b^8k3KT_Kdxm^t@a+a6eX}@Ke?IBw7ag ze^ARH*j9}^UdsSQm4kZFsz&}hvos*Fe!rRlb$6^G;8zQZBIg&CUVC-7QfygZLWtnX=+j4055Uysrh7V9=@@uEbX(hpL8I z`8^m!0eY04j~<#>dEYO=jpDDYs(ypCiOg#D@fr!48fu&e17MFr?9d88f$-zKnJBI9 zi!XccN2lzkzbNspE2>8Bo*kI@;{l1v&MOLP^pp0$KI)iT4-_y+L5T?|B(1NW32r1| zfNBl876__=sOUp{E&D9!B`8Kv@}pwAjt(s-6rzu>qYlvltMfy3dg+TH60Q%`@h5mRg11egtpaTvns9Uf4@fs}naqk{) zNtQ7gg!idvno2?z?N-dkL$klAR&P*XX1)3N>h>tukIs$#(NVfyQ`=-D;%e(tX?RFz zD4GK0C=~4{%gZ(u0;rxn9zlZ;fo%A775`)XAiDL~68%A+0^J4`HH~l-o=1ru8CRn} zQ-$)2pGqX@(Cz?0Ep?6j_{#Ut7=S(p!a)VU2JHpZ?r%>JpiLzL<(&|Ek~9{V{)p2f zAVme!U^G?oL#)a&0=&qWXC+xKGFG%Fyf-uhDQYir5=8*Wh~TXSQN#&GCYV;@LuqLH zFsmR5U_k4D1T9!2jC*AeZosdzzmT7RH3p!*Wj5@%-_MjGz-2&x3o>H5gf6B{`6BM3 znQh7=Dt1mm5Fz^baiCxokWC0kDix$sNG~KR@epSeWIv*41nD8PC%j?!Nbdq&iv}e# zsEuyAiQDjGgqe?!ASo+Qk-!? z9x3O`w3BloWDl!ULa!H$icadr#d{PeYy!$0W{PndDhB!Z{XDnh{HgSO}1 znte6FTj>m5smi2cyBm^SM&9>58Mw= z7p7VWULaqgA#bQ7@qx$4!&PA_dH80#B@Z>ZbkES``vLGVXC)FD5Fh#TrT2Gv!(?H_aB>kJHUXJqa1r15k7 z&Nqm1hn+9jaqvx4NbemRO^yob@xFfcyzw7r^xGQQOXUvVwQ(vdcWC;Fr>1?J(c#J+ zzCOF4FO@rNwgHRE9dcInzdis~o?idmdD={RORH?0-S(Cm7CrWH#t8#=)jxn!SKBys z1gj*-e|p9GC-KumgmQ<}9^=X#CN#w4sLe6|AFrhws_AOJ47g-{tf* zjReShRG&-MlIM8e#`!|b((=g)oStIJD(&5FD$mi-(E1I*TjLDg>|N6yOgZb4{_uoZ zdUCs%vZ8>UwZNOuJ z%6Jy2w7fZ5P33IUpPe``#n`F5%hSzKJF4Y0u6O$E8?hYqO-9 zS&eSt&p#g*0K3m^u$xHRG`Zb1H9vwHO*I+4_O-_~-L;3%=NV#t!SV5J={gTpKEB0w zJKAd1LU1<=DTUa1YV@(iy$x(chon5$K7TMiP3*&Am{eru&r*dd%-~+grbL z?PNQCH|>p2%r#K;TF*DLT!Ux1|2+T2S%hzsGkh&e)~xFg_{3}@3BCc(emaNneeDe2 z+U=Vkw&QoxdgB~*ePY7F@x^p~SO4oTR}#L>&hXVw`}q5gfX^@H92{Rv)VKHScmA93 zed7$@g4XrV*$G?H-TM4u&cX00X1dEuIe&fC(!j=gvY+3*5v|JmGH|^<( zDTio$W=*H7TKtC*y)Dk@)x9-pm<@X4|Ho1f{|8P)&*v924o=SunR&u?#+>DEQq%p; z20a$Gn6|g<`9*g8ZrbA$GY(!$oY$H#I6mpKb9WKG@15aW@XE%W9Ri=2awNewb6|-^ z_WaE&cX4Sg#h^g{V6wkiV5FVXZTjMj=R&2-%ab)qty3_IS0dM zPD`3~6i@qd>1e{&;tXHIlD9wX5ctHLgX7~<;yNE~ksGFcL5_Eu4SXzYxr>js-21{A zHXLu-8=shR@V0Bx;{(lWHg*1vjBmRQdd%^1dTW=>tFz-d(j8wwOgUJoY}qU>#s{l% z|GBG#@a=GhZ^ecu-?f7;t=CbZ>kEh}hghTC;3-Gnyf1eXz8{_8TRVI7R6CA0?eU2@ z2ghfQ7x2bgzsh)l@crZrU&E~CbsYkqm~?P_e3H%FwNbWQewOh4>|ow|@1UUfl8O?Up>j3~ySm38Jn~%sDtc zv(A&bFr)r~pNL+oGkR;EUUO+jpeLproSs>%i0@XD%ly7&<(U@crft--1QIwR8l0K{4mx_;}c27Uk{!*}TID z-)?94T3`I8(F|K$ZGAy8=ivCvV%mJP>hbzcuao25V*?)xTPCuwrDlBn(KZ}!y0_h+ zm~zClU7m8Be#(f$gP^z92E9qNw#B5k``K%=YiU{EzHKkP1uW3_1;w0$^D`$pcq#X# z_g%U(AN*=Mr7lr$H?520{AREGbfKqX02H$hlq5&!!k5yUz26;Y^;ks!Z*>N+;hwtI z4gye2JUGCZy5yH1WL!-EYn=gXT>l*;wmU=<7IO~{kcTgRK4|s!_N*>I22>g9GFq*eo@21u9BpD4RyBRgJT-- z)mI4b`1{j<7+hKEZF2W54ws4lH4Dv6*3rxqCWkK;b|fN@=?>|LHDbv{tA z%Sc}5L-Pd!il6F*7^rK2MVaStuv zGyJH6hZ=V@nPLMVOXBWkN!*I3KC#I!otK?b7buQn zB{9GYkNw9+GQg=e_%SEU`8BS7bCS&#obCX{aV%7P_2;WJj-~nxK8JqD*prVXfYY1- zY+2R%Q3nAij$;9!ny*n8sQDD$g0Ct)U={(K?hN4O&!0gBHCtU^x?7+)j>T#YOXP&3 zHHSm%?mvgRzRS`VakQ2RBgF=?G4 zV}dU{k4xWcUV6SKdhIvHuISmWR1kCn0;mHUQ2mexK;iKArGk0|g4hcVsUVE-D0=q3 zw1NT{*irp@3N0g47}tZ)H&CEH1zA z=G{C;vJw1te);(q+X_H||C|Dl0b{TfJmM9AB&bOQqGBnj6@W}#8uK!OSW3XB8&sdy z)JuUjfhm?E`6^2+nGWblMI}5WpIr%{s8~u$2_S!R2_U~EmSSn)*b+-gIJkYWl&A!d zh&oOJNDfML-%` z%G5QV7R2bJrA!mFNdJBXAl(v6NvQzj7i4KH0Ust`bF#M|-L>_ZEz$`Gw@~f z4)gZmKgLv&OnhDC0HjAC-(ZqXtpH?_j#mK6)FI%)2CpsF7U@_4pf;@~>p?7g7iF!K z7+IQtMLMMbP=IV9%iEJ4)RMgc>7ZrP)@QazCmh^9={5l%lC?UNlZJ>6S*uC93*_5; zOag_O^zYFBk$upvYcuJ91^V~v|0ovd1pS|zwst#d20HSYW=DQkTSs13yE#h~Ou>l0 z7S#Nhlu(%TjwMVIC;F15_?U_egi=T(M#9vGb!4K5hCdjD!OzP&T(aIKC6i1X z2)Bv9&^ZZ0P$#O360b}|sxv_@7z4=@#0^H!dMG(g7bO!4CbEj~1tE+M6(61A3$7y+ z|NFS8B@Kf6h31dmp0te8_H*%MAOi!=kimfv67EflK5h~j7!~q~P+VjUQ#Yy%A(Q(8*WU>&f^=v#@uqykPbS}g$CBwbRVmea0c6wH_Xfe4Zx zfm#%3qPi3MVo1;6FKkVdRftPjuR2bx1`<3lN%j|tIRTiPW*w>|VqzvKTBRic6#w{o zRE0~25E7()EKc!ySt^8W;Ck^_fwPvhs0vo1R*t3Xr9((bM;vZ(ic{*zf+iuf=-6Kc z9Z0aPoxylw6mZByR7^=~Fiww}O6)x)*FOs4qG;-B5cfcr zruWp2{1AkSnEaL>uRZ|Dr>Lk>ytGy(J_u+Uu1IoK3iOw(*90Z6*oPu(CRsBBB1mcw z7Txyp<55XIC@KY@Nre{x!7kO%Rv!H1d`c-FKnI;l0!TWts)=1D{C@k{4Dw$uX(^D~^E#HXmuVva}6=v_iR;A1xU zu^@)q-}+^*&$sD=(;c8NKck}ndfQ$LzE!I9uMZzU0B1V`*!1(GKXnj*!UT;8(5y(v zT}c0xTf7ADac2OVznEfouSJ?$pdzFrCO}?ez~C0x^zK^{0er$4z!hK2v%A+Koj2^E z4^+fFgg0Tn&dK2W^^W=Fq%8z+jx&JG56yhbW)PF^0L46n1B|i2`fs0ppPcYq8vt1l zGnEA~^-C7|?K$Ce=O?BiOhiL6cPsM&h0(P+>xkbxXZ-4#UjJ9e;3sAwtf04KZdc~3 zbl#tG%Y=d8H{S-oDYRjf+uw?J-(K#?9&KmGZF{V4y2h#l(am)J=bEj!wVxuU10%oHK-L-)P?4VGxSB2}fuav*A(U zN4Kp0ju19FL%9Bf)`j*%nl!gUF*gB1J>N?2cwgA^kIROULw?={LKf0YXCck{O^+XH z+aaepK`}MqMre-mb0d85{DYSf!Ntx9*55XDkZpo#ybptVLNPPp1kL=Ik5UgkcThiS z#xK|)$U+)!gAG%EecZl3PIH7}W`Z*6$d6-s!f*Hd_6#9>(HX*)XSTL<7=&VK!V$(; z;onAoSw#q!I78U{_S)y{`{VSsLQTv~I6^a|;fs#0?)H2(A$-Xh!loy`h0M6!fvz-O zxP`t@6LXU!4f&^rt9uc`mz^P;{Y&#Q`yox5BNTHJj?f&j=I-Q|hZ|W)(_{l73u*3Q zA&_6ijNdTH^vA}&R3@j4{x3&q@oBQ&c8@oAGsu8#ad4tco^ zge;_)!9tqm7vB04cgQx3Fue)tVrml8j`=v(hQt1L2bGLm(Qe5|-H=WhcJ5vrv$9k6 zuO9D!c8WgsL-`2W*Qj8x8SREt6;|`{sJNPsnW*`QjYfXSk4<6NMHP_St@x-*C`9QE zQR&Lcs&Gj;3QNj9?1jWWI+-@s1vNufS;@yk#FCFabpyL*An#}BnqRE>Xe6u5s2D2h zN*=)LSPrQO7lum9DoScSeKen{83uM%1~mMX?7!^p4J$gq*52rpwcC>>_x$gFStu9T zrRU|+9u@ZRrj1tD3$wkv(KUVY`q@R-4I6zd7`|oW%so`uf}gti_IJ+JvEYp`8(ACJ809SU7m+aSo&INpxeF4{BBOyk_|c*OxN2u@&R znvV&ZdI^)y`(PU(e9sxe1z)}MUWY*_4rp)(zq3k%?H|c$Z1Yc9MOs~LcSm8 zx&QJmIt>KZ+8{WUR$y|1t&h#RD*JX?jBckVOlJ`&>I=mott5S6>GS(F{s9a>vcYf~ zt>`zc@U_=wc{&b6aZrm%c*HuCI}UvLYJ&K&Gl(rSzuMe^Ad16U9HKe6%7?Y)oqOkl z1o0DR5F3A*^khbdu|z{0+2RoSx&?EDy0QDF&j{i=XAtY(8TGRxi0Qi~7=5B4W+_P^ zp1S4@)+c`I3}Q>;t`|EwvBdO-DCQ{~qFGU#kFg$;CGREY{Fx1i%sJo3!kg9~-#@P7 zFccG&B!>9%n|e7h{M;GC1)q;A={O9md-}a^-A6s*dK(N`fMeR>f(cJM z^3CbHu9bR3F;`*b?X?(SIpV? ziX-2g-j*olD;%O3;BdF{%#p7xA&6f%gSff5_6|qBIlUo@`3i?szZ*PM5 zr89_44{crT2x9s!aHT#`%vTshYn3&3EPwgykvqsae`NzA3vljd0Z!eR#|Aoa&S}li zFQzL=3~~9~<1QdW++>3xbIzP${avF*Ir1y%4N**2P$d&P7|kU$d^r2)t!q{h#IKz} zY?%1N$_@lkOjkHWbM%3)(cJ#m;Tl2Q> z5WjH-vE}IvA9NsyV#2~9nh6Wv(|vf$N#78}Z=FG`e|ynPN6tCDEm2HZI7D;wfj8%& z(x(@abKYVDB6H3&nR9O3@pBg^&N;mqiunp-sG7wEc+p$$m>aLZ91Bjq149z@`3q)H zIHpzKeArWz?5T{DRx!yQ!$1k_V7rPwSiq*M8hsr(@aTz)aqaKTYd=YAYWT%>Hm>W5 za@_mY>VbL({T6BU9PQ04uP7@p#aMh5m6aEqS?lSg1QZq8&5%J3q!jDXf>3QnZE+k6 zF^FTly}7VUs5}CB6MD2vL-kSVb?7IcYB}naYb}4fy%S0L{9rx+N&4{nf3J0BYDphM z##%j8OIH;7xx%8#YkhgX>dNqt$l190Of2RLVCqKk8z@YVg6_fWUf!`BD1&eg7j89j{%lNRc z+UhcXhmKFsr|61h&@gMurhmVVk7kih(DC`~z3K_~p!aSud+!Ntz4xpic4b-89El1d zP`hA)7qS6i4o&h~x*&1N6q!)u!8?#iss_p^kg==+3B+IAMRW=A5KfELD_b8X9qmFy^63oYV;fR3z&Uex2k*tg9w!^+3D^8XHm& zx;W6I(V|C&_LPAk8y1^U_kLlxVbnsIMmW%!6isFq)jd-LC6PeAasQF z`?5~fEC(!(@tQF(d>z-qVqgt=D75n5Ix=20C6iQBpz8yH6B&o9MgYQsTnQ!a-Sm}6 zheknrf(V_C`&1!&XGzAyX|ZG~5TemE3E>O{E$BI-N5<(mWkT@;F&p~>RYMhy!Fo4w z5=89X$T$$HInY7DrK9gAUSB0khOR4K4*iu>$qyYz>#A@*@d7KR4i728-~l19Vm(xx z9Z_#0t5El$Ee_~XpxR=J`@|iHcWOYf1&t#J#VHj+Fwr9up9wl^iULriSwyj&bC+@9ef)g4#`vaXI8E&~GiCBZwxu-Uu-~R&iMg0 z(%Up?;-MXfVRYUm%FrzF!lxV_*|>2TF>H0laPyY2;f}*FI&EVyRAU019aCDHs4@R) zgCPrWxE;<=BOva!0r4K%Tx9y@)eBI+uFn1rt#lR*p z^v*dQ828w~IGy&nHLwt;=BFmP15pf$DGR5_yG4VKK&*T>G@dB#bw;uA@v%2_B#L6r z!W9R2&SLK1pF4cyU#UsgbWS}M#+o#@#pw^ehgx`a0^_Im4K?yNu8?2F1jMV>H)F@d2?{eS>c#jJ3`%wk(+UR)@kU<}Mtg8Fq7w zea^q;J;FH38OElDr|x#@r_p$#t zO!p5bJ-c5+N|xo7hbt>1m3bv)h1JE8Jd_wKtsIQvVwHJi<^SJGGJ)QKC{N4G zOD!qQOVul=RMJt%0rC_~foU=`Pr(MF(~e642yzoM^E9m4Z^tk(uxN6b9CES%1wcCy zzCZ*(L&v}C%&8Y1%-;*t+ChZYX@~$gal7RPvs5d60>A{=>H`J<$j@NUgO6uREz$$i z8k$^d-2%^oyx&QL9}wQ(aIxzK^Y@>-r(OkW?IJ?!I)wLsES~`M{#*jyhsSYhQL%n; zQ8FUZK>{#eAR-#L0tq-G5A-Qm3hdP)U}7lB1B&IP7R7^wGxO5pbMo^GG{Ewpv;YKg zSO8bhUX}n*uyhmQH$<>}cs=a~^Z!{Ve}ax?>mfob@MtzzmfbLQ9xzzmQaf0H%NW4H z5|1Sqz@dypq(SqSAsAt~R!t4?r&IC&D;{OBQr&y1^XW>{0_w${pc5B?3ti2O56^uSE&xs?5@=2U diff --git a/wandb/run-20250310_212229-x0wxhan7/files/requirements.txt b/wandb/run-20250310_212229-x0wxhan7/files/requirements.txt deleted file mode 100644 index 67cb06e8a..000000000 --- a/wandb/run-20250310_212229-x0wxhan7/files/requirements.txt +++ /dev/null @@ -1,295 +0,0 @@ -pydantic==2.10.4 -urllib3==2.3.0 -scipy==1.15.0 -myst-nb==1.1.2 -pure_eval==0.2.3 -wcwidth==0.2.13 -attr-dot-dict==0.1.0 -emoji==2.14.0 -mkl_random==1.2.8 -keras==3.8.0 -nvidia-cuda-runtime-cu12==12.4.127 -torchvision==0.20.1 -cocotb==1.8.0 -wheel==0.44.0 -imageio==2.36.1 -dill==0.3.8 -pydot==3.0.4 -transformers==4.47.1 -sphinx-book-theme==1.1.3 -myst-parser==4.0.0 -traitlets==5.14.3 -nvidia-cudnn-cu12==9.1.0.70 -nvidia-curand-cu12==10.3.5.147 -kiwisolver==1.4.8 -pygame==2.6.1 -greenlet==3.1.1 -pytest-profiling==1.8.1 -requests==2.32.3 -aiosignal==1.2.0 -aiosignal==1.3.2 -Sphinx==8.1.3 -torch-summary==1.4.5 -Farama-Notifications==0.0.4 -sphinxcontrib-plantuml==0.30 -ptyprocess==0.7.0 -pexpect==4.9.0 -yarl==1.18.0 -yarl==1.18.3 -filelock==3.16.1 -filelock==3.13.1 -datasets==3.2.0 -datasets==3.3.2 -bitstring==4.3.0 -triton==3.1.0 -py4j==0.10.9.8 -pybind11==2.13.6 -pluggy==1.5.0 -regex==2024.11.6 -cvxpy==1.6.0 -sphinx-test-reports==1.1.0 -jsonschema-specifications==2024.10.1 -fastjsonschema==2.21.1 -pytest-xdist==3.6.1 -smmap==5.0.2 -onnx==1.17.0 -tornado==6.4.2 -GitPython==3.1.44 -sphinxcontrib-htmlhelp==2.1.0 -iniconfig==2.0.0 -threadpoolctl==3.5.0 -cycler==0.12.1 -tzdata==2024.2 -tzdata==2023.3 -certifi==2024.12.14 -certifi==2025.1.31 -numpy==1.26.4 -gast==0.6.0 -frozenlist==1.5.0 -opt_einsum==3.4.0 -astunparse==1.6.3 -colorlog==6.9.0 -grpcio==1.69.0 -jupyter_core==5.7.2 -torchmetrics==1.6.1 -gprof2dot==2024.6.6 -nvidia-ml-py==12.560.30 -multidict==6.1.0 -etils==1.11.0 -jupyter_client==8.6.3 -sphinxcontrib-jsmath==1.0.1 -tensorboard-plugin-profile==2.19.0 -clarabel==0.9.0 -idna==3.7 -idna==3.10 -pylance==0.21.0 -ipykernel==6.29.5 -matplotlib-inline==0.1.7 -jedi==0.19.2 -lightning-utilities==0.11.9 -namex==0.0.8 -kornia==0.7.4 -docker-pycreds==0.4.0 -mkl-service==2.4.0 -fonttools==4.55.3 -tensorboard-data-server==0.7.2 -beautifulsoup4==4.12.3 -Werkzeug==3.1.3 -Markdown==3.7 -asttokens==3.0.0 -huggingface-hub==0.27.1 -huggingface_hub==0.29.2 -pytest-sugar==1.0.0 -tensorflow==2.18.0 -pytest==8.3.4 -joblib==1.4.2 -ipython==8.31.0 -mdurl==0.1.2 -optimum==1.23.3 -pytest-metadata==3.1.1 -debugpy==1.8.11 -absl-py==2.1.0 -mkl_fft==1.3.11 -sphinxcontrib-serializinghtml==2.0.0 -MarkupSafe==3.0.2 -sympy==1.13.1 -six==1.16.0 -six==1.17.0 -multiprocess==0.70.15 -multiprocess==0.70.16 -snowballstemmer==2.2.0 -zipp==3.21.0 -ale-py==0.10.1 -scs==3.2.7.post2 -find_libpython==0.4.0 -sphinxcontrib-jquery==4.1 -decorator==5.1.1 -nvidia-nvtx-cu12==12.4.127 -prompt_toolkit==3.0.48 -charset-normalizer==3.4.1 -charset-normalizer==3.3.2 -nvidia-cuda-nvrtc-cu12==12.4.127 -evaluate==0.4.3 -tensorboard==2.18.0 -lightning==2.5.0.post0 -py-cpuinfo==9.0.0 -prettytable==3.12.0 -nbclient==0.10.2 -execnet==2.1.1 -torch-tb-profiler==0.4.3 -kornia_rs==0.1.8 -contourpy==1.3.1 -pydata-sphinx-theme==0.16.1 -pip==24.2 -requests-file==2.1.0 -jsonschema==4.23.0 -sphinx_glpi_theme==0.6 -imagesize==1.4.1 -osqp==0.6.7.post3 -importlib_resources==6.5.2 -termcolor==2.5.0 -importlib_metadata==8.5.0 -cocotb-bus==0.2.1 -future==1.0.0 -pyarrow==18.1.0 -pyarrow==19.0.0 -packaging==24.2 -sentry-sdk==2.19.2 -einops==0.8.0 -nvidia-cuda-cupti-cu12==12.4.127 -bitarray==3.0.0 -aiohttp==3.11.10 -aiohttp==3.11.11 -nvidia-cufft-cu12==11.2.1.3 -scikit-learn==1.6.0 -pyzmq==26.2.0 -Mako==1.3.8 -platformdirs==4.3.6 -nvidia-cusolver-cu12==11.6.1.9 -markdown-it-py==3.0.0 -wrapt==1.17.0 -tensorboardX==2.6.2.2 -protobuf==3.20.2 -propcache==0.2.1 -propcache==0.2.0 -pytz==2024.1 -pytz==2024.2 -wandb==0.19.1 -libclang==18.1.1 -nvidia-cublas-cu12==12.4.5.8 -alembic==1.14.0 -nvidia-nvjitlink-cu12==12.4.127 -click==8.1.8 -gymnasium==1.0.0 -Brotli==1.0.9 -lxml==5.3.0 -tensorflow-io-gcs-filesystem==0.37.1 -matplotlib==3.10.0 -tqdm==4.67.1 -annotated-types==0.7.0 -ghp-import==2.1.0 -pillow==10.4.0 -onnxconverter-common==1.14.0 -stable_baselines3==2.4.0 -imageio-ffmpeg==0.5.1 -onnxruntime==1.20.1 -typing_extensions==4.12.2 -Pygments==2.19.0 -coloredlogs==15.0.1 -sentencepiece==0.2.0 -torch==2.5.1 -timm==1.0.12 -mdit-py-plugins==0.4.2 -PyYAML==6.0.2 -gviz-api==1.10.0 -xxhash==3.5.0 -setuptools==75.1.0 -pytorch-nlp==0.5.0 -babel==2.16.0 -soupsieve==2.6 -ipdb==0.13.13 -python-dateutil==2.9.0.post0 -comm==0.2.2 -flatbuffers==24.12.23 -rpds-py==0.22.3 -psutil==6.1.1 -h5py==3.12.1 -numexpr==2.10.1 -optuna==4.1.0 -accessible-pygments==0.0.5 -tf_keras==2.18.0 -mypy-extensions==1.0.0 -pytest-html==4.1.1 -hyperopt==0.2.7 -tabulate==0.9.0 -fsspec==2024.12.0 -fsspec==2024.9.0 -parso==0.8.4 -sphinxcontrib-qthelp==2.0.0 -qdldl==0.1.7.post5 -nvidia-cusparse-cu12==12.3.1.170 -sphinx-data-viewer==0.1.5 -mase-cuda==0.0.1 -cloudpickle==3.1.0 -coverage==7.6.10 -pandas==2.2.3 -Jinja2==3.1.5 -black==24.10.0 -pathspec==0.12.1 -sphinxcontrib-devhelp==2.0.0 -mpmath==1.3.0 -pytorch-lightning==2.5.0.post0 -alabaster==1.0.0 -jupyter-cache==1.0.1 -stack-data==0.6.3 -sphinx-rtd-theme==3.0.2 -accelerate==1.2.1 -pyparsing==3.2.1 -docutils==0.21.2 -pytest-cov==6.0.0 -rich==13.9.4 -safetensors==0.5.3 -safetensors==0.5.0 -humanfriendly==10.0 -PySocks==1.7.1 -toml==0.10.2 -Bottleneck==1.4.2 -setproctitle==1.3.4 -opencv-python==4.10.0.84 -referencing==0.35.1 -nvidia-nccl-cu12==2.21.5 -tokenizers==0.21.0 -attrs==24.3.0 -aiohappyeyeballs==2.4.4 -optree==0.13.1 -networkx==3.4.2 -sphinx-needs==4.1.0 -nbformat==5.10.4 -gitdb==4.0.12 -SQLAlchemy==2.0.36 -executing==2.1.0 -google-pasta==0.2.0 -ml-dtypes==0.4.1 -pynvml==12.0.0 -nest-asyncio==1.6.0 -sphinxcontrib-applehelp==2.0.0 -pydantic_core==2.27.2 -transformers==4.47.1 -mase-tools==1.0.0 -more-itertools==10.3.0 -typing_extensions==4.12.2 -inflect==7.3.1 -typeguard==4.3.0 -tomli==2.0.1 -jaraco.context==5.3.0 -jaraco.functools==4.0.1 -platformdirs==4.2.2 -packaging==24.1 -autocommand==2.2.2 -jaraco.text==3.12.1 -zipp==3.19.2 -jaraco.collections==5.1.0 -importlib_metadata==8.0.0 -wheel==0.43.0 -backports.tarfile==1.2.0 -importlib_resources==6.4.0 diff --git a/wandb/run-20250310_212229-x0wxhan7/files/wandb-metadata.json b/wandb/run-20250310_212229-x0wxhan7/files/wandb-metadata.json deleted file mode 100644 index 752811856..000000000 --- a/wandb/run-20250310_212229-x0wxhan7/files/wandb-metadata.json +++ /dev/null @@ -1,60 +0,0 @@ -{ - "os": "Linux-5.14.0-427.28.1.el9_4.x86_64-x86_64-with-glibc2.34", - "python": "CPython 3.11.11", - "startedAt": "2025-03-10T21:22:29.549593Z", - "program": "/home/jw3621/Projects/bert-onn/test/passes/module/transforms/optical/bert-finetune.py", - "codePath": "test/passes/module/transforms/optical/bert-finetune.py", - "git": { - "remote": "https://github.com/Johnny1882/mase.git", - "commit": "758710333d8ca4b7444930df91d86c7642652426" - }, - "email": "jw3621@ic.ac.uk", - "root": "/home/jw3621/Projects/bert-onn", - "host": "ee-tarrasque", - "executable": "/home/jw3621/anaconda3/envs/mase/bin/python", - "codePathLocal": "test/passes/module/transforms/optical/bert-finetune.py", - "cpu_count": 16, - "cpu_count_logical": 32, - "gpu": "NVIDIA GeForce RTX 3090", - "gpu_count": 4, - "disk": { - "/": { - "total": "75125227520", - "used": "60982046720" - } - }, - "memory": { - "total": "269555560448" - }, - "cpu": { - "count": 16, - "countLogical": 32 - }, - "gpu_nvidia": [ - { - "name": "NVIDIA GeForce RTX 3090", - "memoryTotal": "25769803776", - "cudaCores": 10496, - "architecture": "Ampere" - }, - { - "name": "NVIDIA GeForce RTX 3090", - "memoryTotal": "25769803776", - "cudaCores": 10496, - "architecture": "Ampere" - }, - { - "name": "NVIDIA GeForce RTX 3090", - "memoryTotal": "25769803776", - "cudaCores": 10496, - "architecture": "Ampere" - }, - { - "name": "NVIDIA GeForce RTX 3090", - "memoryTotal": "25769803776", - "cudaCores": 10496, - "architecture": "Ampere" - } - ], - "cudaVersion": "12.7" -} \ No newline at end of file diff --git a/wandb/run-20250310_212229-x0wxhan7/run-x0wxhan7.wandb b/wandb/run-20250310_212229-x0wxhan7/run-x0wxhan7.wandb deleted file mode 100644 index c7816c0c9959c866c7effc0863411253fd092912..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 98304 zcmeIb37lm`dG~KK3&UlAnI6^#0yL<&(A;y*U41nw#%LnO-58hL-08kE)6mo1_QE1! z$h;y-R73$+WN}x-U5ODi#w9VnNqmV)jKUvx;t~@z29r1bzfV;?b*zZ5!j5}|+-v=)Ezi;W@Z)n5NuA$9m92UjRu`2(I9#*f_ z8jVkHtW-9fS3UQ<9o^~XyNCW)rCm9lFWY0$Lu%W$^^&L=SL@YnN!;z#8WW9Lv$3sO zt=DT4(L^<=wVO%3H@v<&R-?+{h0~(P3=M5KVAs$|qZ{Yrv(K40JC4p-Jg0qX#BqQN;VfbWLOt zudj{Onq$!!2b?-|*2byH>BaL-JFnH~G-~f2dTQkXI~EpZ=O1?Z>Dwn4b}Vik8=u*E z`tQ!{n4UgAYPI6icXsD{V|+DyYO~&IM%5%qCR*d&+SX>RR%<8KiEZs@qSY90HfnLB z9`olRU)ucT;~B9v@iCR7DqBqFPutmB=uUJOx?^X}&rBaZbl&LEUy?Fw2bB5laZmoY z5z4HKkE|SC*)-RC{^I0ZPvkW}ws79U(L)p6E53I<1=cUA8qdCPIU*e1Q(_)I@yahY zj~d!=;F1k}@a1Fl_29YQ!s6U?XJT@EVb{>wg}KFENJL%CVKr?)&b-0y3JMnNGhvqH^%U%-r~n`SH2QS*mou;YsG>&cwp` zvpt?Py*M>hIkK{GalY5tx@~HvyAU;e>7y%$E^O;e^|p7%&+kml%%|0Fscf8`Tb%Aq zbawQ*6MT-QSB|S}S(xl~&YhW?==64O?M+NDDfrHI=lIG|<2x3o&+g1mzM$9H*5kL^ zX6DZA&Q0tZ8mjW+;mP?=$ZfemKhzG}XfR+l|_bEUGgdmciV-oCJde#X_xCO(>+_8-ysEtP{NW;*lTowHLs z*2x;_^}2IY=QD0Iv$Kf8e{)#npy|b(ovpp@PQ)`5RgS0}rXO^+&&@2(+K-Q^Y@V1r zr#Cl0xp01Gwm03KS~!2#&>Ge*o_O#=Zzs!17ro(`Ya%|}SXsxTI%m_&dX;s^{A<>I zxw&!>V?H_Fo$BaEBbDK~-fVASQgkQt0IBQ9%AvN;!RI5D!>4CD^0V$jXL|cwcc&Rp zGlrv;<5Y~LWsa;Ivb8&5lw`g$DXL?2Q0he{ZXV1t^{Gm<9q}AhImBuD{P@gVW(u5C zIVv-8It_e{TXP1d$R3#OgebUjD!{!%fXKB4_h41-qRF2r%WB=%s ze1HGSAzfzjAPi~tW@62}5THjKtot>N6 zI@R0R$*ze=y))h2+3U>A@dKIqTSjFkKW!`9<7tc2gM6Pq*{+-Ss--VFl*tOW~0S;ZI(&qzL@Hs zPg__lqPQ|r*&^R$7E-sa^BvE1zv#;Dw0?AZZ)$OrpKPXqJvuZo$&7Q4Uz=2$F?-V% z_mhdanOP=pM`Uc)uBdWK<+$`Y@;NJezSIeoqh#TooniHpzFxG%9;Q<*59VoS-YC-L3O8Q;Q3| z(aLb;AlWL%y9^PvnvMF{%COOWluNcQ+ZEOOV=EQ6c1=vq?HbxRYKKtp;HcT+QH68o znB1LgF}H;bS2nP(pgEJMiu-I+2`u@dR?Q2;(+x>@w#${kZVL<%t}WIf!@#&O_i zZ+2#Uhk2|kErOMey)!dCMG5!QlPbruv#^|j%13M2L&Zg zsvIpPw$CmW)L>`ki3r^l$<<&>&=Nl>y5r-EI~S+y{F+}dqsQhzCFaLPU)xI9uT2T2 zPM1!y!tfglgV;jZTFq!8C{v~V4O`9>ySG_O-F{NUE+}{&s^Gpd>a?a_txDg`XA3h6%*d(9os+UN80@WI&3Cu;K=Jc4vLfrcvf2Eq zUu1rIALhOUVVX6?_bIA85SW01%Q`hRwY59W=4ZG=XWQJ&&dl!Of3mr9Ac$>zae^^( z58xxS(#4RvdmdkjmQ z)HylZ11h-B4y~-+x-G*)`1iKi%)gJS94i0r%*-yRCaxWQQZ!mQsdD6w?$kC}3&$tt zWpR_v_oivs`cc`EM`a6;x;xt-KK(7>8&hlkn?ovV7dTSTA!9w6V#A%1y#Np z8q8^`mru>y01`>hSygF@-8Uf}W;mUuP_GrEaLG(LtpPBweGaC=dI*1Hb+&WZ)IL(# zG`Srh(Cb)jJ$o`;7SQKy^Ar4A2H_Ye1pnv==@V%xengntDhHdd#%HFtO)`kjykNAy z6gXK>^LbNU{-5FP_Fo0e4=yiT$k2b zvHkf_WtiQt$1c!iOp+*^h&7@cQfUa*^aZ(i%CNsJmO7d8#eeIl$IxJ{<)-9>k6}W4Onzp`)KOY8UdYGaPOR!7>M8 zP&Ao(=n*t+qT}X)(Ipmsv+6sw+-%m8MHViYKo`9OIIA3vFyGBFD2H}1 z*}09r?u5V20eb8-dtl}G?Dza~l9dC(;TH%e4S4Z9^TGXGmnyC1l*)*VzdIP}gOW4p z9O4LSc5!}(XeXzP8EGEL?z9NA%56rrlMph>yZN3{_2ccw>5jWmRCA8w6*l5g7{@{K zD?|Q74IcjZXPnV_`jejU*x&n2gFWQS*}2JcgsI42sYn@j`Kmj0Zuk88jx^Y}^Z}LQ zy~g%2vW{XW+hD{oIXynLI3ec*oY&2p#mFu1zBcfemXfu~e=REvc*7!8_TU+)rJtcH z@uMmm&EAyT&Q7RokzXzXZcYAkw@&%?gmC!>fo3llngpU|~ z_>-UUw8tajPFB+3shzWSg>+hdz^EY{vin-i&+-JPZY;%S*PW2xuz*g^r(1^}Sx%e} z9rwd_I8%C5I9V`9wUBH!Mx(4+&&!IN z5<%wZbZ;ByJw3qmB=fO5HM;}+W|fk0w<{lKBHLL=XgajOZ>^wWI}S$gyfV5rATjef z%eXmH@r`4k%*>zMPQ!r|)4~GWkJzovkEFnMPE05BfJ3`3aAvxd$)V z+^xqA4?TFapwnwNVA?wJDR+G5Q`EC@$$FmshX0Xn|1L(T4Bi4}$>{aw`VDILpMLDG z{*rDSvUE4DUcGK?lF@q|#y?~KTWSE8*OOOY{koSwo}L`KWKSM^GED1m1ISPJfY>ef z8*25WUX5F|bDr?rkDrKn^{^$YJL{K2!y7kTWelMI@wN~DcxdFVp+inQv3ma(o_KG0 zz3{~9=~3KBswY0DT76ix`tV0S^u$L_E}TC9ys!Q3b!-0&sc(X7oI3Qd_lYR3F{1dy zPp*FRnvqi~hgu9Ld&M|d4RJP-v0BnhTFqM2sMVWMH1c3h(&lEy$86ED?m68_0P|Gu ziD&E@I<#GFC9Sa*mfEP*tdE?;ex4~iH$N|>4vwP+Csek`gDYFIMf{V8t7ENNwNZ`h zQ6kp#R=audt4@CR@6i6uw*8Mvn}7Q|Xg$ zGif~T;rIKKmr&oZt?xo#-*-Ryed>EqWedCR*<+`VEzYx#JlV^nh5Z|mP<0eF;zqMI za&qO+Y|+Qj$xZdfShJn9>y2i;QEx}BkwZ95mHOm>b*MUWCj{_z0ZGe z-Oz?DwlVKK*tg~#w_S4Gnh|VjoNwZtz+Zc4HiALJex%5oS|3_DBKz5+r{wrxL5@p? zqj;=QZ8xf|cB55ox7zhKT^uPaea895`ddljTD#VaYOU4?^u`tyTOd1stuT$iRuuYOyIGG$j;I`JzuVTGeww9v2P8GL z!u@1+e6n}ut`8B z9~O-v0Qe)*340 zc)3D$wWLCBvgS37B3pG9)imlHr$Sy6N7WIG@uIQX3n$|>wQ60JIqnm2Rh5}X_K7%B zWlr)UP6|YR*z#3}@@B*mNti}w9FB%F>Ml4~dnsL3_j&b;F`Jzcs zE2t?-K&BC1g3Sp0lJ+r%)0+e{(X%05wq*09*r&PE=6Av|z3!)ke%@V{SyuTmSSp ze7X@-#DgQsv4hsrxfW;%U>v|(t@0;s!3ApVwD;jaMPs1eYP|(SW?nU`k!nf4&bsh7 z{-AkPaB@2H@+~cUd=w1n=4dQ#H)REBp-!ktboLLgJ^JII{3Go8{C;6Q?=r0Cwg383 zVE+B{tY_G;o>m(bZ&n$|II1m8dun3>lzHT8)1GfU|D2x*?K#qRKxmI`|G(Y(r-ig9 zZZ;VCBHF`rs5YPXg&(|8XwOl$KA}CfzRPa^^gT{{hN`>nxSrkfXxo_A$?j=ebK4a+ z?2WVsgrOLw+K*vc@T-PQ%Vn~jI}p)j=~(woCIIc6&k4AtpPPDWi;Hm2-yZtchU~ z4uc0gXH7=V{WPboat>*Z$Kd+{%}H7NifB%jw1eRG(;Ry;vK~}Hkb|o6ifB#@a_|Qu zo=NUGn)Adnzw)_vu}~jl7wT)$h5FmK{l{C^C>rDFr;p}<#)LD-J8oB#R=WmGa1bwq zH~`Ex8!iRez^W045Dqm45C=4ZrxR~rh@k&C_B;`6@zbHHOLs!^`t3>hK;J&vauYUHd-S`J3_Y!*JVgv)PT)_7m08uBS{dKHw+8)OZ#kA06r?n94}p?yUy&U|Z}D-K}6W_$~Xp3z!1vbZJ0 z)k~s)Y^czxDhVYUL9LJtJ876-ak~Wq_1z6hA{jdqDF)>NVtBY8f@r10U546xAtK=c&DlWLDnEs?k<01XY z9F14|&TaQ4`co^VKMChAz@cT+A7@Yt=?}+(&1S~R)2b;n?_)oym_zoH5py5=ftNaS z+P=(w>UD2{SepId)nsinW;%FHUe_FJVW-^Bcv8m0@uq6-N*E8ehYaEIn!D#1&-UNC z@z%>(rGLY&(wC&G^xMAr9T<;di~tP`z*%t~9OuBQ!x#nFDNiIE#t1=90giSh^UsHF z*%#yl6Ntw+1w2m4nVi>xGdV3Ls26yJGekeiS`&|XXjuT(Ks7BGSv9G*sx_>)fc!$s z6W$CtL+OdFF6$-(% zAcUg>Wo_e=3T?W3lkMREHZ0Rsgiz=+*P+iR!CWt)x4UV{Ir=qDcQ@<1UhDCTC-in)ylK7 zVm5C|V*zSl3dOLeG6+QLwIG6G8$%R5%NSi5%O(U6LopUlhSGI9P-b6Cc0#$L`2i9N z1`{>sJuu8x?ap_8;i2Njc7k1(jT_s|hT6ROvlj#Acf+Hs-1s9tQqmhNi49rI%6;Ar z5!w>t+Ob^TQn~fT7YH*u(Y9agIkx%NeDX?~pLb)cVT3AnW0N!~)u;TgN{Sp1>{D~vQn>+rI9lR-0M{q%ZAnYTytHj+0J*_g zG(8pf9P++IZX8SR_h4Jc2{v5MDe}_X#!0xyfsKRVj3%oh-$xwS67j?=bztKhUPZhj zo)cju@Ir2=i1!@Nd1db{8{WW*eX?D#UzM)dcijAml;`9Sv_jI%dA7oaJkqt~CBmH> zUmKv2_>q^n?Qsd`^J;$5t z0xdDmvn``+4Y_#_&U`0wCMnkmcFDfcEMhr2X5sT?rrH4}Sb;4Plf2HpJjuz=R5woo zDHbZJN<81#+(c!QF^3013ovLY;^%~bBlEPafj-lKJl9rf1!b+Y43CI`Vq&T)=b0)j z(KifLp2h@q7Czt2(qIAfrwwt{8*Fv3@}jc6O3ift@9ii5oe+joQo`_g!*_1K{q==} zfoW2Ynpq~{;$C9>R-?uNzep|+0(E~|pCM4T!msUmyCG2JY)2w|ScdTzWMVUKRyV?@ z)QVKYZGz>ENeS8d;dhgBgz!ATwqFR3ZT_8?eU0W1w8J!T2W>Q~;%o~gE@L#U2E-5l zVP zT+Vpnx+j$@VLUj36fG8ej`38Uy0HCyEZL{pCHrmZlKsh--JJ1d3()AG4L=X$TJEg2 zMX(6Fh<9cSp@E{9biY@KTVOk>i~md&xAX=v4uoI~!g22T=~;2fVYNAe2Xx3*2H9)_LKt5vCq8ypyYhO!*@tLA|%&Eua-~sO=XNhD0 zqyUmy7RZO|S$e9SA6G4z9>Ql%wFwHsqei?~uspB?xV8n%$ae!L4+%aZwki$12+ATz z1a4n-KxRAf2D2pShQz3FelKt67S^Yl_~?|zUm2+dkea6{XCr(>T)#Y{4$XWr!l4(W zj_`8ELKP>Zf@*Ec;9q4R!OXFf45Th51}u%$Rwja_VD1>}5feff8!<0_LIf?4(NT$e z&2a-6NZb~sG;i;Fra~wvTY6X!w6?GbDurOCQC4j|#K#b^T}zcIG!18lmryR~wuvG# ziVz5)NSAcPia4X#TzIQHs4|v75qLfpAM?U|M7Vq6ZFpg&O1!K16~^Q!$k=D|IvdiU zUmdV^?j~UWgY2q(qj;cw$?%|?Zg?fQzsv&-Llmy3#MTjEB#K{7Zp4&>g_4jtjx1K2 z+kAW5#~vXp?!mVCV)n7Ef8QN%?q_k1tR;A8a!!`F2Qk2S!k+xSlXkq6`W|BI`(t0< zpZ(ciQ(ur?fPN3~PQx7)H)#UUIK*g+Q7p9c4C4*g;_RXxS6RrhZ4z|e?+$m=S01rb zXwXA#a|{i#{rkda{%)V6K{$&mcw7+;N|F>i)|;zEgM``_CNJ=@CVNjz&p;YPjHHTr zMRp$!@lufP+2(#4gx!ZEB`V?-8Gbmi&d69bc|Q%pOQS~oiDQ~8p+WdK6p47x(V#ay ze&e~4a^^IyHl${+J891E%y*i;a1V-K1GIA)=)g)M41aSW*lwpj2;=m`ITe>0ZUW6C z@?OV_gc@-8G=46ePq1HC6vG*2He$3~rAg2)vKd(R3u7SbfK3NdL~RNOOBFV#lp(NI zwmJthHOXz}pC1~zeRx$Glc<`N7ppR72y{$LUe$Yv7T9;K-!BeughACfNAupNVX&ce z%mlP$%3k#LgU%N!F_uz^e=QaN6p7D09f0w+LbGQ*^T;D%2Nv;?L1nY7x5E zP$6;zshC$ng&L%2^wx?rjA}m>5*Ii^IG1e^Dv$5MmvcY;QRhwQ52QfFx-|W9Vht0w zts|{U;=`u)9Ob#>&aLa-%+kyS8?rQCkuJ^O`o)L(DG&QHSc)JLxQXN*DF+iuPe?o` z2N53n5gXa>vZG*zqkM2$jZ9uzYq;wJQK;3v)SJlxj{=ygoFm0(ij?!nOQE`tLDy;n zkmLas_{;n3m0D$JfqjK>B{HNPZMXPpaFQyxl|Pa&n$<9q0^o1FI~jO)5|i)z6?0 zJ{E;Ddh*fXujZ95jN9V*Z4<~(WCv-vIHFIzH1pIXKa5V^dCpf<)KbYnuC%~)9}d+Y zwi4$@JZ~W>Ww|B9-rSRAs0a>0oNc6$FQH+qa2%Ck<$J76OO3BiIQi+IDo%biOr`Ka zrFenJjq%1!GAuGp>|MHggtQGYiVF1w*(qxaYAx5~Fn7YhnH!8FWs7*ewbKT=oNvMS z)0*FW5_BZ7i}Y^;pTA@1%_YCM7C6uC9Q}EoH{#KRNu(L~9)k2T>_R-aD^zRY~pSZh?$OVb`F>5%qp zhW2o%>!W>|_dp`IjaSW_EQ#EjcI7sS-0aGI%bPz21#(a+*rkCK7GUSy)QV!dP0F`| zq5)EpTFYW$0(*c!3U_Fndc+**(t{mJ>J*Gq9Oc3yO?GJm=ks7o&(m|1CTJK>qq}T$ znYd#;P*3tm6l^WVvA(oc7T|^uA4e!XB_}5mHX46U4{`{Z60LPFQlHT=@7+`!i~@8n zU>P|U$~isd%$uMbJd4sM1cN>Z#)ULA591V_7P&E&r0)x*l?JYmxE})~(sJqU zzU+NQ2@0TTPLCYR3slZ`N!$``Zncu+)Jp8LR*Eu$qYBRv0Ey%X(@-`Ts1O(krI8*@0iDq};Wx=*5+znwhfqtx2Awca4kbhg@` z-+2qwmYRHUhbOOMf*WZI?{B%wJafTM-(^4*>rJ-@vbnu_BK}7S0D72hyaWK*wtw$^ zH_OVx)FOuT1_llIZwBcTz&7Z)Bi>?9{exc8k6zBHs`JX zwpV665OYsSHa8N4c~rAB6Ji==Qo{KhSi^O0m&hV?+!czlS^pXt4}{+DE1quysSr1G zDSBL*3O%qmoAoBKr)DpdQ7P|dLSpTydb7le9N$PjT7b{*HB4wF6o|7RRpOOUAQB#= z-kYFNw^r;qE6`Er{^hL_|MPIWVw?CMyJBCp>!T?J3Xp_F2L?O((_)84LmVM#W`oH0 zQd(7JkSREUE;4Z#n3)7vinHPz&2S9}OM)`Pm?8zF0gw^URYunGI3^DWwc3E^6jfUe z@)aSrxZ#4mS;2FGDTBlT=1pE_i9Yfqr<6Novr(x#Qh+;gN5aR-ytxaB+e>O}r1r|8 z;Y3nEaor!;03>S#);$_$7)T&Tk4Q5h&!h0*ffhU>%WUHmJmxto1p&o4%i%Idcts|i zU%qUf3TwTAiqHne-MmuwLmkZg9LdSOlnA&{4w@$Vprfp&PbbJga|(r{c8;K71$f<+ z%w5a{nK5bTQ`YAmzZ*)PBi*>cb9)iP2r$Rzs*2{Q2*bsry#IyDVPzHCU8DRU=Gm1S z@B*7<>3|m>_SiP)Xlb>D@X_9u_sBt+;DcGVUlbhjZwe{=iK z`!iZQyN}#y1bQ#B7sWNUoTf0O_dIINd&DgCo3>VC7P8e|{wzSV z2T8%@LQP2zQ%-|S#ss)6WO5Hsq1&&#XsuA8N808K6|$|r{rexaMy)WEM-p4_7QfjI zKK0C(FZ`ePvavqOw&pG3USXT_#!v2zo^RQNPI=EaxshtwRv~6{85c^jS;_Wuq5fNw z64)p%&`!=vb0K$dW3o1jRWl2D55s-gEX1{KD&iH9A-1lR*!w(&`^gZ; zgH?REoffZz3}IovXUWjZ9`*8ru4T=Bv|Y2WO4sa5-u?r~P$_)Vtvio2^3YPTP9E!t z4?d@D_-0Dp*i0})0E-I2o(qCu4e2wTlm)@K)B6CA7eGVT+%TyD&{U7)!$A6CBZtLB z8gs9Cu7lI_G%s60emg`75E5^m{@Z&3csq>#VIv^+1TQauW-YxRvw-hV9q$6Kx;((N zj8EX8#xyuETsS->STtX44%{gk4p}cEm(KUOGU@09a?FfQ!s#tnKl;3VMFAMLt^8Hn zT87ENE=m7H4lVBN;MyS{brT?3d7?5S0y*TG6eV&+_}e53=+B2?J)#1Iq*}}LGX_+h zYxb~1&Z`~@_n4V#mX>*O~|Ng+ux2)gvk=x|{ zol=7hrxKV>Nen>>#-*<@t+p`bad@@*`43+4)Vp3F^zAXW147?y`@izFt4i!x2`<02 zGCP(8LRK4Zc-i}o7q__IvbBm^oUQfkZ_d>V0!30R+4VO)T9564+evLAvn8t_auQc>W;@R#M zwjl0wS1~Wme+Jlsh$Hl3_Sk%-wxD|j{b9FL39pF$fK)U5M-3JG=@0aW0A&^Nis(;Y z;z&>8_na;0y7hOwT2}1e&aK!lxkYS2x}Ra-GH!ca;jD@c9<^1t-tpkvFLx7> zfe8_-tO|&fqh?jSWdT!M`$xTb% z$^^LsZ?@S}!G#$+LIHD9^5!qJbr3p?7qJ6mI9WD?%owi%4W+0h9N7zP9Wo9H2i7v< zXNJ*3gZCA~2Slo{R;NtC#4Bv88_S?plu39yP_Z)Nx9pQA=(>}yHAmvKXkGxi@Of%jcJTC;*_;1REH;m`E4H!NyuvV{Yv1+; zF#l48s0(MRmD>3D;C^K(3$q>|Z`hc+lf*^Mwc5*1zFOX?>Bk zs>aIQY?b))VQ3~x^~Q(4_D11Bzhi3^9%O61>_^|A*4@hnS0{Jw%6QRh-ur^{WLN#& z%&uzN^YJ&{yzlX#G>6vGPH_Wx5W957mq_8`eTfHUGEL;egUBqe5OQfAY<2XXA0Su9k}9Cs#DUZ5E9MC%Ot< z^VO;?fkRlIam}8q=Nl^+UEY~1ByG`v24E91j#TxU9O%LX{5>!Q+#(N%ouMg!P6dW0 z?Nxx>21+2Pm9linywdE|inH@Lmw_Tdb&agbl2%RbCpBg2on|!Mj?MU}UA%FiDo4*B7x8j&ng|czJP&CV8>kaD_S`RM5Ui z4j@VO2zM~-A)q9iQc)#QQ|>5oA`Q3`Bv67&CS%QnrM4bH(V3y-)-9t8Dvj`12A!dc_T^6{l=K9uP&65|XrK&( ziiRDGQpZZM(MU9wW0K(!j$b-_9=Z=7e96;=Mx2?_h}RktboKxDK>&X~3j(PLYuG1? z_yspQlXhSD#g(tzb(my=d!nt?WP-ER-h7$(ldZE zHdJ4xLF(MgYM&Jv^n14bLW6AcUwP5@{s(B#kIp;d;v3mipJbc!R`IQ{?fL3I{>wf` zgV-_@+N^{IVVTe1b8o0vDGid#<@ObG&_Ei56-ULq5*pNClTT@o7xR6I2Dz&``!7RX zFR?*Q&or0jK`it|SrUdNw{~A_kdIusj~Ilq#$%V+Ad?mj_@ZjOG!b$#4)Y@nlOF*O zfQ%m3_B<2%;mk+=Nuq?n z5=Lx;JS+~&)0+VpxDH|jw3b0uwVY&Jkrem@0?7ER;9EwRaMg3B0((QA1gHc8By3?C zqRqAYhz?gI?g$r2`od6+e&#p-ctX zJWr~%x%qlk?NbjlEgMLTHc+p7sSBKz%}>3&$lO2#5*sjNL8X*jt_NCyW~pTJgPS*R z>MUJGQ%jkX>uvD3qTXuD0$oPl`LWTKO zs_lT-g>3upc)Jj~BHo56z1}jst$^?0y~@o;Pe17R|C0Eir`cL1KFHSkkx%`IT2~|g z0UP3=tA(=p7LJ_$>VHan(9>;eOni`S&fVYt{yxWt>bP92hI<@aM|#w|YFXp%HGGJh zcvZ|R;X|bJO{JXP#k8Lf!H4khRuQj=50O|VgVyhBd}#1xsrd7#jF;v^&bk@$q0R82 z;^c>W&M@@#^*>!Fw*)=IuG!|6AiHLN_R1Sm!%(pEdu+;vu4%CyoCqP&fV&}sq5>ye z5C#892-08$E=4bjNrJOn3myVWrDcO&3@RHs=6Fz2KvcqLkRs`569T|SwWW|Xq;-YO zEv2h@%(#5K=Sl`p911A;*9hn0I9uX{$k(Q5FBnQBkRwjY>hc_mrrKBF=og0BivXPe zB&HpLp~%9a=l}Sq_63P$#!GG)ZN~%@Wc};_pk%>p;i(sq%>k>f4p5#`&=d1co#$NH z0%IWs?JKl7BzbBDcJB8xQwimZ<{Z7W1$phaQ;4&7ak4a7u$+pN70re$xq%Oq)a!GG ziMmP)rInLFxl>-QsI`|BkcTp_w4epFRfXlROLdezs1)}-arl-usXnmTc_UnyF`>Bs zwx!`5(ZhwlOof)_6ch~;Y3i`%7b-nb@u@9e_O$LP!ib(}7jDCd-eVZiUtInUu>UIi zzLl~Z@~B~NCyO>Kwm2Ry!NY1}Isvte$uwDohppzS+-VTM7Gvn{<|d7UOt&w74kIwhmZQ;yV+g;z_#XEu@Kqj z{PeEd_f|$!nmmpbbQK6u%+^t0x!7+YB3IXXwh0)uig{^5R30zX#3|Yn&sDPz`KYXY zz)h~)K2+NcKIE{vvofU_DXwUW`2o&_wkd0>$Hbi3I2o(P0nRJ}x~fM(t;k^UFg)AXeg`tgN`g8Ibw&FXXh5c4O)j00tV6!7IbI-V8}>yE2Vq|( zkAemJ`5B63h&M(M!n-6jI1(`3$~QpyFxS1}7srdG)D#W44%*5Xm=}4nnJjWAt0@>v zPv5B`AHJYc%H|g+9jrhJ#D^2dNu2aBabSIIHN8=|AoNh%nwls}4|i1Q^YeqYA#ls@XEJ|OWSkof^A&dK4NIl&weI`l_dxXR^2@oE>w2q5RrX&HnJ8pa8f{ z>vHmKqCXlq+~X6bMdLh{;N+4wc7PUf2H1~%(Ga}h$xq@wSb@0$h-Rln!v+`v0f6f} zDlY@uIcFj5R)EPeV4qV?OdlENnLO6b^Sjc+dKzlp*;`d@V?z;Wuswn7(4vGC2P!R= z>ORX0aWDwTzHJClQtp08q`7?oJVqnWT5XjF}BgH-~dor>YIV4uVtXdp@Enp$aoz7XqgUffs0 z{eo#`fGB*vT;FQ-wOvafk!{aQ;(SpVCX-{L!R?mK)Vx)29FaE|xoMc^Te)<2j+X>e zYqQ+f<$6_5#2M2|Dqp91g&Fyihf5Dm!h{#_lKimh+L9M*PFAH{!}T`5JZW|bpJLvX z3}zhqu22rQwg#~*K2Jz^_F?!^Y_eu>p%{iIQdVS+K8WFI3Ehjyl}VwNDouWk2^ZPx zK6c#uC27JR+2z_KP57i?K`;J)zaY(Z85LpCYv9$v)Ps&NNae;G6RSfOUhcJCeIL39 zKKLiE7ZTL9%@-16TmRxBf3+twHoWSjHN&%AdZBtiZNeQ6Tp zpkq<)xN57&N^xJ~K#ptnb0Dz=C6PZyUl|F?^`m+Uv*`9VmK9w7%<8KhEos4@p?%8{B@ zs@bU>1qdEbQ}t9f5!*nk&Q;k;Lc+sd)2;rdDy+^|3s&Quwege!FnwN#tBfZPWD8gX zpNKM)BgBTnkF>*~6!X})jCPhkU#?M=l+k{~0~l+uY0zIcHX$^{iH?+lRf$E2p)$nu ztH8q}b9yD`+>#udJO(u)t~`y%gH;}G3u~N`LgRd|l$(&y7--Ht$2y7u>!epeEb~h5 z3RWOmw zjN`tg1R+)HWJHU7?uW1>nH|MwD3=6-k8vpUj)|dB)$0nv$DgSs=MSm@1)MK;eOHz* z)F2Rzh;SOxKetF?;#4j_aW(xL3>PFowrlNipvd&NRFMLbM^Tf&qQ>mT|On!VjN`$FI9&-};tY4xck)|*W&=qyO4(Tusl zy49==AU`p8J3t5}uxbhN^MWluf69bTf8@Bzri|^z z0l@<~qZ#8wKmpCt?1!nAsm41H4YMv{2lp;oG3n8ctQ=yC&hLE8)XezVyN1?AGTI}W z`oP{*rQhL(EjP*^YJan>JN>lTUQfUZhgc=q5|WY4m4l>wG{)auLu&-p4p%nXk8Jh3 zhPbohp_L=DpFMht_H-Ad_F;Bk_AhauZjinMJFJ+UM+!@yasEONx3&tCj14RzH^|7A zO2rmk=rErZNTnU;r5%rDr2QZ(ZT&1G?ZeO#)=QRJLLww>|EX5gxuQyBm^)L2y0J^@;^hZV={;iG+nvHU!&%K;O79yUn>d`w|TW7&o%{#DFqb4#Dz> zhPG)ZLYFQ~7yQ<|S#-%E(C{ULG22|n9)uL}#eo^d63?!J33wDGjmiM!@e-@Wsrj;i z;(KzYGscm#kVPRwx@y3Bok-<`KqfgVH0&q<^_$adbCOaVx3F+ifUZn>eFzS3-jWj# z1BxH0prxBhd*>L9T$+tVRIAL(NtvmD&mmM%Z<+KL#5@2%7g$Av1@KWPl=RZ=Y(3EYzP$Ap=fBMP~?r|#gkhKfXk|YT` zZDUN51lyYLU3|yhNQGJhs8Gy^?FJPpGXy30k!Iuf8>`2JyrUM@r#@`HmolMpcevdk zLgi6G4JAM3AYGy3N)18x4gys74u>lvKv|E&{ZT>Ls33x{mCdgxNdhh{sU^re9PT+o z(35X?&tF}|nhmFy4e;ecfb5!m{bxU5qJj$HlpOVQ(SqE3+Nxo};Pk_pVS@eK0%E`% z?=a0ZsJ`F5Dm9CEvxEblAp-%C>rr&CQ#7b z;&$*7^Y)?Y%1Z(a4=*w9ha6`>b|fB6r5fO3`fv3JM#7bmWQtB%S-wss8=RNn&F2w( z9F+9k&*ncU<=iFEzTmbTm3BbFtbPS1DpzYb!}VgN3_EwV7Sj-8B8*}V{v_?PPKMAg zQs#{_eEC%3t^q+?VJyhoyb}0QyOD;$g_)~urM9e6#!MD(F1TvOkW)f_;4_1ce^kM_ z6U&|y>6#X_4NcUxxtOQ;i+8?I*vU+WoqWSkpVwS^3Gn_NN`!#%h9g>OQwzmo5SAc2 zH97sqZdSjnd%}Uj+&B!E%|)1-ZT=^I@tS_-Hs70G*fq2<9&5(XS$ktphPkE0j`J~_ z@cZ|V`|_70VfOQF-I6fd*8TZE{aZ`iCMM_49y@((alX5~_c(5!9j-SdEe8F-nV6zU#Mo;sZBln`3<7Y^V zP)bF-Lh6&TOZc1<_YohsLAD=p-SUgYuTOnANX9Imx*2*A?>Xx8fA4Pu z)4PTG*d_ai-}-#&0~dmH;EsiX*)#(g*Oj=%aelkTVJOGvI2)Abd2?sl0VL+6mTM>N zt0A%1sjVY`xgZp_LqR@i;#mdHA@EY1%?~t4+aaWYuTZH)q?u&guF}w=P=t^MVnjm& zQ`aCZ8?>NMz={$BVOi{1Qg48@rL2={nq$ogAwbadOt}VO4|tQ+J%A5c2}qyAVBDiX z7Nn2EwFJ~)h+r^UgPIyLVO!Qkq0e8rFA>Ue~+k!H^N1c`pWNQbv zL}&8#=BaEcWz*$|i;S=`83tgqE9h6I0ja!nbc3=cn$dA9%(`vXQcva!r?@Sv!*KFr zqCXX41ZVJOo9`K7^s~Ra9Q3~mj%#&NT%*xQoqGnPapk0kEZjuR;adi!Sl->Nr@Z2< z^Mn}TrliEkHveO{y}q9qdFDfYxQ27*%Mc^aoA4^+OhimtC(Qr#8znaA99y@<2HCp5 z^xt2m?!lA>%W#b>W-K|;px7Ys1g%2V2!UIILVf=qE`EUANpY@ijk%M;Hs`yazH%?* zMJ5)a%sq}VPFe~le|=(vO$bA53U=zNXA=^%QFxcoUTPC64G$vjOa;9g-f=l<1lLcQ z(dy;S_D+cVu3H?oPmrZ!pT9meq8j$5RG0lVH~Xm(0m%`86{`{+geSTe@;#6v;hg6l zdFeY?wa>Gw_SHg-?5cgyTfYr8%AA9OdGV4bWijV!WX18jcpgj@;V$bdlP4?c)$cEL zSI`rnREs$R{O^EFfR6@X&(OLwsA#^eN-F&Q@8uXVA9SK;!2v$Kcf@k+?gW6ZC`a-p9;XDmt8U_M2R#-q zYvnHi-2%Ip#(4#_2)X9cHMpi0lqA)%q^yp_JquSgHJJrcrOq6EDrMZ(EF^R{kzDC0 zpmr+ID{*CW%F1=4eK9||Fsle^jQ>}JK^x-_23u*WUCHtqsI_u=u}m)QR8Y_es>P8D z3FmO6B-5F9(+yU<0#*IgO&|ONVH4-4Y~o)H8T!m8F6(C#jxHo@+AukrV<5vOJlSBf zA}e7NKb&oSU(zqUz}9WjFWAaI_J6KlC8mYrW(!ky9Sf+*{87%d#Fws-j7bhR2^d(~ z{`IpD9k)}+&;_>nLWXSXKl0hz?|Cxxl23m58{#1MLfabSAZMHNmCwBR`ZXF-QO$=` zRI?!!E`UWp%Y;-62w=&TE)QU-`mhR#6)Mgam(zMnPwLc3js zn$~e|5so_%_JtDDmBz`WtVQ8gh#^hJ2P+#_OOmP-6SF~i^CWl#{e4BTAFI#-i6YYO@h|?vRMnyQfx3!Y5PGdhrHU0 zJQi;{)Sk2xyreKwIeZ@uc$%sk?;CPIr{kkNq?Vix3GlRdbi7L2RP3K!lg64s$F+vgiz(_xV2d91l9>+fTSg4Eej< zLjBfsq5kgYK0c5LS%0$Z1sw`|0(??%w&_7}069FA5fo4MswSCUec1p_6+)Se#~y@6 zD#2-`EvMCSna7?tpX32^Su9QgMabJkoHRld3iHVURux>(!y>_HtpP)TcsQ(tcNh7* zNOGf)JwTRJTIK{73v#4ih-3d_+x(UQ`k`PZX~Q8R>7Shb=>eg z@ynOgN<9m4l@{=YFudGzn}GKY*P$UK$>L2W+y<#OQjQovRYnT4jd|C|hJibiH;TnU zTfN?930gS0X!UamyB-YTPGA=M=LBTsnN;NM-GT4 z{~?2yXM!wc{WRH_5G}9|5}xOXnBjN*9KiJ|P-}XXcT-b!$v42V#sV+IP+dum=8hC_ zE5tTLB$?J4$kk{zNgA|hPPk|7p7ZbE(leDKnV-SRm%per&XTYnCdc(JS7+|z0+zxMKvQ+Y}91TLg);Z#bJYw=bUzt(bcBd&Fg`k18lTc*Wl zci7jZ2R7~&E_AW$fp8()fgio%#n;^9zH&RZy!dj7A$o~x&ULcI+V+Fk!?H9N8jwDLJ=U9tyb*3c6_OYt+?rPGC`UEm ztpXA6r$t~S+%#3hOVc8^R)n-je2DsJk=Noq$BN$Z!MhKaCHs>6l6~#_gcUik5Fn7E zty&>S0kSGPHjeXntqk23gRsI}!8r>^K_rO6r2w!2m()Dvz2Ovv1|kq-1+0>{KMhx4 zlLtZhli*ULMro!nA$al@NJ}b5#k~10Yc&7!2)C}t(GakX^gvghikQZhu8yF)5o$(x&~oFqCfx84TETlq#|z+?9Y;A zte1mnOT9ftRFKVVyR@U)?$DO6c6K~NY+EmLJrGi7J8;A2K2}2N$b>?Y9%I}pB6Z@( zYe?Nc9QU;!2qn7I)h(3B*8T1;N^VqaX@k<)}ek!E`RpBLPms@za(~OUvrp~Ys#FM zmu5r`OuGEs{VBCM%=FZA)vQBuylg_W_TAVa$vV0l)*&b2&^okP{O)S|DUnIVpHHbR zm)S2ncIen*OVL^^MfbpMLal$i=*(+brC;t==_}J!`Y+y;VL`!`UkD{iuy&0Tb~4QY z++$+w*@cxZg^qU_IGM{$9UK`KK40NU0A;8aECtqvJp@79MFIIH`&jnaUJ+6uO0&qz zdm5FRPGC#sg-0+3K-};YqXpR4iMwKHxr+2xVl0f`o1tXx%dy6@cYpdykfIh|{PLhwX@xSJQ;-Ag`tZCt_8g!pF3HLvFKVy@X8 zIobxzgpi#pI&=~S+ygtCDxk}c?>R!~XeLAFW=ookK=rzC5~8aRMh-z)c+>RMg)%l3 z`_-j{aZ3t7`YXv38pK9&=ce+{;PFsp@`spr3+rLb)S{$`CkhV|aJ19I0Zn_Usv8`y z6!rhcA78runPB@@xMfB}zB7eQ)F0YzQ=jE)*ME}65m@~_Z{@EQ; z_3U>&>UfD1y3Dm-Qfb@f|L~@-)BM4=3UT`i%LI3AlMH+SMdFfgk^o{KZY+^D;Z4my ze#{l@sIPLZd712}wmCn2&0d*H8+9)U6@vHRBdl0vCB&yDt5J%{SIHD~?;$?;DyW!O zLVU2e6^VJJwxD|l@xcpZ*|wlEw>XFo>%@MyxbDJ&7%{PU;0d5&UE2OLfb84@CZBc3 zeszZg1YPcyW|Pv!F3m5$;ucIk4i5xC8>|EXGcU3P$&HHsR*e82Qqn@_Oa?aHTQ@&a~JN8+b6=noDi$hb*`2=<`Dd-eBxkb1};bIOur)UWPRwq71vI-Oet0I}Zq6PLiO~0%lNu*I10iYtju|d6}N&+nfcn7J`LN+YbI;~6N`kjIXLrJZw zA0FB|{*tFfMT2fc5OfQ~PE5afgQ~H?q)EHOMD1Pz*>E*g(4g7^RnpN4W{B+-9m|lS zrMP8)&uN&ti;>ztN2(eJi;T$1(0*Z45g|XXcFVSqpSKwDbL-Xb zE+IeIc599DWYpxzp!Gt2u*CE}PUWw0m79BnZ0-N>+V>73Kir0db+lGPt0d^NocxF< zT(c%lBrr14Gq0{m>5n#i`z^wLu5cX?_G8=szW@A|ZGV_fd$UcaO{~>2InDAZ)MMH= zz3V!**ehLYt`hcRn{)BE{%UWu`V3}21h=k&*{3Gi2)#kVAM38v?9)U@&&{oZfjOEOvM0`+PMeq@yG~#A(;EyGvT2b5KmmNyga_vEJN04-&w}kv zZZ~~zAMuCUF-gOlk~WAq7H&6`^i~e+1h7O9v%9EL$Orc{4)WD9e5%iq;2=qXKT;kD zD~JnFWCj{$;B{fvC@E^(jKIxYM9#5)YB~v!&Y+@sXsGx;gG~*5$|Eh{o4CnAQ}yZ< zC)fbP*j_tPZ_6wo#yn4mgZ`LgJhefiPSvVpeRdV`Y< z#|A@Nk7t8vK0FmdpaX@iRecGx03j%q$=n8JhG98v9Y8Ye2}EAz1=G{64|xU2xglzY z1N|%# Date: Sat, 15 Mar 2025 16:48:10 +0000 Subject: [PATCH 22/38] . --- ...ut.tfevents.1741637561.ee-tarrasque.3914313.0 | Bin 5028 -> 0 bytes ...ut.tfevents.1741637649.ee-tarrasque.3918878.0 | Bin 6507 -> 0 bytes ...ut.tfevents.1741641748.ee-tarrasque.4023347.0 | Bin 5038 -> 0 bytes .../module/transforms/optical/bert-finetune.py | 2 +- 4 files changed, 1 insertion(+), 1 deletion(-) delete mode 100644 model_sst2/runs/Mar10_20-12-38_ee-tarrasque/events.out.tfevents.1741637561.ee-tarrasque.3914313.0 delete mode 100644 model_sst2/runs/Mar10_20-14-07_ee-tarrasque/events.out.tfevents.1741637649.ee-tarrasque.3918878.0 delete mode 100644 model_sst2/runs/Mar10_21-22-27_ee-tarrasque/events.out.tfevents.1741641748.ee-tarrasque.4023347.0 diff --git a/model_sst2/runs/Mar10_20-12-38_ee-tarrasque/events.out.tfevents.1741637561.ee-tarrasque.3914313.0 b/model_sst2/runs/Mar10_20-12-38_ee-tarrasque/events.out.tfevents.1741637561.ee-tarrasque.3914313.0 deleted file mode 100644 index 21fed95b763c01ecfbd0a4f987eaf9cc96de2140..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 5028 zcmaJ_zmFwH5eDlP?w*7M&LCR7(DJ@}yL;<9TNV;0**X~?M#hd#R;y>GXI^_}rpMhq zyZ3CJ2%I`1frN~Rgb<0pfrJ1dNRbFhkPsrU_^P_6XL|P?8|-SktE=nBS6@}nlW&Bd zuYd96*Y|$=$MTpdR_cE9&Xwx@TW23`R2J7WtYZMAV@*GaqITwCpx z-6)b0wQiLND_=eM=-=;tH|Rb3qAAfcH#p?KYZAL$gUR2ty-#T7*rc$kN&Y!Jc=iMu1`7)8o)v|Sp1u7#v z`GCXOqPmhT*1R2)++bCuOP@)Sm~hbGRkJ4p)L_b8xI4;g(~H(PCrsax$b8M9tf*BW6w;SF;q>PUa*x=)g$&>I5|L&Y59#Q~woPhC=C+O`(pXt(%8)F}u^U_I$#Un_ z;9$)utKs7Z_mdhHK@ueSy5&>eSh%IS!S6*W%biME)2UWT+v~m&6iFPsP$%09K76fo zvqOGW88mK!PFcxL6w1p}x@Pkx%cjM%u=wU;8E)v-P@YerI|R3r&jq9tAfj|X+(=6B3!Nq>72e&@rr3jST(e|lq3Ni6wz4%*;I%=oUO>n1a0VT zDD<~V=)N6$cmJuhr(1EhJvlu)Ie#RSI`Pt4={_6h+xh&m41(qC;-tRT)}fGrCtN-k z!L&X3dXb!ErM>Kh?rWkAaLCeZC$9)dbLlPajC%pP7u%qc7rpVQ*j9I%3OV5~9FVmN z{6!f`@!`E2!JF-myX{^u_hi0V<$!$!(6?ZvyhL^=*V#s=ywxQ|wip|1&Uu(to_&%c+ZtIBjaxH_R=PZw z%!L!M(;bwMc#LegHuh3kvLc#1ZzF0zBuciHC4jUv6!q^>7!@ZfQW45k^s-|Hy4eUT z?PP(@>p!cA>1Lm->I@2L?2-5i0cpps&*0z60xgyFW|1NfK)|pfTqV5-X`NMvK#JPf zh2YYO#wyf<71|M%hx$%o(hXQyi2&bAowG*mDq4t%p#Mat;XUFxNGrQJoV-_>zSfjW z8W}J&9G1n#^ERuzup(H;6=5wox7^_tf?7_=$0dQ@`5-8&cL+iUPtg|fyKJ5C22+qE zCz8i#tO{L3BGpJlZm|Qe`!FDI9+rk6w*yZOJRD?cf|)dSEe=k!h8-$(A_fMAlih&J zVb4OMr-DUUT~Z^-y+1rIr2a-ur7%Goir(gSkjTgvu>?;mRcvYEnv%C(uncBMr$IxQ?Pju>BjiwAQ33ui_bn9>N?O&(a!)Ky6L6SlOBm2Nf`o{hdFh70&{I9EL%L`d2t|{dv8Wqdq|`3Qt1#ar!9EhpYr~?gl(jGfZSV2j>ZX;l!gH$Yp^3 z4?nMO@_NGP1PYEQaEi9!6hJ^FGDl^SqQ=K43l-sr*Yh`RuWCQ`3e&$vLCbd7QKw7S z1xm$a-++(M(ea$g+I`3hgYgiAc>~P?1FlVNkb}+ip&OWwDW~Vh%WOU$BZ)@(9Yjp= zsk-sPM!8$rU~r*IcX*sydP|3AGg-*WRxjSWyZYAZ4t* z?^OJ?;s?5r16}(tA8Ws5N%r&rnN*Rj0_!RD;|+ zP$9rJ_n38|tUoEFQzs+p7Nb>!IeHjAhLnB`gJO*3 zLOAVcaQM(W!WdHlr7O`C@Cpe9NE{ z+?K*{e%plf=9h)T@Qkk)Z0};|a{>^iw-1H(!REokO?(-Vt>OxgmST!T@RbUJzmwoV XLRcnuzP&g*TfJD#sK5B|=gp(M_Mw$6779hPJ}rcX-M-y+yGw1?Wrb2LrRAgifsx6)d*{7(+55%J zz5CHpgNg#7wXSF}L@J1ym>?=9g28|we;^tqM(`t|#)guZ7S>=fA!7ZVnfva&v+u3@ z$8L7!oSE}G=XXBlZeL=4{`=+a5A>|rw}07h-~8j{7rdt~xfc#xIVHkm^U!qzl}i&5 zsz}x&F5Ll@1_76|1Jlw=gd7kg#TuS3m|3~$wc35V&o<{CJ>w6L_gwoObMCt5`YO-# zReHH>s`ZJONmzEx{&zH8;Om8YG#bmDN|zDi$THA<5>O_=M+>gK*`5V^u|f`&dh4n#;vfxSo?(k|f3yPDn&^mJASa z(|JZN(v;lwC_>zE(jfJDLYAvU#0u67>SMHe*Zuq$ErK9$!dlZ>NhASF1aUIU9PYG4 zerm4e3ob)1Y%=mB6o*BBZCEQ$3$NKie1&699582cmj|pa5vcv)!~ zFg!n^I1l1%k|{uPQyU$vo85&pulhed$i0s>M&~X{0F`@P%tkyCEAVtgVs|;j6YBC>5M7*rVTq$o#WLLCz zm=v+^Ih-VL26?oQKr{us53nv7nNu6*Hp%rJDYRaup<2I_%fT8Os`U?U?5~Y5A^H<8 zC0Dnm#qC~tZLGsNSbb>7L*|7IDjTRFWwRBZA|e&3 zSGqIXGt?cfnIrXL5+%r3-wQlS0{Xc>%>+(C3G z*49QK66rbQ*?QW$|paSd)AM*OJ!U}G4r{yW2tLlo7njs6p&1fY?zMZq>yAq zW^%C&uYr+B*#>t&q_UyN{{)HQ>Of_s2d(iXcKtr6DNb!Zm&xgjOQ>ypHBdH&{L+EGWgW$8=wwFcaY7Xya1Wh-68y z45ujtGD9*?QBgOUo;%EIqCN{|h&<`p^`oI&TF62||FOV>AK=-TOCA)C)q749Hav8V6$VkjQO zxE%JBBtlAEB-kXi6dQf_!6EN-1S)rfD^QbTQ=f9QhF2uIC0P`gwwK>Pq%ENS*?2d`SHH>d>4LqWonSCpN?Lkh{29e_W&3YHHdmiVH{omr~q zGz**AHnKA^PF0IY-AV~aaXLzE8i+2Xz+$CfE(Tc3gSQSg{*RZwdv5<|7KAbu#y-`K)4Owq!MRHAU`OT&CHevv$(I> zWuu|$f8yY++o2*i`6tK2%M(XeuyCW4(0>C+RuC zB`p+hPuWdO*9LbMP0EV1T`8rUC%k%CWZDWc#5{61V}(BQh22;MU79h%w&Y3x^BO2L(xdm zP~aWlA)&b+x2V(Q*O`)v8NZ2ogpQ7fjn(c=R5Tg8QCM%FSzy$ac^lNu*5$)%!hFm- zp?<~`=4&;QX!4{5k*I!)Sn$?Hx=Y@~;6j5>s*Qc6nTr$Dx*1i7(RcEYO6SUR;;r`lyd5$_5LY3F>;KEFz9NX^os_i=+rIzQq z7!xpXV!7HBewtNLv%bi+nO64cVq}wfUn4yekNfzZA!#gcQZAvVn|BEwKFDCS5vt3X z;)dFY*)-D;JPza6NsJTIGJ^)Y#IezIbG9CjXD0ZhEul#w&99)s=HA?=|v79sQkT z)oV|@Jbv%%H>|iMQyMwNmh^n)%3SHf$_15kvP;UJbyh7~v3vC0Yfscam?MhC}M3|ny4UYIL(fitmnbpOAWSqd-x;@MAY zIER;Z;5q!*fw3Le&s}s;hF|SEx}@jWuL^~(fYvxaAHVUg+a6h$X+3&y?qtvMUl&@7 zB6{V?t4IC2zxGTu)12wu`c6;fH-%;gq`~>Qt$R-Fx!Z#D+<^yg*N~p?gmmPMJI4FGzEo%}3h9NFJ4TPav+iOG(%GS`Nbb#rLRUy1dSdnETeojIWkGt26KyF){I4a2!vN<-t<-T(*NQdqy6uLsXc=bKwC+EMp$$~Wh`9f<^NQchbHu{5B^BWeV Ukvj{`4oF8Q){Pzb^R>PI1Jnp%`Tzg` diff --git a/model_sst2/runs/Mar10_21-22-27_ee-tarrasque/events.out.tfevents.1741641748.ee-tarrasque.4023347.0 b/model_sst2/runs/Mar10_21-22-27_ee-tarrasque/events.out.tfevents.1741641748.ee-tarrasque.4023347.0 deleted file mode 100644 index 0c1232b6c0a3604c85c2345800ba755b7517067a..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 5038 zcmaJ_&yOTG6^1Aic3U99-jdU(jwsVR-Mf=bq9_q7C0YTS1V|1DAXzuGp@g z=|p>h|9}HG4oKWNaz^4W;J~jd5)xc_w#!wX>CGHwH1_lJ`|-W+J=c@3g`fZZ z>d9~J-h1zzlfQiN_wT*(tD6TQ`@)ZEl-J9Zs6hxL__PMEbI zDb`3OSXlX@x%khA-wt~J`O@G1aQEdu2fg3@)9LDur>pO9)7tw^Tssz3J^tlcaUidy#_D-A z08+F(T@_qbx@8JQM7S&lZmrylc*V5Cj2M)xI7vVUMRb-xHU**&XKOMtK^uBo3jM7T zx^Kqb-hax?=A3OeXWQ-B_7M}}%yDD5{dAmf_ve>o5G?mDPU;(JED9NT!sT-jOxu&M z7s*+#wC6pOeM7VX4q2LQ`JR9@m)_#exEG*%aUN9qMW-DqR>@9MAt(HW1H2J|zbHc~ zK7Q{;@Mintt~xN!J(+JtSYTfO^h%Yyx`eraw`Gmd9Rp~p%R$RNXOv#EYOE=P^Btpw zj$ypD)(tiiNVa8#sIl;H>*HT@$5~uT{(d5zx7QOwc9`yxgSgHX%VzUhw$Kbm7MLE~ zjL;7@0eU;G8U`+`q$NV3IEbF?EBNSikRko70a@L)hi-b#mTMkRp(m%5Mxaqk{UugI z&g*(u+?XWMG8h$kh3rtSvyDzTBP)t*IkwoG^DwRa?B^-6O7og%+~|QRZu4AvE}Vd! z?x2LkV`RgXHkZPX7188*8&LxyQL+uM0Hmd%sDFpTsC}X$6`@>N&pXdRr7$Qq^g$-B z(L-bhW)HEhrR61hu=}(o&YQh6qSGj;u?GZdghq{BpA~}QC7LVUkj0FA0TsiZ(0#dM zA;Gii5K`7?vk+Xm(;9)wFict^^-$?4V6p{23l@NUDRcg4Y)uO>CG;QbB;tU`4&3l= z4y@m+bl*tIDh&@9ACAjn<631k7*?3qu_x3e=axIZnV_l@3~)&xdVdgP)msFigQsYa z_+2T>oW?xF$%*7KTC6}n;Yd0Xky~uR@GguBJ`YPnklVq}5PnR^(gZVU09zcK2o5_` z^h6OJ7?yVfE{8n}iJnRq1$Id-CHMaL;E*~UIhDdt3vluf_$hfCaAiBK-HsYjLi^MX z!7GxmT2O<(LclOj1Tgy@7JF(^2(hFig{NqZMo%pdps2oeyQLRJSc;%eMHB^nv7@tU zloS{XU<=q;%(~th1oDdaby;9LG6c}rph~Br@`yx_@#UV(^&~Zd79BbT;9KH~eAA%B zDysN$Y?_;=7kmNWcc@}Rv)7coV!=0H0xLmCnDS=YnImLTTu}l3cotkLB9ti6^6Ef* zPP1^BZFOFe!)SK|>fK0=%-TV_`$pHzdi2`M(#l{G#^m{GtFFG@q=Y#c4lwF_1vj%fFF>GGjk@wGVFWw z$i7lzjn3NRv4Q~rbv*}o4$&l*??-M6lQqTz-B;k)xx`bZeXl&)VCn2QP4x4~w0y11 zSv#78MOK+ME$`)LUf(jBOPM(zAAbCkCrAjpnU`(|3_aCD5z<9dK`5H+tuqJ7&j~f1 z0(ncDH!;r*^#M&P%d_FYOQqv8*iolT*9A(&jNgKf(9!X*$=ZF$iU;E{2>T5*3yiuZwLu?jPam>{`IvWl|BRW< z_s2-0t+dS$C;-@mhZY4sv~HTd__Z(jfF zh46i$`{nTE-Jkw9eqZ><$?C5stCz#P5*lp&o#FQ<={v*kPW*QU#+}1cgjyrbQlXXz zt5%SksRn)XK!pI?zQ+_aiJg_)in9@Qi_t2=96b!5MJhLj5j6Ic$B=$co2T4fiu#97 z=T?I2aV-(wO6{y*G*z7EgP3PY%sjYJ5e#4@+X`D%)t&(coTHO2%CoLDYe zp~lFfVf{JR##8ar#iQ54`x@oxBv<&JVQ4IGsg}^w!@C3zA5<{f2=Q_}_~86e*bLJV zJz|hv!!y2P_<$~kKFa`MdIeEvAFdxhTE{mM*~<2KzLZlW fg0EB%{)-6?B!p#h56&;P+tstxjQU^j_kZ+%3`T29 diff --git a/test/passes/module/transforms/optical/bert-finetune.py b/test/passes/module/transforms/optical/bert-finetune.py index bab5eecbb..db4d6fa84 100644 --- a/test/passes/module/transforms/optical/bert-finetune.py +++ b/test/passes/module/transforms/optical/bert-finetune.py @@ -108,7 +108,7 @@ def compute_metrics(eval_pred): evaluation_strategy="epoch", report_to=["none"], num_train_epochs=2, - logging_steps=1000, + logging_steps=25000, per_device_train_batch_size=2, # set training batch size per_device_eval_batch_size=2, # set evaluation batch size ) From bee7f6d73e31daa546be83f20beb4f9d7d3b10a8 Mon Sep 17 00:00:00 2001 From: Johnny1882 Date: Wed, 2 Apr 2025 13:52:44 +0100 Subject: [PATCH 23/38] modify morr_linear, implement it inside onn-transform-pass --- src/chop/nn/optical/modules/__init__.py | 2 + .../nn/optical/modules/morr_custom_linear.py | 488 ++++++++++++++++++ .../nn/optical/modules/morr_transformer.py | 474 +++++++++++++++++ .../optical/module_transform_helper.py | 41 +- 4 files changed, 1004 insertions(+), 1 deletion(-) create mode 100644 src/chop/nn/optical/modules/morr_custom_linear.py create mode 100644 src/chop/nn/optical/modules/morr_transformer.py diff --git a/src/chop/nn/optical/modules/__init__.py b/src/chop/nn/optical/modules/__init__.py index b1d7c5629..bf8ecbce6 100644 --- a/src/chop/nn/optical/modules/__init__.py +++ b/src/chop/nn/optical/modules/__init__.py @@ -1,7 +1,9 @@ from .morr_linear import AllPassMORRCirculantLinear from .morr_conv2d import AllPassMORRCirculantConv2d +from .morr_custom_linear import AllPassMORRLinear optical_module_map = { "linear_morr": AllPassMORRCirculantLinear, "conv2d_morr": AllPassMORRCirculantConv2d, + "linear_morr_full": AllPassMORRLinear, } diff --git a/src/chop/nn/optical/modules/morr_custom_linear.py b/src/chop/nn/optical/modules/morr_custom_linear.py new file mode 100644 index 000000000..1a035b88b --- /dev/null +++ b/src/chop/nn/optical/modules/morr_custom_linear.py @@ -0,0 +1,488 @@ +""" +Description: +Author: Jiaqi Gu (jqgu@utexas.edu) +Date: 2022-04-18 14:19:57 +LastEditors: Jiaqi Gu (jqgu@utexas.edu) +LastEditTime: 2022-04-18 16:21:37 +""" + +from typing import Optional +import logging + +import numpy as np +import torch +import torch.fft +from torch import Tensor +from torch.nn import Parameter, init +from torch.types import Device +import torch.nn.functional as F + +from ..utils import MORRConfig_20um_MQ +from ..utils import mrr_roundtrip_phase_to_tr_func, mrr_roundtrip_phase_to_tr_fused +from ..utils import toeplitz +from ..utils import morr_uniform_ +from ..utils import input_quantize_fn, weight_quantize_fn +from .base_layer import ONNBaseLayer + +logger = logging.getLogger(__name__) + +__all__ = ["AllPassMORRCirculantLinear"] + + +class AllPassMORRLinear(ONNBaseLayer): + """ + All-pass MORR Linear layer, assumes (1) block-circulant matrix (2) differential rails (3) learnable balancing factors. + J. Gu, et al., "SqueezeLight: Towards Scalable Optical Neural Networks with Multi-Operand Ring Resonators" + https://doi.org/10.23919/DATE51398.2021.9474147 + """ + + __constants__ = ["in_features", "out_features"] + in_features: int + out_features: int + miniblock: int + weight: Tensor + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = False, + config=None, + device: Device = torch.device("cpu"), + ) -> None: + super(AllPassMORRLinear, self).__init__() + self.in_features = in_features + self.out_features = out_features + + miniblock_size = config.get("miniblock", 4) + self.miniblock = miniblock_size + # M * N/k MORR grid + self.grid_dim_x = int(np.ceil(self.in_features / (self.miniblock))) + self.grid_dim_y = out_features + self.in_features_pad = self.grid_dim_x * miniblock_size + self.out_features_pad = self.grid_dim_y * miniblock_size + + self.v_max = 10.8 + self.v_pi = 4.36 + self.gamma = np.pi / self.v_pi**2 + self.w_bit = 32 + self.in_bit = 32 + + morr_config = config.get("MORRConfig", MORRConfig_20um_MQ) + morr_init_val = config.get("morr_init", MORRConfig_20um_MQ) + self.MORRConfig = morr_config + self.morr_init = morr_init_val + self.mrr_a = morr_config.attenuation_factor + self.mrr_r = morr_config.coupling_factor + self.trainable_morr_bias = config.get("trainable_morr_bias", MORRConfig_20um_MQ) + self.trainable_morr_scale = config.get( + "trainable_morr_scale", MORRConfig_20um_MQ + ) + self.device = device + ### calculate FWHM (rad) + self.morr_fwhm = ( + -4 + * np.pi**2 + * morr_config.radius + * morr_config.effective_index + * ( + 1 / morr_config.resonance_wavelength + - 1 / (morr_config.resonance_wavelength - morr_config.bandwidth / 2) + ) + ) + + ### allocate parameters + self.weight = None + self.x_zero_pad = None + self.morr_output_scale = None ## learnable balancing factors implelemt by MRRs + self.morr_input_bias = None ## round-trip phase shift bias within MORR + self.morr_input_scale = ( + None ## scaling factor for the round-trip phase shift within MORR + ) + self.morr_gain = ( + 100 / (self.in_features // self.miniblock) + ) ** 0.5 ## TIA gain, calculated such that output variance is around 1 + ### build trainable parameters + self.build_parameters() + + ### quantization tool + self.input_quantizer = input_quantize_fn(self.in_bit, device=self.device) + self.weight_quantizer = weight_quantize_fn( + self.w_bit, alg="dorefa_pos" + ) ## [0-1] positive only, maintain the original scale + self.morr_output_scale_quantizer = weight_quantize_fn( + self.w_bit, alg="dorefa_sym" + ) ## [-1,1] full-range + + self.mrr_roundtrip_phase_to_tr = mrr_roundtrip_phase_to_tr_func( + a=self.mrr_a, r=self.mrr_r, intensity=True + ) + + ### default set to slow forward + self.disable_fast_forward() + ### default set no gamma noise + self.set_gamma_noise(0) + ### default set no crosstalk + self.disable_crosstalk() + ### default set no phase variation + self.disable_phase_variation() + + if bias: + self.bias = Parameter(torch.Tensor(out_features).to(self.device)) + else: + self.register_parameter("bias", None) + + self.reset_parameters(morr_init=morr_init_val) + self.finegrain_drop_mask = None + + def build_parameters(self) -> None: + + self.weight = Parameter( + torch.ones( + self.grid_dim_y, + self.grid_dim_x, + self.miniblock, + device=self.device, + dtype=torch.float, + ) + ) + ### Learnable balancing factor (morr_output_scale) + ### We use a single scaling factor for each block + # we init this to be identical to retrive the original output distribution + self.morr_output_scale = Parameter( + torch.ones(1, 1, max(1, self.grid_dim_x // 2) + 1, 1, device=self.device) + ) + + # instead of unique scaling factor each column, we apply a uniform scale for all rows. + self.uniform_scale = Parameter(torch.tensor(1.0, device=self.device)) + + if self.trainable_morr_bias: + ### initialize with the finest-granularity, i.e., per mini-block + self.morr_input_bias = Parameter( + torch.zeros( + self.grid_dim_y, + self.grid_dim_x, + device=self.device, + dtype=torch.float, + ) + ) + if self.trainable_morr_scale: + ### initialize with the finest-granularity, i.e., per mini-block + self.morr_input_scale = Parameter( + torch.zeros( + self.grid_dim_y, + self.grid_dim_x, + device=self.device, + dtype=torch.float, + ) + ) + + def reset_parameters(self, morr_init: bool = False) -> None: + ### nonlinear curve aware initialization + if morr_init: + ## initialize weight + morr_uniform_( + self.weight, + MORRConfig=self.MORRConfig, + n_op=self.miniblock, + biased=self.w_bit >= 16, + gain=2 if self.in_bit < 16 else 1, + ) # quantization needs zero-center + self.sigma_weight = self.weight.data.std().item() + self.weight_quant_gain = None + + ## output distribution aware initialization to output scaling factor + t1 = mrr_roundtrip_phase_to_tr_fused( + torch.tensor([0]).float(), a=self.mrr_a, r=self.mrr_r, intensity=True + ) + t2 = mrr_roundtrip_phase_to_tr_fused( + torch.tensor([self.morr_fwhm * 2.4]).float(), + a=self.mrr_a, + r=self.mrr_r, + intensity=True, + ) + g = ( + (t2 - t1) / (2.4 * self.morr_fwhm) + ).item() ## 0~2.4 FWHM slope as a linear approximation + + self.sigma_out_scale = 4 / (3 * self.grid_dim_x**0.5 * g * self.morr_fwhm) + self.out_scale_quant_gain = None + init.normal_(self.morr_output_scale, 0, self.sigma_out_scale) + else: + init.kaiming_normal_(self.weight.data) + init.kaiming_normal_(self.morr_output_scale.data) + self.sigma_weight = self.weight.data.std().item() + self.weight_quant_gain = None + self.sigma_out_scale = self.morr_output_scale.data.std().item() + self.out_scale_quant_gain = None + + if self.morr_input_bias is not None: + self.morr_input_bias.data.zero_() + if self.morr_input_scale is not None: + ### after sigmoid, it cooresponds to 1 scale + init.normal_(self.morr_input_scale.data, 2, 0.1) + + if self.bias is not None: + init.uniform_(self.bias, 0, 0) + + def sync_parameters(self, src: str = "weight") -> None: + """ + description: synchronize all parameters from the source parameters + """ + + raise NotImplementedError + + def build_weight(self) -> Tensor: + if self.w_bit < 16: + ### differentiable quantizer based on STE to enable QAT (Dorefa-Net, arXiv 2016) + weight = self.weight_quantizer(self.weight) + + ## rescale weights after quantization can maintain the initialization distribution + if self.weight_quant_gain is None: + self.weight_quant_gain = self.sigma_weight / weight.data.std() + if self.trainable_morr_scale: + morr_scale = self.morr_scale * self.weight_quant_gain + else: + morr_scale = self.weight_quant_gain + weight = weight.mul( + morr_scale + ) ### gain factor from Tanh used in quantization + + ### quantize learnable balancing factor + morr_output_scale = self.morr_output_scale_quantizer(self.morr_output_scale) + else: + weight = self.weight.abs() # positive only + morr_output_scale = ( + self.morr_output_scale - self.morr_output_scale.data.mean() + ) + + if self.finegrain_drop_mask is not None: + weight = weight.mul(self.finegrain_drop_mask.float()) + + # morr_output_scale is removed here + + return weight, None + + def enable_fast_forward(self) -> None: + self.fast_forward_flag = True + + def disable_fast_forward(self) -> None: + self.fast_forward_flag = False + + def set_gamma_noise( + self, noise_std: float, random_state: Optional[int] = None + ) -> None: + self.gamma_noise_std = noise_std + + def load_parameters(self, param_dict) -> None: + """ + description: update parameters based on this parameter dictionary\\ + param param_dict {dict of dict} {layer_name: {param_name: param_tensor, ...}, ...} + """ + for name, param in param_dict.items(): + getattr(self, name).data.copy_(param) + + def set_weight_bitwidth(self, w_bit: int) -> None: + self.w_bit = w_bit + self.weight_quantizer.set_bitwidth(w_bit) + self.morr_output_scale_quantizer.set_bitwidth(w_bit) + + def set_input_bitwidth(self, in_bit: int) -> None: + self.in_bit = in_bit + self.input_quantizer.set_bitwidth(in_bit) + + def input_modulator(self, x: Tensor) -> Tensor: + ### voltage to power, which is proportional to the phase shift + return x * x + + def set_crosstalk_coupling_matrix( + self, coupling_factor: float, drop_perc: float = 0 + ) -> None: + ### crosstalk coupling matrix is a symmetric matrix, but the intra-MORR crosstalk can be taken as a round-trip phase shift scaling factor, which is proportional to the number of segments after pruned. + ### drop-perc is the pruning percentage. + assert 0 <= coupling_factor <= 1, logger.error( + f"Coupling factor must in [0,1], but got {coupling_factor}" + ) + + self.crosstalk_factor = ( + 1 + max(3, (self.miniblock * (1 - drop_perc) - 1)) * coupling_factor + ) + + def enable_crosstalk(self) -> None: + self.enable_thermal_crosstalk = True + + def disable_crosstalk(self) -> None: + self.enable_thermal_crosstalk = False + + def set_phase_variation(self, phase_noise_std: float = 0) -> None: + self.phase_noise_std = phase_noise_std + + def enable_phase_variation(self) -> None: + self.enable_phase_noise = True + + def disable_phase_variation(self) -> None: + self.enable_phase_noise = False + + def enable_trainable_morr_scale(self) -> None: + self.trainable_morr_scale = True + + def disable_trainable_morr_scale(self) -> None: + self.trainable_morr_scale = False + + def enable_trainable_morr_bias(self) -> None: + self.trainable_morr_bias = True + + def disable_trainable_morr_bias(self) -> None: + self.trainable_morr_bias = False + + @property + def morr_bias(self) -> Tensor: + if self.morr_input_bias is None: + return None + # return 2 * self.morr_fwhm * torch.sigmoid(self.morr_input_bias.unsqueeze(0).unsqueeze(-1)) + return self.morr_fwhm * torch.tanh( + self.morr_input_bias.unsqueeze(0).unsqueeze(-1) + ) + + @property + def morr_scale(self) -> Tensor: + if self.morr_input_scale is None: + return None + return torch.sigmoid(self.morr_input_scale.unsqueeze(-1)) + 0.2 # [p, q, 1] + + def _compute_single_pass(self, weight: Tensor, x: Tensor, morr_output_scale: Tensor) -> Tensor: + """Helper method to compute a single pass through the MORR.""" + ### x : [bs, N/k, k] + ### weights: [M, N/k, k] + + weight = weight.unsqueeze(0).unsqueeze(-2) # [1, M, N/k, 1, k] + x = x.unsqueeze(1).unsqueeze(-1) # [bs, 1, N/k, k, 1] + x = weight.matmul(x) # [bs, M, N/k, 1, 1] + x = x.squeeze(-1).squeeze(-1) # [bs, M, N/k] + + if self.enable_phase_noise and self.phase_noise_std > 1e-5: + x = x + torch.zeros_like(x).normal_(0, self.phase_noise_std) + + # Input scaling/biasing if enabled + if self.trainable_morr_scale: + x = x * self.morr_scale + + if self.trainable_morr_bias: + x = x - self.morr_bias + + # Apply MORR transmission function + x = self.mrr_roundtrip_phase_to_tr(x) + + # Flatten output + x = morr_output_scale.matmul(x) # [1, 1, 1, N/k] x [bs, M, N/k, k] = [bs, M, 1, k] + x = x.flatten(1) # [bs, M*k] + + return x + + def propagate_morr( + self, weight: Tensor, x: Tensor, morr_output_scale: Optional[Tensor] = None + ) -> Tensor: + """ + @description: propagate through the analytically calculated transfer matrix of molg. We implement circulant matrix multiplication using fast circ matmul + @param weight {torch.Tensor} two phase shifters in the MZI-based attenuators + @param x {torch.Tensor} complex-valued input + @param morr_output_scale {torch.Tensor} learnable balancing factors + @return: y {torch.Tensor} output of attenuators + + Noted Here, we use 4-pass matmals to preserve pre-trained weight losslessly. + """ + ### x : [bs, q, k] + ### weights: [p, q, k] + ### morr_output_scale: [1, 1, 1, q] + + ### input scaling [TCAD'21], must have valid ranges. too small will have dead neuron and not enough nonlinearity; too large will have larger power, cross-channel crosstalk. [0.2 - 1.2] will be suitable + ## build circulant weight matrix + # crosstalk on the weights are much cheaper to compute than on the phase shift + + # Split weights and inputs into positive and negative parts + pos_weight = F.relu(weight) + neg_weight = -F.relu(-weight) # |W-| + + x = x.view(-1, self.grid_dim_x, self.miniblock) # [bs, q, k] + pos_x = F.relu(x) + neg_x = -F.relu(-x) # |X-| + + if self.enable_thermal_crosstalk and self.crosstalk_factor > 1: + # weight = weight * self.crosstalk_factor + pos_weight = pos_weight * self.crosstalk_factor + neg_weight = neg_weight * self.crosstalk_factor + + # Compute the four passes + result_pp = self._compute_single_pass(pos_weight, pos_x, morr_output_scale)# 1. W+X+ + result_np = self._compute_single_pass(neg_weight, pos_x, morr_output_scale)# 2. |W-|X+ + result_pn = self._compute_single_pass(pos_weight, neg_x, morr_output_scale)# 3. W+|X-| + result_nn = self._compute_single_pass(neg_weight, neg_x, morr_output_scale)# 4. W-X- + + x = result_pp - result_np - result_pn + result_nn + + return x + + def get_finegrain_drop_mask(self, topk: int) -> Tensor: + if self.w_bit < 16: + weight = self.weight_quantizer(self.weight.data) # [p, q, k] + else: + weight = self.weight.data.abs() + indices = weight.argsort(dim=-1) + mask = torch.ones_like(weight, dtype=torch.bool, device=weight.device) + + drop_indices = indices[:, :, 0:-topk] + mask.scatter_(2, drop_indices, 0) + self.finegrain_drop_mask = mask + return mask + + def apply_finegrain_drop_mask(self, mask: Tensor) -> None: + if self.w_bit < 16: + self.weight.data.masked_fill_(~mask.view_as(self.weight.data), -1000) + else: + self.weight.data.masked_fill_(~mask.view_as(self.weight.data), 0) + + def forward(self, x: Tensor) -> Tensor: + B, N, D = x.shape + assert ( + x.size(-1) == self.in_features + ), f"[E] Input dimension does not match the weight size {self.out_features, self.in_features}, but got input size ({tuple(x.size())}))" + if self.in_bit < 16: + x = self.input_quantizer(x) + + # ignore morr_output_scale, as we apply a uniform scale for all rows. + weight, morr_output_scale = self.build_weight() + + if self.in_features_pad > self.in_features: + if self.x_zero_pad is None or self.x_zero_pad.size(0) != x.size(0): + self.x_zero_pad = torch.zeros( + x.size(0), + self.in_features_pad - self.in_features, + device=x.device, + dtype=x.dtype, + ) + x = torch.cat([x, self.x_zero_pad], dim=1) + + # Find max values for uniform scaling + w_max = weight.abs().max() + x_max = x.abs().max(dim=1, keepdim=True)[0] + + x = x.view(-1, self.grid_dim_x, self.miniblock) + + ### modulation + ### x: [bs, q, k] -> [bs, q, k] + x = self.input_modulator(x) + + ### propagate through morr array + ### x: [bs, q, k] -> [bs, p*k] + x = self.propagate_morr(weight, x, morr_output_scale) + + # Apply uniform scaling + out = out * x_max * w_max + + if self.out_features < self.out_features_pad: + x = x[..., : self.out_features] + if self.bias is not None: + x = x + self.bias.unsqueeze(0) + + x = x.view(B, N, self.out_features) + return x diff --git a/src/chop/nn/optical/modules/morr_transformer.py b/src/chop/nn/optical/modules/morr_transformer.py new file mode 100644 index 000000000..ef1b86807 --- /dev/null +++ b/src/chop/nn/optical/modules/morr_transformer.py @@ -0,0 +1,474 @@ +from typing import Optional +import logging + +import numpy as np +import torch +import torch.fft +from torch import Tensor +from torch.nn import Parameter, init +from torch.types import Device + +from ..utils import MORRConfig_20um_MQ +from ..utils import mrr_roundtrip_phase_to_tr_func, mrr_roundtrip_phase_to_tr_fused +from ..utils import toeplitz +from ..utils import morr_uniform_ +from ..utils import input_quantize_fn, weight_quantize_fn +from .base_layer import ONNBaseLayer + +logger = logging.getLogger(__name__) + +__all__ = ["AllPassMORRCirculantMatMals"] + + +class AllPassMORRCirculantMatMals(ONNBaseLayer): + """ + All-pass MORR Linear layer, assumes (1) block-circulant matrix (2) differential rails (3) learnable balancing factors. + J. Gu, et al., "SqueezeLight: Towards Scalable Optical Neural Networks with Multi-Operand Ring Resonators" + https://doi.org/10.23919/DATE51398.2021.9474147 + """ + + __constants__ = ["in_features", "out_features"] + in_features: int + out_features: int + miniblock: int + weight: Tensor + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = False, + config=None, + device: Device = torch.device("cpu"), + ) -> None: + super(AllPassMORRCirculantMatMals, self).__init__() + self.in_features = in_features + self.out_features = out_features + + miniblock_size = config.get("miniblock", 4) + self.miniblock = miniblock_size + self.grid_dim_x = int(np.ceil(self.in_features / miniblock_size)) + self.grid_dim_y = int(np.ceil(self.out_features / miniblock_size)) + self.in_features_pad = self.grid_dim_x * miniblock_size + self.out_features_pad = self.grid_dim_y * miniblock_size + + self.v_max = 10.8 + self.v_pi = 4.36 + self.gamma = np.pi / self.v_pi**2 + self.w_bit = 32 + self.in_bit = 32 + + morr_config = config.get("MORRConfig", MORRConfig_20um_MQ) + morr_init_val = config.get("morr_init", MORRConfig_20um_MQ) + self.MORRConfig = morr_config + self.morr_init = morr_init_val + self.mrr_a = morr_config.attenuation_factor + self.mrr_r = morr_config.coupling_factor + self.trainable_morr_bias = config.get("trainable_morr_bias", MORRConfig_20um_MQ) + self.trainable_morr_scale = config.get( + "trainable_morr_scale", MORRConfig_20um_MQ + ) + self.device = device + ### calculate FWHM (rad) + self.morr_fwhm = ( + -4 + * np.pi**2 + * morr_config.radius + * morr_config.effective_index + * ( + 1 / morr_config.resonance_wavelength + - 1 / (morr_config.resonance_wavelength - morr_config.bandwidth / 2) + ) + ) + + ### allocate parameters + self.weight = None + self.x_zero_pad = None + self.morr_output_scale = None ## learnable balancing factors implelemt by MRRs + self.morr_input_bias = None ## round-trip phase shift bias within MORR + self.morr_input_scale = ( + None ## scaling factor for the round-trip phase shift within MORR + ) + self.morr_gain = ( + 100 / (self.in_features // self.miniblock) + ) ** 0.5 ## TIA gain, calculated such that output variance is around 1 + ### build trainable parameters + self.build_parameters() + + ### quantization tool + self.input_quantizer = input_quantize_fn(self.in_bit, device=self.device) + self.weight_quantizer = weight_quantize_fn( + self.w_bit, alg="dorefa_pos" + ) ## [0-1] positive only, maintain the original scale + self.morr_output_scale_quantizer = weight_quantize_fn( + self.w_bit, alg="dorefa_sym" + ) ## [-1,1] full-range + + self.mrr_roundtrip_phase_to_tr = mrr_roundtrip_phase_to_tr_func( + a=self.mrr_a, r=self.mrr_r, intensity=True + ) + + ### default set to slow forward + self.disable_fast_forward() + ### default set no gamma noise + self.set_gamma_noise(0) + ### default set no crosstalk + self.disable_crosstalk() + ### default set no phase variation + self.disable_phase_variation() + + if bias: + self.bias = Parameter(torch.Tensor(out_features).to(self.device)) + else: + self.register_parameter("bias", None) + + self.reset_parameters(morr_init=morr_init_val) + self.finegrain_drop_mask = None + + def build_parameters(self) -> None: + + self.weight = Parameter( + torch.ones( + self.grid_dim_y, + self.grid_dim_x, + self.miniblock, + device=self.device, + dtype=torch.float, + ) + ) + ### Learnable balancing factor (morr_output_scale) + ### We use a single scaling factor for each block + self.morr_output_scale = Parameter( + torch.randn(1, 1, max(1, self.grid_dim_x // 2) + 1, 1, device=self.device) + ) + if self.trainable_morr_bias: + ### initialize with the finest-granularity, i.e., per mini-block + self.morr_input_bias = Parameter( + torch.zeros( + self.grid_dim_y, + self.grid_dim_x, + device=self.device, + dtype=torch.float, + ) + ) + if self.trainable_morr_scale: + ### initialize with the finest-granularity, i.e., per mini-block + self.morr_input_scale = Parameter( + torch.zeros( + self.grid_dim_y, + self.grid_dim_x, + device=self.device, + dtype=torch.float, + ) + ) + + def reset_parameters(self, morr_init: bool = False) -> None: + ### nonlinear curve aware initialization + if morr_init: + ## initialize weight + morr_uniform_( + self.weight, + MORRConfig=self.MORRConfig, + n_op=self.miniblock, + biased=self.w_bit >= 16, + gain=2 if self.in_bit < 16 else 1, + ) # quantization needs zero-center + self.sigma_weight = self.weight.data.std().item() + self.weight_quant_gain = None + + ## output distribution aware initialization to output scaling factor + t1 = mrr_roundtrip_phase_to_tr_fused( + torch.tensor([0]).float(), a=self.mrr_a, r=self.mrr_r, intensity=True + ) + t2 = mrr_roundtrip_phase_to_tr_fused( + torch.tensor([self.morr_fwhm * 2.4]).float(), + a=self.mrr_a, + r=self.mrr_r, + intensity=True, + ) + g = ( + (t2 - t1) / (2.4 * self.morr_fwhm) + ).item() ## 0~2.4 FWHM slope as a linear approximation + + self.sigma_out_scale = 4 / (3 * self.grid_dim_x**0.5 * g * self.morr_fwhm) + self.out_scale_quant_gain = None + init.normal_(self.morr_output_scale, 0, self.sigma_out_scale) + else: + init.kaiming_normal_(self.weight.data) + init.kaiming_normal_(self.morr_output_scale.data) + self.sigma_weight = self.weight.data.std().item() + self.weight_quant_gain = None + self.sigma_out_scale = self.morr_output_scale.data.std().item() + self.out_scale_quant_gain = None + + if self.morr_input_bias is not None: + self.morr_input_bias.data.zero_() + if self.morr_input_scale is not None: + ### after sigmoid, it cooresponds to 1 scale + init.normal_(self.morr_input_scale.data, 2, 0.1) + + if self.bias is not None: + init.uniform_(self.bias, 0, 0) + + def sync_parameters(self, src: str = "weight") -> None: + """ + description: synchronize all parameters from the source parameters + """ + + raise NotImplementedError + + def build_weight(self, ) -> Tensor: + if self.w_bit < 16: + ### differentiable quantizer based on STE to enable QAT (Dorefa-Net, arXiv 2016) + weight = self.weight_quantizer(self.weight) + + ## rescale weights after quantization can maintain the initialization distribution + if self.weight_quant_gain is None: + self.weight_quant_gain = self.sigma_weight / weight.data.std() + if self.trainable_morr_scale: + morr_scale = self.morr_scale * self.weight_quant_gain + else: + morr_scale = self.weight_quant_gain + weight = weight.mul( + morr_scale + ) ### gain factor from Tanh used in quantization + + ### quantize learnable balancing factor + morr_output_scale = self.morr_output_scale_quantizer(self.morr_output_scale) + else: + weight = self.weight.abs() # positive only + morr_output_scale = ( + self.morr_output_scale - self.morr_output_scale.data.mean() + ) + + if self.finegrain_drop_mask is not None: + weight = weight.mul(self.finegrain_drop_mask.float()) + + ## differential balancing factor concatenation + scale = morr_output_scale[..., :-1, :] + scale_pad = morr_output_scale[..., -1:, :] + if self.grid_dim_x % 2 == 0: + # even blocks + scale = torch.cat([scale, -scale], dim=2) # [1, 1, q, 1] + else: + # odd blocks + if self.grid_dim_x > 1: + scale = torch.cat([morr_output_scale, -scale], dim=2) # [1, 1, q, 1] + else: + scale = scale_pad # [1, 1, q, 1] + morr_output_scale = scale.squeeze(-1).unsqueeze(0) # [1 ,1, 1, q] + + return weight, morr_output_scale + + def enable_fast_forward(self) -> None: + self.fast_forward_flag = True + + def disable_fast_forward(self) -> None: + self.fast_forward_flag = False + + def set_gamma_noise( + self, noise_std: float, random_state: Optional[int] = None + ) -> None: + self.gamma_noise_std = noise_std + + def load_parameters(self, param_dict) -> None: + """ + description: update parameters based on this parameter dictionary\\ + param param_dict {dict of dict} {layer_name: {param_name: param_tensor, ...}, ...} + """ + for name, param in param_dict.items(): + getattr(self, name).data.copy_(param) + + def set_weight_bitwidth(self, w_bit: int) -> None: + self.w_bit = w_bit + self.weight_quantizer.set_bitwidth(w_bit) + self.morr_output_scale_quantizer.set_bitwidth(w_bit) + + def set_input_bitwidth(self, in_bit: int) -> None: + self.in_bit = in_bit + self.input_quantizer.set_bitwidth(in_bit) + + def input_modulator(self, x: Tensor) -> Tensor: + ### voltage to power, which is proportional to the phase shift + return x * x + + def set_crosstalk_coupling_matrix( + self, coupling_factor: float, drop_perc: float = 0 + ) -> None: + ### crosstalk coupling matrix is a symmetric matrix, but the intra-MORR crosstalk can be taken as a round-trip phase shift scaling factor, which is proportional to the number of segments after pruned. + ### drop-perc is the pruning percentage. + assert 0 <= coupling_factor <= 1, logger.error( + f"Coupling factor must in [0,1], but got {coupling_factor}" + ) + + self.crosstalk_factor = ( + 1 + max(3, (self.miniblock * (1 - drop_perc) - 1)) * coupling_factor + ) + + def enable_crosstalk(self) -> None: + self.enable_thermal_crosstalk = True + + def disable_crosstalk(self) -> None: + self.enable_thermal_crosstalk = False + + def set_phase_variation(self, phase_noise_std: float = 0) -> None: + self.phase_noise_std = phase_noise_std + + def enable_phase_variation(self) -> None: + self.enable_phase_noise = True + + def disable_phase_variation(self) -> None: + self.enable_phase_noise = False + + def enable_trainable_morr_scale(self) -> None: + self.trainable_morr_scale = True + + def disable_trainable_morr_scale(self) -> None: + self.trainable_morr_scale = False + + def enable_trainable_morr_bias(self) -> None: + self.trainable_morr_bias = True + + def disable_trainable_morr_bias(self) -> None: + self.trainable_morr_bias = False + + @property + def morr_bias(self) -> Tensor: + if self.morr_input_bias is None: + return None + # return 2 * self.morr_fwhm * torch.sigmoid(self.morr_input_bias.unsqueeze(0).unsqueeze(-1)) + return self.morr_fwhm * torch.tanh( + self.morr_input_bias.unsqueeze(0).unsqueeze(-1) + ) + + @property + def morr_scale(self) -> Tensor: + if self.morr_input_scale is None: + return None + return torch.sigmoid(self.morr_input_scale.unsqueeze(-1)) + 0.2 # [p, q, 1] + + def propagate_morr( + self, weight: Tensor, x: Tensor, morr_output_scale: Tensor + ) -> Tensor: + """ + @description: propagate through the analytically calculated transfer matrix of molg. We implement circulant matrix multiplication using fast circ matmul + @param weight {torch.Tensor} two phase shifters in the MZI-based attenuators + @param x {torch.Tensor} complex-valued input + @param morr_output_scale {torch.Tensor} learnable balancing factors + @return: y {torch.Tensor} output of attenuators + """ + ### x : [bs, q, k] + ### weights: [p, q, k] + ### morr_output_scale: [1, 1, 1, q] + + ### input scaling [TCAD'21], must have valid ranges. too small will have dead neuron and not enough nonlinearity; too large will have larger power, cross-channel crosstalk. [0.2 - 1.2] will be suitable + ## build circulant weight matrix + # crosstalk on the weights are much cheaper to compute than on the phase shift + if self.enable_thermal_crosstalk and self.crosstalk_factor > 1: + weight = weight * self.crosstalk_factor + weight = toeplitz(weight).unsqueeze(0) # [1, p, q, k, k] + x = x.unsqueeze(1).unsqueeze(-1) # [bs, 1, q, k, 1] + x = weight.matmul(x).squeeze(-1) # [bs, p, q, k] + + if self.enable_phase_noise and self.phase_noise_std > 1e-5: + x = x + torch.zeros_like(x).normal_(0, self.phase_noise_std) + + ### input biasing [TCAD'21], must have valid ranges. too large will have power issue and cross-channel crosstalk. [-2FWHM ~ 0] + if self.trainable_morr_bias: + x = x - self.morr_bias + + ### Use theoretical transmission function for trainable MORR nonlinearity [TCAD'21] + ### x is the phase detuning, x=0 means on-resonance + ### phase: [bs, p, q, k] + x = self.mrr_roundtrip_phase_to_tr(x) # 3x faster than autograd + + ## implement balancing factor as dot-product + """ + if(self.w_bit < 16): + morr_output_scale = self.morr_output_scale_quantizer(self.morr_output_scale) + if(self.sigma_out_scale_quant_gain is None): + self.sigma_out_scale_quant_gain = self.sigma_out_scale / morr_output_scale.data.std().item() + morr_output_scale = morr_output_scale.mul(self.sigma_out_scale_quant_gain)### gain factor from Tanh used in quantization + else: + morr_output_scale = self.morr_output_scale + # morr_output_scale = morr_output_scale * self.morr_gain + scale = morr_output_scale[..., :-1, :] + scale_pad = morr_output_scale[..., -1:, :] + + # print("morr diff transmission:", end=", ") + # diff = x[..., :x.size(2)//2,:]-x[..., x.size(2)//2:,:] + # print_stat(diff) + if(self.grid_dim_x % 2 == 0): + #even blocks + scale = torch.cat([scale, -scale], dim=2) # [1, 1, q, 1] + else: + # odd blocks + if(self.grid_dim_x > 1): + scale = torch.cat([morr_output_scale, -scale], dim=2) # [1, 1, q, 1] + else: + scale = scale_pad # [1, 1, q, 1] + scale = scale.squeeze(-1).unsqueeze(0) # [1 ,1, 1, q] + # print("output scale Q:", end=", ") + # print_stat(scale[..., :scale.size(-1)//2]) + """ + x = morr_output_scale.matmul(x) # [1, 1, 1, q] x [bs, p, q, k] = [bs, p, 1, k] + x = x.flatten(1) # [bs, p*k] + return x + + def get_finegrain_drop_mask(self, topk: int) -> Tensor: + if self.w_bit < 16: + weight = self.weight_quantizer(self.weight.data) # [p, q, k] + else: + weight = self.weight.data.abs() + indices = weight.argsort(dim=-1) + mask = torch.ones_like(weight, dtype=torch.bool, device=weight.device) + + drop_indices = indices[:, :, 0:-topk] + mask.scatter_(2, drop_indices, 0) + self.finegrain_drop_mask = mask + return mask + + def apply_finegrain_drop_mask(self, mask: Tensor) -> None: + if self.w_bit < 16: + self.weight.data.masked_fill_(~mask.view_as(self.weight.data), -1000) + else: + self.weight.data.masked_fill_(~mask.view_as(self.weight.data), 0) + + def forward(self, x: Tensor, y: Tensor) -> Tensor: + self.weight = y # load y as weight + + B, N, D = x.shape + assert ( + x.size(-1) == self.in_features + ), f"[E] Input dimension does not match the weight size {self.out_features, self.in_features}, but got input size ({tuple(x.size())}))" + if self.in_bit < 16: + x = self.input_quantizer(x) + + weight, morr_output_scale = self.build_weight() + if self.in_features_pad > self.in_features: + if self.x_zero_pad is None or self.x_zero_pad.size(0) != x.size(0): + self.x_zero_pad = torch.zeros( + x.size(0), + self.in_features_pad - self.in_features, + device=x.device, + dtype=x.dtype, + ) + x = torch.cat([x, self.x_zero_pad], dim=1) + + x = x.view(-1, self.grid_dim_x, self.miniblock) + + ### modulation + ### x: [bs, q, k] -> [bs, q, k] + x = self.input_modulator(x) + + ### propagate through morr array + ### x: [bs, q, k] -> [bs, p*k] + x = self.propagate_morr(weight, x, morr_output_scale) + + if self.out_features < self.out_features_pad: + x = x[..., : self.out_features] + if self.bias is not None: + x = x + self.bias.unsqueeze(0) + + x = x.view(B, N, self.out_features) + return x diff --git a/src/chop/passes/module/transforms/optical/module_transform_helper.py b/src/chop/passes/module/transforms/optical/module_transform_helper.py index 3d5f9670b..1520c205f 100644 --- a/src/chop/passes/module/transforms/optical/module_transform_helper.py +++ b/src/chop/passes/module/transforms/optical/module_transform_helper.py @@ -25,8 +25,47 @@ def weight_replacement_optical(original, new_module): "weight replacement function for the optical module not implemented" ) +def weight_replacement_linear_optical(linear_layer, morr_layer): + """ + Replace the weights of AllPassMORRLinear (morr_layer) with those from a standard nn.Linear (linear_layer). + Focuses only on weight copying (no bias copying). + """ + # Extract dimensions + out_features = morr_layer.out_features + in_features = morr_layer.in_features + miniblock = morr_layer.miniblock + grid_dim_x = morr_layer.grid_dim_x + grid_dim_y = morr_layer.grid_dim_y + in_features_pad = morr_layer.in_features_pad + + # Get the weights from the standard linear layer + standard_weights = linear_layer.weight.data # [out_features, in_features] + + # Ensure the shapes match + assert standard_weights.shape[0] == out_features, "Output feature dimensions don't match" + assert standard_weights.shape[1] == in_features, "Input feature dimensions don't match" + + # Pad the standard weights to match in_features_pad + if in_features_pad > in_features: + padded_weights = torch.zeros(out_features, in_features_pad, + device=standard_weights.device, + dtype=standard_weights.dtype) + padded_weights[:, :in_features] = standard_weights + standard_weights = padded_weights # [out_features, in_features_pad] + + # Reshape to match the MORR structure [grid_dim_y, grid_dim_x, miniblock] + assert grid_dim_y == out_features, "grid_dim_y does not match out_features" + assert grid_dim_x * miniblock == in_features_pad, "grid_dim_x * miniblock does not match in_features_pad" + + reshaped_weights = standard_weights.reshape(grid_dim_y, grid_dim_x, miniblock) + + # Copy the weights to the MORR layer + with torch.no_grad(): + morr_layer.weight.data.copy_(reshaped_weights) + + return morr_layer -def weight_replacement_linear_optical(x, y): +def weight_replacement_circulant_linear_optical(x, y): """ Replace the weights of AllPassMORRCirculantLinear (y) with those from a standard nn.Linear (x). From 7c1c7ae7ddc8e94cfa1e7e76d9055757305252aa Mon Sep 17 00:00:00 2001 From: Johnny1882 Date: Sun, 6 Apr 2025 20:42:44 +0100 Subject: [PATCH 24/38] add morr_full layer, perfrom testing on bert model --- .gitignore | 3 +- .../nn/optical/modules/morr_custom_linear.py | 36 +++--- src/chop/nn/optical/modules/morr_linear.py | 12 +- .../nn/optical/modules/morr_transformer.py | 110 ++++++++++++++++-- .../optical/module_transform_helper.py | 38 +++--- .../module/transforms/optical/optical.py | 14 ++- 6 files changed, 167 insertions(+), 46 deletions(-) diff --git a/.gitignore b/.gitignore index 01935c9af..9fd4b8488 100644 --- a/.gitignore +++ b/.gitignore @@ -169,4 +169,5 @@ mase-trainer/ test-trainer/ test/self -model_sst2/ \ No newline at end of file +model_sst2/ +runs \ No newline at end of file diff --git a/src/chop/nn/optical/modules/morr_custom_linear.py b/src/chop/nn/optical/modules/morr_custom_linear.py index 1a035b88b..9e9f5272a 100644 --- a/src/chop/nn/optical/modules/morr_custom_linear.py +++ b/src/chop/nn/optical/modules/morr_custom_linear.py @@ -148,13 +148,13 @@ def build_parameters(self) -> None: ) ### Learnable balancing factor (morr_output_scale) ### We use a single scaling factor for each block - # we init this to be identical to retrive the original output distribution + + # Init this to ones and non-trainable + # TODO: Verify the effectiveness of making this trainable self.morr_output_scale = Parameter( - torch.ones(1, 1, max(1, self.grid_dim_x // 2) + 1, 1, device=self.device) + torch.ones(1, 1, max(1, self.grid_dim_x), 1, device=self.device) ) - - # instead of unique scaling factor each column, we apply a uniform scale for all rows. - self.uniform_scale = Parameter(torch.tensor(1.0, device=self.device)) + # self.morr_output_scale.requires_grad = False if self.trainable_morr_bias: ### initialize with the finest-granularity, i.e., per mini-block @@ -259,9 +259,9 @@ def build_weight(self) -> Tensor: if self.finegrain_drop_mask is not None: weight = weight.mul(self.finegrain_drop_mask.float()) - # morr_output_scale is removed here + # morr_output_scale processing is removed here - return weight, None + return weight, self.morr_output_scale.squeeze(-1).unsqueeze(0) def enable_fast_forward(self) -> None: self.fast_forward_flag = True @@ -358,7 +358,7 @@ def _compute_single_pass(self, weight: Tensor, x: Tensor, morr_output_scale: Ten weight = weight.unsqueeze(0).unsqueeze(-2) # [1, M, N/k, 1, k] x = x.unsqueeze(1).unsqueeze(-1) # [bs, 1, N/k, k, 1] x = weight.matmul(x) # [bs, M, N/k, 1, 1] - x = x.squeeze(-1).squeeze(-1) # [bs, M, N/k] + x = x.squeeze(-1) # [bs, M, N/k, 1] if self.enable_phase_noise and self.phase_noise_std > 1e-5: x = x + torch.zeros_like(x).normal_(0, self.phase_noise_std) @@ -374,8 +374,8 @@ def _compute_single_pass(self, weight: Tensor, x: Tensor, morr_output_scale: Ten x = self.mrr_roundtrip_phase_to_tr(x) # Flatten output - x = morr_output_scale.matmul(x) # [1, 1, 1, N/k] x [bs, M, N/k, k] = [bs, M, 1, k] - x = x.flatten(1) # [bs, M*k] + x = morr_output_scale.matmul(x) # [1, 1, 1, N/k] x [bs, M, N/k, 1] = [bs, M, 1, 1] + x = x.squeeze(-1).squeeze(-1) # [bs, M] return x @@ -442,7 +442,11 @@ def apply_finegrain_drop_mask(self, mask: Tensor) -> None: self.weight.data.masked_fill_(~mask.view_as(self.weight.data), 0) def forward(self, x: Tensor) -> Tensor: - B, N, D = x.shape + # adjust output shape if used in transformer + is_transformer = len(x.shape) == 3 + if is_transformer: + B, N, D = x.shape + assert ( x.size(-1) == self.in_features ), f"[E] Input dimension does not match the weight size {self.out_features, self.in_features}, but got input size ({tuple(x.size())}))" @@ -477,12 +481,14 @@ def forward(self, x: Tensor) -> Tensor: x = self.propagate_morr(weight, x, morr_output_scale) # Apply uniform scaling - out = out * x_max * w_max + # x = x * x_max * w_max if self.out_features < self.out_features_pad: x = x[..., : self.out_features] if self.bias is not None: x = x + self.bias.unsqueeze(0) - - x = x.view(B, N, self.out_features) - return x + + # adjust output shape if used in transformer + if is_transformer: + x = x.view(B, N, self.out_features) + return x \ No newline at end of file diff --git a/src/chop/nn/optical/modules/morr_linear.py b/src/chop/nn/optical/modules/morr_linear.py index dcb35752c..50f28345e 100644 --- a/src/chop/nn/optical/modules/morr_linear.py +++ b/src/chop/nn/optical/modules/morr_linear.py @@ -443,7 +443,11 @@ def apply_finegrain_drop_mask(self, mask: Tensor) -> None: self.weight.data.masked_fill_(~mask.view_as(self.weight.data), 0) def forward(self, x: Tensor) -> Tensor: - B, N, D = x.shape + # if used in transformer + is_transformer = len(x.shape) == 3 + if is_transformer: + B, N, D = x.shape + assert ( x.size(-1) == self.in_features ), f"[E] Input dimension does not match the weight size {self.out_features, self.in_features}, but got input size ({tuple(x.size())}))" @@ -475,6 +479,8 @@ def forward(self, x: Tensor) -> Tensor: x = x[..., : self.out_features] if self.bias is not None: x = x + self.bias.unsqueeze(0) - - x = x.view(B, N, self.out_features) + + # adjust output shape if used in transformer + if is_transformer: + x = x.view(B, N, self.out_features) return x diff --git a/src/chop/nn/optical/modules/morr_transformer.py b/src/chop/nn/optical/modules/morr_transformer.py index ef1b86807..dfe2fec86 100644 --- a/src/chop/nn/optical/modules/morr_transformer.py +++ b/src/chop/nn/optical/modules/morr_transformer.py @@ -2,8 +2,11 @@ import logging import numpy as np +import math import torch +import torch.nn as nn import torch.fft +import torch.nn.functional as F from torch import Tensor from torch.nn import Parameter, init from torch.types import Device @@ -14,12 +17,13 @@ from ..utils import morr_uniform_ from ..utils import input_quantize_fn, weight_quantize_fn from .base_layer import ONNBaseLayer +from .morr_custom_linear import AllPassMORRLinear +from .morr_linear import AllPassMORRCirculantLinear logger = logging.getLogger(__name__) __all__ = ["AllPassMORRCirculantMatMals"] - class AllPassMORRCirculantMatMals(ONNBaseLayer): """ All-pass MORR Linear layer, assumes (1) block-circulant matrix (2) differential rails (3) learnable balancing factors. @@ -217,7 +221,7 @@ def sync_parameters(self, src: str = "weight") -> None: raise NotImplementedError - def build_weight(self, ) -> Tensor: + def build_weight(self, y: Tensor) -> Tensor: if self.w_bit < 16: ### differentiable quantizer based on STE to enable QAT (Dorefa-Net, arXiv 2016) weight = self.weight_quantizer(self.weight) @@ -435,9 +439,15 @@ def apply_finegrain_drop_mask(self, mask: Tensor) -> None: self.weight.data.masked_fill_(~mask.view_as(self.weight.data), 0) def forward(self, x: Tensor, y: Tensor) -> Tensor: - self.weight = y # load y as weight - - B, N, D = x.shape + + # load y as weight: + self.weight.data.copy_(y) + + # if used in transformer + is_transformer = len(x.shape) == 3 + if is_transformer: + B, N, D = x.shape + assert ( x.size(-1) == self.in_features ), f"[E] Input dimension does not match the weight size {self.out_features, self.in_features}, but got input size ({tuple(x.size())}))" @@ -469,6 +479,92 @@ def forward(self, x: Tensor, y: Tensor) -> Tensor: x = x[..., : self.out_features] if self.bias is not None: x = x + self.bias.unsqueeze(0) - - x = x.view(B, N, self.out_features) + + # adjust output shape if used in transformer + if is_transformer: + x = x.view(B, N, self.out_features) return x + +class MORRMHA(nn.Module): + def __init__(self, embed_dim, heads): + super(MORRMHA, self).__init__() + assert embed_dim % heads == 0 + self.n_heads = heads + self.Wq = AllPassMORRCirculantLinear(embed_dim, embed_dim) + self.Wk = AllPassMORRCirculantLinear(embed_dim, embed_dim) + self.Wv = AllPassMORRCirculantLinear(embed_dim, embed_dim) + self.qmm1 = AllPassMORRCirculantMatMals() + self.dropout_wq = nn.Dropout(0.1) + self.dropout_wk = nn.Dropout(0.1) + self.dropout_wv = nn.Dropout(0.1) + self.qmm2 = AllPassMORRCirculantMatMals() + self.Wout = AllPassMORRCirculantLinear(embed_dim, embed_dim) + self.dropout1 = nn.Dropout(0.1) + self.dropout2 = nn.Dropout(0.1) + + def forward(self, x, mask): + b = x.size(0) + n = x.size(1) + h = self.n_heads + d = x.size(2) + + def arrange_heads(acts): + # incoming shape of b, n, d, want b, h, n, d/h + return acts.view(b, n, h, -1).transpose(1, 2) + + q = arrange_heads(self.dropout_wq(self.Wq(x))) + k = arrange_heads(self.dropout_wk(self.Wk(x))) + v = arrange_heads(self.dropout_wv(self.Wv(x))) + + attn = self.qmm1(q, k.transpose(2, 3)) # yields b, h, n, n + masked = attn.masked_fill(mask, float("-inf")) + softmax_attn = self.dropout1(F.softmax(masked / math.sqrt(d // h), dim=3)) + out = self.qmm2(softmax_attn, v) # b, h, n, d/h + + out = out.transpose(1, 2).reshape(b, n, -1) + out = self.dropout2(out) + out = self.Wout(out) + return out + + + +class MORRFF(nn.Module): + def __init__(self, embed_dim, expansion_dim): + super(MORRFF, self).__init__() + self.first_drop = nn.Dropout(0.1) + self.layer1 = AllPassMORRCirculantLinear(embed_dim, expansion_dim, use_noise=True) + self.act = nn.ReLU6(inplace=True) + self.dropout = nn.Dropout(0.1) + self.layer2 = AllPassMORRCirculantLinear(expansion_dim, embed_dim, use_noise=True) + + def forward(self, x): + out = self.first_drop(x) + out = self.layer1(out) + out = self.act(out) + out = self.dropout(out) + out = self.layer2(out) + return out + +class MORRDecoderLayer(nn.Module): + def __init__(self, features, heads): + super(MORRDecoderLayer, self).__init__() + self.norm1 = nn.LayerNorm(features) + self.attn = MORRMHA(features, heads) + self.drop1 = nn.Dropout(0.1) + self.norm2 = nn.LayerNorm(features) + self.ff = MORRFF(features, features * 4) + self.drop2 = nn.Dropout(0.1) + + def forward(self, x, attn_mask): + # no need for key mask for gpt; autoregressive masking already prevents 'real' tokens from attending to padding tokens to the right + identity = x + out = self.norm1(x) + out = self.attn(out, attn_mask) + out = self.drop1(out) + out = out + identity + identity = out + out = self.norm2(out) + out = self.ff(out) + out = self.drop2(out) + out = out + identity + return out \ No newline at end of file diff --git a/src/chop/passes/module/transforms/optical/module_transform_helper.py b/src/chop/passes/module/transforms/optical/module_transform_helper.py index 1520c205f..635d3d8f9 100644 --- a/src/chop/passes/module/transforms/optical/module_transform_helper.py +++ b/src/chop/passes/module/transforms/optical/module_transform_helper.py @@ -1,21 +1,26 @@ import torch import torch.nn as nn +import math from chop.passes.module.module_modify_helper import ( get_module_by_name, set_module_by_name, ) -def replace_by_name_optical(network, module_name: str, new_module): +def replace_by_name_optical(network, module_name: str, new_module, target_name): original = get_module_by_name(network, module_name) - updated_module = weight_replacement_optical(original, new_module) + if target_name == "linear_morr_full": + updated_module = weight_replacement_full_linear_optical(original, new_module) + elif target_name == "linear_morr": + updated_module = weight_replacement_circulant_linear_optical(original, new_module) + network = set_module_by_name(network, module_name, updated_module) return network -def weight_replacement_optical(original, new_module): +def weight_replacement_full_linear_optical(original, new_module): if isinstance(original, nn.Linear): return weight_replacement_linear_optical(original, new_module) elif isinstance(original, nn.Conv2d): @@ -67,13 +72,12 @@ def weight_replacement_linear_optical(linear_layer, morr_layer): def weight_replacement_circulant_linear_optical(x, y): """ - Replace the weights of AllPassMORRCirculantLinear (y) - with those from a standard nn.Linear (x). + Replace the weights of AllPassMORRCirculantLinear (y) with those from a standard nn.Linear (x). Focuses only on weight copying (no bias copying). """ # Fetch original linear weight [out_features, in_features] - W = x.weight.data # shape: (out_features, in_features) + W = x.weight.data # [out_features, in_features] # Grab dimensions and zero-pad if needed out_features_pad = y.out_features_pad # padded out_features in y @@ -84,30 +88,30 @@ def weight_replacement_circulant_linear_optical(x, y): # Construct padded weight tensor W_padded = W.new_zeros((out_features_pad, in_features_pad)) - W_padded[: W.size(0), : W.size(1)] = W # copy original into top-left + W_padded[: W.size(0), : W.size(1)] = W - # Now we create a new tensor of shape [grid_dim_y, grid_dim_x, miniblock] - # by compressing each row-block [1 x miniblock] from W_padded into a single scalar. - # This is a simple example that takes the mean across the miniblock slice. - new_weight = W.new_zeros((grid_dim_y, grid_dim_x, miniblock)) + # Takes the mean across the miniblock slice. + new_weight = W.new_zeros((grid_dim_y, grid_dim_x, miniblock)) # [grid_dim_y, grid_dim_x, miniblock] # Fill new_weight by averaging the corresponding sub-blocks in W_padded + # original miniblock: [k, k] new miniblock: [k, 1] with torch.no_grad(): for p in range(grid_dim_y): for q in range(grid_dim_x): for k in range(miniblock): - # The row in W_padded we look at: - row_idx = p * miniblock + k - # The columns we look at: - col_start = q * miniblock + row_idx = p * miniblock + k # The row in W_padded: + col_start = q * miniblock # The columns in W_padded: col_end = (q + 1) * miniblock - block = W_padded[row_idx, col_start:col_end] + new_weight[p, q, k] = block.mean() + bound = 1 / math.sqrt(miniblock) + new_weight = torch.rand((grid_dim_y, grid_dim_x, miniblock), + device=W.device, + dtype=W.dtype) * 2 * bound - bound # Copy the result into y.weight y.load_parameters({"weight": new_weight}) - # y.weight.data.copy_(new_weight) return y diff --git a/src/chop/passes/module/transforms/optical/optical.py b/src/chop/passes/module/transforms/optical/optical.py index a82ea72fc..d8981bb10 100644 --- a/src/chop/passes/module/transforms/optical/optical.py +++ b/src/chop/passes/module/transforms/optical/optical.py @@ -29,12 +29,12 @@ def optical_transform_by_type(network, pass_args): config = config["config"] postfix = config.pop("name") for n, m in n_m.items(): - print(f"processing {n}...") if isinstance(m, module): + print(f"processing {n}") new_m = instantiate_module( m, postfix, optical_module_map, {"config": config} ) - network = replace_by_name_optical(network, n, new_m) + network = replace_by_name_optical(network, n, new_m, type_name +'_'+postfix) return network @@ -68,6 +68,7 @@ def optical_transform_by_regex_name(network, pass_args): matched_pattern = match_a_pattern(n, patterns) if not matched_pattern: continue + print(f"processing {n}") optical_config = pass_args[matched_pattern]["config"] postfix = optical_config["name"] @@ -78,10 +79,17 @@ def optical_transform_by_regex_name(network, pass_args): else {"config": optical_config} ) + if isinstance(m, torch.nn.Linear): + type_name = "linear" + elif isinstance(m, torch.nn.Conv2d): + type_name = "conv2d" + else: + raise ValueError(f"{type_name} is not supported!") + new_m = instantiate_module( m, postfix, optical_module_map, additional_module_args ) - network = replace_by_name_optical(network, n, new_m) + network = replace_by_name_optical(network, n, new_m, type_name +'_'+postfix) return network From c2813ee65d7441c356a9d32bfef46f3f4358610d Mon Sep 17 00:00:00 2001 From: Johnny1882 Date: Thu, 10 Apr 2025 13:46:31 +0100 Subject: [PATCH 25/38] update morr_transformer --- src/chop/nn/optical/modules/morr_linear.py | 2 +- .../nn/optical/modules/morr_transformer.py | 320 +++++++++++++++++- 2 files changed, 318 insertions(+), 4 deletions(-) diff --git a/src/chop/nn/optical/modules/morr_linear.py b/src/chop/nn/optical/modules/morr_linear.py index 50f28345e..e409070ea 100644 --- a/src/chop/nn/optical/modules/morr_linear.py +++ b/src/chop/nn/optical/modules/morr_linear.py @@ -46,7 +46,7 @@ def __init__( in_features: int, out_features: int, bias: bool = False, - config=None, + config={}, device: Device = torch.device("cpu"), ) -> None: super(AllPassMORRCirculantLinear, self).__init__() diff --git a/src/chop/nn/optical/modules/morr_transformer.py b/src/chop/nn/optical/modules/morr_transformer.py index dfe2fec86..bb1fe9e12 100644 --- a/src/chop/nn/optical/modules/morr_transformer.py +++ b/src/chop/nn/optical/modules/morr_transformer.py @@ -10,6 +10,10 @@ from torch import Tensor from torch.nn import Parameter, init from torch.types import Device +import pytorch_lightning as pl +import torchmetrics +import transformers +from transformers import GPT2TokenizerFast from ..utils import MORRConfig_20um_MQ from ..utils import mrr_roundtrip_phase_to_tr_func, mrr_roundtrip_phase_to_tr_fused @@ -20,10 +24,134 @@ from .morr_custom_linear import AllPassMORRLinear from .morr_linear import AllPassMORRCirculantLinear +from transformers import BertModel, BertForSequenceClassification +from transformers.models.gpt2.modeling_gpt2 import ( + GPT2Attention, + GPT2MLP, + GPT2Block, + Conv1D, +) + logger = logging.getLogger(__name__) __all__ = ["AllPassMORRCirculantMatMals"] +def make_autoregressive_mask_for(x): + length = x.size(1) + ones = x.new_ones((length, length)) + mask = torch.triu(ones, diagonal=1) != 0.0 + return mask + + +def make_position_indices_for(x): + length = x.size(1) + batch_size = x.size(0) + indices = torch.arange(length, device=x.device).repeat(batch_size, 1) + return indices + + +def load_lookup_table(file, device): + data = torch.from_numpy(numpy.genfromtxt(file, delimiter='\t')).float() + levels = data.size(0) + lower_bound = data[0,1].item() + weight = data[:,1].unsqueeze(1).cuda(device) + return weight, lower_bound, levels + + +def apply_lut_to_normalized(x, lut, bit_degredation=0): + lut_weight, lut_lb, lut_levels = lut + deg_factor = 2**bit_degredation + x = x.mul(lut_levels - deg_factor).div(deg_factor).round().mul(deg_factor).to(dtype=torch.long) + x = F.embedding(x, lut_weight).squeeze(-1) + return x + + +class QuantizeValue(torch.autograd.Function): + @staticmethod + def forward(ctx, x, quant_levels, min_val, max_val, quant_mode, lut_min=None): + with torch.no_grad(): + diff = max_val - min_val + x = x.clamp(min_val, max_val).add(-1.0 * min_val).div(diff + 1e-8).mul(quant_levels - 1) + + if quant_mode == 'det': + x = x.round() + x = x.div(quant_levels - 1).mul(diff).add(min_val) + elif quant_mode == 'rand': + x = x.add(torch.rand_like(x).add(-0.5)).round() # randn* 0.288 gives same std as 0-1 rand(), if want to use normal dist. + x = x.div(quant_levels - 1).mul(diff).add(min_val) + + if lut_min is not None: + pos_x = torch.relu(x) + neg_x = x - pos_x + lms = lut_min * max_val + pos_x[pos_x < lms] = lms + lms = lut_min * torch.abs(min_val) + neg_x[neg_x > -lms] = -lms + x = pos_x + neg_x + + return x + + @staticmethod + def backward(ctx, grad_output): + # STE + return grad_output, None, None, None, None, None + +class QuantizeStats(nn.Module): + def __init__(self, percentile, use_clipping=True): + super(QuantizeStats, self).__init__() + self.register_buffer('running_min', torch.tensor(0.0)) + self.register_buffer('running_max', torch.tensor(0.0)) + self.max_calibration_steps = 1 + self.initial_calibration_steps = 0 + #self.register_buffer('calibration_done', torch.tensor(False)) + self.calibration_done = torch.tensor(False) + self.activations = [] + self.percentile = percentile + self.use_clipping = use_clipping + + def update(self, tensor): + if self.use_clipping: + if not self.calibration_done.item(): + self.initial_calibration_steps += 1 + finished = False + + if self.initial_calibration_steps >= self.max_calibration_steps: + finished = True + self.calibration_done = torch.tensor(True) + + with torch.no_grad(): + self.activations.extend(tensor.detach().cpu().tolist()) + + if finished: + maximum = numpy.percentile(self.activations, self.percentile) + self.running_max = torch.tensor(maximum, device=tensor.device, dtype=tensor.dtype) + minimum = tensor.min() + minimum = minimum if minimum >= 0.0 else -maximum + self.running_min = torch.tensor(minimum, device=tensor.device, dtype=tensor.dtype) + self.activations.clear() # free the memory + else: + self.running_min = tensor.min() + self.running_max = tensor.max() + + else: + alpha = 0.999 + with torch.no_grad(): + cur_min = tensor.min() + cur_max = tensor.max() + + if self.initial_calibration_steps == 0: + self.initial_calibration_steps += 1 + self.running_min = cur_min + self.running_max = cur_max + else: + self.running_min = alpha * self.running_min + (1.0 - alpha) * cur_min + self.running_max = alpha * self.running_max + (1.0 - alpha) * cur_max + + + + def get(self): + return self.running_min, self.running_max + class AllPassMORRCirculantMatMals(ONNBaseLayer): """ All-pass MORR Linear layer, assumes (1) block-circulant matrix (2) differential rails (3) learnable balancing factors. @@ -525,8 +653,6 @@ def arrange_heads(acts): out = self.dropout2(out) out = self.Wout(out) return out - - class MORRFF(nn.Module): def __init__(self, embed_dim, expansion_dim): @@ -567,4 +693,192 @@ def forward(self, x, attn_mask): out = self.ff(out) out = self.drop2(out) out = out + identity - return out \ No newline at end of file + return out + +class _MORRGPT(nn.Module): + def __init__(self, features, heads, tokenizer, layers, max_length): + super(_MORRGPT, self).__init__() + vocab_size = len(tokenizer) + 8 - len(tokenizer) % 8 # pad vocab size to 8-multiple for tensor core acceleration + assert vocab_size % 8 == 0 + self.pos_embedding = nn.Embedding(max_length, features) + self.word_embedding = nn.Embedding(vocab_size, features, padding_idx = tokenizer.pad_token_id) + self.embedding_dropout = nn.Dropout(0.1) + self.decoders = nn.ModuleList([MORRDecoderLayer(features, heads) for _ in range(layers)]) + self.norm = nn.LayerNorm(features) + self.output_head = nn.Linear(features, vocab_size) + nn.init.normal_(self.word_embedding.weight, std=0.02) + nn.init.normal_(self.pos_embedding.weight, std=0.02) + + def forward_embedding(self, x): + embedded = self.word_embedding(x) + return embedded + + def forward_attn(self, x): + mask = make_autoregressive_mask_for(x) + pos = make_position_indices_for(x) + pos_embed = self.embedding_dropout(self.pos_embedding(pos) + x) + decoded = pos_embed + for layer in self.decoders: + decoded = layer(decoded, mask) + + out = self.norm(decoded) + return out + + def forward(self, x): + embedded = self.forward_embedding(x) + decoded = self.forward_attn(embedded) + out = self.output_head(decoded) + return out + + +class MORRGPT(pl.LightningModule): + def __init__(self, features, heads, layers=6, max_length=1024): + super().__init__() + self.tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") + self.tokenizer.add_special_tokens({'pad_token': '<|pad|>'}) + self.transformer = _MORRGPT(features, heads, self.tokenizer, layers, max_length) + self.loss = nn.CrossEntropyLoss(ignore_index=self.tokenizer.pad_token_id) + self.val_loss = torchmetrics.MeanMetric() + self.test_loss = torchmetrics.MeanMetric() + self.lr = 0.0005 + self.photon_target = 0 + self.training_steps = 100000 + self.extracting = False + self.use_adam = True + + def get_tokenizer(self): + return self.tokenizer + + def forward(self, x): + return self.transformer(x) + + def training_step(self, batch, batch_idx): + xs, ys = batch + preds = self(xs) + features = preds.size(2) + preds = preds.view(-1, features) + ys = ys.view(-1) + loss = self.loss(preds, ys) + self.log('train loss', loss) + return loss + + def validation_step(self, batch, batch_idx): + xs, ys = batch + preds = self(xs) + features = preds.size(2) + preds = preds.view(-1, features) + ys = ys.view(-1) + loss = self.loss(preds, ys) + self.val_loss.update(loss) + + def validation_epoch_end(self, outputs): + self.log('validation loss', self.val_loss) + + def test_step(self, batch, batch_idx): + xs, ys = batch + preds = self(xs) + features = preds.size(2) + preds = preds.view(-1, features) + ys = ys.view(-1) + loss = self.loss(preds, ys) + self.test_loss.update(loss) + if self.extracting: + raise ValueError("Extraction done, aborting") + + def test_epoch_end(self, outputs): + self.log('test loss', self.test_loss) + self.log('photon target', self.photon_target) + + def configure_optimizers(self): + if self.use_adam: + decay = set() + no_decay = set() + blacklist_weight_modules = (nn.LayerNorm, nn.Embedding) + + for mn, m in self.named_modules(): + for pn, p in m.named_parameters(recurse=False): + fpn = '%s.%s' % (mn, pn) if mn else pn # full param name + + if 'bias' in pn: + no_decay.add(fpn) + elif 'weight' in pn and not isinstance(m, blacklist_weight_modules): + decay.add(fpn) + else: + no_decay.add(fpn) + + param_dict = {pn: p for pn, p in self.named_parameters()} + inter_params = decay & no_decay + union_params = decay | no_decay + + assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), ) + assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \ + % (str(param_dict.keys() - union_params), ) + + optim_groups = [ + {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": 0.02}, + {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0}, + ] + + optimizer = torch.optim.AdamW(optim_groups, lr=self.lr) + scheduler = transformers.get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=2500, num_training_steps=self.training_steps) + return { + 'optimizer': optimizer, + 'lr_scheduler': { + 'scheduler': scheduler, + 'interval': 'step', + 'name': 'Cosine LR scheduler' + } + } + else: + optimizer = torch.optim.RMSprop(self.parameters(), lr=self.lr, weight_decay=1e-5) + scheduler = transformers.get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=2500, num_training_steps=self.training_steps) + return { + 'optimizer': optimizer, + 'lr_scheduler': { + 'scheduler': scheduler, + 'interval': 'step', + 'name': 'Cosine LR scheduler' + } + } + + def replace_output_head(self, module): + self.transformer.output_head = module + + def enable_quantization(self): + for m in self.transformer.modules(): + if isinstance(m, AllPassMORRCirculantLinear) or isinstance(m, AllPassMORRCirculantMatMals): + m.enable_quantization() + + def set_photon_target(self, n_photons): + self.photon_target = n_photons + for m in self.transformer.modules(): + if isinstance(m, AllPassMORRCirculantLinear) or isinstance(m, AllPassMORRCirculantMatMals): + m.set_photon_target(n_photons) + + def set_quantized_eval(self, value=True): + for m in self.transformer.modules(): + if isinstance(m, AllPassMORRCirculantLinear) or isinstance(m, AllPassMORRCirculantMatMals): + print("setting quantized eval") + m.force_quantized_eval = value + + def save(self, fname): + torch.save(self.transformer.state_dict(), fname) + + def load(self, fname): + self.transformer.load_state_dict(torch.load(fname)) + + def enable_extraction(self): + lin1 = self.transformer.decoders[0].ff.layer2 + lin1.extract_simulated = True + lin1.extract_name = 'first_linear' + lin2 = self.transformer.decoders[-1].ff.layer2 + lin2.extract_simulated = True + lin2.extract_name = 'last_linear' + attn1 = self.transformer.decoders[0].attn.qmm1 + attn1.extract_simulated = True + attn1.extract_name = 'first_attn' + attn2 = self.transformer.decoders[-1].attn.qmm1 + attn2.extract_simulated = True + attn2.extract_name = 'last_attn' + self.extracting = True + \ No newline at end of file From 8a27d0d93e1dcd589b3e12fb83a966bac5fde18e Mon Sep 17 00:00:00 2001 From: Johnny1882 Date: Sun, 11 May 2025 22:24:20 +0100 Subject: [PATCH 26/38] update onn kernel and testing script --- src/chop/nn/optical/modules/QuantizedGPT.py | 642 +++++++++++++ src/chop/nn/optical/modules/__init__.py | 7 + .../nn/optical/modules/morr_transformer.py | 884 ------------------ .../modules/morr_transformer/morr_bert.py | 143 +++ .../modules/morr_transformer/morr_matmul.py | 535 +++++++++++ .../morr_transformer/morr_transformer.py | 168 ++++ src/chop/nn/optical/triton_modules/dtype.py | 17 + .../nn/optical/triton_modules/morr_linear.py | 473 ++++++++++ .../triton_modules/morr_linear_kernel.py | 733 +++++++++++++++ .../triton_modules/morr_linear_kernel_mem.py | 720 ++++++++++++++ .../optical/triton_modules/morr_linear_mem.py | 473 ++++++++++ .../nn/optical/triton_modules/quantize.py | 104 +++ .../passes/module/module_modify_helper.py | 34 + .../optical/module_transform_helper.py | 17 +- .../module/transforms/optical/optical.py | 5 + 15 files changed, 4070 insertions(+), 885 deletions(-) create mode 100644 src/chop/nn/optical/modules/QuantizedGPT.py delete mode 100644 src/chop/nn/optical/modules/morr_transformer.py create mode 100644 src/chop/nn/optical/modules/morr_transformer/morr_bert.py create mode 100644 src/chop/nn/optical/modules/morr_transformer/morr_matmul.py create mode 100644 src/chop/nn/optical/modules/morr_transformer/morr_transformer.py create mode 100644 src/chop/nn/optical/triton_modules/dtype.py create mode 100644 src/chop/nn/optical/triton_modules/morr_linear.py create mode 100644 src/chop/nn/optical/triton_modules/morr_linear_kernel.py create mode 100644 src/chop/nn/optical/triton_modules/morr_linear_kernel_mem.py create mode 100644 src/chop/nn/optical/triton_modules/morr_linear_mem.py create mode 100644 src/chop/nn/optical/triton_modules/quantize.py diff --git a/src/chop/nn/optical/modules/QuantizedGPT.py b/src/chop/nn/optical/modules/QuantizedGPT.py new file mode 100644 index 000000000..424570db1 --- /dev/null +++ b/src/chop/nn/optical/modules/QuantizedGPT.py @@ -0,0 +1,642 @@ +import sys +from numpy import outer +import torch +import torch.nn as nn +import pytorch_lightning as pl +import torchmetrics +from transformers import GPT2TokenizerFast +import transformers +import torch.nn.functional as F +import math +import numpy +sys.path.append('...') + + +def make_autoregressive_mask_for(x): + length = x.size(1) + ones = x.new_ones((length, length)) + mask = torch.triu(ones, diagonal=1) != 0.0 + return mask + + +def make_position_indices_for(x): + length = x.size(1) + batch_size = x.size(0) + indices = torch.arange(length, device=x.device).repeat(batch_size, 1) + return indices + + +def load_lookup_table(file, device): + data = torch.from_numpy(numpy.genfromtxt(file, delimiter='\t')).float() + levels = data.size(0) + lower_bound = data[0,1].item() + weight = data[:,1].unsqueeze(1).cuda(device) + return weight, lower_bound, levels + + +def apply_lut_to_normalized(x, lut, bit_degredation=0): + lut_weight, lut_lb, lut_levels = lut + deg_factor = 2**bit_degredation + x = x.mul(lut_levels - deg_factor).div(deg_factor).round().mul(deg_factor).to(dtype=torch.long) + x = F.embedding(x, lut_weight).squeeze(-1) + return x + + +class QuantizeValue(torch.autograd.Function): + @staticmethod + def forward(ctx, x, quant_levels, min_val, max_val, quant_mode, lut_min=None): + with torch.no_grad(): + diff = max_val - min_val + x = x.clamp(min_val, max_val).add(-1.0 * min_val).div(diff + 1e-8).mul(quant_levels - 1) + + if quant_mode == 'det': + x = x.round() + x = x.div(quant_levels - 1).mul(diff).add(min_val) + elif quant_mode == 'rand': + x = x.add(torch.rand_like(x).add(-0.5)).round() # randn* 0.288 gives same std as 0-1 rand(), if want to use normal dist. + x = x.div(quant_levels - 1).mul(diff).add(min_val) + + if lut_min is not None: + pos_x = torch.relu(x) + neg_x = x - pos_x + lms = lut_min * max_val + pos_x[pos_x < lms] = lms + lms = lut_min * torch.abs(min_val) + neg_x[neg_x > -lms] = -lms + x = pos_x + neg_x + + return x + + @staticmethod + def backward(ctx, grad_output): + # STE + return grad_output, None, None, None, None, None + + +class QuantizeStats(nn.Module): + def __init__(self, percentile, use_clipping=True): + super(QuantizeStats, self).__init__() + self.register_buffer('running_min', torch.tensor(0.0)) + self.register_buffer('running_max', torch.tensor(0.0)) + self.max_calibration_steps = 1 + self.initial_calibration_steps = 0 + #self.register_buffer('calibration_done', torch.tensor(False)) + self.calibration_done = torch.tensor(False) + self.activations = [] + self.percentile = percentile + self.use_clipping = use_clipping + + def update(self, tensor): + if self.use_clipping: + if not self.calibration_done.item(): + self.initial_calibration_steps += 1 + finished = False + + if self.initial_calibration_steps >= self.max_calibration_steps: + finished = True + self.calibration_done = torch.tensor(True) + + with torch.no_grad(): + self.activations.extend(tensor.detach().cpu().tolist()) + + if finished: + maximum = numpy.percentile(self.activations, self.percentile) + self.running_max = torch.tensor(maximum, device=tensor.device, dtype=tensor.dtype) + minimum = tensor.min() + minimum = minimum if minimum >= 0.0 else -maximum + self.running_min = torch.tensor(minimum, device=tensor.device, dtype=tensor.dtype) + self.activations.clear() # free the memory + else: + self.running_min = tensor.min() + self.running_max = tensor.max() + + else: + alpha = 0.999 + with torch.no_grad(): + cur_min = tensor.min() + cur_max = tensor.max() + + if self.initial_calibration_steps == 0: + self.initial_calibration_steps += 1 + self.running_min = cur_min + self.running_max = cur_max + else: + self.running_min = alpha * self.running_min + (1.0 - alpha) * cur_min + self.running_max = alpha * self.running_max + (1.0 - alpha) * cur_max + + + + def get(self): + return self.running_min, self.running_max + + +def shot_noise_linear(w, x, n_photons_target, phone_lut=None, slm_lut=None, extract=False, extract_name=None): + noise_level = 0.021 + + if n_photons_target != 0: + quantize = QuantizeValue.apply + use_lut = (phone_lut is not None) and (slm_lut is not None) + w_max = torch.max(w) + w_norm = apply_lut_to_normalized(w / (1e-8 + w_max), slm_lut) if use_lut else w / (1e-8 + w_max) + x_max = torch.max(x, dim=2, keepdim=True)[0] + x_norm = apply_lut_to_normalized(x / (1e-8 + x_max), phone_lut, bit_degredation=0) if use_lut else x / (1e-8 + x_max) + + out_opt = F.linear(x_norm, w_norm, bias=None) + photons_per_act = n_photons_target * x_norm.size(2) / (x_norm.sum(dim=2, keepdim=True) + 1e-8) + fluence_Wx = out_opt * photons_per_act + noise_Wx = torch.poisson(fluence_Wx) + out = noise_Wx / photons_per_act + + random_noise = noise_level * out.mean() + out = torch.normal(out, random_noise) + + out = x_max * out * w_max + else: + out = F.linear(x, w, bias=None) + + if extract and n_photons_target != 0: + torch.save({'x': x_norm[1, :512, :].detach().clone(), + 'w': w_norm[:512].detach().clone(), + 'out': out_opt[1, :512, :512].detach().clone(), + 'noise_level': noise_level}, + #'noise_value': random_noise}, + extract_name) + + return out + + +def shot_noise_bhmm(x, y, n_photons_target, phone_lut=None, slm_lut=None, extract=False, extract_name=None): + # perform xy matrix-multiply like matrix vector, where matrix "slices" in y are like W and x is the vectors. Thus take max over whole matrices in y as we would for W + noise_level = 0.0565 + + if n_photons_target != 0: + quantize = QuantizeValue.apply + use_lut = (phone_lut is not None) and (slm_lut is not None) + x_max = torch.max(x, dim=3, keepdim=True)[0] + x_norm = apply_lut_to_normalized(x / (1e-8 + x_max), phone_lut, bit_degredation=0) if use_lut else x / (1e-8 + x_max) + y_max = torch.amax(y, dim=(2, 3), keepdim=True) + y_norm = apply_lut_to_normalized(y / (1e-8 + y_max), slm_lut) if use_lut else y / (1e-8 + y_max) + + out_opt = torch.matmul(x_norm, y_norm) + photons_per_act = n_photons_target * x_norm.size(3) / (x_norm.sum(dim=3, keepdim=True) + 1e-8) + fluence_mm = out_opt * photons_per_act + noise_Wx = torch.poisson(fluence_mm) + out = noise_Wx / photons_per_act + + random_noise = noise_level * out.mean() + out = torch.normal(out, random_noise) + + out = x_max * out * y_max + else: + out = torch.matmul(x, y) + + if extract and n_photons_target != 0: + torch.save({'x': x_norm[0, 0, :, :].detach().clone(), + 'y': y_norm[0, 0, :, :].detach().clone(), + 'out': out_opt[0, 0, :, :].detach().clone(), + 'noise_level': noise_level}, + #'noise_value': random_noise}, + extract_name) + + return out + + +class QuantizedLinear(nn.Module): + def __init__(self, in_feats, out_feats, use_noise=True): + super(QuantizedLinear, self).__init__() + self.weight = nn.Parameter(torch.zeros(out_feats, in_feats)) + self.input_stats = QuantizeStats(99.99) + self.output_stats = QuantizeStats(99.9999) + nn.init.xavier_uniform_(self.weight) + self.quantize = False + self.photon_target = 0 + self.slm_lut = load_lookup_table('LUTs/SLM_AmpLUT.txt', device=torch.device('cuda:0')) + self.phone_lut = load_lookup_table('LUTs/PhoneLUT.txt', device=torch.device('cuda:0')) + self.use_lut = (self.slm_lut is not None) and (self.phone_lut is not None) + if self.use_lut: + _, self.slm_cutoff, _ = self.slm_lut + else: + self.slm_cutoff = None + self.force_quantized_eval = False + self.extract_simulated = False + self.extract_name = '' + self.noise = use_noise + #print('L module using LUT: {}'.format(self.use_lut)) + + def _weight_min(self): + with torch.no_grad(): + return self.weight_min + + def _weight_max(self): + with torch.no_grad(): + return self.weight_max + + def enable_quantization(self, clipping=True): + with torch.no_grad(): + if clipping: + weight_values = self.weight.detach().cpu().tolist() + maximum = numpy.percentile(weight_values, 99).item() + self.weight_max = torch.tensor(maximum, dtype=self.weight.dtype, device=self.weight.device) + self.weight_min = torch.tensor(-maximum, dtype=self.weight.dtype, device=self.weight.device) + else: + self.weight_min = self.weight.min() + self.weight_max = self.weight.max() + self.quantize = True + + def set_photon_target(self, n_photons): + self.photon_target = n_photons + + def forward(self, x): + if self.quantize: + quantize = QuantizeValue.apply + if self.training or self.force_quantized_eval: + # QAT for activations + if self.training: + self.input_stats.update(x) + input_min, input_max = self.input_stats.get() + quantized_x = quantize(x, 256, input_min, input_max, 'det') + quantized_weights = quantize(self.weight, 256, self._weight_min(), self._weight_max(), 'det', self.slm_cutoff) # 160 + out = F.linear(quantized_x, quantized_weights, bias=None) + if self.training: + self.output_stats.update(out) + output_min, output_max = self.output_stats.get() + quantized_out = quantize(out, 256, output_min, output_max, 'rand') + return quantized_out + else: + # shot noise simulation for linear layer, per-token + input_min, input_max = self.input_stats.get() + weight_min, weight_max = self._weight_min(), self._weight_max() + + if self.use_lut: + w = self.weight.clamp(weight_min, weight_max) + x = x.clamp(input_min, input_max) + else: + quantize = QuantizeValue.apply + x = quantize(x, 256, input_min, input_max, 'det') + w = quantize(self.weight, 256, weight_min, weight_max, 'det', self.slm_cutoff) + + pos_x = F.relu(x) + neg_x = torch.abs(x - pos_x) + pos_w = F.relu(w) + neg_w = torch.abs(w - pos_w) + out = shot_noise_linear(pos_w, pos_x, self.photon_target, self.phone_lut, self.slm_lut, self.extract_simulated, self.extract_name + '_0.pt') \ + + shot_noise_linear(neg_w, neg_x, self.photon_target, self.phone_lut, self.slm_lut, self.extract_simulated, self.extract_name + '_1.pt') \ + - shot_noise_linear(pos_w, neg_x, self.photon_target, self.phone_lut, self.slm_lut, self.extract_simulated, self.extract_name + '_2.pt') \ + - shot_noise_linear(neg_w, pos_x, self.photon_target, self.phone_lut, self.slm_lut, self.extract_simulated, self.extract_name + '_3.pt') + + output_min, output_max = self.output_stats.get() + out = out.clamp(output_min, output_max) + #out = quantize(out, 64, output_min, output_max, 'det') + return out + else: + out = F.linear(x, self.weight, bias=None) + return out + + +class QuantizedMatmul(nn.Module): + def __init__(self): + super(QuantizedMatmul, self).__init__() + self.input1_stats = QuantizeStats(99.99) + self.input2_stats = QuantizeStats(98) + self.output_stats = QuantizeStats(99.9999) + self.quantize = False + self.photon_target = 0 + self.slm_lut = load_lookup_table('LUTs/SLM_AmpLUT.txt', device=torch.device('cuda:0')) + self.phone_lut = load_lookup_table('LUTs/PhoneLUT.txt', device=torch.device('cuda:0')) + self.use_lut = (self.slm_lut is not None) and (self.phone_lut is not None) + if self.use_lut: + _, self.slm_cutoff, _ = self.slm_lut + else: + self.slm_cutoff = None + self.force_quantized_eval = False + self.extract_simulated = False + self.extract_name = '' + #print('MM module using LUT: {}'.format(self.use_lut)) + + def enable_quantization(self): + self.quantize = True + + def set_photon_target(self, n_photons): + self.photon_target = n_photons + + def forward(self, x, y): + if self.quantize: + quantize = QuantizeValue.apply + if self.training or self.force_quantized_eval: + # QAT for activations + if self.training: + self.input1_stats.update(x) + self.input2_stats.update(y) + x_min, x_max = self.input1_stats.get() + y_min, y_max = self.input2_stats.get() + xq = quantize(x, 256, x_min, x_max, 'det') + yq = quantize(y, 256, y_min, y_max, 'det', self.slm_cutoff) + out = torch.matmul(xq, yq) + if self.training: + self.output_stats.update(out) + out_min, out_max = self.output_stats.get() + outq = quantize(out, 256, out_min, out_max, 'rand') + return outq + else: + # Shot noise simulation for broadcasted matrix-matrix multiply + x_min, x_max = self.input1_stats.get() + y_min, y_max = self.input2_stats.get() + + if self.use_lut: + x = x.clamp(x_min, x_max) + y = y.clamp(y_min, y_max) + else: + quantize = QuantizeValue.apply + x = quantize(x, 256, x_min, x_max, 'det') + y = quantize(y, 256, y_min, y_max, 'det', self.slm_cutoff) + + pos_x = F.relu(x) + neg_x = torch.abs(x - pos_x) + pos_y = F.relu(y) + neg_y = torch.abs(y - pos_y) + out = shot_noise_bhmm(pos_x, pos_y, self.photon_target, self.phone_lut, self.slm_lut, self.extract_simulated, self.extract_name + '_0.pt') \ + + shot_noise_bhmm(neg_x, neg_y, self.photon_target, self.phone_lut, self.slm_lut, self.extract_simulated, self.extract_name + '_1.pt') \ + - shot_noise_bhmm(pos_x, neg_y, self.photon_target, self.phone_lut, self.slm_lut, self.extract_simulated, self.extract_name + '_2.pt') \ + - shot_noise_bhmm(neg_x, pos_y, self.photon_target, self.phone_lut, self.slm_lut, self.extract_simulated, self.extract_name + '_3.pt') + + output_min, output_max = self.output_stats.get() + out = out.clamp(output_min, output_max) + #out = quantize(out, 64, output_min, output_max, 'det') + return out + else: + out = torch.matmul(x, y) + return out + + +class QuantizedMHA(nn.Module): + def __init__(self, embed_dim, heads): + super(QuantizedMHA, self).__init__() + assert embed_dim % heads == 0 + self.n_heads = heads + self.Wq = QuantizedLinear(embed_dim, embed_dim) + self.Wk = QuantizedLinear(embed_dim, embed_dim) + self.Wv = QuantizedLinear(embed_dim, embed_dim) + self.qmm1 = QuantizedMatmul() + self.dropout_wq = nn.Dropout(0.1) + self.dropout_wk = nn.Dropout(0.1) + self.dropout_wv = nn.Dropout(0.1) + self.qmm2 = QuantizedMatmul() + self.Wout = QuantizedLinear(embed_dim, embed_dim) + self.dropout1 = nn.Dropout(0.1) + self.dropout2 = nn.Dropout(0.1) + + def forward(self, x, mask): + b = x.size(0) + n = x.size(1) + h = self.n_heads + d = x.size(2) + + def arrange_heads(acts): + # incoming shape of b, n, d, want b, h, n, d/h + return acts.view(b, n, h, -1).transpose(1, 2) + + q = arrange_heads(self.dropout_wq(self.Wq(x))) + k = arrange_heads(self.dropout_wk(self.Wk(x))) + v = arrange_heads(self.dropout_wv(self.Wv(x))) + + attn = self.qmm1(q, k.transpose(2, 3)) # yields b, h, n, n + masked = attn.masked_fill(mask, float("-inf")) + softmax_attn = self.dropout1(F.softmax(masked / math.sqrt(d // h), dim=3)) + out = self.qmm2(softmax_attn, v) # b, h, n, d/h + + out = out.transpose(1, 2).reshape(b, n, -1) + out = self.dropout2(out) + out = self.Wout(out) + return out + + +class QuantizedFF(nn.Module): + def __init__(self, embed_dim, expansion_dim): + super(QuantizedFF, self).__init__() + self.first_drop = nn.Dropout(0.1) + self.layer1 = QuantizedLinear(embed_dim, expansion_dim, use_noise=True) + self.act = nn.ReLU6(inplace=True) + self.dropout = nn.Dropout(0.1) + self.layer2 = QuantizedLinear(expansion_dim, embed_dim, use_noise=True) + + def forward(self, x): + out = self.first_drop(x) + out = self.layer1(out) + out = self.act(out) + out = self.dropout(out) + out = self.layer2(out) + return out + + +class QuantizedDecoderLayer(nn.Module): + def __init__(self, features, heads): + super(QuantizedDecoderLayer, self).__init__() + self.norm1 = nn.LayerNorm(features) + self.attn = QuantizedMHA(features, heads) + self.drop1 = nn.Dropout(0.1) + self.norm2 = nn.LayerNorm(features) + self.ff = QuantizedFF(features, features * 4) + self.drop2 = nn.Dropout(0.1) + + def forward(self, x, attn_mask): + # no need for key mask for gpt; autoregressive masking already prevents 'real' tokens from attending to padding tokens to the right + identity = x + out = self.norm1(x) + out = self.attn(out, attn_mask) + out = self.drop1(out) + out = out + identity + identity = out + out = self.norm2(out) + out = self.ff(out) + out = self.drop2(out) + out = out + identity + return out + + +class _QuantizedGPT(nn.Module): + def __init__(self, features, heads, tokenizer, layers, max_length): + super(_QuantizedGPT, self).__init__() + vocab_size = len(tokenizer) + 8 - len(tokenizer) % 8 # pad vocab size to 8-multiple for tensor core acceleration + assert vocab_size % 8 == 0 + self.pos_embedding = nn.Embedding(max_length, features) + self.word_embedding = nn.Embedding(vocab_size, features, padding_idx = tokenizer.pad_token_id) + self.embedding_dropout = nn.Dropout(0.1) + self.decoders = nn.ModuleList([QuantizedDecoderLayer(features, heads) for _ in range(layers)]) + self.norm = nn.LayerNorm(features) + self.output_head = nn.Linear(features, vocab_size) + nn.init.normal_(self.word_embedding.weight, std=0.02) + nn.init.normal_(self.pos_embedding.weight, std=0.02) + + def forward_embedding(self, x): + embedded = self.word_embedding(x) + return embedded + + def forward_attn(self, x): + mask = make_autoregressive_mask_for(x) + pos = make_position_indices_for(x) + pos_embed = self.embedding_dropout(self.pos_embedding(pos) + x) + decoded = pos_embed + for layer in self.decoders: + decoded = layer(decoded, mask) + + out = self.norm(decoded) + return out + + def forward(self, x): + embedded = self.forward_embedding(x) + decoded = self.forward_attn(embedded) + out = self.output_head(decoded) + return out + + +class QuantizedGPT(pl.LightningModule): + def __init__(self, features, heads, layers=6, max_length=1024): + super().__init__() + self.tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") + self.tokenizer.add_special_tokens({'pad_token': '<|pad|>'}) + self.transformer = _QuantizedGPT(features, heads, self.tokenizer, layers, max_length) + self.loss = nn.CrossEntropyLoss(ignore_index=self.tokenizer.pad_token_id) + self.val_loss = torchmetrics.MeanMetric() + self.test_loss = torchmetrics.MeanMetric() + self.lr = 0.0005 + self.photon_target = 0 + self.training_steps = 100000 + self.extracting = False + self.use_adam = True + + def get_tokenizer(self): + return self.tokenizer + + def forward(self, x): + return self.transformer(x) + + def training_step(self, batch, batch_idx): + xs, ys = batch + preds = self(xs) + features = preds.size(2) + preds = preds.view(-1, features) + ys = ys.view(-1) + loss = self.loss(preds, ys) + self.log('train loss', loss) + return loss + + def validation_step(self, batch, batch_idx): + xs, ys = batch + preds = self(xs) + features = preds.size(2) + preds = preds.view(-1, features) + ys = ys.view(-1) + loss = self.loss(preds, ys) + self.val_loss.update(loss) + + def validation_epoch_end(self, outputs): + self.log('validation loss', self.val_loss) + + def test_step(self, batch, batch_idx): + xs, ys = batch + preds = self(xs) + features = preds.size(2) + preds = preds.view(-1, features) + ys = ys.view(-1) + loss = self.loss(preds, ys) + self.test_loss.update(loss) + if self.extracting: + raise ValueError("Extraction done, aborting") + + def test_epoch_end(self, outputs): + self.log('test loss', self.test_loss) + self.log('photon target', self.photon_target) + + def configure_optimizers(self): + if self.use_adam: + decay = set() + no_decay = set() + blacklist_weight_modules = (nn.LayerNorm, nn.Embedding) + + for mn, m in self.named_modules(): + for pn, p in m.named_parameters(recurse=False): + fpn = '%s.%s' % (mn, pn) if mn else pn # full param name + + if 'bias' in pn: + no_decay.add(fpn) + elif 'weight' in pn and not isinstance(m, blacklist_weight_modules): + decay.add(fpn) + else: + no_decay.add(fpn) + + param_dict = {pn: p for pn, p in self.named_parameters()} + inter_params = decay & no_decay + union_params = decay | no_decay + + assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), ) + assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \ + % (str(param_dict.keys() - union_params), ) + + optim_groups = [ + {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": 0.02}, + {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0}, + ] + + optimizer = torch.optim.AdamW(optim_groups, lr=self.lr) + scheduler = transformers.get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=2500, num_training_steps=self.training_steps) + return { + 'optimizer': optimizer, + 'lr_scheduler': { + 'scheduler': scheduler, + 'interval': 'step', + 'name': 'Cosine LR scheduler' + } + } + else: + optimizer = torch.optim.RMSprop(self.parameters(), lr=self.lr, weight_decay=1e-5) + scheduler = transformers.get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=2500, num_training_steps=self.training_steps) + return { + 'optimizer': optimizer, + 'lr_scheduler': { + 'scheduler': scheduler, + 'interval': 'step', + 'name': 'Cosine LR scheduler' + } + } + + def replace_output_head(self, module): + self.transformer.output_head = module + + def enable_quantization(self): + for m in self.transformer.modules(): + if isinstance(m, QuantizedLinear) or isinstance(m, QuantizedMatmul): + m.enable_quantization() + + def set_photon_target(self, n_photons): + self.photon_target = n_photons + for m in self.transformer.modules(): + if isinstance(m, QuantizedLinear) or isinstance(m, QuantizedMatmul): + m.set_photon_target(n_photons) + + def set_quantized_eval(self, value=True): + for m in self.transformer.modules(): + if isinstance(m, QuantizedLinear) or isinstance(m, QuantizedMatmul): + print("setting quantized eval") + m.force_quantized_eval = value + + def save(self, fname): + torch.save(self.transformer.state_dict(), fname) + + def load(self, fname): + self.transformer.load_state_dict(torch.load(fname)) + + def enable_extraction(self): + lin1 = self.transformer.decoders[0].ff.layer2 + lin1.extract_simulated = True + lin1.extract_name = 'first_linear' + lin2 = self.transformer.decoders[-1].ff.layer2 + lin2.extract_simulated = True + lin2.extract_name = 'last_linear' + attn1 = self.transformer.decoders[0].attn.qmm1 + attn1.extract_simulated = True + attn1.extract_name = 'first_attn' + attn2 = self.transformer.decoders[-1].attn.qmm1 + attn2.extract_simulated = True + attn2.extract_name = 'last_attn' + self.extracting = True + \ No newline at end of file diff --git a/src/chop/nn/optical/modules/__init__.py b/src/chop/nn/optical/modules/__init__.py index bf8ecbce6..cddc777fd 100644 --- a/src/chop/nn/optical/modules/__init__.py +++ b/src/chop/nn/optical/modules/__init__.py @@ -1,9 +1,16 @@ from .morr_linear import AllPassMORRCirculantLinear from .morr_conv2d import AllPassMORRCirculantConv2d from .morr_custom_linear import AllPassMORRLinear +from ..triton_modules.morr_linear import TritonMORRLinear +from ..triton_modules.morr_linear_mem import TritonMemMORRLinear +from .morr_transformer.morr_bert import BertMORRSelfAttention + optical_module_map = { "linear_morr": AllPassMORRCirculantLinear, "conv2d_morr": AllPassMORRCirculantConv2d, "linear_morr_full": AllPassMORRLinear, + "linear_morr_triton": TritonMORRLinear, + "linear_morr_triton_mem": TritonMemMORRLinear, + "bert_self_attention_morr": BertMORRSelfAttention, } diff --git a/src/chop/nn/optical/modules/morr_transformer.py b/src/chop/nn/optical/modules/morr_transformer.py deleted file mode 100644 index bb1fe9e12..000000000 --- a/src/chop/nn/optical/modules/morr_transformer.py +++ /dev/null @@ -1,884 +0,0 @@ -from typing import Optional -import logging - -import numpy as np -import math -import torch -import torch.nn as nn -import torch.fft -import torch.nn.functional as F -from torch import Tensor -from torch.nn import Parameter, init -from torch.types import Device -import pytorch_lightning as pl -import torchmetrics -import transformers -from transformers import GPT2TokenizerFast - -from ..utils import MORRConfig_20um_MQ -from ..utils import mrr_roundtrip_phase_to_tr_func, mrr_roundtrip_phase_to_tr_fused -from ..utils import toeplitz -from ..utils import morr_uniform_ -from ..utils import input_quantize_fn, weight_quantize_fn -from .base_layer import ONNBaseLayer -from .morr_custom_linear import AllPassMORRLinear -from .morr_linear import AllPassMORRCirculantLinear - -from transformers import BertModel, BertForSequenceClassification -from transformers.models.gpt2.modeling_gpt2 import ( - GPT2Attention, - GPT2MLP, - GPT2Block, - Conv1D, -) - -logger = logging.getLogger(__name__) - -__all__ = ["AllPassMORRCirculantMatMals"] - -def make_autoregressive_mask_for(x): - length = x.size(1) - ones = x.new_ones((length, length)) - mask = torch.triu(ones, diagonal=1) != 0.0 - return mask - - -def make_position_indices_for(x): - length = x.size(1) - batch_size = x.size(0) - indices = torch.arange(length, device=x.device).repeat(batch_size, 1) - return indices - - -def load_lookup_table(file, device): - data = torch.from_numpy(numpy.genfromtxt(file, delimiter='\t')).float() - levels = data.size(0) - lower_bound = data[0,1].item() - weight = data[:,1].unsqueeze(1).cuda(device) - return weight, lower_bound, levels - - -def apply_lut_to_normalized(x, lut, bit_degredation=0): - lut_weight, lut_lb, lut_levels = lut - deg_factor = 2**bit_degredation - x = x.mul(lut_levels - deg_factor).div(deg_factor).round().mul(deg_factor).to(dtype=torch.long) - x = F.embedding(x, lut_weight).squeeze(-1) - return x - - -class QuantizeValue(torch.autograd.Function): - @staticmethod - def forward(ctx, x, quant_levels, min_val, max_val, quant_mode, lut_min=None): - with torch.no_grad(): - diff = max_val - min_val - x = x.clamp(min_val, max_val).add(-1.0 * min_val).div(diff + 1e-8).mul(quant_levels - 1) - - if quant_mode == 'det': - x = x.round() - x = x.div(quant_levels - 1).mul(diff).add(min_val) - elif quant_mode == 'rand': - x = x.add(torch.rand_like(x).add(-0.5)).round() # randn* 0.288 gives same std as 0-1 rand(), if want to use normal dist. - x = x.div(quant_levels - 1).mul(diff).add(min_val) - - if lut_min is not None: - pos_x = torch.relu(x) - neg_x = x - pos_x - lms = lut_min * max_val - pos_x[pos_x < lms] = lms - lms = lut_min * torch.abs(min_val) - neg_x[neg_x > -lms] = -lms - x = pos_x + neg_x - - return x - - @staticmethod - def backward(ctx, grad_output): - # STE - return grad_output, None, None, None, None, None - -class QuantizeStats(nn.Module): - def __init__(self, percentile, use_clipping=True): - super(QuantizeStats, self).__init__() - self.register_buffer('running_min', torch.tensor(0.0)) - self.register_buffer('running_max', torch.tensor(0.0)) - self.max_calibration_steps = 1 - self.initial_calibration_steps = 0 - #self.register_buffer('calibration_done', torch.tensor(False)) - self.calibration_done = torch.tensor(False) - self.activations = [] - self.percentile = percentile - self.use_clipping = use_clipping - - def update(self, tensor): - if self.use_clipping: - if not self.calibration_done.item(): - self.initial_calibration_steps += 1 - finished = False - - if self.initial_calibration_steps >= self.max_calibration_steps: - finished = True - self.calibration_done = torch.tensor(True) - - with torch.no_grad(): - self.activations.extend(tensor.detach().cpu().tolist()) - - if finished: - maximum = numpy.percentile(self.activations, self.percentile) - self.running_max = torch.tensor(maximum, device=tensor.device, dtype=tensor.dtype) - minimum = tensor.min() - minimum = minimum if minimum >= 0.0 else -maximum - self.running_min = torch.tensor(minimum, device=tensor.device, dtype=tensor.dtype) - self.activations.clear() # free the memory - else: - self.running_min = tensor.min() - self.running_max = tensor.max() - - else: - alpha = 0.999 - with torch.no_grad(): - cur_min = tensor.min() - cur_max = tensor.max() - - if self.initial_calibration_steps == 0: - self.initial_calibration_steps += 1 - self.running_min = cur_min - self.running_max = cur_max - else: - self.running_min = alpha * self.running_min + (1.0 - alpha) * cur_min - self.running_max = alpha * self.running_max + (1.0 - alpha) * cur_max - - - - def get(self): - return self.running_min, self.running_max - -class AllPassMORRCirculantMatMals(ONNBaseLayer): - """ - All-pass MORR Linear layer, assumes (1) block-circulant matrix (2) differential rails (3) learnable balancing factors. - J. Gu, et al., "SqueezeLight: Towards Scalable Optical Neural Networks with Multi-Operand Ring Resonators" - https://doi.org/10.23919/DATE51398.2021.9474147 - """ - - __constants__ = ["in_features", "out_features"] - in_features: int - out_features: int - miniblock: int - weight: Tensor - - def __init__( - self, - in_features: int, - out_features: int, - bias: bool = False, - config=None, - device: Device = torch.device("cpu"), - ) -> None: - super(AllPassMORRCirculantMatMals, self).__init__() - self.in_features = in_features - self.out_features = out_features - - miniblock_size = config.get("miniblock", 4) - self.miniblock = miniblock_size - self.grid_dim_x = int(np.ceil(self.in_features / miniblock_size)) - self.grid_dim_y = int(np.ceil(self.out_features / miniblock_size)) - self.in_features_pad = self.grid_dim_x * miniblock_size - self.out_features_pad = self.grid_dim_y * miniblock_size - - self.v_max = 10.8 - self.v_pi = 4.36 - self.gamma = np.pi / self.v_pi**2 - self.w_bit = 32 - self.in_bit = 32 - - morr_config = config.get("MORRConfig", MORRConfig_20um_MQ) - morr_init_val = config.get("morr_init", MORRConfig_20um_MQ) - self.MORRConfig = morr_config - self.morr_init = morr_init_val - self.mrr_a = morr_config.attenuation_factor - self.mrr_r = morr_config.coupling_factor - self.trainable_morr_bias = config.get("trainable_morr_bias", MORRConfig_20um_MQ) - self.trainable_morr_scale = config.get( - "trainable_morr_scale", MORRConfig_20um_MQ - ) - self.device = device - ### calculate FWHM (rad) - self.morr_fwhm = ( - -4 - * np.pi**2 - * morr_config.radius - * morr_config.effective_index - * ( - 1 / morr_config.resonance_wavelength - - 1 / (morr_config.resonance_wavelength - morr_config.bandwidth / 2) - ) - ) - - ### allocate parameters - self.weight = None - self.x_zero_pad = None - self.morr_output_scale = None ## learnable balancing factors implelemt by MRRs - self.morr_input_bias = None ## round-trip phase shift bias within MORR - self.morr_input_scale = ( - None ## scaling factor for the round-trip phase shift within MORR - ) - self.morr_gain = ( - 100 / (self.in_features // self.miniblock) - ) ** 0.5 ## TIA gain, calculated such that output variance is around 1 - ### build trainable parameters - self.build_parameters() - - ### quantization tool - self.input_quantizer = input_quantize_fn(self.in_bit, device=self.device) - self.weight_quantizer = weight_quantize_fn( - self.w_bit, alg="dorefa_pos" - ) ## [0-1] positive only, maintain the original scale - self.morr_output_scale_quantizer = weight_quantize_fn( - self.w_bit, alg="dorefa_sym" - ) ## [-1,1] full-range - - self.mrr_roundtrip_phase_to_tr = mrr_roundtrip_phase_to_tr_func( - a=self.mrr_a, r=self.mrr_r, intensity=True - ) - - ### default set to slow forward - self.disable_fast_forward() - ### default set no gamma noise - self.set_gamma_noise(0) - ### default set no crosstalk - self.disable_crosstalk() - ### default set no phase variation - self.disable_phase_variation() - - if bias: - self.bias = Parameter(torch.Tensor(out_features).to(self.device)) - else: - self.register_parameter("bias", None) - - self.reset_parameters(morr_init=morr_init_val) - self.finegrain_drop_mask = None - - def build_parameters(self) -> None: - - self.weight = Parameter( - torch.ones( - self.grid_dim_y, - self.grid_dim_x, - self.miniblock, - device=self.device, - dtype=torch.float, - ) - ) - ### Learnable balancing factor (morr_output_scale) - ### We use a single scaling factor for each block - self.morr_output_scale = Parameter( - torch.randn(1, 1, max(1, self.grid_dim_x // 2) + 1, 1, device=self.device) - ) - if self.trainable_morr_bias: - ### initialize with the finest-granularity, i.e., per mini-block - self.morr_input_bias = Parameter( - torch.zeros( - self.grid_dim_y, - self.grid_dim_x, - device=self.device, - dtype=torch.float, - ) - ) - if self.trainable_morr_scale: - ### initialize with the finest-granularity, i.e., per mini-block - self.morr_input_scale = Parameter( - torch.zeros( - self.grid_dim_y, - self.grid_dim_x, - device=self.device, - dtype=torch.float, - ) - ) - - def reset_parameters(self, morr_init: bool = False) -> None: - ### nonlinear curve aware initialization - if morr_init: - ## initialize weight - morr_uniform_( - self.weight, - MORRConfig=self.MORRConfig, - n_op=self.miniblock, - biased=self.w_bit >= 16, - gain=2 if self.in_bit < 16 else 1, - ) # quantization needs zero-center - self.sigma_weight = self.weight.data.std().item() - self.weight_quant_gain = None - - ## output distribution aware initialization to output scaling factor - t1 = mrr_roundtrip_phase_to_tr_fused( - torch.tensor([0]).float(), a=self.mrr_a, r=self.mrr_r, intensity=True - ) - t2 = mrr_roundtrip_phase_to_tr_fused( - torch.tensor([self.morr_fwhm * 2.4]).float(), - a=self.mrr_a, - r=self.mrr_r, - intensity=True, - ) - g = ( - (t2 - t1) / (2.4 * self.morr_fwhm) - ).item() ## 0~2.4 FWHM slope as a linear approximation - - self.sigma_out_scale = 4 / (3 * self.grid_dim_x**0.5 * g * self.morr_fwhm) - self.out_scale_quant_gain = None - init.normal_(self.morr_output_scale, 0, self.sigma_out_scale) - else: - init.kaiming_normal_(self.weight.data) - init.kaiming_normal_(self.morr_output_scale.data) - self.sigma_weight = self.weight.data.std().item() - self.weight_quant_gain = None - self.sigma_out_scale = self.morr_output_scale.data.std().item() - self.out_scale_quant_gain = None - - if self.morr_input_bias is not None: - self.morr_input_bias.data.zero_() - if self.morr_input_scale is not None: - ### after sigmoid, it cooresponds to 1 scale - init.normal_(self.morr_input_scale.data, 2, 0.1) - - if self.bias is not None: - init.uniform_(self.bias, 0, 0) - - def sync_parameters(self, src: str = "weight") -> None: - """ - description: synchronize all parameters from the source parameters - """ - - raise NotImplementedError - - def build_weight(self, y: Tensor) -> Tensor: - if self.w_bit < 16: - ### differentiable quantizer based on STE to enable QAT (Dorefa-Net, arXiv 2016) - weight = self.weight_quantizer(self.weight) - - ## rescale weights after quantization can maintain the initialization distribution - if self.weight_quant_gain is None: - self.weight_quant_gain = self.sigma_weight / weight.data.std() - if self.trainable_morr_scale: - morr_scale = self.morr_scale * self.weight_quant_gain - else: - morr_scale = self.weight_quant_gain - weight = weight.mul( - morr_scale - ) ### gain factor from Tanh used in quantization - - ### quantize learnable balancing factor - morr_output_scale = self.morr_output_scale_quantizer(self.morr_output_scale) - else: - weight = self.weight.abs() # positive only - morr_output_scale = ( - self.morr_output_scale - self.morr_output_scale.data.mean() - ) - - if self.finegrain_drop_mask is not None: - weight = weight.mul(self.finegrain_drop_mask.float()) - - ## differential balancing factor concatenation - scale = morr_output_scale[..., :-1, :] - scale_pad = morr_output_scale[..., -1:, :] - if self.grid_dim_x % 2 == 0: - # even blocks - scale = torch.cat([scale, -scale], dim=2) # [1, 1, q, 1] - else: - # odd blocks - if self.grid_dim_x > 1: - scale = torch.cat([morr_output_scale, -scale], dim=2) # [1, 1, q, 1] - else: - scale = scale_pad # [1, 1, q, 1] - morr_output_scale = scale.squeeze(-1).unsqueeze(0) # [1 ,1, 1, q] - - return weight, morr_output_scale - - def enable_fast_forward(self) -> None: - self.fast_forward_flag = True - - def disable_fast_forward(self) -> None: - self.fast_forward_flag = False - - def set_gamma_noise( - self, noise_std: float, random_state: Optional[int] = None - ) -> None: - self.gamma_noise_std = noise_std - - def load_parameters(self, param_dict) -> None: - """ - description: update parameters based on this parameter dictionary\\ - param param_dict {dict of dict} {layer_name: {param_name: param_tensor, ...}, ...} - """ - for name, param in param_dict.items(): - getattr(self, name).data.copy_(param) - - def set_weight_bitwidth(self, w_bit: int) -> None: - self.w_bit = w_bit - self.weight_quantizer.set_bitwidth(w_bit) - self.morr_output_scale_quantizer.set_bitwidth(w_bit) - - def set_input_bitwidth(self, in_bit: int) -> None: - self.in_bit = in_bit - self.input_quantizer.set_bitwidth(in_bit) - - def input_modulator(self, x: Tensor) -> Tensor: - ### voltage to power, which is proportional to the phase shift - return x * x - - def set_crosstalk_coupling_matrix( - self, coupling_factor: float, drop_perc: float = 0 - ) -> None: - ### crosstalk coupling matrix is a symmetric matrix, but the intra-MORR crosstalk can be taken as a round-trip phase shift scaling factor, which is proportional to the number of segments after pruned. - ### drop-perc is the pruning percentage. - assert 0 <= coupling_factor <= 1, logger.error( - f"Coupling factor must in [0,1], but got {coupling_factor}" - ) - - self.crosstalk_factor = ( - 1 + max(3, (self.miniblock * (1 - drop_perc) - 1)) * coupling_factor - ) - - def enable_crosstalk(self) -> None: - self.enable_thermal_crosstalk = True - - def disable_crosstalk(self) -> None: - self.enable_thermal_crosstalk = False - - def set_phase_variation(self, phase_noise_std: float = 0) -> None: - self.phase_noise_std = phase_noise_std - - def enable_phase_variation(self) -> None: - self.enable_phase_noise = True - - def disable_phase_variation(self) -> None: - self.enable_phase_noise = False - - def enable_trainable_morr_scale(self) -> None: - self.trainable_morr_scale = True - - def disable_trainable_morr_scale(self) -> None: - self.trainable_morr_scale = False - - def enable_trainable_morr_bias(self) -> None: - self.trainable_morr_bias = True - - def disable_trainable_morr_bias(self) -> None: - self.trainable_morr_bias = False - - @property - def morr_bias(self) -> Tensor: - if self.morr_input_bias is None: - return None - # return 2 * self.morr_fwhm * torch.sigmoid(self.morr_input_bias.unsqueeze(0).unsqueeze(-1)) - return self.morr_fwhm * torch.tanh( - self.morr_input_bias.unsqueeze(0).unsqueeze(-1) - ) - - @property - def morr_scale(self) -> Tensor: - if self.morr_input_scale is None: - return None - return torch.sigmoid(self.morr_input_scale.unsqueeze(-1)) + 0.2 # [p, q, 1] - - def propagate_morr( - self, weight: Tensor, x: Tensor, morr_output_scale: Tensor - ) -> Tensor: - """ - @description: propagate through the analytically calculated transfer matrix of molg. We implement circulant matrix multiplication using fast circ matmul - @param weight {torch.Tensor} two phase shifters in the MZI-based attenuators - @param x {torch.Tensor} complex-valued input - @param morr_output_scale {torch.Tensor} learnable balancing factors - @return: y {torch.Tensor} output of attenuators - """ - ### x : [bs, q, k] - ### weights: [p, q, k] - ### morr_output_scale: [1, 1, 1, q] - - ### input scaling [TCAD'21], must have valid ranges. too small will have dead neuron and not enough nonlinearity; too large will have larger power, cross-channel crosstalk. [0.2 - 1.2] will be suitable - ## build circulant weight matrix - # crosstalk on the weights are much cheaper to compute than on the phase shift - if self.enable_thermal_crosstalk and self.crosstalk_factor > 1: - weight = weight * self.crosstalk_factor - weight = toeplitz(weight).unsqueeze(0) # [1, p, q, k, k] - x = x.unsqueeze(1).unsqueeze(-1) # [bs, 1, q, k, 1] - x = weight.matmul(x).squeeze(-1) # [bs, p, q, k] - - if self.enable_phase_noise and self.phase_noise_std > 1e-5: - x = x + torch.zeros_like(x).normal_(0, self.phase_noise_std) - - ### input biasing [TCAD'21], must have valid ranges. too large will have power issue and cross-channel crosstalk. [-2FWHM ~ 0] - if self.trainable_morr_bias: - x = x - self.morr_bias - - ### Use theoretical transmission function for trainable MORR nonlinearity [TCAD'21] - ### x is the phase detuning, x=0 means on-resonance - ### phase: [bs, p, q, k] - x = self.mrr_roundtrip_phase_to_tr(x) # 3x faster than autograd - - ## implement balancing factor as dot-product - """ - if(self.w_bit < 16): - morr_output_scale = self.morr_output_scale_quantizer(self.morr_output_scale) - if(self.sigma_out_scale_quant_gain is None): - self.sigma_out_scale_quant_gain = self.sigma_out_scale / morr_output_scale.data.std().item() - morr_output_scale = morr_output_scale.mul(self.sigma_out_scale_quant_gain)### gain factor from Tanh used in quantization - else: - morr_output_scale = self.morr_output_scale - # morr_output_scale = morr_output_scale * self.morr_gain - scale = morr_output_scale[..., :-1, :] - scale_pad = morr_output_scale[..., -1:, :] - - # print("morr diff transmission:", end=", ") - # diff = x[..., :x.size(2)//2,:]-x[..., x.size(2)//2:,:] - # print_stat(diff) - if(self.grid_dim_x % 2 == 0): - #even blocks - scale = torch.cat([scale, -scale], dim=2) # [1, 1, q, 1] - else: - # odd blocks - if(self.grid_dim_x > 1): - scale = torch.cat([morr_output_scale, -scale], dim=2) # [1, 1, q, 1] - else: - scale = scale_pad # [1, 1, q, 1] - scale = scale.squeeze(-1).unsqueeze(0) # [1 ,1, 1, q] - # print("output scale Q:", end=", ") - # print_stat(scale[..., :scale.size(-1)//2]) - """ - x = morr_output_scale.matmul(x) # [1, 1, 1, q] x [bs, p, q, k] = [bs, p, 1, k] - x = x.flatten(1) # [bs, p*k] - return x - - def get_finegrain_drop_mask(self, topk: int) -> Tensor: - if self.w_bit < 16: - weight = self.weight_quantizer(self.weight.data) # [p, q, k] - else: - weight = self.weight.data.abs() - indices = weight.argsort(dim=-1) - mask = torch.ones_like(weight, dtype=torch.bool, device=weight.device) - - drop_indices = indices[:, :, 0:-topk] - mask.scatter_(2, drop_indices, 0) - self.finegrain_drop_mask = mask - return mask - - def apply_finegrain_drop_mask(self, mask: Tensor) -> None: - if self.w_bit < 16: - self.weight.data.masked_fill_(~mask.view_as(self.weight.data), -1000) - else: - self.weight.data.masked_fill_(~mask.view_as(self.weight.data), 0) - - def forward(self, x: Tensor, y: Tensor) -> Tensor: - - # load y as weight: - self.weight.data.copy_(y) - - # if used in transformer - is_transformer = len(x.shape) == 3 - if is_transformer: - B, N, D = x.shape - - assert ( - x.size(-1) == self.in_features - ), f"[E] Input dimension does not match the weight size {self.out_features, self.in_features}, but got input size ({tuple(x.size())}))" - if self.in_bit < 16: - x = self.input_quantizer(x) - - weight, morr_output_scale = self.build_weight() - if self.in_features_pad > self.in_features: - if self.x_zero_pad is None or self.x_zero_pad.size(0) != x.size(0): - self.x_zero_pad = torch.zeros( - x.size(0), - self.in_features_pad - self.in_features, - device=x.device, - dtype=x.dtype, - ) - x = torch.cat([x, self.x_zero_pad], dim=1) - - x = x.view(-1, self.grid_dim_x, self.miniblock) - - ### modulation - ### x: [bs, q, k] -> [bs, q, k] - x = self.input_modulator(x) - - ### propagate through morr array - ### x: [bs, q, k] -> [bs, p*k] - x = self.propagate_morr(weight, x, morr_output_scale) - - if self.out_features < self.out_features_pad: - x = x[..., : self.out_features] - if self.bias is not None: - x = x + self.bias.unsqueeze(0) - - # adjust output shape if used in transformer - if is_transformer: - x = x.view(B, N, self.out_features) - return x - -class MORRMHA(nn.Module): - def __init__(self, embed_dim, heads): - super(MORRMHA, self).__init__() - assert embed_dim % heads == 0 - self.n_heads = heads - self.Wq = AllPassMORRCirculantLinear(embed_dim, embed_dim) - self.Wk = AllPassMORRCirculantLinear(embed_dim, embed_dim) - self.Wv = AllPassMORRCirculantLinear(embed_dim, embed_dim) - self.qmm1 = AllPassMORRCirculantMatMals() - self.dropout_wq = nn.Dropout(0.1) - self.dropout_wk = nn.Dropout(0.1) - self.dropout_wv = nn.Dropout(0.1) - self.qmm2 = AllPassMORRCirculantMatMals() - self.Wout = AllPassMORRCirculantLinear(embed_dim, embed_dim) - self.dropout1 = nn.Dropout(0.1) - self.dropout2 = nn.Dropout(0.1) - - def forward(self, x, mask): - b = x.size(0) - n = x.size(1) - h = self.n_heads - d = x.size(2) - - def arrange_heads(acts): - # incoming shape of b, n, d, want b, h, n, d/h - return acts.view(b, n, h, -1).transpose(1, 2) - - q = arrange_heads(self.dropout_wq(self.Wq(x))) - k = arrange_heads(self.dropout_wk(self.Wk(x))) - v = arrange_heads(self.dropout_wv(self.Wv(x))) - - attn = self.qmm1(q, k.transpose(2, 3)) # yields b, h, n, n - masked = attn.masked_fill(mask, float("-inf")) - softmax_attn = self.dropout1(F.softmax(masked / math.sqrt(d // h), dim=3)) - out = self.qmm2(softmax_attn, v) # b, h, n, d/h - - out = out.transpose(1, 2).reshape(b, n, -1) - out = self.dropout2(out) - out = self.Wout(out) - return out - -class MORRFF(nn.Module): - def __init__(self, embed_dim, expansion_dim): - super(MORRFF, self).__init__() - self.first_drop = nn.Dropout(0.1) - self.layer1 = AllPassMORRCirculantLinear(embed_dim, expansion_dim, use_noise=True) - self.act = nn.ReLU6(inplace=True) - self.dropout = nn.Dropout(0.1) - self.layer2 = AllPassMORRCirculantLinear(expansion_dim, embed_dim, use_noise=True) - - def forward(self, x): - out = self.first_drop(x) - out = self.layer1(out) - out = self.act(out) - out = self.dropout(out) - out = self.layer2(out) - return out - -class MORRDecoderLayer(nn.Module): - def __init__(self, features, heads): - super(MORRDecoderLayer, self).__init__() - self.norm1 = nn.LayerNorm(features) - self.attn = MORRMHA(features, heads) - self.drop1 = nn.Dropout(0.1) - self.norm2 = nn.LayerNorm(features) - self.ff = MORRFF(features, features * 4) - self.drop2 = nn.Dropout(0.1) - - def forward(self, x, attn_mask): - # no need for key mask for gpt; autoregressive masking already prevents 'real' tokens from attending to padding tokens to the right - identity = x - out = self.norm1(x) - out = self.attn(out, attn_mask) - out = self.drop1(out) - out = out + identity - identity = out - out = self.norm2(out) - out = self.ff(out) - out = self.drop2(out) - out = out + identity - return out - -class _MORRGPT(nn.Module): - def __init__(self, features, heads, tokenizer, layers, max_length): - super(_MORRGPT, self).__init__() - vocab_size = len(tokenizer) + 8 - len(tokenizer) % 8 # pad vocab size to 8-multiple for tensor core acceleration - assert vocab_size % 8 == 0 - self.pos_embedding = nn.Embedding(max_length, features) - self.word_embedding = nn.Embedding(vocab_size, features, padding_idx = tokenizer.pad_token_id) - self.embedding_dropout = nn.Dropout(0.1) - self.decoders = nn.ModuleList([MORRDecoderLayer(features, heads) for _ in range(layers)]) - self.norm = nn.LayerNorm(features) - self.output_head = nn.Linear(features, vocab_size) - nn.init.normal_(self.word_embedding.weight, std=0.02) - nn.init.normal_(self.pos_embedding.weight, std=0.02) - - def forward_embedding(self, x): - embedded = self.word_embedding(x) - return embedded - - def forward_attn(self, x): - mask = make_autoregressive_mask_for(x) - pos = make_position_indices_for(x) - pos_embed = self.embedding_dropout(self.pos_embedding(pos) + x) - decoded = pos_embed - for layer in self.decoders: - decoded = layer(decoded, mask) - - out = self.norm(decoded) - return out - - def forward(self, x): - embedded = self.forward_embedding(x) - decoded = self.forward_attn(embedded) - out = self.output_head(decoded) - return out - - -class MORRGPT(pl.LightningModule): - def __init__(self, features, heads, layers=6, max_length=1024): - super().__init__() - self.tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") - self.tokenizer.add_special_tokens({'pad_token': '<|pad|>'}) - self.transformer = _MORRGPT(features, heads, self.tokenizer, layers, max_length) - self.loss = nn.CrossEntropyLoss(ignore_index=self.tokenizer.pad_token_id) - self.val_loss = torchmetrics.MeanMetric() - self.test_loss = torchmetrics.MeanMetric() - self.lr = 0.0005 - self.photon_target = 0 - self.training_steps = 100000 - self.extracting = False - self.use_adam = True - - def get_tokenizer(self): - return self.tokenizer - - def forward(self, x): - return self.transformer(x) - - def training_step(self, batch, batch_idx): - xs, ys = batch - preds = self(xs) - features = preds.size(2) - preds = preds.view(-1, features) - ys = ys.view(-1) - loss = self.loss(preds, ys) - self.log('train loss', loss) - return loss - - def validation_step(self, batch, batch_idx): - xs, ys = batch - preds = self(xs) - features = preds.size(2) - preds = preds.view(-1, features) - ys = ys.view(-1) - loss = self.loss(preds, ys) - self.val_loss.update(loss) - - def validation_epoch_end(self, outputs): - self.log('validation loss', self.val_loss) - - def test_step(self, batch, batch_idx): - xs, ys = batch - preds = self(xs) - features = preds.size(2) - preds = preds.view(-1, features) - ys = ys.view(-1) - loss = self.loss(preds, ys) - self.test_loss.update(loss) - if self.extracting: - raise ValueError("Extraction done, aborting") - - def test_epoch_end(self, outputs): - self.log('test loss', self.test_loss) - self.log('photon target', self.photon_target) - - def configure_optimizers(self): - if self.use_adam: - decay = set() - no_decay = set() - blacklist_weight_modules = (nn.LayerNorm, nn.Embedding) - - for mn, m in self.named_modules(): - for pn, p in m.named_parameters(recurse=False): - fpn = '%s.%s' % (mn, pn) if mn else pn # full param name - - if 'bias' in pn: - no_decay.add(fpn) - elif 'weight' in pn and not isinstance(m, blacklist_weight_modules): - decay.add(fpn) - else: - no_decay.add(fpn) - - param_dict = {pn: p for pn, p in self.named_parameters()} - inter_params = decay & no_decay - union_params = decay | no_decay - - assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), ) - assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \ - % (str(param_dict.keys() - union_params), ) - - optim_groups = [ - {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": 0.02}, - {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0}, - ] - - optimizer = torch.optim.AdamW(optim_groups, lr=self.lr) - scheduler = transformers.get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=2500, num_training_steps=self.training_steps) - return { - 'optimizer': optimizer, - 'lr_scheduler': { - 'scheduler': scheduler, - 'interval': 'step', - 'name': 'Cosine LR scheduler' - } - } - else: - optimizer = torch.optim.RMSprop(self.parameters(), lr=self.lr, weight_decay=1e-5) - scheduler = transformers.get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=2500, num_training_steps=self.training_steps) - return { - 'optimizer': optimizer, - 'lr_scheduler': { - 'scheduler': scheduler, - 'interval': 'step', - 'name': 'Cosine LR scheduler' - } - } - - def replace_output_head(self, module): - self.transformer.output_head = module - - def enable_quantization(self): - for m in self.transformer.modules(): - if isinstance(m, AllPassMORRCirculantLinear) or isinstance(m, AllPassMORRCirculantMatMals): - m.enable_quantization() - - def set_photon_target(self, n_photons): - self.photon_target = n_photons - for m in self.transformer.modules(): - if isinstance(m, AllPassMORRCirculantLinear) or isinstance(m, AllPassMORRCirculantMatMals): - m.set_photon_target(n_photons) - - def set_quantized_eval(self, value=True): - for m in self.transformer.modules(): - if isinstance(m, AllPassMORRCirculantLinear) or isinstance(m, AllPassMORRCirculantMatMals): - print("setting quantized eval") - m.force_quantized_eval = value - - def save(self, fname): - torch.save(self.transformer.state_dict(), fname) - - def load(self, fname): - self.transformer.load_state_dict(torch.load(fname)) - - def enable_extraction(self): - lin1 = self.transformer.decoders[0].ff.layer2 - lin1.extract_simulated = True - lin1.extract_name = 'first_linear' - lin2 = self.transformer.decoders[-1].ff.layer2 - lin2.extract_simulated = True - lin2.extract_name = 'last_linear' - attn1 = self.transformer.decoders[0].attn.qmm1 - attn1.extract_simulated = True - attn1.extract_name = 'first_attn' - attn2 = self.transformer.decoders[-1].attn.qmm1 - attn2.extract_simulated = True - attn2.extract_name = 'last_attn' - self.extracting = True - \ No newline at end of file diff --git a/src/chop/nn/optical/modules/morr_transformer/morr_bert.py b/src/chop/nn/optical/modules/morr_transformer/morr_bert.py new file mode 100644 index 000000000..36aff15fc --- /dev/null +++ b/src/chop/nn/optical/modules/morr_transformer/morr_bert.py @@ -0,0 +1,143 @@ +from typing import Optional +import logging + +import numpy as np +import math +import torch +import torch.nn as nn +import torch.fft +import torch.nn.functional as F +from torch import Tensor +from torch.nn import Parameter, init +from torch.types import Device +import pytorch_lightning as pl +import torchmetrics +import transformers +from transformers import GPT2TokenizerFast +from packaging import version +from typing import List, Optional, Tuple, Union + +from ...utils import MORRConfig_20um_MQ +from ...utils import mrr_roundtrip_phase_to_tr_func, mrr_roundtrip_phase_to_tr_fused +from ...utils import toeplitz +from ...utils import morr_uniform_ +from ...utils import input_quantize_fn, weight_quantize_fn +from ..base_layer import ONNBaseLayer +from ..morr_custom_linear import AllPassMORRLinear +from ..morr_linear import AllPassMORRCirculantLinear +from .morr_matmul import AllPassMORRCirculantMatMuls +from .morr_transformer import MORRSdpa + +from transformers.models.bert.modeling_bert import BertSelfAttention +from transformers.utils import ( + get_torch_version, +) + +class BertMORRSelfAttention(BertSelfAttention): + def __init__(self, config, position_embedding_type=None, morr_config=None): + super().__init__(config, position_embedding_type=position_embedding_type) + self.dropout_prob = config.attention_probs_dropout_prob + self.require_contiguous_qkv = version.parse(get_torch_version()) < version.parse("2.2.0") + # define MORR object to perform SDPA + self.morr_spda = None + self.morr_config = morr_config + + # Adapted from BertSelfAttention + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + if self.position_embedding_type != "absolute" or output_attentions or head_mask is not None: + # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once implemented. + # logger.warning_once( + # "BertSdpaSelfAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support " + # "non-absolute `position_embedding_type` or `output_attentions=True` or `head_mask`. Falling back to " + # "the manual attention implementation, but specifying the manual implementation will be required from " + # "Transformers version v5.0.0 onwards. This warning can be removed using the argument " + # '`attn_implementation="eager"` when loading the model.' + # ) + return super().forward( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + + bsz, tgt_len, _ = hidden_states.size() + + query_layer = self.transpose_for_scores(self.query(hidden_states)) + + # If this is instantiated as a cross-attention module, the keys and values come from an encoder; the attention + # mask needs to be such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + current_states = encoder_hidden_states if is_cross_attention else hidden_states + attention_mask = encoder_attention_mask if is_cross_attention else attention_mask + + # Check `seq_length` of `past_key_value` == `len(current_states)` to support prefix tuning + if is_cross_attention and past_key_value and past_key_value[0].shape[2] == current_states.shape[1]: + key_layer, value_layer = past_key_value + else: + key_layer = self.transpose_for_scores(self.key(current_states)) + value_layer = self.transpose_for_scores(self.value(current_states)) + if past_key_value is not None and not is_cross_attention: + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_layer, value_layer) + + # SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom + # attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0. + # Reference: https://github.com/pytorch/pytorch/issues/112577 + if self.require_contiguous_qkv and query_layer.device.type == "cuda" and attention_mask is not None: + query_layer = query_layer.contiguous() + key_layer = key_layer.contiguous() + value_layer = value_layer.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create + # a causal mask in case tgt_len == 1. + is_causal = ( + True if self.is_decoder and not is_cross_attention and attention_mask is None and tgt_len > 1 else False + ) + + self.morr_spda = MORRSdpa( + self.attention_head_size, # Dh + self.num_attention_heads, # H + hidden_states.shape[1], # N + dropout_p=self.dropout_prob, + use_morr=True, + morr_config=self.morr_config, + ) + attn_output = self.morr_spda( + query_layer, + key_layer, + value_layer, + attn_mask=attention_mask, + ) + + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, tgt_len, self.all_head_size) + + outputs = (attn_output,) + if self.is_decoder: + outputs = outputs + (past_key_value,) + return outputs \ No newline at end of file diff --git a/src/chop/nn/optical/modules/morr_transformer/morr_matmul.py b/src/chop/nn/optical/modules/morr_transformer/morr_matmul.py new file mode 100644 index 000000000..3102b6c37 --- /dev/null +++ b/src/chop/nn/optical/modules/morr_transformer/morr_matmul.py @@ -0,0 +1,535 @@ +from typing import Optional +import logging + +import numpy as np +import math +import torch +import torch.nn as nn +import torch.fft +import torch.nn.functional as F +from torch import Tensor +from torch.nn import Parameter, init +from torch.types import Device +import pytorch_lightning as pl +import torchmetrics +import transformers +from transformers import GPT2TokenizerFast + +from ...utils import MORRConfig_20um_MQ +from ...utils import mrr_roundtrip_phase_to_tr_func, mrr_roundtrip_phase_to_tr_fused +from ...utils import toeplitz +from ...utils import morr_uniform_ +from ...utils import input_quantize_fn, weight_quantize_fn +from ..base_layer import ONNBaseLayer +from ..morr_custom_linear import AllPassMORRLinear +from ..morr_linear import AllPassMORRCirculantLinear + +from transformers import BertModel, BertForSequenceClassification +from transformers.models.gpt2.modeling_gpt2 import ( + GPT2Attention, + GPT2MLP, + GPT2Block, + Conv1D, +) + +logger = logging.getLogger(__name__) + +__all__ = ["AllPassMORRCirculantMatMuls"] + + +class AllPassMORRCirculantMatMuls(ONNBaseLayer): + """ + All-pass MORR Linear layer, assumes (1) block-circulant matrix (2) differential rails (3) learnable balancing factors. + J. Gu, et al., "SqueezeLight: Towards Scalable Optical Neural Networks with Multi-Operand Ring Resonators" + https://doi.org/10.23919/DATE51398.2021.9474147 + """ + + __constants__ = ["in_features", "out_features"] + in_features: int + out_features: int + miniblock: int + weight: Tensor + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = False, + config=None, + device: Device = torch.device("cuda" if torch.cuda.is_available() else "cpu"), + ) -> None: + super(AllPassMORRCirculantMatMuls, self).__init__() + self.in_features = in_features + self.out_features = out_features + + miniblock_size = config.get("miniblock", 4) + self.miniblock = miniblock_size + self.grid_dim_x = int(np.ceil(self.in_features / miniblock_size)) + self.grid_dim_y = int(np.ceil(self.out_features / miniblock_size)) + self.in_features_pad = self.grid_dim_x * miniblock_size + self.out_features_pad = self.grid_dim_y * miniblock_size + + self.v_max = 10.8 + self.v_pi = 4.36 + self.gamma = np.pi / self.v_pi**2 + self.w_bit = 32 + self.in_bit = 32 + + morr_config = config.get("MORRConfig", MORRConfig_20um_MQ) + morr_init_val = config.get("morr_init", MORRConfig_20um_MQ) + self.MORRConfig = morr_config + self.morr_init = morr_init_val + self.mrr_a = morr_config.attenuation_factor + self.mrr_r = morr_config.coupling_factor + self.trainable_morr_bias = config.get("trainable_morr_bias", MORRConfig_20um_MQ) + self.trainable_morr_scale = config.get( + "trainable_morr_scale", MORRConfig_20um_MQ + ) + self.device = device + ### calculate FWHM (rad) + self.morr_fwhm = ( + -4 + * np.pi**2 + * morr_config.radius + * morr_config.effective_index + * ( + 1 / morr_config.resonance_wavelength + - 1 / (morr_config.resonance_wavelength - morr_config.bandwidth / 2) + ) + ) + + ### allocate parameters + self.weight = None + self.x_zero_pad = None + self.morr_output_scale = None ## learnable balancing factors implelemt by MRRs + self.morr_input_bias = None ## round-trip phase shift bias within MORR + self.morr_input_scale = ( + None ## scaling factor for the round-trip phase shift within MORR + ) + self.morr_gain = ( + 100 / (self.in_features // self.miniblock) + ) ** 0.5 ## TIA gain, calculated such that output variance is around 1 + ### build trainable parameters + self.build_parameters() + + ### quantization tool + self.input_quantizer = input_quantize_fn(self.in_bit, device=self.device) + self.weight_quantizer = weight_quantize_fn( + self.w_bit, alg="dorefa_pos" + ) ## [0-1] positive only, maintain the original scale + self.morr_output_scale_quantizer = weight_quantize_fn( + self.w_bit, alg="dorefa_sym" + ) ## [-1,1] full-range + + self.mrr_roundtrip_phase_to_tr = mrr_roundtrip_phase_to_tr_func( + a=self.mrr_a, r=self.mrr_r, intensity=True + ) + + ### default set to slow forward + self.disable_fast_forward() + ### default set no gamma noise + self.set_gamma_noise(0) + ### default set no crosstalk + self.disable_crosstalk() + ### default set no phase variation + self.disable_phase_variation() + + if bias: + self.bias = Parameter(torch.Tensor(out_features).to(self.device)) + else: + self.register_parameter("bias", None) + + self.reset_parameters(morr_init=morr_init_val) + self.finegrain_drop_mask = None + + def build_parameters(self) -> None: + + self.weight = Parameter( + torch.ones( + self.grid_dim_y, + self.grid_dim_x, + self.miniblock, + device=self.device, + dtype=torch.float, + ) + ) + ### Learnable balancing factor (morr_output_scale) + ### We use a single scaling factor for each block + self.morr_output_scale = Parameter( + torch.randn(1, 1, max(1, self.grid_dim_x // 2) + 1, 1, device=self.device) + ) + if self.trainable_morr_bias: + ### initialize with the finest-granularity, i.e., per mini-block + self.morr_input_bias = Parameter( + torch.zeros( + self.grid_dim_y, + self.grid_dim_x, + device=self.device, + dtype=torch.float, + ) + ) + if self.trainable_morr_scale: + ### initialize with the finest-granularity, i.e., per mini-block + self.morr_input_scale = Parameter( + torch.zeros( + self.grid_dim_y, + self.grid_dim_x, + device=self.device, + dtype=torch.float, + ) + ) + + def reset_parameters(self, morr_init: bool = False) -> None: + ### nonlinear curve aware initialization + if morr_init: + ## initialize weight + morr_uniform_( + self.weight, + MORRConfig=self.MORRConfig, + n_op=self.miniblock, + biased=self.w_bit >= 16, + gain=2 if self.in_bit < 16 else 1, + ) # quantization needs zero-center + self.sigma_weight = self.weight.data.std().item() + self.weight_quant_gain = None + + ## output distribution aware initialization to output scaling factor + t1 = mrr_roundtrip_phase_to_tr_fused( + torch.tensor([0]).float(), a=self.mrr_a, r=self.mrr_r, intensity=True + ) + t2 = mrr_roundtrip_phase_to_tr_fused( + torch.tensor([self.morr_fwhm * 2.4]).float(), + a=self.mrr_a, + r=self.mrr_r, + intensity=True, + ) + g = ( + (t2 - t1) / (2.4 * self.morr_fwhm) + ).item() ## 0~2.4 FWHM slope as a linear approximation + + self.sigma_out_scale = 4 / (3 * self.grid_dim_x**0.5 * g * self.morr_fwhm) + self.out_scale_quant_gain = None + init.normal_(self.morr_output_scale, 0, self.sigma_out_scale) + else: + init.kaiming_normal_(self.weight.data) + init.kaiming_normal_(self.morr_output_scale.data) + self.sigma_weight = self.weight.data.std().item() + self.weight_quant_gain = None + self.sigma_out_scale = self.morr_output_scale.data.std().item() + self.out_scale_quant_gain = None + + if self.morr_input_bias is not None: + self.morr_input_bias.data.zero_() + if self.morr_input_scale is not None: + ### after sigmoid, it cooresponds to 1 scale + init.normal_(self.morr_input_scale.data, 2, 0.1) + + if self.bias is not None: + init.uniform_(self.bias, 0, 0) + + def sync_parameters(self, src: str = "weight") -> None: + """ + description: synchronize all parameters from the source parameters + """ + + raise NotImplementedError + + def build_weight(self) -> Tensor: + if self.w_bit < 16: + ### differentiable quantizer based on STE to enable QAT (Dorefa-Net, arXiv 2016) + weight = self.weight_quantizer(self.weight) + + ## rescale weights after quantization can maintain the initialization distribution + if self.weight_quant_gain is None: + self.weight_quant_gain = self.sigma_weight / weight.data.std() + if self.trainable_morr_scale: + morr_scale = self.morr_scale * self.weight_quant_gain + else: + morr_scale = self.weight_quant_gain + weight = weight.mul( + morr_scale + ) ### gain factor from Tanh used in quantization + + ### quantize learnable balancing factor + morr_output_scale = self.morr_output_scale_quantizer(self.morr_output_scale) + else: + weight = self.weight.abs() # positive only + morr_output_scale = ( + self.morr_output_scale - self.morr_output_scale.data.mean() + ) + + if self.finegrain_drop_mask is not None: + weight = weight.mul(self.finegrain_drop_mask.float()) + + ## differential balancing factor concatenation + scale = morr_output_scale[..., :-1, :] + scale_pad = morr_output_scale[..., -1:, :] + if self.grid_dim_x % 2 == 0: + # even blocks + scale = torch.cat([scale, -scale], dim=2) # [1, 1, q, 1] + else: + # odd blocks + if self.grid_dim_x > 1: + scale = torch.cat([morr_output_scale, -scale], dim=2) # [1, 1, q, 1] + else: + scale = scale_pad # [1, 1, q, 1] + morr_output_scale = scale.squeeze(-1).unsqueeze(0) # [1 ,1, 1, q] + + return weight, morr_output_scale + + def enable_fast_forward(self) -> None: + self.fast_forward_flag = True + + def disable_fast_forward(self) -> None: + self.fast_forward_flag = False + + def set_gamma_noise( + self, noise_std: float, random_state: Optional[int] = None + ) -> None: + self.gamma_noise_std = noise_std + + def load_parameters(self, param_dict) -> None: + """ + description: update parameters based on this parameter dictionary\\ + param param_dict {dict of dict} {layer_name: {param_name: param_tensor, ...}, ...} + """ + for name, param in param_dict.items(): + getattr(self, name).data.copy_(param) + + def set_weight_bitwidth(self, w_bit: int) -> None: + self.w_bit = w_bit + self.weight_quantizer.set_bitwidth(w_bit) + self.morr_output_scale_quantizer.set_bitwidth(w_bit) + + def set_input_bitwidth(self, in_bit: int) -> None: + self.in_bit = in_bit + self.input_quantizer.set_bitwidth(in_bit) + + def input_modulator(self, x: Tensor) -> Tensor: + ### voltage to power, which is proportional to the phase shift + return x * x + + def set_crosstalk_coupling_matrix( + self, coupling_factor: float, drop_perc: float = 0 + ) -> None: + ### crosstalk coupling matrix is a symmetric matrix, but the intra-MORR crosstalk can be taken as a round-trip phase shift scaling factor, which is proportional to the number of segments after pruned. + ### drop-perc is the pruning percentage. + assert 0 <= coupling_factor <= 1, logger.error( + f"Coupling factor must in [0,1], but got {coupling_factor}" + ) + + self.crosstalk_factor = ( + 1 + max(3, (self.miniblock * (1 - drop_perc) - 1)) * coupling_factor + ) + + def enable_crosstalk(self) -> None: + self.enable_thermal_crosstalk = True + + def disable_crosstalk(self) -> None: + self.enable_thermal_crosstalk = False + + def set_phase_variation(self, phase_noise_std: float = 0) -> None: + self.phase_noise_std = phase_noise_std + + def enable_phase_variation(self) -> None: + self.enable_phase_noise = True + + def disable_phase_variation(self) -> None: + self.enable_phase_noise = False + + def enable_trainable_morr_scale(self) -> None: + self.trainable_morr_scale = True + + def disable_trainable_morr_scale(self) -> None: + self.trainable_morr_scale = False + + def enable_trainable_morr_bias(self) -> None: + self.trainable_morr_bias = True + + def disable_trainable_morr_bias(self) -> None: + self.trainable_morr_bias = False + + @property + def morr_bias(self) -> Tensor: + if self.morr_input_bias is None: + return None + # return 2 * self.morr_fwhm * torch.sigmoid(self.morr_input_bias.unsqueeze(0).unsqueeze(-1)) + return self.morr_fwhm * torch.tanh( + self.morr_input_bias.unsqueeze(0).unsqueeze(-1) + ) + + @property + def morr_scale(self) -> Tensor: + if self.morr_input_scale is None: + return None + return torch.sigmoid(self.morr_input_scale.unsqueeze(-1)) + 0.2 # [p, q, 1] + + def propagate_morr( + self, weight: Tensor, x: Tensor, morr_output_scale: Tensor + ) -> Tensor: + """ + @description: propagate through the analytically calculated transfer matrix of molg. We implement circulant matrix multiplication using fast circ matmul + @param weight {torch.Tensor} two phase shifters in the MZI-based attenuators + @param x {torch.Tensor} complex-valued input + @param morr_output_scale {torch.Tensor} learnable balancing factors + @return: y {torch.Tensor} output of attenuators + """ + ### x : [bs, q, k] + ### weights: [p, q, k] + ### morr_output_scale: [1, 1, 1, q] + + ### input scaling [TCAD'21], must have valid ranges. too small will have dead neuron and not enough nonlinearity; too large will have larger power, cross-channel crosstalk. [0.2 - 1.2] will be suitable + ## build circulant weight matrix + # crosstalk on the weights are much cheaper to compute than on the phase shift + if self.enable_thermal_crosstalk and self.crosstalk_factor > 1: + weight = weight * self.crosstalk_factor + weight = toeplitz(weight).unsqueeze(0) # [1, p, q, k, k] + x = x.unsqueeze(1).unsqueeze(-1) # [bs, 1, q, k, 1] + x = weight.matmul(x).squeeze(-1) # [bs, p, q, k] + + if self.enable_phase_noise and self.phase_noise_std > 1e-5: + x = x + torch.zeros_like(x).normal_(0, self.phase_noise_std) + + ### input biasing [TCAD'21], must have valid ranges. too large will have power issue and cross-channel crosstalk. [-2FWHM ~ 0] + if self.trainable_morr_bias: + x = x - self.morr_bias + + ### Use theoretical transmission function for trainable MORR nonlinearity [TCAD'21] + ### x is the phase detuning, x=0 means on-resonance + ### phase: [bs, p, q, k] + x = self.mrr_roundtrip_phase_to_tr(x) # 3x faster than autograd + + ## implement balancing factor as dot-product + """ + if(self.w_bit < 16): + morr_output_scale = self.morr_output_scale_quantizer(self.morr_output_scale) + if(self.sigma_out_scale_quant_gain is None): + self.sigma_out_scale_quant_gain = self.sigma_out_scale / morr_output_scale.data.std().item() + morr_output_scale = morr_output_scale.mul(self.sigma_out_scale_quant_gain)### gain factor from Tanh used in quantization + else: + morr_output_scale = self.morr_output_scale + # morr_output_scale = morr_output_scale * self.morr_gain + scale = morr_output_scale[..., :-1, :] + scale_pad = morr_output_scale[..., -1:, :] + + # print("morr diff transmission:", end=", ") + # diff = x[..., :x.size(2)//2,:]-x[..., x.size(2)//2:,:] + # print_stat(diff) + if(self.grid_dim_x % 2 == 0): + #even blocks + scale = torch.cat([scale, -scale], dim=2) # [1, 1, q, 1] + else: + # odd blocks + if(self.grid_dim_x > 1): + scale = torch.cat([morr_output_scale, -scale], dim=2) # [1, 1, q, 1] + else: + scale = scale_pad # [1, 1, q, 1] + scale = scale.squeeze(-1).unsqueeze(0) # [1 ,1, 1, q] + # print("output scale Q:", end=", ") + # print_stat(scale[..., :scale.size(-1)//2]) + """ + x = morr_output_scale.matmul(x) # [1, 1, 1, q] x [bs, p, q, k] = [bs, p, 1, k] + x = x.flatten(1) # [bs, p*k] + return x + + def get_finegrain_drop_mask(self, topk: int) -> Tensor: + if self.w_bit < 16: + weight = self.weight_quantizer(self.weight.data) # [p, q, k] + else: + weight = self.weight.data.abs() + indices = weight.argsort(dim=-1) + mask = torch.ones_like(weight, dtype=torch.bool, device=weight.device) + + drop_indices = indices[:, :, 0:-topk] + mask.scatter_(2, drop_indices, 0) + self.finegrain_drop_mask = mask + return mask + + def apply_finegrain_drop_mask(self, mask: Tensor) -> None: + if self.w_bit < 16: + self.weight.data.masked_fill_(~mask.view_as(self.weight.data), -1000) + else: + self.weight.data.masked_fill_(~mask.view_as(self.weight.data), 0) + + def load_compressed_weight(self, weight: Tensor) -> None: + """ + load weight data from torch.linear module.weight.data + """ + assert weight.shape == (self.out_features, self.in_features), ( + f"Expected {(self.out_features, self.in_features)}, got {weight.shape}" + ) + + W_padded = weight.new_zeros((self.out_features_pad, self.in_features_pad)) + W_padded[: weight.size(0), : weight.size(1)] = weight + + new_weight = weight.new_zeros((self.grid_dim_y, self.grid_dim_x, self.miniblock)) + for p in range(self.grid_dim_y): + for q in range(self.grid_dim_x): + for k in range(self.miniblock): + row_idx = p * self.miniblock + k # The row in W_padded: + col_start = q * self.miniblock # The columns in W_padded: + col_end = (q + 1) * self.miniblock + block = W_padded[row_idx, col_start:col_end] + + new_weight[p, q, k] = block.mean() + bound = 1 / math.sqrt(self.miniblock) + new_weight = torch.rand( + (self.grid_dim_y, self.grid_dim_x, self.miniblock), + device=weight.device, + dtype=weight.dtype + ) * 2 * bound - bound + + self.load_parameters({"weight": new_weight}) + + def forward(self, X: Tensor, Y: Tensor) -> Tensor: + """ + this module currently support 4-D multi-head attn MatMul only + - x: input, [B, H, N, D] + - y: weight, [B, H, D, N] + """ + assert len(X.shape) == 4, f"Expected a 4-D tensor, got shape {X.shape}" + B, H, N, D = X.shape + out_rows = [] + + for b in range(B): + for h in range(H): + self.load_compressed_weight(Y[b, h].t()) + x = X[b, h] + + assert ( + x.size(-1) == self.in_features + ), f"[E] Input dimension does not match the weight size {self.out_features, self.in_features}, but got input size ({tuple(x.size())}))" + if self.in_bit < 16: + x = self.input_quantizer(x) + + weight, morr_output_scale = self.build_weight() + if self.in_features_pad > self.in_features: + if self.x_zero_pad is None or self.x_zero_pad.size(0) != x.size(0): + self.x_zero_pad = torch.zeros( + x.size(0), + self.in_features_pad - self.in_features, + device=x.device, + dtype=x.dtype, + ) + x = torch.cat([x, self.x_zero_pad], dim=1) + + x = x.view(-1, self.grid_dim_x, self.miniblock) + + ### modulation + ### x: [bs, q, k] -> [bs, q, k] + x = self.input_modulator(x) + + ### propagate through morr array + ### x: [bs, q, k] -> [bs, p*k] + x = self.propagate_morr(weight, x, morr_output_scale) + + if self.out_features < self.out_features_pad: + x = x[..., : self.out_features] + if self.bias is not None: + x = x + self.bias.unsqueeze(0) + + out_rows.append(x) + + out = torch.stack(out_rows, dim=0) # (B·H, N, N) + out = out.view(B, H, N, self.out_features) + return out \ No newline at end of file diff --git a/src/chop/nn/optical/modules/morr_transformer/morr_transformer.py b/src/chop/nn/optical/modules/morr_transformer/morr_transformer.py new file mode 100644 index 000000000..a9da28085 --- /dev/null +++ b/src/chop/nn/optical/modules/morr_transformer/morr_transformer.py @@ -0,0 +1,168 @@ +from typing import Optional +import logging + +import numpy as np +import math +import torch +import torch.nn as nn +import torch.fft +import torch.nn.functional as F +from torch import Tensor +from torch.nn import Parameter, init +from torch.types import Device +import pytorch_lightning as pl +import torchmetrics +import transformers +from transformers import GPT2TokenizerFast + +from ...utils import MORRConfig_20um_MQ +from ...utils import mrr_roundtrip_phase_to_tr_func, mrr_roundtrip_phase_to_tr_fused +from ...utils import toeplitz +from ...utils import morr_uniform_ +from ...utils import input_quantize_fn, weight_quantize_fn +from ..base_layer import ONNBaseLayer +from ..morr_custom_linear import AllPassMORRLinear +from ..morr_linear import AllPassMORRCirculantLinear +from .morr_matmul import AllPassMORRCirculantMatMuls + +from transformers import BertModel, BertForSequenceClassification +from transformers.models.gpt2.modeling_gpt2 import ( + GPT2Attention, + GPT2MLP, + GPT2Block, + Conv1D, +) + +logger = logging.getLogger(__name__) + +__all__ = [""] + + + +class MORRMHA(nn.Module): + def __init__(self, embed_dim, heads): + super(MORRMHA, self).__init__() + assert embed_dim % heads == 0 + self.n_heads = heads + self.Wq = AllPassMORRCirculantLinear(embed_dim, embed_dim) + self.Wk = AllPassMORRCirculantLinear(embed_dim, embed_dim) + self.Wv = AllPassMORRCirculantLinear(embed_dim, embed_dim) + self.qmm1 = AllPassMORRCirculantMatMuls() + self.dropout_wq = nn.Dropout(0.1) + self.dropout_wk = nn.Dropout(0.1) + self.dropout_wv = nn.Dropout(0.1) + self.qmm2 = AllPassMORRCirculantMatMuls() + self.Wout = AllPassMORRCirculantLinear(embed_dim, embed_dim) + self.dropout1 = nn.Dropout(0.1) + self.dropout2 = nn.Dropout(0.1) + + def forward(self, x, mask): + b = x.size(0) + n = x.size(1) + h = self.n_heads + d = x.size(2) + + def arrange_heads(acts): + # incoming shape of b, n, d, want b, h, n, d/h + return acts.view(b, n, h, -1).transpose(1, 2) + + q = arrange_heads(self.dropout_wq(self.Wq(x))) + k = arrange_heads(self.dropout_wk(self.Wk(x))) + v = arrange_heads(self.dropout_wv(self.Wv(x))) + + attn = self.qmm1(q, k.transpose(2, 3)) # yields b, h, n, n + masked = attn.masked_fill(mask, float("-inf")) + softmax_attn = self.dropout1(F.softmax(masked / math.sqrt(d // h), dim=3)) + out = self.qmm2(softmax_attn, v) # b, h, n, d/h + + out = out.transpose(1, 2).reshape(b, n, -1) + out = self.dropout2(out) + out = self.Wout(out) + return out + + +class MORRFF(nn.Module): + def __init__(self, embed_dim, expansion_dim): + super(MORRFF, self).__init__() + self.first_drop = nn.Dropout(0.1) + self.layer1 = AllPassMORRCirculantLinear(embed_dim, expansion_dim, use_noise=True) + self.act = nn.ReLU6(inplace=True) + self.dropout = nn.Dropout(0.1) + self.layer2 = AllPassMORRCirculantLinear(expansion_dim, embed_dim, use_noise=True) + + def forward(self, x): + out = self.first_drop(x) + out = self.layer1(out) + out = self.act(out) + out = self.dropout(out) + out = self.layer2(out) + return out + +class MORRDecoderLayer(nn.Module): + def __init__(self, features, heads): + super(MORRDecoderLayer, self).__init__() + self.norm1 = nn.LayerNorm(features) + self.attn = MORRMHA(features, heads) + self.drop1 = nn.Dropout(0.1) + self.norm2 = nn.LayerNorm(features) + self.ff = MORRFF(features, features * 4) + self.drop2 = nn.Dropout(0.1) + + def forward(self, x, attn_mask): + # no need for key mask for gpt; autoregressive masking already prevents 'real' tokens from attending to padding tokens to the right + identity = x + out = self.norm1(x) + out = self.attn(out, attn_mask) + out = self.drop1(out) + out = out + identity + identity = out + out = self.norm2(out) + out = self.ff(out) + out = self.drop2(out) + out = out + identity + return out + + +class MORRSdpa(nn.Module): + def __init__(self, attn_head_size, num_heads, seq_length, dropout_p, use_morr = False, morr_config = None): + super(MORRSdpa, self).__init__() + self.attn_head_size = attn_head_size + self.num_heads = num_heads + self.use_morr = use_morr + self.qmm1 = AllPassMORRCirculantMatMuls( + in_features=attn_head_size, # Dh + out_features=seq_length, # N + config = morr_config + ) + self.qmm1.disable_trainable_morr_scale() + self.qmm1.disable_trainable_morr_bias() + + self.qmm2 = AllPassMORRCirculantMatMuls( + in_features=seq_length, # D + out_features=attn_head_size, # N + config = morr_config + ) + self.qmm2.disable_trainable_morr_scale() + self.qmm2.disable_trainable_morr_bias() + self.dropout = nn.Dropout(dropout_p) + + def forward(self, query, key, value, attn_mask): + attn_head_size = self.attn_head_size + + if self.use_morr: + attn_scores = self.qmm1(query, key.transpose(2, 3)) # yields b, h, n, n + else: + attn_scores = torch.matmul(query, key.transpose(2, 3)) + + attn_scores = attn_scores / math.sqrt(attn_head_size) + if attn_mask is not None: + attn_scores = attn_scores + attn_mask + attn_probs = nn.functional.softmax(attn_scores, dim=-1) + attn_probs = self.dropout(attn_probs) + + if self.use_morr: + out = self.qmm2(attn_probs, value) # [B, H, N, N] * [B, H, N, Dh] -> [b, h, n, Dh] + else: + out = torch.matmul(attn_probs, value) + + return out \ No newline at end of file diff --git a/src/chop/nn/optical/triton_modules/dtype.py b/src/chop/nn/optical/triton_modules/dtype.py new file mode 100644 index 000000000..caaa77e69 --- /dev/null +++ b/src/chop/nn/optical/triton_modules/dtype.py @@ -0,0 +1,17 @@ +import torch +import triton.language as tl + + +TORCH_DTYPE_TO_TRITON = { + torch.float16: tl.float16, + torch.float32: tl.float32, + torch.bfloat16: tl.bfloat16, + torch.int8: tl.int8, + torch.uint8: tl.uint8, + torch.int16: tl.int16, + torch.uint16: tl.uint16, + torch.int32: tl.int32, + torch.uint32: tl.uint32, + torch.float8_e4m3fn: tl.float8e4nv, + torch.float8_e5m2: tl.float8e5, +} diff --git a/src/chop/nn/optical/triton_modules/morr_linear.py b/src/chop/nn/optical/triton_modules/morr_linear.py new file mode 100644 index 000000000..5c8aa4986 --- /dev/null +++ b/src/chop/nn/optical/triton_modules/morr_linear.py @@ -0,0 +1,473 @@ +""" +Description: +Author: Jiaqi Gu (jqgu@utexas.edu) +Date: 2022-04-18 14:19:57 +LastEditors: Jiaqi Gu (jqgu@utexas.edu) +LastEditTime: 2022-04-18 16:21:37 +""" + +from typing import Optional +import logging + +import numpy as np +import torch +import torch.fft +from torch import Tensor +from torch.nn import Parameter, init +from torch.types import Device + +from ..utils import MORRConfig_20um_MQ +from ..utils import mrr_roundtrip_phase_to_tr_func, mrr_roundtrip_phase_to_tr_fused +from ..utils import toeplitz +from ..utils import morr_uniform_ +from ..utils import input_quantize_fn, weight_quantize_fn +from ..modules.base_layer import ONNBaseLayer +from .morr_linear_kernel import morr_linear_fn + +logger = logging.getLogger(__name__) + +__all__ = ["AllPassMORRCirculantLinear"] + + +class TritonMORRLinear(ONNBaseLayer): + """ + All-pass MORR Linear layer, assumes (1) block-circulant matrix (2) differential rails (3) learnable balancing factors. + J. Gu, et al., "SqueezeLight: Towards Scalable Optical Neural Networks with Multi-Operand Ring Resonators" + https://doi.org/10.23919/DATE51398.2021.9474147 + """ + + __constants__ = ["in_features", "out_features"] + in_features: int + out_features: int + miniblock: int + weight: Tensor + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = False, + config={}, + device: Device = torch.device("cpu"), + ) -> None: + super(TritonMORRLinear, self).__init__() + self.in_features = in_features + self.out_features = out_features + + miniblock_size = config.get("miniblock", 4) + self.miniblock = miniblock_size + self.grid_dim_x = int(np.ceil(self.in_features / miniblock_size)) + self.grid_dim_y = int(np.ceil(self.out_features / miniblock_size)) + self.in_features_pad = self.grid_dim_x * miniblock_size + self.out_features_pad = self.grid_dim_y * miniblock_size + + self.v_max = 10.8 + self.v_pi = 4.36 + self.gamma = np.pi / self.v_pi**2 + self.w_bit = 32 + self.in_bit = 32 + + morr_config = config.get("MORRConfig", MORRConfig_20um_MQ) + morr_init_val = config.get("morr_init", MORRConfig_20um_MQ) + self.MORRConfig = morr_config + self.morr_init = morr_init_val + self.mrr_a = morr_config.attenuation_factor + self.mrr_r = morr_config.coupling_factor + self.trainable_morr_bias = config.get("trainable_morr_bias", MORRConfig_20um_MQ) + self.trainable_morr_scale = config.get( + "trainable_morr_scale", MORRConfig_20um_MQ + ) + self.device = device + ### calculate FWHM (rad) + self.morr_fwhm = ( + -4 + * np.pi**2 + * morr_config.radius + * morr_config.effective_index + * ( + 1 / morr_config.resonance_wavelength + - 1 / (morr_config.resonance_wavelength - morr_config.bandwidth / 2) + ) + ) + + ### allocate parameters + self.weight = None + self.x_zero_pad = None + self.morr_output_scale = None ## learnable balancing factors implelemt by MRRs + self.morr_input_bias = None ## round-trip phase shift bias within MORR + self.morr_input_scale = ( + None ## scaling factor for the round-trip phase shift within MORR + ) + self.morr_gain = ( + 100 / (self.in_features // self.miniblock) + ) ** 0.5 ## TIA gain, calculated such that output variance is around 1 + ### build trainable parameters + self.build_parameters() + + ### quantization tool + self.input_quantizer = input_quantize_fn(self.in_bit, device=self.device) + self.weight_quantizer = weight_quantize_fn( + self.w_bit, alg="dorefa_pos" + ) ## [0-1] positive only, maintain the original scale + self.morr_output_scale_quantizer = weight_quantize_fn( + self.w_bit, alg="dorefa_sym" + ) ## [-1,1] full-range + + self.mrr_roundtrip_phase_to_tr = mrr_roundtrip_phase_to_tr_func( + a=self.mrr_a, r=self.mrr_r, intensity=True + ) + + ### default set to slow forward + self.disable_fast_forward() + ### default set no gamma noise + self.set_gamma_noise(0) + ### default set no crosstalk + self.disable_crosstalk() + ### default set no phase variation + self.disable_phase_variation() + + if bias: + self.bias = Parameter(torch.Tensor(out_features).to(self.device)) + else: + self.register_parameter("bias", None) + + self.reset_parameters(morr_init=morr_init_val) + self.finegrain_drop_mask = None + + def build_parameters(self) -> None: + + self.weight = Parameter( + torch.ones( + self.grid_dim_y, + self.grid_dim_x, + self.miniblock, + device=self.device, + dtype=torch.float, + ) + ) + ### Learnable balancing factor (morr_output_scale) + ### We use a single scaling factor for each block + self.morr_output_scale = Parameter( + torch.randn(1, 1, max(1, self.grid_dim_x // 2) + 1, 1, device=self.device) + ) + if self.trainable_morr_bias: + ### initialize with the finest-granularity, i.e., per mini-block + self.morr_input_bias = Parameter( + torch.zeros( + self.grid_dim_y, + self.grid_dim_x, + device=self.device, + dtype=torch.float, + ) + ) + if self.trainable_morr_scale: + ### initialize with the finest-granularity, i.e., per mini-block + self.morr_input_scale = Parameter( + torch.zeros( + self.grid_dim_y, + self.grid_dim_x, + device=self.device, + dtype=torch.float, + ) + ) + + def reset_parameters(self, morr_init: bool = False) -> None: + ### nonlinear curve aware initialization + if morr_init: + ## initialize weight + morr_uniform_( + self.weight, + MORRConfig=self.MORRConfig, + n_op=self.miniblock, + biased=self.w_bit >= 16, + gain=2 if self.in_bit < 16 else 1, + ) # quantization needs zero-center + self.sigma_weight = self.weight.data.std().item() + self.weight_quant_gain = None + + ## output distribution aware initialization to output scaling factor + t1 = mrr_roundtrip_phase_to_tr_fused( + torch.tensor([0]).float(), a=self.mrr_a, r=self.mrr_r, intensity=True + ) + t2 = mrr_roundtrip_phase_to_tr_fused( + torch.tensor([self.morr_fwhm * 2.4]).float(), + a=self.mrr_a, + r=self.mrr_r, + intensity=True, + ) + g = ( + (t2 - t1) / (2.4 * self.morr_fwhm) + ).item() ## 0~2.4 FWHM slope as a linear approximation + + self.sigma_out_scale = 4 / (3 * self.grid_dim_x**0.5 * g * self.morr_fwhm) + self.out_scale_quant_gain = None + init.normal_(self.morr_output_scale, 0, self.sigma_out_scale) + else: + init.kaiming_normal_(self.weight.data) + init.kaiming_normal_(self.morr_output_scale.data) + self.sigma_weight = self.weight.data.std().item() + self.weight_quant_gain = None + self.sigma_out_scale = self.morr_output_scale.data.std().item() + self.out_scale_quant_gain = None + + if self.morr_input_bias is not None: + self.morr_input_bias.data.zero_() + if self.morr_input_scale is not None: + ### after sigmoid, it cooresponds to 1 scale + init.normal_(self.morr_input_scale.data, 2, 0.1) + + if self.bias is not None: + init.uniform_(self.bias, 0, 0) + + def sync_parameters(self, src: str = "weight") -> None: + """ + description: synchronize all parameters from the source parameters + """ + + raise NotImplementedError + + def build_weight(self) -> Tensor: + if self.w_bit < 16: + ### differentiable quantizer based on STE to enable QAT (Dorefa-Net, arXiv 2016) + weight = self.weight_quantizer(self.weight) + + ## rescale weights after quantization can maintain the initialization distribution + if self.weight_quant_gain is None: + self.weight_quant_gain = self.sigma_weight / weight.data.std() + if self.trainable_morr_scale: + morr_scale = self.morr_scale * self.weight_quant_gain + else: + morr_scale = self.weight_quant_gain + weight = weight.mul( + morr_scale + ) ### gain factor from Tanh used in quantization + + ### quantize learnable balancing factor + morr_output_scale = self.morr_output_scale_quantizer(self.morr_output_scale) + else: + weight = self.weight.abs() # positive only + morr_output_scale = ( + self.morr_output_scale - self.morr_output_scale.data.mean() + ) + + if self.finegrain_drop_mask is not None: + weight = weight.mul(self.finegrain_drop_mask.float()) + + ## differential balancing factor concatenation + scale = morr_output_scale[..., :-1, :] + scale_pad = morr_output_scale[..., -1:, :] + if self.grid_dim_x % 2 == 0: + # even blocks + scale = torch.cat([scale, -scale], dim=2) # [1, 1, q, 1] + else: + # odd blocks + if self.grid_dim_x > 1: + scale = torch.cat([morr_output_scale, -scale], dim=2) # [1, 1, q, 1] + else: + scale = scale_pad # [1, 1, q, 1] + morr_output_scale = scale.squeeze(-1).unsqueeze(0) # [1 ,1, 1, q] + + return weight, morr_output_scale + + def enable_fast_forward(self) -> None: + self.fast_forward_flag = True + + def disable_fast_forward(self) -> None: + self.fast_forward_flag = False + + def set_gamma_noise( + self, noise_std: float, random_state: Optional[int] = None + ) -> None: + self.gamma_noise_std = noise_std + + def load_parameters(self, param_dict) -> None: + """ + description: update parameters based on this parameter dictionary\\ + param param_dict {dict of dict} {layer_name: {param_name: param_tensor, ...}, ...} + """ + for name, param in param_dict.items(): + getattr(self, name).data.copy_(param) + + def set_weight_bitwidth(self, w_bit: int) -> None: + self.w_bit = w_bit + self.weight_quantizer.set_bitwidth(w_bit) + self.morr_output_scale_quantizer.set_bitwidth(w_bit) + + def set_input_bitwidth(self, in_bit: int) -> None: + self.in_bit = in_bit + self.input_quantizer.set_bitwidth(in_bit) + + def input_modulator(self, x: Tensor) -> Tensor: + ### voltage to power, which is proportional to the phase shift + return x * x + + def set_crosstalk_coupling_matrix( + self, coupling_factor: float, drop_perc: float = 0 + ) -> None: + ### crosstalk coupling matrix is a symmetric matrix, but the intra-MORR crosstalk can be taken as a round-trip phase shift scaling factor, which is proportional to the number of segments after pruned. + ### drop-perc is the pruning percentage. + assert 0 <= coupling_factor <= 1, logger.error( + f"Coupling factor must in [0,1], but got {coupling_factor}" + ) + + self.crosstalk_factor = ( + 1 + max(3, (self.miniblock * (1 - drop_perc) - 1)) * coupling_factor + ) + + def enable_crosstalk(self) -> None: + self.enable_thermal_crosstalk = True + + def disable_crosstalk(self) -> None: + self.enable_thermal_crosstalk = False + + def set_phase_variation(self, phase_noise_std: float = 0) -> None: + self.phase_noise_std = phase_noise_std + + def enable_phase_variation(self) -> None: + self.enable_phase_noise = True + + def disable_phase_variation(self) -> None: + self.enable_phase_noise = False + + def enable_trainable_morr_scale(self) -> None: + self.trainable_morr_scale = True + + def disable_trainable_morr_scale(self) -> None: + self.trainable_morr_scale = False + + def enable_trainable_morr_bias(self) -> None: + self.trainable_morr_bias = True + + def disable_trainable_morr_bias(self) -> None: + self.trainable_morr_bias = False + + @property + def morr_bias(self) -> Tensor: + if self.morr_input_bias is None: + return None + # return 2 * self.morr_fwhm * torch.sigmoid(self.morr_input_bias.unsqueeze(0).unsqueeze(-1)) + return self.morr_fwhm * torch.tanh( + self.morr_input_bias.unsqueeze(0).unsqueeze(-1) + ) + + @property + def morr_scale(self) -> Tensor: + if self.morr_input_scale is None: + return None + return torch.sigmoid(self.morr_input_scale.unsqueeze(-1)) + 0.2 # [p, q, 1] + + def propagate_morr( + self, weight: Tensor, x: Tensor, morr_output_scale: Tensor + ) -> Tensor: + """ + @description: propagate through the analytically calculated transfer matrix of molg. We implement circulant matrix multiplication using fast circ matmul + @param weight {torch.Tensor} two phase shifters in the MZI-based attenuators + @param x {torch.Tensor} complex-valued input + @param morr_output_scale {torch.Tensor} learnable balancing factors + @return: y {torch.Tensor} output of attenuators + """ + ### x : [bs, q, k] + ### weights: [p, q, k] + ### morr_output_scale: [1, 1, 1, q] + + ### input scaling [TCAD'21], must have valid ranges. too small will have dead neuron and not enough nonlinearity; too large will have larger power, cross-channel crosstalk. [0.2 - 1.2] will be suitable + ## build circulant weight matrix + # crosstalk on the weights are much cheaper to compute than on the phase shift + if self.enable_thermal_crosstalk and self.crosstalk_factor > 1: + weight = weight * self.crosstalk_factor + weight = toeplitz(weight).unsqueeze(0) # [1, p, q, k, k] + x = x.unsqueeze(1).unsqueeze(-1) # [bs, 1, q, k, 1] + x = weight.matmul(x).squeeze(-1) # [bs, p, q, k] + + if self.enable_phase_noise and self.phase_noise_std > 1e-5: + x = x + torch.zeros_like(x).normal_(0, self.phase_noise_std) + + ### input biasing [TCAD'21], must have valid ranges. too large will have power issue and cross-channel crosstalk. [-2FWHM ~ 0] + if self.trainable_morr_bias: + x = x - self.morr_bias + + ### Use theoretical transmission function for trainable MORR nonlinearity [TCAD'21] + ### x is the phase detuning, x=0 means on-resonance + ### phase: [bs, p, q, k] + x = self.mrr_roundtrip_phase_to_tr(x) # 3x faster than autograd + + ## implement balancing factor as dot-product + """ + if(self.w_bit < 16): + morr_output_scale = self.morr_output_scale_quantizer(self.morr_output_scale) + if(self.sigma_out_scale_quant_gain is None): + self.sigma_out_scale_quant_gain = self.sigma_out_scale / morr_output_scale.data.std().item() + morr_output_scale = morr_output_scale.mul(self.sigma_out_scale_quant_gain)### gain factor from Tanh used in quantization + else: + morr_output_scale = self.morr_output_scale + # morr_output_scale = morr_output_scale * self.morr_gain + scale = morr_output_scale[..., :-1, :] + scale_pad = morr_output_scale[..., -1:, :] + + # print("morr diff transmission:", end=", ") + # diff = x[..., :x.size(2)//2,:]-x[..., x.size(2)//2:,:] + # print_stat(diff) + if(self.grid_dim_x % 2 == 0): + #even blocks + scale = torch.cat([scale, -scale], dim=2) # [1, 1, q, 1] + else: + # odd blocks + if(self.grid_dim_x > 1): + scale = torch.cat([morr_output_scale, -scale], dim=2) # [1, 1, q, 1] + else: + scale = scale_pad # [1, 1, q, 1] + scale = scale.squeeze(-1).unsqueeze(0) # [1 ,1, 1, q] + # print("output scale Q:", end=", ") + # print_stat(scale[..., :scale.size(-1)//2]) + """ + x = morr_output_scale.matmul(x) # [1, 1, 1, q] x [bs, p, q, k] = [bs, p, 1, k] + x = x.flatten(1) # [bs, p*k] + return x + + def get_finegrain_drop_mask(self, topk: int) -> Tensor: + if self.w_bit < 16: + weight = self.weight_quantizer(self.weight.data) # [p, q, k] + else: + weight = self.weight.data.abs() + indices = weight.argsort(dim=-1) + mask = torch.ones_like(weight, dtype=torch.bool, device=weight.device) + + drop_indices = indices[:, :, 0:-topk] + mask.scatter_(2, drop_indices, 0) + self.finegrain_drop_mask = mask + return mask + + def apply_finegrain_drop_mask(self, mask: Tensor) -> None: + if self.w_bit < 16: + self.weight.data.masked_fill_(~mask.view_as(self.weight.data), -1000) + else: + self.weight.data.masked_fill_(~mask.view_as(self.weight.data), 0) + + def forward(self, x: Tensor) -> Tensor: + output, *_ = morr_linear_fn( + x, + self.weight, + morr_input_bias = self.morr_input_bias, + morr_output_scale = self.morr_output_scale, + bias = None, + morr_bias = self.morr_bias, + grid_dim_x = self.grid_dim_x, + grid_dim_y = self.grid_dim_y, + miniblock = self.miniblock, + enable_thermal_crosstalk=self.enable_thermal_crosstalk, + crosstalk_factor=None if not self.enable_thermal_crosstalk else self.crosstalk_factor, + enable_phase_noise=self.enable_phase_noise, + phase_noise_std=None if not self.enable_phase_noise else self.phase_noise_std, + trainable_morr_bias=self.trainable_morr_bias, + mrr_a=self.mrr_a, + mrr_r=self.mrr_r, + finegrain_drop_mask=None, + in_features = self.in_features, + in_features_pad = self.in_features_pad, + out_features = self.out_features, + out_features_pad = self.out_features_pad, + in_bit = self.in_bit, + w_bit = self.w_bit, + morr_fwhm = self.morr_fwhm, + ) + return output \ No newline at end of file diff --git a/src/chop/nn/optical/triton_modules/morr_linear_kernel.py b/src/chop/nn/optical/triton_modules/morr_linear_kernel.py new file mode 100644 index 000000000..a6852e4e5 --- /dev/null +++ b/src/chop/nn/optical/triton_modules/morr_linear_kernel.py @@ -0,0 +1,733 @@ +import os +# os.environ["TRITON_INTERPRET"] = "1" + +import torch +from torch import Tensor +import triton +import triton.language as tl +import pdb + +from .dtype import TORCH_DTYPE_TO_TRITON +PACKAGE_NAME = "mase_triton" +from ..utils import toeplitz +from .quantize import _input_quantize_fn, _weight_quantize_fn + + +def _get_autotune_configs(): + configs = [] + for _M in [1, 2, 4, 8]: + for _P in [1, 2, 4, 8]: + for _Q in [1, 2, 4, 8]: + configs.append( + triton.Config( + { + "BLOCK_SIZE_M": _M, + "BLOCK_SIZE_P": _P, + "BLOCK_SIZE_Q": _Q, + # "BLOCK_SIZE_K1": 4, + "BLOCK_SIZE_K2": 1, + }, + num_stages=3, + num_warps=8, + ) + ) + return configs + +@triton.jit +def _mrr_roundtrip_phase_to_tr_func( + x: tl.tensor, + a: tl.constexpr = 0.8, + r: tl.constexpr = 0.9, + intensity: tl.constexpr = False, +): + """ + Applies a round-trip phase correction to the input tensor. + """ + c1 = -2.0 * a * r + c2 = a * a + r * r + c3 = 1.0 + r * r * a * a - a * a - r * r + + cos_x = tl.cos(x) + numerator = cos_x * c1 + c2 + denominator = numerator + c3 + x = numerator / denominator + if not intensity: + x = tl.sqrt(x) + return x + +# @triton.autotune( +# configs = _get_autotune_configs(), +# key=["M", "P", "Q", "K"], +# ) +@triton.autotune( + configs= [ + triton.Config( + { + "BLOCK_SIZE_M": 1, + "BLOCK_SIZE_P": 1, + "BLOCK_SIZE_Q": 1, + # "BLOCK_SIZE_K1": 4, + "BLOCK_SIZE_K2": 1, + }, + num_stages=3, + num_warps=8, + ),], + key=["M", "P", "Q", "K"], +) +@triton.jit +def morr_propagate_kernel( + x_ptr, + w_ptr, + o_ptr, + b_ptr, + M, + P, + Q, + K, + grid_dim_q, + grid_dim_p, + miniblock, + crosstalk_factor, + phase_noise_std, + mrr_a, + mrr_r, + in_bit, + w_bit, + seed, + # stride + stride_wm, stride_wp, stride_wq, stride_wk1, stride_wk2, + stride_xm, stride_xp, stride_xq, stride_xk1, stride_xk2, + stride_bm, stride_bp, stride_bq, stride_bk1, + stride_om, stride_op, stride_oq, stride_ok1, stride_ok2, + finegrain_drop_mask, + ENABLE_PHASE_NOISE: tl.constexpr, + ENABLE_THERMAL_CROSSTALK: tl.constexpr, + TRAINABLE_MORR_BIAS: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_P: tl.constexpr, + BLOCK_SIZE_Q: tl.constexpr, + BLOCK_SIZE_K1: tl.constexpr, + BLOCK_SIZE_K2: tl.constexpr, + INPUT_DTYPE: tl.constexpr, +): + + # Program ID for block-based processing + # each program is assigned GROUP_SIZE_MPQ * [1, 1, miniblock, 1] block + pid = tl.program_id(axis=0) + pid_m = pid // (grid_dim_q * grid_dim_p) + pid_p = (pid // grid_dim_q) % grid_dim_p + pid_q = pid % grid_dim_q + + # starting element's m, p, q coordinates in the global tensor + start_m = pid_m * BLOCK_SIZE_M + start_p = pid_p * BLOCK_SIZE_P + start_q = pid_q * BLOCK_SIZE_Q + + # w [1, p, q, k, 1] -> toeplitz [1, p, q, k, k] + offs_wm = tl.arange(0, 1) + offs_wp = pid_p * BLOCK_SIZE_P + tl.arange(0, 1) + offs_wq = pid_q * BLOCK_SIZE_Q + tl.arange(0, 1) + offs_wk1 = tl.arange(0, BLOCK_SIZE_K1) + offs_wk2 = tl.arange(0, BLOCK_SIZE_K1) + + offs_xm = pid_m * BLOCK_SIZE_M + tl.arange(0, 1) + offs_xp = tl.arange(0, 1) + offs_xq = pid_q * BLOCK_SIZE_Q + tl.arange(0, 1) + offs_xk1 = tl.arange(0, BLOCK_SIZE_K1) + offs_xk2 = tl.arange(0, BLOCK_SIZE_K2) + # morr_bias: [1, p, q, 1] + offs_bm = tl.arange(0, 1) + offs_bp = pid_p * BLOCK_SIZE_P + tl.arange(0, 1) + offs_bq = pid_q * BLOCK_SIZE_Q + tl.arange(0, 1) + offs_bk1 = tl.arange(0, 1) + + w_ptrs = w_ptr + ( + offs_wm[:, None, None, None, None] * stride_wm + + offs_wp[None, :, None, None, None] * stride_wp + + offs_wq[None, None, :, None, None] * stride_wq + + offs_wk1[None, None, None, :, None] * stride_wk1 + + offs_wk2[None, None, None, None, :] * stride_wk2 + ) + x_ptrs = x_ptr + ( + offs_xm[:, None, None, None, None] * stride_xm + + offs_xp[None, :, None, None, None] * stride_xp + + offs_xq[None, None, :, None, None] * stride_xq + + offs_xk1[None, None, None, :, None] * stride_xk1 + + offs_xk2[None, None, None, None, :] * stride_xk2 + ) + b_ptrs = b_ptr + ( + offs_bm[:, None, None, None, None] * stride_bm + + offs_bp[None, :, None, None, None] * stride_bp + + offs_bq[None, None, :, None, None] * stride_bq + + offs_bk1[None, None, None, :, None] * stride_bk1 + ) + + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_P, BLOCK_SIZE_Q, BLOCK_SIZE_K1, BLOCK_SIZE_K2), dtype=tl.float32) + m_indices = tl.arange(0, BLOCK_SIZE_M)[:, None, None, None, None] + p_indices = tl.arange(0, BLOCK_SIZE_P)[None, :, None, None, None] + q_indices = tl.arange(0, BLOCK_SIZE_Q)[None, None, :, None, None] + + for m_local in range(BLOCK_SIZE_M): + m = start_m + m_local + for p_local in range(BLOCK_SIZE_P): + p = start_p + p_local + for q_local in range(BLOCK_SIZE_Q): + q = start_q + q_local + + w_mask = (p < P) & (q < Q) + x_mask = (m < M) & (q < Q) + b_mask = (p < P) & (q < Q) + + w = tl.load(w_ptrs, mask=w_mask, other=0.0) + x = tl.load(x_ptrs, mask=x_mask, other=0.0) + b = tl.load(b_ptrs, mask=b_mask, other=0.0) + + # TODO: Test Quantization Function + # if in_bit < 16: + # x = _input_quantize_fn(x) + + # ----- build_weight() ----- + # TODO: fix quantization func + # if w_bit < 16: + # w = _weight_quantize_fn(w) + # else: + # w = tl.abs(w) + + w = tl.abs(w).reshape(BLOCK_SIZE_K1, BLOCK_SIZE_K1) # [1, 1, 1, k, k] -> [k, k] + x = x.reshape(BLOCK_SIZE_K1, BLOCK_SIZE_K2) # [1, 1, 1, k, 1] -> [k, 1] + + if finegrain_drop_mask is not None: + w *= tl.cast(finegrain_drop_mask, tl.float32) + + x = x * x # input_modulator() + # ----- propagate_morr() ----- + + # apply thermal crosstalk noise + if ENABLE_THERMAL_CROSSTALK: + w = w * crosstalk_factor + + # MatMals + # TODO: tl.dot requires 16*16 matrix at least, this is a workaround + x = tl.trans(x) + x = tl.broadcast_to(x, (BLOCK_SIZE_K1, BLOCK_SIZE_K1)) + x = tl.sum(w * x, axis=1) + x = tl.reshape(x, (BLOCK_SIZE_K1, BLOCK_SIZE_K2)) + + # apply phase noise + if ENABLE_PHASE_NOISE: + block_start = pid * BLOCK_SIZE_K1 * BLOCK_SIZE_K2 + offs = tl.reshape(block_start + tl.arange(0, BLOCK_SIZE_K1 * BLOCK_SIZE_K2) , (BLOCK_SIZE_K1, BLOCK_SIZE_K2)) + noise = tl.randn(seed, offs) * phase_noise_std + x = x + noise + + # add trainable bias + b = b.reshape(1, 1) + # pdb.set_trace() + if TRAINABLE_MORR_BIAS: + x = x - b + + # mrr_roundtrip_phase_to_tr + x = _mrr_roundtrip_phase_to_tr_func(x, mrr_a, mrr_r, intensity=True) + + # store the value in acc using mask + res = x + condition_mask = (m_indices == m_local) & (p_indices == p_local) & (q_indices == q_local) + res = res[None, None, None, :, :] + acc = tl.where(condition_mask, res, acc) + + # propagate pointer along Q dimension + w_ptrs += stride_wq + x_ptrs += stride_xq + b_ptrs += stride_bq + + # Q loop end + # reset pointer along Q dimension + w_ptrs -= stride_wq * (BLOCK_SIZE_Q) + x_ptrs -= stride_xq * (BLOCK_SIZE_Q) + b_ptrs -= stride_bq * (BLOCK_SIZE_Q) + # propagate pointer along P dimension + w_ptrs += stride_wp + b_ptrs += stride_bp + # x_ptrs += stride_xp # x has P dimension = 1 + + # P loop end + # reset pointer along P dimension + w_ptrs -= stride_wp * (BLOCK_SIZE_P) + # x_ptrs -= stride_xp * (BLOCK_SIZE_P + 1) # x has P dimension = 1、 + + # propagate pointer along M dimension + # w_ptrs += stride_wp # weight has M dimension = 1 + x_ptrs += stride_xm + + + out = acc.to(INPUT_DTYPE) + out = out.reshape(BLOCK_SIZE_M, BLOCK_SIZE_P, BLOCK_SIZE_Q, BLOCK_SIZE_K1) # [1, 1, q, k, 1] -> [1, 1, q, k] + + offs_om = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_op = pid_p * BLOCK_SIZE_P + tl.arange(0, BLOCK_SIZE_P) + offs_oq = pid_q * BLOCK_SIZE_Q + tl.arange(0, BLOCK_SIZE_Q) + offs_ok1 = tl.arange(0, BLOCK_SIZE_K1) + # offs_ok2 = tl.arange(0, BLOCK_SIZE_K2) + o_ptrs = o_ptr + ( + stride_om * offs_om[:, None, None, None] + + stride_op * offs_op[None, :, None, None] + + stride_oq * offs_oq[None, None, :, None] + + stride_ok1 * offs_ok1[None, None, None, :] + ) + + m_valid = offs_om[:, None, None, None] < M + p_valid = offs_op[None, :, None, None] < P + q_valid = offs_oq[None, None, :, None] < Q + k_valid = offs_ok1[None, None, None, :] < K # K == BLOCK_SIZE_K1 + o_mask = m_valid & p_valid & q_valid & k_valid + tl.store(o_ptrs, out, mask=o_mask) + +@torch.library.custom_op( + f"{PACKAGE_NAME}::optical_morr_linear_linear_fn", mutates_args={}, +) +def morr_linear_fn( + x: Tensor, + weight: Tensor, + morr_input_bias: Tensor, + morr_output_scale: Tensor, + bias: Tensor | None, + morr_bias: Tensor | None, + grid_dim_x: int, + grid_dim_y: int, + miniblock: int, + enable_thermal_crosstalk: bool, + crosstalk_factor: float | None, + enable_phase_noise: bool, + phase_noise_std: float | None, + trainable_morr_bias: bool, + mrr_a: float, + mrr_r: float, + finegrain_drop_mask: Tensor | None, + in_features: int, + in_features_pad: int, + out_features: int, + out_features_pad: int, + in_bit: int, + w_bit: int, + morr_fwhm: float, + seed: int=42, +) -> tuple[Tensor, int, Tensor, Tensor, Tensor, Tensor]: + + assert x.dtype in ( + torch.bfloat16, + torch.float16, + torch.float32, + ), f"Unsupported dtype {x.dtype}" + assert x.is_contiguous(), "Input tensor must be contiguous" + assert weight.dtype in ( + torch.bfloat16, + torch.float16, + torch.float32, + ), f"Unsupported dtype {weight.dtype}" + + # Handle transformer vs non-transformer inputs + ori_x_shape = x.shape + is_transformer = len(ori_x_shape) == 3 + + if is_transformer: + in_B, in_N, in_D = x.shape + M = in_B * in_N + x = x.reshape(M, in_D) + else: + M = x.shape[0] + + # Get dimensions + M, D = x.shape + P, Q, K = weight.shape + + if in_features_pad > D: + x_pad = torch.zeros(M, in_features_pad - D, device=x.device, dtype=x.dtype) + x = torch.cat([x, x_pad], dim=1) + + assert Q * K == in_features_pad, "input and weight dimension mismatch" + assert P * K == out_features_pad, "weight and output dimension mismatch" + + # Reshape x and weight + x = x.view(-1, grid_dim_x, miniblock) # [M, q, k] + x = x.unsqueeze(1).unsqueeze(-1) # [M, 1, q, k, 1] + weight = toeplitz(weight).unsqueeze(0) # [p, q, k] -> [1, p, q, k, k] + + x_ctx = x.squeeze(-1).squeeze(1).clone() # [M, q, k] + w_ctx = weight.clone() + + # Allocate output + output = torch.empty((M, P, Q, K, 1), device=x.device, dtype=x.dtype) + # Launch the Triton kernel + grid = lambda meta: ( + triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(P, meta["BLOCK_SIZE_P"]) * triton.cdiv(Q, meta["BLOCK_SIZE_Q"]), + ) + morr_propagate_kernel[grid]( + x_ptr = x, + w_ptr = weight, + o_ptr = output, + b_ptr = morr_bias, + M=M, + P=P, + Q=Q, + K=K, + grid_dim_q=grid_dim_x, + grid_dim_p=grid_dim_y, + miniblock=miniblock, + crosstalk_factor=crosstalk_factor, + phase_noise_std=phase_noise_std, + mrr_a=mrr_a, + mrr_r=mrr_r, + in_bit=in_bit, + w_bit=w_bit, + seed=seed, + finegrain_drop_mask=finegrain_drop_mask, + stride_wm=weight.stride(0), + stride_wp=weight.stride(1), + stride_wq=weight.stride(2), + stride_wk1=weight.stride(3), + stride_wk2=weight.stride(4), + stride_xm=x.stride(0), + stride_xp=x.stride(1), + stride_xq=x.stride(2), + stride_xk1=x.stride(3), + stride_xk2=x.stride(4), + stride_bm=morr_bias.stride(0) if morr_bias is not None else 0, + stride_bp=morr_bias.stride(1) if morr_bias is not None else 0, + stride_bq=morr_bias.stride(2) if morr_bias is not None else 0, + stride_bk1=morr_bias.stride(3) if morr_bias is not None else 0, + stride_om=output.stride(0), + stride_op=output.stride(1), + stride_oq=output.stride(2), + stride_ok1=output.stride(3), + stride_ok2=output.stride(4), + ENABLE_THERMAL_CROSSTALK=enable_thermal_crosstalk, + ENABLE_PHASE_NOISE=enable_phase_noise and phase_noise_std > 1e-4, + TRAINABLE_MORR_BIAS = trainable_morr_bias, + INPUT_DTYPE=TORCH_DTYPE_TO_TRITON[x.dtype], + BLOCK_SIZE_K1 = K, + ) + + # ----- build_weight() morr_output_scale part ----- + if w_bit < 16: + morr_output_scale = _weight_quantize_fn(morr_output_scale) + else: + morr_output_scale = (morr_output_scale - morr_output_scale.data.mean()) + + # differential balancing factor concatenation + scale = morr_output_scale[..., :-1, :] + scale_pad = morr_output_scale[..., -1:, :] + if grid_dim_x % 2 == 0: + # even blocks + scale = torch.cat([scale, -scale], dim=2) # [1, 1, q, 1] + else: + # odd blocks + if grid_dim_x > 1: + scale = torch.cat([morr_output_scale, -scale], dim=2) # [1, 1, q, 1] + else: + scale = scale_pad # [1, 1, q, 1] + morr_output_scale = scale.squeeze(-1).unsqueeze(0) # [1 ,1, 1, q] + ctx_morr_output_scale = morr_output_scale.clone() + + # Apply output scale + output = output.squeeze(-1) # [m, p, q, k, 1] -> [m, p, q, k] + ctx_x_scalematmul = output.clone() # record x input for matmul + output = morr_output_scale.matmul(output) # [1, 1, 1, q] x [bs, p, q, k] = [bs, p, 1, k] + output = output.flatten(1) # [bs, p*k] + + # Trim output if needed + if out_features < out_features_pad: + output = output[:, :out_features] + if bias is not None: + x = x + bias.unsqueeze(0) + # Reshape back for transformer + if is_transformer: + output = output.view(in_B, in_N, out_features) + + # aux_tensor = ( + # torch.abs(w_ctx), # w_morr: weight in propagate_morr matmul + # x_ctx, # x_modulator: x before x^2 + # ) + + return output, seed, torch.abs(w_ctx), x_ctx, ctx_morr_output_scale, ctx_x_scalematmul + + + +def _morr_linear_setup_context(ctx, inputs, output): + """ + Save for backward only what the backward routine really needs. + """ + ( + x, # 0 Tensor – input + weight, # 1 Tensor – learnable weight + morr_input_bias, # 23 Tensor + _, # 3 morr_output_scale + bias, # 4 Tensor | None – bias + morr_bias, # 2 Tensor | None + grid_dim_x, # 5 int + grid_dim_y, # 6 int + miniblock, # 7 int (== K) + enable_thermal_crosstalk,# 8 bool + crosstalk_factor, # 9 float + enable_phase_noise, # 10 bool + phase_noise_std, # 11 float + trainable_morr_bias, # 12 bool + mrr_a, # 13 float + mrr_r, # 14 float + finegrain_drop_mask, # 15 Tensor | None + in_features, # 16 int + in_features_pad, # 17 int + out_features, # 18 int + out_features_pad, # 19 int + in_bit, # 20 int + w_bit, # 21 int + morr_fwhm, # 22 float + seed, + ) = inputs + + output, seed, w_morr, x_modulator, morr_output_scale, x_scalematmul = output + # ( + # w_morr, + # x_modulator, + # ) = aux_tensor + + device, dtype = x.device, x.dtype + + # ----- Tensor meta-data that backward needs ----- + # Shapes + M = x.shape[0] if x.dim() == 2 else x.shape[0] * x.shape[1] + P, Q, K = weight.shape + tensor_shape = (M, P, Q, K) + + # mrr_para: para for mrr_roundtrip_phase_to_tr() + c1 = -2.0 * mrr_a * mrr_r + c2 = mrr_a * mrr_a + mrr_r * mrr_r + c3 = 1.0 + (mrr_r * mrr_r) * (mrr_a * mrr_a) - mrr_a * mrr_a - mrr_r * mrr_r + c4 = (mrr_a**2.0 - 1.0) * (mrr_r**2.0 - 1.0) * 2.0 * mrr_a * mrr_r + intensity = True + mrr_para = (c1, c2, c3, c4, intensity) + + # x_morr: x input of matmal in propagate_morr() + x_morr = x_modulator ** 2 # [m, q, k] + x_morr = x_morr.unsqueeze(1).unsqueeze(-1) # [m, 1, q, k, 1] + + # x_mrr: x input of mrr_roundtrip_phase_to_tr() + x_mrr = w_morr.matmul(x_morr).squeeze(-1) + if enable_phase_noise and phase_noise_std > 1e-5: + x_mrr = x_mrr + torch.zeros_like(x_mrr).normal_(0, phase_noise_std) + if trainable_morr_bias: + x_mrr = x_mrr - morr_bias + + tanh_input_bias = torch.tanh(morr_input_bias.unsqueeze(0).unsqueeze(-1)) + + # 3. stash tensors + ctx.save_for_backward( + x, # original input + weight.sign(), # original weight's sign + # TODO: complete self.tensor + bias if bias is not None else torch.tensor([], device=device, dtype=dtype), + morr_output_scale, # original morr_output_scale + x_mrr, # x input for mrr_roundtrip_phase_to_tr() + x_morr, + w_morr, # w input for propagate_morr() matmul + # morr_bias, + x_modulator, # x input for input_modulator() + # morr_input_bias, + x_scalematmul, # x input for morr_output_scale.matmul + tanh_input_bias, + ) + ctx.tensor_shape = tensor_shape + ctx.mrr_para = mrr_para + ctx.in_features = in_features + ctx.in_features_pad = in_features_pad + ctx.out_features = out_features + ctx.out_features_pad = out_features_pad + ctx.morr_fwhm = morr_fwhm + ctx.grid_dim_x = grid_dim_x + ctx.grid_dim_y = grid_dim_y + ctx.w_bit = w_bit + ctx.x_input_shape = x.shape + ctx.device = x.device + ctx.w_input_shape = weight.shape + ctx.morr_fwhm = morr_fwhm + ctx.enable_phase_noise = enable_phase_noise + ctx.phase_noise_std = phase_noise_std + ctx.trainable_morr_bias = trainable_morr_bias + + + +def _morr_linear_backward(ctx, grad_output, *ignored): + """ + Backward pass for morr_linear_fn. + """ + ( + x, + w_input_sign, + bias, + morr_output_scale, + x_mrr, + x_morr, + w_morr, + # morr_bias, + x_modulator, + # morr_input_bias, + x_scalematmul, + tanh_input_bias + + ) = ctx.saved_tensors + + M, P, Q, K = ctx.tensor_shape + c1, c2, c3, c4, intensity = ctx.mrr_para + in_features = ctx.in_features + in_features_pad = ctx.in_features_pad + out_features = ctx.out_features + out_features_pad = ctx.out_features_pad + x_input_shape = ctx.x_input_shape + w_input_shape = ctx.w_input_shape + DEVICE = ctx.device + + # --- calculate intermediate activation on the fly --- + # x_morr = (x_modulator ** 2).unsqueeze(1).unsqueeze(-1) # [m, q, k] -> # [m, 1, q, k, 1] + + # tanh_input_bias = torch.tanh(morr_input_bias.unsqueeze(0).unsqueeze(-1)) + # morr_bias = ctx.morr_fwhm * tanh_input_bias + + # # x_mrr: x input of mrr_roundtrip_phase_to_tr() + # x_mrr = w_morr.matmul(x_morr).squeeze(-1) + # if ctx.enable_phase_noise and ctx.phase_noise_std > 1e-5: + # x_mrr = x_mrr + torch.zeros_like(x_mrr).normal_(0, ctx.phase_noise_std) + # if ctx.trainable_morr_bias: + # x_mrr = x_mrr - morr_bias + + # ----- backward prop ----- + # Reshape + grad_out = grad_output.view( + x_input_shape[0], + w_input_shape[1], + w_input_shape[2], + -1 + ) # [M, P, Q, K] + + # ----- Gradient w.r.t input x ----- + if ctx.needs_input_grad[0]: + # 1. reshape + grad_out = grad_out.view(M, -1) # [m, out_features] + + if ctx.needs_input_grad[4] and bias: + grad_bias = grad_out.sum(dim=0) # [out_features] + else: + grad_bias = None + + out_pad = torch.zeros(grad_out.shape[0], out_features_pad-out_features, device = DEVICE) # [m, out_features_pad - out_features] + grad_out = torch.cat([grad_out, out_pad], dim=1) # [m * out_features_pad] = [m, p*k] + + # 2. x=x.flatten(1) + # input: [m, p**k] + grad_out = grad_out.view(M, P, 1, K) # [m, p, 1, k] + + # 3. x = morr_output_scale.matmul(x) # [1, 1, 1, q] x [bs, p, q, k] = [bs, p, 1, k] + # dL/d(morr_output_scale) + if ctx.needs_input_grad[3]: + grad_s = grad_out.matmul(x_scalematmul.transpose(-2, -1)) # [bs, p, 1, q] + grad_s = grad_s.sum(dim=(0, 1)).unsqueeze(0).unsqueeze(1) # [1, 1, 1, q] + grad_s = grad_s.squeeze(0).unsqueeze(-1) # [1, 1, q, 1] gradient of scale + + t = ctx.grid_dim_x // 2 + grad_scale = grad_s.new_zeros((1, 1, t+1, 1)) + + if ctx.grid_dim_x % 2 == 0: + grad_scale[..., :t, :] = grad_s[..., :t, :] - grad_s[..., t:, :] + elif ctx.grid_dim_x == 1: + grad_scale = grad_s + else: + grad_scale[..., :t, :] = grad_s[..., :t, :] - grad_s[..., t+1:, :] + grad_scale[..., t:t+1, :] = grad_s[..., t:t+1, :] + + if ctx.w_bit < 16: + # TODO: backprop of weight_quantizer + raise NotImplementedError("quantization not supported") + else: + grad_scale = None + + # dL/dx + grad_x = morr_output_scale.transpose(-2, -1).matmul(grad_out) # [bs, p, q, k] + + # 4. x = mrr_roundtrip_phase_to_tr(x) + denominator = x_mrr.cos().mul_(c1).add_(c2 + c3) + if intensity: + denominator.square_() + numerator = x_mrr.sin().mul_(c4) + else: + numerator = x_mrr.sin().mul_(c4 / 2) + denominator = ( + denominator.sub(1).pow_(1.5).mul_(denominator.sub(c3).sqrt_()) + ) + grad_x = numerator.div_(denominator).mul_(grad_x) # [bs, p, q, k] + + # 5. x += phase_noise and morr_bias + if ctx.needs_input_grad[2]: + grad_inputbias = - grad_x # [bs, p, q, k] + grad_inputbias = grad_inputbias * ctx.morr_fwhm # [bs, p, q, k] + grad_inputbias = grad_inputbias - tanh_input_bias * tanh_input_bias # [bs, p, q, k] + grad_inputbias = grad_inputbias.sum(dim=(0, -1)) + else: + grad_inputbias = None + + # 6. x = weight.matmul(x) [1, p, q, k, k] * [bs, 1, q, k, 1] = [bs, p, q, k, 1] + grad_x = grad_x.unsqueeze(-1) # [bs, p, q, k, 1] + grad_morr_matmul = grad_x # stash for weight gradient + + # dL/dx + grad_x = torch.matmul(w_morr.transpose(-1, -2), grad_x) # [1, p, q, k, k] x [bs, p, q, k, 1] = [bs, p, q, k, 1] + grad_x = grad_x.sum(dim=1, keepdim=True) # [bs, p, q, k, 1] -> [bs, 1, q, k, 1] + grad_x = grad_x.squeeze(-1).squeeze(1) # [bs, 1, q, k, 1] -> [bs, q, k] + + # 7. input modulator + grad_x = grad_x * 2 * x_modulator # [bs, q, k] + + # 8. input reshape + grad_x = grad_x.view(x_input_shape) + grad_x = grad_x[:, :in_features] + + + + # ----- Gradient w.r.t weight ----- + if ctx.needs_input_grad[1]: + + # 0. gradient after x = weight.matmul(x) + # grad_morr_matmul # [bs, p, q, k, 1] + + # 1. x = weight.matmul(x) + grad_w = torch.matmul(grad_morr_matmul, x_morr.transpose(-1,-2)) # [bs,p,q,k,k] + grad_w = grad_w.sum(dim=0, keepdim=True) # [1,p,q,k,k] + + # 2. weight = toeplitz(weight) + k = grad_w.size(-1) + row = torch.arange(k)[:, None] # (k,1) + col = torch.arange(k)[None, :] # (1,k) + idx = (row - col) & (k - 1) if (k & (k-1)) == 0 else (row - col + k) % k + + idx = idx.expand(grad_w.shape).to(DEVICE) + buffer = torch.zeros_like(grad_w, device=DEVICE) + buffer.scatter_add_(-2, idx, grad_w) # [1, p, q, k, k] + grad_w = buffer.sum(dim=-1, keepdim=True).squeeze(0).squeeze(-1) + + # 3. build_weight() weight = self.weight.abs() + grad_w = grad_w * w_input_sign + + return ( + grad_x, # ∂L/∂x + grad_w, # ∂L/∂w + grad_inputbias, # ∂L/∂morr_input_bias + grad_scale, # ∂L/∂morr_output_scale + grad_bias, # ∂L/∂bias + None, None, None, None, None, None, None, None, None, + None, None, None, + None, None, None, None, None, None, None, + None, + ) + + +morr_linear_fn.register_autograd( + _morr_linear_backward, setup_context=_morr_linear_setup_context, +) \ No newline at end of file diff --git a/src/chop/nn/optical/triton_modules/morr_linear_kernel_mem.py b/src/chop/nn/optical/triton_modules/morr_linear_kernel_mem.py new file mode 100644 index 000000000..b14e8dbf8 --- /dev/null +++ b/src/chop/nn/optical/triton_modules/morr_linear_kernel_mem.py @@ -0,0 +1,720 @@ +import os +# os.environ["TRITON_INTERPRET"] = "1" + +import torch +from torch import Tensor +import triton +import triton.language as tl +import pdb + +from .dtype import TORCH_DTYPE_TO_TRITON +PACKAGE_NAME = "mase_triton" +from ..utils import toeplitz +from .quantize import _input_quantize_fn, _weight_quantize_fn + + +def _get_autotune_configs(): + configs = [] + for _M in [1, 2, 4, 8]: + for _P in [1, 2, 4, 8]: + for _Q in [1, 2, 4, 8]: + configs.append( + triton.Config( + { + "BLOCK_SIZE_M": _M, + "BLOCK_SIZE_P": _P, + "BLOCK_SIZE_Q": _Q, + # "BLOCK_SIZE_K1": 4, + "BLOCK_SIZE_K2": 1, + }, + num_stages=3, + num_warps=8, + ) + ) + return configs + +@triton.jit +def _mrr_roundtrip_phase_to_tr_func( + x: tl.tensor, + a: tl.constexpr = 0.8, + r: tl.constexpr = 0.9, + intensity: tl.constexpr = False, +): + """ + Applies a round-trip phase correction to the input tensor. + """ + c1 = -2.0 * a * r + c2 = a * a + r * r + c3 = 1.0 + r * r * a * a - a * a - r * r + + cos_x = tl.cos(x) + numerator = cos_x * c1 + c2 + denominator = numerator + c3 + x = numerator / denominator + if not intensity: + x = tl.sqrt(x) + return x + +# @triton.autotune( +# configs= [ +# triton.Config( +# { +# "BLOCK_SIZE_M": 1, +# "BLOCK_SIZE_P": 1, +# "BLOCK_SIZE_Q": 1, +# # "BLOCK_SIZE_K1": 2, +# "BLOCK_SIZE_K2": 1, +# }, +# num_stages=3, +# num_warps=8, +# ),], +# key=["M", "P", "Q", "K"], +# ) +@triton.autotune( + configs = _get_autotune_configs(), + key=["M", "P", "Q", "K"], +) +@triton.jit +def morr_propagate_kernel( + x_ptr, + w_ptr, + o_ptr, + b_ptr, + M, + P, + Q, + K, + grid_dim_q, + grid_dim_p, + miniblock, + crosstalk_factor, + phase_noise_std, + mrr_a, + mrr_r, + in_bit, + w_bit, + # stride + stride_wm, stride_wp, stride_wq, stride_wk1, stride_wk2, + stride_xm, stride_xp, stride_xq, stride_xk1, stride_xk2, + stride_bm, stride_bp, stride_bq, stride_bk1, + stride_om, stride_op, stride_oq, stride_ok1, stride_ok2, + finegrain_drop_mask, + ENABLE_PHASE_NOISE: tl.constexpr, + ENABLE_THERMAL_CROSSTALK: tl.constexpr, + TRAINABLE_MORR_BIAS: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_P: tl.constexpr, + BLOCK_SIZE_Q: tl.constexpr, + BLOCK_SIZE_K1: tl.constexpr, + BLOCK_SIZE_K2: tl.constexpr, + INPUT_DTYPE: tl.constexpr, +): + + # Program ID for block-based processing + # each program is assigned GROUP_SIZE_MPQ * [1, 1, miniblock, 1] block + pid = tl.program_id(axis=0) + pid_m = pid // (grid_dim_q * grid_dim_p) + pid_p = (pid // grid_dim_q) % grid_dim_p + pid_q = pid % grid_dim_q + + # starting element's m, p, q coordinates in the global tensor + start_m = pid_m * BLOCK_SIZE_M + start_p = pid_p * BLOCK_SIZE_P + start_q = pid_q * BLOCK_SIZE_Q + + # w [1, p, q, k, 1] -> toeplitz [1, p, q, k, k] + offs_wm = tl.arange(0, 1) + offs_wp = pid_p * BLOCK_SIZE_P + tl.arange(0, 1) + offs_wq = pid_q * BLOCK_SIZE_Q + tl.arange(0, 1) + offs_wk1 = tl.arange(0, BLOCK_SIZE_K1) + offs_wk2 = tl.arange(0, BLOCK_SIZE_K1) + + offs_xm = pid_m * BLOCK_SIZE_M + tl.arange(0, 1) + offs_xp = tl.arange(0, 1) + offs_xq = pid_q * BLOCK_SIZE_Q + tl.arange(0, 1) + offs_xk1 = tl.arange(0, BLOCK_SIZE_K1) + offs_xk2 = tl.arange(0, BLOCK_SIZE_K2) + # morr_bias: [1, p, q, 1] + offs_bm = tl.arange(0, 1) + offs_bp = pid_p * BLOCK_SIZE_P + tl.arange(0, 1) + offs_bq = pid_q * BLOCK_SIZE_Q + tl.arange(0, 1) + offs_bk1 = tl.arange(0, 1) + + w_ptrs = w_ptr + ( + offs_wm[:, None, None, None, None] * stride_wm + + offs_wp[None, :, None, None, None] * stride_wp + + offs_wq[None, None, :, None, None] * stride_wq + + offs_wk1[None, None, None, :, None] * stride_wk1 + + offs_wk2[None, None, None, None, :] * stride_wk2 + ) + x_ptrs = x_ptr + ( + offs_xm[:, None, None, None, None] * stride_xm + + offs_xp[None, :, None, None, None] * stride_xp + + offs_xq[None, None, :, None, None] * stride_xq + + offs_xk1[None, None, None, :, None] * stride_xk1 + + offs_xk2[None, None, None, None, :] * stride_xk2 + ) + b_ptrs = b_ptr + ( + offs_bm[:, None, None, None, None] * stride_bm + + offs_bp[None, :, None, None, None] * stride_bp + + offs_bq[None, None, :, None, None] * stride_bq + + offs_bk1[None, None, None, :, None] * stride_bk1 + ) + + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_P, BLOCK_SIZE_Q, BLOCK_SIZE_K1, BLOCK_SIZE_K2), dtype=tl.float32) + m_indices = tl.arange(0, BLOCK_SIZE_M)[:, None, None, None, None] + p_indices = tl.arange(0, BLOCK_SIZE_P)[None, :, None, None, None] + q_indices = tl.arange(0, BLOCK_SIZE_Q)[None, None, :, None, None] + + for m_local in range(BLOCK_SIZE_M): + m = start_m + m_local + for p_local in range(BLOCK_SIZE_P): + p = start_p + p_local + for q_local in range(BLOCK_SIZE_Q): + q = start_q + q_local + + w_mask = (p < P) & (q < Q) + x_mask = (m < M) & (q < Q) + b_mask = (p < P) & (q < Q) + + w = tl.load(w_ptrs, mask=w_mask, other=0.0) + x = tl.load(x_ptrs, mask=x_mask, other=0.0) + b = tl.load(b_ptrs, mask=b_mask, other=0.0) + + # TODO: Test Quantization Function + # if in_bit < 16: + # x = _input_quantize_fn(x) + + # ----- build_weight() ----- + # TODO: fix quantization func + # if w_bit < 16: + # w = _weight_quantize_fn(w) + # else: + # w = tl.abs(w) + + w = tl.abs(w).reshape(BLOCK_SIZE_K1, BLOCK_SIZE_K1) # [1, 1, 1, k, k] -> [k, k] + x = x.reshape(BLOCK_SIZE_K1, BLOCK_SIZE_K2) # [1, 1, 1, k, 1] -> [k, 1] + + if finegrain_drop_mask is not None: + w *= tl.cast(finegrain_drop_mask, tl.float32) + + x = x * x # input_modulator() + # ----- propagate_morr() ----- + + # apply thermal crosstalk noise + if ENABLE_THERMAL_CROSSTALK: + w = w * crosstalk_factor + + # MatMals + # TODO: tl.dot requires 16*16 matrix at least, this is a workaround + x = tl.trans(x) + x = tl.broadcast_to(x, (BLOCK_SIZE_K1, BLOCK_SIZE_K1)) + x = tl.sum(w * x, axis=1) + x = tl.reshape(x, (BLOCK_SIZE_K1, BLOCK_SIZE_K2)) + + # apply phase noise + if ENABLE_PHASE_NOISE: + noise = tl.zeros_like(x) + tl.randn(x.shape) * phase_noise_std + x = x + noise + + # add trainable bias + b = b.reshape(1, 1) + # pdb.set_trace() + if TRAINABLE_MORR_BIAS: + x = x - b + + # mrr_roundtrip_phase_to_tr + x = _mrr_roundtrip_phase_to_tr_func(x, mrr_a, mrr_r, intensity=True) + + # store the value in acc using mask + res = x + condition_mask = (m_indices == m_local) & (p_indices == p_local) & (q_indices == q_local) + res = res[None, None, None, :, :] + acc = tl.where(condition_mask, res, acc) + + # propagate pointer along Q dimension + w_ptrs += stride_wq + x_ptrs += stride_xq + b_ptrs += stride_bq + + # Q loop end + # reset pointer along Q dimension + w_ptrs -= stride_wq * (BLOCK_SIZE_Q) + x_ptrs -= stride_xq * (BLOCK_SIZE_Q) + b_ptrs -= stride_bq * (BLOCK_SIZE_Q) + # propagate pointer along P dimension + w_ptrs += stride_wp + b_ptrs += stride_bp + # x_ptrs += stride_xp # x has P dimension = 1 + + # P loop end + # reset pointer along P dimension + w_ptrs -= stride_wp * (BLOCK_SIZE_P) + # x_ptrs -= stride_xp * (BLOCK_SIZE_P + 1) # x has P dimension = 1、 + + # propagate pointer along M dimension + # w_ptrs += stride_wp # weight has M dimension = 1 + x_ptrs += stride_xm + + + out = acc.to(INPUT_DTYPE) + out = out.reshape(BLOCK_SIZE_M, BLOCK_SIZE_P, BLOCK_SIZE_Q, BLOCK_SIZE_K1) # [1, 1, q, k, 1] -> [1, 1, q, k] + + offs_om = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_op = pid_p * BLOCK_SIZE_P + tl.arange(0, BLOCK_SIZE_P) + offs_oq = pid_q * BLOCK_SIZE_Q + tl.arange(0, BLOCK_SIZE_Q) + offs_ok1 = tl.arange(0, BLOCK_SIZE_K1) + # offs_ok2 = tl.arange(0, BLOCK_SIZE_K2) + o_ptrs = o_ptr + ( + stride_om * offs_om[:, None, None, None] + + stride_op * offs_op[None, :, None, None] + + stride_oq * offs_oq[None, None, :, None] + + stride_ok1 * offs_ok1[None, None, None, :] + ) + + m_valid = offs_om[:, None, None, None] < M + p_valid = offs_op[None, :, None, None] < P + q_valid = offs_oq[None, None, :, None] < Q + k_valid = offs_ok1[None, None, None, :] < K # K == BLOCK_SIZE_K1 + o_mask = m_valid & p_valid & q_valid & k_valid + tl.store(o_ptrs, out, mask=o_mask) + +@torch.library.custom_op( + f"{PACKAGE_NAME}::optical_morr_linear_linear_fn", mutates_args={}, +) +def morr_linear_fn( + x: Tensor, + weight: Tensor, + morr_input_bias: Tensor, + morr_output_scale: Tensor, + bias: Tensor | None, + morr_bias: Tensor | None, + grid_dim_x: int, + grid_dim_y: int, + miniblock: int, + enable_thermal_crosstalk: bool, + crosstalk_factor: float | None, + enable_phase_noise: bool, + phase_noise_std: float | None, + trainable_morr_bias: bool, + mrr_a: float, + mrr_r: float, + finegrain_drop_mask: Tensor | None, + in_features: int, + in_features_pad: int, + out_features: int, + out_features_pad: int, + in_bit: int, + w_bit: int, + morr_fwhm: float, +) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: + + assert x.dtype in ( + torch.bfloat16, + torch.float16, + torch.float32, + ), f"Unsupported dtype {x.dtype}" + assert x.is_contiguous(), "Input tensor must be contiguous" + assert weight.dtype in ( + torch.bfloat16, + torch.float16, + torch.float32, + ), f"Unsupported dtype {weight.dtype}" + + # Handle transformer vs non-transformer inputs + ori_x_shape = x.shape + is_transformer = len(ori_x_shape) == 3 + + if is_transformer: + in_B, in_N, in_D = x.shape + M = in_B * in_N + x = x.reshape(M, in_D) + else: + M = x.shape[0] + + # Get dimensions + M, D = x.shape + P, Q, K = weight.shape + + if in_features_pad > D: + x_pad = torch.zeros(M, in_features_pad - D, device=x.device, dtype=x.dtype) + x = torch.cat([x, x_pad], dim=1) + + assert Q * K == in_features_pad, "input and weight dimension mismatch" + assert P * K == out_features_pad, "weight and output dimension mismatch" + + # Reshape x and weight + x = x.view(-1, grid_dim_x, miniblock) # [M, q, k] + x = x.unsqueeze(1).unsqueeze(-1) # [M, 1, q, k, 1] + weight = toeplitz(weight).unsqueeze(0) # [p, q, k] -> [1, p, q, k, k] + + x_ctx = x.squeeze(-1).squeeze(1).clone() # [M, q, k] + w_ctx = weight.clone() + + # Allocate output + output = torch.empty((M, P, Q, K, 1), device=x.device, dtype=x.dtype) + # Launch the Triton kernel + grid = lambda meta: ( + triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(P, meta["BLOCK_SIZE_P"]) * triton.cdiv(Q, meta["BLOCK_SIZE_Q"]), + ) + morr_propagate_kernel[grid]( + x_ptr = x, + w_ptr = weight, + o_ptr = output, + b_ptr = morr_bias, + M=M, + P=P, + Q=Q, + K=K, + grid_dim_q=grid_dim_x, + grid_dim_p=grid_dim_y, + miniblock=miniblock, + crosstalk_factor=crosstalk_factor, + phase_noise_std=phase_noise_std, + mrr_a=mrr_a, + mrr_r=mrr_r, + in_bit=in_bit, + w_bit=w_bit, + finegrain_drop_mask=finegrain_drop_mask, + stride_wm=weight.stride(0), + stride_wp=weight.stride(1), + stride_wq=weight.stride(2), + stride_wk1=weight.stride(3), + stride_wk2=weight.stride(4), + stride_xm=x.stride(0), + stride_xp=x.stride(1), + stride_xq=x.stride(2), + stride_xk1=x.stride(3), + stride_xk2=x.stride(4), + stride_bm=morr_bias.stride(0) if morr_bias is not None else 0, + stride_bp=morr_bias.stride(1) if morr_bias is not None else 0, + stride_bq=morr_bias.stride(2) if morr_bias is not None else 0, + stride_bk1=morr_bias.stride(3) if morr_bias is not None else 0, + stride_om=output.stride(0), + stride_op=output.stride(1), + stride_oq=output.stride(2), + stride_ok1=output.stride(3), + stride_ok2=output.stride(4), + ENABLE_THERMAL_CROSSTALK=enable_thermal_crosstalk, + ENABLE_PHASE_NOISE=enable_phase_noise and phase_noise_std > 1, + TRAINABLE_MORR_BIAS = trainable_morr_bias, + INPUT_DTYPE=TORCH_DTYPE_TO_TRITON[x.dtype], + BLOCK_SIZE_K1=K, + ) + + # ----- build_weight() morr_output_scale part ----- + if w_bit < 16: + morr_output_scale = _weight_quantize_fn(morr_output_scale) + else: + morr_output_scale = (morr_output_scale - morr_output_scale.data.mean()) + + # differential balancing factor concatenation + scale = morr_output_scale[..., :-1, :] + scale_pad = morr_output_scale[..., -1:, :] + if grid_dim_x % 2 == 0: + # even blocks + scale = torch.cat([scale, -scale], dim=2) # [1, 1, q, 1] + else: + # odd blocks + if grid_dim_x > 1: + scale = torch.cat([morr_output_scale, -scale], dim=2) # [1, 1, q, 1] + else: + scale = scale_pad # [1, 1, q, 1] + morr_output_scale = scale.squeeze(-1).unsqueeze(0) # [1 ,1, 1, q] + ctx_morr_output_scale = morr_output_scale.clone() + + # Apply output scale + output = output.squeeze(-1) # [m, p, q, k, 1] -> [m, p, q, k] + ctx_x_scalematmul = output.clone() # record x input for matmul + output = morr_output_scale.matmul(output) # [1, 1, 1, q] x [bs, p, q, k] = [bs, p, 1, k] + output = output.flatten(1) # [bs, p*k] + + # Trim output if needed + if out_features < out_features_pad: + output = output[:, :out_features] + if bias is not None: + x = x + bias.unsqueeze(0) + # Reshape back for transformer + if is_transformer: + output = output.view(in_B, in_N, out_features) + + # aux_tensor = ( + # torch.abs(w_ctx), # w_morr: weight in propagate_morr matmul + # x_ctx, # x_modulator: x before x^2 + # ) + + return output, torch.abs(w_ctx), x_ctx, ctx_morr_output_scale, ctx_x_scalematmul + + + +def _morr_linear_setup_context(ctx, inputs, output): + """ + Save for backward only what the backward routine really needs. + """ + ( + x, # 0 Tensor – input + weight, # 1 Tensor – learnable weight + morr_input_bias, # 23 Tensor + _, # 3 morr_output_scale + bias, # 4 Tensor | None – bias + morr_bias, # 2 Tensor | None + grid_dim_x, # 5 int + grid_dim_y, # 6 int + miniblock, # 7 int (== K) + enable_thermal_crosstalk,# 8 bool + crosstalk_factor, # 9 float + enable_phase_noise, # 10 bool + phase_noise_std, # 11 float + trainable_morr_bias, # 12 bool + mrr_a, # 13 float + mrr_r, # 14 float + finegrain_drop_mask, # 15 Tensor | None + in_features, # 16 int + in_features_pad, # 17 int + out_features, # 18 int + out_features_pad, # 19 int + in_bit, # 20 int + w_bit, # 21 int + morr_fwhm, # 22 float + ) = inputs + + output, w_morr, x_modulator, morr_output_scale, x_scalematmul = output + # ( + # w_morr, + # x_modulator, + # ) = aux_tensor + + device, dtype = x.device, x.dtype + + # ----- Tensor meta-data that backward needs ----- + # Shapes + M = x.shape[0] if x.dim() == 2 else x.shape[0] * x.shape[1] + P, Q, K = weight.shape + tensor_shape = (M, P, Q, K) + + # mrr_para: para for mrr_roundtrip_phase_to_tr() + c1 = -2.0 * mrr_a * mrr_r + c2 = mrr_a * mrr_a + mrr_r * mrr_r + c3 = 1.0 + (mrr_r * mrr_r) * (mrr_a * mrr_a) - mrr_a * mrr_a - mrr_r * mrr_r + c4 = (mrr_a**2.0 - 1.0) * (mrr_r**2.0 - 1.0) * 2.0 * mrr_a * mrr_r + intensity = True + mrr_para = (c1, c2, c3, c4, intensity) + + # x_morr: x input of matmal in propagate_morr() + x_morr = x_modulator ** 2 # [m, q, k] + x_morr = x_morr.unsqueeze(1).unsqueeze(-1) # [m, 1, q, k, 1] + + # x_mrr: x input of mrr_roundtrip_phase_to_tr() + x_mrr = w_morr.matmul(x_morr).squeeze(-1) + if enable_phase_noise and phase_noise_std > 1e-5: + x_mrr = x_mrr + torch.zeros_like(x_mrr).normal_(0, phase_noise_std) + if trainable_morr_bias: + x_mrr = x_mrr - morr_bias + + # 3. stash tensors + ctx.save_for_backward( + # x, # original input + weight.sign(), # original weight's sign + # TODO: complete self.tensor + bias if bias is not None else torch.tensor([], device=device, dtype=dtype), + morr_output_scale, # original morr_output_scale + # x_mrr, # x input for mrr_roundtrip_phase_to_tr() + w_morr, # w input for propagate_morr() matmul + x_modulator, # x input for input_modulator() + morr_input_bias, + x_scalematmul, # x input for morr_output_scale.matmul + ) + ctx.tensor_shape = tensor_shape + ctx.mrr_para = mrr_para + ctx.in_features = in_features + ctx.in_features_pad = in_features_pad + ctx.out_features = out_features + ctx.out_features_pad = out_features_pad + ctx.morr_fwhm = morr_fwhm + ctx.grid_dim_x = grid_dim_x + ctx.grid_dim_y = grid_dim_y + ctx.w_bit = w_bit + ctx.x_input_shape = x.shape + ctx.device = x.device + ctx.w_input_shape = weight.shape + ctx.morr_fwhm = morr_fwhm + ctx.enable_phase_noise = enable_phase_noise + ctx.phase_noise_std = phase_noise_std + ctx.trainable_morr_bias = trainable_morr_bias + + + +def _morr_linear_backward(ctx, grad_output, *ignored): + """ + Backward pass for morr_linear_fn. + """ + ( + # x, + w_input_sign, + bias, + morr_output_scale, + # x_mrr, + w_morr, + x_modulator, + morr_input_bias, + x_scalematmul, + ) = ctx.saved_tensors + + M, P, Q, K = ctx.tensor_shape + c1, c2, c3, c4, intensity = ctx.mrr_para + in_features = ctx.in_features + in_features_pad = ctx.in_features_pad + out_features = ctx.out_features + out_features_pad = ctx.out_features_pad + x_input_shape = ctx.x_input_shape + w_input_shape = ctx.w_input_shape + DEVICE = ctx.device + + # --- calculate intermediate activation on the fly --- + x_morr = (x_modulator ** 2).unsqueeze(1).unsqueeze(-1) # [m, q, k] -> # [m, 1, q, k, 1] + + tanh_input_bias = torch.tanh(morr_input_bias.unsqueeze(0).unsqueeze(-1)) + morr_bias = ctx.morr_fwhm * tanh_input_bias + + # x_mrr: x input of mrr_roundtrip_phase_to_tr() + x_mrr = w_morr.matmul(x_morr).squeeze(-1) + if ctx.enable_phase_noise and ctx.phase_noise_std > 1e-5: + x_mrr = x_mrr + torch.zeros_like(x_mrr).normal_(0, ctx.phase_noise_std) + if ctx.trainable_morr_bias: + x_mrr = x_mrr - morr_bias + + + + # ----- backward prop ----- + # Reshape + grad_out = grad_output.view( + x_input_shape[0], + w_input_shape[1], + w_input_shape[2], + -1 + ) # [M, P, Q, K] + + # ----- Gradient w.r.t input x ----- + if ctx.needs_input_grad[0]: + # 1. reshape + grad_out = grad_out.view(M, -1) # [m, out_features] + + if ctx.needs_input_grad[4] and bias: + grad_bias = grad_out.sum(dim=0) # [out_features] + else: + grad_bias = None + + out_pad = torch.zeros(grad_out.shape[0], out_features_pad-out_features, device = DEVICE) # [m, out_features_pad - out_features] + grad_out = torch.cat([grad_out, out_pad], dim=1) # [m * out_features_pad] = [m, p*k] + + # 2. x=x.flatten(1) + # input: [m, p**k] + grad_out = grad_out.view(M, P, 1, K) # [m, p, 1, k] + + # 3. x = morr_output_scale.matmul(x) # [1, 1, 1, q] x [bs, p, q, k] = [bs, p, 1, k] + # dL/d(morr_output_scale) + if ctx.needs_input_grad[3]: + grad_s = grad_out.matmul(x_scalematmul.transpose(-2, -1)) # [bs, p, 1, q] + grad_s = grad_s.sum(dim=(0, 1)).unsqueeze(0).unsqueeze(1) # [1, 1, 1, q] + grad_s = grad_s.squeeze(0).unsqueeze(-1) # [1, 1, q, 1] gradient of scale + + t = ctx.grid_dim_x // 2 + grad_scale = grad_s.new_zeros((1, 1, t+1, 1)) + + if ctx.grid_dim_x % 2 == 0: + grad_scale[..., :t, :] = grad_s[..., :t, :] - grad_s[..., t:, :] + elif ctx.grid_dim_x == 1: + grad_scale = grad_s + else: + grad_scale[..., :t, :] = grad_s[..., :t, :] - grad_s[..., t+1:, :] + grad_scale[..., t:t+1, :] = grad_s[..., t:t+1, :] + + if ctx.w_bit < 16: + # TODO: backprop of weight_quantizer + raise NotImplementedError("quantization not supported") + else: + grad_scale = None + + # dL/dx + grad_x = morr_output_scale.transpose(-2, -1).matmul(grad_out) # [bs, p, q, k] + + # 4. x = mrr_roundtrip_phase_to_tr(x) + denominator = x_mrr.cos().mul_(c1).add_(c2 + c3) + if intensity: + denominator.square_() + numerator = x_mrr.sin().mul_(c4) + else: + numerator = x_mrr.sin().mul_(c4 / 2) + denominator = ( + denominator.sub(1).pow_(1.5).mul_(denominator.sub(c3).sqrt_()) + ) + grad_x = numerator.div_(denominator).mul_(grad_x) # [bs, p, q, k] + + # 5. x += phase_noise and morr_bias + if ctx.needs_input_grad[2]: + grad_inputbias = - grad_x # [bs, p, q, k] + grad_inputbias = grad_inputbias * ctx.morr_fwhm # [bs, p, q, k] + grad_inputbias = grad_inputbias - tanh_input_bias * tanh_input_bias # [bs, p, q, k] + grad_inputbias = grad_inputbias.sum(dim=(0, -1)) + else: + grad_inputbias = None + + # 6. x = weight.matmul(x) [1, p, q, k, k] * [bs, 1, q, k, 1] = [bs, p, q, k, 1] + grad_x = grad_x.unsqueeze(-1) # [bs, p, q, k, 1] + grad_morr_matmul = grad_x # stash for weight gradient + + # dL/dx + grad_x = torch.matmul(w_morr.transpose(-1, -2), grad_x) # [1, p, q, k, k] x [bs, p, q, k, 1] = [bs, p, q, k, 1] + grad_x = grad_x.sum(dim=1, keepdim=True) # [bs, p, q, k, 1] -> [bs, 1, q, k, 1] + grad_x = grad_x.squeeze(-1).squeeze(1) # [bs, 1, q, k, 1] -> [bs, q, k] + + # 7. input modulator + grad_x = grad_x * 2 * x_modulator # [bs, q, k] + + # 8. input reshape + grad_x = grad_x.view(x_input_shape) + grad_x = grad_x[:, :in_features] + + + + # ----- Gradient w.r.t weight ----- + if ctx.needs_input_grad[1]: + + # 0. gradient after x = weight.matmul(x) + # grad_morr_matmul # [bs, p, q, k, 1] + + # 1. x = weight.matmul(x) + grad_w = torch.matmul(grad_morr_matmul, x_morr.transpose(-1,-2)) # [bs,p,q,k,k] + grad_w = grad_w.sum(dim=0, keepdim=True) # [1,p,q,k,k] + + # 2. weight = toeplitz(weight) + k = grad_w.size(-1) + row = torch.arange(k)[:, None] # (k,1) + col = torch.arange(k)[None, :] # (1,k) + idx = (row - col) & (k - 1) if (k & (k-1)) == 0 else (row - col + k) % k + + idx = idx.expand(grad_w.shape).to(DEVICE) + buffer = torch.zeros_like(grad_w, device=DEVICE) + buffer.scatter_add_(-2, idx, grad_w) # [1, p, q, k, k] + grad_w = buffer.sum(dim=-1, keepdim=True).squeeze(0).squeeze(-1) + + # 3. build_weight() weight = self.weight.abs() + grad_w = grad_w * w_input_sign + + return ( + grad_x, # ∂L/∂x + grad_w, # ∂L/∂w + grad_inputbias, # ∂L/∂morr_input_bias + grad_scale, # ∂L/∂morr_output_scale + grad_bias, # ∂L/∂bias + None, None, None, None, None, None, None, None, None, + None, None, None, + None, None, None, None, None, None, None, + None, + ) + + +morr_linear_fn.register_autograd( + _morr_linear_backward, setup_context=_morr_linear_setup_context, +) \ No newline at end of file diff --git a/src/chop/nn/optical/triton_modules/morr_linear_mem.py b/src/chop/nn/optical/triton_modules/morr_linear_mem.py new file mode 100644 index 000000000..221024364 --- /dev/null +++ b/src/chop/nn/optical/triton_modules/morr_linear_mem.py @@ -0,0 +1,473 @@ +""" +Description: +Author: Jiaqi Gu (jqgu@utexas.edu) +Date: 2022-04-18 14:19:57 +LastEditors: Jiaqi Gu (jqgu@utexas.edu) +LastEditTime: 2022-04-18 16:21:37 +""" + +from typing import Optional +import logging + +import numpy as np +import torch +import torch.fft +from torch import Tensor +from torch.nn import Parameter, init +from torch.types import Device + +from ..utils import MORRConfig_20um_MQ +from ..utils import mrr_roundtrip_phase_to_tr_func, mrr_roundtrip_phase_to_tr_fused +from ..utils import toeplitz +from ..utils import morr_uniform_ +from ..utils import input_quantize_fn, weight_quantize_fn +from ..modules.base_layer import ONNBaseLayer +from .morr_linear_kernel_mem import morr_linear_fn + +logger = logging.getLogger(__name__) + +__all__ = ["AllPassMORRCirculantLinear"] + + +class TritonMemMORRLinear(ONNBaseLayer): + """ + All-pass MORR Linear layer, assumes (1) block-circulant matrix (2) differential rails (3) learnable balancing factors. + J. Gu, et al., "SqueezeLight: Towards Scalable Optical Neural Networks with Multi-Operand Ring Resonators" + https://doi.org/10.23919/DATE51398.2021.9474147 + """ + + __constants__ = ["in_features", "out_features"] + in_features: int + out_features: int + miniblock: int + weight: Tensor + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = False, + config={}, + device: Device = torch.device("cpu"), + ) -> None: + super(TritonMemMORRLinear, self).__init__() + self.in_features = in_features + self.out_features = out_features + + miniblock_size = config.get("miniblock", 4) + self.miniblock = miniblock_size + self.grid_dim_x = int(np.ceil(self.in_features / miniblock_size)) + self.grid_dim_y = int(np.ceil(self.out_features / miniblock_size)) + self.in_features_pad = self.grid_dim_x * miniblock_size + self.out_features_pad = self.grid_dim_y * miniblock_size + + self.v_max = 10.8 + self.v_pi = 4.36 + self.gamma = np.pi / self.v_pi**2 + self.w_bit = 32 + self.in_bit = 32 + + morr_config = config.get("MORRConfig", MORRConfig_20um_MQ) + morr_init_val = config.get("morr_init", MORRConfig_20um_MQ) + self.MORRConfig = morr_config + self.morr_init = morr_init_val + self.mrr_a = morr_config.attenuation_factor + self.mrr_r = morr_config.coupling_factor + self.trainable_morr_bias = config.get("trainable_morr_bias", MORRConfig_20um_MQ) + self.trainable_morr_scale = config.get( + "trainable_morr_scale", MORRConfig_20um_MQ + ) + self.device = device + ### calculate FWHM (rad) + self.morr_fwhm = ( + -4 + * np.pi**2 + * morr_config.radius + * morr_config.effective_index + * ( + 1 / morr_config.resonance_wavelength + - 1 / (morr_config.resonance_wavelength - morr_config.bandwidth / 2) + ) + ) + + ### allocate parameters + self.weight = None + self.x_zero_pad = None + self.morr_output_scale = None ## learnable balancing factors implelemt by MRRs + self.morr_input_bias = None ## round-trip phase shift bias within MORR + self.morr_input_scale = ( + None ## scaling factor for the round-trip phase shift within MORR + ) + self.morr_gain = ( + 100 / (self.in_features // self.miniblock) + ) ** 0.5 ## TIA gain, calculated such that output variance is around 1 + ### build trainable parameters + self.build_parameters() + + ### quantization tool + self.input_quantizer = input_quantize_fn(self.in_bit, device=self.device) + self.weight_quantizer = weight_quantize_fn( + self.w_bit, alg="dorefa_pos" + ) ## [0-1] positive only, maintain the original scale + self.morr_output_scale_quantizer = weight_quantize_fn( + self.w_bit, alg="dorefa_sym" + ) ## [-1,1] full-range + + self.mrr_roundtrip_phase_to_tr = mrr_roundtrip_phase_to_tr_func( + a=self.mrr_a, r=self.mrr_r, intensity=True + ) + + ### default set to slow forward + self.disable_fast_forward() + ### default set no gamma noise + self.set_gamma_noise(0) + ### default set no crosstalk + self.disable_crosstalk() + ### default set no phase variation + self.disable_phase_variation() + + if bias: + self.bias = Parameter(torch.Tensor(out_features).to(self.device)) + else: + self.register_parameter("bias", None) + + self.reset_parameters(morr_init=morr_init_val) + self.finegrain_drop_mask = None + + def build_parameters(self) -> None: + + self.weight = Parameter( + torch.ones( + self.grid_dim_y, + self.grid_dim_x, + self.miniblock, + device=self.device, + dtype=torch.float, + ) + ) + ### Learnable balancing factor (morr_output_scale) + ### We use a single scaling factor for each block + self.morr_output_scale = Parameter( + torch.randn(1, 1, max(1, self.grid_dim_x // 2) + 1, 1, device=self.device) + ) + if self.trainable_morr_bias: + ### initialize with the finest-granularity, i.e., per mini-block + self.morr_input_bias = Parameter( + torch.zeros( + self.grid_dim_y, + self.grid_dim_x, + device=self.device, + dtype=torch.float, + ) + ) + if self.trainable_morr_scale: + ### initialize with the finest-granularity, i.e., per mini-block + self.morr_input_scale = Parameter( + torch.zeros( + self.grid_dim_y, + self.grid_dim_x, + device=self.device, + dtype=torch.float, + ) + ) + + def reset_parameters(self, morr_init: bool = False) -> None: + ### nonlinear curve aware initialization + if morr_init: + ## initialize weight + morr_uniform_( + self.weight, + MORRConfig=self.MORRConfig, + n_op=self.miniblock, + biased=self.w_bit >= 16, + gain=2 if self.in_bit < 16 else 1, + ) # quantization needs zero-center + self.sigma_weight = self.weight.data.std().item() + self.weight_quant_gain = None + + ## output distribution aware initialization to output scaling factor + t1 = mrr_roundtrip_phase_to_tr_fused( + torch.tensor([0]).float(), a=self.mrr_a, r=self.mrr_r, intensity=True + ) + t2 = mrr_roundtrip_phase_to_tr_fused( + torch.tensor([self.morr_fwhm * 2.4]).float(), + a=self.mrr_a, + r=self.mrr_r, + intensity=True, + ) + g = ( + (t2 - t1) / (2.4 * self.morr_fwhm) + ).item() ## 0~2.4 FWHM slope as a linear approximation + + self.sigma_out_scale = 4 / (3 * self.grid_dim_x**0.5 * g * self.morr_fwhm) + self.out_scale_quant_gain = None + init.normal_(self.morr_output_scale, 0, self.sigma_out_scale) + else: + init.kaiming_normal_(self.weight.data) + init.kaiming_normal_(self.morr_output_scale.data) + self.sigma_weight = self.weight.data.std().item() + self.weight_quant_gain = None + self.sigma_out_scale = self.morr_output_scale.data.std().item() + self.out_scale_quant_gain = None + + if self.morr_input_bias is not None: + self.morr_input_bias.data.zero_() + if self.morr_input_scale is not None: + ### after sigmoid, it cooresponds to 1 scale + init.normal_(self.morr_input_scale.data, 2, 0.1) + + if self.bias is not None: + init.uniform_(self.bias, 0, 0) + + def sync_parameters(self, src: str = "weight") -> None: + """ + description: synchronize all parameters from the source parameters + """ + + raise NotImplementedError + + def build_weight(self) -> Tensor: + if self.w_bit < 16: + ### differentiable quantizer based on STE to enable QAT (Dorefa-Net, arXiv 2016) + weight = self.weight_quantizer(self.weight) + + ## rescale weights after quantization can maintain the initialization distribution + if self.weight_quant_gain is None: + self.weight_quant_gain = self.sigma_weight / weight.data.std() + if self.trainable_morr_scale: + morr_scale = self.morr_scale * self.weight_quant_gain + else: + morr_scale = self.weight_quant_gain + weight = weight.mul( + morr_scale + ) ### gain factor from Tanh used in quantization + + ### quantize learnable balancing factor + morr_output_scale = self.morr_output_scale_quantizer(self.morr_output_scale) + else: + weight = self.weight.abs() # positive only + morr_output_scale = ( + self.morr_output_scale - self.morr_output_scale.data.mean() + ) + + if self.finegrain_drop_mask is not None: + weight = weight.mul(self.finegrain_drop_mask.float()) + + ## differential balancing factor concatenation + scale = morr_output_scale[..., :-1, :] + scale_pad = morr_output_scale[..., -1:, :] + if self.grid_dim_x % 2 == 0: + # even blocks + scale = torch.cat([scale, -scale], dim=2) # [1, 1, q, 1] + else: + # odd blocks + if self.grid_dim_x > 1: + scale = torch.cat([morr_output_scale, -scale], dim=2) # [1, 1, q, 1] + else: + scale = scale_pad # [1, 1, q, 1] + morr_output_scale = scale.squeeze(-1).unsqueeze(0) # [1 ,1, 1, q] + + return weight, morr_output_scale + + def enable_fast_forward(self) -> None: + self.fast_forward_flag = True + + def disable_fast_forward(self) -> None: + self.fast_forward_flag = False + + def set_gamma_noise( + self, noise_std: float, random_state: Optional[int] = None + ) -> None: + self.gamma_noise_std = noise_std + + def load_parameters(self, param_dict) -> None: + """ + description: update parameters based on this parameter dictionary\\ + param param_dict {dict of dict} {layer_name: {param_name: param_tensor, ...}, ...} + """ + for name, param in param_dict.items(): + getattr(self, name).data.copy_(param) + + def set_weight_bitwidth(self, w_bit: int) -> None: + self.w_bit = w_bit + self.weight_quantizer.set_bitwidth(w_bit) + self.morr_output_scale_quantizer.set_bitwidth(w_bit) + + def set_input_bitwidth(self, in_bit: int) -> None: + self.in_bit = in_bit + self.input_quantizer.set_bitwidth(in_bit) + + def input_modulator(self, x: Tensor) -> Tensor: + ### voltage to power, which is proportional to the phase shift + return x * x + + def set_crosstalk_coupling_matrix( + self, coupling_factor: float, drop_perc: float = 0 + ) -> None: + ### crosstalk coupling matrix is a symmetric matrix, but the intra-MORR crosstalk can be taken as a round-trip phase shift scaling factor, which is proportional to the number of segments after pruned. + ### drop-perc is the pruning percentage. + assert 0 <= coupling_factor <= 1, logger.error( + f"Coupling factor must in [0,1], but got {coupling_factor}" + ) + + self.crosstalk_factor = ( + 1 + max(3, (self.miniblock * (1 - drop_perc) - 1)) * coupling_factor + ) + + def enable_crosstalk(self) -> None: + self.enable_thermal_crosstalk = True + + def disable_crosstalk(self) -> None: + self.enable_thermal_crosstalk = False + + def set_phase_variation(self, phase_noise_std: float = 0) -> None: + self.phase_noise_std = phase_noise_std + + def enable_phase_variation(self) -> None: + self.enable_phase_noise = True + + def disable_phase_variation(self) -> None: + self.enable_phase_noise = False + + def enable_trainable_morr_scale(self) -> None: + self.trainable_morr_scale = True + + def disable_trainable_morr_scale(self) -> None: + self.trainable_morr_scale = False + + def enable_trainable_morr_bias(self) -> None: + self.trainable_morr_bias = True + + def disable_trainable_morr_bias(self) -> None: + self.trainable_morr_bias = False + + @property + def morr_bias(self) -> Tensor: + if self.morr_input_bias is None: + return None + # return 2 * self.morr_fwhm * torch.sigmoid(self.morr_input_bias.unsqueeze(0).unsqueeze(-1)) + return self.morr_fwhm * torch.tanh( + self.morr_input_bias.unsqueeze(0).unsqueeze(-1) + ) + + @property + def morr_scale(self) -> Tensor: + if self.morr_input_scale is None: + return None + return torch.sigmoid(self.morr_input_scale.unsqueeze(-1)) + 0.2 # [p, q, 1] + + def propagate_morr( + self, weight: Tensor, x: Tensor, morr_output_scale: Tensor + ) -> Tensor: + """ + @description: propagate through the analytically calculated transfer matrix of molg. We implement circulant matrix multiplication using fast circ matmul + @param weight {torch.Tensor} two phase shifters in the MZI-based attenuators + @param x {torch.Tensor} complex-valued input + @param morr_output_scale {torch.Tensor} learnable balancing factors + @return: y {torch.Tensor} output of attenuators + """ + ### x : [bs, q, k] + ### weights: [p, q, k] + ### morr_output_scale: [1, 1, 1, q] + + ### input scaling [TCAD'21], must have valid ranges. too small will have dead neuron and not enough nonlinearity; too large will have larger power, cross-channel crosstalk. [0.2 - 1.2] will be suitable + ## build circulant weight matrix + # crosstalk on the weights are much cheaper to compute than on the phase shift + if self.enable_thermal_crosstalk and self.crosstalk_factor > 1: + weight = weight * self.crosstalk_factor + weight = toeplitz(weight).unsqueeze(0) # [1, p, q, k, k] + x = x.unsqueeze(1).unsqueeze(-1) # [bs, 1, q, k, 1] + x = weight.matmul(x).squeeze(-1) # [bs, p, q, k] + + if self.enable_phase_noise and self.phase_noise_std > 1e-5: + x = x + torch.zeros_like(x).normal_(0, self.phase_noise_std) + + ### input biasing [TCAD'21], must have valid ranges. too large will have power issue and cross-channel crosstalk. [-2FWHM ~ 0] + if self.trainable_morr_bias: + x = x - self.morr_bias + + ### Use theoretical transmission function for trainable MORR nonlinearity [TCAD'21] + ### x is the phase detuning, x=0 means on-resonance + ### phase: [bs, p, q, k] + x = self.mrr_roundtrip_phase_to_tr(x) # 3x faster than autograd + + ## implement balancing factor as dot-product + """ + if(self.w_bit < 16): + morr_output_scale = self.morr_output_scale_quantizer(self.morr_output_scale) + if(self.sigma_out_scale_quant_gain is None): + self.sigma_out_scale_quant_gain = self.sigma_out_scale / morr_output_scale.data.std().item() + morr_output_scale = morr_output_scale.mul(self.sigma_out_scale_quant_gain)### gain factor from Tanh used in quantization + else: + morr_output_scale = self.morr_output_scale + # morr_output_scale = morr_output_scale * self.morr_gain + scale = morr_output_scale[..., :-1, :] + scale_pad = morr_output_scale[..., -1:, :] + + # print("morr diff transmission:", end=", ") + # diff = x[..., :x.size(2)//2,:]-x[..., x.size(2)//2:,:] + # print_stat(diff) + if(self.grid_dim_x % 2 == 0): + #even blocks + scale = torch.cat([scale, -scale], dim=2) # [1, 1, q, 1] + else: + # odd blocks + if(self.grid_dim_x > 1): + scale = torch.cat([morr_output_scale, -scale], dim=2) # [1, 1, q, 1] + else: + scale = scale_pad # [1, 1, q, 1] + scale = scale.squeeze(-1).unsqueeze(0) # [1 ,1, 1, q] + # print("output scale Q:", end=", ") + # print_stat(scale[..., :scale.size(-1)//2]) + """ + x = morr_output_scale.matmul(x) # [1, 1, 1, q] x [bs, p, q, k] = [bs, p, 1, k] + x = x.flatten(1) # [bs, p*k] + return x + + def get_finegrain_drop_mask(self, topk: int) -> Tensor: + if self.w_bit < 16: + weight = self.weight_quantizer(self.weight.data) # [p, q, k] + else: + weight = self.weight.data.abs() + indices = weight.argsort(dim=-1) + mask = torch.ones_like(weight, dtype=torch.bool, device=weight.device) + + drop_indices = indices[:, :, 0:-topk] + mask.scatter_(2, drop_indices, 0) + self.finegrain_drop_mask = mask + return mask + + def apply_finegrain_drop_mask(self, mask: Tensor) -> None: + if self.w_bit < 16: + self.weight.data.masked_fill_(~mask.view_as(self.weight.data), -1000) + else: + self.weight.data.masked_fill_(~mask.view_as(self.weight.data), 0) + + def forward(self, x: Tensor) -> Tensor: + output, *_ = morr_linear_fn( + x, + self.weight, + morr_input_bias = self.morr_input_bias, + morr_output_scale = self.morr_output_scale, + bias = None, + morr_bias = self.morr_bias, + grid_dim_x = self.grid_dim_x, + grid_dim_y = self.grid_dim_y, + miniblock = self.miniblock, + enable_thermal_crosstalk=self.enable_thermal_crosstalk, + crosstalk_factor=None if not self.enable_thermal_crosstalk else self.crosstalk_factor, + enable_phase_noise=self.enable_phase_noise, + phase_noise_std=None if not self.enable_phase_noise else self.phase_noise_std, + trainable_morr_bias=self.trainable_morr_bias, + mrr_a=self.mrr_a, + mrr_r=self.mrr_r, + finegrain_drop_mask=None, + in_features = self.in_features, + in_features_pad = self.in_features_pad, + out_features = self.out_features, + out_features_pad = self.out_features_pad, + in_bit = self.in_bit, + w_bit = self.w_bit, + morr_fwhm = self.morr_fwhm, + ) + return output \ No newline at end of file diff --git a/src/chop/nn/optical/triton_modules/quantize.py b/src/chop/nn/optical/triton_modules/quantize.py new file mode 100644 index 000000000..738b1a925 --- /dev/null +++ b/src/chop/nn/optical/triton_modules/quantize.py @@ -0,0 +1,104 @@ +import torch +from torch import Tensor +import triton +import triton.language as tl + + +@triton.jit +def uniform_quantize(x: tl.tensor, k, gradient_clip=False): + if k == 32: + out = input + elif k == 1: + out = tl.where(x >= 0, 1.0, -1.0) + else: + n = float(2 ** k - 1) + out = tl.extra.cuda.libdevice.rint(x * n) / n + + return out + + +def uniform_quantize_new(x: tl.tensor, k, scale, zero_point, gradient_clip=False): + if k == 32: + out = x + elif k == 1: + out = tl.where(x > 0, 1.0, tl.where(x < 0, -1.0, 0.0)) + else: + n = float(2 ** k - 1) + out = tl.div(x, scale) + out = out + zero_point + out = tl.extra.cuda.libdevice.rint(out) + out = tl.clamp(out, 0.0, n) + out = out - zero_point + out = out * scale + return out + + +@triton.jit +def _input_quantize_fn( + x: tl.tensor, quant_ratio, training, in_bit, alg, # self.training +): + # init + if alg == "dorefa": + uniform_q = uniform_quantize(k=in_bit) + elif alg == "normal": + uniform_q = uniform_quantize_new(k=in_bit) + scale = None + zero_point = None + # TODO: fix for triton + if 1 <= in_bit <= 8: # observer does not support higher than 8-bit + obs = torch.quantization.observer.MovingAverageMinMaxObserver( + averaging_constant=0.01, + dtype=torch.quint8, + qscheme=torch.per_tensor_affine, + reduce_range=False, + quant_min=0, + quant_max=2 ** in_bit - 1, + ) + else: + obs = None + + if quant_ratio > 1.0 and training: + rand_vals = tl.random(x.shape) + quant_noise_mask = tl.where(rand_vals > quant_ratio, 1, 0) + else: + quant_noise_mask = None + + if in_bit == 32: + input_q = x + elif in_bit == 1: + x = tl.clamp(x, 0.0, 1.0) + input_q = (uniform_q(x - 0.5) + 1) / 2 + if quant_noise_mask is not None: + noise = input_q - x + masked_noise = tl.where(quant_noise_mask, 0.0, noise) + input_q = x + masked_noise + else: + ### dorefa-style clamp for input data + if alg == "dorefa": + x = tl.clamp(x, 0.0, 1.0) + input_q = uniform_q(x) + elif alg == "normal": + if obs is not None: + if training: + obs(x) + scale, zero_point = obs.calculate_qparams() + # convert scale and zero_point type from qint8 + scale = scale.to(x.dtype) + zero_point = zero_point.to(x.dtype) + input_q = uniform_q(x, scale, zero_point) + else: + input_q = x # if no observer (in_bit > 8), do not quantize + else: + # raise NotImplementedError + input_q = tl.zeros_like(x) + # add noise + if quant_noise_mask is not None: + noise = input_q - x + masked_noise = tl.where(quant_noise_mask, 0.0, noise) + input_q = x + masked_noise + + return input_q + + +def _weight_quantize_fn(w: tl.tensor): + pass diff --git a/src/chop/passes/module/module_modify_helper.py b/src/chop/passes/module/module_modify_helper.py index 73806b2ac..c35cc67ee 100644 --- a/src/chop/passes/module/module_modify_helper.py +++ b/src/chop/passes/module/module_modify_helper.py @@ -19,6 +19,13 @@ LlamaAttention, ) +from transformers.models.bert.modeling_bert import ( + BertSdpaSelfAttention, + BertSelfAttention, +) + +from transformers.models.bert.configuration_bert import BertConfig + roberta_prefix_map = { RobertaSdpaSelfAttention: "roberta_self_attention", RobertaSelfAttention: "roberta_self_attention", @@ -32,6 +39,11 @@ LlamaAttention: "llama_self_attention", } +bert_prefix_map = { + BertSdpaSelfAttention: "bert_self_attention", + BertSelfAttention: "bert_self_attention", +} + def check_module_instance(module, prefix_map): """ @@ -212,10 +224,28 @@ def instantiate_llama_module( ) return llama_module +def instantiate_bert_module( + module, postfix, prefix, module_map, module_args, +): + bert_cls = module_map[f"{prefix}_{postfix}"] + + bert_module = bert_cls( + config=BertConfig( + hidden_size=module.query.in_features, + num_attention_heads=module.num_attention_heads, + attention_head_size=module.attention_head_size, + attention_probs_dropout_prob=module.dropout_prob, + is_decoder=False, + ), + morr_config=module_args, + ) + return bert_module + def instantiate_module(module, postfix, module_map, additional_module_args): is_roberta, roberta_layer_name = check_module_instance(module, roberta_prefix_map) is_llama, llama_layer_name = check_module_instance(module, llama_prefix_map) + is_bert, bert_layer_name = check_module_instance(module, bert_prefix_map) module_args = additional_module_args["config"] network_args = additional_module_args.get("network_config", None) @@ -236,6 +266,10 @@ def instantiate_module(module, postfix, module_map, additional_module_args): module = instantiate_llama_module( module, postfix, llama_layer_name, module_map, module_args, network_args ) + elif is_bert: + module = instantiate_bert_module( + module, postfix, bert_layer_name, module_map, module_args, + ) else: raise ValueError(f"{module} is not supported.") return module diff --git a/src/chop/passes/module/transforms/optical/module_transform_helper.py b/src/chop/passes/module/transforms/optical/module_transform_helper.py index 635d3d8f9..9a7c4be34 100644 --- a/src/chop/passes/module/transforms/optical/module_transform_helper.py +++ b/src/chop/passes/module/transforms/optical/module_transform_helper.py @@ -12,8 +12,12 @@ def replace_by_name_optical(network, module_name: str, new_module, target_name): original = get_module_by_name(network, module_name) if target_name == "linear_morr_full": updated_module = weight_replacement_full_linear_optical(original, new_module) - elif target_name == "linear_morr": + elif target_name in ["linear_morr", "linear_morr_triton", "linear_morr_triton_mem"]: updated_module = weight_replacement_circulant_linear_optical(original, new_module) + elif target_name in ["bert_self_attention_morr"]: + updated_module = weight_replacement_circulant_bert_attention(original, new_module) + else: + raise NotImplementedError(f"weight replacement function for the optical module {target_name} not implemented") network = set_module_by_name(network, module_name, updated_module) @@ -158,3 +162,14 @@ def weight_replacement_conv2d_optical(x, y): # Done. At this point, y.weight and y.bias (if present) have been overwritten # with a simple block-circulant approximation of x's parameters. return y + +def weight_replacement_circulant_bert_attention(original, new_module): + for name in ("query", "key", "value"): + src_linear = getattr(original, name) + dst_linear = getattr(new_module, name) + with torch.no_grad(): + dst_linear.weight.copy_(src_linear.weight) + if src_linear.bias is not None: + dst_linear.bias.copy_(src_linear.bias) + + return new_module \ No newline at end of file diff --git a/src/chop/passes/module/transforms/optical/optical.py b/src/chop/passes/module/transforms/optical/optical.py index d8981bb10..5fd5ec5eb 100644 --- a/src/chop/passes/module/transforms/optical/optical.py +++ b/src/chop/passes/module/transforms/optical/optical.py @@ -1,4 +1,5 @@ import torch +from transformers.models.bert.modeling_bert import BertSdpaSelfAttention from chop.nn.optical.modules import optical_module_map from chop.passes.module.module_modify_helper import instantiate_module @@ -24,6 +25,8 @@ def optical_transform_by_type(network, pass_args): module = torch.nn.Linear elif type_name == "conv2d": module = torch.nn.Conv2d + elif isinstance(m, BertSdpaSelfAttention): + type_name = "bert_self_attention" else: raise ValueError(f"{type_name} is not supported!") config = config["config"] @@ -83,6 +86,8 @@ def optical_transform_by_regex_name(network, pass_args): type_name = "linear" elif isinstance(m, torch.nn.Conv2d): type_name = "conv2d" + elif isinstance(m, BertSdpaSelfAttention): + type_name = "bert_self_attention" else: raise ValueError(f"{type_name} is not supported!") From dddc817df25406255d458b0ceef729dad3e136a7 Mon Sep 17 00:00:00 2001 From: Johnny1882 Date: Sun, 11 May 2025 23:58:39 +0100 Subject: [PATCH 27/38] complete transform pass --- .../passes/module/module_modify_helper.py | 28 ++- .../optical/module_transform_helper.py | 166 +++++++++++++++++- .../module/transforms/optical/optical.py | 14 +- 3 files changed, 200 insertions(+), 8 deletions(-) diff --git a/src/chop/passes/module/module_modify_helper.py b/src/chop/passes/module/module_modify_helper.py index 16d233b26..129438275 100644 --- a/src/chop/passes/module/module_modify_helper.py +++ b/src/chop/passes/module/module_modify_helper.py @@ -20,8 +20,8 @@ ) from transformers.models.bert.modeling_bert import ( - BertSelfAttention, BertSdpaSelfAttention, + BertSelfAttention, ) from transformers.models.bert.configuration_bert import BertConfig @@ -39,6 +39,11 @@ LlamaAttention: "llama_self_attention", } +bert_prefix_map = { + BertSdpaSelfAttention: "bert_self_attention", + BertSelfAttention: "bert_self_attention", +} + def check_module_instance(module, prefix_map): """ @@ -219,11 +224,28 @@ def instantiate_llama_module( ) return llama_module +def instantiate_bert_module( + module, postfix, prefix, module_map, module_args, +): + bert_cls = module_map[f"{prefix}_{postfix}"] + + bert_module = bert_cls( + config=BertConfig( + hidden_size=module.query.in_features, + num_attention_heads=module.num_attention_heads, + attention_head_size=module.attention_head_size, + attention_probs_dropout_prob=module.dropout_prob, + is_decoder=False, + ), + morr_config=module_args, + ) + return bert_module def instantiate_module(module, postfix, module_map, additional_module_args): is_roberta, roberta_layer_name = check_module_instance(module, roberta_prefix_map) is_llama, llama_layer_name = check_module_instance(module, llama_prefix_map) is_bert, bert_layer_name = check_module_instance(module, bert_prefix_map) + is_bert, bert_layer_name = check_module_instance(module, bert_prefix_map) module_args = additional_module_args["config"] network_args = additional_module_args.get("network_config", None) @@ -244,6 +266,10 @@ def instantiate_module(module, postfix, module_map, additional_module_args): module = instantiate_llama_module( module, postfix, llama_layer_name, module_map, module_args, network_args ) + elif is_bert: + module = instantiate_bert_module( + module, postfix, bert_layer_name, module_map, module_args, + ) else: raise ValueError(f"{module} is not supported.") return module diff --git a/src/chop/passes/module/transforms/optical/module_transform_helper.py b/src/chop/passes/module/transforms/optical/module_transform_helper.py index 9a7c4be34..bb18e7c30 100644 --- a/src/chop/passes/module/transforms/optical/module_transform_helper.py +++ b/src/chop/passes/module/transforms/optical/module_transform_helper.py @@ -1,11 +1,58 @@ import torch import torch.nn as nn import math +from functools import reduce, partial +from copy import deepcopy +import logging +import inspect + from chop.passes.module.module_modify_helper import ( get_module_by_name, set_module_by_name, ) +from chop.passes.module.state_dict_map import SPECIAL_CONVERT_PATTERNS + +from transformers.models.roberta.modeling_roberta import ( + RobertaSelfAttention, + RobertaSdpaSelfAttention, + RobertaClassificationHead, + RobertaIntermediate, + RobertaOutput, + RobertaSelfOutput, +) + +from transformers.models.llama.modeling_llama import ( + LlamaAttention, +) + +from transformers.models.bert.modeling_bert import ( + BertSdpaSelfAttention, + BertSelfAttention, +) + +from transformers.models.bert.configuration_bert import BertConfig + + +bert_prefix_map = { + BertSdpaSelfAttention: "bert_self_attention", + BertSelfAttention: "bert_self_attention", +} + +def check_module_instance(module, prefix_map): + """ + Check if the given module is an instance of any class in the prefix_map. If it is, return the corresponding prefix. + Args: + module (object): The module to check. + prefix_map (dict): A dictionary where keys are classes and values are prefixes. + Returns: + tuple: A tuple containing a boolean indicating if the module is an instance of any class in the prefix_map, + and the corresponding prefix if it is an instance, otherwise None. + """ + for cls, name in prefix_map.items(): + if isinstance(module, cls): + return True, name + return False, None def replace_by_name_optical(network, module_name: str, new_module, target_name): @@ -172,4 +219,121 @@ def weight_replacement_circulant_bert_attention(original, new_module): if src_linear.bias is not None: dst_linear.bias.copy_(src_linear.bias) - return new_module \ No newline at end of file + return new_module + + +def instantiate_optical_module(module, postfix, module_map, additional_module_args): + is_bert, bert_layer_name = check_module_instance(module, bert_prefix_map) + + module_args = additional_module_args["config"] + additional_args = additional_module_args["additional"] + network_args = additional_module_args.get("network_config", None) + + if isinstance(module, torch.nn.Linear): + module = instantiate_optical_linear(module, postfix, module_map, module_args, additional_args) + elif isinstance(module, torch.nn.Conv2d): + module = instantiate_optical_conv2d(module, postfix, module_map, module_args) + elif is_bert: + module = instantiate_optical_bert_module( + module, postfix, bert_layer_name, module_map, module_args, + ) + else: + raise ValueError(f"{module} is not supported.") + return module + +def instantiate_optical_linear(module, postfix, module_map, additional_module_args, additional_args): + linear_cls = module_map[f"linear_{postfix}"] + has_bias = not (module.bias is None) + + # TODO: some transformed modules have "config" as an argument then extract the additional_module_args from it. Some directly take the additional_module_args. + # Need to handle this better + if "config" in inspect.signature(linear_cls.__init__).parameters: + linear = linear_cls( + in_features=module.in_features, + out_features=module.out_features, + bias=has_bias, + config=additional_module_args, + ) + else: + linear = linear_cls( + in_features=module.in_features, + out_features=module.out_features, + bias=has_bias, + **additional_module_args, + ) + + # extra handling for morr optical module + enable_thermal_crosstalk = additional_args.get("thermal_crosstalk", False) + enable_phase_noise = additional_args.get("phase_noise", False) + enable_trainable_morr_scale = additional_args.get("trainable_morr_scale", False) + enable_trainable_morr_bias = additional_args.get("trainable_morr_bias", False) + + if enable_thermal_crosstalk: + linear.enable_crosstalk() + linear.set_crosstalk_coupling_matrix( + additional_args.get("coupling_factor", 0.04), + additional_args.get("drop_perc", 0.0), + ) + + if enable_phase_noise: + linear.enable_phase_variation() + phase_noise_std = additional_args.get("phase_noise_std", 0.04) + linear.set_phase_variation(phase_noise_std) + + if enable_trainable_morr_scale: + linear.enable_trainable_morr_scale() + + if enable_trainable_morr_bias: + linear.enable_trainable_morr_bias() + + return linear + +def instantiate_optical_conv2d(module, postfix, module_map, additional_module_args): + conv2d_cls = module_map[f"conv2d_{postfix}"] + has_bias = not (module.bias is None) + # TODO: some transformed modules have "config" as an argument then extract the additional_module_args from it. Some directly take the additional_module_args. + # Need to handle this better + if "config" in inspect.signature(conv2d_cls.__init__).parameters: + conv2d = conv2d_cls( + in_channels=module.in_channels, + out_channels=module.out_channels, + kernel_size=module.kernel_size, + stride=module.stride, + padding=module.padding, + dilation=module.dilation, + groups=module.groups, + bias=has_bias, + padding_mode=module.padding_mode, + config=additional_module_args, + ) + else: + conv2d = conv2d_cls( + in_channels=module.in_channels, + out_channels=module.out_channels, + kernel_size=module.kernel_size, + stride=module.stride, + padding=module.padding, + dilation=module.dilation, + groups=module.groups, + bias=has_bias, + padding_mode=module.padding_mode, + **additional_module_args, + ) + return conv2d + +def instantiate_optical_bert_module( + module, postfix, prefix, module_map, module_args, +): + bert_cls = module_map[f"{prefix}_{postfix}"] + + bert_module = bert_cls( + config=BertConfig( + hidden_size=module.query.in_features, + num_attention_heads=module.num_attention_heads, + attention_head_size=module.attention_head_size, + attention_probs_dropout_prob=module.dropout_prob, + is_decoder=False, + ), + morr_config=module_args, + ) + return bert_module \ No newline at end of file diff --git a/src/chop/passes/module/transforms/optical/optical.py b/src/chop/passes/module/transforms/optical/optical.py index 5fd5ec5eb..6f74a82d5 100644 --- a/src/chop/passes/module/transforms/optical/optical.py +++ b/src/chop/passes/module/transforms/optical/optical.py @@ -5,6 +5,7 @@ from chop.passes.module.module_modify_helper import instantiate_module from chop.passes.module.transforms.optical.module_transform_helper import ( replace_by_name_optical, + instantiate_optical_module, ) from ...state_dict_map import match_a_pattern, check_is_huggingface_model @@ -34,7 +35,7 @@ def optical_transform_by_type(network, pass_args): for n, m in n_m.items(): if isinstance(m, module): print(f"processing {n}") - new_m = instantiate_module( + new_m = instantiate_optical_module( m, postfix, optical_module_map, {"config": config} ) network = replace_by_name_optical(network, n, new_m, type_name +'_'+postfix) @@ -53,7 +54,7 @@ def optical_transform_by_name(network, pass_args): optical_config = optical_config["config"] postfix = optical_config.pop("name") - new_m = instantiate_module( + new_m = instantiate_optical_module( m, postfix, optical_module_map, {"config": optical_config} ) network = replace_by_name_optical(network, n, new_m) @@ -74,12 +75,13 @@ def optical_transform_by_regex_name(network, pass_args): print(f"processing {n}") optical_config = pass_args[matched_pattern]["config"] + optial_additional_config = pass_args[matched_pattern]["additional"] postfix = optical_config["name"] additional_module_args = ( - {"config": optical_config, "network_config": network.config} - if is_huggingface_model - else {"config": optical_config} + {"config": optical_config, "additional": optial_additional_config} + # if is_huggingface_model + # else {"config": optical_config} ) if isinstance(m, torch.nn.Linear): @@ -91,7 +93,7 @@ def optical_transform_by_regex_name(network, pass_args): else: raise ValueError(f"{type_name} is not supported!") - new_m = instantiate_module( + new_m = instantiate_optical_module( m, postfix, optical_module_map, additional_module_args ) network = replace_by_name_optical(network, n, new_m, type_name +'_'+postfix) From 9461079d5fc3c023360c4b329d1b32a0598e07f6 Mon Sep 17 00:00:00 2001 From: Johnny1882 Date: Thu, 15 May 2025 22:05:53 +0100 Subject: [PATCH 28/38] added quantization in custom kernel --- .../nn/optical/triton_modules/morr_linear.py | 8 +- .../triton_modules/morr_linear_kernel.py | 296 ++++++++++++------ .../triton_modules/morr_linear_kernel_mem.py | 195 +++++++----- .../optical/triton_modules/morr_linear_mem.py | 2 +- .../optical/module_transform_helper.py | 13 +- .../module/transforms/optical/optical.py | 7 +- 6 files changed, 356 insertions(+), 165 deletions(-) diff --git a/src/chop/nn/optical/triton_modules/morr_linear.py b/src/chop/nn/optical/triton_modules/morr_linear.py index 5c8aa4986..a27c9bd37 100644 --- a/src/chop/nn/optical/triton_modules/morr_linear.py +++ b/src/chop/nn/optical/triton_modules/morr_linear.py @@ -450,7 +450,8 @@ def forward(self, x: Tensor) -> Tensor: morr_input_bias = self.morr_input_bias, morr_output_scale = self.morr_output_scale, bias = None, - morr_bias = self.morr_bias, + morr_input_scale = self.morr_input_scale, + morr_bias = self.morr_bias.detach(), grid_dim_x = self.grid_dim_x, grid_dim_y = self.grid_dim_y, miniblock = self.miniblock, @@ -469,5 +470,10 @@ def forward(self, x: Tensor) -> Tensor: in_bit = self.in_bit, w_bit = self.w_bit, morr_fwhm = self.morr_fwhm, + sigma_weight=self.sigma_weight, + trainable_morr_scale=self.trainable_morr_scale, # bool + morr_scale=self.morr_scale, + weight_quant_gain=self.weight_quant_gain, + seed = 42, ) return output \ No newline at end of file diff --git a/src/chop/nn/optical/triton_modules/morr_linear_kernel.py b/src/chop/nn/optical/triton_modules/morr_linear_kernel.py index a6852e4e5..e40af6f47 100644 --- a/src/chop/nn/optical/triton_modules/morr_linear_kernel.py +++ b/src/chop/nn/optical/triton_modules/morr_linear_kernel.py @@ -9,7 +9,11 @@ from .dtype import TORCH_DTYPE_TO_TRITON PACKAGE_NAME = "mase_triton" -from ..utils import toeplitz +from ..utils import ( + toeplitz, + input_quantize_fn, + weight_quantize_fn, +) from .quantize import _input_quantize_fn, _weight_quantize_fn @@ -56,22 +60,22 @@ def _mrr_roundtrip_phase_to_tr_func( return x # @triton.autotune( -# configs = _get_autotune_configs(), +# configs= [ +# triton.Config( +# { +# "BLOCK_SIZE_M": 1, +# "BLOCK_SIZE_P": 1, +# "BLOCK_SIZE_Q": 1, +# # "BLOCK_SIZE_K1": 4, +# "BLOCK_SIZE_K2": 1, +# }, +# num_stages=3, +# num_warps=8, +# ),], # key=["M", "P", "Q", "K"], # ) @triton.autotune( - configs= [ - triton.Config( - { - "BLOCK_SIZE_M": 1, - "BLOCK_SIZE_P": 1, - "BLOCK_SIZE_Q": 1, - # "BLOCK_SIZE_K1": 4, - "BLOCK_SIZE_K2": 1, - }, - num_stages=3, - num_warps=8, - ),], + configs = _get_autotune_configs(), key=["M", "P", "Q", "K"], ) @triton.jit @@ -114,9 +118,14 @@ def morr_propagate_kernel( # Program ID for block-based processing # each program is assigned GROUP_SIZE_MPQ * [1, 1, miniblock, 1] block pid = tl.program_id(axis=0) - pid_m = pid // (grid_dim_q * grid_dim_p) - pid_p = (pid // grid_dim_q) % grid_dim_p - pid_q = pid % grid_dim_q + # number of blocks (each program needs to handle) along M, P, Q dimension + pnum_m = grid_dim_p * grid_dim_q + pnum_p = grid_dim_p // BLOCK_SIZE_P + pnum_q = grid_dim_q // BLOCK_SIZE_Q + # block dimension of current program + pid_m = pid // (pnum_q * pnum_p) + pid_p = (pid // pnum_q) % pnum_p + pid_q = pid % pnum_q # starting element's m, p, q coordinates in the global tensor start_m = pid_m * BLOCK_SIZE_M @@ -129,7 +138,7 @@ def morr_propagate_kernel( offs_wq = pid_q * BLOCK_SIZE_Q + tl.arange(0, 1) offs_wk1 = tl.arange(0, BLOCK_SIZE_K1) offs_wk2 = tl.arange(0, BLOCK_SIZE_K1) - + # x [m, 1, q, k, 1] offs_xm = pid_m * BLOCK_SIZE_M + tl.arange(0, 1) offs_xp = tl.arange(0, 1) offs_xq = pid_q * BLOCK_SIZE_Q + tl.arange(0, 1) @@ -183,23 +192,9 @@ def morr_propagate_kernel( x = tl.load(x_ptrs, mask=x_mask, other=0.0) b = tl.load(b_ptrs, mask=b_mask, other=0.0) - # TODO: Test Quantization Function - # if in_bit < 16: - # x = _input_quantize_fn(x) - - # ----- build_weight() ----- - # TODO: fix quantization func - # if w_bit < 16: - # w = _weight_quantize_fn(w) - # else: - # w = tl.abs(w) - - w = tl.abs(w).reshape(BLOCK_SIZE_K1, BLOCK_SIZE_K1) # [1, 1, 1, k, k] -> [k, k] + w = w.reshape(BLOCK_SIZE_K1, BLOCK_SIZE_K1) # [1, 1, 1, k, k] -> [k, k] x = x.reshape(BLOCK_SIZE_K1, BLOCK_SIZE_K2) # [1, 1, 1, k, 1] -> [k, 1] - if finegrain_drop_mask is not None: - w *= tl.cast(finegrain_drop_mask, tl.float32) - x = x * x # input_modulator() # ----- propagate_morr() ----- @@ -254,16 +249,16 @@ def morr_propagate_kernel( # P loop end # reset pointer along P dimension w_ptrs -= stride_wp * (BLOCK_SIZE_P) + b_ptrs -= stride_bp * (BLOCK_SIZE_P) # x_ptrs -= stride_xp * (BLOCK_SIZE_P + 1) # x has P dimension = 1、 # propagate pointer along M dimension # w_ptrs += stride_wp # weight has M dimension = 1 x_ptrs += stride_xm - out = acc.to(INPUT_DTYPE) out = out.reshape(BLOCK_SIZE_M, BLOCK_SIZE_P, BLOCK_SIZE_Q, BLOCK_SIZE_K1) # [1, 1, q, k, 1] -> [1, 1, q, k] - + offs_om = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_op = pid_p * BLOCK_SIZE_P + tl.arange(0, BLOCK_SIZE_P) offs_oq = pid_q * BLOCK_SIZE_Q + tl.arange(0, BLOCK_SIZE_Q) @@ -282,9 +277,10 @@ def morr_propagate_kernel( k_valid = offs_ok1[None, None, None, :] < K # K == BLOCK_SIZE_K1 o_mask = m_valid & p_valid & q_valid & k_valid tl.store(o_ptrs, out, mask=o_mask) + # pdb.set_trace() @torch.library.custom_op( - f"{PACKAGE_NAME}::optical_morr_linear_linear_fn", mutates_args={}, + f"{PACKAGE_NAME}::optical_morr_linear_fn", mutates_args={}, ) def morr_linear_fn( x: Tensor, @@ -292,6 +288,7 @@ def morr_linear_fn( morr_input_bias: Tensor, morr_output_scale: Tensor, bias: Tensor | None, + morr_input_scale: Tensor, morr_bias: Tensor | None, grid_dim_x: int, grid_dim_y: int, @@ -311,9 +308,17 @@ def morr_linear_fn( in_bit: int, w_bit: int, morr_fwhm: float, + sigma_weight: float, + trainable_morr_scale: bool, + morr_scale: Tensor, + weight_quant_gain: float | None = None, + in_quant_alg: str = "dorefa", + w_quant_alg: str = "dorefa_pos", + morr_output_scale_quant_alg: str = "dorefa_sym", seed: int=42, -) -> tuple[Tensor, int, Tensor, Tensor, Tensor, Tensor]: - +) -> tuple[Tensor, int, Tensor, Tensor, Tensor, Tensor, Tensor, float, Tensor, Tensor]: + Device = x.device + Dtype = x.dtype assert x.dtype in ( torch.bfloat16, torch.float16, @@ -342,12 +347,64 @@ def morr_linear_fn( P, Q, K = weight.shape if in_features_pad > D: - x_pad = torch.zeros(M, in_features_pad - D, device=x.device, dtype=x.dtype) + x_pad = torch.zeros(M, in_features_pad - D, device=Device, dtype=x.dtype) x = torch.cat([x, x_pad], dim=1) assert Q * K == in_features_pad, "input and weight dimension mismatch" assert P * K == out_features_pad, "weight and output dimension mismatch" + # Quantize input + ctx_x_quant = torch.empty(0, device=Device, dtype=Dtype) + if in_bit < 16: + input_quantizer = input_quantize_fn(in_bit, device=Device) + input_quantizer.set_bitwidth(in_bit) + ctx_x_quant = x.clone() + x = input_quantizer(x) + + # Build weight + ctx_w_quant = torch.empty(0, device=Device, dtype=Dtype) + if w_bit < 16: + weight_quantizer = weight_quantize_fn(w_bit, alg="dorefa_pos") + weight_quantizer.set_bitwidth(w_bit) + ctx_w_quant = weight.clone() + weight = weight_quantizer(weight) + + ## rescale weights after quantization can maintain the initialization distribution + if weight_quant_gain is None: + weight_quant_gain = sigma_weight / weight.data.std() + if trainable_morr_scale: + morr_scale = morr_scale * weight_quant_gain + else: + morr_scale = weight_quant_gain + weight = weight.mul( + morr_scale + ) ### gain factor from Tanh used in quantization + ### quantize learnable balancing factor + morr_output_scale_quantizer = weight_quantize_fn(w_bit, alg="dorefa_sym") + morr_output_scale = morr_output_scale_quantizer(morr_output_scale) + else: + weight = weight.abs() # positive only + morr_output_scale = (morr_output_scale - morr_output_scale.data.mean()) + + if finegrain_drop_mask is not None: + weight = weight.mul(finegrain_drop_mask.float()) + + # differential balancing factor concatenation + scale = morr_output_scale[..., :-1, :] + scale_pad = morr_output_scale[..., -1:, :] + if grid_dim_x % 2 == 0: + # even blocks + scale = torch.cat([scale, -scale], dim=2) # [1, 1, q, 1] + else: + # odd blocks + if grid_dim_x > 1: + scale = torch.cat([morr_output_scale, -scale], dim=2) # [1, 1, q, 1] + else: + scale = scale_pad # [1, 1, q, 1] + morr_output_scale = scale.squeeze(-1).unsqueeze(0) # [1 ,1, 1, q] + ctx_morr_output_scale = morr_output_scale.clone() + + # Reshape x and weight x = x.view(-1, grid_dim_x, miniblock) # [M, q, k] x = x.unsqueeze(1).unsqueeze(-1) # [M, 1, q, k, 1] @@ -357,7 +414,7 @@ def morr_linear_fn( w_ctx = weight.clone() # Allocate output - output = torch.empty((M, P, Q, K, 1), device=x.device, dtype=x.dtype) + output = torch.empty((M, P, Q, K, 1), device=Device, dtype=x.dtype) # Launch the Triton kernel grid = lambda meta: ( triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(P, meta["BLOCK_SIZE_P"]) * triton.cdiv(Q, meta["BLOCK_SIZE_Q"]), @@ -408,27 +465,6 @@ def morr_linear_fn( BLOCK_SIZE_K1 = K, ) - # ----- build_weight() morr_output_scale part ----- - if w_bit < 16: - morr_output_scale = _weight_quantize_fn(morr_output_scale) - else: - morr_output_scale = (morr_output_scale - morr_output_scale.data.mean()) - - # differential balancing factor concatenation - scale = morr_output_scale[..., :-1, :] - scale_pad = morr_output_scale[..., -1:, :] - if grid_dim_x % 2 == 0: - # even blocks - scale = torch.cat([scale, -scale], dim=2) # [1, 1, q, 1] - else: - # odd blocks - if grid_dim_x > 1: - scale = torch.cat([morr_output_scale, -scale], dim=2) # [1, 1, q, 1] - else: - scale = scale_pad # [1, 1, q, 1] - morr_output_scale = scale.squeeze(-1).unsqueeze(0) # [1 ,1, 1, q] - ctx_morr_output_scale = morr_output_scale.clone() - # Apply output scale output = output.squeeze(-1) # [m, p, q, k, 1] -> [m, p, q, k] ctx_x_scalematmul = output.clone() # record x input for matmul @@ -439,7 +475,7 @@ def morr_linear_fn( if out_features < out_features_pad: output = output[:, :out_features] if bias is not None: - x = x + bias.unsqueeze(0) + output = output + bias.unsqueeze(0) # Reshape back for transformer if is_transformer: output = output.view(in_B, in_N, out_features) @@ -449,7 +485,18 @@ def morr_linear_fn( # x_ctx, # x_modulator: x before x^2 # ) - return output, seed, torch.abs(w_ctx), x_ctx, ctx_morr_output_scale, ctx_x_scalematmul + return ( + output, + seed, + torch.abs(w_ctx), + x_ctx, + ctx_morr_output_scale, + ctx_x_scalematmul, + morr_scale.clone(), + weight_quant_gain if weight_quant_gain is not None else 0.0, + ctx_x_quant, + ctx_w_quant, + ) @@ -461,8 +508,9 @@ def _morr_linear_setup_context(ctx, inputs, output): x, # 0 Tensor – input weight, # 1 Tensor – learnable weight morr_input_bias, # 23 Tensor - _, # 3 morr_output_scale + origin_morr_output_scale, # 3 Original input morr_output_scale bias, # 4 Tensor | None – bias + morr_input_scale, # 5 Tensor morr_bias, # 2 Tensor | None grid_dim_x, # 5 int grid_dim_y, # 6 int @@ -482,10 +530,28 @@ def _morr_linear_setup_context(ctx, inputs, output): in_bit, # 20 int w_bit, # 21 int morr_fwhm, # 22 float + sigma_weight, + trainable_morr_scale, # bool + _morr_scale, + _weight_quant_gain, + in_quant_alg, + w_quant_alg, + morr_output_scale_quant_alg, seed, ) = inputs - output, seed, w_morr, x_modulator, morr_output_scale, x_scalematmul = output + ( + output, + seed, + w_morr, + x_modulator, + morr_output_scale, + x_scalematmul, + morr_scale, + weight_quant_gain, + x_quant, + w_quant + ) = output # ( # w_morr, # x_modulator, @@ -523,10 +589,9 @@ def _morr_linear_setup_context(ctx, inputs, output): # 3. stash tensors ctx.save_for_backward( x, # original input - weight.sign(), # original weight's sign - # TODO: complete self.tensor + weight, # original weight bias if bias is not None else torch.tensor([], device=device, dtype=dtype), - morr_output_scale, # original morr_output_scale + morr_output_scale, # morr_output_scale after modification in build_weight() x_mrr, # x input for mrr_roundtrip_phase_to_tr() x_morr, w_morr, # w input for propagate_morr() matmul @@ -535,6 +600,11 @@ def _morr_linear_setup_context(ctx, inputs, output): # morr_input_bias, x_scalematmul, # x input for morr_output_scale.matmul tanh_input_bias, + morr_input_scale, + morr_scale, # morr_scale after modification in build_weight() + x_quant, # x input for input_quantize_fn() + w_quant, # w input for weight_quantize_fn() + origin_morr_output_scale, # original morr_output_scale ) ctx.tensor_shape = tensor_shape ctx.mrr_para = mrr_para @@ -545,14 +615,19 @@ def _morr_linear_setup_context(ctx, inputs, output): ctx.morr_fwhm = morr_fwhm ctx.grid_dim_x = grid_dim_x ctx.grid_dim_y = grid_dim_y + ctx.in_bit = in_bit ctx.w_bit = w_bit ctx.x_input_shape = x.shape ctx.device = x.device ctx.w_input_shape = weight.shape - ctx.morr_fwhm = morr_fwhm ctx.enable_phase_noise = enable_phase_noise ctx.phase_noise_std = phase_noise_std ctx.trainable_morr_bias = trainable_morr_bias + ctx.trainable_morr_scale = trainable_morr_scale + ctx.weight_quant_gain = weight_quant_gain + ctx.in_quant_alg = in_quant_alg + ctx.w_quant_alg = w_quant_alg + ctx.morr_output_scale_quant_alg = morr_output_scale_quant_alg @@ -562,7 +637,7 @@ def _morr_linear_backward(ctx, grad_output, *ignored): """ ( x, - w_input_sign, + weight, bias, morr_output_scale, x_mrr, @@ -572,8 +647,12 @@ def _morr_linear_backward(ctx, grad_output, *ignored): x_modulator, # morr_input_bias, x_scalematmul, - tanh_input_bias - + tanh_input_bias, + morr_input_scale, + morr_scale, + x_quant, + w_quant, + origin_morr_output_scale, ) = ctx.saved_tensors M, P, Q, K = ctx.tensor_shape @@ -633,21 +712,28 @@ def _morr_linear_backward(ctx, grad_output, *ignored): grad_s = grad_s.squeeze(0).unsqueeze(-1) # [1, 1, q, 1] gradient of scale t = ctx.grid_dim_x // 2 - grad_scale = grad_s.new_zeros((1, 1, t+1, 1)) + grad_output_scale = grad_s.new_zeros((1, 1, t+1, 1)) if ctx.grid_dim_x % 2 == 0: - grad_scale[..., :t, :] = grad_s[..., :t, :] - grad_s[..., t:, :] + grad_output_scale[..., :t, :] = grad_s[..., :t, :] - grad_s[..., t:, :] elif ctx.grid_dim_x == 1: - grad_scale = grad_s + grad_output_scale = grad_s else: - grad_scale[..., :t, :] = grad_s[..., :t, :] - grad_s[..., t+1:, :] - grad_scale[..., t:t+1, :] = grad_s[..., t:t+1, :] - + grad_output_scale[..., :t, :] = grad_s[..., :t, :] - grad_s[..., t+1:, :] + grad_output_scale[..., t:t+1, :] = grad_s[..., t:t+1, :] + # build_weight() if ctx.w_bit < 16: - # TODO: backprop of weight_quantizer - raise NotImplementedError("quantization not supported") + # morr_output_scale_quantizer() + if ctx.morr_output_scale_quant_alg == "dorefa_sym": + # local recompute: + w_in = torch.tanh(origin_morr_output_scale) # [-1, 1] + # ignore gradient for r here + grad_output_scale = grad_output_scale * (1.0 - w_in.pow(2)) + grad_output_scale = grad_output_scale.clamp_(-1, 1) + else: + raise NotImplementedError else: - grad_scale = None + grad_output_scale = None # dL/dx grad_x = morr_output_scale.transpose(-2, -1).matmul(grad_out) # [bs, p, q, k] @@ -682,14 +768,20 @@ def _morr_linear_backward(ctx, grad_output, *ignored): grad_x = grad_x.sum(dim=1, keepdim=True) # [bs, p, q, k, 1] -> [bs, 1, q, k, 1] grad_x = grad_x.squeeze(-1).squeeze(1) # [bs, 1, q, k, 1] -> [bs, q, k] - # 7. input modulator + # 7. input modulator(x) grad_x = grad_x * 2 * x_modulator # [bs, q, k] # 8. input reshape grad_x = grad_x.view(x_input_shape) grad_x = grad_x[:, :in_features] - + # 9.input quantization + if ctx.in_bit >= 16 or ctx.in_quant_alg is None: + pass + elif ctx.in_quant_alg == "dorefa": + grad_x = grad_x * ((x_quant > 0) & (x_quant < 1)) + else: + raise NotImplementedError # ----- Gradient w.r.t weight ----- if ctx.needs_input_grad[1]: @@ -710,21 +802,47 @@ def _morr_linear_backward(ctx, grad_output, *ignored): idx = idx.expand(grad_w.shape).to(DEVICE) buffer = torch.zeros_like(grad_w, device=DEVICE) buffer.scatter_add_(-2, idx, grad_w) # [1, p, q, k, k] - grad_w = buffer.sum(dim=-1, keepdim=True).squeeze(0).squeeze(-1) - - # 3. build_weight() weight = self.weight.abs() - grad_w = grad_w * w_input_sign + grad_w = buffer.sum(dim=-1, keepdim=True).squeeze(0).squeeze(-1) # [p, q, k] + + # 3. build_weight() + # morr_scale: [p, q, 1] + grad_morr_input_scale = None + if ctx.w_bit < 16: + # grad w.r.t morr_scale + if ctx.needs_input_grad[5] & ctx.trainable_morr_scale: + grad_morr_scale = (grad_w * weight).sum(dim=2, keepdim=True) # [p, q, 1] + grad_morr_scale = grad_morr_scale * ctx.weight_quant_gain # [p, q, 1] + # ∂L/∂self.morr_input_scale + sigmoid_scale = torch.sigmoid(morr_input_scale) + grad_morr_input_scale = (grad_morr_scale * sigmoid_scale * (1-sigmoid_scale)).squeeze(-1) # [p, q] + + # grad w.r.t weight + grad_w = grad_w * morr_scale # weight.mul(morr_scale) + # weight_quantizer() + if ctx.w_quant_alg is None: + pass + elif ctx.w_quant_alg == "dorefa_pos": + # local recompute: + w_in = torch.tanh(w_quant) # [-1, 1] + # ignore gradient for r here + grad_w = grad_w * (1.0 - w_in.pow(2)) + grad_w = grad_w.clamp_(-1, 1) + else: + raise NotImplementedError + else: + grad_w = grad_w * weight.sign() return ( grad_x, # ∂L/∂x grad_w, # ∂L/∂w grad_inputbias, # ∂L/∂morr_input_bias - grad_scale, # ∂L/∂morr_output_scale + grad_output_scale, # ∂L/∂morr_output_scale grad_bias, # ∂L/∂bias + grad_morr_input_scale, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, - None, + None, None, None, None ) diff --git a/src/chop/nn/optical/triton_modules/morr_linear_kernel_mem.py b/src/chop/nn/optical/triton_modules/morr_linear_kernel_mem.py index b14e8dbf8..973a9d523 100644 --- a/src/chop/nn/optical/triton_modules/morr_linear_kernel_mem.py +++ b/src/chop/nn/optical/triton_modules/morr_linear_kernel_mem.py @@ -93,6 +93,7 @@ def morr_propagate_kernel( mrr_r, in_bit, w_bit, + seed, # stride stride_wm, stride_wp, stride_wq, stride_wk1, stride_wk2, stride_xm, stride_xp, stride_xq, stride_xk1, stride_xk2, @@ -113,9 +114,14 @@ def morr_propagate_kernel( # Program ID for block-based processing # each program is assigned GROUP_SIZE_MPQ * [1, 1, miniblock, 1] block pid = tl.program_id(axis=0) - pid_m = pid // (grid_dim_q * grid_dim_p) - pid_p = (pid // grid_dim_q) % grid_dim_p - pid_q = pid % grid_dim_q + # number of blocks (each program needs to handle) along M, P, Q dimension + pnum_m = grid_dim_p * grid_dim_q + pnum_p = grid_dim_p // BLOCK_SIZE_P + pnum_q = grid_dim_q // BLOCK_SIZE_Q + # block dimension of current program + pid_m = pid // (pnum_q * pnum_p) + pid_p = (pid // pnum_q) % pnum_p + pid_q = pid % pnum_q # starting element's m, p, q coordinates in the global tensor start_m = pid_m * BLOCK_SIZE_M @@ -181,24 +187,12 @@ def morr_propagate_kernel( w = tl.load(w_ptrs, mask=w_mask, other=0.0) x = tl.load(x_ptrs, mask=x_mask, other=0.0) b = tl.load(b_ptrs, mask=b_mask, other=0.0) - - # TODO: Test Quantization Function - # if in_bit < 16: - # x = _input_quantize_fn(x) - - # ----- build_weight() ----- - # TODO: fix quantization func - # if w_bit < 16: - # w = _weight_quantize_fn(w) - # else: - # w = tl.abs(w) - - w = tl.abs(w).reshape(BLOCK_SIZE_K1, BLOCK_SIZE_K1) # [1, 1, 1, k, k] -> [k, k] + + + w = w.reshape(BLOCK_SIZE_K1, BLOCK_SIZE_K1) # [1, 1, 1, k, k] -> [k, k] x = x.reshape(BLOCK_SIZE_K1, BLOCK_SIZE_K2) # [1, 1, 1, k, 1] -> [k, 1] - if finegrain_drop_mask is not None: - w *= tl.cast(finegrain_drop_mask, tl.float32) - + x = x * x # input_modulator() # ----- propagate_morr() ----- @@ -215,7 +209,9 @@ def morr_propagate_kernel( # apply phase noise if ENABLE_PHASE_NOISE: - noise = tl.zeros_like(x) + tl.randn(x.shape) * phase_noise_std + block_start = pid * BLOCK_SIZE_K1 * BLOCK_SIZE_K2 + offs = tl.reshape(block_start + tl.arange(0, BLOCK_SIZE_K1 * BLOCK_SIZE_K2) , (BLOCK_SIZE_K1, BLOCK_SIZE_K2)) + noise = tl.randn(seed, offs) * phase_noise_std x = x + noise # add trainable bias @@ -251,6 +247,7 @@ def morr_propagate_kernel( # P loop end # reset pointer along P dimension w_ptrs -= stride_wp * (BLOCK_SIZE_P) + b_ptrs -= stride_bp * (BLOCK_SIZE_P) # x_ptrs -= stride_xp * (BLOCK_SIZE_P + 1) # x has P dimension = 1、 # propagate pointer along M dimension @@ -283,12 +280,13 @@ def morr_propagate_kernel( @torch.library.custom_op( f"{PACKAGE_NAME}::optical_morr_linear_linear_fn", mutates_args={}, ) -def morr_linear_fn( +def morr_linear_fn_mem( x: Tensor, weight: Tensor, morr_input_bias: Tensor, morr_output_scale: Tensor, bias: Tensor | None, + morr_input_scale: Tensor, morr_bias: Tensor | None, grid_dim_x: int, grid_dim_y: int, @@ -308,8 +306,13 @@ def morr_linear_fn( in_bit: int, w_bit: int, morr_fwhm: float, -) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: - + sigma_weight: float, + trainable_morr_scale: bool, + morr_scale: Tensor, + weight_quant_gain: float | None = None, + seed: int=42, +) -> tuple[Tensor, int, Tensor, Tensor, Tensor, Tensor, Tensor, float]: + Device = x.device assert x.dtype in ( torch.bfloat16, torch.float16, @@ -338,12 +341,59 @@ def morr_linear_fn( P, Q, K = weight.shape if in_features_pad > D: - x_pad = torch.zeros(M, in_features_pad - D, device=x.device, dtype=x.dtype) + x_pad = torch.zeros(M, in_features_pad - D, device=Device, dtype=x.dtype) x = torch.cat([x, x_pad], dim=1) assert Q * K == in_features_pad, "input and weight dimension mismatch" assert P * K == out_features_pad, "weight and output dimension mismatch" + # Quantize input + if in_bit < 16: + input_quantizer = input_quantize_fn(in_bit, device=Device) + input_quantizer.set_bitwidth(in_bit) + x = input_quantizer(x) + + # Build weight + if w_bit < 16: + weight_quantizer = weight_quantize_fn(w_bit, alg="dorefa_pos") + weight_quantizer.set_bitwidth(w_bit) + weight = weight_quantizer(weight) + + ## rescale weights after quantization can maintain the initialization distribution + if weight_quant_gain is None: + weight_quant_gain = sigma_weight / weight.data.std() + if trainable_morr_scale: + morr_scale = morr_scale * weight_quant_gain + else: + morr_scale = weight_quant_gain + weight = weight.mul( + morr_scale + ) ### gain factor from Tanh used in quantization + ### quantize learnable balancing factor + morr_output_scale_quantizer = weight_quantize_fn(w_bit, alg="dorefa_sym") + morr_output_scale = morr_output_scale_quantizer(morr_output_scale) + else: + weight = weight.abs() # positive only + morr_output_scale = (morr_output_scale - morr_output_scale.data.mean()) + + if finegrain_drop_mask is not None: + weight = weight.mul(finegrain_drop_mask.float()) + + # differential balancing factor concatenation + scale = morr_output_scale[..., :-1, :] + scale_pad = morr_output_scale[..., -1:, :] + if grid_dim_x % 2 == 0: + # even blocks + scale = torch.cat([scale, -scale], dim=2) # [1, 1, q, 1] + else: + # odd blocks + if grid_dim_x > 1: + scale = torch.cat([morr_output_scale, -scale], dim=2) # [1, 1, q, 1] + else: + scale = scale_pad # [1, 1, q, 1] + morr_output_scale = scale.squeeze(-1).unsqueeze(0) # [1 ,1, 1, q] + ctx_morr_output_scale = morr_output_scale.clone() + # Reshape x and weight x = x.view(-1, grid_dim_x, miniblock) # [M, q, k] x = x.unsqueeze(1).unsqueeze(-1) # [M, 1, q, k, 1] @@ -353,7 +403,7 @@ def morr_linear_fn( w_ctx = weight.clone() # Allocate output - output = torch.empty((M, P, Q, K, 1), device=x.device, dtype=x.dtype) + output = torch.empty((M, P, Q, K, 1), device=Device, dtype=x.dtype) # Launch the Triton kernel grid = lambda meta: ( triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(P, meta["BLOCK_SIZE_P"]) * triton.cdiv(Q, meta["BLOCK_SIZE_Q"]), @@ -376,6 +426,7 @@ def morr_linear_fn( mrr_r=mrr_r, in_bit=in_bit, w_bit=w_bit, + seed=seed, finegrain_drop_mask=finegrain_drop_mask, stride_wm=weight.stride(0), stride_wp=weight.stride(1), @@ -397,33 +448,12 @@ def morr_linear_fn( stride_ok1=output.stride(3), stride_ok2=output.stride(4), ENABLE_THERMAL_CROSSTALK=enable_thermal_crosstalk, - ENABLE_PHASE_NOISE=enable_phase_noise and phase_noise_std > 1, + ENABLE_PHASE_NOISE=enable_phase_noise and phase_noise_std > 1e-4, TRAINABLE_MORR_BIAS = trainable_morr_bias, INPUT_DTYPE=TORCH_DTYPE_TO_TRITON[x.dtype], BLOCK_SIZE_K1=K, ) - # ----- build_weight() morr_output_scale part ----- - if w_bit < 16: - morr_output_scale = _weight_quantize_fn(morr_output_scale) - else: - morr_output_scale = (morr_output_scale - morr_output_scale.data.mean()) - - # differential balancing factor concatenation - scale = morr_output_scale[..., :-1, :] - scale_pad = morr_output_scale[..., -1:, :] - if grid_dim_x % 2 == 0: - # even blocks - scale = torch.cat([scale, -scale], dim=2) # [1, 1, q, 1] - else: - # odd blocks - if grid_dim_x > 1: - scale = torch.cat([morr_output_scale, -scale], dim=2) # [1, 1, q, 1] - else: - scale = scale_pad # [1, 1, q, 1] - morr_output_scale = scale.squeeze(-1).unsqueeze(0) # [1 ,1, 1, q] - ctx_morr_output_scale = morr_output_scale.clone() - # Apply output scale output = output.squeeze(-1) # [m, p, q, k, 1] -> [m, p, q, k] ctx_x_scalematmul = output.clone() # record x input for matmul @@ -434,17 +464,12 @@ def morr_linear_fn( if out_features < out_features_pad: output = output[:, :out_features] if bias is not None: - x = x + bias.unsqueeze(0) + output = output + bias.unsqueeze(0) # Reshape back for transformer if is_transformer: output = output.view(in_B, in_N, out_features) - # aux_tensor = ( - # torch.abs(w_ctx), # w_morr: weight in propagate_morr matmul - # x_ctx, # x_modulator: x before x^2 - # ) - - return output, torch.abs(w_ctx), x_ctx, ctx_morr_output_scale, ctx_x_scalematmul + return output, seed, torch.abs(w_ctx), x_ctx, ctx_morr_output_scale, ctx_x_scalematmul, morr_scale.clone(), weight_quant_gain if weight_quant_gain is not None else 0.0 @@ -456,8 +481,9 @@ def _morr_linear_setup_context(ctx, inputs, output): x, # 0 Tensor – input weight, # 1 Tensor – learnable weight morr_input_bias, # 23 Tensor - _, # 3 morr_output_scale + _, # 3 morr_output_scale (original) bias, # 4 Tensor | None – bias + morr_input_scale, morr_bias, # 2 Tensor | None grid_dim_x, # 5 int grid_dim_y, # 6 int @@ -477,9 +503,14 @@ def _morr_linear_setup_context(ctx, inputs, output): in_bit, # 20 int w_bit, # 21 int morr_fwhm, # 22 float + sigma_weight, + trainable_morr_scale, # bool + _morr_scale, + _weight_quant_gain, + seed, # 23 int ) = inputs - output, w_morr, x_modulator, morr_output_scale, x_scalematmul = output + output, seed, w_morr, x_modulator, morr_output_scale, x_scalematmul, morr_scale, weight_quant_gain = output # ( # w_morr, # x_modulator, @@ -510,20 +541,25 @@ def _morr_linear_setup_context(ctx, inputs, output): if enable_phase_noise and phase_noise_std > 1e-5: x_mrr = x_mrr + torch.zeros_like(x_mrr).normal_(0, phase_noise_std) if trainable_morr_bias: - x_mrr = x_mrr - morr_bias + x_mrr = x_mrr - morr_bias # morr_bias here is the detached one from forward + + tanh_input_bias = torch.tanh(morr_input_bias.unsqueeze(0).unsqueeze(-1)) # Added from linear.py # 3. stash tensors ctx.save_for_backward( - # x, # original input - weight.sign(), # original weight's sign - # TODO: complete self.tensor + x, # original input (stashing x for mem version, might need re-evaluation for pure mem-saving) + weight, # original weight (stashing weight for mem version) bias if bias is not None else torch.tensor([], device=device, dtype=dtype), morr_output_scale, # original morr_output_scale # x_mrr, # x input for mrr_roundtrip_phase_to_tr() + # x_morr, w_morr, # w input for propagate_morr() matmul x_modulator, # x input for input_modulator() morr_input_bias, + # x_scalematmul, x_scalematmul, # x input for morr_output_scale.matmul + morr_input_scale, # morr input scale at input + morr_scale, # morr_scale after modification in build_weight() ) ctx.tensor_shape = tensor_shape ctx.mrr_para = mrr_para @@ -538,10 +574,12 @@ def _morr_linear_setup_context(ctx, inputs, output): ctx.x_input_shape = x.shape ctx.device = x.device ctx.w_input_shape = weight.shape - ctx.morr_fwhm = morr_fwhm + # ctx.morr_fwhm = morr_fwhm # Already exists ctx.enable_phase_noise = enable_phase_noise ctx.phase_noise_std = phase_noise_std ctx.trainable_morr_bias = trainable_morr_bias + ctx.trainable_morr_scale = trainable_morr_scale + ctx.weight_quant_gain = weight_quant_gain @@ -550,15 +588,18 @@ def _morr_linear_backward(ctx, grad_output, *ignored): Backward pass for morr_linear_fn. """ ( - # x, - w_input_sign, + x, + weight, bias, morr_output_scale, # x_mrr, + # x_morr, w_morr, x_modulator, morr_input_bias, x_scalematmul, + morr_input_scale, + morr_scale, ) = ctx.saved_tensors M, P, Q, K = ctx.tensor_shape @@ -629,10 +670,7 @@ def _morr_linear_backward(ctx, grad_output, *ignored): else: grad_scale[..., :t, :] = grad_s[..., :t, :] - grad_s[..., t+1:, :] grad_scale[..., t:t+1, :] = grad_s[..., t:t+1, :] - - if ctx.w_bit < 16: - # TODO: backprop of weight_quantizer - raise NotImplementedError("quantization not supported") + else: grad_scale = None @@ -699,8 +737,22 @@ def _morr_linear_backward(ctx, grad_output, *ignored): buffer.scatter_add_(-2, idx, grad_w) # [1, p, q, k, k] grad_w = buffer.sum(dim=-1, keepdim=True).squeeze(0).squeeze(-1) - # 3. build_weight() weight = self.weight.abs() - grad_w = grad_w * w_input_sign + # 3. build_weight() + # morr_scale: [p, q, 1] + grad_morr_input_scale = None + if ctx.w_bit < 16: + # grad w.r.t morr_scale + if ctx.needs_input_grad[5] & ctx.trainable_morr_scale: + grad_morr_scale = (grad_w * weight).sum(dim=2, keepdim=True) # [p, q, 1] + grad_morr_scale = grad_morr_scale * ctx.weight_quant_gain # [p, q, 1] + # ∂L/∂self.morr_input_scale + sigmoid_scale = torch.sigmoid(morr_input_scale) + grad_morr_input_scale = (grad_morr_scale * sigmoid_scale * (1-sigmoid_scale)).squeeze(-1) # [p, q] + + # grad w.r.t weight + grad_w = grad_w * morr_scale + else: + grad_w = grad_w * weight.sign() return ( grad_x, # ∂L/∂x @@ -708,13 +760,14 @@ def _morr_linear_backward(ctx, grad_output, *ignored): grad_inputbias, # ∂L/∂morr_input_bias grad_scale, # ∂L/∂morr_output_scale grad_bias, # ∂L/∂bias + grad_morr_input_scale, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, - None, + None, None, None, None ) -morr_linear_fn.register_autograd( +morr_linear_fn_mem.register_autograd( _morr_linear_backward, setup_context=_morr_linear_setup_context, ) \ No newline at end of file diff --git a/src/chop/nn/optical/triton_modules/morr_linear_mem.py b/src/chop/nn/optical/triton_modules/morr_linear_mem.py index 221024364..85b047122 100644 --- a/src/chop/nn/optical/triton_modules/morr_linear_mem.py +++ b/src/chop/nn/optical/triton_modules/morr_linear_mem.py @@ -22,7 +22,7 @@ from ..utils import morr_uniform_ from ..utils import input_quantize_fn, weight_quantize_fn from ..modules.base_layer import ONNBaseLayer -from .morr_linear_kernel_mem import morr_linear_fn +from .morr_linear_kernel_mem import morr_linear_fn_mem logger = logging.getLogger(__name__) diff --git a/src/chop/passes/module/transforms/optical/module_transform_helper.py b/src/chop/passes/module/transforms/optical/module_transform_helper.py index bb18e7c30..1c7faa482 100644 --- a/src/chop/passes/module/transforms/optical/module_transform_helper.py +++ b/src/chop/passes/module/transforms/optical/module_transform_helper.py @@ -261,7 +261,9 @@ def instantiate_optical_linear(module, postfix, module_map, additional_module_ar bias=has_bias, **additional_module_args, ) - + if additional_args is None: + return linear + # extra handling for morr optical module enable_thermal_crosstalk = additional_args.get("thermal_crosstalk", False) enable_phase_noise = additional_args.get("phase_noise", False) @@ -282,10 +284,19 @@ def instantiate_optical_linear(module, postfix, module_map, additional_module_ar if enable_trainable_morr_scale: linear.enable_trainable_morr_scale() + else: + linear.disable_trainable_morr_scale() if enable_trainable_morr_bias: linear.enable_trainable_morr_bias() + else: + linear.disable_trainable_morr_bias() + if "in_bit" in additional_args: + linear.set_input_bitwidth(in_bit = additional_args["in_bit"]) + if "w_bit" in additional_args: + linear.set_weight_bitwidth(w_bit = additional_args["w_bit"]) + return linear def instantiate_optical_conv2d(module, postfix, module_map, additional_module_args): diff --git a/src/chop/passes/module/transforms/optical/optical.py b/src/chop/passes/module/transforms/optical/optical.py index 6f74a82d5..0ec241e32 100644 --- a/src/chop/passes/module/transforms/optical/optical.py +++ b/src/chop/passes/module/transforms/optical/optical.py @@ -75,11 +75,14 @@ def optical_transform_by_regex_name(network, pass_args): print(f"processing {n}") optical_config = pass_args[matched_pattern]["config"] - optial_additional_config = pass_args[matched_pattern]["additional"] + optial_additional_config = pass_args[matched_pattern].get("additional", None) postfix = optical_config["name"] additional_module_args = ( - {"config": optical_config, "additional": optial_additional_config} + { + "config": optical_config, + "additional": optial_additional_config + } # if is_huggingface_model # else {"config": optical_config} ) From c9974ab7301e3ee00c04d51fdc99e526dc0db0a4 Mon Sep 17 00:00:00 2001 From: Johnny1882 Date: Tue, 20 May 2025 12:19:24 +0100 Subject: [PATCH 29/38] fix memory-efficient kernel --- .../triton_modules/morr_linear_kernel.py | 8 +- .../triton_modules/morr_linear_kernel_mem.py | 272 +++++++++++++++--- .../optical/triton_modules/morr_linear_mem.py | 12 +- 3 files changed, 244 insertions(+), 48 deletions(-) diff --git a/src/chop/nn/optical/triton_modules/morr_linear_kernel.py b/src/chop/nn/optical/triton_modules/morr_linear_kernel.py index e40af6f47..e2a4aa751 100644 --- a/src/chop/nn/optical/triton_modules/morr_linear_kernel.py +++ b/src/chop/nn/optical/triton_modules/morr_linear_kernel.py @@ -772,8 +772,9 @@ def _morr_linear_backward(ctx, grad_output, *ignored): grad_x = grad_x * 2 * x_modulator # [bs, q, k] # 8. input reshape - grad_x = grad_x.view(x_input_shape) - grad_x = grad_x[:, :in_features] + B, N, D = x_input_shape + grad_x = grad_x.view(-1, in_features_pad) # [b*n, in_features_pad] + grad_x = grad_x[:, :in_features] # [b*n, in_features = D] # 9.input quantization if ctx.in_bit >= 16 or ctx.in_quant_alg is None: @@ -783,6 +784,9 @@ def _morr_linear_backward(ctx, grad_output, *ignored): else: raise NotImplementedError + # 10. input reshape + grad_x = grad_x.view(B, N, D) # [b, n, d] + # ----- Gradient w.r.t weight ----- if ctx.needs_input_grad[1]: diff --git a/src/chop/nn/optical/triton_modules/morr_linear_kernel_mem.py b/src/chop/nn/optical/triton_modules/morr_linear_kernel_mem.py index 973a9d523..c928453cd 100644 --- a/src/chop/nn/optical/triton_modules/morr_linear_kernel_mem.py +++ b/src/chop/nn/optical/triton_modules/morr_linear_kernel_mem.py @@ -9,7 +9,12 @@ from .dtype import TORCH_DTYPE_TO_TRITON PACKAGE_NAME = "mase_triton" -from ..utils import toeplitz +from ..utils import ( + toeplitz, + input_quantize_fn, + weight_quantize_fn, + mrr_roundtrip_phase_to_tr_func +) from .quantize import _input_quantize_fn, _weight_quantize_fn @@ -506,15 +511,11 @@ def _morr_linear_setup_context(ctx, inputs, output): sigma_weight, trainable_morr_scale, # bool _morr_scale, - _weight_quant_gain, + weight_quant_gain, seed, # 23 int ) = inputs - output, seed, w_morr, x_modulator, morr_output_scale, x_scalematmul, morr_scale, weight_quant_gain = output - # ( - # w_morr, - # x_modulator, - # ) = aux_tensor + output, seed, w_morr, x_modulator, morr_output_scale, x_scalematmul, morr_scale, _weight_quant_gain = output device, dtype = x.device, x.dtype @@ -525,25 +526,25 @@ def _morr_linear_setup_context(ctx, inputs, output): tensor_shape = (M, P, Q, K) # mrr_para: para for mrr_roundtrip_phase_to_tr() - c1 = -2.0 * mrr_a * mrr_r - c2 = mrr_a * mrr_a + mrr_r * mrr_r - c3 = 1.0 + (mrr_r * mrr_r) * (mrr_a * mrr_a) - mrr_a * mrr_a - mrr_r * mrr_r - c4 = (mrr_a**2.0 - 1.0) * (mrr_r**2.0 - 1.0) * 2.0 * mrr_a * mrr_r - intensity = True - mrr_para = (c1, c2, c3, c4, intensity) + # c1 = -2.0 * mrr_a * mrr_r + # c2 = mrr_a * mrr_a + mrr_r * mrr_r + # c3 = 1.0 + (mrr_r * mrr_r) * (mrr_a * mrr_a) - mrr_a * mrr_a - mrr_r * mrr_r + # c4 = (mrr_a**2.0 - 1.0) * (mrr_r**2.0 - 1.0) * 2.0 * mrr_a * mrr_r + # intensity = True + # mrr_para = (c1, c2, c3, c4, intensity) - # x_morr: x input of matmal in propagate_morr() - x_morr = x_modulator ** 2 # [m, q, k] - x_morr = x_morr.unsqueeze(1).unsqueeze(-1) # [m, 1, q, k, 1] + # # x_morr: x input of matmal in propagate_morr() + # x_morr = x_modulator ** 2 # [m, q, k] + # x_morr = x_morr.unsqueeze(1).unsqueeze(-1) # [m, 1, q, k, 1] - # x_mrr: x input of mrr_roundtrip_phase_to_tr() - x_mrr = w_morr.matmul(x_morr).squeeze(-1) - if enable_phase_noise and phase_noise_std > 1e-5: - x_mrr = x_mrr + torch.zeros_like(x_mrr).normal_(0, phase_noise_std) - if trainable_morr_bias: - x_mrr = x_mrr - morr_bias # morr_bias here is the detached one from forward + # # x_mrr: x input of mrr_roundtrip_phase_to_tr() + # x_mrr = w_morr.matmul(x_morr).squeeze(-1) + # if enable_phase_noise and phase_noise_std > 1e-5: + # x_mrr = x_mrr + torch.zeros_like(x_mrr).normal_(0, phase_noise_std) + # if trainable_morr_bias: + # x_mrr = x_mrr - morr_bias # morr_bias here is the detached one from forward - tanh_input_bias = torch.tanh(morr_input_bias.unsqueeze(0).unsqueeze(-1)) # Added from linear.py + # tanh_input_bias = torch.tanh(morr_input_bias.unsqueeze(0).unsqueeze(-1)) # Added from linear.py # 3. stash tensors ctx.save_for_backward( @@ -553,16 +554,17 @@ def _morr_linear_setup_context(ctx, inputs, output): morr_output_scale, # original morr_output_scale # x_mrr, # x input for mrr_roundtrip_phase_to_tr() # x_morr, - w_morr, # w input for propagate_morr() matmul - x_modulator, # x input for input_modulator() + # w_morr, # w input for propagate_morr() matmul + # x_modulator, # x input for input_modulator() morr_input_bias, # x_scalematmul, - x_scalematmul, # x input for morr_output_scale.matmul + # x_scalematmul, # x input for morr_output_scale.matmul morr_input_scale, # morr input scale at input - morr_scale, # morr_scale after modification in build_weight() + # morr_scale, # morr_scale after modification in build_weight() + finegrain_drop_mask, ) ctx.tensor_shape = tensor_shape - ctx.mrr_para = mrr_para + # ctx.mrr_para = mrr_para ctx.in_features = in_features ctx.in_features_pad = in_features_pad ctx.out_features = out_features @@ -570,6 +572,7 @@ def _morr_linear_setup_context(ctx, inputs, output): ctx.morr_fwhm = morr_fwhm ctx.grid_dim_x = grid_dim_x ctx.grid_dim_y = grid_dim_y + ctx.in_bit = in_bit ctx.w_bit = w_bit ctx.x_input_shape = x.shape ctx.device = x.device @@ -580,8 +583,163 @@ def _morr_linear_setup_context(ctx, inputs, output): ctx.trainable_morr_bias = trainable_morr_bias ctx.trainable_morr_scale = trainable_morr_scale ctx.weight_quant_gain = weight_quant_gain + ctx.miniblock = miniblock + ctx.crosstalk_factor = crosstalk_factor + ctx.sigma_weight = sigma_weight + ctx.enable_thermal_crosstalk = enable_thermal_crosstalk + ctx.mrr_a = mrr_a + ctx.mrr_r = mrr_r + +def recompute_activations( + ctx, + x: Tensor, + weight: Tensor, + bias: Tensor | None, + morr_output_scale: Tensor, + finegrain_drop_mask, + morr_input_bias: Tensor, + morr_input_scale: Tensor, +): + """ + Recompute activations for morr_linear_fn. + """ + Device = x.device + Dtype = x.dtype + + ctx_morr_scale = None + ctx_tanh_input_bias = None + + # Handle transformer vs non-transformer inputs + ori_x_shape = x.shape + is_transformer = len(ori_x_shape) == 3 + + if is_transformer: + in_B, in_N, in_D = x.shape + M = in_B * in_N + x = x.reshape(M, in_D) + else: + M = x.shape[0] + + # Get dimensions + M, D = x.shape + P, Q, K = weight.shape + + if ctx.in_features_pad > D: + x_pad = torch.zeros(M, ctx.in_features_pad - D, device=Device, dtype=x.dtype) + x = torch.cat([x, x_pad], dim=1) + + # Quantize input + if ctx.in_bit < 16: + input_quantizer = input_quantize_fn(ctx.in_bit, device=Device) + input_quantizer.set_bitwidth(ctx.in_bit) + x = input_quantizer(x) + + ################# Build weight ################# + if ctx.w_bit < 16: + weight_quantizer = weight_quantize_fn(ctx.w_bit, alg="dorefa_pos") + weight_quantizer.set_bitwidth(ctx.w_bit) + weight = weight_quantizer(weight) + + # Calculate morr_scale + if morr_input_scale is None: + return None + morr_scale = torch.sigmoid(morr_input_scale.unsqueeze(-1)) + 0.2 # [p, q, 1] + + ## rescale weights after quantization can maintain the initialization distribution + weight_quant_gain = ctx.weight_quant_gain + if weight_quant_gain is None: + weight_quant_gain = ctx.sigma_weight / weight.data.std() + if ctx.trainable_morr_scale: + morr_scale = morr_scale * weight_quant_gain + else: + morr_scale = weight_quant_gain + + ctx_morr_scale = morr_scale.clone() + weight = weight.mul( + morr_scale + ) ### gain factor from Tanh used in quantization + ### quantize learnable balancing factor + morr_output_scale_quantizer = weight_quantize_fn(ctx.w_bit, alg="dorefa_sym") + morr_output_scale = morr_output_scale_quantizer(morr_output_scale) + else: + weight = weight.abs() # positive only + morr_output_scale = (morr_output_scale - morr_output_scale.data.mean()) + + if finegrain_drop_mask is not None: + weight = weight.mul(finegrain_drop_mask.float()) + + # differential balancing factor concatenation + scale = morr_output_scale[..., :-1, :] + scale_pad = morr_output_scale[..., -1:, :] + if ctx.grid_dim_x % 2 == 0: + # even blocks + scale = torch.cat([scale, -scale], dim=2) # [1, 1, q, 1] + else: + # odd blocks + if ctx.grid_dim_x > 1: + scale = torch.cat([morr_output_scale, -scale], dim=2) # [1, 1, q, 1] + else: + scale = scale_pad # [1, 1, q, 1] + morr_output_scale = scale.squeeze(-1).unsqueeze(0) # [1 ,1, 1, q] + ctx_morr_output_scale = morr_output_scale.clone() + + # Reshape x and weight + x = x.view(-1, ctx.grid_dim_x, ctx.miniblock) # [M, q, k] + # input_modulator() + ctx_x_modulator = x.clone() + x = x ** 2 + + + ################# propagate_morr() ################# + if ctx.enable_thermal_crosstalk and ctx.crosstalk_factor > 1: + weight = weight * ctx.crosstalk_factor + weight = toeplitz(weight).unsqueeze(0) # [1, p, q, k, k] + x = x.unsqueeze(1).unsqueeze(-1) # [bs, 1, q, k, 1] + + ctx_x_morr = x.clone() + ctx_w_morr = weight.clone() + x = weight.matmul(x).squeeze(-1) # [bs, p, q, k] + + if ctx.enable_phase_noise and ctx.phase_noise_std > 1e-5: + x = x + torch.zeros_like(x).normal_(0, ctx.phase_noise_std) + + if ctx.trainable_morr_bias: + ctx_tanh_input_bias = torch.tanh(morr_input_bias.unsqueeze(0).unsqueeze(-1)) + morr_bias = ctx.morr_fwhm * ctx_tanh_input_bias + x = x - morr_bias + + ctx_x_mrr = x.clone() + + mrr_roundtrip_phase_to_tr = mrr_roundtrip_phase_to_tr_func(a=ctx.mrr_a, r=ctx.mrr_r, intensity=True) + x = mrr_roundtrip_phase_to_tr(x) + ctx_x_scalematmul = x.clone() + x = morr_output_scale.matmul(x) # [1, 1, 1, q] x [bs, p, q, k] = [bs, p, 1, k] + x = x.flatten(1) # [bs, p*k] + + # ------------------------------------------------------ + + # # Trim output if needed + # if ctx.out_features < ctx.out_features_pad: + # output = output[:, :ctx.out_features] + # if bias is not None: + # output = output + bias.unsqueeze(0) + # # Reshape back for transformer + # if is_transformer: + # output = output.view(in_B, in_N, ctx.out_features) + + return ( + # x, weight, bias, morr_output_scale, + # output, + ctx_x_modulator, # x input for input_modulator() + ctx_x_morr, # x input for propagate_morr() matmul + ctx_w_morr, # w input for propagate_morr() matmul + ctx_x_mrr, # x input for mrr_roundtrip_phase_to_tr() + ctx_x_scalematmul, # x input for morr_output_scale.matmul + ctx_tanh_input_bias, # input_bias after tanh() + ctx_morr_scale, # morr_scale after modification in build_weight() + ) def _morr_linear_backward(ctx, grad_output, *ignored): """ @@ -594,16 +752,17 @@ def _morr_linear_backward(ctx, grad_output, *ignored): morr_output_scale, # x_mrr, # x_morr, - w_morr, - x_modulator, + # w_morr, + # x_modulator, morr_input_bias, - x_scalematmul, + # x_scalematmul, morr_input_scale, - morr_scale, + # morr_scale, + finegrain_drop_mask ) = ctx.saved_tensors M, P, Q, K = ctx.tensor_shape - c1, c2, c3, c4, intensity = ctx.mrr_para + # c1, c2, c3, c4, intensity = ctx.mrr_para in_features = ctx.in_features in_features_pad = ctx.in_features_pad out_features = ctx.out_features @@ -613,17 +772,36 @@ def _morr_linear_backward(ctx, grad_output, *ignored): DEVICE = ctx.device # --- calculate intermediate activation on the fly --- - x_morr = (x_modulator ** 2).unsqueeze(1).unsqueeze(-1) # [m, q, k] -> # [m, 1, q, k, 1] + ( + x_modulator, # x input for input_modulator() + x_morr, # x input for propagate_morr() matmul + w_morr, # w input for propagate_morr() matmul + x_mrr, # x input for mrr_roundtrip_phase_to_tr() + x_scalematmul, # x input for morr_output_scale.matmul + tanh_input_bias, # input_bias after tanh() + morr_scale, # morr_scale after modificaiton in build_weight() + ) = recompute_activations( + ctx, + x, + weight, + bias, + morr_output_scale, + finegrain_drop_mask, + morr_input_bias, + morr_input_scale + ) - tanh_input_bias = torch.tanh(morr_input_bias.unsqueeze(0).unsqueeze(-1)) - morr_bias = ctx.morr_fwhm * tanh_input_bias + # x_morr = (x_modulator ** 2).unsqueeze(1).unsqueeze(-1) # [m, q, k] -> # [m, 1, q, k, 1] - # x_mrr: x input of mrr_roundtrip_phase_to_tr() - x_mrr = w_morr.matmul(x_morr).squeeze(-1) - if ctx.enable_phase_noise and ctx.phase_noise_std > 1e-5: - x_mrr = x_mrr + torch.zeros_like(x_mrr).normal_(0, ctx.phase_noise_std) - if ctx.trainable_morr_bias: - x_mrr = x_mrr - morr_bias + # tanh_input_bias = torch.tanh(morr_input_bias.unsqueeze(0).unsqueeze(-1)) + # morr_bias = ctx.morr_fwhm * tanh_input_bias + + # # x_mrr: x input of mrr_roundtrip_phase_to_tr() + # x_mrr = w_morr.matmul(x_morr).squeeze(-1) + # if ctx.enable_phase_noise and ctx.phase_noise_std > 1e-5: + # x_mrr = x_mrr + torch.zeros_like(x_mrr).normal_(0, ctx.phase_noise_std) + # if ctx.trainable_morr_bias: + # x_mrr = x_mrr - morr_bias @@ -678,6 +856,12 @@ def _morr_linear_backward(ctx, grad_output, *ignored): grad_x = morr_output_scale.transpose(-2, -1).matmul(grad_out) # [bs, p, q, k] # 4. x = mrr_roundtrip_phase_to_tr(x) + mrr_a, mrr_r = ctx.mrr_a, ctx.mrr_r + c1 = -2.0 * mrr_a * mrr_r + c2 = mrr_a * mrr_a + mrr_r * mrr_r + c3 = 1.0 + (mrr_r * mrr_r) * (mrr_a * mrr_a) - mrr_a * mrr_a - mrr_r * mrr_r + c4 = (mrr_a**2.0 - 1.0) * (mrr_r**2.0 - 1.0) * 2.0 * mrr_a * mrr_r + intensity = True denominator = x_mrr.cos().mul_(c1).add_(c2 + c3) if intensity: denominator.square_() @@ -690,7 +874,7 @@ def _morr_linear_backward(ctx, grad_output, *ignored): grad_x = numerator.div_(denominator).mul_(grad_x) # [bs, p, q, k] # 5. x += phase_noise and morr_bias - if ctx.needs_input_grad[2]: + if ctx.trainable_morr_bias and ctx.needs_input_grad[2]: grad_inputbias = - grad_x # [bs, p, q, k] grad_inputbias = grad_inputbias * ctx.morr_fwhm # [bs, p, q, k] grad_inputbias = grad_inputbias - tanh_input_bias * tanh_input_bias # [bs, p, q, k] @@ -738,6 +922,8 @@ def _morr_linear_backward(ctx, grad_output, *ignored): grad_w = buffer.sum(dim=-1, keepdim=True).squeeze(0).squeeze(-1) # 3. build_weight() + if finegrain_drop_mask is not None: + grad_w = grad_w * finegrain_drop_mask.float() # morr_scale: [p, q, 1] grad_morr_input_scale = None if ctx.w_bit < 16: diff --git a/src/chop/nn/optical/triton_modules/morr_linear_mem.py b/src/chop/nn/optical/triton_modules/morr_linear_mem.py index 85b047122..74f731d08 100644 --- a/src/chop/nn/optical/triton_modules/morr_linear_mem.py +++ b/src/chop/nn/optical/triton_modules/morr_linear_mem.py @@ -444,13 +444,14 @@ def apply_finegrain_drop_mask(self, mask: Tensor) -> None: self.weight.data.masked_fill_(~mask.view_as(self.weight.data), 0) def forward(self, x: Tensor) -> Tensor: - output, *_ = morr_linear_fn( - x, + output, *_ = morr_linear_fn_mem( + x, self.weight, morr_input_bias = self.morr_input_bias, morr_output_scale = self.morr_output_scale, bias = None, - morr_bias = self.morr_bias, + morr_input_scale = self.morr_input_scale, + morr_bias = self.morr_bias.detach(), grid_dim_x = self.grid_dim_x, grid_dim_y = self.grid_dim_y, miniblock = self.miniblock, @@ -469,5 +470,10 @@ def forward(self, x: Tensor) -> Tensor: in_bit = self.in_bit, w_bit = self.w_bit, morr_fwhm = self.morr_fwhm, + sigma_weight=self.sigma_weight, + trainable_morr_scale=self.trainable_morr_scale, # bool + morr_scale=self.morr_scale, + weight_quant_gain=self.weight_quant_gain, + seed = 42, ) return output \ No newline at end of file From 38d5da428003e2aeca890eb3924f3b1ae5290aea Mon Sep 17 00:00:00 2001 From: Johnny1882 Date: Wed, 21 May 2025 14:09:11 +0100 Subject: [PATCH 30/38] new memory-saving kernel --- src/chop/nn/optical/triton_modules/morr_linear_kernel_mem.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/chop/nn/optical/triton_modules/morr_linear_kernel_mem.py b/src/chop/nn/optical/triton_modules/morr_linear_kernel_mem.py index c928453cd..2f4866ecf 100644 --- a/src/chop/nn/optical/triton_modules/morr_linear_kernel_mem.py +++ b/src/chop/nn/optical/triton_modules/morr_linear_kernel_mem.py @@ -873,7 +873,7 @@ def _morr_linear_backward(ctx, grad_output, *ignored): ) grad_x = numerator.div_(denominator).mul_(grad_x) # [bs, p, q, k] - # 5. x += phase_noise and morr_bias + # 5. x += phase_noise and x -= morr_bias if ctx.trainable_morr_bias and ctx.needs_input_grad[2]: grad_inputbias = - grad_x # [bs, p, q, k] grad_inputbias = grad_inputbias * ctx.morr_fwhm # [bs, p, q, k] From b59e8d15e54f92bac1113ff7900027a85796f6f8 Mon Sep 17 00:00:00 2001 From: Johnny1882 Date: Sun, 25 May 2025 13:08:32 +0100 Subject: [PATCH 31/38] fix weight loading function --- .../optical/module_transform_helper.py | 50 ++++++++++++++++++- 1 file changed, 48 insertions(+), 2 deletions(-) diff --git a/src/chop/passes/module/transforms/optical/module_transform_helper.py b/src/chop/passes/module/transforms/optical/module_transform_helper.py index 1c7faa482..52a086497 100644 --- a/src/chop/passes/module/transforms/optical/module_transform_helper.py +++ b/src/chop/passes/module/transforms/optical/module_transform_helper.py @@ -60,7 +60,8 @@ def replace_by_name_optical(network, module_name: str, new_module, target_name): if target_name == "linear_morr_full": updated_module = weight_replacement_full_linear_optical(original, new_module) elif target_name in ["linear_morr", "linear_morr_triton", "linear_morr_triton_mem"]: - updated_module = weight_replacement_circulant_linear_optical(original, new_module) + # updated_module = weight_replacement_circulant_linear_optical(original, new_module) + updated_module = weight_randominit_circulant_linear_optical(original, new_module) elif target_name in ["bert_self_attention_morr"]: updated_module = weight_replacement_circulant_bert_attention(original, new_module) else: @@ -127,6 +128,52 @@ def weight_replacement_circulant_linear_optical(x, y): Focuses only on weight copying (no bias copying). """ + # Dense weight + W = x.weight.data # [out_features, in_features] + + # Dimensions defined by the MORR layer + k = y.miniblock # miniblock size + grid_dim_y = y.grid_dim_y # #block-rows (p) + grid_dim_x = y.grid_dim_x # #block-cols (q) + out_features_p = y.out_features_pad + in_features_p = y.in_features_pad + + # Zero-pad so every block is k×k + W_padded = W.new_zeros((out_features_p, in_features_p)) + W_padded[: W.size(0), : W.size(1)] = W + + new_weight = W.new_zeros((grid_dim_y, grid_dim_x, k)) # [p, q, k] + + idx = torch.arange(k, device=W.device) # 0 … k-1, reused in every block + + with torch.no_grad(): + for p in range(grid_dim_y): + row_slice = slice(p * k, (p + 1) * k) + + for q in range(grid_dim_x): + col_slice = slice(q * k, (q + 1) * k) + block = W_padded[row_slice, col_slice] # shape (k, k) + + # Frobenius-projection onto the circulant subspace: + # c_j = mean of { block[i, (i+j) mod k], i=0…k-1 } + c = torch.stack([ + block[idx, (idx + j) % k].mean() + for j in range(k) + ]) + + new_weight[p, q, :] = c # first row + + # Save back into the MORR layer + y.load_parameters({"weight": new_weight}) + + return y + +def weight_randominit_circulant_linear_optical(x, y): + """ + Replace the weights of AllPassMORRCirculantLinear (y) with those from a standard nn.Linear (x). + Focuses only on weight copying (no bias copying). + """ + # Fetch original linear weight [out_features, in_features] W = x.weight.data # [out_features, in_features] @@ -166,7 +213,6 @@ def weight_replacement_circulant_linear_optical(x, y): return y - def weight_replacement_conv2d_optical(x, y): """ Replace the weights (and bias, if present) of a standard nn.Conv2d (x) From bd5d90b03a0c9685c1afe17075508c608df280dc Mon Sep 17 00:00:00 2001 From: Johnny1882 Date: Wed, 28 May 2025 21:47:42 +0100 Subject: [PATCH 32/38] add warning in random weight --- .../transforms/optical/module_transform_helper.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/chop/passes/module/transforms/optical/module_transform_helper.py b/src/chop/passes/module/transforms/optical/module_transform_helper.py index 52a086497..d0387b9d2 100644 --- a/src/chop/passes/module/transforms/optical/module_transform_helper.py +++ b/src/chop/passes/module/transforms/optical/module_transform_helper.py @@ -5,6 +5,7 @@ from copy import deepcopy import logging import inspect +import warnings from chop.passes.module.module_modify_helper import ( get_module_by_name, @@ -60,8 +61,8 @@ def replace_by_name_optical(network, module_name: str, new_module, target_name): if target_name == "linear_morr_full": updated_module = weight_replacement_full_linear_optical(original, new_module) elif target_name in ["linear_morr", "linear_morr_triton", "linear_morr_triton_mem"]: - # updated_module = weight_replacement_circulant_linear_optical(original, new_module) - updated_module = weight_randominit_circulant_linear_optical(original, new_module) + updated_module = weight_replacement_circulant_linear_optical(original, new_module) + # updated_module = weight_randominit_circulant_linear_optical(original, new_module) elif target_name in ["bert_self_attention_morr"]: updated_module = weight_replacement_circulant_bert_attention(original, new_module) else: @@ -126,6 +127,7 @@ def weight_replacement_circulant_linear_optical(x, y): """ Replace the weights of AllPassMORRCirculantLinear (y) with those from a standard nn.Linear (x). Focuses only on weight copying (no bias copying). + take mean value along diagonal """ # Dense weight @@ -173,7 +175,11 @@ def weight_randominit_circulant_linear_optical(x, y): Replace the weights of AllPassMORRCirculantLinear (y) with those from a standard nn.Linear (x). Focuses only on weight copying (no bias copying). """ - + warnings.warn( + "Random weight initiator is being used!", + category=RuntimeWarning, + stacklevel=2, # point the warning at the caller + ) # Fetch original linear weight [out_features, in_features] W = x.weight.data # [out_features, in_features] From 6e82a677adc50ae9da1ee8e1105a696aae111091 Mon Sep 17 00:00:00 2001 From: Johnny1882 Date: Thu, 5 Jun 2025 19:51:16 +0100 Subject: [PATCH 33/38] fix kernel bugs to enable memory and training speed evaliation to be run --- .../nn/optical/triton_modules/morr_linear.py | 2 +- .../triton_modules/morr_linear_kernel.py | 26 +++++-- .../triton_modules/morr_linear_kernel_mem.py | 2 +- .../optical/module_transform_helper.py | 75 ++++++++++--------- 4 files changed, 59 insertions(+), 46 deletions(-) diff --git a/src/chop/nn/optical/triton_modules/morr_linear.py b/src/chop/nn/optical/triton_modules/morr_linear.py index a27c9bd37..e10c088e8 100644 --- a/src/chop/nn/optical/triton_modules/morr_linear.py +++ b/src/chop/nn/optical/triton_modules/morr_linear.py @@ -451,7 +451,7 @@ def forward(self, x: Tensor) -> Tensor: morr_output_scale = self.morr_output_scale, bias = None, morr_input_scale = self.morr_input_scale, - morr_bias = self.morr_bias.detach(), + morr_bias = self.morr_bias.detach() if self.morr_bias is not None else None, grid_dim_x = self.grid_dim_x, grid_dim_y = self.grid_dim_y, miniblock = self.miniblock, diff --git a/src/chop/nn/optical/triton_modules/morr_linear_kernel.py b/src/chop/nn/optical/triton_modules/morr_linear_kernel.py index e2a4aa751..05c0577e1 100644 --- a/src/chop/nn/optical/triton_modules/morr_linear_kernel.py +++ b/src/chop/nn/optical/triton_modules/morr_linear_kernel.py @@ -605,6 +605,7 @@ def _morr_linear_setup_context(ctx, inputs, output): x_quant, # x input for input_quantize_fn() w_quant, # w input for weight_quantize_fn() origin_morr_output_scale, # original morr_output_scale + finegrain_drop_mask, ) ctx.tensor_shape = tensor_shape ctx.mrr_para = mrr_para @@ -653,6 +654,7 @@ def _morr_linear_backward(ctx, grad_output, *ignored): x_quant, w_quant, origin_morr_output_scale, + finegrain_drop_mask ) = ctx.saved_tensors M, P, Q, K = ctx.tensor_shape @@ -688,7 +690,7 @@ def _morr_linear_backward(ctx, grad_output, *ignored): ) # [M, P, Q, K] # ----- Gradient w.r.t input x ----- - if ctx.needs_input_grad[0]: + if True or ctx.needs_input_grad[0]: # 1. reshape grad_out = grad_out.view(M, -1) # [m, out_features] @@ -727,9 +729,13 @@ def _morr_linear_backward(ctx, grad_output, *ignored): if ctx.morr_output_scale_quant_alg == "dorefa_sym": # local recompute: w_in = torch.tanh(origin_morr_output_scale) # [-1, 1] + r = torch.max(w_in.abs()).detach() + # ignore gradient for r here + grad_output_scale = (grad_output_scale * 2 * r).clamp_(-1.0, 1.0) + grad_output_scale = grad_output_scale * (1.0 / (2 * r)) grad_output_scale = grad_output_scale * (1.0 - w_in.pow(2)) - grad_output_scale = grad_output_scale.clamp_(-1, 1) + else: raise NotImplementedError else: @@ -786,9 +792,8 @@ def _morr_linear_backward(ctx, grad_output, *ignored): # 10. input reshape grad_x = grad_x.view(B, N, D) # [b, n, d] - # ----- Gradient w.r.t weight ----- - if ctx.needs_input_grad[1]: + if True or ctx.needs_input_grad[1]: # 0. gradient after x = weight.matmul(x) # grad_morr_matmul # [bs, p, q, k, 1] @@ -805,10 +810,12 @@ def _morr_linear_backward(ctx, grad_output, *ignored): idx = idx.expand(grad_w.shape).to(DEVICE) buffer = torch.zeros_like(grad_w, device=DEVICE) - buffer.scatter_add_(-2, idx, grad_w) # [1, p, q, k, k] + buffer.scatter_add_(-2, idx, grad_w) # [1, p, q, k, k]cvb grad_w = buffer.sum(dim=-1, keepdim=True).squeeze(0).squeeze(-1) # [p, q, k] # 3. build_weight() + if finegrain_drop_mask is not None: + grad_w = grad_w * finegrain_drop_mask.float() # morr_scale: [p, q, 1] grad_morr_input_scale = None if ctx.w_bit < 16: @@ -827,10 +834,15 @@ def _morr_linear_backward(ctx, grad_output, *ignored): pass elif ctx.w_quant_alg == "dorefa_pos": # local recompute: - w_in = torch.tanh(w_quant) # [-1, 1] + w_in = torch.tanh(w_quant) + r = torch.max(w_in.abs()).detach() + 1e-12 # ε avoids /0 # ignore gradient for r here + # grad_w = grad_w * (1.0 - w_in.pow(2)) + # grad_w = grad_w.clamp_(-1, 1) + grad_w = grad_w * (2 * r) + grad_w = grad_w.clamp(-1.0, 1.0) + grad_w = grad_w / (2 * r) grad_w = grad_w * (1.0 - w_in.pow(2)) - grad_w = grad_w.clamp_(-1, 1) else: raise NotImplementedError else: diff --git a/src/chop/nn/optical/triton_modules/morr_linear_kernel_mem.py b/src/chop/nn/optical/triton_modules/morr_linear_kernel_mem.py index 2f4866ecf..6ecc2e033 100644 --- a/src/chop/nn/optical/triton_modules/morr_linear_kernel_mem.py +++ b/src/chop/nn/optical/triton_modules/morr_linear_kernel_mem.py @@ -815,7 +815,7 @@ def _morr_linear_backward(ctx, grad_output, *ignored): ) # [M, P, Q, K] # ----- Gradient w.r.t input x ----- - if ctx.needs_input_grad[0]: + if True or ctx.needs_input_grad[0]: # 1. reshape grad_out = grad_out.view(M, -1) # [m, out_features] diff --git a/src/chop/passes/module/transforms/optical/module_transform_helper.py b/src/chop/passes/module/transforms/optical/module_transform_helper.py index d0387b9d2..c652dc9b5 100644 --- a/src/chop/passes/module/transforms/optical/module_transform_helper.py +++ b/src/chop/passes/module/transforms/optical/module_transform_helper.py @@ -180,44 +180,45 @@ def weight_randominit_circulant_linear_optical(x, y): category=RuntimeWarning, stacklevel=2, # point the warning at the caller ) - # Fetch original linear weight [out_features, in_features] - W = x.weight.data # [out_features, in_features] - - # Grab dimensions and zero-pad if needed - out_features_pad = y.out_features_pad # padded out_features in y - in_features_pad = y.in_features_pad # padded in_features in y - miniblock = y.miniblock - grid_dim_y = y.grid_dim_y - grid_dim_x = y.grid_dim_x - - # Construct padded weight tensor - W_padded = W.new_zeros((out_features_pad, in_features_pad)) - W_padded[: W.size(0), : W.size(1)] = W - - # Takes the mean across the miniblock slice. - new_weight = W.new_zeros((grid_dim_y, grid_dim_x, miniblock)) # [grid_dim_y, grid_dim_x, miniblock] - - # Fill new_weight by averaging the corresponding sub-blocks in W_padded - # original miniblock: [k, k] new miniblock: [k, 1] - with torch.no_grad(): - for p in range(grid_dim_y): - for q in range(grid_dim_x): - for k in range(miniblock): - row_idx = p * miniblock + k # The row in W_padded: - col_start = q * miniblock # The columns in W_padded: - col_end = (q + 1) * miniblock - block = W_padded[row_idx, col_start:col_end] - - new_weight[p, q, k] = block.mean() - - bound = 1 / math.sqrt(miniblock) - new_weight = torch.rand((grid_dim_y, grid_dim_x, miniblock), - device=W.device, - dtype=W.dtype) * 2 * bound - bound - # Copy the result into y.weight - y.load_parameters({"weight": new_weight}) - + # y.reset_parameters() return y + # # Fetch original linear weight [out_features, in_features] + # W = x.weight.data # [out_features, in_features] + + # # Grab dimensions and zero-pad if needed + # out_features_pad = y.out_features_pad # padded out_features in y + # in_features_pad = y.in_features_pad # padded in_features in y + # miniblock = y.miniblock + # grid_dim_y = y.grid_dim_y + # grid_dim_x = y.grid_dim_x + + # # Construct padded weight tensor + # W_padded = W.new_zeros((out_features_pad, in_features_pad)) + # W_padded[: W.size(0), : W.size(1)] = W + + # # Takes the mean across the miniblock slice. + # new_weight = W.new_zeros((grid_dim_y, grid_dim_x, miniblock)) # [grid_dim_y, grid_dim_x, miniblock] + + # # Fill new_weight by averaging the corresponding sub-blocks in W_padded + # # original miniblock: [k, k] new miniblock: [k, 1] + # with torch.no_grad(): + # for p in range(grid_dim_y): + # for q in range(grid_dim_x): + # for k in range(miniblock): + # row_idx = p * miniblock + k # The row in W_padded: + # col_start = q * miniblock # The columns in W_padded: + # col_end = (q + 1) * miniblock + # block = W_padded[row_idx, col_start:col_end] + + # new_weight[p, q, k] = block.mean() + + # bound = 1 / math.sqrt(miniblock) + # new_weight = torch.rand((grid_dim_y, grid_dim_x, miniblock), + # device=W.device, + # dtype=W.dtype) * 2 * bound - bound + # # Copy the result into y.weight + # y.load_parameters({"weight": new_weight}) + def weight_replacement_conv2d_optical(x, y): """ From e956c49a6f8f8b81591eebb8605d032fc7fbcc42 Mon Sep 17 00:00:00 2001 From: Johnny1882 Date: Tue, 24 Jun 2025 17:13:50 +0100 Subject: [PATCH 34/38] remove breakpoint use black to format code --- .gitignore | 4 +- src/chop/nn/optical/modules/QuantizedGPT.py | 642 ------------------ src/chop/nn/optical/modules/__init__.py | 7 +- .../nn/optical/modules/morr_custom_linear.py | 494 -------------- src/chop/nn/optical/modules/morr_linear.py | 2 +- .../modules/morr_transformer/morr_bert.py | 143 ---- .../modules/morr_transformer/morr_matmul.py | 535 --------------- .../morr_transformer/morr_transformer.py | 168 ----- .../nn/optical/triton_modules/morr_linear.py | 44 +- .../triton_modules/morr_linear_kernel.py | 437 +++++++----- .../triton_modules/morr_linear_kernel_mem.py | 464 +++++++------ .../optical/triton_modules/morr_linear_mem.py | 46 +- .../nn/optical/triton_modules/quantize.py | 12 +- src/chop/nn/snn/modules/__init__.py | 4 +- .../add_metadata/common_metadata_layers.py | 4 +- .../passes/module/module_modify_helper.py | 18 +- .../attention/attention_transform_helper.py | 5 +- .../optical/module_transform_helper.py | 198 ++---- .../module/transforms/optical/optical.py | 58 +- .../difflogic_layers/passes.py | 1 - .../attention/test_attention_transform.py | 4 +- .../transforms/optical/bert-finetune.py | 140 ---- .../module/transforms/optical/playground.py | 22 - .../module/transforms/optical/run_glue.py | 637 ----------------- .../transforms/optical/test_optical_module.py | 52 ++ 25 files changed, 748 insertions(+), 3393 deletions(-) delete mode 100644 src/chop/nn/optical/modules/QuantizedGPT.py delete mode 100644 src/chop/nn/optical/modules/morr_custom_linear.py delete mode 100644 src/chop/nn/optical/modules/morr_transformer/morr_bert.py delete mode 100644 src/chop/nn/optical/modules/morr_transformer/morr_matmul.py delete mode 100644 src/chop/nn/optical/modules/morr_transformer/morr_transformer.py delete mode 100644 test/passes/module/transforms/optical/bert-finetune.py delete mode 100644 test/passes/module/transforms/optical/playground.py delete mode 100644 test/passes/module/transforms/optical/run_glue.py diff --git a/.gitignore b/.gitignore index d633017dd..c1b540763 100644 --- a/.gitignore +++ b/.gitignore @@ -174,6 +174,4 @@ test-trainer/ # DiffLogic: tutorial files docs/tutorials/difflogic/data-mnist/ -test/self -model_sst2/ -runs \ No newline at end of file +test/self \ No newline at end of file diff --git a/src/chop/nn/optical/modules/QuantizedGPT.py b/src/chop/nn/optical/modules/QuantizedGPT.py deleted file mode 100644 index 424570db1..000000000 --- a/src/chop/nn/optical/modules/QuantizedGPT.py +++ /dev/null @@ -1,642 +0,0 @@ -import sys -from numpy import outer -import torch -import torch.nn as nn -import pytorch_lightning as pl -import torchmetrics -from transformers import GPT2TokenizerFast -import transformers -import torch.nn.functional as F -import math -import numpy -sys.path.append('...') - - -def make_autoregressive_mask_for(x): - length = x.size(1) - ones = x.new_ones((length, length)) - mask = torch.triu(ones, diagonal=1) != 0.0 - return mask - - -def make_position_indices_for(x): - length = x.size(1) - batch_size = x.size(0) - indices = torch.arange(length, device=x.device).repeat(batch_size, 1) - return indices - - -def load_lookup_table(file, device): - data = torch.from_numpy(numpy.genfromtxt(file, delimiter='\t')).float() - levels = data.size(0) - lower_bound = data[0,1].item() - weight = data[:,1].unsqueeze(1).cuda(device) - return weight, lower_bound, levels - - -def apply_lut_to_normalized(x, lut, bit_degredation=0): - lut_weight, lut_lb, lut_levels = lut - deg_factor = 2**bit_degredation - x = x.mul(lut_levels - deg_factor).div(deg_factor).round().mul(deg_factor).to(dtype=torch.long) - x = F.embedding(x, lut_weight).squeeze(-1) - return x - - -class QuantizeValue(torch.autograd.Function): - @staticmethod - def forward(ctx, x, quant_levels, min_val, max_val, quant_mode, lut_min=None): - with torch.no_grad(): - diff = max_val - min_val - x = x.clamp(min_val, max_val).add(-1.0 * min_val).div(diff + 1e-8).mul(quant_levels - 1) - - if quant_mode == 'det': - x = x.round() - x = x.div(quant_levels - 1).mul(diff).add(min_val) - elif quant_mode == 'rand': - x = x.add(torch.rand_like(x).add(-0.5)).round() # randn* 0.288 gives same std as 0-1 rand(), if want to use normal dist. - x = x.div(quant_levels - 1).mul(diff).add(min_val) - - if lut_min is not None: - pos_x = torch.relu(x) - neg_x = x - pos_x - lms = lut_min * max_val - pos_x[pos_x < lms] = lms - lms = lut_min * torch.abs(min_val) - neg_x[neg_x > -lms] = -lms - x = pos_x + neg_x - - return x - - @staticmethod - def backward(ctx, grad_output): - # STE - return grad_output, None, None, None, None, None - - -class QuantizeStats(nn.Module): - def __init__(self, percentile, use_clipping=True): - super(QuantizeStats, self).__init__() - self.register_buffer('running_min', torch.tensor(0.0)) - self.register_buffer('running_max', torch.tensor(0.0)) - self.max_calibration_steps = 1 - self.initial_calibration_steps = 0 - #self.register_buffer('calibration_done', torch.tensor(False)) - self.calibration_done = torch.tensor(False) - self.activations = [] - self.percentile = percentile - self.use_clipping = use_clipping - - def update(self, tensor): - if self.use_clipping: - if not self.calibration_done.item(): - self.initial_calibration_steps += 1 - finished = False - - if self.initial_calibration_steps >= self.max_calibration_steps: - finished = True - self.calibration_done = torch.tensor(True) - - with torch.no_grad(): - self.activations.extend(tensor.detach().cpu().tolist()) - - if finished: - maximum = numpy.percentile(self.activations, self.percentile) - self.running_max = torch.tensor(maximum, device=tensor.device, dtype=tensor.dtype) - minimum = tensor.min() - minimum = minimum if minimum >= 0.0 else -maximum - self.running_min = torch.tensor(minimum, device=tensor.device, dtype=tensor.dtype) - self.activations.clear() # free the memory - else: - self.running_min = tensor.min() - self.running_max = tensor.max() - - else: - alpha = 0.999 - with torch.no_grad(): - cur_min = tensor.min() - cur_max = tensor.max() - - if self.initial_calibration_steps == 0: - self.initial_calibration_steps += 1 - self.running_min = cur_min - self.running_max = cur_max - else: - self.running_min = alpha * self.running_min + (1.0 - alpha) * cur_min - self.running_max = alpha * self.running_max + (1.0 - alpha) * cur_max - - - - def get(self): - return self.running_min, self.running_max - - -def shot_noise_linear(w, x, n_photons_target, phone_lut=None, slm_lut=None, extract=False, extract_name=None): - noise_level = 0.021 - - if n_photons_target != 0: - quantize = QuantizeValue.apply - use_lut = (phone_lut is not None) and (slm_lut is not None) - w_max = torch.max(w) - w_norm = apply_lut_to_normalized(w / (1e-8 + w_max), slm_lut) if use_lut else w / (1e-8 + w_max) - x_max = torch.max(x, dim=2, keepdim=True)[0] - x_norm = apply_lut_to_normalized(x / (1e-8 + x_max), phone_lut, bit_degredation=0) if use_lut else x / (1e-8 + x_max) - - out_opt = F.linear(x_norm, w_norm, bias=None) - photons_per_act = n_photons_target * x_norm.size(2) / (x_norm.sum(dim=2, keepdim=True) + 1e-8) - fluence_Wx = out_opt * photons_per_act - noise_Wx = torch.poisson(fluence_Wx) - out = noise_Wx / photons_per_act - - random_noise = noise_level * out.mean() - out = torch.normal(out, random_noise) - - out = x_max * out * w_max - else: - out = F.linear(x, w, bias=None) - - if extract and n_photons_target != 0: - torch.save({'x': x_norm[1, :512, :].detach().clone(), - 'w': w_norm[:512].detach().clone(), - 'out': out_opt[1, :512, :512].detach().clone(), - 'noise_level': noise_level}, - #'noise_value': random_noise}, - extract_name) - - return out - - -def shot_noise_bhmm(x, y, n_photons_target, phone_lut=None, slm_lut=None, extract=False, extract_name=None): - # perform xy matrix-multiply like matrix vector, where matrix "slices" in y are like W and x is the vectors. Thus take max over whole matrices in y as we would for W - noise_level = 0.0565 - - if n_photons_target != 0: - quantize = QuantizeValue.apply - use_lut = (phone_lut is not None) and (slm_lut is not None) - x_max = torch.max(x, dim=3, keepdim=True)[0] - x_norm = apply_lut_to_normalized(x / (1e-8 + x_max), phone_lut, bit_degredation=0) if use_lut else x / (1e-8 + x_max) - y_max = torch.amax(y, dim=(2, 3), keepdim=True) - y_norm = apply_lut_to_normalized(y / (1e-8 + y_max), slm_lut) if use_lut else y / (1e-8 + y_max) - - out_opt = torch.matmul(x_norm, y_norm) - photons_per_act = n_photons_target * x_norm.size(3) / (x_norm.sum(dim=3, keepdim=True) + 1e-8) - fluence_mm = out_opt * photons_per_act - noise_Wx = torch.poisson(fluence_mm) - out = noise_Wx / photons_per_act - - random_noise = noise_level * out.mean() - out = torch.normal(out, random_noise) - - out = x_max * out * y_max - else: - out = torch.matmul(x, y) - - if extract and n_photons_target != 0: - torch.save({'x': x_norm[0, 0, :, :].detach().clone(), - 'y': y_norm[0, 0, :, :].detach().clone(), - 'out': out_opt[0, 0, :, :].detach().clone(), - 'noise_level': noise_level}, - #'noise_value': random_noise}, - extract_name) - - return out - - -class QuantizedLinear(nn.Module): - def __init__(self, in_feats, out_feats, use_noise=True): - super(QuantizedLinear, self).__init__() - self.weight = nn.Parameter(torch.zeros(out_feats, in_feats)) - self.input_stats = QuantizeStats(99.99) - self.output_stats = QuantizeStats(99.9999) - nn.init.xavier_uniform_(self.weight) - self.quantize = False - self.photon_target = 0 - self.slm_lut = load_lookup_table('LUTs/SLM_AmpLUT.txt', device=torch.device('cuda:0')) - self.phone_lut = load_lookup_table('LUTs/PhoneLUT.txt', device=torch.device('cuda:0')) - self.use_lut = (self.slm_lut is not None) and (self.phone_lut is not None) - if self.use_lut: - _, self.slm_cutoff, _ = self.slm_lut - else: - self.slm_cutoff = None - self.force_quantized_eval = False - self.extract_simulated = False - self.extract_name = '' - self.noise = use_noise - #print('L module using LUT: {}'.format(self.use_lut)) - - def _weight_min(self): - with torch.no_grad(): - return self.weight_min - - def _weight_max(self): - with torch.no_grad(): - return self.weight_max - - def enable_quantization(self, clipping=True): - with torch.no_grad(): - if clipping: - weight_values = self.weight.detach().cpu().tolist() - maximum = numpy.percentile(weight_values, 99).item() - self.weight_max = torch.tensor(maximum, dtype=self.weight.dtype, device=self.weight.device) - self.weight_min = torch.tensor(-maximum, dtype=self.weight.dtype, device=self.weight.device) - else: - self.weight_min = self.weight.min() - self.weight_max = self.weight.max() - self.quantize = True - - def set_photon_target(self, n_photons): - self.photon_target = n_photons - - def forward(self, x): - if self.quantize: - quantize = QuantizeValue.apply - if self.training or self.force_quantized_eval: - # QAT for activations - if self.training: - self.input_stats.update(x) - input_min, input_max = self.input_stats.get() - quantized_x = quantize(x, 256, input_min, input_max, 'det') - quantized_weights = quantize(self.weight, 256, self._weight_min(), self._weight_max(), 'det', self.slm_cutoff) # 160 - out = F.linear(quantized_x, quantized_weights, bias=None) - if self.training: - self.output_stats.update(out) - output_min, output_max = self.output_stats.get() - quantized_out = quantize(out, 256, output_min, output_max, 'rand') - return quantized_out - else: - # shot noise simulation for linear layer, per-token - input_min, input_max = self.input_stats.get() - weight_min, weight_max = self._weight_min(), self._weight_max() - - if self.use_lut: - w = self.weight.clamp(weight_min, weight_max) - x = x.clamp(input_min, input_max) - else: - quantize = QuantizeValue.apply - x = quantize(x, 256, input_min, input_max, 'det') - w = quantize(self.weight, 256, weight_min, weight_max, 'det', self.slm_cutoff) - - pos_x = F.relu(x) - neg_x = torch.abs(x - pos_x) - pos_w = F.relu(w) - neg_w = torch.abs(w - pos_w) - out = shot_noise_linear(pos_w, pos_x, self.photon_target, self.phone_lut, self.slm_lut, self.extract_simulated, self.extract_name + '_0.pt') \ - + shot_noise_linear(neg_w, neg_x, self.photon_target, self.phone_lut, self.slm_lut, self.extract_simulated, self.extract_name + '_1.pt') \ - - shot_noise_linear(pos_w, neg_x, self.photon_target, self.phone_lut, self.slm_lut, self.extract_simulated, self.extract_name + '_2.pt') \ - - shot_noise_linear(neg_w, pos_x, self.photon_target, self.phone_lut, self.slm_lut, self.extract_simulated, self.extract_name + '_3.pt') - - output_min, output_max = self.output_stats.get() - out = out.clamp(output_min, output_max) - #out = quantize(out, 64, output_min, output_max, 'det') - return out - else: - out = F.linear(x, self.weight, bias=None) - return out - - -class QuantizedMatmul(nn.Module): - def __init__(self): - super(QuantizedMatmul, self).__init__() - self.input1_stats = QuantizeStats(99.99) - self.input2_stats = QuantizeStats(98) - self.output_stats = QuantizeStats(99.9999) - self.quantize = False - self.photon_target = 0 - self.slm_lut = load_lookup_table('LUTs/SLM_AmpLUT.txt', device=torch.device('cuda:0')) - self.phone_lut = load_lookup_table('LUTs/PhoneLUT.txt', device=torch.device('cuda:0')) - self.use_lut = (self.slm_lut is not None) and (self.phone_lut is not None) - if self.use_lut: - _, self.slm_cutoff, _ = self.slm_lut - else: - self.slm_cutoff = None - self.force_quantized_eval = False - self.extract_simulated = False - self.extract_name = '' - #print('MM module using LUT: {}'.format(self.use_lut)) - - def enable_quantization(self): - self.quantize = True - - def set_photon_target(self, n_photons): - self.photon_target = n_photons - - def forward(self, x, y): - if self.quantize: - quantize = QuantizeValue.apply - if self.training or self.force_quantized_eval: - # QAT for activations - if self.training: - self.input1_stats.update(x) - self.input2_stats.update(y) - x_min, x_max = self.input1_stats.get() - y_min, y_max = self.input2_stats.get() - xq = quantize(x, 256, x_min, x_max, 'det') - yq = quantize(y, 256, y_min, y_max, 'det', self.slm_cutoff) - out = torch.matmul(xq, yq) - if self.training: - self.output_stats.update(out) - out_min, out_max = self.output_stats.get() - outq = quantize(out, 256, out_min, out_max, 'rand') - return outq - else: - # Shot noise simulation for broadcasted matrix-matrix multiply - x_min, x_max = self.input1_stats.get() - y_min, y_max = self.input2_stats.get() - - if self.use_lut: - x = x.clamp(x_min, x_max) - y = y.clamp(y_min, y_max) - else: - quantize = QuantizeValue.apply - x = quantize(x, 256, x_min, x_max, 'det') - y = quantize(y, 256, y_min, y_max, 'det', self.slm_cutoff) - - pos_x = F.relu(x) - neg_x = torch.abs(x - pos_x) - pos_y = F.relu(y) - neg_y = torch.abs(y - pos_y) - out = shot_noise_bhmm(pos_x, pos_y, self.photon_target, self.phone_lut, self.slm_lut, self.extract_simulated, self.extract_name + '_0.pt') \ - + shot_noise_bhmm(neg_x, neg_y, self.photon_target, self.phone_lut, self.slm_lut, self.extract_simulated, self.extract_name + '_1.pt') \ - - shot_noise_bhmm(pos_x, neg_y, self.photon_target, self.phone_lut, self.slm_lut, self.extract_simulated, self.extract_name + '_2.pt') \ - - shot_noise_bhmm(neg_x, pos_y, self.photon_target, self.phone_lut, self.slm_lut, self.extract_simulated, self.extract_name + '_3.pt') - - output_min, output_max = self.output_stats.get() - out = out.clamp(output_min, output_max) - #out = quantize(out, 64, output_min, output_max, 'det') - return out - else: - out = torch.matmul(x, y) - return out - - -class QuantizedMHA(nn.Module): - def __init__(self, embed_dim, heads): - super(QuantizedMHA, self).__init__() - assert embed_dim % heads == 0 - self.n_heads = heads - self.Wq = QuantizedLinear(embed_dim, embed_dim) - self.Wk = QuantizedLinear(embed_dim, embed_dim) - self.Wv = QuantizedLinear(embed_dim, embed_dim) - self.qmm1 = QuantizedMatmul() - self.dropout_wq = nn.Dropout(0.1) - self.dropout_wk = nn.Dropout(0.1) - self.dropout_wv = nn.Dropout(0.1) - self.qmm2 = QuantizedMatmul() - self.Wout = QuantizedLinear(embed_dim, embed_dim) - self.dropout1 = nn.Dropout(0.1) - self.dropout2 = nn.Dropout(0.1) - - def forward(self, x, mask): - b = x.size(0) - n = x.size(1) - h = self.n_heads - d = x.size(2) - - def arrange_heads(acts): - # incoming shape of b, n, d, want b, h, n, d/h - return acts.view(b, n, h, -1).transpose(1, 2) - - q = arrange_heads(self.dropout_wq(self.Wq(x))) - k = arrange_heads(self.dropout_wk(self.Wk(x))) - v = arrange_heads(self.dropout_wv(self.Wv(x))) - - attn = self.qmm1(q, k.transpose(2, 3)) # yields b, h, n, n - masked = attn.masked_fill(mask, float("-inf")) - softmax_attn = self.dropout1(F.softmax(masked / math.sqrt(d // h), dim=3)) - out = self.qmm2(softmax_attn, v) # b, h, n, d/h - - out = out.transpose(1, 2).reshape(b, n, -1) - out = self.dropout2(out) - out = self.Wout(out) - return out - - -class QuantizedFF(nn.Module): - def __init__(self, embed_dim, expansion_dim): - super(QuantizedFF, self).__init__() - self.first_drop = nn.Dropout(0.1) - self.layer1 = QuantizedLinear(embed_dim, expansion_dim, use_noise=True) - self.act = nn.ReLU6(inplace=True) - self.dropout = nn.Dropout(0.1) - self.layer2 = QuantizedLinear(expansion_dim, embed_dim, use_noise=True) - - def forward(self, x): - out = self.first_drop(x) - out = self.layer1(out) - out = self.act(out) - out = self.dropout(out) - out = self.layer2(out) - return out - - -class QuantizedDecoderLayer(nn.Module): - def __init__(self, features, heads): - super(QuantizedDecoderLayer, self).__init__() - self.norm1 = nn.LayerNorm(features) - self.attn = QuantizedMHA(features, heads) - self.drop1 = nn.Dropout(0.1) - self.norm2 = nn.LayerNorm(features) - self.ff = QuantizedFF(features, features * 4) - self.drop2 = nn.Dropout(0.1) - - def forward(self, x, attn_mask): - # no need for key mask for gpt; autoregressive masking already prevents 'real' tokens from attending to padding tokens to the right - identity = x - out = self.norm1(x) - out = self.attn(out, attn_mask) - out = self.drop1(out) - out = out + identity - identity = out - out = self.norm2(out) - out = self.ff(out) - out = self.drop2(out) - out = out + identity - return out - - -class _QuantizedGPT(nn.Module): - def __init__(self, features, heads, tokenizer, layers, max_length): - super(_QuantizedGPT, self).__init__() - vocab_size = len(tokenizer) + 8 - len(tokenizer) % 8 # pad vocab size to 8-multiple for tensor core acceleration - assert vocab_size % 8 == 0 - self.pos_embedding = nn.Embedding(max_length, features) - self.word_embedding = nn.Embedding(vocab_size, features, padding_idx = tokenizer.pad_token_id) - self.embedding_dropout = nn.Dropout(0.1) - self.decoders = nn.ModuleList([QuantizedDecoderLayer(features, heads) for _ in range(layers)]) - self.norm = nn.LayerNorm(features) - self.output_head = nn.Linear(features, vocab_size) - nn.init.normal_(self.word_embedding.weight, std=0.02) - nn.init.normal_(self.pos_embedding.weight, std=0.02) - - def forward_embedding(self, x): - embedded = self.word_embedding(x) - return embedded - - def forward_attn(self, x): - mask = make_autoregressive_mask_for(x) - pos = make_position_indices_for(x) - pos_embed = self.embedding_dropout(self.pos_embedding(pos) + x) - decoded = pos_embed - for layer in self.decoders: - decoded = layer(decoded, mask) - - out = self.norm(decoded) - return out - - def forward(self, x): - embedded = self.forward_embedding(x) - decoded = self.forward_attn(embedded) - out = self.output_head(decoded) - return out - - -class QuantizedGPT(pl.LightningModule): - def __init__(self, features, heads, layers=6, max_length=1024): - super().__init__() - self.tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") - self.tokenizer.add_special_tokens({'pad_token': '<|pad|>'}) - self.transformer = _QuantizedGPT(features, heads, self.tokenizer, layers, max_length) - self.loss = nn.CrossEntropyLoss(ignore_index=self.tokenizer.pad_token_id) - self.val_loss = torchmetrics.MeanMetric() - self.test_loss = torchmetrics.MeanMetric() - self.lr = 0.0005 - self.photon_target = 0 - self.training_steps = 100000 - self.extracting = False - self.use_adam = True - - def get_tokenizer(self): - return self.tokenizer - - def forward(self, x): - return self.transformer(x) - - def training_step(self, batch, batch_idx): - xs, ys = batch - preds = self(xs) - features = preds.size(2) - preds = preds.view(-1, features) - ys = ys.view(-1) - loss = self.loss(preds, ys) - self.log('train loss', loss) - return loss - - def validation_step(self, batch, batch_idx): - xs, ys = batch - preds = self(xs) - features = preds.size(2) - preds = preds.view(-1, features) - ys = ys.view(-1) - loss = self.loss(preds, ys) - self.val_loss.update(loss) - - def validation_epoch_end(self, outputs): - self.log('validation loss', self.val_loss) - - def test_step(self, batch, batch_idx): - xs, ys = batch - preds = self(xs) - features = preds.size(2) - preds = preds.view(-1, features) - ys = ys.view(-1) - loss = self.loss(preds, ys) - self.test_loss.update(loss) - if self.extracting: - raise ValueError("Extraction done, aborting") - - def test_epoch_end(self, outputs): - self.log('test loss', self.test_loss) - self.log('photon target', self.photon_target) - - def configure_optimizers(self): - if self.use_adam: - decay = set() - no_decay = set() - blacklist_weight_modules = (nn.LayerNorm, nn.Embedding) - - for mn, m in self.named_modules(): - for pn, p in m.named_parameters(recurse=False): - fpn = '%s.%s' % (mn, pn) if mn else pn # full param name - - if 'bias' in pn: - no_decay.add(fpn) - elif 'weight' in pn and not isinstance(m, blacklist_weight_modules): - decay.add(fpn) - else: - no_decay.add(fpn) - - param_dict = {pn: p for pn, p in self.named_parameters()} - inter_params = decay & no_decay - union_params = decay | no_decay - - assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), ) - assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \ - % (str(param_dict.keys() - union_params), ) - - optim_groups = [ - {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": 0.02}, - {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0}, - ] - - optimizer = torch.optim.AdamW(optim_groups, lr=self.lr) - scheduler = transformers.get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=2500, num_training_steps=self.training_steps) - return { - 'optimizer': optimizer, - 'lr_scheduler': { - 'scheduler': scheduler, - 'interval': 'step', - 'name': 'Cosine LR scheduler' - } - } - else: - optimizer = torch.optim.RMSprop(self.parameters(), lr=self.lr, weight_decay=1e-5) - scheduler = transformers.get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=2500, num_training_steps=self.training_steps) - return { - 'optimizer': optimizer, - 'lr_scheduler': { - 'scheduler': scheduler, - 'interval': 'step', - 'name': 'Cosine LR scheduler' - } - } - - def replace_output_head(self, module): - self.transformer.output_head = module - - def enable_quantization(self): - for m in self.transformer.modules(): - if isinstance(m, QuantizedLinear) or isinstance(m, QuantizedMatmul): - m.enable_quantization() - - def set_photon_target(self, n_photons): - self.photon_target = n_photons - for m in self.transformer.modules(): - if isinstance(m, QuantizedLinear) or isinstance(m, QuantizedMatmul): - m.set_photon_target(n_photons) - - def set_quantized_eval(self, value=True): - for m in self.transformer.modules(): - if isinstance(m, QuantizedLinear) or isinstance(m, QuantizedMatmul): - print("setting quantized eval") - m.force_quantized_eval = value - - def save(self, fname): - torch.save(self.transformer.state_dict(), fname) - - def load(self, fname): - self.transformer.load_state_dict(torch.load(fname)) - - def enable_extraction(self): - lin1 = self.transformer.decoders[0].ff.layer2 - lin1.extract_simulated = True - lin1.extract_name = 'first_linear' - lin2 = self.transformer.decoders[-1].ff.layer2 - lin2.extract_simulated = True - lin2.extract_name = 'last_linear' - attn1 = self.transformer.decoders[0].attn.qmm1 - attn1.extract_simulated = True - attn1.extract_name = 'first_attn' - attn2 = self.transformer.decoders[-1].attn.qmm1 - attn2.extract_simulated = True - attn2.extract_name = 'last_attn' - self.extracting = True - \ No newline at end of file diff --git a/src/chop/nn/optical/modules/__init__.py b/src/chop/nn/optical/modules/__init__.py index cddc777fd..539ddb38a 100644 --- a/src/chop/nn/optical/modules/__init__.py +++ b/src/chop/nn/optical/modules/__init__.py @@ -1,16 +1,11 @@ from .morr_linear import AllPassMORRCirculantLinear from .morr_conv2d import AllPassMORRCirculantConv2d -from .morr_custom_linear import AllPassMORRLinear from ..triton_modules.morr_linear import TritonMORRLinear from ..triton_modules.morr_linear_mem import TritonMemMORRLinear -from .morr_transformer.morr_bert import BertMORRSelfAttention optical_module_map = { "linear_morr": AllPassMORRCirculantLinear, "conv2d_morr": AllPassMORRCirculantConv2d, - "linear_morr_full": AllPassMORRLinear, - "linear_morr_triton": TritonMORRLinear, - "linear_morr_triton_mem": TritonMemMORRLinear, - "bert_self_attention_morr": BertMORRSelfAttention, + "linear_morr_triton": TritonMemMORRLinear, } diff --git a/src/chop/nn/optical/modules/morr_custom_linear.py b/src/chop/nn/optical/modules/morr_custom_linear.py deleted file mode 100644 index 9e9f5272a..000000000 --- a/src/chop/nn/optical/modules/morr_custom_linear.py +++ /dev/null @@ -1,494 +0,0 @@ -""" -Description: -Author: Jiaqi Gu (jqgu@utexas.edu) -Date: 2022-04-18 14:19:57 -LastEditors: Jiaqi Gu (jqgu@utexas.edu) -LastEditTime: 2022-04-18 16:21:37 -""" - -from typing import Optional -import logging - -import numpy as np -import torch -import torch.fft -from torch import Tensor -from torch.nn import Parameter, init -from torch.types import Device -import torch.nn.functional as F - -from ..utils import MORRConfig_20um_MQ -from ..utils import mrr_roundtrip_phase_to_tr_func, mrr_roundtrip_phase_to_tr_fused -from ..utils import toeplitz -from ..utils import morr_uniform_ -from ..utils import input_quantize_fn, weight_quantize_fn -from .base_layer import ONNBaseLayer - -logger = logging.getLogger(__name__) - -__all__ = ["AllPassMORRCirculantLinear"] - - -class AllPassMORRLinear(ONNBaseLayer): - """ - All-pass MORR Linear layer, assumes (1) block-circulant matrix (2) differential rails (3) learnable balancing factors. - J. Gu, et al., "SqueezeLight: Towards Scalable Optical Neural Networks with Multi-Operand Ring Resonators" - https://doi.org/10.23919/DATE51398.2021.9474147 - """ - - __constants__ = ["in_features", "out_features"] - in_features: int - out_features: int - miniblock: int - weight: Tensor - - def __init__( - self, - in_features: int, - out_features: int, - bias: bool = False, - config=None, - device: Device = torch.device("cpu"), - ) -> None: - super(AllPassMORRLinear, self).__init__() - self.in_features = in_features - self.out_features = out_features - - miniblock_size = config.get("miniblock", 4) - self.miniblock = miniblock_size - # M * N/k MORR grid - self.grid_dim_x = int(np.ceil(self.in_features / (self.miniblock))) - self.grid_dim_y = out_features - self.in_features_pad = self.grid_dim_x * miniblock_size - self.out_features_pad = self.grid_dim_y * miniblock_size - - self.v_max = 10.8 - self.v_pi = 4.36 - self.gamma = np.pi / self.v_pi**2 - self.w_bit = 32 - self.in_bit = 32 - - morr_config = config.get("MORRConfig", MORRConfig_20um_MQ) - morr_init_val = config.get("morr_init", MORRConfig_20um_MQ) - self.MORRConfig = morr_config - self.morr_init = morr_init_val - self.mrr_a = morr_config.attenuation_factor - self.mrr_r = morr_config.coupling_factor - self.trainable_morr_bias = config.get("trainable_morr_bias", MORRConfig_20um_MQ) - self.trainable_morr_scale = config.get( - "trainable_morr_scale", MORRConfig_20um_MQ - ) - self.device = device - ### calculate FWHM (rad) - self.morr_fwhm = ( - -4 - * np.pi**2 - * morr_config.radius - * morr_config.effective_index - * ( - 1 / morr_config.resonance_wavelength - - 1 / (morr_config.resonance_wavelength - morr_config.bandwidth / 2) - ) - ) - - ### allocate parameters - self.weight = None - self.x_zero_pad = None - self.morr_output_scale = None ## learnable balancing factors implelemt by MRRs - self.morr_input_bias = None ## round-trip phase shift bias within MORR - self.morr_input_scale = ( - None ## scaling factor for the round-trip phase shift within MORR - ) - self.morr_gain = ( - 100 / (self.in_features // self.miniblock) - ) ** 0.5 ## TIA gain, calculated such that output variance is around 1 - ### build trainable parameters - self.build_parameters() - - ### quantization tool - self.input_quantizer = input_quantize_fn(self.in_bit, device=self.device) - self.weight_quantizer = weight_quantize_fn( - self.w_bit, alg="dorefa_pos" - ) ## [0-1] positive only, maintain the original scale - self.morr_output_scale_quantizer = weight_quantize_fn( - self.w_bit, alg="dorefa_sym" - ) ## [-1,1] full-range - - self.mrr_roundtrip_phase_to_tr = mrr_roundtrip_phase_to_tr_func( - a=self.mrr_a, r=self.mrr_r, intensity=True - ) - - ### default set to slow forward - self.disable_fast_forward() - ### default set no gamma noise - self.set_gamma_noise(0) - ### default set no crosstalk - self.disable_crosstalk() - ### default set no phase variation - self.disable_phase_variation() - - if bias: - self.bias = Parameter(torch.Tensor(out_features).to(self.device)) - else: - self.register_parameter("bias", None) - - self.reset_parameters(morr_init=morr_init_val) - self.finegrain_drop_mask = None - - def build_parameters(self) -> None: - - self.weight = Parameter( - torch.ones( - self.grid_dim_y, - self.grid_dim_x, - self.miniblock, - device=self.device, - dtype=torch.float, - ) - ) - ### Learnable balancing factor (morr_output_scale) - ### We use a single scaling factor for each block - - # Init this to ones and non-trainable - # TODO: Verify the effectiveness of making this trainable - self.morr_output_scale = Parameter( - torch.ones(1, 1, max(1, self.grid_dim_x), 1, device=self.device) - ) - # self.morr_output_scale.requires_grad = False - - if self.trainable_morr_bias: - ### initialize with the finest-granularity, i.e., per mini-block - self.morr_input_bias = Parameter( - torch.zeros( - self.grid_dim_y, - self.grid_dim_x, - device=self.device, - dtype=torch.float, - ) - ) - if self.trainable_morr_scale: - ### initialize with the finest-granularity, i.e., per mini-block - self.morr_input_scale = Parameter( - torch.zeros( - self.grid_dim_y, - self.grid_dim_x, - device=self.device, - dtype=torch.float, - ) - ) - - def reset_parameters(self, morr_init: bool = False) -> None: - ### nonlinear curve aware initialization - if morr_init: - ## initialize weight - morr_uniform_( - self.weight, - MORRConfig=self.MORRConfig, - n_op=self.miniblock, - biased=self.w_bit >= 16, - gain=2 if self.in_bit < 16 else 1, - ) # quantization needs zero-center - self.sigma_weight = self.weight.data.std().item() - self.weight_quant_gain = None - - ## output distribution aware initialization to output scaling factor - t1 = mrr_roundtrip_phase_to_tr_fused( - torch.tensor([0]).float(), a=self.mrr_a, r=self.mrr_r, intensity=True - ) - t2 = mrr_roundtrip_phase_to_tr_fused( - torch.tensor([self.morr_fwhm * 2.4]).float(), - a=self.mrr_a, - r=self.mrr_r, - intensity=True, - ) - g = ( - (t2 - t1) / (2.4 * self.morr_fwhm) - ).item() ## 0~2.4 FWHM slope as a linear approximation - - self.sigma_out_scale = 4 / (3 * self.grid_dim_x**0.5 * g * self.morr_fwhm) - self.out_scale_quant_gain = None - init.normal_(self.morr_output_scale, 0, self.sigma_out_scale) - else: - init.kaiming_normal_(self.weight.data) - init.kaiming_normal_(self.morr_output_scale.data) - self.sigma_weight = self.weight.data.std().item() - self.weight_quant_gain = None - self.sigma_out_scale = self.morr_output_scale.data.std().item() - self.out_scale_quant_gain = None - - if self.morr_input_bias is not None: - self.morr_input_bias.data.zero_() - if self.morr_input_scale is not None: - ### after sigmoid, it cooresponds to 1 scale - init.normal_(self.morr_input_scale.data, 2, 0.1) - - if self.bias is not None: - init.uniform_(self.bias, 0, 0) - - def sync_parameters(self, src: str = "weight") -> None: - """ - description: synchronize all parameters from the source parameters - """ - - raise NotImplementedError - - def build_weight(self) -> Tensor: - if self.w_bit < 16: - ### differentiable quantizer based on STE to enable QAT (Dorefa-Net, arXiv 2016) - weight = self.weight_quantizer(self.weight) - - ## rescale weights after quantization can maintain the initialization distribution - if self.weight_quant_gain is None: - self.weight_quant_gain = self.sigma_weight / weight.data.std() - if self.trainable_morr_scale: - morr_scale = self.morr_scale * self.weight_quant_gain - else: - morr_scale = self.weight_quant_gain - weight = weight.mul( - morr_scale - ) ### gain factor from Tanh used in quantization - - ### quantize learnable balancing factor - morr_output_scale = self.morr_output_scale_quantizer(self.morr_output_scale) - else: - weight = self.weight.abs() # positive only - morr_output_scale = ( - self.morr_output_scale - self.morr_output_scale.data.mean() - ) - - if self.finegrain_drop_mask is not None: - weight = weight.mul(self.finegrain_drop_mask.float()) - - # morr_output_scale processing is removed here - - return weight, self.morr_output_scale.squeeze(-1).unsqueeze(0) - - def enable_fast_forward(self) -> None: - self.fast_forward_flag = True - - def disable_fast_forward(self) -> None: - self.fast_forward_flag = False - - def set_gamma_noise( - self, noise_std: float, random_state: Optional[int] = None - ) -> None: - self.gamma_noise_std = noise_std - - def load_parameters(self, param_dict) -> None: - """ - description: update parameters based on this parameter dictionary\\ - param param_dict {dict of dict} {layer_name: {param_name: param_tensor, ...}, ...} - """ - for name, param in param_dict.items(): - getattr(self, name).data.copy_(param) - - def set_weight_bitwidth(self, w_bit: int) -> None: - self.w_bit = w_bit - self.weight_quantizer.set_bitwidth(w_bit) - self.morr_output_scale_quantizer.set_bitwidth(w_bit) - - def set_input_bitwidth(self, in_bit: int) -> None: - self.in_bit = in_bit - self.input_quantizer.set_bitwidth(in_bit) - - def input_modulator(self, x: Tensor) -> Tensor: - ### voltage to power, which is proportional to the phase shift - return x * x - - def set_crosstalk_coupling_matrix( - self, coupling_factor: float, drop_perc: float = 0 - ) -> None: - ### crosstalk coupling matrix is a symmetric matrix, but the intra-MORR crosstalk can be taken as a round-trip phase shift scaling factor, which is proportional to the number of segments after pruned. - ### drop-perc is the pruning percentage. - assert 0 <= coupling_factor <= 1, logger.error( - f"Coupling factor must in [0,1], but got {coupling_factor}" - ) - - self.crosstalk_factor = ( - 1 + max(3, (self.miniblock * (1 - drop_perc) - 1)) * coupling_factor - ) - - def enable_crosstalk(self) -> None: - self.enable_thermal_crosstalk = True - - def disable_crosstalk(self) -> None: - self.enable_thermal_crosstalk = False - - def set_phase_variation(self, phase_noise_std: float = 0) -> None: - self.phase_noise_std = phase_noise_std - - def enable_phase_variation(self) -> None: - self.enable_phase_noise = True - - def disable_phase_variation(self) -> None: - self.enable_phase_noise = False - - def enable_trainable_morr_scale(self) -> None: - self.trainable_morr_scale = True - - def disable_trainable_morr_scale(self) -> None: - self.trainable_morr_scale = False - - def enable_trainable_morr_bias(self) -> None: - self.trainable_morr_bias = True - - def disable_trainable_morr_bias(self) -> None: - self.trainable_morr_bias = False - - @property - def morr_bias(self) -> Tensor: - if self.morr_input_bias is None: - return None - # return 2 * self.morr_fwhm * torch.sigmoid(self.morr_input_bias.unsqueeze(0).unsqueeze(-1)) - return self.morr_fwhm * torch.tanh( - self.morr_input_bias.unsqueeze(0).unsqueeze(-1) - ) - - @property - def morr_scale(self) -> Tensor: - if self.morr_input_scale is None: - return None - return torch.sigmoid(self.morr_input_scale.unsqueeze(-1)) + 0.2 # [p, q, 1] - - def _compute_single_pass(self, weight: Tensor, x: Tensor, morr_output_scale: Tensor) -> Tensor: - """Helper method to compute a single pass through the MORR.""" - ### x : [bs, N/k, k] - ### weights: [M, N/k, k] - - weight = weight.unsqueeze(0).unsqueeze(-2) # [1, M, N/k, 1, k] - x = x.unsqueeze(1).unsqueeze(-1) # [bs, 1, N/k, k, 1] - x = weight.matmul(x) # [bs, M, N/k, 1, 1] - x = x.squeeze(-1) # [bs, M, N/k, 1] - - if self.enable_phase_noise and self.phase_noise_std > 1e-5: - x = x + torch.zeros_like(x).normal_(0, self.phase_noise_std) - - # Input scaling/biasing if enabled - if self.trainable_morr_scale: - x = x * self.morr_scale - - if self.trainable_morr_bias: - x = x - self.morr_bias - - # Apply MORR transmission function - x = self.mrr_roundtrip_phase_to_tr(x) - - # Flatten output - x = morr_output_scale.matmul(x) # [1, 1, 1, N/k] x [bs, M, N/k, 1] = [bs, M, 1, 1] - x = x.squeeze(-1).squeeze(-1) # [bs, M] - - return x - - def propagate_morr( - self, weight: Tensor, x: Tensor, morr_output_scale: Optional[Tensor] = None - ) -> Tensor: - """ - @description: propagate through the analytically calculated transfer matrix of molg. We implement circulant matrix multiplication using fast circ matmul - @param weight {torch.Tensor} two phase shifters in the MZI-based attenuators - @param x {torch.Tensor} complex-valued input - @param morr_output_scale {torch.Tensor} learnable balancing factors - @return: y {torch.Tensor} output of attenuators - - Noted Here, we use 4-pass matmals to preserve pre-trained weight losslessly. - """ - ### x : [bs, q, k] - ### weights: [p, q, k] - ### morr_output_scale: [1, 1, 1, q] - - ### input scaling [TCAD'21], must have valid ranges. too small will have dead neuron and not enough nonlinearity; too large will have larger power, cross-channel crosstalk. [0.2 - 1.2] will be suitable - ## build circulant weight matrix - # crosstalk on the weights are much cheaper to compute than on the phase shift - - # Split weights and inputs into positive and negative parts - pos_weight = F.relu(weight) - neg_weight = -F.relu(-weight) # |W-| - - x = x.view(-1, self.grid_dim_x, self.miniblock) # [bs, q, k] - pos_x = F.relu(x) - neg_x = -F.relu(-x) # |X-| - - if self.enable_thermal_crosstalk and self.crosstalk_factor > 1: - # weight = weight * self.crosstalk_factor - pos_weight = pos_weight * self.crosstalk_factor - neg_weight = neg_weight * self.crosstalk_factor - - # Compute the four passes - result_pp = self._compute_single_pass(pos_weight, pos_x, morr_output_scale)# 1. W+X+ - result_np = self._compute_single_pass(neg_weight, pos_x, morr_output_scale)# 2. |W-|X+ - result_pn = self._compute_single_pass(pos_weight, neg_x, morr_output_scale)# 3. W+|X-| - result_nn = self._compute_single_pass(neg_weight, neg_x, morr_output_scale)# 4. W-X- - - x = result_pp - result_np - result_pn + result_nn - - return x - - def get_finegrain_drop_mask(self, topk: int) -> Tensor: - if self.w_bit < 16: - weight = self.weight_quantizer(self.weight.data) # [p, q, k] - else: - weight = self.weight.data.abs() - indices = weight.argsort(dim=-1) - mask = torch.ones_like(weight, dtype=torch.bool, device=weight.device) - - drop_indices = indices[:, :, 0:-topk] - mask.scatter_(2, drop_indices, 0) - self.finegrain_drop_mask = mask - return mask - - def apply_finegrain_drop_mask(self, mask: Tensor) -> None: - if self.w_bit < 16: - self.weight.data.masked_fill_(~mask.view_as(self.weight.data), -1000) - else: - self.weight.data.masked_fill_(~mask.view_as(self.weight.data), 0) - - def forward(self, x: Tensor) -> Tensor: - # adjust output shape if used in transformer - is_transformer = len(x.shape) == 3 - if is_transformer: - B, N, D = x.shape - - assert ( - x.size(-1) == self.in_features - ), f"[E] Input dimension does not match the weight size {self.out_features, self.in_features}, but got input size ({tuple(x.size())}))" - if self.in_bit < 16: - x = self.input_quantizer(x) - - # ignore morr_output_scale, as we apply a uniform scale for all rows. - weight, morr_output_scale = self.build_weight() - - if self.in_features_pad > self.in_features: - if self.x_zero_pad is None or self.x_zero_pad.size(0) != x.size(0): - self.x_zero_pad = torch.zeros( - x.size(0), - self.in_features_pad - self.in_features, - device=x.device, - dtype=x.dtype, - ) - x = torch.cat([x, self.x_zero_pad], dim=1) - - # Find max values for uniform scaling - w_max = weight.abs().max() - x_max = x.abs().max(dim=1, keepdim=True)[0] - - x = x.view(-1, self.grid_dim_x, self.miniblock) - - ### modulation - ### x: [bs, q, k] -> [bs, q, k] - x = self.input_modulator(x) - - ### propagate through morr array - ### x: [bs, q, k] -> [bs, p*k] - x = self.propagate_morr(weight, x, morr_output_scale) - - # Apply uniform scaling - # x = x * x_max * w_max - - if self.out_features < self.out_features_pad: - x = x[..., : self.out_features] - if self.bias is not None: - x = x + self.bias.unsqueeze(0) - - # adjust output shape if used in transformer - if is_transformer: - x = x.view(B, N, self.out_features) - return x \ No newline at end of file diff --git a/src/chop/nn/optical/modules/morr_linear.py b/src/chop/nn/optical/modules/morr_linear.py index e409070ea..946cf864a 100644 --- a/src/chop/nn/optical/modules/morr_linear.py +++ b/src/chop/nn/optical/modules/morr_linear.py @@ -447,7 +447,7 @@ def forward(self, x: Tensor) -> Tensor: is_transformer = len(x.shape) == 3 if is_transformer: B, N, D = x.shape - + assert ( x.size(-1) == self.in_features ), f"[E] Input dimension does not match the weight size {self.out_features, self.in_features}, but got input size ({tuple(x.size())}))" diff --git a/src/chop/nn/optical/modules/morr_transformer/morr_bert.py b/src/chop/nn/optical/modules/morr_transformer/morr_bert.py deleted file mode 100644 index 36aff15fc..000000000 --- a/src/chop/nn/optical/modules/morr_transformer/morr_bert.py +++ /dev/null @@ -1,143 +0,0 @@ -from typing import Optional -import logging - -import numpy as np -import math -import torch -import torch.nn as nn -import torch.fft -import torch.nn.functional as F -from torch import Tensor -from torch.nn import Parameter, init -from torch.types import Device -import pytorch_lightning as pl -import torchmetrics -import transformers -from transformers import GPT2TokenizerFast -from packaging import version -from typing import List, Optional, Tuple, Union - -from ...utils import MORRConfig_20um_MQ -from ...utils import mrr_roundtrip_phase_to_tr_func, mrr_roundtrip_phase_to_tr_fused -from ...utils import toeplitz -from ...utils import morr_uniform_ -from ...utils import input_quantize_fn, weight_quantize_fn -from ..base_layer import ONNBaseLayer -from ..morr_custom_linear import AllPassMORRLinear -from ..morr_linear import AllPassMORRCirculantLinear -from .morr_matmul import AllPassMORRCirculantMatMuls -from .morr_transformer import MORRSdpa - -from transformers.models.bert.modeling_bert import BertSelfAttention -from transformers.utils import ( - get_torch_version, -) - -class BertMORRSelfAttention(BertSelfAttention): - def __init__(self, config, position_embedding_type=None, morr_config=None): - super().__init__(config, position_embedding_type=position_embedding_type) - self.dropout_prob = config.attention_probs_dropout_prob - self.require_contiguous_qkv = version.parse(get_torch_version()) < version.parse("2.2.0") - # define MORR object to perform SDPA - self.morr_spda = None - self.morr_config = morr_config - - # Adapted from BertSelfAttention - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, - output_attentions: Optional[bool] = False, - ) -> Tuple[torch.Tensor]: - if self.position_embedding_type != "absolute" or output_attentions or head_mask is not None: - # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once implemented. - # logger.warning_once( - # "BertSdpaSelfAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support " - # "non-absolute `position_embedding_type` or `output_attentions=True` or `head_mask`. Falling back to " - # "the manual attention implementation, but specifying the manual implementation will be required from " - # "Transformers version v5.0.0 onwards. This warning can be removed using the argument " - # '`attn_implementation="eager"` when loading the model.' - # ) - return super().forward( - hidden_states, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) - - bsz, tgt_len, _ = hidden_states.size() - - query_layer = self.transpose_for_scores(self.query(hidden_states)) - - # If this is instantiated as a cross-attention module, the keys and values come from an encoder; the attention - # mask needs to be such that the encoder's padding tokens are not attended to. - is_cross_attention = encoder_hidden_states is not None - - current_states = encoder_hidden_states if is_cross_attention else hidden_states - attention_mask = encoder_attention_mask if is_cross_attention else attention_mask - - # Check `seq_length` of `past_key_value` == `len(current_states)` to support prefix tuning - if is_cross_attention and past_key_value and past_key_value[0].shape[2] == current_states.shape[1]: - key_layer, value_layer = past_key_value - else: - key_layer = self.transpose_for_scores(self.key(current_states)) - value_layer = self.transpose_for_scores(self.value(current_states)) - if past_key_value is not None and not is_cross_attention: - key_layer = torch.cat([past_key_value[0], key_layer], dim=2) - value_layer = torch.cat([past_key_value[1], value_layer], dim=2) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_layer, value_layer) - - # SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom - # attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0. - # Reference: https://github.com/pytorch/pytorch/issues/112577 - if self.require_contiguous_qkv and query_layer.device.type == "cuda" and attention_mask is not None: - query_layer = query_layer.contiguous() - key_layer = key_layer.contiguous() - value_layer = value_layer.contiguous() - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create - # a causal mask in case tgt_len == 1. - is_causal = ( - True if self.is_decoder and not is_cross_attention and attention_mask is None and tgt_len > 1 else False - ) - - self.morr_spda = MORRSdpa( - self.attention_head_size, # Dh - self.num_attention_heads, # H - hidden_states.shape[1], # N - dropout_p=self.dropout_prob, - use_morr=True, - morr_config=self.morr_config, - ) - attn_output = self.morr_spda( - query_layer, - key_layer, - value_layer, - attn_mask=attention_mask, - ) - - attn_output = attn_output.transpose(1, 2) - attn_output = attn_output.reshape(bsz, tgt_len, self.all_head_size) - - outputs = (attn_output,) - if self.is_decoder: - outputs = outputs + (past_key_value,) - return outputs \ No newline at end of file diff --git a/src/chop/nn/optical/modules/morr_transformer/morr_matmul.py b/src/chop/nn/optical/modules/morr_transformer/morr_matmul.py deleted file mode 100644 index 3102b6c37..000000000 --- a/src/chop/nn/optical/modules/morr_transformer/morr_matmul.py +++ /dev/null @@ -1,535 +0,0 @@ -from typing import Optional -import logging - -import numpy as np -import math -import torch -import torch.nn as nn -import torch.fft -import torch.nn.functional as F -from torch import Tensor -from torch.nn import Parameter, init -from torch.types import Device -import pytorch_lightning as pl -import torchmetrics -import transformers -from transformers import GPT2TokenizerFast - -from ...utils import MORRConfig_20um_MQ -from ...utils import mrr_roundtrip_phase_to_tr_func, mrr_roundtrip_phase_to_tr_fused -from ...utils import toeplitz -from ...utils import morr_uniform_ -from ...utils import input_quantize_fn, weight_quantize_fn -from ..base_layer import ONNBaseLayer -from ..morr_custom_linear import AllPassMORRLinear -from ..morr_linear import AllPassMORRCirculantLinear - -from transformers import BertModel, BertForSequenceClassification -from transformers.models.gpt2.modeling_gpt2 import ( - GPT2Attention, - GPT2MLP, - GPT2Block, - Conv1D, -) - -logger = logging.getLogger(__name__) - -__all__ = ["AllPassMORRCirculantMatMuls"] - - -class AllPassMORRCirculantMatMuls(ONNBaseLayer): - """ - All-pass MORR Linear layer, assumes (1) block-circulant matrix (2) differential rails (3) learnable balancing factors. - J. Gu, et al., "SqueezeLight: Towards Scalable Optical Neural Networks with Multi-Operand Ring Resonators" - https://doi.org/10.23919/DATE51398.2021.9474147 - """ - - __constants__ = ["in_features", "out_features"] - in_features: int - out_features: int - miniblock: int - weight: Tensor - - def __init__( - self, - in_features: int, - out_features: int, - bias: bool = False, - config=None, - device: Device = torch.device("cuda" if torch.cuda.is_available() else "cpu"), - ) -> None: - super(AllPassMORRCirculantMatMuls, self).__init__() - self.in_features = in_features - self.out_features = out_features - - miniblock_size = config.get("miniblock", 4) - self.miniblock = miniblock_size - self.grid_dim_x = int(np.ceil(self.in_features / miniblock_size)) - self.grid_dim_y = int(np.ceil(self.out_features / miniblock_size)) - self.in_features_pad = self.grid_dim_x * miniblock_size - self.out_features_pad = self.grid_dim_y * miniblock_size - - self.v_max = 10.8 - self.v_pi = 4.36 - self.gamma = np.pi / self.v_pi**2 - self.w_bit = 32 - self.in_bit = 32 - - morr_config = config.get("MORRConfig", MORRConfig_20um_MQ) - morr_init_val = config.get("morr_init", MORRConfig_20um_MQ) - self.MORRConfig = morr_config - self.morr_init = morr_init_val - self.mrr_a = morr_config.attenuation_factor - self.mrr_r = morr_config.coupling_factor - self.trainable_morr_bias = config.get("trainable_morr_bias", MORRConfig_20um_MQ) - self.trainable_morr_scale = config.get( - "trainable_morr_scale", MORRConfig_20um_MQ - ) - self.device = device - ### calculate FWHM (rad) - self.morr_fwhm = ( - -4 - * np.pi**2 - * morr_config.radius - * morr_config.effective_index - * ( - 1 / morr_config.resonance_wavelength - - 1 / (morr_config.resonance_wavelength - morr_config.bandwidth / 2) - ) - ) - - ### allocate parameters - self.weight = None - self.x_zero_pad = None - self.morr_output_scale = None ## learnable balancing factors implelemt by MRRs - self.morr_input_bias = None ## round-trip phase shift bias within MORR - self.morr_input_scale = ( - None ## scaling factor for the round-trip phase shift within MORR - ) - self.morr_gain = ( - 100 / (self.in_features // self.miniblock) - ) ** 0.5 ## TIA gain, calculated such that output variance is around 1 - ### build trainable parameters - self.build_parameters() - - ### quantization tool - self.input_quantizer = input_quantize_fn(self.in_bit, device=self.device) - self.weight_quantizer = weight_quantize_fn( - self.w_bit, alg="dorefa_pos" - ) ## [0-1] positive only, maintain the original scale - self.morr_output_scale_quantizer = weight_quantize_fn( - self.w_bit, alg="dorefa_sym" - ) ## [-1,1] full-range - - self.mrr_roundtrip_phase_to_tr = mrr_roundtrip_phase_to_tr_func( - a=self.mrr_a, r=self.mrr_r, intensity=True - ) - - ### default set to slow forward - self.disable_fast_forward() - ### default set no gamma noise - self.set_gamma_noise(0) - ### default set no crosstalk - self.disable_crosstalk() - ### default set no phase variation - self.disable_phase_variation() - - if bias: - self.bias = Parameter(torch.Tensor(out_features).to(self.device)) - else: - self.register_parameter("bias", None) - - self.reset_parameters(morr_init=morr_init_val) - self.finegrain_drop_mask = None - - def build_parameters(self) -> None: - - self.weight = Parameter( - torch.ones( - self.grid_dim_y, - self.grid_dim_x, - self.miniblock, - device=self.device, - dtype=torch.float, - ) - ) - ### Learnable balancing factor (morr_output_scale) - ### We use a single scaling factor for each block - self.morr_output_scale = Parameter( - torch.randn(1, 1, max(1, self.grid_dim_x // 2) + 1, 1, device=self.device) - ) - if self.trainable_morr_bias: - ### initialize with the finest-granularity, i.e., per mini-block - self.morr_input_bias = Parameter( - torch.zeros( - self.grid_dim_y, - self.grid_dim_x, - device=self.device, - dtype=torch.float, - ) - ) - if self.trainable_morr_scale: - ### initialize with the finest-granularity, i.e., per mini-block - self.morr_input_scale = Parameter( - torch.zeros( - self.grid_dim_y, - self.grid_dim_x, - device=self.device, - dtype=torch.float, - ) - ) - - def reset_parameters(self, morr_init: bool = False) -> None: - ### nonlinear curve aware initialization - if morr_init: - ## initialize weight - morr_uniform_( - self.weight, - MORRConfig=self.MORRConfig, - n_op=self.miniblock, - biased=self.w_bit >= 16, - gain=2 if self.in_bit < 16 else 1, - ) # quantization needs zero-center - self.sigma_weight = self.weight.data.std().item() - self.weight_quant_gain = None - - ## output distribution aware initialization to output scaling factor - t1 = mrr_roundtrip_phase_to_tr_fused( - torch.tensor([0]).float(), a=self.mrr_a, r=self.mrr_r, intensity=True - ) - t2 = mrr_roundtrip_phase_to_tr_fused( - torch.tensor([self.morr_fwhm * 2.4]).float(), - a=self.mrr_a, - r=self.mrr_r, - intensity=True, - ) - g = ( - (t2 - t1) / (2.4 * self.morr_fwhm) - ).item() ## 0~2.4 FWHM slope as a linear approximation - - self.sigma_out_scale = 4 / (3 * self.grid_dim_x**0.5 * g * self.morr_fwhm) - self.out_scale_quant_gain = None - init.normal_(self.morr_output_scale, 0, self.sigma_out_scale) - else: - init.kaiming_normal_(self.weight.data) - init.kaiming_normal_(self.morr_output_scale.data) - self.sigma_weight = self.weight.data.std().item() - self.weight_quant_gain = None - self.sigma_out_scale = self.morr_output_scale.data.std().item() - self.out_scale_quant_gain = None - - if self.morr_input_bias is not None: - self.morr_input_bias.data.zero_() - if self.morr_input_scale is not None: - ### after sigmoid, it cooresponds to 1 scale - init.normal_(self.morr_input_scale.data, 2, 0.1) - - if self.bias is not None: - init.uniform_(self.bias, 0, 0) - - def sync_parameters(self, src: str = "weight") -> None: - """ - description: synchronize all parameters from the source parameters - """ - - raise NotImplementedError - - def build_weight(self) -> Tensor: - if self.w_bit < 16: - ### differentiable quantizer based on STE to enable QAT (Dorefa-Net, arXiv 2016) - weight = self.weight_quantizer(self.weight) - - ## rescale weights after quantization can maintain the initialization distribution - if self.weight_quant_gain is None: - self.weight_quant_gain = self.sigma_weight / weight.data.std() - if self.trainable_morr_scale: - morr_scale = self.morr_scale * self.weight_quant_gain - else: - morr_scale = self.weight_quant_gain - weight = weight.mul( - morr_scale - ) ### gain factor from Tanh used in quantization - - ### quantize learnable balancing factor - morr_output_scale = self.morr_output_scale_quantizer(self.morr_output_scale) - else: - weight = self.weight.abs() # positive only - morr_output_scale = ( - self.morr_output_scale - self.morr_output_scale.data.mean() - ) - - if self.finegrain_drop_mask is not None: - weight = weight.mul(self.finegrain_drop_mask.float()) - - ## differential balancing factor concatenation - scale = morr_output_scale[..., :-1, :] - scale_pad = morr_output_scale[..., -1:, :] - if self.grid_dim_x % 2 == 0: - # even blocks - scale = torch.cat([scale, -scale], dim=2) # [1, 1, q, 1] - else: - # odd blocks - if self.grid_dim_x > 1: - scale = torch.cat([morr_output_scale, -scale], dim=2) # [1, 1, q, 1] - else: - scale = scale_pad # [1, 1, q, 1] - morr_output_scale = scale.squeeze(-1).unsqueeze(0) # [1 ,1, 1, q] - - return weight, morr_output_scale - - def enable_fast_forward(self) -> None: - self.fast_forward_flag = True - - def disable_fast_forward(self) -> None: - self.fast_forward_flag = False - - def set_gamma_noise( - self, noise_std: float, random_state: Optional[int] = None - ) -> None: - self.gamma_noise_std = noise_std - - def load_parameters(self, param_dict) -> None: - """ - description: update parameters based on this parameter dictionary\\ - param param_dict {dict of dict} {layer_name: {param_name: param_tensor, ...}, ...} - """ - for name, param in param_dict.items(): - getattr(self, name).data.copy_(param) - - def set_weight_bitwidth(self, w_bit: int) -> None: - self.w_bit = w_bit - self.weight_quantizer.set_bitwidth(w_bit) - self.morr_output_scale_quantizer.set_bitwidth(w_bit) - - def set_input_bitwidth(self, in_bit: int) -> None: - self.in_bit = in_bit - self.input_quantizer.set_bitwidth(in_bit) - - def input_modulator(self, x: Tensor) -> Tensor: - ### voltage to power, which is proportional to the phase shift - return x * x - - def set_crosstalk_coupling_matrix( - self, coupling_factor: float, drop_perc: float = 0 - ) -> None: - ### crosstalk coupling matrix is a symmetric matrix, but the intra-MORR crosstalk can be taken as a round-trip phase shift scaling factor, which is proportional to the number of segments after pruned. - ### drop-perc is the pruning percentage. - assert 0 <= coupling_factor <= 1, logger.error( - f"Coupling factor must in [0,1], but got {coupling_factor}" - ) - - self.crosstalk_factor = ( - 1 + max(3, (self.miniblock * (1 - drop_perc) - 1)) * coupling_factor - ) - - def enable_crosstalk(self) -> None: - self.enable_thermal_crosstalk = True - - def disable_crosstalk(self) -> None: - self.enable_thermal_crosstalk = False - - def set_phase_variation(self, phase_noise_std: float = 0) -> None: - self.phase_noise_std = phase_noise_std - - def enable_phase_variation(self) -> None: - self.enable_phase_noise = True - - def disable_phase_variation(self) -> None: - self.enable_phase_noise = False - - def enable_trainable_morr_scale(self) -> None: - self.trainable_morr_scale = True - - def disable_trainable_morr_scale(self) -> None: - self.trainable_morr_scale = False - - def enable_trainable_morr_bias(self) -> None: - self.trainable_morr_bias = True - - def disable_trainable_morr_bias(self) -> None: - self.trainable_morr_bias = False - - @property - def morr_bias(self) -> Tensor: - if self.morr_input_bias is None: - return None - # return 2 * self.morr_fwhm * torch.sigmoid(self.morr_input_bias.unsqueeze(0).unsqueeze(-1)) - return self.morr_fwhm * torch.tanh( - self.morr_input_bias.unsqueeze(0).unsqueeze(-1) - ) - - @property - def morr_scale(self) -> Tensor: - if self.morr_input_scale is None: - return None - return torch.sigmoid(self.morr_input_scale.unsqueeze(-1)) + 0.2 # [p, q, 1] - - def propagate_morr( - self, weight: Tensor, x: Tensor, morr_output_scale: Tensor - ) -> Tensor: - """ - @description: propagate through the analytically calculated transfer matrix of molg. We implement circulant matrix multiplication using fast circ matmul - @param weight {torch.Tensor} two phase shifters in the MZI-based attenuators - @param x {torch.Tensor} complex-valued input - @param morr_output_scale {torch.Tensor} learnable balancing factors - @return: y {torch.Tensor} output of attenuators - """ - ### x : [bs, q, k] - ### weights: [p, q, k] - ### morr_output_scale: [1, 1, 1, q] - - ### input scaling [TCAD'21], must have valid ranges. too small will have dead neuron and not enough nonlinearity; too large will have larger power, cross-channel crosstalk. [0.2 - 1.2] will be suitable - ## build circulant weight matrix - # crosstalk on the weights are much cheaper to compute than on the phase shift - if self.enable_thermal_crosstalk and self.crosstalk_factor > 1: - weight = weight * self.crosstalk_factor - weight = toeplitz(weight).unsqueeze(0) # [1, p, q, k, k] - x = x.unsqueeze(1).unsqueeze(-1) # [bs, 1, q, k, 1] - x = weight.matmul(x).squeeze(-1) # [bs, p, q, k] - - if self.enable_phase_noise and self.phase_noise_std > 1e-5: - x = x + torch.zeros_like(x).normal_(0, self.phase_noise_std) - - ### input biasing [TCAD'21], must have valid ranges. too large will have power issue and cross-channel crosstalk. [-2FWHM ~ 0] - if self.trainable_morr_bias: - x = x - self.morr_bias - - ### Use theoretical transmission function for trainable MORR nonlinearity [TCAD'21] - ### x is the phase detuning, x=0 means on-resonance - ### phase: [bs, p, q, k] - x = self.mrr_roundtrip_phase_to_tr(x) # 3x faster than autograd - - ## implement balancing factor as dot-product - """ - if(self.w_bit < 16): - morr_output_scale = self.morr_output_scale_quantizer(self.morr_output_scale) - if(self.sigma_out_scale_quant_gain is None): - self.sigma_out_scale_quant_gain = self.sigma_out_scale / morr_output_scale.data.std().item() - morr_output_scale = morr_output_scale.mul(self.sigma_out_scale_quant_gain)### gain factor from Tanh used in quantization - else: - morr_output_scale = self.morr_output_scale - # morr_output_scale = morr_output_scale * self.morr_gain - scale = morr_output_scale[..., :-1, :] - scale_pad = morr_output_scale[..., -1:, :] - - # print("morr diff transmission:", end=", ") - # diff = x[..., :x.size(2)//2,:]-x[..., x.size(2)//2:,:] - # print_stat(diff) - if(self.grid_dim_x % 2 == 0): - #even blocks - scale = torch.cat([scale, -scale], dim=2) # [1, 1, q, 1] - else: - # odd blocks - if(self.grid_dim_x > 1): - scale = torch.cat([morr_output_scale, -scale], dim=2) # [1, 1, q, 1] - else: - scale = scale_pad # [1, 1, q, 1] - scale = scale.squeeze(-1).unsqueeze(0) # [1 ,1, 1, q] - # print("output scale Q:", end=", ") - # print_stat(scale[..., :scale.size(-1)//2]) - """ - x = morr_output_scale.matmul(x) # [1, 1, 1, q] x [bs, p, q, k] = [bs, p, 1, k] - x = x.flatten(1) # [bs, p*k] - return x - - def get_finegrain_drop_mask(self, topk: int) -> Tensor: - if self.w_bit < 16: - weight = self.weight_quantizer(self.weight.data) # [p, q, k] - else: - weight = self.weight.data.abs() - indices = weight.argsort(dim=-1) - mask = torch.ones_like(weight, dtype=torch.bool, device=weight.device) - - drop_indices = indices[:, :, 0:-topk] - mask.scatter_(2, drop_indices, 0) - self.finegrain_drop_mask = mask - return mask - - def apply_finegrain_drop_mask(self, mask: Tensor) -> None: - if self.w_bit < 16: - self.weight.data.masked_fill_(~mask.view_as(self.weight.data), -1000) - else: - self.weight.data.masked_fill_(~mask.view_as(self.weight.data), 0) - - def load_compressed_weight(self, weight: Tensor) -> None: - """ - load weight data from torch.linear module.weight.data - """ - assert weight.shape == (self.out_features, self.in_features), ( - f"Expected {(self.out_features, self.in_features)}, got {weight.shape}" - ) - - W_padded = weight.new_zeros((self.out_features_pad, self.in_features_pad)) - W_padded[: weight.size(0), : weight.size(1)] = weight - - new_weight = weight.new_zeros((self.grid_dim_y, self.grid_dim_x, self.miniblock)) - for p in range(self.grid_dim_y): - for q in range(self.grid_dim_x): - for k in range(self.miniblock): - row_idx = p * self.miniblock + k # The row in W_padded: - col_start = q * self.miniblock # The columns in W_padded: - col_end = (q + 1) * self.miniblock - block = W_padded[row_idx, col_start:col_end] - - new_weight[p, q, k] = block.mean() - bound = 1 / math.sqrt(self.miniblock) - new_weight = torch.rand( - (self.grid_dim_y, self.grid_dim_x, self.miniblock), - device=weight.device, - dtype=weight.dtype - ) * 2 * bound - bound - - self.load_parameters({"weight": new_weight}) - - def forward(self, X: Tensor, Y: Tensor) -> Tensor: - """ - this module currently support 4-D multi-head attn MatMul only - - x: input, [B, H, N, D] - - y: weight, [B, H, D, N] - """ - assert len(X.shape) == 4, f"Expected a 4-D tensor, got shape {X.shape}" - B, H, N, D = X.shape - out_rows = [] - - for b in range(B): - for h in range(H): - self.load_compressed_weight(Y[b, h].t()) - x = X[b, h] - - assert ( - x.size(-1) == self.in_features - ), f"[E] Input dimension does not match the weight size {self.out_features, self.in_features}, but got input size ({tuple(x.size())}))" - if self.in_bit < 16: - x = self.input_quantizer(x) - - weight, morr_output_scale = self.build_weight() - if self.in_features_pad > self.in_features: - if self.x_zero_pad is None or self.x_zero_pad.size(0) != x.size(0): - self.x_zero_pad = torch.zeros( - x.size(0), - self.in_features_pad - self.in_features, - device=x.device, - dtype=x.dtype, - ) - x = torch.cat([x, self.x_zero_pad], dim=1) - - x = x.view(-1, self.grid_dim_x, self.miniblock) - - ### modulation - ### x: [bs, q, k] -> [bs, q, k] - x = self.input_modulator(x) - - ### propagate through morr array - ### x: [bs, q, k] -> [bs, p*k] - x = self.propagate_morr(weight, x, morr_output_scale) - - if self.out_features < self.out_features_pad: - x = x[..., : self.out_features] - if self.bias is not None: - x = x + self.bias.unsqueeze(0) - - out_rows.append(x) - - out = torch.stack(out_rows, dim=0) # (B·H, N, N) - out = out.view(B, H, N, self.out_features) - return out \ No newline at end of file diff --git a/src/chop/nn/optical/modules/morr_transformer/morr_transformer.py b/src/chop/nn/optical/modules/morr_transformer/morr_transformer.py deleted file mode 100644 index a9da28085..000000000 --- a/src/chop/nn/optical/modules/morr_transformer/morr_transformer.py +++ /dev/null @@ -1,168 +0,0 @@ -from typing import Optional -import logging - -import numpy as np -import math -import torch -import torch.nn as nn -import torch.fft -import torch.nn.functional as F -from torch import Tensor -from torch.nn import Parameter, init -from torch.types import Device -import pytorch_lightning as pl -import torchmetrics -import transformers -from transformers import GPT2TokenizerFast - -from ...utils import MORRConfig_20um_MQ -from ...utils import mrr_roundtrip_phase_to_tr_func, mrr_roundtrip_phase_to_tr_fused -from ...utils import toeplitz -from ...utils import morr_uniform_ -from ...utils import input_quantize_fn, weight_quantize_fn -from ..base_layer import ONNBaseLayer -from ..morr_custom_linear import AllPassMORRLinear -from ..morr_linear import AllPassMORRCirculantLinear -from .morr_matmul import AllPassMORRCirculantMatMuls - -from transformers import BertModel, BertForSequenceClassification -from transformers.models.gpt2.modeling_gpt2 import ( - GPT2Attention, - GPT2MLP, - GPT2Block, - Conv1D, -) - -logger = logging.getLogger(__name__) - -__all__ = [""] - - - -class MORRMHA(nn.Module): - def __init__(self, embed_dim, heads): - super(MORRMHA, self).__init__() - assert embed_dim % heads == 0 - self.n_heads = heads - self.Wq = AllPassMORRCirculantLinear(embed_dim, embed_dim) - self.Wk = AllPassMORRCirculantLinear(embed_dim, embed_dim) - self.Wv = AllPassMORRCirculantLinear(embed_dim, embed_dim) - self.qmm1 = AllPassMORRCirculantMatMuls() - self.dropout_wq = nn.Dropout(0.1) - self.dropout_wk = nn.Dropout(0.1) - self.dropout_wv = nn.Dropout(0.1) - self.qmm2 = AllPassMORRCirculantMatMuls() - self.Wout = AllPassMORRCirculantLinear(embed_dim, embed_dim) - self.dropout1 = nn.Dropout(0.1) - self.dropout2 = nn.Dropout(0.1) - - def forward(self, x, mask): - b = x.size(0) - n = x.size(1) - h = self.n_heads - d = x.size(2) - - def arrange_heads(acts): - # incoming shape of b, n, d, want b, h, n, d/h - return acts.view(b, n, h, -1).transpose(1, 2) - - q = arrange_heads(self.dropout_wq(self.Wq(x))) - k = arrange_heads(self.dropout_wk(self.Wk(x))) - v = arrange_heads(self.dropout_wv(self.Wv(x))) - - attn = self.qmm1(q, k.transpose(2, 3)) # yields b, h, n, n - masked = attn.masked_fill(mask, float("-inf")) - softmax_attn = self.dropout1(F.softmax(masked / math.sqrt(d // h), dim=3)) - out = self.qmm2(softmax_attn, v) # b, h, n, d/h - - out = out.transpose(1, 2).reshape(b, n, -1) - out = self.dropout2(out) - out = self.Wout(out) - return out - - -class MORRFF(nn.Module): - def __init__(self, embed_dim, expansion_dim): - super(MORRFF, self).__init__() - self.first_drop = nn.Dropout(0.1) - self.layer1 = AllPassMORRCirculantLinear(embed_dim, expansion_dim, use_noise=True) - self.act = nn.ReLU6(inplace=True) - self.dropout = nn.Dropout(0.1) - self.layer2 = AllPassMORRCirculantLinear(expansion_dim, embed_dim, use_noise=True) - - def forward(self, x): - out = self.first_drop(x) - out = self.layer1(out) - out = self.act(out) - out = self.dropout(out) - out = self.layer2(out) - return out - -class MORRDecoderLayer(nn.Module): - def __init__(self, features, heads): - super(MORRDecoderLayer, self).__init__() - self.norm1 = nn.LayerNorm(features) - self.attn = MORRMHA(features, heads) - self.drop1 = nn.Dropout(0.1) - self.norm2 = nn.LayerNorm(features) - self.ff = MORRFF(features, features * 4) - self.drop2 = nn.Dropout(0.1) - - def forward(self, x, attn_mask): - # no need for key mask for gpt; autoregressive masking already prevents 'real' tokens from attending to padding tokens to the right - identity = x - out = self.norm1(x) - out = self.attn(out, attn_mask) - out = self.drop1(out) - out = out + identity - identity = out - out = self.norm2(out) - out = self.ff(out) - out = self.drop2(out) - out = out + identity - return out - - -class MORRSdpa(nn.Module): - def __init__(self, attn_head_size, num_heads, seq_length, dropout_p, use_morr = False, morr_config = None): - super(MORRSdpa, self).__init__() - self.attn_head_size = attn_head_size - self.num_heads = num_heads - self.use_morr = use_morr - self.qmm1 = AllPassMORRCirculantMatMuls( - in_features=attn_head_size, # Dh - out_features=seq_length, # N - config = morr_config - ) - self.qmm1.disable_trainable_morr_scale() - self.qmm1.disable_trainable_morr_bias() - - self.qmm2 = AllPassMORRCirculantMatMuls( - in_features=seq_length, # D - out_features=attn_head_size, # N - config = morr_config - ) - self.qmm2.disable_trainable_morr_scale() - self.qmm2.disable_trainable_morr_bias() - self.dropout = nn.Dropout(dropout_p) - - def forward(self, query, key, value, attn_mask): - attn_head_size = self.attn_head_size - - if self.use_morr: - attn_scores = self.qmm1(query, key.transpose(2, 3)) # yields b, h, n, n - else: - attn_scores = torch.matmul(query, key.transpose(2, 3)) - - attn_scores = attn_scores / math.sqrt(attn_head_size) - if attn_mask is not None: - attn_scores = attn_scores + attn_mask - attn_probs = nn.functional.softmax(attn_scores, dim=-1) - attn_probs = self.dropout(attn_probs) - - if self.use_morr: - out = self.qmm2(attn_probs, value) # [B, H, N, N] * [B, H, N, Dh] -> [b, h, n, Dh] - else: - out = torch.matmul(attn_probs, value) - - return out \ No newline at end of file diff --git a/src/chop/nn/optical/triton_modules/morr_linear.py b/src/chop/nn/optical/triton_modules/morr_linear.py index e10c088e8..f474f7aed 100644 --- a/src/chop/nn/optical/triton_modules/morr_linear.py +++ b/src/chop/nn/optical/triton_modules/morr_linear.py @@ -447,33 +447,37 @@ def forward(self, x: Tensor) -> Tensor: output, *_ = morr_linear_fn( x, self.weight, - morr_input_bias = self.morr_input_bias, - morr_output_scale = self.morr_output_scale, - bias = None, - morr_input_scale = self.morr_input_scale, - morr_bias = self.morr_bias.detach() if self.morr_bias is not None else None, - grid_dim_x = self.grid_dim_x, - grid_dim_y = self.grid_dim_y, - miniblock = self.miniblock, + morr_input_bias=self.morr_input_bias, + morr_output_scale=self.morr_output_scale, + bias=None, + morr_input_scale=self.morr_input_scale, + morr_bias=self.morr_bias.detach() if self.morr_bias is not None else None, + grid_dim_x=self.grid_dim_x, + grid_dim_y=self.grid_dim_y, + miniblock=self.miniblock, enable_thermal_crosstalk=self.enable_thermal_crosstalk, - crosstalk_factor=None if not self.enable_thermal_crosstalk else self.crosstalk_factor, + crosstalk_factor=( + None if not self.enable_thermal_crosstalk else self.crosstalk_factor + ), enable_phase_noise=self.enable_phase_noise, - phase_noise_std=None if not self.enable_phase_noise else self.phase_noise_std, + phase_noise_std=( + None if not self.enable_phase_noise else self.phase_noise_std + ), trainable_morr_bias=self.trainable_morr_bias, mrr_a=self.mrr_a, mrr_r=self.mrr_r, finegrain_drop_mask=None, - in_features = self.in_features, - in_features_pad = self.in_features_pad, - out_features = self.out_features, - out_features_pad = self.out_features_pad, - in_bit = self.in_bit, - w_bit = self.w_bit, - morr_fwhm = self.morr_fwhm, + in_features=self.in_features, + in_features_pad=self.in_features_pad, + out_features=self.out_features, + out_features_pad=self.out_features_pad, + in_bit=self.in_bit, + w_bit=self.w_bit, + morr_fwhm=self.morr_fwhm, sigma_weight=self.sigma_weight, - trainable_morr_scale=self.trainable_morr_scale, # bool + trainable_morr_scale=self.trainable_morr_scale, # bool morr_scale=self.morr_scale, weight_quant_gain=self.weight_quant_gain, - seed = 42, + seed=42, ) - return output \ No newline at end of file + return output diff --git a/src/chop/nn/optical/triton_modules/morr_linear_kernel.py b/src/chop/nn/optical/triton_modules/morr_linear_kernel.py index 05c0577e1..6de1d97f5 100644 --- a/src/chop/nn/optical/triton_modules/morr_linear_kernel.py +++ b/src/chop/nn/optical/triton_modules/morr_linear_kernel.py @@ -1,4 +1,5 @@ import os + # os.environ["TRITON_INTERPRET"] = "1" import torch @@ -8,6 +9,7 @@ import pdb from .dtype import TORCH_DTYPE_TO_TRITON + PACKAGE_NAME = "mase_triton" from ..utils import ( toeplitz, @@ -37,6 +39,7 @@ def _get_autotune_configs(): ) return configs + @triton.jit def _mrr_roundtrip_phase_to_tr_func( x: tl.tensor, @@ -59,6 +62,7 @@ def _mrr_roundtrip_phase_to_tr_func( x = tl.sqrt(x) return x + # @triton.autotune( # configs= [ # triton.Config( @@ -75,7 +79,7 @@ def _mrr_roundtrip_phase_to_tr_func( # key=["M", "P", "Q", "K"], # ) @triton.autotune( - configs = _get_autotune_configs(), + configs=_get_autotune_configs(), key=["M", "P", "Q", "K"], ) @triton.jit @@ -99,10 +103,25 @@ def morr_propagate_kernel( w_bit, seed, # stride - stride_wm, stride_wp, stride_wq, stride_wk1, stride_wk2, - stride_xm, stride_xp, stride_xq, stride_xk1, stride_xk2, - stride_bm, stride_bp, stride_bq, stride_bk1, - stride_om, stride_op, stride_oq, stride_ok1, stride_ok2, + stride_wm, + stride_wp, + stride_wq, + stride_wk1, + stride_wk2, + stride_xm, + stride_xp, + stride_xq, + stride_xk1, + stride_xk2, + stride_bm, + stride_bp, + stride_bq, + stride_bk1, + stride_om, + stride_op, + stride_oq, + stride_ok1, + stride_ok2, finegrain_drop_mask, ENABLE_PHASE_NOISE: tl.constexpr, ENABLE_THERMAL_CROSSTALK: tl.constexpr, @@ -131,7 +150,7 @@ def morr_propagate_kernel( start_m = pid_m * BLOCK_SIZE_M start_p = pid_p * BLOCK_SIZE_P start_q = pid_q * BLOCK_SIZE_Q - + # w [1, p, q, k, 1] -> toeplitz [1, p, q, k, k] offs_wm = tl.arange(0, 1) offs_wp = pid_p * BLOCK_SIZE_P + tl.arange(0, 1) @@ -171,8 +190,10 @@ def morr_propagate_kernel( + offs_bk1[None, None, None, :, None] * stride_bk1 ) - - acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_P, BLOCK_SIZE_Q, BLOCK_SIZE_K1, BLOCK_SIZE_K2), dtype=tl.float32) + acc = tl.zeros( + (BLOCK_SIZE_M, BLOCK_SIZE_P, BLOCK_SIZE_Q, BLOCK_SIZE_K1, BLOCK_SIZE_K2), + dtype=tl.float32, + ) m_indices = tl.arange(0, BLOCK_SIZE_M)[:, None, None, None, None] p_indices = tl.arange(0, BLOCK_SIZE_P)[None, :, None, None, None] q_indices = tl.arange(0, BLOCK_SIZE_Q)[None, None, :, None, None] @@ -184,24 +205,24 @@ def morr_propagate_kernel( for q_local in range(BLOCK_SIZE_Q): q = start_q + q_local - w_mask = (p < P) & (q < Q) + w_mask = (p < P) & (q < Q) x_mask = (m < M) & (q < Q) b_mask = (p < P) & (q < Q) w = tl.load(w_ptrs, mask=w_mask, other=0.0) x = tl.load(x_ptrs, mask=x_mask, other=0.0) b = tl.load(b_ptrs, mask=b_mask, other=0.0) - - w = w.reshape(BLOCK_SIZE_K1, BLOCK_SIZE_K1) # [1, 1, 1, k, k] -> [k, k] - x = x.reshape(BLOCK_SIZE_K1, BLOCK_SIZE_K2) # [1, 1, 1, k, 1] -> [k, 1] - + + w = w.reshape(BLOCK_SIZE_K1, BLOCK_SIZE_K1) # [1, 1, 1, k, k] -> [k, k] + x = x.reshape(BLOCK_SIZE_K1, BLOCK_SIZE_K2) # [1, 1, 1, k, 1] -> [k, 1] + x = x * x # input_modulator() # ----- propagate_morr() ----- # apply thermal crosstalk noise if ENABLE_THERMAL_CROSSTALK: w = w * crosstalk_factor - + # MatMals # TODO: tl.dot requires 16*16 matrix at least, this is a workaround x = tl.trans(x) @@ -212,30 +233,37 @@ def morr_propagate_kernel( # apply phase noise if ENABLE_PHASE_NOISE: block_start = pid * BLOCK_SIZE_K1 * BLOCK_SIZE_K2 - offs = tl.reshape(block_start + tl.arange(0, BLOCK_SIZE_K1 * BLOCK_SIZE_K2) , (BLOCK_SIZE_K1, BLOCK_SIZE_K2)) + offs = tl.reshape( + block_start + tl.arange(0, BLOCK_SIZE_K1 * BLOCK_SIZE_K2), + (BLOCK_SIZE_K1, BLOCK_SIZE_K2), + ) noise = tl.randn(seed, offs) * phase_noise_std x = x + noise - + # add trainable bias b = b.reshape(1, 1) - # pdb.set_trace() + if TRAINABLE_MORR_BIAS: x = x - b - + # mrr_roundtrip_phase_to_tr x = _mrr_roundtrip_phase_to_tr_func(x, mrr_a, mrr_r, intensity=True) # store the value in acc using mask res = x - condition_mask = (m_indices == m_local) & (p_indices == p_local) & (q_indices == q_local) + condition_mask = ( + (m_indices == m_local) + & (p_indices == p_local) + & (q_indices == q_local) + ) res = res[None, None, None, :, :] - acc = tl.where(condition_mask, res, acc) + acc = tl.where(condition_mask, res, acc) # propagate pointer along Q dimension w_ptrs += stride_wq x_ptrs += stride_xq b_ptrs += stride_bq - + # Q loop end # reset pointer along Q dimension w_ptrs -= stride_wq * (BLOCK_SIZE_Q) @@ -245,7 +273,7 @@ def morr_propagate_kernel( w_ptrs += stride_wp b_ptrs += stride_bp # x_ptrs += stride_xp # x has P dimension = 1 - + # P loop end # reset pointer along P dimension w_ptrs -= stride_wp * (BLOCK_SIZE_P) @@ -257,8 +285,10 @@ def morr_propagate_kernel( x_ptrs += stride_xm out = acc.to(INPUT_DTYPE) - out = out.reshape(BLOCK_SIZE_M, BLOCK_SIZE_P, BLOCK_SIZE_Q, BLOCK_SIZE_K1) # [1, 1, q, k, 1] -> [1, 1, q, k] - + out = out.reshape( + BLOCK_SIZE_M, BLOCK_SIZE_P, BLOCK_SIZE_Q, BLOCK_SIZE_K1 + ) # [1, 1, q, k, 1] -> [1, 1, q, k] + offs_om = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_op = pid_p * BLOCK_SIZE_P + tl.arange(0, BLOCK_SIZE_P) offs_oq = pid_q * BLOCK_SIZE_Q + tl.arange(0, BLOCK_SIZE_Q) @@ -274,13 +304,14 @@ def morr_propagate_kernel( m_valid = offs_om[:, None, None, None] < M p_valid = offs_op[None, :, None, None] < P q_valid = offs_oq[None, None, :, None] < Q - k_valid = offs_ok1[None, None, None, :] < K # K == BLOCK_SIZE_K1 + k_valid = offs_ok1[None, None, None, :] < K # K == BLOCK_SIZE_K1 o_mask = m_valid & p_valid & q_valid & k_valid tl.store(o_ptrs, out, mask=o_mask) - # pdb.set_trace() + @torch.library.custom_op( - f"{PACKAGE_NAME}::optical_morr_linear_fn", mutates_args={}, + f"{PACKAGE_NAME}::optical_morr_linear_fn", + mutates_args={}, ) def morr_linear_fn( x: Tensor, @@ -315,7 +346,7 @@ def morr_linear_fn( in_quant_alg: str = "dorefa", w_quant_alg: str = "dorefa_pos", morr_output_scale_quant_alg: str = "dorefa_sym", - seed: int=42, + seed: int = 42, ) -> tuple[Tensor, int, Tensor, Tensor, Tensor, Tensor, Tensor, float, Tensor, Tensor]: Device = x.device Dtype = x.dtype @@ -349,7 +380,7 @@ def morr_linear_fn( if in_features_pad > D: x_pad = torch.zeros(M, in_features_pad - D, device=Device, dtype=x.dtype) x = torch.cat([x, x_pad], dim=1) - + assert Q * K == in_features_pad, "input and weight dimension mismatch" assert P * K == out_features_pad, "weight and output dimension mismatch" @@ -360,7 +391,7 @@ def morr_linear_fn( input_quantizer.set_bitwidth(in_bit) ctx_x_quant = x.clone() x = input_quantizer(x) - + # Build weight ctx_w_quant = torch.empty(0, device=Device, dtype=Dtype) if w_bit < 16: @@ -376,19 +407,17 @@ def morr_linear_fn( morr_scale = morr_scale * weight_quant_gain else: morr_scale = weight_quant_gain - weight = weight.mul( - morr_scale - ) ### gain factor from Tanh used in quantization + weight = weight.mul(morr_scale) ### gain factor from Tanh used in quantization ### quantize learnable balancing factor morr_output_scale_quantizer = weight_quantize_fn(w_bit, alg="dorefa_sym") morr_output_scale = morr_output_scale_quantizer(morr_output_scale) else: weight = weight.abs() # positive only - morr_output_scale = (morr_output_scale - morr_output_scale.data.mean()) - + morr_output_scale = morr_output_scale - morr_output_scale.data.mean() + if finegrain_drop_mask is not None: weight = weight.mul(finegrain_drop_mask.float()) - + # differential balancing factor concatenation scale = morr_output_scale[..., :-1, :] scale_pad = morr_output_scale[..., -1:, :] @@ -404,26 +433,27 @@ def morr_linear_fn( morr_output_scale = scale.squeeze(-1).unsqueeze(0) # [1 ,1, 1, q] ctx_morr_output_scale = morr_output_scale.clone() - # Reshape x and weight x = x.view(-1, grid_dim_x, miniblock) # [M, q, k] - x = x.unsqueeze(1).unsqueeze(-1) # [M, 1, q, k, 1] - weight = toeplitz(weight).unsqueeze(0) # [p, q, k] -> [1, p, q, k, k] + x = x.unsqueeze(1).unsqueeze(-1) # [M, 1, q, k, 1] + weight = toeplitz(weight).unsqueeze(0) # [p, q, k] -> [1, p, q, k, k] - x_ctx = x.squeeze(-1).squeeze(1).clone() # [M, q, k] + x_ctx = x.squeeze(-1).squeeze(1).clone() # [M, q, k] w_ctx = weight.clone() - + # Allocate output output = torch.empty((M, P, Q, K, 1), device=Device, dtype=x.dtype) # Launch the Triton kernel grid = lambda meta: ( - triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(P, meta["BLOCK_SIZE_P"]) * triton.cdiv(Q, meta["BLOCK_SIZE_Q"]), + triton.cdiv(M, meta["BLOCK_SIZE_M"]) + * triton.cdiv(P, meta["BLOCK_SIZE_P"]) + * triton.cdiv(Q, meta["BLOCK_SIZE_Q"]), ) morr_propagate_kernel[grid]( - x_ptr = x, - w_ptr = weight, - o_ptr = output, - b_ptr = morr_bias, + x_ptr=x, + w_ptr=weight, + o_ptr=output, + b_ptr=morr_bias, M=M, P=P, Q=Q, @@ -460,16 +490,18 @@ def morr_linear_fn( stride_ok2=output.stride(4), ENABLE_THERMAL_CROSSTALK=enable_thermal_crosstalk, ENABLE_PHASE_NOISE=enable_phase_noise and phase_noise_std > 1e-4, - TRAINABLE_MORR_BIAS = trainable_morr_bias, + TRAINABLE_MORR_BIAS=trainable_morr_bias, INPUT_DTYPE=TORCH_DTYPE_TO_TRITON[x.dtype], - BLOCK_SIZE_K1 = K, + BLOCK_SIZE_K1=K, ) # Apply output scale output = output.squeeze(-1) # [m, p, q, k, 1] -> [m, p, q, k] - ctx_x_scalematmul = output.clone() # record x input for matmul - output = morr_output_scale.matmul(output) # [1, 1, 1, q] x [bs, p, q, k] = [bs, p, 1, k] - output = output.flatten(1) # [bs, p*k] + ctx_x_scalematmul = output.clone() # record x input for matmul + output = morr_output_scale.matmul( + output + ) # [1, 1, 1, q] x [bs, p, q, k] = [bs, p, 1, k] + output = output.flatten(1) # [bs, p*k] # Trim output if needed if out_features < out_features_pad: @@ -486,52 +518,51 @@ def morr_linear_fn( # ) return ( - output, - seed, - torch.abs(w_ctx), - x_ctx, - ctx_morr_output_scale, - ctx_x_scalematmul, - morr_scale.clone(), + output, + seed, + torch.abs(w_ctx), + x_ctx, + ctx_morr_output_scale, + ctx_x_scalematmul, + morr_scale.clone(), weight_quant_gain if weight_quant_gain is not None else 0.0, ctx_x_quant, ctx_w_quant, ) - def _morr_linear_setup_context(ctx, inputs, output): """ Save for backward only what the backward routine really needs. """ ( - x, # 0 Tensor – input - weight, # 1 Tensor – learnable weight - morr_input_bias, # 23 Tensor - origin_morr_output_scale, # 3 Original input morr_output_scale - bias, # 4 Tensor | None – bias - morr_input_scale, # 5 Tensor - morr_bias, # 2 Tensor | None - grid_dim_x, # 5 int - grid_dim_y, # 6 int - miniblock, # 7 int (== K) - enable_thermal_crosstalk,# 8 bool - crosstalk_factor, # 9 float - enable_phase_noise, # 10 bool - phase_noise_std, # 11 float - trainable_morr_bias, # 12 bool - mrr_a, # 13 float - mrr_r, # 14 float - finegrain_drop_mask, # 15 Tensor | None - in_features, # 16 int - in_features_pad, # 17 int - out_features, # 18 int - out_features_pad, # 19 int - in_bit, # 20 int - w_bit, # 21 int - morr_fwhm, # 22 float + x, # 0 Tensor – input + weight, # 1 Tensor – learnable weight + morr_input_bias, # 23 Tensor + origin_morr_output_scale, # 3 Original input morr_output_scale + bias, # 4 Tensor | None – bias + morr_input_scale, # 5 Tensor + morr_bias, # 2 Tensor | None + grid_dim_x, # 5 int + grid_dim_y, # 6 int + miniblock, # 7 int (== K) + enable_thermal_crosstalk, # 8 bool + crosstalk_factor, # 9 float + enable_phase_noise, # 10 bool + phase_noise_std, # 11 float + trainable_morr_bias, # 12 bool + mrr_a, # 13 float + mrr_r, # 14 float + finegrain_drop_mask, # 15 Tensor | None + in_features, # 16 int + in_features_pad, # 17 int + out_features, # 18 int + out_features_pad, # 19 int + in_bit, # 20 int + w_bit, # 21 int + morr_fwhm, # 22 float sigma_weight, - trainable_morr_scale, # bool + trainable_morr_scale, # bool _morr_scale, _weight_quant_gain, in_quant_alg, @@ -541,19 +572,19 @@ def _morr_linear_setup_context(ctx, inputs, output): ) = inputs ( - output, - seed, - w_morr, - x_modulator, - morr_output_scale, - x_scalematmul, - morr_scale, + output, + seed, + w_morr, + x_modulator, + morr_output_scale, + x_scalematmul, + morr_scale, weight_quant_gain, x_quant, - w_quant + w_quant, ) = output # ( - # w_morr, + # w_morr, # x_modulator, # ) = aux_tensor @@ -569,13 +600,13 @@ def _morr_linear_setup_context(ctx, inputs, output): c1 = -2.0 * mrr_a * mrr_r c2 = mrr_a * mrr_a + mrr_r * mrr_r c3 = 1.0 + (mrr_r * mrr_r) * (mrr_a * mrr_a) - mrr_a * mrr_a - mrr_r * mrr_r - c4 = (mrr_a**2.0 - 1.0) * (mrr_r**2.0 - 1.0) * 2.0 * mrr_a * mrr_r + c4 = (mrr_a**2.0 - 1.0) * (mrr_r**2.0 - 1.0) * 2.0 * mrr_a * mrr_r intensity = True mrr_para = (c1, c2, c3, c4, intensity) - + # x_morr: x input of matmal in propagate_morr() - x_morr = x_modulator ** 2 # [m, q, k] - x_morr = x_morr.unsqueeze(1).unsqueeze(-1) # [m, 1, q, k, 1] + x_morr = x_modulator**2 # [m, q, k] + x_morr = x_morr.unsqueeze(1).unsqueeze(-1) # [m, 1, q, k, 1] # x_mrr: x input of mrr_roundtrip_phase_to_tr() x_mrr = w_morr.matmul(x_morr).squeeze(-1) @@ -586,32 +617,32 @@ def _morr_linear_setup_context(ctx, inputs, output): tanh_input_bias = torch.tanh(morr_input_bias.unsqueeze(0).unsqueeze(-1)) - # 3. stash tensors + # 3. stash tensors ctx.save_for_backward( - x, # original input - weight, # original weight + x, # original input + weight, # original weight bias if bias is not None else torch.tensor([], device=device, dtype=dtype), - morr_output_scale, # morr_output_scale after modification in build_weight() - x_mrr, # x input for mrr_roundtrip_phase_to_tr() + morr_output_scale, # morr_output_scale after modification in build_weight() + x_mrr, # x input for mrr_roundtrip_phase_to_tr() x_morr, - w_morr, # w input for propagate_morr() matmul + w_morr, # w input for propagate_morr() matmul # morr_bias, - x_modulator, # x input for input_modulator() - # morr_input_bias, - x_scalematmul, # x input for morr_output_scale.matmul + x_modulator, # x input for input_modulator() + # morr_input_bias, + x_scalematmul, # x input for morr_output_scale.matmul tanh_input_bias, morr_input_scale, - morr_scale, # morr_scale after modification in build_weight() - x_quant, # x input for input_quantize_fn() - w_quant, # w input for weight_quantize_fn() - origin_morr_output_scale, # original morr_output_scale + morr_scale, # morr_scale after modification in build_weight() + x_quant, # x input for input_quantize_fn() + w_quant, # w input for weight_quantize_fn() + origin_morr_output_scale, # original morr_output_scale finegrain_drop_mask, ) ctx.tensor_shape = tensor_shape ctx.mrr_para = mrr_para - ctx.in_features = in_features - ctx.in_features_pad = in_features_pad - ctx.out_features = out_features + ctx.in_features = in_features + ctx.in_features_pad = in_features_pad + ctx.out_features = out_features ctx.out_features_pad = out_features_pad ctx.morr_fwhm = morr_fwhm ctx.grid_dim_x = grid_dim_x @@ -631,14 +662,13 @@ def _morr_linear_setup_context(ctx, inputs, output): ctx.morr_output_scale_quant_alg = morr_output_scale_quant_alg - def _morr_linear_backward(ctx, grad_output, *ignored): """ Backward pass for morr_linear_fn. """ ( - x, - weight, + x, + weight, bias, morr_output_scale, x_mrr, @@ -654,10 +684,10 @@ def _morr_linear_backward(ctx, grad_output, *ignored): x_quant, w_quant, origin_morr_output_scale, - finegrain_drop_mask + finegrain_drop_mask, ) = ctx.saved_tensors - M, P, Q, K = ctx.tensor_shape + M, P, Q, K = ctx.tensor_shape c1, c2, c3, c4, intensity = ctx.mrr_para in_features = ctx.in_features in_features_pad = ctx.in_features_pad @@ -679,70 +709,73 @@ def _morr_linear_backward(ctx, grad_output, *ignored): # x_mrr = x_mrr + torch.zeros_like(x_mrr).normal_(0, ctx.phase_noise_std) # if ctx.trainable_morr_bias: # x_mrr = x_mrr - morr_bias - + # ----- backward prop ----- # Reshape grad_out = grad_output.view( - x_input_shape[0], - w_input_shape[1], - w_input_shape[2], - -1 + x_input_shape[0], w_input_shape[1], w_input_shape[2], -1 ) # [M, P, Q, K] - + # ----- Gradient w.r.t input x ----- if True or ctx.needs_input_grad[0]: # 1. reshape - grad_out = grad_out.view(M, -1) # [m, out_features] + grad_out = grad_out.view(M, -1) # [m, out_features] if ctx.needs_input_grad[4] and bias: - grad_bias = grad_out.sum(dim=0) # [out_features] + grad_bias = grad_out.sum(dim=0) # [out_features] else: grad_bias = None - out_pad = torch.zeros(grad_out.shape[0], out_features_pad-out_features, device = DEVICE) # [m, out_features_pad - out_features] - grad_out = torch.cat([grad_out, out_pad], dim=1) # [m * out_features_pad] = [m, p*k] + out_pad = torch.zeros( + grad_out.shape[0], out_features_pad - out_features, device=DEVICE + ) # [m, out_features_pad - out_features] + grad_out = torch.cat( + [grad_out, out_pad], dim=1 + ) # [m * out_features_pad] = [m, p*k] # 2. x=x.flatten(1) # input: [m, p**k] - grad_out = grad_out.view(M, P, 1, K) # [m, p, 1, k] + grad_out = grad_out.view(M, P, 1, K) # [m, p, 1, k] # 3. x = morr_output_scale.matmul(x) # [1, 1, 1, q] x [bs, p, q, k] = [bs, p, 1, k] # dL/d(morr_output_scale) if ctx.needs_input_grad[3]: - grad_s = grad_out.matmul(x_scalematmul.transpose(-2, -1)) # [bs, p, 1, q] - grad_s = grad_s.sum(dim=(0, 1)).unsqueeze(0).unsqueeze(1) # [1, 1, 1, q] - grad_s = grad_s.squeeze(0).unsqueeze(-1) # [1, 1, q, 1] gradient of scale + grad_s = grad_out.matmul(x_scalematmul.transpose(-2, -1)) # [bs, p, 1, q] + grad_s = grad_s.sum(dim=(0, 1)).unsqueeze(0).unsqueeze(1) # [1, 1, 1, q] + grad_s = grad_s.squeeze(0).unsqueeze(-1) # [1, 1, q, 1] gradient of scale - t = ctx.grid_dim_x // 2 - grad_output_scale = grad_s.new_zeros((1, 1, t+1, 1)) + t = ctx.grid_dim_x // 2 + grad_output_scale = grad_s.new_zeros((1, 1, t + 1, 1)) if ctx.grid_dim_x % 2 == 0: grad_output_scale[..., :t, :] = grad_s[..., :t, :] - grad_s[..., t:, :] elif ctx.grid_dim_x == 1: grad_output_scale = grad_s else: - grad_output_scale[..., :t, :] = grad_s[..., :t, :] - grad_s[..., t+1:, :] - grad_output_scale[..., t:t+1, :] = grad_s[..., t:t+1, :] + grad_output_scale[..., :t, :] = ( + grad_s[..., :t, :] - grad_s[..., t + 1 :, :] + ) + grad_output_scale[..., t : t + 1, :] = grad_s[..., t : t + 1, :] # build_weight() if ctx.w_bit < 16: # morr_output_scale_quantizer() if ctx.morr_output_scale_quant_alg == "dorefa_sym": - # local recompute: - w_in = torch.tanh(origin_morr_output_scale) # [-1, 1] + # local recompute: + w_in = torch.tanh(origin_morr_output_scale) # [-1, 1] r = torch.max(w_in.abs()).detach() # ignore gradient for r here grad_output_scale = (grad_output_scale * 2 * r).clamp_(-1.0, 1.0) - grad_output_scale = grad_output_scale * (1.0 / (2 * r)) + grad_output_scale = grad_output_scale * (1.0 / (2 * r)) grad_output_scale = grad_output_scale * (1.0 - w_in.pow(2)) - + else: raise NotImplementedError else: grad_output_scale = None - + # dL/dx - grad_x = morr_output_scale.transpose(-2, -1).matmul(grad_out) # [bs, p, q, k] + grad_x = morr_output_scale.transpose(-2, -1).matmul(grad_out) # [bs, p, q, k] # 4. x = mrr_roundtrip_phase_to_tr(x) denominator = x_mrr.cos().mul_(c1).add_(c2 + c3) @@ -751,36 +784,38 @@ def _morr_linear_backward(ctx, grad_output, *ignored): numerator = x_mrr.sin().mul_(c4) else: numerator = x_mrr.sin().mul_(c4 / 2) - denominator = ( - denominator.sub(1).pow_(1.5).mul_(denominator.sub(c3).sqrt_()) - ) - grad_x = numerator.div_(denominator).mul_(grad_x) # [bs, p, q, k] - + denominator = denominator.sub(1).pow_(1.5).mul_(denominator.sub(c3).sqrt_()) + grad_x = numerator.div_(denominator).mul_(grad_x) # [bs, p, q, k] + # 5. x += phase_noise and morr_bias if ctx.needs_input_grad[2]: - grad_inputbias = - grad_x # [bs, p, q, k] - grad_inputbias = grad_inputbias * ctx.morr_fwhm # [bs, p, q, k] - grad_inputbias = grad_inputbias - tanh_input_bias * tanh_input_bias # [bs, p, q, k] + grad_inputbias = -grad_x # [bs, p, q, k] + grad_inputbias = grad_inputbias * ctx.morr_fwhm # [bs, p, q, k] + grad_inputbias = ( + grad_inputbias - tanh_input_bias * tanh_input_bias + ) # [bs, p, q, k] grad_inputbias = grad_inputbias.sum(dim=(0, -1)) else: grad_inputbias = None # 6. x = weight.matmul(x) [1, p, q, k, k] * [bs, 1, q, k, 1] = [bs, p, q, k, 1] - grad_x = grad_x.unsqueeze(-1) # [bs, p, q, k, 1] - grad_morr_matmul = grad_x # stash for weight gradient - + grad_x = grad_x.unsqueeze(-1) # [bs, p, q, k, 1] + grad_morr_matmul = grad_x # stash for weight gradient + # dL/dx - grad_x = torch.matmul(w_morr.transpose(-1, -2), grad_x) # [1, p, q, k, k] x [bs, p, q, k, 1] = [bs, p, q, k, 1] - grad_x = grad_x.sum(dim=1, keepdim=True) # [bs, p, q, k, 1] -> [bs, 1, q, k, 1] - grad_x = grad_x.squeeze(-1).squeeze(1) # [bs, 1, q, k, 1] -> [bs, q, k] + grad_x = torch.matmul( + w_morr.transpose(-1, -2), grad_x + ) # [1, p, q, k, k] x [bs, p, q, k, 1] = [bs, p, q, k, 1] + grad_x = grad_x.sum(dim=1, keepdim=True) # [bs, p, q, k, 1] -> [bs, 1, q, k, 1] + grad_x = grad_x.squeeze(-1).squeeze(1) # [bs, 1, q, k, 1] -> [bs, q, k] # 7. input modulator(x) - grad_x = grad_x * 2 * x_modulator # [bs, q, k] + grad_x = grad_x * 2 * x_modulator # [bs, q, k] # 8. input reshape B, N, D = x_input_shape - grad_x = grad_x.view(-1, in_features_pad) # [b*n, in_features_pad] - grad_x = grad_x[:, :in_features] # [b*n, in_features = D] + grad_x = grad_x.view(-1, in_features_pad) # [b*n, in_features_pad] + grad_x = grad_x[:, :in_features] # [b*n, in_features = D] # 9.input quantization if ctx.in_bit >= 16 or ctx.in_quant_alg is None: @@ -791,27 +826,29 @@ def _morr_linear_backward(ctx, grad_output, *ignored): raise NotImplementedError # 10. input reshape - grad_x = grad_x.view(B, N, D) # [b, n, d] + grad_x = grad_x.view(B, N, D) # [b, n, d] # ----- Gradient w.r.t weight ----- if True or ctx.needs_input_grad[1]: - + # 0. gradient after x = weight.matmul(x) # grad_morr_matmul # [bs, p, q, k, 1] # 1. x = weight.matmul(x) - grad_w = torch.matmul(grad_morr_matmul, x_morr.transpose(-1,-2)) # [bs,p,q,k,k] - grad_w = grad_w.sum(dim=0, keepdim=True) # [1,p,q,k,k] + grad_w = torch.matmul( + grad_morr_matmul, x_morr.transpose(-1, -2) + ) # [bs,p,q,k,k] + grad_w = grad_w.sum(dim=0, keepdim=True) # [1,p,q,k,k] # 2. weight = toeplitz(weight) k = grad_w.size(-1) - row = torch.arange(k)[:, None] # (k,1) - col = torch.arange(k)[None, :] # (1,k) - idx = (row - col) & (k - 1) if (k & (k-1)) == 0 else (row - col + k) % k + row = torch.arange(k)[:, None] # (k,1) + col = torch.arange(k)[None, :] # (1,k) + idx = (row - col) & (k - 1) if (k & (k - 1)) == 0 else (row - col + k) % k idx = idx.expand(grad_w.shape).to(DEVICE) buffer = torch.zeros_like(grad_w, device=DEVICE) - buffer.scatter_add_(-2, idx, grad_w) # [1, p, q, k, k]cvb - grad_w = buffer.sum(dim=-1, keepdim=True).squeeze(0).squeeze(-1) # [p, q, k] + buffer.scatter_add_(-2, idx, grad_w) # [1, p, q, k, k]cvb + grad_w = buffer.sum(dim=-1, keepdim=True).squeeze(0).squeeze(-1) # [p, q, k] # 3. build_weight() if finegrain_drop_mask is not None: @@ -819,23 +856,29 @@ def _morr_linear_backward(ctx, grad_output, *ignored): # morr_scale: [p, q, 1] grad_morr_input_scale = None if ctx.w_bit < 16: - # grad w.r.t morr_scale + # grad w.r.t morr_scale if ctx.needs_input_grad[5] & ctx.trainable_morr_scale: - grad_morr_scale = (grad_w * weight).sum(dim=2, keepdim=True) # [p, q, 1] - grad_morr_scale = grad_morr_scale * ctx.weight_quant_gain # [p, q, 1] + grad_morr_scale = (grad_w * weight).sum( + dim=2, keepdim=True + ) # [p, q, 1] + grad_morr_scale = grad_morr_scale * ctx.weight_quant_gain # [p, q, 1] # ∂L/∂self.morr_input_scale sigmoid_scale = torch.sigmoid(morr_input_scale) - grad_morr_input_scale = (grad_morr_scale * sigmoid_scale * (1-sigmoid_scale)).squeeze(-1) # [p, q] + grad_morr_input_scale = ( + grad_morr_scale * sigmoid_scale * (1 - sigmoid_scale) + ).squeeze( + -1 + ) # [p, q] # grad w.r.t weight - grad_w = grad_w * morr_scale # weight.mul(morr_scale) + grad_w = grad_w * morr_scale # weight.mul(morr_scale) # weight_quantizer() if ctx.w_quant_alg is None: pass elif ctx.w_quant_alg == "dorefa_pos": - # local recompute: + # local recompute: w_in = torch.tanh(w_quant) - r = torch.max(w_in.abs()).detach() + 1e-12 # ε avoids /0 + r = torch.max(w_in.abs()).detach() + 1e-12 # ε avoids /0 # ignore gradient for r here # grad_w = grad_w * (1.0 - w_in.pow(2)) # grad_w = grad_w.clamp_(-1, 1) @@ -847,21 +890,41 @@ def _morr_linear_backward(ctx, grad_output, *ignored): raise NotImplementedError else: grad_w = grad_w * weight.sign() - + return ( - grad_x, # ∂L/∂x - grad_w, # ∂L/∂w - grad_inputbias, # ∂L/∂morr_input_bias + grad_x, # ∂L/∂x + grad_w, # ∂L/∂w + grad_inputbias, # ∂L/∂morr_input_bias grad_output_scale, # ∂L/∂morr_output_scale - grad_bias, # ∂L/∂bias + grad_bias, # ∂L/∂bias grad_morr_input_scale, - None, None, None, None, None, None, None, None, None, - None, None, None, - None, None, None, None, None, None, None, - None, None, None, None + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, ) morr_linear_fn.register_autograd( - _morr_linear_backward, setup_context=_morr_linear_setup_context, -) \ No newline at end of file + _morr_linear_backward, + setup_context=_morr_linear_setup_context, +) diff --git a/src/chop/nn/optical/triton_modules/morr_linear_kernel_mem.py b/src/chop/nn/optical/triton_modules/morr_linear_kernel_mem.py index 6ecc2e033..c2eb11b27 100644 --- a/src/chop/nn/optical/triton_modules/morr_linear_kernel_mem.py +++ b/src/chop/nn/optical/triton_modules/morr_linear_kernel_mem.py @@ -1,4 +1,5 @@ import os + # os.environ["TRITON_INTERPRET"] = "1" import torch @@ -8,12 +9,13 @@ import pdb from .dtype import TORCH_DTYPE_TO_TRITON + PACKAGE_NAME = "mase_triton" from ..utils import ( toeplitz, input_quantize_fn, weight_quantize_fn, - mrr_roundtrip_phase_to_tr_func + mrr_roundtrip_phase_to_tr_func, ) from .quantize import _input_quantize_fn, _weight_quantize_fn @@ -38,6 +40,7 @@ def _get_autotune_configs(): ) return configs + @triton.jit def _mrr_roundtrip_phase_to_tr_func( x: tl.tensor, @@ -60,6 +63,7 @@ def _mrr_roundtrip_phase_to_tr_func( x = tl.sqrt(x) return x + # @triton.autotune( # configs= [ # triton.Config( @@ -76,7 +80,7 @@ def _mrr_roundtrip_phase_to_tr_func( # key=["M", "P", "Q", "K"], # ) @triton.autotune( - configs = _get_autotune_configs(), + configs=_get_autotune_configs(), key=["M", "P", "Q", "K"], ) @triton.jit @@ -100,10 +104,25 @@ def morr_propagate_kernel( w_bit, seed, # stride - stride_wm, stride_wp, stride_wq, stride_wk1, stride_wk2, - stride_xm, stride_xp, stride_xq, stride_xk1, stride_xk2, - stride_bm, stride_bp, stride_bq, stride_bk1, - stride_om, stride_op, stride_oq, stride_ok1, stride_ok2, + stride_wm, + stride_wp, + stride_wq, + stride_wk1, + stride_wk2, + stride_xm, + stride_xp, + stride_xq, + stride_xk1, + stride_xk2, + stride_bm, + stride_bp, + stride_bq, + stride_bk1, + stride_om, + stride_op, + stride_oq, + stride_ok1, + stride_ok2, finegrain_drop_mask, ENABLE_PHASE_NOISE: tl.constexpr, ENABLE_THERMAL_CROSSTALK: tl.constexpr, @@ -132,7 +151,7 @@ def morr_propagate_kernel( start_m = pid_m * BLOCK_SIZE_M start_p = pid_p * BLOCK_SIZE_P start_q = pid_q * BLOCK_SIZE_Q - + # w [1, p, q, k, 1] -> toeplitz [1, p, q, k, k] offs_wm = tl.arange(0, 1) offs_wp = pid_p * BLOCK_SIZE_P + tl.arange(0, 1) @@ -172,8 +191,10 @@ def morr_propagate_kernel( + offs_bk1[None, None, None, :, None] * stride_bk1 ) - - acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_P, BLOCK_SIZE_Q, BLOCK_SIZE_K1, BLOCK_SIZE_K2), dtype=tl.float32) + acc = tl.zeros( + (BLOCK_SIZE_M, BLOCK_SIZE_P, BLOCK_SIZE_Q, BLOCK_SIZE_K1, BLOCK_SIZE_K2), + dtype=tl.float32, + ) m_indices = tl.arange(0, BLOCK_SIZE_M)[:, None, None, None, None] p_indices = tl.arange(0, BLOCK_SIZE_P)[None, :, None, None, None] q_indices = tl.arange(0, BLOCK_SIZE_Q)[None, None, :, None, None] @@ -185,18 +206,16 @@ def morr_propagate_kernel( for q_local in range(BLOCK_SIZE_Q): q = start_q + q_local - w_mask = (p < P) & (q < Q) + w_mask = (p < P) & (q < Q) x_mask = (m < M) & (q < Q) b_mask = (p < P) & (q < Q) w = tl.load(w_ptrs, mask=w_mask, other=0.0) x = tl.load(x_ptrs, mask=x_mask, other=0.0) b = tl.load(b_ptrs, mask=b_mask, other=0.0) - - w = w.reshape(BLOCK_SIZE_K1, BLOCK_SIZE_K1) # [1, 1, 1, k, k] -> [k, k] - x = x.reshape(BLOCK_SIZE_K1, BLOCK_SIZE_K2) # [1, 1, 1, k, 1] -> [k, 1] - + w = w.reshape(BLOCK_SIZE_K1, BLOCK_SIZE_K1) # [1, 1, 1, k, k] -> [k, k] + x = x.reshape(BLOCK_SIZE_K1, BLOCK_SIZE_K2) # [1, 1, 1, k, 1] -> [k, 1] x = x * x # input_modulator() # ----- propagate_morr() ----- @@ -204,7 +223,7 @@ def morr_propagate_kernel( # apply thermal crosstalk noise if ENABLE_THERMAL_CROSSTALK: w = w * crosstalk_factor - + # MatMals # TODO: tl.dot requires 16*16 matrix at least, this is a workaround x = tl.trans(x) @@ -215,30 +234,37 @@ def morr_propagate_kernel( # apply phase noise if ENABLE_PHASE_NOISE: block_start = pid * BLOCK_SIZE_K1 * BLOCK_SIZE_K2 - offs = tl.reshape(block_start + tl.arange(0, BLOCK_SIZE_K1 * BLOCK_SIZE_K2) , (BLOCK_SIZE_K1, BLOCK_SIZE_K2)) + offs = tl.reshape( + block_start + tl.arange(0, BLOCK_SIZE_K1 * BLOCK_SIZE_K2), + (BLOCK_SIZE_K1, BLOCK_SIZE_K2), + ) noise = tl.randn(seed, offs) * phase_noise_std x = x + noise # add trainable bias b = b.reshape(1, 1) - # pdb.set_trace() + if TRAINABLE_MORR_BIAS: x = x - b - + # mrr_roundtrip_phase_to_tr x = _mrr_roundtrip_phase_to_tr_func(x, mrr_a, mrr_r, intensity=True) # store the value in acc using mask res = x - condition_mask = (m_indices == m_local) & (p_indices == p_local) & (q_indices == q_local) + condition_mask = ( + (m_indices == m_local) + & (p_indices == p_local) + & (q_indices == q_local) + ) res = res[None, None, None, :, :] - acc = tl.where(condition_mask, res, acc) + acc = tl.where(condition_mask, res, acc) # propagate pointer along Q dimension w_ptrs += stride_wq x_ptrs += stride_xq b_ptrs += stride_bq - + # Q loop end # reset pointer along Q dimension w_ptrs -= stride_wq * (BLOCK_SIZE_Q) @@ -248,7 +274,7 @@ def morr_propagate_kernel( w_ptrs += stride_wp b_ptrs += stride_bp # x_ptrs += stride_xp # x has P dimension = 1 - + # P loop end # reset pointer along P dimension w_ptrs -= stride_wp * (BLOCK_SIZE_P) @@ -259,9 +285,10 @@ def morr_propagate_kernel( # w_ptrs += stride_wp # weight has M dimension = 1 x_ptrs += stride_xm - out = acc.to(INPUT_DTYPE) - out = out.reshape(BLOCK_SIZE_M, BLOCK_SIZE_P, BLOCK_SIZE_Q, BLOCK_SIZE_K1) # [1, 1, q, k, 1] -> [1, 1, q, k] + out = out.reshape( + BLOCK_SIZE_M, BLOCK_SIZE_P, BLOCK_SIZE_Q, BLOCK_SIZE_K1 + ) # [1, 1, q, k, 1] -> [1, 1, q, k] offs_om = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_op = pid_p * BLOCK_SIZE_P + tl.arange(0, BLOCK_SIZE_P) @@ -278,12 +305,14 @@ def morr_propagate_kernel( m_valid = offs_om[:, None, None, None] < M p_valid = offs_op[None, :, None, None] < P q_valid = offs_oq[None, None, :, None] < Q - k_valid = offs_ok1[None, None, None, :] < K # K == BLOCK_SIZE_K1 + k_valid = offs_ok1[None, None, None, :] < K # K == BLOCK_SIZE_K1 o_mask = m_valid & p_valid & q_valid & k_valid tl.store(o_ptrs, out, mask=o_mask) + @torch.library.custom_op( - f"{PACKAGE_NAME}::optical_morr_linear_linear_fn", mutates_args={}, + f"{PACKAGE_NAME}::optical_morr_linear_linear_fn", + mutates_args={}, ) def morr_linear_fn_mem( x: Tensor, @@ -315,7 +344,7 @@ def morr_linear_fn_mem( trainable_morr_scale: bool, morr_scale: Tensor, weight_quant_gain: float | None = None, - seed: int=42, + seed: int = 42, ) -> tuple[Tensor, int, Tensor, Tensor, Tensor, Tensor, Tensor, float]: Device = x.device assert x.dtype in ( @@ -348,7 +377,7 @@ def morr_linear_fn_mem( if in_features_pad > D: x_pad = torch.zeros(M, in_features_pad - D, device=Device, dtype=x.dtype) x = torch.cat([x, x_pad], dim=1) - + assert Q * K == in_features_pad, "input and weight dimension mismatch" assert P * K == out_features_pad, "weight and output dimension mismatch" @@ -357,7 +386,7 @@ def morr_linear_fn_mem( input_quantizer = input_quantize_fn(in_bit, device=Device) input_quantizer.set_bitwidth(in_bit) x = input_quantizer(x) - + # Build weight if w_bit < 16: weight_quantizer = weight_quantize_fn(w_bit, alg="dorefa_pos") @@ -371,19 +400,17 @@ def morr_linear_fn_mem( morr_scale = morr_scale * weight_quant_gain else: morr_scale = weight_quant_gain - weight = weight.mul( - morr_scale - ) ### gain factor from Tanh used in quantization + weight = weight.mul(morr_scale) ### gain factor from Tanh used in quantization ### quantize learnable balancing factor morr_output_scale_quantizer = weight_quantize_fn(w_bit, alg="dorefa_sym") morr_output_scale = morr_output_scale_quantizer(morr_output_scale) else: weight = weight.abs() # positive only - morr_output_scale = (morr_output_scale - morr_output_scale.data.mean()) - + morr_output_scale = morr_output_scale - morr_output_scale.data.mean() + if finegrain_drop_mask is not None: weight = weight.mul(finegrain_drop_mask.float()) - + # differential balancing factor concatenation scale = morr_output_scale[..., :-1, :] scale_pad = morr_output_scale[..., -1:, :] @@ -401,23 +428,25 @@ def morr_linear_fn_mem( # Reshape x and weight x = x.view(-1, grid_dim_x, miniblock) # [M, q, k] - x = x.unsqueeze(1).unsqueeze(-1) # [M, 1, q, k, 1] - weight = toeplitz(weight).unsqueeze(0) # [p, q, k] -> [1, p, q, k, k] + x = x.unsqueeze(1).unsqueeze(-1) # [M, 1, q, k, 1] + weight = toeplitz(weight).unsqueeze(0) # [p, q, k] -> [1, p, q, k, k] - x_ctx = x.squeeze(-1).squeeze(1).clone() # [M, q, k] + x_ctx = x.squeeze(-1).squeeze(1).clone() # [M, q, k] w_ctx = weight.clone() - + # Allocate output output = torch.empty((M, P, Q, K, 1), device=Device, dtype=x.dtype) # Launch the Triton kernel grid = lambda meta: ( - triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(P, meta["BLOCK_SIZE_P"]) * triton.cdiv(Q, meta["BLOCK_SIZE_Q"]), + triton.cdiv(M, meta["BLOCK_SIZE_M"]) + * triton.cdiv(P, meta["BLOCK_SIZE_P"]) + * triton.cdiv(Q, meta["BLOCK_SIZE_Q"]), ) morr_propagate_kernel[grid]( - x_ptr = x, - w_ptr = weight, - o_ptr = output, - b_ptr = morr_bias, + x_ptr=x, + w_ptr=weight, + o_ptr=output, + b_ptr=morr_bias, M=M, P=P, Q=Q, @@ -454,16 +483,18 @@ def morr_linear_fn_mem( stride_ok2=output.stride(4), ENABLE_THERMAL_CROSSTALK=enable_thermal_crosstalk, ENABLE_PHASE_NOISE=enable_phase_noise and phase_noise_std > 1e-4, - TRAINABLE_MORR_BIAS = trainable_morr_bias, + TRAINABLE_MORR_BIAS=trainable_morr_bias, INPUT_DTYPE=TORCH_DTYPE_TO_TRITON[x.dtype], BLOCK_SIZE_K1=K, ) # Apply output scale output = output.squeeze(-1) # [m, p, q, k, 1] -> [m, p, q, k] - ctx_x_scalematmul = output.clone() # record x input for matmul - output = morr_output_scale.matmul(output) # [1, 1, 1, q] x [bs, p, q, k] = [bs, p, 1, k] - output = output.flatten(1) # [bs, p*k] + ctx_x_scalematmul = output.clone() # record x input for matmul + output = morr_output_scale.matmul( + output + ) # [1, 1, 1, q] x [bs, p, q, k] = [bs, p, 1, k] + output = output.flatten(1) # [bs, p*k] # Trim output if needed if out_features < out_features_pad: @@ -474,8 +505,16 @@ def morr_linear_fn_mem( if is_transformer: output = output.view(in_B, in_N, out_features) - return output, seed, torch.abs(w_ctx), x_ctx, ctx_morr_output_scale, ctx_x_scalematmul, morr_scale.clone(), weight_quant_gain if weight_quant_gain is not None else 0.0 - + return ( + output, + seed, + torch.abs(w_ctx), + x_ctx, + ctx_morr_output_scale, + ctx_x_scalematmul, + morr_scale.clone(), + weight_quant_gain if weight_quant_gain is not None else 0.0, + ) def _morr_linear_setup_context(ctx, inputs, output): @@ -483,39 +522,48 @@ def _morr_linear_setup_context(ctx, inputs, output): Save for backward only what the backward routine really needs. """ ( - x, # 0 Tensor – input - weight, # 1 Tensor – learnable weight - morr_input_bias, # 23 Tensor - _, # 3 morr_output_scale (original) - bias, # 4 Tensor | None – bias + x, # 0 Tensor – input + weight, # 1 Tensor – learnable weight + morr_input_bias, # 23 Tensor + _, # 3 morr_output_scale (original) + bias, # 4 Tensor | None – bias morr_input_scale, - morr_bias, # 2 Tensor | None - grid_dim_x, # 5 int - grid_dim_y, # 6 int - miniblock, # 7 int (== K) - enable_thermal_crosstalk,# 8 bool - crosstalk_factor, # 9 float - enable_phase_noise, # 10 bool - phase_noise_std, # 11 float - trainable_morr_bias, # 12 bool - mrr_a, # 13 float - mrr_r, # 14 float - finegrain_drop_mask, # 15 Tensor | None - in_features, # 16 int - in_features_pad, # 17 int - out_features, # 18 int - out_features_pad, # 19 int - in_bit, # 20 int - w_bit, # 21 int - morr_fwhm, # 22 float + morr_bias, # 2 Tensor | None + grid_dim_x, # 5 int + grid_dim_y, # 6 int + miniblock, # 7 int (== K) + enable_thermal_crosstalk, # 8 bool + crosstalk_factor, # 9 float + enable_phase_noise, # 10 bool + phase_noise_std, # 11 float + trainable_morr_bias, # 12 bool + mrr_a, # 13 float + mrr_r, # 14 float + finegrain_drop_mask, # 15 Tensor | None + in_features, # 16 int + in_features_pad, # 17 int + out_features, # 18 int + out_features_pad, # 19 int + in_bit, # 20 int + w_bit, # 21 int + morr_fwhm, # 22 float sigma_weight, - trainable_morr_scale, # bool + trainable_morr_scale, # bool _morr_scale, weight_quant_gain, - seed, # 23 int + seed, # 23 int ) = inputs - output, seed, w_morr, x_modulator, morr_output_scale, x_scalematmul, morr_scale, _weight_quant_gain = output + ( + output, + seed, + w_morr, + x_modulator, + morr_output_scale, + x_scalematmul, + morr_scale, + _weight_quant_gain, + ) = output device, dtype = x.device, x.dtype @@ -529,10 +577,10 @@ def _morr_linear_setup_context(ctx, inputs, output): # c1 = -2.0 * mrr_a * mrr_r # c2 = mrr_a * mrr_a + mrr_r * mrr_r # c3 = 1.0 + (mrr_r * mrr_r) * (mrr_a * mrr_a) - mrr_a * mrr_a - mrr_r * mrr_r - # c4 = (mrr_a**2.0 - 1.0) * (mrr_r**2.0 - 1.0) * 2.0 * mrr_a * mrr_r + # c4 = (mrr_a**2.0 - 1.0) * (mrr_r**2.0 - 1.0) * 2.0 * mrr_a * mrr_r # intensity = True # mrr_para = (c1, c2, c3, c4, intensity) - + # # x_morr: x input of matmal in propagate_morr() # x_morr = x_modulator ** 2 # [m, q, k] # x_morr = x_morr.unsqueeze(1).unsqueeze(-1) # [m, 1, q, k, 1] @@ -546,28 +594,28 @@ def _morr_linear_setup_context(ctx, inputs, output): # tanh_input_bias = torch.tanh(morr_input_bias.unsqueeze(0).unsqueeze(-1)) # Added from linear.py - # 3. stash tensors + # 3. stash tensors ctx.save_for_backward( - x, # original input (stashing x for mem version, might need re-evaluation for pure mem-saving) - weight, # original weight (stashing weight for mem version) + x, # original input (stashing x for mem version, might need re-evaluation for pure mem-saving) + weight, # original weight (stashing weight for mem version) bias if bias is not None else torch.tensor([], device=device, dtype=dtype), - morr_output_scale, # original morr_output_scale + morr_output_scale, # original morr_output_scale # x_mrr, # x input for mrr_roundtrip_phase_to_tr() # x_morr, # w_morr, # w input for propagate_morr() matmul # x_modulator, # x input for input_modulator() - morr_input_bias, + morr_input_bias, # x_scalematmul, # x_scalematmul, # x input for morr_output_scale.matmul - morr_input_scale, # morr input scale at input + morr_input_scale, # morr input scale at input # morr_scale, # morr_scale after modification in build_weight() finegrain_drop_mask, ) ctx.tensor_shape = tensor_shape # ctx.mrr_para = mrr_para - ctx.in_features = in_features - ctx.in_features_pad = in_features_pad - ctx.out_features = out_features + ctx.in_features = in_features + ctx.in_features_pad = in_features_pad + ctx.out_features = out_features ctx.out_features_pad = out_features_pad ctx.morr_fwhm = morr_fwhm ctx.grid_dim_x = grid_dim_x @@ -590,6 +638,7 @@ def _morr_linear_setup_context(ctx, inputs, output): ctx.mrr_a = mrr_a ctx.mrr_r = mrr_r + def recompute_activations( ctx, x: Tensor, @@ -633,7 +682,7 @@ def recompute_activations( input_quantizer = input_quantize_fn(ctx.in_bit, device=Device) input_quantizer.set_bitwidth(ctx.in_bit) x = input_quantizer(x) - + ################# Build weight ################# if ctx.w_bit < 16: weight_quantizer = weight_quantize_fn(ctx.w_bit, alg="dorefa_pos") @@ -653,21 +702,19 @@ def recompute_activations( morr_scale = morr_scale * weight_quant_gain else: morr_scale = weight_quant_gain - + ctx_morr_scale = morr_scale.clone() - weight = weight.mul( - morr_scale - ) ### gain factor from Tanh used in quantization + weight = weight.mul(morr_scale) ### gain factor from Tanh used in quantization ### quantize learnable balancing factor morr_output_scale_quantizer = weight_quantize_fn(ctx.w_bit, alg="dorefa_sym") morr_output_scale = morr_output_scale_quantizer(morr_output_scale) else: weight = weight.abs() # positive only - morr_output_scale = (morr_output_scale - morr_output_scale.data.mean()) - + morr_output_scale = morr_output_scale - morr_output_scale.data.mean() + if finegrain_drop_mask is not None: weight = weight.mul(finegrain_drop_mask.float()) - + # differential balancing factor concatenation scale = morr_output_scale[..., :-1, :] scale_pad = morr_output_scale[..., -1:, :] @@ -688,12 +735,11 @@ def recompute_activations( # input_modulator() ctx_x_modulator = x.clone() - x = x ** 2 - + x = x**2 ################# propagate_morr() ################# if ctx.enable_thermal_crosstalk and ctx.crosstalk_factor > 1: - weight = weight * ctx.crosstalk_factor + weight = weight * ctx.crosstalk_factor weight = toeplitz(weight).unsqueeze(0) # [1, p, q, k, k] x = x.unsqueeze(1).unsqueeze(-1) # [bs, 1, q, k, 1] @@ -708,15 +754,17 @@ def recompute_activations( ctx_tanh_input_bias = torch.tanh(morr_input_bias.unsqueeze(0).unsqueeze(-1)) morr_bias = ctx.morr_fwhm * ctx_tanh_input_bias x = x - morr_bias - + ctx_x_mrr = x.clone() - - mrr_roundtrip_phase_to_tr = mrr_roundtrip_phase_to_tr_func(a=ctx.mrr_a, r=ctx.mrr_r, intensity=True) + + mrr_roundtrip_phase_to_tr = mrr_roundtrip_phase_to_tr_func( + a=ctx.mrr_a, r=ctx.mrr_r, intensity=True + ) x = mrr_roundtrip_phase_to_tr(x) ctx_x_scalematmul = x.clone() - x = morr_output_scale.matmul(x) # [1, 1, 1, q] x [bs, p, q, k] = [bs, p, 1, k] - x = x.flatten(1) # [bs, p*k] + x = morr_output_scale.matmul(x) # [1, 1, 1, q] x [bs, p, q, k] = [bs, p, 1, k] + x = x.flatten(1) # [bs, p*k] # ------------------------------------------------------ @@ -731,23 +779,24 @@ def recompute_activations( return ( # x, weight, bias, morr_output_scale, - # output, - ctx_x_modulator, # x input for input_modulator() - ctx_x_morr, # x input for propagate_morr() matmul - ctx_w_morr, # w input for propagate_morr() matmul - ctx_x_mrr, # x input for mrr_roundtrip_phase_to_tr() - ctx_x_scalematmul, # x input for morr_output_scale.matmul - ctx_tanh_input_bias, # input_bias after tanh() - ctx_morr_scale, # morr_scale after modification in build_weight() + # output, + ctx_x_modulator, # x input for input_modulator() + ctx_x_morr, # x input for propagate_morr() matmul + ctx_w_morr, # w input for propagate_morr() matmul + ctx_x_mrr, # x input for mrr_roundtrip_phase_to_tr() + ctx_x_scalematmul, # x input for morr_output_scale.matmul + ctx_tanh_input_bias, # input_bias after tanh() + ctx_morr_scale, # morr_scale after modification in build_weight() ) + def _morr_linear_backward(ctx, grad_output, *ignored): """ Backward pass for morr_linear_fn. """ ( - x, - weight, + x, + weight, bias, morr_output_scale, # x_mrr, @@ -758,10 +807,10 @@ def _morr_linear_backward(ctx, grad_output, *ignored): # x_scalematmul, morr_input_scale, # morr_scale, - finegrain_drop_mask + finegrain_drop_mask, ) = ctx.saved_tensors - M, P, Q, K = ctx.tensor_shape + M, P, Q, K = ctx.tensor_shape # c1, c2, c3, c4, intensity = ctx.mrr_para in_features = ctx.in_features in_features_pad = ctx.in_features_pad @@ -773,22 +822,22 @@ def _morr_linear_backward(ctx, grad_output, *ignored): # --- calculate intermediate activation on the fly --- ( - x_modulator, # x input for input_modulator() - x_morr, # x input for propagate_morr() matmul - w_morr, # w input for propagate_morr() matmul - x_mrr, # x input for mrr_roundtrip_phase_to_tr() - x_scalematmul, # x input for morr_output_scale.matmul - tanh_input_bias, # input_bias after tanh() - morr_scale, # morr_scale after modificaiton in build_weight() + x_modulator, # x input for input_modulator() + x_morr, # x input for propagate_morr() matmul + w_morr, # w input for propagate_morr() matmul + x_mrr, # x input for mrr_roundtrip_phase_to_tr() + x_scalematmul, # x input for morr_output_scale.matmul + tanh_input_bias, # input_bias after tanh() + morr_scale, # morr_scale after modificaiton in build_weight() ) = recompute_activations( - ctx, - x, - weight, - bias, - morr_output_scale, - finegrain_drop_mask, - morr_input_bias, - morr_input_scale + ctx, + x, + weight, + bias, + morr_output_scale, + finegrain_drop_mask, + morr_input_bias, + morr_input_scale, ) # x_morr = (x_modulator ** 2).unsqueeze(1).unsqueeze(-1) # [m, q, k] -> # [m, 1, q, k, 1] @@ -802,65 +851,64 @@ def _morr_linear_backward(ctx, grad_output, *ignored): # x_mrr = x_mrr + torch.zeros_like(x_mrr).normal_(0, ctx.phase_noise_std) # if ctx.trainable_morr_bias: # x_mrr = x_mrr - morr_bias - - - + # ----- backward prop ----- # Reshape grad_out = grad_output.view( - x_input_shape[0], - w_input_shape[1], - w_input_shape[2], - -1 + x_input_shape[0], w_input_shape[1], w_input_shape[2], -1 ) # [M, P, Q, K] - + # ----- Gradient w.r.t input x ----- if True or ctx.needs_input_grad[0]: # 1. reshape - grad_out = grad_out.view(M, -1) # [m, out_features] + grad_out = grad_out.view(M, -1) # [m, out_features] if ctx.needs_input_grad[4] and bias: - grad_bias = grad_out.sum(dim=0) # [out_features] + grad_bias = grad_out.sum(dim=0) # [out_features] else: grad_bias = None - out_pad = torch.zeros(grad_out.shape[0], out_features_pad-out_features, device = DEVICE) # [m, out_features_pad - out_features] - grad_out = torch.cat([grad_out, out_pad], dim=1) # [m * out_features_pad] = [m, p*k] + out_pad = torch.zeros( + grad_out.shape[0], out_features_pad - out_features, device=DEVICE + ) # [m, out_features_pad - out_features] + grad_out = torch.cat( + [grad_out, out_pad], dim=1 + ) # [m * out_features_pad] = [m, p*k] # 2. x=x.flatten(1) # input: [m, p**k] - grad_out = grad_out.view(M, P, 1, K) # [m, p, 1, k] + grad_out = grad_out.view(M, P, 1, K) # [m, p, 1, k] # 3. x = morr_output_scale.matmul(x) # [1, 1, 1, q] x [bs, p, q, k] = [bs, p, 1, k] # dL/d(morr_output_scale) if ctx.needs_input_grad[3]: - grad_s = grad_out.matmul(x_scalematmul.transpose(-2, -1)) # [bs, p, 1, q] - grad_s = grad_s.sum(dim=(0, 1)).unsqueeze(0).unsqueeze(1) # [1, 1, 1, q] - grad_s = grad_s.squeeze(0).unsqueeze(-1) # [1, 1, q, 1] gradient of scale + grad_s = grad_out.matmul(x_scalematmul.transpose(-2, -1)) # [bs, p, 1, q] + grad_s = grad_s.sum(dim=(0, 1)).unsqueeze(0).unsqueeze(1) # [1, 1, 1, q] + grad_s = grad_s.squeeze(0).unsqueeze(-1) # [1, 1, q, 1] gradient of scale - t = ctx.grid_dim_x // 2 - grad_scale = grad_s.new_zeros((1, 1, t+1, 1)) + t = ctx.grid_dim_x // 2 + grad_scale = grad_s.new_zeros((1, 1, t + 1, 1)) if ctx.grid_dim_x % 2 == 0: grad_scale[..., :t, :] = grad_s[..., :t, :] - grad_s[..., t:, :] elif ctx.grid_dim_x == 1: grad_scale = grad_s else: - grad_scale[..., :t, :] = grad_s[..., :t, :] - grad_s[..., t+1:, :] - grad_scale[..., t:t+1, :] = grad_s[..., t:t+1, :] - + grad_scale[..., :t, :] = grad_s[..., :t, :] - grad_s[..., t + 1 :, :] + grad_scale[..., t : t + 1, :] = grad_s[..., t : t + 1, :] + else: grad_scale = None - + # dL/dx - grad_x = morr_output_scale.transpose(-2, -1).matmul(grad_out) # [bs, p, q, k] + grad_x = morr_output_scale.transpose(-2, -1).matmul(grad_out) # [bs, p, q, k] # 4. x = mrr_roundtrip_phase_to_tr(x) mrr_a, mrr_r = ctx.mrr_a, ctx.mrr_r c1 = -2.0 * mrr_a * mrr_r c2 = mrr_a * mrr_a + mrr_r * mrr_r c3 = 1.0 + (mrr_r * mrr_r) * (mrr_a * mrr_a) - mrr_a * mrr_a - mrr_r * mrr_r - c4 = (mrr_a**2.0 - 1.0) * (mrr_r**2.0 - 1.0) * 2.0 * mrr_a * mrr_r + c4 = (mrr_a**2.0 - 1.0) * (mrr_r**2.0 - 1.0) * 2.0 * mrr_a * mrr_r intensity = True denominator = x_mrr.cos().mul_(c1).add_(c2 + c3) if intensity: @@ -868,57 +916,59 @@ def _morr_linear_backward(ctx, grad_output, *ignored): numerator = x_mrr.sin().mul_(c4) else: numerator = x_mrr.sin().mul_(c4 / 2) - denominator = ( - denominator.sub(1).pow_(1.5).mul_(denominator.sub(c3).sqrt_()) - ) - grad_x = numerator.div_(denominator).mul_(grad_x) # [bs, p, q, k] - + denominator = denominator.sub(1).pow_(1.5).mul_(denominator.sub(c3).sqrt_()) + grad_x = numerator.div_(denominator).mul_(grad_x) # [bs, p, q, k] + # 5. x += phase_noise and x -= morr_bias if ctx.trainable_morr_bias and ctx.needs_input_grad[2]: - grad_inputbias = - grad_x # [bs, p, q, k] - grad_inputbias = grad_inputbias * ctx.morr_fwhm # [bs, p, q, k] - grad_inputbias = grad_inputbias - tanh_input_bias * tanh_input_bias # [bs, p, q, k] + grad_inputbias = -grad_x # [bs, p, q, k] + grad_inputbias = grad_inputbias * ctx.morr_fwhm # [bs, p, q, k] + grad_inputbias = ( + grad_inputbias - tanh_input_bias * tanh_input_bias + ) # [bs, p, q, k] grad_inputbias = grad_inputbias.sum(dim=(0, -1)) else: grad_inputbias = None # 6. x = weight.matmul(x) [1, p, q, k, k] * [bs, 1, q, k, 1] = [bs, p, q, k, 1] - grad_x = grad_x.unsqueeze(-1) # [bs, p, q, k, 1] - grad_morr_matmul = grad_x # stash for weight gradient - + grad_x = grad_x.unsqueeze(-1) # [bs, p, q, k, 1] + grad_morr_matmul = grad_x # stash for weight gradient + # dL/dx - grad_x = torch.matmul(w_morr.transpose(-1, -2), grad_x) # [1, p, q, k, k] x [bs, p, q, k, 1] = [bs, p, q, k, 1] - grad_x = grad_x.sum(dim=1, keepdim=True) # [bs, p, q, k, 1] -> [bs, 1, q, k, 1] - grad_x = grad_x.squeeze(-1).squeeze(1) # [bs, 1, q, k, 1] -> [bs, q, k] + grad_x = torch.matmul( + w_morr.transpose(-1, -2), grad_x + ) # [1, p, q, k, k] x [bs, p, q, k, 1] = [bs, p, q, k, 1] + grad_x = grad_x.sum(dim=1, keepdim=True) # [bs, p, q, k, 1] -> [bs, 1, q, k, 1] + grad_x = grad_x.squeeze(-1).squeeze(1) # [bs, 1, q, k, 1] -> [bs, q, k] # 7. input modulator - grad_x = grad_x * 2 * x_modulator # [bs, q, k] + grad_x = grad_x * 2 * x_modulator # [bs, q, k] # 8. input reshape grad_x = grad_x.view(x_input_shape) grad_x = grad_x[:, :in_features] - - # ----- Gradient w.r.t weight ----- if ctx.needs_input_grad[1]: - + # 0. gradient after x = weight.matmul(x) # grad_morr_matmul # [bs, p, q, k, 1] # 1. x = weight.matmul(x) - grad_w = torch.matmul(grad_morr_matmul, x_morr.transpose(-1,-2)) # [bs,p,q,k,k] - grad_w = grad_w.sum(dim=0, keepdim=True) # [1,p,q,k,k] + grad_w = torch.matmul( + grad_morr_matmul, x_morr.transpose(-1, -2) + ) # [bs,p,q,k,k] + grad_w = grad_w.sum(dim=0, keepdim=True) # [1,p,q,k,k] # 2. weight = toeplitz(weight) k = grad_w.size(-1) - row = torch.arange(k)[:, None] # (k,1) - col = torch.arange(k)[None, :] # (1,k) - idx = (row - col) & (k - 1) if (k & (k-1)) == 0 else (row - col + k) % k + row = torch.arange(k)[:, None] # (k,1) + col = torch.arange(k)[None, :] # (1,k) + idx = (row - col) & (k - 1) if (k & (k - 1)) == 0 else (row - col + k) % k idx = idx.expand(grad_w.shape).to(DEVICE) buffer = torch.zeros_like(grad_w, device=DEVICE) - buffer.scatter_add_(-2, idx, grad_w) # [1, p, q, k, k] + buffer.scatter_add_(-2, idx, grad_w) # [1, p, q, k, k] grad_w = buffer.sum(dim=-1, keepdim=True).squeeze(0).squeeze(-1) # 3. build_weight() @@ -927,33 +977,59 @@ def _morr_linear_backward(ctx, grad_output, *ignored): # morr_scale: [p, q, 1] grad_morr_input_scale = None if ctx.w_bit < 16: - # grad w.r.t morr_scale + # grad w.r.t morr_scale if ctx.needs_input_grad[5] & ctx.trainable_morr_scale: - grad_morr_scale = (grad_w * weight).sum(dim=2, keepdim=True) # [p, q, 1] - grad_morr_scale = grad_morr_scale * ctx.weight_quant_gain # [p, q, 1] + grad_morr_scale = (grad_w * weight).sum( + dim=2, keepdim=True + ) # [p, q, 1] + grad_morr_scale = grad_morr_scale * ctx.weight_quant_gain # [p, q, 1] # ∂L/∂self.morr_input_scale sigmoid_scale = torch.sigmoid(morr_input_scale) - grad_morr_input_scale = (grad_morr_scale * sigmoid_scale * (1-sigmoid_scale)).squeeze(-1) # [p, q] + grad_morr_input_scale = ( + grad_morr_scale * sigmoid_scale * (1 - sigmoid_scale) + ).squeeze( + -1 + ) # [p, q] # grad w.r.t weight grad_w = grad_w * morr_scale else: grad_w = grad_w * weight.sign() - + return ( - grad_x, # ∂L/∂x - grad_w, # ∂L/∂w - grad_inputbias, # ∂L/∂morr_input_bias + grad_x, # ∂L/∂x + grad_w, # ∂L/∂w + grad_inputbias, # ∂L/∂morr_input_bias grad_scale, # ∂L/∂morr_output_scale - grad_bias, # ∂L/∂bias + grad_bias, # ∂L/∂bias grad_morr_input_scale, - None, None, None, None, None, None, None, None, None, - None, None, None, - None, None, None, None, None, None, None, - None, None, None, None + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, ) morr_linear_fn_mem.register_autograd( - _morr_linear_backward, setup_context=_morr_linear_setup_context, -) \ No newline at end of file + _morr_linear_backward, + setup_context=_morr_linear_setup_context, +) diff --git a/src/chop/nn/optical/triton_modules/morr_linear_mem.py b/src/chop/nn/optical/triton_modules/morr_linear_mem.py index 74f731d08..eb314b3a0 100644 --- a/src/chop/nn/optical/triton_modules/morr_linear_mem.py +++ b/src/chop/nn/optical/triton_modules/morr_linear_mem.py @@ -445,35 +445,39 @@ def apply_finegrain_drop_mask(self, mask: Tensor) -> None: def forward(self, x: Tensor) -> Tensor: output, *_ = morr_linear_fn_mem( - x, + x, self.weight, - morr_input_bias = self.morr_input_bias, - morr_output_scale = self.morr_output_scale, - bias = None, - morr_input_scale = self.morr_input_scale, - morr_bias = self.morr_bias.detach(), - grid_dim_x = self.grid_dim_x, - grid_dim_y = self.grid_dim_y, - miniblock = self.miniblock, + morr_input_bias=self.morr_input_bias, + morr_output_scale=self.morr_output_scale, + bias=None, + morr_input_scale=self.morr_input_scale, + morr_bias=self.morr_bias.detach(), + grid_dim_x=self.grid_dim_x, + grid_dim_y=self.grid_dim_y, + miniblock=self.miniblock, enable_thermal_crosstalk=self.enable_thermal_crosstalk, - crosstalk_factor=None if not self.enable_thermal_crosstalk else self.crosstalk_factor, + crosstalk_factor=( + None if not self.enable_thermal_crosstalk else self.crosstalk_factor + ), enable_phase_noise=self.enable_phase_noise, - phase_noise_std=None if not self.enable_phase_noise else self.phase_noise_std, + phase_noise_std=( + None if not self.enable_phase_noise else self.phase_noise_std + ), trainable_morr_bias=self.trainable_morr_bias, mrr_a=self.mrr_a, mrr_r=self.mrr_r, finegrain_drop_mask=None, - in_features = self.in_features, - in_features_pad = self.in_features_pad, - out_features = self.out_features, - out_features_pad = self.out_features_pad, - in_bit = self.in_bit, - w_bit = self.w_bit, - morr_fwhm = self.morr_fwhm, + in_features=self.in_features, + in_features_pad=self.in_features_pad, + out_features=self.out_features, + out_features_pad=self.out_features_pad, + in_bit=self.in_bit, + w_bit=self.w_bit, + morr_fwhm=self.morr_fwhm, sigma_weight=self.sigma_weight, - trainable_morr_scale=self.trainable_morr_scale, # bool + trainable_morr_scale=self.trainable_morr_scale, # bool morr_scale=self.morr_scale, weight_quant_gain=self.weight_quant_gain, - seed = 42, + seed=42, ) - return output \ No newline at end of file + return output diff --git a/src/chop/nn/optical/triton_modules/quantize.py b/src/chop/nn/optical/triton_modules/quantize.py index 738b1a925..fdd0848ef 100644 --- a/src/chop/nn/optical/triton_modules/quantize.py +++ b/src/chop/nn/optical/triton_modules/quantize.py @@ -11,7 +11,7 @@ def uniform_quantize(x: tl.tensor, k, gradient_clip=False): elif k == 1: out = tl.where(x >= 0, 1.0, -1.0) else: - n = float(2 ** k - 1) + n = float(2**k - 1) out = tl.extra.cuda.libdevice.rint(x * n) / n return out @@ -23,7 +23,7 @@ def uniform_quantize_new(x: tl.tensor, k, scale, zero_point, gradient_clip=False elif k == 1: out = tl.where(x > 0, 1.0, tl.where(x < 0, -1.0, 0.0)) else: - n = float(2 ** k - 1) + n = float(2**k - 1) out = tl.div(x, scale) out = out + zero_point out = tl.extra.cuda.libdevice.rint(out) @@ -35,7 +35,11 @@ def uniform_quantize_new(x: tl.tensor, k, scale, zero_point, gradient_clip=False @triton.jit def _input_quantize_fn( - x: tl.tensor, quant_ratio, training, in_bit, alg, # self.training + x: tl.tensor, + quant_ratio, + training, + in_bit, + alg, # self.training ): # init if alg == "dorefa": @@ -52,7 +56,7 @@ def _input_quantize_fn( qscheme=torch.per_tensor_affine, reduce_range=False, quant_min=0, - quant_max=2 ** in_bit - 1, + quant_max=2**in_bit - 1, ) else: obs = None diff --git a/src/chop/nn/snn/modules/__init__.py b/src/chop/nn/snn/modules/__init__.py index 6b6efd229..cec1196bd 100644 --- a/src/chop/nn/snn/modules/__init__.py +++ b/src/chop/nn/snn/modules/__init__.py @@ -60,9 +60,7 @@ ) from .embedding import EmbeddingZIPTF -from .roberta import ( - RobertaSelfAttentionZIPTF, -) +from .roberta import RobertaSelfAttentionZIPTF spiking_basic_module_map = { "conv1d": Conv1d, diff --git a/src/chop/passes/graph/analysis/add_metadata/common_metadata_layers.py b/src/chop/passes/graph/analysis/add_metadata/common_metadata_layers.py index be02f0046..7d59937ac 100644 --- a/src/chop/passes/graph/analysis/add_metadata/common_metadata_layers.py +++ b/src/chop/passes/graph/analysis/add_metadata/common_metadata_layers.py @@ -315,9 +315,9 @@ "key": "config", "value": "data_in", }, - "invert": { # Added for Wave2Vec + "invert": { "input": "data_in", - }, + }, # Added for Wave2Vec } module_data = { diff --git a/src/chop/passes/module/module_modify_helper.py b/src/chop/passes/module/module_modify_helper.py index 129438275..ea0634be3 100644 --- a/src/chop/passes/module/module_modify_helper.py +++ b/src/chop/passes/module/module_modify_helper.py @@ -15,9 +15,7 @@ RobertaSelfOutput, ) -from transformers.models.llama.modeling_llama import ( - LlamaAttention, -) +from transformers.models.llama.modeling_llama import LlamaAttention from transformers.models.bert.modeling_bert import ( BertSdpaSelfAttention, @@ -224,8 +222,13 @@ def instantiate_llama_module( ) return llama_module + def instantiate_bert_module( - module, postfix, prefix, module_map, module_args, + module, + postfix, + prefix, + module_map, + module_args, ): bert_cls = module_map[f"{prefix}_{postfix}"] @@ -241,6 +244,7 @@ def instantiate_bert_module( ) return bert_module + def instantiate_module(module, postfix, module_map, additional_module_args): is_roberta, roberta_layer_name = check_module_instance(module, roberta_prefix_map) is_llama, llama_layer_name = check_module_instance(module, llama_prefix_map) @@ -268,7 +272,11 @@ def instantiate_module(module, postfix, module_map, additional_module_args): ) elif is_bert: module = instantiate_bert_module( - module, postfix, bert_layer_name, module_map, module_args, + module, + postfix, + bert_layer_name, + module_map, + module_args, ) else: raise ValueError(f"{module} is not supported.") diff --git a/src/chop/passes/module/transforms/attention/attention_transform_helper.py b/src/chop/passes/module/transforms/attention/attention_transform_helper.py index 9f9129b69..3216e1db9 100644 --- a/src/chop/passes/module/transforms/attention/attention_transform_helper.py +++ b/src/chop/passes/module/transforms/attention/attention_transform_helper.py @@ -11,9 +11,7 @@ MLA, ) from chop.nn.modules.mgqa import MGQALayers, MGQA -from chop.nn.modules.lora_linear import ( - LowRankLinear, -) +from chop.nn.modules.lora_linear import LowRankLinear from ...module_modify_helper import ( get_module_by_name, set_module_by_name, @@ -421,7 +419,6 @@ def _create_rotary_embeddings(self, seqlen, rope_dim, device): class MGQAWrapper(torch.nn.Module): - def __init__(self, mgqa: MGQA): super().__init__() self.mgqa = mgqa diff --git a/src/chop/passes/module/transforms/optical/module_transform_helper.py b/src/chop/passes/module/transforms/optical/module_transform_helper.py index c652dc9b5..a88df57a0 100644 --- a/src/chop/passes/module/transforms/optical/module_transform_helper.py +++ b/src/chop/passes/module/transforms/optical/module_transform_helper.py @@ -22,23 +22,8 @@ RobertaSelfOutput, ) -from transformers.models.llama.modeling_llama import ( - LlamaAttention, -) - -from transformers.models.bert.modeling_bert import ( - BertSdpaSelfAttention, - BertSelfAttention, -) - -from transformers.models.bert.configuration_bert import BertConfig - - +from transformers.models.llama.modeling_llama import LlamaAttention -bert_prefix_map = { - BertSdpaSelfAttention: "bert_self_attention", - BertSelfAttention: "bert_self_attention", -} def check_module_instance(module, prefix_map): """ @@ -55,19 +40,21 @@ def check_module_instance(module, prefix_map): return True, name return False, None + def replace_by_name_optical(network, module_name: str, new_module, target_name): original = get_module_by_name(network, module_name) if target_name == "linear_morr_full": updated_module = weight_replacement_full_linear_optical(original, new_module) - elif target_name in ["linear_morr", "linear_morr_triton", "linear_morr_triton_mem"]: - updated_module = weight_replacement_circulant_linear_optical(original, new_module) - # updated_module = weight_randominit_circulant_linear_optical(original, new_module) - elif target_name in ["bert_self_attention_morr"]: - updated_module = weight_replacement_circulant_bert_attention(original, new_module) + elif target_name in ["linear_morr", "linear_morr_triton"]: + updated_module = weight_replacement_circulant_linear_optical( + original, new_module + ) else: - raise NotImplementedError(f"weight replacement function for the optical module {target_name} not implemented") - + raise NotImplementedError( + f"weight replacement function for the optical module {target_name} not implemented" + ) + network = set_module_by_name(network, module_name, updated_module) return network @@ -83,6 +70,7 @@ def weight_replacement_full_linear_optical(original, new_module): "weight replacement function for the optical module not implemented" ) + def weight_replacement_linear_optical(linear_layer, morr_layer): """ Replace the weights of AllPassMORRLinear (morr_layer) with those from a standard nn.Linear (linear_layer). @@ -95,34 +83,44 @@ def weight_replacement_linear_optical(linear_layer, morr_layer): grid_dim_x = morr_layer.grid_dim_x grid_dim_y = morr_layer.grid_dim_y in_features_pad = morr_layer.in_features_pad - + # Get the weights from the standard linear layer standard_weights = linear_layer.weight.data # [out_features, in_features] - + # Ensure the shapes match - assert standard_weights.shape[0] == out_features, "Output feature dimensions don't match" - assert standard_weights.shape[1] == in_features, "Input feature dimensions don't match" - + assert ( + standard_weights.shape[0] == out_features + ), "Output feature dimensions don't match" + assert ( + standard_weights.shape[1] == in_features + ), "Input feature dimensions don't match" + # Pad the standard weights to match in_features_pad if in_features_pad > in_features: - padded_weights = torch.zeros(out_features, in_features_pad, - device=standard_weights.device, - dtype=standard_weights.dtype) + padded_weights = torch.zeros( + out_features, + in_features_pad, + device=standard_weights.device, + dtype=standard_weights.dtype, + ) padded_weights[:, :in_features] = standard_weights - standard_weights = padded_weights # [out_features, in_features_pad] - + standard_weights = padded_weights # [out_features, in_features_pad] + # Reshape to match the MORR structure [grid_dim_y, grid_dim_x, miniblock] assert grid_dim_y == out_features, "grid_dim_y does not match out_features" - assert grid_dim_x * miniblock == in_features_pad, "grid_dim_x * miniblock does not match in_features_pad" - + assert ( + grid_dim_x * miniblock == in_features_pad + ), "grid_dim_x * miniblock does not match in_features_pad" + reshaped_weights = standard_weights.reshape(grid_dim_y, grid_dim_x, miniblock) - + # Copy the weights to the MORR layer with torch.no_grad(): morr_layer.weight.data.copy_(reshaped_weights) - + return morr_layer + def weight_replacement_circulant_linear_optical(x, y): """ Replace the weights of AllPassMORRCirculantLinear (y) with those from a standard nn.Linear (x). @@ -131,14 +129,14 @@ def weight_replacement_circulant_linear_optical(x, y): """ # Dense weight - W = x.weight.data # [out_features, in_features] + W = x.weight.data # [out_features, in_features] # Dimensions defined by the MORR layer - k = y.miniblock # miniblock size - grid_dim_y = y.grid_dim_y # #block-rows (p) - grid_dim_x = y.grid_dim_x # #block-cols (q) + k = y.miniblock # miniblock size + grid_dim_y = y.grid_dim_y # #block-rows (p) + grid_dim_x = y.grid_dim_x # #block-cols (q) out_features_p = y.out_features_pad - in_features_p = y.in_features_pad + in_features_p = y.in_features_pad # Zero-pad so every block is k×k W_padded = W.new_zeros((out_features_p, in_features_p)) @@ -154,71 +152,19 @@ def weight_replacement_circulant_linear_optical(x, y): for q in range(grid_dim_x): col_slice = slice(q * k, (q + 1) * k) - block = W_padded[row_slice, col_slice] # shape (k, k) + block = W_padded[row_slice, col_slice] # shape (k, k) # Frobenius-projection onto the circulant subspace: # c_j = mean of { block[i, (i+j) mod k], i=0…k-1 } - c = torch.stack([ - block[idx, (idx + j) % k].mean() - for j in range(k) - ]) + c = torch.stack([block[idx, (idx + j) % k].mean() for j in range(k)]) - new_weight[p, q, :] = c # first row + new_weight[p, q, :] = c # first row # Save back into the MORR layer y.load_parameters({"weight": new_weight}) return y -def weight_randominit_circulant_linear_optical(x, y): - """ - Replace the weights of AllPassMORRCirculantLinear (y) with those from a standard nn.Linear (x). - Focuses only on weight copying (no bias copying). - """ - warnings.warn( - "Random weight initiator is being used!", - category=RuntimeWarning, - stacklevel=2, # point the warning at the caller - ) - # y.reset_parameters() - return y - # # Fetch original linear weight [out_features, in_features] - # W = x.weight.data # [out_features, in_features] - - # # Grab dimensions and zero-pad if needed - # out_features_pad = y.out_features_pad # padded out_features in y - # in_features_pad = y.in_features_pad # padded in_features in y - # miniblock = y.miniblock - # grid_dim_y = y.grid_dim_y - # grid_dim_x = y.grid_dim_x - - # # Construct padded weight tensor - # W_padded = W.new_zeros((out_features_pad, in_features_pad)) - # W_padded[: W.size(0), : W.size(1)] = W - - # # Takes the mean across the miniblock slice. - # new_weight = W.new_zeros((grid_dim_y, grid_dim_x, miniblock)) # [grid_dim_y, grid_dim_x, miniblock] - - # # Fill new_weight by averaging the corresponding sub-blocks in W_padded - # # original miniblock: [k, k] new miniblock: [k, 1] - # with torch.no_grad(): - # for p in range(grid_dim_y): - # for q in range(grid_dim_x): - # for k in range(miniblock): - # row_idx = p * miniblock + k # The row in W_padded: - # col_start = q * miniblock # The columns in W_padded: - # col_end = (q + 1) * miniblock - # block = W_padded[row_idx, col_start:col_end] - - # new_weight[p, q, k] = block.mean() - - # bound = 1 / math.sqrt(miniblock) - # new_weight = torch.rand((grid_dim_y, grid_dim_x, miniblock), - # device=W.device, - # dtype=W.dtype) * 2 * bound - bound - # # Copy the result into y.weight - # y.load_parameters({"weight": new_weight}) - def weight_replacement_conv2d_optical(x, y): """ @@ -263,38 +209,26 @@ def weight_replacement_conv2d_optical(x, y): # with a simple block-circulant approximation of x's parameters. return y -def weight_replacement_circulant_bert_attention(original, new_module): - for name in ("query", "key", "value"): - src_linear = getattr(original, name) - dst_linear = getattr(new_module, name) - with torch.no_grad(): - dst_linear.weight.copy_(src_linear.weight) - if src_linear.bias is not None: - dst_linear.bias.copy_(src_linear.bias) - - return new_module - def instantiate_optical_module(module, postfix, module_map, additional_module_args): - is_bert, bert_layer_name = check_module_instance(module, bert_prefix_map) - module_args = additional_module_args["config"] additional_args = additional_module_args["additional"] network_args = additional_module_args.get("network_config", None) if isinstance(module, torch.nn.Linear): - module = instantiate_optical_linear(module, postfix, module_map, module_args, additional_args) + module = instantiate_optical_linear( + module, postfix, module_map, module_args, additional_args + ) elif isinstance(module, torch.nn.Conv2d): module = instantiate_optical_conv2d(module, postfix, module_map, module_args) - elif is_bert: - module = instantiate_optical_bert_module( - module, postfix, bert_layer_name, module_map, module_args, - ) else: raise ValueError(f"{module} is not supported.") return module -def instantiate_optical_linear(module, postfix, module_map, additional_module_args, additional_args): + +def instantiate_optical_linear( + module, postfix, module_map, additional_module_args, additional_args +): linear_cls = module_map[f"linear_{postfix}"] has_bias = not (module.bias is None) @@ -316,7 +250,7 @@ def instantiate_optical_linear(module, postfix, module_map, additional_module_ar ) if additional_args is None: return linear - + # extra handling for morr optical module enable_thermal_crosstalk = additional_args.get("thermal_crosstalk", False) enable_phase_noise = additional_args.get("phase_noise", False) @@ -329,29 +263,30 @@ def instantiate_optical_linear(module, postfix, module_map, additional_module_ar additional_args.get("coupling_factor", 0.04), additional_args.get("drop_perc", 0.0), ) - + if enable_phase_noise: linear.enable_phase_variation() phase_noise_std = additional_args.get("phase_noise_std", 0.04) linear.set_phase_variation(phase_noise_std) - + if enable_trainable_morr_scale: linear.enable_trainable_morr_scale() else: linear.disable_trainable_morr_scale() - + if enable_trainable_morr_bias: linear.enable_trainable_morr_bias() else: linear.disable_trainable_morr_bias() - + if "in_bit" in additional_args: - linear.set_input_bitwidth(in_bit = additional_args["in_bit"]) + linear.set_input_bitwidth(in_bit=additional_args["in_bit"]) if "w_bit" in additional_args: - linear.set_weight_bitwidth(w_bit = additional_args["w_bit"]) + linear.set_weight_bitwidth(w_bit=additional_args["w_bit"]) return linear + def instantiate_optical_conv2d(module, postfix, module_map, additional_module_args): conv2d_cls = module_map[f"conv2d_{postfix}"] has_bias = not (module.bias is None) @@ -384,20 +319,3 @@ def instantiate_optical_conv2d(module, postfix, module_map, additional_module_ar **additional_module_args, ) return conv2d - -def instantiate_optical_bert_module( - module, postfix, prefix, module_map, module_args, -): - bert_cls = module_map[f"{prefix}_{postfix}"] - - bert_module = bert_cls( - config=BertConfig( - hidden_size=module.query.in_features, - num_attention_heads=module.num_attention_heads, - attention_head_size=module.attention_head_size, - attention_probs_dropout_prob=module.dropout_prob, - is_decoder=False, - ), - morr_config=module_args, - ) - return bert_module \ No newline at end of file diff --git a/src/chop/passes/module/transforms/optical/optical.py b/src/chop/passes/module/transforms/optical/optical.py index 0ec241e32..9e72f617f 100644 --- a/src/chop/passes/module/transforms/optical/optical.py +++ b/src/chop/passes/module/transforms/optical/optical.py @@ -9,6 +9,7 @@ ) from ...state_dict_map import match_a_pattern, check_is_huggingface_model + def get_config(config: dict, name: str): if name in config: return config[name]["config"] @@ -30,15 +31,25 @@ def optical_transform_by_type(network, pass_args): type_name = "bert_self_attention" else: raise ValueError(f"{type_name} is not supported!") - config = config["config"] - postfix = config.pop("name") + + # config = config["config"] + # postfix = config.pop("name") + optical_config = config["config"] + optial_additional_config = config.get("additional", None) + postfix = optical_config["name"] + + additional_module_args = { + "config": optical_config, + "additional": optial_additional_config, + } for n, m in n_m.items(): if isinstance(m, module): - print(f"processing {n}") new_m = instantiate_optical_module( - m, postfix, optical_module_map, {"config": config} + m, postfix, optical_module_map, additional_module_args + ) + network = replace_by_name_optical( + network, n, new_m, type_name + "_" + postfix ) - network = replace_by_name_optical(network, n, new_m, type_name +'_'+postfix) return network @@ -49,17 +60,32 @@ def optical_transform_by_name(network, pass_args): n_m[n] = m for n, m in n_m.items(): if n in optical_names: - optical_config = pass_args[n] + optical_config = pass_args[n]["config"] + optial_additional_config = pass_args[n].get("additional", None) + postfix = optical_config["name"] + + additional_module_args = { + "config": optical_config, + "additional": optial_additional_config, + } - optical_config = optical_config["config"] - postfix = optical_config.pop("name") + if isinstance(m, torch.nn.Linear): + type_name = "linear" + elif isinstance(m, torch.nn.Conv2d): + type_name = "conv2d" + else: + raise ValueError(f"{type_name} is not supported!") new_m = instantiate_optical_module( - m, postfix, optical_module_map, {"config": optical_config} + m, postfix, optical_module_map, additional_module_args + ) + network = replace_by_name_optical( + network, n, new_m, type_name + "_" + postfix ) - network = replace_by_name_optical(network, n, new_m) + return network + def optical_transform_by_regex_name(network, pass_args): is_huggingface_model = check_is_huggingface_model(network) @@ -79,10 +105,7 @@ def optical_transform_by_regex_name(network, pass_args): postfix = optical_config["name"] additional_module_args = ( - { - "config": optical_config, - "additional": optial_additional_config - } + {"config": optical_config, "additional": optial_additional_config} # if is_huggingface_model # else {"config": optical_config} ) @@ -91,18 +114,17 @@ def optical_transform_by_regex_name(network, pass_args): type_name = "linear" elif isinstance(m, torch.nn.Conv2d): type_name = "conv2d" - elif isinstance(m, BertSdpaSelfAttention): - type_name = "bert_self_attention" else: raise ValueError(f"{type_name} is not supported!") - + new_m = instantiate_optical_module( m, postfix, optical_module_map, additional_module_args ) - network = replace_by_name_optical(network, n, new_m, type_name +'_'+postfix) + network = replace_by_name_optical(network, n, new_m, type_name + "_" + postfix) return network + def optical_module_transform_pass(network, pass_args): """ Apply optical transformation to the given nn.Module. diff --git a/src/mase_components/difflogic_layers/passes.py b/src/mase_components/difflogic_layers/passes.py index c1b42977e..6693d5cf2 100644 --- a/src/mase_components/difflogic_layers/passes.py +++ b/src/mase_components/difflogic_layers/passes.py @@ -3,7 +3,6 @@ def difflogic_hardware_metadata_optimize_pass(graph, args={}): - def _is_logiclayer(node): return node.meta["mase"]["common"]["mase_op"] == "user_defined_module" diff --git a/test/passes/module/transforms/attention/test_attention_transform.py b/test/passes/module/transforms/attention/test_attention_transform.py index 14137c66e..7bd58f13d 100644 --- a/test/passes/module/transforms/attention/test_attention_transform.py +++ b/test/passes/module/transforms/attention/test_attention_transform.py @@ -7,9 +7,7 @@ sys.path.append(Path(__file__).resolve().parents[5].as_posix()) -from chop.passes.module.transforms import ( - attention_swap_transform_pass, -) +from chop.passes.module.transforms import attention_swap_transform_pass from pathlib import Path import time diff --git a/test/passes/module/transforms/optical/bert-finetune.py b/test/passes/module/transforms/optical/bert-finetune.py deleted file mode 100644 index db4d6fa84..000000000 --- a/test/passes/module/transforms/optical/bert-finetune.py +++ /dev/null @@ -1,140 +0,0 @@ -import os -os.environ["CUDA_VISIBLE_DEVICES"] = "0,1" -import torch -import numpy as np -import evaluate -from datasets import load_dataset -import dill -from pathlib import Path -from transformers import ( - AutoTokenizer, - AutoModelForSequenceClassification, - Trainer, - TrainingArguments, - DataCollatorWithPadding, -) -from chop.passes.module.transforms.optical import optical_module_transform_pass - -def bert_onn_transform(model): - type_args = { - "by": "type", - "linear": { - "config": { - "name": "morr", - "miniblock": 4, - "morr_init": True, - "trainable_morr_bias": False, - "trainable_morr_scale": False, - } - }, - } - - name_args = { - "by": "name", - "bert.encoder.layer.0.attention.self.query": { - "config": { - "name": "morr", - "miniblock": 4, - "morr_init": True, - "trainable_morr_bias": False, - "trainable_morr_scale": False, - } - }, - "bert.encoder.layer.0.attention.self.key": { - "config": { - "name": "morr", - "miniblock": 4, - "morr_init": True, - "trainable_morr_bias": False, - "trainable_morr_scale": False, - } - }, - "bert.encoder.layer.0.attention.self.value": { - "config": { - "name": "morr", - "miniblock": 4, - "morr_init": True, - "trainable_morr_bias": False, - "trainable_morr_scale": False, - } - }, - } - - pattern = r"^bert\.encoder\.layer\.\d+\.attention\.self\.(key|query|value)$" - regex_args = { - "by": "regex_name", - pattern: { - "config": { - "name": "morr", - "miniblock": 4, - "morr_init": True, - "trainable_morr_bias": False, - "trainable_morr_scale": False, - } - }, - } - - model, _ = optical_module_transform_pass(model, regex_args) - return model - -def test_bert_inference(model, text="This is a test."): - """ - Passes a sample string through the model for quick debugging. - """ - tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") - inputs = tokenizer(text, return_tensors="pt") - outputs = model(**inputs) - - return outputs - -def finetune_bert(model): - model_name = "bert-base-uncased" - tokenizer = AutoTokenizer.from_pretrained(model_name) - - dataset = load_dataset("glue", "sst2") - def preprocess(examples): - return tokenizer(examples["sentence"], truncation=True, padding=True) - dataset = dataset.map(preprocess, batched=True) - - data_collator = DataCollatorWithPadding(tokenizer) - metric = evaluate.load("accuracy") - def compute_metrics(eval_pred): - logits, labels = eval_pred - return metric.compute(predictions=np.argmax(logits, axis=1), references=labels) - - training_args = TrainingArguments( - output_dir="model_sst2", - run_name="bert_sst2_experiment", - evaluation_strategy="epoch", - report_to=["none"], - num_train_epochs=2, - logging_steps=25000, - per_device_train_batch_size=2, # set training batch size - per_device_eval_batch_size=2, # set evaluation batch size - ) - - trainer = Trainer( - model=model, - args=training_args, - train_dataset=dataset["train"], - eval_dataset=dataset["validation"], - data_collator=data_collator, - compute_metrics=compute_metrics - ) - trainer.train() - return model - -if __name__ == "__main__": - model_name = "bert-base-uncased" - tokenizer = AutoTokenizer.from_pretrained(model_name) - model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2) - - model = bert_onn_transform(model) - print(model) - - model = finetune_bert(model) - with open(f"{Path.home()}/bert-onn-2epoch", "wb") as f: - dill.dump(model, f) - # print(1) - # test_bert_inference(model) - # main() diff --git a/test/passes/module/transforms/optical/playground.py b/test/passes/module/transforms/optical/playground.py deleted file mode 100644 index d0d83d6bf..000000000 --- a/test/passes/module/transforms/optical/playground.py +++ /dev/null @@ -1,22 +0,0 @@ -import torch -import torch.nn as nn - -def test_linear_out_shape(hidden_size=768, out_size=1024): - """ - Passes a [1, 7, hidden_size] tensor through nn.Linear - and prints input/output shapes. - """ - # Sample input tensor - x = torch.randn(1, 7, hidden_size) - - # Linear layer: change dims if needed - linear_layer = nn.Linear(hidden_size, out_size) - - # Forward pass - y = linear_layer(x) - - # Print shapes for quick verification - print("Input shape:", x.shape) - print("Output shape:", y.shape) - -test_linear_out_shape() \ No newline at end of file diff --git a/test/passes/module/transforms/optical/run_glue.py b/test/passes/module/transforms/optical/run_glue.py deleted file mode 100644 index b5d5aeb77..000000000 --- a/test/passes/module/transforms/optical/run_glue.py +++ /dev/null @@ -1,637 +0,0 @@ -#!/usr/bin/env python -# coding=utf-8 -# Copyright 2020 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Finetuning the library models for sequence classification on GLUE.""" -# You can also adapt this script on your own text classification task. Pointers for this are left as comments. - -import logging -import os -import random -import sys -from dataclasses import dataclass, field -from typing import Optional - -import datasets -import evaluate -import numpy as np -from datasets import load_dataset - -import transformers -from transformers import ( - AutoConfig, - AutoModelForSequenceClassification, - AutoTokenizer, - DataCollatorWithPadding, - EvalPrediction, - HfArgumentParser, - PretrainedConfig, - Trainer, - TrainingArguments, - default_data_collator, - set_seed, -) -from transformers.trainer_utils import get_last_checkpoint -from transformers.utils import check_min_version, send_example_telemetry -from transformers.utils.versions import require_version - - -# Will error if the minimal version of Transformers is not installed. Remove at your own risks. -check_min_version("4.50.0.dev0") - -require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt") - -task_to_keys = { - "cola": ("sentence", None), - "mnli": ("premise", "hypothesis"), - "mrpc": ("sentence1", "sentence2"), - "qnli": ("question", "sentence"), - "qqp": ("question1", "question2"), - "rte": ("sentence1", "sentence2"), - "sst2": ("sentence", None), - "stsb": ("sentence1", "sentence2"), - "wnli": ("sentence1", "sentence2"), -} - -logger = logging.getLogger(__name__) - - -@dataclass -class DataTrainingArguments: - """ - Arguments pertaining to what data we are going to input our model for training and eval. - - Using `HfArgumentParser` we can turn this class - into argparse arguments to be able to specify them on - the command line. - """ - - task_name: Optional[str] = field( - default=None, - metadata={"help": "The name of the task to train on: " + ", ".join(task_to_keys.keys())}, - ) - dataset_name: Optional[str] = field( - default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} - ) - dataset_config_name: Optional[str] = field( - default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} - ) - max_seq_length: int = field( - default=128, - metadata={ - "help": ( - "The maximum total input sequence length after tokenization. Sequences longer " - "than this will be truncated, sequences shorter will be padded." - ) - }, - ) - overwrite_cache: bool = field( - default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."} - ) - pad_to_max_length: bool = field( - default=True, - metadata={ - "help": ( - "Whether to pad all samples to `max_seq_length`. " - "If False, will pad the samples dynamically when batching to the maximum length in the batch." - ) - }, - ) - max_train_samples: Optional[int] = field( - default=None, - metadata={ - "help": ( - "For debugging purposes or quicker training, truncate the number of training examples to this " - "value if set." - ) - }, - ) - max_eval_samples: Optional[int] = field( - default=None, - metadata={ - "help": ( - "For debugging purposes or quicker training, truncate the number of evaluation examples to this " - "value if set." - ) - }, - ) - max_predict_samples: Optional[int] = field( - default=None, - metadata={ - "help": ( - "For debugging purposes or quicker training, truncate the number of prediction examples to this " - "value if set." - ) - }, - ) - train_file: Optional[str] = field( - default=None, metadata={"help": "A csv or a json file containing the training data."} - ) - validation_file: Optional[str] = field( - default=None, metadata={"help": "A csv or a json file containing the validation data."} - ) - test_file: Optional[str] = field(default=None, metadata={"help": "A csv or a json file containing the test data."}) - - def __post_init__(self): - if self.task_name is not None: - self.task_name = self.task_name.lower() - if self.task_name not in task_to_keys.keys(): - raise ValueError("Unknown task, you should pick one in " + ",".join(task_to_keys.keys())) - elif self.dataset_name is not None: - pass - elif self.train_file is None or self.validation_file is None: - raise ValueError("Need either a GLUE task, a training/validation file or a dataset name.") - else: - train_extension = self.train_file.split(".")[-1] - assert train_extension in ["csv", "json"], "`train_file` should be a csv or a json file." - validation_extension = self.validation_file.split(".")[-1] - assert ( - validation_extension == train_extension - ), "`validation_file` should have the same extension (csv or json) as `train_file`." - - -@dataclass -class ModelArguments: - """ - Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. - """ - - model_name_or_path: str = field( - metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} - ) - config_name: Optional[str] = field( - default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} - ) - tokenizer_name: Optional[str] = field( - default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} - ) - cache_dir: Optional[str] = field( - default=None, - metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"}, - ) - use_fast_tokenizer: bool = field( - default=True, - metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, - ) - model_revision: str = field( - default="main", - metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, - ) - token: str = field( - default=None, - metadata={ - "help": ( - "The token to use as HTTP bearer authorization for remote files. If not specified, will use the token " - "generated when running `huggingface-cli login` (stored in `~/.huggingface`)." - ) - }, - ) - trust_remote_code: bool = field( - default=False, - metadata={ - "help": ( - "Whether to trust the execution of code from datasets/models defined on the Hub." - " This option should only be set to `True` for repositories you trust and in which you have read the" - " code, as it will execute code present on the Hub on your local machine." - ) - }, - ) - ignore_mismatched_sizes: bool = field( - default=False, - metadata={"help": "Will enable to load a pretrained model whose head dimensions are different."}, - ) - - -def main(): - # See all possible arguments in src/transformers/training_args.py - # or by passing the --help flag to this script. - # We now keep distinct sets of args, for a cleaner separation of concerns. - - parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) - if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): - # If we pass only one argument to the script and it's the path to a json file, - # let's parse it to get our arguments. - model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) - else: - model_args, data_args, training_args = parser.parse_args_into_dataclasses() - - # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The - # information sent is the one passed as arguments along with your Python/PyTorch versions. - send_example_telemetry("run_glue", model_args, data_args) - - # Setup logging - logging.basicConfig( - format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", - datefmt="%m/%d/%Y %H:%M:%S", - handlers=[logging.StreamHandler(sys.stdout)], - ) - - if training_args.should_log: - # The default of training_args.log_level is passive, so we set log level at info here to have that default. - transformers.utils.logging.set_verbosity_info() - - log_level = training_args.get_process_log_level() - logger.setLevel(log_level) - datasets.utils.logging.set_verbosity(log_level) - transformers.utils.logging.set_verbosity(log_level) - transformers.utils.logging.enable_default_handler() - transformers.utils.logging.enable_explicit_format() - - # Log on each process the small summary: - logger.warning( - f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, " - + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}" - ) - logger.info(f"Training/evaluation parameters {training_args}") - - # Detecting last checkpoint. - last_checkpoint = None - if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: - last_checkpoint = get_last_checkpoint(training_args.output_dir) - if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: - raise ValueError( - f"Output directory ({training_args.output_dir}) already exists and is not empty. " - "Use --overwrite_output_dir to overcome." - ) - elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: - logger.info( - f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " - "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." - ) - - # Set seed before initializing model. - set_seed(training_args.seed) - - # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below) - # or specify a GLUE benchmark task (the dataset will be downloaded automatically from the datasets Hub). - # - # For CSV/JSON files, this script will use as labels the column called 'label' and as pair of sentences the - # sentences in columns called 'sentence1' and 'sentence2' if such column exists or the first two columns not named - # label if at least two columns are provided. - # - # If the CSVs/JSONs contain only one non-label column, the script does single sentence classification on this - # single column. You can easily tweak this behavior (see below) - # - # In distributed training, the load_dataset function guarantee that only one local process can concurrently - # download the dataset. - if data_args.task_name is not None: - # Downloading and loading a dataset from the hub. - raw_datasets = load_dataset( - "nyu-mll/glue", - data_args.task_name, - cache_dir=model_args.cache_dir, - token=model_args.token, - ) - elif data_args.dataset_name is not None: - # Downloading and loading a dataset from the hub. - raw_datasets = load_dataset( - data_args.dataset_name, - data_args.dataset_config_name, - cache_dir=model_args.cache_dir, - token=model_args.token, - trust_remote_code=model_args.trust_remote_code, - ) - else: - # Loading a dataset from your local files. - # CSV/JSON training and evaluation files are needed. - data_files = {"train": data_args.train_file, "validation": data_args.validation_file} - - # Get the test dataset: you can provide your own CSV/JSON test file (see below) - # when you use `do_predict` without specifying a GLUE benchmark task. - if training_args.do_predict: - if data_args.test_file is not None: - train_extension = data_args.train_file.split(".")[-1] - test_extension = data_args.test_file.split(".")[-1] - assert ( - test_extension == train_extension - ), "`test_file` should have the same extension (csv or json) as `train_file`." - data_files["test"] = data_args.test_file - else: - raise ValueError("Need either a GLUE task or a test file for `do_predict`.") - - for key in data_files.keys(): - logger.info(f"load a local file for {key}: {data_files[key]}") - - if data_args.train_file.endswith(".csv"): - # Loading a dataset from local csv files - raw_datasets = load_dataset( - "csv", - data_files=data_files, - cache_dir=model_args.cache_dir, - token=model_args.token, - ) - else: - # Loading a dataset from local json files - raw_datasets = load_dataset( - "json", - data_files=data_files, - cache_dir=model_args.cache_dir, - token=model_args.token, - ) - # See more about loading any type of standard or custom dataset at - # https://huggingface.co/docs/datasets/loading_datasets. - - # Labels - if data_args.task_name is not None: - is_regression = data_args.task_name == "stsb" - if not is_regression: - label_list = raw_datasets["train"].features["label"].names - num_labels = len(label_list) - else: - num_labels = 1 - else: - # Trying to have good defaults here, don't hesitate to tweak to your needs. - is_regression = raw_datasets["train"].features["label"].dtype in ["float32", "float64"] - if is_regression: - num_labels = 1 - else: - # A useful fast method: - # https://huggingface.co/docs/datasets/package_reference/main_classes#datasets.Dataset.unique - label_list = raw_datasets["train"].unique("label") - label_list.sort() # Let's sort it for determinism - num_labels = len(label_list) - - # Load pretrained model and tokenizer - # - # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently - # download model & vocab. - config = AutoConfig.from_pretrained( - model_args.config_name if model_args.config_name else model_args.model_name_or_path, - num_labels=num_labels, - finetuning_task=data_args.task_name, - cache_dir=model_args.cache_dir, - revision=model_args.model_revision, - token=model_args.token, - trust_remote_code=model_args.trust_remote_code, - ) - tokenizer = AutoTokenizer.from_pretrained( - model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, - cache_dir=model_args.cache_dir, - use_fast=model_args.use_fast_tokenizer, - revision=model_args.model_revision, - token=model_args.token, - trust_remote_code=model_args.trust_remote_code, - ) - model = AutoModelForSequenceClassification.from_pretrained( - model_args.model_name_or_path, - from_tf=bool(".ckpt" in model_args.model_name_or_path), - config=config, - cache_dir=model_args.cache_dir, - revision=model_args.model_revision, - token=model_args.token, - trust_remote_code=model_args.trust_remote_code, - ignore_mismatched_sizes=model_args.ignore_mismatched_sizes, - ) - - # Preprocessing the raw_datasets - if data_args.task_name is not None: - sentence1_key, sentence2_key = task_to_keys[data_args.task_name] - else: - # Again, we try to have some nice defaults but don't hesitate to tweak to your use case. - non_label_column_names = [name for name in raw_datasets["train"].column_names if name != "label"] - if "sentence1" in non_label_column_names and "sentence2" in non_label_column_names: - sentence1_key, sentence2_key = "sentence1", "sentence2" - else: - if len(non_label_column_names) >= 2: - sentence1_key, sentence2_key = non_label_column_names[:2] - else: - sentence1_key, sentence2_key = non_label_column_names[0], None - - # Padding strategy - if data_args.pad_to_max_length: - padding = "max_length" - else: - # We will pad later, dynamically at batch creation, to the max sequence length in each batch - padding = False - - # Some models have set the order of the labels to use, so let's make sure we do use it. - label_to_id = None - if ( - model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id - and data_args.task_name is not None - and not is_regression - ): - # Some have all caps in their config, some don't. - label_name_to_id = {k.lower(): v for k, v in model.config.label2id.items()} - if sorted(label_name_to_id.keys()) == sorted(label_list): - label_to_id = {i: int(label_name_to_id[label_list[i]]) for i in range(num_labels)} - else: - logger.warning( - "Your model seems to have been trained with labels, but they don't match the dataset: " - f"model labels: {sorted(label_name_to_id.keys())}, dataset labels: {sorted(label_list)}." - "\nIgnoring the model labels as a result.", - ) - elif data_args.task_name is None and not is_regression: - label_to_id = {v: i for i, v in enumerate(label_list)} - - if label_to_id is not None: - model.config.label2id = label_to_id - model.config.id2label = {id: label for label, id in config.label2id.items()} - elif data_args.task_name is not None and not is_regression: - model.config.label2id = {l: i for i, l in enumerate(label_list)} - model.config.id2label = {id: label for label, id in config.label2id.items()} - - if data_args.max_seq_length > tokenizer.model_max_length: - logger.warning( - f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the " - f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}." - ) - max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length) - - def preprocess_function(examples): - # Tokenize the texts - args = ( - (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key]) - ) - result = tokenizer(*args, padding=padding, max_length=max_seq_length, truncation=True) - - # Map labels to IDs (not necessary for GLUE tasks) - if label_to_id is not None and "label" in examples: - result["label"] = [(label_to_id[l] if l != -1 else -1) for l in examples["label"]] - return result - - with training_args.main_process_first(desc="dataset map pre-processing"): - raw_datasets = raw_datasets.map( - preprocess_function, - batched=True, - load_from_cache_file=not data_args.overwrite_cache, - desc="Running tokenizer on dataset", - ) - if training_args.do_train: - if "train" not in raw_datasets: - raise ValueError("--do_train requires a train dataset") - train_dataset = raw_datasets["train"] - if data_args.max_train_samples is not None: - max_train_samples = min(len(train_dataset), data_args.max_train_samples) - train_dataset = train_dataset.select(range(max_train_samples)) - - if training_args.do_eval: - if "validation" not in raw_datasets and "validation_matched" not in raw_datasets: - raise ValueError("--do_eval requires a validation dataset") - eval_dataset = raw_datasets["validation_matched" if data_args.task_name == "mnli" else "validation"] - if data_args.max_eval_samples is not None: - max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples) - eval_dataset = eval_dataset.select(range(max_eval_samples)) - - if training_args.do_predict or data_args.task_name is not None or data_args.test_file is not None: - if "test" not in raw_datasets and "test_matched" not in raw_datasets: - raise ValueError("--do_predict requires a test dataset") - predict_dataset = raw_datasets["test_matched" if data_args.task_name == "mnli" else "test"] - if data_args.max_predict_samples is not None: - max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples) - predict_dataset = predict_dataset.select(range(max_predict_samples)) - - # Log a few random samples from the training set: - if training_args.do_train: - for index in random.sample(range(len(train_dataset)), 3): - logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") - - # Get the metric function - if data_args.task_name is not None: - metric = evaluate.load("glue", data_args.task_name, cache_dir=model_args.cache_dir) - elif is_regression: - metric = evaluate.load("mse", cache_dir=model_args.cache_dir) - else: - metric = evaluate.load("accuracy", cache_dir=model_args.cache_dir) - - # You can define your custom compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a - # predictions and label_ids field) and has to return a dictionary string to float. - def compute_metrics(p: EvalPrediction): - preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions - preds = np.squeeze(preds) if is_regression else np.argmax(preds, axis=1) - result = metric.compute(predictions=preds, references=p.label_ids) - if len(result) > 1: - result["combined_score"] = np.mean(list(result.values())).item() - return result - - # Data collator will default to DataCollatorWithPadding when the tokenizer is passed to Trainer, so we change it if - # we already did the padding. - if data_args.pad_to_max_length: - data_collator = default_data_collator - elif training_args.fp16: - data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8) - else: - data_collator = None - - # Initialize our Trainer - trainer = Trainer( - model=model, - args=training_args, - train_dataset=train_dataset if training_args.do_train else None, - eval_dataset=eval_dataset if training_args.do_eval else None, - compute_metrics=compute_metrics, - processing_class=tokenizer, - data_collator=data_collator, - ) - - # Training - if training_args.do_train: - checkpoint = None - if training_args.resume_from_checkpoint is not None: - checkpoint = training_args.resume_from_checkpoint - elif last_checkpoint is not None: - checkpoint = last_checkpoint - train_result = trainer.train(resume_from_checkpoint=checkpoint) - metrics = train_result.metrics - max_train_samples = ( - data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset) - ) - metrics["train_samples"] = min(max_train_samples, len(train_dataset)) - - trainer.save_model() # Saves the tokenizer too for easy upload - - trainer.log_metrics("train", metrics) - trainer.save_metrics("train", metrics) - trainer.save_state() - - # Evaluation - if training_args.do_eval: - logger.info("*** Evaluate ***") - - # Loop to handle MNLI double evaluation (matched, mis-matched) - tasks = [data_args.task_name] - eval_datasets = [eval_dataset] - if data_args.task_name == "mnli": - tasks.append("mnli-mm") - valid_mm_dataset = raw_datasets["validation_mismatched"] - if data_args.max_eval_samples is not None: - max_eval_samples = min(len(valid_mm_dataset), data_args.max_eval_samples) - valid_mm_dataset = valid_mm_dataset.select(range(max_eval_samples)) - eval_datasets.append(valid_mm_dataset) - combined = {} - - for eval_dataset, task in zip(eval_datasets, tasks): - metrics = trainer.evaluate(eval_dataset=eval_dataset) - - max_eval_samples = ( - data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) - ) - metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) - - if task == "mnli-mm": - metrics = {k + "_mm": v for k, v in metrics.items()} - if task is not None and "mnli" in task: - combined.update(metrics) - - trainer.log_metrics("eval", metrics) - trainer.save_metrics("eval", combined if task is not None and "mnli" in task else metrics) - - if training_args.do_predict: - logger.info("*** Predict ***") - - # Loop to handle MNLI double evaluation (matched, mis-matched) - tasks = [data_args.task_name] - predict_datasets = [predict_dataset] - if data_args.task_name == "mnli": - tasks.append("mnli-mm") - predict_datasets.append(raw_datasets["test_mismatched"]) - - for predict_dataset, task in zip(predict_datasets, tasks): - # Removing the `label` columns because it contains -1 and Trainer won't like that. - predict_dataset = predict_dataset.remove_columns("label") - predictions = trainer.predict(predict_dataset, metric_key_prefix="predict").predictions - predictions = np.squeeze(predictions) if is_regression else np.argmax(predictions, axis=1) - - output_predict_file = os.path.join(training_args.output_dir, f"predict_results_{task}.txt") - if trainer.is_world_process_zero(): - with open(output_predict_file, "w") as writer: - logger.info(f"***** Predict results {task} *****") - writer.write("index\tprediction\n") - for index, item in enumerate(predictions): - if is_regression: - writer.write(f"{index}\t{item:3.3f}\n") - else: - item = label_list[item] - writer.write(f"{index}\t{item}\n") - - kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "text-classification"} - if data_args.task_name is not None: - kwargs["language"] = "en" - kwargs["dataset_tags"] = "glue" - kwargs["dataset_args"] = data_args.task_name - kwargs["dataset"] = f"GLUE {data_args.task_name.upper()}" - - if training_args.push_to_hub: - trainer.push_to_hub(**kwargs) - else: - trainer.create_model_card(**kwargs) - - -def _mp_fn(index): - # For xla_spawn (TPUs) - main() - - -if __name__ == "__main__": - main() diff --git a/test/passes/module/transforms/optical/test_optical_module.py b/test/passes/module/transforms/optical/test_optical_module.py index e65546822..067ecf1c8 100644 --- a/test/passes/module/transforms/optical/test_optical_module.py +++ b/test/passes/module/transforms/optical/test_optical_module.py @@ -67,4 +67,56 @@ def test_optical_module_transform_pass(): optical_module_transform_pass(model, pass_args) +def test_optical_module_transform_pass_2(): + model = Net() + # Sanity check and report + pass_args = { + "by": "name", + "fc1": { + "config": { + "name": "morr_triton", + "miniblock": 4, + "morr_init": True, + "trainable_morr_bias": False, + "trainable_morr_scale": False, + } + }, + "conv1": { + "config": { + "name": "morr_triton", + "miniblock": 4, + "morr_init": True, + "trainable_morr_bias": False, + "trainable_morr_scale": False, + } + }, + } + optical_module_transform_pass(model, pass_args) + + +def test_optical_module_transform_pass_3(): + model = Net() + pass_args = { + "by": "regex_name", + "^fc1$": { + "config": {"name": "morr_triton", "miniblock": 4}, + "additional": { + "trainable_morr_bias": False, + "trainable_morr_scale": False, + "thermal_crosstalk": True, + "coupling_factor": 0.04, + "drop_perc": 0.0, + "phase_noise": True, + "phase_noise_std": 0.04, + "in_bit": 8, + "w_bit": 8, + }, + }, + } + new_model, _ = optical_module_transform_pass(model, pass_args) + print(new_model) + + test_optical_module_transform_pass() +test_optical_module_transform_pass_2() +test_optical_module_transform_pass_3() From c63c376cde2f669e41e2d43a24cf8152391228ce Mon Sep 17 00:00:00 2001 From: Johnny1882 Date: Tue, 24 Jun 2025 19:34:41 +0100 Subject: [PATCH 35/38] remove triton kernel --- src/chop/nn/optical/modules/__init__.py | 5 +- .../nn/optical/triton_modules/morr_linear.py | 483 --------- .../triton_modules/morr_linear_kernel.py | 930 ------------------ .../transforms/optical/test_optical_module.py | 63 +- 4 files changed, 4 insertions(+), 1477 deletions(-) delete mode 100644 src/chop/nn/optical/triton_modules/morr_linear.py delete mode 100644 src/chop/nn/optical/triton_modules/morr_linear_kernel.py diff --git a/src/chop/nn/optical/modules/__init__.py b/src/chop/nn/optical/modules/__init__.py index 539ddb38a..bfd7a558f 100644 --- a/src/chop/nn/optical/modules/__init__.py +++ b/src/chop/nn/optical/modules/__init__.py @@ -1,11 +1,10 @@ from .morr_linear import AllPassMORRCirculantLinear from .morr_conv2d import AllPassMORRCirculantConv2d -from ..triton_modules.morr_linear import TritonMORRLinear -from ..triton_modules.morr_linear_mem import TritonMemMORRLinear +# from ..triton_modules.morr_linear_mem import TritonMemMORRLinear optical_module_map = { "linear_morr": AllPassMORRCirculantLinear, "conv2d_morr": AllPassMORRCirculantConv2d, - "linear_morr_triton": TritonMemMORRLinear, + # "linear_morr_triton": TritonMemMORRLinear, } diff --git a/src/chop/nn/optical/triton_modules/morr_linear.py b/src/chop/nn/optical/triton_modules/morr_linear.py deleted file mode 100644 index f474f7aed..000000000 --- a/src/chop/nn/optical/triton_modules/morr_linear.py +++ /dev/null @@ -1,483 +0,0 @@ -""" -Description: -Author: Jiaqi Gu (jqgu@utexas.edu) -Date: 2022-04-18 14:19:57 -LastEditors: Jiaqi Gu (jqgu@utexas.edu) -LastEditTime: 2022-04-18 16:21:37 -""" - -from typing import Optional -import logging - -import numpy as np -import torch -import torch.fft -from torch import Tensor -from torch.nn import Parameter, init -from torch.types import Device - -from ..utils import MORRConfig_20um_MQ -from ..utils import mrr_roundtrip_phase_to_tr_func, mrr_roundtrip_phase_to_tr_fused -from ..utils import toeplitz -from ..utils import morr_uniform_ -from ..utils import input_quantize_fn, weight_quantize_fn -from ..modules.base_layer import ONNBaseLayer -from .morr_linear_kernel import morr_linear_fn - -logger = logging.getLogger(__name__) - -__all__ = ["AllPassMORRCirculantLinear"] - - -class TritonMORRLinear(ONNBaseLayer): - """ - All-pass MORR Linear layer, assumes (1) block-circulant matrix (2) differential rails (3) learnable balancing factors. - J. Gu, et al., "SqueezeLight: Towards Scalable Optical Neural Networks with Multi-Operand Ring Resonators" - https://doi.org/10.23919/DATE51398.2021.9474147 - """ - - __constants__ = ["in_features", "out_features"] - in_features: int - out_features: int - miniblock: int - weight: Tensor - - def __init__( - self, - in_features: int, - out_features: int, - bias: bool = False, - config={}, - device: Device = torch.device("cpu"), - ) -> None: - super(TritonMORRLinear, self).__init__() - self.in_features = in_features - self.out_features = out_features - - miniblock_size = config.get("miniblock", 4) - self.miniblock = miniblock_size - self.grid_dim_x = int(np.ceil(self.in_features / miniblock_size)) - self.grid_dim_y = int(np.ceil(self.out_features / miniblock_size)) - self.in_features_pad = self.grid_dim_x * miniblock_size - self.out_features_pad = self.grid_dim_y * miniblock_size - - self.v_max = 10.8 - self.v_pi = 4.36 - self.gamma = np.pi / self.v_pi**2 - self.w_bit = 32 - self.in_bit = 32 - - morr_config = config.get("MORRConfig", MORRConfig_20um_MQ) - morr_init_val = config.get("morr_init", MORRConfig_20um_MQ) - self.MORRConfig = morr_config - self.morr_init = morr_init_val - self.mrr_a = morr_config.attenuation_factor - self.mrr_r = morr_config.coupling_factor - self.trainable_morr_bias = config.get("trainable_morr_bias", MORRConfig_20um_MQ) - self.trainable_morr_scale = config.get( - "trainable_morr_scale", MORRConfig_20um_MQ - ) - self.device = device - ### calculate FWHM (rad) - self.morr_fwhm = ( - -4 - * np.pi**2 - * morr_config.radius - * morr_config.effective_index - * ( - 1 / morr_config.resonance_wavelength - - 1 / (morr_config.resonance_wavelength - morr_config.bandwidth / 2) - ) - ) - - ### allocate parameters - self.weight = None - self.x_zero_pad = None - self.morr_output_scale = None ## learnable balancing factors implelemt by MRRs - self.morr_input_bias = None ## round-trip phase shift bias within MORR - self.morr_input_scale = ( - None ## scaling factor for the round-trip phase shift within MORR - ) - self.morr_gain = ( - 100 / (self.in_features // self.miniblock) - ) ** 0.5 ## TIA gain, calculated such that output variance is around 1 - ### build trainable parameters - self.build_parameters() - - ### quantization tool - self.input_quantizer = input_quantize_fn(self.in_bit, device=self.device) - self.weight_quantizer = weight_quantize_fn( - self.w_bit, alg="dorefa_pos" - ) ## [0-1] positive only, maintain the original scale - self.morr_output_scale_quantizer = weight_quantize_fn( - self.w_bit, alg="dorefa_sym" - ) ## [-1,1] full-range - - self.mrr_roundtrip_phase_to_tr = mrr_roundtrip_phase_to_tr_func( - a=self.mrr_a, r=self.mrr_r, intensity=True - ) - - ### default set to slow forward - self.disable_fast_forward() - ### default set no gamma noise - self.set_gamma_noise(0) - ### default set no crosstalk - self.disable_crosstalk() - ### default set no phase variation - self.disable_phase_variation() - - if bias: - self.bias = Parameter(torch.Tensor(out_features).to(self.device)) - else: - self.register_parameter("bias", None) - - self.reset_parameters(morr_init=morr_init_val) - self.finegrain_drop_mask = None - - def build_parameters(self) -> None: - - self.weight = Parameter( - torch.ones( - self.grid_dim_y, - self.grid_dim_x, - self.miniblock, - device=self.device, - dtype=torch.float, - ) - ) - ### Learnable balancing factor (morr_output_scale) - ### We use a single scaling factor for each block - self.morr_output_scale = Parameter( - torch.randn(1, 1, max(1, self.grid_dim_x // 2) + 1, 1, device=self.device) - ) - if self.trainable_morr_bias: - ### initialize with the finest-granularity, i.e., per mini-block - self.morr_input_bias = Parameter( - torch.zeros( - self.grid_dim_y, - self.grid_dim_x, - device=self.device, - dtype=torch.float, - ) - ) - if self.trainable_morr_scale: - ### initialize with the finest-granularity, i.e., per mini-block - self.morr_input_scale = Parameter( - torch.zeros( - self.grid_dim_y, - self.grid_dim_x, - device=self.device, - dtype=torch.float, - ) - ) - - def reset_parameters(self, morr_init: bool = False) -> None: - ### nonlinear curve aware initialization - if morr_init: - ## initialize weight - morr_uniform_( - self.weight, - MORRConfig=self.MORRConfig, - n_op=self.miniblock, - biased=self.w_bit >= 16, - gain=2 if self.in_bit < 16 else 1, - ) # quantization needs zero-center - self.sigma_weight = self.weight.data.std().item() - self.weight_quant_gain = None - - ## output distribution aware initialization to output scaling factor - t1 = mrr_roundtrip_phase_to_tr_fused( - torch.tensor([0]).float(), a=self.mrr_a, r=self.mrr_r, intensity=True - ) - t2 = mrr_roundtrip_phase_to_tr_fused( - torch.tensor([self.morr_fwhm * 2.4]).float(), - a=self.mrr_a, - r=self.mrr_r, - intensity=True, - ) - g = ( - (t2 - t1) / (2.4 * self.morr_fwhm) - ).item() ## 0~2.4 FWHM slope as a linear approximation - - self.sigma_out_scale = 4 / (3 * self.grid_dim_x**0.5 * g * self.morr_fwhm) - self.out_scale_quant_gain = None - init.normal_(self.morr_output_scale, 0, self.sigma_out_scale) - else: - init.kaiming_normal_(self.weight.data) - init.kaiming_normal_(self.morr_output_scale.data) - self.sigma_weight = self.weight.data.std().item() - self.weight_quant_gain = None - self.sigma_out_scale = self.morr_output_scale.data.std().item() - self.out_scale_quant_gain = None - - if self.morr_input_bias is not None: - self.morr_input_bias.data.zero_() - if self.morr_input_scale is not None: - ### after sigmoid, it cooresponds to 1 scale - init.normal_(self.morr_input_scale.data, 2, 0.1) - - if self.bias is not None: - init.uniform_(self.bias, 0, 0) - - def sync_parameters(self, src: str = "weight") -> None: - """ - description: synchronize all parameters from the source parameters - """ - - raise NotImplementedError - - def build_weight(self) -> Tensor: - if self.w_bit < 16: - ### differentiable quantizer based on STE to enable QAT (Dorefa-Net, arXiv 2016) - weight = self.weight_quantizer(self.weight) - - ## rescale weights after quantization can maintain the initialization distribution - if self.weight_quant_gain is None: - self.weight_quant_gain = self.sigma_weight / weight.data.std() - if self.trainable_morr_scale: - morr_scale = self.morr_scale * self.weight_quant_gain - else: - morr_scale = self.weight_quant_gain - weight = weight.mul( - morr_scale - ) ### gain factor from Tanh used in quantization - - ### quantize learnable balancing factor - morr_output_scale = self.morr_output_scale_quantizer(self.morr_output_scale) - else: - weight = self.weight.abs() # positive only - morr_output_scale = ( - self.morr_output_scale - self.morr_output_scale.data.mean() - ) - - if self.finegrain_drop_mask is not None: - weight = weight.mul(self.finegrain_drop_mask.float()) - - ## differential balancing factor concatenation - scale = morr_output_scale[..., :-1, :] - scale_pad = morr_output_scale[..., -1:, :] - if self.grid_dim_x % 2 == 0: - # even blocks - scale = torch.cat([scale, -scale], dim=2) # [1, 1, q, 1] - else: - # odd blocks - if self.grid_dim_x > 1: - scale = torch.cat([morr_output_scale, -scale], dim=2) # [1, 1, q, 1] - else: - scale = scale_pad # [1, 1, q, 1] - morr_output_scale = scale.squeeze(-1).unsqueeze(0) # [1 ,1, 1, q] - - return weight, morr_output_scale - - def enable_fast_forward(self) -> None: - self.fast_forward_flag = True - - def disable_fast_forward(self) -> None: - self.fast_forward_flag = False - - def set_gamma_noise( - self, noise_std: float, random_state: Optional[int] = None - ) -> None: - self.gamma_noise_std = noise_std - - def load_parameters(self, param_dict) -> None: - """ - description: update parameters based on this parameter dictionary\\ - param param_dict {dict of dict} {layer_name: {param_name: param_tensor, ...}, ...} - """ - for name, param in param_dict.items(): - getattr(self, name).data.copy_(param) - - def set_weight_bitwidth(self, w_bit: int) -> None: - self.w_bit = w_bit - self.weight_quantizer.set_bitwidth(w_bit) - self.morr_output_scale_quantizer.set_bitwidth(w_bit) - - def set_input_bitwidth(self, in_bit: int) -> None: - self.in_bit = in_bit - self.input_quantizer.set_bitwidth(in_bit) - - def input_modulator(self, x: Tensor) -> Tensor: - ### voltage to power, which is proportional to the phase shift - return x * x - - def set_crosstalk_coupling_matrix( - self, coupling_factor: float, drop_perc: float = 0 - ) -> None: - ### crosstalk coupling matrix is a symmetric matrix, but the intra-MORR crosstalk can be taken as a round-trip phase shift scaling factor, which is proportional to the number of segments after pruned. - ### drop-perc is the pruning percentage. - assert 0 <= coupling_factor <= 1, logger.error( - f"Coupling factor must in [0,1], but got {coupling_factor}" - ) - - self.crosstalk_factor = ( - 1 + max(3, (self.miniblock * (1 - drop_perc) - 1)) * coupling_factor - ) - - def enable_crosstalk(self) -> None: - self.enable_thermal_crosstalk = True - - def disable_crosstalk(self) -> None: - self.enable_thermal_crosstalk = False - - def set_phase_variation(self, phase_noise_std: float = 0) -> None: - self.phase_noise_std = phase_noise_std - - def enable_phase_variation(self) -> None: - self.enable_phase_noise = True - - def disable_phase_variation(self) -> None: - self.enable_phase_noise = False - - def enable_trainable_morr_scale(self) -> None: - self.trainable_morr_scale = True - - def disable_trainable_morr_scale(self) -> None: - self.trainable_morr_scale = False - - def enable_trainable_morr_bias(self) -> None: - self.trainable_morr_bias = True - - def disable_trainable_morr_bias(self) -> None: - self.trainable_morr_bias = False - - @property - def morr_bias(self) -> Tensor: - if self.morr_input_bias is None: - return None - # return 2 * self.morr_fwhm * torch.sigmoid(self.morr_input_bias.unsqueeze(0).unsqueeze(-1)) - return self.morr_fwhm * torch.tanh( - self.morr_input_bias.unsqueeze(0).unsqueeze(-1) - ) - - @property - def morr_scale(self) -> Tensor: - if self.morr_input_scale is None: - return None - return torch.sigmoid(self.morr_input_scale.unsqueeze(-1)) + 0.2 # [p, q, 1] - - def propagate_morr( - self, weight: Tensor, x: Tensor, morr_output_scale: Tensor - ) -> Tensor: - """ - @description: propagate through the analytically calculated transfer matrix of molg. We implement circulant matrix multiplication using fast circ matmul - @param weight {torch.Tensor} two phase shifters in the MZI-based attenuators - @param x {torch.Tensor} complex-valued input - @param morr_output_scale {torch.Tensor} learnable balancing factors - @return: y {torch.Tensor} output of attenuators - """ - ### x : [bs, q, k] - ### weights: [p, q, k] - ### morr_output_scale: [1, 1, 1, q] - - ### input scaling [TCAD'21], must have valid ranges. too small will have dead neuron and not enough nonlinearity; too large will have larger power, cross-channel crosstalk. [0.2 - 1.2] will be suitable - ## build circulant weight matrix - # crosstalk on the weights are much cheaper to compute than on the phase shift - if self.enable_thermal_crosstalk and self.crosstalk_factor > 1: - weight = weight * self.crosstalk_factor - weight = toeplitz(weight).unsqueeze(0) # [1, p, q, k, k] - x = x.unsqueeze(1).unsqueeze(-1) # [bs, 1, q, k, 1] - x = weight.matmul(x).squeeze(-1) # [bs, p, q, k] - - if self.enable_phase_noise and self.phase_noise_std > 1e-5: - x = x + torch.zeros_like(x).normal_(0, self.phase_noise_std) - - ### input biasing [TCAD'21], must have valid ranges. too large will have power issue and cross-channel crosstalk. [-2FWHM ~ 0] - if self.trainable_morr_bias: - x = x - self.morr_bias - - ### Use theoretical transmission function for trainable MORR nonlinearity [TCAD'21] - ### x is the phase detuning, x=0 means on-resonance - ### phase: [bs, p, q, k] - x = self.mrr_roundtrip_phase_to_tr(x) # 3x faster than autograd - - ## implement balancing factor as dot-product - """ - if(self.w_bit < 16): - morr_output_scale = self.morr_output_scale_quantizer(self.morr_output_scale) - if(self.sigma_out_scale_quant_gain is None): - self.sigma_out_scale_quant_gain = self.sigma_out_scale / morr_output_scale.data.std().item() - morr_output_scale = morr_output_scale.mul(self.sigma_out_scale_quant_gain)### gain factor from Tanh used in quantization - else: - morr_output_scale = self.morr_output_scale - # morr_output_scale = morr_output_scale * self.morr_gain - scale = morr_output_scale[..., :-1, :] - scale_pad = morr_output_scale[..., -1:, :] - - # print("morr diff transmission:", end=", ") - # diff = x[..., :x.size(2)//2,:]-x[..., x.size(2)//2:,:] - # print_stat(diff) - if(self.grid_dim_x % 2 == 0): - #even blocks - scale = torch.cat([scale, -scale], dim=2) # [1, 1, q, 1] - else: - # odd blocks - if(self.grid_dim_x > 1): - scale = torch.cat([morr_output_scale, -scale], dim=2) # [1, 1, q, 1] - else: - scale = scale_pad # [1, 1, q, 1] - scale = scale.squeeze(-1).unsqueeze(0) # [1 ,1, 1, q] - # print("output scale Q:", end=", ") - # print_stat(scale[..., :scale.size(-1)//2]) - """ - x = morr_output_scale.matmul(x) # [1, 1, 1, q] x [bs, p, q, k] = [bs, p, 1, k] - x = x.flatten(1) # [bs, p*k] - return x - - def get_finegrain_drop_mask(self, topk: int) -> Tensor: - if self.w_bit < 16: - weight = self.weight_quantizer(self.weight.data) # [p, q, k] - else: - weight = self.weight.data.abs() - indices = weight.argsort(dim=-1) - mask = torch.ones_like(weight, dtype=torch.bool, device=weight.device) - - drop_indices = indices[:, :, 0:-topk] - mask.scatter_(2, drop_indices, 0) - self.finegrain_drop_mask = mask - return mask - - def apply_finegrain_drop_mask(self, mask: Tensor) -> None: - if self.w_bit < 16: - self.weight.data.masked_fill_(~mask.view_as(self.weight.data), -1000) - else: - self.weight.data.masked_fill_(~mask.view_as(self.weight.data), 0) - - def forward(self, x: Tensor) -> Tensor: - output, *_ = morr_linear_fn( - x, - self.weight, - morr_input_bias=self.morr_input_bias, - morr_output_scale=self.morr_output_scale, - bias=None, - morr_input_scale=self.morr_input_scale, - morr_bias=self.morr_bias.detach() if self.morr_bias is not None else None, - grid_dim_x=self.grid_dim_x, - grid_dim_y=self.grid_dim_y, - miniblock=self.miniblock, - enable_thermal_crosstalk=self.enable_thermal_crosstalk, - crosstalk_factor=( - None if not self.enable_thermal_crosstalk else self.crosstalk_factor - ), - enable_phase_noise=self.enable_phase_noise, - phase_noise_std=( - None if not self.enable_phase_noise else self.phase_noise_std - ), - trainable_morr_bias=self.trainable_morr_bias, - mrr_a=self.mrr_a, - mrr_r=self.mrr_r, - finegrain_drop_mask=None, - in_features=self.in_features, - in_features_pad=self.in_features_pad, - out_features=self.out_features, - out_features_pad=self.out_features_pad, - in_bit=self.in_bit, - w_bit=self.w_bit, - morr_fwhm=self.morr_fwhm, - sigma_weight=self.sigma_weight, - trainable_morr_scale=self.trainable_morr_scale, # bool - morr_scale=self.morr_scale, - weight_quant_gain=self.weight_quant_gain, - seed=42, - ) - return output diff --git a/src/chop/nn/optical/triton_modules/morr_linear_kernel.py b/src/chop/nn/optical/triton_modules/morr_linear_kernel.py deleted file mode 100644 index 6de1d97f5..000000000 --- a/src/chop/nn/optical/triton_modules/morr_linear_kernel.py +++ /dev/null @@ -1,930 +0,0 @@ -import os - -# os.environ["TRITON_INTERPRET"] = "1" - -import torch -from torch import Tensor -import triton -import triton.language as tl -import pdb - -from .dtype import TORCH_DTYPE_TO_TRITON - -PACKAGE_NAME = "mase_triton" -from ..utils import ( - toeplitz, - input_quantize_fn, - weight_quantize_fn, -) -from .quantize import _input_quantize_fn, _weight_quantize_fn - - -def _get_autotune_configs(): - configs = [] - for _M in [1, 2, 4, 8]: - for _P in [1, 2, 4, 8]: - for _Q in [1, 2, 4, 8]: - configs.append( - triton.Config( - { - "BLOCK_SIZE_M": _M, - "BLOCK_SIZE_P": _P, - "BLOCK_SIZE_Q": _Q, - # "BLOCK_SIZE_K1": 4, - "BLOCK_SIZE_K2": 1, - }, - num_stages=3, - num_warps=8, - ) - ) - return configs - - -@triton.jit -def _mrr_roundtrip_phase_to_tr_func( - x: tl.tensor, - a: tl.constexpr = 0.8, - r: tl.constexpr = 0.9, - intensity: tl.constexpr = False, -): - """ - Applies a round-trip phase correction to the input tensor. - """ - c1 = -2.0 * a * r - c2 = a * a + r * r - c3 = 1.0 + r * r * a * a - a * a - r * r - - cos_x = tl.cos(x) - numerator = cos_x * c1 + c2 - denominator = numerator + c3 - x = numerator / denominator - if not intensity: - x = tl.sqrt(x) - return x - - -# @triton.autotune( -# configs= [ -# triton.Config( -# { -# "BLOCK_SIZE_M": 1, -# "BLOCK_SIZE_P": 1, -# "BLOCK_SIZE_Q": 1, -# # "BLOCK_SIZE_K1": 4, -# "BLOCK_SIZE_K2": 1, -# }, -# num_stages=3, -# num_warps=8, -# ),], -# key=["M", "P", "Q", "K"], -# ) -@triton.autotune( - configs=_get_autotune_configs(), - key=["M", "P", "Q", "K"], -) -@triton.jit -def morr_propagate_kernel( - x_ptr, - w_ptr, - o_ptr, - b_ptr, - M, - P, - Q, - K, - grid_dim_q, - grid_dim_p, - miniblock, - crosstalk_factor, - phase_noise_std, - mrr_a, - mrr_r, - in_bit, - w_bit, - seed, - # stride - stride_wm, - stride_wp, - stride_wq, - stride_wk1, - stride_wk2, - stride_xm, - stride_xp, - stride_xq, - stride_xk1, - stride_xk2, - stride_bm, - stride_bp, - stride_bq, - stride_bk1, - stride_om, - stride_op, - stride_oq, - stride_ok1, - stride_ok2, - finegrain_drop_mask, - ENABLE_PHASE_NOISE: tl.constexpr, - ENABLE_THERMAL_CROSSTALK: tl.constexpr, - TRAINABLE_MORR_BIAS: tl.constexpr, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_P: tl.constexpr, - BLOCK_SIZE_Q: tl.constexpr, - BLOCK_SIZE_K1: tl.constexpr, - BLOCK_SIZE_K2: tl.constexpr, - INPUT_DTYPE: tl.constexpr, -): - - # Program ID for block-based processing - # each program is assigned GROUP_SIZE_MPQ * [1, 1, miniblock, 1] block - pid = tl.program_id(axis=0) - # number of blocks (each program needs to handle) along M, P, Q dimension - pnum_m = grid_dim_p * grid_dim_q - pnum_p = grid_dim_p // BLOCK_SIZE_P - pnum_q = grid_dim_q // BLOCK_SIZE_Q - # block dimension of current program - pid_m = pid // (pnum_q * pnum_p) - pid_p = (pid // pnum_q) % pnum_p - pid_q = pid % pnum_q - - # starting element's m, p, q coordinates in the global tensor - start_m = pid_m * BLOCK_SIZE_M - start_p = pid_p * BLOCK_SIZE_P - start_q = pid_q * BLOCK_SIZE_Q - - # w [1, p, q, k, 1] -> toeplitz [1, p, q, k, k] - offs_wm = tl.arange(0, 1) - offs_wp = pid_p * BLOCK_SIZE_P + tl.arange(0, 1) - offs_wq = pid_q * BLOCK_SIZE_Q + tl.arange(0, 1) - offs_wk1 = tl.arange(0, BLOCK_SIZE_K1) - offs_wk2 = tl.arange(0, BLOCK_SIZE_K1) - # x [m, 1, q, k, 1] - offs_xm = pid_m * BLOCK_SIZE_M + tl.arange(0, 1) - offs_xp = tl.arange(0, 1) - offs_xq = pid_q * BLOCK_SIZE_Q + tl.arange(0, 1) - offs_xk1 = tl.arange(0, BLOCK_SIZE_K1) - offs_xk2 = tl.arange(0, BLOCK_SIZE_K2) - # morr_bias: [1, p, q, 1] - offs_bm = tl.arange(0, 1) - offs_bp = pid_p * BLOCK_SIZE_P + tl.arange(0, 1) - offs_bq = pid_q * BLOCK_SIZE_Q + tl.arange(0, 1) - offs_bk1 = tl.arange(0, 1) - - w_ptrs = w_ptr + ( - offs_wm[:, None, None, None, None] * stride_wm - + offs_wp[None, :, None, None, None] * stride_wp - + offs_wq[None, None, :, None, None] * stride_wq - + offs_wk1[None, None, None, :, None] * stride_wk1 - + offs_wk2[None, None, None, None, :] * stride_wk2 - ) - x_ptrs = x_ptr + ( - offs_xm[:, None, None, None, None] * stride_xm - + offs_xp[None, :, None, None, None] * stride_xp - + offs_xq[None, None, :, None, None] * stride_xq - + offs_xk1[None, None, None, :, None] * stride_xk1 - + offs_xk2[None, None, None, None, :] * stride_xk2 - ) - b_ptrs = b_ptr + ( - offs_bm[:, None, None, None, None] * stride_bm - + offs_bp[None, :, None, None, None] * stride_bp - + offs_bq[None, None, :, None, None] * stride_bq - + offs_bk1[None, None, None, :, None] * stride_bk1 - ) - - acc = tl.zeros( - (BLOCK_SIZE_M, BLOCK_SIZE_P, BLOCK_SIZE_Q, BLOCK_SIZE_K1, BLOCK_SIZE_K2), - dtype=tl.float32, - ) - m_indices = tl.arange(0, BLOCK_SIZE_M)[:, None, None, None, None] - p_indices = tl.arange(0, BLOCK_SIZE_P)[None, :, None, None, None] - q_indices = tl.arange(0, BLOCK_SIZE_Q)[None, None, :, None, None] - - for m_local in range(BLOCK_SIZE_M): - m = start_m + m_local - for p_local in range(BLOCK_SIZE_P): - p = start_p + p_local - for q_local in range(BLOCK_SIZE_Q): - q = start_q + q_local - - w_mask = (p < P) & (q < Q) - x_mask = (m < M) & (q < Q) - b_mask = (p < P) & (q < Q) - - w = tl.load(w_ptrs, mask=w_mask, other=0.0) - x = tl.load(x_ptrs, mask=x_mask, other=0.0) - b = tl.load(b_ptrs, mask=b_mask, other=0.0) - - w = w.reshape(BLOCK_SIZE_K1, BLOCK_SIZE_K1) # [1, 1, 1, k, k] -> [k, k] - x = x.reshape(BLOCK_SIZE_K1, BLOCK_SIZE_K2) # [1, 1, 1, k, 1] -> [k, 1] - - x = x * x # input_modulator() - # ----- propagate_morr() ----- - - # apply thermal crosstalk noise - if ENABLE_THERMAL_CROSSTALK: - w = w * crosstalk_factor - - # MatMals - # TODO: tl.dot requires 16*16 matrix at least, this is a workaround - x = tl.trans(x) - x = tl.broadcast_to(x, (BLOCK_SIZE_K1, BLOCK_SIZE_K1)) - x = tl.sum(w * x, axis=1) - x = tl.reshape(x, (BLOCK_SIZE_K1, BLOCK_SIZE_K2)) - - # apply phase noise - if ENABLE_PHASE_NOISE: - block_start = pid * BLOCK_SIZE_K1 * BLOCK_SIZE_K2 - offs = tl.reshape( - block_start + tl.arange(0, BLOCK_SIZE_K1 * BLOCK_SIZE_K2), - (BLOCK_SIZE_K1, BLOCK_SIZE_K2), - ) - noise = tl.randn(seed, offs) * phase_noise_std - x = x + noise - - # add trainable bias - b = b.reshape(1, 1) - - if TRAINABLE_MORR_BIAS: - x = x - b - - # mrr_roundtrip_phase_to_tr - x = _mrr_roundtrip_phase_to_tr_func(x, mrr_a, mrr_r, intensity=True) - - # store the value in acc using mask - res = x - condition_mask = ( - (m_indices == m_local) - & (p_indices == p_local) - & (q_indices == q_local) - ) - res = res[None, None, None, :, :] - acc = tl.where(condition_mask, res, acc) - - # propagate pointer along Q dimension - w_ptrs += stride_wq - x_ptrs += stride_xq - b_ptrs += stride_bq - - # Q loop end - # reset pointer along Q dimension - w_ptrs -= stride_wq * (BLOCK_SIZE_Q) - x_ptrs -= stride_xq * (BLOCK_SIZE_Q) - b_ptrs -= stride_bq * (BLOCK_SIZE_Q) - # propagate pointer along P dimension - w_ptrs += stride_wp - b_ptrs += stride_bp - # x_ptrs += stride_xp # x has P dimension = 1 - - # P loop end - # reset pointer along P dimension - w_ptrs -= stride_wp * (BLOCK_SIZE_P) - b_ptrs -= stride_bp * (BLOCK_SIZE_P) - # x_ptrs -= stride_xp * (BLOCK_SIZE_P + 1) # x has P dimension = 1、 - - # propagate pointer along M dimension - # w_ptrs += stride_wp # weight has M dimension = 1 - x_ptrs += stride_xm - - out = acc.to(INPUT_DTYPE) - out = out.reshape( - BLOCK_SIZE_M, BLOCK_SIZE_P, BLOCK_SIZE_Q, BLOCK_SIZE_K1 - ) # [1, 1, q, k, 1] -> [1, 1, q, k] - - offs_om = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_op = pid_p * BLOCK_SIZE_P + tl.arange(0, BLOCK_SIZE_P) - offs_oq = pid_q * BLOCK_SIZE_Q + tl.arange(0, BLOCK_SIZE_Q) - offs_ok1 = tl.arange(0, BLOCK_SIZE_K1) - # offs_ok2 = tl.arange(0, BLOCK_SIZE_K2) - o_ptrs = o_ptr + ( - stride_om * offs_om[:, None, None, None] - + stride_op * offs_op[None, :, None, None] - + stride_oq * offs_oq[None, None, :, None] - + stride_ok1 * offs_ok1[None, None, None, :] - ) - - m_valid = offs_om[:, None, None, None] < M - p_valid = offs_op[None, :, None, None] < P - q_valid = offs_oq[None, None, :, None] < Q - k_valid = offs_ok1[None, None, None, :] < K # K == BLOCK_SIZE_K1 - o_mask = m_valid & p_valid & q_valid & k_valid - tl.store(o_ptrs, out, mask=o_mask) - - -@torch.library.custom_op( - f"{PACKAGE_NAME}::optical_morr_linear_fn", - mutates_args={}, -) -def morr_linear_fn( - x: Tensor, - weight: Tensor, - morr_input_bias: Tensor, - morr_output_scale: Tensor, - bias: Tensor | None, - morr_input_scale: Tensor, - morr_bias: Tensor | None, - grid_dim_x: int, - grid_dim_y: int, - miniblock: int, - enable_thermal_crosstalk: bool, - crosstalk_factor: float | None, - enable_phase_noise: bool, - phase_noise_std: float | None, - trainable_morr_bias: bool, - mrr_a: float, - mrr_r: float, - finegrain_drop_mask: Tensor | None, - in_features: int, - in_features_pad: int, - out_features: int, - out_features_pad: int, - in_bit: int, - w_bit: int, - morr_fwhm: float, - sigma_weight: float, - trainable_morr_scale: bool, - morr_scale: Tensor, - weight_quant_gain: float | None = None, - in_quant_alg: str = "dorefa", - w_quant_alg: str = "dorefa_pos", - morr_output_scale_quant_alg: str = "dorefa_sym", - seed: int = 42, -) -> tuple[Tensor, int, Tensor, Tensor, Tensor, Tensor, Tensor, float, Tensor, Tensor]: - Device = x.device - Dtype = x.dtype - assert x.dtype in ( - torch.bfloat16, - torch.float16, - torch.float32, - ), f"Unsupported dtype {x.dtype}" - assert x.is_contiguous(), "Input tensor must be contiguous" - assert weight.dtype in ( - torch.bfloat16, - torch.float16, - torch.float32, - ), f"Unsupported dtype {weight.dtype}" - - # Handle transformer vs non-transformer inputs - ori_x_shape = x.shape - is_transformer = len(ori_x_shape) == 3 - - if is_transformer: - in_B, in_N, in_D = x.shape - M = in_B * in_N - x = x.reshape(M, in_D) - else: - M = x.shape[0] - - # Get dimensions - M, D = x.shape - P, Q, K = weight.shape - - if in_features_pad > D: - x_pad = torch.zeros(M, in_features_pad - D, device=Device, dtype=x.dtype) - x = torch.cat([x, x_pad], dim=1) - - assert Q * K == in_features_pad, "input and weight dimension mismatch" - assert P * K == out_features_pad, "weight and output dimension mismatch" - - # Quantize input - ctx_x_quant = torch.empty(0, device=Device, dtype=Dtype) - if in_bit < 16: - input_quantizer = input_quantize_fn(in_bit, device=Device) - input_quantizer.set_bitwidth(in_bit) - ctx_x_quant = x.clone() - x = input_quantizer(x) - - # Build weight - ctx_w_quant = torch.empty(0, device=Device, dtype=Dtype) - if w_bit < 16: - weight_quantizer = weight_quantize_fn(w_bit, alg="dorefa_pos") - weight_quantizer.set_bitwidth(w_bit) - ctx_w_quant = weight.clone() - weight = weight_quantizer(weight) - - ## rescale weights after quantization can maintain the initialization distribution - if weight_quant_gain is None: - weight_quant_gain = sigma_weight / weight.data.std() - if trainable_morr_scale: - morr_scale = morr_scale * weight_quant_gain - else: - morr_scale = weight_quant_gain - weight = weight.mul(morr_scale) ### gain factor from Tanh used in quantization - ### quantize learnable balancing factor - morr_output_scale_quantizer = weight_quantize_fn(w_bit, alg="dorefa_sym") - morr_output_scale = morr_output_scale_quantizer(morr_output_scale) - else: - weight = weight.abs() # positive only - morr_output_scale = morr_output_scale - morr_output_scale.data.mean() - - if finegrain_drop_mask is not None: - weight = weight.mul(finegrain_drop_mask.float()) - - # differential balancing factor concatenation - scale = morr_output_scale[..., :-1, :] - scale_pad = morr_output_scale[..., -1:, :] - if grid_dim_x % 2 == 0: - # even blocks - scale = torch.cat([scale, -scale], dim=2) # [1, 1, q, 1] - else: - # odd blocks - if grid_dim_x > 1: - scale = torch.cat([morr_output_scale, -scale], dim=2) # [1, 1, q, 1] - else: - scale = scale_pad # [1, 1, q, 1] - morr_output_scale = scale.squeeze(-1).unsqueeze(0) # [1 ,1, 1, q] - ctx_morr_output_scale = morr_output_scale.clone() - - # Reshape x and weight - x = x.view(-1, grid_dim_x, miniblock) # [M, q, k] - x = x.unsqueeze(1).unsqueeze(-1) # [M, 1, q, k, 1] - weight = toeplitz(weight).unsqueeze(0) # [p, q, k] -> [1, p, q, k, k] - - x_ctx = x.squeeze(-1).squeeze(1).clone() # [M, q, k] - w_ctx = weight.clone() - - # Allocate output - output = torch.empty((M, P, Q, K, 1), device=Device, dtype=x.dtype) - # Launch the Triton kernel - grid = lambda meta: ( - triton.cdiv(M, meta["BLOCK_SIZE_M"]) - * triton.cdiv(P, meta["BLOCK_SIZE_P"]) - * triton.cdiv(Q, meta["BLOCK_SIZE_Q"]), - ) - morr_propagate_kernel[grid]( - x_ptr=x, - w_ptr=weight, - o_ptr=output, - b_ptr=morr_bias, - M=M, - P=P, - Q=Q, - K=K, - grid_dim_q=grid_dim_x, - grid_dim_p=grid_dim_y, - miniblock=miniblock, - crosstalk_factor=crosstalk_factor, - phase_noise_std=phase_noise_std, - mrr_a=mrr_a, - mrr_r=mrr_r, - in_bit=in_bit, - w_bit=w_bit, - seed=seed, - finegrain_drop_mask=finegrain_drop_mask, - stride_wm=weight.stride(0), - stride_wp=weight.stride(1), - stride_wq=weight.stride(2), - stride_wk1=weight.stride(3), - stride_wk2=weight.stride(4), - stride_xm=x.stride(0), - stride_xp=x.stride(1), - stride_xq=x.stride(2), - stride_xk1=x.stride(3), - stride_xk2=x.stride(4), - stride_bm=morr_bias.stride(0) if morr_bias is not None else 0, - stride_bp=morr_bias.stride(1) if morr_bias is not None else 0, - stride_bq=morr_bias.stride(2) if morr_bias is not None else 0, - stride_bk1=morr_bias.stride(3) if morr_bias is not None else 0, - stride_om=output.stride(0), - stride_op=output.stride(1), - stride_oq=output.stride(2), - stride_ok1=output.stride(3), - stride_ok2=output.stride(4), - ENABLE_THERMAL_CROSSTALK=enable_thermal_crosstalk, - ENABLE_PHASE_NOISE=enable_phase_noise and phase_noise_std > 1e-4, - TRAINABLE_MORR_BIAS=trainable_morr_bias, - INPUT_DTYPE=TORCH_DTYPE_TO_TRITON[x.dtype], - BLOCK_SIZE_K1=K, - ) - - # Apply output scale - output = output.squeeze(-1) # [m, p, q, k, 1] -> [m, p, q, k] - ctx_x_scalematmul = output.clone() # record x input for matmul - output = morr_output_scale.matmul( - output - ) # [1, 1, 1, q] x [bs, p, q, k] = [bs, p, 1, k] - output = output.flatten(1) # [bs, p*k] - - # Trim output if needed - if out_features < out_features_pad: - output = output[:, :out_features] - if bias is not None: - output = output + bias.unsqueeze(0) - # Reshape back for transformer - if is_transformer: - output = output.view(in_B, in_N, out_features) - - # aux_tensor = ( - # torch.abs(w_ctx), # w_morr: weight in propagate_morr matmul - # x_ctx, # x_modulator: x before x^2 - # ) - - return ( - output, - seed, - torch.abs(w_ctx), - x_ctx, - ctx_morr_output_scale, - ctx_x_scalematmul, - morr_scale.clone(), - weight_quant_gain if weight_quant_gain is not None else 0.0, - ctx_x_quant, - ctx_w_quant, - ) - - -def _morr_linear_setup_context(ctx, inputs, output): - """ - Save for backward only what the backward routine really needs. - """ - ( - x, # 0 Tensor – input - weight, # 1 Tensor – learnable weight - morr_input_bias, # 23 Tensor - origin_morr_output_scale, # 3 Original input morr_output_scale - bias, # 4 Tensor | None – bias - morr_input_scale, # 5 Tensor - morr_bias, # 2 Tensor | None - grid_dim_x, # 5 int - grid_dim_y, # 6 int - miniblock, # 7 int (== K) - enable_thermal_crosstalk, # 8 bool - crosstalk_factor, # 9 float - enable_phase_noise, # 10 bool - phase_noise_std, # 11 float - trainable_morr_bias, # 12 bool - mrr_a, # 13 float - mrr_r, # 14 float - finegrain_drop_mask, # 15 Tensor | None - in_features, # 16 int - in_features_pad, # 17 int - out_features, # 18 int - out_features_pad, # 19 int - in_bit, # 20 int - w_bit, # 21 int - morr_fwhm, # 22 float - sigma_weight, - trainable_morr_scale, # bool - _morr_scale, - _weight_quant_gain, - in_quant_alg, - w_quant_alg, - morr_output_scale_quant_alg, - seed, - ) = inputs - - ( - output, - seed, - w_morr, - x_modulator, - morr_output_scale, - x_scalematmul, - morr_scale, - weight_quant_gain, - x_quant, - w_quant, - ) = output - # ( - # w_morr, - # x_modulator, - # ) = aux_tensor - - device, dtype = x.device, x.dtype - - # ----- Tensor meta-data that backward needs ----- - # Shapes - M = x.shape[0] if x.dim() == 2 else x.shape[0] * x.shape[1] - P, Q, K = weight.shape - tensor_shape = (M, P, Q, K) - - # mrr_para: para for mrr_roundtrip_phase_to_tr() - c1 = -2.0 * mrr_a * mrr_r - c2 = mrr_a * mrr_a + mrr_r * mrr_r - c3 = 1.0 + (mrr_r * mrr_r) * (mrr_a * mrr_a) - mrr_a * mrr_a - mrr_r * mrr_r - c4 = (mrr_a**2.0 - 1.0) * (mrr_r**2.0 - 1.0) * 2.0 * mrr_a * mrr_r - intensity = True - mrr_para = (c1, c2, c3, c4, intensity) - - # x_morr: x input of matmal in propagate_morr() - x_morr = x_modulator**2 # [m, q, k] - x_morr = x_morr.unsqueeze(1).unsqueeze(-1) # [m, 1, q, k, 1] - - # x_mrr: x input of mrr_roundtrip_phase_to_tr() - x_mrr = w_morr.matmul(x_morr).squeeze(-1) - if enable_phase_noise and phase_noise_std > 1e-5: - x_mrr = x_mrr + torch.zeros_like(x_mrr).normal_(0, phase_noise_std) - if trainable_morr_bias: - x_mrr = x_mrr - morr_bias - - tanh_input_bias = torch.tanh(morr_input_bias.unsqueeze(0).unsqueeze(-1)) - - # 3. stash tensors - ctx.save_for_backward( - x, # original input - weight, # original weight - bias if bias is not None else torch.tensor([], device=device, dtype=dtype), - morr_output_scale, # morr_output_scale after modification in build_weight() - x_mrr, # x input for mrr_roundtrip_phase_to_tr() - x_morr, - w_morr, # w input for propagate_morr() matmul - # morr_bias, - x_modulator, # x input for input_modulator() - # morr_input_bias, - x_scalematmul, # x input for morr_output_scale.matmul - tanh_input_bias, - morr_input_scale, - morr_scale, # morr_scale after modification in build_weight() - x_quant, # x input for input_quantize_fn() - w_quant, # w input for weight_quantize_fn() - origin_morr_output_scale, # original morr_output_scale - finegrain_drop_mask, - ) - ctx.tensor_shape = tensor_shape - ctx.mrr_para = mrr_para - ctx.in_features = in_features - ctx.in_features_pad = in_features_pad - ctx.out_features = out_features - ctx.out_features_pad = out_features_pad - ctx.morr_fwhm = morr_fwhm - ctx.grid_dim_x = grid_dim_x - ctx.grid_dim_y = grid_dim_y - ctx.in_bit = in_bit - ctx.w_bit = w_bit - ctx.x_input_shape = x.shape - ctx.device = x.device - ctx.w_input_shape = weight.shape - ctx.enable_phase_noise = enable_phase_noise - ctx.phase_noise_std = phase_noise_std - ctx.trainable_morr_bias = trainable_morr_bias - ctx.trainable_morr_scale = trainable_morr_scale - ctx.weight_quant_gain = weight_quant_gain - ctx.in_quant_alg = in_quant_alg - ctx.w_quant_alg = w_quant_alg - ctx.morr_output_scale_quant_alg = morr_output_scale_quant_alg - - -def _morr_linear_backward(ctx, grad_output, *ignored): - """ - Backward pass for morr_linear_fn. - """ - ( - x, - weight, - bias, - morr_output_scale, - x_mrr, - x_morr, - w_morr, - # morr_bias, - x_modulator, - # morr_input_bias, - x_scalematmul, - tanh_input_bias, - morr_input_scale, - morr_scale, - x_quant, - w_quant, - origin_morr_output_scale, - finegrain_drop_mask, - ) = ctx.saved_tensors - - M, P, Q, K = ctx.tensor_shape - c1, c2, c3, c4, intensity = ctx.mrr_para - in_features = ctx.in_features - in_features_pad = ctx.in_features_pad - out_features = ctx.out_features - out_features_pad = ctx.out_features_pad - x_input_shape = ctx.x_input_shape - w_input_shape = ctx.w_input_shape - DEVICE = ctx.device - - # --- calculate intermediate activation on the fly --- - # x_morr = (x_modulator ** 2).unsqueeze(1).unsqueeze(-1) # [m, q, k] -> # [m, 1, q, k, 1] - - # tanh_input_bias = torch.tanh(morr_input_bias.unsqueeze(0).unsqueeze(-1)) - # morr_bias = ctx.morr_fwhm * tanh_input_bias - - # # x_mrr: x input of mrr_roundtrip_phase_to_tr() - # x_mrr = w_morr.matmul(x_morr).squeeze(-1) - # if ctx.enable_phase_noise and ctx.phase_noise_std > 1e-5: - # x_mrr = x_mrr + torch.zeros_like(x_mrr).normal_(0, ctx.phase_noise_std) - # if ctx.trainable_morr_bias: - # x_mrr = x_mrr - morr_bias - - # ----- backward prop ----- - # Reshape - grad_out = grad_output.view( - x_input_shape[0], w_input_shape[1], w_input_shape[2], -1 - ) # [M, P, Q, K] - - # ----- Gradient w.r.t input x ----- - if True or ctx.needs_input_grad[0]: - # 1. reshape - grad_out = grad_out.view(M, -1) # [m, out_features] - - if ctx.needs_input_grad[4] and bias: - grad_bias = grad_out.sum(dim=0) # [out_features] - else: - grad_bias = None - - out_pad = torch.zeros( - grad_out.shape[0], out_features_pad - out_features, device=DEVICE - ) # [m, out_features_pad - out_features] - grad_out = torch.cat( - [grad_out, out_pad], dim=1 - ) # [m * out_features_pad] = [m, p*k] - - # 2. x=x.flatten(1) - # input: [m, p**k] - grad_out = grad_out.view(M, P, 1, K) # [m, p, 1, k] - - # 3. x = morr_output_scale.matmul(x) # [1, 1, 1, q] x [bs, p, q, k] = [bs, p, 1, k] - # dL/d(morr_output_scale) - if ctx.needs_input_grad[3]: - grad_s = grad_out.matmul(x_scalematmul.transpose(-2, -1)) # [bs, p, 1, q] - grad_s = grad_s.sum(dim=(0, 1)).unsqueeze(0).unsqueeze(1) # [1, 1, 1, q] - grad_s = grad_s.squeeze(0).unsqueeze(-1) # [1, 1, q, 1] gradient of scale - - t = ctx.grid_dim_x // 2 - grad_output_scale = grad_s.new_zeros((1, 1, t + 1, 1)) - - if ctx.grid_dim_x % 2 == 0: - grad_output_scale[..., :t, :] = grad_s[..., :t, :] - grad_s[..., t:, :] - elif ctx.grid_dim_x == 1: - grad_output_scale = grad_s - else: - grad_output_scale[..., :t, :] = ( - grad_s[..., :t, :] - grad_s[..., t + 1 :, :] - ) - grad_output_scale[..., t : t + 1, :] = grad_s[..., t : t + 1, :] - # build_weight() - if ctx.w_bit < 16: - # morr_output_scale_quantizer() - if ctx.morr_output_scale_quant_alg == "dorefa_sym": - # local recompute: - w_in = torch.tanh(origin_morr_output_scale) # [-1, 1] - r = torch.max(w_in.abs()).detach() - - # ignore gradient for r here - grad_output_scale = (grad_output_scale * 2 * r).clamp_(-1.0, 1.0) - grad_output_scale = grad_output_scale * (1.0 / (2 * r)) - grad_output_scale = grad_output_scale * (1.0 - w_in.pow(2)) - - else: - raise NotImplementedError - else: - grad_output_scale = None - - # dL/dx - grad_x = morr_output_scale.transpose(-2, -1).matmul(grad_out) # [bs, p, q, k] - - # 4. x = mrr_roundtrip_phase_to_tr(x) - denominator = x_mrr.cos().mul_(c1).add_(c2 + c3) - if intensity: - denominator.square_() - numerator = x_mrr.sin().mul_(c4) - else: - numerator = x_mrr.sin().mul_(c4 / 2) - denominator = denominator.sub(1).pow_(1.5).mul_(denominator.sub(c3).sqrt_()) - grad_x = numerator.div_(denominator).mul_(grad_x) # [bs, p, q, k] - - # 5. x += phase_noise and morr_bias - if ctx.needs_input_grad[2]: - grad_inputbias = -grad_x # [bs, p, q, k] - grad_inputbias = grad_inputbias * ctx.morr_fwhm # [bs, p, q, k] - grad_inputbias = ( - grad_inputbias - tanh_input_bias * tanh_input_bias - ) # [bs, p, q, k] - grad_inputbias = grad_inputbias.sum(dim=(0, -1)) - else: - grad_inputbias = None - - # 6. x = weight.matmul(x) [1, p, q, k, k] * [bs, 1, q, k, 1] = [bs, p, q, k, 1] - grad_x = grad_x.unsqueeze(-1) # [bs, p, q, k, 1] - grad_morr_matmul = grad_x # stash for weight gradient - - # dL/dx - grad_x = torch.matmul( - w_morr.transpose(-1, -2), grad_x - ) # [1, p, q, k, k] x [bs, p, q, k, 1] = [bs, p, q, k, 1] - grad_x = grad_x.sum(dim=1, keepdim=True) # [bs, p, q, k, 1] -> [bs, 1, q, k, 1] - grad_x = grad_x.squeeze(-1).squeeze(1) # [bs, 1, q, k, 1] -> [bs, q, k] - - # 7. input modulator(x) - grad_x = grad_x * 2 * x_modulator # [bs, q, k] - - # 8. input reshape - B, N, D = x_input_shape - grad_x = grad_x.view(-1, in_features_pad) # [b*n, in_features_pad] - grad_x = grad_x[:, :in_features] # [b*n, in_features = D] - - # 9.input quantization - if ctx.in_bit >= 16 or ctx.in_quant_alg is None: - pass - elif ctx.in_quant_alg == "dorefa": - grad_x = grad_x * ((x_quant > 0) & (x_quant < 1)) - else: - raise NotImplementedError - - # 10. input reshape - grad_x = grad_x.view(B, N, D) # [b, n, d] - # ----- Gradient w.r.t weight ----- - if True or ctx.needs_input_grad[1]: - - # 0. gradient after x = weight.matmul(x) - # grad_morr_matmul # [bs, p, q, k, 1] - - # 1. x = weight.matmul(x) - grad_w = torch.matmul( - grad_morr_matmul, x_morr.transpose(-1, -2) - ) # [bs,p,q,k,k] - grad_w = grad_w.sum(dim=0, keepdim=True) # [1,p,q,k,k] - - # 2. weight = toeplitz(weight) - k = grad_w.size(-1) - row = torch.arange(k)[:, None] # (k,1) - col = torch.arange(k)[None, :] # (1,k) - idx = (row - col) & (k - 1) if (k & (k - 1)) == 0 else (row - col + k) % k - - idx = idx.expand(grad_w.shape).to(DEVICE) - buffer = torch.zeros_like(grad_w, device=DEVICE) - buffer.scatter_add_(-2, idx, grad_w) # [1, p, q, k, k]cvb - grad_w = buffer.sum(dim=-1, keepdim=True).squeeze(0).squeeze(-1) # [p, q, k] - - # 3. build_weight() - if finegrain_drop_mask is not None: - grad_w = grad_w * finegrain_drop_mask.float() - # morr_scale: [p, q, 1] - grad_morr_input_scale = None - if ctx.w_bit < 16: - # grad w.r.t morr_scale - if ctx.needs_input_grad[5] & ctx.trainable_morr_scale: - grad_morr_scale = (grad_w * weight).sum( - dim=2, keepdim=True - ) # [p, q, 1] - grad_morr_scale = grad_morr_scale * ctx.weight_quant_gain # [p, q, 1] - # ∂L/∂self.morr_input_scale - sigmoid_scale = torch.sigmoid(morr_input_scale) - grad_morr_input_scale = ( - grad_morr_scale * sigmoid_scale * (1 - sigmoid_scale) - ).squeeze( - -1 - ) # [p, q] - - # grad w.r.t weight - grad_w = grad_w * morr_scale # weight.mul(morr_scale) - # weight_quantizer() - if ctx.w_quant_alg is None: - pass - elif ctx.w_quant_alg == "dorefa_pos": - # local recompute: - w_in = torch.tanh(w_quant) - r = torch.max(w_in.abs()).detach() + 1e-12 # ε avoids /0 - # ignore gradient for r here - # grad_w = grad_w * (1.0 - w_in.pow(2)) - # grad_w = grad_w.clamp_(-1, 1) - grad_w = grad_w * (2 * r) - grad_w = grad_w.clamp(-1.0, 1.0) - grad_w = grad_w / (2 * r) - grad_w = grad_w * (1.0 - w_in.pow(2)) - else: - raise NotImplementedError - else: - grad_w = grad_w * weight.sign() - - return ( - grad_x, # ∂L/∂x - grad_w, # ∂L/∂w - grad_inputbias, # ∂L/∂morr_input_bias - grad_output_scale, # ∂L/∂morr_output_scale - grad_bias, # ∂L/∂bias - grad_morr_input_scale, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - ) - - -morr_linear_fn.register_autograd( - _morr_linear_backward, - setup_context=_morr_linear_setup_context, -) diff --git a/test/passes/module/transforms/optical/test_optical_module.py b/test/passes/module/transforms/optical/test_optical_module.py index 067ecf1c8..6786b0c0d 100644 --- a/test/passes/module/transforms/optical/test_optical_module.py +++ b/test/passes/module/transforms/optical/test_optical_module.py @@ -54,69 +54,10 @@ def test_optical_module_transform_pass(): "trainable_morr_scale": False, } }, - "conv1": { - "config": { - "name": "morr", - "miniblock": 4, - "morr_init": True, - "trainable_morr_bias": False, - "trainable_morr_scale": False, - } - }, } optical_module_transform_pass(model, pass_args) -def test_optical_module_transform_pass_2(): - model = Net() - # Sanity check and report - pass_args = { - "by": "name", - "fc1": { - "config": { - "name": "morr_triton", - "miniblock": 4, - "morr_init": True, - "trainable_morr_bias": False, - "trainable_morr_scale": False, - } - }, - "conv1": { - "config": { - "name": "morr_triton", - "miniblock": 4, - "morr_init": True, - "trainable_morr_bias": False, - "trainable_morr_scale": False, - } - }, - } - optical_module_transform_pass(model, pass_args) - - -def test_optical_module_transform_pass_3(): - model = Net() - pass_args = { - "by": "regex_name", - "^fc1$": { - "config": {"name": "morr_triton", "miniblock": 4}, - "additional": { - "trainable_morr_bias": False, - "trainable_morr_scale": False, - "thermal_crosstalk": True, - "coupling_factor": 0.04, - "drop_perc": 0.0, - "phase_noise": True, - "phase_noise_std": 0.04, - "in_bit": 8, - "w_bit": 8, - }, - }, - } - new_model, _ = optical_module_transform_pass(model, pass_args) - print(new_model) - - test_optical_module_transform_pass() -test_optical_module_transform_pass_2() -test_optical_module_transform_pass_3() +# test_optical_module_transform_pass_2() +# test_optical_module_transform_pass_3() \ No newline at end of file From fe3c170f5d470fc2aee3e5ed4359f3f40d21f24d Mon Sep 17 00:00:00 2001 From: Johnny1882 Date: Tue, 24 Jun 2025 19:35:45 +0100 Subject: [PATCH 36/38] add notes --- test/passes/module/transforms/optical/note.md | 1 + 1 file changed, 1 insertion(+) create mode 100644 test/passes/module/transforms/optical/note.md diff --git a/test/passes/module/transforms/optical/note.md b/test/passes/module/transforms/optical/note.md new file mode 100644 index 000000000..98aed8041 --- /dev/null +++ b/test/passes/module/transforms/optical/note.md @@ -0,0 +1 @@ +### Note on using custom kernel for MORR linear layer From 7fe6ecffe20b95ea199c1b70be02fd79dea6bde1 Mon Sep 17 00:00:00 2001 From: Johnny1882 Date: Tue, 24 Jun 2025 19:40:48 +0100 Subject: [PATCH 37/38] reformat --- src/chop/nn/optical/modules/__init__.py | 1 + test/passes/module/transforms/optical/test_optical_module.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/chop/nn/optical/modules/__init__.py b/src/chop/nn/optical/modules/__init__.py index bfd7a558f..840b28a2a 100644 --- a/src/chop/nn/optical/modules/__init__.py +++ b/src/chop/nn/optical/modules/__init__.py @@ -1,5 +1,6 @@ from .morr_linear import AllPassMORRCirculantLinear from .morr_conv2d import AllPassMORRCirculantConv2d + # from ..triton_modules.morr_linear_mem import TritonMemMORRLinear diff --git a/test/passes/module/transforms/optical/test_optical_module.py b/test/passes/module/transforms/optical/test_optical_module.py index 6786b0c0d..f6e74e99e 100644 --- a/test/passes/module/transforms/optical/test_optical_module.py +++ b/test/passes/module/transforms/optical/test_optical_module.py @@ -60,4 +60,4 @@ def test_optical_module_transform_pass(): test_optical_module_transform_pass() # test_optical_module_transform_pass_2() -# test_optical_module_transform_pass_3() \ No newline at end of file +# test_optical_module_transform_pass_3() From ee5a1f2d0ae38d243f9bf8e777e231fc26010a95 Mon Sep 17 00:00:00 2001 From: Johnny1882 Date: Tue, 24 Jun 2025 20:48:52 +0100 Subject: [PATCH 38/38] complete note on further work --- test/passes/module/transforms/optical/note.md | 46 +++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/test/passes/module/transforms/optical/note.md b/test/passes/module/transforms/optical/note.md index 98aed8041..92bce3b98 100644 --- a/test/passes/module/transforms/optical/note.md +++ b/test/passes/module/transforms/optical/note.md @@ -1 +1,47 @@ ### Note on using custom kernel for MORR linear layer + +Current optical transform pass only support MORR linear PyTorch module. To enbale substitution using Optimised MORR linear module (using Triton kernel): + +1. uncomment `TritonMemMORRLinear` inside [file](../../../../../src/chop/nn/optical/modules/__init__.py) +2. replace `morr_linear_fn_mem` function in [kernel wrapper](../../../../../src/chop/nn/optical/triton_modules/morr_linear_mem.py). Current implementation import it from a project file, import it from mase-triton instead. +3. You should now able to use optimised MORR linear module in optical transform pass. Two sample usage are shown below: + +```python + +# Minimal example ─ apply the MORR-Triton replacement to a single layer +model = Net() +pass_args = { + "by": "name", + "fc1": { + "config": { + "name": "morr_triton", + "miniblock": 4, + "morr_init": True, + "trainable_morr_bias": False, + "trainable_morr_scale": False, + } + }, +} +new_model, _ = optical_module_transform_pass(model, pass_args) + +# Use additional config to initialise MORR linear module with noise modelling +model = Net() +pass_args = { + "by": "regex_name", + "^fc1$": { + "config": {"name": "morr_triton", "miniblock": 4}, + "additional": { + "trainable_morr_bias": False, + "trainable_morr_scale": False, + "thermal_crosstalk": True, + "coupling_factor": 0.04, + "drop_perc": 0.0, + "phase_noise": True, + "phase_noise_std": 0.04, + "in_bit": 8, + "w_bit": 8, + }, + }, +} +new_model, _ = optical_module_transform_pass(model, pass_args) +``` \ No newline at end of file