diff --git a/scripts/register.py b/scripts/register.py index f680fd7..9649163 100644 --- a/scripts/register.py +++ b/scripts/register.py @@ -6,6 +6,8 @@ import random from argparse import ArgumentParser import torchio as tio +import pathlib +import platform from pathlib import Path from keymorph.model import KeyMorph @@ -375,7 +377,7 @@ def get_model(args): return registration_model -if __name__ == "__main__": +if __name__ == "__main__": args = parse_args() # Select GPU @@ -438,7 +440,13 @@ def get_model(args): # Model registration_model = get_model(args) registration_model.eval() - + + # pathlib support for windows OS - use WindowsPath as PosixPath if on windows to unpickle the weights + posixpath_backup = pathlib.PosixPath + plt = platform.system() + if plt == 'Windows': + pathlib.PosixPath = pathlib.WindowsPath + # Checkpoint loading if args.registration_model == "keymorph": args.load_path = get_foundation_weights_path( @@ -476,3 +484,4 @@ def get_model(args): args, save_dir_prefix="", ) + pathlib.PosixPath = posixpath_backup # set PosixPath back to whatever it was before