-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_profiling.py
More file actions
executable file
·383 lines (329 loc) · 14.5 KB
/
train_profiling.py
File metadata and controls
executable file
·383 lines (329 loc) · 14.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
A minimal training script for DiT.
"""
import torch
# the first flag below was False when we tested this script but True makes A100 training a lot faster:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
import torch.distributed as dist
from torch.utils.data import Dataset, DataLoader
import numpy as np
from collections import OrderedDict
from PIL import Image
from copy import deepcopy
from glob import glob
from time import time
import argparse
import logging
import os
from accelerate import Accelerator
from tqdm import tqdm
from models import DiT_models
from diffusion import create_diffusion
from accelerate.utils import set_seed
#import wandb
torch._dynamo.config.optimize_ddp=False
from torch.profiler import profile, record_function, ProfilerActivity,schedule
#################################################################################
# Training Helper Functions #
#################################################################################
@torch.no_grad()
def update_ema(ema_model, model, decay=0.9999):
"""
Step the EMA model towards the current model.
"""
ema_params = OrderedDict(ema_model.named_parameters())
model_params = OrderedDict(model.named_parameters())
for name, param in model_params.items():
name = name.replace("module.", "")
# name = name.replace("_orig_mod._orig_mod.","_orig_mod.")
# TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed
if name not in ema_params.keys():
print(name)
print("ema")
print(ema_params.keys())
print("model")
print(model_params.keys())
ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay)
def trace_handler(p):
sort_by_keyword = "cuda_time_total"
# output = p.key_averages(group_by_input_shape=True).table(sort_by=sort_by_keyword, row_limit=20)
# print(output)
def requires_grad(model, flag=True):
"""
Set requires_grad flag for all parameters in a model.
"""
for p in model.parameters():
p.requires_grad = flag
def create_logger(logging_dir):
"""
Create a logger that writes to a log file and stdout.
"""
logging.basicConfig(
level=logging.INFO,
format='[\033[34m%(asctime)s\033[0m] %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")]
)
logger = logging.getLogger(__name__)
return logger
def center_crop_arr(pil_image, image_size):
"""
Center cropping implementation from ADM.
https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
"""
while min(*pil_image.size) >= 2 * image_size:
pil_image = pil_image.resize(
tuple(x // 2 for x in pil_image.size), resample=Image.BOX
)
scale = image_size / min(*pil_image.size)
pil_image = pil_image.resize(
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
)
arr = np.array(pil_image)
crop_y = (arr.shape[0] - image_size) // 2
crop_x = (arr.shape[1] - image_size) // 2
return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])
class CustomDataset(Dataset):
def __init__(self, features_dir, labels_dir):
self.features_dir = features_dir
self.labels_dir = labels_dir
self.features_files = sorted(os.listdir(features_dir))
self.retries = 20
def __len__(self):
return len(self.features_files)
def __getitem__(self, idx):
cur_retry = 0
while cur_retry < self.retries:
try:
feature_file = self.features_files[idx]
features = np.load(os.path.join(self.features_dir, feature_file))
labels = int(feature_file[feature_file.rfind('label')+5:feature_file.rfind('.npy')])
except Exception as e:
print('error when loading file: %s' % feature_file)
print(e)
print('retrying...')
if cur_retry < self.retries:
cur_retry += 1
idx = np.random.randint(0, len(self.features_files)-1)
continue
break
return torch.from_numpy(features), torch.tensor([labels])
class DummyDataset(Dataset):
def __init__(self, latent_dim=32):
self.latent_dim = latent_dim
def __len__(self):
return 100000
def __getitem__(self, idx):
feat = torch.zeros((4, self.latent_dim, self.latent_dim))
label = torch.zeros((1)).long()
return feat, label
#################################################################################
# Training Loop #
#################################################################################
def main(args):
"""
Trains a new DiT model.
"""
assert torch.cuda.is_available(), "Training currently requires at least one GPU."
if args.gemm_tuning:
print("using gemm tuning")
#use TUNABLEOP and gemm tuning csv
os.environ["PYTORCH_TUNABLEOP_VERBOSE"]="1"
os.environ["PYTORCH_TUNABLEOP_ENABLED"]="1"
os.environ["PYTORCH_TUNABLEOP_FILENAME"]="gemm_tuning_results/profiling_dit_{}.csv".format(args.image_size)
# Setup accelerator:
accelerator = Accelerator(
mixed_precision=args.mixed_precision,
)
device = accelerator.device
# Setup an experiment folder:
if accelerator.is_main_process:
os.makedirs(args.results_dir, exist_ok=True) # Make results folder (holds all experiment subfolders)
experiment_index = len(glob(f"{args.results_dir}/*"))
model_string_name = args.model.replace("/", "-") # e.g., DiT-XL/2 --> DiT-XL-2 (for naming folders)
experiment_dir = f"{args.results_dir}/{experiment_index:03d}-{model_string_name}" # Create an experiment folder
checkpoint_dir = f"{experiment_dir}/checkpoints" # Stores saved model checkpoints
os.makedirs(checkpoint_dir, exist_ok=True)
logger = create_logger(experiment_dir)
logger.info(f"Experiment directory created at {experiment_dir}")
#accelerator.init_trackers(
# project_name="DiT",
# config=args,
# init_kwargs={
# },
#)
if args.seed is not None:
set_seed(args.seed)
# Create model:
assert args.image_size % 8 == 0, "Image size must be divisible by 8 (for the VAE encoder)."
latent_size = args.image_size // 8
model = DiT_models[args.model](
input_size=latent_size,
num_classes=args.num_classes,
use_fa=args.use_fa
)
# Note that parameter initialization is done within the DiT constructor
model = model.to(device)
if args.compile:
torch._dynamo.config.optimize_ddp=False
model= torch.compile(model)
ema = deepcopy(model).to(device) # Create an EMA of the model for use after training
requires_grad(ema, False)
diffusion = create_diffusion(timestep_respacing="") # default: 1000 steps, linear noise schedule
# vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-{args.vae}").to(device)
if accelerator.is_main_process:
logger.info(f"DiT Parameters: {sum(p.numel() for p in model.parameters()):,}")
# Setup optimizer (we used default Adam betas=(0.9, 0.999) and a constant learning rate of 1e-4 in our paper):
opt = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0)
# Setup data:
features_dir = f"{args.feature_path}/imagenet256_features"
labels_dir = f"{args.feature_path}/imagenet256_labels"
if args.dummydata:
dataset = DummyDataset(latent_size)
else:
dataset = CustomDataset(features_dir, labels_dir)
loader = DataLoader(
dataset,
batch_size=int(args.global_batch_size // accelerator.num_processes),
shuffle=False,
num_workers=args.num_workers,
pin_memory=True,
drop_last=True
)
if accelerator.is_main_process:
logger.info(f"Dataset contains {len(dataset):,} images ({args.feature_path})")
# Prepare models for training:
update_ema(ema, model, decay=0) # Ensure EMA is initialized with synced weights
model.train() # important! This enables embedding dropout for classifier-free guidance
ema.eval() # EMA model should always be in eval mode
model, opt, loader = accelerator.prepare(model, opt, loader)
# Variables for monitoring/logging purposes:
train_steps = 0
# running_loss = 0
# start_time = time()
progress_bar = tqdm(
range(0, args.max_train_steps),
initial=train_steps,
desc="Steps",
# Only show the progress bar once on each machine.
disable=not accelerator.is_local_main_process,
)
t0 = torch.cuda.Event(enable_timing=True)
t1 = torch.cuda.Event(enable_timing=True)
WARMUP_ITERS = 10
activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA]
my_schedule = schedule(
skip_first=10,
wait=5,
warmup=1,
active=3,
repeat=2)
iter_count = 0
if accelerator.is_main_process:
logger.info(f"Training for {args.epochs} epochs...")
with profile(activities=activities, record_shapes=True,schedule=my_schedule,on_trace_ready=trace_handler,with_flops=True) as prof:
for epoch in range(args.epochs):
if accelerator.is_main_process:
logger.info(f"Beginning epoch {epoch}...")
for x, y in loader:
x = x.to(device)
y = y.to(device)
x = x.squeeze(dim=1)
y = y.squeeze(dim=1)
t = torch.randint(0, diffusion.num_timesteps, (x.shape[0],), device=device)
model_kwargs = dict(y=y)
loss_dict = diffusion.training_losses(model, x, t, model_kwargs)
loss = loss_dict["loss"].mean()
prof.step()
opt.zero_grad()
accelerator.backward(loss)
opt.step()
update_ema(ema, model)
train_steps += 1
progress_bar.update(1)
# print(accelerator.gather(loss))
logs = {
"loss": accelerator.gather(loss).mean().detach().item(),
}
accelerator.log(logs, step=train_steps)
progress_bar.set_postfix(**logs)
# Save DiT checkpoint:
if train_steps % args.ckpt_every == 0 and train_steps > 0:
if accelerator.is_main_process:
checkpoint = {
"model": model.module.state_dict(),
"ema": ema.state_dict(),
"opt": opt.state_dict(),
"args": args
}
checkpoint_path = f"{checkpoint_dir}/{train_steps:07d}.pt"
torch.save(checkpoint, checkpoint_path)
logger.info(f"Saved checkpoint to {checkpoint_path}")
if train_steps == WARMUP_ITERS:
t0.record()
if train_steps >= args.max_train_steps:
break
if train_steps >= args.max_train_steps:
break
model.eval() # important! This disables randomized embedding dropout
# do any sampling/FID calculation/etc. with ema (or model) in eval mode ...
if accelerator.is_main_process:
logger.info("Done!")
t1.record()
torch.cuda.synchronize()
dt = t0.elapsed_time(t1) / 1000
logger.info(f"{(train_steps-WARMUP_ITERS)*args.global_batch_size/dt:0.2f} samples/s ({dt:0.4g}s)")
sort_by_keyword = "cuda_time_total"
output = prof.key_averages(group_by_input_shape=True,).table(sort_by=sort_by_keyword, row_limit=100,max_src_column_width=100,max_shapes_column_width=100,max_name_column_width=100)
if accelerator.is_main_process:
print(output)
model_name=args.model[:-2]
prof.export_chrome_trace("trace_nv_fa3_{}_rank_{}.json".format(model_name,accelerator.local_process_index))
torch.save(output,f"profiling_nv_fa3_{model_name}.txt")
if __name__ == "__main__":
# Default args here will train DiT-XL/2 with the hyperparameters we used in our paper (except training iters).
parser = argparse.ArgumentParser()
parser.add_argument("--feature-path", type=str, default="features")
parser.add_argument("--results-dir", type=str, default="results")
parser.add_argument("--model", type=str, choices=list(DiT_models.keys()), default="DiT-XL/2")
parser.add_argument("--image-size", type=int, choices=[256, 512], default=256)
parser.add_argument("--num-classes", type=int, default=1000)
parser.add_argument("--epochs", type=int, default=1400)
parser.add_argument("--global-batch-size", type=int, default=256)
parser.add_argument("--global-seed", type=int, default=0)
parser.add_argument("--vae", type=str, choices=["ema", "mse"], default="ema") # Choice doesn't affect training
parser.add_argument("--num-workers", type=int, default=0)
parser.add_argument("--log-every", type=int, default=100)
parser.add_argument("--ckpt-every", type=int, default=50_000)
parser.add_argument("--max-train-steps", type=int, default=400_000)
parser.add_argument(
"--dummydata",
action='store_true',
help="whether to use dummy data",
)
parser.add_argument("--exp-name", type=str, default="init_exp")
parser.add_argument("--mixed-precision", type=str, default="fp16", choices=["no", "fp16", "bf16"])
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--latent_dim", type=int, default=32)
parser.add_argument(
"--use_fa",
action='store_true',
help="whether to use flash attention",
)
parser.add_argument(
"--compile",
action='store_true',
help="whether to use torch.compile",
)
parser.add_argument(
"--gemm-tuning",
action='store_true',
help="whether to use torch.compile",
)
args = parser.parse_args()
main(args)