From 1214af478defcd489e056cd40dbc85ba98b8f336 Mon Sep 17 00:00:00 2001 From: Lang Xu Date: Mon, 22 Sep 2025 15:59:11 -0400 Subject: [PATCH] Pytorch profiler integration --- benchmarks/communication/all_gather.py | 28 +++++++------- benchmarks/communication/all_reduce.py | 10 +++-- benchmarks/communication/all_to_all.py | 16 ++++---- benchmarks/communication/broadcast.py | 10 +++-- benchmarks/communication/pt2pt.py | 29 +++++++------- benchmarks/communication/reduce_scatter.py | 26 +++++++------ benchmarks/communication/utils.py | 45 ++++++++++++++++++++++ 7 files changed, 110 insertions(+), 54 deletions(-) diff --git a/benchmarks/communication/all_gather.py b/benchmarks/communication/all_gather.py index b55e96c..0f40d3b 100644 --- a/benchmarks/communication/all_gather.py +++ b/benchmarks/communication/all_gather.py @@ -33,19 +33,21 @@ def timed_all_gather(input, output, start_event, end_event, args): sync_all() # time the actual comm op trials times and average it - start_event.record() - for i in range(args.trials): - if args.dist == 'torch': - if hasattr(torch.distributed, "_all_gather_base"): - dist._all_gather_base(output, input, group=None, async_op=args.async_op) - else: - output_tensors = list( - torch.chunk(output, - dist.get_world_size())) - dist.all_gather(output_tensors, input, group=None, async_op=True) - elif args.dist == 'deepspeed': - dist.allgather_fn(output, input, group=None, async_op=args.async_op) - end_event.record() + with prof(args) as profiler: + start_event.record() + for i in range(args.trials): + if args.dist == 'torch': + if hasattr(torch.distributed, "_all_gather_base"): + dist._all_gather_base(output, input, group=None, async_op=args.async_op) + else: + output_tensors = list( + torch.chunk(output, + dist.get_world_size())) + dist.all_gather(output_tensors, input, group=None, async_op=True) + elif args.dist == 'deepspeed': + dist.allgather_fn(output, input, group=None, async_op=args.async_op) + profiler.step() + end_event.record() sync_all() duration = start_event.elapsed_time(end_event) / 1000 diff --git a/benchmarks/communication/all_reduce.py b/benchmarks/communication/all_reduce.py index 438e18c..76e506f 100644 --- a/benchmarks/communication/all_reduce.py +++ b/benchmarks/communication/all_reduce.py @@ -21,10 +21,12 @@ def timed_all_reduce(input, start_event, end_event, args): sync_all() # time the actual comm op trials times and average it - start_event.record() - for i in range(args.trials): - dist.all_reduce(input, async_op=args.async_op) - end_event.record() + with prof(args) as profiler: + start_event.record() + for i in range(args.trials): + dist.all_reduce(input, async_op=args.async_op) + profiler.step() + end_event.record() sync_all() duration = start_event.elapsed_time(end_event) / 1000 diff --git a/benchmarks/communication/all_to_all.py b/benchmarks/communication/all_to_all.py index f779a4a..8103230 100644 --- a/benchmarks/communication/all_to_all.py +++ b/benchmarks/communication/all_to_all.py @@ -31,13 +31,15 @@ def timed_all_to_all(input, output, start_event, end_event, args): sync_all() # time the actual comm op trials times and average it - start_event.record() - for i in range(args.trials): - if args.all_to_all_v: - dist.all_to_all(output_list, input_list, async_op=args.async_op) - else: - dist.all_to_all_single(output, input, async_op=args.async_op) - end_event.record() + with prof(args) as profiler: + start_event.record() + for i in range(args.trials): + if args.all_to_all_v: + dist.all_to_all(output_list, input_list, async_op=args.async_op) + else: + dist.all_to_all_single(output, input, async_op=args.async_op) + profiler.step() + end_event.record() sync_all() duration = start_event.elapsed_time(end_event) / 1000 diff --git a/benchmarks/communication/broadcast.py b/benchmarks/communication/broadcast.py index f2955f9..532afac 100644 --- a/benchmarks/communication/broadcast.py +++ b/benchmarks/communication/broadcast.py @@ -21,10 +21,12 @@ def timed_broadcast(input, start_event, end_event, args): sync_all() # time the actual comm op trials times and average it - start_event.record() - for i in range(args.trials): - dist.broadcast(input, 0, async_op=args.async_op) - end_event.record() + with prof(args) as profiler: + start_event.record() + for i in range(args.trials): + dist.broadcast(input, 0, async_op=args.async_op) + profiler.step() + end_event.record() sync_all() duration = start_event.elapsed_time(end_event) / 1000 diff --git a/benchmarks/communication/pt2pt.py b/benchmarks/communication/pt2pt.py index c99da8c..cd3c18c 100644 --- a/benchmarks/communication/pt2pt.py +++ b/benchmarks/communication/pt2pt.py @@ -30,20 +30,21 @@ def timed_pt2pt(input, start_event, end_event, args): sync_all() # time the actual comm op trials times and average it - start_event.record() - for i in range(args.trials): - if dist.get_rank() == 0: - if args.async_op: - dist.isend(input, 1) - else: - dist.send(input, 1) - if dist.get_rank() == 1: - if args.async_op: - dist.irecv(input, src=0) - else: - dist.recv(input, src=0) - - end_event.record() + with prof(args) as profiler: + start_event.record() + for i in range(args.trials): + if dist.get_rank() == 0: + if args.async_op: + dist.isend(input, 1) + else: + dist.send(input, 1) + if dist.get_rank() == 1: + if args.async_op: + dist.irecv(input, src=0) + else: + dist.recv(input, src=0) + profiler.step() + end_event.record() sync_all() duration = start_event.elapsed_time(end_event) / 1000 diff --git a/benchmarks/communication/reduce_scatter.py b/benchmarks/communication/reduce_scatter.py index 9edc427..5e00683 100644 --- a/benchmarks/communication/reduce_scatter.py +++ b/benchmarks/communication/reduce_scatter.py @@ -33,18 +33,20 @@ def timed_reduce_scatter(input, start_event, end_event, args): sync_all() # time the actual comm op trials times and average it - start_event.record() - for i in range(args.trials): - if hasattr(torch.distributed, "reduce_scatter_tensor"): - dist.reduce_scatter_tensor(output, input, async_op=args.async_op) - elif hasattr(torch.distributed, "_reduce_scatter_base"): - dist._reduce_scatter_base(output, input, async_op=args.async_op) - else: - input_tensors = list( - torch.chunk(input, - dist.get_world_size())) - dist.reduce_scatter(output, input_tensors, async_op=args.async_op) - end_event.record() + with prof(args) as profiler: + start_event.record() + for i in range(args.trials): + if hasattr(torch.distributed, "reduce_scatter_tensor"): + dist.reduce_scatter_tensor(output, input, async_op=args.async_op) + elif hasattr(torch.distributed, "_reduce_scatter_base"): + dist._reduce_scatter_base(output, input, async_op=args.async_op) + else: + input_tensors = list( + torch.chunk(input, + dist.get_world_size())) + dist.reduce_scatter(output, input_tensors, async_op=args.async_op) + profiler.step() + end_event.record() sync_all() duration = start_event.elapsed_time(end_event) / 1000 diff --git a/benchmarks/communication/utils.py b/benchmarks/communication/utils.py index dee4d31..ca6c9d5 100644 --- a/benchmarks/communication/utils.py +++ b/benchmarks/communication/utils.py @@ -2,6 +2,7 @@ import os, sys import math import argparse +from contextlib import nullcontext COMMS_BENCH_DIR = os.path.join(os.path.dirname(__file__), "../") sys.path.append(COMMS_BENCH_DIR) @@ -235,4 +236,48 @@ def benchmark_parser(): parser.add_argument("--debug", action="store_true", help='Enables all_to_all debug prints') parser.add_argument('--all-to-all-v', action='store_true', help='Use alltoallv instead of alltoall. This will run the all_to_all benchmark with vector variant. Use with --all-to-all or alone to run just this benchmark.') + parser.add_argument("--profile", action="store_true", help='Enable PyTorch profiler during timed iterations') return parser + +class PassProfile: + """ + Even when profiling is disabled, the code can still walk through step. + """ + def step(self): + pass + +def prof(args): + """ + Returns a context manager that enables PyTorch profiler when args.profile is True. + """ + if not getattr(args, 'profile', False): + return nullcontext(PassProfile()) + + try: + from torch.profiler import profile, ProfilerActivity, schedule, tensorboard_trace_handler + except Exception: + return nullcontext(PassProfile()) + + activities = [ProfilerActivity.CPU] + if torch.cuda.is_available(): + activities.append(ProfilerActivity.CUDA) + + prof_schedule = schedule(wait=1, warmup=1, active=5, repeat=1) + + # assume saving logs under communication folder + comm_dir = os.path.abspath(os.path.dirname(__file__)) + log_dir = os.path.join(comm_dir, 'profiles') + os.makedirs(log_dir, exist_ok=True) + + rank = 0 + if 'dist' in globals(): rank = dist.get_rank() + handler = tensorboard_trace_handler(os.path.join(log_dir, f'rank_{rank}')) + + return profile( + activities=activities, + schedule=prof_schedule, + on_trace_ready=handler, + record_shapes=True, + profile_memory=True, + with_stack=True, + )