diff --git a/benchmark/synthetic_benchmark.py b/benchmark/synthetic_benchmark.py index 3358b15..1f1fc21 100644 --- a/benchmark/synthetic_benchmark.py +++ b/benchmark/synthetic_benchmark.py @@ -64,10 +64,16 @@ ) parser.add_argument( "--async-sync-interval", - default=50, + default=500, type=int, help="Model synchronization interval(ms) for async algorithm", ) +parser.add_argument( + "--async-warmup-steps", + default=0, + type=int, + help="Warmup(allreduce) steps for async algorithm", +) parser.add_argument( "--amp", action="store_true", @@ -131,13 +137,16 @@ elif args.algorithm == "qadam": from bagua.torch_api.algorithms import q_adam - optimizer = q_adam.QAdamOptimizer(model.parameters(), lr=0.01 * bagua.get_world_size(), warmup_steps=100) + optimizer = q_adam.QAdamOptimizer( + model.parameters(), lr=0.01 * bagua.get_world_size(), warmup_steps=100 + ) algorithm = q_adam.QAdamAlgorithm(optimizer) elif args.algorithm == "async": from bagua.torch_api.algorithms import async_model_average algorithm = async_model_average.AsyncModelAverageAlgorithm( - sync_interval_ms=args.async_sync_interval + sync_interval_ms=args.async_sync_interval, + warmup_steps=args.async_warmup_steps, ) else: raise NotImplementedError @@ -186,6 +195,7 @@ def benchmark_step(): # Warm-up logging.info("Running warmup...") + timeit.timeit(benchmark_step, number=args.num_warmup_batches) # Benchmark diff --git a/imagenet/main.py b/imagenet/main.py index 5cf6adc..705f399 100644 --- a/imagenet/main.py +++ b/imagenet/main.py @@ -149,11 +149,18 @@ parser.add_argument( "--async-sync-interval", - default=100, + default=500, type=int, help="Model synchronization interval(ms) for async algorithm", ) +parser.add_argument( + "--async-warmup-steps", + default=100, + type=int, + help="Warmup(allreduce) steps for async algorithm", +) + best_acc1 = 0 @@ -232,13 +239,16 @@ def main_worker(args): elif args.algorithm == "qadam": from bagua.torch_api.algorithms import q_adam - optimizer = q_adam.QAdamOptimizer(model.parameters(), lr=args.lr, warmup_steps=100) + optimizer = q_adam.QAdamOptimizer( + model.parameters(), lr=args.lr, warmup_steps=100 + ) algorithm = q_adam.QAdamAlgorithm(optimizer) elif args.algorithm == "async": from bagua.torch_api.algorithms import async_model_average algorithm = async_model_average.AsyncModelAverageAlgorithm( - sync_interval_ms=args.async_sync_interval + sync_interval_ms=args.async_sync_interval, + warmup_steps=args.async_warmup_steps, ) else: raise NotImplementedError @@ -335,9 +345,15 @@ def main_worker(args): if args.distributed: train_sampler.set_epoch(epoch) + if args.algorithm == "async": + algorithm.resume(model) + # train for one epoch train(train_loader, model, criterion, optimizer, scaler, epoch, args) + if args.algorithm == "async": + algorithm.abort(model) + # evaluate on validation set acc1 = validate(val_loader, model, criterion, epoch, args) @@ -357,9 +373,6 @@ def main_worker(args): is_best, ) - if args.algorithm == "async": - algorithm.abort(model) - def train(train_loader, model, criterion, optimizer, scaler, epoch, args): batch_time = AverageMeter("Time", ":6.3f") @@ -415,10 +428,17 @@ def train(train_loader, model, criterion, optimizer, scaler, epoch, args): top5.update(acc5[0], images.size(0)) if args.prof >= 0: - torch.cuda.nvtx.range_push("optimizer.step()") + torch.cuda.nvtx.range_push("backward") # compute gradient and do SGD step scaler.scale(loss).backward() + + if args.prof >= 0: + torch.cuda.nvtx.range_pop() + + if args.prof >= 0: + torch.cuda.nvtx.range_push("optimizer.step()") + scaler.step(optimizer) scaler.update() @@ -439,6 +459,9 @@ def train(train_loader, model, criterion, optimizer, scaler, epoch, args): if args.prof >= 0 and i == args.prof + 10: print("Profiling ended at iteration {}".format(i)) torch.cuda.cudart().cudaProfilerStop() + + if args.algorithm == "async": + model.bagua_algorithm.abort(model) quit() diff --git a/mnist/main.py b/mnist/main.py index 8d3654c..1e75b48 100644 --- a/mnist/main.py +++ b/mnist/main.py @@ -232,13 +232,15 @@ def main(): elif args.algorithm == "qadam": from bagua.torch_api.algorithms import q_adam - optimizer = q_adam.QAdamOptimizer(model.parameters(), lr=args.lr, warmup_steps=100) + optimizer = q_adam.QAdamOptimizer( + model.parameters(), lr=args.lr, warmup_steps=100 + ) algorithm = q_adam.QAdamAlgorithm(optimizer) elif args.algorithm == "async": from bagua.torch_api.algorithms import async_model_average algorithm = async_model_average.AsyncModelAverageAlgorithm( - sync_interval_ms=args.async_sync_interval + sync_interval_ms=args.async_sync_interval, ) else: raise NotImplementedError @@ -250,14 +252,17 @@ def main(): scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) for epoch in range(1, args.epochs + 1): + if args.algorithm == "async": + algorithm.resume(model) + train(args, model, train_loader, optimizer, epoch) + if args.algorithm == "async": + algorithm.abort(model) + test(model, test_loader) scheduler.step() - if args.algorithm == "async": - algorithm.abort(model) - if args.save_model: torch.save(model.state_dict(), "mnist_cnn.pt") diff --git a/squad/main.py b/squad/main.py index 3cd0bca..9aff996 100644 --- a/squad/main.py +++ b/squad/main.py @@ -164,7 +164,8 @@ def train(args, train_dataset, model, tokenizer): from bagua.torch_api.algorithms import async_model_average algorithm = async_model_average.AsyncModelAverageAlgorithm( - sync_interval_ms=args.async_sync_interval + sync_interval_ms=args.async_sync_interval, + warmup_steps=args.async_warmup_steps, ) else: raise NotImplementedError @@ -399,7 +400,6 @@ def train(args, train_dataset, model, tokenizer): if args.algorithm == "async": algorithm.abort(model) - torch.cuda.synchronize() return global_step, tr_loss / global_step @@ -919,6 +919,12 @@ def main(): type=int, help="Model synchronization interval(ms) for async algorithm", ) + parser.add_argument( + "--async-warmup-steps", + default=100, + type=int, + help="Warmup(allreduce) steps for async algorithm", + ) args = parser.parse_args() if args.doc_stride >= args.max_seq_length - args.max_query_length: