From 738111aafedccf6e6dfabd45358d344a3dff9eda Mon Sep 17 00:00:00 2001 From: ankitade Date: Tue, 21 Jun 2022 04:32:17 +0000 Subject: [PATCH] Temp CL [ghstack-poisoned] --- torchmultimodal/models/flava/flava_model.py | 76 +++++++++++++++++---- torchmultimodal/modules/losses/flava.py | 44 +++++++++--- 2 files changed, 96 insertions(+), 24 deletions(-) diff --git a/torchmultimodal/models/flava/flava_model.py b/torchmultimodal/models/flava/flava_model.py index 8a2af13d..88dbdb2c 100644 --- a/torchmultimodal/models/flava/flava_model.py +++ b/torchmultimodal/models/flava/flava_model.py @@ -37,8 +37,17 @@ FLAVAOutput = namedtuple( "FLAVAOutput", - ["image", "image_masked", "text", "text_masked", "multimodal", "multimodal_masked"], - defaults=(None, None, None, None, None, None), + [ + "image", + "image_masked", + "text", + "text_masked", + "multimodal", + "multimodal_masked", + "projected_image_embeddings", + "projected_text_embeddings", + ], + defaults=(None, None, None, None, None, None, None, None), ) FLAVAOutput.__annotations__ = { "image": FLAVATransformerOutput, @@ -124,6 +133,8 @@ def flava_model( multimodal_intermediate_activation: Callable[..., Tensor] = nn.functional.gelu, multimodal_attention_probs_dropout_prob: float = 0.0, multimodal_layer_norm_eps: float = 1e-12, + # projection + text_and_image_proj_size: int = 768, **kwargs: Any, ): image_encoder = flava_image_encoder( @@ -169,12 +180,17 @@ def flava_model( image_to_mm_projection = nn.Linear(image_hidden_size, multimodal_hidden_size) text_to_mm_projection = nn.Linear(text_hidden_size, multimodal_hidden_size) + image_projection = nn.Linear(image_hidden_size, text_and_image_proj_size) + text_projection = nn.Linear(text_hidden_size, text_and_image_proj_size) + return FLAVAModel( image_encoder=image_encoder, text_encoder=text_encoder, mm_encoder=mm_encoder, image_to_mm_projection=image_to_mm_projection, text_to_mm_projection=text_to_mm_projection, + text_projection=text_projection, + image_projection=image_projection, ) @@ -246,6 +262,8 @@ def __init__( mm_encoder: nn.Module, image_to_mm_projection: nn.Module, text_to_mm_projection: nn.Module, + text_projection: nn.Module, + image_projection: nn.Module, **kwargs: Any, ): super().__init__() @@ -254,6 +272,8 @@ def __init__( self.mm_encoder = mm_encoder self.image_to_mm_projection = image_to_mm_projection self.text_to_mm_projection = text_to_mm_projection + self.text_projection = text_projection + self.image_projection = image_projection def forward( self, @@ -272,18 +292,30 @@ def forward( else: required_embedding = "text" - image_outputs = self._encode_data_to_embeddings( + image_encoding_out = self._encode_data_to_embeddings( image, required_embedding, ["image", "mm"], - self.encode_image, + partial(self.encode_image, projection=True), ) - text_outputs = self._encode_data_to_embeddings( + if len(image_encoding_out) == 2: + image_outputs, projected_image_embeddings = image_encoding_out + else: + image_outputs = image_encoding_out + projected_image_embeddings = None + + text_encoding_out = self._encode_data_to_embeddings( text, required_embedding, ["text", "mm"], - self.encode_text, + partial(self.encode_text, projection=True), ) + if len(text_encoding_out) == 2: + text_outputs, projected_text_embeddings = text_encoding_out + else: + text_outputs = text_encoding_out + projected_text_embeddings = None + image_masked_outputs = self._encode_data_to_embeddings( image, required_embedding, @@ -329,26 +361,41 @@ def forward( text_masked=text_masked_outputs, multimodal=multimodal_outputs, multimodal_masked=multimodal_masked_outputs, + projected_image_embeddings=projected_image_embeddings, + projected_text_embeddings=projected_text_embeddings, ) def encode_image( - self, image: Tensor, image_patches_mask: Optional[Tensor] = None + self, + image: Tensor, + image_patches_mask: Optional[Tensor] = None, + projection: bool = False, ) -> Optional[FLAVATransformerOutput]: if image_patches_mask is not None: - return self.image_encoder(image, image_patches_mask) + encoded_image = self.image_encoder(image, image_patches_mask) else: - return self.image_encoder(image) + encoded_image = self.image_encoder(image) + if projection: + projected_embeddings = self.image_projection( + encoded_image.last_hidden_state[:, 0, :] + ) + return encoded_image, projected_embeddings + return encoded_image def encode_text( - self, - text: Tensor, - text_mask: Optional[Tensor] = None, + self, text: Tensor, text_mask: Optional[Tensor] = None, projection: bool = False ) -> Optional[FLAVATransformerOutput]: # TODO(asg): Give proper parameter names when implementing text encoder - return self.text_encoder( + encoded_text = self.text_encoder( input_ids=text, attention_mask=text_mask, ) + if projection: + projected_embeddings = self.text_projection( + encoded_text.last_hidden_state[:, 0, :] + ) + return encoded_text, projected_embeddings + return encoded_text def _encode_data_to_embeddings( self, @@ -361,7 +408,6 @@ def _encode_data_to_embeddings( if data is not None and selected_head_encoder in encoder_options: output = encode_callable(data) - return output def encode_mm( @@ -450,6 +496,8 @@ def forward( itm_labels=itm_labels, mim_labels=image_labels, mlm_labels=mlm_labels, + projected_image_embeddings=flava_output.projected_image_embeddings, + projected_text_embeddings=flava_output.projected_text_embeddings, ) diff --git a/torchmultimodal/modules/losses/flava.py b/torchmultimodal/modules/losses/flava.py index d40e9afa..b0f8c6ee 100644 --- a/torchmultimodal/modules/losses/flava.py +++ b/torchmultimodal/modules/losses/flava.py @@ -249,10 +249,10 @@ def __init__( else: self.logit_scale = nn.Parameter(logit_scale * torch.ones([])) - self.image_projection = nn.Linear(image_embedding_size, projection_size) - self.text_projection = nn.Linear(text_embedding_size, projection_size) - self.image_embedding_index = image_embedding_index - self.text_embedding_index = text_embedding_index + # self.image_projection = nn.Linear(image_embedding_size, projection_size) + # self.text_projection = nn.Linear(text_embedding_size, projection_size) + # self.image_embedding_index = image_embedding_index + # self.text_embedding_index = text_embedding_index def forward( self, @@ -260,11 +260,17 @@ def forward( text_sequence: Tensor, mask: Tensor, ): - text_embedding = nn.functional.normalize( - self.text_projection(text_sequence[:, self.text_embedding_index, :]), dim=-1 - ) + # text_embedding = nn.functional.normalize( + # self.text_projection(text_sequence[:, self.text_embedding_index, :]), dim=-1 + # ) + # image_embedding = nn.functional.normalize( + # self.image_projection(image_sequence[:, self.image_embedding_index, :]), + # dim=-1, + # ) + + text_embedding = nn.functional.normalize(text_sequence, dim=-1) image_embedding = nn.functional.normalize( - self.image_projection(image_sequence[:, self.image_embedding_index, :]), + image_sequence, dim=-1, ) @@ -278,6 +284,7 @@ def forward( # Always true for FLAVA global contrastive loss backprop_in_gather=True, ) + print(output.loss) return FLAVAGlobalContrastiveLossOutput( loss=output.loss, @@ -376,6 +383,8 @@ def forward( itm_labels: Optional[Tensor] = None, mim_labels: Optional[Tensor] = None, mlm_labels: Optional[Tensor] = None, + projected_image_embeddings=None, + projected_text_embeddings=None, ) -> FLAVAPretrainingLossOutput: outputs = FLAVAPretrainingLossOutput() pos_mask = None @@ -386,8 +395,8 @@ def forward( and self.contrastive_loss_weight > 0 ): outputs.global_contrastive_output = self.contrastive_loss( - image_sequence, - text_sequence, + projected_image_embeddings, + projected_text_embeddings, pos_mask, ) outputs.global_contrastive_output.loss *= self.contrastive_loss_weight @@ -398,6 +407,21 @@ def forward( # Check multimodal_masked_sequence to make sure this is unimodal case # This specific case can though be backpropagated directly as MIM is independent of # text, but that is a research question :) + if ( + image_sequence is not None + and text_sequence is not None + and self.contrastive_loss_weight > 0 + ): + outputs.global_contrastive_output = self.contrastive_loss( + projected_image_embeddings, + projected_text_embeddings, + pos_mask, + ) + outputs.global_contrastive_output.loss *= self.contrastive_loss_weight + outputs.losses.global_contrastive_loss = ( + outputs.global_contrastive_output.loss + ) + if ( image_masked_sequence is not None and self.mim_weight > 0