Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 83 additions & 37 deletions benchmarks/communication/all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,47 +42,95 @@ def timed_all_reduce(input, start_event, end_event, args):
print_rank_0(f"{size:<20} {desc:25s} {duration_str:20s} {tput_str:20s} {busbw_str:20s}")


def create_input_tensor(world_size, num_elements, global_rank, local_rank, args):
"""
Create input tensor for all_reduce benchmark/validation.

Each rank's tensor is filled with its rank value, so after all_reduce SUM,
each element should equal sum(0..n-1) = n*(n-1)/2.

Args:
world_size: Total number of ranks
num_elements: Number of elements per rank
global_rank: This rank's global rank
local_rank: This rank's local rank
args: Benchmark arguments

Returns:
Input tensor on GPU, or None if OOM
"""
if args.dist == 'torch':
import torch.distributed as dist
elif args.dist == 'deepspeed':
import deepspeed.comm as dist

try:
mat = torch.ones(num_elements, dtype=getattr(torch, args.dtype)).cuda(local_rank)
input = mat.mul_(float(global_rank))
return input
except RuntimeError as e:
if 'out of memory' in str(e):
if dist.get_rank() == 0:
print('WARNING: Ran out of GPU memory.')
sync_all()
return None
else:
raise e


def run_all_reduce(local_rank, args):
if args.dist == 'torch':
import torch.distributed as dist
elif args.dist == 'deepspeed':
import deepspeed.comm as dist

# Prepare benchmark header
print_header(args, 'all_reduce')
# Prepare benchmark header unless validating
if not args.validate:
print_header(args, 'all_reduce')
else:
print_rank_0("Running Allreduce validation")

world_size = dist.get_world_size()
global_rank = dist.get_rank()

start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)

if args.scan:
M_LIST = []
for x in (2**p for p in range(1, args.maxsize)):
M_LIST.append(x)
if args.single:
sync_all()
num_elements = world_size * (2 ** args.maxsize)
input = create_input_tensor(world_size, num_elements, global_rank, local_rank, args)
if input is None:
if dist.get_rank() == 0:
print('Exiting comm op.')
return
sync_all()

if args.validate:
run_validation(input, args)
else:
timed_all_reduce(input, start_event, end_event, args)

elif args.scan:
M_LIST = [2**p for p in range(1, args.maxsize)]

sync_all()
# loop over various tensor sizes
for M in M_LIST:
global_rank = dist.get_rank()
try:
mat = torch.ones(world_size, M,
dtype=getattr(torch, args.dtype)).cuda(local_rank)
sync_all()
input = ((mat.mul_(float(global_rank))).view(-1))
del mat
torch.cuda.empty_cache()
except RuntimeError as e:
if 'out of memory' in str(e):
if dist.get_rank() == 0:
print('WARNING: Ran out of GPU memory. Exiting comm op.')
sync_all()
break
else:
raise e
num_elements = world_size * M
input = create_input_tensor(world_size, num_elements, global_rank, local_rank, args)
if input is None:
break
sync_all()
timed_all_reduce(input, start_event, end_event, args)

if args.validate:
run_validation(input, args)
else:
timed_all_reduce(input, start_event, end_event, args)

# Clean up for next iteration
del input
torch.cuda.empty_cache()
else:
# Send the biggest message size our GPUs can fit. If you're facing OOM errors, reduce the mem_factor
# Don't need output tensor, so we double mem_factor
Expand All @@ -91,24 +139,22 @@ def run_all_reduce(local_rank, args):
mem_factor=args.mem_factor * 2,
local_rank=local_rank,
args=args)
try:
mat = torch.ones(elements_per_gpu, dtype=getattr(torch,
args.dtype)).cuda(local_rank)
input = ((mat.mul_(float(global_rank))).view(-1))
except RuntimeError as e:
if 'out of memory' in str(e):
if dist.get_rank() == 0:
print('WARNING: Ran out of GPU memory. Try to reduce the --mem-factor argument!')
sync_all()
return
else:
raise e

input = create_input_tensor(world_size, elements_per_gpu, global_rank, local_rank, args)
if input is None:
if dist.get_rank() == 0:
print('Try to reduce the --mem-factor argument!')
return
sync_all()
timed_all_reduce(input, start_event, end_event, args)

