-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathtrain.py
More file actions
630 lines (537 loc) · 28 KB
/
train.py
File metadata and controls
630 lines (537 loc) · 28 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
import logging
import math
import os
import shutil
from pathlib import Path
import random
import accelerate
import torch
import torch.utils.checkpoint
from torch.utils.data import RandomSampler
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed
from huggingface_hub import create_repo, upload_folder
from packaging import version
from tqdm.auto import tqdm
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from einops import rearrange
import diffusers
from diffusers import AutoencoderKLTemporalDecoder, EulerDiscreteScheduler, UNetSpatioTemporalConditionModel
from diffusers.optimization import get_scheduler
from diffusers.training_utils import EMAModel
from diffusers.utils import deprecate, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available
from svd_poseguider.models.unet_spatio_temporal_condition_controlnet import UNetSpatioTemporalConditionControlNetModel
from svd_poseguider.models.controlnet_svd import ControlNetSVDModel
from svd_poseguider.pipelines.pipeline_stable_video_diffusion_controlnet import _get_add_time_ids, encode_image
from svd_poseguider.pipelines.pipeline_stable_video_diffusion_controlnet import StableVideoDiffusionPipelineControlNet
from svd_poseguider.argparse import parse_args
from svd_poseguider.dataset import make_train_dataset
from svd_poseguider.ops import rand_log_normal
from svd_poseguider.utils import load_images_from_folder, save_combined_frames, tensor_to_vae_latent
logger = get_logger(__name__, log_level="INFO")
def main():
args = parse_args()
if args.non_ema_revision is not None:
deprecate(
"non_ema_revision!=None",
"0.15.0",
message=(
"Downloading 'non_ema' weights from revision branches of the Hub is deprecated. Please make sure to"
" use `--variant=non_ema` instead."
),
)
logging_dir = os.path.join(args.output_dir, args.logging_dir)
accelerator_project_config = ProjectConfiguration(
project_dir=args.output_dir, logging_dir=logging_dir)
# ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
# log_with=args.report_to,
project_config=accelerator_project_config,
# kwargs_handlers=[ddp_kwargs]
)
generator = torch.Generator(
device=accelerator.device).manual_seed(23123134)
if args.report_to == "wandb":
if not is_wandb_available():
raise ImportError(
"Make sure to install wandb if you want to use it for logging during training.")
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger.info(accelerator.state, main_process_only=False)
if accelerator.is_local_main_process:
transformers.utils.logging.set_verbosity_warning()
diffusers.utils.logging.set_verbosity_info()
else:
transformers.utils.logging.set_verbosity_error()
diffusers.utils.logging.set_verbosity_error()
# If passed along, set the training seed now.
if args.seed is not None:
set_seed(args.seed)
# Handle the repository creation
if accelerator.is_main_process:
if args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)
if args.push_to_hub:
repo_id = create_repo(
repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
).repo_id
# Load scheduler, tokenizer and models.
noise_scheduler = EulerDiscreteScheduler.from_pretrained(
args.pretrained_model_name_or_path, subfolder="scheduler")
feature_extractor = CLIPImageProcessor.from_pretrained(
args.pretrained_model_name_or_path, subfolder="feature_extractor", revision=args.revision
)
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
args.pretrained_model_name_or_path, subfolder="image_encoder", revision=args.revision, variant="fp16"
)
vae = AutoencoderKLTemporalDecoder.from_pretrained(
args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant="fp16")
unet = UNetSpatioTemporalConditionControlNetModel.from_pretrained(
args.pretrained_model_name_or_path if args.pretrain_unet is None else args.pretrain_unet,
subfolder="unet",
low_cpu_mem_usage=True,
variant="fp16",
)
if args.controlnet_model_name_or_path:
logger.info("Loading existing controlnet weights")
controlnet = ControlNetSVDModel.from_pretrained(args.controlnet_model_name_or_path)
else:
logger.info("Initializing controlnet weights from unet")
controlnet = ControlNetSVDModel.from_unet(unet)
# Freeze vae and image_encoder
vae.requires_grad_(False)
image_encoder.requires_grad_(False)
unet.requires_grad_(False)
controlnet.requires_grad_(False)
# For mixed precision training we cast the text_encoder and vae weights to half-precision
# as these models are only used for inference, keeping weights in full precision is not required.
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
# Move image_encoder and vae to gpu and cast to weight_dtype
image_encoder.to(accelerator.device, dtype=weight_dtype)
vae.to(accelerator.device, dtype=weight_dtype)
unet.to(accelerator.device, dtype=weight_dtype)
# controlnet.to(accelerator.device, dtype=weight_dtype)
# Create EMA for the unet.
if args.use_ema:
ema_controlnet = EMAModel(unet.parameters(
), model_cls=UNetSpatioTemporalConditionModel, model_config=unet.config)
if args.enable_xformers_memory_efficient_attention:
if is_xformers_available():
import xformers
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
logger.warn(
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
unet.enable_xformers_memory_efficient_attention()
else:
raise ValueError(
"xformers is not available. Make sure it is installed correctly")
# `accelerate` 0.16.0 will have better support for customized saving
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir):
if args.use_ema:
ema_controlnet.save_pretrained(os.path.join(output_dir, "controlnet_ema"))
for i, model in enumerate(models):
model.save_pretrained(os.path.join(output_dir, "controlnet"))
# make sure to pop weight so that corresponding model is not saved again
weights.pop()
def load_model_hook(models, input_dir):
if args.use_ema:
load_model = EMAModel.from_pretrained(os.path.join(
input_dir, "unet_ema"), UNetSpatioTemporalConditionModel)
ema_controlnet.load_state_dict(load_model.state_dict())
ema_controlnet.to(accelerator.device)
del load_model
for i in range(len(models)):
# pop models so that they are not loaded again
model = models.pop()
# load diffusers style into model
load_model = ControlNetSVDModel.from_pretrained(
input_dir, subfolder="controlnet")
model.register_to_config(**load_model.config)
model.load_state_dict(load_model.state_dict())
del load_model
accelerator.register_save_state_pre_hook(save_model_hook)
accelerator.register_load_state_pre_hook(load_model_hook)
if args.gradient_checkpointing:
controlnet.enable_gradient_checkpointing()
# Enable TF32 for faster training on Ampere GPUs,
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
if args.allow_tf32:
torch.backends.cuda.matmul.allow_tf32 = True
if args.scale_lr:
args.learning_rate = (
args.learning_rate * args.gradient_accumulation_steps *
args.per_gpu_batch_size * accelerator.num_processes
)
# Initialize the optimizer
if args.use_8bit_adam:
try:
import bitsandbytes as bnb
except ImportError:
raise ImportError(
"Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
)
optimizer_cls = bnb.optim.AdamW8bit
else:
optimizer_cls = torch.optim.AdamW
controlnet.requires_grad_(True)
parameters_list = []
# for name, para in unet.named_parameters():
# if 'temporal_transformer_block' in name and 'down_blocks' in name:
# parameters_list.append(para)
# para.requires_grad = True
# else:
# para.requires_grad = False
optimizer = optimizer_cls(
controlnet.parameters(),
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
)
# check para
# if accelerator.is_main_process:
# rec_txt1 = open('rec_para.txt', 'w')
# rec_txt2 = open('rec_para_train.txt', 'w')
# for name, para in controlnet.named_parameters():
# if para.requires_grad is False:
# rec_txt1.write(f'{name}\n')
# else:
# rec_txt2.write(f'{name}\n')
# rec_txt1.close()
# rec_txt2.close()
# DataLoaders creation:
args.global_batch_size = args.per_gpu_batch_size * accelerator.num_processes
train_dataset = make_train_dataset(args)
sampler = RandomSampler(train_dataset)
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
sampler=sampler,
batch_size=args.per_gpu_batch_size,
num_workers=args.num_workers,
)
# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(
len(train_dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
overrode_max_train_steps = True
lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
num_training_steps=args.max_train_steps * accelerator.num_processes,
)
# Prepare everything with our `accelerator`.
unet, optimizer, lr_scheduler, train_dataloader, controlnet = accelerator.prepare(
unet, optimizer, lr_scheduler, train_dataloader, controlnet
)
if args.use_ema:
ema_controlnet.to(accelerator.device)
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(
len(train_dataloader) / args.gradient_accumulation_steps)
if overrode_max_train_steps:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
# Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(
args.max_train_steps / num_update_steps_per_epoch)
# We need to initialize the trackers we use, and also store our configuration.
# The trackers initializes automatically on the main process.
if accelerator.is_main_process:
accelerator.init_trackers("SVD_PoseGuider", config=vars(args))
# Train!
total_batch_size = args.per_gpu_batch_size * \
accelerator.num_processes * args.gradient_accumulation_steps
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataset)}")
logger.info(f" Num Epochs = {args.num_train_epochs}")
logger.info(
f" Instantaneous batch size per device = {args.per_gpu_batch_size}")
logger.info(
f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(
f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {args.max_train_steps}")
global_step = 0
first_epoch = 0
# Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint:
if args.resume_from_checkpoint != "latest":
path = os.path.basename(args.resume_from_checkpoint)
else:
# Get the most recent checkpoint
dirs = os.listdir(args.output_dir)
dirs = [d for d in dirs if d.startswith("checkpoint")]
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
path = dirs[-1] if len(dirs) > 0 else None
if path is None:
accelerator.print(
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
)
args.resume_from_checkpoint = None
else:
accelerator.print(f"Resuming from checkpoint {path}")
accelerator.load_state(os.path.join(args.output_dir, path))
global_step = int(path.split("-")[1])
resume_global_step = global_step * args.gradient_accumulation_steps
first_epoch = global_step // num_update_steps_per_epoch
resume_step = resume_global_step % (
num_update_steps_per_epoch * args.gradient_accumulation_steps)
# Only show the progress bar once on each machine.
progress_bar = tqdm(range(global_step, args.max_train_steps),
disable=not accelerator.is_local_main_process)
progress_bar.set_description("Steps")
for epoch in range(first_epoch, args.num_train_epochs):
controlnet.train()
train_loss = 0.0
for step, batch in enumerate(train_dataloader):
# Skip steps until we reach the resumed step
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
if step % args.gradient_accumulation_steps == 0:
progress_bar.update(1)
continue
with accelerator.accumulate(controlnet):
# We want to learn the denoising process w.r.t the edited images which
# are conditioned on the original image (which was edited) and the edit instruction.
# So, first, convert images to latent space.
pixel_values = batch["pixel_values"].to(weight_dtype).to(
accelerator.device, non_blocking=True
)
latents = tensor_to_vae_latent(pixel_values, vae)
bsz = latents.shape[0]
refer_index = random.randint(1, int(args.num_frames))
conditional_pixel_values = pixel_values[:, (refer_index-1):refer_index, :, :, :]
cond_sigmas = rand_log_normal(shape=[bsz, ], loc=-3.0, scale=0.5).to(latents)
cond_sigmas = cond_sigmas[:, None, None, None, None]
conditional_pixel_values = conditional_pixel_values + cond_sigmas * torch.randn_like(
conditional_pixel_values)
video_length = conditional_pixel_values.shape[1]
conditional_pixel_values = rearrange(conditional_pixel_values, "b f c h w -> (b f) c h w")
conditional_latents = vae.encode(conditional_pixel_values).latent_dist.sample()
conditional_latents = rearrange(conditional_latents, "(b f) c h w -> b f c h w", f=video_length)
conditional_latents = conditional_latents[:, 0, :, :, :]
sigma_data = 1
P_mean = 0.7
P_std = 1.6
rnd_normal = torch.randn([bsz, 1, 1, 1, 1], device=accelerator.device) # * N(0,1)
sigma = (rnd_normal * P_std + P_mean).exp() # * ln(\sigma) = N(0,1) * P_std + P_mean
c_skip = sigma_data ** 2 / (sigma ** 2 + sigma_data ** 2) # * skip scaling
c_out = -sigma * sigma_data / (sigma ** 2 + sigma_data ** 2) ** 0.5 # * output scaling
c_in = 1 / (sigma ** 2 + sigma_data ** 2) ** 0.5 # * input scaling
c_noise = sigma.log() / 4 # * noise conditioning
timesteps = c_noise.reshape([bsz])
loss_weight = (sigma ** 2 + sigma_data ** 2) / (sigma_data * sigma) ** 2 # * loss weighting
noisy_latents = latents + torch.randn_like(latents) * sigma
inp_noisy_latents = c_in * noisy_latents
image_embeddings = encode_image(
pixel_values[:, 0, :, :, :].float(),
accelerator=accelerator,
feature_extractor=feature_extractor,
image_encoder=image_encoder,
weight_dtype=weight_dtype)
added_time_ids = _get_add_time_ids(
7,
batch["motion_values"],
0.02, # noise_aug_strength == 0.0
image_embeddings.dtype,
bsz,
unet
)
added_time_ids = added_time_ids.to(latents.device)
# Concatenate the `conditional_latents` with the `noisy_latents`.
conditional_latents = conditional_latents.unsqueeze(
1).repeat(1, noisy_latents.shape[1], 1, 1, 1)
inp_noisy_latents = torch.cat(
[inp_noisy_latents, conditional_latents], dim=2)
controlnet_image = batch["pose_pixel_values"].to(dtype=weight_dtype)
negative_image_embeddings = torch.zeros_like(image_embeddings)
losses = []
for i in range(2):
encoder_hidden_states = (
negative_image_embeddings if i == 0 else image_embeddings
)
down_block_res_samples, mid_block_res_sample = controlnet(
inp_noisy_latents,
timesteps,
encoder_hidden_states=encoder_hidden_states,
added_time_ids=added_time_ids,
controlnet_cond=controlnet_image,
return_dict=False,
)
# Predict the noise residual
model_pred = unet(
inp_noisy_latents,
timesteps,
encoder_hidden_states=encoder_hidden_states,
added_time_ids=added_time_ids,
down_block_additional_residuals=[
sample.to(dtype=weight_dtype) for sample in down_block_res_samples
],
mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype),
).sample
predict_x0 = c_out * model_pred + c_skip * noisy_latents
loss = torch.mean((predict_x0 - latents) ** 2 * loss_weight)
losses.append(loss)
loss = losses[0] if len(losses) == 1 else losses[0] + losses[1] ###
# Gather the losses across all processes for logging (if we use distributed training).
avg_loss = accelerator.gather(
loss.repeat(args.per_gpu_batch_size)).mean()
train_loss += avg_loss.item() / args.gradient_accumulation_steps
# Backpropagate
accelerator.backward(loss)
# if accelerator.sync_gradients:
# accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
if args.use_ema:
ema_controlnet.step(controlnet.parameters())
progress_bar.update(1)
global_step += 1
accelerator.log({"train_loss": train_loss}, step=global_step)
train_loss = 0.0
if accelerator.is_main_process:
# save checkpoints!
if global_step % args.checkpointing_steps == 0:
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
if args.checkpoints_total_limit is not None:
checkpoints = os.listdir(args.output_dir)
checkpoints = [
d for d in checkpoints if d.startswith("checkpoint")]
checkpoints = sorted(
checkpoints, key=lambda x: int(x.split("-")[1]))
# before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
if len(checkpoints) >= args.checkpoints_total_limit:
num_to_remove = len(
checkpoints) - args.checkpoints_total_limit + 1
removing_checkpoints = checkpoints[0:num_to_remove]
logger.info(
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
)
logger.info(
f"removing checkpoints: {', '.join(removing_checkpoints)}")
for removing_checkpoint in removing_checkpoints:
removing_checkpoint = os.path.join(
args.output_dir, removing_checkpoint)
shutil.rmtree(removing_checkpoint)
save_path = os.path.join(
args.output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path)
logger.info(f"Saved state to {save_path}")
# sample images!
if (
(global_step % args.validation_steps == 0)
or (global_step == 1)
):
logger.info(
f"Running validation... \n Generating {args.num_validation_images} videos."
)
# create pipelines
if args.use_ema:
# Store the UNet parameters temporarily and load the EMA parameters to perform inference.
ema_controlnet.store(controlnet.parameters())
ema_controlnet.copy_to(controlnet.parameters())
# The models need unwrapping because for compatibility in distributed training mode.
pipeline = StableVideoDiffusionPipelineControlNet.from_pretrained(
args.pretrained_model_name_or_path,
unet=accelerator.unwrap_model(unet),
controlnet=accelerator.unwrap_model(
controlnet),
image_encoder=accelerator.unwrap_model(
image_encoder),
vae=accelerator.unwrap_model(vae),
revision=args.revision,
torch_dtype=weight_dtype,
)
pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True)
validation_images = load_images_from_folder(args.validation_image_folder)
validation_control_images = load_images_from_folder(args.validation_control_folder)
# run inference
val_save_dir = os.path.join(
args.output_dir, "validation_images")
if not os.path.exists(val_save_dir):
os.makedirs(val_save_dir)
with torch.autocast(
str(accelerator.device).replace(":0", ""), enabled=accelerator.mixed_precision == "fp16"
):
for val_img_idx in range(args.num_validation_images):
num_frames = args.num_frames
video_frames = pipeline(
validation_images[0],
validation_control_images[:14],
height=args.height,
width=args.width,
num_frames=num_frames,
decode_chunk_size=8,
motion_bucket_id=127,
fps=7,
noise_aug_strength=0,
# generator=generator,
).frames
out_file = os.path.join(
val_save_dir,
f"step_{global_step}_val_img_{val_img_idx}.mp4",
)
pipeline.save_pretrained(args.output_dir)
# for i in range(num_frames):
# img = video_frames[i]
# video_frames[i] = np.array(img)
save_combined_frames(video_frames, validation_images, validation_control_images,
val_save_dir)
# export_to_gif(video_frames, out_file, 8)
if args.use_ema:
# Switch back to the original UNet parameters.
ema_controlnet.restore(controlnet.parameters())
del pipeline
torch.cuda.empty_cache()
logs = {"step_loss": loss.detach().item(
), "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
if global_step >= args.max_train_steps:
break
# Create the pipelines using the trained modules and save it.
accelerator.wait_for_everyone()
if accelerator.is_main_process:
controlnet = accelerator.unwrap_model(controlnet)
if args.use_ema:
ema_controlnet.copy_to(controlnet.parameters())
pipeline = StableVideoDiffusionPipelineControlNet.from_pretrained(
args.pretrained_model_name_or_path,
image_encoder=accelerator.unwrap_model(image_encoder),
vae=accelerator.unwrap_model(vae),
unet=unet,
controlnet=controlnet,
revision=args.revision,
)
pipeline.save_pretrained(args.output_dir)
if args.push_to_hub:
upload_folder(
repo_id=repo_id,
folder_path=args.output_dir,
commit_message="End of training",
ignore_patterns=["step_*", "epoch_*"],
)
accelerator.end_training()
if __name__ == "__main__":
main()