-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathparams.py
More file actions
157 lines (135 loc) · 4.74 KB
/
params.py
File metadata and controls
157 lines (135 loc) · 4.74 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
from model import UniBindClassifier, ForwardMode
from loss import l2_loss, ce_loss
import torch
from torch.optim import AdamW
from transform import unnormalize_inplace, normalize_inplace
from attack import APGDAttack, AttackModel
def find_lr(
logger,
device,
raw_emb,
raw_lbls,
lbl_to_idx,
idx_to_lbl,
pretrain_weights,
use_flash_attention,
train_mean,
train_std,
train_loader,
attack_loss_type,
train_loss_type,
epsilon,
steps=100,
):
logger.info("Finding learning rate ...")
logger.info("Initializing original model ...")
model_original = UniBindClassifier(
device=device,
pretrain_weights=pretrain_weights,
modality="image",
centre_embeddings=raw_emb,
centre_labels=raw_lbls,
label_to_index=lbl_to_idx,
index_to_label=idx_to_lbl,
logger=logger,
use_flash_attention=use_flash_attention,
modality_head_mlp_weights=None
)
model_original.to(device)
logger.info("Initializing training model ...")
model_train = UniBindClassifier(
device=device,
pretrain_weights=pretrain_weights,
modality="image",
centre_embeddings=raw_emb,
centre_labels=raw_lbls,
label_to_index=lbl_to_idx,
index_to_label=idx_to_lbl,
logger=logger,
use_flash_attention=use_flash_attention,
use_lora=True,
use_modality_head_mlp=False,
fine_tuned_weights=None,
)
model_train.to(device)
trainable_params = [p for p in model_train.parameters() if p.requires_grad]
optimizer = AdamW(trainable_params, lr=1e-3, weight_decay=1e-4, betas=(0.9, 0.95))
attack = APGDAttack(
model=AttackModel(model_train, train_mean, train_std),
norm='linf',
n_restarts=1,
n_iter=10,
eps=epsilon,
loss_type=attack_loss_type,
device=device,
logger=logger
)
model_train.train()
model_original.eval()
logger.info("Running LR finder ...")
num_batches = min(len(train_loader), steps)
init_value = 1e-7
final_value = 1.0
lr_multiplier = (final_value / init_value) ** (1.0 / max(1, num_batches - 1))
lr = init_value
logger.info(f"Initial learning rate: {lr:.6f}, final learning rate: {final_value:.6f}, multiplier: {lr_multiplier:.6f}")
for param_group in optimizer.param_groups:
param_group['lr'] = lr
beta = 0.98
avg_loss = 0.0
best_loss = float('inf')
batch_num = 0
losses = []
smoothed_losses = []
lrs = []
for batch_idx, (inp, lbl) in enumerate(train_loader):
if batch_idx >= steps:
break
logger.info(f"Batch {batch_idx + 1}/{steps} ...")
logger.info(f"Learning rate: {lr:.6f}")
batch_num += 1
inp, lbl = inp.to(device), lbl.to(device)
model_train.eval()
inp_unorm = inp.clone().detach()
unnormalize_inplace(inp_unorm, train_mean, train_std)
emb_orig = None
if train_loss_type == 'l2':
with torch.no_grad():
emb_orig = model_original(inp, mode=ForwardMode.EMBEDDINGS)
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
adv_inp = attack.perturb(inp_unorm, lbl, emb_orig)
normalize_inplace(adv_inp, train_mean, train_std)
model_train.train()
optimizer.zero_grad()
if train_loss_type == 'l2':
emb_adv = model_train(adv_inp, mode=ForwardMode.EMBEDDINGS)
loss = l2_loss(emb_orig, emb_adv)
del emb_orig, emb_adv
elif train_loss_type == 'ce':
logits_adv, _ = model_train(adv_inp, mode=ForwardMode.LOGITS)
loss = ce_loss(logits_adv, lbl)
del logits_adv
else:
raise ValueError(f"Unknown loss type: {train_loss_type}")
loss_val = loss.item()
avg_loss = beta * avg_loss + (1 - beta) * loss_val
smoothed_loss = avg_loss / (1 - beta ** batch_num)
logger.info(f"Loss: {loss_val}, Smoothed loss: {smoothed_loss:.6f}, best loss: {best_loss:.6f}")
if smoothed_loss < best_loss:
best_loss = smoothed_loss
losses.append(loss_val)
smoothed_losses.append(smoothed_loss)
lrs.append(lr)
loss.backward()
optimizer.step()
lr *= lr_multiplier
logger.info(f"Updated learning rate: {lr:.6f}")
for param_group in optimizer.param_groups:
param_group['lr'] = lr
if smoothed_loss > 4 * best_loss:
logger.info("Stopping early due to loss explosion.")
break
del inp, lbl, inp_unorm, adv_inp, loss
torch.cuda.empty_cache()
logger.info("Finished running LR finder.")
return lrs, losses, smoothed_losses