diff --git a/CorridorKeyModule/backend.py b/CorridorKeyModule/backend.py index 459319c1..b1ece13d 100644 --- a/CorridorKeyModule/backend.py +++ b/CorridorKeyModule/backend.py @@ -15,6 +15,8 @@ import numpy as np import torch +from .model_assets import ensure_corridorkey_assets + logger = logging.getLogger(__name__) CHECKPOINT_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "checkpoints") @@ -303,7 +305,29 @@ def create_engine( Set to None to disable tiling and use full-frame inference. overlap: MLX only — overlap pixels between tiles (default 64). """ - backend = resolve_backend(backend) + requested_backend = None if backend is None else backend.lower() + + if requested_backend in (None, "auto"): + # Auto-detect needs assets on disk before resolve_backend() runs: + # _auto_detect_backend() checks for local MLX weights and can prefer MLX + # only after they have been downloaded or discovered. + ensure_corridorkey_assets( + ensure_torch=True, + ensure_mlx=False, + download_mlx_if_available=True, + checkpoint_dir=CHECKPOINT_DIR, + ) + backend = resolve_backend(backend) + else: + # Explicit backends can resolve first because the caller has already + # chosen the runtime; we only need to fetch assets for that backend. + backend = resolve_backend(backend) + ensure_corridorkey_assets( + ensure_torch=True, + ensure_mlx=backend == "mlx", + download_mlx_if_available=True, + checkpoint_dir=CHECKPOINT_DIR, + ) if backend == "mlx": ckpt = _discover_checkpoint(MLX_EXT) diff --git a/CorridorKeyModule/inference_engine.py b/CorridorKeyModule/inference_engine.py index 3f43588d..f2e6411d 100755 --- a/CorridorKeyModule/inference_engine.py +++ b/CorridorKeyModule/inference_engine.py @@ -50,7 +50,7 @@ def __init__( self.model_precision = model_precision - self.model = self._load_model() + self.model = self._load_model().to(model_precision) # We only tested compilation on Windows and Linux. For other platforms compilation is disabled as a precaution. if sys.platform == "linux" or sys.platform == "win32": diff --git a/CorridorKeyModule/model_assets.py b/CorridorKeyModule/model_assets.py new file mode 100644 index 00000000..f5332b80 --- /dev/null +++ b/CorridorKeyModule/model_assets.py @@ -0,0 +1,386 @@ +from __future__ import annotations + +import importlib.util +import logging +import os +import platform +import shutil +import subprocess +import sys +import threading +import time +import urllib.request +from pathlib import Path +from typing import Any, Callable + +from huggingface_hub import hf_hub_download, snapshot_download +from huggingface_hub.utils import GatedRepoError, HfHubHTTPError, LocalEntryNotFoundError, RepositoryNotFoundError + +logger = logging.getLogger(__name__) + +PACKAGE_DIR = Path(__file__).resolve().parent +PROJECT_ROOT = PACKAGE_DIR.parent + +DEFAULT_CORRIDORKEY_CHECKPOINT_DIR = PACKAGE_DIR / "checkpoints" +DEFAULT_GVM_WEIGHTS_DIR = PROJECT_ROOT / "gvm_core" / "weights" +DEFAULT_VIDEOMAMA_CHECKPOINTS_DIR = PROJECT_ROOT / "VideoMaMaInferenceModule" / "checkpoints" + +TORCH_EXT = ".pth" +MLX_EXT = ".safetensors" + +CORRIDORKEY_REPO_ID = "nikopueringer/CorridorKey_v1.0" +CORRIDORKEY_TORCH_FILENAME = "CorridorKey_v1.0.pth" +CORRIDORKEY_MLX_FILENAME = "corridorkey_mlx.safetensors" +CORRIDORKEY_MLX_REPO = "nikopueringer/corridorkey-mlx" +CORRIDORKEY_MLX_TAG = "v1.0.0" + +GVM_REPO_ID = "geyongtao/gvm" +VIDEOMAMA_REPO_ID = "SammyLim/VideoMaMa" +VIDEOMAMA_BASE_REPO_ID = "stabilityai/stable-video-diffusion-img2vid-xt" +VIDEOMAMA_BASE_LICENSE_URL = "https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt" + +DOWNLOAD_ATTEMPTS = 3 +_LOCKS: dict[str, threading.Lock] = {} +_LOCKS_GUARD = threading.Lock() + + +def mlx_runtime_available() -> bool: + """Return True when MLX can actually be used on this machine.""" + if sys.platform != "darwin" or platform.machine() != "arm64": + return False + return importlib.util.find_spec("corridorkey_mlx") is not None + + +def ensure_corridorkey_assets( + *, + ensure_torch: bool = True, + ensure_mlx: bool = False, + download_mlx_if_available: bool = False, + checkpoint_dir: str | os.PathLike[str] | None = None, +) -> Path: + """Ensure the main CorridorKey checkpoints needed for first-run inference exist. + + Behavior: + - Torch weights are downloaded when requested and missing. + - MLX weights are downloaded when explicitly requested and missing. + - If the checkpoint folder is otherwise empty, MLX is auto-downloaded as well + when `download_mlx_if_available=True` and the local runtime supports MLX. + """ + checkpoint_dir = _ensure_dir(checkpoint_dir or DEFAULT_CORRIDORKEY_CHECKPOINT_DIR) + lock_name = f"corridorkey:{checkpoint_dir.resolve()}" + + with _get_lock(lock_name): + torch_files = _find_checkpoint_files(checkpoint_dir, TORCH_EXT) + mlx_files = _find_checkpoint_files(checkpoint_dir, MLX_EXT) + was_empty = not torch_files and not mlx_files + + if ensure_torch and not torch_files: + _download_corridorkey_torch(checkpoint_dir) + torch_files = _find_checkpoint_files(checkpoint_dir, TORCH_EXT) + + should_download_mlx = False + if ensure_mlx and not mlx_files: + should_download_mlx = True + elif download_mlx_if_available and was_empty and not mlx_files and mlx_runtime_available(): + should_download_mlx = True + + if should_download_mlx: + _download_corridorkey_mlx(checkpoint_dir) + + return checkpoint_dir + + +def ensure_gvm_weights(weights_dir: str | os.PathLike[str] | None = None) -> Path: + """Ensure the GVM alpha-generation weights are available locally.""" + weights_dir = _ensure_dir(weights_dir or DEFAULT_GVM_WEIGHTS_DIR) + lock_name = f"gvm:{weights_dir.resolve()}" + + with _get_lock(lock_name): + if _gvm_weights_ready(weights_dir): + return weights_dir + + logger.info("Downloading GVM weights to %s", weights_dir) + _download_with_retries( + label="GVM weights", + action=lambda: snapshot_download( + repo_id=GVM_REPO_ID, + local_dir=str(weights_dir), + local_dir_use_symlinks=False, + allow_patterns=["vae/*", "scheduler/*", "unet/*"], + ), + ) + + if not _gvm_weights_ready(weights_dir): + raise RuntimeError(f"GVM download completed but required files are still missing in {weights_dir}") + + return weights_dir + + +def ensure_videomama_weights( + checkpoints_dir: str | os.PathLike[str] | None = None, +) -> tuple[Path, Path]: + """Ensure both VideoMaMa checkpoints are present locally.""" + checkpoints_dir = _ensure_dir(checkpoints_dir or DEFAULT_VIDEOMAMA_CHECKPOINTS_DIR) + base_dir = _ensure_dir(checkpoints_dir / "stable-video-diffusion-img2vid-xt") + unet_dir = _ensure_dir(checkpoints_dir / "VideoMaMa") + lock_name = f"videomama:{checkpoints_dir.resolve()}" + + with _get_lock(lock_name): + if not _videomama_base_ready(base_dir): + logger.info("Downloading VideoMaMa base weights to %s", base_dir) + _download_with_retries( + label="VideoMaMa base weights", + action=lambda: snapshot_download( + repo_id=VIDEOMAMA_BASE_REPO_ID, + local_dir=str(base_dir), + local_dir_use_symlinks=False, + allow_patterns=["feature_extractor/*", "image_encoder/*", "vae/*"], + ), + gated_repo_url=VIDEOMAMA_BASE_LICENSE_URL, + ) + + if not _videomama_unet_ready(unet_dir): + logger.info("Downloading VideoMaMa UNet weights to %s", unet_dir) + _download_with_retries( + label="VideoMaMa UNet weights", + action=lambda: snapshot_download( + repo_id=VIDEOMAMA_REPO_ID, + local_dir=str(unet_dir), + local_dir_use_symlinks=False, + allow_patterns=["unet/*"], + ), + ) + + if not _videomama_base_ready(base_dir): + raise RuntimeError(f"VideoMaMa base download completed but files are still missing in {base_dir}") + if not _videomama_unet_ready(unet_dir): + raise RuntimeError(f"VideoMaMa UNet download completed but files are still missing in {unet_dir}") + + return base_dir, unet_dir + + +def _download_corridorkey_torch(checkpoint_dir: Path) -> Path: + logger.info("Downloading CorridorKey Torch checkpoint to %s", checkpoint_dir) + + cached_path = _download_with_retries( + label="CorridorKey Torch checkpoint", + action=lambda: hf_hub_download( + repo_id=CORRIDORKEY_REPO_ID, + filename=CORRIDORKEY_TORCH_FILENAME, + ), + ) + + destination = checkpoint_dir / CORRIDORKEY_TORCH_FILENAME + _copy_atomic(Path(cached_path), destination) + return destination + + +def _download_corridorkey_mlx(checkpoint_dir: Path) -> Path: + if not mlx_runtime_available(): + raise RuntimeError("MLX weights were requested, but MLX is not available on this machine.") + + logger.info("Downloading CorridorKey MLX checkpoint to %s", checkpoint_dir) + release_tag = os.environ.get("CORRIDORKEY_MLX_WEIGHTS_TAG", CORRIDORKEY_MLX_TAG) + repo_override = os.environ.get("CORRIDORKEY_MLX_WEIGHTS_REPO", CORRIDORKEY_MLX_REPO) + + def action() -> Path: + command = [ + sys.executable, + "-m", + "corridorkey_mlx", + "weights", + "download", + "--tag", + release_tag, + "--asset", + CORRIDORKEY_MLX_FILENAME, + "--print-path", + ] + env = os.environ.copy() + env["CORRIDORKEY_MLX_WEIGHTS_REPO"] = repo_override + + try: + completed = subprocess.run(command, check=True, capture_output=True, text=True, env=env) + except subprocess.CalledProcessError: + logger.info( + "corridorkey_mlx CLI download failed for %s@%s; falling back to direct download.", + repo_override, + release_tag, + ) + return _download_corridorkey_mlx_direct( + checkpoint_dir, repo_override=repo_override, release_tag=release_tag + ) + + cached_path = _extract_path_from_output(completed) + if cached_path is None or not cached_path.exists(): + logger.info( + "corridorkey_mlx CLI did not return a usable path for %s@%s; falling back to direct download.", + repo_override, + release_tag, + ) + return _download_corridorkey_mlx_direct( + checkpoint_dir, repo_override=repo_override, release_tag=release_tag + ) + return cached_path + + cached_path = _download_with_retries(label="CorridorKey MLX checkpoint", action=action) + + destination = checkpoint_dir / CORRIDORKEY_MLX_FILENAME + if cached_path.resolve() != destination.resolve(): + _copy_atomic(cached_path, destination) + return destination + + +def _download_corridorkey_mlx_direct(checkpoint_dir: Path, *, repo_override: str, release_tag: str) -> Path: + destination = checkpoint_dir / CORRIDORKEY_MLX_FILENAME + if destination.exists(): + return destination + + download_url = f"https://github.com/{repo_override}/releases/download/{release_tag}/{CORRIDORKEY_MLX_FILENAME}" + tmp_path = destination.with_name(f".{destination.name}.download-{os.getpid()}-{threading.get_ident()}") + + try: + urllib.request.urlretrieve(download_url, tmp_path) + os.replace(tmp_path, destination) + finally: + if tmp_path.exists(): + tmp_path.unlink(missing_ok=True) + + return destination + + +def _download_with_retries( + *, + label: str, + action: Callable[[], Any], + gated_repo_url: str | None = None, +) -> Any: + last_exc: Exception | None = None + + for attempt in range(1, DOWNLOAD_ATTEMPTS + 1): + try: + return action() + except GatedRepoError as exc: + last_exc = exc + break + except RepositoryNotFoundError as exc: + last_exc = exc + break + except subprocess.CalledProcessError as exc: + last_exc = exc + except (HfHubHTTPError, LocalEntryNotFoundError, OSError, RuntimeError) as exc: + last_exc = exc + + if attempt < DOWNLOAD_ATTEMPTS: + logger.warning("%s download failed on attempt %d/%d. Retrying...", label, attempt, DOWNLOAD_ATTEMPTS) + time.sleep(min(2**attempt, 8)) + + assert last_exc is not None + raise _wrap_download_error(label=label, exc=last_exc, gated_repo_url=gated_repo_url) from last_exc + + +def _wrap_download_error(*, label: str, exc: Exception, gated_repo_url: str | None = None) -> RuntimeError: + if isinstance(exc, GatedRepoError): + if gated_repo_url: + return RuntimeError( + f"{label} could not be downloaded because the repository is gated. " + f"Accept the license at {gated_repo_url}, then retry." + ) + return RuntimeError(f"{label} could not be downloaded because the repository is gated.") + + if isinstance(exc, RepositoryNotFoundError): + return RuntimeError(f"{label} could not be downloaded because the source repository was not found.") + + if isinstance(exc, subprocess.CalledProcessError): + stderr = (exc.stderr or "").strip() + stdout = (exc.stdout or "").strip() + details = stderr or stdout or str(exc) + return RuntimeError(f"{label} download failed: {details}") + + if isinstance(exc, HfHubHTTPError): + return RuntimeError(f"{label} download failed with an HTTP error: {exc}") + + if isinstance(exc, LocalEntryNotFoundError): + return RuntimeError(f"{label} download failed because the files could not be fetched from Hugging Face.") + + return RuntimeError(f"{label} download failed: {exc}") + + +def _find_checkpoint_files(directory: Path, ext: str) -> list[Path]: + return sorted(path for path in directory.glob(f"*{ext}") if path.is_file()) + + +def _ensure_dir(path: str | os.PathLike[str]) -> Path: + directory = Path(path) + directory.mkdir(parents=True, exist_ok=True) + return directory + + +def _copy_atomic(source: Path, destination: Path) -> None: + destination.parent.mkdir(parents=True, exist_ok=True) + if destination.exists(): + return + + tmp_path = destination.with_name(f".{destination.name}.tmp-{os.getpid()}-{threading.get_ident()}") + + try: + shutil.copy2(source, tmp_path) + os.replace(tmp_path, destination) + finally: + if tmp_path.exists(): + tmp_path.unlink(missing_ok=True) + + +def _extract_path_from_output( + completed: subprocess.CompletedProcess[str], *, expected_name: str = CORRIDORKEY_MLX_FILENAME +) -> Path | None: + for stream in (completed.stdout, completed.stderr): + lines = [line.strip() for line in stream.splitlines() if line.strip()] + for line in reversed(lines): + if not _looks_like_path(line): + continue + candidate = Path(line).expanduser() + if candidate.name == expected_name: + return candidate + return None + + +def _looks_like_path(line: str) -> bool: + return line.startswith(("/", "~", "./", "../")) or (len(line) >= 3 and line[1] == ":" and line[2] in ("\\", "/")) + + +def _has_weight_file(directory: Path) -> bool: + if not directory.exists(): + return False + return any( + path.is_file() and path.suffix in {".bin", ".pt", ".pth", ".safetensors"} for path in directory.rglob("*") + ) + + +def _gvm_weights_ready(weights_dir: Path) -> bool: + return ( + _has_weight_file(weights_dir / "vae") + and (weights_dir / "scheduler" / "scheduler_config.json").is_file() + and _has_weight_file(weights_dir / "unet") + ) + + +def _videomama_base_ready(base_dir: Path) -> bool: + return ( + (base_dir / "feature_extractor" / "preprocessor_config.json").is_file() + and _has_weight_file(base_dir / "image_encoder") + and _has_weight_file(base_dir / "vae") + ) + + +def _videomama_unet_ready(unet_dir: Path) -> bool: + return _has_weight_file(unet_dir / "unet") + + +def _get_lock(name: str) -> threading.Lock: + with _LOCKS_GUARD: + lock = _LOCKS.get(name) + if lock is None: + lock = threading.Lock() + _LOCKS[name] = lock + return lock diff --git a/VideoMaMaInferenceModule/inference.py b/VideoMaMaInferenceModule/inference.py index 4032db5e..93106360 100644 --- a/VideoMaMaInferenceModule/inference.py +++ b/VideoMaMaInferenceModule/inference.py @@ -12,6 +12,8 @@ from typing import List, Union, Optional from pathlib import Path +from CorridorKeyModule.model_assets import ensure_videomama_weights + # Add current directory to path so that pipeline.py's intra-package imports # (e.g. "from pipeline import ...") resolve when this module is imported from # outside the VideoMaMaInferenceModule directory. This is a workaround for the @@ -37,12 +39,13 @@ def load_videomama_model(base_model_path: Optional[str] = None, unet_checkpoint_ Returns: VideoInferencePipeline: Loaded pipeline instance. """ - # Default to local checkpoints if not provided - if base_model_path is None: - base_model_path = os.path.join(current_dir, "checkpoints", "stable-video-diffusion-img2vid-xt") - - if unet_checkpoint_path is None: - unet_checkpoint_path = os.path.join(current_dir, "checkpoints", "VideoMaMa") + checkpoints_dir = os.path.join(current_dir, "checkpoints") + if base_model_path is None or unet_checkpoint_path is None: + ensured_base_path, ensured_unet_path = ensure_videomama_weights(checkpoints_dir) + if base_model_path is None: + base_model_path = str(ensured_base_path) + if unet_checkpoint_path is None: + unet_checkpoint_path = str(ensured_unet_path) print(f"Loading Base model from {base_model_path}...") print(f"Loading VideoMaMa UNet from {unet_checkpoint_path}...") @@ -198,4 +201,3 @@ def save_video(frames: List[np.ndarray], output_path: str, fps: float): out.release() print(f"Saved video to {output_path}") - diff --git a/backend/service.py b/backend/service.py index 0d341876..b78c433e 100644 --- a/backend/service.py +++ b/backend/service.py @@ -289,19 +289,11 @@ def _get_engine(self): return self._engine try: - from CorridorKeyModule.backend import TORCH_EXT, _discover_checkpoint - from CorridorKeyModule.inference_engine import CorridorKeyEngine + from CorridorKeyModule.backend import create_engine except ImportError as exc: raise RuntimeError("CorridorKeyModule is not installed. Run: uv sync") from exc - - ckpt_path = _discover_checkpoint(TORCH_EXT) - logger.info(f"Loading checkpoint: {os.path.basename(ckpt_path)}") t0 = time.monotonic() - self._engine = CorridorKeyEngine( - checkpoint_path=ckpt_path, - device=self._device, - img_size=2048, - ) + self._engine = create_engine(backend="torch", device=self._device, img_size=2048) logger.info(f"Engine loaded in {time.monotonic() - t0:.1f}s") return self._engine diff --git a/gvm_core/wrapper.py b/gvm_core/wrapper.py index 66d6217f..3f153303 100644 --- a/gvm_core/wrapper.py +++ b/gvm_core/wrapper.py @@ -16,6 +16,8 @@ from diffusers import AutoencoderKLTemporalDecoder, FlowMatchEulerDiscreteScheduler from tqdm import tqdm +from CorridorKeyModule.model_assets import ensure_gvm_weights + # Relative imports from the internal gvm package # Assuming this file is inside gvm_core/ from .gvm.pipelines.pipeline_gvm import GVMPipeline @@ -70,6 +72,7 @@ def __init__(self, # Resolve default weights path relative to this file if model_base is None: model_base = osp.join(osp.dirname(__file__), "weights") + ensure_gvm_weights(model_base) self.model_base = model_base self.unet_base = unet_base diff --git a/tests/test_backend.py b/tests/test_backend.py index 7276ff61..1c61f866 100644 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -3,6 +3,8 @@ import errno import logging import os +import sys +from types import ModuleType from unittest import mock import numpy as np @@ -17,6 +19,7 @@ _discover_checkpoint, _ensure_torch_checkpoint, _wrap_mlx_output, + create_engine, resolve_backend, ) @@ -203,6 +206,50 @@ def test_logging_on_download(self, tmp_path, caplog): assert any("saved" in msg.lower() for msg in caplog.messages) +class TestCreateEngineBootstrap: + def test_auto_backend_bootstraps_assets_before_engine_load(self, tmp_path): + ensured_kwargs = {} + + def fake_ensure(**kwargs): + ensured_kwargs.update(kwargs) + (tmp_path / "model.pth").touch() + + inference_mod = ModuleType("CorridorKeyModule.inference_engine") + + class DummyEngine: + def __init__( + self, + checkpoint_path: str, + device: str, + img_size: int, + model_precision=None, + ): + self.checkpoint_path = checkpoint_path + self.device = device + self.img_size = img_size + self.model_precision = model_precision + + inference_mod.CorridorKeyEngine = DummyEngine + + with ( + mock.patch("CorridorKeyModule.backend.CHECKPOINT_DIR", str(tmp_path)), + mock.patch("CorridorKeyModule.backend.ensure_corridorkey_assets", side_effect=fake_ensure), + mock.patch("CorridorKeyModule.backend.resolve_backend", return_value="torch"), + mock.patch.dict(sys.modules, {"CorridorKeyModule.inference_engine": inference_mod}), + ): + engine = create_engine(backend="auto", device="cpu", img_size=1024) + + assert ensured_kwargs == { + "ensure_torch": True, + "ensure_mlx": False, + "download_mlx_if_available": True, + "checkpoint_dir": str(tmp_path), + } + assert engine.checkpoint_path == str(tmp_path / "model.pth") + assert engine.device == "cpu" + assert engine.img_size == 1024 + + # --- _wrap_mlx_output --- diff --git a/tests/test_inference_engine.py b/tests/test_inference_engine.py index 243b72bf..3e26b988 100644 --- a/tests/test_inference_engine.py +++ b/tests/test_inference_engine.py @@ -13,6 +13,8 @@ from __future__ import annotations +from unittest import mock + import numpy as np import pytest import torch @@ -45,6 +47,37 @@ def _make_engine_with_mock(mock_greenformer, img_size=64, device="cpu"): return engine +# --------------------------------------------------------------------------- +# __init__ regression coverage +# --------------------------------------------------------------------------- + + +class _InitDummyModel: + def __init__(self): + self.to_calls = [] + + def to(self, arg): + self.to_calls.append(arg) + return self + + +class TestEngineInitialization: + def test_non_compiled_platform_still_sets_model(self): + """macOS/eager path must still assign self.model before first inference.""" + from CorridorKeyModule.inference_engine import CorridorKeyEngine + + dummy_model = _InitDummyModel() + + with ( + mock.patch("CorridorKeyModule.inference_engine.CorridorKeyEngine._load_model", return_value=dummy_model), + mock.patch("CorridorKeyModule.inference_engine.sys.platform", "darwin"), + ): + engine = CorridorKeyEngine(checkpoint_path="/fake/checkpoint.pth", device="cpu") + + assert engine.model is dummy_model + assert dummy_model.to_calls == [torch.float32] + + # --------------------------------------------------------------------------- # process_frame output structure # --------------------------------------------------------------------------- diff --git a/tests/test_model_assets.py b/tests/test_model_assets.py new file mode 100644 index 00000000..6ecb6fbb --- /dev/null +++ b/tests/test_model_assets.py @@ -0,0 +1,192 @@ +"""Tests for runtime model asset bootstrap/download helpers.""" + +from __future__ import annotations + +import subprocess +from pathlib import Path +from unittest import mock + +from CorridorKeyModule import model_assets as assets + + +def _write_text(path: Path, text: str = "x") -> Path: + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(text) + return path + + +class TestEnsureCorridorKeyAssets: + def test_empty_checkpoint_dir_downloads_torch_and_mlx_when_available(self, tmp_path): + torch_downloads: list[Path] = [] + mlx_downloads: list[Path] = [] + + def fake_torch_download(checkpoint_dir: Path) -> Path: + torch_downloads.append(checkpoint_dir) + return _write_text(checkpoint_dir / assets.CORRIDORKEY_TORCH_FILENAME) + + def fake_mlx_download(checkpoint_dir: Path) -> Path: + mlx_downloads.append(checkpoint_dir) + return _write_text(checkpoint_dir / assets.CORRIDORKEY_MLX_FILENAME) + + with ( + mock.patch.object(assets, "_download_corridorkey_torch", side_effect=fake_torch_download), + mock.patch.object(assets, "_download_corridorkey_mlx", side_effect=fake_mlx_download), + mock.patch.object(assets, "mlx_runtime_available", return_value=True), + ): + assets.ensure_corridorkey_assets( + checkpoint_dir=tmp_path, + ensure_torch=True, + ensure_mlx=False, + download_mlx_if_available=True, + ) + + assert torch_downloads == [tmp_path] + assert mlx_downloads == [tmp_path] + assert (tmp_path / assets.CORRIDORKEY_TORCH_FILENAME).is_file() + assert (tmp_path / assets.CORRIDORKEY_MLX_FILENAME).is_file() + + def test_existing_torch_checkpoint_skips_opportunistic_mlx_download(self, tmp_path): + _write_text(tmp_path / "existing_model.pth") + + with ( + mock.patch.object(assets, "_download_corridorkey_torch") as mock_torch, + mock.patch.object(assets, "_download_corridorkey_mlx") as mock_mlx, + mock.patch.object(assets, "mlx_runtime_available", return_value=True), + ): + assets.ensure_corridorkey_assets( + checkpoint_dir=tmp_path, + ensure_torch=True, + ensure_mlx=False, + download_mlx_if_available=True, + ) + + mock_torch.assert_not_called() + mock_mlx.assert_not_called() + + def test_explicit_mlx_request_downloads_mlx_even_when_torch_exists(self, tmp_path): + _write_text(tmp_path / "existing_model.pth") + + with mock.patch.object( + assets, + "_download_corridorkey_mlx", + side_effect=lambda checkpoint_dir: _write_text(checkpoint_dir / assets.CORRIDORKEY_MLX_FILENAME), + ) as mock_mlx: + assets.ensure_corridorkey_assets( + checkpoint_dir=tmp_path, + ensure_torch=False, + ensure_mlx=True, + download_mlx_if_available=False, + ) + + mock_mlx.assert_called_once_with(tmp_path) + assert (tmp_path / assets.CORRIDORKEY_MLX_FILENAME).is_file() + + +class TestDownloadCorridorKeyMlx: + def test_cli_download_uses_expected_repo_override_and_release_tag(self, tmp_path): + cached = _write_text(tmp_path / "cache" / assets.CORRIDORKEY_MLX_FILENAME) + + completed = subprocess.CompletedProcess( + args=[], + returncode=0, + stdout=str(cached), + stderr="", + ) + + with ( + mock.patch.object(assets, "mlx_runtime_available", return_value=True), + mock.patch.object(assets, "subprocess") as mock_subprocess, + ): + mock_subprocess.run.return_value = completed + result = assets._download_corridorkey_mlx(tmp_path) + + assert result == tmp_path / assets.CORRIDORKEY_MLX_FILENAME + assert result.is_file() + + _, kwargs = mock_subprocess.run.call_args + command = kwargs["args"] if "args" in kwargs else mock_subprocess.run.call_args.args[0] + env = kwargs["env"] + + assert "--tag" in command + assert assets.CORRIDORKEY_MLX_TAG in command + assert env["CORRIDORKEY_MLX_WEIGHTS_REPO"] == assets.CORRIDORKEY_MLX_REPO + + def test_cli_failure_falls_back_to_direct_release_download(self, tmp_path): + with ( + mock.patch.object(assets, "mlx_runtime_available", return_value=True), + mock.patch.object( + assets.subprocess, + "run", + side_effect=subprocess.CalledProcessError(1, ["corridorkey_mlx"]), + ), + mock.patch.object(assets.urllib.request, "urlretrieve") as mock_urlretrieve, + ): + mock_urlretrieve.side_effect = lambda url, dest: Path(dest).write_bytes(b"mlx-weights") + result = assets._download_corridorkey_mlx(tmp_path) + + assert result == tmp_path / assets.CORRIDORKEY_MLX_FILENAME + assert result.read_bytes() == b"mlx-weights" + download_url = mock_urlretrieve.call_args.args[0] + assert assets.CORRIDORKEY_MLX_REPO in download_url + assert assets.CORRIDORKEY_MLX_TAG in download_url + assert download_url.endswith(f"/{assets.CORRIDORKEY_MLX_FILENAME}") + + +class TestExtractPathFromOutput: + def test_returns_bare_expected_path(self, tmp_path): + expected = _write_text(tmp_path / assets.CORRIDORKEY_MLX_FILENAME) + completed = subprocess.CompletedProcess( + args=[], + returncode=0, + stdout=f"Downloading...\n{expected}\n", + stderr="", + ) + + assert assets._extract_path_from_output(completed) == expected + + def test_ignores_non_path_logs_and_other_safetensors_files(self, tmp_path): + other = _write_text(tmp_path / "cache" / "other_model.safetensors") + expected = _write_text(tmp_path / "cache" / assets.CORRIDORKEY_MLX_FILENAME) + completed = subprocess.CompletedProcess( + args=[], + returncode=0, + stdout=f"{other}\ncache hit for {expected}\n", + stderr=f"{expected}\n", + ) + + assert assets._extract_path_from_output(completed) == expected + + +class TestEnsureOptionalStepWeights: + def test_gvm_weights_download_once_and_reuse(self, tmp_path): + def fake_snapshot_download(*, local_dir: str, **kwargs): + weights_dir = Path(local_dir) + _write_text(weights_dir / "vae" / "diffusion_pytorch_model.safetensors") + _write_text(weights_dir / "scheduler" / "scheduler_config.json", "{}") + _write_text(weights_dir / "unet" / "diffusion_pytorch_model.safetensors") + return str(weights_dir) + + with mock.patch.object(assets, "snapshot_download", side_effect=fake_snapshot_download) as mock_snapshot: + assets.ensure_gvm_weights(tmp_path) + assets.ensure_gvm_weights(tmp_path) + + assert mock_snapshot.call_count == 1 + + def test_videomama_weights_download_base_and_unet_once(self, tmp_path): + def fake_snapshot_download(*, repo_id: str, local_dir: str, **kwargs): + target_dir = Path(local_dir) + if repo_id == assets.VIDEOMAMA_BASE_REPO_ID: + _write_text(target_dir / "feature_extractor" / "preprocessor_config.json", "{}") + _write_text(target_dir / "image_encoder" / "model.safetensors") + _write_text(target_dir / "vae" / "diffusion_pytorch_model.safetensors") + elif repo_id == assets.VIDEOMAMA_REPO_ID: + _write_text(target_dir / "unet" / "diffusion_pytorch_model.safetensors") + else: + raise AssertionError(f"Unexpected repo_id: {repo_id}") + return str(target_dir) + + with mock.patch.object(assets, "snapshot_download", side_effect=fake_snapshot_download) as mock_snapshot: + assets.ensure_videomama_weights(tmp_path) + assets.ensure_videomama_weights(tmp_path) + + assert mock_snapshot.call_count == 2