From a05e8bb6492f56d2493d759612962c6f92f7090b Mon Sep 17 00:00:00 2001 From: AlexMili Date: Mon, 31 May 2021 11:33:02 +0200 Subject: [PATCH 1/3] Fix sementic filename in example command --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 737c19a..ca4932a 100644 --- a/README.md +++ b/README.md @@ -18,7 +18,7 @@ To run the code, you need to download the pre-trained models from [here](https:/ python3 eval.py -i image_file -s semantic_file # example - python3 eval.py -i ./sample_images/img1.jpg -s ./sample_images/img1.npy + python3 eval.py -i ./sample_images/img1.jpg -s ./sample_images/sem1.npy The generated scanpaths are saved in the **results** folder. From d07c3b9978f4d40c6fd33a5df03b190a5fed2fee Mon Sep 17 00:00:00 2001 From: AlexMili Date: Mon, 31 May 2021 11:36:23 +0200 Subject: [PATCH 2/3] Create requirements.txt --- requirements.txt | 6 ++++++ 1 file changed, 6 insertions(+) create mode 100644 requirements.txt diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..801c858 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,6 @@ +matplotlib +torch +torchvision +numpy +Pillow +scipy From 0ed7122ded4800e232e626e756df8d0185ac7f99 Mon Sep 17 00:00:00 2001 From: AlexMili Date: Mon, 31 May 2021 11:39:58 +0200 Subject: [PATCH 3/3] Fix weigths loading when only CPU is available --- components.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/components.py b/components.py index ca49cc0..9678093 100644 --- a/components.py +++ b/components.py @@ -23,8 +23,10 @@ ).to(device) fix_duration = FixationDuration(512, 128).to(device) -feature_extractor.load_state_dict(torch.load('./data/weights/vgg19.pth')) -fuser.load_state_dict(torch.load('./data/weights/fuser.pth')) -iorroi_lstm.load_state_dict(torch.load('./data/weights/iorroi.pth')) -mdn.load_state_dict(torch.load('./data/weights/mdn.pth')) -fix_duration.load_state_dict(torch.load('./data/weights/fix_duration.pth')) +location_mapper = None if torch.cuda.is_available() else device + +feature_extractor.load_state_dict(torch.load('./data/weights/vgg19.pth', map_location=location_mapper)) +fuser.load_state_dict(torch.load('./data/weights/fuser.pth', map_location=location_mapper)) +iorroi_lstm.load_state_dict(torch.load('./data/weights/iorroi.pth', map_location=location_mapper)) +mdn.load_state_dict(torch.load('./data/weights/mdn.pth', map_location=location_mapper)) +fix_duration.load_state_dict(torch.load('./data/weights/fix_duration.pth', map_location=location_mapper))