Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 64 additions & 75 deletions draw_dwpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,109 +4,98 @@
import numpy as np
from tqdm import tqdm
from PIL import Image
from concurrent.futures import ThreadPoolExecutor

from pose.script.tool import save_videos_from_pil
from pose.script.dwpose import draw_pose



def draw_dwpose(video_path, pose_path, out_path, draw_face):

# capture video info
def get_video_info(video_path):
"""Retrieve video properties such as width, height, and fps."""
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
raise IOError(f"Cannot open video file: {video_path}")

width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = cap.get(cv2.CAP_PROP_FPS)
fps = int(np.around(fps))
# fps = get_fps(video_path)
fps = int(np.around(cap.get(cv2.CAP_PROP_FPS)))
cap.release()
return width, height, fps

# render resolution, short edge = 1024
k = float(1024) / min(width, height)
h_render = int(k*height//2 * 2)
w_render = int(k*width//2 * 2)

# save resolution, short edge = 768
k = float(768) / min(width, height)
h_save = int(k*height//2 * 2)
w_save = int(k*width//2 * 2)
def draw_dwpose(video_path, pose_path, out_path, draw_face):
# Get video properties
width, height, fps = get_video_info(video_path)

poses = np.load(pose_path, allow_pickle=True)
poses = poses.tolist()
# Calculate render and save dimensions
k_render = 1024 / min(width, height)
h_render, w_render = int(k_render * height // 2 * 2), int(k_render * width // 2 * 2)
k_save = 768 / min(width, height)
h_save, w_save = int(k_save * height // 2 * 2), int(k_save * width // 2 * 2)

# Load pose data
poses = np.load(pose_path, allow_pickle=True).tolist()

frames = []
for pose in tqdm(poses):
for pose in poses:
detected_map = draw_pose(pose, h_render, w_render, draw_face)
detected_map = cv2.resize(detected_map, (w_save, h_save), interpolation=cv2.INTER_AREA)
# cv2.imshow('', detected_map)
# cv2.waitKey(0)
detected_map = cv2.cvtColor(detected_map, cv2.COLOR_BGR2RGB)
detected_map = Image.fromarray(detected_map)
frames.append(detected_map)
frames.append(Image.fromarray(detected_map))

# Save the generated frames as a video
save_videos_from_pil(frames, out_path, fps)


def process_video(video_path, video_dir, pose_dir, save_dir, draw_face):
"""Processes a single video by drawing poses and saving the output."""
video_name = os.path.relpath(video_path, video_dir)
base_name = os.path.splitext(video_name)[0]

pose_path = os.path.join(pose_dir, base_name + '.npy')
if not os.path.exists(pose_path):
print(f'No keypoint file found for: {pose_path}')
return

if __name__ == "__main__":
out_path = os.path.join(save_dir, base_name + '.mp4')
if os.path.exists(out_path):
print(f'Already rendered pose video: {out_path}')
return

draw_dwpose(video_path, pose_path, out_path, draw_face)
print(f"Processed: {video_path}")


def main():
parser = argparse.ArgumentParser()
parser.add_argument("--video_dir", type=str, default="./UBC_fashion/test", help='dance video dir')
parser.add_argument("--pose_dir", type=str, default=None, help='auto makedir')
parser.add_argument("--save_dir", type=str, default=None, help='auto makedir')
parser.add_argument("--draw_face", type=bool, default=False, help='whether draw face or not')
parser.add_argument("--video_dir", type=str, default="./UBC_fashion/test", help='Path to the directory containing video files')
parser.add_argument("--pose_dir", type=str, help='Directory containing pose keypoints; auto-created if not provided')
parser.add_argument("--save_dir", type=str, help='Directory to save output videos; auto-created if not provided')
parser.add_argument("--draw_face", type=bool, default=False, help='Whether to draw face keypoints or not')
parser.add_argument("--num_workers", type=int, default=4, help="Number of workers for parallel video processing")
args = parser.parse_args()


# video dir
video_dir = args.video_dir
pose_dir = args.pose_dir or f"{video_dir}_dwpose_keypoints"
save_dir = args.save_dir or f"{video_dir}_dwpose" if args.draw_face else f"{video_dir}_dwpose_without_face"

# pose dir
if args.pose_dir is None:
pose_dir = args.video_dir + "_dwpose_keypoints"
else:
pose_dir = args.pose_dir

# save dir
if args.save_dir is None:
if args.draw_face == True:
save_dir = args.video_dir + "_dwpose"
else:
save_dir = args.video_dir + "_dwpose_without_face"
else:
save_dir = args.save_dir
if not os.path.exists(save_dir):
os.makedirs(save_dir)


# collect all video_folder paths
video_mp4_paths = set()
for root, dirs, files in os.walk(args.video_dir):
for name in files:
if name.endswith(".mp4"):
video_mp4_paths.add(os.path.join(root, name))
video_mp4_paths = list(video_mp4_paths)
# random.shuffle(video_mp4_paths)
video_mp4_paths.sort()
print("Num of videos:", len(video_mp4_paths))


# draw dwpose
for i in range(len(video_mp4_paths)):
video_path = video_mp4_paths[i]
video_name = os.path.relpath(video_path, video_dir)
base_name = os.path.splitext(video_name)[0]
# Create output directory if not exists
os.makedirs(save_dir, exist_ok=True)

# Collect all video files from the directory
video_mp4_paths = [os.path.join(root, file) for root, _, files in os.walk(video_dir) for file in files if file.endswith(".mp4")]
print(f"Found {len(video_mp4_paths)} video(s)")

# Process videos using ThreadPoolExecutor for parallelism
with ThreadPoolExecutor(max_workers=args.num_workers) as executor:
futures = [executor.submit(process_video, video_path, video_dir, pose_dir, save_dir, args.draw_face) for video_path in video_mp4_paths]

pose_path = os.path.join(pose_dir, base_name + '.npy')
if not os.path.exists(pose_path):
print('no keypoint file:', pose_path)
# Using tqdm to track the progress of video processing
for future in tqdm(futures, total=len(video_mp4_paths), desc="Processing videos"):
future.result()

out_path = os.path.join(save_dir, base_name + '.mp4')
if os.path.exists(out_path):
print('already have rendered pose:', out_path)
continue
print('All videos processed successfully!')

draw_dwpose(video_path, pose_path, out_path, args.draw_face)
print(f"Process {i+1}/{len(video_mp4_paths)} video")

print('all done!')
if __name__ == "__main__":
main()