From b70fe33528cd3cb12cc4aef8618572b657d3aaa0 Mon Sep 17 00:00:00 2001 From: sablin39 <1020030829@qq.com> Date: Sun, 23 Mar 2025 20:30:09 +0800 Subject: [PATCH 1/7] Initial commit after copy --- framefusion/interface.py | 145 + framefusion/main.py | 314 ++ .../modeling_llava_next_video.py | 236 ++ .../llava_video/modeling_llava_video.py | 339 +++ .../models/minicpmv/modeling_minicpmv.py | 109 + framefusion/models/qwen2/modeling_qwen2.py | 333 +++ .../models/qwen2/modeling_qwen2_baseline.py | 2562 +++++++++++++++++ framefusion/utils.py | 101 + 8 files changed, 4139 insertions(+) create mode 100644 framefusion/interface.py create mode 100644 framefusion/main.py create mode 100644 framefusion/models/llava_next_video/modeling_llava_next_video.py create mode 100644 framefusion/models/llava_video/modeling_llava_video.py create mode 100644 framefusion/models/minicpmv/modeling_minicpmv.py create mode 100644 framefusion/models/qwen2/modeling_qwen2.py create mode 100644 framefusion/models/qwen2/modeling_qwen2_baseline.py create mode 100644 framefusion/utils.py diff --git a/framefusion/interface.py b/framefusion/interface.py new file mode 100644 index 0000000..a39af80 --- /dev/null +++ b/framefusion/interface.py @@ -0,0 +1,145 @@ +# common imports +from types import MethodType +from typing import Callable +import torch +import torch.nn as nn +from accelerate.hooks import add_hook_to_module +from transformers import PreTrainedModel + +# framefusion methods +from framefusion.main import FrameFusion +from framefusion.utils import TEXT_TOKEN, IGNORE_TOKEN, get_attr_by_name + +# model types +from transformers import LlavaNextVideoForConditionalGeneration +from llava.model.language_model.llava_qwen import LlavaQwenForCausalLM + +# replace methods +from framefusion.models.llava_next_video.modeling_llava_next_video import _merge_input_ids_with_image_features_get_token_type +from framefusion.models.llava_video.modeling_llava_video import prepare_inputs_labels_for_multimodal_get_patch_type +from framefusion.models.minicpmv.modeling_minicpmv import get_vllm_embedding +from framefusion.models.qwen2.modeling_qwen2 import Qwen2Model_merge_then_fastv_cost_given_forward, Qwen2DecoderLayer_merge_then_prune_by_cost_forward, Qwen2SdpaAttention_merge_then_prune_by_cost_forward + + +def apply_framefusion(model, cost, similarity_lower_bound, ratio_lower_bound): + """ + Apply FrameFusion to the model + + Args: + model: the model to apply FrameFusion to + cost: the cost of the FrameFusion + similarity_lower_bound: the similarity lower bound of the FrameFusion + ratio_lower_bound: the ratio lower bound of the FrameFusion + """ + # LlavaNextVideo Model + if isinstance(model, LlavaNextVideoForConditionalGeneration): + model._merge_input_ids_with_image_features = MethodType(_merge_input_ids_with_image_features_get_token_type, model) + + llm_forward = Qwen2Model_merge_then_fastv_cost_given_forward + decoder_forward = Qwen2DecoderLayer_merge_then_prune_by_cost_forward + attention_forward = Qwen2SdpaAttention_merge_then_prune_by_cost_forward + llm_key = "model" + decoder_key = "layers" + attention_key = "self_attn" + + # LlavaVideo Model + elif isinstance(model, LlavaQwenForCausalLM): + model.prepare_inputs_labels_for_multimodal = MethodType(prepare_inputs_labels_for_multimodal_get_patch_type, model) + + llm_forward = Qwen2Model_merge_then_fastv_cost_given_forward + decoder_forward = Qwen2DecoderLayer_merge_then_prune_by_cost_forward + attention_forward = Qwen2SdpaAttention_merge_then_prune_by_cost_forward + llm_key = "model" + decoder_key = "layers" + attention_key = "self_attn" + + # MiniCPM Model + elif model.config.architectures[0] == "MiniCPMV": + + model.get_vllm_embedding = MethodType(get_vllm_embedding, model) + llm_forward = Qwen2Model_merge_then_fastv_cost_given_forward + decoder_forward = Qwen2DecoderLayer_merge_then_prune_by_cost_forward + attention_forward = Qwen2SdpaAttention_merge_then_prune_by_cost_forward + llm_key = "llm.model" + decoder_key = "layers" + attention_key = "self_attn" + + else: + raise NotImplementedError + + replace_framefusion_forward( + model, + cost=cost, + similarity_lower_bound=similarity_lower_bound, + ratio_lower_bound=ratio_lower_bound, + llm_forward=llm_forward, + decoder_forward=decoder_forward, + attention_forward=attention_forward, + llm_key=llm_key, + decoder_key=decoder_key, + attention_key=attention_key, + ) + + +def get_token_type(model): + # LlavaNextVideo Model + if isinstance(model, LlavaNextVideoForConditionalGeneration): + model._merge_input_ids_with_image_features = MethodType(_merge_input_ids_with_image_features_get_token_type, model) + + # LlavaVideo Model + elif isinstance(model, LlavaQwenForCausalLM): + model.prepare_inputs_labels_for_multimodal = MethodType(prepare_inputs_labels_for_multimodal_get_patch_type, model) + + # MiniCPM Model + elif model.config.architectures[0] == "MiniCPMV": + model.get_vllm_embedding = MethodType(get_vllm_embedding, model) + else: + raise NotImplementedError + + +def replace_framefusion_forward( + module: torch.nn.Module, + cost: float, + similarity_lower_bound: float, + ratio_lower_bound: float, + llm_forward: Callable, + decoder_forward: Callable, + attention_forward: Callable, + llm_key: str = "model", + decoder_key: str = "layers", + attention_key: str = "self_attn", +): + """ + Replace the forward method of the model with the framefusion forward method. + Make framefusion a property of the model. + + The keys are accessed in an hierarchical manner: llm_key -> decoder_key -> attention_key. Each key can have multiple hierarchies, e.g. "llm.model", which will be accessed by module.llm.model + """ + framefusion = FrameFusion(cost, similarity_lower_bound, ratio_lower_bound) + + module.framefusion = framefusion + + llm = get_attr_by_name(module, llm_key) + assert isinstance(llm, PreTrainedModel), f"{llm_key} is not a PreTrainedModel" + + llm.framefusion = framefusion + llm.forward = MethodType(llm_forward, llm) + + decoder_layers = get_attr_by_name(llm, decoder_key) + for i, decoder_layer in enumerate(decoder_layers): + assert isinstance(decoder_layer, nn.Module), f"{decoder_key}[{i}] is not a nn.Module" + + decoder_layer.framefusion = framefusion + decoder_layer.forward = MethodType(decoder_forward, decoder_layer) + + # ensure accelerate hooks are not removed + if hasattr(decoder_layer, "_hf_hook"): + decoder_layer._old_forward = MethodType(decoder_forward, decoder_layer) + add_hook_to_module(decoder_layer, decoder_layer._hf_hook) + + qwen2_attention_instance = get_attr_by_name(decoder_layer, attention_key) + assert isinstance(qwen2_attention_instance, nn.Module), f"{decoder_key}[{i}].self_attn is not a nn.Module" + + # replace the forward method of the attention layer + qwen2_attention_instance.framefusion = framefusion + qwen2_attention_instance.forward = MethodType(attention_forward, qwen2_attention_instance) diff --git a/framefusion/main.py b/framefusion/main.py new file mode 100644 index 0000000..2e4b520 --- /dev/null +++ b/framefusion/main.py @@ -0,0 +1,314 @@ +from typing import List +import torch +from torch import nn + +TEXT_TOKEN = -1 +IGNORE_TOKEN = -2 + +class FrameFusion(nn.Module): + def __init__(self, cost=0.3, similarity_lower_bound=0.6, ratio_lower_bound=0.1): + super(FrameFusion, self).__init__() + self.cost = cost + self.similarity_lower_bound = similarity_lower_bound + self.ratio_lower_bound = ratio_lower_bound + + def prepare(self, patch_type, patch_num, image_token_start_index, image_token_end_index, image_token_length, original_length, finish_merging = False, finish_pruning = False, sparsity_list: List = None): + self.patch_type = patch_type + self.patch_num = patch_num + self.image_token_start_index = image_token_start_index + self.image_token_end_index = image_token_end_index + self.image_token_length = image_token_length + self.original_length = original_length + self.finish_merging = finish_merging + self.finish_pruning = finish_pruning + if sparsity_list is None: + self.sparsity_list = [] + else: + self.sparsity_list = sparsity_list + + def forward(self, hidden_states, position_embeddings, attention_mask, self_attn_weights = None): + """ + This is the forward method of the FrameFusion class. + + Args: + hidden_states (torch.Tensor): A tensor of shape (batch_size, sequence_length, hidden_size). + position_embeddings (torch.Tensor): A tensor of shape (batch_size, sequence_length, hidden_size). + attention_mask (torch.Tensor): A tensor of shape (batch_size, sequence_length, sequence_length). + self_attn_weights (torch.Tensor): A tensor of shape (batch_size, sequence_length, sequence_length). + + Returns: + hidden_states (torch.Tensor): A tensor of shape (batch_size, sequence_length, hidden_size). + position_embeddings (torch.Tensor): A tensor of shape (batch_size, sequence_length, hidden_size). + attention_mask (torch.Tensor): A tensor of shape (batch_size, sequence_length, sequence_length). + """ + bsz, q_len, hidden_size = hidden_states.size() + device = hidden_states.device + + # pruning + if q_len >1 and self.finish_merging == True and self.finish_pruning == False: + + image_token_pruning_start_index = self.image_token_start_index.item() + image_token_pruning_length = self.image_token_length + # update image_token_pruning_length + image_token_pruning_length = (self.image_token_length - (self.original_length - q_len)) + + last_layer_attention = self_attn_weights + last_layer_attention_avg = torch.mean(last_layer_attention, dim=(1,2))[0] + last_layer_attention_avg_image = last_layer_attention_avg[image_token_pruning_start_index:image_token_pruning_start_index+image_token_pruning_length] + + pruning_ratio = self._compute_pruning_ratio(self.sparsity_list, self.cost) + top_attention_rank_index = last_layer_attention_avg_image.topk(round(image_token_pruning_length*(1-pruning_ratio))).indices + image_token_pruning_start_index + + keep_indexs = torch.cat( (torch.arange(image_token_pruning_start_index,device=device), top_attention_rank_index, torch.arange(image_token_pruning_start_index+image_token_pruning_length, q_len, device=device))) + keep_indexs = keep_indexs.sort().values + + hidden_states = hidden_states[:,keep_indexs,:] + position_embeddings[0] = position_embeddings[0][:,keep_indexs,:] + position_embeddings[1] = position_embeddings[1][:,keep_indexs,:] + if attention_mask != None: + attention_mask = attention_mask[:,:,keep_indexs,:][:,:,:,keep_indexs] + self.finish_pruning = True + + # merging + if q_len >1 and (not self.finish_merging): + # align devices + self.patch_type = self.patch_type.to(device) + + # prefill + sparsity_upper_bound = self._compute_pruning_ratio(self.sparsity_list, self.cost) + similarity_by_patch, token_index_by_patch = self.compute_similarity_and_token_index_by_patch(hidden_states, self.patch_type, self.patch_num) # only support bsz = 1 + + frame_token_num = torch.sum(self.patch_type != TEXT_TOKEN).item() + merge_index_by_patch = torch.where(similarity_by_patch >= self.similarity_lower_bound)[1] + above_k_ratio = merge_index_by_patch.shape[0] / frame_token_num + + if above_k_ratio < sparsity_upper_bound: + self.sparsity_list.append(above_k_ratio) + + if above_k_ratio < self.ratio_lower_bound: + self.finish_merging = True + else: + topk_values, topk_indices = torch.topk(similarity_by_patch, int(sparsity_upper_bound*frame_token_num)) + topk_indices, _ = torch.sort(topk_indices) + merge_index_by_patch = topk_indices[0] + + self.finish_merging = True + self.finish_pruning = True + + + hidden_states, token_mask = self.merge_tokens_and_get_mask(hidden_states, similarity_by_patch, token_index_by_patch, merge_index_by_patch) + # here only bsz=1 + # update patch type + self.patch_type = self.patch_type.to(device)[token_mask].reshape(bsz, -1) + hidden_states = hidden_states[token_mask, :].reshape(bsz, -1, hidden_size) + position_embeddings[0] = position_embeddings[0][:,token_mask[0],:] + position_embeddings[1] = position_embeddings[1][:,token_mask[0],:] + if attention_mask is not None: + attention_mask = attention_mask[:,:,token_mask[0],:][:,:,:,token_mask[0]] + + return hidden_states, position_embeddings, attention_mask + + @staticmethod + def compute_similarity_and_token_index_by_patch(hidden_states, token_patch_type, patch_num): + """ + Compute the similarity between consecutive tokens of the same patch type and record the token index. + + Args: + hidden_states (torch.Tensor): A tensor of shape (batch_size, sequence_length, hidden_size). + token_patch_type (torch.Tensor): A tensor indicating the patch type of each token in the sequence. + patch_num (int): The total number of patches of one image in the model. + + Returns: + similarity_by_patch (torch.Tensor): A tensor of shape (batch_size, sequence_length) containing + the cosine similarity between consecutive tokens of the + same patch type. Tokens from different patches are set to -2. + token_index_by_patch (torch.Tensor): A tensor of shape (batch_size, sequence_length) containing + the token index corresponding to the new order after + sorting by patch type. + + """ + + bsz, q_len, _ = hidden_states.size() + device = hidden_states.device + + assert bsz == 1, "Only support batch size 1" + + token_index_by_patch = [] + similarity_by_patch = [] + + + token_patch_type_by_patch, token_index_by_patch = torch.where( + token_patch_type == torch.arange(patch_num, device=device)[:, None] + ) + + # noqa: reshape to batch size = 1, with shape (batch_size, q_len), + token_patch_type_by_patch = token_patch_type_by_patch[None, :] + token_index_by_patch = token_index_by_patch[None, :] + + similarity_by_patch = cosine_similarity( + hidden_states[ + torch.arange(bsz, device=device), token_index_by_patch[:, :-1], : + ], + hidden_states[ + torch.arange(bsz, device=device), token_index_by_patch[:, 1:], : + ], + ) + + similarity_by_patch[token_patch_type_by_patch[:, :-1] != token_patch_type_by_patch[:, 1:]] = -2 + + similarity_by_patch = torch.cat( + ( + torch.full( + size=(bsz, 1), + fill_value=IGNORE_TOKEN, + dtype=hidden_states.dtype, + device=device, + ), + similarity_by_patch, + ), + dim=1, + ) + + assert similarity_by_patch.shape[1] == token_index_by_patch.shape[1] + return similarity_by_patch, token_index_by_patch + + + @staticmethod + def merge_tokens_and_get_mask(hidden_states: torch.Tensor, similarity_by_patch, token_index_by_patch, merge_index_by_patch): + """ + Merge tokens and get a mask indicating which tokens to keep. + + Args: + hidden_states (torch.Tensor): A tensor of shape (batch_size, sequence_length, hidden_size) + similarity_by_patch (torch.Tensor): A tensor of shape (batch_size, sequence_length) containing + the cosine similarity between consecutive tokens of the + same patch type. + token_index_by_patch (torch.Tensor): A tensor of shape (batch_size, sequence_length) containing + the token indices corresponding to the new order after + sorting by patch type. + merge_index_by_patch (torch.Tensor): A tensor containing the indices of tokens to be merged, in the patch_type order. + + Returns: + hidden_states (torch.Tensor): A tensor containing the hidden states of the tokens after merging. + keep_mask (torch.Tensor): A boolean tensor of shape (batch_size, sequence_length) indicating + which tokens in the original sequence should be kept after merging. + """ + device = hidden_states.device + if merge_index_by_patch.shape[0] == 0: + keep_mask = torch.ones(hidden_states.shape[:-1], dtype=torch.bool, device=device) + return hidden_states, keep_mask + bsz, q_len, _ = hidden_states.size() + bsz_index = torch.arange(bsz, device=hidden_states.device)[:, None] + merge_mask_by_patch: torch.LongTensor = torch.zeros( + bsz, + similarity_by_patch.shape[1], + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + merge_mask_by_patch[bsz_index, merge_index_by_patch] = 1 + last_merge_token_by_patch = find_contigious_latter_index(merge_mask_by_patch) + + keep_mask = torch.ones(hidden_states.shape[:-1], dtype=torch.bool, device=device) + keep_mask[bsz_index, token_index_by_patch[bsz_index, merge_index_by_patch]] = False + + # noqa: batch size = 1 + unique_merge_nums = torch.sort(torch.unique(last_merge_token_by_patch.to(torch.long))).values + unique_merge_nums = (unique_merge_nums[1:] if (unique_merge_nums[0] == 0).item() else unique_merge_nums) + + merge_num_indices, token_merge_index_in_patch = torch.where( + last_merge_token_by_patch == unique_merge_nums[:, None] + ) + + merge_nums = unique_merge_nums[merge_num_indices] + token_merge_start_index_in_patch = token_merge_index_in_patch - merge_nums + token_merge_member_start_index_in_patch = torch.repeat_interleave(token_merge_start_index_in_patch, merge_nums) + + merge_member_length = torch.sum(merge_nums) + merge_member_contigious_sequence = torch.arange(1, merge_member_length + 1, device = device) + + merge_nums_cumulative_counts = torch.cumsum(merge_nums, dim=0) + merge_nums_start = torch.cat((torch.tensor([0], device = device), merge_nums_cumulative_counts[:-1])) + + contigious_sequence_by_merge_nums = merge_member_contigious_sequence - torch.repeat_interleave(merge_nums_start, merge_nums) + + token_merge_member_index_in_patch = token_merge_member_start_index_in_patch + contigious_sequence_by_merge_nums + + # noqa: this function may have numerical instability + hidden_states.index_add_( + dim = 1, + index = token_index_by_patch[0, token_merge_member_start_index_in_patch], + source = hidden_states[ + bsz_index, + token_index_by_patch[bsz_index, token_merge_member_index_in_patch], + ] + ) + + # divide to get average + hidden_states[ + bsz_index, + token_index_by_patch[bsz_index, token_merge_start_index_in_patch], + ] /= (merge_nums[None, :, None] + 1) + + + return hidden_states, keep_mask + + @staticmethod + def _compute_pruning_ratio(sparsity_list, cost, num_layers = 28): + """ + Args: + sparsity_list (list): A list containing the sparsity values of the model's first few layers. + cost (float): The total computation budget given by the user. + num_layers (int, optional): The number of layers in the model. + + Returns: + float: the required sparsity for the next layer to achieve the given cost + """ + list_length = len(sparsity_list) + s = 1 + total_calcution =0 + for i in range(list_length): + s *= (1 - sparsity_list[i]) + total_calcution += s + remain_calcution = num_layers * cost - total_calcution + if remain_calcution < 0: + raise ValueError("The cost is too small") + if remain_calcution/((num_layers-list_length)*s) > 1: + return 0 + return 1 - (remain_calcution/((num_layers-list_length)*s)) + +def cosine_similarity(mat1, mat2): + dot_product = torch.sum(mat1*mat2, dim=-1) + norm_vec1 = torch.norm(mat1, dim=-1) + norm_vec2 = torch.norm(mat2, dim=-1) + return dot_product / (norm_vec1 * norm_vec2) + +def find_contigious_latter_index(index_tensor: torch.LongTensor) -> torch.Tensor: + """ + Args: + index_tensor (torch.LongTensor): A binary tensor containing sequences of ones and zeros. + + Returns: + torch.Tensor: A tensor where each contiguous sequence of ones in the input tensor + is replaced by zeros, except for the last element of each sequence, + which is replaced by the length of that sequence. + + Example: + Input: torch.tensor([0, 1, 1, 1, 0, 0, 1, 1]) + Output: torch.tensor([0, 0, 0, 3, 0, 0, 0, 2]) + """ + bsz, n = index_tensor.shape + t_prev = torch.cat([torch.zeros((bsz, 1), dtype=index_tensor.dtype, device=index_tensor.device), index_tensor[:, :-1]], dim=1) + t_next = torch.cat([index_tensor[:, 1:], torch.zeros((bsz, 1), dtype=index_tensor.dtype, device=index_tensor.device)], dim=1) + + # Identify the starts and ends of runs of ones + run_starts = (index_tensor == 1) & (t_prev == 0) + run_ends = (index_tensor == 1) & (t_next == 0) + + start_indices = torch.nonzero(run_starts, as_tuple=True) + end_indices = torch.nonzero(run_ends, as_tuple=True) + run_lengths = (end_indices[1] - start_indices[1] + 1).to(index_tensor.dtype) + + output = torch.zeros_like(index_tensor, dtype=index_tensor.dtype) + output[end_indices[0], end_indices[1]] = run_lengths + + return output \ No newline at end of file diff --git a/framefusion/models/llava_next_video/modeling_llava_next_video.py b/framefusion/models/llava_next_video/modeling_llava_next_video.py new file mode 100644 index 0000000..476eb27 --- /dev/null +++ b/framefusion/models/llava_next_video/modeling_llava_next_video.py @@ -0,0 +1,236 @@ +import torch +from transformers.models.llava_next_video.modeling_llava_next_video import logger + + +def _merge_input_ids_with_image_features_get_token_type( + self, + image_features, + feature_lens, + inputs_embeds, + input_ids, + attention_mask, + position_ids=None, + labels=None, + image_token_index=None, + ignore_index=-100, +): + """ + Merge input_ids with with image features into final embeddings + + Args: + image_features (`torch.Tensor` of shape `(all_feature_lens, embed_dim)`): + All vision vectors of all images in the batch + feature_lens (`torch.LongTensor` of shape `(num_images)`): + The length of visual embeddings of each image as stacked in `image_features` + inputs_embeds (`torch.Tensor` of shape `(batch_size, sequence_length, embed_dim)`): + Token embeddings before merging with visual embeddings + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Input_ids of tokens, possibly filled with image token + attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Mask to avoid performing attention on padding token indices. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + labels (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*) + :abels need to be recalculated to support training (if provided) + image_token_index (`int`, *optional*) + Token id used to indicate the special "image" token. Defaults to `config.image_token_index` + ignore_index (`int`, *optional*) + Value that is used to pad `labels` and will be ignored when calculated loss. Default: -100. + Returns: + final_embedding, final_attention_mask, position_ids, final_labels + + Explanation: + each image has variable length embeddings, with length specified by feature_lens + image_features is concatenation of all visual embed vectors + task: fill each with the correct number of visual embeddings + Example: + X (5 patches), Y (3 patches), Z (8) + X, Y are in the same sequence (in-context learning) + if right padding + input_ids: [ + a b c d e f X g h i j k Y l m + o p q r Z s t u v _ _ _ _ _ _ + ] + input_ids should be: [ + a b c d e f X X X X X g h i j k Y Y Y l m + o p q r Z Z Z Z Z Z Z Z s t u v _ _ _ _ _ + ] + labels should be: [ + a b c d e f _ _ _ _ _ g h i j k _ _ _ l m + o p q r _ _ _ _ _ _ _ _ s t u v _ _ _ _ _ + ] + elif left padding + input_ids: [ + a b c d e f X g h i j k Y l m + _ _ _ _ _ _ o p q r Z s t u v + ] + input_ids should be: [ + a b c d e f X X X X X g h i j k Y Y Y l m + _ _ _ _ _ o p q r Z Z Z Z Z Z Z Z s t u v + ] + labels should be: [ + a b c d e f _ _ _ _ _ g h i j k _ _ _ l m + _ _ _ _ _ o p q r _ _ _ _ _ _ _ _ s t u v + ] + Edge cases: + * If tokens are same but image token sizes are different, then cannot infer left or right padding + ```python + cat_img = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw) + chart_img = Image.open(requests.get("https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/images/llava_v1_5_radar.jpg?raw=true", stream=True).raw) + prompts = [ + "[INST] \nWhat is shown in this image? [/INST]", + "[INST] \nWhat is shown in this image? [/INST]", + ] + inputs = processor(prompts, [chart_img, cat_img], return_tensors='pt', padding=True).to("cuda") + chart_img has 2634 tokens, while cat_img has 2340 tokens + ``` + + input_ids: [ + a b c d X g h + i j Y k l m n + ] + where X is 3 tokens while Y is 5, this mean after merge + if left-padding (batched generation) + input_ids should be: [ + _ _ a b c d X X X g h + i j Y Y Y Y Y k l m n + ] + elif (right padding) (training) + input_ids should be: [ + a b c d X X X g h _ _ + i j Y Y Y Y Y k l m n + ] + """ + image_token_index = image_token_index if image_token_index is not None else self.config.image_token_index + ignore_index = ignore_index if ignore_index is not None else self.config.ignore_index + + if self.training and self.padding_side == "left": + logger.warning_once( + "Padding side is set to 'left' but the model is in training mode. For training " "it is recommended to set `model.padding_side='right' and `processor.tokenizer.padding_side='right'`. " "If that's intended, ignore this warning" + ) + if not self.training and self.padding_side == "right": + logger.warning_once( + "Padding side is set to 'right' but the model is in inference mode. For correct " "generation results, please set `model.padding_side='left'` and `processor.tokenizer.padding_side='left'`. " "If that's intended, ignore this warning" + ) + + with torch.no_grad(): + # ! in llava 1.6, number of patches is variable + num_images = feature_lens.size(0) + num_image_features, embed_dim = image_features.shape + if feature_lens.sum() != num_image_features: + raise ValueError(f"{feature_lens=} / {feature_lens.sum()} != {image_features.shape=}") + batch_size = input_ids.shape[0] + _left_padding = torch.any(attention_mask[:, 0] == 0) + _right_padding = torch.any(attention_mask[:, -1] == 0) + + left_padding = self.padding_side == "left" + if batch_size > 1: + if _left_padding and _right_padding: + raise ValueError(f"both side of attention_mask has zero, invalid. {attention_mask}") + elif _right_padding and left_padding: + left_padding = False + elif _left_padding and not left_padding: + left_padding = True + + # Whether to turn off right padding + # 1. Create a mask to know where special image tokens are + special_image_token_mask = input_ids == image_token_index + # special_image_token_mask: [bsz, seqlen] + num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1) + # num_special_image_tokens: [bsz] + # Reserve for padding of num_images + total_num_special_image_tokens = torch.sum(special_image_token_mask) + if total_num_special_image_tokens != num_images: + raise ValueError(f"Number of image tokens in input_ids ({total_num_special_image_tokens}) different from num_images ({num_images}).") + # Compute the maximum embed dimension + # max_image_feature_lens is max_feature_lens per batch + feature_lens = feature_lens.to(input_ids.device) + feature_lens_batch = feature_lens.split(num_special_image_tokens.tolist(), dim=0) + feature_lens_batch_sum = torch.tensor([x.sum() for x in feature_lens_batch], device=input_ids.device) + embed_sequence_lengths = (attention_mask == 1).long().sum(-1) - num_special_image_tokens + feature_lens_batch_sum + max_embed_dim = embed_sequence_lengths.max() + + batch_indices, non_image_indices = torch.where((input_ids != image_token_index) & (attention_mask == 1)) + # 2. Compute the positions where text should be written + # Calculate new positions for text tokens in merged image-text sequence. + # `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images` text tokens. + # `torch.cumsum` computes how each image token shifts subsequent text token positions. + # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one. + # ! instead of special_image_token_mask * (num_image_patches - 1) + # special_image_token_mask * (num_feature_len - 1) + special_image_token_mask = special_image_token_mask.long() + special_image_token_mask[special_image_token_mask == 1] = feature_lens - 1 + new_token_positions = torch.cumsum((special_image_token_mask + 1), -1) - 1 + if left_padding: + # shift right token positions so that they are ending at the same number + # the below here was incorrect? new_token_positions += new_token_positions[:, -1].max() - new_token_positions[:, -1:] + new_token_positions += max_embed_dim - 1 - new_token_positions[:, -1:] + + text_to_overwrite = new_token_positions[batch_indices, non_image_indices] + + # 3. Create the full embedding, already padded to the maximum position + final_embedding = torch.zeros(batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device) + final_attention_mask = torch.zeros(batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device) + final_input_ids = torch.full((batch_size, max_embed_dim), self.pad_token_id, dtype=input_ids.dtype, device=inputs_embeds.device) + # In case the Vision model or the Language model has been offloaded to CPU, we need to manually + # set the corresponding tensors into their correct target device. + target_device = inputs_embeds.device + batch_indices, non_image_indices, text_to_overwrite = ( + batch_indices.to(target_device), + non_image_indices.to(target_device), + text_to_overwrite.to(target_device), + ) + attention_mask = attention_mask.to(target_device) + input_ids = input_ids.to(target_device) + + # 4. Fill the embeddings based on the mask. If we have ["hey" "", "how", "are"] + # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features + final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices] + final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices] + final_input_ids[batch_indices, text_to_overwrite] = input_ids[batch_indices, non_image_indices] + final_labels = None + if labels is not None: + labels = labels.to(target_device) + final_labels = torch.full_like(final_attention_mask, ignore_index).to(torch.long) + final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices] + + # 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835) + with torch.no_grad(): + image_to_overwrite = torch.full((batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device) + image_to_overwrite[batch_indices, text_to_overwrite] = False + embed_indices = torch.arange(max_embed_dim).unsqueeze(0).to(target_device) + embed_indices = embed_indices.expand(batch_size, max_embed_dim) + embed_seq_lens = embed_sequence_lengths[:, None].to(target_device) + + if left_padding: + # exclude padding on the left + max_embed_dim = max_embed_dim.to(target_device) + val = (max_embed_dim - embed_indices) <= embed_seq_lens + else: + # exclude padding on the right + val = embed_indices < embed_seq_lens + image_to_overwrite &= val + + if image_to_overwrite.sum() != num_image_features: + raise ValueError( + f"{image_to_overwrite.sum()=} != {num_image_features=} The input provided to the model are wrong. " + f"The number of image tokens is {torch.sum(special_image_token_mask)} while" + f" the number of image given to the model is {num_images}. " + f"This prevents correct indexing and breaks batch generation." + ) + final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device) + final_attention_mask |= image_to_overwrite + position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1) + + token_type = torch.ones_like(final_input_ids) * -10 + token_type[batch_indices, text_to_overwrite] = -1 + token_per_frame = self.vision_tower.vision_model.embeddings.num_patches // self.vision_resampler.pool.kernel_size**2 + for n_batch in range(token_type.shape[0]): + n_frame = image_to_overwrite[n_batch].sum() // token_per_frame + frame_token_type = torch.arange(n_frame, dtype=token_type.dtype, device=token_type.device).reshape(-1, 1).expand(-1, token_per_frame).reshape(-1) + token_type[n_batch, image_to_overwrite[n_batch]] = frame_token_type + self.token_type = token_type + self.current_embedding=final_embedding + + return final_embedding, final_attention_mask, position_ids, final_labels, final_input_ids diff --git a/framefusion/models/llava_video/modeling_llava_video.py b/framefusion/models/llava_video/modeling_llava_video.py new file mode 100644 index 0000000..aa8ab4d --- /dev/null +++ b/framefusion/models/llava_video/modeling_llava_video.py @@ -0,0 +1,339 @@ +import math +import re +import torch +import torch.nn as nn + +from llava.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX + +from llava.mm_utils import get_anyres_image_grid_shape +from llava.utils import rank0_print +import random + +SPECIAL_TOKEN = -9 +IGNORE_TOKEN = -2 +TEXT_TOKEN = -1 + + +def prepare_inputs_labels_for_multimodal_get_patch_type(self, input_ids, position_ids, attention_mask, past_key_values, labels, images, modalities=["image"], image_sizes=None): + vision_tower = self.get_vision_tower() + # rank_print(modalities) + if vision_tower is None or images is None or input_ids.shape[1] == 1: + return input_ids, position_ids, attention_mask, past_key_values, None, labels + + if isinstance(modalities, str): + modalities = [modalities] + + # import pdb; pdb.set_trace() + if type(images) is list or images.ndim == 5: + if type(images) is list: + images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images] + + video_idx_in_batch = [] + for _ in range(len(modalities)): + if modalities[_] == "video": + video_idx_in_batch.append(_) + + images_list = [] + for image in images: + if image.ndim == 4: + images_list.append(image) + else: + images_list.append(image.unsqueeze(0)) + + concat_images = torch.cat([image for image in images_list], dim=0) + split_sizes = [image.shape[0] for image in images_list] + encoded_image_features = self.encode_images(concat_images) + # image_features,all_faster_video_features = self.encode_multimodals(concat_images, video_idx_in_batch, split_sizes) + + # This is a list, each element is [num_images, patch * patch, dim] + # rank_print(f"Concat images : {concat_images.shape}") + encoded_image_features = torch.split(encoded_image_features, split_sizes) + image_features = [] + for idx, image_feat in enumerate(encoded_image_features): + if idx in video_idx_in_batch: + image_features.append(self.get_2dPool(image_feat)) + else: + image_features.append(image_feat) + # image_features = self.encode_multimodals(concat_images, video_idx_in_batch, split_sizes) + # rank_print(f"Encoded image feats : {[x.shape for x in image_features]}") + # image_features = torch.split(image_features, split_sizes, dim=0) + mm_patch_merge_type = getattr(self.config, "mm_patch_merge_type", "flat") + image_aspect_ratio = getattr(self.config, "image_aspect_ratio", "square") + mm_newline_position = getattr(self.config, "mm_newline_position", "one_token") + + if mm_patch_merge_type == "flat": + image_features = [x.flatten(0, 1) for x in image_features] + + elif mm_patch_merge_type.startswith("spatial"): + new_image_features = [] + for image_idx, image_feature in enumerate(image_features): + # FIXME: now assume the image is square, and split to 2x2 patches + # num_patches = h * w, where h = w = sqrt(num_patches) + # currently image_feature is a tensor of shape (4, num_patches, hidden_size) + # we want to first unflatten it to (2, 2, h, w, hidden_size) + # rank0_print("At least we are reaching here") + # import pdb; pdb.set_trace() + if image_idx in video_idx_in_batch: # video operations + # rank0_print("Video") + if mm_newline_position == "grid": + # Grid-wise + image_feature = self.add_token_per_grid(image_feature) + if getattr(self.config, "add_faster_video", False): + faster_video_feature = self.add_token_per_grid(all_faster_video_features[image_idx]) + # Add a token for each frame + concat_slow_fater_token = [] + # import pdb; pdb.set_trace() + for _ in range(image_feature.shape[0]): + if _ % self.config.faster_token_stride == 0: + concat_slow_fater_token.append(torch.cat((image_feature[_], self.model.faster_token[None].to(image_feature.device)), dim=0)) + else: + concat_slow_fater_token.append(torch.cat((faster_video_feature[_], self.model.faster_token[None].to(image_feature.device)), dim=0)) + # import pdb; pdb.set_trace() + image_feature = torch.cat(concat_slow_fater_token) + + # print("!!!!!!!!!!!!") + + new_image_features.append(image_feature) + elif mm_newline_position == "frame": + # Frame-wise + image_feature = self.add_token_per_frame(image_feature) + + new_image_features.append(image_feature.flatten(0, 1)) + + elif mm_newline_position == "one_token": + # one-token + image_feature = image_feature.flatten(0, 1) + if "unpad" in mm_patch_merge_type: + image_feature = torch.cat((image_feature, self.model.image_newline[None].to(image_feature.device)), dim=0) + new_image_features.append(image_feature) + elif mm_newline_position == "no_token": + new_image_features.append(image_feature.flatten(0, 1)) + else: + raise ValueError(f"Unexpected mm_newline_position: {mm_newline_position}") + elif image_feature.shape[0] > 1: # multi patches and multi images operations + # rank0_print("Single-images") + base_image_feature = image_feature[0] + image_feature = image_feature[1:] + height = width = self.get_vision_tower().num_patches_per_side + assert height * width == base_image_feature.shape[0] + + if "anyres_max" in image_aspect_ratio: + matched_anyres_max_num_patches = re.match(r"anyres_max_(\d+)", image_aspect_ratio) + if matched_anyres_max_num_patches: + max_num_patches = int(matched_anyres_max_num_patches.group(1)) + + if image_aspect_ratio == "anyres" or "anyres_max" in image_aspect_ratio: + if hasattr(self.get_vision_tower(), "image_size"): + vision_tower_image_size = self.get_vision_tower().image_size + else: + raise ValueError("vision_tower_image_size is not found in the vision tower.") + try: + num_patch_width, num_patch_height = get_anyres_image_grid_shape(image_sizes[image_idx], self.config.image_grid_pinpoints, vision_tower_image_size) + except Exception as e: + rank0_print(f"Error: {e}") + num_patch_width, num_patch_height = 2, 2 + image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1) + else: + image_feature = image_feature.view(2, 2, height, width, -1) + + if "maxpool2x2" in mm_patch_merge_type: + image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() + image_feature = image_feature.flatten(1, 2).flatten(2, 3) + image_feature = nn.functional.max_pool2d(image_feature, 2) + image_feature = image_feature.flatten(1, 2).transpose(0, 1) + elif "unpad" in mm_patch_merge_type and "anyres_max" in image_aspect_ratio and matched_anyres_max_num_patches: + unit = image_feature.shape[2] + image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() + image_feature = image_feature.flatten(1, 2).flatten(2, 3) + image_feature = unpad_image(image_feature, image_sizes[image_idx]) + c, h, w = image_feature.shape + times = math.sqrt(h * w / (max_num_patches * unit**2)) + if times > 1.1: + image_feature = image_feature[None] + image_feature = nn.functional.interpolate(image_feature, [int(h // times), int(w // times)], mode="bilinear")[0] + image_feature = torch.cat((image_feature, self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)), dim=-1) + image_feature = image_feature.flatten(1, 2).transpose(0, 1) + elif "unpad" in mm_patch_merge_type: + image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() + image_feature = image_feature.flatten(1, 2).flatten(2, 3) + image_feature = unpad_image(image_feature, image_sizes[image_idx]) + image_feature = torch.cat((image_feature, self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)), dim=-1) + image_feature = image_feature.flatten(1, 2).transpose(0, 1) + else: + image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous() + image_feature = image_feature.flatten(0, 3) + if "nobase" in mm_patch_merge_type: + pass + else: + image_feature = torch.cat((base_image_feature, image_feature), dim=0) + new_image_features.append(image_feature) + else: # single image operations + image_feature = image_feature[0] + if "unpad" in mm_patch_merge_type: + image_feature = torch.cat((image_feature, self.model.image_newline[None]), dim=0) + + new_image_features.append(image_feature) + image_features = new_image_features + else: + raise ValueError(f"Unexpected mm_patch_merge_type: {self.config.mm_patch_merge_type}") + else: + image_features = self.encode_images(images) + + # TODO: image start / end is not implemented here to support pretraining. + if getattr(self.config, "tune_mm_mlp_adapter", False) and getattr(self.config, "mm_use_im_start_end", False): + raise NotImplementedError + # rank_print(f"Total images : {len(image_features)}") + + # Let's just add dummy tensors if they do not exist, + # it is a headache to deal with None all the time. + # But it is not ideal, and if you have a better idea, + # please open an issue / submit a PR, thanks. + _labels = labels + _position_ids = position_ids + _attention_mask = attention_mask + if attention_mask is None: + attention_mask = torch.ones_like(input_ids, dtype=torch.bool) + else: + attention_mask = attention_mask.bool() + if position_ids is None: + position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device) + if labels is None: + labels = torch.full_like(input_ids, IGNORE_INDEX) + + # remove the padding using attention_mask -- FIXME + _input_ids = input_ids + input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)] + labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)] + + new_input_embeds = [] + new_labels = [] + cur_image_idx = 0 + # rank_print("Inserting Images embedding") + for batch_idx, cur_input_ids in enumerate(input_ids): + num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum() + # rank0_print(num_images) + if num_images == 0: + cur_image_features = image_features[cur_image_idx] + cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids) + cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0) + new_input_embeds.append(cur_input_embeds) + new_labels.append(labels[batch_idx]) + cur_image_idx += 1 + continue + + image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]] + cur_input_ids_noim = [] + cur_labels = labels[batch_idx] + cur_labels_noim = [] + for i in range(len(image_token_indices) - 1): + cur_input_ids_noim.append(cur_input_ids[image_token_indices[i] + 1 : image_token_indices[i + 1]]) + cur_labels_noim.append(cur_labels[image_token_indices[i] + 1 : image_token_indices[i + 1]]) + split_sizes = [x.shape[0] for x in cur_labels_noim] + cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim)) + cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0) + cur_new_input_embeds = [] + cur_new_labels = [] + + for i in range(num_images + 1): + cur_new_input_embeds.append(cur_input_embeds_no_im[i]) + cur_new_labels.append(cur_labels_noim[i]) + if i < num_images: + try: + cur_image_features = image_features[cur_image_idx] + except IndexError: + cur_image_features = image_features[cur_image_idx - 1] + cur_image_idx += 1 + cur_new_input_embeds.append(cur_image_features) + cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype)) + + cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds] + + # import pdb; pdb.set_trace() + cur_new_input_embeds = torch.cat(cur_new_input_embeds) + cur_new_labels = torch.cat(cur_new_labels) + + new_input_embeds.append(cur_new_input_embeds) + new_labels.append(cur_new_labels) + + # Truncate sequences to max length as image embeddings can make the sequence longer + tokenizer_model_max_length = getattr(self.config, "tokenizer_model_max_length", None) + # rank_print("Finishing Inserting") + + new_input_embeds = [x[:tokenizer_model_max_length] for x, modality in zip(new_input_embeds, modalities)] + new_labels = [x[:tokenizer_model_max_length] for x, modality in zip(new_labels, modalities)] + # TODO: Hard code for control loss spike + # if tokenizer_model_max_length is not None: + # new_input_embeds = [x[:4096] if modality != "video" else x[:tokenizer_model_max_length] for x, modality in zip(new_input_embeds, modalities)] + # new_labels = [x[:4096] if modality != "video" else x[:tokenizer_model_max_length] for x, modality in zip(new_labels, modalities)] + + # Combine them + max_len = max(x.shape[0] for x in new_input_embeds) + batch_size = len(new_input_embeds) + + new_input_embeds_padded = [] + new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device) + attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device) + position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device) + # rank0_print("Prepare pos id") + + for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)): + cur_len = cur_new_embed.shape[0] + if getattr(self.config, "tokenizer_padding_side", "right") == "left": + new_input_embeds_padded.append(torch.cat((torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device), cur_new_embed), dim=0)) + if cur_len > 0: + new_labels_padded[i, -cur_len:] = cur_new_labels + attention_mask[i, -cur_len:] = True + position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device) + else: + new_input_embeds_padded.append(torch.cat((cur_new_embed, torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)), dim=0)) + if cur_len > 0: + new_labels_padded[i, :cur_len] = cur_new_labels + attention_mask[i, :cur_len] = True + position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device) + + new_input_embeds = torch.stack(new_input_embeds_padded, dim=0) + # rank0_print("tokenizer padding") + + if _labels is None: + new_labels = None + else: + new_labels = new_labels_padded + + if _attention_mask is None: + attention_mask = None + else: + attention_mask = attention_mask.to(dtype=_attention_mask.dtype) + + if _position_ids is None: + position_ids = None + if getattr(self.config, "use_pos_skipping", False) and self.training: + position_ids = torch.arange(new_input_embeds.size(1), device=new_input_embeds.device).unsqueeze(0).to(new_input_embeds.device) + split_position = random.randint(0, new_input_embeds.size(1)) + left_add = random.randint(0, self.config.pos_skipping_range) + right_add = random.randint(left_add, self.config.pos_skipping_range) + position_ids[:, :split_position] += left_add + position_ids[:, split_position:] += right_add + # import pdb; pdb.set_trace() + # rank0_print("Finish preparing") + + ### FRAMEFUSION START ### + if self.config.mm_spatial_pool_mode == "bilinear": + patch_size = math.ceil(self.get_vision_tower().num_patches_per_side / 2) + else: + patch_size = self.get_vision_tower().num_patches_per_side // 2 + patch_num = patch_size * (patch_size + 1) + + assert batch_size == 1 + assert num_images == 1 + image_token_length = image_features[0].shape[0] + n_frames = image_token_length // patch_num + image_token_start_index = torch.where(input_ids[0] == IMAGE_TOKEN_INDEX)[0] + image_token_end_index = image_token_start_index + image_token_length - 1 + original_length = input_ids[0].shape[0] + image_token_length - 1 + patch_type = [TEXT_TOKEN] * image_token_start_index + list(range(patch_num)) * n_frames + [TEXT_TOKEN] * (original_length - image_token_end_index - 1) + patch_type = torch.tensor([patch_type], device=new_input_embeds.device) + + self.framefusion.prepare(patch_type, patch_num, image_token_start_index, image_token_end_index, image_token_length, original_length) + ### FRAMEFUSION END ### + + return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels diff --git a/framefusion/models/minicpmv/modeling_minicpmv.py b/framefusion/models/minicpmv/modeling_minicpmv.py new file mode 100644 index 0000000..3689db2 --- /dev/null +++ b/framefusion/models/minicpmv/modeling_minicpmv.py @@ -0,0 +1,109 @@ +import torch +import math + +TEXT_TOKEN = -1 + + +def get_vllm_embedding(self, data): + if "vision_hidden_states" not in data: + dtype = self.llm.model.embed_tokens.weight.dtype + device = self.llm.model.embed_tokens.weight.device + tgt_sizes = data["tgt_sizes"] + pixel_values_list = data["pixel_values"] + vision_hidden_states = [] + all_pixel_values = [] + img_cnt = [] + for pixel_values in pixel_values_list: + img_cnt.append(len(pixel_values)) + all_pixel_values.extend([i.flatten(end_dim=1).permute(1, 0) for i in pixel_values]) + + # exist image + if all_pixel_values: + tgt_sizes = [tgt_size for tgt_size in tgt_sizes if isinstance(tgt_size, torch.Tensor)] + tgt_sizes = torch.vstack(tgt_sizes).type(torch.int32) + + max_patches = torch.max(tgt_sizes[:, 0] * tgt_sizes[:, 1]) + + all_pixel_values = torch.nn.utils.rnn.pad_sequence(all_pixel_values, batch_first=True, padding_value=0.0) + B, L, _ = all_pixel_values.shape + all_pixel_values = all_pixel_values.permute(0, 2, 1).reshape(B, 3, -1, L) + + patch_attn_mask = torch.zeros((B, 1, max_patches), dtype=torch.bool, device=device) + for i in range(B): + patch_attn_mask[i, 0, : tgt_sizes[i][0] * tgt_sizes[i][1]] = True + + vision_batch_size = self.config.vision_batch_size + all_pixel_values = all_pixel_values.type(dtype) + if B > vision_batch_size: + hs = [] + for i in range(0, B, vision_batch_size): + start_idx = i + end_idx = i + vision_batch_size + tmp_hs = self.vpm(all_pixel_values[start_idx:end_idx], patch_attention_mask=patch_attn_mask[start_idx:end_idx], tgt_sizes=tgt_sizes[start_idx:end_idx]).last_hidden_state + hs.append(tmp_hs) + vision_embedding = torch.cat(hs, dim=0) + else: + vision_embedding = self.vpm(all_pixel_values, patch_attention_mask=patch_attn_mask, tgt_sizes=tgt_sizes).last_hidden_state + vision_embedding = self.resampler(vision_embedding, tgt_sizes) + + start = 0 + for pixel_values in pixel_values_list: + img_cnt = len(pixel_values) + if img_cnt > 0: + vision_hidden_states.append(vision_embedding[start : start + img_cnt]) + start += img_cnt + else: + vision_hidden_states.append([]) + else: # no image + if self.training: + dummy_image = torch.zeros((1, 3, 224, 224), device=device, dtype=dtype) + tgt_sizes = torch.Tensor([[(224 // self.config.patch_size), math.ceil(224 / self.config.patch_size)]]).type(torch.int32) + dummy_feature = self.resampler(self.vpm(dummy_image).last_hidden_state, tgt_sizes) + else: + dummy_feature = [] + for _ in range(len(pixel_values_list)): + vision_hidden_states.append(dummy_feature) + + else: + vision_hidden_states = data["vision_hidden_states"] + + if hasattr(self.llm.config, "scale_emb"): + vllm_embedding = self.llm.model.embed_tokens(data["input_ids"]) * self.llm.config.scale_emb + else: + vllm_embedding = self.llm.model.embed_tokens(data["input_ids"]) + + vision_hidden_states = [i.type(vllm_embedding.dtype) if isinstance(i, torch.Tensor) else i for i in vision_hidden_states] + + bs = len(data["input_ids"]) + for i in range(bs): + cur_vs_hs = vision_hidden_states[i] + if len(cur_vs_hs) > 0: + cur_vllm_emb = vllm_embedding[i] + cur_image_bound = data["image_bound"][i] + if len(cur_image_bound) > 0: + image_indices = torch.stack([torch.arange(r[0], r[1], dtype=torch.long) for r in cur_image_bound]).to(vllm_embedding.device) + + cur_vllm_emb.scatter_(0, image_indices.view(-1, 1).repeat(1, cur_vllm_emb.shape[-1]), cur_vs_hs.view(-1, cur_vs_hs.shape[-1])) + elif self.training: + cur_vllm_emb += cur_vs_hs[0].mean() * 0 + + ### FRAMEFUSION START ### + assert bs == 1 + patch_type = torch.full((bs, vllm_embedding.shape[1]), TEXT_TOKEN, dtype=torch.long, device=vllm_embedding.device) + num_frames = self.num_frames + + image_bound = data["image_bound"][0] + patch_per_frame = image_bound.shape[0] // num_frames + token_per_frame = image_bound[patch_per_frame, 0] - image_bound[0, 0] + patch_type[i, image_bound[0, 0] : (image_bound[-1, 1] + 2)] = torch.arange(0, image_bound[-1, 1] - image_bound[0, 0] + 2, device=patch_type.device) % token_per_frame + + patch_num = token_per_frame + image_token_start_index = torch.argmax((patch_type >= 0).int(), dim=1) + image_token_end_index = patch_type.shape[1] - 1 - torch.argmax((torch.flip(patch_type, dims=[1]) >= 0).int(), dim=1) + original_length = patch_type.shape[1] + image_token_length = image_token_end_index - image_token_start_index + 1 + + self.framefusion.prepare(patch_type, patch_num, image_token_start_index, image_token_end_index, image_token_length, original_length) + ### FRAMEFUSION END ### + + return vllm_embedding, vision_hidden_states diff --git a/framefusion/models/qwen2/modeling_qwen2.py b/framefusion/models/qwen2/modeling_qwen2.py new file mode 100644 index 0000000..6ff4a46 --- /dev/null +++ b/framefusion/models/qwen2/modeling_qwen2.py @@ -0,0 +1,333 @@ +from typing import List, Optional, Tuple, Union +import torch +import torch.utils.checkpoint +from transformers.cache_utils import Cache, DynamicCache,DynamicCache +from transformers.models.qwen2.modeling_qwen2 import repeat_kv,apply_rotary_pos_emb, logger, QWEN2_INPUTS_DOCSTRING +from transformers.modeling_outputs import BaseModelOutputWithPast +from transformers.utils.doc import add_start_docstrings_to_model_forward + +from framefusion.utils import scaled_dot_product_attention + +def Qwen2DecoderLayer_merge_then_prune_by_cost_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + **kwargs, + ) -> Tuple[Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]], torch.Tensor, Optional[torch.Tensor]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + ### start token merging at layer 0 before attention + if self.self_attn.layer_idx == 0: + hidden_states, position_embeddings, attention_mask = self.framefusion(hidden_states, position_embeddings, attention_mask) + ### end token merging at layer 0 before attention + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + hidden_states = residual + hidden_states + + ### start token merging or fastv after attention + hidden_states, position_embeddings, attention_mask = self.framefusion(hidden_states, position_embeddings, attention_mask, self_attn_weights) + ### end token merging or fastv after attention + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + ### start return the updated position embeddings and attention mask + outputs += (position_embeddings, attention_mask) + return outputs + ### end return the updated position embeddings and attention mask + +def Qwen2SdpaAttention_merge_then_prune_by_cost_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "Qwen2Model is using Qwen2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view( + bsz, q_len, self.num_heads, self.head_dim + ).transpose(1, 2) + key_states = key_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + value_states = value_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and attention_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + is_causal = True if causal_mask is None and q_len > 1 else False + + ### start storing attn_weights if needed + attn_weights = None + if (q_len > 1) and (self.framefusion.finish_merging) and (not self.framefusion.finish_pruning): + attn_weights = scaled_dot_product_attention( + query_states, + key_states, + value_states, + num=1, + attn_mask=None, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + ### end storing attn_weights if needed + + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + return attn_output, attn_weights, past_key_value + +@add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING) +def Qwen2Model_merge_then_fastv_cost_given_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # kept for BC (non `Cache` `past_key_values` inputs) + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + if past_key_values is None: + past_key_values = DynamicCache() + else: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + logger.warning_once( + "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " + "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " + "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" + ) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + ### change position_embeddings into a list for future pruning + position_embeddings = list(position_embeddings) + ### end changing + + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + ### start update the attention mask and position embeddings modified by framefusion + position_embeddings = layer_outputs[-2] + causal_mask = layer_outputs[-1] + ### end changing position embedding + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = next_cache.to_legacy_cache() + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) diff --git a/framefusion/models/qwen2/modeling_qwen2_baseline.py b/framefusion/models/qwen2/modeling_qwen2_baseline.py new file mode 100644 index 0000000..97b48d2 --- /dev/null +++ b/framefusion/models/qwen2/modeling_qwen2_baseline.py @@ -0,0 +1,2562 @@ +from types import MethodType +from functools import partial +import math +from typing import List, Optional, Tuple, Union +import torch +import torch.utils.checkpoint +from transformers.cache_utils import Cache, DynamicCache,DynamicCache, SinkCache +from transformers.models.qwen2.modeling_qwen2 import repeat_kv,apply_rotary_pos_emb, logger, QWEN2_INPUTS_DOCSTRING +from transformers.modeling_outputs import BaseModelOutputWithPast +from transformers.utils.doc import add_start_docstrings_to_model_forward +from transformers.models.qwen2.modeling_qwen2 import Qwen2SdpaAttention, Qwen2DecoderLayer, Qwen2Model +from functools import partial +try: + from minference import streaming_forward +except ImportError: + # minference is not needed if streamingllm is not used + streaming_forward = None + +from framefusion.utils import TEXT_TOKEN, IGNORE_TOKEN +from framefusion.main import find_contigious_latter_index + +""" +Utils +""" + +def compute_density_overhead(sparsity_list) -> tuple: + """ + Compute the average cumulative product and total product of the sparsity list. + """ + density_list = [1-s for s in sparsity_list] + + cost = 0.0 + remaining_density = 1.0 + for density in density_list: + remaining_density *= density + cost += remaining_density + + norm_cost = cost / len(density_list) + return norm_cost, remaining_density + +""" +Meta Interface +""" + +def replace_Qwen2_forward(model, mode="merge_then_fastv_cost_given", **kwargs): + print(f"replace_Qwen2_forward mode: {mode} and kwargs: {kwargs}") + + if mode=="prefill_merge": + prefill_merge_kwargs = { + "sparsity": kwargs.get("sparsity", [0.0] * 28), + } + + print(f"Config\n{prefill_merge_kwargs}") + + cost, remaining_density=compute_density_overhead(prefill_merge_kwargs['sparsity']) + print(f"Computational cost: {cost:.3f}, Remaining density: {remaining_density:.3f}") + + replace_Qwen2_merging( + model, + **prefill_merge_kwargs + ) + elif mode=="fastv": + fastv_kwargs = { + "fastv_k": kwargs.get("fastv_k", 3), + "fastv_r": kwargs.get("fastv_r", 0.5) + } + print(f"Config\n{fastv_kwargs}") + + replace_Qwen2_fastv( + model, + **fastv_kwargs + ) + elif mode=="merge_then_fastv": + merge_then_fastv_kwargs = { + "sparsity": kwargs.get("sparsity", [0.1] * 28), + "fastv_k": kwargs.get("fastv_k", 3), + "fastv_r": kwargs.get("fastv_r", 0.5) + } + print(f"Config\n{merge_then_fastv_kwargs}") + + replace_Qwen2_merge_then_fastv( + model, + **merge_then_fastv_kwargs + ) + elif mode=="streamingllm": + streamingllm_kwargs = { + "init_num": kwargs.get("init_num", 8), + "length_rate": kwargs.get("length_rate", 0.3), + } + print(f"Config\n{streamingllm_kwargs}") + + replace_Qwen2_streamingllm( + model, + **streamingllm_kwargs + ) + elif mode=="fastv_then_merge": + fastv_then_merge_kwargs = { + "fastv_k": kwargs.get("fastv_k", 2), + "fastv_r": kwargs.get("fastv_r", 0.75), + "merging_sparsity": kwargs.get("merging_sparsity", 0.3) + } + print(f"Config\n{fastv_then_merge_kwargs}") + + replace_Qwen2_fastv_then_merge( + model, + **fastv_then_merge_kwargs + ) + else: + raise NotImplementedError(f"Mode {mode} is not implemented yet.") + +def replace_minicpmv_forward(model, mode="fastv", **kwargs): + print(f"replace_minicpmv_forward mode: {mode} and kwargs: {kwargs}") + if mode=="fastv": + fastv_kwargs = { + "fastv_k": kwargs.get("fastv_k", 3), + "fastv_r": kwargs.get("fastv_r", 0.5) + } + print(f"Config\n{fastv_kwargs}") + + replace_minicpmv_fastv( + model, + **fastv_kwargs + ) + elif mode=="streamingllm": + streamingllm_kwargs = { + "init_num": kwargs.get("init_num", 8), + "length_rate": kwargs.get("length_rate", 0.3), + } + print(f"Config\n{streamingllm_kwargs}") + + replace_minicpmv_streamingllm( + model, + **streamingllm_kwargs + ) + else: + raise NotImplementedError(f"Mode {mode} is not implemented yet.") + + +""" +Forward functions +""" + +""" +Fastv forward functions +""" + +def replace_Qwen2_fastv(model, fastv_k = 3, fastv_r = 0.5): + model.fastv_k = fastv_k + model.fastv_r = fastv_r + + if isinstance(model.model, Qwen2Model): + model.model.forward = MethodType(partial(Qwen2Model_fastv_forward, model=model), model.model) + for i, decoder_layer in enumerate(model.model.layers): + if isinstance(decoder_layer, Qwen2DecoderLayer): + decoder_layer.forward=MethodType(Qwen2DecoderLayer_fastv_forward, decoder_layer) + qwen2_attention_instance = decoder_layer.self_attn + if isinstance(qwen2_attention_instance, Qwen2SdpaAttention): + qwen2_attention_instance.forward = MethodType(partial(Qwen2SdpaAttention_fastv_forward, model=model), qwen2_attention_instance) + else: + raise TypeError("language model is not Qwen2.") + +def replace_minicpmv_fastv(model, fastv_k = 3, fastv_r = 0.5): + model.fastv_k = fastv_k + model.fastv_r = fastv_r + + if isinstance(model.llm.model, Qwen2Model): + model.llm.model.forward = MethodType(partial(Qwen2Model_fastv_forward, model=model), model.llm.model) + for i, decoder_layer in enumerate(model.llm.model.layers): + if isinstance(decoder_layer, Qwen2DecoderLayer): + decoder_layer.forward=MethodType(Qwen2DecoderLayer_fastv_forward, decoder_layer) + qwen2_attention_instance = decoder_layer.self_attn + if isinstance(qwen2_attention_instance, Qwen2SdpaAttention): + qwen2_attention_instance.forward = MethodType(partial(Qwen2SdpaAttention_fastv_forward, model=model), qwen2_attention_instance) + else: + raise TypeError("language model is not Qwen2.") + +@add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING) +def Qwen2Model_fastv_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + model = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, +) -> Union[Tuple, BaseModelOutputWithPast]: + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # kept for BC (non `Cache` `past_key_values` inputs) + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + if past_key_values is None: + past_key_values = DynamicCache() + else: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + logger.warning_once( + "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " + "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " + "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" + ) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + ### change position_embeddings into a list for future pruning + position_embeddings = list(position_embeddings) + + ### end changing + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + ### implement fastv + FASTV_k = model.fastv_k # the layer_idx to prune + FASTV_r = model.fastv_r # the pruning ratio + FASTV_image_token_start_index = model.image_token_start_index.item() + FASTV_image_token_length = model.image_token_length.item() + device = self.device + #seq_length_with_past = past_seen_tokens + inputs_embeds.shape[1] (here because cache position in minicpmv is not none,so past_seen_tokens is not defined ) + for layer_idx, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + # pruning hidden states, no kv cache + + if use_cache: + if hidden_states.shape[1] != 1: + if layer_idx