Skip to content

Commit 88c5ce8

Browse files
committed
Refactor code structure for improved readability and maintainability. Improve CKA to support recursive hook with string module names.
1 parent 9fb3941 commit 88c5ce8

File tree

4 files changed

+2582
-2808
lines changed

4 files changed

+2582
-2808
lines changed

cka_pytorch/cka.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,13 @@ def __init__(
2828
self,
2929
model1: nn.Module,
3030
model2: nn.Module,
31-
model1_layers: List[str],
31+
model1_layers: List[str] | None = None,
3232
model2_layers: List[str] | None = None,
3333
model1_name: str = "Model 1",
3434
model2_name: str = "Model 2",
3535
batched_feature_size: int = 64,
3636
device: Optional[torch.device] = None,
37+
hook_recursive: bool = True,
3738
verbose: bool = True,
3839
) -> None:
3940
"""
@@ -56,6 +57,9 @@ def __init__(
5657
when dealing with large models or many layers. Defaults to 64.
5758
device: An optional `torch.device` to perform computations on (e.g., `torch.device("cuda")`
5859
or `torch.device("cpu")`). If `None`, the device of `model1`'s parameters will be used.
60+
hook_recursive: A boolean indicating whether to register hooks recursively on the model.
61+
If `True`, hooks will be registered on all submodules of the specified layers.
62+
Defaults to `True`.
5963
verbose: A boolean indicating whether to print progress bars during CKA calculation.
6064
Defaults to `True`.
6165
"""
@@ -70,9 +74,15 @@ def __init__(
7074
self.model1.eval()
7175
self.model2.eval()
7276

73-
self.hook_manager1 = HookManager(model1, model1_layers)
77+
self.hook_manager1 = HookManager(
78+
model1,
79+
model1_layers,
80+
recursive=hook_recursive,
81+
)
7482
self.hook_manager2 = HookManager(
75-
model2, model2_layers if model2_layers else model1_layers
83+
model2,
84+
model2_layers if model2_layers else model1_layers,
85+
recursive=hook_recursive,
7686
)
7787

7888
self.num_layers_x = len(self.hook_manager1.module_names)

cka_pytorch/hook_manager.py

Lines changed: 72 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
import torch
77
import torch.nn as nn
88

9+
from .utils import gram
10+
911

1012
class HookManager:
1113
"""
@@ -17,7 +19,12 @@ class HookManager:
1719
to clear collected features and remove hooks.
1820
"""
1921

20-
def __init__(self, model: nn.Module, layers: List[str]) -> None:
22+
def __init__(
23+
self,
24+
model: nn.Module,
25+
layers: List[str] | None = None,
26+
recursive: bool = True,
27+
) -> None:
2128
"""
2229
Initializes the HookManager and registers forward hooks on the specified layers.
2330
@@ -26,6 +33,9 @@ def __init__(self, model: nn.Module, layers: List[str]) -> None:
2633
layers: A list of strings, where each string is the fully qualified name of a
2734
module (layer) within the `model` from which to extract features.
2835
These names typically come from `model.named_modules()`.
36+
If `None`, all layers will be hooked. Defaults to `None`.
37+
recursive: A boolean indicating whether to register hooks recursively on the model.
38+
If `True`, hooks will be registered on all submodules of the specified layers.
2939
3040
Raises:
3141
ValueError: If no valid layers are found in the model based on the provided `layers` list.
@@ -34,8 +44,38 @@ def __init__(self, model: nn.Module, layers: List[str]) -> None:
3444
self.features: Dict[str, torch.Tensor] = {}
3545
self.handles: List[torch.utils.hooks.RemovableHandle] = []
3646

37-
# Use list(dict.fromkeys(layers)) to preserve order while removing duplicate layer names
38-
self.module_names = self._insert_hooks(list(dict.fromkeys(layers)))
47+
if layers is None:
48+
layers = self._extract_all_layers(model, recursive)
49+
50+
self.module_names = self._insert_hooks(
51+
module=model,
52+
layers=layers,
53+
recursive=recursive,
54+
)
55+
56+
def _extract_all_layers(
57+
self, module: nn.Module, recursive: bool = True
58+
) -> List[str]:
59+
"""
60+
Extracts all layer names from the model recursively.
61+
62+
This method traverses the model and collects the names of all modules
63+
(layers) in a flat list. It is useful for debugging or when you want to
64+
register hooks on all layers without specifying them explicitly.
65+
66+
Args:
67+
module: The PyTorch model (`torch.nn.Module`) to extract layer names from.
68+
69+
Returns:
70+
A list of strings, where each string is the name of a module in the model.
71+
"""
72+
layers = set()
73+
for name, child in module.named_children():
74+
if recursive and len(list(child.named_children())) > 0:
75+
layers.update(self._extract_all_layers(child))
76+
else:
77+
layers.add(name)
78+
return list(layers)
3979

4080
def _hook(
4181
self,
@@ -62,9 +102,18 @@ def _hook(
62102
module,
63103
inp,
64104
) # Unused parameters, but kept for compatibility with the hook signature
65-
self.features[module_name] = out.detach()
105+
batch_size = out.size(0)
106+
feature = out.reshape(batch_size, -1)
107+
feature = gram(feature)
108+
self.features[module_name] = feature
66109

67-
def _insert_hooks(self, layers: List[str]) -> List[str]:
110+
def _insert_hooks(
111+
self,
112+
module: nn.Module,
113+
layers: List[str],
114+
recursive: bool = True,
115+
prev_name: str = "",
116+
) -> List[str]:
68117
"""
69118
Registers forward hooks on the specified layers of the model.
70119
@@ -87,24 +136,26 @@ def _insert_hooks(self, layers: List[str]) -> List[str]:
87136
This typically indicates an issue with the provided layer names.
88137
"""
89138
filtered_layers: List[str] = []
90-
for module_name, module in self.model.named_modules():
139+
for module_name, child in module.named_children():
140+
curr_name = f"{prev_name}.{module_name}" if prev_name else module_name
141+
curr_name = curr_name.replace("_model.", "")
142+
num_grandchildren = len(list(child.named_children()))
143+
144+
if recursive and num_grandchildren > 0:
145+
# If the module has children, recursively register hooks for them
146+
filtered_layers.extend(
147+
self._insert_hooks(
148+
module=child,
149+
layers=layers,
150+
recursive=recursive,
151+
prev_name=curr_name,
152+
)
153+
)
154+
91155
if module_name in layers:
92-
handle = module.register_forward_hook(partial(self._hook, module_name)) # type: ignore
156+
handle = child.register_forward_hook(partial(self._hook, curr_name)) # type: ignore
93157
self.handles.append(handle)
94-
filtered_layers.append(module_name)
95-
96-
if len(filtered_layers) != len(layers):
97-
hooked_set = set(filtered_layers)
98-
not_hooked = [layer for layer in layers if layer not in hooked_set]
99-
print(
100-
f"Warning: Could not find layers: {not_hooked}. They will be ignored."
101-
)
102-
103-
if not filtered_layers:
104-
raise ValueError(
105-
"No layers were found in the model. Please use `model.named_modules()` "
106-
"to check the available layer names."
107-
)
158+
filtered_layers.append(curr_name)
108159

109160
return filtered_layers
110161

cka_pytorch/utils.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import torch
2+
3+
4+
def gram(x: torch.Tensor) -> torch.Tensor:
5+
"""
6+
Computes the Gram matrix of the input tensor.
7+
8+
The Gram matrix is a square matrix of inner products, where G_ij = v_i^T v_j.
9+
In this context, it is used to capture the relationships between feature vectors
10+
in a set of samples.
11+
12+
Args:
13+
x: A tensor of shape (N, D), where N is the number of samples (batch size)
14+
and D is the feature dimension.
15+
16+
Returns:
17+
The Gram matrix of shape (N, N).
18+
"""
19+
return x.matmul(x.t())

0 commit comments

Comments
 (0)