-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathOwn1CycleLR.py
More file actions
79 lines (65 loc) · 2.63 KB
/
Own1CycleLR.py
File metadata and controls
79 lines (65 loc) · 2.63 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import torch
import torch.nn as nn
import math
from torch.optim.lr_scheduler import LinearLR, ConstantLR, SequentialLR
import matplotlib.pyplot as plt
class Own1CycleLR():
def __init__(
self,
start_lr: float,
max_lr: float,
end_lr: float,
warumup: float,
cooldown: float,
total_steps: int) -> None:
assert warumup + cooldown <= 1.0, "warumup + cooldown must be smaller or equal to 1.0"
if warumup == 0.0:
start_lr = max_lr
if warumup == 0.0 and cooldown == 0.0:
# make it a constant scheduler
warumup = 0.2
cooldown = 0.2
start_lr = end_lr = max_lr
self.start_lr = start_lr
self.max_lr = max_lr
self.end_lr = end_lr
self.warumup_steps = math.floor(total_steps * warumup)
self.cooldown_steps = math.floor(total_steps * cooldown)
self.total_steps = total_steps
def getScheduler(self, optimizer):
# only generate the scheduler when it is requested, because LinearLR manipulate the optimizer when initialized !!!!!!!
warumup_scheduler = LinearLR(
optimizer,
start_factor=self.start_lr / self.max_lr,
end_factor=1.0,
total_iters=self.warumup_steps)
constant_scheduler = ConstantLR(
optimizer,
factor=1.0,
total_iters=self.total_steps - self.warumup_steps - self.cooldown_steps
)
decay_scheduler = LinearLR(
optimizer,
start_factor=1.0,
end_factor=self.end_lr / self.max_lr,
total_iters=self.cooldown_steps)
step_scheduler = SequentialLR(
optimizer,
schedulers=[warumup_scheduler, constant_scheduler, decay_scheduler],
milestones=[self.warumup_steps, self.total_steps - self.cooldown_steps - 1] # -1 to also reach the last cooldown learning rate
)
return step_scheduler
def plot(self):
dummy_model = torch.nn.Sequential(
torch.nn.Linear(3, 1),
torch.nn.Flatten(0, 1))
dummy_optimizer = torch.optim.AdamW(dummy_model.parameters(), lr=self.max_lr)
step_scheduler = self.getScheduler(dummy_optimizer)
lr = []
for i in range(self.total_steps):
lr.append(dummy_optimizer.param_groups[0]['lr'])
step_scheduler.step()
# show the plot with dots instead of lines and show x y for each dot
plt.plot(lr, marker='o', linestyle='--', linewidth=0.5, markersize=2)
plt.grid(True)
plt.show()