-
Notifications
You must be signed in to change notification settings - Fork 12
Expand file tree
/
Copy pathutils.py
More file actions
25 lines (22 loc) · 667 Bytes
/
utils.py
File metadata and controls
25 lines (22 loc) · 667 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import torch
def _get_compute_dtype():
major, minor = torch.cuda.get_device_capability()
if major >= 8:
return torch.bfloat16
else:
return torch.float16
def print_trainable_parameters(model):
"""
Prints the number of trainable parameters in the model.
"""
trainable_params = 0
all_param = 0
for _, param in model.named_parameters():
all_param += param.numel()
if param.requires_grad:
trainable_params += param.numel()
print(
f"trainable params: {trainable_params} || "
f"all params: {all_param} || "
f"trainable: {100 * trainable_params / all_param}"
)