-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathgenerator.py
More file actions
23 lines (19 loc) · 736 Bytes
/
generator.py
File metadata and controls
23 lines (19 loc) · 736 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
def masks(module):
r"""Returns an iterator over modules masks, yielding the mask."""
for name, buf in module.named_buffers():
if "weight_mask" in name:
yield buf
def parameters(model):
r"""Returns an iterator over models trainable parameters, yielding just the
parameter tensor.
"""
for module in model.modules():
for param in module.parameters(recurse=False):
yield param
def masked_parameters(model):
r"""Returns an iterator over models prunable parameters, yielding both the
mask and parameter tensors.
"""
for module in model.modules():
for mask, param in zip(masks(module), module.parameters(recurse=False)):
yield mask, param