diff --git a/muse_maskgit_pytorch/dataset.py b/muse_maskgit_pytorch/dataset.py index 34c95a0..d56aef0 100644 --- a/muse_maskgit_pytorch/dataset.py +++ b/muse_maskgit_pytorch/dataset.py @@ -167,27 +167,31 @@ def __getitem__(self, index): else: text = descriptions # max length from the paper - encoded = self.tokenizer.batch_encode_plus( - [str(text)], - return_tensors="pt", - padding="max_length", - max_length=MAX_LENGTH, - truncation=True, - ) + if self.tokenizer is not None: + encoded = self.tokenizer.batch_encode_plus( + [str(text)], + return_tensors="pt", + padding="max_length", + max_length=MAX_LENGTH, + truncation=True, + ) - input_ids = encoded.input_ids - attn_mask = encoded.attention_mask + input_ids = encoded.input_ids + attn_mask = encoded.attention_mask + else: + input_ids = [0] + attn_mask = [0] if self.using_taming: if self.embeds: - return self.transform(image) - 0.5, input_ids[0], attn_mask[0], embed + return self.transform(image) - 0.5, input_ids[0], attn_mask[0], embed, text else: - return self.transform(image), input_ids[0], attn_mask[0], [] + return self.transform(image), input_ids[0], attn_mask[0], [], text else: if self.embeds: - return self.transform(image), input_ids[0], attn_mask[0], embed + return self.transform(image), input_ids[0], attn_mask[0], embed, text else: - return self.transform(image), input_ids[0], attn_mask[0], [] + return self.transform(image), input_ids[0], attn_mask[0], [], text class URLTextDataset(ImageDataset): @@ -242,27 +246,31 @@ def __getitem__(self, index): else: text = descriptions # max length from the paper - encoded = self.tokenizer.batch_encode_plus( - [str(text)], - return_tensors="pt", - padding="max_length", - max_length=MAX_LENGTH, - truncation=True, - ) + if self.tokenizer is not None: + encoded = self.tokenizer.batch_encode_plus( + [str(text)], + return_tensors="pt", + padding="max_length", + max_length=MAX_LENGTH, + truncation=True, + ) - input_ids = encoded.input_ids - attn_mask = encoded.attention_mask + input_ids = encoded.input_ids + attn_mask = encoded.attention_mask + else: + input_ids = [0] + attn_mask = [0] if self.using_taming: if self.embeds: - return self.transform(image) - 0.5, input_ids[0], attn_mask[0], embed + return self.transform(image) - 0.5, input_ids[0], attn_mask[0], embed, text else: - return self.transform(image), input_ids[0], attn_mask[0], [] + return self.transform(image), input_ids[0], attn_mask[0], [], text else: if self.embeds: - return self.transform(image), input_ids[0], attn_mask[0], embed + return self.transform(image), input_ids[0], attn_mask[0], embed, text else: - return self.transform(image), input_ids[0], attn_mask[0], [] + return self.transform(image), input_ids[0], attn_mask[0], [], text class LocalTextImageDataset(Dataset): @@ -338,26 +346,31 @@ def __getitem__(self, index): embed = self.embeds[index] # max length from the paper - encoded = self.tokenizer.batch_encode_plus( - [str(text)], - return_tensors="pt", - padding="max_length", - max_length=MAX_LENGTH, - truncation=True, - ) + if self.tokenizer is not None: + encoded = self.tokenizer.batch_encode_plus( + [str(text)], + return_tensors="pt", + padding="max_length", + max_length=MAX_LENGTH, + truncation=True, + ) + + input_ids = encoded.input_ids + attn_mask = encoded.attention_mask + else: + input_ids = [0] + attn_mask = [0] - input_ids = encoded.input_ids - attn_mask = encoded.attention_mask if self.using_taming: if self.embeds: - return self.transform(image) - 0.5, input_ids[0], attn_mask[0], embed + return self.transform(image) - 0.5, input_ids[0], attn_mask[0], embed, text else: - return self.transform(image), input_ids[0], attn_mask[0], [] + return self.transform(image), input_ids[0], attn_mask[0], [], text else: if self.embeds: - return self.transform(image), input_ids[0], attn_mask[0], embed + return self.transform(image), input_ids[0], attn_mask[0], embed, text else: - return self.transform(image), input_ids[0], attn_mask[0], [] + return self.transform(image), input_ids[0], attn_mask[0], [], text def get_directory_size(path): diff --git a/muse_maskgit_pytorch/muse_maskgit_pytorch.py b/muse_maskgit_pytorch/muse_maskgit_pytorch.py index e7615fd..6566370 100644 --- a/muse_maskgit_pytorch/muse_maskgit_pytorch.py +++ b/muse_maskgit_pytorch/muse_maskgit_pytorch.py @@ -13,7 +13,7 @@ from einops import rearrange, repeat from torch import einsum, isnan, nn from tqdm.auto import tqdm -from transformers import T5EncoderModel, T5Tokenizer +from transformers import CLIPTextModel, CLIPTokenizer, CLIPTokenizerFast, T5EncoderModel, T5Tokenizer from .attn import ein_attn, sdp_attn from .t5 import DEFAULT_T5_NAME, get_encoded_dim, get_model_and_tokenizer, t5_encode_text @@ -169,6 +169,7 @@ def __init__( self_cond: bool = False, add_mask_id: bool = False, cache_path: PathLike = None, + use_clip=False, **kwargs, ): super().__init__() @@ -183,17 +184,24 @@ def __init__( self.transformer_blocks = TransformerBlocks(dim=dim, **kwargs) self.norm = LayerNorm(dim) + self.use_clip = use_clip + self.tokenizer = None + self.dim_out = default(dim_out, num_tokens) self.to_logits = nn.Linear(dim, self.dim_out, bias=False) # text conditioning - t5, tokenizer = get_model_and_tokenizer(t5_name, cache_path) - self.t5: T5EncoderModel = t5 - self.tokenizer: T5Tokenizer = tokenizer + if not use_clip: + t5, tokenizer = get_model_and_tokenizer(t5_name, cache_path) + self.t5: T5EncoderModel = t5 + self.tokenizer: T5Tokenizer = tokenizer + + self.t5.eval() - self.t5.eval() + text_embed_dim = get_encoded_dim(t5_name) - text_embed_dim = get_encoded_dim(t5_name) + else: + text_embed_dim = 512 self.text_embed_proj = ( nn.Linear(text_embed_dim, dim, bias=False) if text_embed_dim != dim else nn.Identity() @@ -204,8 +212,11 @@ def __init__( self.self_cond_to_init_embed = FeedForward(dim) def encode_text(self, *args, **kwargs): - kwargs.update(tokenizer=self.tokenizer, t5=self.t5) - return t5_encode_text(*args, **kwargs) + if not self.use_clip: + kwargs.update(tokenizer=self.tokenizer, t5=self.t5) + return t5_encode_text(*args, **kwargs) + else: + print("Using clip instead, this function shouldn't be accessed") def forward_with_cond_scale(self, *args, cond_scale=3.0, return_embed=False, **kwargs): if cond_scale == 1: @@ -406,6 +417,8 @@ def __init__( self, image_size, transformer: MaskGitTransformer, + clip: CLIPTextModel = None, + clip_tokenizer: Union[CLIPTokenizer, CLIPTokenizerFast] = None, accelerator: Optional[Accelerator] = None, noise_schedule: Callable = cosine_schedule, token_critic: Optional[TokenCritic] = None, @@ -435,6 +448,9 @@ def __init__( self.resize_image_for_cond_image = exists(cond_image_size) self.cond_drop_prob = cond_drop_prob + self.clip = clip + self.clip_tokenizer = clip_tokenizer + self.transformer = transformer self.self_cond = transformer.self_cond if not self.vae.codebook_size == self.cond_vae.codebook_size == transformer.num_tokens: @@ -513,22 +529,54 @@ def generate( cond_ids = None - text_embeds = self.transformer.encode_text(texts) - - demask_fn = self.transformer.forward_with_cond_scale - - # whether to use token critic for scores - use_token_critic = exists(self.token_critic) and not force_not_use_token_critic - + # whether to use token critic for scores if use_token_critic: token_critic_fn = self.token_critic.forward_with_cond_scale - # negative prompting, as in paper + if self.clip is not None and self.clip_tokenizer is not None: + clip_model = self.clip + clip_tokenizer = self.clip_tokenizer + print(texts) + inputs = [token[1:-1] for token in clip_tokenizer(texts, truncation=True).input_ids] + + inputs = torch.tensor(inputs, device=self.accelerator.device) + max_embeddings_multiples = (inputs.shape[1] - 2) // (75 - 2) + if max_embeddings_multiples > 1: + text_embeddings = [] + for i in range(max_embeddings_multiples): + # extract the i-th chunk + text_input_chunk = inputs[:, i * (75 - 2) : (i + 1) * (75 - 2) + 2].clone() + + # cover the head and the tail by the starting and the ending tokens + text_input_chunk[:, 0] = inputs[0, 0] + text_input_chunk[:, -1] = inputs[0, -1] + text_embedding = clip_model(text_input_chunk)[0] + + if i == 0: + # discard the ending token + text_embedding = text_embedding[:, :-1] + elif i == max_embeddings_multiples - 1: + # discard the starting token + text_embedding = text_embedding[:, 1:] + else: + # discard both starting and ending tokens + text_embedding = text_embedding[:, 1:-1] + + text_embeddings.append(text_embedding) + text_embeds = torch.concat(text_embeddings, axis=1).to(self.accelerator.device) + else: + text_embeds = clip_model(inputs)[0].to(self.accelerator.device) + else: + text_embeds = self.transformer.encode_text(texts) + + demask_fn = self.transformer.forward_with_cond_scale - neg_text_embeds = None - if exists(negative_texts): - assert len(texts) == len(negative_texts) + # negative prompting, as in paper + + neg_text_embeds = None + if exists(negative_texts): + assert len(texts) == len(negative_texts) neg_text_embeds = self.transformer.encode_text(negative_texts) demask_fn = partial( @@ -536,11 +584,11 @@ def generate( neg_text_embeds=neg_text_embeds, ) - if use_token_critic: - token_critic_fn = partial( - self.token_critic.forward_with_neg_prompt, - neg_text_embeds=neg_text_embeds, - ) + if use_token_critic: + token_critic_fn = partial( + self.token_critic.forward_with_neg_prompt, + neg_text_embeds=neg_text_embeds, + ) if self.resize_image_for_cond_image: if cond_images is None: @@ -565,14 +613,18 @@ def generate( ids = ids.scatter(1, masked_indices, self.mask_id) - logits, embed = demask_fn( - ids, - text_embeds=text_embeds, - self_cond_embed=self_cond_embed, - conditioning_token_ids=cond_ids, - cond_scale=cond_scale, - return_embed=True, - ) + if self.clip is None: + logits, embed = demask_fn( + ids, + text_embeds=text_embeds, + self_cond_embed=self_cond_embed, + conditioning_token_ids=cond_ids, + cond_scale=cond_scale, + return_embed=True, + ) + else: + embed = text_embeds + logits = text_embeds self_cond_embed = embed if self.self_cond else None diff --git a/muse_maskgit_pytorch/trainers/maskgit_trainer.py b/muse_maskgit_pytorch/trainers/maskgit_trainer.py index cfce2f1..fb29613 100644 --- a/muse_maskgit_pytorch/trainers/maskgit_trainer.py +++ b/muse_maskgit_pytorch/trainers/maskgit_trainer.py @@ -6,7 +6,6 @@ from diffusers.optimization import SchedulerType from ema_pytorch import EMA from omegaconf import OmegaConf -from PIL import Image from torch.optim import Optimizer from torch.utils.data import DataLoader from torchvision.utils import save_image @@ -24,9 +23,27 @@ xm = None met = None +import open_clip +import torchvision.transforms as transforms +from PIL import Image from tqdm import tqdm +def divide_string(string, parts): + # Determine the length of each substring + part_length = len(string) // parts + + # Divide the string into 'parts' number of substrings + substrings = [string[i : i + part_length] for i in range(0, len(string), part_length)] + + # If there are any leftover characters, add them to the last substring + if len(substrings) > parts: + substrings[-2] += substrings[-1] + substrings.pop() + + return substrings + + class MaskGitTrainer(BaseAcceleratedTrainer): def __init__( self, @@ -59,6 +76,7 @@ def __init__( validation_image_scale: float = 1.0, only_save_last_checkpoint=False, args=None, + clip=None, ): super().__init__( dataloader=dataloader, @@ -90,12 +108,18 @@ def __init__( # maskgit maskgit.vae.requires_grad_(False) - maskgit.transformer.t5.requires_grad_(False) + self.model: MaskGit = maskgit self.optim: Optimizer = optimizer self.lr_scheduler: SchedulerType = scheduler + self.use_clip = True if clip is not None else False + self.clip_model = clip + + if not self.use_clip: + maskgit.transformer.t5.requires_grad_(False) + self.use_ema = use_ema self.validation_prompts: List[str] = validation_prompts if use_ema: @@ -125,14 +149,17 @@ def save_validation_images( self.accelerator.print( f"\nStep: {step} | Logging with prompts: {[' | '.join(validation_prompts)]}" ) - - images = self.model.generate( - validation_prompts, - cond_images=cond_image, - cond_scale=cond_scale, - temperature=temperature, - timesteps=timesteps, - ).to(self.accelerator.device) + images = [] + for text in validation_prompts: + images.append( + self.model.generate( + (text,), + cond_images=cond_image, + cond_scale=cond_scale, + temperature=temperature, + timesteps=timesteps, + ).to(self.accelerator.device) + ) save_dir = self.results_dir.joinpath("MaskGit") save_dir.mkdir(exist_ok=True, parents=True) @@ -154,15 +181,49 @@ def train(self): # logs for epoch in range(self.current_step // len(self.dl), self.num_epochs): - for imgs, input_ids, attn_mask, text_embeds in iter(self.dl): + for imgs, input_ids, attn_mask, text_embeds, text in iter(self.dl): train_loss = 0.0 steps = int(self.steps.item()) - if not text_embeds: - with torch.no_grad(): - text_embeds = t5_encode_text_from_encoded( - input_ids, attn_mask, self.model.transformer.t5, self.accelerator.device - ) + if not self.use_clip: + if not text_embeds: + with torch.no_grad(): + text_embeds = t5_encode_text_from_encoded( + input_ids, attn_mask, self.model.transformer.t5, self.accelerator.device + ) + else: + print(text) + clip_model, clip_tokenizer = self.clip_model + inputs = [token[1:-1] for token in clip_tokenizer(text, truncation=True).input_ids] + + inputs = torch.tensor(inputs, device=self.accelerator.device) + + max_embeddings_multiples = (inputs.shape[1] - 2) // (75 - 2) + if max_embeddings_multiples > 1: + text_embeddings = [] + for i in range(max_embeddings_multiples): + # extract the i-th chunk + text_input_chunk = inputs[:, i * (75 - 2) : (i + 1) * (75 - 2) + 2].clone() + + # cover the head and the tail by the starting and the ending tokens + text_input_chunk[:, 0] = inputs[0, 0] + text_input_chunk[:, -1] = inputs[0, -1] + text_embedding = clip_model(text_input_chunk)[0] + + if i == 0: + # discard the ending token + text_embedding = text_embedding[:, :-1] + elif i == max_embeddings_multiples - 1: + # discard the starting token + text_embedding = text_embedding[:, 1:] + else: + # discard both starting and ending tokens + text_embedding = text_embedding[:, 1:-1] + + text_embeddings.append(text_embedding) + text_embeds = torch.concat(text_embeddings, axis=1).to(self.accelerator.device) + else: + text_embeds = clip_model(inputs)[0].to(self.accelerator.device) with self.accelerator.accumulate(self.model), self.accelerator.autocast(): loss = self.model(imgs, text_embeds=text_embeds) diff --git a/train_muse_maskgit.py b/train_muse_maskgit.py index 4d11423..d0ed167 100644 --- a/train_muse_maskgit.py +++ b/train_muse_maskgit.py @@ -18,6 +18,7 @@ from rich import inspect from torch.optim import Optimizer from tqdm import tqdm +from transformers import AutoTokenizer, CLIPTextModel import wandb from muse_maskgit_pytorch.t5 import t5_encode_text_from_encoded @@ -124,6 +125,12 @@ def decompress_pickle(file): parser.add_argument("--heads", type=int, default=8, help="Attention heads") parser.add_argument("--ff_mult", type=int, default=4, help="Feed forward expansion factor") parser.add_argument("--t5_name", type=str, default="t5-small", help="Name of your t5 model") +parser.add_argument( + "--use_clip", + action="store_true", + default=False, + help="whether to use MetaClip instead of a T5", +) parser.add_argument("--cond_image_size", type=int, default=None, help="Conditional image size.") parser.add_argument( "--validation_prompt", @@ -480,6 +487,7 @@ class Arguments: heads: int = 8 ff_mult: int = 4 t5_name: str = "t5-small" + use_clip: bool = False mixed_precision: str = "no" cond_image_size: Optional[int] = None validation_prompt: str = "A photo of a dog" @@ -750,12 +758,27 @@ def main(): cache_path=args.cache_path, flash=flash, xformers=xformers, + use_clip=args.use_clip, ) + if args.use_clip: + model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32", cache_dir=args.cache_path).to( + accelerator.device + ) + tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32", cache_dir=args.cache_path) + + clip = (model, tokenizer) + else: + model = None + tokenizer = None + clip = None + # (2) pass your trained VAE and the base transformer to MaskGit maskgit = MaskGit( vae=vae, # vqgan vae transformer=transformer, # transformer + clip=model, + clip_tokenizer=tokenizer, accelerator=accelerator, # accelerator image_size=args.image_size, # image size cond_drop_prob=args.cond_drop_prob, # conditional dropout, for classifier free guidance @@ -815,6 +838,9 @@ def main(): else: embeds = [] + if args.use_clip: + transformer.tokenizer = None + # Create the dataset objects with accelerator.main_process_first(): if args.no_cache and args.train_data_dir: @@ -1017,6 +1043,7 @@ def main(): only_save_last_checkpoint=args.only_save_last_checkpoint, num_epochs=args.num_epochs, args=args, + clip=clip, ) # Prepare the trainer for distributed training