Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
211 changes: 191 additions & 20 deletions CorridorKeyModule/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@
from __future__ import annotations

import glob
import importlib
import logging
import os
import platform
import shutil
import subprocess
import sys
from pathlib import Path

Expand All @@ -20,12 +23,19 @@

BACKEND_ENV_VAR = "CORRIDORKEY_BACKEND"
VALID_BACKENDS = ("auto", "torch", "mlx")
MLX_INSTALL_SPEC = "corridorkey-mlx@git+https://github.com/nikopueringer/corridorkey-mlx.git"
PROJECT_ROOT = Path(__file__).resolve().parents[1]
MLX_CONVERTER_REPO_URL = "https://github.com/nikopueringer/corridorkey-mlx.git"
MLX_CONVERTER_REPO_DIR = PROJECT_ROOT / "corridorkey-mlx"
DEFAULT_MLX_CHECKPOINT_NAME = f"corridorkey_mlx{MLX_EXT}"
_MLX_INSTALL_ATTEMPTED = False


def resolve_backend(requested: str | None = None) -> str:
"""Resolve backend: CLI flag > env var > auto-detect.

Auto mode: Apple Silicon + corridorkey_mlx importable + .safetensors found → mlx.
Auto mode prefers MLX on Apple Silicon and auto-installs the MLX runtime
when needed.
Otherwise → torch.

Raises RuntimeError if explicit backend is unavailable.
Expand All @@ -48,38 +58,199 @@ def resolve_backend(requested: str | None = None) -> str:


def _auto_detect_backend() -> str:
"""Try MLX on Apple Silicon, fall back to Torch."""
if sys.platform != "darwin" or platform.machine() != "arm64":
"""Prefer MLX on Apple Silicon, otherwise fall back to Torch."""
if not _is_apple_silicon():
logger.info("Not Apple Silicon — using torch backend")
return "torch"

try:
import corridorkey_mlx # type: ignore[import-not-found] # noqa: F401
except ImportError:
logger.info("corridorkey_mlx not installed — using torch backend")
if not _mlx_runtime_available(auto_install=True):
logger.info("Apple Silicon detected but corridorkey_mlx could not be installed — using torch backend")
return "torch"

safetensor_files = glob.glob(os.path.join(CHECKPOINT_DIR, f"*{MLX_EXT}"))
if not safetensor_files:
logger.info("No %s checkpoint found — using torch backend", MLX_EXT)
return "torch"

logger.info("Apple Silicon + MLX available — using mlx backend")
logger.info("Apple Silicon detected — preferring mlx backend")
return "mlx"


def _validate_mlx_available() -> None:
"""Raise RuntimeError with actionable message if MLX can't be used."""
if sys.platform != "darwin" or platform.machine() != "arm64":
if not _is_apple_silicon():
raise RuntimeError("MLX backend requires Apple Silicon (M1+ Mac)")

if not _mlx_runtime_available(auto_install=True):
raise RuntimeError(
"MLX backend requested but corridorkey_mlx is unavailable and automatic installation failed. "
f"Tried: {_install_command_summary()}"
)


def _is_apple_silicon() -> bool:
"""Return True when running on an Apple Silicon Mac."""
return sys.platform == "darwin" and platform.machine() == "arm64"


def _mlx_runtime_available(*, auto_install: bool = False) -> bool:
"""Check whether the MLX runtime package can be imported."""
if _can_import_mlx_runtime():
return True

if not auto_install:
return False

return _install_mlx_runtime()


def _can_import_mlx_runtime() -> bool:
"""Return True when corridorkey_mlx can be imported."""
try:
import corridorkey_mlx # type: ignore[import-not-found] # noqa: F401
except ImportError as err:
raise RuntimeError(
"MLX backend requested but corridorkey_mlx is not installed. "
"Install with: uv pip install corridorkey-mlx@git+https://github.com/cmoyates/corridorkey-mlx.git"
) from err
except ImportError:
return False

return True


def _install_mlx_runtime() -> bool:
"""Attempt a one-time runtime install of corridorkey_mlx."""
global _MLX_INSTALL_ATTEMPTED

if _MLX_INSTALL_ATTEMPTED:
return _can_import_mlx_runtime()

_MLX_INSTALL_ATTEMPTED = True

if sys.version_info < (3, 11):
logger.warning("Automatic MLX install requires Python 3.11+")
return False

logger.info("corridorkey_mlx not installed — attempting automatic install")

for cmd in _install_commands():
try:
logger.info("Attempting MLX install via: %s", " ".join(cmd))
result = subprocess.run(cmd, capture_output=True, text=True, check=False)
except Exception as err:
logger.warning("Automatic corridorkey_mlx install failed: %s", err)
continue

