From eacdc75a6e9235ba4c5a8b980e5af9694bbd646d Mon Sep 17 00:00:00 2001 From: Michelia_zhx Date: Fri, 14 Apr 2023 21:12:04 +0800 Subject: [PATCH] fix the while loop --- finetune.py | 38 +++++++++++++++++++------------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/finetune.py b/finetune.py index 0fed260..976204a 100644 --- a/finetune.py +++ b/finetune.py @@ -39,11 +39,11 @@ def end_to_end_finetune(train_loader, test_loader, model, t_model, args): top5 = AverageMeter() finish = False while not finish: + iter_nums += 1 + if iter_nums > args.epoch: + finish = True + break for batch_idx, (data, target) in enumerate(train_loader): - iter_nums += 1 - if iter_nums > args.epoch: - finish = True - break # measure data loading time data = data.cuda() target = target.cuda() @@ -66,21 +66,21 @@ def end_to_end_finetune(train_loader, test_loader, model, t_model, args): # measure elapsed time batch_time.update(time.time() - end) end = time.time() - if iter_nums % args.print_freq == 0: - print('Train: [{0}/{1}]\t' - 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' - 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' - 'LR {lr}\t' - 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' - 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' - 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( - iter_nums, args.epoch, batch_time=batch_time, lr=lr, - data_time=data_time, loss=losses, top1=top1, top5=top5)) - if iter_nums % args.eval_freq == 0: - validate(test_loader, model) - model.train() - model.get_feat = 'pre_GAP' - scheduler.step() + if iter_nums % args.print_freq == 0: + print('Train: [{0}/{1}]\t' + 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' + 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' + 'LR {lr}\t' + 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' + 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' + 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( + iter_nums, args.epoch, batch_time=batch_time, lr=lr, + data_time=data_time, loss=losses, top1=top1, top5=top5)) + if iter_nums % args.eval_freq == 0: + validate(test_loader, model) + model.train() + model.get_feat = 'pre_GAP' + scheduler.step() validate(test_loader, model)