From df0cb4dbb52171a269ecef16e33fa0f57a7fe564 Mon Sep 17 00:00:00 2001 From: HIT-cwh <2892770585@qq.com> Date: Mon, 8 Apr 2024 13:49:55 +0800 Subject: [PATCH 1/7] WIP: support sequence parallel --- configs/ffs/ffs_img_train.yaml | 12 ++-- configs/ffs/ffs_img_train_sp2.yaml | 47 +++++++++++++++ datasets/ffs_image_datasets.py | 55 ++++++++++-------- diffusion/gaussian_diffusion.py | 40 ++++++++++++- models/latte_img.py | 92 ++++++++++++++++++++++++------ train_with_img.py | 63 +++++++++++++++++--- 6 files changed, 253 insertions(+), 56 deletions(-) create mode 100644 configs/ffs/ffs_img_train_sp2.yaml diff --git a/configs/ffs/ffs_img_train.yaml b/configs/ffs/ffs_img_train.yaml index 15315fc..983a8b3 100644 --- a/configs/ffs/ffs_img_train.yaml +++ b/configs/ffs/ffs_img_train.yaml @@ -2,9 +2,9 @@ dataset: "ffs_img" data_path: "/path/to/datasets/preprocessed_ffs/train/videos/" -frame_data_path: "/path/to/datasets/preprocessed_ffs/train/images/" -frame_data_txt: "/path/to/datasets/preprocessed_ffs/train_list.txt" -pretrained_model_path: "/path/to/pretrained/Latte/" +frame_data_path: "/mnt/petrelfs/share_data/linzhihao/dataset/llava_data/LLaVA-Pretrain/images/00000" +frame_data_txt: "/mnt/petrelfs/caoweihan/projects/Latte/data/toy_train_list.txt" +pretrained_model_path: "maxin-cn/Latte" # save and load results_dir: "./results_img" @@ -12,7 +12,8 @@ pretrained: # model config: model: LatteIMG-XL/2 -num_frames: 16 +# num_frames: 16 +num_frames: 0 image_size: 256 # choices=[256, 512] num_sampling_steps: 250 frame_interval: 3 @@ -22,6 +23,7 @@ learn_sigma: True # important extras: 1 # [1, 2] 1 unconditional generation, 2 class-conditional generation # train config: +sequence_parallel_size: 1 save_ceph: True # important use_image_num: 8 learning_rate: 1e-4 @@ -32,7 +34,7 @@ local_batch_size: 4 # important max_train_steps: 1000000 global_seed: 3407 num_workers: 8 -log_every: 100 +log_every: 10 lr_warmup_steps: 0 resume_from_checkpoint: gradient_accumulation_steps: 1 # TODO diff --git a/configs/ffs/ffs_img_train_sp2.yaml b/configs/ffs/ffs_img_train_sp2.yaml new file mode 100644 index 0000000..0f832f9 --- /dev/null +++ b/configs/ffs/ffs_img_train_sp2.yaml @@ -0,0 +1,47 @@ +# dataset +dataset: "ffs_img" + +data_path: "/path/to/datasets/preprocessed_ffs/train/videos/" +frame_data_path: "/mnt/petrelfs/share_data/linzhihao/dataset/llava_data/LLaVA-Pretrain/images/00000" +frame_data_txt: "/mnt/petrelfs/caoweihan/projects/Latte/data/toy_train_list.txt" +pretrained_model_path: "maxin-cn/Latte" + +# save and load +results_dir: "./results_img" +pretrained: + +# model config: +model: LatteIMG-XL/2 +# num_frames: 16 +num_frames: 0 +image_size: 256 # choices=[256, 512] +num_sampling_steps: 250 +frame_interval: 3 +fixed_spatial: False +attention_bias: True +learn_sigma: True # important +extras: 1 # [1, 2] 1 unconditional generation, 2 class-conditional generation + +# train config: +sequence_parallel_size: 2 +save_ceph: True # important +use_image_num: 8 +learning_rate: 1e-4 +ckpt_every: 10000 +clip_max_norm: 0.1 +start_clip_iter: 500000 +local_batch_size: 4 # important +max_train_steps: 1000000 +global_seed: 3407 +num_workers: 8 +log_every: 10 +lr_warmup_steps: 0 +resume_from_checkpoint: +gradient_accumulation_steps: 1 # TODO +num_classes: + +# low VRAM and speed up training +use_compile: False +mixed_precision: False +enable_xformers_memory_efficient_attention: False +gradient_checkpointing: False \ No newline at end of file diff --git a/datasets/ffs_image_datasets.py b/datasets/ffs_image_datasets.py index 1140e16..00240d1 100644 --- a/datasets/ffs_image_datasets.py +++ b/datasets/ffs_image_datasets.py @@ -139,38 +139,40 @@ def __init__(self, transform=None, temporal_sample=None): self.configs = configs - self.data_path = configs.data_path - self.video_lists = get_filelist(configs.data_path) + # self.data_path = configs.data_path + # self.video_lists = get_filelist(configs.data_path) self.transform = transform self.temporal_sample = temporal_sample - self.target_video_len = self.configs.num_frames - self.v_decoder = DecordInit() - self.video_length = len(self.video_lists) + # self.target_video_len = self.configs.num_frames + # self.v_decoder = DecordInit() + # self.video_length = len(self.video_lists) # ffs video frames self.video_frame_path = configs.frame_data_path self.video_frame_txt = configs.frame_data_txt self.video_frame_files = [frame_file.strip() for frame_file in open(self.video_frame_txt)] - random.shuffle(self.video_frame_files) + # random.shuffle(self.video_frame_files) self.use_image_num = configs.use_image_num self.image_tranform = transforms.Compose([ transforms.ToTensor(), + transforms.Resize(configs.image_size), + transforms.CenterCrop(configs.image_size), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True) ]) def __getitem__(self, index): - video_index = index % self.video_length - path = self.video_lists[video_index] - vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit='sec', output_format='TCHW') - total_frames = len(vframes) + # video_index = index % self.video_length + # path = self.video_lists[video_index] + # vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit='sec', output_format='TCHW') + # total_frames = len(vframes) - # Sampling video frames - start_frame_ind, end_frame_ind = self.temporal_sample(total_frames) - assert end_frame_ind - start_frame_ind >= self.target_video_len - frame_indice = np.linspace(start_frame_ind, end_frame_ind-1, self.target_video_len, dtype=int) - video = vframes[frame_indice] - # videotransformer data proprecess - video = self.transform(video) # T C H W + # # Sampling video frames + # start_frame_ind, end_frame_ind = self.temporal_sample(total_frames) + # assert end_frame_ind - start_frame_ind >= self.target_video_len + # frame_indice = np.linspace(start_frame_ind, end_frame_ind-1, self.target_video_len, dtype=int) + # video = vframes[frame_indice] + # # videotransformer data proprecess + # video = self.transform(video) # T C H W # get video frames images = [] @@ -182,15 +184,16 @@ def __getitem__(self, index): images.append(image) break except Exception as e: - traceback.print_exc() + # traceback.print_exc() index = random.randint(0, len(self.video_frame_files) - self.use_image_num) images = torch.cat(images, dim=0) assert len(images) == self.use_image_num + return {'video': images, 'video_name': 1} - video_cat = torch.cat([video, images], dim=0) + # video_cat = torch.cat([video, images], dim=0) - return {'video': video_cat, 'video_name': 1} + # return {'video': video_cat, 'video_name': 1} def __len__(self): return len(self.video_frame_files) @@ -216,6 +219,7 @@ def __len__(self): parser.add_argument("--data-path", type=str, default="/path/to/datasets/preprocessed_ffs/train/videos/") parser.add_argument("--frame-data-path", type=str, default="/path/to/datasets/preprocessed_ffs/train/images/") parser.add_argument("--frame-data-txt", type=str, default="/path/to/datasets/faceForensics_v1/train_list.txt") + parser.add_argument("--image-size", type=int, default=256) config = parser.parse_args() temporal_sample = video_transforms.TemporalRandomCrop(config.num_frames * config.frame_interval) @@ -226,14 +230,19 @@ def __len__(self): ]) dataset = FaceForensicsImages(config, transform=transform_webvideo, temporal_sample=temporal_sample) - dataloader = Data.DataLoader(dataset=dataset, batch_size=1, shuffle=True, num_workers=4) + print(len(dataset)) + # breakpoint() + dataloader = Data.DataLoader(dataset=dataset, batch_size=1, shuffle=True, num_workers=1) for i, video_data in enumerate(dataloader): video, video_label = video_data['video'], video_data['video_name'] # print(video_label) # print(image_label) + print(video.shape) print(video_label) + # if i % 100 == 0: + # print(i) # video_ = ((video[0] * 0.5 + 0.5) * 255).add_(0.5).clamp_(0, 255).to(dtype=torch.uint8).cpu().permute(0, 2, 3, 1) # print(video_.shape) # try: @@ -241,6 +250,6 @@ def __len__(self): # except: # pass - # if i % 100 == 0 and i != 0: - # break + if i % 100 == 0 and i != 0: + break print('Done!') \ No newline at end of file diff --git a/diffusion/gaussian_diffusion.py b/diffusion/gaussian_diffusion.py index 9933d65..593324c 100644 --- a/diffusion/gaussian_diffusion.py +++ b/diffusion/gaussian_diffusion.py @@ -141,6 +141,36 @@ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): return np.array(betas) +from xtuner.parallel.sequence import get_sequence_parallel_world_size, get_sequence_parallel_rank, reduce_sequence_parallel_loss +import torch +import torch.distributed as dist +def split_for_sequence_parallel(tokens, split_dim=1): + seq_parallel_world_size = get_sequence_parallel_world_size() + if seq_parallel_world_size == 1: + return tokens + + seq_parallel_world_rank = get_sequence_parallel_rank() + + # bs, seq_len, dim = tokens.shape + seq_len = tokens.shape[split_dim] + assert seq_len % seq_parallel_world_size == 0 + sub_seq_len = seq_len // seq_parallel_world_size + sub_seq_start = seq_parallel_world_rank * sub_seq_len + sub_seq_end = (seq_parallel_world_rank + 1) * sub_seq_len + + if split_dim == 0: + tokens = tokens[sub_seq_start:sub_seq_end] + elif split_dim == 1: + tokens = tokens[:, sub_seq_start:sub_seq_end] + elif split_dim == 2: + tokens = tokens[:, :, sub_seq_start:sub_seq_end] + elif split_dim == 3: + tokens = tokens[:, :, :, sub_seq_start:sub_seq_end] + else: + raise NotImplementedError + return tokens + + class GaussianDiffusion: """ Utilities for training and sampling diffusion models. @@ -732,7 +762,7 @@ def training_losses(self, model, x_start, t, model_kwargs=None, noise=None): model_kwargs = {} if noise is None: noise = th.randn_like(x_start) - x_t = self.q_sample(x_start, t, noise=noise) + x_t = self.q_sample(x_start, t, noise=noise) # x: (bs, frame, c_in, h, w), t: (bs, ) terms = {} @@ -748,7 +778,10 @@ def training_losses(self, model, x_start, t, model_kwargs=None, noise=None): if self.loss_type == LossType.RESCALED_KL: terms["loss"] *= self.num_timesteps elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE: - model_output = model(x_t, t, **model_kwargs) + model_output = model(x_t, t, **model_kwargs) # (bs, frame, 2*c_in, h//p, w) + x_t = split_for_sequence_parallel(x_t, split_dim=3) + x_start = split_for_sequence_parallel(x_start, split_dim=3) + noise = split_for_sequence_parallel(noise, split_dim=3) # try: # model_output = model(x_t, t, **model_kwargs).sample # for tav unet # except: @@ -784,7 +817,8 @@ def training_losses(self, model, x_start, t, model_kwargs=None, noise=None): ModelMeanType.EPSILON: noise, }[self.model_mean_type] assert model_output.shape == target.shape == x_start.shape - terms["mse"] = mean_flat((target - model_output) ** 2) + loss = mean_flat((target - model_output) ** 2) + terms["mse"] = reduce_sequence_parallel_loss(loss, torch.tensor(1, device=loss.device, dtype=loss.dtype)) if "vb" in terms: terms["loss"] = terms["mse"] + terms["vb"] else: diff --git a/models/latte_img.py b/models/latte_img.py index c468c63..c5a49a1 100644 --- a/models/latte_img.py +++ b/models/latte_img.py @@ -36,6 +36,9 @@ def modulate(x, shift, scale): # Attention Layers from TIMM # ################################################################################# +from xtuner.parallel.sequence.attention import post_process_for_sequence_parallel_attn, pre_process_for_sequence_parallel_attn +import torch.distributed as dist + class Attention(nn.Module): def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., use_lora=False, attention_mode='math'): super().__init__() @@ -52,28 +55,46 @@ def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) - def forward(self, x): + def forward(self, x, use_sequence_parallel=True): B, N, C = x.shape - qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).contiguous() - q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) + q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2] + + # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).contiguous() + # q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) + + enable_sequence_parallel = ( + dist.is_initialized() and get_sequence_parallel_world_size() > 1 + and self.training and use_sequence_parallel) + if enable_sequence_parallel: + q, k, v = pre_process_for_sequence_parallel_attn(q, k, v) + + q = q.transpose(1, 2).contiguous() # bshd -> bhsd + k = k.transpose(1, 2).contiguous() + v = v.transpose(1, 2).contiguous() if self.attention_mode == 'xformers': # cause loss nan while using with amp - x = xformers.ops.memory_efficient_attention(q, k, v).reshape(B, N, C) + x = xformers.ops.memory_efficient_attention(q, k, v) elif self.attention_mode == 'flash': # cause loss nan while using with amp # Optionally use the context manager to ensure one of the fused kerenels is run with torch.backends.cuda.sdp_kernel(enable_math=False): - x = torch.nn.functional.scaled_dot_product_attention(q, k, v).reshape(B, N, C) # require pytorch 2.0 + x = torch.nn.functional.scaled_dot_product_attention(q, k, v) # require pytorch 2.0 elif self.attention_mode == 'math': attn = (q @ k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) - x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = (attn @ v).transpose(1, 2).contiguous() else: raise NotImplemented + + if enable_sequence_parallel: + x = post_process_for_sequence_parallel_attn(x).contiguous() + + x = x.reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) @@ -177,9 +198,9 @@ def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs): nn.Linear(hidden_size, 6 * hidden_size, bias=True) ) - def forward(self, x, c): + def forward(self, x, c, use_sequence_parallel=True): shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1) - x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa)) + x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), use_sequence_parallel) x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) return x @@ -204,6 +225,35 @@ def forward(self, x, c): return x +from xtuner.parallel.sequence import get_sequence_parallel_world_size, get_sequence_parallel_rank + +def split_for_sequence_parallel(tokens, split_dim=1): + seq_parallel_world_size = get_sequence_parallel_world_size() + if seq_parallel_world_size == 1: + return tokens + + seq_parallel_world_rank = get_sequence_parallel_rank() + + # bs, seq_len, dim = tokens.shape + seq_len = tokens.shape[split_dim] + assert seq_len % seq_parallel_world_size == 0 + sub_seq_len = seq_len // seq_parallel_world_size + sub_seq_start = seq_parallel_world_rank * sub_seq_len + sub_seq_end = (seq_parallel_world_rank + 1) * sub_seq_len + + if split_dim == 0: + tokens = tokens[sub_seq_start:sub_seq_end] + elif split_dim == 1: + tokens = tokens[:, sub_seq_start:sub_seq_end] + elif split_dim == 2: + tokens = tokens[:, :, sub_seq_start:sub_seq_end] + elif split_dim == 3: + tokens = tokens[:, :, :, sub_seq_start:sub_seq_end] + else: + raise NotImplementedError + return tokens + + class Latte(nn.Module): """ Diffusion model with a Transformer backbone. @@ -301,14 +351,17 @@ def unpatchify(self, x): x: (N, T, patch_size**2 * C) imgs: (N, H, W, C) """ + seq_parallel_world_size = get_sequence_parallel_world_size() c = self.out_channels p = self.x_embedder.patch_size[0] - h = w = int(x.shape[1] ** 0.5) - assert h * w == x.shape[1] + h = w = int((x.shape[1] * seq_parallel_world_size) ** 0.5) + assert h * w == x.shape[1] * seq_parallel_world_size + assert h % seq_parallel_world_size == 0 + h_per_rank = h // seq_parallel_world_size - x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) + x = x.reshape(shape=(x.shape[0], h_per_rank, w, p, p, c)) x = torch.einsum('nhwpqc->nchpwq', x) - imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p)) + imgs = x.reshape(shape=(x.shape[0], c, h_per_rank * p, w * p)) return imgs # @torch.cuda.amp.autocast() @@ -324,12 +377,19 @@ def forward(self, x, t, y=None, use_fp16=False, y_image=None, use_image_num=0): """ if use_fp16: x = x.to(dtype=torch.float16) + batches, frames, channels, high, weight = x.shape x = rearrange(x, 'b f c h w -> (b f) c h w') x = self.x_embedder(x) + self.pos_embed - t = self.t_embedder(t, use_fp16=use_fp16) - timestep_spatial = repeat(t, 'n d -> (n c) d', c=self.temp_embed.shape[1] + use_image_num) + # print(f'before model x.shape = {x.shape}') + bs, seq_len, dim = x.shape + seq_parallel_world_size = get_sequence_parallel_world_size() + assert seq_len % seq_parallel_world_size == 0 + x = split_for_sequence_parallel(x, split_dim=1) + t = self.t_embedder(t, use_fp16=use_fp16) + timestep_spatial = repeat(t, 'n d -> (n c) d', c=self.temp_embed.shape[1] + use_image_num) # (1, num_frames, hidden_size) timestep_temp = repeat(t, 'n d -> (n c) d', c=self.pos_embed.shape[1]) + timestep_temp = split_for_sequence_parallel(timestep_temp, split_dim=0) if self.extras == 2: y = self.y_embedder(y, self.training) @@ -367,7 +427,7 @@ def forward(self, x, t, y=None, use_fp16=False, y_image=None, use_image_num=0): c = timestep_spatial + text_embedding_spatial else: c = timestep_spatial - x = spatial_block(x, c) + x = spatial_block(x, c, use_sequence_parallel=True) x = rearrange(x, '(b f) t d -> (b t) f d', b=batches) x_video = x[:, :(frames-use_image_num), :] @@ -384,7 +444,7 @@ def forward(self, x, t, y=None, use_fp16=False, y_image=None, use_image_num=0): else: c = timestep_temp - x_video = temp_block(x_video, c) + x_video = temp_block(x_video, c, use_sequence_parallel=False) x = torch.cat([x_video, x_image], dim=1) x = rearrange(x, '(b t) f d -> (b f) t d', b=batches) diff --git a/train_with_img.py b/train_with_img.py index 2355e4a..3460606 100644 --- a/train_with_img.py +++ b/train_with_img.py @@ -37,6 +37,47 @@ write_tensorboard, setup_distributed, get_experiment_dir) +from xtuner.parallel.sequence import ( + init_sequence_parallel, get_sequence_parallel_world_size, + get_sequence_parallel_rank, get_data_parallel_world_size, + get_data_parallel_rank, SequenceParallelSampler) + + +# class SequenceParallelSampler(DistributedSampler): +# def __init__(self, dataset, num_replicas = None, +# rank = None, shuffle: bool = True, +# seed: int = 0, drop_last: bool = False): +# if num_replicas is None: +# if not dist.is_available(): +# raise RuntimeError("Requires distributed package to be available") +# num_replicas = get_data_parallel_world_size() +# if rank is None: +# if not dist.is_available(): +# raise RuntimeError("Requires distributed package to be available") +# rank = get_data_parallel_rank() +# if rank >= num_replicas or rank < 0: +# raise ValueError( +# f"Invalid rank {rank}, rank should be in the interval [0, {num_replicas - 1}]") +# self.dataset = dataset +# self.num_replicas = num_replicas +# self.rank = rank +# self.epoch = 0 +# self.drop_last = drop_last +# # If the dataset length is evenly divisible by # of replicas, then there +# # is no need to drop any data, since the dataset will be split equally. +# if self.drop_last and len(self.dataset) % self.num_replicas != 0: # type: ignore[arg-type] +# # Split to nearest available length that is evenly divisible. +# # This is to ensure each rank receives the same amount of data when +# # using this Sampler. +# self.num_samples = math.ceil( +# (len(self.dataset) - self.num_replicas) / self.num_replicas # type: ignore[arg-type] +# ) +# else: +# self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) # type: ignore[arg-type] +# self.total_size = self.num_samples * self.num_replicas +# self.shuffle = shuffle +# self.seed = seed + ################################################################################# # Training Loop # ################################################################################# @@ -47,6 +88,7 @@ def main(args): # Setup DDP: setup_distributed() + init_sequence_parallel(args.sequence_parallel_size) # dist.init_process_group("nccl") # assert args.global_batch_size % dist.get_world_size() == 0, f"Batch size must be divisible by world size." # rank = dist.get_rank() @@ -57,7 +99,8 @@ def main(args): local_rank = int(os.environ["LOCAL_RANK"]) device = torch.device("cuda", local_rank) - seed = args.global_seed + rank + # seed = args.global_seed + rank + seed = args.global_seed torch.manual_seed(seed) torch.cuda.set_device(device) print(f"Starting rank={rank}, local rank={local_rank}, seed={seed}, world_size={dist.get_world_size()}.") @@ -147,14 +190,16 @@ def main(args): # Setup data: dataset = get_dataset(args) + + sampler = SequenceParallelSampler(dataset, shuffle=True, seed=args.global_seed) - sampler = DistributedSampler( - dataset, - num_replicas=dist.get_world_size(), - rank=rank, - shuffle=True, - seed=args.global_seed - ) + # sampler = DistributedSampler( + # dataset, + # num_replicas=dist.get_world_size(), + # rank=rank, + # shuffle=True, + # seed=args.global_seed + # ) loader = DataLoader( dataset, batch_size=int(args.local_batch_size), @@ -164,7 +209,7 @@ def main(args): pin_memory=True, drop_last=True ) - logger.info(f"Dataset contains {len(dataset):,} videos ({args.webvideo_data_path})") + # logger.info(f"Dataset contains {len(dataset):,} videos ({args.webvideo_data_path})") # Scheduler lr_scheduler = get_scheduler( From f5e8aa773671eda2ebf459ce9827b4848d875874 Mon Sep 17 00:00:00 2001 From: HIT-cwh <2892770585@qq.com> Date: Tue, 9 Apr 2024 18:48:20 +0800 Subject: [PATCH 2/7] use new XTuner sequence parallel API --- diffusion/gaussian_diffusion.py | 32 ++------------------------------ models/latte_img.py | 30 +----------------------------- 2 files changed, 3 insertions(+), 59 deletions(-) diff --git a/diffusion/gaussian_diffusion.py b/diffusion/gaussian_diffusion.py index 593324c..0041032 100644 --- a/diffusion/gaussian_diffusion.py +++ b/diffusion/gaussian_diffusion.py @@ -9,6 +9,8 @@ import numpy as np import torch as th import enum +from xtuner.parallel.sequence import reduce_sequence_parallel_loss, split_for_sequence_parallel +import torch from .diffusion_utils import discretized_gaussian_log_likelihood, normal_kl @@ -141,36 +143,6 @@ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): return np.array(betas) -from xtuner.parallel.sequence import get_sequence_parallel_world_size, get_sequence_parallel_rank, reduce_sequence_parallel_loss -import torch -import torch.distributed as dist -def split_for_sequence_parallel(tokens, split_dim=1): - seq_parallel_world_size = get_sequence_parallel_world_size() - if seq_parallel_world_size == 1: - return tokens - - seq_parallel_world_rank = get_sequence_parallel_rank() - - # bs, seq_len, dim = tokens.shape - seq_len = tokens.shape[split_dim] - assert seq_len % seq_parallel_world_size == 0 - sub_seq_len = seq_len // seq_parallel_world_size - sub_seq_start = seq_parallel_world_rank * sub_seq_len - sub_seq_end = (seq_parallel_world_rank + 1) * sub_seq_len - - if split_dim == 0: - tokens = tokens[sub_seq_start:sub_seq_end] - elif split_dim == 1: - tokens = tokens[:, sub_seq_start:sub_seq_end] - elif split_dim == 2: - tokens = tokens[:, :, sub_seq_start:sub_seq_end] - elif split_dim == 3: - tokens = tokens[:, :, :, sub_seq_start:sub_seq_end] - else: - raise NotImplementedError - return tokens - - class GaussianDiffusion: """ Utilities for training and sampling diffusion models. diff --git a/models/latte_img.py b/models/latte_img.py index c5a49a1..473f1ae 100644 --- a/models/latte_img.py +++ b/models/latte_img.py @@ -14,6 +14,7 @@ from einops import rearrange, repeat from timm.models.vision_transformer import Mlp, PatchEmbed +from xtuner.parallel.sequence import get_sequence_parallel_world_size, split_for_sequence_parallel import os import sys @@ -225,35 +226,6 @@ def forward(self, x, c): return x -from xtuner.parallel.sequence import get_sequence_parallel_world_size, get_sequence_parallel_rank - -def split_for_sequence_parallel(tokens, split_dim=1): - seq_parallel_world_size = get_sequence_parallel_world_size() - if seq_parallel_world_size == 1: - return tokens - - seq_parallel_world_rank = get_sequence_parallel_rank() - - # bs, seq_len, dim = tokens.shape - seq_len = tokens.shape[split_dim] - assert seq_len % seq_parallel_world_size == 0 - sub_seq_len = seq_len // seq_parallel_world_size - sub_seq_start = seq_parallel_world_rank * sub_seq_len - sub_seq_end = (seq_parallel_world_rank + 1) * sub_seq_len - - if split_dim == 0: - tokens = tokens[sub_seq_start:sub_seq_end] - elif split_dim == 1: - tokens = tokens[:, sub_seq_start:sub_seq_end] - elif split_dim == 2: - tokens = tokens[:, :, sub_seq_start:sub_seq_end] - elif split_dim == 3: - tokens = tokens[:, :, :, sub_seq_start:sub_seq_end] - else: - raise NotImplementedError - return tokens - - class Latte(nn.Module): """ Diffusion model with a Transformer backbone. From 60d514d89de789f82113c3b46113f2a073562498 Mon Sep 17 00:00:00 2001 From: HIT-cwh <2892770585@qq.com> Date: Wed, 10 Apr 2024 16:16:52 +0800 Subject: [PATCH 3/7] update sp usage according the new API in XTuner --- diffusion/gaussian_diffusion.py | 8 ++----- models/latte_img.py | 26 +++++++++++---------- train_with_img.py | 41 +-------------------------------- 3 files changed, 17 insertions(+), 58 deletions(-) diff --git a/diffusion/gaussian_diffusion.py b/diffusion/gaussian_diffusion.py index 0041032..0c9af4b 100644 --- a/diffusion/gaussian_diffusion.py +++ b/diffusion/gaussian_diffusion.py @@ -750,10 +750,7 @@ def training_losses(self, model, x_start, t, model_kwargs=None, noise=None): if self.loss_type == LossType.RESCALED_KL: terms["loss"] *= self.num_timesteps elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE: - model_output = model(x_t, t, **model_kwargs) # (bs, frame, 2*c_in, h//p, w) - x_t = split_for_sequence_parallel(x_t, split_dim=3) - x_start = split_for_sequence_parallel(x_start, split_dim=3) - noise = split_for_sequence_parallel(noise, split_dim=3) + model_output = model(x_t, t, **model_kwargs) # (bs, frame, 2*c_in, h, w) # try: # model_output = model(x_t, t, **model_kwargs).sample # for tav unet # except: @@ -789,8 +786,7 @@ def training_losses(self, model, x_start, t, model_kwargs=None, noise=None): ModelMeanType.EPSILON: noise, }[self.model_mean_type] assert model_output.shape == target.shape == x_start.shape - loss = mean_flat((target - model_output) ** 2) - terms["mse"] = reduce_sequence_parallel_loss(loss, torch.tensor(1, device=loss.device, dtype=loss.dtype)) + terms["mse"] = mean_flat((target - model_output) ** 2) if "vb" in terms: terms["loss"] = terms["mse"] + terms["vb"] else: diff --git a/models/latte_img.py b/models/latte_img.py index 473f1ae..956d7d7 100644 --- a/models/latte_img.py +++ b/models/latte_img.py @@ -14,7 +14,7 @@ from einops import rearrange, repeat from timm.models.vision_transformer import Mlp, PatchEmbed -from xtuner.parallel.sequence import get_sequence_parallel_world_size, split_for_sequence_parallel +from xtuner.parallel.sequence import get_sequence_parallel_world_size, split_for_sequence_parallel, get_sequence_parallel_group import os import sys @@ -37,7 +37,10 @@ def modulate(x, shift, scale): # Attention Layers from TIMM # ################################################################################# -from xtuner.parallel.sequence.attention import post_process_for_sequence_parallel_attn, pre_process_for_sequence_parallel_attn +from xtuner.parallel.sequence import ( + post_process_for_sequence_parallel_attn, pre_process_for_sequence_parallel_attn, + gather_forward_split_backward, get_sequence_parallel_group + ) import torch.distributed as dist class Attention(nn.Module): @@ -323,17 +326,14 @@ def unpatchify(self, x): x: (N, T, patch_size**2 * C) imgs: (N, H, W, C) """ - seq_parallel_world_size = get_sequence_parallel_world_size() c = self.out_channels p = self.x_embedder.patch_size[0] - h = w = int((x.shape[1] * seq_parallel_world_size) ** 0.5) - assert h * w == x.shape[1] * seq_parallel_world_size - assert h % seq_parallel_world_size == 0 - h_per_rank = h // seq_parallel_world_size + h = w = int(x.shape[1] ** 0.5) + assert h * w == x.shape[1] - x = x.reshape(shape=(x.shape[0], h_per_rank, w, p, p, c)) + x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) x = torch.einsum('nhwpqc->nchpwq', x) - imgs = x.reshape(shape=(x.shape[0], c, h_per_rank * p, w * p)) + imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p)) return imgs # @torch.cuda.amp.autocast() @@ -357,11 +357,11 @@ def forward(self, x, t, y=None, use_fp16=False, y_image=None, use_image_num=0): bs, seq_len, dim = x.shape seq_parallel_world_size = get_sequence_parallel_world_size() assert seq_len % seq_parallel_world_size == 0 - x = split_for_sequence_parallel(x, split_dim=1) + x = split_for_sequence_parallel(x, get_sequence_parallel_group(), split_dim=1) t = self.t_embedder(t, use_fp16=use_fp16) timestep_spatial = repeat(t, 'n d -> (n c) d', c=self.temp_embed.shape[1] + use_image_num) # (1, num_frames, hidden_size) timestep_temp = repeat(t, 'n d -> (n c) d', c=self.pos_embed.shape[1]) - timestep_temp = split_for_sequence_parallel(timestep_temp, split_dim=0) + timestep_temp = split_for_sequence_parallel(timestep_temp, get_sequence_parallel_group(), split_dim=0) if self.extras == 2: y = self.y_embedder(y, self.training) @@ -424,7 +424,9 @@ def forward(self, x, t, y=None, use_fp16=False, y_image=None, use_image_num=0): c = timestep_spatial + y_spatial else: c = timestep_spatial - x = self.final_layer(x, c) + x = self.final_layer(x, c) + # breakpoint() + x = gather_forward_split_backward(x, get_sequence_parallel_group(), dim=1, grad_scale="up") x = self.unpatchify(x) x = rearrange(x, '(b f) c h w -> b f c h w', b=batches) # print(x.shape) diff --git a/train_with_img.py b/train_with_img.py index 3460606..ae1f6d8 100644 --- a/train_with_img.py +++ b/train_with_img.py @@ -37,46 +37,7 @@ write_tensorboard, setup_distributed, get_experiment_dir) -from xtuner.parallel.sequence import ( - init_sequence_parallel, get_sequence_parallel_world_size, - get_sequence_parallel_rank, get_data_parallel_world_size, - get_data_parallel_rank, SequenceParallelSampler) - - -# class SequenceParallelSampler(DistributedSampler): -# def __init__(self, dataset, num_replicas = None, -# rank = None, shuffle: bool = True, -# seed: int = 0, drop_last: bool = False): -# if num_replicas is None: -# if not dist.is_available(): -# raise RuntimeError("Requires distributed package to be available") -# num_replicas = get_data_parallel_world_size() -# if rank is None: -# if not dist.is_available(): -# raise RuntimeError("Requires distributed package to be available") -# rank = get_data_parallel_rank() -# if rank >= num_replicas or rank < 0: -# raise ValueError( -# f"Invalid rank {rank}, rank should be in the interval [0, {num_replicas - 1}]") -# self.dataset = dataset -# self.num_replicas = num_replicas -# self.rank = rank -# self.epoch = 0 -# self.drop_last = drop_last -# # If the dataset length is evenly divisible by # of replicas, then there -# # is no need to drop any data, since the dataset will be split equally. -# if self.drop_last and len(self.dataset) % self.num_replicas != 0: # type: ignore[arg-type] -# # Split to nearest available length that is evenly divisible. -# # This is to ensure each rank receives the same amount of data when -# # using this Sampler. -# self.num_samples = math.ceil( -# (len(self.dataset) - self.num_replicas) / self.num_replicas # type: ignore[arg-type] -# ) -# else: -# self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) # type: ignore[arg-type] -# self.total_size = self.num_samples * self.num_replicas -# self.shuffle = shuffle -# self.seed = seed +from xtuner.parallel.sequence import init_sequence_parallel, SequenceParallelSampler ################################################################################# # Training Loop # From 14fb495b6da707692512789d9d9da6cf00bd0a61 Mon Sep 17 00:00:00 2001 From: HIT-cwh <2892770585@qq.com> Date: Wed, 10 Apr 2024 18:02:15 +0800 Subject: [PATCH 4/7] use split_forward_gather_backward --- models/latte_img.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/models/latte_img.py b/models/latte_img.py index 956d7d7..0983900 100644 --- a/models/latte_img.py +++ b/models/latte_img.py @@ -11,10 +11,15 @@ import torch import torch.nn as nn import numpy as np +import torch.distributed as dist from einops import rearrange, repeat from timm.models.vision_transformer import Mlp, PatchEmbed -from xtuner.parallel.sequence import get_sequence_parallel_world_size, split_for_sequence_parallel, get_sequence_parallel_group +from xtuner.parallel.sequence import ( + get_sequence_parallel_world_size, split_for_sequence_parallel, + get_sequence_parallel_group, gather_forward_split_backward, + post_process_for_sequence_parallel_attn, pre_process_for_sequence_parallel_attn, + split_forward_gather_backward) import os import sys @@ -37,12 +42,6 @@ def modulate(x, shift, scale): # Attention Layers from TIMM # ################################################################################# -from xtuner.parallel.sequence import ( - post_process_for_sequence_parallel_attn, pre_process_for_sequence_parallel_attn, - gather_forward_split_backward, get_sequence_parallel_group - ) -import torch.distributed as dist - class Attention(nn.Module): def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., use_lora=False, attention_mode='math'): super().__init__() @@ -357,11 +356,15 @@ def forward(self, x, t, y=None, use_fp16=False, y_image=None, use_image_num=0): bs, seq_len, dim = x.shape seq_parallel_world_size = get_sequence_parallel_world_size() assert seq_len % seq_parallel_world_size == 0 - x = split_for_sequence_parallel(x, get_sequence_parallel_group(), split_dim=1) + + x = split_forward_gather_backward(x, get_sequence_parallel_group(), dim=1, grad_scale="down") + # x = split_for_sequence_parallel(x, get_sequence_parallel_group(), dim=1) t = self.t_embedder(t, use_fp16=use_fp16) timestep_spatial = repeat(t, 'n d -> (n c) d', c=self.temp_embed.shape[1] + use_image_num) # (1, num_frames, hidden_size) - timestep_temp = repeat(t, 'n d -> (n c) d', c=self.pos_embed.shape[1]) - timestep_temp = split_for_sequence_parallel(timestep_temp, get_sequence_parallel_group(), split_dim=0) + timestep_temp = repeat(t, 'n d -> n c d', c=self.pos_embed.shape[1]) + timestep_temp = split_forward_gather_backward(timestep_temp, get_sequence_parallel_group(), dim=1, grad_scale="down") + # timestep_temp = split_for_sequence_parallel(timestep_temp, get_sequence_parallel_group(), dim=1) + timestep_temp = timestep_temp.flatten(0, 1) if self.extras == 2: y = self.y_embedder(y, self.training) From 9dc4df491b26ce5d7a85d1ea773674f676cf0891 Mon Sep 17 00:00:00 2001 From: HIT-cwh <2892770585@qq.com> Date: Wed, 10 Apr 2024 18:37:04 +0800 Subject: [PATCH 5/7] fix bug --- models/latte_img.py | 38 ++++++++++++++++---------------------- 1 file changed, 16 insertions(+), 22 deletions(-) diff --git a/models/latte_img.py b/models/latte_img.py index 0983900..c3f7fbc 100644 --- a/models/latte_img.py +++ b/models/latte_img.py @@ -19,7 +19,7 @@ get_sequence_parallel_world_size, split_for_sequence_parallel, get_sequence_parallel_group, gather_forward_split_backward, post_process_for_sequence_parallel_attn, pre_process_for_sequence_parallel_attn, - split_forward_gather_backward) + split_forward_gather_backward, all_to_all) import os import sys @@ -57,24 +57,21 @@ def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) - + def forward(self, x, use_sequence_parallel=True): B, N, C = x.shape - qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) - q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2] - - # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).contiguous() - # q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).contiguous() + q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) enable_sequence_parallel = ( dist.is_initialized() and get_sequence_parallel_world_size() > 1 and self.training and use_sequence_parallel) if enable_sequence_parallel: - q, k, v = pre_process_for_sequence_parallel_attn(q, k, v) - - q = q.transpose(1, 2).contiguous() # bshd -> bhsd - k = k.transpose(1, 2).contiguous() - v = v.transpose(1, 2).contiguous() + # (bs, n_head, seq_len / p, dim) -> (bs, n_head / p, seq_len, dim) + sp_group = get_sequence_parallel_group() + q = all_to_all(q, sp_group, scatter_dim=1, gather_dim=2) + k = all_to_all(k, sp_group, scatter_dim=1, gather_dim=2) + v = all_to_all(v, sp_group, scatter_dim=1, gather_dim=2) if self.attention_mode == 'xformers': # cause loss nan while using with amp x = xformers.ops.memory_efficient_attention(q, k, v) @@ -89,13 +86,15 @@ def forward(self, x, use_sequence_parallel=True): attn = (q @ k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) - x = (attn @ v).transpose(1, 2).contiguous() + x = (attn @ v).transpose(1, 2) else: raise NotImplemented if enable_sequence_parallel: - x = post_process_for_sequence_parallel_attn(x).contiguous() + # (bs, seq_len, n_head / p, dim) -> (bs, seq_len / p, n_head, dim) + sp_group = get_sequence_parallel_group() + x = all_to_all(x, sp_group, scatter_dim=1, gather_dim=2) x = x.reshape(B, N, C) @@ -352,18 +351,15 @@ def forward(self, x, t, y=None, use_fp16=False, y_image=None, use_image_num=0): batches, frames, channels, high, weight = x.shape x = rearrange(x, 'b f c h w -> (b f) c h w') x = self.x_embedder(x) + self.pos_embed - # print(f'before model x.shape = {x.shape}') bs, seq_len, dim = x.shape seq_parallel_world_size = get_sequence_parallel_world_size() assert seq_len % seq_parallel_world_size == 0 - x = split_forward_gather_backward(x, get_sequence_parallel_group(), dim=1, grad_scale="down") - # x = split_for_sequence_parallel(x, get_sequence_parallel_group(), dim=1) + x = split_forward_gather_backward(x, dim=1, sp_group=get_sequence_parallel_group(), grad_scale="down") t = self.t_embedder(t, use_fp16=use_fp16) timestep_spatial = repeat(t, 'n d -> (n c) d', c=self.temp_embed.shape[1] + use_image_num) # (1, num_frames, hidden_size) timestep_temp = repeat(t, 'n d -> n c d', c=self.pos_embed.shape[1]) - timestep_temp = split_forward_gather_backward(timestep_temp, get_sequence_parallel_group(), dim=1, grad_scale="down") - # timestep_temp = split_for_sequence_parallel(timestep_temp, get_sequence_parallel_group(), dim=1) + timestep_temp = split_forward_gather_backward(timestep_temp, dim=1, sp_group=get_sequence_parallel_group(), grad_scale="down") timestep_temp = timestep_temp.flatten(0, 1) if self.extras == 2: @@ -428,11 +424,9 @@ def forward(self, x, t, y=None, use_fp16=False, y_image=None, use_image_num=0): else: c = timestep_spatial x = self.final_layer(x, c) - # breakpoint() - x = gather_forward_split_backward(x, get_sequence_parallel_group(), dim=1, grad_scale="up") + x = gather_forward_split_backward(x, dim=1, sp_group=get_sequence_parallel_group(), grad_scale="up") x = self.unpatchify(x) x = rearrange(x, '(b f) c h w -> b f c h w', b=batches) - # print(x.shape) return x From 66ad86db917dc2aa25a084522c68468dc7dee6f7 Mon Sep 17 00:00:00 2001 From: HIT-cwh <2892770585@qq.com> Date: Wed, 10 Apr 2024 18:38:29 +0800 Subject: [PATCH 6/7] delete useless codes --- diffusion/gaussian_diffusion.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/diffusion/gaussian_diffusion.py b/diffusion/gaussian_diffusion.py index 0c9af4b..9933d65 100644 --- a/diffusion/gaussian_diffusion.py +++ b/diffusion/gaussian_diffusion.py @@ -9,8 +9,6 @@ import numpy as np import torch as th import enum -from xtuner.parallel.sequence import reduce_sequence_parallel_loss, split_for_sequence_parallel -import torch from .diffusion_utils import discretized_gaussian_log_likelihood, normal_kl @@ -734,7 +732,7 @@ def training_losses(self, model, x_start, t, model_kwargs=None, noise=None): model_kwargs = {} if noise is None: noise = th.randn_like(x_start) - x_t = self.q_sample(x_start, t, noise=noise) # x: (bs, frame, c_in, h, w), t: (bs, ) + x_t = self.q_sample(x_start, t, noise=noise) terms = {} @@ -750,7 +748,7 @@ def training_losses(self, model, x_start, t, model_kwargs=None, noise=None): if self.loss_type == LossType.RESCALED_KL: terms["loss"] *= self.num_timesteps elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE: - model_output = model(x_t, t, **model_kwargs) # (bs, frame, 2*c_in, h, w) + model_output = model(x_t, t, **model_kwargs) # try: # model_output = model(x_t, t, **model_kwargs).sample # for tav unet # except: From a8af8d9581d99b37c5ff04550e90fafc165893f3 Mon Sep 17 00:00:00 2001 From: HIT-cwh <2892770585@qq.com> Date: Wed, 10 Apr 2024 18:39:37 +0800 Subject: [PATCH 7/7] delete useless codes --- models/latte_img.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/models/latte_img.py b/models/latte_img.py index c3f7fbc..7f0e485 100644 --- a/models/latte_img.py +++ b/models/latte_img.py @@ -16,10 +16,8 @@ from einops import rearrange, repeat from timm.models.vision_transformer import Mlp, PatchEmbed from xtuner.parallel.sequence import ( - get_sequence_parallel_world_size, split_for_sequence_parallel, - get_sequence_parallel_group, gather_forward_split_backward, - post_process_for_sequence_parallel_attn, pre_process_for_sequence_parallel_attn, - split_forward_gather_backward, all_to_all) + get_sequence_parallel_world_size, get_sequence_parallel_group, + gather_forward_split_backward, split_forward_gather_backward, all_to_all) import os import sys