From baa59a1d4099fa2305f82eecefd45bb73a2ba971 Mon Sep 17 00:00:00 2001 From: marclie Date: Mon, 9 Mar 2026 15:19:38 +0100 Subject: [PATCH 01/44] torch.compile + precision optimizations --- CorridorKeyModule/core/model_transformer.py | 3 + CorridorKeyModule/inference_engine.py | 31 +- pyproject.toml | 7 +- test_vram.py | 19 +- uv.lock | 357 ++++++++++---------- 5 files changed, 219 insertions(+), 198 deletions(-) diff --git a/CorridorKeyModule/core/model_transformer.py b/CorridorKeyModule/core/model_transformer.py index fe6be4e4..0b6d20e9 100644 --- a/CorridorKeyModule/core/model_transformer.py +++ b/CorridorKeyModule/core/model_transformer.py @@ -1,4 +1,5 @@ from __future__ import annotations +import sys import timm import torch @@ -138,6 +139,8 @@ def forward(self, img: torch.Tensor, coarse_pred: torch.Tensor) -> torch.Tensor: return self.final(x) * 10.0 +# We only tested compilation on windows and linux. For other platforms compilation is disabled as a precaution. +@torch.compile(disable=(sys.platform != "linux" and sys.platform != "win32")) class GreenFormer(nn.Module): def __init__( self, diff --git a/CorridorKeyModule/inference_engine.py b/CorridorKeyModule/inference_engine.py index 6cd0c803..acd9a8bd 100644 --- a/CorridorKeyModule/inference_engine.py +++ b/CorridorKeyModule/inference_engine.py @@ -2,6 +2,7 @@ import math import os +from timeit import timeit import cv2 import numpy as np @@ -14,7 +15,13 @@ class CorridorKeyEngine: def __init__( - self, checkpoint_path: str, device: str = "cpu", img_size: int = 2048, use_refiner: bool = True + self, + checkpoint_path: str, + device: str = "cpu", + img_size: int = 2048, + use_refiner: bool = True, + mixed_precision: bool = True, + model_precision: torch.dtype = torch.float32, ) -> None: self.device = torch.device(device) self.img_size = img_size @@ -23,8 +30,20 @@ def __init__( self.mean = np.array([0.485, 0.456, 0.406], dtype=np.float32).reshape(1, 1, 3) self.std = np.array([0.229, 0.224, 0.225], dtype=np.float32).reshape(1, 1, 3) - - self.model = self._load_model() + + if mixed_precision or model_precision != torch.float32: + # Use faster matrix multiplication implementation + # This reduces the floating point precision a little bit, but it should be negligible compared to fp16 precision + torch.set_float32_matmul_precision('high') + + self.mixed_precision = mixed_precision + if mixed_precision and model_precision == torch.float16: + # using mixed precision, when the precision is already fp16, is slower + self.mixed_precision = False + + self.model_precision = model_precision + + self.model = self._load_model().to(model_precision) def _load_model(self) -> GreenFormer: print(f"Loading CorridorKey from {self.checkpoint_path}...") @@ -83,7 +102,7 @@ def _load_model(self) -> GreenFormer: return model - @torch.no_grad() + @torch.inference_mode() def process_frame( self, image: np.ndarray, @@ -149,7 +168,7 @@ def process_frame( # 4. Prepare Tensor inp_np = np.concatenate([img_norm, mask_resized], axis=-1) # [H, W, 4] - inp_t = torch.from_numpy(inp_np.transpose((2, 0, 1))).float().unsqueeze(0).to(self.device) + inp_t = torch.from_numpy(inp_np.transpose((2, 0, 1))).float().unsqueeze(0).to(self.model_precision).to(self.device) # 5. Inference # Hook for Refiner Scaling @@ -161,7 +180,7 @@ def scale_hook(module, input, output): handle = self.model.refiner.register_forward_hook(scale_hook) - with torch.autocast(device_type=self.device.type, dtype=torch.float16): + with torch.autocast(device_type=self.device.type, dtype=torch.float16, enabled=self.mixed_precision): out = self.model(inp_t) if handle: diff --git a/pyproject.toml b/pyproject.toml index 7abf8586..93c3f82f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,6 +16,9 @@ dependencies = [ "opencv-python", "tqdm", "setuptools", + # Triton fix for Windows + # There might still be issues though https://github.com/triton-lang/triton-windows?tab=readme-ov-file#windows-file-path-length-limit-260-causes-compilation-failure + "triton-windows==3.6.0.post25 ; sys_platform == 'win32'", # GVM alpha hint generator "diffusers", @@ -88,9 +91,11 @@ omit = [ [[tool.uv.index]] name = "pytorch" -url = "https://download.pytorch.org/whl/cu126" +url = "https://download.pytorch.org/whl/cu128" # CUDA 12.6 doesn't support RTX 5000 Series explicit = true [tool.uv.sources] torch = { index = "pytorch" } torchvision = { index = "pytorch" } +# Use Hiera fix in order to utilize the FlashAttention Kernel +timm = { git = "https://github.com/Raiden129/pytorch-image-models-fix", branch = "fix/hiera-flash-attention-global-4d" } diff --git a/test_vram.py b/test_vram.py index b10ab071..901161bd 100644 --- a/test_vram.py +++ b/test_vram.py @@ -1,22 +1,27 @@ +import timeit import numpy as np import torch from CorridorKeyModule.inference_engine import CorridorKeyEngine +def process_frame(engine): + img = np.random.randint(0, 255, (2160, 3840, 3), dtype=np.uint8) + mask = np.random.randint(0, 255, (2160, 3840), dtype=np.uint8) + + engine.process_frame(img, mask) + def test_vram(): print("Loading engine...") - engine = CorridorKeyEngine(checkpoint_path="CorridorKeyModule/checkpoints/CorridorKey_v1.0.pth", img_size=2048) - - # Create dummy data - img = np.random.randint(0, 255, (2160, 3840, 3), dtype=np.uint8) - mask = np.random.randint(0, 255, (2160, 3840), dtype=np.uint8) + engine = CorridorKeyEngine(checkpoint_path="CorridorKeyModule/checkpoints/CorridorKey_v1.0.pth", img_size=2048, device="cuda", model_precision=torch.float16) # Reset stats torch.cuda.reset_peak_memory_stats() - print("Running inference pass...") - engine.process_frame(img, mask) + iterations = 24 + print(f"Running {iterations} inference passes...") + time = timeit.timeit(lambda: process_frame(engine), number=iterations) + print(f"Seconds per frame: {time / iterations}") peak_vram = torch.cuda.max_memory_allocated() / (1024**3) print(f"Peak VRAM used: {peak_vram:.2f} GB") diff --git a/uv.lock b/uv.lock index b7a19785..fbcb6769 100644 --- a/uv.lock +++ b/uv.lock @@ -1,13 +1,10 @@ version = 1 revision = 3 -requires-python = ">=3.10" +requires-python = ">=3.10, <=3.14" resolution-markers = [ - "(python_full_version >= '3.12' and platform_machine != 'aarch64') or (python_full_version >= '3.15' and platform_machine == 'aarch64' and platform_python_implementation == 'CPython' and sys_platform == 'linux') or (python_full_version >= '3.12' and platform_python_implementation != 'CPython') or (python_full_version >= '3.12' and sys_platform != 'linux')", - "python_full_version >= '3.12' and python_full_version < '3.15' and platform_machine == 'aarch64' and platform_python_implementation == 'CPython' and sys_platform == 'linux'", - "(python_full_version == '3.11.*' and platform_machine != 'aarch64') or (python_full_version == '3.11.*' and platform_python_implementation != 'CPython') or (python_full_version == '3.11.*' and sys_platform != 'linux')", - "python_full_version == '3.11.*' and platform_machine == 'aarch64' and platform_python_implementation == 'CPython' and sys_platform == 'linux'", - "(python_full_version < '3.11' and platform_machine != 'aarch64') or (python_full_version < '3.11' and platform_python_implementation != 'CPython') or (python_full_version < '3.11' and sys_platform != 'linux')", - "python_full_version < '3.11' and platform_machine == 'aarch64' and platform_python_implementation == 'CPython' and sys_platform == 'linux'", + "python_full_version >= '3.12'", + "python_full_version == '3.11.*'", + "python_full_version < '3.11'", ] [[package]] @@ -233,8 +230,7 @@ name = "contourpy" version = "1.3.2" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "(python_full_version < '3.11' and platform_machine != 'aarch64') or (python_full_version < '3.11' and platform_python_implementation != 'CPython') or (python_full_version < '3.11' and sys_platform != 'linux')", - "python_full_version < '3.11' and platform_machine == 'aarch64' and platform_python_implementation == 'CPython' and sys_platform == 'linux'", + "python_full_version < '3.11'", ] dependencies = [ { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, @@ -304,10 +300,8 @@ name = "contourpy" version = "1.3.3" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "(python_full_version >= '3.12' and platform_machine != 'aarch64') or (python_full_version >= '3.15' and platform_machine == 'aarch64' and platform_python_implementation == 'CPython' and sys_platform == 'linux') or (python_full_version >= '3.12' and platform_python_implementation != 'CPython') or (python_full_version >= '3.12' and sys_platform != 'linux')", - "python_full_version >= '3.12' and python_full_version < '3.15' and platform_machine == 'aarch64' and platform_python_implementation == 'CPython' and sys_platform == 'linux'", - "(python_full_version == '3.11.*' and platform_machine != 'aarch64') or (python_full_version == '3.11.*' and platform_python_implementation != 'CPython') or (python_full_version == '3.11.*' and sys_platform != 'linux')", - "python_full_version == '3.11.*' and platform_machine == 'aarch64' and platform_python_implementation == 'CPython' and sys_platform == 'linux'", + "python_full_version >= '3.12'", + "python_full_version == '3.11.*'", ] dependencies = [ { name = "numpy", version = "2.4.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, @@ -409,10 +403,10 @@ dependencies = [ { name = "setuptools" }, { name = "timm" }, { name = "torch" }, - { name = "torchvision", version = "0.24.1", source = { registry = "https://download.pytorch.org/whl/cu126" }, marker = "python_full_version < '3.15' and platform_machine == 'aarch64' and platform_python_implementation == 'CPython' and sys_platform == 'linux'" }, - { name = "torchvision", version = "0.24.1+cu126", source = { registry = "https://download.pytorch.org/whl/cu126" }, marker = "python_full_version >= '3.15' or platform_machine != 'aarch64' or platform_python_implementation != 'CPython' or sys_platform != 'linux'" }, + { name = "torchvision" }, { name = "tqdm" }, { name = "transformers" }, + { name = "triton-windows", marker = "sys_platform == 'win32'" }, ] [package.dev-dependencies] @@ -438,11 +432,12 @@ requires-dist = [ { name = "pillow" }, { name = "pims" }, { name = "setuptools" }, - { name = "timm", specifier = "==1.0.24" }, - { name = "torch", specifier = "==2.9.1", index = "https://download.pytorch.org/whl/cu126" }, - { name = "torchvision", index = "https://download.pytorch.org/whl/cu126" }, + { name = "timm", git = "https://github.com/Raiden129/pytorch-image-models-fix?branch=fix%2Fhiera-flash-attention-global-4d" }, + { name = "torch", specifier = "==2.10.0", index = "https://download.pytorch.org/whl/cu128" }, + { name = "torchvision", specifier = "==0.25.0", index = "https://download.pytorch.org/whl/cu128" }, { name = "tqdm" }, { name = "transformers" }, + { name = "triton-windows", marker = "sys_platform == 'win32'", specifier = "==3.6.0.post25" }, ] [package.metadata.requires-dev] @@ -570,6 +565,38 @@ toml = [ { name = "tomli", marker = "python_full_version <= '3.11'" }, ] +[[package]] +name = "cuda-bindings" +version = "12.9.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cuda-pathfinder" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/37/31/bfcc870f69c6a017c4ad5c42316207fc7551940db6f3639aa4466ec5faf3/cuda_bindings-12.9.4-cp310-cp310-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a022c96b8bd847e8dc0675523431149a4c3e872f440e3002213dbb9e08f0331a", size = 11800959, upload-time = "2025-10-21T14:51:26.458Z" }, + { url = "https://files.pythonhosted.org/packages/7a/d8/b546104b8da3f562c1ff8ab36d130c8fe1dd6a045ced80b4f6ad74f7d4e1/cuda_bindings-12.9.4-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:4d3c842c2a4303b2a580fe955018e31aea30278be19795ae05226235268032e5", size = 12148218, upload-time = "2025-10-21T14:51:28.855Z" }, + { url = "https://files.pythonhosted.org/packages/a9/2b/ebcbb60aa6dba830474cd360c42e10282f7a343c0a1f58d24fbd3b7c2d77/cuda_bindings-12.9.4-cp311-cp311-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a6a429dc6c13148ff1e27c44f40a3dd23203823e637b87fd0854205195988306", size = 11840604, upload-time = "2025-10-21T14:51:34.565Z" }, + { url = "https://files.pythonhosted.org/packages/45/e7/b47792cc2d01c7e1d37c32402182524774dadd2d26339bd224e0e913832e/cuda_bindings-12.9.4-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c912a3d9e6b6651853eed8eed96d6800d69c08e94052c292fec3f282c5a817c9", size = 12210593, upload-time = "2025-10-21T14:51:36.574Z" }, + { url = "https://files.pythonhosted.org/packages/0c/c2/65bfd79292b8ff18be4dd7f7442cea37bcbc1a228c1886f1dea515c45b67/cuda_bindings-12.9.4-cp312-cp312-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:694ba35023846625ef471257e6b5a4bc8af690f961d197d77d34b1d1db393f56", size = 11760260, upload-time = "2025-10-21T14:51:40.79Z" }, + { url = "https://files.pythonhosted.org/packages/a9/c1/dabe88f52c3e3760d861401bb994df08f672ec893b8f7592dc91626adcf3/cuda_bindings-12.9.4-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:fda147a344e8eaeca0c6ff113d2851ffca8f7dfc0a6c932374ee5c47caa649c8", size = 12151019, upload-time = "2025-10-21T14:51:43.167Z" }, + { url = "https://files.pythonhosted.org/packages/05/8b/b4b2d1c7775fa403b64333e720cfcfccef8dcb9cdeb99947061ca5a77628/cuda_bindings-12.9.4-cp313-cp313-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:cf8bfaedc238f3b115d957d1fd6562b7e8435ba57f6d0e2f87d0e7149ccb2da5", size = 11570071, upload-time = "2025-10-21T14:51:47.472Z" }, + { url = "https://files.pythonhosted.org/packages/63/56/e465c31dc9111be3441a9ba7df1941fe98f4aa6e71e8788a3fb4534ce24d/cuda_bindings-12.9.4-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:32bdc5a76906be4c61eb98f546a6786c5773a881f3b166486449b5d141e4a39f", size = 11906628, upload-time = "2025-10-21T14:51:49.905Z" }, + { url = "https://files.pythonhosted.org/packages/ec/07/6aff13bc1e977e35aaa6b22f52b172e2890c608c6db22438cf7ed2bf43a6/cuda_bindings-12.9.4-cp313-cp313t-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3adf4958dcf68ae7801a59b73fb00a8b37f8d0595060d66ceae111b1002de38d", size = 11566797, upload-time = "2025-10-21T14:51:54.581Z" }, + { url = "https://files.pythonhosted.org/packages/a3/84/1e6be415e37478070aeeee5884c2022713c1ecc735e6d82d744de0252eee/cuda_bindings-12.9.4-cp313-cp313t-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:56e0043c457a99ac473ddc926fe0dc4046694d99caef633e92601ab52cbe17eb", size = 11925991, upload-time = "2025-10-21T14:51:56.535Z" }, + { url = "https://files.pythonhosted.org/packages/1e/b5/96a6696e20c4ffd2b327f54c7d0fde2259bdb998d045c25d5dedbbe30290/cuda_bindings-12.9.4-cp314-cp314-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:1f53a7f453d4b2643d8663d036bafe29b5ba89eb904c133180f295df6dc151e5", size = 11624530, upload-time = "2025-10-21T14:52:01.539Z" }, + { url = "https://files.pythonhosted.org/packages/d1/af/6dfd8f2ed90b1d4719bc053ff8940e494640fe4212dc3dd72f383e4992da/cuda_bindings-12.9.4-cp314-cp314-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:8b72ee72a9cc1b531db31eebaaee5c69a8ec3500e32c6933f2d3b15297b53686", size = 11922703, upload-time = "2025-10-21T14:52:03.585Z" }, + { url = "https://files.pythonhosted.org/packages/39/73/d2fc40c043bac699c3880bf88d3cebe9d88410cd043795382826c93a89f0/cuda_bindings-12.9.4-cp314-cp314t-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:20f2699d61d724de3eb3f3369d57e2b245f93085cab44fd37c3bea036cea1a6f", size = 11565056, upload-time = "2025-10-21T14:52:08.338Z" }, + { url = "https://files.pythonhosted.org/packages/6c/19/90ac264acc00f6df8a49378eedec9fd2db3061bf9263bf9f39fd3d8377c3/cuda_bindings-12.9.4-cp314-cp314t-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d80bffc357df9988dca279734bc9674c3934a654cab10cadeed27ce17d8635ee", size = 11924658, upload-time = "2025-10-21T14:52:10.411Z" }, +] + +[[package]] +name = "cuda-pathfinder" +version = "1.4.1" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/07/02/59a5bc738a09def0b49aea0e460bdf97f65206d0d041246147cf6207e69c/cuda_pathfinder-1.4.1-py3-none-any.whl", hash = "sha256:40793006082de88e0950753655e55558a446bed9a7d9d0bcb48b2506d50ed82a", size = 43903, upload-time = "2026-03-06T21:05:24.372Z" }, +] + [[package]] name = "cycler" version = "0.12.1" @@ -1153,8 +1180,7 @@ name = "networkx" version = "3.4.2" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "(python_full_version < '3.11' and platform_machine != 'aarch64') or (python_full_version < '3.11' and platform_python_implementation != 'CPython') or (python_full_version < '3.11' and sys_platform != 'linux')", - "python_full_version < '3.11' and platform_machine == 'aarch64' and platform_python_implementation == 'CPython' and sys_platform == 'linux'", + "python_full_version < '3.11'", ] sdist = { url = "https://files.pythonhosted.org/packages/fd/1d/06475e1cd5264c0b870ea2cc6fdb3e37177c1e565c43f56ff17a10e3937f/networkx-3.4.2.tar.gz", hash = "sha256:307c3669428c5362aab27c8a1260aa8f47c4e91d3891f48be0141738d8d053e1", size = 2151368, upload-time = "2024-10-21T12:39:38.695Z" } wheels = [ @@ -1166,10 +1192,8 @@ name = "networkx" version = "3.6.1" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "(python_full_version >= '3.12' and platform_machine != 'aarch64') or (python_full_version >= '3.15' and platform_machine == 'aarch64' and platform_python_implementation == 'CPython' and sys_platform == 'linux') or (python_full_version >= '3.12' and platform_python_implementation != 'CPython') or (python_full_version >= '3.12' and sys_platform != 'linux')", - "python_full_version >= '3.12' and python_full_version < '3.15' and platform_machine == 'aarch64' and platform_python_implementation == 'CPython' and sys_platform == 'linux'", - "(python_full_version == '3.11.*' and platform_machine != 'aarch64') or (python_full_version == '3.11.*' and platform_python_implementation != 'CPython') or (python_full_version == '3.11.*' and sys_platform != 'linux')", - "python_full_version == '3.11.*' and platform_machine == 'aarch64' and platform_python_implementation == 'CPython' and sys_platform == 'linux'", + "python_full_version >= '3.12'", + "python_full_version == '3.11.*'", ] sdist = { url = "https://files.pythonhosted.org/packages/6a/51/63fe664f3908c97be9d2e4f1158eb633317598cfa6e1fc14af5383f17512/networkx-3.6.1.tar.gz", hash = "sha256:26b7c357accc0c8cde558ad486283728b65b6a95d85ee1cd66bafab4c8168509", size = 2517025, upload-time = "2025-12-08T17:02:39.908Z" } wheels = [ @@ -1181,8 +1205,7 @@ name = "numpy" version = "2.2.6" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "(python_full_version < '3.11' and platform_machine != 'aarch64') or (python_full_version < '3.11' and platform_python_implementation != 'CPython') or (python_full_version < '3.11' and sys_platform != 'linux')", - "python_full_version < '3.11' and platform_machine == 'aarch64' and platform_python_implementation == 'CPython' and sys_platform == 'linux'", + "python_full_version < '3.11'", ] sdist = { url = "https://files.pythonhosted.org/packages/76/21/7d2a95e4bba9dc13d043ee156a356c0a8f0c6309dff6b21b4d71a073b8a8/numpy-2.2.6.tar.gz", hash = "sha256:e29554e2bef54a90aa5cc07da6ce955accb83f21ab5de01a62c8478897b264fd", size = 20276440, upload-time = "2025-05-17T22:38:04.611Z" } wheels = [ @@ -1247,10 +1270,8 @@ name = "numpy" version = "2.4.2" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "(python_full_version >= '3.12' and platform_machine != 'aarch64') or (python_full_version >= '3.15' and platform_machine == 'aarch64' and platform_python_implementation == 'CPython' and sys_platform == 'linux') or (python_full_version >= '3.12' and platform_python_implementation != 'CPython') or (python_full_version >= '3.12' and sys_platform != 'linux')", - "python_full_version >= '3.12' and python_full_version < '3.15' and platform_machine == 'aarch64' and platform_python_implementation == 'CPython' and sys_platform == 'linux'", - "(python_full_version == '3.11.*' and platform_machine != 'aarch64') or (python_full_version == '3.11.*' and platform_python_implementation != 'CPython') or (python_full_version == '3.11.*' and sys_platform != 'linux')", - "python_full_version == '3.11.*' and platform_machine == 'aarch64' and platform_python_implementation == 'CPython' and sys_platform == 'linux'", + "python_full_version >= '3.12'", + "python_full_version == '3.11.*'", ] sdist = { url = "https://files.pythonhosted.org/packages/57/fd/0005efbd0af48e55eb3c7208af93f2862d4b1a56cd78e84309a2d959208d/numpy-2.4.2.tar.gz", hash = "sha256:659a6107e31a83c4e33f763942275fd278b21d095094044eb35569e86a21ddae", size = 20723651, upload-time = "2026-01-31T23:13:10.135Z" } wheels = [ @@ -1329,42 +1350,38 @@ wheels = [ [[package]] name = "nvidia-cublas-cu12" -version = "12.6.4.1" +version = "12.8.4.1" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/af/eb/ff4b8c503fa1f1796679dce648854d58751982426e4e4b37d6fce49d259c/nvidia_cublas_cu12-12.6.4.1-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:08ed2686e9875d01b58e3cb379c6896df8e76c75e0d4a7f7dace3d7b6d9ef8eb", size = 393138322, upload-time = "2024-11-20T17:40:25.65Z" }, - { url = "https://files.pythonhosted.org/packages/97/0d/f1f0cadbf69d5b9ef2e4f744c9466cb0a850741d08350736dfdb4aa89569/nvidia_cublas_cu12-12.6.4.1-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:235f728d6e2a409eddf1df58d5b0921cf80cfa9e72b9f2775ccb7b4a87984668", size = 390794615, upload-time = "2024-11-20T17:39:52.715Z" }, + { url = "https://files.pythonhosted.org/packages/29/99/db44d685f0e257ff0e213ade1964fc459b4a690a73293220e98feb3307cf/nvidia_cublas_cu12-12.8.4.1-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:b86f6dd8935884615a0683b663891d43781b819ac4f2ba2b0c9604676af346d0", size = 590537124, upload-time = "2025-03-07T01:43:53.556Z" }, + { url = "https://files.pythonhosted.org/packages/dc/61/e24b560ab2e2eaeb3c839129175fb330dfcfc29e5203196e5541a4c44682/nvidia_cublas_cu12-12.8.4.1-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:8ac4e771d5a348c551b2a426eda6193c19aa630236b418086020df5ba9667142", size = 594346921, upload-time = "2025-03-07T01:44:31.254Z" }, ] [[package]] name = "nvidia-cuda-cupti-cu12" -version = "12.6.80" +version = "12.8.90" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/e6/8b/2f6230cb715646c3a9425636e513227ce5c93c4d65823a734f4bb86d43c3/nvidia_cuda_cupti_cu12-12.6.80-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:166ee35a3ff1587f2490364f90eeeb8da06cd867bd5b701bf7f9a02b78bc63fc", size = 8236764, upload-time = "2024-11-20T17:35:41.03Z" }, - { url = "https://files.pythonhosted.org/packages/25/0f/acb326ac8fd26e13c799e0b4f3b2751543e1834f04d62e729485872198d4/nvidia_cuda_cupti_cu12-12.6.80-py3-none-manylinux2014_aarch64.whl", hash = "sha256:358b4a1d35370353d52e12f0a7d1769fc01ff74a191689d3870b2123156184c4", size = 8236756, upload-time = "2024-10-01T16:57:45.507Z" }, - { url = "https://files.pythonhosted.org/packages/49/60/7b6497946d74bcf1de852a21824d63baad12cd417db4195fc1bfe59db953/nvidia_cuda_cupti_cu12-12.6.80-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:6768bad6cab4f19e8292125e5f1ac8aa7d1718704012a0e3272a6f61c4bce132", size = 8917980, upload-time = "2024-11-20T17:36:04.019Z" }, - { url = "https://files.pythonhosted.org/packages/a5/24/120ee57b218d9952c379d1e026c4479c9ece9997a4fb46303611ee48f038/nvidia_cuda_cupti_cu12-12.6.80-py3-none-manylinux2014_x86_64.whl", hash = "sha256:a3eff6cdfcc6a4c35db968a06fcadb061cbc7d6dde548609a941ff8701b98b73", size = 8917972, upload-time = "2024-10-01T16:58:06.036Z" }, + { url = "https://files.pythonhosted.org/packages/d5/1f/b3bd73445e5cb342727fd24fe1f7b748f690b460acadc27ea22f904502c8/nvidia_cuda_cupti_cu12-12.8.90-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:4412396548808ddfed3f17a467b104ba7751e6b58678a4b840675c56d21cf7ed", size = 9533318, upload-time = "2025-03-07T01:40:10.421Z" }, + { url = "https://files.pythonhosted.org/packages/f8/02/2adcaa145158bf1a8295d83591d22e4103dbfd821bcaf6f3f53151ca4ffa/nvidia_cuda_cupti_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ea0cb07ebda26bb9b29ba82cda34849e73c166c18162d3913575b0c9db9a6182", size = 10248621, upload-time = "2025-03-07T01:40:21.213Z" }, ] [[package]] name = "nvidia-cuda-nvrtc-cu12" -version = "12.6.77" +version = "12.8.93" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/f4/2f/72df534873235983cc0a5371c3661bebef7c4682760c275590b972c7b0f9/nvidia_cuda_nvrtc_cu12-12.6.77-py3-none-manylinux2014_aarch64.whl", hash = "sha256:5847f1d6e5b757f1d2b3991a01082a44aad6f10ab3c5c0213fa3e25bddc25a13", size = 23162955, upload-time = "2024-10-01T16:59:50.922Z" }, - { url = "https://files.pythonhosted.org/packages/75/2e/46030320b5a80661e88039f59060d1790298b4718944a65a7f2aeda3d9e9/nvidia_cuda_nvrtc_cu12-12.6.77-py3-none-manylinux2014_x86_64.whl", hash = "sha256:35b0cc6ee3a9636d5409133e79273ce1f3fd087abb0532d2d2e8fff1fe9efc53", size = 23650380, upload-time = "2024-10-01T17:00:14.643Z" }, + { url = "https://files.pythonhosted.org/packages/05/6b/32f747947df2da6994e999492ab306a903659555dddc0fbdeb9d71f75e52/nvidia_cuda_nvrtc_cu12-12.8.93-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:a7756528852ef889772a84c6cd89d41dfa74667e24cca16bb31f8f061e3e9994", size = 88040029, upload-time = "2025-03-07T01:42:13.562Z" }, + { url = "https://files.pythonhosted.org/packages/eb/d1/e50d0acaab360482034b84b6e27ee83c6738f7d32182b987f9c7a4e32962/nvidia_cuda_nvrtc_cu12-12.8.93-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:fc1fec1e1637854b4c0a65fb9a8346b51dd9ee69e61ebaccc82058441f15bce8", size = 43106076, upload-time = "2025-03-07T01:41:59.817Z" }, ] [[package]] name = "nvidia-cuda-runtime-cu12" -version = "12.6.77" +version = "12.8.90" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/8f/ea/590b2ac00d772a8abd1c387a92b46486d2679ca6622fd25c18ff76265663/nvidia_cuda_runtime_cu12-12.6.77-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:6116fad3e049e04791c0256a9778c16237837c08b27ed8c8401e2e45de8d60cd", size = 908052, upload-time = "2024-11-20T17:35:19.905Z" }, - { url = "https://files.pythonhosted.org/packages/b7/3d/159023799677126e20c8fd580cca09eeb28d5c5a624adc7f793b9aa8bbfa/nvidia_cuda_runtime_cu12-12.6.77-py3-none-manylinux2014_aarch64.whl", hash = "sha256:d461264ecb429c84c8879a7153499ddc7b19b5f8d84c204307491989a365588e", size = 908040, upload-time = "2024-10-01T16:57:22.221Z" }, - { url = "https://files.pythonhosted.org/packages/e1/23/e717c5ac26d26cf39a27fbc076240fad2e3b817e5889d671b67f4f9f49c5/nvidia_cuda_runtime_cu12-12.6.77-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ba3b56a4f896141e25e19ab287cd71e52a6a0f4b29d0d31609f60e3b4d5219b7", size = 897690, upload-time = "2024-11-20T17:35:30.697Z" }, - { url = "https://files.pythonhosted.org/packages/f0/62/65c05e161eeddbafeca24dc461f47de550d9fa8a7e04eb213e32b55cfd99/nvidia_cuda_runtime_cu12-12.6.77-py3-none-manylinux2014_x86_64.whl", hash = "sha256:a84d15d5e1da416dd4774cb42edf5e954a3e60cc945698dc1d5be02321c44dc8", size = 897678, upload-time = "2024-10-01T16:57:33.821Z" }, + { url = "https://files.pythonhosted.org/packages/7c/75/f865a3b236e4647605ea34cc450900854ba123834a5f1598e160b9530c3a/nvidia_cuda_runtime_cu12-12.8.90-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:52bf7bbee900262ffefe5e9d5a2a69a30d97e2bc5bb6cc866688caa976966e3d", size = 965265, upload-time = "2025-03-07T01:39:43.533Z" }, + { url = "https://files.pythonhosted.org/packages/0d/9b/a997b638fcd068ad6e4d53b8551a7d30fe8b404d6f1804abf1df69838932/nvidia_cuda_runtime_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:adade8dcbd0edf427b7204d480d6066d33902cab2a4707dcfc48a2d0fd44ab90", size = 954765, upload-time = "2025-03-07T01:40:01.615Z" }, ] [[package]] @@ -1381,41 +1398,37 @@ wheels = [ [[package]] name = "nvidia-cufft-cu12" -version = "11.3.0.4" +version = "11.3.3.83" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "nvidia-nvjitlink-cu12" }, ] wheels = [ - { url = "https://files.pythonhosted.org/packages/1f/37/c50d2b2f2c07e146776389e3080f4faf70bcc4fa6e19d65bb54ca174ebc3/nvidia_cufft_cu12-11.3.0.4-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:d16079550df460376455cba121db6564089176d9bac9e4f360493ca4741b22a6", size = 200164144, upload-time = "2024-11-20T17:40:58.288Z" }, - { url = "https://files.pythonhosted.org/packages/ce/f5/188566814b7339e893f8d210d3a5332352b1409815908dad6a363dcceac1/nvidia_cufft_cu12-11.3.0.4-py3-none-manylinux2014_aarch64.whl", hash = "sha256:8510990de9f96c803a051822618d42bf6cb8f069ff3f48d93a8486efdacb48fb", size = 200164135, upload-time = "2024-10-01T17:03:24.212Z" }, - { url = "https://files.pythonhosted.org/packages/8f/16/73727675941ab8e6ffd86ca3a4b7b47065edcca7a997920b831f8147c99d/nvidia_cufft_cu12-11.3.0.4-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ccba62eb9cef5559abd5e0d54ceed2d9934030f51163df018532142a8ec533e5", size = 200221632, upload-time = "2024-11-20T17:41:32.357Z" }, - { url = "https://files.pythonhosted.org/packages/60/de/99ec247a07ea40c969d904fc14f3a356b3e2a704121675b75c366b694ee1/nvidia_cufft_cu12-11.3.0.4-py3-none-manylinux2014_x86_64.whl", hash = "sha256:768160ac89f6f7b459bee747e8d175dbf53619cfe74b2a5636264163138013ca", size = 200221622, upload-time = "2024-10-01T17:03:58.79Z" }, + { url = "https://files.pythonhosted.org/packages/60/bc/7771846d3a0272026c416fbb7e5f4c1f146d6d80704534d0b187dd6f4800/nvidia_cufft_cu12-11.3.3.83-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:848ef7224d6305cdb2a4df928759dca7b1201874787083b6e7550dd6765ce69a", size = 193109211, upload-time = "2025-03-07T01:44:56.873Z" }, + { url = "https://files.pythonhosted.org/packages/1f/13/ee4e00f30e676b66ae65b4f08cb5bcbb8392c03f54f2d5413ea99a5d1c80/nvidia_cufft_cu12-11.3.3.83-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4d2dd21ec0b88cf61b62e6b43564355e5222e4a3fb394cac0db101f2dd0d4f74", size = 193118695, upload-time = "2025-03-07T01:45:27.821Z" }, ] [[package]] name = "nvidia-cufile-cu12" -version = "1.11.1.6" +version = "1.13.1.3" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/b2/66/cc9876340ac68ae71b15c743ddb13f8b30d5244af344ec8322b449e35426/nvidia_cufile_cu12-1.11.1.6-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:cc23469d1c7e52ce6c1d55253273d32c565dd22068647f3aa59b3c6b005bf159", size = 1142103, upload-time = "2024-11-20T17:42:11.83Z" }, - { url = "https://files.pythonhosted.org/packages/17/bf/cc834147263b929229ce4aadd62869f0b195e98569d4c28b23edc72b85d9/nvidia_cufile_cu12-1.11.1.6-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:8f57a0051dcf2543f6dc2b98a98cb2719c37d3cee1baba8965d57f3bbc90d4db", size = 1066155, upload-time = "2024-11-20T17:41:49.376Z" }, + { url = "https://files.pythonhosted.org/packages/bb/fe/1bcba1dfbfb8d01be8d93f07bfc502c93fa23afa6fd5ab3fc7c1df71038a/nvidia_cufile_cu12-1.13.1.3-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1d069003be650e131b21c932ec3d8969c1715379251f8d23a1860554b1cb24fc", size = 1197834, upload-time = "2025-03-07T01:45:50.723Z" }, + { url = "https://files.pythonhosted.org/packages/1e/f5/5607710447a6fe9fd9b3283956fceeee8a06cda1d2f56ce31371f595db2a/nvidia_cufile_cu12-1.13.1.3-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:4beb6d4cce47c1a0f1013d72e02b0994730359e17801d395bdcbf20cfb3bb00a", size = 1120705, upload-time = "2025-03-07T01:45:41.434Z" }, ] [[package]] name = "nvidia-curand-cu12" -version = "10.3.7.77" +version = "10.3.9.90" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/42/ac/36543605358a355632f1a6faa3e2d5dfb91eab1e4bc7d552040e0383c335/nvidia_curand_cu12-10.3.7.77-py3-none-manylinux2014_aarch64.whl", hash = "sha256:6e82df077060ea28e37f48a3ec442a8f47690c7499bff392a5938614b56c98d8", size = 56289881, upload-time = "2024-10-01T17:04:18.981Z" }, - { url = "https://files.pythonhosted.org/packages/73/1b/44a01c4e70933637c93e6e1a8063d1e998b50213a6b65ac5a9169c47e98e/nvidia_curand_cu12-10.3.7.77-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:a42cd1344297f70b9e39a1e4f467a4e1c10f1da54ff7a85c12197f6c652c8bdf", size = 56279010, upload-time = "2024-11-20T17:42:50.958Z" }, - { url = "https://files.pythonhosted.org/packages/4a/aa/2c7ff0b5ee02eaef890c0ce7d4f74bc30901871c5e45dee1ae6d0083cd80/nvidia_curand_cu12-10.3.7.77-py3-none-manylinux2014_x86_64.whl", hash = "sha256:99f1a32f1ac2bd134897fc7a203f779303261268a65762a623bf30cc9fe79117", size = 56279000, upload-time = "2024-10-01T17:04:45.274Z" }, - { url = "https://files.pythonhosted.org/packages/a6/02/5362a9396f23f7de1dd8a64369e87c85ffff8216fc8194ace0fa45ba27a5/nvidia_curand_cu12-10.3.7.77-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:7b2ed8e95595c3591d984ea3603dd66fe6ce6812b886d59049988a712ed06b6e", size = 56289882, upload-time = "2024-11-20T17:42:25.222Z" }, + { url = "https://files.pythonhosted.org/packages/45/5e/92aa15eca622a388b80fbf8375d4760738df6285b1e92c43d37390a33a9a/nvidia_curand_cu12-10.3.9.90-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:dfab99248034673b779bc6decafdc3404a8a6f502462201f2f31f11354204acd", size = 63625754, upload-time = "2025-03-07T01:46:10.735Z" }, + { url = "https://files.pythonhosted.org/packages/fb/aa/6584b56dc84ebe9cf93226a5cde4d99080c8e90ab40f0c27bda7a0f29aa1/nvidia_curand_cu12-10.3.9.90-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:b32331d4f4df5d6eefa0554c565b626c7216f87a06a4f56fab27c3b68a830ec9", size = 63619976, upload-time = "2025-03-07T01:46:23.323Z" }, ] [[package]] name = "nvidia-cusolver-cu12" -version = "11.7.1.2" +version = "11.7.3.90" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "nvidia-cublas-cu12" }, @@ -1423,24 +1436,20 @@ dependencies = [ { name = "nvidia-nvjitlink-cu12" }, ] wheels = [ - { url = "https://files.pythonhosted.org/packages/93/17/dbe1aa865e4fdc7b6d4d0dd308fdd5aaab60f939abfc0ea1954eac4fb113/nvidia_cusolver_cu12-11.7.1.2-py3-none-manylinux2014_aarch64.whl", hash = "sha256:0ce237ef60acde1efc457335a2ddadfd7610b892d94efee7b776c64bb1cac9e0", size = 157833628, upload-time = "2024-10-01T17:05:05.591Z" }, - { url = "https://files.pythonhosted.org/packages/f0/6e/c2cf12c9ff8b872e92b4a5740701e51ff17689c4d726fca91875b07f655d/nvidia_cusolver_cu12-11.7.1.2-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e9e49843a7707e42022babb9bcfa33c29857a93b88020c4e4434656a655b698c", size = 158229790, upload-time = "2024-11-20T17:43:43.211Z" }, - { url = "https://files.pythonhosted.org/packages/9f/81/baba53585da791d043c10084cf9553e074548408e04ae884cfe9193bd484/nvidia_cusolver_cu12-11.7.1.2-py3-none-manylinux2014_x86_64.whl", hash = "sha256:6cf28f17f64107a0c4d7802be5ff5537b2130bfc112f25d5a30df227058ca0e6", size = 158229780, upload-time = "2024-10-01T17:05:39.875Z" }, - { url = "https://files.pythonhosted.org/packages/7c/5f/07d0ba3b7f19be5a5ec32a8679fc9384cfd9fc6c869825e93be9f28d6690/nvidia_cusolver_cu12-11.7.1.2-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:dbbe4fc38ec1289c7e5230e16248365e375c3673c9c8bac5796e2e20db07f56e", size = 157833630, upload-time = "2024-11-20T17:43:16.77Z" }, + { url = "https://files.pythonhosted.org/packages/c8/32/f7cd6ce8a7690544d084ea21c26e910a97e077c9b7f07bf5de623ee19981/nvidia_cusolver_cu12-11.7.3.90-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:db9ed69dbef9715071232caa9b69c52ac7de3a95773c2db65bdba85916e4e5c0", size = 267229841, upload-time = "2025-03-07T01:46:54.356Z" }, + { url = "https://files.pythonhosted.org/packages/85/48/9a13d2975803e8cf2777d5ed57b87a0b6ca2cc795f9a4f59796a910bfb80/nvidia_cusolver_cu12-11.7.3.90-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:4376c11ad263152bd50ea295c05370360776f8c3427b30991df774f9fb26c450", size = 267506905, upload-time = "2025-03-07T01:47:16.273Z" }, ] [[package]] name = "nvidia-cusparse-cu12" -version = "12.5.4.2" +version = "12.5.8.93" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "nvidia-nvjitlink-cu12" }, ] wheels = [ - { url = "https://files.pythonhosted.org/packages/eb/eb/6681efd0aa7df96b4f8067b3ce7246833dd36830bb4cec8896182773db7d/nvidia_cusparse_cu12-12.5.4.2-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:d25b62fb18751758fe3c93a4a08eff08effedfe4edf1c6bb5afd0890fe88f887", size = 216451147, upload-time = "2024-11-20T17:44:18.055Z" }, - { url = "https://files.pythonhosted.org/packages/d3/56/3af21e43014eb40134dea004e8d0f1ef19d9596a39e4d497d5a7de01669f/nvidia_cusparse_cu12-12.5.4.2-py3-none-manylinux2014_aarch64.whl", hash = "sha256:7aa32fa5470cf754f72d1116c7cbc300b4e638d3ae5304cfa4a638a5b87161b1", size = 216451135, upload-time = "2024-10-01T17:06:03.826Z" }, - { url = "https://files.pythonhosted.org/packages/06/1e/b8b7c2f4099a37b96af5c9bb158632ea9e5d9d27d7391d7eb8fc45236674/nvidia_cusparse_cu12-12.5.4.2-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:7556d9eca156e18184b94947ade0fba5bb47d69cec46bf8660fd2c71a4b48b73", size = 216561367, upload-time = "2024-11-20T17:44:54.824Z" }, - { url = "https://files.pythonhosted.org/packages/43/ac/64c4316ba163e8217a99680c7605f779accffc6a4bcd0c778c12948d3707/nvidia_cusparse_cu12-12.5.4.2-py3-none-manylinux2014_x86_64.whl", hash = "sha256:23749a6571191a215cb74d1cdbff4a86e7b19f1200c071b3fcf844a5bea23a2f", size = 216561357, upload-time = "2024-10-01T17:06:29.861Z" }, + { url = "https://files.pythonhosted.org/packages/bc/f7/cd777c4109681367721b00a106f491e0d0d15cfa1fd59672ce580ce42a97/nvidia_cusparse_cu12-12.5.8.93-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:9b6c161cb130be1a07a27ea6923df8141f3c295852f4b260c65f18f3e0a091dc", size = 288117129, upload-time = "2025-03-07T01:47:40.407Z" }, + { url = "https://files.pythonhosted.org/packages/c2/f5/e1854cb2f2bcd4280c44736c93550cc300ff4b8c95ebe370d0aa7d2b473d/nvidia_cusparse_cu12-12.5.8.93-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1ec05d76bbbd8b61b06a80e1eaf8cf4959c3d4ce8e711b65ebd0443bb0ebb13b", size = 288216466, upload-time = "2025-03-07T01:48:13.779Z" }, ] [[package]] @@ -1463,31 +1472,29 @@ wheels = [ [[package]] name = "nvidia-nvjitlink-cu12" -version = "12.6.85" +version = "12.8.93" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/9d/d7/c5383e47c7e9bf1c99d5bd2a8c935af2b6d705ad831a7ec5c97db4d82f4f/nvidia_nvjitlink_cu12-12.6.85-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:eedc36df9e88b682efe4309aa16b5b4e78c2407eac59e8c10a6a47535164369a", size = 19744971, upload-time = "2024-11-20T17:46:53.366Z" }, - { url = "https://files.pythonhosted.org/packages/31/db/dc71113d441f208cdfe7ae10d4983884e13f464a6252450693365e166dcf/nvidia_nvjitlink_cu12-12.6.85-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:cf4eaa7d4b6b543ffd69d6abfb11efdeb2db48270d94dfd3a452c24150829e41", size = 19270338, upload-time = "2024-11-20T17:46:29.758Z" }, + { url = "https://files.pythonhosted.org/packages/f6/74/86a07f1d0f42998ca31312f998bd3b9a7eff7f52378f4f270c8679c77fb9/nvidia_nvjitlink_cu12-12.8.93-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:81ff63371a7ebd6e6451970684f916be2eab07321b73c9d244dc2b4da7f73b88", size = 39254836, upload-time = "2025-03-07T01:49:55.661Z" }, + { url = "https://files.pythonhosted.org/packages/2a/a2/8cee5da30d13430e87bf99bb33455d2724d0a4a9cb5d7926d80ccb96d008/nvidia_nvjitlink_cu12-12.8.93-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:adccd7161ace7261e01bb91e44e88da350895c270d23f744f0820c818b7229e7", size = 38386204, upload-time = "2025-03-07T01:49:43.612Z" }, ] [[package]] name = "nvidia-nvshmem-cu12" -version = "3.3.20" +version = "3.4.5" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/92/9d/3dd98852568fb845ec1f7902c90a22b240fe1cbabda411ccedf2fd737b7b/nvidia_nvshmem_cu12-3.3.20-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:0b0b960da3842212758e4fa4696b94f129090b30e5122fea3c5345916545cff0", size = 124484616, upload-time = "2025-08-04T20:24:59.172Z" }, - { url = "https://files.pythonhosted.org/packages/3b/6c/99acb2f9eb85c29fc6f3a7ac4dccfd992e22666dd08a642b303311326a97/nvidia_nvshmem_cu12-3.3.20-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:d00f26d3f9b2e3c3065be895e3059d6479ea5c638a3f38c9fec49b1b9dd7c1e5", size = 124657145, upload-time = "2025-08-04T20:25:19.995Z" }, + { url = "https://files.pythonhosted.org/packages/1d/6a/03aa43cc9bd3ad91553a88b5f6fb25ed6a3752ae86ce2180221962bc2aa5/nvidia_nvshmem_cu12-3.4.5-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:0b48363fc6964dede448029434c6abed6c5e37f823cb43c3bcde7ecfc0457e15", size = 138936938, upload-time = "2025-09-06T00:32:05.589Z" }, + { url = "https://files.pythonhosted.org/packages/b5/09/6ea3ea725f82e1e76684f0708bbedd871fc96da89945adeba65c3835a64c/nvidia_nvshmem_cu12-3.4.5-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:042f2500f24c021db8a06c5eec2539027d57460e1c1a762055a6554f72c369bd", size = 139103095, upload-time = "2025-09-06T00:32:31.266Z" }, ] [[package]] name = "nvidia-nvtx-cu12" -version = "12.6.77" +version = "12.8.90" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/b9/93/80f8a520375af9d7ee44571a6544653a176e53c2b8ccce85b97b83c2491b/nvidia_nvtx_cu12-12.6.77-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:f44f8d86bb7d5629988d61c8d3ae61dddb2015dee142740536bc7481b022fe4b", size = 90549, upload-time = "2024-11-20T17:38:17.387Z" }, - { url = "https://files.pythonhosted.org/packages/2b/53/36e2fd6c7068997169b49ffc8c12d5af5e5ff209df6e1a2c4d373b3a638f/nvidia_nvtx_cu12-12.6.77-py3-none-manylinux2014_aarch64.whl", hash = "sha256:adcaabb9d436c9761fca2b13959a2d237c5f9fd406c8e4b723c695409ff88059", size = 90539, upload-time = "2024-10-01T17:00:27.179Z" }, - { url = "https://files.pythonhosted.org/packages/56/9a/fff8376f8e3d084cd1530e1ef7b879bb7d6d265620c95c1b322725c694f4/nvidia_nvtx_cu12-12.6.77-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:b90bed3df379fa79afbd21be8e04a0314336b8ae16768b58f2d34cb1d04cd7d2", size = 89276, upload-time = "2024-11-20T17:38:27.621Z" }, - { url = "https://files.pythonhosted.org/packages/9e/4e/0d0c945463719429b7bd21dece907ad0bde437a2ff12b9b12fee94722ab0/nvidia_nvtx_cu12-12.6.77-py3-none-manylinux2014_x86_64.whl", hash = "sha256:6574241a3ec5fdc9334353ab8c479fe75841dbe8f4532a8fc97ce63503330ba1", size = 89265, upload-time = "2024-10-01T17:00:38.172Z" }, + { url = "https://files.pythonhosted.org/packages/10/c0/1b303feea90d296f6176f32a2a70b5ef230f9bdeb3a72bddb0dc922dc137/nvidia_nvtx_cu12-12.8.90-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:d7ad891da111ebafbf7e015d34879f7112832fc239ff0d7d776b6cb685274615", size = 91161, upload-time = "2025-03-07T01:42:23.922Z" }, + { url = "https://files.pythonhosted.org/packages/a2/eb/86626c1bbc2edb86323022371c39aa48df6fd8b0a1647bc274577f72e90b/nvidia_nvtx_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5b17e2001cc0d751a5bc2c6ec6d26ad95913324a4adb86788c944f8ce9ba441f", size = 89954, upload-time = "2025-03-07T01:42:44.131Z" }, ] [[package]] @@ -2069,8 +2076,7 @@ name = "tifffile" version = "2025.5.10" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "(python_full_version < '3.11' and platform_machine != 'aarch64') or (python_full_version < '3.11' and platform_python_implementation != 'CPython') or (python_full_version < '3.11' and sys_platform != 'linux')", - "python_full_version < '3.11' and platform_machine == 'aarch64' and platform_python_implementation == 'CPython' and sys_platform == 'linux'", + "python_full_version < '3.11'", ] dependencies = [ { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, @@ -2085,10 +2091,8 @@ name = "tifffile" version = "2026.2.24" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "(python_full_version >= '3.12' and platform_machine != 'aarch64') or (python_full_version >= '3.15' and platform_machine == 'aarch64' and platform_python_implementation == 'CPython' and sys_platform == 'linux') or (python_full_version >= '3.12' and platform_python_implementation != 'CPython') or (python_full_version >= '3.12' and sys_platform != 'linux')", - "python_full_version >= '3.12' and python_full_version < '3.15' and platform_machine == 'aarch64' and platform_python_implementation == 'CPython' and sys_platform == 'linux'", - "(python_full_version == '3.11.*' and platform_machine != 'aarch64') or (python_full_version == '3.11.*' and platform_python_implementation != 'CPython') or (python_full_version == '3.11.*' and sys_platform != 'linux')", - "python_full_version == '3.11.*' and platform_machine == 'aarch64' and platform_python_implementation == 'CPython' and sys_platform == 'linux'", + "python_full_version >= '3.12'", + "python_full_version == '3.11.*'", ] dependencies = [ { name = "numpy", version = "2.4.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, @@ -2100,19 +2104,14 @@ wheels = [ [[package]] name = "timm" -version = "1.0.24" -source = { registry = "https://pypi.org/simple" } +version = "1.0.25" +source = { git = "https://github.com/Raiden129/pytorch-image-models-fix?branch=fix%2Fhiera-flash-attention-global-4d#fc1dca92c5a44d5436605dacb19594117ce2eb0c" } dependencies = [ { name = "huggingface-hub" }, { name = "pyyaml" }, { name = "safetensors" }, { name = "torch" }, - { name = "torchvision", version = "0.24.1", source = { registry = "https://download.pytorch.org/whl/cu126" }, marker = "python_full_version < '3.15' and platform_machine == 'aarch64' and platform_python_implementation == 'CPython' and sys_platform == 'linux'" }, - { name = "torchvision", version = "0.24.1+cu126", source = { registry = "https://download.pytorch.org/whl/cu126" }, marker = "python_full_version >= '3.15' or platform_machine != 'aarch64' or platform_python_implementation != 'CPython' or sys_platform != 'linux'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/f4/9d/0ea45640be447445c8664ce2b10c74f763b0b0b9ed11620d41a4d4baa10c/timm-1.0.24.tar.gz", hash = "sha256:c7b909f43fe2ef8fe62c505e270cd4f1af230dfbc37f2ee93e3608492b9d9a40", size = 2412239, upload-time = "2026-01-07T00:26:17.541Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/92/dd/c1f5b0890f7b5db661bde0864b41cb0275be76851047e5f7e085fe0b455a/timm-1.0.24-py3-none-any.whl", hash = "sha256:8301ac783410c6ad72c73c49326af6d71a9e4d1558238552796e825c2464913f", size = 2560563, upload-time = "2026-01-07T00:26:13.956Z" }, + { name = "torchvision" }, ] [[package]] @@ -2201,9 +2200,10 @@ wheels = [ [[package]] name = "torch" -version = "2.9.1+cu126" -source = { registry = "https://download.pytorch.org/whl/cu126" } +version = "2.10.0+cu128" +source = { registry = "https://download.pytorch.org/whl/cu128" } dependencies = [ + { name = "cuda-bindings", marker = "sys_platform == 'linux'" }, { name = "filelock" }, { name = "fsspec" }, { name = "jinja2" }, @@ -2230,84 +2230,61 @@ dependencies = [ { name = "typing-extensions" }, ] wheels = [ - { url = "https://download.pytorch.org/whl/cu126/torch-2.9.1%2Bcu126-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:472da048ab936302ee0dec3bedea16e697ecb41d51bd341142aca2677466f436" }, - { url = "https://download.pytorch.org/whl/cu126/torch-2.9.1%2Bcu126-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:8840a4439668cad44961933cedee9b1242eb67da93ec49c1ab552f4dbce10bbb" }, - { url = "https://download.pytorch.org/whl/cu126/torch-2.9.1%2Bcu126-cp310-cp310-win_amd64.whl", hash = "sha256:37249b92a40042cdd35e536fd8d628453093c879678c9e5587279e2055d69c40" }, - { url = "https://download.pytorch.org/whl/cu126/torch-2.9.1%2Bcu126-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:a1641ad5278e8d830f31eee2f628627d42c53892e1770d1d1e1c475576d327f7" }, - { url = "https://download.pytorch.org/whl/cu126/torch-2.9.1%2Bcu126-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:57e4f908dda76a6d5bf7138727d95fcf7ce07115bc040e7ed541d9d25074b28d" }, - { url = "https://download.pytorch.org/whl/cu126/torch-2.9.1%2Bcu126-cp311-cp311-win_amd64.whl", hash = "sha256:8afd366413aeb51a4732042719f168fae6f4c72326e59e9bdbe20a5c5be52418" }, - { url = "https://download.pytorch.org/whl/cu126/torch-2.9.1%2Bcu126-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:a4fc209b36bd4752db5370388b0ffaab58944240de36a2c0f88129fcf4b07eb2" }, - { url = "https://download.pytorch.org/whl/cu126/torch-2.9.1%2Bcu126-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:67e9b1054f435d33af6fa67343f93d73dc2d37013623672d6ffb24ce39b666c2" }, - { url = "https://download.pytorch.org/whl/cu126/torch-2.9.1%2Bcu126-cp312-cp312-win_amd64.whl", hash = "sha256:f2f1c68c7957ed8b6b56fc450482eb3fa53947fb74838b03834a1760451cf60f" }, - { url = "https://download.pytorch.org/whl/cu126/torch-2.9.1%2Bcu126-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:ad4eb85330a6b7db124462d7e9e00dea3c150e96ca017cc53c4335625705a7a2" }, - { url = "https://download.pytorch.org/whl/cu126/torch-2.9.1%2Bcu126-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:f58a36f53f6bf24312d5d548b640062c99a40394fcb7d0c5a70375aa5be31628" }, - { url = "https://download.pytorch.org/whl/cu126/torch-2.9.1%2Bcu126-cp313-cp313-win_amd64.whl", hash = "sha256:625703f377a53e20cade81291ac742f044ea46a1e86a3949069e62e321025ba3" }, - { url = "https://download.pytorch.org/whl/cu126/torch-2.9.1%2Bcu126-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:8433729c5cf0f928ba4dd43adb3509e6faadd223f0f11028841af025e8721b18" }, - { url = "https://download.pytorch.org/whl/cu126/torch-2.9.1%2Bcu126-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:ad0d5dd90f8e43c5a739e662b0542448e36968002efc4c2a11c5ad3b01faf04b" }, - { url = "https://download.pytorch.org/whl/cu126/torch-2.9.1%2Bcu126-cp313-cp313t-win_amd64.whl", hash = "sha256:2985f3ca723da9f8bc596b38698946a394a0cab541f008ac5bcf5b36696d4ecb" }, - { url = "https://download.pytorch.org/whl/cu126/torch-2.9.1%2Bcu126-cp314-cp314-manylinux_2_28_aarch64.whl", hash = "sha256:c53b63557e2bdb28f94b2e27014f2947a975733b538874c6252c0c2ca47f69e7" }, - { url = "https://download.pytorch.org/whl/cu126/torch-2.9.1%2Bcu126-cp314-cp314-manylinux_2_28_x86_64.whl", hash = "sha256:a7feda6101616061bbd680665bd44cd8ddbdbf5a11ed4c20615821ba09cc9f1c" }, - { url = "https://download.pytorch.org/whl/cu126/torch-2.9.1%2Bcu126-cp314-cp314-win_amd64.whl", hash = "sha256:3c24c69528f328f844d4cd2677a076ff324fe24edde3ed9c00f28d008dc11166" }, - { url = "https://download.pytorch.org/whl/cu126/torch-2.9.1%2Bcu126-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:9311e4e614356421a92d81de0dc78d38ed11074ee4d4e9059cd2d75884308fa2" }, - { url = "https://download.pytorch.org/whl/cu126/torch-2.9.1%2Bcu126-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:a18e6b0eccee2163f90cc894d0a12ed0a83cf009c8597063a05237f2606438d0" }, - { url = "https://download.pytorch.org/whl/cu126/torch-2.9.1%2Bcu126-cp314-cp314t-win_amd64.whl", hash = "sha256:5b8b89f0284bd0d3caf178b64cbc9a5ca785f6c8fa19980718a09e7c13c56131" }, -] - -[[package]] -name = "torchvision" -version = "0.24.1" -source = { registry = "https://download.pytorch.org/whl/cu126" } -resolution-markers = [ - "python_full_version >= '3.12' and python_full_version < '3.15' and platform_machine == 'aarch64' and platform_python_implementation == 'CPython' and sys_platform == 'linux'", - "python_full_version == '3.11.*' and platform_machine == 'aarch64' and platform_python_implementation == 'CPython' and sys_platform == 'linux'", - "python_full_version < '3.11' and platform_machine == 'aarch64' and platform_python_implementation == 'CPython' and sys_platform == 'linux'", -] -dependencies = [ - { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11' and platform_machine == 'aarch64' and platform_python_implementation == 'CPython' and sys_platform == 'linux'" }, - { name = "numpy", version = "2.4.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11' and python_full_version < '3.15' and platform_machine == 'aarch64' and platform_python_implementation == 'CPython' and sys_platform == 'linux'" }, - { name = "pillow", marker = "python_full_version < '3.15' and platform_machine == 'aarch64' and platform_python_implementation == 'CPython' and sys_platform == 'linux'" }, - { name = "torch", marker = "python_full_version < '3.15' and platform_machine == 'aarch64' and platform_python_implementation == 'CPython' and sys_platform == 'linux'" }, -] -wheels = [ - { url = "https://download.pytorch.org/whl/cu126/torchvision-0.24.1-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:9c323bef9dae9e16829c8d7817be9aa9e69a41da5f93bf104f008bf55eb4c2fc" }, - { url = "https://download.pytorch.org/whl/cu126/torchvision-0.24.1-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:6b2623f609a879c0c644e306d9cba7fe15ba13a992f8cbcaabc5cb0c62424717" }, - { url = "https://download.pytorch.org/whl/cu126/torchvision-0.24.1-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:fd907f8ef669947ebfedaabfe6fe6377b8c996116cea2b9580e035555dbaba2d" }, - { url = "https://download.pytorch.org/whl/cu126/torchvision-0.24.1-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:b482eb94ae1804adc602a16d9fcacd3448f6252560e08d1a2f01e0cc97489669" }, - { url = "https://download.pytorch.org/whl/cu126/torchvision-0.24.1-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:1c92fef4ec7dbcc3e252027a3da3d82bdd756a4e70751bba693c3c495fac1c84" }, - { url = "https://download.pytorch.org/whl/cu126/torchvision-0.24.1-cp314-cp314-manylinux_2_28_aarch64.whl", hash = "sha256:0c61888bc79a99f7fa6cdf8c6b93a63d728fcf1694b7d0560247f0d9f1343aad" }, - { url = "https://download.pytorch.org/whl/cu126/torchvision-0.24.1-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:ce4f98b1ffa300a4be656190a50d13f7ce8bea9eed288245860ad5e63e1697b5" }, + { url = "https://download.pytorch.org/whl/cu128/torch-2.10.0%2Bcu128-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:e186f57ef1de1aa877943259819468fc6f27efb583b4a91f9215ada7b7f4e6cc" }, + { url = "https://download.pytorch.org/whl/cu128/torch-2.10.0%2Bcu128-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:36368507b56eaa51acbd3c96ac8893bb9a86991ffcd0699fea3a1a74a2b8bdcb" }, + { url = "https://download.pytorch.org/whl/cu128/torch-2.10.0%2Bcu128-cp310-cp310-win_amd64.whl", hash = "sha256:14d2831b9292c3a9b0d80116451315a08ffe8db745d403d06000bc47165b1f9e" }, + { url = "https://download.pytorch.org/whl/cu128/torch-2.10.0%2Bcu128-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:85ed7944655ea6fd69377692e9cbfd7bba28d99696ceae79985e7caa99cf0a95" }, + { url = "https://download.pytorch.org/whl/cu128/torch-2.10.0%2Bcu128-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:1d01ffaebf64715c0f507a39463149cb19e596ff702bd4bcf862601f2881dabc" }, + { url = "https://download.pytorch.org/whl/cu128/torch-2.10.0%2Bcu128-cp311-cp311-win_amd64.whl", hash = "sha256:3523fda6e2cfab2b04ae09b1424681358e508bb3faa11ceb67004113d5e7acad" }, + { url = "https://download.pytorch.org/whl/cu128/torch-2.10.0%2Bcu128-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:6f09cdf2415516be028ae82e6b985bcfc3eac37bc52ab401142689f6224516ca" }, + { url = "https://download.pytorch.org/whl/cu128/torch-2.10.0%2Bcu128-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:628e89bd5110ced7debee2a57c69959725b7fbc64eab81a39dd70e46c7e28ba5" }, + { url = "https://download.pytorch.org/whl/cu128/torch-2.10.0%2Bcu128-cp312-cp312-win_amd64.whl", hash = "sha256:fbde8f6a9ec8c76979a0d14df21c10b9e5cab6f0d106a73ca73e2179bc597cae" }, + { url = "https://download.pytorch.org/whl/cu128/torch-2.10.0%2Bcu128-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:bdbcc703382f948e951c063448c9406bf38ce66c41dd698d9e2733fcf96c037a" }, + { url = "https://download.pytorch.org/whl/cu128/torch-2.10.0%2Bcu128-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:7b4bd23ed63de97456fcc81c26fea9f02ee02ce1112111c4dac0d8cfe574b23e" }, + { url = "https://download.pytorch.org/whl/cu128/torch-2.10.0%2Bcu128-cp313-cp313-win_amd64.whl", hash = "sha256:4d1b0b49c54223c7c04050b49eac141d77b6edbc34aea1dfc74a6fdb661baa8c" }, + { url = "https://download.pytorch.org/whl/cu128/torch-2.10.0%2Bcu128-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:f1f8b840c64b645a4bc61a393db48effb9c92b2dc26c8373873911f0750d1ea7" }, + { url = "https://download.pytorch.org/whl/cu128/torch-2.10.0%2Bcu128-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:23f58258012bcf1c349cb22af387e33aadca7f83ea617b080e774eb41e4fe8ff" }, + { url = "https://download.pytorch.org/whl/cu128/torch-2.10.0%2Bcu128-cp313-cp313t-win_amd64.whl", hash = "sha256:01b216e097b17a5277cfb47c383cdcacf06abeadcb0daca0c76b59e72854c3b6" }, + { url = "https://download.pytorch.org/whl/cu128/torch-2.10.0%2Bcu128-cp314-cp314-manylinux_2_28_aarch64.whl", hash = "sha256:c42377bc2607e3e1c60da71b792fb507c3938c87fd6edab8b21c59c91473c36d" }, + { url = "https://download.pytorch.org/whl/cu128/torch-2.10.0%2Bcu128-cp314-cp314-manylinux_2_28_x86_64.whl", hash = "sha256:37d71feea068776855686a1512058df3f19f6f040a151f055aa746601678744f" }, + { url = "https://download.pytorch.org/whl/cu128/torch-2.10.0%2Bcu128-cp314-cp314-win_amd64.whl", hash = "sha256:c57017ca29e62271e362fdeee7d20070e254755a5148b30b553d8a10fc83c7ef" }, + { url = "https://download.pytorch.org/whl/cu128/torch-2.10.0%2Bcu128-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:777461f50b2daf77e4bdd8e2ad34bdfc5a993bf1bdf2ab9ef39f5edfe4e9c12b" }, + { url = "https://download.pytorch.org/whl/cu128/torch-2.10.0%2Bcu128-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:7bcba6a7c5f0987a13298b1ca843155dcceceac758fa3c7ccd5c7af4059a1080" }, + { url = "https://download.pytorch.org/whl/cu128/torch-2.10.0%2Bcu128-cp314-cp314t-win_amd64.whl", hash = "sha256:70d89143c956389d4806cb4e5fe0b1129fe0db280e1073288d17fa76c101cba4" }, ] [[package]] name = "torchvision" -version = "0.24.1+cu126" -source = { registry = "https://download.pytorch.org/whl/cu126" } -resolution-markers = [ - "(python_full_version >= '3.12' and platform_machine != 'aarch64') or (python_full_version >= '3.15' and platform_machine == 'aarch64' and platform_python_implementation == 'CPython' and sys_platform == 'linux') or (python_full_version >= '3.12' and platform_python_implementation != 'CPython') or (python_full_version >= '3.12' and sys_platform != 'linux')", - "(python_full_version == '3.11.*' and platform_machine != 'aarch64') or (python_full_version == '3.11.*' and platform_python_implementation != 'CPython') or (python_full_version == '3.11.*' and sys_platform != 'linux')", - "(python_full_version < '3.11' and platform_machine != 'aarch64') or (python_full_version < '3.11' and platform_python_implementation != 'CPython') or (python_full_version < '3.11' and sys_platform != 'linux')", -] +version = "0.25.0+cu128" +source = { registry = "https://download.pytorch.org/whl/cu128" } dependencies = [ - { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version < '3.11' and platform_machine != 'aarch64') or (python_full_version < '3.11' and platform_python_implementation != 'CPython') or (python_full_version < '3.11' and sys_platform != 'linux')" }, - { name = "numpy", version = "2.4.2", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version >= '3.11' and platform_machine != 'aarch64') or (python_full_version >= '3.15' and platform_machine == 'aarch64' and platform_python_implementation == 'CPython' and sys_platform == 'linux') or (python_full_version >= '3.11' and platform_python_implementation != 'CPython') or (python_full_version >= '3.11' and sys_platform != 'linux')" }, - { name = "pillow", marker = "python_full_version >= '3.15' or platform_machine != 'aarch64' or platform_python_implementation != 'CPython' or sys_platform != 'linux'" }, - { name = "torch", marker = "python_full_version >= '3.15' or platform_machine != 'aarch64' or platform_python_implementation != 'CPython' or sys_platform != 'linux'" }, + { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "numpy", version = "2.4.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "pillow" }, + { name = "torch" }, ] wheels = [ - { url = "https://download.pytorch.org/whl/cu126/torchvision-0.24.1%2Bcu126-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:c737acd56a7737c91961b72c303b3ad61c8e6d24f84acd49708c28c934388363" }, - { url = "https://download.pytorch.org/whl/cu126/torchvision-0.24.1%2Bcu126-cp310-cp310-win_amd64.whl", hash = "sha256:ed14aae2437edfd975632d12666c02a21a7a863b16c0b9b9def330e36dd70b85" }, - { url = "https://download.pytorch.org/whl/cu126/torchvision-0.24.1%2Bcu126-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:a5d41ccba2f8ee675cb2101e815dcf6daf85bae1d027fa2bfb1d8d55916bc84c" }, - { url = "https://download.pytorch.org/whl/cu126/torchvision-0.24.1%2Bcu126-cp311-cp311-win_amd64.whl", hash = "sha256:c855a39415d8a217cf6c488bcd29735ceefe33e49cb8afe6ebda44e6e30c106d" }, - { url = "https://download.pytorch.org/whl/cu126/torchvision-0.24.1%2Bcu126-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:987243137be01e8c3e13bde20264c1c3fc22c570f607e19759d92bbc1364a75d" }, - { url = "https://download.pytorch.org/whl/cu126/torchvision-0.24.1%2Bcu126-cp312-cp312-win_amd64.whl", hash = "sha256:54c1902bad62bd113f66dd3cc0368aa4d0005837100d3ab9dc823aebf945ead0" }, - { url = "https://download.pytorch.org/whl/cu126/torchvision-0.24.1%2Bcu126-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:853063c3774053c034b384abc2faf90b4b06f5d3f0af486411011346cc078e13" }, - { url = "https://download.pytorch.org/whl/cu126/torchvision-0.24.1%2Bcu126-cp313-cp313-win_amd64.whl", hash = "sha256:4a445944ea2042f86dede9109412519e70759cfc3bbbab841b01680cd478f291" }, - { url = "https://download.pytorch.org/whl/cu126/torchvision-0.24.1%2Bcu126-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:b835cae9feae5aaa8153b6f07625fa6440ccd785542085afb9d2fab59dbc85d9" }, - { url = "https://download.pytorch.org/whl/cu126/torchvision-0.24.1%2Bcu126-cp313-cp313t-win_amd64.whl", hash = "sha256:6c1d06b3807c1523ff3941b2c97fdde52e3ca8067658975a7f9b10c87a92d451" }, - { url = "https://download.pytorch.org/whl/cu126/torchvision-0.24.1%2Bcu126-cp314-cp314-manylinux_2_28_x86_64.whl", hash = "sha256:a5467e817c5da64909e240684d5470b5eb8d7e9b521cc41802dbfc336ac65973" }, - { url = "https://download.pytorch.org/whl/cu126/torchvision-0.24.1%2Bcu126-cp314-cp314-win_amd64.whl", hash = "sha256:6e769b5fa356a4fb994c4f47541a9665ee7b66e8d10c626d041cd44ab74f7dda" }, - { url = "https://download.pytorch.org/whl/cu126/torchvision-0.24.1%2Bcu126-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:1a778c4f1a84a90edc9e38710f619742c5ffc00bc082da0a1e00bb27273ed53c" }, - { url = "https://download.pytorch.org/whl/cu126/torchvision-0.24.1%2Bcu126-cp314-cp314t-win_amd64.whl", hash = "sha256:9fa881f07cc7eb0c64a832864db803158c2e06eb3af99bd36af36aaec469db7b" }, + { url = "https://download.pytorch.org/whl/cu128/torchvision-0.25.0%2Bcu128-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:7fe593ec669dcb74c3749548c21ae302edbb610d1844bf2bae622accdef76dbc" }, + { url = "https://download.pytorch.org/whl/cu128/torchvision-0.25.0%2Bcu128-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:43350ca113e9f224dedb90b0dbc31853fde4f322794cbe4a7589fccab5768f0c" }, + { url = "https://download.pytorch.org/whl/cu128/torchvision-0.25.0%2Bcu128-cp310-cp310-win_amd64.whl", hash = "sha256:cef372be1a78856000cd04b72e5d738a0d851022c68b6900cbbd61dc91c25f79" }, + { url = "https://download.pytorch.org/whl/cu128/torchvision-0.25.0%2Bcu128-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:5d576c65d40198627e0fad03bddeb0ef536371312f2bdfcc804c22fd28fa6018" }, + { url = "https://download.pytorch.org/whl/cu128/torchvision-0.25.0%2Bcu128-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:ebf2b495c76097796b9a2eac9290efbcae96e0fd9e5ae52c40eff188610bb440" }, + { url = "https://download.pytorch.org/whl/cu128/torchvision-0.25.0%2Bcu128-cp311-cp311-win_amd64.whl", hash = "sha256:af00b4e0cdb3f490f4393e9a335b622fe1b92fd5afb181033256ccba03b9637c" }, + { url = "https://download.pytorch.org/whl/cu128/torchvision-0.25.0%2Bcu128-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:8623e534ef6a815bd6407d4b52dd70c7154e2eda626ad4b9cb895d36c5a3305b" }, + { url = "https://download.pytorch.org/whl/cu128/torchvision-0.25.0%2Bcu128-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:1255a0ca2bf987acf9f103b96c5c4cfe3415fc4a1eef17fa08af527a04a4f573" }, + { url = "https://download.pytorch.org/whl/cu128/torchvision-0.25.0%2Bcu128-cp312-cp312-win_amd64.whl", hash = "sha256:068e519838b4a8b32a09521244b170edd8c2ac9eeb6538b7bf492cd70e57ebf5" }, + { url = "https://download.pytorch.org/whl/cu128/torchvision-0.25.0%2Bcu128-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:12c253520a26483fe3c614f63ff16eca6d9b0b4ebe510699b7d15d88e6c0cd35" }, + { url = "https://download.pytorch.org/whl/cu128/torchvision-0.25.0%2Bcu128-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:a9c0de893dce9c2913c9c7ae88a916910f92d02b99da149678806d18e8079f29" }, + { url = "https://download.pytorch.org/whl/cu128/torchvision-0.25.0%2Bcu128-cp313-cp313-win_amd64.whl", hash = "sha256:e2e0317e3861bba1b5aeba7c1cb4bcd50937cf0bffdbea478619d1f5f73e9050" }, + { url = "https://download.pytorch.org/whl/cu128/torchvision-0.25.0%2Bcu128-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:58b2971b55c761f1d2491bd80fcc4618ea97d363d387a9dd3aff23220cbee264" }, + { url = "https://download.pytorch.org/whl/cu128/torchvision-0.25.0%2Bcu128-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:1b6878b043513ea3dea1b90bfb5193455d9b248b8c4d5e66ea9f5d1643a43f13" }, + { url = "https://download.pytorch.org/whl/cu128/torchvision-0.25.0%2Bcu128-cp313-cp313t-win_amd64.whl", hash = "sha256:96cd2ba7b289117873b2a8f4c80605d38118d920b1045f3ce21a9f0ca68a701e" }, + { url = "https://download.pytorch.org/whl/cu128/torchvision-0.25.0%2Bcu128-cp314-cp314-manylinux_2_28_aarch64.whl", hash = "sha256:e2dbf9ea9f4b2416822249e96ff3ad873d9a84e51285d6b9967732be3015c523" }, + { url = "https://download.pytorch.org/whl/cu128/torchvision-0.25.0%2Bcu128-cp314-cp314-manylinux_2_28_x86_64.whl", hash = "sha256:5b7ad3fb6cf03ef2a2fd617cb4b4e41efa9bb0143c67f506c2a3e6765c7b12ad" }, + { url = "https://download.pytorch.org/whl/cu128/torchvision-0.25.0%2Bcu128-cp314-cp314-win_amd64.whl", hash = "sha256:a52ff3b072e89280f41499813e11c418d168ffc502b86cb17767bab29f432b3a" }, + { url = "https://download.pytorch.org/whl/cu128/torchvision-0.25.0%2Bcu128-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:687987fbcb074fd7f7a61cf2b407b1eac07588ace8351a3a36978546a00adc52" }, + { url = "https://download.pytorch.org/whl/cu128/torchvision-0.25.0%2Bcu128-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:84c5e2cb699235339b8a5c295e974a795244a45d1104ecee658d9d19600cdc75" }, + { url = "https://download.pytorch.org/whl/cu128/torchvision-0.25.0%2Bcu128-cp314-cp314t-win_amd64.whl", hash = "sha256:d1cf27bc2da13fd9e83694ae601b1bf4135c24d9c9e9ec249056896395a78a9e" }, ] [[package]] @@ -2345,23 +2322,35 @@ wheels = [ [[package]] name = "triton" -version = "3.5.1" -source = { registry = "https://pypi.org/simple" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/d9/2e/f95e673222afa2c7f0c687d8913e98fcf2589ef0b1405de76894e37fe18f/triton-3.5.1-cp310-cp310-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f63e34dcb32d7bd3a1d0195f60f30d2aee8b08a69a0424189b71017e23dfc3d2", size = 159821655, upload-time = "2025-11-11T17:51:44.09Z" }, - { url = "https://files.pythonhosted.org/packages/fd/6e/676ab5019b4dde8b9b7bab71245102fc02778ef3df48218b298686b9ffd6/triton-3.5.1-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:5fc53d849f879911ea13f4a877243afc513187bc7ee92d1f2c0f1ba3169e3c94", size = 170320692, upload-time = "2025-11-11T17:40:46.074Z" }, - { url = "https://files.pythonhosted.org/packages/dc/dc/6ce44d055f2fc2403c4ec6b3cfd3a9b25f57b7d95efadccdea91497f8e81/triton-3.5.1-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:da47169e30a779bade679ce78df4810fca6d78a955843d2ddb11f226adc517dc", size = 159928005, upload-time = "2025-11-11T17:51:50.008Z" }, - { url = "https://files.pythonhosted.org/packages/b0/72/ec90c3519eaf168f22cb1757ad412f3a2add4782ad3a92861c9ad135d886/triton-3.5.1-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:61413522a48add32302353fdbaaf92daaaab06f6b5e3229940d21b5207f47579", size = 170425802, upload-time = "2025-11-11T17:40:53.209Z" }, - { url = "https://files.pythonhosted.org/packages/db/53/2bcc46879910991f09c063eea07627baef2bc62fe725302ba8f46a2c1ae5/triton-3.5.1-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:275a045b6ed670dd1bd005c3e6c2d61846c74c66f4512d6f33cc027b11de8fd4", size = 159940689, upload-time = "2025-11-11T17:51:55.938Z" }, - { url = "https://files.pythonhosted.org/packages/f2/50/9a8358d3ef58162c0a415d173cfb45b67de60176e1024f71fbc4d24c0b6d/triton-3.5.1-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d2c6b915a03888ab931a9fd3e55ba36785e1fe70cbea0b40c6ef93b20fc85232", size = 170470207, upload-time = "2025-11-11T17:41:00.253Z" }, - { url = "https://files.pythonhosted.org/packages/f1/ba/805684a992ee32d486b7948d36aed2f5e3c643fc63883bf8bdca1c3f3980/triton-3.5.1-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:56765ffe12c554cd560698398b8a268db1f616c120007bfd8829d27139abd24a", size = 159955460, upload-time = "2025-11-11T17:52:01.861Z" }, - { url = "https://files.pythonhosted.org/packages/27/46/8c3bbb5b0a19313f50edcaa363b599e5a1a5ac9683ead82b9b80fe497c8d/triton-3.5.1-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:f3f4346b6ebbd4fad18773f5ba839114f4826037c9f2f34e0148894cd5dd3dba", size = 170470410, upload-time = "2025-11-11T17:41:06.319Z" }, - { url = "https://files.pythonhosted.org/packages/84/1e/7df59baef41931e21159371c481c31a517ff4c2517343b62503d0cd2be99/triton-3.5.1-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:02c770856f5e407d24d28ddc66e33cf026e6f4d360dcb8b2fabe6ea1fc758621", size = 160072799, upload-time = "2025-11-11T17:52:07.293Z" }, - { url = "https://files.pythonhosted.org/packages/37/92/e97fcc6b2c27cdb87ce5ee063d77f8f26f19f06916aa680464c8104ef0f6/triton-3.5.1-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0b4d2c70127fca6a23e247f9348b8adde979d2e7a20391bfbabaac6aebc7e6a8", size = 170579924, upload-time = "2025-11-11T17:41:12.455Z" }, - { url = "https://files.pythonhosted.org/packages/14/f9/0430e879c1e63a1016cb843261528fd3187c872c3a9539132efc39514753/triton-3.5.1-cp314-cp314-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f617aa7925f9ea9968ec2e1adaf93e87864ff51549c8f04ce658f29bbdb71e2d", size = 159956163, upload-time = "2025-11-11T17:52:12.999Z" }, - { url = "https://files.pythonhosted.org/packages/a4/e6/c595c35e5c50c4bc56a7bac96493dad321e9e29b953b526bbbe20f9911d0/triton-3.5.1-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d0637b1efb1db599a8e9dc960d53ab6e4637db7d4ab6630a0974705d77b14b60", size = 170480488, upload-time = "2025-11-11T17:41:18.222Z" }, - { url = "https://files.pythonhosted.org/packages/41/1e/63d367c576c75919e268e4fbc33c1cb33b6dc12bb85e8bfe531c2a8bd5d3/triton-3.5.1-cp314-cp314t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:8932391d7f93698dfe5bc9bead77c47a24f97329e9f20c10786bb230a9083f56", size = 160073620, upload-time = "2025-11-11T17:52:18.403Z" }, - { url = "https://files.pythonhosted.org/packages/16/b5/b0d3d8b901b6a04ca38df5e24c27e53afb15b93624d7fd7d658c7cd9352a/triton-3.5.1-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:bac7f7d959ad0f48c0e97d6643a1cc0fd5786fe61cb1f83b537c6b2d54776478", size = 170582192, upload-time = "2025-11-11T17:41:23.963Z" }, +version = "3.6.0" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/44/ba/b1b04f4b291a3205d95ebd24465de0e5bf010a2df27a4e58a9b5f039d8f2/triton-3.6.0-cp310-cp310-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6c723cfb12f6842a0ae94ac307dba7e7a44741d720a40cf0e270ed4a4e3be781", size = 175972180, upload-time = "2026-01-20T16:15:53.664Z" }, + { url = "https://files.pythonhosted.org/packages/8c/f7/f1c9d3424ab199ac53c2da567b859bcddbb9c9e7154805119f8bd95ec36f/triton-3.6.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a6550fae429e0667e397e5de64b332d1e5695b73650ee75a6146e2e902770bea", size = 188105201, upload-time = "2026-01-20T16:00:29.272Z" }, + { url = "https://files.pythonhosted.org/packages/0f/2c/96f92f3c60387e14cc45aed49487f3486f89ea27106c1b1376913c62abe4/triton-3.6.0-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:49df5ef37379c0c2b5c0012286f80174fcf0e073e5ade1ca9a86c36814553651", size = 176081190, upload-time = "2026-01-20T16:16:00.523Z" }, + { url = "https://files.pythonhosted.org/packages/e0/12/b05ba554d2c623bffa59922b94b0775673de251f468a9609bc9e45de95e9/triton-3.6.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e8e323d608e3a9bfcc2d9efcc90ceefb764a82b99dea12a86d643c72539ad5d3", size = 188214640, upload-time = "2026-01-20T16:00:35.869Z" }, + { url = "https://files.pythonhosted.org/packages/17/5d/08201db32823bdf77a0e2b9039540080b2e5c23a20706ddba942924ebcd6/triton-3.6.0-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:374f52c11a711fd062b4bfbb201fd9ac0a5febd28a96fb41b4a0f51dde3157f4", size = 176128243, upload-time = "2026-01-20T16:16:07.857Z" }, + { url = "https://files.pythonhosted.org/packages/ab/a8/cdf8b3e4c98132f965f88c2313a4b493266832ad47fb52f23d14d4f86bb5/triton-3.6.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:74caf5e34b66d9f3a429af689c1c7128daba1d8208df60e81106b115c00d6fca", size = 188266850, upload-time = "2026-01-20T16:00:43.041Z" }, + { url = "https://files.pythonhosted.org/packages/3c/12/34d71b350e89a204c2c7777a9bba0dcf2f19a5bfdd70b57c4dbc5ffd7154/triton-3.6.0-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:448e02fe6dc898e9e5aa89cf0ee5c371e99df5aa5e8ad976a80b93334f3494fd", size = 176133521, upload-time = "2026-01-20T16:16:13.321Z" }, + { url = "https://files.pythonhosted.org/packages/f9/0b/37d991d8c130ce81a8728ae3c25b6e60935838e9be1b58791f5997b24a54/triton-3.6.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:10c7f76c6e72d2ef08df639e3d0d30729112f47a56b0c81672edc05ee5116ac9", size = 188289450, upload-time = "2026-01-20T16:00:49.136Z" }, + { url = "https://files.pythonhosted.org/packages/ce/4e/41b0c8033b503fd3cfcd12392cdd256945026a91ff02452bef40ec34bee7/triton-3.6.0-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:1722e172d34e32abc3eb7711d0025bb69d7959ebea84e3b7f7a341cd7ed694d6", size = 176276087, upload-time = "2026-01-20T16:16:18.989Z" }, + { url = "https://files.pythonhosted.org/packages/35/f8/9c66bfc55361ec6d0e4040a0337fb5924ceb23de4648b8a81ae9d33b2b38/triton-3.6.0-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d002e07d7180fd65e622134fbd980c9a3d4211fb85224b56a0a0efbd422ab72f", size = 188400296, upload-time = "2026-01-20T16:00:56.042Z" }, + { url = "https://files.pythonhosted.org/packages/49/55/5ecf0dcaa0f2fbbd4420f7ef227ee3cb172e91e5fede9d0ecaddc43363b4/triton-3.6.0-cp314-cp314-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ef5523241e7d1abca00f1d240949eebdd7c673b005edbbce0aca95b8191f1d43", size = 176138577, upload-time = "2026-01-20T16:16:25.426Z" }, + { url = "https://files.pythonhosted.org/packages/df/3d/9e7eee57b37c80cec63322c0231bb6da3cfe535a91d7a4d64896fcb89357/triton-3.6.0-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a17a5d5985f0ac494ed8a8e54568f092f7057ef60e1b0fa09d3fd1512064e803", size = 188273063, upload-time = "2026-01-20T16:01:07.278Z" }, + { url = "https://files.pythonhosted.org/packages/48/db/56ee649cab5eaff4757541325aca81f52d02d4a7cd3506776cad2451e060/triton-3.6.0-cp314-cp314t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0b3a97e8ed304dfa9bd23bb41ca04cdf6b2e617d5e782a8653d616037a5d537d", size = 176274804, upload-time = "2026-01-20T16:16:31.528Z" }, + { url = "https://files.pythonhosted.org/packages/f6/56/6113c23ff46c00aae423333eb58b3e60bdfe9179d542781955a5e1514cb3/triton-3.6.0-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:46bd1c1af4b6704e554cad2eeb3b0a6513a980d470ccfa63189737340c7746a7", size = 188397994, upload-time = "2026-01-20T16:01:14.236Z" }, +] + +[[package]] +name = "triton-windows" +version = "3.6.0.post25" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ba/ca/6d38c374a427a360dc4c7687f15fdb217dd4ca3bf87d0ad31f9818e22188/triton_windows-3.6.0.post25-cp310-cp310-win_amd64.whl", hash = "sha256:8c45b7f83eecb71c3aeded1da7914af0050bddda710f47a2cfae936d55fae0ca", size = 47380144, upload-time = "2026-01-26T03:21:14.942Z" }, + { url = "https://files.pythonhosted.org/packages/49/b8/2ce283452b0b9e0d239c7833626750befe94d5bbed18fb9449dcc5fa494e/triton_windows-3.6.0.post25-cp311-cp311-win_amd64.whl", hash = "sha256:5dabf103499825379c9ba877da46a4c34296466a628b539249482ab6d970708e", size = 47381466, upload-time = "2026-01-26T03:21:21.541Z" }, + { url = "https://files.pythonhosted.org/packages/66/b1/9744fc17eded50644ffb95f3f4b1ffd1f42d646d6e0a811d92e43834865e/triton_windows-3.6.0.post25-cp312-cp312-win_amd64.whl", hash = "sha256:8361375ee4b5e0a4fe7a3c7fc2fde368ce74237396d8ff95c2e26983dd32e342", size = 47382693, upload-time = "2026-01-26T03:21:28.157Z" }, + { url = "https://files.pythonhosted.org/packages/e5/cb/1f5f738cf8f6b8c6d475a92422251228a16ca2ee6f872d0f63c761f02896/triton_windows-3.6.0.post25-cp313-cp313-win_amd64.whl", hash = "sha256:d22e5f6f4896b43037d811910e2fcc5ff5f057b78f6094ab28999e4a21997b76", size = 47383937, upload-time = "2026-01-26T03:21:35.071Z" }, + { url = "https://files.pythonhosted.org/packages/c7/d3/58ad68518e04a97ce0549cad98eccbafac01ddba640379776a58b513020b/triton_windows-3.6.0.post25-cp314-cp314-win_amd64.whl", hash = "sha256:6f4c4775b22cfb18e9c60aead83deb7b9b970624ae3c13cd26b9be80b5cb8cd8", size = 48566374, upload-time = "2026-01-26T03:21:41.743Z" }, ] [[package]] From 769dc7b192ffaef6b0f8bfe677ad86ca48c80d37 Mon Sep 17 00:00:00 2001 From: Marclie Date: Mon, 9 Mar 2026 15:32:37 +0100 Subject: [PATCH 02/44] fix tests and formatting --- CorridorKeyModule/core/model_transformer.py | 1 + CorridorKeyModule/inference_engine.py | 20 +-- test_vram.py | 11 +- tests/test_inference_engine.py | 2 + uv.lock | 175 +++++++++++++++----- 5 files changed, 155 insertions(+), 54 deletions(-) diff --git a/CorridorKeyModule/core/model_transformer.py b/CorridorKeyModule/core/model_transformer.py index 0b6d20e9..7333f27e 100644 --- a/CorridorKeyModule/core/model_transformer.py +++ b/CorridorKeyModule/core/model_transformer.py @@ -1,4 +1,5 @@ from __future__ import annotations + import sys import timm diff --git a/CorridorKeyModule/inference_engine.py b/CorridorKeyModule/inference_engine.py index acd9a8bd..f3c59e19 100644 --- a/CorridorKeyModule/inference_engine.py +++ b/CorridorKeyModule/inference_engine.py @@ -2,7 +2,6 @@ import math import os -from timeit import timeit import cv2 import numpy as np @@ -15,9 +14,9 @@ class CorridorKeyEngine: def __init__( - self, - checkpoint_path: str, - device: str = "cpu", + self, + checkpoint_path: str, + device: str = "cpu", img_size: int = 2048, use_refiner: bool = True, mixed_precision: bool = True, @@ -30,17 +29,18 @@ def __init__( self.mean = np.array([0.485, 0.456, 0.406], dtype=np.float32).reshape(1, 1, 3) self.std = np.array([0.229, 0.224, 0.225], dtype=np.float32).reshape(1, 1, 3) - + if mixed_precision or model_precision != torch.float32: # Use faster matrix multiplication implementation - # This reduces the floating point precision a little bit, but it should be negligible compared to fp16 precision - torch.set_float32_matmul_precision('high') - + # This reduces the floating point precision a little bit, + # but it should be negligible compared to fp16 precision + torch.set_float32_matmul_precision("high") + self.mixed_precision = mixed_precision if mixed_precision and model_precision == torch.float16: # using mixed precision, when the precision is already fp16, is slower self.mixed_precision = False - + self.model_precision = model_precision self.model = self._load_model().to(model_precision) @@ -168,7 +168,7 @@ def process_frame( # 4. Prepare Tensor inp_np = np.concatenate([img_norm, mask_resized], axis=-1) # [H, W, 4] - inp_t = torch.from_numpy(inp_np.transpose((2, 0, 1))).float().unsqueeze(0).to(self.model_precision).to(self.device) + inp_t = torch.from_numpy(inp_np.transpose((2, 0, 1))).unsqueeze(0).to(self.model_precision).to(self.device) # 5. Inference # Hook for Refiner Scaling diff --git a/test_vram.py b/test_vram.py index 901161bd..2f734d8d 100644 --- a/test_vram.py +++ b/test_vram.py @@ -1,19 +1,26 @@ import timeit + import numpy as np import torch from CorridorKeyModule.inference_engine import CorridorKeyEngine + def process_frame(engine): img = np.random.randint(0, 255, (2160, 3840, 3), dtype=np.uint8) mask = np.random.randint(0, 255, (2160, 3840), dtype=np.uint8) - + engine.process_frame(img, mask) def test_vram(): print("Loading engine...") - engine = CorridorKeyEngine(checkpoint_path="CorridorKeyModule/checkpoints/CorridorKey_v1.0.pth", img_size=2048, device="cuda", model_precision=torch.float16) + engine = CorridorKeyEngine( + checkpoint_path="CorridorKeyModule/checkpoints/CorridorKey_v1.0.pth", + img_size=2048, + device="cuda", + model_precision=torch.float16, + ) # Reset stats torch.cuda.reset_peak_memory_stats() diff --git a/tests/test_inference_engine.py b/tests/test_inference_engine.py index 8cf7b5fb..ba0f8d4c 100644 --- a/tests/test_inference_engine.py +++ b/tests/test_inference_engine.py @@ -37,6 +37,8 @@ def _make_engine_with_mock(mock_greenformer, img_size=64): engine.mean = np.array([0.485, 0.456, 0.406], dtype=np.float32).reshape(1, 1, 3) engine.std = np.array([0.229, 0.224, 0.225], dtype=np.float32).reshape(1, 1, 3) engine.model = mock_greenformer + engine.model_precision = torch.float32 + engine.mixed_precision = True return engine diff --git a/uv.lock b/uv.lock index fbcb6769..609ba3f1 100644 --- a/uv.lock +++ b/uv.lock @@ -2,9 +2,12 @@ version = 1 revision = 3 requires-python = ">=3.10, <=3.14" resolution-markers = [ - "python_full_version >= '3.12'", - "python_full_version == '3.11.*'", - "python_full_version < '3.11'", + "python_full_version >= '3.12' and sys_platform != 'darwin'", + "python_full_version >= '3.12' and sys_platform == 'darwin'", + "python_full_version == '3.11.*' and sys_platform != 'darwin'", + "python_full_version == '3.11.*' and sys_platform == 'darwin'", + "python_full_version < '3.11' and sys_platform != 'darwin'", + "python_full_version < '3.11' and sys_platform == 'darwin'", ] [[package]] @@ -19,7 +22,8 @@ dependencies = [ { name = "psutil" }, { name = "pyyaml" }, { name = "safetensors" }, - { name = "torch" }, + { name = "torch", version = "2.10.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform == 'darwin'" }, + { name = "torch", version = "2.10.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "sys_platform != 'darwin'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/4a/8e/ac2a9566747a93f8be36ee08532eb0160558b07630a081a6056a9f89bf1d/accelerate-1.12.0.tar.gz", hash = "sha256:70988c352feb481887077d2ab845125024b2a137a5090d6d7a32b57d03a45df6", size = 398399, upload-time = "2025-11-21T11:27:46.973Z" } wheels = [ @@ -230,7 +234,8 @@ name = "contourpy" version = "1.3.2" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version < '3.11'", + "python_full_version < '3.11' and sys_platform != 'darwin'", + "python_full_version < '3.11' and sys_platform == 'darwin'", ] dependencies = [ { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, @@ -300,8 +305,10 @@ name = "contourpy" version = "1.3.3" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version >= '3.12'", - "python_full_version == '3.11.*'", + "python_full_version >= '3.12' and sys_platform != 'darwin'", + "python_full_version >= '3.12' and sys_platform == 'darwin'", + "python_full_version == '3.11.*' and sys_platform != 'darwin'", + "python_full_version == '3.11.*' and sys_platform == 'darwin'", ] dependencies = [ { name = "numpy", version = "2.4.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, @@ -402,8 +409,10 @@ dependencies = [ { name = "pims" }, { name = "setuptools" }, { name = "timm" }, - { name = "torch" }, - { name = "torchvision" }, + { name = "torch", version = "2.10.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform == 'darwin'" }, + { name = "torch", version = "2.10.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "sys_platform != 'darwin'" }, + { name = "torchvision", version = "0.25.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform == 'darwin'" }, + { name = "torchvision", version = "0.25.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "sys_platform != 'darwin'" }, { name = "tqdm" }, { name = "transformers" }, { name = "triton-windows", marker = "sys_platform == 'win32'" }, @@ -433,8 +442,10 @@ requires-dist = [ { name = "pims" }, { name = "setuptools" }, { name = "timm", git = "https://github.com/Raiden129/pytorch-image-models-fix?branch=fix%2Fhiera-flash-attention-global-4d" }, - { name = "torch", specifier = "==2.10.0", index = "https://download.pytorch.org/whl/cu128" }, - { name = "torchvision", specifier = "==0.25.0", index = "https://download.pytorch.org/whl/cu128" }, + { name = "torch", marker = "sys_platform != 'darwin'", specifier = "==2.10.0", index = "https://download.pytorch.org/whl/cu128" }, + { name = "torch", marker = "sys_platform == 'darwin'", specifier = "==2.10.0" }, + { name = "torchvision", marker = "sys_platform != 'darwin'", specifier = "==0.25.0", index = "https://download.pytorch.org/whl/cu128" }, + { name = "torchvision", marker = "sys_platform == 'darwin'", specifier = "==0.25.0" }, { name = "tqdm" }, { name = "transformers" }, { name = "triton-windows", marker = "sys_platform == 'win32'", specifier = "==3.6.0.post25" }, @@ -570,7 +581,7 @@ name = "cuda-bindings" version = "12.9.4" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "cuda-pathfinder" }, + { name = "cuda-pathfinder", marker = "sys_platform != 'darwin'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/37/31/bfcc870f69c6a017c4ad5c42316207fc7551940db6f3639aa4466ec5faf3/cuda_bindings-12.9.4-cp310-cp310-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a022c96b8bd847e8dc0675523431149a4c3e872f440e3002213dbb9e08f0331a", size = 11800959, upload-time = "2025-10-21T14:51:26.458Z" }, @@ -1180,7 +1191,8 @@ name = "networkx" version = "3.4.2" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version < '3.11'", + "python_full_version < '3.11' and sys_platform != 'darwin'", + "python_full_version < '3.11' and sys_platform == 'darwin'", ] sdist = { url = "https://files.pythonhosted.org/packages/fd/1d/06475e1cd5264c0b870ea2cc6fdb3e37177c1e565c43f56ff17a10e3937f/networkx-3.4.2.tar.gz", hash = "sha256:307c3669428c5362aab27c8a1260aa8f47c4e91d3891f48be0141738d8d053e1", size = 2151368, upload-time = "2024-10-21T12:39:38.695Z" } wheels = [ @@ -1192,8 +1204,10 @@ name = "networkx" version = "3.6.1" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version >= '3.12'", - "python_full_version == '3.11.*'", + "python_full_version >= '3.12' and sys_platform != 'darwin'", + "python_full_version >= '3.12' and sys_platform == 'darwin'", + "python_full_version == '3.11.*' and sys_platform != 'darwin'", + "python_full_version == '3.11.*' and sys_platform == 'darwin'", ] sdist = { url = "https://files.pythonhosted.org/packages/6a/51/63fe664f3908c97be9d2e4f1158eb633317598cfa6e1fc14af5383f17512/networkx-3.6.1.tar.gz", hash = "sha256:26b7c357accc0c8cde558ad486283728b65b6a95d85ee1cd66bafab4c8168509", size = 2517025, upload-time = "2025-12-08T17:02:39.908Z" } wheels = [ @@ -1205,7 +1219,8 @@ name = "numpy" version = "2.2.6" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version < '3.11'", + "python_full_version < '3.11' and sys_platform != 'darwin'", + "python_full_version < '3.11' and sys_platform == 'darwin'", ] sdist = { url = "https://files.pythonhosted.org/packages/76/21/7d2a95e4bba9dc13d043ee156a356c0a8f0c6309dff6b21b4d71a073b8a8/numpy-2.2.6.tar.gz", hash = "sha256:e29554e2bef54a90aa5cc07da6ce955accb83f21ab5de01a62c8478897b264fd", size = 20276440, upload-time = "2025-05-17T22:38:04.611Z" } wheels = [ @@ -1270,8 +1285,10 @@ name = "numpy" version = "2.4.2" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version >= '3.12'", - "python_full_version == '3.11.*'", + "python_full_version >= '3.12' and sys_platform != 'darwin'", + "python_full_version >= '3.12' and sys_platform == 'darwin'", + "python_full_version == '3.11.*' and sys_platform != 'darwin'", + "python_full_version == '3.11.*' and sys_platform == 'darwin'", ] sdist = { url = "https://files.pythonhosted.org/packages/57/fd/0005efbd0af48e55eb3c7208af93f2862d4b1a56cd78e84309a2d959208d/numpy-2.4.2.tar.gz", hash = "sha256:659a6107e31a83c4e33f763942275fd278b21d095094044eb35569e86a21ddae", size = 20723651, upload-time = "2026-01-31T23:13:10.135Z" } wheels = [ @@ -1389,7 +1406,7 @@ name = "nvidia-cudnn-cu12" version = "9.10.2.21" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-cublas-cu12" }, + { name = "nvidia-cublas-cu12", marker = "sys_platform != 'darwin'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/fa/41/e79269ce215c857c935fd86bcfe91a451a584dfc27f1e068f568b9ad1ab7/nvidia_cudnn_cu12-9.10.2.21-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:c9132cc3f8958447b4910a1720036d9eff5928cc3179b0a51fb6d167c6cc87d8", size = 705026878, upload-time = "2025-06-06T21:52:51.348Z" }, @@ -1401,7 +1418,7 @@ name = "nvidia-cufft-cu12" version = "11.3.3.83" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-nvjitlink-cu12" }, + { name = "nvidia-nvjitlink-cu12", marker = "sys_platform != 'darwin'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/60/bc/7771846d3a0272026c416fbb7e5f4c1f146d6d80704534d0b187dd6f4800/nvidia_cufft_cu12-11.3.3.83-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:848ef7224d6305cdb2a4df928759dca7b1201874787083b6e7550dd6765ce69a", size = 193109211, upload-time = "2025-03-07T01:44:56.873Z" }, @@ -1431,9 +1448,9 @@ name = "nvidia-cusolver-cu12" version = "11.7.3.90" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-cublas-cu12" }, - { name = "nvidia-cusparse-cu12" }, - { name = "nvidia-nvjitlink-cu12" }, + { name = "nvidia-cublas-cu12", marker = "sys_platform != 'darwin'" }, + { name = "nvidia-cusparse-cu12", marker = "sys_platform != 'darwin'" }, + { name = "nvidia-nvjitlink-cu12", marker = "sys_platform != 'darwin'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/c8/32/f7cd6ce8a7690544d084ea21c26e910a97e077c9b7f07bf5de623ee19981/nvidia_cusolver_cu12-11.7.3.90-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:db9ed69dbef9715071232caa9b69c52ac7de3a95773c2db65bdba85916e4e5c0", size = 267229841, upload-time = "2025-03-07T01:46:54.356Z" }, @@ -1445,7 +1462,7 @@ name = "nvidia-cusparse-cu12" version = "12.5.8.93" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-nvjitlink-cu12" }, + { name = "nvidia-nvjitlink-cu12", marker = "sys_platform != 'darwin'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/bc/f7/cd777c4109681367721b00a106f491e0d0d15cfa1fd59672ce580ce42a97/nvidia_cusparse_cu12-12.5.8.93-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:9b6c161cb130be1a07a27ea6923df8141f3c295852f4b260c65f18f3e0a091dc", size = 288117129, upload-time = "2025-03-07T01:47:40.407Z" }, @@ -1538,7 +1555,8 @@ dependencies = [ { name = "psutil" }, { name = "pyyaml" }, { name = "safetensors" }, - { name = "torch" }, + { name = "torch", version = "2.10.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform == 'darwin'" }, + { name = "torch", version = "2.10.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "sys_platform != 'darwin'" }, { name = "tqdm" }, { name = "transformers" }, ] @@ -2076,7 +2094,8 @@ name = "tifffile" version = "2025.5.10" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version < '3.11'", + "python_full_version < '3.11' and sys_platform != 'darwin'", + "python_full_version < '3.11' and sys_platform == 'darwin'", ] dependencies = [ { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, @@ -2091,8 +2110,10 @@ name = "tifffile" version = "2026.2.24" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version >= '3.12'", - "python_full_version == '3.11.*'", + "python_full_version >= '3.12' and sys_platform != 'darwin'", + "python_full_version >= '3.12' and sys_platform == 'darwin'", + "python_full_version == '3.11.*' and sys_platform != 'darwin'", + "python_full_version == '3.11.*' and sys_platform == 'darwin'", ] dependencies = [ { name = "numpy", version = "2.4.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, @@ -2110,8 +2131,10 @@ dependencies = [ { name = "huggingface-hub" }, { name = "pyyaml" }, { name = "safetensors" }, - { name = "torch" }, - { name = "torchvision" }, + { name = "torch", version = "2.10.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform == 'darwin'" }, + { name = "torch", version = "2.10.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "sys_platform != 'darwin'" }, + { name = "torchvision", version = "0.25.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform == 'darwin'" }, + { name = "torchvision", version = "0.25.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "sys_platform != 'darwin'" }, ] [[package]] @@ -2198,17 +2221,55 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/23/d1/136eb2cb77520a31e1f64cbae9d33ec6df0d78bdf4160398e86eec8a8754/tomli-2.4.0-py3-none-any.whl", hash = "sha256:1f776e7d669ebceb01dee46484485f43a4048746235e683bcdffacdf1fb4785a", size = 14477, upload-time = "2026-01-11T11:22:37.446Z" }, ] +[[package]] +name = "torch" +version = "2.10.0" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version >= '3.12' and sys_platform == 'darwin'", + "python_full_version == '3.11.*' and sys_platform == 'darwin'", + "python_full_version < '3.11' and sys_platform == 'darwin'", +] +dependencies = [ + { name = "filelock", marker = "sys_platform == 'darwin'" }, + { name = "fsspec", marker = "sys_platform == 'darwin'" }, + { name = "jinja2", marker = "sys_platform == 'darwin'" }, + { name = "networkx", version = "3.4.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11' and sys_platform == 'darwin'" }, + { name = "networkx", version = "3.6.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11' and sys_platform == 'darwin'" }, + { name = "setuptools", marker = "python_full_version >= '3.12' and sys_platform == 'darwin'" }, + { name = "sympy", marker = "sys_platform == 'darwin'" }, + { name = "typing-extensions", marker = "sys_platform == 'darwin'" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/5b/30/bfebdd8ec77db9a79775121789992d6b3b75ee5494971294d7b4b7c999bc/torch-2.10.0-2-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:2b980edd8d7c0a68c4e951ee1856334a43193f98730d97408fbd148c1a933313", size = 79411457, upload-time = "2026-02-10T21:44:59.189Z" }, + { url = "https://files.pythonhosted.org/packages/0f/8b/4b61d6e13f7108f36910df9ab4b58fd389cc2520d54d81b88660804aad99/torch-2.10.0-2-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:418997cb02d0a0f1497cf6a09f63166f9f5df9f3e16c8a716ab76a72127c714f", size = 79423467, upload-time = "2026-02-10T21:44:48.711Z" }, + { url = "https://files.pythonhosted.org/packages/d3/54/a2ba279afcca44bbd320d4e73675b282fcee3d81400ea1b53934efca6462/torch-2.10.0-2-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:13ec4add8c3faaed8d13e0574f5cd4a323c11655546f91fbe6afa77b57423574", size = 79498202, upload-time = "2026-02-10T21:44:52.603Z" }, + { url = "https://files.pythonhosted.org/packages/ec/23/2c9fe0c9c27f7f6cb865abcea8a4568f29f00acaeadfc6a37f6801f84cb4/torch-2.10.0-2-cp313-none-macosx_11_0_arm64.whl", hash = "sha256:e521c9f030a3774ed770a9c011751fb47c4d12029a3d6522116e48431f2ff89e", size = 79498254, upload-time = "2026-02-10T21:44:44.095Z" }, + { url = "https://files.pythonhosted.org/packages/76/bb/d820f90e69cda6c8169b32a0c6a3ab7b17bf7990b8f2c680077c24a3c14c/torch-2.10.0-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:35e407430795c8d3edb07a1d711c41cc1f9eaddc8b2f1cc0a165a6767a8fb73d", size = 79411450, upload-time = "2026-01-21T16:25:30.692Z" }, + { url = "https://files.pythonhosted.org/packages/61/d8/15b9d9d3a6b0c01b883787bd056acbe5cc321090d4b216d3ea89a8fcfdf3/torch-2.10.0-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:b7bd80f3477b830dd166c707c5b0b82a898e7b16f59a7d9d42778dd058272e8b", size = 79423461, upload-time = "2026-01-21T16:24:50.266Z" }, + { url = "https://files.pythonhosted.org/packages/c9/5c/dee910b87c4d5c0fcb41b50839ae04df87c1cfc663cf1b5fca7ea565eeaa/torch-2.10.0-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:6d3707a61863d1c4d6ebba7be4ca320f42b869ee657e9b2c21c736bf17000294", size = 79498198, upload-time = "2026-01-21T16:24:34.704Z" }, + { url = "https://files.pythonhosted.org/packages/1a/0b/39929b148f4824bc3ad6f9f72a29d4ad865bcf7ebfc2fa67584773e083d2/torch-2.10.0-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:3202429f58309b9fa96a614885eace4b7995729f44beb54d3e4a47773649d382", size = 79851305, upload-time = "2026-01-21T16:24:09.209Z" }, + { url = "https://files.pythonhosted.org/packages/0e/13/e76b4d9c160e89fff48bf16b449ea324bda84745d2ab30294c37c2434c0d/torch-2.10.0-cp313-none-macosx_11_0_arm64.whl", hash = "sha256:cdf2a523d699b70d613243211ecaac14fe9c5df8a0b0a9c02add60fb2a413e0f", size = 79498248, upload-time = "2026-01-21T16:23:09.315Z" }, + { url = "https://files.pythonhosted.org/packages/4f/93/716b5ac0155f1be70ed81bacc21269c3ece8dba0c249b9994094110bfc51/torch-2.10.0-cp314-cp314-macosx_14_0_arm64.whl", hash = "sha256:bf0d9ff448b0218e0433aeb198805192346c4fd659c852370d5cc245f602a06a", size = 79464992, upload-time = "2026-01-21T16:23:05.162Z" }, + { url = "https://files.pythonhosted.org/packages/d8/94/71994e7d0d5238393df9732fdab607e37e2b56d26a746cb59fdb415f8966/torch-2.10.0-cp314-cp314t-macosx_14_0_arm64.whl", hash = "sha256:f5ab4ba32383061be0fb74bda772d470140a12c1c3b58a0cfbf3dae94d164c28", size = 79850324, upload-time = "2026-01-21T16:22:09.494Z" }, +] + [[package]] name = "torch" version = "2.10.0+cu128" source = { registry = "https://download.pytorch.org/whl/cu128" } +resolution-markers = [ + "python_full_version >= '3.12' and sys_platform != 'darwin'", + "python_full_version == '3.11.*' and sys_platform != 'darwin'", + "python_full_version < '3.11' and sys_platform != 'darwin'", +] dependencies = [ { name = "cuda-bindings", marker = "sys_platform == 'linux'" }, - { name = "filelock" }, - { name = "fsspec" }, - { name = "jinja2" }, - { name = "networkx", version = "3.4.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, - { name = "networkx", version = "3.6.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "filelock", marker = "sys_platform != 'darwin'" }, + { name = "fsspec", marker = "sys_platform != 'darwin'" }, + { name = "jinja2", marker = "sys_platform != 'darwin'" }, + { name = "networkx", version = "3.4.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11' and sys_platform != 'darwin'" }, + { name = "networkx", version = "3.6.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11' and sys_platform != 'darwin'" }, { name = "nvidia-cublas-cu12", marker = "sys_platform == 'linux'" }, { name = "nvidia-cuda-cupti-cu12", marker = "sys_platform == 'linux'" }, { name = "nvidia-cuda-nvrtc-cu12", marker = "sys_platform == 'linux'" }, @@ -2224,10 +2285,10 @@ dependencies = [ { name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux'" }, { name = "nvidia-nvshmem-cu12", marker = "sys_platform == 'linux'" }, { name = "nvidia-nvtx-cu12", marker = "sys_platform == 'linux'" }, - { name = "setuptools", marker = "python_full_version >= '3.12'" }, - { name = "sympy" }, + { name = "setuptools", marker = "python_full_version >= '3.12' and sys_platform != 'darwin'" }, + { name = "sympy", marker = "sys_platform != 'darwin'" }, { name = "triton", marker = "sys_platform == 'linux'" }, - { name = "typing-extensions" }, + { name = "typing-extensions", marker = "sys_platform != 'darwin'" }, ] wheels = [ { url = "https://download.pytorch.org/whl/cu128/torch-2.10.0%2Bcu128-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:e186f57ef1de1aa877943259819468fc6f27efb583b4a91f9215ada7b7f4e6cc" }, @@ -2253,15 +2314,45 @@ wheels = [ { url = "https://download.pytorch.org/whl/cu128/torch-2.10.0%2Bcu128-cp314-cp314t-win_amd64.whl", hash = "sha256:70d89143c956389d4806cb4e5fe0b1129fe0db280e1073288d17fa76c101cba4" }, ] +[[package]] +name = "torchvision" +version = "0.25.0" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version >= '3.12' and sys_platform == 'darwin'", + "python_full_version == '3.11.*' and sys_platform == 'darwin'", + "python_full_version < '3.11' and sys_platform == 'darwin'", +] +dependencies = [ + { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11' and sys_platform == 'darwin'" }, + { name = "numpy", version = "2.4.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11' and sys_platform == 'darwin'" }, + { name = "pillow", marker = "sys_platform == 'darwin'" }, + { name = "torch", version = "2.10.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform == 'darwin'" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/50/ae/cbf727421eb73f1cf907fbe5788326a08f111b3f6b6ddca15426b53fec9a/torchvision-0.25.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a95c47abb817d4e90ea1a8e57bd0d728e3e6b533b3495ae77d84d883c4d11f56", size = 1874919, upload-time = "2026-01-21T16:27:47.617Z" }, + { url = "https://files.pythonhosted.org/packages/3e/be/c704bceaf11c4f6b19d64337a34a877fcdfe3bd68160a8c9ae9bea4a35a3/torchvision-0.25.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:db74a551946b75d19f9996c419a799ffdf6a223ecf17c656f90da011f1d75b20", size = 1874923, upload-time = "2026-01-21T16:27:46.574Z" }, + { url = "https://files.pythonhosted.org/packages/56/3a/6ea0d73f49a9bef38a1b3a92e8dd455cea58470985d25635beab93841748/torchvision-0.25.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:c2abe430c90b1d5e552680037d68da4eb80a5852ebb1c811b2b89d299b10573b", size = 1874920, upload-time = "2026-01-21T16:27:45.348Z" }, + { url = "https://files.pythonhosted.org/packages/f5/5b/1562a04a6a5a4cf8cf40016a0cdeda91ede75d6962cff7f809a85ae966a5/torchvision-0.25.0-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:24e11199e4d84ba9c5ee7825ebdf1cd37ce8deec225117f10243cae984ced3ec", size = 1874918, upload-time = "2026-01-21T16:27:39.02Z" }, + { url = "https://files.pythonhosted.org/packages/52/99/dca81ed21ebaeff2b67cc9f815a20fdaa418b69f5f9ea4c6ed71721470db/torchvision-0.25.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:a8f8061284395ce31bcd460f2169013382ccf411148ceb2ee38e718e9860f5a7", size = 1896209, upload-time = "2026-01-21T16:27:32.159Z" }, + { url = "https://files.pythonhosted.org/packages/9e/1f/fa839532660e2602b7e704d65010787c5bb296258b44fa8b9c1cd6175e7d/torchvision-0.25.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:620a236288d594dcec7634c754484542dc0a5c1b0e0b83a34bda5e91e9b7c3a1", size = 1896193, upload-time = "2026-01-21T16:27:24.785Z" }, + { url = "https://files.pythonhosted.org/packages/97/36/96374a4c7ab50dea9787ce987815614ccfe988a42e10ac1a2e3e5b60319a/torchvision-0.25.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:ad9a8a5877782944d99186e4502a614770fe906626d76e9cd32446a0ac3075f2", size = 1896207, upload-time = "2026-01-21T16:27:23.383Z" }, +] + [[package]] name = "torchvision" version = "0.25.0+cu128" source = { registry = "https://download.pytorch.org/whl/cu128" } +resolution-markers = [ + "python_full_version >= '3.12' and sys_platform != 'darwin'", + "python_full_version == '3.11.*' and sys_platform != 'darwin'", + "python_full_version < '3.11' and sys_platform != 'darwin'", +] dependencies = [ - { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, - { name = "numpy", version = "2.4.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, - { name = "pillow" }, - { name = "torch" }, + { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11' and sys_platform != 'darwin'" }, + { name = "numpy", version = "2.4.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11' and sys_platform != 'darwin'" }, + { name = "pillow", marker = "sys_platform != 'darwin'" }, + { name = "torch", version = "2.10.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "sys_platform != 'darwin'" }, ] wheels = [ { url = "https://download.pytorch.org/whl/cu128/torchvision-0.25.0%2Bcu128-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:7fe593ec669dcb74c3749548c21ae302edbb610d1844bf2bae622accdef76dbc" }, From cd702552a1fcb220c0aa2aa9a0899c53db92a5ef Mon Sep 17 00:00:00 2001 From: Marclie Date: Mon, 9 Mar 2026 19:52:13 +0100 Subject: [PATCH 03/44] add warmup run to remove compilation overhead from benchmark --- test_vram.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/test_vram.py b/test_vram.py index 2f734d8d..7a179a2c 100644 --- a/test_vram.py +++ b/test_vram.py @@ -20,6 +20,7 @@ def test_vram(): img_size=2048, device="cuda", model_precision=torch.float16, + mixed_precision=True, ) # Reset stats @@ -27,11 +28,11 @@ def test_vram(): iterations = 24 print(f"Running {iterations} inference passes...") - time = timeit.timeit(lambda: process_frame(engine), number=iterations) + time = timeit.timeit(lambda: process_frame(engine), number=iterations, setup=lambda: process_frame(engine)) print(f"Seconds per frame: {time / iterations}") peak_vram = torch.cuda.max_memory_allocated() / (1024**3) - print(f"Peak VRAM used: {peak_vram:.2f} GB") + print(f"Peak VRAM used: {peak_vram:.2f} GiB") if __name__ == "__main__": From 6631537ff0cf0abfa737db622d941b015afebbc5 Mon Sep 17 00:00:00 2001 From: Marclie Date: Mon, 9 Mar 2026 19:54:51 +0100 Subject: [PATCH 04/44] implement preprocessing on GPU using torchvision --- CorridorKeyModule/inference_engine.py | 47 ++++++++++++++++++--------- 1 file changed, 31 insertions(+), 16 deletions(-) diff --git a/CorridorKeyModule/inference_engine.py b/CorridorKeyModule/inference_engine.py index f3c59e19..c90670ec 100644 --- a/CorridorKeyModule/inference_engine.py +++ b/CorridorKeyModule/inference_engine.py @@ -6,7 +6,10 @@ import cv2 import numpy as np import torch +import torchvision import torch.nn.functional as F +import torchvision.transforms.functional as TF + from .core import color_utils as cu from .core.model_transformer import GreenFormer @@ -27,8 +30,8 @@ def __init__( self.checkpoint_path = checkpoint_path self.use_refiner = use_refiner - self.mean = np.array([0.485, 0.456, 0.406], dtype=np.float32).reshape(1, 1, 3) - self.std = np.array([0.229, 0.224, 0.225], dtype=np.float32).reshape(1, 1, 3) + self.mean = torch.tensor([0.485, 0.456, 0.406], dtype=model_precision, device=self.device).reshape(3, 1, 1) + self.std = torch.tensor([0.229, 0.224, 0.225], dtype=model_precision, device=self.device).reshape(3, 1, 1) if mixed_precision or model_precision != torch.float32: # Use faster matrix multiplication implementation @@ -132,43 +135,55 @@ def process_frame( Returns: dict: {'alpha': np, 'fg': np (sRGB), 'comp': np (sRGB on Gray)} """ + image_was_uint8 = image.dtype == np.uint8 + mask_was_uint8 = mask_linear.dtype == np.uint8 + + # immediately casting to float is fine since fp16 can represent all uint8 values exactly + image = torch.from_numpy(image).to(self.model_precision).to(self.device) + mask_linear = torch.from_numpy(mask_linear).to(self.model_precision).to(self.device) # 1. Inputs Check & Normalization - if image.dtype == np.uint8: - image = image.astype(np.float32) / 255.0 + if image_was_uint8: + image = image / 255.0 - if mask_linear.dtype == np.uint8: - mask_linear = mask_linear.astype(np.float32) / 255.0 + if mask_was_uint8: + mask_linear = mask_linear / 255.0 h, w = image.shape[:2] # Ensure Mask Shape if mask_linear.ndim == 2: - mask_linear = mask_linear[:, :, np.newaxis] + mask_linear = mask_linear.unsqueeze(-1) + + image = image.permute(2, 0, 1) # [C, H, W] + mask_linear = mask_linear.permute(2, 0, 1) # [C, H, W] # 2. Resize to Model Size # If input is linear, we resize in linear to preserve energy/highlights, # THEN convert to sRGB for the model. if input_is_linear: - # Resize in Linear - img_resized_lin = cv2.resize(image, (self.img_size, self.img_size), interpolation=cv2.INTER_LINEAR) + # TODO: Check if interpolation is comparable to cv2.INTER_LINEAR (probably close enough) + img_resized_lin = TF.resize( + image, [self.img_size, self.img_size], interpolation=torchvision.transforms.InterpolationMode.BILINEAR + ) # Convert to sRGB for Model img_resized = cu.linear_to_srgb(img_resized_lin) else: # Standard sRGB Resize - img_resized = cv2.resize(image, (self.img_size, self.img_size), interpolation=cv2.INTER_LINEAR) + img_resized = TF.resize( + image, [self.img_size, self.img_size], interpolation=torchvision.transforms.InterpolationMode.BILINEAR + ) - mask_resized = cv2.resize(mask_linear, (self.img_size, self.img_size), interpolation=cv2.INTER_LINEAR) - - if mask_resized.ndim == 2: - mask_resized = mask_resized[:, :, np.newaxis] + mask_resized = TF.resize( + mask_linear, [self.img_size, self.img_size], interpolation=torchvision.transforms.InterpolationMode.BILINEAR + ) # 3. Normalize (ImageNet) # Model expects sRGB input normalized img_norm = (img_resized - self.mean) / self.std # 4. Prepare Tensor - inp_np = np.concatenate([img_norm, mask_resized], axis=-1) # [H, W, 4] - inp_t = torch.from_numpy(inp_np.transpose((2, 0, 1))).unsqueeze(0).to(self.model_precision).to(self.device) + inp_concat = torch.concat((img_norm, mask_resized), 0) # [4, H, W] + inp_t = inp_concat.unsqueeze(0) # 5. Inference # Hook for Refiner Scaling From ed578e8b2a6b0a95b6a73ddc37fc09fac9a1e744 Mon Sep 17 00:00:00 2001 From: Marclie Date: Mon, 9 Mar 2026 22:14:07 +0100 Subject: [PATCH 05/44] add a batched frame processing function --- CorridorKeyModule/inference_engine.py | 257 +++++++++++++++++++------- test_vram.py | 18 +- 2 files changed, 202 insertions(+), 73 deletions(-) diff --git a/CorridorKeyModule/inference_engine.py b/CorridorKeyModule/inference_engine.py index c90670ec..462eae23 100644 --- a/CorridorKeyModule/inference_engine.py +++ b/CorridorKeyModule/inference_engine.py @@ -105,6 +105,108 @@ def _load_model(self) -> GreenFormer: return model + def _preprocess_input( + self, image_batch: torch.Tensor, mask_batch_linear: torch.Tensor, input_is_linear: bool + ) -> torch.Tensor: + # 2. Resize to Model Size + # If input is linear, we resize in linear to preserve energy/highlights, + # THEN convert to sRGB for the model. + if input_is_linear: + # TODO: Check if interpolation is comparable to cv2.INTER_LINEAR (probably close enough) + img_resized_lin = TF.resize( + image_batch, + [self.img_size, self.img_size], + interpolation=torchvision.transforms.InterpolationMode.BILINEAR, + ) + # Convert to sRGB for Model + img_resized = cu.linear_to_srgb(img_resized_lin) + else: + # Standard sRGB Resize + img_resized = TF.resize( + image_batch, + [self.img_size, self.img_size], + interpolation=torchvision.transforms.InterpolationMode.BILINEAR, + ) + + mask_resized = TF.resize( + mask_batch_linear, + [self.img_size, self.img_size], + interpolation=torchvision.transforms.InterpolationMode.BILINEAR, + ) + + # 3. Normalize (ImageNet) + # Model expects sRGB input normalized + img_norm = (img_resized - self.mean) / self.std + + # 4. Prepare Tensor + inp_concat = torch.concat((img_norm, mask_resized), -3) # [4, H, W] + + return inp_concat + + def _postprocess_output( + self, + pred_alpha: torch.Tensor, + pred_fg: torch.Tensor, + w: int, + h: int, + fg_is_straight: bool, + despill_strength: float, + auto_despeckle: bool, + despeckle_size: int, + ) -> dict[str, np.ndarray]: + # 6. Post-Process (Resize Back to Original Resolution) + # We use Lanczos4 for high-quality resampling to minimize blur when going back to 4K/Original. + res_alpha = pred_alpha.permute(1, 2, 0).numpy() + res_fg = pred_fg.permute(1, 2, 0).numpy() + res_alpha = cv2.resize(res_alpha, (w, h), interpolation=cv2.INTER_LANCZOS4) + res_fg = cv2.resize(res_fg, (w, h), interpolation=cv2.INTER_LANCZOS4) + + if res_alpha.ndim == 2: + res_alpha = res_alpha[:, :, np.newaxis] + + # --- ADVANCED COMPOSITING --- + + # A. Clean Matte (Auto-Despeckle) + if auto_despeckle: + processed_alpha = cu.clean_matte(res_alpha, area_threshold=despeckle_size, dilation=25, blur_size=5) + else: + processed_alpha = res_alpha + + # B. Despill FG + # res_fg is sRGB. + fg_despilled = cu.despill(res_fg, green_limit_mode="average", strength=despill_strength) + + # C. Premultiply (for EXR Output) + # CONVERT TO LINEAR FIRST! EXRs must house linear color premultiplied by linear alpha. + fg_despilled_lin = cu.srgb_to_linear(fg_despilled) + fg_premul_lin = cu.premultiply(fg_despilled_lin, processed_alpha) + + # D. Pack RGBA + # [H, W, 4] - All channels are now strictly Linear Float + processed_rgba = np.concatenate([fg_premul_lin, processed_alpha], axis=-1) + + # ---------------------------- + + # 7. Composite (on Checkerboard) for checking + # Generate Dark/Light Gray Checkerboard (in sRGB, convert to Linear) + bg_srgb = cu.create_checkerboard(w, h, checker_size=128, color1=0.15, color2=0.55) + bg_lin = cu.srgb_to_linear(bg_srgb) + + if fg_is_straight: + comp_lin = cu.composite_straight(fg_despilled_lin, bg_lin, processed_alpha) + else: + # If premultiplied model, we shouldn't multiply again (though our pipeline forces straight) + comp_lin = cu.composite_premul(fg_despilled_lin, bg_lin, processed_alpha) + + comp_srgb = cu.linear_to_srgb(comp_lin) + + return { + "alpha": res_alpha, # Linear, Raw Prediction + "fg": res_fg, # sRGB, Raw Prediction (Straight) + "comp": comp_srgb, # sRGB, Composite + "processed": processed_rgba, # Linear/Premul, RGBA, Garbage Matted & Despilled + } + @torch.inference_mode() def process_frame( self, @@ -157,33 +259,10 @@ def process_frame( image = image.permute(2, 0, 1) # [C, H, W] mask_linear = mask_linear.permute(2, 0, 1) # [C, H, W] - # 2. Resize to Model Size - # If input is linear, we resize in linear to preserve energy/highlights, - # THEN convert to sRGB for the model. - if input_is_linear: - # TODO: Check if interpolation is comparable to cv2.INTER_LINEAR (probably close enough) - img_resized_lin = TF.resize( - image, [self.img_size, self.img_size], interpolation=torchvision.transforms.InterpolationMode.BILINEAR - ) - # Convert to sRGB for Model - img_resized = cu.linear_to_srgb(img_resized_lin) - else: - # Standard sRGB Resize - img_resized = TF.resize( - image, [self.img_size, self.img_size], interpolation=torchvision.transforms.InterpolationMode.BILINEAR - ) + image = image.unsqueeze(0) + mask_linear = mask_linear.unsqueeze(0) - mask_resized = TF.resize( - mask_linear, [self.img_size, self.img_size], interpolation=torchvision.transforms.InterpolationMode.BILINEAR - ) - - # 3. Normalize (ImageNet) - # Model expects sRGB input normalized - img_norm = (img_resized - self.mean) / self.std - - # 4. Prepare Tensor - inp_concat = torch.concat((img_norm, mask_resized), 0) # [4, H, W] - inp_t = inp_concat.unsqueeze(0) + inp_t = self._preprocess_input(image, mask_linear, input_is_linear) # 5. Inference # Hook for Refiner Scaling @@ -196,63 +275,101 @@ def scale_hook(module, input, output): handle = self.model.refiner.register_forward_hook(scale_hook) with torch.autocast(device_type=self.device.type, dtype=torch.float16, enabled=self.mixed_precision): - out = self.model(inp_t) + prediction = self.model(inp_t) if handle: handle.remove() - pred_alpha = out["alpha"] - pred_fg = out["fg"] # Output is sRGB (Sigmoid) + pred_alpha = prediction["alpha"][0].cpu().float() + pred_fg = prediction["fg"][0].cpu().float() # Output is sRGB (Sigmoid) - # 6. Post-Process (Resize Back to Original Resolution) - # We use Lanczos4 for high-quality resampling to minimize blur when going back to 4K/Original. - res_alpha = pred_alpha[0].permute(1, 2, 0).float().cpu().numpy() - res_fg = pred_fg[0].permute(1, 2, 0).float().cpu().numpy() - res_alpha = cv2.resize(res_alpha, (w, h), interpolation=cv2.INTER_LANCZOS4) - res_fg = cv2.resize(res_fg, (w, h), interpolation=cv2.INTER_LANCZOS4) + return self._postprocess_output( + pred_alpha, pred_fg, w, h, fg_is_straight, despill_strength, auto_despeckle, despeckle_size + ) - if res_alpha.ndim == 2: - res_alpha = res_alpha[:, :, np.newaxis] + @torch.inference_mode() + def batch_process_frames( + self, + images: np.ndarray, + masks_linear: np.ndarray, + refiner_scale: float = 1.0, + input_is_linear: bool = False, + fg_is_straight: bool = True, + despill_strength: float = 1.0, + auto_despeckle: bool = True, + despeckle_size: int = 400, + ) -> list[dict[str, np.ndarray]]: + """ + Process a single frame. + Args: + images: Numpy array [B, H, W, 3] (0.0-1.0 or 0-255). + - If input_is_linear=False (Default): Assumed sRGB. + - If input_is_linear=True: Assumed Linear. + masks_linear: Numpy array [B, H, W] or [B, H, W, 1] (0.0-1.0). Assumed Linear. + refiner_scale: Multiplier for Refiner Deltas (default 1.0). + input_is_linear: bool. If True, resizes in Linear then transforms to sRGB. + If False, resizes in sRGB (standard). + fg_is_straight: bool. If True, assumes FG output is Straight (unpremultiplied). + If False, assumes FG output is Premultiplied. + despill_strength: float. 0.0 to 1.0 multiplier for the despill effect. + auto_despeckle: bool. If True, cleans up small disconnected components from the predicted alpha matte. + despeckle_size: int. Minimum number of consecutive pixels required to keep an island. + Returns: + list[dict: {'alpha': np, 'fg': np (sRGB), 'comp': np (sRGB on Gray)}] + """ + image_was_uint8 = images.dtype == np.uint8 + mask_was_uint8 = masks_linear.dtype == np.uint8 - # --- ADVANCED COMPOSITING --- + # immediately casting to float is fine since fp16 can represent all uint8 values exactly + image = torch.from_numpy(images).to(self.model_precision).to(self.device) + mask_linear = torch.from_numpy(masks_linear).to(self.model_precision).to(self.device) + # 1. Inputs Check & Normalization + if image_was_uint8: + image = image / 255.0 - # A. Clean Matte (Auto-Despeckle) - if auto_despeckle: - processed_alpha = cu.clean_matte(res_alpha, area_threshold=despeckle_size, dilation=25, blur_size=5) - else: - processed_alpha = res_alpha + if mask_was_uint8: + mask_linear = mask_linear / 255.0 - # B. Despill FG - # res_fg is sRGB. - fg_despilled = cu.despill(res_fg, green_limit_mode="average", strength=despill_strength) + h, w = image.shape[1:3] - # C. Premultiply (for EXR Output) - # CONVERT TO LINEAR FIRST! EXRs must house linear color premultiplied by linear alpha. - fg_despilled_lin = cu.srgb_to_linear(fg_despilled) - fg_premul_lin = cu.premultiply(fg_despilled_lin, processed_alpha) + # Ensure Mask Shape + if mask_linear.ndim == 3: + mask_linear = mask_linear.unsqueeze(-1) - # D. Pack RGBA - # [H, W, 4] - All channels are now strictly Linear Float - processed_rgba = np.concatenate([fg_premul_lin, processed_alpha], axis=-1) + image = image.permute(0, 3, 1, 2) # [B, C, H, W] + mask_linear = mask_linear.permute(0, 3, 1, 2) # [B, C, H, W] - # ---------------------------- + inp_t = self._preprocess_input(image, mask_linear, input_is_linear) - # 7. Composite (on Checkerboard) for checking - # Generate Dark/Light Gray Checkerboard (in sRGB, convert to Linear) - bg_srgb = cu.create_checkerboard(w, h, checker_size=128, color1=0.15, color2=0.55) - bg_lin = cu.srgb_to_linear(bg_srgb) + # Free up unused VRAM in order to keep peak usage down and avoid OOM errors + torch.cuda.empty_cache() - if fg_is_straight: - comp_lin = cu.composite_straight(fg_despilled_lin, bg_lin, processed_alpha) - else: - # If premultiplied model, we shouldn't multiply again (though our pipeline forces straight) - comp_lin = cu.composite_premul(fg_despilled_lin, bg_lin, processed_alpha) + # 5. Inference + # Hook for Refiner Scaling + handle = None + if refiner_scale != 1.0 and self.model.refiner is not None: - comp_srgb = cu.linear_to_srgb(comp_lin) + def scale_hook(module, input, output): + return output * refiner_scale - return { - "alpha": res_alpha, # Linear, Raw Prediction - "fg": res_fg, # sRGB, Raw Prediction (Straight) - "comp": comp_srgb, # sRGB, Composite - "processed": processed_rgba, # Linear/Premul, RGBA, Garbage Matted & Despilled - } + handle = self.model.refiner.register_forward_hook(scale_hook) + + with torch.autocast(device_type=self.device.type, dtype=torch.float16, enabled=self.mixed_precision): + prediction = self.model(inp_t) + + # Free up unused VRAM in order to keep peak usage down and avoid OOM errors + del inp_t + torch.cuda.empty_cache() + + if handle: + handle.remove() + + out = [] + for pred_alpha, pred_fg in zip(prediction["alpha"].cpu().float(), prediction["fg"].cpu().float()): + out.append( + self._postprocess_output( + pred_alpha, pred_fg, w, h, fg_is_straight, despill_strength, auto_despeckle, despeckle_size + ) + ) + + return out diff --git a/test_vram.py b/test_vram.py index 7a179a2c..32c4d5ab 100644 --- a/test_vram.py +++ b/test_vram.py @@ -6,13 +6,20 @@ from CorridorKeyModule.inference_engine import CorridorKeyEngine -def process_frame(engine): +def process_frame(engine: CorridorKeyEngine): img = np.random.randint(0, 255, (2160, 3840, 3), dtype=np.uint8) mask = np.random.randint(0, 255, (2160, 3840), dtype=np.uint8) engine.process_frame(img, mask) +def batch_process_frame(engine: CorridorKeyEngine, batch_size: int): + imgs = np.random.randint(0, 255, (batch_size, 2160, 3840, 3), dtype=np.uint8) + masks = np.random.randint(0, 255, (batch_size, 2160, 3840), dtype=np.uint8) + + engine.batch_process_frames(imgs, masks) + + def test_vram(): print("Loading engine...") engine = CorridorKeyEngine( @@ -27,9 +34,14 @@ def test_vram(): torch.cuda.reset_peak_memory_stats() iterations = 24 + batch_size = 6 # works with a 16GB GPU print(f"Running {iterations} inference passes...") - time = timeit.timeit(lambda: process_frame(engine), number=iterations, setup=lambda: process_frame(engine)) - print(f"Seconds per frame: {time / iterations}") + time = timeit.timeit( + lambda: batch_process_frame(engine, batch_size), + number=iterations, + setup=lambda: batch_process_frame(engine, batch_size), + ) + print(f"Seconds per frame: {time / (iterations * batch_size):.4f}") peak_vram = torch.cuda.max_memory_allocated() / (1024**3) print(f"Peak VRAM used: {peak_vram:.2f} GiB") From efc3a4ddabc2c7bddfa05f9e95f8293ad4a00780 Mon Sep 17 00:00:00 2001 From: Marclie Date: Mon, 9 Mar 2026 23:24:55 +0100 Subject: [PATCH 06/44] use multiprocessing to speed up post-processing --- CorridorKeyModule/inference_engine.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/CorridorKeyModule/inference_engine.py b/CorridorKeyModule/inference_engine.py index 462eae23..b505da03 100644 --- a/CorridorKeyModule/inference_engine.py +++ b/CorridorKeyModule/inference_engine.py @@ -298,6 +298,7 @@ def batch_process_frames( despill_strength: float = 1.0, auto_despeckle: bool = True, despeckle_size: int = 400, + num_workers: int = torch.multiprocessing.cpu_count() // 2, ) -> list[dict[str, np.ndarray]]: """ Process a single frame. @@ -314,6 +315,7 @@ def batch_process_frames( despill_strength: float. 0.0 to 1.0 multiplier for the despill effect. auto_despeckle: bool. If True, cleans up small disconnected components from the predicted alpha matte. despeckle_size: int. Minimum number of consecutive pixels required to keep an island. + num_workers: int. Number of worker threads used for post-processing Returns: list[dict: {'alpha': np, 'fg': np (sRGB), 'comp': np (sRGB on Gray)}] """ @@ -364,12 +366,17 @@ def scale_hook(module, input, output): if handle: handle.remove() - out = [] - for pred_alpha, pred_fg in zip(prediction["alpha"].cpu().float(), prediction["fg"].cpu().float()): - out.append( - self._postprocess_output( - pred_alpha, pred_fg, w, h, fg_is_straight, despill_strength, auto_despeckle, despeckle_size - ) + with torch.multiprocessing.Pool(num_workers) as pool: + input = zip( + prediction["alpha"].cpu().float(), + prediction["fg"].cpu().float(), + [w] * len(prediction["alpha"]), + [h] * len(prediction["alpha"]), + [fg_is_straight] * len(prediction["alpha"]), + [despill_strength] * len(prediction["alpha"]), + [auto_despeckle] * len(prediction["alpha"]), + [despeckle_size] * len(prediction["alpha"]), ) + out = pool.starmap(self._postprocess_output, input) return out From 18db1cb8cbba50c26377a7c3e5e77f39185b6a0b Mon Sep 17 00:00:00 2001 From: Marclie Date: Tue, 10 Mar 2026 01:52:05 +0100 Subject: [PATCH 07/44] implement batched frame processing --- CorridorKeyModule/inference_engine.py | 285 +++++++++++++++++++------- test_vram.py | 21 +- 2 files changed, 229 insertions(+), 77 deletions(-) diff --git a/CorridorKeyModule/inference_engine.py b/CorridorKeyModule/inference_engine.py index f3c59e19..b505da03 100644 --- a/CorridorKeyModule/inference_engine.py +++ b/CorridorKeyModule/inference_engine.py @@ -6,7 +6,10 @@ import cv2 import numpy as np import torch +import torchvision import torch.nn.functional as F +import torchvision.transforms.functional as TF + from .core import color_utils as cu from .core.model_transformer import GreenFormer @@ -27,8 +30,8 @@ def __init__( self.checkpoint_path = checkpoint_path self.use_refiner = use_refiner - self.mean = np.array([0.485, 0.456, 0.406], dtype=np.float32).reshape(1, 1, 3) - self.std = np.array([0.229, 0.224, 0.225], dtype=np.float32).reshape(1, 1, 3) + self.mean = torch.tensor([0.485, 0.456, 0.406], dtype=model_precision, device=self.device).reshape(3, 1, 1) + self.std = torch.tensor([0.229, 0.224, 0.225], dtype=model_precision, device=self.device).reshape(3, 1, 1) if mixed_precision or model_precision != torch.float32: # Use faster matrix multiplication implementation @@ -102,6 +105,108 @@ def _load_model(self) -> GreenFormer: return model + def _preprocess_input( + self, image_batch: torch.Tensor, mask_batch_linear: torch.Tensor, input_is_linear: bool + ) -> torch.Tensor: + # 2. Resize to Model Size + # If input is linear, we resize in linear to preserve energy/highlights, + # THEN convert to sRGB for the model. + if input_is_linear: + # TODO: Check if interpolation is comparable to cv2.INTER_LINEAR (probably close enough) + img_resized_lin = TF.resize( + image_batch, + [self.img_size, self.img_size], + interpolation=torchvision.transforms.InterpolationMode.BILINEAR, + ) + # Convert to sRGB for Model + img_resized = cu.linear_to_srgb(img_resized_lin) + else: + # Standard sRGB Resize + img_resized = TF.resize( + image_batch, + [self.img_size, self.img_size], + interpolation=torchvision.transforms.InterpolationMode.BILINEAR, + ) + + mask_resized = TF.resize( + mask_batch_linear, + [self.img_size, self.img_size], + interpolation=torchvision.transforms.InterpolationMode.BILINEAR, + ) + + # 3. Normalize (ImageNet) + # Model expects sRGB input normalized + img_norm = (img_resized - self.mean) / self.std + + # 4. Prepare Tensor + inp_concat = torch.concat((img_norm, mask_resized), -3) # [4, H, W] + + return inp_concat + + def _postprocess_output( + self, + pred_alpha: torch.Tensor, + pred_fg: torch.Tensor, + w: int, + h: int, + fg_is_straight: bool, + despill_strength: float, + auto_despeckle: bool, + despeckle_size: int, + ) -> dict[str, np.ndarray]: + # 6. Post-Process (Resize Back to Original Resolution) + # We use Lanczos4 for high-quality resampling to minimize blur when going back to 4K/Original. + res_alpha = pred_alpha.permute(1, 2, 0).numpy() + res_fg = pred_fg.permute(1, 2, 0).numpy() + res_alpha = cv2.resize(res_alpha, (w, h), interpolation=cv2.INTER_LANCZOS4) + res_fg = cv2.resize(res_fg, (w, h), interpolation=cv2.INTER_LANCZOS4) + + if res_alpha.ndim == 2: + res_alpha = res_alpha[:, :, np.newaxis] + + # --- ADVANCED COMPOSITING --- + + # A. Clean Matte (Auto-Despeckle) + if auto_despeckle: + processed_alpha = cu.clean_matte(res_alpha, area_threshold=despeckle_size, dilation=25, blur_size=5) + else: + processed_alpha = res_alpha + + # B. Despill FG + # res_fg is sRGB. + fg_despilled = cu.despill(res_fg, green_limit_mode="average", strength=despill_strength) + + # C. Premultiply (for EXR Output) + # CONVERT TO LINEAR FIRST! EXRs must house linear color premultiplied by linear alpha. + fg_despilled_lin = cu.srgb_to_linear(fg_despilled) + fg_premul_lin = cu.premultiply(fg_despilled_lin, processed_alpha) + + # D. Pack RGBA + # [H, W, 4] - All channels are now strictly Linear Float + processed_rgba = np.concatenate([fg_premul_lin, processed_alpha], axis=-1) + + # ---------------------------- + + # 7. Composite (on Checkerboard) for checking + # Generate Dark/Light Gray Checkerboard (in sRGB, convert to Linear) + bg_srgb = cu.create_checkerboard(w, h, checker_size=128, color1=0.15, color2=0.55) + bg_lin = cu.srgb_to_linear(bg_srgb) + + if fg_is_straight: + comp_lin = cu.composite_straight(fg_despilled_lin, bg_lin, processed_alpha) + else: + # If premultiplied model, we shouldn't multiply again (though our pipeline forces straight) + comp_lin = cu.composite_premul(fg_despilled_lin, bg_lin, processed_alpha) + + comp_srgb = cu.linear_to_srgb(comp_lin) + + return { + "alpha": res_alpha, # Linear, Raw Prediction + "fg": res_fg, # sRGB, Raw Prediction (Straight) + "comp": comp_srgb, # sRGB, Composite + "processed": processed_rgba, # Linear/Premul, RGBA, Garbage Matted & Despilled + } + @torch.inference_mode() def process_frame( self, @@ -132,43 +237,32 @@ def process_frame( Returns: dict: {'alpha': np, 'fg': np (sRGB), 'comp': np (sRGB on Gray)} """ + image_was_uint8 = image.dtype == np.uint8 + mask_was_uint8 = mask_linear.dtype == np.uint8 + + # immediately casting to float is fine since fp16 can represent all uint8 values exactly + image = torch.from_numpy(image).to(self.model_precision).to(self.device) + mask_linear = torch.from_numpy(mask_linear).to(self.model_precision).to(self.device) # 1. Inputs Check & Normalization - if image.dtype == np.uint8: - image = image.astype(np.float32) / 255.0 + if image_was_uint8: + image = image / 255.0 - if mask_linear.dtype == np.uint8: - mask_linear = mask_linear.astype(np.float32) / 255.0 + if mask_was_uint8: + mask_linear = mask_linear / 255.0 h, w = image.shape[:2] # Ensure Mask Shape if mask_linear.ndim == 2: - mask_linear = mask_linear[:, :, np.newaxis] + mask_linear = mask_linear.unsqueeze(-1) - # 2. Resize to Model Size - # If input is linear, we resize in linear to preserve energy/highlights, - # THEN convert to sRGB for the model. - if input_is_linear: - # Resize in Linear - img_resized_lin = cv2.resize(image, (self.img_size, self.img_size), interpolation=cv2.INTER_LINEAR) - # Convert to sRGB for Model - img_resized = cu.linear_to_srgb(img_resized_lin) - else: - # Standard sRGB Resize - img_resized = cv2.resize(image, (self.img_size, self.img_size), interpolation=cv2.INTER_LINEAR) + image = image.permute(2, 0, 1) # [C, H, W] + mask_linear = mask_linear.permute(2, 0, 1) # [C, H, W] - mask_resized = cv2.resize(mask_linear, (self.img_size, self.img_size), interpolation=cv2.INTER_LINEAR) + image = image.unsqueeze(0) + mask_linear = mask_linear.unsqueeze(0) - if mask_resized.ndim == 2: - mask_resized = mask_resized[:, :, np.newaxis] - - # 3. Normalize (ImageNet) - # Model expects sRGB input normalized - img_norm = (img_resized - self.mean) / self.std - - # 4. Prepare Tensor - inp_np = np.concatenate([img_norm, mask_resized], axis=-1) # [H, W, 4] - inp_t = torch.from_numpy(inp_np.transpose((2, 0, 1))).unsqueeze(0).to(self.model_precision).to(self.device) + inp_t = self._preprocess_input(image, mask_linear, input_is_linear) # 5. Inference # Hook for Refiner Scaling @@ -181,63 +275,108 @@ def scale_hook(module, input, output): handle = self.model.refiner.register_forward_hook(scale_hook) with torch.autocast(device_type=self.device.type, dtype=torch.float16, enabled=self.mixed_precision): - out = self.model(inp_t) + prediction = self.model(inp_t) if handle: handle.remove() - pred_alpha = out["alpha"] - pred_fg = out["fg"] # Output is sRGB (Sigmoid) + pred_alpha = prediction["alpha"][0].cpu().float() + pred_fg = prediction["fg"][0].cpu().float() # Output is sRGB (Sigmoid) - # 6. Post-Process (Resize Back to Original Resolution) - # We use Lanczos4 for high-quality resampling to minimize blur when going back to 4K/Original. - res_alpha = pred_alpha[0].permute(1, 2, 0).float().cpu().numpy() - res_fg = pred_fg[0].permute(1, 2, 0).float().cpu().numpy() - res_alpha = cv2.resize(res_alpha, (w, h), interpolation=cv2.INTER_LANCZOS4) - res_fg = cv2.resize(res_fg, (w, h), interpolation=cv2.INTER_LANCZOS4) + return self._postprocess_output( + pred_alpha, pred_fg, w, h, fg_is_straight, despill_strength, auto_despeckle, despeckle_size + ) - if res_alpha.ndim == 2: - res_alpha = res_alpha[:, :, np.newaxis] + @torch.inference_mode() + def batch_process_frames( + self, + images: np.ndarray, + masks_linear: np.ndarray, + refiner_scale: float = 1.0, + input_is_linear: bool = False, + fg_is_straight: bool = True, + despill_strength: float = 1.0, + auto_despeckle: bool = True, + despeckle_size: int = 400, + num_workers: int = torch.multiprocessing.cpu_count() // 2, + ) -> list[dict[str, np.ndarray]]: + """ + Process a single frame. + Args: + images: Numpy array [B, H, W, 3] (0.0-1.0 or 0-255). + - If input_is_linear=False (Default): Assumed sRGB. + - If input_is_linear=True: Assumed Linear. + masks_linear: Numpy array [B, H, W] or [B, H, W, 1] (0.0-1.0). Assumed Linear. + refiner_scale: Multiplier for Refiner Deltas (default 1.0). + input_is_linear: bool. If True, resizes in Linear then transforms to sRGB. + If False, resizes in sRGB (standard). + fg_is_straight: bool. If True, assumes FG output is Straight (unpremultiplied). + If False, assumes FG output is Premultiplied. + despill_strength: float. 0.0 to 1.0 multiplier for the despill effect. + auto_despeckle: bool. If True, cleans up small disconnected components from the predicted alpha matte. + despeckle_size: int. Minimum number of consecutive pixels required to keep an island. + num_workers: int. Number of worker threads used for post-processing + Returns: + list[dict: {'alpha': np, 'fg': np (sRGB), 'comp': np (sRGB on Gray)}] + """ + image_was_uint8 = images.dtype == np.uint8 + mask_was_uint8 = masks_linear.dtype == np.uint8 - # --- ADVANCED COMPOSITING --- + # immediately casting to float is fine since fp16 can represent all uint8 values exactly + image = torch.from_numpy(images).to(self.model_precision).to(self.device) + mask_linear = torch.from_numpy(masks_linear).to(self.model_precision).to(self.device) + # 1. Inputs Check & Normalization + if image_was_uint8: + image = image / 255.0 - # A. Clean Matte (Auto-Despeckle) - if auto_despeckle: - processed_alpha = cu.clean_matte(res_alpha, area_threshold=despeckle_size, dilation=25, blur_size=5) - else: - processed_alpha = res_alpha + if mask_was_uint8: + mask_linear = mask_linear / 255.0 - # B. Despill FG - # res_fg is sRGB. - fg_despilled = cu.despill(res_fg, green_limit_mode="average", strength=despill_strength) + h, w = image.shape[1:3] - # C. Premultiply (for EXR Output) - # CONVERT TO LINEAR FIRST! EXRs must house linear color premultiplied by linear alpha. - fg_despilled_lin = cu.srgb_to_linear(fg_despilled) - fg_premul_lin = cu.premultiply(fg_despilled_lin, processed_alpha) + # Ensure Mask Shape + if mask_linear.ndim == 3: + mask_linear = mask_linear.unsqueeze(-1) - # D. Pack RGBA - # [H, W, 4] - All channels are now strictly Linear Float - processed_rgba = np.concatenate([fg_premul_lin, processed_alpha], axis=-1) + image = image.permute(0, 3, 1, 2) # [B, C, H, W] + mask_linear = mask_linear.permute(0, 3, 1, 2) # [B, C, H, W] - # ---------------------------- + inp_t = self._preprocess_input(image, mask_linear, input_is_linear) - # 7. Composite (on Checkerboard) for checking - # Generate Dark/Light Gray Checkerboard (in sRGB, convert to Linear) - bg_srgb = cu.create_checkerboard(w, h, checker_size=128, color1=0.15, color2=0.55) - bg_lin = cu.srgb_to_linear(bg_srgb) + # Free up unused VRAM in order to keep peak usage down and avoid OOM errors + torch.cuda.empty_cache() - if fg_is_straight: - comp_lin = cu.composite_straight(fg_despilled_lin, bg_lin, processed_alpha) - else: - # If premultiplied model, we shouldn't multiply again (though our pipeline forces straight) - comp_lin = cu.composite_premul(fg_despilled_lin, bg_lin, processed_alpha) + # 5. Inference + # Hook for Refiner Scaling + handle = None + if refiner_scale != 1.0 and self.model.refiner is not None: - comp_srgb = cu.linear_to_srgb(comp_lin) + def scale_hook(module, input, output): + return output * refiner_scale - return { - "alpha": res_alpha, # Linear, Raw Prediction - "fg": res_fg, # sRGB, Raw Prediction (Straight) - "comp": comp_srgb, # sRGB, Composite - "processed": processed_rgba, # Linear/Premul, RGBA, Garbage Matted & Despilled - } + handle = self.model.refiner.register_forward_hook(scale_hook) + + with torch.autocast(device_type=self.device.type, dtype=torch.float16, enabled=self.mixed_precision): + prediction = self.model(inp_t) + + # Free up unused VRAM in order to keep peak usage down and avoid OOM errors + del inp_t + torch.cuda.empty_cache() + + if handle: + handle.remove() + + with torch.multiprocessing.Pool(num_workers) as pool: + input = zip( + prediction["alpha"].cpu().float(), + prediction["fg"].cpu().float(), + [w] * len(prediction["alpha"]), + [h] * len(prediction["alpha"]), + [fg_is_straight] * len(prediction["alpha"]), + [despill_strength] * len(prediction["alpha"]), + [auto_despeckle] * len(prediction["alpha"]), + [despeckle_size] * len(prediction["alpha"]), + ) + out = pool.starmap(self._postprocess_output, input) + + return out diff --git a/test_vram.py b/test_vram.py index 2f734d8d..32c4d5ab 100644 --- a/test_vram.py +++ b/test_vram.py @@ -6,13 +6,20 @@ from CorridorKeyModule.inference_engine import CorridorKeyEngine -def process_frame(engine): +def process_frame(engine: CorridorKeyEngine): img = np.random.randint(0, 255, (2160, 3840, 3), dtype=np.uint8) mask = np.random.randint(0, 255, (2160, 3840), dtype=np.uint8) engine.process_frame(img, mask) +def batch_process_frame(engine: CorridorKeyEngine, batch_size: int): + imgs = np.random.randint(0, 255, (batch_size, 2160, 3840, 3), dtype=np.uint8) + masks = np.random.randint(0, 255, (batch_size, 2160, 3840), dtype=np.uint8) + + engine.batch_process_frames(imgs, masks) + + def test_vram(): print("Loading engine...") engine = CorridorKeyEngine( @@ -20,18 +27,24 @@ def test_vram(): img_size=2048, device="cuda", model_precision=torch.float16, + mixed_precision=True, ) # Reset stats torch.cuda.reset_peak_memory_stats() iterations = 24 + batch_size = 6 # works with a 16GB GPU print(f"Running {iterations} inference passes...") - time = timeit.timeit(lambda: process_frame(engine), number=iterations) - print(f"Seconds per frame: {time / iterations}") + time = timeit.timeit( + lambda: batch_process_frame(engine, batch_size), + number=iterations, + setup=lambda: batch_process_frame(engine, batch_size), + ) + print(f"Seconds per frame: {time / (iterations * batch_size):.4f}") peak_vram = torch.cuda.max_memory_allocated() / (1024**3) - print(f"Peak VRAM used: {peak_vram:.2f} GB") + print(f"Peak VRAM used: {peak_vram:.2f} GiB") if __name__ == "__main__": From 3852f0db7218f82625eda2cc731bd11ff3b0a09f Mon Sep 17 00:00:00 2001 From: Marclie Date: Tue, 10 Mar 2026 01:55:10 +0100 Subject: [PATCH 08/44] Revert "use multiprocessing to speed up post-processing" This reverts commit efc3a4ddabc2c7bddfa05f9e95f8293ad4a00780. --- CorridorKeyModule/inference_engine.py | 19 ++++++------------- 1 file changed, 6 insertions(+), 13 deletions(-) diff --git a/CorridorKeyModule/inference_engine.py b/CorridorKeyModule/inference_engine.py index b505da03..462eae23 100644 --- a/CorridorKeyModule/inference_engine.py +++ b/CorridorKeyModule/inference_engine.py @@ -298,7 +298,6 @@ def batch_process_frames( despill_strength: float = 1.0, auto_despeckle: bool = True, despeckle_size: int = 400, - num_workers: int = torch.multiprocessing.cpu_count() // 2, ) -> list[dict[str, np.ndarray]]: """ Process a single frame. @@ -315,7 +314,6 @@ def batch_process_frames( despill_strength: float. 0.0 to 1.0 multiplier for the despill effect. auto_despeckle: bool. If True, cleans up small disconnected components from the predicted alpha matte. despeckle_size: int. Minimum number of consecutive pixels required to keep an island. - num_workers: int. Number of worker threads used for post-processing Returns: list[dict: {'alpha': np, 'fg': np (sRGB), 'comp': np (sRGB on Gray)}] """ @@ -366,17 +364,12 @@ def scale_hook(module, input, output): if handle: handle.remove() - with torch.multiprocessing.Pool(num_workers) as pool: - input = zip( - prediction["alpha"].cpu().float(), - prediction["fg"].cpu().float(), - [w] * len(prediction["alpha"]), - [h] * len(prediction["alpha"]), - [fg_is_straight] * len(prediction["alpha"]), - [despill_strength] * len(prediction["alpha"]), - [auto_despeckle] * len(prediction["alpha"]), - [despeckle_size] * len(prediction["alpha"]), + out = [] + for pred_alpha, pred_fg in zip(prediction["alpha"].cpu().float(), prediction["fg"].cpu().float()): + out.append( + self._postprocess_output( + pred_alpha, pred_fg, w, h, fg_is_straight, despill_strength, auto_despeckle, despeckle_size + ) ) - out = pool.starmap(self._postprocess_output, input) return out From 855d435fbf37ea7c7627ff02f02058fb6bd720a0 Mon Sep 17 00:00:00 2001 From: Marclie Date: Tue, 10 Mar 2026 01:55:10 +0100 Subject: [PATCH 09/44] Revert "add a batched frame processing function" This reverts commit ed578e8b2a6b0a95b6a73ddc37fc09fac9a1e744. --- CorridorKeyModule/inference_engine.py | 257 +++++++------------------- test_vram.py | 18 +- 2 files changed, 73 insertions(+), 202 deletions(-) diff --git a/CorridorKeyModule/inference_engine.py b/CorridorKeyModule/inference_engine.py index 462eae23..c90670ec 100644 --- a/CorridorKeyModule/inference_engine.py +++ b/CorridorKeyModule/inference_engine.py @@ -105,108 +105,6 @@ def _load_model(self) -> GreenFormer: return model - def _preprocess_input( - self, image_batch: torch.Tensor, mask_batch_linear: torch.Tensor, input_is_linear: bool - ) -> torch.Tensor: - # 2. Resize to Model Size - # If input is linear, we resize in linear to preserve energy/highlights, - # THEN convert to sRGB for the model. - if input_is_linear: - # TODO: Check if interpolation is comparable to cv2.INTER_LINEAR (probably close enough) - img_resized_lin = TF.resize( - image_batch, - [self.img_size, self.img_size], - interpolation=torchvision.transforms.InterpolationMode.BILINEAR, - ) - # Convert to sRGB for Model - img_resized = cu.linear_to_srgb(img_resized_lin) - else: - # Standard sRGB Resize - img_resized = TF.resize( - image_batch, - [self.img_size, self.img_size], - interpolation=torchvision.transforms.InterpolationMode.BILINEAR, - ) - - mask_resized = TF.resize( - mask_batch_linear, - [self.img_size, self.img_size], - interpolation=torchvision.transforms.InterpolationMode.BILINEAR, - ) - - # 3. Normalize (ImageNet) - # Model expects sRGB input normalized - img_norm = (img_resized - self.mean) / self.std - - # 4. Prepare Tensor - inp_concat = torch.concat((img_norm, mask_resized), -3) # [4, H, W] - - return inp_concat - - def _postprocess_output( - self, - pred_alpha: torch.Tensor, - pred_fg: torch.Tensor, - w: int, - h: int, - fg_is_straight: bool, - despill_strength: float, - auto_despeckle: bool, - despeckle_size: int, - ) -> dict[str, np.ndarray]: - # 6. Post-Process (Resize Back to Original Resolution) - # We use Lanczos4 for high-quality resampling to minimize blur when going back to 4K/Original. - res_alpha = pred_alpha.permute(1, 2, 0).numpy() - res_fg = pred_fg.permute(1, 2, 0).numpy() - res_alpha = cv2.resize(res_alpha, (w, h), interpolation=cv2.INTER_LANCZOS4) - res_fg = cv2.resize(res_fg, (w, h), interpolation=cv2.INTER_LANCZOS4) - - if res_alpha.ndim == 2: - res_alpha = res_alpha[:, :, np.newaxis] - - # --- ADVANCED COMPOSITING --- - - # A. Clean Matte (Auto-Despeckle) - if auto_despeckle: - processed_alpha = cu.clean_matte(res_alpha, area_threshold=despeckle_size, dilation=25, blur_size=5) - else: - processed_alpha = res_alpha - - # B. Despill FG - # res_fg is sRGB. - fg_despilled = cu.despill(res_fg, green_limit_mode="average", strength=despill_strength) - - # C. Premultiply (for EXR Output) - # CONVERT TO LINEAR FIRST! EXRs must house linear color premultiplied by linear alpha. - fg_despilled_lin = cu.srgb_to_linear(fg_despilled) - fg_premul_lin = cu.premultiply(fg_despilled_lin, processed_alpha) - - # D. Pack RGBA - # [H, W, 4] - All channels are now strictly Linear Float - processed_rgba = np.concatenate([fg_premul_lin, processed_alpha], axis=-1) - - # ---------------------------- - - # 7. Composite (on Checkerboard) for checking - # Generate Dark/Light Gray Checkerboard (in sRGB, convert to Linear) - bg_srgb = cu.create_checkerboard(w, h, checker_size=128, color1=0.15, color2=0.55) - bg_lin = cu.srgb_to_linear(bg_srgb) - - if fg_is_straight: - comp_lin = cu.composite_straight(fg_despilled_lin, bg_lin, processed_alpha) - else: - # If premultiplied model, we shouldn't multiply again (though our pipeline forces straight) - comp_lin = cu.composite_premul(fg_despilled_lin, bg_lin, processed_alpha) - - comp_srgb = cu.linear_to_srgb(comp_lin) - - return { - "alpha": res_alpha, # Linear, Raw Prediction - "fg": res_fg, # sRGB, Raw Prediction (Straight) - "comp": comp_srgb, # sRGB, Composite - "processed": processed_rgba, # Linear/Premul, RGBA, Garbage Matted & Despilled - } - @torch.inference_mode() def process_frame( self, @@ -259,10 +157,33 @@ def process_frame( image = image.permute(2, 0, 1) # [C, H, W] mask_linear = mask_linear.permute(2, 0, 1) # [C, H, W] - image = image.unsqueeze(0) - mask_linear = mask_linear.unsqueeze(0) + # 2. Resize to Model Size + # If input is linear, we resize in linear to preserve energy/highlights, + # THEN convert to sRGB for the model. + if input_is_linear: + # TODO: Check if interpolation is comparable to cv2.INTER_LINEAR (probably close enough) + img_resized_lin = TF.resize( + image, [self.img_size, self.img_size], interpolation=torchvision.transforms.InterpolationMode.BILINEAR + ) + # Convert to sRGB for Model + img_resized = cu.linear_to_srgb(img_resized_lin) + else: + # Standard sRGB Resize + img_resized = TF.resize( + image, [self.img_size, self.img_size], interpolation=torchvision.transforms.InterpolationMode.BILINEAR + ) - inp_t = self._preprocess_input(image, mask_linear, input_is_linear) + mask_resized = TF.resize( + mask_linear, [self.img_size, self.img_size], interpolation=torchvision.transforms.InterpolationMode.BILINEAR + ) + + # 3. Normalize (ImageNet) + # Model expects sRGB input normalized + img_norm = (img_resized - self.mean) / self.std + + # 4. Prepare Tensor + inp_concat = torch.concat((img_norm, mask_resized), 0) # [4, H, W] + inp_t = inp_concat.unsqueeze(0) # 5. Inference # Hook for Refiner Scaling @@ -275,101 +196,63 @@ def scale_hook(module, input, output): handle = self.model.refiner.register_forward_hook(scale_hook) with torch.autocast(device_type=self.device.type, dtype=torch.float16, enabled=self.mixed_precision): - prediction = self.model(inp_t) + out = self.model(inp_t) if handle: handle.remove() - pred_alpha = prediction["alpha"][0].cpu().float() - pred_fg = prediction["fg"][0].cpu().float() # Output is sRGB (Sigmoid) - - return self._postprocess_output( - pred_alpha, pred_fg, w, h, fg_is_straight, despill_strength, auto_despeckle, despeckle_size - ) - - @torch.inference_mode() - def batch_process_frames( - self, - images: np.ndarray, - masks_linear: np.ndarray, - refiner_scale: float = 1.0, - input_is_linear: bool = False, - fg_is_straight: bool = True, - despill_strength: float = 1.0, - auto_despeckle: bool = True, - despeckle_size: int = 400, - ) -> list[dict[str, np.ndarray]]: - """ - Process a single frame. - Args: - images: Numpy array [B, H, W, 3] (0.0-1.0 or 0-255). - - If input_is_linear=False (Default): Assumed sRGB. - - If input_is_linear=True: Assumed Linear. - masks_linear: Numpy array [B, H, W] or [B, H, W, 1] (0.0-1.0). Assumed Linear. - refiner_scale: Multiplier for Refiner Deltas (default 1.0). - input_is_linear: bool. If True, resizes in Linear then transforms to sRGB. - If False, resizes in sRGB (standard). - fg_is_straight: bool. If True, assumes FG output is Straight (unpremultiplied). - If False, assumes FG output is Premultiplied. - despill_strength: float. 0.0 to 1.0 multiplier for the despill effect. - auto_despeckle: bool. If True, cleans up small disconnected components from the predicted alpha matte. - despeckle_size: int. Minimum number of consecutive pixels required to keep an island. - Returns: - list[dict: {'alpha': np, 'fg': np (sRGB), 'comp': np (sRGB on Gray)}] - """ - image_was_uint8 = images.dtype == np.uint8 - mask_was_uint8 = masks_linear.dtype == np.uint8 - - # immediately casting to float is fine since fp16 can represent all uint8 values exactly - image = torch.from_numpy(images).to(self.model_precision).to(self.device) - mask_linear = torch.from_numpy(masks_linear).to(self.model_precision).to(self.device) - # 1. Inputs Check & Normalization - if image_was_uint8: - image = image / 255.0 - - if mask_was_uint8: - mask_linear = mask_linear / 255.0 - - h, w = image.shape[1:3] + pred_alpha = out["alpha"] + pred_fg = out["fg"] # Output is sRGB (Sigmoid) - # Ensure Mask Shape - if mask_linear.ndim == 3: - mask_linear = mask_linear.unsqueeze(-1) + # 6. Post-Process (Resize Back to Original Resolution) + # We use Lanczos4 for high-quality resampling to minimize blur when going back to 4K/Original. + res_alpha = pred_alpha[0].permute(1, 2, 0).float().cpu().numpy() + res_fg = pred_fg[0].permute(1, 2, 0).float().cpu().numpy() + res_alpha = cv2.resize(res_alpha, (w, h), interpolation=cv2.INTER_LANCZOS4) + res_fg = cv2.resize(res_fg, (w, h), interpolation=cv2.INTER_LANCZOS4) - image = image.permute(0, 3, 1, 2) # [B, C, H, W] - mask_linear = mask_linear.permute(0, 3, 1, 2) # [B, C, H, W] + if res_alpha.ndim == 2: + res_alpha = res_alpha[:, :, np.newaxis] - inp_t = self._preprocess_input(image, mask_linear, input_is_linear) + # --- ADVANCED COMPOSITING --- - # Free up unused VRAM in order to keep peak usage down and avoid OOM errors - torch.cuda.empty_cache() + # A. Clean Matte (Auto-Despeckle) + if auto_despeckle: + processed_alpha = cu.clean_matte(res_alpha, area_threshold=despeckle_size, dilation=25, blur_size=5) + else: + processed_alpha = res_alpha - # 5. Inference - # Hook for Refiner Scaling - handle = None - if refiner_scale != 1.0 and self.model.refiner is not None: + # B. Despill FG + # res_fg is sRGB. + fg_despilled = cu.despill(res_fg, green_limit_mode="average", strength=despill_strength) - def scale_hook(module, input, output): - return output * refiner_scale + # C. Premultiply (for EXR Output) + # CONVERT TO LINEAR FIRST! EXRs must house linear color premultiplied by linear alpha. + fg_despilled_lin = cu.srgb_to_linear(fg_despilled) + fg_premul_lin = cu.premultiply(fg_despilled_lin, processed_alpha) - handle = self.model.refiner.register_forward_hook(scale_hook) + # D. Pack RGBA + # [H, W, 4] - All channels are now strictly Linear Float + processed_rgba = np.concatenate([fg_premul_lin, processed_alpha], axis=-1) - with torch.autocast(device_type=self.device.type, dtype=torch.float16, enabled=self.mixed_precision): - prediction = self.model(inp_t) + # ---------------------------- - # Free up unused VRAM in order to keep peak usage down and avoid OOM errors - del inp_t - torch.cuda.empty_cache() + # 7. Composite (on Checkerboard) for checking + # Generate Dark/Light Gray Checkerboard (in sRGB, convert to Linear) + bg_srgb = cu.create_checkerboard(w, h, checker_size=128, color1=0.15, color2=0.55) + bg_lin = cu.srgb_to_linear(bg_srgb) - if handle: - handle.remove() + if fg_is_straight: + comp_lin = cu.composite_straight(fg_despilled_lin, bg_lin, processed_alpha) + else: + # If premultiplied model, we shouldn't multiply again (though our pipeline forces straight) + comp_lin = cu.composite_premul(fg_despilled_lin, bg_lin, processed_alpha) - out = [] - for pred_alpha, pred_fg in zip(prediction["alpha"].cpu().float(), prediction["fg"].cpu().float()): - out.append( - self._postprocess_output( - pred_alpha, pred_fg, w, h, fg_is_straight, despill_strength, auto_despeckle, despeckle_size - ) - ) + comp_srgb = cu.linear_to_srgb(comp_lin) - return out + return { + "alpha": res_alpha, # Linear, Raw Prediction + "fg": res_fg, # sRGB, Raw Prediction (Straight) + "comp": comp_srgb, # sRGB, Composite + "processed": processed_rgba, # Linear/Premul, RGBA, Garbage Matted & Despilled + } diff --git a/test_vram.py b/test_vram.py index 32c4d5ab..7a179a2c 100644 --- a/test_vram.py +++ b/test_vram.py @@ -6,20 +6,13 @@ from CorridorKeyModule.inference_engine import CorridorKeyEngine -def process_frame(engine: CorridorKeyEngine): +def process_frame(engine): img = np.random.randint(0, 255, (2160, 3840, 3), dtype=np.uint8) mask = np.random.randint(0, 255, (2160, 3840), dtype=np.uint8) engine.process_frame(img, mask) -def batch_process_frame(engine: CorridorKeyEngine, batch_size: int): - imgs = np.random.randint(0, 255, (batch_size, 2160, 3840, 3), dtype=np.uint8) - masks = np.random.randint(0, 255, (batch_size, 2160, 3840), dtype=np.uint8) - - engine.batch_process_frames(imgs, masks) - - def test_vram(): print("Loading engine...") engine = CorridorKeyEngine( @@ -34,14 +27,9 @@ def test_vram(): torch.cuda.reset_peak_memory_stats() iterations = 24 - batch_size = 6 # works with a 16GB GPU print(f"Running {iterations} inference passes...") - time = timeit.timeit( - lambda: batch_process_frame(engine, batch_size), - number=iterations, - setup=lambda: batch_process_frame(engine, batch_size), - ) - print(f"Seconds per frame: {time / (iterations * batch_size):.4f}") + time = timeit.timeit(lambda: process_frame(engine), number=iterations, setup=lambda: process_frame(engine)) + print(f"Seconds per frame: {time / iterations}") peak_vram = torch.cuda.max_memory_allocated() / (1024**3) print(f"Peak VRAM used: {peak_vram:.2f} GiB") From d1bc29580e6edce84e8466479b14887d5065147f Mon Sep 17 00:00:00 2001 From: Marclie Date: Tue, 10 Mar 2026 01:55:10 +0100 Subject: [PATCH 10/44] Revert "implement preprocessing on GPU using torchvision" This reverts commit 6631537ff0cf0abfa737db622d941b015afebbc5. --- CorridorKeyModule/inference_engine.py | 47 +++++++++------------------ 1 file changed, 16 insertions(+), 31 deletions(-) diff --git a/CorridorKeyModule/inference_engine.py b/CorridorKeyModule/inference_engine.py index c90670ec..f3c59e19 100644 --- a/CorridorKeyModule/inference_engine.py +++ b/CorridorKeyModule/inference_engine.py @@ -6,10 +6,7 @@ import cv2 import numpy as np import torch -import torchvision import torch.nn.functional as F -import torchvision.transforms.functional as TF - from .core import color_utils as cu from .core.model_transformer import GreenFormer @@ -30,8 +27,8 @@ def __init__( self.checkpoint_path = checkpoint_path self.use_refiner = use_refiner - self.mean = torch.tensor([0.485, 0.456, 0.406], dtype=model_precision, device=self.device).reshape(3, 1, 1) - self.std = torch.tensor([0.229, 0.224, 0.225], dtype=model_precision, device=self.device).reshape(3, 1, 1) + self.mean = np.array([0.485, 0.456, 0.406], dtype=np.float32).reshape(1, 1, 3) + self.std = np.array([0.229, 0.224, 0.225], dtype=np.float32).reshape(1, 1, 3) if mixed_precision or model_precision != torch.float32: # Use faster matrix multiplication implementation @@ -135,55 +132,43 @@ def process_frame( Returns: dict: {'alpha': np, 'fg': np (sRGB), 'comp': np (sRGB on Gray)} """ - image_was_uint8 = image.dtype == np.uint8 - mask_was_uint8 = mask_linear.dtype == np.uint8 - - # immediately casting to float is fine since fp16 can represent all uint8 values exactly - image = torch.from_numpy(image).to(self.model_precision).to(self.device) - mask_linear = torch.from_numpy(mask_linear).to(self.model_precision).to(self.device) # 1. Inputs Check & Normalization - if image_was_uint8: - image = image / 255.0 + if image.dtype == np.uint8: + image = image.astype(np.float32) / 255.0 - if mask_was_uint8: - mask_linear = mask_linear / 255.0 + if mask_linear.dtype == np.uint8: + mask_linear = mask_linear.astype(np.float32) / 255.0 h, w = image.shape[:2] # Ensure Mask Shape if mask_linear.ndim == 2: - mask_linear = mask_linear.unsqueeze(-1) - - image = image.permute(2, 0, 1) # [C, H, W] - mask_linear = mask_linear.permute(2, 0, 1) # [C, H, W] + mask_linear = mask_linear[:, :, np.newaxis] # 2. Resize to Model Size # If input is linear, we resize in linear to preserve energy/highlights, # THEN convert to sRGB for the model. if input_is_linear: - # TODO: Check if interpolation is comparable to cv2.INTER_LINEAR (probably close enough) - img_resized_lin = TF.resize( - image, [self.img_size, self.img_size], interpolation=torchvision.transforms.InterpolationMode.BILINEAR - ) + # Resize in Linear + img_resized_lin = cv2.resize(image, (self.img_size, self.img_size), interpolation=cv2.INTER_LINEAR) # Convert to sRGB for Model img_resized = cu.linear_to_srgb(img_resized_lin) else: # Standard sRGB Resize - img_resized = TF.resize( - image, [self.img_size, self.img_size], interpolation=torchvision.transforms.InterpolationMode.BILINEAR - ) + img_resized = cv2.resize(image, (self.img_size, self.img_size), interpolation=cv2.INTER_LINEAR) - mask_resized = TF.resize( - mask_linear, [self.img_size, self.img_size], interpolation=torchvision.transforms.InterpolationMode.BILINEAR - ) + mask_resized = cv2.resize(mask_linear, (self.img_size, self.img_size), interpolation=cv2.INTER_LINEAR) + + if mask_resized.ndim == 2: + mask_resized = mask_resized[:, :, np.newaxis] # 3. Normalize (ImageNet) # Model expects sRGB input normalized img_norm = (img_resized - self.mean) / self.std # 4. Prepare Tensor - inp_concat = torch.concat((img_norm, mask_resized), 0) # [4, H, W] - inp_t = inp_concat.unsqueeze(0) + inp_np = np.concatenate([img_norm, mask_resized], axis=-1) # [H, W, 4] + inp_t = torch.from_numpy(inp_np.transpose((2, 0, 1))).unsqueeze(0).to(self.model_precision).to(self.device) # 5. Inference # Hook for Refiner Scaling From 01a84c4c9ddb46367f0abb1eaf462dae9056cad7 Mon Sep 17 00:00:00 2001 From: Marclie Date: Tue, 10 Mar 2026 01:55:11 +0100 Subject: [PATCH 11/44] Revert "add warmup run to remove compilation overhead from benchmark" This reverts commit cd702552a1fcb220c0aa2aa9a0899c53db92a5ef. --- test_vram.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/test_vram.py b/test_vram.py index 7a179a2c..2f734d8d 100644 --- a/test_vram.py +++ b/test_vram.py @@ -20,7 +20,6 @@ def test_vram(): img_size=2048, device="cuda", model_precision=torch.float16, - mixed_precision=True, ) # Reset stats @@ -28,11 +27,11 @@ def test_vram(): iterations = 24 print(f"Running {iterations} inference passes...") - time = timeit.timeit(lambda: process_frame(engine), number=iterations, setup=lambda: process_frame(engine)) + time = timeit.timeit(lambda: process_frame(engine), number=iterations) print(f"Seconds per frame: {time / iterations}") peak_vram = torch.cuda.max_memory_allocated() / (1024**3) - print(f"Peak VRAM used: {peak_vram:.2f} GiB") + print(f"Peak VRAM used: {peak_vram:.2f} GB") if __name__ == "__main__": From 4a148e2a0ed8e4f3c7498d911703c811fea17f75 Mon Sep 17 00:00:00 2001 From: Marclie Date: Tue, 10 Mar 2026 02:02:47 +0100 Subject: [PATCH 12/44] update uv.lock --- uv.lock | 82 ++++++++++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 75 insertions(+), 7 deletions(-) diff --git a/uv.lock b/uv.lock index 609ba3f1..8a2d453f 100644 --- a/uv.lock +++ b/uv.lock @@ -1,12 +1,12 @@ version = 1 revision = 3 -requires-python = ">=3.10, <=3.14" +requires-python = ">=3.10, <3.15" resolution-markers = [ "python_full_version >= '3.12' and sys_platform != 'darwin'", - "python_full_version >= '3.12' and sys_platform == 'darwin'", "python_full_version == '3.11.*' and sys_platform != 'darwin'", - "python_full_version == '3.11.*' and sys_platform == 'darwin'", "python_full_version < '3.11' and sys_platform != 'darwin'", + "python_full_version >= '3.12' and sys_platform == 'darwin'", + "python_full_version == '3.11.*' and sys_platform == 'darwin'", "python_full_version < '3.11' and sys_platform == 'darwin'", ] @@ -306,8 +306,8 @@ version = "1.3.3" source = { registry = "https://pypi.org/simple" } resolution-markers = [ "python_full_version >= '3.12' and sys_platform != 'darwin'", - "python_full_version >= '3.12' and sys_platform == 'darwin'", "python_full_version == '3.11.*' and sys_platform != 'darwin'", + "python_full_version >= '3.12' and sys_platform == 'darwin'", "python_full_version == '3.11.*' and sys_platform == 'darwin'", ] dependencies = [ @@ -407,6 +407,7 @@ dependencies = [ { name = "peft" }, { name = "pillow" }, { name = "pims" }, + { name = "rich" }, { name = "setuptools" }, { name = "timm" }, { name = "torch", version = "2.10.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform == 'darwin'" }, @@ -416,6 +417,7 @@ dependencies = [ { name = "tqdm" }, { name = "transformers" }, { name = "triton-windows", marker = "sys_platform == 'win32'" }, + { name = "typer" }, ] [package.dev-dependencies] @@ -424,6 +426,9 @@ dev = [ { name = "pytest-cov" }, { name = "ruff" }, ] +docs = [ + { name = "zensical" }, +] [package.metadata] requires-dist = [ @@ -440,6 +445,7 @@ requires-dist = [ { name = "peft" }, { name = "pillow" }, { name = "pims" }, + { name = "rich", specifier = ">=13" }, { name = "setuptools" }, { name = "timm", git = "https://github.com/Raiden129/pytorch-image-models-fix?branch=fix%2Fhiera-flash-attention-global-4d" }, { name = "torch", marker = "sys_platform != 'darwin'", specifier = "==2.10.0", index = "https://download.pytorch.org/whl/cu128" }, @@ -449,6 +455,7 @@ requires-dist = [ { name = "tqdm" }, { name = "transformers" }, { name = "triton-windows", marker = "sys_platform == 'win32'", specifier = "==3.6.0.post25" }, + { name = "typer", specifier = ">=0.12" }, ] [package.metadata.requires-dev] @@ -457,6 +464,7 @@ dev = [ { name = "pytest-cov" }, { name = "ruff" }, ] +docs = [{ name = "zensical", specifier = ">=0.0.24" }] [[package]] name = "coverage" @@ -617,6 +625,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e7/05/c19819d5e3d95294a6f5947fb9b9629efb316b96de511b418c53d245aae6/cycler-0.12.1-py3-none-any.whl", hash = "sha256:85cef7cff222d8644161529808465972e51340599459b8ac3ccbac5a854e0d30", size = 8321, upload-time = "2023-10-07T05:32:16.783Z" }, ] +[[package]] +name = "deepmerge" +version = "2.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/a8/3a/b0ba594708f1ad0bc735884b3ad854d3ca3bdc1d741e56e40bbda6263499/deepmerge-2.0.tar.gz", hash = "sha256:5c3d86081fbebd04dd5de03626a0607b809a98fb6ccba5770b62466fe940ff20", size = 19890, upload-time = "2024-08-30T05:31:50.308Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2d/82/e5d2c1c67d19841e9edc74954c827444ae826978499bde3dfc1d007c8c11/deepmerge-2.0-py3-none-any.whl", hash = "sha256:6de9ce507115cff0bed95ff0ce9ecc31088ef50cbdf09bc90a09349a318b3d00", size = 13475, upload-time = "2024-08-30T05:31:48.659Z" }, +] + [[package]] name = "diffusers" version = "0.36.0" @@ -996,6 +1013,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/da/e9/0d4add7873a73e462aeb45c036a2dead2562b825aa46ba326727b3f31016/kiwisolver-1.4.9-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:fb940820c63a9590d31d88b815e7a3aa5915cad3ce735ab45f0c730b39547de1", size = 73929, upload-time = "2025-08-10T21:27:48.236Z" }, ] +[[package]] +name = "markdown" +version = "3.10.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/2b/f4/69fa6ed85ae003c2378ffa8f6d2e3234662abd02c10d216c0ba96081a238/markdown-3.10.2.tar.gz", hash = "sha256:994d51325d25ad8aa7ce4ebaec003febcce822c3f8c911e3b17c52f7f589f950", size = 368805, upload-time = "2026-02-09T14:57:26.942Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/de/1f/77fa3081e4f66ca3576c896ae5d31c3002ac6607f9747d2e3aa49227e464/markdown-3.10.2-py3-none-any.whl", hash = "sha256:e91464b71ae3ee7afd3017d9f358ef0baf158fd9a298db92f1d4761133824c36", size = 108180, upload-time = "2026-02-09T14:57:25.787Z" }, +] + [[package]] name = "markdown-it-py" version = "4.0.0" @@ -1205,8 +1231,8 @@ version = "3.6.1" source = { registry = "https://pypi.org/simple" } resolution-markers = [ "python_full_version >= '3.12' and sys_platform != 'darwin'", - "python_full_version >= '3.12' and sys_platform == 'darwin'", "python_full_version == '3.11.*' and sys_platform != 'darwin'", + "python_full_version >= '3.12' and sys_platform == 'darwin'", "python_full_version == '3.11.*' and sys_platform == 'darwin'", ] sdist = { url = "https://files.pythonhosted.org/packages/6a/51/63fe664f3908c97be9d2e4f1158eb633317598cfa6e1fc14af5383f17512/networkx-3.6.1.tar.gz", hash = "sha256:26b7c357accc0c8cde558ad486283728b65b6a95d85ee1cd66bafab4c8168509", size = 2517025, upload-time = "2025-12-08T17:02:39.908Z" } @@ -1286,8 +1312,8 @@ version = "2.4.2" source = { registry = "https://pypi.org/simple" } resolution-markers = [ "python_full_version >= '3.12' and sys_platform != 'darwin'", - "python_full_version >= '3.12' and sys_platform == 'darwin'", "python_full_version == '3.11.*' and sys_platform != 'darwin'", + "python_full_version >= '3.12' and sys_platform == 'darwin'", "python_full_version == '3.11.*' and sys_platform == 'darwin'", ] sdist = { url = "https://files.pythonhosted.org/packages/57/fd/0005efbd0af48e55eb3c7208af93f2862d4b1a56cd78e84309a2d959208d/numpy-2.4.2.tar.gz", hash = "sha256:659a6107e31a83c4e33f763942275fd278b21d095094044eb35569e86a21ddae", size = 20723651, upload-time = "2026-01-31T23:13:10.135Z" } @@ -1724,6 +1750,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c7/21/705964c7812476f378728bdf590ca4b771ec72385c533964653c68e86bdc/pygments-2.19.2-py3-none-any.whl", hash = "sha256:86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b", size = 1225217, upload-time = "2025-06-21T13:39:07.939Z" }, ] +[[package]] +name = "pymdown-extensions" +version = "10.21" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "markdown" }, + { name = "pyyaml" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ba/63/06673d1eb6d8f83c0ea1f677d770e12565fb516928b4109c9e2055656a9e/pymdown_extensions-10.21.tar.gz", hash = "sha256:39f4a020f40773f6b2ff31d2cd2546c2c04d0a6498c31d9c688d2be07e1767d5", size = 853363, upload-time = "2026-02-15T20:44:06.748Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/6f/2c/5b079febdc65e1c3fb2729bf958d18b45be7113828528e8a0b5850dd819a/pymdown_extensions-10.21-py3-none-any.whl", hash = "sha256:91b879f9f864d49794c2d9534372b10150e6141096c3908a455e45ca72ad9d3f", size = 268877, upload-time = "2026-02-15T20:44:05.464Z" }, +] + [[package]] name = "pyparsing" version = "3.3.2" @@ -2111,8 +2150,8 @@ version = "2026.2.24" source = { registry = "https://pypi.org/simple" } resolution-markers = [ "python_full_version >= '3.12' and sys_platform != 'darwin'", - "python_full_version >= '3.12' and sys_platform == 'darwin'", "python_full_version == '3.11.*' and sys_platform != 'darwin'", + "python_full_version >= '3.12' and sys_platform == 'darwin'", "python_full_version == '3.11.*' and sys_platform == 'darwin'", ] dependencies = [ @@ -2489,6 +2528,35 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/39/08/aaaad47bc4e9dc8c725e68f9d04865dbcb2052843ff09c97b08904852d84/urllib3-2.6.3-py3-none-any.whl", hash = "sha256:bf272323e553dfb2e87d9bfd225ca7b0f467b919d7bbd355436d3fd37cb0acd4", size = 131584, upload-time = "2026-01-07T16:24:42.685Z" }, ] +[[package]] +name = "zensical" +version = "0.0.24" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "click" }, + { name = "deepmerge" }, + { name = "markdown" }, + { name = "pygments" }, + { name = "pymdown-extensions" }, + { name = "pyyaml" }, + { name = "tomli", marker = "python_full_version < '3.11'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/3b/96/9c6cbdd7b351d1023cdbbcf7872d4cb118b0334cfe5821b99e0dd18e3f00/zensical-0.0.24.tar.gz", hash = "sha256:b5d99e225329bf4f98c8022bdf0a0ee9588c2fada7b4df1b7b896fcc62b37ec3", size = 3840688, upload-time = "2026-02-26T09:43:44.557Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8e/aa/b8201af30e376a67566f044a1c56210edac5ae923fd986a836d2cf593c9c/zensical-0.0.24-cp310-abi3-macosx_10_12_x86_64.whl", hash = "sha256:d390c5453a5541ca35d4f9e1796df942b6612c546e3153dd928236d3b758409a", size = 12263407, upload-time = "2026-02-26T09:43:14.716Z" }, + { url = "https://files.pythonhosted.org/packages/78/8e/3d910214471ade604fd39b080db3696864acc23678b5b4b8475c7dbfd2ce/zensical-0.0.24-cp310-abi3-macosx_11_0_arm64.whl", hash = "sha256:81ac072869cf4d280853765b2bfb688653da0dfb9408f3ab15aca96455ab8223", size = 12142610, upload-time = "2026-02-26T09:43:17.546Z" }, + { url = "https://files.pythonhosted.org/packages/cf/d7/eb0983640aa0419ddf670298cfbcf8b75629b6484925429b857851e00784/zensical-0.0.24-cp310-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b5eb1dfa84cae8e960bfa2c6851d2bc8e9710c4c4c683bd3aaf23185f646ae46", size = 12508380, upload-time = "2026-02-26T09:43:20.114Z" }, + { url = "https://files.pythonhosted.org/packages/a3/04/4405b9e6f937a75db19f0d875798a7eb70817d6a3bec2a2d289a2d5e8aea/zensical-0.0.24-cp310-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:57d7c9e589da99c1879a1c703e67c85eaa6be4661cdc6ce6534f7bb3575983f4", size = 12440807, upload-time = "2026-02-26T09:43:22.679Z" }, + { url = "https://files.pythonhosted.org/packages/12/dc/a7ca2a4224b3072a2c2998b6611ad7fd4f8f131ceae7aa23238d97d26e22/zensical-0.0.24-cp310-abi3-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:42fcc121c3095734b078a95a0dae4d4924fb8fbf16bf730456146ad6cab48ad0", size = 12782727, upload-time = "2026-02-26T09:43:25.347Z" }, + { url = "https://files.pythonhosted.org/packages/42/37/22f1727da356ed3fcbd31f68d4a477f15c232997c87e270cfffb927459ac/zensical-0.0.24-cp310-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:832d4a2a051b9f49561031a2986ace502326f82d9a401ddf125530d30025fdd4", size = 12547616, upload-time = "2026-02-26T09:43:28.031Z" }, + { url = "https://files.pythonhosted.org/packages/6d/ff/c75ff111b8e12157901d00752beef9d691dbb5a034b6a77359972262416a/zensical-0.0.24-cp310-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:e5fea3bb61238dba9f930f52669db67b0c26be98e1c8386a05eb2b1e3cb875dc", size = 12684883, upload-time = "2026-02-26T09:43:30.642Z" }, + { url = "https://files.pythonhosted.org/packages/b9/92/4f6ea066382e3d068d3cadbed99e9a71af25e46c84a403e0f747960472a2/zensical-0.0.24-cp310-abi3-musllinux_1_2_armv7l.whl", hash = "sha256:75eef0428eec2958590633fdc82dc2a58af124879e29573aa7e153b662978073", size = 12713825, upload-time = "2026-02-26T09:43:33.273Z" }, + { url = "https://files.pythonhosted.org/packages/bc/fb/bf735b19bce0034b1f3b8e1c50b2896ebbd0c5d92d462777e759e78bb083/zensical-0.0.24-cp310-abi3-musllinux_1_2_i686.whl", hash = "sha256:3c6b39659156394ff805b4831dac108c839483d9efa4c9b901eaa913efee1ac7", size = 12854318, upload-time = "2026-02-26T09:43:35.632Z" }, + { url = "https://files.pythonhosted.org/packages/7e/28/0ddab6c1237e3625e7763ff666806f31e5760bb36d18624135a6bb6e8643/zensical-0.0.24-cp310-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:9eef82865a18b3ca4c3cd13e245dff09a865d1da3c861e2fc86eaa9253a90f02", size = 12818270, upload-time = "2026-02-26T09:43:37.749Z" }, + { url = "https://files.pythonhosted.org/packages/2a/93/d2cef3705d4434896feadffb5b3e44744ef9f1204bc41202c1b84a4eeef6/zensical-0.0.24-cp310-abi3-win32.whl", hash = "sha256:f4d0ff47d505c786a26c9332317aa3e9ad58d1382f55212a10dc5bafcca97864", size = 11857695, upload-time = "2026-02-26T09:43:39.906Z" }, + { url = "https://files.pythonhosted.org/packages/f1/26/9707587c0f6044dd1e1cc5bc3b9fa5fed81ce6c7bcdb09c21a9795e802d9/zensical-0.0.24-cp310-abi3-win_amd64.whl", hash = "sha256:e00a62cf04526dbed665e989b8f448eb976247f077a76dfdd84699ace4aa3ac3", size = 12057762, upload-time = "2026-02-26T09:43:42.627Z" }, +] + [[package]] name = "zipp" version = "3.23.0" From fa1b4cbd6322bc39fc27bd945d48f0270d669268 Mon Sep 17 00:00:00 2001 From: Marclie Date: Tue, 10 Mar 2026 02:25:41 +0100 Subject: [PATCH 13/44] fix lint --- CorridorKeyModule/core/model_transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CorridorKeyModule/core/model_transformer.py b/CorridorKeyModule/core/model_transformer.py index 619b640c..31138cc4 100644 --- a/CorridorKeyModule/core/model_transformer.py +++ b/CorridorKeyModule/core/model_transformer.py @@ -1,7 +1,7 @@ from __future__ import annotations -import sys import logging +import sys import timm import torch From e3b2b03416b2d7d39bd557cc6681dd59b41611ee Mon Sep 17 00:00:00 2001 From: Marclie Date: Tue, 10 Mar 2026 01:52:05 +0100 Subject: [PATCH 14/44] implement batched frame processing --- CorridorKeyModule/inference_engine.py | 285 +++++++++++++++++++------- test_vram.py | 21 +- 2 files changed, 229 insertions(+), 77 deletions(-) diff --git a/CorridorKeyModule/inference_engine.py b/CorridorKeyModule/inference_engine.py index 73bf2bec..53ac96e4 100644 --- a/CorridorKeyModule/inference_engine.py +++ b/CorridorKeyModule/inference_engine.py @@ -7,7 +7,10 @@ import cv2 import numpy as np import torch +import torchvision import torch.nn.functional as F +import torchvision.transforms.functional as TF + from .core import color_utils as cu from .core.model_transformer import GreenFormer @@ -30,8 +33,8 @@ def __init__( self.checkpoint_path = checkpoint_path self.use_refiner = use_refiner - self.mean = np.array([0.485, 0.456, 0.406], dtype=np.float32).reshape(1, 1, 3) - self.std = np.array([0.229, 0.224, 0.225], dtype=np.float32).reshape(1, 1, 3) + self.mean = torch.tensor([0.485, 0.456, 0.406], dtype=model_precision, device=self.device).reshape(3, 1, 1) + self.std = torch.tensor([0.229, 0.224, 0.225], dtype=model_precision, device=self.device).reshape(3, 1, 1) if mixed_precision or model_precision != torch.float32: # Use faster matrix multiplication implementation @@ -105,6 +108,108 @@ def _load_model(self) -> GreenFormer: return model + def _preprocess_input( + self, image_batch: torch.Tensor, mask_batch_linear: torch.Tensor, input_is_linear: bool + ) -> torch.Tensor: + # 2. Resize to Model Size + # If input is linear, we resize in linear to preserve energy/highlights, + # THEN convert to sRGB for the model. + if input_is_linear: + # TODO: Check if interpolation is comparable to cv2.INTER_LINEAR (probably close enough) + img_resized_lin = TF.resize( + image_batch, + [self.img_size, self.img_size], + interpolation=torchvision.transforms.InterpolationMode.BILINEAR, + ) + # Convert to sRGB for Model + img_resized = cu.linear_to_srgb(img_resized_lin) + else: + # Standard sRGB Resize + img_resized = TF.resize( + image_batch, + [self.img_size, self.img_size], + interpolation=torchvision.transforms.InterpolationMode.BILINEAR, + ) + + mask_resized = TF.resize( + mask_batch_linear, + [self.img_size, self.img_size], + interpolation=torchvision.transforms.InterpolationMode.BILINEAR, + ) + + # 3. Normalize (ImageNet) + # Model expects sRGB input normalized + img_norm = (img_resized - self.mean) / self.std + + # 4. Prepare Tensor + inp_concat = torch.concat((img_norm, mask_resized), -3) # [4, H, W] + + return inp_concat + + def _postprocess_output( + self, + pred_alpha: torch.Tensor, + pred_fg: torch.Tensor, + w: int, + h: int, + fg_is_straight: bool, + despill_strength: float, + auto_despeckle: bool, + despeckle_size: int, + ) -> dict[str, np.ndarray]: + # 6. Post-Process (Resize Back to Original Resolution) + # We use Lanczos4 for high-quality resampling to minimize blur when going back to 4K/Original. + res_alpha = pred_alpha.permute(1, 2, 0).numpy() + res_fg = pred_fg.permute(1, 2, 0).numpy() + res_alpha = cv2.resize(res_alpha, (w, h), interpolation=cv2.INTER_LANCZOS4) + res_fg = cv2.resize(res_fg, (w, h), interpolation=cv2.INTER_LANCZOS4) + + if res_alpha.ndim == 2: + res_alpha = res_alpha[:, :, np.newaxis] + + # --- ADVANCED COMPOSITING --- + + # A. Clean Matte (Auto-Despeckle) + if auto_despeckle: + processed_alpha = cu.clean_matte(res_alpha, area_threshold=despeckle_size, dilation=25, blur_size=5) + else: + processed_alpha = res_alpha + + # B. Despill FG + # res_fg is sRGB. + fg_despilled = cu.despill(res_fg, green_limit_mode="average", strength=despill_strength) + + # C. Premultiply (for EXR Output) + # CONVERT TO LINEAR FIRST! EXRs must house linear color premultiplied by linear alpha. + fg_despilled_lin = cu.srgb_to_linear(fg_despilled) + fg_premul_lin = cu.premultiply(fg_despilled_lin, processed_alpha) + + # D. Pack RGBA + # [H, W, 4] - All channels are now strictly Linear Float + processed_rgba = np.concatenate([fg_premul_lin, processed_alpha], axis=-1) + + # ---------------------------- + + # 7. Composite (on Checkerboard) for checking + # Generate Dark/Light Gray Checkerboard (in sRGB, convert to Linear) + bg_srgb = cu.create_checkerboard(w, h, checker_size=128, color1=0.15, color2=0.55) + bg_lin = cu.srgb_to_linear(bg_srgb) + + if fg_is_straight: + comp_lin = cu.composite_straight(fg_despilled_lin, bg_lin, processed_alpha) + else: + # If premultiplied model, we shouldn't multiply again (though our pipeline forces straight) + comp_lin = cu.composite_premul(fg_despilled_lin, bg_lin, processed_alpha) + + comp_srgb = cu.linear_to_srgb(comp_lin) + + return { # type: ignore[return-value] # cu.* returns ndarray|Tensor but inputs are always ndarray here + "alpha": res_alpha, # Linear, Raw Prediction + "fg": res_fg, # sRGB, Raw Prediction (Straight) + "comp": comp_srgb, # sRGB, Composite + "processed": processed_rgba, # Linear/Premul, RGBA, Garbage Matted & Despilled + } + @torch.inference_mode() def process_frame( self, @@ -135,43 +240,32 @@ def process_frame( Returns: dict: {'alpha': np, 'fg': np (sRGB), 'comp': np (sRGB on Gray)} """ + image_was_uint8 = image.dtype == np.uint8 + mask_was_uint8 = mask_linear.dtype == np.uint8 + + # immediately casting to float is fine since fp16 can represent all uint8 values exactly + image = torch.from_numpy(image).to(self.model_precision).to(self.device) + mask_linear = torch.from_numpy(mask_linear).to(self.model_precision).to(self.device) # 1. Inputs Check & Normalization - if image.dtype == np.uint8: - image = image.astype(np.float32) / 255.0 + if image_was_uint8: + image = image / 255.0 - if mask_linear.dtype == np.uint8: - mask_linear = mask_linear.astype(np.float32) / 255.0 + if mask_was_uint8: + mask_linear = mask_linear / 255.0 h, w = image.shape[:2] # Ensure Mask Shape if mask_linear.ndim == 2: - mask_linear = mask_linear[:, :, np.newaxis] + mask_linear = mask_linear.unsqueeze(-1) - # 2. Resize to Model Size - # If input is linear, we resize in linear to preserve energy/highlights, - # THEN convert to sRGB for the model. - if input_is_linear: - # Resize in Linear - img_resized_lin = cv2.resize(image, (self.img_size, self.img_size), interpolation=cv2.INTER_LINEAR) - # Convert to sRGB for Model - img_resized = cu.linear_to_srgb(img_resized_lin) - else: - # Standard sRGB Resize - img_resized = cv2.resize(image, (self.img_size, self.img_size), interpolation=cv2.INTER_LINEAR) + image = image.permute(2, 0, 1) # [C, H, W] + mask_linear = mask_linear.permute(2, 0, 1) # [C, H, W] - mask_resized = cv2.resize(mask_linear, (self.img_size, self.img_size), interpolation=cv2.INTER_LINEAR) + image = image.unsqueeze(0) + mask_linear = mask_linear.unsqueeze(0) - if mask_resized.ndim == 2: - mask_resized = mask_resized[:, :, np.newaxis] - - # 3. Normalize (ImageNet) - # Model expects sRGB input normalized - img_norm = (img_resized - self.mean) / self.std - - # 4. Prepare Tensor - inp_np = np.concatenate([img_norm, mask_resized], axis=-1) # [H, W, 4] - inp_t = torch.from_numpy(inp_np.transpose((2, 0, 1))).unsqueeze(0).to(self.model_precision).to(self.device) + inp_t = self._preprocess_input(image, mask_linear, input_is_linear) # 5. Inference # Hook for Refiner Scaling @@ -184,63 +278,108 @@ def scale_hook(module, input, output): handle = self.model.refiner.register_forward_hook(scale_hook) with torch.autocast(device_type=self.device.type, dtype=torch.float16, enabled=self.mixed_precision): - out = self.model(inp_t) + prediction = self.model(inp_t) if handle: handle.remove() - pred_alpha = out["alpha"] - pred_fg = out["fg"] # Output is sRGB (Sigmoid) + pred_alpha = prediction["alpha"][0].cpu().float() + pred_fg = prediction["fg"][0].cpu().float() # Output is sRGB (Sigmoid) - # 6. Post-Process (Resize Back to Original Resolution) - # We use Lanczos4 for high-quality resampling to minimize blur when going back to 4K/Original. - res_alpha = pred_alpha[0].permute(1, 2, 0).float().cpu().numpy() - res_fg = pred_fg[0].permute(1, 2, 0).float().cpu().numpy() - res_alpha = cv2.resize(res_alpha, (w, h), interpolation=cv2.INTER_LANCZOS4) - res_fg = cv2.resize(res_fg, (w, h), interpolation=cv2.INTER_LANCZOS4) + return self._postprocess_output( + pred_alpha, pred_fg, w, h, fg_is_straight, despill_strength, auto_despeckle, despeckle_size + ) - if res_alpha.ndim == 2: - res_alpha = res_alpha[:, :, np.newaxis] + @torch.inference_mode() + def batch_process_frames( + self, + images: np.ndarray, + masks_linear: np.ndarray, + refiner_scale: float = 1.0, + input_is_linear: bool = False, + fg_is_straight: bool = True, + despill_strength: float = 1.0, + auto_despeckle: bool = True, + despeckle_size: int = 400, + num_workers: int = torch.multiprocessing.cpu_count() // 2, + ) -> list[dict[str, np.ndarray]]: + """ + Process a single frame. + Args: + images: Numpy array [B, H, W, 3] (0.0-1.0 or 0-255). + - If input_is_linear=False (Default): Assumed sRGB. + - If input_is_linear=True: Assumed Linear. + masks_linear: Numpy array [B, H, W] or [B, H, W, 1] (0.0-1.0). Assumed Linear. + refiner_scale: Multiplier for Refiner Deltas (default 1.0). + input_is_linear: bool. If True, resizes in Linear then transforms to sRGB. + If False, resizes in sRGB (standard). + fg_is_straight: bool. If True, assumes FG output is Straight (unpremultiplied). + If False, assumes FG output is Premultiplied. + despill_strength: float. 0.0 to 1.0 multiplier for the despill effect. + auto_despeckle: bool. If True, cleans up small disconnected components from the predicted alpha matte. + despeckle_size: int. Minimum number of consecutive pixels required to keep an island. + num_workers: int. Number of worker threads used for post-processing + Returns: + list[dict: {'alpha': np, 'fg': np (sRGB), 'comp': np (sRGB on Gray)}] + """ + image_was_uint8 = images.dtype == np.uint8 + mask_was_uint8 = masks_linear.dtype == np.uint8 - # --- ADVANCED COMPOSITING --- + # immediately casting to float is fine since fp16 can represent all uint8 values exactly + image = torch.from_numpy(images).to(self.model_precision).to(self.device) + mask_linear = torch.from_numpy(masks_linear).to(self.model_precision).to(self.device) + # 1. Inputs Check & Normalization + if image_was_uint8: + image = image / 255.0 - # A. Clean Matte (Auto-Despeckle) - if auto_despeckle: - processed_alpha = cu.clean_matte(res_alpha, area_threshold=despeckle_size, dilation=25, blur_size=5) - else: - processed_alpha = res_alpha + if mask_was_uint8: + mask_linear = mask_linear / 255.0 - # B. Despill FG - # res_fg is sRGB. - fg_despilled = cu.despill(res_fg, green_limit_mode="average", strength=despill_strength) + h, w = image.shape[1:3] - # C. Premultiply (for EXR Output) - # CONVERT TO LINEAR FIRST! EXRs must house linear color premultiplied by linear alpha. - fg_despilled_lin = cu.srgb_to_linear(fg_despilled) - fg_premul_lin = cu.premultiply(fg_despilled_lin, processed_alpha) + # Ensure Mask Shape + if mask_linear.ndim == 3: + mask_linear = mask_linear.unsqueeze(-1) - # D. Pack RGBA - # [H, W, 4] - All channels are now strictly Linear Float - processed_rgba = np.concatenate([fg_premul_lin, processed_alpha], axis=-1) + image = image.permute(0, 3, 1, 2) # [B, C, H, W] + mask_linear = mask_linear.permute(0, 3, 1, 2) # [B, C, H, W] - # ---------------------------- + inp_t = self._preprocess_input(image, mask_linear, input_is_linear) - # 7. Composite (on Checkerboard) for checking - # Generate Dark/Light Gray Checkerboard (in sRGB, convert to Linear) - bg_srgb = cu.create_checkerboard(w, h, checker_size=128, color1=0.15, color2=0.55) - bg_lin = cu.srgb_to_linear(bg_srgb) + # Free up unused VRAM in order to keep peak usage down and avoid OOM errors + torch.cuda.empty_cache() - if fg_is_straight: - comp_lin = cu.composite_straight(fg_despilled_lin, bg_lin, processed_alpha) - else: - # If premultiplied model, we shouldn't multiply again (though our pipeline forces straight) - comp_lin = cu.composite_premul(fg_despilled_lin, bg_lin, processed_alpha) + # 5. Inference + # Hook for Refiner Scaling + handle = None + if refiner_scale != 1.0 and self.model.refiner is not None: - comp_srgb = cu.linear_to_srgb(comp_lin) + def scale_hook(module, input, output): + return output * refiner_scale - return { # type: ignore[return-value] # cu.* returns ndarray|Tensor but inputs are always ndarray here - "alpha": res_alpha, # Linear, Raw Prediction - "fg": res_fg, # sRGB, Raw Prediction (Straight) - "comp": comp_srgb, # sRGB, Composite - "processed": processed_rgba, # Linear/Premul, RGBA, Garbage Matted & Despilled - } + handle = self.model.refiner.register_forward_hook(scale_hook) + + with torch.autocast(device_type=self.device.type, dtype=torch.float16, enabled=self.mixed_precision): + prediction = self.model(inp_t) + + # Free up unused VRAM in order to keep peak usage down and avoid OOM errors + del inp_t + torch.cuda.empty_cache() + + if handle: + handle.remove() + + with torch.multiprocessing.Pool(num_workers) as pool: + input = zip( + prediction["alpha"].cpu().float(), + prediction["fg"].cpu().float(), + [w] * len(prediction["alpha"]), + [h] * len(prediction["alpha"]), + [fg_is_straight] * len(prediction["alpha"]), + [despill_strength] * len(prediction["alpha"]), + [auto_despeckle] * len(prediction["alpha"]), + [despeckle_size] * len(prediction["alpha"]), + ) + out = pool.starmap(self._postprocess_output, input) + + return out diff --git a/test_vram.py b/test_vram.py index 2f734d8d..32c4d5ab 100644 --- a/test_vram.py +++ b/test_vram.py @@ -6,13 +6,20 @@ from CorridorKeyModule.inference_engine import CorridorKeyEngine -def process_frame(engine): +def process_frame(engine: CorridorKeyEngine): img = np.random.randint(0, 255, (2160, 3840, 3), dtype=np.uint8) mask = np.random.randint(0, 255, (2160, 3840), dtype=np.uint8) engine.process_frame(img, mask) +def batch_process_frame(engine: CorridorKeyEngine, batch_size: int): + imgs = np.random.randint(0, 255, (batch_size, 2160, 3840, 3), dtype=np.uint8) + masks = np.random.randint(0, 255, (batch_size, 2160, 3840), dtype=np.uint8) + + engine.batch_process_frames(imgs, masks) + + def test_vram(): print("Loading engine...") engine = CorridorKeyEngine( @@ -20,18 +27,24 @@ def test_vram(): img_size=2048, device="cuda", model_precision=torch.float16, + mixed_precision=True, ) # Reset stats torch.cuda.reset_peak_memory_stats() iterations = 24 + batch_size = 6 # works with a 16GB GPU print(f"Running {iterations} inference passes...") - time = timeit.timeit(lambda: process_frame(engine), number=iterations) - print(f"Seconds per frame: {time / iterations}") + time = timeit.timeit( + lambda: batch_process_frame(engine, batch_size), + number=iterations, + setup=lambda: batch_process_frame(engine, batch_size), + ) + print(f"Seconds per frame: {time / (iterations * batch_size):.4f}") peak_vram = torch.cuda.max_memory_allocated() / (1024**3) - print(f"Peak VRAM used: {peak_vram:.2f} GB") + print(f"Peak VRAM used: {peak_vram:.2f} GiB") if __name__ == "__main__": From 809f37639652db38a916bc17e6239ba160a8bb0c Mon Sep 17 00:00:00 2001 From: Marclie Date: Tue, 10 Mar 2026 21:22:00 +0100 Subject: [PATCH 15/44] move compilation to function call to improve flexibility --- CorridorKeyModule/core/model_transformer.py | 3 --- CorridorKeyModule/inference_engine.py | 8 ++++++-- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/CorridorKeyModule/core/model_transformer.py b/CorridorKeyModule/core/model_transformer.py index 31138cc4..2908c59b 100644 --- a/CorridorKeyModule/core/model_transformer.py +++ b/CorridorKeyModule/core/model_transformer.py @@ -1,7 +1,6 @@ from __future__ import annotations import logging -import sys import timm import torch @@ -143,8 +142,6 @@ def forward(self, img: torch.Tensor, coarse_pred: torch.Tensor) -> torch.Tensor: return self.final(x) * 10.0 -# We only tested compilation on windows and linux. For other platforms compilation is disabled as a precaution. -@torch.compile(disable=(sys.platform != "linux" and sys.platform != "win32")) class GreenFormer(nn.Module): def __init__( self, diff --git a/CorridorKeyModule/inference_engine.py b/CorridorKeyModule/inference_engine.py index 53ac96e4..0d18bbe3 100644 --- a/CorridorKeyModule/inference_engine.py +++ b/CorridorKeyModule/inference_engine.py @@ -3,15 +3,15 @@ import logging import math import os +import sys import cv2 import numpy as np import torch -import torchvision import torch.nn.functional as F +import torchvision import torchvision.transforms.functional as TF - from .core import color_utils as cu from .core.model_transformer import GreenFormer @@ -106,6 +106,10 @@ def _load_model(self) -> GreenFormer: if len(unexpected) > 0: print(f"[Warning] Unexpected keys: {unexpected}") + # We only tested compilation on Windows and Linux. For other platforms compilation is disabled as a precaution. + if sys.platform == "linux" or sys.platform == "win32": + model = torch.compile(model) + return model def _preprocess_input( From 2ead927b71984b8b79b72caaea20c8bba904f425 Mon Sep 17 00:00:00 2001 From: Marclie Date: Tue, 10 Mar 2026 21:22:51 +0100 Subject: [PATCH 16/44] bound threads with batch size and fix lint --- CorridorKeyModule/inference_engine.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/CorridorKeyModule/inference_engine.py b/CorridorKeyModule/inference_engine.py index 0d18bbe3..ef0a337e 100644 --- a/CorridorKeyModule/inference_engine.py +++ b/CorridorKeyModule/inference_engine.py @@ -339,7 +339,7 @@ def batch_process_frames( if mask_was_uint8: mask_linear = mask_linear / 255.0 - h, w = image.shape[1:3] + bs, h, w = image.shape[:3] # Ensure Mask Shape if mask_linear.ndim == 3: @@ -373,8 +373,8 @@ def scale_hook(module, input, output): if handle: handle.remove() - with torch.multiprocessing.Pool(num_workers) as pool: - input = zip( + with torch.multiprocessing.Pool(min(num_workers, bs)) as pool: + inp = zip( prediction["alpha"].cpu().float(), prediction["fg"].cpu().float(), [w] * len(prediction["alpha"]), @@ -383,7 +383,8 @@ def scale_hook(module, input, output): [despill_strength] * len(prediction["alpha"]), [auto_despeckle] * len(prediction["alpha"]), [despeckle_size] * len(prediction["alpha"]), + strict=True, ) - out = pool.starmap(self._postprocess_output, input) + out = pool.starmap(self._postprocess_output, inp) return out From 8775cb97ed5b6134491d7aed78e1bb4b05b934f7 Mon Sep 17 00:00:00 2001 From: Marclie Date: Wed, 11 Mar 2026 21:02:56 +0100 Subject: [PATCH 17/44] add qualitative comparison helper script --- test_outputs.py | 142 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 142 insertions(+) create mode 100644 test_outputs.py diff --git a/test_outputs.py b/test_outputs.py new file mode 100644 index 00000000..9e16e10f --- /dev/null +++ b/test_outputs.py @@ -0,0 +1,142 @@ +import os + +import torch +from torchvision.io import read_image +from torchvision.utils import save_image + +from CorridorKeyModule.inference_engine import CorridorKeyEngine + +# there is some compile weirdness when generating the images +torch._dynamo.config.cache_size_limit = 1024 + + +def load_engine(img_size, precision, mixed_precision): + return CorridorKeyEngine( + checkpoint_path="CorridorKeyModule/checkpoints/CorridorKey_v1.0.pth", + img_size=img_size, + device="cuda", + model_precision=precision, + mixed_precision=mixed_precision, + ) + + +def generate_test_images(img_path, mask_path): + img = read_image(img_path).permute(1, 2, 0).numpy() + mask = read_image(mask_path).permute(1, 2, 0).numpy() + img_sizes = [512, 1024, 2048] + precisions = [torch.float16, torch.float32, torch.float64] + for precision in precisions: + for img_size in img_sizes: + # Reset stats + torch.cuda.reset_peak_memory_stats() + + if precision == torch.float64 and img_size > 1024: + continue + + engine = load_engine(img_size, precision) + out = engine.process_frame(img, mask) + + save_image( + torch.from_numpy(out["fg"]).permute(2, 0, 1), + f"./Output/foreground_{img_size}_{str(precision)[-7:]}.png", + ) + save_image( + torch.from_numpy(out["alpha"]).permute(2, 0, 1), f"./Output/alpha_{img_size}_{str(precision)[-7:]}.png" + ) + + peak_vram = torch.cuda.max_memory_allocated() / (1024**3) + print(f"Precision: {precision}, Image Size: {img_size}, Peak VRAM: {peak_vram:.2f} GB") + + +def compare_implementations(src, comparison): + for _, _, files in os.walk(src): + for file in files: + src_img = read_image(str(os.path.join(src, file))).float() + comp_img = read_image(str(os.path.join(comparison, file))).float() + + is_mask = src_img.shape[0] == 1 or (src_img[0] == src_img[1]).all() and (src_img[1] == src_img[2]).all() + + difference = (src_img - comp_img).float() / 255 + + if is_mask: + difference = difference[0].unsqueeze(0) + difference = torch.cat( + (difference.clamp(-1, 0).abs(), difference.clamp(0, 1), torch.zeros_like(difference)), dim=0 + ) + print(difference.shape) + print(difference.min(), difference.max()) + else: + difference = difference.abs() + + save_image(difference, f"./Output/diff_{file}") + + +def compare_floating_point_precision(folder, ref="float64"): + for _, _, files in os.walk(folder): + for file in files: + name, fmt = file.split(".") + typ, img_size, precision = name.split("_") + if precision != ref: + continue + float_ref = read_image(str(os.path.join(folder, file))).float() + float_32 = read_image(str(os.path.join(folder, f"{typ}_{img_size}_float32.{fmt}"))).float() + + is_mask = typ == "alpha" + + difference = (float_ref - float_32).float() / 255 + + if is_mask: + difference = difference[0].unsqueeze(0) + difference = torch.cat( + (difference.clamp(-1, 0).abs(), difference.clamp(0, 1), torch.zeros_like(difference)), dim=0 + ) + else: + difference = difference.abs() + print( + is_mask, + difference.min().item(), + difference.max().item(), + difference.mean().item(), + difference.median().item(), + ) + + save_image(difference, f"./Output/prec_{ref}_{typ}_{img_size}.{fmt}") + + +def compare_img_sizes(folder, ref=1024): + for _, _, files in os.walk(folder): + for file in files: + name, fmt = file.split(".") + typ, img_size, precision = name.split("_") + if img_size != str(ref): + continue + if precision == "float64": + continue + img_ref = read_image(str(os.path.join(folder, file))).float() + img_2048 = read_image(str(os.path.join(folder, f"{typ}_2048_{precision}.{fmt}"))).float() + + is_mask = typ == "alpha" + + difference = (img_ref - img_2048).float() / 255 + + if is_mask: + difference = difference[0].unsqueeze(0) + difference = torch.cat( + (difference.clamp(-1, 0).abs(), difference.clamp(0, 1), torch.zeros_like(difference)), dim=0 + ) + else: + difference = difference.abs() + print( + is_mask, + difference.min().item(), + difference.max().item(), + difference.mean().item(), + difference.median().item(), + ) + + save_image(difference, f"./Output/img_{ref}_{typ}_{precision}.{fmt}") + + +if __name__ == "__main__": + compare_img_sizes("./Output/original", 1024) + compare_img_sizes("./Output/original", 512) From 8bb35e845c43edacc534f309ffd9b69f6a51b3c4 Mon Sep 17 00:00:00 2001 From: Marclie Date: Wed, 11 Mar 2026 23:20:59 +0100 Subject: [PATCH 18/44] fix tests --- tests/test_inference_engine.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/tests/test_inference_engine.py b/tests/test_inference_engine.py index cd3d7739..ef09d297 100644 --- a/tests/test_inference_engine.py +++ b/tests/test_inference_engine.py @@ -22,7 +22,7 @@ # --------------------------------------------------------------------------- -def _make_engine_with_mock(mock_greenformer, img_size=64): +def _make_engine_with_mock(mock_greenformer, img_size=64, device="cpu"): """Create a CorridorKeyEngine with a mocked model, bypassing __init__. Manually sets the attributes that __init__ would create, avoiding the @@ -31,12 +31,12 @@ def _make_engine_with_mock(mock_greenformer, img_size=64): from CorridorKeyModule.inference_engine import CorridorKeyEngine engine = object.__new__(CorridorKeyEngine) - engine.device = torch.device("cpu") + engine.device = torch.device(device) engine.img_size = img_size engine.checkpoint_path = "/fake/checkpoint.pth" engine.use_refiner = False - engine.mean = np.array([0.485, 0.456, 0.406], dtype=np.float32).reshape(1, 1, 3) - engine.std = np.array([0.229, 0.224, 0.225], dtype=np.float32).reshape(1, 1, 3) + engine.mean = torch.tensor([0.485, 0.456, 0.406], dtype=torch.float32, device=torch.device(device)).reshape(3, 1, 1) + engine.std = torch.tensor([0.229, 0.224, 0.225], dtype=torch.float32, device=torch.device(device)).reshape(3, 1, 1) engine.model = mock_greenformer engine.model_precision = torch.float32 engine.mixed_precision = True @@ -224,8 +224,7 @@ def test_process_frame_on_gpu(self, sample_frame_rgb, sample_mask, mock_greenfor if not torch.cuda.is_available(): pytest.skip("CUDA not available") - engine = _make_engine_with_mock(mock_greenformer) - engine.device = torch.device("cuda") + engine = _make_engine_with_mock(mock_greenformer, device=torch.device("cuda")) result = engine.process_frame(sample_frame_rgb, sample_mask) assert result["alpha"].dtype == np.float32 From 12308e46ea469a3d6f7236ec43622f1ebeb981dc Mon Sep 17 00:00:00 2001 From: marclie Date: Thu, 12 Mar 2026 16:55:10 +0100 Subject: [PATCH 19/44] initial GPU pipeline draft --- CorridorKeyModule/inference_engine.py | 159 +++++++++++++++++++++++--- 1 file changed, 144 insertions(+), 15 deletions(-) diff --git a/CorridorKeyModule/inference_engine.py b/CorridorKeyModule/inference_engine.py index ef0a337e..f33288f6 100644 --- a/CorridorKeyModule/inference_engine.py +++ b/CorridorKeyModule/inference_engine.py @@ -150,7 +150,7 @@ def _preprocess_input( return inp_concat - def _postprocess_output( + def _postprocess_cpu( self, pred_alpha: torch.Tensor, pred_fg: torch.Tensor, @@ -213,6 +213,147 @@ def _postprocess_output( "comp": comp_srgb, # sRGB, Composite "processed": processed_rgba, # Linear/Premul, RGBA, Garbage Matted & Despilled } + + def _postprocess_gpu( + self, + pred_alpha: torch.Tensor, + pred_fg: torch.Tensor, + h: int, + w: int, + auto_despeckle: bool, + despeckle_size: int, + despill_strength: float, + fg_is_straight: bool, + sync: bool = True, + ) -> dict[str, np.ndarray]: + """Post-process on GPU, transfer final results to CPU. + + When ``sync=True`` (default), blocks until transfer completes and + returns numpy arrays. When ``sync=False``, starts the DMA + non-blocking and returns a :class:`PendingTransfer` — call + ``.resolve()`` to get the numpy dict later. + """ + # Resize on GPU using F.interpolate (much faster than cv2 at 4K) + alpha_up = TF.resize( + pred_alpha.float(), + [h, w], + interpolation=torchvision.transforms.InterpolationMode.BILINEAR, + ) + fg_up = TF.resize( + pred_fg.float(), + [h, w], + interpolation=torchvision.transforms.InterpolationMode.BILINEAR, + ) + + # Convert to HWC on GPU + res_alpha = alpha_up.permute(0, 2, 3, 1) # [B, H, W, 1] + res_fg = fg_up.permute(0, 2, 3, 1) # [B, H, W, 3] + + # A. Clean matte + if auto_despeckle: + processed_alpha = self._clean_matte_gpu(res_alpha, despeckle_size, dilation=25, blur_size=5) + else: + processed_alpha = res_alpha + + # B. Despill on GPU + fg_despilled = self._despill_gpu(res_fg, despill_strength) + + # C. sRGB → linear on GPU + fg_despilled_lin = cu.srgb_to_linear(fg_despilled) + + # D. Premultiply on GPU + fg_premul_lin = cu.premultiply(fg_despilled_lin, processed_alpha) + + # E. Pack RGBA on GPU + processed_rgba = torch.cat([fg_premul_lin, processed_alpha], dim=-1) + + # === Bulk transfer to CPU via copy stream === + # + # Pack all outputs into one contiguous tensor, DMA to a pinned + # buffer on the copy stream. Layout varies by comp mode: + # Checkerboard: [H, W, 1+3+3+4] = [alpha, fg, comp_rgb, processed] + # Transparent: [H, W, 1+3+4+4] = [alpha, fg, comp_rgba, processed] + # No comp: [H, W, 1+3+3+4] = [alpha, fg, zeros, processed] + comp_channels = 3 + bulk = torch.cat([res_alpha, res_fg, torch.zeros_like(res_fg), processed_rgba], dim=-1) + + bulk_np = bulk.cpu().numpy() + cc = comp_channels + if bulk_np.shape[0] == 1: + result = { + "alpha": bulk_np[:, :, 0:1], + "fg": bulk_np[:, :, 1:4], + "comp": bulk_np[:, :, 4 : 4 + cc], + "processed": bulk_np[:, :, 4 + cc : 4 + cc + 4], + } + return result + + out = [] + for i in range(bulk_np.shape[0]): + result = { + "alpha": bulk_np[i, :, :, 0:1], + "fg": bulk_np[i, :, :, 1:4], + "comp": bulk_np[i, :, :, 4 : 4 + cc], + "processed": bulk_np[i, :, :, 4 + cc : 4 + cc + 4], + } + out.append(result) + + @staticmethod + @torch.compile() + def _clean_matte_gpu(alpha: torch.Tensor, area_threshold: int, dilation: int, blur_size: int) -> torch.Tensor: + """Fully GPU matte cleanup using morphological operations. + + Approximates connected-component removal by eroding small regions + away, then dilating back. Avoids the GPU→CPU→GPU roundtrip that + ``cv2.connectedComponentsWithStats`` would require. + + The erosion radius is derived from ``area_threshold``: a circular + spot of area A has radius sqrt(A/pi), so erosion by that radius + eliminates it. + """ + _device = alpha.device + # alpha: [H, W, 1] + a2d = alpha[..., 0] + mask = (a2d > 0.5).float().unsqueeze(0) # [B, 1, H, W] + + # Erode: kill spots smaller than area_threshold + # A circle of area A has radius r = sqrt(A / pi) + import math + + erode_r = max(1, int(math.sqrt(area_threshold / math.pi))) + erode_k = erode_r * 2 + 1 + # Erosion = negative of max_pool on negated mask + mask = -F.max_pool2d(-mask, erode_k, stride=1, padding=erode_r) + + # Dilate back to restore edges of large regions + dilate_r = erode_r + (dilation if dilation > 0 else 0) + dilate_k = dilate_r * 2 + 1 + mask = F.max_pool2d(mask, dilate_k, stride=1, padding=dilate_r) + + # Blur for soft edges + if blur_size > 0: + k = int(blur_size * 2 + 1) + mask = F.avg_pool2d(mask, k, stride=1, padding=blur_size) + + safe = mask.squeeze(0).squeeze(0) # [H, W] + return (a2d * safe).unsqueeze(-1) # [H, W, 1] + + @staticmethod + @torch.compile() + def _despill_gpu(image: torch.Tensor, strength: float) -> torch.Tensor: + """GPU despill — keeps data on device.""" + if strength <= 0.0: + return image + r, g, b = image[..., 0], image[..., 1], image[..., 2] + limit = (r + b) / 2.0 + spill = torch.clamp(g - limit, min=0.0) + g_new = g - spill + r_new = r + spill * 0.5 + b_new = b + spill * 0.5 + despilled = torch.stack([r_new, g_new, b_new], dim=-1) + if strength < 1.0: + return image * (1.0 - strength) + despilled * strength + return despilled @torch.inference_mode() def process_frame( @@ -290,7 +431,7 @@ def scale_hook(module, input, output): pred_alpha = prediction["alpha"][0].cpu().float() pred_fg = prediction["fg"][0].cpu().float() # Output is sRGB (Sigmoid) - return self._postprocess_output( + return self._postprocess_cpu( pred_alpha, pred_fg, w, h, fg_is_straight, despill_strength, auto_despeckle, despeckle_size ) @@ -373,18 +514,6 @@ def scale_hook(module, input, output): if handle: handle.remove() - with torch.multiprocessing.Pool(min(num_workers, bs)) as pool: - inp = zip( - prediction["alpha"].cpu().float(), - prediction["fg"].cpu().float(), - [w] * len(prediction["alpha"]), - [h] * len(prediction["alpha"]), - [fg_is_straight] * len(prediction["alpha"]), - [despill_strength] * len(prediction["alpha"]), - [auto_despeckle] * len(prediction["alpha"]), - [despeckle_size] * len(prediction["alpha"]), - strict=True, - ) - out = pool.starmap(self._postprocess_output, inp) + out = self._postprocess_gpu(prediction["alpha"], prediction["fg"], h, w, auto_despeckle, despeckle_size, despill_strength, fg_is_straight) return out From 75cafcdf01f4be55600a1d9c8d90e94e14fd2f1d Mon Sep 17 00:00:00 2001 From: Marclie Date: Sat, 14 Mar 2026 01:42:19 +0100 Subject: [PATCH 20/44] fix tests --- CorridorKeyModule/inference_engine.py | 190 ++++++++++++++------------ 1 file changed, 99 insertions(+), 91 deletions(-) mode change 100644 => 100755 CorridorKeyModule/inference_engine.py diff --git a/CorridorKeyModule/inference_engine.py b/CorridorKeyModule/inference_engine.py old mode 100644 new mode 100755 index f33288f6..dae66995 --- a/CorridorKeyModule/inference_engine.py +++ b/CorridorKeyModule/inference_engine.py @@ -213,91 +213,97 @@ def _postprocess_cpu( "comp": comp_srgb, # sRGB, Composite "processed": processed_rgba, # Linear/Premul, RGBA, Garbage Matted & Despilled } - + + def _get_checkerboard_linear_gpu(self, w: int, h: int) -> torch.Tensor: + """Return a cached checkerboard tensor [H, W, 3] on device in linear space.""" + checker_size = 128 + y_coords = torch.arange(h, device=self.device) // checker_size + x_coords = torch.arange(w, device=self.device) // checker_size + y_grid, x_grid = torch.meshgrid(y_coords, x_coords, indexing="ij") + checker = ((x_grid + y_grid) % 2).float() + # Map 0 -> 0.15, 1 -> 0.55 (sRGB), then convert to linear before caching + bg_srgb = checker * 0.4 + 0.15 # [H, W] + bg_srgb_3 = bg_srgb.unsqueeze(-1).expand(-1, -1, 3) + return cu.srgb_to_linear(bg_srgb_3) + def _postprocess_gpu( - self, - pred_alpha: torch.Tensor, - pred_fg: torch.Tensor, - h: int, - w: int, - auto_despeckle: bool, - despeckle_size: int, - despill_strength: float, - fg_is_straight: bool, - sync: bool = True, - ) -> dict[str, np.ndarray]: - """Post-process on GPU, transfer final results to CPU. - - When ``sync=True`` (default), blocks until transfer completes and - returns numpy arrays. When ``sync=False``, starts the DMA - non-blocking and returns a :class:`PendingTransfer` — call - ``.resolve()`` to get the numpy dict later. - """ - # Resize on GPU using F.interpolate (much faster than cv2 at 4K) - alpha_up = TF.resize( - pred_alpha.float(), - [h, w], - interpolation=torchvision.transforms.InterpolationMode.BILINEAR, - ) - fg_up = TF.resize( - pred_fg.float(), - [h, w], - interpolation=torchvision.transforms.InterpolationMode.BILINEAR, - ) - - # Convert to HWC on GPU - res_alpha = alpha_up.permute(0, 2, 3, 1) # [B, H, W, 1] - res_fg = fg_up.permute(0, 2, 3, 1) # [B, H, W, 3] - - # A. Clean matte - if auto_despeckle: - processed_alpha = self._clean_matte_gpu(res_alpha, despeckle_size, dilation=25, blur_size=5) - else: - processed_alpha = res_alpha - - # B. Despill on GPU - fg_despilled = self._despill_gpu(res_fg, despill_strength) - - # C. sRGB → linear on GPU - fg_despilled_lin = cu.srgb_to_linear(fg_despilled) - - # D. Premultiply on GPU - fg_premul_lin = cu.premultiply(fg_despilled_lin, processed_alpha) - - # E. Pack RGBA on GPU - processed_rgba = torch.cat([fg_premul_lin, processed_alpha], dim=-1) - - # === Bulk transfer to CPU via copy stream === - # - # Pack all outputs into one contiguous tensor, DMA to a pinned - # buffer on the copy stream. Layout varies by comp mode: - # Checkerboard: [H, W, 1+3+3+4] = [alpha, fg, comp_rgb, processed] - # Transparent: [H, W, 1+3+4+4] = [alpha, fg, comp_rgba, processed] - # No comp: [H, W, 1+3+3+4] = [alpha, fg, zeros, processed] - comp_channels = 3 - bulk = torch.cat([res_alpha, res_fg, torch.zeros_like(res_fg), processed_rgba], dim=-1) - - bulk_np = bulk.cpu().numpy() - cc = comp_channels - if bulk_np.shape[0] == 1: - result = { - "alpha": bulk_np[:, :, 0:1], - "fg": bulk_np[:, :, 1:4], - "comp": bulk_np[:, :, 4 : 4 + cc], - "processed": bulk_np[:, :, 4 + cc : 4 + cc + 4], - } - return result - - out = [] - for i in range(bulk_np.shape[0]): - result = { - "alpha": bulk_np[i, :, :, 0:1], - "fg": bulk_np[i, :, :, 1:4], - "comp": bulk_np[i, :, :, 4 : 4 + cc], - "processed": bulk_np[i, :, :, 4 + cc : 4 + cc + 4], - } - out.append(result) - + self, + pred_alpha: torch.Tensor, + pred_fg: torch.Tensor, + w: int, + h: int, + fg_is_straight: bool, + despill_strength: float, + auto_despeckle: bool, + despeckle_size: int, + ) -> list[dict[str, np.ndarray]]: + """Post-process on GPU, transfer final results to CPU. + + When ``sync=True`` (default), blocks until transfer completes and + returns numpy arrays. When ``sync=False``, starts the DMA + non-blocking and returns a :class:`PendingTransfer` — call + ``.resolve()`` to get the numpy dict later. + """ + # Resize on GPU using F.interpolate (much faster than cv2 at 4K) + alpha_up = TF.resize( + pred_alpha.float(), + [h, w], + interpolation=torchvision.transforms.InterpolationMode.BILINEAR, + ) + fg_up = TF.resize( + pred_fg.float(), + [h, w], + interpolation=torchvision.transforms.InterpolationMode.BILINEAR, + ) + + # Convert to HWC on GPU + res_alpha = alpha_up.permute(0, 2, 3, 1) # [B, H, W, 1] + res_fg = fg_up.permute(0, 2, 3, 1) # [B, H, W, 3] + + # A. Clean matte + if auto_despeckle: + processed_alpha = self._clean_matte_gpu(res_alpha, despeckle_size, dilation=25, blur_size=5) + else: + processed_alpha = res_alpha + + # B. Despill on GPU + fg_despilled = self._despill_gpu(res_fg, despill_strength) + + # C. sRGB → linear on GPU + fg_despilled_lin = cu.srgb_to_linear(fg_despilled) + + # D. Premultiply on GPU + fg_premul_lin = cu.premultiply(fg_despilled_lin, processed_alpha) + + # E. Pack RGBA on GPU + processed_rgba = torch.cat([fg_premul_lin, processed_alpha], dim=-1) + + # F. Composite + bg_lin = self._get_checkerboard_linear_gpu(w, h) + if fg_is_straight: + comp_lin = cu.composite_straight(fg_despilled_lin, bg_lin, processed_alpha) + else: + comp_lin = cu.composite_premul(fg_despilled_lin, bg_lin, processed_alpha) + comp_srgb = cu.linear_to_srgb(comp_lin) # [H, W, 3] opaque + + res_alpha, res_fg, comp_srgb, processed_rgba = ( + res_alpha.cpu(), + res_fg.cpu(), + comp_srgb.cpu(), + processed_rgba.cpu(), + ) + + out = [] + for i in range(res_alpha.shape[0]): + result = { + "alpha": res_alpha[i].numpy(), + "fg": res_fg[i].numpy(), + "comp": comp_srgb[i].numpy(), + "processed": processed_rgba[i].numpy(), + } + out.append(result) + return out + @staticmethod @torch.compile() def _clean_matte_gpu(alpha: torch.Tensor, area_threshold: int, dilation: int, blur_size: int) -> torch.Tensor: @@ -314,7 +320,7 @@ def _clean_matte_gpu(alpha: torch.Tensor, area_threshold: int, dilation: int, bl _device = alpha.device # alpha: [H, W, 1] a2d = alpha[..., 0] - mask = (a2d > 0.5).float().unsqueeze(0) # [B, 1, H, W] + mask = (a2d > 0.5).float().unsqueeze(-3) # [B, 1, H, W] # Erode: kill spots smaller than area_threshold # A circle of area A has radius r = sqrt(A / pi) @@ -333,7 +339,7 @@ def _clean_matte_gpu(alpha: torch.Tensor, area_threshold: int, dilation: int, bl # Blur for soft edges if blur_size > 0: k = int(blur_size * 2 + 1) - mask = F.avg_pool2d(mask, k, stride=1, padding=blur_size) + mask = TF.gaussian_blur(mask, [k, k]) safe = mask.squeeze(0).squeeze(0) # [H, W] return (a2d * safe).unsqueeze(-1) # [H, W, 1] @@ -428,12 +434,12 @@ def scale_hook(module, input, output): if handle: handle.remove() - pred_alpha = prediction["alpha"][0].cpu().float() - pred_fg = prediction["fg"][0].cpu().float() # Output is sRGB (Sigmoid) + pred_alpha = prediction["alpha"].float() + pred_fg = prediction["fg"].float() # Output is sRGB (Sigmoid) - return self._postprocess_cpu( + return self._postprocess_gpu( pred_alpha, pred_fg, w, h, fg_is_straight, despill_strength, auto_despeckle, despeckle_size - ) + )[0] @torch.inference_mode() def batch_process_frames( @@ -514,6 +520,8 @@ def scale_hook(module, input, output): if handle: handle.remove() - out = self._postprocess_gpu(prediction["alpha"], prediction["fg"], h, w, auto_despeckle, despeckle_size, despill_strength, fg_is_straight) + out = self._postprocess_gpu( + prediction["alpha"], prediction["fg"], w, h, fg_is_straight, despill_strength, auto_despeckle, despeckle_size + ) return out From 13091631b4243aa27cfab20691b95b97ec609eea Mon Sep 17 00:00:00 2001 From: Marclie Date: Mon, 16 Mar 2026 15:50:22 +0100 Subject: [PATCH 21/44] optimize VRAM usage --- CorridorKeyModule/inference_engine.py | 63 +++++++++++++++------------ 1 file changed, 34 insertions(+), 29 deletions(-) diff --git a/CorridorKeyModule/inference_engine.py b/CorridorKeyModule/inference_engine.py index c154cdbb..ac34f822 100755 --- a/CorridorKeyModule/inference_engine.py +++ b/CorridorKeyModule/inference_engine.py @@ -241,7 +241,7 @@ def _get_checkerboard_linear_gpu(self, w: int, h: int) -> torch.Tensor: bg_srgb_3 = bg_srgb.unsqueeze(-1).expand(-1, -1, 3) return cu.srgb_to_linear(bg_srgb_3) - def _postprocess_gpu( + def _postprocess_torch( self, pred_alpha: torch.Tensor, pred_fg: torch.Tensor, @@ -251,6 +251,7 @@ def _postprocess_gpu( despill_strength: float, auto_despeckle: bool, despeckle_size: int, + generate_comp: bool = False, ) -> list[dict[str, np.ndarray]]: """Post-process on GPU, transfer final results to CPU. @@ -260,61 +261,65 @@ def _postprocess_gpu( ``.resolve()`` to get the numpy dict later. """ # Resize on GPU using F.interpolate (much faster than cv2 at 4K) - alpha_up = TF.resize( + alpha = TF.resize( pred_alpha.float(), [h, w], interpolation=torchvision.transforms.InterpolationMode.BILINEAR, ) - fg_up = TF.resize( + fg = TF.resize( pred_fg.float(), [h, w], interpolation=torchvision.transforms.InterpolationMode.BILINEAR, ) # Convert to HWC on GPU - res_alpha = alpha_up.permute(0, 2, 3, 1) # [B, H, W, 1] - res_fg = fg_up.permute(0, 2, 3, 1) # [B, H, W, 3] + alpha = alpha.permute(0, 2, 3, 1) # [B, H, W, 1] + fg = fg.permute(0, 2, 3, 1) # [B, H, W, 3] # A. Clean matte if auto_despeckle: - processed_alpha = self._clean_matte_gpu(res_alpha, despeckle_size, dilation=25, blur_size=5) + processed_alpha = self._clean_matte_gpu(alpha, despeckle_size, dilation=25, blur_size=5) else: - processed_alpha = res_alpha + processed_alpha = alpha # B. Despill on GPU - fg_despilled = self._despill_gpu(res_fg, despill_strength) + processed_fg = self._despill_gpu(fg, despill_strength) # C. sRGB → linear on GPU - fg_despilled_lin = cu.srgb_to_linear(fg_despilled) + processed_fg = cu.srgb_to_linear(processed_fg) # D. Premultiply on GPU - fg_premul_lin = cu.premultiply(fg_despilled_lin, processed_alpha) + processed_fg = cu.premultiply(processed_fg, processed_alpha) # E. Pack RGBA on GPU - processed_rgba = torch.cat([fg_premul_lin, processed_alpha], dim=-1) + packed_processed = torch.cat([processed_fg, processed_alpha], dim=-1) # F. Composite - bg_lin = self._get_checkerboard_linear_gpu(w, h) - if fg_is_straight: - comp_lin = cu.composite_straight(fg_despilled_lin, bg_lin, processed_alpha) + if generate_comp: + bg_lin = self._get_checkerboard_linear_gpu(w, h) + if fg_is_straight: + comp = cu.composite_straight(processed_fg, bg_lin, processed_alpha) + else: + comp = cu.composite_premul(processed_fg, bg_lin, processed_alpha) + comp = cu.linear_to_srgb(comp) # [H, W, 3] opaque else: - comp_lin = cu.composite_premul(fg_despilled_lin, bg_lin, processed_alpha) - comp_srgb = cu.linear_to_srgb(comp_lin) # [H, W, 3] opaque - - res_alpha, res_fg, comp_srgb, processed_rgba = ( - res_alpha.cpu(), - res_fg.cpu(), - comp_srgb.cpu(), - processed_rgba.cpu(), + del processed_fg, processed_alpha + comp = torch.zeros(h, w, 3, device=fg.device, dtype=fg.dtype) + + alpha, fg, comp, packed_processed = ( + alpha.cpu().numpy(), + fg.cpu().numpy(), + comp.cpu().numpy(), + packed_processed.cpu().numpy(), ) out = [] - for i in range(res_alpha.shape[0]): + for i in range(alpha.shape[0]): result = { - "alpha": res_alpha[i].numpy(), - "fg": res_fg[i].numpy(), - "comp": comp_srgb[i].numpy(), - "processed": processed_rgba[i].numpy(), + "alpha": alpha[i], + "fg": fg[i], + "comp": comp[i], + "processed": packed_processed[i], } out.append(result) return out @@ -452,7 +457,7 @@ def scale_hook(module, input, output): pred_alpha = prediction["alpha"].float() pred_fg = prediction["fg"].float() # Output is sRGB (Sigmoid) - return self._postprocess_gpu( + return self._postprocess_torch( pred_alpha, pred_fg, w, h, fg_is_straight, despill_strength, auto_despeckle, despeckle_size )[0] @@ -535,7 +540,7 @@ def scale_hook(module, input, output): if handle: handle.remove() - out = self._postprocess_gpu( + out = self._postprocess_torch( prediction["alpha"], prediction["fg"], w, h, fg_is_straight, despill_strength, auto_despeckle, despeckle_size ) From 2649c10be8af87be34a57ee0c73f1828d38bedb8 Mon Sep 17 00:00:00 2001 From: Marclie Date: Mon, 16 Mar 2026 21:52:31 +0100 Subject: [PATCH 22/44] Move to channels first format --- CorridorKeyModule/inference_engine.py | 191 +++++++++++++------------- 1 file changed, 96 insertions(+), 95 deletions(-) diff --git a/CorridorKeyModule/inference_engine.py b/CorridorKeyModule/inference_engine.py index ac34f822..e4e1a6a5 100755 --- a/CorridorKeyModule/inference_engine.py +++ b/CorridorKeyModule/inference_engine.py @@ -4,13 +4,15 @@ import math import os import sys +from functools import lru_cache import cv2 import numpy as np import torch import torch.nn.functional as F import torchvision -import torchvision.transforms.functional as TF +import torchvision.transforms.v2 as T +import torchvision.transforms.v2.functional as TF from .core import color_utils as cu from .core.model_transformer import GreenFormer @@ -18,6 +20,21 @@ logger = logging.getLogger(__name__) +@lru_cache(maxsize=4) +def _get_checkerboard_linear_torch(w: int, h: int, device: torch.device) -> torch.Tensor: + """Return a cached checkerboard tensor [3, H, W] on device in linear space.""" + print("Uncached checkerboard generation on GPU...") + checker_size = 128 + y_coords = torch.arange(h, device=device) // checker_size + x_coords = torch.arange(w, device=device) // checker_size + y_grid, x_grid = torch.meshgrid(y_coords, x_coords, indexing="ij") + checker = ((x_grid + y_grid) % 2).float() + # Map 0 -> 0.15, 1 -> 0.55 (sRGB), then convert to linear before caching + bg_srgb = checker * 0.4 + 0.15 # [H, W] + bg_srgb_3 = bg_srgb.unsqueeze(0).expand(3, -1, -1) + return cu.srgb_to_linear(bg_srgb_3) + + class CorridorKeyEngine: def __init__( self, @@ -33,8 +50,8 @@ def __init__( self.checkpoint_path = checkpoint_path self.use_refiner = use_refiner - self.mean = torch.tensor([0.485, 0.456, 0.406], dtype=model_precision, device=self.device).reshape(3, 1, 1) - self.std = torch.tensor([0.229, 0.224, 0.225], dtype=model_precision, device=self.device).reshape(3, 1, 1) + self.mean = torch.tensor([0.485, 0.456, 0.406], dtype=model_precision, device=self.device) + self.std = torch.tensor([0.229, 0.224, 0.225], dtype=model_precision, device=self.device) if mixed_precision or model_precision != torch.float32: # Use faster matrix multiplication implementation @@ -49,22 +66,7 @@ def __init__( self.model_precision = model_precision - model = self._load_model().to(model_precision) - - # We only tested compilation on windows and linux. For other platforms compilation is disabled as a precaution. - if sys.platform == "linux" or sys.platform == "win32": - # Try compiling the model. Fallback to eager mode if it fails. - try: - self.model = torch.compile(model) - # Trigger compilation with a dummy input - dummy_input = torch.zeros(1, 4, img_size, img_size, dtype=model_precision, device=self.device) - with torch.inference_mode(): - self.model(dummy_input) - except Exception as e: - logger.info(f"Model compilation failed with error: {e}") - logger.warning("Model compilation failed. Falling back to eager mode.") - torch.cuda.empty_cache() - self.model = model + self.model = self._load_model().to(model_precision) def _load_model(self) -> GreenFormer: logger.info("Loading CorridorKey from %s", self.checkpoint_path) @@ -121,9 +123,29 @@ def _load_model(self) -> GreenFormer: if len(unexpected) > 0: print(f"[Warning] Unexpected keys: {unexpected}") + model = model.to(self.model_precision) + # We only tested compilation on Windows and Linux. For other platforms compilation is disabled as a precaution. if sys.platform == "linux" or sys.platform == "win32": - model = torch.compile(model) + # Try compiling the model. Fallback to eager mode if it fails. + try: + compiled_model = torch.compile(model, mode="max-autotune") + # Trigger compilation with a dummy input + dummy_input = torch.zeros( + 1, 4, self.img_size, self.img_size, dtype=self.model_precision, device=self.device + ).to(memory_format=torch.channels_last) + with torch.inference_mode(): + compiled_model(dummy_input) + model = compiled_model + + self._preprocess_input = torch.compile(self._preprocess_input, mode="max-autotune") + self._despill_gpu = torch.compile(self._despill_gpu, mode="max-autotune") + self._clean_matte_gpu = torch.compile(self._clean_matte_gpu, mode="max-autotune") + + except Exception as e: + print(f"Model compilation failed with error: {e}") + logger.warning("Model compilation failed. Falling back to eager mode.") + torch.cuda.empty_cache() return model @@ -133,39 +155,30 @@ def _preprocess_input( # 2. Resize to Model Size # If input is linear, we resize in linear to preserve energy/highlights, # THEN convert to sRGB for the model. + image_batch = TF.resize( + image_batch, + [self.img_size, self.img_size], + interpolation=T.InterpolationMode.BILINEAR, + ) if input_is_linear: - # TODO: Check if interpolation is comparable to cv2.INTER_LINEAR (probably close enough) - img_resized_lin = TF.resize( - image_batch, - [self.img_size, self.img_size], - interpolation=torchvision.transforms.InterpolationMode.BILINEAR, - ) - # Convert to sRGB for Model - img_resized = cu.linear_to_srgb(img_resized_lin) - else: - # Standard sRGB Resize - img_resized = TF.resize( - image_batch, - [self.img_size, self.img_size], - interpolation=torchvision.transforms.InterpolationMode.BILINEAR, - ) + image_batch = cu.linear_to_srgb(image_batch) - mask_resized = TF.resize( + mask_batch_linear = TF.resize( mask_batch_linear, [self.img_size, self.img_size], - interpolation=torchvision.transforms.InterpolationMode.BILINEAR, + interpolation=T.InterpolationMode.BILINEAR, ) # 3. Normalize (ImageNet) # Model expects sRGB input normalized - img_norm = (img_resized - self.mean) / self.std + image_batch = TF.normalize(image_batch, self.mean, self.std) # 4. Prepare Tensor - inp_concat = torch.concat((img_norm, mask_resized), -3) # [4, H, W] + inp_concat = torch.concat((image_batch, mask_batch_linear), -3) # [4, H, W] return inp_concat - def _postprocess_cpu( + def _postprocess_opencv( self, pred_alpha: torch.Tensor, pred_fg: torch.Tensor, @@ -229,18 +242,6 @@ def _postprocess_cpu( "processed": processed_rgba, # Linear/Premul, RGBA, Garbage Matted & Despilled } - def _get_checkerboard_linear_gpu(self, w: int, h: int) -> torch.Tensor: - """Return a cached checkerboard tensor [H, W, 3] on device in linear space.""" - checker_size = 128 - y_coords = torch.arange(h, device=self.device) // checker_size - x_coords = torch.arange(w, device=self.device) // checker_size - y_grid, x_grid = torch.meshgrid(y_coords, x_coords, indexing="ij") - checker = ((x_grid + y_grid) % 2).float() - # Map 0 -> 0.15, 1 -> 0.55 (sRGB), then convert to linear before caching - bg_srgb = checker * 0.4 + 0.15 # [H, W] - bg_srgb_3 = bg_srgb.unsqueeze(-1).expand(-1, -1, 3) - return cu.srgb_to_linear(bg_srgb_3) - def _postprocess_torch( self, pred_alpha: torch.Tensor, @@ -251,7 +252,7 @@ def _postprocess_torch( despill_strength: float, auto_despeckle: bool, despeckle_size: int, - generate_comp: bool = False, + generate_comp: bool = True, ) -> list[dict[str, np.ndarray]]: """Post-process on GPU, transfer final results to CPU. @@ -272,9 +273,8 @@ def _postprocess_torch( interpolation=torchvision.transforms.InterpolationMode.BILINEAR, ) - # Convert to HWC on GPU - alpha = alpha.permute(0, 2, 3, 1) # [B, H, W, 1] - fg = fg.permute(0, 2, 3, 1) # [B, H, W, 3] + del pred_fg, pred_alpha + torch.cuda.empty_cache() # A. Clean matte if auto_despeckle: @@ -292,11 +292,11 @@ def _postprocess_torch( processed_fg = cu.premultiply(processed_fg, processed_alpha) # E. Pack RGBA on GPU - packed_processed = torch.cat([processed_fg, processed_alpha], dim=-1) + packed_processed = torch.cat([processed_fg, processed_alpha], dim=1) # F. Composite if generate_comp: - bg_lin = self._get_checkerboard_linear_gpu(w, h) + bg_lin = _get_checkerboard_linear_torch(w, h, self.device) if fg_is_straight: comp = cu.composite_straight(processed_fg, bg_lin, processed_alpha) else: @@ -304,13 +304,13 @@ def _postprocess_torch( comp = cu.linear_to_srgb(comp) # [H, W, 3] opaque else: del processed_fg, processed_alpha - comp = torch.zeros(h, w, 3, device=fg.device, dtype=fg.dtype) + comp = [None] * alpha.shape[0] # placeholder alpha, fg, comp, packed_processed = ( - alpha.cpu().numpy(), - fg.cpu().numpy(), - comp.cpu().numpy(), - packed_processed.cpu().numpy(), + alpha.cpu().permute(0, 2, 3, 1).numpy(), + fg.cpu().permute(0, 2, 3, 1).numpy(), + comp.cpu().permute(0, 2, 3, 1).numpy() if generate_comp else comp, + packed_processed.cpu().permute(0, 2, 3, 1).numpy(), ) out = [] @@ -325,7 +325,6 @@ def _postprocess_torch( return out @staticmethod - @torch.compile() def _clean_matte_gpu(alpha: torch.Tensor, area_threshold: int, dilation: int, blur_size: int) -> torch.Tensor: """Fully GPU matte cleanup using morphological operations. @@ -339,8 +338,7 @@ def _clean_matte_gpu(alpha: torch.Tensor, area_threshold: int, dilation: int, bl """ _device = alpha.device # alpha: [H, W, 1] - a2d = alpha[..., 0] - mask = (a2d > 0.5).float().unsqueeze(-3) # [B, 1, H, W] + mask = (alpha > 0.5).float() # [B, 1, H, W] # Erode: kill spots smaller than area_threshold # A circle of area A has radius r = sqrt(A / pi) @@ -361,22 +359,21 @@ def _clean_matte_gpu(alpha: torch.Tensor, area_threshold: int, dilation: int, bl k = int(blur_size * 2 + 1) mask = TF.gaussian_blur(mask, [k, k]) - safe = mask.squeeze(0).squeeze(0) # [H, W] - return (a2d * safe).unsqueeze(-1) # [H, W, 1] + safe = mask + return alpha * safe @staticmethod - @torch.compile() def _despill_gpu(image: torch.Tensor, strength: float) -> torch.Tensor: """GPU despill — keeps data on device.""" if strength <= 0.0: return image - r, g, b = image[..., 0], image[..., 1], image[..., 2] + r, g, b = image[:, 0], image[:, 1], image[:, 2] limit = (r + b) / 2.0 spill = torch.clamp(g - limit, min=0.0) g_new = g - spill r_new = r + spill * 0.5 b_new = b + spill * 0.5 - despilled = torch.stack([r_new, g_new, b_new], dim=-1) + despilled = torch.stack([r_new, g_new, b_new], dim=1) if strength < 1.0: return image * (1.0 - strength) + despilled * strength return despilled @@ -472,7 +469,6 @@ def batch_process_frames( despill_strength: float = 1.0, auto_despeckle: bool = True, despeckle_size: int = 400, - num_workers: int = torch.multiprocessing.cpu_count() // 2, ) -> list[dict[str, np.ndarray]]: """ Process a single frame. @@ -489,33 +485,32 @@ def batch_process_frames( despill_strength: float. 0.0 to 1.0 multiplier for the despill effect. auto_despeckle: bool. If True, cleans up small disconnected components from the predicted alpha matte. despeckle_size: int. Minimum number of consecutive pixels required to keep an island. - num_workers: int. Number of worker threads used for post-processing Returns: list[dict: {'alpha': np, 'fg': np (sRGB), 'comp': np (sRGB on Gray)}] """ - image_was_uint8 = images.dtype == np.uint8 - mask_was_uint8 = masks_linear.dtype == np.uint8 + bs, h, w = images.shape[:3] - # immediately casting to float is fine since fp16 can represent all uint8 values exactly - image = torch.from_numpy(images).to(self.model_precision).to(self.device) - mask_linear = torch.from_numpy(masks_linear).to(self.model_precision).to(self.device) # 1. Inputs Check & Normalization - if image_was_uint8: - image = image / 255.0 - - if mask_was_uint8: - mask_linear = mask_linear / 255.0 - - bs, h, w = image.shape[:3] - - # Ensure Mask Shape - if mask_linear.ndim == 3: - mask_linear = mask_linear.unsqueeze(-1) - - image = image.permute(0, 3, 1, 2) # [B, C, H, W] - mask_linear = mask_linear.permute(0, 3, 1, 2) # [B, C, H, W] + images = ( + TF.to_dtype( + torch.from_numpy(images).permute((0, 3, 1, 2)), + self.model_precision, + scale=True, + ) + .pin_memory() + .to(self.device, non_blocking=True) + ) + masks_linear = ( + TF.to_dtype( + torch.from_numpy(masks_linear.reshape((bs, h, w, 1))).permute((0, 3, 1, 2)), + self.model_precision, + scale=True, + ) + .pin_memory() + .to(self.device, non_blocking=True) + ) - inp_t = self._preprocess_input(image, mask_linear, input_is_linear) + inp_t = self._preprocess_input(images, masks_linear, input_is_linear) # Free up unused VRAM in order to keep peak usage down and avoid OOM errors torch.cuda.empty_cache() @@ -535,13 +530,19 @@ def scale_hook(module, input, output): # Free up unused VRAM in order to keep peak usage down and avoid OOM errors del inp_t - torch.cuda.empty_cache() if handle: handle.remove() out = self._postprocess_torch( - prediction["alpha"], prediction["fg"], w, h, fg_is_straight, despill_strength, auto_despeckle, despeckle_size + prediction["alpha"], + prediction["fg"], + w, + h, + fg_is_straight, + despill_strength, + auto_despeckle, + despeckle_size, ) return out From 64d993814f0043b2eaef0dfe8fa582d72f48b55c Mon Sep 17 00:00:00 2001 From: Marclie Date: Mon, 16 Mar 2026 21:54:44 +0100 Subject: [PATCH 23/44] improve logic --- test_vram.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/test_vram.py b/test_vram.py index 32c4d5ab..2306caf9 100644 --- a/test_vram.py +++ b/test_vram.py @@ -21,6 +21,8 @@ def batch_process_frame(engine: CorridorKeyEngine, batch_size: int): def test_vram(): + torch.backends.cudnn.benchmark = True + print("Loading engine...") engine = CorridorKeyEngine( checkpoint_path="CorridorKeyModule/checkpoints/CorridorKey_v1.0.pth", @@ -33,13 +35,19 @@ def test_vram(): # Reset stats torch.cuda.reset_peak_memory_stats() - iterations = 24 - batch_size = 6 # works with a 16GB GPU + total_seconds = 6 + batch_size = 2 # works with a 16GB GPU + iterations = total_seconds * 24 // batch_size print(f"Running {iterations} inference passes...") time = timeit.timeit( lambda: batch_process_frame(engine, batch_size), number=iterations, - setup=lambda: batch_process_frame(engine, batch_size), + setup=lambda: ( + batch_process_frame(engine, batch_size), + torch.cuda.synchronize(), + torch.cuda.empty_cache(), + print("Compilation and warmup complete, starting timed runs..."), + ), ) print(f"Seconds per frame: {time / (iterations * batch_size):.4f}") From ed7340d6dc7d21824a892345b47945bd4da010ce Mon Sep 17 00:00:00 2001 From: marclie Date: Tue, 17 Mar 2026 02:07:14 +0100 Subject: [PATCH 24/44] optimize clean_matte --- CorridorKeyModule/inference_engine.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/CorridorKeyModule/inference_engine.py b/CorridorKeyModule/inference_engine.py index e4e1a6a5..983ff4d3 100755 --- a/CorridorKeyModule/inference_engine.py +++ b/CorridorKeyModule/inference_engine.py @@ -337,7 +337,6 @@ def _clean_matte_gpu(alpha: torch.Tensor, area_threshold: int, dilation: int, bl eliminates it. """ _device = alpha.device - # alpha: [H, W, 1] mask = (alpha > 0.5).float() # [B, 1, H, W] # Erode: kill spots smaller than area_threshold @@ -351,16 +350,17 @@ def _clean_matte_gpu(alpha: torch.Tensor, area_threshold: int, dilation: int, bl # Dilate back to restore edges of large regions dilate_r = erode_r + (dilation if dilation > 0 else 0) - dilate_k = dilate_r * 2 + 1 - mask = F.max_pool2d(mask, dilate_k, stride=1, padding=dilate_r) + # How many applications with kernel size 5 are needed to achieve the desired dilation radius + repeats = dilate_r // 2 + for _ in range(repeats): + mask = F.max_pool2d(mask, 5, stride=1, padding=2) # Blur for soft edges if blur_size > 0: k = int(blur_size * 2 + 1) mask = TF.gaussian_blur(mask, [k, k]) - safe = mask - return alpha * safe + return alpha * mask @staticmethod def _despill_gpu(image: torch.Tensor, strength: float) -> torch.Tensor: From dc9e1091ba5d0812b1b2f65ca9660b06ec66b146 Mon Sep 17 00:00:00 2001 From: marclie Date: Tue, 17 Mar 2026 02:07:42 +0100 Subject: [PATCH 25/44] Add config options --- CorridorKeyModule/inference_engine.py | 67 ++++++++++++++++++--------- 1 file changed, 46 insertions(+), 21 deletions(-) diff --git a/CorridorKeyModule/inference_engine.py b/CorridorKeyModule/inference_engine.py index 983ff4d3..b0bf033d 100755 --- a/CorridorKeyModule/inference_engine.py +++ b/CorridorKeyModule/inference_engine.py @@ -188,6 +188,7 @@ def _postprocess_opencv( despill_strength: float, auto_despeckle: bool, despeckle_size: int, + generate_comp: bool, ) -> dict[str, np.ndarray]: # 6. Post-Process (Resize Back to Original Resolution) # We use Lanczos4 for high-quality resampling to minimize blur when going back to 4K/Original. @@ -224,16 +225,19 @@ def _postprocess_opencv( # 7. Composite (on Checkerboard) for checking # Generate Dark/Light Gray Checkerboard (in sRGB, convert to Linear) - bg_srgb = cu.create_checkerboard(w, h, checker_size=128, color1=0.15, color2=0.55) - bg_lin = cu.srgb_to_linear(bg_srgb) + if generate_comp: + bg_srgb = cu.create_checkerboard(w, h, checker_size=128, color1=0.15, color2=0.55) + bg_lin = cu.srgb_to_linear(bg_srgb) - if fg_is_straight: - comp_lin = cu.composite_straight(fg_despilled_lin, bg_lin, processed_alpha) - else: - # If premultiplied model, we shouldn't multiply again (though our pipeline forces straight) - comp_lin = cu.composite_premul(fg_despilled_lin, bg_lin, processed_alpha) + if fg_is_straight: + comp_lin = cu.composite_straight(fg_despilled_lin, bg_lin, processed_alpha) + else: + # If premultiplied model, we shouldn't multiply again (though our pipeline forces straight) + comp_lin = cu.composite_premul(fg_despilled_lin, bg_lin, processed_alpha) - comp_srgb = cu.linear_to_srgb(comp_lin) + comp_srgb = cu.linear_to_srgb(comp_lin) + else: + comp_srgb = None return { # type: ignore[return-value] # cu.* returns ndarray|Tensor but inputs are always ndarray here "alpha": res_alpha, # Linear, Raw Prediction @@ -252,7 +256,7 @@ def _postprocess_torch( despill_strength: float, auto_despeckle: bool, despeckle_size: int, - generate_comp: bool = True, + generate_comp: bool, ) -> list[dict[str, np.ndarray]]: """Post-process on GPU, transfer final results to CPU. @@ -469,6 +473,8 @@ def batch_process_frames( despill_strength: float = 1.0, auto_despeckle: bool = True, despeckle_size: int = 400, + generate_comp: bool = False, + post_process_on_gpu: bool = True, ) -> list[dict[str, np.ndarray]]: """ Process a single frame. @@ -497,7 +503,6 @@ def batch_process_frames( self.model_precision, scale=True, ) - .pin_memory() .to(self.device, non_blocking=True) ) masks_linear = ( @@ -506,7 +511,6 @@ def batch_process_frames( self.model_precision, scale=True, ) - .pin_memory() .to(self.device, non_blocking=True) ) @@ -534,15 +538,36 @@ def scale_hook(module, input, output): if handle: handle.remove() - out = self._postprocess_torch( - prediction["alpha"], - prediction["fg"], - w, - h, - fg_is_straight, - despill_strength, - auto_despeckle, - despeckle_size, - ) + if post_process_on_gpu: + out = self._postprocess_torch( + prediction["alpha"], + prediction["fg"], + w, + h, + fg_is_straight, + despill_strength, + auto_despeckle, + despeckle_size, + generate_comp, + ) + else: + # Move prediction to CPU before post-processing + pred_alpha = prediction["alpha"].cpu().float() + pred_fg = prediction["fg"].cpu().float() + + out = [] + for i in range(bs): + result = self._postprocess_opencv( + pred_alpha[i], + pred_fg[i], + w, + h, + fg_is_straight, + despill_strength, + auto_despeckle, + despeckle_size, + generate_comp, + ) + out.append(result) return out From 9bacf50a09ed76c077f69ef3cd37bd829205f770 Mon Sep 17 00:00:00 2001 From: marclie Date: Tue, 17 Mar 2026 02:15:39 +0100 Subject: [PATCH 26/44] Add changes to single frame method --- CorridorKeyModule/inference_engine.py | 39 ++++++++++++++++++++------- 1 file changed, 30 insertions(+), 9 deletions(-) diff --git a/CorridorKeyModule/inference_engine.py b/CorridorKeyModule/inference_engine.py index b0bf033d..844ce5a0 100755 --- a/CorridorKeyModule/inference_engine.py +++ b/CorridorKeyModule/inference_engine.py @@ -192,8 +192,8 @@ def _postprocess_opencv( ) -> dict[str, np.ndarray]: # 6. Post-Process (Resize Back to Original Resolution) # We use Lanczos4 for high-quality resampling to minimize blur when going back to 4K/Original. - res_alpha = pred_alpha.permute(1, 2, 0).numpy() - res_fg = pred_fg.permute(1, 2, 0).numpy() + res_alpha = pred_alpha.permute(1, 2, 0).cpu().numpy() + res_fg = pred_fg.permute(1, 2, 0).cpu().numpy() res_alpha = cv2.resize(res_alpha, (w, h), interpolation=cv2.INTER_LANCZOS4) res_fg = cv2.resize(res_fg, (w, h), interpolation=cv2.INTER_LANCZOS4) @@ -300,7 +300,7 @@ def _postprocess_torch( # F. Composite if generate_comp: - bg_lin = _get_checkerboard_linear_torch(w, h, self.device) + bg_lin = _get_checkerboard_linear_torch(w, h, processed_fg.device) if fg_is_straight: comp = cu.composite_straight(processed_fg, bg_lin, processed_alpha) else: @@ -393,6 +393,8 @@ def process_frame( despill_strength: float = 1.0, auto_despeckle: bool = True, despeckle_size: int = 400, + generate_comp: bool = True, + post_process_on_gpu: bool = True, ) -> dict[str, np.ndarray]: """ Process a single frame. @@ -455,12 +457,31 @@ def scale_hook(module, input, output): if handle: handle.remove() - pred_alpha = prediction["alpha"].float() - pred_fg = prediction["fg"].float() # Output is sRGB (Sigmoid) - - return self._postprocess_torch( - pred_alpha, pred_fg, w, h, fg_is_straight, despill_strength, auto_despeckle, despeckle_size - )[0] + if post_process_on_gpu: + out = self._postprocess_torch( + prediction["alpha"].float(), + prediction["fg"].float(), + w, + h, + fg_is_straight, + despill_strength, + auto_despeckle, + despeckle_size, + generate_comp, + )[0] # batch of 1, take first element + else: + out = self._postprocess_opencv( + prediction["alpha"][0].float(), + prediction["fg"][0].float(), + w, + h, + fg_is_straight, + despill_strength, + auto_despeckle, + despeckle_size, + generate_comp, + ) + return out @torch.inference_mode() def batch_process_frames( From b210153d60b7a8b686fd2b6586ac7c9adaf70e00 Mon Sep 17 00:00:00 2001 From: Marclie Date: Tue, 17 Mar 2026 23:23:04 +0100 Subject: [PATCH 27/44] clean up --- CorridorKeyModule/inference_engine.py | 31 +++++++++++---------------- 1 file changed, 13 insertions(+), 18 deletions(-) diff --git a/CorridorKeyModule/inference_engine.py b/CorridorKeyModule/inference_engine.py index b0bf033d..6abb2f12 100755 --- a/CorridorKeyModule/inference_engine.py +++ b/CorridorKeyModule/inference_engine.py @@ -23,7 +23,6 @@ @lru_cache(maxsize=4) def _get_checkerboard_linear_torch(w: int, h: int, device: torch.device) -> torch.Tensor: """Return a cached checkerboard tensor [3, H, W] on device in linear space.""" - print("Uncached checkerboard generation on GPU...") checker_size = 128 y_coords = torch.arange(h, device=device) // checker_size x_coords = torch.arange(w, device=device) // checker_size @@ -473,7 +472,7 @@ def batch_process_frames( despill_strength: float = 1.0, auto_despeckle: bool = True, despeckle_size: int = 400, - generate_comp: bool = False, + generate_comp: bool = True, post_process_on_gpu: bool = True, ) -> list[dict[str, np.ndarray]]: """ @@ -491,28 +490,24 @@ def batch_process_frames( despill_strength: float. 0.0 to 1.0 multiplier for the despill effect. auto_despeckle: bool. If True, cleans up small disconnected components from the predicted alpha matte. despeckle_size: int. Minimum number of consecutive pixels required to keep an island. + generate_comp: bool. If True, also generates a composite on checkerboard for quick checking. + post_process_on_gpu: bool. If True, performs post-processing on GPU using PyTorch instead of OpenCV. Returns: list[dict: {'alpha': np, 'fg': np (sRGB), 'comp': np (sRGB on Gray)}] """ bs, h, w = images.shape[:3] # 1. Inputs Check & Normalization - images = ( - TF.to_dtype( - torch.from_numpy(images).permute((0, 3, 1, 2)), - self.model_precision, - scale=True, - ) - .to(self.device, non_blocking=True) - ) - masks_linear = ( - TF.to_dtype( - torch.from_numpy(masks_linear.reshape((bs, h, w, 1))).permute((0, 3, 1, 2)), - self.model_precision, - scale=True, - ) - .to(self.device, non_blocking=True) - ) + images = TF.to_dtype( + torch.from_numpy(images).permute((0, 3, 1, 2)), + self.model_precision, + scale=True, + ).to(self.device, non_blocking=True) + masks_linear = TF.to_dtype( + torch.from_numpy(masks_linear.reshape((bs, h, w, 1))).permute((0, 3, 1, 2)), + self.model_precision, + scale=True, + ).to(self.device, non_blocking=True) inp_t = self._preprocess_input(images, masks_linear, input_is_linear) From 42cf2f88d5f247b82f52b7a970c86da28f961dba Mon Sep 17 00:00:00 2001 From: Marclie Date: Tue, 17 Mar 2026 23:23:34 +0100 Subject: [PATCH 28/44] use new methods --- clip_manager.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/clip_manager.py b/clip_manager.py index 4bc1bd44..908518fa 100644 --- a/clip_manager.py +++ b/clip_manager.py @@ -754,16 +754,16 @@ def run_inference( # 3. Process USE_STRAIGHT_MODEL = True - res = engine.process_frame( - img_srgb, - mask_linear, + res = engine.batch_process_frames( + img_srgb[np.newaxis, :], + mask_linear[np.newaxis, :], input_is_linear=input_is_linear, fg_is_straight=USE_STRAIGHT_MODEL, despill_strength=settings.despill_strength, auto_despeckle=settings.auto_despeckle, despeckle_size=settings.despeckle_size, refiner_scale=settings.refiner_scale, - ) + )[0] pred_fg = res["fg"] # sRGB pred_alpha = res["alpha"] # Linear @@ -782,10 +782,11 @@ def run_inference( cv2.imwrite(os.path.join(matte_dir, f"{input_stem}.exr"), pred_alpha, EXR_WRITE_FLAGS) # 5. Generate Reference Comp - comp_srgb = res["comp"] - # Save Comp (PNG 8-bit) - comp_bgr = cv2.cvtColor((np.clip(comp_srgb, 0.0, 1.0) * 255.0).astype(np.uint8), cv2.COLOR_RGB2BGR) - cv2.imwrite(os.path.join(comp_dir, f"{input_stem}.png"), comp_bgr) + if res["comp"] is not None: + comp_srgb = res["comp"] + # Save Comp (PNG 8-bit) + comp_bgr = cv2.cvtColor((np.clip(comp_srgb, 0.0, 1.0) * 255.0).astype(np.uint8), cv2.COLOR_RGB2BGR) + cv2.imwrite(os.path.join(comp_dir, f"{input_stem}.png"), comp_bgr) # 6. Save Processed (RGBA EXR) if "processed" in res: From 1c56ae288768875e69f0cf81c5ec3465b917646c Mon Sep 17 00:00:00 2001 From: Marclie Date: Wed, 18 Mar 2026 17:33:21 +0100 Subject: [PATCH 29/44] improved fast despeckle --- CorridorKeyModule/core/color_utils.py | 36 ++++++++++++++++++++++++++ CorridorKeyModule/inference_engine.py | 37 +++++++++++---------------- 2 files changed, 51 insertions(+), 22 deletions(-) diff --git a/CorridorKeyModule/core/color_utils.py b/CorridorKeyModule/core/color_utils.py index 4a3b17e0..cc28ea7b 100644 --- a/CorridorKeyModule/core/color_utils.py +++ b/CorridorKeyModule/core/color_utils.py @@ -6,6 +6,7 @@ import cv2 import numpy as np import torch +import torch.nn.functional as F def _is_tensor(x: np.ndarray | torch.Tensor) -> bool: @@ -247,6 +248,41 @@ def despill( return despilled +def connected_components(mask: torch.Tensor, min_component_width=1, max_iterations=100) -> torch.Tensor: + """ + Adapted from: https://gist.github.com/efirdc/5d8bd66859e574c683a504a4690ae8bc + mask: torch Tensor [B, 1, H, W] binary 1 or 0 + min_component_width: Minimum width of connected components that are separated instead of merged. + max_iterations: Maximum number of flood fill iterations. Adjust based on expected component sizes. + """ + bs, _, H, W = mask.shape + + # Reference implementation uses torch.arange instead of torch.randperm + # torch.randperm converges considerably faster and more uniformly + comp = torch.randperm(W * H).repeat(bs, 1).view(mask.shape).float().to(mask.device) + comp[mask != 1] = 0 + + prev_comp = torch.zeros_like(comp) + + iteration = 0 + + while not torch.equal(comp, prev_comp) and iteration < max_iterations: + prev_comp = comp.clone() + comp[mask == 1] = F.max_pool2d( + comp, kernel_size=(2 * min_component_width) + 1, stride=1, padding=min_component_width + )[mask == 1] + iteration += 1 + + comp = comp.long() + # Relabel components to have contiguous labels starting from 1 + unique_labels = torch.unique(comp) + label_map = torch.zeros(unique_labels.max().item() + 1, dtype=torch.long, device=mask.device) + label_map[unique_labels] = torch.arange(len(unique_labels), device=mask.device) + comp = label_map[comp] + + return comp + + def clean_matte(alpha_np: np.ndarray, area_threshold: int = 300, dilation: int = 15, blur_size: int = 5) -> np.ndarray: """ Cleans up small disconnected components (like tracking markers) from a predicted alpha matte. diff --git a/CorridorKeyModule/inference_engine.py b/CorridorKeyModule/inference_engine.py index 6abb2f12..ed06464a 100755 --- a/CorridorKeyModule/inference_engine.py +++ b/CorridorKeyModule/inference_engine.py @@ -329,34 +329,27 @@ def _postprocess_torch( @staticmethod def _clean_matte_gpu(alpha: torch.Tensor, area_threshold: int, dilation: int, blur_size: int) -> torch.Tensor: - """Fully GPU matte cleanup using morphological operations. - - Approximates connected-component removal by eroding small regions - away, then dilating back. Avoids the GPU→CPU→GPU roundtrip that - ``cv2.connectedComponentsWithStats`` would require. - - The erosion radius is derived from ``area_threshold``: a circular - spot of area A has radius sqrt(A/pi), so erosion by that radius - eliminates it. + """ + Fully GPU matte cleanup """ _device = alpha.device - mask = (alpha > 0.5).float() # [B, 1, H, W] + mask = alpha > 0.5 # [B, 1, H, W] - # Erode: kill spots smaller than area_threshold - # A circle of area A has radius r = sqrt(A / pi) - import math + # Find the largest connected components in the mask + # only a limited amount of iterations is needed to find components above the area threshold + components = cu.connected_components(mask, max_iterations=area_threshold // 8, min_component_width=2) + sizes = torch.bincount(components.flatten()) + big_sizes = torch.nonzero(sizes >= area_threshold) - erode_r = max(1, int(math.sqrt(area_threshold / math.pi))) - erode_k = erode_r * 2 + 1 - # Erosion = negative of max_pool on negated mask - mask = -F.max_pool2d(-mask, erode_k, stride=1, padding=erode_r) + mask = torch.zeros_like(mask).float() + mask[torch.isin(components, big_sizes)] = 1.0 # Dilate back to restore edges of large regions - dilate_r = erode_r + (dilation if dilation > 0 else 0) - # How many applications with kernel size 5 are needed to achieve the desired dilation radius - repeats = dilate_r // 2 - for _ in range(repeats): - mask = F.max_pool2d(mask, 5, stride=1, padding=2) + if dilation > 0: + # How many applications with kernel size 5 are needed to achieve the desired dilation radius + repeats = dilation // 2 + for _ in range(repeats): + mask = F.max_pool2d(mask, 5, stride=1, padding=2) # Blur for soft edges if blur_size > 0: From 15d3f669de2034ddd966eac6656654eb6a404ea3 Mon Sep 17 00:00:00 2001 From: Marclie Date: Wed, 18 Mar 2026 18:27:48 +0100 Subject: [PATCH 30/44] small fixes --- CorridorKeyModule/core/color_utils.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/CorridorKeyModule/core/color_utils.py b/CorridorKeyModule/core/color_utils.py index cc28ea7b..31ff8c3a 100644 --- a/CorridorKeyModule/core/color_utils.py +++ b/CorridorKeyModule/core/color_utils.py @@ -251,15 +251,18 @@ def despill( def connected_components(mask: torch.Tensor, min_component_width=1, max_iterations=100) -> torch.Tensor: """ Adapted from: https://gist.github.com/efirdc/5d8bd66859e574c683a504a4690ae8bc - mask: torch Tensor [B, 1, H, W] binary 1 or 0 - min_component_width: Minimum width of connected components that are separated instead of merged. - max_iterations: Maximum number of flood fill iterations. Adjust based on expected component sizes. + Args: + mask: torch Tensor [B, 1, H, W] binary 1 or 0 + min_component_width: int. Minimum width of connected components that are separated instead of merged. + max_iterations: int. Maximum number of flood fill iterations. Adjust based on expected component sizes. + Returns: + comp: torch Tensor [B, 1, H, W] with connected component labels (0 = background, 1..N = components) """ bs, _, H, W = mask.shape # Reference implementation uses torch.arange instead of torch.randperm # torch.randperm converges considerably faster and more uniformly - comp = torch.randperm(W * H).repeat(bs, 1).view(mask.shape).float().to(mask.device) + comp = (torch.randperm(W * H) + 1).repeat(bs, 1).view(mask.shape).float().to(mask.device) comp[mask != 1] = 0 prev_comp = torch.zeros_like(comp) From bd26920ecdf37dc0afbdabccacce14d3f7704db9 Mon Sep 17 00:00:00 2001 From: Marclie Date: Wed, 18 Mar 2026 21:29:37 +0100 Subject: [PATCH 31/44] fix compositing --- CorridorKeyModule/inference_engine.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/CorridorKeyModule/inference_engine.py b/CorridorKeyModule/inference_engine.py index cbc9dd26..c4135d05 100755 --- a/CorridorKeyModule/inference_engine.py +++ b/CorridorKeyModule/inference_engine.py @@ -289,10 +289,10 @@ def _postprocess_torch( processed_fg = self._despill_gpu(fg, despill_strength) # C. sRGB → linear on GPU - processed_fg = cu.srgb_to_linear(processed_fg) + processed_fg_lin = cu.srgb_to_linear(processed_fg) # D. Premultiply on GPU - processed_fg = cu.premultiply(processed_fg, processed_alpha) + processed_fg = cu.premultiply(processed_fg_lin, processed_alpha) # E. Pack RGBA on GPU packed_processed = torch.cat([processed_fg, processed_alpha], dim=1) @@ -301,9 +301,9 @@ def _postprocess_torch( if generate_comp: bg_lin = _get_checkerboard_linear_torch(w, h, processed_fg.device) if fg_is_straight: - comp = cu.composite_straight(processed_fg, bg_lin, processed_alpha) + comp = cu.composite_straight(processed_fg_lin, bg_lin, processed_alpha) else: - comp = cu.composite_premul(processed_fg, bg_lin, processed_alpha) + comp = cu.composite_premul(processed_fg_lin, bg_lin, processed_alpha) comp = cu.linear_to_srgb(comp) # [H, W, 3] opaque else: del processed_fg, processed_alpha From 39eb7fab3288199865f7f39217b4ade7066d53a1 Mon Sep 17 00:00:00 2001 From: Marclie Date: Wed, 18 Mar 2026 21:30:29 +0100 Subject: [PATCH 32/44] small fixes --- CorridorKeyModule/inference_engine.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/CorridorKeyModule/inference_engine.py b/CorridorKeyModule/inference_engine.py index c4135d05..d64e9594 100755 --- a/CorridorKeyModule/inference_engine.py +++ b/CorridorKeyModule/inference_engine.py @@ -139,7 +139,8 @@ def _load_model(self) -> GreenFormer: self._preprocess_input = torch.compile(self._preprocess_input, mode="max-autotune") self._despill_gpu = torch.compile(self._despill_gpu, mode="max-autotune") - self._clean_matte_gpu = torch.compile(self._clean_matte_gpu, mode="max-autotune") + # Raises runtime errors due to complicated logic being hard to compile + # self._clean_matte_gpu = torch.compile(self._clean_matte_gpu, mode="max-autotune") except Exception as e: print(f"Model compilation failed with error: {e}") @@ -264,7 +265,7 @@ def _postprocess_torch( non-blocking and returns a :class:`PendingTransfer` — call ``.resolve()`` to get the numpy dict later. """ - # Resize on GPU using F.interpolate (much faster than cv2 at 4K) + # Resize on GPU using torchvision (much faster than cv2 at 4K) alpha = TF.resize( pred_alpha.float(), [h, w], @@ -406,6 +407,7 @@ def process_frame( Returns: dict: {'alpha': np, 'fg': np (sRGB), 'comp': np (sRGB on Gray)} """ + torch.compiler.cudagraph_mark_step_begin() image_was_uint8 = image.dtype == np.uint8 mask_was_uint8 = mask_linear.dtype == np.uint8 @@ -509,6 +511,7 @@ def batch_process_frames( Returns: list[dict: {'alpha': np, 'fg': np (sRGB), 'comp': np (sRGB on Gray)}] """ + torch.compiler.cudagraph_mark_step_begin() bs, h, w = images.shape[:3] # 1. Inputs Check & Normalization From 64840069f6ca338bdf3650541924746cfda6ccfe Mon Sep 17 00:00:00 2001 From: Marclie Date: Thu, 19 Mar 2026 15:02:58 +0100 Subject: [PATCH 33/44] fix tests --- tests/test_e2e_workflow.py | 20 +++++++++++--------- tests/test_exr_gamma_bug_condition.py | 20 +++++++++++--------- tests/test_exr_gamma_preservation.py | 20 +++++++++++--------- 3 files changed, 33 insertions(+), 27 deletions(-) diff --git a/tests/test_e2e_workflow.py b/tests/test_e2e_workflow.py index de6df153..b2d1b900 100644 --- a/tests/test_e2e_workflow.py +++ b/tests/test_e2e_workflow.py @@ -22,14 +22,16 @@ # --------------------------------------------------------------------------- -def _fake_result(h: int = 4, w: int = 4) -> dict: +def _fake_result(h: int = 4, w: int = 4) -> list[dict]: """Return a minimal but valid process_frame result dict sized to (h, w).""" - return { - "alpha": np.full((h, w, 1), 0.8, dtype=np.float32), - "fg": np.full((h, w, 3), 0.6, dtype=np.float32), - "comp": np.full((h, w, 3), 0.5, dtype=np.float32), - "processed": np.full((h, w, 4), 0.4, dtype=np.float32), - } + return [ + { + "alpha": np.full((h, w, 1), 0.8, dtype=np.float32), + "fg": np.full((h, w, 3), 0.6, dtype=np.float32), + "comp": np.full((h, w, 3), 0.5, dtype=np.float32), + "processed": np.full((h, w, 4), 0.4, dtype=np.float32), + } + ] # --------------------------------------------------------------------------- @@ -56,7 +58,7 @@ def test_output_directories_created(self, tmp_clip_dir, monkeypatch): monkeypatch.setattr("builtins.input", lambda prompt="": "") mock_engine = MagicMock() - mock_engine.process_frame.return_value = _fake_result() + mock_engine.batch_process_frames.return_value = _fake_result() with patch("CorridorKeyModule.backend.create_engine", return_value=mock_engine): run_inference([entry], device="cpu") @@ -81,7 +83,7 @@ def test_output_files_written_per_frame(self, tmp_clip_dir, monkeypatch): monkeypatch.setattr("builtins.input", lambda prompt="": "") mock_engine = MagicMock() - mock_engine.process_frame.return_value = _fake_result() + mock_engine.batch_process_frames.return_value = _fake_result() with patch("CorridorKeyModule.backend.create_engine", return_value=mock_engine): run_inference([entry], device="cpu") diff --git a/tests/test_exr_gamma_bug_condition.py b/tests/test_exr_gamma_bug_condition.py index 52b75105..b7cbd9f8 100644 --- a/tests/test_exr_gamma_bug_condition.py +++ b/tests/test_exr_gamma_bug_condition.py @@ -195,17 +195,19 @@ def test_exr_srgb_frame_is_gamma_corrected(self, data: np.ndarray) -> None: # We'll capture what process_frame actually receives by patching it captured_args = {} - def mock_process_frame(image, mask_linear, *, input_is_linear=False, **kwargs): + def mock_batch_process_frames(image, mask_linear, *, input_is_linear=False, **kwargs): captured_args["image"] = image.copy() captured_args["input_is_linear"] = input_is_linear # Return minimal valid result h_img, w_img = image.shape[:2] - return { - "alpha": np.zeros((h_img, w_img, 1), dtype=np.float32), - "fg": np.zeros((h_img, w_img, 3), dtype=np.float32), - "comp": np.zeros((h_img, w_img, 3), dtype=np.float32), - "processed": np.zeros((h_img, w_img, 4), dtype=np.float32), - } + return [ + { + "alpha": np.zeros((h_img, w_img, 1), dtype=np.float32), + "fg": np.zeros((h_img, w_img, 3), dtype=np.float32), + "comp": np.zeros((h_img, w_img, 3), dtype=np.float32), + "processed": np.zeros((h_img, w_img, 4), dtype=np.float32), + } + ] # Build a mock clip that looks like an EXR image sequence mock_clip = MagicMock() @@ -228,7 +230,7 @@ def mock_process_frame(image, mask_linear, *, input_is_linear=False, **kwargs): # Mock the engine mock_engine = MagicMock() - mock_engine.process_frame = mock_process_frame + mock_engine.batch_process_frames = mock_batch_process_frames # Patch create_engine where it's imported from inside run_inference with patch("CorridorKeyModule.backend.create_engine", return_value=mock_engine): @@ -243,7 +245,7 @@ def mock_process_frame(image, mask_linear, *, input_is_linear=False, **kwargs): assert "image" in captured_args, "process_frame was never called — clip setup may be wrong" - actual_image = captured_args["image"] + actual_image = captured_args["image"][0] actual_is_linear = captured_args["input_is_linear"] # Defect 1: The frame should be gamma-corrected (sRGB), not raw linear diff --git a/tests/test_exr_gamma_preservation.py b/tests/test_exr_gamma_preservation.py index fdd903e8..82a8aa73 100644 --- a/tests/test_exr_gamma_preservation.py +++ b/tests/test_exr_gamma_preservation.py @@ -251,16 +251,18 @@ def test_linear_exr_passes_through_unchanged(self, data: np.ndarray) -> None: # Capture what process_frame actually receives captured_args = {} - def mock_process_frame(image, mask_linear, *, input_is_linear=False, **kwargs): + def mock_batch_process_frames(image, mask_linear, *, input_is_linear=False, **kwargs): captured_args["image"] = image.copy() captured_args["input_is_linear"] = input_is_linear h_img, w_img = image.shape[:2] - return { - "alpha": np.zeros((h_img, w_img, 1), dtype=np.float32), - "fg": np.zeros((h_img, w_img, 3), dtype=np.float32), - "comp": np.zeros((h_img, w_img, 3), dtype=np.float32), - "processed": np.zeros((h_img, w_img, 4), dtype=np.float32), - } + return [ + { + "alpha": np.zeros((h_img, w_img, 1), dtype=np.float32), + "fg": np.zeros((h_img, w_img, 3), dtype=np.float32), + "comp": np.zeros((h_img, w_img, 3), dtype=np.float32), + "processed": np.zeros((h_img, w_img, 4), dtype=np.float32), + } + ] mock_clip = MagicMock() mock_clip.name = "test_clip" @@ -281,7 +283,7 @@ def mock_process_frame(image, mask_linear, *, input_is_linear=False, **kwargs): mock_settings.refiner_scale = 1.0 mock_engine = MagicMock() - mock_engine.process_frame = mock_process_frame + mock_engine.batch_process_frames = mock_batch_process_frames with patch("CorridorKeyModule.backend.create_engine", return_value=mock_engine): from clip_manager import run_inference @@ -297,7 +299,7 @@ def mock_process_frame(image, mask_linear, *, input_is_linear=False, **kwargs): # The frame should be raw linear data — no gamma correction np.testing.assert_allclose( - captured_args["image"], + captured_args["image"][0], expected_linear, atol=1e-6, err_msg=( From 8221f78259406ede2b69ba5ff09e8a1a98b0116e Mon Sep 17 00:00:00 2001 From: Marclie Date: Thu, 19 Mar 2026 15:29:45 +0100 Subject: [PATCH 34/44] match batch processing --- CorridorKeyModule/inference_engine.py | 43 +++++++++++++-------------- 1 file changed, 21 insertions(+), 22 deletions(-) diff --git a/CorridorKeyModule/inference_engine.py b/CorridorKeyModule/inference_engine.py index d64e9594..09c26b23 100755 --- a/CorridorKeyModule/inference_engine.py +++ b/CorridorKeyModule/inference_engine.py @@ -404,34 +404,33 @@ def process_frame( despill_strength: float. 0.0 to 1.0 multiplier for the despill effect. auto_despeckle: bool. If True, cleans up small disconnected components from the predicted alpha matte. despeckle_size: int. Minimum number of consecutive pixels required to keep an island. + generate_comp: bool. If True, also generates a composite on checkerboard for quick checking. + post_process_on_gpu: bool. If True, performs post-processing on GPU using PyTorch instead of OpenCV. Returns: dict: {'alpha': np, 'fg': np (sRGB), 'comp': np (sRGB on Gray)} """ torch.compiler.cudagraph_mark_step_begin() - image_was_uint8 = image.dtype == np.uint8 - mask_was_uint8 = mask_linear.dtype == np.uint8 - - # immediately casting to float is fine since fp16 can represent all uint8 values exactly - image = torch.from_numpy(image).to(self.model_precision).to(self.device) - mask_linear = torch.from_numpy(mask_linear).to(self.model_precision).to(self.device) - # 1. Inputs Check & Normalization - if image_was_uint8: - image = image / 255.0 - - if mask_was_uint8: - mask_linear = mask_linear / 255.0 - h, w = image.shape[:2] - # Ensure Mask Shape - if mask_linear.ndim == 2: - mask_linear = mask_linear.unsqueeze(-1) - - image = image.permute(2, 0, 1) # [C, H, W] - mask_linear = mask_linear.permute(2, 0, 1) # [C, H, W] - - image = image.unsqueeze(0) - mask_linear = mask_linear.unsqueeze(0) + # 1. Inputs Check & Normalization + image = ( + TF.to_dtype( + torch.from_numpy(image).permute((2, 0, 1)), + self.model_precision, + scale=True, + ) + .to(self.device, non_blocking=True) + .unsqueeze(0) + ) + mask_linear = ( + TF.to_dtype( + torch.from_numpy(mask_linear.reshape((h, w, 1))).permute((2, 0, 1)), + self.model_precision, + scale=True, + ) + .to(self.device, non_blocking=True) + .unsqueeze(0) + ) inp_t = self._preprocess_input(image, mask_linear, input_is_linear) From 1f1970e3e283d355d8848c93d947cbaf34963170 Mon Sep 17 00:00:00 2001 From: Marclie Date: Thu, 19 Mar 2026 15:31:37 +0100 Subject: [PATCH 35/44] parameterize tests over backend --- tests/test_inference_engine.py | 92 ++++++++++++++++++++++------------ 1 file changed, 60 insertions(+), 32 deletions(-) diff --git a/tests/test_inference_engine.py b/tests/test_inference_engine.py index 8c17dcb1..e244a20e 100644 --- a/tests/test_inference_engine.py +++ b/tests/test_inference_engine.py @@ -53,47 +53,52 @@ def _make_engine_with_mock(mock_greenformer, img_size=64, device="cpu"): class TestProcessFrameOutputs: """Verify shape, dtype, and key presence of process_frame outputs.""" - def test_output_keys(self, sample_frame_rgb, sample_mask, mock_greenformer): + @pytest.mark.parametrize("backend", ["openCV", "torch"]) + def test_output_keys(self, sample_frame_rgb, sample_mask, mock_greenformer, backend): """process_frame must return alpha, fg, comp, and processed.""" engine = _make_engine_with_mock(mock_greenformer) - result = engine.process_frame(sample_frame_rgb, sample_mask) + result = engine.process_frame(sample_frame_rgb, sample_mask, post_process_on_gpu=backend == "torch") assert "alpha" in result assert "fg" in result assert "comp" in result assert "processed" in result - def test_output_shapes_match_input(self, sample_frame_rgb, sample_mask, mock_greenformer): + @pytest.mark.parametrize("backend", ["openCV", "torch"]) + def test_output_shapes_match_input(self, sample_frame_rgb, sample_mask, mock_greenformer, backend): """All outputs should match the spatial dimensions of the input.""" h, w = sample_frame_rgb.shape[:2] engine = _make_engine_with_mock(mock_greenformer) - result = engine.process_frame(sample_frame_rgb, sample_mask) + result = engine.process_frame(sample_frame_rgb, sample_mask, post_process_on_gpu=backend == "torch") assert result["alpha"].shape[:2] == (h, w) assert result["fg"].shape[:2] == (h, w) assert result["comp"].shape == (h, w, 3) assert result["processed"].shape == (h, w, 4) - def test_output_dtype_float32(self, sample_frame_rgb, sample_mask, mock_greenformer): + @pytest.mark.parametrize("backend", ["openCV", "torch"]) + def test_output_dtype_float32(self, sample_frame_rgb, sample_mask, mock_greenformer, backend): """All outputs should be float32 numpy arrays.""" engine = _make_engine_with_mock(mock_greenformer) - result = engine.process_frame(sample_frame_rgb, sample_mask) + result = engine.process_frame(sample_frame_rgb, sample_mask, post_process_on_gpu=backend == "torch") for key in ("alpha", "fg", "comp", "processed"): assert result[key].dtype == np.float32, f"{key} should be float32" - def test_alpha_output_range_is_zero_to_one(self, sample_frame_rgb, sample_mask, mock_greenformer): + @pytest.mark.parametrize("backend", ["openCV", "torch"]) + def test_alpha_output_range_is_zero_to_one(self, sample_frame_rgb, sample_mask, mock_greenformer, backend): """Alpha output must be in [0, 1] — values outside this range corrupt compositing.""" engine = _make_engine_with_mock(mock_greenformer) - result = engine.process_frame(sample_frame_rgb, sample_mask) + result = engine.process_frame(sample_frame_rgb, sample_mask, post_process_on_gpu=backend == "torch") alpha = result["alpha"] assert alpha.min() >= -0.01, f"alpha min {alpha.min():.4f} is below 0" assert alpha.max() <= 1.01, f"alpha max {alpha.max():.4f} is above 1" - def test_fg_output_range_is_zero_to_one(self, sample_frame_rgb, sample_mask, mock_greenformer): + @pytest.mark.parametrize("backend", ["openCV", "torch"]) + def test_fg_output_range_is_zero_to_one(self, sample_frame_rgb, sample_mask, mock_greenformer, backend): """FG output must be in [0, 1] — required for downstream sRGB conversion and EXR export.""" engine = _make_engine_with_mock(mock_greenformer) - result = engine.process_frame(sample_frame_rgb, sample_mask) + result = engine.process_frame(sample_frame_rgb, sample_mask, post_process_on_gpu=backend == "torch") fg = result["fg"] assert fg.min() >= -0.01, f"fg min {fg.min():.4f} is below 0" assert fg.max() <= 1.01, f"fg max {fg.max():.4f} is above 1" @@ -112,34 +117,42 @@ class TestProcessFrameColorSpace: When False (default), it resizes in sRGB directly. """ - def test_srgb_input_default(self, sample_frame_rgb, sample_mask, mock_greenformer): + @pytest.mark.parametrize("backend", ["openCV", "torch"]) + def test_srgb_input_default(self, sample_frame_rgb, sample_mask, mock_greenformer, backend): """Default sRGB path should not crash and should return valid outputs.""" engine = _make_engine_with_mock(mock_greenformer) - result = engine.process_frame(sample_frame_rgb, sample_mask, input_is_linear=False) + result = engine.process_frame( + sample_frame_rgb, sample_mask, input_is_linear=False, post_process_on_gpu=backend == "torch" + ) np.testing.assert_allclose(result["comp"], 0.545655, atol=1e-4) - def test_linear_input_path(self, sample_frame_rgb, sample_mask, mock_greenformer): + @pytest.mark.parametrize("backend", ["openCV", "torch"]) + def test_linear_input_path(self, sample_frame_rgb, sample_mask, mock_greenformer, backend): """Linear input path should convert to sRGB before model input.""" engine = _make_engine_with_mock(mock_greenformer) - result = engine.process_frame(sample_frame_rgb, sample_mask, input_is_linear=True) + result = engine.process_frame( + sample_frame_rgb, sample_mask, input_is_linear=True, post_process_on_gpu=backend == "torch" + ) assert result["comp"].shape == sample_frame_rgb.shape - def test_uint8_input_normalized(self, sample_mask, mock_greenformer): + @pytest.mark.parametrize("backend", ["openCV", "torch"]) + def test_uint8_input_normalized(self, sample_mask, mock_greenformer, backend): """uint8 input should be auto-converted to float32 [0, 1].""" img_uint8 = np.random.default_rng(42).integers(0, 256, (64, 64, 3), dtype=np.uint8) engine = _make_engine_with_mock(mock_greenformer) # Should not crash — uint8 is auto-normalized to float32 - result = engine.process_frame(img_uint8, sample_mask) + result = engine.process_frame(img_uint8, sample_mask, post_process_on_gpu=backend == "torch") assert result["alpha"].dtype == np.float32 - def test_model_called_exactly_once(self, sample_frame_rgb, sample_mask, mock_greenformer): + @pytest.mark.parametrize("backend", ["openCV", "torch"]) + def test_model_called_exactly_once(self, sample_frame_rgb, sample_mask, mock_greenformer, backend): """The neural network model must be called exactly once per process_frame() call. Double-inference would double latency and produce incorrect outputs. """ engine = _make_engine_with_mock(mock_greenformer) - engine.process_frame(sample_frame_rgb, sample_mask) + engine.process_frame(sample_frame_rgb, sample_mask, post_process_on_gpu=backend == "torch") assert mock_greenformer.call_count == 1 @@ -151,7 +164,8 @@ def test_model_called_exactly_once(self, sample_frame_rgb, sample_mask, mock_gre class TestProcessFramePostProcessing: """Verify post-processing: despill, despeckle, premultiply, composite.""" - def test_despill_strength_reduces_green_in_spill_pixels(self, sample_frame_rgb, sample_mask): + @pytest.mark.parametrize("backend", ["openCV", "torch"]) + def test_despill_strength_reduces_green_in_spill_pixels(self, sample_frame_rgb, sample_mask, backend): """despill_strength=1.0 must reduce green in spill pixels; strength=0.0 must leave it unchanged. The default mock_greenformer returns uniform gray (R=G=B=0.6) which has no @@ -178,8 +192,13 @@ def green_heavy_forward(x): green_mock.use_refiner = False engine = _make_engine_with_mock(green_mock) - result_no_despill = engine.process_frame(sample_frame_rgb, sample_mask, despill_strength=0.0) - result_full_despill = engine.process_frame(sample_frame_rgb, sample_mask, despill_strength=1.0) + + result_no_despill = engine.process_frame( + sample_frame_rgb, sample_mask, despill_strength=0.0, post_process_on_gpu=backend == "torch" + ) + result_full_despill = engine.process_frame( + sample_frame_rgb, sample_mask, despill_strength=1.0, post_process_on_gpu=backend == "torch" + ) rgb_none = result_no_despill["processed"][:, :, :3] rgb_full = result_full_despill["processed"][:, :, :3] @@ -194,13 +213,17 @@ def green_heavy_forward(x): "despill_strength=1.0 should reduce the green channel relative to strength=0.0 when G > (R+B)/2" ) - def test_auto_despeckle_toggle(self, sample_frame_rgb, sample_mask, mock_greenformer): + @pytest.mark.parametrize("backend", ["openCV", "torch"]) + def test_auto_despeckle_toggle(self, sample_frame_rgb, sample_mask, mock_greenformer, backend): """auto_despeckle=False should skip clean_matte without crashing.""" engine = _make_engine_with_mock(mock_greenformer) - result = engine.process_frame(sample_frame_rgb, sample_mask, auto_despeckle=False) + result = engine.process_frame( + sample_frame_rgb, sample_mask, auto_despeckle=False, post_process_on_gpu=backend == "torch" + ) assert result["alpha"].shape[:2] == sample_frame_rgb.shape[:2] - def test_processed_is_linear_premul_rgba(self, sample_frame_rgb, sample_mask, mock_greenformer): + @pytest.mark.parametrize("backend", ["openCV", "torch"]) + def test_processed_is_linear_premul_rgba(self, sample_frame_rgb, sample_mask, mock_greenformer, backend): """The 'processed' output should be 4-channel RGBA (linear, premultiplied). This is the EXR-ready output that compositors load into Nuke for @@ -208,7 +231,7 @@ def test_processed_is_linear_premul_rgba(self, sample_frame_rgb, sample_mask, mo means color is already multiplied by alpha). """ engine = _make_engine_with_mock(mock_greenformer) - result = engine.process_frame(sample_frame_rgb, sample_mask) + result = engine.process_frame(sample_frame_rgb, sample_mask, post_process_on_gpu=backend == "torch") processed = result["processed"] assert processed.shape[2] == 4 @@ -223,22 +246,26 @@ def test_processed_is_linear_premul_rgba(self, sample_frame_rgb, sample_mask, mo np.testing.assert_allclose(alpha, 0.8, atol=1e-5) np.testing.assert_allclose(rgb, expected_premul, atol=1e-4) - def test_mask_2d_vs_3d_input(self, sample_frame_rgb, mock_greenformer): + @pytest.mark.parametrize("backend", ["openCV", "torch"]) + def test_mask_2d_vs_3d_input(self, sample_frame_rgb, mock_greenformer, backend): """process_frame should accept both [H, W] and [H, W, 1] masks.""" engine = _make_engine_with_mock(mock_greenformer) mask_2d = np.ones((64, 64), dtype=np.float32) * 0.5 mask_3d = mask_2d[:, :, np.newaxis] - result_2d = engine.process_frame(sample_frame_rgb, mask_2d) - result_3d = engine.process_frame(sample_frame_rgb, mask_3d) + result_2d = engine.process_frame(sample_frame_rgb, mask_2d, post_process_on_gpu=backend == "torch") + result_3d = engine.process_frame(sample_frame_rgb, mask_3d, post_process_on_gpu=backend == "torch") # Both should produce the same output np.testing.assert_allclose(result_2d["alpha"], result_3d["alpha"], atol=1e-5) - def test_refiner_scale_parameter_accepted(self, sample_frame_rgb, sample_mask, mock_greenformer): + @pytest.mark.parametrize("backend", ["openCV", "torch"]) + def test_refiner_scale_parameter_accepted(self, sample_frame_rgb, sample_mask, mock_greenformer, backend): """Non-default refiner_scale must not raise — the parameter must be threaded through.""" engine = _make_engine_with_mock(mock_greenformer) - result = engine.process_frame(sample_frame_rgb, sample_mask, refiner_scale=0.5) + result = engine.process_frame( + sample_frame_rgb, sample_mask, refiner_scale=0.5, post_process_on_gpu=backend == "torch" + ) assert result["alpha"].shape[:2] == sample_frame_rgb.shape[:2] @@ -249,7 +276,8 @@ def test_refiner_scale_parameter_accepted(self, sample_frame_rgb, sample_mask, m class TestNvidiaGPUProcess: @pytest.mark.gpu - def test_process_frame_on_gpu(self, sample_frame_rgb, sample_mask, mock_greenformer): + @pytest.mark.parametrize("backend", ["openCV", "torch"]) + def test_process_frame_on_gpu(self, sample_frame_rgb, sample_mask, mock_greenformer, backend): """ Scenario: Process a frame using a CUDA-configured engine to verify cross-device compatibility. Expected: The mock model detects the GPU input and returns matching tensors without a device mismatch error. @@ -259,5 +287,5 @@ def test_process_frame_on_gpu(self, sample_frame_rgb, sample_mask, mock_greenfor engine = _make_engine_with_mock(mock_greenformer, device=torch.device("cuda")) - result = engine.process_frame(sample_frame_rgb, sample_mask) + result = engine.process_frame(sample_frame_rgb, sample_mask, post_process_on_gpu=backend == "torch") assert result["alpha"].dtype == np.float32 From 2e5b6ab11df51c69597ef7c4e6b347211dbd8779 Mon Sep 17 00:00:00 2001 From: Marclie Date: Thu, 19 Mar 2026 18:25:11 +0100 Subject: [PATCH 36/44] feat: add tests for batched frame processing --- tests/test_inference_engine.py | 210 +++++++++++++++++++++++++-------- 1 file changed, 164 insertions(+), 46 deletions(-) diff --git a/tests/test_inference_engine.py b/tests/test_inference_engine.py index e244a20e..8790cb83 100644 --- a/tests/test_inference_engine.py +++ b/tests/test_inference_engine.py @@ -54,10 +54,18 @@ class TestProcessFrameOutputs: """Verify shape, dtype, and key presence of process_frame outputs.""" @pytest.mark.parametrize("backend", ["openCV", "torch"]) - def test_output_keys(self, sample_frame_rgb, sample_mask, mock_greenformer, backend): + @pytest.mark.parametrize("batched", [True, False]) + def test_output_keys(self, sample_frame_rgb, sample_mask, mock_greenformer, backend, batched): """process_frame must return alpha, fg, comp, and processed.""" engine = _make_engine_with_mock(mock_greenformer) - result = engine.process_frame(sample_frame_rgb, sample_mask, post_process_on_gpu=backend == "torch") + if batched: + sample_frame_rgb = np.stack([sample_frame_rgb] * 2, axis=0) + sample_mask = np.stack([sample_mask] * 2, axis=0) + result = engine.batch_process_frames(sample_frame_rgb, sample_mask, post_process_on_gpu=backend == "torch")[ + 0 + ] + else: + result = engine.process_frame(sample_frame_rgb, sample_mask, post_process_on_gpu=backend == "torch") assert "alpha" in result assert "fg" in result @@ -65,11 +73,19 @@ def test_output_keys(self, sample_frame_rgb, sample_mask, mock_greenformer, back assert "processed" in result @pytest.mark.parametrize("backend", ["openCV", "torch"]) - def test_output_shapes_match_input(self, sample_frame_rgb, sample_mask, mock_greenformer, backend): + @pytest.mark.parametrize("batched", [True, False]) + def test_output_shapes_match_input(self, sample_frame_rgb, sample_mask, mock_greenformer, backend, batched): """All outputs should match the spatial dimensions of the input.""" h, w = sample_frame_rgb.shape[:2] engine = _make_engine_with_mock(mock_greenformer) - result = engine.process_frame(sample_frame_rgb, sample_mask, post_process_on_gpu=backend == "torch") + if batched: + sample_frame_rgb = np.stack([sample_frame_rgb] * 2, axis=0) + sample_mask = np.stack([sample_mask] * 2, axis=0) + result = engine.batch_process_frames(sample_frame_rgb, sample_mask, post_process_on_gpu=backend == "torch")[ + 0 + ] + else: + result = engine.process_frame(sample_frame_rgb, sample_mask, post_process_on_gpu=backend == "torch") assert result["alpha"].shape[:2] == (h, w) assert result["fg"].shape[:2] == (h, w) @@ -77,28 +93,52 @@ def test_output_shapes_match_input(self, sample_frame_rgb, sample_mask, mock_gre assert result["processed"].shape == (h, w, 4) @pytest.mark.parametrize("backend", ["openCV", "torch"]) - def test_output_dtype_float32(self, sample_frame_rgb, sample_mask, mock_greenformer, backend): + @pytest.mark.parametrize("batched", [True, False]) + def test_output_dtype_float32(self, sample_frame_rgb, sample_mask, mock_greenformer, backend, batched): """All outputs should be float32 numpy arrays.""" engine = _make_engine_with_mock(mock_greenformer) - result = engine.process_frame(sample_frame_rgb, sample_mask, post_process_on_gpu=backend == "torch") + if batched: + sample_frame_rgb = np.stack([sample_frame_rgb] * 2, axis=0) + sample_mask = np.stack([sample_mask] * 2, axis=0) + result = engine.batch_process_frames(sample_frame_rgb, sample_mask, post_process_on_gpu=backend == "torch")[ + 0 + ] + else: + result = engine.process_frame(sample_frame_rgb, sample_mask, post_process_on_gpu=backend == "torch") for key in ("alpha", "fg", "comp", "processed"): assert result[key].dtype == np.float32, f"{key} should be float32" @pytest.mark.parametrize("backend", ["openCV", "torch"]) - def test_alpha_output_range_is_zero_to_one(self, sample_frame_rgb, sample_mask, mock_greenformer, backend): + @pytest.mark.parametrize("batched", [True, False]) + def test_alpha_output_range_is_zero_to_one(self, sample_frame_rgb, sample_mask, mock_greenformer, backend, batched): """Alpha output must be in [0, 1] — values outside this range corrupt compositing.""" engine = _make_engine_with_mock(mock_greenformer) - result = engine.process_frame(sample_frame_rgb, sample_mask, post_process_on_gpu=backend == "torch") + if batched: + sample_frame_rgb = np.stack([sample_frame_rgb] * 2, axis=0) + sample_mask = np.stack([sample_mask] * 2, axis=0) + result = engine.batch_process_frames(sample_frame_rgb, sample_mask, post_process_on_gpu=backend == "torch")[ + 0 + ] + else: + result = engine.process_frame(sample_frame_rgb, sample_mask, post_process_on_gpu=backend == "torch") alpha = result["alpha"] assert alpha.min() >= -0.01, f"alpha min {alpha.min():.4f} is below 0" assert alpha.max() <= 1.01, f"alpha max {alpha.max():.4f} is above 1" @pytest.mark.parametrize("backend", ["openCV", "torch"]) - def test_fg_output_range_is_zero_to_one(self, sample_frame_rgb, sample_mask, mock_greenformer, backend): + @pytest.mark.parametrize("batched", [True, False]) + def test_fg_output_range_is_zero_to_one(self, sample_frame_rgb, sample_mask, mock_greenformer, backend, batched): """FG output must be in [0, 1] — required for downstream sRGB conversion and EXR export.""" engine = _make_engine_with_mock(mock_greenformer) - result = engine.process_frame(sample_frame_rgb, sample_mask, post_process_on_gpu=backend == "torch") + if batched: + sample_frame_rgb = np.stack([sample_frame_rgb] * 2, axis=0) + sample_mask = np.stack([sample_mask] * 2, axis=0) + result = engine.batch_process_frames(sample_frame_rgb, sample_mask, post_process_on_gpu=backend == "torch")[ + 0 + ] + else: + result = engine.process_frame(sample_frame_rgb, sample_mask, post_process_on_gpu=backend == "torch") fg = result["fg"] assert fg.min() >= -0.01, f"fg min {fg.min():.4f} is below 0" assert fg.max() <= 1.01, f"fg max {fg.max():.4f} is above 1" @@ -118,35 +158,57 @@ class TestProcessFrameColorSpace: """ @pytest.mark.parametrize("backend", ["openCV", "torch"]) - def test_srgb_input_default(self, sample_frame_rgb, sample_mask, mock_greenformer, backend): + @pytest.mark.parametrize("batched", [True, False]) + def test_srgb_input_default(self, sample_frame_rgb, sample_mask, mock_greenformer, backend, batched): """Default sRGB path should not crash and should return valid outputs.""" engine = _make_engine_with_mock(mock_greenformer) - result = engine.process_frame( - sample_frame_rgb, sample_mask, input_is_linear=False, post_process_on_gpu=backend == "torch" - ) + if batched: + sample_frame_rgb = np.stack([sample_frame_rgb] * 2, axis=0) + sample_mask = np.stack([sample_mask] * 2, axis=0) + result = engine.batch_process_frames( + sample_frame_rgb, sample_mask, input_is_linear=False, post_process_on_gpu=backend == "torch" + )[0] + else: + result = engine.process_frame( + sample_frame_rgb, sample_mask, input_is_linear=False, post_process_on_gpu=backend == "torch" + ) np.testing.assert_allclose(result["comp"], 0.545655, atol=1e-4) @pytest.mark.parametrize("backend", ["openCV", "torch"]) - def test_linear_input_path(self, sample_frame_rgb, sample_mask, mock_greenformer, backend): + @pytest.mark.parametrize("batched", [True, False]) + def test_linear_input_path(self, sample_frame_rgb, sample_mask, mock_greenformer, backend, batched): """Linear input path should convert to sRGB before model input.""" engine = _make_engine_with_mock(mock_greenformer) - result = engine.process_frame( - sample_frame_rgb, sample_mask, input_is_linear=True, post_process_on_gpu=backend == "torch" - ) - assert result["comp"].shape == sample_frame_rgb.shape + if batched: + sample_frame_rgb = np.stack([sample_frame_rgb] * 2, axis=0) + sample_mask = np.stack([sample_mask] * 2, axis=0) + result = engine.batch_process_frames( + sample_frame_rgb, sample_mask, input_is_linear=True, post_process_on_gpu=backend == "torch" + )[0] + else: + result = engine.process_frame( + sample_frame_rgb, sample_mask, input_is_linear=True, post_process_on_gpu=backend == "torch" + ) + assert result["comp"].shape == sample_frame_rgb.shape[1:] if batched else sample_frame_rgb.shape @pytest.mark.parametrize("backend", ["openCV", "torch"]) - def test_uint8_input_normalized(self, sample_mask, mock_greenformer, backend): + @pytest.mark.parametrize("batched", [True, False]) + def test_uint8_input_normalized(self, sample_mask, mock_greenformer, backend, batched): """uint8 input should be auto-converted to float32 [0, 1].""" img_uint8 = np.random.default_rng(42).integers(0, 256, (64, 64, 3), dtype=np.uint8) engine = _make_engine_with_mock(mock_greenformer) - # Should not crash — uint8 is auto-normalized to float32 - result = engine.process_frame(img_uint8, sample_mask, post_process_on_gpu=backend == "torch") + if batched: + img_uint8 = np.stack([img_uint8] * 2, axis=0) + sample_mask = np.stack([sample_mask] * 2, axis=0) + result = engine.batch_process_frames(img_uint8, sample_mask, post_process_on_gpu=backend == "torch")[0] + else: + result = engine.process_frame(img_uint8, sample_mask, post_process_on_gpu=backend == "torch") assert result["alpha"].dtype == np.float32 @pytest.mark.parametrize("backend", ["openCV", "torch"]) - def test_model_called_exactly_once(self, sample_frame_rgb, sample_mask, mock_greenformer, backend): + @pytest.mark.parametrize("batched", [True, False]) + def test_model_called_exactly_once(self, sample_frame_rgb, sample_mask, mock_greenformer, backend, batched): """The neural network model must be called exactly once per process_frame() call. Double-inference would double latency and produce incorrect outputs. @@ -165,7 +227,8 @@ class TestProcessFramePostProcessing: """Verify post-processing: despill, despeckle, premultiply, composite.""" @pytest.mark.parametrize("backend", ["openCV", "torch"]) - def test_despill_strength_reduces_green_in_spill_pixels(self, sample_frame_rgb, sample_mask, backend): + @pytest.mark.parametrize("batched", [True, False]) + def test_despill_strength_reduces_green_in_spill_pixels(self, sample_frame_rgb, sample_mask, backend, batched): """despill_strength=1.0 must reduce green in spill pixels; strength=0.0 must leave it unchanged. The default mock_greenformer returns uniform gray (R=G=B=0.6) which has no @@ -193,12 +256,22 @@ def green_heavy_forward(x): engine = _make_engine_with_mock(green_mock) - result_no_despill = engine.process_frame( - sample_frame_rgb, sample_mask, despill_strength=0.0, post_process_on_gpu=backend == "torch" - ) - result_full_despill = engine.process_frame( - sample_frame_rgb, sample_mask, despill_strength=1.0, post_process_on_gpu=backend == "torch" - ) + if batched: + sample_frame_rgb = np.stack([sample_frame_rgb] * 2, axis=0) + sample_mask = np.stack([sample_mask] * 2, axis=0) + result_no_despill = engine.batch_process_frames( + sample_frame_rgb, sample_mask, despill_strength=0.0, post_process_on_gpu=backend == "torch" + )[0] + result_full_despill = engine.batch_process_frames( + sample_frame_rgb, sample_mask, despill_strength=1.0, post_process_on_gpu=backend == "torch" + )[0] + else: + result_no_despill = engine.process_frame( + sample_frame_rgb, sample_mask, despill_strength=0.0, post_process_on_gpu=backend == "torch" + ) + result_full_despill = engine.process_frame( + sample_frame_rgb, sample_mask, despill_strength=1.0, post_process_on_gpu=backend == "torch" + ) rgb_none = result_no_despill["processed"][:, :, :3] rgb_full = result_full_despill["processed"][:, :, :3] @@ -214,16 +287,26 @@ def green_heavy_forward(x): ) @pytest.mark.parametrize("backend", ["openCV", "torch"]) - def test_auto_despeckle_toggle(self, sample_frame_rgb, sample_mask, mock_greenformer, backend): + @pytest.mark.parametrize("batched", [True, False]) + def test_auto_despeckle_toggle(self, sample_frame_rgb, sample_mask, mock_greenformer, backend, batched): """auto_despeckle=False should skip clean_matte without crashing.""" engine = _make_engine_with_mock(mock_greenformer) - result = engine.process_frame( - sample_frame_rgb, sample_mask, auto_despeckle=False, post_process_on_gpu=backend == "torch" - ) + if batched: + sample_frame_rgb = np.stack([sample_frame_rgb] * 2, axis=0) + sample_mask = np.stack([sample_mask] * 2, axis=0) + result = engine.batch_process_frames( + sample_frame_rgb, sample_mask, auto_despeckle=False, post_process_on_gpu=backend == "torch" + )[0] + sample_frame_rgb = sample_frame_rgb[0] # for the shape assertion below + else: + result = engine.process_frame( + sample_frame_rgb, sample_mask, auto_despeckle=False, post_process_on_gpu=backend == "torch" + ) assert result["alpha"].shape[:2] == sample_frame_rgb.shape[:2] @pytest.mark.parametrize("backend", ["openCV", "torch"]) - def test_processed_is_linear_premul_rgba(self, sample_frame_rgb, sample_mask, mock_greenformer, backend): + @pytest.mark.parametrize("batched", [True, False]) + def test_processed_is_linear_premul_rgba(self, sample_frame_rgb, sample_mask, mock_greenformer, backend, batched): """The 'processed' output should be 4-channel RGBA (linear, premultiplied). This is the EXR-ready output that compositors load into Nuke for @@ -231,7 +314,14 @@ def test_processed_is_linear_premul_rgba(self, sample_frame_rgb, sample_mask, mo means color is already multiplied by alpha). """ engine = _make_engine_with_mock(mock_greenformer) - result = engine.process_frame(sample_frame_rgb, sample_mask, post_process_on_gpu=backend == "torch") + if batched: + sample_frame_rgb = np.stack([sample_frame_rgb] * 2, axis=0) + sample_mask = np.stack([sample_mask] * 2, axis=0) + result = engine.batch_process_frames(sample_frame_rgb, sample_mask, post_process_on_gpu=backend == "torch")[ + 0 + ] + else: + result = engine.process_frame(sample_frame_rgb, sample_mask, post_process_on_gpu=backend == "torch") processed = result["processed"] assert processed.shape[2] == 4 @@ -247,25 +337,46 @@ def test_processed_is_linear_premul_rgba(self, sample_frame_rgb, sample_mask, mo np.testing.assert_allclose(rgb, expected_premul, atol=1e-4) @pytest.mark.parametrize("backend", ["openCV", "torch"]) - def test_mask_2d_vs_3d_input(self, sample_frame_rgb, mock_greenformer, backend): + @pytest.mark.parametrize("batched", [True, False]) + def test_mask_2d_vs_3d_input(self, sample_frame_rgb, mock_greenformer, backend, batched): """process_frame should accept both [H, W] and [H, W, 1] masks.""" engine = _make_engine_with_mock(mock_greenformer) mask_2d = np.ones((64, 64), dtype=np.float32) * 0.5 mask_3d = mask_2d[:, :, np.newaxis] - result_2d = engine.process_frame(sample_frame_rgb, mask_2d, post_process_on_gpu=backend == "torch") - result_3d = engine.process_frame(sample_frame_rgb, mask_3d, post_process_on_gpu=backend == "torch") + if batched: + sample_frame_rgb = np.stack([sample_frame_rgb] * 2, axis=0) + mask_2d = np.stack([mask_2d] * 2, axis=0) + mask_3d = np.stack([mask_3d] * 2, axis=0) + result_2d = engine.batch_process_frames(sample_frame_rgb, mask_2d, post_process_on_gpu=backend == "torch")[ + 0 + ] + result_3d = engine.batch_process_frames(sample_frame_rgb, mask_3d, post_process_on_gpu=backend == "torch")[ + 0 + ] + else: + result_2d = engine.process_frame(sample_frame_rgb, mask_2d, post_process_on_gpu=backend == "torch") + result_3d = engine.process_frame(sample_frame_rgb, mask_3d, post_process_on_gpu=backend == "torch") # Both should produce the same output np.testing.assert_allclose(result_2d["alpha"], result_3d["alpha"], atol=1e-5) @pytest.mark.parametrize("backend", ["openCV", "torch"]) - def test_refiner_scale_parameter_accepted(self, sample_frame_rgb, sample_mask, mock_greenformer, backend): + @pytest.mark.parametrize("batched", [True, False]) + def test_refiner_scale_parameter_accepted(self, sample_frame_rgb, sample_mask, mock_greenformer, backend, batched): """Non-default refiner_scale must not raise — the parameter must be threaded through.""" engine = _make_engine_with_mock(mock_greenformer) - result = engine.process_frame( - sample_frame_rgb, sample_mask, refiner_scale=0.5, post_process_on_gpu=backend == "torch" - ) + if batched: + sample_frame_rgb = np.stack([sample_frame_rgb] * 2, axis=0) + sample_mask = np.stack([sample_mask] * 2, axis=0) + result = engine.batch_process_frames( + sample_frame_rgb, sample_mask, refiner_scale=0.5, post_process_on_gpu=backend == "torch" + )[0] + sample_frame_rgb = sample_frame_rgb[0] # for the shape assertion below + else: + result = engine.process_frame( + sample_frame_rgb, sample_mask, refiner_scale=0.5, post_process_on_gpu=backend == "torch" + ) assert result["alpha"].shape[:2] == sample_frame_rgb.shape[:2] @@ -277,7 +388,8 @@ def test_refiner_scale_parameter_accepted(self, sample_frame_rgb, sample_mask, m class TestNvidiaGPUProcess: @pytest.mark.gpu @pytest.mark.parametrize("backend", ["openCV", "torch"]) - def test_process_frame_on_gpu(self, sample_frame_rgb, sample_mask, mock_greenformer, backend): + @pytest.mark.parametrize("batched", [True, False]) + def test_process_frame_on_gpu(self, sample_frame_rgb, sample_mask, mock_greenformer, backend, batched): """ Scenario: Process a frame using a CUDA-configured engine to verify cross-device compatibility. Expected: The mock model detects the GPU input and returns matching tensors without a device mismatch error. @@ -285,7 +397,13 @@ def test_process_frame_on_gpu(self, sample_frame_rgb, sample_mask, mock_greenfor if not torch.cuda.is_available(): pytest.skip("CUDA not available") - engine = _make_engine_with_mock(mock_greenformer, device=torch.device("cuda")) + engine = _make_engine_with_mock(mock_greenformer, device="cuda") - result = engine.process_frame(sample_frame_rgb, sample_mask, post_process_on_gpu=backend == "torch") + if batched: + sample_frame_rgb = np.stack([sample_frame_rgb] * 2, axis=0) + sample_mask = np.stack([sample_mask] * 2, axis=0) + result = engine.batch_process_frames(sample_frame_rgb, sample_mask, post_process_on_gpu=backend == "torch") + result = result[0] + else: + result = engine.process_frame(sample_frame_rgb, sample_mask, post_process_on_gpu=backend == "torch") assert result["alpha"].dtype == np.float32 From cfa20123598b322931999aeb0b8ac6e3825cf5bf Mon Sep 17 00:00:00 2001 From: Marclie Date: Fri, 20 Mar 2026 21:23:10 +0100 Subject: [PATCH 37/44] feat: cleanup + reorganization --- CorridorKeyModule/backend.py | 4 +- CorridorKeyModule/core/color_utils.py | 70 ++++++++++++++- CorridorKeyModule/inference_engine.py | 117 ++++++-------------------- test_outputs.py | 9 +- tests/test_color_utils.py | 32 +++---- 5 files changed, 117 insertions(+), 115 deletions(-) diff --git a/CorridorKeyModule/backend.py b/CorridorKeyModule/backend.py index 2e50d119..69358661 100644 --- a/CorridorKeyModule/backend.py +++ b/CorridorKeyModule/backend.py @@ -179,12 +179,12 @@ def _wrap_mlx_output(raw: dict, despill_strength: float, auto_despeckle: bool, d # Apply despeckle (MLX stubs this) if auto_despeckle: - processed_alpha = cu.clean_matte(alpha, area_threshold=despeckle_size, dilation=25, blur_size=5) + processed_alpha = cu.clean_matte_opencv(alpha, area_threshold=despeckle_size, dilation=25, blur_size=5) else: processed_alpha = alpha # Apply despill (MLX stubs this) - fg_despilled = cu.despill(fg, green_limit_mode="average", strength=despill_strength) + fg_despilled = cu.despill_opencv(fg, green_limit_mode="average", strength=despill_strength) # Composite over checkerboard for comp output h, w = fg.shape[:2] diff --git a/CorridorKeyModule/core/color_utils.py b/CorridorKeyModule/core/color_utils.py index 31ff8c3a..6ac059c9 100644 --- a/CorridorKeyModule/core/color_utils.py +++ b/CorridorKeyModule/core/color_utils.py @@ -7,6 +7,7 @@ import numpy as np import torch import torch.nn.functional as F +import torchvision.transforms.v2.functional as TF def _is_tensor(x: np.ndarray | torch.Tensor) -> bool: @@ -203,7 +204,7 @@ def apply_garbage_matte( return predicted_matte * garbage_mask -def despill( +def despill_opencv( image: np.ndarray | torch.Tensor, green_limit_mode: str = "average", strength: float = 1.0 ) -> np.ndarray | torch.Tensor: """ @@ -248,6 +249,22 @@ def despill( return despilled +def despill_torch(image: torch.Tensor, strength: float) -> torch.Tensor: + """GPU despill — keeps data on device.""" + if strength <= 0.0: + return image + r, g, b = image[:, 0], image[:, 1], image[:, 2] + limit = (r + b) / 2.0 + spill = torch.clamp(g - limit, min=0.0) + g_new = g - spill + r_new = r + spill * 0.5 + b_new = b + spill * 0.5 + despilled = torch.stack([r_new, g_new, b_new], dim=1) + if strength < 1.0: + return image * (1.0 - strength) + despilled * strength + return despilled + + def connected_components(mask: torch.Tensor, min_component_width=1, max_iterations=100) -> torch.Tensor: """ Adapted from: https://gist.github.com/efirdc/5d8bd66859e574c683a504a4690ae8bc @@ -286,7 +303,9 @@ def connected_components(mask: torch.Tensor, min_component_width=1, max_iteratio return comp -def clean_matte(alpha_np: np.ndarray, area_threshold: int = 300, dilation: int = 15, blur_size: int = 5) -> np.ndarray: +def clean_matte_opencv( + alpha_np: np.ndarray, area_threshold: int = 300, dilation: int = 15, blur_size: int = 5 +) -> np.ndarray: """ Cleans up small disconnected components (like tracking markers) from a predicted alpha matte. alpha_np: Numpy array [H, W] or [H, W, 1] float (0.0 - 1.0) @@ -334,6 +353,39 @@ def clean_matte(alpha_np: np.ndarray, area_threshold: int = 300, dilation: int = return result_alpha +def clean_matte_torch(alpha: torch.Tensor, area_threshold: int, dilation: int, blur_size: int) -> torch.Tensor: + """ + Cleans up small disconnected components (like tracking markers) from a predicted alpha matte. + Supports fully running on the GPU + alpha_np: torch Tensor [B, 1, H, W] (0.0 - 1.0) + """ + _device = alpha.device + mask = alpha > 0.5 # [B, 1, H, W] + + # Find the largest connected components in the mask + # only a limited amount of iterations is needed to find components above the area threshold + components = connected_components(mask, max_iterations=area_threshold // 8, min_component_width=2) + sizes = torch.bincount(components.flatten()) + big_sizes = torch.nonzero(sizes >= area_threshold) + + mask = torch.zeros_like(mask, dtype=torch.float32) + mask[torch.isin(components, big_sizes)] = 1.0 + + # Dilate back to restore edges of large regions + if dilation > 0: + # How many applications with kernel size 5 are needed to achieve the desired dilation radius + repeats = dilation // 2 + for _ in range(repeats): + mask = F.max_pool2d(mask, 5, stride=1, padding=2) + + # Blur for soft edges + if blur_size > 0: + k = int(blur_size * 2 + 1) + mask = TF.gaussian_blur(mask, [k, k]) + + return alpha * mask + + def create_checkerboard( width: int, height: int, checker_size: int = 64, color1: float = 0.2, color2: float = 0.4 ) -> np.ndarray: @@ -360,3 +412,17 @@ def create_checkerboard( # Make it 3-channel return np.stack([bg_img, bg_img, bg_img], axis=-1) + + +@functools.lru_cache(maxsize=4) +def get_checkerboard_linear_torch(w: int, h: int, device: torch.device) -> torch.Tensor: + """Return a cached checkerboard tensor [3, H, W] on device in linear space.""" + checker_size = 128 + y_coords = torch.arange(h, device=device) // checker_size + x_coords = torch.arange(w, device=device) // checker_size + y_grid, x_grid = torch.meshgrid(y_coords, x_coords, indexing="ij") + checker = ((x_grid + y_grid) % 2).float() + # Map 0 -> 0.15, 1 -> 0.55 (sRGB), then convert to linear before caching + bg_srgb = checker * 0.4 + 0.15 # [H, W] + bg_srgb_3 = bg_srgb.unsqueeze(0).expand(3, -1, -1) + return srgb_to_linear(bg_srgb_3) diff --git a/CorridorKeyModule/inference_engine.py b/CorridorKeyModule/inference_engine.py index 09c26b23..4c667faf 100755 --- a/CorridorKeyModule/inference_engine.py +++ b/CorridorKeyModule/inference_engine.py @@ -4,7 +4,6 @@ import math import os import sys -from functools import lru_cache import cv2 import numpy as np @@ -20,20 +19,6 @@ logger = logging.getLogger(__name__) -@lru_cache(maxsize=4) -def _get_checkerboard_linear_torch(w: int, h: int, device: torch.device) -> torch.Tensor: - """Return a cached checkerboard tensor [3, H, W] on device in linear space.""" - checker_size = 128 - y_coords = torch.arange(h, device=device) // checker_size - x_coords = torch.arange(w, device=device) // checker_size - y_grid, x_grid = torch.meshgrid(y_coords, x_coords, indexing="ij") - checker = ((x_grid + y_grid) % 2).float() - # Map 0 -> 0.15, 1 -> 0.55 (sRGB), then convert to linear before caching - bg_srgb = checker * 0.4 + 0.15 # [H, W] - bg_srgb_3 = bg_srgb.unsqueeze(0).expand(3, -1, -1) - return cu.srgb_to_linear(bg_srgb_3) - - class CorridorKeyEngine: def __init__( self, @@ -65,7 +50,11 @@ def __init__( self.model_precision = model_precision - self.model = self._load_model().to(model_precision) + self.model = self._load_model() + + # We only tested compilation on Windows and Linux. For other platforms compilation is disabled as a precaution. + if sys.platform == "linux" or sys.platform == "win32": + self._compile() def _load_model(self) -> GreenFormer: logger.info("Loading CorridorKey from %s", self.checkpoint_path) @@ -124,31 +113,24 @@ def _load_model(self) -> GreenFormer: model = model.to(self.model_precision) - # We only tested compilation on Windows and Linux. For other platforms compilation is disabled as a precaution. - if sys.platform == "linux" or sys.platform == "win32": - # Try compiling the model. Fallback to eager mode if it fails. - try: - compiled_model = torch.compile(model, mode="max-autotune") - # Trigger compilation with a dummy input - dummy_input = torch.zeros( - 1, 4, self.img_size, self.img_size, dtype=self.model_precision, device=self.device - ).to(memory_format=torch.channels_last) - with torch.inference_mode(): - compiled_model(dummy_input) - model = compiled_model - - self._preprocess_input = torch.compile(self._preprocess_input, mode="max-autotune") - self._despill_gpu = torch.compile(self._despill_gpu, mode="max-autotune") - # Raises runtime errors due to complicated logic being hard to compile - # self._clean_matte_gpu = torch.compile(self._clean_matte_gpu, mode="max-autotune") - - except Exception as e: - print(f"Model compilation failed with error: {e}") - logger.warning("Model compilation failed. Falling back to eager mode.") - torch.cuda.empty_cache() - return model + def _compile(self): + try: + compiled_model = torch.compile(self.model, mode="max-autotune") + # Trigger compilation with a dummy input + dummy_input = torch.zeros( + 1, 4, self.img_size, self.img_size, dtype=self.model_precision, device=self.device + ).to(memory_format=torch.channels_last) + with torch.inference_mode(): + compiled_model(dummy_input) + self.model = compiled_model + + except Exception as e: + logger.info(f"Compilation error: {e}") + logger.warning("Model compilation failed. Falling back to eager mode.") + torch.cuda.empty_cache() + def _preprocess_input( self, image_batch: torch.Tensor, mask_batch_linear: torch.Tensor, input_is_linear: bool ) -> torch.Tensor: @@ -204,13 +186,13 @@ def _postprocess_opencv( # A. Clean Matte (Auto-Despeckle) if auto_despeckle: - processed_alpha = cu.clean_matte(res_alpha, area_threshold=despeckle_size, dilation=25, blur_size=5) + processed_alpha = cu.clean_matte_opencv(res_alpha, area_threshold=despeckle_size, dilation=25, blur_size=5) else: processed_alpha = res_alpha # B. Despill FG # res_fg is sRGB. - fg_despilled = cu.despill(res_fg, green_limit_mode="average", strength=despill_strength) + fg_despilled = cu.despill_opencv(res_fg, green_limit_mode="average", strength=despill_strength) # C. Premultiply (for EXR Output) # CONVERT TO LINEAR FIRST! EXRs must house linear color premultiplied by linear alpha. @@ -282,12 +264,12 @@ def _postprocess_torch( # A. Clean matte if auto_despeckle: - processed_alpha = self._clean_matte_gpu(alpha, despeckle_size, dilation=25, blur_size=5) + processed_alpha = cu.clean_matte_torch(alpha, despeckle_size, dilation=25, blur_size=5) else: processed_alpha = alpha # B. Despill on GPU - processed_fg = self._despill_gpu(fg, despill_strength) + processed_fg = cu.despill_torch(fg, despill_strength) # C. sRGB → linear on GPU processed_fg_lin = cu.srgb_to_linear(processed_fg) @@ -300,7 +282,7 @@ def _postprocess_torch( # F. Composite if generate_comp: - bg_lin = _get_checkerboard_linear_torch(w, h, processed_fg.device) + bg_lin = cu.get_checkerboard_linear_torch(w, h, processed_fg.device) if fg_is_straight: comp = cu.composite_straight(processed_fg_lin, bg_lin, processed_alpha) else: @@ -328,53 +310,6 @@ def _postprocess_torch( out.append(result) return out - @staticmethod - def _clean_matte_gpu(alpha: torch.Tensor, area_threshold: int, dilation: int, blur_size: int) -> torch.Tensor: - """ - Fully GPU matte cleanup - """ - _device = alpha.device - mask = alpha > 0.5 # [B, 1, H, W] - - # Find the largest connected components in the mask - # only a limited amount of iterations is needed to find components above the area threshold - components = cu.connected_components(mask, max_iterations=area_threshold // 8, min_component_width=2) - sizes = torch.bincount(components.flatten()) - big_sizes = torch.nonzero(sizes >= area_threshold) - - mask = torch.zeros_like(mask).float() - mask[torch.isin(components, big_sizes)] = 1.0 - - # Dilate back to restore edges of large regions - if dilation > 0: - # How many applications with kernel size 5 are needed to achieve the desired dilation radius - repeats = dilation // 2 - for _ in range(repeats): - mask = F.max_pool2d(mask, 5, stride=1, padding=2) - - # Blur for soft edges - if blur_size > 0: - k = int(blur_size * 2 + 1) - mask = TF.gaussian_blur(mask, [k, k]) - - return alpha * mask - - @staticmethod - def _despill_gpu(image: torch.Tensor, strength: float) -> torch.Tensor: - """GPU despill — keeps data on device.""" - if strength <= 0.0: - return image - r, g, b = image[:, 0], image[:, 1], image[:, 2] - limit = (r + b) / 2.0 - spill = torch.clamp(g - limit, min=0.0) - g_new = g - spill - r_new = r + spill * 0.5 - b_new = b + spill * 0.5 - despilled = torch.stack([r_new, g_new, b_new], dim=1) - if strength < 1.0: - return image * (1.0 - strength) + despilled * strength - return despilled - @torch.inference_mode() def process_frame( self, diff --git a/test_outputs.py b/test_outputs.py index 9e16e10f..12098976 100644 --- a/test_outputs.py +++ b/test_outputs.py @@ -48,7 +48,7 @@ def generate_test_images(img_path, mask_path): print(f"Precision: {precision}, Image Size: {img_size}, Peak VRAM: {peak_vram:.2f} GB") -def compare_implementations(src, comparison): +def compare_implementations(src, comparison, output_dir="./Output"): for _, _, files in os.walk(src): for file in files: src_img = read_image(str(os.path.join(src, file))).float() @@ -68,7 +68,9 @@ def compare_implementations(src, comparison): else: difference = difference.abs() - save_image(difference, f"./Output/diff_{file}") + os.makedirs(output_dir, exist_ok=True) + + save_image(difference, f"{output_dir}/diff_{file}") def compare_floating_point_precision(folder, ref="float64"): @@ -138,5 +140,4 @@ def compare_img_sizes(folder, ref=1024): if __name__ == "__main__": - compare_img_sizes("./Output/original", 1024) - compare_img_sizes("./Output/original", 512) + compare_implementations("./Output/Comp", "./Output/Comp") diff --git a/tests/test_color_utils.py b/tests/test_color_utils.py index 10ea0aef..4a97569d 100644 --- a/tests/test_color_utils.py +++ b/tests/test_color_utils.py @@ -246,32 +246,32 @@ class TestDespill: def test_pure_green_reduced_average_mode_numpy(self): """A pure green pixel should have green clamped to (R+B)/2 = 0.""" img = _to_np([[0.0, 1.0, 0.0]]) - result = cu.despill(img, green_limit_mode="average", strength=1.0) + result = cu.despill_opencv(img, green_limit_mode="average", strength=1.0) # Green should be 0 (clamped to avg of R=0, B=0) assert result[0, 1] == pytest.approx(0.0, abs=1e-6) def test_pure_green_reduced_max_mode_numpy(self): """With 'max' mode, green clamped to max(R, B) = 0 for pure green.""" img = _to_np([[0.0, 1.0, 0.0]]) - result = cu.despill(img, green_limit_mode="max", strength=1.0) + result = cu.despill_opencv(img, green_limit_mode="max", strength=1.0) assert result[0, 1] == pytest.approx(0.0, abs=1e-6) def test_pure_red_unchanged_numpy(self): """A pixel with no green excess should not be modified.""" img = _to_np([[1.0, 0.0, 0.0]]) - result = cu.despill(img, green_limit_mode="average", strength=1.0) + result = cu.despill_opencv(img, green_limit_mode="average", strength=1.0) np.testing.assert_allclose(result, img, atol=1e-6) def test_strength_zero_is_noop_numpy(self): """strength=0 should return the input unchanged.""" img = _to_np([[0.2, 0.9, 0.1]]) - result = cu.despill(img, strength=0.0) + result = cu.despill_opencv(img, strength=0.0) np.testing.assert_allclose(result, img, atol=1e-7) def test_partial_green_average_mode_numpy(self): """Green slightly above (R+B)/2 should be reduced, not zeroed.""" img = _to_np([[0.4, 0.8, 0.2]]) - result = cu.despill(img, green_limit_mode="average", strength=1.0) + result = cu.despill_opencv(img, green_limit_mode="average", strength=1.0) limit = (0.4 + 0.2) / 2.0 # 0.3 expected_green = limit # green clamped to limit assert result[0, 1] == pytest.approx(expected_green, abs=1e-5) @@ -279,16 +279,16 @@ def test_partial_green_average_mode_numpy(self): def test_max_mode_higher_limit_than_average(self): """'max' mode uses max(R,B) which is >= (R+B)/2, so less despill.""" img = _to_np([[0.6, 0.8, 0.1]]) - result_avg = cu.despill(img, green_limit_mode="average", strength=1.0) - result_max = cu.despill(img, green_limit_mode="max", strength=1.0) + result_avg = cu.despill_opencv(img, green_limit_mode="average", strength=1.0) + result_max = cu.despill_opencv(img, green_limit_mode="max", strength=1.0) # max(R,B)=0.6 vs avg(R,B)=0.35, so max mode removes less green assert result_max[0, 1] >= result_avg[0, 1] def test_fractional_strength_interpolates(self): """strength=0.5 should produce a result between original and fully despilled.""" img = _to_np([[0.2, 0.9, 0.1]]) - full = cu.despill(img, green_limit_mode="average", strength=1.0) - half = cu.despill(img, green_limit_mode="average", strength=0.5) + full = cu.despill_opencv(img, green_limit_mode="average", strength=1.0) + half = cu.despill_opencv(img, green_limit_mode="average", strength=0.5) # Half-strength green should be between original green and fully despilled green assert half[0, 1] < img[0, 1] # less green than original assert half[0, 1] > full[0, 1] # more green than full despill @@ -300,8 +300,8 @@ def test_despill_torch(self): """Verify torch path matches numpy path.""" img_np = _to_np([[0.3, 0.9, 0.2]]) img_t = _to_torch([[0.3, 0.9, 0.2]]) - result_np = cu.despill(img_np, green_limit_mode="average", strength=1.0) - result_t = cu.despill(img_t, green_limit_mode="average", strength=1.0) + result_np = cu.despill_opencv(img_np, green_limit_mode="average", strength=1.0) + result_t = cu.despill_opencv(img_t, green_limit_mode="average", strength=1.0) np.testing.assert_allclose(result_np, result_t.numpy(), atol=1e-5) def test_green_below_limit_unchanged_numpy(self): @@ -315,7 +315,7 @@ def test_green_below_limit_unchanged_numpy(self): # G=0.3 is well below the average limit (0.8+0.6)/2 = 0.7 # spill_amount = max(0.3 - 0.7, 0) = 0 → output equals input img = _to_np([[0.8, 0.3, 0.6]]) - result = cu.despill(img, green_limit_mode="average", strength=1.0) + result = cu.despill_opencv(img, green_limit_mode="average", strength=1.0) np.testing.assert_allclose(result, img, atol=1e-6) @@ -335,7 +335,7 @@ def test_large_blob_preserved(self): """A single large opaque region should survive cleanup.""" matte = np.zeros((100, 100), dtype=np.float32) matte[20:80, 20:80] = 1.0 # 60x60 = 3600 pixels - result = cu.clean_matte(matte, area_threshold=300) + result = cu.clean_matte_opencv(matte, area_threshold=300) # Center of the blob should still be opaque assert result[50, 50] > 0.9 @@ -343,7 +343,7 @@ def test_small_blob_removed(self): """A tiny blob below the threshold should be removed.""" matte = np.zeros((100, 100), dtype=np.float32) matte[5:8, 5:8] = 1.0 # 3x3 = 9 pixels - result = cu.clean_matte(matte, area_threshold=300) + result = cu.clean_matte_opencv(matte, area_threshold=300) assert result[6, 6] == pytest.approx(0.0, abs=1e-5) def test_mixed_blobs(self): @@ -354,7 +354,7 @@ def test_mixed_blobs(self): # Small blob: 5x5 = 25 px matte[150:155, 150:155] = 1.0 - result = cu.clean_matte(matte, area_threshold=100) + result = cu.clean_matte_opencv(matte, area_threshold=100) assert result[35, 35] > 0.9 # large blob center preserved assert result[152, 152] < 0.01 # small blob removed @@ -362,7 +362,7 @@ def test_3d_input_preserved(self): """[H, W, 1] input should return [H, W, 1] output.""" matte = np.zeros((50, 50, 1), dtype=np.float32) matte[10:40, 10:40, 0] = 1.0 - result = cu.clean_matte(matte, area_threshold=100) + result = cu.clean_matte_opencv(matte, area_threshold=100) assert result.ndim == 3 assert result.shape[2] == 1 From bfd0b31f23254a12a5c7d8089c54503695dcd558 Mon Sep 17 00:00:00 2001 From: Marclie Date: Fri, 20 Mar 2026 22:28:39 +0100 Subject: [PATCH 38/44] feat: use full float16 precision --- CorridorKeyModule/backend.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/CorridorKeyModule/backend.py b/CorridorKeyModule/backend.py index 69358661..02e4427f 100644 --- a/CorridorKeyModule/backend.py +++ b/CorridorKeyModule/backend.py @@ -12,6 +12,7 @@ from pathlib import Path import numpy as np +import torch logger = logging.getLogger(__name__) @@ -287,4 +288,6 @@ def create_engine( from CorridorKeyModule.inference_engine import CorridorKeyEngine logger.info("Torch engine loaded: %s (device=%s)", ckpt.name, device) - return CorridorKeyEngine(checkpoint_path=str(ckpt), device=device or "cpu", img_size=img_size) + return CorridorKeyEngine( + checkpoint_path=str(ckpt), device=device or "cpu", img_size=img_size, model_precision=torch.float16 + ) From db17b634dbe9f7ad36832731b056d91f3ad2991e Mon Sep 17 00:00:00 2001 From: marclie Date: Sat, 21 Mar 2026 12:11:36 +0100 Subject: [PATCH 39/44] fix: remove channels_last format --- CorridorKeyModule/inference_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CorridorKeyModule/inference_engine.py b/CorridorKeyModule/inference_engine.py index 4c667faf..baeca87f 100755 --- a/CorridorKeyModule/inference_engine.py +++ b/CorridorKeyModule/inference_engine.py @@ -121,7 +121,7 @@ def _compile(self): # Trigger compilation with a dummy input dummy_input = torch.zeros( 1, 4, self.img_size, self.img_size, dtype=self.model_precision, device=self.device - ).to(memory_format=torch.channels_last) + ) with torch.inference_mode(): compiled_model(dummy_input) self.model = compiled_model From 93f8df8ff7d995ae9358d409321c204ca3392dd5 Mon Sep 17 00:00:00 2001 From: Marclie Date: Sat, 21 Mar 2026 21:41:02 +0100 Subject: [PATCH 40/44] feat: add CLI options --- clip_manager.py | 4 ++++ corridorkey_cli.py | 30 ++++++++++++++++++++++++++++++ 2 files changed, 34 insertions(+) diff --git a/clip_manager.py b/clip_manager.py index 908518fa..24001c5c 100644 --- a/clip_manager.py +++ b/clip_manager.py @@ -38,6 +38,8 @@ class InferenceSettings: auto_despeckle: bool = True despeckle_size: int = 400 refiner_scale: float = 1.0 + generate_comp: bool = True + gpu_post_processing: bool = False # Core Paths @@ -763,6 +765,8 @@ def run_inference( auto_despeckle=settings.auto_despeckle, despeckle_size=settings.despeckle_size, refiner_scale=settings.refiner_scale, + generate_comp=settings.generate_comp, + gpu_post_processing=settings.gpu_post_processing, )[0] pred_fg = res["fg"] # sRGB diff --git a/corridorkey_cli.py b/corridorkey_cli.py index 575ae128..6e4fda03 100644 --- a/corridorkey_cli.py +++ b/corridorkey_cli.py @@ -137,6 +137,8 @@ def _prompt_inference_settings( default_despeckle: bool | None = None, default_despeckle_size: int | None = None, default_refiner: float | None = None, + default_comp: bool | None = None, + default_gpu_post: bool | None = None, ) -> InferenceSettings: """Interactively prompt for inference settings, skipping any pre-filled values.""" console.print(Panel("Inference Settings", style="bold cyan")) @@ -189,12 +191,30 @@ def _prompt_inference_settings( except ValueError: refiner_scale = 1.0 + if default_comp is not None: + generate_comp = default_comp + else: + generate_comp = Confirm.ask( + "Generate composition previews", + default=True, + ) + + if default_gpu_post is not None: + gpu_post_processing = default_gpu_post + else: + gpu_post_processing = Confirm.ask( + "Use GPU accelerated post-processing [dim](experimental)[/dim]", + default=False, + ) + return InferenceSettings( input_is_linear=input_is_linear, despill_strength=despill_strength, auto_despeckle=auto_despeckle, despeckle_size=despeckle_size, refiner_scale=refiner_scale, + generate_comp=generate_comp, + gpu_post_processing=gpu_post_processing, ) @@ -273,6 +293,14 @@ def run_inference_cmd( Optional[float], typer.Option("--refiner", help="Refiner strength multiplier (default: prompt)"), ] = None, + generate_comp: Annotated[ + Optional[bool], + typer.Option("--comp/--no-comp", help="Generate comp previews (default: prompt)"), + ] = None, + gpu_post: Annotated[ + Optional[bool], + typer.Option("--gpu-post/--cpu-post", help="Use GPU post-processing (default: prompt)"), + ] = None, ) -> None: """Run CorridorKey inference on clips with Input + AlphaHint. @@ -300,6 +328,8 @@ def run_inference_cmd( default_despeckle=despeckle, default_despeckle_size=despeckle_size, default_refiner=refiner, + default_comp=generate_comp, + default_gpu_post=gpu_post, ) with ProgressContext() as ctx_progress: From d5208ab8fe79241d81fda4d639ee53e205db1bef Mon Sep 17 00:00:00 2001 From: Marclie Date: Sat, 21 Mar 2026 23:09:06 +0100 Subject: [PATCH 41/44] feat: remove redundant batch processing method --- CorridorKeyModule/inference_engine.py | 126 ++++---------------------- clip_manager.py | 8 +- test_outputs.py | 2 +- tests/test_e2e_workflow.py | 20 ++-- tests/test_exr_gamma_bug_condition.py | 20 ++-- tests/test_exr_gamma_preservation.py | 20 ++-- tests/test_inference_engine.py | 48 ++++------ 7 files changed, 65 insertions(+), 179 deletions(-) diff --git a/CorridorKeyModule/inference_engine.py b/CorridorKeyModule/inference_engine.py index baeca87f..3f43588d 100755 --- a/CorridorKeyModule/inference_engine.py +++ b/CorridorKeyModule/inference_engine.py @@ -323,14 +323,14 @@ def process_frame( despeckle_size: int = 400, generate_comp: bool = True, post_process_on_gpu: bool = True, - ) -> dict[str, np.ndarray]: + ) -> dict[str, np.ndarray] | list[dict[str, np.ndarray]]: """ Process a single frame. Args: - image: Numpy array [H, W, 3] (0.0-1.0 or 0-255). + image: Numpy array [H, W, 3] or [B, H, W, 3] (0.0-1.0 or 0-255). - If input_is_linear=False (Default): Assumed sRGB. - If input_is_linear=True: Assumed Linear. - mask_linear: Numpy array [H, W] or [H, W, 1] (0.0-1.0). Assumed Linear. + mask_linear: Numpy array [H, W] or [B, H, W] or [H, W, 1] or [B, H, W, 1] (0.0-1.0). Assumed Linear. refiner_scale: Multiplier for Refiner Deltas (default 1.0). input_is_linear: bool. If True, resizes in Linear then transforms to sRGB. If False, resizes in sRGB (standard). @@ -345,125 +345,30 @@ def process_frame( dict: {'alpha': np, 'fg': np (sRGB), 'comp': np (sRGB on Gray)} """ torch.compiler.cudagraph_mark_step_begin() - h, w = image.shape[:2] - - # 1. Inputs Check & Normalization - image = ( - TF.to_dtype( - torch.from_numpy(image).permute((2, 0, 1)), - self.model_precision, - scale=True, - ) - .to(self.device, non_blocking=True) - .unsqueeze(0) - ) - mask_linear = ( - TF.to_dtype( - torch.from_numpy(mask_linear.reshape((h, w, 1))).permute((2, 0, 1)), - self.model_precision, - scale=True, - ) - .to(self.device, non_blocking=True) - .unsqueeze(0) - ) - - inp_t = self._preprocess_input(image, mask_linear, input_is_linear) - - # 5. Inference - # Hook for Refiner Scaling - handle = None - if refiner_scale != 1.0 and self.model.refiner is not None: - - def scale_hook(module, input, output): - return output * refiner_scale - - handle = self.model.refiner.register_forward_hook(scale_hook) - with torch.autocast(device_type=self.device.type, dtype=torch.float16, enabled=self.mixed_precision): - prediction = self.model(inp_t) - - if handle: - handle.remove() - - if post_process_on_gpu: - out = self._postprocess_torch( - prediction["alpha"].float(), - prediction["fg"].float(), - w, - h, - fg_is_straight, - despill_strength, - auto_despeckle, - despeckle_size, - generate_comp, - )[0] # batch of 1, take first element - else: - out = self._postprocess_opencv( - prediction["alpha"][0].float(), - prediction["fg"][0].float(), - w, - h, - fg_is_straight, - despill_strength, - auto_despeckle, - despeckle_size, - generate_comp, - ) - return out + # If input is a single image, add batch dimension + if image.ndim == 3: + image = image[np.newaxis, :] + mask_linear = mask_linear[np.newaxis, :] - @torch.inference_mode() - def batch_process_frames( - self, - images: np.ndarray, - masks_linear: np.ndarray, - refiner_scale: float = 1.0, - input_is_linear: bool = False, - fg_is_straight: bool = True, - despill_strength: float = 1.0, - auto_despeckle: bool = True, - despeckle_size: int = 400, - generate_comp: bool = True, - post_process_on_gpu: bool = True, - ) -> list[dict[str, np.ndarray]]: - """ - Process a single frame. - Args: - images: Numpy array [B, H, W, 3] (0.0-1.0 or 0-255). - - If input_is_linear=False (Default): Assumed sRGB. - - If input_is_linear=True: Assumed Linear. - masks_linear: Numpy array [B, H, W] or [B, H, W, 1] (0.0-1.0). Assumed Linear. - refiner_scale: Multiplier for Refiner Deltas (default 1.0). - input_is_linear: bool. If True, resizes in Linear then transforms to sRGB. - If False, resizes in sRGB (standard). - fg_is_straight: bool. If True, assumes FG output is Straight (unpremultiplied). - If False, assumes FG output is Premultiplied. - despill_strength: float. 0.0 to 1.0 multiplier for the despill effect. - auto_despeckle: bool. If True, cleans up small disconnected components from the predicted alpha matte. - despeckle_size: int. Minimum number of consecutive pixels required to keep an island. - generate_comp: bool. If True, also generates a composite on checkerboard for quick checking. - post_process_on_gpu: bool. If True, performs post-processing on GPU using PyTorch instead of OpenCV. - Returns: - list[dict: {'alpha': np, 'fg': np (sRGB), 'comp': np (sRGB on Gray)}] - """ - torch.compiler.cudagraph_mark_step_begin() - bs, h, w = images.shape[:3] + bs, h, w = image.shape[:3] # 1. Inputs Check & Normalization - images = TF.to_dtype( - torch.from_numpy(images).permute((0, 3, 1, 2)), + image = TF.to_dtype( + torch.from_numpy(image).permute((0, 3, 1, 2)), self.model_precision, scale=True, ).to(self.device, non_blocking=True) - masks_linear = TF.to_dtype( - torch.from_numpy(masks_linear.reshape((bs, h, w, 1))).permute((0, 3, 1, 2)), + mask_linear = TF.to_dtype( + torch.from_numpy(mask_linear.reshape((bs, h, w, 1))).permute((0, 3, 1, 2)), self.model_precision, scale=True, ).to(self.device, non_blocking=True) - inp_t = self._preprocess_input(images, masks_linear, input_is_linear) + inp_t = self._preprocess_input(image, mask_linear, input_is_linear) # Free up unused VRAM in order to keep peak usage down and avoid OOM errors - torch.cuda.empty_cache() + del image, mask_linear # 5. Inference # Hook for Refiner Scaling @@ -516,4 +421,7 @@ def scale_hook(module, input, output): ) out.append(result) + if bs == 1: + return out[0] + return out diff --git a/clip_manager.py b/clip_manager.py index 24001c5c..80023281 100644 --- a/clip_manager.py +++ b/clip_manager.py @@ -756,9 +756,9 @@ def run_inference( # 3. Process USE_STRAIGHT_MODEL = True - res = engine.batch_process_frames( - img_srgb[np.newaxis, :], - mask_linear[np.newaxis, :], + res = engine.process_frame( + img_srgb, + mask_linear, input_is_linear=input_is_linear, fg_is_straight=USE_STRAIGHT_MODEL, despill_strength=settings.despill_strength, @@ -767,7 +767,7 @@ def run_inference( refiner_scale=settings.refiner_scale, generate_comp=settings.generate_comp, gpu_post_processing=settings.gpu_post_processing, - )[0] + ) pred_fg = res["fg"] # sRGB pred_alpha = res["alpha"] # Linear diff --git a/test_outputs.py b/test_outputs.py index 12098976..7c1e9645 100644 --- a/test_outputs.py +++ b/test_outputs.py @@ -140,4 +140,4 @@ def compare_img_sizes(folder, ref=1024): if __name__ == "__main__": - compare_implementations("./Output/Comp", "./Output/Comp") + compare_implementations("./Output/gpu_full_res/Comp", "./Output/gpu_fp16/Comp", "./Output/diff/fp16_vs_fp32") diff --git a/tests/test_e2e_workflow.py b/tests/test_e2e_workflow.py index b2d1b900..de6df153 100644 --- a/tests/test_e2e_workflow.py +++ b/tests/test_e2e_workflow.py @@ -22,16 +22,14 @@ # --------------------------------------------------------------------------- -def _fake_result(h: int = 4, w: int = 4) -> list[dict]: +def _fake_result(h: int = 4, w: int = 4) -> dict: """Return a minimal but valid process_frame result dict sized to (h, w).""" - return [ - { - "alpha": np.full((h, w, 1), 0.8, dtype=np.float32), - "fg": np.full((h, w, 3), 0.6, dtype=np.float32), - "comp": np.full((h, w, 3), 0.5, dtype=np.float32), - "processed": np.full((h, w, 4), 0.4, dtype=np.float32), - } - ] + return { + "alpha": np.full((h, w, 1), 0.8, dtype=np.float32), + "fg": np.full((h, w, 3), 0.6, dtype=np.float32), + "comp": np.full((h, w, 3), 0.5, dtype=np.float32), + "processed": np.full((h, w, 4), 0.4, dtype=np.float32), + } # --------------------------------------------------------------------------- @@ -58,7 +56,7 @@ def test_output_directories_created(self, tmp_clip_dir, monkeypatch): monkeypatch.setattr("builtins.input", lambda prompt="": "") mock_engine = MagicMock() - mock_engine.batch_process_frames.return_value = _fake_result() + mock_engine.process_frame.return_value = _fake_result() with patch("CorridorKeyModule.backend.create_engine", return_value=mock_engine): run_inference([entry], device="cpu") @@ -83,7 +81,7 @@ def test_output_files_written_per_frame(self, tmp_clip_dir, monkeypatch): monkeypatch.setattr("builtins.input", lambda prompt="": "") mock_engine = MagicMock() - mock_engine.batch_process_frames.return_value = _fake_result() + mock_engine.process_frame.return_value = _fake_result() with patch("CorridorKeyModule.backend.create_engine", return_value=mock_engine): run_inference([entry], device="cpu") diff --git a/tests/test_exr_gamma_bug_condition.py b/tests/test_exr_gamma_bug_condition.py index b7cbd9f8..52b75105 100644 --- a/tests/test_exr_gamma_bug_condition.py +++ b/tests/test_exr_gamma_bug_condition.py @@ -195,19 +195,17 @@ def test_exr_srgb_frame_is_gamma_corrected(self, data: np.ndarray) -> None: # We'll capture what process_frame actually receives by patching it captured_args = {} - def mock_batch_process_frames(image, mask_linear, *, input_is_linear=False, **kwargs): + def mock_process_frame(image, mask_linear, *, input_is_linear=False, **kwargs): captured_args["image"] = image.copy() captured_args["input_is_linear"] = input_is_linear # Return minimal valid result h_img, w_img = image.shape[:2] - return [ - { - "alpha": np.zeros((h_img, w_img, 1), dtype=np.float32), - "fg": np.zeros((h_img, w_img, 3), dtype=np.float32), - "comp": np.zeros((h_img, w_img, 3), dtype=np.float32), - "processed": np.zeros((h_img, w_img, 4), dtype=np.float32), - } - ] + return { + "alpha": np.zeros((h_img, w_img, 1), dtype=np.float32), + "fg": np.zeros((h_img, w_img, 3), dtype=np.float32), + "comp": np.zeros((h_img, w_img, 3), dtype=np.float32), + "processed": np.zeros((h_img, w_img, 4), dtype=np.float32), + } # Build a mock clip that looks like an EXR image sequence mock_clip = MagicMock() @@ -230,7 +228,7 @@ def mock_batch_process_frames(image, mask_linear, *, input_is_linear=False, **kw # Mock the engine mock_engine = MagicMock() - mock_engine.batch_process_frames = mock_batch_process_frames + mock_engine.process_frame = mock_process_frame # Patch create_engine where it's imported from inside run_inference with patch("CorridorKeyModule.backend.create_engine", return_value=mock_engine): @@ -245,7 +243,7 @@ def mock_batch_process_frames(image, mask_linear, *, input_is_linear=False, **kw assert "image" in captured_args, "process_frame was never called — clip setup may be wrong" - actual_image = captured_args["image"][0] + actual_image = captured_args["image"] actual_is_linear = captured_args["input_is_linear"] # Defect 1: The frame should be gamma-corrected (sRGB), not raw linear diff --git a/tests/test_exr_gamma_preservation.py b/tests/test_exr_gamma_preservation.py index 82a8aa73..fdd903e8 100644 --- a/tests/test_exr_gamma_preservation.py +++ b/tests/test_exr_gamma_preservation.py @@ -251,18 +251,16 @@ def test_linear_exr_passes_through_unchanged(self, data: np.ndarray) -> None: # Capture what process_frame actually receives captured_args = {} - def mock_batch_process_frames(image, mask_linear, *, input_is_linear=False, **kwargs): + def mock_process_frame(image, mask_linear, *, input_is_linear=False, **kwargs): captured_args["image"] = image.copy() captured_args["input_is_linear"] = input_is_linear h_img, w_img = image.shape[:2] - return [ - { - "alpha": np.zeros((h_img, w_img, 1), dtype=np.float32), - "fg": np.zeros((h_img, w_img, 3), dtype=np.float32), - "comp": np.zeros((h_img, w_img, 3), dtype=np.float32), - "processed": np.zeros((h_img, w_img, 4), dtype=np.float32), - } - ] + return { + "alpha": np.zeros((h_img, w_img, 1), dtype=np.float32), + "fg": np.zeros((h_img, w_img, 3), dtype=np.float32), + "comp": np.zeros((h_img, w_img, 3), dtype=np.float32), + "processed": np.zeros((h_img, w_img, 4), dtype=np.float32), + } mock_clip = MagicMock() mock_clip.name = "test_clip" @@ -283,7 +281,7 @@ def mock_batch_process_frames(image, mask_linear, *, input_is_linear=False, **kw mock_settings.refiner_scale = 1.0 mock_engine = MagicMock() - mock_engine.batch_process_frames = mock_batch_process_frames + mock_engine.process_frame = mock_process_frame with patch("CorridorKeyModule.backend.create_engine", return_value=mock_engine): from clip_manager import run_inference @@ -299,7 +297,7 @@ def mock_batch_process_frames(image, mask_linear, *, input_is_linear=False, **kw # The frame should be raw linear data — no gamma correction np.testing.assert_allclose( - captured_args["image"][0], + captured_args["image"], expected_linear, atol=1e-6, err_msg=( diff --git a/tests/test_inference_engine.py b/tests/test_inference_engine.py index 98f312f3..243b72bf 100644 --- a/tests/test_inference_engine.py +++ b/tests/test_inference_engine.py @@ -61,9 +61,7 @@ def test_output_keys(self, sample_frame_rgb, sample_mask, mock_greenformer, back if batched: sample_frame_rgb = np.stack([sample_frame_rgb] * 2, axis=0) sample_mask = np.stack([sample_mask] * 2, axis=0) - result = engine.batch_process_frames(sample_frame_rgb, sample_mask, post_process_on_gpu=backend == "torch")[ - 0 - ] + result = engine.process_frame(sample_frame_rgb, sample_mask, post_process_on_gpu=backend == "torch")[0] else: result = engine.process_frame(sample_frame_rgb, sample_mask, post_process_on_gpu=backend == "torch") @@ -81,9 +79,7 @@ def test_output_shapes_match_input(self, sample_frame_rgb, sample_mask, mock_gre if batched: sample_frame_rgb = np.stack([sample_frame_rgb] * 2, axis=0) sample_mask = np.stack([sample_mask] * 2, axis=0) - result = engine.batch_process_frames(sample_frame_rgb, sample_mask, post_process_on_gpu=backend == "torch")[ - 0 - ] + result = engine.process_frame(sample_frame_rgb, sample_mask, post_process_on_gpu=backend == "torch")[0] else: result = engine.process_frame(sample_frame_rgb, sample_mask, post_process_on_gpu=backend == "torch") @@ -100,9 +96,7 @@ def test_output_dtype_float32(self, sample_frame_rgb, sample_mask, mock_greenfor if batched: sample_frame_rgb = np.stack([sample_frame_rgb] * 2, axis=0) sample_mask = np.stack([sample_mask] * 2, axis=0) - result = engine.batch_process_frames(sample_frame_rgb, sample_mask, post_process_on_gpu=backend == "torch")[ - 0 - ] + result = engine.process_frame(sample_frame_rgb, sample_mask, post_process_on_gpu=backend == "torch")[0] else: result = engine.process_frame(sample_frame_rgb, sample_mask, post_process_on_gpu=backend == "torch") @@ -117,9 +111,7 @@ def test_alpha_output_range_is_zero_to_one(self, sample_frame_rgb, sample_mask, if batched: sample_frame_rgb = np.stack([sample_frame_rgb] * 2, axis=0) sample_mask = np.stack([sample_mask] * 2, axis=0) - result = engine.batch_process_frames(sample_frame_rgb, sample_mask, post_process_on_gpu=backend == "torch")[ - 0 - ] + result = engine.process_frame(sample_frame_rgb, sample_mask, post_process_on_gpu=backend == "torch")[0] else: result = engine.process_frame(sample_frame_rgb, sample_mask, post_process_on_gpu=backend == "torch") alpha = result["alpha"] @@ -134,9 +126,7 @@ def test_fg_output_range_is_zero_to_one(self, sample_frame_rgb, sample_mask, moc if batched: sample_frame_rgb = np.stack([sample_frame_rgb] * 2, axis=0) sample_mask = np.stack([sample_mask] * 2, axis=0) - result = engine.batch_process_frames(sample_frame_rgb, sample_mask, post_process_on_gpu=backend == "torch")[ - 0 - ] + result = engine.process_frame(sample_frame_rgb, sample_mask, post_process_on_gpu=backend == "torch")[0] else: result = engine.process_frame(sample_frame_rgb, sample_mask, post_process_on_gpu=backend == "torch") fg = result["fg"] @@ -165,7 +155,7 @@ def test_srgb_input_default(self, sample_frame_rgb, sample_mask, mock_greenforme if batched: sample_frame_rgb = np.stack([sample_frame_rgb] * 2, axis=0) sample_mask = np.stack([sample_mask] * 2, axis=0) - result = engine.batch_process_frames( + result = engine.process_frame( sample_frame_rgb, sample_mask, input_is_linear=False, post_process_on_gpu=backend == "torch" )[0] else: @@ -183,7 +173,7 @@ def test_linear_input_path(self, sample_frame_rgb, sample_mask, mock_greenformer if batched: sample_frame_rgb = np.stack([sample_frame_rgb] * 2, axis=0) sample_mask = np.stack([sample_mask] * 2, axis=0) - result = engine.batch_process_frames( + result = engine.process_frame( sample_frame_rgb, sample_mask, input_is_linear=True, post_process_on_gpu=backend == "torch" )[0] else: @@ -201,7 +191,7 @@ def test_uint8_input_normalized(self, sample_mask, mock_greenformer, backend, ba if batched: img_uint8 = np.stack([img_uint8] * 2, axis=0) sample_mask = np.stack([sample_mask] * 2, axis=0) - result = engine.batch_process_frames(img_uint8, sample_mask, post_process_on_gpu=backend == "torch")[0] + result = engine.process_frame(img_uint8, sample_mask, post_process_on_gpu=backend == "torch")[0] else: result = engine.process_frame(img_uint8, sample_mask, post_process_on_gpu=backend == "torch") assert result["alpha"].dtype == np.float32 @@ -259,10 +249,10 @@ def green_heavy_forward(x): if batched: sample_frame_rgb = np.stack([sample_frame_rgb] * 2, axis=0) sample_mask = np.stack([sample_mask] * 2, axis=0) - result_no_despill = engine.batch_process_frames( + result_no_despill = engine.process_frame( sample_frame_rgb, sample_mask, despill_strength=0.0, post_process_on_gpu=backend == "torch" )[0] - result_full_despill = engine.batch_process_frames( + result_full_despill = engine.process_frame( sample_frame_rgb, sample_mask, despill_strength=1.0, post_process_on_gpu=backend == "torch" )[0] else: @@ -294,7 +284,7 @@ def test_auto_despeckle_toggle(self, sample_frame_rgb, sample_mask, mock_greenfo if batched: sample_frame_rgb = np.stack([sample_frame_rgb] * 2, axis=0) sample_mask = np.stack([sample_mask] * 2, axis=0) - result = engine.batch_process_frames( + result = engine.process_frame( sample_frame_rgb, sample_mask, auto_despeckle=False, post_process_on_gpu=backend == "torch" )[0] sample_frame_rgb = sample_frame_rgb[0] # for the shape assertion below @@ -317,9 +307,7 @@ def test_processed_is_linear_premul_rgba(self, sample_frame_rgb, sample_mask, mo if batched: sample_frame_rgb = np.stack([sample_frame_rgb] * 2, axis=0) sample_mask = np.stack([sample_mask] * 2, axis=0) - result = engine.batch_process_frames(sample_frame_rgb, sample_mask, post_process_on_gpu=backend == "torch")[ - 0 - ] + result = engine.process_frame(sample_frame_rgb, sample_mask, post_process_on_gpu=backend == "torch")[0] else: result = engine.process_frame(sample_frame_rgb, sample_mask, post_process_on_gpu=backend == "torch") processed = result["processed"] @@ -348,12 +336,8 @@ def test_mask_2d_vs_3d_input(self, sample_frame_rgb, mock_greenformer, backend, sample_frame_rgb = np.stack([sample_frame_rgb] * 2, axis=0) mask_2d = np.stack([mask_2d] * 2, axis=0) mask_3d = np.stack([mask_3d] * 2, axis=0) - result_2d = engine.batch_process_frames(sample_frame_rgb, mask_2d, post_process_on_gpu=backend == "torch")[ - 0 - ] - result_3d = engine.batch_process_frames(sample_frame_rgb, mask_3d, post_process_on_gpu=backend == "torch")[ - 0 - ] + result_2d = engine.process_frame(sample_frame_rgb, mask_2d, post_process_on_gpu=backend == "torch")[0] + result_3d = engine.process_frame(sample_frame_rgb, mask_3d, post_process_on_gpu=backend == "torch")[0] else: result_2d = engine.process_frame(sample_frame_rgb, mask_2d, post_process_on_gpu=backend == "torch") result_3d = engine.process_frame(sample_frame_rgb, mask_3d, post_process_on_gpu=backend == "torch") @@ -369,7 +353,7 @@ def test_refiner_scale_parameter_accepted(self, sample_frame_rgb, sample_mask, m if batched: sample_frame_rgb = np.stack([sample_frame_rgb] * 2, axis=0) sample_mask = np.stack([sample_mask] * 2, axis=0) - result = engine.batch_process_frames( + result = engine.process_frame( sample_frame_rgb, sample_mask, refiner_scale=0.5, post_process_on_gpu=backend == "torch" )[0] sample_frame_rgb = sample_frame_rgb[0] # for the shape assertion below @@ -412,7 +396,7 @@ def spy_forward(x): if batched: sample_frame_rgb = np.stack([sample_frame_rgb] * 2, axis=0) sample_mask = np.stack([sample_mask] * 2, axis=0) - result = engine.batch_process_frames(sample_frame_rgb, sample_mask, post_process_on_gpu=backend == "torch") + result = engine.process_frame(sample_frame_rgb, sample_mask, post_process_on_gpu=backend == "torch") result = result[0] else: result = engine.process_frame(sample_frame_rgb, sample_mask, post_process_on_gpu=backend == "torch") From a90900bfce3cc062026d20f561b84585a081fe29 Mon Sep 17 00:00:00 2001 From: Marclie Date: Sat, 21 Mar 2026 23:35:30 +0100 Subject: [PATCH 42/44] fix: parameter name --- clip_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/clip_manager.py b/clip_manager.py index 80023281..18d58c6a 100644 --- a/clip_manager.py +++ b/clip_manager.py @@ -766,7 +766,7 @@ def run_inference( despeckle_size=settings.despeckle_size, refiner_scale=settings.refiner_scale, generate_comp=settings.generate_comp, - gpu_post_processing=settings.gpu_post_processing, + post_process_on_gpu=settings.gpu_post_processing, ) pred_fg = res["fg"] # sRGB From b4373378281fd2c60b9a7dec76c7a015a848fabb Mon Sep 17 00:00:00 2001 From: Marclie Date: Sun, 22 Mar 2026 17:32:39 +0100 Subject: [PATCH 43/44] feat: add safeguards for mlx --- CorridorKeyModule/backend.py | 1 + corridorkey_cli.py | 30 ++++++++++++++++-------------- 2 files changed, 17 insertions(+), 14 deletions(-) diff --git a/CorridorKeyModule/backend.py b/CorridorKeyModule/backend.py index 02e4427f..f4a097a6 100644 --- a/CorridorKeyModule/backend.py +++ b/CorridorKeyModule/backend.py @@ -224,6 +224,7 @@ def process_frame( despill_strength=1.0, auto_despeckle=True, despeckle_size=400, + **_kwargs, ): """Delegate to MLX engine, then normalize output to Torch contract.""" # MLX engine expects uint8 input — convert if float diff --git a/corridorkey_cli.py b/corridorkey_cli.py index 6e4fda03..b040b3fe 100644 --- a/corridorkey_cli.py +++ b/corridorkey_cli.py @@ -43,6 +43,7 @@ run_videomama, scan_clips, ) +from CorridorKeyModule.backend import resolve_backend from device_utils import resolve_device logger = logging.getLogger(__name__) @@ -191,21 +192,22 @@ def _prompt_inference_settings( except ValueError: refiner_scale = 1.0 - if default_comp is not None: - generate_comp = default_comp - else: - generate_comp = Confirm.ask( - "Generate composition previews", - default=True, - ) + if resolve_backend() == "torch": + if default_comp is not None: + generate_comp = default_comp + else: + generate_comp = Confirm.ask( + "Generate composition previews", + default=True, + ) - if default_gpu_post is not None: - gpu_post_processing = default_gpu_post - else: - gpu_post_processing = Confirm.ask( - "Use GPU accelerated post-processing [dim](experimental)[/dim]", - default=False, - ) + if default_gpu_post is not None: + gpu_post_processing = default_gpu_post + else: + gpu_post_processing = Confirm.ask( + "Use GPU accelerated post-processing [dim](experimental)[/dim]", + default=False, + ) return InferenceSettings( input_is_linear=input_is_linear, From f2501fc8f9468d4a3dbb6724157700de146aa54a Mon Sep 17 00:00:00 2001 From: Marclie Date: Tue, 24 Mar 2026 23:18:35 +0100 Subject: [PATCH 44/44] fix: use sample paths --- test_outputs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test_outputs.py b/test_outputs.py index 7c1e9645..cff802d6 100644 --- a/test_outputs.py +++ b/test_outputs.py @@ -140,4 +140,4 @@ def compare_img_sizes(folder, ref=1024): if __name__ == "__main__": - compare_implementations("./Output/gpu_full_res/Comp", "./Output/gpu_fp16/Comp", "./Output/diff/fp16_vs_fp32") + compare_implementations("./Output/base/Comp", "./Output/compare/Comp", output_dir="./Output/diff")