Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions configs/ffs/ffs_img_train.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,18 @@
dataset: "ffs_img"

data_path: "/path/to/datasets/preprocessed_ffs/train/videos/"
frame_data_path: "/path/to/datasets/preprocessed_ffs/train/images/"
frame_data_txt: "/path/to/datasets/preprocessed_ffs/train_list.txt"
pretrained_model_path: "/path/to/pretrained/Latte/"
frame_data_path: "/mnt/petrelfs/share_data/linzhihao/dataset/llava_data/LLaVA-Pretrain/images/00000"
frame_data_txt: "/mnt/petrelfs/caoweihan/projects/Latte/data/toy_train_list.txt"
pretrained_model_path: "maxin-cn/Latte"

# save and load
results_dir: "./results_img"
pretrained:

# model config:
model: LatteIMG-XL/2
num_frames: 16
# num_frames: 16
num_frames: 0
image_size: 256 # choices=[256, 512]
num_sampling_steps: 250
frame_interval: 3
Expand All @@ -22,6 +23,7 @@ learn_sigma: True # important
extras: 1 # [1, 2] 1 unconditional generation, 2 class-conditional generation

# train config:
sequence_parallel_size: 1
save_ceph: True # important
use_image_num: 8
learning_rate: 1e-4
Expand All @@ -32,7 +34,7 @@ local_batch_size: 4 # important
max_train_steps: 1000000
global_seed: 3407
num_workers: 8
log_every: 100
log_every: 10
lr_warmup_steps: 0
resume_from_checkpoint:
gradient_accumulation_steps: 1 # TODO
Expand Down
47 changes: 47 additions & 0 deletions configs/ffs/ffs_img_train_sp2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# dataset
dataset: "ffs_img"

data_path: "/path/to/datasets/preprocessed_ffs/train/videos/"
frame_data_path: "/mnt/petrelfs/share_data/linzhihao/dataset/llava_data/LLaVA-Pretrain/images/00000"
frame_data_txt: "/mnt/petrelfs/caoweihan/projects/Latte/data/toy_train_list.txt"
pretrained_model_path: "maxin-cn/Latte"

# save and load
results_dir: "./results_img"
pretrained:

# model config:
model: LatteIMG-XL/2
# num_frames: 16
num_frames: 0
image_size: 256 # choices=[256, 512]
num_sampling_steps: 250
frame_interval: 3
fixed_spatial: False
attention_bias: True
learn_sigma: True # important
extras: 1 # [1, 2] 1 unconditional generation, 2 class-conditional generation

# train config:
sequence_parallel_size: 2
save_ceph: True # important
use_image_num: 8
learning_rate: 1e-4
ckpt_every: 10000
clip_max_norm: 0.1
start_clip_iter: 500000
local_batch_size: 4 # important
max_train_steps: 1000000
global_seed: 3407
num_workers: 8
log_every: 10
lr_warmup_steps: 0
resume_from_checkpoint:
gradient_accumulation_steps: 1 # TODO
num_classes:

# low VRAM and speed up training
use_compile: False
mixed_precision: False
enable_xformers_memory_efficient_attention: False
gradient_checkpointing: False
55 changes: 32 additions & 23 deletions datasets/ffs_image_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,38 +139,40 @@ def __init__(self,
transform=None,
temporal_sample=None):
self.configs = configs
self.data_path = configs.data_path
self.video_lists = get_filelist(configs.data_path)
# self.data_path = configs.data_path
# self.video_lists = get_filelist(configs.data_path)
self.transform = transform
self.temporal_sample = temporal_sample
self.target_video_len = self.configs.num_frames
self.v_decoder = DecordInit()
self.video_length = len(self.video_lists)
# self.target_video_len = self.configs.num_frames
# self.v_decoder = DecordInit()
# self.video_length = len(self.video_lists)

# ffs video frames
self.video_frame_path = configs.frame_data_path
self.video_frame_txt = configs.frame_data_txt
self.video_frame_files = [frame_file.strip() for frame_file in open(self.video_frame_txt)]
random.shuffle(self.video_frame_files)
# random.shuffle(self.video_frame_files)
self.use_image_num = configs.use_image_num
self.image_tranform = transforms.Compose([
transforms.ToTensor(),
transforms.Resize(configs.image_size),
transforms.CenterCrop(configs.image_size),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
])

def __getitem__(self, index):
video_index = index % self.video_length
path = self.video_lists[video_index]
vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit='sec', output_format='TCHW')
total_frames = len(vframes)
# video_index = index % self.video_length
# path = self.video_lists[video_index]
# vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit='sec', output_format='TCHW')
# total_frames = len(vframes)

# Sampling video frames
start_frame_ind, end_frame_ind = self.temporal_sample(total_frames)
assert end_frame_ind - start_frame_ind >= self.target_video_len
frame_indice = np.linspace(start_frame_ind, end_frame_ind-1, self.target_video_len, dtype=int)
video = vframes[frame_indice]
# videotransformer data proprecess
video = self.transform(video) # T C H W
# # Sampling video frames
# start_frame_ind, end_frame_ind = self.temporal_sample(total_frames)
# assert end_frame_ind - start_frame_ind >= self.target_video_len
# frame_indice = np.linspace(start_frame_ind, end_frame_ind-1, self.target_video_len, dtype=int)
# video = vframes[frame_indice]
# # videotransformer data proprecess
# video = self.transform(video) # T C H W

