diff --git a/bitorch/layers/qactivation.py b/bitorch/layers/qactivation.py index c6722c0..d93bc92 100644 --- a/bitorch/layers/qactivation.py +++ b/bitorch/layers/qactivation.py @@ -9,10 +9,13 @@ class GradientCancellation(Function): + @staticmethod + def setup_context(ctx, inputs, output): + ctx.save_for_backward(inputs[0], torch.tensor(inputs[1], device=inputs[0].device)) + @staticmethod @typing.no_type_check def forward( - ctx: torch.autograd.function.BackwardCFunction, # type: ignore input_tensor: torch.Tensor, threshold: float, ) -> torch.Tensor: @@ -24,8 +27,7 @@ def forward( Returns: tensor: binarized input tensor """ - ctx.save_for_backward(input_tensor, torch.tensor(threshold, device=input_tensor.device)) - return input_tensor + return input_tensor.view_as(input_tensor) @staticmethod @typing.no_type_check @@ -50,6 +52,34 @@ def backward( ) return cancelled, None + @staticmethod + def vmap(info, in_dims, input_tensor, threshold): + """ + Define behavior of the autograd function under vmap. + + Args: + info: Information about vmap (e.g., batch_size). + in_dims: Tuple specifying the batch dimension for each input. + input_tensor: Batched input tensor (with batch dimension at in_dims[0]). + threshold: Scalar threshold (or batched threshold, depending on in_dims[1]). + + Returns: + Tuple[output, out_dims]: Output tensor and its batch dimension. + """ + # Ensure input_tensor has a batch dimension + input_batch_dim = in_dims[0] + if input_batch_dim is not None and input_batch_dim != 0: + input_tensor = input_tensor.movedim(input_batch_dim, 0) + + # Ensure threshold has a batch dimension if provided + if in_dims[1] is not None: + threshold = threshold.movedim(in_dims[1], 0) + + # Forward pass is a view operation, so return the input directly + output = input_tensor.view_as(input_tensor) + out_dims = 0 # Output has batch dimension at index 0 + + return output, out_dims class QActivation(nn.Module): """Activation layer for quantization""" diff --git a/bitorch/layers/qconv2d.py b/bitorch/layers/qconv2d.py index e01f06d..1595f30 100644 --- a/bitorch/layers/qconv2d.py +++ b/bitorch/layers/qconv2d.py @@ -13,6 +13,7 @@ from .qactivation import QActivation from .qconv_mixin import QConvArgsProviderMixin from .register import QConv2dImplementation +import copy class QConv2d_NoAct(Conv2d): # type: ignore # noqa: N801 @@ -112,8 +113,14 @@ class _QConv2dComposed(DefaultImplementationMixin, QConv2dBase): To implement a custom QConv2d implementation use QConv2dBase as a super class instead. """ - - pass + def __deepcopy__(self, memo): + new_instance = self.__class__.__new__(self.__class__) + memo[id(self)] = new_instance + + for k, v in self.__dict__.items(): + setattr(new_instance, k, copy.deepcopy(v, memo)) + + return new_instance QConv2d: Type[_QConv2dComposed] = QConv2dImplementation(RuntimeMode.DEFAULT)(_QConv2dComposed) # type: ignore diff --git a/bitorch/quantizations/sign.py b/bitorch/quantizations/sign.py index 99d45e7..f823bd7 100644 --- a/bitorch/quantizations/sign.py +++ b/bitorch/quantizations/sign.py @@ -8,10 +8,13 @@ class SignFunction(STE): + @staticmethod + def setup_context(ctx, inputs, output): + pass + @staticmethod @typing.no_type_check def forward( - ctx: torch.autograd.function.BackwardCFunction, # type: ignore input_tensor: torch.Tensor, ) -> torch.Tensor: """Binarize the input tensor using the sign function. @@ -27,6 +30,32 @@ def forward( sign_tensor = torch.where(sign_tensor == 0, torch.tensor(1.0, device=sign_tensor.device), sign_tensor) return sign_tensor + @staticmethod + def vmap(info, in_dims, input_tensor): + """ + Vectorized implementation of the forward method for batched inputs. + + Args: + info: Contains vmap-related information (batch_size, randomness). + in_dims (tuple): Specifies which dimension of `input_tensor` is the batch dimension. + input_tensor (torch.Tensor): Batched input tensor. + + Returns: + Tuple[torch.Tensor, int]: The batched output tensor and the dimension of its batch. + """ + # Ensure input_tensor has the batch dimension as the first dimension + input_batch_dim = in_dims[0] if in_dims[0] is not None else 0 + if input_batch_dim != 0: + input_tensor = input_tensor.movedim(input_batch_dim, 0) + + # Apply the sign and binarize operation across the batch + sign_tensor = torch.sign(input_tensor) + sign_tensor = torch.where(sign_tensor == 0, torch.tensor(1.0, device=sign_tensor.device), sign_tensor) + + # Output has batch dimension at index 0 + out_dims = 0 + return sign_tensor, out_dims + class Sign(Quantization): """Module for applying the sign function with straight through estimator in backward pass."""