Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 29 additions & 1 deletion config.example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,32 @@ alphaclip:
checkpoint_dir: "..."

others:
wandb_token: "TOKEN"
wandb_token: "TOKEN"

training:
epochs: 50
optimizer:
adapter_lr: 0.001
lora_lr: 0.000001
scheduler:
eta_min: 0.0001
preprocess_params:
model: "alpha-clip"
only_masks: false
model_params:
lora_rank: 16
q4: true
q8: false
expand_factor: 2
mlp_adapter: true
mask_prob: 0.2
dropout: 0.2
noise_level: 0.000001
pos_weight: 2
neg_weight: 1
temperature: 0.1
end_turn_token: "<end_of_turn>\n"
seg_pos: "randomized"
text: true
log_interval: 10
val_every: 50
27 changes: 26 additions & 1 deletion configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,17 +68,42 @@ def __post_init__(self):
@dataclass
class OthersConfig:
wandb_token: str

def __post_init__(self):
self.wandb_token = os.path.expanduser(self.wandb_token)


@dataclass
class OptimizerConfig:
adapter_lr: float
lora_lr: float


@dataclass
class SchedulerConfig:
eta_min: float


@dataclass
class TrainingConfig:
epochs: int
optimizer: OptimizerConfig
scheduler: SchedulerConfig
preprocess_params: dict
model_params: dict
log_interval: int = 10
val_every: int = 50
per_device_train_batch_size: int = 2
per_device_eval_batch_size: int = 2

@dataclass
class ProjectConfig:
llava: LLavaConfig
dataset: DatasetConfig
sam: SAMConfig
alphaclip: AlphaCLIPConfig
others: OthersConfig
training: TrainingConfig


