From 9993cdd8ca521d6bf833d1da2c5dd68de88d7b23 Mon Sep 17 00:00:00 2001
From: alic-li <2591709191@qq.com>
Date: Sun, 25 May 2025 13:07:37 +0800
Subject: [PATCH] Add World RWKV Video Understanding~
---
scripts/vidtt.sh | 29 ++++
web/video_web.py | 243 ++++++++++++++++++++++++++++++++
world/dataset.py | 41 ++++++
world/encoder/siglip_encoder.py | 41 +++++-
world/frame_att.py | 73 ++++++++++
5 files changed, 422 insertions(+), 5 deletions(-)
create mode 100644 scripts/vidtt.sh
create mode 100644 web/video_web.py
create mode 100644 world/frame_att.py
diff --git a/scripts/vidtt.sh b/scripts/vidtt.sh
new file mode 100644
index 0000000..caf55ef
--- /dev/null
+++ b/scripts/vidtt.sh
@@ -0,0 +1,29 @@
+load_model=/home/rwkv/alic-li/WorldRWKV/rwkv-0.pth
+proj_dir=rwkv7-0.4b-video-siglip-ocr-base
+data_file=/home/rwkv/alic-li/video_datasets/LLaVA-Video-178K/1_2_m_nextqa
+
+n_layer=24
+n_embd=1024
+
+encoder_path="google/siglip2-base-patch16-384"
+encoder_type=siglip
+data_type=video
+
+micro_bsz=12
+epoch_save=1
+epoch_steps=570
+ctx_len=2048
+
+
+HF_ENDPOINT="https://hf-mirror.com" python world_train.py \
+--load_model $load_model \
+--proj_dir $proj_dir --data_file $data_file \
+--data_type $data_type \
+--vocab_size 65536 \
+--n_layer $n_layer --n_embd $n_embd \
+--ctx_len $ctx_len --micro_bsz $micro_bsz \
+--epoch_steps $epoch_steps --epoch_count 1 --epoch_begin 0 --epoch_save $epoch_save \
+--lr_init 1e-3 --lr_final 0 --warmup_steps 0 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 \
+--accelerator gpu --devices 4 --precision bf16 --strategy deepspeed_stage_1 --grad_cp 1 \
+--encoder_path $encoder_path --encoder_type $encoder_type \
+--my_testing "x070" --train_step adapter rwkv
\ No newline at end of file
diff --git a/web/video_web.py b/web/video_web.py
new file mode 100644
index 0000000..b19761e
--- /dev/null
+++ b/web/video_web.py
@@ -0,0 +1,243 @@
+import gradio as gr
+from infer.worldmodel import Worldinfer
+from PIL import Image
+import re
+import torch
+import random
+from PIL import Image
+from decord import VideoReader
+from decord import cpu
+import os
+
+# 初始化模型
+llm_path = '/home/rwkv/alic-li/WorldRWKV/rwkv7-0.4b-video-siglip-ocr-base/rwkv-0'
+encoder_path = 'google/siglip2-base-patch16-384'
+encoder_type = 'siglip'
+
+enable_think = False
+# 全局变量存储当前上传的视频关键帧和模型状态
+current_video_frames = None # 存储关键帧列表
+current_state = None
+first_question = False
+# 是否是第一轮对话
+# 初始化模型
+model = Worldinfer(model_path=llm_path, encoder_type=encoder_type, encoder_path=encoder_path)
+
+# 处理用户输入的核心逻辑
+import html # 导入html库
+
+
+def frame_att_generator(video_path, threshold=0.05, min_k=3, max_k=10):
+ vr = VideoReader(video_path, ctx=cpu(0)) # 使用 CPU 解码
+ fps = vr.get_avg_fps() # 获取视频平均帧率
+ sampling_interval = int(fps) # 每秒采样一帧作为基准
+
+ frames = []
+ frames_flattened = []
+
+ for idx in range(len(vr)):
+ # 只对采样帧进行处理
+ if idx % sampling_interval == 0:
+ frame = vr[idx].asnumpy() # 转换为 numpy 数组
+ frame_rgb = frame / 255.0 # 归一化到 [0, 1]
+ frame_tensor = torch.tensor(frame_rgb).permute(2, 0, 1).half()
+ flat = frame_tensor.reshape(-1) # 展平缓存
+ frames.append(frame_tensor)
+ frames_flattened.append(flat)
+
+ if len(frames) <= 1:
+ return frames
+
+ # 批量计算帧差
+ flattened_tensor = torch.stack(frames_flattened) # shape: (N, C*H*W)
+ diffs = torch.mean(torch.abs(flattened_tensor[1:] - flattened_tensor[:-1]), dim=1)
+ selected_indices_sampled = [0] + [i + 1 for i, diff in enumerate(diffs) if diff > threshold]
+
+ K = len(selected_indices_sampled)
+
+ # 如果帧太少,补充随机帧
+ if K < min_k:
+ candidates = [i for i in range(len(frames)) if i not in selected_indices_sampled]
+ missing = min_k - K
+ selected_indices_sampled += random.sample(candidates, missing)
+ selected_indices_sampled = sorted(selected_indices_sampled)
+
+ # 如果帧太多,保留前 max_k 个差异最大的帧
+ elif K > max_k:
+ frame_diffs = [(diff.item(), i + 1) for i, diff in enumerate(diffs)]
+ frame_diffs.sort(reverse=True, key=lambda x: x[0])
+ top_indices = [0] + [idx for diff, idx in frame_diffs[:max_k - 1]]
+ selected_indices_sampled = sorted(top_indices)
+
+ # 返回 PIL.Image.Image 图片列表
+ return [
+ Image.fromarray((frames[i].permute(1, 2, 0).cpu().numpy() * 255).astype("uint8"))
+ for i in selected_indices_sampled
+ ]
+def chat_fn(user_input, chat_history, video=None):
+ global current_video_frames, current_state, first_question
+
+ # 如果上传了新视频,更新当前视频帧并重置状态
+ if video is not None:
+ current_video_frames = frame_att_generator(video)
+
+ # 如果没有视频帧,提示用户上传
+ if current_video_frames is None or len(current_video_frames) == 0:
+ bot_response = "请先上传一个视频!"
+ chat_history.append((user_input, bot_response))
+ return "", chat_history
+
+ # 构造提示文本
+ prompt = f'\x16User: {user_input}\x17Assistant:'
+
+ # 生成结果,传入当前状态
+ try:
+ if first_question:
+ result, state = model.generate(prompt, current_video_frames[0], state=None) # 使用第一帧作为初始图像
+ else:
+ result, state = model.generate(prompt, 'none', state=current_state)
+
+ first_question = False
+ bot_response, current_state = result, state
+ if enable_think == True:
+ # 解析标签
+ think_pattern = re.compile(r'', re.DOTALL)
+ think_matches = think_pattern.findall(bot_response)
+
+ # 解析标签
+ answer_pattern = re.compile(r'(.*?)', re.DOTALL)
+ answer_matches = answer_pattern.findall(bot_response)
+
+ # 构造最终的输出
+ final_response = ""
+ for match in think_matches:
+ final_response += f"Think 🤔
{html.escape(match)} "
+
+ for match in answer_matches:
+ final_response += "Answer 💡"
+ final_response += "\n"
+ final_response += html.escape(match)
+
+ # 转义HTML标签
+ bot_response = final_response
+
+ except Exception as e:
+ bot_response = f"生成回复时出错: {str(e)}"
+ current_state = None # 出错时重置状态
+
+ # 更新对话历史
+ chat_history.append((user_input, bot_response))
+
+ # 返回更新后的组件状态
+ return "", chat_history # 清空输入框,更新聊天记录# 处理图片上传
+def update_video(video_path):
+ global current_video_frames, current_state, first_question
+ if video_path is not None:
+ current_video_frames = frame_att_generator(video_path) # 提取关键帧
+ current_state = None
+ first_question = True
+ return "视频已上传成功!可以开始提问了。"
+ else:
+ return "视频上传失败,请重新上传。"
+
+# 清空图片
+def clear_image():
+ global current_state, current_video_frames
+ current_video_frames = None
+ current_state = None
+ return None, "图片已清除,请上传新图片。"
+
+# 清空历史和图片
+def clear_all():
+ global current_video_frames, current_state
+ current_video_frames = None
+ current_state = None
+ return [], "", "图片和对话已清空,请重新上传图片。"
+
+def chat_without_video_update(user_input, chat_history):
+ return chat_fn(user_input, chat_history)
+
+# 界面布局组件
+with gr.Blocks(title="WORLD RWKV", theme=gr.themes.Soft()) as demo:
+ gr.Markdown("# WORLD RWKV")
+ gr.Markdown("上传一个视频,然后可以进行多轮提问")
+
+ with gr.Row():
+ # 左侧视频上传区
+ with gr.Column(scale=2):
+ video_input = gr.Video(
+ label="上传视频",
+ height=400
+ )
+
+ # 视频状态和操作
+ with gr.Row():
+ video_status = gr.Textbox(
+ label="视频状态",
+ value="请上传视频",
+ interactive=False
+ )
+ clear_video_btn = gr.Button("删除视频")
+
+ # 右侧对话区
+ with gr.Column(scale=3):
+ chatbot = gr.Chatbot(
+ label="对话记录",
+ bubble_full_width=False,
+ height=500
+ )
+
+ # 控制区域
+ with gr.Row():
+ # 输入组件
+ user_input = gr.Textbox(
+ placeholder="请输入问题...",
+ scale=7,
+ container=False,
+ label="问题输入"
+ )
+
+ # 操作按钮
+ with gr.Column(scale=1):
+ submit_btn = gr.Button("发送", variant="primary")
+ clear_btn = gr.Button("清空所有")
+ # 事件绑定
+ # 视频上传事件
+ video_input.change(
+ fn=update_video,
+ inputs=[video_input],
+ outputs=[video_status]
+ )
+
+ # 删除视频按钮事件
+ clear_video_btn.click(
+ fn=lambda: (None, "视频已清除,请上传新视频。"), # 使用lambda直接返回正确类型
+ inputs=None,
+ outputs=[video_input, video_status]
+ )
+
+ # 发送按钮事件
+ submit_btn.click(
+ fn=chat_fn,
+ inputs=[user_input, chatbot, video_input],
+ outputs=[user_input, chatbot]
+ )
+
+ # 输入框回车事件 - 使用不需要视频参数的函数
+ user_input.submit(
+ fn=chat_without_video_update,
+ inputs=[user_input, chatbot],
+ outputs=[user_input, chatbot]
+ )
+
+ # 清空按钮事件
+ clear_btn.click(
+ fn=lambda: ([], "", "图片和对话已清空,请重新上传图片。", None), # 修复返回值
+ inputs=None,
+ outputs=[chatbot, user_input, video_status, video_input],
+ queue=False
+ )
+
+# 启动应用
+if __name__ == "__main__":
+ demo.launch(server_name="127.0.0.1", server_port=7860)
\ No newline at end of file
diff --git a/world/dataset.py b/world/dataset.py
index 255eeaf..73fa060 100644
--- a/world/dataset.py
+++ b/world/dataset.py
@@ -18,6 +18,7 @@
pipeline = PIPELINE('rwkv', "rwkv_vocab_v20230424")
from PIL import Image
+from .frame_att import frame_att_generator
import pandas as pd
import librosa
import io
@@ -32,6 +33,28 @@
transforms.ToTensor() # 将图像转换为张量
])
+def process_video(self, video_file):
+
+ try:
+ if "shareVideoGPTV" in video_file:
+ # # 如果是 shareVideoGPTV 类型的数据,直接从图像文件读取
+ # frame_files = sorted([os.path.join(video_file, f) for f in os.listdir(video_file)])
+ # num_frames_to_sample = getattr(self.args, 'frames_upbound', 10)
+ # sampled_indices = np.linspace(0, len(frame_files) - 1, num_frames_to_sample, dtype=int)
+
+ video_list = []
+ # for idx in sampled_indices:
+ # frame = Image.open(frame_files[idx]).convert("RGB")
+ # video_list.append(frame)
+ else:
+ video_list = frame_att_generator(video_file, threshold=0.05, min_k=3, max_k=10)
+
+ return video_list # 返回关键帧图像列表
+
+ except Exception as e:
+ print(f"Error loading video {video_file}: {e}")
+ raise
+
def process_conversation_text(conversations):
conversation_text = f"\x16"
@@ -56,10 +79,12 @@ def process_tokens(conversations):
if role == 'human':
question = f"\x16User: {content}\x17"
+ # print(question, "\n")
input = torch.tensor(pipeline.encode(question))
label = torch.full_like(input, -100)
elif role in ['assistant', 'gpt']:
answer = f"\x16Assistant: {content}\x17"
+ # print(answer, "\n")
input= torch.tensor(pipeline.encode(answer))
label = input
inputs.append(input)
@@ -199,6 +224,13 @@ def list_subdirectories(base_path):
with jsonlines.open(args.data_file) as file:
self.data = list(file)
+
+ elif args.data_type=='video':
+ import jsonlines
+ # with open(f'{args.data_file}/chat.json', 'r', encoding='utf-8') as file:
+ # self.data = json.load(file)
+ with jsonlines.open(f'{args.data_file}/chat.jsonl') as file:
+ self.data = list(file)
else:
self.data = pd.read_parquet(args.data_file)
@@ -234,6 +266,15 @@ def __getitem__(self, idx):
image = Image.open(mod_path).convert('RGB')
sign = image
text_tokens, text_labels = process_tokens(conversation_text)
+ elif args.data_type == 'video':
+
+ vid_name = self.data[idx]['video']
+ conversation_text = self.data[idx]['conversations']
+
+ vid_path = f'{args.data_file}/{vid_name}'
+ sign = process_video(self, vid_path)
+ text_tokens, text_labels = process_tokens(conversation_text)
+ # print(sign,"\n")
return sign, text_tokens, text_labels
# if args.data_type =='wav':
diff --git a/world/encoder/siglip_encoder.py b/world/encoder/siglip_encoder.py
index 62b7b20..16f297b 100644
--- a/world/encoder/siglip_encoder.py
+++ b/world/encoder/siglip_encoder.py
@@ -38,7 +38,25 @@ def forward(self, x):
x = self.proj(x)
return x + self.pre_norm(x)
+class VideoAdapter(nn.Module):
+ def __init__(self, input_dim=1024):
+ super().__init__()
+ self.input_dim = input_dim
+
+ def depthwise_conv():
+ return nn.Conv1d(input_dim, input_dim, kernel_size=3, stride=2, padding=1, groups=input_dim)
+
+ self.reduce_conv = nn.Sequential(
+ depthwise_conv(), nn.Conv1d(input_dim, input_dim, kernel_size=1), nn.ReLU(), # 1/2
+ depthwise_conv(), nn.Conv1d(input_dim, input_dim, kernel_size=1), nn.ReLU(), # 1/4
+ # depthwise_conv(), nn.Conv1d(input_dim, input_dim, kernel_size=1), nn.ReLU(), # 1/8
+ )
+ def forward(self, x):
+ x = x.transpose(1, 2)
+ x = self.reduce_conv(x)
+ x = x.transpose(1, 2)
+ return x
class SiglipEncoder(nn.Module):
@@ -57,10 +75,23 @@ def __init__(
self.encoder_dim = 768 #self.model.config.hidden_size
self.adapter = VisualAdapter(self.encoder_dim, project_dim)
+ self.VideoAdapter = VideoAdapter()
def forward(self, x):
+ if isinstance(x, list): # 输入图像列表,用于视频理解
+ img_tensor_list = []
+ for frame in x:
+ frame = torch.from_numpy(self.image_processor(frame)['pixel_values'][0]).to(self.device,dtype=torch.bfloat16)
+ frame = self.model(frame.unsqueeze(0), output_hidden_states=True).last_hidden_state
+ img_tensor_list.append(frame)
+ out_put = torch.cat(img_tensor_list, dim=1)
+ out_put = self.adapter(out_put)
+ out_put = self.VideoAdapter(out_put)
+ # print(out_put.shape)
+ return out_put
- x= torch.from_numpy(self.image_processor(x)['pixel_values'][0]).to(self.device,dtype=torch.bfloat16)
- x = self.model(x.unsqueeze(0), output_hidden_states=True).last_hidden_state
- x = self.adapter(x)
-
- return x
\ No newline at end of file
+ else: # 默认处理单个图像
+ x = torch.from_numpy(self.image_processor(x)['pixel_values'][0]).to(self.device,dtype=torch.bfloat16)
+ x = self.model(x.unsqueeze(0), output_hidden_states=True).last_hidden_state
+ x = self.adapter(x)
+ return x
+
\ No newline at end of file
diff --git a/world/frame_att.py b/world/frame_att.py
new file mode 100644
index 0000000..ac53dc9
--- /dev/null
+++ b/world/frame_att.py
@@ -0,0 +1,73 @@
+import torch
+import random
+from PIL import Image
+from decord import VideoReader
+from decord import cpu
+import os
+
+def frame_att_generator(video_path, threshold=0.05, min_k=3, max_k=10):
+ vr = VideoReader(video_path, ctx=cpu(0)) # 使用 CPU 解码
+ fps = vr.get_avg_fps() # 获取视频平均帧率
+ sampling_interval = int(fps) # 每秒采样一帧作为基准
+
+ frames = []
+ frames_flattened = []
+
+ for idx in range(len(vr)):
+ # 只对采样帧进行处理
+ if idx % sampling_interval == 0:
+ frame = vr[idx].asnumpy() # 转换为 numpy 数组
+ frame_rgb = frame / 255.0 # 归一化到 [0, 1]
+ frame_tensor = torch.tensor(frame_rgb).permute(2, 0, 1).half()
+ flat = frame_tensor.reshape(-1) # 展平缓存
+ frames.append(frame_tensor)
+ frames_flattened.append(flat)
+
+ if len(frames) <= 1:
+ return frames
+
+ # 批量计算帧差
+ flattened_tensor = torch.stack(frames_flattened) # shape: (N, C*H*W)
+ diffs = torch.mean(torch.abs(flattened_tensor[1:] - flattened_tensor[:-1]), dim=1)
+ selected_indices_sampled = [0] + [i + 1 for i, diff in enumerate(diffs) if diff > threshold]
+
+ K = len(selected_indices_sampled)
+
+ # 如果帧太少,补充随机帧
+ if K < min_k:
+ candidates = [i for i in range(len(frames)) if i not in selected_indices_sampled]
+ missing = min_k - K
+ selected_indices_sampled += random.sample(candidates, missing)
+ selected_indices_sampled = sorted(selected_indices_sampled)
+
+ # 如果帧太多,保留前 max_k 个差异最大的帧
+ elif K > max_k:
+ frame_diffs = [(diff.item(), i + 1) for i, diff in enumerate(diffs)]
+ frame_diffs.sort(reverse=True, key=lambda x: x[0])
+ top_indices = [0] + [idx for diff, idx in frame_diffs[:max_k - 1]]
+ selected_indices_sampled = sorted(top_indices)
+
+ # 返回 PIL.Image.Image 图片列表
+ return [
+ Image.fromarray((frames[i].permute(1, 2, 0).cpu().numpy() * 255).astype("uint8"))
+ for i in selected_indices_sampled
+ ]
+
+if __name__ == "__main__":
+ import time
+ video_path = "/home/alic-li/Videos/VID_20241030121322.mp4"
+ output_folder = "./output_frames"
+ os.makedirs(output_folder, exist_ok=True)
+
+ start_time = time.time()
+ key_frames_pil = frame_att_generator(video_path)
+ end_time = time.time()
+
+ total_frames = len(key_frames_pil)
+ processing_time = end_time - start_time
+ fps = total_frames / processing_time
+
+ print(f"共处理 {total_frames} 帧,耗时 {processing_time:.2f} 秒,处理速度:{fps:.2f} FPS")
+
+ for i, img in enumerate(key_frames_pil):
+ img.save(os.path.join(output_folder, f"key_frame_{i}.jpg"))
\ No newline at end of file