Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
4d544b0
MORR ONN transform pass and test script implementation
Jan 23, 2025
5871543
add support for onn conv2d layer
Jan 26, 2025
5dc0974
black format
Jan 26, 2025
bd44a07
fix optical test script
Jan 26, 2025
913d8fa
black format fixed
Jan 26, 2025
b681e30
remove unnecessary import
Jan 26, 2025
3d99177
fix cuda error for morr_conv2d and morr_linear
Jan 28, 2025
d27d729
Merge remote-tracking branch 'upstream/main' into jy/onn_transform_pass
Jan 29, 2025
09da13d
reduce sphinx docstring error
ChengZhang-98 Feb 11, 2025
aac9097
undo docstrings changeRevert "reduce sphinx docstring error"
ChengZhang-98 Feb 11, 2025
9dba4fd
rename files
ChengZhang-98 Feb 11, 2025
1886884
remove torch_train.py
ChengZhang-98 Feb 11, 2025
745865a
Merge branch 'main' into jy/onn_transform_pass
ChengZhang-98 Feb 11, 2025
6d9dddd
remove redundent onn functions
Feb 15, 2025
8aa1dbd
.
Feb 16, 2025
1d7fdda
Merge branch 'main' into jy/onn_transform_pass
Johnny1882 Feb 16, 2025
d9e15b3
remove optical test to prevent doc generation error
Feb 20, 2025
ea77e74
reformat
Feb 20, 2025
6be34c9
remove unorgainzed test
Feb 21, 2025
87a222b
remove self test file
Feb 21, 2025
67b91b5
add back test_optical_module
Mar 10, 2025
7587103
Merge branch 'main' of https://github.com/Johnny1882/mase into jy/onn…
Mar 10, 2025
ce0f533
add bert transform file
Mar 11, 2025
cadc417
add debug file
Mar 11, 2025
fe9fff6
switch branch
Mar 12, 2025
4dfaa74
.
Mar 15, 2025
bee7f6d
modify morr_linear, implement it inside onn-transform-pass
Apr 2, 2025
7c1c7ae
add morr_full layer, perfrom testing on bert model
Apr 6, 2025
c2813ee
update morr_transformer
Apr 10, 2025
8a27d0d
update onn kernel and testing script
May 11, 2025
a06933d
Merge remote-tracking branch 'upstream/main' into jy/bert-onn
May 11, 2025
dddc817
complete transform pass
May 11, 2025
9461079
added quantization in custom kernel
May 15, 2025
c9974ab
fix memory-efficient kernel
May 20, 2025
38d5da4
new memory-saving kernel
May 21, 2025
b59e8d1
fix weight loading function
May 25, 2025
bd5d90b
add warning in random weight
May 28, 2025
6e82a67
fix kernel bugs to enable memory and training speed evaliation to be run
Jun 5, 2025
e956c49
remove breakpoint use black to format code
Jun 24, 2025
c63c376
remove triton kernel
Jun 24, 2025
fe3c170
add notes
Jun 24, 2025
7fe6ecf
reformat
Jun 24, 2025
ee5a1f2
complete note on further work
Jun 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -172,4 +172,6 @@ mase-trainer/
test-trainer/

# DiffLogic: tutorial files
docs/tutorials/difflogic/data-mnist/
docs/tutorials/difflogic/data-mnist/