# get video frames
images = []
Expand All @@ -182,15 +184,16 @@ def __getitem__(self, index):
images.append(image)
break
except Exception as e:
traceback.print_exc()
# traceback.print_exc()
index = random.randint(0, len(self.video_frame_files) - self.use_image_num)
images = torch.cat(images, dim=0)

assert len(images) == self.use_image_num
return {'video': images, 'video_name': 1}

video_cat = torch.cat([video, images], dim=0)
# video_cat = torch.cat([video, images], dim=0)

return {'video': video_cat, 'video_name': 1}
# return {'video': video_cat, 'video_name': 1}

def __len__(self):
return len(self.video_frame_files)
Expand All @@ -216,6 +219,7 @@ def __len__(self):
parser.add_argument("--data-path", type=str, default="/path/to/datasets/preprocessed_ffs/train/videos/")
parser.add_argument("--frame-data-path", type=str, default="/path/to/datasets/preprocessed_ffs/train/images/")
parser.add_argument("--frame-data-txt", type=str, default="/path/to/datasets/faceForensics_v1/train_list.txt")
parser.add_argument("--image-size", type=int, default=256)
config = parser.parse_args()

temporal_sample = video_transforms.TemporalRandomCrop(config.num_frames * config.frame_interval)
Expand All @@ -226,21 +230,26 @@ def __len__(self):
])

dataset = FaceForensicsImages(config, transform=transform_webvideo, temporal_sample=temporal_sample)
dataloader = Data.DataLoader(dataset=dataset, batch_size=1, shuffle=True, num_workers=4)
print(len(dataset))
# breakpoint()
dataloader = Data.DataLoader(dataset=dataset, batch_size=1, shuffle=True, num_workers=1)

for i, video_data in enumerate(dataloader):
video, video_label = video_data['video'], video_data['video_name']
# print(video_label)
# print(image_label)

print(video.shape)
print(video_label)
# if i % 100 == 0:
# print(i)
# video_ = ((video[0] * 0.5 + 0.5) * 255).add_(0.5).clamp_(0, 255).to(dtype=torch.uint8).cpu().permute(0, 2, 3, 1)
# print(video_.shape)
# try:
# torchvision.io.write_video(f'./test/{i:03d}_{video_label}.mp4', video_[:16], fps=8)
# except:
# pass

# if i % 100 == 0 and i != 0:
# break
if i % 100 == 0 and i != 0:
break
print('Done!')
57 changes: 43 additions & 14 deletions models/latte_img.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,13 @@
import torch
import torch.nn as nn
import numpy as np
import torch.distributed as dist

from einops import rearrange, repeat
from timm.models.vision_transformer import Mlp, PatchEmbed
from xtuner.parallel.sequence import (
get_sequence_parallel_world_size, get_sequence_parallel_group,
gather_forward_split_backward, split_forward_gather_backward, all_to_all)

import os
import sys
Expand Down Expand Up @@ -51,29 +55,46 @@ def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.,
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)

def forward(self, x):
def forward(self, x, use_sequence_parallel=True):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).contiguous()
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)

enable_sequence_parallel = (
dist.is_initialized() and get_sequence_parallel_world_size() > 1
and self.training and use_sequence_parallel)
if enable_sequence_parallel:
# (bs, n_head, seq_len / p, dim) -> (bs, n_head / p, seq_len, dim)
sp_group = get_sequence_parallel_group()
q = all_to_all(q, sp_group, scatter_dim=1, gather_dim=2)
k = all_to_all(k, sp_group, scatter_dim=1, gather_dim=2)
v = all_to_all(v, sp_group, scatter_dim=1, gather_dim=2)

if self.attention_mode == 'xformers': # cause loss nan while using with amp
x = xformers.ops.memory_efficient_attention(q, k, v).reshape(B, N, C)
x = xformers.ops.memory_efficient_attention(q, k, v)

elif self.attention_mode == 'flash':
# cause loss nan while using with amp
# Optionally use the context manager to ensure one of the fused kerenels is run
with torch.backends.cuda.sdp_kernel(enable_math=False):
x = torch.nn.functional.scaled_dot_product_attention(q, k, v).reshape(B, N, C) # require pytorch 2.0
x = torch.nn.functional.scaled_dot_product_attention(q, k, v) # require pytorch 2.0

elif self.attention_mode == 'math':
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = (attn @ v).transpose(1, 2)

else:
raise NotImplemented

if enable_sequence_parallel:
# (bs, seq_len, n_head / p, dim) -> (bs, seq_len / p, n_head, dim)
sp_group = get_sequence_parallel_group()
x = all_to_all(x, sp_group, scatter_dim=1, gather_dim=2)

x = x.reshape(B, N, C)

x = self.proj(x)
x = self.proj_drop(x)
Expand Down Expand Up @@ -177,9 +198,9 @@ def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)

