From 088a5b77e0089934018a01bcbae97343175d2bce Mon Sep 17 00:00:00 2001 From: davide cavicchini <40665241+DavidC001@users.noreply.github.com> Date: Thu, 5 Jun 2025 18:12:01 +0200 Subject: [PATCH] Use config for training parameters --- config.example.yaml | 30 +++++++++- configuration.py | 27 ++++++++- llava_finetune/model.py | 124 ++++++++++++++++++---------------------- pyproject.toml | 1 + pytest.ini | 2 + requirements.txt | 1 + train_trainer.py | 71 +++++++++++++++++++++++ trainer_hf.py | 89 ++++++++++++++++++++++++++++ 8 files changed, 275 insertions(+), 70 deletions(-) create mode 100644 pytest.ini create mode 100644 train_trainer.py create mode 100644 trainer_hf.py diff --git a/config.example.yaml b/config.example.yaml index 51fc04d..5b8e05e 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -17,4 +17,32 @@ alphaclip: checkpoint_dir: "..." others: - wandb_token: "TOKEN" \ No newline at end of file + wandb_token: "TOKEN" + +training: + epochs: 50 + optimizer: + adapter_lr: 0.001 + lora_lr: 0.000001 + scheduler: + eta_min: 0.0001 + preprocess_params: + model: "alpha-clip" + only_masks: false + model_params: + lora_rank: 16 + q4: true + q8: false + expand_factor: 2 + mlp_adapter: true + mask_prob: 0.2 + dropout: 0.2 + noise_level: 0.000001 + pos_weight: 2 + neg_weight: 1 + temperature: 0.1 + end_turn_token: "\n" + seg_pos: "randomized" + text: true + log_interval: 10 + val_every: 50 diff --git a/configuration.py b/configuration.py index 3d5bd83..a02f0ca 100644 --- a/configuration.py +++ b/configuration.py @@ -68,10 +68,34 @@ def __post_init__(self): @dataclass class OthersConfig: wandb_token: str - + def __post_init__(self): self.wandb_token = os.path.expanduser(self.wandb_token) + +@dataclass +class OptimizerConfig: + adapter_lr: float + lora_lr: float + + +@dataclass +class SchedulerConfig: + eta_min: float + + +@dataclass +class TrainingConfig: + epochs: int + optimizer: OptimizerConfig + scheduler: SchedulerConfig + preprocess_params: dict + model_params: dict + log_interval: int = 10 + val_every: int = 50 + per_device_train_batch_size: int = 2 + per_device_eval_batch_size: int = 2 + @dataclass class ProjectConfig: llava: LLavaConfig @@ -79,6 +103,7 @@ class ProjectConfig: sam: SAMConfig alphaclip: AlphaCLIPConfig others: OthersConfig + training: TrainingConfig def load_yaml_config(path) -> ProjectConfig: diff --git a/llava_finetune/model.py b/llava_finetune/model.py index e1f9880..b6a6fb0 100644 --- a/llava_finetune/model.py +++ b/llava_finetune/model.py @@ -426,35 +426,21 @@ def __init__( self.to(device) - def optim_step( + def compute_loss( self, texts: list[str], images: list[Image.Image], labels: list[str], pos_mask_embeds: list[torch.Tensor], neg_mask_embeds: list[torch.Tensor], - optimizer: torch.optim.Optimizer, ): - """ - Forward pass + optimization step through the model - - Args: - texts (list[str]): List of input text sequences I only the text - images (list[Image.Image]): List of input images - labels (list[str]): List of target text sequences for the model to predict I expect only the text that the model should predict - pos_mask_embeds (list[torch.Tensor]): List of positive mask embeddings with shape (num_pos_masks, seg_emb_size) - neg_mask_embeds (list[torch.Tensor]): List of negative mask embeddings with shape (num_neg_masks, seg_emb_size) - optimizer (torch.optim.Optimizer): Optimizer object to update the model's parameters - - Returns: - Tuple[torch.Tensor]: Logits for the next tokens computed for the last num_generate tokens in the input sequence and the loss - """ + """Compute logits and loss without performing an optimizer step.""" input_texts = [] free_token = 1 new_tokens = [] - + seg_pos = self.seg_pos - + possible_texts = [ "The segmentation mask for the object in the image is", "The requested object cab be found in", @@ -473,8 +459,9 @@ def optim_step( [pos_mask_embeds[i].size(0) for i in range(len(pos_mask_embeds))] + [neg_mask_embeds[i].size(0) for i in range(len(neg_mask_embeds))] ) - num_pos_tokens = sum([pos_mask_embeds[i].size(0) for i in range(len(pos_mask_embeds))]) - num_neg_tokens = sum([neg_mask_embeds[i].size(0) for i in range(len(neg_mask_embeds))]) + num_pos_tokens = sum( + [pos_mask_embeds[i].size(0) for i in range(len(pos_mask_embeds))] + ) token_masks = torch.zeros( len(texts), self.llava_model.original_vocab_size + num_new_tokens @@ -482,13 +469,12 @@ def optim_step( token_masks[:, : self.llava_model.tokenizer_vocab_size + 1] = 1 token_masks[:, self.llava_model.tokenizer_vocab_size + num_new_tokens + 1 :] = 1 - # Pass all tokens to the adapter and add the corresponding token lemma to the labels texts for i in range(len(pos_mask_embeds)): if self.seg_pos == "randomized": seg_pos = "before" if torch.rand(1) < 0.5 else "after" - + if seg_pos == "after" and self.text: - labels[i] = labels[i] + " "+possible_texts[torch.randint(0, len(possible_texts), (1,)).item()] + labels[i] = labels[i] + " " + possible_texts[torch.randint(0, len(possible_texts), (1,)).item()] elif seg_pos == "before" and self.text: labels[i] = ". " + labels[i] for j in range(pos_mask_embeds[i].size(0)): @@ -501,11 +487,10 @@ def optim_step( free_token += 1 if seg_pos == "before" and self.text: - labels[i] = possible_texts[torch.randint(0, len(possible_texts), (1,)).item()] + labels[i] + labels[i] = possible_texts[torch.randint(0, len(possible_texts), (1,)).item()] + labels[i] elif seg_pos == "after" and self.text: labels[i] = labels[i] + "." - # apply the chat template to the texts possibile_output_texts = [ "Output the segmentation mask for the object in the image", "Output the mask for the object in the image", @@ -520,15 +505,15 @@ def optim_step( "Provide the segmentation mask", "Provide the mask", ] - + if self.seg_pos == "randomized": seg_pos = "before" if torch.rand(1) < 0.5 else "after" - + if seg_pos == "after": user_message = f"\n{texts[i]} {possibile_output_texts[torch.randint(0, len(possibile_output_texts), (1,)).item()]}." elif seg_pos == "before": user_message = f"\n{possibile_output_texts[torch.randint(0, len(possibile_output_texts), (1,)).item()]}. {texts[i]}" - + input_texts.append( self.llava_model.processor.tokenizer.apply_chat_template( [ @@ -539,16 +524,8 @@ def optim_step( add_generation_prompt=False, ) ) - # get the last token of the text as the end token to be added to the labels labels[i] += f"{self.end_token}" - if DEBUG_PRINTS: - print() - print("MODEL INPUTS") - print(f"\tInput texts: {input_texts[0]}") - print(f"\tLabels: {labels[0]}") - print() - for i in range(len(neg_mask_embeds)): for j in range(neg_mask_embeds[i].size(0)): new_tokens.append(neg_mask_embeds[i][j]) @@ -558,8 +535,6 @@ def optim_step( new_tokens = torch.stack(new_tokens) new_tokens = self.adapter(new_tokens.to(self.device)) - - # tokenize the texts inputs = self.llava_model.processor( text=input_texts, images=images, @@ -567,14 +542,10 @@ def optim_step( padding=True, add_special_tokens=False, ).to(self.device) - # remove last token from the input text inputs["input_ids"] = inputs["input_ids"][:, :-1] - - # drop random tokens and substitute them with the unk token + mask = torch.rand(inputs["input_ids"].size()) < self.mask_tok_prob - # keep image tokens mask[inputs["input_ids"] == self.llava_model.processor.tokenizer.encode("", add_special_tokens=False)[0]] = False - # mask the tokens inputs["input_ids"][mask] = self.llava_model.processor.tokenizer.unk_token_id labels_ids = self.llava_model.processor( @@ -586,13 +557,6 @@ def optim_step( ) labels_input_ids = labels_ids["input_ids"].to(self.device) - # print() - # print("MODEL TOKENS INPUT") - # print(f"\tInput tokens: {inputs['input_ids']}") - # print(f"\tLabels: {labels_input_ids}") - # print() - - # Forward pass through the model logits = self.llava_model( **inputs, additional_tokens=new_tokens, @@ -607,45 +571,69 @@ def optim_step( class_weights = torch.ones(logits.size(-1)).to(self.device) class_weights[ - self.llava_model.tokenizer_vocab_size + 1 : - self.llava_model.tokenizer_vocab_size + num_pos_tokens + 1 + self.llava_model.tokenizer_vocab_size + 1 : self.llava_model.tokenizer_vocab_size + num_pos_tokens + 1 ] = self.pos_weight class_weights[ - self.llava_model.tokenizer_vocab_size + num_pos_tokens + 1 : - self.llava_model.tokenizer_vocab_size + num_new_tokens + 1 + self.llava_model.tokenizer_vocab_size + num_pos_tokens + 1 : self.llava_model.tokenizer_vocab_size + num_new_tokens + 1 ] = self.neg_weight + loss_fn = nn.CrossEntropyLoss( ignore_index=self.llava_model.processor.tokenizer.pad_token_id, reduction="none", weight=class_weights, ) - # where the labels_input_ids are one of the new tokens, the loss should be multiplied by 2 loss = loss_fn(logits, labels_input_ids) - weights = torch.ones_like(labels_input_ids) - # weights[labels_input_ids > self.llava_model.tokenizer_vocab_size] = 1.5 - # # penilize the model for not generating the end tokens - # weights[labels_input_ids == self.tokenized_end_token] = 1.5 loss = (loss * weights).mean() - - # if loss is nan break + if torch.isnan(loss): print("NAN LOSS") + return None, None, None + + return logits, loss, new_tokens + + def optim_step( + self, + texts: list[str], + images: list[Image.Image], + labels: list[str], + pos_mask_embeds: list[torch.Tensor], + neg_mask_embeds: list[torch.Tensor], + optimizer: torch.optim.Optimizer, + ): + """ + Forward pass + optimization step through the model + + Args: + texts (list[str]): List of input text sequences I only the text + images (list[Image.Image]): List of input images + labels (list[str]): List of target text sequences for the model to predict I expect only the text that the model should predict + pos_mask_embeds (list[torch.Tensor]): List of positive mask embeddings with shape (num_pos_masks, seg_emb_size) + neg_mask_embeds (list[torch.Tensor]): List of negative mask embeddings with shape (num_neg_masks, seg_emb_size) + optimizer (torch.optim.Optimizer): Optimizer object to update the model's parameters + + Returns: + Tuple[torch.Tensor]: Logits for the next tokens computed for the last num_generate tokens in the input sequence and the loss + """ + logits, loss, new_tokens = self.compute_loss( + texts, + images, + labels, + pos_mask_embeds, + neg_mask_embeds, + ) + + if loss is None: return None, None - # backward pass optimizer.zero_grad() - loss.backward() - # copy the gradients to the mask embeddings emb_grads = self.llava_model.llava_model.get_input_embeddings().weight.grad new_tokens.backward( emb_grads[ - self.llava_model.tokenizer_vocab_size + 1 : - self.llava_model.tokenizer_vocab_size - + num_new_tokens - + 1 + self.llava_model.tokenizer_vocab_size + 1 : + self.llava_model.tokenizer_vocab_size + new_tokens.size(0) + 1 ] ) diff --git a/pyproject.toml b/pyproject.toml index e8626ac..9c43fd7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,7 @@ dependencies = [ "tqdm>=4.67.1", "transformers==4.46.0", "wandb>=0.19.0", + "pycocotools>=2.0.6", ] [tool.uv.sources] diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..168889a --- /dev/null +++ b/pytest.ini @@ -0,0 +1,2 @@ +[pytest] +norecursedirs = AlphaCLIP diff --git a/requirements.txt b/requirements.txt index a1ba790..cfbb3e8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -108,6 +108,7 @@ sentry-sdk==2.19.2 setproctitle==1.3.4 setuptools==75.6.0 shapely==2.0.6 +pycocotools==2.0.10 shellingham==1.5.4 ; sys_platform != 'emscripten' six==1.16.0 smmap==5.0.1 diff --git a/train_trainer.py b/train_trainer.py new file mode 100644 index 0000000..a47dce8 --- /dev/null +++ b/train_trainer.py @@ -0,0 +1,71 @@ +import warnings +import random +import numpy as np +import torch +import wandb +from transformers import TrainingArguments +from configuration import load_yaml_config +from llava_finetune.utils import get_dataloaders, collate_fn +from llava_finetune.model import LISA_Model +from trainer_hf import LISATrainer + +warnings.filterwarnings("ignore") + +seed = 42 +torch.manual_seed(seed) +random.seed(seed) +np.random.seed(seed) + +config = load_yaml_config("config.yaml") +wandb.login(key=config.others.wandb_token) + +if __name__ == "__main__": + print("Loading Datasets") + train_loader, val_loader, test_loader = get_dataloaders( + config.dataset.json_path, + config.dataset.image_dir, + 2, + "data/train_final.jsonl", + "data/val_final.jsonl", + "data/test_final.jsonl", + ) + print("Datasets Loaded Successfully") + + train_dataset = train_loader.dataset + val_dataset = val_loader.dataset + test_dataset = test_loader.dataset + + training = config.training + + model = LISA_Model( + model_name=config.llava.model, + seg_emb_size=train_dataset[0]["gt_embs"].shape[1], + **training.model_params, + ).to("cuda") + + args = TrainingArguments( + output_dir="models/run", + per_device_train_batch_size=training.per_device_train_batch_size, + per_device_eval_batch_size=training.per_device_eval_batch_size, + num_train_epochs=training.epochs, + logging_steps=training.log_interval, + evaluation_strategy="steps", + eval_steps=training.val_every, + report_to=["wandb"], + learning_rate=training.optimizer.adapter_lr, + ) + + trainer = LISATrainer( + model=model, + args=args, + train_dataset=train_dataset, + eval_dataset=val_dataset, + data_collator=collate_fn, + lora_lr=training.optimizer.lora_lr, + eta_min=training.scheduler.eta_min, + ) + + trainer.train() + trainer.save_model("models/run") + metrics = trainer.evaluate(eval_dataset=test_dataset) + print(metrics) diff --git a/trainer_hf.py b/trainer_hf.py new file mode 100644 index 0000000..aa3608c --- /dev/null +++ b/trainer_hf.py @@ -0,0 +1,89 @@ +import torch +from torch.utils.data import DataLoader +from transformers import Trainer +from bitsandbytes.optim import AdamW8bit +from torch.optim.lr_scheduler import CosineAnnealingLR + +from llava_finetune.utils import collate_fn +from llava_finetune.functions import val_step + +class LISATrainer(Trainer): + def __init__(self, *args, lora_lr=1e-6, eta_min=1e-4, **kwargs): + self.lora_lr = lora_lr + self.eta_min = eta_min + super().__init__(*args, **kwargs) + self.current_new_tokens = None + + def compute_loss(self, model, inputs, return_outputs=False): + logits, loss, new_tokens = model.compute_loss( + texts=inputs["queries"], + images=inputs["image"], + labels=inputs["answer"], + pos_mask_embeds=inputs["gt_embs"], + neg_mask_embeds=inputs["sam_embs"], + ) + self.current_new_tokens = new_tokens + return (loss, {"logits": logits}) if return_outputs else loss + + def backward(self, loss): + super().backward(loss) + if self.current_new_tokens is not None: + emb_grads = ( + self.model.llava_model.llava_model.get_input_embeddings().weight.grad + ) + self.current_new_tokens.backward( + emb_grads[ + self.model.llava_model.tokenizer_vocab_size + 1 : self.model.llava_model.tokenizer_vocab_size + + self.current_new_tokens.size(0) + 1 + ] + ) + + def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, closure, **kwargs): + super().optimizer_step(epoch, batch_idx, optimizer, optimizer_idx, closure, **kwargs) + self.model.llava_model.reset_tokens() + + def create_optimizer(self): + if self.optimizer is None: + adapter_params = [p for p in self.model.adapter.parameters()] + lora_params = [p for p in self.model.llava_model.llava_model.parameters() if p.requires_grad] + self.optimizer = AdamW8bit( + [ + {"params": adapter_params, "lr": self.args.learning_rate}, + {"params": lora_params, "lr": self.lora_lr}, + ] + ) + return self.optimizer + + def create_scheduler(self, num_training_steps, optimizer=None): + if self.lr_scheduler is None: + self.lr_scheduler = CosineAnnealingLR( + optimizer or self.optimizer, + T_max=self.args.num_train_epochs, + eta_min=self.eta_min, + ) + return self.lr_scheduler + + def evaluate(self, eval_dataset=None, **kwargs): + dataset = eval_dataset or self.eval_dataset + dataloader = DataLoader( + dataset, + batch_size=self.args.per_device_eval_batch_size, + shuffle=False, + collate_fn=collate_fn, + ) + ( + accuracy, + precision, + recall, + f1, + _, + _, + _, + _, + ) = val_step(self.model, dataloader, 0) + return { + "eval_accuracy": accuracy, + "eval_precision": precision, + "eval_recall": recall, + "eval_f1": f1, + }