Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions scripts/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import random
from argparse import ArgumentParser
import torchio as tio
import pathlib
import platform
from pathlib import Path

from keymorph.model import KeyMorph
Expand Down Expand Up @@ -375,7 +377,7 @@ def get_model(args):
return registration_model


if __name__ == "__main__":
if __name__ == "__main__":
args = parse_args()

# Select GPU
Expand Down Expand Up @@ -438,7 +440,13 @@ def get_model(args):
# Model
registration_model = get_model(args)
registration_model.eval()


# pathlib support for windows OS - use WindowsPath as PosixPath if on windows to unpickle the weights
posixpath_backup = pathlib.PosixPath
plt = platform.system()
if plt == 'Windows':
pathlib.PosixPath = pathlib.WindowsPath

# Checkpoint loading
if args.registration_model == "keymorph":
args.load_path = get_foundation_weights_path(
Expand Down Expand Up @@ -476,3 +484,4 @@ def get_model(args):
args,
save_dir_prefix="",
)
pathlib.PosixPath = posixpath_backup # set PosixPath back to whatever it was before