diff --git a/cycnn/main.py b/cycnn/main.py index 05ae9d5..a9540c3 100644 --- a/cycnn/main.py +++ b/cycnn/main.py @@ -70,7 +70,7 @@ def validate(model, device, criterion, test_loader, epoch, args): model.eval() validation_loss, correct, num_data = 0, 0, 0 - with torch.no_grad(): + with torch.inference_mode(): for batch_idx, (images, labels) in enumerate(test_loader): if args['dataset'] == 'svhn': @@ -122,7 +122,7 @@ def test(model, device, criterion, test_loader, args): model.eval() test_loss, correct, num_data = 0, 0, 0 - with torch.no_grad(): + with torch.inference_mode(): for batch_idx, (images, labels) in enumerate(test_loader): if args['dataset'] == 'svhn':