Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions CorridorKeyModule/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
CHECKPOINT_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "checkpoints")
TORCH_EXT = ".pth"
MLX_EXT = ".safetensors"
TORCH_SAFETENSORS_EXT = ".safetensors"
DEFAULT_IMG_SIZE = 2048

BACKEND_ENV_VAR = "CORRIDORKEY_BACKEND"
Expand Down Expand Up @@ -163,9 +164,28 @@ def _ensure_torch_checkpoint() -> Path:
def _discover_checkpoint(ext: str) -> Path:
"""Find exactly one checkpoint with the given extension.

For the Torch backend (TORCH_EXT), safetensors files are checked first
(excluding the MLX model by name) so that converted checkpoints are
preferred over legacy .pth files without any manual rename step.

Raises FileNotFoundError (0 found) or ValueError (>1 found).
Includes cross-reference hints when wrong extension files exist.
"""
if ext == TORCH_EXT:
# Prefer .safetensors over .pth — same dir is shared with MLX, so
# exclude the known MLX filename to avoid ambiguity.
st_matches = [
m
for m in glob.glob(os.path.join(CHECKPOINT_DIR, f"*{TORCH_SAFETENSORS_EXT}"))
if os.path.basename(m) != MLX_MODEL_FILENAME
]
if len(st_matches) == 1:
return Path(st_matches[0])
if len(st_matches) > 1:
names = [os.path.basename(f) for f in st_matches]
raise ValueError(f"Multiple .safetensors Torch checkpoints in {CHECKPOINT_DIR}: {names}. Keep exactly one.")
# Fall through to .pth discovery below.

matches = glob.glob(os.path.join(CHECKPOINT_DIR, f"*{ext}"))

if len(matches) == 0:
Expand All @@ -186,6 +206,26 @@ def _discover_checkpoint(ext: str) -> Path:
return Path(matches[0])


def _migrate_pth_to_safetensors(src: Path) -> Path:
"""Convert a .pth checkpoint to .safetensors in-place.

Saves the converted file alongside the original, then deletes the .pth.
Returns the path to the new .safetensors file.
"""
import torch
from safetensors.torch import save_file

dst = src.with_suffix(".safetensors")
logger.info("Migrating %s → %s ...", src.name, dst.name)
checkpoint = torch.load(src, map_location="cpu", weights_only=True)
state_dict = checkpoint.get("state_dict", checkpoint)
state_dict = {k: v.contiguous() for k, v in state_dict.items()}
save_file(state_dict, dst)
src.unlink()
logger.info("Migration complete. Legacy file removed: %s", src.name)
return dst


def _wrap_mlx_output(raw: dict, despill_strength: float, auto_despeckle: bool, despeckle_size: int) -> dict:
"""Normalize MLX uint8 output to match Torch float32 contract.

Expand Down Expand Up @@ -313,6 +353,8 @@ def create_engine(
return _MLXEngineAdapter(raw_engine)
else:
ckpt = _discover_checkpoint(TORCH_EXT)
if ckpt.suffix == ".pth":
ckpt = _migrate_pth_to_safetensors(ckpt)
from CorridorKeyModule.inference_engine import CorridorKeyEngine

logger.info("Torch engine loaded: %s (device=%s)", ckpt.name, device)
Expand Down
9 changes: 7 additions & 2 deletions CorridorKeyModule/inference_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,13 @@ def _load_model(self) -> GreenFormer:
if not os.path.isfile(self.checkpoint_path):
raise FileNotFoundError(f"Checkpoint not found: {self.checkpoint_path}")

checkpoint = torch.load(self.checkpoint_path, map_location=self.device, weights_only=True)
state_dict = checkpoint.get("state_dict", checkpoint)
if self.checkpoint_path.endswith(".safetensors"):
from safetensors.torch import load_file as _st_load

state_dict = _st_load(self.checkpoint_path, device=str(self.device))
else:
checkpoint = torch.load(self.checkpoint_path, map_location=self.device, weights_only=True)
state_dict = checkpoint.get("state_dict", checkpoint)

# Fix Compiled Model Prefix & Handle PosEmbed Mismatch
new_state_dict = {}
Expand Down
123 changes: 123 additions & 0 deletions convert_checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
#!/usr/bin/env python3
"""Convert .pth / .pt checkpoints to the safetensors format.

Usage
-----
# Convert all checkpoints in the default directory:
python convert_checkpoint.py

# Convert specific files or directories:
python convert_checkpoint.py path/to/model.pth path/to/lora_dir/

# Convert and remove the originals after success:
python convert_checkpoint.py --delete-original
"""

from __future__ import annotations

import argparse
import sys
from pathlib import Path

DEFAULT_CHECKPOINT_DIR = Path("CorridorKeyModule/checkpoints")


def _convert(src: Path) -> Path:
"""Load *src* (.pth / .pt) and save as a .safetensors file next to it.