if args.validate:
run_validation(input, args)
else:
timed_all_reduce(input, start_event, end_event, args)


if __name__ == "__main__":
args = benchmark_parser().parse_args()
rank = args.local_rank
init_processes(local_rank=rank, args=args)
run_all_reduce(local_rank=rank, args=args)
run_all_reduce(local_rank=rank, args=args)
60 changes: 56 additions & 4 deletions benchmarks/communication/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
def env2int(env_list, default=-1):
for e in env_list:
val = int(os.environ.get(e, -1))
if val >= 0: return val
if val >= 0:
return val
return default


Expand Down Expand Up @@ -133,7 +134,7 @@ def get_metric_strings(args, tput, busbw, duration):
duration_ms = duration * 1e3
duration_us = duration * 1e6
tput = f'{tput / 1e9:.3f}'
busbw = f'{busbw /1e9:.3f}'
busbw = f'{busbw / 1e9:.3f}'

if duration_us < 1e3 or args.raw:
duration = f'{duration_us:.3f}'
Expand Down Expand Up @@ -207,6 +208,9 @@ def benchmark_parser():
parser.add_argument("--trials", type=int, default=DEFAULT_TRIALS, help='Number of timed iterations')
parser.add_argument("--warmups", type=int, default=DEFAULT_WARMUPS, help='Number of warmup (non-timed) iterations')
parser.add_argument("--maxsize", type=int, default=24, help='Max message size as a power of 2')
group = parser.add_mutually_exclusive_group()
group.add_argument("--scan", action="store_true", help='Enables scanning all message sizes')
group.add_argument("--single", action="store_true", help='Run only at 2^maxsize message size')
parser.add_argument("--async-op", action="store_true", help='Enables non-blocking communication')
parser.add_argument("--bw-unit", type=str, default=DEFAULT_UNIT, choices=['Gbps', 'GBps'])
parser.add_argument("--backend",
Expand All @@ -219,7 +223,6 @@ def benchmark_parser():
default=DEFAULT_DIST,
choices=['deepspeed', 'torch'],
help='Distributed DL framework to use')
parser.add_argument("--scan", action="store_true", help='Enables scanning all message sizes')
parser.add_argument("--raw", action="store_true", help='Print the message size and latency without units')
parser.add_argument("--all-reduce", action="store_true", help='Run all_reduce')
parser.add_argument("--reduce-scatter", action="store_true", help='Run reduce_scatter')
Expand All @@ -233,6 +236,55 @@ def benchmark_parser():
default=.3,
help='Proportion of max available GPU memory to use for single-size evals')
parser.add_argument("--debug", action="store_true", help='Enables all_to_all debug prints')
parser.add_argument('--all-to-all-v', action='store_true',
parser.add_argument('--all-to-all-v', action='store_true',
help='Use alltoallv instead of alltoall. This will run the all_to_all benchmark with vector variant. Use with --all-to-all or alone to run just this benchmark.')
parser.add_argument("--validate", action="store_true", help='Validate collective results')
return parser


def validate_allreduce(input, args):
"""
Validate all_reduce operation by checking the result against expected value.

Each rank initializes its tensor elements to its rank value (0, 1, 2, ..., n-1).
After all_reduce with SUM, each element should equal n*(n-1)/2.

Args:
input: Input tensor (will be modified in-place by all_reduce)
args: Benchmark arguments containing dist framework info

Returns:
bool: True if validation passes, False otherwise
"""
if args.dist == 'torch':
import torch.distributed as dist
elif args.dist == 'deepspeed':
import deepspeed.comm as dist

dist.all_reduce(input, async_op=False)
sync_all()
n = dist.get_world_size()
expected = float(n * (n - 1) / 2)
return torch.allclose(input, torch.full_like(input, expected))


def run_validation(input, args):
"""
Run validation trials and print results.

Args:
input: Input tensor to validate (will be cloned for each trial)
args: Benchmark arguments
"""
passes = 0
for _ in range(args.trials):
if validate_allreduce(input.clone(), args):
passes += 1

size = input.element_size() * input.nelement()
if not args.raw:
size = convert_size(size)

desc = f"validation ({passes}/{args.trials})"
status = 'PASS' if passes == args.trials else 'FAIL'
print_rank_0(f"{size:<20} {desc:25s} {status}")