test/self
8 changes: 5 additions & 3 deletions src/chop/actions/search/search_space/nas_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -1199,9 +1199,11 @@ def forward(
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
if self.config.is_decoder and encoder_hidden_states is not None:
encoder_batch_size, encoder_sequence_length, _ = (
encoder_hidden_states.size()
)
(
encoder_batch_size,
encoder_sequence_length,
_,
) = encoder_hidden_states.size()
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
Expand Down
5 changes: 3 additions & 2 deletions src/chop/distributed/tensor/_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,8 +201,9 @@ def default_tensor(spec: _DTensorSpec) -> torch.Tensor:
# did not already construct one
random._rng_tracker = random.OffsetBasedRNGTracker(mesh.device_type)

first_arg, first_local_arg = cast(dtensor.DTensor, args[0]), cast(
torch.Tensor, local_tensor_args[0]
first_arg, first_local_arg = (
cast(dtensor.DTensor, args[0]),
cast(torch.Tensor, local_tensor_args[0]),
)
rng_context = (
random._rng_tracker._distribute_region(first_arg._spec)
Expand Down
1 change: 0 additions & 1 deletion src/chop/ir/onnx/mase_onnx_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@


class MaseOnnxGraph:

def __init__(
self,
model_proto: onnx.onnx_ml_pb2.ModelProto,
Expand Down
8 changes: 2 additions & 6 deletions src/chop/models/cnv/cnv.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,8 @@
from typing import Any
import numpy as np

from chop.nn.quantized.modules.conv2d import (
Conv2dBinaryResidualSign,
)
from chop.nn.quantized.modules.linear import (
LinearBinaryResidualSign,
)
from chop.nn.quantized.modules.conv2d import Conv2dBinaryResidualSign
from chop.nn.quantized.modules.linear import LinearBinaryResidualSign
from chop.models.utils import register_mase_model, register_mase_checkpoint

"""
Expand Down
4 changes: 1 addition & 3 deletions src/chop/nn/backward/modules/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
from .linear import (
CustomLinear,
)
from .linear import CustomLinear


custom_module_map = {
Expand Down
1 change: 1 addition & 0 deletions src/chop/nn/optical/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .modules import optical_module_map
11 changes: 11 additions & 0 deletions src/chop/nn/optical/modules/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from .morr_linear import AllPassMORRCirculantLinear
from .morr_conv2d import AllPassMORRCirculantConv2d

# from ..triton_modules.morr_linear_mem import TritonMemMORRLinear


optical_module_map = {
"linear_morr": AllPassMORRCirculantLinear,
"conv2d_morr": AllPassMORRCirculantConv2d,
# "linear_morr_triton": TritonMemMORRLinear,
}
76 changes: 76 additions & 0 deletions src/chop/nn/optical/modules/base_layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
"""
Description:
Author: Jiaqi Gu (jqgu@utexas.edu)
Date: 2021-06-08 18:55:05
LastEditors: Jiaqi Gu (jqgu@utexas.edu)
LastEditTime: 2021-06-08 18:55:05
"""

from typing import Any, Dict, Optional
import torch
from torch import nn
from torch.types import Device

__all__ = ["ONNBaseLayer"]


class ONNBaseLayer(nn.Module):
def __init__(self, *args, device: Device = torch.device("cpu"), **kwargs) -> None:
super().__init__(*args, **kwargs)
# cuda or cpu, defaults to cpu
self.device = device

def build_parameters(self) -> None:
raise NotImplementedError

def reset_parameters(self) -> None:
raise NotImplementedError

@classmethod
def from_layer(cls, layer: nn.Module, *args, **kwargs) -> nn.Module:
raise NotImplementedError

def get_num_parameters(self) -> int:
return sum(p.numel() for p in self.parameters() if p.requires_grad)

def enable_fast_forward(self) -> None:
self.fast_forward_flag = True

def disable_fast_forward(self) -> None:
self.fast_forward_flag = False

def set_phase_variation(
self, noise_std: float, random_state: Optional[int] = None
) -> None:
self.phase_noise_std = noise_std

def set_gamma_noise(
self, noise_std: float, random_state: Optional[int] = None
) -> None:
self.gamma_noise_std = noise_std

def set_crosstalk_factor(self, crosstalk_factor: float) -> None:
self.crosstalk_factor = crosstalk_factor

def set_weight_bitwidth(self, w_bit: int) -> None:
self.w_bit = w_bit

def set_input_bitwidth(self, in_bit: int) -> None:
self.in_bit = in_bit

def load_parameters(self, param_dict: Dict[str, Any]) -> None:
"""
description: update parameters based on this parameter dictionary\\
param param_dict {dict of dict} {param_name: param_tensor, ...}
"""
for name, param in param_dict.items():
getattr(self, name).data.copy_(param)

def switch_mode_to(self, mode: str) -> None:
self.mode = mode

def forward(self, x):
raise NotImplementedError

def extra_repr(self) -> str:
return ""
Loading