66import torch
77import torch .nn as nn
88
9+ from .utils import gram
10+
911
1012class HookManager :
1113 """
@@ -17,7 +19,12 @@ class HookManager:
1719 to clear collected features and remove hooks.
1820 """
1921
20- def __init__ (self , model : nn .Module , layers : List [str ]) -> None :
22+ def __init__ (
23+ self ,
24+ model : nn .Module ,
25+ layers : List [str ] | None = None ,
26+ recursive : bool = True ,
27+ ) -> None :
2128 """
2229 Initializes the HookManager and registers forward hooks on the specified layers.
2330
@@ -26,6 +33,9 @@ def __init__(self, model: nn.Module, layers: List[str]) -> None:
2633 layers: A list of strings, where each string is the fully qualified name of a
2734 module (layer) within the `model` from which to extract features.
2835 These names typically come from `model.named_modules()`.
36+ If `None`, all layers will be hooked. Defaults to `None`.
37+ recursive: A boolean indicating whether to register hooks recursively on the model.
38+ If `True`, hooks will be registered on all submodules of the specified layers.
2939
3040 Raises:
3141 ValueError: If no valid layers are found in the model based on the provided `layers` list.
@@ -34,8 +44,38 @@ def __init__(self, model: nn.Module, layers: List[str]) -> None:
3444 self .features : Dict [str , torch .Tensor ] = {}
3545 self .handles : List [torch .utils .hooks .RemovableHandle ] = []
3646
37- # Use list(dict.fromkeys(layers)) to preserve order while removing duplicate layer names
38- self .module_names = self ._insert_hooks (list (dict .fromkeys (layers )))
47+ if layers is None :
48+ layers = self ._extract_all_layers (model , recursive )
49+
50+ self .module_names = self ._insert_hooks (
51+ module = model ,
52+ layers = layers ,
53+ recursive = recursive ,
54+ )
55+
56+ def _extract_all_layers (
57+ self , module : nn .Module , recursive : bool = True
58+ ) -> List [str ]:
59+ """
60+ Extracts all layer names from the model recursively.
61+
62+ This method traverses the model and collects the names of all modules
63+ (layers) in a flat list. It is useful for debugging or when you want to
64+ register hooks on all layers without specifying them explicitly.
65+
66+ Args:
67+ module: The PyTorch model (`torch.nn.Module`) to extract layer names from.
68+
69+ Returns:
70+ A list of strings, where each string is the name of a module in the model.
71+ """
72+ layers = set ()
73+ for name , child in module .named_children ():
74+ if recursive and len (list (child .named_children ())) > 0 :
75+ layers .update (self ._extract_all_layers (child ))
76+ else :
77+ layers .add (name )
78+ return list (layers )
3979
4080 def _hook (
4181 self ,
@@ -62,9 +102,18 @@ def _hook(
62102 module ,
63103 inp ,
64104 ) # Unused parameters, but kept for compatibility with the hook signature
65- self .features [module_name ] = out .detach ()
105+ batch_size = out .size (0 )
106+ feature = out .reshape (batch_size , - 1 )
107+ feature = gram (feature )
108+ self .features [module_name ] = feature
66109
67- def _insert_hooks (self , layers : List [str ]) -> List [str ]:
110+ def _insert_hooks (
111+ self ,
112+ module : nn .Module ,
113+ layers : List [str ],
114+ recursive : bool = True ,
115+ prev_name : str = "" ,
116+ ) -> List [str ]:
68117 """
69118 Registers forward hooks on the specified layers of the model.
70119
@@ -87,24 +136,26 @@ def _insert_hooks(self, layers: List[str]) -> List[str]:
87136 This typically indicates an issue with the provided layer names.
88137 """
89138 filtered_layers : List [str ] = []
90- for module_name , module in self .model .named_modules ():
139+ for module_name , child in module .named_children ():
140+ curr_name = f"{ prev_name } .{ module_name } " if prev_name else module_name
141+ curr_name = curr_name .replace ("_model." , "" )
142+ num_grandchildren = len (list (child .named_children ()))
143+
144+ if recursive and num_grandchildren > 0 :
145+ # If the module has children, recursively register hooks for them
146+ filtered_layers .extend (
147+ self ._insert_hooks (
148+ module = child ,
149+ layers = layers ,
150+ recursive = recursive ,
151+ prev_name = curr_name ,
152+ )
153+ )
154+
91155 if module_name in layers :
92- handle = module .register_forward_hook (partial (self ._hook , module_name )) # type: ignore
156+ handle = child .register_forward_hook (partial (self ._hook , curr_name )) # type: ignore
93157 self .handles .append (handle )
94- filtered_layers .append (module_name )
95-
96- if len (filtered_layers ) != len (layers ):
97- hooked_set = set (filtered_layers )
98- not_hooked = [layer for layer in layers if layer not in hooked_set ]
99- print (
100- f"Warning: Could not find layers: { not_hooked } . They will be ignored."
101- )
102-
103- if not filtered_layers :
104- raise ValueError (
105- "No layers were found in the model. Please use `model.named_modules()` "
106- "to check the available layer names."
107- )
158+ filtered_layers .append (curr_name )
108159
109160 return filtered_layers
110161
0 commit comments