if result.returncode == 0:
importlib.invalidate_caches()
if _can_import_mlx_runtime():
logger.info("corridorkey_mlx installed successfully")
return True

logger.warning("corridorkey_mlx install completed but the package is still unavailable")
return False

details = (result.stderr or result.stdout or "").strip()
if details:
logger.warning("Automatic corridorkey_mlx install failed: %s", details.splitlines()[-1])
else:
logger.warning("Automatic corridorkey_mlx install failed with exit code %s", result.returncode)

return False


def _install_commands() -> list[list[str]]:
"""Return installer commands in preferred order for the active interpreter."""
commands: list[list[str]] = []

uv_path = shutil.which("uv")
if uv_path:
commands.append([uv_path, "pip", "install", "--python", sys.executable, MLX_INSTALL_SPEC])

commands.append([sys.executable, "-m", "pip", "install", MLX_INSTALL_SPEC])
return commands


def _install_command_summary() -> str:
"""Human-readable summary of attempted installer commands."""
return " or ".join(" ".join(cmd) for cmd in _install_commands())


def _ensure_mlx_checkpoint() -> Path:
"""Return an MLX checkpoint, converting from the Torch checkpoint when needed."""
matches = glob.glob(os.path.join(CHECKPOINT_DIR, f"*{MLX_EXT}"))

if len(matches) == 1:
return Path(matches[0])

if len(matches) > 1:
names = [os.path.basename(f) for f in matches]
raise ValueError(f"Multiple {MLX_EXT} checkpoints in {CHECKPOINT_DIR}: {names}. Keep exactly one.")

torch_checkpoint = _discover_checkpoint(TORCH_EXT)
output_path = Path(CHECKPOINT_DIR) / DEFAULT_MLX_CHECKPOINT_NAME

logger.info(
"No %s checkpoint found — attempting automatic MLX conversion from %s",
MLX_EXT,
torch_checkpoint.name,
)
_convert_torch_checkpoint_to_mlx(torch_checkpoint, output_path)

if not output_path.is_file():
raise FileNotFoundError(
f"Automatic MLX conversion completed but no {MLX_EXT} checkpoint was created at {output_path}."
)

return output_path


def _convert_torch_checkpoint_to_mlx(torch_checkpoint: Path, output_path: Path) -> None:
"""Clone the converter repo if needed, then generate MLX weights."""
git_path = shutil.which("git")
uv_path = shutil.which("uv")

if not git_path:
raise RuntimeError("Automatic MLX weight conversion requires `git` to be installed.")
if not uv_path:
raise RuntimeError("Automatic MLX weight conversion requires `uv` to be installed.")

repo_dir = MLX_CONVERTER_REPO_DIR
if not repo_dir.exists():
logger.info("Cloning MLX converter repo into %s", repo_dir)
_run_checked_command([git_path, "clone", MLX_CONVERTER_REPO_URL, str(repo_dir)], cwd=PROJECT_ROOT)

convert_script = repo_dir / "scripts" / "convert_weights.py"
if not convert_script.is_file():
raise RuntimeError(f"MLX converter script not found at {convert_script}")

logger.info("Syncing MLX converter dependencies in %s (including reference group)", repo_dir)
_run_checked_command([uv_path, "sync", "--group", "reference"], cwd=repo_dir)

logger.info("Converting Torch checkpoint %s -> %s", torch_checkpoint.name, output_path.name)
_run_checked_command(
[
uv_path,
"run",
"--group",
"reference",
"python",
"scripts/convert_weights.py",
"--checkpoint",
str(torch_checkpoint),
"--output",
str(output_path),
],
cwd=repo_dir,
)


def _run_checked_command(cmd: list[str], *, cwd: Path) -> None:
"""Run a subprocess and raise a concise error if it fails."""
try:
result = subprocess.run(cmd, cwd=str(cwd), capture_output=True, text=True, check=False)
except Exception as err:
raise RuntimeError(f"Command `{' '.join(cmd)}` failed to start: {err}") from err

if result.returncode == 0:
return

details = (result.stderr or result.stdout or "").strip()
if details:
raise RuntimeError(f"Command `{' '.join(cmd)}` failed: {details.splitlines()[-1]}")

raise RuntimeError(f"Command `{' '.join(cmd)}` failed with exit code {result.returncode}")


def _discover_checkpoint(ext: str) -> Path:
Expand Down Expand Up @@ -224,7 +395,7 @@ def create_engine(
backend = resolve_backend(backend)

if backend == "mlx":
ckpt = _discover_checkpoint(MLX_EXT)
ckpt = _ensure_mlx_checkpoint()
from corridorkey_mlx import CorridorKeyMLXEngine # type: ignore[import-not-found]

raw_engine = CorridorKeyMLXEngine(str(ckpt), img_size=img_size, tile_size=tile_size, overlap=overlap)
Expand Down
Loading