diff --git a/benchmarks/communication/all_reduce.py b/benchmarks/communication/all_reduce.py index 438e18c..17a0d32 100644 --- a/benchmarks/communication/all_reduce.py +++ b/benchmarks/communication/all_reduce.py @@ -42,14 +42,53 @@ def timed_all_reduce(input, start_event, end_event, args): print_rank_0(f"{size:<20} {desc:25s} {duration_str:20s} {tput_str:20s} {busbw_str:20s}") +def create_input_tensor(world_size, num_elements, global_rank, local_rank, args): + """ + Create input tensor for all_reduce benchmark/validation. + + Each rank's tensor is filled with its rank value, so after all_reduce SUM, + each element should equal sum(0..n-1) = n*(n-1)/2. + + Args: + world_size: Total number of ranks + num_elements: Number of elements per rank + global_rank: This rank's global rank + local_rank: This rank's local rank + args: Benchmark arguments + + Returns: + Input tensor on GPU, or None if OOM + """ + if args.dist == 'torch': + import torch.distributed as dist + elif args.dist == 'deepspeed': + import deepspeed.comm as dist + + try: + mat = torch.ones(num_elements, dtype=getattr(torch, args.dtype)).cuda(local_rank) + input = mat.mul_(float(global_rank)) + return input + except RuntimeError as e: + if 'out of memory' in str(e): + if dist.get_rank() == 0: + print('WARNING: Ran out of GPU memory.') + sync_all() + return None + else: + raise e + + def run_all_reduce(local_rank, args): if args.dist == 'torch': import torch.distributed as dist elif args.dist == 'deepspeed': import deepspeed.comm as dist - # Prepare benchmark header - print_header(args, 'all_reduce') + # Prepare benchmark header unless validating + if not args.validate: + print_header(args, 'all_reduce') + else: + print_rank_0("Running Allreduce validation") world_size = dist.get_world_size() global_rank = dist.get_rank() @@ -57,32 +96,41 @@ def run_all_reduce(local_rank, args): start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) - if args.scan: - M_LIST = [] - for x in (2**p for p in range(1, args.maxsize)): - M_LIST.append(x) + if args.single: + sync_all() + num_elements = world_size * (2 ** args.maxsize) + input = create_input_tensor(world_size, num_elements, global_rank, local_rank, args) + if input is None: + if dist.get_rank() == 0: + print('Exiting comm op.') + return + sync_all() + + if args.validate: + run_validation(input, args) + else: + timed_all_reduce(input, start_event, end_event, args) + + elif args.scan: + M_LIST = [2**p for p in range(1, args.maxsize)] sync_all() # loop over various tensor sizes for M in M_LIST: - global_rank = dist.get_rank() - try: - mat = torch.ones(world_size, M, - dtype=getattr(torch, args.dtype)).cuda(local_rank) - sync_all() - input = ((mat.mul_(float(global_rank))).view(-1)) - del mat - torch.cuda.empty_cache() - except RuntimeError as e: - if 'out of memory' in str(e): - if dist.get_rank() == 0: - print('WARNING: Ran out of GPU memory. Exiting comm op.') - sync_all() - break - else: - raise e + num_elements = world_size * M + input = create_input_tensor(world_size, num_elements, global_rank, local_rank, args) + if input is None: + break sync_all() - timed_all_reduce(input, start_event, end_event, args) + + if args.validate: + run_validation(input, args) + else: + timed_all_reduce(input, start_event, end_event, args) + + # Clean up for next iteration + del input + torch.cuda.empty_cache() else: # Send the biggest message size our GPUs can fit. If you're facing OOM errors, reduce the mem_factor # Don't need output tensor, so we double mem_factor @@ -91,24 +139,22 @@ def run_all_reduce(local_rank, args): mem_factor=args.mem_factor * 2, local_rank=local_rank, args=args) - try: - mat = torch.ones(elements_per_gpu, dtype=getattr(torch, - args.dtype)).cuda(local_rank) - input = ((mat.mul_(float(global_rank))).view(-1)) - except RuntimeError as e: - if 'out of memory' in str(e): - if dist.get_rank() == 0: - print('WARNING: Ran out of GPU memory. Try to reduce the --mem-factor argument!') - sync_all() - return - else: - raise e + + input = create_input_tensor(world_size, elements_per_gpu, global_rank, local_rank, args) + if input is None: + if dist.get_rank() == 0: + print('Try to reduce the --mem-factor argument!') + return sync_all() - timed_all_reduce(input, start_event, end_event, args) + + if args.validate: + run_validation(input, args) + else: + timed_all_reduce(input, start_event, end_event, args) if __name__ == "__main__": args = benchmark_parser().parse_args() rank = args.local_rank init_processes(local_rank=rank, args=args) - run_all_reduce(local_rank=rank, args=args) + run_all_reduce(local_rank=rank, args=args) \ No newline at end of file diff --git a/benchmarks/communication/utils.py b/benchmarks/communication/utils.py index dee4d31..63a5e0c 100644 --- a/benchmarks/communication/utils.py +++ b/benchmarks/communication/utils.py @@ -13,7 +13,8 @@ def env2int(env_list, default=-1): for e in env_list: val = int(os.environ.get(e, -1)) - if val >= 0: return val + if val >= 0: + return val return default @@ -133,7 +134,7 @@ def get_metric_strings(args, tput, busbw, duration): duration_ms = duration * 1e3 duration_us = duration * 1e6 tput = f'{tput / 1e9:.3f}' - busbw = f'{busbw /1e9:.3f}' + busbw = f'{busbw / 1e9:.3f}' if duration_us < 1e3 or args.raw: duration = f'{duration_us:.3f}' @@ -207,6 +208,9 @@ def benchmark_parser(): parser.add_argument("--trials", type=int, default=DEFAULT_TRIALS, help='Number of timed iterations') parser.add_argument("--warmups", type=int, default=DEFAULT_WARMUPS, help='Number of warmup (non-timed) iterations') parser.add_argument("--maxsize", type=int, default=24, help='Max message size as a power of 2') + group = parser.add_mutually_exclusive_group() + group.add_argument("--scan", action="store_true", help='Enables scanning all message sizes') + group.add_argument("--single", action="store_true", help='Run only at 2^maxsize message size') parser.add_argument("--async-op", action="store_true", help='Enables non-blocking communication') parser.add_argument("--bw-unit", type=str, default=DEFAULT_UNIT, choices=['Gbps', 'GBps']) parser.add_argument("--backend", @@ -219,7 +223,6 @@ def benchmark_parser(): default=DEFAULT_DIST, choices=['deepspeed', 'torch'], help='Distributed DL framework to use') - parser.add_argument("--scan", action="store_true", help='Enables scanning all message sizes') parser.add_argument("--raw", action="store_true", help='Print the message size and latency without units') parser.add_argument("--all-reduce", action="store_true", help='Run all_reduce') parser.add_argument("--reduce-scatter", action="store_true", help='Run reduce_scatter') @@ -233,6 +236,55 @@ def benchmark_parser(): default=.3, help='Proportion of max available GPU memory to use for single-size evals') parser.add_argument("--debug", action="store_true", help='Enables all_to_all debug prints') - parser.add_argument('--all-to-all-v', action='store_true', + parser.add_argument('--all-to-all-v', action='store_true', help='Use alltoallv instead of alltoall. This will run the all_to_all benchmark with vector variant. Use with --all-to-all or alone to run just this benchmark.') + parser.add_argument("--validate", action="store_true", help='Validate collective results') return parser + + +def validate_allreduce(input, args): + """ + Validate all_reduce operation by checking the result against expected value. + + Each rank initializes its tensor elements to its rank value (0, 1, 2, ..., n-1). + After all_reduce with SUM, each element should equal n*(n-1)/2. + + Args: + input: Input tensor (will be modified in-place by all_reduce) + args: Benchmark arguments containing dist framework info + + Returns: + bool: True if validation passes, False otherwise + """ + if args.dist == 'torch': + import torch.distributed as dist + elif args.dist == 'deepspeed': + import deepspeed.comm as dist + + dist.all_reduce(input, async_op=False) + sync_all() + n = dist.get_world_size() + expected = float(n * (n - 1) / 2) + return torch.allclose(input, torch.full_like(input, expected)) + + +def run_validation(input, args): + """ + Run validation trials and print results. + + Args: + input: Input tensor to validate (will be cloned for each trial) + args: Benchmark arguments + """ + passes = 0 + for _ in range(args.trials): + if validate_allreduce(input.clone(), args): + passes += 1 + + size = input.element_size() * input.nelement() + if not args.raw: + size = convert_size(size) + + desc = f"validation ({passes}/{args.trials})" + status = 'PASS' if passes == args.trials else 'FAIL' + print_rank_0(f"{size:<20} {desc:25s} {status}") \ No newline at end of file