-
Notifications
You must be signed in to change notification settings - Fork 238
Description
The matmul flop counts seem to be off by 2x.
I tested the code on a simple MLP which reads as:
import torch.nn
## @package eml.mlp.Module
# Simple MultiLayer Perceptron (MLP) with fixed dimensions.
#
# The MLP is assumes a 28^2 input-image and 10 output classes.
# These are the dimensions of the Fashion MNIST dataset.
class Model( torch.nn.Module ):
## Initializes the class.
# @param self object pointer.
def __init__( self ):
super( Model, self ).__init__()
## flattens the input
self.m_flatten = torch.nn.Flatten()
## layers of the MLP: 3x(linear + relu)
self.m_layers = torch.nn.Sequential( torch.nn.Linear( 28*28, 512 ),
torch.nn.ReLU(),
torch.nn.Linear( 512, 512 ),
torch.nn.ReLU(),
torch.nn.Linear( 512, 10 ) )
## Forward pass with the given input.
# @param self object pointer.
# @param i_input input for the forward pass.
# @return output of the MLP.
def forward( self,
i_input ):
l_flatten = self.m_flatten( i_input )
l_result = self.m_layers( l_flatten )
return l_result
Embedded this in some code with the crucial piece here:
l_model = eml.mlp.model.Model()
[...]
print( l_model )
#
# flop count code
# https://github.com/facebookresearch/fvcore/blob/main/docs/flop_count.md
#
import fvcore.nn
l_x, l_y = next(iter(l_data_loader_train))
print( l_x.size() )
l_flops = fvcore.nn.FlopCountAnalysis( l_model,
l_x )
print( l_flops.by_module_and_operator() )
print( fvcore.nn.flop_count_table( l_flops ) )
This returns:
Model(
(m_flatten): Flatten(start_dim=1, end_dim=-1)
(m_layers): Sequential(
(0): Linear(in_features=784, out_features=512, bias=True)
(1): ReLU()
(2): Linear(in_features=512, out_features=512, bias=True)
(3): ReLU()
(4): Linear(in_features=512, out_features=10, bias=True)
)
)
torch.Size([64, 1, 28, 28])
{'': Counter({'addmm': 42795008}), 'm_flatten': Counter(), 'm_layers': Counter({'addmm': 42795008}), 'm_layers.0': Counter({'addmm': 25690112}), 'm_layers.1': Counter(), 'm_layers.2': Counter({'addmm': 16777216}), 'm_layers.3': Counter(), 'm_layers.4': Counter({'addmm': 327680})}
| module | #parameters or shape | #flops |
|:-----------|:-----------------------|:---------|
| m_layers | 0.67M | 42.795M |
| 0 | 0.402M | 25.69M |
| 0.weight | (512, 784) | |
| 0.bias | (512,) | |
| 2 | 0.263M | 16.777M |
| 2.weight | (512, 512) | |
| 2.bias | (512,) | |
| 4 | 5.13K | 0.328M |
| 4.weight | (10, 512) | |
| 4.bias | (10,) | |
Let's take the first linear layer as an example: Matrix A in https://pytorch.org/docs/stable/generated/torch.nn.Linear.html has shape (512, 784).
Matrix x (since the example batched) has shape (64, 784).
Computing the result, C=xA^T requires 2*64*512*784 - 64*512 floating point operations.
However, in the example a bias is used, i.e., 64*512 additions on top -> 2*64*512*784=513,80,224 flops total; the tool reports 25,690,112 for the first layer. btw: I am not sure why the bias doesn't show up separately.
I believe that the code below is off since the number of ops of the op C+=AB using BLAS identifiers is 2*M*N*K not M*N*K:
fvcore/fvcore/nn/jit_handles.py
Line 225 in e4f0b3d
| flop = prod(input_shapes[0]) * input_shapes[-1][-1] |