Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 86 additions & 68 deletions backend/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,79 +618,97 @@ def run_inference(
frame_indices = range(num_frames)
range_count = num_frames

from device_utils import RuntimeThreadPool

def _process_single_frame(progress_i: int, i: int):
"""Inner worker function mapped inside the thread pool."""
if job and job.is_cancelled:
return ('cancel', FrameResult(i, f"{i:05d}", False, "cancelled"), None)

try:
# 1. IO: Read input
img, input_stem, is_linear = self._read_input_frame(
clip, i, input_files, input_cap, params.input_is_linear
)
if img is None:
return ('error', FrameResult(i, f"{i:05d}", False, "video read failed"), "video read failed")

# Resume tracking
if input_stem in skip_stems:
return ('skip', FrameResult(i, input_stem, True, "resumed (skipped)"), None)

# 2. IO: Read alpha mask and resize
mask = self._read_alpha_frame(clip, i, alpha_files, alpha_cap)
if mask is None:
return ('error', FrameResult(i, input_stem, False, "alpha read failed"), "alpha read failed")

if mask.shape[:2] != img.shape[:2]:
mask = cv2.resize(mask, (img.shape[1], img.shape[0]), interpolation=cv2.INTER_LINEAR)

# 3. Pre-process mapping (CPU)
t_frame = time.monotonic()
inp_t, h, w = engine.preprocess(img, mask, input_is_linear=is_linear)

# 4. Infer (GPU locked)
with self._gpu_lock:
pred_alpha, pred_fg = engine.infer(inp_t, refiner_scale=params.refiner_scale)

# 5. Post-process mapping (CPU)
res = engine.postprocess(
pred_alpha, pred_fg, h, w,
fg_is_straight=True,
despill_strength=params.despill_strength,
auto_despeckle=params.auto_despeckle,
despeckle_size=params.despeckle_size
)
logger.debug(f"Clip '{clip.name}' frame {i}: process/infer {time.monotonic() - t_frame:.3f}s")

# 6. IO: Write outputs
self._write_outputs(res, dirs, input_stem, clip.name, i, cfg)
return ('success', FrameResult(i, input_stem, True), None)

except FrameReadError as e:
return ('warning', FrameResult(i, f"{i:05d}", False, str(e)), str(e))
except WriteFailureError as e:
return ('error', FrameResult(i, f"{i:05d}", False, str(e)), str(e))

try:
for progress_i, i in enumerate(frame_indices):
# Check cancellation between frames
if job and job.is_cancelled:
raise JobCancelledError(clip.name, i)
is_video = bool(input_cap or alpha_cap)
with RuntimeThreadPool(is_video) as executor:
from concurrent.futures import as_completed
futures = {}

# Report progress every frame (enables responsive cancel + timer)
if on_progress:
on_progress(clip.name, progress_i, range_count)

try:
# Read input
img, input_stem, is_linear = self._read_input_frame(
clip,
i,
input_files,
input_cap,
params.input_is_linear,
)
if img is None:
skipped.append(i)
results.append(FrameResult(i, f"{i:05d}", False, "video read failed"))
continue
for progress_i, i in enumerate(frame_indices):
fut = executor.submit(_process_single_frame, progress_i, i)
futures[fut] = (progress_i, i)

# Resume: skip frames that already have outputs
if input_stem in skip_stems:
results.append(FrameResult(i, input_stem, True, "resumed (skipped)"))
continue
processed_count = 0
for fut in as_completed(futures):
progress_i, i = futures[fut]

# Read alpha
mask = self._read_alpha_frame(clip, i, alpha_files, alpha_cap)
if mask is None:
# Fast cancellation check
if job and job.is_cancelled:
executor.shutdown(wait=False, cancel_futures=True)
raise JobCancelledError(clip.name, i)

# Accumulate parsed results
status, frame_res, msg = fut.result()

if status in ('error', 'warning'):
skipped.append(i)
results.append(FrameResult(i, input_stem, False, "alpha read failed"))
continue

# Resize mask if dimensions don't match input
if mask.shape[:2] != img.shape[:2]:
mask = cv2.resize(mask, (img.shape[1], img.shape[0]), interpolation=cv2.INTER_LINEAR)

# Process (GPU-locked — process_frame mutates model hooks)
t_frame = time.monotonic()
with self._gpu_lock:
res = engine.process_frame(
img,
mask,
input_is_linear=is_linear,
fg_is_straight=True,
despill_strength=params.despill_strength,
auto_despeckle=params.auto_despeckle,
despeckle_size=params.despeckle_size,
refiner_scale=params.refiner_scale,
)
logger.debug(f"Clip '{clip.name}' frame {i}: process_frame {time.monotonic() - t_frame:.3f}s")

# Write outputs
self._write_outputs(res, dirs, input_stem, clip.name, i, cfg)
results.append(FrameResult(i, input_stem, True))

except FrameReadError as e:
logger.warning(str(e))
skipped.append(i)
results.append(FrameResult(i, f"{i:05d}", False, str(e)))
if on_warning:
on_warning(str(e))

except WriteFailureError as e:
logger.error(str(e))
results.append(FrameResult(i, f"{i:05d}", False, str(e)))
if on_warning:
on_warning(str(e))

# Final progress
results.append(frame_res)
if msg:
logger.warning(msg)
if on_warning:
on_warning(msg)
elif status in ('success', 'skip'):
results.append(frame_res)

# Update internal tracking
processed_count += 1
if on_progress:
on_progress(clip.name, processed_count, range_count)

if on_progress:
on_progress(clip.name, range_count, range_count)

Expand Down
63 changes: 47 additions & 16 deletions clip_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,18 +607,33 @@ def run_inference(
if on_clip_start:
on_clip_start(clip.name, num_frames)

for i in range(num_frames):
# Unpack settings for easier access in the worker function
user_input_is_linear = settings.input_is_linear
despill_strength = settings.despill_strength
auto_despeckle = settings.auto_despeckle
despeckle_size = settings.despeckle_size
refiner_scale = settings.refiner_scale

# Create a lock for GPU operations to ensure thread-safe inference
import threading
_gpu_lock = threading.Lock()

from device_utils import RuntimeThreadPool

def _process_single_frame(i: int):
if i % 10 == 0:
print(f" Frame {i}/{num_frames}...", end="\r")

# 1. Read Input
img_srgb = None
input_stem = f"{i:05d}"
input_is_linear = user_input_is_linear

# Use the settings-defined gamma
input_is_linear = settings.input_is_linear

# IO is strictly locked to prevent scramble if using VideoCapture
if input_cap:
ret, frame = input_cap.read()
if not ret:
break
return False
img_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
img_srgb = img_rgb.astype(np.float32) / 255.0
input_stem = f"{i:05d}"
Expand All @@ -630,14 +645,13 @@ def run_inference(
if is_exr:
img_linear = cv2.imread(fpath, cv2.IMREAD_UNCHANGED)
if img_linear is None:
continue
return False
img_linear_rgb = cv2.cvtColor(img_linear, cv2.COLOR_BGR2RGB)
# Support overriding EXR behavior if user picked 's'
img_srgb = np.maximum(img_linear_rgb, 0.0)
else:
img_bgr = cv2.imread(fpath)
if img_bgr is None:
continue
return False
img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
img_srgb = img_rgb.astype(np.float32) / 255.0

Expand All @@ -646,14 +660,14 @@ def run_inference(
if alpha_cap:
ret, frame = alpha_cap.read()
if not ret:
break
return False
mask_linear = frame[:, :, 2].astype(np.float32) / 255.0
else:
fpath = os.path.join(clip.alpha_asset.path, alpha_files[i])
mask_in = cv2.imread(fpath, cv2.IMREAD_ANYDEPTH | cv2.IMREAD_UNCHANGED)

if mask_in is None:
continue
return False

if mask_in.ndim == 3:
if mask_in.shape[2] == 3:
Expand All @@ -675,14 +689,18 @@ def run_inference(
mask_linear, (img_srgb.shape[1], img_srgb.shape[0]), interpolation=cv2.INTER_LINEAR
)

# 3. Process
# 3. Pre-process mapping (CPU)
USE_STRAIGHT_MODEL = True
res = engine.process_frame(
img_srgb,
mask_linear,
input_is_linear=input_is_linear,
inp_t, h, w = engine.preprocess(img_srgb, mask_linear, input_is_linear=input_is_linear)

# 4. Infer (GPU locked)
with _gpu_lock:
pred_alpha, pred_fg = engine.infer(inp_t, refiner_scale=refiner_scale)

# 5. Post-process mapping (CPU)
res = engine.postprocess(
pred_alpha, pred_fg, h, w,
fg_is_straight=USE_STRAIGHT_MODEL,
despill_strength=settings.despill_strength,
auto_despeckle=settings.auto_despeckle,
despeckle_size=settings.despeckle_size,
refiner_scale=settings.refiner_scale,
Expand Down Expand Up @@ -720,6 +738,19 @@ def run_inference(

if on_frame_complete:
on_frame_complete(i, num_frames)

return True

# Parallelize I/O using the managed thread pool
is_video = bool(input_cap or alpha_cap)
with RuntimeThreadPool(is_video) as executor:
from concurrent.futures import as_completed
futures = [executor.submit(_process_single_frame, i) for i in range(num_frames)]
for fut in as_completed(futures):
try:
fut.result()
except Exception as e:
logger.error(f"Frame processing failed: {e}")

if input_cap:
input_cap.release()
Expand Down
62 changes: 62 additions & 0 deletions device_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,65 @@ def clear_device_cache(device: torch.device | str) -> None:
torch.cuda.empty_cache()
elif device_type == "mps":
torch.mps.empty_cache()


class RuntimeThreadPool:
"""Context manager to handle parallel I/O with constrained compute threads.

On high-core-count systems (e.g. Threadripper), nested thread pools in
OpenCV, PyTorch, and OpenEXR can cause massive context-switching overhead
when running inside a top-level ThreadPoolExecutor.

This context manager:
1. Determines optimal worker count based on CPU and workload type.
2. Temporarily constrains library internal threads to 1 per worker.
3. Provides a managed ThreadPoolExecutor for I/O-heavy tasks.
"""

def __init__(self, is_video: bool):
self.is_video = is_video
self.max_workers = 1 if is_video else min(32, (os.cpu_count() or 1) + 4)
self.executor = None
self._prev_cv2_threads = None
self._prev_torch_threads = None
self._prev_exr_threads = os.environ.get("OPENEXR_NUM_THREADS")

def __enter__(self):
# VideoCapture is not thread-safe; use a single worker
if self.is_video:
from concurrent.futures import ThreadPoolExecutor
self.executor = ThreadPoolExecutor(max_workers=1)
return self.executor

# Constrain library threads to prevent thrashing
import cv2
import torch

self._prev_cv2_threads = cv2.getNumThreads()
self._prev_torch_threads = torch.get_num_threads()

cv2.setNumThreads(0)
torch.set_num_threads(1)
os.environ["OPENEXR_NUM_THREADS"] = "1"

from concurrent.futures import ThreadPoolExecutor
self.executor = ThreadPoolExecutor(max_workers=self.max_workers)
return self.executor

def __exit__(self, exc_type, exc_val, exc_tb):
if self.executor:
self.executor.shutdown(wait=True)

if not self.is_video:
import cv2
import torch

if self._prev_cv2_threads is not None:
cv2.setNumThreads(self._prev_cv2_threads)
if self._prev_torch_threads is not None:
torch.set_num_threads(self._prev_torch_threads)

if self._prev_exr_threads is not None:
os.environ["OPENEXR_NUM_THREADS"] = self._prev_exr_threads
else:
os.environ.pop("OPENEXR_NUM_THREADS", None)
Loading