-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutil_pth.py
More file actions
29 lines (19 loc) · 750 Bytes
/
util_pth.py
File metadata and controls
29 lines (19 loc) · 750 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import torch
import torch.nn as nn
from resnet import ResIRSE
from resnet18 import ResNet18
from collections import OrderedDict
# Teacher train with data parallel
# state_dict = torch.load("models/teacher.pth", map_location='cpu')
# new_state_dict = OrderedDict()
# for k, v in state_dict.items():
# name = k[7:] # remove `module.`
# new_state_dict[name] = v
# model = ResIRSE(embedding_size=512,drop_ratio=0.5)
# model.load_state_dict(new_state_dict)
# torch.save(model.state_dict(), "TeacherWithoutDataParalle.pth")
# Student train without data parallel
model = ResNet18()
state_dict = torch.load("models/s.pth", map_location='cpu')
model.load_state_dict(state_dict)
torch.save(model.state_dict(), "StudentWithoutDataParalle.pth")