Handles both plain state-dicts and checkpoints that wrap the state-dict
under a ``"state_dict"`` key (e.g. checkpoints saved by PyTorch Lightning
or the CorridorKey training harness).

Returns the path of the newly created .safetensors file.
"""
import torch
from safetensors.torch import save_file

print(f" Loading {src} …")
checkpoint = torch.load(src, map_location="cpu", weights_only=True)
state_dict = checkpoint.get("state_dict", checkpoint) if isinstance(checkpoint, dict) else checkpoint

non_tensor = [k for k, v in state_dict.items() if not hasattr(v, "contiguous")]
if non_tensor:
print(f" Warning: skipping non-tensor keys: {non_tensor}")
state_dict = {k: v for k, v in state_dict.items() if hasattr(v, "contiguous")}

state_dict = {k: v.contiguous() for k, v in state_dict.items()}

dst = src.with_suffix(".safetensors")
print(f" Saving {dst} ({len(state_dict)} tensors) …")
save_file(state_dict, dst)
return dst


def convert_file(src: Path, *, delete_original: bool = False) -> Path:
if not src.exists():
raise FileNotFoundError(f"Not found: {src}")
if src.suffix not in (".pth", ".pt"):
raise ValueError(f"Expected .pth or .pt, got: {src}")

dst = _convert(src)

if delete_original:
src.unlink()
print(f" Deleted {src}")

print(f" Done: {dst}\n")
return dst


def main() -> None:
parser = argparse.ArgumentParser(
description="Convert .pth/.pt checkpoints to safetensors.",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog=__doc__,
)
parser.add_argument(
"paths",
nargs="*",
help="Files or directories to convert. Defaults to CorridorKeyModule/checkpoints/.",
)
parser.add_argument(
"--delete-original",
action="store_true",
help="Remove the original .pth/.pt file after a successful conversion.",
)
args = parser.parse_args()

targets: list[Path] = []
search_roots = [Path(p) for p in args.paths] if args.paths else [DEFAULT_CHECKPOINT_DIR]

for root in search_roots:
if root.is_file():
targets.append(root)
elif root.is_dir():
found = list(root.glob("*.pth")) + list(root.glob("*.pt"))
if not found:
print(f"No .pth/.pt files found in {root}")
targets.extend(found)
else:
print(f"Warning: path not found — {root}", file=sys.stderr)

if not targets:
print("Nothing to convert.")
sys.exit(0)

errors: list[tuple[Path, Exception]] = []
for src in targets:
print(f"Converting {src.name}:")
try:
convert_file(src, delete_original=args.delete_original)
except Exception as exc: # noqa: BLE001
print(f" ERROR: {exc}\n", file=sys.stderr)
errors.append((src, exc))

if errors:
print(f"\n{len(errors)} file(s) failed to convert.", file=sys.stderr)
sys.exit(1)

print(f"Converted {len(targets)} file(s).")


if __name__ == "__main__":
main()
10 changes: 8 additions & 2 deletions gvm_core/gvm/pipelines/pipeline_gvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,14 @@ def load_lora_weights(
):

unet_lora_config = LoraConfig.from_pretrained(pretrained_model_name_or_path_or_dict)
checkpoint = os.path.join(pretrained_model_name_or_path_or_dict, f"pytorch_lora_weights.pt")
unet_lora_ckpt = torch.load(checkpoint)
checkpoint_st = os.path.join(pretrained_model_name_or_path_or_dict, "pytorch_lora_weights.safetensors")
checkpoint_pt = os.path.join(pretrained_model_name_or_path_or_dict, "pytorch_lora_weights.pt")
if os.path.exists(checkpoint_st):
from safetensors.torch import load_file as _st_load

unet_lora_ckpt = _st_load(checkpoint_st)
else:
unet_lora_ckpt = torch.load(checkpoint_pt)
self.unet = LoraModel(self.unet, unet_lora_config, "default")
set_peft_model_state_dict(self.unet, unet_lora_ckpt)

Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ dependencies = [
"einops",
# BiRefNet alpha hint generator
"kornia",
# Safetensors — secure, fast tensor serialisation (replaces torch.load for .safetensors checkpoints)
"safetensors",
# CLI tools (huggingface-hub is also a transitive dep, but must be direct
# so that uv installs the "hf" console-script entry point)
"huggingface-hub",
Expand Down
79 changes: 79 additions & 0 deletions tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@
HF_CHECKPOINT_FILENAME,
HF_REPO_ID,
MLX_EXT,
MLX_MODEL_FILENAME,
TORCH_EXT,
_discover_checkpoint,
_ensure_torch_checkpoint,
_migrate_pth_to_safetensors,
_wrap_mlx_output,
resolve_backend,
)
Expand Down Expand Up @@ -154,6 +156,44 @@ def test_skip_when_present(self, tmp_path):
assert result == ckpt
mock_dl.assert_not_called()

def test_safetensors_preferred_over_pth(self, tmp_path):
"""When both .safetensors and .pth are present, .safetensors is returned."""
(tmp_path / "model.pth").write_bytes(b"legacy")
st = tmp_path / "model.safetensors"
st.write_bytes(b"converted")
with mock.patch("CorridorKeyModule.backend.CHECKPOINT_DIR", str(tmp_path)):
with mock.patch("huggingface_hub.hf_hub_download") as mock_dl:
result = _discover_checkpoint(TORCH_EXT)
assert result == st
mock_dl.assert_not_called()

def test_safetensors_torch_found_directly(self, tmp_path):
"""A lone .safetensors (non-MLX) is returned for TORCH_EXT discovery."""
st = tmp_path / "CorridorKey.safetensors"
st.write_bytes(b"converted")
with mock.patch("CorridorKeyModule.backend.CHECKPOINT_DIR", str(tmp_path)):
result = _discover_checkpoint(TORCH_EXT)
assert result == st

def test_mlx_safetensors_excluded_from_torch_discovery(self, tmp_path):
"""The MLX checkpoint is not returned when discovering Torch checkpoints."""
mlx_ckpt = tmp_path / MLX_MODEL_FILENAME
mlx_ckpt.write_bytes(b"mlx-weights")
with mock.patch("CorridorKeyModule.backend.CHECKPOINT_DIR", str(tmp_path)):
# No Torch .safetensors → should fall through to .pth, find none, auto-download
with mock.patch("CorridorKeyModule.backend._ensure_torch_checkpoint") as mock_dl:
mock_dl.return_value = tmp_path / "CorridorKey.pth"
_discover_checkpoint(TORCH_EXT)
mock_dl.assert_called_once()

def test_multiple_safetensors_torch_raises(self, tmp_path):
"""More than one non-MLX .safetensors raises ValueError."""
(tmp_path / "a.safetensors").write_bytes(b"x")
(tmp_path / "b.safetensors").write_bytes(b"y")
with mock.patch("CorridorKeyModule.backend.CHECKPOINT_DIR", str(tmp_path)):
with pytest.raises(ValueError, match="Multiple"):
_discover_checkpoint(TORCH_EXT)

def test_mlx_not_triggered(self, tmp_path):
"""MLX ext with empty dir raises FileNotFoundError, no download attempted."""
with mock.patch("CorridorKeyModule.backend.CHECKPOINT_DIR", str(tmp_path)):
Expand Down Expand Up @@ -203,6 +243,45 @@ def test_logging_on_download(self, tmp_path, caplog):
assert any("saved" in msg.lower() for msg in caplog.messages)


# --- _migrate_pth_to_safetensors ---


class TestMigratePthToSafetensors:
def test_converts_and_deletes_pth(self, tmp_path):
"""Successful migration saves .safetensors and removes the .pth."""
import torch
from safetensors.torch import load_file

src = tmp_path / "model.pth"
state_dict = {"weight": torch.zeros(4, 4)}
torch.save({"state_dict": state_dict}, src)

dst = _migrate_pth_to_safetensors(src)

assert dst == src.with_suffix(".safetensors")
assert dst.exists()
assert not src.exists()
loaded = load_file(dst)
assert "weight" in loaded

def test_flat_state_dict(self, tmp_path):
"""Handles a .pth that is a plain flat dict (no 'state_dict' wrapper)."""
import torch

src = tmp_path / "flat.pth"
torch.save({"bias": torch.ones(3)}, src)
dst = _migrate_pth_to_safetensors(src)
assert dst.exists()

def test_returns_safetensors_path(self, tmp_path):
import torch

src = tmp_path / "ck.pth"
torch.save({"w": torch.eye(2)}, src)
result = _migrate_pth_to_safetensors(src)
assert result.suffix == ".safetensors"


# --- _wrap_mlx_output ---


Expand Down
6 changes: 4 additions & 2 deletions tests/test_pbt_auto_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,14 @@
# Strategies
# ---------------------------------------------------------------------------

# File extensions that are NOT .pth — used to populate "non-empty but no .pth" dirs
# File extensions that are NOT .pth and NOT .safetensors — used to populate
# "non-empty but no recognised checkpoint" dirs.
# .safetensors is excluded because it is now a valid Torch checkpoint format
# (any non-MLX .safetensors in the checkpoint dir is returned directly).
_non_pth_extensions = st.sampled_from(
[
".txt",
".json",
".safetensors",
".bin",
".onnx",
".csv",
Expand Down
2 changes: 2 additions & 0 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading