diff --git a/evals/image_classification_frozen/eval.py b/evals/image_classification_frozen/eval.py index 56d2f28e..d72b4496 100644 --- a/evals/image_classification_frozen/eval.py +++ b/evals/image_classification_frozen/eval.py @@ -166,7 +166,7 @@ def main(args_eval, resume_preempt=False): num_classes=num_classes ).to(device) - train_loader = make_dataloader( + train_loader, dist_sampler = make_dataloader( dataset_name=dataset_name, root_path=root_path, resolution=resolution, @@ -175,7 +175,7 @@ def main(args_eval, resume_preempt=False): world_size=world_size, rank=rank, training=True) - val_loader = make_dataloader( + val_loader, _ = make_dataloader( dataset_name=dataset_name, root_path=root_path, resolution=resolution, @@ -229,6 +229,9 @@ def save_checkpoint(epoch): # TRAIN LOOP for epoch in range(start_epoch, num_epochs): logger.info('Epoch %d' % (epoch + 1)) + + dist_sampler.set_epoch(epoch) + train_acc = run_one_epoch( device=device, training=True, @@ -408,7 +411,7 @@ def make_dataloader( transforms.ToTensor(), transforms.Normalize(normalization[0], normalization[1])]) - data_loader, _ = init_data( + return init_data( data=dataset_name, transform=transform, batch_size=batch_size, @@ -420,7 +423,6 @@ def make_dataloader( copy_data=False, drop_last=False, subset_file=subset_file) - return data_loader def init_model( diff --git a/evals/video_classification_frozen/eval.py b/evals/video_classification_frozen/eval.py index f81f526d..a4617f3f 100644 --- a/evals/video_classification_frozen/eval.py +++ b/evals/video_classification_frozen/eval.py @@ -186,7 +186,7 @@ def main(args_eval, resume_preempt=False): num_classes=num_classes, ).to(device) - train_loader = make_dataloader( + train_loader, dist_sampler = make_dataloader( dataset_type=dataset_type, root_path=train_data_path, resolution=resolution, @@ -200,7 +200,7 @@ def main(args_eval, resume_preempt=False): world_size=world_size, rank=rank, training=True) - val_loader = make_dataloader( + val_loader, _ = make_dataloader( dataset_type=dataset_type, root_path=val_data_path, resolution=resolution, @@ -259,6 +259,9 @@ def save_checkpoint(epoch): # TRAIN LOOP for epoch in range(start_epoch, num_epochs): logger.info('Epoch %d' % (epoch + 1)) + + dist_sampler.set_epoch(epoch) + train_acc = run_one_epoch( device=device, training=True, @@ -469,7 +472,7 @@ def make_dataloader( crop_size=resolution, ) - data_loader, _ = init_data( + return init_data( data=dataset_type, root_path=root_path, transform=transform, @@ -485,7 +488,6 @@ def make_dataloader( copy_data=False, drop_last=False, subset_file=subset_file) - return data_loader def init_model(