diff --git a/README.md b/README.md index 737c19a..ca4932a 100644 --- a/README.md +++ b/README.md @@ -18,7 +18,7 @@ To run the code, you need to download the pre-trained models from [here](https:/ python3 eval.py -i image_file -s semantic_file # example - python3 eval.py -i ./sample_images/img1.jpg -s ./sample_images/img1.npy + python3 eval.py -i ./sample_images/img1.jpg -s ./sample_images/sem1.npy The generated scanpaths are saved in the **results** folder. diff --git a/components.py b/components.py index ca49cc0..9678093 100644 --- a/components.py +++ b/components.py @@ -23,8 +23,10 @@ ).to(device) fix_duration = FixationDuration(512, 128).to(device) -feature_extractor.load_state_dict(torch.load('./data/weights/vgg19.pth')) -fuser.load_state_dict(torch.load('./data/weights/fuser.pth')) -iorroi_lstm.load_state_dict(torch.load('./data/weights/iorroi.pth')) -mdn.load_state_dict(torch.load('./data/weights/mdn.pth')) -fix_duration.load_state_dict(torch.load('./data/weights/fix_duration.pth')) +location_mapper = None if torch.cuda.is_available() else device + +feature_extractor.load_state_dict(torch.load('./data/weights/vgg19.pth', map_location=location_mapper)) +fuser.load_state_dict(torch.load('./data/weights/fuser.pth', map_location=location_mapper)) +iorroi_lstm.load_state_dict(torch.load('./data/weights/iorroi.pth', map_location=location_mapper)) +mdn.load_state_dict(torch.load('./data/weights/mdn.pth', map_location=location_mapper)) +fix_duration.load_state_dict(torch.load('./data/weights/fix_duration.pth', map_location=location_mapper)) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..801c858 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,6 @@ +matplotlib +torch +torchvision +numpy +Pillow +scipy