-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
61 lines (49 loc) · 1.55 KB
/
train.py
File metadata and controls
61 lines (49 loc) · 1.55 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import os
import torch
import torch.optim as optim
from .utils import show_samples
def train_diffusion(
model,
dataloader,
num_epochs=50,
lr=1e-4,
log_interval=800,
sample_interval=1,
checkpoint_dir="./checkpoints",
device="cuda",
):
os.makedirs(checkpoint_dir, exist_ok=True)
optimizer = optim.Adam(model.parameters(), lr=lr)
model.to(device)
step = 0
for epoch in range(num_epochs):
for batch_idx, (images, _) in enumerate(dataloader):
images = images.to(device)
loss, _, _, _ = model(images)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if step % log_interval == 0:
print(
f"Epoch [{epoch + 1}/{num_epochs}] Step [{batch_idx}/{len(dataloader)}] "
f"Loss: {loss.item():.4f}"
)
step += 1
ckpt_path = os.path.join(checkpoint_dir, f"model_epoch_{epoch + 1}.pt")
torch.save(
{
"epoch": epoch + 1,
"model_state": model.state_dict(),
"optimizer_state": optimizer.state_dict(),
},
ckpt_path,
)
print(f"Saved checkpoint: {ckpt_path}")
step = 0
# Generate & show samples
if (epoch + 1) % sample_interval == 0:
model.eval()
with torch.no_grad():
samples = model.sample((16, 1, 32, 32), device=device)
show_samples(samples, epoch + 1)
model.train()