diff --git a/.gitignore b/.gitignore index 9289cd3..6c562a2 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,6 @@ .DS_store .idea -*.pyc \ No newline at end of file +*.pyc +eenet/ +datasets/cifar-100-python +datasets/cifar-100-python.tar.gz \ No newline at end of file diff --git a/main.py b/main.py index d196f74..fd9ce36 100644 --- a/main.py +++ b/main.py @@ -17,9 +17,9 @@ from config import Config from data_tools.dataloader import get_dataloaders from predict import validate -from predict_utils import dynamic_evaluate +from utils.predict_utils import dynamic_evaluate from train import train -from utils import save_checkpoint, load_checkpoint, measure_flops, load_state_dict +from utils.utils import save_checkpoint, load_checkpoint, measure_flops, load_state_dict # import sys # if all(['models' not in sys.path]): diff --git a/predict.py b/predict.py index 6eb9a15..a65e0d4 100644 --- a/predict.py +++ b/predict.py @@ -10,7 +10,7 @@ import torch.nn.parallel import torch.optim -from utils import accuracy, AverageMeter +from utils.utils import accuracy, AverageMeter def validate(model, val_loader, criterion, args): @@ -40,17 +40,19 @@ def validate(model, val_loader, criterion, args): if not isinstance(output, list): output = [output] - loss = torch.zeros(0) + #loss = torch.zeros(0) + loss = torch.tensor(0.0, dtype=torch.float) for j in range(len(output)): if 'bert' in model.__class__.__name__: - loss += (j + 1) * criterion(output[j], target_var) / (args.num_exits * (args.num_exits + 1)) + loss += (j + 1) * criterion(output[j][0][0], target_var) / (args.num_exits * (args.num_exits + 1)) else: - loss += criterion(output[j], target_var) / args.num_exits - + loss += criterion(output[j][0][0], target_var) / args.num_exits + losses.update(loss.item(), input.size(0)) for j in range(len(output)): - prec1, prec5 = accuracy(output[j].data, target, topk=(1, 5)) + relevant_output = output[j][0][0].detach() + prec1, prec5 = accuracy(relevant_output, target, topk=(1, 5)) top1[j].update(prec1.item(), input.size(0)) top5[j].update(prec5.item(), input.size(0)) diff --git a/train.py b/train.py index 0f0d403..0d16f2c 100644 --- a/train.py +++ b/train.py @@ -10,7 +10,7 @@ import torch.nn.parallel import torch.optim -from utils import adjust_learning_rate, accuracy, AverageMeter +from utils.utils import adjust_learning_rate, accuracy, AverageMeter def train(model, train_loader, criterion, optimizer, epoch, args, train_params): @@ -52,18 +52,23 @@ def train(model, train_loader, criterion, optimizer, epoch, args, train_params): if not isinstance(output, list): output = [output] - loss = torch.tensor(0) + #the loss variable must be float type as it handles floating point values from the dataset + loss = torch.tensor(0.0, dtype=torch.float) + #print(f"Data type of loss: {loss.dtype}") for j in range(len(output)): - loss += (j + 1) * criterion(output[j], target_var) / (args.num_exits * (args.num_exits + 1)) - if epoch > train_params['num_epoch'] * 0.75 and j < len(output) - 1: - T = 3 - alpha_kl = 0.01 - loss += torch.nn.KLDivLoss()(torch.log_softmax(output[j] / T, dim=-1), torch.softmax(output[-1] / T, dim=-1)) * alpha_kl * T * T + #torch._C._nn.cross_entropy_loss function accepts a tensor. + loss += (j + 1) * criterion(output[j][0][0], target_var) / (args.num_exits * (args.num_exits + 1)) + if epoch > train_params['num_epoch'] * 0.75 and j < len(output) - 1: + T = 3 + alpha_kl = 0.01 + loss += torch.nn.KLDivLoss()(torch.log_softmax(output[j] / T, dim=-1), torch.softmax(output[-1] / T, dim=-1)) * alpha_kl * T * T losses.update(loss.item(), input.size(0)) for j in range(len(output)): - prec1, prec5 = accuracy(output[j].data, target, topk=(1, 5)) + #tensor has data attribute + relevant_output = output[j][0][0].detach() + prec1, prec5 = accuracy(relevant_output, target, topk=(1, 5)) top1[j].update(prec1.item(), input.size(0)) top5[j].update(prec5.item(), input.size(0)) diff --git a/utils/predict_utils.py b/utils/predict_utils.py index d527a21..f8a2c87 100644 --- a/utils/predict_utils.py +++ b/utils/predict_utils.py @@ -10,7 +10,7 @@ import pandas as pd -from predict_helpers import * +from utils.predict_helpers import * def dynamic_evaluate(model, test_loader, val_loader, args): diff --git a/utils/utils.py b/utils/utils.py index 4b610e8..2dcd197 100644 --- a/utils/utils.py +++ b/utils/utils.py @@ -6,7 +6,7 @@ import models import torch -from op_counter import measure_model +from utils.op_counter import measure_model def save_checkpoint(state, args, is_best, filename, result, prec1_per_exit, prec5_per_exit):