-
Notifications
You must be signed in to change notification settings - Fork 39
Open
Description
Hi, and thank you for making this code available!
I am running:
python video_infer.py D:\\DeBlur\\SimDeblur-main\\SimDeblur-main\\configs\\dbn\\dbn_gopro.yaml D:\\DeBlur\\SimDeblur-main\\SimDeblur-main\\saves\\checkpoints\\DBN\\dbn_ckpt.pth --frames_folder_path=datasets/input --save_path=deblur
where video_infer.py is:
import os
import sys
import argparse
import numpy as np
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
from easydict import EasyDict as edict
from simdeblur.config import build_config
from simdeblur.model import build_backbone
from simdeblur.dataset.frames_folder import FramesFolder
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("config", type=str, help="The config .yaml file of deblurring model. ")
parser.add_argument("ckpt", type=str, help="The trained checkpoint of the selected deblurring model. ")
parser.add_argument("--frames_folder_path", type=str, help="The video frames folder path. ")
parser.add_argument("--save_path", type=str, help="The output deblurred path")
args = parser.parse_args()
return args
def frames_foler_demo():
args = parse_args()
config = build_config(args.config)
config.args = args
model = build_backbone(config.model).cuda()
ckpt = torch.load(args.ckpt, map_location="cuda:0")
model_ckpt = ckpt["model"]
model_ckpt = {k[7:]: v for k, v in model_ckpt.items()}
model.load_state_dict(model_ckpt)
data_config = edict({
"root_input": "D:\\DeBlur\\SimDeblur-main\\SimDeblur-main\\datasets\\input",
"num_frames": 5,
"overlapping": True,
"sampling": "n_c"
})
frames_data = FramesFolder(data_config)
frames_dataloader = torch.utils.data.DataLoader(frames_data, 1)
model.eval()
with torch.no_grad():
for i, batch_data in enumerate(frames_dataloader):
out = model(batch_data["input_frames"].cuda())
print(batch_data["gt_names"], out.shape)
save_image_path = f"./deblurred_frame{i}.png"
cv2.imwrite(save_image_path, out[0].cpu().permute(1, 2, 0).numpy())
if __name__ == "__main__":
frames_foler_demo()
The code runs, and prints:
Cannot inport EDVR modules!!!
[('00002.jpg',)] torch.Size([1, 3, 1080, 1920])
[('00003.jpg',)] torch.Size([1, 3, 1080, 1920])
....
But all the resulting frames are just solid black. What might be happening here?
Thanks!
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels