diff --git a/scripts/vidtt.sh b/scripts/vidtt.sh new file mode 100644 index 0000000..caf55ef --- /dev/null +++ b/scripts/vidtt.sh @@ -0,0 +1,29 @@ +load_model=/home/rwkv/alic-li/WorldRWKV/rwkv-0.pth +proj_dir=rwkv7-0.4b-video-siglip-ocr-base +data_file=/home/rwkv/alic-li/video_datasets/LLaVA-Video-178K/1_2_m_nextqa + +n_layer=24 +n_embd=1024 + +encoder_path="google/siglip2-base-patch16-384" +encoder_type=siglip +data_type=video + +micro_bsz=12 +epoch_save=1 +epoch_steps=570 +ctx_len=2048 + + +HF_ENDPOINT="https://hf-mirror.com" python world_train.py \ +--load_model $load_model \ +--proj_dir $proj_dir --data_file $data_file \ +--data_type $data_type \ +--vocab_size 65536 \ +--n_layer $n_layer --n_embd $n_embd \ +--ctx_len $ctx_len --micro_bsz $micro_bsz \ +--epoch_steps $epoch_steps --epoch_count 1 --epoch_begin 0 --epoch_save $epoch_save \ +--lr_init 1e-3 --lr_final 0 --warmup_steps 0 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 \ +--accelerator gpu --devices 4 --precision bf16 --strategy deepspeed_stage_1 --grad_cp 1 \ +--encoder_path $encoder_path --encoder_type $encoder_type \ +--my_testing "x070" --train_step adapter rwkv \ No newline at end of file diff --git a/web/video_web.py b/web/video_web.py new file mode 100644 index 0000000..b19761e --- /dev/null +++ b/web/video_web.py @@ -0,0 +1,243 @@ +import gradio as gr +from infer.worldmodel import Worldinfer +from PIL import Image +import re +import torch +import random +from PIL import Image +from decord import VideoReader +from decord import cpu +import os + +# 初始化模型 +llm_path = '/home/rwkv/alic-li/WorldRWKV/rwkv7-0.4b-video-siglip-ocr-base/rwkv-0' +encoder_path = 'google/siglip2-base-patch16-384' +encoder_type = 'siglip' + +enable_think = False +# 全局变量存储当前上传的视频关键帧和模型状态 +current_video_frames = None # 存储关键帧列表 +current_state = None +first_question = False +# 是否是第一轮对话 +# 初始化模型 +model = Worldinfer(model_path=llm_path, encoder_type=encoder_type, encoder_path=encoder_path) + +# 处理用户输入的核心逻辑 +import html # 导入html库 + + +def frame_att_generator(video_path, threshold=0.05, min_k=3, max_k=10): + vr = VideoReader(video_path, ctx=cpu(0)) # 使用 CPU 解码 + fps = vr.get_avg_fps() # 获取视频平均帧率 + sampling_interval = int(fps) # 每秒采样一帧作为基准 + + frames = [] + frames_flattened = [] + + for idx in range(len(vr)): + # 只对采样帧进行处理 + if idx % sampling_interval == 0: + frame = vr[idx].asnumpy() # 转换为 numpy 数组 + frame_rgb = frame / 255.0 # 归一化到 [0, 1] + frame_tensor = torch.tensor(frame_rgb).permute(2, 0, 1).half() + flat = frame_tensor.reshape(-1) # 展平缓存 + frames.append(frame_tensor) + frames_flattened.append(flat) + + if len(frames) <= 1: + return frames + + # 批量计算帧差 + flattened_tensor = torch.stack(frames_flattened) # shape: (N, C*H*W) + diffs = torch.mean(torch.abs(flattened_tensor[1:] - flattened_tensor[:-1]), dim=1) + selected_indices_sampled = [0] + [i + 1 for i, diff in enumerate(diffs) if diff > threshold] + + K = len(selected_indices_sampled) + + # 如果帧太少,补充随机帧 + if K < min_k: + candidates = [i for i in range(len(frames)) if i not in selected_indices_sampled] + missing = min_k - K + selected_indices_sampled += random.sample(candidates, missing) + selected_indices_sampled = sorted(selected_indices_sampled) + + # 如果帧太多,保留前 max_k 个差异最大的帧 + elif K > max_k: + frame_diffs = [(diff.item(), i + 1) for i, diff in enumerate(diffs)] + frame_diffs.sort(reverse=True, key=lambda x: x[0]) + top_indices = [0] + [idx for diff, idx in frame_diffs[:max_k - 1]] + selected_indices_sampled = sorted(top_indices) + + # 返回 PIL.Image.Image 图片列表 + return [ + Image.fromarray((frames[i].permute(1, 2, 0).cpu().numpy() * 255).astype("uint8")) + for i in selected_indices_sampled + ] +def chat_fn(user_input, chat_history, video=None): + global current_video_frames, current_state, first_question + + # 如果上传了新视频,更新当前视频帧并重置状态 + if video is not None: + current_video_frames = frame_att_generator(video) + + # 如果没有视频帧,提示用户上传 + if current_video_frames is None or len(current_video_frames) == 0: + bot_response = "请先上传一个视频!" + chat_history.append((user_input, bot_response)) + return "", chat_history + + # 构造提示文本 + prompt = f'\x16User: {user_input}\x17Assistant:' + + # 生成结果,传入当前状态 + try: + if first_question: + result, state = model.generate(prompt, current_video_frames[0], state=None) # 使用第一帧作为初始图像 + else: + result, state = model.generate(prompt, 'none', state=current_state) + + first_question = False + bot_response, current_state = result, state + if enable_think == True: + # 解析标签 + think_pattern = re.compile(r'', re.DOTALL) + think_matches = think_pattern.findall(bot_response) + + # 解析标签 + answer_pattern = re.compile(r'(.*?)', re.DOTALL) + answer_matches = answer_pattern.findall(bot_response) + + # 构造最终的输出 + final_response = "" + for match in think_matches: + final_response += f"
Think 🤔 {html.escape(match)}
" + + for match in answer_matches: + final_response += "Answer 💡" + final_response += "\n" + final_response += html.escape(match) + + # 转义HTML标签 + bot_response = final_response + + except Exception as e: + bot_response = f"生成回复时出错: {str(e)}" + current_state = None # 出错时重置状态 + + # 更新对话历史 + chat_history.append((user_input, bot_response)) + + # 返回更新后的组件状态 + return "", chat_history # 清空输入框,更新聊天记录# 处理图片上传 +def update_video(video_path): + global current_video_frames, current_state, first_question + if video_path is not None: + current_video_frames = frame_att_generator(video_path) # 提取关键帧 + current_state = None + first_question = True + return "视频已上传成功!可以开始提问了。" + else: + return "视频上传失败,请重新上传。" + +# 清空图片 +def clear_image(): + global current_state, current_video_frames + current_video_frames = None + current_state = None + return None, "图片已清除,请上传新图片。" + +# 清空历史和图片 +def clear_all(): + global current_video_frames, current_state + current_video_frames = None + current_state = None + return [], "", "图片和对话已清空,请重新上传图片。" + +def chat_without_video_update(user_input, chat_history): + return chat_fn(user_input, chat_history) + +# 界面布局组件 +with gr.Blocks(title="WORLD RWKV", theme=gr.themes.Soft()) as demo: + gr.Markdown("# WORLD RWKV") + gr.Markdown("上传一个视频,然后可以进行多轮提问") + + with gr.Row(): + # 左侧视频上传区 + with gr.Column(scale=2): + video_input = gr.Video( + label="上传视频", + height=400 + ) + + # 视频状态和操作 + with gr.Row(): + video_status = gr.Textbox( + label="视频状态", + value="请上传视频", + interactive=False + ) + clear_video_btn = gr.Button("删除视频") + + # 右侧对话区 + with gr.Column(scale=3): + chatbot = gr.Chatbot( + label="对话记录", + bubble_full_width=False, + height=500 + ) + + # 控制区域 + with gr.Row(): + # 输入组件 + user_input = gr.Textbox( + placeholder="请输入问题...", + scale=7, + container=False, + label="问题输入" + ) + + # 操作按钮 + with gr.Column(scale=1): + submit_btn = gr.Button("发送", variant="primary") + clear_btn = gr.Button("清空所有") + # 事件绑定 + # 视频上传事件 + video_input.change( + fn=update_video, + inputs=[video_input], + outputs=[video_status] + ) + + # 删除视频按钮事件 + clear_video_btn.click( + fn=lambda: (None, "视频已清除,请上传新视频。"), # 使用lambda直接返回正确类型 + inputs=None, + outputs=[video_input, video_status] + ) + + # 发送按钮事件 + submit_btn.click( + fn=chat_fn, + inputs=[user_input, chatbot, video_input], + outputs=[user_input, chatbot] + ) + + # 输入框回车事件 - 使用不需要视频参数的函数 + user_input.submit( + fn=chat_without_video_update, + inputs=[user_input, chatbot], + outputs=[user_input, chatbot] + ) + + # 清空按钮事件 + clear_btn.click( + fn=lambda: ([], "", "图片和对话已清空,请重新上传图片。", None), # 修复返回值 + inputs=None, + outputs=[chatbot, user_input, video_status, video_input], + queue=False + ) + +# 启动应用 +if __name__ == "__main__": + demo.launch(server_name="127.0.0.1", server_port=7860) \ No newline at end of file diff --git a/world/dataset.py b/world/dataset.py index 255eeaf..73fa060 100644 --- a/world/dataset.py +++ b/world/dataset.py @@ -18,6 +18,7 @@ pipeline = PIPELINE('rwkv', "rwkv_vocab_v20230424") from PIL import Image +from .frame_att import frame_att_generator import pandas as pd import librosa import io @@ -32,6 +33,28 @@ transforms.ToTensor() # 将图像转换为张量 ]) +def process_video(self, video_file): + + try: + if "shareVideoGPTV" in video_file: + # # 如果是 shareVideoGPTV 类型的数据,直接从图像文件读取 + # frame_files = sorted([os.path.join(video_file, f) for f in os.listdir(video_file)]) + # num_frames_to_sample = getattr(self.args, 'frames_upbound', 10) + # sampled_indices = np.linspace(0, len(frame_files) - 1, num_frames_to_sample, dtype=int) + + video_list = [] + # for idx in sampled_indices: + # frame = Image.open(frame_files[idx]).convert("RGB") + # video_list.append(frame) + else: + video_list = frame_att_generator(video_file, threshold=0.05, min_k=3, max_k=10) + + return video_list # 返回关键帧图像列表 + + except Exception as e: + print(f"Error loading video {video_file}: {e}") + raise + def process_conversation_text(conversations): conversation_text = f"\x16" @@ -56,10 +79,12 @@ def process_tokens(conversations): if role == 'human': question = f"\x16User: {content}\x17" + # print(question, "\n") input = torch.tensor(pipeline.encode(question)) label = torch.full_like(input, -100) elif role in ['assistant', 'gpt']: answer = f"\x16Assistant: {content}\x17" + # print(answer, "\n") input= torch.tensor(pipeline.encode(answer)) label = input inputs.append(input) @@ -199,6 +224,13 @@ def list_subdirectories(base_path): with jsonlines.open(args.data_file) as file: self.data = list(file) + + elif args.data_type=='video': + import jsonlines + # with open(f'{args.data_file}/chat.json', 'r', encoding='utf-8') as file: + # self.data = json.load(file) + with jsonlines.open(f'{args.data_file}/chat.jsonl') as file: + self.data = list(file) else: self.data = pd.read_parquet(args.data_file) @@ -234,6 +266,15 @@ def __getitem__(self, idx): image = Image.open(mod_path).convert('RGB') sign = image text_tokens, text_labels = process_tokens(conversation_text) + elif args.data_type == 'video': + + vid_name = self.data[idx]['video'] + conversation_text = self.data[idx]['conversations'] + + vid_path = f'{args.data_file}/{vid_name}' + sign = process_video(self, vid_path) + text_tokens, text_labels = process_tokens(conversation_text) + # print(sign,"\n") return sign, text_tokens, text_labels # if args.data_type =='wav': diff --git a/world/encoder/siglip_encoder.py b/world/encoder/siglip_encoder.py index 62b7b20..16f297b 100644 --- a/world/encoder/siglip_encoder.py +++ b/world/encoder/siglip_encoder.py @@ -38,7 +38,25 @@ def forward(self, x): x = self.proj(x) return x + self.pre_norm(x) +class VideoAdapter(nn.Module): + def __init__(self, input_dim=1024): + super().__init__() + self.input_dim = input_dim + + def depthwise_conv(): + return nn.Conv1d(input_dim, input_dim, kernel_size=3, stride=2, padding=1, groups=input_dim) + + self.reduce_conv = nn.Sequential( + depthwise_conv(), nn.Conv1d(input_dim, input_dim, kernel_size=1), nn.ReLU(), # 1/2 + depthwise_conv(), nn.Conv1d(input_dim, input_dim, kernel_size=1), nn.ReLU(), # 1/4 + # depthwise_conv(), nn.Conv1d(input_dim, input_dim, kernel_size=1), nn.ReLU(), # 1/8 + ) + def forward(self, x): + x = x.transpose(1, 2) + x = self.reduce_conv(x) + x = x.transpose(1, 2) + return x class SiglipEncoder(nn.Module): @@ -57,10 +75,23 @@ def __init__( self.encoder_dim = 768 #self.model.config.hidden_size self.adapter = VisualAdapter(self.encoder_dim, project_dim) + self.VideoAdapter = VideoAdapter() def forward(self, x): + if isinstance(x, list): # 输入图像列表,用于视频理解 + img_tensor_list = [] + for frame in x: + frame = torch.from_numpy(self.image_processor(frame)['pixel_values'][0]).to(self.device,dtype=torch.bfloat16) + frame = self.model(frame.unsqueeze(0), output_hidden_states=True).last_hidden_state + img_tensor_list.append(frame) + out_put = torch.cat(img_tensor_list, dim=1) + out_put = self.adapter(out_put) + out_put = self.VideoAdapter(out_put) + # print(out_put.shape) + return out_put - x= torch.from_numpy(self.image_processor(x)['pixel_values'][0]).to(self.device,dtype=torch.bfloat16) - x = self.model(x.unsqueeze(0), output_hidden_states=True).last_hidden_state - x = self.adapter(x) - - return x \ No newline at end of file + else: # 默认处理单个图像 + x = torch.from_numpy(self.image_processor(x)['pixel_values'][0]).to(self.device,dtype=torch.bfloat16) + x = self.model(x.unsqueeze(0), output_hidden_states=True).last_hidden_state + x = self.adapter(x) + return x + \ No newline at end of file diff --git a/world/frame_att.py b/world/frame_att.py new file mode 100644 index 0000000..ac53dc9 --- /dev/null +++ b/world/frame_att.py @@ -0,0 +1,73 @@ +import torch +import random +from PIL import Image +from decord import VideoReader +from decord import cpu +import os + +def frame_att_generator(video_path, threshold=0.05, min_k=3, max_k=10): + vr = VideoReader(video_path, ctx=cpu(0)) # 使用 CPU 解码 + fps = vr.get_avg_fps() # 获取视频平均帧率 + sampling_interval = int(fps) # 每秒采样一帧作为基准 + + frames = [] + frames_flattened = [] + + for idx in range(len(vr)): + # 只对采样帧进行处理 + if idx % sampling_interval == 0: + frame = vr[idx].asnumpy() # 转换为 numpy 数组 + frame_rgb = frame / 255.0 # 归一化到 [0, 1] + frame_tensor = torch.tensor(frame_rgb).permute(2, 0, 1).half() + flat = frame_tensor.reshape(-1) # 展平缓存 + frames.append(frame_tensor) + frames_flattened.append(flat) + + if len(frames) <= 1: + return frames + + # 批量计算帧差 + flattened_tensor = torch.stack(frames_flattened) # shape: (N, C*H*W) + diffs = torch.mean(torch.abs(flattened_tensor[1:] - flattened_tensor[:-1]), dim=1) + selected_indices_sampled = [0] + [i + 1 for i, diff in enumerate(diffs) if diff > threshold] + + K = len(selected_indices_sampled) + + # 如果帧太少,补充随机帧 + if K < min_k: + candidates = [i for i in range(len(frames)) if i not in selected_indices_sampled] + missing = min_k - K + selected_indices_sampled += random.sample(candidates, missing) + selected_indices_sampled = sorted(selected_indices_sampled) + + # 如果帧太多,保留前 max_k 个差异最大的帧 + elif K > max_k: + frame_diffs = [(diff.item(), i + 1) for i, diff in enumerate(diffs)] + frame_diffs.sort(reverse=True, key=lambda x: x[0]) + top_indices = [0] + [idx for diff, idx in frame_diffs[:max_k - 1]] + selected_indices_sampled = sorted(top_indices) + + # 返回 PIL.Image.Image 图片列表 + return [ + Image.fromarray((frames[i].permute(1, 2, 0).cpu().numpy() * 255).astype("uint8")) + for i in selected_indices_sampled + ] + +if __name__ == "__main__": + import time + video_path = "/home/alic-li/Videos/VID_20241030121322.mp4" + output_folder = "./output_frames" + os.makedirs(output_folder, exist_ok=True) + + start_time = time.time() + key_frames_pil = frame_att_generator(video_path) + end_time = time.time() + + total_frames = len(key_frames_pil) + processing_time = end_time - start_time + fps = total_frames / processing_time + + print(f"共处理 {total_frames} 帧,耗时 {processing_time:.2f} 秒,处理速度:{fps:.2f} FPS") + + for i, img in enumerate(key_frames_pil): + img.save(os.path.join(output_folder, f"key_frame_{i}.jpg")) \ No newline at end of file