From 6d2e204626ec10dc61d345ff1a1facff8b46d947 Mon Sep 17 00:00:00 2001 From: yiliu Date: Tue, 14 Jan 2025 23:39:00 -0800 Subject: [PATCH 1/4] fix: fix bugs when using deepcopy on resnet --- bitorch/layers/qconv2d.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/bitorch/layers/qconv2d.py b/bitorch/layers/qconv2d.py index e01f06d..17fbc65 100644 --- a/bitorch/layers/qconv2d.py +++ b/bitorch/layers/qconv2d.py @@ -106,14 +106,23 @@ def forward(self, input_tensor: Tensor) -> Tensor: return super().forward(self.activation(input_tensor)) +import copy class _QConv2dComposed(DefaultImplementationMixin, QConv2dBase): """ This class defines the default implementation of a QConv2d layer (which is actually implemented by QConv2dBase). To implement a custom QConv2d implementation use QConv2dBase as a super class instead. """ - - pass + def __deepcopy__(self, memo): + # 创建一个新的实例 + new_instance = self.__class__.__new__(self.__class__) + memo[id(self)] = new_instance + + # 深拷贝所有属性 + for k, v in self.__dict__.items(): + setattr(new_instance, k, copy.deepcopy(v, memo)) + + return new_instance QConv2d: Type[_QConv2dComposed] = QConv2dImplementation(RuntimeMode.DEFAULT)(_QConv2dComposed) # type: ignore From 1b70a94385cdf36b4962de1b5e125e51bedfa41a Mon Sep 17 00:00:00 2001 From: yiliu Date: Wed, 15 Jan 2025 11:03:10 -0800 Subject: [PATCH 2/4] feat: qactivation will support vmap now --- bitorch/layers/qactivation.py | 36 ++++++++++++++++++++++++++++++++--- 1 file changed, 33 insertions(+), 3 deletions(-) diff --git a/bitorch/layers/qactivation.py b/bitorch/layers/qactivation.py index c6722c0..d93bc92 100644 --- a/bitorch/layers/qactivation.py +++ b/bitorch/layers/qactivation.py @@ -9,10 +9,13 @@ class GradientCancellation(Function): + @staticmethod + def setup_context(ctx, inputs, output): + ctx.save_for_backward(inputs[0], torch.tensor(inputs[1], device=inputs[0].device)) + @staticmethod @typing.no_type_check def forward( - ctx: torch.autograd.function.BackwardCFunction, # type: ignore input_tensor: torch.Tensor, threshold: float, ) -> torch.Tensor: @@ -24,8 +27,7 @@ def forward( Returns: tensor: binarized input tensor """ - ctx.save_for_backward(input_tensor, torch.tensor(threshold, device=input_tensor.device)) - return input_tensor + return input_tensor.view_as(input_tensor) @staticmethod @typing.no_type_check @@ -50,6 +52,34 @@ def backward( ) return cancelled, None + @staticmethod + def vmap(info, in_dims, input_tensor, threshold): + """ + Define behavior of the autograd function under vmap. + + Args: + info: Information about vmap (e.g., batch_size). + in_dims: Tuple specifying the batch dimension for each input. + input_tensor: Batched input tensor (with batch dimension at in_dims[0]). + threshold: Scalar threshold (or batched threshold, depending on in_dims[1]). + + Returns: + Tuple[output, out_dims]: Output tensor and its batch dimension. + """ + # Ensure input_tensor has a batch dimension + input_batch_dim = in_dims[0] + if input_batch_dim is not None and input_batch_dim != 0: + input_tensor = input_tensor.movedim(input_batch_dim, 0) + + # Ensure threshold has a batch dimension if provided + if in_dims[1] is not None: + threshold = threshold.movedim(in_dims[1], 0) + + # Forward pass is a view operation, so return the input directly + output = input_tensor.view_as(input_tensor) + out_dims = 0 # Output has batch dimension at index 0 + + return output, out_dims class QActivation(nn.Module): """Activation layer for quantization""" From c2a7977c2a4de7b8264ae8102a3c057ed8de9164 Mon Sep 17 00:00:00 2001 From: yiliu Date: Wed, 15 Jan 2025 11:34:15 -0800 Subject: [PATCH 3/4] feat: sign will support vmap now --- bitorch/quantizations/sign.py | 33 ++++++++++++++++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/bitorch/quantizations/sign.py b/bitorch/quantizations/sign.py index 99d45e7..e11496a 100644 --- a/bitorch/quantizations/sign.py +++ b/bitorch/quantizations/sign.py @@ -8,10 +8,15 @@ class SignFunction(STE): + @staticmethod + def setup_context(ctx, inputs, output): + # ctx.save_for_backward(inputs[0], torch.tensor(inputs[1], device=inputs[0].device)) + pass + @staticmethod @typing.no_type_check def forward( - ctx: torch.autograd.function.BackwardCFunction, # type: ignore + # ctx: torch.autograd.function.BackwardCFunction, # type: ignore input_tensor: torch.Tensor, ) -> torch.Tensor: """Binarize the input tensor using the sign function. @@ -27,6 +32,32 @@ def forward( sign_tensor = torch.where(sign_tensor == 0, torch.tensor(1.0, device=sign_tensor.device), sign_tensor) return sign_tensor + @staticmethod + def vmap(info, in_dims, input_tensor): + """ + Vectorized implementation of the forward method for batched inputs. + + Args: + info: Contains vmap-related information (batch_size, randomness). + in_dims (tuple): Specifies which dimension of `input_tensor` is the batch dimension. + input_tensor (torch.Tensor): Batched input tensor. + + Returns: + Tuple[torch.Tensor, int]: The batched output tensor and the dimension of its batch. + """ + # Ensure input_tensor has the batch dimension as the first dimension + input_batch_dim = in_dims[0] if in_dims[0] is not None else 0 + if input_batch_dim != 0: + input_tensor = input_tensor.movedim(input_batch_dim, 0) + + # Apply the sign and binarize operation across the batch + sign_tensor = torch.sign(input_tensor) + sign_tensor = torch.where(sign_tensor == 0, torch.tensor(1.0, device=sign_tensor.device), sign_tensor) + + # Output has batch dimension at index 0 + out_dims = 0 + return sign_tensor, out_dims + class Sign(Quantization): """Module for applying the sign function with straight through estimator in backward pass.""" From 5b4dfe260d6e7bd01a4c7954941dfd2ba912eae2 Mon Sep 17 00:00:00 2001 From: yiliu Date: Wed, 15 Jan 2025 12:00:23 -0800 Subject: [PATCH 4/4] docs: update comments in code --- bitorch/layers/qconv2d.py | 4 +--- bitorch/quantizations/sign.py | 2 -- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/bitorch/layers/qconv2d.py b/bitorch/layers/qconv2d.py index 17fbc65..1595f30 100644 --- a/bitorch/layers/qconv2d.py +++ b/bitorch/layers/qconv2d.py @@ -13,6 +13,7 @@ from .qactivation import QActivation from .qconv_mixin import QConvArgsProviderMixin from .register import QConv2dImplementation +import copy class QConv2d_NoAct(Conv2d): # type: ignore # noqa: N801 @@ -106,7 +107,6 @@ def forward(self, input_tensor: Tensor) -> Tensor: return super().forward(self.activation(input_tensor)) -import copy class _QConv2dComposed(DefaultImplementationMixin, QConv2dBase): """ This class defines the default implementation of a QConv2d layer (which is actually implemented by QConv2dBase). @@ -114,11 +114,9 @@ class _QConv2dComposed(DefaultImplementationMixin, QConv2dBase): To implement a custom QConv2d implementation use QConv2dBase as a super class instead. """ def __deepcopy__(self, memo): - # 创建一个新的实例 new_instance = self.__class__.__new__(self.__class__) memo[id(self)] = new_instance - # 深拷贝所有属性 for k, v in self.__dict__.items(): setattr(new_instance, k, copy.deepcopy(v, memo)) diff --git a/bitorch/quantizations/sign.py b/bitorch/quantizations/sign.py index e11496a..f823bd7 100644 --- a/bitorch/quantizations/sign.py +++ b/bitorch/quantizations/sign.py @@ -10,13 +10,11 @@ class SignFunction(STE): @staticmethod def setup_context(ctx, inputs, output): - # ctx.save_for_backward(inputs[0], torch.tensor(inputs[1], device=inputs[0].device)) pass @staticmethod @typing.no_type_check def forward( - # ctx: torch.autograd.function.BackwardCFunction, # type: ignore input_tensor: torch.Tensor, ) -> torch.Tensor: """Binarize the input tensor using the sign function.