diff --git a/.gitignore b/.gitignore index 936138f2b..c1b540763 100644 --- a/.gitignore +++ b/.gitignore @@ -172,4 +172,6 @@ mase-trainer/ test-trainer/ # DiffLogic: tutorial files -docs/tutorials/difflogic/data-mnist/ \ No newline at end of file +docs/tutorials/difflogic/data-mnist/ + +test/self \ No newline at end of file diff --git a/src/chop/actions/search/search_space/nas_bert.py b/src/chop/actions/search/search_space/nas_bert.py index 3c4838ea1..760d3b27f 100644 --- a/src/chop/actions/search/search_space/nas_bert.py +++ b/src/chop/actions/search/search_space/nas_bert.py @@ -1199,9 +1199,11 @@ def forward( # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] if self.config.is_decoder and encoder_hidden_states is not None: - encoder_batch_size, encoder_sequence_length, _ = ( - encoder_hidden_states.size() - ) + ( + encoder_batch_size, + encoder_sequence_length, + _, + ) = encoder_hidden_states.size() encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) if encoder_attention_mask is None: encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) diff --git a/src/chop/distributed/tensor/_dispatch.py b/src/chop/distributed/tensor/_dispatch.py index 9aa96ae68..f5e08f714 100644 --- a/src/chop/distributed/tensor/_dispatch.py +++ b/src/chop/distributed/tensor/_dispatch.py @@ -201,8 +201,9 @@ def default_tensor(spec: _DTensorSpec) -> torch.Tensor: # did not already construct one random._rng_tracker = random.OffsetBasedRNGTracker(mesh.device_type) - first_arg, first_local_arg = cast(dtensor.DTensor, args[0]), cast( - torch.Tensor, local_tensor_args[0] + first_arg, first_local_arg = ( + cast(dtensor.DTensor, args[0]), + cast(torch.Tensor, local_tensor_args[0]), ) rng_context = ( random._rng_tracker._distribute_region(first_arg._spec) diff --git a/src/chop/ir/onnx/mase_onnx_graph.py b/src/chop/ir/onnx/mase_onnx_graph.py index 40bb7915c..f6e1d31e6 100644 --- a/src/chop/ir/onnx/mase_onnx_graph.py +++ b/src/chop/ir/onnx/mase_onnx_graph.py @@ -9,7 +9,6 @@ class MaseOnnxGraph: - def __init__( self, model_proto: onnx.onnx_ml_pb2.ModelProto, diff --git a/src/chop/models/cnv/cnv.py b/src/chop/models/cnv/cnv.py index 00df962a6..8bb9ad468 100644 --- a/src/chop/models/cnv/cnv.py +++ b/src/chop/models/cnv/cnv.py @@ -4,12 +4,8 @@ from typing import Any import numpy as np -from chop.nn.quantized.modules.conv2d import ( - Conv2dBinaryResidualSign, -) -from chop.nn.quantized.modules.linear import ( - LinearBinaryResidualSign, -) +from chop.nn.quantized.modules.conv2d import Conv2dBinaryResidualSign +from chop.nn.quantized.modules.linear import LinearBinaryResidualSign from chop.models.utils import register_mase_model, register_mase_checkpoint """ diff --git a/src/chop/nn/backward/modules/__init__.py b/src/chop/nn/backward/modules/__init__.py index 1dcc39e65..279e0c867 100644 --- a/src/chop/nn/backward/modules/__init__.py +++ b/src/chop/nn/backward/modules/__init__.py @@ -1,6 +1,4 @@ -from .linear import ( - CustomLinear, -) +from .linear import CustomLinear custom_module_map = { diff --git a/src/chop/nn/optical/__init__.py b/src/chop/nn/optical/__init__.py new file mode 100644 index 000000000..0310afb71 --- /dev/null +++ b/src/chop/nn/optical/__init__.py @@ -0,0 +1 @@ +from .modules import optical_module_map diff --git a/src/chop/nn/optical/modules/__init__.py b/src/chop/nn/optical/modules/__init__.py new file mode 100644 index 000000000..840b28a2a --- /dev/null +++ b/src/chop/nn/optical/modules/__init__.py @@ -0,0 +1,11 @@ +from .morr_linear import AllPassMORRCirculantLinear +from .morr_conv2d import AllPassMORRCirculantConv2d + +# from ..triton_modules.morr_linear_mem import TritonMemMORRLinear + + +optical_module_map = { + "linear_morr": AllPassMORRCirculantLinear, + "conv2d_morr": AllPassMORRCirculantConv2d, + # "linear_morr_triton": TritonMemMORRLinear, +} diff --git a/src/chop/nn/optical/modules/base_layer.py b/src/chop/nn/optical/modules/base_layer.py new file mode 100644 index 000000000..4ae2b35bf --- /dev/null +++ b/src/chop/nn/optical/modules/base_layer.py @@ -0,0 +1,76 @@ +""" +Description: +Author: Jiaqi Gu (jqgu@utexas.edu) +Date: 2021-06-08 18:55:05 +LastEditors: Jiaqi Gu (jqgu@utexas.edu) +LastEditTime: 2021-06-08 18:55:05 +""" + +from typing import Any, Dict, Optional +import torch +from torch import nn +from torch.types import Device + +__all__ = ["ONNBaseLayer"] + + +class ONNBaseLayer(nn.Module): + def __init__(self, *args, device: Device = torch.device("cpu"), **kwargs) -> None: + super().__init__(*args, **kwargs) + # cuda or cpu, defaults to cpu + self.device = device + + def build_parameters(self) -> None: + raise NotImplementedError + + def reset_parameters(self) -> None: + raise NotImplementedError + + @classmethod + def from_layer(cls, layer: nn.Module, *args, **kwargs) -> nn.Module: + raise NotImplementedError + + def get_num_parameters(self) -> int: + return sum(p.numel() for p in self.parameters() if p.requires_grad) + + def enable_fast_forward(self) -> None: + self.fast_forward_flag = True + + def disable_fast_forward(self) -> None: + self.fast_forward_flag = False + + def set_phase_variation( + self, noise_std: float, random_state: Optional[int] = None + ) -> None: + self.phase_noise_std = noise_std + + def set_gamma_noise( + self, noise_std: float, random_state: Optional[int] = None + ) -> None: + self.gamma_noise_std = noise_std + + def set_crosstalk_factor(self, crosstalk_factor: float) -> None: + self.crosstalk_factor = crosstalk_factor + + def set_weight_bitwidth(self, w_bit: int) -> None: + self.w_bit = w_bit + + def set_input_bitwidth(self, in_bit: int) -> None: + self.in_bit = in_bit + + def load_parameters(self, param_dict: Dict[str, Any]) -> None: + """ + description: update parameters based on this parameter dictionary\\ + param param_dict {dict of dict} {param_name: param_tensor, ...} + """ + for name, param in param_dict.items(): + getattr(self, name).data.copy_(param) + + def switch_mode_to(self, mode: str) -> None: + self.mode = mode + + def forward(self, x): + raise NotImplementedError + + def extra_repr(self) -> str: + return "" diff --git a/src/chop/nn/optical/modules/morr_conv2d.py b/src/chop/nn/optical/modules/morr_conv2d.py new file mode 100644 index 000000000..13f9532c3 --- /dev/null +++ b/src/chop/nn/optical/modules/morr_conv2d.py @@ -0,0 +1,517 @@ +""" +Author: Jiaqi Gu (jqgu@utexas.edu) +Date: 2021-01-27 01:08:44 +LastEditors: Jiaqi Gu (jqgu@utexas.edu) +LastEditTime: 2021-07-18 00:40:18 +""" + +from typing import Optional, Tuple +import logging + +import numpy as np +import torch +import torch.fft +from torch import Tensor, nn +from torch.nn import Parameter, init +from torch.nn.modules.utils import _pair +from torch.types import Device, _size + +from ..utils import MORRConfig_20um_MQ +from ..utils import mrr_roundtrip_phase_to_tr_func, mrr_roundtrip_phase_to_tr_fused +from ..utils import im2col_2d, toeplitz +from ..utils import morr_uniform_ +from ..utils import input_quantize_fn, weight_quantize_fn + +from .base_layer import ONNBaseLayer + +logger = logging.getLogger(__name__) + +__all__ = ["AllPassMORRCirculantConv2d"] + + +class AllPassMORRCirculantConv2d(ONNBaseLayer): + """ + All-pass MORR Conv2d layer, assumes (1) block-circulant matrix (2) differential rails (3) learnable balancing factors. + J. Gu, et al., "SqueezeLight: Towards Scalable Optical Neural Networks with Multi-Operand Ring Resonators" + https://doi.org/10.23919/DATE51398.2021.9474147 + """ + + __constants__ = [ + "stride", + "padding", + "dilation", + "groups", + "padding_mode", + "output_padding", + "in_channels", + "out_channels", + "kernel_size", + "miniblock", + ] + __annotations__ = {"bias": Optional[torch.Tensor]} + + _in_channels: int + out_channels: int + kernel_size: Tuple[int, ...] + stride: Tuple[int, ...] + padding: Tuple[int, ...] + dilation: Tuple[int, ...] + transposed: bool + output_padding: Tuple[int, ...] + groups: int + padding_mode: str + weight: Tensor + bias: Optional[Tensor] + miniblock: int + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: _size, + stride: _size = 1, + padding: _size = 0, + dilation: _size = 1, + groups: int = 1, + bias: bool = True, + padding_mode=None, # @johnny: unused argument + config=None, + device: Device = torch.device("cpu"), + ) -> None: + super(AllPassMORRCirculantConv2d, self).__init__() + assert config is not None + + miniblock = config.get("miniblock", 4) + MORRConfig = config.get("MORRConfig", MORRConfig_20um_MQ) + morr_init = config.get("morr_init", True) + trainable_morr_bias = config.get("trainable_morr_bias", False) + trainable_morr_scale = config.get("trainable_morr_scale", False) + device = config.get("device", device) + + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = _pair(kernel_size) + self.stride = _pair(stride) + self.padding = _pair(padding) + self.dilation = _pair(dilation) + self.groups = groups + assert ( + groups == 1 + ), f"Currently group convolution is not supported, but got group: {groups}" + self.in_channels_flat = ( + self.in_channels * self.kernel_size[0] * self.kernel_size[1] + ) + self.grid_dim_x = int(np.ceil(self.in_channels_flat / miniblock)) + self.grid_dim_y = int(np.ceil(self.out_channels / miniblock)) + self.in_channels_pad = self.grid_dim_x * miniblock + self.out_channels_pad = self.grid_dim_y * miniblock + self.miniblock = miniblock + + self.v_max = 10.8 + self.v_pi = 4.36 + self.gamma = np.pi / self.v_pi**2 + self.w_bit = 32 + self.in_bit = 32 + self.MORRConfig = MORRConfig + self.morr_init = morr_init + self.mrr_a = MORRConfig.attenuation_factor + self.mrr_r = MORRConfig.coupling_factor + self.trainable_morr_bias = trainable_morr_bias + self.trainable_morr_scale = trainable_morr_scale + self.device = device + + ### calculate FWHM (rad) + self.morr_fwhm = ( + -4 + * np.pi**2 + * MORRConfig.radius + * MORRConfig.effective_index + * ( + 1 / MORRConfig.resonance_wavelength + - 1 / (MORRConfig.resonance_wavelength - MORRConfig.bandwidth / 2) + ) + ) + + ### allocate parameters + self.weight = None + self.x_zero_pad = None + self.morr_output_scale = None ## learnable balancing factors implelemt by MRRs + self.morr_input_bias = None ## round-trip phase shift bias within MORR + self.morr_input_scale = ( + None ## scaling factor for the round-trip phase shift within MORR + ) + self.morr_gain = ( + 100 / (self.in_channels_flat // self.miniblock) + ) ** 0.5 ## set this TIA gain such that output variance is around 1 + ### build trainable parameters + self.build_parameters() + + ### quantization tool + self.input_quantizer = input_quantize_fn(self.in_bit, device=self.device) + self.weight_quantizer = weight_quantize_fn( + self.w_bit, alg="dorefa_pos" + ) ## [0-1] positive only, maintain the original scale + self.morr_output_scale_quantizer = weight_quantize_fn( + self.w_bit, alg="dorefa_sym" + ) ## [-1,1] full-range + + self.mrr_roundtrip_phase_to_tr = mrr_roundtrip_phase_to_tr_func( + a=self.mrr_a, r=self.mrr_r, intensity=True + ) + + ### default set to slow forward + self.disable_fast_forward() + ### default set no gamma noise + self.set_gamma_noise(0) + ### default set no crosstalk + self.disable_crosstalk() + ### default set no phase variation + self.disable_phase_variation() + + if bias: + self.bias = Parameter(torch.Tensor(out_channels).to(self.device)) + else: + self.register_parameter("bias", None) + + self.reset_parameters(morr_init=morr_init) + + # support fine-grained structured pruning for MORRs + self.finegrain_drop_mask = None + + def build_parameters(self) -> None: + ### MORR weights + self.weight = Parameter( + torch.ones( + self.grid_dim_y, + self.grid_dim_x, + self.miniblock, + device=self.device, + dtype=torch.float, + ) + ) + ### learnable balancing factor achieved by MRRs (morr_output_scale) + ### We use a single scaling factor for each block + self.morr_output_scale = Parameter( + torch.zeros(max(1, self.grid_dim_x // 2) + 1, device=self.device) + ) + if self.trainable_morr_bias: + ### initialize with the finest-granularity, i.e., per mini-block + self.morr_input_bias = Parameter( + torch.zeros( + self.grid_dim_y, + self.grid_dim_x, + device=self.device, + dtype=torch.float, + ) + ) + if self.trainable_morr_scale: + ### initialize with the finest-granularity, i.e., per mini-block + self.morr_input_scale = Parameter( + torch.zeros( + self.grid_dim_y, + self.grid_dim_x, + device=self.device, + dtype=torch.float, + ) + ) + + def reset_parameters(self, morr_init: bool = False) -> None: + if morr_init: + ### nonlinear curve aware initialization + morr_uniform_( + self.weight, + MORRConfig=self.MORRConfig, + n_op=self.miniblock, + biased=self.w_bit >= 16, + gain=2 if self.in_bit < 16 else 1, + ) + self.sigma_weight = self.weight.data.std().item() + self.weight_quant_gain = None + ### output distribution aware initialization to output scaling factor + t1 = mrr_roundtrip_phase_to_tr_fused( + torch.tensor([0]).float(), a=self.mrr_a, r=self.mrr_r, intensity=True + ) + t2 = mrr_roundtrip_phase_to_tr_fused( + torch.tensor([self.morr_fwhm * 2.4]).float(), + a=self.mrr_a, + r=self.mrr_r, + intensity=True, + ) + g = ( + (t2 - t1) / (2.4 * self.morr_fwhm) + ).item() ## 0~2.4 FWHM slope as a linear approximation + + self.sigma_out_scale = 4 / (3 * self.grid_dim_x**0.5 * g * self.morr_fwhm) + self.out_scale_quant_gain = None + init.normal_(self.morr_output_scale, 0, self.sigma_out_scale) + + else: + nn.init.kaiming_normal_(self.weight) + nn.init.kaiming_normal_(self.morr_output_scale) + self.sigma_weight = self.weight.data.std().item() + self.weight_quant_gain = None + self.sigma_out_scale = self.morr_output_scale.data.std().item() + self.out_scale_quant_gain = None + + if self.morr_input_bias is not None: + init.zeros_(self.morr_input_bias.data) + if self.morr_input_scale is not None: + init.zeros_(self.morr_input_scale.data) + + if self.bias is not None: + init.uniform_(self.bias, 0, 0) + + def sync_parameters(self, src: str = "weight") -> None: + """ + description: synchronize all parameters from the source parameters + """ + + raise NotImplementedError + + def build_weight(self) -> Tensor: + if self.w_bit < 16: + ### differentiable quantizer based on STE to enable QAT (Dorefa-Net, arXiv 2016) + weight = self.weight_quantizer(self.weight) + else: + weight = self.weight.abs() ## have to be all positive + if self.finegrain_drop_mask is not None: + weight = weight.mul(self.finegrain_drop_mask.float()) + + return weight + + def enable_fast_forward(self) -> None: + self.fast_forward_flag = True + + def disable_fast_forward(self) -> None: + self.fast_forward_flag = False + + def set_gamma_noise( + self, noise_std: float, random_state: Optional[int] = None + ) -> None: + self.gamma_noise_std = noise_std + + def load_parameters(self, param_dict) -> None: + """ + description: update parameters based on this parameter dictionary\\ + param param_dict {dict of dict} {layer_name: {param_name: param_tensor, ...}, ...} + """ + for name, param in param_dict.items(): + getattr(self, name).data.copy_(param) + + def set_weight_bitwidth(self, w_bit: int) -> None: + self.w_bit = w_bit + self.weight_quantizer.set_bitwidth(w_bit) + self.morr_output_scale_quantizer.set_bitwidth(w_bit) + + def set_input_bitwidth(self, in_bit: int) -> None: + self.in_bit = in_bit + self.input_quantizer.set_bitwidth(in_bit) + + def input_modulator(self, x: Tensor) -> Tensor: + ### voltage to power, which is proportional to the phase shift + return x * x + + def set_crosstalk_coupling_matrix( + self, coupling_factor: float, drop_perc: float = 0 + ) -> None: + ### crosstalk coupling matrix is a symmetric matrix, but the intra-MORR crosstalk can be taken as a round-trip phase shift scaling factor, which is proportional to the number of segments after pruned. + ### See SqueezeLight paper + ### drop-perc is the pruning percentage. + assert 0 <= coupling_factor <= 1, logger.error( + f"Coupling factor must in [0,1], but got {coupling_factor}" + ) + + self.crosstalk_factor = ( + 1 + max(3, (self.miniblock * (1 - drop_perc) - 1)) * coupling_factor + ) + + def enable_crosstalk(self) -> None: + self.enable_thermal_crosstalk = True + + def disable_crosstalk(self) -> None: + self.enable_thermal_crosstalk = False + + def set_phase_variation(self, phase_noise_std: float = 0) -> None: + self.phase_noise_std = phase_noise_std + + def enable_phase_variation(self) -> None: + self.enable_phase_noise = True + + def disable_phase_variation(self) -> None: + self.enable_phase_noise = False + + def enable_trainable_morr_scale(self) -> None: + self.trainable_morr_scale = True + + def disable_trainable_morr_scale(self) -> None: + self.trainable_morr_scale = False + + def enable_trainable_morr_bias(self) -> None: + self.trainable_morr_bias = True + + def disable_trainable_morr_bias(self) -> None: + self.trainable_morr_bias = False + + @property + def morr_scale(self) -> Tensor: + return torch.sigmoid(self.morr_input_scale.unsqueeze(0).unsqueeze(-1)) + 0.2 + + @property + def morr_bias(self) -> Tensor: + return self.morr_fwhm * torch.tanh( + self.morr_input_bias.unsqueeze(0).unsqueeze(-1) + ) + + def propagate_morr(self, weight: Tensor, x: Tensor) -> Tensor: + """Propagate through the analytically calculated transfer matrix of MORR. + + :param weight: First column vectors in the block-circulant matrix. + :type weight: Tensor + :param x: Input tensor. + :type x: Tensor + + :return: Output of MORR array. + :rtype: Tensor + """ + + x = x.t() # [h_out*w_out*bs, ks*ks*inc] + x = x.view(x.size(0), self.grid_dim_x, self.miniblock) # [h_out*w_out*bs, q, k] + + ### injecting crosstalk into weights is more efficient + if self.enable_thermal_crosstalk and self.crosstalk_factor > 1: + weight = weight * self.crosstalk_factor + + ### construct block-circulant matrix + weight = toeplitz(weight).unsqueeze(0) # [1, p, q, k, k] + x = x.unsqueeze(1).unsqueeze(-1) # [h*w*bs, 1, q, k, 1] + x = weight.matmul(x).squeeze(-1) # [h*w*bs, p, q, k] + + if self.enable_phase_noise and self.phase_noise_std > 1e-5: + x = x + torch.zeros_like(x).normal_( + 0, self.phase_noise_std + ) # [h*w*bs, p, q, k] + + ### input scaling, learnable MORR nonlinearity + if self.trainable_morr_scale: + x = x * self.morr_scale # [h*w*bs, p, q, k] + ### input biasing, learnable MORR nonlinearity + if self.trainable_morr_bias: + x = x - self.morr_bias + + ### Use theoretical transmission function for trainable MORR nonlinearity + ### x is the phase detuning, x=0 means on-resonance + ### x: [h_out*w_out*bs, p, q, k] + x = self.mrr_roundtrip_phase_to_tr(x) + + ### output scaling or learnable balancing factors + if self.w_bit < 16: + morr_output_scale = self.morr_output_scale_quantizer(self.morr_output_scale) + if self.out_scale_quant_gain is None: + self.out_scale_quant_gain = ( + self.sigma_out_scale / morr_output_scale.data.std().item() + ) + morr_output_scale = morr_output_scale.mul( + self.out_scale_quant_gain + ) ### gain factor from Tanh used in quantization + else: + morr_output_scale = self.morr_output_scale + + scale = morr_output_scale[:-1] + scale_pad = morr_output_scale[-1:] + + ### differential rails + if self.grid_dim_x % 2 == 0: + # even blocks + scale = torch.cat([scale, -scale], dim=0) + else: + # odd blocks + if self.grid_dim_x > 1: + scale = torch.cat([morr_output_scale, -scale], dim=0) + else: + scale = scale_pad + scale = scale.unsqueeze(0).unsqueeze(0).unsqueeze(0) # [1, 1, 1, q] + + x = scale.matmul( + x + ) # [1,1,1,q]x[h_out*w_out*bs, p, q, k]=[h_out*w_out*bs, p, 1, k] + x = x.view(x.size(0), -1).t() # [p*k, h_out*w_out*bs] + if self.out_channels_pad > self.out_channels: + x = x[: self.out_channels, :] # [outc, h_out*w_out*bs] + return x + + def morr_conv2d(self, X: Tensor, W: Tensor) -> Tensor: + ### W : [p, q, k] + n_x = X.size(0) + + _, X_col, h_out, w_out = im2col_2d( + None, + X, + stride=self.stride[0], + padding=self.padding[0], + w_size=( + self.out_channels, + self.in_channels, + self.kernel_size[0], + self.kernel_size[1], + ), + ) + ## zero-padding X_col + if self.in_channels_pad > self.in_channels_flat: + if self.x_zero_pad is None or self.x_zero_pad.size(1) != X_col.size(1): + self.x_zero_pad = torch.zeros( + self.in_channels_pad - self.in_channels_flat, + X_col.size(1), + dtype=torch.float32, + device=self.device, + ) + + X_col = torch.cat([X_col, self.x_zero_pad], dim=0) + # matmul + out = self.propagate_morr(W, X_col) # [outc, w_out] + out = out.view(self.out_channels, h_out, w_out, n_x) + out = out.permute(3, 0, 1, 2).contiguous() + + return out + + def get_finegrain_drop_mask(self, topk: int) -> Tensor: + if self.w_bit < 16: + weight = self.weight_quantizer(self.weight.data) # [p, q, k] + else: + weight = self.weight.data.abs() + indices = weight.argsort(dim=-1) + mask = torch.ones_like(weight, dtype=torch.bool, device=weight.device) + + drop_indices = indices[:, :, 0:-topk] + mask.scatter_(2, drop_indices, 0) + self.finegrain_drop_mask = mask + return mask + + def apply_finegrain_drop_mask(self, mask: Tensor) -> None: + if self.w_bit < 16: + self.weight.data.masked_fill_(~mask.view_as(self.weight.data), -1000) + else: + self.weight.data.masked_fill_(~mask.view_as(self.weight.data), 0) + + def get_output_dim(self, img_height: int, img_width: int) -> Tuple[int, int]: + """ + get the output features size + """ + h_out = (img_height - self.kernel_size[0] + 2 * self.padding[0]) / self.stride[ + 0 + ] + 1 + w_out = (img_width - self.kernel_size[1] + 2 * self.padding[1]) / self.stride[ + 1 + ] + 1 + return (int(h_out), int(w_out)) + + def forward(self, x: Tensor) -> Tensor: + if self.in_bit < 16: + x = self.input_quantizer(x) + weight = self.build_weight() + x = self.input_modulator(x) + x = self.morr_conv2d(x, weight) + + if self.bias is not None: + x = x + self.bias.unsqueeze(0).unsqueeze(2).unsqueeze(3) + + return x diff --git a/src/chop/nn/optical/modules/morr_linear.py b/src/chop/nn/optical/modules/morr_linear.py new file mode 100644 index 000000000..946cf864a --- /dev/null +++ b/src/chop/nn/optical/modules/morr_linear.py @@ -0,0 +1,486 @@ +""" +Description: +Author: Jiaqi Gu (jqgu@utexas.edu) +Date: 2022-04-18 14:19:57 +LastEditors: Jiaqi Gu (jqgu@utexas.edu) +LastEditTime: 2022-04-18 16:21:37 +""" + +from typing import Optional +import logging + +import numpy as np +import torch +import torch.fft +from torch import Tensor +from torch.nn import Parameter, init +from torch.types import Device + +from ..utils import MORRConfig_20um_MQ +from ..utils import mrr_roundtrip_phase_to_tr_func, mrr_roundtrip_phase_to_tr_fused +from ..utils import toeplitz +from ..utils import morr_uniform_ +from ..utils import input_quantize_fn, weight_quantize_fn +from .base_layer import ONNBaseLayer + +logger = logging.getLogger(__name__) + +__all__ = ["AllPassMORRCirculantLinear"] + + +class AllPassMORRCirculantLinear(ONNBaseLayer): + """ + All-pass MORR Linear layer, assumes (1) block-circulant matrix (2) differential rails (3) learnable balancing factors. + J. Gu, et al., "SqueezeLight: Towards Scalable Optical Neural Networks with Multi-Operand Ring Resonators" + https://doi.org/10.23919/DATE51398.2021.9474147 + """ + + __constants__ = ["in_features", "out_features"] + in_features: int + out_features: int + miniblock: int + weight: Tensor + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = False, + config={}, + device: Device = torch.device("cpu"), + ) -> None: + super(AllPassMORRCirculantLinear, self).__init__() + self.in_features = in_features + self.out_features = out_features + + miniblock_size = config.get("miniblock", 4) + self.miniblock = miniblock_size + self.grid_dim_x = int(np.ceil(self.in_features / miniblock_size)) + self.grid_dim_y = int(np.ceil(self.out_features / miniblock_size)) + self.in_features_pad = self.grid_dim_x * miniblock_size + self.out_features_pad = self.grid_dim_y * miniblock_size + + self.v_max = 10.8 + self.v_pi = 4.36 + self.gamma = np.pi / self.v_pi**2 + self.w_bit = 32 + self.in_bit = 32 + + morr_config = config.get("MORRConfig", MORRConfig_20um_MQ) + morr_init_val = config.get("morr_init", MORRConfig_20um_MQ) + self.MORRConfig = morr_config + self.morr_init = morr_init_val + self.mrr_a = morr_config.attenuation_factor + self.mrr_r = morr_config.coupling_factor + self.trainable_morr_bias = config.get("trainable_morr_bias", MORRConfig_20um_MQ) + self.trainable_morr_scale = config.get( + "trainable_morr_scale", MORRConfig_20um_MQ + ) + self.device = device + ### calculate FWHM (rad) + self.morr_fwhm = ( + -4 + * np.pi**2 + * morr_config.radius + * morr_config.effective_index + * ( + 1 / morr_config.resonance_wavelength + - 1 / (morr_config.resonance_wavelength - morr_config.bandwidth / 2) + ) + ) + + ### allocate parameters + self.weight = None + self.x_zero_pad = None + self.morr_output_scale = None ## learnable balancing factors implelemt by MRRs + self.morr_input_bias = None ## round-trip phase shift bias within MORR + self.morr_input_scale = ( + None ## scaling factor for the round-trip phase shift within MORR + ) + self.morr_gain = ( + 100 / (self.in_features // self.miniblock) + ) ** 0.5 ## TIA gain, calculated such that output variance is around 1 + ### build trainable parameters + self.build_parameters() + + ### quantization tool + self.input_quantizer = input_quantize_fn(self.in_bit, device=self.device) + self.weight_quantizer = weight_quantize_fn( + self.w_bit, alg="dorefa_pos" + ) ## [0-1] positive only, maintain the original scale + self.morr_output_scale_quantizer = weight_quantize_fn( + self.w_bit, alg="dorefa_sym" + ) ## [-1,1] full-range + + self.mrr_roundtrip_phase_to_tr = mrr_roundtrip_phase_to_tr_func( + a=self.mrr_a, r=self.mrr_r, intensity=True + ) + + ### default set to slow forward + self.disable_fast_forward() + ### default set no gamma noise + self.set_gamma_noise(0) + ### default set no crosstalk + self.disable_crosstalk() + ### default set no phase variation + self.disable_phase_variation() + + if bias: + self.bias = Parameter(torch.Tensor(out_features).to(self.device)) + else: + self.register_parameter("bias", None) + + self.reset_parameters(morr_init=morr_init_val) + self.finegrain_drop_mask = None + + def build_parameters(self) -> None: + + self.weight = Parameter( + torch.ones( + self.grid_dim_y, + self.grid_dim_x, + self.miniblock, + device=self.device, + dtype=torch.float, + ) + ) + ### Learnable balancing factor (morr_output_scale) + ### We use a single scaling factor for each block + self.morr_output_scale = Parameter( + torch.randn(1, 1, max(1, self.grid_dim_x // 2) + 1, 1, device=self.device) + ) + if self.trainable_morr_bias: + ### initialize with the finest-granularity, i.e., per mini-block + self.morr_input_bias = Parameter( + torch.zeros( + self.grid_dim_y, + self.grid_dim_x, + device=self.device, + dtype=torch.float, + ) + ) + if self.trainable_morr_scale: + ### initialize with the finest-granularity, i.e., per mini-block + self.morr_input_scale = Parameter( + torch.zeros( + self.grid_dim_y, + self.grid_dim_x, + device=self.device, + dtype=torch.float, + ) + ) + + def reset_parameters(self, morr_init: bool = False) -> None: + ### nonlinear curve aware initialization + if morr_init: + ## initialize weight + morr_uniform_( + self.weight, + MORRConfig=self.MORRConfig, + n_op=self.miniblock, + biased=self.w_bit >= 16, + gain=2 if self.in_bit < 16 else 1, + ) # quantization needs zero-center + self.sigma_weight = self.weight.data.std().item() + self.weight_quant_gain = None + + ## output distribution aware initialization to output scaling factor + t1 = mrr_roundtrip_phase_to_tr_fused( + torch.tensor([0]).float(), a=self.mrr_a, r=self.mrr_r, intensity=True + ) + t2 = mrr_roundtrip_phase_to_tr_fused( + torch.tensor([self.morr_fwhm * 2.4]).float(), + a=self.mrr_a, + r=self.mrr_r, + intensity=True, + ) + g = ( + (t2 - t1) / (2.4 * self.morr_fwhm) + ).item() ## 0~2.4 FWHM slope as a linear approximation + + self.sigma_out_scale = 4 / (3 * self.grid_dim_x**0.5 * g * self.morr_fwhm) + self.out_scale_quant_gain = None + init.normal_(self.morr_output_scale, 0, self.sigma_out_scale) + else: + init.kaiming_normal_(self.weight.data) + init.kaiming_normal_(self.morr_output_scale.data) + self.sigma_weight = self.weight.data.std().item() + self.weight_quant_gain = None + self.sigma_out_scale = self.morr_output_scale.data.std().item() + self.out_scale_quant_gain = None + + if self.morr_input_bias is not None: + self.morr_input_bias.data.zero_() + if self.morr_input_scale is not None: + ### after sigmoid, it cooresponds to 1 scale + init.normal_(self.morr_input_scale.data, 2, 0.1) + + if self.bias is not None: + init.uniform_(self.bias, 0, 0) + + def sync_parameters(self, src: str = "weight") -> None: + """ + description: synchronize all parameters from the source parameters + """ + + raise NotImplementedError + + def build_weight(self) -> Tensor: + if self.w_bit < 16: + ### differentiable quantizer based on STE to enable QAT (Dorefa-Net, arXiv 2016) + weight = self.weight_quantizer(self.weight) + + ## rescale weights after quantization can maintain the initialization distribution + if self.weight_quant_gain is None: + self.weight_quant_gain = self.sigma_weight / weight.data.std() + if self.trainable_morr_scale: + morr_scale = self.morr_scale * self.weight_quant_gain + else: + morr_scale = self.weight_quant_gain + weight = weight.mul( + morr_scale + ) ### gain factor from Tanh used in quantization + + ### quantize learnable balancing factor + morr_output_scale = self.morr_output_scale_quantizer(self.morr_output_scale) + else: + weight = self.weight.abs() # positive only + morr_output_scale = ( + self.morr_output_scale - self.morr_output_scale.data.mean() + ) + + if self.finegrain_drop_mask is not None: + weight = weight.mul(self.finegrain_drop_mask.float()) + + ## differential balancing factor concatenation + scale = morr_output_scale[..., :-1, :] + scale_pad = morr_output_scale[..., -1:, :] + if self.grid_dim_x % 2 == 0: + # even blocks + scale = torch.cat([scale, -scale], dim=2) # [1, 1, q, 1] + else: + # odd blocks + if self.grid_dim_x > 1: + scale = torch.cat([morr_output_scale, -scale], dim=2) # [1, 1, q, 1] + else: + scale = scale_pad # [1, 1, q, 1] + morr_output_scale = scale.squeeze(-1).unsqueeze(0) # [1 ,1, 1, q] + + return weight, morr_output_scale + + def enable_fast_forward(self) -> None: + self.fast_forward_flag = True + + def disable_fast_forward(self) -> None: + self.fast_forward_flag = False + + def set_gamma_noise( + self, noise_std: float, random_state: Optional[int] = None + ) -> None: + self.gamma_noise_std = noise_std + + def load_parameters(self, param_dict) -> None: + """ + description: update parameters based on this parameter dictionary\\ + param param_dict {dict of dict} {layer_name: {param_name: param_tensor, ...}, ...} + """ + for name, param in param_dict.items(): + getattr(self, name).data.copy_(param) + + def set_weight_bitwidth(self, w_bit: int) -> None: + self.w_bit = w_bit + self.weight_quantizer.set_bitwidth(w_bit) + self.morr_output_scale_quantizer.set_bitwidth(w_bit) + + def set_input_bitwidth(self, in_bit: int) -> None: + self.in_bit = in_bit + self.input_quantizer.set_bitwidth(in_bit) + + def input_modulator(self, x: Tensor) -> Tensor: + ### voltage to power, which is proportional to the phase shift + return x * x + + def set_crosstalk_coupling_matrix( + self, coupling_factor: float, drop_perc: float = 0 + ) -> None: + ### crosstalk coupling matrix is a symmetric matrix, but the intra-MORR crosstalk can be taken as a round-trip phase shift scaling factor, which is proportional to the number of segments after pruned. + ### drop-perc is the pruning percentage. + assert 0 <= coupling_factor <= 1, logger.error( + f"Coupling factor must in [0,1], but got {coupling_factor}" + ) + + self.crosstalk_factor = ( + 1 + max(3, (self.miniblock * (1 - drop_perc) - 1)) * coupling_factor + ) + + def enable_crosstalk(self) -> None: + self.enable_thermal_crosstalk = True + + def disable_crosstalk(self) -> None: + self.enable_thermal_crosstalk = False + + def set_phase_variation(self, phase_noise_std: float = 0) -> None: + self.phase_noise_std = phase_noise_std + + def enable_phase_variation(self) -> None: + self.enable_phase_noise = True + + def disable_phase_variation(self) -> None: + self.enable_phase_noise = False + + def enable_trainable_morr_scale(self) -> None: + self.trainable_morr_scale = True + + def disable_trainable_morr_scale(self) -> None: + self.trainable_morr_scale = False + + def enable_trainable_morr_bias(self) -> None: + self.trainable_morr_bias = True + + def disable_trainable_morr_bias(self) -> None: + self.trainable_morr_bias = False + + @property + def morr_bias(self) -> Tensor: + if self.morr_input_bias is None: + return None + # return 2 * self.morr_fwhm * torch.sigmoid(self.morr_input_bias.unsqueeze(0).unsqueeze(-1)) + return self.morr_fwhm * torch.tanh( + self.morr_input_bias.unsqueeze(0).unsqueeze(-1) + ) + + @property + def morr_scale(self) -> Tensor: + if self.morr_input_scale is None: + return None + return torch.sigmoid(self.morr_input_scale.unsqueeze(-1)) + 0.2 # [p, q, 1] + + def propagate_morr( + self, weight: Tensor, x: Tensor, morr_output_scale: Tensor + ) -> Tensor: + """ + @description: propagate through the analytically calculated transfer matrix of molg. We implement circulant matrix multiplication using fast circ matmul + @param weight {torch.Tensor} two phase shifters in the MZI-based attenuators + @param x {torch.Tensor} complex-valued input + @param morr_output_scale {torch.Tensor} learnable balancing factors + @return: y {torch.Tensor} output of attenuators + """ + ### x : [bs, q, k] + ### weights: [p, q, k] + ### morr_output_scale: [1, 1, 1, q] + + ### input scaling [TCAD'21], must have valid ranges. too small will have dead neuron and not enough nonlinearity; too large will have larger power, cross-channel crosstalk. [0.2 - 1.2] will be suitable + ## build circulant weight matrix + # crosstalk on the weights are much cheaper to compute than on the phase shift + if self.enable_thermal_crosstalk and self.crosstalk_factor > 1: + weight = weight * self.crosstalk_factor + weight = toeplitz(weight).unsqueeze(0) # [1, p, q, k, k] + x = x.unsqueeze(1).unsqueeze(-1) # [bs, 1, q, k, 1] + x = weight.matmul(x).squeeze(-1) # [bs, p, q, k] + + if self.enable_phase_noise and self.phase_noise_std > 1e-5: + x = x + torch.zeros_like(x).normal_(0, self.phase_noise_std) + + ### input biasing [TCAD'21], must have valid ranges. too large will have power issue and cross-channel crosstalk. [-2FWHM ~ 0] + if self.trainable_morr_bias: + x = x - self.morr_bias + + ### Use theoretical transmission function for trainable MORR nonlinearity [TCAD'21] + ### x is the phase detuning, x=0 means on-resonance + ### phase: [bs, p, q, k] + x = self.mrr_roundtrip_phase_to_tr(x) # 3x faster than autograd + + ## implement balancing factor as dot-product + """ + if(self.w_bit < 16): + morr_output_scale = self.morr_output_scale_quantizer(self.morr_output_scale) + if(self.sigma_out_scale_quant_gain is None): + self.sigma_out_scale_quant_gain = self.sigma_out_scale / morr_output_scale.data.std().item() + morr_output_scale = morr_output_scale.mul(self.sigma_out_scale_quant_gain)### gain factor from Tanh used in quantization + else: + morr_output_scale = self.morr_output_scale + # morr_output_scale = morr_output_scale * self.morr_gain + scale = morr_output_scale[..., :-1, :] + scale_pad = morr_output_scale[..., -1:, :] + + # print("morr diff transmission:", end=", ") + # diff = x[..., :x.size(2)//2,:]-x[..., x.size(2)//2:,:] + # print_stat(diff) + if(self.grid_dim_x % 2 == 0): + #even blocks + scale = torch.cat([scale, -scale], dim=2) # [1, 1, q, 1] + else: + # odd blocks + if(self.grid_dim_x > 1): + scale = torch.cat([morr_output_scale, -scale], dim=2) # [1, 1, q, 1] + else: + scale = scale_pad # [1, 1, q, 1] + scale = scale.squeeze(-1).unsqueeze(0) # [1 ,1, 1, q] + # print("output scale Q:", end=", ") + # print_stat(scale[..., :scale.size(-1)//2]) + """ + x = morr_output_scale.matmul(x) # [1, 1, 1, q] x [bs, p, q, k] = [bs, p, 1, k] + x = x.flatten(1) # [bs, p*k] + return x + + def get_finegrain_drop_mask(self, topk: int) -> Tensor: + if self.w_bit < 16: + weight = self.weight_quantizer(self.weight.data) # [p, q, k] + else: + weight = self.weight.data.abs() + indices = weight.argsort(dim=-1) + mask = torch.ones_like(weight, dtype=torch.bool, device=weight.device) + + drop_indices = indices[:, :, 0:-topk] + mask.scatter_(2, drop_indices, 0) + self.finegrain_drop_mask = mask + return mask + + def apply_finegrain_drop_mask(self, mask: Tensor) -> None: + if self.w_bit < 16: + self.weight.data.masked_fill_(~mask.view_as(self.weight.data), -1000) + else: + self.weight.data.masked_fill_(~mask.view_as(self.weight.data), 0) + + def forward(self, x: Tensor) -> Tensor: + # if used in transformer + is_transformer = len(x.shape) == 3 + if is_transformer: + B, N, D = x.shape + + assert ( + x.size(-1) == self.in_features + ), f"[E] Input dimension does not match the weight size {self.out_features, self.in_features}, but got input size ({tuple(x.size())}))" + if self.in_bit < 16: + x = self.input_quantizer(x) + + weight, morr_output_scale = self.build_weight() + if self.in_features_pad > self.in_features: + if self.x_zero_pad is None or self.x_zero_pad.size(0) != x.size(0): + self.x_zero_pad = torch.zeros( + x.size(0), + self.in_features_pad - self.in_features, + device=x.device, + dtype=x.dtype, + ) + x = torch.cat([x, self.x_zero_pad], dim=1) + + x = x.view(-1, self.grid_dim_x, self.miniblock) + + ### modulation + ### x: [bs, q, k] -> [bs, q, k] + x = self.input_modulator(x) + + ### propagate through morr array + ### x: [bs, q, k] -> [bs, p*k] + x = self.propagate_morr(weight, x, morr_output_scale) + + if self.out_features < self.out_features_pad: + x = x[..., : self.out_features] + if self.bias is not None: + x = x + self.bias.unsqueeze(0) + + # adjust output shape if used in transformer + if is_transformer: + x = x.view(B, N, self.out_features) + return x diff --git a/src/chop/nn/optical/triton_modules/dtype.py b/src/chop/nn/optical/triton_modules/dtype.py new file mode 100644 index 000000000..caaa77e69 --- /dev/null +++ b/src/chop/nn/optical/triton_modules/dtype.py @@ -0,0 +1,17 @@ +import torch +import triton.language as tl + + +TORCH_DTYPE_TO_TRITON = { + torch.float16: tl.float16, + torch.float32: tl.float32, + torch.bfloat16: tl.bfloat16, + torch.int8: tl.int8, + torch.uint8: tl.uint8, + torch.int16: tl.int16, + torch.uint16: tl.uint16, + torch.int32: tl.int32, + torch.uint32: tl.uint32, + torch.float8_e4m3fn: tl.float8e4nv, + torch.float8_e5m2: tl.float8e5, +} diff --git a/src/chop/nn/optical/triton_modules/morr_linear_kernel_mem.py b/src/chop/nn/optical/triton_modules/morr_linear_kernel_mem.py new file mode 100644 index 000000000..c2eb11b27 --- /dev/null +++ b/src/chop/nn/optical/triton_modules/morr_linear_kernel_mem.py @@ -0,0 +1,1035 @@ +import os + +# os.environ["TRITON_INTERPRET"] = "1" + +import torch +from torch import Tensor +import triton +import triton.language as tl +import pdb + +from .dtype import TORCH_DTYPE_TO_TRITON + +PACKAGE_NAME = "mase_triton" +from ..utils import ( + toeplitz, + input_quantize_fn, + weight_quantize_fn, + mrr_roundtrip_phase_to_tr_func, +) +from .quantize import _input_quantize_fn, _weight_quantize_fn + + +def _get_autotune_configs(): + configs = [] + for _M in [1, 2, 4, 8]: + for _P in [1, 2, 4, 8]: + for _Q in [1, 2, 4, 8]: + configs.append( + triton.Config( + { + "BLOCK_SIZE_M": _M, + "BLOCK_SIZE_P": _P, + "BLOCK_SIZE_Q": _Q, + # "BLOCK_SIZE_K1": 4, + "BLOCK_SIZE_K2": 1, + }, + num_stages=3, + num_warps=8, + ) + ) + return configs + + +@triton.jit +def _mrr_roundtrip_phase_to_tr_func( + x: tl.tensor, + a: tl.constexpr = 0.8, + r: tl.constexpr = 0.9, + intensity: tl.constexpr = False, +): + """ + Applies a round-trip phase correction to the input tensor. + """ + c1 = -2.0 * a * r + c2 = a * a + r * r + c3 = 1.0 + r * r * a * a - a * a - r * r + + cos_x = tl.cos(x) + numerator = cos_x * c1 + c2 + denominator = numerator + c3 + x = numerator / denominator + if not intensity: + x = tl.sqrt(x) + return x + + +# @triton.autotune( +# configs= [ +# triton.Config( +# { +# "BLOCK_SIZE_M": 1, +# "BLOCK_SIZE_P": 1, +# "BLOCK_SIZE_Q": 1, +# # "BLOCK_SIZE_K1": 2, +# "BLOCK_SIZE_K2": 1, +# }, +# num_stages=3, +# num_warps=8, +# ),], +# key=["M", "P", "Q", "K"], +# ) +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "P", "Q", "K"], +) +@triton.jit +def morr_propagate_kernel( + x_ptr, + w_ptr, + o_ptr, + b_ptr, + M, + P, + Q, + K, + grid_dim_q, + grid_dim_p, + miniblock, + crosstalk_factor, + phase_noise_std, + mrr_a, + mrr_r, + in_bit, + w_bit, + seed, + # stride + stride_wm, + stride_wp, + stride_wq, + stride_wk1, + stride_wk2, + stride_xm, + stride_xp, + stride_xq, + stride_xk1, + stride_xk2, + stride_bm, + stride_bp, + stride_bq, + stride_bk1, + stride_om, + stride_op, + stride_oq, + stride_ok1, + stride_ok2, + finegrain_drop_mask, + ENABLE_PHASE_NOISE: tl.constexpr, + ENABLE_THERMAL_CROSSTALK: tl.constexpr, + TRAINABLE_MORR_BIAS: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_P: tl.constexpr, + BLOCK_SIZE_Q: tl.constexpr, + BLOCK_SIZE_K1: tl.constexpr, + BLOCK_SIZE_K2: tl.constexpr, + INPUT_DTYPE: tl.constexpr, +): + + # Program ID for block-based processing + # each program is assigned GROUP_SIZE_MPQ * [1, 1, miniblock, 1] block + pid = tl.program_id(axis=0) + # number of blocks (each program needs to handle) along M, P, Q dimension + pnum_m = grid_dim_p * grid_dim_q + pnum_p = grid_dim_p // BLOCK_SIZE_P + pnum_q = grid_dim_q // BLOCK_SIZE_Q + # block dimension of current program + pid_m = pid // (pnum_q * pnum_p) + pid_p = (pid // pnum_q) % pnum_p + pid_q = pid % pnum_q + + # starting element's m, p, q coordinates in the global tensor + start_m = pid_m * BLOCK_SIZE_M + start_p = pid_p * BLOCK_SIZE_P + start_q = pid_q * BLOCK_SIZE_Q + + # w [1, p, q, k, 1] -> toeplitz [1, p, q, k, k] + offs_wm = tl.arange(0, 1) + offs_wp = pid_p * BLOCK_SIZE_P + tl.arange(0, 1) + offs_wq = pid_q * BLOCK_SIZE_Q + tl.arange(0, 1) + offs_wk1 = tl.arange(0, BLOCK_SIZE_K1) + offs_wk2 = tl.arange(0, BLOCK_SIZE_K1) + + offs_xm = pid_m * BLOCK_SIZE_M + tl.arange(0, 1) + offs_xp = tl.arange(0, 1) + offs_xq = pid_q * BLOCK_SIZE_Q + tl.arange(0, 1) + offs_xk1 = tl.arange(0, BLOCK_SIZE_K1) + offs_xk2 = tl.arange(0, BLOCK_SIZE_K2) + # morr_bias: [1, p, q, 1] + offs_bm = tl.arange(0, 1) + offs_bp = pid_p * BLOCK_SIZE_P + tl.arange(0, 1) + offs_bq = pid_q * BLOCK_SIZE_Q + tl.arange(0, 1) + offs_bk1 = tl.arange(0, 1) + + w_ptrs = w_ptr + ( + offs_wm[:, None, None, None, None] * stride_wm + + offs_wp[None, :, None, None, None] * stride_wp + + offs_wq[None, None, :, None, None] * stride_wq + + offs_wk1[None, None, None, :, None] * stride_wk1 + + offs_wk2[None, None, None, None, :] * stride_wk2 + ) + x_ptrs = x_ptr + ( + offs_xm[:, None, None, None, None] * stride_xm + + offs_xp[None, :, None, None, None] * stride_xp + + offs_xq[None, None, :, None, None] * stride_xq + + offs_xk1[None, None, None, :, None] * stride_xk1 + + offs_xk2[None, None, None, None, :] * stride_xk2 + ) + b_ptrs = b_ptr + ( + offs_bm[:, None, None, None, None] * stride_bm + + offs_bp[None, :, None, None, None] * stride_bp + + offs_bq[None, None, :, None, None] * stride_bq + + offs_bk1[None, None, None, :, None] * stride_bk1 + ) + + acc = tl.zeros( + (BLOCK_SIZE_M, BLOCK_SIZE_P, BLOCK_SIZE_Q, BLOCK_SIZE_K1, BLOCK_SIZE_K2), + dtype=tl.float32, + ) + m_indices = tl.arange(0, BLOCK_SIZE_M)[:, None, None, None, None] + p_indices = tl.arange(0, BLOCK_SIZE_P)[None, :, None, None, None] + q_indices = tl.arange(0, BLOCK_SIZE_Q)[None, None, :, None, None] + + for m_local in range(BLOCK_SIZE_M): + m = start_m + m_local + for p_local in range(BLOCK_SIZE_P): + p = start_p + p_local + for q_local in range(BLOCK_SIZE_Q): + q = start_q + q_local + + w_mask = (p < P) & (q < Q) + x_mask = (m < M) & (q < Q) + b_mask = (p < P) & (q < Q) + + w = tl.load(w_ptrs, mask=w_mask, other=0.0) + x = tl.load(x_ptrs, mask=x_mask, other=0.0) + b = tl.load(b_ptrs, mask=b_mask, other=0.0) + + w = w.reshape(BLOCK_SIZE_K1, BLOCK_SIZE_K1) # [1, 1, 1, k, k] -> [k, k] + x = x.reshape(BLOCK_SIZE_K1, BLOCK_SIZE_K2) # [1, 1, 1, k, 1] -> [k, 1] + + x = x * x # input_modulator() + # ----- propagate_morr() ----- + + # apply thermal crosstalk noise + if ENABLE_THERMAL_CROSSTALK: + w = w * crosstalk_factor + + # MatMals + # TODO: tl.dot requires 16*16 matrix at least, this is a workaround + x = tl.trans(x) + x = tl.broadcast_to(x, (BLOCK_SIZE_K1, BLOCK_SIZE_K1)) + x = tl.sum(w * x, axis=1) + x = tl.reshape(x, (BLOCK_SIZE_K1, BLOCK_SIZE_K2)) + + # apply phase noise + if ENABLE_PHASE_NOISE: + block_start = pid * BLOCK_SIZE_K1 * BLOCK_SIZE_K2 + offs = tl.reshape( + block_start + tl.arange(0, BLOCK_SIZE_K1 * BLOCK_SIZE_K2), + (BLOCK_SIZE_K1, BLOCK_SIZE_K2), + ) + noise = tl.randn(seed, offs) * phase_noise_std + x = x + noise + + # add trainable bias + b = b.reshape(1, 1) + + if TRAINABLE_MORR_BIAS: + x = x - b + + # mrr_roundtrip_phase_to_tr + x = _mrr_roundtrip_phase_to_tr_func(x, mrr_a, mrr_r, intensity=True) + + # store the value in acc using mask + res = x + condition_mask = ( + (m_indices == m_local) + & (p_indices == p_local) + & (q_indices == q_local) + ) + res = res[None, None, None, :, :] + acc = tl.where(condition_mask, res, acc) + + # propagate pointer along Q dimension + w_ptrs += stride_wq + x_ptrs += stride_xq + b_ptrs += stride_bq + + # Q loop end + # reset pointer along Q dimension + w_ptrs -= stride_wq * (BLOCK_SIZE_Q) + x_ptrs -= stride_xq * (BLOCK_SIZE_Q) + b_ptrs -= stride_bq * (BLOCK_SIZE_Q) + # propagate pointer along P dimension + w_ptrs += stride_wp + b_ptrs += stride_bp + # x_ptrs += stride_xp # x has P dimension = 1 + + # P loop end + # reset pointer along P dimension + w_ptrs -= stride_wp * (BLOCK_SIZE_P) + b_ptrs -= stride_bp * (BLOCK_SIZE_P) + # x_ptrs -= stride_xp * (BLOCK_SIZE_P + 1) # x has P dimension = 1、 + + # propagate pointer along M dimension + # w_ptrs += stride_wp # weight has M dimension = 1 + x_ptrs += stride_xm + + out = acc.to(INPUT_DTYPE) + out = out.reshape( + BLOCK_SIZE_M, BLOCK_SIZE_P, BLOCK_SIZE_Q, BLOCK_SIZE_K1 + ) # [1, 1, q, k, 1] -> [1, 1, q, k] + + offs_om = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_op = pid_p * BLOCK_SIZE_P + tl.arange(0, BLOCK_SIZE_P) + offs_oq = pid_q * BLOCK_SIZE_Q + tl.arange(0, BLOCK_SIZE_Q) + offs_ok1 = tl.arange(0, BLOCK_SIZE_K1) + # offs_ok2 = tl.arange(0, BLOCK_SIZE_K2) + o_ptrs = o_ptr + ( + stride_om * offs_om[:, None, None, None] + + stride_op * offs_op[None, :, None, None] + + stride_oq * offs_oq[None, None, :, None] + + stride_ok1 * offs_ok1[None, None, None, :] + ) + + m_valid = offs_om[:, None, None, None] < M + p_valid = offs_op[None, :, None, None] < P + q_valid = offs_oq[None, None, :, None] < Q + k_valid = offs_ok1[None, None, None, :] < K # K == BLOCK_SIZE_K1 + o_mask = m_valid & p_valid & q_valid & k_valid + tl.store(o_ptrs, out, mask=o_mask) + + +@torch.library.custom_op( + f"{PACKAGE_NAME}::optical_morr_linear_linear_fn", + mutates_args={}, +) +def morr_linear_fn_mem( + x: Tensor, + weight: Tensor, + morr_input_bias: Tensor, + morr_output_scale: Tensor, + bias: Tensor | None, + morr_input_scale: Tensor, + morr_bias: Tensor | None, + grid_dim_x: int, + grid_dim_y: int, + miniblock: int, + enable_thermal_crosstalk: bool, + crosstalk_factor: float | None, + enable_phase_noise: bool, + phase_noise_std: float | None, + trainable_morr_bias: bool, + mrr_a: float, + mrr_r: float, + finegrain_drop_mask: Tensor | None, + in_features: int, + in_features_pad: int, + out_features: int, + out_features_pad: int, + in_bit: int, + w_bit: int, + morr_fwhm: float, + sigma_weight: float, + trainable_morr_scale: bool, + morr_scale: Tensor, + weight_quant_gain: float | None = None, + seed: int = 42, +) -> tuple[Tensor, int, Tensor, Tensor, Tensor, Tensor, Tensor, float]: + Device = x.device + assert x.dtype in ( + torch.bfloat16, + torch.float16, + torch.float32, + ), f"Unsupported dtype {x.dtype}" + assert x.is_contiguous(), "Input tensor must be contiguous" + assert weight.dtype in ( + torch.bfloat16, + torch.float16, + torch.float32, + ), f"Unsupported dtype {weight.dtype}" + + # Handle transformer vs non-transformer inputs + ori_x_shape = x.shape + is_transformer = len(ori_x_shape) == 3 + + if is_transformer: + in_B, in_N, in_D = x.shape + M = in_B * in_N + x = x.reshape(M, in_D) + else: + M = x.shape[0] + + # Get dimensions + M, D = x.shape + P, Q, K = weight.shape + + if in_features_pad > D: + x_pad = torch.zeros(M, in_features_pad - D, device=Device, dtype=x.dtype) + x = torch.cat([x, x_pad], dim=1) + + assert Q * K == in_features_pad, "input and weight dimension mismatch" + assert P * K == out_features_pad, "weight and output dimension mismatch" + + # Quantize input + if in_bit < 16: + input_quantizer = input_quantize_fn(in_bit, device=Device) + input_quantizer.set_bitwidth(in_bit) + x = input_quantizer(x) + + # Build weight + if w_bit < 16: + weight_quantizer = weight_quantize_fn(w_bit, alg="dorefa_pos") + weight_quantizer.set_bitwidth(w_bit) + weight = weight_quantizer(weight) + + ## rescale weights after quantization can maintain the initialization distribution + if weight_quant_gain is None: + weight_quant_gain = sigma_weight / weight.data.std() + if trainable_morr_scale: + morr_scale = morr_scale * weight_quant_gain + else: + morr_scale = weight_quant_gain + weight = weight.mul(morr_scale) ### gain factor from Tanh used in quantization + ### quantize learnable balancing factor + morr_output_scale_quantizer = weight_quantize_fn(w_bit, alg="dorefa_sym") + morr_output_scale = morr_output_scale_quantizer(morr_output_scale) + else: + weight = weight.abs() # positive only + morr_output_scale = morr_output_scale - morr_output_scale.data.mean() + + if finegrain_drop_mask is not None: + weight = weight.mul(finegrain_drop_mask.float()) + + # differential balancing factor concatenation + scale = morr_output_scale[..., :-1, :] + scale_pad = morr_output_scale[..., -1:, :] + if grid_dim_x % 2 == 0: + # even blocks + scale = torch.cat([scale, -scale], dim=2) # [1, 1, q, 1] + else: + # odd blocks + if grid_dim_x > 1: + scale = torch.cat([morr_output_scale, -scale], dim=2) # [1, 1, q, 1] + else: + scale = scale_pad # [1, 1, q, 1] + morr_output_scale = scale.squeeze(-1).unsqueeze(0) # [1 ,1, 1, q] + ctx_morr_output_scale = morr_output_scale.clone() + + # Reshape x and weight + x = x.view(-1, grid_dim_x, miniblock) # [M, q, k] + x = x.unsqueeze(1).unsqueeze(-1) # [M, 1, q, k, 1] + weight = toeplitz(weight).unsqueeze(0) # [p, q, k] -> [1, p, q, k, k] + + x_ctx = x.squeeze(-1).squeeze(1).clone() # [M, q, k] + w_ctx = weight.clone() + + # Allocate output + output = torch.empty((M, P, Q, K, 1), device=Device, dtype=x.dtype) + # Launch the Triton kernel + grid = lambda meta: ( + triton.cdiv(M, meta["BLOCK_SIZE_M"]) + * triton.cdiv(P, meta["BLOCK_SIZE_P"]) + * triton.cdiv(Q, meta["BLOCK_SIZE_Q"]), + ) + morr_propagate_kernel[grid]( + x_ptr=x, + w_ptr=weight, + o_ptr=output, + b_ptr=morr_bias, + M=M, + P=P, + Q=Q, + K=K, + grid_dim_q=grid_dim_x, + grid_dim_p=grid_dim_y, + miniblock=miniblock, + crosstalk_factor=crosstalk_factor, + phase_noise_std=phase_noise_std, + mrr_a=mrr_a, + mrr_r=mrr_r, + in_bit=in_bit, + w_bit=w_bit, + seed=seed, + finegrain_drop_mask=finegrain_drop_mask, + stride_wm=weight.stride(0), + stride_wp=weight.stride(1), + stride_wq=weight.stride(2), + stride_wk1=weight.stride(3), + stride_wk2=weight.stride(4), + stride_xm=x.stride(0), + stride_xp=x.stride(1), + stride_xq=x.stride(2), + stride_xk1=x.stride(3), + stride_xk2=x.stride(4), + stride_bm=morr_bias.stride(0) if morr_bias is not None else 0, + stride_bp=morr_bias.stride(1) if morr_bias is not None else 0, + stride_bq=morr_bias.stride(2) if morr_bias is not None else 0, + stride_bk1=morr_bias.stride(3) if morr_bias is not None else 0, + stride_om=output.stride(0), + stride_op=output.stride(1), + stride_oq=output.stride(2), + stride_ok1=output.stride(3), + stride_ok2=output.stride(4), + ENABLE_THERMAL_CROSSTALK=enable_thermal_crosstalk, + ENABLE_PHASE_NOISE=enable_phase_noise and phase_noise_std > 1e-4, + TRAINABLE_MORR_BIAS=trainable_morr_bias, + INPUT_DTYPE=TORCH_DTYPE_TO_TRITON[x.dtype], + BLOCK_SIZE_K1=K, + ) + + # Apply output scale + output = output.squeeze(-1) # [m, p, q, k, 1] -> [m, p, q, k] + ctx_x_scalematmul = output.clone() # record x input for matmul + output = morr_output_scale.matmul( + output + ) # [1, 1, 1, q] x [bs, p, q, k] = [bs, p, 1, k] + output = output.flatten(1) # [bs, p*k] + + # Trim output if needed + if out_features < out_features_pad: + output = output[:, :out_features] + if bias is not None: + output = output + bias.unsqueeze(0) + # Reshape back for transformer + if is_transformer: + output = output.view(in_B, in_N, out_features) + + return ( + output, + seed, + torch.abs(w_ctx), + x_ctx, + ctx_morr_output_scale, + ctx_x_scalematmul, + morr_scale.clone(), + weight_quant_gain if weight_quant_gain is not None else 0.0, + ) + + +def _morr_linear_setup_context(ctx, inputs, output): + """ + Save for backward only what the backward routine really needs. + """ + ( + x, # 0 Tensor – input + weight, # 1 Tensor – learnable weight + morr_input_bias, # 23 Tensor + _, # 3 morr_output_scale (original) + bias, # 4 Tensor | None – bias + morr_input_scale, + morr_bias, # 2 Tensor | None + grid_dim_x, # 5 int + grid_dim_y, # 6 int + miniblock, # 7 int (== K) + enable_thermal_crosstalk, # 8 bool + crosstalk_factor, # 9 float + enable_phase_noise, # 10 bool + phase_noise_std, # 11 float + trainable_morr_bias, # 12 bool + mrr_a, # 13 float + mrr_r, # 14 float + finegrain_drop_mask, # 15 Tensor | None + in_features, # 16 int + in_features_pad, # 17 int + out_features, # 18 int + out_features_pad, # 19 int + in_bit, # 20 int + w_bit, # 21 int + morr_fwhm, # 22 float + sigma_weight, + trainable_morr_scale, # bool + _morr_scale, + weight_quant_gain, + seed, # 23 int + ) = inputs + + ( + output, + seed, + w_morr, + x_modulator, + morr_output_scale, + x_scalematmul, + morr_scale, + _weight_quant_gain, + ) = output + + device, dtype = x.device, x.dtype + + # ----- Tensor meta-data that backward needs ----- + # Shapes + M = x.shape[0] if x.dim() == 2 else x.shape[0] * x.shape[1] + P, Q, K = weight.shape + tensor_shape = (M, P, Q, K) + + # mrr_para: para for mrr_roundtrip_phase_to_tr() + # c1 = -2.0 * mrr_a * mrr_r + # c2 = mrr_a * mrr_a + mrr_r * mrr_r + # c3 = 1.0 + (mrr_r * mrr_r) * (mrr_a * mrr_a) - mrr_a * mrr_a - mrr_r * mrr_r + # c4 = (mrr_a**2.0 - 1.0) * (mrr_r**2.0 - 1.0) * 2.0 * mrr_a * mrr_r + # intensity = True + # mrr_para = (c1, c2, c3, c4, intensity) + + # # x_morr: x input of matmal in propagate_morr() + # x_morr = x_modulator ** 2 # [m, q, k] + # x_morr = x_morr.unsqueeze(1).unsqueeze(-1) # [m, 1, q, k, 1] + + # # x_mrr: x input of mrr_roundtrip_phase_to_tr() + # x_mrr = w_morr.matmul(x_morr).squeeze(-1) + # if enable_phase_noise and phase_noise_std > 1e-5: + # x_mrr = x_mrr + torch.zeros_like(x_mrr).normal_(0, phase_noise_std) + # if trainable_morr_bias: + # x_mrr = x_mrr - morr_bias # morr_bias here is the detached one from forward + + # tanh_input_bias = torch.tanh(morr_input_bias.unsqueeze(0).unsqueeze(-1)) # Added from linear.py + + # 3. stash tensors + ctx.save_for_backward( + x, # original input (stashing x for mem version, might need re-evaluation for pure mem-saving) + weight, # original weight (stashing weight for mem version) + bias if bias is not None else torch.tensor([], device=device, dtype=dtype), + morr_output_scale, # original morr_output_scale + # x_mrr, # x input for mrr_roundtrip_phase_to_tr() + # x_morr, + # w_morr, # w input for propagate_morr() matmul + # x_modulator, # x input for input_modulator() + morr_input_bias, + # x_scalematmul, + # x_scalematmul, # x input for morr_output_scale.matmul + morr_input_scale, # morr input scale at input + # morr_scale, # morr_scale after modification in build_weight() + finegrain_drop_mask, + ) + ctx.tensor_shape = tensor_shape + # ctx.mrr_para = mrr_para + ctx.in_features = in_features + ctx.in_features_pad = in_features_pad + ctx.out_features = out_features + ctx.out_features_pad = out_features_pad + ctx.morr_fwhm = morr_fwhm + ctx.grid_dim_x = grid_dim_x + ctx.grid_dim_y = grid_dim_y + ctx.in_bit = in_bit + ctx.w_bit = w_bit + ctx.x_input_shape = x.shape + ctx.device = x.device + ctx.w_input_shape = weight.shape + # ctx.morr_fwhm = morr_fwhm # Already exists + ctx.enable_phase_noise = enable_phase_noise + ctx.phase_noise_std = phase_noise_std + ctx.trainable_morr_bias = trainable_morr_bias + ctx.trainable_morr_scale = trainable_morr_scale + ctx.weight_quant_gain = weight_quant_gain + ctx.miniblock = miniblock + ctx.crosstalk_factor = crosstalk_factor + ctx.sigma_weight = sigma_weight + ctx.enable_thermal_crosstalk = enable_thermal_crosstalk + ctx.mrr_a = mrr_a + ctx.mrr_r = mrr_r + + +def recompute_activations( + ctx, + x: Tensor, + weight: Tensor, + bias: Tensor | None, + morr_output_scale: Tensor, + finegrain_drop_mask, + morr_input_bias: Tensor, + morr_input_scale: Tensor, +): + """ + Recompute activations for morr_linear_fn. + """ + Device = x.device + Dtype = x.dtype + + ctx_morr_scale = None + ctx_tanh_input_bias = None + + # Handle transformer vs non-transformer inputs + ori_x_shape = x.shape + is_transformer = len(ori_x_shape) == 3 + + if is_transformer: + in_B, in_N, in_D = x.shape + M = in_B * in_N + x = x.reshape(M, in_D) + else: + M = x.shape[0] + + # Get dimensions + M, D = x.shape + P, Q, K = weight.shape + + if ctx.in_features_pad > D: + x_pad = torch.zeros(M, ctx.in_features_pad - D, device=Device, dtype=x.dtype) + x = torch.cat([x, x_pad], dim=1) + + # Quantize input + if ctx.in_bit < 16: + input_quantizer = input_quantize_fn(ctx.in_bit, device=Device) + input_quantizer.set_bitwidth(ctx.in_bit) + x = input_quantizer(x) + + ################# Build weight ################# + if ctx.w_bit < 16: + weight_quantizer = weight_quantize_fn(ctx.w_bit, alg="dorefa_pos") + weight_quantizer.set_bitwidth(ctx.w_bit) + weight = weight_quantizer(weight) + + # Calculate morr_scale + if morr_input_scale is None: + return None + morr_scale = torch.sigmoid(morr_input_scale.unsqueeze(-1)) + 0.2 # [p, q, 1] + + ## rescale weights after quantization can maintain the initialization distribution + weight_quant_gain = ctx.weight_quant_gain + if weight_quant_gain is None: + weight_quant_gain = ctx.sigma_weight / weight.data.std() + if ctx.trainable_morr_scale: + morr_scale = morr_scale * weight_quant_gain + else: + morr_scale = weight_quant_gain + + ctx_morr_scale = morr_scale.clone() + weight = weight.mul(morr_scale) ### gain factor from Tanh used in quantization + ### quantize learnable balancing factor + morr_output_scale_quantizer = weight_quantize_fn(ctx.w_bit, alg="dorefa_sym") + morr_output_scale = morr_output_scale_quantizer(morr_output_scale) + else: + weight = weight.abs() # positive only + morr_output_scale = morr_output_scale - morr_output_scale.data.mean() + + if finegrain_drop_mask is not None: + weight = weight.mul(finegrain_drop_mask.float()) + + # differential balancing factor concatenation + scale = morr_output_scale[..., :-1, :] + scale_pad = morr_output_scale[..., -1:, :] + if ctx.grid_dim_x % 2 == 0: + # even blocks + scale = torch.cat([scale, -scale], dim=2) # [1, 1, q, 1] + else: + # odd blocks + if ctx.grid_dim_x > 1: + scale = torch.cat([morr_output_scale, -scale], dim=2) # [1, 1, q, 1] + else: + scale = scale_pad # [1, 1, q, 1] + morr_output_scale = scale.squeeze(-1).unsqueeze(0) # [1 ,1, 1, q] + ctx_morr_output_scale = morr_output_scale.clone() + + # Reshape x and weight + x = x.view(-1, ctx.grid_dim_x, ctx.miniblock) # [M, q, k] + + # input_modulator() + ctx_x_modulator = x.clone() + x = x**2 + + ################# propagate_morr() ################# + if ctx.enable_thermal_crosstalk and ctx.crosstalk_factor > 1: + weight = weight * ctx.crosstalk_factor + weight = toeplitz(weight).unsqueeze(0) # [1, p, q, k, k] + x = x.unsqueeze(1).unsqueeze(-1) # [bs, 1, q, k, 1] + + ctx_x_morr = x.clone() + ctx_w_morr = weight.clone() + x = weight.matmul(x).squeeze(-1) # [bs, p, q, k] + + if ctx.enable_phase_noise and ctx.phase_noise_std > 1e-5: + x = x + torch.zeros_like(x).normal_(0, ctx.phase_noise_std) + + if ctx.trainable_morr_bias: + ctx_tanh_input_bias = torch.tanh(morr_input_bias.unsqueeze(0).unsqueeze(-1)) + morr_bias = ctx.morr_fwhm * ctx_tanh_input_bias + x = x - morr_bias + + ctx_x_mrr = x.clone() + + mrr_roundtrip_phase_to_tr = mrr_roundtrip_phase_to_tr_func( + a=ctx.mrr_a, r=ctx.mrr_r, intensity=True + ) + x = mrr_roundtrip_phase_to_tr(x) + + ctx_x_scalematmul = x.clone() + x = morr_output_scale.matmul(x) # [1, 1, 1, q] x [bs, p, q, k] = [bs, p, 1, k] + x = x.flatten(1) # [bs, p*k] + + # ------------------------------------------------------ + + # # Trim output if needed + # if ctx.out_features < ctx.out_features_pad: + # output = output[:, :ctx.out_features] + # if bias is not None: + # output = output + bias.unsqueeze(0) + # # Reshape back for transformer + # if is_transformer: + # output = output.view(in_B, in_N, ctx.out_features) + + return ( + # x, weight, bias, morr_output_scale, + # output, + ctx_x_modulator, # x input for input_modulator() + ctx_x_morr, # x input for propagate_morr() matmul + ctx_w_morr, # w input for propagate_morr() matmul + ctx_x_mrr, # x input for mrr_roundtrip_phase_to_tr() + ctx_x_scalematmul, # x input for morr_output_scale.matmul + ctx_tanh_input_bias, # input_bias after tanh() + ctx_morr_scale, # morr_scale after modification in build_weight() + ) + + +def _morr_linear_backward(ctx, grad_output, *ignored): + """ + Backward pass for morr_linear_fn. + """ + ( + x, + weight, + bias, + morr_output_scale, + # x_mrr, + # x_morr, + # w_morr, + # x_modulator, + morr_input_bias, + # x_scalematmul, + morr_input_scale, + # morr_scale, + finegrain_drop_mask, + ) = ctx.saved_tensors + + M, P, Q, K = ctx.tensor_shape + # c1, c2, c3, c4, intensity = ctx.mrr_para + in_features = ctx.in_features + in_features_pad = ctx.in_features_pad + out_features = ctx.out_features + out_features_pad = ctx.out_features_pad + x_input_shape = ctx.x_input_shape + w_input_shape = ctx.w_input_shape + DEVICE = ctx.device + + # --- calculate intermediate activation on the fly --- + ( + x_modulator, # x input for input_modulator() + x_morr, # x input for propagate_morr() matmul + w_morr, # w input for propagate_morr() matmul + x_mrr, # x input for mrr_roundtrip_phase_to_tr() + x_scalematmul, # x input for morr_output_scale.matmul + tanh_input_bias, # input_bias after tanh() + morr_scale, # morr_scale after modificaiton in build_weight() + ) = recompute_activations( + ctx, + x, + weight, + bias, + morr_output_scale, + finegrain_drop_mask, + morr_input_bias, + morr_input_scale, + ) + + # x_morr = (x_modulator ** 2).unsqueeze(1).unsqueeze(-1) # [m, q, k] -> # [m, 1, q, k, 1] + + # tanh_input_bias = torch.tanh(morr_input_bias.unsqueeze(0).unsqueeze(-1)) + # morr_bias = ctx.morr_fwhm * tanh_input_bias + + # # x_mrr: x input of mrr_roundtrip_phase_to_tr() + # x_mrr = w_morr.matmul(x_morr).squeeze(-1) + # if ctx.enable_phase_noise and ctx.phase_noise_std > 1e-5: + # x_mrr = x_mrr + torch.zeros_like(x_mrr).normal_(0, ctx.phase_noise_std) + # if ctx.trainable_morr_bias: + # x_mrr = x_mrr - morr_bias + + # ----- backward prop ----- + # Reshape + grad_out = grad_output.view( + x_input_shape[0], w_input_shape[1], w_input_shape[2], -1 + ) # [M, P, Q, K] + + # ----- Gradient w.r.t input x ----- + if True or ctx.needs_input_grad[0]: + # 1. reshape + grad_out = grad_out.view(M, -1) # [m, out_features] + + if ctx.needs_input_grad[4] and bias: + grad_bias = grad_out.sum(dim=0) # [out_features] + else: + grad_bias = None + + out_pad = torch.zeros( + grad_out.shape[0], out_features_pad - out_features, device=DEVICE + ) # [m, out_features_pad - out_features] + grad_out = torch.cat( + [grad_out, out_pad], dim=1 + ) # [m * out_features_pad] = [m, p*k] + + # 2. x=x.flatten(1) + # input: [m, p**k] + grad_out = grad_out.view(M, P, 1, K) # [m, p, 1, k] + + # 3. x = morr_output_scale.matmul(x) # [1, 1, 1, q] x [bs, p, q, k] = [bs, p, 1, k] + # dL/d(morr_output_scale) + if ctx.needs_input_grad[3]: + grad_s = grad_out.matmul(x_scalematmul.transpose(-2, -1)) # [bs, p, 1, q] + grad_s = grad_s.sum(dim=(0, 1)).unsqueeze(0).unsqueeze(1) # [1, 1, 1, q] + grad_s = grad_s.squeeze(0).unsqueeze(-1) # [1, 1, q, 1] gradient of scale + + t = ctx.grid_dim_x // 2 + grad_scale = grad_s.new_zeros((1, 1, t + 1, 1)) + + if ctx.grid_dim_x % 2 == 0: + grad_scale[..., :t, :] = grad_s[..., :t, :] - grad_s[..., t:, :] + elif ctx.grid_dim_x == 1: + grad_scale = grad_s + else: + grad_scale[..., :t, :] = grad_s[..., :t, :] - grad_s[..., t + 1 :, :] + grad_scale[..., t : t + 1, :] = grad_s[..., t : t + 1, :] + + else: + grad_scale = None + + # dL/dx + grad_x = morr_output_scale.transpose(-2, -1).matmul(grad_out) # [bs, p, q, k] + + # 4. x = mrr_roundtrip_phase_to_tr(x) + mrr_a, mrr_r = ctx.mrr_a, ctx.mrr_r + c1 = -2.0 * mrr_a * mrr_r + c2 = mrr_a * mrr_a + mrr_r * mrr_r + c3 = 1.0 + (mrr_r * mrr_r) * (mrr_a * mrr_a) - mrr_a * mrr_a - mrr_r * mrr_r + c4 = (mrr_a**2.0 - 1.0) * (mrr_r**2.0 - 1.0) * 2.0 * mrr_a * mrr_r + intensity = True + denominator = x_mrr.cos().mul_(c1).add_(c2 + c3) + if intensity: + denominator.square_() + numerator = x_mrr.sin().mul_(c4) + else: + numerator = x_mrr.sin().mul_(c4 / 2) + denominator = denominator.sub(1).pow_(1.5).mul_(denominator.sub(c3).sqrt_()) + grad_x = numerator.div_(denominator).mul_(grad_x) # [bs, p, q, k] + + # 5. x += phase_noise and x -= morr_bias + if ctx.trainable_morr_bias and ctx.needs_input_grad[2]: + grad_inputbias = -grad_x # [bs, p, q, k] + grad_inputbias = grad_inputbias * ctx.morr_fwhm # [bs, p, q, k] + grad_inputbias = ( + grad_inputbias - tanh_input_bias * tanh_input_bias + ) # [bs, p, q, k] + grad_inputbias = grad_inputbias.sum(dim=(0, -1)) + else: + grad_inputbias = None + + # 6. x = weight.matmul(x) [1, p, q, k, k] * [bs, 1, q, k, 1] = [bs, p, q, k, 1] + grad_x = grad_x.unsqueeze(-1) # [bs, p, q, k, 1] + grad_morr_matmul = grad_x # stash for weight gradient + + # dL/dx + grad_x = torch.matmul( + w_morr.transpose(-1, -2), grad_x + ) # [1, p, q, k, k] x [bs, p, q, k, 1] = [bs, p, q, k, 1] + grad_x = grad_x.sum(dim=1, keepdim=True) # [bs, p, q, k, 1] -> [bs, 1, q, k, 1] + grad_x = grad_x.squeeze(-1).squeeze(1) # [bs, 1, q, k, 1] -> [bs, q, k] + + # 7. input modulator + grad_x = grad_x * 2 * x_modulator # [bs, q, k] + + # 8. input reshape + grad_x = grad_x.view(x_input_shape) + grad_x = grad_x[:, :in_features] + + # ----- Gradient w.r.t weight ----- + if ctx.needs_input_grad[1]: + + # 0. gradient after x = weight.matmul(x) + # grad_morr_matmul # [bs, p, q, k, 1] + + # 1. x = weight.matmul(x) + grad_w = torch.matmul( + grad_morr_matmul, x_morr.transpose(-1, -2) + ) # [bs,p,q,k,k] + grad_w = grad_w.sum(dim=0, keepdim=True) # [1,p,q,k,k] + + # 2. weight = toeplitz(weight) + k = grad_w.size(-1) + row = torch.arange(k)[:, None] # (k,1) + col = torch.arange(k)[None, :] # (1,k) + idx = (row - col) & (k - 1) if (k & (k - 1)) == 0 else (row - col + k) % k + + idx = idx.expand(grad_w.shape).to(DEVICE) + buffer = torch.zeros_like(grad_w, device=DEVICE) + buffer.scatter_add_(-2, idx, grad_w) # [1, p, q, k, k] + grad_w = buffer.sum(dim=-1, keepdim=True).squeeze(0).squeeze(-1) + + # 3. build_weight() + if finegrain_drop_mask is not None: + grad_w = grad_w * finegrain_drop_mask.float() + # morr_scale: [p, q, 1] + grad_morr_input_scale = None + if ctx.w_bit < 16: + # grad w.r.t morr_scale + if ctx.needs_input_grad[5] & ctx.trainable_morr_scale: + grad_morr_scale = (grad_w * weight).sum( + dim=2, keepdim=True + ) # [p, q, 1] + grad_morr_scale = grad_morr_scale * ctx.weight_quant_gain # [p, q, 1] + # ∂L/∂self.morr_input_scale + sigmoid_scale = torch.sigmoid(morr_input_scale) + grad_morr_input_scale = ( + grad_morr_scale * sigmoid_scale * (1 - sigmoid_scale) + ).squeeze( + -1 + ) # [p, q] + + # grad w.r.t weight + grad_w = grad_w * morr_scale + else: + grad_w = grad_w * weight.sign() + + return ( + grad_x, # ∂L/∂x + grad_w, # ∂L/∂w + grad_inputbias, # ∂L/∂morr_input_bias + grad_scale, # ∂L/∂morr_output_scale + grad_bias, # ∂L/∂bias + grad_morr_input_scale, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ) + + +morr_linear_fn_mem.register_autograd( + _morr_linear_backward, + setup_context=_morr_linear_setup_context, +) diff --git a/src/chop/nn/optical/triton_modules/morr_linear_mem.py b/src/chop/nn/optical/triton_modules/morr_linear_mem.py new file mode 100644 index 000000000..eb314b3a0 --- /dev/null +++ b/src/chop/nn/optical/triton_modules/morr_linear_mem.py @@ -0,0 +1,483 @@ +""" +Description: +Author: Jiaqi Gu (jqgu@utexas.edu) +Date: 2022-04-18 14:19:57 +LastEditors: Jiaqi Gu (jqgu@utexas.edu) +LastEditTime: 2022-04-18 16:21:37 +""" + +from typing import Optional +import logging + +import numpy as np +import torch +import torch.fft +from torch import Tensor +from torch.nn import Parameter, init +from torch.types import Device + +from ..utils import MORRConfig_20um_MQ +from ..utils import mrr_roundtrip_phase_to_tr_func, mrr_roundtrip_phase_to_tr_fused +from ..utils import toeplitz +from ..utils import morr_uniform_ +from ..utils import input_quantize_fn, weight_quantize_fn +from ..modules.base_layer import ONNBaseLayer +from .morr_linear_kernel_mem import morr_linear_fn_mem + +logger = logging.getLogger(__name__) + +__all__ = ["AllPassMORRCirculantLinear"] + + +class TritonMemMORRLinear(ONNBaseLayer): + """ + All-pass MORR Linear layer, assumes (1) block-circulant matrix (2) differential rails (3) learnable balancing factors. + J. Gu, et al., "SqueezeLight: Towards Scalable Optical Neural Networks with Multi-Operand Ring Resonators" + https://doi.org/10.23919/DATE51398.2021.9474147 + """ + + __constants__ = ["in_features", "out_features"] + in_features: int + out_features: int + miniblock: int + weight: Tensor + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = False, + config={}, + device: Device = torch.device("cpu"), + ) -> None: + super(TritonMemMORRLinear, self).__init__() + self.in_features = in_features + self.out_features = out_features + + miniblock_size = config.get("miniblock", 4) + self.miniblock = miniblock_size + self.grid_dim_x = int(np.ceil(self.in_features / miniblock_size)) + self.grid_dim_y = int(np.ceil(self.out_features / miniblock_size)) + self.in_features_pad = self.grid_dim_x * miniblock_size + self.out_features_pad = self.grid_dim_y * miniblock_size + + self.v_max = 10.8 + self.v_pi = 4.36 + self.gamma = np.pi / self.v_pi**2 + self.w_bit = 32 + self.in_bit = 32 + + morr_config = config.get("MORRConfig", MORRConfig_20um_MQ) + morr_init_val = config.get("morr_init", MORRConfig_20um_MQ) + self.MORRConfig = morr_config + self.morr_init = morr_init_val + self.mrr_a = morr_config.attenuation_factor + self.mrr_r = morr_config.coupling_factor + self.trainable_morr_bias = config.get("trainable_morr_bias", MORRConfig_20um_MQ) + self.trainable_morr_scale = config.get( + "trainable_morr_scale", MORRConfig_20um_MQ + ) + self.device = device + ### calculate FWHM (rad) + self.morr_fwhm = ( + -4 + * np.pi**2 + * morr_config.radius + * morr_config.effective_index + * ( + 1 / morr_config.resonance_wavelength + - 1 / (morr_config.resonance_wavelength - morr_config.bandwidth / 2) + ) + ) + + ### allocate parameters + self.weight = None + self.x_zero_pad = None + self.morr_output_scale = None ## learnable balancing factors implelemt by MRRs + self.morr_input_bias = None ## round-trip phase shift bias within MORR + self.morr_input_scale = ( + None ## scaling factor for the round-trip phase shift within MORR + ) + self.morr_gain = ( + 100 / (self.in_features // self.miniblock) + ) ** 0.5 ## TIA gain, calculated such that output variance is around 1 + ### build trainable parameters + self.build_parameters() + + ### quantization tool + self.input_quantizer = input_quantize_fn(self.in_bit, device=self.device) + self.weight_quantizer = weight_quantize_fn( + self.w_bit, alg="dorefa_pos" + ) ## [0-1] positive only, maintain the original scale + self.morr_output_scale_quantizer = weight_quantize_fn( + self.w_bit, alg="dorefa_sym" + ) ## [-1,1] full-range + + self.mrr_roundtrip_phase_to_tr = mrr_roundtrip_phase_to_tr_func( + a=self.mrr_a, r=self.mrr_r, intensity=True + ) + + ### default set to slow forward + self.disable_fast_forward() + ### default set no gamma noise + self.set_gamma_noise(0) + ### default set no crosstalk + self.disable_crosstalk() + ### default set no phase variation + self.disable_phase_variation() + + if bias: + self.bias = Parameter(torch.Tensor(out_features).to(self.device)) + else: + self.register_parameter("bias", None) + + self.reset_parameters(morr_init=morr_init_val) + self.finegrain_drop_mask = None + + def build_parameters(self) -> None: + + self.weight = Parameter( + torch.ones( + self.grid_dim_y, + self.grid_dim_x, + self.miniblock, + device=self.device, + dtype=torch.float, + ) + ) + ### Learnable balancing factor (morr_output_scale) + ### We use a single scaling factor for each block + self.morr_output_scale = Parameter( + torch.randn(1, 1, max(1, self.grid_dim_x // 2) + 1, 1, device=self.device) + ) + if self.trainable_morr_bias: + ### initialize with the finest-granularity, i.e., per mini-block + self.morr_input_bias = Parameter( + torch.zeros( + self.grid_dim_y, + self.grid_dim_x, + device=self.device, + dtype=torch.float, + ) + ) + if self.trainable_morr_scale: + ### initialize with the finest-granularity, i.e., per mini-block + self.morr_input_scale = Parameter( + torch.zeros( + self.grid_dim_y, + self.grid_dim_x, + device=self.device, + dtype=torch.float, + ) + ) + + def reset_parameters(self, morr_init: bool = False) -> None: + ### nonlinear curve aware initialization + if morr_init: + ## initialize weight + morr_uniform_( + self.weight, + MORRConfig=self.MORRConfig, + n_op=self.miniblock, + biased=self.w_bit >= 16, + gain=2 if self.in_bit < 16 else 1, + ) # quantization needs zero-center + self.sigma_weight = self.weight.data.std().item() + self.weight_quant_gain = None + + ## output distribution aware initialization to output scaling factor + t1 = mrr_roundtrip_phase_to_tr_fused( + torch.tensor([0]).float(), a=self.mrr_a, r=self.mrr_r, intensity=True + ) + t2 = mrr_roundtrip_phase_to_tr_fused( + torch.tensor([self.morr_fwhm * 2.4]).float(), + a=self.mrr_a, + r=self.mrr_r, + intensity=True, + ) + g = ( + (t2 - t1) / (2.4 * self.morr_fwhm) + ).item() ## 0~2.4 FWHM slope as a linear approximation + + self.sigma_out_scale = 4 / (3 * self.grid_dim_x**0.5 * g * self.morr_fwhm) + self.out_scale_quant_gain = None + init.normal_(self.morr_output_scale, 0, self.sigma_out_scale) + else: + init.kaiming_normal_(self.weight.data) + init.kaiming_normal_(self.morr_output_scale.data) + self.sigma_weight = self.weight.data.std().item() + self.weight_quant_gain = None + self.sigma_out_scale = self.morr_output_scale.data.std().item() + self.out_scale_quant_gain = None + + if self.morr_input_bias is not None: + self.morr_input_bias.data.zero_() + if self.morr_input_scale is not None: + ### after sigmoid, it cooresponds to 1 scale + init.normal_(self.morr_input_scale.data, 2, 0.1) + + if self.bias is not None: + init.uniform_(self.bias, 0, 0) + + def sync_parameters(self, src: str = "weight") -> None: + """ + description: synchronize all parameters from the source parameters + """ + + raise NotImplementedError + + def build_weight(self) -> Tensor: + if self.w_bit < 16: + ### differentiable quantizer based on STE to enable QAT (Dorefa-Net, arXiv 2016) + weight = self.weight_quantizer(self.weight) + + ## rescale weights after quantization can maintain the initialization distribution + if self.weight_quant_gain is None: + self.weight_quant_gain = self.sigma_weight / weight.data.std() + if self.trainable_morr_scale: + morr_scale = self.morr_scale * self.weight_quant_gain + else: + morr_scale = self.weight_quant_gain + weight = weight.mul( + morr_scale + ) ### gain factor from Tanh used in quantization + + ### quantize learnable balancing factor + morr_output_scale = self.morr_output_scale_quantizer(self.morr_output_scale) + else: + weight = self.weight.abs() # positive only + morr_output_scale = ( + self.morr_output_scale - self.morr_output_scale.data.mean() + ) + + if self.finegrain_drop_mask is not None: + weight = weight.mul(self.finegrain_drop_mask.float()) + + ## differential balancing factor concatenation + scale = morr_output_scale[..., :-1, :] + scale_pad = morr_output_scale[..., -1:, :] + if self.grid_dim_x % 2 == 0: + # even blocks + scale = torch.cat([scale, -scale], dim=2) # [1, 1, q, 1] + else: + # odd blocks + if self.grid_dim_x > 1: + scale = torch.cat([morr_output_scale, -scale], dim=2) # [1, 1, q, 1] + else: + scale = scale_pad # [1, 1, q, 1] + morr_output_scale = scale.squeeze(-1).unsqueeze(0) # [1 ,1, 1, q] + + return weight, morr_output_scale + + def enable_fast_forward(self) -> None: + self.fast_forward_flag = True + + def disable_fast_forward(self) -> None: + self.fast_forward_flag = False + + def set_gamma_noise( + self, noise_std: float, random_state: Optional[int] = None + ) -> None: + self.gamma_noise_std = noise_std + + def load_parameters(self, param_dict) -> None: + """ + description: update parameters based on this parameter dictionary\\ + param param_dict {dict of dict} {layer_name: {param_name: param_tensor, ...}, ...} + """ + for name, param in param_dict.items(): + getattr(self, name).data.copy_(param) + + def set_weight_bitwidth(self, w_bit: int) -> None: + self.w_bit = w_bit + self.weight_quantizer.set_bitwidth(w_bit) + self.morr_output_scale_quantizer.set_bitwidth(w_bit) + + def set_input_bitwidth(self, in_bit: int) -> None: + self.in_bit = in_bit + self.input_quantizer.set_bitwidth(in_bit) + + def input_modulator(self, x: Tensor) -> Tensor: + ### voltage to power, which is proportional to the phase shift + return x * x + + def set_crosstalk_coupling_matrix( + self, coupling_factor: float, drop_perc: float = 0 + ) -> None: + ### crosstalk coupling matrix is a symmetric matrix, but the intra-MORR crosstalk can be taken as a round-trip phase shift scaling factor, which is proportional to the number of segments after pruned. + ### drop-perc is the pruning percentage. + assert 0 <= coupling_factor <= 1, logger.error( + f"Coupling factor must in [0,1], but got {coupling_factor}" + ) + + self.crosstalk_factor = ( + 1 + max(3, (self.miniblock * (1 - drop_perc) - 1)) * coupling_factor + ) + + def enable_crosstalk(self) -> None: + self.enable_thermal_crosstalk = True + + def disable_crosstalk(self) -> None: + self.enable_thermal_crosstalk = False + + def set_phase_variation(self, phase_noise_std: float = 0) -> None: + self.phase_noise_std = phase_noise_std + + def enable_phase_variation(self) -> None: + self.enable_phase_noise = True + + def disable_phase_variation(self) -> None: + self.enable_phase_noise = False + + def enable_trainable_morr_scale(self) -> None: + self.trainable_morr_scale = True + + def disable_trainable_morr_scale(self) -> None: + self.trainable_morr_scale = False + + def enable_trainable_morr_bias(self) -> None: + self.trainable_morr_bias = True + + def disable_trainable_morr_bias(self) -> None: + self.trainable_morr_bias = False + + @property + def morr_bias(self) -> Tensor: + if self.morr_input_bias is None: + return None + # return 2 * self.morr_fwhm * torch.sigmoid(self.morr_input_bias.unsqueeze(0).unsqueeze(-1)) + return self.morr_fwhm * torch.tanh( + self.morr_input_bias.unsqueeze(0).unsqueeze(-1) + ) + + @property + def morr_scale(self) -> Tensor: + if self.morr_input_scale is None: + return None + return torch.sigmoid(self.morr_input_scale.unsqueeze(-1)) + 0.2 # [p, q, 1] + + def propagate_morr( + self, weight: Tensor, x: Tensor, morr_output_scale: Tensor + ) -> Tensor: + """ + @description: propagate through the analytically calculated transfer matrix of molg. We implement circulant matrix multiplication using fast circ matmul + @param weight {torch.Tensor} two phase shifters in the MZI-based attenuators + @param x {torch.Tensor} complex-valued input + @param morr_output_scale {torch.Tensor} learnable balancing factors + @return: y {torch.Tensor} output of attenuators + """ + ### x : [bs, q, k] + ### weights: [p, q, k] + ### morr_output_scale: [1, 1, 1, q] + + ### input scaling [TCAD'21], must have valid ranges. too small will have dead neuron and not enough nonlinearity; too large will have larger power, cross-channel crosstalk. [0.2 - 1.2] will be suitable + ## build circulant weight matrix + # crosstalk on the weights are much cheaper to compute than on the phase shift + if self.enable_thermal_crosstalk and self.crosstalk_factor > 1: + weight = weight * self.crosstalk_factor + weight = toeplitz(weight).unsqueeze(0) # [1, p, q, k, k] + x = x.unsqueeze(1).unsqueeze(-1) # [bs, 1, q, k, 1] + x = weight.matmul(x).squeeze(-1) # [bs, p, q, k] + + if self.enable_phase_noise and self.phase_noise_std > 1e-5: + x = x + torch.zeros_like(x).normal_(0, self.phase_noise_std) + + ### input biasing [TCAD'21], must have valid ranges. too large will have power issue and cross-channel crosstalk. [-2FWHM ~ 0] + if self.trainable_morr_bias: + x = x - self.morr_bias + + ### Use theoretical transmission function for trainable MORR nonlinearity [TCAD'21] + ### x is the phase detuning, x=0 means on-resonance + ### phase: [bs, p, q, k] + x = self.mrr_roundtrip_phase_to_tr(x) # 3x faster than autograd + + ## implement balancing factor as dot-product + """ + if(self.w_bit < 16): + morr_output_scale = self.morr_output_scale_quantizer(self.morr_output_scale) + if(self.sigma_out_scale_quant_gain is None): + self.sigma_out_scale_quant_gain = self.sigma_out_scale / morr_output_scale.data.std().item() + morr_output_scale = morr_output_scale.mul(self.sigma_out_scale_quant_gain)### gain factor from Tanh used in quantization + else: + morr_output_scale = self.morr_output_scale + # morr_output_scale = morr_output_scale * self.morr_gain + scale = morr_output_scale[..., :-1, :] + scale_pad = morr_output_scale[..., -1:, :] + + # print("morr diff transmission:", end=", ") + # diff = x[..., :x.size(2)//2,:]-x[..., x.size(2)//2:,:] + # print_stat(diff) + if(self.grid_dim_x % 2 == 0): + #even blocks + scale = torch.cat([scale, -scale], dim=2) # [1, 1, q, 1] + else: + # odd blocks + if(self.grid_dim_x > 1): + scale = torch.cat([morr_output_scale, -scale], dim=2) # [1, 1, q, 1] + else: + scale = scale_pad # [1, 1, q, 1] + scale = scale.squeeze(-1).unsqueeze(0) # [1 ,1, 1, q] + # print("output scale Q:", end=", ") + # print_stat(scale[..., :scale.size(-1)//2]) + """ + x = morr_output_scale.matmul(x) # [1, 1, 1, q] x [bs, p, q, k] = [bs, p, 1, k] + x = x.flatten(1) # [bs, p*k] + return x + + def get_finegrain_drop_mask(self, topk: int) -> Tensor: + if self.w_bit < 16: + weight = self.weight_quantizer(self.weight.data) # [p, q, k] + else: + weight = self.weight.data.abs() + indices = weight.argsort(dim=-1) + mask = torch.ones_like(weight, dtype=torch.bool, device=weight.device) + + drop_indices = indices[:, :, 0:-topk] + mask.scatter_(2, drop_indices, 0) + self.finegrain_drop_mask = mask + return mask + + def apply_finegrain_drop_mask(self, mask: Tensor) -> None: + if self.w_bit < 16: + self.weight.data.masked_fill_(~mask.view_as(self.weight.data), -1000) + else: + self.weight.data.masked_fill_(~mask.view_as(self.weight.data), 0) + + def forward(self, x: Tensor) -> Tensor: + output, *_ = morr_linear_fn_mem( + x, + self.weight, + morr_input_bias=self.morr_input_bias, + morr_output_scale=self.morr_output_scale, + bias=None, + morr_input_scale=self.morr_input_scale, + morr_bias=self.morr_bias.detach(), + grid_dim_x=self.grid_dim_x, + grid_dim_y=self.grid_dim_y, + miniblock=self.miniblock, + enable_thermal_crosstalk=self.enable_thermal_crosstalk, + crosstalk_factor=( + None if not self.enable_thermal_crosstalk else self.crosstalk_factor + ), + enable_phase_noise=self.enable_phase_noise, + phase_noise_std=( + None if not self.enable_phase_noise else self.phase_noise_std + ), + trainable_morr_bias=self.trainable_morr_bias, + mrr_a=self.mrr_a, + mrr_r=self.mrr_r, + finegrain_drop_mask=None, + in_features=self.in_features, + in_features_pad=self.in_features_pad, + out_features=self.out_features, + out_features_pad=self.out_features_pad, + in_bit=self.in_bit, + w_bit=self.w_bit, + morr_fwhm=self.morr_fwhm, + sigma_weight=self.sigma_weight, + trainable_morr_scale=self.trainable_morr_scale, # bool + morr_scale=self.morr_scale, + weight_quant_gain=self.weight_quant_gain, + seed=42, + ) + return output diff --git a/src/chop/nn/optical/triton_modules/quantize.py b/src/chop/nn/optical/triton_modules/quantize.py new file mode 100644 index 000000000..fdd0848ef --- /dev/null +++ b/src/chop/nn/optical/triton_modules/quantize.py @@ -0,0 +1,108 @@ +import torch +from torch import Tensor +import triton +import triton.language as tl + + +@triton.jit +def uniform_quantize(x: tl.tensor, k, gradient_clip=False): + if k == 32: + out = input + elif k == 1: + out = tl.where(x >= 0, 1.0, -1.0) + else: + n = float(2**k - 1) + out = tl.extra.cuda.libdevice.rint(x * n) / n + + return out + + +def uniform_quantize_new(x: tl.tensor, k, scale, zero_point, gradient_clip=False): + if k == 32: + out = x + elif k == 1: + out = tl.where(x > 0, 1.0, tl.where(x < 0, -1.0, 0.0)) + else: + n = float(2**k - 1) + out = tl.div(x, scale) + out = out + zero_point + out = tl.extra.cuda.libdevice.rint(out) + out = tl.clamp(out, 0.0, n) + out = out - zero_point + out = out * scale + return out + + +@triton.jit +def _input_quantize_fn( + x: tl.tensor, + quant_ratio, + training, + in_bit, + alg, # self.training +): + # init + if alg == "dorefa": + uniform_q = uniform_quantize(k=in_bit) + elif alg == "normal": + uniform_q = uniform_quantize_new(k=in_bit) + scale = None + zero_point = None + # TODO: fix for triton + if 1 <= in_bit <= 8: # observer does not support higher than 8-bit + obs = torch.quantization.observer.MovingAverageMinMaxObserver( + averaging_constant=0.01, + dtype=torch.quint8, + qscheme=torch.per_tensor_affine, + reduce_range=False, + quant_min=0, + quant_max=2**in_bit - 1, + ) + else: + obs = None + + if quant_ratio > 1.0 and training: + rand_vals = tl.random(x.shape) + quant_noise_mask = tl.where(rand_vals > quant_ratio, 1, 0) + else: + quant_noise_mask = None + + if in_bit == 32: + input_q = x + elif in_bit == 1: + x = tl.clamp(x, 0.0, 1.0) + input_q = (uniform_q(x - 0.5) + 1) / 2 + if quant_noise_mask is not None: + noise = input_q - x + masked_noise = tl.where(quant_noise_mask, 0.0, noise) + input_q = x + masked_noise + else: + ### dorefa-style clamp for input data + if alg == "dorefa": + x = tl.clamp(x, 0.0, 1.0) + input_q = uniform_q(x) + elif alg == "normal": + if obs is not None: + if training: + obs(x) + scale, zero_point = obs.calculate_qparams() + # convert scale and zero_point type from qint8 + scale = scale.to(x.dtype) + zero_point = zero_point.to(x.dtype) + input_q = uniform_q(x, scale, zero_point) + else: + input_q = x # if no observer (in_bit > 8), do not quantize + else: + # raise NotImplementedError + input_q = tl.zeros_like(x) + # add noise + if quant_noise_mask is not None: + noise = input_q - x + masked_noise = tl.where(quant_noise_mask, 0.0, noise) + input_q = x + masked_noise + + return input_q + + +def _weight_quantize_fn(w: tl.tensor): + pass diff --git a/src/chop/nn/optical/utils/__init__.py b/src/chop/nn/optical/utils/__init__.py new file mode 100644 index 000000000..248e214da --- /dev/null +++ b/src/chop/nn/optical/utils/__init__.py @@ -0,0 +1,24 @@ +from .mrr import ( + MORRConfig_20um_MQ, + MRRConfig_5um_HQ, + MRRConfig_5um_MQ, + MRRConfig_5um_LQ, + MORRConfig_10um_MQ, +) + +from .compute import ( + im2col_2d, + toeplitz, +) + +from .initializer import morr_uniform_ + +from .quantize import ( + input_quantize_fn, + weight_quantize_fn, +) + +from .mrr_op import ( + mrr_roundtrip_phase_to_tr_func, + mrr_roundtrip_phase_to_tr_fused, +) diff --git a/src/chop/nn/optical/utils/compute.py b/src/chop/nn/optical/utils/compute.py new file mode 100644 index 000000000..8ae1b3279 --- /dev/null +++ b/src/chop/nn/optical/utils/compute.py @@ -0,0 +1,187 @@ +""" +Description: +Author: Jiaqi Gu (jqgu@utexas.edu) +Date: 2021-06-06 02:17:08 +LastEditors: Jiaqi Gu (jqgu@utexas.edu) +LastEditTime: 2021-06-06 02:17:08 +""" + +import contextlib +import logging +from functools import lru_cache +from typing import Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +from scipy.stats import truncnorm +from torch import Tensor, nn +from torch.autograd import grad +from torch.nn.modules.utils import _pair +from torch.types import Device, _size + +__all__ = [ + "toeplitz", + "im2col_2d", +] + + +def shift(v: Tensor, f: float = 1) -> Tensor: + return torch.cat((f * v[..., -1:], v[..., :-1]), dim=-1) + + +def Krylov(linear_map: Callable, v: Tensor, n: Optional[int] = None) -> Tensor: + if n is None: + n = v.size(-1) + cols = [v] + for _ in range(n - 1): + v = linear_map(v) + cols.append(v) + return torch.stack(cols, dim=-2) + + +def circulant(eigens: Tensor) -> Tensor: + circ = Krylov(shift, eigens).transpose(-1, -2) + return circ + + +@lru_cache(maxsize=4) +def _get_toeplitz_indices(n: int, device: Device) -> Tensor: + # cached toeplitz indices. avoid repeatedly generate the indices. + indices = circulant(torch.arange(n, device=device)) + return indices + + +def toeplitz(col: Tensor) -> Tensor: + """ + Efficient Toeplitz matrix generation from the first column. The column vector must in the last dimension. Batch generation is supported. Suitable for AutoGrad. Circulant matrix multiplication is ~4x faster than rfft-based implementation!\\ + @col {torch.Tensor} (Batched) column vectors.\\ + return out {torch.Tensor} (Batched) circulant matrices + """ + n = col.size(-1) + indices = _get_toeplitz_indices(n, device=col.device) + return col[..., indices] + + +def im2col_2d( + W: Optional[Tensor] = None, + X: Optional[Tensor] = None, + stride: int = 1, + padding: int = 0, + w_size: Optional[_size] = None, +) -> Tuple[Tensor, Tensor, int, int]: + if W is not None: + W_col = W.view(W.size(0), -1) + else: + W_col = None + + if X is not None: + n_filters, d_filter, h_filter, w_filter = W.size() if W is not None else w_size + n_x, d_x, h_x, w_x = X.size() + + h_out = (h_x - h_filter + 2 * padding) / stride + 1 + w_out = (w_x - w_filter + 2 * padding) / stride + 1 + + h_out, w_out = int(h_out), int(w_out) + X_col = torch.nn.functional.unfold( + X.view(1, -1, h_x, w_x), + h_filter, + dilation=1, + padding=padding, + stride=stride, + ).view(n_x, -1, h_out * w_out) + X_col = X_col.permute(1, 2, 0).contiguous().view(X_col.size(1), -1) + else: + X_col, h_out, w_out = None, None, None + + return W_col, X_col, h_out, w_out + + +def complex_mult(X: Tensor, Y: Tensor) -> Tensor: + """Complex-valued element-wise multiplication + + Args: + X (Tensor): Real tensor with last dim of 2 or complex tensor + Y (Tensor): Real tensor with last dim of 2 or complex tensor + + Returns: + Tensor: tensor with the same type as input + """ + if not torch.is_complex(X) and not torch.is_complex(Y): + assert ( + X.shape[-1] == 2 and Y.shape[-1] == 2 + ), "Last dimension of real-valued tensor must be 2" + if hasattr(torch, "view_as_complex"): + return torch.view_as_real( + torch.view_as_complex(X) * torch.view_as_complex(Y) + ) + else: + return torch.stack( + ( + X[..., 0] * Y[..., 0] - X[..., 1] * Y[..., 1], + X[..., 0] * Y[..., 1] + X[..., 1] * Y[..., 0], + ), + dim=-1, + ) + else: + return X.mul(Y) + + +def polar_to_complex(mag: Tensor, angle: Tensor) -> Tensor: + # magnitude and angle to real and imag + if angle is None: + return real_to_complex(angle) + if mag is None: + if isinstance(angle, torch.Tensor): + x = torch.stack([angle.cos(), angle.sin()], dim=-1) + elif isinstance(angle, np.ndarray): + x = np.stack([np.cos(angle), np.sin(angle)], axis=-1) + else: + raise NotImplementedError + else: + if isinstance(angle, torch.Tensor): + x = torch.stack([mag * angle.cos(), mag * angle.sin()], dim=-1) + elif isinstance(angle, np.ndarray): + x = np.stack([mag * np.cos(angle), mag * np.sin(angle)], axis=-1) + else: + raise NotImplementedError + return x + + +@lru_cache(maxsize=4) +def _polynomial_order_base(order: int, device: Device) -> Tensor: + return torch.arange(order - 1, -1, -1, device=device) + + +def polynomial(x: Tensor | np.ndarray, coeff: Tensor | np.ndarray) -> Tensor: + """calculate polynomial function of x given coefficient coeff + + Args: + x (Tensor): input tensor + coeff (Tensor): Tensor of shape [n], where n is the degree of polynomial. Orders: [n, n-1, ..., 2, 1, constant] + + Returns: + Tensor: output tensor coeff[0]*x^n + coeff[1]*x^{n-1} + ... + coeff[n-1]*x + coeff[n] + """ + # xs = [x] + # for i in range(2, coeff.size(0)): + # xs.append(xs[-1]*x) + # xs.reverse() + # x = torch.stack(xs, dim=-1) + + # Deprecated implementation + # x = torch.stack([x**i for i in range(coeff.size(0) - 1, 0, -1)], dim=-1) + # out = (x * coeff[:-1]).sum(dim=-1) + coeff[-1].data.item() + # return out + + ### x^n, x^{n-1}, ..., x^2, x, 1 + order = coeff.shape[0] # n+1 + if isinstance(x, Tensor): + ## torch from highest order to constant + x = x[..., None].expand([-1] * x.dim() + [order]) + order_base = _polynomial_order_base(order, x.device) + return x.pow(order_base).matmul(coeff) + elif isinstance(x, np.ndarray): + ## numpy polyval from constant to higher order + return np.polynomial.polynomial.polyval(x, coeff[::-1]) + else: + raise NotImplementedError diff --git a/src/chop/nn/optical/utils/initializer.py b/src/chop/nn/optical/utils/initializer.py new file mode 100644 index 000000000..cbdd1f83c --- /dev/null +++ b/src/chop/nn/optical/utils/initializer.py @@ -0,0 +1,60 @@ +""" +Description: +Author: Jiaqi Gu (jqgu@utexas.edu) +Date: 2021-06-06 01:57:16 +LastEditors: Jiaqi Gu (jqgu@utexas.edu) +LastEditTime: 2021-06-06 01:57:18 +""" + +import numpy as np +import torch + +__all__ = [ + # "quant_kaiming_uniform", + # "quant_kaiming_uniform_", + # "truncated_normal", + # "truncated_normal_", + "morr_uniform_", + # "morr_uniform", +] + + +def morr_uniform_(tensor, MORRConfig, n_op=4, biased=False, gain=1): + """ + description: Uniform initialization for MORR array based tensor core [SqueezeLight, Gu+, DATE'21]. We only consider how n_op influence one MORR's output. How to balance vector length should be considered in learnable balancing factor\\ + @tensor {torch.Tensor} weight tensor/parameter\\ + @MORRConfig {Config} MORR configuration defined in the onnlib/model/layer/device/mrr\\ + @n_op {int scalar} Number of operands on an MORR\\ + @biased {bool} biased=True, weight in [0, L]; otherwise in [-L/2, L/2].\\ + @gain {float} Gain due to activation. ReLU=sqrt(2), Tanh=5/3, Clamp(0,1)=2\\ + return {} + """ + morr_fwhm = ( + -4 + * np.pi**2 + * MORRConfig.radius + * MORRConfig.effective_index + * ( + 1 / MORRConfig.resonance_wavelength + - 1 / (MORRConfig.resonance_wavelength - MORRConfig.bandwidth / 2) + ) + ) + ### first we need to calculate the information gain of an MORR, estimated by linear estimation at 0 and FWHM + # t1 = mrr_roundtrip_phase_to_tr_fused(torch.tensor([0]).float(), a=MORRConfig.attenuation_factor, r=MORRConfig.coupling_factor, intensity=True) + # t2 = mrr_roundtrip_phase_to_tr_fused(torch.tensor([morr_fwhm]).float(), a=MORRConfig.attenuation_factor, r=MORRConfig.coupling_factor, intensity=True) + # g = (t2 - t1) / morr_fwhm + + ### calculate the variance of the weight + # var_phi = 1 ## assume the input is normalized to have variance 1 + # var_w = 1/(3/2*g**4*n_op*var_phi) + + ### calculate range of uniform distribution U(-L,L) + # L = ((3 * var_w)**0.5).item() + # return torch.nn.init.uniform_(tensor, -L, L) + + ## approximation by assuming 4*std(phi)= 3*FWHM, E[x]=0, D[x]=1, W ~ U[0, L] + L = (3 / (4 * n_op)) ** 0.5 * morr_fwhm * gain + if biased: + return torch.nn.init.uniform_(tensor, 0, L) + else: + return torch.nn.init.uniform_(tensor, -L / 2, L / 2) diff --git a/src/chop/nn/optical/utils/mrr.py b/src/chop/nn/optical/utils/mrr.py new file mode 100644 index 000000000..826e8ad64 --- /dev/null +++ b/src/chop/nn/optical/utils/mrr.py @@ -0,0 +1,73 @@ +""" +Description: +Author: Jiaqi Gu (jqgu@utexas.edu) +Date: 2021-07-18 00:03:04 +LastEditors: Jiaqi Gu (jqgu@utexas.edu) +LastEditTime: 2021-07-18 00:03:05 +""" + +import numpy as np + + +__all__ = [ + "MORRConfig_20um_MQ", + "MRRConfig_5um_HQ", + "MRRConfig_5um_MQ", + "MRRConfig_5um_LQ", + "MORRConfig_10um_MQ", +] + + +class MORRConfig_20um_MQ: + attenuation_factor = 0.8578 + coupling_factor = 0.8985 + radius = 20000 # nm + group_index = 2.35316094 + effective_index = 2.35 + resonance_wavelength = 1554.252 # nm + bandwidth = 0.67908 # nm + quality_factor = 2288.7644639 + + +class MRRConfig_5um_HQ: + attenuation_factor = 0.987 + coupling_factor = 0.99 + radius = 5000 # nm + group_index = 2.35316094 + effective_index = 2.4 + resonance_wavelength = 1538.739 # nm + bandwidth = 0.2278 # nm + quality_factor = 6754.780509 + + +class MRRConfig_5um_MQ: + attenuation_factor = 0.925 + coupling_factor = 0.93 + radius = 5000 # nm + group_index = 2.35316094 + effective_index = 2.4 + resonance_wavelength = 1538.739 # nm + bandwidth = 1.5068 # nm + quality_factor = 1021.1965755 + + +class MRRConfig_5um_LQ: + attenuation_factor = 0.845 + coupling_factor = 0.85 + radius = 5000 # nm + group_index = 2.35316094 + effective_index = 2.4 + resonance_wavelength = 1538.739 # nm + bandwidth = 2.522 # nm + quality_factor = 610.1265 + + +class MORRConfig_10um_MQ: + attenuation_factor = 0.8578 + coupling_factor = 0.8985 + radius = 10000 # nm + group_index = 2.35316094 + effective_index = 2.4 + resonance_wavelength = 1538.739 # nm + bandwidth = 1.6702 # nm + quality_factor = 1213.047 diff --git a/src/chop/nn/optical/utils/mrr_op.py b/src/chop/nn/optical/utils/mrr_op.py new file mode 100644 index 000000000..189db8761 --- /dev/null +++ b/src/chop/nn/optical/utils/mrr_op.py @@ -0,0 +1,113 @@ +""" +Description: +Author: Jiaqi Gu (jqgu@utexas.edu) +Date: 2021-07-18 00:01:34 +LastEditors: Jiaqi Gu (jqgu@utexas.edu) +LastEditTime: 2021-07-18 00:01:36 +""" + +from .compute import ( + complex_mult, + polar_to_complex, + polynomial, +) +import logging + +import numpy as np +import torch + +torch._C._jit_set_profiling_executor(False) + + +__all__ = [ + # "mrr_voltage_to_delta_lambda", + # "mrr_tr_to_roundtrip_phase", + # "mrr_roundtrip_phase_to_tr", + "mrr_roundtrip_phase_to_tr_fused", + # "mrr_roundtrip_phase_to_tr_grad_fused", + "mrr_roundtrip_phase_to_tr_func", + # "mrr_roundtrip_phase_to_out_phase", + # "mrr_tr_to_out_phase", + # "mrr_roundtrip_phase_to_tr_phase", + # "mrr_roundtrip_phase_to_tr_phase_fused", + # "mrr_modulator", + # "mrr_filter", + # "morr_filter", + # "mrr_fwhm_to_ng", + # "mrr_ng_to_fsr", + # "mrr_finesse", +] + + +@torch.jit.script +def mrr_roundtrip_phase_to_tr_fused( + rt_phi, a: float = 0.8, r: float = 0.9, intensity: bool = False +): + """ + description: round trip phase shift to field transmission + rt_phi {torch.Tensor or np.ndarray} abs of roundtrip phase shift (abs(phase lag)). range from abs([-pi, 0])=[0, pi]\\ + a {scalar} attenuation coefficient\\ + r {scalar} self-coupling coefficient\\ + intensity {bool scalar} whether output intensity tranmission or field transmission\\ + return t {torch.Tensor or np.ndarray} mrr through port field/intensity transmission + """ + + # use slow but accurate mode from theoretical equation + # create e^(-j phi) first + + # angle = -rt_phi + # ephi = torch.view_as_complex(torch.stack([angle.cos(), angle.sin()], dim=-1)) ## this sign is from the negativity of phase lag + # a_ephi = -a * ephi + # t = torch.view_as_real((r + a_ephi).div(1 + r * a_ephi)) + # if(intensity): + # t = get_complex_energy(t) + # else: + # t = get_complex_magnitude(t) + ra_cosphi_by_n2 = -2 * r * a * rt_phi.cos() + t = (a * a + r * r + ra_cosphi_by_n2) / (1 + r * r * a * a + ra_cosphi_by_n2) + if not intensity: + # as long as a is not equal to r, t cannot be 0. + t = t.sqrt() + + return t + + +def mrr_roundtrip_phase_to_tr_func( + a: float = 0.8, r: float = 0.9, intensity: bool = False +): + c1 = -2 * a * r + c2 = a * a + r * r + c3 = 1 + r * r * a * a - a * a - r * r + c4 = (a**2 - 1) * (r**2 - 1) * 2 * a * r + + class MRRRoundTripPhaseToTrFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, input): + ctx.save_for_backward(input) + # ra_cosphi_by_n2 = input.cos().mul_(c1) + # numerator = ra_cosphi_by_n2.add_(c2) + # denominator = numerator.add(c3) + # t = numerator / denominator + t = input.cos().mul_(c1).add_(c2 + c3).reciprocal_().mul_(-c3).add_(1) + if not intensity: + # as long as a is not equal to r, t cannot be 0. + t.sqrt_() + return t + + @staticmethod + def backward(ctx, grad_output): + (input,) = ctx.saved_tensors + denominator = input.cos().mul_(c1).add_(c2 + c3) + + if intensity: + denominator.square_() + numerator = input.sin().mul_(c4) + else: + numerator = input.sin().mul_(c4 / 2) + denominator = ( + denominator.sub(1).pow_(1.5).mul_(denominator.sub(c3).sqrt_()) + ) + grad_input = numerator.div_(denominator).mul_(grad_output) + return grad_input + + return MRRRoundTripPhaseToTrFunction.apply diff --git a/src/chop/nn/optical/utils/quantize.py b/src/chop/nn/optical/utils/quantize.py new file mode 100644 index 000000000..0bb53f90c --- /dev/null +++ b/src/chop/nn/optical/utils/quantize.py @@ -0,0 +1,382 @@ +# """ +# Description: +# Author: Jiaqi Gu (jqgu@utexas.edu) +# Date: 2021-06-06 03:15:00 +# LastEditors: Jiaqi Gu (jqgu@utexas.edu) +# LastEditTime: 2021-06-06 03:15:00 +# """ + +import numpy as np +import torch +import logging + + +__all__ = [ + # "uniform_quantize_cpu", + # "pact_quantize", + # "PACT_Act", + # "uniform_quantize", + # "uniform_quantize_new", + # "ewgs_quantize", + "input_quantize_fn", + "weight_quantize_fn", +] + + +def uniform_quantize(k, gradient_clip=False): + class qfn(torch.autograd.Function): + @staticmethod + def forward(ctx, input): + if k == 32: + out = input + elif k == 1: + out = torch.sign(input) + else: + n = float(2**k - 1) + out = torch.round(input * n) / n + return out + + @staticmethod + def backward(ctx, grad_output): + grad_input = grad_output.clone() + if gradient_clip: + grad_input.clamp_(-1, 1) + return grad_input + + return qfn.apply + + +############ add observer and new quant based on range and zeropoint for activation +def uniform_quantize_new(k, gradient_clip=False): + # """ + # Support uniform quantization with auto-adjusted input data range + # args: + # k: bitwidth + # scale, zeropoint: obtained from observer + # """ + + class qfn(torch.autograd.Function): + @staticmethod + def forward(ctx, input, scale, zero_point): + if k == 32: + out = input + elif k == 1: + out = torch.sign(input) + else: + n = float(2**k - 1) + # out = torch.round(input * n) / n + # out = (torch.clamp(torch.round(input / scale + zero_point), 0, n) - zero_point) * scale + out = ( + input.div(scale) + .add_(zero_point) + .round_() + .clamp_(0, n) + .sub_(zero_point) + .mul_(scale) + ) + return out + + @staticmethod + def backward(ctx, grad_output): + grad_input = grad_output.clone() + if gradient_clip: + grad_input.clamp_(-1, 1) + return grad_input, None, None + + return qfn.apply + + +class input_quantize_fn(torch.nn.Module): + def __init__( + self, in_bit, alg="dorefa", device=torch.device("cuda:0"), quant_ratio=1.0 + ): + # """Input quantizer with Quant_Noise supported + # Args: + # in_bit (int): Input quantization bitwidth. + # device (Device, optional): torch Device. Defaults to torch.device("cuda:0"). + # quant_ratio (float, optional): Quantization ratio. Defaults to 1.0. + # """ + super(input_quantize_fn, self).__init__() + assert 1 <= in_bit <= 32 + self.in_bit = in_bit + self.alg = alg + assert alg in { + "dorefa", + "normal", + }, f"Only support (dorefa, normal), but got {alg}" + self.quant_ratio = quant_ratio + assert 0 <= quant_ratio <= 1, logging.error( + f"Wrong quant ratio. Must in [0,1], but got {quant_ratio}" + ) + self.device = device + + # define quant style + # dorefa: clamp to 0-1 + # normal: obtain scale and zero_point via observer + + if self.alg == "dorefa": + self.uniform_q = uniform_quantize(k=in_bit) + elif self.alg == "normal": + self.uniform_q = uniform_quantize_new(k=in_bit) + self.scale = None + self.zero_point = None + ### select scale and zero-point using EMA: exponential moving averages + # AT: MovingAverageMinMaxObserver only support self-defined quant bitwidths for pytorch1.7 + # obs = torch.quantization.observer.MovingAverageMinMaxObserver(averaging_constant=0.01, dtype=torch.quint8, + # qscheme=torch.per_tensor_affine, reduce_range=False, quant_min=0, quant_max=2**self.in_bit-1) + # Thus use our version + ### torch version must be higher than 1.7 + if 1 <= self.in_bit <= 8: # observer does not support higher than 8-bit + self.obs = torch.quantization.observer.MovingAverageMinMaxObserver( + averaging_constant=0.01, + dtype=torch.quint8, + qscheme=torch.per_tensor_affine, + reduce_range=False, + quant_min=0, + quant_max=2**self.in_bit - 1, + ).to(self.device) + else: + self.obs = None + + def set_bitwidth(self, bit: int) -> None: + ### regenerate quantizer without changing observation statistics + if bit != self.in_bit: + if self.alg == "dorefa": + self.uniform_q = uniform_quantize(k=bit) + elif self.alg == "normal": + self.uniform_q = uniform_quantize_new(k=bit) + self.in_bit = bit + + def set_alg(self, alg: str) -> None: + assert alg in { + "dorefa", + "normal", + }, f"Only support (dorefa, normal), but got {alg}" + if alg != self.alg: + if alg == "dorefa": + self.uniform_q = uniform_quantize(k=self.in_bit) + elif alg == "normal": + self.uniform_q = uniform_quantize_new(k=self.in_bit) + self.alg = alg + + def set_quant_ratio(self, quant_ratio=None): + if quant_ratio is None: + ### get recommended value + quant_ratio = [ + None, + 0.2, + 0.3, + 0.4, + 0.5, + 0.55, + 0.6, + 0.7, + 0.8, + 0.83, + 0.86, + 0.89, + 0.92, + 0.95, + 0.98, + 0.99, + 1, + ][min(self.in_bit, 16)] + assert 0 <= quant_ratio <= 1, logging.error( + f"Wrong quant ratio. Must in [0,1], but got {quant_ratio}" + ) + self.quant_ratio = quant_ratio + + def forward(self, x): + if self.quant_ratio < 1 and self.training: + ### implementation from fairseq + ### must fully quantize during inference + quant_noise_mask = torch.empty_like(x, dtype=torch.bool).bernoulli_( + 1 - self.quant_ratio + ) + else: + quant_noise_mask = None + + if self.in_bit == 32: + input_q = x + elif self.in_bit == 1: + x = x.clamp(0, 1) + input_q = (self.uniform_q(x - 0.5) + 1) / 2 + if quant_noise_mask is not None: + noise = input_q.data.sub_(x.data).masked_fill_(quant_noise_mask, 0) + ### unquantized inputs have to be clamped + input_q = x + noise + else: + ### dorefa-style clamp for input data + if self.alg == "dorefa": + x = x.clamp(0, 1) + input_q = self.uniform_q(x) + elif self.alg == "normal": + if self.obs is not None: + if self.training: + self.obs(x) + scale, zero_point = self.obs.calculate_qparams() + # convert scale and zero_point type from qint8 + self.scale = scale.to(x) + self.zero_point = zero_point.to(x) + input_q = self.uniform_q(x, self.scale, self.zero_point) + else: + input_q = x # if no observer (in_bit > 8), do not quantize + else: + raise NotImplementedError + + # add noise + if quant_noise_mask is not None: + noise = input_q.data.sub_(x.data).masked_fill_(quant_noise_mask, 0) + ### unquantized inputs have to be clamped + input_q = x + noise + + return input_q + + +class weight_quantize_fn(torch.nn.Module): + def __init__(self, w_bit, mode="oconv", alg="dorefa", quant_ratio=1.0): + # """Differentiable weight quantizer. Support different algorithms. Support Quant-Noise with partial quantization. + + # Args: + # w_bit (int): quantization bitwidth + # mode (str, optional): Different mode indicates different NN architectures. Defaults to "oconv". + # alg (str, optional): Quantization algorithms. [dorefa, dorefa_sym, qnn, dorefa_pos] Defaults to "dorefa". + # quant_ratio (float, optional): Quantization ratio to support full-precision gradient flow. Defaults to 1.0. + # """ + super(weight_quantize_fn, self).__init__() + assert 1 <= w_bit <= 32, logging.error( + f"Only support 1 - 32 bit quantization, but got {w_bit}" + ) + self.w_bit = w_bit + self.alg = alg + self.mode = mode + assert alg in {"dorefa", "dorefa_sym", "qnn", "dorefa_pos"}, logging.error( + f"Only support (dorefa, dorefa_sym, qnn, dorefa_pos) algorithms, but got {alg}" + ) + self.quant_ratio = quant_ratio + assert 0 <= quant_ratio <= 1, logging.error( + f"Wrong quant ratio. Must in [0,1], but got {quant_ratio}" + ) + self.uniform_q = uniform_quantize(k=w_bit, gradient_clip=True) + + def set_quant_ratio(self, quant_ratio=None): + if quant_ratio is None: + ### get recommended value + quant_ratio = [ + None, + 0.2, + 0.3, + 0.4, + 0.5, + 0.55, + 0.6, + 0.7, + 0.8, + 0.83, + 0.86, + 0.89, + 0.92, + 0.95, + 0.98, + 0.99, + 1, + ][min(self.w_bit, 16)] + assert 0 <= quant_ratio <= 1, logging.error( + f"Wrong quant ratio. Must in [0,1], but got {quant_ratio}" + ) + self.quant_ratio = quant_ratio + + def set_bitwidth(self, bit: int) -> None: + ### regenerate quantizer without changing observation statistics + if bit != self.w_bit: + self.uniform_q = uniform_quantize(k=bit, gradient_clip=True) + self.w_bit = bit + + def forward(self, x): + if self.quant_ratio < 1 and self.training: + ### implementation from fairseq + ### must fully quantize during inference + quant_noise_mask = torch.empty_like(x, dtype=torch.bool).bernoulli_( + 1 - self.quant_ratio + ) + else: + quant_noise_mask = None + + if self.w_bit == 32: + weight_q = torch.tanh(x) + weight_q = weight_q / torch.max(torch.abs(weight_q)) + elif self.w_bit == 1: + if self.mode == "ringonn": + weight_q = (self.uniform_q(x) / 4) + 0.5 + else: + if self.alg == "dorefa": + E = x.data.abs().mean() + weight_q = (self.uniform_q(x / E) * E + E) / 2 # [0, E] + if quant_noise_mask is not None: + x = (x + E) / 2 + noise = weight_q.data.sub_(x.data).masked_fill_( + quant_noise_mask, 0 + ) + ### unquantized weights have to follow reparameterization, i.e., tanh and scale + weight_q = x + noise + elif self.alg == "dorefa_sym": + E = x.data.abs().mean() + weight_q = self.uniform_q(x / E) * E # [-E, E] + if quant_noise_mask is not None: + noise = weight_q.data.sub_(x.data).masked_fill_( + quant_noise_mask, 0 + ) + ### unquantized weights have to follow reparameterization, i.e., tanh and scale + weight_q = x + noise + else: + assert NotImplementedError + else: + if self.alg == "dorefa": + weight = torch.tanh(x) # [-1, 1] + weight = weight / 2 / torch.max(torch.abs(weight.data)) + 0.5 + # weight = weight / 2 + 0.5 + weight_q = self.uniform_q(weight) + if quant_noise_mask is not None: + noise = weight_q.data.sub_(weight.data).masked_fill_( + quant_noise_mask, 0 + ) + ### unquantized weights have to follow reparameterization, i.e., tanh and scale + weight_q = weight + noise + + elif self.alg == "dorefa_sym": + weight = torch.tanh(x) # [-1, 1] + r = torch.max(torch.abs(weight.data)) + # weight = weight / 2 + 0.5 + weight_q = self.uniform_q(weight / (2 * r) + 0.5) * (2 * r) - r + if quant_noise_mask is not None: + noise = weight_q.data.sub_(weight.data).masked_fill_( + quant_noise_mask, 0 + ) + ### unquantized weights have to follow reparameterization, i.e., tanh + weight_q = weight + noise + elif self.alg == "dorefa_pos": + weight = torch.tanh(x) # [-1, 1] + r = torch.max(torch.abs(weight.data)) + weight = weight + r + # weight = weight / 2 + 0.5 + weight_q = self.uniform_q(weight / (2 * r)) * 2 * r + if quant_noise_mask is not None: + noise = weight_q.data.sub_(weight.data).masked_fill_( + quant_noise_mask, 0 + ) + ### unquantized weights have to follow reparameterization, i.e., tanh + weight_q = weight + noise + + elif self.alg == "qnn": + x_min = torch.min(x.data) + x_max = torch.max(x.data) + x_range = x_max - x_min + weight_q = self.uniform_q((x - x_min) / x_range) * x_range + x_min + if quant_noise_mask is not None: + noise = weight_q.data.sub_(x.data).masked_fill_(quant_noise_mask, 0) + ### unquantized weights have to follow reparameterization, i.e., tanh + weight_q = x + noise + else: + assert NotImplementedError + + return weight_q diff --git a/src/chop/nn/quantized/modules/__init__.py b/src/chop/nn/quantized/modules/__init__.py index dd7b49c4e..4369ea54e 100644 --- a/src/chop/nn/quantized/modules/__init__.py +++ b/src/chop/nn/quantized/modules/__init__.py @@ -82,9 +82,7 @@ BatchNorm2dInteger, BatchNorm2dBinary, ) -from .layer_norm import ( - LayerNormInteger, -) +from .layer_norm import LayerNormInteger from .group_norm import GroupNormInteger from .instance_norm2d import InstanceNorm2dInteger @@ -161,12 +159,8 @@ SoftplusBinary, SoftplusTernary, ) -from .batch_norm1d import ( - BatchNorm1dInteger, -) -from .gqa import ( - GroupedQueryAttentionInteger, -) +from .batch_norm1d import BatchNorm1dInteger +from .gqa import GroupedQueryAttentionInteger quantized_basic_module_map = { "conv1d_block_minifloat": Conv1dBlockMinifloat, diff --git a/src/chop/nn/quantized/modules/attention.py b/src/chop/nn/quantized/modules/attention.py index 45819db75..315a74ec3 100644 --- a/src/chop/nn/quantized/modules/attention.py +++ b/src/chop/nn/quantized/modules/attention.py @@ -6,9 +6,7 @@ from transformers.models.bert.modeling_bert import BertSelfAttention -from chop.nn.quantized.modules.linear import ( - LinearInteger, -) +from chop.nn.quantized.modules.linear import LinearInteger from chop.nn.quantized.functional import fixed_softermax from chop.nn.quantized.functional import matmul_integer diff --git a/src/chop/nn/quantized/modules/attention_head.py b/src/chop/nn/quantized/modules/attention_head.py index 8f9ea5969..93f4801e8 100644 --- a/src/chop/nn/quantized/modules/attention_head.py +++ b/src/chop/nn/quantized/modules/attention_head.py @@ -6,9 +6,7 @@ from typing import Optional, Tuple from functools import partial -from chop.nn.quantized.functional.matmul import ( - generic_matmul_integer, -) +from chop.nn.quantized.functional.matmul import generic_matmul_integer from chop.nn.quantizers.integer import integer_quantizer diff --git a/src/chop/nn/quantized/modules/batch_norm1d.py b/src/chop/nn/quantized/modules/batch_norm1d.py index b84c0d131..bafb96f9b 100644 --- a/src/chop/nn/quantized/modules/batch_norm1d.py +++ b/src/chop/nn/quantized/modules/batch_norm1d.py @@ -4,9 +4,7 @@ from torch import Tensor from torch.nn import functional as F -from chop.nn.quantizers import ( - integer_quantizer, -) +from chop.nn.quantizers import integer_quantizer class _BatchNorm1dBase(torch.nn.BatchNorm1d): diff --git a/src/chop/nn/quantized/modules/group_norm.py b/src/chop/nn/quantized/modules/group_norm.py index a90e5b651..25721aae9 100644 --- a/src/chop/nn/quantized/modules/group_norm.py +++ b/src/chop/nn/quantized/modules/group_norm.py @@ -7,9 +7,7 @@ from torch import Tensor import torch.nn.functional as F -from chop.nn.quantizers import ( - integer_quantizer, -) +from chop.nn.quantizers import integer_quantizer from mase_components.scalar_operators.fixed.test.isqrt_sw import isqrt_sw2 diff --git a/src/chop/nn/quantized/modules/instance_norm2d.py b/src/chop/nn/quantized/modules/instance_norm2d.py index 0a7260443..d3946401f 100644 --- a/src/chop/nn/quantized/modules/instance_norm2d.py +++ b/src/chop/nn/quantized/modules/instance_norm2d.py @@ -4,9 +4,7 @@ from torch import Tensor import torch.nn.functional as F -from chop.nn.quantizers import ( - integer_quantizer, -) +from chop.nn.quantizers import integer_quantizer class _InstanceNorm2dBase(nn.InstanceNorm2d): diff --git a/src/chop/nn/quantized/modules/layer_norm.py b/src/chop/nn/quantized/modules/layer_norm.py index 2ca5c6068..0d7e4d413 100644 --- a/src/chop/nn/quantized/modules/layer_norm.py +++ b/src/chop/nn/quantized/modules/layer_norm.py @@ -4,9 +4,7 @@ from torch import Tensor import torch.nn.functional as F -from chop.nn.quantizers import ( - integer_quantizer, -) +from chop.nn.quantizers import integer_quantizer class _LayerNormBase(nn.LayerNorm): diff --git a/src/chop/nn/quantized/modules/rms_norm.py b/src/chop/nn/quantized/modules/rms_norm.py index 91dd9d9d6..a6b893b32 100644 --- a/src/chop/nn/quantized/modules/rms_norm.py +++ b/src/chop/nn/quantized/modules/rms_norm.py @@ -5,9 +5,7 @@ from torch import Tensor import torch.nn.functional as F -from chop.nn.quantizers import ( - integer_quantizer, -) +from chop.nn.quantizers import integer_quantizer def _rms_norm(x: Tensor, eps, scale: Tensor | None): diff --git a/src/chop/nn/snn/modules/__init__.py b/src/chop/nn/snn/modules/__init__.py index 6b6efd229..cec1196bd 100644 --- a/src/chop/nn/snn/modules/__init__.py +++ b/src/chop/nn/snn/modules/__init__.py @@ -60,9 +60,7 @@ ) from .embedding import EmbeddingZIPTF -from .roberta import ( - RobertaSelfAttentionZIPTF, -) +from .roberta import RobertaSelfAttentionZIPTF spiking_basic_module_map = { "conv1d": Conv1d, diff --git a/src/chop/nn/snn/modules/neuron/ifnode.py b/src/chop/nn/snn/modules/neuron/ifnode.py index e5e3c373f..c2376d310 100644 --- a/src/chop/nn/snn/modules/neuron/ifnode.py +++ b/src/chop/nn/snn/modules/neuron/ifnode.py @@ -244,10 +244,12 @@ def multi_step_forward(self, x_seq: torch.Tensor): self.v_float_to_tensor(x_seq[0]) if self.v_reset is None: if self.store_v_seq: - spike_seq, self.v, self.v_seq = ( - self.jit_eval_multi_step_forward_soft_reset_with_v_seq( - x_seq, self.v, self.v_threshold - ) + ( + spike_seq, + self.v, + self.v_seq, + ) = self.jit_eval_multi_step_forward_soft_reset_with_v_seq( + x_seq, self.v, self.v_threshold ) else: spike_seq, self.v = self.jit_eval_multi_step_forward_soft_reset( @@ -255,10 +257,12 @@ def multi_step_forward(self, x_seq: torch.Tensor): ) else: if self.store_v_seq: - spike_seq, self.v, self.v_seq = ( - self.jit_eval_multi_step_forward_hard_reset_with_v_seq( - x_seq, self.v, self.v_threshold, self.v_reset - ) + ( + spike_seq, + self.v, + self.v_seq, + ) = self.jit_eval_multi_step_forward_hard_reset_with_v_seq( + x_seq, self.v, self.v_threshold, self.v_reset ) else: spike_seq, self.v = self.jit_eval_multi_step_forward_hard_reset( diff --git a/src/chop/nn/snn/modules/neuron/lifnode.py b/src/chop/nn/snn/modules/neuron/lifnode.py index 90b02289c..0f0522f4d 100644 --- a/src/chop/nn/snn/modules/neuron/lifnode.py +++ b/src/chop/nn/snn/modules/neuron/lifnode.py @@ -417,29 +417,33 @@ def single_step_forward(self, x: torch.Tensor): self.v_float_to_tensor(x) if self.v_reset is None: if self.decay_input: - spike, self.v = ( - self.jit_eval_single_step_forward_soft_reset_decay_input( - x, self.v, self.v_threshold, self.tau - ) + ( + spike, + self.v, + ) = self.jit_eval_single_step_forward_soft_reset_decay_input( + x, self.v, self.v_threshold, self.tau ) else: - spike, self.v = ( - self.jit_eval_single_step_forward_soft_reset_no_decay_input( - x, self.v, self.v_threshold, self.tau - ) + ( + spike, + self.v, + ) = self.jit_eval_single_step_forward_soft_reset_no_decay_input( + x, self.v, self.v_threshold, self.tau ) else: if self.decay_input: - spike, self.v = ( - self.jit_eval_single_step_forward_hard_reset_decay_input( - x, self.v, self.v_threshold, self.v_reset, self.tau - ) + ( + spike, + self.v, + ) = self.jit_eval_single_step_forward_hard_reset_decay_input( + x, self.v, self.v_threshold, self.v_reset, self.tau ) else: - spike, self.v = ( - self.jit_eval_single_step_forward_hard_reset_no_decay_input( - x, self.v, self.v_threshold, self.v_reset, self.tau - ) + ( + spike, + self.v, + ) = self.jit_eval_single_step_forward_hard_reset_no_decay_input( + x, self.v, self.v_threshold, self.v_reset, self.tau ) return spike @@ -514,56 +518,68 @@ def multi_step_forward(self, x_seq: torch.Tensor): if self.v_reset is None: if self.decay_input: if self.store_v_seq: - spike_seq, self.v, self.v_seq = ( - self.jit_eval_multi_step_forward_soft_reset_decay_input_with_v_seq( - x_seq, self.v, self.v_threshold, self.tau - ) + ( + spike_seq, + self.v, + self.v_seq, + ) = self.jit_eval_multi_step_forward_soft_reset_decay_input_with_v_seq( + x_seq, self.v, self.v_threshold, self.tau ) else: - spike_seq, self.v = ( - self.jit_eval_multi_step_forward_soft_reset_decay_input( - x_seq, self.v, self.v_threshold, self.tau - ) + ( + spike_seq, + self.v, + ) = self.jit_eval_multi_step_forward_soft_reset_decay_input( + x_seq, self.v, self.v_threshold, self.tau ) else: if self.store_v_seq: - spike_seq, self.v, self.v_seq = ( - self.jit_eval_multi_step_forward_soft_reset_no_decay_input_with_v_seq( - x_seq, self.v, self.v_threshold, self.tau - ) + ( + spike_seq, + self.v, + self.v_seq, + ) = self.jit_eval_multi_step_forward_soft_reset_no_decay_input_with_v_seq( + x_seq, self.v, self.v_threshold, self.tau ) else: - spike_seq, self.v = ( - self.jit_eval_multi_step_forward_soft_reset_no_decay_input( - x_seq, self.v, self.v_threshold, self.tau - ) + ( + spike_seq, + self.v, + ) = self.jit_eval_multi_step_forward_soft_reset_no_decay_input( + x_seq, self.v, self.v_threshold, self.tau ) else: if self.decay_input: if self.store_v_seq: - spike_seq, self.v, self.v_seq = ( - self.jit_eval_multi_step_forward_hard_reset_decay_input_with_v_seq( - x_seq, self.v, self.v_threshold, self.v_reset, self.tau - ) + ( + spike_seq, + self.v, + self.v_seq, + ) = self.jit_eval_multi_step_forward_hard_reset_decay_input_with_v_seq( + x_seq, self.v, self.v_threshold, self.v_reset, self.tau ) else: - spike_seq, self.v = ( - self.jit_eval_multi_step_forward_hard_reset_decay_input( - x_seq, self.v, self.v_threshold, self.v_reset, self.tau - ) + ( + spike_seq, + self.v, + ) = self.jit_eval_multi_step_forward_hard_reset_decay_input( + x_seq, self.v, self.v_threshold, self.v_reset, self.tau ) else: if self.store_v_seq: - spike_seq, self.v, self.v_seq = ( - self.jit_eval_multi_step_forward_hard_reset_no_decay_input_with_v_seq( - x_seq, self.v, self.v_threshold, self.v_reset, self.tau - ) + ( + spike_seq, + self.v, + self.v_seq, + ) = self.jit_eval_multi_step_forward_hard_reset_no_decay_input_with_v_seq( + x_seq, self.v, self.v_threshold, self.v_reset, self.tau ) else: - spike_seq, self.v = ( - self.jit_eval_multi_step_forward_hard_reset_no_decay_input( - x_seq, self.v, self.v_threshold, self.v_reset, self.tau - ) + ( + spike_seq, + self.v, + ) = self.jit_eval_multi_step_forward_hard_reset_no_decay_input( + x_seq, self.v, self.v_threshold, self.v_reset, self.tau ) return spike_seq diff --git a/src/chop/passes/__init__.py b/src/chop/passes/__init__.py index f5aa7fe22..eadf7e849 100644 --- a/src/chop/passes/__init__.py +++ b/src/chop/passes/__init__.py @@ -46,8 +46,6 @@ ann2snn_module_transform_pass, ) -from .onnx.analysis import ( - export_fx_graph_analysis_pass, -) +from .onnx.analysis import export_fx_graph_analysis_pass from .graph.analysis.autosharding import autosharding_analysis_pass diff --git a/src/chop/passes/graph/__init__.py b/src/chop/passes/graph/__init__.py index 09786f09f..61e076a27 100644 --- a/src/chop/passes/graph/__init__.py +++ b/src/chop/passes/graph/__init__.py @@ -47,9 +47,7 @@ from .transforms.quantize.quant_parsers import parse_node_config -from chop.passes.graph.analysis.runtime.runtime_analysis import ( - runtime_analysis_pass, -) +from chop.passes.graph.analysis.runtime.runtime_analysis import runtime_analysis_pass from .interface import tensorrt_engine_interface_pass diff --git a/src/chop/passes/graph/analysis/add_metadata/common_metadata_layers.py b/src/chop/passes/graph/analysis/add_metadata/common_metadata_layers.py index be02f0046..7d59937ac 100644 --- a/src/chop/passes/graph/analysis/add_metadata/common_metadata_layers.py +++ b/src/chop/passes/graph/analysis/add_metadata/common_metadata_layers.py @@ -315,9 +315,9 @@ "key": "config", "value": "data_in", }, - "invert": { # Added for Wave2Vec + "invert": { "input": "data_in", - }, + }, # Added for Wave2Vec } module_data = { diff --git a/src/chop/passes/graph/analysis/autosharding/autosharding.py b/src/chop/passes/graph/analysis/autosharding/autosharding.py index f1ad3a4fe..b149edf4a 100644 --- a/src/chop/passes/graph/analysis/autosharding/autosharding.py +++ b/src/chop/passes/graph/analysis/autosharding/autosharding.py @@ -302,8 +302,11 @@ def autosharding_analysis_pass(mg, pass_args: dict = {}): if not pass_args.get(f"skip_forward", False): tensor_sharding_map = _get_sharding_map(mg) - return mg, { - "autosharding_time": autosharding_time, - "tensor_sharding_map": tensor_sharding_map, - **pass_outs, - } + return ( + mg, + { + "autosharding_time": autosharding_time, + "tensor_sharding_map": tensor_sharding_map, + **pass_outs, + }, + ) diff --git a/src/chop/passes/graph/analysis/autosharding/strategies/tensor_ops.py b/src/chop/passes/graph/analysis/autosharding/strategies/tensor_ops.py index 05c76160a..96cbfa445 100644 --- a/src/chop/passes/graph/analysis/autosharding/strategies/tensor_ops.py +++ b/src/chop/passes/graph/analysis/autosharding/strategies/tensor_ops.py @@ -6,9 +6,7 @@ PlacementStrategy, StrategyType, ) -from torch.distributed.tensor.ops.utils import ( - is_tensor_partial, -) +from torch.distributed.tensor.ops.utils import is_tensor_partial from torch.distributed.tensor.placement_types import ( _DTensorSpec, Partial, diff --git a/src/chop/passes/graph/transforms/utils/logicnets_fusion.py b/src/chop/passes/graph/transforms/utils/logicnets_fusion.py index 5bd5977b3..772206878 100644 --- a/src/chop/passes/graph/transforms/utils/logicnets_fusion.py +++ b/src/chop/passes/graph/transforms/utils/logicnets_fusion.py @@ -12,12 +12,8 @@ matches_module_pattern, replace_node_module, ) -from chop.nn.quantized.modules.linear import ( - LinearLogicNets, -) -from chop.nn.quantized.modules.conv2d import ( - Conv2DLogicNets, -) +from chop.nn.quantized.modules.linear import LinearLogicNets +from chop.nn.quantized.modules.conv2d import Conv2DLogicNets # Housekeeping ------------------------------------------------------------------------- logger = logging.getLogger(__file__) diff --git a/src/chop/passes/graph/transforms/verilog/logicnets/emit_linear.py b/src/chop/passes/graph/transforms/verilog/logicnets/emit_linear.py index 3d8c0a16e..f0a52f75a 100644 --- a/src/chop/passes/graph/transforms/verilog/logicnets/emit_linear.py +++ b/src/chop/passes/graph/transforms/verilog/logicnets/emit_linear.py @@ -4,9 +4,7 @@ import torch.nn as nn from chop.passes.graph.utils import init_project -from chop.nn.quantized.modules.linear import ( - LinearLogicNets, -) +from chop.nn.quantized.modules.linear import LinearLogicNets from .util import ( generate_lut_verilog, diff --git a/src/chop/passes/module/module_modify_helper.py b/src/chop/passes/module/module_modify_helper.py index e04492572..ea0634be3 100644 --- a/src/chop/passes/module/module_modify_helper.py +++ b/src/chop/passes/module/module_modify_helper.py @@ -15,15 +15,15 @@ RobertaSelfOutput, ) -from transformers.models.llama.modeling_llama import ( - LlamaAttention, -) +from transformers.models.llama.modeling_llama import LlamaAttention from transformers.models.bert.modeling_bert import ( - BertSelfAttention, BertSdpaSelfAttention, + BertSelfAttention, ) +from transformers.models.bert.configuration_bert import BertConfig + roberta_prefix_map = { RobertaSdpaSelfAttention: "roberta_self_attention", RobertaSelfAttention: "roberta_self_attention", @@ -38,8 +38,8 @@ } bert_prefix_map = { - BertSelfAttention: "bert_self_attention", BertSdpaSelfAttention: "bert_self_attention", + BertSelfAttention: "bert_self_attention", } @@ -141,7 +141,7 @@ def instantiate_conv2d(module, postfix, module_map, additional_module_args): has_bias = not (module.bias is None) # TODO: some transformed modules have "config" as an argument then extract the additional_module_args from it. Some directly take the additional_module_args. # Need to handle this better - if "config" in inspect.signature(conv2d.__init__).parameters: + if "config" in inspect.signature(conv2d_cls.__init__).parameters: conv2d = conv2d_cls( in_channels=module.in_channels, out_channels=module.out_channels, @@ -224,14 +224,23 @@ def instantiate_llama_module( def instantiate_bert_module( - module, postfix, prefix, module_map, module_args, network_args + module, + postfix, + prefix, + module_map, + module_args, ): bert_cls = module_map[f"{prefix}_{postfix}"] bert_module = bert_cls( - config=network_args, - layer_idx=module.layer_idx, - q_config=module_args, + config=BertConfig( + hidden_size=module.query.in_features, + num_attention_heads=module.num_attention_heads, + attention_head_size=module.attention_head_size, + attention_probs_dropout_prob=module.dropout_prob, + is_decoder=False, + ), + morr_config=module_args, ) return bert_module @@ -240,6 +249,7 @@ def instantiate_module(module, postfix, module_map, additional_module_args): is_roberta, roberta_layer_name = check_module_instance(module, roberta_prefix_map) is_llama, llama_layer_name = check_module_instance(module, llama_prefix_map) is_bert, bert_layer_name = check_module_instance(module, bert_prefix_map) + is_bert, bert_layer_name = check_module_instance(module, bert_prefix_map) module_args = additional_module_args["config"] network_args = additional_module_args.get("network_config", None) @@ -262,7 +272,11 @@ def instantiate_module(module, postfix, module_map, additional_module_args): ) elif is_bert: module = instantiate_bert_module( - module, postfix, llama_layer_name, module_map, module_args, network_args + module, + postfix, + bert_layer_name, + module_map, + module_args, ) else: raise ValueError(f"{module} is not supported.") diff --git a/src/chop/passes/module/transforms/attention/attention_transform_helper.py b/src/chop/passes/module/transforms/attention/attention_transform_helper.py index 9f9129b69..3216e1db9 100644 --- a/src/chop/passes/module/transforms/attention/attention_transform_helper.py +++ b/src/chop/passes/module/transforms/attention/attention_transform_helper.py @@ -11,9 +11,7 @@ MLA, ) from chop.nn.modules.mgqa import MGQALayers, MGQA -from chop.nn.modules.lora_linear import ( - LowRankLinear, -) +from chop.nn.modules.lora_linear import LowRankLinear from ...module_modify_helper import ( get_module_by_name, set_module_by_name, @@ -421,7 +419,6 @@ def _create_rotary_embeddings(self, seqlen, rope_dim, device): class MGQAWrapper(torch.nn.Module): - def __init__(self, mgqa: MGQA): super().__init__() self.mgqa = mgqa diff --git a/src/chop/passes/module/transforms/optical/__init__.py b/src/chop/passes/module/transforms/optical/__init__.py new file mode 100644 index 000000000..9b1840c4e --- /dev/null +++ b/src/chop/passes/module/transforms/optical/__init__.py @@ -0,0 +1 @@ +from .optical import optical_module_transform_pass diff --git a/src/chop/passes/module/transforms/optical/module_transform_helper.py b/src/chop/passes/module/transforms/optical/module_transform_helper.py new file mode 100644 index 000000000..a88df57a0 --- /dev/null +++ b/src/chop/passes/module/transforms/optical/module_transform_helper.py @@ -0,0 +1,321 @@ +import torch +import torch.nn as nn +import math +from functools import reduce, partial +from copy import deepcopy +import logging +import inspect +import warnings + +from chop.passes.module.module_modify_helper import ( + get_module_by_name, + set_module_by_name, +) +from chop.passes.module.state_dict_map import SPECIAL_CONVERT_PATTERNS + +from transformers.models.roberta.modeling_roberta import ( + RobertaSelfAttention, + RobertaSdpaSelfAttention, + RobertaClassificationHead, + RobertaIntermediate, + RobertaOutput, + RobertaSelfOutput, +) + +from transformers.models.llama.modeling_llama import LlamaAttention + + +def check_module_instance(module, prefix_map): + """ + Check if the given module is an instance of any class in the prefix_map. If it is, return the corresponding prefix. + Args: + module (object): The module to check. + prefix_map (dict): A dictionary where keys are classes and values are prefixes. + Returns: + tuple: A tuple containing a boolean indicating if the module is an instance of any class in the prefix_map, + and the corresponding prefix if it is an instance, otherwise None. + """ + for cls, name in prefix_map.items(): + if isinstance(module, cls): + return True, name + return False, None + + +def replace_by_name_optical(network, module_name: str, new_module, target_name): + + original = get_module_by_name(network, module_name) + if target_name == "linear_morr_full": + updated_module = weight_replacement_full_linear_optical(original, new_module) + elif target_name in ["linear_morr", "linear_morr_triton"]: + updated_module = weight_replacement_circulant_linear_optical( + original, new_module + ) + else: + raise NotImplementedError( + f"weight replacement function for the optical module {target_name} not implemented" + ) + + network = set_module_by_name(network, module_name, updated_module) + + return network + + +def weight_replacement_full_linear_optical(original, new_module): + if isinstance(original, nn.Linear): + return weight_replacement_linear_optical(original, new_module) + elif isinstance(original, nn.Conv2d): + return weight_replacement_conv2d_optical(original, new_module) + else: + raise NotImplementedError( + "weight replacement function for the optical module not implemented" + ) + + +def weight_replacement_linear_optical(linear_layer, morr_layer): + """ + Replace the weights of AllPassMORRLinear (morr_layer) with those from a standard nn.Linear (linear_layer). + Focuses only on weight copying (no bias copying). + """ + # Extract dimensions + out_features = morr_layer.out_features + in_features = morr_layer.in_features + miniblock = morr_layer.miniblock + grid_dim_x = morr_layer.grid_dim_x + grid_dim_y = morr_layer.grid_dim_y + in_features_pad = morr_layer.in_features_pad + + # Get the weights from the standard linear layer + standard_weights = linear_layer.weight.data # [out_features, in_features] + + # Ensure the shapes match + assert ( + standard_weights.shape[0] == out_features + ), "Output feature dimensions don't match" + assert ( + standard_weights.shape[1] == in_features + ), "Input feature dimensions don't match" + + # Pad the standard weights to match in_features_pad + if in_features_pad > in_features: + padded_weights = torch.zeros( + out_features, + in_features_pad, + device=standard_weights.device, + dtype=standard_weights.dtype, + ) + padded_weights[:, :in_features] = standard_weights + standard_weights = padded_weights # [out_features, in_features_pad] + + # Reshape to match the MORR structure [grid_dim_y, grid_dim_x, miniblock] + assert grid_dim_y == out_features, "grid_dim_y does not match out_features" + assert ( + grid_dim_x * miniblock == in_features_pad + ), "grid_dim_x * miniblock does not match in_features_pad" + + reshaped_weights = standard_weights.reshape(grid_dim_y, grid_dim_x, miniblock) + + # Copy the weights to the MORR layer + with torch.no_grad(): + morr_layer.weight.data.copy_(reshaped_weights) + + return morr_layer + + +def weight_replacement_circulant_linear_optical(x, y): + """ + Replace the weights of AllPassMORRCirculantLinear (y) with those from a standard nn.Linear (x). + Focuses only on weight copying (no bias copying). + take mean value along diagonal + """ + + # Dense weight + W = x.weight.data # [out_features, in_features] + + # Dimensions defined by the MORR layer + k = y.miniblock # miniblock size + grid_dim_y = y.grid_dim_y # #block-rows (p) + grid_dim_x = y.grid_dim_x # #block-cols (q) + out_features_p = y.out_features_pad + in_features_p = y.in_features_pad + + # Zero-pad so every block is k×k + W_padded = W.new_zeros((out_features_p, in_features_p)) + W_padded[: W.size(0), : W.size(1)] = W + + new_weight = W.new_zeros((grid_dim_y, grid_dim_x, k)) # [p, q, k] + + idx = torch.arange(k, device=W.device) # 0 … k-1, reused in every block + + with torch.no_grad(): + for p in range(grid_dim_y): + row_slice = slice(p * k, (p + 1) * k) + + for q in range(grid_dim_x): + col_slice = slice(q * k, (q + 1) * k) + block = W_padded[row_slice, col_slice] # shape (k, k) + + # Frobenius-projection onto the circulant subspace: + # c_j = mean of { block[i, (i+j) mod k], i=0…k-1 } + c = torch.stack([block[idx, (idx + j) % k].mean() for j in range(k)]) + + new_weight[p, q, :] = c # first row + + # Save back into the MORR layer + y.load_parameters({"weight": new_weight}) + + return y + + +def weight_replacement_conv2d_optical(x, y): + """ + Replace the weights (and bias, if present) of a standard nn.Conv2d (x) + into an AllPassMORRCirculantConv2d (y). + + Args: + x (nn.Conv2d): A standard PyTorch Conv2d module + y (AllPassMORRCirculantConv2d): An already-constructed optical Conv2d + module into which we copy weights/bias. + """ + with torch.no_grad(): + # 1) Copy bias (if both x and y actually have one). + if x.bias is not None and y.bias is not None: + y.bias.copy_(x.bias) + + # 2) Flatten nn.Conv2d's weight => shape [out_channels, in_channels*kernel_h*kernel_w] + w_flat = x.weight.data.view(x.out_channels, -1) + + # 3) Zero-pad to match (out_channels_pad, in_channels_pad) + outC_pad = y.out_channels_pad # == y.grid_dim_y * y.miniblock + inC_pad = y.in_channels_pad # == y.grid_dim_x * y.miniblock + + W = torch.zeros(outC_pad, inC_pad, device=w_flat.device, dtype=w_flat.dtype) + # Copy as many channels/elements as we have + W[: x.out_channels, : w_flat.size(1)] = w_flat + + # 4) Reshape into blocks => shape [p, miniblock, q, miniblock] + p = y.grid_dim_y + q = y.grid_dim_x + k = y.miniblock + W_blocks = W.view(p, k, q, k) # => [p, k, q, k] + + # 5) For each p,q block, extract the "first column" of size 'k' and place it in y.weight + # That is, for a k x k sub-block, we interpret sub_block[:,0] as the "circulant first column". + for i in range(p): + for j in range(q): + sub_block = W_blocks[i, :, j, :] # shape [k, k] + y.weight.data[i, j, :] = sub_block[:, 0] + + # Done. At this point, y.weight and y.bias (if present) have been overwritten + # with a simple block-circulant approximation of x's parameters. + return y + + +def instantiate_optical_module(module, postfix, module_map, additional_module_args): + module_args = additional_module_args["config"] + additional_args = additional_module_args["additional"] + network_args = additional_module_args.get("network_config", None) + + if isinstance(module, torch.nn.Linear): + module = instantiate_optical_linear( + module, postfix, module_map, module_args, additional_args + ) + elif isinstance(module, torch.nn.Conv2d): + module = instantiate_optical_conv2d(module, postfix, module_map, module_args) + else: + raise ValueError(f"{module} is not supported.") + return module + + +def instantiate_optical_linear( + module, postfix, module_map, additional_module_args, additional_args +): + linear_cls = module_map[f"linear_{postfix}"] + has_bias = not (module.bias is None) + + # TODO: some transformed modules have "config" as an argument then extract the additional_module_args from it. Some directly take the additional_module_args. + # Need to handle this better + if "config" in inspect.signature(linear_cls.__init__).parameters: + linear = linear_cls( + in_features=module.in_features, + out_features=module.out_features, + bias=has_bias, + config=additional_module_args, + ) + else: + linear = linear_cls( + in_features=module.in_features, + out_features=module.out_features, + bias=has_bias, + **additional_module_args, + ) + if additional_args is None: + return linear + + # extra handling for morr optical module + enable_thermal_crosstalk = additional_args.get("thermal_crosstalk", False) + enable_phase_noise = additional_args.get("phase_noise", False) + enable_trainable_morr_scale = additional_args.get("trainable_morr_scale", False) + enable_trainable_morr_bias = additional_args.get("trainable_morr_bias", False) + + if enable_thermal_crosstalk: + linear.enable_crosstalk() + linear.set_crosstalk_coupling_matrix( + additional_args.get("coupling_factor", 0.04), + additional_args.get("drop_perc", 0.0), + ) + + if enable_phase_noise: + linear.enable_phase_variation() + phase_noise_std = additional_args.get("phase_noise_std", 0.04) + linear.set_phase_variation(phase_noise_std) + + if enable_trainable_morr_scale: + linear.enable_trainable_morr_scale() + else: + linear.disable_trainable_morr_scale() + + if enable_trainable_morr_bias: + linear.enable_trainable_morr_bias() + else: + linear.disable_trainable_morr_bias() + + if "in_bit" in additional_args: + linear.set_input_bitwidth(in_bit=additional_args["in_bit"]) + if "w_bit" in additional_args: + linear.set_weight_bitwidth(w_bit=additional_args["w_bit"]) + + return linear + + +def instantiate_optical_conv2d(module, postfix, module_map, additional_module_args): + conv2d_cls = module_map[f"conv2d_{postfix}"] + has_bias = not (module.bias is None) + # TODO: some transformed modules have "config" as an argument then extract the additional_module_args from it. Some directly take the additional_module_args. + # Need to handle this better + if "config" in inspect.signature(conv2d_cls.__init__).parameters: + conv2d = conv2d_cls( + in_channels=module.in_channels, + out_channels=module.out_channels, + kernel_size=module.kernel_size, + stride=module.stride, + padding=module.padding, + dilation=module.dilation, + groups=module.groups, + bias=has_bias, + padding_mode=module.padding_mode, + config=additional_module_args, + ) + else: + conv2d = conv2d_cls( + in_channels=module.in_channels, + out_channels=module.out_channels, + kernel_size=module.kernel_size, + stride=module.stride, + padding=module.padding, + dilation=module.dilation, + groups=module.groups, + bias=has_bias, + padding_mode=module.padding_mode, + **additional_module_args, + ) + return conv2d diff --git a/src/chop/passes/module/transforms/optical/optical.py b/src/chop/passes/module/transforms/optical/optical.py new file mode 100644 index 000000000..9e72f617f --- /dev/null +++ b/src/chop/passes/module/transforms/optical/optical.py @@ -0,0 +1,152 @@ +import torch +from transformers.models.bert.modeling_bert import BertSdpaSelfAttention + +from chop.nn.optical.modules import optical_module_map +from chop.passes.module.module_modify_helper import instantiate_module +from chop.passes.module.transforms.optical.module_transform_helper import ( + replace_by_name_optical, + instantiate_optical_module, +) +from ...state_dict_map import match_a_pattern, check_is_huggingface_model + + +def get_config(config: dict, name: str): + if name in config: + return config[name]["config"] + else: + return config["default"]["config"] + + +def optical_transform_by_type(network, pass_args): + for type_name, config in pass_args.items(): + n_m = {} + for n, m in network.named_modules(): + n_m[n] = m + + if type_name == "linear": + module = torch.nn.Linear + elif type_name == "conv2d": + module = torch.nn.Conv2d + elif isinstance(m, BertSdpaSelfAttention): + type_name = "bert_self_attention" + else: + raise ValueError(f"{type_name} is not supported!") + + # config = config["config"] + # postfix = config.pop("name") + optical_config = config["config"] + optial_additional_config = config.get("additional", None) + postfix = optical_config["name"] + + additional_module_args = { + "config": optical_config, + "additional": optial_additional_config, + } + for n, m in n_m.items(): + if isinstance(m, module): + new_m = instantiate_optical_module( + m, postfix, optical_module_map, additional_module_args + ) + network = replace_by_name_optical( + network, n, new_m, type_name + "_" + postfix + ) + return network + + +def optical_transform_by_name(network, pass_args): + optical_names = pass_args.keys() + n_m = {} + for n, m in network.named_modules(): + n_m[n] = m + for n, m in n_m.items(): + if n in optical_names: + optical_config = pass_args[n]["config"] + optial_additional_config = pass_args[n].get("additional", None) + postfix = optical_config["name"] + + additional_module_args = { + "config": optical_config, + "additional": optial_additional_config, + } + + if isinstance(m, torch.nn.Linear): + type_name = "linear" + elif isinstance(m, torch.nn.Conv2d): + type_name = "conv2d" + else: + raise ValueError(f"{type_name} is not supported!") + + new_m = instantiate_optical_module( + m, postfix, optical_module_map, additional_module_args + ) + network = replace_by_name_optical( + network, n, new_m, type_name + "_" + postfix + ) + + return network + + +def optical_transform_by_regex_name(network, pass_args): + is_huggingface_model = check_is_huggingface_model(network) + + patterns = list(pass_args.keys()) + n_m = {} + for n, m in network.named_modules(): + n_m[n] = m + + for n, m in n_m.items(): + matched_pattern = match_a_pattern(n, patterns) + if not matched_pattern: + continue + print(f"processing {n}") + + optical_config = pass_args[matched_pattern]["config"] + optial_additional_config = pass_args[matched_pattern].get("additional", None) + postfix = optical_config["name"] + + additional_module_args = ( + {"config": optical_config, "additional": optial_additional_config} + # if is_huggingface_model + # else {"config": optical_config} + ) + + if isinstance(m, torch.nn.Linear): + type_name = "linear" + elif isinstance(m, torch.nn.Conv2d): + type_name = "conv2d" + else: + raise ValueError(f"{type_name} is not supported!") + + new_m = instantiate_optical_module( + m, postfix, optical_module_map, additional_module_args + ) + network = replace_by_name_optical(network, n, new_m, type_name + "_" + postfix) + + return network + + +def optical_module_transform_pass(network, pass_args): + """ + Apply optical transformation to the given nn.Module. + + :param network: The input network to be transformed. + :type network: torch.nn.Module + + :param pass_args: Additional arguments for the transformation. + :type pass_args: dict, optional + + :return: The transformed torch.nn.Module. + :rtype: tuple + :raises ValueError: If the "by" argument is unsupported. + """ + by = pass_args.pop("by") + match by: + case "type": + network = optical_transform_by_type(network, pass_args) + case "name": + network = optical_transform_by_name(network, pass_args) + case "regex_name": + network = optical_transform_by_regex_name(network, pass_args) + case _: + raise ValueError(f'Unsupported quantize "by": {by}') + return network, {} diff --git a/src/mase_components/activation_layers/test/fixed_gelu_tb.py b/src/mase_components/activation_layers/test/fixed_gelu_tb.py index ab7d75f4c..1a7d760b6 100644 --- a/src/mase_components/activation_layers/test/fixed_gelu_tb.py +++ b/src/mase_components/activation_layers/test/fixed_gelu_tb.py @@ -9,9 +9,7 @@ from mase_cocotb.runner import mase_runner -from mase_components.helper.generate_memory import ( - generate_sv_lut, -) +from mase_components.helper.generate_memory import generate_sv_lut DATA_IN_0_PRECISION_1 = 8 diff --git a/src/mase_components/activation_layers/test/fixed_softermax_1d_tb.py b/src/mase_components/activation_layers/test/fixed_softermax_1d_tb.py index d48214384..fa2b2aa75 100644 --- a/src/mase_components/activation_layers/test/fixed_softermax_1d_tb.py +++ b/src/mase_components/activation_layers/test/fixed_softermax_1d_tb.py @@ -19,9 +19,7 @@ from chop.nn.quantized.functional import fixed_softermax -from chop.nn.quantizers import ( - integer_quantizer, -) +from chop.nn.quantizers import integer_quantizer class SoftermaxTB(Testbench): diff --git a/src/mase_components/difflogic_layers/passes.py b/src/mase_components/difflogic_layers/passes.py index c1b42977e..6693d5cf2 100644 --- a/src/mase_components/difflogic_layers/passes.py +++ b/src/mase_components/difflogic_layers/passes.py @@ -3,7 +3,6 @@ def difflogic_hardware_metadata_optimize_pass(graph, args={}): - def _is_logiclayer(node): return node.meta["mase"]["common"]["mase_op"] == "user_defined_module" diff --git a/src/mase_components/linear_layers/fixed_operators/test/fixed_lut_index_tb.py b/src/mase_components/linear_layers/fixed_operators/test/fixed_lut_index_tb.py index 42bcb4edc..89c80aa83 100644 --- a/src/mase_components/linear_layers/fixed_operators/test/fixed_lut_index_tb.py +++ b/src/mase_components/linear_layers/fixed_operators/test/fixed_lut_index_tb.py @@ -94,7 +94,6 @@ async def cocotb_test_fixed_lut_index(dut): @pytest.mark.skip(reason="Needs to be fixed.") def test_fixed_lut_index(): - def full_sweep(): parameter_list = [] lut_pow = 5 diff --git a/src/mase_components/linear_layers/fixed_operators/test/fixed_range_reduction_tb.py b/src/mase_components/linear_layers/fixed_operators/test/fixed_range_reduction_tb.py index 1b54ea013..ed155214a 100644 --- a/src/mase_components/linear_layers/fixed_operators/test/fixed_range_reduction_tb.py +++ b/src/mase_components/linear_layers/fixed_operators/test/fixed_range_reduction_tb.py @@ -71,7 +71,6 @@ async def cocotb_test_fixed_range_reduction(dut): @pytest.mark.skip(reason="Needs to be fixed.") def test_fixed_range_reduction(): - def full_sweep(): parameter_list = [] for width in range(1, 17): diff --git a/src/mase_components/normalization_layers/test/batch_norm_2d_tb.py b/src/mase_components/normalization_layers/test/batch_norm_2d_tb.py index 02297dd3b..5d9da4bc6 100644 --- a/src/mase_components/normalization_layers/test/batch_norm_2d_tb.py +++ b/src/mase_components/normalization_layers/test/batch_norm_2d_tb.py @@ -344,7 +344,6 @@ async def valid_backpressure(dut): @pytest.mark.skip(reason="Needs to be fixed.") def test_batch_norm_2d(): - def gen_cfg( total_dim0: int = 4, total_dim1: int = 4, diff --git a/src/mase_components/normalization_layers/test/channel_selection_tb.py b/src/mase_components/normalization_layers/test/channel_selection_tb.py index ceb65df8a..0fb603971 100644 --- a/src/mase_components/normalization_layers/test/channel_selection_tb.py +++ b/src/mase_components/normalization_layers/test/channel_selection_tb.py @@ -61,7 +61,6 @@ async def basic(dut): @pytest.mark.skip(reason="Needs to be fixed.") def test_channel_selection(): - def gen_cfg(num_channels, num_blocks): return {"NUM_CHANNELS": num_channels, "NUM_SPATIAL_BLOCKS": num_blocks} diff --git a/src/mase_components/transformer_layers/test/fixed_grouped_query_attention_wrapper_tb.py b/src/mase_components/transformer_layers/test/fixed_grouped_query_attention_wrapper_tb.py index cc7d23428..d999fdc59 100644 --- a/src/mase_components/transformer_layers/test/fixed_grouped_query_attention_wrapper_tb.py +++ b/src/mase_components/transformer_layers/test/fixed_grouped_query_attention_wrapper_tb.py @@ -89,13 +89,16 @@ def forward(self, x: Tensor): out = self.o_projection(attn_output) - return out, { - "query": query, - "key": key.transpose(1, 2), # Key is transposed in hardware - "value": value, - "heads_out": heads_out, - "attn_output": attn_output, - } + return ( + out, + { + "query": query, + "key": key.transpose(1, 2), # Key is transposed in hardware + "value": value, + "heads_out": heads_out, + "attn_output": attn_output, + }, + ) class FixedGroupedQueryAttentionTB(Testbench): diff --git a/test/nn/quantized/modules/attention_head.py b/test/nn/quantized/modules/attention_head.py index e0f52a61e..72b625f90 100644 --- a/test/nn/quantized/modules/attention_head.py +++ b/test/nn/quantized/modules/attention_head.py @@ -1,6 +1,4 @@ -from chop.nn.quantized.modules.attention_head import ( - BertSelfAttentionHeadInteger, -) +from chop.nn.quantized.modules.attention_head import BertSelfAttentionHeadInteger from transformers import AutoConfig import torch diff --git a/test/passes/graph/transforms/quantize/test_quantize_lutnet_linear_2.py b/test/passes/graph/transforms/quantize/test_quantize_lutnet_linear_2.py index f1115cf56..72a0ce71d 100644 --- a/test/passes/graph/transforms/quantize/test_quantize_lutnet_linear_2.py +++ b/test/passes/graph/transforms/quantize/test_quantize_lutnet_linear_2.py @@ -10,9 +10,7 @@ from pathlib import Path sys.path.append(Path(__file__).resolve().parents[5].as_posix()) -from chop.nn.quantized.modules.linear import ( - LinearLogicNets, -) +from chop.nn.quantized.modules.linear import LinearLogicNets def generate_input_tensor(batch_size, input_features, min_val, max_val): diff --git a/test/passes/graph/transforms/training/test_training_base_pass.py b/test/passes/graph/transforms/training/test_training_base_pass.py index 04ed2167c..06d8371dd 100644 --- a/test/passes/graph/transforms/training/test_training_base_pass.py +++ b/test/passes/graph/transforms/training/test_training_base_pass.py @@ -17,9 +17,7 @@ verify_common_metadata_analysis_pass, ) from chop.ir.graph.mase_graph import MaseGraph -from chop.passes.graph.transforms import ( - training_base_pass, -) +from chop.passes.graph.transforms import training_base_pass from chop.passes.graph.utils import deepcopy_mase_graph from chop.tools.logger import set_logging_verbosity diff --git a/test/passes/graph/transforms/verilog/test_emit_verilog_norm.py b/test/passes/graph/transforms/verilog/test_emit_verilog_norm.py index daf78fbee..324a2e29d 100644 --- a/test/passes/graph/transforms/verilog/test_emit_verilog_norm.py +++ b/test/passes/graph/transforms/verilog/test_emit_verilog_norm.py @@ -11,9 +11,7 @@ from mase_components.scalar_operators.fixed.test.isqrt_sw import make_lut from mase_components.common.test.lut_tb import write_memb from chop.passes.graph.utils import get_module_by_name -from chop.nn.quantizers.quantizers_for_hw import ( - integer_quantizer_for_hw, -) +from chop.nn.quantizers.quantizers_for_hw import integer_quantizer_for_hw # import chop.models.manual.rms_norm as rms diff --git a/test/passes/module/transforms/attention/test_attention_transform.py b/test/passes/module/transforms/attention/test_attention_transform.py index 14137c66e..7bd58f13d 100644 --- a/test/passes/module/transforms/attention/test_attention_transform.py +++ b/test/passes/module/transforms/attention/test_attention_transform.py @@ -7,9 +7,7 @@ sys.path.append(Path(__file__).resolve().parents[5].as_posix()) -from chop.passes.module.transforms import ( - attention_swap_transform_pass, -) +from chop.passes.module.transforms import attention_swap_transform_pass from pathlib import Path import time diff --git a/test/passes/module/transforms/optical/note.md b/test/passes/module/transforms/optical/note.md new file mode 100644 index 000000000..92bce3b98 --- /dev/null +++ b/test/passes/module/transforms/optical/note.md @@ -0,0 +1,47 @@ +### Note on using custom kernel for MORR linear layer + +Current optical transform pass only support MORR linear PyTorch module. To enbale substitution using Optimised MORR linear module (using Triton kernel): + +1. uncomment `TritonMemMORRLinear` inside [file](../../../../../src/chop/nn/optical/modules/__init__.py) +2. replace `morr_linear_fn_mem` function in [kernel wrapper](../../../../../src/chop/nn/optical/triton_modules/morr_linear_mem.py). Current implementation import it from a project file, import it from mase-triton instead. +3. You should now able to use optimised MORR linear module in optical transform pass. Two sample usage are shown below: + +```python + +# Minimal example ─ apply the MORR-Triton replacement to a single layer +model = Net() +pass_args = { + "by": "name", + "fc1": { + "config": { + "name": "morr_triton", + "miniblock": 4, + "morr_init": True, + "trainable_morr_bias": False, + "trainable_morr_scale": False, + } + }, +} +new_model, _ = optical_module_transform_pass(model, pass_args) + +# Use additional config to initialise MORR linear module with noise modelling +model = Net() +pass_args = { + "by": "regex_name", + "^fc1$": { + "config": {"name": "morr_triton", "miniblock": 4}, + "additional": { + "trainable_morr_bias": False, + "trainable_morr_scale": False, + "thermal_crosstalk": True, + "coupling_factor": 0.04, + "drop_perc": 0.0, + "phase_noise": True, + "phase_noise_std": 0.04, + "in_bit": 8, + "w_bit": 8, + }, + }, +} +new_model, _ = optical_module_transform_pass(model, pass_args) +``` \ No newline at end of file diff --git a/test/passes/module/transforms/optical/test_optical_module.py b/test/passes/module/transforms/optical/test_optical_module.py new file mode 100644 index 000000000..f6e74e99e --- /dev/null +++ b/test/passes/module/transforms/optical/test_optical_module.py @@ -0,0 +1,63 @@ +#!/usr/bin/env python3 +# This example converts a simple MLP model to an ONN model +import sys + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from pathlib import Path + +sys.path.append(Path(__file__).resolve().parents[5].as_posix()) + + +from chop.passes.module.transforms.optical import optical_module_transform_pass + + +class Net(nn.Module): + def __init__(self): + super(Net, self).__init__() + self.conv1 = nn.Conv2d(1, 32, 3, 1) + self.conv2 = nn.Conv2d(32, 64, 3, 1) + self.dropout1 = nn.Dropout(0.25) + self.dropout2 = nn.Dropout(0.5) + self.fc1 = nn.Linear(9216, 128) + self.fc2 = nn.Linear(128, 10) + + def forward(self, x): + x = self.conv1(x) + x = F.relu(x) + x = self.conv2(x) + x = F.relu(x) + x = F.max_pool2d(x, 2) + x = self.dropout1(x) + x = torch.flatten(x, 1) + x = self.fc1(x) + x = F.relu(x) + x = self.dropout2(x) + x = self.fc2(x) + output = F.log_softmax(x, dim=1) + return output + + +def test_optical_module_transform_pass(): + model = Net() + # Sanity check and report + pass_args = { + "by": "name", + "fc1": { + "config": { + "name": "morr", + "miniblock": 4, + "morr_init": True, + "trainable_morr_bias": False, + "trainable_morr_scale": False, + } + }, + } + optical_module_transform_pass(model, pass_args) + + +test_optical_module_transform_pass() +# test_optical_module_transform_pass_2() +# test_optical_module_transform_pass_3()