-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathtrain.py
More file actions
164 lines (139 loc) · 6.12 KB
/
train.py
File metadata and controls
164 lines (139 loc) · 6.12 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
# a simple training script for a diffusion model
import os
import torch
import random
from datasets import load_dataset
from torch.utils.data import DataLoader
import torchvision
from dotenv import load_dotenv
from model.unet2d import UNet2DConfig, SimpleUNet2D
from train_scheme import (
OTFlowMatching,
NoisePrediction,
)
from utils.exp_control import ExperimentController
from argparse import ArgumentParser
# Data augmentation
def transform(example):
# Resize
example["image"] = torchvision.transforms.Resize((args.train_size, args.train_size))(example["image"])
# Random horizontal flip
if random.random() > 0.5:
example["image"] = torchvision.transforms.functional.hflip(example["image"])
return example
def main(args):
# construct experiment controller with config
config = vars(args) # Convert args to dict for logging
exp_controller = ExperimentController(
experiment_name=f"{args.dataset}_{args.model}_{args.mode}".replace("/", "_"),
config=config,
)
# set device
device = torch.device(args.device)
exp_controller.logger.info(f"Using device: {device}")
# load dataset
exp_controller.logger.info(f"Loading dataset: {args.dataset}")
dataset = load_dataset(
args.dataset,
token=os.getenv("HUGGINGFACE_TOKEN"),
cache_dir="datasets",
split="train",
)
dataset.set_format(type="torch", columns=["image"])
dataset = dataset.map(transform, batched=True)
train_dataset = dataset["image"]
exp_controller.logger.info(f"Dataset loaded with {len(train_dataset)} images")
# Create config with custom parameters
unet_config = UNet2DConfig(
in_channels=args.in_channels,
base_channels=32,
num_res_blocks=3,
attention_layers=(1, 2, 3, 4),
dropout=0.2,
channel_multiplier=(1, 2, 3, 4), # This creates a 4-block structure as before
time_emb_dim=128,
num_heads=4
)
# Initialize model with config
model = SimpleUNet2D(config=unet_config).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.999))
# Log model architecture to TensorBoard
# exp_controller.log_model_graph(model, input_shape=(1, 3, args.train_size, args.train_size))
# load dataloader
train_loader = DataLoader(
train_dataset,
batch_size=args.bs,
shuffle=True,
num_workers=4,
)
# train the model
# train_scheme = OTFlowMatching()
if args.mode == "ot":
train_scheme = OTFlowMatching(device=device)
elif args.mode == "noise":
train_scheme = NoisePrediction(device=device)
global_step = 0
best_loss = float('inf')
for epoch in range(args.epoch):
model.train()
epoch_losses = []
for idx, batch in enumerate(train_loader):
batch = batch.to(device) / 255.0 # normalize to [0, 1]
noise_sample = torch.randn_like(batch)
data_sample = batch
# Forward pass
loss = train_scheme.loss(model, noise_sample, data_sample)
# Backward pass
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Log metrics
loss_value = loss.item()
epoch_losses.append(loss_value)
exp_controller.log_metrics({'loss': loss_value}, global_step, phase='train')
if idx % 100 == 0:
exp_controller.logger.info(f"Epoch {epoch}, Step {idx}, Loss: {loss_value}")
global_step += 1
# Log epoch-level metrics
epoch_avg_loss = sum(epoch_losses) / len(epoch_losses)
exp_controller.log_metrics({'epoch_avg_loss': epoch_avg_loss}, epoch, phase='train')
# Update best model if needed
# is_best = exp_controller.update_best_model(model, optimizer, epoch_avg_loss, epoch)
# Save regular checkpoint
if epoch % args.save_every == 0 or epoch == args.epoch - 1:
exp_controller.save_model(model, optimizer, step=epoch)
# Sample some generations
model.eval()
with torch.no_grad():
for i in range(5):
noise = torch.randn(1, batch.shape[1], args.train_size, args.train_size).to(device)
generated_image = train_scheme.generate(model, steps=args.sample_steps, noise=noise)
# Save with experiment controller
image_filename = f"generated_epoch{epoch}_sample{i}.png"
exp_controller.save_image(
generated_image,
image_filename,
normalize=True,
tag=f"generation/epoch_{epoch}"
)
# Finalize experiment
exp_controller.plot_metrics()
exp_controller.finish()
exp_controller.logger.info("Training completed!")
if __name__ == "__main__":
load_dotenv()
parser = ArgumentParser(description="Train a diffusion model")
parser.add_argument("--dataset", type=str, default="korexyz/celeba-hq-256x256", help="Dataset to use in HF format")
parser.add_argument("--model", type=str, default="unet2d", help="Model architecture")
parser.add_argument("--device", type=str, default="cuda:7", help="Device to use for training")
parser.add_argument("--bs", type=int, default=24, help="Batch size")
parser.add_argument("--epoch", type=int, default=10, help="Number of epochs")
parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate")
parser.add_argument("--train_size", type=int, default=48, help="Size of training images")
parser.add_argument("--save_every", type=int, default=2, help="Save checkpoints every N epochs")
parser.add_argument("--sample_steps", type=int, default=25, help="Number of sampling steps")
parser.add_argument("--mode", type=str, default="train", choices=["ot", "noise"],
help="Training mode: 'ot' for OTFlowMatching, 'noise' for NoisePrediction")
parser.add_argument("--in_channels", type=int, default=3, help="Number of input channels")
args = parser.parse_args()
main(args)