Skip to content

Matmul Flops #108

@breuera

Description

@breuera

The matmul flop counts seem to be off by 2x.

I tested the code on a simple MLP which reads as:

import torch.nn

## @package eml.mlp.Module
#  Simple MultiLayer Perceptron (MLP) with fixed dimensions.
#
#  The MLP is assumes a 28^2 input-image and 10 output classes.
#  These are the dimensions of the Fashion MNIST dataset.
class Model( torch.nn.Module ):
  ## Initializes the class.
  #  @param self object pointer.
  def __init__( self ):
    super( Model, self ).__init__()
    ## flattens the input
    self.m_flatten = torch.nn.Flatten()
    ## layers of the MLP: 3x(linear + relu)
    self.m_layers = torch.nn.Sequential( torch.nn.Linear( 28*28, 512 ),
                                         torch.nn.ReLU(),
                                         torch.nn.Linear( 512, 512 ),
                                         torch.nn.ReLU(),
                                         torch.nn.Linear( 512, 10 ) )

  ## Forward pass with the given input.
  #  @param self object pointer.
  #  @param i_input input for the forward pass.
  #  @return output of the MLP.
  def forward( self,
               i_input ):
    l_flatten = self.m_flatten( i_input )
    l_result = self.m_layers( l_flatten )
    return l_result

Embedded this in some code with the crucial piece here:

l_model = eml.mlp.model.Model()
[...]
print( l_model )

#
# flop count code
# https://github.com/facebookresearch/fvcore/blob/main/docs/flop_count.md
#
import fvcore.nn

l_x, l_y = next(iter(l_data_loader_train))

print( l_x.size() )

l_flops = fvcore.nn.FlopCountAnalysis( l_model,
                                       l_x )

print( l_flops.by_module_and_operator() )

print( fvcore.nn.flop_count_table( l_flops ) )

This returns:

Model(
  (m_flatten): Flatten(start_dim=1, end_dim=-1)
  (m_layers): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
    (3): ReLU()
    (4): Linear(in_features=512, out_features=10, bias=True)
  )
)
torch.Size([64, 1, 28, 28])
{'': Counter({'addmm': 42795008}), 'm_flatten': Counter(), 'm_layers': Counter({'addmm': 42795008}), 'm_layers.0': Counter({'addmm': 25690112}), 'm_layers.1': Counter(), 'm_layers.2': Counter({'addmm': 16777216}), 'm_layers.3': Counter(), 'm_layers.4': Counter({'addmm': 327680})}
| module     | #parameters or shape   | #flops   |
|:-----------|:-----------------------|:---------|
| m_layers   | 0.67M                  | 42.795M  |
|  0         |  0.402M                |  25.69M  |
|   0.weight |   (512, 784)           |          |
|   0.bias   |   (512,)               |          |
|  2         |  0.263M                |  16.777M |
|   2.weight |   (512, 512)           |          |
|   2.bias   |   (512,)               |          |
|  4         |  5.13K                 |  0.328M  |
|   4.weight |   (10, 512)            |          |
|   4.bias   |   (10,)                |          |

Let's take the first linear layer as an example: Matrix A in https://pytorch.org/docs/stable/generated/torch.nn.Linear.html has shape (512, 784).
Matrix x (since the example batched) has shape (64, 784).
Computing the result, C=xA^T requires 2*64*512*784 - 64*512 floating point operations.
However, in the example a bias is used, i.e., 64*512 additions on top -> 2*64*512*784=513,80,224 flops total; the tool reports 25,690,112 for the first layer. btw: I am not sure why the bias doesn't show up separately.

I believe that the code below is off since the number of ops of the op C+=AB using BLAS identifiers is 2*M*N*K not M*N*K:

flop = prod(input_shapes[0]) * input_shapes[-1][-1]

Metadata

Metadata

Assignees

No one assigned

    Labels

    questionFurther information is requested

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions