diff --git a/fvcore/nn/activation_count.py b/fvcore/nn/activation_count.py index 6ce543f..a0bc517 100644 --- a/fvcore/nn/activation_count.py +++ b/fvcore/nn/activation_count.py @@ -19,6 +19,9 @@ "aten::einsum": generic_activation_jit(), "aten::matmul": generic_activation_jit(), "aten::linear": generic_activation_jit(), + "aten::mm": generic_activation_jit(), + "aten::mul": generic_activation_jit(), + "aten::mul_": generic_activation_jit(), } diff --git a/fvcore/nn/flop_count.py b/fvcore/nn/flop_count.py index ac21f61..b45ab87 100644 --- a/fvcore/nn/flop_count.py +++ b/fvcore/nn/flop_count.py @@ -31,6 +31,8 @@ "aten::matmul": matmul_flop_jit, "aten::mm": matmul_flop_jit, "aten::linear": linear_flop_jit, + "aten::mul": elementwise_flop_counter(0, 1), + "aten::mul_": elementwise_flop_counter(0, 1), # You might want to ignore BN flops due to inference-time fusion. # Use `set_op_handle("aten::batch_norm", None) "aten::batch_norm": batchnorm_flop_jit, diff --git a/tests/test_flop_count.py b/tests/test_flop_count.py index ad593b9..b354d1f 100644 --- a/tests/test_flop_count.py +++ b/tests/test_flop_count.py @@ -147,6 +147,17 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: return x +class MulNet(nn.Module): + """ + A network with a single torch.mul operation. This is used for testing + flop count for torch.mul. + """ + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + x = torch.mul(x, y) + return x + + class CustomNet(nn.Module): """ A network with a fully connected layer followed by a sigmoid layer. This is @@ -705,6 +716,26 @@ def test_einsum(self) -> None: "Einsum operation ntg,ncg->nct failed to pass the flop count test.", ) + def test_mul(self) -> None: + """ + Test flop count for operation torch.mul. + """ + m = 2 + n = 5 + p = 7 + net = MulNet() + x = torch.randn(m, 1, n) + y = torch.randn(p, 1) + flop_dict, _ = flop_count(net, (x, y)) + gt_flop = m * n * p / 1e9 + gt_dict = defaultdict(float) + gt_dict["mul"] = gt_flop + self.assertDictEqual( + flop_dict, + gt_dict, + "Mul operation failed to pass the flop count test." + ) + def test_batchnorm(self) -> None: """ Test flop count for operation batchnorm. The test cases include