Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion CorridorKeyModule/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import numpy as np
import torch

from .model_assets import ensure_corridorkey_assets

logger = logging.getLogger(__name__)

CHECKPOINT_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "checkpoints")
Expand Down Expand Up @@ -303,7 +305,29 @@ def create_engine(
Set to None to disable tiling and use full-frame inference.
overlap: MLX only — overlap pixels between tiles (default 64).
"""
backend = resolve_backend(backend)
requested_backend = None if backend is None else backend.lower()

if requested_backend in (None, "auto"):
# Auto-detect needs assets on disk before resolve_backend() runs:
# _auto_detect_backend() checks for local MLX weights and can prefer MLX
# only after they have been downloaded or discovered.
ensure_corridorkey_assets(
ensure_torch=True,
ensure_mlx=False,
download_mlx_if_available=True,
checkpoint_dir=CHECKPOINT_DIR,
)
backend = resolve_backend(backend)
else:
# Explicit backends can resolve first because the caller has already
# chosen the runtime; we only need to fetch assets for that backend.
backend = resolve_backend(backend)
ensure_corridorkey_assets(
ensure_torch=True,
ensure_mlx=backend == "mlx",
download_mlx_if_available=True,
checkpoint_dir=CHECKPOINT_DIR,
)

if backend == "mlx":
ckpt = _discover_checkpoint(MLX_EXT)
Expand Down
2 changes: 1 addition & 1 deletion CorridorKeyModule/inference_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __init__(

self.model_precision = model_precision

self.model = self._load_model()
self.model = self._load_model().to(model_precision)

# We only tested compilation on Windows and Linux. For other platforms compilation is disabled as a precaution.
if sys.platform == "linux" or sys.platform == "win32":
Expand Down
Loading