From e692dd632515ed9bf4fbdc953f50d80dd5af27b8 Mon Sep 17 00:00:00 2001 From: michal1000w Date: Tue, 17 Mar 2026 15:27:11 +0100 Subject: [PATCH 1/2] automatic mlx selection and setup --- CorridorKeyModule/backend.py | 211 ++++++++++++++++++++++++++++++---- corridorkey-mlx | 1 + tests/test_backend.py | 213 +++++++++++++++++++++++++++++++++-- 3 files changed, 395 insertions(+), 30 deletions(-) create mode 160000 corridorkey-mlx diff --git a/CorridorKeyModule/backend.py b/CorridorKeyModule/backend.py index 8c51137b..356a32a5 100644 --- a/CorridorKeyModule/backend.py +++ b/CorridorKeyModule/backend.py @@ -3,9 +3,12 @@ from __future__ import annotations import glob +import importlib import logging import os import platform +import shutil +import subprocess import sys from pathlib import Path @@ -20,12 +23,19 @@ BACKEND_ENV_VAR = "CORRIDORKEY_BACKEND" VALID_BACKENDS = ("auto", "torch", "mlx") +MLX_INSTALL_SPEC = "corridorkey-mlx@git+https://github.com/nikopueringer/corridorkey-mlx.git" +PROJECT_ROOT = Path(__file__).resolve().parents[1] +MLX_CONVERTER_REPO_URL = "https://github.com/nikopueringer/corridorkey-mlx.git" +MLX_CONVERTER_REPO_DIR = PROJECT_ROOT / "corridorkey-mlx" +DEFAULT_MLX_CHECKPOINT_NAME = f"corridorkey_mlx{MLX_EXT}" +_MLX_INSTALL_ATTEMPTED = False def resolve_backend(requested: str | None = None) -> str: """Resolve backend: CLI flag > env var > auto-detect. - Auto mode: Apple Silicon + corridorkey_mlx importable + .safetensors found → mlx. + Auto mode prefers MLX on Apple Silicon and auto-installs the MLX runtime + when needed. Otherwise → torch. Raises RuntimeError if explicit backend is unavailable. @@ -48,38 +58,199 @@ def resolve_backend(requested: str | None = None) -> str: def _auto_detect_backend() -> str: - """Try MLX on Apple Silicon, fall back to Torch.""" - if sys.platform != "darwin" or platform.machine() != "arm64": + """Prefer MLX on Apple Silicon, otherwise fall back to Torch.""" + if not _is_apple_silicon(): logger.info("Not Apple Silicon — using torch backend") return "torch" - try: - import corridorkey_mlx # type: ignore[import-not-found] # noqa: F401 - except ImportError: - logger.info("corridorkey_mlx not installed — using torch backend") + if not _mlx_runtime_available(auto_install=True): + logger.info("Apple Silicon detected but corridorkey_mlx could not be installed — using torch backend") return "torch" - safetensor_files = glob.glob(os.path.join(CHECKPOINT_DIR, f"*{MLX_EXT}")) - if not safetensor_files: - logger.info("No %s checkpoint found — using torch backend", MLX_EXT) - return "torch" - - logger.info("Apple Silicon + MLX available — using mlx backend") + logger.info("Apple Silicon detected — preferring mlx backend") return "mlx" def _validate_mlx_available() -> None: """Raise RuntimeError with actionable message if MLX can't be used.""" - if sys.platform != "darwin" or platform.machine() != "arm64": + if not _is_apple_silicon(): raise RuntimeError("MLX backend requires Apple Silicon (M1+ Mac)") + if not _mlx_runtime_available(auto_install=True): + raise RuntimeError( + "MLX backend requested but corridorkey_mlx is unavailable and automatic installation failed. " + f"Tried: {_install_command_summary()}" + ) + + +def _is_apple_silicon() -> bool: + """Return True when running on an Apple Silicon Mac.""" + return sys.platform == "darwin" and platform.machine() == "arm64" + + +def _mlx_runtime_available(*, auto_install: bool = False) -> bool: + """Check whether the MLX runtime package can be imported.""" + if _can_import_mlx_runtime(): + return True + + if not auto_install: + return False + + return _install_mlx_runtime() + + +def _can_import_mlx_runtime() -> bool: + """Return True when corridorkey_mlx can be imported.""" try: import corridorkey_mlx # type: ignore[import-not-found] # noqa: F401 - except ImportError as err: - raise RuntimeError( - "MLX backend requested but corridorkey_mlx is not installed. " - "Install with: uv pip install corridorkey-mlx@git+https://github.com/cmoyates/corridorkey-mlx.git" - ) from err + except ImportError: + return False + + return True + + +def _install_mlx_runtime() -> bool: + """Attempt a one-time runtime install of corridorkey_mlx.""" + global _MLX_INSTALL_ATTEMPTED + + if _MLX_INSTALL_ATTEMPTED: + return _can_import_mlx_runtime() + + _MLX_INSTALL_ATTEMPTED = True + + if sys.version_info < (3, 11): + logger.warning("Automatic MLX install requires Python 3.11+") + return False + + logger.info("corridorkey_mlx not installed — attempting automatic install") + + for cmd in _install_commands(): + try: + logger.info("Attempting MLX install via: %s", " ".join(cmd)) + result = subprocess.run(cmd, capture_output=True, text=True, check=False) + except Exception as err: + logger.warning("Automatic corridorkey_mlx install failed: %s", err) + continue + + if result.returncode == 0: + importlib.invalidate_caches() + if _can_import_mlx_runtime(): + logger.info("corridorkey_mlx installed successfully") + return True + + logger.warning("corridorkey_mlx install completed but the package is still unavailable") + return False + + details = (result.stderr or result.stdout or "").strip() + if details: + logger.warning("Automatic corridorkey_mlx install failed: %s", details.splitlines()[-1]) + else: + logger.warning("Automatic corridorkey_mlx install failed with exit code %s", result.returncode) + + return False + + +def _install_commands() -> list[list[str]]: + """Return installer commands in preferred order for the active interpreter.""" + commands: list[list[str]] = [] + + uv_path = shutil.which("uv") + if uv_path: + commands.append([uv_path, "pip", "install", "--python", sys.executable, MLX_INSTALL_SPEC]) + + commands.append([sys.executable, "-m", "pip", "install", MLX_INSTALL_SPEC]) + return commands + + +def _install_command_summary() -> str: + """Human-readable summary of attempted installer commands.""" + return " or ".join(" ".join(cmd) for cmd in _install_commands()) + + +def _ensure_mlx_checkpoint() -> Path: + """Return an MLX checkpoint, converting from the Torch checkpoint when needed.""" + matches = glob.glob(os.path.join(CHECKPOINT_DIR, f"*{MLX_EXT}")) + + if len(matches) == 1: + return Path(matches[0]) + + if len(matches) > 1: + names = [os.path.basename(f) for f in matches] + raise ValueError(f"Multiple {MLX_EXT} checkpoints in {CHECKPOINT_DIR}: {names}. Keep exactly one.") + + torch_checkpoint = _discover_checkpoint(TORCH_EXT) + output_path = Path(CHECKPOINT_DIR) / DEFAULT_MLX_CHECKPOINT_NAME + + logger.info( + "No %s checkpoint found — attempting automatic MLX conversion from %s", + MLX_EXT, + torch_checkpoint.name, + ) + _convert_torch_checkpoint_to_mlx(torch_checkpoint, output_path) + + if not output_path.is_file(): + raise FileNotFoundError( + f"Automatic MLX conversion completed but no {MLX_EXT} checkpoint was created at {output_path}." + ) + + return output_path + + +def _convert_torch_checkpoint_to_mlx(torch_checkpoint: Path, output_path: Path) -> None: + """Clone the converter repo if needed, then generate MLX weights.""" + git_path = shutil.which("git") + uv_path = shutil.which("uv") + + if not git_path: + raise RuntimeError("Automatic MLX weight conversion requires `git` to be installed.") + if not uv_path: + raise RuntimeError("Automatic MLX weight conversion requires `uv` to be installed.") + + repo_dir = MLX_CONVERTER_REPO_DIR + if not repo_dir.exists(): + logger.info("Cloning MLX converter repo into %s", repo_dir) + _run_checked_command([git_path, "clone", MLX_CONVERTER_REPO_URL, str(repo_dir)], cwd=PROJECT_ROOT) + + convert_script = repo_dir / "scripts" / "convert_weights.py" + if not convert_script.is_file(): + raise RuntimeError(f"MLX converter script not found at {convert_script}") + + logger.info("Syncing MLX converter dependencies in %s (including reference group)", repo_dir) + _run_checked_command([uv_path, "sync", "--group", "reference"], cwd=repo_dir) + + logger.info("Converting Torch checkpoint %s -> %s", torch_checkpoint.name, output_path.name) + _run_checked_command( + [ + uv_path, + "run", + "--group", + "reference", + "python", + "scripts/convert_weights.py", + "--checkpoint", + str(torch_checkpoint), + "--output", + str(output_path), + ], + cwd=repo_dir, + ) + + +def _run_checked_command(cmd: list[str], *, cwd: Path) -> None: + """Run a subprocess and raise a concise error if it fails.""" + try: + result = subprocess.run(cmd, cwd=str(cwd), capture_output=True, text=True, check=False) + except Exception as err: + raise RuntimeError(f"Command `{' '.join(cmd)}` failed to start: {err}") from err + + if result.returncode == 0: + return + + details = (result.stderr or result.stdout or "").strip() + if details: + raise RuntimeError(f"Command `{' '.join(cmd)}` failed: {details.splitlines()[-1]}") + + raise RuntimeError(f"Command `{' '.join(cmd)}` failed with exit code {result.returncode}") def _discover_checkpoint(ext: str) -> Path: @@ -224,7 +395,7 @@ def create_engine( backend = resolve_backend(backend) if backend == "mlx": - ckpt = _discover_checkpoint(MLX_EXT) + ckpt = _ensure_mlx_checkpoint() from corridorkey_mlx import CorridorKeyMLXEngine # type: ignore[import-not-found] raw_engine = CorridorKeyMLXEngine(str(ckpt), img_size=img_size, tile_size=tile_size, overlap=overlap) diff --git a/corridorkey-mlx b/corridorkey-mlx new file mode 160000 index 00000000..04503e79 --- /dev/null +++ b/corridorkey-mlx @@ -0,0 +1 @@ +Subproject commit 04503e797060e091f991bc88b85ec61b0b9b862b diff --git a/tests/test_backend.py b/tests/test_backend.py index 93c9cf99..93bfcd4b 100644 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -1,6 +1,8 @@ """Unit tests for CorridorKeyModule.backend — no GPU/MLX required.""" import os +import sys +from types import ModuleType from unittest import mock import numpy as np @@ -11,7 +13,10 @@ MLX_EXT, TORCH_EXT, _discover_checkpoint, + _ensure_mlx_checkpoint, + _install_commands, _wrap_mlx_output, + create_engine, resolve_backend, ) @@ -38,26 +43,59 @@ def test_auto_non_darwin(self): mock_sys.platform = "linux" assert resolve_backend("auto") == "torch" - def test_auto_darwin_no_mlx_package(self): + def test_auto_darwin_no_mlx_package_and_install_fails(self): with ( mock.patch("CorridorKeyModule.backend.sys") as mock_sys, mock.patch("CorridorKeyModule.backend.platform") as mock_platform, + mock.patch("CorridorKeyModule.backend._can_import_mlx_runtime", return_value=False), + mock.patch("CorridorKeyModule.backend._install_mlx_runtime", return_value=False) as mock_install, ): mock_sys.platform = "darwin" mock_platform.machine.return_value = "arm64" - # corridorkey_mlx not importable - import builtins + assert resolve_backend("auto") == "torch" + mock_install.assert_called_once_with() + + def test_auto_darwin_missing_mlx_package_installs_and_prefers_mlx(self): + with ( + mock.patch("CorridorKeyModule.backend.sys") as mock_sys, + mock.patch("CorridorKeyModule.backend.platform") as mock_platform, + mock.patch("CorridorKeyModule.backend._can_import_mlx_runtime", return_value=False), + mock.patch("CorridorKeyModule.backend._install_mlx_runtime", return_value=True) as mock_install, + ): + mock_sys.platform = "darwin" + mock_platform.machine.return_value = "arm64" + + assert resolve_backend("auto") == "mlx" + mock_install.assert_called_once_with() + + def test_auto_darwin_with_mlx_package_prefers_mlx(self): + fake_mlx = ModuleType("corridorkey_mlx") + with ( + mock.patch("CorridorKeyModule.backend.sys") as mock_sys, + mock.patch("CorridorKeyModule.backend.platform") as mock_platform, + mock.patch.dict(sys.modules, {"corridorkey_mlx": fake_mlx}, clear=False), + ): + mock_sys.platform = "darwin" + mock_platform.machine.return_value = "arm64" - real_import = builtins.__import__ + assert resolve_backend("auto") == "mlx" + + def test_explicit_mlx_attempts_install_before_raising(self): + with ( + mock.patch("CorridorKeyModule.backend.sys") as mock_sys, + mock.patch("CorridorKeyModule.backend.platform") as mock_platform, + mock.patch("CorridorKeyModule.backend._can_import_mlx_runtime", return_value=False), + mock.patch("CorridorKeyModule.backend._install_mlx_runtime", return_value=False) as mock_install, + ): + mock_sys.platform = "darwin" + mock_platform.machine.return_value = "arm64" + mock_sys.executable = "/tmp/fake-python" - def fail_mlx(name, *args, **kwargs): - if name == "corridorkey_mlx": - raise ImportError - return real_import(name, *args, **kwargs) + with pytest.raises(RuntimeError, match="automatic installation failed"): + resolve_backend("mlx") - with mock.patch("builtins.__import__", side_effect=fail_mlx): - assert resolve_backend("auto") == "torch" + mock_install.assert_called_once_with() def test_unknown_backend_raises(self): with pytest.raises(RuntimeError, match="Unknown backend"): @@ -101,6 +139,161 @@ def test_safetensors(self, tmp_path): assert result == ckpt +class TestInstallCommands: + def test_prefers_uv_for_active_interpreter(self): + with ( + mock.patch("CorridorKeyModule.backend.shutil.which", return_value="/opt/homebrew/bin/uv"), + mock.patch("CorridorKeyModule.backend.sys") as mock_sys, + ): + mock_sys.executable = "/tmp/project/.venv/bin/python" + + assert _install_commands() == [ + [ + "/opt/homebrew/bin/uv", + "pip", + "install", + "--python", + "/tmp/project/.venv/bin/python", + "corridorkey-mlx@git+https://github.com/nikopueringer/corridorkey-mlx.git", + ], + [ + "/tmp/project/.venv/bin/python", + "-m", + "pip", + "install", + "corridorkey-mlx@git+https://github.com/nikopueringer/corridorkey-mlx.git", + ], + ] + + def test_falls_back_to_pip_when_uv_missing(self): + with ( + mock.patch("CorridorKeyModule.backend.shutil.which", return_value=None), + mock.patch("CorridorKeyModule.backend.sys") as mock_sys, + ): + mock_sys.executable = "/tmp/project/.venv/bin/python" + + assert _install_commands() == [ + [ + "/tmp/project/.venv/bin/python", + "-m", + "pip", + "install", + "corridorkey-mlx@git+https://github.com/nikopueringer/corridorkey-mlx.git", + ] + ] + + +class TestEnsureMlxCheckpoint: + def test_returns_existing_mlx_checkpoint_without_conversion(self, tmp_path): + ckpt = tmp_path / "existing.safetensors" + ckpt.touch() + + with ( + mock.patch("CorridorKeyModule.backend.CHECKPOINT_DIR", str(tmp_path)), + mock.patch("CorridorKeyModule.backend._convert_torch_checkpoint_to_mlx") as mock_convert, + ): + result = _ensure_mlx_checkpoint() + + assert result == ckpt + mock_convert.assert_not_called() + + def test_converts_torch_checkpoint_when_mlx_weights_missing(self, tmp_path): + torch_ckpt = tmp_path / "CorridorKey_v1.0.pth" + torch_ckpt.touch() + + repo_dir = tmp_path / "corridorkey-mlx" + output_ckpt = tmp_path / "corridorkey_mlx.safetensors" + command_calls: list[tuple[list[str], str]] = [] + + def fake_run_checked_command(cmd, *, cwd): + command_calls.append((cmd, str(cwd))) + if cmd[0] == "/usr/bin/git": + (repo_dir / "scripts").mkdir(parents=True, exist_ok=True) + (repo_dir / "scripts" / "convert_weights.py").touch() + if cmd[:5] == ["/usr/local/bin/uv", "run", "--group", "reference", "python"]: + output_ckpt.touch() + + with ( + mock.patch("CorridorKeyModule.backend.CHECKPOINT_DIR", str(tmp_path)), + mock.patch("CorridorKeyModule.backend.PROJECT_ROOT", tmp_path), + mock.patch("CorridorKeyModule.backend.MLX_CONVERTER_REPO_DIR", repo_dir), + mock.patch("CorridorKeyModule.backend.shutil.which") as mock_which, + mock.patch("CorridorKeyModule.backend._run_checked_command", side_effect=fake_run_checked_command), + ): + mock_which.side_effect = lambda name: { + "git": "/usr/bin/git", + "uv": "/usr/local/bin/uv", + }.get(name) + + result = _ensure_mlx_checkpoint() + + assert result == output_ckpt + assert command_calls == [ + ( + [ + "/usr/bin/git", + "clone", + "https://github.com/nikopueringer/corridorkey-mlx.git", + str(repo_dir), + ], + str(tmp_path), + ), + ( + [ + "/usr/local/bin/uv", + "sync", + "--group", + "reference", + ], + str(repo_dir), + ), + ( + [ + "/usr/local/bin/uv", + "run", + "--group", + "reference", + "python", + "scripts/convert_weights.py", + "--checkpoint", + str(torch_ckpt), + "--output", + str(output_ckpt), + ], + str(repo_dir), + ), + ] + + +class TestCreateEngine: + def test_auto_on_apple_silicon_uses_mlx_engine(self, tmp_path): + ckpt = tmp_path / "model.safetensors" + ckpt.touch() + + fake_raw_engine = mock.Mock() + fake_mlx = ModuleType("corridorkey_mlx") + fake_mlx.CorridorKeyMLXEngine = mock.Mock(return_value=fake_raw_engine) + + with ( + mock.patch("CorridorKeyModule.backend.CHECKPOINT_DIR", str(tmp_path)), + mock.patch("CorridorKeyModule.backend.sys") as mock_sys, + mock.patch("CorridorKeyModule.backend.platform") as mock_platform, + mock.patch.dict(sys.modules, {"corridorkey_mlx": fake_mlx}, clear=False), + ): + mock_sys.platform = "darwin" + mock_platform.machine.return_value = "arm64" + + engine = create_engine(backend="auto", device="mps", img_size=1024) + + fake_mlx.CorridorKeyMLXEngine.assert_called_once_with( + str(ckpt), + img_size=1024, + tile_size=512, + overlap=64, + ) + assert engine._engine is fake_raw_engine + + # --- _wrap_mlx_output --- From 841b336900a4f43191d022f7b0d41615911cf328 Mon Sep 17 00:00:00 2001 From: michal Date: Mon, 23 Mar 2026 13:09:14 +0100 Subject: [PATCH 2/2] remove repo --- corridorkey-mlx | 1 - 1 file changed, 1 deletion(-) delete mode 160000 corridorkey-mlx diff --git a/corridorkey-mlx b/corridorkey-mlx deleted file mode 160000 index 04503e79..00000000 --- a/corridorkey-mlx +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 04503e797060e091f991bc88b85ec61b0b9b862b