-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcount_key.py
More file actions
31 lines (26 loc) · 1017 Bytes
/
count_key.py
File metadata and controls
31 lines (26 loc) · 1017 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import torch
from collections import defaultdict
def count_keys_by_full_prefix(weight_path):
# Load the checkpoint
state_dict = torch.load(weight_path, map_location='cpu')
# Extract state_dict from common wrappers
if isinstance(state_dict, dict):
if 'state_dict' in state_dict:
state_dict = state_dict['state_dict']
elif 'model' in state_dict:
state_dict = state_dict['model']
# Count using full prefix minus final parameter name (e.g., weight/bias)
prefix_counts = defaultdict(int)
for key in state_dict.keys():
# Split off the parameter name (like weight/bias)
parts = key.split('.')
if len(parts) > 1:
prefix = '.'.join(parts[:-1])
else:
prefix = parts[0]
prefix_counts[prefix] += 1
# Sort and print
for prefix in sorted(prefix_counts):
print(f"{prefix}: {prefix_counts[prefix]} keys")
# Example usage
count_keys_by_full_prefix("./ckpts/pretrained_weights.pt")