Skip to content

Counting FLOPS for a custom op with set_op_handle: a toy example that doesn't work. #147

@guynich

Description

@guynich

I'm extending the given example for fvcore.nn.FlopCountAnalysis to add flops count of a custom op within my model class.

import torch

from collections import Counter

from fvcore.nn import FlopCountAnalysis
from torch import nn

class TestModel(nn.Module):
    """Toy model."""
    def __init__(self):
        super().__init__()
        self.act = nn.ReLU()
        self.conv = nn.Conv2d(in_channels=3, out_channels=10, kernel_size=1)
        self.fc = nn.Linear(in_features=1000, out_features=10)

    def forward(self, x):
        _ = self.custom_op_flop_counter(inputs=x, outputs=None)
        return self.fc(self.act(self.conv(x)).flatten(1))

    @staticmethod
    # Has no access to anything else in the class.
    def custom_op_flop_counter(inputs, outputs) -> Counter:
        """Returns counter value to include in flops."""
        # The function should return a counter object with per-operator statistics.
        return Counter({'custom_op': 500})


model = TestModel()
inputs = (torch.randn((1, 3, 10, 10)),)

flops = FlopCountAnalysis(
    model,
    inputs).set_op_handle(
        "custom_op", model.custom_op_flop_counter)

print(flops.by_module_and_operator())

The "custom_op" and its returned value of 500 are not seen in the print statement. It does print the expected values for the linear and conv operators, e.g.: {'': Counter({'linear': 10000, 'conv': 3000}), 'act': Counter(), 'conv': Counter({'conv': 3000}), 'fc': Counter({'linear': 10000})}. What am I doing wrong here that prevents my custom op count being included?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions