diff --git a/draw_dwpose.py b/draw_dwpose.py index 98117af..b8eaaa4 100644 --- a/draw_dwpose.py +++ b/draw_dwpose.py @@ -4,109 +4,98 @@ import numpy as np from tqdm import tqdm from PIL import Image +from concurrent.futures import ThreadPoolExecutor from pose.script.tool import save_videos_from_pil from pose.script.dwpose import draw_pose - -def draw_dwpose(video_path, pose_path, out_path, draw_face): - - # capture video info +def get_video_info(video_path): + """Retrieve video properties such as width, height, and fps.""" cap = cv2.VideoCapture(video_path) + if not cap.isOpened(): + raise IOError(f"Cannot open video file: {video_path}") + width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) - fps = cap.get(cv2.CAP_PROP_FPS) - fps = int(np.around(fps)) - # fps = get_fps(video_path) + fps = int(np.around(cap.get(cv2.CAP_PROP_FPS))) cap.release() + return width, height, fps - # render resolution, short edge = 1024 - k = float(1024) / min(width, height) - h_render = int(k*height//2 * 2) - w_render = int(k*width//2 * 2) - # save resolution, short edge = 768 - k = float(768) / min(width, height) - h_save = int(k*height//2 * 2) - w_save = int(k*width//2 * 2) +def draw_dwpose(video_path, pose_path, out_path, draw_face): + # Get video properties + width, height, fps = get_video_info(video_path) - poses = np.load(pose_path, allow_pickle=True) - poses = poses.tolist() + # Calculate render and save dimensions + k_render = 1024 / min(width, height) + h_render, w_render = int(k_render * height // 2 * 2), int(k_render * width // 2 * 2) + k_save = 768 / min(width, height) + h_save, w_save = int(k_save * height // 2 * 2), int(k_save * width // 2 * 2) + + # Load pose data + poses = np.load(pose_path, allow_pickle=True).tolist() frames = [] - for pose in tqdm(poses): + for pose in poses: detected_map = draw_pose(pose, h_render, w_render, draw_face) detected_map = cv2.resize(detected_map, (w_save, h_save), interpolation=cv2.INTER_AREA) - # cv2.imshow('', detected_map) - # cv2.waitKey(0) detected_map = cv2.cvtColor(detected_map, cv2.COLOR_BGR2RGB) - detected_map = Image.fromarray(detected_map) - frames.append(detected_map) - + frames.append(Image.fromarray(detected_map)) + + # Save the generated frames as a video save_videos_from_pil(frames, out_path, fps) +def process_video(video_path, video_dir, pose_dir, save_dir, draw_face): + """Processes a single video by drawing poses and saving the output.""" + video_name = os.path.relpath(video_path, video_dir) + base_name = os.path.splitext(video_name)[0] + + pose_path = os.path.join(pose_dir, base_name + '.npy') + if not os.path.exists(pose_path): + print(f'No keypoint file found for: {pose_path}') + return -if __name__ == "__main__": + out_path = os.path.join(save_dir, base_name + '.mp4') + if os.path.exists(out_path): + print(f'Already rendered pose video: {out_path}') + return + draw_dwpose(video_path, pose_path, out_path, draw_face) + print(f"Processed: {video_path}") + + +def main(): parser = argparse.ArgumentParser() - parser.add_argument("--video_dir", type=str, default="./UBC_fashion/test", help='dance video dir') - parser.add_argument("--pose_dir", type=str, default=None, help='auto makedir') - parser.add_argument("--save_dir", type=str, default=None, help='auto makedir') - parser.add_argument("--draw_face", type=bool, default=False, help='whether draw face or not') + parser.add_argument("--video_dir", type=str, default="./UBC_fashion/test", help='Path to the directory containing video files') + parser.add_argument("--pose_dir", type=str, help='Directory containing pose keypoints; auto-created if not provided') + parser.add_argument("--save_dir", type=str, help='Directory to save output videos; auto-created if not provided') + parser.add_argument("--draw_face", type=bool, default=False, help='Whether to draw face keypoints or not') + parser.add_argument("--num_workers", type=int, default=4, help="Number of workers for parallel video processing") args = parser.parse_args() - - # video dir video_dir = args.video_dir + pose_dir = args.pose_dir or f"{video_dir}_dwpose_keypoints" + save_dir = args.save_dir or f"{video_dir}_dwpose" if args.draw_face else f"{video_dir}_dwpose_without_face" - # pose dir - if args.pose_dir is None: - pose_dir = args.video_dir + "_dwpose_keypoints" - else: - pose_dir = args.pose_dir - - # save dir - if args.save_dir is None: - if args.draw_face == True: - save_dir = args.video_dir + "_dwpose" - else: - save_dir = args.video_dir + "_dwpose_without_face" - else: - save_dir = args.save_dir - if not os.path.exists(save_dir): - os.makedirs(save_dir) - - - # collect all video_folder paths - video_mp4_paths = set() - for root, dirs, files in os.walk(args.video_dir): - for name in files: - if name.endswith(".mp4"): - video_mp4_paths.add(os.path.join(root, name)) - video_mp4_paths = list(video_mp4_paths) - # random.shuffle(video_mp4_paths) - video_mp4_paths.sort() - print("Num of videos:", len(video_mp4_paths)) - - - # draw dwpose - for i in range(len(video_mp4_paths)): - video_path = video_mp4_paths[i] - video_name = os.path.relpath(video_path, video_dir) - base_name = os.path.splitext(video_name)[0] + # Create output directory if not exists + os.makedirs(save_dir, exist_ok=True) + + # Collect all video files from the directory + video_mp4_paths = [os.path.join(root, file) for root, _, files in os.walk(video_dir) for file in files if file.endswith(".mp4")] + print(f"Found {len(video_mp4_paths)} video(s)") + + # Process videos using ThreadPoolExecutor for parallelism + with ThreadPoolExecutor(max_workers=args.num_workers) as executor: + futures = [executor.submit(process_video, video_path, video_dir, pose_dir, save_dir, args.draw_face) for video_path in video_mp4_paths] - pose_path = os.path.join(pose_dir, base_name + '.npy') - if not os.path.exists(pose_path): - print('no keypoint file:', pose_path) + # Using tqdm to track the progress of video processing + for future in tqdm(futures, total=len(video_mp4_paths), desc="Processing videos"): + future.result() - out_path = os.path.join(save_dir, base_name + '.mp4') - if os.path.exists(out_path): - print('already have rendered pose:', out_path) - continue + print('All videos processed successfully!') - draw_dwpose(video_path, pose_path, out_path, args.draw_face) - print(f"Process {i+1}/{len(video_mp4_paths)} video") - print('all done!') \ No newline at end of file +if __name__ == "__main__": + main()