From 92be27989a526fca87a32524436ccdfdb9744aec Mon Sep 17 00:00:00 2001 From: "Aart J.C. Bik" Date: Thu, 17 Apr 2025 16:01:35 -0700 Subject: [PATCH] Disable gradients in inference mode --- python/mpact/models/train.py | 29 +++++++++++++++-------------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/python/mpact/models/train.py b/python/mpact/models/train.py index 2fdaf68..d5e27cc 100644 --- a/python/mpact/models/train.py +++ b/python/mpact/models/train.py @@ -28,20 +28,21 @@ def training_loop(model, optimizer, loss_function, train, validation, epochs=10) tloss += loss.data.item() # Switch to inference mode. - model.eval() - vloss = 0.0 - num_validation = len(validation) # in batches - num_correct = 0 - num_total = 0 - for inp, target in validation: # batch loop (validation) - output = model(inp) - loss = loss_function(output, target) - vloss += loss.data.item() - correct = torch.eq( - torch.max(F.softmax(output, dim=1), dim=1)[1], target - ).view(-1) - num_correct += torch.sum(correct).item() - num_total += correct.shape[0] + model.eval() # disables e.g. model drop-out + with torch.no_grad(): # disables gradient computations + vloss = 0.0 + num_validation = len(validation) # in batches + num_correct = 0 + num_total = 0 + for inp, target in validation: # batch loop (validation) + output = model(inp) + loss = loss_function(output, target) + vloss += loss.data.item() + correct = torch.eq( + torch.max(F.softmax(output, dim=1), dim=1)[1], target + ).view(-1) + num_correct += torch.sum(correct).item() + num_total += correct.shape[0] # Report stats. print(