def load_yaml_config(path) -> ProjectConfig:
Expand Down
124 changes: 56 additions & 68 deletions llava_finetune/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,35 +426,21 @@ def __init__(

self.to(device)

def optim_step(
def compute_loss(
self,
texts: list[str],
images: list[Image.Image],
labels: list[str],
pos_mask_embeds: list[torch.Tensor],
neg_mask_embeds: list[torch.Tensor],
optimizer: torch.optim.Optimizer,
):
"""
Forward pass + optimization step through the model

Args:
texts (list[str]): List of input text sequences I only the text
images (list[Image.Image]): List of input images
labels (list[str]): List of target text sequences for the model to predict I expect only the text that the model should predict
pos_mask_embeds (list[torch.Tensor]): List of positive mask embeddings with shape (num_pos_masks, seg_emb_size)
neg_mask_embeds (list[torch.Tensor]): List of negative mask embeddings with shape (num_neg_masks, seg_emb_size)
optimizer (torch.optim.Optimizer): Optimizer object to update the model's parameters

Returns:
Tuple[torch.Tensor]: Logits for the next tokens computed for the last num_generate tokens in the input sequence and the loss
"""
"""Compute logits and loss without performing an optimizer step."""
input_texts = []
free_token = 1
new_tokens = []

seg_pos = self.seg_pos

possible_texts = [
"The segmentation mask for the object in the image is",
"The requested object cab be found in",
Expand All @@ -473,22 +459,22 @@ def optim_step(
[pos_mask_embeds[i].size(0) for i in range(len(pos_mask_embeds))]
+ [neg_mask_embeds[i].size(0) for i in range(len(neg_mask_embeds))]
)
num_pos_tokens = sum([pos_mask_embeds[i].size(0) for i in range(len(pos_mask_embeds))])
num_neg_tokens = sum([neg_mask_embeds[i].size(0) for i in range(len(neg_mask_embeds))])
num_pos_tokens = sum(
[pos_mask_embeds[i].size(0) for i in range(len(pos_mask_embeds))]
)

token_masks = torch.zeros(
len(texts), self.llava_model.original_vocab_size + num_new_tokens
).to(self.device)
token_masks[:, : self.llava_model.tokenizer_vocab_size + 1] = 1
token_masks[:, self.llava_model.tokenizer_vocab_size + num_new_tokens + 1 :] = 1

# Pass all tokens to the adapter and add the corresponding token lemma to the labels texts
for i in range(len(pos_mask_embeds)):
if self.seg_pos == "randomized":
seg_pos = "before" if torch.rand(1) < 0.5 else "after"

if seg_pos == "after" and self.text:
labels[i] = labels[i] + " "+possible_texts[torch.randint(0, len(possible_texts), (1,)).item()]
labels[i] = labels[i] + " " + possible_texts[torch.randint(0, len(possible_texts), (1,)).item()]
elif seg_pos == "before" and self.text:
labels[i] = ". " + labels[i]
for j in range(pos_mask_embeds[i].size(0)):
Expand All @@ -501,11 +487,10 @@ def optim_step(
free_token += 1

if seg_pos == "before" and self.text:
labels[i] = possible_texts[torch.randint(0, len(possible_texts), (1,)).item()] + labels[i]
labels[i] = possible_texts[torch.randint(0, len(possible_texts), (1,)).item()] + labels[i]
elif seg_pos == "after" and self.text:
labels[i] = labels[i] + "."

# apply the chat template to the texts
possibile_output_texts = [
"Output the segmentation mask for the object in the image",
"Output the mask for the object in the image",
Expand All @@ -520,15 +505,15 @@ def optim_step(
"Provide the segmentation mask",
"Provide the mask",
]

if self.seg_pos == "randomized":
seg_pos = "before" if torch.rand(1) < 0.5 else "after"

if seg_pos == "after":
user_message = f"<image>\n{texts[i]} {possibile_output_texts[torch.randint(0, len(possibile_output_texts), (1,)).item()]}."
elif seg_pos == "before":
user_message = f"<image>\n{possibile_output_texts[torch.randint(0, len(possibile_output_texts), (1,)).item()]}. {texts[i]}"

input_texts.append(
self.llava_model.processor.tokenizer.apply_chat_template(
[
Expand All @@ -539,16 +524,8 @@ def optim_step(
add_generation_prompt=False,
)
)
# get the last token of the text as the end token to be added to the labels
labels[i] += f"{self.end_token}"

if DEBUG_PRINTS:
print()
print("MODEL INPUTS")
print(f"\tInput texts: {input_texts[0]}")
print(f"\tLabels: {labels[0]}")
print()

for i in range(len(neg_mask_embeds)):
for j in range(neg_mask_embeds[i].size(0)):
new_tokens.append(neg_mask_embeds[i][j])
Expand All @@ -558,23 +535,17 @@ def optim_step(
new_tokens = torch.stack(new_tokens)
new_tokens = self.adapter(new_tokens.to(self.device))


# tokenize the texts
inputs = self.llava_model.processor(
text=input_texts,
images=images,
return_tensors="pt",
padding=True,
add_special_tokens=False,
).to(self.device)
# remove last token from the input text
inputs["input_ids"] = inputs["input_ids"][:, :-1]

# drop random tokens and substitute them with the unk token

mask = torch.rand(inputs["input_ids"].size()) < self.mask_tok_prob
# keep image tokens
mask[inputs["input_ids"] == self.llava_model.processor.tokenizer.encode("<image>", add_special_tokens=False)[0]] = False
# mask the tokens
inputs["input_ids"][mask] = self.llava_model.processor.tokenizer.unk_token_id

labels_ids = self.llava_model.processor(
Expand All @@ -586,13 +557,6 @@ def optim_step(
)
labels_input_ids = labels_ids["input_ids"].to(self.device)

# print()
# print("MODEL TOKENS INPUT")
# print(f"\tInput tokens: {inputs['input_ids']}")
# print(f"\tLabels: {labels_input_ids}")
# print()

# Forward pass through the model
logits = self.llava_model(
**inputs,
additional_tokens=new_tokens,
Expand All @@ -607,45 +571,69 @@ def optim_step(

class_weights = torch.ones(logits.size(-1)).to(self.device)
class_weights[
self.llava_model.tokenizer_vocab_size + 1 :
self.llava_model.tokenizer_vocab_size + num_pos_tokens + 1
self.llava_model.tokenizer_vocab_size + 1 : self.llava_model.tokenizer_vocab_size + num_pos_tokens + 1
] = self.pos_weight
class_weights[
self.llava_model.tokenizer_vocab_size + num_pos_tokens + 1 :
self.llava_model.tokenizer_vocab_size + num_new_tokens + 1
self.llava_model.tokenizer_vocab_size + num_pos_tokens + 1 : self.llava_model.tokenizer_vocab_size + num_new_tokens + 1
] = self.neg_weight

loss_fn = nn.CrossEntropyLoss(
ignore_index=self.llava_model.processor.tokenizer.pad_token_id,
reduction="none",
weight=class_weights,
)
# where the labels_input_ids are one of the new tokens, the loss should be multiplied by 2
loss = loss_fn(logits, labels_input_ids)

weights = torch.ones_like(labels_input_ids)
# weights[labels_input_ids > self.llava_model.tokenizer_vocab_size] = 1.5
# # penilize the model for not generating the end tokens
# weights[labels_input_ids == self.tokenized_end_token] = 1.5
loss = (loss * weights).mean()

# if loss is nan break

if torch.isnan(loss):
print("NAN LOSS")
return None, None, None

return logits, loss, new_tokens

def optim_step(
self,
texts: list[str],
images: list[Image.Image],
labels: list[str],
pos_mask_embeds: list[torch.Tensor],
neg_mask_embeds: list[torch.Tensor],
optimizer: torch.optim.Optimizer,
):
"""
Forward pass + optimization step through the model

Args:
texts (list[str]): List of input text sequences I only the text
images (list[Image.Image]): List of input images
labels (list[str]): List of target text sequences for the model to predict I expect only the text that the model should predict
pos_mask_embeds (list[torch.Tensor]): List of positive mask embeddings with shape (num_pos_masks, seg_emb_size)
neg_mask_embeds (list[torch.Tensor]): List of negative mask embeddings with shape (num_neg_masks, seg_emb_size)
optimizer (torch.optim.Optimizer): Optimizer object to update the model's parameters

Returns:
Tuple[torch.Tensor]: Logits for the next tokens computed for the last num_generate tokens in the input sequence and the loss
"""
logits, loss, new_tokens = self.compute_loss(
texts,
images,
labels,
pos_mask_embeds,
neg_mask_embeds,
)

if loss is None:
return None, None

# backward pass
optimizer.zero_grad()

loss.backward()

# copy the gradients to the mask embeddings
emb_grads = self.llava_model.llava_model.get_input_embeddings().weight.grad
new_tokens.backward(
emb_grads[
self.llava_model.tokenizer_vocab_size + 1 :
self.llava_model.tokenizer_vocab_size
+ num_new_tokens
+ 1
self.llava_model.tokenizer_vocab_size + 1 :
self.llava_model.tokenizer_vocab_size + new_tokens.size(0) + 1
]
)

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ dependencies = [
"tqdm>=4.67.1",
"transformers==4.46.0",
"wandb>=0.19.0",
"pycocotools>=2.0.6",
]

[tool.uv.sources]
Expand Down
2 changes: 2 additions & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[pytest]
norecursedirs = AlphaCLIP
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ sentry-sdk==2.19.2
setproctitle==1.3.4
setuptools==75.6.0
shapely==2.0.6
pycocotools==2.0.10
shellingham==1.5.4 ; sys_platform != 'emscripten'
six==1.16.0
smmap==5.0.1
Expand Down
Loading