diff --git a/train/comms/pt/comms.py b/train/comms/pt/comms.py index b4d914c8..f31873bd 100755 --- a/train/comms/pt/comms.py +++ b/train/comms/pt/comms.py @@ -11,14 +11,10 @@ import time import numpy as np - -# pytorch import torch - from param_bench.train.comms.pt import comms_utils from param_bench.train.comms.pt.comms_utils import ( bootstrap_info_holder, - commsParamsHolder, commsParamsHolderBase, ensureTensorFlush, MultilineFormatter, @@ -33,9 +29,7 @@ commsQuantCollPerfMetrics, customized_perf_loggers, ) - from param_bench.train.comms.pt.pytorch_backend_utils import ( - backendFunctions, pt2ptPatterns, supportedC10dBackends, supportedCollectives, @@ -174,6 +168,16 @@ def readArgs(self, parser): default=None, help="execute pytorch profiler at specified size", ) # execute pytorch profiler at specified size if applicable + parser.add_argument( + "--profiler-active-iters", + "--pa", + type=int, + required=False, + help=( + "If set, profiler will only record these many iters. " + "Otherwise it records the full --num_iters across one benchmark size iteration." + ), + ) parser.add_argument( "--tag", type=str, @@ -1320,13 +1324,19 @@ def benchComm(self, index, commsParams, backendFuncs): self.collectiveArgs.data_type = commsParams.data_type if commsParams.size_start_profiler == curSize: + profiler_active_iters = commsParams.profiler_active_iters + if profiler_active_iters is None: + # not specified in arg + profiler_active_iters = ( + self.collectiveArgs.graph_launches + if self.collectiveArgs.graph_launches + else self.collectiveArgs.numIters + ) self.collectiveArgs.enable_profiler = comms_utils.startProfiler( rank=self.backendFuncs.get_global_rank(), device=self.collectiveArgs.device, numWarmupIters=self.collectiveArgs.numWarmupIters, - numIters=self.collectiveArgs.graph_launches - if self.collectiveArgs.graph_launches - else self.collectiveArgs.numIters, + numIters=profiler_active_iters, ) # self.collectiveArgs has all the information on the experiment. diff --git a/train/comms/pt/comms_utils.py b/train/comms/pt/comms_utils.py index 79f93cf2..6ada709a 100644 --- a/train/comms/pt/comms_utils.py +++ b/train/comms/pt/comms_utils.py @@ -526,12 +526,12 @@ def startProfiler(rank: int, device: str, numWarmupIters: int, numIters: int) -> rank=rank, device=device, warmup=numWarmupIters, - iters=numIters, + active=numIters, ) fbStartProfiler() return True else: - logger.debug("Internal profiler is not available, skip...") + logger.warning("Internal profiler is not available, skip...") return False @@ -901,6 +901,7 @@ def __init__( self.bootstrap_info = bootstrap_info self.size_start_profiler = args.size_start_profiler + self.profiler_active_iters = args.profiler_active_iters self.groupRanks = groupRanks self.include_0B = args.include_0B