From 8782b63937379efd3569b248464c3cdf03e91b5f Mon Sep 17 00:00:00 2001 From: kimmishi Date: Mon, 28 Aug 2023 09:42:27 +0800 Subject: [PATCH 1/9] tried tensor parallel --- examples/text_to_image/train_sdxl.py | 1162 ++++++++++++++++++ examples/text_to_image/train_sdxl_new.py | 1222 +++++++++++++++++++ src/diffusers/models/attention.py | 88 +- src/diffusers/models/attention_processor.py | 11 +- src/diffusers/models/tp_utils.py | 224 ++++ src/diffusers/models/transformer_2d.py | 8 +- src/diffusers/models/unet_2d_blocks.py | 7 + src/diffusers/models/unet_2d_condition.py | 17 +- 8 files changed, 2726 insertions(+), 13 deletions(-) create mode 100644 examples/text_to_image/train_sdxl.py create mode 100644 examples/text_to_image/train_sdxl_new.py create mode 100644 src/diffusers/models/tp_utils.py diff --git a/examples/text_to_image/train_sdxl.py b/examples/text_to_image/train_sdxl.py new file mode 100644 index 000000000000..09f51eccf633 --- /dev/null +++ b/examples/text_to_image/train_sdxl.py @@ -0,0 +1,1162 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fine-tuning script for Stable Diffusion XL for text2image.""" + +import argparse +import functools +import gc +import logging +import math +import os +import random +import shutil +from pathlib import Path + +import accelerate +import datasets +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import ProjectConfiguration, set_seed +from datasets import load_dataset +from huggingface_hub import create_repo, upload_folder +from packaging import version +from torchvision import transforms +from torchvision.transforms.functional import crop +from tqdm.auto import tqdm +from transformers import AutoTokenizer, PretrainedConfig + +import diffusers +from diffusers import ( + AutoencoderKL, + DDPMScheduler, + StableDiffusionXLPipeline, + UNet2DConditionModel, +) +from diffusers.optimization import get_scheduler +from diffusers.training_utils import EMAModel +from diffusers.utils import check_min_version, is_wandb_available +from diffusers.utils.import_utils import is_xformers_available + + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.20.0.dev0") + +logger = get_logger(__name__) + + +DATASET_NAME_MAPPING = { + "lambdalabs/pokemon-blip-captions": ("image", "text"), +} + + +def save_model_card( + repo_id: str, + images=None, + base_model=str, + dataset_name=str, + repo_folder=None, + vae_path=None, +): + img_str = "" + for i, image in enumerate(images): + image.save(os.path.join(repo_folder, f"image_{i}.png")) + img_str += f"![img_{i}](./image_{i}.png)\n" + + yaml = f""" +--- +license: creativeml-openrail-m +base_model: {base_model} +dataset: {dataset_name} +tags: +- stable-diffusion-xl +- stable-diffusion-xl-diffusers +- text-to-image +- diffusers +inference: true +--- + """ + model_card = f""" +# Text-to-image finetuning - {repo_id} + +This pipeline was finetuned from **{args.pretrained_model_name_or_path}** on the **{args.dataset_name}** dataset. Below are some example images generated with the finetuned pipeline using the following prompts: {args.validation_prompts}: \n +{img_str} + +Special VAE used for training: {vae_path}. +""" + with open(os.path.join(repo_folder, "README.md"), "w") as f: + f.write(yaml + model_card) + + +def import_model_class_from_model_name_or_path( + pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder" +): + text_encoder_config = PretrainedConfig.from_pretrained( + pretrained_model_name_or_path, subfolder=subfolder, revision=revision + ) + model_class = text_encoder_config.architectures[0] + + if model_class == "CLIPTextModel": + from transformers import CLIPTextModel + + return CLIPTextModel + elif model_class == "CLIPTextModelWithProjection": + from transformers import CLIPTextModelWithProjection + + return CLIPTextModelWithProjection + else: + raise ValueError(f"{model_class} is not supported.") + + +def parse_args(input_args=None): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--pretrained_vae_model_name_or_path", + type=str, + default=None, + help="Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help=( + "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private," + " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," + " or to a folder containing files that 🤗 Datasets can understand." + ), + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The config of the Dataset, leave as None if there's only one config.", + ) + parser.add_argument( + "--train_data_dir", + type=str, + default=None, + help=( + "A folder containing the training data. Folder contents must follow the structure described in" + " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file" + " must exist to provide the captions for the images. Ignored if `dataset_name` is specified." + ), + ) + parser.add_argument( + "--image_column", type=str, default="image", help="The column of the dataset containing an image." + ) + parser.add_argument( + "--caption_column", + type=str, + default="text", + help="The column of the dataset containing a caption or a list of captions.", + ) + parser.add_argument( + "--validation_prompt", + type=str, + default=None, + help="A prompt that is used during validation to verify that the model is learning.", + ) + parser.add_argument( + "--num_validation_images", + type=int, + default=4, + help="Number of images that should be generated during validation with `validation_prompt`.", + ) + parser.add_argument( + "--validation_epochs", + type=int, + default=1, + help=( + "Run fine-tuning validation every X epochs. The validation process consists of running the prompt" + " `args.validation_prompt` multiple times: `args.num_validation_images`." + ), + ) + parser.add_argument( + "--max_train_samples", + type=int, + default=None, + help=( + "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + ), + ) + parser.add_argument( + "--proportion_empty_prompts", + type=float, + default=0, + help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).", + ) + parser.add_argument( + "--output_dir", + type=str, + default="sdxl-model-finetuned", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=1024, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--center_crop", + default=False, + action="store_true", + help=( + "Whether to center crop the input images to the resolution. If not set, the images will be randomly" + " cropped. The images will be resized to the resolution first before cropping." + ), + ) + parser.add_argument( + "--random_flip", + action="store_true", + help="whether to randomly flip images horizontally", + ) + parser.add_argument( + "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument("--num_train_epochs", type=int, default=100) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" + " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--snr_gamma", + type=float, + default=None, + help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. " + "More details here: https://arxiv.org/abs/2303.09556.", + ) + parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.") + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument( + "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." + ) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--prediction_type", + type=str, + default=None, + help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediciton_type` is chosen.", + ) + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument( + "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." + ) + parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.") + + if input_args is not None: + args = parser.parse_args(input_args) + else: + args = parser.parse_args() + + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + # Sanity checks + if args.dataset_name is None and args.train_data_dir is None: + raise ValueError("Need either a dataset name or a training folder.") + + if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1: + raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].") + + return args + + +# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt +def encode_prompt(batch, text_encoders, tokenizers, proportion_empty_prompts, caption_column, is_train=True): + prompt_embeds_list = [] + prompt_batch = batch[caption_column] + + captions = [] + for caption in prompt_batch: + if random.random() < proportion_empty_prompts: + captions.append("") + elif isinstance(caption, str): + captions.append(caption) + elif isinstance(caption, (list, np.ndarray)): + # take a random caption if there are multiple + captions.append(random.choice(caption) if is_train else caption[0]) + + with torch.no_grad(): + for tokenizer, text_encoder in zip(tokenizers, text_encoders): + text_inputs = tokenizer( + captions, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + prompt_embeds = text_encoder( + text_input_ids.to(text_encoder.device), + output_hidden_states=True, + ) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + prompt_embeds = prompt_embeds.hidden_states[-2] + bs_embed, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1) + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1) + return {"prompt_embeds": prompt_embeds.cpu(), "pooled_prompt_embeds": pooled_prompt_embeds.cpu()} + + +def main(args): + logging_dir = Path(args.output_dir, args.logging_dir) + + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + ) + + if args.report_to == "wandb": + if not is_wandb_available(): + raise ImportError("Make sure to install wandb if you want to use it for logging during training.") + import wandb + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token + ).repo_id + + # Load the tokenizers + tokenizer_one = AutoTokenizer.from_pretrained( + args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, use_fast=False + ) + tokenizer_two = AutoTokenizer.from_pretrained( + args.pretrained_model_name_or_path, subfolder="tokenizer_2", revision=args.revision, use_fast=False + ) + + # import correct text encoder classes + text_encoder_cls_one = import_model_class_from_model_name_or_path( + args.pretrained_model_name_or_path, args.revision + ) + text_encoder_cls_two = import_model_class_from_model_name_or_path( + args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2" + ) + + # Load scheduler and models + noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") + text_encoder_one = text_encoder_cls_one.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision + ) + text_encoder_two = text_encoder_cls_two.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision + ) + vae_path = ( + args.pretrained_model_name_or_path + if args.pretrained_vae_model_name_or_path is None + else args.pretrained_vae_model_name_or_path + ) + vae = AutoencoderKL.from_pretrained( + vae_path, subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, revision=args.revision + ) + unet = UNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision + ) + numel = sum([p.numel() for p in unet.parameters()]) + print('UNET numel:', numel) + # import pdb;pdb.set_trace() + # Freeze vae and text encoders. + vae.requires_grad_(False) + text_encoder_one.requires_grad_(False) + text_encoder_two.requires_grad_(False) + print('vae numel:', sum([p.numel() for p in vae.parameters()])) + print('text_encoder_one numel:', sum([p.numel() for p in text_encoder_one.parameters()])) + print('text_encoder_two numel:', sum([p.numel() for p in text_encoder_two.parameters()])) + + # For mixed precision training we cast all non-trainable weigths to half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # Move unet, vae and text_encoder to device and cast to weight_dtype + # The VAE is in float32 to avoid NaN losses. + if args.pretrained_vae_model_name_or_path is None: + vae.to(accelerator.device, dtype=torch.float32) + else: + vae.to(accelerator.device, dtype=weight_dtype) + text_encoder_one.to(accelerator.device, dtype=weight_dtype) + text_encoder_two.to(accelerator.device, dtype=weight_dtype) + + # Create EMA for the unet. + if args.use_ema: + ema_unet = UNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision + ) + ema_unet = EMAModel(ema_unet.parameters(), model_cls=UNet2DConditionModel, model_config=ema_unet.config) + + if args.enable_xformers_memory_efficient_attention: + if is_xformers_available(): + import xformers + + xformers_version = version.parse(xformers.__version__) + if xformers_version == version.parse("0.0.16"): + logger.warn( + "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." + ) + unet.enable_xformers_memory_efficient_attention() + else: + raise ValueError("xformers is not available. Make sure it is installed correctly") + + def compute_snr(timesteps): + """ + Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849 + """ + alphas_cumprod = noise_scheduler.alphas_cumprod + sqrt_alphas_cumprod = alphas_cumprod**0.5 + sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 + + # Expand the tensors. + # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026 + sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float() + while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape): + sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] + alpha = sqrt_alphas_cumprod.expand(timesteps.shape) + + sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float() + while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape): + sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] + sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape) + + # Compute SNR. + snr = (alpha / sigma) ** 2 + return snr + + # `accelerate` 0.16.0 will have better support for customized saving + if version.parse(accelerate.__version__) >= version.parse("0.16.0"): + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + if args.use_ema: + ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema")) + + for i, model in enumerate(models): + model.save_pretrained(os.path.join(output_dir, "unet")) + + # make sure to pop weight so that corresponding model is not saved again + weights.pop() + + def load_model_hook(models, input_dir): + if args.use_ema: + load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DConditionModel) + ema_unet.load_state_dict(load_model.state_dict()) + ema_unet.to(accelerator.device) + del load_model + + for i in range(len(models)): + # pop models so that they are not loaded again + model = models.pop() + + # load diffusers style into model + load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet") + model.register_to_config(**load_model.config) + + model.load_state_dict(load_model.state_dict()) + del load_model + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + + # Optimizer creation + params_to_optimize = unet.parameters() + optimizer = optimizer_class( + params_to_optimize, + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + # Get the datasets: you can either provide your own training and evaluation files (see below) + # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub). + + # In distributed training, the load_dataset function guarantees that only one local process can concurrently + # download the dataset. + if args.dataset_name is not None: + # Downloading and loading a dataset from the hub. + dataset = load_dataset( + args.dataset_name, + args.dataset_config_name, + cache_dir=args.cache_dir, + ) + else: + data_files = {} + if args.train_data_dir is not None: + data_files["train"] = os.path.join(args.train_data_dir, "**") + dataset = load_dataset( + "imagefolder", + data_files=data_files, + cache_dir=args.cache_dir, + ) + # See more about loading custom images at + # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder + + # Preprocessing the datasets. + # We need to tokenize inputs and targets. + column_names = dataset["train"].column_names + + # 6. Get the column names for input/target. + dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None) + if args.image_column is None: + image_column = dataset_columns[0] if dataset_columns is not None else column_names[0] + else: + image_column = args.image_column + if image_column not in column_names: + raise ValueError( + f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}" + ) + if args.caption_column is None: + caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1] + else: + caption_column = args.caption_column + if caption_column not in column_names: + raise ValueError( + f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}" + ) + + # Preprocessing the datasets. + train_resize = transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR) + train_crop = transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution) + train_flip = transforms.RandomHorizontalFlip(p=1.0) + train_transforms = transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + def preprocess_train(examples): + images = [image.convert("RGB") for image in examples[image_column]] + # image aug + original_sizes = [] + all_images = [] + crop_top_lefts = [] + for image in images: + original_sizes.append((image.height, image.width)) + image = train_resize(image) + if args.center_crop: + y1 = max(0, int(round((image.height - args.resolution) / 2.0))) + x1 = max(0, int(round((image.width - args.resolution) / 2.0))) + image = train_crop(image) + else: + y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution)) + image = crop(image, y1, x1, h, w) + if args.random_flip and random.random() < 0.5: + # flip + x1 = image.width - x1 + image = train_flip(image) + crop_top_left = (y1, x1) + crop_top_lefts.append(crop_top_left) + image = train_transforms(image) + all_images.append(image) + + examples["original_sizes"] = original_sizes + examples["crop_top_lefts"] = crop_top_lefts + examples["pixel_values"] = all_images + return examples + + with accelerator.main_process_first(): + if args.max_train_samples is not None: + dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples)) + # Set the training transforms + train_dataset = dataset["train"].with_transform(preprocess_train) + + # Let's first compute all the embeddings so that we can free up the text encoders + # from memory. + text_encoders = [text_encoder_one, text_encoder_two] + tokenizers = [tokenizer_one, tokenizer_two] + compute_embeddings_fn = functools.partial( + encode_prompt, + text_encoders=text_encoders, + tokenizers=tokenizers, + proportion_empty_prompts=args.proportion_empty_prompts, + caption_column=args.caption_column, + ) + with accelerator.main_process_first(): + from datasets.fingerprint import Hasher + + # fingerprint used by the cache for the other processes to load the result + # details: https://github.com/huggingface/diffusers/pull/4038#discussion_r1266078401 + new_fingerprint = Hasher.hash(args) + train_dataset = train_dataset.map(compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint) + + del text_encoders, tokenizers + gc.collect() + torch.cuda.empty_cache() + + def collate_fn(examples): + pixel_values = torch.stack([example["pixel_values"] for example in examples]) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + original_sizes = [example["original_sizes"] for example in examples] + crop_top_lefts = [example["crop_top_lefts"] for example in examples] + prompt_embeds = torch.stack([torch.tensor(example["prompt_embeds"]) for example in examples]) + pooled_prompt_embeds = torch.stack([torch.tensor(example["pooled_prompt_embeds"]) for example in examples]) + + return { + "pixel_values": pixel_values, + "prompt_embeds": prompt_embeds, + "pooled_prompt_embeds": pooled_prompt_embeds, + "original_sizes": original_sizes, + "crop_top_lefts": crop_top_lefts, + } + + # DataLoaders creation: + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + shuffle=True, + collate_fn=collate_fn, + batch_size=args.train_batch_size, + num_workers=args.dataloader_num_workers, + ) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, + num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + ) + + # Prepare everything with our `accelerator`. + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, optimizer, train_dataloader, lr_scheduler + ) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + accelerator.init_trackers("text2image-fine-tune-sdxl", config=vars(args)) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + resume_global_step = global_step * args.gradient_accumulation_steps + first_epoch = global_step // num_update_steps_per_epoch + resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps) + + # Only show the progress bar once on each machine. + progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) + progress_bar.set_description("Steps") + + for epoch in range(first_epoch, args.num_train_epochs): + unet.train() + train_loss = 0.0 + for step, batch in enumerate(train_dataloader): + # Skip steps until we reach the resumed step + if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: + if step % args.gradient_accumulation_steps == 0: + progress_bar.update(1) + continue + + with accelerator.accumulate(unet): + # Convert images to latent space + pixel_values = batch["pixel_values"].to(dtype=weight_dtype) + model_input = vae.encode(pixel_values.to(vae.dtype)).latent_dist.sample() + model_input = model_input * vae.config.scaling_factor + + # Sample noise that we'll add to the latents + noise = torch.randn_like(model_input) + if args.noise_offset: + # https://www.crosslabs.org//blog/diffusion-with-offset-noise + noise += args.noise_offset * torch.randn( + (model_input.shape[0], model_input.shape[1], 1, 1), device=model_input.device + ) + + bsz = model_input.shape[0] + # Sample a random timestep for each image + timesteps = torch.randint( + 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device + ) + timesteps = timesteps.long() + + # Add noise to the model input according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps) + + # time ids + def compute_time_ids(original_size, crops_coords_top_left): + # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids + target_size = (args.resolution, args.resolution) + add_time_ids = list(original_size + crops_coords_top_left + target_size) + add_time_ids = torch.tensor([add_time_ids]) + add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype) + return add_time_ids + + add_time_ids = torch.cat( + [compute_time_ids(s, c) for s, c in zip(batch["original_sizes"], batch["crop_top_lefts"])] + ) + + # Predict the noise residual + unet_added_conditions = {"time_ids": add_time_ids} + prompt_embeds = batch["prompt_embeds"].to(accelerator.device) + pooled_prompt_embeds = batch["pooled_prompt_embeds"].to(accelerator.device) + unet_added_conditions.update({"text_embeds": pooled_prompt_embeds}) + prompt_embeds = prompt_embeds + model_pred = unet( + noisy_model_input, timesteps, prompt_embeds, added_cond_kwargs=unet_added_conditions + ).sample + + # Get the target for loss depending on the prediction type + if args.prediction_type is not None: + # set prediction_type of scheduler if defined + noise_scheduler.register_to_config(prediction_type=args.prediction_type) + + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(model_input, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + + if args.snr_gamma is None: + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + else: + # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556. + # Since we predict the noise instead of x_0, the original formulation is slightly changed. + # This is discussed in Section 4.2 of the same paper. + snr = compute_snr(timesteps) + mse_loss_weights = ( + torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr + ) + # We first calculate the original loss. Then we mean over the non-batch dimensions and + # rebalance the sample-wise losses with their respective loss weights. + # Finally, we take the mean of the rebalanced loss. + loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") + loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights + loss = loss.mean() + + # Gather the losses across all processes for logging (if we use distributed training). + avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean() + train_loss += avg_loss.item() / args.gradient_accumulation_steps + + # Backpropagate + accelerator.backward(loss) + if accelerator.sync_gradients: + params_to_clip = unet.parameters() + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + accelerator.log({"train_loss": train_loss}, step=global_step) + train_loss = 0.0 + + if accelerator.is_main_process: + if global_step % args.checkpointing_steps == 0: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "Max MEM": torch.cuda.max_memory_allocated()} + progress_bar.set_postfix(**logs) + + if global_step >= args.max_train_steps: + break + + if accelerator.is_main_process: + if args.validation_prompt is not None and epoch % args.validation_epochs == 0: + logger.info( + f"Running validation... \n Generating {args.num_validation_images} images with prompt:" + f" {args.validation_prompt}." + ) + if args.use_ema: + # Store the UNet parameters temporarily and load the EMA parameters to perform inference. + ema_unet.store(unet.parameters()) + ema_unet.copy_to(unet.parameters()) + # create pipeline + pipeline = StableDiffusionXLPipeline.from_pretrained( + args.pretrained_model_name_or_path, + vae=vae, + unet=accelerator.unwrap_model(unet), + revision=args.revision, + torch_dtype=weight_dtype, + ) + + pipeline = pipeline.to(accelerator.device) + pipeline.set_progress_bar_config(disable=True) + + # run inference + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None + pipeline_args = {"prompt": args.validation_prompt} + + with torch.cuda.amp.autocast(): + images = [ + pipeline(**pipeline_args, generator=generator, num_inference_steps=25).images[0] + for _ in range(args.num_validation_images) + ] + + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC") + if tracker.name == "wandb": + tracker.log( + { + "validation": [ + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") + for i, image in enumerate(images) + ] + } + ) + + del pipeline + torch.cuda.empty_cache() + + accelerator.wait_for_everyone() + if accelerator.is_main_process: + unet = accelerator.unwrap_model(unet) + if args.use_ema: + ema_unet.copy_to(unet.parameters()) + + # Final inference + # Load previous pipeline + vae = AutoencoderKL.from_pretrained( + vae_path, + subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, + revision=args.revision, + torch_dtype=weight_dtype, + ) + pipeline = StableDiffusionXLPipeline.from_pretrained( + args.pretrained_model_name_or_path, unet=unet, vae=vae, revision=args.revision, torch_dtype=weight_dtype + ) + pipeline = pipeline.to(accelerator.device) + + # run inference + images = [] + if args.validation_prompt and args.num_validation_images > 0: + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None + images = [ + pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0] + for _ in range(args.num_validation_images) + ] + + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC") + if tracker.name == "wandb": + tracker.log( + { + "test": [ + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") + for i, image in enumerate(images) + ] + } + ) + + if args.push_to_hub: + save_model_card( + repo_id, + images=images, + base_model=args.pretrained_model_name_or_path, + prompt=args.instance_prompt, + repo_folder=args.output_dir, + vae_path=args.pretrained_vae_model_name_or_path, + ) + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) + + accelerator.end_training() + + +if __name__ == "__main__": + args = parse_args() + main(args) \ No newline at end of file diff --git a/examples/text_to_image/train_sdxl_new.py b/examples/text_to_image/train_sdxl_new.py new file mode 100644 index 000000000000..ffe0e0b03112 --- /dev/null +++ b/examples/text_to_image/train_sdxl_new.py @@ -0,0 +1,1222 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fine-tuning script for Stable Diffusion XL for text2image.""" + +import argparse +import functools +import gc +import logging +import math +import os +import random +import shutil +from pathlib import Path + +import accelerate +import datasets +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import ProjectConfiguration, set_seed +from datasets import load_dataset +from huggingface_hub import create_repo, upload_folder +from packaging import version +from torchvision import transforms +from torchvision.transforms.functional import crop +from tqdm.auto import tqdm +from transformers import AutoTokenizer, PretrainedConfig +from apex.optimizers import FusedAdam + +import diffusers +from diffusers import ( + AutoencoderKL, + DDPMScheduler, + StableDiffusionXLPipeline, + UNet2DConditionModel, +) +from diffusers.optimization import get_scheduler +from diffusers.training_utils import EMAModel +from diffusers.utils import check_min_version, is_wandb_available +from diffusers.utils.import_utils import is_xformers_available + +import time + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.20.0.dev0") + +logger = get_logger(__name__) + +from apex.normalization import FusedLayerNorm +from flash_attn.ops.rms_norm import RMSNorm + + +DATASET_NAME_MAPPING = { + "lambdalabs/pokemon-blip-captions": ("image", "text"), +} + +def replace_all_layernorms(model, tgt): + for name, module in model.named_children(): + if isinstance(module, torch.nn.LayerNorm): + if tgt == 'fusedln': + setattr(model, name, FusedLayerNorm( + module.normalized_shape, module.eps, module.elementwise_affine)) + elif tgt == 'rmsnorm': + setattr(model, name, RMSNorm(module.normalized_shape, eps=module.eps)) + else: + import pdb;pdb.set_trace() + raise ValueError + else: + replace_all_layernorms(module, tgt) + return model + +def save_model_card( + repo_id: str, + images=None, + base_model=str, + dataset_name=str, + repo_folder=None, + vae_path=None, +): + img_str = "" + for i, image in enumerate(images): + image.save(os.path.join(repo_folder, f"image_{i}.png")) + img_str += f"![img_{i}](./image_{i}.png)\n" + + yaml = f""" +--- +license: creativeml-openrail-m +base_model: {base_model} +dataset: {dataset_name} +tags: +- stable-diffusion-xl +- stable-diffusion-xl-diffusers +- text-to-image +- diffusers +inference: true +--- + """ + model_card = f""" +# Text-to-image finetuning - {repo_id} + +This pipeline was finetuned from **{args.pretrained_model_name_or_path}** on the **{args.dataset_name}** dataset. Below are some example images generated with the finetuned pipeline using the following prompts: {args.validation_prompts}: \n +{img_str} + +Special VAE used for training: {vae_path}. +""" + with open(os.path.join(repo_folder, "README.md"), "w") as f: + f.write(yaml + model_card) + + +def import_model_class_from_model_name_or_path( + pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder" +): + text_encoder_config = PretrainedConfig.from_pretrained( + pretrained_model_name_or_path, subfolder=subfolder, revision=revision + ) + model_class = text_encoder_config.architectures[0] + + if model_class == "CLIPTextModel": + from transformers import CLIPTextModel + + return CLIPTextModel + elif model_class == "CLIPTextModelWithProjection": + from transformers import CLIPTextModelWithProjection + + return CLIPTextModelWithProjection + else: + raise ValueError(f"{model_class} is not supported.") + + +def parse_args(input_args=None): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--pretrained_vae_model_name_or_path", + type=str, + default=None, + help="Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help=( + "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private," + " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," + " or to a folder containing files that 🤗 Datasets can understand." + ), + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The config of the Dataset, leave as None if there's only one config.", + ) + parser.add_argument( + "--train_data_dir", + type=str, + default=None, + help=( + "A folder containing the training data. Folder contents must follow the structure described in" + " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file" + " must exist to provide the captions for the images. Ignored if `dataset_name` is specified." + ), + ) + parser.add_argument( + "--image_column", type=str, default="image", help="The column of the dataset containing an image." + ) + parser.add_argument( + "--caption_column", + type=str, + default="text", + help="The column of the dataset containing a caption or a list of captions.", + ) + parser.add_argument( + "--validation_prompt", + type=str, + default=None, + help="A prompt that is used during validation to verify that the model is learning.", + ) + parser.add_argument( + "--num_validation_images", + type=int, + default=4, + help="Number of images that should be generated during validation with `validation_prompt`.", + ) + parser.add_argument( + "--validation_epochs", + type=int, + default=1, + help=( + "Run fine-tuning validation every X epochs. The validation process consists of running the prompt" + " `args.validation_prompt` multiple times: `args.num_validation_images`." + ), + ) + parser.add_argument( + "--max_train_samples", + type=int, + default=None, + help=( + "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + ), + ) + parser.add_argument( + "--proportion_empty_prompts", + type=float, + default=0, + help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).", + ) + parser.add_argument( + "--output_dir", + type=str, + default="sdxl-model-finetuned", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=1024, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--center_crop", + default=False, + action="store_true", + help=( + "Whether to center crop the input images to the resolution. If not set, the images will be randomly" + " cropped. The images will be resized to the resolution first before cropping." + ), + ) + parser.add_argument( + "--random_flip", + action="store_true", + help="whether to randomly flip images horizontally", + ) + parser.add_argument( + "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument("--num_train_epochs", type=int, default=100) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" + " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--snr_gamma", + type=float, + default=None, + help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. " + "More details here: https://arxiv.org/abs/2303.09556.", + ) + parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.") + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument( + "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." + ) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--prediction_type", + type=str, + default=None, + help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediciton_type` is chosen.", + ) + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument( + "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." + ) + parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.") + parser.add_argument("--profile", action="store_true") + + if input_args is not None: + args = parser.parse_args(input_args) + else: + args = parser.parse_args() + + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + # Sanity checks + if args.dataset_name is None and args.train_data_dir is None: + raise ValueError("Need either a dataset name or a training folder.") + + if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1: + raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].") + + return args + + +# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt +def encode_prompt(batch, text_encoders, tokenizers, proportion_empty_prompts, caption_column, is_train=True): + prompt_embeds_list = [] + prompt_batch = batch[caption_column] + + captions = [] + for caption in prompt_batch: + if random.random() < proportion_empty_prompts: + captions.append("") + elif isinstance(caption, str): + captions.append(caption) + elif isinstance(caption, (list, np.ndarray)): + # take a random caption if there are multiple + captions.append(random.choice(caption) if is_train else caption[0]) + + with torch.no_grad(): + for tokenizer, text_encoder in zip(tokenizers, text_encoders): + text_inputs = tokenizer( + captions, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + prompt_embeds = text_encoder( + text_input_ids.to(text_encoder.device), + output_hidden_states=True, + ) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + prompt_embeds = prompt_embeds.hidden_states[-2] + bs_embed, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1) + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1) + return {"prompt_embeds": prompt_embeds.cpu(), "pooled_prompt_embeds": pooled_prompt_embeds.cpu()} + + +def compute_vae_encodings(batch, vae): + images = batch.pop("pixel_values") + pixel_values = torch.stack(list(images)) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + pixel_values = pixel_values.to(vae.device, dtype=vae.dtype) + + with torch.no_grad(): + model_input = vae.encode(pixel_values).latent_dist.sample() + model_input = model_input * vae.config.scaling_factor + return {"model_input": model_input.cpu()} + + +def main(args): + logging_dir = Path(args.output_dir, args.logging_dir) + + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + ) + + if args.report_to == "wandb": + if not is_wandb_available(): + raise ImportError("Make sure to install wandb if you want to use it for logging during training.") + import wandb + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token + ).repo_id + + # Load the tokenizers + tokenizer_one = AutoTokenizer.from_pretrained( + args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, use_fast=False + ) + tokenizer_two = AutoTokenizer.from_pretrained( + args.pretrained_model_name_or_path, subfolder="tokenizer_2", revision=args.revision, use_fast=False + ) + + # import correct text encoder classes + text_encoder_cls_one = import_model_class_from_model_name_or_path( + args.pretrained_model_name_or_path, args.revision + ) + text_encoder_cls_two = import_model_class_from_model_name_or_path( + args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2" + ) + + # Load scheduler and models + noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") + text_encoder_one = text_encoder_cls_one.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision + ) + text_encoder_two = text_encoder_cls_two.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision + ) + vae_path = ( + args.pretrained_model_name_or_path + if args.pretrained_vae_model_name_or_path is None + else args.pretrained_vae_model_name_or_path + ) + vae = AutoencoderKL.from_pretrained( + vae_path, subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, revision=args.revision + ) + unet = UNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision + ) + # replace_all_layernorms(unet, 'rmsnorm') + # unet=torch.compile(unet, mode="reduce-overhead") + unet=torch.compile(unet) + torch._dynamo.config.suppress_errors = True + # unet = FullyShardedDataParallel(unet, sharding_strategy=torch.distributed.fsdp.ShardingStrategy.SHARD_GRAD_OP) + # import pdb;pdb.set_trace() + # Freeze vae and text encoders. + vae.requires_grad_(False) + text_encoder_one.requires_grad_(False) + text_encoder_two.requires_grad_(False) + + # For mixed precision training we cast all non-trainable weigths to half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # Move unet, vae and text_encoder to device and cast to weight_dtype + # The VAE is in float32 to avoid NaN losses. + if args.pretrained_vae_model_name_or_path is None: + vae.to(accelerator.device, dtype=torch.float32) + else: + vae.to(accelerator.device, dtype=weight_dtype) + text_encoder_one.to(accelerator.device, dtype=weight_dtype) + text_encoder_two.to(accelerator.device, dtype=weight_dtype) + + # Create EMA for the unet. + if args.use_ema: + ema_unet = UNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision + ) + ema_unet = EMAModel(ema_unet.parameters(), model_cls=UNet2DConditionModel, model_config=ema_unet.config) + + if args.enable_xformers_memory_efficient_attention: + if is_xformers_available(): + import xformers + + xformers_version = version.parse(xformers.__version__) + if xformers_version == version.parse("0.0.16"): + logger.warn( + "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." + ) + unet.enable_xformers_memory_efficient_attention() + else: + raise ValueError("xformers is not available. Make sure it is installed correctly") + + def compute_snr(timesteps): + """ + Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849 + """ + alphas_cumprod = noise_scheduler.alphas_cumprod + sqrt_alphas_cumprod = alphas_cumprod**0.5 + sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 + + # Expand the tensors. + # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026 + sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float() + while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape): + sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] + alpha = sqrt_alphas_cumprod.expand(timesteps.shape) + + sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float() + while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape): + sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] + sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape) + + # Compute SNR. + snr = (alpha / sigma) ** 2 + return snr + + # `accelerate` 0.16.0 will have better support for customized saving + if version.parse(accelerate.__version__) >= version.parse("0.16.0"): + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + if args.use_ema: + ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema")) + + for i, model in enumerate(models): + model.save_pretrained(os.path.join(output_dir, "unet")) + + # make sure to pop weight so that corresponding model is not saved again + weights.pop() + + def load_model_hook(models, input_dir): + if args.use_ema: + load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DConditionModel) + ema_unet.load_state_dict(load_model.state_dict()) + ema_unet.to(accelerator.device) + del load_model + + for i in range(len(models)): + # pop models so that they are not loaded again + model = models.pop() + + # load diffusers style into model + load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet") + model.register_to_config(**load_model.config) + + model.load_state_dict(load_model.state_dict()) + del load_model + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + else: + unet.disable_gradient_checkpointing() + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + optimizer_class = FusedAdam + + # Optimizer creation + params_to_optimize = unet.parameters() + optimizer = optimizer_class( + params_to_optimize, + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + # Get the datasets: you can either provide your own training and evaluation files (see below) + # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub). + + # In distributed training, the load_dataset function guarantees that only one local process can concurrently + # download the dataset. + if args.dataset_name is not None: + # Downloading and loading a dataset from the hub. + dataset = load_dataset( + args.dataset_name, + args.dataset_config_name, + cache_dir=args.cache_dir, + ) + else: + data_files = {} + if args.train_data_dir is not None: + data_files["train"] = os.path.join(args.train_data_dir, "**") + dataset = load_dataset( + "imagefolder", + data_files=data_files, + cache_dir=args.cache_dir, + ) + # See more about loading custom images at + # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder + + # Preprocessing the datasets. + # We need to tokenize inputs and targets. + column_names = dataset["train"].column_names + + # 6. Get the column names for input/target. + dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None) + if args.image_column is None: + image_column = dataset_columns[0] if dataset_columns is not None else column_names[0] + else: + image_column = args.image_column + if image_column not in column_names: + raise ValueError( + f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}" + ) + if args.caption_column is None: + caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1] + else: + caption_column = args.caption_column + if caption_column not in column_names: + raise ValueError( + f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}" + ) + + # Preprocessing the datasets. + train_resize = transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR) + train_crop = transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution) + train_flip = transforms.RandomHorizontalFlip(p=1.0) + train_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]) + + def preprocess_train(examples): + images = [image.convert("RGB") for image in examples[image_column]] + # image aug + original_sizes = [] + all_images = [] + crop_top_lefts = [] + for image in images: + original_sizes.append((image.height, image.width)) + image = train_resize(image) + if args.center_crop: + y1 = max(0, int(round((image.height - args.resolution) / 2.0))) + x1 = max(0, int(round((image.width - args.resolution) / 2.0))) + image = train_crop(image) + else: + y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution)) + image = crop(image, y1, x1, h, w) + if args.random_flip and random.random() < 0.5: + # flip + x1 = image.width - x1 + image = train_flip(image) + crop_top_left = (y1, x1) + crop_top_lefts.append(crop_top_left) + image = train_transforms(image) + all_images.append(image) + + examples["original_sizes"] = original_sizes + examples["crop_top_lefts"] = crop_top_lefts + examples["pixel_values"] = all_images + return examples + + with accelerator.main_process_first(): + if args.max_train_samples is not None: + dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples)) + # Set the training transforms + train_dataset = dataset["train"].with_transform(preprocess_train) + + # Let's first compute all the embeddings so that we can free up the text encoders + # from memory. We will pre-compute the VAE encodings too. + text_encoders = [text_encoder_one, text_encoder_two] + tokenizers = [tokenizer_one, tokenizer_two] + compute_embeddings_fn = functools.partial( + encode_prompt, + text_encoders=text_encoders, + tokenizers=tokenizers, + proportion_empty_prompts=args.proportion_empty_prompts, + caption_column=args.caption_column, + ) + compute_vae_encodings_fn = functools.partial(compute_vae_encodings, vae=vae) + with accelerator.main_process_first(): + from datasets.fingerprint import Hasher + + # fingerprint used by the cache for the other processes to load the result + # details: https://github.com/huggingface/diffusers/pull/4038#discussion_r1266078401 + new_fingerprint = Hasher.hash(args) + new_fingerprint_for_vae = Hasher.hash("vae") + train_dataset = train_dataset.map(compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint) + train_dataset = train_dataset.map( + compute_vae_encodings_fn, + batched=True, + batch_size=args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps, + new_fingerprint=new_fingerprint_for_vae, + ) + + del text_encoders, tokenizers, vae + gc.collect() + torch.cuda.empty_cache() + + def collate_fn(examples): + # for k in examples: + # print(k, type(examples[k])) + model_input = torch.stack([torch.tensor(example["model_input"]) for example in examples]) + original_sizes = [example["original_sizes"] for example in examples] + crop_top_lefts = [example["crop_top_lefts"] for example in examples] + prompt_embeds = torch.stack([torch.tensor(example["prompt_embeds"]) for example in examples]) + pooled_prompt_embeds = torch.stack([torch.tensor(example["pooled_prompt_embeds"]) for example in examples]) + + return { + "model_input": model_input, + "prompt_embeds": prompt_embeds, + "pooled_prompt_embeds": pooled_prompt_embeds, + "original_sizes": original_sizes, + "crop_top_lefts": crop_top_lefts, + } + + # DataLoaders creation: + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + shuffle=True, + collate_fn=collate_fn, + batch_size=args.train_batch_size, + num_workers=args.dataloader_num_workers, + ) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, + num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + ) + + # Prepare everything with our `accelerator`. + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, optimizer, train_dataloader, lr_scheduler + ) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + accelerator.init_trackers("text2image-fine-tune-sdxl", config=vars(args)) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + resume_global_step = global_step * args.gradient_accumulation_steps + first_epoch = global_step // num_update_steps_per_epoch + resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps) + + # Only show the progress bar once on each machine. + progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) + progress_bar.set_description("Steps") + + # prof = torch.profiler.profile(record_shapes=True, profile_memory=True, with_stack=True) + from torch.profiler import profile, record_function, ProfilerActivity + prof=None + if args.profile: + prof = torch.profiler.profile( + schedule=torch.profiler.schedule(wait=10, warmup=1, active=1, repeat=1), + on_trace_ready=torch.profiler.tensorboard_trace_handler('./prof/trace_sdxl_bs24_torchcompile'), + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + record_shapes=True, + with_stack=True, + profile_memory=True) + prof.start() + for epoch in range(first_epoch, args.num_train_epochs): + unet.train() + train_loss = 0.0 + print("epoch", epoch) + beg = time.time() + for step, batch in enumerate(train_dataloader): + # Skip steps until we reach the resumed step + if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: + if step % args.gradient_accumulation_steps == 0: + progress_bar.update(1) + continue + + # if epoch==5 & step==0: + # beg = time.time() + # if epoch==5 & step==5: + # print("time used per step:", (time.time()-beg)/5, flush=True) + with accelerator.accumulate(unet): + # Sample noise that we'll add to the latents + model_input = batch["model_input"].to(accelerator.device) + noise = torch.randn_like(model_input) + if args.noise_offset: + # https://www.crosslabs.org//blog/diffusion-with-offset-noise + noise += args.noise_offset * torch.randn( + (model_input.shape[0], model_input.shape[1], 1, 1), device=model_input.device + ) + + bsz = model_input.shape[0] + # Sample a random timestep for each image + timesteps = torch.randint( + 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device + ) + timesteps = timesteps.long() + + # Add noise to the model input according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps) + + # time ids + def compute_time_ids(original_size, crops_coords_top_left): + # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids + target_size = (args.resolution, args.resolution) + add_time_ids = list(original_size + crops_coords_top_left + target_size) + add_time_ids = torch.tensor([add_time_ids]) + add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype) + return add_time_ids + + add_time_ids = torch.cat( + [compute_time_ids(s, c) for s, c in zip(batch["original_sizes"], batch["crop_top_lefts"])] + ) + + # Predict the noise residual + unet_added_conditions = {"time_ids": add_time_ids} + prompt_embeds = batch["prompt_embeds"].to(accelerator.device) + pooled_prompt_embeds = batch["pooled_prompt_embeds"].to(accelerator.device) + unet_added_conditions.update({"text_embeds": pooled_prompt_embeds}) + prompt_embeds = prompt_embeds + model_pred = unet( + noisy_model_input, timesteps, prompt_embeds, added_cond_kwargs=unet_added_conditions + ).sample + + # Get the target for loss depending on the prediction type + if args.prediction_type is not None: + # set prediction_type of scheduler if defined + noise_scheduler.register_to_config(prediction_type=args.prediction_type) + + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(model_input, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + + if args.snr_gamma is None: + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + else: + # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556. + # Since we predict the noise instead of x_0, the original formulation is slightly changed. + # This is discussed in Section 4.2 of the same paper. + snr = compute_snr(timesteps) + mse_loss_weights = ( + torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr + ) + # We first calculate the original loss. Then we mean over the non-batch dimensions and + # rebalance the sample-wise losses with their respective loss weights. + # Finally, we take the mean of the rebalanced loss. + loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") + loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights + loss = loss.mean() + + # Gather the losses across all processes for logging (if we use distributed training). + avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean() + train_loss += avg_loss.item() / args.gradient_accumulation_steps + + # Backpropagate + accelerator.backward(loss) + if accelerator.sync_gradients: + params_to_clip = unet.parameters() + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + print(f"epoch {epoch}, step {step}, time: {time.time()-beg}") + beg=time.time() + if prof is not None: + prof.step() + + + # if step==21: + # prof.stop() + # prof.export_chrome_trace('./sdxl_prof_bf16_np1_bs32.json') + # import pdb;pdb.set_trace() + # break + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + accelerator.log({"train_loss": train_loss}, step=global_step) + train_loss = 0.0 + + if accelerator.is_main_process: + if global_step % args.checkpointing_steps == 0: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "Max MEM": torch.cuda.max_memory_allocated()/1e9} + progress_bar.set_postfix(**logs) + + if global_step >= args.max_train_steps: + break + + if accelerator.is_main_process: + if args.validation_prompt is not None and epoch % args.validation_epochs == 0: + logger.info( + f"Running validation... \n Generating {args.num_validation_images} images with prompt:" + f" {args.validation_prompt}." + ) + if args.use_ema: + # Store the UNet parameters temporarily and load the EMA parameters to perform inference. + ema_unet.store(unet.parameters()) + ema_unet.copy_to(unet.parameters()) + # create pipeline + pipeline = StableDiffusionXLPipeline.from_pretrained( + args.pretrained_model_name_or_path, + unet=accelerator.unwrap_model(unet), + revision=args.revision, + torch_dtype=weight_dtype, + ) + + pipeline = pipeline.to(accelerator.device) + pipeline.set_progress_bar_config(disable=True) + + # run inference + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None + pipeline_args = {"prompt": args.validation_prompt} + + with torch.cuda.amp.autocast(): + images = [ + pipeline(**pipeline_args, generator=generator, num_inference_steps=25).images[0] + for _ in range(args.num_validation_images) + ] + + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC") + if tracker.name == "wandb": + tracker.log( + { + "validation": [ + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") + for i, image in enumerate(images) + ] + } + ) + + del pipeline + torch.cuda.empty_cache() + + accelerator.wait_for_everyone() + if accelerator.is_main_process: + unet = accelerator.unwrap_model(unet) + if args.use_ema: + ema_unet.copy_to(unet.parameters()) + + # Final inference + # Load previous pipeline + vae = AutoencoderKL.from_pretrained( + vae_path, + subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, + revision=args.revision, + torch_dtype=weight_dtype, + ) + pipeline = StableDiffusionXLPipeline.from_pretrained( + args.pretrained_model_name_or_path, unet=unet, vae=vae, revision=args.revision, torch_dtype=weight_dtype + ) + pipeline = pipeline.to(accelerator.device) + + # run inference + images = [] + if args.validation_prompt and args.num_validation_images > 0: + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None + images = [ + pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0] + for _ in range(args.num_validation_images) + ] + + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC") + if tracker.name == "wandb": + tracker.log( + { + "test": [ + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") + for i, image in enumerate(images) + ] + } + ) + + if args.push_to_hub: + save_model_card( + repo_id, + images=images, + base_model=args.pretrained_model_name_or_path, + prompt=args.instance_prompt, + repo_folder=args.output_dir, + vae_path=args.pretrained_vae_model_name_or_path, + ) + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) + + accelerator.end_training() + + +if __name__ == "__main__": + args = parse_args() + main(args) \ No newline at end of file diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index ad899212e5a5..7b64266b47e1 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -22,7 +22,7 @@ from .attention_processor import Attention from .embeddings import CombinedTimestepLabelEmbeddings from .lora import LoRACompatibleLinear - +from .tp_utils import ColParallelLinear,RowParallelLinear, gather_from_sequence_parallel_region, reduce_scatter_to_sequence_parallel_region, _split_along_first_dim @maybe_allow_in_graph class BasicTransformerBlock(nn.Module): @@ -139,6 +139,10 @@ def forward( cross_attention_kwargs: Dict[str, Any] = None, class_labels: Optional[torch.LongTensor] = None, ): + # import pdb;pdb.set_trace() + # split + hidden_states = _split_along_first_dim(hidden_states) + # Notice that normalization is always applied before the real computation in the following blocks. # 1. Self-Attention if self.use_ada_layer_norm: @@ -152,6 +156,7 @@ def forward( cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + norm_hidden_states = gather_from_sequence_parallel_region(norm_hidden_states) attn_output = self.attn1( norm_hidden_states, encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, @@ -167,7 +172,7 @@ def forward( norm_hidden_states = ( self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) ) - + norm_hidden_states = gather_from_sequence_parallel_region(norm_hidden_states) attn_output = self.attn2( norm_hidden_states, encoder_hidden_states=encoder_hidden_states, @@ -195,12 +200,14 @@ def forward( dim=self._chunk_dim, ) else: + norm_hidden_states = gather_from_sequence_parallel_region(norm_hidden_states) ff_output = self.ff(norm_hidden_states) if self.use_ada_layer_norm_zero: ff_output = gate_mlp.unsqueeze(1) * ff_output hidden_states = ff_output + hidden_states + hidden_states = gather_from_sequence_parallel_region(hidden_states) return hidden_states @@ -218,6 +225,56 @@ class FeedForward(nn.Module): final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. """ + def __init__( + self, + dim: int, + dim_out: Optional[int] = None, + mult: int = 4, + dropout: float = 0.0, + activation_fn: str = "geglu", + final_dropout: bool = False, + ): + super().__init__() + inner_dim = int(dim * mult) + dim_out = dim_out if dim_out is not None else dim + + assert activation_fn == "geglu" + act_fn = GEGLU_ColParallel(dim, inner_dim) + + + self.net = nn.ModuleList([]) + # project in + self.net.append(act_fn) + # project dropout + self.net.append(nn.Dropout(dropout)) + # project out + proj2 = RowParallelLinear(inner_dim, dim_out, sequence_parallel=True) + # LoRACompatibleLinear(inner_dim, dim_out) + self.net.append(proj2) + # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout + if final_dropout: + self.net.append(nn.Dropout(dropout)) + + def forward(self, hidden_states): + # import pdb;pdb.set_trace() + for module in self.net: + hidden_states = module(hidden_states) + return hidden_states + + +class FeedForwardBKUP(nn.Module): + r""" + A feed-forward layer. + + Parameters: + dim (`int`): The number of channels in the input. + dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. + mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. + """ + def __init__( self, dim: int, @@ -256,7 +313,6 @@ def forward(self, hidden_states): hidden_states = module(hidden_states) return hidden_states - class GELU(nn.Module): r""" GELU activation function with tanh approximation support with `approximate="tanh"`. @@ -300,9 +356,35 @@ def gelu(self, gate): def forward(self, hidden_states): hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1) + import pdb;pdb.set_trace() return hidden_states * self.gelu(gate) +class GEGLU_ColParallel(nn.Module): + r""" + A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202. + + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. + """ + + def __init__(self, dim_in: int, dim_out: int): + super().__init__() + # self.proj = LoRACompatibleLinear(dim_in, dim_out * 2) + self.proj = ColParallelLinear(dim_in, dim_out * 2) + + def gelu(self, gate): + if gate.device.type != "mps": + return F.gelu(gate) + # mps: gelu is not implemented for float16 + return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) + + def forward(self, hidden_states): + hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1) + # import pdb;pdb.set_trace() + return hidden_states * self.gelu(gate) + class ApproximateGELU(nn.Module): """ The approximate form of Gaussian Error Linear Unit (GELU) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 43497c2284ac..a07e3a64c10c 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -21,6 +21,7 @@ from ..utils.import_utils import is_xformers_available from .lora import LoRALinearLayer +from .tp_utils import ColParallelLinear,RowParallelLinear logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -135,12 +136,12 @@ def __init__( f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'" ) - self.to_q = nn.Linear(query_dim, inner_dim, bias=bias) - + # self.to_q = nn.Linear(query_dim, inner_dim, bias=bias) + self.to_q = ColParallelLinear(query_dim, inner_dim, bias=bias) if not self.only_cross_attention: # only relevant for the `AddedKVProcessor` classes - self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias) - self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias) + self.to_k = ColParallelLinear(cross_attention_dim, inner_dim, bias=bias) + self.to_v = ColParallelLinear(cross_attention_dim, inner_dim, bias=bias) else: self.to_k = None self.to_v = None @@ -150,7 +151,7 @@ def __init__( self.add_v_proj = nn.Linear(added_kv_proj_dim, inner_dim) self.to_out = nn.ModuleList([]) - self.to_out.append(nn.Linear(inner_dim, query_dim, bias=out_bias)) + self.to_out.append(RowParallelLinear(inner_dim, query_dim, bias=out_bias, sequence_parallel=True)) self.to_out.append(nn.Dropout(dropout)) # set attention processor diff --git a/src/diffusers/models/tp_utils.py b/src/diffusers/models/tp_utils.py new file mode 100644 index 000000000000..23bcff23ffa1 --- /dev/null +++ b/src/diffusers/models/tp_utils.py @@ -0,0 +1,224 @@ +import torch +from torch import nn as nn +from torch.nn.parameter import Parameter + +import torch.distributed as dist + +TP_GROUP=None +def get_tp_group(): + global TP_GROUP + return TP_GROUP + +def set_tp_group(group): + global TP_GROUP + TP_GROUP=group + +def get_tensor_model_parallel_world_size(): + return dist.get_world_size(get_tp_group()) + +class mylinear(nn.Module): + def __init__(self, fin, fout, bias=True): + super(mylinear, self).__init__() + self.weight = Parameter(torch.rand((fin, fout))) + self.bias =None + if bias: + self.bias = Parameter(torch.zeros(fout)) + + def forward(self, x): + out = torch.matmul(x, self.weight) + if self.bias is not None: + out +=self.bias + return out + +class _ReduceFromModelParallelRegion(torch.autograd.Function): + """All-reduce the input from the model parallel region.""" + + @staticmethod + def forward(ctx, input_): + torch.distributed.all_reduce(input_, group=get_tp_group()) + return input_ + + @staticmethod + def backward(ctx, grad_output): + return grad_output + + +def _reduce_scatter_along_first_dim(input_): + """Reduce-scatter the input tensor across model parallel group.""" + world_size = get_tensor_model_parallel_world_size() + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + + dim_size = list(input_.size()) + assert dim_size[0] % world_size == 0, \ + "First dimension of the tensor should be divisible by tensor parallel size" + + dim_size[0] = dim_size[0] // world_size + + output = torch.empty(dim_size, dtype=input_.dtype, + device=torch.cuda.current_device()) + torch.distributed._reduce_scatter_base(output, input_.contiguous(), + group=get_tp_group()) + return output + +def _gather_along_first_dim(input_): + """Gather tensors and concatinate along the first dimension.""" + + world_size = get_tensor_model_parallel_world_size() + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + + dim_size = list(input_.size()) + dim_size[0] = dim_size[0] * world_size + + output = torch.empty(dim_size, dtype=input_.dtype, + device=torch.cuda.current_device()) + torch.distributed._all_gather_base(output, input_.contiguous(), + group=get_tp_group()) + + return output +def _split_along_first_dim(input_): + """Split the tensor along its first dimension and keep the + corresponding slice.""" + + world_size = get_tensor_model_parallel_world_size() + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + + # Split along first dimension. + dim_size = input_.size()[0] + assert dim_size % world_size == 0, \ + "First dimension of the tensor should be divisible by tensor parallel size" + local_dim_size = dim_size // world_size + rank = torch.distributed.get_rank(get_tp_group()) + dim_offset = rank * local_dim_size + + output = input_[dim_offset:dim_offset+local_dim_size].contiguous() + + return output + +class _ReduceScatterToSequenceParallelRegion(torch.autograd.Function): + """Reduce scatter the input from the model parallel region.""" + + @staticmethod + def symbolic(graph, input_): + return _reduce_scatter_along_first_dim(input_) + + @staticmethod + def forward(ctx, input_): + return _reduce_scatter_along_first_dim(input_) + + @staticmethod + def backward(ctx, grad_output): + return _gather_along_first_dim(grad_output) + + +class _GatherFromSequenceParallelRegion(torch.autograd.Function): + """Gather the input from sequence parallel region and concatinate.""" + + @staticmethod + def symbolic(graph, input_, tensor_parallel_output_grad=True): + return _gather_along_first_dim(input_) + + @staticmethod + def forward(ctx, input_, tensor_parallel_output_grad=True): + ctx.tensor_parallel_output_grad = tensor_parallel_output_grad + return _gather_along_first_dim(input_) + + @staticmethod + def backward(ctx, grad_output): + tensor_parallel_output_grad = ctx.tensor_parallel_output_grad + + # If the computation graph after the gather operation is + # in the tensor parallel mode, output gradients need to reduce + # scattered and whereas if the computation is duplicated, + # output gradients need to be scattered. + if tensor_parallel_output_grad: + return _reduce_scatter_along_first_dim(grad_output), None + else: + return _split_along_first_dim(grad_output), None + +def gather_from_sequence_parallel_region(input_, tensor_parallel_output_grad=True): + return _GatherFromSequenceParallelRegion.apply(input_, tensor_parallel_output_grad) +def reduce_scatter_to_sequence_parallel_region(input_): + return _ReduceScatterToSequenceParallelRegion.apply(input_) + +class ColParallelLinear(nn.Module): + def __init__(self, fin, fout, bias=True): + super(ColParallelLinear, self).__init__() + tp_group=get_tp_group() + self.tp_world_size = torch.distributed.get_world_size(tp_group) + assert fout%self.tp_world_size==0 + self.fout = int(fout/self.tp_world_size) + + self.linear = mylinear(fin, self.fout, bias) + + + def forward(self, x): + """ + 1. automatically split fout + 2. do linear compute + 3. the output is split , no communication + """ + out = self.linear(x) + return out + + def init_weight_from_full(self, fullwt): + cur_rank = torch.distributed.get_rank(get_tp_group()) + start_ind = cur_rank*self.fout + end_ind = (cur_rank+1)*self.fout + slice = fullwt[:,start_ind:end_ind] + with torch.no_grad(): + self.linear.weight.copy_(slice) + + def init_weight_from_full_attn(self, fullwt): + cur_rank = torch.distributed.get_rank(get_tp_group()) + ws = torch.distributed.get_world_size(get_tp_group()) + dim=fullwt.shape[0] + dim3=fullwt.shape[1] + fullwts = fullwt.split(dim3//3, dim=-1) # (q,k,v) + splits = [] + for wt in fullwts: + splits.append(wt.split(wt.shape[-1]//ws, dim=-1)[cur_rank]) + + cat_full = torch.cat(splits, dim=-1) + + with torch.no_grad(): + # org_shape = self.linear.weight.shape + # self.linear.weight = self.linear.weight.reshape(dim, nh//ws, dim3//nh) + self.linear.weight.copy_(cat_full) + +class RowParallelLinear(nn.Module): + def __init__(self, fin, fout, bias=True, sequence_parallel=False): + super(RowParallelLinear, self).__init__() + tp_group=get_tp_group() + self.tp_world_size = torch.distributed.get_world_size(tp_group) + assert fin%self.tp_world_size==0 + self.fin = int(fin/self.tp_world_size) + self.linear = mylinear(self.fin, fout, bias) + self.sequence_parallel = sequence_parallel + + + def forward(self, x): + """ + 1. automatically split fout + 2. do linear compute + 3. the output is allreduced + """ + out = self.linear(x) + if not self.sequence_parallel: + out = _ReduceFromModelParallelRegion.apply(out) + else: + out = reduce_scatter_to_sequence_parallel_region(out) + return out + + def init_weight_from_full(self, fullwt): + cur_rank = torch.distributed.get_rank(get_tp_group()) + start_ind = cur_rank*self.fin + end_ind = (cur_rank+1)*self.fin + slice = fullwt[start_ind:end_ind] + with torch.no_grad(): + self.linear.weight.copy_(slice) \ No newline at end of file diff --git a/src/diffusers/models/transformer_2d.py b/src/diffusers/models/transformer_2d.py index 344a9441ced1..a709ee7158d5 100644 --- a/src/diffusers/models/transformer_2d.py +++ b/src/diffusers/models/transformer_2d.py @@ -270,6 +270,7 @@ def forward( encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + print("before input", torch.cuda.memory_allocated()/1e9, torch.cuda.max_memory_allocated()/1e9) # 1. Input if self.is_input_continuous: batch, _, height, width = hidden_states.shape @@ -289,6 +290,8 @@ def forward( elif self.is_input_patches: hidden_states = self.pos_embed(hidden_states) + # print("after input", torch.cuda.memory_allocated()/1e9, torch.cuda.max_memory_allocated()/1e9) + # 2. Blocks for block in self.transformer_blocks: if self.training and self.gradient_checkpointing: @@ -313,7 +316,8 @@ def forward( cross_attention_kwargs=cross_attention_kwargs, class_labels=class_labels, ) - + # print("after a block", torch.cuda.memory_allocated()/1e9, torch.cuda.max_memory_allocated()/1e9) + # import pdb;pdb.set_trace() # 3. Output if self.is_input_continuous: if not self.use_linear_projection: @@ -350,7 +354,7 @@ def forward( output = hidden_states.reshape( shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size) ) - + # print("after output", torch.cuda.memory_allocated()/1e9, torch.cuda.max_memory_allocated()/1e9) if not return_dict: return (output,) diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index e894628462ef..942ec6ae7901 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -1041,7 +1041,10 @@ def custom_forward(*inputs): return_dict=False, )[0] else: + print("before resnet:", torch.cuda.memory_allocated()/1e9, torch.cuda.max_memory_allocated()/1e9) hidden_states = resnet(hidden_states, temb) + print("after resnet:", torch.cuda.memory_allocated()/1e9, torch.cuda.max_memory_allocated()/1e9) + hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, @@ -1050,6 +1053,8 @@ def custom_forward(*inputs): encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] + print("after attn:", torch.cuda.memory_allocated()/1e9, torch.cuda.max_memory_allocated()/1e9) + # apply additional residuals to the output of the last pair of resnet and attention blocks if i == len(blocks) - 1 and additional_residuals is not None: @@ -1120,6 +1125,7 @@ def __init__( def forward(self, hidden_states, temb=None): output_states = () + print("before down2d resnet:", torch.cuda.memory_allocated()/1e9, torch.cuda.max_memory_allocated()/1e9) for resnet in self.resnets: if self.training and self.gradient_checkpointing: @@ -1148,6 +1154,7 @@ def custom_forward(*inputs): hidden_states = downsampler(hidden_states) output_states = output_states + (hidden_states,) + print("after down2d resnet:", torch.cuda.memory_allocated()/1e9, torch.cuda.max_memory_allocated()/1e9) return hidden_states, output_states diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index fea1b4cd7823..97234b368323 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -893,13 +893,19 @@ def forward( image_embeds = added_cond_kwargs.get("image_embeds") encoder_hidden_states = self.encoder_hid_proj(image_embeds) # 2. pre-process - sample = self.conv_in(sample) + print("before conv_in", torch.cuda.memory_allocated()/1e9, torch.cuda.max_memory_allocated()/1e9) + sample = self.conv_in(sample) + print("after conv_in", torch.cuda.memory_allocated()/1e9, torch.cuda.max_memory_allocated()/1e9) # 3. down is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None + def get_size(t): + ts = 4 if t.dtype==torch.float else 2 + return t.numel() * ts + # import pdb;pdb.set_trace() down_block_res_samples = (sample,) for downsample_block in self.down_blocks: if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: @@ -924,7 +930,9 @@ def forward( sample += down_block_additional_residuals.pop(0) down_block_res_samples += res_samples - + # print("after block:", torch.cuda.memory_allocated()/1e9, torch.cuda.max_memory_allocated()/1e9) + # print("down_block_res_samples total mem",sum([get_size(t) for t in down_block_res_samples])/1e6) + # import pdb;pdb.set_trace() if is_controlnet: new_down_block_res_samples = () @@ -936,6 +944,8 @@ def forward( down_block_res_samples = new_down_block_res_samples + # import pdb;pdb.set_trace() + # print(torch.cuda.memory_allocated()/1e9, torch.cuda.max_memory_allocated()/1e9) # 4. mid if self.mid_block is not None: sample = self.mid_block( @@ -946,7 +956,8 @@ def forward( cross_attention_kwargs=cross_attention_kwargs, encoder_attention_mask=encoder_attention_mask, ) - + # import pdb;pdb.set_trace() + # print("after mid block",torch.cuda.memory_allocated()/1e9, torch.cuda.max_memory_allocated()/1e9) if is_controlnet: sample = sample + mid_block_additional_residual From 60b5a6b371ee83ee4ea7dd4db64465c6ce13a1a7 Mon Sep 17 00:00:00 2001 From: kimmishi Date: Tue, 29 Aug 2023 11:25:53 +0800 Subject: [PATCH 2/9] xattn sequence parallel --- src/diffusers/models/attention.py | 10 +++++-- src/diffusers/models/tp_utils.py | 20 +++++++++++++- src/diffusers/models/transformer_2d.py | 10 ++++++- src/diffusers/models/unet_2d_blocks.py | 33 ++++++++++++++++++----- src/diffusers/models/unet_2d_condition.py | 7 ++--- 5 files changed, 67 insertions(+), 13 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 7b64266b47e1..f6295ade03ec 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -62,9 +62,11 @@ def __init__( norm_elementwise_affine: bool = True, norm_type: str = "layer_norm", final_dropout: bool = False, + sequence_parallel = True, ): super().__init__() self.only_cross_attention = only_cross_attention + self.sequence_parallel = sequence_parallel self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" @@ -141,7 +143,9 @@ def forward( ): # import pdb;pdb.set_trace() # split - hidden_states = _split_along_first_dim(hidden_states) + # hidden_states = _split_along_first_dim(hidden_states) + # if not hasattr(hidden_states, 'sequence_parallel'): + # hidden_states = _split_along_first_dim(hidden_states) # Notice that normalization is always applied before the real computation in the following blocks. # 1. Self-Attention @@ -179,6 +183,7 @@ def forward( attention_mask=encoder_attention_mask, **cross_attention_kwargs, ) + # import pdb;pdb.set_trace() hidden_states = attn_output + hidden_states # 3. Feed-forward @@ -207,7 +212,8 @@ def forward( ff_output = gate_mlp.unsqueeze(1) * ff_output hidden_states = ff_output + hidden_states - hidden_states = gather_from_sequence_parallel_region(hidden_states) + # hidden_states = gather_from_sequence_parallel_region(hidden_states) + setattr(hidden_states, "sequence_parallel", True) return hidden_states diff --git a/src/diffusers/models/tp_utils.py b/src/diffusers/models/tp_utils.py index 23bcff23ffa1..f02962c24670 100644 --- a/src/diffusers/models/tp_utils.py +++ b/src/diffusers/models/tp_utils.py @@ -30,6 +30,20 @@ def forward(self, x): out +=self.bias return out +def maybe_gather_from_sequence_parallel(inp): + if hasattr(inp, "sequence_parallel") and inp.sequence_parallel: + return gather_from_sequence_parallel_region(inp) + return inp + +def maybe_split_into_sequnce_parallel(inp): + if not hasattr(inp, "sequence_parallel"): + return _split_along_first_dim(inp) + return inp + +def set_sequence_parallel_attr(inp, value=True): + setattr(inp, "sequence_parallel", value) + return inp + class _ReduceFromModelParallelRegion(torch.autograd.Function): """All-reduce the input from the model parallel region.""" @@ -98,6 +112,7 @@ def _split_along_first_dim(input_): output = input_[dim_offset:dim_offset+local_dim_size].contiguous() + setattr(output, "sequence_parallel", True) return output class _ReduceScatterToSequenceParallelRegion(torch.autograd.Function): @@ -142,7 +157,10 @@ def backward(ctx, grad_output): return _split_along_first_dim(grad_output), None def gather_from_sequence_parallel_region(input_, tensor_parallel_output_grad=True): - return _GatherFromSequenceParallelRegion.apply(input_, tensor_parallel_output_grad) + output = _GatherFromSequenceParallelRegion.apply(input_, tensor_parallel_output_grad) + setattr(output, "sequence_parallel", False) + return output + def reduce_scatter_to_sequence_parallel_region(input_): return _ReduceScatterToSequenceParallelRegion.apply(input_) diff --git a/src/diffusers/models/transformer_2d.py b/src/diffusers/models/transformer_2d.py index a709ee7158d5..418e25a1014a 100644 --- a/src/diffusers/models/transformer_2d.py +++ b/src/diffusers/models/transformer_2d.py @@ -26,6 +26,7 @@ from .lora import LoRACompatibleConv, LoRACompatibleLinear from .modeling_utils import ModelMixin +from .tp_utils import _split_along_first_dim, gather_from_sequence_parallel_region @dataclass class Transformer2DModelOutput(BaseOutput): @@ -270,7 +271,10 @@ def forward( encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 encoder_attention_mask = encoder_attention_mask.unsqueeze(1) - print("before input", torch.cuda.memory_allocated()/1e9, torch.cuda.max_memory_allocated()/1e9) + # hidden_states = _split_along_first_dim(hidden_states) + # import pdb;pdb.set_trace() + + # print("before input", torch.cuda.memory_allocated()/1e9, torch.cuda.max_memory_allocated()/1e9) # 1. Input if self.is_input_continuous: batch, _, height, width = hidden_states.shape @@ -292,6 +296,8 @@ def forward( # print("after input", torch.cuda.memory_allocated()/1e9, torch.cuda.max_memory_allocated()/1e9) + setattr(hidden_states, "sequence_parallel", True) + # input to blocks are sequence parallel # 2. Blocks for block in self.transformer_blocks: if self.training and self.gradient_checkpointing: @@ -354,6 +360,8 @@ def forward( output = hidden_states.reshape( shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size) ) + # output = gather_from_sequence_parallel_region(output) + # print("after output", torch.cuda.memory_allocated()/1e9, torch.cuda.max_memory_allocated()/1e9) if not return_dict: return (output,) diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index 942ec6ae7901..d3fcfdd9694c 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -25,7 +25,8 @@ from .dual_transformer_2d import DualTransformer2DModel from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, KDownsample2D, KUpsample2D, ResnetBlock2D, Upsample2D from .transformer_2d import Transformer2DModel - +from .tp_utils import gather_from_sequence_parallel_region, _split_along_first_dim,maybe_gather_from_sequence_parallel, \ + maybe_split_into_sequnce_parallel, set_sequence_parallel_attr logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -634,6 +635,8 @@ def forward( cross_attention_kwargs: Optional[Dict[str, Any]] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, ) -> torch.FloatTensor: + hidden_states = maybe_split_into_sequnce_parallel(hidden_states) + temb = maybe_split_into_sequnce_parallel(temb) hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): if self.training and self.gradient_checkpointing: @@ -671,8 +674,11 @@ def custom_forward(*inputs): encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] + hidden_states = resnet(hidden_states, temb) + hidden_states = set_sequence_parallel_attr(hidden_states) + # hidden_states = gather_from_sequence_parallel_region(hidden_states) return hidden_states @@ -783,6 +789,7 @@ def forward( # mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask mask = attention_mask + hidden_states = _split_along_first_dim(hidden_states) hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): # attn @@ -796,6 +803,7 @@ def forward( # resnet hidden_states = resnet(hidden_states, temb) + hidden_states = gather_from_sequence_parallel_region(hidden_states) return hidden_states @@ -1013,6 +1021,8 @@ def forward( blocks = list(zip(self.resnets, self.attentions)) + hidden_states = maybe_split_into_sequnce_parallel(hidden_states) + temb = maybe_split_into_sequnce_parallel(temb) for i, (resnet, attn) in enumerate(blocks): if self.training and self.gradient_checkpointing: @@ -1041,9 +1051,9 @@ def custom_forward(*inputs): return_dict=False, )[0] else: - print("before resnet:", torch.cuda.memory_allocated()/1e9, torch.cuda.max_memory_allocated()/1e9) + # print("before resnet:", torch.cuda.memory_allocated()/1e9, torch.cuda.max_memory_allocated()/1e9) hidden_states = resnet(hidden_states, temb) - print("after resnet:", torch.cuda.memory_allocated()/1e9, torch.cuda.max_memory_allocated()/1e9) + # print("after resnet:", torch.cuda.memory_allocated()/1e9, torch.cuda.max_memory_allocated()/1e9) hidden_states = attn( hidden_states, @@ -1053,9 +1063,11 @@ def custom_forward(*inputs): encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] - print("after attn:", torch.cuda.memory_allocated()/1e9, torch.cuda.max_memory_allocated()/1e9) + set_sequence_parallel_attr(hidden_states) + # print("after attn:", torch.cuda.memory_allocated()/1e9, torch.cuda.max_memory_allocated()/1e9) + # full_hidden_states = gather_from_sequence_parallel_region(hidden_states) # apply additional residuals to the output of the last pair of resnet and attention blocks if i == len(blocks) - 1 and additional_residuals is not None: hidden_states = hidden_states + additional_residuals @@ -1068,6 +1080,7 @@ def custom_forward(*inputs): output_states = output_states + (hidden_states,) + set_sequence_parallel_attr(hidden_states) return hidden_states, output_states @@ -1125,7 +1138,7 @@ def __init__( def forward(self, hidden_states, temb=None): output_states = () - print("before down2d resnet:", torch.cuda.memory_allocated()/1e9, torch.cuda.max_memory_allocated()/1e9) + # print("before down2d resnet:", torch.cuda.memory_allocated()/1e9, torch.cuda.max_memory_allocated()/1e9) for resnet in self.resnets: if self.training and self.gradient_checkpointing: @@ -1154,7 +1167,7 @@ def custom_forward(*inputs): hidden_states = downsampler(hidden_states) output_states = output_states + (hidden_states,) - print("after down2d resnet:", torch.cuda.memory_allocated()/1e9, torch.cuda.max_memory_allocated()/1e9) + # print("after down2d resnet:", torch.cuda.memory_allocated()/1e9, torch.cuda.max_memory_allocated()/1e9) return hidden_states, output_states @@ -2147,10 +2160,15 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, ): + hidden_states = maybe_split_into_sequnce_parallel(hidden_states) + temb = maybe_split_into_sequnce_parallel(temb) + for resnet, attn in zip(self.resnets, self.attentions): # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] + # import pdb;pdb.set_trace() + res_hidden_states = maybe_split_into_sequnce_parallel(res_hidden_states) hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) if self.training and self.gradient_checkpointing: @@ -2194,6 +2212,8 @@ def custom_forward(*inputs): for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states, upsample_size) + # hidden_states = gather_from_sequence_parallel_region(hidden_states) + hidden_states = set_sequence_parallel_attr(hidden_states) return hidden_states @@ -2246,6 +2266,7 @@ def __init__( self.gradient_checkpointing = False def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): + hidden_states = maybe_gather_from_sequence_parallel(hidden_states) for resnet in self.resnets: # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 97234b368323..459cd51b1bdd 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -41,7 +41,7 @@ get_down_block, get_up_block, ) - +from .tp_utils import _split_along_first_dim, gather_from_sequence_parallel_region logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -893,10 +893,11 @@ def forward( image_embeds = added_cond_kwargs.get("image_embeds") encoder_hidden_states = self.encoder_hid_proj(image_embeds) # 2. pre-process - print("before conv_in", torch.cuda.memory_allocated()/1e9, torch.cuda.max_memory_allocated()/1e9) + # print("before conv_in", torch.cuda.memory_allocated()/1e9, torch.cuda.max_memory_allocated()/1e9) + # sample = _split_along_first_dim(sample) sample = self.conv_in(sample) - print("after conv_in", torch.cuda.memory_allocated()/1e9, torch.cuda.max_memory_allocated()/1e9) + # print("after conv_in", torch.cuda.memory_allocated()/1e9, torch.cuda.max_memory_allocated()/1e9) # 3. down is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None From 1e68b82ba8a295708c9fc20eeb159df49925ff5f Mon Sep 17 00:00:00 2001 From: kimmishi Date: Tue, 29 Aug 2023 12:36:18 +0800 Subject: [PATCH 3/9] down block, upblock sequence parallel --- src/diffusers/models/unet_2d_blocks.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index d3fcfdd9694c..33c879ab99a2 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -1074,10 +1074,12 @@ def custom_forward(*inputs): output_states = output_states + (hidden_states,) + # import pdb;pdb.set_trace() if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states) + set_sequence_parallel_attr(hidden_states) output_states = output_states + (hidden_states,) set_sequence_parallel_attr(hidden_states) @@ -1140,6 +1142,9 @@ def forward(self, hidden_states, temb=None): output_states = () # print("before down2d resnet:", torch.cuda.memory_allocated()/1e9, torch.cuda.max_memory_allocated()/1e9) + hidden_states = maybe_split_into_sequnce_parallel(hidden_states) + if temb is not None: + temb = maybe_split_into_sequnce_parallel(temb) for resnet in self.resnets: if self.training and self.gradient_checkpointing: @@ -1159,16 +1164,19 @@ def custom_forward(*inputs): ) else: hidden_states = resnet(hidden_states, temb) - + set_sequence_parallel_attr(hidden_states) + # import pdb;pdb.set_trace() output_states = output_states + (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states) + set_sequence_parallel_attr(hidden_states) output_states = output_states + (hidden_states,) # print("after down2d resnet:", torch.cuda.memory_allocated()/1e9, torch.cuda.max_memory_allocated()/1e9) + set_sequence_parallel_attr(hidden_states) return hidden_states, output_states @@ -2266,10 +2274,14 @@ def __init__( self.gradient_checkpointing = False def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): - hidden_states = maybe_gather_from_sequence_parallel(hidden_states) + # hidden_states = maybe_gather_from_sequence_parallel(hidden_states) + hidden_states = maybe_split_into_sequnce_parallel(hidden_states) + if temb is not None: + temb = maybe_split_into_sequnce_parallel(temb) for resnet in self.resnets: # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states = maybe_split_into_sequnce_parallel(res_hidden_states) res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) @@ -2291,11 +2303,13 @@ def custom_forward(*inputs): ) else: hidden_states = resnet(hidden_states, temb) + set_sequence_parallel_attr(hidden_states) if self.upsamplers is not None: for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states, upsample_size) + hidden_states = maybe_gather_from_sequence_parallel(hidden_states) return hidden_states From 6067e6b9a229b58d946ada6c4062f8db2726e638 Mon Sep 17 00:00:00 2001 From: KimmiShi Date: Tue, 29 Aug 2023 19:02:10 +0800 Subject: [PATCH 4/9] Delete train_sdxl.py --- examples/text_to_image/train_sdxl.py | 1162 -------------------------- 1 file changed, 1162 deletions(-) delete mode 100644 examples/text_to_image/train_sdxl.py diff --git a/examples/text_to_image/train_sdxl.py b/examples/text_to_image/train_sdxl.py deleted file mode 100644 index 09f51eccf633..000000000000 --- a/examples/text_to_image/train_sdxl.py +++ /dev/null @@ -1,1162 +0,0 @@ -#!/usr/bin/env python -# coding=utf-8 -# Copyright 2023 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Fine-tuning script for Stable Diffusion XL for text2image.""" - -import argparse -import functools -import gc -import logging -import math -import os -import random -import shutil -from pathlib import Path - -import accelerate -import datasets -import numpy as np -import torch -import torch.nn.functional as F -import torch.utils.checkpoint -import transformers -from accelerate import Accelerator -from accelerate.logging import get_logger -from accelerate.utils import ProjectConfiguration, set_seed -from datasets import load_dataset -from huggingface_hub import create_repo, upload_folder -from packaging import version -from torchvision import transforms -from torchvision.transforms.functional import crop -from tqdm.auto import tqdm -from transformers import AutoTokenizer, PretrainedConfig - -import diffusers -from diffusers import ( - AutoencoderKL, - DDPMScheduler, - StableDiffusionXLPipeline, - UNet2DConditionModel, -) -from diffusers.optimization import get_scheduler -from diffusers.training_utils import EMAModel -from diffusers.utils import check_min_version, is_wandb_available -from diffusers.utils.import_utils import is_xformers_available - - -# Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.20.0.dev0") - -logger = get_logger(__name__) - - -DATASET_NAME_MAPPING = { - "lambdalabs/pokemon-blip-captions": ("image", "text"), -} - - -def save_model_card( - repo_id: str, - images=None, - base_model=str, - dataset_name=str, - repo_folder=None, - vae_path=None, -): - img_str = "" - for i, image in enumerate(images): - image.save(os.path.join(repo_folder, f"image_{i}.png")) - img_str += f"![img_{i}](./image_{i}.png)\n" - - yaml = f""" ---- -license: creativeml-openrail-m -base_model: {base_model} -dataset: {dataset_name} -tags: -- stable-diffusion-xl -- stable-diffusion-xl-diffusers -- text-to-image -- diffusers -inference: true ---- - """ - model_card = f""" -# Text-to-image finetuning - {repo_id} - -This pipeline was finetuned from **{args.pretrained_model_name_or_path}** on the **{args.dataset_name}** dataset. Below are some example images generated with the finetuned pipeline using the following prompts: {args.validation_prompts}: \n -{img_str} - -Special VAE used for training: {vae_path}. -""" - with open(os.path.join(repo_folder, "README.md"), "w") as f: - f.write(yaml + model_card) - - -def import_model_class_from_model_name_or_path( - pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder" -): - text_encoder_config = PretrainedConfig.from_pretrained( - pretrained_model_name_or_path, subfolder=subfolder, revision=revision - ) - model_class = text_encoder_config.architectures[0] - - if model_class == "CLIPTextModel": - from transformers import CLIPTextModel - - return CLIPTextModel - elif model_class == "CLIPTextModelWithProjection": - from transformers import CLIPTextModelWithProjection - - return CLIPTextModelWithProjection - else: - raise ValueError(f"{model_class} is not supported.") - - -def parse_args(input_args=None): - parser = argparse.ArgumentParser(description="Simple example of a training script.") - parser.add_argument( - "--pretrained_model_name_or_path", - type=str, - default=None, - required=True, - help="Path to pretrained model or model identifier from huggingface.co/models.", - ) - parser.add_argument( - "--pretrained_vae_model_name_or_path", - type=str, - default=None, - help="Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.", - ) - parser.add_argument( - "--revision", - type=str, - default=None, - required=False, - help="Revision of pretrained model identifier from huggingface.co/models.", - ) - parser.add_argument( - "--dataset_name", - type=str, - default=None, - help=( - "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private," - " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," - " or to a folder containing files that 🤗 Datasets can understand." - ), - ) - parser.add_argument( - "--dataset_config_name", - type=str, - default=None, - help="The config of the Dataset, leave as None if there's only one config.", - ) - parser.add_argument( - "--train_data_dir", - type=str, - default=None, - help=( - "A folder containing the training data. Folder contents must follow the structure described in" - " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file" - " must exist to provide the captions for the images. Ignored if `dataset_name` is specified." - ), - ) - parser.add_argument( - "--image_column", type=str, default="image", help="The column of the dataset containing an image." - ) - parser.add_argument( - "--caption_column", - type=str, - default="text", - help="The column of the dataset containing a caption or a list of captions.", - ) - parser.add_argument( - "--validation_prompt", - type=str, - default=None, - help="A prompt that is used during validation to verify that the model is learning.", - ) - parser.add_argument( - "--num_validation_images", - type=int, - default=4, - help="Number of images that should be generated during validation with `validation_prompt`.", - ) - parser.add_argument( - "--validation_epochs", - type=int, - default=1, - help=( - "Run fine-tuning validation every X epochs. The validation process consists of running the prompt" - " `args.validation_prompt` multiple times: `args.num_validation_images`." - ), - ) - parser.add_argument( - "--max_train_samples", - type=int, - default=None, - help=( - "For debugging purposes or quicker training, truncate the number of training examples to this " - "value if set." - ), - ) - parser.add_argument( - "--proportion_empty_prompts", - type=float, - default=0, - help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).", - ) - parser.add_argument( - "--output_dir", - type=str, - default="sdxl-model-finetuned", - help="The output directory where the model predictions and checkpoints will be written.", - ) - parser.add_argument( - "--cache_dir", - type=str, - default=None, - help="The directory where the downloaded models and datasets will be stored.", - ) - parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") - parser.add_argument( - "--resolution", - type=int, - default=1024, - help=( - "The resolution for input images, all the images in the train/validation dataset will be resized to this" - " resolution" - ), - ) - parser.add_argument( - "--center_crop", - default=False, - action="store_true", - help=( - "Whether to center crop the input images to the resolution. If not set, the images will be randomly" - " cropped. The images will be resized to the resolution first before cropping." - ), - ) - parser.add_argument( - "--random_flip", - action="store_true", - help="whether to randomly flip images horizontally", - ) - parser.add_argument( - "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader." - ) - parser.add_argument("--num_train_epochs", type=int, default=100) - parser.add_argument( - "--max_train_steps", - type=int, - default=None, - help="Total number of training steps to perform. If provided, overrides num_train_epochs.", - ) - parser.add_argument( - "--checkpointing_steps", - type=int, - default=500, - help=( - "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" - " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming" - " training using `--resume_from_checkpoint`." - ), - ) - parser.add_argument( - "--checkpoints_total_limit", - type=int, - default=None, - help=("Max number of checkpoints to store."), - ) - parser.add_argument( - "--resume_from_checkpoint", - type=str, - default=None, - help=( - "Whether training should be resumed from a previous checkpoint. Use a path saved by" - ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' - ), - ) - parser.add_argument( - "--gradient_accumulation_steps", - type=int, - default=1, - help="Number of updates steps to accumulate before performing a backward/update pass.", - ) - parser.add_argument( - "--gradient_checkpointing", - action="store_true", - help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", - ) - parser.add_argument( - "--learning_rate", - type=float, - default=1e-4, - help="Initial learning rate (after the potential warmup period) to use.", - ) - parser.add_argument( - "--scale_lr", - action="store_true", - default=False, - help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", - ) - parser.add_argument( - "--lr_scheduler", - type=str, - default="constant", - help=( - 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' - ' "constant", "constant_with_warmup"]' - ), - ) - parser.add_argument( - "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." - ) - parser.add_argument( - "--snr_gamma", - type=float, - default=None, - help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. " - "More details here: https://arxiv.org/abs/2303.09556.", - ) - parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.") - parser.add_argument( - "--allow_tf32", - action="store_true", - help=( - "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" - " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" - ), - ) - parser.add_argument( - "--dataloader_num_workers", - type=int, - default=0, - help=( - "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." - ), - ) - parser.add_argument( - "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." - ) - parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") - parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") - parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") - parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") - parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") - parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") - parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") - parser.add_argument( - "--prediction_type", - type=str, - default=None, - help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediciton_type` is chosen.", - ) - parser.add_argument( - "--hub_model_id", - type=str, - default=None, - help="The name of the repository to keep in sync with the local `output_dir`.", - ) - parser.add_argument( - "--logging_dir", - type=str, - default="logs", - help=( - "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" - " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." - ), - ) - parser.add_argument( - "--report_to", - type=str, - default="tensorboard", - help=( - 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' - ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' - ), - ) - parser.add_argument( - "--mixed_precision", - type=str, - default=None, - choices=["no", "fp16", "bf16"], - help=( - "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" - " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" - " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." - ), - ) - parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") - parser.add_argument( - "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." - ) - parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.") - - if input_args is not None: - args = parser.parse_args(input_args) - else: - args = parser.parse_args() - - env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) - if env_local_rank != -1 and env_local_rank != args.local_rank: - args.local_rank = env_local_rank - - # Sanity checks - if args.dataset_name is None and args.train_data_dir is None: - raise ValueError("Need either a dataset name or a training folder.") - - if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1: - raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].") - - return args - - -# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt -def encode_prompt(batch, text_encoders, tokenizers, proportion_empty_prompts, caption_column, is_train=True): - prompt_embeds_list = [] - prompt_batch = batch[caption_column] - - captions = [] - for caption in prompt_batch: - if random.random() < proportion_empty_prompts: - captions.append("") - elif isinstance(caption, str): - captions.append(caption) - elif isinstance(caption, (list, np.ndarray)): - # take a random caption if there are multiple - captions.append(random.choice(caption) if is_train else caption[0]) - - with torch.no_grad(): - for tokenizer, text_encoder in zip(tokenizers, text_encoders): - text_inputs = tokenizer( - captions, - padding="max_length", - max_length=tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - prompt_embeds = text_encoder( - text_input_ids.to(text_encoder.device), - output_hidden_states=True, - ) - - # We are only ALWAYS interested in the pooled output of the final text encoder - pooled_prompt_embeds = prompt_embeds[0] - prompt_embeds = prompt_embeds.hidden_states[-2] - bs_embed, seq_len, _ = prompt_embeds.shape - prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1) - prompt_embeds_list.append(prompt_embeds) - - prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) - pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1) - return {"prompt_embeds": prompt_embeds.cpu(), "pooled_prompt_embeds": pooled_prompt_embeds.cpu()} - - -def main(args): - logging_dir = Path(args.output_dir, args.logging_dir) - - accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) - - accelerator = Accelerator( - gradient_accumulation_steps=args.gradient_accumulation_steps, - mixed_precision=args.mixed_precision, - log_with=args.report_to, - project_config=accelerator_project_config, - ) - - if args.report_to == "wandb": - if not is_wandb_available(): - raise ImportError("Make sure to install wandb if you want to use it for logging during training.") - import wandb - - # Make one log on every process with the configuration for debugging. - logging.basicConfig( - format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", - datefmt="%m/%d/%Y %H:%M:%S", - level=logging.INFO, - ) - logger.info(accelerator.state, main_process_only=False) - if accelerator.is_local_main_process: - datasets.utils.logging.set_verbosity_warning() - transformers.utils.logging.set_verbosity_warning() - diffusers.utils.logging.set_verbosity_info() - else: - datasets.utils.logging.set_verbosity_error() - transformers.utils.logging.set_verbosity_error() - diffusers.utils.logging.set_verbosity_error() - - # If passed along, set the training seed now. - if args.seed is not None: - set_seed(args.seed) - - # Handle the repository creation - if accelerator.is_main_process: - if args.output_dir is not None: - os.makedirs(args.output_dir, exist_ok=True) - - if args.push_to_hub: - repo_id = create_repo( - repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token - ).repo_id - - # Load the tokenizers - tokenizer_one = AutoTokenizer.from_pretrained( - args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, use_fast=False - ) - tokenizer_two = AutoTokenizer.from_pretrained( - args.pretrained_model_name_or_path, subfolder="tokenizer_2", revision=args.revision, use_fast=False - ) - - # import correct text encoder classes - text_encoder_cls_one = import_model_class_from_model_name_or_path( - args.pretrained_model_name_or_path, args.revision - ) - text_encoder_cls_two = import_model_class_from_model_name_or_path( - args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2" - ) - - # Load scheduler and models - noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") - text_encoder_one = text_encoder_cls_one.from_pretrained( - args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision - ) - text_encoder_two = text_encoder_cls_two.from_pretrained( - args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision - ) - vae_path = ( - args.pretrained_model_name_or_path - if args.pretrained_vae_model_name_or_path is None - else args.pretrained_vae_model_name_or_path - ) - vae = AutoencoderKL.from_pretrained( - vae_path, subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, revision=args.revision - ) - unet = UNet2DConditionModel.from_pretrained( - args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision - ) - numel = sum([p.numel() for p in unet.parameters()]) - print('UNET numel:', numel) - # import pdb;pdb.set_trace() - # Freeze vae and text encoders. - vae.requires_grad_(False) - text_encoder_one.requires_grad_(False) - text_encoder_two.requires_grad_(False) - print('vae numel:', sum([p.numel() for p in vae.parameters()])) - print('text_encoder_one numel:', sum([p.numel() for p in text_encoder_one.parameters()])) - print('text_encoder_two numel:', sum([p.numel() for p in text_encoder_two.parameters()])) - - # For mixed precision training we cast all non-trainable weigths to half-precision - # as these weights are only used for inference, keeping weights in full precision is not required. - weight_dtype = torch.float32 - if accelerator.mixed_precision == "fp16": - weight_dtype = torch.float16 - elif accelerator.mixed_precision == "bf16": - weight_dtype = torch.bfloat16 - - # Move unet, vae and text_encoder to device and cast to weight_dtype - # The VAE is in float32 to avoid NaN losses. - if args.pretrained_vae_model_name_or_path is None: - vae.to(accelerator.device, dtype=torch.float32) - else: - vae.to(accelerator.device, dtype=weight_dtype) - text_encoder_one.to(accelerator.device, dtype=weight_dtype) - text_encoder_two.to(accelerator.device, dtype=weight_dtype) - - # Create EMA for the unet. - if args.use_ema: - ema_unet = UNet2DConditionModel.from_pretrained( - args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision - ) - ema_unet = EMAModel(ema_unet.parameters(), model_cls=UNet2DConditionModel, model_config=ema_unet.config) - - if args.enable_xformers_memory_efficient_attention: - if is_xformers_available(): - import xformers - - xformers_version = version.parse(xformers.__version__) - if xformers_version == version.parse("0.0.16"): - logger.warn( - "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." - ) - unet.enable_xformers_memory_efficient_attention() - else: - raise ValueError("xformers is not available. Make sure it is installed correctly") - - def compute_snr(timesteps): - """ - Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849 - """ - alphas_cumprod = noise_scheduler.alphas_cumprod - sqrt_alphas_cumprod = alphas_cumprod**0.5 - sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 - - # Expand the tensors. - # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026 - sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float() - while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape): - sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] - alpha = sqrt_alphas_cumprod.expand(timesteps.shape) - - sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float() - while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape): - sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] - sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape) - - # Compute SNR. - snr = (alpha / sigma) ** 2 - return snr - - # `accelerate` 0.16.0 will have better support for customized saving - if version.parse(accelerate.__version__) >= version.parse("0.16.0"): - # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format - def save_model_hook(models, weights, output_dir): - if args.use_ema: - ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema")) - - for i, model in enumerate(models): - model.save_pretrained(os.path.join(output_dir, "unet")) - - # make sure to pop weight so that corresponding model is not saved again - weights.pop() - - def load_model_hook(models, input_dir): - if args.use_ema: - load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DConditionModel) - ema_unet.load_state_dict(load_model.state_dict()) - ema_unet.to(accelerator.device) - del load_model - - for i in range(len(models)): - # pop models so that they are not loaded again - model = models.pop() - - # load diffusers style into model - load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet") - model.register_to_config(**load_model.config) - - model.load_state_dict(load_model.state_dict()) - del load_model - - accelerator.register_save_state_pre_hook(save_model_hook) - accelerator.register_load_state_pre_hook(load_model_hook) - - accelerator.register_save_state_pre_hook(save_model_hook) - accelerator.register_load_state_pre_hook(load_model_hook) - - if args.gradient_checkpointing: - unet.enable_gradient_checkpointing() - - # Enable TF32 for faster training on Ampere GPUs, - # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices - if args.allow_tf32: - torch.backends.cuda.matmul.allow_tf32 = True - - if args.scale_lr: - args.learning_rate = ( - args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes - ) - - # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs - if args.use_8bit_adam: - try: - import bitsandbytes as bnb - except ImportError: - raise ImportError( - "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." - ) - - optimizer_class = bnb.optim.AdamW8bit - else: - optimizer_class = torch.optim.AdamW - - # Optimizer creation - params_to_optimize = unet.parameters() - optimizer = optimizer_class( - params_to_optimize, - lr=args.learning_rate, - betas=(args.adam_beta1, args.adam_beta2), - weight_decay=args.adam_weight_decay, - eps=args.adam_epsilon, - ) - - # Get the datasets: you can either provide your own training and evaluation files (see below) - # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub). - - # In distributed training, the load_dataset function guarantees that only one local process can concurrently - # download the dataset. - if args.dataset_name is not None: - # Downloading and loading a dataset from the hub. - dataset = load_dataset( - args.dataset_name, - args.dataset_config_name, - cache_dir=args.cache_dir, - ) - else: - data_files = {} - if args.train_data_dir is not None: - data_files["train"] = os.path.join(args.train_data_dir, "**") - dataset = load_dataset( - "imagefolder", - data_files=data_files, - cache_dir=args.cache_dir, - ) - # See more about loading custom images at - # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder - - # Preprocessing the datasets. - # We need to tokenize inputs and targets. - column_names = dataset["train"].column_names - - # 6. Get the column names for input/target. - dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None) - if args.image_column is None: - image_column = dataset_columns[0] if dataset_columns is not None else column_names[0] - else: - image_column = args.image_column - if image_column not in column_names: - raise ValueError( - f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}" - ) - if args.caption_column is None: - caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1] - else: - caption_column = args.caption_column - if caption_column not in column_names: - raise ValueError( - f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}" - ) - - # Preprocessing the datasets. - train_resize = transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR) - train_crop = transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution) - train_flip = transforms.RandomHorizontalFlip(p=1.0) - train_transforms = transforms.Compose( - [ - transforms.ToTensor(), - transforms.Normalize([0.5], [0.5]), - ] - ) - - def preprocess_train(examples): - images = [image.convert("RGB") for image in examples[image_column]] - # image aug - original_sizes = [] - all_images = [] - crop_top_lefts = [] - for image in images: - original_sizes.append((image.height, image.width)) - image = train_resize(image) - if args.center_crop: - y1 = max(0, int(round((image.height - args.resolution) / 2.0))) - x1 = max(0, int(round((image.width - args.resolution) / 2.0))) - image = train_crop(image) - else: - y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution)) - image = crop(image, y1, x1, h, w) - if args.random_flip and random.random() < 0.5: - # flip - x1 = image.width - x1 - image = train_flip(image) - crop_top_left = (y1, x1) - crop_top_lefts.append(crop_top_left) - image = train_transforms(image) - all_images.append(image) - - examples["original_sizes"] = original_sizes - examples["crop_top_lefts"] = crop_top_lefts - examples["pixel_values"] = all_images - return examples - - with accelerator.main_process_first(): - if args.max_train_samples is not None: - dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples)) - # Set the training transforms - train_dataset = dataset["train"].with_transform(preprocess_train) - - # Let's first compute all the embeddings so that we can free up the text encoders - # from memory. - text_encoders = [text_encoder_one, text_encoder_two] - tokenizers = [tokenizer_one, tokenizer_two] - compute_embeddings_fn = functools.partial( - encode_prompt, - text_encoders=text_encoders, - tokenizers=tokenizers, - proportion_empty_prompts=args.proportion_empty_prompts, - caption_column=args.caption_column, - ) - with accelerator.main_process_first(): - from datasets.fingerprint import Hasher - - # fingerprint used by the cache for the other processes to load the result - # details: https://github.com/huggingface/diffusers/pull/4038#discussion_r1266078401 - new_fingerprint = Hasher.hash(args) - train_dataset = train_dataset.map(compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint) - - del text_encoders, tokenizers - gc.collect() - torch.cuda.empty_cache() - - def collate_fn(examples): - pixel_values = torch.stack([example["pixel_values"] for example in examples]) - pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() - original_sizes = [example["original_sizes"] for example in examples] - crop_top_lefts = [example["crop_top_lefts"] for example in examples] - prompt_embeds = torch.stack([torch.tensor(example["prompt_embeds"]) for example in examples]) - pooled_prompt_embeds = torch.stack([torch.tensor(example["pooled_prompt_embeds"]) for example in examples]) - - return { - "pixel_values": pixel_values, - "prompt_embeds": prompt_embeds, - "pooled_prompt_embeds": pooled_prompt_embeds, - "original_sizes": original_sizes, - "crop_top_lefts": crop_top_lefts, - } - - # DataLoaders creation: - train_dataloader = torch.utils.data.DataLoader( - train_dataset, - shuffle=True, - collate_fn=collate_fn, - batch_size=args.train_batch_size, - num_workers=args.dataloader_num_workers, - ) - - # Scheduler and math around the number of training steps. - overrode_max_train_steps = False - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) - if args.max_train_steps is None: - args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch - overrode_max_train_steps = True - - lr_scheduler = get_scheduler( - args.lr_scheduler, - optimizer=optimizer, - num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, - num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, - ) - - # Prepare everything with our `accelerator`. - unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - unet, optimizer, train_dataloader, lr_scheduler - ) - - # We need to recalculate our total training steps as the size of the training dataloader may have changed. - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) - if overrode_max_train_steps: - args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch - # Afterwards we recalculate our number of training epochs - args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) - - # We need to initialize the trackers we use, and also store our configuration. - # The trackers initializes automatically on the main process. - if accelerator.is_main_process: - accelerator.init_trackers("text2image-fine-tune-sdxl", config=vars(args)) - - # Train! - total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps - - logger.info("***** Running training *****") - logger.info(f" Num examples = {len(train_dataset)}") - logger.info(f" Num Epochs = {args.num_train_epochs}") - logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") - logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") - logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") - logger.info(f" Total optimization steps = {args.max_train_steps}") - global_step = 0 - first_epoch = 0 - - # Potentially load in the weights and states from a previous save - if args.resume_from_checkpoint: - if args.resume_from_checkpoint != "latest": - path = os.path.basename(args.resume_from_checkpoint) - else: - # Get the most recent checkpoint - dirs = os.listdir(args.output_dir) - dirs = [d for d in dirs if d.startswith("checkpoint")] - dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) - path = dirs[-1] if len(dirs) > 0 else None - - if path is None: - accelerator.print( - f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." - ) - args.resume_from_checkpoint = None - else: - accelerator.print(f"Resuming from checkpoint {path}") - accelerator.load_state(os.path.join(args.output_dir, path)) - global_step = int(path.split("-")[1]) - - resume_global_step = global_step * args.gradient_accumulation_steps - first_epoch = global_step // num_update_steps_per_epoch - resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps) - - # Only show the progress bar once on each machine. - progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) - progress_bar.set_description("Steps") - - for epoch in range(first_epoch, args.num_train_epochs): - unet.train() - train_loss = 0.0 - for step, batch in enumerate(train_dataloader): - # Skip steps until we reach the resumed step - if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: - if step % args.gradient_accumulation_steps == 0: - progress_bar.update(1) - continue - - with accelerator.accumulate(unet): - # Convert images to latent space - pixel_values = batch["pixel_values"].to(dtype=weight_dtype) - model_input = vae.encode(pixel_values.to(vae.dtype)).latent_dist.sample() - model_input = model_input * vae.config.scaling_factor - - # Sample noise that we'll add to the latents - noise = torch.randn_like(model_input) - if args.noise_offset: - # https://www.crosslabs.org//blog/diffusion-with-offset-noise - noise += args.noise_offset * torch.randn( - (model_input.shape[0], model_input.shape[1], 1, 1), device=model_input.device - ) - - bsz = model_input.shape[0] - # Sample a random timestep for each image - timesteps = torch.randint( - 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device - ) - timesteps = timesteps.long() - - # Add noise to the model input according to the noise magnitude at each timestep - # (this is the forward diffusion process) - noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps) - - # time ids - def compute_time_ids(original_size, crops_coords_top_left): - # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids - target_size = (args.resolution, args.resolution) - add_time_ids = list(original_size + crops_coords_top_left + target_size) - add_time_ids = torch.tensor([add_time_ids]) - add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype) - return add_time_ids - - add_time_ids = torch.cat( - [compute_time_ids(s, c) for s, c in zip(batch["original_sizes"], batch["crop_top_lefts"])] - ) - - # Predict the noise residual - unet_added_conditions = {"time_ids": add_time_ids} - prompt_embeds = batch["prompt_embeds"].to(accelerator.device) - pooled_prompt_embeds = batch["pooled_prompt_embeds"].to(accelerator.device) - unet_added_conditions.update({"text_embeds": pooled_prompt_embeds}) - prompt_embeds = prompt_embeds - model_pred = unet( - noisy_model_input, timesteps, prompt_embeds, added_cond_kwargs=unet_added_conditions - ).sample - - # Get the target for loss depending on the prediction type - if args.prediction_type is not None: - # set prediction_type of scheduler if defined - noise_scheduler.register_to_config(prediction_type=args.prediction_type) - - if noise_scheduler.config.prediction_type == "epsilon": - target = noise - elif noise_scheduler.config.prediction_type == "v_prediction": - target = noise_scheduler.get_velocity(model_input, noise, timesteps) - else: - raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") - - if args.snr_gamma is None: - loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") - else: - # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556. - # Since we predict the noise instead of x_0, the original formulation is slightly changed. - # This is discussed in Section 4.2 of the same paper. - snr = compute_snr(timesteps) - mse_loss_weights = ( - torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr - ) - # We first calculate the original loss. Then we mean over the non-batch dimensions and - # rebalance the sample-wise losses with their respective loss weights. - # Finally, we take the mean of the rebalanced loss. - loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") - loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights - loss = loss.mean() - - # Gather the losses across all processes for logging (if we use distributed training). - avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean() - train_loss += avg_loss.item() / args.gradient_accumulation_steps - - # Backpropagate - accelerator.backward(loss) - if accelerator.sync_gradients: - params_to_clip = unet.parameters() - accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad() - - # Checks if the accelerator has performed an optimization step behind the scenes - if accelerator.sync_gradients: - progress_bar.update(1) - global_step += 1 - accelerator.log({"train_loss": train_loss}, step=global_step) - train_loss = 0.0 - - if accelerator.is_main_process: - if global_step % args.checkpointing_steps == 0: - # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` - if args.checkpoints_total_limit is not None: - checkpoints = os.listdir(args.output_dir) - checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] - checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) - - # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints - if len(checkpoints) >= args.checkpoints_total_limit: - num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 - removing_checkpoints = checkpoints[0:num_to_remove] - - logger.info( - f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" - ) - logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") - - for removing_checkpoint in removing_checkpoints: - removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) - shutil.rmtree(removing_checkpoint) - - save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") - accelerator.save_state(save_path) - logger.info(f"Saved state to {save_path}") - - logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "Max MEM": torch.cuda.max_memory_allocated()} - progress_bar.set_postfix(**logs) - - if global_step >= args.max_train_steps: - break - - if accelerator.is_main_process: - if args.validation_prompt is not None and epoch % args.validation_epochs == 0: - logger.info( - f"Running validation... \n Generating {args.num_validation_images} images with prompt:" - f" {args.validation_prompt}." - ) - if args.use_ema: - # Store the UNet parameters temporarily and load the EMA parameters to perform inference. - ema_unet.store(unet.parameters()) - ema_unet.copy_to(unet.parameters()) - # create pipeline - pipeline = StableDiffusionXLPipeline.from_pretrained( - args.pretrained_model_name_or_path, - vae=vae, - unet=accelerator.unwrap_model(unet), - revision=args.revision, - torch_dtype=weight_dtype, - ) - - pipeline = pipeline.to(accelerator.device) - pipeline.set_progress_bar_config(disable=True) - - # run inference - generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None - pipeline_args = {"prompt": args.validation_prompt} - - with torch.cuda.amp.autocast(): - images = [ - pipeline(**pipeline_args, generator=generator, num_inference_steps=25).images[0] - for _ in range(args.num_validation_images) - ] - - for tracker in accelerator.trackers: - if tracker.name == "tensorboard": - np_images = np.stack([np.asarray(img) for img in images]) - tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC") - if tracker.name == "wandb": - tracker.log( - { - "validation": [ - wandb.Image(image, caption=f"{i}: {args.validation_prompt}") - for i, image in enumerate(images) - ] - } - ) - - del pipeline - torch.cuda.empty_cache() - - accelerator.wait_for_everyone() - if accelerator.is_main_process: - unet = accelerator.unwrap_model(unet) - if args.use_ema: - ema_unet.copy_to(unet.parameters()) - - # Final inference - # Load previous pipeline - vae = AutoencoderKL.from_pretrained( - vae_path, - subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, - revision=args.revision, - torch_dtype=weight_dtype, - ) - pipeline = StableDiffusionXLPipeline.from_pretrained( - args.pretrained_model_name_or_path, unet=unet, vae=vae, revision=args.revision, torch_dtype=weight_dtype - ) - pipeline = pipeline.to(accelerator.device) - - # run inference - images = [] - if args.validation_prompt and args.num_validation_images > 0: - generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None - images = [ - pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0] - for _ in range(args.num_validation_images) - ] - - for tracker in accelerator.trackers: - if tracker.name == "tensorboard": - np_images = np.stack([np.asarray(img) for img in images]) - tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC") - if tracker.name == "wandb": - tracker.log( - { - "test": [ - wandb.Image(image, caption=f"{i}: {args.validation_prompt}") - for i, image in enumerate(images) - ] - } - ) - - if args.push_to_hub: - save_model_card( - repo_id, - images=images, - base_model=args.pretrained_model_name_or_path, - prompt=args.instance_prompt, - repo_folder=args.output_dir, - vae_path=args.pretrained_vae_model_name_or_path, - ) - upload_folder( - repo_id=repo_id, - folder_path=args.output_dir, - commit_message="End of training", - ignore_patterns=["step_*", "epoch_*"], - ) - - accelerator.end_training() - - -if __name__ == "__main__": - args = parse_args() - main(args) \ No newline at end of file From de1dbaa456c6d2d97642b182adeb0af1574a60c6 Mon Sep 17 00:00:00 2001 From: KimmiShi Date: Tue, 29 Aug 2023 19:02:20 +0800 Subject: [PATCH 5/9] Delete train_sdxl_new.py --- examples/text_to_image/train_sdxl_new.py | 1222 ---------------------- 1 file changed, 1222 deletions(-) delete mode 100644 examples/text_to_image/train_sdxl_new.py diff --git a/examples/text_to_image/train_sdxl_new.py b/examples/text_to_image/train_sdxl_new.py deleted file mode 100644 index ffe0e0b03112..000000000000 --- a/examples/text_to_image/train_sdxl_new.py +++ /dev/null @@ -1,1222 +0,0 @@ -#!/usr/bin/env python -# coding=utf-8 -# Copyright 2023 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Fine-tuning script for Stable Diffusion XL for text2image.""" - -import argparse -import functools -import gc -import logging -import math -import os -import random -import shutil -from pathlib import Path - -import accelerate -import datasets -import numpy as np -import torch -import torch.nn.functional as F -import torch.utils.checkpoint -import transformers -from accelerate import Accelerator -from accelerate.logging import get_logger -from accelerate.utils import ProjectConfiguration, set_seed -from datasets import load_dataset -from huggingface_hub import create_repo, upload_folder -from packaging import version -from torchvision import transforms -from torchvision.transforms.functional import crop -from tqdm.auto import tqdm -from transformers import AutoTokenizer, PretrainedConfig -from apex.optimizers import FusedAdam - -import diffusers -from diffusers import ( - AutoencoderKL, - DDPMScheduler, - StableDiffusionXLPipeline, - UNet2DConditionModel, -) -from diffusers.optimization import get_scheduler -from diffusers.training_utils import EMAModel -from diffusers.utils import check_min_version, is_wandb_available -from diffusers.utils.import_utils import is_xformers_available - -import time - -# Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.20.0.dev0") - -logger = get_logger(__name__) - -from apex.normalization import FusedLayerNorm -from flash_attn.ops.rms_norm import RMSNorm - - -DATASET_NAME_MAPPING = { - "lambdalabs/pokemon-blip-captions": ("image", "text"), -} - -def replace_all_layernorms(model, tgt): - for name, module in model.named_children(): - if isinstance(module, torch.nn.LayerNorm): - if tgt == 'fusedln': - setattr(model, name, FusedLayerNorm( - module.normalized_shape, module.eps, module.elementwise_affine)) - elif tgt == 'rmsnorm': - setattr(model, name, RMSNorm(module.normalized_shape, eps=module.eps)) - else: - import pdb;pdb.set_trace() - raise ValueError - else: - replace_all_layernorms(module, tgt) - return model - -def save_model_card( - repo_id: str, - images=None, - base_model=str, - dataset_name=str, - repo_folder=None, - vae_path=None, -): - img_str = "" - for i, image in enumerate(images): - image.save(os.path.join(repo_folder, f"image_{i}.png")) - img_str += f"![img_{i}](./image_{i}.png)\n" - - yaml = f""" ---- -license: creativeml-openrail-m -base_model: {base_model} -dataset: {dataset_name} -tags: -- stable-diffusion-xl -- stable-diffusion-xl-diffusers -- text-to-image -- diffusers -inference: true ---- - """ - model_card = f""" -# Text-to-image finetuning - {repo_id} - -This pipeline was finetuned from **{args.pretrained_model_name_or_path}** on the **{args.dataset_name}** dataset. Below are some example images generated with the finetuned pipeline using the following prompts: {args.validation_prompts}: \n -{img_str} - -Special VAE used for training: {vae_path}. -""" - with open(os.path.join(repo_folder, "README.md"), "w") as f: - f.write(yaml + model_card) - - -def import_model_class_from_model_name_or_path( - pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder" -): - text_encoder_config = PretrainedConfig.from_pretrained( - pretrained_model_name_or_path, subfolder=subfolder, revision=revision - ) - model_class = text_encoder_config.architectures[0] - - if model_class == "CLIPTextModel": - from transformers import CLIPTextModel - - return CLIPTextModel - elif model_class == "CLIPTextModelWithProjection": - from transformers import CLIPTextModelWithProjection - - return CLIPTextModelWithProjection - else: - raise ValueError(f"{model_class} is not supported.") - - -def parse_args(input_args=None): - parser = argparse.ArgumentParser(description="Simple example of a training script.") - parser.add_argument( - "--pretrained_model_name_or_path", - type=str, - default=None, - required=True, - help="Path to pretrained model or model identifier from huggingface.co/models.", - ) - parser.add_argument( - "--pretrained_vae_model_name_or_path", - type=str, - default=None, - help="Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.", - ) - parser.add_argument( - "--revision", - type=str, - default=None, - required=False, - help="Revision of pretrained model identifier from huggingface.co/models.", - ) - parser.add_argument( - "--dataset_name", - type=str, - default=None, - help=( - "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private," - " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," - " or to a folder containing files that 🤗 Datasets can understand." - ), - ) - parser.add_argument( - "--dataset_config_name", - type=str, - default=None, - help="The config of the Dataset, leave as None if there's only one config.", - ) - parser.add_argument( - "--train_data_dir", - type=str, - default=None, - help=( - "A folder containing the training data. Folder contents must follow the structure described in" - " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file" - " must exist to provide the captions for the images. Ignored if `dataset_name` is specified." - ), - ) - parser.add_argument( - "--image_column", type=str, default="image", help="The column of the dataset containing an image." - ) - parser.add_argument( - "--caption_column", - type=str, - default="text", - help="The column of the dataset containing a caption or a list of captions.", - ) - parser.add_argument( - "--validation_prompt", - type=str, - default=None, - help="A prompt that is used during validation to verify that the model is learning.", - ) - parser.add_argument( - "--num_validation_images", - type=int, - default=4, - help="Number of images that should be generated during validation with `validation_prompt`.", - ) - parser.add_argument( - "--validation_epochs", - type=int, - default=1, - help=( - "Run fine-tuning validation every X epochs. The validation process consists of running the prompt" - " `args.validation_prompt` multiple times: `args.num_validation_images`." - ), - ) - parser.add_argument( - "--max_train_samples", - type=int, - default=None, - help=( - "For debugging purposes or quicker training, truncate the number of training examples to this " - "value if set." - ), - ) - parser.add_argument( - "--proportion_empty_prompts", - type=float, - default=0, - help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).", - ) - parser.add_argument( - "--output_dir", - type=str, - default="sdxl-model-finetuned", - help="The output directory where the model predictions and checkpoints will be written.", - ) - parser.add_argument( - "--cache_dir", - type=str, - default=None, - help="The directory where the downloaded models and datasets will be stored.", - ) - parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") - parser.add_argument( - "--resolution", - type=int, - default=1024, - help=( - "The resolution for input images, all the images in the train/validation dataset will be resized to this" - " resolution" - ), - ) - parser.add_argument( - "--center_crop", - default=False, - action="store_true", - help=( - "Whether to center crop the input images to the resolution. If not set, the images will be randomly" - " cropped. The images will be resized to the resolution first before cropping." - ), - ) - parser.add_argument( - "--random_flip", - action="store_true", - help="whether to randomly flip images horizontally", - ) - parser.add_argument( - "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader." - ) - parser.add_argument("--num_train_epochs", type=int, default=100) - parser.add_argument( - "--max_train_steps", - type=int, - default=None, - help="Total number of training steps to perform. If provided, overrides num_train_epochs.", - ) - parser.add_argument( - "--checkpointing_steps", - type=int, - default=500, - help=( - "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" - " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming" - " training using `--resume_from_checkpoint`." - ), - ) - parser.add_argument( - "--checkpoints_total_limit", - type=int, - default=None, - help=("Max number of checkpoints to store."), - ) - parser.add_argument( - "--resume_from_checkpoint", - type=str, - default=None, - help=( - "Whether training should be resumed from a previous checkpoint. Use a path saved by" - ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' - ), - ) - parser.add_argument( - "--gradient_accumulation_steps", - type=int, - default=1, - help="Number of updates steps to accumulate before performing a backward/update pass.", - ) - parser.add_argument( - "--gradient_checkpointing", - action="store_true", - help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", - ) - parser.add_argument( - "--learning_rate", - type=float, - default=1e-4, - help="Initial learning rate (after the potential warmup period) to use.", - ) - parser.add_argument( - "--scale_lr", - action="store_true", - default=False, - help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", - ) - parser.add_argument( - "--lr_scheduler", - type=str, - default="constant", - help=( - 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' - ' "constant", "constant_with_warmup"]' - ), - ) - parser.add_argument( - "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." - ) - parser.add_argument( - "--snr_gamma", - type=float, - default=None, - help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. " - "More details here: https://arxiv.org/abs/2303.09556.", - ) - parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.") - parser.add_argument( - "--allow_tf32", - action="store_true", - help=( - "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" - " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" - ), - ) - parser.add_argument( - "--dataloader_num_workers", - type=int, - default=0, - help=( - "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." - ), - ) - parser.add_argument( - "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." - ) - parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") - parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") - parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") - parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") - parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") - parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") - parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") - parser.add_argument( - "--prediction_type", - type=str, - default=None, - help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediciton_type` is chosen.", - ) - parser.add_argument( - "--hub_model_id", - type=str, - default=None, - help="The name of the repository to keep in sync with the local `output_dir`.", - ) - parser.add_argument( - "--logging_dir", - type=str, - default="logs", - help=( - "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" - " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." - ), - ) - parser.add_argument( - "--report_to", - type=str, - default="tensorboard", - help=( - 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' - ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' - ), - ) - parser.add_argument( - "--mixed_precision", - type=str, - default=None, - choices=["no", "fp16", "bf16"], - help=( - "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" - " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" - " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." - ), - ) - parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") - parser.add_argument( - "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." - ) - parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.") - parser.add_argument("--profile", action="store_true") - - if input_args is not None: - args = parser.parse_args(input_args) - else: - args = parser.parse_args() - - env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) - if env_local_rank != -1 and env_local_rank != args.local_rank: - args.local_rank = env_local_rank - - # Sanity checks - if args.dataset_name is None and args.train_data_dir is None: - raise ValueError("Need either a dataset name or a training folder.") - - if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1: - raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].") - - return args - - -# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt -def encode_prompt(batch, text_encoders, tokenizers, proportion_empty_prompts, caption_column, is_train=True): - prompt_embeds_list = [] - prompt_batch = batch[caption_column] - - captions = [] - for caption in prompt_batch: - if random.random() < proportion_empty_prompts: - captions.append("") - elif isinstance(caption, str): - captions.append(caption) - elif isinstance(caption, (list, np.ndarray)): - # take a random caption if there are multiple - captions.append(random.choice(caption) if is_train else caption[0]) - - with torch.no_grad(): - for tokenizer, text_encoder in zip(tokenizers, text_encoders): - text_inputs = tokenizer( - captions, - padding="max_length", - max_length=tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - prompt_embeds = text_encoder( - text_input_ids.to(text_encoder.device), - output_hidden_states=True, - ) - - # We are only ALWAYS interested in the pooled output of the final text encoder - pooled_prompt_embeds = prompt_embeds[0] - prompt_embeds = prompt_embeds.hidden_states[-2] - bs_embed, seq_len, _ = prompt_embeds.shape - prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1) - prompt_embeds_list.append(prompt_embeds) - - prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) - pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1) - return {"prompt_embeds": prompt_embeds.cpu(), "pooled_prompt_embeds": pooled_prompt_embeds.cpu()} - - -def compute_vae_encodings(batch, vae): - images = batch.pop("pixel_values") - pixel_values = torch.stack(list(images)) - pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() - pixel_values = pixel_values.to(vae.device, dtype=vae.dtype) - - with torch.no_grad(): - model_input = vae.encode(pixel_values).latent_dist.sample() - model_input = model_input * vae.config.scaling_factor - return {"model_input": model_input.cpu()} - - -def main(args): - logging_dir = Path(args.output_dir, args.logging_dir) - - accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) - - accelerator = Accelerator( - gradient_accumulation_steps=args.gradient_accumulation_steps, - mixed_precision=args.mixed_precision, - log_with=args.report_to, - project_config=accelerator_project_config, - ) - - if args.report_to == "wandb": - if not is_wandb_available(): - raise ImportError("Make sure to install wandb if you want to use it for logging during training.") - import wandb - - # Make one log on every process with the configuration for debugging. - logging.basicConfig( - format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", - datefmt="%m/%d/%Y %H:%M:%S", - level=logging.INFO, - ) - logger.info(accelerator.state, main_process_only=False) - if accelerator.is_local_main_process: - datasets.utils.logging.set_verbosity_warning() - transformers.utils.logging.set_verbosity_warning() - diffusers.utils.logging.set_verbosity_info() - else: - datasets.utils.logging.set_verbosity_error() - transformers.utils.logging.set_verbosity_error() - diffusers.utils.logging.set_verbosity_error() - - # If passed along, set the training seed now. - if args.seed is not None: - set_seed(args.seed) - - # Handle the repository creation - if accelerator.is_main_process: - if args.output_dir is not None: - os.makedirs(args.output_dir, exist_ok=True) - - if args.push_to_hub: - repo_id = create_repo( - repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token - ).repo_id - - # Load the tokenizers - tokenizer_one = AutoTokenizer.from_pretrained( - args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, use_fast=False - ) - tokenizer_two = AutoTokenizer.from_pretrained( - args.pretrained_model_name_or_path, subfolder="tokenizer_2", revision=args.revision, use_fast=False - ) - - # import correct text encoder classes - text_encoder_cls_one = import_model_class_from_model_name_or_path( - args.pretrained_model_name_or_path, args.revision - ) - text_encoder_cls_two = import_model_class_from_model_name_or_path( - args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2" - ) - - # Load scheduler and models - noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") - text_encoder_one = text_encoder_cls_one.from_pretrained( - args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision - ) - text_encoder_two = text_encoder_cls_two.from_pretrained( - args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision - ) - vae_path = ( - args.pretrained_model_name_or_path - if args.pretrained_vae_model_name_or_path is None - else args.pretrained_vae_model_name_or_path - ) - vae = AutoencoderKL.from_pretrained( - vae_path, subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, revision=args.revision - ) - unet = UNet2DConditionModel.from_pretrained( - args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision - ) - # replace_all_layernorms(unet, 'rmsnorm') - # unet=torch.compile(unet, mode="reduce-overhead") - unet=torch.compile(unet) - torch._dynamo.config.suppress_errors = True - # unet = FullyShardedDataParallel(unet, sharding_strategy=torch.distributed.fsdp.ShardingStrategy.SHARD_GRAD_OP) - # import pdb;pdb.set_trace() - # Freeze vae and text encoders. - vae.requires_grad_(False) - text_encoder_one.requires_grad_(False) - text_encoder_two.requires_grad_(False) - - # For mixed precision training we cast all non-trainable weigths to half-precision - # as these weights are only used for inference, keeping weights in full precision is not required. - weight_dtype = torch.float32 - if accelerator.mixed_precision == "fp16": - weight_dtype = torch.float16 - elif accelerator.mixed_precision == "bf16": - weight_dtype = torch.bfloat16 - - # Move unet, vae and text_encoder to device and cast to weight_dtype - # The VAE is in float32 to avoid NaN losses. - if args.pretrained_vae_model_name_or_path is None: - vae.to(accelerator.device, dtype=torch.float32) - else: - vae.to(accelerator.device, dtype=weight_dtype) - text_encoder_one.to(accelerator.device, dtype=weight_dtype) - text_encoder_two.to(accelerator.device, dtype=weight_dtype) - - # Create EMA for the unet. - if args.use_ema: - ema_unet = UNet2DConditionModel.from_pretrained( - args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision - ) - ema_unet = EMAModel(ema_unet.parameters(), model_cls=UNet2DConditionModel, model_config=ema_unet.config) - - if args.enable_xformers_memory_efficient_attention: - if is_xformers_available(): - import xformers - - xformers_version = version.parse(xformers.__version__) - if xformers_version == version.parse("0.0.16"): - logger.warn( - "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." - ) - unet.enable_xformers_memory_efficient_attention() - else: - raise ValueError("xformers is not available. Make sure it is installed correctly") - - def compute_snr(timesteps): - """ - Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849 - """ - alphas_cumprod = noise_scheduler.alphas_cumprod - sqrt_alphas_cumprod = alphas_cumprod**0.5 - sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 - - # Expand the tensors. - # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026 - sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float() - while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape): - sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] - alpha = sqrt_alphas_cumprod.expand(timesteps.shape) - - sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float() - while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape): - sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] - sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape) - - # Compute SNR. - snr = (alpha / sigma) ** 2 - return snr - - # `accelerate` 0.16.0 will have better support for customized saving - if version.parse(accelerate.__version__) >= version.parse("0.16.0"): - # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format - def save_model_hook(models, weights, output_dir): - if args.use_ema: - ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema")) - - for i, model in enumerate(models): - model.save_pretrained(os.path.join(output_dir, "unet")) - - # make sure to pop weight so that corresponding model is not saved again - weights.pop() - - def load_model_hook(models, input_dir): - if args.use_ema: - load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DConditionModel) - ema_unet.load_state_dict(load_model.state_dict()) - ema_unet.to(accelerator.device) - del load_model - - for i in range(len(models)): - # pop models so that they are not loaded again - model = models.pop() - - # load diffusers style into model - load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet") - model.register_to_config(**load_model.config) - - model.load_state_dict(load_model.state_dict()) - del load_model - - accelerator.register_save_state_pre_hook(save_model_hook) - accelerator.register_load_state_pre_hook(load_model_hook) - - if args.gradient_checkpointing: - unet.enable_gradient_checkpointing() - else: - unet.disable_gradient_checkpointing() - - # Enable TF32 for faster training on Ampere GPUs, - # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices - if args.allow_tf32: - torch.backends.cuda.matmul.allow_tf32 = True - - if args.scale_lr: - args.learning_rate = ( - args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes - ) - - # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs - if args.use_8bit_adam: - try: - import bitsandbytes as bnb - except ImportError: - raise ImportError( - "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." - ) - - optimizer_class = bnb.optim.AdamW8bit - else: - optimizer_class = torch.optim.AdamW - optimizer_class = FusedAdam - - # Optimizer creation - params_to_optimize = unet.parameters() - optimizer = optimizer_class( - params_to_optimize, - lr=args.learning_rate, - betas=(args.adam_beta1, args.adam_beta2), - weight_decay=args.adam_weight_decay, - eps=args.adam_epsilon, - ) - - # Get the datasets: you can either provide your own training and evaluation files (see below) - # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub). - - # In distributed training, the load_dataset function guarantees that only one local process can concurrently - # download the dataset. - if args.dataset_name is not None: - # Downloading and loading a dataset from the hub. - dataset = load_dataset( - args.dataset_name, - args.dataset_config_name, - cache_dir=args.cache_dir, - ) - else: - data_files = {} - if args.train_data_dir is not None: - data_files["train"] = os.path.join(args.train_data_dir, "**") - dataset = load_dataset( - "imagefolder", - data_files=data_files, - cache_dir=args.cache_dir, - ) - # See more about loading custom images at - # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder - - # Preprocessing the datasets. - # We need to tokenize inputs and targets. - column_names = dataset["train"].column_names - - # 6. Get the column names for input/target. - dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None) - if args.image_column is None: - image_column = dataset_columns[0] if dataset_columns is not None else column_names[0] - else: - image_column = args.image_column - if image_column not in column_names: - raise ValueError( - f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}" - ) - if args.caption_column is None: - caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1] - else: - caption_column = args.caption_column - if caption_column not in column_names: - raise ValueError( - f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}" - ) - - # Preprocessing the datasets. - train_resize = transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR) - train_crop = transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution) - train_flip = transforms.RandomHorizontalFlip(p=1.0) - train_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]) - - def preprocess_train(examples): - images = [image.convert("RGB") for image in examples[image_column]] - # image aug - original_sizes = [] - all_images = [] - crop_top_lefts = [] - for image in images: - original_sizes.append((image.height, image.width)) - image = train_resize(image) - if args.center_crop: - y1 = max(0, int(round((image.height - args.resolution) / 2.0))) - x1 = max(0, int(round((image.width - args.resolution) / 2.0))) - image = train_crop(image) - else: - y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution)) - image = crop(image, y1, x1, h, w) - if args.random_flip and random.random() < 0.5: - # flip - x1 = image.width - x1 - image = train_flip(image) - crop_top_left = (y1, x1) - crop_top_lefts.append(crop_top_left) - image = train_transforms(image) - all_images.append(image) - - examples["original_sizes"] = original_sizes - examples["crop_top_lefts"] = crop_top_lefts - examples["pixel_values"] = all_images - return examples - - with accelerator.main_process_first(): - if args.max_train_samples is not None: - dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples)) - # Set the training transforms - train_dataset = dataset["train"].with_transform(preprocess_train) - - # Let's first compute all the embeddings so that we can free up the text encoders - # from memory. We will pre-compute the VAE encodings too. - text_encoders = [text_encoder_one, text_encoder_two] - tokenizers = [tokenizer_one, tokenizer_two] - compute_embeddings_fn = functools.partial( - encode_prompt, - text_encoders=text_encoders, - tokenizers=tokenizers, - proportion_empty_prompts=args.proportion_empty_prompts, - caption_column=args.caption_column, - ) - compute_vae_encodings_fn = functools.partial(compute_vae_encodings, vae=vae) - with accelerator.main_process_first(): - from datasets.fingerprint import Hasher - - # fingerprint used by the cache for the other processes to load the result - # details: https://github.com/huggingface/diffusers/pull/4038#discussion_r1266078401 - new_fingerprint = Hasher.hash(args) - new_fingerprint_for_vae = Hasher.hash("vae") - train_dataset = train_dataset.map(compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint) - train_dataset = train_dataset.map( - compute_vae_encodings_fn, - batched=True, - batch_size=args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps, - new_fingerprint=new_fingerprint_for_vae, - ) - - del text_encoders, tokenizers, vae - gc.collect() - torch.cuda.empty_cache() - - def collate_fn(examples): - # for k in examples: - # print(k, type(examples[k])) - model_input = torch.stack([torch.tensor(example["model_input"]) for example in examples]) - original_sizes = [example["original_sizes"] for example in examples] - crop_top_lefts = [example["crop_top_lefts"] for example in examples] - prompt_embeds = torch.stack([torch.tensor(example["prompt_embeds"]) for example in examples]) - pooled_prompt_embeds = torch.stack([torch.tensor(example["pooled_prompt_embeds"]) for example in examples]) - - return { - "model_input": model_input, - "prompt_embeds": prompt_embeds, - "pooled_prompt_embeds": pooled_prompt_embeds, - "original_sizes": original_sizes, - "crop_top_lefts": crop_top_lefts, - } - - # DataLoaders creation: - train_dataloader = torch.utils.data.DataLoader( - train_dataset, - shuffle=True, - collate_fn=collate_fn, - batch_size=args.train_batch_size, - num_workers=args.dataloader_num_workers, - ) - - # Scheduler and math around the number of training steps. - overrode_max_train_steps = False - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) - if args.max_train_steps is None: - args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch - overrode_max_train_steps = True - - lr_scheduler = get_scheduler( - args.lr_scheduler, - optimizer=optimizer, - num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, - num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, - ) - - # Prepare everything with our `accelerator`. - unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - unet, optimizer, train_dataloader, lr_scheduler - ) - - # We need to recalculate our total training steps as the size of the training dataloader may have changed. - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) - if overrode_max_train_steps: - args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch - # Afterwards we recalculate our number of training epochs - args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) - - # We need to initialize the trackers we use, and also store our configuration. - # The trackers initializes automatically on the main process. - if accelerator.is_main_process: - accelerator.init_trackers("text2image-fine-tune-sdxl", config=vars(args)) - - # Train! - total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps - - logger.info("***** Running training *****") - logger.info(f" Num examples = {len(train_dataset)}") - logger.info(f" Num Epochs = {args.num_train_epochs}") - logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") - logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") - logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") - logger.info(f" Total optimization steps = {args.max_train_steps}") - global_step = 0 - first_epoch = 0 - - # Potentially load in the weights and states from a previous save - if args.resume_from_checkpoint: - if args.resume_from_checkpoint != "latest": - path = os.path.basename(args.resume_from_checkpoint) - else: - # Get the most recent checkpoint - dirs = os.listdir(args.output_dir) - dirs = [d for d in dirs if d.startswith("checkpoint")] - dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) - path = dirs[-1] if len(dirs) > 0 else None - - if path is None: - accelerator.print( - f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." - ) - args.resume_from_checkpoint = None - else: - accelerator.print(f"Resuming from checkpoint {path}") - accelerator.load_state(os.path.join(args.output_dir, path)) - global_step = int(path.split("-")[1]) - - resume_global_step = global_step * args.gradient_accumulation_steps - first_epoch = global_step // num_update_steps_per_epoch - resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps) - - # Only show the progress bar once on each machine. - progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) - progress_bar.set_description("Steps") - - # prof = torch.profiler.profile(record_shapes=True, profile_memory=True, with_stack=True) - from torch.profiler import profile, record_function, ProfilerActivity - prof=None - if args.profile: - prof = torch.profiler.profile( - schedule=torch.profiler.schedule(wait=10, warmup=1, active=1, repeat=1), - on_trace_ready=torch.profiler.tensorboard_trace_handler('./prof/trace_sdxl_bs24_torchcompile'), - activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], - record_shapes=True, - with_stack=True, - profile_memory=True) - prof.start() - for epoch in range(first_epoch, args.num_train_epochs): - unet.train() - train_loss = 0.0 - print("epoch", epoch) - beg = time.time() - for step, batch in enumerate(train_dataloader): - # Skip steps until we reach the resumed step - if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: - if step % args.gradient_accumulation_steps == 0: - progress_bar.update(1) - continue - - # if epoch==5 & step==0: - # beg = time.time() - # if epoch==5 & step==5: - # print("time used per step:", (time.time()-beg)/5, flush=True) - with accelerator.accumulate(unet): - # Sample noise that we'll add to the latents - model_input = batch["model_input"].to(accelerator.device) - noise = torch.randn_like(model_input) - if args.noise_offset: - # https://www.crosslabs.org//blog/diffusion-with-offset-noise - noise += args.noise_offset * torch.randn( - (model_input.shape[0], model_input.shape[1], 1, 1), device=model_input.device - ) - - bsz = model_input.shape[0] - # Sample a random timestep for each image - timesteps = torch.randint( - 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device - ) - timesteps = timesteps.long() - - # Add noise to the model input according to the noise magnitude at each timestep - # (this is the forward diffusion process) - noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps) - - # time ids - def compute_time_ids(original_size, crops_coords_top_left): - # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids - target_size = (args.resolution, args.resolution) - add_time_ids = list(original_size + crops_coords_top_left + target_size) - add_time_ids = torch.tensor([add_time_ids]) - add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype) - return add_time_ids - - add_time_ids = torch.cat( - [compute_time_ids(s, c) for s, c in zip(batch["original_sizes"], batch["crop_top_lefts"])] - ) - - # Predict the noise residual - unet_added_conditions = {"time_ids": add_time_ids} - prompt_embeds = batch["prompt_embeds"].to(accelerator.device) - pooled_prompt_embeds = batch["pooled_prompt_embeds"].to(accelerator.device) - unet_added_conditions.update({"text_embeds": pooled_prompt_embeds}) - prompt_embeds = prompt_embeds - model_pred = unet( - noisy_model_input, timesteps, prompt_embeds, added_cond_kwargs=unet_added_conditions - ).sample - - # Get the target for loss depending on the prediction type - if args.prediction_type is not None: - # set prediction_type of scheduler if defined - noise_scheduler.register_to_config(prediction_type=args.prediction_type) - - if noise_scheduler.config.prediction_type == "epsilon": - target = noise - elif noise_scheduler.config.prediction_type == "v_prediction": - target = noise_scheduler.get_velocity(model_input, noise, timesteps) - else: - raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") - - if args.snr_gamma is None: - loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") - else: - # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556. - # Since we predict the noise instead of x_0, the original formulation is slightly changed. - # This is discussed in Section 4.2 of the same paper. - snr = compute_snr(timesteps) - mse_loss_weights = ( - torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr - ) - # We first calculate the original loss. Then we mean over the non-batch dimensions and - # rebalance the sample-wise losses with their respective loss weights. - # Finally, we take the mean of the rebalanced loss. - loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") - loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights - loss = loss.mean() - - # Gather the losses across all processes for logging (if we use distributed training). - avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean() - train_loss += avg_loss.item() / args.gradient_accumulation_steps - - # Backpropagate - accelerator.backward(loss) - if accelerator.sync_gradients: - params_to_clip = unet.parameters() - accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad() - print(f"epoch {epoch}, step {step}, time: {time.time()-beg}") - beg=time.time() - if prof is not None: - prof.step() - - - # if step==21: - # prof.stop() - # prof.export_chrome_trace('./sdxl_prof_bf16_np1_bs32.json') - # import pdb;pdb.set_trace() - # break - - # Checks if the accelerator has performed an optimization step behind the scenes - if accelerator.sync_gradients: - progress_bar.update(1) - global_step += 1 - accelerator.log({"train_loss": train_loss}, step=global_step) - train_loss = 0.0 - - if accelerator.is_main_process: - if global_step % args.checkpointing_steps == 0: - # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` - if args.checkpoints_total_limit is not None: - checkpoints = os.listdir(args.output_dir) - checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] - checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) - - # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints - if len(checkpoints) >= args.checkpoints_total_limit: - num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 - removing_checkpoints = checkpoints[0:num_to_remove] - - logger.info( - f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" - ) - logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") - - for removing_checkpoint in removing_checkpoints: - removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) - shutil.rmtree(removing_checkpoint) - - save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") - accelerator.save_state(save_path) - logger.info(f"Saved state to {save_path}") - - logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "Max MEM": torch.cuda.max_memory_allocated()/1e9} - progress_bar.set_postfix(**logs) - - if global_step >= args.max_train_steps: - break - - if accelerator.is_main_process: - if args.validation_prompt is not None and epoch % args.validation_epochs == 0: - logger.info( - f"Running validation... \n Generating {args.num_validation_images} images with prompt:" - f" {args.validation_prompt}." - ) - if args.use_ema: - # Store the UNet parameters temporarily and load the EMA parameters to perform inference. - ema_unet.store(unet.parameters()) - ema_unet.copy_to(unet.parameters()) - # create pipeline - pipeline = StableDiffusionXLPipeline.from_pretrained( - args.pretrained_model_name_or_path, - unet=accelerator.unwrap_model(unet), - revision=args.revision, - torch_dtype=weight_dtype, - ) - - pipeline = pipeline.to(accelerator.device) - pipeline.set_progress_bar_config(disable=True) - - # run inference - generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None - pipeline_args = {"prompt": args.validation_prompt} - - with torch.cuda.amp.autocast(): - images = [ - pipeline(**pipeline_args, generator=generator, num_inference_steps=25).images[0] - for _ in range(args.num_validation_images) - ] - - for tracker in accelerator.trackers: - if tracker.name == "tensorboard": - np_images = np.stack([np.asarray(img) for img in images]) - tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC") - if tracker.name == "wandb": - tracker.log( - { - "validation": [ - wandb.Image(image, caption=f"{i}: {args.validation_prompt}") - for i, image in enumerate(images) - ] - } - ) - - del pipeline - torch.cuda.empty_cache() - - accelerator.wait_for_everyone() - if accelerator.is_main_process: - unet = accelerator.unwrap_model(unet) - if args.use_ema: - ema_unet.copy_to(unet.parameters()) - - # Final inference - # Load previous pipeline - vae = AutoencoderKL.from_pretrained( - vae_path, - subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, - revision=args.revision, - torch_dtype=weight_dtype, - ) - pipeline = StableDiffusionXLPipeline.from_pretrained( - args.pretrained_model_name_or_path, unet=unet, vae=vae, revision=args.revision, torch_dtype=weight_dtype - ) - pipeline = pipeline.to(accelerator.device) - - # run inference - images = [] - if args.validation_prompt and args.num_validation_images > 0: - generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None - images = [ - pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0] - for _ in range(args.num_validation_images) - ] - - for tracker in accelerator.trackers: - if tracker.name == "tensorboard": - np_images = np.stack([np.asarray(img) for img in images]) - tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC") - if tracker.name == "wandb": - tracker.log( - { - "test": [ - wandb.Image(image, caption=f"{i}: {args.validation_prompt}") - for i, image in enumerate(images) - ] - } - ) - - if args.push_to_hub: - save_model_card( - repo_id, - images=images, - base_model=args.pretrained_model_name_or_path, - prompt=args.instance_prompt, - repo_folder=args.output_dir, - vae_path=args.pretrained_vae_model_name_or_path, - ) - upload_folder( - repo_id=repo_id, - folder_path=args.output_dir, - commit_message="End of training", - ignore_patterns=["step_*", "epoch_*"], - ) - - accelerator.end_training() - - -if __name__ == "__main__": - args = parse_args() - main(args) \ No newline at end of file From a52a8762aeffa0685eda4b2cdcf9f5fc994ce051 Mon Sep 17 00:00:00 2001 From: kimmishi Date: Wed, 30 Aug 2023 11:46:13 +0800 Subject: [PATCH 6/9] remove comments --- src/diffusers/models/tp_utils.py | 8 ++++---- src/diffusers/models/transformer_2d.py | 10 ---------- 2 files changed, 4 insertions(+), 14 deletions(-) diff --git a/src/diffusers/models/tp_utils.py b/src/diffusers/models/tp_utils.py index f02962c24670..8c4ef15f8199 100644 --- a/src/diffusers/models/tp_utils.py +++ b/src/diffusers/models/tp_utils.py @@ -16,9 +16,9 @@ def set_tp_group(group): def get_tensor_model_parallel_world_size(): return dist.get_world_size(get_tp_group()) -class mylinear(nn.Module): +class TpLinear(nn.Module): def __init__(self, fin, fout, bias=True): - super(mylinear, self).__init__() + super(TpLinear, self).__init__() self.weight = Parameter(torch.rand((fin, fout))) self.bias =None if bias: @@ -172,7 +172,7 @@ def __init__(self, fin, fout, bias=True): assert fout%self.tp_world_size==0 self.fout = int(fout/self.tp_world_size) - self.linear = mylinear(fin, self.fout, bias) + self.linear = TpLinear(fin, self.fout, bias) def forward(self, x): @@ -216,7 +216,7 @@ def __init__(self, fin, fout, bias=True, sequence_parallel=False): self.tp_world_size = torch.distributed.get_world_size(tp_group) assert fin%self.tp_world_size==0 self.fin = int(fin/self.tp_world_size) - self.linear = mylinear(self.fin, fout, bias) + self.linear = TpLinear(self.fin, fout, bias) self.sequence_parallel = sequence_parallel diff --git a/src/diffusers/models/transformer_2d.py b/src/diffusers/models/transformer_2d.py index 418e25a1014a..a99a6ec6c9ff 100644 --- a/src/diffusers/models/transformer_2d.py +++ b/src/diffusers/models/transformer_2d.py @@ -271,10 +271,6 @@ def forward( encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 encoder_attention_mask = encoder_attention_mask.unsqueeze(1) - # hidden_states = _split_along_first_dim(hidden_states) - # import pdb;pdb.set_trace() - - # print("before input", torch.cuda.memory_allocated()/1e9, torch.cuda.max_memory_allocated()/1e9) # 1. Input if self.is_input_continuous: batch, _, height, width = hidden_states.shape @@ -294,8 +290,6 @@ def forward( elif self.is_input_patches: hidden_states = self.pos_embed(hidden_states) - # print("after input", torch.cuda.memory_allocated()/1e9, torch.cuda.max_memory_allocated()/1e9) - setattr(hidden_states, "sequence_parallel", True) # input to blocks are sequence parallel # 2. Blocks @@ -322,8 +316,6 @@ def forward( cross_attention_kwargs=cross_attention_kwargs, class_labels=class_labels, ) - # print("after a block", torch.cuda.memory_allocated()/1e9, torch.cuda.max_memory_allocated()/1e9) - # import pdb;pdb.set_trace() # 3. Output if self.is_input_continuous: if not self.use_linear_projection: @@ -360,9 +352,7 @@ def forward( output = hidden_states.reshape( shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size) ) - # output = gather_from_sequence_parallel_region(output) - # print("after output", torch.cuda.memory_allocated()/1e9, torch.cuda.max_memory_allocated()/1e9) if not return_dict: return (output,) From e13fcf88c0204c19f2c8a1220a523b133df80f23 Mon Sep 17 00:00:00 2001 From: kimmishi Date: Thu, 31 Aug 2023 17:09:24 +0800 Subject: [PATCH 7/9] refact sequence parallel code --- src/diffusers/models/attention.py | 78 ++++----------------- src/diffusers/models/attention_processor.py | 10 ++- src/diffusers/models/tp_utils.py | 41 ++++++----- src/diffusers/models/transformer_2d.py | 6 +- src/diffusers/models/unet_2d_blocks.py | 76 ++++++++------------ src/diffusers/models/unet_2d_condition.py | 31 ++++---- 6 files changed, 90 insertions(+), 152 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index f6295ade03ec..29ff024af079 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -62,7 +62,7 @@ def __init__( norm_elementwise_affine: bool = True, norm_type: str = "layer_norm", final_dropout: bool = False, - sequence_parallel = True, + sequence_parallel = False, ): super().__init__() self.only_cross_attention = only_cross_attention @@ -93,6 +93,7 @@ def __init__( bias=attention_bias, cross_attention_dim=cross_attention_dim if only_cross_attention else None, upcast_attention=upcast_attention, + sequence_parallel = sequence_parallel ) # 2. Cross-Attn @@ -113,6 +114,7 @@ def __init__( dropout=dropout, bias=attention_bias, upcast_attention=upcast_attention, + sequence_parallel = sequence_parallel ) # is self-attn if encoder_hidden_states is none else: self.norm2 = None @@ -120,7 +122,8 @@ def __init__( # 3. Feed-forward self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) - self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout) + self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout, + sequence_parallel = sequence_parallel) # let chunk size default to None self._chunk_size = None @@ -140,13 +143,8 @@ def forward( timestep: Optional[torch.LongTensor] = None, cross_attention_kwargs: Dict[str, Any] = None, class_labels: Optional[torch.LongTensor] = None, + sequence_parallel=False ): - # import pdb;pdb.set_trace() - # split - # hidden_states = _split_along_first_dim(hidden_states) - # if not hasattr(hidden_states, 'sequence_parallel'): - # hidden_states = _split_along_first_dim(hidden_states) - # Notice that normalization is always applied before the real computation in the following blocks. # 1. Self-Attention if self.use_ada_layer_norm: @@ -160,7 +158,6 @@ def forward( cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} - norm_hidden_states = gather_from_sequence_parallel_region(norm_hidden_states) attn_output = self.attn1( norm_hidden_states, encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, @@ -169,6 +166,7 @@ def forward( ) if self.use_ada_layer_norm_zero: attn_output = gate_msa.unsqueeze(1) * attn_output + hidden_states = attn_output + hidden_states # 2. Cross-Attention @@ -176,14 +174,12 @@ def forward( norm_hidden_states = ( self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) ) - norm_hidden_states = gather_from_sequence_parallel_region(norm_hidden_states) attn_output = self.attn2( norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=encoder_attention_mask, **cross_attention_kwargs, ) - # import pdb;pdb.set_trace() hidden_states = attn_output + hidden_states # 3. Feed-forward @@ -205,14 +201,12 @@ def forward( dim=self._chunk_dim, ) else: - norm_hidden_states = gather_from_sequence_parallel_region(norm_hidden_states) ff_output = self.ff(norm_hidden_states) if self.use_ada_layer_norm_zero: ff_output = gate_mlp.unsqueeze(1) * ff_output hidden_states = ff_output + hidden_states - # hidden_states = gather_from_sequence_parallel_region(hidden_states) setattr(hidden_states, "sequence_parallel", True) return hidden_states @@ -239,10 +233,12 @@ def __init__( dropout: float = 0.0, activation_fn: str = "geglu", final_dropout: bool = False, + sequence_parallel = False, ): super().__init__() inner_dim = int(dim * mult) dim_out = dim_out if dim_out is not None else dim + self.sequence_parallel = sequence_parallel assert activation_fn == "geglu" act_fn = GEGLU_ColParallel(dim, inner_dim) @@ -254,7 +250,7 @@ def __init__( # project dropout self.net.append(nn.Dropout(dropout)) # project out - proj2 = RowParallelLinear(inner_dim, dim_out, sequence_parallel=True) + proj2 = RowParallelLinear(inner_dim, dim_out, sequence_parallel = sequence_parallel) # LoRACompatibleLinear(inner_dim, dim_out) self.net.append(proj2) # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout @@ -262,63 +258,13 @@ def __init__( self.net.append(nn.Dropout(dropout)) def forward(self, hidden_states): - # import pdb;pdb.set_trace() + if self.sequence_parallel: + hidden_states = gather_from_sequence_parallel_region(hidden_states) for module in self.net: hidden_states = module(hidden_states) return hidden_states -class FeedForwardBKUP(nn.Module): - r""" - A feed-forward layer. - - Parameters: - dim (`int`): The number of channels in the input. - dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. - mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. - dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. - activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. - final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. - """ - - def __init__( - self, - dim: int, - dim_out: Optional[int] = None, - mult: int = 4, - dropout: float = 0.0, - activation_fn: str = "geglu", - final_dropout: bool = False, - ): - super().__init__() - inner_dim = int(dim * mult) - dim_out = dim_out if dim_out is not None else dim - - if activation_fn == "gelu": - act_fn = GELU(dim, inner_dim) - if activation_fn == "gelu-approximate": - act_fn = GELU(dim, inner_dim, approximate="tanh") - elif activation_fn == "geglu": - act_fn = GEGLU(dim, inner_dim) - elif activation_fn == "geglu-approximate": - act_fn = ApproximateGELU(dim, inner_dim) - - self.net = nn.ModuleList([]) - # project in - self.net.append(act_fn) - # project dropout - self.net.append(nn.Dropout(dropout)) - # project out - self.net.append(LoRACompatibleLinear(inner_dim, dim_out)) - # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout - if final_dropout: - self.net.append(nn.Dropout(dropout)) - - def forward(self, hidden_states): - for module in self.net: - hidden_states = module(hidden_states) - return hidden_states - class GELU(nn.Module): r""" GELU activation function with tanh approximation support with `approximate="tanh"`. diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index a07e3a64c10c..8b70f99059f2 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -21,7 +21,7 @@ from ..utils.import_utils import is_xformers_available from .lora import LoRALinearLayer -from .tp_utils import ColParallelLinear,RowParallelLinear +from .tp_utils import ColParallelLinear,RowParallelLinear, gather_from_sequence_parallel_region logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -72,6 +72,7 @@ def __init__( residual_connection: bool = False, _from_deprecated_attn_block=False, processor: Optional["AttnProcessor"] = None, + sequence_parallel = False, ): super().__init__() inner_dim = dim_head * heads @@ -97,6 +98,7 @@ def __init__( self.added_kv_proj_dim = added_kv_proj_dim self.only_cross_attention = only_cross_attention + self.sequence_parallel = sequence_parallel if self.added_kv_proj_dim is None and self.only_cross_attention: raise ValueError( @@ -151,7 +153,7 @@ def __init__( self.add_v_proj = nn.Linear(added_kv_proj_dim, inner_dim) self.to_out = nn.ModuleList([]) - self.to_out.append(RowParallelLinear(inner_dim, query_dim, bias=out_bias, sequence_parallel=True)) + self.to_out.append(RowParallelLinear(inner_dim, query_dim, bias=out_bias, sequence_parallel=sequence_parallel)) self.to_out.append(nn.Dropout(dropout)) # set attention processor @@ -317,6 +319,10 @@ def set_processor(self, processor: "AttnProcessor"): self.processor = processor def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs): + if self.sequence_parallel: + hidden_states = gather_from_sequence_parallel_region(hidden_states) + if encoder_hidden_states is not None: + encoder_hidden_states = gather_from_sequence_parallel_region(encoder_hidden_states) # The `Attention` class can call different attention processors / attention functions # here we simply pass along all tensors to the selected processor class # For standard processors that are defined here, `**cross_attention_kwargs` is empty diff --git a/src/diffusers/models/tp_utils.py b/src/diffusers/models/tp_utils.py index 8c4ef15f8199..e18e023aba8f 100644 --- a/src/diffusers/models/tp_utils.py +++ b/src/diffusers/models/tp_utils.py @@ -16,20 +16,6 @@ def set_tp_group(group): def get_tensor_model_parallel_world_size(): return dist.get_world_size(get_tp_group()) -class TpLinear(nn.Module): - def __init__(self, fin, fout, bias=True): - super(TpLinear, self).__init__() - self.weight = Parameter(torch.rand((fin, fout))) - self.bias =None - if bias: - self.bias = Parameter(torch.zeros(fout)) - - def forward(self, x): - out = torch.matmul(x, self.weight) - if self.bias is not None: - out +=self.bias - return out - def maybe_gather_from_sequence_parallel(inp): if hasattr(inp, "sequence_parallel") and inp.sequence_parallel: return gather_from_sequence_parallel_region(inp) @@ -44,6 +30,11 @@ def set_sequence_parallel_attr(inp, value=True): setattr(inp, "sequence_parallel", value) return inp +def is_squence_parallel_tensor(inp): + return hasattr(inp, "sequence_parallel") and inp.sequence_parallel==True + + +# the following Reduce and Gather functions are adopted from https://github.com/NVIDIA/Megatron-LM class _ReduceFromModelParallelRegion(torch.autograd.Function): """All-reduce the input from the model parallel region.""" @@ -112,7 +103,7 @@ def _split_along_first_dim(input_): output = input_[dim_offset:dim_offset+local_dim_size].contiguous() - setattr(output, "sequence_parallel", True) + set_sequence_parallel_attr(output, True) return output class _ReduceScatterToSequenceParallelRegion(torch.autograd.Function): @@ -158,12 +149,27 @@ def backward(ctx, grad_output): def gather_from_sequence_parallel_region(input_, tensor_parallel_output_grad=True): output = _GatherFromSequenceParallelRegion.apply(input_, tensor_parallel_output_grad) - setattr(output, "sequence_parallel", False) + set_sequence_parallel_attr(output, False) return output def reduce_scatter_to_sequence_parallel_region(input_): return _ReduceScatterToSequenceParallelRegion.apply(input_) + +class TpLinear(nn.Module): + def __init__(self, fin, fout, bias=True): + super(TpLinear, self).__init__() + self.weight = Parameter(torch.rand((fin, fout))) + self.bias =None + if bias: + self.bias = Parameter(torch.zeros(fout)) + + def forward(self, x): + out = torch.matmul(x, self.weight) + if self.bias is not None: + out +=self.bias + return out + class ColParallelLinear(nn.Module): def __init__(self, fin, fout, bias=True): super(ColParallelLinear, self).__init__() @@ -174,7 +180,6 @@ def __init__(self, fin, fout, bias=True): self.linear = TpLinear(fin, self.fout, bias) - def forward(self, x): """ 1. automatically split fout @@ -205,8 +210,6 @@ def init_weight_from_full_attn(self, fullwt): cat_full = torch.cat(splits, dim=-1) with torch.no_grad(): - # org_shape = self.linear.weight.shape - # self.linear.weight = self.linear.weight.reshape(dim, nh//ws, dim3//nh) self.linear.weight.copy_(cat_full) class RowParallelLinear(nn.Module): diff --git a/src/diffusers/models/transformer_2d.py b/src/diffusers/models/transformer_2d.py index a99a6ec6c9ff..4633b069812e 100644 --- a/src/diffusers/models/transformer_2d.py +++ b/src/diffusers/models/transformer_2d.py @@ -26,8 +26,6 @@ from .lora import LoRACompatibleConv, LoRACompatibleLinear from .modeling_utils import ModelMixin -from .tp_utils import _split_along_first_dim, gather_from_sequence_parallel_region - @dataclass class Transformer2DModelOutput(BaseOutput): """ @@ -92,6 +90,7 @@ def __init__( upcast_attention: bool = False, norm_type: str = "layer_norm", norm_elementwise_affine: bool = True, + sequence_parallel=False, ): super().__init__() self.use_linear_projection = use_linear_projection @@ -184,6 +183,7 @@ def __init__( upcast_attention=upcast_attention, norm_type=norm_type, norm_elementwise_affine=norm_elementwise_affine, + sequence_parallel=sequence_parallel ) for d in range(num_layers) ] @@ -290,8 +290,6 @@ def forward( elif self.is_input_patches: hidden_states = self.pos_embed(hidden_states) - setattr(hidden_states, "sequence_parallel", True) - # input to blocks are sequence parallel # 2. Blocks for block in self.transformer_blocks: if self.training and self.gradient_checkpointing: diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index 33c879ab99a2..6789762fa01e 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -25,8 +25,6 @@ from .dual_transformer_2d import DualTransformer2DModel from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, KDownsample2D, KUpsample2D, ResnetBlock2D, Upsample2D from .transformer_2d import Transformer2DModel -from .tp_utils import gather_from_sequence_parallel_region, _split_along_first_dim,maybe_gather_from_sequence_parallel, \ - maybe_split_into_sequnce_parallel, set_sequence_parallel_attr logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -55,6 +53,7 @@ def get_down_block( cross_attention_norm=None, attention_head_dim=None, downsample_type=None, + sequence_parallel=True, ): # If attn head dim is not defined, we default it to the number of heads if attention_head_dim is None: @@ -130,6 +129,7 @@ def get_down_block( only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, + sequence_parallel=sequence_parallel ) elif down_block_type == "SimpleCrossAttnDownBlock2D": if cross_attention_dim is None: @@ -250,6 +250,7 @@ def get_up_block( cross_attention_norm=None, attention_head_dim=None, upsample_type=None, + sequence_parallel=True, ): # If attn head dim is not defined, we default it to the number of heads if attention_head_dim is None: @@ -308,6 +309,7 @@ def get_up_block( only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, + sequence_parallel=sequence_parallel ) elif up_block_type == "SimpleCrossAttnUpBlock2D": if cross_attention_dim is None: @@ -557,6 +559,7 @@ def __init__( dual_cross_attention=False, use_linear_projection=False, upcast_attention=False, + sequence_parallel=True, ): super().__init__() @@ -593,6 +596,7 @@ def __init__( norm_num_groups=resnet_groups, use_linear_projection=use_linear_projection, upcast_attention=upcast_attention, + sequence_parallel=sequence_parallel ) ) else: @@ -625,6 +629,7 @@ def __init__( self.resnets = nn.ModuleList(resnets) self.gradient_checkpointing = False + self.sequence_parallel = sequence_parallel def forward( self, @@ -635,8 +640,6 @@ def forward( cross_attention_kwargs: Optional[Dict[str, Any]] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, ) -> torch.FloatTensor: - hidden_states = maybe_split_into_sequnce_parallel(hidden_states) - temb = maybe_split_into_sequnce_parallel(temb) hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): if self.training and self.gradient_checkpointing: @@ -674,11 +677,10 @@ def custom_forward(*inputs): encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] - hidden_states = resnet(hidden_states, temb) - hidden_states = set_sequence_parallel_attr(hidden_states) - # hidden_states = gather_from_sequence_parallel_region(hidden_states) + # if self.sequence_parallel: + # hidden_states = set_sequence_parallel_attr(hidden_states) return hidden_states @@ -789,7 +791,6 @@ def forward( # mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask mask = attention_mask - hidden_states = _split_along_first_dim(hidden_states) hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): # attn @@ -803,7 +804,6 @@ def forward( # resnet hidden_states = resnet(hidden_states, temb) - hidden_states = gather_from_sequence_parallel_region(hidden_states) return hidden_states @@ -942,6 +942,7 @@ def __init__( use_linear_projection=False, only_cross_attention=False, upcast_attention=False, + sequence_parallel=True, ): super().__init__() resnets = [] @@ -978,6 +979,7 @@ def __init__( use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, + sequence_parallel=sequence_parallel, ) ) else: @@ -1006,6 +1008,7 @@ def __init__( self.downsamplers = None self.gradient_checkpointing = False + self.sequence_parallel = sequence_parallel def forward( self, @@ -1021,8 +1024,6 @@ def forward( blocks = list(zip(self.resnets, self.attentions)) - hidden_states = maybe_split_into_sequnce_parallel(hidden_states) - temb = maybe_split_into_sequnce_parallel(temb) for i, (resnet, attn) in enumerate(blocks): if self.training and self.gradient_checkpointing: @@ -1051,9 +1052,7 @@ def custom_forward(*inputs): return_dict=False, )[0] else: - # print("before resnet:", torch.cuda.memory_allocated()/1e9, torch.cuda.max_memory_allocated()/1e9) hidden_states = resnet(hidden_states, temb) - # print("after resnet:", torch.cuda.memory_allocated()/1e9, torch.cuda.max_memory_allocated()/1e9) hidden_states = attn( hidden_states, @@ -1064,25 +1063,23 @@ def custom_forward(*inputs): return_dict=False, )[0] - set_sequence_parallel_attr(hidden_states) - # print("after attn:", torch.cuda.memory_allocated()/1e9, torch.cuda.max_memory_allocated()/1e9) + # if self.sequence_parallel: + # set_sequence_parallel_attr(hidden_states) - # full_hidden_states = gather_from_sequence_parallel_region(hidden_states) # apply additional residuals to the output of the last pair of resnet and attention blocks if i == len(blocks) - 1 and additional_residuals is not None: hidden_states = hidden_states + additional_residuals output_states = output_states + (hidden_states,) - # import pdb;pdb.set_trace() if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states) - set_sequence_parallel_attr(hidden_states) + # if self.sequence_parallel: + # set_sequence_parallel_attr(hidden_states) output_states = output_states + (hidden_states,) - set_sequence_parallel_attr(hidden_states) return hidden_states, output_states @@ -1102,6 +1099,7 @@ def __init__( output_scale_factor=1.0, add_downsample=True, downsample_padding=1, + sequence_parallel=True, ): super().__init__() resnets = [] @@ -1137,14 +1135,10 @@ def __init__( self.downsamplers = None self.gradient_checkpointing = False + self.sequence_parallel = sequence_parallel def forward(self, hidden_states, temb=None): output_states = () - # print("before down2d resnet:", torch.cuda.memory_allocated()/1e9, torch.cuda.max_memory_allocated()/1e9) - - hidden_states = maybe_split_into_sequnce_parallel(hidden_states) - if temb is not None: - temb = maybe_split_into_sequnce_parallel(temb) for resnet in self.resnets: if self.training and self.gradient_checkpointing: @@ -1164,19 +1158,18 @@ def custom_forward(*inputs): ) else: hidden_states = resnet(hidden_states, temb) - set_sequence_parallel_attr(hidden_states) - # import pdb;pdb.set_trace() + # if self.sequence_parallel: + # set_sequence_parallel_attr(hidden_states) + output_states = output_states + (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states) - - set_sequence_parallel_attr(hidden_states) + # if self.sequence_parallel: + # set_sequence_parallel_attr(hidden_states) output_states = output_states + (hidden_states,) - # print("after down2d resnet:", torch.cuda.memory_allocated()/1e9, torch.cuda.max_memory_allocated()/1e9) - set_sequence_parallel_attr(hidden_states) return hidden_states, output_states @@ -2096,6 +2089,7 @@ def __init__( use_linear_projection=False, only_cross_attention=False, upcast_attention=False, + sequence_parallel=False ): super().__init__() resnets = [] @@ -2134,6 +2128,7 @@ def __init__( use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, + sequence_parallel=sequence_parallel ) ) else: @@ -2167,16 +2162,12 @@ def forward( upsample_size: Optional[int] = None, attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, + sequence_parallel=False, ): - hidden_states = maybe_split_into_sequnce_parallel(hidden_states) - temb = maybe_split_into_sequnce_parallel(temb) - for resnet, attn in zip(self.resnets, self.attentions): # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] - # import pdb;pdb.set_trace() - res_hidden_states = maybe_split_into_sequnce_parallel(res_hidden_states) hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) if self.training and self.gradient_checkpointing: @@ -2220,8 +2211,6 @@ def custom_forward(*inputs): for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states, upsample_size) - # hidden_states = gather_from_sequence_parallel_region(hidden_states) - hidden_states = set_sequence_parallel_attr(hidden_states) return hidden_states @@ -2241,6 +2230,7 @@ def __init__( resnet_pre_norm: bool = True, output_scale_factor=1.0, add_upsample=True, + sequence_parallel=True, ): super().__init__() resnets = [] @@ -2272,16 +2262,12 @@ def __init__( self.upsamplers = None self.gradient_checkpointing = False + # self.sequence_parallel = sequence_parallel - def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): - # hidden_states = maybe_gather_from_sequence_parallel(hidden_states) - hidden_states = maybe_split_into_sequnce_parallel(hidden_states) - if temb is not None: - temb = maybe_split_into_sequnce_parallel(temb) + def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, sequence_parallel=False): for resnet in self.resnets: # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] - res_hidden_states = maybe_split_into_sequnce_parallel(res_hidden_states) res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) @@ -2303,13 +2289,13 @@ def custom_forward(*inputs): ) else: hidden_states = resnet(hidden_states, temb) - set_sequence_parallel_attr(hidden_states) + # if self.sequence_parallel: + # set_sequence_parallel_attr(hidden_states) if self.upsamplers is not None: for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states, upsample_size) - hidden_states = maybe_gather_from_sequence_parallel(hidden_states) return hidden_states diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 459cd51b1bdd..480ddb6b456a 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -41,7 +41,7 @@ get_down_block, get_up_block, ) -from .tp_utils import _split_along_first_dim, gather_from_sequence_parallel_region +from .tp_utils import gather_from_sequence_parallel_region, _split_along_first_dim logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -202,6 +202,7 @@ def __init__( mid_block_only_cross_attention: Optional[bool] = None, cross_attention_norm: Optional[str] = None, addition_embed_type_num_heads=64, + sequence_parallel: bool = True, ): super().__init__() @@ -450,6 +451,7 @@ def __init__( resnet_out_scale_factor=resnet_out_scale_factor, cross_attention_norm=cross_attention_norm, attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, + sequence_parallel=sequence_parallel ) self.down_blocks.append(down_block) @@ -469,6 +471,7 @@ def __init__( dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, upcast_attention=upcast_attention, + sequence_parallel=sequence_parallel ) elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn": self.mid_block = UNetMidBlock2DSimpleCrossAttn( @@ -539,6 +542,7 @@ def __init__( resnet_out_scale_factor=resnet_out_scale_factor, cross_attention_norm=cross_attention_norm, attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, + sequence_parallel=sequence_parallel ) self.up_blocks.append(up_block) prev_output_channel = output_channel @@ -560,6 +564,8 @@ def __init__( block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding ) + self.sequence_parallel=sequence_parallel + @property def attn_processors(self) -> Dict[str, AttentionProcessor]: r""" @@ -893,20 +899,16 @@ def forward( image_embeds = added_cond_kwargs.get("image_embeds") encoder_hidden_states = self.encoder_hid_proj(image_embeds) # 2. pre-process - # print("before conv_in", torch.cuda.memory_allocated()/1e9, torch.cuda.max_memory_allocated()/1e9) - - # sample = _split_along_first_dim(sample) + if self.sequence_parallel: + sample = _split_along_first_dim(sample) + emb = _split_along_first_dim(emb) + encoder_hidden_states = _split_along_first_dim(encoder_hidden_states) sample = self.conv_in(sample) - # print("after conv_in", torch.cuda.memory_allocated()/1e9, torch.cuda.max_memory_allocated()/1e9) # 3. down is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None - def get_size(t): - ts = 4 if t.dtype==torch.float else 2 - return t.numel() * ts - # import pdb;pdb.set_trace() down_block_res_samples = (sample,) for downsample_block in self.down_blocks: if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: @@ -931,9 +933,7 @@ def get_size(t): sample += down_block_additional_residuals.pop(0) down_block_res_samples += res_samples - # print("after block:", torch.cuda.memory_allocated()/1e9, torch.cuda.max_memory_allocated()/1e9) - # print("down_block_res_samples total mem",sum([get_size(t) for t in down_block_res_samples])/1e6) - # import pdb;pdb.set_trace() + if is_controlnet: new_down_block_res_samples = () @@ -945,8 +945,6 @@ def get_size(t): down_block_res_samples = new_down_block_res_samples - # import pdb;pdb.set_trace() - # print(torch.cuda.memory_allocated()/1e9, torch.cuda.max_memory_allocated()/1e9) # 4. mid if self.mid_block is not None: sample = self.mid_block( @@ -957,8 +955,6 @@ def get_size(t): cross_attention_kwargs=cross_attention_kwargs, encoder_attention_mask=encoder_attention_mask, ) - # import pdb;pdb.set_trace() - # print("after mid block",torch.cuda.memory_allocated()/1e9, torch.cuda.max_memory_allocated()/1e9) if is_controlnet: sample = sample + mid_block_additional_residual @@ -996,6 +992,9 @@ def get_size(t): sample = self.conv_act(sample) sample = self.conv_out(sample) + if self.sequence_parallel: + sample = gather_from_sequence_parallel_region(sample) + if not return_dict: return (sample,) From 2f687782639769ca9074da037babb14476cfbf74 Mon Sep 17 00:00:00 2001 From: KimmiShi Date: Mon, 18 Sep 2023 15:30:02 +0800 Subject: [PATCH 8/9] Create train_unet.py --- examples/text_to_image/train_unet.py | 164 +++++++++++++++++++++++++++ 1 file changed, 164 insertions(+) create mode 100644 examples/text_to_image/train_unet.py diff --git a/examples/text_to_image/train_unet.py b/examples/text_to_image/train_unet.py new file mode 100644 index 000000000000..a07671cb937e --- /dev/null +++ b/examples/text_to_image/train_unet.py @@ -0,0 +1,164 @@ +import diffusers +from diffusers import ( + AutoencoderKL, + DDPMScheduler, + StableDiffusionXLPipeline, + UNet2DConditionModel, +) +from apex.optimizers import FusedAdam +import torch +import torch.nn.functional as F +from accelerate import Accelerator +import time +import functools + +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.data.distributed import DistributedSampler +from torch.distributed.optim import ZeroRedundancyOptimizer +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp import ShardingStrategy + + +def enable_tf32(): + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + +def batch_inp(bs20inp, target_bs): + mbs=2 + bs4inp = bs20inp[:mbs] + if target_bs10: + perf_times.append(time.time()-beg) + beg=time.time() + # prof.step() + + # if torch.distributed.get_rank()==0: + print("max mem", torch.cuda.max_memory_allocated()/1e9) + print(perf_times) + # prof.stop() + + + +enable_tf32() +# rank, world_size, port, addr=setup_distributed_slurm() + +pretrained_model_name_or_path="stabilityai/stable-diffusion-xl-base-1.0" + +pretrained_model_name_or_path="CompVis/stable-diffusion-v1-4" +unet = UNet2DConditionModel.from_pretrained( + pretrained_model_name_or_path, subfolder="unet", revision=None, + low_cpu_mem_usage=False, device_map=None +).cuda() +unet.train() +# unet = unet.to(memory_format=torch.channels_last) + +vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae").cuda() + + +# optimizer_class = FusedAdam +optimizer_class = functools.partial(torch.optim.Adam,fused = True) +# optimizer_class = torch.optim.AdamW +train(unet, vae, optimizer_class, '/mnt/petrelfs/share_data/shidongxing/tensors20/', 16, + use_amp=True, use_zero=False, h=512, w=512, is_xl ='xl' in pretrained_model_name_or_path) From 884d51329f1be449a11b5374e165ce45c468878c Mon Sep 17 00:00:00 2001 From: KimmiShi Date: Mon, 18 Sep 2023 15:30:37 +0800 Subject: [PATCH 9/9] Update train_unet.py --- examples/text_to_image/train_unet.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/text_to_image/train_unet.py b/examples/text_to_image/train_unet.py index a07671cb937e..227fbd3dd691 100644 --- a/examples/text_to_image/train_unet.py +++ b/examples/text_to_image/train_unet.py @@ -39,7 +39,7 @@ def batch_inp(bs20inp, target_bs): -def train(model, vae, optimizer_class, path, batchsize, use_zero=False, use_amp=True, h=512, w=512, is_xl=False): +def train(model, vae, optimizer_class, batchsize, use_zero=False, use_amp=True, h=512, w=512, is_xl=False): timesteps = torch.arange(batchsize, dtype=torch.int64).cuda()+100 prompt_embeds = torch.rand([batchsize,77,768], dtype=torch.float16).cuda() time_ids = torch.rand([batchsize,6], dtype=torch.float16).cuda() @@ -160,5 +160,5 @@ def train(model, vae, optimizer_class, path, batchsize, use_zero=False, use_amp= # optimizer_class = FusedAdam optimizer_class = functools.partial(torch.optim.Adam,fused = True) # optimizer_class = torch.optim.AdamW -train(unet, vae, optimizer_class, '/mnt/petrelfs/share_data/shidongxing/tensors20/', 16, +train(unet, vae, optimizer_class, 16, use_amp=True, use_zero=False, h=512, w=512, is_xl ='xl' in pretrained_model_name_or_path)