From 5d75519888aa57bb9c4de155817210310ca60c49 Mon Sep 17 00:00:00 2001 From: zxgx Date: Thu, 26 Dec 2024 15:41:42 +0800 Subject: [PATCH 1/6] fix contiguous memory for opensora --- videosys/core/dcp/profiler.py | 30 ++++++++++++------- videosys/core/distributed/parallel_mgr.py | 1 + .../transformers/open_sora_transformer_3d.py | 2 +- 3 files changed, 22 insertions(+), 11 deletions(-) diff --git a/videosys/core/dcp/profiler.py b/videosys/core/dcp/profiler.py index a91313b3..64741839 100644 --- a/videosys/core/dcp/profiler.py +++ b/videosys/core/dcp/profiler.py @@ -310,6 +310,7 @@ def _load_profile(self): self.latest_raw_result = None self.raw_results = [] self.dp_results = [] + self.sp_detail_results = [] logging.info(f"Profile results: {pformat(self.profile_results, sort_dicts=False)}") if self.dynamic_sp and not self.dynamic_recompute and not self.auto_grad_acc: @@ -635,7 +636,7 @@ def profile(self, batch, model, gas): if self.auto_grad_acc: self.dp_results.append(result_row) else: - self.latest_raw_result = result_row + self.sp_detail_results.append(result_row) self.detail_results.append(result_row) @@ -648,14 +649,23 @@ def profile(self, batch, model, gas): self.next_warmup_iter = not self.auto_grad_acc else: if not self.dynamic_recompute and not self.auto_grad_acc: - if bs == 1: - if self.logger: - self.logger.info( - f">>> [Profiling] bucket {ar_name} {num_frame} cannot fit into sp: {sp_size}" - ) + if bs == 1 and self.logger: + self.logger.info( + f">>> [Profiling] bucket {ar_name} {num_frame} cannot fit into sp: {sp_size}" + ) else: - assert self.latest_raw_result is not None - self.dp_results.append(self.latest_raw_result) + last = self.sp_detail_results[-1] + throughput = last[2] / last[3] / last[4] + if len(self.sp_detail_results)>1: + prev = self.sp_detail_results[-2] + prev_throughput = prev[2] / prev[3] / prev[4] + if prev_throughput > throughput: + self.dp_results.append(prev) + else: + self.dp_results.append(last) + else: + self.dp_results.append(last) + self.sp_detail_results = [] if sp_size < self.max_sp: self.next_sp_size = sp_size * 2 @@ -735,13 +745,13 @@ def profile(self, batch, model, gas): pred_full_time, pred_full_mem = self.estimate_overhead(self.latest_raw_result) cur_throughput = bs / sp_size / pred_full_time - if len(self.dp_results) > 0: + if len(self.dp_results) > 1: prev_row = self.dp_results[-2] prev_time, prev_mem = self.estimate_overhead(prev_row) throughput = prev_row.bs / prev_row.sp_size / prev_time # override for empty cache operation caused slow down - if (throughput / cur_throughput) > 2: + if (throughput / cur_throughput) > 1.5: bs = prev_row.bs sp_size = prev_row.sp_size pred_full_time = prev_time diff --git a/videosys/core/distributed/parallel_mgr.py b/videosys/core/distributed/parallel_mgr.py index 50a4a4a0..d8a0be75 100644 --- a/videosys/core/distributed/parallel_mgr.py +++ b/videosys/core/distributed/parallel_mgr.py @@ -125,6 +125,7 @@ def set_distributed_state(distributed_profile=None): node_rank = int(os.getenv("NODE_RANK", os.getenv("OMPI_COMM_WORLD_NODE_RANK", "0"))) node_size = int(os.getenv("NNODES", "1")) + print(f">>> [Distributed] Rank: {rank}/{world_size}, local rank: {os.getenv('OMPI_COMM_WORLD_LOCAL_RANK', None)}") if distributed_profile: "launch multiple single-node instances for fast profile" assert world_size % device_count == 0 diff --git a/videosys/models/transformers/open_sora_transformer_3d.py b/videosys/models/transformers/open_sora_transformer_3d.py index 33a622cb..1bbf2b1b 100644 --- a/videosys/models/transformers/open_sora_transformer_3d.py +++ b/videosys/models/transformers/open_sora_transformer_3d.py @@ -590,7 +590,7 @@ def forward( y, y_lens = self.encode_text(y, mask) # === get x embed === - x = self.x_embedder(x) # [B, N, C] + x = self.x_embedder(x).contiguous() # [B, N, C] x = rearrange(x, "B (T S) C -> B T S C", T=T, S=S) x = x + pos_emb From 73770395144d46577682e7116f894521f5049ed0 Mon Sep 17 00:00:00 2001 From: zxgx Date: Mon, 30 Dec 2024 16:13:58 +0800 Subject: [PATCH 2/6] add dcp support for cogvidex-5b, pass profiling --- .../configs/benchmarks/baseline.yaml | 68 + .../configs/benchmarks/dcp_inter.yaml | 71 ++ .../configs/benchmarks/dcp_inter_ckpt.yaml | 71 ++ .../configs/benchmarks/dcp_intra.yaml | 70 ++ examples/training/cogvideox/train.py | 514 ++++++++ videosys/core/dcp/profiler.py | 140 ++- .../transformers/cogvideox_transformer_3d.py | 67 +- .../schedulers/scheduling_dpm_cogvideox.py | 79 ++ .../training/datasets/cogvideox/__init__.py | 0 .../training/datasets/cogvideox/aspect.py | 642 ++++++++++ .../training/datasets/cogvideox/bucket.py | 151 +++ .../training/datasets/cogvideox/dataloader.py | 138 +++ .../training/datasets/cogvideox/datasets.py | 549 +++++++++ .../training/datasets/cogvideox/read_video.py | 258 ++++ .../training/datasets/cogvideox/sampler.py | 1098 +++++++++++++++++ videosys/training/datasets/cogvideox/utils.py | 363 ++++++ .../datasets/cogvideox/video_transforms.py | 520 ++++++++ 17 files changed, 4733 insertions(+), 66 deletions(-) create mode 100644 examples/training/cogvideox/configs/benchmarks/baseline.yaml create mode 100644 examples/training/cogvideox/configs/benchmarks/dcp_inter.yaml create mode 100644 examples/training/cogvideox/configs/benchmarks/dcp_inter_ckpt.yaml create mode 100644 examples/training/cogvideox/configs/benchmarks/dcp_intra.yaml create mode 100644 examples/training/cogvideox/train.py create mode 100644 videosys/training/datasets/cogvideox/__init__.py create mode 100644 videosys/training/datasets/cogvideox/aspect.py create mode 100644 videosys/training/datasets/cogvideox/bucket.py create mode 100644 videosys/training/datasets/cogvideox/dataloader.py create mode 100644 videosys/training/datasets/cogvideox/datasets.py create mode 100644 videosys/training/datasets/cogvideox/read_video.py create mode 100644 videosys/training/datasets/cogvideox/sampler.py create mode 100644 videosys/training/datasets/cogvideox/utils.py create mode 100644 videosys/training/datasets/cogvideox/video_transforms.py diff --git a/examples/training/cogvideox/configs/benchmarks/baseline.yaml b/examples/training/cogvideox/configs/benchmarks/baseline.yaml new file mode 100644 index 00000000..25764076 --- /dev/null +++ b/examples/training/cogvideox/configs/benchmarks/baseline.yaml @@ -0,0 +1,68 @@ +zipf_offset: 5 +outputs: exp/cogvideox/baseline +profile_path: exp/cogvideox/profile/baseline +sp_size: 1 +dummy_dataset: true +dummy_data_size: 2000 +verbose: true +calculate_imbalance: true + + +# ==== training config ==== + +# preprocess embedding +data_path: "./assets/example_data/demo_preprocess.csv" +preprocessed_data: false +drop_last: true + +# train +ckpt_path: "THUDM/CogVideoX-5b" +grad_checkpoint: True +num_workers: 8 +dtype: "bf16" + +# log +seed: 42 +epochs: 1 +log_every: 1e10 + +# optimization +grad_clip: 1.0 +lr: 1e-8 +ema_decay: 0.99 +adam_eps: 1e-15 +warmup_steps: 10 + +# data +image_mixing_frac: 50 +num_bucket_build_workers: 16 +bucket_config: + "144p": {1: [1.0, 345], 51: [1.0, 48], 102: [1.0, 25], 204: [1.0, 12], 408: [1.0, 6]} + "240p": {1: [1.0, 128], 51: [1.0, 16], 102: [1.0, 8], 204: [1.0, 4], 408: [1.0, 2]} + "360p": {1: [1.0, 64], 51: [1.0, 7], 102: [1.0, 4], 204: [1.0, 2], 408: [1.0, 1]} + "480p": {1: [1.0, 32], 51: [1.0, 4], 102: [1.0, 2], 204: [1.0, 1], 408: [1.0, 1]} + "720p": {1: [1.0, 14], 51: [1.0, 1], 102: [1.0, 1], 204: [1.0, 1], 408: [1.0, 1]} + +# override default common ar +# for benchmark, we use single ar for all resolutions +# otherwise the data will be too sparse +common_ar: + "144p": {"0.56": [144, 256]} + "240p": {"0.56": [240, 426]} + "360p": {"0.56": [360, 640]} + "480p": {"0.56": [480, 720]} + "720p": {"0.56": [720, 1280]} + +# mask +mask_ratios: { + "random": 0.01, + "intepolate": 0.002, + "quarter_random": 0.002, + "quarter_head": 0.002, + "quarter_tail": 0.002, + "quarter_head_tail": 0.002, + "image_random": 0.0, + "image_head": 0.22, + "image_tail": 0.005, + "image_head_tail": 0.005, +} diff --git a/examples/training/cogvideox/configs/benchmarks/dcp_inter.yaml b/examples/training/cogvideox/configs/benchmarks/dcp_inter.yaml new file mode 100644 index 00000000..aa6b6ee4 --- /dev/null +++ b/examples/training/cogvideox/configs/benchmarks/dcp_inter.yaml @@ -0,0 +1,71 @@ +zipf_offset: 5 +outputs: exp/cogvideox/dcp_inter +profile_path: exp/cogvideox/profile/dcp_inter +dynamic_sp: true +dynamic_recompute: false +auto_grad_accumulation: true +dummy_dataset: true +dummy_data_size: 2000 +verbose: true +calculate_imbalance: true +max_grad_accumulation_steps: 5 + + +# ==== training config ==== + +# preprocess embedding +data_path: "./assets/example_data/demo_preprocess.csv" +preprocessed_data: true +drop_last: true + +# train +ckpt_path: "THUDM/CogVideoX-5b" +grad_checkpoint: True +num_workers: 8 +dtype: "bf16" + +# log +seed: 42 +epochs: 1 +log_every: 1e10 + +# optimization +grad_clip: 1.0 +lr: 1e-8 +ema_decay: 0.99 +adam_eps: 1e-15 +warmup_steps: 10 + +# data +image_mixing_frac: 50 +num_bucket_build_workers: 16 +bucket_config: + "144p": {1: [1.0, 345], 51: [1.0, 48], 102: [1.0, 25], 204: [1.0, 12], 408: [1.0, 6]} + "240p": {1: [1.0, 128], 51: [1.0, 16], 102: [1.0, 8], 204: [1.0, 4], 408: [1.0, 2]} + "360p": {1: [1.0, 64], 51: [1.0, 7], 102: [1.0, 4], 204: [1.0, 2], 408: [1.0, 1]} + "480p": {1: [1.0, 32], 51: [1.0, 4], 102: [1.0, 2], 204: [1.0, 1], 408: [1.0, 1]} + "720p": {1: [1.0, 14], 51: [1.0, 1], 102: [1.0, 1], 204: [1.0, 1], 408: [1.0, 1]} + +# override default common ar +# for benchmark, we use single ar for all resolutions +# otherwise the data will be too sparse +common_ar: + "144p": {"0.56": [144, 256]} + "240p": {"0.56": [240, 426]} + "360p": {"0.56": [360, 640]} + "480p": {"0.56": [480, 720]} + "720p": {"0.56": [720, 1280]} + +# mask +mask_ratios: { + "random": 0.01, + "intepolate": 0.002, + "quarter_random": 0.002, + "quarter_head": 0.002, + "quarter_tail": 0.002, + "quarter_head_tail": 0.002, + "image_random": 0.0, + "image_head": 0.22, + "image_tail": 0.005, + "image_head_tail": 0.005, +} diff --git a/examples/training/cogvideox/configs/benchmarks/dcp_inter_ckpt.yaml b/examples/training/cogvideox/configs/benchmarks/dcp_inter_ckpt.yaml new file mode 100644 index 00000000..a592b2b3 --- /dev/null +++ b/examples/training/cogvideox/configs/benchmarks/dcp_inter_ckpt.yaml @@ -0,0 +1,71 @@ +zipf_offset: 5 +outputs: exp/cogvideox/dcp_inter_ckpt +profile_path: exp/cogvideox/profile/dcp_inter_ckpt +dynamic_sp: true +dynamic_recompute: true +auto_grad_accumulation: true +dummy_dataset: true +dummy_data_size: 2000 +verbose: true +calculate_imbalance: true +max_grad_accumulation_steps: 5 +min_grad_accumulation_steps: 15 + +# ==== training config ==== + +# preprocess embedding +data_path: "./assets/example_data/demo_preprocess.csv" +preprocessed_data: true +drop_last: true + +# train +ckpt_path: "THUDM/CogVideoX-5b" +grad_checkpoint: True +num_workers: 8 +dtype: "bf16" + +# log +seed: 42 +epochs: 1 +log_every: 1e10 + +# optimization +grad_clip: 1.0 +lr: 1e-8 +ema_decay: 0.99 +adam_eps: 1e-15 +warmup_steps: 10 + +# data +image_mixing_frac: 50 +num_bucket_build_workers: 16 +bucket_config: + "144p": {1: [1.0, 345], 51: [1.0, 48], 102: [1.0, 25], 204: [1.0, 12], 408: [1.0, 6]} + "240p": {1: [1.0, 128], 51: [1.0, 16], 102: [1.0, 8], 204: [1.0, 4], 408: [1.0, 2]} + "360p": {1: [1.0, 64], 51: [1.0, 7], 102: [1.0, 4], 204: [1.0, 2], 408: [1.0, 1]} + "480p": {1: [1.0, 32], 51: [1.0, 4], 102: [1.0, 2], 204: [1.0, 1], 408: [1.0, 1]} + "720p": {1: [1.0, 14], 51: [1.0, 1], 102: [1.0, 1], 204: [1.0, 1], 408: [1.0, 1]} + +# override default common ar +# for benchmark, we use single ar for all resolutions +# otherwise the data will be too sparse +common_ar: + "144p": {"0.56": [144, 256]} + "240p": {"0.56": [240, 426]} + "360p": {"0.56": [360, 640]} + "480p": {"0.56": [480, 720]} + "720p": {"0.56": [720, 1280]} + +# mask +mask_ratios: { + "random": 0.01, + "intepolate": 0.002, + "quarter_random": 0.002, + "quarter_head": 0.002, + "quarter_tail": 0.002, + "quarter_head_tail": 0.002, + "image_random": 0.0, + "image_head": 0.22, + "image_tail": 0.005, + "image_head_tail": 0.005, +} diff --git a/examples/training/cogvideox/configs/benchmarks/dcp_intra.yaml b/examples/training/cogvideox/configs/benchmarks/dcp_intra.yaml new file mode 100644 index 00000000..5d8011ef --- /dev/null +++ b/examples/training/cogvideox/configs/benchmarks/dcp_intra.yaml @@ -0,0 +1,70 @@ +zipf_offset: 5 +outputs: exp/cogvideox/dcp_intra +profile_path: exp/cogvideox/profile/dcp_intra +dynamic_sp: true +dynamic_recompute: false +auto_grad_accumulation: false +dummy_dataset: true +dummy_data_size: 2000 +verbose: true +calculate_imbalance: true + + +# ==== training config ==== + +# preprocess embedding +data_path: "./assets/example_data/demo_preprocess.csv" +preprocessed_data: true +drop_last: true + +# train +ckpt_path: "THUDM/CogVideoX-5b" +grad_checkpoint: True +num_workers: 8 +dtype: "bf16" + +# log +seed: 42 +epochs: 1 +log_every: 1e10 + +# optimization +grad_clip: 1.0 +lr: 1e-8 +ema_decay: 0.99 +adam_eps: 1e-15 +warmup_steps: 10 + +# data +image_mixing_frac: 50 +num_bucket_build_workers: 16 +bucket_config: + "144p": {1: [1.0, 345], 51: [1.0, 48], 102: [1.0, 25], 204: [1.0, 12], 408: [1.0, 6]} + "240p": {1: [1.0, 128], 51: [1.0, 16], 102: [1.0, 8], 204: [1.0, 4], 408: [1.0, 2]} + "360p": {1: [1.0, 64], 51: [1.0, 7], 102: [1.0, 4], 204: [1.0, 2], 408: [1.0, 1]} + "480p": {1: [1.0, 32], 51: [1.0, 4], 102: [1.0, 2], 204: [1.0, 1], 408: [1.0, 1]} + "720p": {1: [1.0, 14], 51: [1.0, 1], 102: [1.0, 1], 204: [1.0, 1], 408: [1.0, 1]} + +# override default common ar +# for benchmark, we use single ar for all resolutions +# otherwise the data will be too sparse +common_ar: + "144p": {"0.56": [144, 256]} + "240p": {"0.56": [240, 426]} + "360p": {"0.56": [360, 640]} + "480p": {"0.56": [480, 720]} + "720p": {"0.56": [720, 1280]} + +# mask +mask_ratios: { + "random": 0.01, + "intepolate": 0.002, + "quarter_random": 0.002, + "quarter_head": 0.002, + "quarter_tail": 0.002, + "quarter_head_tail": 0.002, + "image_random": 0.0, + "image_head": 0.22, + "image_tail": 0.005, + "image_head_tail": 0.005, +} diff --git a/examples/training/cogvideox/train.py b/examples/training/cogvideox/train.py new file mode 100644 index 00000000..33c95fa2 --- /dev/null +++ b/examples/training/cogvideox/train.py @@ -0,0 +1,514 @@ +import argparse +import logging +import os +from datetime import timedelta +from pprint import pformat + +import deepspeed +import torch +import torch.distributed as dist +import wandb +from omegaconf import OmegaConf +from tqdm import tqdm +from transformers import AutoTokenizer, T5EncoderModel + +from videosys.core.dcp.profiler import Profiler, set_profiler +from videosys.core.distributed.parallel_mgr import DynamicParallelManager, ParallelManager, set_distributed_state +from videosys.models.autoencoders.autoencoder_kl_cogvideox import AutoencoderKLCogVideoX +from videosys.models.transformers.cogvideox_transformer_3d import CogVideoXTransformer3DModel +from videosys.schedulers.scheduling_dpm_cogvideox import CogVideoXDPMScheduler +from videosys.training.ckpt_io import load, save, save_training_config +from videosys.training.datasets.cogvideox.dataloader import prepare_dataloader +from videosys.training.datasets.cogvideox.datasets import DummyVariableVideoTextDataset, VariableVideoTextDataset +from videosys.training.lr_schedulers.linear_warmup_open_sora import LinearWarmupLR +from videosys.utils.logging import init_logger +from videosys.utils.training import ( + all_reduce_mean, + define_experiment_workspace, + format_numel_str, + get_model_numel, + requires_grad, +) +from videosys.utils.utils import merge_args, set_seed, str_to_dtype + + +def main(args): + # ====================================================== + # 1. configs & runtime variables + # ====================================================== + # == device and dtype == + assert torch.cuda.is_available(), "Training currently requires at least one GPU." + assert args.dtype in ["fp16", "bf16"], f"Unknown mixed precision {args.dtype}" + dtype = str_to_dtype(args.dtype) + + # == init distributed training == + rank, world_size, node_rank, node_size = set_distributed_state(args.distributed_profile) + dist.init_process_group( + rank=rank, + world_size=world_size, + backend="nccl", + timeout=timedelta(minutes=10), + ) + deepspeed.init_distributed(timeout=timedelta(seconds=10)) + torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count()) + set_seed(args.seed) + device = torch.cuda.current_device() + + # == init exp_dir == + exp_name, exp_dir = define_experiment_workspace(args.outputs) + dist.barrier() + if dist.get_rank() == 0: + os.makedirs(exp_dir, exist_ok=True) + save_training_config(vars(args), exp_dir) + dist.barrier() + + # == init logger, tensorboard & wandb == + init_logger(exp_dir) + logging.info(f"Experiment directory created at {exp_dir}") + logging.info(f"Training configuration:\n {pformat(vars(args))}") + if dist.get_rank() == 0: + if args.wandb: + wandb.init(project="Open-Sora", name=exp_name, config=vars(args), dir="./outputs/wandb") + + # == init parallel manager == + torch.set_num_threads(1) + if args.dynamic_sp: + parallel_mgr = DynamicParallelManager() + else: + parallel_mgr = ParallelManager(dist.get_world_size() // args.sp_size, 1, args.sp_size) + preprocessed_data = args.preprocessed_data + if args.profile_path is None or not os.path.exists(args.profile_path): + do_profile = True + preprocessed_data = True + logging.info( + f"[ATTENTION!] Profile file is not found at `{args.profile_path}`! Profiling will be performed then exit." + ) + else: + do_profile = False + + # import pdb + # if torch.distributed.get_rank() == 0: + # pdb.set_trace() + + # ====================================================== + # 2. build model + # ====================================================== + logging.info("Building models...") + + model_path = args.ckpt_path + # == build text-encoder and vae == + if not preprocessed_data: + text_encoder = T5EncoderModel.from_pretrained( + model_path, subfolder="text_encoder", torch_dtype=dtype + ).to(device).eval() + tokenizer = AutoTokenizer.from_pretrained(model_path, subfolder="tokenizer") + vae = AutoencoderKLCogVideoX.from_pretrained( + model_path, subfolder="vae", torch_dtype=dtype + ).to(device).eval() + vae.enable_slicing() + vae.enable_tiling() + + text_encoder.requires_grad_(False) + vae.requires_grad_(False) + + vae_scale_factor_spatial = 2 ** (len(vae.config.block_out_channels) - 1) + else: + vae_scale_factor_spatial = 2 ** 3 + + # == build diffusion model == + model = CogVideoXTransformer3DModel.from_pretrained( + model_path, subfolder="transformer", torch_dtype=dtype, + # override for 720p, 408 frame + sample_height=90, sample_width=160, sample_frames=409, + ).to(device).train() + model_numel, model_numel_trainable = get_model_numel(model) + logging.info( + f"[Diffusion] Trainable model params: {format_numel_str(model_numel_trainable)}, " + f"Total model params: {format_numel_str(model_numel)}", + ) + + # == setup loss function, build scheduler == + scheduler = CogVideoXDPMScheduler.from_pretrained( + model_path, + subfolder="scheduler", + ) + + # == setup optimizer == + optimizer = torch.optim.AdamW( + filter(lambda p: p.requires_grad, model.parameters()), + lr=args.lr, + weight_decay=args.weight_decay, + eps=args.adam_eps, + ) + + # == setup learning rate scheduler == + warmup_steps = args.warmup_steps + if warmup_steps is None: + lr_scheduler = None + else: + lr_scheduler = LinearWarmupLR(optimizer, warmup_steps=args.warmup_steps) + + # == additional preparation == + if args.grad_checkpoint: + model.enable_grad_checkpointing() + model.enable_parallel(parallel_mgr=parallel_mgr) + + # ====================================================== + # 3. build dataset and dataloader + # ====================================================== + logging.info("Building dataset...") + # create dcp profiler + # TODO: scheduler is a better name? + profiler: Profiler = set_profiler( + model_type=model.config._name_or_path, + total_layers=model.config.num_layers, + bucket_config=args.bucket_config, + text_max_seq_len=model.config.max_text_seq_length, + text_hidden_size=model.config.text_embed_dim, + global_interpolation=not args.no_global_interpolation, + dynamic_sp=args.dynamic_sp, + dynamic_recompute=args.dynamic_recompute, + auto_grad_acc=args.auto_grad_accumulation, + do_profile=do_profile, + distributed_profile=args.distributed_profile, + node_rank=node_rank, + node_size=node_size, + alloc_fraction=args.alloc_memory_fraction, + profile_path=args.profile_path, + parallel_mgr=parallel_mgr, + verbose=args.verbose, + ) + + # == build dataset == + if args.dummy_dataset: + dataset = DummyVariableVideoTextDataset( + data_size=args.dummy_data_size, + seed=args.seed, + data_path=args.data_path, + transform_name="resize_crop", + preprocessed_data=preprocessed_data, + bucket_config=args.bucket_config, + common_ar=args.common_ar, + distribution=args.distribution, + zipf_offset=args.zipf_offset, + image_mixing_type=args.image_mixing_type, + image_mixing_frac=args.image_mixing_frac, + text_max_seq_len=model.config.max_text_seq_length, + text_hidden_size=model.config.text_embed_dim, + ) + else: + dataset = VariableVideoTextDataset( + transform_name="resize_crop", data_path=args.data_path, preprocessed_data=preprocessed_data + ) + logging.info(f"Dataset contains {len(dataset)} samples.") + + # == build dataloader == + dataloader, sampler = prepare_dataloader( + dataset=dataset, + batch_size=args.batch_size, + num_workers=args.num_workers, + seed=args.seed, + shuffle=True, + drop_last=args.drop_last, + process_group=parallel_mgr.dp_group, + prefetch_factor=args.prefetch_factor, + auto_grad_accumulation=args.auto_grad_accumulation, + bucket_config=args.bucket_config, + num_bucket_build_workers=args.num_bucket_build_workers, + parallel_mgr=parallel_mgr, + calculate_imbalance=args.calculate_imbalance, + verbose=args.verbose, + max_grad_accumulation_steps=args.max_grad_accumulation_steps, + min_grad_accumulation_steps=args.min_grad_accumulation_steps, + ) + + # ======================================================= + # 4. distributed training preparation + # ======================================================= + logging.info("Preparing for distributed training...") + # == boosting == + # we set dtype first to make initialization of model consistent with the dtype + # then reset it to the fp32 as we make diffusion scheduler in fp32 + torch.set_default_dtype(dtype) + ds_config = { + "train_micro_batch_size_per_gpu": 1, + "steps_per_print": 1e8, # dont print + "gradient_clipping": 1.0, + "zero_optimization": { + "stage": 1, + "reduce_scatter": True, + "allgather_bucket_size": 5e8, + "reduce_bucket_size": 5e8, + "overlap_comm": True, + "contiguous_gradients": True, + }, + "bf16": {"enabled": True}, + } + # Initialize the model, optimizer, and lr scheduler + model, optimizer, _, _ = deepspeed.initialize( + model=model, + optimizer=optimizer, + config=ds_config, + ) + torch.set_default_dtype(torch.float) + logging.info("Boosting model for distributed training") + profiler.register_modules( + { + "layer": model.module.transformer_blocks, + } + ) + + start_epoch = start_step = log_step = acc_step = 0 + # TODO: resume functionality should consider the profiler status + # == resume == + # if args.load is not None: + # logging.info("Loading checkpoint") + # ret = load( + # args.load, + # model=model, + # ema=ema, + # sampler=None if args.start_from_scratch else sampler, + # ) + # if not args.start_from_scratch: + # start_epoch, start_step = ret + # logging.info(f"Loaded checkpoint {args.load} at epoch {start_epoch} step {start_step}") + + # == global variables == + if do_profile: + start_epoch, cfg_epochs = 0, 1 + else: + cfg_epochs = args.epochs + running_loss = 0.0 + logging.info(f"Training for {cfg_epochs} epochs{' with profiling' if profiler.need_profile() else ''}.") + + # ======================================================= + # 5. training loop + # ======================================================= + dist.barrier() + token_counter = torch.zeros((1,), dtype=torch.double, device=device) + + for epoch in range(start_epoch, cfg_epochs): + local_token_counter = 0.0 + if profiler.need_profile(): + # TODO: add timer for profile + disable = True + num_steps_per_epoch = None + dataloader_iter = profiler.get_data_iter() + epoch_desc = "Profiling" + profiler.init_profiler() + else: + # == set dataloader to new epoch == + sampler.set_epoch(epoch) + disable = not dist.get_rank() == 0 + num_steps_per_epoch = len(dataloader) + dataloader_iter = iter(dataloader) + epoch_desc = f"Epoch {epoch}" + logging.info(f"Beginning {epoch_desc}...") + + # == training loop in an epoch == + pbar = tqdm( + enumerate(dataloader_iter, start=start_step), + desc=epoch_desc, + disable=disable, + initial=start_step, + total=num_steps_per_epoch, + ) + for step, batch in pbar: + # TODO: more elegant here + profiler.optimize_dynamics(batch, model) + + total_gas = batch["gas"] + iter_loss = 0.0 + + for gas in range(total_gas): + with profiler.profile(batch, model, gas) as valid_depth: + batch_data = batch["data"][gas] + height = batch_data.pop("height")[0].item() + width = batch_data.pop("width")[0].item() + + if preprocessed_data: + # move data + x = batch_data.pop("video").permute(0, 2, 1, 3, 4).to(device, dtype) # [B, T, C, H, W] + y = batch_data.pop("text").to(device, dtype) + else: + raise NotImplementedError("Not implemented for non-preprocessed data") + # with torch.no_grad(): + # x = batch_data.pop("video").to(device, dtype) # [B, C, T, H, W] + # y = batch_data.pop("text") + # # Prepare visual inputs + # x = vae.encode(x) # [B, C, T, H/P, W/P] + # # Prepare text inputs + # model_args = encode_prompt(text_encoder, tokenizer, y) + # for k, v in batch_data.items(): + # if isinstance(v, torch.Tensor): + # model_args[k] = v.to(device, dtype) + # # TODO: polish + model_args = dict(valid_depth=valid_depth) + + # mask + # mask = None + # if mask_generator is not None: + # mask = mask_generator.get_masks(x) + # model_args["x_mask"] = mask + + # diffusion + loss_dict = scheduler.training_losses(model, x, y, height, width, vae_scale_factor_spatial, model_args) + + # backward + profiler.set_gradient_accumulation_boundary(model, batch, gas) + + loss = loss_dict["loss"].mean() + model.backward(loss) + + model.step() + if lr_scheduler is not None: + lr_scheduler.step() + + iter_loss += loss.detach() + + if profiler.need_profile(): + continue + + # == update EMA == + # update_ema(ema, model.module, decay=args.ema_decay) + + # == update log info == + all_reduce_mean(iter_loss) + iter_loss = iter_loss.item() / total_gas + running_loss += iter_loss + global_step = epoch * num_steps_per_epoch + step + log_step += 1 + acc_step += 1 + + # == logging == + if dist.get_rank() == 0 and (global_step + 1) % args.log_every == 0: + avg_loss = running_loss / log_step + # progress bar + pbar.set_postfix({"loss": avg_loss, "step": step, "global_step": global_step}) + # wandb + if args.wandb: + wandb.log( + { + "iter": global_step, + "acc_step": acc_step, + "epoch": epoch, + "loss": iter_loss, + "avg_loss": avg_loss, + "lr": optimizer.param_groups[0]["lr"], + }, + step=global_step, + ) + + running_loss = 0.0 + log_step = 0 + + # == checkpoint saving == + # if args.ckpt_every > 0 and (global_step + 1) % args.ckpt_every == 0: + # ema_gathering(model.module, ema) + # save_dir = save( + # save_dir=exp_dir, + # save_optimizer=args.save_optimizer, + # model=model, + # ema=ema, + # sampler=sampler, + # epoch=epoch, + # step=step + 1, + # global_step=global_step + 1, + # batch_size=args.batch_size, + # ) + # ema_sharding(model.module, ema) + # logging.info( + # f"Saved checkpoint at epoch {epoch}, step {step + 1}, global_step {global_step + 1} to {save_dir}" + # ) + + token_counter.fill_(local_token_counter) + dist.all_reduce(token_counter) + if rank == 0 and not disable: + elapsed_time = pbar.format_dict["elapsed"] + logging.info( + f"Epoch {epoch}: steps: {num_steps_per_epoch} elapsed time: {elapsed_time:.2f} s" + f", effective samples: {sampler.effective_samples}" + f", sample throughput: {sampler.effective_samples / elapsed_time:.2f} samples/s" + f", token throughput: {token_counter.item()/elapsed_time:.2f} token/s" + ) + + sampler.reset() + start_step = 0 + dist.barrier() + + if do_profile: + logging.info( + f"Profiling is done and saved to {args.profile_path}. Please restart this programe for training with " + f"`profile_path: {args.profile_path}` in the config file. Exiting..." + ) + else: + logging.info("Training is done. Exiting...") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + # model config + parser.add_argument("config", help="model config file path") + + parser.add_argument("--seed", default=1024, type=int, help="seed for reproducibility") + parser.add_argument("--batch-size", default=None, type=int, help="batch size") + parser.add_argument("--outputs", default="./outputs", type=str, help="the dir to save model weights") + parser.add_argument("--data-path", default=None, type=str, help="path to data csv") + parser.add_argument("--dtype", default="bf16", type=str, help="data type") + parser.add_argument("--grad-clip", default=0, type=float, help="gradient clipping") + parser.add_argument("--sp-size", default=1, type=int, help="sequence parallelism size") + parser.add_argument("--reduce-bucket-size-in-m", default=20, type=int, help="reduce bucket size in MB") + parser.add_argument("--epochs", default=100, type=int, help="number of epochs") + parser.add_argument("--num-workers", default=4, type=int, help="number of workers") + parser.add_argument("--prefetch-factor", default=2, type=int, help="prefetch factor") + parser.add_argument("--bucket-config", default=None, type=str, help="bucket config") + parser.add_argument("--num-bucket-build-workers", default=1, type=int, help="number of bucket build workers") + parser.add_argument("--weight-decay", default=0, type=float, help="weight decay") + parser.add_argument("--adam-eps", default=1e-8, type=float, help="adam epsilon") + parser.add_argument("--grad-checkpoint", default=False, action="store_true", help="gradient checkpoint") + parser.add_argument("--mask-ratios", default=None, type=str, help="mask ratios") + parser.add_argument("--ema-decay", default=0.99, type=float, help="ema decay") + parser.add_argument("--log-every", default=1, type=int, help="log every") + parser.add_argument("--ckpt-every", default=-1, type=int, help="checkpoint every") + parser.add_argument("--ckpt-path", default="hpcai-tech/OpenSora-STDiT-v3", type=str, help="path to model ckpt") + + parser.add_argument("--lr", default=1e-4, type=float, help="learning rate") + parser.add_argument("--wandb", default=False, action="store_true", help="enable wandb") + parser.add_argument("--load", default=None, type=str, help="path to continue training") + parser.add_argument("--start-from-scratch", action="store_true", help="start training from scratch") + parser.add_argument("--warmup-steps", default=None, type=int, help="warmup steps") + parser.add_argument("--verbose", action="store_true", help="verbose") + parser.add_argument("--save-optimizer", action="store_true", help="save optimizer") + + # experimental features + parser.add_argument("--drop-last", action="store_true") + parser.add_argument("--dummy-dataset", action="store_true") + parser.add_argument("--dummy-data-size", default=100, type=int) + parser.add_argument("--common-ar", type=dict, default=None) + parser.add_argument("--preprocessed-data", action="store_true") + parser.add_argument("--image-mixing-type", default="exclusive", type=str, choices=["inclusive", "exclusive"]) + parser.add_argument("--image-mixing-frac", default=1, type=float) + parser.add_argument("--distribution", default="zipf", type=str, choices=["zipf", "uniform"]) + parser.add_argument("--zipf-offset", type=int, default=5) + parser.add_argument("--no-global-interpolation", action="store_true") + parser.add_argument("--dynamic-sp", action="store_true") + parser.add_argument("--dynamic-recompute", action="store_true") + parser.add_argument("--auto-grad-accumulation", action="store_true") + parser.add_argument( + "--alloc-memory-fraction", + default=0.70, + type=float, + help="This is an empirical value to cap the allocated memory during profiling with dynamic sp. Communication in different ranks can cause free memory discrepancy, which can leads to comm deadlock. So you need to leave enough space to bear this discrepancy. If you meet this problem during profiling, try to decrease this value.", + ) + parser.add_argument("--profile-path", default="exp/profile", type=str) + parser.add_argument("--distributed-profile", action="store_true") + parser.add_argument("--calculate-imbalance", action="store_true") + parser.add_argument("--max-grad-accumulation-steps", default=3, type=int) + parser.add_argument("--min-grad-accumulation-steps", default=2, type=int) + + args = parser.parse_args() + config_args = OmegaConf.load(args.config) + args = merge_args(args, config_args) + + main(args) diff --git a/videosys/core/dcp/profiler.py b/videosys/core/dcp/profiler.py index 64741839..0675611c 100644 --- a/videosys/core/dcp/profiler.py +++ b/videosys/core/dcp/profiler.py @@ -15,11 +15,27 @@ from videosys.core.dcp.recompute import disable_profile, enable_profile, get_profile_context from videosys.core.distributed.parallel_mgr import DynamicParallelManager -from videosys.training.datasets.open_sora.aspect import ASPECT_RATIOS, DEFAULT_AR_MAP from videosys.utils.training import GroupTimer, set_grad_accumulation_steps PROFILER = None GB = 1024**3 +BATCH_SYHTHESIZER = None +LOCAL_ASPECT_RATIOS = None +LOCAL_DEFAULT_AR_MAP = None + +def setup_batch_synthesizer(model_type): + global BATCH_SYHTHESIZER, LOCAL_ASPECT_RATIOS, LOCAL_DEFAULT_AR_MAP + if "OpenSora" in model_type: + BATCH_SYHTHESIZER = open_sora_synthesizer + from videosys.training.datasets.open_sora.aspect import ASPECT_RATIOS, DEFAULT_AR_MAP + LOCAL_ASPECT_RATIOS = ASPECT_RATIOS + LOCAL_DEFAULT_AR_MAP = DEFAULT_AR_MAP + + elif "CogVideoX" in model_type: + BATCH_SYHTHESIZER = cogvideox_synthesizer + from videosys.training.datasets.cogvideox.aspect import ASPECT_RATIOS, DEFAULT_AR_MAP + LOCAL_ASPECT_RATIOS = ASPECT_RATIOS + LOCAL_DEFAULT_AR_MAP = DEFAULT_AR_MAP def clean_cache(): @@ -91,14 +107,85 @@ def to_list(self): return ret +def open_sora_synthesizer(data_plan, auto_grad_acc, data_idx, text_max_seq_len, text_hidden_size): + height, width = LOCAL_DEFAULT_AR_MAP[data_plan.ar_name] + nf = 1 + if data_plan.num_frame > 1: + nf = data_plan.num_frame * 5 // 17 + + ret = dict( + ar_name=data_plan.ar_name, + num_frame=data_plan.num_frame, + sp_size=data_plan.sp_size, + gas=data_plan.gas, + data=[], + profile_grad_acc=auto_grad_acc and data_idx > 0, + ) + + for _ in range(data_plan.gas): + ret["data"].append( + dict( + video=torch.rand(data_plan.bs, 4, nf, height // 8, width // 8), + text=torch.rand( + data_plan.bs, + 1, + text_max_seq_len, + text_hidden_size, + ), + mask=torch.ones(data_plan.bs, text_max_seq_len, dtype=torch.long), + num_frames=torch.tensor([data_plan.num_frame] * data_plan.bs), + height=torch.tensor([height] * data_plan.bs), + width=torch.tensor([width] * data_plan.bs), + fps=torch.tensor([24 if data_plan.num_frame > 1 else 120] * data_plan.bs), + ar=torch.tensor([height / width] * data_plan.bs), + plan_idx=data_idx, + warmup_iter=data_plan.warmup_iter, + ) + ) + + return ret + + +def cogvideox_synthesizer(data_plan, auto_grad_acc, data_idx, text_max_seq_len, text_hidden_size): + height, width = LOCAL_DEFAULT_AR_MAP[data_plan.ar_name] + nf = max(1, data_plan.num_frame // 4) + + ret = dict( + ar_name=data_plan.ar_name, + num_frame=data_plan.num_frame, + sp_size=data_plan.sp_size, + gas=data_plan.gas, + data=[], + profile_grad_acc=auto_grad_acc and data_idx > 0, + ) + + for _ in range(data_plan.gas): + ret["data"].append( + dict( + video=torch.rand(data_plan.bs, 16, nf, height // 8, width // 8), + text=torch.rand( + data_plan.bs, + text_max_seq_len, + text_hidden_size, + ), + height=torch.tensor([height] * data_plan.bs), + width=torch.tensor([width] * data_plan.bs), + plan_idx=data_idx, + warmup_iter=data_plan.warmup_iter, + ) + ) + + return ret + + class ProfileDataIter: - def __init__(self, profiler): + def __init__(self, profiler, init_bucket): self.profiler: Profiler = profiler self.data_plan = [ DataPlan( - ar_name="144p", - num_frame=51, + ar_name=init_bucket[0], + num_frame=init_bucket[1], sp_size=self.profiler.max_sp, gas=1, bs=1, @@ -113,42 +200,11 @@ def __iter__(self): data_idx = self.next_idx self.next_idx += 1 - height, width = DEFAULT_AR_MAP[data_plan.ar_name] - nf = 1 - if data_plan.num_frame > 1: - nf = data_plan.num_frame * 5 // 17 - - ret = dict( - ar_name=data_plan.ar_name, - num_frame=data_plan.num_frame, - sp_size=data_plan.sp_size, - gas=data_plan.gas, - data=[], - profile_grad_acc=self.profiler.auto_grad_acc and data_idx > 0, + yield BATCH_SYHTHESIZER( + data_plan, self.profiler.auto_grad_acc, data_idx, + self.profiler.text_max_seq_len, self.profiler.text_hidden_size ) - for _ in range(data_plan.gas): - ret["data"].append( - dict( - video=torch.rand(data_plan.bs, 4, nf, height // 8, width // 8), - text=torch.rand( - data_plan.bs, - 1, - self.profiler.text_max_seq_len, - self.profiler.text_hidden_size, - ), - mask=torch.ones(data_plan.bs, self.profiler.text_max_seq_len, dtype=torch.long), - num_frames=torch.tensor([data_plan.num_frame] * data_plan.bs), - height=torch.tensor([height] * data_plan.bs), - width=torch.tensor([width] * data_plan.bs), - fps=torch.tensor([24 if data_plan.num_frame > 1 else 120] * data_plan.bs), - ar=torch.tensor([height / width] * data_plan.bs), - plan_idx=data_idx, - warmup_iter=data_plan.warmup_iter, - ) - ) - yield ret - if self.profiler.has_next_data_plan(): self.data_plan.append(self.profiler.next_data_plan()) self.profiler.finalize_profile() @@ -157,6 +213,7 @@ def __iter__(self): class Profiler: def __init__( self, + model_type, total_layers, bucket_config, text_max_seq_len, @@ -176,6 +233,7 @@ def __init__( profile_depth=2, parallel_mgr=None, ): + setup_batch_synthesizer(model_type) self.total_layers = total_layers # [(ar_name, num_frame)] @@ -183,7 +241,7 @@ def __init__( for ar_name in bucket_config: for num_frame in bucket_config[ar_name]: self.bucket_config.append((ar_name, num_frame)) - self.bucket_config = sorted(self.bucket_config, key=lambda x: ASPECT_RATIOS[x[0]][0] * x[1], reverse=True) + self.bucket_config = sorted(self.bucket_config, key=lambda x: LOCAL_ASPECT_RATIOS[x[0]][0] * x[1], reverse=True) self.text_max_seq_len = text_max_seq_len self.text_hidden_size = text_hidden_size @@ -387,7 +445,7 @@ def get_recompute_cfg(self, ar_name, num_frame): ############################################################ # Key functionality: profiling and planning for bs, sp size, and recompute cfg def get_data_iter(self): - return ProfileDataIter(self) + return ProfileDataIter(self, self.bucket_config[-1]) def has_next_data_plan(self): "Move to next bucket" @@ -869,6 +927,7 @@ def update_timer_group(self): def set_profiler( + model_type, total_layers, bucket_config, text_max_seq_len, @@ -888,6 +947,7 @@ def set_profiler( ) -> Profiler: global PROFILER PROFILER = Profiler( + model_type=model_type, total_layers=total_layers, bucket_config=bucket_config, text_max_seq_len=text_max_seq_len, diff --git a/videosys/models/transformers/cogvideox_transformer_3d.py b/videosys/models/transformers/cogvideox_transformer_3d.py index 3b995209..5abe5c0b 100644 --- a/videosys/models/transformers/cogvideox_transformer_3d.py +++ b/videosys/models/transformers/cogvideox_transformer_3d.py @@ -25,6 +25,7 @@ from videosys.core.distributed.comm import all_to_all_comm, gather_sequence, get_pad, set_pad, split_sequence from videosys.core.distributed.parallel_mgr import ParallelManager from videosys.core.pab.pab_mgr import enable_pab, if_broadcast_spatial +from videosys.core.dcp.recompute import auto_recompute from videosys.models.modules.embeddings import apply_rotary_emb from videosys.utils.utils import batch_func @@ -265,6 +266,8 @@ def __init__( self.last_attn = None self.block_idx = block_idx + self.grad_checkpointing = True + def forward( self, hidden_states: torch.Tensor, @@ -459,15 +462,18 @@ def __init__( # parallel self.parallel_manager = None - def enable_parallel(self, dp_size, sp_size, enable_cp): - # update cfg parallel - if enable_cp and sp_size % 2 == 0: - sp_size = sp_size // 2 - cp_size = 2 + def enable_parallel(self, dp_size=None, sp_size=None, enable_cp=None, parallel_mgr=None): + if parallel_mgr: + self.parallel_manager = parallel_mgr else: - cp_size = 1 + # update cfg parallel + if enable_cp and sp_size % 2 == 0: + sp_size = sp_size // 2 + cp_size = 2 + else: + cp_size = 1 - self.parallel_manager: ParallelManager = ParallelManager(dp_size, cp_size, sp_size) + self.parallel_manager: ParallelManager = ParallelManager(dp_size, cp_size, sp_size) for _, module in self.named_modules(): if hasattr(module, "parallel_manager"): @@ -476,6 +482,10 @@ def enable_parallel(self, dp_size, sp_size, enable_cp): def _set_gradient_checkpointing(self, module, value=False): self.gradient_checkpointing = value + def enable_grad_checkpointing(self): + for block in self.transformer_blocks: + block.grad_checkpointing = True + def forward( self, hidden_states: torch.Tensor, @@ -484,6 +494,7 @@ def forward( timestep_cond: Optional[torch.Tensor] = None, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, return_dict: bool = True, + **kwargs, ): if self.parallel_manager.cp_size > 1: ( @@ -500,8 +511,16 @@ def forward( timestep_cond, image_rotary_emb, ) - + batch_size, num_frames, channels, height, width = hidden_states.shape + height_pad, width_pad = 0, 0 + if height % self.config.patch_size != 0: + height_pad = height % self.config.patch_size + height = height + self.config.patch_size - height_pad + if width % self.config.patch_size != 0: + width_pad = width % self.config.patch_size + width = width + self.config.patch_size - width_pad + hidden_states = F.pad(hidden_states, (0, width_pad, 0, height_pad), value=0) # 1. Time embedding timesteps = timestep @@ -519,6 +538,7 @@ def forward( # 3. Position embedding text_seq_length = encoder_hidden_states.shape[1] if not self.config.use_rotary_positional_embeddings: + # TODO: fix odd dims for CogVideoX-2b seq_length = height * width * num_frames // (self.config.patch_size**2) pos_embeds = self.pos_embedding[:, : text_seq_length + seq_length] @@ -533,23 +553,12 @@ def forward( hidden_states = split_sequence(hidden_states, self.parallel_manager.sp_group, dim=1, pad=get_pad("pad")) # 4. Transformer blocks - for i, block in enumerate(self.transformer_blocks): - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - hidden_states, - encoder_hidden_states, - emb, - image_rotary_emb, - **ckpt_kwargs, + valid_depth = kwargs.get("valid_depth", len(self.transformer_blocks)) + for i in range(valid_depth): + block = self.transformer_blocks[i] + if self.training: + hidden_states, encoder_hidden_states = auto_recompute( + block, hidden_states, encoder_hidden_states, emb, image_rotary_emb ) else: hidden_states, encoder_hidden_states = block( @@ -579,8 +588,14 @@ def custom_forward(*inputs): # 6. Unpatchify p = self.config.patch_size output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, channels, p, p) + # b, f, c, h, w output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4) - + if height_pad > 0: + # unpad + output = output[:, :, :, :-height_pad] + if width_pad > 0: + output = output[:, :, :, :, :-width_pad] + if self.parallel_manager.cp_size > 1: output = gather_sequence(output, self.parallel_manager.cp_group, dim=0) diff --git a/videosys/schedulers/scheduling_dpm_cogvideox.py b/videosys/schedulers/scheduling_dpm_cogvideox.py index 3209dbe7..2410cf36 100644 --- a/videosys/schedulers/scheduling_dpm_cogvideox.py +++ b/videosys/schedulers/scheduling_dpm_cogvideox.py @@ -17,6 +17,8 @@ import torch from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin +from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid +from diffusers.models.embeddings import get_3d_rotary_pos_embed from diffusers.utils import BaseOutput from diffusers.utils.torch_utils import randn_tensor @@ -116,6 +118,39 @@ def rescale_zero_terminal_snr(alphas_cumprod): return alphas_bar +def prepare_rotary_positional_embeddings( + height: int, + width: int, + num_frames: int, + vae_scale_factor_spatial: int = 8, + patch_size: int = 2, + patch_size_t: int = 1, + attention_head_dim: int = 64, + device: Optional[torch.device] = None, + base_height: int = 480, + base_width: int = 720, +) -> Tuple[torch.Tensor, torch.Tensor]: + grid_height = (height // vae_scale_factor_spatial + patch_size - 1) // patch_size + grid_width = (width // vae_scale_factor_spatial + patch_size - 1) // patch_size + base_size_width = base_width // (vae_scale_factor_spatial * patch_size) + base_size_height = base_height // (vae_scale_factor_spatial * patch_size) + + p_t = patch_size_t + base_num_frames = (num_frames + p_t - 1) // p_t + + grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size_width, base_size_height) + freqs_cos, freqs_sin = get_3d_rotary_pos_embed( + embed_dim=attention_head_dim, + crops_coords=grid_crops_coords, + grid_size=(grid_height, grid_width), + temporal_size=base_num_frames, + ) + + freqs_cos = freqs_cos.to(device=device) + freqs_sin = freqs_sin.to(device=device) + return freqs_cos, freqs_sin + + class CogVideoXDPMScheduler(SchedulerMixin, ConfigMixin): """ `DDIMScheduler` extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with @@ -481,3 +516,47 @@ def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: tor def __len__(self): return self.config.num_train_timesteps + + def training_losses(self, model, model_input, prompt_embeds, height, width, vae_scale_factor_spatial, model_args): + model_config = model.module.config + noise = torch.randn_like(model_input) + batch_size, num_frames = model_input.shape[:2] + timesteps = torch.randint( + 0, self.config.num_train_timesteps, (batch_size,), + device=model_input.device, dtype=torch.long, + ) + + image_rotary_emb = ( + prepare_rotary_positional_embeddings( + height=height, + width=width, + num_frames=num_frames, + vae_scale_factor_spatial=vae_scale_factor_spatial, + patch_size=model_config.patch_size, + #model_config.patch_size_t if model_config.patch_size_t is not None else 1, + patch_size_t=getattr(model_config, "patch_size_t", 1), + attention_head_dim=model_config.attention_head_dim, + device=torch.cuda.current_device(), + ) + if model_config.use_rotary_positional_embeddings + else None + ) + noisy_model_input = self.add_noise(model_input, noise, timesteps) + + model_output = model( + hidden_states=noisy_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timesteps, + image_rotary_emb=image_rotary_emb, + return_dict=False, + **model_args)[0] + model_pred = self.get_velocity(model_output, noisy_model_input, timesteps) + alphas_cumprod = self.alphas_cumprod[timesteps] + weights = 1 / (1 - alphas_cumprod) + while len(weights.shape) < len(model_pred.shape): + weights = weights.unsqueeze(-1) + + target = model_input + + loss = torch.mean((weights * (model_pred - target) ** 2).reshape(batch_size, -1), dim=1) + return {"loss": loss} diff --git a/videosys/training/datasets/cogvideox/__init__.py b/videosys/training/datasets/cogvideox/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/videosys/training/datasets/cogvideox/aspect.py b/videosys/training/datasets/cogvideox/aspect.py new file mode 100644 index 00000000..85bfc156 --- /dev/null +++ b/videosys/training/datasets/cogvideox/aspect.py @@ -0,0 +1,642 @@ +import math + + +# computation +def get_h_w(a, ts, eps=1e-4): + h = (ts * a) ** 0.5 + h = h + eps + h = math.ceil(h) if math.ceil(h) % 2 == 0 else math.floor(h) + w = h / a + w = w + eps + w = math.ceil(w) if math.ceil(w) % 2 == 0 else math.floor(w) + return h, w + + +def get_aspect_ratios_dict(ars, ts=360 * 640): + est = {f"{a:.2f}": get_h_w(a, ts) for a in ars} + return est + + +def get_ar(ratio): + h, w = ratio.split(":") + return int(h) / int(w) + + +# H:W +ASPECT_RATIO_MAP = { + "3:8": "0.38", + "9:21": "0.43", + "12:25": "0.48", + "1:2": "0.50", + "9:17": "0.53", + "27:50": "0.54", + "9:16": "0.56", + "5:8": "0.62", + "2:3": "0.67", + "3:4": "0.75", + "1:1": "1.00", + "4:3": "1.33", + "3:2": "1.50", + "16:9": "1.78", + "17:9": "1.89", + "2:1": "2.00", + "50:27": "2.08", +} + + +AR = [get_ar(ratio) for ratio in ASPECT_RATIO_MAP.keys()] + +# computed from above code +# S = 8294400 +ASPECT_RATIO_4K = { + "0.38": (1764, 4704), + "0.43": (1886, 4400), + "0.48": (1996, 4158), + "0.50": (2036, 4072), + "0.53": (2096, 3960), + "0.54": (2118, 3918), + "0.62": (2276, 3642), + "0.56": (2160, 3840), # base + "0.67": (2352, 3528), + "0.75": (2494, 3326), + "1.00": (2880, 2880), + "1.33": (3326, 2494), + "1.50": (3528, 2352), + "1.78": (3840, 2160), + "1.89": (3958, 2096), + "2.00": (4072, 2036), + "2.08": (4156, 1994), +} + +# S = 3686400 +ASPECT_RATIO_2K = { + "0.38": (1176, 3136), + "0.43": (1256, 2930), + "0.48": (1330, 2770), + "0.50": (1358, 2716), + "0.53": (1398, 2640), + "0.54": (1412, 2612), + "0.56": (1440, 2560), # base + "0.62": (1518, 2428), + "0.67": (1568, 2352), + "0.75": (1662, 2216), + "1.00": (1920, 1920), + "1.33": (2218, 1664), + "1.50": (2352, 1568), + "1.78": (2560, 1440), + "1.89": (2638, 1396), + "2.00": (2716, 1358), + "2.08": (2772, 1330), +} + +# S = 2073600 +ASPECT_RATIO_1080P = { + "0.38": (882, 2352), + "0.43": (942, 2198), + "0.48": (998, 2080), + "0.50": (1018, 2036), + "0.53": (1048, 1980), + "0.54": (1058, 1958), + "0.56": (1080, 1920), # base + "0.62": (1138, 1820), + "0.67": (1176, 1764), + "0.75": (1248, 1664), + "1.00": (1440, 1440), + "1.33": (1662, 1246), + "1.50": (1764, 1176), + "1.78": (1920, 1080), + "1.89": (1980, 1048), + "2.00": (2036, 1018), + "2.08": (2078, 998), +} + +# S = 921600 +ASPECT_RATIO_720P = { + "0.38": (588, 1568), + "0.43": (628, 1466), + "0.48": (666, 1388), + "0.50": (678, 1356), + "0.53": (698, 1318), + "0.54": (706, 1306), + "0.56": (720, 1280), # base + "0.62": (758, 1212), + "0.67": (784, 1176), + "0.75": (832, 1110), + "1.00": (960, 960), + "1.33": (1108, 832), + "1.50": (1176, 784), + "1.78": (1280, 720), + "1.89": (1320, 698), + "2.00": (1358, 680), + "2.08": (1386, 666), +} + +# S = 409920 +ASPECT_RATIO_480P = { + "0.38": (392, 1046), + "0.43": (420, 980), + "0.48": (444, 925), + "0.50": (452, 904), + "0.53": (466, 880), + "0.54": (470, 870), + "0.56": (480, 854), # base + "0.62": (506, 810), + "0.67": (522, 784), + "0.75": (554, 738), + "1.00": (640, 640), + "1.33": (740, 555), + "1.50": (784, 522), + "1.78": (854, 480), + "1.89": (880, 466), + "2.00": (906, 454), + "2.08": (924, 444), +} + +# S = 230400 +ASPECT_RATIO_360P = { + "0.38": (294, 784), + "0.43": (314, 732), + "0.48": (332, 692), + "0.50": (340, 680), + "0.53": (350, 662), + "0.54": (352, 652), + "0.56": (360, 640), # base + "0.62": (380, 608), + "0.67": (392, 588), + "0.75": (416, 554), + "1.00": (480, 480), + "1.33": (554, 416), + "1.50": (588, 392), + "1.78": (640, 360), + "1.89": (660, 350), + "2.00": (678, 340), + "2.08": (692, 332), +} + +# S = 102240 +ASPECT_RATIO_240P = { + "0.38": (196, 522), + "0.43": (210, 490), + "0.48": (222, 462), + "0.50": (226, 452), + "0.53": (232, 438), + "0.54": (236, 436), + "0.56": (240, 426), # base + "0.62": (252, 404), + "0.67": (262, 393), + "0.75": (276, 368), + "1.00": (320, 320), + "1.33": (370, 278), + "1.50": (392, 262), + "1.78": (426, 240), + "1.89": (440, 232), + "2.00": (452, 226), + "2.08": (462, 222), +} + +# S = 36864 +ASPECT_RATIO_144P = { + "0.38": (117, 312), + "0.43": (125, 291), + "0.48": (133, 277), + "0.50": (135, 270), + "0.53": (139, 262), + "0.54": (141, 260), + "0.56": (144, 256), # base + "0.62": (151, 241), + "0.67": (156, 234), + "0.75": (166, 221), + "1.00": (192, 192), + "1.33": (221, 165), + "1.50": (235, 156), + "1.78": (256, 144), + "1.89": (263, 139), + "2.00": (271, 135), + "2.08": (277, 132), +} + +# from PixArt +# S = 8294400 +ASPECT_RATIO_2880 = { + "0.25": (1408, 5760), + "0.26": (1408, 5568), + "0.27": (1408, 5376), + "0.28": (1408, 5184), + "0.32": (1600, 4992), + "0.33": (1600, 4800), + "0.34": (1600, 4672), + "0.40": (1792, 4480), + "0.42": (1792, 4288), + "0.47": (1920, 4096), + "0.49": (1920, 3904), + "0.51": (1920, 3776), + "0.55": (2112, 3840), + "0.59": (2112, 3584), + "0.68": (2304, 3392), + "0.72": (2304, 3200), + "0.78": (2496, 3200), + "0.83": (2496, 3008), + "0.89": (2688, 3008), + "0.93": (2688, 2880), + "1.00": (2880, 2880), + "1.07": (2880, 2688), + "1.12": (3008, 2688), + "1.21": (3008, 2496), + "1.28": (3200, 2496), + "1.39": (3200, 2304), + "1.47": (3392, 2304), + "1.70": (3584, 2112), + "1.82": (3840, 2112), + "2.03": (3904, 1920), + "2.13": (4096, 1920), + "2.39": (4288, 1792), + "2.50": (4480, 1792), + "2.92": (4672, 1600), + "3.00": (4800, 1600), + "3.12": (4992, 1600), + "3.68": (5184, 1408), + "3.82": (5376, 1408), + "3.95": (5568, 1408), + "4.00": (5760, 1408), +} + +# S = 4194304 +ASPECT_RATIO_2048 = { + "0.25": (1024, 4096), + "0.26": (1024, 3968), + "0.27": (1024, 3840), + "0.28": (1024, 3712), + "0.32": (1152, 3584), + "0.33": (1152, 3456), + "0.35": (1152, 3328), + "0.40": (1280, 3200), + "0.42": (1280, 3072), + "0.48": (1408, 2944), + "0.50": (1408, 2816), + "0.52": (1408, 2688), + "0.57": (1536, 2688), + "0.60": (1536, 2560), + "0.68": (1664, 2432), + "0.72": (1664, 2304), + "0.78": (1792, 2304), + "0.82": (1792, 2176), + "0.88": (1920, 2176), + "0.94": (1920, 2048), + "1.00": (2048, 2048), + "1.07": (2048, 1920), + "1.13": (2176, 1920), + "1.21": (2176, 1792), + "1.29": (2304, 1792), + "1.38": (2304, 1664), + "1.46": (2432, 1664), + "1.67": (2560, 1536), + "1.75": (2688, 1536), + "2.00": (2816, 1408), + "2.09": (2944, 1408), + "2.40": (3072, 1280), + "2.50": (3200, 1280), + "2.89": (3328, 1152), + "3.00": (3456, 1152), + "3.11": (3584, 1152), + "3.62": (3712, 1024), + "3.75": (3840, 1024), + "3.88": (3968, 1024), + "4.00": (4096, 1024), +} + +# S = 1048576 +ASPECT_RATIO_1024 = { + "0.25": (512, 2048), + "0.26": (512, 1984), + "0.27": (512, 1920), + "0.28": (512, 1856), + "0.32": (576, 1792), + "0.33": (576, 1728), + "0.35": (576, 1664), + "0.40": (640, 1600), + "0.42": (640, 1536), + "0.48": (704, 1472), + "0.50": (704, 1408), + "0.52": (704, 1344), + "0.57": (768, 1344), + "0.60": (768, 1280), + "0.68": (832, 1216), + "0.72": (832, 1152), + "0.78": (896, 1152), + "0.82": (896, 1088), + "0.88": (960, 1088), + "0.94": (960, 1024), + "1.00": (1024, 1024), + "1.07": (1024, 960), + "1.13": (1088, 960), + "1.21": (1088, 896), + "1.29": (1152, 896), + "1.38": (1152, 832), + "1.46": (1216, 832), + "1.67": (1280, 768), + "1.75": (1344, 768), + "2.00": (1408, 704), + "2.09": (1472, 704), + "2.40": (1536, 640), + "2.50": (1600, 640), + "2.89": (1664, 576), + "3.00": (1728, 576), + "3.11": (1792, 576), + "3.62": (1856, 512), + "3.75": (1920, 512), + "3.88": (1984, 512), + "4.00": (2048, 512), +} + +# S = 262144 +ASPECT_RATIO_512 = { + "0.25": (256, 1024), + "0.26": (256, 992), + "0.27": (256, 960), + "0.28": (256, 928), + "0.32": (288, 896), + "0.33": (288, 864), + "0.35": (288, 832), + "0.40": (320, 800), + "0.42": (320, 768), + "0.48": (352, 736), + "0.50": (352, 704), + "0.52": (352, 672), + "0.57": (384, 672), + "0.60": (384, 640), + "0.68": (416, 608), + "0.72": (416, 576), + "0.78": (448, 576), + "0.82": (448, 544), + "0.88": (480, 544), + "0.94": (480, 512), + "1.00": (512, 512), + "1.07": (512, 480), + "1.13": (544, 480), + "1.21": (544, 448), + "1.29": (576, 448), + "1.38": (576, 416), + "1.46": (608, 416), + "1.67": (640, 384), + "1.75": (672, 384), + "2.00": (704, 352), + "2.09": (736, 352), + "2.40": (768, 320), + "2.50": (800, 320), + "2.89": (832, 288), + "3.00": (864, 288), + "3.11": (896, 288), + "3.62": (928, 256), + "3.75": (960, 256), + "3.88": (992, 256), + "4.00": (1024, 256), +} + +# S = 65536 +ASPECT_RATIO_256 = { + "0.25": (128, 512), + "0.26": (128, 496), + "0.27": (128, 480), + "0.28": (128, 464), + "0.32": (144, 448), + "0.33": (144, 432), + "0.35": (144, 416), + "0.40": (160, 400), + "0.42": (160, 384), + "0.48": (176, 368), + "0.50": (176, 352), + "0.52": (176, 336), + "0.57": (192, 336), + "0.60": (192, 320), + "0.68": (208, 304), + "0.72": (208, 288), + "0.78": (224, 288), + "0.82": (224, 272), + "0.88": (240, 272), + "0.94": (240, 256), + "1.00": (256, 256), + "1.07": (256, 240), + "1.13": (272, 240), + "1.21": (272, 224), + "1.29": (288, 224), + "1.38": (288, 208), + "1.46": (304, 208), + "1.67": (320, 192), + "1.75": (336, 192), + "2.00": (352, 176), + "2.09": (368, 176), + "2.40": (384, 160), + "2.50": (400, 160), + "2.89": (416, 144), + "3.00": (432, 144), + "3.11": (448, 144), + "3.62": (464, 128), + "3.75": (480, 128), + "3.88": (496, 128), + "4.00": (512, 128), +} + + +def get_closest_ratio(height: float, width: float, ratios: dict): + aspect_ratio = height / width + closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - aspect_ratio)) + return closest_ratio + + +ASPECT_RATIOS = { + "144p": (36864, ASPECT_RATIO_144P), + "256": (65536, ASPECT_RATIO_256), + "240p": (102240, ASPECT_RATIO_240P), + "360p": (230400, ASPECT_RATIO_360P), + "512": (262144, ASPECT_RATIO_512), + "480p": (409920, ASPECT_RATIO_480P), + "720p": (921600, ASPECT_RATIO_720P), + "1024": (1048576, ASPECT_RATIO_1024), + "1080p": (2073600, ASPECT_RATIO_1080P), + "2k": (3686400, ASPECT_RATIO_2K), + "2048": (4194304, ASPECT_RATIO_2048), + "2880": (8294400, ASPECT_RATIO_2880), + "4k": (8294400, ASPECT_RATIO_4K), +} + + +DEFAULT_AR_MAP = { + "144p": (144, 256), + "240p": (240, 426), + "360p": (360, 640), + "480p": (480, 720), + "720p": (720, 1280), + # "1080p": (1080, 1920), +} + +# used for data synthesis +COMMON_AR_144P = { + "0.56": (144, 256), + "0.75": (166, 221), + "1.00": (192, 192), + "1.33": (221, 165), + "1.78": (256, 144), +} + +COMMON_AR_256 = { + "0.57": (192, 336), + "0.78": (224, 288), + "1.00": (256, 256), + "1.29": (288, 224), + "1.75": (336, 192), +} + +COMMON_AR_240P = { + "0.56": (240, 426), + "0.75": (276, 368), + "1.00": (320, 320), + "1.33": (370, 278), + "1.78": (426, 240), +} + +COMMON_AR_360P = { + "0.56": (360, 640), + "0.75": (416, 554), + "1.00": (480, 480), + "1.33": (554, 416), + "1.78": (640, 360), +} + +COMMON_AR_512 = { + "0.57": (384, 672), + "0.78": (448, 576), + "1.00": (512, 512), + "1.29": (576, 448), + "1.75": (672, 384), +} + +COMMON_AR_480P = { + "0.56": (480, 854), + "0.75": (554, 738), + "1.00": (640, 640), + "1.33": (740, 555), + "1.78": (854, 480), +} + +COMMON_AR_720P = { + "0.56": (720, 1280), + "0.75": (832, 1110), + "1.00": (960, 960), + "1.33": (1108, 832), + "1.78": (1280, 720), +} + +COMMON_AR_1024 = { + "0.57": (768, 1344), + "0.78": (896, 1152), + "1.00": (1024, 1024), + "1.29": (1152, 896), + "1.75": (1344, 768), +} + +COMMON_AR_1080P = { + "0.56": (1080, 1920), + "0.75": (1248, 1664), + "1.00": (1440, 1440), + "1.33": (1662, 1246), + "1.78": (1920, 1080), +} + +COMMON_AR_2K = { + "0.56": (1440, 2560), + "0.75": (1662, 2216), + "1.00": (1920, 1920), + "1.33": (2218, 1664), + "1.78": (2560, 1440), +} + +COMMON_AR_2048 = { + "0.57": (1536, 2688), + "0.78": (1792, 2304), + "1.00": (2048, 2048), + "1.29": (2304, 1792), + "1.75": (2688, 1536), +} + +COMMON_AR_2880 = { + "0.55": (2112, 3840), + "0.78": (2496, 3200), + "1.00": (2880, 2880), + "1.28": (3200, 2496), + "1.82": (3840, 2112), +} + +COMMON_AR_4K = { + "0.56": (2160, 3840), + "0.75": (2494, 3326), + "1.00": (2880, 2880), + "1.33": (3326, 2494), + "1.78": (3840, 2160), +} + +COMMON_AR = { + "144p": (144, COMMON_AR_144P), + "256": (256, COMMON_AR_256), + "240p": (240, COMMON_AR_240P), + "360p": (360, COMMON_AR_360P), + "512": (512, COMMON_AR_512), + "480p": (480, COMMON_AR_480P), + "720p": (720, COMMON_AR_720P), + "1024": (1024, COMMON_AR_1024), + "1080p": (1080, COMMON_AR_1080P), + "2k": (1440, COMMON_AR_2K), + "2048": (2048, COMMON_AR_2048), + "2880": (2880, COMMON_AR_2880), + "4k": (2160, COMMON_AR_4K), +} + + +def update_common_ar(bucket_config, override_common_ar): + new_common_ar = {} + global COMMON_AR + for res in COMMON_AR: + if res not in bucket_config: + continue + new_common_ar[res] = COMMON_AR[res] + + if override_common_ar is not None: + for name in override_common_ar: + if name not in new_common_ar: + continue + new_common_ar[name][1].clear() + new_common_ar[name][1].update(override_common_ar[name]) + + COMMON_AR.clear() + COMMON_AR.update(new_common_ar) + + +def get_num_pixels(name): + return ASPECT_RATIOS[name][0] + + +def get_image_size(resolution, ar_ratio): + ar_key = ASPECT_RATIO_MAP[ar_ratio] + rs_dict = ASPECT_RATIOS[resolution][1] + assert ar_key in rs_dict, f"Aspect ratio {ar_ratio} not found for resolution {resolution}" + return rs_dict[ar_key] + + +NUM_FRAMES_MAP = { + "1x": 51, + "2x": 102, + "4x": 204, + "8x": 408, + "16x": 816, + "2s": 51, + "4s": 102, + "8s": 204, + "16s": 408, + "32s": 816, +} + + +def get_num_frames(num_frames): + if num_frames in NUM_FRAMES_MAP: + return NUM_FRAMES_MAP[num_frames] + else: + return int(num_frames) diff --git a/videosys/training/datasets/cogvideox/bucket.py b/videosys/training/datasets/cogvideox/bucket.py new file mode 100644 index 00000000..a7411799 --- /dev/null +++ b/videosys/training/datasets/cogvideox/bucket.py @@ -0,0 +1,151 @@ +import logging +from collections import OrderedDict +from typing import Iterable + +import numpy as np + +from .aspect import ASPECT_RATIOS, get_closest_ratio + + +def find_approximate_hw(hw, hw_dict, approx=0.8): + for k, v in hw_dict.items(): + if hw >= v * approx: + return k + return None + + +def find_closet_smaller_bucket(t, t_dict, frame_interval): + # process image + if t == 1: + if 1 in t_dict: + return 1 + else: + return None + # process video + for k, v in t_dict.items(): + if t >= v * frame_interval and v != 1: + return k + return None + + +class Bucket: + def __init__(self, bucket_config): + for key in bucket_config: + assert key in ASPECT_RATIOS, f"Aspect ratio {key} not found." + # wrap config with OrderedDict + bucket_probs = OrderedDict() + bucket_bs = OrderedDict() + bucket_names = sorted(bucket_config.keys(), key=lambda x: ASPECT_RATIOS[x][0], reverse=True) + # print(bucket_config) + for key in bucket_names: + bucket_time_names = sorted(bucket_config[key].keys(), key=lambda x: x, reverse=True) + bucket_probs[key] = OrderedDict({k: bucket_config[key][k][0] for k in bucket_time_names}) + bucket_bs[key] = OrderedDict({k: bucket_config[key][k][1] for k in bucket_time_names}) + + # first level: HW + num_bucket = 0 + hw_criteria = dict() + t_criteria = dict() + ar_criteria = dict() + bucket_id = OrderedDict() + bucket_id_cnt = 0 + for k1, v1 in bucket_probs.items(): + hw_criteria[k1] = ASPECT_RATIOS[k1][0] + t_criteria[k1] = dict() + ar_criteria[k1] = dict() + bucket_id[k1] = dict() + for k2, _ in v1.items(): + t_criteria[k1][k2] = k2 + bucket_id[k1][k2] = bucket_id_cnt + bucket_id_cnt += 1 + ar_criteria[k1][k2] = dict() + for k3, v3 in ASPECT_RATIOS[k1][1].items(): + ar_criteria[k1][k2][k3] = v3 + num_bucket += 1 + + self.bucket_probs = bucket_probs + self.bucket_bs = bucket_bs + self.bucket_id = bucket_id + self.hw_criteria = hw_criteria + self.t_criteria = t_criteria + self.ar_criteria = ar_criteria + self.num_bucket = num_bucket + + self.hw2ar_name_map = dict() + for ar_name in ASPECT_RATIOS: + ar_hw = ASPECT_RATIOS[ar_name][1].values() + for each in ar_hw: + self.hw2ar_name_map[each] = ar_name + + logging.info("Number of buckets: %s", num_bucket) + + def get_bucket_id(self, T, H, W, frame_interval=1, seed=None): + resolution = H * W + approx = 0.8 + + fail = True + for hw_id, t_criteria in self.bucket_probs.items(): + if resolution < self.hw_criteria[hw_id] * approx: + continue + + # if sample is an image + if T == 1: + if 1 in t_criteria: + rng = np.random.default_rng(seed + self.bucket_id[hw_id][1]) + if rng.random() < t_criteria[1]: + fail = False + t_id = 1 + break + else: + continue + + # otherwise, find suitable t_id for video + t_fail = True + for t_id, prob in t_criteria.items(): + rng = np.random.default_rng(seed + self.bucket_id[hw_id][t_id]) + if isinstance(prob, Iterable): + prob_t = prob[1] + if rng.random() > prob_t: + continue + if T >= t_id * frame_interval and t_id != 1: + t_fail = False + break + if t_fail: + continue + + # leave the loop if prob is high enough + if isinstance(prob, Iterable): + prob = prob[0] + if prob >= 1 or rng.random() < prob: + fail = False + break + if fail: + print(f"pass: {T}, {H}, {W}") + return None + + # get aspect ratio id + ar_criteria = self.ar_criteria[hw_id][t_id] + ar_id = get_closest_ratio(H, W, ar_criteria) + return hw_id, t_id, ar_id + + def get_thw(self, bucket_id): + assert len(bucket_id) == 3 + T = self.t_criteria[bucket_id[0]][bucket_id[1]] + H, W = self.ar_criteria[bucket_id[0]][bucket_id[1]][bucket_id[2]] + return T, H, W + + def get_prob(self, bucket_id): + return self.bucket_probs[bucket_id[0]][bucket_id[1]] + + def get_batch_size(self, bucket_id): + return self.bucket_bs[bucket_id[0]][bucket_id[1]] + + def __len__(self): + return self.num_bucket + + +def closet_smaller_bucket(value, bucket): + for i in range(1, len(bucket)): + if value < bucket[i]: + return bucket[i - 1] + return bucket[-1] diff --git a/videosys/training/datasets/cogvideox/dataloader.py b/videosys/training/datasets/cogvideox/dataloader.py new file mode 100644 index 00000000..a15ccd86 --- /dev/null +++ b/videosys/training/datasets/cogvideox/dataloader.py @@ -0,0 +1,138 @@ +import random +from typing import Optional + +import numpy as np +import torch +from torch.distributed import ProcessGroup +from torch.distributed.distributed_c10d import _get_default_group +from torch.utils.data import DataLoader + +from .datasets import DummyVariableVideoTextDataset, VariableVideoTextDataset, VideoTextDataset +from .sampler import StatefulDistributedSampler, VariableVideoBatchSampler + + +# Deterministic dataloader +def get_seed_worker(seed): + def seed_worker(worker_id): + worker_seed = seed + np.random.seed(worker_seed) + torch.manual_seed(worker_seed) + random.seed(worker_seed) + + return seed_worker + + +def prepare_dataloader( + dataset, + batch_size: Optional[int] = None, + shuffle: bool = False, + seed: int = 1024, + drop_last: bool = False, + pin_memory: bool = True, + num_workers: int = 0, + process_group: Optional[ProcessGroup] = None, + bucket_config=None, + num_bucket_build_workers: int = 1, + prefetch_factor: Optional[int] = None, + sp_balance_scope: str = "iter", + auto_grad_accumulation: bool = False, + max_grad_accumulation_steps: int = 2, + parallel_mgr=None, + calculate_imbalance: bool = False, + verbose: bool = False, + min_grad_accumulation_steps: int = 2, + **kwargs, +): + _kwargs = kwargs.copy() + if isinstance(dataset, (VariableVideoTextDataset, DummyVariableVideoTextDataset)): + batch_sampler = VariableVideoBatchSampler( + dataset, + bucket_config, + num_replicas=process_group.size(), + rank=process_group.rank(), + shuffle=shuffle, + seed=seed, + drop_last=drop_last, + verbose=verbose, + num_bucket_build_workers=num_bucket_build_workers, + sp_balance_scope=sp_balance_scope, + auto_grad_accumulation=auto_grad_accumulation, + max_grad_accumulation_steps=max_grad_accumulation_steps, + parallel_mgr=parallel_mgr, + calculate_imbalance=calculate_imbalance, + min_grad_accumulation_steps=min_grad_accumulation_steps, + ) + return ( + DataLoader( + dataset, + batch_sampler=batch_sampler, + worker_init_fn=get_seed_worker(seed), + pin_memory=pin_memory, + num_workers=num_workers, + collate_fn=_collate_fn, + prefetch_factor=prefetch_factor, + **_kwargs, + ), + batch_sampler, + ) + elif isinstance(dataset, VideoTextDataset): + process_group = process_group or _get_default_group() + sampler = StatefulDistributedSampler( + dataset, + num_replicas=process_group.size(), + rank=process_group.rank(), + shuffle=shuffle, + ) + return ( + DataLoader( + dataset, + batch_size=batch_size, + sampler=sampler, + worker_init_fn=get_seed_worker(seed), + drop_last=drop_last, + pin_memory=pin_memory, + num_workers=num_workers, + collate_fn=collate_fn_default, + prefetch_factor=prefetch_factor, + **_kwargs, + ), + sampler, + ) + else: + raise ValueError(f"Unsupported dataset type: {type(dataset)}") + + +def _collate_fn(batch): + ar_name = batch[0]["ar_name"] + num_frame = batch[0]["num_frames"] + sp_size = batch[0]["sp_size"] + gas = batch[0]["gas"] + + stride = (len(batch) + gas - 1) // gas + ret = dict(ar_name=ar_name, num_frame=num_frame, sp_size=sp_size, gas=gas, data=[]) + for i in range(0, len(batch), stride): + assert all(each.pop("sp_size") == sp_size for each in batch[i : i + stride]) + assert all(each.pop("gas") == gas for each in batch[i : i + stride]) + assert all(each.pop("ar_name") == ar_name for each in batch[i : i + stride]) + assert all(each["num_frames"] == num_frame for each in batch[i : i + stride]) + + ret["data"].append(torch.utils.data.default_collate(batch[i : i + stride])) + return ret + + +def collate_fn_default(batch): + # HACK: for loading text features + use_mask = False + if "mask" in batch[0] and isinstance(batch[0]["mask"], int): + masks = [x.pop("mask") for x in batch] + + texts = [x.pop("text") for x in batch] + texts = torch.cat(texts, dim=1) + use_mask = True + + ret = torch.utils.data.default_collate(batch) + + if use_mask: + ret["mask"] = masks + ret["text"] = texts + return ret diff --git a/videosys/training/datasets/cogvideox/datasets.py b/videosys/training/datasets/cogvideox/datasets.py new file mode 100644 index 00000000..1babc0aa --- /dev/null +++ b/videosys/training/datasets/cogvideox/datasets.py @@ -0,0 +1,549 @@ +import logging +import os +from pprint import pformat + +import numpy as np +import pandas as pd +import torch +from PIL import ImageFile +from torchvision.datasets.folder import IMG_EXTENSIONS, pil_loader + +from .aspect import COMMON_AR, update_common_ar +from .read_video import read_video +from .utils import ( + VID_EXTENSIONS, + get_transforms_image, + get_transforms_video, + read_file, + remove_interval, + split_df_by_rank, + temporal_random_crop, +) + +ImageFile.LOAD_TRUNCATED_IMAGES = True +IMG_FPS = 120 + + +def half_normal_pdf(x, sigma=1.0): + if x < 0: + return 0 + + distribution = torch.distributions.normal.Normal(0, sigma) + return np.sqrt(2 / np.pi) * torch.exp(distribution.log_prob(torch.tensor([x], dtype=torch.float))).numpy() + + +class VideoTextDataset(torch.utils.data.Dataset): + """load video according to the csv file. + + Args: + target_video_len (int): the number of video frames will be load. + align_transform (callable): Align different videos in a specified size. + temporal_sample (callable): Sample the target length of a video. + """ + + def __init__( + self, + data_path=None, + num_frames=16, + frame_interval=1, + image_size=(256, 256), + transform_name="center", + ): + self.data_path = data_path + self.data = read_file(data_path) + self.get_text = "text" in self.data.columns + self.num_frames = num_frames + self.frame_interval = frame_interval + self.image_size = image_size + self.transforms = { + "image": get_transforms_image(transform_name, image_size), + "video": get_transforms_video(transform_name, image_size), + } + + def _print_data_number(self): + num_videos = 0 + num_images = 0 + for path in self.data["path"]: + if self.get_type(path) == "video": + num_videos += 1 + else: + num_images += 1 + print(f"Dataset contains {num_videos} videos and {num_images} images.") + + def get_type(self, path): + ext = os.path.splitext(path)[-1].lower() + if ext.lower() in VID_EXTENSIONS: + return "video" + else: + assert ext.lower() in IMG_EXTENSIONS, f"Unsupported file format: {ext}" + return "image" + + def getitem(self, index): + sample = self.data.iloc[index] + path = sample["path"] + file_type = self.get_type(path) + + if file_type == "video": + # loading + vframes, vinfo = read_video(path, backend="av") + video_fps = vinfo["video_fps"] if "video_fps" in vinfo else 24 + + # Sampling video frames + video = temporal_random_crop(vframes, self.num_frames, self.frame_interval) + + # transform + transform = self.transforms["video"] + video = transform(video) # T C H W + else: + # loading + image = pil_loader(path) + video_fps = IMG_FPS + + # transform + transform = self.transforms["image"] + image = transform(image) + + # repeat + video = image.unsqueeze(0).repeat(self.num_frames, 1, 1, 1) + + # TCHW -> CTHW + video = video.permute(1, 0, 2, 3) + + ret = {"video": video, "fps": video_fps} + if self.get_text: + ret["text"] = sample["text"] + return ret + + def __getitem__(self, index): + for _ in range(10): + try: + return self.getitem(index) + except Exception as e: + path = self.data.iloc[index]["path"] + print(f"data {path}: {e}") + index = np.random.randint(len(self)) + raise RuntimeError("Too many bad data.") + + def __len__(self): + return len(self.data) + + +class VariableVideoTextDataset(VideoTextDataset): + def __init__( + self, + data_path=None, + num_frames=None, + frame_interval=1, + image_size=(None, None), + transform_name=None, + preprocessed_data=False, + ): + super().__init__(data_path, num_frames, frame_interval, image_size, transform_name=None) + + # repeat data if it is the example data + if "/assets/example_data/demo" in data_path: + logging.info( + f"Repeat example data {data_path} 10 times (size: {len(self.data)} -> {len(self.data) * 10}) for easy training." + ) + self.data = pd.concat([self.data] * 10, ignore_index=True) + + self.transform_name = transform_name + self.data["id"] = np.arange(len(self.data)) + self.preprocessed_data = preprocessed_data + + def get_data_info(self, index): + T = self.data.iloc[index]["num_frames"] + H = self.data.iloc[index]["height"] + W = self.data.iloc[index]["width"] + return T, H, W + + def getitem(self, index): + # a hack to pass in the (time, height, width) info from sampler + index, num_frames, height, width, ar_name, sp_size, gas = index + + sample = self.data.iloc[index] + path = sample["path"] + text = sample["text"] + ar = height / width + video_fps = 24 # default fps + + if not self.preprocessed_data: + file_type = self.get_type(path) + if file_type == "video": + # loading + vframes, vinfo = read_video(path, backend="av") + video_fps = vinfo["video_fps"] if "video_fps" in vinfo else 24 + + # Sampling video frames + video = temporal_random_crop(vframes, num_frames, self.frame_interval) + video = video.clone() + del vframes + + video_fps = video_fps // self.frame_interval + + # transform + transform = get_transforms_video(self.transform_name, (height, width)) + video = transform(video) # T C H W + else: + # loading + image = pil_loader(path) + video_fps = IMG_FPS + + # transform + transform = get_transforms_image(self.transform_name, (height, width)) + image = transform(image) + + # repeat + video = image.unsqueeze(0) + else: + video = torch.load(sample["vae_emb"], weights_only=False) # C T H W + nf = max(num_frames * 5 // 17, 1) + video = video.permute(1, 0, 2, 3) # T C H W + video = temporal_random_crop(video, nf, 1) + model_args = torch.load(sample["text_emb"], weights_only=False) + text = model_args["y"] + mask = model_args["mask"] + + # TCHW -> CTHW + video = video.permute(1, 0, 2, 3) + ret = { + "video": video, + "num_frames": num_frames, + "height": height, + "width": width, + "ar": ar, + "fps": video_fps, + "ar_name": ar_name, + "sp_size": sp_size, + "gas": gas, + "text": text, + } + if self.preprocessed_data: + ret["mask"] = mask + return ret + + def __getitem__(self, index): + return self.getitem(index) + + +class DummyVariableVideoTextDataset(torch.utils.data.Dataset): + def __init__( + self, + data_size, + seed, + data_path=None, + frame_interval=1, + transform_name=None, + preprocessed_data=False, + num_frames=None, + image_size=(None, None), + bucket_config=None, + common_ar=None, + distribution: str = "zipf", # uniform or zipf + zipf_offset: float = 10, + image_mixing_type: str = "exclusive", + image_mixing_frac: float = -1.0, + res_scale: float = None, + frame_scale: float = None, + text_max_seq_len: int = 100, + text_hidden_size: int = 4096 + ): + self.data_size = data_size + self.seed = seed # use this generator to ensure global consistency + self.generator = np.random.default_rng(seed + data_size) # data are generated on host mem + self.torch_generator = torch.Generator() + self.data_path = data_path + self.preprocessed_data = preprocessed_data + self.image_mixing_type = image_mixing_type + self.image_mixing_frac = image_mixing_frac + self.res_scale = res_scale + self.frame_scale = frame_scale + self.text_max_seq_len = text_max_seq_len + self.text_hidden_size = text_hidden_size + update_common_ar(bucket_config, common_ar) + logging.info(f"common ar for data synthesis: {pformat(COMMON_AR)}") + self._build_dummy_dataset(bucket_config, distribution, zipf_offset) + + self.frame_interval = frame_interval + self.transform_name = transform_name + self.num_frames = num_frames + self.image_size = image_size + + def _build_dummy_dataset(self, bucket_config, distribution, zipf_offset): + self.get_text = False + assert bucket_config is not None + + data, frame_data = [], [] + log_str = "build dummy dataset:" + if self.res_scale is not None and self.frame_scale is not None: + res_list, frame_list = set(), set() + data_dict = {} + for ar_name in bucket_config: + res_list.add(COMMON_AR[ar_name][0]) + for num_frame in bucket_config[ar_name]: + if bucket_config[ar_name][num_frame][1] is not None: + frame_list.add(num_frame) + data_dict[(ar_name, num_frame)] = {} + res_list = np.array(sorted(list(res_list))) + frame_list = np.array(sorted(list(frame_list))) + + res_weights = res_list / max(res_list) * self.res_scale + frame_weights = np.sqrt(np.sqrt(frame_list / max(frame_list))) * self.frame_scale + + total = 0.0 + for ar_name, num_frame in data_dict: + res_w = res_weights[np.where(res_list == COMMON_AR[ar_name][0])[0][0]] + frame_w = frame_weights[np.where(frame_list == num_frame)[0][0]] + prob = half_normal_pdf(res_w) * half_normal_pdf(frame_w) + data_dict[(ar_name, num_frame)] = dict( + res_weight=res_w, + frame_weight=frame_w, + prob=prob, + ) + total += prob + + img_cnt, vid_cnt = 0, 0 + keys = sorted(data_dict.keys(), key=lambda x: COMMON_AR[x[0]][0] * x[1]) + for ar_name, num_frame in keys: + prob = data_dict[(ar_name, num_frame)]["prob"] / total + data_dict[(ar_name, num_frame)]["prob"] = prob + + cnt = int(prob * self.data_size) + data_dict[(ar_name, num_frame)]["cnt"] = cnt + + log_str += f"\n ({ar_name}, {num_frame}), cnt: {cnt}" + if num_frame == 1: + img_cnt += cnt + else: + vid_cnt += cnt + + height_width_pool = pd.DataFrame(COMMON_AR[ar_name][1].values(), columns=["height", "width"]) + idx = self.generator.integers(low=0, high=len(height_width_pool), size=(cnt,)) + bucket_data = height_width_pool.iloc[idx] + bucket_data.reset_index(drop=True, inplace=True) + data.append(bucket_data) + frame_data.extend([num_frame] * cnt) + log_str += f"\nimg_cnt: {img_cnt}, vid_cnt: {vid_cnt}, ratio: {img_cnt / vid_cnt}" + else: + # collect valid bucket candidate, only consider ar_name and num_frame + img_candidates, vid_candidates = [], [] + for ar_name in bucket_config: + for num_frame in bucket_config[ar_name]: + if bucket_config[ar_name][num_frame][1] is not None: + if num_frame == 1: + img_candidates.append((ar_name, num_frame)) + else: + vid_candidates.append((ar_name, num_frame)) + # sort by total pixels = num_frame * pixels + vid_candidates = sorted(vid_candidates, key=lambda x: x[1] * COMMON_AR[x[0]][0]) + img_candidates = sorted(img_candidates, key=lambda x: COMMON_AR[x[0]][0]) + + if self.image_mixing_type == "inclusive": + if self.image_mixing_frac < 0: + img_size = int((len(img_candidates) / (len(vid_candidates) + len(img_candidates))) * self.data_size) + vid_size = self.data_size - img_size + else: + assert self.image_mixing_frac <= 1.0 + img_size = int(self.image_mixing_frac * self.data_size) + vid_size = self.data_size - img_size + elif self.image_mixing_type == "exclusive": + assert self.image_mixing_frac >= 0 + img_size = int(self.image_mixing_frac * self.data_size) + vid_size = self.data_size + else: + raise ValueError(f"unsupported image mixing type: {self.image_mixing_type}") + + if distribution == "uniform": + idx = self.generator.integers(low=0, high=len(img_candidates), size=(img_size,)) + img_candidate_cnts = np.bincount(idx, minlength=len(img_candidates)) + + idx = self.generator.integers(low=0, high=len(vid_candidates), size=(vid_size,)) + vid_candidate_cnts = np.bincount(idx, minlength=len(vid_candidates)) + + elif distribution == "zipf": + zipf_alpha = 3.0 + # https://en.wikipedia.org/wiki/Zipf%27s_law#Formal_definition + ranks = np.power(np.arange(1, len(img_candidates) + 1 + zipf_offset), zipf_alpha) + H_N_s = np.sum(1 / ranks) + img_candidate_prob = 1 / (ranks * H_N_s) + img_candidate_prob = img_candidate_prob[zipf_offset:] + img_candidate_cnts = img_size * img_candidate_prob / np.sum(img_candidate_prob) + + img_candidate_cnts = np.round(img_candidate_cnts).astype(int) + if len(img_candidate_cnts) > 0: + img_candidate_cnts[0] += img_size - np.sum(img_candidate_cnts) + + ranks = np.power(np.arange(1, len(vid_candidates) + 1 + zipf_offset), zipf_alpha) + H_N_s = np.sum(1 / ranks) + vid_candidate_prob = 1 / (ranks * H_N_s) + vid_candidate_prob = vid_candidate_prob[zipf_offset:] + vid_candidate_cnts = vid_size * vid_candidate_prob / np.sum(vid_candidate_prob) + + vid_candidate_cnts = np.round(vid_candidate_cnts).astype(int) + if len(vid_candidate_cnts) > 0: + vid_candidate_cnts[0] += vid_size - np.sum(vid_candidate_cnts) + else: + raise ValueError(f"unsupported distributionL {distribution}") + + for candidates, candidate_cnts in zip( + [img_candidates, vid_candidates], [img_candidate_cnts, vid_candidate_cnts] + ): + for bucket, cnt in zip(candidates, candidate_cnts): + log_str += f"\nbucket: {bucket}, cnt: {cnt}" + + ar_name, num_frame = bucket + height_width_pool = pd.DataFrame(COMMON_AR[ar_name][1].values(), columns=["height", "width"]) + idx = self.generator.integers(low=0, high=len(height_width_pool), size=(cnt,)) + bucket_data = height_width_pool.iloc[idx] + bucket_data.reset_index(drop=True, inplace=True) + data.append(bucket_data) + frame_data.extend([num_frame] * cnt) + + data = pd.concat(data, ignore_index=True) + data.reset_index(drop=True, inplace=True) + + data["num_frames"] = np.array(frame_data) + data["id"] = np.arange(len(data)) + self.data = data + log_str += f"\ndefault data_size: {self.data_size}, full data size: {data.shape[0]}" + logging.info(log_str) + self.data_size = data.shape[0] + + def __getitem__(self, index): + if isinstance(index, int): + # for unit test only + return self.data.iloc[index] + + # a hack to pass in the (time, height, width) info from sampler + index, num_frames, height, width, ar_name, sp_size, gas = index + ar = height / width + + video_fps = 24 + ret = { + # "id": index, + "num_frames": num_frames, + "height": height, + "width": width, + "ar": ar, + "fps": video_fps, + } + if not self.preprocessed_data: + ret["video"] = torch.randn(3, num_frames, height, width, generator=self.torch_generator) + ret["text"] = "dummy text" + else: + nf = max(1, num_frames//4) + self.torch_generator.manual_seed(self.seed + index + self.data_size) + ret["video"] = torch.randn(16, nf, height // 8, width // 8, generator=self.torch_generator) + ret["text"] = torch.rand(self.text_max_seq_len, self.text_hidden_size, generator=self.torch_generator) + + ret["ar_name"] = ar_name + ret["sp_size"] = sp_size + ret["gas"] = gas + return ret + + def __len__( + self, + ): + return self.data_size + + +class VideoPreProcesssDataset(torch.utils.data.Dataset): + """load video according to the csv file. + + Args: + target_video_len (int): the number of video frames will be load. + align_transform (callable): Align different videos in a specified size. + temporal_sample (callable): Sample the target length of a video. + """ + + def __init__( + self, + data_path=None, + num_frames=16, + frame_interval=1, + image_size=(256, 256), + transform_name="center", + ): + self.data_path = data_path + self.data = read_file(data_path) + self.num_frames = num_frames + self.frame_interval = frame_interval + self.image_size = image_size + self.transforms = { + "image": get_transforms_image(transform_name, image_size), + "video": get_transforms_video(transform_name, image_size), + } + self.data = split_df_by_rank(self.data) + + def get_type(self, path): + ext = os.path.splitext(path)[-1].lower() + if ext.lower() in VID_EXTENSIONS: + return "video" + else: + assert ext.lower() in IMG_EXTENSIONS, f"Unsupported file format: {ext}" + return "image" + + def getitem(self, index): + sample = self.data.iloc[index] + path = sample["path"] + file_type = self.get_type(path) + + if file_type == "video": + # loading + vframes, vinfo = read_video(path, backend="av") + video_fps = vinfo["video_fps"] if "video_fps" in vinfo else 24 + + # Sampling video frames + video = remove_interval(vframes, self.frame_interval) + + # transform + transform = self.transforms["video"] + video = transform(video) # T C H W + else: + # loading + image = pil_loader(path) + video_fps = IMG_FPS + + # transform + transform = self.transforms["image"] + image = transform(image) + + # repeat + video = image.unsqueeze(0).repeat(self.num_frames, 1, 1, 1) + + # TCHW -> CTHW + video = video.permute(1, 0, 2, 3) + + ret = {"video": video, "fps": video_fps, "path": path, "index": index} + return ret + + def __getitem__(self, index): + for _ in range(10): + try: + return self.getitem(index) + except Exception as e: + path = self.data.iloc[index]["path"] + print(f"data {path}: {e}") + index = np.random.randint(len(self)) + raise RuntimeError("Too many bad data.") + + def __len__(self): + return len(self.data) + + +class TextPreProcessDataset(torch.utils.data.Dataset): + def __init__( + self, + data_path=None, + ): + self.data_path = data_path + self.data = read_file(data_path) + self.data = split_df_by_rank(self.data) + + def __getitem__(self, index): + sample = self.data.iloc[index] + ret = {"text": sample["text"], "path": sample["path"], "index": index} + return ret + + def __len__(self): + return len(self.data) diff --git a/videosys/training/datasets/cogvideox/read_video.py b/videosys/training/datasets/cogvideox/read_video.py new file mode 100644 index 00000000..2408c4c2 --- /dev/null +++ b/videosys/training/datasets/cogvideox/read_video.py @@ -0,0 +1,258 @@ +import gc +import math +import os +import re +import warnings +from fractions import Fraction +from typing import Any, Dict, List, Optional, Tuple, Union + +import av +import cv2 +import numpy as np +import torch +from torchvision import get_video_backend +from torchvision.io.video import _check_av_available + +MAX_NUM_FRAMES = 2500 + + +def read_video_av( + filename: str, + start_pts: Union[float, Fraction] = 0, + end_pts: Optional[Union[float, Fraction]] = None, + pts_unit: str = "pts", + output_format: str = "THWC", +) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]: + """ + Reads a video from a file, returning both the video frames and the audio frames + + This method is modified from torchvision.io.video.read_video, with the following changes: + + 1. will not extract audio frames and return empty for aframes + 2. remove checks and only support pyav + 3. add container.close() and gc.collect() to avoid thread leakage + 4. try our best to avoid memory leak + + Args: + filename (str): path to the video file + start_pts (int if pts_unit = 'pts', float / Fraction if pts_unit = 'sec', optional): + The start presentation time of the video + end_pts (int if pts_unit = 'pts', float / Fraction if pts_unit = 'sec', optional): + The end presentation time + pts_unit (str, optional): unit in which start_pts and end_pts values will be interpreted, + either 'pts' or 'sec'. Defaults to 'pts'. + output_format (str, optional): The format of the output video tensors. Can be either "THWC" (default) or "TCHW". + + Returns: + vframes (Tensor[T, H, W, C] or Tensor[T, C, H, W]): the `T` video frames + aframes (Tensor[K, L]): the audio frames, where `K` is the number of channels and `L` is the number of points + info (Dict): metadata for the video and audio. Can contain the fields video_fps (float) and audio_fps (int) + """ + # format + output_format = output_format.upper() + if output_format not in ("THWC", "TCHW"): + raise ValueError(f"output_format should be either 'THWC' or 'TCHW', got {output_format}.") + # file existence + if not os.path.exists(filename): + raise RuntimeError(f"File not found: {filename}") + # backend check + assert get_video_backend() == "pyav", "pyav backend is required for read_video_av" + _check_av_available() + # end_pts check + if end_pts is None: + end_pts = float("inf") + if end_pts < start_pts: + raise ValueError(f"end_pts should be larger than start_pts, got start_pts={start_pts} and end_pts={end_pts}") + + # == get video info == + info = {} + # TODO: creating an container leads to memory leak (1G for 8 workers 1 GPU) + container = av.open(filename, metadata_errors="ignore") + # fps + video_fps = container.streams.video[0].average_rate + # guard against potentially corrupted files + if video_fps is not None: + info["video_fps"] = float(video_fps) + iter_video = container.decode(**{"video": 0}) + frame = next(iter_video).to_rgb().to_ndarray() + height, width = frame.shape[:2] + total_frames = container.streams.video[0].frames + if total_frames == 0: + total_frames = MAX_NUM_FRAMES + warnings.warn(f"total_frames is 0, using {MAX_NUM_FRAMES} as a fallback") + container.close() + del container + + # HACK: must create before iterating stream + # use np.zeros will not actually allocate memory + # use np.ones will lead to a little memory leak + video_frames = np.zeros((total_frames, height, width, 3), dtype=np.uint8) + + # == read == + try: + # TODO: The reading has memory leak (4G for 8 workers 1 GPU) + container = av.open(filename, metadata_errors="ignore") + assert container.streams.video is not None + video_frames = _read_from_stream( + video_frames, + container, + start_pts, + end_pts, + pts_unit, + container.streams.video[0], + {"video": 0}, + filename=filename, + ) + except av.AVError as e: + print(f"[Warning] Error while reading video {filename}: {e}") + + vframes = torch.from_numpy(video_frames).clone() + del video_frames + if output_format == "TCHW": + # [T,H,W,C] --> [T,C,H,W] + vframes = vframes.permute(0, 3, 1, 2) + + aframes = torch.empty((1, 0), dtype=torch.float32) + return vframes, aframes, info + + +def _read_from_stream( + video_frames, + container: "av.container.Container", + start_offset: float, + end_offset: float, + pts_unit: str, + stream: "av.stream.Stream", + stream_name: Dict[str, Optional[Union[int, Tuple[int, ...], List[int]]]], + filename: Optional[str] = None, +) -> List["av.frame.Frame"]: + if pts_unit == "sec": + # TODO: we should change all of this from ground up to simply take + # sec and convert to MS in C++ + start_offset = int(math.floor(start_offset * (1 / stream.time_base))) + if end_offset != float("inf"): + end_offset = int(math.ceil(end_offset * (1 / stream.time_base))) + else: + warnings.warn("The pts_unit 'pts' gives wrong results. Please use pts_unit 'sec'.") + + should_buffer = True + max_buffer_size = 5 + if stream.type == "video": + # DivX-style packed B-frames can have out-of-order pts (2 frames in a single pkt) + # so need to buffer some extra frames to sort everything + # properly + extradata = stream.codec_context.extradata + # overly complicated way of finding if `divx_packed` is set, following + # https://github.com/FFmpeg/FFmpeg/commit/d5a21172283572af587b3d939eba0091484d3263 + if extradata and b"DivX" in extradata: + # can't use regex directly because of some weird characters sometimes... + pos = extradata.find(b"DivX") + d = extradata[pos:] + o = re.search(rb"DivX(\d+)Build(\d+)(\w)", d) + if o is None: + o = re.search(rb"DivX(\d+)b(\d+)(\w)", d) + if o is not None: + should_buffer = o.group(3) == b"p" + seek_offset = start_offset + # some files don't seek to the right location, so better be safe here + seek_offset = max(seek_offset - 1, 0) + if should_buffer: + # FIXME this is kind of a hack, but we will jump to the previous keyframe + # so this will be safe + seek_offset = max(seek_offset - max_buffer_size, 0) + try: + # TODO check if stream needs to always be the video stream here or not + container.seek(seek_offset, any_frame=False, backward=True, stream=stream) + except av.AVError as e: + print(f"[Warning] Error while seeking video {filename}: {e}") + return [] + + # == main == + buffer_count = 0 + frames_pts = [] + cnt = 0 + try: + for _idx, frame in enumerate(container.decode(**stream_name)): + frames_pts.append(frame.pts) + video_frames[cnt] = frame.to_rgb().to_ndarray() + cnt += 1 + if cnt >= len(video_frames): + break + if frame.pts >= end_offset: + if should_buffer and buffer_count < max_buffer_size: + buffer_count += 1 + continue + break + except av.AVError as e: + print(f"[Warning] Error while reading video {filename}: {e}") + + # garbage collection for thread leakage + container.close() + del container + # NOTE: manually garbage collect to close pyav threads + gc.collect() + + # ensure that the results are sorted wrt the pts + # NOTE: here we assert frames_pts is sorted + start_ptr = 0 + end_ptr = cnt + while start_ptr < end_ptr and frames_pts[start_ptr] < start_offset: + start_ptr += 1 + while start_ptr < end_ptr and frames_pts[end_ptr - 1] > end_offset: + end_ptr -= 1 + if start_offset > 0 and start_offset not in frames_pts[start_ptr:end_ptr]: + # if there is no frame that exactly matches the pts of start_offset + # add the last frame smaller than start_offset, to guarantee that + # we will have all the necessary data. This is most useful for audio + if start_ptr > 0: + start_ptr -= 1 + result = video_frames[start_ptr:end_ptr].copy() + return result + + +def read_video_cv2(video_path): + cap = cv2.VideoCapture(video_path) + + if not cap.isOpened(): + # print("Error: Unable to open video") + raise ValueError + else: + fps = cap.get(cv2.CAP_PROP_FPS) + vinfo = { + "video_fps": fps, + } + + frames = [] + while True: + # Read a frame from the video + ret, frame = cap.read() + + # If frame is not read correctly, break the loop + if not ret: + break + + frames.append(frame[:, :, ::-1]) # BGR to RGB + + # Exit if 'q' is pressed + if cv2.waitKey(25) & 0xFF == ord("q"): + break + + # Release the video capture object and close all windows + cap.release() + cv2.destroyAllWindows() + + frames = np.stack(frames) + frames = torch.from_numpy(frames) # [T, H, W, C=3] + frames = frames.permute(0, 3, 1, 2) + return frames, vinfo + + +def read_video(video_path, backend="av"): + if backend == "cv2": + vframes, vinfo = read_video_cv2(video_path) + elif backend == "av": + vframes, _, vinfo = read_video_av(filename=video_path, pts_unit="sec", output_format="TCHW") + else: + raise ValueError + + return vframes, vinfo diff --git a/videosys/training/datasets/cogvideox/sampler.py b/videosys/training/datasets/cogvideox/sampler.py new file mode 100644 index 00000000..93933149 --- /dev/null +++ b/videosys/training/datasets/cogvideox/sampler.py @@ -0,0 +1,1098 @@ +import logging +import math +import time +from collections import OrderedDict, defaultdict +from dataclasses import dataclass +from pprint import pformat +from typing import Iterator, List, Optional, Union + +import torch +import torch.distributed as dist +from torch.utils.data import Dataset, DistributedSampler + +from videosys.core.dcp.profiler import get_profiler + +from .bucket import Bucket +from .datasets import DummyVariableVideoTextDataset, VariableVideoTextDataset + +GB = 1024**3 + + +# use pandarallel to accelerate bucket processing +# NOTE: pandarallel should only access local variables +def apply(data, method=None, frame_interval=None, seed=None, num_bucket=None): + return method( + data["num_frames"], + data["height"], + data["width"], + frame_interval, + seed + data["id"] * num_bucket, + ) + + +@dataclass +class BucketPlan: + bucket_id: tuple + batch_size: int + sp_size: int + exec_time: float + + +class StatefulDistributedSampler(DistributedSampler): + def __init__( + self, + dataset: Dataset, + num_replicas: Optional[int] = None, + rank: Optional[int] = None, + shuffle: bool = True, + seed: int = 0, + drop_last: bool = False, + ) -> None: + super().__init__(dataset, num_replicas, rank, shuffle, seed, drop_last) + self.start_index: int = 0 + + def __iter__(self) -> Iterator: + iterator = super().__iter__() + indices = list(iterator) + indices = indices[self.start_index :] + return iter(indices) + + def __len__(self) -> int: + return self.num_samples - self.start_index + + def reset(self) -> None: + self.start_index = 0 + + def state_dict(self, step) -> dict: + return {"start_index": step} + + def load_state_dict(self, state_dict: dict) -> None: + self.__dict__.update(state_dict) + + +class VariableVideoBatchSampler(DistributedSampler): + def __init__( + self, + dataset: Union[VariableVideoTextDataset, DummyVariableVideoTextDataset], + bucket_config: dict, + num_replicas: Optional[int] = None, + rank: Optional[int] = None, + shuffle: bool = True, + seed: int = 0, + drop_last: bool = False, + verbose: bool = False, + num_bucket_build_workers: int = 1, + sp_balance_scope: str = "iter", + auto_grad_accumulation: bool = False, + max_grad_accumulation_steps: int = 5, + parallel_mgr=None, + calculate_imbalance: bool = False, + min_grad_accumulation_steps: int = 2, + ) -> None: + super().__init__( + dataset=dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle, seed=seed, drop_last=drop_last + ) + self.dataset = dataset + self.bucket = Bucket(bucket_config) + self.verbose = verbose + self.last_micro_batch_access_index = 0 + self.approximate_num_batch = None + self.keep_last = not drop_last + self._get_num_batch_cached_bucket_sample_dict = None + self.num_bucket_build_workers = num_bucket_build_workers + + self.sp_balance_scope = sp_balance_scope + self.auto_grad_accumulation = auto_grad_accumulation + self.max_grad_accumulation_steps = max_grad_accumulation_steps + self.min_grad_accumulation_steps = min_grad_accumulation_steps + self.profiler = get_profiler() + self.optimized_schedule = "local" if self.profiler.dynamic_sp else None + self.generator = None + if self.shuffle: + self.generator = torch.Generator() + self.generator.manual_seed(self.seed + self.epoch) + self.cached_bucket_id_access_order = None + self.effective_samples = 0 + self.parallel_mgr = parallel_mgr + self.calculate_imbalance = calculate_imbalance + self.imbalance_list = [] + self.est_total_execution_time = 0.0 + + def __iter__(self) -> Iterator[List[int]]: + if self._get_num_batch_cached_bucket_sample_dict is not None: + bucket_sample_dict = self._get_num_batch_cached_bucket_sample_dict + self._get_num_batch_cached_bucket_sample_dict = None + else: + bucket_sample_dict = self.group_by_bucket() + if self.optimized_schedule is not None: + self.get_num_batch_with_optimized_schedule(bucket_sample_dict) + else: + self.get_num_batch(bucket_sample_dict) + + if self.optimized_schedule is not None: + yield from self._optimized_schedule_iter(bucket_sample_dict) + else: + yield from self._bucketized_iter(bucket_sample_dict) + + def change_timer_group(self, timers): + cur_group = self.parallel_mgr.sp_group + for t in timers: + timers[t].group = cur_group + + def _build_bucketized_bucket_id_access_order(self, bucket_sample_dict): + g = torch.Generator() + g.manual_seed(self.seed + self.epoch) + bucket_micro_batch_count = OrderedDict() + self.effective_samples = 0 + + # process the samples + for bucket_id, data_list in bucket_sample_dict.items(): + ar_name, num_frame = bucket_id[:2] + if not self.profiler.is_valid_bucket(ar_name, num_frame): + if self.verbose: + logging.info(f"skip building batches for bucket {bucket_id} because it's invalid") + continue + # handle droplast + bs_per_gpu = self.get_batch_size(bucket_id) + org_num_samples = len(data_list) + remainder = org_num_samples % bs_per_gpu + + if (not self.keep_last) and remainder > 0: + # we just drop the remainder to make it divisible + data_list = data_list[:-remainder] + + # handle shuffle + if self.shuffle: + data_indices = torch.randperm(len(data_list), generator=g).tolist() + data_list = [data_list[i] for i in data_indices] + bucket_sample_dict[bucket_id] = data_list + + # compute how many micro-batches each bucket has + effect_size = len(data_list) + if self.keep_last: + num_micro_batches = (effect_size + bs_per_gpu - 1) // bs_per_gpu + self.effective_samples += effect_size + else: + num_micro_batches = effect_size // bs_per_gpu + self.effective_samples += bs_per_gpu * num_micro_batches + bucket_micro_batch_count[bucket_id] = num_micro_batches + + # compute the bucket access order + # each bucket may have more than one batch of data + # thus bucket_id may appear more than 1 time + bucket_id_access_order = [] + for bucket_id, num_micro_batch in bucket_micro_batch_count.items(): + bucket_id_access_order.extend([bucket_id] * num_micro_batch) + + # randomize the access order + if self.shuffle: + bucket_id_access_order_indices = torch.randperm(len(bucket_id_access_order), generator=g).tolist() + bucket_id_access_order = [bucket_id_access_order[i] for i in bucket_id_access_order_indices] + + # make the number of bucket accesses divisible by dp size + original_batches = len(bucket_id_access_order) + remainder = original_batches % self.num_replicas + bucket_num_batch_to_deduct = defaultdict(int) + if remainder > 0: + for i in range(original_batches - remainder, original_batches): + bucket_num_batch_to_deduct[bucket_id_access_order[i]] += 1 + + bucket_id_access_order = bucket_id_access_order[: original_batches - remainder] + + for bucket_id, num_batch_to_deduct in bucket_num_batch_to_deduct.items(): + total_samples = len(bucket_sample_dict[bucket_id]) + total_batches = bucket_micro_batch_count[bucket_id] + + left_batchs = total_batches - num_batch_to_deduct + left_samples = left_batchs * self.get_batch_size(bucket_id) + self.effective_samples -= total_samples - left_samples + + if self.verbose: + for i in range(len(bucket_id_access_order)): + logging.info(f"iter {i}, bucket_id: {bucket_id_access_order[i]}") + logging.info(f"dropped: {pformat(bucket_num_batch_to_deduct, sort_dicts=False)}") + return bucket_id_access_order + + def _bucketized_iter(self, bucket_sample_dict): + bucket_last_consumed = OrderedDict() + # acc_num_samples = torch.zeros(1, device=torch.cuda.current_device(), dtype=torch.float) + if self.cached_bucket_id_access_order is not None: + bucket_id_access_order = self.cached_bucket_id_access_order + self.cached_bucket_id_access_order = None + else: + bucket_id_access_order = self._build_bucketized_bucket_id_access_order(bucket_sample_dict) + + # prepare each batch from its bucket + # according to the predefined bucket access order + num_iters = len(bucket_id_access_order) // self.num_replicas + start_iter_idx = self.last_micro_batch_access_index // self.num_replicas + + # re-compute the micro-batch consumption + # this is useful when resuming from a state dict with a different number of GPUs + self.last_micro_batch_access_index = start_iter_idx * self.num_replicas + for i in range(self.last_micro_batch_access_index): + bucket_id = bucket_id_access_order[i] + bucket_bs = self.get_batch_size(bucket_id) + if bucket_id in bucket_last_consumed: + bucket_last_consumed[bucket_id] += bucket_bs + else: + bucket_last_consumed[bucket_id] = bucket_bs + + self.est_total_execution_time = 0.0 + for i in range(start_iter_idx, num_iters): + bucket_access_list = bucket_id_access_order[i * self.num_replicas : (i + 1) * self.num_replicas] + self.last_micro_batch_access_index += self.num_replicas + + # compute the data samples consumed by each access + bucket_access_boundaries = [] + for bucket_id in bucket_access_list: + bucket_bs = self.get_batch_size(bucket_id) + last_consumed_index = bucket_last_consumed.get(bucket_id, 0) + bucket_access_boundaries.append([last_consumed_index, last_consumed_index + bucket_bs]) + + # update consumption + if bucket_id in bucket_last_consumed: + bucket_last_consumed[bucket_id] += bucket_bs + else: + bucket_last_consumed[bucket_id] = bucket_bs + + if self.calculate_imbalance: + total_time = [] + for bucket_id in bucket_access_list: + cur_time = self.profiler.get_execution_time(bucket_id[0], bucket_id[1]) + total_time.append(cur_time) + max_time = max(total_time) + imbalance = sum([(max_time - t) for t in total_time]) / len(total_time) + self.imbalance_list.append(imbalance) + self.est_total_execution_time += max_time + logging.info( + f"iter {i}, \nbucket_access_list: {bucket_access_list}\ntotal time: {total_time}" + f"\ncur imbalance: {imbalance/max_time*100:.4f} %, \nestimate total imbalance: {sum(self.imbalance_list) / len(self.imbalance_list) * num_iters:.4f}s" + ) + + # compute the range of data accessed by each GPU + bucket_id = bucket_access_list[self.rank] + boundary = bucket_access_boundaries[self.rank] + cur_micro_batch = bucket_sample_dict[bucket_id][boundary[0] : boundary[1]] + + # encode t, h, w into the sample index + real_t, real_h, real_w = self.bucket.get_thw(bucket_id) + cur_micro_batch = [ + (idx, real_t, real_h, real_w, bucket_id[0], self.parallel_mgr.sp_size, 1) for idx in cur_micro_batch + ] + yield cur_micro_batch + + self.reset() + + def get_batch_size(self, bucket_id): + bs_from_bucket_config = self.profiler.get_batch_size(bucket_id[0], bucket_id[1]) + return bs_from_bucket_config + + def __len__(self) -> int: + bucket_sample_dict = self.group_by_bucket() + self._get_num_batch_cached_bucket_sample_dict = bucket_sample_dict + + if self.optimized_schedule is not None: + return self.get_num_batch_with_optimized_schedule(bucket_sample_dict) + else: + return self.get_num_batch(bucket_sample_dict) // self.num_replicas + + def group_by_bucket(self) -> dict: + bucket_sample_dict = OrderedDict() + + from pandarallel import pandarallel + + pandarallel.initialize(nb_workers=self.num_bucket_build_workers, progress_bar=False, verbose=self.verbose) + if self.verbose: + logging.info(f"Building buckets...") + bucket_ids = self.dataset.data.parallel_apply( + apply, + axis=1, + method=self.bucket.get_bucket_id, + frame_interval=self.dataset.frame_interval, + seed=self.seed + self.epoch, + num_bucket=self.bucket.num_bucket, + ) + + # group by bucket + # each data sample is put into a bucket with a similar image/video size + for i in range(len(self.dataset)): + bucket_id = bucket_ids[i] + if bucket_id is None: + continue + if bucket_id not in bucket_sample_dict: + bucket_sample_dict[bucket_id] = [] + bucket_sample_dict[bucket_id].append(i) + return bucket_sample_dict + + def _calculate_grad_accumulation_num(self, cur_first_batch_bucket_id_list): + def score_func(new_time, median_time): + if new_time > median_time: + return (new_time - median_time) * 1.2 + else: + return (median_time - new_time) * 1 + + exec_time_list = [self.profiler.get_execution_time(*i[0][:2]) for i in cur_first_batch_bucket_id_list] + max_time = max(exec_time_list) * self.max_grad_accumulation_steps + min_diff = float("inf") + num_gas = None + for exec_time_outer in exec_time_list: + max_mult_outer = int(max_time / exec_time_outer) + for mult in range(1, max_mult_outer + 1): + time_outer = exec_time_outer * mult + if time_outer > max_time: + break + gas_outer, diff_outer = [], 0 + for exec_time_inner in exec_time_list: + gas_inner, diff_inner = None, float("inf") + max_mult_inner = int(max_time / exec_time_inner) + for gas_val in range(1, max_mult_inner + 1): + time_inner = exec_time_inner * gas_val + if time_inner > max_time: + break + now_diff = score_func(time_inner, time_outer) + if now_diff < diff_inner: + diff_inner = now_diff + gas_inner = gas_val + diff_outer += diff_inner + gas_outer.append(gas_inner) + if diff_outer < min_diff: + min_diff = diff_outer + num_gas = gas_outer + + # if max grad accumulation is less than min grad accumulation, repeat the grad accumulation + if max(num_gas) < self.min_grad_accumulation_steps: + grad_accumulation_steps = math.ceil(self.min_grad_accumulation_steps / max(num_gas)) + num_gas = [i * grad_accumulation_steps for i in num_gas] + + return num_gas + + def _build_local_bucket_id_access_order_acc(self, bucket_sample_dict): + wsize = dist.get_world_size() + bucket_id_access_order = [] + self.effective_samples = 0 + + bucket_sp_map, sp_bucket_map = dict(), dict() + for bucket_id, data_list in bucket_sample_dict.items(): + ar_name, num_frame = bucket_id[:2] + if not self.profiler.is_valid_bucket(ar_name, num_frame): + if self.verbose: + logging.info(f"skip building batches for bucket {bucket_id} because it's invalid") + continue + + # collect bucket_sp_map, sp_bucket_map + sp_size = self.profiler.get_sp_size(ar_name, num_frame) + max_bs = self.profiler.get_batch_size(ar_name, num_frame) + cur_len = len(data_list) + remainder = cur_len % max_bs + if (not self.keep_last) and remainder > 0: + if self.drop_last: + data_list = data_list[:-remainder] + else: + pad = max_bs - remainder + if pad > cur_len: + data_list = data_list * ((pad + cur_len - 1) // cur_len + 1) + data_list = data_list[: pad + cur_len] + else: + data_list += data_list[:pad] + logging.info(f"bucket {bucket_id} original len: {cur_len} padded len: {len(data_list)} for bs {max_bs}") + + bucket_sp_map[bucket_id] = sp_size + if sp_size not in sp_bucket_map: + sp_bucket_map[sp_size] = [] + sp_bucket_map[sp_size].append(bucket_id) + + if self.generator is not None: + data_indices = torch.randperm(len(data_list), generator=self.generator).tolist() + data_list = [data_list[i] for i in data_indices] + + bucket_sample_dict[bucket_id] = data_list + + bucket_sample_dict_last_access = {k: 0 for k in bucket_sample_dict.keys()} + sp_size_list = sorted(sp_bucket_map.keys()) + while sp_size_list: + cur_first_batch_bucket_id_list = [] + remain_gpus = wsize + has_one_more_batch = True + while remain_gpus > 0: + max_sp_idx = 0 + while max_sp_idx < len(sp_size_list) and remain_gpus >= sp_size_list[max_sp_idx]: + max_sp_idx += 1 + + if max_sp_idx == 0: + # if false, cur_first_batch_bucket_id_list will be discarded + has_one_more_batch = False + break + + # select sp size + if self.generator is not None: + cur_sp_size_list = sp_size_list[:max_sp_idx] + sp_size_sample_num = OrderedDict({k: len(sp_bucket_map[k]) for k in cur_sp_size_list}) + total_samples = sum(sp_size_sample_num.values()) + val = torch.rand(size=(1,), generator=self.generator).item() + idx = list(sp_size_sample_num.keys())[-1] + for k, v in sp_size_sample_num.items(): + if val < v / total_samples: + idx = k + break + idx = sp_size_list.index(idx) + else: + idx = max_sp_idx - 1 + sp = sp_size_list[idx] + remain_gpus -= sp + + # select bucket id + if self.generator is not None: + bucket_index = torch.randint( + low=0, high=len(sp_bucket_map[sp]), size=(1,), generator=self.generator + ).item() + else: + bucket_index = 0 + bucket_id = sp_bucket_map[sp][bucket_index] + ar_name, num_frame = bucket_id[:2] + # max bs for first batch + bs = self.profiler.get_batch_size(ar_name, num_frame) + + offset = bucket_sample_dict_last_access[bucket_id] + num_samples = min(bs, len(bucket_sample_dict[bucket_id]) - offset) + cur_first_batch_bucket_id_list.append((bucket_id, num_samples)) + + offset += num_samples + bucket_sample_dict_last_access[bucket_id] = offset + if offset == len(bucket_sample_dict[bucket_id]): + sp_bucket_map[sp].pop(bucket_index) + if not sp_bucket_map[sp]: + sp_size_list.remove(sp) + sp_bucket_map.pop(sp) + + # to be more efficient, only pop when gpu is full + if not self.keep_last and remain_gpus <= 0: + # get max grad accumulation + exec_time_list = [ + self.profiler.get_execution_time(*i[0][:2]) for i in cur_first_batch_bucket_id_list + ] + num_gas = self._calculate_grad_accumulation_num(cur_first_batch_bucket_id_list) + max_time = max([exec_time * gas for exec_time, gas in zip(exec_time_list, num_gas)]) + # remove bucket that is not enough for grad accumulation + index = 0 + while index < len(cur_first_batch_bucket_id_list): + bucket_id, bs = cur_first_batch_bucket_id_list[index] + ar_name, num_frame = bucket_id[:2] + exec_time = self.profiler.get_execution_time(ar_name, num_frame) + max_bs = self.profiler.get_batch_size(ar_name, num_frame) + sp_size = self.profiler.get_sp_size(ar_name, num_frame) + + required_gas = max_time // exec_time - 1 + remain_batches = ( + len(bucket_sample_dict[bucket_id]) - bucket_sample_dict_last_access[bucket_id] + ) // max_bs + + # calculate repeat times for this bucket + bucket_id_list = [i[0] for i in cur_first_batch_bucket_id_list] + occur_times = bucket_id_list.count(bucket_id) + required_gas *= occur_times + + if remain_batches < required_gas: + cur_first_batch_bucket_id_list.pop(index) + remain_gpus += sp_size + bucket_sample_dict_last_access[bucket_id] = len(bucket_sample_dict[bucket_id]) + if sp_size in sp_bucket_map: + if bucket_id in sp_bucket_map[sp_size]: + sp_bucket_map[sp_size].remove(bucket_id) + if not sp_bucket_map[sp_size]: + sp_size_list.remove(sp_size) + sp_bucket_map.pop(sp_size) + else: + index += 1 + + if has_one_more_batch: + # sort to make sure fitting + cur_first_batch_bucket_id_list = sorted( + cur_first_batch_bucket_id_list, key=lambda x: bucket_sp_map[x[0]], reverse=True + ) + + if self.auto_grad_accumulation: + num_gas = self._calculate_grad_accumulation_num(cur_first_batch_bucket_id_list) + else: + num_gas = [1 for _ in cur_first_batch_bucket_id_list] + + # decide accumulate batch_size + # [[bucket_id, bs] * ] * + cur_batch_bucket_id_list = [] + batch_log = [] + # TODO: potential optimization to decide batch size and number of micro batches (for grad acc) + for bidx, each in enumerate(cur_first_batch_bucket_id_list): + this_bucket_acc_list = [each] + bucket_id, max_bs = each + batch_log.append( + [ + (bucket_id + (bucket_sp_map[bucket_id], max_bs)), + ] + ) + # collect effective samples for the first batch of this iter + self.effective_samples += max_bs + + # if has remaining samples for grad acc batch + offset = bucket_sample_dict_last_access[bucket_id] + total_len = len(bucket_sample_dict[bucket_id]) + if offset < total_len: + ar_name, num_frame = bucket_id[:2] + sp = bucket_sp_map[bucket_id] + + # here I use the max bs from profile + bs = max_bs + + # minus one because of the first batch + num_acc = num_gas[bidx] - 1 + + while num_acc > 0 and offset < total_len: + num_samples = min(bs, total_len - offset) + this_bucket_acc_list.append((bucket_id, num_samples)) + + offset += num_samples + num_acc -= 1 + + # collect effective samples for grad acc batches of this iter + self.effective_samples += num_samples + batch_log[-1].append((bucket_id + (bucket_sp_map[bucket_id], num_samples))) + + bucket_sample_dict_last_access[bucket_id] = offset + # remove exhausted buckets from local indices + if offset == total_len: + sp_bucket_map[sp].remove(bucket_id) + if not sp_bucket_map[sp]: + sp_size_list.remove(sp) + sp_bucket_map.pop(sp) + + cur_batch_bucket_id_list.append(this_bucket_acc_list) + logging.info( + f"iter {len(bucket_id_access_order)}, gas: {num_gas} actual: {[len(each) for each in cur_batch_bucket_id_list]}" + f", buckets: {batch_log}" + ) + bucket_id_access_order.append(cur_batch_bucket_id_list) + + return bucket_id_access_order + + def _build_local_bucket_id_access_order_sp_balance(self, bucket_sample_dict): + wsize = dist.get_world_size() + bucket_id_access_order = [] + self.effective_samples = 0 + + bucket_sample_counts = {} + sp_bucket_map = dict() + for bucket_id, data_list in bucket_sample_dict.items(): + ar_name, num_frame = bucket_id[:2] + if not self.profiler.is_valid_bucket(ar_name, num_frame): + if self.verbose: + logging.info(f"skip building batches for bucket {bucket_id} because it's invalid") + continue + + # shuffle + if self.generator is not None: + data_indices = torch.randperm(len(data_list), generator=self.generator).tolist() + data_list = [data_list[i] for i in data_indices] + + # record + bucket_sample_dict[bucket_id] = data_list + bucket_sample_counts[bucket_id] = len(data_list) + sp_size = self.profiler.get_sp_size(ar_name, num_frame) + if sp_size not in sp_bucket_map: + sp_bucket_map[sp_size] = [] + sp_bucket_map[sp_size].append(bucket_id) + + sp_size_list = sorted(sp_bucket_map.keys()) + while sp_size_list: + cur_batch_bucket_id_list = [] + remain_gpus = wsize + has_one_more_batch = True + while remain_gpus > 0: + max_sp_idx = 0 + while max_sp_idx < len(sp_size_list) and remain_gpus >= sp_size_list[max_sp_idx]: + max_sp_idx += 1 + + if max_sp_idx == 0: + # if false, cur_batch_bucket_id_list will be discarded + has_one_more_batch = False + break + + # select sp + if self.generator is not None: + cur_sp_size_list = sp_size_list[:max_sp_idx] + probs = torch.tensor([len(sp_bucket_map[k]) for k in cur_sp_size_list], dtype=torch.float) + idx = torch.multinomial(probs, 1, generator=self.generator).item() + else: + idx = max_sp_idx - 1 + sp = sp_size_list[idx] + + # select bucket + if self.generator is not None: + bucket_index = torch.randint(0, len(sp_bucket_map[sp]), (1,), generator=self.generator).item() + else: + bucket_index = 0 + bucket_id = sp_bucket_map[sp][bucket_index] + ar_name, num_frame = bucket_id[:2] + bs = self.profiler.get_batch_size(ar_name, num_frame) + + num_samples = min(bs, bucket_sample_counts[bucket_id]) + bucket_sample_counts[bucket_id] -= num_samples + if bucket_sample_counts[bucket_id] == 0: + sp_bucket_map[sp].remove(bucket_id) + if not sp_bucket_map[sp]: + sp_size_list.remove(sp) + sp_bucket_map.pop(sp) + + exec_time = self.profiler.get_execution_time(ar_name, num_frame) / bs * num_samples + cur_batch_bucket_id_list.append( + BucketPlan( + bucket_id=bucket_id, + batch_size=num_samples, + sp_size=sp, + exec_time=exec_time, + ) + ) + + remain_gpus -= sp + + if not self.keep_last and len(cur_batch_bucket_id_list) > 1: + min_time_idx, min_time = -1, float("inf") + for i, each in enumerate(cur_batch_bucket_id_list): + max_bs = self.profiler.get_batch_size(*each.bucket_id[:2]) + if each.exec_time < min_time and max_bs > each.batch_size: + min_time = each.exec_time + min_time_idx = i + if min_time_idx > -1: + # drop last batch for this bucket if it is the shortest time + pop_plan = cur_batch_bucket_id_list.pop(min_time_idx) + remain_gpus += pop_plan.sp_size + + if not has_one_more_batch: + continue + + logging.info( + f"iter {len(bucket_id_access_order)}\noriginal buckets: {[(each.bucket_id, each.batch_size, each.sp_size, each.exec_time) for each in cur_batch_bucket_id_list]}" + ) + + min_time, min_bucket = min( + [(each.exec_time, each.bucket_id) for each in cur_batch_bucket_id_list], key=lambda x: x[0] + ) + skip_bucket_idx = [] + if self.keep_last: + no_last_batch, last_batch = [], [] + for i, bucket_plan in enumerate(cur_batch_bucket_id_list): + if bucket_sample_counts[bucket_plan.bucket_id] > 0: + no_last_batch.append(bucket_plan) + else: + last_batch.append(bucket_plan) + + if not no_last_batch: + assert len(last_batch) == len(cur_batch_bucket_id_list) + skip_bucket_idx = list(range(len(cur_batch_bucket_id_list))) + min_time = 0 + else: + min_time, min_bucket = min( + [(each.exec_time, each.bucket_id) for each in no_last_batch], key=lambda x: x[0] + ) + skip_bucket_idx = [] + if last_batch: + for i, bucket_plan in enumerate(cur_batch_bucket_id_list): + if bucket_plan.exec_time < min_time: + skip_bucket_idx.append(i) + logs = [] + for i, bucket_plan in enumerate(cur_batch_bucket_id_list): + if i in skip_bucket_idx or bucket_plan.bucket_id == min_bucket: + continue + + ar_name, num_frame = bucket_plan.bucket_id[:2] + + original_exec_time = bucket_plan.exec_time + original_batch_size = bucket_plan.batch_size + original_sp_size = bucket_plan.sp_size + unit_time = original_exec_time / original_batch_size + + original_remain_samples = bucket_sample_counts[bucket_plan.bucket_id] + + best_diff = float("inf") + best_exec_time, best_bs, best_sp = original_exec_time, original_batch_size, original_sp_size + cur_sp_size = original_sp_size + log_str = f"\n>>> bucket {bucket_plan.bucket_id}, bs: {original_batch_size}, sp: {original_sp_size}, time: {original_exec_time}" + while cur_sp_size <= self.profiler.max_sp: + max_bs = self.profiler.detail_results[ar_name][num_frame][cur_sp_size]["bs"] + cur_unit_time = self.profiler.detail_results[ar_name][num_frame][cur_sp_size]["pred_time"] / max_bs + + if max_bs - original_batch_size > original_remain_samples: + max_bs = original_batch_size + original_remain_samples + + cur_bs = max(1, round(min_time / cur_unit_time)) + if cur_bs > max_bs: + cur_bs = max_bs + + cur_exec_time = cur_unit_time * cur_bs + cur_diff = abs(cur_exec_time - min_time) + log_str += ( + f"\nCHECK sp: {cur_sp_size}, bs: {cur_bs}, time: {cur_exec_time}, diff: {cur_diff}/{best_diff}" + ) + if cur_diff < best_diff: + best_diff = cur_diff + best_exec_time = cur_exec_time + best_bs = cur_bs + best_sp = cur_sp_size + + if abs(cur_exec_time / min_time - 1) < 0.1: + break + cur_sp_size *= 2 + logs.append(log_str) + assert ( + best_bs > 0 + ), f"best_bs: {best_bs} - {original_batch_size}, best sp: {best_sp} - {original_sp_size}, best time: {best_exec_time} - {original_exec_time}, min time: {min_time}" + + # return left samples back to record + if best_bs < cur_batch_bucket_id_list[i].batch_size: + bucket_id = cur_batch_bucket_id_list[i].bucket_id + org_sp = bucket_plan.sp_size + left_bs = cur_batch_bucket_id_list[i].batch_size - best_bs + + bucket_sample_counts[bucket_id] += left_bs + if org_sp not in sp_bucket_map: + sp_bucket_map[org_sp] = [] + sp_bucket_map[org_sp].append(bucket_id) + sp_size_list.append(org_sp) + else: + if sp_bucket_map[org_sp].count(bucket_id) == 0: + sp_bucket_map[org_sp].append(bucket_id) + elif best_bs > cur_batch_bucket_id_list[i].batch_size: + bucket_id = cur_batch_bucket_id_list[i].bucket_id + org_sp = bucket_plan.sp_size + + bucket_sample_counts[bucket_id] -= best_bs - cur_batch_bucket_id_list[i].batch_size + if bucket_sample_counts[bucket_id] == 0: + sp_bucket_map[org_sp].remove(bucket_id) + if not sp_bucket_map[org_sp]: + sp_size_list.remove(org_sp) + sp_bucket_map.pop(org_sp) + + cur_batch_bucket_id_list[i].batch_size = best_bs + cur_batch_bucket_id_list[i].exec_time = best_exec_time + cur_batch_bucket_id_list[i].sp_size = best_sp + + # pop and recover buckets out of limit + cur_batch_bucket_id_list = sorted(cur_batch_bucket_id_list, key=lambda x: x.sp_size, reverse=True) + total_gpus = sum([each.sp_size for each in cur_batch_bucket_id_list]) + poped = [] + while total_gpus > wsize: + bucket_plan = cur_batch_bucket_id_list.pop() + bucket_id = bucket_plan.bucket_id + ar_name, num_frame = bucket_id[:2] + org_sp = self.profiler.get_sp_size(ar_name, num_frame) + sp = bucket_plan.sp_size + bs = bucket_plan.batch_size + + bucket_sample_counts[bucket_id] += bs + if org_sp not in sp_bucket_map: + sp_bucket_map[org_sp] = [] + sp_bucket_map[org_sp].append(bucket_id) + sp_size_list.append(org_sp) + elif sp_bucket_map[org_sp].count(bucket_id) == 0: + sp_bucket_map[org_sp].append(bucket_id) + + total_gpus -= sp + poped.append(bucket_plan) + assert total_gpus == wsize + + # rebalance bs only + min_max_time, min_max_bucket = float("inf"), None + for i, bucket_plan in enumerate(cur_batch_bucket_id_list): + bucket_id = bucket_plan.bucket_id + ar_name, num_frame = bucket_id[:2] + cur_bs = bucket_plan.batch_size + cur_sp = bucket_plan.sp_size + cur_time = bucket_plan.exec_time + unit_time = cur_time / cur_bs + + max_bs = self.profiler.detail_results[ar_name][num_frame][cur_sp]["bs"] + if max_bs - cur_bs > bucket_sample_counts[bucket_id]: + max_bs = cur_bs + bucket_sample_counts[bucket_id] + + max_tmp_time = unit_time * max_bs + if max_tmp_time < min_max_time: + min_max_time = max_tmp_time + + for i, bucket_plan in enumerate(cur_batch_bucket_id_list): + bucket_id = bucket_plan.bucket_id + # if bucket_id == min_max_bucket: + # continue + + ar_name, num_frame = bucket_id[:2] + cur_sp = bucket_plan.sp_size + max_bs = self.profiler.detail_results[ar_name][num_frame][cur_sp]["bs"] + + cur_exec_time = bucket_plan.exec_time + cur_bs = bucket_plan.batch_size + unit_time = cur_exec_time / cur_bs + + diff_time = min_max_time - cur_exec_time + if diff_time <= 0: + continue + + increment_bs = int(diff_time // unit_time) + if increment_bs + cur_bs > max_bs: + increment_bs = max_bs - cur_bs + increment_bs = min(increment_bs, bucket_sample_counts[bucket_id]) + increment_time = unit_time * increment_bs + + if increment_bs > 0: + sp = self.profiler.get_sp_size(ar_name, num_frame) + bucket_sample_counts[bucket_id] -= increment_bs + if bucket_sample_counts[bucket_id] == 0: + sp_bucket_map[sp].remove(bucket_id) + if not sp_bucket_map[sp]: + sp_size_list.remove(sp) + sp_bucket_map.pop(sp) + + bucket_plan.batch_size += increment_bs + bucket_plan.exec_time += increment_time + assert ( + bucket_plan.batch_size > 0 + ), f"increment_bs: {increment_bs}, cur_bs: {cur_bs}, max_bs: {max_bs}, increment_time: {increment_time}, cur_time: {cur_exec_time}, min_max_time: {min_max_time}" + + this_bucket_acc_list = [] + for bucket_plan in cur_batch_bucket_id_list: + self.effective_samples += bucket_plan.batch_size + this_bucket_acc_list.append( + [(bucket_plan.bucket_id, bucket_plan.batch_size, bucket_plan.sp_size, bucket_plan.exec_time)] + ) + bucket_id_access_order.append(this_bucket_acc_list) + logging.info( + f"iter {len(bucket_id_access_order)}\nbuckets: {[(each.bucket_id, each.batch_size, each.sp_size, each.exec_time) for each in cur_batch_bucket_id_list]}" + f"\npoped: {[(each.bucket_id, each.batch_size, each.sp_size, each.exec_time) for each in poped]}" + f"\nmin time: {min_time:.2f}, max time: {min_max_time:.2f}" + f"\n{logs}" + ) + + return bucket_id_access_order + + def _optimized_schedule_iter(self, bucket_sample_dict): + rank, wsize = dist.get_rank(), dist.get_world_size() + is_sp_balance_iter = ( + self.profiler.dynamic_sp + and not self.profiler.dynamic_recompute + and not self.auto_grad_accumulation + and self.sp_balance_scope == "iter" + ) + + # bucket_id_access_order: [[(bucket_id, bs)] * ] * + if self.cached_bucket_id_access_order is not None: + bucket_id_access_order = self.cached_bucket_id_access_order + self.cached_bucket_id_access_order = None + elif is_sp_balance_iter: + bucket_id_access_order = self._build_local_bucket_id_access_order_sp_balance(bucket_sample_dict) + else: + # support grad acc + bucket_id_access_order = self._build_local_bucket_id_access_order_acc(bucket_sample_dict) + + num_iter = len(bucket_id_access_order) + # skip resume code + start_iter_idx = self.last_micro_batch_access_index + self.est_total_execution_time = 0.0 + # generate execution plan + bucket_last_consumed = OrderedDict() + for i in range(start_iter_idx, num_iter): + bucket_id_access_list = bucket_id_access_order[i] + + sp_size_map_list, bucket_id_map_list = [], [] + bucket_access_boundaries = [] + for bucket_list in bucket_id_access_list: + boundary_gas_list = [] + for item in bucket_list: + bucket_id, bs = item[:2] + + last_consumed_index = bucket_last_consumed.get(bucket_id, 0) + boundary_gas_list.append([last_consumed_index, last_consumed_index + bs]) + + if bucket_id in bucket_last_consumed: + bucket_last_consumed[bucket_id] += bs + else: + bucket_last_consumed[bucket_id] = bs + assert bucket_last_consumed[bucket_id] <= len( + bucket_sample_dict[bucket_id] + ), f"rank {rank} iter: {i}, bucket_id_access_list: {bucket_id_access_list}, bucket_last_consumed[{bucket_id}] = {bucket_last_consumed[bucket_id]} > {len(bucket_sample_dict[bucket_id])}" + + bucket_id = bucket_list[0][0] + if is_sp_balance_iter: + sp_size = bucket_list[0][2] + else: + sp_size = self.profiler.get_sp_size(bucket_id[0], bucket_id[1]) + + sp_size_map_list.extend([sp_size] * sp_size) + bucket_id_map_list.extend([bucket_list] * sp_size) + bucket_access_boundaries.extend([boundary_gas_list] * sp_size) + + if self.calculate_imbalance: + log_bucket_list, log_time_list = [], [] + for bucket_list in bucket_id_access_list: + bucket_id = bucket_list[0][0] + + log_bucket_list.append(bucket_id) + if is_sp_balance_iter: + cur_time = bucket_list[0][3] + else: + cur_time = self.profiler.get_execution_time(bucket_id[0], bucket_id[1]) + log_time_list.append(len(bucket_list) * cur_time) + + total_time = [] + for bucket_list in bucket_id_map_list: + gas = len(bucket_list) + bucket_id = bucket_list[0][0] + if is_sp_balance_iter: + cur_time = bucket_list[0][3] + else: + cur_time = self.profiler.get_execution_time(bucket_id[0], bucket_id[1]) + cur_time = cur_time * gas + total_time.append(cur_time) + max_time = max(total_time) + imbalance = sum([(max_time - t) for t in total_time]) / len(total_time) + self.imbalance_list.append(imbalance) + self.est_total_execution_time += max_time + logging.info( + f"iter {i}, \nbucket_id_map_list: {log_bucket_list}\ntotal time: {log_time_list}" + f"\ncur imbalance: {imbalance/max_time*100:.4f} %, \nestimate total imbalance: {sum(self.imbalance_list) / len(self.imbalance_list) * num_iter:.4f}s" + ) + + assert len(sp_size_map_list) == wsize + sp_size = sp_size_map_list[rank] + bucket_list = bucket_id_map_list[rank] + boundaries = bucket_access_boundaries[rank] + + gas = len(bucket_list) + cur_micro_batches = [] + for bucket, boundary in zip(bucket_list, boundaries): + bucket_id, bs = bucket[:2] + gas_micro_batches = bucket_sample_dict[bucket_id][boundary[0] : boundary[1]] + assert ( + len(gas_micro_batches) == bs + ), f"iter {i}, rank {rank}, target bs: {bs}, actual bs: {len(gas_micro_batches)}" + + real_t, real_h, real_w = self.bucket.get_thw(bucket_id) + cur_micro_batches.extend( + [(idx, real_t, real_h, real_w, bucket_id[0], sp_size, gas) for idx in gas_micro_batches] + ) + + assert ( + len(cur_micro_batches) > 0 + ), f"rank: {rank} iter: {i}, bucket_id_map_list: {bucket_id_map_list}, bucket_access_boundaries: {bucket_access_boundaries}" + yield cur_micro_batches + + self.reset() + + def get_num_batch_with_optimized_schedule(self, bucket_sample_dict) -> int: + start_ = time.time() + if ( + self.profiler.dynamic_sp + and not self.profiler.dynamic_recompute + and not self.auto_grad_accumulation + and self.sp_balance_scope == "iter" + ): + bucket_id_access_order = self._build_local_bucket_id_access_order_sp_balance(bucket_sample_dict) + self.cached_bucket_id_access_order = bucket_id_access_order + else: + bucket_id_access_order = self._build_local_bucket_id_access_order_acc(bucket_sample_dict) + self.cached_bucket_id_access_order = bucket_id_access_order + self.approximate_num_batch = len(bucket_id_access_order) + elapsed = time.time() - start_ + + # collect statistics + total_samples = 0 + bucket_stat_dict = dict() + for k, v in bucket_sample_dict.items(): + ar_name, num_frame = k[:2] + if not self.profiler.is_valid_bucket(ar_name, num_frame): + continue + size = len(v) + max_bs = self.profiler.get_batch_size(ar_name, num_frame) + if self.keep_last: + effect_size = size + max_bs - 1 + else: + effect_size = size + num_batch = effect_size // max_bs + if not self.keep_last: + size = max_bs * num_batch + + total_samples += size + + bucket_stat_dict[k] = [size, num_batch] + + # log + if dist.get_rank() == 0 and self.verbose: + logging.info(f"Building index costs: {elapsed:.2f}s") + logging.info(f"Bucket Info at epoch {self.epoch} with optimized schedule:") + logging.info("Bucket [#sample, #batch]:\n%s", pformat(bucket_stat_dict, sort_dicts=False)) + logging.info( + "#training batch: %s, #training sample: %s, #non empty bucket: %s", + self.approximate_num_batch, + total_samples, + len(bucket_sample_dict), + ) + + return self.approximate_num_batch + + def get_num_batch(self, bucket_sample_dict) -> int: + start_ = time.time() + bucket_id_access_order = self._build_bucketized_bucket_id_access_order(bucket_sample_dict) + self.cached_bucket_id_access_order = bucket_id_access_order + self.approximate_num_batch = len(bucket_id_access_order) + elapsed = time.time() - start_ + + # collect statistics + total_samples = 0 + total_batch = 0 + + bucket_stat_dict = dict() + for k, v in bucket_sample_dict.items(): + if not self.profiler.is_valid_bucket(k[0], k[1]): + continue + size = len(v) + bs = self.get_batch_size(k) + if self.keep_last: + effect_size = size + bs - 1 + else: + effect_size = size + num_batch = effect_size // bs + if not self.keep_last: + size = bs * num_batch + + total_samples += size + total_batch += num_batch + + bucket_stat_dict[k] = [size, num_batch] + + # log + if dist.get_rank() == 0 and self.verbose: + logging.info(f"Building index costs: {elapsed:.2f}s") + logging.info(f"Bucket Info at epoch {self.epoch} with bucketized schedule:") + logging.info("Bucket [#sample, #batch]:\n%s", pformat(bucket_stat_dict, sort_dicts=False)) + logging.info( + "#training batch: %s, #training sample: %s, #non empty bucket: %s", + total_batch, + total_samples, + len(bucket_sample_dict), + ) + return self.approximate_num_batch + + def reset(self): + if self.calculate_imbalance and len(self.imbalance_list) > 0: + total_imbalance_time = sum(self.imbalance_list) + logging.info( + f"Total imbalance for this epoch: {total_imbalance_time:.2f}/{self.est_total_execution_time:.2f} ({total_imbalance_time/self.est_total_execution_time*100:.2f}%)" + ) + self.imbalance_list = [] + self.est_total_execution_time = 0.0 + self.last_micro_batch_access_index = 0 + + def state_dict(self, num_steps: int) -> dict: + # the last_micro_batch_access_index in the __iter__ is often + # not accurate during multi-workers and data prefetching + # thus, we need the user to pass the actual steps which have been executed + # to calculate the correct last_micro_batch_access_index + return {"seed": self.seed, "epoch": self.epoch, "last_micro_batch_access_index": num_steps * self.num_replicas} + + def load_state_dict(self, state_dict: dict) -> None: + self.__dict__.update(state_dict) diff --git a/videosys/training/datasets/cogvideox/utils.py b/videosys/training/datasets/cogvideox/utils.py new file mode 100644 index 00000000..00737e5f --- /dev/null +++ b/videosys/training/datasets/cogvideox/utils.py @@ -0,0 +1,363 @@ +import logging +import math +import os +import random +import re + +import numpy as np +import pandas as pd +import requests +import torch +import torch.distributed as dist +import torchvision +import torchvision.transforms as transforms +from PIL import Image +from torchvision.datasets.folder import IMG_EXTENSIONS, pil_loader +from torchvision.io import write_video +from torchvision.utils import save_image + +from . import video_transforms + +VID_EXTENSIONS = (".mp4", ".avi", ".mov", ".mkv") + +regex = re.compile( + r"^(?:http|ftp)s?://" # http:// or https:// + r"(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}\.?)|" # domain... + r"localhost|" # localhost... + r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})" # ...or ip + r"(?::\d+)?" # optional port + r"(?:/?|[/?]\S+)$", + re.IGNORECASE, +) + + +def is_img(path): + ext = os.path.splitext(path)[-1].lower() + return ext in IMG_EXTENSIONS + + +def is_vid(path): + ext = os.path.splitext(path)[-1].lower() + return ext in VID_EXTENSIONS + + +def is_url(url): + return re.match(regex, url) is not None + + +def read_file(input_path): + if input_path.endswith(".csv"): + return pd.read_csv(input_path) + elif input_path.endswith(".parquet"): + return pd.read_parquet(input_path) + else: + raise NotImplementedError(f"Unsupported file format: {input_path}") + + +def split_df_by_rank(df): + world_size = dist.get_world_size() + rank = dist.get_rank() + chunk_size = max(1, len(df) // world_size) + return df.iloc[rank * chunk_size : (rank + 1) * chunk_size] + + +def download_url(input_path): + output_dir = "cache" + os.makedirs(output_dir, exist_ok=True) + base_name = os.path.basename(input_path) + output_path = os.path.join(output_dir, base_name) + img_data = requests.get(input_path).content + with open(output_path, "wb") as handler: + handler.write(img_data) + print(f"URL {input_path} downloaded to {output_path}") + return output_path + + +def temporal_random_crop(vframes, num_frames, frame_interval): + temporal_sample = video_transforms.TemporalRandomCrop(num_frames * frame_interval) + total_frames = len(vframes) + start_frame_ind, end_frame_ind = temporal_sample(total_frames) + assert ( + end_frame_ind - start_frame_ind >= num_frames + ), f"Not enough frames to sample, {end_frame_ind} - {start_frame_ind} < {num_frames}" + frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, num_frames, dtype=int) + video = vframes[frame_indice] + return video + + +def remove_interval(vframes, frame_interval): + total_frames = len(vframes) + target_frames = total_frames // frame_interval + frame_indice = np.linspace(0, total_frames - 1, target_frames, dtype=int) + video = vframes[frame_indice] + return video + + +def get_transforms_video(name="center", image_size=(256, 256)): + if name is None: + return None + elif name == "center": + assert image_size[0] == image_size[1], "image_size must be square for center crop" + transform_video = transforms.Compose( + [ + video_transforms.ToTensorVideo(), # TCHW + # video_transforms.RandomHorizontalFlipVideo(), + video_transforms.UCFCenterCropVideo(image_size[0]), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + ] + ) + elif name == "resize_crop": + transform_video = transforms.Compose( + [ + video_transforms.ToTensorVideo(), # TCHW + video_transforms.ResizeCrop(image_size), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + ] + ) + else: + raise NotImplementedError(f"Transform {name} not implemented") + return transform_video + + +def get_transforms_image(name="center", image_size=(256, 256)): + if name is None: + return None + elif name == "center": + assert image_size[0] == image_size[1], "Image size must be square for center crop" + transform = transforms.Compose( + [ + transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, image_size[0])), + # transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + ] + ) + elif name == "resize_crop": + transform = transforms.Compose( + [ + transforms.Lambda(lambda pil_image: resize_crop_to_fill(pil_image, image_size)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + ] + ) + else: + raise NotImplementedError(f"Transform {name} not implemented") + return transform + + +def read_image_from_path(path, transform=None, transform_name="center", num_frames=1, image_size=(256, 256)): + image = pil_loader(path) + if transform is None: + transform = get_transforms_image(image_size=image_size, name=transform_name) + image = transform(image) + video = image.unsqueeze(0).repeat(num_frames, 1, 1, 1) + video = video.permute(1, 0, 2, 3) + return video + + +def read_video_from_path(path, transform=None, transform_name="center", image_size=(256, 256)): + vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit="sec", output_format="TCHW") + if transform is None: + transform = get_transforms_video(image_size=image_size, name=transform_name) + video = transform(vframes) # T C H W + video = video.permute(1, 0, 2, 3) + return video + + +def read_from_path(path, image_size, transform_name="center"): + if is_url(path): + path = download_url(path) + ext = os.path.splitext(path)[-1].lower() + if ext.lower() in VID_EXTENSIONS: + return read_video_from_path(path, image_size=image_size, transform_name=transform_name) + else: + assert ext.lower() in IMG_EXTENSIONS, f"Unsupported file format: {ext}" + return read_image_from_path(path, image_size=image_size, transform_name=transform_name) + + +def save_sample(x, save_path=None, fps=8, normalize=True, value_range=(-1, 1), force_video=False, verbose=True): + """ + Args: + x (Tensor): shape [C, T, H, W] + """ + assert x.ndim == 4 + + if not force_video and x.shape[1] == 1: # T = 1: save as image + save_path += ".png" + x = x.squeeze(1) + save_image([x], save_path, normalize=normalize, value_range=value_range) + else: + save_path += ".mp4" + if normalize: + low, high = value_range + x.clamp_(min=low, max=high) + x.sub_(low).div_(max(high - low, 1e-5)) + + x = x.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 3, 0).to("cpu", torch.uint8) + write_video(save_path, x, fps=fps, video_codec="h264") + if verbose: + print(f"Saved to {save_path}") + return save_path + + +def center_crop_arr(pil_image, image_size): + """ + Center cropping implementation from ADM. + https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 + """ + while min(*pil_image.size) >= 2 * image_size: + pil_image = pil_image.resize(tuple(x // 2 for x in pil_image.size), resample=Image.BOX) + + scale = image_size / min(*pil_image.size) + pil_image = pil_image.resize(tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC) + + arr = np.array(pil_image) + crop_y = (arr.shape[0] - image_size) // 2 + crop_x = (arr.shape[1] - image_size) // 2 + return Image.fromarray(arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size]) + + +def resize_crop_to_fill(pil_image, image_size): + w, h = pil_image.size # PIL is (W, H) + th, tw = image_size + rh, rw = th / h, tw / w + if rh > rw: + sh, sw = th, round(w * rh) + image = pil_image.resize((sw, sh), Image.BICUBIC) + i = 0 + j = int(round((sw - tw) / 2.0)) + else: + sh, sw = round(h * rw), tw + image = pil_image.resize((sw, sh), Image.BICUBIC) + i = int(round((sh - th) / 2.0)) + j = 0 + arr = np.array(image) + assert i + th <= arr.shape[0] and j + tw <= arr.shape[1] + return Image.fromarray(arr[i : i + th, j : j + tw]) + + +class MaskGenerator: + def __init__(self, mask_ratios): + valid_mask_names = [ + "identity", + "quarter_random", + "quarter_head", + "quarter_tail", + "quarter_head_tail", + "image_random", + "image_head", + "image_tail", + "image_head_tail", + "random", + "intepolate", + ] + assert all( + mask_name in valid_mask_names for mask_name in mask_ratios.keys() + ), f"mask_name should be one of {valid_mask_names}, got {mask_ratios.keys()}" + assert all( + mask_ratio >= 0 for mask_ratio in mask_ratios.values() + ), f"mask_ratio should be greater than or equal to 0, got {mask_ratios.values()}" + assert all( + mask_ratio <= 1 for mask_ratio in mask_ratios.values() + ), f"mask_ratio should be less than or equal to 1, got {mask_ratios.values()}" + # sum of mask_ratios should be 1 + if "identity" not in mask_ratios: + mask_ratios["identity"] = 1.0 - sum(mask_ratios.values()) + assert math.isclose( + sum(mask_ratios.values()), 1.0, abs_tol=1e-6 + ), f"sum of mask_ratios should be 1, got {sum(mask_ratios.values())}" + logging.info("mask ratios: %s", mask_ratios) + self.mask_ratios = mask_ratios + + def get_mask(self, x): + mask_type = random.random() + mask_name = None + prob_acc = 0.0 + for mask, mask_ratio in self.mask_ratios.items(): + prob_acc += mask_ratio + if mask_type < prob_acc: + mask_name = mask + break + + num_frames = x.shape[2] + # Hardcoded condition_frames + condition_frames_max = num_frames // 4 + + mask = torch.ones(num_frames, dtype=torch.bool, device=x.device) + if num_frames <= 1: + return mask + + if mask_name == "quarter_random": + random_size = random.randint(1, condition_frames_max) + random_pos = random.randint(0, x.shape[2] - random_size) + mask[random_pos : random_pos + random_size] = 0 + elif mask_name == "image_random": + random_size = 1 + random_pos = random.randint(0, x.shape[2] - random_size) + mask[random_pos : random_pos + random_size] = 0 + elif mask_name == "quarter_head": + random_size = random.randint(1, condition_frames_max) + mask[:random_size] = 0 + elif mask_name == "image_head": + random_size = 1 + mask[:random_size] = 0 + elif mask_name == "quarter_tail": + random_size = random.randint(1, condition_frames_max) + mask[-random_size:] = 0 + elif mask_name == "image_tail": + random_size = 1 + mask[-random_size:] = 0 + elif mask_name == "quarter_head_tail": + random_size = random.randint(1, condition_frames_max) + mask[:random_size] = 0 + mask[-random_size:] = 0 + elif mask_name == "image_head_tail": + random_size = 1 + mask[:random_size] = 0 + mask[-random_size:] = 0 + elif mask_name == "intepolate": + random_start = random.randint(0, 1) + mask[random_start::2] = 0 + elif mask_name == "random": + mask_ratio = random.uniform(0.1, 0.9) + mask = torch.rand(num_frames, device=x.device) > mask_ratio + # if mask is all False, set the last frame to True + if not mask.any(): + mask[-1] = 1 + + return mask + + def get_masks(self, x): + masks = [] + for _ in range(len(x)): + mask = self.get_mask(x) + masks.append(mask) + masks = torch.stack(masks, dim=0) + return masks + + +def get_text_embeddings(tokenizer, text_encoder, texts): + text_tokens_and_mask = tokenizer( + texts, + max_length=300, + padding="max_length", + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + device = text_encoder.device + input_ids = text_tokens_and_mask["input_ids"].to(device) + attention_mask = text_tokens_and_mask["attention_mask"].to(device) + with torch.no_grad(): + text_encoder_embs = text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + )["last_hidden_state"].detach() + return text_encoder_embs, attention_mask + + +def encode_prompt(text_encoder, tokenizer, text): + caption_embs, emb_masks = get_text_embeddings(tokenizer, text_encoder, text) + caption_embs = caption_embs[:, None] + return dict(y=caption_embs, mask=emb_masks) diff --git a/videosys/training/datasets/cogvideox/video_transforms.py b/videosys/training/datasets/cogvideox/video_transforms.py new file mode 100644 index 00000000..8cf50468 --- /dev/null +++ b/videosys/training/datasets/cogvideox/video_transforms.py @@ -0,0 +1,520 @@ +# Copyright 2024 Vchitect/Latte + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.# Modified from Latte + +# - This file is adapted from https://github.com/Vchitect/Latte/blob/main/datasets/video_transforms.py + + +import numbers +import random + +import numpy as np +import torch + + +def _is_tensor_video_clip(clip): + if not torch.is_tensor(clip): + raise TypeError("clip should be Tensor. Got %s" % type(clip)) + + if not clip.ndimension() == 4: + raise ValueError("clip should be 4D. Got %dD" % clip.dim()) + + return True + + +def crop(clip, i, j, h, w): + """ + Args: + clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) + """ + if len(clip.size()) != 4: + raise ValueError("clip should be a 4D tensor") + return clip[..., i : i + h, j : j + w] + + +def resize(clip, target_size, interpolation_mode): + if len(target_size) != 2: + raise ValueError(f"target size should be tuple (height, width), instead got {target_size}") + return torch.nn.functional.interpolate(clip, size=target_size, mode=interpolation_mode, align_corners=False) + + +def resize_scale(clip, target_size, interpolation_mode): + if len(target_size) != 2: + raise ValueError(f"target size should be tuple (height, width), instead got {target_size}") + H, W = clip.size(-2), clip.size(-1) + scale_ = target_size[0] / min(H, W) + return torch.nn.functional.interpolate(clip, scale_factor=scale_, mode=interpolation_mode, align_corners=False) + + +def resized_crop(clip, i, j, h, w, size, interpolation_mode="bilinear"): + """ + Do spatial cropping and resizing to the video clip + Args: + clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) + i (int): i in (i,j) i.e coordinates of the upper left corner. + j (int): j in (i,j) i.e coordinates of the upper left corner. + h (int): Height of the cropped region. + w (int): Width of the cropped region. + size (tuple(int, int)): height and width of resized clip + Returns: + clip (torch.tensor): Resized and cropped clip. Size is (T, C, H, W) + """ + if not _is_tensor_video_clip(clip): + raise ValueError("clip should be a 4D torch.tensor") + clip = crop(clip, i, j, h, w) + clip = resize(clip, size, interpolation_mode) + return clip + + +def center_crop(clip, crop_size): + if not _is_tensor_video_clip(clip): + raise ValueError("clip should be a 4D torch.tensor") + h, w = clip.size(-2), clip.size(-1) + th, tw = crop_size + if h < th or w < tw: + raise ValueError("height and width must be no smaller than crop_size") + + i = int(round((h - th) / 2.0)) + j = int(round((w - tw) / 2.0)) + return crop(clip, i, j, th, tw) + + +def center_crop_using_short_edge(clip): + if not _is_tensor_video_clip(clip): + raise ValueError("clip should be a 4D torch.tensor") + h, w = clip.size(-2), clip.size(-1) + if h < w: + th, tw = h, h + i = 0 + j = int(round((w - tw) / 2.0)) + else: + th, tw = w, w + i = int(round((h - th) / 2.0)) + j = 0 + return crop(clip, i, j, th, tw) + + +def resize_crop_to_fill(clip, target_size): + if not _is_tensor_video_clip(clip): + raise ValueError("clip should be a 4D torch.tensor") + h, w = clip.size(-2), clip.size(-1) + th, tw = target_size[0], target_size[1] + rh, rw = th / h, tw / w + if rh > rw: + sh, sw = th, round(w * rh) + clip = resize(clip, (sh, sw), "bilinear") + i = 0 + j = int(round(sw - tw) / 2.0) + else: + sh, sw = round(h * rw), tw + clip = resize(clip, (sh, sw), "bilinear") + i = int(round(sh - th) / 2.0) + j = 0 + assert i + th <= clip.size(-2) and j + tw <= clip.size(-1) + return crop(clip, i, j, th, tw) + + +def random_shift_crop(clip): + """ + Slide along the long edge, with the short edge as crop size + """ + if not _is_tensor_video_clip(clip): + raise ValueError("clip should be a 4D torch.tensor") + h, w = clip.size(-2), clip.size(-1) + + if h <= w: + short_edge = h + else: + short_edge = w + + th, tw = short_edge, short_edge + + i = torch.randint(0, h - th + 1, size=(1,)).item() + j = torch.randint(0, w - tw + 1, size=(1,)).item() + return crop(clip, i, j, th, tw) + + +def to_tensor(clip): + """ + Convert tensor data type from uint8 to float, divide value by 255.0 and + permute the dimensions of clip tensor + Args: + clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W) + Return: + clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W) + """ + _is_tensor_video_clip(clip) + if not clip.dtype == torch.uint8: + raise TypeError("clip tensor should have data type uint8. Got %s" % str(clip.dtype)) + # return clip.float().permute(3, 0, 1, 2) / 255.0 + return clip.float() / 255.0 + + +def normalize(clip, mean, std, inplace=False): + """ + Args: + clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W) + mean (tuple): pixel RGB mean. Size is (3) + std (tuple): pixel standard deviation. Size is (3) + Returns: + normalized clip (torch.tensor): Size is (T, C, H, W) + """ + if not _is_tensor_video_clip(clip): + raise ValueError("clip should be a 4D torch.tensor") + if not inplace: + clip = clip.clone() + mean = torch.as_tensor(mean, dtype=clip.dtype, device=clip.device) + # print(mean) + std = torch.as_tensor(std, dtype=clip.dtype, device=clip.device) + clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None]) + return clip + + +def hflip(clip): + """ + Args: + clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W) + Returns: + flipped clip (torch.tensor): Size is (T, C, H, W) + """ + if not _is_tensor_video_clip(clip): + raise ValueError("clip should be a 4D torch.tensor") + return clip.flip(-1) + + +class ResizeCrop: + def __init__(self, size): + if isinstance(size, numbers.Number): + self.size = (int(size), int(size)) + else: + self.size = size + + def __call__(self, clip): + clip = resize_crop_to_fill(clip, self.size) + return clip + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(size={self.size})" + + +class RandomCropVideo: + def __init__(self, size): + if isinstance(size, numbers.Number): + self.size = (int(size), int(size)) + else: + self.size = size + + def __call__(self, clip): + """ + Args: + clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) + Returns: + torch.tensor: randomly cropped video clip. + size is (T, C, OH, OW) + """ + i, j, h, w = self.get_params(clip) + return crop(clip, i, j, h, w) + + def get_params(self, clip): + h, w = clip.shape[-2:] + th, tw = self.size + + if h < th or w < tw: + raise ValueError(f"Required crop size {(th, tw)} is larger than input image size {(h, w)}") + + if w == tw and h == th: + return 0, 0, h, w + + i = torch.randint(0, h - th + 1, size=(1,)).item() + j = torch.randint(0, w - tw + 1, size=(1,)).item() + + return i, j, th, tw + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(size={self.size})" + + +class CenterCropResizeVideo: + """ + First use the short side for cropping length, + center crop video, then resize to the specified size + """ + + def __init__( + self, + size, + interpolation_mode="bilinear", + ): + if isinstance(size, tuple): + if len(size) != 2: + raise ValueError(f"size should be tuple (height, width), instead got {size}") + self.size = size + else: + self.size = (size, size) + + self.interpolation_mode = interpolation_mode + + def __call__(self, clip): + """ + Args: + clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) + Returns: + torch.tensor: scale resized / center cropped video clip. + size is (T, C, crop_size, crop_size) + """ + clip_center_crop = center_crop_using_short_edge(clip) + clip_center_crop_resize = resize( + clip_center_crop, target_size=self.size, interpolation_mode=self.interpolation_mode + ) + return clip_center_crop_resize + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}" + + +class UCFCenterCropVideo: + """ + First scale to the specified size in equal proportion to the short edge, + then center cropping + """ + + def __init__( + self, + size, + interpolation_mode="bilinear", + ): + if isinstance(size, tuple): + if len(size) != 2: + raise ValueError(f"size should be tuple (height, width), instead got {size}") + self.size = size + else: + self.size = (size, size) + + self.interpolation_mode = interpolation_mode + + def __call__(self, clip): + """ + Args: + clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) + Returns: + torch.tensor: scale resized / center cropped video clip. + size is (T, C, crop_size, crop_size) + """ + clip_resize = resize_scale(clip=clip, target_size=self.size, interpolation_mode=self.interpolation_mode) + clip_center_crop = center_crop(clip_resize, self.size) + return clip_center_crop + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}" + + +class KineticsRandomCropResizeVideo: + """ + Slide along the long edge, with the short edge as crop size. And resie to the desired size. + """ + + def __init__( + self, + size, + interpolation_mode="bilinear", + ): + if isinstance(size, tuple): + if len(size) != 2: + raise ValueError(f"size should be tuple (height, width), instead got {size}") + self.size = size + else: + self.size = (size, size) + + self.interpolation_mode = interpolation_mode + + def __call__(self, clip): + clip_random_crop = random_shift_crop(clip) + clip_resize = resize(clip_random_crop, self.size, self.interpolation_mode) + return clip_resize + + +class CenterCropVideo: + def __init__( + self, + size, + interpolation_mode="bilinear", + ): + if isinstance(size, tuple): + if len(size) != 2: + raise ValueError(f"size should be tuple (height, width), instead got {size}") + self.size = size + else: + self.size = (size, size) + + self.interpolation_mode = interpolation_mode + + def __call__(self, clip): + """ + Args: + clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) + Returns: + torch.tensor: center cropped video clip. + size is (T, C, crop_size, crop_size) + """ + clip_center_crop = center_crop(clip, self.size) + return clip_center_crop + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}" + + +class NormalizeVideo: + """ + Normalize the video clip by mean subtraction and division by standard deviation + Args: + mean (3-tuple): pixel RGB mean + std (3-tuple): pixel RGB standard deviation + inplace (boolean): whether do in-place normalization + """ + + def __init__(self, mean, std, inplace=False): + self.mean = mean + self.std = std + self.inplace = inplace + + def __call__(self, clip): + """ + Args: + clip (torch.tensor): video clip must be normalized. Size is (C, T, H, W) + """ + return normalize(clip, self.mean, self.std, self.inplace) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(mean={self.mean}, std={self.std}, inplace={self.inplace})" + + +class ToTensorVideo: + """ + Convert tensor data type from uint8 to float, divide value by 255.0 and + permute the dimensions of clip tensor + """ + + def __init__(self): + pass + + def __call__(self, clip): + """ + Args: + clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W) + Return: + clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W) + """ + return to_tensor(clip) + + def __repr__(self) -> str: + return self.__class__.__name__ + + +class RandomHorizontalFlipVideo: + """ + Flip the video clip along the horizontal direction with a given probability + Args: + p (float): probability of the clip being flipped. Default value is 0.5 + """ + + def __init__(self, p=0.5): + self.p = p + + def __call__(self, clip): + """ + Args: + clip (torch.tensor): Size is (T, C, H, W) + Return: + clip (torch.tensor): Size is (T, C, H, W) + """ + if random.random() < self.p: + clip = hflip(clip) + return clip + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(p={self.p})" + + +# ------------------------------------------------------------ +# --------------------- Sampling --------------------------- +# ------------------------------------------------------------ +class TemporalRandomCrop(object): + """Temporally crop the given frame indices at a random location. + + Args: + size (int): Desired length of frames will be seen in the model. + """ + + def __init__(self, size): + self.size = size + + def __call__(self, total_frames): + rand_end = max(0, total_frames - self.size - 1) + begin_index = random.randint(0, rand_end) + end_index = min(begin_index + self.size, total_frames) + return begin_index, end_index + + +if __name__ == "__main__": + import os + + import numpy as np + import torchvision.io as io + from torchvision import transforms + from torchvision.utils import save_image + + vframes, aframes, info = io.read_video(filename="./v_Archery_g01_c03.avi", pts_unit="sec", output_format="TCHW") + + trans = transforms.Compose( + [ + ToTensorVideo(), + RandomHorizontalFlipVideo(), + UCFCenterCropVideo(512), + # NormalizeVideo(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + ] + ) + + target_video_len = 32 + frame_interval = 1 + total_frames = len(vframes) + print(total_frames) + + temporal_sample = TemporalRandomCrop(target_video_len * frame_interval) + + # Sampling video frames + start_frame_ind, end_frame_ind = temporal_sample(total_frames) + # print(start_frame_ind) + # print(end_frame_ind) + assert end_frame_ind - start_frame_ind >= target_video_len + frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, target_video_len, dtype=int) + print(frame_indice) + + select_vframes = vframes[frame_indice] + print(select_vframes.shape) + print(select_vframes.dtype) + + select_vframes_trans = trans(select_vframes) + print(select_vframes_trans.shape) + print(select_vframes_trans.dtype) + + select_vframes_trans_int = ((select_vframes_trans * 0.5 + 0.5) * 255).to(dtype=torch.uint8) + print(select_vframes_trans_int.dtype) + print(select_vframes_trans_int.permute(0, 2, 3, 1).shape) + + io.write_video("./test.avi", select_vframes_trans_int.permute(0, 2, 3, 1), fps=8) + + for i in range(target_video_len): + save_image( + select_vframes_trans[i], os.path.join("./test000", "%04d.png" % i), normalize=True, value_range=(-1, 1) + ) From 91b5faf68dd0dac767b8778fe3f188811314fcb9 Mon Sep 17 00:00:00 2001 From: zxgx Date: Wed, 1 Jan 2025 11:40:48 +0800 Subject: [PATCH 3/6] pass training --- .../cogvideox/configs/benchmarks/baseline.yaml | 4 ++-- examples/training/cogvideox/train.py | 11 +++++------ 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/examples/training/cogvideox/configs/benchmarks/baseline.yaml b/examples/training/cogvideox/configs/benchmarks/baseline.yaml index 25764076..a0027fc2 100644 --- a/examples/training/cogvideox/configs/benchmarks/baseline.yaml +++ b/examples/training/cogvideox/configs/benchmarks/baseline.yaml @@ -1,7 +1,7 @@ zipf_offset: 5 outputs: exp/cogvideox/baseline profile_path: exp/cogvideox/profile/baseline -sp_size: 1 +sp_size: 8 dummy_dataset: true dummy_data_size: 2000 verbose: true @@ -12,7 +12,7 @@ calculate_imbalance: true # preprocess embedding data_path: "./assets/example_data/demo_preprocess.csv" -preprocessed_data: false +preprocessed_data: true drop_last: true # train diff --git a/examples/training/cogvideox/train.py b/examples/training/cogvideox/train.py index 33c95fa2..a1a03b8e 100644 --- a/examples/training/cogvideox/train.py +++ b/examples/training/cogvideox/train.py @@ -85,11 +85,7 @@ def main(args): ) else: do_profile = False - - # import pdb - # if torch.distributed.get_rank() == 0: - # pdb.set_trace() - + # ====================================================== # 2. build model # ====================================================== @@ -339,7 +335,10 @@ def main(args): # x = vae.encode(x) # [B, C, T, H/P, W/P] # # Prepare text inputs # model_args = encode_prompt(text_encoder, tokenizer, y) - # for k, v in batch_data.items(): + + local_token_counter += x.shape[0] * x.shape[1] * x.shape[3] * x.shape[4] / parallel_mgr.sp_size + + # for k, v in batch_data.items(): # if isinstance(v, torch.Tensor): # model_args[k] = v.to(device, dtype) # # TODO: polish From 068f0a1f32321662837ce7dc09cf7529dacaafed Mon Sep 17 00:00:00 2001 From: zxgx Date: Sun, 9 Mar 2025 11:56:58 +0800 Subject: [PATCH 4/6] update opensora profiler --- examples/training/open_sora/train.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/training/open_sora/train.py b/examples/training/open_sora/train.py index 2b3c1b0d..dfdd5409 100644 --- a/examples/training/open_sora/train.py +++ b/examples/training/open_sora/train.py @@ -157,6 +157,7 @@ def main(args): # create dcp profiler # TODO: scheduler is a better name? profiler: Profiler = set_profiler( + model_type=model.config._name_or_path, total_layers=model.config.depth, bucket_config=args.bucket_config, text_max_seq_len=model.config.model_max_length, From 73a546be75dc0f13cb77150cd3f9b33c8ab02055 Mon Sep 17 00:00:00 2001 From: zxgx Date: Sat, 9 Aug 2025 22:57:56 +0800 Subject: [PATCH 5/6] fall back to drop last, remove memory hard limit --- examples/training/cogvideox/1x8_baseline.sh | 69 ++++++++++++++++++ examples/training/cogvideox/1x8_dcp_inter.sh | 69 ++++++++++++++++++ .../training/cogvideox/1x8_dcp_inter_ckpt.sh | 69 ++++++++++++++++++ examples/training/cogvideox/1x8_dcp_intra.sh | 69 ++++++++++++++++++ .../training/cogvideox/baseline_profile.sh | 50 +++++++++++++ .../configs/benchmarks/baseline.yaml | 14 ++-- .../configs/benchmarks/dcp_inter.yaml | 14 ++-- .../configs/benchmarks/dcp_inter_ckpt.yaml | 14 ++-- .../configs/benchmarks/dcp_intra.yaml | 14 ++-- .../cogvideox/profile-dcp-inter-ckpt.sh | 49 +++++++++++++ .../training/cogvideox/profile-dcp-inter.sh | 49 +++++++++++++ .../training/cogvideox/profile-dcp-intra.sh | 49 +++++++++++++ examples/training/cogvideox/train.py | 2 +- examples/training/open_sora/1x8_baseline.sh | 69 ++++++++++++++++++ examples/training/open_sora/1x8_dcp_inter.sh | 69 ++++++++++++++++++ .../training/open_sora/1x8_dcp_inter_ckpt.sh | 69 ++++++++++++++++++ examples/training/open_sora/1x8_dcp_intra.sh | 69 ++++++++++++++++++ .../training/open_sora/baseline_profile.sh | 50 +++++++++++++ .../configs/benchmarks-sp4/dcp_inter.yaml | 71 +++++++++++++++++++ .../benchmarks-sp4/dcp_inter_ckpt.yaml | 71 +++++++++++++++++++ .../configs/benchmarks-sp4/dcp_intra.yaml | 70 ++++++++++++++++++ .../configs/benchmarks/baseline.yaml | 8 +-- .../configs/benchmarks/dcp_inter.yaml | 4 +- .../configs/benchmarks/dcp_inter_ckpt.yaml | 4 +- .../configs/benchmarks/dcp_intra.yaml | 4 +- .../open_sora/profile-dcp-inter-ckpt-sp4.sh | 49 +++++++++++++ .../open_sora/profile-dcp-inter-sp4.sh | 49 +++++++++++++ .../open_sora/profile-dcp-intra-sp4.sh | 49 +++++++++++++ examples/training/open_sora/train.py | 10 +-- parse_log.py | 56 +++++++++++++++ videosys/core/dcp/profiler.py | 11 ++- videosys/core/distributed/parallel_mgr.py | 2 +- 32 files changed, 1268 insertions(+), 47 deletions(-) create mode 100755 examples/training/cogvideox/1x8_baseline.sh create mode 100755 examples/training/cogvideox/1x8_dcp_inter.sh create mode 100755 examples/training/cogvideox/1x8_dcp_inter_ckpt.sh create mode 100755 examples/training/cogvideox/1x8_dcp_intra.sh create mode 100755 examples/training/cogvideox/baseline_profile.sh create mode 100755 examples/training/cogvideox/profile-dcp-inter-ckpt.sh create mode 100755 examples/training/cogvideox/profile-dcp-inter.sh create mode 100755 examples/training/cogvideox/profile-dcp-intra.sh create mode 100755 examples/training/open_sora/1x8_baseline.sh create mode 100755 examples/training/open_sora/1x8_dcp_inter.sh create mode 100755 examples/training/open_sora/1x8_dcp_inter_ckpt.sh create mode 100755 examples/training/open_sora/1x8_dcp_intra.sh create mode 100755 examples/training/open_sora/baseline_profile.sh create mode 100755 examples/training/open_sora/configs/benchmarks-sp4/dcp_inter.yaml create mode 100755 examples/training/open_sora/configs/benchmarks-sp4/dcp_inter_ckpt.yaml create mode 100755 examples/training/open_sora/configs/benchmarks-sp4/dcp_intra.yaml create mode 100755 examples/training/open_sora/profile-dcp-inter-ckpt-sp4.sh create mode 100755 examples/training/open_sora/profile-dcp-inter-sp4.sh create mode 100755 examples/training/open_sora/profile-dcp-intra-sp4.sh create mode 100644 parse_log.py diff --git a/examples/training/cogvideox/1x8_baseline.sh b/examples/training/cogvideox/1x8_baseline.sh new file mode 100755 index 00000000..3045a18b --- /dev/null +++ b/examples/training/cogvideox/1x8_baseline.sh @@ -0,0 +1,69 @@ +#!/bin/bash +#PBS -P CFP02-CF-004 +#PBS -l select=1:ngpus=8 +#PBS -l place=vscatter +#PBS -l walltime=96:00:00 +#PBS -j oe +#PBS -o 1x8-cogvideox-baseline.log + +# =============== env params ================ +# This script is for NSCC which uses PBS Pro as the scheduler + +# where the singularity image is saved +SCRATCH_PATH=$HPCTMP + +cd $PBS_O_WORKDIR +echo "JOB ID: $PBS_JOBID, pwd: $PWD, pbs workdir: $PBS_O_WORKDIR" + +# for torch.distributed +export NNODES=1 +# export NODE_RANK=0 +export GPUS_PER_NODE=8 +export WORLD_SIZE=$(($NNODES*$GPUS_PER_NODE)) +export MASTER_ADDR=$(head -n 1 $PBS_NODEFILE | awk -F'.' '{print $1}') +export MASTER_PORT=29502 +echo "master node: $MASTER_ADDR" + +# used by OpenMPI +export HOSTFILE="$PBS_JOBID.hostfile" +cat $PBS_NODEFILE | awk -F'.' '{for(i=1;i<=NF;i+=6) print $1 " slots="ENVIRON["GPUS_PER_NODE"]}' > $HOSTFILE +echo "detected hosts: $(cat $HOSTFILE)" + +# refer to: https://apptainer.org/user-docs/master/gpu.html +# for apptainer, replace SINGULARITYENV_* with APPTAINERENV_* +# export SINGULARITYENV_CUDA_VISIBLE_DEVICES=$(printf "%s," $(seq 0 $(($GPUS_PER_NODE-1))) | sed 's/,$//') +# echo "singularity cuda visible devices: $SINGULARITYENV_CUDA_VISIBLE_DEVICES" + +# =============== program params ================ +export PYTHONPATH=$PYTHONPATH:$PWD +export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True +export TOKENIZERS_PARALLELISM=false + +# =============== zipf-1 ================ +mpirun --hostfile $HOSTFILE --np $WORLD_SIZE -N $GPUS_PER_NODE --oversubscribe \ + singularity exec --nv /app1/common/singularity-img/hopper/cuda/cuda_12.1.1-cudnn8-devel-ubuntu22.04.sif \ + /bin/bash -c "source /hpctmp/e1154485/venvs/videosys/bin/activate && \ + python examples/training/cogvideox/train.py \ + examples/training/cogvideox/configs/benchmarks/baseline.yaml \ + --image-mixing-frac 1 +" + +# =============== zipf-10 ================ +mpirun --hostfile $HOSTFILE --np $WORLD_SIZE -N $GPUS_PER_NODE --oversubscribe \ + singularity exec --nv /app1/common/singularity-img/hopper/cuda/cuda_12.1.1-cudnn8-devel-ubuntu22.04.sif \ + /bin/bash -c "source /hpctmp/e1154485/venvs/videosys/bin/activate && \ + python examples/training/cogvideox/train.py \ + examples/training/cogvideox/configs/benchmarks/baseline.yaml \ + --image-mixing-frac 10 +" + +# =============== zipf-50 ================ +mpirun --hostfile $HOSTFILE --np $WORLD_SIZE -N $GPUS_PER_NODE --oversubscribe \ + singularity exec --nv /app1/common/singularity-img/hopper/cuda/cuda_12.1.1-cudnn8-devel-ubuntu22.04.sif \ + /bin/bash -c "source /hpctmp/e1154485/venvs/videosys/bin/activate && \ + python examples/training/cogvideox/train.py \ + examples/training/cogvideox/configs/benchmarks/baseline.yaml \ + --image-mixing-frac 50 +" + +rm $HOSTFILE diff --git a/examples/training/cogvideox/1x8_dcp_inter.sh b/examples/training/cogvideox/1x8_dcp_inter.sh new file mode 100755 index 00000000..5f372f14 --- /dev/null +++ b/examples/training/cogvideox/1x8_dcp_inter.sh @@ -0,0 +1,69 @@ +#!/bin/bash +#PBS -P CFP02-CF-004 +#PBS -l select=1:ngpus=8 +#PBS -l place=vscatter +#PBS -l walltime=96:00:00 +#PBS -j oe +#PBS -o 1x8-cogvideox-dcp-inter.log + +# =============== env params ================ +# This script is for NSCC which uses PBS Pro as the scheduler + +# where the singularity image is saved +SCRATCH_PATH=$HPCTMP + +cd $PBS_O_WORKDIR +echo "JOB ID: $PBS_JOBID, pwd: $PWD, pbs workdir: $PBS_O_WORKDIR" + +# for torch.distributed +export NNODES=1 +# export NODE_RANK=0 +export GPUS_PER_NODE=8 +export WORLD_SIZE=$(($NNODES*$GPUS_PER_NODE)) +export MASTER_ADDR=$(head -n 1 $PBS_NODEFILE | awk -F'.' '{print $1}') +export MASTER_PORT=29502 +echo "master node: $MASTER_ADDR" + +# used by OpenMPI +export HOSTFILE="$PBS_JOBID.hostfile" +cat $PBS_NODEFILE | awk -F'.' '{for(i=1;i<=NF;i+=6) print $1 " slots="ENVIRON["GPUS_PER_NODE"]}' > $HOSTFILE +echo "detected hosts: $(cat $HOSTFILE)" + +# refer to: https://apptainer.org/user-docs/master/gpu.html +# for apptainer, replace SINGULARITYENV_* with APPTAINERENV_* +# export SINGULARITYENV_CUDA_VISIBLE_DEVICES=$(printf "%s," $(seq 0 $(($GPUS_PER_NODE-1))) | sed 's/,$//') +# echo "singularity cuda visible devices: $SINGULARITYENV_CUDA_VISIBLE_DEVICES" + +# =============== program params ================ +export PYTHONPATH=$PYTHONPATH:$PWD +export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True +export TOKENIZERS_PARALLELISM=false + +# =============== zipf-1 ================ +mpirun --hostfile $HOSTFILE --np $WORLD_SIZE -N $GPUS_PER_NODE --oversubscribe \ + singularity exec --nv /app1/common/singularity-img/hopper/cuda/cuda_12.1.1-cudnn8-devel-ubuntu22.04.sif \ + /bin/bash -c "source /hpctmp/e1154485/venvs/videosys/bin/activate && \ + python examples/training/cogvideox/train.py \ + examples/training/cogvideox/configs/benchmarks/dcp_inter.yaml \ + --image-mixing-frac 1 +" + +# =============== zipf-10 ================ +mpirun --hostfile $HOSTFILE --np $WORLD_SIZE -N $GPUS_PER_NODE --oversubscribe \ + singularity exec --nv /app1/common/singularity-img/hopper/cuda/cuda_12.1.1-cudnn8-devel-ubuntu22.04.sif \ + /bin/bash -c "source /hpctmp/e1154485/venvs/videosys/bin/activate && \ + python examples/training/cogvideox/train.py \ + examples/training/cogvideox/configs/benchmarks/dcp_inter.yaml \ + --image-mixing-frac 10 +" + +# =============== zipf-50 ================ +mpirun --hostfile $HOSTFILE --np $WORLD_SIZE -N $GPUS_PER_NODE --oversubscribe \ + singularity exec --nv /app1/common/singularity-img/hopper/cuda/cuda_12.1.1-cudnn8-devel-ubuntu22.04.sif \ + /bin/bash -c "source /hpctmp/e1154485/venvs/videosys/bin/activate && \ + python examples/training/cogvideox/train.py \ + examples/training/cogvideox/configs/benchmarks/dcp_inter.yaml \ + --image-mixing-frac 50 +" + +rm $HOSTFILE diff --git a/examples/training/cogvideox/1x8_dcp_inter_ckpt.sh b/examples/training/cogvideox/1x8_dcp_inter_ckpt.sh new file mode 100755 index 00000000..d6bc04a5 --- /dev/null +++ b/examples/training/cogvideox/1x8_dcp_inter_ckpt.sh @@ -0,0 +1,69 @@ +#!/bin/bash +#PBS -P CFP02-CF-004 +#PBS -l select=1:ngpus=8 +#PBS -l place=vscatter +#PBS -l walltime=96:00:00 +#PBS -j oe +#PBS -o 1x8-cogvideox-dcp-inter-ckpt.log + +# =============== env params ================ +# This script is for NSCC which uses PBS Pro as the scheduler + +# where the singularity image is saved +SCRATCH_PATH=$HPCTMP + +cd $PBS_O_WORKDIR +echo "JOB ID: $PBS_JOBID, pwd: $PWD, pbs workdir: $PBS_O_WORKDIR" + +# for torch.distributed +export NNODES=1 +# export NODE_RANK=0 +export GPUS_PER_NODE=8 +export WORLD_SIZE=$(($NNODES*$GPUS_PER_NODE)) +export MASTER_ADDR=$(head -n 1 $PBS_NODEFILE | awk -F'.' '{print $1}') +export MASTER_PORT=29502 +echo "master node: $MASTER_ADDR" + +# used by OpenMPI +export HOSTFILE="$PBS_JOBID.hostfile" +cat $PBS_NODEFILE | awk -F'.' '{for(i=1;i<=NF;i+=6) print $1 " slots="ENVIRON["GPUS_PER_NODE"]}' > $HOSTFILE +echo "detected hosts: $(cat $HOSTFILE)" + +# refer to: https://apptainer.org/user-docs/master/gpu.html +# for apptainer, replace SINGULARITYENV_* with APPTAINERENV_* +# export SINGULARITYENV_CUDA_VISIBLE_DEVICES=$(printf "%s," $(seq 0 $(($GPUS_PER_NODE-1))) | sed 's/,$//') +# echo "singularity cuda visible devices: $SINGULARITYENV_CUDA_VISIBLE_DEVICES" + +# =============== program params ================ +export PYTHONPATH=$PYTHONPATH:$PWD +export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True +export TOKENIZERS_PARALLELISM=false + +# =============== zipf-1 ================ +mpirun --hostfile $HOSTFILE --np $WORLD_SIZE -N $GPUS_PER_NODE --oversubscribe \ + singularity exec --nv /app1/common/singularity-img/hopper/cuda/cuda_12.1.1-cudnn8-devel-ubuntu22.04.sif \ + /bin/bash -c "source /hpctmp/e1154485/venvs/videosys/bin/activate && \ + python examples/training/cogvideox/train.py \ + examples/training/cogvideox/configs/benchmarks/dcp_inter_ckpt.yaml \ + --image-mixing-frac 1 +" + +# =============== zipf-10 ================ +mpirun --hostfile $HOSTFILE --np $WORLD_SIZE -N $GPUS_PER_NODE --oversubscribe \ + singularity exec --nv /app1/common/singularity-img/hopper/cuda/cuda_12.1.1-cudnn8-devel-ubuntu22.04.sif \ + /bin/bash -c "source /hpctmp/e1154485/venvs/videosys/bin/activate && \ + python examples/training/cogvideox/train.py \ + examples/training/cogvideox/configs/benchmarks/dcp_inter_ckpt.yaml \ + --image-mixing-frac 10 +" + +# =============== zipf-50 ================ +mpirun --hostfile $HOSTFILE --np $WORLD_SIZE -N $GPUS_PER_NODE --oversubscribe \ + singularity exec --nv /app1/common/singularity-img/hopper/cuda/cuda_12.1.1-cudnn8-devel-ubuntu22.04.sif \ + /bin/bash -c "source /hpctmp/e1154485/venvs/videosys/bin/activate && \ + python examples/training/cogvideox/train.py \ + examples/training/cogvideox/configs/benchmarks/dcp_inter_ckpt.yaml \ + --image-mixing-frac 50 +" + +rm $HOSTFILE diff --git a/examples/training/cogvideox/1x8_dcp_intra.sh b/examples/training/cogvideox/1x8_dcp_intra.sh new file mode 100755 index 00000000..a8db520a --- /dev/null +++ b/examples/training/cogvideox/1x8_dcp_intra.sh @@ -0,0 +1,69 @@ +#!/bin/bash +#PBS -P CFP02-CF-004 +#PBS -l select=1:ngpus=8 +#PBS -l place=vscatter +#PBS -l walltime=96:00:00 +#PBS -j oe +#PBS -o 1x8-cogvideox-dcp-intra.log + +# =============== env params ================ +# This script is for NSCC which uses PBS Pro as the scheduler + +# where the singularity image is saved +SCRATCH_PATH=$HPCTMP + +cd $PBS_O_WORKDIR +echo "JOB ID: $PBS_JOBID, pwd: $PWD, pbs workdir: $PBS_O_WORKDIR" + +# for torch.distributed +export NNODES=1 +# export NODE_RANK=0 +export GPUS_PER_NODE=8 +export WORLD_SIZE=$(($NNODES*$GPUS_PER_NODE)) +export MASTER_ADDR=$(head -n 1 $PBS_NODEFILE | awk -F'.' '{print $1}') +export MASTER_PORT=29502 +echo "master node: $MASTER_ADDR" + +# used by OpenMPI +export HOSTFILE="$PBS_JOBID.hostfile" +cat $PBS_NODEFILE | awk -F'.' '{for(i=1;i<=NF;i+=6) print $1 " slots="ENVIRON["GPUS_PER_NODE"]}' > $HOSTFILE +echo "detected hosts: $(cat $HOSTFILE)" + +# refer to: https://apptainer.org/user-docs/master/gpu.html +# for apptainer, replace SINGULARITYENV_* with APPTAINERENV_* +# export SINGULARITYENV_CUDA_VISIBLE_DEVICES=$(printf "%s," $(seq 0 $(($GPUS_PER_NODE-1))) | sed 's/,$//') +# echo "singularity cuda visible devices: $SINGULARITYENV_CUDA_VISIBLE_DEVICES" + +# =============== program params ================ +export PYTHONPATH=$PYTHONPATH:$PWD +export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True +export TOKENIZERS_PARALLELISM=false + +# =============== zipf-1 ================ +mpirun --hostfile $HOSTFILE --np $WORLD_SIZE -N $GPUS_PER_NODE --oversubscribe \ + singularity exec --nv /app1/common/singularity-img/hopper/cuda/cuda_12.1.1-cudnn8-devel-ubuntu22.04.sif \ + /bin/bash -c "source /hpctmp/e1154485/venvs/videosys/bin/activate && \ + python examples/training/cogvideox/train.py \ + examples/training/cogvideox/configs/benchmarks/dcp_intra.yaml \ + --image-mixing-frac 1 +" + +# =============== zipf-10 ================ +mpirun --hostfile $HOSTFILE --np $WORLD_SIZE -N $GPUS_PER_NODE --oversubscribe \ + singularity exec --nv /app1/common/singularity-img/hopper/cuda/cuda_12.1.1-cudnn8-devel-ubuntu22.04.sif \ + /bin/bash -c "source /hpctmp/e1154485/venvs/videosys/bin/activate && \ + python examples/training/cogvideox/train.py \ + examples/training/cogvideox/configs/benchmarks/dcp_intra.yaml \ + --image-mixing-frac 10 +" + +# =============== zipf-50 ================ +mpirun --hostfile $HOSTFILE --np $WORLD_SIZE -N $GPUS_PER_NODE --oversubscribe \ + singularity exec --nv /app1/common/singularity-img/hopper/cuda/cuda_12.1.1-cudnn8-devel-ubuntu22.04.sif \ + /bin/bash -c "source /hpctmp/e1154485/venvs/videosys/bin/activate && \ + python examples/training/cogvideox/train.py \ + examples/training/cogvideox/configs/benchmarks/dcp_intra.yaml \ + --image-mixing-frac 50 +" + +rm $HOSTFILE diff --git a/examples/training/cogvideox/baseline_profile.sh b/examples/training/cogvideox/baseline_profile.sh new file mode 100755 index 00000000..b33e2dfa --- /dev/null +++ b/examples/training/cogvideox/baseline_profile.sh @@ -0,0 +1,50 @@ +#!/bin/bash +#PBS -P CFP02-CF-004 +#PBS -l select=1:ngpus=8 +#PBS -l place=vscatter +#PBS -l walltime=12:00:00 +#PBS -j oe +#PBS -o profile-baseline-cogvideox.log + +# =============== env params ================ +# This script is for NSCC which uses PBS Pro as the scheduler + +# where the singularity image is saved +SCRATCH_PATH=$HPCTMP + +cd $PBS_O_WORKDIR +echo "JOB ID: $PBS_JOBID, pwd: $PWD, pbs workdir: $PBS_O_WORKDIR" + +# for torch.distributed +export NNODES=1 +# export NODE_RANK=0 +export GPUS_PER_NODE=8 +export WORLD_SIZE=$(($NNODES*$GPUS_PER_NODE)) +export MASTER_ADDR=$(head -n 1 $PBS_NODEFILE | awk -F'.' '{print $1}') +export MASTER_PORT=9527 +echo "master node: $MASTER_ADDR" + +# used by OpenMPI +export HOSTFILE="$PBS_JOBID.hostfile" +cat $PBS_NODEFILE | awk -F'.' '{for(i=1;i<=NF;i+=6) print $1 " slots="ENVIRON["GPUS_PER_NODE"]}' > $HOSTFILE +echo "detected hosts: $(cat $HOSTFILE)" + +# refer to: https://apptainer.org/user-docs/master/gpu.html +# for apptainer, replace SINGULARITYENV_* with APPTAINERENV_* +# export SINGULARITYENV_CUDA_VISIBLE_DEVICES=$(printf "%s," $(seq 0 $(($GPUS_PER_NODE-1))) | sed 's/,$//') +# echo "singularity cuda visible devices: $SINGULARITYENV_CUDA_VISIBLE_DEVICES" + +# =============== program params ================ +export PYTHONPATH=$PYTHONPATH:$PWD +export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True +export TOKENIZERS_PARALLELISM=false + +# baseline +mpirun --hostfile $HOSTFILE --np $WORLD_SIZE -N $GPUS_PER_NODE --oversubscribe \ + singularity exec --nv /app1/common/singularity-img/hopper/cuda/cuda_12.1.1-cudnn8-devel-ubuntu22.04.sif \ + /bin/bash -c "source /hpctmp/e1154485/venvs/videosys/bin/activate && \ + python examples/training/cogvideox/train.py \ + examples/training/cogvideox/configs/benchmarks/baseline.yaml +" + +rm $HOSTFILE diff --git a/examples/training/cogvideox/configs/benchmarks/baseline.yaml b/examples/training/cogvideox/configs/benchmarks/baseline.yaml index a0027fc2..e6d9a164 100644 --- a/examples/training/cogvideox/configs/benchmarks/baseline.yaml +++ b/examples/training/cogvideox/configs/benchmarks/baseline.yaml @@ -13,7 +13,7 @@ calculate_imbalance: true # preprocess embedding data_path: "./assets/example_data/demo_preprocess.csv" preprocessed_data: true -drop_last: true +drop_last: false # true # train ckpt_path: "THUDM/CogVideoX-5b" @@ -34,14 +34,14 @@ adam_eps: 1e-15 warmup_steps: 10 # data -image_mixing_frac: 50 +# image_mixing_frac: 50 num_bucket_build_workers: 16 bucket_config: - "144p": {1: [1.0, 345], 51: [1.0, 48], 102: [1.0, 25], 204: [1.0, 12], 408: [1.0, 6]} - "240p": {1: [1.0, 128], 51: [1.0, 16], 102: [1.0, 8], 204: [1.0, 4], 408: [1.0, 2]} - "360p": {1: [1.0, 64], 51: [1.0, 7], 102: [1.0, 4], 204: [1.0, 2], 408: [1.0, 1]} - "480p": {1: [1.0, 32], 51: [1.0, 4], 102: [1.0, 2], 204: [1.0, 1], 408: [1.0, 1]} - "720p": {1: [1.0, 14], 51: [1.0, 1], 102: [1.0, 1], 204: [1.0, 1], 408: [1.0, 1]} + "144p": {1: [1.0, 345], 25: [1.0, 48], 49: [1.0, 25], 73: [1.0, 12], 97: [1.0, 6]} + "240p": {1: [1.0, 128], 25: [1.0, 16], 49: [1.0, 8], 73: [1.0, 4], 97: [1.0, 2]} + "360p": {1: [1.0, 64], 25: [1.0, 7], 49: [1.0, 4], 73: [1.0, 2], 97: [1.0, 1]} + "480p": {1: [1.0, 32], 25: [1.0, 4], 49: [1.0, 2], 73: [1.0, 1], 97: [1.0, 1]} + "720p": {1: [1.0, 14], 25: [1.0, 1], 49: [1.0, 1], 73: [1.0, 1], 97: [1.0, 1]} # override default common ar # for benchmark, we use single ar for all resolutions diff --git a/examples/training/cogvideox/configs/benchmarks/dcp_inter.yaml b/examples/training/cogvideox/configs/benchmarks/dcp_inter.yaml index aa6b6ee4..87987294 100644 --- a/examples/training/cogvideox/configs/benchmarks/dcp_inter.yaml +++ b/examples/training/cogvideox/configs/benchmarks/dcp_inter.yaml @@ -16,7 +16,7 @@ max_grad_accumulation_steps: 5 # preprocess embedding data_path: "./assets/example_data/demo_preprocess.csv" preprocessed_data: true -drop_last: true +drop_last: false # true # train ckpt_path: "THUDM/CogVideoX-5b" @@ -37,14 +37,14 @@ adam_eps: 1e-15 warmup_steps: 10 # data -image_mixing_frac: 50 +# image_mixing_frac: 50 num_bucket_build_workers: 16 bucket_config: - "144p": {1: [1.0, 345], 51: [1.0, 48], 102: [1.0, 25], 204: [1.0, 12], 408: [1.0, 6]} - "240p": {1: [1.0, 128], 51: [1.0, 16], 102: [1.0, 8], 204: [1.0, 4], 408: [1.0, 2]} - "360p": {1: [1.0, 64], 51: [1.0, 7], 102: [1.0, 4], 204: [1.0, 2], 408: [1.0, 1]} - "480p": {1: [1.0, 32], 51: [1.0, 4], 102: [1.0, 2], 204: [1.0, 1], 408: [1.0, 1]} - "720p": {1: [1.0, 14], 51: [1.0, 1], 102: [1.0, 1], 204: [1.0, 1], 408: [1.0, 1]} + "144p": {1: [1.0, 345], 25: [1.0, 48], 49: [1.0, 25], 73: [1.0, 12], 97: [1.0, 6]} + "240p": {1: [1.0, 128], 25: [1.0, 16], 49: [1.0, 8], 73: [1.0, 4], 97: [1.0, 2]} + "360p": {1: [1.0, 64], 25: [1.0, 7], 49: [1.0, 4], 73: [1.0, 2], 97: [1.0, 1]} + "480p": {1: [1.0, 32], 25: [1.0, 4], 49: [1.0, 2], 73: [1.0, 1], 97: [1.0, 1]} + "720p": {1: [1.0, 14], 25: [1.0, 1], 49: [1.0, 1], 73: [1.0, 1], 97: [1.0, 1]} # override default common ar # for benchmark, we use single ar for all resolutions diff --git a/examples/training/cogvideox/configs/benchmarks/dcp_inter_ckpt.yaml b/examples/training/cogvideox/configs/benchmarks/dcp_inter_ckpt.yaml index a592b2b3..3c441e2f 100644 --- a/examples/training/cogvideox/configs/benchmarks/dcp_inter_ckpt.yaml +++ b/examples/training/cogvideox/configs/benchmarks/dcp_inter_ckpt.yaml @@ -16,7 +16,7 @@ min_grad_accumulation_steps: 15 # preprocess embedding data_path: "./assets/example_data/demo_preprocess.csv" preprocessed_data: true -drop_last: true +drop_last: false # true # train ckpt_path: "THUDM/CogVideoX-5b" @@ -37,14 +37,14 @@ adam_eps: 1e-15 warmup_steps: 10 # data -image_mixing_frac: 50 +# image_mixing_frac: 50 num_bucket_build_workers: 16 bucket_config: - "144p": {1: [1.0, 345], 51: [1.0, 48], 102: [1.0, 25], 204: [1.0, 12], 408: [1.0, 6]} - "240p": {1: [1.0, 128], 51: [1.0, 16], 102: [1.0, 8], 204: [1.0, 4], 408: [1.0, 2]} - "360p": {1: [1.0, 64], 51: [1.0, 7], 102: [1.0, 4], 204: [1.0, 2], 408: [1.0, 1]} - "480p": {1: [1.0, 32], 51: [1.0, 4], 102: [1.0, 2], 204: [1.0, 1], 408: [1.0, 1]} - "720p": {1: [1.0, 14], 51: [1.0, 1], 102: [1.0, 1], 204: [1.0, 1], 408: [1.0, 1]} + "144p": {1: [1.0, 345], 25: [1.0, 48], 49: [1.0, 25], 73: [1.0, 12], 97: [1.0, 6]} + "240p": {1: [1.0, 128], 25: [1.0, 16], 49: [1.0, 8], 73: [1.0, 4], 97: [1.0, 2]} + "360p": {1: [1.0, 64], 25: [1.0, 7], 49: [1.0, 4], 73: [1.0, 2], 97: [1.0, 1]} + "480p": {1: [1.0, 32], 25: [1.0, 4], 49: [1.0, 2], 73: [1.0, 1], 97: [1.0, 1]} + "720p": {1: [1.0, 14], 25: [1.0, 1], 49: [1.0, 1], 73: [1.0, 1], 97: [1.0, 1]} # override default common ar # for benchmark, we use single ar for all resolutions diff --git a/examples/training/cogvideox/configs/benchmarks/dcp_intra.yaml b/examples/training/cogvideox/configs/benchmarks/dcp_intra.yaml index 5d8011ef..139408dc 100644 --- a/examples/training/cogvideox/configs/benchmarks/dcp_intra.yaml +++ b/examples/training/cogvideox/configs/benchmarks/dcp_intra.yaml @@ -15,7 +15,7 @@ calculate_imbalance: true # preprocess embedding data_path: "./assets/example_data/demo_preprocess.csv" preprocessed_data: true -drop_last: true +drop_last: false # true # train ckpt_path: "THUDM/CogVideoX-5b" @@ -36,14 +36,14 @@ adam_eps: 1e-15 warmup_steps: 10 # data -image_mixing_frac: 50 +# image_mixing_frac: 50 num_bucket_build_workers: 16 bucket_config: - "144p": {1: [1.0, 345], 51: [1.0, 48], 102: [1.0, 25], 204: [1.0, 12], 408: [1.0, 6]} - "240p": {1: [1.0, 128], 51: [1.0, 16], 102: [1.0, 8], 204: [1.0, 4], 408: [1.0, 2]} - "360p": {1: [1.0, 64], 51: [1.0, 7], 102: [1.0, 4], 204: [1.0, 2], 408: [1.0, 1]} - "480p": {1: [1.0, 32], 51: [1.0, 4], 102: [1.0, 2], 204: [1.0, 1], 408: [1.0, 1]} - "720p": {1: [1.0, 14], 51: [1.0, 1], 102: [1.0, 1], 204: [1.0, 1], 408: [1.0, 1]} + "144p": {1: [1.0, 345], 25: [1.0, 48], 49: [1.0, 25], 73: [1.0, 12], 97: [1.0, 6]} + "240p": {1: [1.0, 128], 25: [1.0, 16], 49: [1.0, 8], 73: [1.0, 4], 97: [1.0, 2]} + "360p": {1: [1.0, 64], 25: [1.0, 7], 49: [1.0, 4], 73: [1.0, 2], 97: [1.0, 1]} + "480p": {1: [1.0, 32], 25: [1.0, 4], 49: [1.0, 2], 73: [1.0, 1], 97: [1.0, 1]} + "720p": {1: [1.0, 14], 25: [1.0, 1], 49: [1.0, 1], 73: [1.0, 1], 97: [1.0, 1]} # override default common ar # for benchmark, we use single ar for all resolutions diff --git a/examples/training/cogvideox/profile-dcp-inter-ckpt.sh b/examples/training/cogvideox/profile-dcp-inter-ckpt.sh new file mode 100755 index 00000000..ef434309 --- /dev/null +++ b/examples/training/cogvideox/profile-dcp-inter-ckpt.sh @@ -0,0 +1,49 @@ +#!/bin/bash +#PBS -P CFP02-CF-004 +#PBS -l select=1:ngpus=8 +#PBS -l place=vscatter +#PBS -l walltime=12:00:00 +#PBS -j oe +#PBS -o profile-cogvideox-dcp-inter-ckpt.log + +# =============== env params ================ +# This script is for NSCC which uses PBS Pro as the scheduler + +# where the singularity image is saved +SCRATCH_PATH=$HPCTMP + +cd $PBS_O_WORKDIR +echo "JOB ID: $PBS_JOBID, pwd: $PWD, pbs workdir: $PBS_O_WORKDIR" + +# for torch.distributed +export NNODES=1 +# export NODE_RANK=0 +export GPUS_PER_NODE=8 +export WORLD_SIZE=$(($NNODES*$GPUS_PER_NODE)) +export MASTER_ADDR=$(head -n 1 $PBS_NODEFILE | awk -F'.' '{print $1}') +export MASTER_PORT=9527 +echo "master node: $MASTER_ADDR" + +# used by OpenMPI +export HOSTFILE="$PBS_JOBID.hostfile" +cat $PBS_NODEFILE | awk -F'.' '{for(i=1;i<=NF;i+=6) print $1 " slots="ENVIRON["GPUS_PER_NODE"]}' > $HOSTFILE +echo "detected hosts: $(cat $HOSTFILE)" + +# refer to: https://apptainer.org/user-docs/master/gpu.html +# for apptainer, replace SINGULARITYENV_* with APPTAINERENV_* +# export SINGULARITYENV_CUDA_VISIBLE_DEVICES=$(printf "%s," $(seq 0 $(($GPUS_PER_NODE-1))) | sed 's/,$//') +# echo "singularity cuda visible devices: $SINGULARITYENV_CUDA_VISIBLE_DEVICES" + +# =============== program params ================ +export PYTHONPATH=$PYTHONPATH:$PWD +export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True +export TOKENIZERS_PARALLELISM=false + +mpirun --hostfile $HOSTFILE --np $WORLD_SIZE -N $GPUS_PER_NODE --oversubscribe \ + singularity exec --nv /app1/common/singularity-img/hopper/cuda/cuda_12.1.1-cudnn8-devel-ubuntu22.04.sif \ + /bin/bash -c "source /hpctmp/e1154485/venvs/videosys/bin/activate && \ + python examples/training/cogvideox/train.py \ + examples/training/cogvideox/configs/benchmarks/dcp_inter_ckpt.yaml +" + +rm $HOSTFILE diff --git a/examples/training/cogvideox/profile-dcp-inter.sh b/examples/training/cogvideox/profile-dcp-inter.sh new file mode 100755 index 00000000..be5619c2 --- /dev/null +++ b/examples/training/cogvideox/profile-dcp-inter.sh @@ -0,0 +1,49 @@ +#!/bin/bash +#PBS -P CFP02-CF-004 +#PBS -l select=1:ngpus=8 +#PBS -l place=vscatter +#PBS -l walltime=12:00:00 +#PBS -j oe +#PBS -o profile-cogvideox-dcp-inter.log + +# =============== env params ================ +# This script is for NSCC which uses PBS Pro as the scheduler + +# where the singularity image is saved +SCRATCH_PATH=$HPCTMP + +cd $PBS_O_WORKDIR +echo "JOB ID: $PBS_JOBID, pwd: $PWD, pbs workdir: $PBS_O_WORKDIR" + +# for torch.distributed +export NNODES=1 +# export NODE_RANK=0 +export GPUS_PER_NODE=8 +export WORLD_SIZE=$(($NNODES*$GPUS_PER_NODE)) +export MASTER_ADDR=$(head -n 1 $PBS_NODEFILE | awk -F'.' '{print $1}') +export MASTER_PORT=9527 +echo "master node: $MASTER_ADDR" + +# used by OpenMPI +export HOSTFILE="$PBS_JOBID.hostfile" +cat $PBS_NODEFILE | awk -F'.' '{for(i=1;i<=NF;i+=6) print $1 " slots="ENVIRON["GPUS_PER_NODE"]}' > $HOSTFILE +echo "detected hosts: $(cat $HOSTFILE)" + +# refer to: https://apptainer.org/user-docs/master/gpu.html +# for apptainer, replace SINGULARITYENV_* with APPTAINERENV_* +# export SINGULARITYENV_CUDA_VISIBLE_DEVICES=$(printf "%s," $(seq 0 $(($GPUS_PER_NODE-1))) | sed 's/,$//') +# echo "singularity cuda visible devices: $SINGULARITYENV_CUDA_VISIBLE_DEVICES" + +# =============== program params ================ +export PYTHONPATH=$PYTHONPATH:$PWD +export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True +export TOKENIZERS_PARALLELISM=false + +mpirun --hostfile $HOSTFILE --np $WORLD_SIZE -N $GPUS_PER_NODE --oversubscribe \ + singularity exec --nv /app1/common/singularity-img/hopper/cuda/cuda_12.1.1-cudnn8-devel-ubuntu22.04.sif \ + /bin/bash -c "source /hpctmp/e1154485/venvs/videosys/bin/activate && \ + python examples/training/cogvideox/train.py \ + examples/training/cogvideox/configs/benchmarks/dcp_inter.yaml +" + +rm $HOSTFILE diff --git a/examples/training/cogvideox/profile-dcp-intra.sh b/examples/training/cogvideox/profile-dcp-intra.sh new file mode 100755 index 00000000..3ffcec80 --- /dev/null +++ b/examples/training/cogvideox/profile-dcp-intra.sh @@ -0,0 +1,49 @@ +#!/bin/bash +#PBS -P CFP02-CF-004 +#PBS -l select=1:ngpus=8 +#PBS -l place=vscatter +#PBS -l walltime=12:00:00 +#PBS -j oe +#PBS -o profile-cogvideox-dcp-intra.log + +# =============== env params ================ +# This script is for NSCC which uses PBS Pro as the scheduler + +# where the singularity image is saved +SCRATCH_PATH=$HPCTMP + +cd $PBS_O_WORKDIR +echo "JOB ID: $PBS_JOBID, pwd: $PWD, pbs workdir: $PBS_O_WORKDIR" + +# for torch.distributed +export NNODES=1 +# export NODE_RANK=0 +export GPUS_PER_NODE=8 +export WORLD_SIZE=$(($NNODES*$GPUS_PER_NODE)) +export MASTER_ADDR=$(head -n 1 $PBS_NODEFILE | awk -F'.' '{print $1}') +export MASTER_PORT=9527 +echo "master node: $MASTER_ADDR" + +# used by OpenMPI +export HOSTFILE="$PBS_JOBID.hostfile" +cat $PBS_NODEFILE | awk -F'.' '{for(i=1;i<=NF;i+=6) print $1 " slots="ENVIRON["GPUS_PER_NODE"]}' > $HOSTFILE +echo "detected hosts: $(cat $HOSTFILE)" + +# refer to: https://apptainer.org/user-docs/master/gpu.html +# for apptainer, replace SINGULARITYENV_* with APPTAINERENV_* +# export SINGULARITYENV_CUDA_VISIBLE_DEVICES=$(printf "%s," $(seq 0 $(($GPUS_PER_NODE-1))) | sed 's/,$//') +# echo "singularity cuda visible devices: $SINGULARITYENV_CUDA_VISIBLE_DEVICES" + +# =============== program params ================ +export PYTHONPATH=$PYTHONPATH:$PWD +export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True +export TOKENIZERS_PARALLELISM=false + +mpirun --hostfile $HOSTFILE --np $WORLD_SIZE -N $GPUS_PER_NODE --oversubscribe \ + singularity exec --nv /app1/common/singularity-img/hopper/cuda/cuda_12.1.1-cudnn8-devel-ubuntu22.04.sif \ + /bin/bash -c "source /hpctmp/e1154485/venvs/videosys/bin/activate && \ + python examples/training/cogvideox/train.py \ + examples/training/cogvideox/configs/benchmarks/dcp_intra.yaml +" + +rm $HOSTFILE diff --git a/examples/training/cogvideox/train.py b/examples/training/cogvideox/train.py index a1a03b8e..825b1c7b 100644 --- a/examples/training/cogvideox/train.py +++ b/examples/training/cogvideox/train.py @@ -496,7 +496,7 @@ def main(args): parser.add_argument("--auto-grad-accumulation", action="store_true") parser.add_argument( "--alloc-memory-fraction", - default=0.70, + default=0.56, type=float, help="This is an empirical value to cap the allocated memory during profiling with dynamic sp. Communication in different ranks can cause free memory discrepancy, which can leads to comm deadlock. So you need to leave enough space to bear this discrepancy. If you meet this problem during profiling, try to decrease this value.", ) diff --git a/examples/training/open_sora/1x8_baseline.sh b/examples/training/open_sora/1x8_baseline.sh new file mode 100755 index 00000000..01d03e35 --- /dev/null +++ b/examples/training/open_sora/1x8_baseline.sh @@ -0,0 +1,69 @@ +#!/bin/bash +#PBS -P CFP02-CF-004 +#PBS -l select=1:ngpus=8 +#PBS -l place=vscatter +#PBS -l walltime=96:00:00 +#PBS -j oe +#PBS -o 1x8-opensora-baseline.log + +# =============== env params ================ +# This script is for NSCC which uses PBS Pro as the scheduler + +# where the singularity image is saved +SCRATCH_PATH=$HPCTMP + +cd $PBS_O_WORKDIR +echo "JOB ID: $PBS_JOBID, pwd: $PWD, pbs workdir: $PBS_O_WORKDIR" + +# for torch.distributed +export NNODES=1 +# export NODE_RANK=0 +export GPUS_PER_NODE=8 +export WORLD_SIZE=$(($NNODES*$GPUS_PER_NODE)) +export MASTER_ADDR=$(head -n 1 $PBS_NODEFILE | awk -F'.' '{print $1}') +export MASTER_PORT=29502 +echo "master node: $MASTER_ADDR" + +# used by OpenMPI +export HOSTFILE="$PBS_JOBID.hostfile" +cat $PBS_NODEFILE | awk -F'.' '{for(i=1;i<=NF;i+=6) print $1 " slots="ENVIRON["GPUS_PER_NODE"]}' > $HOSTFILE +echo "detected hosts: $(cat $HOSTFILE)" + +# refer to: https://apptainer.org/user-docs/master/gpu.html +# for apptainer, replace SINGULARITYENV_* with APPTAINERENV_* +# export SINGULARITYENV_CUDA_VISIBLE_DEVICES=$(printf "%s," $(seq 0 $(($GPUS_PER_NODE-1))) | sed 's/,$//') +# echo "singularity cuda visible devices: $SINGULARITYENV_CUDA_VISIBLE_DEVICES" + +# =============== program params ================ +export PYTHONPATH=$PYTHONPATH:$PWD +export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True +export TOKENIZERS_PARALLELISM=false + +# =============== zipf-1 ================ +mpirun --hostfile $HOSTFILE --np $WORLD_SIZE -N $GPUS_PER_NODE --oversubscribe \ + singularity exec --nv /app1/common/singularity-img/hopper/cuda/cuda_12.1.1-cudnn8-devel-ubuntu22.04.sif \ + /bin/bash -c "source /hpctmp/e1154485/venvs/videosys/bin/activate && \ + python examples/training/open_sora/train.py \ + examples/training/open_sora/configs/benchmarks/baseline.yaml \ + --image-mixing-frac 1 +" + +# =============== zipf-10 ================ +mpirun --hostfile $HOSTFILE --np $WORLD_SIZE -N $GPUS_PER_NODE --oversubscribe \ + singularity exec --nv /app1/common/singularity-img/hopper/cuda/cuda_12.1.1-cudnn8-devel-ubuntu22.04.sif \ + /bin/bash -c "source /hpctmp/e1154485/venvs/videosys/bin/activate && \ + python examples/training/open_sora/train.py \ + examples/training/open_sora/configs/benchmarks/baseline.yaml \ + --image-mixing-frac 10 +" + +# =============== zipf-50 ================ +mpirun --hostfile $HOSTFILE --np $WORLD_SIZE -N $GPUS_PER_NODE --oversubscribe \ + singularity exec --nv /app1/common/singularity-img/hopper/cuda/cuda_12.1.1-cudnn8-devel-ubuntu22.04.sif \ + /bin/bash -c "source /hpctmp/e1154485/venvs/videosys/bin/activate && \ + python examples/training/open_sora/train.py \ + examples/training/open_sora/configs/benchmarks/baseline.yaml \ + --image-mixing-frac 50 +" + +rm $HOSTFILE diff --git a/examples/training/open_sora/1x8_dcp_inter.sh b/examples/training/open_sora/1x8_dcp_inter.sh new file mode 100755 index 00000000..753040f5 --- /dev/null +++ b/examples/training/open_sora/1x8_dcp_inter.sh @@ -0,0 +1,69 @@ +#!/bin/bash +#PBS -P CFP02-CF-004 +#PBS -l select=1:ngpus=8 +#PBS -l place=vscatter +#PBS -l walltime=96:00:00 +#PBS -j oe +#PBS -o 1x8-opensora-dcp-inter.log + +# =============== env params ================ +# This script is for NSCC which uses PBS Pro as the scheduler + +# where the singularity image is saved +SCRATCH_PATH=$HPCTMP + +cd $PBS_O_WORKDIR +echo "JOB ID: $PBS_JOBID, pwd: $PWD, pbs workdir: $PBS_O_WORKDIR" + +# for torch.distributed +export NNODES=1 +# export NODE_RANK=0 +export GPUS_PER_NODE=8 +export WORLD_SIZE=$(($NNODES*$GPUS_PER_NODE)) +export MASTER_ADDR=$(head -n 1 $PBS_NODEFILE | awk -F'.' '{print $1}') +export MASTER_PORT=29502 +echo "master node: $MASTER_ADDR" + +# used by OpenMPI +export HOSTFILE="$PBS_JOBID.hostfile" +cat $PBS_NODEFILE | awk -F'.' '{for(i=1;i<=NF;i+=6) print $1 " slots="ENVIRON["GPUS_PER_NODE"]}' > $HOSTFILE +echo "detected hosts: $(cat $HOSTFILE)" + +# refer to: https://apptainer.org/user-docs/master/gpu.html +# for apptainer, replace SINGULARITYENV_* with APPTAINERENV_* +# export SINGULARITYENV_CUDA_VISIBLE_DEVICES=$(printf "%s," $(seq 0 $(($GPUS_PER_NODE-1))) | sed 's/,$//') +# echo "singularity cuda visible devices: $SINGULARITYENV_CUDA_VISIBLE_DEVICES" + +# =============== program params ================ +export PYTHONPATH=$PYTHONPATH:$PWD +export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True +export TOKENIZERS_PARALLELISM=false + +# =============== zipf-1 ================ +mpirun --hostfile $HOSTFILE --np $WORLD_SIZE -N $GPUS_PER_NODE --oversubscribe \ + singularity exec --nv /app1/common/singularity-img/hopper/cuda/cuda_12.1.1-cudnn8-devel-ubuntu22.04.sif \ + /bin/bash -c "source /hpctmp/e1154485/venvs/videosys/bin/activate && \ + python examples/training/open_sora/train.py \ + examples/training/open_sora/configs/benchmarks-sp4/dcp_inter.yaml \ + --image-mixing-frac 1 +" + +# =============== zipf-10 ================ +mpirun --hostfile $HOSTFILE --np $WORLD_SIZE -N $GPUS_PER_NODE --oversubscribe \ + singularity exec --nv /app1/common/singularity-img/hopper/cuda/cuda_12.1.1-cudnn8-devel-ubuntu22.04.sif \ + /bin/bash -c "source /hpctmp/e1154485/venvs/videosys/bin/activate && \ + python examples/training/open_sora/train.py \ + examples/training/open_sora/configs/benchmarks-sp4/dcp_inter.yaml \ + --image-mixing-frac 10 +" + +# =============== zipf-50 ================ +mpirun --hostfile $HOSTFILE --np $WORLD_SIZE -N $GPUS_PER_NODE --oversubscribe \ + singularity exec --nv /app1/common/singularity-img/hopper/cuda/cuda_12.1.1-cudnn8-devel-ubuntu22.04.sif \ + /bin/bash -c "source /hpctmp/e1154485/venvs/videosys/bin/activate && \ + python examples/training/open_sora/train.py \ + examples/training/open_sora/configs/benchmarks-sp4/dcp_inter.yaml \ + --image-mixing-frac 50 +" + +rm $HOSTFILE diff --git a/examples/training/open_sora/1x8_dcp_inter_ckpt.sh b/examples/training/open_sora/1x8_dcp_inter_ckpt.sh new file mode 100755 index 00000000..987fcb0f --- /dev/null +++ b/examples/training/open_sora/1x8_dcp_inter_ckpt.sh @@ -0,0 +1,69 @@ +#!/bin/bash +#PBS -P CFP02-CF-004 +#PBS -l select=1:ngpus=8 +#PBS -l place=vscatter +#PBS -l walltime=96:00:00 +#PBS -j oe +#PBS -o 1x8-opensora-dcp-inter-ckpt.log + +# =============== env params ================ +# This script is for NSCC which uses PBS Pro as the scheduler + +# where the singularity image is saved +SCRATCH_PATH=$HPCTMP + +cd $PBS_O_WORKDIR +echo "JOB ID: $PBS_JOBID, pwd: $PWD, pbs workdir: $PBS_O_WORKDIR" + +# for torch.distributed +export NNODES=1 +# export NODE_RANK=0 +export GPUS_PER_NODE=8 +export WORLD_SIZE=$(($NNODES*$GPUS_PER_NODE)) +export MASTER_ADDR=$(head -n 1 $PBS_NODEFILE | awk -F'.' '{print $1}') +export MASTER_PORT=29502 +echo "master node: $MASTER_ADDR" + +# used by OpenMPI +export HOSTFILE="$PBS_JOBID.hostfile" +cat $PBS_NODEFILE | awk -F'.' '{for(i=1;i<=NF;i+=6) print $1 " slots="ENVIRON["GPUS_PER_NODE"]}' > $HOSTFILE +echo "detected hosts: $(cat $HOSTFILE)" + +# refer to: https://apptainer.org/user-docs/master/gpu.html +# for apptainer, replace SINGULARITYENV_* with APPTAINERENV_* +# export SINGULARITYENV_CUDA_VISIBLE_DEVICES=$(printf "%s," $(seq 0 $(($GPUS_PER_NODE-1))) | sed 's/,$//') +# echo "singularity cuda visible devices: $SINGULARITYENV_CUDA_VISIBLE_DEVICES" + +# =============== program params ================ +export PYTHONPATH=$PYTHONPATH:$PWD +export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True +export TOKENIZERS_PARALLELISM=false + +# # =============== zipf-1 ================ +# mpirun --hostfile $HOSTFILE --np $WORLD_SIZE -N $GPUS_PER_NODE --oversubscribe \ +# singularity exec --nv /app1/common/singularity-img/hopper/cuda/cuda_12.1.1-cudnn8-devel-ubuntu22.04.sif \ +# /bin/bash -c "source /hpctmp/e1154485/venvs/videosys/bin/activate && \ +# python examples/training/open_sora/train.py \ +# examples/training/open_sora/configs/benchmarks-sp4/dcp_inter_ckpt.yaml \ +# --image-mixing-frac 1 +# " + +# # =============== zipf-10 ================ +# mpirun --hostfile $HOSTFILE --np $WORLD_SIZE -N $GPUS_PER_NODE --oversubscribe \ +# singularity exec --nv /app1/common/singularity-img/hopper/cuda/cuda_12.1.1-cudnn8-devel-ubuntu22.04.sif \ +# /bin/bash -c "source /hpctmp/e1154485/venvs/videosys/bin/activate && \ +# python examples/training/open_sora/train.py \ +# examples/training/open_sora/configs/benchmarks-sp4/dcp_inter_ckpt.yaml \ +# --image-mixing-frac 10 +# " + +# =============== zipf-50 ================ +mpirun --hostfile $HOSTFILE --np $WORLD_SIZE -N $GPUS_PER_NODE --oversubscribe \ + singularity exec --nv /app1/common/singularity-img/hopper/cuda/cuda_12.1.1-cudnn8-devel-ubuntu22.04.sif \ + /bin/bash -c "source /hpctmp/e1154485/venvs/videosys/bin/activate && \ + python examples/training/open_sora/train.py \ + examples/training/open_sora/configs/benchmarks-sp4/dcp_inter_ckpt.yaml \ + --image-mixing-frac 50 +" + +rm $HOSTFILE diff --git a/examples/training/open_sora/1x8_dcp_intra.sh b/examples/training/open_sora/1x8_dcp_intra.sh new file mode 100755 index 00000000..0d5a1d07 --- /dev/null +++ b/examples/training/open_sora/1x8_dcp_intra.sh @@ -0,0 +1,69 @@ +#!/bin/bash +#PBS -P CFP02-CF-004 +#PBS -l select=1:ngpus=8 +#PBS -l place=vscatter +#PBS -l walltime=96:00:00 +#PBS -j oe +#PBS -o 1x8-opensora-dcp-intra.log + +# =============== env params ================ +# This script is for NSCC which uses PBS Pro as the scheduler + +# where the singularity image is saved +SCRATCH_PATH=$HPCTMP + +cd $PBS_O_WORKDIR +echo "JOB ID: $PBS_JOBID, pwd: $PWD, pbs workdir: $PBS_O_WORKDIR" + +# for torch.distributed +export NNODES=1 +# export NODE_RANK=0 +export GPUS_PER_NODE=8 +export WORLD_SIZE=$(($NNODES*$GPUS_PER_NODE)) +export MASTER_ADDR=$(head -n 1 $PBS_NODEFILE | awk -F'.' '{print $1}') +export MASTER_PORT=29502 +echo "master node: $MASTER_ADDR" + +# used by OpenMPI +export HOSTFILE="$PBS_JOBID.hostfile" +cat $PBS_NODEFILE | awk -F'.' '{for(i=1;i<=NF;i+=6) print $1 " slots="ENVIRON["GPUS_PER_NODE"]}' > $HOSTFILE +echo "detected hosts: $(cat $HOSTFILE)" + +# refer to: https://apptainer.org/user-docs/master/gpu.html +# for apptainer, replace SINGULARITYENV_* with APPTAINERENV_* +# export SINGULARITYENV_CUDA_VISIBLE_DEVICES=$(printf "%s," $(seq 0 $(($GPUS_PER_NODE-1))) | sed 's/,$//') +# echo "singularity cuda visible devices: $SINGULARITYENV_CUDA_VISIBLE_DEVICES" + +# =============== program params ================ +export PYTHONPATH=$PYTHONPATH:$PWD +export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True +export TOKENIZERS_PARALLELISM=false + +# =============== zipf-1 ================ +mpirun --hostfile $HOSTFILE --np $WORLD_SIZE -N $GPUS_PER_NODE --oversubscribe \ + singularity exec --nv /app1/common/singularity-img/hopper/cuda/cuda_12.1.1-cudnn8-devel-ubuntu22.04.sif \ + /bin/bash -c "source /hpctmp/e1154485/venvs/videosys/bin/activate && \ + python examples/training/open_sora/train.py \ + examples/training/open_sora/configs/benchmarks-sp4/dcp_intra.yaml \ + --image-mixing-frac 1 +" + +# =============== zipf-10 ================ +mpirun --hostfile $HOSTFILE --np $WORLD_SIZE -N $GPUS_PER_NODE --oversubscribe \ + singularity exec --nv /app1/common/singularity-img/hopper/cuda/cuda_12.1.1-cudnn8-devel-ubuntu22.04.sif \ + /bin/bash -c "source /hpctmp/e1154485/venvs/videosys/bin/activate && \ + python examples/training/open_sora/train.py \ + examples/training/open_sora/configs/benchmarks-sp4/dcp_intra.yaml \ + --image-mixing-frac 10 +" + +# # =============== zipf-50 ================ +# mpirun --hostfile $HOSTFILE --np $WORLD_SIZE -N $GPUS_PER_NODE --oversubscribe \ +# singularity exec --nv /app1/common/singularity-img/hopper/cuda/cuda_12.1.1-cudnn8-devel-ubuntu22.04.sif \ +# /bin/bash -c "source /hpctmp/e1154485/venvs/videosys/bin/activate && \ +# python examples/training/open_sora/train.py \ +# examples/training/open_sora/configs/benchmarks-sp4/dcp_intra.yaml \ +# --image-mixing-frac 50 +# " + +rm $HOSTFILE diff --git a/examples/training/open_sora/baseline_profile.sh b/examples/training/open_sora/baseline_profile.sh new file mode 100755 index 00000000..a32ac4e5 --- /dev/null +++ b/examples/training/open_sora/baseline_profile.sh @@ -0,0 +1,50 @@ +#!/bin/bash +#PBS -P CFP02-CF-004 +#PBS -l select=1:ngpus=4 +#PBS -l place=vscatter +#PBS -l walltime=12:00:00 +#PBS -j oe +#PBS -o profile-baseline-opensora.log + +# =============== env params ================ +# This script is for NSCC which uses PBS Pro as the scheduler + +# where the singularity image is saved +SCRATCH_PATH=$HPCTMP + +cd $PBS_O_WORKDIR +echo "JOB ID: $PBS_JOBID, pwd: $PWD, pbs workdir: $PBS_O_WORKDIR" + +# for torch.distributed +export NNODES=1 +# export NODE_RANK=0 +export GPUS_PER_NODE=4 +export WORLD_SIZE=$(($NNODES*$GPUS_PER_NODE)) +export MASTER_ADDR=$(head -n 1 $PBS_NODEFILE | awk -F'.' '{print $1}') +export MASTER_PORT=9527 +echo "master node: $MASTER_ADDR" + +# used by OpenMPI +export HOSTFILE="$PBS_JOBID.hostfile" +cat $PBS_NODEFILE | awk -F'.' '{for(i=1;i<=NF;i+=6) print $1 " slots="ENVIRON["GPUS_PER_NODE"]}' > $HOSTFILE +echo "detected hosts: $(cat $HOSTFILE)" + +# refer to: https://apptainer.org/user-docs/master/gpu.html +# for apptainer, replace SINGULARITYENV_* with APPTAINERENV_* +# export SINGULARITYENV_CUDA_VISIBLE_DEVICES=$(printf "%s," $(seq 0 $(($GPUS_PER_NODE-1))) | sed 's/,$//') +# echo "singularity cuda visible devices: $SINGULARITYENV_CUDA_VISIBLE_DEVICES" + +# =============== program params ================ +export PYTHONPATH=$PYTHONPATH:$PWD +export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True +export TOKENIZERS_PARALLELISM=false + +# baseline +mpirun --hostfile $HOSTFILE --np $WORLD_SIZE -N $GPUS_PER_NODE --oversubscribe \ + singularity exec --nv /app1/common/singularity-img/hopper/cuda/cuda_12.1.1-cudnn8-devel-ubuntu22.04.sif \ + /bin/bash -c "source /hpctmp/e1154485/venvs/videosys/bin/activate && \ + python examples/training/open_sora/train.py \ + examples/training/open_sora/configs/benchmarks/baseline.yaml +" + +rm $HOSTFILE diff --git a/examples/training/open_sora/configs/benchmarks-sp4/dcp_inter.yaml b/examples/training/open_sora/configs/benchmarks-sp4/dcp_inter.yaml new file mode 100755 index 00000000..4c2f82bc --- /dev/null +++ b/examples/training/open_sora/configs/benchmarks-sp4/dcp_inter.yaml @@ -0,0 +1,71 @@ +zipf_offset: 5 +outputs: exp/opensora/dcp_inter +profile_path: exp/opensora/profile-sp4/dcp_inter +dynamic_sp: true +dynamic_recompute: false +auto_grad_accumulation: true +dummy_dataset: true +dummy_data_size: 2000 +verbose: true +calculate_imbalance: true +max_grad_accumulation_steps: 5 + + +# ==== training config ==== + +# preprocess embedding +data_path: "./assets/example_data/demo_preprocess.csv" +preprocessed_data: true +drop_last: false # true + +# train +ckpt_path: "hpcai-tech/OpenSora-STDiT-v3" +grad_checkpoint: True +num_workers: 8 +dtype: "bf16" + +# log +seed: 42 +epochs: 1 +log_every: 1e10 + +# optimization +grad_clip: 1.0 +lr: 1e-8 +ema_decay: 0.99 +adam_eps: 1e-15 +warmup_steps: 10 + +# data +# image_mixing_frac: 50 +num_bucket_build_workers: 16 +bucket_config: + "144p": {1: [1.0, 345], 51: [1.0, 48], 102: [1.0, 25], 204: [1.0, 12], 408: [1.0, 6]} + "240p": {1: [1.0, 128], 51: [1.0, 16], 102: [1.0, 8], 204: [1.0, 4], 408: [1.0, 2]} + "360p": {1: [1.0, 64], 51: [1.0, 7], 102: [1.0, 4], 204: [1.0, 2], 408: [1.0, 1]} + "480p": {1: [1.0, 32], 51: [1.0, 4], 102: [1.0, 2], 204: [1.0, 1], 408: [1.0, 1]} + "720p": {1: [1.0, 14], 51: [1.0, 1], 102: [1.0, 1], 204: [1.0, 1], 408: [1.0, 1]} + +# override default common ar +# for benchmark, we use single ar for all resolutions +# otherwise the data will be too sparse +common_ar: + "144p": {"0.56": [144, 256]} + "240p": {"0.56": [240, 426]} + "360p": {"0.56": [360, 640]} + "480p": {"0.56": [480, 854]} + "720p": {"0.56": [720, 1280]} + +# mask +mask_ratios: { + "random": 0.01, + "intepolate": 0.002, + "quarter_random": 0.002, + "quarter_head": 0.002, + "quarter_tail": 0.002, + "quarter_head_tail": 0.002, + "image_random": 0.0, + "image_head": 0.22, + "image_tail": 0.005, + "image_head_tail": 0.005, +} diff --git a/examples/training/open_sora/configs/benchmarks-sp4/dcp_inter_ckpt.yaml b/examples/training/open_sora/configs/benchmarks-sp4/dcp_inter_ckpt.yaml new file mode 100755 index 00000000..c09b1754 --- /dev/null +++ b/examples/training/open_sora/configs/benchmarks-sp4/dcp_inter_ckpt.yaml @@ -0,0 +1,71 @@ +zipf_offset: 5 +outputs: exp/opensora/dcp_inter_ckpt +profile_path: exp/opensora/profile-sp4/dcp_inter_ckpt +dynamic_sp: true +dynamic_recompute: true +auto_grad_accumulation: true +dummy_dataset: true +dummy_data_size: 2000 +verbose: true +calculate_imbalance: true +max_grad_accumulation_steps: 5 +min_grad_accumulation_steps: 15 + +# ==== training config ==== + +# preprocess embedding +data_path: "./assets/example_data/demo_preprocess.csv" +preprocessed_data: true +drop_last: false # true + +# train +ckpt_path: "hpcai-tech/OpenSora-STDiT-v3" +grad_checkpoint: True +num_workers: 8 +dtype: "bf16" + +# log +seed: 42 +epochs: 1 +log_every: 1e10 + +# optimization +grad_clip: 1.0 +lr: 1e-8 +ema_decay: 0.99 +adam_eps: 1e-15 +warmup_steps: 10 + +# data +# image_mixing_frac: 50 +num_bucket_build_workers: 16 +bucket_config: + "144p": {1: [1.0, 345], 51: [1.0, 48], 102: [1.0, 25], 204: [1.0, 12], 408: [1.0, 6]} + "240p": {1: [1.0, 128], 51: [1.0, 16], 102: [1.0, 8], 204: [1.0, 4], 408: [1.0, 2]} + "360p": {1: [1.0, 64], 51: [1.0, 7], 102: [1.0, 4], 204: [1.0, 2], 408: [1.0, 1]} + "480p": {1: [1.0, 32], 51: [1.0, 4], 102: [1.0, 2], 204: [1.0, 1], 408: [1.0, 1]} + "720p": {1: [1.0, 14], 51: [1.0, 1], 102: [1.0, 1], 204: [1.0, 1], 408: [1.0, 1]} + +# override default common ar +# for benchmark, we use single ar for all resolutions +# otherwise the data will be too sparse +common_ar: + "144p": {"0.56": [144, 256]} + "240p": {"0.56": [240, 426]} + "360p": {"0.56": [360, 640]} + "480p": {"0.56": [480, 854]} + "720p": {"0.56": [720, 1280]} + +# mask +mask_ratios: { + "random": 0.01, + "intepolate": 0.002, + "quarter_random": 0.002, + "quarter_head": 0.002, + "quarter_tail": 0.002, + "quarter_head_tail": 0.002, + "image_random": 0.0, + "image_head": 0.22, + "image_tail": 0.005, + "image_head_tail": 0.005, +} diff --git a/examples/training/open_sora/configs/benchmarks-sp4/dcp_intra.yaml b/examples/training/open_sora/configs/benchmarks-sp4/dcp_intra.yaml new file mode 100755 index 00000000..8cd3b840 --- /dev/null +++ b/examples/training/open_sora/configs/benchmarks-sp4/dcp_intra.yaml @@ -0,0 +1,70 @@ +zipf_offset: 5 +outputs: exp/opensora/dcp_intra +profile_path: exp/opensora/profile-sp4/dcp_intra +dynamic_sp: true +dynamic_recompute: false +auto_grad_accumulation: false +dummy_dataset: true +dummy_data_size: 2000 +verbose: true +calculate_imbalance: true + + +# ==== training config ==== + +# preprocess embedding +data_path: "./assets/example_data/demo_preprocess.csv" +preprocessed_data: true +drop_last: false # true + +# train +ckpt_path: "hpcai-tech/OpenSora-STDiT-v3" +grad_checkpoint: True +num_workers: 8 +dtype: "bf16" + +# log +seed: 42 +epochs: 1 +log_every: 1e10 + +# optimization +grad_clip: 1.0 +lr: 1e-8 +ema_decay: 0.99 +adam_eps: 1e-15 +warmup_steps: 10 + +# data +# image_mixing_frac: 50 +num_bucket_build_workers: 16 +bucket_config: + "144p": {1: [1.0, 345], 51: [1.0, 48], 102: [1.0, 25], 204: [1.0, 12], 408: [1.0, 6]} + "240p": {1: [1.0, 128], 51: [1.0, 16], 102: [1.0, 8], 204: [1.0, 4], 408: [1.0, 2]} + "360p": {1: [1.0, 64], 51: [1.0, 7], 102: [1.0, 4], 204: [1.0, 2], 408: [1.0, 1]} + "480p": {1: [1.0, 32], 51: [1.0, 4], 102: [1.0, 2], 204: [1.0, 1], 408: [1.0, 1]} + "720p": {1: [1.0, 14], 51: [1.0, 1], 102: [1.0, 1], 204: [1.0, 1], 408: [1.0, 1]} + +# override default common ar +# for benchmark, we use single ar for all resolutions +# otherwise the data will be too sparse +common_ar: + "144p": {"0.56": [144, 256]} + "240p": {"0.56": [240, 426]} + "360p": {"0.56": [360, 640]} + "480p": {"0.56": [480, 854]} + "720p": {"0.56": [720, 1280]} + +# mask +mask_ratios: { + "random": 0.01, + "intepolate": 0.002, + "quarter_random": 0.002, + "quarter_head": 0.002, + "quarter_tail": 0.002, + "quarter_head_tail": 0.002, + "image_random": 0.0, + "image_head": 0.22, + "image_tail": 0.005, + "image_head_tail": 0.005, +} diff --git a/examples/training/open_sora/configs/benchmarks/baseline.yaml b/examples/training/open_sora/configs/benchmarks/baseline.yaml index b707d903..6d2907b8 100644 --- a/examples/training/open_sora/configs/benchmarks/baseline.yaml +++ b/examples/training/open_sora/configs/benchmarks/baseline.yaml @@ -1,6 +1,6 @@ zipf_offset: 5 -outputs: exp/baseline -profile_path: exp/profile/baseline +outputs: exp/opensora/baseline +profile_path: exp/opensora/profile/baseline sp_size: 4 dummy_dataset: true dummy_data_size: 2000 @@ -13,7 +13,7 @@ calculate_imbalance: true # preprocess embedding data_path: "./assets/example_data/demo_preprocess.csv" preprocessed_data: true -drop_last: true +drop_last: false # true # train ckpt_path: "hpcai-tech/OpenSora-STDiT-v3" @@ -34,7 +34,7 @@ adam_eps: 1e-15 warmup_steps: 10 # data -image_mixing_frac: 50 +# image_mixing_frac: 50 num_bucket_build_workers: 16 bucket_config: "144p": {1: [1.0, 345], 51: [1.0, 48], 102: [1.0, 25], 204: [1.0, 12], 408: [1.0, 6]} diff --git a/examples/training/open_sora/configs/benchmarks/dcp_inter.yaml b/examples/training/open_sora/configs/benchmarks/dcp_inter.yaml index d85d86ef..7485bd16 100644 --- a/examples/training/open_sora/configs/benchmarks/dcp_inter.yaml +++ b/examples/training/open_sora/configs/benchmarks/dcp_inter.yaml @@ -1,6 +1,6 @@ zipf_offset: 5 -outputs: exp/dcp_inter -profile_path: exp/profile/dcp_inter +outputs: exp/opensora/dcp_inter +profile_path: exp/opensora/profile/dcp_inter dynamic_sp: true dynamic_recompute: false auto_grad_accumulation: true diff --git a/examples/training/open_sora/configs/benchmarks/dcp_inter_ckpt.yaml b/examples/training/open_sora/configs/benchmarks/dcp_inter_ckpt.yaml index b62093d3..718d3fe0 100644 --- a/examples/training/open_sora/configs/benchmarks/dcp_inter_ckpt.yaml +++ b/examples/training/open_sora/configs/benchmarks/dcp_inter_ckpt.yaml @@ -1,6 +1,6 @@ zipf_offset: 5 -outputs: exp/dcp_inter_ckpt -profile_path: exp/profile/dcp_inter_ckpt +outputs: exp/opensora/dcp_inter_ckpt +profile_path: exp/opensora/profile/dcp_inter_ckpt dynamic_sp: true dynamic_recompute: true auto_grad_accumulation: true diff --git a/examples/training/open_sora/configs/benchmarks/dcp_intra.yaml b/examples/training/open_sora/configs/benchmarks/dcp_intra.yaml index d6bf3cda..4970996d 100644 --- a/examples/training/open_sora/configs/benchmarks/dcp_intra.yaml +++ b/examples/training/open_sora/configs/benchmarks/dcp_intra.yaml @@ -1,6 +1,6 @@ zipf_offset: 5 -outputs: exp/dcp_intra -profile_path: exp/profile/dcp_intra +outputs: exp/opensora/dcp_intra +profile_path: exp/opensora/profile/dcp_intra dynamic_sp: true dynamic_recompute: false auto_grad_accumulation: false diff --git a/examples/training/open_sora/profile-dcp-inter-ckpt-sp4.sh b/examples/training/open_sora/profile-dcp-inter-ckpt-sp4.sh new file mode 100755 index 00000000..5a30fee7 --- /dev/null +++ b/examples/training/open_sora/profile-dcp-inter-ckpt-sp4.sh @@ -0,0 +1,49 @@ +#!/bin/bash +#PBS -P CFP02-CF-004 +#PBS -l select=1:ngpus=4 +#PBS -l place=vscatter +#PBS -l walltime=12:00:00 +#PBS -j oe +#PBS -o profile-dcp-inter-ckpt-opensora-sp4.log + +# =============== env params ================ +# This script is for NSCC which uses PBS Pro as the scheduler + +# where the singularity image is saved +SCRATCH_PATH=$HPCTMP + +cd $PBS_O_WORKDIR +echo "JOB ID: $PBS_JOBID, pwd: $PWD, pbs workdir: $PBS_O_WORKDIR" + +# for torch.distributed +export NNODES=1 +# export NODE_RANK=0 +export GPUS_PER_NODE=4 +export WORLD_SIZE=$(($NNODES*$GPUS_PER_NODE)) +export MASTER_ADDR=$(head -n 1 $PBS_NODEFILE | awk -F'.' '{print $1}') +export MASTER_PORT=9527 +echo "master node: $MASTER_ADDR" + +# used by OpenMPI +export HOSTFILE="$PBS_JOBID.hostfile" +cat $PBS_NODEFILE | awk -F'.' '{for(i=1;i<=NF;i+=6) print $1 " slots="ENVIRON["GPUS_PER_NODE"]}' > $HOSTFILE +echo "detected hosts: $(cat $HOSTFILE)" + +# refer to: https://apptainer.org/user-docs/master/gpu.html +# for apptainer, replace SINGULARITYENV_* with APPTAINERENV_* +# export SINGULARITYENV_CUDA_VISIBLE_DEVICES=$(printf "%s," $(seq 0 $(($GPUS_PER_NODE-1))) | sed 's/,$//') +# echo "singularity cuda visible devices: $SINGULARITYENV_CUDA_VISIBLE_DEVICES" + +# =============== program params ================ +export PYTHONPATH=$PYTHONPATH:$PWD +export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True +export TOKENIZERS_PARALLELISM=false + +mpirun --hostfile $HOSTFILE --np $WORLD_SIZE -N $GPUS_PER_NODE --oversubscribe \ + singularity exec --nv /app1/common/singularity-img/hopper/cuda/cuda_12.1.1-cudnn8-devel-ubuntu22.04.sif \ + /bin/bash -c "source /hpctmp/e1154485/venvs/videosys/bin/activate && \ + python examples/training/open_sora/train.py \ + examples/training/open_sora/configs/benchmarks-sp4/dcp_inter_ckpt.yaml +" + +rm $HOSTFILE diff --git a/examples/training/open_sora/profile-dcp-inter-sp4.sh b/examples/training/open_sora/profile-dcp-inter-sp4.sh new file mode 100755 index 00000000..50979593 --- /dev/null +++ b/examples/training/open_sora/profile-dcp-inter-sp4.sh @@ -0,0 +1,49 @@ +#!/bin/bash +#PBS -P CFP02-CF-004 +#PBS -l select=1:ngpus=4 +#PBS -l place=vscatter +#PBS -l walltime=12:00:00 +#PBS -j oe +#PBS -o profile-dcp-inter-opensora-sp4.log + +# =============== env params ================ +# This script is for NSCC which uses PBS Pro as the scheduler + +# where the singularity image is saved +SCRATCH_PATH=$HPCTMP + +cd $PBS_O_WORKDIR +echo "JOB ID: $PBS_JOBID, pwd: $PWD, pbs workdir: $PBS_O_WORKDIR" + +# for torch.distributed +export NNODES=1 +# export NODE_RANK=0 +export GPUS_PER_NODE=4 +export WORLD_SIZE=$(($NNODES*$GPUS_PER_NODE)) +export MASTER_ADDR=$(head -n 1 $PBS_NODEFILE | awk -F'.' '{print $1}') +export MASTER_PORT=9527 +echo "master node: $MASTER_ADDR" + +# used by OpenMPI +export HOSTFILE="$PBS_JOBID.hostfile" +cat $PBS_NODEFILE | awk -F'.' '{for(i=1;i<=NF;i+=6) print $1 " slots="ENVIRON["GPUS_PER_NODE"]}' > $HOSTFILE +echo "detected hosts: $(cat $HOSTFILE)" + +# refer to: https://apptainer.org/user-docs/master/gpu.html +# for apptainer, replace SINGULARITYENV_* with APPTAINERENV_* +# export SINGULARITYENV_CUDA_VISIBLE_DEVICES=$(printf "%s," $(seq 0 $(($GPUS_PER_NODE-1))) | sed 's/,$//') +# echo "singularity cuda visible devices: $SINGULARITYENV_CUDA_VISIBLE_DEVICES" + +# =============== program params ================ +export PYTHONPATH=$PYTHONPATH:$PWD +export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True +export TOKENIZERS_PARALLELISM=false + +mpirun --hostfile $HOSTFILE --np $WORLD_SIZE -N $GPUS_PER_NODE --oversubscribe \ + singularity exec --nv /app1/common/singularity-img/hopper/cuda/cuda_12.1.1-cudnn8-devel-ubuntu22.04.sif \ + /bin/bash -c "source /hpctmp/e1154485/venvs/videosys/bin/activate && \ + python examples/training/open_sora/train.py \ + examples/training/open_sora/configs/benchmarks-sp4/dcp_inter.yaml +" + +rm $HOSTFILE diff --git a/examples/training/open_sora/profile-dcp-intra-sp4.sh b/examples/training/open_sora/profile-dcp-intra-sp4.sh new file mode 100755 index 00000000..1ceceb91 --- /dev/null +++ b/examples/training/open_sora/profile-dcp-intra-sp4.sh @@ -0,0 +1,49 @@ +#!/bin/bash +#PBS -P CFP02-CF-004 +#PBS -l select=1:ngpus=4 +#PBS -l place=vscatter +#PBS -l walltime=12:00:00 +#PBS -j oe +#PBS -o profile-dcp-intra-opensora-sp4.log + +# =============== env params ================ +# This script is for NSCC which uses PBS Pro as the scheduler + +# where the singularity image is saved +SCRATCH_PATH=$HPCTMP + +cd $PBS_O_WORKDIR +echo "JOB ID: $PBS_JOBID, pwd: $PWD, pbs workdir: $PBS_O_WORKDIR" + +# for torch.distributed +export NNODES=1 +# export NODE_RANK=0 +export GPUS_PER_NODE=4 +export WORLD_SIZE=$(($NNODES*$GPUS_PER_NODE)) +export MASTER_ADDR=$(head -n 1 $PBS_NODEFILE | awk -F'.' '{print $1}') +export MASTER_PORT=9528 +echo "master node: $MASTER_ADDR" + +# used by OpenMPI +export HOSTFILE="$PBS_JOBID.hostfile" +cat $PBS_NODEFILE | awk -F'.' '{for(i=1;i<=NF;i+=6) print $1 " slots="ENVIRON["GPUS_PER_NODE"]}' > $HOSTFILE +echo "detected hosts: $(cat $HOSTFILE)" + +# refer to: https://apptainer.org/user-docs/master/gpu.html +# for apptainer, replace SINGULARITYENV_* with APPTAINERENV_* +# export SINGULARITYENV_CUDA_VISIBLE_DEVICES=$(printf "%s," $(seq 0 $(($GPUS_PER_NODE-1))) | sed 's/,$//') +# echo "singularity cuda visible devices: $SINGULARITYENV_CUDA_VISIBLE_DEVICES" + +# =============== program params ================ +export PYTHONPATH=$PYTHONPATH:$PWD +export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True +export TOKENIZERS_PARALLELISM=false + +mpirun --hostfile $HOSTFILE --np $WORLD_SIZE -N $GPUS_PER_NODE --oversubscribe \ + singularity exec --nv /app1/common/singularity-img/hopper/cuda/cuda_12.1.1-cudnn8-devel-ubuntu22.04.sif \ + /bin/bash -c "source /hpctmp/e1154485/venvs/videosys/bin/activate && \ + python examples/training/open_sora/train.py \ + examples/training/open_sora/configs/benchmarks-sp4/dcp_intra.yaml +" + +rm $HOSTFILE diff --git a/examples/training/open_sora/train.py b/examples/training/open_sora/train.py index dfdd5409..e3cc6c7d 100644 --- a/examples/training/open_sora/train.py +++ b/examples/training/open_sora/train.py @@ -52,7 +52,7 @@ def main(args): backend="nccl", timeout=timedelta(minutes=10), ) - deepspeed.init_distributed(timeout=timedelta(seconds=10)) + deepspeed.init_distributed(timeout=timedelta(minutes=5)) torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count()) set_seed(args.seed) device = torch.cuda.current_device() @@ -153,7 +153,7 @@ def main(args): # ====================================================== # 3. build dataset and dataloader # ====================================================== - logging.info("Building dataset...") + logging.info(f"Building dataset... model max length: {model.config.model_max_length}, ") # create dcp profiler # TODO: scheduler is a better name? profiler: Profiler = set_profiler( @@ -328,7 +328,9 @@ def main(args): # move data x = batch_data.pop("video").to(device, dtype) # [B, C, T, H, W] y = batch_data.pop("text").to(device, dtype) - mask = batch_data.pop("mask").to(device) + mask = batch_data.pop("mask") + if mask is not None: + mask = mask.to(device) model_args = dict(y=y, mask=mask) else: with torch.no_grad(): @@ -499,7 +501,7 @@ def main(args): parser.add_argument("--auto-grad-accumulation", action="store_true") parser.add_argument( "--alloc-memory-fraction", - default=0.70, + default=0.56, type=float, help="This is an empirical value to cap the allocated memory during profiling with dynamic sp. Communication in different ranks can cause free memory discrepancy, which can leads to comm deadlock. So you need to leave enough space to bear this discrepancy. If you meet this problem during profiling, try to decrease this value.", ) diff --git a/parse_log.py b/parse_log.py new file mode 100644 index 00000000..4e5b67d5 --- /dev/null +++ b/parse_log.py @@ -0,0 +1,56 @@ +import re +import argparse +import os +import pandas as pd + + +def parse_log(file_path): + imbalance_pattern = r'Total imbalance for this epoch:.*?\((\d+\.\d+)%\)' + throughput_pattern = r'token throughput: (\d+\.\d+) token/s' + + with open(file_path, 'r') as file: + log_lines = file.read() + + imbalance_match = re.search(imbalance_pattern, log_lines) + throughput_match = re.search(throughput_pattern, log_lines) + + imbalance_percent = float(imbalance_match.group(1)) if imbalance_match else None + token_throughput = float(throughput_match.group(1)) if throughput_match else None + + return imbalance_percent, token_throughput + + +def main(): + parser = argparse.ArgumentParser(description='Parse log file for imbalance and throughput metrics.') + parser.add_argument("--log_dir", type=str, required=True) + args = parser.parse_args() + results = [] + for dirpath, dirnames, filenames in os.walk(args.log_dir): + if 'log.txt' in filenames: + log_path = os.path.join(dirpath, 'log.txt') + try: + imbalance_percent, token_throughput = parse_log(log_path) + + relative_path = os.path.relpath(dirpath, args.log_dir) + + results.append({ + 'experiment': relative_path.split('/')[0], # baseline, dcp_inter, etc. + 'run': relative_path.split('/')[-1], # 000-OpenSora, etc. + 'log_path': relative_path, + 'imbalance_percent': imbalance_percent, + 'token_throughput': token_throughput + }) + + except Exception as e: + print(f"Error reading {log_path}: {e}") + + df = pd.DataFrame(results) + df = df.sort_values(by=['run', 'experiment']).reset_index(drop=True) + # print(df) + save_path = os.path.join(args.log_dir, 'summary.csv') + df.to_csv(save_path, index=False) + print(f"Parsed results saved to '{save_path}'") + + +if __name__ == "__main__": + main() diff --git a/videosys/core/dcp/profiler.py b/videosys/core/dcp/profiler.py index 0675611c..3ff54ee0 100644 --- a/videosys/core/dcp/profiler.py +++ b/videosys/core/dcp/profiler.py @@ -132,7 +132,7 @@ def open_sora_synthesizer(data_plan, auto_grad_acc, data_idx, text_max_seq_len, text_max_seq_len, text_hidden_size, ), - mask=torch.ones(data_plan.bs, text_max_seq_len, dtype=torch.long), + mask=None, # torch.ones(data_plan.bs, text_max_seq_len, dtype=torch.long), num_frames=torch.tensor([data_plan.num_frame] * data_plan.bs), height=torch.tensor([height] * data_plan.bs), width=torch.tensor([width] * data_plan.bs), @@ -286,7 +286,7 @@ def _load_profile(self): if not self.do_profile: assert os.path.isdir(self.profile_path) self.profile_results = {} - + max_sp = 0 # Iterate through all profile_*.json files in the directory for filename in os.listdir(self.profile_path): if filename.startswith("profile") and filename.endswith(".json"): @@ -298,6 +298,11 @@ def _load_profile(self): if ar_name not in self.profile_results: self.profile_results[ar_name] = {} self.profile_results[ar_name].update(num_frame_dict) + for num_frame in num_frame_dict: + sp_size = num_frame_dict[num_frame]["sp_size"] + if sp_size > max_sp: + max_sp = sp_size + self.max_sp = max_sp # Convert frame numbers from strings to integers for ar_name in self.profile_results: @@ -517,7 +522,7 @@ def finalize_profile(self): clean_cache() def init_profiler(self): - torch.cuda.set_per_process_memory_fraction(self.alloc_fraction) + # torch.cuda.set_per_process_memory_fraction(self.alloc_fraction) self.profile_pbar = tqdm( range(self.next_bucket_idx, self.bucket_partition_boundary), desc="Profiling", diff --git a/videosys/core/distributed/parallel_mgr.py b/videosys/core/distributed/parallel_mgr.py index d8a0be75..0d17872a 100644 --- a/videosys/core/distributed/parallel_mgr.py +++ b/videosys/core/distributed/parallel_mgr.py @@ -73,7 +73,7 @@ def _build_clusters(self): group_start_indices = list(range(0, wsize, _s)) for group_start_idx in group_start_indices: group_ranks = global_ranks[group_start_idx : group_start_idx + _s] - gpu_group = dist.new_group(group_ranks, use_local_synchronization=True, timeout=timedelta(seconds=60)) + gpu_group = dist.new_group(group_ranks, use_local_synchronization=True, timeout=timedelta(minutes=5)) cpu_group = dist.new_group(group_ranks, backend="gloo", use_local_synchronization=True) if self._rank in group_ranks: self.sp_clusters[_s] = gpu_group From 71becec3274505fc4245be4ebcde05a0810e3d20 Mon Sep 17 00:00:00 2001 From: zxgx Date: Wed, 27 Aug 2025 15:45:36 +0800 Subject: [PATCH 6/6] integrate flops counter, finish exp --- examples/training/cogvideox/1x8_baseline.sh | 8 ++--- examples/training/cogvideox/1x8_dcp_inter.sh | 36 +++++++++---------- .../training/cogvideox/1x8_dcp_inter_ckpt.sh | 8 ++--- examples/training/cogvideox/1x8_dcp_intra.sh | 8 ++--- .../configs/benchmarks/baseline.yaml | 2 +- .../configs/benchmarks/dcp_inter.yaml | 4 +-- .../configs/benchmarks/dcp_inter_ckpt.yaml | 2 +- .../configs/benchmarks/dcp_intra.yaml | 2 +- examples/training/cogvideox/train.py | 20 ++++++++++- examples/training/open_sora/1x8_baseline.sh | 6 ++-- examples/training/open_sora/1x8_dcp_inter.sh | 6 ++-- .../training/open_sora/1x8_dcp_inter_ckpt.sh | 34 +++++++++--------- examples/training/open_sora/1x8_dcp_intra.sh | 20 +++++------ .../configs/benchmarks-sp4/dcp_inter.yaml | 4 +-- .../benchmarks-sp4/dcp_inter_ckpt.yaml | 4 +-- .../configs/benchmarks-sp4/dcp_intra.yaml | 4 +-- .../configs/benchmarks/baseline.yaml | 4 +-- examples/training/open_sora/train.py | 20 ++++++++++- parse_log.py | 12 ++++--- 19 files changed, 122 insertions(+), 82 deletions(-) diff --git a/examples/training/cogvideox/1x8_baseline.sh b/examples/training/cogvideox/1x8_baseline.sh index 3045a18b..4b86f0e9 100755 --- a/examples/training/cogvideox/1x8_baseline.sh +++ b/examples/training/cogvideox/1x8_baseline.sh @@ -2,7 +2,7 @@ #PBS -P CFP02-CF-004 #PBS -l select=1:ngpus=8 #PBS -l place=vscatter -#PBS -l walltime=96:00:00 +#PBS -l walltime=24:00:00 #PBS -j oe #PBS -o 1x8-cogvideox-baseline.log @@ -45,7 +45,7 @@ mpirun --hostfile $HOSTFILE --np $WORLD_SIZE -N $GPUS_PER_NODE --oversubscribe \ /bin/bash -c "source /hpctmp/e1154485/venvs/videosys/bin/activate && \ python examples/training/cogvideox/train.py \ examples/training/cogvideox/configs/benchmarks/baseline.yaml \ - --image-mixing-frac 1 + --image-mixing-frac 1 --profile-flops " # =============== zipf-10 ================ @@ -54,7 +54,7 @@ mpirun --hostfile $HOSTFILE --np $WORLD_SIZE -N $GPUS_PER_NODE --oversubscribe \ /bin/bash -c "source /hpctmp/e1154485/venvs/videosys/bin/activate && \ python examples/training/cogvideox/train.py \ examples/training/cogvideox/configs/benchmarks/baseline.yaml \ - --image-mixing-frac 10 + --image-mixing-frac 10 --profile-flops " # =============== zipf-50 ================ @@ -63,7 +63,7 @@ mpirun --hostfile $HOSTFILE --np $WORLD_SIZE -N $GPUS_PER_NODE --oversubscribe \ /bin/bash -c "source /hpctmp/e1154485/venvs/videosys/bin/activate && \ python examples/training/cogvideox/train.py \ examples/training/cogvideox/configs/benchmarks/baseline.yaml \ - --image-mixing-frac 50 + --image-mixing-frac 50 --profile-flops " rm $HOSTFILE diff --git a/examples/training/cogvideox/1x8_dcp_inter.sh b/examples/training/cogvideox/1x8_dcp_inter.sh index 5f372f14..9ba60d5f 100755 --- a/examples/training/cogvideox/1x8_dcp_inter.sh +++ b/examples/training/cogvideox/1x8_dcp_inter.sh @@ -2,7 +2,7 @@ #PBS -P CFP02-CF-004 #PBS -l select=1:ngpus=8 #PBS -l place=vscatter -#PBS -l walltime=96:00:00 +#PBS -l walltime=24:00:00 #PBS -j oe #PBS -o 1x8-cogvideox-dcp-inter.log @@ -45,25 +45,25 @@ mpirun --hostfile $HOSTFILE --np $WORLD_SIZE -N $GPUS_PER_NODE --oversubscribe \ /bin/bash -c "source /hpctmp/e1154485/venvs/videosys/bin/activate && \ python examples/training/cogvideox/train.py \ examples/training/cogvideox/configs/benchmarks/dcp_inter.yaml \ - --image-mixing-frac 1 + --image-mixing-frac 1 --profile-flops " -# =============== zipf-10 ================ -mpirun --hostfile $HOSTFILE --np $WORLD_SIZE -N $GPUS_PER_NODE --oversubscribe \ - singularity exec --nv /app1/common/singularity-img/hopper/cuda/cuda_12.1.1-cudnn8-devel-ubuntu22.04.sif \ - /bin/bash -c "source /hpctmp/e1154485/venvs/videosys/bin/activate && \ - python examples/training/cogvideox/train.py \ - examples/training/cogvideox/configs/benchmarks/dcp_inter.yaml \ - --image-mixing-frac 10 -" +# # =============== zipf-10 ================ +# mpirun --hostfile $HOSTFILE --np $WORLD_SIZE -N $GPUS_PER_NODE --oversubscribe \ +# singularity exec --nv /app1/common/singularity-img/hopper/cuda/cuda_12.1.1-cudnn8-devel-ubuntu22.04.sif \ +# /bin/bash -c "source /hpctmp/e1154485/venvs/videosys/bin/activate && \ +# python examples/training/cogvideox/train.py \ +# examples/training/cogvideox/configs/benchmarks/dcp_inter.yaml \ +# --image-mixing-frac 10 --profile-flops +# " -# =============== zipf-50 ================ -mpirun --hostfile $HOSTFILE --np $WORLD_SIZE -N $GPUS_PER_NODE --oversubscribe \ - singularity exec --nv /app1/common/singularity-img/hopper/cuda/cuda_12.1.1-cudnn8-devel-ubuntu22.04.sif \ - /bin/bash -c "source /hpctmp/e1154485/venvs/videosys/bin/activate && \ - python examples/training/cogvideox/train.py \ - examples/training/cogvideox/configs/benchmarks/dcp_inter.yaml \ - --image-mixing-frac 50 -" +# # =============== zipf-50 ================ +# mpirun --hostfile $HOSTFILE --np $WORLD_SIZE -N $GPUS_PER_NODE --oversubscribe \ +# singularity exec --nv /app1/common/singularity-img/hopper/cuda/cuda_12.1.1-cudnn8-devel-ubuntu22.04.sif \ +# /bin/bash -c "source /hpctmp/e1154485/venvs/videosys/bin/activate && \ +# python examples/training/cogvideox/train.py \ +# examples/training/cogvideox/configs/benchmarks/dcp_inter.yaml \ +# --image-mixing-frac 50 --profile-flops +# " rm $HOSTFILE diff --git a/examples/training/cogvideox/1x8_dcp_inter_ckpt.sh b/examples/training/cogvideox/1x8_dcp_inter_ckpt.sh index d6bc04a5..3b976d5d 100755 --- a/examples/training/cogvideox/1x8_dcp_inter_ckpt.sh +++ b/examples/training/cogvideox/1x8_dcp_inter_ckpt.sh @@ -2,7 +2,7 @@ #PBS -P CFP02-CF-004 #PBS -l select=1:ngpus=8 #PBS -l place=vscatter -#PBS -l walltime=96:00:00 +#PBS -l walltime=24:00:00 #PBS -j oe #PBS -o 1x8-cogvideox-dcp-inter-ckpt.log @@ -45,7 +45,7 @@ mpirun --hostfile $HOSTFILE --np $WORLD_SIZE -N $GPUS_PER_NODE --oversubscribe \ /bin/bash -c "source /hpctmp/e1154485/venvs/videosys/bin/activate && \ python examples/training/cogvideox/train.py \ examples/training/cogvideox/configs/benchmarks/dcp_inter_ckpt.yaml \ - --image-mixing-frac 1 + --image-mixing-frac 1 --profile-flops " # =============== zipf-10 ================ @@ -54,7 +54,7 @@ mpirun --hostfile $HOSTFILE --np $WORLD_SIZE -N $GPUS_PER_NODE --oversubscribe \ /bin/bash -c "source /hpctmp/e1154485/venvs/videosys/bin/activate && \ python examples/training/cogvideox/train.py \ examples/training/cogvideox/configs/benchmarks/dcp_inter_ckpt.yaml \ - --image-mixing-frac 10 + --image-mixing-frac 10 --profile-flops " # =============== zipf-50 ================ @@ -63,7 +63,7 @@ mpirun --hostfile $HOSTFILE --np $WORLD_SIZE -N $GPUS_PER_NODE --oversubscribe \ /bin/bash -c "source /hpctmp/e1154485/venvs/videosys/bin/activate && \ python examples/training/cogvideox/train.py \ examples/training/cogvideox/configs/benchmarks/dcp_inter_ckpt.yaml \ - --image-mixing-frac 50 + --image-mixing-frac 50 --profile-flops " rm $HOSTFILE diff --git a/examples/training/cogvideox/1x8_dcp_intra.sh b/examples/training/cogvideox/1x8_dcp_intra.sh index a8db520a..e7531d8a 100755 --- a/examples/training/cogvideox/1x8_dcp_intra.sh +++ b/examples/training/cogvideox/1x8_dcp_intra.sh @@ -2,7 +2,7 @@ #PBS -P CFP02-CF-004 #PBS -l select=1:ngpus=8 #PBS -l place=vscatter -#PBS -l walltime=96:00:00 +#PBS -l walltime=24:00:00 #PBS -j oe #PBS -o 1x8-cogvideox-dcp-intra.log @@ -45,7 +45,7 @@ mpirun --hostfile $HOSTFILE --np $WORLD_SIZE -N $GPUS_PER_NODE --oversubscribe \ /bin/bash -c "source /hpctmp/e1154485/venvs/videosys/bin/activate && \ python examples/training/cogvideox/train.py \ examples/training/cogvideox/configs/benchmarks/dcp_intra.yaml \ - --image-mixing-frac 1 + --image-mixing-frac 1 --profile-flops " # =============== zipf-10 ================ @@ -54,7 +54,7 @@ mpirun --hostfile $HOSTFILE --np $WORLD_SIZE -N $GPUS_PER_NODE --oversubscribe \ /bin/bash -c "source /hpctmp/e1154485/venvs/videosys/bin/activate && \ python examples/training/cogvideox/train.py \ examples/training/cogvideox/configs/benchmarks/dcp_intra.yaml \ - --image-mixing-frac 10 + --image-mixing-frac 10 --profile-flops " # =============== zipf-50 ================ @@ -63,7 +63,7 @@ mpirun --hostfile $HOSTFILE --np $WORLD_SIZE -N $GPUS_PER_NODE --oversubscribe \ /bin/bash -c "source /hpctmp/e1154485/venvs/videosys/bin/activate && \ python examples/training/cogvideox/train.py \ examples/training/cogvideox/configs/benchmarks/dcp_intra.yaml \ - --image-mixing-frac 50 + --image-mixing-frac 50 --profile-flops " rm $HOSTFILE diff --git a/examples/training/cogvideox/configs/benchmarks/baseline.yaml b/examples/training/cogvideox/configs/benchmarks/baseline.yaml index e6d9a164..9dbb247d 100644 --- a/examples/training/cogvideox/configs/benchmarks/baseline.yaml +++ b/examples/training/cogvideox/configs/benchmarks/baseline.yaml @@ -13,7 +13,7 @@ calculate_imbalance: true # preprocess embedding data_path: "./assets/example_data/demo_preprocess.csv" preprocessed_data: true -drop_last: false # true +drop_last: true # train ckpt_path: "THUDM/CogVideoX-5b" diff --git a/examples/training/cogvideox/configs/benchmarks/dcp_inter.yaml b/examples/training/cogvideox/configs/benchmarks/dcp_inter.yaml index 87987294..d5aee363 100644 --- a/examples/training/cogvideox/configs/benchmarks/dcp_inter.yaml +++ b/examples/training/cogvideox/configs/benchmarks/dcp_inter.yaml @@ -5,7 +5,7 @@ dynamic_sp: true dynamic_recompute: false auto_grad_accumulation: true dummy_dataset: true -dummy_data_size: 2000 +dummy_data_size: 10000 verbose: true calculate_imbalance: true max_grad_accumulation_steps: 5 @@ -16,7 +16,7 @@ max_grad_accumulation_steps: 5 # preprocess embedding data_path: "./assets/example_data/demo_preprocess.csv" preprocessed_data: true -drop_last: false # true +drop_last: true # train ckpt_path: "THUDM/CogVideoX-5b" diff --git a/examples/training/cogvideox/configs/benchmarks/dcp_inter_ckpt.yaml b/examples/training/cogvideox/configs/benchmarks/dcp_inter_ckpt.yaml index 3c441e2f..6e344909 100644 --- a/examples/training/cogvideox/configs/benchmarks/dcp_inter_ckpt.yaml +++ b/examples/training/cogvideox/configs/benchmarks/dcp_inter_ckpt.yaml @@ -16,7 +16,7 @@ min_grad_accumulation_steps: 15 # preprocess embedding data_path: "./assets/example_data/demo_preprocess.csv" preprocessed_data: true -drop_last: false # true +drop_last: true # train ckpt_path: "THUDM/CogVideoX-5b" diff --git a/examples/training/cogvideox/configs/benchmarks/dcp_intra.yaml b/examples/training/cogvideox/configs/benchmarks/dcp_intra.yaml index 139408dc..d4523960 100644 --- a/examples/training/cogvideox/configs/benchmarks/dcp_intra.yaml +++ b/examples/training/cogvideox/configs/benchmarks/dcp_intra.yaml @@ -15,7 +15,7 @@ calculate_imbalance: true # preprocess embedding data_path: "./assets/example_data/demo_preprocess.csv" preprocessed_data: true -drop_last: false # true +drop_last: true # train ckpt_path: "THUDM/CogVideoX-5b" diff --git a/examples/training/cogvideox/train.py b/examples/training/cogvideox/train.py index 825b1c7b..dcc63500 100644 --- a/examples/training/cogvideox/train.py +++ b/examples/training/cogvideox/train.py @@ -3,6 +3,8 @@ import os from datetime import timedelta from pprint import pformat +import numpy as np +import time import deepspeed import torch @@ -277,6 +279,8 @@ def main(args): running_loss = 0.0 logging.info(f"Training for {cfg_epochs} epochs{' with profiling' if profiler.need_profile() else ''}.") + if args.profile_flops: + prof = deepspeed.profiling.flops_profiler.FlopsProfiler(model) # ======================================================= # 5. training loop # ======================================================= @@ -300,7 +304,7 @@ def main(args): dataloader_iter = iter(dataloader) epoch_desc = f"Epoch {epoch}" logging.info(f"Beginning {epoch_desc}...") - + flops_list = [] # == training loop in an epoch == pbar = tqdm( enumerate(dataloader_iter, start=start_step), @@ -316,6 +320,9 @@ def main(args): total_gas = batch["gas"] iter_loss = 0.0 + if args.profile_flops: + prof.start_profile() + start_time = time.time() for gas in range(total_gas): with profiler.profile(batch, model, gas) as valid_depth: batch_data = batch["data"][gas] @@ -365,6 +372,13 @@ def main(args): iter_loss += loss.detach() + if args.profile_flops: + prof.stop_profile() + flops = prof.get_total_flops() + prof.end_profile() + step_elapsed = time.time() - start_time + flops = flops / step_elapsed / 1e12 + flops_list.append(flops) if profiler.need_profile(): continue @@ -430,6 +444,9 @@ def main(args): f", sample throughput: {sampler.effective_samples / elapsed_time:.2f} samples/s" f", token throughput: {token_counter.item()/elapsed_time:.2f} token/s" ) + if args.profile_flops: + logging.info(f"Final FLOPS: {np.mean(flops_list):.2f} +- {np.std(flops_list):.2f} [ {np.min(flops_list):.2f} - {np.max(flops_list):.2f} ]") + flops_list.clear() sampler.reset() start_step = 0 @@ -505,6 +522,7 @@ def main(args): parser.add_argument("--calculate-imbalance", action="store_true") parser.add_argument("--max-grad-accumulation-steps", default=3, type=int) parser.add_argument("--min-grad-accumulation-steps", default=2, type=int) + parser.add_argument("--profile-flops", action="store_true", help="enable flops profiler") args = parser.parse_args() config_args = OmegaConf.load(args.config) diff --git a/examples/training/open_sora/1x8_baseline.sh b/examples/training/open_sora/1x8_baseline.sh index 01d03e35..b295583a 100755 --- a/examples/training/open_sora/1x8_baseline.sh +++ b/examples/training/open_sora/1x8_baseline.sh @@ -45,7 +45,7 @@ mpirun --hostfile $HOSTFILE --np $WORLD_SIZE -N $GPUS_PER_NODE --oversubscribe \ /bin/bash -c "source /hpctmp/e1154485/venvs/videosys/bin/activate && \ python examples/training/open_sora/train.py \ examples/training/open_sora/configs/benchmarks/baseline.yaml \ - --image-mixing-frac 1 + --image-mixing-frac 1 --profile-flops " # =============== zipf-10 ================ @@ -54,7 +54,7 @@ mpirun --hostfile $HOSTFILE --np $WORLD_SIZE -N $GPUS_PER_NODE --oversubscribe \ /bin/bash -c "source /hpctmp/e1154485/venvs/videosys/bin/activate && \ python examples/training/open_sora/train.py \ examples/training/open_sora/configs/benchmarks/baseline.yaml \ - --image-mixing-frac 10 + --image-mixing-frac 10 --profile-flops " # =============== zipf-50 ================ @@ -63,7 +63,7 @@ mpirun --hostfile $HOSTFILE --np $WORLD_SIZE -N $GPUS_PER_NODE --oversubscribe \ /bin/bash -c "source /hpctmp/e1154485/venvs/videosys/bin/activate && \ python examples/training/open_sora/train.py \ examples/training/open_sora/configs/benchmarks/baseline.yaml \ - --image-mixing-frac 50 + --image-mixing-frac 50 --profile-flops " rm $HOSTFILE diff --git a/examples/training/open_sora/1x8_dcp_inter.sh b/examples/training/open_sora/1x8_dcp_inter.sh index 753040f5..2de2d47f 100755 --- a/examples/training/open_sora/1x8_dcp_inter.sh +++ b/examples/training/open_sora/1x8_dcp_inter.sh @@ -45,7 +45,7 @@ mpirun --hostfile $HOSTFILE --np $WORLD_SIZE -N $GPUS_PER_NODE --oversubscribe \ /bin/bash -c "source /hpctmp/e1154485/venvs/videosys/bin/activate && \ python examples/training/open_sora/train.py \ examples/training/open_sora/configs/benchmarks-sp4/dcp_inter.yaml \ - --image-mixing-frac 1 + --image-mixing-frac 1 --profile-flops " # =============== zipf-10 ================ @@ -54,7 +54,7 @@ mpirun --hostfile $HOSTFILE --np $WORLD_SIZE -N $GPUS_PER_NODE --oversubscribe \ /bin/bash -c "source /hpctmp/e1154485/venvs/videosys/bin/activate && \ python examples/training/open_sora/train.py \ examples/training/open_sora/configs/benchmarks-sp4/dcp_inter.yaml \ - --image-mixing-frac 10 + --image-mixing-frac 10 --profile-flops " # =============== zipf-50 ================ @@ -63,7 +63,7 @@ mpirun --hostfile $HOSTFILE --np $WORLD_SIZE -N $GPUS_PER_NODE --oversubscribe \ /bin/bash -c "source /hpctmp/e1154485/venvs/videosys/bin/activate && \ python examples/training/open_sora/train.py \ examples/training/open_sora/configs/benchmarks-sp4/dcp_inter.yaml \ - --image-mixing-frac 50 + --image-mixing-frac 50 --profile-flops " rm $HOSTFILE diff --git a/examples/training/open_sora/1x8_dcp_inter_ckpt.sh b/examples/training/open_sora/1x8_dcp_inter_ckpt.sh index 987fcb0f..a99e1623 100755 --- a/examples/training/open_sora/1x8_dcp_inter_ckpt.sh +++ b/examples/training/open_sora/1x8_dcp_inter_ckpt.sh @@ -39,23 +39,23 @@ export PYTHONPATH=$PYTHONPATH:$PWD export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True export TOKENIZERS_PARALLELISM=false -# # =============== zipf-1 ================ -# mpirun --hostfile $HOSTFILE --np $WORLD_SIZE -N $GPUS_PER_NODE --oversubscribe \ -# singularity exec --nv /app1/common/singularity-img/hopper/cuda/cuda_12.1.1-cudnn8-devel-ubuntu22.04.sif \ -# /bin/bash -c "source /hpctmp/e1154485/venvs/videosys/bin/activate && \ -# python examples/training/open_sora/train.py \ -# examples/training/open_sora/configs/benchmarks-sp4/dcp_inter_ckpt.yaml \ -# --image-mixing-frac 1 -# " +# =============== zipf-1 ================ +mpirun --hostfile $HOSTFILE --np $WORLD_SIZE -N $GPUS_PER_NODE --oversubscribe \ + singularity exec --nv /app1/common/singularity-img/hopper/cuda/cuda_12.1.1-cudnn8-devel-ubuntu22.04.sif \ + /bin/bash -c "source /hpctmp/e1154485/venvs/videosys/bin/activate && \ + python examples/training/open_sora/train.py \ + examples/training/open_sora/configs/benchmarks-sp4/dcp_inter_ckpt.yaml \ + --image-mixing-frac 1 --profile-flops +" -# # =============== zipf-10 ================ -# mpirun --hostfile $HOSTFILE --np $WORLD_SIZE -N $GPUS_PER_NODE --oversubscribe \ -# singularity exec --nv /app1/common/singularity-img/hopper/cuda/cuda_12.1.1-cudnn8-devel-ubuntu22.04.sif \ -# /bin/bash -c "source /hpctmp/e1154485/venvs/videosys/bin/activate && \ -# python examples/training/open_sora/train.py \ -# examples/training/open_sora/configs/benchmarks-sp4/dcp_inter_ckpt.yaml \ -# --image-mixing-frac 10 -# " +# =============== zipf-10 ================ +mpirun --hostfile $HOSTFILE --np $WORLD_SIZE -N $GPUS_PER_NODE --oversubscribe \ + singularity exec --nv /app1/common/singularity-img/hopper/cuda/cuda_12.1.1-cudnn8-devel-ubuntu22.04.sif \ + /bin/bash -c "source /hpctmp/e1154485/venvs/videosys/bin/activate && \ + python examples/training/open_sora/train.py \ + examples/training/open_sora/configs/benchmarks-sp4/dcp_inter_ckpt.yaml \ + --image-mixing-frac 10 --profile-flops +" # =============== zipf-50 ================ mpirun --hostfile $HOSTFILE --np $WORLD_SIZE -N $GPUS_PER_NODE --oversubscribe \ @@ -63,7 +63,7 @@ mpirun --hostfile $HOSTFILE --np $WORLD_SIZE -N $GPUS_PER_NODE --oversubscribe \ /bin/bash -c "source /hpctmp/e1154485/venvs/videosys/bin/activate && \ python examples/training/open_sora/train.py \ examples/training/open_sora/configs/benchmarks-sp4/dcp_inter_ckpt.yaml \ - --image-mixing-frac 50 + --image-mixing-frac 50 --profile-flops " rm $HOSTFILE diff --git a/examples/training/open_sora/1x8_dcp_intra.sh b/examples/training/open_sora/1x8_dcp_intra.sh index 0d5a1d07..efff278e 100755 --- a/examples/training/open_sora/1x8_dcp_intra.sh +++ b/examples/training/open_sora/1x8_dcp_intra.sh @@ -45,7 +45,7 @@ mpirun --hostfile $HOSTFILE --np $WORLD_SIZE -N $GPUS_PER_NODE --oversubscribe \ /bin/bash -c "source /hpctmp/e1154485/venvs/videosys/bin/activate && \ python examples/training/open_sora/train.py \ examples/training/open_sora/configs/benchmarks-sp4/dcp_intra.yaml \ - --image-mixing-frac 1 + --image-mixing-frac 1 --profile-flops " # =============== zipf-10 ================ @@ -54,16 +54,16 @@ mpirun --hostfile $HOSTFILE --np $WORLD_SIZE -N $GPUS_PER_NODE --oversubscribe \ /bin/bash -c "source /hpctmp/e1154485/venvs/videosys/bin/activate && \ python examples/training/open_sora/train.py \ examples/training/open_sora/configs/benchmarks-sp4/dcp_intra.yaml \ - --image-mixing-frac 10 + --image-mixing-frac 10 --profile-flops " -# # =============== zipf-50 ================ -# mpirun --hostfile $HOSTFILE --np $WORLD_SIZE -N $GPUS_PER_NODE --oversubscribe \ -# singularity exec --nv /app1/common/singularity-img/hopper/cuda/cuda_12.1.1-cudnn8-devel-ubuntu22.04.sif \ -# /bin/bash -c "source /hpctmp/e1154485/venvs/videosys/bin/activate && \ -# python examples/training/open_sora/train.py \ -# examples/training/open_sora/configs/benchmarks-sp4/dcp_intra.yaml \ -# --image-mixing-frac 50 -# " +# =============== zipf-50 ================ +mpirun --hostfile $HOSTFILE --np $WORLD_SIZE -N $GPUS_PER_NODE --oversubscribe \ + singularity exec --nv /app1/common/singularity-img/hopper/cuda/cuda_12.1.1-cudnn8-devel-ubuntu22.04.sif \ + /bin/bash -c "source /hpctmp/e1154485/venvs/videosys/bin/activate && \ + python examples/training/open_sora/train.py \ + examples/training/open_sora/configs/benchmarks-sp4/dcp_intra.yaml \ + --image-mixing-frac 50 --profile-flops +" rm $HOSTFILE diff --git a/examples/training/open_sora/configs/benchmarks-sp4/dcp_inter.yaml b/examples/training/open_sora/configs/benchmarks-sp4/dcp_inter.yaml index 4c2f82bc..5ad32077 100755 --- a/examples/training/open_sora/configs/benchmarks-sp4/dcp_inter.yaml +++ b/examples/training/open_sora/configs/benchmarks-sp4/dcp_inter.yaml @@ -5,7 +5,7 @@ dynamic_sp: true dynamic_recompute: false auto_grad_accumulation: true dummy_dataset: true -dummy_data_size: 2000 +dummy_data_size: 5000 verbose: true calculate_imbalance: true max_grad_accumulation_steps: 5 @@ -16,7 +16,7 @@ max_grad_accumulation_steps: 5 # preprocess embedding data_path: "./assets/example_data/demo_preprocess.csv" preprocessed_data: true -drop_last: false # true +drop_last: true # train ckpt_path: "hpcai-tech/OpenSora-STDiT-v3" diff --git a/examples/training/open_sora/configs/benchmarks-sp4/dcp_inter_ckpt.yaml b/examples/training/open_sora/configs/benchmarks-sp4/dcp_inter_ckpt.yaml index c09b1754..158e9591 100755 --- a/examples/training/open_sora/configs/benchmarks-sp4/dcp_inter_ckpt.yaml +++ b/examples/training/open_sora/configs/benchmarks-sp4/dcp_inter_ckpt.yaml @@ -5,7 +5,7 @@ dynamic_sp: true dynamic_recompute: true auto_grad_accumulation: true dummy_dataset: true -dummy_data_size: 2000 +dummy_data_size: 5000 verbose: true calculate_imbalance: true max_grad_accumulation_steps: 5 @@ -16,7 +16,7 @@ min_grad_accumulation_steps: 15 # preprocess embedding data_path: "./assets/example_data/demo_preprocess.csv" preprocessed_data: true -drop_last: false # true +drop_last: true # train ckpt_path: "hpcai-tech/OpenSora-STDiT-v3" diff --git a/examples/training/open_sora/configs/benchmarks-sp4/dcp_intra.yaml b/examples/training/open_sora/configs/benchmarks-sp4/dcp_intra.yaml index 8cd3b840..8c9b09ed 100755 --- a/examples/training/open_sora/configs/benchmarks-sp4/dcp_intra.yaml +++ b/examples/training/open_sora/configs/benchmarks-sp4/dcp_intra.yaml @@ -5,7 +5,7 @@ dynamic_sp: true dynamic_recompute: false auto_grad_accumulation: false dummy_dataset: true -dummy_data_size: 2000 +dummy_data_size: 5000 verbose: true calculate_imbalance: true @@ -15,7 +15,7 @@ calculate_imbalance: true # preprocess embedding data_path: "./assets/example_data/demo_preprocess.csv" preprocessed_data: true -drop_last: false # true +drop_last: true # train ckpt_path: "hpcai-tech/OpenSora-STDiT-v3" diff --git a/examples/training/open_sora/configs/benchmarks/baseline.yaml b/examples/training/open_sora/configs/benchmarks/baseline.yaml index 6d2907b8..fc7285ea 100644 --- a/examples/training/open_sora/configs/benchmarks/baseline.yaml +++ b/examples/training/open_sora/configs/benchmarks/baseline.yaml @@ -3,7 +3,7 @@ outputs: exp/opensora/baseline profile_path: exp/opensora/profile/baseline sp_size: 4 dummy_dataset: true -dummy_data_size: 2000 +dummy_data_size: 5000 verbose: true calculate_imbalance: true @@ -13,7 +13,7 @@ calculate_imbalance: true # preprocess embedding data_path: "./assets/example_data/demo_preprocess.csv" preprocessed_data: true -drop_last: false # true +drop_last: true # train ckpt_path: "hpcai-tech/OpenSora-STDiT-v3" diff --git a/examples/training/open_sora/train.py b/examples/training/open_sora/train.py index e3cc6c7d..75dbc651 100644 --- a/examples/training/open_sora/train.py +++ b/examples/training/open_sora/train.py @@ -4,6 +4,8 @@ from copy import deepcopy from datetime import timedelta from pprint import pformat +import numpy as np +import time import deepspeed import torch @@ -281,6 +283,8 @@ def main(args): running_loss = 0.0 logging.info(f"Training for {cfg_epochs} epochs{' with profiling' if profiler.need_profile() else ''}.") + if args.profile_flops: + prof = deepspeed.profiling.flops_profiler.FlopsProfiler(model) # ======================================================= # 5. training loop # ======================================================= @@ -304,7 +308,7 @@ def main(args): dataloader_iter = iter(dataloader) epoch_desc = f"Epoch {epoch}" logging.info(f"Beginning {epoch_desc}...") - + flops_list = [] # == training loop in an epoch == pbar = tqdm( enumerate(dataloader_iter, start=start_step), @@ -320,6 +324,9 @@ def main(args): total_gas = batch["gas"] iter_loss = 0.0 + if args.profile_flops: + prof.start_profile() + start_time = time.time() for gas in range(total_gas): with profiler.profile(batch, model, gas) as valid_depth: batch_data = batch["data"][gas] @@ -370,6 +377,13 @@ def main(args): iter_loss += loss.detach() + if args.profile_flops: + prof.stop_profile() + flops = prof.get_total_flops() + prof.end_profile() + step_elapsed = time.time() - start_time + flops = flops / step_elapsed / 1e12 + flops_list.append(flops) if profiler.need_profile(): continue @@ -435,6 +449,9 @@ def main(args): f", sample throughput: {sampler.effective_samples / elapsed_time:.2f} samples/s" f", token throughput: {token_counter.item()/elapsed_time:.2f} token/s" ) + if args.profile_flops: + logging.info(f"Final FLOPS: {np.mean(flops_list):.2f} +- {np.std(flops_list):.2f} [ {np.min(flops_list):.2f} - {np.max(flops_list):.2f} ]") + flops_list.clear() sampler.reset() start_step = 0 @@ -510,6 +527,7 @@ def main(args): parser.add_argument("--calculate-imbalance", action="store_true") parser.add_argument("--max-grad-accumulation-steps", default=3, type=int) parser.add_argument("--min-grad-accumulation-steps", default=2, type=int) + parser.add_argument("--profile-flops", action="store_true", help="enable flops profiler") args = parser.parse_args() config_args = OmegaConf.load(args.config) diff --git a/parse_log.py b/parse_log.py index 4e5b67d5..750f8207 100644 --- a/parse_log.py +++ b/parse_log.py @@ -7,17 +7,20 @@ def parse_log(file_path): imbalance_pattern = r'Total imbalance for this epoch:.*?\((\d+\.\d+)%\)' throughput_pattern = r'token throughput: (\d+\.\d+) token/s' + flops_pattern = r'Final FLOPS: (\d+\.\d+)' # Extracts the first number after 'Final FLOPS:' with open(file_path, 'r') as file: log_lines = file.read() - + imbalance_match = re.search(imbalance_pattern, log_lines) throughput_match = re.search(throughput_pattern, log_lines) + flops_match = re.search(flops_pattern, log_lines) imbalance_percent = float(imbalance_match.group(1)) if imbalance_match else None token_throughput = float(throughput_match.group(1)) if throughput_match else None + flops = float(flops_match.group(1)) if flops_match else None - return imbalance_percent, token_throughput + return imbalance_percent, token_throughput, flops def main(): @@ -29,7 +32,7 @@ def main(): if 'log.txt' in filenames: log_path = os.path.join(dirpath, 'log.txt') try: - imbalance_percent, token_throughput = parse_log(log_path) + imbalance_percent, token_throughput, flops = parse_log(log_path) relative_path = os.path.relpath(dirpath, args.log_dir) @@ -38,7 +41,8 @@ def main(): 'run': relative_path.split('/')[-1], # 000-OpenSora, etc. 'log_path': relative_path, 'imbalance_percent': imbalance_percent, - 'token_throughput': token_throughput + 'token_throughput': token_throughput, + 'flops': flops }) except Exception as e: