Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 33 additions & 3 deletions bitorch/layers/qactivation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,13 @@


class GradientCancellation(Function):
@staticmethod
def setup_context(ctx, inputs, output):
ctx.save_for_backward(inputs[0], torch.tensor(inputs[1], device=inputs[0].device))

@staticmethod
@typing.no_type_check
def forward(
ctx: torch.autograd.function.BackwardCFunction, # type: ignore
input_tensor: torch.Tensor,
threshold: float,
) -> torch.Tensor:
Expand All @@ -24,8 +27,7 @@ def forward(
Returns:
tensor: binarized input tensor
"""
ctx.save_for_backward(input_tensor, torch.tensor(threshold, device=input_tensor.device))
return input_tensor
return input_tensor.view_as(input_tensor)

@staticmethod
@typing.no_type_check
Expand All @@ -50,6 +52,34 @@ def backward(
)
return cancelled, None

@staticmethod
def vmap(info, in_dims, input_tensor, threshold):
"""
Define behavior of the autograd function under vmap.

Args:
info: Information about vmap (e.g., batch_size).
in_dims: Tuple specifying the batch dimension for each input.
input_tensor: Batched input tensor (with batch dimension at in_dims[0]).
threshold: Scalar threshold (or batched threshold, depending on in_dims[1]).

Returns:
Tuple[output, out_dims]: Output tensor and its batch dimension.
"""
# Ensure input_tensor has a batch dimension
input_batch_dim = in_dims[0]
if input_batch_dim is not None and input_batch_dim != 0:
input_tensor = input_tensor.movedim(input_batch_dim, 0)

# Ensure threshold has a batch dimension if provided
if in_dims[1] is not None:
threshold = threshold.movedim(in_dims[1], 0)

# Forward pass is a view operation, so return the input directly
output = input_tensor.view_as(input_tensor)
out_dims = 0 # Output has batch dimension at index 0

return output, out_dims

class QActivation(nn.Module):
"""Activation layer for quantization"""
Expand Down
11 changes: 9 additions & 2 deletions bitorch/layers/qconv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .qactivation import QActivation
from .qconv_mixin import QConvArgsProviderMixin
from .register import QConv2dImplementation
import copy


class QConv2d_NoAct(Conv2d): # type: ignore # noqa: N801
Expand Down Expand Up @@ -112,8 +113,14 @@ class _QConv2dComposed(DefaultImplementationMixin, QConv2dBase):

To implement a custom QConv2d implementation use QConv2dBase as a super class instead.
"""

pass
def __deepcopy__(self, memo):
new_instance = self.__class__.__new__(self.__class__)
memo[id(self)] = new_instance

for k, v in self.__dict__.items():
setattr(new_instance, k, copy.deepcopy(v, memo))

return new_instance


QConv2d: Type[_QConv2dComposed] = QConv2dImplementation(RuntimeMode.DEFAULT)(_QConv2dComposed) # type: ignore
Expand Down
31 changes: 30 additions & 1 deletion bitorch/quantizations/sign.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,13 @@


class SignFunction(STE):
@staticmethod
def setup_context(ctx, inputs, output):
pass

@staticmethod
@typing.no_type_check
def forward(
ctx: torch.autograd.function.BackwardCFunction, # type: ignore
input_tensor: torch.Tensor,
) -> torch.Tensor:
"""Binarize the input tensor using the sign function.
Expand All @@ -27,6 +30,32 @@ def forward(
sign_tensor = torch.where(sign_tensor == 0, torch.tensor(1.0, device=sign_tensor.device), sign_tensor)
return sign_tensor

@staticmethod
def vmap(info, in_dims, input_tensor):
"""
Vectorized implementation of the forward method for batched inputs.

Args:
info: Contains vmap-related information (batch_size, randomness).
in_dims (tuple): Specifies which dimension of `input_tensor` is the batch dimension.
input_tensor (torch.Tensor): Batched input tensor.

Returns:
Tuple[torch.Tensor, int]: The batched output tensor and the dimension of its batch.
"""
# Ensure input_tensor has the batch dimension as the first dimension
input_batch_dim = in_dims[0] if in_dims[0] is not None else 0
if input_batch_dim != 0:
input_tensor = input_tensor.movedim(input_batch_dim, 0)

# Apply the sign and binarize operation across the batch
sign_tensor = torch.sign(input_tensor)
sign_tensor = torch.where(sign_tensor == 0, torch.tensor(1.0, device=sign_tensor.device), sign_tensor)

# Output has batch dimension at index 0
out_dims = 0
return sign_tensor, out_dims


class Sign(Quantization):
"""Module for applying the sign function with straight through estimator in backward pass."""
Expand Down