def forward(self, x, c):
def forward(self, x, c, use_sequence_parallel=True):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), use_sequence_parallel)
x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
return x

Expand Down Expand Up @@ -324,12 +345,20 @@ def forward(self, x, t, y=None, use_fp16=False, y_image=None, use_image_num=0):
"""
if use_fp16:
x = x.to(dtype=torch.float16)

batches, frames, channels, high, weight = x.shape
x = rearrange(x, 'b f c h w -> (b f) c h w')
x = self.x_embedder(x) + self.pos_embed
t = self.t_embedder(t, use_fp16=use_fp16)
timestep_spatial = repeat(t, 'n d -> (n c) d', c=self.temp_embed.shape[1] + use_image_num)
timestep_temp = repeat(t, 'n d -> (n c) d', c=self.pos_embed.shape[1])
bs, seq_len, dim = x.shape
seq_parallel_world_size = get_sequence_parallel_world_size()
assert seq_len % seq_parallel_world_size == 0

x = split_forward_gather_backward(x, dim=1, sp_group=get_sequence_parallel_group(), grad_scale="down")
t = self.t_embedder(t, use_fp16=use_fp16)
timestep_spatial = repeat(t, 'n d -> (n c) d', c=self.temp_embed.shape[1] + use_image_num) # (1, num_frames, hidden_size)
timestep_temp = repeat(t, 'n d -> n c d', c=self.pos_embed.shape[1])
timestep_temp = split_forward_gather_backward(timestep_temp, dim=1, sp_group=get_sequence_parallel_group(), grad_scale="down")
timestep_temp = timestep_temp.flatten(0, 1)

if self.extras == 2:
y = self.y_embedder(y, self.training)
Expand Down Expand Up @@ -367,7 +396,7 @@ def forward(self, x, t, y=None, use_fp16=False, y_image=None, use_image_num=0):
c = timestep_spatial + text_embedding_spatial
else:
c = timestep_spatial
x = spatial_block(x, c)
x = spatial_block(x, c, use_sequence_parallel=True)

x = rearrange(x, '(b f) t d -> (b t) f d', b=batches)
x_video = x[:, :(frames-use_image_num), :]
Expand All @@ -384,18 +413,18 @@ def forward(self, x, t, y=None, use_fp16=False, y_image=None, use_image_num=0):
else:
c = timestep_temp

x_video = temp_block(x_video, c)
x_video = temp_block(x_video, c, use_sequence_parallel=False)
x = torch.cat([x_video, x_image], dim=1)
x = rearrange(x, '(b t) f d -> (b f) t d', b=batches)

if self.extras == 2:
c = timestep_spatial + y_spatial
else:
c = timestep_spatial
x = self.final_layer(x, c)
x = self.final_layer(x, c)
x = gather_forward_split_backward(x, dim=1, sp_group=get_sequence_parallel_group(), grad_scale="up")
x = self.unpatchify(x)
x = rearrange(x, '(b f) c h w -> b f c h w', b=batches)
# print(x.shape)
return x


Expand Down
24 changes: 15 additions & 9 deletions train_with_img.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
write_tensorboard, setup_distributed, get_experiment_dir)


from xtuner.parallel.sequence import init_sequence_parallel, SequenceParallelSampler

#################################################################################
# Training Loop #
#################################################################################
Expand All @@ -47,6 +49,7 @@ def main(args):

# Setup DDP:
setup_distributed()
init_sequence_parallel(args.sequence_parallel_size)
# dist.init_process_group("nccl")
# assert args.global_batch_size % dist.get_world_size() == 0, f"Batch size must be divisible by world size."
# rank = dist.get_rank()
Expand All @@ -57,7 +60,8 @@ def main(args):
local_rank = int(os.environ["LOCAL_RANK"])
device = torch.device("cuda", local_rank)

seed = args.global_seed + rank
# seed = args.global_seed + rank
seed = args.global_seed
torch.manual_seed(seed)
torch.cuda.set_device(device)
print(f"Starting rank={rank}, local rank={local_rank}, seed={seed}, world_size={dist.get_world_size()}.")
Expand Down Expand Up @@ -147,14 +151,16 @@ def main(args):

# Setup data:
dataset = get_dataset(args)

sampler = SequenceParallelSampler(dataset, shuffle=True, seed=args.global_seed)

sampler = DistributedSampler(
dataset,
num_replicas=dist.get_world_size(),
rank=rank,
shuffle=True,
seed=args.global_seed
)
# sampler = DistributedSampler(
# dataset,
# num_replicas=dist.get_world_size(),
# rank=rank,
# shuffle=True,
# seed=args.global_seed
# )
loader = DataLoader(
dataset,
batch_size=int(args.local_batch_size),
Expand All @@ -164,7 +170,7 @@ def main(args):
pin_memory=True,
drop_last=True
)
logger.info(f"Dataset contains {len(dataset):,} videos ({args.webvideo_data_path})")
# logger.info(f"Dataset contains {len(dataset):,} videos ({args.webvideo_data_path})")

# Scheduler
lr_scheduler = get_scheduler(
Expand Down