Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,15 @@ pip install -e benchmark/VLMEvalKit
python third_party/VLMEvalKit/run.py --work-dir ./results/ --config eval/vlmevalkit/config.json
```
Currenty multi-GPU is not tested.
<Directory to save results>

## Multimodal Token Fusion

We used [FrameFusion](https://github.com/thu-nics/FrameFusion) to fuse the image and video tokens so they could fit into the context length of 4096.

To enable the feature, simply set `use_token_reduction=True` in the `Worldinfer` constructor. Refer to the official codebase or [paper]() for further details.

> [!NOTE]
> Though this method is capable of fusing multi-image and video tokens, current `RWKV7-*-siglip2` models have not been trained with multi-image and video QA tasks. Unexpected behaviors may occur.


## Training
Expand Down
11 changes: 10 additions & 1 deletion README_zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,16 @@ export PYTHONPATH=$PYTHONPATH:$(pwd)
pip install -e benchmark/VLMEvalKit
python -m benchmark/VLMEvalKit/run.py --work-dir <Directory to save results> --config eval/vlmevalkit/config.json
```
Currenty multi-GPU is not tested.
目前尚未测试多GPU下的运行。

## 多模态Token融合

我们使用了 [FrameFusion](https://github.com/thu-nics/FrameFusion) 来融合多张图片和视频帧的token,使其能够适应4096的上下文长度。

为了使用这个功能,你需要在 `Worldinfer` 构造函数中设置 `use_token_reduction=True`,对于相关参数的设置请参考其官方代码仓库与[论文](https://arxiv.org/abs/2501.01986).

> [!NOTE]
> 虽然这个Token融合方法能够融合多张图片和视频帧的token,但 `RWKV7-*-siglip2` 系列模型并未针对多张图片与视频的处理进行过训练,因此在相关任务中可能会出现意外行为。

# 训练
> [!NOTE]
Expand Down
25 changes: 15 additions & 10 deletions eval/vlmevalkit/config.json
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
{
"model": {
"worldrwkv7_siglip_custom": {
"worldrwkv7_3b_siglip": {
"class": "WorldRWKV7_Siglip2",
"model_path": "/home/rwkv/WorldRWKV/model/RWKV7-3B-siglip2/rwkv-0",
"encoder_path": "google/siglip2-base-patch16-384",
"use_custom_prompt": true,
"use_token_reduction":true,
"verbose": false,
"args": {
"temperature": 1.0,
Expand All @@ -19,17 +20,21 @@
}
},
"data": {
"MMBench_DEV_EN": {
"class": "ImageMCQDataset",
"dataset": "MMBench_DEV_EN"
},
"MMMU_TEST": {
"MMMU_DEV_VAL":{
"class": "MMMUDataset",
"dataset": "MMMU_TEST"
"dataset": "MMMU_DEV_VAL"
},
"MMMU_EVAL":{
"class": "MMMUDataset",
"dataset": "MMMU_DEV_EVAL"
"TextVQA_VAL":{
"class":"ImageVQADataset",
"dataset":"TextVQA_VAL"
},
"ScienceQA_VAL":{
"class":"ImageMCQDataset",
"dataset":"ScienceQA_VAL"
},
"COCO_VAL":{
"class":"ImageCaptionDataset",
"dataset":"COCO_VAL"
}
}
}
105 changes: 105 additions & 0 deletions framefusion/example_siglip_framefusion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import os
import torch
import argparse
from PIL import Image
import matplotlib.pyplot as plt
from world.encoder.siglip_encoder import SiglipEncoder
from framefusion.siglip_adapter import apply_siglip_framefusion
from framefusion.video_processor import load_video_frames, encode_video_frames_with_framefusion


def parse_args():
parser = argparse.ArgumentParser(description="Test SIGLIP FrameFusion on a video")
parser.add_argument("--video_path", type=str, required=True, help="Path to the video file")
parser.add_argument("--encoder_path", type=str, default="google/siglip2-base-patch16-384", help="Path to the SIGLIP encoder model")
parser.add_argument("--project_dim", type=int, default=768, help="Projection dimension for the encoder")
parser.add_argument("--sample_steps", type=int, default=30, help="Number of frames to skip between samples")
parser.add_argument("--max_frames", type=int, default=None, help="Maximum number of frames to extract")
parser.add_argument("--cost", type=float, default=0.3, help="Computational budget for FrameFusion")
parser.add_argument("--similarity_threshold", type=float, default=0.6, help="Similarity threshold for merging tokens")
parser.add_argument("--ratio_threshold", type=float, default=0.1, help="Minimum ratio of tokens to keep")
parser.add_argument("--output_dir", type=str, default="./framefusion_output", help="Directory to save visualization")
return parser.parse_args()


def visualize_frames(frames, output_path):
"""Visualize a subset of frames and save to disk"""
os.makedirs(os.path.dirname(output_path), exist_ok=True)

# Select a subset of frames if there are too many
if len(frames) > 16:
indices = torch.linspace(0, len(frames)-1, 16).long().tolist()
frames_subset = [frames[i] for i in indices]
else:
frames_subset = frames

# Create a grid of images
fig, axes = plt.subplots(4, 4, figsize=(12, 12))
axes = axes.flatten()

for i, frame in enumerate(frames_subset):
if i < len(axes):
axes[i].imshow(frame)
axes[i].set_title(f"Frame {i}")
axes[i].axis('off')

# Hide any unused subplots
for i in range(len(frames_subset), len(axes)):
axes[i].axis('off')

plt.tight_layout()
plt.savefig(output_path)
plt.close()
print(f"Saved visualization to {output_path}")


def main():
args = parse_args()

# Create output directory
os.makedirs(args.output_dir, exist_ok=True)

print(f"Processing video: {args.video_path}")

# Process video without FrameFusion
print("\n=== Processing without FrameFusion ===")
result_without_ff = encode_video_frames_with_framefusion(
args.video_path,
encoder_path=args.encoder_path,
project_dim=args.project_dim,
sample_steps=args.sample_steps,
max_frames=args.max_frames,
use_framefusion=False
)
encoded_frames_without_ff, frames_without_ff, _ = result_without_ff

# Process video with FrameFusion
print("\n=== Processing with FrameFusion ===")
result_with_ff = encode_video_frames_with_framefusion(
args.video_path,
encoder_path=args.encoder_path,
project_dim=args.project_dim,
sample_steps=args.sample_steps,
max_frames=args.max_frames,
use_framefusion=True,
cost=args.cost,
similarity_lower_bound=args.similarity_threshold,
ratio_lower_bound=args.ratio_threshold
)
encoded_frames_with_ff, frames_with_ff, original_frame_count = result_with_ff

# Print results
print("\n=== Results ===")
print(f"Original frame count: {original_frame_count}")
print(f"Without FrameFusion: {encoded_frames_without_ff.shape[0]} frames with {encoded_frames_without_ff.shape[1]} patches per frame")
print(f"With FrameFusion: {encoded_frames_with_ff.shape[0]} frames with {encoded_frames_with_ff.shape[1]} patches per frame")
print(f"Frame reduction: {(1 - encoded_frames_with_ff.shape[0]/original_frame_count) * 100:.2f}%")

# Visualize frames
visualize_frames(frames_without_ff, os.path.join(args.output_dir, "original_frames.png"))

print("\nFrameFusion successfully applied to SIGLIP encoder!")


if __name__ == "__main__":
main()
145 changes: 145 additions & 0 deletions framefusion/interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
# common imports
from types import MethodType
from typing import Callable
import torch
import torch.nn as nn
from accelerate.hooks import add_hook_to_module
from transformers import PreTrainedModel

# framefusion methods
from framefusion.main import FrameFusion
from framefusion.utils import TEXT_TOKEN, IGNORE_TOKEN, get_attr_by_name

# model types
from transformers import LlavaNextVideoForConditionalGeneration
from llava.model.language_model.llava_qwen import LlavaQwenForCausalLM

# replace methods
from framefusion.models.llava_next_video.modeling_llava_next_video import _merge_input_ids_with_image_features_get_token_type
from framefusion.models.llava_video.modeling_llava_video import prepare_inputs_labels_for_multimodal_get_patch_type
from framefusion.models.minicpmv.modeling_minicpmv import get_vllm_embedding
from framefusion.models.qwen2.modeling_qwen2 import Qwen2Model_merge_then_fastv_cost_given_forward, Qwen2DecoderLayer_merge_then_prune_by_cost_forward, Qwen2SdpaAttention_merge_then_prune_by_cost_forward


def apply_framefusion(model, cost, similarity_lower_bound, ratio_lower_bound):
"""
Apply FrameFusion to the model

Args:
model: the model to apply FrameFusion to
cost: the cost of the FrameFusion
similarity_lower_bound: the similarity lower bound of the FrameFusion
ratio_lower_bound: the ratio lower bound of the FrameFusion
"""
# LlavaNextVideo Model
if isinstance(model, LlavaNextVideoForConditionalGeneration):
model._merge_input_ids_with_image_features = MethodType(_merge_input_ids_with_image_features_get_token_type, model)

llm_forward = Qwen2Model_merge_then_fastv_cost_given_forward
decoder_forward = Qwen2DecoderLayer_merge_then_prune_by_cost_forward
attention_forward = Qwen2SdpaAttention_merge_then_prune_by_cost_forward
llm_key = "model"
decoder_key = "layers"
attention_key = "self_attn"

# LlavaVideo Model
elif isinstance(model, LlavaQwenForCausalLM):
model.prepare_inputs_labels_for_multimodal = MethodType(prepare_inputs_labels_for_multimodal_get_patch_type, model)

llm_forward = Qwen2Model_merge_then_fastv_cost_given_forward
decoder_forward = Qwen2DecoderLayer_merge_then_prune_by_cost_forward
attention_forward = Qwen2SdpaAttention_merge_then_prune_by_cost_forward
llm_key = "model"
decoder_key = "layers"
attention_key = "self_attn"

# MiniCPM Model
elif model.config.architectures[0] == "MiniCPMV":

model.get_vllm_embedding = MethodType(get_vllm_embedding, model)
llm_forward = Qwen2Model_merge_then_fastv_cost_given_forward
decoder_forward = Qwen2DecoderLayer_merge_then_prune_by_cost_forward
attention_forward = Qwen2SdpaAttention_merge_then_prune_by_cost_forward
llm_key = "llm.model"
decoder_key = "layers"
attention_key = "self_attn"

else:
raise NotImplementedError

replace_framefusion_forward(
model,
cost=cost,
similarity_lower_bound=similarity_lower_bound,
ratio_lower_bound=ratio_lower_bound,
llm_forward=llm_forward,
decoder_forward=decoder_forward,
attention_forward=attention_forward,
llm_key=llm_key,
decoder_key=decoder_key,
attention_key=attention_key,
)


def get_token_type(model):
# LlavaNextVideo Model
if isinstance(model, LlavaNextVideoForConditionalGeneration):
model._merge_input_ids_with_image_features = MethodType(_merge_input_ids_with_image_features_get_token_type, model)

# LlavaVideo Model
elif isinstance(model, LlavaQwenForCausalLM):
model.prepare_inputs_labels_for_multimodal = MethodType(prepare_inputs_labels_for_multimodal_get_patch_type, model)

# MiniCPM Model
elif model.config.architectures[0] == "MiniCPMV":
model.get_vllm_embedding = MethodType(get_vllm_embedding, model)
else:
raise NotImplementedError


def replace_framefusion_forward(
module: torch.nn.Module,
cost: float,
similarity_lower_bound: float,
ratio_lower_bound: float,
llm_forward: Callable,
decoder_forward: Callable,
attention_forward: Callable,
llm_key: str = "model",
decoder_key: str = "layers",
attention_key: str = "self_attn",
):
"""
Replace the forward method of the model with the framefusion forward method.
Make framefusion a property of the model.

The keys are accessed in an hierarchical manner: llm_key -> decoder_key -> attention_key. Each key can have multiple hierarchies, e.g. "llm.model", which will be accessed by module.llm.model
"""
framefusion = FrameFusion(cost, similarity_lower_bound, ratio_lower_bound)

module.framefusion = framefusion

llm = get_attr_by_name(module, llm_key)
assert isinstance(llm, PreTrainedModel), f"{llm_key} is not a PreTrainedModel"

llm.framefusion = framefusion
llm.forward = MethodType(llm_forward, llm)

decoder_layers = get_attr_by_name(llm, decoder_key)
for i, decoder_layer in enumerate(decoder_layers):
assert isinstance(decoder_layer, nn.Module), f"{decoder_key}[{i}] is not a nn.Module"

decoder_layer.framefusion = framefusion
decoder_layer.forward = MethodType(decoder_forward, decoder_layer)

# ensure accelerate hooks are not removed
if hasattr(decoder_layer, "_hf_hook"):
decoder_layer._old_forward = MethodType(decoder_forward, decoder_layer)
add_hook_to_module(decoder_layer, decoder_layer._hf_hook)

qwen2_attention_instance = get_attr_by_name(decoder_layer, attention_key)
assert isinstance(qwen2_attention_instance, nn.Module), f"{decoder_key}[{i}].self_attn is not a nn.Module"

# replace the forward method of the attention layer
qwen2_attention_instance.framefusion = framefusion
qwen2_attention_instance.forward = MethodType(attention_forward, qwen2_attention_instance)
Loading