Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 58 additions & 13 deletions clip_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,15 +593,46 @@ def run_videomama(
traceback.print_exc()


def _read_frame(cap, files, path, index, is_linear):
"""
Helper function to safely read an image frame and convert it to sRGB or linear output format.
Isolates OpenCV disk I/O logic from the main processing pipeline.
"""
if cap:
ret, frame = cap.read()
if not ret:
return None
img_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
return img_rgb.astype(np.float32) / 255.0

fpath = os.path.join(path, files[index])
is_exr = fpath.lower().endswith(".exr")

if is_exr:
img_linear = cv2.imread(fpath, cv2.IMREAD_UNCHANGED)
if img_linear is None:
return None
img_linear_rgb = cv2.cvtColor(img_linear, cv2.COLOR_BGR2RGB)
return np.maximum(img_linear_rgb, 0.0)
else:
img_bgr = cv2.imread(fpath)
if img_bgr is None:
return None
img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
return img_rgb.astype(np.float32) / 255.0


def run_inference(
clips,
device=None,
backend=None,
max_frames=None,
start_frame: int = 0,
settings: InferenceSettings | None = None,
*,
on_clip_start: Callable[[str, int], None] | None = None,
on_frame_complete: Callable[[int, int], None] | None = None,
engine_override=None,
):
ready_clips = [c for c in clips if c.input_asset and c.alpha_asset]

Expand All @@ -623,9 +654,12 @@ def run_inference(

if device is None:
device = resolve_device()
from CorridorKeyModule.backend import create_engine

engine = create_engine(backend=backend, device=device)
engine = engine_override
if engine is None:
from CorridorKeyModule.backend import create_engine

engine = create_engine(backend=backend, device=device)

for clip in ready_clips:
logger.info(f"Running Inference on: {clip.name}")
Expand All @@ -640,16 +674,23 @@ def run_inference(
for d in [fg_dir, matte_dir, comp_dir, proc_dir]:
os.makedirs(d, exist_ok=True)

num_frames = min(clip.input_asset.frame_count, clip.alpha_asset.frame_count)
total_available = min(clip.input_asset.frame_count, clip.alpha_asset.frame_count)
num_frames = total_available
if max_frames is not None:
num_frames = min(num_frames, max_frames)

actual_processing_frames = max(0, num_frames - start_frame)

logger.info(
f" Input frames: {clip.input_asset.frame_count},"
f" Alpha frames: {clip.alpha_asset.frame_count} -> Processing {num_frames} frames"
f" Alpha frames: {clip.alpha_asset.frame_count} -> Processing form frame {start_frame} up to {num_frames}"
)

if num_frames == 0:
logger.warning(f"Clip '{clip.name}': 0 frames to process, skipping.")
if actual_processing_frames <= 0:
logger.warning(
f"Clip '{clip.name}': 0 frames to process "
f"(start_frame={start_frame} >= num_frames={num_frames}), skipping."
)
continue

input_cap = None
Expand All @@ -659,18 +700,24 @@ def run_inference(

if clip.input_asset.type == "video":
input_cap = cv2.VideoCapture(clip.input_asset.path)
# Advance to start_frame
if start_frame > 0:
input_cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)
else:
input_files = sorted([f for f in os.listdir(clip.input_asset.path) if is_image_file(f)])

if clip.alpha_asset.type == "video":
alpha_cap = cv2.VideoCapture(clip.alpha_asset.path)
# Advance to start_frame
if start_frame > 0:
alpha_cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)
else:
alpha_files = sorted([f for f in os.listdir(clip.alpha_asset.path) if is_image_file(f)])

if on_clip_start:
on_clip_start(clip.name, num_frames)
on_clip_start(clip.name, actual_processing_frames)

for i in range(num_frames):
for i in range(start_frame, num_frames):
# 1. Read Input
img_srgb = None
input_stem = f"{i:05d}"
Expand All @@ -679,14 +726,12 @@ def run_inference(
input_is_linear = settings.input_is_linear

if input_cap:
ret, frame = input_cap.read()
if not ret:
img_srgb = _read_frame(input_cap, [], clip.input_asset.path, 0, input_is_linear)
if img_srgb is None:
break
img_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
img_srgb = img_rgb.astype(np.float32) / 255.0
input_stem = f"{i:05d}"
else:
fpath = os.path.join(clip.input_asset.path, input_files[i])
img_srgb = _read_frame(None, input_files, clip.input_asset.path, i, input_is_linear)
input_stem = os.path.splitext(input_files[i])[0]

is_exr = fpath.lower().endswith(".exr")
Expand Down
13 changes: 10 additions & 3 deletions corridorkey_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,10 @@ def run_inference_cmd(
Optional[int],
typer.Option("--max-frames", help="Limit frames per clip"),
] = None,
start_frame: Annotated[
int,
typer.Option("--start-frame", help="Start inference from this frame index"),
] = 0,
linear: Annotated[
Optional[bool],
typer.Option("--linear/--srgb", help="Input colorspace (default: prompt)"),
Expand Down Expand Up @@ -280,7 +284,9 @@ def run_inference_cmd(
# despeckle_size excluded — sensible default even in headless mode
required_flags_set = all(v is not None for v in [linear, despill, despeckle, refiner])
if required_flags_set:
assert linear is not None and despill is not None and despeckle is not None and refiner is not None
if any(v is None for v in [linear, despill, despeckle, refiner]):
raise ValueError("Missing required flags for inference settings.")

despill_clamped = max(0, min(10, despill))
settings = InferenceSettings(
input_is_linear=linear,
Expand All @@ -304,6 +310,7 @@ def run_inference_cmd(
device=ctx.obj["device"],
backend=backend,
max_frames=max_frames,
start_frame=start_frame,
settings=settings,
on_clip_start=ctx_progress.on_clip_start,
on_frame_complete=ctx_progress.on_frame_complete,
Expand Down Expand Up @@ -439,8 +446,8 @@ def interactive_wizard(win_path: str, device: str | None = None) -> None:
entry = ClipEntry(os.path.basename(d), d)
try:
entry.find_assets()
except (FileNotFoundError, ValueError, OSError):
pass
except (FileNotFoundError, ValueError, OSError) as e:
logger.debug(f"Skipping clip setup for '{d}': {e}")

has_mask = False
mask_dir = os.path.join(d, "VideoMamaMaskHint")
Expand Down
Loading
Loading