-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodels_finetune.py
More file actions
executable file
·129 lines (105 loc) · 5.16 KB
/
models_finetune.py
File metadata and controls
executable file
·129 lines (105 loc) · 5.16 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import torch
from torch import nn
from timm.models.layers import trunc_normal_
from models_pretrain import *
from models.models_vit import vit_customized
from models.unetformer import Segmentation_ViT_UNetFormer
from utils.pos_embed import interpolate_pos_embed
def create_finetune_model_vit(train_model, which_finetuning, drop_path, global_pool, config, pretrained_path, finetuning_type, device, args):
### Define and embedding dimension of the encoder
if train_model == "vit-t-16":
encoder_output_dim = 192
elif train_model == "vit-s-16":
encoder_output_dim = 384
elif train_model == "vit-b-16":
encoder_output_dim = 768
elif train_model == "vit-l-16":
encoder_output_dim = 1024
else:
raise ValueError(f"Unknown model type: {train_model}. Available types: vit-t-16, vit-s-16, vit-b-16, vit-l-16")
### Define model
if train_model == "vit-t-16":
model = vit_customized(
img_size=config["input_size"][0], patch_size=16, in_chans=3,
embed_dim=encoder_output_dim, depth=12, num_heads=3,
drop_path_rate=drop_path, global_pool=global_pool, num_classes=config["num_classes"]
)
elif train_model == "vit-s-16":
model = vit_customized(
img_size=config["input_size"][0], patch_size=16, in_chans=3,
embed_dim=encoder_output_dim, depth=12, num_heads=6,
drop_path_rate=drop_path, global_pool=global_pool, num_classes=config["num_classes"]
)
elif train_model == "vit-b-16":
model = vit_customized(
img_size=config["input_size"][0], patch_size=16, in_chans=3,
embed_dim=encoder_output_dim, depth=12, num_heads=12,
drop_path_rate=drop_path, global_pool=global_pool, num_classes=config["num_classes"]
)
elif train_model == "vit-l-16":
model = vit_customized(
img_size=config["input_size"][0], patch_size=16, in_chans=3,
embed_dim=encoder_output_dim, depth=24, num_heads=16,
drop_path_rate=drop_path, global_pool=global_pool, num_classes=config["num_classes"]
)
else:
raise ValueError(f"Unknown model type: {train_model}. Available types: vit-b-16, vit-b-32, vit-l-16, vit-l-32")
if which_finetuning == "scratch_training":
### Do nothing to train model from scratch
pass
else:
### Load pre-trained weights if provided
checkpoint = torch.load(pretrained_path, map_location='cpu', weights_only=False)
print("\nLoad pre-trained checkpoint from: %s" % pretrained_path)
checkpoint_model = checkpoint['model']
state_dict = model.state_dict()
for k in ['head.weight', 'head.bias']:
if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
print(f"Removing key {k} from pretrained checkpoint")
del checkpoint_model[k]
### Interpolate position embedding
interpolate_pos_embed(model, checkpoint_model)
### Load pre-trained model
msg = model.load_state_dict(checkpoint_model, strict=False)
if global_pool:
assert set(msg.missing_keys) == {'head.weight', 'head.bias', 'fc_norm.weight', 'fc_norm.bias'}
else:
assert set(msg.missing_keys) == {'head.weight', 'head.bias'}
### Manually initialize fc layer
trunc_normal_(model.head.weight, std=2e-5)
if "classification" in config["task_type"]:
if finetuning_type == "lp":
if which_finetuning in ["imagenet_pretrained", "mae_imagenet_pretrained", "checkpoint"]:
for param in model.parameters():
param.requires_grad = False
if hasattr(model, 'head'):
for param in model.head.parameters():
param.requires_grad = True
if hasattr(model, 'fc_norm'):
for param in model.fc_norm.parameters():
param.requires_grad = True
model = model.to(device)
return model
if "segmentation" in config["task_type"]:
model.head = nn.Identity()
if finetuning_type == "lp":
if which_finetuning in ["imagenet_pretrained", "mae_imagenet_pretrained", "checkpoint"]:
for param in model.parameters():
param.requires_grad = False
model = Segmentation_ViT_UNetFormer(encoder=model, encoder_output_dim=encoder_output_dim, num_classes=config["num_classes"], decoder_channels=64, window_size=8, dropout=0.1)
''' TODO
### Ensure decoder parameters remain trainable (they should be by default)
### But let's explicitly check - decoder parameters include:
### - projection_layers (ModuleDict)
### - spatial_projection
### - stage1, stage2, stage3, stage4
### - final_conv
decoder_params = []
for name, param in model.named_parameters():
if not name.startswith('encoder'): ### Everything that's not the encoder
decoder_params.append(param)
for param in decoder_params:
param.requires_grad = True
'''
model = model.to(device)
return model