diff --git a/clip_manager.py b/clip_manager.py index 22aa6c9a..90540b45 100644 --- a/clip_manager.py +++ b/clip_manager.py @@ -593,15 +593,46 @@ def run_videomama( traceback.print_exc() +def _read_frame(cap, files, path, index, is_linear): + """ + Helper function to safely read an image frame and convert it to sRGB or linear output format. + Isolates OpenCV disk I/O logic from the main processing pipeline. + """ + if cap: + ret, frame = cap.read() + if not ret: + return None + img_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + return img_rgb.astype(np.float32) / 255.0 + + fpath = os.path.join(path, files[index]) + is_exr = fpath.lower().endswith(".exr") + + if is_exr: + img_linear = cv2.imread(fpath, cv2.IMREAD_UNCHANGED) + if img_linear is None: + return None + img_linear_rgb = cv2.cvtColor(img_linear, cv2.COLOR_BGR2RGB) + return np.maximum(img_linear_rgb, 0.0) + else: + img_bgr = cv2.imread(fpath) + if img_bgr is None: + return None + img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB) + return img_rgb.astype(np.float32) / 255.0 + + def run_inference( clips, device=None, backend=None, max_frames=None, + start_frame: int = 0, settings: InferenceSettings | None = None, *, on_clip_start: Callable[[str, int], None] | None = None, on_frame_complete: Callable[[int, int], None] | None = None, + engine_override=None, ): ready_clips = [c for c in clips if c.input_asset and c.alpha_asset] @@ -623,9 +654,12 @@ def run_inference( if device is None: device = resolve_device() - from CorridorKeyModule.backend import create_engine - engine = create_engine(backend=backend, device=device) + engine = engine_override + if engine is None: + from CorridorKeyModule.backend import create_engine + + engine = create_engine(backend=backend, device=device) for clip in ready_clips: logger.info(f"Running Inference on: {clip.name}") @@ -640,16 +674,23 @@ def run_inference( for d in [fg_dir, matte_dir, comp_dir, proc_dir]: os.makedirs(d, exist_ok=True) - num_frames = min(clip.input_asset.frame_count, clip.alpha_asset.frame_count) + total_available = min(clip.input_asset.frame_count, clip.alpha_asset.frame_count) + num_frames = total_available if max_frames is not None: num_frames = min(num_frames, max_frames) + + actual_processing_frames = max(0, num_frames - start_frame) + logger.info( f" Input frames: {clip.input_asset.frame_count}," - f" Alpha frames: {clip.alpha_asset.frame_count} -> Processing {num_frames} frames" + f" Alpha frames: {clip.alpha_asset.frame_count} -> Processing form frame {start_frame} up to {num_frames}" ) - if num_frames == 0: - logger.warning(f"Clip '{clip.name}': 0 frames to process, skipping.") + if actual_processing_frames <= 0: + logger.warning( + f"Clip '{clip.name}': 0 frames to process " + f"(start_frame={start_frame} >= num_frames={num_frames}), skipping." + ) continue input_cap = None @@ -659,18 +700,24 @@ def run_inference( if clip.input_asset.type == "video": input_cap = cv2.VideoCapture(clip.input_asset.path) + # Advance to start_frame + if start_frame > 0: + input_cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame) else: input_files = sorted([f for f in os.listdir(clip.input_asset.path) if is_image_file(f)]) if clip.alpha_asset.type == "video": alpha_cap = cv2.VideoCapture(clip.alpha_asset.path) + # Advance to start_frame + if start_frame > 0: + alpha_cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame) else: alpha_files = sorted([f for f in os.listdir(clip.alpha_asset.path) if is_image_file(f)]) if on_clip_start: - on_clip_start(clip.name, num_frames) + on_clip_start(clip.name, actual_processing_frames) - for i in range(num_frames): + for i in range(start_frame, num_frames): # 1. Read Input img_srgb = None input_stem = f"{i:05d}" @@ -679,14 +726,12 @@ def run_inference( input_is_linear = settings.input_is_linear if input_cap: - ret, frame = input_cap.read() - if not ret: + img_srgb = _read_frame(input_cap, [], clip.input_asset.path, 0, input_is_linear) + if img_srgb is None: break - img_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) - img_srgb = img_rgb.astype(np.float32) / 255.0 input_stem = f"{i:05d}" else: - fpath = os.path.join(clip.input_asset.path, input_files[i]) + img_srgb = _read_frame(None, input_files, clip.input_asset.path, i, input_is_linear) input_stem = os.path.splitext(input_files[i])[0] is_exr = fpath.lower().endswith(".exr") diff --git a/corridorkey_cli.py b/corridorkey_cli.py index cc2afcdc..d5b0579b 100644 --- a/corridorkey_cli.py +++ b/corridorkey_cli.py @@ -249,6 +249,10 @@ def run_inference_cmd( Optional[int], typer.Option("--max-frames", help="Limit frames per clip"), ] = None, + start_frame: Annotated[ + int, + typer.Option("--start-frame", help="Start inference from this frame index"), + ] = 0, linear: Annotated[ Optional[bool], typer.Option("--linear/--srgb", help="Input colorspace (default: prompt)"), @@ -280,7 +284,9 @@ def run_inference_cmd( # despeckle_size excluded — sensible default even in headless mode required_flags_set = all(v is not None for v in [linear, despill, despeckle, refiner]) if required_flags_set: - assert linear is not None and despill is not None and despeckle is not None and refiner is not None + if any(v is None for v in [linear, despill, despeckle, refiner]): + raise ValueError("Missing required flags for inference settings.") + despill_clamped = max(0, min(10, despill)) settings = InferenceSettings( input_is_linear=linear, @@ -304,6 +310,7 @@ def run_inference_cmd( device=ctx.obj["device"], backend=backend, max_frames=max_frames, + start_frame=start_frame, settings=settings, on_clip_start=ctx_progress.on_clip_start, on_frame_complete=ctx_progress.on_frame_complete, @@ -439,8 +446,8 @@ def interactive_wizard(win_path: str, device: str | None = None) -> None: entry = ClipEntry(os.path.basename(d), d) try: entry.find_assets() - except (FileNotFoundError, ValueError, OSError): - pass + except (FileNotFoundError, ValueError, OSError) as e: + logger.debug(f"Skipping clip setup for '{d}': {e}") has_mask = False mask_dir = os.path.join(d, "VideoMamaMaskHint") diff --git a/tests/test_clip_manager.py b/tests/test_clip_manager.py index b7763ebe..a1ceecb4 100644 --- a/tests/test_clip_manager.py +++ b/tests/test_clip_manager.py @@ -1,292 +1,182 @@ -"""Tests for clip_manager.py utility functions and ClipEntry discovery. - -These tests verify the non-interactive parts of clip_manager: file type -detection, Windows→Linux path mapping, and the ClipEntry asset discovery -that scans directory trees to find Input/AlphaHint pairs. - -No GPU, model weights, or interactive input required. -""" - -from __future__ import annotations - import os +import unittest +from unittest.mock import MagicMock, patch -import cv2 import numpy as np -import pytest - -from clip_manager import ( - ClipAsset, - ClipEntry, - is_image_file, - is_video_file, - map_path, - organize_target, -) - -# --------------------------------------------------------------------------- -# is_image_file / is_video_file -# --------------------------------------------------------------------------- +from clip_manager import ClipAsset, ClipEntry, InferenceSettings, is_image_file, is_video_file, run_inference -class TestFileTypeDetection: - """Verify extension-based file type helpers. - These are used everywhere in clip_manager to decide how to read inputs. - A missed extension means a valid frame silently disappears from the batch. +class TestClipManagerInference(unittest.TestCase): """ - - @pytest.mark.parametrize( - "filename", - [ - "frame.png", - "SHOT_001.EXR", - "plate.jpg", - "ref.JPEG", - "scan.tif", - "deep.tiff", - "comp.bmp", - ], - ) - def test_image_extensions_recognized(self, filename): - assert is_image_file(filename) - - @pytest.mark.parametrize( - "filename", - [ - "frame.mp4", - "CLIP.MOV", - "take.avi", - "rushes.mkv", - ], - ) - def test_video_extensions_recognized(self, filename): - assert is_video_file(filename) - - @pytest.mark.parametrize( - "filename", - [ - "readme.txt", - "notes.pdf", - "project.nk", - "scene.blend", - ".DS_Store", - ], - ) - def test_non_media_rejected(self, filename): - assert not is_image_file(filename) - assert not is_video_file(filename) - - def test_image_is_not_video(self): - """Image and video extensions must not overlap.""" - assert not is_video_file("frame.png") - assert not is_video_file("plate.exr") - - def test_video_is_not_image(self): - assert not is_image_file("clip.mp4") - assert not is_image_file("rushes.mov") - - -# --------------------------------------------------------------------------- -# map_path -# --------------------------------------------------------------------------- - - -class TestMapPath: - r"""Windows→Linux path mapping. - - The tool is designed for studios running a Linux render farm with - Windows workstations. V:\ maps to /mnt/ssd-storage. - """ - - def test_basic_mapping(self): - result = map_path(r"V:\Projects\Shot1") - assert result == "/mnt/ssd-storage/Projects/Shot1" - - def test_case_insensitive_drive_letter(self): - result = map_path(r"v:\projects\shot1") - assert result == "/mnt/ssd-storage/projects/shot1" - - def test_trailing_whitespace_stripped(self): - result = map_path(r" V:\Projects\Shot1 ") - assert result == "/mnt/ssd-storage/Projects/Shot1" - - def test_backslashes_converted(self): - result = map_path(r"V:\Deep\Nested\Path\Here") - assert "\\" not in result - - def test_non_v_drive_passthrough(self): - """Paths not on V: are returned as-is (may already be Linux paths).""" - linux_path = "/mnt/other/data" - assert map_path(linux_path) == linux_path - - def test_drive_root_only(self): - result = map_path("V:\\") - assert result == "/mnt/ssd-storage/" - - -# --------------------------------------------------------------------------- -# ClipAsset -# --------------------------------------------------------------------------- - - -class TestClipAsset: - """ClipAsset wraps a directory of images or a video file and counts frames.""" - - def test_sequence_frame_count(self, tmp_path): - """Image sequence: frame count = number of image files in directory.""" - seq_dir = tmp_path / "Input" - seq_dir.mkdir() - tiny = np.zeros((4, 4, 3), dtype=np.uint8) - for i in range(5): - cv2.imwrite(str(seq_dir / f"frame_{i:04d}.png"), tiny) - - asset = ClipAsset(str(seq_dir), "sequence") - assert asset.frame_count == 5 - - def test_sequence_ignores_non_image_files(self, tmp_path): - """Non-image files (thumbs.db, .nk, etc.) should not be counted.""" - seq_dir = tmp_path / "Input" - seq_dir.mkdir() - tiny = np.zeros((4, 4, 3), dtype=np.uint8) - cv2.imwrite(str(seq_dir / "frame_0000.png"), tiny) - (seq_dir / "thumbs.db").write_text("junk") - (seq_dir / "notes.txt").write_text("notes") - - asset = ClipAsset(str(seq_dir), "sequence") - assert asset.frame_count == 1 - - def test_empty_sequence(self, tmp_path): - """Empty directory → 0 frames.""" - seq_dir = tmp_path / "Input" - seq_dir.mkdir() - asset = ClipAsset(str(seq_dir), "sequence") - assert asset.frame_count == 0 - - -# --------------------------------------------------------------------------- -# ClipEntry.find_assets -# --------------------------------------------------------------------------- - - -class TestClipEntryFindAssets: - """ClipEntry.find_assets() discovers Input and AlphaHint from a shot directory. - - This is the core discovery logic that decides what's ready for inference - vs. what still needs alpha generation. + Test suite for the CorridorKey inference pipeline using mocks. """ - def test_finds_image_sequence_input(self, tmp_clip_dir): - """shot_a has Input/ with 2 PNGs → input_asset is a sequence.""" - entry = ClipEntry("shot_a", str(tmp_clip_dir / "shot_a")) - entry.find_assets() - assert entry.input_asset is not None - assert entry.input_asset.type == "sequence" - assert entry.input_asset.frame_count == 2 - - def test_finds_alpha_hint(self, tmp_clip_dir): - """shot_a has AlphaHint/ with 2 PNGs → alpha_asset is populated.""" - entry = ClipEntry("shot_a", str(tmp_clip_dir / "shot_a")) + @patch("clip_manager.cv2.imread") + @patch("clip_manager.cv2.cvtColor") + @patch("clip_manager.cv2.imwrite") + @patch("clip_manager.is_image_file") + @patch("clip_manager.os.listdir") + def test_run_inference_basic_flow( + self, mock_listdir, mock_is_image_file, mock_imwrite, mock_cvt_color, mock_imread + ): + import tempfile + + with tempfile.TemporaryDirectory() as tmp_path: + clip_root = os.path.join(tmp_path, "TestClip") + os.makedirs(os.path.join(clip_root, "Input")) + os.makedirs(os.path.join(clip_root, "AlphaHint")) + + clip = ClipEntry("TestClip", clip_root) + clip.input_asset = ClipAsset(os.path.join(clip_root, "Input"), "sequence") + clip.alpha_asset = ClipAsset(os.path.join(clip_root, "AlphaHint"), "sequence") + clip.input_asset.frame_count = 2 + clip.alpha_asset.frame_count = 2 + + def side_effect_imread(path, flags=None): + if "alpha" in path.lower(): + return np.zeros((10, 10), dtype=np.uint8) + return np.zeros((10, 10, 3), dtype=np.uint8) + + mock_imread.side_effect = side_effect_imread + mock_cvt_color.return_value = np.zeros((10, 10, 3), dtype=np.uint8) + mock_is_image_file.return_value = True + + def side_effect_listdir(path): + if "input" in path.lower(): + return ["frame_000.png", "frame_001.png"] + else: + return ["alpha_000.png", "alpha_001.png"] + + mock_listdir.side_effect = side_effect_listdir + + import clip_manager + + clip_manager.cv2.IMREAD_UNCHANGED = -1 + clip_manager.cv2.IMREAD_ANYDEPTH = 2 + clip_manager.cv2.COLOR_BGR2RGB = 4 + + mock_engine = MagicMock() + mock_engine.process_frame.return_value = { + "alpha": np.zeros((10, 10, 1), dtype=np.float32), + "fg": np.zeros((10, 10, 3), dtype=np.float32), + "comp": np.zeros((10, 10, 3), dtype=np.float32), + "processed": np.zeros((10, 10, 4), dtype=np.float32), + } + + mock_on_clip_start = MagicMock() + mock_on_frame_complete = MagicMock() + settings = InferenceSettings() + + run_inference( + clips=[clip], + settings=settings, + on_clip_start=mock_on_clip_start, + on_frame_complete=mock_on_frame_complete, + engine_override=mock_engine, + ) + + self.assertEqual(mock_engine.process_frame.call_count, 2) + mock_on_clip_start.assert_called_once_with("TestClip", 2) + self.assertTrue(os.path.exists(os.path.join(clip_root, "Output", "FG"))) + self.assertTrue(os.path.exists(os.path.join(clip_root, "Output", "Matte"))) + + @patch("clip_manager.cv2.imread") + @patch("clip_manager.cv2.cvtColor") + @patch("clip_manager.cv2.imwrite") + @patch("clip_manager.is_image_file") + @patch("clip_manager.os.listdir") + def test_run_inference_start_frame( + self, mock_listdir, mock_is_image_file, mock_imwrite, mock_cvt_color, mock_imread + ): + import tempfile + + with tempfile.TemporaryDirectory() as tmp_path: + clip_root = os.path.join(tmp_path, "TestClipStartFrame") + os.makedirs(os.path.join(clip_root, "Input")) + os.makedirs(os.path.join(clip_root, "AlphaHint")) + + clip = ClipEntry("TestClipStartFrame", clip_root) + clip.input_asset = ClipAsset(os.path.join(clip_root, "Input"), "sequence") + clip.alpha_asset = ClipAsset(os.path.join(clip_root, "AlphaHint"), "sequence") + clip.input_asset.frame_count = 3 + clip.alpha_asset.frame_count = 3 + + def side_effect_imread(path, flags=None): + if "alpha" in path.lower(): + return np.zeros((10, 10), dtype=np.uint8) + return np.zeros((10, 10, 3), dtype=np.uint8) + + mock_imread.side_effect = side_effect_imread + mock_cvt_color.return_value = np.zeros((10, 10, 3), dtype=np.uint8) + mock_is_image_file.return_value = True + + def side_effect_listdir(path): + if "input" in path.lower(): + return ["frame_000.png", "frame_001.png", "frame_002.png"] + else: + return ["alpha_000.png", "alpha_001.png", "alpha_002.png"] + + mock_listdir.side_effect = side_effect_listdir + + import clip_manager + + clip_manager.cv2.IMREAD_UNCHANGED = -1 + clip_manager.cv2.IMREAD_ANYDEPTH = 2 + clip_manager.cv2.COLOR_BGR2RGB = 4 + + mock_engine = MagicMock() + mock_engine.process_frame.return_value = { + "alpha": np.zeros((10, 10, 1), dtype=np.float32), + "fg": np.zeros((10, 10, 3), dtype=np.float32), + "comp": np.zeros((10, 10, 3), dtype=np.float32), + "processed": np.zeros((10, 10, 4), dtype=np.float32), + } + + mock_on_clip_start = MagicMock() + + run_inference(clips=[clip], start_frame=2, on_clip_start=mock_on_clip_start, engine_override=mock_engine) + + self.assertEqual(mock_engine.process_frame.call_count, 1) + mock_on_clip_start.assert_called_once_with("TestClipStartFrame", 1) + + +class TestFileTypeHelpers(unittest.TestCase): + def test_is_image_file_accepted(self): + for ext in (".png", ".jpg", ".jpeg", ".exr", ".tif", ".tiff", ".bmp"): + self.assertTrue(is_image_file(f"frame{ext}")) + + def test_is_image_file_rejected(self): + for ext in (".mp4", ".txt", ".mov"): + self.assertFalse(is_image_file(f"clip{ext}")) + + def test_is_video_file_accepted(self): + for ext in (".mp4", ".mov", ".avi", ".mkv"): + self.assertTrue(is_video_file(f"clip{ext}")) + + def test_is_video_file_rejected(self): + self.assertFalse(is_video_file("frame.png")) + + +class TestClipAsset(unittest.TestCase): + @patch("clip_manager.os.listdir", return_value=["a.png", "b.png", "c.png"]) + @patch("clip_manager.is_image_file", return_value=True) + def test_sequence_frame_count(self, mock_iif, mock_ls): + asset = ClipAsset("/fake/path", "sequence") + self.assertEqual(asset.frame_count, 3) + + +class TestClipEntry(unittest.TestCase): + @patch("clip_manager.glob.glob", return_value=[]) + @patch("clip_manager.os.listdir", return_value=["frame_000.png"]) + @patch("clip_manager.is_image_file", return_value=True) + @patch("clip_manager.os.path.isdir", side_effect=lambda p: os.path.basename(p) == "Input") + def test_find_assets_sequence(self, mock_isdir, mock_iif, mock_ls, mock_glob): + entry = ClipEntry("shot_a", "/fake/root") entry.find_assets() - assert entry.alpha_asset is not None - assert entry.alpha_asset.type == "sequence" - assert entry.alpha_asset.frame_count == 2 - - def test_empty_alpha_hint_is_none(self, tmp_clip_dir): - """shot_b has empty AlphaHint/ → alpha_asset is None (needs generation).""" - entry = ClipEntry("shot_b", str(tmp_clip_dir / "shot_b")) - entry.find_assets() - assert entry.input_asset is not None - assert entry.alpha_asset is None - - def test_missing_input_raises(self, tmp_path): - """A shot with no Input directory or video raises ValueError.""" - empty_shot = tmp_path / "empty_shot" - empty_shot.mkdir() - entry = ClipEntry("empty_shot", str(empty_shot)) - with pytest.raises(ValueError, match="No 'Input' directory or video file found"): - entry.find_assets() - - def test_empty_input_dir_raises(self, tmp_path): - """An empty Input/ directory raises ValueError.""" - shot = tmp_path / "bad_shot" - (shot / "Input").mkdir(parents=True) - entry = ClipEntry("bad_shot", str(shot)) - with pytest.raises(ValueError, match="'Input' directory is empty"): - entry.find_assets() - - def test_validate_pair_frame_count_mismatch(self, tmp_path): - """Mismatched Input/AlphaHint frame counts raise ValueError.""" - shot = tmp_path / "mismatch" - (shot / "Input").mkdir(parents=True) - (shot / "AlphaHint").mkdir(parents=True) - - tiny = np.zeros((4, 4, 3), dtype=np.uint8) - tiny_mask = np.zeros((4, 4), dtype=np.uint8) - - # 3 input frames, 2 alpha frames - for i in range(3): - cv2.imwrite(str(shot / "Input" / f"frame_{i:04d}.png"), tiny) - for i in range(2): - cv2.imwrite(str(shot / "AlphaHint" / f"frame_{i:04d}.png"), tiny_mask) - - entry = ClipEntry("mismatch", str(shot)) - entry.find_assets() - with pytest.raises(ValueError, match="Frame count mismatch"): - entry.validate_pair() - - def test_validate_pair_matching_counts_ok(self, tmp_clip_dir): - """Matching frame counts pass validation without error.""" - entry = ClipEntry("shot_a", str(tmp_clip_dir / "shot_a")) - entry.find_assets() - entry.validate_pair() # should not raise - - -# --------------------------------------------------------------------------- -# organize_target -# --------------------------------------------------------------------------- - - -class TestOrganizeTarget: - """organize_target() sets up the hint directory structure for a shot. - - It creates AlphaHint/ and VideoMamaMaskHint/ directories if missing. - """ - - def test_creates_hint_directories(self, tmp_path): - """Missing hint directories should be created.""" - shot = tmp_path / "shot_x" - (shot / "Input").mkdir(parents=True) - tiny = np.zeros((4, 4, 3), dtype=np.uint8) - cv2.imwrite(str(shot / "Input" / "frame_0000.png"), tiny) - - organize_target(str(shot)) - - assert (shot / "AlphaHint").is_dir() - assert (shot / "VideoMamaMaskHint").is_dir() - - def test_existing_hint_dirs_preserved(self, tmp_clip_dir): - """Existing hint directories and their contents are not disturbed.""" - shot_a = tmp_clip_dir / "shot_a" - alpha_files_before = sorted(os.listdir(shot_a / "AlphaHint")) - - organize_target(str(shot_a)) - - alpha_files_after = sorted(os.listdir(shot_a / "AlphaHint")) - assert alpha_files_before == alpha_files_after - - def test_moves_loose_images_to_input(self, tmp_path): - """Loose image files in a shot dir get moved into Input/.""" - shot = tmp_path / "messy_shot" - shot.mkdir() - tiny = np.zeros((4, 4, 3), dtype=np.uint8) - cv2.imwrite(str(shot / "frame_0000.png"), tiny) - cv2.imwrite(str(shot / "frame_0001.png"), tiny) + self.assertIsNotNone(entry.input_asset) + self.assertEqual(entry.input_asset.type, "sequence") - organize_target(str(shot)) - assert (shot / "Input").is_dir() - input_files = os.listdir(shot / "Input") - assert len(input_files) == 2 - # Original loose files should be gone - assert not (shot / "frame_0000.png").exists() +if __name__ == "__main__": + unittest.main()