From 2f457912837de15883d5903be2d6d5ad287c15db Mon Sep 17 00:00:00 2001 From: MArty Date: Tue, 10 Mar 2026 09:10:13 +0000 Subject: [PATCH] perf: implement multithreaded I/O optimization for high-core-count CPUs Refactor optimization logic into a reusable RuntimeThreadPool context manager in device_utils.py. Applied this refactor to clip_manager.py and backend/service.py to improve inference throughput on systems like Threadripper by overlapping I/O with GPU inference. --- backend/service.py | 154 +++++++++++++++++++++++++-------------------- clip_manager.py | 63 ++++++++++++++----- device_utils.py | 62 ++++++++++++++++++ 3 files changed, 195 insertions(+), 84 deletions(-) diff --git a/backend/service.py b/backend/service.py index f4aeb1ad..eb4fd12c 100644 --- a/backend/service.py +++ b/backend/service.py @@ -618,79 +618,97 @@ def run_inference( frame_indices = range(num_frames) range_count = num_frames + from device_utils import RuntimeThreadPool + + def _process_single_frame(progress_i: int, i: int): + """Inner worker function mapped inside the thread pool.""" + if job and job.is_cancelled: + return ('cancel', FrameResult(i, f"{i:05d}", False, "cancelled"), None) + + try: + # 1. IO: Read input + img, input_stem, is_linear = self._read_input_frame( + clip, i, input_files, input_cap, params.input_is_linear + ) + if img is None: + return ('error', FrameResult(i, f"{i:05d}", False, "video read failed"), "video read failed") + + # Resume tracking + if input_stem in skip_stems: + return ('skip', FrameResult(i, input_stem, True, "resumed (skipped)"), None) + + # 2. IO: Read alpha mask and resize + mask = self._read_alpha_frame(clip, i, alpha_files, alpha_cap) + if mask is None: + return ('error', FrameResult(i, input_stem, False, "alpha read failed"), "alpha read failed") + + if mask.shape[:2] != img.shape[:2]: + mask = cv2.resize(mask, (img.shape[1], img.shape[0]), interpolation=cv2.INTER_LINEAR) + + # 3. Pre-process mapping (CPU) + t_frame = time.monotonic() + inp_t, h, w = engine.preprocess(img, mask, input_is_linear=is_linear) + + # 4. Infer (GPU locked) + with self._gpu_lock: + pred_alpha, pred_fg = engine.infer(inp_t, refiner_scale=params.refiner_scale) + + # 5. Post-process mapping (CPU) + res = engine.postprocess( + pred_alpha, pred_fg, h, w, + fg_is_straight=True, + despill_strength=params.despill_strength, + auto_despeckle=params.auto_despeckle, + despeckle_size=params.despeckle_size + ) + logger.debug(f"Clip '{clip.name}' frame {i}: process/infer {time.monotonic() - t_frame:.3f}s") + + # 6. IO: Write outputs + self._write_outputs(res, dirs, input_stem, clip.name, i, cfg) + return ('success', FrameResult(i, input_stem, True), None) + + except FrameReadError as e: + return ('warning', FrameResult(i, f"{i:05d}", False, str(e)), str(e)) + except WriteFailureError as e: + return ('error', FrameResult(i, f"{i:05d}", False, str(e)), str(e)) + try: - for progress_i, i in enumerate(frame_indices): - # Check cancellation between frames - if job and job.is_cancelled: - raise JobCancelledError(clip.name, i) + is_video = bool(input_cap or alpha_cap) + with RuntimeThreadPool(is_video) as executor: + from concurrent.futures import as_completed + futures = {} - # Report progress every frame (enables responsive cancel + timer) - if on_progress: - on_progress(clip.name, progress_i, range_count) - - try: - # Read input - img, input_stem, is_linear = self._read_input_frame( - clip, - i, - input_files, - input_cap, - params.input_is_linear, - ) - if img is None: - skipped.append(i) - results.append(FrameResult(i, f"{i:05d}", False, "video read failed")) - continue + for progress_i, i in enumerate(frame_indices): + fut = executor.submit(_process_single_frame, progress_i, i) + futures[fut] = (progress_i, i) - # Resume: skip frames that already have outputs - if input_stem in skip_stems: - results.append(FrameResult(i, input_stem, True, "resumed (skipped)")) - continue + processed_count = 0 + for fut in as_completed(futures): + progress_i, i = futures[fut] - # Read alpha - mask = self._read_alpha_frame(clip, i, alpha_files, alpha_cap) - if mask is None: + # Fast cancellation check + if job and job.is_cancelled: + executor.shutdown(wait=False, cancel_futures=True) + raise JobCancelledError(clip.name, i) + + # Accumulate parsed results + status, frame_res, msg = fut.result() + + if status in ('error', 'warning'): skipped.append(i) - results.append(FrameResult(i, input_stem, False, "alpha read failed")) - continue - - # Resize mask if dimensions don't match input - if mask.shape[:2] != img.shape[:2]: - mask = cv2.resize(mask, (img.shape[1], img.shape[0]), interpolation=cv2.INTER_LINEAR) - - # Process (GPU-locked — process_frame mutates model hooks) - t_frame = time.monotonic() - with self._gpu_lock: - res = engine.process_frame( - img, - mask, - input_is_linear=is_linear, - fg_is_straight=True, - despill_strength=params.despill_strength, - auto_despeckle=params.auto_despeckle, - despeckle_size=params.despeckle_size, - refiner_scale=params.refiner_scale, - ) - logger.debug(f"Clip '{clip.name}' frame {i}: process_frame {time.monotonic() - t_frame:.3f}s") - - # Write outputs - self._write_outputs(res, dirs, input_stem, clip.name, i, cfg) - results.append(FrameResult(i, input_stem, True)) - - except FrameReadError as e: - logger.warning(str(e)) - skipped.append(i) - results.append(FrameResult(i, f"{i:05d}", False, str(e))) - if on_warning: - on_warning(str(e)) - - except WriteFailureError as e: - logger.error(str(e)) - results.append(FrameResult(i, f"{i:05d}", False, str(e))) - if on_warning: - on_warning(str(e)) - - # Final progress + results.append(frame_res) + if msg: + logger.warning(msg) + if on_warning: + on_warning(msg) + elif status in ('success', 'skip'): + results.append(frame_res) + + # Update internal tracking + processed_count += 1 + if on_progress: + on_progress(clip.name, processed_count, range_count) + if on_progress: on_progress(clip.name, range_count, range_count) diff --git a/clip_manager.py b/clip_manager.py index ebdab434..477ecf8c 100644 --- a/clip_manager.py +++ b/clip_manager.py @@ -607,18 +607,33 @@ def run_inference( if on_clip_start: on_clip_start(clip.name, num_frames) - for i in range(num_frames): + # Unpack settings for easier access in the worker function + user_input_is_linear = settings.input_is_linear + despill_strength = settings.despill_strength + auto_despeckle = settings.auto_despeckle + despeckle_size = settings.despeckle_size + refiner_scale = settings.refiner_scale + + # Create a lock for GPU operations to ensure thread-safe inference + import threading + _gpu_lock = threading.Lock() + + from device_utils import RuntimeThreadPool + + def _process_single_frame(i: int): + if i % 10 == 0: + print(f" Frame {i}/{num_frames}...", end="\r") + # 1. Read Input img_srgb = None input_stem = f"{i:05d}" + input_is_linear = user_input_is_linear - # Use the settings-defined gamma - input_is_linear = settings.input_is_linear - + # IO is strictly locked to prevent scramble if using VideoCapture if input_cap: ret, frame = input_cap.read() if not ret: - break + return False img_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) img_srgb = img_rgb.astype(np.float32) / 255.0 input_stem = f"{i:05d}" @@ -630,14 +645,13 @@ def run_inference( if is_exr: img_linear = cv2.imread(fpath, cv2.IMREAD_UNCHANGED) if img_linear is None: - continue + return False img_linear_rgb = cv2.cvtColor(img_linear, cv2.COLOR_BGR2RGB) - # Support overriding EXR behavior if user picked 's' img_srgb = np.maximum(img_linear_rgb, 0.0) else: img_bgr = cv2.imread(fpath) if img_bgr is None: - continue + return False img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB) img_srgb = img_rgb.astype(np.float32) / 255.0 @@ -646,14 +660,14 @@ def run_inference( if alpha_cap: ret, frame = alpha_cap.read() if not ret: - break + return False mask_linear = frame[:, :, 2].astype(np.float32) / 255.0 else: fpath = os.path.join(clip.alpha_asset.path, alpha_files[i]) mask_in = cv2.imread(fpath, cv2.IMREAD_ANYDEPTH | cv2.IMREAD_UNCHANGED) if mask_in is None: - continue + return False if mask_in.ndim == 3: if mask_in.shape[2] == 3: @@ -675,14 +689,18 @@ def run_inference( mask_linear, (img_srgb.shape[1], img_srgb.shape[0]), interpolation=cv2.INTER_LINEAR ) - # 3. Process + # 3. Pre-process mapping (CPU) USE_STRAIGHT_MODEL = True - res = engine.process_frame( - img_srgb, - mask_linear, - input_is_linear=input_is_linear, + inp_t, h, w = engine.preprocess(img_srgb, mask_linear, input_is_linear=input_is_linear) + + # 4. Infer (GPU locked) + with _gpu_lock: + pred_alpha, pred_fg = engine.infer(inp_t, refiner_scale=refiner_scale) + + # 5. Post-process mapping (CPU) + res = engine.postprocess( + pred_alpha, pred_fg, h, w, fg_is_straight=USE_STRAIGHT_MODEL, - despill_strength=settings.despill_strength, auto_despeckle=settings.auto_despeckle, despeckle_size=settings.despeckle_size, refiner_scale=settings.refiner_scale, @@ -720,6 +738,19 @@ def run_inference( if on_frame_complete: on_frame_complete(i, num_frames) + + return True + + # Parallelize I/O using the managed thread pool + is_video = bool(input_cap or alpha_cap) + with RuntimeThreadPool(is_video) as executor: + from concurrent.futures import as_completed + futures = [executor.submit(_process_single_frame, i) for i in range(num_frames)] + for fut in as_completed(futures): + try: + fut.result() + except Exception as e: + logger.error(f"Frame processing failed: {e}") if input_cap: input_cap.release() diff --git a/device_utils.py b/device_utils.py index 6894d082..e4bca308 100644 --- a/device_utils.py +++ b/device_utils.py @@ -74,3 +74,65 @@ def clear_device_cache(device: torch.device | str) -> None: torch.cuda.empty_cache() elif device_type == "mps": torch.mps.empty_cache() + + +class RuntimeThreadPool: + """Context manager to handle parallel I/O with constrained compute threads. + + On high-core-count systems (e.g. Threadripper), nested thread pools in + OpenCV, PyTorch, and OpenEXR can cause massive context-switching overhead + when running inside a top-level ThreadPoolExecutor. + + This context manager: + 1. Determines optimal worker count based on CPU and workload type. + 2. Temporarily constrains library internal threads to 1 per worker. + 3. Provides a managed ThreadPoolExecutor for I/O-heavy tasks. + """ + + def __init__(self, is_video: bool): + self.is_video = is_video + self.max_workers = 1 if is_video else min(32, (os.cpu_count() or 1) + 4) + self.executor = None + self._prev_cv2_threads = None + self._prev_torch_threads = None + self._prev_exr_threads = os.environ.get("OPENEXR_NUM_THREADS") + + def __enter__(self): + # VideoCapture is not thread-safe; use a single worker + if self.is_video: + from concurrent.futures import ThreadPoolExecutor + self.executor = ThreadPoolExecutor(max_workers=1) + return self.executor + + # Constrain library threads to prevent thrashing + import cv2 + import torch + + self._prev_cv2_threads = cv2.getNumThreads() + self._prev_torch_threads = torch.get_num_threads() + + cv2.setNumThreads(0) + torch.set_num_threads(1) + os.environ["OPENEXR_NUM_THREADS"] = "1" + + from concurrent.futures import ThreadPoolExecutor + self.executor = ThreadPoolExecutor(max_workers=self.max_workers) + return self.executor + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.executor: + self.executor.shutdown(wait=True) + + if not self.is_video: + import cv2 + import torch + + if self._prev_cv2_threads is not None: + cv2.setNumThreads(self._prev_cv2_threads) + if self._prev_torch_threads is not None: + torch.set_num_threads(self._prev_torch_threads) + + if self._prev_exr_threads is not None: + os.environ["OPENEXR_NUM_THREADS"] = self._prev_exr_threads + else: + os.environ.pop("OPENEXR_NUM_THREADS", None)