From 5b0744527f964d30504855b572f2a6bf7852b5a4 Mon Sep 17 00:00:00 2001 From: rivers Date: Thu, 17 Nov 2022 23:25:14 +0800 Subject: [PATCH 1/2] add flop counter for mul operation supporting broadcast --- fvcore/nn/activation_count.py | 3 +++ fvcore/nn/flop_count.py | 3 +++ fvcore/nn/jit_handles.py | 16 ++++++++++++++++ tests/test_flop_count.py | 31 +++++++++++++++++++++++++++++++ 4 files changed, 53 insertions(+) diff --git a/fvcore/nn/activation_count.py b/fvcore/nn/activation_count.py index 6ce543f..a0bc517 100644 --- a/fvcore/nn/activation_count.py +++ b/fvcore/nn/activation_count.py @@ -19,6 +19,9 @@ "aten::einsum": generic_activation_jit(), "aten::matmul": generic_activation_jit(), "aten::linear": generic_activation_jit(), + "aten::mm": generic_activation_jit(), + "aten::mul": generic_activation_jit(), + "aten::mul_": generic_activation_jit(), } diff --git a/fvcore/nn/flop_count.py b/fvcore/nn/flop_count.py index ac21f61..6f5d9bd 100644 --- a/fvcore/nn/flop_count.py +++ b/fvcore/nn/flop_count.py @@ -19,6 +19,7 @@ linear_flop_jit, matmul_flop_jit, norm_flop_counter, + mul_flop_jit, ) @@ -31,6 +32,8 @@ "aten::matmul": matmul_flop_jit, "aten::mm": matmul_flop_jit, "aten::linear": linear_flop_jit, + "aten::mul": mul_flop_jit, + "aten::mul_": mul_flop_jit, # You might want to ignore BN flops due to inference-time fusion. # Use `set_op_handle("aten::batch_norm", None) "aten::batch_norm": batchnorm_flop_jit, diff --git a/fvcore/nn/jit_handles.py b/fvcore/nn/jit_handles.py index 747205c..ba0a0ef 100644 --- a/fvcore/nn/jit_handles.py +++ b/fvcore/nn/jit_handles.py @@ -227,6 +227,22 @@ def matmul_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number: return flop +def mul_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number: + """ + Count flops for the mul operation supporting broadcast. + """ + # Inputs should be a list of length 2. + # Inputs contains the shapes of two tensor. + assert len(inputs) == 2, len(inputs) + input_shapes = [get_shape(v) for v in inputs] + shape_zero_len, shape_one_len = len(input_shapes[0]), len(input_shapes[1]) + max_len = max(shape_zero_len, shape_one_len) + shape_zero_padded = np.pad(input_shapes[0], (max_len - shape_zero_len, 0), 'constant', constant_values=(1, 1)) + shape_one_padded = np.pad(input_shapes[1], (max_len - shape_one_len, 0), 'constant', constant_values=(1, 1)) + flop = int(prod(np.maximum(shape_zero_padded, shape_one_padded))) + return flop + + def norm_flop_counter(affine_arg_index: int) -> Handle: """ Args: diff --git a/tests/test_flop_count.py b/tests/test_flop_count.py index ad593b9..b354d1f 100644 --- a/tests/test_flop_count.py +++ b/tests/test_flop_count.py @@ -147,6 +147,17 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: return x +class MulNet(nn.Module): + """ + A network with a single torch.mul operation. This is used for testing + flop count for torch.mul. + """ + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + x = torch.mul(x, y) + return x + + class CustomNet(nn.Module): """ A network with a fully connected layer followed by a sigmoid layer. This is @@ -705,6 +716,26 @@ def test_einsum(self) -> None: "Einsum operation ntg,ncg->nct failed to pass the flop count test.", ) + def test_mul(self) -> None: + """ + Test flop count for operation torch.mul. + """ + m = 2 + n = 5 + p = 7 + net = MulNet() + x = torch.randn(m, 1, n) + y = torch.randn(p, 1) + flop_dict, _ = flop_count(net, (x, y)) + gt_flop = m * n * p / 1e9 + gt_dict = defaultdict(float) + gt_dict["mul"] = gt_flop + self.assertDictEqual( + flop_dict, + gt_dict, + "Mul operation failed to pass the flop count test." + ) + def test_batchnorm(self) -> None: """ Test flop count for operation batchnorm. The test cases include From dfc5b8b0b4415334b7dcc4668c7805c38d0e0573 Mon Sep 17 00:00:00 2001 From: rivers Date: Sat, 19 Nov 2022 11:32:46 +0800 Subject: [PATCH 2/2] correct the implementation --- fvcore/nn/flop_count.py | 5 ++--- fvcore/nn/jit_handles.py | 16 ---------------- 2 files changed, 2 insertions(+), 19 deletions(-) diff --git a/fvcore/nn/flop_count.py b/fvcore/nn/flop_count.py index 6f5d9bd..b45ab87 100644 --- a/fvcore/nn/flop_count.py +++ b/fvcore/nn/flop_count.py @@ -19,7 +19,6 @@ linear_flop_jit, matmul_flop_jit, norm_flop_counter, - mul_flop_jit, ) @@ -32,8 +31,8 @@ "aten::matmul": matmul_flop_jit, "aten::mm": matmul_flop_jit, "aten::linear": linear_flop_jit, - "aten::mul": mul_flop_jit, - "aten::mul_": mul_flop_jit, + "aten::mul": elementwise_flop_counter(0, 1), + "aten::mul_": elementwise_flop_counter(0, 1), # You might want to ignore BN flops due to inference-time fusion. # Use `set_op_handle("aten::batch_norm", None) "aten::batch_norm": batchnorm_flop_jit, diff --git a/fvcore/nn/jit_handles.py b/fvcore/nn/jit_handles.py index ba0a0ef..747205c 100644 --- a/fvcore/nn/jit_handles.py +++ b/fvcore/nn/jit_handles.py @@ -227,22 +227,6 @@ def matmul_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number: return flop -def mul_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number: - """ - Count flops for the mul operation supporting broadcast. - """ - # Inputs should be a list of length 2. - # Inputs contains the shapes of two tensor. - assert len(inputs) == 2, len(inputs) - input_shapes = [get_shape(v) for v in inputs] - shape_zero_len, shape_one_len = len(input_shapes[0]), len(input_shapes[1]) - max_len = max(shape_zero_len, shape_one_len) - shape_zero_padded = np.pad(input_shapes[0], (max_len - shape_zero_len, 0), 'constant', constant_values=(1, 1)) - shape_one_padded = np.pad(input_shapes[1], (max_len - shape_one_len, 0), 'constant', constant_values=(1, 1)) - flop = int(prod(np.maximum(shape_zero_padded, shape_one_padded))) - return flop - - def norm_flop_counter(affine_arg_index: int) -> Handle: """ Args: