diff --git a/CorridorKeyModule/backend.py b/CorridorKeyModule/backend.py index e01aa43b..da1024d3 100644 --- a/CorridorKeyModule/backend.py +++ b/CorridorKeyModule/backend.py @@ -19,6 +19,7 @@ CHECKPOINT_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "checkpoints") TORCH_EXT = ".pth" MLX_EXT = ".safetensors" +TORCH_SAFETENSORS_EXT = ".safetensors" DEFAULT_IMG_SIZE = 2048 BACKEND_ENV_VAR = "CORRIDORKEY_BACKEND" @@ -163,9 +164,28 @@ def _ensure_torch_checkpoint() -> Path: def _discover_checkpoint(ext: str) -> Path: """Find exactly one checkpoint with the given extension. + For the Torch backend (TORCH_EXT), safetensors files are checked first + (excluding the MLX model by name) so that converted checkpoints are + preferred over legacy .pth files without any manual rename step. + Raises FileNotFoundError (0 found) or ValueError (>1 found). Includes cross-reference hints when wrong extension files exist. """ + if ext == TORCH_EXT: + # Prefer .safetensors over .pth — same dir is shared with MLX, so + # exclude the known MLX filename to avoid ambiguity. + st_matches = [ + m + for m in glob.glob(os.path.join(CHECKPOINT_DIR, f"*{TORCH_SAFETENSORS_EXT}")) + if os.path.basename(m) != MLX_MODEL_FILENAME + ] + if len(st_matches) == 1: + return Path(st_matches[0]) + if len(st_matches) > 1: + names = [os.path.basename(f) for f in st_matches] + raise ValueError(f"Multiple .safetensors Torch checkpoints in {CHECKPOINT_DIR}: {names}. Keep exactly one.") + # Fall through to .pth discovery below. + matches = glob.glob(os.path.join(CHECKPOINT_DIR, f"*{ext}")) if len(matches) == 0: @@ -186,6 +206,26 @@ def _discover_checkpoint(ext: str) -> Path: return Path(matches[0]) +def _migrate_pth_to_safetensors(src: Path) -> Path: + """Convert a .pth checkpoint to .safetensors in-place. + + Saves the converted file alongside the original, then deletes the .pth. + Returns the path to the new .safetensors file. + """ + import torch + from safetensors.torch import save_file + + dst = src.with_suffix(".safetensors") + logger.info("Migrating %s → %s ...", src.name, dst.name) + checkpoint = torch.load(src, map_location="cpu", weights_only=True) + state_dict = checkpoint.get("state_dict", checkpoint) + state_dict = {k: v.contiguous() for k, v in state_dict.items()} + save_file(state_dict, dst) + src.unlink() + logger.info("Migration complete. Legacy file removed: %s", src.name) + return dst + + def _wrap_mlx_output(raw: dict, despill_strength: float, auto_despeckle: bool, despeckle_size: int) -> dict: """Normalize MLX uint8 output to match Torch float32 contract. @@ -313,6 +353,8 @@ def create_engine( return _MLXEngineAdapter(raw_engine) else: ckpt = _discover_checkpoint(TORCH_EXT) + if ckpt.suffix == ".pth": + ckpt = _migrate_pth_to_safetensors(ckpt) from CorridorKeyModule.inference_engine import CorridorKeyEngine logger.info("Torch engine loaded: %s (device=%s)", ckpt.name, device) diff --git a/CorridorKeyModule/inference_engine.py b/CorridorKeyModule/inference_engine.py index 44d28150..88eaceaf 100644 --- a/CorridorKeyModule/inference_engine.py +++ b/CorridorKeyModule/inference_engine.py @@ -77,8 +77,13 @@ def _load_model(self) -> GreenFormer: if not os.path.isfile(self.checkpoint_path): raise FileNotFoundError(f"Checkpoint not found: {self.checkpoint_path}") - checkpoint = torch.load(self.checkpoint_path, map_location=self.device, weights_only=True) - state_dict = checkpoint.get("state_dict", checkpoint) + if self.checkpoint_path.endswith(".safetensors"): + from safetensors.torch import load_file as _st_load + + state_dict = _st_load(self.checkpoint_path, device=str(self.device)) + else: + checkpoint = torch.load(self.checkpoint_path, map_location=self.device, weights_only=True) + state_dict = checkpoint.get("state_dict", checkpoint) # Fix Compiled Model Prefix & Handle PosEmbed Mismatch new_state_dict = {} diff --git a/convert_checkpoint.py b/convert_checkpoint.py new file mode 100644 index 00000000..90603c6f --- /dev/null +++ b/convert_checkpoint.py @@ -0,0 +1,123 @@ +#!/usr/bin/env python3 +"""Convert .pth / .pt checkpoints to the safetensors format. + +Usage +----- +# Convert all checkpoints in the default directory: +python convert_checkpoint.py + +# Convert specific files or directories: +python convert_checkpoint.py path/to/model.pth path/to/lora_dir/ + +# Convert and remove the originals after success: +python convert_checkpoint.py --delete-original +""" + +from __future__ import annotations + +import argparse +import sys +from pathlib import Path + +DEFAULT_CHECKPOINT_DIR = Path("CorridorKeyModule/checkpoints") + + +def _convert(src: Path) -> Path: + """Load *src* (.pth / .pt) and save as a .safetensors file next to it. + + Handles both plain state-dicts and checkpoints that wrap the state-dict + under a ``"state_dict"`` key (e.g. checkpoints saved by PyTorch Lightning + or the CorridorKey training harness). + + Returns the path of the newly created .safetensors file. + """ + import torch + from safetensors.torch import save_file + + print(f" Loading {src} …") + checkpoint = torch.load(src, map_location="cpu", weights_only=True) + state_dict = checkpoint.get("state_dict", checkpoint) if isinstance(checkpoint, dict) else checkpoint + + non_tensor = [k for k, v in state_dict.items() if not hasattr(v, "contiguous")] + if non_tensor: + print(f" Warning: skipping non-tensor keys: {non_tensor}") + state_dict = {k: v for k, v in state_dict.items() if hasattr(v, "contiguous")} + + state_dict = {k: v.contiguous() for k, v in state_dict.items()} + + dst = src.with_suffix(".safetensors") + print(f" Saving {dst} ({len(state_dict)} tensors) …") + save_file(state_dict, dst) + return dst + + +def convert_file(src: Path, *, delete_original: bool = False) -> Path: + if not src.exists(): + raise FileNotFoundError(f"Not found: {src}") + if src.suffix not in (".pth", ".pt"): + raise ValueError(f"Expected .pth or .pt, got: {src}") + + dst = _convert(src) + + if delete_original: + src.unlink() + print(f" Deleted {src}") + + print(f" Done: {dst}\n") + return dst + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Convert .pth/.pt checkpoints to safetensors.", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=__doc__, + ) + parser.add_argument( + "paths", + nargs="*", + help="Files or directories to convert. Defaults to CorridorKeyModule/checkpoints/.", + ) + parser.add_argument( + "--delete-original", + action="store_true", + help="Remove the original .pth/.pt file after a successful conversion.", + ) + args = parser.parse_args() + + targets: list[Path] = [] + search_roots = [Path(p) for p in args.paths] if args.paths else [DEFAULT_CHECKPOINT_DIR] + + for root in search_roots: + if root.is_file(): + targets.append(root) + elif root.is_dir(): + found = list(root.glob("*.pth")) + list(root.glob("*.pt")) + if not found: + print(f"No .pth/.pt files found in {root}") + targets.extend(found) + else: + print(f"Warning: path not found — {root}", file=sys.stderr) + + if not targets: + print("Nothing to convert.") + sys.exit(0) + + errors: list[tuple[Path, Exception]] = [] + for src in targets: + print(f"Converting {src.name}:") + try: + convert_file(src, delete_original=args.delete_original) + except Exception as exc: # noqa: BLE001 + print(f" ERROR: {exc}\n", file=sys.stderr) + errors.append((src, exc)) + + if errors: + print(f"\n{len(errors)} file(s) failed to convert.", file=sys.stderr) + sys.exit(1) + + print(f"Converted {len(targets)} file(s).") + + +if __name__ == "__main__": + main() diff --git a/gvm_core/gvm/pipelines/pipeline_gvm.py b/gvm_core/gvm/pipelines/pipeline_gvm.py index 814e85d5..41e9c82b 100644 --- a/gvm_core/gvm/pipelines/pipeline_gvm.py +++ b/gvm_core/gvm/pipelines/pipeline_gvm.py @@ -36,8 +36,14 @@ def load_lora_weights( ): unet_lora_config = LoraConfig.from_pretrained(pretrained_model_name_or_path_or_dict) - checkpoint = os.path.join(pretrained_model_name_or_path_or_dict, f"pytorch_lora_weights.pt") - unet_lora_ckpt = torch.load(checkpoint) + checkpoint_st = os.path.join(pretrained_model_name_or_path_or_dict, "pytorch_lora_weights.safetensors") + checkpoint_pt = os.path.join(pretrained_model_name_or_path_or_dict, "pytorch_lora_weights.pt") + if os.path.exists(checkpoint_st): + from safetensors.torch import load_file as _st_load + + unet_lora_ckpt = _st_load(checkpoint_st) + else: + unet_lora_ckpt = torch.load(checkpoint_pt) self.unet = LoraModel(self.unet, unet_lora_config, "default") set_peft_model_state_dict(self.unet, unet_lora_ckpt) diff --git a/pyproject.toml b/pyproject.toml index f4f76372..b3c21751 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,8 @@ dependencies = [ "einops", # BiRefNet alpha hint generator "kornia", + # Safetensors — secure, fast tensor serialisation (replaces torch.load for .safetensors checkpoints) + "safetensors", # CLI tools (huggingface-hub is also a transitive dep, but must be direct # so that uv installs the "hf" console-script entry point) "huggingface-hub", diff --git a/tests/test_backend.py b/tests/test_backend.py index 7276ff61..d8f2719a 100644 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -13,9 +13,11 @@ HF_CHECKPOINT_FILENAME, HF_REPO_ID, MLX_EXT, + MLX_MODEL_FILENAME, TORCH_EXT, _discover_checkpoint, _ensure_torch_checkpoint, + _migrate_pth_to_safetensors, _wrap_mlx_output, resolve_backend, ) @@ -154,6 +156,44 @@ def test_skip_when_present(self, tmp_path): assert result == ckpt mock_dl.assert_not_called() + def test_safetensors_preferred_over_pth(self, tmp_path): + """When both .safetensors and .pth are present, .safetensors is returned.""" + (tmp_path / "model.pth").write_bytes(b"legacy") + st = tmp_path / "model.safetensors" + st.write_bytes(b"converted") + with mock.patch("CorridorKeyModule.backend.CHECKPOINT_DIR", str(tmp_path)): + with mock.patch("huggingface_hub.hf_hub_download") as mock_dl: + result = _discover_checkpoint(TORCH_EXT) + assert result == st + mock_dl.assert_not_called() + + def test_safetensors_torch_found_directly(self, tmp_path): + """A lone .safetensors (non-MLX) is returned for TORCH_EXT discovery.""" + st = tmp_path / "CorridorKey.safetensors" + st.write_bytes(b"converted") + with mock.patch("CorridorKeyModule.backend.CHECKPOINT_DIR", str(tmp_path)): + result = _discover_checkpoint(TORCH_EXT) + assert result == st + + def test_mlx_safetensors_excluded_from_torch_discovery(self, tmp_path): + """The MLX checkpoint is not returned when discovering Torch checkpoints.""" + mlx_ckpt = tmp_path / MLX_MODEL_FILENAME + mlx_ckpt.write_bytes(b"mlx-weights") + with mock.patch("CorridorKeyModule.backend.CHECKPOINT_DIR", str(tmp_path)): + # No Torch .safetensors → should fall through to .pth, find none, auto-download + with mock.patch("CorridorKeyModule.backend._ensure_torch_checkpoint") as mock_dl: + mock_dl.return_value = tmp_path / "CorridorKey.pth" + _discover_checkpoint(TORCH_EXT) + mock_dl.assert_called_once() + + def test_multiple_safetensors_torch_raises(self, tmp_path): + """More than one non-MLX .safetensors raises ValueError.""" + (tmp_path / "a.safetensors").write_bytes(b"x") + (tmp_path / "b.safetensors").write_bytes(b"y") + with mock.patch("CorridorKeyModule.backend.CHECKPOINT_DIR", str(tmp_path)): + with pytest.raises(ValueError, match="Multiple"): + _discover_checkpoint(TORCH_EXT) + def test_mlx_not_triggered(self, tmp_path): """MLX ext with empty dir raises FileNotFoundError, no download attempted.""" with mock.patch("CorridorKeyModule.backend.CHECKPOINT_DIR", str(tmp_path)): @@ -203,6 +243,45 @@ def test_logging_on_download(self, tmp_path, caplog): assert any("saved" in msg.lower() for msg in caplog.messages) +# --- _migrate_pth_to_safetensors --- + + +class TestMigratePthToSafetensors: + def test_converts_and_deletes_pth(self, tmp_path): + """Successful migration saves .safetensors and removes the .pth.""" + import torch + from safetensors.torch import load_file + + src = tmp_path / "model.pth" + state_dict = {"weight": torch.zeros(4, 4)} + torch.save({"state_dict": state_dict}, src) + + dst = _migrate_pth_to_safetensors(src) + + assert dst == src.with_suffix(".safetensors") + assert dst.exists() + assert not src.exists() + loaded = load_file(dst) + assert "weight" in loaded + + def test_flat_state_dict(self, tmp_path): + """Handles a .pth that is a plain flat dict (no 'state_dict' wrapper).""" + import torch + + src = tmp_path / "flat.pth" + torch.save({"bias": torch.ones(3)}, src) + dst = _migrate_pth_to_safetensors(src) + assert dst.exists() + + def test_returns_safetensors_path(self, tmp_path): + import torch + + src = tmp_path / "ck.pth" + torch.save({"w": torch.eye(2)}, src) + result = _migrate_pth_to_safetensors(src) + assert result.suffix == ".safetensors" + + # --- _wrap_mlx_output --- diff --git a/tests/test_pbt_auto_download.py b/tests/test_pbt_auto_download.py index 7b6c1454..f34af7d3 100644 --- a/tests/test_pbt_auto_download.py +++ b/tests/test_pbt_auto_download.py @@ -29,12 +29,14 @@ # Strategies # --------------------------------------------------------------------------- -# File extensions that are NOT .pth — used to populate "non-empty but no .pth" dirs +# File extensions that are NOT .pth and NOT .safetensors — used to populate +# "non-empty but no recognised checkpoint" dirs. +# .safetensors is excluded because it is now a valid Torch checkpoint format +# (any non-MLX .safetensors in the checkpoint dir is returned directly). _non_pth_extensions = st.sampled_from( [ ".txt", ".json", - ".safetensors", ".bin", ".onnx", ".csv", diff --git a/uv.lock b/uv.lock index 7e8fe882..a495e360 100644 --- a/uv.lock +++ b/uv.lock @@ -376,6 +376,7 @@ dependencies = [ { name = "pillow" }, { name = "pims" }, { name = "rich" }, + { name = "safetensors" }, { name = "setuptools" }, { name = "timm" }, { name = "torch", version = "2.8.0", source = { registry = "https://pypi.org/simple" }, marker = "extra == 'extra-11-corridorkey-mlx' or extra != 'extra-11-corridorkey-cuda'" }, @@ -426,6 +427,7 @@ requires-dist = [ { name = "pillow" }, { name = "pims" }, { name = "rich", specifier = ">=13" }, + { name = "safetensors" }, { name = "setuptools" }, { name = "timm", git = "https://github.com/Raiden129/pytorch-image-models-fix?branch=fix%2Fhiera-flash-attention-global-4d" }, { name = "torch", specifier = "==2.8.0" },