Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions fvcore/nn/jit_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,22 @@ def _named_modules_without_dup(model: nn.Module) -> Iterator[Tuple[str, nn.Modul
yield name, mod


def _maybe_flatten(object) -> List[torch.Tensor]:
# Try its best to find all tensors within the object and put them
# into a flattened list. Custom stuctures cannot be recognized.
# TODO: improve coverage of other structures, e.g. by using __dict__
ret = []
if isinstance(object, torch.Tensor):
ret.append(object)
if isinstance(object, (list, tuple)):
for x in object:
ret.extend(_maybe_flatten(x))
if isinstance(object, dict):
for x in object.values():
ret.extend(_maybe_flatten(x))
return ret


def _get_scoped_trace_graph(
module: nn.Module,
inputs: Union[Tensor, Tuple[Tensor, ...]],
Expand Down Expand Up @@ -149,8 +165,11 @@ def __call__(self, module: nn.Module, inputs: Any, outputs: Any) -> Any:
tracing_state = torch._C._get_tracing_state()
if tracing_state:
tracing_state.pop_scope()
# Don't save all intermediate tensors on GPU. There could be a lot.
all_output_tensors.extend([x.cpu() for x in _maybe_flatten(outputs)])
return outputs

all_output_tensors: List[torch.Tensor] = []
hook_handles: List[Any] = []

def register_hooks(mod: nn.Module, name: str) -> None:
Expand All @@ -173,6 +192,27 @@ def register_hooks(mod: nn.Module, name: str) -> None:
name = aliases[mod]
register_hooks(mod, name)

class WrapperModule(nn.Module):
def __init__(self, module):
super().__init__()
self._wrapped = module

def forward(self, *args):
# Some intermediate tensors may not be directly connected to the final model
# output, for example due to:
# * control flow not observed by tracing
# * tensor -> numpy/int conversion
# Operations that produce such tensors will get pruned by pytorch's DCE,
# but we want to include them in the graph.
# There is currently no way to disable DCE. So we capture all tensors we can
# and return them here, to reduce missing flops.
outputs = self._wrapped(*args)
return outputs, all_output_tensors

# Hooks are registered before wrapping with their original scope names, so
# adding a wrapper here won't affect scopes.
module = WrapperModule(module)

graph, _ = _get_trace_graph(module, inputs)

for handle in hook_handles:
Expand Down
33 changes: 30 additions & 3 deletions tests/test_jit_model_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@
import unittest
import warnings
from collections import Counter
from typing import Any, Dict, List
from typing import Any, Dict, List, Union

import torch
import torch.nn as nn
from fvcore.nn.flop_count import FlopCountAnalysis
from fvcore.nn.jit_analysis import JitModelAnalysis
from fvcore.nn.jit_handles import addmm_flop_jit, conv_flop_jit, Handle, linear_flop_jit
from torch.nn import functional as F


class NestedNetInnerModule(nn.Module):
Expand Down Expand Up @@ -283,20 +284,28 @@ class TraceWarningNet(nn.Module):
will be skipped and raise a warning.
"""

class IntLinear(nn.Linear):
"""
A linear that outputs int, therefore cannot be traced.
"""

def forward(self, x) -> Union[float, int]:
return F.linear(x, self.weight, self.bias).item()

def __init__(self) -> None:
super().__init__()
self.input_size = (10,)
fc1_in, fc1_out = 10, 1
fc2_in, fc2_out = 10, 10

self.fc1 = nn.Linear(in_features=fc1_in, out_features=fc1_out)
self.fc1 = TraceWarningNet.IntLinear(in_features=fc1_in, out_features=fc1_out)
self.fc2 = nn.Linear(in_features=fc2_in, out_features=fc2_out)

self.fc1_flops: int = fc1_in * fc1_out
self.fc2_flops: int = fc2_in * fc2_out

def forward(self, x: torch.Tensor) -> torch.Tensor:
y = self.fc1(x).item()
y = self.fc1(x)
warnings.warn("Dummy RuntimeWarning.", RuntimeWarning)
if y < 0.0:
x = self.fc2(x)
Expand Down Expand Up @@ -806,6 +815,24 @@ def test_disable_warnings(self) -> None:
self.assertTrue(any(uncalled_msg in s for s in cm.output))
self.assertTrue(any(uncalled_modules in s for s in cm.output))

def test_capture_intermediate_outputs(self) -> None:
class TestCaptureNet(nn.Module):
def __init__(self) -> None:
super().__init__()
self.fc1 = nn.Linear(10, 1)
self.fc2 = nn.Linear(10, 10)

def forward(self, x: torch.Tensor) -> torch.Tensor:
y = self.fc1(x)
del y # unused by output
return self.fc2(x) + 2

model = TestCaptureNet()
inputs = (torch.randn((1, 10)),)
analyzer = FlopCountAnalysis(model=model, inputs=inputs)
_ = analyzer.total()
self.assertEqual(analyzer.uncalled_modules(), set())

def test_skip_uncalled_containers_warnings(self) -> None:
# uncalled containers should not warn

Expand Down