-
Notifications
You must be signed in to change notification settings - Fork 238
Open
Description
i use bellow code to test mha flops. but it's same flops and params when nhead=4 or 8
import torch
import torch.nn as nn
class MHAModel(nn.Module):
def __init__(self, dim, nhead, dropout):
super(MHAModel, self).__init__()
self.mha = nn.MultiheadAttention(dim_out, nhead, dropout=dropout, batch_first=True)
def forward(self, x):
x = self.mha(x, x, x)[0]
return x
from fvcore.nn import FlopCountAnalysis, flop_count_table
dim_out = 448
seq_len = 300
nhead = 4
dropout = 0.1
net = MHAModel(dim=dim_out, nhead=nhead, dropout=dropout)
net.eval()
data = torch.randn((1, seq_len, dim_out))
flops = FlopCountAnalysis(net, (data))
print(flop_count_table(flops, max_depth=4))
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels
