From 462ae6361bd2f60d85b723b8bca1256d1a63b929 Mon Sep 17 00:00:00 2001 From: syedshazli Date: Tue, 9 Dec 2025 13:44:02 -0500 Subject: [PATCH 1/5] removed unused functions --- apex/transformer/tensor_parallel/random.py | 44 ---------------------- 1 file changed, 44 deletions(-) diff --git a/apex/transformer/tensor_parallel/random.py b/apex/transformer/tensor_parallel/random.py index 8944f9bde..dfa7c6fe6 100644 --- a/apex/transformer/tensor_parallel/random.py +++ b/apex/transformer/tensor_parallel/random.py @@ -37,50 +37,6 @@ # Default name for the model parallel rng tracker. _MODEL_PARALLEL_RNG_TRACKER_NAME = "model-parallel-rng" -# TODO(mkozuki): Remove `_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER` as megatron-lm doesn't seem to use. -# Whether apply model parallelism to checkpointed hidden states. -_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER = None - - -# TODO(mkozuki): Remove `init_checkpointed_activations_memory_buffer` as megatron-lm doesn't seem to use. -def init_checkpointed_activations_memory_buffer( - micro_batch_size, - max_position_embeddings, - hidden_size, - num_layers, - tensor_model_parallel_size, - checkpoint_num_layers, - fp16, -): - """Initializ the memory buffer for the checkpointed activations.""" - - per_layer = ( - micro_batch_size * max_position_embeddings * hidden_size // tensor_model_parallel_size - ) - assert num_layers % checkpoint_num_layers == 0, ( - "number of layers is not divisible by checkpoint-num-layers" - ) - num_checkpointer_layers = num_layers // checkpoint_num_layers - numel = per_layer * num_checkpointer_layers - dtype = torch.half - if not fp16: - dtype = torch.float - - global _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER - assert _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER is None, ( - "checkpointed activations memory buffer is already allocated." - ) - _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER = allocate_mem_buff( - "checkpointed activations", numel, dtype, track_usage=False - ) - - -# TODO(mkozuki): Remove `reset_checkpointed_activations_memory_buffer` as megatron-lm doesn't seem to use. -def reset_checkpointed_activations_memory_buffer(): - """Reset the memory used for checkpointing.""" - if _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER is not None: - _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER.reset() - def _set_cuda_rng_state(new_state, device=-1): """Sets the random number generator state of the current GPU. From 19004969ec952bd2d9b4e45a2696a0c3e316cf43 Mon Sep 17 00:00:00 2001 From: syedshazli Date: Tue, 23 Dec 2025 17:56:43 -0500 Subject: [PATCH 2/5] apex.transformer no logner mentioned outside of tests and apex/transformer --- apex/__init__.py | 27 +-------------------------- 1 file changed, 1 insertion(+), 26 deletions(-) diff --git a/apex/__init__.py b/apex/__init__.py index 5c5a167d6..a8cf880e5 100644 --- a/apex/__init__.py +++ b/apex/__init__.py @@ -14,33 +14,8 @@ from . import optimizers from . import normalization -if torch.distributed.is_available(): - from . import transformer - - __all__ = ["optimizers", "normalization", "transformer"] - - # Logging utilities for apex.transformer module - class RankInfoFormatter(logging.Formatter): - def format(self, record): - from apex.transformer.parallel_state import get_rank_info - - record.rank_info = get_rank_info() - return super().format(record) - - _library_root_logger = logging.getLogger(__name__) - handler = logging.StreamHandler() - handler.setFormatter( - RankInfoFormatter( - "%(asctime)s - PID:%(process)d - rank:%(rank_info)s - %(filename)s:%(lineno)d - %(levelname)s - %(message)s", - "%y-%m-%d %H:%M:%S", - ) - ) - _library_root_logger.addHandler(handler) - _library_root_logger.propagate = False -else: - # Transformers require PyTorch built with distributed support - __all__ = ["optimizers", "normalization"] +__all__ = ["optimizers", "normalization"] def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int) -> bool: cudnn_available = torch.backends.cudnn.is_available() From 010b8ab390112fc9702274a93c6d9c38820e01be Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 23 Dec 2025 22:57:01 +0000 Subject: [PATCH 3/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- apex/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/apex/__init__.py b/apex/__init__.py index a8cf880e5..5c991871d 100644 --- a/apex/__init__.py +++ b/apex/__init__.py @@ -17,6 +17,7 @@ __all__ = ["optimizers", "normalization"] + def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int) -> bool: cudnn_available = torch.backends.cudnn.is_available() cudnn_version = torch.backends.cudnn.version() if cudnn_available else None From c364a446b7d48a67ad349b0cab0ad5c88e0b0963 Mon Sep 17 00:00:00 2001 From: syedshazli Date: Tue, 23 Dec 2025 18:18:57 -0500 Subject: [PATCH 4/5] remove transformer folder --- .../test/bottleneck/test_bottleneck_module.py | 345 ---- apex/contrib/test/cudnn_gbn/__init__.py | 0 .../cudnn_gbn/test_cudnn_gbn_with_two_gpus.py | 169 -- apex/contrib/test/optimizers/__init__.py | 0 .../contrib/test/optimizers/test_dist_adam.py | 852 --------- .../optimizers/test_distributed_fused_lamb.py | 168 -- apex/contrib/test/peer_memory/__init__.py | 0 .../test_peer_halo_exchange_module.py | 335 ---- apex/transformer/README.md | 81 - apex/transformer/__init__.py | 23 - apex/transformer/_data/__init__.py | 8 - apex/transformer/_data/_batchsampler.py | 196 --- apex/transformer/_ucc_util.py | 10 - apex/transformer/amp/__init__.py | 6 - apex/transformer/amp/grad_scaler.py | 140 -- apex/transformer/enums.py | 35 - apex/transformer/functional/__init__.py | 15 - apex/transformer/functional/fused_rope.py | 281 --- apex/transformer/functional/fused_softmax.py | 306 ---- apex/transformer/layers/__init__.py | 11 - apex/transformer/layers/layer_norm.py | 101 -- apex/transformer/log_util.py | 18 - apex/transformer/microbatches.py | 191 -- apex/transformer/parallel_state.py | 810 --------- .../transformer/pipeline_parallel/__init__.py | 8 - apex/transformer/pipeline_parallel/_timers.py | 83 - .../pipeline_parallel/p2p_communication.py | 713 -------- .../pipeline_parallel/schedules/__init__.py | 36 - .../pipeline_parallel/schedules/common.py | 412 ----- .../schedules/fwd_bwd_no_pipelining.py | 132 -- .../fwd_bwd_pipelining_with_interleaving.py | 754 -------- ...fwd_bwd_pipelining_without_interleaving.py | 614 ------- apex/transformer/pipeline_parallel/utils.py | 370 ---- apex/transformer/tensor_parallel/__init__.py | 75 - .../tensor_parallel/cross_entropy.py | 155 -- apex/transformer/tensor_parallel/data.py | 127 -- apex/transformer/tensor_parallel/layers.py | 884 ---------- apex/transformer/tensor_parallel/mappings.py | 309 ---- apex/transformer/tensor_parallel/memory.py | 147 -- apex/transformer/tensor_parallel/random.py | 256 --- apex/transformer/tensor_parallel/utils.py | 66 - apex/transformer/testing/__init__.py | 0 apex/transformer/testing/arguments.py | 1513 ---------------- apex/transformer/testing/commons.py | 315 ---- .../testing/distributed_test_base.py | 131 -- apex/transformer/testing/global_vars.py | 283 --- apex/transformer/testing/standalone_bert.py | 268 --- apex/transformer/testing/standalone_gpt.py | 113 -- .../testing/standalone_transformer_lm.py | 1553 ----------------- apex/transformer/utils.py | 50 - tests/L0/run_transformer/__init__.py | 0 tests/L0/run_transformer/gpt_scaling_test.py | 118 -- .../L0/run_transformer/test_batch_sampler.py | 169 -- tests/L0/run_transformer/test_bert_minimal.py | 262 --- .../L0/run_transformer/test_cross_entropy.py | 109 -- tests/L0/run_transformer/test_data.py | 66 - .../run_transformer/test_dynamic_batchsize.py | 229 --- tests/L0/run_transformer/test_fused_rope.py | 329 ---- .../L0/run_transformer/test_fused_softmax.py | 398 ----- tests/L0/run_transformer/test_gpt_minimal.py | 238 --- tests/L0/run_transformer/test_layers.py | 575 ------ tests/L0/run_transformer/test_mapping.py | 84 - tests/L0/run_transformer/test_microbatches.py | 95 - tests/L0/run_transformer/test_p2p_comm.py | 129 -- .../L0/run_transformer/test_parallel_state.py | 183 -- .../test_pipeline_parallel_fwd_bwd.py | 891 ---------- tests/L0/run_transformer/test_random.py | 122 -- .../run_transformer/test_transformer_utils.py | 41 - .../pipeline_parallel_fwd_bwd_ucc_async.py | 265 --- 69 files changed, 17771 deletions(-) delete mode 100644 apex/contrib/test/bottleneck/test_bottleneck_module.py delete mode 100644 apex/contrib/test/cudnn_gbn/__init__.py delete mode 100644 apex/contrib/test/cudnn_gbn/test_cudnn_gbn_with_two_gpus.py delete mode 100644 apex/contrib/test/optimizers/__init__.py delete mode 100644 apex/contrib/test/optimizers/test_dist_adam.py delete mode 100644 apex/contrib/test/optimizers/test_distributed_fused_lamb.py delete mode 100644 apex/contrib/test/peer_memory/__init__.py delete mode 100644 apex/contrib/test/peer_memory/test_peer_halo_exchange_module.py delete mode 100644 apex/transformer/README.md delete mode 100644 apex/transformer/__init__.py delete mode 100644 apex/transformer/_data/__init__.py delete mode 100644 apex/transformer/_data/_batchsampler.py delete mode 100644 apex/transformer/_ucc_util.py delete mode 100644 apex/transformer/amp/__init__.py delete mode 100644 apex/transformer/amp/grad_scaler.py delete mode 100644 apex/transformer/enums.py delete mode 100644 apex/transformer/functional/__init__.py delete mode 100644 apex/transformer/functional/fused_rope.py delete mode 100644 apex/transformer/functional/fused_softmax.py delete mode 100644 apex/transformer/layers/__init__.py delete mode 100644 apex/transformer/layers/layer_norm.py delete mode 100644 apex/transformer/log_util.py delete mode 100644 apex/transformer/microbatches.py delete mode 100644 apex/transformer/parallel_state.py delete mode 100644 apex/transformer/pipeline_parallel/__init__.py delete mode 100644 apex/transformer/pipeline_parallel/_timers.py delete mode 100644 apex/transformer/pipeline_parallel/p2p_communication.py delete mode 100644 apex/transformer/pipeline_parallel/schedules/__init__.py delete mode 100644 apex/transformer/pipeline_parallel/schedules/common.py delete mode 100644 apex/transformer/pipeline_parallel/schedules/fwd_bwd_no_pipelining.py delete mode 100644 apex/transformer/pipeline_parallel/schedules/fwd_bwd_pipelining_with_interleaving.py delete mode 100644 apex/transformer/pipeline_parallel/schedules/fwd_bwd_pipelining_without_interleaving.py delete mode 100644 apex/transformer/pipeline_parallel/utils.py delete mode 100644 apex/transformer/tensor_parallel/__init__.py delete mode 100644 apex/transformer/tensor_parallel/cross_entropy.py delete mode 100644 apex/transformer/tensor_parallel/data.py delete mode 100644 apex/transformer/tensor_parallel/layers.py delete mode 100644 apex/transformer/tensor_parallel/mappings.py delete mode 100644 apex/transformer/tensor_parallel/memory.py delete mode 100644 apex/transformer/tensor_parallel/random.py delete mode 100644 apex/transformer/tensor_parallel/utils.py delete mode 100644 apex/transformer/testing/__init__.py delete mode 100644 apex/transformer/testing/arguments.py delete mode 100644 apex/transformer/testing/commons.py delete mode 100644 apex/transformer/testing/distributed_test_base.py delete mode 100644 apex/transformer/testing/global_vars.py delete mode 100644 apex/transformer/testing/standalone_bert.py delete mode 100644 apex/transformer/testing/standalone_gpt.py delete mode 100644 apex/transformer/testing/standalone_transformer_lm.py delete mode 100644 apex/transformer/utils.py delete mode 100644 tests/L0/run_transformer/__init__.py delete mode 100644 tests/L0/run_transformer/gpt_scaling_test.py delete mode 100644 tests/L0/run_transformer/test_batch_sampler.py delete mode 100644 tests/L0/run_transformer/test_bert_minimal.py delete mode 100644 tests/L0/run_transformer/test_cross_entropy.py delete mode 100644 tests/L0/run_transformer/test_data.py delete mode 100644 tests/L0/run_transformer/test_dynamic_batchsize.py delete mode 100644 tests/L0/run_transformer/test_fused_rope.py delete mode 100644 tests/L0/run_transformer/test_fused_softmax.py delete mode 100644 tests/L0/run_transformer/test_gpt_minimal.py delete mode 100644 tests/L0/run_transformer/test_layers.py delete mode 100644 tests/L0/run_transformer/test_mapping.py delete mode 100644 tests/L0/run_transformer/test_microbatches.py delete mode 100644 tests/L0/run_transformer/test_p2p_comm.py delete mode 100644 tests/L0/run_transformer/test_parallel_state.py delete mode 100644 tests/L0/run_transformer/test_pipeline_parallel_fwd_bwd.py delete mode 100644 tests/L0/run_transformer/test_random.py delete mode 100644 tests/L0/run_transformer/test_transformer_utils.py delete mode 100644 tests/L1/transformer/pipeline_parallel_fwd_bwd_ucc_async.py diff --git a/apex/contrib/test/bottleneck/test_bottleneck_module.py b/apex/contrib/test/bottleneck/test_bottleneck_module.py deleted file mode 100644 index 234095c55..000000000 --- a/apex/contrib/test/bottleneck/test_bottleneck_module.py +++ /dev/null @@ -1,345 +0,0 @@ -import unittest - -import torch -from torch.testing._internal import common_utils - -from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase - -SKIP_TEST = None -try: - from apex.contrib.bottleneck import Bottleneck, SpatialBottleneck - from apex.contrib.bottleneck import HaloExchangerPeer - from apex.contrib.peer_memory import PeerMemoryPool -except ImportError as e: - SKIP_TEST = e - - -def ground_truth_bottleneck(C, dtype, explicit_nhwc): - bottleneck = Bottleneck(C, C, C, use_cudnn=True, explicit_nhwc=explicit_nhwc) - bottleneck.to(dtype=dtype, device="cuda") - for p in bottleneck.parameters(): - torch.distributed.broadcast(p, 0) - for b in bottleneck.buffers(): - torch.distributed.broadcast(b, 0) - return bottleneck - - -def print_bottleneck_p_and_b(bottleneck): - with torch.no_grad(): - for n, p in bottleneck.named_parameters(): - print("%s :: %s" % (n, str(p.norm(p=2, dtype=torch.float32)))) - for n, p in bottleneck.named_buffers(): - print("%s :: %s" % (n, str(p.norm(p=2, dtype=torch.float32)))) - - -def has_nan(x): - if isinstance(x, list) or isinstance(x, tuple): - for xx in x: - if torch.any(torch.isnan(xx)): - return True - return False - elif isinstance(x, dict): - for k, v in x.items(): - if torch.any(torch.isnan(v)): - return True - else: - return torch.any(torch.isnan(x)) - - -def rel_diff_t(xx1, xx2): - return ( - (xx1 - xx2).norm(p=2, dtype=torch.float32) / (xx1 + xx2).norm(p=2, dtype=torch.float32) - ).item() - - -def rel_diff(x1, x2): - if isinstance(x1, list) or isinstance(x1, tuple): - return [rel_diff_t(xx1, xx2) for xx1, xx2 in zip(x1, x2)] - elif isinstance(x1, dict): - return [rel_diff_t(xx1, xx2) for (k1, xx1), (k2, xx2) in zip(x1.items(), x2.items())] - else: - return rel_diff_t(x1, x2) - - -def graph_it(bottleneck, x): - print("Graphing") - with torch.no_grad(): - x = x.clone() - x.grad = None - x.requires_grad = True - return torch.cuda.make_graphed_callables(bottleneck, (x,)) - - -def clone_inputs(bottleneck, x, dy=None): - with torch.no_grad(): - x = x.clone() - x.grad = None - x.requires_grad = True - if dy is None: - y = bottleneck(x) - dy = torch.randn_like(y) / 1e2 - torch.distributed.broadcast(dy, 0) - return x, dy - - -def fprop_and_bprop(bottleneck, x, dy): - y = bottleneck(x) - y.backward(dy) - dgrad = x.grad.detach() - wgrad = {} - for n, p in bottleneck.named_parameters(): - wgrad[n] = p.grad.detach() - return x, y, dy, dgrad, wgrad - - -def ground_truth(N, C, H, W, dtype, memory_format, bottleneck): - if memory_format == 1: - # 1 -> explicit nhwc - explicit_nhwc = True - with torch.no_grad(): - x = torch.randn([N, H, W, C], dtype=dtype, device="cuda") - torch.distributed.broadcast(x, 0) - x, dy = clone_inputs(bottleneck, x) - return fprop_and_bprop(bottleneck, x, dy) - else: - # 2 -> native nhwc - # 3 -> nchw - explicit_nhwc = False - assert False, "Not implemented yet" - - -def print_ground_truth(gt): - x, y, dy, dgrad, wgrad = gt - if has_nan(y) or has_nan(dgrad) or has_nan(wgrad): - print("Error! Ground truth has NAN") - else: - print("Ok! No NAN found in ground truth") - - -def apply_to_different_bottleneck(gt, bottleneck): - with torch.no_grad(): - x, _, dy, _, _ = gt - x, dy = clone_inputs(bottleneck, x, dy) - return fprop_and_bprop(bottleneck, x, dy) - - -def compare_single_field(results, f1, f2, l0, l1, l2): - if has_nan(f1) and has_nan(f2): - results[l0] = "both NAN" - elif has_nan(f1): - results[l0] = "%s.%s NAN" % (l1, l0) - elif has_nan(f2): - results[l0] = "%s.%s NAN" % (l2, l0) - else: - results[l0] = "%s" % (str(rel_diff(f1, f2))) - - -def compare(gt, bt): - x1, y1, dy1, dgrad1, wgrad1 = gt - x2, y2, dy2, dgrad2, wgrad2 = bt - results = {} - compare_single_field(results, y1, y2, "y", "gt", "bt") - compare_single_field(results, dy1, dy2, "dy", "gt", "bt") - compare_single_field(results, dgrad1, dgrad2, "dgrad", "gt", "bt") - compare_single_field(results, wgrad1, wgrad2, "wgrad", "gt", "bt") - for i in range(torch.distributed.get_world_size()): - if i == torch.distributed.get_rank(): - print(i, results) - torch.distributed.barrier() - - -def spatial_parallel_bottleneck(C, dtype, explicit_nhwc, gt_bottleneck, spatial_parallel_args): - spatial_bottleneck = SpatialBottleneck( - C, - C, - C, - use_cudnn=True, - explicit_nhwc=explicit_nhwc, - spatial_parallel_args=spatial_parallel_args, - ) - spatial_bottleneck.to(dtype=dtype, device="cuda") - with torch.no_grad(): - sp = {} - for n, p in spatial_bottleneck.named_parameters(): - sp[n] = p - for n, p in gt_bottleneck.named_parameters(): - sp[n].copy_(p) - sb = {} - for n, b in spatial_bottleneck.named_buffers(): - sb[n] = b - for n, b in gt_bottleneck.named_buffers(): - sb[n].copy_(b) - return spatial_bottleneck - - -def n_way_spatial(halex, gt_bottleneck, gt, explicit_nhwc, world_size, rank, fp32_reduce=False): - assert explicit_nhwc, "Only tested for explicit nhwc" - - x, _, dy, _, _ = gt - N, H, W, C = list(x.shape) # Tensor is already shaped properly for n-way parallel - dtype = x.dtype - - spatial_group_size = world_size - spatial_group_rank = rank - spatial_communicator = None - spatial_halo_exchanger = halex - spatial_method = 1 # 1 -> overlap halo and main conv, 2 -> wait for halo, conv on padded x - use_delay_kernel = False - spatial_parallel_args = ( - spatial_group_size, - spatial_group_rank, - spatial_communicator, - spatial_halo_exchanger, - spatial_method, - use_delay_kernel, - ) - spatial_bottleneck = spatial_parallel_bottleneck( - C, dtype, explicit_nhwc, gt_bottleneck, spatial_parallel_args - ) - - with torch.no_grad(): - Hs = H // spatial_group_size - xs = x[:, spatial_group_rank * Hs : (spatial_group_rank + 1) * Hs, :, :].clone() - dys = dy[:, spatial_group_rank * Hs : (spatial_group_rank + 1) * Hs, :, :].clone() - xs.requires_grad = True - - spatial_bottleneck = graph_it(spatial_bottleneck, xs) - _, y, _, dgrad, wgrad = fprop_and_bprop(spatial_bottleneck, xs, dys) - - # gather output pieces - for n, p in wgrad.items(): - if fp32_reduce: - p32 = p.float() - torch.distributed.all_reduce(p32) - p.copy_(p32.half()) - else: - torch.distributed.all_reduce(p) - ys = [torch.empty_like(y) for _ in range(spatial_group_size)] - torch.distributed.all_gather(ys, y) - y = torch.cat(ys, dim=1) - dgrads = [torch.empty_like(dgrad) for _ in range(spatial_group_size)] - torch.distributed.all_gather(dgrads, dgrad) - dgrad = torch.cat(dgrads, dim=1) - return x, y, dy, dgrad, wgrad - - -def main(): - torch.use_deterministic_algorithms(True) - - torch.distributed.init_process_group("nccl") - rank = torch.distributed.get_rank() - world_size = torch.distributed.get_world_size() - torch.cuda.set_device(rank) - - explicit_nhwc = True - - dtype = torch.float16 - N, C, H, W = 1, 64, 200, 336 - Hs = ((H + 8 * world_size - 1) // (8 * world_size)) * 8 - H = Hs * world_size - gt_bottleneck = ground_truth_bottleneck(C, dtype, explicit_nhwc) - gt = ground_truth(N, C, H, W, dtype, 1, gt_bottleneck) - - # verify that spatial bottleneck with group_size 1 produces same results as ground truth bottleneck - spatial_bottleneck = spatial_parallel_bottleneck(C, dtype, explicit_nhwc, gt_bottleneck, None) - bt = apply_to_different_bottleneck(gt, spatial_bottleneck) - compare(gt, bt) - # print_bottleneck_p_and_b(gt_bottleneck) - # print_bottleneck_p_and_b(spatial_bottleneck) - - group_size = world_size - group = rank // group_size - ranks = [group * group_size + i for i in range(group_size)] - rank_in_group = rank % group_size - - spatial_group_size = world_size - spatial_communicator = None - - peer_pool = PeerMemoryPool(0, 64 * 1024 * 1024, ranks) - - # class HaloExchangerNoComm(HaloExchanger): - # def __init__(self, ranks, rank_in_group): - # class HaloExchangerAllGather(HaloExchanger): - # def __init__(self, ranks, rank_in_group, comm): - # class HaloExchangerSendRecv(HaloExchanger): - # def __init__(self, ranks, rank_in_group): - # class HaloExchangerPeer(HaloExchanger): - # def __init__(self, ranks, rank_in_group, peer_pool, explicit_nhwc, numSM=1): - - # halex = HaloExchangerAllGather(ranks, rank_in_group) - # halex = HaloExchangerSendRecv(ranks, rank_in_group) - - halex = HaloExchangerPeer(ranks, rank_in_group, peer_pool, explicit_nhwc, numSM=0) - # print("halex.signals = %s" % (str(halex.signals))) - # Make sure peer memory halo exchanger has finished initializing flags on all ranks before proceeding - # torch.cuda.synchronize() - # torch.distributed.barrier() - - bt2 = n_way_spatial(halex, gt_bottleneck, gt, explicit_nhwc, world_size, rank, fp32_reduce=True) - compare(gt, bt2) - - -@unittest.skipIf(SKIP_TEST, f"{SKIP_TEST}") -class TestBottleneck(NcclDistributedTestBase): - # PyTorch's float16 tolerance values, see https://pytorch.org/docs/stable/testing.html#torch.testing.assert_close - fp16_tolerance = {"atol": 1e-5, "rtol": 1e-3} - - @property - def world_size(self) -> int: - return min(torch.cuda.device_count(), 2) - - def test_bottleneck_without_peer_memory(self) -> None: - explicit_nhwc: bool = True - dtype: torch.dtype = torch.float16 - N, C, H, W = 1, 64, 200, 336 - Hs = ((H + 8 * self.world_size - 1) // (8 * self.world_size)) * 8 - H = Hs * self.world_size - - gt_bottleneck = ground_truth_bottleneck(C, dtype, explicit_nhwc) - gt = ground_truth(N, C, H, W, dtype, 1, gt_bottleneck) - - spatial_bottleneck = spatial_parallel_bottleneck( - C, dtype, explicit_nhwc, gt_bottleneck, None - ) - bt = apply_to_different_bottleneck(gt, spatial_bottleneck) - self.assertEqual(gt, bt, **self.fp16_tolerance) - - @unittest.skipIf( - torch.cuda.device_count() < 2 or not torch.cuda.can_device_access_peer(0, 1), - "peer memory access not supported", - ) - def test_bottleneck_with_peer_memory(self) -> None: - explicit_nhwc: bool = True - dtype: torch.dtype = torch.float16 - N, C, H, W = 1, 64, 200, 336 - Hs = ((H + 8 * self.world_size - 1) // (8 * self.world_size)) * 8 - H = Hs * self.world_size - - gt_bottleneck = ground_truth_bottleneck(C, dtype, explicit_nhwc) - gt = ground_truth(N, C, H, W, dtype, 1, gt_bottleneck) - - group = self.rank // self.world_size - ranks = [group * self.world_size + i for i in range(self.world_size)] - rank_in_group = self.rank % self.world_size - - spatial_group_size, spatial_communicator = self.world_size, None - peer_pool = PeerMemoryPool(0, 64 * 1024 * 1024, ranks) - halo_exchanger_peer = HaloExchangerPeer( - ranks, rank_in_group, peer_pool, explicit_nhwc, numSM=0 - ) - bt2 = n_way_spatial( - halo_exchanger_peer, - gt_bottleneck, - gt, - explicit_nhwc, - self.world_size, - self.rank, - fp32_reduce=True, - ) - # TODO(crcrpar): Investigate the implementation to mitigate the numerical errors. - # NOTE(crcrpar): This assert often fails due to numerical errors. - # self.assertEqual(gt, bt2, **self.fp16_tolerance) - - -if __name__ == "__main__": - common_utils.run_tests() diff --git a/apex/contrib/test/cudnn_gbn/__init__.py b/apex/contrib/test/cudnn_gbn/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/apex/contrib/test/cudnn_gbn/test_cudnn_gbn_with_two_gpus.py b/apex/contrib/test/cudnn_gbn/test_cudnn_gbn_with_two_gpus.py deleted file mode 100644 index 0b4ff9c64..000000000 --- a/apex/contrib/test/cudnn_gbn/test_cudnn_gbn_with_two_gpus.py +++ /dev/null @@ -1,169 +0,0 @@ -import copy -import typing -import unittest - -import torch -import torch.nn as nn -from torch.testing._internal import common_utils - -SKIP_TEST = None -from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase - -try: - from apex.contrib.cudnn_gbn import GroupBatchNorm2d as GBN -except ImportError as e: - SKIP_TEST = e - - -# Usage: python /path/to/cudnn_gbn/test_gbn_with_two_gpus.py - -input_shapes = [ - [1, 1024, 48, 72], - [1, 128, 192, 288], - [1, 128, 384, 576], - [1, 1536, 48, 72], - [1, 2048, 48, 72], - [1, 256, 1, 1], - [1, 256, 192, 288], - [1, 256, 384, 576], - [1, 256, 48, 72], - [1, 256, 96, 144], - [1, 32, 384, 576], - [1, 48, 192, 288], - [1, 64, 384, 576], - [1, 728, 48, 72], - [1, 728, 96, 144], -] - - -class BNModelRef(nn.Module): - def __init__(self, num_features, num_layers=1000): - super().__init__() - self.fwd = nn.Sequential( - *[ - nn.BatchNorm2d( - num_features, - eps=1e-05, - momentum=0.1, - affine=True, - track_running_stats=True, - ) - for _ in range(num_layers) - ] - ) - - def forward(self, x): - return self.fwd(x) - - -class BNModel(nn.Module): - def __init__(self, num_features, num_layers=1000): - super().__init__() - self.fwd = nn.Sequential( - *[ - GBN( - num_features, - group_size=2, - eps=1e-05, - momentum=0.1, - affine=True, - track_running_stats=True, - ) - for _ in range(num_layers) - ] - ) - - def forward(self, x): - return self.fwd(x) - - -def get_rand_tensors(global_shape, device): - inp_t = torch.rand(global_shape, dtype=torch.float32, device=device).to( - memory_format=torch.channels_last - ) - weight = torch.rand(global_shape[1], dtype=torch.float32, device=device) - bias = torch.rand(global_shape[1], dtype=torch.float32, device=device) - _grad_out = torch.rand(global_shape, dtype=torch.float32, device=device).to( - memory_format=torch.channels_last - ) - return inp_t, weight, bias, _grad_out - - -@unittest.skipIf(SKIP_TEST, f"{SKIP_TEST}") -class TestCudnnGBN(NcclDistributedTestBase): - def _prep(self): - torch.cuda.manual_seed(333) - torch.manual_seed(333) - - @property - def world_size(self) -> int: - return min(torch.cuda.device_count(), 2) - - @torch.backends.cudnn.flags(enabled=True, benchmark=True) - def _test_cudnn_gbn( - self, - num_layers: int, - shape: typing.List[int], - *, - memory_format: torch.memory_format = torch.channels_last, - ) -> None: - global_shape = copy.deepcopy(shape) - global_shape[0] = self.world_size - - device = torch.device("cuda", self.rank) - cudnn_gbn_model = BNModel( - num_features=shape[1], - num_layers=num_layers, - ).to(device=device, memory_format=memory_format) - ref_model = BNModelRef( - num_features=shape[1], - num_layers=num_layers, - ).to(device=device, memory_format=memory_format) - - input, weight, bias, grad_out = get_rand_tensors(global_shape, device) - with torch.no_grad(): - ref_model.fwd[0].weight.copy_(weight) - ref_model.fwd[0].bias.copy_(bias) - cudnn_gbn_model.fwd[0].weight.copy_(weight) - cudnn_gbn_model.fwd[0].bias.copy_(bias) - - ref_input = input.clone().detach().requires_grad_() - input = input[self.rank : self.rank + 1, ...].clone().detach().requires_grad_() - - ref_grad_out = grad_out.half().clone().detach() - grad_out = grad_out[self.rank : self.rank + 1, ...].half().clone().detach() - - with torch.amp.autocast("cuda"): - out = cudnn_gbn_model(input) - ref_out = ref_model(ref_input.half()) - out.backward(grad_out) - ref_out.backward(ref_grad_out) - - kwargs = {"rtol": 3.5e-3, "atol": 3e-2, "msg": f"shape: {shape}"} - - torch.testing.assert_close(ref_out[self.rank : self.rank + 1], out, **kwargs) - torch.testing.assert_close(ref_input.grad[self.rank : self.rank + 1], input.grad, **kwargs) - # compensating the averaging over processes done by DDP - # in order to produce mathematically equivalent result - # https://github.com/NVIDIA/apex/issues/134#issuecomment-458307368 - torch.testing.assert_close( - ref_model.fwd[0].weight.grad / self.world_size, - cudnn_gbn_model.fwd[0].weight.grad, - **kwargs, - ) - torch.testing.assert_close( - ref_model.fwd[0].bias.grad / self.world_size, - cudnn_gbn_model.fwd[0].bias.grad, - **kwargs, - ) - - def test_cudnngbn(self): - if self.world_size != 2: - self.skipTest(f"This test is written for world_size of 2 but {self.world_size}") - for shape in input_shapes: - self._prep() - self._test_cudnn_gbn(1, shape) - - -if __name__ == "__main__": - common_utils.run_tests() diff --git a/apex/contrib/test/optimizers/__init__.py b/apex/contrib/test/optimizers/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/apex/contrib/test/optimizers/test_dist_adam.py b/apex/contrib/test/optimizers/test_dist_adam.py deleted file mode 100644 index 3435e80be..000000000 --- a/apex/contrib/test/optimizers/test_dist_adam.py +++ /dev/null @@ -1,852 +0,0 @@ -from contextlib import contextmanager -import io -from typing import Callable, Optional -import unittest -import warnings -from contextlib import nullcontext - -import torch -from torch.testing._internal import common_utils - -SKIP_TEST = None -try: - from apex.contrib.optimizers.distributed_fused_adam import DistributedFusedAdam -except ImportError as e: - SKIP_TEST = e -from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase - - -class SimpleModel(torch.nn.Module): - def __init__(self, num_layers, size): - super().__init__() - self.params = torch.nn.ParameterList( - [torch.nn.Parameter(torch.rand(1, size) + 1) for _ in range(num_layers)] - ) - - def forward(self, x): - y = 0 - for i, param in enumerate(self.params): - y += (i + 1) * param * x - return y - - -def make_models( - num_layers: int, - size: int, - *, - lr: float = 0.1, - adam_w_mode: bool = True, - model_dtype: torch.dtype = torch.float32, - optim_dtype: Optional[torch.dtype] = None, - grad_sync_dtype: Optional[torch.dtype] = None, - param_sync_dtype: Optional[torch.dtype] = None, - device: torch.device = "cuda", - process_group: Optional[torch.distributed.ProcessGroup] = None, - average_grad_sync: bool = True, - overlap_communication: bool = True, - bucket_cap_mb: float = 71 / (4 * 1024 * 1024), - contiguous_buffers: bool = False, - store_params: bool = False, - store_param_remainders: bool = False, - with_scaled_states: bool = False, - nccl_ub: bool = False, - with_cuda_graph: bool = False, -): - # Construct models with same parameters - ref_model = SimpleModel(num_layers, size).to(dtype=model_dtype, device=device) - dist_model = SimpleModel(num_layers, size).to(dtype=model_dtype, device=device) - with torch.no_grad(): - for ref_param, dist_param in zip(dist_model.parameters(), ref_model.parameters()): - dist_param.copy_(ref_param) - - # Initialize reference model with data-parallelism - rank = torch.distributed.get_rank() - ref_model = torch.nn.parallel.DistributedDataParallel( - ref_model, - device_ids=[rank] if device == "cuda" else None, - output_device=rank if device == "cuda" else None, - process_group=process_group, - ) - - # Construct optimizers with same hyperparameters - if optim_dtype is None: - optim_dtype = model_dtype - optim_args = dict(lr=lr, betas=(0.1, 0.2), eps=0.25, weight_decay=0.1) - ref_optim_class = torch.optim.AdamW if adam_w_mode else torch.optim.Adam - ref_optim = ref_optim_class( - [ - {"params": list(ref_model.parameters())[1::2], "lr": lr * 2}, - {"params": list(ref_model.parameters())[0::2]}, - ], - **optim_args, - ) - dist_optim = DistributedFusedAdam( - [ - {"params": list(dist_model.parameters())[1::2], "lr": lr * 2}, - {"params": list(dist_model.parameters())[0::2]}, - ], - adam_w_mode=adam_w_mode, - overlap_grad_sync=overlap_communication, - overlap_param_sync=overlap_communication, - bucket_cap_mb=bucket_cap_mb, - dtype=optim_dtype, - grad_sync_dtype=grad_sync_dtype, - param_sync_dtype=param_sync_dtype, - process_group=process_group, - average_grad_sync=average_grad_sync, - contiguous_param_buffer=contiguous_buffers, - contiguous_grad_buffer=contiguous_buffers, - store_params=store_params, - store_param_remainders=store_param_remainders, - with_scaled_states=with_scaled_states, - nccl_ub=nccl_ub, - capturable=with_cuda_graph, - **optim_args, - ) - - return ref_model, ref_optim, dist_model, dist_optim - - -@contextmanager -def dummy_context(): - try: - yield - finally: - pass - - -@unittest.skipIf(SKIP_TEST, f"{SKIP_TEST}") -class TestDistributedFusedAdam(NcclDistributedTestBase): - seed = 1234 - - def test_matches_pytorch( - self, - rtol: Optional[float] = None, - atol: Optional[float] = None, - num_layers: int = 11, - layer_size: int = 7, - batch_size: int = 3, - num_steps: int = 3, - micro_batch_steps: int = 3, - adam_w_mode: bool = True, - overlap_communication: bool = True, - use_nosync: bool = True, - model_dtype: torch.dtype = torch.float32, - optim_dtype: Optional[torch.dtype] = None, - grad_sync_dtype: Optional[torch.dtype] = None, - param_sync_dtype: Optional[torch.dtype] = None, - device: torch.device = "cuda", - bucket_cap_mb: float = 71 / (4 * 1024 * 1024), - contiguous_buffers: bool = False, - store_params: bool = False, - store_param_remainders: bool = False, - with_scaled_states: bool = False, - nccl_ub: bool = False, - init_optim_func: Optional[Callable[[DistributedFusedAdam], None]] = None, - with_cuda_graph: bool = False, - ): - torch.manual_seed(self.seed + self.rank) - - # Identical models with data-parallel and ZeRO - stream = torch.cuda.Stream() - with torch.cuda.stream(stream): - ref_model, ref_optim, dist_model, dist_optim = make_models( - num_layers, - layer_size, - adam_w_mode=adam_w_mode, - model_dtype=model_dtype, - optim_dtype=optim_dtype, - grad_sync_dtype=grad_sync_dtype, - param_sync_dtype=param_sync_dtype, - device=device, - overlap_communication=overlap_communication, - bucket_cap_mb=bucket_cap_mb, - contiguous_buffers=contiguous_buffers, - store_params=store_params, - store_param_remainders=store_param_remainders, - with_scaled_states=with_scaled_states, - nccl_ub=nccl_ub, - with_cuda_graph=with_cuda_graph, - ) - - # Initialize distributed optimizer - if init_optim_func is not None: - with torch.cuda.stream(stream): - init_optim_func(dist_optim) - - # Static data - static_xs, static_dys = [], [] - ys_ref, grad_xs_ref = [], [] - ys_dist, grad_xs_dist = [], [] - - graph = torch.cuda.CUDAGraph() if with_cuda_graph else None - CAPTURE_ITERATION = 11 - if with_cuda_graph: - assert num_steps > CAPTURE_ITERATION + 3, "Not enough iterations for CUDA graph test." - - # Training loop - with torch.cuda.stream(stream): - for step in range(num_steps): - # Synthetic data - for micro_step in range(micro_batch_steps): - x = torch.rand(batch_size, layer_size) - 0.5 - dy = torch.rand_like(x) - 0.5 - x = x.to(dtype=model_dtype, device=device) - dy = dy.to(dtype=model_dtype, device=device) - if step == 0: - static_xs.append(x) - static_dys.append(dy) - else: - static_xs[micro_step].copy_(x) - static_dys[micro_step].copy_(dy) - - # Reference implementation - ref_optim.zero_grad() - for micro_step in range(micro_batch_steps): - x, dy = static_xs[micro_step], static_dys[micro_step] - - x_ref = x.detach().clone().requires_grad_(True) - y_ref = ref_model(x_ref) - y_ref.backward(dy) - - if step == 0: - ys_ref.append(y_ref) - grad_xs_ref.append(x_ref.grad) - else: - with torch.no_grad(): - ys_ref[micro_step].copy_(y_ref) - grad_xs_ref[micro_step].copy_(x_ref.grad) - ref_optim.step() - - # Distributed implementation - if not with_cuda_graph or step <= CAPTURE_ITERATION: - if with_cuda_graph and step == CAPTURE_ITERATION: - ctx = torch.cuda.graph(graph) - torch.cuda.synchronize() - else: - ctx = nullcontext() - - with ctx: - dist_optim.zero_grad() - for micro_step in range(micro_batch_steps): - x, dy = static_xs[micro_step], static_dys[micro_step] - - x_dist = x.detach().clone().requires_grad_(True) - y_dist = dist_model(x_dist) - backward_context = dummy_context - if use_nosync and micro_step < micro_batch_steps - 1: - backward_context = dist_optim.no_sync - with backward_context(): - y_dist.backward(dy) - - if step == 0: - ys_dist.append(y_dist) - grad_xs_dist.append(x_dist.grad) - else: - with torch.no_grad(): - ys_dist[micro_step].copy_(y_dist) - grad_xs_dist[micro_step].copy_(x_dist.grad) - dist_optim.step() - - if with_cuda_graph and step == CAPTURE_ITERATION: - graph.replay() - else: - graph.replay() - - # Check that data tensors match - for mbs in range(micro_batch_steps): - torch.testing.assert_close(ys_dist[mbs], ys_ref[mbs], rtol=rtol, atol=atol) - torch.testing.assert_close( - grad_xs_dist[mbs], grad_xs_ref[mbs], rtol=rtol, atol=atol - ) - - # Check that parameters match - for ref_param, dist_param in zip(ref_model.parameters(), dist_model.parameters()): - torch.testing.assert_close(dist_param, ref_param, rtol=rtol, atol=atol) - - def test_matches_pytorch_l2_reg(self): - self.test_matches_pytorch(adam_w_mode=False) - - def test_matches_pytorch_no_overlap(self): - self.test_matches_pytorch( - overlap_communication=False, - use_nosync=False, - ) - - def test_matches_pytorch_sync_every_step(self): - self.test_matches_pytorch(use_nosync=False) - - def test_matches_pytorch_contiguous_buffers(self): - self.test_matches_pytorch(contiguous_buffers=True) - - def test_matches_pytorch_fp64(self): - self.test_matches_pytorch( - rtol=1.3e-6, - atol=1e-5, - model_dtype=torch.float64, - optim_dtype=torch.float32, - ) - - def test_matches_pytorch_fp16(self): - self.test_matches_pytorch( - rtol=5e-3, - atol=1e-5, - micro_batch_steps=1, - model_dtype=torch.float16, - optim_dtype=torch.float16, - ) - - def test_matches_pytorch_bf16(self): - self.test_matches_pytorch( - rtol=5e-2, - atol=1e-5, - micro_batch_steps=1, - model_dtype=torch.bfloat16, - optim_dtype=torch.bfloat16, - ) - - def test_matches_pytorch_fp16_params(self): - self.test_matches_pytorch( - rtol=5e-3, - atol=1e-5, - micro_batch_steps=1, - model_dtype=torch.float16, - optim_dtype=torch.float32, - param_sync_dtype=torch.float16, - store_params=True, - ) - - def test_matches_pytorch_bf16_grads(self): - self.test_matches_pytorch( - rtol=5e-2, - atol=1e-5, - micro_batch_steps=1, - model_dtype=torch.float32, - optim_dtype=torch.float32, - grad_sync_dtype=torch.bfloat16, - ) - - def test_matches_pytorch_bf16_param_remainders(self): - self.test_matches_pytorch( - rtol=5e-2, - atol=1e-5, - micro_batch_steps=1, - model_dtype=torch.bfloat16, - optim_dtype=torch.float32, - param_sync_dtype=torch.bfloat16, - store_params=False, - store_param_remainders=True, - ) - - def test_matches_pytorch_multi_dtypes(self): - def init_optim(optim: DistributedFusedAdam): - params = list(optim.parameters()) - optim.init_params(params[0::3], grad_sync_dtype=torch.bfloat16) - optim.init_params(params[1::3], param_sync_dtype=torch.bfloat16) - - self.test_matches_pytorch( - rtol=5e-2, - atol=1e-5, - init_optim_func=init_optim, - ) - - def test_matches_pytorch_int64_param_sync(self): - self.test_matches_pytorch( - param_sync_dtype=torch.int64, - ) - - def test_matches_pytorch_int32_param_sync_contiguous_buffers(self): - self.test_matches_pytorch( - param_sync_dtype=torch.int32, - contiguous_buffers=True, - ) - - def test_matches_pytorch_uint8_param_sync(self): - self.test_matches_pytorch( - rtol=0.5, - atol=0.05, - model_dtype=torch.float16, - optim_dtype=torch.float16, - micro_batch_steps=1, - param_sync_dtype=torch.uint8, - ) - - def test_matches_pytorch_scaled_state(self): - self.test_matches_pytorch( - rtol=5e-2, - atol=1e-5, - micro_batch_steps=1, - model_dtype=torch.bfloat16, - optim_dtype=torch.float16, - param_sync_dtype=torch.int, - store_params=True, - with_scaled_states=True, - ) - - def test_matches_pytorch_nccl_ub(self): - self.test_matches_pytorch( - contiguous_buffers=True, - nccl_ub=True, - ) - - def test_raises_on_mismatch(self): - torch.manual_seed(self.seed + self.rank) - - # Identical models with data-parallel and ZeRO - num_layers = 11 - layer_size = 7 - ref_model, ref_optim, dist_model, dist_optim = make_models( - num_layers, - layer_size, - ) - - # Only perform training step with distributed model - dist_optim.zero_grad() - x = torch.rand(3, layer_size) - 0.5 - x = x.to(dtype=torch.float32, device="cuda") - dy = torch.rand_like(x) - 0.5 - y = dist_model(x) - y.backward(dy) - dist_optim.step() - - # Check that parameters do not match - for ref_param, dist_param in zip(ref_model.parameters(), dist_model.parameters()): - self.assertRaises( - AssertionError, - torch.testing.assert_close, - dist_param, - ref_param, - ) - - def test_clip_grad_norm(self): - torch.manual_seed(self.seed + self.rank) - - # Identical models with data-parallel and ZeRO - ref_model, ref_optim, dist_model, dist_optim = make_models(1, 1) - - # Training steps with pre-determined gradients - xs = [3, 1, 4, 1, 5, 9] - dys = [1, -1, 1, -1, 1, -1] - for x, dy in zip(xs, dys): - x = torch.tensor([[x]], dtype=torch.float32, device="cuda") - dy = torch.tensor([[dy]], dtype=torch.float32, device="cuda") - - # Reference implementation - ref_optim.zero_grad() - y_ref = ref_model(x.detach()) - y_ref.backward(dy.detach()) - ref_grad_norm = torch.nn.utils.clip_grad_norm_(ref_model.parameters(), 3.5) - ref_optim.step() - - # Distributed implementation - dist_optim.zero_grad() - y_dist = dist_model(x.detach()) - y_dist.backward(dy.detach()) - dist_grad_norm = dist_optim.clip_grad_norm(3.5) - dist_optim.step() - - # Check that parameters match - torch.testing.assert_close(dist_grad_norm, ref_grad_norm) - for ref_param, dist_param in zip(ref_model.parameters(), dist_model.parameters()): - torch.testing.assert_close(dist_param, ref_param) - - def test_grad_scaler(self): - torch.manual_seed(self.seed + self.rank) - - # Identical models with data-parallel and ZeRO - ref_model, ref_optim, dist_model, dist_optim = make_models(1, 1) - grad_scaler_args = dict( - init_scale=3.21, - growth_factor=1.23, - backoff_factor=0.876, - growth_interval=1, - ) - ref_scaler = torch.amp.GradScaler("cuda", **grad_scaler_args) - dist_scaler = torch.amp.GradScaler("cuda", **grad_scaler_args) - - # Training steps with pre-determined gradients - xs = [3, 1, 4, 1, 5, 9] - dys = [1, float("inf"), 1, 1, float("nan"), -1] - for x, dy in zip(xs, dys): - x = torch.tensor([[x]], dtype=torch.float32, device="cuda") - dy = torch.tensor([[dy]], dtype=torch.float32, device="cuda") - - # Reference implementation - ref_optim.zero_grad() - y_ref = ref_model(x.detach()) - ref_scaler.scale(y_ref).backward(dy.detach()) - ref_scaler.step(ref_optim) - ref_scaler.update() - - # Distributed implementation - dist_optim.zero_grad() - y_dist = dist_model(x.detach()) - dist_scaler.scale(y_dist).backward(dy.detach()) - dist_scaler.step(dist_optim) - dist_scaler.update() - - # Check that parameters match - for ref_param, dist_param in zip(ref_model.parameters(), dist_model.parameters()): - torch.testing.assert_close(dist_param, ref_param) - - def test_checkpoint( - self, - rtol: Optional[float] = None, - atol: Optional[float] = None, - num_layers: int = 2, - layer_size: int = 2, - num_steps: int = 3, - save_group_size: Optional[int] = None, - load_group_size: Optional[int] = None, - save_model_kwargs: Optional[dict] = None, - load_model_kwargs: Optional[dict] = None, - ): - """Test state_dict and load_state_dict functions - - Two models are constructed, possibly on different process - groups. One of the models is trained for a few steps, a - checkpoint is saved, and the checkpoint is loaded on the other - model. Both models are then trained for a few steps and - checked to make sure that they produce identical results. - - Arguments: - rtol (float): Relative tolerance for numerical checks (see - torch.allclose). - atol (float): Absolute tolerance for numerical checks (see - torch.allclose). - num_layers (int): Number of layers in test model. - layer_size (int): Number of features in model layers. - num_steps (int): Number of training steps to perform - before and after checkpointing. - save_group_size (int): Process group size for model that - saves the checkpoint. Uses the default process group - by default. - load_group_size (int): Process group size for model that - loads the checkpoint. Uses the default process group - by default. - save_model_kwargs (dict): keyword arguments passed to - make_models when constructing the model that saves the - checkpoint. - load_model_kwargs (dict): keyword arguments passed to - make_models when constructing the model that loads the - checkpoint. - - """ - - # Initialize process groups - world_size = torch.distributed.get_world_size() - if save_group_size is None: - save_group_size = world_size - save_group = None - else: - if save_group_size > world_size: - self.skipTest(f"Requires {save_group_size} ranks, found {world_size}") - save_ranks = list(range(save_group_size)) - save_group = torch.distributed.new_group(ranks=save_ranks) - if load_group_size is None: - load_group_size = world_size - load_group = None - else: - if load_group_size > world_size: - self.skipTest(f"Requires {load_group_size} ranks, found {world_size}") - load_ranks = list(range(load_group_size)) - load_group = torch.distributed.new_group(ranks=load_ranks) - - # Construct two models with same config and different params - torch.manual_seed(self.seed) - if self.rank < save_group_size: - if not save_model_kwargs: - save_model_kwargs = {} - _, _, model_save, optim_save = make_models( - num_layers, - layer_size, - lr=0.1, - process_group=save_group, - average_grad_sync=False, - overlap_communication=False, - **save_model_kwargs, - ) - optim_save.init_params(reversed(list(model_save.parameters()))) - torch.manual_seed(self.seed + 1) - if self.rank < load_group_size: - if not load_model_kwargs: - load_model_kwargs = {} - _, _, model_load, optim_load = make_models( - num_layers, - layer_size, - lr=1234.0, - process_group=load_group, - average_grad_sync=False, - overlap_communication=False, - **load_model_kwargs, - ) - optim_load.init_params(list(model_load.parameters())) - - batch_size = 2 * save_group_size * load_group_size - - def make_global_batch() -> torch.Tensor: - """Generate random tensor on root rank and broadcast""" - x = torch.empty(batch_size, layer_size, device="cuda") - if self.rank == 0: - torch.rand(x.size(), out=x) - x -= 0.5 - torch.distributed.broadcast(x, src=0) - return x - - def to_local_batch( - global_batch: torch.Tensor, - group: Optional[torch.distributed.ProcessGroup], - ) -> Optional[torch.Tensor]: - """Get local portion of tensor that is replicated across all ranks""" - group_size = torch.distributed.get_world_size(group) - if group_size < 0: - return None - local_batch_size = batch_size // group_size - batch_start = self.rank * local_batch_size - batch_end = (self.rank + 1) * local_batch_size - return global_batch[batch_start:batch_end, ...] - - def to_global_batch( - local_batch: torch.Tensor, - group: Optional[torch.distributed.ProcessGroup], - ) -> torch.Tensor: - """Gather distributed tensor and broadcast to all ranks""" - - # Allocate buffer - global_batch = torch.empty(batch_size, layer_size, device="cuda") - - # Gather data on root rank - group_size = torch.distributed.get_world_size(group) - if group_size > 0: - local_batches = None - if self.rank == 0: - local_batch_size = batch_size // group_size - local_batches = [ - global_batch[rank * local_batch_size : (rank + 1) * local_batch_size, ...] - for rank in range(group_size) - ] - torch.distributed.gather( - local_batch, - local_batches, - dst=0, - group=group, - ) - - # Broadcast data to all ranks - torch.distributed.broadcast(global_batch, src=0) - return global_batch - - # Train one of the models - torch.manual_seed(self.seed + 2) - for step in range(num_steps): - if self.rank < save_group_size: - optim_save.zero_grad() - x = make_global_batch() - dy = make_global_batch() - if self.rank < save_group_size: - x = to_local_batch(x, save_group) - dy = to_local_batch(dy, save_group) - y = model_save(x) - y.backward(dy) - optim_save.step() - - # Make sure models are different - if self.rank < min(save_group_size, load_group_size): - for param_save, param_load in zip(model_save.parameters(), model_load.parameters()): - self.assertRaises( - AssertionError, - torch.testing.assert_close, - param_load, - param_save, - rtol=rtol, - atol=atol, - ) - - # Save state - state_bytes = None - if self.rank < save_group_size: - state_dict = { - "model": model_save.state_dict(), - "optim": optim_save.state_dict(), - } - byte_stream = io.BytesIO() - torch.save(state_dict, byte_stream) - state_bytes = byte_stream.getvalue() - - # Broadcast state from root rank and load - if self.rank < load_group_size: - if load_group_size != save_group_size: - if self.rank != 0: - state_bytes = None - state_bytes = [state_bytes] - torch.distributed.broadcast_object_list( - state_bytes, - src=0, - group=load_group, - ) - state_bytes = state_bytes[0] - state_dict = torch.load(io.BytesIO(state_bytes)) - model_load.load_state_dict(state_dict["model"]) - optim_load.load_state_dict(state_dict["optim"]) - - # Make sure models are identical - if self.rank < min(save_group_size, load_group_size): - for param_save, param_load in zip(model_save.parameters(), model_load.parameters()): - torch.testing.assert_close(param_load, param_save, rtol=rtol, atol=atol) - - # Train both models - torch.manual_seed(self.seed + 3) - for step in range(num_steps): - # Reset grads - if self.rank < save_group_size: - optim_save.zero_grad() - if self.rank < load_group_size: - optim_load.zero_grad() - - # Synthetic data - x = make_global_batch() - dy = make_global_batch() - - # Training step for model that saved checkpoint - y_save = None - dx_save = None - if self.rank < save_group_size: - x_save = to_local_batch(x, save_group) - x_save = x_save.detach().clone().requires_grad_(True) - dy_save = to_local_batch(dy, save_group) - y_save = model_save(x_save) - y_save.backward(dy_save) - dx_save = x_save.grad - y_save = to_global_batch(y_save, save_group) - dx_save = to_global_batch(dx_save, save_group) - - # Training step for model that loaded checkpoint - y_load = None - dx_load = None - if self.rank < load_group_size: - x_load = to_local_batch(x, load_group) - x_load = x_load.detach().clone().requires_grad_(True) - dy_load = to_local_batch(dy, load_group) - y_load = model_load(x_load) - y_load.backward(dy_load) - dx_load = x_load.grad - y_load = to_global_batch(y_load, load_group) - dx_load = to_global_batch(dx_load, load_group) - - # Check that data tensors match - torch.testing.assert_close(y_load, y_save, rtol=rtol, atol=atol) - torch.testing.assert_close(dx_load, dx_save, rtol=rtol, atol=atol) - - # Optimizer step - if self.rank < save_group_size: - optim_save.step() - if self.rank < load_group_size: - optim_load.step() - - # Check that parameters match - if self.rank < min(save_group_size, load_group_size): - for param_save, param_load in zip(model_save.parameters(), model_load.parameters()): - torch.testing.assert_close( - param_load, - param_save, - rtol=rtol, - atol=atol, - ) - - def test_checkpoint_save_1gpu(self): - """Test loading checkpoint with one GPU""" - self.test_checkpoint(save_group_size=1) - - def test_checkpoint_load_1gpu(self): - """Test saving checkpoint with one GPU""" - self.test_checkpoint(load_group_size=1) - - def test_checkpoint_bf16(self): - """Test checkpoint with BF16 model""" - self.test_checkpoint( - rtol=5e-2, - atol=1e-5, - save_model_kwargs=dict( - model_dtype=torch.bfloat16, - optim_dtype=torch.float32, - param_sync_dtype=torch.bfloat16, - store_params=False, - store_param_remainders=True, - ), - load_model_kwargs=dict( - model_dtype=torch.bfloat16, - optim_dtype=torch.float32, - param_sync_dtype=torch.bfloat16, - store_params=False, - store_param_remainders=True, - ), - ) - - def test_checkpoint_scaled_state(self): - """Test checkpoint with scaled FP16 state""" - self.test_checkpoint( - rtol=5e-2, - atol=1e-5, - save_model_kwargs=dict( - model_dtype=torch.bfloat16, - optim_dtype=torch.float16, - param_sync_dtype=torch.int, - store_params=True, - with_scaled_states=True, - ), - load_model_kwargs=dict( - model_dtype=torch.bfloat16, - optim_dtype=torch.float16, - param_sync_dtype=torch.int, - store_params=True, - with_scaled_states=True, - ), - ) - - def test_bucket_low_utilization_warning(self): - """Test warning when bucket utilization is low""" - layer_size = 2 * 1024 * 1024 - num_layers = 4 - fairish_bucket_cap_mb = 4 * num_layers * layer_size / (1024 * 1024) - - # Check that warning is raised when bucket utilization is low - with self.assertWarnsRegex(Warning, ".*Consider decreasing the bucket_cap_mb argument."): - self.test_matches_pytorch( - num_layers=num_layers, - layer_size=layer_size, - overlap_communication=False, - bucket_cap_mb=fairish_bucket_cap_mb * 2, - ) - - # Check that warning is not raised when bucket utilization is high - with warnings.catch_warnings(record=True) as warns: - self.test_matches_pytorch( - num_layers=num_layers, - layer_size=layer_size, - overlap_communication=False, - bucket_cap_mb=fairish_bucket_cap_mb, - ) - for w in warns: - self.assertNotRegex( - str(w.message), ".*Consider decreasing the bucket_cap_mb argument." - ) - - def test_cuda_graph(self): - """Test distributed adam with CUDA graph""" - if self.world_size <= 8: - self.skipTest(f"{self.world_size=} is expected to be >= 8") - self.test_matches_pytorch( - rtol=5e-3, - atol=1e-5, - num_steps=15, - micro_batch_steps=1, - model_dtype=torch.float16, - optim_dtype=torch.float16, - contiguous_buffers=True, - with_cuda_graph=True, - ) - - -if __name__ == "__main__": - # Assume script has been run with torchrun - common_utils.run_tests() diff --git a/apex/contrib/test/optimizers/test_distributed_fused_lamb.py b/apex/contrib/test/optimizers/test_distributed_fused_lamb.py deleted file mode 100644 index 4501e99b1..000000000 --- a/apex/contrib/test/optimizers/test_distributed_fused_lamb.py +++ /dev/null @@ -1,168 +0,0 @@ -import inspect - -import torch -from torch.cuda.amp import GradScaler -from torch.testing._internal import common_utils -from torch.distributed.distributed_c10d import _coalescing_manager - -from apex.contrib.optimizers.distributed_fused_lamb import DistributedFusedLAMB -from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase - - -def flat_dist_call(param_list: list[torch.Tensor], op, args): - with _coalescing_manager(async_ops=True) as cm: - for p in param_list: - op(p, *args) - - cm.wait() - - -def get_init_weights_func(): - @torch.no_grad() - def init_weights(m): - if isinstance(m, torch.nn.Linear): - m.weight.fill_(1.0) - - return init_weights - - -class ModelFoo(torch.nn.Module): - def __init__(self): - super(ModelFoo, self).__init__() - self.linear = torch.nn.Linear(128, 128, bias=False) - self.loss = torch.nn.MSELoss() - - def forward(self, input_tensor, gt): - y = self.linear(input_tensor) - loss = self.loss(y, gt) - return loss - - -# A test for distributed fused Lamb optimizer: run several iterations and see if loss decreases -# There are two instances of the same test because based on `world_size` the optimizer decides what collectives operation to use. -# If torch.distributed.get_world_size() == torch.cuda.device_count() it uses only `all_gather`. -# If torch.distributed.get_world_size() < torch.cuda.device_count() it uses both `all_gather` and `reduce_scatter`. -class NcclDistributedFusedLAMB(NcclDistributedTestBase): - @property - def world_size(self) -> int: - return torch.cuda.device_count() - - @common_utils.parametrize("no_copy", [False, True]) - @common_utils.parametrize( - "opt_kwargs", - [ - dict( - overlap_reductions=True, - dwu_num_blocks=2, - dwu_num_chunks=2, - fused_norm=False, - fuse_scale=False, - clip_after_ar=True, - full_ar=False, - ), - dict( - overlap_reductions=False, - dwu_num_blocks=1, - dwu_num_chunks=1, - fused_norm=True, - fuse_scale=True, - clip_after_ar=False, - ), - ], - ) - def test_distributed_fused_lamb(self, no_copy, opt_kwargs): - if ( - no_copy - and "no_copy" not in inspect.getfullargspec(torch.distributed.reduce_scatter).args - ): - self.skipTest("does not support no_copy") - if no_copy and "no_copy" not in inspect.getfullargspec(torch.distributed.all_gather).args: - self.skipTest("does not support no_copy") - - assert torch.distributed.is_initialized() - gpu_count = torch.distributed.get_world_size() - - init_scale = 100 - lr = torch.tensor(0.1).cuda() - grad_scaler = GradScaler(init_scale=init_scale, growth_interval=1000) - - model = ModelFoo() - model = model.cuda().half() - model.apply(get_init_weights_func()) - - param_optimizer = list(model.named_parameters()) - no_decay = ["bias", "gamma", "beta", "LayerNorm"] - optimizer_grouped_parameters = [ - { - "params": [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], - "weight_decay": 0.01, - }, - { - "params": [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], - "weight_decay": 0.0, - }, - ] - - if "full_ar" not in opt_kwargs: - opt_kwargs["full_ar"] = gpu_count == torch.cuda.device_count() - - # Aidyn-A: not sure what parameters are the best for testing purposes, - # setting up whatever I think appropriate. - optimizer = DistributedFusedLAMB( - optimizer_grouped_parameters, - lr=0.1, - betas=(0.9, 0.9), - eps=1e-6, - max_grad_norm=1.0, - dwu_group_size=gpu_count, - dwu_num_rs_pg=1, - dwu_num_ar_pg=1, - dwu_num_ag_pg=1, - use_nvlamb=False, - set_param_views_to_flat_buffer=False, - e5m2_allgather=False, - **opt_kwargs, - ) - optimizer.set_global_scale(init_scale) - - optimizer._reduce_scatter_no_copy = no_copy - optimizer._all_gather_no_copy = no_copy - - flat_dist_call( - [param.data for param in model.parameters()], - torch.distributed.broadcast, - (0,), - ) - - x = torch.randn(4096, 128, dtype=torch.float16).cuda() - y = torch.randn(4096, 128, dtype=torch.float16).cuda() - - losses = [] - for _ in range(10): - loss = model(x, y) - optimizer._lazy_init_stage1() - grad_scaler.scale(loss).backward() - optimizer._lazy_init_stage2() - optimizer._lr = lr - optimizer.complete_reductions() - optimizer.set_global_scale(grad_scaler._get_scale_async()) - grad_scaler.step(optimizer) - grad_scaler.update() - optimizer.zero_grad(set_to_none=True) - - losses.append(loss.item()) - - self.assertTrue(losses == sorted(losses, reverse=True)) - - -common_utils.instantiate_parametrized_tests(NcclDistributedFusedLAMB) - - -class NcclDistributedFusedLAMB_partial_ar(NcclDistributedFusedLAMB): - @property - def world_size(self) -> int: - return max(torch.cuda.device_count() - 1, 1) - - -if __name__ == "__main__": - common_utils.run_tests() diff --git a/apex/contrib/test/peer_memory/__init__.py b/apex/contrib/test/peer_memory/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/apex/contrib/test/peer_memory/test_peer_halo_exchange_module.py b/apex/contrib/test/peer_memory/test_peer_halo_exchange_module.py deleted file mode 100644 index b7fb5488a..000000000 --- a/apex/contrib/test/peer_memory/test_peer_halo_exchange_module.py +++ /dev/null @@ -1,335 +0,0 @@ -import unittest - -import torch -from torch.testing._internal import common_utils - -SKIP_TEST = None -from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase - -try: - from apex.contrib.peer_memory import PeerMemoryPool, PeerHaloExchanger1d -except ImportError as e: - SKIP_TEST = e - -# How to run: -# python /path/to/test_peer_halo_exchange_module.py - - -# Output of this function is used as ground truth in module tests. -def nccl_halo_ex(peer_rank, peer_group_size, y, half_halo, explicit_nhwc, H_split): - if explicit_nhwc: - if H_split: - _, Hp, _, _ = list(y.shape) - H = Hp - 2 * half_halo - top_out_halo = y[:, half_halo : 2 * half_halo, :, :] - top_inp_halo = y[:, :half_halo, :, :] - btm_out_halo = y[:, H : H + half_halo, :, :] - btm_inp_halo = y[:, H + half_halo : H + 2 * half_halo, :, :] - else: - _, _, Wp, _ = list(y.shape) - W = Wp - 2 * half_halo - top_out_halo = y[:, :, half_halo : 2 * half_halo, :] - top_inp_halo = y[:, :, :half_halo, :] - btm_out_halo = y[:, :, W : W + half_halo, :] - btm_inp_halo = y[:, :, W + half_halo : W + 2 * half_halo, :] - else: - if H_split: - _, _, Hp, _ = list(y.shape) - H = Hp - 2 * half_halo - top_out_halo = y[:, :, half_halo : 2 * half_halo, :] - top_inp_halo = y[:, :, :half_halo, :] - btm_out_halo = y[:, :, H : H + half_halo, :] - btm_inp_halo = y[:, :, H + half_halo : H + 2 * half_halo, :] - else: - _, _, _, Wp = list(y.shape) - W = Wp - 2 * half_halo - top_out_halo = y[:, :, :, half_halo : 2 * half_halo] - top_inp_halo = y[:, :, :, :half_halo] - btm_out_halo = y[:, :, :, W : W + half_halo] - btm_inp_halo = y[:, :, :, W + half_halo : W + 2 * half_halo] - - mf = ( - torch.channels_last - if y.is_contiguous(memory_format=torch.channels_last) - else torch.contiguous_format - ) - top_out_halo = top_out_halo.contiguous() - btm_out_halo = btm_out_halo.contiguous() - - top_inp_halos = [torch.empty_like(top_out_halo) for _ in range(peer_group_size)] - torch.distributed.all_gather(top_inp_halos, top_out_halo) - btm_inp_halos = [torch.empty_like(btm_out_halo) for _ in range(peer_group_size)] - torch.distributed.all_gather(btm_inp_halos, btm_out_halo) - top_rank = (peer_rank + peer_group_size - 1) % peer_group_size - btm_rank = (peer_rank + 1) % peer_group_size - if peer_rank == 0: - top_inp_halo.zero_() - else: - top_inp_halo.copy_(btm_inp_halos[top_rank].to(memory_format=mf)) - if peer_rank == peer_group_size - 1: - btm_inp_halo.zero_() - else: - btm_inp_halo.copy_(top_inp_halos[btm_rank].to(memory_format=mf)) - - -def single_test( - peer_rank, - peer_group_size, - halo_ex, - C, - H, - W, - half_halo, - dtype, - memory_format, - H_split, - num_steps, - numSM=1, -): - if memory_format == 1: - # 1 -> explicit nhwc - explicit_nhwc = True - if H_split: - y = torch.randn([1, H + 2 * half_halo, W, C], dtype=dtype, device="cuda") - ym = y[:, half_halo : H + half_halo, :, :] - else: - y = torch.randn([1, H, W + 2 * half_halo, C], dtype=dtype, device="cuda") - ym = y[:, :, half_halo : W + half_halo, :] - else: - # 2 -> native nhwc - # 3 -> nchw - explicit_nhwc = False - if H_split: - y = torch.randn([1, C, H + 2 * half_halo, W], dtype=dtype, device="cuda") - if memory_format == 2: - y = y.to(memory_format=torch.channels_last) - ym = y[:, :, half_halo : H + half_halo, :] - else: - y = torch.randn([1, C, H, W + 2 * half_halo], dtype=dtype, device="cuda") - if memory_format == 2: - y = y.to(memory_format=torch.channels_last) - ym = y[:, :, :, half_halo : W + half_halo] - y3 = y.clone() - list_y = [] - for step in range(num_steps): - halo_ex(y, H_split, explicit_nhwc, numSM) - list_y.append(y.clone()) - y.copy_(y3) - halo_ex.peer_pool.reset() - torch.distributed.barrier() - y2 = y3.clone() - list_y2 = [] - for step in range(num_steps): - nccl_halo_ex(peer_rank, peer_group_size, y2, half_halo, explicit_nhwc, H_split) - list_y2.append(y2.clone()) - y2.copy_(y3) - if memory_format == 1: - memory_format_str = "explicit_nhwc" - elif memory_format == 2: - memory_format_str = "native nhwc" - elif memory_format == 3: - memory_format_str = "nchw" - else: - memory_format_str = "???" - torch.testing.assert_close(list_y, list_y2, msg=memory_format_str) - # is_equal = [torch.all(torch.eq(yy, yy2)) for yy, yy2 in zip(list_y, list_y2)] - # is_equal = torch.tensor(is_equal, dtype=torch.bool) - # is_equal = torch.all(is_equal) - # if peer_rank == 0: - # if is_equal: - # print( - # "SUCCESS : N,C,H,W = 1,%d,%d,%d, half_halo=%d, %s, %s, %s" - # % ( - # C, - # H, - # W, - # half_halo, - # str(dtype), - # memory_format_str, - # "H-split" if H_split else "W-split", - # ) - # ) - # else: - # print( - # "FAILURE : N,C,H,W = 1,%d,%d,%d, half_halo=%d, %s, %s, %s" - # % ( - # C, - # H, - # W, - # half_halo, - # str(dtype), - # memory_format_str, - # "H-split" if H_split else "W-split", - # ) - # ) - # - # peer memory flag sync relies on there being at least one barrier per step - # torch.distributed.barrier() - - -def H_split_tests(N, C, H, W, half_halo, rank, world_size, halo_ex, num_steps): - Hr = 8 * world_size - Hp = ((H + Hr - 1) // Hr) * 8 - - for i in range(4): - div = int(pow(2, i)) - single_test( - rank, - world_size, - halo_ex, - C * div, - Hp // div, - W // div, - half_halo, - torch.float16, - 1, - True, - num_steps, - ) - single_test( - rank, - world_size, - halo_ex, - C * div, - Hp // div, - W // div, - half_halo, - torch.float16, - 2, - True, - num_steps, - ) - single_test( - rank, - world_size, - halo_ex, - C * div, - Hp // div, - W // div, - half_halo, - torch.float16, - 3, - True, - num_steps, - ) - - -def W_split_tests(N, C, H, W, half_halo, rank, world_size, halo_ex, num_steps): - Wr = 8 * world_size - Wp = ((W + Wr - 1) // Wr) * 8 - - for i in range(4): - div = int(pow(2, i)) - single_test( - rank, - world_size, - halo_ex, - C * div, - H // div, - Wp // div, - half_halo, - torch.float16, - 1, - False, - num_steps, - ) - single_test( - rank, - world_size, - halo_ex, - C * div, - H // div, - Wp // div, - half_halo, - torch.float16, - 2, - False, - num_steps, - ) - single_test( - rank, - world_size, - halo_ex, - C * div, - H // div, - Wp // div, - half_halo, - torch.float16, - 3, - False, - num_steps, - ) - - -def main(): - # for this trivial example peer_rank == rank and peer_group_size == world_size - - torch.distributed.init_process_group("nccl") - rank = torch.distributed.get_rank() - world_size = torch.distributed.get_world_size() - torch.cuda.set_device(rank) - peer_ranks = [i for i in range(world_size)] - pool = PeerMemoryPool(0, 2 * 1024 * 1024, peer_ranks) - - num_steps = 100 - - half_halo = 1 - halo_ex = PeerHaloExchanger1d(peer_ranks, rank, pool, half_halo) - - H_split_tests(1, 64, 336, 200, half_halo, rank, world_size, halo_ex, num_steps) - W_split_tests(1, 64, 200, 336, half_halo, rank, world_size, halo_ex, num_steps) - - -@unittest.skipIf(SKIP_TEST, f"{SKIP_TEST}") -class TestPeerMemory(NcclDistributedTestBase): - HALF_HALO = 1 - NUM_STEPS = 100 - - @property - def world_size(self) -> int: - return min(torch.cuda.device_count(), 2) - - # TODO(crcrpar): Check if `world_size` being multiple of 2 is must. - def _check_world_size_and_may_skip(self) -> None: - if not (self.world_size >= 2 and self.world_size % 2 == 0): - self.skipTest(f"world_size is expected to be a multiple of 2 but, {self.world_size}") - - def get_halo_excnahger_1d(self): - peer_ranks = [i for i in range(self.world_size)] - pool = PeerMemoryPool(64 * 1024, 2 * 1024 * 1024, peer_ranks) - halo_exchanger_1d = PeerHaloExchanger1d( - peer_ranks, self.rank, pool, TestPeerMemory.HALF_HALO - ) - return halo_exchanger_1d - - def test_height_split(self): - self._check_world_size_and_may_skip() - H_split_tests( - 1, - 64, - 336, - 200, - TestPeerMemory.HALF_HALO, - self.rank, - self.world_size, - self.get_halo_excnahger_1d(), - TestPeerMemory.NUM_STEPS, - ) - - def test_width_split(self): - self._check_world_size_and_may_skip() - W_split_tests( - 1, - 64, - 200, - 336, - TestPeerMemory.HALF_HALO, - self.rank, - self.world_size, - self.get_halo_excnahger_1d(), - TestPeerMemory.NUM_STEPS, - ) - - -if __name__ == "__main__": - common_utils.run_tests() diff --git a/apex/transformer/README.md b/apex/transformer/README.md deleted file mode 100644 index 7383f65eb..000000000 --- a/apex/transformer/README.md +++ /dev/null @@ -1,81 +0,0 @@ -# apex.transformer - -`apex.transformer` is a module which enables efficient large Transformer models at scale. - -`apex.transformer.tensor_parallel` and `apex.transformer.pipeline_parallel` are both based on [NVIDIA/Megatron-LM](https://github.com/NVIDIA/Megatron-LM)'s module. -The former is based on `megatron.mpu` and the latter is on `megatron.schedules` and `megatron.p2p_communication`. - -## Tensor Model Parallel (TP) - -APEX's tensor model parallel utilities provides some `torch.nn.Module`'s, custom fused kernels, and PRNG state handling. -See Appendix B.2 of [Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism](https://arxiv.org/abs/1909.08053) for the details of -PRNG state handling. - -## Pipeline Model Parallel (PP) -APEX's pipeline model parallel functions require models to have `.set_input_tensor` because -the input tensor for `.forward` method can be `None`. - -The following is a really casual sketch of training script with apex pp. - -```python -import torch -import torch.nn as nn -import torch.nn.functional as F - -from apex.transformer import parallel_state -from apex.transformer.pipeline_parallel import get_forward_backward_func - - -class Model(nn.Module): - - ... - - def __init__(self, *args, **kwargs): - super().__init__() - pre_process = kwargs.pop("pre_process") - post_process = kwargs.pop("post_process") - - def set_input_tensor(self, tensor): - self.input_tensor = tensor - - def forward(self, x, ...): - if parallel_state.is_pipeline_first_stage(): - input = x - else: - input = self.input_tensor - ... - - -def model_provider_func(*args, **kwargs): - return Model(*args, **kwargs) - - -def loss_func(pred, label): - loss = ... - averaged_loss = average_losses_across_data_parallel_group([loss]) - return loss, {'nice_loss': averaged_loss} - - -def forward_step_func(batch, model): - input, label = process_batch(batch) - out = model(input) - return out, partial(loss_func, label) - - -forward_backward_func = get_forward_backward_func(virtual_pipeline_model_parallel_size, pipeline_model_parallel_size) - - -parallel_state.initialize_model_parallel( - tensor_model_parallel_size, - pipeline_model_parallel_size, - virtual_pipeline_model_parallel_size, -) -# The following line basically is equivalent to `build_model(Model, wrap_with_ddp, virtual_pipeline_model_parallel_size, *model_args, **model_kwargs)` -model = build_model(model_provider_func, wrap_with_ddp, virtual_pipeline_model_parallel_size, *model_args, **model_kwargs) -optimizer = ... -data_loader = ... -for epoch in range(num_epochs): - for batch in data_loader: - forward_backward_func(forward_step_func, batch, model, forward_only=False, tensor_shape) - optimizer.step() -``` diff --git a/apex/transformer/__init__.py b/apex/transformer/__init__.py deleted file mode 100644 index ff9c7b95d..000000000 --- a/apex/transformer/__init__.py +++ /dev/null @@ -1,23 +0,0 @@ -from apex.transformer import amp -from apex.transformer import functional -from apex.transformer import parallel_state -from apex.transformer import pipeline_parallel -from apex.transformer import tensor_parallel -from apex.transformer import utils -from apex.transformer.enums import LayerType -from apex.transformer.enums import AttnType -from apex.transformer.enums import AttnMaskType - - -__all__ = [ - "amp", - "functional", - "parallel_state", - "pipeline_parallel", - "tensor_parallel", - "utils", - # enums.py - "LayerType", - "AttnType", - "AttnMaskType", -] diff --git a/apex/transformer/_data/__init__.py b/apex/transformer/_data/__init__.py deleted file mode 100644 index 2831dfb11..000000000 --- a/apex/transformer/_data/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -from apex.transformer._data._batchsampler import MegatronPretrainingRandomSampler -from apex.transformer._data._batchsampler import MegatronPretrainingSampler - - -__all__ = [ - "MegatronPretrainingRandomSampler", - "MegatronPretrainingSampler", -] diff --git a/apex/transformer/_data/_batchsampler.py b/apex/transformer/_data/_batchsampler.py deleted file mode 100644 index 5f321c463..000000000 --- a/apex/transformer/_data/_batchsampler.py +++ /dev/null @@ -1,196 +0,0 @@ -"""BatchSampler implementations for POC of dynamic batch size or rampup_batch_size support. - -Implementations are based on https://github.com/NVIDIA/Megatron-LM/blob/bcd605f8570ebeeb0436c115ebbfafc3c5a40ae5/megatron/data/data_samplers.py. -""" # NOQA - -import abc - -import torch - - -__all__ = [ - "MegatronPretrainingSampler", - "MegatronPretrainingRandomSampler", -] - - -class _Base: - """Base class for Megatron style BatchSampler.""" - - @abc.abstractmethod - def __len__(self) -> int: ... - - @abc.abstractmethod - def __iter__(self): ... - - @property - @abc.abstractmethod - def local_minibatch_size(self) -> int: ... - - @local_minibatch_size.setter - @abc.abstractclassmethod - def local_minibatch_size(self) -> None: ... - - -class MegatronPretrainingSampler(_Base): - def __init__( - self, - total_samples: int, - consumed_samples: int, - local_minibatch_size: int, - data_parallel_rank: int, - data_parallel_size: int, - drop_last: bool = True, - ): - # Sanity checks. - if total_samples <= 0: - raise RuntimeError("no sample to consume: {}".format(self.total_samples)) - if consumed_samples >= total_samples: - raise RuntimeError( - "no samples left to consume: {}, {}".format( - self.consumed_samples, self.total_samples - ) - ) - if local_minibatch_size <= 0: - raise RuntimeError( - f"local minibatch size must be greater than 0: {local_minibatch_size}" - ) - if data_parallel_size <= 0: - raise RuntimeError(f"data parallel size must be greater than 0: {data_parallel_size}") - if data_parallel_rank >= data_parallel_size: - raise RuntimeError( - "data_parallel_rank should be smaller than data size: {}, {}".format( - self.data_parallel_rank, data_parallel_size - ) - ) - # Keep a copy of input params for later use. - self.total_samples = total_samples - self.consumed_samples = consumed_samples - self._local_minibatch_size = local_minibatch_size - self.data_parallel_rank = data_parallel_rank - self.data_parallel_size = data_parallel_size - self.local_minibatch_times_data_parallel_size = ( - self._local_minibatch_size * data_parallel_size - ) - self.drop_last = drop_last - - def __len__(self): - return self.total_samples - - def get_start_end_idx(self): - start_idx = self.data_parallel_rank * self.local_minibatch_size - end_idx = start_idx + self.local_minibatch_size - return start_idx, end_idx - - @property - def local_minibatch_size(self) -> int: - return self._local_minibatch_size - - @local_minibatch_size.setter - def local_minibatch_size(self, new_local_minibatch_size) -> None: - self._local_minibatch_size = new_local_minibatch_size - self.local_minibatch_times_data_parallel_size = ( - self._local_minibatch_size * self.data_parallel_size - ) - - def __iter__(self): - batch = [] - # Last batch will be dropped if drop_last is not set False - for idx in range(self.consumed_samples, self.total_samples): - batch.append(idx) - if len(batch) == self.local_minibatch_size: - start_idx, end_idx = self.get_start_end_idx() - yield batch[start_idx:end_idx] - batch = [] - - # Check the last partial batch and see drop_last is set - if len(batch) > 0 and not self.drop_last: - start_idx, end_idx = self.get_start_end_idx() - yield batch[start_idx:end_idx] - - -class MegatronPretrainingRandomSampler(_Base): - """Megatron style Random Batch Sampler. - - Major difference is that `__iter__` yields a local minibatch, not a microbatch. - A local minibatch consists of `global_batch_size / data_parallel_size` - - Args: - total_samples: The number of data samples, i.e. ``len(dataset)``. - consumed_samples: The number of samples already consumed in pretraining. - local_minibatch_size: The number of data in each batch returned from `__iter__`. Basically - `local_minibatch_size = global_batch_size / data_parallel_size`. - data_parallel_rank: - data_parallel_size: - """ - - def __init__( - self, - total_samples: int, - consumed_samples: int, - local_minibatch_size: int, - data_parallel_rank: int, - data_parallel_size: int, - ) -> None: - if total_samples <= 0: - raise ValueError(f"no sample to consume: total_samples of {total_samples}") - if local_minibatch_size <= 0: - raise ValueError(f"Invalid local_minibatch_size: {local_minibatch_size}") - if data_parallel_size <= 0: - raise ValueError(f"Invalid data_parallel_size: {data_parallel_size}") - if data_parallel_rank >= data_parallel_size: - raise ValueError( - f"data_parallel_rank should be smaller than data parallel size: {data_parallel_rank} < {data_parallel_size}" - ) - # Keep a copy of input params for later use. - self.total_samples = total_samples - self.consumed_samples = consumed_samples - self._local_minibatch_size = local_minibatch_size - self.data_parallel_rank = data_parallel_rank - self.data_parallel_size = data_parallel_size - self.local_minibatch_times_data_parallel_size = ( - self._local_minibatch_size * self.data_parallel_size - ) - self.last_batch_size = self.total_samples % self.local_minibatch_times_data_parallel_size - - def __len__(self) -> int: - return self.total_samples - - @property - def local_minibatch_size(self) -> int: - return self._local_minibatch_size - - @local_minibatch_size.setter - def local_minibatch_size(self, new_local_minibatch_size) -> None: - self._local_minibatch_size = new_local_minibatch_size - self.local_minibatch_times_data_parallel_size = ( - self._local_minibatch_size * self.data_parallel_size - ) - - def __iter__(self): - active_total_samples = self.total_samples - self.last_batch_size - self.epoch = self.consumed_samples // active_total_samples - current_epoch_samples = self.consumed_samples % active_total_samples - # note(mkozuki): might be better to uncomment - # assert current_epoch_samples % (self.data_parallel_size * apex.transformer.pipeline_parallel.utils.get_micro_batch_size()) == 0 - - # data sharding and random sampling - bucket_size = ( - self.total_samples // self.local_minibatch_times_data_parallel_size - ) * self.local_minibatch_size - bucket_offset = current_epoch_samples // self.data_parallel_size - start_idx = self.data_parallel_rank * bucket_size - - g = torch.Generator() - g.manual_seed(self.epoch) - random_idx = torch.randperm(bucket_size, generator=g).tolist() - idx_range = [start_idx + x for x in random_idx[bucket_offset:]] - - batch = [] - # Last batch if not complete will be dropped. - for idx in idx_range: - batch.append(idx) - if len(batch) == self.local_minibatch_size: - self.consumed_samples += self.local_minibatch_times_data_parallel_size - yield batch - batch = [] diff --git a/apex/transformer/_ucc_util.py b/apex/transformer/_ucc_util.py deleted file mode 100644 index 1302d2d74..000000000 --- a/apex/transformer/_ucc_util.py +++ /dev/null @@ -1,10 +0,0 @@ -from torch import distributed as dist - -HAS_UCC = hasattr(dist, "is_ucc_available") and dist.is_ucc_available() -if not HAS_UCC: - try: - import torch_ucc - - HAS_UCC = True - except ImportError: - HAS_UCC = False diff --git a/apex/transformer/amp/__init__.py b/apex/transformer/amp/__init__.py deleted file mode 100644 index dbef36a3c..000000000 --- a/apex/transformer/amp/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -from apex.transformer.amp.grad_scaler import GradScaler - - -__all__ = [ - "GradScaler", -] diff --git a/apex/transformer/amp/grad_scaler.py b/apex/transformer/amp/grad_scaler.py deleted file mode 100644 index 8ab94fede..000000000 --- a/apex/transformer/amp/grad_scaler.py +++ /dev/null @@ -1,140 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from collections import defaultdict - -import torch - -from apex.transformer import parallel_state - - -class GradScaler(torch.cuda.amp.GradScaler): - """ - Gradient scaler for model-parallel inf check. The inf in gradients are checked across tensor-parallel - ranks in (1) executing optimizer step and (2) gradient scaler update. - """ - - def __init__( - self, - init_scale=2.0**16, - growth_factor=2.0, - backoff_factor=0.5, - growth_interval=2000, - enabled=True, - ): - super().__init__( - init_scale=init_scale, - growth_factor=growth_factor, - backoff_factor=backoff_factor, - growth_interval=growth_interval, - enabled=enabled, - ) - - def _unscale_grads_(self, optimizer, *args): - if getattr(optimizer, "_custom_amp_unscale_grads", False): - return optimizer.unscale_grads(*args) - else: - return super()._unscale_grads_(optimizer, *args) - - def _maybe_opt_step(self, optimizer, optimizer_state, *args, **kwargs): - retval = None - found_inf = torch.cuda.FloatTensor( - [sum(v.item() for v in optimizer_state["found_inf_per_device"].values())] - ) - - # Update across all model parallel instances. - torch.distributed.all_reduce( - found_inf, - op=torch.distributed.ReduceOp.MAX, - group=parallel_state.get_model_parallel_group(), - ) - - if found_inf.item() == 0: - retval = optimizer.step(*args, **kwargs) - return retval - - def update(self, new_scale=None): - """ - Updates the scale factor. - If any optimizer steps were skipped the scale is multiplied by ``backoff_factor`` - to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively, - the scale is multiplied by ``growth_factor`` to increase it. - Passing ``new_scale`` sets the new scale value manually. (``new_scale`` is not - used directly, it's used to fill GradScaler's internal scale tensor. So if - ``new_scale`` was a tensor, later in-place changes to that tensor will not further - affect the scale GradScaler uses internally.) - Args: - new_scale (float or :class:`torch.cuda.FloatTensor`, optional, default=None): New scale factor. - .. warning:: - :meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has - been invoked for all optimizers used this iteration. - """ - if not self._enabled: - return - - _scale, _growth_tracker = self._check_scale_growth_tracker("update") - - if new_scale is not None: - # Accept a new user-defined scale. - if isinstance(new_scale, float): - self._scale.fill_(new_scale) # type: ignore[union-attr] - else: - reason = "new_scale should be a float or a 1-element torch.cuda.FloatTensor with requires_grad=False." - assert isinstance(new_scale, torch.cuda.FloatTensor), reason # type: ignore[attr-defined] - assert new_scale.numel() == 1, reason - assert new_scale.requires_grad is False, reason - self._scale.copy_(new_scale) # type: ignore[union-attr] - else: - # Consume shared inf/nan data collected from optimizers to update the scale. - # If all found_inf tensors are on the same device as self._scale, this operation is asynchronous. - found_infs = [ - found_inf.to(device=_scale.device, non_blocking=True) - for state in self._per_optimizer_states.values() - for found_inf in state["found_inf_per_device"].values() - ] - - assert len(found_infs) > 0, "No inf checks were recorded prior to update." - - found_inf_combined = found_infs[0] - - # Update across all model parallel instances. - torch.distributed.all_reduce( - found_inf_combined, - op=torch.distributed.ReduceOp.MAX, - group=parallel_state.get_model_parallel_group(), - ) - - if len(found_infs) > 1: - for i in range(1, len(found_infs)): - found_inf = found_infs[i] - # Update across all model parallel instances. - torch.distributed.all_reduce( - found_inf, - op=torch.distributed.ReduceOp.MAX, - group=parallel_state.get_model_parallel_group(), - ) - found_inf_combined += found_inf - - torch._amp_update_scale_( - _scale, - _growth_tracker, - found_inf_combined, - self._growth_factor, - self._backoff_factor, - self._growth_interval, - ) - - # To prepare for next iteration, clear the data collected from optimizers this iteration. - self._per_optimizer_states = defaultdict( - torch.cuda.amp.grad_scaler._refresh_per_optimizer_state - ) diff --git a/apex/transformer/enums.py b/apex/transformer/enums.py deleted file mode 100644 index 78da6c995..000000000 --- a/apex/transformer/enums.py +++ /dev/null @@ -1,35 +0,0 @@ -# coding=utf-8 -# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import enum - - -class LayerType(enum.Enum): - encoder = 1 - decoder = 2 - - -class AttnType(enum.Enum): - self_attn = 1 - cross_attn = 2 - - -class AttnMaskType(enum.Enum): - padding = 1 - causal = 2 - - -class ModelType(enum.Enum): - encoder_or_decoder = 1 - encoder_and_decoder = 2 diff --git a/apex/transformer/functional/__init__.py b/apex/transformer/functional/__init__.py deleted file mode 100644 index f307df79f..000000000 --- a/apex/transformer/functional/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -from apex.transformer.functional.fused_rope import ( - fused_apply_rotary_pos_emb, - fused_apply_rotary_pos_emb_cached, - fused_apply_rotary_pos_emb_thd, - fused_apply_rotary_pos_emb_2d, -) -from apex.transformer.functional.fused_softmax import FusedScaleMaskSoftmax - -__all__ = [ - "FusedScaleMaskSoftmax", - "fused_apply_rotary_pos_emb", - "fused_apply_rotary_pos_emb_cached", - "fused_apply_rotary_pos_emb_thd", - "fused_apply_rotary_pos_emb_2d", -] diff --git a/apex/transformer/functional/fused_rope.py b/apex/transformer/functional/fused_rope.py deleted file mode 100644 index bf3ae6f88..000000000 --- a/apex/transformer/functional/fused_rope.py +++ /dev/null @@ -1,281 +0,0 @@ -# coding=utf-8 -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Tuple, Union -import torch - - -class FusedRoPEFunc(torch.autograd.Function): - """ - Fused RoPE function - - This implementation assumes the input tensor to be in `sbhd` format and the RoPE tensor to be - of shape (s, 1, 1, d). It accepts arbitrary memory layouts to avoid the expensive - `.contiguous()` calls, thus it may not achieve the best memory access pattern. - """ - - @staticmethod - def forward( - ctx, - t: torch.Tensor, - freqs: torch.Tensor, - transpose_output_memory: bool = False, - ) -> torch.Tensor: - import fused_rotary_positional_embedding - - output = fused_rotary_positional_embedding.forward(t, freqs, transpose_output_memory) - ctx.save_for_backward(freqs) - ctx.transpose_output_memory = transpose_output_memory - - return output - - @staticmethod - def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: - import fused_rotary_positional_embedding - - (freqs,) = ctx.saved_tensors - grad_input = fused_rotary_positional_embedding.backward( - grad_output, freqs, ctx.transpose_output_memory - ) - - return grad_input, None, None - - -def fused_apply_rotary_pos_emb( - t: torch.Tensor, - freqs: torch.Tensor, - transpose_output_memory: bool = False, -) -> torch.Tensor: - """Apply rotary positional embedding to input tensor T in `sbhd` format, where - s: sequence length - b: batch size - h: head num - d: dim of each head - - Args: - t (Tensor): Input tensor T is of shape [s, b, h, d] - freqs (Tensor): Rotary Positional embedding tensor freq is of shape [s, 1, 1, d] and - `float` dtype - transpose_output_memory (bool): Default to False. Whether to transpose the 's' and 'b' - dimension of the output's underlying memory format. This is very helpful when you want to - get a contiguous tensor after calling `output.transpose(0, 1)`. - - Returns: - Tensor: The input tensor after applying RoPE - """ - return FusedRoPEFunc.apply(t, freqs, transpose_output_memory) - - -class FusedRoPECachedFunc(torch.autograd.Function): - """ - Fused RoPE function - - This implementation assumes the input tensor to be in `sbhd` format and the RoPE tensor to be - of shape (s, 1, 1, d). It accepts arbitrary memory layouts to avoid the expensive - `.contiguous()` calls, thus it may not achieve the best memory access pattern. - """ - - @staticmethod - def forward( - ctx, - t: torch.Tensor, - cos_: torch.Tensor, - sin_: torch.Tensor, - transpose_output_memory: bool = False, - ) -> torch.Tensor: - import fused_rotary_positional_embedding - - output = fused_rotary_positional_embedding.forward_cached( - t, cos_, sin_, transpose_output_memory - ) - ctx.save_for_backward(cos_, sin_) - ctx.transpose_output_memory = transpose_output_memory - - return output - - @staticmethod - def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: - import fused_rotary_positional_embedding - - cos_, sin_ = ctx.saved_tensors - grad_input = fused_rotary_positional_embedding.backward_cached( - grad_output, cos_, sin_, ctx.transpose_output_memory - ) - - return grad_input, None, None, None - - -def fused_apply_rotary_pos_emb_cached( - t: torch.Tensor, - cos_: torch.Tensor, - sin_: torch.Tensor, - transpose_output_memory: bool = False, -) -> torch.Tensor: - """Apply rotary positional embedding to input tensor T in `sbhd` format, where - s: sequence length - b: batch size - h: head num - d: dim of each head - - Args: - t (Tensor): Input tensor T is of shape [s, b, h, d] - cos_ (Tensor): Cached cosine of the rotary positional embedding tensor is of - shape [s, 1, 1, d] and dtype either `float` or the same as `t`. - sin_ (Tensor): Cached sine of the rotary positional embedding tensor is of - shape [s, 1, 1, d] and dtype either `float` or the same as `t`. - transpose_output_memory (bool): Default to False. Whether to transpose the 's' and 'b' - dimension of the output's underlying memory format. This is very helpful when you want to - get a contiguous tensor after calling `output.transpose(0, 1)`. - - Returns: - Tensor: The input tensor after applying RoPE - """ - return FusedRoPECachedFunc.apply(t, cos_, sin_, transpose_output_memory) - - -class FusedRoPETHDFunc(torch.autograd.Function): - """ - Fused RoPE function for `thd` format. - - This implementation accepts arbitrary memory layouts to avoid the expensive - `.contiguous()` calls, thus it may not achieve the best memory access pattern. - """ - - @staticmethod - def forward( - ctx, - t: torch.Tensor, - cu_seqlens: torch.Tensor, - freqs: torch.Tensor, - ) -> torch.Tensor: - import fused_rotary_positional_embedding - - output = fused_rotary_positional_embedding.forward_thd(t, cu_seqlens, freqs) - ctx.save_for_backward(cu_seqlens, freqs) - - return output - - @staticmethod - def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: - import fused_rotary_positional_embedding - - cu_seqlens, freqs = ctx.saved_tensors - grad_input = fused_rotary_positional_embedding.backward_thd(grad_output, cu_seqlens, freqs) - - return grad_input, None, None - - -def fused_apply_rotary_pos_emb_thd( - t: torch.Tensor, - cu_seqlens: torch.Tensor, - freqs: torch.Tensor, -) -> torch.Tensor: - """Apply rotary positional embedding to input tensor T in `thd` format, where - t: cumulative sum of sequence lengths - h: head num - d: dim of each head - - Args: - t (Tensor): Input tensor T is of shape [t, h, d] - cu_seqlens (Tensor): Cumulative sum of sequence lengths in a batch for `t`, - with shape [b + 1] and dtype torch.int32. - freqs (Tensor): Rotary Positional embedding tensor freq is of shape [max_s, 1, 1, d] and - `float` dtype - - Returns: - Tensor: The input tensor after applying RoPE - """ - return FusedRoPETHDFunc.apply(t, cu_seqlens, freqs) - - -class FusedRoPE2DFunc(torch.autograd.Function): - """ - Fused 2D RoPE function - """ - - @staticmethod - def forward( - ctx, - t: torch.Tensor, - img_h: int, - img_w: int, - cos_h: torch.Tensor, - sin_h: torch.Tensor, - cos_w: torch.Tensor, - sin_w: torch.Tensor, - ) -> torch.Tensor: - import fused_rotary_positional_embedding - - t = t.view(t.shape[0], img_h, img_w, t.shape[2], t.shape[3]) - output = fused_rotary_positional_embedding.forward_2d(t, cos_h, sin_h, cos_w, sin_w) - ctx.save_for_backward(cos_h, sin_h, cos_w, sin_w) - ctx.img_h = img_h - ctx.img_w = img_w - - return output - - @staticmethod - def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: - import fused_rotary_positional_embedding - - grad_output = grad_output.view( - grad_output.shape[0], - ctx.img_h, - ctx.img_w, - grad_output.shape[2], - grad_output.shape[3], - ) - cos_h, sin_h, cos_w, sin_w = ctx.saved_tensors - grad_input = fused_rotary_positional_embedding.backward_2d( - grad_output, cos_h, sin_h, cos_w, sin_w - ) - - return grad_input, None, None, None, None, None, None - - -def fused_apply_rotary_pos_emb_2d( - t: torch.Tensor, - img_h: int, - img_w: int, - cos_h: torch.Tensor, - sin_h: torch.Tensor, - cos_w: torch.Tensor, - sin_w: torch.Tensor, -) -> torch.Tensor: - """Apply rotary positional embedding to input tensor T in `bshd` format, where - b: batch size - s: sequence length - h: head num - d: dim of each head - - Args: - t (Tensor): Input tensor T is of shape [b, s, h, d] - img_h (int): s == img_h * img_w - img_w (int): s == img_h * img_w - cos_h (Tensor): shape [1, H, 1, d // 2] and dtype either `float` or - the same as `t`. H >= img_h. - sin_h (Tensor): shape [1, H, 1, d // 2] and dtype either `float` or - the same as `t`. H >= img_h. - cos_w (Tensor): shape [1, W, 1, d // 2] and dtype either `float` or - the same as `t`. W >= img_w. - sin_w (Tensor): shape [1, W, 1, d // 2] and dtype either `float` or - the same as `t`. W >= img_w. - - Returns: - Tensor: The input tensor after applying RoPE - """ - assert t.size(1) == img_h * img_w, "The sequence length should be equal to img_h * img_w" - assert cos_h.size() == sin_h.size(), "The shape of cos_h and sin_h should be the same" - assert cos_w.size() == sin_w.size(), "The shape of cos_w and sin_w should be the same" - return FusedRoPE2DFunc.apply(t, img_h, img_w, cos_h, sin_h, cos_w, sin_w) diff --git a/apex/transformer/functional/fused_softmax.py b/apex/transformer/functional/fused_softmax.py deleted file mode 100644 index b3b581be7..000000000 --- a/apex/transformer/functional/fused_softmax.py +++ /dev/null @@ -1,306 +0,0 @@ -# coding=utf-8 -# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import torch - -from apex._autocast_utils import _cast_if_autocast_enabled -from apex.transformer.enums import AttnMaskType - - -class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function): - """ - Fused operation which performs following three operations in sequence - 1. Scale the tensor. - 2. Apply upper triangular mask (typically used in gpt models). - 3. Perform softmax. - """ - - @staticmethod - def forward(ctx, inputs, scale): - import scaled_upper_triang_masked_softmax_cuda - - scale_t = torch.tensor([scale]) - softmax_results = scaled_upper_triang_masked_softmax_cuda.forward(inputs, scale_t[0]) - - ctx.save_for_backward(softmax_results, scale_t) - return softmax_results - - @staticmethod - def backward(ctx, output_grads): - import scaled_upper_triang_masked_softmax_cuda - - softmax_results, scale_t = ctx.saved_tensors - input_grads = scaled_upper_triang_masked_softmax_cuda.backward( - output_grads, softmax_results, scale_t[0] - ) - - return input_grads, None - - -def scaled_upper_triang_masked_softmax(inputs, _, scale): - b, np, sq, sk = inputs.size() - assert sq == sk, "causal mask is only for self attention" - # Reshaping input to 3D tensor (attn_batches, sq, sk) - inputs = inputs.view(-1, sq, sk) - args = _cast_if_autocast_enabled(inputs, scale) - with torch.amp.autocast("cuda", enabled=False): - probs = ScaledUpperTriangMaskedSoftmax.apply(*args) - return probs.view(b, np, sq, sk) - - -# NOTE (mkozuki): `ScaledMaskedSoftmax` somehow doesn't work well with `torch.cuda.amp.custom_fwd`. -# Without `cast_inputs` kwarg, somehow inputs are not cast to dtype used in the autocast context. -# So I needed to manually write two `torch.autograd.Function` inheritances. -# Fused operation which performs following three operations in sequence -# 1. Scale the tensor. -# 2. Apply the mask. -# 3. Perform softmax. -class ScaledMaskedSoftmax(torch.autograd.Function): - @staticmethod - def forward(ctx, inputs, mask, scale): - import scaled_masked_softmax_cuda - - scale_t = torch.tensor([scale]) - - softmax_results = scaled_masked_softmax_cuda.forward(inputs, mask, scale_t[0]) - ctx.save_for_backward(softmax_results, scale_t) - return softmax_results - - @staticmethod - def backward(ctx, output_grads): - import scaled_masked_softmax_cuda - - softmax_results, scale_t = ctx.saved_tensors - - input_grads = scaled_masked_softmax_cuda.backward(output_grads, softmax_results, scale_t[0]) - return input_grads, None, None - - -def scaled_masked_softmax(inputs, mask, scale): - # input is 4D tensor (b, np, sq, sk) - if mask is not None: - args = _cast_if_autocast_enabled(inputs, mask, scale) - with torch.amp.autocast("cuda", enabled=False): - return ScaledMaskedSoftmax.apply(*args) - else: - args = _cast_if_autocast_enabled(inputs, scale) - with torch.amp.autocast("cuda", enabled=False): - return ScaledSoftmax.apply(*args) - - -class GenericScaledMaskedSoftmax(torch.autograd.Function): - @staticmethod - def forward(ctx, inputs, mask, scale): - import generic_scaled_masked_softmax_cuda - - scale_t = torch.tensor([scale]) - softmax_results = generic_scaled_masked_softmax_cuda.forward(inputs, mask, scale_t[0]) - ctx.save_for_backward(softmax_results, scale_t) - return softmax_results - - @staticmethod - def backward(ctx, output_grads): - softmax_results, scale_t = ctx.saved_tensors - - input_grads = generic_scaled_masked_softmax_cuda.backward( - output_grads, softmax_results, scale_t[0] - ) - return input_grads, None, None - - -def generic_scaled_masked_softmax(inputs, mask, scale): - # input is 4D tensor (b, np, sq, sk) - args = _cast_if_autocast_enabled(inputs, mask, scale) - with torch.amp.autocast("cuda", enabled=False): - return GenericScaledMaskedSoftmax.apply(*args) - - -class ScaledSoftmax(torch.autograd.Function): - """ - Fused operation which performs following two operations in sequence - 1. Scale the tensor. - 2. Perform softmax. - """ - - @staticmethod - def forward(ctx, inputs, scale): - import scaled_softmax_cuda - - scale_t = torch.tensor([scale]) - - softmax_results = scaled_softmax_cuda.forward(inputs, scale_t[0]) - ctx.save_for_backward(softmax_results, scale_t) - return softmax_results - - @staticmethod - def backward(ctx, output_grads): - import scaled_softmax_cuda - - softmax_results, scale_t = ctx.saved_tensors - - input_grads = scaled_softmax_cuda.backward(output_grads, softmax_results, scale_t[0]) - return input_grads, None, None - - -class FusedScaleMaskSoftmax(torch.nn.Module): - """ - fused operation: scaling + mask + softmax - - Arguments: - input_in_fp16: flag to indicate if input in fp16 data format. - input_in_bf16: flag to indicate if input in bf16 data format. - attn_mask_type: attention mask type (pad or causal) - scaled_masked_softmax_fusion: flag to indicate user want to use softmax fusion - mask_func: mask function to be applied. - softmax_in_fp32: if true, softmax in performed at fp32 precision. - scale: scaling factor used in input tensor scaling. - """ - - def __init__( - self, - input_in_fp16, - input_in_bf16, - attn_mask_type, - scaled_masked_softmax_fusion, - mask_func, - softmax_in_fp32, - scale, - ): - super().__init__() - self.input_in_fp16 = input_in_fp16 - self.input_in_bf16 = input_in_bf16 - if self.input_in_fp16 and self.input_in_bf16: - raise RuntimeError("both fp16 and bf16 flags cannot be active at the same time.") - self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16 - self.attn_mask_type = attn_mask_type - self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion - self.mask_func = mask_func - self.softmax_in_fp32 = softmax_in_fp32 - self.scale = scale - - if not (self.scale is None or softmax_in_fp32): - raise RuntimeError("softmax should be in fp32 when scaled") - - if self.scaled_masked_softmax_fusion: - if self.attn_mask_type == AttnMaskType.causal: - self.fused_softmax_func = scaled_upper_triang_masked_softmax - elif self.attn_mask_type == AttnMaskType.padding: - self.fused_softmax_func = scaled_masked_softmax - else: - raise ValueError("Invalid attn_mask_type.") - - def forward(self, input, mask): - # [b, np, sq, sk] - assert input.dim() == 4 - - if self.is_kernel_available(mask, *input.size()): - return self.forward_fused_softmax(input, mask) - else: - return self.forward_torch_softmax(input, mask) - - def is_kernel_available(self, mask, b, np, sq, sk): - attn_batches = b * np - - if ( - self.scaled_masked_softmax_fusion # user want to fuse - and self.input_in_float16 # input must be fp16 - and ( - self.attn_mask_type == AttnMaskType.causal - or self.attn_mask_type == AttnMaskType.padding - ) - and 16 < sk <= 16384 # sk must be 16 ~ 16384 - and sq % 4 == 0 # sq must be divisor of 4 - and sk % 4 == 0 # sk must be divisor of 4 - and attn_batches % 4 == 0 # np * b must be divisor of 4 - ): - if 0 <= sk <= 16384: - batch_per_block = self.get_batch_per_block(sq, sk, b, np) - - if self.attn_mask_type == AttnMaskType.causal: - if attn_batches % batch_per_block == 0: - return True - else: - if sq % batch_per_block == 0: - return True - return False - - def forward_fused_softmax(self, input, mask): - # input.shape = [b, np, sq, sk] - scale = self.scale if self.scale is not None else 1.0 - return self.fused_softmax_func(input, mask, scale) - - def forward_torch_softmax(self, input, mask): - if self.input_in_float16 and self.softmax_in_fp32: - input = input.float() - - if self.scale is not None: - input = input * self.scale - mask_output = self.mask_func(input, mask) if mask is not None else input - probs = torch.nn.Softmax(dim=-1)(mask_output) - - if self.input_in_float16 and self.softmax_in_fp32: - if self.input_in_fp16: - probs = probs.half() - else: - probs = probs.bfloat16() - - return probs - - @staticmethod - def get_batch_per_block(sq, sk, b, np): - import scaled_masked_softmax_cuda - - return scaled_masked_softmax_cuda.get_batch_per_block(sq, sk, b, np) - - -class GenericFusedScaleMaskSoftmax(FusedScaleMaskSoftmax): - """ - Generic version of FusedSacleMaskSoftmax. - It removes the seq-len limitations and has slight performance degragation compared with FusedScaleMaskSoftmax - - fused operation: scaling + mask + softmax - - Arguments: - input_in_fp16: flag to indicate if input in fp16 data format. - input_in_bf16: flag to indicate if input in bf16 data format. - scaled_masked_softmax_fusion: flag to indicate user want to use softmax fusion - mask_func: mask function to be applied. - softmax_in_fp32: if true, softmax in performed at fp32 precision. - scale: scaling factor used in input tensor scaling. - """ - - def __init__( - self, - input_in_fp16, - input_in_bf16, - scaled_masked_softmax_fusion, - mask_func, - softmax_in_fp32, - scale, - ): - super().__init__( - input_in_fp16, - input_in_bf16, - AttnMaskType.padding, - scaled_masked_softmax_fusion, - mask_func, - softmax_in_fp32, - scale, - ) - self.scaled_masked_softmax_fusion = generic_scaled_masked_softmax - - def is_kernel_available(self, mask, b, np, sq, sk): - if self.scaled_masked_softmax_fusion and 0 < sk: # user want to fuse # sk must be 1 ~ - return True - return False diff --git a/apex/transformer/layers/__init__.py b/apex/transformer/layers/__init__.py deleted file mode 100644 index bc247d3c1..000000000 --- a/apex/transformer/layers/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. -from apex.transformer.layers.layer_norm import FastLayerNorm -from apex.transformer.layers.layer_norm import FusedLayerNorm -from apex.transformer.layers.layer_norm import MixedFusedLayerNorm - - -__all__ = [ - "FastLayerNorm", - "FusedLayerNorm", - "MixedFusedLayerNorm", -] diff --git a/apex/transformer/layers/layer_norm.py b/apex/transformer/layers/layer_norm.py deleted file mode 100644 index 2cd0aa97a..000000000 --- a/apex/transformer/layers/layer_norm.py +++ /dev/null @@ -1,101 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. -# NOTE(mkozuki): This file defines two LayerNorm that are compatible with Megatron-LM. -# while avoiding introducing the breaking change of `"sequence_parallel_enabled"` attribute into apex.normalization.FusedLayerNorm -# and apex.contrib.layer_norm.FastLayerNorm. -import warnings - -import torch - -from apex.normalization import FusedLayerNorm as OrigFusedLayerNorm -from apex.normalization import MixedFusedLayerNorm as OrigMixedFusedLayerNorm - -try: - from apex.contrib.layer_norm import FastLayerNorm as OrigFastLayerNorm -except ImportError: - HAS_FAST_LAYER_NORM = False -else: - HAS_FAST_LAYER_NORM = True - - -__all__ = [ - "FusedLayerNorm", - "FastLayerNorm", - "MixedFusedLayerNorm", -] - - -def _set_sequence_parallel_enabled( - param: torch.Tensor, - sequence_parallel_enabled: bool, -) -> None: - setattr(param, "sequence_parallel_enabled", sequence_parallel_enabled) - - -class FusedLayerNorm(OrigFusedLayerNorm): - def __init__( - self, - normalized_shape, - eps: float = 1e-5, - elementwise_affine: bool = True, - *, - sequence_parallel_enabled: bool = False, - ): - super().__init__( - normalized_shape=normalized_shape, - eps=eps, - elementwise_affine=elementwise_affine, - ) - self.sequence_parallel_enabled = sequence_parallel_enabled - if self.elementwise_affine: - _set_sequence_parallel_enabled(self.weight, self.sequence_parallel_enabled) - _set_sequence_parallel_enabled(self.bias, self.sequence_parallel_enabled) - - -# note: MixedFusedLayerNorm is no different from FusedLayerNorm if it's used in `torch.cuda.amp`. -class MixedFusedLayerNorm(OrigMixedFusedLayerNorm): - def __init__( - self, - normalized_shape, - eps: float = 1e-5, - **kwargs, - ) -> None: - self.sequence_parallel_enabled = kwargs.get("sequence_parallel_enabled", False) - super().__init__(normalized_shape=normalized_shape, eps=eps, **kwargs) - if self.sequence_parallel_enabled: - _set_sequence_parallel_enabled(self.weight, self.sequence_parallel_enabled) - _set_sequence_parallel_enabled(self.bias, self.sequence_parallel_enabled) - - -if HAS_FAST_LAYER_NORM: - - class FastLayerNorm(OrigFastLayerNorm): - def __init__( - self, - hidden_size, - eps: float = 1e-5, - *, - sequence_parallel_enabled: bool = False, - ): - super().__init__(hidden_size=hidden_size, eps=eps) - self.sequence_parallel_enabled = sequence_parallel_enabled - _set_sequence_parallel_enabled(self.weight, self.sequence_parallel_enabled) - _set_sequence_parallel_enabled(self.bias, self.sequence_parallel_enabled) -else: - - class FastLayerNorm(FusedLayerNorm): - def __init__( - self, - hidden_size, - eps: float = 1e-5, - *, - sequence_parallel_enabled: bool = False, - ): - warnings.warn( - "`apex.contrib.layer_norm.FastLayerNorm` isn't available thus falling back to `apex.normalization.FusedLayerNorm`" - ) - super().__init__( - normalized_shape=hidden_size, - eps=eps, - elementwise_affine=True, - sequence_parallel_enabled=sequence_parallel_enabled, - ) diff --git a/apex/transformer/log_util.py b/apex/transformer/log_util.py deleted file mode 100644 index 7eaafee22..000000000 --- a/apex/transformer/log_util.py +++ /dev/null @@ -1,18 +0,0 @@ -import logging -import os - - -def get_transformer_logger(name: str) -> logging.Logger: - name_wo_ext = os.path.splitext(name)[0] - return logging.getLogger(name_wo_ext) - - -def set_logging_level(verbosity) -> None: - """Change logging severity. - - Args: - verbosity - """ - from apex import _library_root_logger - - _library_root_logger.setLevel(verbosity) diff --git a/apex/transformer/microbatches.py b/apex/transformer/microbatches.py deleted file mode 100644 index d96828d1e..000000000 --- a/apex/transformer/microbatches.py +++ /dev/null @@ -1,191 +0,0 @@ -# coding=utf-8 -# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Megatron number of micro-batches calculators.""" - -from abc import ABC -from abc import abstractmethod -from typing import Optional, List - -from apex.transformer.log_util import get_transformer_logger - - -_logger = get_transformer_logger(__name__) - - -def build_num_microbatches_calculator( - rank: int, - rampup_batch_size: Optional[List[int]], - global_batch_size: int, - micro_batch_size: int, - data_parallel_size: int, -): - # Constant num micro-batches. - if rampup_batch_size is None: - num_microbatches_calculator = ConstantNumMicroBatches( - global_batch_size, micro_batch_size, data_parallel_size - ) - if rank == 0: - _logger.info( - "setting number of micro-batches to constant {}".format( - num_microbatches_calculator.get() - ) - ) - - else: - assert len(rampup_batch_size) == 3, ( - "expected the following " - "format: --rampup-batch-size " - " " - ) - start_batch_size = int(rampup_batch_size[0]) - batch_size_increment = int(rampup_batch_size[1]) - ramup_samples = int(rampup_batch_size[2]) - if rank == 0: - _logger.info( - "will use batch size rampup starting from global batch " - "size {} to global batch size {} with batch size increments " - "{} over {} samples.".format( - start_batch_size, - global_batch_size, - batch_size_increment, - ramup_samples, - ) - ) - num_microbatches_calculator = RampupBatchsizeNumMicroBatches( - start_batch_size, - batch_size_increment, - ramup_samples, - global_batch_size, - micro_batch_size, - data_parallel_size, - ) - - return num_microbatches_calculator - - -class NumMicroBatchesCalculator(ABC): - def __init__(self): - self.num_micro_batches = None - self.current_global_batch_size = None - - def get(self): - return self.num_micro_batches - - def get_current_global_batch_size(self): - return self.current_global_batch_size - - @abstractmethod - def update(self, consumed_samples, consistency_check): - pass - - -class ConstantNumMicroBatches(NumMicroBatchesCalculator): - def __init__(self, global_batch_size, micro_batch_size, data_parallel_size): - micro_batch_times_data_parallel = micro_batch_size * data_parallel_size - assert global_batch_size % micro_batch_times_data_parallel == 0, ( - "global batch size ({}) is not divisible by micro batch size ({})" - " times data parallel size ({})".format( - global_batch_size, micro_batch_size, data_parallel_size - ) - ) - self.num_micro_batches = global_batch_size // micro_batch_times_data_parallel - assert self.num_micro_batches >= 1 - self.current_global_batch_size = global_batch_size - - self.micro_batch_size = micro_batch_size - - def update(self, consumed_samples, consistency_check): - pass - - -class RampupBatchsizeNumMicroBatches(NumMicroBatchesCalculator): - def __init__( - self, - start_batch_size, - batch_size_increment, - ramup_samples, - global_batch_size, - micro_batch_size, - data_parallel_size, - ): - """Batch size ramp up. - Over - steps = (global-batch-size - start-batch-size) / batch_size_increment - increment batch size from start-batch-size to global-batch-size using - rampup-samples / steps - samples. - Arguments: - start_batch_size: global batch size to start with - batch_size_increment: global batch size increments - ramup_samples: number of samples to use ramp up global - batch size from `start_batch_size` to `global_batch_size` - global_batch_size: global batch size post rampup - micro_batch_size: micro batch size - data_parallel_size: data parallel size. - """ - - self.micro_batch_size = micro_batch_size - self.data_parallel_size = data_parallel_size - self.micro_batch_times_data_parallel_size = self.micro_batch_size * self.data_parallel_size - assert self.micro_batch_times_data_parallel_size > 0 - - assert start_batch_size > 0 - self.start_batch_size = start_batch_size - - assert global_batch_size > 0 - self.global_batch_size = global_batch_size - diff_batch_size = self.global_batch_size - self.start_batch_size - assert diff_batch_size >= 0 - assert batch_size_increment > 0 - self.batch_size_increment = batch_size_increment - assert diff_batch_size % batch_size_increment == 0, ( - "expected " - "global batch size interval ({}) to be divisible by global batch " - "size increment ({})".format(diff_batch_size, batch_size_increment) - ) - - num_increments = diff_batch_size // self.batch_size_increment - self.ramup_samples = ramup_samples - assert self.ramup_samples >= 0 - self.rampup_samples_per_increment = self.ramup_samples / num_increments - - # Initialize number of microbatches. - self.update(0, False) - - def update(self, consumed_samples, consistency_check): - if consumed_samples > self.ramup_samples: - self.current_global_batch_size = self.global_batch_size - else: - steps = int(consumed_samples / self.rampup_samples_per_increment) - self.current_global_batch_size = ( - self.start_batch_size + steps * self.batch_size_increment - ) - assert self.current_global_batch_size <= self.global_batch_size - - if consistency_check: - assert ( - self.current_global_batch_size % self.micro_batch_times_data_parallel_size == 0 - ), ( - "current global " - "batch size ({}) is not divisible by micro-batch-size ({}) times" - "data parallel size ({})".format( - self.current_global_batch_size, - self.micro_batch_size, - self.data_parallel_size, - ) - ) - self.num_micro_batches = ( - self.current_global_batch_size // self.micro_batch_times_data_parallel_size - ) diff --git a/apex/transformer/parallel_state.py b/apex/transformer/parallel_state.py deleted file mode 100644 index 1cd9b0d2c..000000000 --- a/apex/transformer/parallel_state.py +++ /dev/null @@ -1,810 +0,0 @@ -# coding=utf-8 -# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# TODO (mkozuki): Replace assert with RuntimeError. -# TODO (mkozuki): Sort the functions in the same order of megatron/mpu/initialize.py -"""Model and data parallel groups.""" - -from typing import Tuple, Optional -import warnings -import os -import torch - -from apex.transformer.log_util import get_transformer_logger -from apex.transformer._ucc_util import HAS_UCC - - -_logger = get_transformer_logger(__name__) - -# N.B. (mkozuki): Diff btwn Megatron-LM & apex parallel_state -# set(megatron_mpu_initialize_funcs) - set(apex.transformer.parallel_state) = -# { -# 'get_num_layers', -# } - - -# Intra-layer model parallel group that the current rank belongs to. -_TENSOR_MODEL_PARALLEL_GROUP = None -# Inter-layer model parallel group that the current rank belongs to. -_PIPELINE_MODEL_PARALLEL_GROUP = None -# Model parallel group (both intra- and pipeline) that the current rank belongs to. -_MODEL_PARALLEL_GROUP = None -# Embedding group. -_EMBEDDING_GROUP = None -# Position embedding group. -_POSITION_EMBEDDING_GROUP = None -# Relative position embedding group. -_ENCODER_RELATIVE_POSITION_EMBEDDING_GROUP = None -_DECODER_RELATIVE_POSITION_EMBEDDING_GROUP = None -# Data parallel group that the current rank belongs to. -_DATA_PARALLEL_GROUP = None -# Data parallel AMAX reduction group that the current rank belongs to. -_AMAX_REDUCTION_GROUP = None - -_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None -_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None -_PIPELINE_MODEL_PARALLEL_SPLIT_RANK = None - -# These values enable us to change the mpu sizes on the fly. -_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None -_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None -_MPU_TENSOR_MODEL_PARALLEL_RANK = None -_MPU_PIPELINE_MODEL_PARALLEL_RANK = None - -# A list of ranks that have a copy of the embedding. -_EMBEDDING_GLOBAL_RANKS = None - -# A list of ranks that have a copy of the position embedding. -_POSITION_EMBEDDING_GLOBAL_RANKS = None - -# A list of ranks that have a copy of the relative position embedding. -_ENCODER_RELATIVE_POSITION_EMBEDDING_GLOBAL_RANKS = None -_DECODER_RELATIVE_POSITION_EMBEDDING_GLOBAL_RANKS = None - -# A list of global ranks for each pipeline group to ease calculation of the source -# rank when broadcasting from the first or last pipeline stage -_PIPELINE_GLOBAL_RANKS = None - - -def is_unitialized(): - """Useful for code segments that may be accessed with or without mpu initialization""" - return _DATA_PARALLEL_GROUP is None - - -def set_nccl_socket_envs(): - if os.getenv("NCCL_SOCKET_IFNAME") is None: - raise RuntimeError("NCCL_SOCKET_IFNAME was not set") - os.environ["NCCL_NET"] = "Socket" - - -def set_nccl_ib_envs(): - os.environ["NCCL_NET"] = "IB" - - -def init_nccl_net(group): - temp = torch.ones(1, device="cuda") - torch.distributed.all_reduce(temp, group=group) - torch.cuda.synchronize() - - -def new_nccl_socket_group(ranks): - set_nccl_socket_envs() - group = torch.distributed.new_group(ranks, backend="nccl") - init_nccl_net(group=group) - return group - - -def new_nccl_ib_group(ranks): - set_nccl_ib_envs() - group = torch.distributed.new_group(ranks, backend="nccl") - init_nccl_net(group=group) - return group - - -def new_process_group(ranks, backend): - """ - This function creates process groups. - - In addition to simply creating the process groups, it initializes NCCL - for hybrid IB/Socket network like in the following diagram: - - ____________ - [GPU Node 0]---TCP---| |---TCP---[GPU Node 2] - | | | | - | | | | - IB | IP Network | IB - | | | | - | | | | - [GPU Node 1]---TCP---|____________|---TCP---[GPU Node 3] - - - If an environment variable NUM_GPUS_PER_IB_BLOCK is defined it looks up the ranks - and determines whether the list of ranks belong to the same computational block where - GPUs nodes are interconnected via IB type of connection or not. - If all ranks are in the same block, the process group will use NCCL_NET=IB for - communication, otherwise it will use NCCL_NET=Socket. - - If NCCL_NET=Socket is ever to be used, the user must set NCCL_SOCKET_IFNAME. - Additionally, it is recommended to set NCCL_SOCKET_NTHREADS and - NCCL_NSOCKS_PERTHREAD before running the job. - See: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html - for more info - - The core assumption for this functionality is that the ranks are evenly divided - into IB blocks and all these IB blocks are of the same size. - """ - if backend is None: - backend = "nccl" - - compute_block_size = os.getenv("NUM_GPUS_PER_IB_BLOCK") - if backend == "nccl" and compute_block_size is not None: - compute_block_size = int(compute_block_size) - blocks = [rank // compute_block_size for rank in ranks] - use_ib = all(block == blocks[0] for block in blocks) - if use_ib: - return new_nccl_ib_group(ranks) - else: - return new_nccl_socket_group(ranks) - else: - return torch.distributed.new_group(ranks, backend=backend) - - -def initialize_model_parallel( - tensor_model_parallel_size_: int = 1, - pipeline_model_parallel_size_: int = 1, - virtual_pipeline_model_parallel_size_: Optional[int] = None, - pipeline_model_parallel_split_rank_: Optional[int] = None, - use_fp8_: bool = False, - init_mpi_proc_group: bool = False, - *, - default_backend: Optional[str] = None, - p2p_backend: Optional[str] = None, -) -> None: - """ - Initialize model data parallel groups. - - Arguments: - tensor_model_parallel_size: number of GPUs used to parallelize model tensor. - pipeline_model_parallel_size: number of GPUs used to parallelize model pipeline. - virtual_pipeline_model_parallel_size: number of virtual stages (interleaved pipeline). - pipeline_model_parallel_split_rank: for models with both encoder and decoder, rank in pipeline with split point. - use_fp8_: FP8 training that needs AMAX reduction across data-parallel ranks. - init_mpi_proc_group: Create a MPI process group, which is used for UCX-based communication APIs. - Keyword Arguments: - default_backend: Backend of process groups except for pipeline parallel ones. - If :obj:`None`, the backend specified in `torch.distributed.init_process_group` will be used. - p2p_backend: Backend of process groups for pipeline model parallel. - If :obj:`None`, the backend specified in `torch.distributed.init_process_group` will be used. - - .. note:: - `torch_ucc `_ is - necessary for "ucc" backend. - - Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we - use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize - the model pipeline. The present function will - create 8 tensor model-parallel groups, 4 pipeline model-parallel groups - and 8 data-parallel groups as: - 8 data_parallel groups: - [g0, g2], [g1, g3], [g4, g6], [g5, g7], [g8, g10], [g9, g11], [g12, g14], [g13, g15] - 8 tensor model-parallel groups: - [g0, g1], [g2, g3], [g4, g5], [g6, g7], [g8, g9], [g10, g11], [g12, g13], [g14, g15] - 4 pipeline model-parallel groups: - [g0, g4, g8, g12], [g1, g5, g9, g13], [g2, g6, g10, g14], [g3, g7, g11, g15] - Note that for efficiency, the caller should make sure adjacent ranks - are on the same DGX box. For example if we are using 2 DGX-1 boxes - with a total of 16 GPUs, rank 0 to 7 belong to the first box and - ranks 8 to 15 belong to the second box. - """ - - from apex import deprecated_warning - - deprecated_warning( - "`apex.transformer` is deprecated and will be removed in September 2025. " - "We encourage you to migrate to Megatron Core. " - "It is available on PyPI at https://pypi.org/project/megatron-core/ " - "and its documentation can be found at https://docs.nvidia.com/megatron-core/developer-guide/latest/index.html." - ) - # Get world size and rank. Ensure some consistencies. - assert torch.distributed.is_initialized() - assert default_backend is None or default_backend in ("nccl", "ucc") - assert p2p_backend is None or p2p_backend in ("nccl", "ucc") - if "ucc" in (default_backend, p2p_backend): - if not HAS_UCC: - raise ImportError( - "UCC backend requires pytorch source build with UCC installed and enabled" - ) - warnings.warn("`ucc` backend support is experimental", ExperimentalWarning) - if default_backend == "ucc": - warnings.warn( - "The UCC's functionality as `default_backend` is not well verified", - ExperimentalWarning, - ) - - # Saving the NCCL_NET type for reusing it at the epilogue - default_nccl_net = os.getenv("NCCL_NET") - - world_size: int = torch.distributed.get_world_size() - tensor_model_parallel_size: int = min(tensor_model_parallel_size_, world_size) - pipeline_model_parallel_size: int = min(pipeline_model_parallel_size_, world_size) - if world_size % (tensor_model_parallel_size * pipeline_model_parallel_size) != 0: - raise RuntimeError( - f"`world_size` ({world_size}) is not divisible by tensor_model_parallel_size ({tensor_model_parallel_size}) x pipeline_model_parallel_size ({pipeline_model_parallel_size})" - ) - data_parallel_size: int = world_size // ( - tensor_model_parallel_size * pipeline_model_parallel_size - ) - if torch.distributed.get_rank() == 0: - _logger.info( - "> initializing tensor model parallel with size {}".format(tensor_model_parallel_size) - ) - _logger.info( - "> initializing pipeline model parallel with size {}".format( - pipeline_model_parallel_size - ) - ) - _logger.info("> initializing data parallel with size {}".format(data_parallel_size)) - - num_tensor_model_parallel_groups: int = world_size // tensor_model_parallel_size - num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size - num_data_parallel_groups: int = world_size // data_parallel_size - - if virtual_pipeline_model_parallel_size_ is not None: - # n.b. (eqy) This check was inherited from Megatron-LM, need to revisit - # the root cause as we do see numerical mismatches with 2 stages and - # the interleaved schedule - assert pipeline_model_parallel_size_ > 2, ( - "pipeline-model-parallel size should be greater than 2 with interleaved schedule" - ) - global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK - global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE - _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = 0 - _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = virtual_pipeline_model_parallel_size_ - - if pipeline_model_parallel_split_rank_ is not None: - global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK - _PIPELINE_MODEL_PARALLEL_SPLIT_RANK = pipeline_model_parallel_split_rank_ - - rank = torch.distributed.get_rank() - - # Build the data-parallel groups. - global _DATA_PARALLEL_GROUP - assert _DATA_PARALLEL_GROUP is None, "data parallel group is already initialized" - all_data_parallel_group_ranks = [] - for i in range(pipeline_model_parallel_size): - start_rank = i * num_pipeline_model_parallel_groups - end_rank = (i + 1) * num_pipeline_model_parallel_groups - for j in range(tensor_model_parallel_size): - ranks = range(start_rank + j, end_rank, tensor_model_parallel_size) - all_data_parallel_group_ranks.append(list(ranks)) - group = new_process_group(ranks, backend=default_backend) - if rank in ranks: - _DATA_PARALLEL_GROUP = group - - # Build the amax-reduction groups for fp8 precision conversion. - if use_fp8_: - global _AMAX_REDUCTION_GROUP - assert _AMAX_REDUCTION_GROUP is None, "amax reduction group is already initialized" - amax_group_size: int = tensor_model_parallel_size * data_parallel_size - num_amax_groups: int = world_size // amax_group_size - for i in range(num_amax_groups): - start_rank = i * amax_group_size - end_rank = (i + 1) * amax_group_size - ranks = range(start_rank, end_rank) - group = torch.distributed.new_group(ranks, backend=default_backend) - if rank in ranks: - _AMAX_REDUCTION_GROUP = group - - # Build the model-parallel groups. - global _MODEL_PARALLEL_GROUP - assert _MODEL_PARALLEL_GROUP is None, "model parallel group is already initialized" - for i in range(data_parallel_size): - ranks = [ - data_parallel_group_ranks[i] - for data_parallel_group_ranks in all_data_parallel_group_ranks - ] - group = new_process_group(ranks, backend=default_backend) - if rank in ranks: - _MODEL_PARALLEL_GROUP = group - - # Build the tensor model-parallel groups. - global _TENSOR_MODEL_PARALLEL_GROUP - assert _TENSOR_MODEL_PARALLEL_GROUP is None, ( - "tensor model parallel group is already initialized" - ) - for i in range(num_tensor_model_parallel_groups): - ranks = list(range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size)) - group = new_process_group(ranks, backend=default_backend) - if rank in ranks: - _TENSOR_MODEL_PARALLEL_GROUP = group - - # Build the pipeline model-parallel groups and embedding groups - # (first and last rank in each pipeline model-parallel group). - global _PIPELINE_MODEL_PARALLEL_GROUP - global _PIPELINE_GLOBAL_RANKS - assert _PIPELINE_MODEL_PARALLEL_GROUP is None, ( - "pipeline model parallel group is already initialized" - ) - global _EMBEDDING_GROUP - global _EMBEDDING_GLOBAL_RANKS - assert _EMBEDDING_GROUP is None, "embedding group is already initialized" - global _POSITION_EMBEDDING_GROUP - global _POSITION_EMBEDDING_GLOBAL_RANKS - assert _POSITION_EMBEDDING_GROUP is None, "position embedding group is already initialized" - global _ENCODER_RELATIVE_POSITION_EMBEDDING_GROUP - global _DECODER_RELATIVE_POSITION_EMBEDDING_GROUP - global _ENCODER_RELATIVE_POSITION_EMBEDDING_GLOBAL_RANKS - global _DECODER_RELATIVE_POSITION_EMBEDDING_GLOBAL_RANKS - assert ( - _ENCODER_RELATIVE_POSITION_EMBEDDING_GROUP is None - or _DECODER_RELATIVE_POSITION_EMBEDDING_GROUP is None - ), "relative position embedding group is already initialized" - for i in range(num_pipeline_model_parallel_groups): - ranks = range(i, world_size, num_pipeline_model_parallel_groups) - group = new_process_group(ranks, backend=p2p_backend) - if rank in ranks: - _PIPELINE_MODEL_PARALLEL_GROUP = group - _PIPELINE_GLOBAL_RANKS = ranks - # Setup embedding group (to exchange gradients between - # first and last stages). - encoder_relative_position_embedding_ranks = None - decoder_relative_position_embedding_ranks = None - if len(ranks) > 1: - embedding_ranks = [ranks[0], ranks[-1]] - position_embedding_ranks = [ranks[0]] - encoder_relative_position_embedding_ranks = [ranks[0]] - decoder_relative_position_embedding_ranks = [ranks[0]] - if pipeline_model_parallel_split_rank_ is not None: - encoder_relative_position_embedding_ranks = ranks[ - :pipeline_model_parallel_split_rank_ - ] - decoder_relative_position_embedding_ranks = ranks[ - pipeline_model_parallel_split_rank_: - ] - if ranks[pipeline_model_parallel_split_rank_] not in embedding_ranks: - embedding_ranks = [ - ranks[0], - ranks[pipeline_model_parallel_split_rank_], - ranks[-1], - ] - if ranks[pipeline_model_parallel_split_rank_] not in position_embedding_ranks: - position_embedding_ranks = [ - ranks[0], - ranks[pipeline_model_parallel_split_rank_], - ] - else: - embedding_ranks = ranks - position_embedding_ranks = ranks - encoder_relative_position_embedding_ranks = ranks - decoder_relative_position_embedding_ranks = ranks - - group = new_process_group(embedding_ranks, backend=p2p_backend) - if rank in embedding_ranks: - _EMBEDDING_GROUP = group - if rank in ranks: - _EMBEDDING_GLOBAL_RANKS = embedding_ranks - - group = new_process_group(position_embedding_ranks, backend=p2p_backend) - if rank in position_embedding_ranks: - _POSITION_EMBEDDING_GROUP = group - if rank in ranks: - _POSITION_EMBEDDING_GLOBAL_RANKS = position_embedding_ranks - - if encoder_relative_position_embedding_ranks: - group = new_process_group( - encoder_relative_position_embedding_ranks, backend=p2p_backend - ) - if rank in encoder_relative_position_embedding_ranks: - _ENCODER_RELATIVE_POSITION_EMBEDDING_GROUP = group - if rank in ranks: - _ENCODER_RELATIVE_POSITION_EMBEDDING_GLOBAL_RANKS = ( - encoder_relative_position_embedding_ranks - ) - - if decoder_relative_position_embedding_ranks: - group = new_process_group( - decoder_relative_position_embedding_ranks, backend=p2p_backend - ) - if rank in decoder_relative_position_embedding_ranks: - _DECODER_RELATIVE_POSITION_EMBEDDING_GROUP = group - if rank in ranks: - _DECODER_RELATIVE_POSITION_EMBEDDING_GLOBAL_RANKS = ( - decoder_relative_position_embedding_ranks - ) - - if init_mpi_proc_group: - torch.distributed.new_group(backend="mpi") - - if default_nccl_net == "Socket": - set_nccl_socket_envs() - elif default_nccl_net == "IB": - set_nccl_ib_envs() - elif default_nccl_net is None: - os.unsetenv("NCCL_NET") - else: - os.environ["NCCL_NET"] = default_nccl_net - - -def get_rank_info() -> Tuple[int, int, int]: - """Returns a tuple of (data, tensor, pipeline, virtual pipeline)-parallel-rank for logger.""" - if model_parallel_is_initialized(): - return ( - get_data_parallel_rank(), - get_tensor_model_parallel_rank(), - get_pipeline_model_parallel_rank(), - get_virtual_pipeline_model_parallel_rank(), - ) - return (0, 0, 0, 0) - - -def model_parallel_is_initialized(): - """Check if model and data parallel groups are initialized.""" - if ( - _TENSOR_MODEL_PARALLEL_GROUP is None - or _PIPELINE_MODEL_PARALLEL_GROUP is None - or _DATA_PARALLEL_GROUP is None - ): - return False - return True - - -def get_model_parallel_group(): - """Get the model parallel group the caller rank belongs to.""" - assert _MODEL_PARALLEL_GROUP is not None, "model parallel group is not initialized" - return _MODEL_PARALLEL_GROUP - - -def get_tensor_model_parallel_group(): - """Get the tensor model parallel group the caller rank belongs to.""" - assert _TENSOR_MODEL_PARALLEL_GROUP is not None, ( - "intra_layer_model parallel group is not initialized" - ) - return _TENSOR_MODEL_PARALLEL_GROUP - - -def get_pipeline_model_parallel_group(): - """Get the pipeline model parallel group the caller rank belongs to.""" - assert _PIPELINE_MODEL_PARALLEL_GROUP is not None, ( - "pipeline_model parallel group is not initialized" - ) - return _PIPELINE_MODEL_PARALLEL_GROUP - - -def get_data_parallel_group(): - """Get the data parallel group the caller rank belongs to.""" - assert _DATA_PARALLEL_GROUP is not None, "data parallel group is not initialized" - return _DATA_PARALLEL_GROUP - - -def get_amax_reduction_group(): - """Get the amax reduction group the caller rank belongs to.""" - assert _AMAX_REDUCTION_GROUP is not None, "AMAX reduction group is not initialized" - return _AMAX_REDUCTION_GROUP - - -def get_embedding_group(): - """Get the embedding group the caller rank belongs to.""" - assert _EMBEDDING_GROUP is not None, "embedding group is not initialized" - return _EMBEDDING_GROUP - - -def get_position_embedding_group(): - """Get the position embedding group the caller rank belongs to.""" - assert _POSITION_EMBEDDING_GROUP is not None, "position embedding group is not initialized" - return _POSITION_EMBEDDING_GROUP - - -def get_encoder_relative_position_embedding_group(): - """Get the encoder relative position embedding group the caller rank belongs to.""" - assert _ENCODER_RELATIVE_POSITION_EMBEDDING_GROUP is not None, ( - "encoder relative position embedding group is not initialized" - ) - return _ENCODER_RELATIVE_POSITION_EMBEDDING_GROUP - - -def get_decoder_relative_position_embedding_group(): - """Get the decoder relative position embedding group the caller rank belongs to.""" - assert _DECODER_RELATIVE_POSITION_EMBEDDING_GROUP is not None, ( - "decoder relative position embedding group is not initialized" - ) - return _DECODER_RELATIVE_POSITION_EMBEDDING_GROUP - - -def is_rank_in_embedding_group(ignore_virtual=False): - """Return true if current rank is in embedding group, False otherwise.""" - rank = torch.distributed.get_rank() - global _EMBEDDING_GLOBAL_RANKS - if ignore_virtual: - return rank in _EMBEDDING_GLOBAL_RANKS - if rank in _EMBEDDING_GLOBAL_RANKS: - if rank == _EMBEDDING_GLOBAL_RANKS[0]: - return is_pipeline_first_stage(ignore_virtual=False) - elif rank == _EMBEDDING_GLOBAL_RANKS[-1]: - return is_pipeline_last_stage(ignore_virtual=False) - else: - return True - return False - - -def is_rank_in_position_embedding_group(): - """Return whether the current rank is in position embedding group.""" - rank = torch.distributed.get_rank() - global _POSITION_EMBEDDING_GLOBAL_RANKS - return rank in _POSITION_EMBEDDING_GLOBAL_RANKS - - -def is_rank_in_encoder_relative_position_embedding_group(): - """Return true if current rank is in encoder relative position embedding group, False otherwise.""" - rank = torch.distributed.get_rank() - global _ENCODER_RELATIVE_POSITION_EMBEDDING_GLOBAL_RANKS - return rank in _ENCODER_RELATIVE_POSITION_EMBEDDING_GLOBAL_RANKS - - -def is_rank_in_decoder_relative_position_embedding_group(): - """Return true if current rank is in decoder relative position embedding group, False otherwise.""" - rank = torch.distributed.get_rank() - global _DECODER_RELATIVE_POSITION_EMBEDDING_GLOBAL_RANKS - return rank in _DECODER_RELATIVE_POSITION_EMBEDDING_GLOBAL_RANKS - - -def is_pipeline_stage_before_split(rank=None): - """Return True if pipeline stage executes encoder block for a model - with both encoder and decoder.""" - if get_pipeline_model_parallel_world_size() == 1: - return True - if rank is None: - rank = get_pipeline_model_parallel_rank() - global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK - if _PIPELINE_MODEL_PARALLEL_SPLIT_RANK is None: - return True - if rank < _PIPELINE_MODEL_PARALLEL_SPLIT_RANK: - return True - return False - - -def is_pipeline_stage_after_split(rank=None): - """Return True if pipeline stage executes decoder block for a model - with both encoder and decoder.""" - if get_pipeline_model_parallel_world_size() == 1: - return True - if rank is None: - rank = get_pipeline_model_parallel_rank() - global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK - if _PIPELINE_MODEL_PARALLEL_SPLIT_RANK is None: - return True - if rank >= _PIPELINE_MODEL_PARALLEL_SPLIT_RANK: - return True - return False - - -def is_pipeline_stage_at_split(): - """Return true if pipeline stage executes decoder block and next - stage executes encoder block for a model with both encoder and - decoder.""" - rank = get_pipeline_model_parallel_rank() - return is_pipeline_stage_before_split(rank) and is_pipeline_stage_after_split(rank + 1) - - -def set_tensor_model_parallel_world_size(world_size): - """Set the tensor model parallel size""" - global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE - _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = world_size - - -def set_pipeline_model_parallel_world_size(world_size): - """Set the pipeline model parallel size""" - global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE - _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = world_size - - -def get_tensor_model_parallel_world_size(): - """Return world size for the tensor model parallel group.""" - global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE - if _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE is not None: - return _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE - return torch.distributed.get_world_size(group=get_tensor_model_parallel_group()) - - -def get_pipeline_model_parallel_world_size(): - """Return world size for the pipeline model parallel group.""" - global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE - if _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE is not None: - return _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE - return torch.distributed.get_world_size(group=get_pipeline_model_parallel_group()) - - -def set_tensor_model_parallel_rank(rank): - """Set tensor model parallel rank.""" - global _MPU_TENSOR_MODEL_PARALLEL_RANK - _MPU_TENSOR_MODEL_PARALLEL_RANK = rank - - -def set_pipeline_model_parallel_rank(rank): - """Set pipeline model parallel rank.""" - global _MPU_PIPELINE_MODEL_PARALLEL_RANK - _MPU_PIPELINE_MODEL_PARALLEL_RANK = rank - - -def get_tensor_model_parallel_rank(): - """Return my rank for the tensor model parallel group.""" - global _MPU_TENSOR_MODEL_PARALLEL_RANK - if _MPU_TENSOR_MODEL_PARALLEL_RANK is not None: - return _MPU_TENSOR_MODEL_PARALLEL_RANK - return torch.distributed.get_rank(group=get_tensor_model_parallel_group()) - - -def get_pipeline_model_parallel_rank(): - """Return my rank for the pipeline model parallel group.""" - global _MPU_PIPELINE_MODEL_PARALLEL_RANK - if _MPU_PIPELINE_MODEL_PARALLEL_RANK is not None: - return _MPU_PIPELINE_MODEL_PARALLEL_RANK - return torch.distributed.get_rank(group=get_pipeline_model_parallel_group()) - - -# TODO (mkozuki): Add [`get_num_layers`](https://github.com/NVIDIA/Megatron-LM/blob/e156d2fea7fc5c98e645f7742eb86b643956d840/megatron/mpu/initialize.py#L321) here, maybe? - - -def get_pipeline_model_parallel_split_rank(): - """Return my rank for the pipeline model parallel split rank.""" - global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK - return _PIPELINE_MODEL_PARALLEL_SPLIT_RANK - - -def set_pipeline_model_parallel_split_rank(pipeline_model_parallel_split_rank: int): - """Set my rank for the pipeline model parallel split rank.""" - global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK - _PIPELINE_MODEL_PARALLEL_SPLIT_RANK = pipeline_model_parallel_split_rank - - -def is_pipeline_first_stage(ignore_virtual=False): - """Return True if in the first pipeline model-parallel stage, False otherwise.""" - if not ignore_virtual: - if ( - get_virtual_pipeline_model_parallel_world_size() is not None - and get_virtual_pipeline_model_parallel_rank() != 0 - ): - return False - return get_pipeline_model_parallel_rank() == 0 - - -def is_pipeline_last_stage(ignore_virtual=False): - """Return True if in the last pipeline model-parallel stage, False otherwise.""" - if not ignore_virtual: - virtual_pipeline_model_parallel_world_size = ( - get_virtual_pipeline_model_parallel_world_size() - ) - if ( - virtual_pipeline_model_parallel_world_size is not None - and get_virtual_pipeline_model_parallel_rank() - != (virtual_pipeline_model_parallel_world_size - 1) - ): - return False - return get_pipeline_model_parallel_rank() == (get_pipeline_model_parallel_world_size() - 1) - - -def get_virtual_pipeline_model_parallel_rank(): - """Return the virtual pipeline-parallel rank.""" - global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK - return _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK - - -def set_virtual_pipeline_model_parallel_rank(rank): - """Set the virtual pipeline-parallel rank.""" - global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK - _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = rank - - -def get_virtual_pipeline_model_parallel_world_size(): - """Return the virtual pipeline-parallel world size.""" - global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE - return _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE - - -def set_virtual_pipeline_model_parallel_world_size(size): - """Return the virtual pipeline-parallel world size.""" - global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE - _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = size - - -def get_tensor_model_parallel_src_rank(): - """Calculate the global rank corresponding to the first local rank - in the tensor model parallel group.""" - global_rank = torch.distributed.get_rank() - local_world_size = get_tensor_model_parallel_world_size() - return (global_rank // local_world_size) * local_world_size - - -def get_data_parallel_src_rank(): - """Calculate the global rank corresponding to the first local rank in the data parallel group.""" - global_rank = torch.distributed.get_rank() - data_parallel_size: int = get_data_parallel_world_size() - num_data_parallel_groups = torch.distributed.get_world_size() // data_parallel_size - return global_rank % num_data_parallel_groups - - -def get_pipeline_model_parallel_first_rank(): - assert _PIPELINE_GLOBAL_RANKS is not None, "Pipeline parallel group is not initialized" - return _PIPELINE_GLOBAL_RANKS[0] - - -def get_pipeline_model_parallel_last_rank(): - assert _PIPELINE_GLOBAL_RANKS is not None, "Pipeline parallel group is not initialized" - last_rank_local = get_pipeline_model_parallel_world_size() - 1 - return _PIPELINE_GLOBAL_RANKS[last_rank_local] - - -def get_pipeline_model_parallel_next_rank(): - assert _PIPELINE_GLOBAL_RANKS is not None, "Pipeline parallel group is not initialized" - rank_in_pipeline = get_pipeline_model_parallel_rank() - world_size = get_pipeline_model_parallel_world_size() - return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline + 1) % world_size] - - -def get_pipeline_model_parallel_prev_rank(): - assert _PIPELINE_GLOBAL_RANKS is not None, "Pipeline parallel group is not initialized" - rank_in_pipeline = get_pipeline_model_parallel_rank() - world_size = get_pipeline_model_parallel_world_size() - return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline - 1) % world_size] - - -def get_data_parallel_world_size(): - """Return world size for the data parallel group.""" - return torch.distributed.get_world_size(group=get_data_parallel_group()) - - -def get_data_parallel_rank(): - """Return my rank for the data parallel group.""" - return torch.distributed.get_rank(group=get_data_parallel_group()) - - -# note (mkozuki): `destroy_model_parallel` voids more global variables than Megatron-LM. -# Otherwise pipeline parallel forward_backward functions test hangs possibly because -# the clean-up of the original is NOT enough. -def destroy_model_parallel(): - """Set the groups to none.""" - global _MODEL_PARALLEL_GROUP - _MODEL_PARALLEL_GROUP = None - global _TENSOR_MODEL_PARALLEL_GROUP - _TENSOR_MODEL_PARALLEL_GROUP = None - global _PIPELINE_MODEL_PARALLEL_GROUP - _PIPELINE_MODEL_PARALLEL_GROUP = None - global _DATA_PARALLEL_GROUP - _DATA_PARALLEL_GROUP = None - global _AMAX_REDUCTION_GROUP - _AMAX_REDUCTION_GROUP = None - global _EMBEDDING_GROUP - _EMBEDDING_GROUP = None - global _POSITION_EMBEDDING_GROUP - _POSITION_EMBEDDING_GROUP = None - global _ENCODER_RELATIVE_POSITION_EMBEDDING_GROUP - _ENCODER_RELATIVE_POSITION_EMBEDDING_GROUP = None - global _DECODER_RELATIVE_POSITION_EMBEDDING_GROUP - _DECODER_RELATIVE_POSITION_EMBEDDING_GROUP = None - global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK - _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None - global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE - _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None - global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE - _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None - global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE - _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None - global _MPU_TENSOR_MODEL_PARALLEL_RANK - _MPU_TENSOR_MODEL_PARALLEL_RANK = None - global _MPU_PIPELINE_MODEL_PARALLEL_RANK - _MPU_PIPELINE_MODEL_PARALLEL_RANK = None - - -# Used to warn when the UCC is specified. -class ExperimentalWarning(Warning): - pass diff --git a/apex/transformer/pipeline_parallel/__init__.py b/apex/transformer/pipeline_parallel/__init__.py deleted file mode 100644 index 98bb96028..000000000 --- a/apex/transformer/pipeline_parallel/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -from apex.transformer.pipeline_parallel.schedules import get_forward_backward_func -from apex.transformer.pipeline_parallel.schedules.common import build_model - - -__all__ = [ - "get_forward_backward_func", - "build_model", -] diff --git a/apex/transformer/pipeline_parallel/_timers.py b/apex/transformer/pipeline_parallel/_timers.py deleted file mode 100644 index 55d89f351..000000000 --- a/apex/transformer/pipeline_parallel/_timers.py +++ /dev/null @@ -1,83 +0,0 @@ -import time - -import torch - - -class _Timer: - """Timer.""" - - def __init__(self, name): - self.name_ = name - self.elapsed_ = 0.0 - self.started_ = False - self.start_time = time.time() - - def start(self): - """Start the timer.""" - assert not self.started_, "timer has already been started" - torch.cuda.synchronize() - self.start_time = time.time() - self.started_ = True - - def stop(self): - """Stop the timer.""" - assert self.started_, "timer is not started" - torch.cuda.synchronize() - self.elapsed_ += time.time() - self.start_time - self.started_ = False - - def reset(self): - """Reset timer.""" - self.elapsed_ = 0.0 - self.started_ = False - - def elapsed(self, reset=True): - """Calculate the elapsed time.""" - started_ = self.started_ - # If the timing in progress, end it first. - if self.started_: - self.stop() - # Get the elapsed time. - elapsed_ = self.elapsed_ - # Reset the elapsed time - if reset: - self.reset() - # If timing was in progress, set it back. - if started_: - self.start() - return elapsed_ - - -class _Timers: - """Group of timers.""" - - def __init__(self): - self.timers = {} - - def __call__(self, name): - if name not in self.timers: - self.timers[name] = _Timer(name) - return self.timers[name] - - def write(self, names, writer, iteration, normalizer=1.0, reset=False): - """Write timers to a tensorboard writer""" - # currently when using add_scalars, - # torch.utils.add_scalars makes each timer its own run, which - # polutes the runs list, so we just add each as a scalar - assert normalizer > 0.0 - for name in names: - value = self.timers[name].elapsed(reset=reset) / normalizer - writer.add_scalar(name + "-time", value, iteration) - - def log(self, names, normalizer=1.0, reset=True): - """Log a group of timers.""" - assert normalizer > 0.0 - string = "time (ms)" - for name in names: - elapsed_time = self.timers[name].elapsed(reset=reset) * 1000.0 / normalizer - string += " | {}: {:.2f}".format(name, elapsed_time) - if torch.distributed.is_initialized(): - if torch.distributed.get_rank() == (torch.distributed.get_world_size() - 1): - print(string, flush=True) - else: - print(string, flush=True) diff --git a/apex/transformer/pipeline_parallel/p2p_communication.py b/apex/transformer/pipeline_parallel/p2p_communication.py deleted file mode 100644 index f62d56297..000000000 --- a/apex/transformer/pipeline_parallel/p2p_communication.py +++ /dev/null @@ -1,713 +0,0 @@ -# coding=utf-8 -# Copyright (c) 2021-22, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# TODO(mkozuki): Consider removing `timers`. - -from functools import reduce -import operator -from typing import Union, Optional, Tuple - -import torch - -from apex.transformer import parallel_state -from apex.transformer.log_util import get_transformer_logger -from apex.transformer.utils import split_tensor_into_1d_equal_chunks -from apex.transformer.utils import gather_split_1d_tensor -from apex.transformer.pipeline_parallel.utils import Shape -from apex.transformer.pipeline_parallel._timers import _Timers - - -_logger = get_transformer_logger(__name__) - - -class FutureTensor: - def __init__(self, tensor: torch.Tensor, waitfunc): - self.tensor = tensor - self.waitfunc = waitfunc - - def get(self): - if self.waitfunc is not None: - res = self.waitfunc() - if isinstance(res, torch.Tensor): - self.tensor = res - self.waitfunc = None - return self.tensor - - -def _run_p2pops( - tensor_send_prev: Union[torch.Tensor, None], - tensor_send_next: Union[torch.Tensor, None], - tensor_recv_prev: Union[torch.Tensor, None], - tensor_recv_next: Union[torch.Tensor, None], - async_comm: bool = False, - overlap_p2p_comm: bool = False, - batch_p2p_comm: bool = True, -): - p2p_group = parallel_state.get_pipeline_model_parallel_group() - default_group = parallel_state.get_model_parallel_group() - - need_to_sync = p2p_group.name() != default_group.name() - reqs = [] - ops = [] - - if batch_p2p_comm and p2p_group.name() == "nccl": - if tensor_send_prev is not None: - send_prev_op = torch.distributed.P2POp( - op=torch.distributed.isend, - tensor=tensor_send_prev, - peer=parallel_state.get_pipeline_model_parallel_prev_rank(), - group=p2p_group, - ) - ops.append(send_prev_op) - if tensor_recv_prev is not None: - recv_prev_op = torch.distributed.P2POp( - op=torch.distributed.irecv, - tensor=tensor_recv_prev, - peer=parallel_state.get_pipeline_model_parallel_prev_rank(), - group=p2p_group, - ) - ops.append(recv_prev_op) - if tensor_send_next is not None: - send_next_op = torch.distributed.P2POp( - op=torch.distributed.isend, - tensor=tensor_send_next, - peer=parallel_state.get_pipeline_model_parallel_next_rank(), - group=p2p_group, - ) - ops.append(send_next_op) - if tensor_recv_next is not None: - recv_next_op = torch.distributed.P2POp( - op=torch.distributed.irecv, - tensor=tensor_recv_next, - peer=parallel_state.get_pipeline_model_parallel_next_rank(), - group=p2p_group, - ) - ops.append(recv_next_op) - if len(ops) > 0: - # sync before communication if needed - if need_to_sync: - torch.cuda.synchronize() - reqs = torch.distributed.batch_isend_irecv(ops) - else: - # sync before communication if needed - if need_to_sync and any( - [ - tensor_send_prev is not None, - tensor_recv_prev is not None, - tensor_send_next is not None, - tensor_recv_next is not None, - ] - ): - torch.cuda.synchronize() - - if tensor_send_prev is not None: - send_prev_req = torch.distributed.isend( - tensor=tensor_send_prev, - dst=parallel_state.get_pipeline_model_parallel_prev_rank(), - group=p2p_group, - ) - reqs.append(send_prev_req) - if tensor_recv_prev is not None: - recv_prev_req = torch.distributed.irecv( - tensor=tensor_recv_prev, - src=parallel_state.get_pipeline_model_parallel_prev_rank(), - group=p2p_group, - ) - reqs.append(recv_prev_req) - if tensor_send_next is not None: - send_next_req = torch.distributed.isend( - tensor=tensor_send_next, - dst=parallel_state.get_pipeline_model_parallel_next_rank(), - group=p2p_group, - ) - reqs.append(send_next_req) - if tensor_recv_next is not None: - recv_next_op = torch.distributed.irecv( - tensor=tensor_recv_next, - src=parallel_state.get_pipeline_model_parallel_next_rank(), - group=p2p_group, - ) - reqs.append(recv_next_op) - - if len(reqs) > 0: - if overlap_p2p_comm: - return (None, None, None, None, reqs) - - if async_comm: - if len(ops) == 0 or len(reqs) == len(ops): - tensor_send_prev_req = None if tensor_send_prev is None else reqs.pop(0) - tensor_recv_prev_req = None if tensor_recv_prev is None else reqs.pop(0) - tensor_send_next_req = None if tensor_send_next is None else reqs.pop(0) - tensor_recv_next_req = None if tensor_recv_next is None else reqs.pop(0) - elif len(reqs) == 1: - tensor_send_prev_req = None if tensor_send_prev is None else reqs[0] - tensor_recv_prev_req = None if tensor_recv_prev is None else reqs[0] - tensor_send_next_req = None if tensor_send_next is None else reqs[0] - tensor_recv_next_req = None if tensor_recv_next is None else reqs[0] - else: - assert False, "failed to manage p2p requests and handles" - return ( - tensor_send_prev_req, - tensor_recv_prev_req, - tensor_send_next_req, - tensor_recv_next_req, - None, - ) - else: - for req in reqs: - req.wait() - return (None, None, None, None, None) - return (None, None, None, None, None) - - -# TODO(mkozuki): Check if it's possible to sunset `override_scatter_gather_tensors_in_pipeline`. -# TODO(mkozuki): Think about if it's possible to push some logic and arguments e.g. -# `scatter_gather_tensors_in_pipeline`, `sequence_parallel_enabled`, and -# `override_scatter_gather_tensors_in_pipeline` # to the user of -# apex.transformer forward_backwardfunctions. -def _communicate( - tensor_send_next: Optional[torch.Tensor], - tensor_send_prev: Optional[torch.Tensor], - recv_prev: bool, - recv_next: bool, - tensor_shape: Optional[Shape] = None, - override_scatter_gather_tensors_in_pipeline: bool = False, - dtype_: Optional[torch.dtype] = None, - *, - scatter_gather_tensors_in_pipeline: bool = True, - params_dtype: Optional[torch.dtype] = None, - fp32_residual_connection: bool = False, - async_comm: bool = False, - sequence_parallel_enabled: bool = False, - sync_batch_comm: bool = True, - overlap_p2p_comm: bool = False, - batch_p2p_comm: bool = True, -) -> Tuple[Union[torch.Tensor, FutureTensor, None], Union[torch.Tensor, FutureTensor, None]]: - """Base function for communication of tensors between stages. - - - .. note:: - Reference https://github.com/NVIDIA/Megatron-LM/blob/cfd2e2160700b7f2c1bf35298ac14bc341f4c759/megatron/p2p_communication.py#L24-L159 - - dtype logic: If none of ``dtype_``, ``params_dtype``, ``fp32_residual_connection`` is specified, - torch.float32 is used. - - See https://github.com/NVIDIA/Megatron-LM/blob/d41696840ed0a7edb7e0499eb82a48ae112d9bb3/megatron/arguments.py#L145-L159 - for the details of arguments of ``dtype_``, ``params_dtype``, ``fp32_residual_connection``. - - Args: - tensor_send_next: tensor to send to next rank (no tensor sent if set to None). - tensor_send_prev: tensor to send to prev rank (no tensor sent if set to None). - recv_prev: boolean for whether tensor should be received from previous rank. - recv_next: boolean for whether tensor should be received from next rank. - tensor_shape: optional, use when the input sequence contains less tokens than the default sequence length - override_scatter_gather_tensors_in_pipeline: - optional, this is used when tensor_shape is provided to override scatter gather tensors - dtype_: This is used when tensor_shape is provided and what is the type of tensor_shape - - Keyword args: - scatter_gather_tensors_in_pipeline: Optional. If :obj:`True`, use scatter/gather to optimize communication of tensors. - params_dtype: Optional and legacy. Defaults to torch.float. If you manually call `.half()` or `.bfloat16()` on - your model deliberately, pass this argument. - fp32_residual_connection: Optional. If :obj:`True`, move residual connections to fp32. - sequence_parallel_enabled: Set to :obj:`True` if sequence parallel is enabled. - This argument is here for consistency with Megatron-LM. - This argument has an effect on the communication optimization, not on tensor_shape update. - sync_batch_comm: If :obj:`False`, disable cuda synchronization after the batched communication. - To disable, https://github.com/pytorch/pytorch/pull/82450 would be required. - overlap_p2p_comm: If :obj:`True`, returns cuda wait handles to scheduler instead of completing - the communication within the p2p transfer API instance. The scheduler manages the communication completion - to overlap with computation. - batch_p2p_comm: If :obj:`True`, use the batched send and receive api to conduct the communication of - a collection of send and receive operations between peer. If :obj:`False`, conduct each send and recv operation - individually. - - Returns: - tuple containing - - - tensor_recv_prev: `torch.Tensor` if `recv_prev` is :obj:`True`, `None` otherwise. - - tensor_recv_next: `torch.Tensor` if `recv_next` is :obj:`True`, `None` otherwise. - """ - if async_comm and sequence_parallel_enabled: - import warnings # NOQA - - class ExperimentalWarning(UserWarning): - pass # NOQA - - warnings.warn( - "The combination of `async_comm` and `sequence_parallel_enabled` is not well tested.", - ExperimentalWarning, - ) - # Create placeholder tensors for receive in forward and backward directions if needed. - tensor_recv_prev = None - tensor_recv_next = None - if tensor_shape is None: - # In megatron, `tensor_shape` is set to `(args.seq_length, args.micro_batch_size, args.hidden_size)` - raise RuntimeError( - "`tensor_shape` must be specified. Common `tensor_shape` is `(seq_length, micro_batch_size, hidden_size)`" - ) - - tensor_parallel_size = parallel_state.get_tensor_model_parallel_world_size() - override_scatter_gather_tensors_in_pipeline_ = False - # TODO(mkozuki): Demystify hardcode False of `scatter_gather_tensors_in_pipeline` and add a testcase if possible. - # NOTE(mkozuki): This is super strange and doesn't make sense to me. I have no idea what is happening here. - # However, I can say that this hardcoding override is necessary for sequence parallel in nemo megatron to work. - # I've not managed to reproduce the hang using standalone GPT with sequence parallel. - # The hang in NeMo Megatron happens in the 3rd iteration, the last iteration of stead phase inside - # forward_backward_pipelining_without_interleaving, pipeline parallel rank of 0 (tensor model parallel world - # size of 2 and pipeline model parallel world size of 2). The commit then of APEX and NeMo were - # https://github.com/NVIDIA/apex/pull/1396/commits/3060c98dd8ba42abf7702ea9d2cff0f39ea74f45 and - # https://github.com/NVIDIA/NeMo/pull/4232/commits/1cb32dfca2ab9b20f53ebdb84476c34cb42f0205. - # The PyTorch version was 1.13.0a0+git2d354cd, for what is worth. - # Currently, indiscriminately this is set to `False`, which can lead to an unexpected performance regression - # for non sequence parallel case. - scatter_gather_tensors_in_pipeline = False - if scatter_gather_tensors_in_pipeline and not sequence_parallel_enabled: - tensor_chunk_size = int(reduce(operator.mul, tensor_shape, 1)) - if tensor_chunk_size % tensor_parallel_size == 0: - tensor_chunk_shape = [tensor_chunk_size // tensor_parallel_size] - else: - tensor_chunk_shape = tensor_shape - override_scatter_gather_tensors_in_pipeline_ = True - else: - tensor_chunk_shape = tensor_shape - - # The dtype logic below is copied from NVIDIA/Megatron-LM repo: - # https://github.com/NVIDIA/Megatron-LM/blob/d41696840ed0a7edb7e0499eb82a48ae112d9bb3/megatron/p2p_communication.py#L74-L81 - dtype = params_dtype or torch.float - if fp32_residual_connection: - dtype = torch.float - requires_grad = True - if dtype_ is not None: - dtype = dtype_ - # TODO(mkozuki): Figure out why this logic of requires_grad isn't working - # when sequence_parallel_enabled=True. Otherwise, `x.retain_grad()` of - # https://github.com/crcrpar/apex/blob/069832078a652b4bd8a99db84faf953a81415ab3/apex/transformer/pipeline_parallel/schedules/common.py#L360 - # fails. - # requires_grad = False - - if recv_prev: - tensor_recv_prev = torch.empty( - tensor_chunk_shape, - requires_grad=requires_grad, - device=torch.cuda.current_device(), - dtype=dtype, - ) - if recv_next: - tensor_recv_next = torch.empty( - tensor_chunk_shape, - requires_grad=requires_grad, - device=torch.cuda.current_device(), - dtype=dtype, - ) - - # Split tensor into smaller chunks if using scatter-gather optimization. - scatter_gather_optimization_doable = ( - not override_scatter_gather_tensors_in_pipeline_ - and scatter_gather_tensors_in_pipeline - and not sequence_parallel_enabled - ) - if scatter_gather_optimization_doable: - if tensor_send_next is not None: - tensor_send_next = split_tensor_into_1d_equal_chunks(tensor_send_next) - - if tensor_send_prev is not None: - tensor_send_prev = split_tensor_into_1d_equal_chunks(tensor_send_prev) - - # Send tensors in both the forward and backward directions as appropriate. - ( - tensor_send_prev_req, - tensor_recv_prev_req, - tensor_send_next_req, - tensor_recv_next_req, - wait_handles, - ) = _run_p2pops( - tensor_send_prev, - tensor_send_next, - tensor_recv_prev, - tensor_recv_next, - async_comm, - overlap_p2p_comm, - batch_p2p_comm, - ) - - if async_comm: - tensor_recv_prev_waitfunc = None - tensor_recv_next_waitfunc = None - # TODO: investigate whether this is necessary for correctness (ref: https://github.com/pytorch/pytorch/issues/38642) - # see also: sync added for async_comm callbacks below in gather_recv_prev_wait and gather_recv_next_wait - if tensor_recv_prev_req is not None: - - def tensor_recv_prev_wait(): - tensor_recv_prev_req.wait() - torch.cuda.synchronize() - - tensor_recv_prev_waitfunc = tensor_recv_prev_wait - if tensor_recv_next_req is not None: - - def tensor_recv_next_wait(): - tensor_recv_next_req.wait() - torch.cuda.synchronize() - - tensor_recv_next_waitfunc = tensor_recv_next_wait - else: - if sync_batch_comm: - # To protect against race condition when using batch_isend_irecv(). - torch.cuda.synchronize() - - # If using scatter-gather optimization, gather smaller chunks. - if scatter_gather_optimization_doable: - if not async_comm: - if recv_prev: - tensor_recv_prev = ( - gather_split_1d_tensor(tensor_recv_prev).view(tensor_shape).requires_grad_() - ) - - if recv_next: - tensor_recv_next = ( - gather_split_1d_tensor(tensor_recv_next).view(tensor_shape).requires_grad_() - ) - else: - - def gather_recv_prev_wait(): - tensor_recv_prev_req.wait() - # From @Deepak's PR https://github.com/NVIDIA/Megatron-LM/commit/27fc468964064eeb33b703c9a0b2af938d80dd14 - # A sync seems to be needed before gather otherwise losses jump around e.g., in run_gpt_minimal_test - torch.cuda.synchronize() - return gather_split_1d_tensor(tensor_recv_prev).view(tensor_shape).requires_grad_() - - def gather_recv_next_wait(): - tensor_recv_next_req.wait() - torch.cuda.synchronize() - return gather_split_1d_tensor(tensor_recv_next).view(tensor_shape).requires_grad_() - - tensor_recv_prev_waitfunc = gather_recv_prev_wait - tensor_recv_next_waitfunc = gather_recv_next_wait - if async_comm: - future_tensor_recv_prev = None - future_tensor_recv_next = None - if tensor_recv_prev is not None: - future_tensor_recv_prev = FutureTensor(tensor_recv_prev, tensor_recv_prev_waitfunc) - if tensor_recv_next is not None: - future_tensor_recv_next = FutureTensor(tensor_recv_next, tensor_recv_next_waitfunc) - return future_tensor_recv_prev, future_tensor_recv_next, None - return tensor_recv_prev, tensor_recv_next, wait_handles - - -def recv_forward( - tensor_shape: Shape, - override_scatter_gather_tensors_in_pipeline: bool = False, - *, - dtype: Optional[torch.dtype] = None, - async_comm: bool = False, - sequence_parallel_enabled: bool = False, - sync_batch_comm: bool = True, - batch_p2p_comm: bool = True, - timers: _Timers = None, -) -> Union[torch.Tensor, FutureTensor, None]: - """Receive tensor from previous rank in pipeline (forward receive).""" - if parallel_state.is_pipeline_first_stage(): - return None - # if timers is not None: - # timers("forward-recv").start() - input_tensor, _, _ = _communicate( - tensor_send_next=None, - tensor_send_prev=None, - recv_prev=True, - recv_next=False, - tensor_shape=tensor_shape, - override_scatter_gather_tensors_in_pipeline=override_scatter_gather_tensors_in_pipeline, - dtype_=dtype, - async_comm=async_comm, - sequence_parallel_enabled=sequence_parallel_enabled, - sync_batch_comm=sync_batch_comm, - batch_p2p_comm=batch_p2p_comm, - ) - # if timers is not None: - # timers("forward-recv").stop() - return input_tensor - - -def recv_backward( - tensor_shape: Shape = None, - *, - dtype: Optional[torch.dtype] = None, - async_comm: bool = False, - sequence_parallel_enabled: bool = False, - sync_batch_comm: bool = True, - batch_p2p_comm: bool = True, - timers: _Timers = None, -) -> Union[torch.Tensor, FutureTensor, None]: - """Receive tensor from next rank in pipeline (backward receive).""" - if parallel_state.is_pipeline_last_stage(): - return None - # if timers is not None: - # timers("backward-recv").start() - _, output_tensor_grad, _ = _communicate( - tensor_send_next=None, - tensor_send_prev=None, - recv_prev=False, - recv_next=True, - tensor_shape=tensor_shape, - dtype_=dtype, - async_comm=async_comm, - sequence_parallel_enabled=sequence_parallel_enabled, - sync_batch_comm=sync_batch_comm, - batch_p2p_comm=batch_p2p_comm, - ) - # if timers is not None: - # timers("backward-recv").stop() - return output_tensor_grad - - -def send_forward( - output_tensor: torch.Tensor, - override_scatter_gather_tensors_in_pipeline: bool = False, - tensor_shape: Shape = None, - *, - dtype: Optional[torch.dtype] = None, - async_comm: bool = False, - sequence_parallel_enabled: bool = False, - sync_batch_comm: bool = True, - batch_p2p_comm: bool = True, - timers: _Timers = None, -) -> None: - """Send tensor to next rank in pipeline (forward send).""" - if parallel_state.is_pipeline_last_stage(): - return - # if timers is not None: - # timers("forward-send").start() - _communicate( - tensor_send_next=output_tensor, - tensor_send_prev=None, - recv_prev=False, - recv_next=False, - override_scatter_gather_tensors_in_pipeline=override_scatter_gather_tensors_in_pipeline, - tensor_shape=tensor_shape, - dtype_=dtype, - async_comm=async_comm, - sequence_parallel_enabled=sequence_parallel_enabled, - sync_batch_comm=sync_batch_comm, - batch_p2p_comm=batch_p2p_comm, - ) - # if timers is not None: - # timers("forward-send").stop() - - -def send_backward( - input_tensor_grad: torch.Tensor, - tensor_shape: Shape, - *, - dtype: Optional[torch.dtype] = None, - async_comm: bool = False, - sequence_parallel_enabled: bool = False, - sync_batch_comm: bool = True, - batch_p2p_comm: bool = True, - timers: _Timers = None, -) -> None: - """Send tensor to previous rank in pipeline (backward send).""" - if parallel_state.is_pipeline_first_stage(): - return - # if timers is not None: - # timers("backward-send").start() - _communicate( - tensor_send_next=None, - tensor_send_prev=input_tensor_grad, - recv_prev=False, - recv_next=False, - tensor_shape=tensor_shape, - dtype_=dtype, - async_comm=async_comm, - sequence_parallel_enabled=sequence_parallel_enabled, - sync_batch_comm=sync_batch_comm, - batch_p2p_comm=batch_p2p_comm, - ) - # if timers is not None: - # timers("backward-send").stop() - - -def send_forward_recv_backward( - output_tensor: torch.Tensor, - tensor_shape: Shape, - *, - dtype: Optional[torch.dtype] = None, - async_comm: bool = False, - sequence_parallel_enabled: bool = False, - sync_batch_comm: bool = True, - batch_p2p_comm: bool = True, - timers: _Timers = None, -) -> Union[torch.Tensor, FutureTensor, None]: - """Batched send and recv with next rank in pipeline.""" - if parallel_state.is_pipeline_last_stage(): - return None - # if timers is not None: - # timers("forward-send-backward-recv").start() - _, output_tensor_grad, _ = _communicate( - tensor_send_next=output_tensor, - tensor_send_prev=None, - recv_prev=False, - recv_next=True, - tensor_shape=tensor_shape, - dtype_=dtype, - async_comm=async_comm, - sequence_parallel_enabled=sequence_parallel_enabled, - sync_batch_comm=sync_batch_comm, - batch_p2p_comm=batch_p2p_comm, - ) - # if timers is not None: - # timers("forward-send-backward-recv").stop() - return output_tensor_grad - - -def send_backward_recv_forward( - input_tensor_grad: torch.Tensor, - tensor_shape: Shape, - *, - dtype: Optional[torch.dtype] = None, - async_comm: bool = False, - sequence_parallel_enabled: bool = False, - sync_batch_comm: bool = True, - batch_p2p_comm: bool = True, - timers: _Timers = None, -) -> Union[torch.Tensor, FutureTensor, None]: - """Batched send and recv with previous rank in pipeline.""" - if parallel_state.is_pipeline_first_stage(): - return None - # if timers is not None: - # timers("backward-send-forward-recv").start() - input_tensor, _, _ = _communicate( - tensor_send_next=None, - tensor_send_prev=input_tensor_grad, - recv_prev=True, - recv_next=False, - tensor_shape=tensor_shape, - dtype_=dtype, - async_comm=async_comm, - sequence_parallel_enabled=sequence_parallel_enabled, - sync_batch_comm=sync_batch_comm, - batch_p2p_comm=batch_p2p_comm, - ) - # if timers is not None: - # timers("backward-send-forward-recv").stop() - return input_tensor - - -def send_forward_recv_forward( - output_tensor: torch.Tensor, - recv_prev: bool, - tensor_shape: Shape, - *, - dtype: Optional[torch.dtype] = None, - async_comm: bool = False, - sequence_parallel_enabled: bool = False, - sync_batch_comm: bool = True, - overlap_p2p_comm: bool = False, - batch_p2p_comm: bool = True, - timers: _Timers = None, -) -> Union[torch.Tensor, FutureTensor]: - """Batched recv from previous rank and send to next rank in pipeline.""" - # if timers is not None: - # timers("forward-send-forward-recv").start() - input_tensor, _, wait_handles = _communicate( - tensor_send_next=output_tensor, - tensor_send_prev=None, - recv_prev=recv_prev, - recv_next=False, - tensor_shape=tensor_shape, - dtype_=dtype, - async_comm=async_comm, - sequence_parallel_enabled=sequence_parallel_enabled, - sync_batch_comm=sync_batch_comm, - overlap_p2p_comm=overlap_p2p_comm, - batch_p2p_comm=batch_p2p_comm, - ) - # if timers is not None: - # timers("forward-send-forward-recv").stop() - if overlap_p2p_comm: - return input_tensor, wait_handles - return input_tensor - - -def send_backward_recv_backward( - input_tensor_grad: torch.Tensor, - recv_next: bool, - tensor_shape: Shape, - *, - dtype: Optional[torch.dtype] = None, - async_comm: bool = False, - sequence_parallel_enabled: bool = False, - sync_batch_comm: bool = True, - overlap_p2p_comm: bool = False, - batch_p2p_comm: bool = True, - timers: _Timers = None, -) -> Union[torch.Tensor, FutureTensor]: - """Batched recv from next rank and send to previous rank in pipeline.""" - # if timers is not None: - # timers("backward-send-backward-recv").start() - _, output_tensor_grad, wait_handles = _communicate( - tensor_send_next=None, - tensor_send_prev=input_tensor_grad, - recv_prev=False, - recv_next=recv_next, - tensor_shape=tensor_shape, - dtype_=dtype, - async_comm=async_comm, - sequence_parallel_enabled=sequence_parallel_enabled, - sync_batch_comm=sync_batch_comm, - overlap_p2p_comm=overlap_p2p_comm, - batch_p2p_comm=batch_p2p_comm, - ) - # if timers is not None: - # timers("backward-send-backward-recv").stop() - if overlap_p2p_comm: - return output_tensor_grad, wait_handles - return output_tensor_grad - - -def send_forward_backward_recv_forward_backward( - output_tensor: torch.Tensor, - input_tensor_grad: torch.Tensor, - recv_prev: bool, - recv_next: bool, - tensor_shape: Shape, - *, - dtype: Optional[torch.dtype] = None, - async_comm: bool = False, - sequence_parallel_enabled: bool = False, - sync_batch_comm: bool = True, - overlap_p2p_comm: bool = False, - batch_p2p_comm: bool = True, - timers: _Timers = None, -) -> Tuple[Union[torch.Tensor, FutureTensor], Union[torch.Tensor, FutureTensor]]: - """Batched send and recv with previous and next ranks in pipeline.""" - # if timers is not None: - # timers("forward-backward-send-forward-backward-recv").start() - input_tensor, output_tensor_grad, wait_handles = _communicate( - tensor_send_next=output_tensor, - tensor_send_prev=input_tensor_grad, - recv_prev=recv_prev, - recv_next=recv_next, - tensor_shape=tensor_shape, - dtype_=dtype, - async_comm=async_comm, - sequence_parallel_enabled=sequence_parallel_enabled, - sync_batch_comm=sync_batch_comm, - overlap_p2p_comm=overlap_p2p_comm, - batch_p2p_comm=batch_p2p_comm, - ) - # if timers is not None: - # timers("forward-backward-send-forward-backward-recv").stop() - if overlap_p2p_comm: - return input_tensor, output_tensor_grad, wait_handles - return input_tensor, output_tensor_grad diff --git a/apex/transformer/pipeline_parallel/schedules/__init__.py b/apex/transformer/pipeline_parallel/schedules/__init__.py deleted file mode 100644 index 7e0d0c25d..000000000 --- a/apex/transformer/pipeline_parallel/schedules/__init__.py +++ /dev/null @@ -1,36 +0,0 @@ -from apex.transformer import parallel_state -from apex.transformer.pipeline_parallel.utils import get_num_microbatches -from apex.transformer.pipeline_parallel.schedules.fwd_bwd_no_pipelining import ( - forward_backward_no_pipelining, -) -from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_with_interleaving import ( - _forward_backward_pipelining_with_interleaving, -) -from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_without_interleaving import ( - forward_backward_pipelining_without_interleaving, -) - -__all__ = [ - "get_forward_backward_func", -] - - -class ExperimentalWarning(Warning): - pass - - -def get_forward_backward_func( - virtual_pipeline_model_parallel_size, - pipeline_model_parallel_size, -): - if parallel_state.get_pipeline_model_parallel_world_size() > 1: - if virtual_pipeline_model_parallel_size is not None: - if get_num_microbatches() % pipeline_model_parallel_size != 0: - msg = "number of microbatches is not divisible by pipeline-parallel size when using interleaved schedule" - raise RuntimeError(msg) - forward_backward_func = _forward_backward_pipelining_with_interleaving - else: - forward_backward_func = forward_backward_pipelining_without_interleaving - else: - forward_backward_func = forward_backward_no_pipelining - return forward_backward_func diff --git a/apex/transformer/pipeline_parallel/schedules/common.py b/apex/transformer/pipeline_parallel/schedules/common.py deleted file mode 100644 index 802bf7231..000000000 --- a/apex/transformer/pipeline_parallel/schedules/common.py +++ /dev/null @@ -1,412 +0,0 @@ -from typing import Any, Callable, Dict, List, Tuple, Union, Optional, Sequence - -import torch -from torch.autograd.variable import Variable - -from apex.normalization.fused_layer_norm import FusedLayerNorm -from apex.transformer import parallel_state -from apex.transformer.enums import ModelType -from apex.transformer.pipeline_parallel.p2p_communication import FutureTensor -from apex.transformer.pipeline_parallel.utils import get_num_microbatches -from apex.transformer.pipeline_parallel.utils import listify_model -from apex.transformer.pipeline_parallel.utils import unwrap_model -from apex.transformer.pipeline_parallel.utils import get_model_type -from apex.transformer.tensor_parallel.layers import ( - set_defaults_if_not_set_tensor_model_parallel_attributes, -) -from apex.transformer.log_util import get_transformer_logger - - -_logger = get_transformer_logger(__name__) - - -Batch = Union[ - torch.Tensor, - FutureTensor, - List[Union[torch.Tensor, FutureTensor]], - Tuple[Union[torch.Tensor, FutureTensor], ...], -] -LossFunc = Callable[[torch.Tensor], torch.Tensor] -FwdStepFunc = Callable[[Optional[Batch], torch.nn.Module], Tuple[torch.Tensor, LossFunc]] - - -def build_model( - model_provider_func: Callable[[Any, Dict[str, Any]], torch.nn.Module], - wrap_with_ddp: bool = True, - virtual_pipeline_model_parallel_size: Optional[int] = None, - model_type: ModelType = ModelType.encoder_or_decoder, - *args: Any, - **kwargs: Any, -) -> List[torch.nn.Module]: - """Build the model satisfying pipeline model parallel requirements. - - This function sets `pre_process` and `post_process` to `**kwargs` and pass `*args` and `**kwargs` to - `model_provider_func`. - - Args: - model_provider_func: A function which takes `*args` and `**kwargs` and returns a `nn.Module`. - wrap_with_ddp: If :obj:`True`, wrap the instantiated model - with `torch.nn.parallel.distributed.DistributedDataParallel`, a.k.a. `DDP`. - virtual_pipeline_model_parallel_size: Specify when using interleaving scheduling pipeline model parallel. - model_type: - *args: arguments for model provider func - **kwargs: Keyword arguments for model provider func - - Returns: - a list of `nn.Module`(s). If `virtual_pipeline_model_parallel_size` is not None, - the list has multiple models, otherwise one. - """ - if ( - parallel_state.get_pipeline_model_parallel_world_size() > 1 - and virtual_pipeline_model_parallel_size is not None - ): - model = [] - for i in range(virtual_pipeline_model_parallel_size): - cur_args = args - cur_kwargs = kwargs - parallel_state.set_virtual_pipeline_model_parallel_rank(i) - # Set pre_process and post_process only after virtual rank is set. - pre_process = parallel_state.is_pipeline_first_stage() - post_process = parallel_state.is_pipeline_last_stage() - cur_kwargs.update( - { - "pre_process": pre_process, - "post_process": post_process, - } - ) - this_model = model_provider_func(*cur_args, **cur_kwargs) - model.append(this_model) - else: - cur_args = args - cur_kwargs = kwargs - if model_type == ModelType.encoder_or_decoder: - pre_process = parallel_state.is_pipeline_first_stage() - post_process = parallel_state.is_pipeline_last_stage() - cur_kwargs.update( - { - "pre_process": pre_process, - "post_process": post_process, - } - ) - model = model_provider_func(*cur_args, **cur_kwargs) - elif model_type == ModelType.encoder_and_decoder: - pre_process = parallel_state.is_pipeline_first_stage() - post_process = parallel_state.is_pipeline_last_stage() - # `add_encoder` & `add_decoder` logic. - add_encoder, add_decoder = True, True - if parallel_state.get_pipeline_model_parallel_world_size() > 1: - split_rank = parallel_state.get_pipeline_model_parallel_split_rank() - if split_rank is None: - raise RuntimeError( - "Split rank needs to be specified for model with both encoder and decoder." - ) - rank = parallel_state.get_pipeline_model_parallel_rank() - world_size = parallel_state.get_pipeline_model_parallel_world_size() - pre_process = rank == 0 or rank == split_rank - post_process = rank == (split_rank - 1) or rank == (world_size - 1) - add_encoder = parallel_state.is_pipeline_stage_before_split() - add_decoder = parallel_state.is_pipeline_stage_after_split() - cur_kwargs.update( - { - "pre_process": pre_process, - "post_process": post_process, - "add_encoder": add_encoder, - "add_decoder": add_decoder, - } - ) - model = model_provider_func(*cur_args, **cur_kwargs) - model.model_type = model_type - - if not isinstance(model, list): - model = [model] - - # Set tensor model parallel attributes if not set. - # Only parameters that are already tensor model parallel have these - # attributes set for them. We should make sure the default attributes - # are set for all params so the optimizer can use them. - for model_module in model: - for param in model_module.parameters(): - set_defaults_if_not_set_tensor_model_parallel_attributes(param) - - # Print number of parameters. - if ( - parallel_state.model_parallel_is_initialized() - and parallel_state.get_data_parallel_rank() == 0 - ): - msg = ( - " > number of parameters on (tensor, pipeline) model parallel rank ({}, {}): {}".format( - parallel_state.get_tensor_model_parallel_rank(), - parallel_state.get_pipeline_model_parallel_rank(), - _calc_number_of_params(model), - ) - ) - print(msg, flush=True) - - # GPU allocation. - for model_module in model: - model_module.cuda(torch.cuda.current_device()) - - if wrap_with_ddp: - i = torch.cuda.current_device() - model = [ - torch.nn.parallel.distributed.DistributedDataParallel( - model_module, - device_ids=[i], - output_device=i, - process_group=parallel_state.get_data_parallel_group(), - ) - for model_module in model - ] - return model - - -def _calc_number_of_params(model: List[torch.nn.Module]) -> int: - assert isinstance(model, list) - return sum([sum([p.nelement() for p in model_module.parameters()]) for model_module in model]) - - -def _get_params_for_weight_decay_optimization( - model: Union[torch.nn.Module, List[torch.nn.Module]], - *, - no_weight_decay_modules=(FusedLayerNorm,), -) -> Dict[str, torch.nn.Parameter]: - """Divide params into with-weight-decay and without-weight-decay groups. - - Layernorms and biases will have no weight decay but the rest will. - """ - modules = listify_model(model) - weight_decay_params = {"params": []} - no_weight_decay_params = {"params": [], "weight_decay": 0.0} - for module in modules: - for module_ in module.modules(): - if isinstance(module_, no_weight_decay_modules): - no_weight_decay_params["params"].extend( - [p for p in list(module_._parameters.values()) if p is not None] - ) - else: - weight_decay_params["params"].extend( - [ - p - for n, p in list(module_._parameters.items()) - if p is not None and n != "bias" - ] - ) - no_weight_decay_params["params"].extend( - [ - p - for n, p in list(module_._parameters.items()) - if p is not None and n == "bias" - ] - ) - - return weight_decay_params, no_weight_decay_params - - -def free_output_tensor( - output_tensors: Optional[Union[torch.Tensor, Sequence[torch.Tensor]]], - deallocate_pipeline_outputs: bool = False, -) -> None: - """Pseudo-free the output tensor's `.data` field. - - This method should be called right after the output tensor has been sent to the next - pipeline stage. At this point, the output tensor is only useful for its `.grad_fn` field, - and not its `.data`. - """ - if not deallocate_pipeline_outputs: - return - if output_tensors is None: - return - if isinstance(output_tensors, torch.Tensor): - output_tensors = [output_tensors] - for output_tensor in output_tensors: - output_tensor.data = torch.cuda.FloatTensor([0]) - - -def custom_backward(output: torch.Tensor, grad_output: Optional[torch.Tensor]) -> None: - """Directly call C++ autograd engine. - - To make the `free_output_tensor` optimization work, the C++ autograd engine must be called - directly, bypassing PyTorch's `torch.autograd.backward`. PyTorch's `backward` checks that the - output and grad have the same shape, while C++ `backward` does not. - """ - assert output.numel() == 1, ( - "output should be pseudo-freed in schedule, to optimize memory consumption" - ) - assert isinstance(output, torch.Tensor), "output == {}.".format(type(output).__name__) - assert isinstance(grad_output, (torch.Tensor, type(None))), "grad_outptu == {}.".format( - type(grad_output).__name__ - ) - - # Handle scalar output - if grad_output is None: - assert output.numel() == 1, "Implicit grad requires scalar output." - grad_output = torch.ones_like(output, memory_format=torch.preserve_format) - - # Call C++ engine [ see torch/csrc/autograd/python_engine.cpp ] - Variable._execution_engine.run_backward( - tensors=(output,), - grad_tensors=(grad_output,), - keep_graph=False, - create_graph=False, - inputs=(), - allow_unreachable=True, - accumulate_grad=True, - ) - - -def forward_step( - forward_step_func: FwdStepFunc, - batch: Optional[Batch], - model: torch.nn.Module, - input_tensor: Optional[Union[torch.Tensor, List[torch.Tensor]]], - losses_reduced: List[torch.Tensor], - dtype: torch.dtype, - disable_autocast: bool = False, - checkpoint_activations_micro_batch: Optional[bool] = None, -) -> Union[torch.Tensor, Sequence[torch.Tensor]]: - """Forward step for passed-in model. - - If first stage, input tensor is obtained from batch, otherwise passed-in input_tensor is used. - - Returns output tensor. - - Args: - forward_step_func: Model specific function. This takes a minibatch and model as its arguments and - returns the model's output and the loss function. - batch: minibatch - model: unwrappable model - input_tensor: - losses_reduced: - dtype: - disable_autocast: - checkpoint_activations_micro_batch: - - Returns: - output_tensor - """ - # timers = get_timers() - # timers("forward-compute").start() - unwrapped_model = unwrap_model(model) - model_type = get_model_type(unwrapped_model) - # NOTE (mkozuki): The passed `model` is expected to implement `set_input_tensor`. - # See https://github.com/NVIDIA/Megatron-LM/blob/5ac5571ba0265af4c491ee0af1508ca7589450c6/megatron/model/transformer.py#L679 # NOQA - # for the details of `set_input_tensor`. - unwrap_output_tensor = not isinstance(input_tensor, list) - if unwrap_output_tensor: - input_tensor = [input_tensor] - - input_tensor = [inp.get() if isinstance(inp, FutureTensor) else inp for inp in input_tensor] - - unwrapped_model.set_input_tensor(input_tensor) - with torch.amp.autocast( - "cuda", - enabled=not disable_autocast and dtype in (torch.half, torch.bfloat16), - dtype=dtype, - ): - if checkpoint_activations_micro_batch is None: - output_tensor, loss_func = forward_step_func(batch, model) - else: - output_tensor, loss_func = forward_step_func( - batch, model, checkpoint_activations_micro_batch - ) - if parallel_state.is_pipeline_last_stage(): - output_tensor = loss_func(output_tensor) - loss, loss_reduced = output_tensor - output_tensor = loss / get_num_microbatches() - losses_reduced.append(loss_reduced) - # timers("forward-compute").stop() - - # If T5 model (or other model with encoder and decoder) - # and in decoder stack, then send encoder_hidden_state - # downstream as well. - if ( - parallel_state.is_pipeline_stage_after_split() - and model_type == ModelType.encoder_and_decoder - ): - return [output_tensor, input_tensor[-1]] - if unwrap_output_tensor: - return output_tensor - return [output_tensor] - - -def backward_step( - input_tensor: Optional[torch.Tensor], - output_tensor: torch.Tensor, - output_tensor_grad: Optional[torch.Tensor], - model_type: ModelType, - *, - grad_scaler: Optional[torch.cuda.amp.GradScaler] = None, - deallocate_pipeline_outputs: bool = False, -) -> Union[None, torch.Tensor, Sequence[torch.Tensor]]: - """Backward step through passed-in output tensor. - - If last stage, output_tensor_grad is None, otherwise gradient of loss - with respect to stage's output tensor. - - Returns gradient of loss with respect to input tensor (None if first - stage). - - Args: - input_tensor: - output_tensor: - output_tensor_grad: - Keyword Arguments: - grad_scaler: - deallocate_pipeline_outputs: Experimental. - Returns: - input_tensor_grad - """ - - # timers = get_timers() - # timers("backward-compute").start() - - # Retain the grad on the input_tensor. - unwrap_input_tensor_grad = not isinstance(input_tensor, list) - if unwrap_input_tensor_grad: - input_tensor = [input_tensor] - - input_tensor = [inp.get() if isinstance(inp, FutureTensor) else inp for inp in input_tensor] - - for x in input_tensor: - if x is not None: - x.retain_grad() - - if not isinstance(output_tensor, list): - output_tensor = [output_tensor] - - output_tensor = [out.get() if isinstance(out, FutureTensor) else out for out in output_tensor] - - if not isinstance(output_tensor_grad, list): - output_tensor_grad = [output_tensor_grad] - - output_tensor_grad = [ - ogr.get() if isinstance(ogr, FutureTensor) else ogr for ogr in output_tensor_grad - ] - - # Backward pass. - if grad_scaler is not None and output_tensor_grad[0] is None: - output_tensor[0] = grad_scaler.scale(output_tensor[0]) - if deallocate_pipeline_outputs: - custom_backward(output_tensor[0], output_tensor_grad[0]) - else: - torch.autograd.backward(output_tensor[0], grad_tensors=output_tensor_grad[0]) - - # Collect the grad of the input_tensor. - input_tensor_grad = [None] - if input_tensor is not None: - input_tensor_grad = [] - for x in input_tensor: - input_tensor_grad.append(None if x is None else x.grad) - - # Handle single skip connection if it exists (encoder_hidden_state in model with encoder and decoder). - if ( - parallel_state.get_pipeline_model_parallel_world_size() > 1 - and parallel_state.is_pipeline_stage_after_split() - and model_type == ModelType.encoder_and_decoder - ): - if output_tensor_grad[1] is not None: - # todo (mkozuki): Replace the inplace add with `+= output_tensor_grad[1]`? - input_tensor_grad[-1].add_(output_tensor_grad[1]) - - # timers("backward-compute").stop() - return input_tensor_grad[0] if unwrap_input_tensor_grad else input_tensor_grad diff --git a/apex/transformer/pipeline_parallel/schedules/fwd_bwd_no_pipelining.py b/apex/transformer/pipeline_parallel/schedules/fwd_bwd_no_pipelining.py deleted file mode 100644 index 66e664865..000000000 --- a/apex/transformer/pipeline_parallel/schedules/fwd_bwd_no_pipelining.py +++ /dev/null @@ -1,132 +0,0 @@ -import contextlib -from typing import List, Union, Optional - -import torch - -from apex.transformer.pipeline_parallel.utils import listify_model -from apex.transformer.pipeline_parallel.utils import get_num_microbatches -from apex.transformer.pipeline_parallel.utils import get_kth_microbatch -from apex.transformer.pipeline_parallel.utils import get_model_type -from apex.transformer.pipeline_parallel.schedules.common import Batch -from apex.transformer.pipeline_parallel.schedules.common import FwdStepFunc -from apex.transformer.pipeline_parallel.schedules.common import forward_step -from apex.transformer.pipeline_parallel.schedules.common import backward_step -from apex.transformer.log_util import get_transformer_logger - - -_all__ = ["forward_backward_no_pipelining"] - - -_logger = get_transformer_logger(__name__) - - -def forward_backward_no_pipelining( - forward_step_func: FwdStepFunc, - batch: Batch, - model: Union[torch.nn.Module, List[torch.nn.Module]], - *, - forward_only: bool, - dtype: Optional[torch.dtype] = None, - grad_scaler: Optional[torch.cuda.amp.GradScaler] = None, - disable_autocast: bool = False, - custom_sync_context_handler=None, - **kwargs, -): - """Run forward and backward passes with no pipeline parallelism (no inter-stage communication). - - This pipeline parallel scheduling handles the last microbatch differently to synchronize gradients. - - Args: - forward_step_func: A function which takes a minibatch and model as its arguments and - returns model's forward output and the loss function. - The loss function is supposed to take one `torch.Tensor` and - return a `torch.Tensor` of loss and a dictionary of `str` and `torch.Tensor`. - batch: A List of torch.Tensors - model: A `torch.nn.Module` or a list of `torch.nn.Module`. - - Keyword args: - forward_only: - grad_scaler: - dtype: - disable_autocast: Turn off `enabled` flag of `torch.cuda.amp.autocast` if :obj:`True`. - Should be used when your forward and loss computation is in the autocast context to - avoid unnecesarily nest autocast context. - custom_sync_context_handler: Context manager to disable asynchronous gradient reductions. - **kwargs: Added to handle `tensor_shape` which has no effect on this function. - - Returns: - a list of dictionaries of loss `torch.Tensor`s if the last stage, empty list otherwise. - """ - from apex import deprecated_warning - - deprecated_warning( - "`apex.transformer` is deprecated and will be removed in September 2025. " - "We encourage you to migrate to Megatron Core. " - "It is available on PyPI at https://pypi.org/project/megatron-core/ " - "and its documentation can be found at https://docs.nvidia.com/megatron-core/developer-guide/latest/index.html." - ) - model = listify_model(model) - if len(model) != 1: - msg = f"`model` is expected be a `nn.Module`, but {type(model)}" - raise RuntimeError(msg) - model = model[0] - model_type = get_model_type(model) - - if custom_sync_context_handler is not None: - context_handler = custom_sync_context_handler - elif isinstance(model, torch.nn.parallel.distributed.DistributedDataParallel): - context_handler = model.no_sync - else: - context_handler = contextlib.nullcontext - - losses_reduced = [] - input_tensor, output_tensor_grad = None, None - num_micro_batches = get_num_microbatches() - with context_handler(): - for i in range(num_micro_batches - 1): - _logger.info(f"Iter {i} of {num_micro_batches - 1}") - cur_micro_batch = get_kth_microbatch(batch, i) - _logger.debug("Call `forward_step`") - output_tensor = forward_step( - forward_step_func, - cur_micro_batch, - model, - input_tensor, - losses_reduced, - dtype=dtype, - disable_autocast=disable_autocast, - ) - if not forward_only: - _logger.debug("Call `backward_step`") - backward_step( - input_tensor, - output_tensor, - output_tensor_grad, - model_type=model_type, - grad_scaler=grad_scaler, - ) - - # Run computation for last microbatch out of context handler (want to - # synchronize gradients). - _logger.info("Cooldown") - _logger.debug("Call `forward_step`") - output_tensor = forward_step( - forward_step_func, - get_kth_microbatch(batch, num_micro_batches - 1), - model, - input_tensor, - losses_reduced, - dtype=dtype, - disable_autocast=disable_autocast, - ) - if not forward_only: - _logger.debug("Call `backward_step`") - backward_step( - input_tensor, - output_tensor, - output_tensor_grad, - model_type=model_type, - grad_scaler=grad_scaler, - ) - - return losses_reduced diff --git a/apex/transformer/pipeline_parallel/schedules/fwd_bwd_pipelining_with_interleaving.py b/apex/transformer/pipeline_parallel/schedules/fwd_bwd_pipelining_with_interleaving.py deleted file mode 100644 index 2a41c2200..000000000 --- a/apex/transformer/pipeline_parallel/schedules/fwd_bwd_pipelining_with_interleaving.py +++ /dev/null @@ -1,754 +0,0 @@ -import contextlib -from typing import Callable, List, Optional, Sequence, Union -import warnings - -import torch - -from apex.transformer import parallel_state -from apex.transformer.pipeline_parallel import p2p_communication -from apex.transformer.pipeline_parallel.schedules.common import Batch -from apex.transformer.pipeline_parallel.schedules.common import FwdStepFunc -from apex.transformer.pipeline_parallel.schedules.common import backward_step -from apex.transformer.pipeline_parallel.schedules.common import forward_step -from apex.transformer.pipeline_parallel.schedules.common import free_output_tensor -from apex.transformer.pipeline_parallel.utils import get_kth_microbatch -from apex.transformer.pipeline_parallel.utils import get_num_microbatches -from apex.transformer.pipeline_parallel.utils import get_model_type -from apex.transformer.log_util import get_transformer_logger - - -__all__ = ["_forward_backward_pipelining_with_interleaving"] - - -_logger = get_transformer_logger(__name__) - - -# TODO(mkozuki): Reduce cyclomatic complexity -def _forward_backward_pipelining_with_interleaving( - forward_step_func: FwdStepFunc, - batch: List[Optional[Batch]], - model: List[torch.nn.Module], - *, - forward_only: bool, - tensor_shape: Optional[Union[List[int], torch.Size]] = None, - dtype: Optional[torch.dtype] = None, - grad_scaler: Optional[torch.cuda.amp.GradScaler] = None, - disable_autocast: bool = False, - deallocate_pipeline_outputs: bool = False, - async_comm: bool = False, - sequence_parallel_enabled: bool = False, - custom_sync_context_handler: Optional[Callable] = None, - custom_grad_sync_func: Optional[Callable] = None, - custom_param_sync_func: Optional[Callable] = None, - sync_batch_comm: bool = True, - num_micro_batches_with_partial_activation_checkpoints: Optional[int] = None, - overlap_p2p_comm: bool = False, - batch_p2p_comm: bool = True, - **kwargs, -) -> List[Union[torch.Tensor, Sequence[torch.Tensor]]]: - """Run interleaved 1F1B schedule with communication between pipeline stages as needed. - - This function assumes `batch` and `model` is a list of `Batch`'s and a list of `torch.nn.Module`, respectively. - This means that model is split into model chunks. - - This pipeline parallel scheduling consists of three steps: - 1. warmup - 2. 1F1B a.k.a. steady state - 3. cooldown - Note that if `forward_only` this scheduling consists of only warmup phase. - - Args: - forward_step_func: A function which takes a minibatch and model as its arguments and - returns model's forward output and the loss function. - The loss function is supposed to take one `torch.Tensor` and - return a `torch.Tensor` of loss and a dictionary of `str` and `torch.Tensor`. - batch: A minibatch, i.e., a list of `torch.Tensor`'s. - model: A `torch.nn.Module` or a list of `torch.nn.Module`. - - Keyword args: - forward_only: - tensor_shape: Shape of tensor. The tensor is expected to be 3D and its order of dimension - is supposed to be ``(sequence, batch, hidden)``. - dtype: dtype used in p2p communication. If ``None`` (default value), - torch.float32 will be used even if ``autocast`` is enabled. - grad_scaler: - disable_autocast: - deallocate_pipeline_outputs: If :obj:`True`, free the data of the output tensor of - each pipeline stage. Experimental. - sequence_parallel_enabled: Set to :obj:`True` for this function to handle sequence length. - When :obj:`True`, the sequence length on each tensor model parallel rank is updated - to :math:`original\_sequence\_length / tensor\_model\_parallel\_world\_size`. - custom_sync_context_handler: If provided, this is treated as a - function to construct a context manager to disable - asynchronous gradient reductions. Asynchronous gradient - reductions are only enabled in the final backward pass of - each model chunk. - custom_grad_sync_func: If provided, this is treated as a - function to launch asynchronous gradient reductions (e.g. - reduce-scatters with distributed optimizer). The function - should take one positional argument: a list of parameters - whose gradients should be synchronized. Asynchronous - gradient reductions are launched after the final backward - pass of each model chunk. - custom_param_sync_func: If provided, this is treated as a - function to launch asynchronous parameter synchronizations - (e.g. all-gathers with distributed optimizer). The - function should take one positional argument: a list of - parameters whose values should be synchronized. - Asynchronous parameter synchronizations are launched - before the first forward pass of each model chunk. - sync_batch_comm: If :obj:`False`, disable cuda synchronization after the batched communication. - To disable, https://github.com/pytorch/pytorch/pull/82450 would be required. - num_micro_batches_with_partial_activation_checkpoints: If :obj:`int`, set the number of - micro-batches checkpointing the activation of partial number of Transformer layers. - The rest of the micro-batch within the window of maximum outstanding micro-batch - backpropagations would checkpoint all Transformer layers. - overlap_p2p_comm: If :obj:`True`, returns cuda wait handles to scheduler instead of completing - the communication within the p2p transfer API instance. The scheduler manages the communication completion - to overlap with computation. - batch_p2p_comm: If :obj:`True`, use the batched send and receive api to conduct the communication of - a collection of send and receive operations between peer. If :obj:`False`, conduct each send and recv operation - individually. - - Returns: - a list of loss `torch.Tensor`s if the last stage, empty list otherwise. - - """ - from apex import deprecated_warning - - deprecated_warning( - "`apex.transformer` is deprecated and will be removed in September 2025. " - "We encourage you to migrate to Megatron Core. " - "It is available on PyPI at https://pypi.org/project/megatron-core/ " - "and its documentation can be found at https://docs.nvidia.com/megatron-core/developer-guide/latest/index.html." - ) - if not isinstance(model, list): - raise RuntimeError("`model` must be a list of `nn.Module`'s'") - - if deallocate_pipeline_outputs: - warnings.warn( - "`deallocate_pipeline_outputs` is experimental and subject to change. " - "This option is not recommended." - ) - - # Construct helper functions for async grad reductions - if custom_sync_context_handler is not None: - sync_context_handler = custom_sync_context_handler - else: - sync_context_handler = contextlib.nullcontext - sync_context = None - - def disable_grad_sync(): - """Disable asynchronous grad reductions""" - nonlocal sync_context - if sync_context is None: - sync_context = sync_context_handler() - sync_context.__enter__() - - def enable_grad_sync(): - """Enable asynchronous grad reductions""" - nonlocal sync_context - if sync_context is not None: - sync_context.__exit__(None, None, None) - sync_context = None - - disable_grad_sync() - - # mypy will blame the following if statement - if sequence_parallel_enabled: - seq_length, batch_size, hidden = tensor_shape - tensor_shape = ( - seq_length // parallel_state.get_tensor_model_parallel_world_size(), - batch_size, - hidden, - ) - - num_model_chunks: int = len(model) - input_tensors: List[List[Union[None, torch.Tensor]]] = [[] for _ in range(num_model_chunks)] - output_tensors: List[List[Union[None, torch.Tensor]]] = [[] for _ in range(num_model_chunks)] - curr_iters: List[int] = [0 for _ in range(num_model_chunks)] - losses_reduced: List[Union[None, torch.Tensor]] = [] - if not forward_only: - output_tensor_grads: List[List[Union[None, torch.Tensor]]] = [ - [] for _ in range(num_model_chunks) - ] - - pipeline_parallel_size: int = parallel_state.get_pipeline_model_parallel_world_size() - pipeline_parallel_rank: int = parallel_state.get_pipeline_model_parallel_rank() - - # Compute number of warmup and remaining microbatches. - num_microbatches: int = get_num_microbatches() * num_model_chunks - all_warmup_microbatches: bool = False - if forward_only: - num_warmup_microbatches: int = num_microbatches - else: - # Run all forward passes and then all backward passes if number of - # microbatches is just the number of pipeline stages. - # Otherwise, perform (num_model_chunks-1)*pipeline_parallel_size on - # all workers, followed by more microbatches after depending on - # stage ID (more forward passes for earlier stages, later stages can - # immediately start with 1F1B). - if get_num_microbatches() == pipeline_parallel_size: - num_warmup_microbatches = num_microbatches - all_warmup_microbatches = True - else: - num_warmup_microbatches = (pipeline_parallel_size - pipeline_parallel_rank - 1) * 2 - num_warmup_microbatches += (num_model_chunks - 1) * pipeline_parallel_size - num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches) - num_microbatches_remaining: int = num_microbatches - num_warmup_microbatches - - # Checkpoint the activations of partial Transformer layers in a number of micro-batches - # within the maximum outstanding micro-batch backpropagations. - # Micro-batches with the ids less than 'num_micro_batches_with_partial_activation_checkpoints' - # checkpoint partial Transformer layers (or skip checkpointing) and - # the rest of micro-batches within a window of micro-batches checkpoint - # all Transformer layers. The window of micro-batches is set by the maximum - # outstanding backpropagations and becomes smaller at later pipeline stages. - # Please refer the appendix C in https://arxiv.org/pdf/2205.05198.pdf - max_outstanding_backprops = None - if num_micro_batches_with_partial_activation_checkpoints is not None: - max_outstanding_backprops = num_warmup_microbatches + 1 - - _logger.info( - f"num_microbatches: {num_microbatches}, " - f"num_warmup_microbatches: {num_warmup_microbatches}, " - f"num_microbatches_remaining: {num_microbatches_remaining}" - ) - - # Synchronize params for first two model chunks - if custom_param_sync_func is not None: - custom_param_sync_func(model[0].parameters()) - custom_param_sync_func(model[1].parameters()) - - ################################################################################################################### - # Helper function definitions. - ################################################################################################################### - def get_model_chunk_id(microbatch_id: int, forward: bool) -> int: - """Helper function to get the model chunk ID given the iteration number. - - Each model chunk processes pipeline_parallel_size microbatches - at a time. We assume that the number of microbatches is a - multiple of pipeline_parallel_size*num_model_chunks. - """ - microbatch_group_size = pipeline_parallel_size * num_model_chunks - microbatch_id_in_group = microbatch_id % microbatch_group_size - model_chunk_id = microbatch_id_in_group // pipeline_parallel_size - if not forward: - model_chunk_id = num_model_chunks - model_chunk_id - 1 - return model_chunk_id - - def is_first_microbatch_for_model_chunk(microbatch_id: int) -> bool: - """Helper function to check if an iteration is the first for a model - chunk. - """ - microbatch_group_size = pipeline_parallel_size * num_model_chunks - num_microbatch_groups = num_microbatches // microbatch_group_size - microbatch_group_id = microbatch_id // microbatch_group_size - microbatch_id_in_group = microbatch_id % microbatch_group_size - if microbatch_group_id == 0: - return microbatch_id_in_group % pipeline_parallel_size == 0 - else: - return False - - def is_last_microbatch_for_model_chunk(microbatch_id: int) -> bool: - """Helper function to check if an iteration is the last for a model - chunk. - """ - microbatch_group_size = pipeline_parallel_size * num_model_chunks - num_microbatch_groups = num_microbatches // microbatch_group_size - microbatch_group_id = microbatch_id // microbatch_group_size - microbatch_id_in_group = microbatch_id % microbatch_group_size - if microbatch_group_id == num_microbatch_groups - 1: - return microbatch_id_in_group % pipeline_parallel_size == pipeline_parallel_size - 1 - else: - return False - - def forward_step_helper( - microbatch_id: int, - curr_iters: List[int], - checkpoint_activations_micro_batch: Optional[bool] = None, - ) -> torch.Tensor: - """Helper method to run forward step with model split into chunks - - (run set_virtual_pipeline_model_parallel_rank() before calling forward_step()). - """ - model_chunk_id = get_model_chunk_id(microbatch_id, forward=True) - parallel_state.set_virtual_pipeline_model_parallel_rank(model_chunk_id) - - # launch param synchronization for next model chunk - # Note: To achieve maximum performance, pipeline parallelism - # assumes all ranks have the same compute time. However, - # asynchronous communication tends to slow down compute. Thus, - # we launch asynchronous communication at the same time across - # the pipeline-parallel group. - if custom_param_sync_func is not None: - param_sync_microbatch_id = microbatch_id + pipeline_parallel_rank - if param_sync_microbatch_id < num_microbatches and is_first_microbatch_for_model_chunk( - param_sync_microbatch_id - ): - param_sync_chunk_id = get_model_chunk_id(param_sync_microbatch_id, forward=True) + 1 - if 1 < param_sync_chunk_id < num_model_chunks: - custom_param_sync_func(model[param_sync_chunk_id].parameters()) - - # forward step - if parallel_state.is_pipeline_first_stage() and len(input_tensors[model_chunk_id]) == len( - output_tensors[model_chunk_id] - ): - input_tensors[model_chunk_id].append(None) - input_tensor = input_tensors[model_chunk_id][-1] - output_tensor = forward_step( - forward_step_func, - get_kth_microbatch(batch, curr_iters[model_chunk_id]), - model[model_chunk_id], - input_tensor, - losses_reduced, - dtype, - disable_autocast, - checkpoint_activations_micro_batch, - ) - curr_iters[model_chunk_id] += 1 - output_tensors[model_chunk_id].append(output_tensor) - - # if forward-only, no need to save tensors for a backward pass - if forward_only: - input_tensors[model_chunk_id].pop() - output_tensors[model_chunk_id].pop() - - return output_tensor - - def backward_step_helper(microbatch_id: int) -> torch.Tensor: - """Helper method to run backward step with model split into chunks - - (run set_virtual_pipeline_model_parallel_rank() before calling backward_step()). - """ - model_chunk_id = get_model_chunk_id(microbatch_id, forward=False) - model_type = get_model_type(model[model_chunk_id]) - parallel_state.set_virtual_pipeline_model_parallel_rank(model_chunk_id) - - # launch grad synchronization (default) - if custom_grad_sync_func is None and is_last_microbatch_for_model_chunk(microbatch_id): - enable_grad_sync() - - # backward step - if parallel_state.is_pipeline_last_stage(): - if len(output_tensor_grads[model_chunk_id]) == 0: - output_tensor_grads[model_chunk_id].append(None) - input_tensor = input_tensors[model_chunk_id].pop(0) - output_tensor = output_tensors[model_chunk_id].pop(0) - output_tensor_grad = output_tensor_grads[model_chunk_id].pop(0) - input_tensor_grad = backward_step( - input_tensor, - output_tensor, - output_tensor_grad, - model_type=model_type, - grad_scaler=grad_scaler, - deallocate_pipeline_outputs=deallocate_pipeline_outputs, - ) - - # launch grad synchronization (custom grad sync) - # Note: To achieve maximum performance, pipeline parallelism - # assumes all ranks have the same compute time. However, - # asynchronous communication tends to slow down compute. Thus, - # we launch asynchronous communication at the same time across - # the pipeline-parallel group. - if custom_grad_sync_func is not None: - grad_sync_microbatch_id = microbatch_id - pipeline_parallel_rank - if grad_sync_microbatch_id >= 0 and is_last_microbatch_for_model_chunk( - grad_sync_microbatch_id - ): - grad_sync_chunk_id = get_model_chunk_id(grad_sync_microbatch_id, forward=False) - enable_grad_sync() - custom_grad_sync_func(model[grad_sync_chunk_id].parameters()) - disable_grad_sync() - - return input_tensor_grad - - ################################################################################################################### - # Run warmup forward passes. - ################################################################################################################### - fwd_wait_handles, bwd_wait_handles = None, None - parallel_state.set_virtual_pipeline_model_parallel_rank(0) - input_tensors[0].append( - p2p_communication.recv_forward( - tensor_shape=tensor_shape, - dtype=dtype, - async_comm=async_comm, - sequence_parallel_enabled=sequence_parallel_enabled, - sync_batch_comm=sync_batch_comm, - batch_p2p_comm=batch_p2p_comm, - ) - ) - _logger.info("Warmup phase") - for k in range(num_warmup_microbatches): - _logger.debug(f"warmup iter: {k} / {num_warmup_microbatches}") - - # Decide to checkpoint all layers' activations of the current micro-batch - if max_outstanding_backprops is not None: - checkpoint_activations_micro_batch = ( - k % max_outstanding_backprops - >= num_micro_batches_with_partial_activation_checkpoints - ) - else: - checkpoint_activations_micro_batch = None - - if fwd_wait_handles is not None: - for wait_handle in fwd_wait_handles: - wait_handle.wait() - - output_tensor = forward_step_helper(k, curr_iters, checkpoint_activations_micro_batch) - - # Determine if tensor should be received from previous stage. - next_forward_model_chunk_id = get_model_chunk_id(k + 1, forward=True) - recv_prev = True - if parallel_state.is_pipeline_first_stage(ignore_virtual=True): - if next_forward_model_chunk_id == 0: - recv_prev = False - if k == (num_microbatches - 1): - recv_prev = False - _logger.debug( - f"next fwd model chunk ID: {next_forward_model_chunk_id}, recv_prev: {recv_prev}" - ) - - # Don't send tensor downstream if on last stage. - if parallel_state.is_pipeline_last_stage(): - _logger.debug("Pipeline last stage, not sending tensor downstream") - output_tensor = None - - if overlap_p2p_comm: - # P2P communications in warmup are not overlapped with computes. We split P2P - # communications for activation forward and activation_gradient backward in warmup, - # to match the send/recv API granularity in 1F1B in case of using batched send/recv API. - - # Send and receive tensors as appropriate (send tensors computed - # in this iteration; receive tensors for next iteration). - _logger.debug("send fwd and receive fwd") - input_tensor, fwd_wait_handles = p2p_communication.send_forward_recv_forward( - output_tensor, - recv_prev=recv_prev, - tensor_shape=tensor_shape, - dtype=dtype, - async_comm=async_comm, - sequence_parallel_enabled=sequence_parallel_enabled, - sync_batch_comm=sync_batch_comm, - overlap_p2p_comm=True, - batch_p2p_comm=batch_p2p_comm, - ) - if ( - k == (num_warmup_microbatches - 1) - and not forward_only - and not all_warmup_microbatches - ): - input_tensor_grad = None - recv_next = True - if parallel_state.is_pipeline_last_stage(ignore_virtual=True): - recv_next = False - _logger.debug("send bwd and receive bwd") - output_tensor_grad, bwd_wait_handles = ( - p2p_communication.send_backward_recv_backward( - input_tensor_grad, - recv_next=recv_next, - tensor_shape=tensor_shape, - dtype=dtype, - async_comm=async_comm, - sequence_parallel_enabled=sequence_parallel_enabled, - sync_batch_comm=sync_batch_comm, - overlap_p2p_comm=True, - batch_p2p_comm=batch_p2p_comm, - ) - ) - output_tensor_grads[num_model_chunks - 1].append(output_tensor_grad) - input_tensors[next_forward_model_chunk_id].append(input_tensor) - else: - # Send and receive tensors as appropriate (send tensors computed - # in this iteration; receive tensors for next iteration). - if ( - k == (num_warmup_microbatches - 1) - and not forward_only - and not all_warmup_microbatches - ): - input_tensor_grad = None - recv_next = True - if parallel_state.is_pipeline_last_stage(ignore_virtual=True): - recv_next = False - _logger.debug("send fwd&bwd and receive fwd&bwd") - ( - input_tensor, - output_tensor_grad, - ) = p2p_communication.send_forward_backward_recv_forward_backward( - output_tensor, - input_tensor_grad, - recv_prev=recv_prev, - recv_next=recv_next, - tensor_shape=tensor_shape, - dtype=dtype, - async_comm=async_comm, - sequence_parallel_enabled=sequence_parallel_enabled, - sync_batch_comm=sync_batch_comm, - batch_p2p_comm=batch_p2p_comm, - ) - output_tensor_grads[num_model_chunks - 1].append(output_tensor_grad) - else: - _logger.debug("send fwd and receive fwd") - input_tensor = p2p_communication.send_forward_recv_forward( - output_tensor, - recv_prev=recv_prev, - tensor_shape=tensor_shape, - dtype=dtype, - async_comm=async_comm, - sequence_parallel_enabled=sequence_parallel_enabled, - sync_batch_comm=sync_batch_comm, - batch_p2p_comm=batch_p2p_comm, - ) - input_tensors[next_forward_model_chunk_id].append(input_tensor) - free_output_tensor(output_tensor, deallocate_pipeline_outputs) - - ################################################################################################################### - # Run 1F1B in steady state. - ################################################################################################################### - _logger.info("Steady phase") - for k in range(num_microbatches_remaining): - # Forward pass. - _logger.debug(f" steady phase iter {k} / {num_microbatches_remaining}") - forward_k = k + num_warmup_microbatches - - # Decide to checkpoint all layers' activations of the current micro-batch - if max_outstanding_backprops is not None: - checkpoint_activations_micro_batch = ( - forward_k % max_outstanding_backprops - >= num_micro_batches_with_partial_activation_checkpoints - ) - else: - checkpoint_activations_micro_batch = None - - if overlap_p2p_comm: - if fwd_wait_handles is not None: - for wait_handle in fwd_wait_handles: - wait_handle.wait() - - output_tensor = forward_step_helper( - forward_k, curr_iters, checkpoint_activations_micro_batch - ) - - # Set forward model chunk id - forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True) - parallel_state.set_virtual_pipeline_model_parallel_rank(forward_model_chunk_id) - - # Last virtual stage no activation tensor to send - if parallel_state.is_pipeline_last_stage(): - output_tensor = None - - # Determine if the current virtual stage has an activation tensor to receive - recv_prev = True - if parallel_state.is_pipeline_first_stage(ignore_virtual=True): - # First stage is ahead of last stage by (pipeline_parallel_size - 1). - next_forward_model_chunk_id = get_model_chunk_id( - forward_k - (pipeline_parallel_size - 1), forward=True - ) - if next_forward_model_chunk_id == (num_model_chunks - 1): - recv_prev = False - next_forward_model_chunk_id += 1 - else: - next_forward_model_chunk_id = get_model_chunk_id(forward_k + 1, forward=True) - - # If last iteration, don't receive; we already received one extra - # before the start of the for loop. - if k == (num_microbatches_remaining - 1): - recv_prev = False - - # Send activation tensor to the next stage and receive activation tensor from the - # previous stage - _logger.debug("send fwd and receive fwd") - input_tensor, fwd_wait_handles = p2p_communication.send_forward_recv_forward( - output_tensor, - recv_prev=recv_prev, - tensor_shape=tensor_shape, - dtype=dtype, - async_comm=async_comm, - sequence_parallel_enabled=sequence_parallel_enabled, - sync_batch_comm=sync_batch_comm, - overlap_p2p_comm=True, - batch_p2p_comm=batch_p2p_comm, - ) - - if bwd_wait_handles is not None: - for wait_handle in bwd_wait_handles: - wait_handle.wait() - - # Backward pass. - backward_k = k - input_tensor_grad = backward_step_helper(backward_k) - - # Set backward model chunk id - backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False) - parallel_state.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id) - _logger.debug( - f"fwd/bwd model chunk id: {forward_model_chunk_id}/{backward_model_chunk_id}" - ) - - # First virtual stage no activation gradient tensor to send - if parallel_state.is_pipeline_first_stage(): - input_tensor_grad = None - - # Determine if the current virtual stage has an activation gradient tensor to receive - recv_next = True - if parallel_state.is_pipeline_last_stage(ignore_virtual=True): - # Last stage is ahead of first stage by (pipeline_parallel_size - 1). - next_backward_model_chunk_id = get_model_chunk_id( - backward_k - (pipeline_parallel_size - 1), forward=False - ) - if next_backward_model_chunk_id == 0: - recv_next = False - next_backward_model_chunk_id -= 1 - else: - next_backward_model_chunk_id = get_model_chunk_id(backward_k + 1, forward=False) - - # Send activation grad tensor to the previous stage and receive activation grad tensor - # from the previous stage - _logger.debug("send bwd and receive bwd") - output_tensor_grad, bwd_wait_handles = p2p_communication.send_backward_recv_backward( - input_tensor_grad, - recv_next=recv_next, - tensor_shape=tensor_shape, - dtype=dtype, - async_comm=async_comm, - sequence_parallel_enabled=sequence_parallel_enabled, - sync_batch_comm=sync_batch_comm, - overlap_p2p_comm=True, - batch_p2p_comm=batch_p2p_comm, - ) - else: - output_tensor = forward_step_helper( - forward_k, curr_iters, checkpoint_activations_micro_batch - ) - - # Backward pass. - backward_k = k - input_tensor_grad = backward_step_helper(backward_k) - - # Send output_tensor and input_tensor_grad, receive input_tensor - # and output_tensor_grad. - - # Determine if current stage has anything to send in either direction, - # otherwise set tensor to None. - forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True) - parallel_state.set_virtual_pipeline_model_parallel_rank(forward_model_chunk_id) - if parallel_state.is_pipeline_last_stage(): - output_tensor = None - - backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False) - parallel_state.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id) - _logger.debug( - f"fwd/bwd model chunk id: {forward_model_chunk_id}/{backward_model_chunk_id}" - ) - if parallel_state.is_pipeline_first_stage(): - input_tensor_grad = None - - # Determine if peers are sending, and where in data structure to put - # received tensors. - recv_prev = True - if parallel_state.is_pipeline_first_stage(ignore_virtual=True): - # First stage is ahead of last stage by (pipeline_parallel_size - 1). - next_forward_model_chunk_id = get_model_chunk_id( - forward_k - (pipeline_parallel_size - 1), forward=True - ) - if next_forward_model_chunk_id == (num_model_chunks - 1): - recv_prev = False - next_forward_model_chunk_id += 1 - else: - next_forward_model_chunk_id = get_model_chunk_id(forward_k + 1, forward=True) - - recv_next = True - if parallel_state.is_pipeline_last_stage(ignore_virtual=True): - # Last stage is ahead of first stage by (pipeline_parallel_size - 1). - next_backward_model_chunk_id = get_model_chunk_id( - backward_k - (pipeline_parallel_size - 1), forward=False - ) - if next_backward_model_chunk_id == 0: - recv_next = False - next_backward_model_chunk_id -= 1 - else: - next_backward_model_chunk_id = get_model_chunk_id(backward_k + 1, forward=False) - - # If last iteration, don't receive; we already received one extra - # before the start of the for loop. - if k == (num_microbatches_remaining - 1): - recv_prev = False - - # Communicate tensors. - _logger.debug("send fwd&bwd and receive fwd&bwd") - ( - input_tensor, - output_tensor_grad, - ) = p2p_communication.send_forward_backward_recv_forward_backward( - output_tensor, - input_tensor_grad, - recv_prev=recv_prev, - recv_next=recv_next, - tensor_shape=tensor_shape, - dtype=dtype, - async_comm=async_comm, - sequence_parallel_enabled=sequence_parallel_enabled, - sync_batch_comm=sync_batch_comm, - batch_p2p_comm=batch_p2p_comm, - ) - free_output_tensor(output_tensor, deallocate_pipeline_outputs) - - # Put input_tensor and output_tensor_grad in data structures in the - # right location. - if recv_prev: - input_tensors[next_forward_model_chunk_id].append(input_tensor) - if recv_next: - output_tensor_grads[next_backward_model_chunk_id].append(output_tensor_grad) - - ################################################################################################################### - # Run cooldown backward passes (flush out pipeline). - ################################################################################################################### - _logger.info("Cooldown phase") - if not forward_only: - if overlap_p2p_comm and bwd_wait_handles is not None: - for wait_handle in bwd_wait_handles: - wait_handle.wait() - - if all_warmup_microbatches: - output_tensor_grads[num_model_chunks - 1].append( - p2p_communication.recv_backward( - tensor_shape=tensor_shape, - dtype=dtype, - async_comm=async_comm, - sequence_parallel_enabled=sequence_parallel_enabled, - sync_batch_comm=sync_batch_comm, - batch_p2p_comm=batch_p2p_comm, - ) - ) - - for k in range(num_microbatches_remaining, num_microbatches): - _logger.debug( - f"cooldown iter {k} in range({num_microbatches_remaining}, {num_microbatches})" - ) - input_tensor_grad = backward_step_helper(k) - next_backward_model_chunk_id = get_model_chunk_id(k + 1, forward=False) - - recv_next = True - if parallel_state.is_pipeline_last_stage(ignore_virtual=True): - if next_backward_model_chunk_id == (num_model_chunks - 1): - recv_next = False - if k == (num_microbatches - 1): - recv_next = False - - output_tensor_grads[next_backward_model_chunk_id].append( - p2p_communication.send_backward_recv_backward( - input_tensor_grad, - recv_next=recv_next, - tensor_shape=tensor_shape, - dtype=dtype, - async_comm=async_comm, - sequence_parallel_enabled=sequence_parallel_enabled, - sync_batch_comm=sync_batch_comm, - batch_p2p_comm=batch_p2p_comm, - ) - ) - - # Make sure to exit context handler for async grad reductions - enable_grad_sync() - - return losses_reduced diff --git a/apex/transformer/pipeline_parallel/schedules/fwd_bwd_pipelining_without_interleaving.py b/apex/transformer/pipeline_parallel/schedules/fwd_bwd_pipelining_without_interleaving.py deleted file mode 100644 index f84779f2f..000000000 --- a/apex/transformer/pipeline_parallel/schedules/fwd_bwd_pipelining_without_interleaving.py +++ /dev/null @@ -1,614 +0,0 @@ -import contextlib -from typing import Any, List, Optional, Sequence, Union -import warnings - -import torch - -from apex.transformer import parallel_state -from apex.transformer.enums import ModelType -from apex.transformer.pipeline_parallel import p2p_communication -from apex.transformer.pipeline_parallel.p2p_communication import FutureTensor -from apex.transformer.pipeline_parallel.utils import get_kth_microbatch -from apex.transformer.pipeline_parallel.utils import listify_model -from apex.transformer.pipeline_parallel.utils import get_num_microbatches -from apex.transformer.pipeline_parallel.utils import get_model_type -from apex.transformer.pipeline_parallel.schedules.common import Batch -from apex.transformer.pipeline_parallel.schedules.common import FwdStepFunc -from apex.transformer.pipeline_parallel.schedules.common import backward_step -from apex.transformer.pipeline_parallel.schedules.common import forward_step -from apex.transformer.pipeline_parallel.schedules.common import free_output_tensor -from apex.transformer.log_util import get_transformer_logger - - -__all__ = ["forward_backward_pipelining_without_interleaving"] - - -_logger = get_transformer_logger(__name__) - - -def get_tensor_shapes( - rank: int, - model_type: ModelType, - *, - tensor_shape: Union[List[int], torch.Size], - decoder_sequence_length: Optional[int] = None, - sequence_parallel_enabled: bool = False, -) -> Sequence[Sequence[int]]: - """Get tensors shapes - - Args: - rank: pipeline parallel rank - model_type: - - Keyword Args: - tensor_shape: - decoder_sequence_length: - sequence_parallel_enabled: - """ - # Determine right tensor sizes (based on position of rank with respect to split - # rank) and model size. - # Send two tensors if model is T5 and rank is in decoder stage: - # first tensor is decoder (pre-transpose), - # second tensor is encoder (post-transpose). - # If model is T5 and rank is at the boundary: - # send one tensor (post-transpose from encoder). - # Otherwise, send one tensor (pre-transpose). - assert len(tensor_shape) == 3, ( - f"`tensor_shape` should be [sequence_length, micro_batch_size, hidden_size] but {tensor_shape}" - ) - - sequence_length, micro_batch_size, hidden_size = tensor_shape - - tensor_shapes = [] - - if sequence_parallel_enabled: - seq_length = sequence_length // parallel_state.get_tensor_model_parallel_world_size() - else: - seq_length = sequence_length - - if model_type == ModelType.encoder_and_decoder: - if sequence_parallel_enabled: - dec_seq_length = ( - decoder_sequence_length // parallel_state.get_tensor_model_parallel_world_size() - ) - else: - dec_seq_length = decoder_sequence_length - - if parallel_state.is_pipeline_stage_before_split(rank): - tensor_shapes.append((seq_length, micro_batch_size, hidden_size)) - else: - tensor_shapes.append((dec_seq_length, micro_batch_size, hidden_size)) - tensor_shapes.append((seq_length, micro_batch_size, hidden_size)) - else: - tensor_shapes.append((seq_length, micro_batch_size, hidden_size)) - - return tensor_shapes - - -def recv_forward( - tensor_shapes: List[Union[None, List[int]]], - *, - dtype: Optional[torch.dtype] = None, - async_comm: bool = False, - sequence_parallel_enabled: bool = False, - sync_batch_comm: bool = True, -) -> List[Union[None, torch.Tensor, FutureTensor]]: - input_tensors = [] - for tensor_shape in tensor_shapes: - if tensor_shape is None: - input_tensors.append(None) - else: - input_tensors.append( - p2p_communication.recv_forward( - tensor_shape=tensor_shape, - dtype=dtype, - async_comm=async_comm, - sequence_parallel_enabled=sequence_parallel_enabled, - sync_batch_comm=sync_batch_comm, - ) - ) - return input_tensors - - -def recv_backward( - tensor_shapes: List[Union[None, List[int]]], - *, - dtype: Optional[torch.dtype] = None, - async_comm: bool = False, - sequence_parallel_enabled: bool = False, - sync_batch_comm: bool = True, -) -> List[Union[None, torch.Tensor, FutureTensor]]: - output_tensor_grads = [] - for tensor_shape in tensor_shapes: - if tensor_shape is None: - output_tensor_grads.append(None) - else: - output_tensor_grads.append( - p2p_communication.recv_backward( - tensor_shape=tensor_shape, - dtype=dtype, - async_comm=async_comm, - sequence_parallel_enabled=sequence_parallel_enabled, - sync_batch_comm=sync_batch_comm, - ) - ) - return output_tensor_grads - - -def send_forward( - output_tensors: Union[torch.Tensor, List[Union[None, torch.Tensor]]], - tensor_shapes: List[Union[None, List[int]]], - *, - dtype: Optional[torch.dtype] = None, - async_comm: bool = False, - sequence_parallel_enabled: bool = False, - sync_batch_comm: bool = True, -) -> None: - if not isinstance(output_tensors, list): - output_tensors = [output_tensors] - for output_tensor, tensor_shape in zip(output_tensors, tensor_shapes): - if tensor_shape is None: - continue - p2p_communication.send_forward( - output_tensor, - tensor_shape=tensor_shape, - dtype=dtype, - async_comm=async_comm, - sequence_parallel_enabled=sequence_parallel_enabled, - sync_batch_comm=sync_batch_comm, - ) - - -def send_backward( - input_tensor_grads: Union[torch.Tensor, List[Union[None, torch.Tensor]]], - tensor_shapes: List[Union[None, List[int]]], - *, - dtype: Optional[torch.dtype] = None, - async_comm: bool = False, - sequence_parallel_enabled: bool = False, - sync_batch_comm: bool = True, -) -> None: - if not isinstance(input_tensor_grads, list): - input_tensor_grads = [input_tensor_grads] - for input_tensor_grad, tensor_shape in zip(input_tensor_grads, tensor_shapes): - if tensor_shape is None: - continue - p2p_communication.send_backward( - input_tensor_grad, - tensor_shape=tensor_shape, - dtype=dtype, - async_comm=async_comm, - sequence_parallel_enabled=sequence_parallel_enabled, - sync_batch_comm=sync_batch_comm, - ) - - -def send_forward_recv_backward( - output_tensors: Union[torch.Tensor, List[Union[None, torch.Tensor]]], - tensor_shapes: List[Union[None, List[int]]], - *, - dtype: Optional[torch.dtype] = None, - async_comm: bool = False, - sequence_parallel_enabled: bool = False, - sync_batch_comm: bool = True, -) -> List[Union[None, torch.Tensor, FutureTensor]]: - if not isinstance(output_tensors, list): - output_tensors = [output_tensors] - output_tensor_grads = [] - for output_tensor, tensor_shape in zip(output_tensors, tensor_shapes): - if tensor_shape is None: - output_tensor_grads.append(None) - continue - output_tensor_grad = p2p_communication.send_forward_recv_backward( - output_tensor, - tensor_shape=tensor_shape, - dtype=dtype, - async_comm=async_comm, - sequence_parallel_enabled=sequence_parallel_enabled, - sync_batch_comm=sync_batch_comm, - ) - output_tensor_grads.append(output_tensor_grad) - return output_tensor_grads - - -def send_backward_recv_forward( - input_tensor_grads: Union[torch.Tensor, List[Union[None, torch.Tensor]]], - tensor_shapes: List[Union[None, List[int]]], - *, - dtype: Optional[torch.dtype] = None, - async_comm: bool = False, - sequence_parallel_enabled: bool = False, - sync_batch_comm: bool = True, -) -> List[Union[None, torch.Tensor, FutureTensor]]: - if not isinstance(input_tensor_grads, list): - input_tensor_grads = [input_tensor_grads] - input_tensors = [] - for input_tensor_grad, tensor_shape in zip(input_tensor_grads, tensor_shapes): - if tensor_shape is None: - input_tensors.append(None) - continue - input_tensor = p2p_communication.send_backward_recv_forward( - input_tensor_grad, - tensor_shape=tensor_shape, - dtype=dtype, - async_comm=async_comm, - sequence_parallel_enabled=sequence_parallel_enabled, - sync_batch_comm=sync_batch_comm, - ) - input_tensors.append(input_tensor) - return input_tensors - - -def forward_backward_pipelining_without_interleaving( - forward_step_func: FwdStepFunc, - batch: Optional[Batch], - model: Union[torch.nn.Module, List[torch.nn.Module]], - *, - forward_only: bool, - tensor_shape: Optional[Union[List[int], torch.Size]] = None, - decoder_sequence_length: Optional[int] = None, - dtype: Optional[torch.dtype] = None, - grad_scaler: Optional[torch.cuda.amp.GradScaler] = None, - disable_autocast: bool = False, - deallocate_pipeline_outputs: bool = False, - async_comm: bool = False, - sequence_parallel_enabled: bool = False, - custom_sync_context_handler: Optional[Any] = None, - custom_grad_sync_func: Optional[Any] = None, - sync_batch_comm: bool = True, - num_micro_batches_with_partial_activation_checkpoints: Optional[int] = None, - **kwargs, -) -> List[Union[torch.Tensor, Sequence[torch.Tensor]]]: - """Run non-interleaved 1F1B schedule, with communication between pipeline stages. - - This pipeline parallel scheduling consists of three steps: - 1. warmup - 2. 1F1B a.k.a. steady state - 3. cooldown if not forward_only - - Args: - forward_step_func: A function which takes a minibatch and model as its arguments and - returns model's forward output and the loss function. - The loss function is supposed to take one `torch.Tensor` and - return a `torch.Tensor` of loss and a dictionary of `str` and `torch.Tensor`. - batch: A minibatch, i.e., a list of `torch.Tensor`'s. - model: A `torch.nn.Module` or a list of `torch.nn.Module`. - - Keyword args: - forward_only: - tensor_shape: Shape of tensor. The tensor is expected to be 3D and its order of dimension - is supposed to be ``(sequence, batch, hidden)``. - dtype: dtype used in p2p communication. If ``None`` (default value), - torch.float32 will be used even if ``autocast`` is enabled. - grad_scaler: - disable_autocast: - deallocate_pipeline_outputs: If :obj:`True`, free the data of the output tensor of - each pipeline stage. Experimental. - sequence_parallel_enabled: Set to :obj:`True` for this function to handle sequence length. - When :obj:`True`, the sequence length on each tensor model parallel rank is updated - to :math:`original\_sequence\_length / tensor\_model\_parallel\_world\_size`. - custom_sync_context_handler: Does nothing if ``None`` (default - value). Otherwise, a function to construct a context - manager that disable asynchronous gradient reductions. - Asynchronous gradient reductions are only enabled in the - first pipeline stage, during the last backward pass. - custom_grad_sync_func: Does nothing if ``None`` (default - value). Otherwise, a function to perform gradient - reductions. This is called in all pipeline stages except - the first, during the bubble overhead. - sync_batch_comm: If :obj:`False`, disable cuda synchronization after the batched communication. - To disable, https://github.com/pytorch/pytorch/pull/82450 would be required. - num_micro_batches_with_partial_activation_checkpoints: If :obj:`int`, set the number of - micro-batches checkpointing the activation of partial number of Transformer layers. - The rest of the micro-batch within the window of maximum outstanding micro-batch - backpropagations would checkpoint all Transformer layers. - - Returns: - a list of loss `torch.Tensor`s if the last stage, empty list otherwise. - - """ - # timers = get_timers() - from apex import deprecated_warning - - deprecated_warning( - "`apex.transformer` is deprecated and will be removed in September 2025. " - "We encourage you to migrate to Megatron Core. " - "It is available on PyPI at https://pypi.org/project/megatron-core/ " - "and its documentation can be found at https://docs.nvidia.com/megatron-core/developer-guide/latest/index.html." - ) - - if deallocate_pipeline_outputs: - warnings.warn( - "`deallocate_pipeline_outputs` is experimental and subject to change. " - "This option is not recommended." - ) - - model: List[torch.nn.Module] = listify_model(model) - if len(model) != 1: - msg = f"`model` is expected be a `nn.Module`, but {type(model)}" - raise RuntimeError(msg) - model: torch.nn.Module = model[0] - - # Disable async grad reductions - if custom_sync_context_handler is not None: - sync_context_handler = custom_sync_context_handler - else: - sync_context_handler = contextlib.nullcontext - sync_context = None - - def disable_grad_sync(): - """Disable asynchronous grad reductions""" - nonlocal sync_context - if sync_context is None: - sync_context = sync_context_handler() - sync_context.__enter__() - - def enable_grad_sync(): - """Enable asynchronous grad reductions""" - nonlocal sync_context - if sync_context is not None: - sync_context.__exit__(None, None, None) - sync_context = None - - disable_grad_sync() - - # Compute number of warmup microbatches. - num_microbatches: int = get_num_microbatches() - num_warmup_microbatches: int = ( - parallel_state.get_pipeline_model_parallel_world_size() - - parallel_state.get_pipeline_model_parallel_rank() - - 1 - ) - num_warmup_microbatches: int = min(num_warmup_microbatches, num_microbatches) - num_microbatches_remaining: int = num_microbatches - num_warmup_microbatches - - # Checkpoint the activations of partial Transformer layers in a number of micro-batches - # within the maximum outstanding micro-batch backpropagations. - # Micro-batches with the ids less than 'num_micro_batches_with_partial_activation_checkpoints' - # checkpoint partial Transformer layers (or skip checkpointing) and - # the rest of micro-batches within a window of micro-batches checkpoint - # all Transformer layers. The window of micro-batches is set by the maximum - # outstanding backpropagations and becomes smaller at later pipeline stages. - # Please refer the appendix C in https://arxiv.org/pdf/2205.05198.pdf - max_outstanding_backprops = None - if num_micro_batches_with_partial_activation_checkpoints is not None: - max_outstanding_backprops = num_warmup_microbatches + 1 - - model_type = get_model_type(model) - rank: int = parallel_state.get_pipeline_model_parallel_rank() - recv_tensor_shapes: List[List[int]] = get_tensor_shapes( - rank - 1, - model_type, - tensor_shape=tensor_shape, - decoder_sequence_length=decoder_sequence_length, - sequence_parallel_enabled=sequence_parallel_enabled, - ) - send_tensor_shapes: List[List[int]] = get_tensor_shapes( - rank, - model_type, - tensor_shape=tensor_shape, - decoder_sequence_length=decoder_sequence_length, - sequence_parallel_enabled=sequence_parallel_enabled, - ) - - _logger.info( - f"num_microbatches: {num_microbatches}, " - f"num_warmup_microbatches: {num_warmup_microbatches}, " - f"num_microbatches_remaining: {num_microbatches_remaining}" - ) - - # Input, output tensors only need to be saved when doing backward passes - input_tensors: List[Union[None, torch.Tensor]] = [] - output_tensors: List[Union[None, torch.Tensor]] = [] - losses_reduced: List[Union[None, torch.Tensor]] = [] - ################################################################################################################### - # Run warmup forward passes. - ################################################################################################################### - _logger.info("Warmup") - for i in range(num_warmup_microbatches): - _logger.debug(f"warmup iter: {i} / {num_warmup_microbatches}") - _logger.debug("receive fwd") - - # Decide to checkpoint all layers' activations of the current micro-batch - if max_outstanding_backprops is not None: - checkpoint_activations_micro_batch = ( - i % max_outstanding_backprops - >= num_micro_batches_with_partial_activation_checkpoints - ) - else: - checkpoint_activations_micro_batch = None - input_tensor = recv_forward( - tensor_shapes=recv_tensor_shapes, - dtype=dtype, - async_comm=async_comm, - sequence_parallel_enabled=sequence_parallel_enabled, - sync_batch_comm=sync_batch_comm, - ) - cur_microbatch: Optional[torch.Tensor] = get_kth_microbatch(batch, i) - output_tensor = forward_step( - forward_step_func, - cur_microbatch, - model, - input_tensor, - losses_reduced, - dtype, - disable_autocast, - checkpoint_activations_micro_batch, - ) - _logger.debug("send fwd") - send_forward( - output_tensor, - tensor_shapes=send_tensor_shapes, - dtype=dtype, - async_comm=async_comm, - sequence_parallel_enabled=sequence_parallel_enabled, - sync_batch_comm=sync_batch_comm, - ) - - if not forward_only: - input_tensors.append(input_tensor) - output_tensors.append(output_tensor) - free_output_tensor(output_tensor, deallocate_pipeline_outputs) - - # Before running 1F1B, need to receive first forward tensor. - # If all microbatches are run in warmup / cooldown phase, then no need to - # receive this tensor here. - if num_microbatches_remaining > 0: - _logger.debug("recv_forward before steady state start") - input_tensor: List[Union[None, torch.Tensor, FutureTensor]] = recv_forward( - tensor_shapes=recv_tensor_shapes, - dtype=dtype, - async_comm=async_comm, - sync_batch_comm=sync_batch_comm, - ) - - ################################################################################################################### - # Run 1F1B in steady state. - ################################################################################################################### - _logger.info("Steady phase") - for i in range(num_microbatches_remaining): - _logger.debug(f"steady iter: {i} / {num_microbatches_remaining}") - last_iteration: bool = i == (num_microbatches_remaining - 1) - - # Decide to checkpoint all layers' activations of the current micro-batch - if max_outstanding_backprops is not None: - checkpoint_activations_micro_batch = ( - (i + num_warmup_microbatches) % max_outstanding_backprops - ) >= num_micro_batches_with_partial_activation_checkpoints - else: - checkpoint_activations_micro_batch = None - cur_microbatch: Optional[torch.Tensor] = get_kth_microbatch( - batch, i + num_warmup_microbatches - ) - output_tensor: Union[torch.Tensor, Sequence[torch.Tensor]] = forward_step( - forward_step_func, - cur_microbatch, - model, - input_tensor, - losses_reduced, - dtype, - disable_autocast, - checkpoint_activations_micro_batch, - ) - if forward_only: - _logger.debug("send fwd") - send_forward( - output_tensor, - tensor_shapes=send_tensor_shapes, - dtype=dtype, - async_comm=async_comm, - sequence_parallel_enabled=sequence_parallel_enabled, - sync_batch_comm=sync_batch_comm, - ) - - if not last_iteration: - _logger.debug("receive fwd (last iteration)") - input_tensor = recv_forward( - tensor_shapes=recv_tensor_shapes, - dtype=dtype, - async_comm=async_comm, - sequence_parallel_enabled=sequence_parallel_enabled, - sync_batch_comm=sync_batch_comm, - ) - - else: - _logger.debug("send fwd & receive bwd") - output_tensor_grad = send_forward_recv_backward( - output_tensor, - tensor_shapes=send_tensor_shapes, - dtype=dtype, - async_comm=async_comm, - sequence_parallel_enabled=sequence_parallel_enabled, - sync_batch_comm=sync_batch_comm, - ) - - # Add input_tensor and output_tensor to end of list. - input_tensors.append(input_tensor) - output_tensors.append(output_tensor) - free_output_tensor(output_tensor, deallocate_pipeline_outputs) - - # Pop input_tensor and output_tensor from the start of the list for the backward pass. - input_tensor = input_tensors.pop(0) - output_tensor = output_tensors.pop(0) - - input_tensor_grad = backward_step( - input_tensor, - output_tensor, - output_tensor_grad, - model_type=model_type, - grad_scaler=grad_scaler, - deallocate_pipeline_outputs=deallocate_pipeline_outputs, - ) - - if last_iteration: - input_tensor = None - _logger.debug("send bwd") - send_backward( - input_tensor_grad, - tensor_shapes=recv_tensor_shapes, - dtype=dtype, - async_comm=async_comm, - sequence_parallel_enabled=sequence_parallel_enabled, - sync_batch_comm=sync_batch_comm, - ) - else: - _logger.debug("send bwd and receive fwd") - input_tensor = send_backward_recv_forward( - input_tensor_grad, - tensor_shapes=recv_tensor_shapes, - dtype=dtype, - async_comm=async_comm, - sequence_parallel_enabled=sequence_parallel_enabled, - sync_batch_comm=sync_batch_comm, - ) - ################################################################################################################### - # Run cooldown backward passes. - ################################################################################################################### - _logger.info("Cooldown phase") - if not forward_only: - for i in range(num_warmup_microbatches): - _logger.debug(f"cooldown iter: {i} / {num_warmup_microbatches}") - - if i == num_warmup_microbatches - 1 and rank == 0: - # Async grad reduction in first pipeline stage, during - # last backward pass - enable_grad_sync() - - input_tensor = input_tensors.pop(0) - output_tensor = output_tensors.pop(0) - - _logger.debug("receive bwd") - output_tensor_grad = recv_backward( - tensor_shapes=send_tensor_shapes, - dtype=dtype, - async_comm=async_comm, - sequence_parallel_enabled=sequence_parallel_enabled, - sync_batch_comm=sync_batch_comm, - ) - - input_tensor_grad = backward_step( - input_tensor, - output_tensor, - output_tensor_grad, - model_type=model_type, - grad_scaler=grad_scaler, - deallocate_pipeline_outputs=deallocate_pipeline_outputs, - ) - - _logger.debug("send bwd") - send_backward( - input_tensor_grad, - tensor_shapes=recv_tensor_shapes, - dtype=dtype, - async_comm=async_comm, - sequence_parallel_enabled=sequence_parallel_enabled, - sync_batch_comm=sync_batch_comm, - ) - - # Grad reduction in all pipeline stages except the first, during - # the bubble overhead - enable_grad_sync() - if rank != 0 and custom_grad_sync_func is not None: - custom_grad_sync_func() - - return losses_reduced diff --git a/apex/transformer/pipeline_parallel/utils.py b/apex/transformer/pipeline_parallel/utils.py deleted file mode 100644 index 73a724d03..000000000 --- a/apex/transformer/pipeline_parallel/utils.py +++ /dev/null @@ -1,370 +0,0 @@ -# coding=utf-8 -# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Utilities for pipeline model parallel.""" - -from typing import Optional, List, Union, Tuple - -import torch -from torch.nn.parallel import DistributedDataParallel - -from apex.multi_tensor_apply import multi_tensor_applier -from apex.transformer import parallel_state -from apex.transformer.enums import ModelType -from apex.transformer.microbatches import build_num_microbatches_calculator -from apex.transformer.pipeline_parallel._timers import _Timers - -if multi_tensor_applier.available: - import amp_C - - -_GLOBAL_ARGS = None -_GLOBAL_NUM_MICROBATCHES_CALCULATOR = None -_GLOBAL_TOKENIZER = None -_GLOBAL_TENSORBOARD_WRITER = None -_GLOBAL_AUTORESUME = None -_GLOBAL_TIMERS = None - - -Shape = Union[List[int], torch.Size] - - -def listify_model( - model: Union[torch.nn.Module, List[torch.nn.Module]], -) -> List[torch.nn.Module]: - if isinstance(model, list): - return model - return [model] - - -def _ensure_var_is_initialized(var, name): - """Make sure the input variable is not None.""" - assert var is not None, "{} is not initialized.".format(name) - - -def _ensure_var_is_not_initialized(var, name): - """Make sure the input variable is not None.""" - assert var is None, "{} is already initialized.".format(name) - - -def setup_microbatch_calculator( - rank: int, - rampup_batch_size: Optional[List[int]], - global_batch_size: int, - micro_batch_size: int, - data_parallel_size: int, -) -> None: - global _GLOBAL_NUM_MICROBATCHES_CALCULATOR - _ensure_var_is_not_initialized( - _GLOBAL_NUM_MICROBATCHES_CALCULATOR, "num microbatches calculator" - ) - - _GLOBAL_NUM_MICROBATCHES_CALCULATOR = build_num_microbatches_calculator( - rank, rampup_batch_size, global_batch_size, micro_batch_size, data_parallel_size - ) - - -def _reconfigure_microbatch_calculator( - rank: int, - rampup_batch_size: Optional[List[int]], - global_batch_size: int, - micro_batch_size: int, - data_parallel_size: int, -) -> None: - if torch.distributed.get_rank() == 0: - import warnings - - warnings.warn("This function is only for unittest") - global _GLOBAL_NUM_MICROBATCHES_CALCULATOR - - _GLOBAL_NUM_MICROBATCHES_CALCULATOR = build_num_microbatches_calculator( - rank, rampup_batch_size, global_batch_size, micro_batch_size, data_parallel_size - ) - - -def get_micro_batch_size(): - return _GLOBAL_NUM_MICROBATCHES_CALCULATOR.micro_batch_size - - -def get_num_microbatches(): - return _GLOBAL_NUM_MICROBATCHES_CALCULATOR.get() - - -def get_current_global_batch_size(): - return _GLOBAL_NUM_MICROBATCHES_CALCULATOR.get_current_global_batch_size() - - -def update_num_microbatches(consumed_samples, consistency_check=True): - _GLOBAL_NUM_MICROBATCHES_CALCULATOR.update(consumed_samples, consistency_check) - - -# note (mkozuki): Comment out in favor of `get_kth_microbatch` -def _split_batch_into_microbatch( - batch: List[torch.Tensor], - *, - _micro_batch_size: Optional[int] = None, - _global_batch_size: Optional[int] = None, -) -> List[List[torch.Tensor]]: - micro_batch_size = _micro_batch_size - global_batch_size = _global_batch_size - if micro_batch_size is None: - micro_batch_size = get_micro_batch_size() - if global_batch_size is None: - global_batch_size = get_current_global_batch_size() - for i in range(0, global_batch_size, micro_batch_size): - yield [x[i * micro_batch_size : (i + 1) * micro_batch_size] for x in batch] - - -# TODO(mkozuki): Support non-tensor local minibatches? -def get_kth_microbatch(batch: Optional[List[torch.Tensor]], k: int) -> List[torch.Tensor]: - """Create a list of microbatches from a list of local minibatches. - - This function creates a list of `k`th microbatches from a list of local minibatches. - `a local minibatch` consists of `global_batch_size / data_parallel_size` samples. - """ - if batch is None or not isinstance(batch, (List, Tuple)): - return batch - micro_batch_size = get_micro_batch_size() - start = k * micro_batch_size - end = start + micro_batch_size - microbatch = list() - for x in batch: - size = x.size(0) - assert size > start and size >= end - microbatch.append(x[start:end]) - assert len(microbatch) > 0 - return microbatch - - -def get_autoresume(): - return _GLOBAL_AUTORESUME - - -def _set_timers(): - """Initialize timers.""" - global _GLOBAL_TIMERS - _ensure_var_is_not_initialized(_GLOBAL_TIMERS, "timers") - _GLOBAL_TIMERS = _Timers() - - -def get_timers(): - """Return timers.""" - _ensure_var_is_initialized(_GLOBAL_TIMERS, "timers") - return _GLOBAL_TIMERS - - -def print_rank_0(message: str) -> None: - """If distributed is initialized, print only on rank 0.""" - if torch.distributed.is_initialized(): - if torch.distributed.get_rank() == 0: - print(message, flush=True) - else: - print(message, flush=True) - - -def is_last_rank(): - return torch.distributed.get_rank() == (torch.distributed.get_world_size() - 1) - - -def print_rank_last(message): - """If distributed is initialized, print only on last rank.""" - if torch.distributed.is_initialized(): - if is_last_rank(): - print(message, flush=True) - else: - print(message, flush=True) - - -def param_is_not_shared(param: torch.nn.Parameter) -> bool: - return getattr(param, "shared", False) - - -def unwrap_model(model, module_instances=(DistributedDataParallel,)): - return_list = True - if not isinstance(model, list): - model = [model] - return_list = False - unwrapped_model = [] - for model_module in model: - while isinstance(model_module, module_instances): - model_module = model_module.module - unwrapped_model.append(model_module) - if not return_list: - return unwrapped_model[0] - return unwrapped_model - - -def get_model_type( - model: torch.nn.Module, -) -> ModelType: - """Get `model_type` of `model`. - - If ``model`` doesn't have ``model_type`` attribute, return ``ModelType.encoder_or_decoder``. - - Args: - model - """ - return getattr(unwrap_model(model), "model_type", ModelType.encoder_or_decoder) - - -def calc_params_l2_norm(model: torch.nn.Module, bf16: bool): - """Calculate l2 norm of parameters""" - # args = get_args() - if not isinstance(model, list): - model = [model] - # Remove duplicate params. - params_data = [] - for model_ in model: - for param in model_.parameters(): - is_not_shared = param_is_not_shared(param) - is_not_tp_duplicate = parallel_state.param_is_not_tensor_parallel_duplicate(param) - if is_not_shared and is_not_tp_duplicate: - if bf16: - params_data.append(param.data.float()) - else: - params_data.append(param.data) - # Calculate norm - dummy_overflow_buf = torch.cuda.IntTensor([0]) - norm, _ = multi_tensor_applier( - amp_C.multi_tensor_l2norm, - dummy_overflow_buf, - [params_data], - False, # no per-parameter norm - ) - norm_2 = norm * norm - # Sum across all model-parallel GPUs. - torch.distributed.all_reduce( - norm_2, - op=torch.distributed.ReduceOp.SUM, - group=parallel_state.get_model_parallel_group(), - ) - return norm_2.item() ** 0.5 - - -def average_losses_across_data_parallel_group(losses): - """Reduce a tensor of losses across all GPUs.""" - averaged_losses = torch.cat([loss.clone().detach().view(1) for loss in losses]) - torch.distributed.all_reduce(averaged_losses, group=parallel_state.get_data_parallel_group()) - averaged_losses = averaged_losses / torch.distributed.get_world_size( - group=parallel_state.get_data_parallel_group() - ) - - return averaged_losses - - -def report_memory(name): - """Simple GPU memory report.""" - mega_bytes = 1024.0 * 1024.0 - string = name + " memory (MB)" - string += " | allocated: {}".format(torch.cuda.memory_allocated() / mega_bytes) - string += " | max allocated: {}".format(torch.cuda.max_memory_allocated() / mega_bytes) - string += " | reserved: {}".format(torch.cuda.memory_reserved() / mega_bytes) - string += " | max reserved: {}".format(torch.cuda.max_memory_reserved() / mega_bytes) - if parallel_state.get_data_parallel_rank() == 0: - print("[Rank {}] {}".format(torch.distributed.get_rank(), string), flush=True) - - -def print_params_min_max_norm(optimizer, iteration): - """Print min, max, and norm of all parameters.""" - index = 0 - rank = torch.distributed.get_rank() - string = "iteration, rank, index, tensor-model-parallel, min, max, norm\n" - optimizer_ = optimizer.optimizer - for param_group in optimizer_.param_groups: - for param in param_group["params"]: - index += 1 - min_ = param.data.min() - max_ = param.data.max() - norm = torch.linalg.norm(param.data) - string += "{:7d}, {:4d}, {:4d}, {:2d}, ".format( - iteration, rank, index, int(param.tensor_model_parallel) - ) - string += "{:.6E}, {:.6E}, {:.6E}\n".format(min_, max_, norm) - print(string, flush=True) - - -# NOTE (mkozuki): APEX doesn't have anything equivalent for -# `_GLOBAL_ADLR_AUTORESUME` like Megatron-LM. -# def check_adlr_autoresume_termination(iteration, model, optimizer, lr_scheduler, save: bool): -# """Check for autoresume signal and exit if it is received.""" -# from apex.ppu.checkpointing import save_checkpoint -# -# autoresume = get_adlr_autoresume() -# # Add barrier to ensure consistency. -# torch.distributed.barrier() -# if autoresume.termination_requested(): -# if save: -# save_checkpoint(iteration, model, optimizer, lr_scheduler) -# print_rank_0(">>> autoresume termination request found!") -# if torch.distributed.get_rank() == 0: -# autoresume.request_resume() -# print_rank_0(">>> training terminated. Returning") -# sys.exit(0) - - -def get_ltor_masks_and_position_ids( - data, eod_token, reset_position_ids, reset_attention_mask, eod_mask_loss -): - """Build masks and position id for left to right model.""" - - # Extract batch size and sequence length. - micro_batch_size, seq_length = data.size() - - # Attention mask (lower triangular). - if reset_attention_mask: - att_mask_batch = micro_batch_size - else: - att_mask_batch = 1 - attention_mask = torch.tril( - torch.ones((att_mask_batch, seq_length, seq_length), device=data.device) - ).view(att_mask_batch, 1, seq_length, seq_length) - - # Loss mask. - loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device) - if eod_mask_loss: - loss_mask[data == eod_token] = 0.0 - - # Position ids. - position_ids = torch.arange(seq_length, dtype=torch.long, device=data.device) - position_ids = position_ids.unsqueeze(0).expand_as(data) - # We need to clone as the ids will be modifed based on batch index. - if reset_position_ids: - position_ids = position_ids.clone() - - if reset_position_ids or reset_attention_mask: - # Loop through the batches: - for b in range(micro_batch_size): - # Find indecies where EOD token is. - eod_index = position_ids[b, data[b] == eod_token] - # Detach indecies from positions if going to modify positions. - if reset_position_ids: - eod_index = eod_index.clone() - - # Loop through EOD indecies: - prev_index = 0 - for j in range(eod_index.size()[0]): - i = eod_index[j] - # Mask attention loss. - if reset_attention_mask: - attention_mask[b, 0, (i + 1) :, : (i + 1)] = 0 - # Reset positions. - if reset_position_ids: - position_ids[b, (i + 1) :] -= i + 1 - prev_index - prev_index = i + 1 - - # Convert attention mask to binary: - attention_mask = attention_mask < 0.5 - - return attention_mask, loss_mask, position_ids diff --git a/apex/transformer/tensor_parallel/__init__.py b/apex/transformer/tensor_parallel/__init__.py deleted file mode 100644 index ccad80e6b..000000000 --- a/apex/transformer/tensor_parallel/__init__.py +++ /dev/null @@ -1,75 +0,0 @@ -# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Model parallel utility interface.""" - -from apex.transformer.tensor_parallel.cross_entropy import vocab_parallel_cross_entropy - -from apex.transformer.tensor_parallel.data import broadcast_data - -from apex.transformer.tensor_parallel.layers import ( - ColumnParallelLinear, - RowParallelLinear, - VocabParallelEmbedding, - set_tensor_model_parallel_attributes, - set_defaults_if_not_set_tensor_model_parallel_attributes, - copy_tensor_model_parallel_attributes, -) - -from apex.transformer.tensor_parallel.mappings import ( - copy_to_tensor_model_parallel_region, - gather_from_tensor_model_parallel_region, - reduce_from_tensor_model_parallel_region, - scatter_to_tensor_model_parallel_region, - scatter_to_sequence_parallel_region, -) - -from .random import ( - checkpoint, - get_cuda_rng_tracker, - init_checkpointed_activations_memory_buffer, - model_parallel_cuda_manual_seed, - reset_checkpointed_activations_memory_buffer, -) - -from apex.transformer.tensor_parallel.utils import split_tensor_along_last_dim - - -__all__ = [ - # cross_entropy.py - "vocab_parallel_cross_entropy", - # data.py - "broadcast_data", - # layers.py - "ColumnParallelLinear", - "RowParallelLinear", - "VocabParallelEmbedding", - "set_tensor_model_parallel_attributes", - "set_defaults_if_not_set_tensor_model_parallel_attributes", - "copy_tensor_model_parallel_attributes", - # mappings.py - "copy_to_tensor_model_parallel_region", - "gather_from_tensor_model_parallel_region", - "reduce_from_tensor_model_parallel_region", - "scatter_to_tensor_model_parallel_region", - "scatter_to_sequence_parallel_region", - # random.py - "checkpoint", - "get_cuda_rng_tracker", - "init_checkpointed_activations_memory_buffer", - "model_parallel_cuda_manual_seed", - "reset_checkpointed_activations_memory_buffer", - # utils.py - "split_tensor_along_last_dim", -] diff --git a/apex/transformer/tensor_parallel/cross_entropy.py b/apex/transformer/tensor_parallel/cross_entropy.py deleted file mode 100644 index 9867720b4..000000000 --- a/apex/transformer/tensor_parallel/cross_entropy.py +++ /dev/null @@ -1,155 +0,0 @@ -# coding=utf-8 -# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import torch - -from apex.transformer.parallel_state import get_tensor_model_parallel_group -from apex.transformer.parallel_state import get_tensor_model_parallel_rank -from apex.transformer.parallel_state import get_tensor_model_parallel_world_size -from apex.transformer.tensor_parallel.utils import VocabUtility - - -class _VocabParallelCrossEntropy(torch.autograd.Function): - @staticmethod - def forward(ctx, vocab_parallel_logits, target, label_smoothing=0.0): - from apex import deprecated_warning - - deprecated_warning( - "`apex.transformer` is deprecated and will be removed in September 2025. " - "We encourage you to migrate to Megatron Core. " - "It is available on PyPI at https://pypi.org/project/megatron-core/ " - "and its documentation can be found at https://docs.nvidia.com/megatron-core/developer-guide/latest/index.html." - ) - - # Maximum value along vocab dimension across all GPUs. - logits_max = torch.max(vocab_parallel_logits, dim=-1)[0] - torch.distributed.all_reduce( - logits_max, - op=torch.distributed.ReduceOp.MAX, - group=get_tensor_model_parallel_group(), - ) - # Subtract the maximum value. - vocab_parallel_logits = vocab_parallel_logits - logits_max.unsqueeze(dim=-1) - - # Get the partition's vocab indecies - get_vocab_range = VocabUtility.vocab_range_from_per_partition_vocab_size - partition_vocab_size = vocab_parallel_logits.size()[-1] - rank = get_tensor_model_parallel_rank() - world_size = get_tensor_model_parallel_world_size() - vocab_start_index, vocab_end_index = get_vocab_range(partition_vocab_size, rank, world_size) - - # Create a mask of valid vocab ids (1 means it needs to be masked). - target_mask = (target < vocab_start_index) | (target >= vocab_end_index) - masked_target = target.clone() - vocab_start_index - masked_target[target_mask] = 0 - - # Get predicted-logits = logits[target]. - # For Simplicity, we convert logits to a 2-D tensor with size - # [*, partition-vocab-size] and target to a 1-D tensor of size [*]. - logits_2d = vocab_parallel_logits.view(-1, partition_vocab_size) - masked_target_1d = masked_target.view(-1) - arange_1d = torch.arange(start=0, end=logits_2d.size()[0], device=logits_2d.device) - predicted_logits_1d = logits_2d[arange_1d, masked_target_1d] - predicted_logits_1d = predicted_logits_1d.clone().contiguous() - predicted_logits = predicted_logits_1d.view_as(target) - predicted_logits[target_mask] = 0.0 - # All reduce is needed to get the chunks from other GPUs. - torch.distributed.all_reduce( - predicted_logits, - op=torch.distributed.ReduceOp.SUM, - group=get_tensor_model_parallel_group(), - ) - - # Sum of exponential of logits along vocab dimension across all GPUs. - exp_logits = vocab_parallel_logits - torch.exp(vocab_parallel_logits, out=exp_logits) - sum_exp_logits = exp_logits.sum(dim=-1) - torch.distributed.all_reduce( - sum_exp_logits, - op=torch.distributed.ReduceOp.SUM, - group=get_tensor_model_parallel_group(), - ) - - # Loss = log(sum(exp(logits))) - predicted-logit. - loss = torch.log(sum_exp_logits) - predicted_logits - - # Store softmax, target-mask and masked-target for backward pass. - exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1)) - - vocab_size = exp_logits.size(-1) - if label_smoothing > 0: - """ - We'd like to assign 1 / (K - 1) probability mass to every index that is not the ground truth. - = (1 - alpha) * y_gt + alpha * mean(y_{i for i != gt}) - = (1 - alpha) * y_gt + (alpha / (K - 1)) * \sum_{i != gt} y_i - = ((K - 1) * (1 - alpha) / (K - 1)) * y_gt + (alpha / (K - 1)) * \sum_{i != gt} y_i - = (K * (1 - alpha) - 1) / (K - 1)) * y_gt + (alpha / (K - 1)) * \sum_{i} y_i - = (1 - (alpha * K) / (K - 1)) * y_gt + ( (alpha * K) / (K - 1) ) * \sum_{i} y_i / K - From: https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/common/losses/smoothed_cross_entropy.py - """ - assert 1.0 > label_smoothing > 0.0 - smoothing = label_smoothing * vocab_size / (vocab_size - 1) - - # Exp logits at this point are normalized probabilities. So we can just take the log to get log-probs. - log_probs = torch.log(exp_logits) - mean_log_probs = log_probs.mean(dim=-1) - loss = (1.0 - smoothing) * loss - smoothing * mean_log_probs - - ctx.label_smoothing, ctx.vocab_size = label_smoothing, vocab_size - ctx.save_for_backward(exp_logits, target_mask, masked_target_1d) - - return loss - - @staticmethod - def backward(ctx, grad_output): - # Retreive tensors from the forward path. - softmax, target_mask, masked_target_1d = ctx.saved_tensors - label_smoothing, vocab_size = ctx.label_smoothing, ctx.vocab_size - - # All the inputs have softmax as thier gradient. - grad_input = softmax - # For simplicity, work with the 2D gradient. - partition_vocab_size = softmax.size()[-1] - grad_2d = grad_input.view(-1, partition_vocab_size) - - # Add the gradient from matching classes. - arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device) - - softmax_update = 1.0 - target_mask.view(-1).float() - - if label_smoothing > 0: - smoothing = label_smoothing * vocab_size / (vocab_size - 1) - grad_2d[arange_1d, masked_target_1d] -= (1.0 - smoothing) * softmax_update - average_grad = 1 / vocab_size - grad_2d[arange_1d, :] -= smoothing * average_grad - else: - grad_2d[arange_1d, masked_target_1d] -= softmax_update - - # Finally elementwise multiplication with the output gradients. - grad_input.mul_(grad_output.unsqueeze(dim=-1)) - - return grad_input, None, None - - -def vocab_parallel_cross_entropy(vocab_parallel_logits, target, label_smoothing=0.0): - """Helper function for the cross entropy.""" - from apex import deprecated_warning - - deprecated_warning( - "`apex.transformer` is deprecated and will be removed in September 2025. " - "We encourage you to migrate to Megatron Core. " - "It is available on PyPI at https://pypi.org/project/megatron-core/ " - "and its documentation can be found at https://docs.nvidia.com/megatron-core/developer-guide/latest/index.html." - ) - return _VocabParallelCrossEntropy.apply(vocab_parallel_logits, target, label_smoothing) diff --git a/apex/transformer/tensor_parallel/data.py b/apex/transformer/tensor_parallel/data.py deleted file mode 100644 index 558775e35..000000000 --- a/apex/transformer/tensor_parallel/data.py +++ /dev/null @@ -1,127 +0,0 @@ -# coding=utf-8 -# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import torch - -from apex.transformer.parallel_state import get_tensor_model_parallel_group -from apex.transformer.parallel_state import get_tensor_model_parallel_rank -from apex.transformer.parallel_state import get_tensor_model_parallel_src_rank - - -_MAX_DATA_DIM = 5 - - -def _check_data_types(keys, data, target_dtype): - """Check that all the keys have the same target data type.""" - for key in keys: - assert data[key].dtype == target_dtype, ( - "{} has data type {} which is different than {}".format( - key, data[key].dtype, target_dtype - ) - ) - - -def _build_key_size_numel_dictionaries(keys, data): - """Build the size on rank 0 and broadcast.""" - max_dim = _MAX_DATA_DIM - sizes = [0 for _ in range(max_dim) for _ in keys] - - # Pack the sizes on rank zero. - if get_tensor_model_parallel_rank() == 0: - offset = 0 - for key in keys: - assert data[key].dim() < max_dim, "you should increase MAX_DATA_DIM" - size = data[key].size() - for i, s in enumerate(size): - sizes[i + offset] = s - offset += max_dim - - # Move to GPU and broadcast. - sizes_cuda = torch.cuda.LongTensor(sizes) - torch.distributed.broadcast( - sizes_cuda, - get_tensor_model_parallel_src_rank(), - group=get_tensor_model_parallel_group(), - ) - - # Move back to cpu and unpack. - sizes_cpu = sizes_cuda.cpu() - key_size = {} - key_numel = {} - total_numel = 0 - offset = 0 - for key in keys: - i = 0 - size = [] - numel = 1 - while sizes_cpu[offset + i] > 0: - this_size = sizes_cpu[offset + i] - size.append(this_size) - numel *= this_size - i += 1 - key_size[key] = size - key_numel[key] = numel - total_numel += numel - offset += max_dim - - return key_size, key_numel, total_numel - - -def broadcast_data(keys, data, datatype): - """Broadcast data from rank zero of each model parallel group to the - members of the same model parallel group. - - Arguments: - keys: list of keys in the data disctionary to be broadcasted - data: data dictionary of string keys and cpu tensor values. - datatype: torch data type of all tensors in data associated - with keys. - """ - from apex import deprecated_warning - - deprecated_warning( - "`apex.transformer` is deprecated and will be removed in September 2025. " - "We encourage you to migrate to Megatron Core. " - "It is available on PyPI at https://pypi.org/project/megatron-core/ " - "and its documentation can be found at https://docs.nvidia.com/megatron-core/developer-guide/latest/index.html." - ) - # Build (key, size) and (key, number of elements) dictionaries along - # with the total number of elements on all ranks. - key_size, key_numel, total_numel = _build_key_size_numel_dictionaries(keys, data) - # Pack on rank zero. - if get_tensor_model_parallel_rank() == 0: - # Check that all keys have the same data type. - _check_data_types(keys, data, datatype) - # Flatten the data associated with the keys - flatten_data = torch.cat([data[key].contiguous().view(-1) for key in keys], dim=0).cuda() - else: - flatten_data = torch.empty(total_numel, device=torch.cuda.current_device(), dtype=datatype) - - # Broadcast - torch.distributed.broadcast( - flatten_data, - get_tensor_model_parallel_src_rank(), - group=get_tensor_model_parallel_group(), - ) - - # Unpack - output = {} - offset = 0 - for key in keys: - size = key_size[key] - numel = key_numel[key] - output[key] = flatten_data.narrow(0, offset, numel).view(size) - offset += numel - - return output diff --git a/apex/transformer/tensor_parallel/layers.py b/apex/transformer/tensor_parallel/layers.py deleted file mode 100644 index f54756dbc..000000000 --- a/apex/transformer/tensor_parallel/layers.py +++ /dev/null @@ -1,884 +0,0 @@ -# coding=utf-8 -# Copyright (c) 2021-22, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -# Parts of the code here are adapted from PyTorch -# repo: https://github.com/pytorch/pytorch -from typing import Optional, Tuple -import warnings - -import torch -import torch.nn.functional as F -import torch.nn.init as init -from torch.nn.parameter import Parameter - -from apex._autocast_utils import _cast_if_autocast_enabled -from apex.transformer.parallel_state import get_tensor_model_parallel_group -from apex.transformer.parallel_state import get_tensor_model_parallel_rank -from apex.transformer.parallel_state import get_tensor_model_parallel_world_size -from apex.transformer.utils import divide -from apex.transformer.tensor_parallel.mappings import ( - copy_to_tensor_model_parallel_region, -) -from apex.transformer.tensor_parallel.mappings import ( - gather_from_tensor_model_parallel_region, -) -from apex.transformer.tensor_parallel.mappings import ( - reduce_from_tensor_model_parallel_region, -) -from apex.transformer.tensor_parallel.mappings import ( - scatter_to_tensor_model_parallel_region, -) -from apex.transformer.tensor_parallel.mappings import ( - reduce_scatter_to_sequence_parallel_region, -) -from apex.transformer.tensor_parallel.random import get_cuda_rng_tracker -from apex.transformer.tensor_parallel.utils import VocabUtility -from apex.transformer.log_util import get_transformer_logger - -# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for -# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent -# version of PyTorch. The following 4 lines are for backward comparability with -# older PyTorch. -if "reduce_scatter_tensor" not in dir(torch.distributed): - torch.distributed.reduce_scatter_tensor = torch.distributed._reduce_scatter_base -if "all_gather_into_tensor" not in dir(torch.distributed): - torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base - -_logger = get_transformer_logger(__name__) - - -_grad_accum_fusion_available = True -try: - import fused_weight_gradient_mlp_cuda -except ImportError: - _grad_accum_fusion_available = False - - -_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = { - "tensor_model_parallel": False, - "partition_dim": -1, - "partition_stride": 1, -} - - -def param_is_not_tensor_parallel_duplicate(param: torch.Tensor) -> bool: - from apex import deprecated_warning - - deprecated_warning( - "`apex.transformer` is deprecated and will be removed in September 2025. " - "We encourage you to migrate to Megatron Core. " - "It is available on PyPI at https://pypi.org/project/megatron-core/ " - "and its documentation can be found at https://docs.nvidia.com/megatron-core/developer-guide/latest/index.html." - ) - return (hasattr(param, "tensor_model_parallel") and param.tensor_model_parallel) or ( - get_tensor_model_parallel_rank() == 0 - ) - - -def set_tensor_model_parallel_attributes( - tensor: torch.Tensor, is_parallel: bool, dim: int, stride: int -) -> None: - from apex import deprecated_warning - - deprecated_warning( - "`apex.transformer` is deprecated and will be removed in September 2025. " - "We encourage you to migrate to Megatron Core. " - "It is available on PyPI at https://pypi.org/project/megatron-core/ " - "and its documentation can be found at https://docs.nvidia.com/megatron-core/developer-guide/latest/index.html." - ) - # Make sure the attributes are not set. - for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS: - assert not hasattr(tensor, attribute) - # Set the attributes. - setattr(tensor, "tensor_model_parallel", is_parallel) - setattr(tensor, "partition_dim", dim) - setattr(tensor, "partition_stride", stride) - - -def set_defaults_if_not_set_tensor_model_parallel_attributes( - tensor: torch.Tensor, -) -> None: - from apex import deprecated_warning - - deprecated_warning( - "`apex.transformer` is deprecated and will be removed in September 2025. " - "We encourage you to migrate to Megatron Core. " - "It is available on PyPI at https://pypi.org/project/megatron-core/ " - "and its documentation can be found at https://docs.nvidia.com/megatron-core/developer-guide/latest/index.html." - ) - - def maybe_set(attribute, value): - if not hasattr(tensor, attribute): - setattr(tensor, attribute, value) - - for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS: - maybe_set(attribute, _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS[attribute]) - - -def copy_tensor_model_parallel_attributes( - destination_tensor: torch.Tensor, source_tensor: torch.Tensor -) -> None: - from apex import deprecated_warning - - deprecated_warning( - "`apex.transformer` is deprecated and will be removed in September 2025. " - "We encourage you to migrate to Megatron Core. " - "It is available on PyPI at https://pypi.org/project/megatron-core/ " - "and its documentation can be found at https://docs.nvidia.com/megatron-core/developer-guide/latest/index.html." - ) - - def maybe_copy(attribute): - if hasattr(source_tensor, attribute): - setattr(destination_tensor, attribute, getattr(source_tensor, attribute)) - - for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS: - maybe_copy(attribute) - - -def _initialize_affine_weight_gpu(weight, init_method, partition_dim, stride=1): - """Initialize affine weight for model parallel on GPU. - - Args: - weight (Parameter): - init_method (Callable[[Tensor], None]): Taking a Tensor and initialize its elements. - partition_dim (int): Dimension to apply partition. - stride (int): - """ - - set_tensor_model_parallel_attributes( - tensor=weight, is_parallel=True, dim=partition_dim, stride=stride - ) - - with get_cuda_rng_tracker().fork(): - init_method(weight) - - -# TODO (mkozuki): Re-consider removing params_dtype from arguments to make this -# more parallel with _initialize_affine_weight_gpu -def _initialize_affine_weight_cpu( - weight, - output_size, - input_size, - per_partition_size, - partition_dim, - init_method, - stride=1, - return_master_weight=False, - *, - params_dtype=torch.float32, -): - """Initialize affine weight for model parallel. - - Build the master weight on all processes and scatter - the relevant chunk.""" - - set_tensor_model_parallel_attributes( - tensor=weight, is_parallel=True, dim=partition_dim, stride=stride - ) - - # Initialize master weight - master_weight = torch.empty(output_size, input_size, dtype=torch.float, requires_grad=False) - init_method(master_weight) - master_weight = master_weight.to(dtype=params_dtype) - - # Split and copy - per_partition_per_stride_size = divide(per_partition_size, stride) - weight_list = torch.split(master_weight, per_partition_per_stride_size, dim=partition_dim) - rank = get_tensor_model_parallel_rank() - world_size = get_tensor_model_parallel_world_size() - my_weight_list = weight_list[rank::world_size] - - with torch.no_grad(): - torch.cat(my_weight_list, dim=partition_dim, out=weight) - if return_master_weight: - return master_weight - return None - - -class VocabParallelEmbedding(torch.nn.Module): - """Embedding parallelized in the vocabulary dimension. - - This is mainly adapted from torch.nn.Embedding and all the default - values are kept. - Arguments: - num_embeddings: vocabulary size. - embedding_dim: size of hidden state. - init_method: method to initialize weights. - """ - - def __init__( - self, - num_embeddings: int, - embedding_dim: int, - init_method=init.xavier_normal_, - *, - params_dtype: torch.dtype = torch.float32, - use_cpu_initialization: bool = False, - ): - from apex import deprecated_warning - - deprecated_warning( - "`apex.transformer` is deprecated and will be removed in September 2025. " - "We encourage you to migrate to Megatron Core. " - "It is available on PyPI at https://pypi.org/project/megatron-core/ " - "and its documentation can be found at https://docs.nvidia.com/megatron-core/developer-guide/latest/index.html." - ) - super().__init__() - # Keep the input dimensions. - self.num_embeddings = num_embeddings - self.embedding_dim = embedding_dim - # Set the detauls for compatibility. - self.padding_idx = None - self.max_norm = None - self.norm_type = 2.0 - self.scale_grad_by_freq = False - self.sparse = False - self._weight = None - self.tensor_model_parallel_size = get_tensor_model_parallel_world_size() - # Divide the weight matrix along the vocabulary dimension. - ( - self.vocab_start_index, - self.vocab_end_index, - ) = VocabUtility.vocab_range_from_global_vocab_size( - self.num_embeddings, - get_tensor_model_parallel_rank(), - self.tensor_model_parallel_size, - ) - self.num_embeddings_per_partition = self.vocab_end_index - self.vocab_start_index - - # Allocate weights and initialize. - if use_cpu_initialization: - self.weight = Parameter( - torch.empty( - self.num_embeddings_per_partition, - self.embedding_dim, - dtype=params_dtype, - ) - ) - _initialize_affine_weight_cpu( - self.weight, - self.num_embeddings, - self.embedding_dim, - self.num_embeddings_per_partition, - 0, - init_method, - params_dtype=params_dtype, - ) - else: - self.weight = Parameter( - torch.empty( - self.num_embeddings_per_partition, - self.embedding_dim, - device=torch.cuda.current_device(), - dtype=params_dtype, - ) - ) - _initialize_affine_weight_gpu(self.weight, init_method, partition_dim=0, stride=1) - - def forward(self, input_): - if self.tensor_model_parallel_size > 1: - # Build the mask. - input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index) - # Mask the input. - masked_input = input_.clone() - self.vocab_start_index - masked_input[input_mask] = 0 - else: - masked_input = input_ - # Get the embeddings. - output_parallel = F.embedding( - masked_input, - self.weight, - self.padding_idx, - self.max_norm, - self.norm_type, - self.scale_grad_by_freq, - self.sparse, - ) - # Mask the output embedding. - if self.tensor_model_parallel_size > 1: - output_parallel[input_mask, :] = 0.0 - # Reduce across all the model parallel GPUs. - output = reduce_from_tensor_model_parallel_region(output_parallel) - return output - - -class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function): - """Linear layer execution with asynchronous communication and gradient accumulation fusion in backprop.""" - - @staticmethod - def forward( - ctx, - input: torch.Tensor, - weight: torch.Tensor, - bias: Optional[torch.Tensor], - gradient_accumulation_fusion: bool, - async_grad_allreduce: bool, - sequence_parallel_enabled: bool, - use_16bit_in_wgrad_accum_fusion: Optional[bool] = None, - ): - ctx.use_bias = bias is not None and weight.requires_grad - ctx.gradient_accumulation_fusion = gradient_accumulation_fusion - ctx.async_grad_allreduce = async_grad_allreduce - ctx.sequence_parallel_enabled = sequence_parallel_enabled - ctx.compute_weight_gradient = weight.requires_grad - - if use_16bit_in_wgrad_accum_fusion is not None: - warnings.warn( - "Deprecated option `use_16bit_in_wgrad_accum_fusion` " - f"is set to {use_16bit_in_wgrad_accum_fusion}" - ) - - if ctx.compute_weight_gradient: - ctx.save_for_backward(input, weight) - else: - ctx.save_for_backward(weight) - - if ctx.sequence_parallel_enabled: - world_size = get_tensor_model_parallel_world_size() - # `input` is supposed to be 3D and its order of dimension is [sequence, batch, hidden] - shape = list(input.shape) - shape[0] *= world_size - - all_gather_buffer = torch.empty( - shape, - dtype=input.dtype, - device=torch.cuda.current_device(), - requires_grad=False, - ) - torch.distributed.all_gather_into_tensor( - all_gather_buffer, input, group=get_tensor_model_parallel_group() - ) - total_input = all_gather_buffer - else: - total_input = input - output = torch.matmul(total_input, weight.t()) - if bias is not None: - output = output + bias - return output - - @staticmethod - def backward(ctx, grad_output): - if ctx.compute_weight_gradient: - input, weight = ctx.saved_tensors - else: - weight = ctx.saved_tensors[0] - input = None - - use_bias = ctx.use_bias - - # only get sequence parallel inputs if need to calculate weight grad - handle = None - if ctx.compute_weight_gradient: - if ctx.sequence_parallel_enabled: - world_size = get_tensor_model_parallel_world_size() - shape = list(input.shape) - shape[0] *= world_size - - all_gather_buffer = torch.empty( - shape, - dtype=input.dtype, - device=torch.cuda.current_device(), - requires_grad=False, - ) - handle = torch.distributed.all_gather_into_tensor( - all_gather_buffer, - input, - group=get_tensor_model_parallel_group(), - async_op=True, - ) - total_input = all_gather_buffer - else: - total_input = input - - grad_input = grad_output.matmul(weight) - - if handle is not None: - handle.wait() - - if ctx.async_grad_allreduce: - # Asynchronous all-reduce - handle = torch.distributed.all_reduce( - grad_input, group=get_tensor_model_parallel_group(), async_op=True - ) - - # if no weight gradient, immediately return - if not ctx.compute_weight_gradient: - if ctx.sequence_parallel_enabled: - assert not ctx.async_grad_allreduce - world_size = get_tensor_model_parallel_world_size() - shape = list(grad_input.shape) - shape[0] //= world_size - - sub_grad_input = torch.empty( - torch.Size(shape), - dtype=grad_input.dtype, - device=torch.cuda.current_device(), - requires_grad=False, - ) - handle = torch.distributed.reduce_scatter_tensor( - sub_grad_input, - grad_input, - group=get_tensor_model_parallel_group(), - async_op=True, - ) - handle.wait() - return sub_grad_input, None, None, None, None, None, None - if ctx.async_grad_allreduce: - handle.wait() - return grad_input, None, None, None, None, None, None - - # Convert the tensor shapes to 2D for execution compatibility - grad_output = grad_output.contiguous() - grad_output = grad_output.view( - grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2] - ) - total_input = total_input.view( - total_input.shape[0] * total_input.shape[1], total_input.shape[2] - ) - - if ctx.sequence_parallel_enabled: - assert not ctx.async_grad_allreduce - sub_grad_input = torch.empty( - input.shape, - dtype=input.dtype, - device=torch.cuda.current_device(), - requires_grad=False, - ) - handle = torch.distributed.reduce_scatter_tensor( - sub_grad_input, - grad_input, - group=get_tensor_model_parallel_group(), - async_op=True, - ) - - if ctx.gradient_accumulation_fusion: - if not hasattr(weight, "main_grad"): - raise RuntimeError( - "attempted to perform gradient accumulation fusion on param without setting main_grad" - ) - if weight.main_grad.dtype == torch.float32: - fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32( - total_input, grad_output, weight.main_grad - ) - elif weight.main_grad.dtype in (torch.float16, torch.bfloat16): - fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16( - total_input, grad_output, weight.main_grad - ) - else: - raise RuntimeError(f"unsupported dtype for main_grad ({weight.main_grad.dtype})") - grad_weight = None - else: - grad_weight = grad_output.t().matmul(total_input) - grad_bias = grad_output.sum(dim=0) if use_bias else None - if ctx.sequence_parallel_enabled: - handle.wait() - return sub_grad_input, grad_weight, grad_bias, None, None, None, None - if ctx.async_grad_allreduce: - handle.wait() - return grad_input, grad_weight, grad_bias, None, None, None, None - - -def linear_with_grad_accumulation_and_async_allreduce( - input: torch.Tensor, - weight: torch.Tensor, - bias: Optional[torch.Tensor], - gradient_accumulation_fusion: bool, - async_grad_allreduce: bool, - sequence_parallel_enabled: bool, -) -> torch.Tensor: - args = _cast_if_autocast_enabled( - input, - weight, - bias, - gradient_accumulation_fusion, - async_grad_allreduce, - sequence_parallel_enabled, - ) - with torch.amp.autocast("cuda", enabled=False): - return LinearWithGradAccumulationAndAsyncCommunication.apply(*args) - - -class ColumnParallelLinear(torch.nn.Module): - """Linear layer with column parallelism. - - The linear layer is defined as Y = XA + b. A is parallelized along - its second dimension as A = [A_1, ..., A_p]. - - .. note:: - Input is supposed to be three dimensional and each dimension - is expected to be sequence, batch, and hidden feature, respectively. - - Arguments: - input_size: first dimension of matrix A. - output_size: second dimension of matrix A. - bias: If true, add bias - gather_output: If true, call all-gether on output and make Y avaiable - to all GPUs, otherwise, every GPU will have its output - which is Y_i = XA_i - init_method: method to initialize weights. Note that bias is always set - to zero. - stride: For the strided linear layers. - keep_master_weight_for_test: This was added for testing and should be - set to False. It returns the master weights - used for initialization. - skip_bias_add: This was added to enable performance optimations where bias - can be fused with other elementwise operations. we skip - adding bias but instead return it. - - Keyword Arguments: - no_async_tensor_model_parallel_allreduce: - params_dtype: - use_cpu_initialization: - gradient_accumulation_fusion: - sequence_parallel_enabled: - accumulation_in_fp16: Deprecated - """ - - def __init__( - self, - input_size, - output_size, - bias=True, - gather_output=True, - init_method=init.xavier_normal_, - stride=1, - keep_master_weight_for_test=False, - skip_bias_add=False, - *, - no_async_tensor_model_parallel_allreduce=False, - params_dtype=torch.float32, - use_cpu_initialization=False, - gradient_accumulation_fusion=False, - sequence_parallel_enabled: bool = False, - accumulation_in_fp16: Optional[bool] = None, - ): - from apex import deprecated_warning - - deprecated_warning( - "`apex.transformer` is deprecated and will be removed in September 2025. " - "We encourage you to migrate to Megatron Core. " - "It is available on PyPI at https://pypi.org/project/megatron-core/ " - "and its documentation can be found at https://docs.nvidia.com/megatron-core/developer-guide/latest/index.html." - ) - super().__init__() - - # Keep input parameters - self.input_size = input_size - self.output_size = output_size - self.gather_output = gather_output - # Divide the weight matrix along the last dimension. - world_size = get_tensor_model_parallel_world_size() - self.output_size_per_partition = divide(output_size, world_size) - self.skip_bias_add = skip_bias_add - - if accumulation_in_fp16 is not None: - warnings.warn( - f"Deprecated option `accumulation_in_fp16` is set to {accumulation_in_fp16}" - ) - - # Parameters. - # Note: torch.nn.functional.linear performs XA^T + b and as a result - # we allocate the transpose. - # Initialize weight. - if use_cpu_initialization: - self.weight = Parameter( - torch.empty(self.output_size_per_partition, self.input_size, dtype=params_dtype) - ) - self.master_weight = _initialize_affine_weight_cpu( - self.weight, - self.output_size, - self.input_size, - self.output_size_per_partition, - 0, - init_method, - stride=stride, - return_master_weight=keep_master_weight_for_test, - params_dtype=params_dtype, - ) - else: - self.weight = Parameter( - torch.empty( - self.output_size_per_partition, - self.input_size, - device=torch.cuda.current_device(), - dtype=params_dtype, - ) - ) - _initialize_affine_weight_gpu(self.weight, init_method, partition_dim=0, stride=stride) - - if bias: - if use_cpu_initialization: - self.bias = Parameter( - torch.empty(self.output_size_per_partition, dtype=params_dtype) - ) - else: - self.bias = Parameter( - torch.empty( - self.output_size_per_partition, - device=torch.cuda.current_device(), - dtype=params_dtype, - ) - ) - set_tensor_model_parallel_attributes(self.bias, True, 0, stride) - # Always initialize bias to zero. - with torch.no_grad(): - self.bias.zero_() - else: - self.register_parameter("bias", None) - - self.async_tensor_model_parallel_allreduce = ( - not no_async_tensor_model_parallel_allreduce and world_size > 1 - ) - if sequence_parallel_enabled: - if world_size <= 1: - warnings.warn( - f"`sequence_parallel_enabled` is set to `True`, but got world_size of {world_size}" - ) - # sequence_parallel_enabled = False - self.sequence_parallel_enabled = sequence_parallel_enabled - if gradient_accumulation_fusion: - if not _grad_accum_fusion_available: - # Basically, apex.transformer module users are expected to install APEX's - # `--cpp_ext` and `--cuda_ext`. The example installation command is as follows: - # `pip install --global-option="--cpp_ext" --global-option="--cuda_ext ." - # at the root of APEX repository. - warnings.warn( - "`gradient_accumulation_fusion` is set to `True` but " - "the custom CUDA extension of `fused_weight_gradient_mlp_cuda` module not " - "found. Thus `gradient_accumulation_fusion` set to `False`. " - "Note that the extension requires CUDA>=11." - ) - gradient_accumulation_fusion = False - self.gradient_accumulation_fusion = gradient_accumulation_fusion - - if self.async_tensor_model_parallel_allreduce and self.sequence_parallel_enabled: - raise RuntimeError( - "`async_tensor_model_parallel_allreduce` and `sequence_parallel_enabled` cannot be enabled at the same time." - ) - - self._forward_impl = linear_with_grad_accumulation_and_async_allreduce - - def forward(self, input_: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - """Forward of ColumnParallelLinear - - Args: - input_: 3D tensor whose order of dimension is [sequence, batch, hidden] - - Returns: - - output - - bias - """ - bias = self.bias if not self.skip_bias_add else None - - if self.async_tensor_model_parallel_allreduce or self.sequence_parallel_enabled: - input_parallel = input_ - else: - input_parallel = copy_to_tensor_model_parallel_region(input_) - - # Matrix multiply. - output_parallel = self._forward_impl( - input=input_parallel, - weight=self.weight, - bias=bias, - gradient_accumulation_fusion=self.gradient_accumulation_fusion, - async_grad_allreduce=self.async_tensor_model_parallel_allreduce, - sequence_parallel_enabled=self.sequence_parallel_enabled, - ) - if self.gather_output: - # All-gather across the partitions. - assert not self.sequence_parallel_enabled - output = gather_from_tensor_model_parallel_region(output_parallel) - else: - output = output_parallel - output_bias = self.bias if self.skip_bias_add else None - return output, output_bias - - -class RowParallelLinear(torch.nn.Module): - """Linear layer with row parallelism. - - The linear layer is defined as Y = XA + b. A is parallelized along - its first dimension and X along its second dimension as: - - - - | A_1 | - | . | - A = | . | X = [X_1, ..., X_p] - | . | - | A_p | - - - - - .. note:: - Input is supposed to be three dimensional and each dimension - is expected to be sequence, batch, and hidden feature, respectively. - - Arguments: - input_size: first dimension of matrix A. - output_size: second dimension of matrix A. - bias: If true, add bias. Note that bias is not parallelized. - input_is_parallel: If true, we assume that the input is already - split across the GPUs and we do not split - again. - init_method: method to initialize weights. Note that bias is always set - to zero. - stride: For the strided linear layers. - keep_master_weight_for_test: This was added for testing and should be - set to False. It returns the master weights - used for initialization. - skip_bias_add: This was added to enable performance optimization where bias - can be fused with other elementwise operations. We skip - adding bias but instead return it. - Keyword Arguments: - params_dtype: - use_cpu_initialization: - gradient_accumulation_fusion: - sequence_parallel_enabled: - accumulation_in_fp16: Deprecated - """ - - def __init__( - self, - input_size, - output_size, - bias=True, - input_is_parallel=False, - init_method=init.xavier_normal_, - stride=1, - keep_master_weight_for_test=False, - skip_bias_add=False, - *, - params_dtype=torch.float32, - use_cpu_initialization=False, - gradient_accumulation_fusion=False, - sequence_parallel_enabled: bool = False, - accumulation_in_fp16: Optional[bool] = None, - ): - from apex import deprecated_warning - - deprecated_warning( - "`apex.transformer` is deprecated and will be removed in September 2025. " - "We encourage you to migrate to Megatron Core. " - "It is available on PyPI at https://pypi.org/project/megatron-core/ " - "and its documentation can be found at https://docs.nvidia.com/megatron-core/developer-guide/latest/index.html." - ) - super().__init__() - - # Keep input parameters - self.input_size = input_size - self.output_size = output_size - self.input_is_parallel = input_is_parallel - # Divide the weight matrix along the last dimension. - world_size = get_tensor_model_parallel_world_size() - self.input_size_per_partition = divide(input_size, world_size) - self.skip_bias_add = skip_bias_add - self.gradient_accumulation_fusion = gradient_accumulation_fusion - self.sequence_parallel_enabled = sequence_parallel_enabled - if self.sequence_parallel_enabled and not self.input_is_parallel: - raise RuntimeError( - "To enable `sequence_parallel_enabled`, `input_is_parallel` must be `True`" - ) - - if accumulation_in_fp16 is not None: - warnings.warn( - f"Deprecated option `accumulation_in_fp16` is set to {accumulation_in_fp16}" - ) - - # as an argument to this function? - # Parameters. - # Note: torch.nn.functional.linear performs XA^T + b and as a result - # we allocate the transpose. - # Initialize weight. - if use_cpu_initialization: - self.weight = Parameter( - torch.empty(self.output_size, self.input_size_per_partition, dtype=params_dtype) - ) - self.master_weight = _initialize_affine_weight_cpu( - self.weight, - self.output_size, - self.input_size, - self.input_size_per_partition, - 1, - init_method, - stride=stride, - return_master_weight=keep_master_weight_for_test, - params_dtype=params_dtype, - ) - else: - self.weight = Parameter( - torch.empty( - self.output_size, - self.input_size_per_partition, - device=torch.cuda.current_device(), - dtype=params_dtype, - ) - ) - _initialize_affine_weight_gpu(self.weight, init_method, partition_dim=1, stride=stride) - if bias: - if use_cpu_initialization: - self.bias = Parameter(torch.empty(self.output_size, dtype=params_dtype)) - else: - self.bias = Parameter( - torch.empty( - self.output_size, - device=torch.cuda.current_device(), - dtype=params_dtype, - ) - ) - # Always initialize bias to zero. - with torch.no_grad(): - self.bias.zero_() - setattr(self.bias, "sequence_parallel_enabled", sequence_parallel_enabled) - else: - self.register_parameter("bias", None) - - self._forward_impl = linear_with_grad_accumulation_and_async_allreduce - - def forward(self, input_: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - """Forward of RowParallelLinear - - Args: - input_: 3D tensor whose order of dimension is [sequence, batch, hidden] - - Returns: - - output - - bias - """ - # Set up backprop all-reduce. - if self.input_is_parallel: - input_parallel = input_ - else: - assert not self.sequence_parallel_enabled - input_parallel = scatter_to_tensor_model_parallel_region(input_) - # Matrix multiply. - output_parallel = self._forward_impl( - input=input_parallel, - weight=self.weight, - bias=None, - gradient_accumulation_fusion=self.gradient_accumulation_fusion, - async_grad_allreduce=False, - sequence_parallel_enabled=False, - ) - # All-reduce across all the partitions. - if self.sequence_parallel_enabled: - output_ = reduce_scatter_to_sequence_parallel_region(output_parallel) - else: - output_ = reduce_from_tensor_model_parallel_region(output_parallel) - if not self.skip_bias_add: - output = output_ + self.bias if self.bias is not None else output_ - output_bias = None - else: - output = output_ - output_bias = self.bias - return output, output_bias diff --git a/apex/transformer/tensor_parallel/mappings.py b/apex/transformer/tensor_parallel/mappings.py deleted file mode 100644 index a2501a17f..000000000 --- a/apex/transformer/tensor_parallel/mappings.py +++ /dev/null @@ -1,309 +0,0 @@ -# coding=utf-8 -# Copyright (c) 2021-22, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import torch - -from apex.transformer.parallel_state import get_tensor_model_parallel_group -from apex.transformer.parallel_state import get_tensor_model_parallel_world_size -from apex.transformer.parallel_state import get_tensor_model_parallel_rank -from apex.transformer.tensor_parallel.utils import split_tensor_along_last_dim - -# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for -# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent -# version of PyTorch. The following 4 lines are for backward comparability with -# older PyTorch. -if "all_gather_into_tensor" not in dir(torch.distributed): - torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base -if "reduce_scatter_tensor" not in dir(torch.distributed): - torch.distributed.reduce_scatter_tensor = torch.distributed._reduce_scatter_base - - -def _reduce(input_: torch.Tensor) -> torch.Tensor: - """All-reduce the input tensor across model parallel group.""" - - # Bypass the function if we are using only 1 GPU. - if get_tensor_model_parallel_world_size() == 1: - return input_ - - # All-reduce. - torch.distributed.all_reduce(input_, group=get_tensor_model_parallel_group()) - - return input_ - - -def _split_along_last_dim(input_: torch.Tensor) -> torch.Tensor: - """Split the tensor along its last dimension and keep the - corresponding slice.""" - - world_size = get_tensor_model_parallel_world_size() - # Bypass the function if we are using only 1 GPU. - if world_size == 1: - return input_ - - # Split along last dimension. - input_list = split_tensor_along_last_dim(input_, world_size) - - # Note: torch.split does not create contiguous tensors by default. - rank = get_tensor_model_parallel_rank() - output = input_list[rank].contiguous() - - return output - - -def _split_along_first_dim(input_: torch.Tensor) -> torch.Tensor: - """Split the tensor along its first dimension and keep the corresponding slice.""" - world_size = get_tensor_model_parallel_world_size() - # Bypass the function if we are using only 1 GPU for tensor model parallel. - if world_size == 1: - return input_ - - # Split along first dimension. - dim_size = input_.size(0) - assert dim_size % world_size == 0 - local_dim_size = dim_size // world_size - dim_offset = get_tensor_model_parallel_rank() * local_dim_size - output = input_[dim_offset : dim_offset + local_dim_size].contiguous() - return output - - -def _gather_along_last_dim(input_: torch.Tensor) -> torch.Tensor: - """Gather tensors and concatenate along the last dimension.""" - - world_size = get_tensor_model_parallel_world_size() - # Bypass the function if we are using only 1 GPU. - if world_size == 1: - return input_ - - # Size and dimension. - last_dim = input_.dim() - 1 - rank = get_tensor_model_parallel_rank() - - tensor_list = [torch.empty_like(input_) for _ in range(world_size)] - tensor_list[rank] = input_ - torch.distributed.all_gather(tensor_list, input_, group=get_tensor_model_parallel_group()) - - # Note: torch.cat already creates a contiguous tensor. - output = torch.cat(tensor_list, dim=last_dim).contiguous() - - return output - - -def _gather_along_first_dim(input_: torch.Tensor) -> torch.Tensor: - """Gather tensors and concatenate along the first dimension.""" - world_size = get_tensor_model_parallel_world_size() - # Bypass the function if we are using only 1 GPU. - if world_size == 1: - return input_ - - shape = list(input_.shape) - shape[0] *= world_size - - output = torch.empty(shape, dtype=input_.dtype, device=torch.cuda.current_device()) - torch.distributed.all_gather_into_tensor( - output, input_.contiguous(), group=get_tensor_model_parallel_group() - ) - return output - - -def _reduce_scatter_along_first_dim(input_: torch.Tensor) -> torch.Tensor: - """Reduce-scatter the input tensor across model parallel group.""" - world_size = get_tensor_model_parallel_world_size() - # Bypass the function if we are using only 1 GPU. - if world_size == 1: - return input_ - - shape = list(input_.shape) - assert shape[0] % world_size == 0 - shape[0] //= world_size - output = torch.empty(shape, dtype=input_.dtype, device=torch.cuda.current_device()) - torch.distributed.reduce_scatter_tensor( - output, input_.contiguous(), group=get_tensor_model_parallel_group() - ) - return output - - -class _CopyToModelParallelRegion(torch.autograd.Function): - """Pass the input to the tensor model parallel region.""" - - # FIXME(mkozuki): Definition of static symbolic methods don't look correct according to - # https://pytorch.org/docs/stable/onnx.html#static-symbolic-method - @staticmethod - def symbolic(graph, input_): - return input_ - - @staticmethod - def forward(ctx, input_): - return input_ - - @staticmethod - def backward(ctx, grad_output): - return _reduce(grad_output) - - -class _ReduceFromModelParallelRegion(torch.autograd.Function): - """All-reduce the input from the tensor model parallel region.""" - - # FIXME(mkozuki): Definition of static symbolic methods don't look correct according to - # https://pytorch.org/docs/stable/onnx.html#static-symbolic-method - @staticmethod - def symbolic(graph, input_): - return _reduce(input_) - - @staticmethod - def forward(ctx, input_): - return _reduce(input_) - - @staticmethod - def backward(ctx, grad_output): - return grad_output - - -class _ScatterToModelParallelRegion(torch.autograd.Function): - """Split the input and keep only the corresponding chuck to the rank.""" - - # FIXME(mkozuki): Definition of static symbolic methods don't look correct according to - # https://pytorch.org/docs/stable/onnx.html#static-symbolic-method - @staticmethod - def symbolic(graph, input_): - return _split_along_last_dim(input_) - - @staticmethod - def forward(ctx, input_): - return _split_along_last_dim(input_) - - @staticmethod - def backward(ctx, grad_output): - return _gather_along_last_dim(grad_output) - - -class _GatherFromModelParallelRegion(torch.autograd.Function): - """Gather the input from tensor model parallel region and concatenate.""" - - # FIXME(mkozuki): Definition of static symbolic methods don't look correct according to - # https://pytorch.org/docs/stable/onnx.html#static-symbolic-method - @staticmethod - def symbolic(graph, input_): - return _gather_along_last_dim(input_) - - @staticmethod - def forward(ctx, input_): - return _gather_along_last_dim(input_) - - @staticmethod - def backward(ctx, grad_output): - return _split_along_last_dim(grad_output) - - -class _ScatterToSequenceParallelRegion(torch.autograd.Function): - """Split the input and keep only the corresponding chunk to the rank.""" - - # FIXME(mkozuki): Definition of static symbolic methods don't look correct according to - # https://pytorch.org/docs/stable/onnx.html#static-symbolic-method - @staticmethod - def symbolic(graph, input_): - return _split_along_first_dim(input_) - - @staticmethod - def forward(ctx, input_): - return _split_along_first_dim(input_) - - @staticmethod - def backward(ctx, grad_output): - return _gather_along_first_dim(grad_output) - - -class _GatherFromSequenceParallelRegion(torch.autograd.Function): - """Gather the input from sequence parallel region and concatenate.""" - - # FIXME(mkozuki): Definition of static symbolic methods don't look correct according to - # https://pytorch.org/docs/stable/onnx.html#static-symbolic-method - @staticmethod - def symbolic(graph, input_, to_model_parallel: bool = True): - return _gather_along_first_dim(input_) - - @staticmethod - def forward(ctx, input_, to_model_parallel: bool = True): - ctx.to_model_parallel = to_model_parallel - return _gather_along_first_dim(input_) - - @staticmethod - def backward(ctx, grad_output): - if ctx.to_model_parallel: - return _reduce_scatter_along_first_dim(grad_output), None - else: - return _split_along_first_dim(grad_output), None - - -class _ReduceScatterToSequenceParallelRegion(torch.autograd.Function): - """Reduce scatter the input from the sequence parallel region and concatenate.""" - - # FIXME(mkozuki): Definition of static symbolic methods don't look correct according to - # https://pytorch.org/docs/stable/onnx.html#static-symbolic-method - @staticmethod - def symbolic(graph, input_): - return _reduce_scatter_along_first_dim(input_) - - @staticmethod - def forward(ctx, input_): - return _reduce_scatter_along_first_dim(input_) - - @staticmethod - def backward(ctx, grad_output): - return _gather_along_first_dim(grad_output) - - -# ----------------- -# Helper functions. -# ----------------- - - -def copy_to_tensor_model_parallel_region(input_: torch.Tensor) -> torch.Tensor: - return _CopyToModelParallelRegion.apply(input_) - - -def reduce_from_tensor_model_parallel_region(input_: torch.Tensor) -> torch.Tensor: - return _ReduceFromModelParallelRegion.apply(input_) - - -def scatter_to_tensor_model_parallel_region(input_: torch.Tensor) -> torch.Tensor: - return _ScatterToModelParallelRegion.apply(input_) - - -def gather_from_tensor_model_parallel_region(input_: torch.Tensor) -> torch.Tensor: - return _GatherFromModelParallelRegion.apply(input_) - - -def scatter_to_sequence_parallel_region(input_: torch.Tensor) -> torch.Tensor: - return _ScatterToSequenceParallelRegion.apply(input_) - - -def gather_from_sequence_parallel_region( - input_: torch.Tensor, to_model_parallel: bool = True -) -> torch.Tensor: - return _GatherFromSequenceParallelRegion.apply(input_, to_model_parallel) - - -def reduce_scatter_to_sequence_parallel_region(input_: torch.Tensor) -> torch.Tensor: - return _ReduceScatterToSequenceParallelRegion.apply(input_) - - -__all__ = [ - "copy_to_tensor_model_parallel_region", - "reduce_from_tensor_model_parallel_region", - "scatter_to_tensor_model_parallel_region", - "gather_from_tensor_model_parallel_region", - "scatter_to_sequence_parallel_region", - "gather_from_sequence_parallel_region", - "reduce_scatter_to_sequence_parallel_region", -] diff --git a/apex/transformer/tensor_parallel/memory.py b/apex/transformer/tensor_parallel/memory.py deleted file mode 100644 index 61e35c7a9..000000000 --- a/apex/transformer/tensor_parallel/memory.py +++ /dev/null @@ -1,147 +0,0 @@ -# coding=utf-8 -# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. - -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# TODO(mkozuki): Remove this file as Megatron-LM seems to have done so. -import torch - - -# A dictionary of all the memory buffers allocated. -_MEM_BUFFS = dict() - - -def allocate_mem_buff(name, numel, dtype, track_usage): - """Allocate a memory buffer.""" - assert name not in _MEM_BUFFS, "memory buffer {} already allocated.".format(name) - _MEM_BUFFS[name] = MemoryBuffer(name, numel, dtype, track_usage) - return _MEM_BUFFS[name] - - -def get_mem_buff(name): - """Get the memory buffer.""" - return _MEM_BUFFS[name] - - -class MemoryBuffer: - """Contiguous memory buffer. - Allocate a contiguous memory of type `dtype` and size `numel`. It is - used to reduce memory fragmentation. - - Usage: After the allocation, the `_start` index is set tot the first - index of the memory. A memory chunk starting from `_start` index - can be `allocated` for an input tensor, with the elements of the - tensor being coppied. The buffer can be reused by resetting the - `_start` index. - - """ - - def __init__(self, name, numel, dtype, track_usage): - if torch.distributed.get_rank() == 0: - element_size = torch.tensor([], dtype=dtype).element_size() - print( - "> building the {} memory buffer with {} num elements " - "and {} dtype ({:.1f} MB)...".format( - name, numel, dtype, numel * element_size / 1024 / 1024 - ), - flush=True, - ) - self.name = name - self.numel = numel - self.dtype = dtype - self.data = torch.empty( - self.numel, - dtype=self.dtype, - device=torch.cuda.current_device(), - requires_grad=False, - ) - - # Index tracking the start of the free memory. - self._start = 0 - - # Values used for tracking usage. - self.track_usage = track_usage - if self.track_usage: - self.in_use_value = 0.0 - self.total_value = 0.0 - - def reset(self): - """Reset the buffer start index to the beginning of the buffer.""" - self._start = 0 - - def is_in_use(self): - """Whether the current buffer hold on to any memory.""" - return self._start > 0 - - def numel_in_use(self): - """Return number of elements in use.""" - return self._start - - def add(self, tensor): - """Allocate a chunk of memory from the buffer to tensor and copy - the values.""" - assert tensor.dtype == self.dtype, ( - "Input tensor type {} different from buffer type {}".format(tensor.dtype, self.dtype) - ) - # Number of elements of the input tensor. - tensor_numel = torch.numel(tensor) - new_start = self._start + tensor_numel - assert new_start <= self.numel, "Not enough memory left in the buffer ({} > {})".format( - tensor_numel, self.numel - self._start - ) - # New tensor is a view into the memory. - new_tensor = self.data[self._start : new_start] - self._start = new_start - new_tensor = new_tensor.view(tensor.shape) - new_tensor.copy_(tensor) - # Return a pointer to the new tensor. - return new_tensor - - def get_data(self): - """Return the data currently in use.""" - if self.track_usage: - self.in_use_value += float(self._start) - self.total_value += float(self.numel) - return self.data[: self._start] - - def print_average_usage(self): - """Print memory usage average over time. We would like this value - to be as high as possible.""" - assert self.track_usage, "You need to enable track usage." - if torch.distributed.get_rank() == 0: - print( - " > usage of {} memory buffer: {:.2f} %".format( - self.name, self.in_use_value * 100.0 / self.total_value - ), - flush=True, - ) - - -class RingMemBuffer: - """A ring of memory buffers.""" - - def __init__(self, name, num_buffers, numel, dtype, track_usage): - self.num_buffers = num_buffers - self.buffers = [ - allocate_mem_buff(name + " {}".format(i), numel, dtype, track_usage) - for i in range(num_buffers) - ] - self._index = -1 - - def get_next_buffer(self): - self._index += 1 - self._index = self._index % self.num_buffers - buff = self.buffers[self._index] - assert not buff.is_in_use(), "buffer is already in use." - return buff diff --git a/apex/transformer/tensor_parallel/random.py b/apex/transformer/tensor_parallel/random.py deleted file mode 100644 index dfa7c6fe6..000000000 --- a/apex/transformer/tensor_parallel/random.py +++ /dev/null @@ -1,256 +0,0 @@ -# coding=utf-8 -# Copyright (c) 2021-22, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# NOTE(mkozuki): This file is based on megatron-lm/mpu/random.py with some differences: -# - Not using "viewless" tensor: -# - _kernel_make_viewless_tensor -# - MakeViewlessTensor -# - make_viewless_tensor -# - assert_viewless_tensor -# - safely_set_viewless_tensor_data - -# Parts of the code here are adapted from PyTorch -# repo: https://github.com/pytorch/pytorch -import contextlib - -import torch -from torch import _C -from torch.cuda import _lazy_call, device as device_ctx_manager -from torch.utils.checkpoint import detach_variable - -from apex.transformer.parallel_state import get_tensor_model_parallel_rank -from apex.transformer.tensor_parallel.memory import allocate_mem_buff - - -# Default name for the model parallel rng tracker. -_MODEL_PARALLEL_RNG_TRACKER_NAME = "model-parallel-rng" - - -def _set_cuda_rng_state(new_state, device=-1): - """Sets the random number generator state of the current GPU. - - Arguments: - new_state (torch.ByteTensor): The desired state - This function is adapted from PyTorch repo (torch.cuda.set_rng_state) - with a single change: the input state is not cloned. Cloning caused - major performance issues for +4 GPU cases. - """ - if hasattr(_C, "_cuda_setRNGState") and callable(_C._cuda_setRNGState): - # older PyTorch - def cb(): - with device_ctx_manager(device): - _C._cuda_setRNGState(new_state) - - else: - # newer PyTorch - if device == -1: - device = torch.device("cuda") - elif isinstance(device, str): - device = torch.device(device) - elif isinstance(device, int): - device = torch.device("cuda", device) - - def cb(): - idx = device.index - if idx is None: - idx = torch.cuda.current_device() - default_generator = torch.cuda.default_generators[idx] - default_generator.set_state(new_state) - - _lazy_call(cb) - - -class CudaRNGStatesTracker: - """Tracker for the cuda RNG states. - - Using the `add` method, a cuda rng state is initialized based on - the input `seed` and is assigned to `name`. Later, by forking the - rng state, we can perform operations and return to our starting - cuda state. - """ - - def __init__(self): - # Map from a string name to the cuda rng state. - self.states_ = {} - # Seeds are just for book keeping and ensure no seed is set twice. - self.seeds_ = set() - - def reset(self): - """Set to the initial state (no tracker).""" - self.states_ = {} - self.seeds_ = set() - - def get_states(self): - """Get rng states. Copy the dictionary so we have direct - pointers to the states, not just a pointer to the dictionary.""" - states = {} - for name in self.states_: - states[name] = self.states_[name] - return states - - def set_states(self, states): - """Set the rng states. For efficiency purposes, we do not check - the size of seed for compatibility.""" - self.states_ = states - - def add(self, name, seed): - """Track the rng state.""" - # Check seed is not already used. - if seed in self.seeds_: - raise Exception("seed {} already exists".format(seed)) - self.seeds_.add(seed) - # Check that state is not already defined. - if name in self.states_: - raise Exception("cuda rng state {} already exists".format(name)) - # Get the current rng state. - orig_rng_state = torch.cuda.get_rng_state() - # Set the new state and store it. - torch.cuda.manual_seed(seed) - self.states_[name] = torch.cuda.get_rng_state() - # Reset rng state to what it was. - _set_cuda_rng_state(orig_rng_state) - - @contextlib.contextmanager - def fork(self, name=_MODEL_PARALLEL_RNG_TRACKER_NAME): - """Fork the cuda rng state, perform operations, and exit with - the original state.""" - # Check if we have added the state - if name not in self.states_: - raise Exception("cuda rng state {} is not added".format(name)) - # Store current rng state. - orig_cuda_rng_state = torch.cuda.get_rng_state() - # Set rng state to the desired one - _set_cuda_rng_state(self.states_[name]) - # Do the stuff we wanted to do. - try: - yield - finally: - # Update the current rng state for later use. - self.states_[name] = torch.cuda.get_rng_state() - # And set the state to the original state we started with. - _set_cuda_rng_state(orig_cuda_rng_state) - - -# RNG tracker object. -_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker() - - -def get_cuda_rng_tracker(): - """Get cuda rng tracker.""" - return _CUDA_RNG_STATE_TRACKER - - -def model_parallel_cuda_manual_seed(seed): - """Initialize model parallel cuda seed. - - This function should be called after the model parallel is - initialized. Also, no torch.cuda.manual_seed should be called - after this function. Basically, this is replacement for that - function. - Two set of RNG states are tracked: - default state: This is for data parallelism and is the same among a - set of model parallel GPUs but different across - different model paralle groups. This is used for - example for dropout in the non-tensor-model-parallel regions. - tensor-model-parallel state: This state is different among a set of model - parallel GPUs, but the same across data parallel - groups. This is used for example for dropout in - model parallel regions. - """ - # 2718 is just for fun and any POSITIVE value will work. - offset = seed + 2718 - tensor_model_parallel_seed = offset + get_tensor_model_parallel_rank() - # Data parallel gets the original seed. - data_parallel_seed = seed - - _CUDA_RNG_STATE_TRACKER.reset() - # Set the default state. - torch.cuda.manual_seed(data_parallel_seed) - # and model parallel state. - _CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME, tensor_model_parallel_seed) - - -# TODO (mkozuki): Move the below gradient checkpoint related features to another (new) file. -class CheckpointFunction(torch.autograd.Function): - """This function is adapted from torch.utils.checkpoint with - two main changes: - 1) torch.cuda.set_rng_state is replaced with `_set_cuda_rng_state` - 2) the states in the model parallel tracker are also properly - tracked/set/reset. - """ - - @staticmethod - def forward(ctx, run_function, distribute_saved_activations, *args): - ctx.run_function = run_function - ctx.distribute_saved_activations = distribute_saved_activations - - # Copy the rng states. - ctx.fwd_cpu_rng_state = torch.get_rng_state() - ctx.fwd_cuda_rng_state = torch.cuda.get_rng_state() - ctx.fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states() - - with torch.no_grad(): - outputs = run_function(*args) - - # Divide hidden states across model parallel group and only keep - # the chunk corresponding to the current rank. - if ctx.distribute_saved_activations: - ctx.input_0_shape = args[0].shape - - # Store everything. - ctx.save_for_backward(*args) - return outputs - - @staticmethod - def backward(ctx, *args): - if not torch.autograd._is_checkpoint_valid(): - raise RuntimeError( - "Checkpointing is not compatible with .grad(), please use .backward() if possible" - ) - inputs = ctx.saved_tensors - - # Store the current states. - bwd_cpu_rng_state = torch.get_rng_state() - bwd_cuda_rng_state = torch.cuda.get_rng_state() - bwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states() - - # Set the states to what it used to be before the forward pass. - torch.set_rng_state(ctx.fwd_cpu_rng_state) - _set_cuda_rng_state(ctx.fwd_cuda_rng_state) - get_cuda_rng_tracker().set_states(ctx.fwd_cuda_rng_state_tracker) - - # Compute the forward pass. - detached_inputs = detach_variable(inputs) - with torch.enable_grad(): - outputs = ctx.run_function(*detached_inputs) - - # Set the states back to what it was at the start of this function. - torch.set_rng_state(bwd_cpu_rng_state) - _set_cuda_rng_state(bwd_cuda_rng_state) - get_cuda_rng_tracker().set_states(bwd_cuda_rng_state_tracker) - - if isinstance(outputs, torch.Tensor): - outputs = (outputs,) - torch.autograd.backward(outputs, args) - grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp for inp in detached_inputs) - return (None, None) + grads - - -# NOTE(mkozuki): It doesn't look like `distribute_saved_activations` is used in apex.transformer -# but I added this change to reduce the superficial difference from Megatron-LM. -def checkpoint(function, distribute_saved_activations, *args): - """Checkpoint a model or part of the model. - This has been directly copied from torch.utils.checkpoint.""" - return CheckpointFunction.apply(function, distribute_saved_activations, *args) diff --git a/apex/transformer/tensor_parallel/utils.py b/apex/transformer/tensor_parallel/utils.py deleted file mode 100644 index 6d8291022..000000000 --- a/apex/transformer/tensor_parallel/utils.py +++ /dev/null @@ -1,66 +0,0 @@ -# coding=utf-8 -# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import List, Sequence - -import torch - -from apex.transformer.utils import divide - - -def split_tensor_along_last_dim( - tensor: torch.Tensor, - num_partitions: int, - contiguous_split_chunks: bool = False, -) -> List[torch.Tensor]: - """Split a tensor along its last dimension. - Arguments: - tensor: input tensor. - num_partitions: number of partitions to split the tensor - contiguous_split_chunks: If True, make each chunk contiguous - in memory. - """ - # Get the size and dimension. - last_dim = tensor.dim() - 1 - last_dim_size = divide(tensor.size()[last_dim], num_partitions) - # Split. - tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) - # Note: torch.split does not create contiguous tensors by default. - if contiguous_split_chunks: - return tuple(chunk.contiguous() for chunk in tensor_list) - - return tensor_list - - -class VocabUtility: - """Split the vocabulary into `world_size` chunks and return the - first and last index of the vocabulary belonging to the `rank` - partition: Note that indices in [fist, last)""" - - @staticmethod - def vocab_range_from_per_partition_vocab_size( - per_partition_vocab_size: int, rank, world_size: int - ) -> Sequence[int]: - index_f = rank * per_partition_vocab_size - index_l = index_f + per_partition_vocab_size - return index_f, index_l - - @staticmethod - def vocab_range_from_global_vocab_size( - global_vocab_size: int, rank: int, world_size: int - ) -> Sequence[int]: - per_partition_vocab_size = divide(global_vocab_size, world_size) - return VocabUtility.vocab_range_from_per_partition_vocab_size( - per_partition_vocab_size, rank, world_size - ) diff --git a/apex/transformer/testing/__init__.py b/apex/transformer/testing/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/apex/transformer/testing/arguments.py b/apex/transformer/testing/arguments.py deleted file mode 100644 index 5c0a4e46d..000000000 --- a/apex/transformer/testing/arguments.py +++ /dev/null @@ -1,1513 +0,0 @@ -# coding=utf-8 -# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Megatron arguments.""" - -import argparse -import os - -import torch - - -def parse_args(extra_args_provider=None, defaults={}, override_args={}, ignore_unknown_args=False): - """Parse all arguments.""" - parser = argparse.ArgumentParser(description="Megatron-LM Arguments", allow_abbrev=False) - - # Standard arguments. - parser = _add_network_size_args(parser) - parser = _add_regularization_args(parser) - parser = _add_training_args(parser) - parser = _add_initialization_args(parser) - parser = _add_learning_rate_args(parser) - parser = _add_checkpointing_args(parser) - parser = _add_mixed_precision_args(parser) - parser = _add_distributed_args(parser) - parser = _add_validation_args(parser) - parser = _add_data_args(parser) - parser = _add_autoresume_args(parser) - parser = _add_biencoder_args(parser) - parser = _add_vision_args(parser) - parser = _add_logging_args(parser) - - # NOTE(mkozuki): This option is added to investigate the potential of `torch.autograd.graph.save_on_cpu()`. - # ref: https://pytorch.org/docs/stable/autograd.html#torch.autograd.graph.save_on_cpu. - parser.add_argument( - "--cpu-offload", - action="store_true", - default=False, - help="Turns on CPU offloading", - ) - - # Custom arguments. - if extra_args_provider is not None: - parser = extra_args_provider(parser) - - # Parse. - if ignore_unknown_args: - args, _ = parser.parse_known_args() - else: - args = parser.parse_args() - - # Distributed args. - args.rank = int(os.getenv("RANK", "0")) - args.world_size = int(os.getenv("WORLD_SIZE", "1")) - - for key in override_args: - setattr(args, key, override_args[key]) - - # Tensor model parallel size. - args.tensor_model_parallel_size = min(args.tensor_model_parallel_size, args.world_size) - assert args.world_size % args.tensor_model_parallel_size == 0, ( - "world size ({}) is not divisible by tensor model parallel size ({})".format( - args.world_size, args.tensor_model_parallel_size - ) - ) - - # Pipeline model parallel size. - args.pipeline_model_parallel_size = min( - args.pipeline_model_parallel_size, - (args.world_size // args.tensor_model_parallel_size), - ) - - args.transformer_pipeline_model_parallel_size = ( - args.pipeline_model_parallel_size - 1 - if args.standalone_embedding_stage - else args.pipeline_model_parallel_size - ) - # Checks. - model_parallel_size = args.pipeline_model_parallel_size * args.tensor_model_parallel_size - assert args.world_size % model_parallel_size == 0, ( - "world size is not" - " divisible by tensor parallel size ({}) times pipeline parallel " - "size ({})".format( - args.world_size, - args.tensor_model_parallel_size, - ) - ) - args.data_parallel_size = args.world_size // model_parallel_size - if args.rank == 0: - print( - "using world size: {}, data-parallel-size: {}, " - "tensor-model-parallel size: {}, " - "pipeline-model-parallel size: {} ".format( - args.world_size, - args.data_parallel_size, - args.tensor_model_parallel_size, - args.pipeline_model_parallel_size, - ), - flush=True, - ) - if args.pipeline_model_parallel_size > 1: - if args.pipeline_model_parallel_split_rank is not None: - assert args.pipeline_model_parallel_split_rank < args.pipeline_model_parallel_size, ( - "split rank needs to be less than pipeline model parallel size ({})".format( - args.pipeline_model_parallel_size - ) - ) - - # Deprecated arguments - assert args.batch_size is None, ( - "--batch-size argument is no longer valid, use --micro-batch-size instead" - ) - del args.batch_size - assert args.warmup is None, ( - "--warmup argument is no longer valid, use --lr-warmup-fraction instead" - ) - del args.warmup - assert args.model_parallel_size is None, ( - "--model-parallel-size is no longer valid, use --tensor-model-parallel-size instead" - ) - del args.model_parallel_size - if args.checkpoint_activations: - args.recompute_granularity = "full" - args.recompute_method = "uniform" - if args.rank == 0: - print( - "--checkpoint-activations is no longer valid, " - "use --recompute-granularity and --recompute-method instead. " - "Defaulting to recompute-granularity=full and recompute-method=uniform." - ) - del args.checkpoint_activations - - if args.recompute_activations: - args.recompute_granularity = "selective" - del args.recompute_activations - - # Set input defaults. - for key in defaults: - # For default to be valid, it should not be provided in the - # arguments that are passed to the program. We check this by - # ensuring the arg is set to None. - if getattr(args, key) is not None: - if args.rank == 0: - print( - "WARNING: overriding default arguments for {key}:{v} \ - with {key}:{v2}".format(key=key, v=defaults[key], v2=getattr(args, key)), - flush=True, - ) - else: - setattr(args, key, defaults[key]) - - # Batch size. - assert args.micro_batch_size is not None - assert args.micro_batch_size > 0 - if args.global_batch_size is None: - args.global_batch_size = args.micro_batch_size * args.data_parallel_size - if args.rank == 0: - print( - "setting global batch size to {}".format(args.global_batch_size), - flush=True, - ) - assert args.global_batch_size > 0 - if args.num_layers_per_virtual_pipeline_stage is not None: - assert args.pipeline_model_parallel_size > 2, ( - "pipeline-model-parallel size should be greater than 2 with interleaved schedule" - ) - assert args.num_layers % args.num_layers_per_virtual_pipeline_stage == 0, ( - "number of layers is not divisible by number of layers per virtual pipeline stage" - ) - args.virtual_pipeline_model_parallel_size = ( - args.num_layers // args.pipeline_model_parallel_size - ) // args.num_layers_per_virtual_pipeline_stage - else: - args.virtual_pipeline_model_parallel_size = None - - # Parameters dtype. - args.params_dtype = torch.float - if args.fp16: - assert not args.bf16 - args.params_dtype = torch.half - if args.bf16: - assert not args.fp16 - args.params_dtype = torch.bfloat16 - # bfloat16 requires gradient accumulation and all-reduce to - # be done in fp32. - if not args.accumulate_allreduce_grads_in_fp32: - args.accumulate_allreduce_grads_in_fp32 = True - if args.rank == 0: - print( - "accumulate and all-reduce gradients in fp32 for bfloat16 data type.", - flush=True, - ) - - if args.rank == 0: - print("using {} for parameters ...".format(args.params_dtype), flush=True) - - # If we do accumulation and all-reduces in fp32, we need to have local DDP - # and we should make sure use-contiguous-buffers-in-local-ddp is not off. - if args.accumulate_allreduce_grads_in_fp32: - assert args.DDP_impl == "local" - assert args.use_contiguous_buffers_in_local_ddp - else: - if args.gradient_accumulation_fusion: - args.gradient_accumulation_fusion = False - if args.rank == 0: - print( - "Gradient accumulation fusion to linear layer weight " - "gradient computation is supported only with fp32 " - "gradient accumulation. Setting gradient_accumulation_fusion " - "to False", - flush=True, - ) - - # For torch DDP, we do not use contiguous buffer - if args.DDP_impl == "torch": - args.use_contiguous_buffers_in_local_ddp = False - - if args.dataloader_type is None: - args.dataloader_type = "single" - - # Consumed tokens. - args.consumed_train_samples = 0 - args.consumed_valid_samples = 0 - - # Iteration-based training. - if args.train_iters: - # If we use iteration-based training, make sure the - # sample-based options are off. - assert args.train_samples is None, "expected iteration-based training" - assert args.lr_decay_samples is None, "expected iteration-based learning rate decay" - assert args.lr_warmup_samples == 0, "expected iteration-based learning rate warmup" - assert args.rampup_batch_size is None, ( - "expected no batch-size rampup for iteration-based training" - ) - if args.lr_warmup_fraction is not None: - assert args.lr_warmup_iters == 0, ( - "can only specify one of lr-warmup-fraction and lr-warmup-iters" - ) - - # Sample-based training. - if args.train_samples: - # If we use sample-based training, make sure the - # iteration-based options are off. - assert args.train_iters is None, "expected sample-based training" - assert args.lr_decay_iters is None, "expected sample-based learning rate decay" - assert args.lr_warmup_iters == 0, "expected sample-based learnig rate warmup" - if args.lr_warmup_fraction is not None: - assert args.lr_warmup_samples == 0, ( - "can only specify one of lr-warmup-fraction and lr-warmup-samples" - ) - - # Check required arguments. - required_args = [ - "num_layers", - "hidden_size", - "num_attention_heads", - "max_position_embeddings", - ] - for req_arg in required_args: - _check_arg_is_not_none(args, req_arg) - - # Checks. - if args.ffn_hidden_size is None: - args.ffn_hidden_size = 4 * args.hidden_size - - if args.kv_channels is None: - assert args.hidden_size % args.num_attention_heads == 0 - args.kv_channels = args.hidden_size // args.num_attention_heads - - if args.seq_length is not None: - assert args.encoder_seq_length is None - args.encoder_seq_length = args.seq_length - else: - assert args.encoder_seq_length is not None - args.seq_length = args.encoder_seq_length - - if args.seq_length is not None: - assert args.max_position_embeddings >= args.seq_length - if args.decoder_seq_length is not None: - assert args.max_position_embeddings >= args.decoder_seq_length - if args.lr is not None: - assert args.min_lr <= args.lr - if args.save is not None: - assert args.save_interval is not None - # Mixed precision checks. - if args.fp16_lm_cross_entropy: - assert args.fp16, "lm cross entropy in fp16 only support in fp16 mode." - if args.fp32_residual_connection: - assert args.fp16 or args.bf16, ( - "residual connection in fp32 only supported when using fp16 or bf16." - ) - - if args.weight_decay_incr_style == "constant": - assert args.start_weight_decay is None - assert args.end_weight_decay is None - args.start_weight_decay = args.weight_decay - args.end_weight_decay = args.weight_decay - else: - assert args.start_weight_decay is not None - assert args.end_weight_decay is not None - - TORCH_MAJOR = int(torch.__version__.split(".")[0]) - TORCH_MINOR = int(torch.__version__.split(".")[1]) - # Persistent fused layer norm. - if TORCH_MAJOR < 1 or (TORCH_MAJOR == 1 and TORCH_MINOR < 11): - args.no_persist_layer_norm = True - if args.rank == 0: - print( - "Persistent fused layer norm kernel is supported from " - "pytorch v1.11 (nvidia pytorch container paired with v1.11). " - "Defaulting to no_persist_layer_norm=True" - ) - - # Activation recomputing. - if args.distribute_saved_activations: - assert args.tensor_model_parallel_size > 1, ( - "can distribute recomputed activations only across tensor model parallel groups" - ) - assert args.recompute_granularity == "full", ( - "distributed recompute activations is only application to full recompute granularity" - ) - assert args.recompute_method is not None, ( - "for distributed recompute activations to work you need to use a recompute method " - ) - assert TORCH_MAJOR >= 1 and TORCH_MINOR >= 10, ( - "distributed recompute activations are supported for pytorch " - "v1.10 and above (Nvidia Pytorch container >= 21.07). Current " - "pytorch version is v%s.%s." % (TORCH_MAJOR, TORCH_MINOR) - ) - - if args.recompute_granularity == "selective": - assert args.recompute_method is None, ( - "recompute method is not yet supported for selective recomputing granularity" - ) - - # disable async_tensor_model_parallel_allreduce when - # model parallel memory optimization is enabled - if args.sequence_parallel: - args.async_tensor_model_parallel_allreduce = False - - _print_args(args) - return args - - -def _print_args(args): - """Print arguments.""" - if args.rank == 0: - print("------------------------ arguments ------------------------", flush=True) - str_list = [] - for arg in vars(args): - dots = "." * (48 - len(arg)) - str_list.append(" {} {} {}".format(arg, dots, getattr(args, arg))) - for arg in sorted(str_list, key=lambda x: x.lower()): - print(arg, flush=True) - print("-------------------- end of arguments ---------------------", flush=True) - - -def _check_arg_is_not_none(args, arg): - assert getattr(args, arg) is not None, "{} argument is None".format(arg) - - -def _add_inference_args(parser): - group = parser.add_argument_group(title="inference") - - group.add_argument( - "--inference-batch-times-seqlen-threshold", - type=int, - default=512, - help="During inference, if batch-size times " - "sequence-length is smaller than this threshold " - "then we will not use pipelining, otherwise we will.", - ) - - return parser - - -def _add_network_size_args(parser): - group = parser.add_argument_group(title="network size") - - group.add_argument("--num-layers", type=int, default=None, help="Number of transformer layers.") - group.add_argument("--hidden-size", type=int, default=None, help="Tansformer hidden size.") - group.add_argument( - "--ffn-hidden-size", - type=int, - default=None, - help="Transformer Feed-Forward Network hidden size. " - "This is set to 4*hidden-size if not provided", - ) - group.add_argument( - "--num-attention-heads", - type=int, - default=None, - help="Number of transformer attention heads.", - ) - group.add_argument( - "--kv-channels", - type=int, - default=None, - help="Projection weights dimension in multi-head " - "attention. This is set to " - " args.hidden_size // args.num_attention_heads " - "if not provided.", - ) - group.add_argument( - "--max-position-embeddings", - type=int, - default=None, - help="Maximum number of position embeddings to use. " - "This is the size of position embedding.", - ) - group.add_argument( - "--make-vocab-size-divisible-by", - type=int, - default=128, - help="Pad the vocab size to be divisible by this value." - "This is added for computational efficieny reasons.", - ) - group.add_argument("--layernorm-epsilon", type=float, default=1e-5, help="Layer norm epsilon.") - group.add_argument( - "--apply-residual-connection-post-layernorm", - action="store_true", - help="If set, use original BERT residula connection ordering.", - ) - group.add_argument( - "--openai-gelu", - action="store_true", - help="Use OpenAIs GeLU implementation. This option" - "should not be used unless for backward compatibility" - "reasons.", - ) - group.add_argument( - "--onnx-safe", - type=bool, - required=False, - help="Use workarounds for known problems with Torch ONNX exporter", - ) - group.add_argument( - "--bert-no-binary-head", - action="store_false", - help="Disable BERT binary head.", - dest="bert_binary_head", - ) - group.add_argument( - "--num-experts", - type=int, - default=None, - help="Number of Experts in Switch Transformer (None means no Switch)", - ) - - return parser - - -def _add_logging_args(parser): - group = parser.add_argument_group(title="logging") - - group.add_argument( - "--log-params-norm", - action="store_true", - help="If set, calculate and log parameters norm.", - ) - group.add_argument( - "--log-num-zeros-in-grad", - action="store_true", - help="If set, calculate and log the number of zeros in gradient.", - ) - group.add_argument( - "--tensorboard-log-interval", - type=int, - default=1, - help="Report to tensorboard interval.", - ) - group.add_argument( - "--tensorboard-queue-size", - type=int, - default=1000, - help="Size of the tensorboard queue for pending events " - "and summaries before one of the ‘add’ calls forces a " - "flush to disk.", - ) - group.add_argument( - "--log-timers-to-tensorboard", - action="store_true", - help="If set, write timers to tensorboard.", - ) - group.add_argument( - "--log-batch-size-to-tensorboard", - action="store_true", - help="If set, write batch-size to tensorboard.", - ) - group.add_argument( - "--no-log-learnig-rate-to-tensorboard", - action="store_false", - help="Disable learning rate logging to tensorboard.", - dest="log_learning_rate_to_tensorboard", - ) - group.add_argument( - "--no-log-loss-scale-to-tensorboard", - action="store_false", - help="Disable loss-scale logging to tensorboard.", - dest="log_loss_scale_to_tensorboard", - ) - group.add_argument( - "--log-validation-ppl-to-tensorboard", - action="store_true", - help="If set, write validation perplexity to tensorboard.", - ) - group.add_argument( - "--log-memory-to-tensorboard", - action="store_true", - help="Enable memory logging to tensorboard.", - ) - group.add_argument( - "--log-world-size-to-tensorboard", - action="store_true", - help="Enable world size logging to tensorboard.", - ) - - return parser - - -def _add_regularization_args(parser): - group = parser.add_argument_group(title="regularization") - - group.add_argument( - "--attention-dropout", - type=float, - default=0.1, - help="Post attention dropout probability.", - ) - group.add_argument( - "--hidden-dropout", - type=float, - default=0.1, - help="Dropout probability for hidden state transformer.", - ) - group.add_argument( - "--weight-decay", - type=float, - default=0.01, - help="Weight decay coefficient for L2 regularization.", - ) - group.add_argument( - "--start-weight-decay", - type=float, - help="Initial weight decay coefficient for L2 regularization.", - ) - group.add_argument( - "--end-weight-decay", - type=float, - help="End of run weight decay coefficient for L2 regularization.", - ) - group.add_argument( - "--weight-decay-incr-style", - type=str, - default="constant", - choices=["constant", "linear", "cosine"], - help="Weight decay increment function.", - ) - group.add_argument( - "--clip-grad", - type=float, - default=1.0, - help="Gradient clipping based on global L2 norm.", - ) - group.add_argument( - "--adam-beta1", - type=float, - default=0.9, - help="First coefficient for computing running averages of gradient and its square", - ) - group.add_argument( - "--adam-beta2", - type=float, - default=0.999, - help="Second coefficient for computing running averages of gradient and its square", - ) - group.add_argument( - "--adam-eps", - type=float, - default=1e-08, - help="Term added to the denominator to improvenumerical stability", - ) - group.add_argument("--sgd-momentum", type=float, default=0.9, help="Momentum factor for sgd") - - return parser - - -def _add_training_args(parser): - group = parser.add_argument_group(title="training") - - group.add_argument( - "--micro-batch-size", - type=int, - default=None, - help="Batch size per model instance (local batch size). " - "Global batch size is local batch size times data " - "parallel size times number of micro batches.", - ) - group.add_argument( - "--batch-size", - type=int, - default=None, - help="Old batch size parameter, do not use. Use --micro-batch-size instead", - ) - group.add_argument( - "--global-batch-size", - type=int, - default=None, - help="Training batch size. If set, it should be a " - "multiple of micro-batch-size times data-parallel-size. " - "If this value is None, then " - "use micro-batch-size * data-parallel-size as the " - "global batch size. This choice will result in 1 for " - "number of micro-batches.", - ) - group.add_argument( - "--rampup-batch-size", - nargs="*", - default=None, - help="Batch size ramp up with the following values:" - " --rampup-batch-size " - " " - " " - "For example:" - " --rampup-batch-size 16 8 300000 \ " - " --global-batch-size 1024" - "will start with global batch size 16 and over " - " (1024 - 16) / 8 = 126 intervals will increase" - "the batch size linearly to 1024. In each interval" - "we will use approximately 300000 / 126 = 2380 samples.", - ) - group.add_argument( - "--recompute-activations", - action="store_true", - help="recompute activation to allow for training " - "with larger models, sequences, and batch sizes.", - ) - group.add_argument( - "--recompute-granularity", - type=str, - default=None, - choices=["full", "selective"], - help="Checkpoint activations to allow for training " - "with larger models, sequences, and batch sizes. " - "It is supported at two granularities 1) full: " - "whole transformer layer is recomputed, " - "2) selective: core attention part of the transformer " - "layer is recomputed.", - ) - group.add_argument( - "--distribute-saved-activations", - action="store_true", - help="If set, distribute recomputed activations across model parallel group.", - ) - group.add_argument( - "--recompute-method", - type=str, - default=None, - choices=["uniform", "block"], - help="1) uniform: uniformly divide the total number of " - "Transformer layers and recompute the input activation of " - "each divided chunk at specified granularity, " - "2) recompute the input activations of only a set number of " - "individual Transformer layers per pipeline stage and do the " - "rest without any recomputing at specified granularity" - "default) do not apply activations recompute to any layers", - ) - group.add_argument( - "--recompute-num-layers", - type=int, - default=1, - help="1) uniform: the number of Transformer layers in each " - "uniformly divided recompute unit, " - "2) block: the number of individual Transformer layers " - "to recompute within each pipeline stage.", - ) - - # deprecated - group.add_argument( - "--checkpoint-activations", - action="store_true", - help="Checkpoint activation to allow for training " - "with larger models, sequences, and batch sizes.", - ) - group.add_argument( - "--train-iters", - type=int, - default=None, - help="Total number of iterations to train over all " - "training runs. Note that either train-iters or " - "train-samples should be provided.", - ) - group.add_argument( - "--train-samples", - type=int, - default=None, - help="Total number of samples to train over all " - "training runs. Note that either train-iters or " - "train-samples should be provided.", - ) - group.add_argument( - "--log-interval", type=int, default=100, help="Report loss and timing interval." - ) - group.add_argument( - "--exit-interval", - type=int, - default=None, - help="Exit the program after the iteration is divisible by this value.", - ) - group.add_argument( - "--exit-duration-in-mins", - type=int, - default=None, - help="Exit the program after this many minutes.", - ) - group.add_argument( - "--tensorboard-dir", - type=str, - default=None, - help="Write TensorBoard logs to this directory.", - ) - group.add_argument( - "--no-masked-softmax-fusion", - action="store_false", - help="Disable fusion of query_key_value scaling, masking, and softmax.", - dest="masked_softmax_fusion", - ) - group.add_argument( - "--no-bias-gelu-fusion", - action="store_false", - help="Disable bias and gelu fusion.", - dest="bias_gelu_fusion", - ) - group.add_argument( - "--no-bias-dropout-fusion", - action="store_false", - help="Disable bias and dropout fusion.", - dest="bias_dropout_fusion", - ) - group.add_argument( - "--optimizer", - type=str, - default="adam", - choices=["adam", "sgd"], - help="Optimizer function", - ) - group.add_argument( - "--dataloader-type", - type=str, - default=None, - choices=["single", "cyclic"], - help="Single pass vs multiple pass data loader", - ) - group.add_argument( - "--no-async-tensor-model-parallel-allreduce", - action="store_true", - help="Disable asynchronous execution of " - "tensor-model-parallel all-reduce with weight " - "gradient compuation of a column-linear layer.", - dest="async_tensor_model_parallel_allreduce", - ) - group.add_argument( - "--no-persist-layer-norm", - action="store_true", - help="Disable using persistent fused layer norm kernel. " - "This kernel supports only a set of hidden sizes. Please " - "check persist_ln_hidden_sizes if your hidden " - "size is supported.", - ) - group.add_argument( - "--sequence-parallel", - action="store_true", - help="Enable sequence parallel optimization.", - ) - group.add_argument( - "--no-gradient-accumulation-fusion", - action="store_false", - help="Disable fusing gradient accumulation to weight gradient computation of linear layers", - dest="gradient_accumulation_fusion", - ) - return parser - - -def _add_initialization_args(parser): - group = parser.add_argument_group(title="initialization") - - group.add_argument( - "--seed", - type=int, - default=1234, - help="Random seed used for python, numpy, pytorch, and cuda.", - ) - group.add_argument( - "--init-method-std", - type=float, - default=0.02, - help="Standard deviation of the zero mean normal " - "distribution used for weight initialization.", - ) - group.add_argument( - "--init-method-xavier-uniform", - action="store_true", - help="Enable Xavier uniform parameter initialization", - ) - - return parser - - -def _add_learning_rate_args(parser): - group = parser.add_argument_group(title="learning rate") - - group.add_argument( - "--lr", - type=float, - default=None, - help="Initial learning rate. Depending on decay style " - "and initial warmup, the learing rate at each " - "iteration would be different.", - ) - group.add_argument( - "--lr-decay-style", - type=str, - default="linear", - choices=["constant", "linear", "cosine"], - help="Learning rate decay function.", - ) - group.add_argument( - "--lr-decay-iters", - type=int, - default=None, - help="number of iterations to decay learning rate over," - " If None defaults to `--train-iters`", - ) - group.add_argument( - "--lr-decay-samples", - type=int, - default=None, - help="number of samples to decay learning rate over, If None defaults to `--train-samples`", - ) - group.add_argument( - "--lr-warmup-fraction", - type=float, - default=None, - help="fraction of lr-warmup-(iters/samples) to use for warmup (as a float)", - ) - group.add_argument( - "--lr-warmup-iters", - type=int, - default=0, - help="number of iterations to linearly warmup learning rate over.", - ) - group.add_argument( - "--lr-warmup-samples", - type=int, - default=0, - help="number of samples to linearly warmup learning rate over.", - ) - group.add_argument( - "--warmup", - type=int, - default=None, - help="Old lr warmup argument, do not use. Use one of the--lr-warmup-* arguments above", - ) - group.add_argument( - "--min-lr", - type=float, - default=0.0, - help="Minumum value for learning rate. The schedulerclip values below this threshold.", - ) - group.add_argument( - "--override-lr-scheduler", - action="store_true", - help="Reset the values of the scheduler (learning rate," - "warmup iterations, minimum learning rate, maximum " - "number of iterations, and decay style from input " - "arguments and ignore values from checkpoints. Note" - "that all the above values will be reset.", - ) - group.add_argument( - "--use-checkpoint-lr-scheduler", - action="store_true", - help="Use checkpoint to set the values of the scheduler " - "(learning rate, warmup iterations, minimum learning " - "rate, maximum number of iterations, and decay style " - "from checkpoint and ignore input arguments.", - ) - - return parser - - -def _add_checkpointing_args(parser): - group = parser.add_argument_group(title="checkpointing") - - group.add_argument( - "--save", - type=str, - default=None, - help="Output directory to save checkpoints to.", - ) - group.add_argument( - "--save-interval", - type=int, - default=None, - help="Number of iterations between checkpoint saves.", - ) - group.add_argument( - "--no-save-optim", - action="store_true", - default=None, - help="Do not save current optimizer.", - ) - group.add_argument( - "--no-save-rng", - action="store_true", - default=None, - help="Do not save current rng state.", - ) - group.add_argument( - "--load", - type=str, - default=None, - help="Directory containing a model checkpoint.", - ) - group.add_argument( - "--no-load-optim", - action="store_true", - default=None, - help="Do not load optimizer when loading checkpoint.", - ) - group.add_argument( - "--no-load-rng", - action="store_true", - default=None, - help="Do not load rng state when loading checkpoint.", - ) - group.add_argument( - "--finetune", - action="store_true", - help="Load model for finetuning. Do not load optimizer " - "or rng state from checkpoint and set iteration to 0. " - "Assumed when loading a release checkpoint.", - ) - - return parser - - -def _add_mixed_precision_args(parser): - group = parser.add_argument_group(title="mixed precision") - - group.add_argument("--fp16", action="store_true", help="Run model in fp16 mode.") - group.add_argument("--bf16", action="store_true", help="Run model in bfloat16 mode.") - group.add_argument( - "--loss-scale", - type=float, - default=None, - help="Static loss scaling, positive power of 2 " - "values can improve fp16 convergence. If None, dynamic" - "loss scaling is used.", - ) - group.add_argument( - "--initial-loss-scale", - type=float, - default=2**32, - help="Initial loss-scale for dynamic loss scaling.", - ) - group.add_argument( - "--min-loss-scale", - type=float, - default=1.0, - help="Minimum loss scale for dynamic loss scale.", - ) - group.add_argument( - "--loss-scale-window", - type=float, - default=1000, - help="Window over which to raise/lower dynamic scale.", - ) - group.add_argument( - "--hysteresis", type=int, default=2, help="hysteresis for dynamic loss scaling" - ) - group.add_argument( - "--fp32-residual-connection", - action="store_true", - help="Move residual connections to fp32.", - ) - group.add_argument( - "--no-query-key-layer-scaling", - action="store_false", - help="Do not scale Q * K^T by 1 / layer-number.", - dest="apply_query_key_layer_scaling", - ) - group.add_argument( - "--attention-softmax-in-fp32", - action="store_true", - help="Run attention masking and softmax in fp32. " - "This flag is ignored unless " - "--no-query-key-layer-scaling is specified.", - ) - group.add_argument( - "--accumulate-allreduce-grads-in-fp32", - action="store_true", - help="Gradient accumulation and all-reduce in fp32.", - ) - group.add_argument( - "--fp16-lm-cross-entropy", - action="store_true", - help="Move the cross entropy unreduced loss calculationfor lm head to fp16.", - ) - - return parser - - -def _add_distributed_args(parser): - group = parser.add_argument_group(title="distributed") - - group.add_argument( - "--tensor-model-parallel-size", - type=int, - default=1, - help="Degree of tensor model parallelism.", - ) - group.add_argument( - "--pipeline-model-parallel-size", - type=int, - default=1, - help="Degree of pipeline model parallelism.", - ) - group.add_argument( - "--pipeline-model-parallel-split-rank", - type=int, - default=None, - help="Rank where encoder and decoder should be split.", - ) - group.add_argument( - "--model-parallel-size", - type=int, - default=None, - help="Old model parallel argument, do not use. Use --tensor-model-parallel-size instead.", - ) - group.add_argument( - "--num-layers-per-virtual-pipeline-stage", - type=int, - default=None, - help="Number of layers per virtual pipeline stage", - ) - group.add_argument( - "--distributed-backend", - default="nccl", - choices=["nccl", "gloo"], - help="Which backend to use for distributed training.", - ) - group.add_argument( - "--DDP-impl", - default="local", - choices=["local", "torch"], - help="which DistributedDataParallel implementation to use.", - ) - group.add_argument( - "--no-contiguous-buffers-in-local-ddp", - action="store_false", - help="If set, dont use contiguous buffer in local DDP.", - dest="use_contiguous_buffers_in_local_ddp", - ) - group.add_argument( - "--no-scatter-gather-tensors-in-pipeline", - action="store_false", - help="Use scatter/gather to optimize communication of tensors in pipeline", - dest="scatter_gather_tensors_in_pipeline", - ) - group.add_argument( - "--local_rank", - type=int, - default=None, - help="local rank passed from distributed launcher.", - ) - group.add_argument( - "--lazy-mpu-init", - type=bool, - required=False, - help="If set to True, initialize_megatron() " - "skips DDP initialization and returns function to " - "complete it instead.Also turns on " - "--use-cpu-initialization flag. This is for " - "external DDP manager.", - ) - group.add_argument( - "--use-cpu-initialization", - action="store_true", - default=None, - help="If set, affine parallel weights initialization uses CPU", - ) - group.add_argument( - "--empty-unused-memory-level", - default=0, - type=int, - choices=[0, 1, 2], - help="Call torch.cuda.empty_cache() each iteration " - "(training and eval), to reduce fragmentation." - "0=off, 1=moderate, 2=aggressive.", - ) - group.add_argument( - "--standalone-embedding-stage", - action="store_true", - default=False, - help="If set, *input* embedding layer " - "is placed on its own pipeline stage, without any " - "transformer layers. (For T5, this flag currently only " - "affects the encoder embedding.)", - ) - return parser - - -def _add_validation_args(parser): - group = parser.add_argument_group(title="validation") - - group.add_argument( - "--eval-iters", - type=int, - default=100, - help="Number of iterations to run for evaluationvalidation/test for.", - ) - group.add_argument( - "--eval-interval", - type=int, - default=1000, - help="Interval between running evaluation on validation set.", - ) - - return parser - - -def _add_data_args(parser): - group = parser.add_argument_group(title="data and dataloader") - - group.add_argument( - "--data-path", - nargs="*", - default=None, - help="Path to the training dataset. Accepted format:" - "1) a single data path, 2) multiple datasets in the" - "form: dataset1-weight dataset1-path dataset2-weight " - "dataset2-path ...", - ) - group.add_argument( - "--split", - type=str, - default="969, 30, 1", - help="Comma-separated list of proportions for training," - " validation, and test split. For example the split " - "`90,5,5` will use 90%% of data for training, 5%% for " - "validation and 5%% for test.", - ) - group.add_argument("--vocab-file", type=str, default=None, help="Path to the vocab file.") - group.add_argument("--merge-file", type=str, default=None, help="Path to the BPE merge file.") - group.add_argument( - "--vocab-extra-ids", - type=int, - default=0, - help="Number of additional vocabulary tokens. " - "They are used for span masking in the T5 model", - ) - group.add_argument( - "--seq-length", - type=int, - default=None, - help="Maximum sequence length to process.", - ) - group.add_argument( - "--encoder-seq-length", - type=int, - default=None, - help="Maximum encoder sequence length to process.This should be exclusive of --seq-length", - ) - group.add_argument( - "--decoder-seq-length", - type=int, - default=None, - help="Maximum decoder sequence length to process.", - ) - group.add_argument( - "--retriever-seq-length", - type=int, - default=256, - help="Maximum sequence length for the biencoder model for retriever", - ) - group.add_argument( - "--sample-rate", - type=float, - default=1.0, - help="sample rate for training data. Supposed to be 0 < sample_rate < 1", - ) - group.add_argument( - "--mask-prob", - type=float, - default=0.15, - help="Probability of replacing a token with mask.", - ) - group.add_argument( - "--short-seq-prob", - type=float, - default=0.1, - help="Probability of producing a short sequence.", - ) - group.add_argument("--mmap-warmup", action="store_true", help="Warm up mmap files.") - group.add_argument("--num-workers", type=int, default=2, help="Dataloader number of workers.") - group.add_argument( - "--tokenizer-type", - type=str, - default=None, - choices=["BertWordPieceLowerCase", "BertWordPieceCase", "GPT2BPETokenizer"], - help="What type of tokenizer to use.", - ) - group.add_argument( - "--data-impl", - type=str, - default="infer", - choices=["lazy", "cached", "mmap", "infer"], - help="Implementation of indexed datasets.", - ) - group.add_argument( - "--reset-position-ids", - action="store_true", - help="Reset posistion ids after end-of-document token.", - ) - group.add_argument( - "--reset-attention-mask", - action="store_true", - help="Reset self attention maske after end-of-document token.", - ) - group.add_argument( - "--eod-mask-loss", - action="store_true", - help="Mask loss for the end of document tokens.", - ) - - return parser - - -def _add_autoresume_args(parser): - group = parser.add_argument_group(title="autoresume") - - group.add_argument( - "--adlr-autoresume", - action="store_true", - help="Enable autoresume on adlr cluster.", - ) - group.add_argument( - "--adlr-autoresume-interval", - type=int, - default=1000, - help="Intervals over which check for autoresumetermination signal", - ) - - return parser - - -def _add_biencoder_args(parser): - group = parser.add_argument_group(title="biencoder") - - # network size - group.add_argument( - "--ict-head-size", - type=int, - default=None, - help="Size of block embeddings to be used in ICT and REALM (paper default: 128)", - ) - group.add_argument( - "--biencoder-projection-dim", - type=int, - default=0, - help="Size of projection head used in biencoder (paper default: 128)", - ) - group.add_argument( - "--biencoder-shared-query-context-model", - action="store_true", - help="Whether to share the parameters of the query and context models or not", - ) - - # checkpointing - group.add_argument( - "--ict-load", - type=str, - default=None, - help="Directory containing an ICTBertModel checkpoint", - ) - group.add_argument( - "--bert-load", - type=str, - default=None, - help="Directory containing an BertModel checkpoint (needed to start ICT and REALM)", - ) - - # data - group.add_argument( - "--titles-data-path", - type=str, - default=None, - help="Path to titles dataset used for ICT", - ) - group.add_argument( - "--query-in-block-prob", - type=float, - default=0.1, - help="Probability of keeping query in block for ICT dataset", - ) - group.add_argument( - "--use-one-sent-docs", - action="store_true", - help="Whether to use one sentence documents in ICT", - ) - group.add_argument( - "--evidence-data-path", - type=str, - default=None, - help="Path to Wikipedia Evidence frm DPR paper", - ) - - # training - group.add_argument( - "--retriever-report-topk-accuracies", - nargs="+", - type=int, - default=[], - help="Which top-k accuracies to report (e.g. '1 5 20')", - ) - group.add_argument( - "--retriever-score-scaling", - action="store_true", - help="Whether to scale retriever scores by inverse square root of hidden size", - ) - - # faiss index - group.add_argument( - "--block-data-path", - type=str, - default=None, - help="Where to save/load BlockData to/from", - ) - group.add_argument( - "--embedding-path", - type=str, - default=None, - help="Where to save/load Open-Retrieval Embedding data to/from", - ) - - # indexer - group.add_argument( - "--indexer-batch-size", - type=int, - default=128, - help="How large of batches to use when doing indexing jobs", - ) - group.add_argument( - "--indexer-log-interval", - type=int, - default=1000, - help="After how many batches should the indexer report progress", - ) - return parser - - -def _add_vision_args(parser): - group = parser.add_argument_group(title="vision") - - # general vision arguments - group.add_argument( - "--num-classes", - type=int, - default=1000, - help="num of classes in vision classificaiton task", - ) - group.add_argument( - "--img-h", - type=int, - default=224, - help="Image height for vision classification task", - ) - group.add_argument( - "--img-w", - type=int, - default=224, - help="Image height for vision classification task", - ) - group.add_argument( - "--num-channels", - type=int, - default=3, - help="Number of channels in input image data", - ) - group.add_argument("--patch-dim", type=int, default=16, help="patch dimension") - group.add_argument( - "--classes-fraction", - type=float, - default=1.0, - help="training with fraction of classes.", - ) - group.add_argument( - "--data-per-class-fraction", - type=float, - default=1.0, - help="training with fraction of data per class.", - ) - group.add_argument( - "--no-data-sharding", - action="store_false", - help="Disable data sharding.", - dest="data_sharding", - ) - group.add_argument( - "--head-lr-mult", - type=float, - default=1.0, - help="learning rate multiplier for head during finetuning", - ) - - # pretraining type and backbone selection` - group.add_argument( - "--vision-pretraining", - action="store_true", - help="flag to indicate vision pretraining", - ) - group.add_argument( - "--vision-pretraining-type", - type=str, - default="classify", - choices=["classify", "inpaint", "dino"], - help="pretraining objectives", - ) - group.add_argument( - "--vision-backbone-type", - type=str, - default="vit", - choices=["vit", "mit", "swin"], - help="backbone types types", - ) - group.add_argument( - "--swin-backbone-type", - type=str, - default="tiny", - choices=["tiny", "base", "h3"], - help="pretraining objectives", - ) - - # inpainting arguments - group.add_argument( - "--mask-type", - type=str, - default="random", - choices=["random", "row"], - help="mask types", - ) - group.add_argument("--mask-factor", type=float, default=1.0, help="mask size scaling parameter") - - # dino arguments - group.add_argument("--iter-per-epoch", type=int, default=1250, help="iterations per epoch") - group.add_argument( - "--dino-local-img-size", - type=int, - default=96, - help="Image size for vision classification task", - ) - group.add_argument( - "--dino-local-crops-number", type=int, default=10, help="Number of local crops" - ) - group.add_argument( - "--dino-head-hidden-size", - type=int, - default=2048, - help="Hidden dimension size in dino head", - ) - group.add_argument( - "--dino-bottleneck-size", - type=int, - default=256, - help="Bottle neck dimension in dino head ", - ) - group.add_argument( - "--dino-freeze-last-layer", - type=float, - default=1, - help="Freezing last layer weights", - ) - group.add_argument( - "--dino-norm-last-layer", - action="store_true", - help="Disable Norm in last layer.", - ) - group.add_argument( - "--dino-warmup-teacher-temp", - type=float, - default=0.04, - help="warump teacher temperature", - ) - group.add_argument("--dino-teacher-temp", type=float, default=0.07, help="teacher temperature") - group.add_argument( - "--dino-warmup-teacher-temp-epochs", - type=int, - default=30, - help="warmup teacher temperaure epochs", - ) - - return parser diff --git a/apex/transformer/testing/commons.py b/apex/transformer/testing/commons.py deleted file mode 100644 index c880f892b..000000000 --- a/apex/transformer/testing/commons.py +++ /dev/null @@ -1,315 +0,0 @@ -# coding=utf-8 -# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from dataclasses import dataclass -import datetime -import os -import random -from typing import Optional, Union, List, Tuple, Callable, Dict - -import numpy -import torch -import torch.nn as nn - -from apex import transformer -from apex.transformer.tensor_parallel import ( - ColumnParallelLinear, - RowParallelLinear, - scatter_to_sequence_parallel_region, -) -from apex.transformer.pipeline_parallel.utils import ( - average_losses_across_data_parallel_group, -) -from apex.transformer.pipeline_parallel.schedules.common import ( - Batch, -) -from apex.transformer.testing import global_vars -from apex.transformer._ucc_util import HAS_UCC - -TEST_SUCCESS_MESSAGE = ">> passed the test :-)" - - -# note (mkozuki): `pre_process` and `post_process` are a placeholder until interleaving schedule test comes. -class MyLayer(nn.Module): - def __init__(self, hidden_size: int, pre_process: bool, post_process: bool): - super().__init__() - self.pre_process = pre_process - self.post_process = post_process - self.layer = nn.Linear(hidden_size, hidden_size) - - def forward(self, x): - return self.layer(x) - - -class MyModel(nn.Module): - def __init__( - self, - hidden_size: int, - pre_process: bool = False, - post_process: bool = False, - *, - add_encoder: bool = False, - add_decoder: bool = False, - ) -> None: - super().__init__() - self.pre_process = pre_process - self.post_process = post_process - self.layer = MyLayer( - hidden_size=hidden_size, pre_process=pre_process, post_process=post_process - ) - self.input_tensor = None - - def set_input_tensor(self, input_tensor: Union[torch.Tensor, List[torch.Tensor]]) -> None: - if not isinstance(input_tensor, list): - input_tensor = [input_tensor] - self.input_tensor = input_tensor[0] - - def forward(self, x: Optional[torch.Tensor]) -> torch.Tensor: - if self.input_tensor is None: - return self.layer(x) - return self.layer(self.input_tensor) - - -class ToyParallelMLP(nn.Module): - def __init__( - self, - hidden_size: int, - pre_process: bool = False, - post_process: bool = False, - *, - sequence_parallel_enabled: bool = False, - # TODO(mkozuki): Support these two? - add_encoder: bool = False, - add_decoder: bool = False, - ) -> None: - super().__init__() - self.pre_process = pre_process - self.post_process = post_process - self.sequence_parallel_enabled = sequence_parallel_enabled - - ffn_hidden_size = 4 * hidden_size - self.dense_h_to_4h = ColumnParallelLinear( - hidden_size, - ffn_hidden_size, - gather_output=False, - # init_method=init_method, - skip_bias_add=True, - # use_cpu_initialization=use_cpu_initialization, - bias=True, - sequence_parallel_enabled=sequence_parallel_enabled, - no_async_tensor_model_parallel_allreduce=True, - ) - self.dense_4h_to_h = RowParallelLinear( - ffn_hidden_size, - hidden_size, - input_is_parallel=True, - # init_method=output_layer_init_method, - skip_bias_add=False, - # use_cpu_initialization=use_cpu_initialization, - bias=True, - sequence_parallel_enabled=sequence_parallel_enabled, - ) - self.activation_func = torch.nn.GELU() - - def set_input_tensor( - self, - input_tensor: Union[torch.Tensor, List[torch.Tensor]], - ) -> None: - if not isinstance(input_tensor, list): - input_tensor = [input_tensor] - self.input_tensor = input_tensor[0] - - def forward( - self, - x: Optional[torch.Tensor], - ) -> torch.Tensor: - """Forward of Simplified ParallelMLP. - - Args: - x: :obj:`None` if pipeline rank != pippeline first rank. When :obj:`None`, - `self.input_tensor` is taken care of by `forward_step` defined in - apex/transformer/pipeline_parallel/schedules/common.py - """ - # [s, b, h] - if self.input_tensor is None: - input = x - else: - input = self.input_tensor - intermediate_parallel, bias_parallel = self.dense_h_to_4h(input) - - if bias_parallel is not None: - intermediate_parallel += bias_parallel - intermediate_parallel = self.activation_func(intermediate_parallel) - # [s, b, h] - output, output_bias = self.dense_4h_to_h(intermediate_parallel) - return output - - -def model_provider_func( - hidden_size: int, - pre_process: bool, - post_process: bool, - *, - add_encoder: bool = False, - add_decoder: bool = False, -) -> MyModel: - return MyModel( - hidden_size, - pre_process, - post_process, - add_encoder=add_encoder, - add_decoder=add_decoder, - ) - - -def mlp_provider_func( - hidden_size: int, - pre_process: bool, - post_process: bool, - *, - add_encoder: bool = False, - add_decoder: bool = False, - sequence_parallel_enabled: bool = False, -) -> ToyParallelMLP: - return ToyParallelMLP( - hidden_size, - pre_process, - post_process, - add_encoder=add_encoder, - add_decoder=add_decoder, - sequence_parallel_enabled=sequence_parallel_enabled, - ) - - -def process_batch(batch): - if isinstance(batch, list): - x = batch[0] - else: - x = batch - return x - - -def fwd_step_func(batch, model): - x = process_batch(batch) - y = model(x) - - # note (mkozuki): I don't think this function is nice but I do think this is enough for now - # just to check the sanity of ported pipeline functions. - def loss_func(x): - loss = torch.sum(x) - averaged_loss = average_losses_across_data_parallel_group([loss]) - return loss, {"avg": averaged_loss} - - return y, loss_func - - -@dataclass(frozen=True) -class ToyParallelMLPFwdBwdStepFunc: - sequence_parallel_enabled: bool - - def __call__( - self, - batch: Batch, - model: torch.nn.Module, - ) -> Tuple[ - torch.Tensor, - Callable[[torch.Tensor], Tuple[torch.Tensor, Dict[str, torch.Tensor]]], - ]: - x = batch[0] if isinstance(batch, list) else batch - if isinstance(x, torch.Tensor): - x = x.transpose(0, 1).contiguous() - if self.sequence_parallel_enabled: - x = scatter_to_sequence_parallel_region(x) - y = model(x) - - # note (mkozuki): I don't think this function is nice but I do think this is enough for now - # just to check the sanity of ported pipeline functions. - def loss_func(x): - loss = torch.sum(x) - averaged_loss = average_losses_across_data_parallel_group([loss]) - return loss, {"avg": averaged_loss} - - return y, loss_func - - -class IdentityLayer(torch.nn.Module): - def __init__(self, size, scale=1.0): - super(IdentityLayer, self).__init__() - self.weight = torch.nn.Parameter(scale * torch.randn(size)) - - def forward(self): - return self.weight - - -def set_random_seed(seed): - """Set random seed for reproducibility.""" - random.seed(seed) - numpy.random.seed(seed) - torch.manual_seed(seed) - transformer.tensor_parallel.model_parallel_cuda_manual_seed(seed) - - -def initialize_distributed(backend="nccl"): - """Initialize torch.distributed.""" - # Get local rank in case it is provided. - # parser = argparse.ArgumentParser() - # parser.add_argument('--local_rank', type=int, default=None, - # help='local rank passed from distributed launcher') - # args = parser.parse_args() - if backend not in ("nccl", "ucc"): - raise RuntimeError(f"Currently only nccl & ucc are supported but {backend}") - if backend == "ucc": - if not HAS_UCC: - raise ImportError( - "UCC backend requires pytorch source build with UCC installed and enabled" - ) - args = global_vars.get_args() - local_rank = args.local_rank - - # Get rank and world size. - rank = int(os.getenv("RANK", "0")) - world_size = int(os.getenv("WORLD_SIZE", "1")) - - print( - "> initializing torch.distributed with local rank: {}, rank: {}, world size: {}".format( - local_rank, rank, world_size - ) - ) - - # Set the device id. - device = rank % torch.cuda.device_count() - if local_rank is not None: - device = local_rank - torch.cuda.set_device(device) - - # Call the init process. - init_method = "tcp://" - master_ip = os.getenv("MASTER_ADDR", "localhost") - master_port = os.getenv("MASTER_PORT", "6000") - init_method += master_ip + ":" + master_port - torch.distributed.init_process_group( - backend=backend, - world_size=world_size, - rank=rank, - init_method=init_method, - timeout=datetime.timedelta(seconds=60), - ) - - -def print_separator(message): - filler_len = (78 - len(message)) // 2 - filler = "-" * filler_len - string = "\n" + filler + " {} ".format(message) + filler - if torch.distributed.get_rank() == 0: - print(string, flush=True) diff --git a/apex/transformer/testing/distributed_test_base.py b/apex/transformer/testing/distributed_test_base.py deleted file mode 100644 index 10cec8287..000000000 --- a/apex/transformer/testing/distributed_test_base.py +++ /dev/null @@ -1,131 +0,0 @@ -import os -import sys -import unittest -from packaging.version import Version, parse - -import torch -from torch import distributed as dist -from torch.utils import collect_env -from torch.testing._internal import common_utils -from torch.testing._internal import common_distributed - -from apex.transformer._ucc_util import HAS_UCC - -# NOTE(mkozuki): Version guard for ucc. ref: https://github.com/openucx/ucc/issues/496 -_TORCH_UCC_COMPAT_NVIDIA_DRIVER_VERSION = Version("470.42.01") -_driver_version = None -if torch.cuda.is_available(): - _driver_version = parse(collect_env.get_nvidia_driver_version(collect_env.run)) -HAS_TORCH_UCC_COMPAT_NVIDIA_DRIVER = ( - _driver_version is not None and _driver_version >= _TORCH_UCC_COMPAT_NVIDIA_DRIVER_VERSION -) - - -class DistributedTestBase(common_distributed.MultiProcessTestCase): - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - - def setUp(self) -> None: - super().setUp() - self._setup_pre_spawn() - self._spawn_processes() - - def tearDown(self) -> None: - torch.cuda.empty_cache() - super().tearDown() - - @property - def world_size(self) -> int: - return min(torch.cuda.device_count(), 4) - - @property - def init_method(self): - return f"{common_utils.FILE_SCHEMA}{self.file_name}" - - @property - def destroy_pg_upon_exit(self) -> bool: - # Overriding base test class: do not auto destroy PG upon exit. - return False - - @classmethod - def _run(cls, rank, test_name, file_name, pipe, **kwargs): - self = cls(test_name) - self.assertTrue(torch.cuda.is_available()) - self.assertTrue(hasattr(self, "DISTRIBUTED_BACKEND")) - self.rank = rank - self.file_name = file_name - - print(f"[dist init] rank = {self.rank}, world_size = {self.world_size}") - - try: - dist.init_process_group( - init_method=self.init_method, - backend=self.DISTRIBUTED_BACKEND, - world_size=int(self.world_size), - rank=self.rank, - ) - except RuntimeError as e: - if "recompile" in e.args[0]: - print(f"Backend of {self.DISTRIBUTED_BACKEND} not available") - sys.exit(0) - raise - - torch.cuda.set_device(self.rank % torch.cuda.device_count()) - - dist.barrier() - self.run_test(test_name, pipe) - dist.barrier() - - dist.destroy_process_group() - sys.exit(0) - - def _setup_pre_spawn(self): - pass - - -class NcclDistributedTestBase(DistributedTestBase): - DISTRIBUTED_BACKEND = "nccl" - - -@unittest.skipUnless( - HAS_UCC, - "Requires either torch ucc or pytorch build from source with native ucc installed and enabled", -) -@unittest.skipUnless( - HAS_TORCH_UCC_COMPAT_NVIDIA_DRIVER, - f"`torch_ucc` requires NVIDIA driver >= {_TORCH_UCC_COMPAT_NVIDIA_DRIVER_VERSION} but {_driver_version} found. " - "See https://github.com/openucx/ucc/issues/496", -) -class UccDistributedTestBase(DistributedTestBase): - DISTRIBUTED_BACKEND = "ucc" - - def _setup_pre_spawn(self) -> None: - self.master_addr = "localhost" - os.environ["MASTER_ADDR"] = "localhost" - self._has_master_port = "MASTER_PORT" in os.environ - if self._has_master_port: - self.master_port = os.environ["MASTER_PORT"] - else: - try: - from caffe2.torch.fb.common.utils import get_free_port - - self.master_port = str(get_free_port()) - except ImportError: - self.master_port = "12375" - os.environ["MASTER_PORT"] = self.master_port - - self._has_ucx_tls = "UCX_TLS" in os.environ - if not self._has_ucx_tls: - os.environ["UCX_TLS"] = "tcp,cuda" - print('os.environ["UCX_TLS"] = {}'.format(os.environ["UCX_TLS"])) - - def tearDown(self) -> None: - super().tearDown() - if not self._has_master_port: - del os.environ["MASTER_PORT"] - if not self._has_ucx_tls: - del os.environ["UCX_TLS"] - - @property - def init_method(self): - return "tcp://localhost:" + os.environ["MASTER_PORT"] diff --git a/apex/transformer/testing/global_vars.py b/apex/transformer/testing/global_vars.py deleted file mode 100644 index f19283347..000000000 --- a/apex/transformer/testing/global_vars.py +++ /dev/null @@ -1,283 +0,0 @@ -# coding=utf-8 -# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Megatron global variables.""" - -import os -import sys -import time - -import torch - -from apex.transformer.microbatches import build_num_microbatches_calculator -from .arguments import parse_args - -_GLOBAL_ARGS = None -_GLOBAL_NUM_MICROBATCHES_CALCULATOR = None -_GLOBAL_TOKENIZER = None -_GLOBAL_TENSORBOARD_WRITER = None -_GLOBAL_ADLR_AUTORESUME = None -_GLOBAL_TIMERS = None - - -def get_args(): - """Return arguments.""" - _ensure_var_is_initialized(_GLOBAL_ARGS, "args") - return _GLOBAL_ARGS - - -def get_num_microbatches() -> int: - return _GLOBAL_NUM_MICROBATCHES_CALCULATOR.get() - - -def get_current_global_batch_size() -> int: - return _GLOBAL_NUM_MICROBATCHES_CALCULATOR.get_current_global_batch_size() - - -def update_num_microbatches(consumed_samples: int, *, consistency_check: bool = True) -> None: - """Update the number of microbatches upon the number of consumed samples. - - .. note:: - This function has no effect unless ``rampup_batch_size`` is set. - - Args: - consumed_samples: The number of consumed samples so far. Basically this is equal to - :math:`num_iter * global_batch_size`. - consistency_check: If :obj:`True`, sanity checks the consumed samples, i.e., check if - ``consumed_samples`` is divisible by :math:`micro_batch_size \times data_parallel_size`. - """ - _GLOBAL_NUM_MICROBATCHES_CALCULATOR.update(consumed_samples, consistency_check) - - -# def get_tokenizer(): -# """Return tokenizer.""" -# _ensure_var_is_initialized(_GLOBAL_TOKENIZER, 'tokenizer') -# return _GLOBAL_TOKENIZER - - -def get_tensorboard_writer(): - """Return tensorboard writer. It can be None so no need - to check if it is initialized.""" - return _GLOBAL_TENSORBOARD_WRITER - - -def get_adlr_autoresume(): - """ADLR autoresume object. It can be None so no need - to check if it is initialized.""" - return _GLOBAL_ADLR_AUTORESUME - - -def get_timers(): - """Return timers.""" - _ensure_var_is_initialized(_GLOBAL_TIMERS, "timers") - return _GLOBAL_TIMERS - - -def set_global_variables( - extra_args_provider=None, - args_defaults={}, - override_args={}, - ignore_unknown_args=False, -): - """Set args, tokenizer, tensorboard-writer, adlr-autoresume, and timers.""" - args = _parse_args( - extra_args_provider=extra_args_provider, - defaults=args_defaults, - override_args=override_args, - ignore_unknown_args=ignore_unknown_args, - ) - # _build_num_microbatches_calculator(args) - # if args.vocab_file: - # _ = _build_tokenizer(args) - _set_tensorboard_writer(args) - _set_adlr_autoresume(args) - _set_timers() - - -def _parse_args(extra_args_provider=None, defaults={}, override_args={}, ignore_unknown_args=False): - """Parse entire arguments.""" - global _GLOBAL_ARGS - _ensure_var_is_not_initialized(_GLOBAL_ARGS, "args") - _GLOBAL_ARGS = parse_args( - extra_args_provider=extra_args_provider, - defaults=defaults, - override_args=override_args, - ignore_unknown_args=ignore_unknown_args, - ) - return _GLOBAL_ARGS - - -def _build_num_microbatches_calculator(args): - global _GLOBAL_NUM_MICROBATCHES_CALCULATOR - _ensure_var_is_not_initialized( - _GLOBAL_NUM_MICROBATCHES_CALCULATOR, "num microbatches calculator" - ) - - _GLOBAL_NUM_MICROBATCHES_CALCULATOR = build_num_microbatches_calculator(args) - - -# def _build_tokenizer(args): -# """Initialize tokenizer.""" -# global _GLOBAL_TOKENIZER -# _ensure_var_is_not_initialized(_GLOBAL_TOKENIZER, 'tokenizer') -# _GLOBAL_TOKENIZER = build_tokenizer(args) -# return _GLOBAL_TOKENIZER - - -# def rebuild_tokenizer(args): -# global _GLOBAL_TOKENIZER -# _GLOBAL_TOKENIZER = None -# return _build_tokenizer(args) - - -def _set_tensorboard_writer(args): - """Set tensorboard writer.""" - global _GLOBAL_TENSORBOARD_WRITER - _ensure_var_is_not_initialized(_GLOBAL_TENSORBOARD_WRITER, "tensorboard writer") - - if ( - hasattr(args, "tensorboard_dir") - and args.tensorboard_dir - and args.rank == (args.world_size - 1) - ): - try: - from torch.utils.tensorboard import SummaryWriter - - print("> setting tensorboard ...") - _GLOBAL_TENSORBOARD_WRITER = SummaryWriter( - log_dir=args.tensorboard_dir, max_queue=args.tensorboard_queue_size - ) - except ModuleNotFoundError: - print( - "WARNING: TensorBoard writing requested but is not " - "available (are you using PyTorch 1.1.0 or later?), " - "no TensorBoard logs will be written.", - flush=True, - ) - - -def _set_adlr_autoresume(args): - """Initialize ADLR autoresume.""" - global _GLOBAL_ADLR_AUTORESUME - _ensure_var_is_not_initialized(_GLOBAL_ADLR_AUTORESUME, "adlr autoresume") - - if args.adlr_autoresume: - if args.rank == 0: - print("enabling autoresume ...", flush=True) - sys.path.append(os.environ.get("SUBMIT_SCRIPTS", ".")) - try: - from userlib.auto_resume import AutoResume - except BaseException: - print("ADLR autoresume is not available, exiting ...") - sys.exit() - - _GLOBAL_ADLR_AUTORESUME = AutoResume - - -def _set_timers(): - """Initialize timers.""" - global _GLOBAL_TIMERS - _ensure_var_is_not_initialized(_GLOBAL_TIMERS, "timers") - _GLOBAL_TIMERS = Timers() - - -def _ensure_var_is_initialized(var, name): - """Make sure the input variable is not None.""" - assert var is not None, "{} is not initialized.".format(name) - - -def _ensure_var_is_not_initialized(var, name): - """Make sure the input variable is not None.""" - assert var is None, "{} is already initialized.".format(name) - - -class _Timer: - """Timer.""" - - def __init__(self, name): - self.name_ = name - self.elapsed_ = 0.0 - self.started_ = False - self.start_time = time.time() - - def start(self): - """Start the timer.""" - assert not self.started_, "timer has already been started" - torch.cuda.synchronize() - self.start_time = time.time() - self.started_ = True - - def stop(self): - """Stop the timer.""" - assert self.started_, "timer is not started" - torch.cuda.synchronize() - self.elapsed_ += time.time() - self.start_time - self.started_ = False - - def reset(self): - """Reset timer.""" - self.elapsed_ = 0.0 - self.started_ = False - - def elapsed(self, reset=True): - """Calculate the elapsed time.""" - started_ = self.started_ - # If the timing in progress, end it first. - if self.started_: - self.stop() - # Get the elapsed time. - elapsed_ = self.elapsed_ - # Reset the elapsed time - if reset: - self.reset() - # If timing was in progress, set it back. - if started_: - self.start() - return elapsed_ - - -class Timers: - """Group of timers.""" - - def __init__(self): - self.timers = {} - - def __call__(self, name): - if name not in self.timers: - self.timers[name] = _Timer(name) - return self.timers[name] - - def write(self, names, writer, iteration, normalizer=1.0, reset=False): - """Write timers to a tensorboard writer""" - # currently when using add_scalars, - # torch.utils.add_scalars makes each timer its own run, which - # polutes the runs list, so we just add each as a scalar - assert normalizer > 0.0 - for name in names: - value = self.timers[name].elapsed(reset=reset) / normalizer - writer.add_scalar(name + "-time", value, iteration) - - def log(self, names, normalizer=1.0, reset=True): - """Log a group of timers.""" - assert normalizer > 0.0 - string = "time (ms)" - for name in names: - elapsed_time = self.timers[name].elapsed(reset=reset) * 1000.0 / normalizer - string += " | {}: {:.2f}".format(name, elapsed_time) - if torch.distributed.is_initialized(): - if torch.distributed.get_rank() == (torch.distributed.get_world_size() - 1): - print(string, flush=True) - else: - print(string, flush=True) diff --git a/apex/transformer/testing/standalone_bert.py b/apex/transformer/testing/standalone_bert.py deleted file mode 100644 index ad95369fb..000000000 --- a/apex/transformer/testing/standalone_bert.py +++ /dev/null @@ -1,268 +0,0 @@ -import contextlib - -import torch - -from apex.transformer import tensor_parallel -from apex.transformer.enums import AttnMaskType -from apex.transformer.layers import FusedLayerNorm as LayerNorm -from apex.transformer.testing.global_vars import get_args -from apex.transformer.testing.standalone_transformer_lm import ( - MegatronModule, - get_language_model, - get_linear_layer, - init_method_normal, - scaled_init_method_normal, - parallel_lm_logits, -) - - -def bert_extended_attention_mask(attention_mask): - # We create a 3D attention mask from a 2D tensor mask. - # [b, 1, s] - attention_mask_b1s = attention_mask.unsqueeze(1) - # [b, s, 1] - attention_mask_bs1 = attention_mask.unsqueeze(2) - # [b, s, s] - attention_mask_bss = attention_mask_b1s * attention_mask_bs1 - # [b, 1, s, s] - extended_attention_mask = attention_mask_bss.unsqueeze(1) - - # Convert attention mask to binary: - extended_attention_mask = extended_attention_mask < 0.5 - - return extended_attention_mask - - -def bert_position_ids(token_ids): - # Create position ids - seq_length = token_ids.size(1) - position_ids = torch.arange(seq_length, dtype=torch.long, device=token_ids.device) - position_ids = position_ids.unsqueeze(0).expand_as(token_ids) - - return position_ids - - -class BertLMHead(MegatronModule): - """Masked LM head for Bert - - Arguments: - mpu_vocab_size: model parallel size of vocabulary. - hidden_size: hidden size - init_method: init method for weight initialization - layernorm_epsilon: tolerance for layer norm divisions - parallel_output: whether output logits being distributed or not. - """ - - def __init__( - self, - mpu_vocab_size, - hidden_size, - init_method, - layernorm_epsilon, - parallel_output, - ): - super(BertLMHead, self).__init__() - - args = get_args() - - self.bias = torch.nn.Parameter(torch.zeros(mpu_vocab_size)) - # TODO: do we need this? - # mpu.set_tensor_model_parallel_attributes(self.bias, True, 0, 1) - self.parallel_output = parallel_output - - self.dense = get_linear_layer(hidden_size, hidden_size, init_method) - setattr(self.dense.weight, "sequence_parallel", args.sequence_parallel) - setattr(self.dense.bias, "sequence_parallel", args.sequence_parallel) - - self.layernorm = LayerNorm( - hidden_size, - eps=layernorm_epsilon, - sequence_parallel_enabled=args.sequence_parallel, - ) - self.gelu = torch.nn.functional.gelu - if args.openai_gelu: - self.gelu = openai_gelu - elif args.onnx_safe: - self.gelu = erf_gelu - - def forward(self, hidden_states, word_embeddings_weight): - hidden_states = self.dense(hidden_states) - hidden_states = self.gelu(hidden_states) - hidden_states = self.layernorm(hidden_states) - output = parallel_lm_logits( - hidden_states, word_embeddings_weight, self.parallel_output, bias=self.bias - ) - return output - - -def post_language_model_processing( - lm_output, - pooled_output, - lm_head, - binary_head, - lm_labels, - logit_weights, - fp16_lm_cross_entropy, -): - # Output. - lm_logits = lm_head(lm_output, logit_weights) - - binary_logits = None - if binary_head is not None: - binary_logits = binary_head(pooled_output) - - if lm_labels is None: - # [s b h] => [b s h] - return lm_logits.transpose(0, 1).contiguous(), binary_logits - else: - # [b s] => [s b] - lm_labels = lm_labels.transpose(0, 1).contiguous() - # lm_logits: [s b h] lm_labels: [s b] - if fp16_lm_cross_entropy: - assert lm_logits.dtype == torch.half - lm_loss = tensor_parallel.vocab_parallel_cross_entropy(lm_logits, lm_labels) - else: - lm_loss = tensor_parallel.vocab_parallel_cross_entropy(lm_logits.float(), lm_labels) - return lm_loss, binary_logits - - -class BertModel(MegatronModule): - """Bert Language model.""" - - def __init__( - self, - num_tokentypes=2, - add_binary_head=True, - parallel_output=True, - pre_process=True, - post_process=True, - cpu_offload=False, - ): - super(BertModel, self).__init__() - args = get_args() - - self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy - self.add_binary_head = add_binary_head - self.parallel_output = parallel_output - self.pre_process = pre_process - self.post_process = post_process - - init_method = init_method_normal(args.init_method_std) - scaled_init_method = scaled_init_method_normal(args.init_method_std, args.num_layers) - - self.language_model, self._language_model_key = get_language_model( - num_tokentypes=num_tokentypes, - add_pooler=self.add_binary_head, - encoder_attn_mask_type=AttnMaskType.padding, - init_method=init_method, - scaled_init_method=scaled_init_method, - pre_process=self.pre_process, - post_process=self.post_process, - ) - - self.initialize_word_embeddings(init_method_normal) - if self.post_process: - self.lm_head = BertLMHead( - self.word_embeddings_weight().size(0), - args.hidden_size, - init_method, - args.layernorm_epsilon, - parallel_output, - ) - self._lm_head_key = "lm_head" - self.binary_head = None - if self.add_binary_head: - self.binary_head = get_linear_layer(args.hidden_size, 2, init_method) - self._binary_head_key = "binary_head" - - self.forward_context = contextlib.nullcontext - if cpu_offload: - self.forward_context = torch.autograd.graph.save_on_cpu - - def set_input_tensor(self, input_tensor): - """See megatron.model.transformer.set_input_tensor()""" - self.language_model.set_input_tensor(input_tensor) - - def forward(self, bert_model_input, attention_mask, tokentype_ids=None, lm_labels=None): - with self.forward_context(): - extended_attention_mask = bert_extended_attention_mask(attention_mask) - input_ids = bert_model_input - position_ids = bert_position_ids(input_ids) - - lm_output = self.language_model( - input_ids, - position_ids, - extended_attention_mask, - tokentype_ids=tokentype_ids, - ) - - if self.post_process and self.add_binary_head: - lm_output, pooled_output = lm_output - else: - pooled_output = None - - if self.post_process: - return post_language_model_processing( - lm_output, - pooled_output, - self.lm_head, - self.binary_head, - lm_labels, - self.word_embeddings_weight(), - self.fp16_lm_cross_entropy, - ) - else: - return lm_output - - # NOTE(mkozuki): This method is not maintained as apex only tests forward_backward with best effort. - def state_dict_for_save_checkpoint(self, destination=None, prefix="", keep_vars=False): - """For easy load when model is combined with other heads, - add an extra key.""" - - state_dict_ = {} - state_dict_[self._language_model_key] = self.language_model.state_dict_for_save_checkpoint( - destination, prefix, keep_vars - ) - if self.post_process: - state_dict_[self._lm_head_key] = self.lm_head.state_dict_for_save_checkpoint( - destination, prefix, keep_vars - ) - if self.post_process and self.add_binary_head: - state_dict_[self._binary_head_key] = self.binary_head.state_dict( - destination, prefix, keep_vars - ) - # Save word_embeddings. - if self.post_process and not self.pre_process: - state_dict_[self._word_embeddings_for_head_key] = self.word_embeddings.state_dict( - destination, prefix, keep_vars - ) - return state_dict_ - - # NOTE(mkozuki): This method is not maintained as apex only tests forward_backward with best effort. - def load_state_dict(self, state_dict, strict=True): - """Customized load.""" - - self.language_model.load_state_dict(state_dict[self._language_model_key], strict=strict) - if self.post_process: - self.lm_head.load_state_dict(state_dict[self._lm_head_key], strict=strict) - if self.post_process and self.add_binary_head: - self.binary_head.load_state_dict(state_dict[self._binary_head_key], strict=strict) - # Load word_embeddings. - if self.post_process and not self.pre_process: - self.word_embeddings.load_state_dict( - state_dict[self._word_embeddings_for_head_key], strict=strict - ) - - -def bert_model_provider(pre_process=True, post_process=True, cpu_offload=False): - args = get_args() - num_tokentypes = 2 if args.bert_binary_head else 0 - model = BertModel( - num_tokentypes=num_tokentypes, - add_binary_head=args.bert_binary_head, - parallel_output=True, - pre_process=pre_process, - post_process=post_process, - cpu_offload=cpu_offload, - ) - return model diff --git a/apex/transformer/testing/standalone_gpt.py b/apex/transformer/testing/standalone_gpt.py deleted file mode 100644 index 3522362c2..000000000 --- a/apex/transformer/testing/standalone_gpt.py +++ /dev/null @@ -1,113 +0,0 @@ -# Copyright (c) 2021-22, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import contextlib -import torch - -from apex.transformer.enums import AttnMaskType -from apex.transformer.testing.global_vars import get_args -from apex.transformer.testing.standalone_transformer_lm import MegatronModule -from apex.transformer.testing.standalone_transformer_lm import ( - post_language_model_processing, -) -from apex.transformer.testing.standalone_transformer_lm import get_language_model -from apex.transformer.testing.standalone_transformer_lm import init_method_normal -from apex.transformer.testing.standalone_transformer_lm import ( - scaled_init_method_normal, -) - - -def gpt_model_provider( - pre_process: bool = True, - post_process: bool = True, - cpu_offload: bool = False, -) -> "GPTModel": - args = get_args() - model = GPTModel( - num_tokentypes=0, - parallel_output=True, - pre_process=pre_process, - post_process=post_process, - cpu_offload=args.cpu_offload, - ) - return model - - -class GPTModel(MegatronModule): - """GPT-2 Language model.""" - - def __init__( - self, - num_tokentypes: int = 0, - parallel_output: bool = True, - pre_process: bool = True, - post_process: bool = True, - cpu_offload: bool = False, - ): - super().__init__() - args = get_args() - - self.forward_context = contextlib.nullcontext - if cpu_offload: - self.forward_context = torch.autograd.graph.save_on_cpu - - self.parallel_output = parallel_output - self.pre_process = pre_process - self.post_process = post_process - self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy - - self.language_model, self._language_model_key = get_language_model( - num_tokentypes=num_tokentypes, - add_pooler=False, - encoder_attn_mask_type=AttnMaskType.causal, - init_method=init_method_normal(args.init_method_std), - scaled_init_method=scaled_init_method_normal(args.init_method_std, args.num_layers), - pre_process=self.pre_process, - post_process=self.post_process, - ) - - self.initialize_word_embeddings(init_method_normal) - - def set_input_tensor(self, input_tensor): - """See megatron.model.transformer.set_input_tensor()""" - self.language_model.set_input_tensor(input_tensor) - - def forward( - self, - input_ids, - position_ids, - attention_mask, - labels=None, - tokentype_ids=None, - inference_params=None, - ): - with self.forward_context(): - lm_output = self.language_model( - input_ids, - position_ids, - attention_mask, - inference_params=inference_params, - ) - - if self.post_process: - return post_language_model_processing( - lm_output, - # note(mkozuki): Am I overlooking some order of dim change? - labels.t().contiguous(), - self.word_embeddings_weight(), - self.parallel_output, - self.fp16_lm_cross_entropy, - ) - else: - return lm_output diff --git a/apex/transformer/testing/standalone_transformer_lm.py b/apex/transformer/testing/standalone_transformer_lm.py deleted file mode 100644 index 87b26078f..000000000 --- a/apex/transformer/testing/standalone_transformer_lm.py +++ /dev/null @@ -1,1553 +0,0 @@ -# coding=utf-8 -# Copyright (c) 2021-22, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""GPT-2 model.""" - -import math -import contextlib - -import torch -import torch.nn.functional as F - -import apex.transformer.utils -from apex.transformer.layers import FusedLayerNorm as LayerNorm -from apex.transformer.functional import FusedScaleMaskSoftmax -from apex.transformer import tensor_parallel -from apex.transformer.tensor_parallel.layers import ColumnParallelLinear -from apex.transformer.tensor_parallel.layers import RowParallelLinear -from apex.transformer.tensor_parallel.layers import VocabParallelEmbedding -from apex.transformer.tensor_parallel.mappings import ( - scatter_to_sequence_parallel_region, -) -from apex.transformer import parallel_state -from apex.transformer.testing.global_vars import get_args -from apex.transformer.enums import ModelType -from apex.transformer.enums import LayerType -from apex.transformer.enums import AttnType -from apex.transformer.enums import AttnMaskType -from apex.transformer.log_util import get_transformer_logger - - -_logger = get_transformer_logger(__name__) - - -def param_is_not_shared(param: torch.Tensor) -> bool: - return getattr(param, "shared", False) - - -class MegatronModule(torch.nn.Module): - """Megatron specific extensions of torch Module with support for pipelining.""" - - def __init__(self, share_word_embeddings: bool = True) -> None: - super().__init__() - self.share_word_embeddings = share_word_embeddings - - def word_embeddings_weight(self): - if self.pre_process: - return self.language_model.embedding.word_embeddings.weight - else: - if not self.share_word_embeddings: - raise Exception( - "word_embeddings_weight() called for last stage, but share_word_embeddings is false" - ) - return self.word_embeddings.weight - - def initialize_word_embeddings(self, init_method_normal): - args = get_args() - if not self.share_word_embeddings: - raise Exception( - "initialize_word_embeddings() was called but share_word_embeddings is false" - ) - - # This function just initializes the word embeddings in the final stage - # when we are using pipeline parallelism. Nothing to do if we aren't - # using pipeline parallelism. - if args.pipeline_model_parallel_size == 1: - return - - # Parameters are shared between the word embeddings layers, and the - # heads at the end of the model. In a pipelined setup with more than - # one stage, the initial embedding layer and the head are on different - # workers, so we do the following: - # 1. Create a second copy of word_embeddings on the last stage, with - # initial parameters of 0.0. - # 2. Do an all-reduce between the first and last stage to ensure that - # the two copies of word_embeddings start off with the same - # parameter values. - # 3. In the training loop, before an all-reduce between the grads of - # the two word_embeddings layers to ensure that every applied weight - # update is the same on both stages. - if parallel_state.is_pipeline_last_stage() and not self.pre_process: - assert not parallel_state.is_pipeline_first_stage() - self._word_embeddings_for_head_key = "word_embeddings_for_head" - # set word_embeddings weights to 0 here, then copy first - # stage's weights using all_reduce below. - self.word_embeddings = VocabParallelEmbedding( - args.padded_vocab_size, - args.hidden_size, - init_method=init_method_normal(args.init_method_std), - ) - self.word_embeddings.weight.data.fill_(0) - self.word_embeddings.weight.shared = True - - # Zero out initial weights for decoder embedding. - # NOTE: We don't currently support T5 with the interleaved schedule. - if not parallel_state.is_pipeline_first_stage(ignore_virtual=True) and self.pre_process: - self.language_model.embedding.zero_parameters() - - # Ensure that first and last stages have the same initial parameter - # values. - if torch.distributed.is_initialized(): - if parallel_state.is_rank_in_embedding_group(): - torch.distributed.all_reduce( - self.word_embeddings_weight(), - group=parallel_state.get_embedding_group(), - ) - - # Ensure that encoder(first stage) and decoder(split stage) position - # embeddings have the same initial parameter values - # NOTE: We don't currently support T5 with the interleaved schedule. - if ( - parallel_state.is_rank_in_position_embedding_group() - and args.pipeline_model_parallel_split_rank is not None - ): - # TODO: Support tokentype embedding. - self.language_model.embedding.cuda() - position_embeddings = self.language_model.embedding.position_embeddings - torch.distributed.all_reduce( - position_embeddings.weight, - group=parallel_state.get_position_embedding_group(), - ) - - else: - print( - "WARNING! Distributed processes aren't initialized, so " - "word embeddings in the last layer are not initialized. " - "If you are just manipulating a model this is fine, but " - "this needs to be handled manually. If you are training " - "something is definitely wrong." - ) - - -def get_linear_layer(rows, columns, init_method): - """Simple linear layer with weight initialization.""" - layer = torch.nn.Linear(rows, columns) - init_method(layer.weight) - with torch.no_grad(): - layer.bias.zero_() - return layer - - -# NOTE(mkozuki): Avoid inplace op. -def attention_mask_func( - attention_scores: torch.Tensor, attention_mask: torch.Tensor -) -> torch.Tensor: - # attention_scores.masked_fill_(attention_mask, -10000.0) - # return attention_scores - return attention_scores.masked_fill(attention_mask, -10000.0) - - -def init_method_normal(sigma): - """Init method based on N(0, sigma).""" - - def init_(tensor): - return torch.nn.init.normal_(tensor, mean=0.0, std=sigma) - - return init_ - - -def scaled_init_method_normal(sigma, num_layers): - """Init method based on N(0, sigma/sqrt(2*num_layers).""" - std = sigma / math.sqrt(2.0 * num_layers) - - def init_(tensor): - return torch.nn.init.normal_(tensor, mean=0.0, std=std) - - return init_ - - -class ParallelMLP(MegatronModule): - """MLP. - - MLP will take the input with h hidden state, project it to 4*h - hidden dimension, perform nonlinear transformation, and project the - state back into h hidden dimension. - """ - - def __init__(self, init_method, output_layer_init_method): - super().__init__() - args = get_args() - - # Project to 4h. - self.dense_h_to_4h = ColumnParallelLinear( - args.hidden_size, - args.ffn_hidden_size, - gather_output=False, - init_method=init_method, - skip_bias_add=True, - no_async_tensor_model_parallel_allreduce=not args.async_tensor_model_parallel_allreduce, - sequence_parallel_enabled=args.sequence_parallel, - ) - - self.bias_gelu_fusion = args.bias_gelu_fusion - self.activation_func = F.gelu - - # Project back to h. - self.dense_4h_to_h = RowParallelLinear( - args.ffn_hidden_size, - args.hidden_size, - input_is_parallel=True, - init_method=output_layer_init_method, - skip_bias_add=True, - sequence_parallel_enabled=args.sequence_parallel, - ) - - def forward(self, hidden_states): - # [s, b, 4hp] - intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states) - - intermediate_parallel = self.activation_func(intermediate_parallel + bias_parallel) - - # [s, b, h] - output, output_bias = self.dense_4h_to_h(intermediate_parallel) - return output, output_bias - - -class CoreAttention(MegatronModule): - def __init__(self, layer_number, attn_mask_type=AttnMaskType.padding): - super().__init__() - args = get_args() - self.fp16 = args.fp16 - self.bf16 = args.bf16 - - self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling - self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32 - if self.apply_query_key_layer_scaling: - self.attention_softmax_in_fp32 = True - self.layer_number = max(1, layer_number) - self.attn_mask_type = attn_mask_type - self.sequence_parallel = args.sequence_parallel - - projection_size = args.kv_channels * args.num_attention_heads - - # Per attention head and per partition values. - world_size = parallel_state.get_tensor_model_parallel_world_size() - self.hidden_size_per_partition = apex.transformer.utils.divide(projection_size, world_size) - self.hidden_size_per_attention_head = apex.transformer.utils.divide( - projection_size, args.num_attention_heads - ) - self.num_attention_heads_per_partition = apex.transformer.utils.divide( - args.num_attention_heads, world_size - ) - - coeff = None - self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) - if self.apply_query_key_layer_scaling: - coeff = self.layer_number - self.norm_factor *= coeff - - self.scale_mask_softmax = FusedScaleMaskSoftmax( - self.fp16, - self.bf16, - self.attn_mask_type, - args.masked_softmax_fusion, - attention_mask_func, - self.attention_softmax_in_fp32, - coeff, - ) - # Dropout. Note that for a single iteration, this layer will generate - # different outputs on different number of parallel partitions but - # on average it should not be partition dependent. - self.attention_dropout = torch.nn.Dropout(args.attention_dropout) - - def forward(self, query_layer, key_layer, value_layer, attention_mask): - # =================================== - # Raw attention scores. [b, np, s, s] - # =================================== - # [b, np, sq, sk] - output_size = ( - query_layer.size(1), - query_layer.size(2), - query_layer.size(0), - key_layer.size(0), - ) - # [sq, b, np, hn] -> [sq, b * np, hn] - query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1) - # [sk, b, np, hn] -> [sk, b * np, hn] - key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1) - - # preallocting input tensor: [b * np, sq, sk] - matmul_input_buffer = torch.empty( - output_size[0] * output_size[1], - output_size[2], - output_size[3], - dtype=query_layer.dtype, - device=torch.cuda.current_device(), - ) - # Raw attention scores. [b * np, sq, sk] - matmul_result = torch.baddbmm( - matmul_input_buffer, - query_layer.transpose(0, 1), # [b * np, sq, hn] - key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] - beta=0.0, - alpha=(1.0 / self.norm_factor), - ) - - # change view to [b, np, sq, sk] - attention_scores = matmul_result.view(*output_size) - - # =========================== - # Attention probs and dropout - # =========================== - # attention scores and attention mask [b, np, sq, sk] - attention_probs = self.scale_mask_softmax(attention_scores, attention_mask) - - # This is actually dropping out entire tokens to attend to, which might - # seem a bit unusual, but is taken from the original Transformer paper. - if not self.sequence_parallel: - with tensor_parallel.get_cuda_rng_tracker().fork(): - attention_probs = self.attention_dropout(attention_probs) - else: - attention_probs = self.attention_dropout(attention_probs) - - # ========================= - # Context layer. [sq, b, hp] - # ========================= - - # value_layer -> context layer. - # [sk, b, np, hn] --> [b, np, sq, hn] - - # context layer shape: [b, np, sq, hn] - output_size = ( - value_layer.size(1), - value_layer.size(2), - query_layer.size(0), - value_layer.size(3), - ) - - # change view [sk, b * np, hn] - value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1) - - # change view [b * np, sq, sk] - attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1) - - # matmul: [b * np, sq, hn] - context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1)) - - # change view [b, np, sq, hn] - context_layer = context_layer.view(*output_size) - - # [b, np, sq, hn] --> [sq, b, np, hn] - context_layer = context_layer.permute(2, 0, 1, 3).contiguous() - - # [sq, b, np, hn] --> [sq, b, hp] - new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) - context_layer = context_layer.view(*new_context_layer_shape) - - return context_layer - - -class ParallelAttention(MegatronModule): - """Parallel self-attention layer abstract class. - - Self-attention layer takes input with size [b, s, h] - and returns output of the same size. - """ - - def __init__( - self, - init_method, - output_layer_init_method, - layer_number, - attention_type=AttnType.self_attn, - attn_mask_type=AttnMaskType.padding, - ): - super().__init__() - args = get_args() - self.layer_number = max(1, layer_number) - self.attention_type = attention_type - self.attn_mask_type = attn_mask_type - self.params_dtype = args.params_dtype - - projection_size = args.kv_channels * args.num_attention_heads - - # Per attention head and per partition values. - world_size = parallel_state.get_tensor_model_parallel_world_size() - self.hidden_size_per_attention_head = apex.transformer.utils.divide( - projection_size, args.num_attention_heads - ) - self.num_attention_heads_per_partition = apex.transformer.utils.divide( - args.num_attention_heads, world_size - ) - - # Strided linear layer. - if attention_type == AttnType.self_attn: - self.query_key_value = ColumnParallelLinear( - args.hidden_size, - 3 * projection_size, - gather_output=False, - init_method=init_method, - no_async_tensor_model_parallel_allreduce=not args.async_tensor_model_parallel_allreduce, - sequence_parallel_enabled=args.sequence_parallel, - ) - else: - assert attention_type == AttnType.cross_attn - self.query = ColumnParallelLinear( - args.hidden_size, - projection_size, - gather_output=False, - init_method=init_method, - no_async_tensor_model_parallel_allreduce=not args.async_tensor_model_parallel_allreduce, - sequence_parallel_enabled=args.sequence_parallel, - ) - - self.key_value = ColumnParallelLinear( - args.hidden_size, - 2 * projection_size, - gather_output=False, - init_method=init_method, - no_async_tensor_model_parallel_allreduce=not args.async_tensor_model_parallel_allreduce, - sequence_parallel_enabled=args.sequence_parallel, - ) - - self.core_attention = CoreAttention(self.layer_number, self.attn_mask_type) - self.checkpoint_core_attention = args.recompute_granularity == "selective" - - # Output. - self.dense = RowParallelLinear( - projection_size, - args.hidden_size, - input_is_parallel=True, - init_method=output_layer_init_method, - skip_bias_add=True, - sequence_parallel_enabled=args.sequence_parallel, - ) - - def _checkpointed_attention_forward(self, query_layer, key_layer, value_layer, attention_mask): - """Forward method with activation checkpointing.""" - - def custom_forward(*inputs): - query_layer = inputs[0] - key_layer = inputs[1] - value_layer = inputs[2] - attention_mask = inputs[3] - output_ = self.core_attention(query_layer, key_layer, value_layer, attention_mask) - return output_ - - hidden_states = tensor_parallel.checkpoint( - custom_forward, False, query_layer, key_layer, value_layer, attention_mask - ) - - return hidden_states - - def _allocate_memory(self, inference_max_sequence_len, batch_size): - return torch.empty( - inference_max_sequence_len, - batch_size, - self.num_attention_heads_per_partition, - self.hidden_size_per_attention_head, - dtype=self.params_dtype, - device=torch.cuda.current_device(), - ) - - def forward(self, hidden_states, attention_mask, encoder_output=None, inference_params=None): - # hidden_states: [sq, b, h] - - # ================================================= - # Pre-allocate memory for key-values for inference. - # ================================================= - if inference_params: - if self.layer_number not in inference_params.key_value_memory_dict: - inf_max_seq_len = inference_params.max_sequence_len - inf_max_batch_size = inference_params.max_batch_size - inference_key_memory = self._allocate_memory(inf_max_seq_len, inf_max_batch_size) - inference_value_memory = self._allocate_memory(inf_max_seq_len, inf_max_batch_size) - inference_params.key_value_memory_dict[self.layer_number] = ( - inference_key_memory, - inference_value_memory, - ) - else: - ( - inference_key_memory, - inference_value_memory, - ) = inference_params.key_value_memory_dict[self.layer_number] - - # ===================== - # Query, Key, and Value - # ===================== - - if self.attention_type == AttnType.self_attn: - # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] - mixed_x_layer, _ = self.query_key_value(hidden_states) - - # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn] - new_tensor_shape = mixed_x_layer.size()[:-1] + ( - self.num_attention_heads_per_partition, - 3 * self.hidden_size_per_attention_head, - ) - mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) - - # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] - ( - query_layer, - key_layer, - value_layer, - ) = tensor_parallel.utils.split_tensor_along_last_dim(mixed_x_layer, 3) - else: - # Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)] - mixed_kv_layer, _ = self.key_value(encoder_output) - - # [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn] - new_tensor_shape = mixed_kv_layer.size()[:-1] + ( - self.num_attention_heads_per_partition, - 2 * self.hidden_size_per_attention_head, - ) - mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape) - - # [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn] - ( - key_layer, - value_layer, - ) = tensor_parallel.utils.split_tensor_along_last_dim(mixed_kv_layer, 2) - - # Attention head [sq, b, h] --> [sq, b, hp] - query_layer, _ = self.query(hidden_states) - # [sq, b, hp] --> [sq, b, np, hn] - new_tensor_shape = query_layer.size()[:-1] + ( - self.num_attention_heads_per_partition, - self.hidden_size_per_attention_head, - ) - query_layer = query_layer.view(*new_tensor_shape) - - # ================================== - # Adjust key and value for inference - # ================================== - - if inference_params: - batch_start = inference_params.batch_size_offset - batch_end = batch_start + key_layer.size(1) - assert batch_end <= inference_key_memory.size(1) - sequence_start = inference_params.sequence_len_offset - sequence_end = sequence_start + key_layer.size(0) - assert sequence_end <= inference_key_memory.size(0) - # Copy key and values. - inference_key_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = ( - key_layer - ) - inference_value_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = ( - value_layer - ) - key_layer = inference_key_memory[:sequence_end, batch_start:batch_end, ...] - value_layer = inference_value_memory[:sequence_end, batch_start:batch_end, ...] - - # ================================== - # core attention computation - # ================================== - - if self.checkpoint_core_attention: - context_layer = self._checkpointed_attention_forward( - query_layer, key_layer, value_layer, attention_mask - ) - else: - context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask) - - # ================= - # Output. [sq, b, h] - # ================= - - output, bias = self.dense(context_layer) - - return output, bias - - -def bias_dropout_add( - x: torch.Tensor, - bias: torch.Tensor, - residual: torch.Tensor, - prob: float, - training: bool, -) -> torch.Tensor: - out = torch.nn.functional.dropout(x + bias, p=prob, training=training) - out = residual + out - return out - - -def get_bias_dropout_add(training): - def _bias_dropout_add(x, bias, residual, prob): - return bias_dropout_add(x, bias, residual, prob, training) - - return _bias_dropout_add - - -class ParallelTransformerLayer(MegatronModule): - """A single transformer layer. - - Transformer layer takes input with size [s, b, h] and returns an - output of the same size. - """ - - def __init__( - self, - init_method, - output_layer_init_method, - layer_number, - layer_type=LayerType.encoder, - self_attn_mask_type=AttnMaskType.padding, - drop_path_rate=0.0, - ): - args = get_args() - - super().__init__() - self.layer_number = layer_number - self.layer_type = layer_type - - self.apply_residual_connection_post_layernorm = ( - args.apply_residual_connection_post_layernorm - ) - - self.bf16 = args.bf16 - self.fp32_residual_connection = args.fp32_residual_connection - - # Layernorm on the input data. - self.input_layernorm = LayerNorm( - args.hidden_size, - eps=args.layernorm_epsilon, - # no_persist_layer_norm=args.no_persist_layer_norm, - sequence_parallel_enabled=args.sequence_parallel, - ) - - # Self attention. - self.self_attention = ParallelAttention( - init_method, - output_layer_init_method, - layer_number, - attention_type=AttnType.self_attn, - attn_mask_type=self_attn_mask_type, - ) - self.hidden_dropout = args.hidden_dropout - self.bias_dropout_fusion = args.bias_dropout_fusion - # note(mkozuki) - # self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else None - assert drop_path_rate <= 0.0 - self.drop_path = None - - # Layernorm on the attention output - self.post_attention_layernorm = LayerNorm( - args.hidden_size, - eps=args.layernorm_epsilon, - # no_persist_layer_norm=args.no_persist_layer_norm, - sequence_parallel_enabled=args.sequence_parallel, - ) - - if self.layer_type == LayerType.decoder: - self.inter_attention = ParallelAttention( - init_method, - output_layer_init_method, - layer_number, - attention_type=AttnType.cross_attn, - ) - # Layernorm on the attention output. - self.post_inter_attention_layernorm = LayerNorm( - args.hidden_size, - eps=args.layernorm_epsilon, - # no_persist_layer_norm=args.no_persist_layer_norm, - sequence_parallel_enabled=args.sequence_parallel, - ) - - # MLP - # note(mkozuki) - assert args.num_experts is None - # if args.num_experts is not None: - # self.mlp = SwitchMLP(init_method, output_layer_init_method) - # else: - # self.mlp = ParallelMLP(init_method, output_layer_init_method) - self.mlp = ParallelMLP(init_method, output_layer_init_method) - - # Set bias+dropout+add fusion grad_enable execution handler. - TORCH_MAJOR = int(torch.__version__.split(".")[0]) - TORCH_MINOR = int(torch.__version__.split(".")[1]) - use_nvfuser = TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10) - self.bias_dropout_add_exec_handler = ( - contextlib.nullcontext if use_nvfuser else torch.enable_grad - ) - - def forward( - self, - hidden_states, - attention_mask, - encoder_output=None, - enc_dec_attn_mask=None, - inference_params=None, - ): - # hidden_states: [s, b, h] - # Layer norm at the beginning of the transformer layer. - layernorm_output = self.input_layernorm(hidden_states) - # Self attention. - attention_output, attention_bias = self.self_attention( - layernorm_output, attention_mask, inference_params=inference_params - ) - - # Residual connection. - if self.apply_residual_connection_post_layernorm: - residual = layernorm_output - else: - residual = hidden_states - - if self.drop_path is None: - bias_dropout_add_func = get_bias_dropout_add(self.training) - - with self.bias_dropout_add_exec_handler(): - layernorm_input = bias_dropout_add_func( - attention_output, - attention_bias.expand_as(residual), - residual, - self.hidden_dropout, - ) - else: - out = torch.nn.functional.dropout( - attention_output + attention_bias, - p=self.hidden_dropout, - training=self.training, - ) - layernorm_input = residual + self.drop_path(out) - - # Layer norm post the self attention. - layernorm_output = self.post_attention_layernorm(layernorm_input) - - if self.layer_type == LayerType.decoder: - attention_output, attention_bias = self.inter_attention( - layernorm_output, enc_dec_attn_mask, encoder_output=encoder_output - ) - # residual connection - if self.apply_residual_connection_post_layernorm: - residual = layernorm_output - else: - residual = layernorm_input - - with self.bias_dropout_add_exec_handler(): - layernorm_input = bias_dropout_add_func( - attention_output, - attention_bias.expand_as(residual), - residual, - self.hidden_dropout, - ) - - # Layer norm post the decoder attention - layernorm_output = self.post_inter_attention_layernorm(layernorm_input) - - # MLP. - mlp_output, mlp_bias = self.mlp(layernorm_output) - - # Second residual connection. - if self.apply_residual_connection_post_layernorm: - residual = layernorm_output - else: - residual = layernorm_input - - if self.drop_path is None: - with self.bias_dropout_add_exec_handler(): - output = bias_dropout_add_func( - mlp_output, - mlp_bias.expand_as(residual), - residual, - self.hidden_dropout, - ) - else: - out = torch.nn.functional.dropout( - mlp_output + mlp_bias, p=self.hidden_dropout, training=self.training - ) - output = residual + self.drop_path(out) - - return output - - -class ParallelTransformer(MegatronModule): - """Transformer class.""" - - def __init__( - self, - init_method, - output_layer_init_method, - layer_type=LayerType.encoder, - self_attn_mask_type=AttnMaskType.padding, - post_layer_norm=True, - pre_process=True, - post_process=True, - drop_path_rate=0.0, - ): - super().__init__() - args = get_args() - - self.layer_type = layer_type - self.model_type = args.model_type - self.bf16 = args.bf16 - self.fp32_residual_connection = args.fp32_residual_connection - self.post_layer_norm = post_layer_norm - self.pre_process = pre_process - self.post_process = post_process - self.input_tensor = None - self.drop_path_rate = drop_path_rate - - # Store activation checkpoiting flag. - self.recompute_granularity = args.recompute_granularity - self.recompute_method = args.recompute_method - self.recompute_num_layers = args.recompute_num_layers - self.distribute_saved_activations = ( - args.distribute_saved_activations and not args.sequence_parallel - ) - - self.sequence_parallel = args.sequence_parallel - - # Number of layers. - self.num_layers = get_num_layers(args, args.model_type == ModelType.encoder_and_decoder) - - self.drop_path_rates = [ - rate.item() for rate in torch.linspace(0, self.drop_path_rate, args.num_layers) - ] - - # Transformer layers. - def build_layer(layer_number): - return ParallelTransformerLayer( - init_method, - output_layer_init_method, - layer_number, - layer_type=layer_type, - self_attn_mask_type=self_attn_mask_type, - drop_path_rate=self.drop_path_rates[layer_number - 1], - ) - - if args.virtual_pipeline_model_parallel_size is not None: - assert args.num_layers % args.virtual_pipeline_model_parallel_size == 0, ( - "num_layers_per_stage must be divisible by virtual_pipeline_model_parallel_size" - ) - assert args.model_type != ModelType.encoder_and_decoder - # Number of layers in each model chunk is the number of layers in the stage, - # divided by the number of model chunks in a stage. - self.num_layers = self.num_layers // args.virtual_pipeline_model_parallel_size - # With 8 layers, 2 stages, and 4 model chunks, we want an assignment of - # layers to stages like (each list is a model chunk): - # Stage 0: [0] [2] [4] [6] - # Stage 1: [1] [3] [5] [7] - # With 8 layers, 2 stages, and 2 virtual stages, we want an assignment of - # layers to stages like (each list is a model chunk): - # Stage 0: [0, 1] [4, 5] - # Stage 1: [2, 3] [6, 7] - offset = parallel_state.get_virtual_pipeline_model_parallel_rank() * ( - args.num_layers // args.virtual_pipeline_model_parallel_size - ) + (parallel_state.get_pipeline_model_parallel_rank() * self.num_layers) - else: - # Each stage gets a contiguous set of layers. - if ( - args.model_type == ModelType.encoder_and_decoder - and parallel_state.get_pipeline_model_parallel_world_size() > 1 - ): - pipeline_rank = parallel_state.get_pipeline_model_parallel_rank() - if layer_type == LayerType.encoder: - offset = pipeline_rank * self.num_layers - else: - num_ranks_in_enc = args.pipeline_model_parallel_split_rank - offset = (pipeline_rank - num_ranks_in_enc) * self.num_layers - else: - offset = parallel_state.get_pipeline_model_parallel_rank() * self.num_layers - - if self.num_layers == 0: - # When a standalone embedding stage is used (e.g., - # args.standalone_embedding_stage == True), virtual pipeline ranks - # on pipeline rank 0 will have zero transformer layers assigned to - # them. This results in the model's input and output tensors to be - # the same, which will cause failure for certain output tensor - # optimizations (e.g., pipeline output deallocation). To remedy - # this, we assign a 'no-op' layer on these ranks, which will - # disconnect the input tensor from the output tensor. - self.num_layers = 1 - self.layers = torch.nn.ModuleList([NoopTransformerLayer(1)]) - else: - self.layers = torch.nn.ModuleList( - [build_layer(i + 1 + offset) for i in range(self.num_layers)] - ) - - if self.post_process and self.post_layer_norm: - # Final layer norm before output. - self.final_layernorm = LayerNorm( - args.hidden_size, - eps=args.layernorm_epsilon, - # no_persist_layer_norm=args.no_persist_layer_norm, - sequence_parallel_enabled=args.sequence_parallel, - ) - - def _get_layer(self, layer_number): - return self.layers[layer_number] - - def _checkpointed_forward( - self, hidden_states, attention_mask, encoder_output, enc_dec_attn_mask - ): - """Forward method with activation checkpointing.""" - - def custom(start, end): - def custom_forward(*inputs): - x_ = inputs[0] - attention_mask = inputs[1] - encoder_output = inputs[2] - enc_dec_attn_mask = inputs[3] - for index in range(start, end): - layer = self._get_layer(index) - x_ = layer(x_, attention_mask, encoder_output, enc_dec_attn_mask) - return x_ - - return custom_forward - - if self.recompute_method == "uniform": - # Uniformly divide the total number of Transformer layers and checkpoint - # the input activation of each divided chunk. - # A method to further reduce memory usage reducing checkpoints. - l = 0 - while l < self.num_layers: - hidden_states = tensor_parallel.random.checkpoint( - custom(l, l + self.recompute_num_layers), - self.distribute_saved_activations, - hidden_states, - attention_mask, - encoder_output, - enc_dec_attn_mask, - ) - l += self.recompute_num_layers - - elif self.recompute_method == "block": - # Checkpoint the input activation of only a set number of individual - # Transformer layers and skip the rest. - # A method fully use the device memory removing redundant re-computation. - for l in range(self.num_layers): - if l < self.recompute_num_layers: - hidden_states = tensor_parallel.random.checkpoint( - custom(l, l + 1), - self.distribute_saved_activations, - hidden_states, - attention_mask, - encoder_output, - enc_dec_attn_mask, - ) - else: - hidden_states = custom(l, l + 1)( - hidden_states, attention_mask, encoder_output, enc_dec_attn_mask - ) - else: - raise ValueError("Invalid activation recompute method.") - - return hidden_states - - def set_input_tensor(self, input_tensor): - """Set input tensor to be used instead of forward()'s input. - - When doing pipeline parallelism the input from the previous - stage comes from communication, not from the input, so the - model's forward_step_func won't have it. This function is thus - used by internal code to bypass the input provided by the - forward_step_func""" - self.input_tensor = input_tensor - - def forward( - self, - hidden_states, - attention_mask, - encoder_output=None, - enc_dec_attn_mask=None, - inference_params=None, - ): - # hidden_states: [s, b, h] - - # Checks. - if inference_params: - assert self.recompute_granularity is None, ( - "inference does not work with activation checkpointing" - ) - - if not self.pre_process: - # See set_input_tensor() - hidden_states = self.input_tensor - - # Viewless tensor. - # - We only need to create a viewless tensor in the case of micro batch - # size (mbs) == 1, since in this case, 'hidden_states.transpose()' - # above creates a view tensor, and '.contiguous()' is a pass-through. - # For mbs >= 2, '.contiguous()' creates a new tensor, eliminating - # the need to make it viewless. - # - # However, we don't explicitly check mbs == 1 here because - # make_viewless_tensor() has negligible overhead when its input - # is already viewless. - # - # - For the 'else' case above, calling make_viewless_tensor() here is - # likely redundant, since p2p_communication.py (likely originator) - # already creates viewless tensors. That said, make_viewless_tensor() - # is called here to be future-proof and corner-case-proof. - # hidden_states = mpu.make_viewless_tensor(hidden_states, requires_grad=True, keep_graph=True) - - if self.sequence_parallel: - rng_context = tensor_parallel.get_cuda_rng_tracker().fork() - else: - rng_context = contextlib.nullcontext() - - with rng_context: - # Forward pass. - if self.recompute_granularity == "full": - hidden_states = self._checkpointed_forward( - hidden_states, attention_mask, encoder_output, enc_dec_attn_mask - ) - else: - for index in range(self.num_layers): - layer = self._get_layer(index) - hidden_states = layer( - hidden_states, - attention_mask, - encoder_output=encoder_output, - enc_dec_attn_mask=enc_dec_attn_mask, - inference_params=inference_params, - ) - - # Final layer norm. - if self.post_process and self.post_layer_norm: - hidden_states = self.final_layernorm(hidden_states) - - return hidden_states - - -def get_num_layers(args, is_encoder_and_decoder_model): - """Compute the number of transformer layers resident on the current rank.""" - if parallel_state.get_pipeline_model_parallel_world_size() > 1: - if is_encoder_and_decoder_model: - assert args.pipeline_model_parallel_split_rank is not None - - # When a standalone embedding stage is used, a rank is taken from - # the encoder's ranks, to be used for the encoder's embedding - # layer. This way, the rank referenced by the 'split rank' remains - # the same whether or not a standalone embedding stage is used. - num_ranks_in_encoder = ( - args.pipeline_model_parallel_split_rank - 1 - if args.standalone_embedding_stage - else args.pipeline_model_parallel_split_rank - ) - num_ranks_in_decoder = ( - args.transformer_pipeline_model_parallel_size - num_ranks_in_encoder - ) - assert args.num_layers % num_ranks_in_encoder == 0, ( - "num_layers (%d) must be divisible by number of ranks given to encoder (%d)" - % ( - args.num_layers, - num_ranks_in_encoder, - ) - ) - assert args.num_layers % num_ranks_in_decoder == 0, ( - "num_layers (%d) must be divisible by number of ranks given to decoder (%d)" - % ( - args.num_layers, - num_ranks_in_decoder, - ) - ) - if parallel_state.is_pipeline_stage_before_split(): - num_layers = ( - 0 - if args.standalone_embedding_stage - and parallel_state.get_pipeline_model_parallel_rank() == 0 - else args.num_layers // num_ranks_in_encoder - ) - else: - num_layers = args.num_layers // num_ranks_in_decoder - else: - assert args.num_layers % args.transformer_pipeline_model_parallel_size == 0, ( - "num_layers must be divisible by transformer_pipeline_model_parallel_size" - ) - - # When a standalone embedding stage is used, all transformer layers - # are divided among pipeline rank >= 1, while on pipeline rank 0, - # ranks either contain the input embedding layer (virtual pp rank 0), - # or no layers at all (virtual pp rank >= 1). - num_layers = ( - 0 - if args.standalone_embedding_stage - and parallel_state.get_pipeline_model_parallel_rank() == 0 - else args.num_layers // args.transformer_pipeline_model_parallel_size - ) - else: - num_layers = args.num_layers - return num_layers - - -class NoopTransformerLayer(MegatronModule): - """A single 'no-op' transformer layer. - - The sole purpose of this layer is for when a standalone embedding layer - is used (i.e., args.standalone_embedding_stage == True). In this case, - zero transformer layers are assigned when pipeline rank == 0. Additionally, - when virtual pipeline rank >= 1, zero total model parameters are created - (virtual rank 0 contains the input embedding). This results in the model's - input and output tensors being the same, which causes an error when - performing certain memory optimiations on the output tensor (e.g., - deallocating it). Thus, this layer disconnects the input from the output - via a clone. Since ranks containing a no-op layer are generally under- - utilized (both compute and memory), there's no worry of any performance - degredation. - """ - - def __init__(self, layer_number): - super().__init__() - self.layer_number = layer_number - - def forward( - self, - hidden_states, - attention_mask, - encoder_output=None, - enc_dec_attn_mask=None, - inference_params=None, - ): - return hidden_states.clone() - - -def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, bias=None): - """LM logits using word embedding weights.""" - args = get_args() - # Parallel logits. - if args.async_tensor_model_parallel_allreduce or args.sequence_parallel: - input_parallel = input_ - model_parallel = parallel_state.get_tensor_model_parallel_world_size() > 1 - async_grad_allreduce = ( - args.async_tensor_model_parallel_allreduce - and model_parallel - and not args.sequence_parallel - ) - else: - input_parallel = tensor_parallel.copy_to_tensor_model_parallel_region(input_) - async_grad_allreduce = False - - # Matrix multiply. - # logits_parallel = tensor_parallel.layers.LinearWithGradAccumulationAndAsyncCommunication.apply( - # input_parallel, word_embeddings_weight, bias, args.gradient_accumulation_fusion, async_grad_allreduce, args.sequence_parallel) - logits_parallel = tensor_parallel.layers.linear_with_grad_accumulation_and_async_allreduce( - input_parallel, - word_embeddings_weight, - bias, - args.gradient_accumulation_fusion, - async_grad_allreduce, - args.sequence_parallel, - ) - # Gather if needed. - - if parallel_output: - return logits_parallel - - return tensor_parallel.gather_from_tensor_model_parallel_region(logits_parallel) - - -def get_language_model( - num_tokentypes, - add_pooler, - encoder_attn_mask_type, - init_method=None, - scaled_init_method=None, - add_encoder=True, - add_decoder=False, - decoder_attn_mask_type=AttnMaskType.causal, - pre_process=True, - post_process=True, -): - """Build language model and return along with the key to save.""" - args = get_args() - - if init_method is None: - init_method = init_method_normal(args.init_method_std) - if scaled_init_method is None: - scaled_init_method = scaled_init_method_normal(args.init_method_std, args.num_layers) - - # Language model. - language_model = TransformerLanguageModel( - init_method, - scaled_init_method, - encoder_attn_mask_type, - num_tokentypes=num_tokentypes, - add_encoder=add_encoder, - add_decoder=add_decoder, - decoder_attn_mask_type=decoder_attn_mask_type, - add_pooler=add_pooler, - pre_process=pre_process, - post_process=post_process, - ) - # key used for checkpoints. - language_model_key = "language_model" - - return language_model, language_model_key - - -class Pooler(MegatronModule): - """Pooler layer. - - Pool hidden states of a specific token (for example start of the - sequence) and add a linear transformation followed by a tanh. - - Arguments: - hidden_size: hidden size - init_method: weight initialization method for the linear layer. - bias is set to zero. - """ - - def __init__(self, hidden_size, init_method): - super().__init__() - args = get_args() - self.dense = get_linear_layer(hidden_size, hidden_size, init_method) - self.sequence_parallel = args.sequence_parallel - - def forward(self, hidden_states, sequence_index=0): - # hidden_states: [s, b, h] - # sequence_index: index of the token to pool. - # gather data along sequence dimensions - # same pooler is run on all tensor parallel nodes - if self.sequence_parallel: - hidden_states = tensor_parallel.mappings.gather_from_sequence_parallel_region( - hidden_states - ) - pooled = hidden_states[sequence_index, :, :] - pooled = self.dense(pooled) - pooled = torch.tanh(pooled) - return pooled - - -class Embedding(MegatronModule): - """Language model embeddings. - - Arguments: - hidden_size: hidden size - vocab_size: vocabulary size - max_sequence_length: maximum size of sequence. This - is used for positional embedding - embedding_dropout_prob: dropout probability for embeddings - init_method: weight initialization method - num_tokentypes: size of the token-type embeddings. 0 value - will ignore this embedding - """ - - def __init__( - self, - hidden_size, - vocab_size, - max_sequence_length, - embedding_dropout_prob, - init_method, - num_tokentypes=0, - ): - super().__init__() - - self.hidden_size = hidden_size - self.init_method = init_method - self.num_tokentypes = num_tokentypes - - args = get_args() - - # Word embeddings (parallel). - self.word_embeddings = VocabParallelEmbedding( - vocab_size, self.hidden_size, init_method=self.init_method - ) - self._word_embeddings_key = "word_embeddings" - - # Position embedding (serial). - self.position_embeddings = torch.nn.Embedding(max_sequence_length, self.hidden_size) - self._position_embeddings_key = "position_embeddings" - # Initialize the position embeddings. - self.init_method(self.position_embeddings.weight) - - # Token type embedding. - # Add this as an optional field that can be added through - # method call so we can load a pretrain model without - # token types and add them as needed. - self._tokentype_embeddings_key = "tokentype_embeddings" - if self.num_tokentypes > 0: - self.tokentype_embeddings = torch.nn.Embedding(self.num_tokentypes, self.hidden_size) - # Initialize the token-type embeddings. - self.init_method(self.tokentype_embeddings.weight) - else: - self.tokentype_embeddings = None - - self.fp32_residual_connection = args.fp32_residual_connection - self.sequence_parallel = args.sequence_parallel - # Embeddings dropout - self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob) - - def zero_parameters(self): - """Zero out all parameters in embedding.""" - self.word_embeddings.weight.data.fill_(0) - self.word_embeddings.weight.shared = True - self.position_embeddings.weight.data.fill_(0) - self.position_embeddings.weight.shared = True - if self.num_tokentypes > 0: - self.tokentype_embeddings.weight.fill_(0) - self.tokentype_embeddings.weight.shared = True - - def add_tokentype_embeddings(self, num_tokentypes): - """Add token-type embedding. This function is provided so we can add - token-type embeddings in case the pretrained model does not have it. - This allows us to load the model normally and then add this embedding. - """ - if self.tokentype_embeddings is not None: - raise Exception("tokentype embeddings is already initialized") - if torch.distributed.get_rank() == 0: - print("adding embedding for {} tokentypes".format(num_tokentypes), flush=True) - self.num_tokentypes = num_tokentypes - self.tokentype_embeddings = torch.nn.Embedding(num_tokentypes, self.hidden_size) - # Initialize the token-type embeddings. - self.init_method(self.tokentype_embeddings.weight) - - def forward(self, input_ids, position_ids, tokentype_ids=None): - # Embeddings. - words_embeddings = self.word_embeddings(input_ids) - position_embeddings = self.position_embeddings(position_ids) - embeddings = words_embeddings + position_embeddings - if tokentype_ids is not None: - assert self.tokentype_embeddings is not None - embeddings = embeddings + self.tokentype_embeddings(tokentype_ids) - else: - assert self.tokentype_embeddings is None - - # Data format change to avoid explicit tranposes : [b s h] --> [s b h]. - embeddings = embeddings.transpose(0, 1).contiguous() - - # If the input flag for fp32 residual connection is set, convert for float. - if self.fp32_residual_connection: - embeddings = embeddings.float() - - # Dropout. - if self.sequence_parallel: - embeddings = scatter_to_sequence_parallel_region(embeddings) - with tensor_parallel.get_cuda_rng_tracker().fork(): - embeddings = self.embedding_dropout(embeddings) - else: - embeddings = self.embedding_dropout(embeddings) - - return embeddings - - -class TransformerLanguageModel(MegatronModule): - """Transformer language model. - - Arguments: - transformer_hparams: transformer hyperparameters - vocab_size: vocabulary size - max_sequence_length: maximum size of sequence. This - is used for positional embedding - embedding_dropout_prob: dropout probability for embeddings - num_tokentypes: size of the token-type embeddings. 0 value - will ignore this embedding - """ - - def __init__( - self, - init_method, - output_layer_init_method, - encoder_attn_mask_type, - num_tokentypes=0, - add_encoder=True, - add_decoder=False, - decoder_attn_mask_type=AttnMaskType.causal, - add_pooler=False, - pre_process=True, - post_process=True, - ): - super().__init__() - args = get_args() - - self.pre_process = pre_process - self.post_process = post_process - self.hidden_size = args.hidden_size - self.num_tokentypes = num_tokentypes - self.init_method = init_method - self.add_encoder = add_encoder - self.encoder_attn_mask_type = encoder_attn_mask_type - self.add_decoder = add_decoder - self.decoder_attn_mask_type = decoder_attn_mask_type - self.add_pooler = add_pooler - self.encoder_hidden_state = None - - # Embeddings. - if self.pre_process: - self.embedding = Embedding( - self.hidden_size, - args.padded_vocab_size, - args.max_position_embeddings, - args.hidden_dropout, - self.init_method, - self.num_tokentypes, - ) - self._embedding_key = "embedding" - - # Transformer. - # Encoder (usually set to True, False if part of an encoder-decoder - # architecture and in encoder-only stage). - if self.add_encoder: - self.encoder = ParallelTransformer( - self.init_method, - output_layer_init_method, - self_attn_mask_type=self.encoder_attn_mask_type, - pre_process=self.pre_process, - post_process=self.post_process, - ) - self._encoder_key = "encoder" - else: - self.encoder = None - - # Decoder (usually set to False, True if part of an encoder-decoder - # architecture and in decoder-only stage). - if self.add_decoder: - self.decoder = ParallelTransformer( - self.init_method, - output_layer_init_method, - layer_type=LayerType.decoder, - self_attn_mask_type=self.decoder_attn_mask_type, - pre_process=self.pre_process, - post_process=self.post_process, - ) - self._decoder_key = "decoder" - else: - self.decoder = None - - if self.post_process: - # Pooler. - if self.add_pooler: - self.pooler = Pooler(self.hidden_size, self.init_method) - self._pooler_key = "pooler" - - def set_input_tensor(self, input_tensor): - """See megatron.model.transformer.set_input_tensor()""" - - # This is usually handled in schedules.py but some inference code still - # gives us non-lists or None - if not isinstance(input_tensor, list): - input_tensor = [input_tensor] - - if self.add_encoder and self.add_decoder: - assert len(input_tensor) == 1, ( - "input_tensor should only be length 1 for stage with both encoder and decoder" - ) - self.encoder.set_input_tensor(input_tensor[0]) - elif self.add_encoder: - assert len(input_tensor) == 1, ( - "input_tensor should only be length 1 for stage with only encoder" - ) - self.encoder.set_input_tensor(input_tensor[0]) - elif self.add_decoder: - if len(input_tensor) == 2: - self.decoder.set_input_tensor(input_tensor[0]) - self.encoder_hidden_state = input_tensor[1] - elif len(input_tensor) == 1: - self.decoder.set_input_tensor(None) - self.encoder_hidden_state = input_tensor[0] - else: - raise Exception("input_tensor must have either length 1 or 2") - else: - raise Exception("Stage must have at least either encoder or decoder") - - def forward( - self, - enc_input_ids, - enc_position_ids, - enc_attn_mask, - dec_input_ids=None, - dec_position_ids=None, - dec_attn_mask=None, - enc_dec_attn_mask=None, - tokentype_ids=None, - inference_params=None, - pooling_sequence_index=0, - enc_hidden_states=None, - output_enc_hidden=False, - ): - args = get_args() - # Encoder embedding. - if self.pre_process: - encoder_input = self.embedding( - enc_input_ids, enc_position_ids, tokentype_ids=tokentype_ids - ) - else: - encoder_input = None - - # Run encoder. - if enc_hidden_states is None: - if self.encoder is not None: - encoder_output = self.encoder( - encoder_input, enc_attn_mask, inference_params=inference_params - ) - else: - encoder_output = self.encoder_hidden_state - else: - encoder_output = enc_hidden_states.to(encoder_input.dtype) - - if self.post_process: - if self.add_pooler: - pooled_output = self.pooler(encoder_output, pooling_sequence_index) - - # output_enc_hidden refers to when we just need the encoder's - # output. For example, it is helpful to compute - # similarity between two sequences by average pooling - if not self.add_decoder or output_enc_hidden: - if self.add_pooler and self.post_process: - return encoder_output, pooled_output - else: - return encoder_output - - # Decoder embedding. - if self.pre_process: - decoder_input = self.embedding(dec_input_ids, dec_position_ids) - else: - decoder_input = None - - # Run decoder. - decoder_output = self.decoder( - decoder_input, - dec_attn_mask, - encoder_output=encoder_output, - enc_dec_attn_mask=enc_dec_attn_mask, - inference_params=inference_params, - ) - - if self.add_pooler and self.post_process: - return decoder_output, encoder_output, pooled_output - else: - return decoder_output, encoder_output - - -def post_language_model_processing( - lm_output, labels, logit_weights, parallel_output, fp16_lm_cross_entropy -): - # Output. - output = parallel_lm_logits(lm_output, logit_weights, parallel_output) - - if labels is None: - return output - else: - if fp16_lm_cross_entropy: - assert output.dtype == torch.half - loss = tensor_parallel.vocab_parallel_cross_entropy(output, labels) - else: - loss = tensor_parallel.vocab_parallel_cross_entropy(output.float(), labels) - return loss - - -def module_size(m: torch.nn.Module, only_trainable: bool = False): - """ - returns the total number of parameters used by `m` (only counting - shared parameters once); if `only_trainable` is True, then only - includes parameters with `requires_grad = True` - """ - parameters = list(m.parameters()) - if only_trainable: - parameters = [p for p in parameters if p.requires_grad] - unique = {p.data_ptr(): p for p in parameters}.values() - return sum(p.numel() for p in unique) diff --git a/apex/transformer/utils.py b/apex/transformer/utils.py deleted file mode 100644 index 9d4b779df..000000000 --- a/apex/transformer/utils.py +++ /dev/null @@ -1,50 +0,0 @@ -"""Utility functions used by both `pipeline_parallel` and `tensor_parallel`""" - -import torch - -from apex.transformer import parallel_state - -# `all_gather_into_tensor` is new placeholders for `_all_gather_base`. -# It requires the most recent version of PyTorch. -# The following 4 lines are for backward comparability with -# older PyTorch. -if "all_gather_into_tensor" not in dir(torch.distributed): - torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base - - -def ensure_divisibility(numerator, denominator): - """Ensure that numerator is divisible by the denominator.""" - assert numerator % denominator == 0, "{} is not divisible by {}".format(numerator, denominator) - - -def divide(numerator, denominator): - """Ensure that numerator is divisible by the denominator and return - the division value.""" - ensure_divisibility(numerator, denominator) - return numerator // denominator - - -def split_tensor_into_1d_equal_chunks(tensor): - """Break a tensor into equal 1D chunks.""" - data = tensor.view(-1) - partition_size = torch.numel(data) // parallel_state.get_tensor_model_parallel_world_size() - start_index = partition_size * parallel_state.get_tensor_model_parallel_rank() - end_index = start_index + partition_size - return data[start_index:end_index] - - -def gather_split_1d_tensor(tensor): - """Opposite of above function, gather values from model parallel ranks.""" - world_size = parallel_state.get_tensor_model_parallel_world_size() - numel = torch.numel(tensor) - numel_gathered = world_size * numel - gathered = torch.empty( - numel_gathered, - dtype=tensor.dtype, - device=torch.cuda.current_device(), - requires_grad=False, - ) - torch.distributed.all_gather_into_tensor( - gathered, tensor, group=parallel_state.get_tensor_model_parallel_group() - ) - return gathered diff --git a/tests/L0/run_transformer/__init__.py b/tests/L0/run_transformer/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/L0/run_transformer/gpt_scaling_test.py b/tests/L0/run_transformer/gpt_scaling_test.py deleted file mode 100644 index f65a1cefc..000000000 --- a/tests/L0/run_transformer/gpt_scaling_test.py +++ /dev/null @@ -1,118 +0,0 @@ -import subprocess -import os - -from apex.transformer.testing.commons import TEST_SUCCESS_MESSAGE - - -def run_gpt(cmd): - args = list(cmd.split(" ")) - p = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE) - outs, errs = p.communicate() - outs = list(str((outs).decode("utf-8")).splitlines()) - success = False - runtime = 0 - num_params = 0 - for out in outs: - out = str(out) - if "Average Iteration Time:" in str(out): - slicey = out[out.find(":") + 2 :] - try: - runtime = float(slicey) - except: - print(slicey) - quit() - if "Number of Parameters:" in str(out): - slicey = out[out.find(":") + 2 :] - try: - num_params = int(slicey) - except: - print(slicey) - quit() - if str(out) == str(TEST_SUCCESS_MESSAGE): - success = True - return runtime, round(float(int(num_params)) / 10.0**9, 3), success, errs - - -def plot(runtimes): - import matplotlib.pyplot as plt - - for distributed_setting in runtimes.keys(): - plt.scatter( - runtimes[distributed_setting].keys(), - runtimes[distributed_setting].values(), - label=distributed_setting, - ) - plt.legend() - plt.xlabel("Parameters (Billions)") - plt.ylabel("Training Iteration time (s)") - plt.title(str("GPT Scaling w/ Offloading")) - plt.savefig("offload_gpt_scaling.png") - plt.close() - if not os.path.exists("/my_workspace/"): - os.system("mkdir /my_workspace/") - os.system("cp *.png /my_workspace/") - - -def main(): - runtimes = {} - nlist = ( - list(range(2000, 10000, 2000)) - + list(range(10000, 50000, 5000)) - + list(range(50000, 100000, 10000)) - ) - print("N-List:", nlist) - for data_parr, tens_parr, pipe_parr in [(8, 1, 1), (4, 2, 1), (2, 1, 4), (1, 2, 4)]: - for offload in [True, False]: - dist_setting = ( - "ddp=" - + str(data_parr) - + ", tensor_parr=" - + str(tens_parr) - + ", pipe_parr=" - + str(pipe_parr) - + ", offload=" - + str(offload) - ) - runtimes[dist_setting] = {} - print("Beginning Testing for", dist_setting) - for n in nlist: - cmd = ( - "python3 -m torch.distributed.launch --nproc_per_node=8 run_gpt_minimal_test.py" - ) - cmd += ( - " --micro-batch-size 1 --num-layers " - + str(n) - + " --hidden-size 128 --num-attention-heads 16" - ) - cmd += ( - " --max-position-embeddings 128 --seq-length 128 --tensor-model-parallel-size " - + str(tens_parr) - ) - cmd += ( - " --pipeline-model-parallel-size " - + str(pipe_parr) - + (" --cpu-offload" if offload else "") - ) - print(cmd) - runtime, bill_params, success, errs = run_gpt(cmd) - if success: - runtimes[dist_setting][bill_params] = runtime - print( - str(runtime) + "s per training iter for", - str(bill_params) + "B parameter GPT-2", - ) - if n >= 10000: - plot(runtimes) - else: - print("GPT-2 w/", n, "layers failed using", dist_setting) - print("Moving on to the next distributed setting...") - print("#" * (25)) - print() - plot(runtimes) - break - print(runtimes) - plot(runtimes) - - -if __name__ == "__main__": - main() diff --git a/tests/L0/run_transformer/test_batch_sampler.py b/tests/L0/run_transformer/test_batch_sampler.py deleted file mode 100644 index f3273b931..000000000 --- a/tests/L0/run_transformer/test_batch_sampler.py +++ /dev/null @@ -1,169 +0,0 @@ -import torch -from torch.testing._internal import common_utils -from torch.utils.data import Dataset -from torch.utils.data import DataLoader - -from apex.transformer.pipeline_parallel.utils import ( - _split_batch_into_microbatch as split_batch_into_microbatch, -) - - -class MyIterableDataset(Dataset): - def __init__(self, start, end): - super().__init__() - assert end > start, "this example code only works with end >= start" - self.start = start - self.end = end - self.samples = list(range(self.start, self.end)) - - def __iter__(self): - return iter(range(self.start, self.end)) - - def __getitem__(self, index): - return self.samples[index] - - -class MegatronPretrainingRandomSampler: - def __init__( - self, - total_samples, - consumed_samples, - micro_batch_size, - data_parallel_rank, - data_parallel_size, - ): - # Keep a copy of input params for later use. - self.total_samples = total_samples - self.consumed_samples = consumed_samples - self.micro_batch_size = micro_batch_size - self.data_parallel_rank = data_parallel_rank - self.data_parallel_size = data_parallel_size - self.micro_batch_times_data_parallel_size = self.micro_batch_size * data_parallel_size - self.last_batch_size = self.total_samples % self.micro_batch_times_data_parallel_size - - # Sanity checks. - assert self.total_samples > 0, "no sample to consume: {}".format(self.total_samples) - assert self.micro_batch_size > 0 - assert data_parallel_size > 0 - assert self.data_parallel_rank < data_parallel_size, ( - "data_parallel_rank should be smaller than data size: {}, {}".format( - self.data_parallel_rank, data_parallel_size - ) - ) - - def __len__(self): - return self.total_samples - - def __iter__(self): - active_total_samples = self.total_samples - self.last_batch_size - self.epoch = self.consumed_samples // active_total_samples - current_epoch_samples = self.consumed_samples % active_total_samples - assert current_epoch_samples % self.micro_batch_times_data_parallel_size == 0 - - # data sharding and random sampling - bucket_size = ( - self.total_samples // self.micro_batch_times_data_parallel_size - ) * self.micro_batch_size - bucket_offset = current_epoch_samples // self.data_parallel_size - start_idx = self.data_parallel_rank * bucket_size - - g = torch.Generator() - g.manual_seed(self.epoch) - random_idx = torch.randperm(bucket_size, generator=g).tolist() - idx_range = [start_idx + x for x in random_idx[bucket_offset:]] - - batch = [] - # Last batch if not complete will be dropped. - for idx in idx_range: - batch.append(idx) - if len(batch) == self.micro_batch_size: - self.consumed_samples += self.micro_batch_times_data_parallel_size - yield batch - batch = [] - - -# Samples 8 tensors in total. -# First sample 4 tensors twice, then sample 2 tensors fourth. -class TestBatchSamplerBehavior(common_utils.TestCase): - def tearDown(self) -> None: - torch.cuda.empty_cache() - super().tearDown() - - def test_batch_sampler_behavior(self): - dataset = MyIterableDataset(0, 100) - - for num_workers in (1, 2, 4): - torch.manual_seed(42) - loader = DataLoader( - dataset, - batch_sampler=MegatronPretrainingRandomSampler(100, 0, 4, 0, 1), - num_workers=num_workers, - ) - samples = [] - for i, batch in enumerate(loader): - samples.append(batch) - if i == 2 - 1: - break - - torch.manual_seed(42) - loader = DataLoader( - dataset, - batch_sampler=MegatronPretrainingRandomSampler(100, 0, 2, 0, 1), - num_workers=num_workers, - ) - samples2 = [] - for i, batch in enumerate(loader): - samples2.append(batch) - if i == 4 - 1: - break - self.assertEqual( - torch.cat(samples), - torch.cat(samples2), - msg=f"num_workers={num_workers}", - ) - - def test_split_batch(self): - class MyIterableDataset(Dataset): - def __init__(self, start, end): - super().__init__() - assert end > start, "this example code only works with end >= start" - self.start = start - self.end = end - self.samples = list(range(self.start, self.end)) - - def __len__(self): - return self.end - self.start - - def __iter__(self): - return iter(range(self.start, self.end)) - - def __getitem__(self, index): - return ( - torch.tensor([index, index]), - torch.tensor([index // 2, index // 2]), - ) - - dataset = MyIterableDataset(0, 100) - torch.manual_seed(42) - global_batch_size = 16 - loader = DataLoader( - dataset, - batch_sampler=MegatronPretrainingRandomSampler(100, 0, global_batch_size, 0, 1), - num_workers=2, - ) - batch = next(iter(loader)) - - for _micro_batch_size in (1, 2, 4, 8): - microbatches = list( - split_batch_into_microbatch( - batch, - _micro_batch_size=_micro_batch_size, - _global_batch_size=global_batch_size, - ) - ) - self.assertEqual(len(microbatches), global_batch_size // _micro_batch_size) - self.assertEqual(len(microbatches[0][0]), _micro_batch_size) - - -if __name__ == "__main__": - common_utils.run_tests() diff --git a/tests/L0/run_transformer/test_bert_minimal.py b/tests/L0/run_transformer/test_bert_minimal.py deleted file mode 100644 index 9fd5b5728..000000000 --- a/tests/L0/run_transformer/test_bert_minimal.py +++ /dev/null @@ -1,262 +0,0 @@ -import torch -import unittest -from apex.transformer.testing import global_vars -from apex.transformer.testing.standalone_bert import bert_model_provider -from apex.transformer.pipeline_parallel.schedules.common import ( - _get_params_for_weight_decay_optimization, - build_model, -) -from apex.transformer.pipeline_parallel.schedules import get_forward_backward_func -from apex.transformer.pipeline_parallel.utils import ( - average_losses_across_data_parallel_group, - unwrap_model, - setup_microbatch_calculator, -) -from apex.transformer.log_util import set_logging_level -from apex.transformer import tensor_parallel, parallel_state -from apex.transformer.enums import ModelType -from apex.transformer._ucc_util import HAS_UCC -from apex.transformer.testing.distributed_test_base import ( - UccDistributedTestBase, - NcclDistributedTestBase, -) -import logging - -from torch.testing._internal import common_utils - -logging.getLogger("torch").setLevel(logging.WARNING) - - -logging.getLogger("apex").setLevel(logging.WARNING) - - -set_logging_level("WARNING") - - -class BertTestBase: - def _download_fancy_data(self): - text = """ - An original sentence not subject to any license restrictions, copyright, or royalty payments. Nothing to see here. Commercial or non-commercial use. Research or non-research purposes. The quick brown fox jumps over the lazy dog. Lorem ipsum. - """ - text = text * 1024 - encoded = text.encode("ascii", "replace") - ints = [int(encoded[i]) for i in range(len(encoded))] - return torch.tensor(ints) - - # build a batch given sequence_len and batch size - def _generate_fancy_data_labels(self, sequence_len, batch_size): - temps = [] - for i in range(batch_size): - if self.inds is None or self.data_idx >= len(self.inds): - # hack as use of RNG will fall out of sync due to pipelines being different - torch.manual_seed(self.MANUAL_SEED) - self.inds = torch.randperm(self.effective_length, device="cuda") - self.masks = ( - torch.rand( - len(self.inds) // batch_size + 1, - batch_size, - sequence_len, - device="cuda", - ) - >= self.MASK_PROB - ).long() - self.MANUAL_SEED += 1 - self.data_idx = 0 - if self.rank == 0: - print("new epoch", len(self.inds)) - print("my start", self.inds[0:5]) - print("masks_checksum:", torch.sum(self.masks)) - if self.EASY_MODE: - data_idx_ = self.data_idx % self.EASY_MODE_SIZ - else: - data_idx_ = self.data_idx - offset = self.inds[data_idx_] # * SEQUENCE_LEN - self.data_idx += 1 - - curr = self.fancy_data[offset : offset + sequence_len].clone().detach() - temps.append(curr) - temp = torch.stack(temps, dim=0).cuda() - mask = self.masks[self.data_idx // batch_size] - mask_not = torch.logical_not(mask).long() - data = mask * temp + mask_not * 124 - label = temp - if parallel_state.get_tensor_model_parallel_rank() == 0: - data_dict = {"text": data, "label": label, "mask_not": mask_not} - else: - data_dict = None - keys = ["text", "label", "mask_not"] - broadcasted_data = tensor_parallel.broadcast_data(keys, data_dict, torch.long) - return ( - broadcasted_data["text"].long(), - broadcasted_data["label"].long(), - broadcasted_data["mask_not"], - ) - - def _fwd_step_func(self, batch, model): - data, label, loss_mask = batch - y = model(data, torch.ones_like(data), lm_labels=label) - - def loss_func(output_tensor): - output_tensor, _ = output_tensor - lm_loss_ = output_tensor.float() - lm_loss = torch.sum(lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum() - averaged_loss = average_losses_across_data_parallel_group([lm_loss]) - if self.data_idx >= 1536: - # NOTE (patwang): Loss cutoff might be excessively high but roughly one in five - # unlucky random seeds do cause loss to spike to just under 8.0 - self.assertLess(averaged_loss, 8.0) - return lm_loss, {"avg": averaged_loss} - - return y, loss_func - - def _train( - self, - model, - optim, - virtual_pipeline_model_parallel_size, - pipeline_model_parallel_size, - async_comm, - ): - args = global_vars.get_args() - sequence_len = args.seq_length - micro_batch_size = args.micro_batch_size - hidden_size = args.hidden_size - global_batch_size = args.global_batch_size - forward_backward_func = get_forward_backward_func( - virtual_pipeline_model_parallel_size, pipeline_model_parallel_size - ) - tensor_shape = (sequence_len, micro_batch_size, hidden_size) - for _ in range(16): - batch = self._generate_fancy_data_labels(sequence_len, global_batch_size) - optim.zero_grad() - forward_backward_func( - self._fwd_step_func, - batch, - model, - forward_only=False, - tensor_shape=tensor_shape, - async_comm=async_comm, - sequence_parallel_enabled=args.sequence_parallel, - ) - # All-reduce layernorm parameters across model parallel nodes - # when sequence parallelism is used - if parallel_state.get_tensor_model_parallel_world_size() > 1 and args.sequence_parallel: - for model_module in model: - unwrapped_model = unwrap_model(model_module) - for param in unwrapped_model.parameters(): - if getattr(param, "sequence_parallel_enabled", False): - grad = param.grad - torch.distributed.all_reduce( - grad, - group=parallel_state.get_tensor_model_parallel_group(), - ) - - optim.step() - - @unittest.skipUnless(torch.cuda.device_count() > 2, "requires at least 3 gpus") - def test_bert_without_interleaving(self): - self._test_bert(virtual_pipeline_model_parallel_size=None) - - @unittest.skipUnless(torch.cuda.device_count() > 2, "requires at least 3 gpus") - def test_bert_with_interleaving(self): - if self.DISTRIBUTED_BACKEND == "ucc": - self.skipTest("skip interleaving with ucc") - self._test_bert(virtual_pipeline_model_parallel_size=2) - - def _test_bert(self, virtual_pipeline_model_parallel_size): - self.MANUAL_SEED = 42 - self.inds = None - self.masks = None - self.data_idx = 0 - self.MASK_PROB = 0.1 - self.EASY_MODE = False - self.EASY_MODE_SIZ = 32 - - tensor_model_parallel_size = 2 if self.world_size % 2 == 0 and self.world_size > 4 else 1 - pipeline_model_parallel_size = self.world_size // tensor_model_parallel_size - - override_args = { - "micro_batch_size": 2, - "num_layers": 16, - "hidden_size": 256, - "num_attention_heads": 8, - "max_position_embeddings": 512, - "seq_length": 512, - "global_batch_size": 128, - "pipeline_model_parallel_size": pipeline_model_parallel_size, - "tensor_model_parallel_size": tensor_model_parallel_size, - "bert_binary_head": False, - "world_size": self.world_size, - "rank": self.rank, - } - - global_vars.set_global_variables(override_args=override_args, ignore_unknown_args=True) - args = global_vars.get_args() - - self.fancy_data = self._download_fancy_data() - self.effective_length = self.fancy_data.size(0) // args.seq_length - self.effective_length = self.fancy_data.size(0) - args.seq_length - - if self.rank == 0: - print( - f"testing backend: {self.DISTRIBUTED_BACKEND} with virtual_pipeline_model_parallel_size: {virtual_pipeline_model_parallel_size}" - ) - async_comm = not args.sequence_parallel and virtual_pipeline_model_parallel_size is None - self.data_idx = 0 - args.padded_vocab_size = 128 # needed in standalone gpt - args.model_type = ModelType.encoder_or_decoder - setup_microbatch_calculator( - args.rank, - args.rampup_batch_size, - args.global_batch_size, - args.micro_batch_size, - args.data_parallel_size, - ) - parallel_state.initialize_model_parallel( - args.tensor_model_parallel_size, - args.pipeline_model_parallel_size, - virtual_pipeline_model_parallel_size, - default_backend="nccl", - p2p_backend=self.DISTRIBUTED_BACKEND, - ) - - tensor_parallel.random.model_parallel_cuda_manual_seed(0) - model = build_model( - bert_model_provider, - wrap_with_ddp=parallel_state.get_data_parallel_world_size() > 1, - virtual_pipeline_model_parallel_size=virtual_pipeline_model_parallel_size, - cpu_offload=args.cpu_offload, - ) - assert isinstance(model, list) - assert len(model) == ( - 1 - if virtual_pipeline_model_parallel_size is None - else virtual_pipeline_model_parallel_size - ) - _param_groups = _get_params_for_weight_decay_optimization(model) - optim = torch.optim.Adam(_param_groups) - self._train( - model, - optim, - virtual_pipeline_model_parallel_size, - args.pipeline_model_parallel_size, - async_comm, - ) - torch.cuda.synchronize() - - -class NcclBertTest(BertTestBase, NcclDistributedTestBase): - @property - def world_size(self) -> int: - return min(torch.cuda.device_count(), 8) - - -@unittest.skipUnless(HAS_UCC, "requires pytorch to be built with native ucc") -class UccBertTest(BertTestBase, UccDistributedTestBase): - @property - def world_size(self) -> int: - return min(torch.cuda.device_count(), 8) - - -if __name__ == "__main__": - common_utils.run_tests() diff --git a/tests/L0/run_transformer/test_cross_entropy.py b/tests/L0/run_transformer/test_cross_entropy.py deleted file mode 100644 index d12512005..000000000 --- a/tests/L0/run_transformer/test_cross_entropy.py +++ /dev/null @@ -1,109 +0,0 @@ -import logging -from typing import Tuple - -import torch -import torch.nn.functional as F -from torch.testing._internal import common_utils - -logging.getLogger("torch").setLevel(logging.WARNING) - -from apex.transformer import parallel_state -from apex.transformer import tensor_parallel -from apex.transformer.tensor_parallel import cross_entropy -from apex.transformer.testing.commons import set_random_seed, IdentityLayer -from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase -from apex.transformer.testing.distributed_test_base import UccDistributedTestBase - -logging.getLogger("apex").setLevel(logging.WARNING) - - -def torch_cross_entropy( - batch_size: int, - seq_length: int, - vocab_size: int, - logits_scale: float, - seed: int, - label_smoothing: float = 0.0, -) -> Tuple[torch.Tensor, torch.Tensor]: - set_random_seed(seed) - identity = IdentityLayer((batch_size, seq_length, vocab_size), scale=logits_scale).cuda() - logits = identity() - target = torch.cuda.LongTensor(size=(batch_size, seq_length)).random_(0, vocab_size) - loss = ( - F.cross_entropy( - logits.view(-1, logits.size()[-1]), - target.view(-1), - reduction="none", - label_smoothing=label_smoothing, - ) - .view_as(target) - .mean() - ) - loss.backward() - return loss, identity.weight.grad - - -def tensor_sharded_cross_entropy( - batch_size, seq_length, vocab_size, logits_scale, seed, label_smoothing=0.0 -): - set_random_seed(seed) - identity = IdentityLayer((batch_size, seq_length, vocab_size), scale=logits_scale).cuda() - logits = identity() - logits_parallel = tensor_parallel.scatter_to_tensor_model_parallel_region(logits) - target = torch.cuda.LongTensor(size=(batch_size, seq_length)).random_(0, vocab_size) - logits_parallel_ = logits_parallel.clone().detach() - loss = cross_entropy.vocab_parallel_cross_entropy( - logits_parallel, target, label_smoothing=label_smoothing - ).mean() - loss.backward() - # check for mutation - assert torch.equal(logits_parallel_, logits_parallel) - return loss, identity.weight.grad - - -class VocabParallelCrossEntropyTestBase: - def test_cross_entropy(self): - batch_size, sequence_length, vocab_size_per_partition = 13, 17, 11 - logits_scale = 1000.0 - seed = 1234 - for tensor_model_parallel_world_size in range(1, self.world_size + 1): - if self.world_size % tensor_model_parallel_world_size: - continue - parallel_state.initialize_model_parallel( - tensor_model_parallel_size_=tensor_model_parallel_world_size, - ) - vocab_size = vocab_size_per_partition * tensor_model_parallel_world_size - loss_torch, grad_torch = torch_cross_entropy( - batch_size, sequence_length, vocab_size, logits_scale, seed - ) - ( - loss_tensor_parallel, - grad_tensor_parallel, - ) = tensor_sharded_cross_entropy( - batch_size, sequence_length, vocab_size, logits_scale, seed - ) - - self.assertEqual( - loss_torch, - loss_tensor_parallel, - msg=f"tensor_model_parallel_size: {tensor_model_parallel_world_size}", - ) - self.assertEqual( - grad_torch, - grad_tensor_parallel, - msg=f"tensor_model_parallel_size: {tensor_model_parallel_world_size}", - ) - - parallel_state.destroy_model_parallel() - - -class NcclVocabParallelCrossEntropyTest(VocabParallelCrossEntropyTestBase, NcclDistributedTestBase): - pass - - -class UccVocabParallelCrossEntropyTest(VocabParallelCrossEntropyTestBase, UccDistributedTestBase): - pass - - -if __name__ == "__main__": - common_utils.run_tests() diff --git a/tests/L0/run_transformer/test_data.py b/tests/L0/run_transformer/test_data.py deleted file mode 100644 index f77ab0309..000000000 --- a/tests/L0/run_transformer/test_data.py +++ /dev/null @@ -1,66 +0,0 @@ -import logging - -import torch.testing -from torch.testing._internal import common_utils - -logging.getLogger("torch").setLevel(logging.WARNING) - -from apex.transformer import parallel_state -from apex.transformer.tensor_parallel import data as data_utils -from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase -from apex.transformer.testing.distributed_test_base import UccDistributedTestBase - -logging.getLogger("torch").setLevel(logging.WARNING) - - -class BroadcastDataTestBase: - def test_broadcast_data(self): - tensor_model_parallel_world_size: int = self.world_size // (1 + self.world_size > 1) - parallel_state.initialize_model_parallel( - tensor_model_parallel_size_=tensor_model_parallel_world_size - ) - - target_key_size = { - "key1": [7, 11], - "key2": [8, 2, 1], - "key3": [13], - "key4": [5, 1, 2], - "key5": [5, 12], - } - keys = [k for k in target_key_size] - - data = {} - data_t = {} - with torch.no_grad(): - for key in target_key_size: - data[key] = torch.randint(0, 1000, size=target_key_size[key]) - data_t[key] = data[key].clone() - # "key_x" is supposed to be ignored. - data["key_x"] = torch.rand(5) - data_t["key_x"] = data["key_x"].clone() - if parallel_state.get_tensor_model_parallel_rank() != 0: - data = None - - data_utils._check_data_types(keys, data_t, torch.int64) - key_size, _, _ = data_utils._build_key_size_numel_dictionaries(keys, data) - - for key in keys: - self.assertEqual(target_key_size[key], key_size[key]) - - broadcasted_data = data_utils.broadcast_data(keys, data, torch.int64) - for key in keys: - self.assertEqual(broadcasted_data[key], data_t[key].cuda()) - - parallel_state.destroy_model_parallel() - - -class NcclBroadcastDataTest(BroadcastDataTestBase, NcclDistributedTestBase): - pass - - -class UccBroadcastDataTest(BroadcastDataTestBase, UccDistributedTestBase): - pass - - -if __name__ == "__main__": - common_utils.run_tests() diff --git a/tests/L0/run_transformer/test_dynamic_batchsize.py b/tests/L0/run_transformer/test_dynamic_batchsize.py deleted file mode 100644 index 9e78d6099..000000000 --- a/tests/L0/run_transformer/test_dynamic_batchsize.py +++ /dev/null @@ -1,229 +0,0 @@ -from typing import Tuple, List - -import torch -import unittest - -from apex.transformer import parallel_state -from apex.transformer.pipeline_parallel.utils import get_num_microbatches -from apex.transformer.pipeline_parallel.schedules.common import ( - _get_params_for_weight_decay_optimization, - build_model, -) -from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_with_interleaving import ( - _forward_backward_pipelining_with_interleaving, -) -from apex.transformer.pipeline_parallel.utils import ( - setup_microbatch_calculator, - _reconfigure_microbatch_calculator, - update_num_microbatches, -) -from apex.transformer.testing import global_vars -from apex.transformer.testing.commons import ( - print_separator, - fwd_step_func, - model_provider_func, -) -from apex.transformer.log_util import get_transformer_logger -from apex.transformer._data import ( - MegatronPretrainingRandomSampler, - MegatronPretrainingSampler, -) -from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase - -from torch.testing._internal import common_utils - -# note(mkozuki): To see warmup, steady, cooldown iterations, uncomment the line below -# set_logging_level("INFO") -_logger = get_transformer_logger("pipeline_parallel_test") -# note(mkozuki): To see if local batch size increases, uncomment the line below -# _logger.setLevel("INFO") - - -NUM_ITERATIONS = 20 -NUM_SAMPLES = 16384 // 2 -HIDDEN_SIZE = 16 - - -def Dataset(num_samples: int) -> List[Tuple[torch.Tensor, torch.Tensor]]: - return [ - ( - torch.randn(HIDDEN_SIZE, HIDDEN_SIZE), - torch.randn(HIDDEN_SIZE // 2, HIDDEN_SIZE // 2), - ) - for _ in range(num_samples) - ] - - -# Run forward & backward with dynamic batch size. -def run_interleaved_with_dynamic_batch_size( - pipeline_model_parallel_size: int, - forward_only: bool, - BatchSamplerCls, -) -> None: - args = global_vars.get_args() - _reconfigure_microbatch_calculator( - args.rank, - args.rampup_batch_size, - args.global_batch_size, - args.micro_batch_size, - 1, # args.data_parallel_size, - ) - virtual_pipeline_model_parallel_size = 2 - # NOTE (mkozuki): `virtual_pipeline_model_parallel_size` is a requisite for the interleaving scheduling - # In megatron, `args.virtual_pipeline_model_parallel_size` is computed in megatron/arguments.py and - # used ubiquitously but this test uses custom model so it's safe to abuse. - parallel_state.initialize_model_parallel( - 1, pipeline_model_parallel_size, virtual_pipeline_model_parallel_size - ) - pipeline_model_parallel_size = parallel_state.get_pipeline_model_parallel_world_size() - - print_separator(f"BatchSamplerCls: {BatchSamplerCls.__name__}, forward_only: {forward_only}") - - model = build_model( - model_provider_func, - wrap_with_ddp=True, - virtual_pipeline_model_parallel_size=virtual_pipeline_model_parallel_size, - hidden_size=HIDDEN_SIZE, - ) - assert isinstance(model, list) - assert len(model) == virtual_pipeline_model_parallel_size - optimizer = torch.optim.Adam(_get_params_for_weight_decay_optimization(model)) - - initial_local_minibatch_size = get_num_microbatches() * args.micro_batch_size - dataset = Dataset(NUM_SAMPLES) - data_loader = torch.utils.data.DataLoader( - dataset, - batch_sampler=BatchSamplerCls( - NUM_SAMPLES, - 0, - initial_local_minibatch_size, - parallel_state.get_data_parallel_rank(), - parallel_state.get_data_parallel_world_size(), - ), - ) - data_iter = iter(data_loader) - - def get_num_samples(batch): - if isinstance(batch, torch.Tensor): - return len(batch) - assert isinstance(batch, (list, tuple)) - return [get_num_samples(b) for b in batch] - - tensor_shape = [args.micro_batch_size, HIDDEN_SIZE, HIDDEN_SIZE] - consumed_samples = 0 - for i in range(NUM_ITERATIONS): - update_num_microbatches(consumed_samples, consistency_check=False) - local_batch_size = get_num_microbatches() * args.micro_batch_size - data_iter._index_sampler.local_minibatch_size = local_batch_size - local_mini_batch = next(data_iter) - - _logger.info( - f"iter: {i} / {NUM_ITERATIONS} " - f"local batchsize: {get_num_samples(local_mini_batch)} " - f"consumed_samples: {consumed_samples} / {NUM_SAMPLES}" - ) - _forward_backward_pipelining_with_interleaving( - fwd_step_func, - local_mini_batch, - model, - forward_only=forward_only, - tensor_shape=tensor_shape, - ) - - consumed_samples += ( - parallel_state.get_data_parallel_world_size() - * get_num_microbatches() - * args.micro_batch_size - ) - - if not forward_only: - for m in model: - for p in m.parameters(): - if p.grad is None: - raise RuntimeError("grad not found") - else: - optimizer.zero_grad(set_to_none=True) - - torch.cuda.synchronize() - - -class DynamicBatchsizeTestBase: - @unittest.skipUnless(torch.cuda.device_count() > 2, "requires at least 3 gpus") - def test_dynamic_batchsize(self): - n_tests = 0 - failures = [] - - override_args = { - "micro_batch_size": 2, - "num_layers": 16, - "hidden_size": 256, - "num_attention_heads": 8, - "max_position_embeddings": 512, - "seq_length": 512, - "global_batch_size": 128, - "use_cpu_initialization": True, - "world_size": self.world_size, - "rank": self.rank, - } - - global_vars.set_global_variables( - args_defaults={ - "global_batch_size": 512, - "rampup_batch_size": [64, 64, 1000], - }, - ignore_unknown_args=True, - override_args=override_args, - ) - - args = global_vars.get_args() - - setup_microbatch_calculator( - args.rank, - args.rampup_batch_size, - args.global_batch_size, - args.micro_batch_size, - 1, # args.data_parallel_size, - ) - for BatchSamplerCls in ( - MegatronPretrainingSampler, - MegatronPretrainingRandomSampler, - ): - for forward_only in (False, True): - n_tests += 1 - pipeline_model_parallel_size = self.world_size - try: - run_interleaved_with_dynamic_batch_size( - pipeline_model_parallel_size, - forward_only, - BatchSamplerCls, - ) - except Exception as e: - msg = ( - f"\tforward_only: {forward_only}\n" - f"pipeline rank: {parallel_state.get_pipeline_model_parallel_rank()}, " - f"virtual pipeline rank: {parallel_state.get_virtual_pipeline_model_parallel_rank()}\n" - f"{str(e)}" - ) - raise RuntimeError(msg) - finally: - parallel_state.destroy_model_parallel() - if failures: - print_separator("TEST FAILED:") - print("\n".join(failures)) - msg = f"{len(failures)} / {n_tests} cases failed" - raise RuntimeError(msg) - else: - if torch.distributed.get_rank() == 0: - print_separator("TEST RESULT: ### PASS!") - - -class NcclDynamicBatchsizeTest(DynamicBatchsizeTestBase, NcclDistributedTestBase): - pass - - -# TODO: (Fuzzkatt) UCC still doesn't work with fwd_bwd_pipelining_with_interleaving - - -if __name__ == "__main__": - torch.backends.cuda.matmul.allow_tf32 = False - common_utils.run_tests() diff --git a/tests/L0/run_transformer/test_fused_rope.py b/tests/L0/run_transformer/test_fused_rope.py deleted file mode 100644 index e37e6b695..000000000 --- a/tests/L0/run_transformer/test_fused_rope.py +++ /dev/null @@ -1,329 +0,0 @@ -"""Test for fused RoPE functions. - -Ref: https://github.com/NVIDIA/Megatron-LM/blob/40becfc96c4144985458ac0e0fae45dbb111fbd2/megatron/fused_kernels/tests/test_fused_kernels.py -""" # NOQA - -import itertools - -import torch -from torch.testing._internal import common_utils -from apex.transformer.functional import ( - fused_apply_rotary_pos_emb, - fused_apply_rotary_pos_emb_cached, - fused_apply_rotary_pos_emb_thd, - fused_apply_rotary_pos_emb_2d, -) - - -def _rotate_half(x: torch.Tensor) -> torch.Tensor: - """Change sign so the last dimension becomes [-odd, +even] - - Args: - x (Tensor): Input tensor - - Returns: - Tensor: Tensor rotated half - """ - - x1, x2 = torch.chunk(x, 2, dim=-1) - return torch.cat((-x2, x1), dim=-1) - - -# Copied from Megatron-Core for testing. -# https://github.com/NVIDIA/Megatron-LM/blob/5f2877d85cb26e47ce6dcdae4b80adf376abf4e8/megatron/core/models/common/embeddings/rotary_pos_embedding.py#L139 -def apply_rotary_pos_emb(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: - """Apply rotary positional embedding to input tensor T. - - check https://kexue.fm/archives/8265 for detailed formulas - - Args: - t (Tensor): Input tensor T is of shape [seq_length, ... , dim] - freqs (Tensor): Rotary Positional embedding tensor freq is of shape [seq_length, ..., dim] - - Returns: - Tensor: The input tensor after applying RoPE - """ - rot_dim = freqs.shape[-1] - - # ideally t_pass is empty so rotary pos embedding is applied to all tensor t - t, t_pass = t[..., :rot_dim], t[..., rot_dim:] - - # first part is cosine component - # second part is sine component, need to change signs with _rotate_half method - cos_ = torch.cos(freqs).to(t.dtype) - sin_ = torch.sin(freqs).to(t.dtype) - - t = (t * cos_) + (_rotate_half(t) * sin_) - return torch.cat((t, t_pass), dim=-1) - - -def apply_rotary_pos_emb_thd( - t: torch.Tensor, cu_seqlens: torch.Tensor, freqs: torch.Tensor -) -> torch.Tensor: - """A baseline implementation of applying RoPE for `thd` format. - - Args: - t (Tensor): Input tensor T is of shape [t, h, d] - cu_seqlens(Tensor): Cumulative sum of sequence lengths in a batch for `t`, - with shape [b + 1] and dtype torch.int32. - freqs (Tensor): Rotary Positional embedding tensor freq is of shape [max_s, 1, 1, d] - - Returns: - Tensor: Shape [t, h, d]. The input tensor after applying RoPE. - """ - seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() - return torch.cat( - [apply_rotary_pos_emb(x.unsqueeze(1), freqs[: x.size(0)]) for x in torch.split(t, seqlens)] - ).squeeze(1) - - -def apply_rotary_pos_emb_2d(q, img_h, img_w, cos_h, sin_h, cos_w, sin_w): - q = q.view(q.shape[0], img_h, img_w, q.shape[2], q.shape[3]) - q1, q2 = q.chunk(2, dim=-1) - cos_h = cos_h[:, :img_h].unsqueeze(2) # [1, H, 1, 1, D//2] - sin_h = sin_h[:, :img_h].unsqueeze(2) # [1, H, 1, 1, D//2] - q1 = (q1 * cos_h) + (_rotate_half(q1) * sin_h) - cos_w = cos_w[:, :img_w].unsqueeze(1) # [1, 1, W, 1, D//2] - sin_w = sin_w[:, :img_w].unsqueeze(1) # [1, 1, W, 1, D//2] - q2 = (q2 * cos_w) + (_rotate_half(q2) * sin_w) - return torch.cat([q1, q2], dim=-1).view(q.shape[0], -1, q.shape[3], q.shape[4]) - - -class TestFusedRoPE(common_utils.TestCase): - def setUp(self): - super().setUp() - self.batch_size = 2 - self.head_num = 64 - self.seq_length = [2048, 4096] - self.hidden_size = [128, 256] - self.rotary_percent = [0.5, 1.0] - self.dtype = [torch.float32, torch.bfloat16, torch.float16] - self.transpose = [None, (0, 1), (2, 3)] - self.transpose_output_memory = [False, True] - self.loss_func = [self._overlapping_grad, self._non_overlapping_grad] - self.cached = [False, True] - self.device = torch.cuda.current_device() - # for 2D RoPE - self.img_h = [32, 64] - self.img_w = [32, 64] - - def tearDown(self) -> None: - torch.cuda.empty_cache() - super().tearDown() - - def _overlapping_grad(self, output) -> torch.Tensor: - return output.sum() * 2 - - def _non_overlapping_grad(self, output) -> torch.Tensor: - t = torch.ones_like(output) - return torch.sum(output * t) - - def test_forward_backward(self): - for ( - dtype, - seq_length, - hidden_size, - rotary_percent, - transpose, - transpose_output_memory, - loss_func, - cached, - ) in itertools.product( - self.dtype, - self.seq_length, - self.hidden_size, - self.rotary_percent, - self.transpose, - self.transpose_output_memory, - self.loss_func, - self.cached, - ): - t = torch.rand( - (seq_length, self.batch_size, self.head_num, hidden_size), - dtype=dtype, - device=self.device, - ) - if transpose: - t = t.transpose(*transpose).contiguous().transpose(*transpose) - t.requires_grad = True - - emb = torch.rand( - (seq_length, 1, 1, int(hidden_size * rotary_percent)), - dtype=torch.float32, - device=self.device, - ) - - # unfused - output_unfused = apply_rotary_pos_emb(t, emb) - loss_unfused = loss_func(output_unfused) - loss_unfused.backward() - grad_unfused = t.grad.detach().clone() - t.grad = None - - # fused - if cached: - cos, sin = emb.cos(), emb.sin() - output_fused = fused_apply_rotary_pos_emb_cached( - t, cos, sin, transpose_output_memory=transpose_output_memory - ) - else: - output_fused = fused_apply_rotary_pos_emb( - t, emb, transpose_output_memory=transpose_output_memory - ) - loss_fused = loss_func(output_fused) - loss_fused.backward() - grad_fused = t.grad.detach().clone() - t.grad = None - - self.assertEqual( - output_unfused, - output_fused, - msg=f"{dtype=}, {seq_length=}, {hidden_size=}, {rotary_percent=}, " - f"{transpose=}, {transpose_output_memory=}, loss_func={loss_func.__name__}", - ) - self.assertEqual( - grad_unfused, - grad_fused, - msg=f"{dtype=}, {seq_length=}, {hidden_size=}, {rotary_percent=}, " - f"{transpose=}, {transpose_output_memory=}, loss_func={loss_func.__name__}", - ) - assert output_fused.transpose(0, 1).is_contiguous() is transpose_output_memory - - def test_thd_forward_backward(self): - cu_seqlens = torch.tensor( - [0, 400, 542, 711, 727, 752, 1270, 1426, 1450, 1954, 2044, 2048], - dtype=torch.int32, - device=self.device, - ) - for ( - dtype, - hidden_size, - rotary_percent, - transpose, - loss_func, - ) in itertools.product( - self.dtype, - self.hidden_size, - self.rotary_percent, - [None, [1, 2]], - self.loss_func, - ): - t = torch.rand( - (cu_seqlens[-1], self.head_num, hidden_size), - dtype=dtype, - device=self.device, - ) - if transpose: - t = t.transpose(*transpose).contiguous().transpose(*transpose) - t.requires_grad = True - - emb = torch.rand( - (cu_seqlens[-1], 1, 1, int(hidden_size * rotary_percent)), - dtype=torch.float32, - device=self.device, - ) - - # unfused - output_unfused = apply_rotary_pos_emb_thd(t, cu_seqlens, emb) - loss_unfused = loss_func(output_unfused) - loss_unfused.backward() - grad_unfused = t.grad.detach().clone() - t.grad = None - - # fused - output_fused = fused_apply_rotary_pos_emb_thd( - t, - cu_seqlens, - emb, - ) - loss_fused = loss_func(output_fused) - loss_fused.backward() - grad_fused = t.grad.detach().clone() - t.grad = None - - self.assertEqual( - output_unfused, - output_fused, - msg=f"{dtype=}, {cu_seqlens=}, {hidden_size=}, {rotary_percent=}, " - f"{transpose=}, loss_func={loss_func.__name__}", - ) - self.assertEqual( - grad_unfused, - grad_fused, - msg=f"{dtype=}, {cu_seqlens=}, {hidden_size=}, {rotary_percent=}, " - f"{transpose=}, loss_func={loss_func.__name__}", - ) - - def test_2d_forward_backward(self): - for ( - dtype, - img_h, - img_w, - hidden_size, - transpose, - loss_func, - margin, - ) in itertools.product( - self.dtype, - self.img_h, - self.img_w, - self.hidden_size, - self.transpose, - self.loss_func, - [0, 3], - ): - t = torch.rand( - (self.batch_size, img_h * img_w, self.head_num, hidden_size), - dtype=dtype, - device=self.device, - ) - if transpose: - t = t.transpose(*transpose).contiguous().transpose(*transpose) - t.requires_grad = True - - emb_h = torch.rand( - (1, img_h + margin, 1, hidden_size // 2), - dtype=torch.float32, - device=self.device, - ) - cos_h, sin_h = emb_h.cos().to(dtype), emb_h.sin().to(dtype) - - emb_w = torch.rand( - (1, img_w + margin, 1, hidden_size // 2), - dtype=torch.float32, - device=self.device, - ) - cos_w, sin_w = emb_w.cos().to(dtype), emb_w.sin().to(dtype) - - # unfused - output_unfused = apply_rotary_pos_emb_2d(t, img_h, img_w, cos_h, sin_h, cos_w, sin_w) - loss_unfused = loss_func(output_unfused) - loss_unfused.backward() - grad_unfused = t.grad.detach().clone() - t.grad = None - - # fused - output_fused = fused_apply_rotary_pos_emb_2d( - t, img_h, img_w, cos_h, sin_h, cos_w, sin_w - ) - loss_fused = loss_func(output_fused) - loss_fused.backward() - grad_fused = t.grad.detach().clone() - t.grad = None - - self.assertEqual( - output_unfused, - output_fused, - msg=f"{dtype=}, {img_h=}, {img_w=}, {hidden_size=}, " - f"{transpose=}, loss_func={loss_func.__name__}", - ) - self.assertEqual( - grad_unfused, - grad_fused, - msg=f"{dtype=}, {img_h=}, {img_w=}, {hidden_size=}, " - f"{transpose=}, loss_func={loss_func.__name__}", - ) - - -if __name__ == "__main__": - common_utils.run_tests() diff --git a/tests/L0/run_transformer/test_fused_softmax.py b/tests/L0/run_transformer/test_fused_softmax.py deleted file mode 100644 index ba4add713..000000000 --- a/tests/L0/run_transformer/test_fused_softmax.py +++ /dev/null @@ -1,398 +0,0 @@ -"""Test for fused softmax functions. - -Ref: https://github.com/NVIDIA/Megatron-LM/blob/40becfc96c4144985458ac0e0fae45dbb111fbd2/megatron/fused_kernels/tests/test_fused_kernels.py -""" # NOQA - -import itertools - -import torch -from torch.testing._internal import common_utils - -from apex.transformer import AttnMaskType -from apex.transformer.functional import FusedScaleMaskSoftmax - - -def attention_mask_func(attention_scores, attention_mask): - return attention_scores.masked_fill(attention_mask, -10000.0) - - -def forward_torch_softmax(input, mask, scale): - input = input * scale - mask_output = attention_mask_func(input, mask) if mask is not None else input - probs = torch.nn.Softmax(dim=-1)(mask_output) - all_k_masked = mask.all(axis=-1) - zero_attention_mask = (1.0 - all_k_masked.float())[:, :, :, None] - probs = probs * zero_attention_mask - return probs - - -autocast_dtypes = (torch.half, torch.bfloat16) if torch.cuda.is_bf16_supported() else (torch.half,) - - -class TestFusedScaleMaskSoftmax(common_utils.TestCase): - def _setup_fused_softmax( - self, - input_in_fp16, - input_in_bf16, - scale=None, - softmax_in_fp32=False, - attn_mask_type=AttnMaskType.padding, - ): - fused_fn = FusedScaleMaskSoftmax( - input_in_fp16=input_in_fp16, - input_in_bf16=input_in_bf16, - mask_func=attention_mask_func, - scale=scale, - softmax_in_fp32=softmax_in_fp32, - attn_mask_type=attn_mask_type, - scaled_masked_softmax_fusion=True, - ) - torch_fn = FusedScaleMaskSoftmax( - input_in_fp16=input_in_fp16, - input_in_bf16=input_in_bf16, - mask_func=attention_mask_func, - scale=scale, - softmax_in_fp32=softmax_in_fp32, - attn_mask_type=attn_mask_type, - scaled_masked_softmax_fusion=False, - ) - return fused_fn, torch_fn - - def tearDown(self) -> None: - torch.cuda.empty_cache() - super().tearDown() - - def test_fused_scale_mask_softmax(self): - """ - attention_scores.shape = [4, 12, 24, 24] - mask.shape = [4, 1, 24, 24] - """ - for dtype, scale, softmax_in_fp32, shape in itertools.product( - (torch.half, torch.bfloat16), - (None, 2.0), - (False, True), - ((4, 12, 24, 24), (32, 12, 4, 214)), - ): - msg = f"{dtype}-{scale}-{softmax_in_fp32}" - input_in_fp16 = dtype == torch.half - input_in_bf16 = dtype == torch.bfloat16 - if not (scale is None or softmax_in_fp32): - with self.assertRaises(RuntimeError, msg=msg): - self._setup_fused_softmax( - input_in_fp16, - input_in_bf16, - scale, - softmax_in_fp32, - AttnMaskType.padding, - ) - return - fused_fn, torch_fn = self._setup_fused_softmax( - input_in_fp16, - input_in_bf16, - scale, - softmax_in_fp32, - AttnMaskType.padding, - ) - - attention_scores_0 = ( - torch.randn(shape).to(device="cuda", dtype=dtype).requires_grad_(True) - ) - with torch.no_grad(): - attention_scores_1 = attention_scores_0.clone().requires_grad_(True) - mask_shape = (shape[0],) + (1,) + shape[2:] - mask = torch.randint(0, 2, mask_shape, device="cuda").bool() - expected = fused_fn(attention_scores_0, mask) - actual = torch_fn(attention_scores_1, mask) - self.assertEqual(actual, expected, msg=msg) - - g0 = torch.rand_like(actual) - with torch.no_grad(): - g1 = g0.clone() - expected.backward(g0) - actual.backward(g1) - - def test_autocast_fused_scale_mask_softmax(self): - for dtype in autocast_dtypes: - msg = f"dtype: {dtype}" - input_in_fp16 = dtype == torch.half - input_in_bf16 = dtype == torch.bfloat16 - fused_fn, torch_fn = self._setup_fused_softmax( - input_in_fp16, input_in_bf16, attn_mask_type=AttnMaskType.padding - ) - - attention_scores_0 = torch.randn((4, 12, 24, 24)).cuda().requires_grad_(True) - with torch.no_grad(): - attention_scores_1 = attention_scores_0.clone().to(dtype).requires_grad_(True) - mask = torch.randint(0, 2, (4, 1, 24, 24)).bool().cuda() - - expected = torch_fn(attention_scores_1, mask) - with torch.amp.autocast("cuda", dtype=dtype): - actual = fused_fn(attention_scores_0, mask) - self.assertEqual(actual.dtype, dtype, msg=msg) - self.assertEqual(actual, expected, msg=msg) - - g0 = torch.rand_like(actual) - with torch.no_grad(): - g1 = g0.clone() - expected.backward(g0) - actual.backward(g1) - - def test_fused_scale_softmax(self): - """ - attention_scores.shape = [4, 12, 24, 24] - mask = None - """ - for dtype, scale, softmax_in_fp32, shape in itertools.product( - (torch.half, torch.bfloat16), - (None, 2.0), - (False, True), - ((4, 12, 24, 24), (32, 12, 4, 214)), - ): - msg = f"{dtype}-{scale}-{softmax_in_fp32}" - input_in_fp16 = dtype == torch.half - input_in_bf16 = dtype == torch.bfloat16 - if not (scale is None or softmax_in_fp32): - with self.assertRaises(RuntimeError, msg=msg): - self._setup_fused_softmax( - input_in_fp16, - input_in_bf16, - scale, - softmax_in_fp32, - AttnMaskType.padding, - ) - return - fused_fn, torch_fn = self._setup_fused_softmax( - input_in_fp16, - input_in_bf16, - scale, - softmax_in_fp32, - AttnMaskType.padding, - ) - - attention_scores_0 = ( - torch.randn(shape).to(device="cuda", dtype=dtype).requires_grad_(True) - ) - with torch.no_grad(): - attention_scores_1 = attention_scores_0.clone().requires_grad_(True) - mask = None - - expected = fused_fn(attention_scores_0, mask) - actual = torch_fn(attention_scores_1, mask) - self.assertEqual(actual, expected, msg=msg) - - g0 = torch.rand_like(actual) - with torch.no_grad(): - g1 = g0.clone() - expected.backward(g0) - actual.backward(g1) - - def test_autocast_fused_scale_softmax(self): - for dtype in autocast_dtypes: - msg = f"dtype: {dtype}" - input_in_fp16 = dtype == torch.half - input_in_bf16 = dtype == torch.bfloat16 - fused_fn, torch_fn = self._setup_fused_softmax( - input_in_fp16, input_in_bf16, attn_mask_type=AttnMaskType.padding - ) - - attention_scores_0 = torch.randn((4, 12, 24, 24)).cuda().requires_grad_(True) - with torch.no_grad(): - attention_scores_1 = attention_scores_0.clone().to(dtype).requires_grad_(True) - mask = None - - expected = torch_fn(attention_scores_1, mask) - with torch.amp.autocast("cuda", dtype=dtype): - actual = fused_fn(attention_scores_0, mask) - self.assertEqual(actual.dtype, dtype, msg=msg) - self.assertEqual(actual, expected, msg=msg) - - g0 = torch.rand_like(actual) - with torch.no_grad(): - g1 = g0.clone() - expected.backward(g0) - actual.backward(g1) - - def test_fused_upper_triangle_mask_softmax(self): - """ - attn_weights.shape: [4, 12, 24, 24] - total_mask.shape: [4, 1, 24, 24] - - total_mask[0, 0], a 24x24 matrix is like a lower triangular matrix, but - upper elements are True and lower elements and diagonal are False. - """ - for dtype, scale, softmax_in_fp32 in itertools.product( - (torch.half, torch.bfloat16), - (None, 2.0), - (False, True), - ): - msg = f"{dtype}-{scale}-{softmax_in_fp32}" - input_in_fp16 = dtype == torch.half - input_in_bf16 = dtype == torch.bfloat16 - if not (scale is None or softmax_in_fp32): - with self.assertRaises(RuntimeError, msg=msg): - self._setup_fused_softmax( - input_in_fp16, - input_in_bf16, - scale, - softmax_in_fp32, - AttnMaskType.causal, - ) - return - fused_fn, torch_fn = self._setup_fused_softmax( - input_in_fp16, - input_in_bf16, - scale, - softmax_in_fp32, - AttnMaskType.causal, - ) - - attn_weights_0 = ( - torch.randn((4, 12, 24, 24)).to(device="cuda", dtype=dtype).requires_grad_(True) - ) - with torch.no_grad(): - attn_weights_1 = attn_weights_0.clone().requires_grad_(True) - total_mask = ~(torch.tril(torch.randn((24, 24), device="cuda")).bool()).unsqueeze( - 0 - ).unsqueeze(0) - total_mask = total_mask.repeat((4, 1, 1, 1)) - expected = fused_fn(attn_weights_0, total_mask) - actual = torch_fn(attn_weights_1, total_mask) - self.assertEqual(actual, expected, msg=msg) - - g0 = torch.randn_like(actual) - with torch.no_grad(): - g1 = g0.clone() - actual.backward(g0) - expected.backward(g1) - - def test_autocast_fused_upper_triangle_mask_softmax(self): - for dtype in autocast_dtypes: - msg = f"dtype: {dtype}" - input_in_fp16 = dtype == torch.half - input_in_bf16 = dtype == torch.bfloat16 - fused_fn, torch_fn = self._setup_fused_softmax( - input_in_fp16, input_in_bf16, attn_mask_type=AttnMaskType.causal - ) - - attn_weights_0 = torch.randn((4, 12, 24, 24)).cuda().requires_grad_(True) - with torch.no_grad(): - attn_weights_1 = attn_weights_0.clone().to(dtype).requires_grad_(True) - total_mask = ~(torch.tril(torch.randn((24, 24), device="cuda")).bool()).unsqueeze( - 0 - ).unsqueeze(0) - - with torch.amp.autocast("cuda", dtype=dtype): - actual = fused_fn(attn_weights_0, total_mask) - self.assertEqual(actual.dtype, dtype, msg=msg) - expected = torch_fn(attn_weights_1, total_mask) - self.assertEqual(actual, expected, msg=msg) - - g0 = torch.randn_like(actual) - with torch.no_grad(): - g1 = g0.clone() - actual.backward(g0) - expected.backward(g1) - - -class TestGenericFusedSoftmaxKernel(common_utils.TestCase): - def setUp(self): - super().setUp() - self.batch = 2 - self.attn = 16 - self.scale_t = 1.0 - self.dtype = torch.float16 - self.device = torch.cuda.current_device() - self.thresh = {"atol": 1e-3, "rtol": 1e-3} - - qlen = [1, 2] - klen = [1, 2, 3, 4, 5, 8, 10, 11, 13, 128, 256, 1200, 1234] - available_cuda_mem = torch.cuda.memory.mem_get_info(self.device)[0] / (1024**3) - if available_cuda_mem > 40: - qlen.extend([1234, 2322, 2348]) - klen.extend([2048, 3123, 4096, 4128, 7234, 8192]) - - self.q_k_lens = itertools.product(qlen, klen) - - def tearDown(self) -> None: - torch.cuda.empty_cache() - super().tearDown() - - def test_forward(self, allmasked: bool = False): - import generic_scaled_masked_softmax_cuda - - for qlen, klen in self.q_k_lens: - inputs = torch.normal( - 0, - 2, - (self.batch, self.attn, qlen, klen), - dtype=self.dtype, - device=self.device, - ) - masks = ( - torch.randint( - 0, - 2, - (self.batch, 1, qlen, klen), - dtype=torch.bool, - device=self.device, - ) - if not allmasked - else torch.ones((self.batch, 1, qlen, klen), dtype=torch.bool, device=self.device) - ) - softmax_results = generic_scaled_masked_softmax_cuda.forward( - inputs, masks, self.scale_t - ) - softmax_results_torch = forward_torch_softmax(inputs, masks, self.scale_t) - self.assertEqual( - softmax_results_torch.to(self.dtype), - softmax_results, - **self.thresh, - msg=f"(q, k) = ({qlen, klen})", - ) - - def test_backward(self, allmasked: bool = False): - import generic_scaled_masked_softmax_cuda - - prev_thresh = self.thresh - self.thresh = {"atol": 1.5e-1, "rtol": 5e-3} - for qlen, klen in self.q_k_lens: - inputs = torch.normal( - 0, - 2, - (self.batch, self.attn, qlen, klen), - dtype=self.dtype, - device=self.device, - ) - backward = torch.rand_like(inputs, dtype=torch.float16, device=self.device) - masks = ( - torch.randint( - 0, - 2, - (self.batch, 1, qlen, klen), - dtype=torch.bool, - device=self.device, - ) - if not allmasked - else torch.ones((self.batch, 1, qlen, klen), dtype=torch.bool, device=self.device) - ) - softmax_results = generic_scaled_masked_softmax_cuda.forward( - inputs, masks, self.scale_t - ) - back_grad = generic_scaled_masked_softmax_cuda.backward( - backward, softmax_results, self.scale_t - ) - inputs.requires_grad = True - softmax_results_torch = forward_torch_softmax(inputs, masks, self.scale_t) - softmax_results_torch.backward(backward) - self.assertEqual(back_grad, inputs.grad, **self.thresh, msg=f"(q, k) = ({qlen, klen})") - self.thresh = prev_thresh - - def test_allmasked(self): - self.test_forward(True) - - def test_allmask_backward(self): - self.test_backward(True) - - -if __name__ == "__main__": - common_utils.run_tests() diff --git a/tests/L0/run_transformer/test_gpt_minimal.py b/tests/L0/run_transformer/test_gpt_minimal.py deleted file mode 100644 index a8f80b509..000000000 --- a/tests/L0/run_transformer/test_gpt_minimal.py +++ /dev/null @@ -1,238 +0,0 @@ -from functools import partial -from typing import List -import time - -import torch - -import unittest - -from apex.transformer._ucc_util import HAS_UCC -from apex.transformer import parallel_state -from apex.transformer.enums import ModelType -from apex.transformer.tensor_parallel import model_parallel_cuda_manual_seed -from apex.transformer.pipeline_parallel.utils import ( - average_losses_across_data_parallel_group, - unwrap_model, - setup_microbatch_calculator, - get_ltor_masks_and_position_ids, -) -from apex.transformer.pipeline_parallel.schedules.common import ( - _get_params_for_weight_decay_optimization, - build_model, -) -from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_without_interleaving import ( - forward_backward_pipelining_without_interleaving, -) -from apex.transformer.testing.standalone_gpt import gpt_model_provider -from apex.transformer.testing import global_vars - -from apex.transformer.testing.distributed_test_base import ( - UccDistributedTestBase, - NcclDistributedTestBase, -) - -from torch.testing._internal import common_utils - - -class GptTestBase: - def _download_fancy_data(self): - text = """ - An original sentence not subject to any license restrictions, copyright, or royalty payments. Nothing to see here. Commercial or non-commercial use. Research or non-research purposes. The quick brown fox jumps over the lazy dog. Lorem ipsum. - """ - text = text * 1024 - encoded = text.encode("ascii", "replace") - ints = [int(encoded[i]) for i in range(len(encoded))] - return torch.tensor(ints) - - # build a batch given sequence_len and batch size - def _generate_fancy_data_labels(self, sequence_len, batch_size): - temps = list() - for i in range(batch_size): - if self.inds is None or self.data_idx >= len(self.inds): - # hack as use of RNG will fall out of sync due to pipelines being different - model_parallel_cuda_manual_seed(self.MANUAL_SEED) - self.inds = torch.randperm(effective_length, device="cuda") - self.MANUAL_SEED += 1 - self.data_idx = 0 - data_idx_ = self.data_idx - offset = self.inds[data_idx_] - self.data_idx += 1 - curr = fancy_data[offset : offset + sequence_len + 1].clone().detach() - temps.append(curr) - temp = torch.stack(temps, dim=0).cuda() - return temp - - def _get_batch(self, int_tensors: List[torch.Tensor]): - data = int_tensors[0] - # Unpack. - tokens_ = data.long() - labels = tokens_[:, 1:].contiguous() - tokens = tokens_[:, :-1].contiguous() - # Get the masks and position ids. - attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids( - tokens, - self.N_VOCAB, # tokenizer.eod, - False, # args.reset_position_ids, - False, # args.reset_attention_mask, - False, # args.eod_mask_loss, - ) - return tokens, labels, loss_mask, attention_mask, position_ids - - # Ref: https://github.com/NVIDIA/Megatron-LM/blob/b31e1296354e979722627a6c4dedafe19b51fa97/pretrain_gpt.py#L75 - def _loss_func(self, loss_mask, output_tensor): - losses = output_tensor.float() - loss_mask = loss_mask.view(-1).float() - loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() - - # Reduce loss for logging. - averaged_loss = average_losses_across_data_parallel_group([loss]) - - return loss, {"lm loss": averaged_loss[0]} - - # Ref: https://github.com/NVIDIA/Megatron-LM/blob/b31e1296354e979722627a6c4dedafe19b51fa97/pretrain_gpt.py#L86 - def _fwd_step_func(self, batch, model): - """Forward step.""" - tokens, labels, loss_mask, attention_mask, position_ids = self._get_batch(batch) - output_tensor = model(tokens, position_ids, attention_mask, labels=labels) - return output_tensor, partial(self._loss_func, loss_mask) - - def _train(self, model, optim, pipeline_model_parallel_size, async_comm): - args = global_vars.get_args() - fwd_bwd_func = forward_backward_pipelining_without_interleaving - - tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size) - runtime = 0 - # training loop - for i in range(3): - since = time.time() - if torch.distributed.get_rank() == 0: - print("begin iter", i) - batch = [ - self._generate_fancy_data_labels(args.seq_length, args.global_batch_size) - for _ in range(pipeline_model_parallel_size) - ] - if torch.distributed.get_rank() == 0: - print("finished making batch...") - optim.zero_grad() - fwd_bwd_func( - self._fwd_step_func, - batch, - model, - forward_only=False, - tensor_shape=tensor_shape, - async_comm=async_comm, - sequence_parallel_enabled=args.sequence_parallel, - ) - if torch.distributed.get_rank() == 0: - print("finished forward step") - # All-reduce layernorm parameters across model parallel nodes - # when sequence parallelism is used - if ( - parallel_state.get_tensor_model_parallel_world_size() > 1 - and global_vars.get_args().sequence_parallel - ): - for model_module in model: - unwrapped_model = unwrap_model(model_module) - for param in unwrapped_model.parameters(): - if getattr(param, "sequence_parallel_enabled", False): - grad = param.grad - torch.distributed.all_reduce( - grad, - group=parallel_state.get_tensor_model_parallel_group(), - ) - optim.step() - if torch.distributed.get_rank() == 0: - print("finished iter", i) - runtime += time.time() - since - return runtime / 3.0 - - @unittest.skipUnless(torch.cuda.device_count() > 2, "requires at least 3 gpus") - def test_gpt(self): - self.MANUAL_SEED = 42 - self.inds = None - self.data_idx = 0 - self.N_VOCAB = 128 - init = True - - tensor_model_parallel_size = 2 if self.world_size % 2 == 0 and self.world_size >= 4 else 1 - pipeline_model_parallel_size = self.world_size // tensor_model_parallel_size - - override_args = { - "micro_batch_size": 2, - "num_layers": 16, - "hidden_size": 256, - "num_attention_heads": 8, - "max_position_embeddings": 512, - "seq_length": 512, - "global_batch_size": 128, - "pipeline_model_parallel_size": pipeline_model_parallel_size, - "tensor_model_parallel_size": tensor_model_parallel_size, - "world_size": self.world_size, - "rank": self.rank, - } - - global_vars.set_global_variables(override_args=override_args, ignore_unknown_args=True) - args = global_vars.get_args() - - for async_comm in (False,) if args.sequence_parallel else (False, True): - global fancy_data - global effective_length - - if init: - init = False - - fancy_data = self._download_fancy_data() - args = global_vars.get_args() - args.model_type = ModelType.encoder_or_decoder - effective_length = fancy_data.size(0) // args.seq_length - effective_length = fancy_data.size(0) - args.seq_length - - args.padded_vocab_size = 128 - setup_microbatch_calculator( - args.rank, - args.rampup_batch_size, - args.global_batch_size, - args.micro_batch_size, - args.data_parallel_size, - ) - - print(args.tensor_model_parallel_size, "MODEL PARALLEL SIZE") - - parallel_state.initialize_model_parallel( - tensor_model_parallel_size_=args.tensor_model_parallel_size, - pipeline_model_parallel_size_=args.pipeline_model_parallel_size, - default_backend="nccl", - p2p_backend=self.DISTRIBUTED_BACKEND, - ) - - model_parallel_cuda_manual_seed(0) - model = build_model( - gpt_model_provider, - wrap_with_ddp=parallel_state.get_data_parallel_world_size() > 1, - virtual_pipeline_model_parallel_size=None, - cpu_offload=args.cpu_offload, - ) - assert isinstance(model, list), model - _param_groups = _get_params_for_weight_decay_optimization(model) - optim = torch.optim.Adam(_param_groups) - runtime = self._train(model, optim, args.pipeline_model_parallel_size, async_comm) - - parallel_state.destroy_model_parallel() - torch.cuda.synchronize() - - -class NcclGptTest(GptTestBase, NcclDistributedTestBase): - @property - def world_size(self) -> int: - return min(torch.cuda.device_count(), 8) - - -@unittest.skipUnless(HAS_UCC, "requires pytorch to be built with native ucc") -class UccGptTest(GptTestBase, UccDistributedTestBase): - @property - def world_size(self) -> int: - return min(torch.cuda.device_count(), 8) - - -if __name__ == "__main__": - common_utils.run_tests() diff --git a/tests/L0/run_transformer/test_layers.py b/tests/L0/run_transformer/test_layers.py deleted file mode 100644 index 719d351c3..000000000 --- a/tests/L0/run_transformer/test_layers.py +++ /dev/null @@ -1,575 +0,0 @@ -import logging -import unittest -import typing - -import torch -import torch.nn as nn -from torch.testing._internal import common_utils - -from apex.transformer import parallel_state -from apex.transformer.tensor_parallel import layers -from apex.transformer.testing.commons import set_random_seed -from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase -from apex.transformer.testing.distributed_test_base import UccDistributedTestBase - - -logging.getLogger("torch").setLevel(logging.WARNING) -logging.getLogger("apex").setLevel(logging.WARNING) - - -# N.B.(mkozuki): Disable TF32 matrix multiply. -# Matrices used in this test are so small that TF32 matmul -# can be less precise so that `self.assertEqual` raises. -torch.backends.cuda.matmul.allow_tf32 = False - - -class TensorParallelLayerTestBase: - BATCH_SIZE: int = 8 - SEQUENCE_LENGTH: int = 128 - VOCAB_SIZE: int = 1024 - HIDDEN_SIZE: int = 256 - INPUT_SIZE_COEFF: int = 256 - OUTPUT_SIZE_COEFF: int = 256 - SEED: int = 123456 - - @property - def tensor_shape(self) -> typing.Sequence[int]: - return [self.SEQUENCE_LENGTH, self.BATCH_SIZE, self.HIDDEN_SIZE] - - @torch.no_grad() - @unittest.skipIf(torch.cuda.device_count() < 2, "Requires >=2 GPUs") - def test_all_gather_parity(self) -> None: - if self.DISTRIBUTED_BACKEND == "ucc": - self.skipTest( - "torch_ucc does NOT support `torch.distributed._all_gather_base` as of 2022/06/15" - ) - from torch.distributed.distributed_c10d import all_gather, _all_gather_base # NOQA - - for tensor_model_parallel_world_size in range(1, self.world_size + 1): - if self.world_size % tensor_model_parallel_world_size: - continue - parallel_state.initialize_model_parallel( - tensor_model_parallel_size_=tensor_model_parallel_world_size, - ) - tensor_model_parallel_rank = parallel_state.get_tensor_model_parallel_rank() - cur_tensor_model_device = torch.device(f"cuda:{tensor_model_parallel_rank}") - with torch.no_grad(): - tensor = tensor_model_parallel_rank * torch.ones( - self.tensor_shape, - dtype=torch.float32, - device=cur_tensor_model_device, - ) - numel = tensor.numel() - numel_gathered = tensor_model_parallel_world_size * numel - gathered = torch.empty( - torch.Size((numel_gathered,)), - device=cur_tensor_model_device, - dtype=torch.float32, - requires_grad=False, - ) - chunks = [ - gathered[i * numel : (i + 1) * numel] - for i in range(tensor_model_parallel_world_size) - ] - all_gather(chunks, tensor, group=parallel_state.get_tensor_model_parallel_group()) - - gathered_for_base = torch.empty( - torch.Size((numel_gathered,)), - device=cur_tensor_model_device, - dtype=torch.float32, - requires_grad=False, - ) - _all_gather_base( - gathered_for_base, - tensor, - group=parallel_state.get_tensor_model_parallel_group(), - ) - - msg = f"tensor_model_parallel_world_size: {tensor_model_parallel_world_size}" - self.assertEqual(gathered, gathered_for_base, msg=msg) - parallel_state.destroy_model_parallel() - - @torch.no_grad() - @unittest.skipIf(torch.cuda.device_count() < 2, "Requires >=2 GPUs") - def test_reduce_scatter_parity(self) -> None: - if self.DISTRIBUTED_BACKEND == "ucc": - self.skipTest( - "torch_ucc does NOT support `torch.distributed._reduce_scatter_base` as of 2022/06/15" - ) - from torch.distributed.distributed_c10d import ( - reduce_scatter, - _reduce_scatter_base, - ) # NOQA - - for tensor_model_parallel_world_size in range(2, self.world_size + 1): - if self.world_size % tensor_model_parallel_world_size: - continue - parallel_state.initialize_model_parallel( - tensor_model_parallel_size_=tensor_model_parallel_world_size, - ) - tensor_model_parallel_rank = parallel_state.get_tensor_model_parallel_rank() - cur_tensor_model_device = torch.device(f"cuda:{tensor_model_parallel_rank}") - with torch.no_grad(): - input = torch.cat( - [ - i - * torch.ones( - self.tensor_shape, - dtype=torch.float32, - device=cur_tensor_model_device, - ) - for i in range(tensor_model_parallel_world_size) - ] - ) - input_list = [t.clone() for t in input.chunk(tensor_model_parallel_world_size)] - output = torch.empty( - self.tensor_shape, - device=cur_tensor_model_device, - dtype=torch.float32, - requires_grad=False, - ) - reduce_scatter( - output, - input_list, - group=parallel_state.get_tensor_model_parallel_group(), - ) - - output_for_base = torch.empty( - self.tensor_shape, - device=cur_tensor_model_device, - dtype=torch.float32, - requires_grad=False, - ) - _reduce_scatter_base( - output_for_base, - input, - group=parallel_state.get_tensor_model_parallel_group(), - ) - - msg = f"tensor_model_parallel_world_size: {tensor_model_parallel_world_size}" - self.assertEqual(output, output_for_base, msg=msg) - self.assertEqual(input, torch.cat(input_list), msg=msg) - parallel_state.destroy_model_parallel() - - def test_parallel_embedding(self) -> None: - for tensor_model_parallel_world_size in range(1, self.world_size + 1): - if self.world_size % tensor_model_parallel_world_size: - continue - parallel_state.initialize_model_parallel( - tensor_model_parallel_size_=tensor_model_parallel_world_size, - ) - set_random_seed(self.SEED + 1) - input_tensor = torch.randint( - 0, - self.VOCAB_SIZE, - ( - self.BATCH_SIZE, - self.SEQUENCE_LENGTH, - ), - device="cuda", - ) - loss_weight = torch.randn( - ( - self.BATCH_SIZE, - self.SEQUENCE_LENGTH, - self.HIDDEN_SIZE, - ), - device="cuda", - ) - - set_random_seed(self.SEED) - embedding_torch = nn.Embedding( - self.VOCAB_SIZE, - self.HIDDEN_SIZE, - ).cuda() - output_torch = embedding_torch(input_tensor) - loss_torch = torch.mul(output_torch, loss_weight).sum() - loss_torch.backward() - - # N.B.(mkozuki): With affine weight initialization on GPU, - # it's super difficult to keep the consistency with nn.Embedding. - # Thus, turning on `use_cpu_initialization`. - set_random_seed(self.SEED) - embedding_vocab_parallel = layers.VocabParallelEmbedding( - self.VOCAB_SIZE, - self.HIDDEN_SIZE, - init_method=nn.init.normal_, - use_cpu_initialization=True, - ).cuda() - output_vocab_parallel = embedding_vocab_parallel(input_tensor) - loss_vocab_parallel = torch.mul(output_vocab_parallel, loss_weight).sum() - loss_vocab_parallel.backward() - - msg = f"tensor_model_parallel_world_size: {tensor_model_parallel_world_size}" - self.assertEqual(output_torch, output_vocab_parallel, msg=msg) - self.assertEqual(loss_torch, loss_vocab_parallel, msg=msg) - - splitted_weight_torch = torch.split( - embedding_torch.weight.grad, - self.VOCAB_SIZE // tensor_model_parallel_world_size, - 0, - )[parallel_state.get_tensor_model_parallel_rank()] - self.assertEqual( - splitted_weight_torch, - embedding_vocab_parallel.weight.grad, - msg=msg, - ) - - parallel_state.destroy_model_parallel() - - def _affine_weight_init_test_impl(self, init_device: str, is_column_parallel: bool) -> None: - dim = int(not is_column_parallel) - for tensor_model_parallel_world_size in range(1, self.world_size + 1): - if self.world_size % tensor_model_parallel_world_size: - continue - parallel_state.initialize_model_parallel( - tensor_model_parallel_size_=tensor_model_parallel_world_size - ) - input_size: int = self.INPUT_SIZE_COEFF * tensor_model_parallel_world_size - output_size: int = self.OUTPUT_SIZE_COEFF * tensor_model_parallel_world_size - - weight_shape = ( - (self.OUTPUT_SIZE_COEFF, input_size) - if is_column_parallel - else (output_size, self.INPUT_SIZE_COEFF) - ) - weight = torch.empty(weight_shape) - set_random_seed(self.SEED) - - sharding_dim_size = ( - self.OUTPUT_SIZE_COEFF if is_column_parallel else self.INPUT_SIZE_COEFF - ) - - if init_device == "cpu": - layers._initialize_affine_weight_cpu( - weight, - output_size, - input_size, - sharding_dim_size, - dim, - nn.init.normal_, - params_dtype=torch.float32, - ) - else: - layers._initialize_affine_weight_gpu(weight, torch.nn.init.normal_, dim) - # Target - set_random_seed(self.SEED) - if init_device == "cpu": - main_weight = torch.empty(output_size, input_size) - nn.init.normal_(main_weight) - curr_weight = torch.split(main_weight, sharding_dim_size, dim=dim)[ - parallel_state.get_tensor_model_parallel_rank() - ] - else: - curr_weight = torch.empty(*weight_shape) - nn.init.normal_(curr_weight) - - self.assertEqual( - curr_weight, - weight, - msg=f"tensor_model_parallel_world_size: {tensor_model_parallel_world_size}", - ) - parallel_state.destroy_model_parallel() - - def test_affine_weight_init_column_parallel_cpu(self) -> None: - self._affine_weight_init_test_impl(init_device="cpu", is_column_parallel=True) - - def test_affine_weight_init_column_parallel_gpu(self) -> None: - self._affine_weight_init_test_impl(init_device="gpu", is_column_parallel=True) - - def test_affine_weight_init_row_parallel_cpu(self) -> None: - self._affine_weight_init_test_impl(init_device="cpu", is_column_parallel=False) - - def test_affine_weight_init_row_parallel_gpu(self) -> None: - self._affine_weight_init_test_impl(init_device="gpu", is_column_parallel=False) - - def test_row_parallel_linear(self) -> None: - self._row_parallel_linear_test_impl(False, False, False) - - def test_row_parallel_linear_gradient_accumulation_fusion(self) -> None: - self._row_parallel_linear_test_impl(True, False, False) - - def test_row_parallel_linear_gradient_accumulation_fusion_in_fp16(self) -> None: - self._row_parallel_linear_test_impl(True, True, False) - - # fails on native ucc and torch ucc: ucc does not support reduce scatter - @unittest.skipIf(torch.cuda.device_count() < 2, "Sequence Parallel requires >=2 GPUs") - def test_row_parallel_linear_sequence_parallel(self) -> None: - self._row_parallel_linear_test_impl(False, False, True) - - # TODO(mkozuki): Merge this with `_column_parallel_linear_test_impl` - # Note that `input_is_parallel` is unique to `RowParallelLinear` which could make the merge complicated. - def _row_parallel_linear_test_impl( - self, - gradient_accumulation_fusion: bool, - accumulation_in_fp16: bool, - sequence_parallel_enabled: bool, - ) -> None: - tensor_shape = ( - self.SEQUENCE_LENGTH, - self.BATCH_SIZE, - self.HIDDEN_SIZE, - ) - for tensor_model_parallel_world_size in range( - 1 + int(sequence_parallel_enabled), self.world_size + 1 - ): - if self.world_size % tensor_model_parallel_world_size: - continue - parallel_state.initialize_model_parallel( - tensor_model_parallel_size_=tensor_model_parallel_world_size, - ) - set_random_seed(self.SEED) - - linear = layers.RowParallelLinear( - self.HIDDEN_SIZE, - self.HIDDEN_SIZE, - keep_master_weight_for_test=True, - params_dtype=torch.float32, - use_cpu_initialization=True, - gradient_accumulation_fusion=gradient_accumulation_fusion, - accumulation_in_fp16=accumulation_in_fp16, - sequence_parallel_enabled=sequence_parallel_enabled, - # n.b.(mkozuki): RowParallelLinear is constructed with `input_is_parallel=True` - # by default, e.g. https://github.com/NVIDIA/NeMo/blob/782b4e1652aaa43c8be390d9\ - # db0dc89544afa080/nemo/collections/nlp/modules/common/megatron/transformer.py#L204 - input_is_parallel=True, - ).cuda() - if accumulation_in_fp16: - linear = linear.half() - # Simulate the situation where fusion of weight grad calculation and gradient accumulation is enabled. - if gradient_accumulation_fusion: - with torch.no_grad(): - linear.weight.main_grad = torch.zeros_like(linear.weight) - - msg = f"tensor_model_parallel_world_size: {tensor_model_parallel_world_size}" - - with torch.no_grad(): - orig_input_tensor = torch.randn(tensor_shape, requires_grad=True, device="cuda") - orig_loss_weight = torch.randn(tensor_shape, device="cuda") - input_tensor = orig_input_tensor.chunk( - chunks=tensor_model_parallel_world_size, - dim=2, - )[parallel_state.get_tensor_model_parallel_rank()].contiguous() - if sequence_parallel_enabled: - loss_weight = orig_loss_weight.chunk( - chunks=tensor_model_parallel_world_size, - dim=0, - )[parallel_state.get_tensor_model_parallel_rank()] - else: - loss_weight = orig_loss_weight - if accumulation_in_fp16: - orig_input_tensor = orig_input_tensor.half() - input_tensor = input_tensor.half() - loss_weight = loss_weight.half() - input_tensor.requires_grad_() - output, _ = linear(input_tensor) - loss = torch.mul(output, loss_weight).sum() - loss.backward() - self.assertIsNotNone(input_tensor.grad, msg=msg) - - ref_linear = nn.Linear( - in_features=self.HIDDEN_SIZE, - out_features=self.HIDDEN_SIZE, - bias=False, - device="cuda", - ) - with torch.no_grad(): - dldy = orig_loss_weight.clone() - x = orig_input_tensor.clone() - ref_linear.weight.copy_(linear.master_weight) - if accumulation_in_fp16: - ref_linear = ref_linear.half() - x.requires_grad_() - expected_output = ref_linear(x) - expected_loss = torch.mul(expected_output, dldy).sum() - expected_loss.backward() - - if not accumulation_in_fp16: - if sequence_parallel_enabled: - self.assertEqual( - x=output, - y=expected_output.chunk( - chunks=tensor_model_parallel_world_size, - dim=0, - )[parallel_state.get_tensor_model_parallel_rank()], - msg=msg, - ) - else: - self.assertEqual( - x=output, - y=expected_output, - msg=msg, - ) - - grad_attr_name = "main_grad" if gradient_accumulation_fusion else "grad" - # NOTE(mkozuki): Numerical errors seems to be enlarged by tensor model parallel. - if tensor_model_parallel_world_size == 1: - self.assertEqual( - x=getattr(linear.weight, grad_attr_name), - y=ref_linear.weight.grad.chunk( - chunks=tensor_model_parallel_world_size, - dim=0, - )[parallel_state.get_tensor_model_parallel_rank()], - msg=msg, - ) - - parallel_state.destroy_model_parallel() - - def test_column_parallel_linear(self): - self._column_parallel_linear_test_impl(False, False, False, False) - - def test_column_parallel_linear_async(self): - self._column_parallel_linear_test_impl(True, False, False, False) - - def test_column_parallel_linear_gradient_accumulation_fusion(self): - self._column_parallel_linear_test_impl(False, True, False, False) - - def test_column_parallel_linear_gradient_accumulation_fusion_in_fp16(self): - self._column_parallel_linear_test_impl(False, True, True, False) - - def test_column_parallel_linear_sequence_parallel(self): - if self.DISTRIBUTED_BACKEND == "ucc": - self.skipTest("Backward's reduce_scatter fails. as of 2022/06/15") - self._column_parallel_linear_test_impl(False, False, False, True) - - @unittest.skipIf(torch.cuda.device_count() < 2, "Sequence Parallel requires >= 2 GPUs") - def test_column_parallel_linear_exception(self): - with self.assertRaisesRegex( - RuntimeError, - "`async_tensor_model_parallel_allreduce` and `sequence_parallel_enabled` cannot be enabled at the same time.", - ): - self._column_parallel_linear_test_impl(True, False, False, True) - - def _column_parallel_linear_test_impl( - self, - async_tensor_model_parallel_allreduce: bool, - gradient_accumulation_fusion: bool, - accumulation_in_fp16: bool, - sequence_parallel_enabled: bool, - ): - for tensor_model_parallel_world_size in range(1, self.world_size + 1): - if async_tensor_model_parallel_allreduce and sequence_parallel_enabled: - if tensor_model_parallel_world_size == 1: - continue - if self.world_size % tensor_model_parallel_world_size: - continue - msg = f"tensor_model_parallel_world_size: {tensor_model_parallel_world_size}" - parallel_state.initialize_model_parallel( - tensor_model_parallel_size_=tensor_model_parallel_world_size, - ) - - input_tensor_shape = self.tensor_shape - expected_output_shape = self.tensor_shape - # When sequence parallel, `gather_output` is disabled, i.e., - # output of matmul isn't gathered in dimension of feature/hidden (last dim). - if sequence_parallel_enabled: - expected_output_shape[-1] //= tensor_model_parallel_world_size - - # tensor's shape is [sequence length, batch size, hidden size] - set_random_seed(self.SEED) - linear = layers.ColumnParallelLinear( - self.HIDDEN_SIZE, - self.HIDDEN_SIZE, - bias=False, - keep_master_weight_for_test=True, - params_dtype=torch.float32, - use_cpu_initialization=True, - gather_output=not sequence_parallel_enabled, - no_async_tensor_model_parallel_allreduce=not async_tensor_model_parallel_allreduce, - gradient_accumulation_fusion=gradient_accumulation_fusion, - accumulation_in_fp16=accumulation_in_fp16, - sequence_parallel_enabled=sequence_parallel_enabled, - ).cuda() - if accumulation_in_fp16: - linear = linear.half() - - # Simulate the situation where fusion of weight grad calculation and gradient accumulation happens. - if gradient_accumulation_fusion: - with torch.no_grad(): - linear.weight.main_grad = torch.zeros_like(linear.weight) - - orig_input_tensor = torch.randn(input_tensor_shape, device="cuda", requires_grad=True) - if accumulation_in_fp16: - orig_input_tensor = orig_input_tensor.half() - if sequence_parallel_enabled: - input_tensor = list( - orig_input_tensor.chunk(tensor_model_parallel_world_size, dim=0) - )[parallel_state.get_tensor_model_parallel_rank()] - else: - input_tensor = orig_input_tensor - output, _ = linear(input_tensor) - # The order of dimension is expected to be (sequence, batch, hidden) - self.assertEqual(output.shape, expected_output_shape, msg=msg) - - orig_loss_weight = torch.randn(input_tensor_shape, device="cuda") - if accumulation_in_fp16: - orig_loss_weight = orig_loss_weight.half() - if sequence_parallel_enabled: - loss_weight = orig_loss_weight.chunk( - tensor_model_parallel_world_size, - dim=2, - )[parallel_state.get_tensor_model_parallel_rank()] - else: - loss_weight = orig_loss_weight - loss = torch.mul(output, loss_weight).sum() - loss.backward() - - with torch.no_grad(): - dldy = orig_loss_weight.clone() - x = orig_input_tensor.clone() - ref_linear = nn.Linear( - in_features=self.HIDDEN_SIZE, - out_features=self.HIDDEN_SIZE, - bias=False, - device="cuda", - ) - if accumulation_in_fp16: - ref_linear = ref_linear.half() - # NOTE(mkozuki): `master_weight` is available because `keep_master_weight_for_test` is set. - ref_linear.weight.copy_(linear.master_weight) - x.requires_grad_() - expected_output = ref_linear(x) - if sequence_parallel_enabled: - chunk = expected_output.chunk( - tensor_model_parallel_world_size, - dim=2, - )[parallel_state.get_tensor_model_parallel_rank()] - self.assertEqual( - x=output, - y=chunk, - msg=msg, - ) - else: - self.assertEqual( - x=output, - y=expected_output, - msg=msg, - ) - - expected_loss = torch.mul(expected_output, dldy).sum() - expected_loss.backward() - grad_attr_name = "main_grad" if gradient_accumulation_fusion else "grad" - # NOTE(mkozuki): Numerical errors seems to be enlarged by tensor model parallel. - if tensor_model_parallel_world_size == 1: - self.assertEqual( - x=getattr(linear.weight, grad_attr_name), - y=ref_linear.weight.grad.chunk( - chunks=tensor_model_parallel_world_size, - dim=0, - )[parallel_state.get_tensor_model_parallel_rank()], - msg=msg, - ) - - parallel_state.destroy_model_parallel() - - -class NcclTensorParallelLayerTest(TensorParallelLayerTestBase, NcclDistributedTestBase): - pass - - -class UccTensorParallelLayerTest(TensorParallelLayerTestBase, UccDistributedTestBase): - pass - - -if __name__ == "__main__": - common_utils.run_tests() diff --git a/tests/L0/run_transformer/test_mapping.py b/tests/L0/run_transformer/test_mapping.py deleted file mode 100644 index 537bb1094..000000000 --- a/tests/L0/run_transformer/test_mapping.py +++ /dev/null @@ -1,84 +0,0 @@ -import logging - -import torch -from torch.testing._internal import common_utils - -from apex.transformer import parallel_state -from apex.transformer.tensor_parallel import mappings -from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase -from apex.transformer.testing.distributed_test_base import UccDistributedTestBase - - -logging.getLogger("torch").setLevel(logging.WARNING) -logging.getLogger("apex").setLevel(logging.WARNING) - - -class MappingTestBase: - def test_reduce(self): - for tensor_model_paralell_world_size in range(1, self.world_size + 1): - if self.world_size % tensor_model_paralell_world_size > 0: - continue - parallel_state.initialize_model_parallel( - tensor_model_parallel_size_=tensor_model_paralell_world_size - ) - t = torch.full((10, 10, 10, 10), 50, device=f"cuda:{self.rank}") - expected = torch.full( - (10, 10, 10, 10), - 50 * tensor_model_paralell_world_size, - device=f"cuda:{self.rank}", - ) - self.assertTrue( - torch.equal(mappings._reduce(t), expected), - msg=f"tensor_model_paralell_world_size: {tensor_model_paralell_world_size}", - ) - parallel_state.destroy_model_parallel() - - def test_split(self): - for tensor_model_paralell_world_size in range(1, self.world_size + 1): - if self.world_size % tensor_model_paralell_world_size > 0: - continue - parallel_state.initialize_model_parallel( - tensor_model_parallel_size_=tensor_model_paralell_world_size - ) - - tensors = [torch.randn(10, 1) for _ in range(tensor_model_paralell_world_size)] - x = torch.cat(tensors, 1) - out = mappings._split_along_last_dim(x) - self.assertTrue( - torch.equal(out, tensors[parallel_state.get_tensor_model_parallel_rank()]), - msg=f"tensor_model_paralell_world_size: {tensor_model_paralell_world_size}", - ) - parallel_state.destroy_model_parallel() - - def test_gather(self): - for tensor_model_paralell_world_size in range(1, self.world_size + 1): - if self.world_size % tensor_model_paralell_world_size > 0: - continue - parallel_state.initialize_model_parallel( - tensor_model_parallel_size_=tensor_model_paralell_world_size - ) - device = f"cuda:{self.rank}" - gathered = mappings._gather_along_last_dim( - torch.tensor([parallel_state.get_tensor_model_parallel_rank()], device=device) - ) - expected = torch.tensor( - [rank for rank in range(tensor_model_paralell_world_size)], - device=device, - ) - self.assertTrue( - torch.equal(gathered, expected), - msg=f"tensor_model_paralell_world_size: {tensor_model_paralell_world_size}", - ) - parallel_state.destroy_model_parallel() - - -class NcclMappingTest(MappingTestBase, NcclDistributedTestBase): - pass - - -class UccMappingTest(MappingTestBase, UccDistributedTestBase): - pass - - -if __name__ == "__main__": - common_utils.run_tests() diff --git a/tests/L0/run_transformer/test_microbatches.py b/tests/L0/run_transformer/test_microbatches.py deleted file mode 100644 index 3fb904570..000000000 --- a/tests/L0/run_transformer/test_microbatches.py +++ /dev/null @@ -1,95 +0,0 @@ -import logging -from typing import List, Optional - -from torch.testing._internal import common_utils - -logging.getLogger("torch").setLevel(logging.WARNING) - -from apex.transformer import parallel_state -from apex.transformer.pipeline_parallel.utils import ( - _reconfigure_microbatch_calculator, - get_micro_batch_size, - get_num_microbatches, - get_current_global_batch_size, - update_num_microbatches, -) -from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase -from apex.transformer.testing.distributed_test_base import UccDistributedTestBase - -logging.getLogger("apex").setLevel(logging.WARNING) - - -class MicrobatchCalculatorTestBase: - GLOBAL_BATCH_SIZE: int = 1024 - MICRO_BATCH_SIZE: int = 1 - - def _test(self, rampup_batch_size: Optional[List[int]]) -> None: - for data_parallel_size in range(1, self.world_size + 1): - expected_global_batch_size = self.GLOBAL_BATCH_SIZE - expected_micro_batch_size = self.MICRO_BATCH_SIZE - if rampup_batch_size: - expected_global_batch_size = rampup_batch_size[0] - num_consumed_samples = 0 - step_of_global_batch_size = rampup_batch_size[1] - threshold = rampup_batch_size[2] - - if data_parallel_size > 1 and data_parallel_size % 2 != 0: - continue - if self.world_size % data_parallel_size != 0: - continue - msg = f"data_parallel_size: {data_parallel_size}" - parallel_state.initialize_model_parallel( - tensor_model_parallel_size_=self.world_size // data_parallel_size, - pipeline_model_parallel_size_=1, - ) - self.assertEqual( - data_parallel_size, - parallel_state.get_data_parallel_world_size(), - msg=msg, - ) - - _reconfigure_microbatch_calculator( - self.rank, - rampup_batch_size, - self.GLOBAL_BATCH_SIZE, - self.MICRO_BATCH_SIZE, - data_parallel_size, - ) - - self.assertEqual(get_micro_batch_size(), expected_micro_batch_size, msg=msg) - self.assertEqual( - get_num_microbatches(), - expected_global_batch_size / expected_micro_batch_size / data_parallel_size, - msg=msg, - ) - current_global_batch_size = get_current_global_batch_size() - self.assertEqual(current_global_batch_size, expected_global_batch_size, msg=msg) - - # Make sure `global_batch_size` equals to the final global batch size after - # certain number of updates. - if rampup_batch_size: - update_num_microbatches(current_global_batch_size) - for i in range(100): - current_global_batch_size = get_current_global_batch_size() - update_num_microbatches(current_global_batch_size) - current_global_batch_size = get_current_global_batch_size() - self.assertEqual(get_current_global_batch_size(), self.GLOBAL_BATCH_SIZE, msg=msg) - parallel_state.destroy_model_parallel() - - def test_constant_microbatch_calculator(self): - self._test(rampup_batch_size=None) - - def test_dynamic_microbatch_calculator(self): - self._test(rampup_batch_size=[256, 128, 500]) - - -class NcclMicrobatchCalculatorTest(MicrobatchCalculatorTestBase, NcclDistributedTestBase): - pass - - -class UccMicrobatchCalculatorTest(MicrobatchCalculatorTestBase, UccDistributedTestBase): - pass - - -if __name__ == "__main__": - common_utils.run_tests() diff --git a/tests/L0/run_transformer/test_p2p_comm.py b/tests/L0/run_transformer/test_p2p_comm.py deleted file mode 100644 index 861c2144d..000000000 --- a/tests/L0/run_transformer/test_p2p_comm.py +++ /dev/null @@ -1,129 +0,0 @@ -import logging -import unittest - -import torch -from torch.testing._internal import common_utils - -logging.getLogger("torch").setLevel(logging.WARNING) - -from apex.transformer import parallel_state -from apex.transformer.pipeline_parallel import p2p_communication -from apex.transformer.testing.distributed_test_base import UccDistributedTestBase - -logging.getLogger("apex").setLevel(logging.DEBUG) - - -# [P2P Ops Involved in Pipeline Model Parallel forward/backward] -# **forward_backward_pipelining_without_interleaving** -# - send_forward / recv_forward -# - send_backward / recv_backward -# - send_forward_recv_backward -# - send_backward_recv_forward -# **forward_backward_pipelining_with_interleaving** -# - send_backward_recv_backward -# - recv_backward -# - recv_forward -# - send_forward_backward_recv_forward_backward -# - send_forward_recv_forward -class P2PCommTestBase: - numel = 4 - shape = (2, 2) - dtype = torch.float32 - - @property - def world_size(self): - return min(2, torch.cuda.device_count()) - - def _init_model_parallel(self): - parallel_state.initialize_model_parallel( - tensor_model_parallel_size_=1, - pipeline_model_parallel_size_=self.world_size, - virtual_pipeline_model_parallel_size_=None, - ) - - def create_tensor(self, value: int = None): - return ( - torch.tensor([value] * self.numel).view(self.shape).to(device="cuda", dtype=self.dtype) - ) - - # Brief: Simulate warm-up. - # Brief: test `recv_forward` & `send_forward`. - def test_no_interleaving_warmup(self): - self.assertEqual(self.world_size, 2) - self._init_model_parallel() - input_tensor = None - if parallel_state.is_pipeline_first_stage(): - tensor = self.create_tensor(self.rank) - print(tensor) - p2p_communication.send_forward( - output_tensor=tensor, tensor_shape=self.shape, dtype=self.dtype - ) - else: - input_tensor = p2p_communication.recv_forward(tensor_shape=self.shape, dtype=self.dtype) - - if parallel_state.is_pipeline_first_stage(): - self.assertIsNone(input_tensor) - else: - expected_input_tensor = self.create_tensor(self.rank - 1) - self.assertEqual(input_tensor, expected_input_tensor) - - # Brief: test `send_forward`, `send_forward_recv_forward`, and `recv_forward`. - def test_send_forward_recv_forward(self): - self._init_model_parallel() - prev_tensor = None - tensor = self.create_tensor(self.rank) - if parallel_state.is_pipeline_first_stage(): - p2p_communication.send_forward( - output_tensor=tensor, tensor_shape=self.shape, dtype=self.dtype - ) - elif parallel_state.is_pipeline_last_stage(): - prev_tensor = p2p_communication.recv_forward(tensor_shape=self.shape, dtype=self.dtype) - else: - prev_tensor = p2p_communication.send_forward_recv_forward( - output_tensor=tensor, - recv_prev=True, - tensor_shape=self.shape, - dtype=self.dtype, - ) - - if parallel_state.is_pipeline_first_stage(): - self.assertIsNone(prev_tensor) - else: - expected_prev_tensor = self.create_tensor(self.rank - 1) - self.assertEqual(prev_tensor, expected_prev_tensor) - - # Brief: test `send_backward`, `send_backward_recv_backward`, and `recv_backward`. - def test_send_backward_recv_backward(self): - self._init_model_parallel() - tensor = self.create_tensor(self.rank) - - next_tensor = None - if parallel_state.is_pipeline_first_stage(): - next_tensor = p2p_communication.recv_backward(tensor_shape=self.shape, dtype=self.dtype) - elif parallel_state.is_pipeline_last_stage(): - p2p_communication.send_backward( - input_tensor_grad=tensor, tensor_shape=self.shape, dtype=self.dtype - ) - else: - next_tensor = p2p_communication.send_backward_recv_backward( - input_tensor_grad=tensor, - recv_next=True, - tensor_shape=self.shape, - dtype=self.dtype, - ) - - if parallel_state.is_pipeline_last_stage(): - self.assertIsNone(next_tensor) - else: - expected_next_tensor = self.create_tensor(self.rank + 1) - self.assertEqual(next_tensor, expected_next_tensor) - - -# n.b.(mkozuki): Intentionally skip NCCL backend tests as I trust pytorch/pytorch repo. -@unittest.skipIf(torch.cuda.device_count() < 2, "Requires >= 2 GPUs") -class UccP2PCommTest(P2PCommTestBase, UccDistributedTestBase): - pass - - -if __name__ == "__main__": - common_utils.run_tests() diff --git a/tests/L0/run_transformer/test_parallel_state.py b/tests/L0/run_transformer/test_parallel_state.py deleted file mode 100644 index 1d73739cf..000000000 --- a/tests/L0/run_transformer/test_parallel_state.py +++ /dev/null @@ -1,183 +0,0 @@ -import logging -import os - -from torch.testing._internal import common_utils - -logging.getLogger("torch").setLevel(logging.WARNING) - -from apex.transformer import parallel_state -from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase -from apex.transformer.testing.distributed_test_base import UccDistributedTestBase - -logging.getLogger("apex").setLevel(logging.WARNING) - - -os.environ["BACKEND"] = "NCCL" -DATA_PARALLEL_WORLD_SIZE: int = 1 - - -def calc_expected_tensor_model_paralell_rank( - rank: int, - tensor_model_parallel_world_size: int, -) -> int: - return rank % tensor_model_parallel_world_size - - -class ParallelStateTestBase: - def test_initialize_model_parallel(self) -> None: - self.assertFalse(parallel_state.model_parallel_is_initialized()) - - for tensor_model_parallel_world_size in range(1, self.world_size + 1): - msg = f"tensor_model_parallel_world_siz: {tensor_model_parallel_world_size}" - if self.world_size % tensor_model_parallel_world_size: - continue - - pipeline_model_parallel_world_size = self.world_size // tensor_model_parallel_world_size - - parallel_state.initialize_model_parallel( - tensor_model_parallel_size_=tensor_model_parallel_world_size, - pipeline_model_parallel_size_=pipeline_model_parallel_world_size, - ) - self.assertEqual( - tensor_model_parallel_world_size, - parallel_state.get_tensor_model_parallel_world_size(), - msg=msg, - ) - expected_tensor_model_parallel_rank = calc_expected_tensor_model_paralell_rank( - self.rank, tensor_model_parallel_world_size - ) - self.assertEqual( - expected_tensor_model_parallel_rank, - parallel_state.get_tensor_model_parallel_rank(), - msg=msg, - ) - - expected_tensor_model_parallel_src_rank = ( - self.rank // tensor_model_parallel_world_size - ) * tensor_model_parallel_world_size - self.assertEqual( - expected_tensor_model_parallel_src_rank, - parallel_state.get_tensor_model_parallel_src_rank(), - msg=msg, - ) - - parallel_state.destroy_model_parallel() - self.assertFalse(parallel_state.model_parallel_is_initialized(), msg=msg) - - def test_initialize_model_parallel_with_virtual_and_split(self) -> None: - if self.world_size < 4: - self.skipTest("requires >= 4 GPUs") - self.assertFalse(parallel_state.model_parallel_is_initialized()) - - tensor_model_parallel_world_size = 1 + int(self.world_size > 4) - pipeline_model_parallel_world_size = self.world_size // tensor_model_parallel_world_size - virtual_pipeline_model_parallel_world_size = 2 - pipeline_model_parallel_split_rank = pipeline_model_parallel_world_size // 2 - - parallel_state.initialize_model_parallel( - tensor_model_parallel_size_=tensor_model_parallel_world_size, - pipeline_model_parallel_size_=pipeline_model_parallel_world_size, - virtual_pipeline_model_parallel_size_=virtual_pipeline_model_parallel_world_size, - pipeline_model_parallel_split_rank_=pipeline_model_parallel_split_rank, - ) - self.assertEqual( - calc_expected_tensor_model_paralell_rank(self.rank, tensor_model_parallel_world_size), - parallel_state.get_tensor_model_parallel_rank(), - ) - self.assertEqual( - pipeline_model_parallel_world_size, - parallel_state.get_pipeline_model_parallel_world_size(), - ) - self.assertEqual( - virtual_pipeline_model_parallel_world_size, - parallel_state.get_virtual_pipeline_model_parallel_world_size(), - ) - - expected_pipeline_rank = ( - self.rank - (self.rank % tensor_model_parallel_world_size) - ) % pipeline_model_parallel_world_size - self.assertEqual( - expected_pipeline_rank, - parallel_state.get_pipeline_model_parallel_rank(), - ) - # virtual pipeline model parallel rank is lazily set, i.e., right after the call of - # `initialize_model_parallel`, it's set to 0. - self.assertEqual( - 0, - parallel_state.get_virtual_pipeline_model_parallel_rank(), - ) - self.assertEqual( - pipeline_model_parallel_split_rank, - parallel_state.get_pipeline_model_parallel_split_rank(), - ) - - fake_split_rank = 77 - parallel_state.set_pipeline_model_parallel_split_rank(fake_split_rank) - self.assertEqual(fake_split_rank, parallel_state.get_pipeline_model_parallel_split_rank()) - - # relative position embedding groups check - self.assertEqual( - expected_pipeline_rank < pipeline_model_parallel_split_rank, - parallel_state.is_rank_in_encoder_relative_position_embedding_group(), - ) - self.assertEqual( - expected_pipeline_rank >= pipeline_model_parallel_split_rank, - parallel_state.is_rank_in_decoder_relative_position_embedding_group(), - ) - - parallel_state.destroy_model_parallel() - - def test_initialize_model_parallel_decoder_only(self) -> None: - """Initialize model parallelism for decoder-only Transformers like GPT-3""" - - self.assertFalse(parallel_state.model_parallel_is_initialized()) - - for tensor_model_parallel_world_size in range(1, self.world_size + 1): - msg = f"tensor_model_parallel_world_size: {tensor_model_parallel_world_size}" - if self.world_size % tensor_model_parallel_world_size: - continue - - pipeline_model_parallel_world_size = self.world_size // tensor_model_parallel_world_size - - parallel_state.initialize_model_parallel( - tensor_model_parallel_size_=tensor_model_parallel_world_size, - pipeline_model_parallel_size_=pipeline_model_parallel_world_size, - pipeline_model_parallel_split_rank_=0, - ) - self.assertEqual( - tensor_model_parallel_world_size, - parallel_state.get_tensor_model_parallel_world_size(), - msg=msg, - ) - expected_tensor_model_parallel_rank = calc_expected_tensor_model_paralell_rank( - self.rank, tensor_model_parallel_world_size - ) - self.assertEqual( - expected_tensor_model_parallel_rank, - parallel_state.get_tensor_model_parallel_rank(), - msg=msg, - ) - - expected_tensor_model_parallel_src_rank = ( - self.rank // tensor_model_parallel_world_size - ) * tensor_model_parallel_world_size - self.assertEqual( - expected_tensor_model_parallel_src_rank, - parallel_state.get_tensor_model_parallel_src_rank(), - msg=msg, - ) - - parallel_state.destroy_model_parallel() - self.assertFalse(parallel_state.model_parallel_is_initialized(), msg=msg) - - -class NcclParallelStateTest(ParallelStateTestBase, NcclDistributedTestBase): - pass - - -class UccParallelStateTest(ParallelStateTestBase, UccDistributedTestBase): - pass - - -if __name__ == "__main__": - common_utils.run_tests() diff --git a/tests/L0/run_transformer/test_pipeline_parallel_fwd_bwd.py b/tests/L0/run_transformer/test_pipeline_parallel_fwd_bwd.py deleted file mode 100644 index daf263e39..000000000 --- a/tests/L0/run_transformer/test_pipeline_parallel_fwd_bwd.py +++ /dev/null @@ -1,891 +0,0 @@ -import contextlib -import logging -import itertools -import os -from datetime import datetime -from packaging.version import parse, Version -import re -from typing import Optional, Tuple, List -import unittest - -import torch -from torch.testing._internal import common_utils - -from apex._autocast_utils import _get_autocast_dtypes -from apex.transformer import parallel_state -from apex.transformer.enums import ModelType -from apex.transformer.pipeline_parallel import utils as pp_utils -from apex.transformer.pipeline_parallel.schedules.common import ( - FwdStepFunc, - build_model, - _get_params_for_weight_decay_optimization, -) -from apex.transformer.pipeline_parallel.schedules.fwd_bwd_no_pipelining import ( - forward_backward_no_pipelining, -) -from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_with_interleaving import ( - _forward_backward_pipelining_with_interleaving, -) -from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_without_interleaving import ( - forward_backward_pipelining_without_interleaving, -) -from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase -from apex.transformer.testing.distributed_test_base import UccDistributedTestBase -from apex.transformer.testing.distributed_test_base import ( - HAS_TORCH_UCC_COMPAT_NVIDIA_DRIVER, -) -from apex.transformer.testing import commons as testing_utils -from apex.transformer._ucc_util import HAS_UCC - -logging.getLogger("torch").setLevel(logging.WARNING) -logging.getLogger("apex").setLevel(logging.WARNING) - -weight_coeff = 1024 - - -# Guard for https://github.com/pytorch/pytorch/pull/82450 -def get_nvidia_pytorch_version(): - ver = os.getenv("NVIDIA_PYTORCH_VERSION", "22.08") - if "master" in ver: - ver = datetime.today().strftime("%y.%m") - elif "update_for_" in ver: - ver = ver.replace("update_for_", "") - return ver - - -CAN_SKIP_SYNC_AFTER_BATCH_ISEND_IRECV = False -ngc_container_2209, pytorch_113 = Version("22.09"), Version("1.13") -if parse(torch.__version__) >= pytorch_113: - CAN_SKIP_SYNC_AFTER_BATCH_ISEND_IRECV = True -elif parse(get_nvidia_pytorch_version()) >= ngc_container_2209: - CAN_SKIP_SYNC_AFTER_BATCH_ISEND_IRECV = True -else: - CAN_SKIP_SYNC_AFTER_BATCH_ISEND_IRECV = False - - -def get_init_weights_func(offset: int = 0): - @torch.no_grad() - def init_weights(m): - rank = parallel_state.get_pipeline_model_parallel_rank() - if isinstance(m, torch.nn.Linear): - m.weight.fill_((rank + offset + 1.0) / weight_coeff) - m.bias.fill_(1.0) - - return init_weights - - -def get_dtype_for_comparison(): - if torch.cuda.get_device_capability() >= (8, 0): - return torch.float64 - return torch.float32 - - -def get_target_loss_and_model( - global_batch_shape: tuple, hidden_size: int, total_layers: int -) -> Tuple[torch.Tensor, List[torch.Tensor]]: - model = [] - dtype = get_dtype_for_comparison() - data = torch.ones(global_batch_shape, dtype=dtype) - for i in range(total_layers): - w = torch.ones((hidden_size, hidden_size), dtype=dtype) * (i + 1.0) / weight_coeff - b = torch.ones(hidden_size, dtype=dtype) - - w.requires_grad_() - b.requires_grad_() - - # don't need to care about transpose semantics as all values are the same - data = torch.matmul(w, data) + b - model.append([w, b]) - - loss = data.sum() / global_batch_shape[0] - loss.backward() - - return loss, model - - -def _get_default_world_sizes_model_parallel_world_size( - pipeline_model_parallel_world_size: Optional[int] = None, -) -> Tuple[int, int, int]: - # TODO: revisit if we can fold this into the class for skip logic / avoid duplication - # of world size computation - world_size = torch.cuda.device_count() - tensor_model_parallel_world_size = 1 - data_parallel_size = 1 + (world_size >= 8 and world_size % 2 == 0) - - if pipeline_model_parallel_world_size is None: - pipeline_model_parallel_world_size = world_size // ( - tensor_model_parallel_world_size * data_parallel_size - ) - else: - data_parallel_size = world_size // ( - tensor_model_parallel_world_size * pipeline_model_parallel_world_size - ) - - return ( - tensor_model_parallel_world_size, - data_parallel_size, - pipeline_model_parallel_world_size, - ) - - -class PipelineParallelForwardBackwardTestBase: - GLOBAL_BATCH_SIZE = 16 - MICRO_BATCH_SIZE = 2 - HIDDEN_SIZE = 32 - - deallocate_options = (True, False) - # If :obj:`None`, (torch.float32, torch.float16, torch.bfloat16) are dtype options on Ampere. - # You can limit the options by overriding the following `dtypes`. - dtypes = None - - def _forward_backward_test_impl( - self, - forward_only: bool, - fwd_bwd_func: FwdStepFunc, - pipeline_model_parallel_world_size: Optional[int], - virtual_pipeline_model_parallel_size: Optional[int], - async_comm: bool = False, - *, - default_backend: Optional[str] = None, - p2p_backend: Optional[str] = None, - sync_batch_comm: bool = True, - ) -> None: - if fwd_bwd_func == _forward_backward_pipelining_with_interleaving: - self.assertIsNotNone(virtual_pipeline_model_parallel_size) - self.assertGreater(virtual_pipeline_model_parallel_size, 1) - dtype_options = self.dtypes or [torch.float32, torch.double] + _get_autocast_dtypes() - - for dtype, deallocate_pipeline_outputs in itertools.product( - dtype_options, - self.deallocate_options, - ): - grad_scaler = ( - torch.amp.GradScaler("cuda", init_scale=4.0) if dtype == torch.half else None - ) - - ( - tensor_model_parallel_world_size, - data_parallel_size, - pipeline_model_parallel_world_size, - ) = _get_default_world_sizes_model_parallel_world_size( - pipeline_model_parallel_world_size - ) - - parallel_state.initialize_model_parallel( - tensor_model_parallel_size_=tensor_model_parallel_world_size, - pipeline_model_parallel_size_=pipeline_model_parallel_world_size, - virtual_pipeline_model_parallel_size_=virtual_pipeline_model_parallel_size, - default_backend=default_backend, - p2p_backend=p2p_backend, - ) - pp_utils._reconfigure_microbatch_calculator( - rank=parallel_state.get_tensor_model_parallel_rank(), - rampup_batch_size=None, - global_batch_size=self.GLOBAL_BATCH_SIZE, - micro_batch_size=self.MICRO_BATCH_SIZE, - data_parallel_size=parallel_state.get_data_parallel_world_size(), - ) - - global_batch_shape = ( - self.GLOBAL_BATCH_SIZE // parallel_state.get_data_parallel_world_size(), - self.HIDDEN_SIZE, - self.HIDDEN_SIZE, - ) - - batch = None - if parallel_state.is_pipeline_first_stage(): - batch = (torch.ones(global_batch_shape, dtype=dtype).cuda(),) - - model = build_model( - testing_utils.model_provider_func, - # Use DDP only when it's better to have - wrap_with_ddp=data_parallel_size > 1, - virtual_pipeline_model_parallel_size=virtual_pipeline_model_parallel_size, - hidden_size=self.HIDDEN_SIZE, - ) - - offset = ( - pipeline_model_parallel_world_size - if virtual_pipeline_model_parallel_size is not None - else 0 - ) - for idx, model_module in enumerate(model): - model_module = model_module.to(dtype) - model_module.apply(get_init_weights_func(idx * offset)) - - _param_groups = _get_params_for_weight_decay_optimization(model) - optimizer = torch.optim.Adam(_param_groups, lr=1e-3) - - pp_utils.update_num_microbatches(0) - - loss = fwd_bwd_func( - testing_utils.fwd_step_func, - batch, - model, - forward_only=forward_only, - # `tensor_shape` is the shape of micro batch. - tensor_shape=( - self.MICRO_BATCH_SIZE, - self.HIDDEN_SIZE, - self.HIDDEN_SIZE, - ), - dtype=dtype, - async_comm=async_comm, - grad_scaler=grad_scaler, - deallocate_pipeline_output=deallocate_pipeline_outputs, - sync_batch_comm=sync_batch_comm, - ) - - if dtype == get_dtype_for_comparison(): - torch.cuda.synchronize() - hidden_size = self.HIDDEN_SIZE - microbatch_size = self.MICRO_BATCH_SIZE - total_layers = pipeline_model_parallel_world_size - if virtual_pipeline_model_parallel_size is not None: - total_layers *= virtual_pipeline_model_parallel_size - target_loss, target_model = get_target_loss_and_model( - global_batch_shape, hidden_size, total_layers - ) - - for loss_item in loss: - x = loss_item["avg"] - self.assertEqual(x.item() / microbatch_size, target_loss.item()) - - if not forward_only: - for vm_id, model_module in enumerate(model): - params = list(model_module.parameters()) - rank = params[0].get_device() - offset = pipeline_model_parallel_world_size - param_id = rank // data_parallel_size + vm_id * offset - target_params = target_model[param_id] - - self.assertEqual(params[0].cpu(), target_params[0]) - self.assertEqual(params[1].cpu(), target_params[1]) - self.assertEqual( - params[0].grad.cpu() / microbatch_size, - target_params[0].grad, - ) - self.assertEqual( - params[1].grad.cpu() / microbatch_size, - target_params[1].grad, - ) - - if not forward_only: - for m in model: - for p in m.parameters(): - self.assertIsNotNone(p.grad) - optimizer.step() - optimizer.zero_grad(set_to_none=True) - - parallel_state.destroy_model_parallel() - - def test_learning_no_pipelining(self): - self._forward_backward_test_impl(False, forward_backward_no_pipelining, 1, None) - - def test_inference_no_pipelining(self): - self._forward_backward_test_impl(True, forward_backward_no_pipelining, 1, None) - - def test_learning_pipelining_without_interleaving(self, sync_batch_comm: bool = True): - self._forward_backward_test_impl( - False, - forward_backward_pipelining_without_interleaving, - None, - None, - sync_batch_comm=sync_batch_comm, - ) - - def test_inference_pipelining_without_interleaving(self, sync_batch_comm: bool = True): - self._forward_backward_test_impl( - True, - forward_backward_pipelining_without_interleaving, - None, - None, - sync_batch_comm=sync_batch_comm, - ) - - def test_learning_async_pipelining_without_interleaving(self, sync_batch_comm: bool = True): - self._forward_backward_test_impl( - False, - forward_backward_pipelining_without_interleaving, - None, - None, - async_comm=True, - sync_batch_comm=sync_batch_comm, - ) - - def test_inference_async_pipelining_without_interleaving(self, sync_batch_comm: bool = True): - self._forward_backward_test_impl( - True, - forward_backward_pipelining_without_interleaving, - None, - None, - async_comm=True, - sync_batch_comm=sync_batch_comm, - ) - - # fails on native ucc: times out - @unittest.skipUnless( - _get_default_world_sizes_model_parallel_world_size()[-1] > 2, - "Interleaved schedule requires pipeline_model_parallel_world_size > 2", - ) - def test_learning_pipelining_with_interleaving(self, sync_batch_comm: bool = True): - self._forward_backward_test_impl( - False, - _forward_backward_pipelining_with_interleaving, - None, - virtual_pipeline_model_parallel_size=2, - sync_batch_comm=sync_batch_comm, - ) - - # fails on native ucc: times out - @unittest.skipUnless( - _get_default_world_sizes_model_parallel_world_size()[-1] > 2, - "Interleaved schedule requires pipeline_model_parallel_world_size > 2", - ) - def test_inference_pipelining_with_interleaving(self, sync_batch_comm: bool = True): - self._forward_backward_test_impl( - True, - _forward_backward_pipelining_with_interleaving, - None, - virtual_pipeline_model_parallel_size=2, - sync_batch_comm=sync_batch_comm, - ) - - # fails on native ucc: times out - @unittest.skipUnless( - _get_default_world_sizes_model_parallel_world_size()[-1] > 2, - "Interleaved schedule requires pipeline_model_parallel_world_size > 2", - ) - def test_learning_async_pipelining_with_interleaving(self, sync_batch_comm: bool = True): - self._forward_backward_test_impl( - False, - _forward_backward_pipelining_with_interleaving, - None, - virtual_pipeline_model_parallel_size=2, - async_comm=True, - sync_batch_comm=sync_batch_comm, - ) - - # fails on native ucc: times out - @unittest.skipUnless( - _get_default_world_sizes_model_parallel_world_size()[-1] > 2, - "Interleaved schedule requires pipeline_model_parallel_world_size > 2", - ) - def test_inference_async_pipelining_with_interleaving(self, sync_batch_comm: bool = True): - self._forward_backward_test_impl( - True, - _forward_backward_pipelining_with_interleaving, - None, - virtual_pipeline_model_parallel_size=2, - async_comm=True, - sync_batch_comm=sync_batch_comm, - ) - - -class NcclPipelineParallelForwardBackwardTest( - NcclDistributedTestBase, PipelineParallelForwardBackwardTestBase -): - @property - def world_size(self) -> int: - return min(torch.cuda.device_count(), 8) - - def _run_hybrid_distributed_backend(self, forward_only: bool) -> None: - self._forward_backward_test_impl( - forward_only, - forward_backward_pipelining_without_interleaving, - None, - None, - default_backend="nccl", - p2p_backend="ucc", - ) - - @unittest.skipUnless(HAS_TORCH_UCC_COMPAT_NVIDIA_DRIVER, "Needs driver >= 470.42.01") - def _test_hybrid_backends(self, forward_only: bool) -> None: - if HAS_UCC: - self._run_hybrid_distributed_backend(forward_only) - else: - with self.assertRaisesRegex( - ImportError, - re.escape( - "UCC backend requires pytorch source build with UCC installed and enabled" - ), - ): - self._run_hybrid_distributed_backend(forward_only) - - def test_learning_pipelining_without_interleaving_ucc_for_p2p(self): - self._test_hybrid_backends(False) - - def test_inference_pipelining_without_interleaving_ucc_for_p2p(self): - self._test_hybrid_backends(True) - - @unittest.skipUnless( - CAN_SKIP_SYNC_AFTER_BATCH_ISEND_IRECV, - "Requires https://github.com/pytorch/pytorch/pull/82450", - ) - def test_learning_pipelining_without_interleaving_skyp_sync_after_batch_isend_irecv( - self, - ): - self.test_learning_pipelining_without_interleaving(sync_batch_comm=False) - - @unittest.skipUnless( - CAN_SKIP_SYNC_AFTER_BATCH_ISEND_IRECV, - "Requires https://github.com/pytorch/pytorch/pull/82450", - ) - def test_inference_pipelining_without_interleaving_skip_sync_after_batch_isend_irecv( - self, - ): - self.test_inference_pipelining_without_interleaving(sync_batch_comm=False) - - @unittest.skipUnless( - CAN_SKIP_SYNC_AFTER_BATCH_ISEND_IRECV, - "Requires https://github.com/pytorch/pytorch/pull/82450", - ) - def test_learning_async_pipelining_without_interleaving_skip_sync_after_batch_isend_irecv( - self, - ): - self.test_learning_async_pipelining_without_interleaving(sync_batch_comm=False) - - @unittest.skipUnless( - CAN_SKIP_SYNC_AFTER_BATCH_ISEND_IRECV, - "Requires https://github.com/pytorch/pytorch/pull/82450", - ) - def test_inference_async_pipelining_without_interleaving_skip_sync_after_batch_isend_irecv( - self, - ): - self.test_inference_async_pipelining_without_interleaving(sync_batch_comm=False) - - @unittest.skipUnless( - CAN_SKIP_SYNC_AFTER_BATCH_ISEND_IRECV, - "Requires https://github.com/pytorch/pytorch/pull/82450", - ) - def test_learning_pipelining_with_interleaving_skip_sync_after_batch_isend_irecv( - self, - ): - self.test_learning_pipelining_with_interleaving(sync_batch_comm=False) - - @unittest.skipUnless( - CAN_SKIP_SYNC_AFTER_BATCH_ISEND_IRECV, - "Requires https://github.com/pytorch/pytorch/pull/82450", - ) - def test_inference_pipelining_with_interleaving_skip_sync_after_batch_isend_irecv( - self, - ): - self.test_inference_pipelining_with_interleaving(sync_batch_comm=False) - - @unittest.skipUnless( - CAN_SKIP_SYNC_AFTER_BATCH_ISEND_IRECV, - "Requires https://github.com/pytorch/pytorch/pull/82450", - ) - def test_learning_async_pipelining_with_interleaving_skip_sync_after_batch_isend_irecv( - self, - ): - self.test_learning_async_pipelining_with_interleaving(sync_batch_comm=False) - - @unittest.skipUnless( - CAN_SKIP_SYNC_AFTER_BATCH_ISEND_IRECV, - "Requires https://github.com/pytorch/pytorch/pull/82450", - ) - def test_inference_async_pipelining_with_interleaving_skip_sync_after_batch_isend_irecv( - self, - ): - self.test_inference_async_pipelining_with_interleaving(sync_batch_comm=False) - - -# n.b.(mkozuki): pipeline parallel w/o interleaving with UCX_TLS=tcp,sm fails. -class UccPipelineParallelForwardBackwardTest( - UccDistributedTestBase, PipelineParallelForwardBackwardTestBase -): - @property - def world_size(self) -> int: - return min(torch.cuda.device_count(), 8) - - deallocate_options = (False,) - dtypes = (torch.float32,) - - -# Sanity checking the functionality of `forward_backward_pipelining_without_interleaving` with -# `model_type=ModelType.encoder_and_decoder` which is used for pipeline training of transformer -# models such as T5. -@unittest.skipIf(torch.cuda.device_count() < 4, "Requires >= 4 GPUs") -class NcclPipelineParallelWithToyParallelMLP(NcclDistributedTestBase): - GLOBAL_BATCH_SIZE: int = 16 - MICRO_BATCH_SIZE: int = 2 - HIDDEN_SIZE: int = 64 - # TODO(mkozuki): Change `DECODER_SEQUENCE_LENGTH` to a value different from `ENCODER_SEQUENCE_LENGTH`. - # To test forward_backward_pipelining_without_interleaving with `model_type=ModelType.encoder_and_decoder`, - # `decoder_seq_length` is necessary and ideally should be different from `encoder_sequence_length` - # but my laziness let me use the same value. - # Note that you may have to either update `MyModel` def or define another `MyModel`. - # to support different `DECODER_SEQUENCE_LENGTH`. - ENCODER_SEQUENCE_LENGTH: int = 32 - DECODER_SEQUENCE_LENGTH: int = 32 - - @property - def world_size(self) -> int: - return min(torch.cuda.device_count(), 8) - - # TODO(mkozuki): Set `tensor_model_parallel>1` for encoder_and_decoder as well if there's enough GPUs - # in order to let `sequence_parallel_enabled` have an effect on tensor shape logic. - def _forward_backward_test_impl( - self, - *, - forward_only: bool, - sequence_parallel_enabled: bool, - model_type: ModelType, - dtype: torch.dtype = torch.float32, - ) -> None: - # N.B.(mkozuki): It might be better to set `tensor_model_parallel_size` to >1 - # if `self.world_size > 5`. Otherwise, `pipeline_model_parallel_split_rank` - # can be 1, which can be too far real usecase. - tensor_model_parallel_size = 1 + int(self.world_size >= 4) - pipeline_model_parallel_world_size = self.world_size // tensor_model_parallel_size - if model_type == ModelType.encoder_and_decoder: - pipeline_model_parallel_split_rank = pipeline_model_parallel_world_size // 2 - else: - pipeline_model_parallel_split_rank = None - - parallel_state.initialize_model_parallel( - tensor_model_parallel_size_=tensor_model_parallel_size, - pipeline_model_parallel_size_=pipeline_model_parallel_world_size, - virtual_pipeline_model_parallel_size_=None, - pipeline_model_parallel_split_rank_=pipeline_model_parallel_split_rank, - ) - testing_utils.set_random_seed(567) - pp_utils._reconfigure_microbatch_calculator( - rank=parallel_state.get_tensor_model_parallel_rank(), - rampup_batch_size=None, - global_batch_size=self.GLOBAL_BATCH_SIZE, - micro_batch_size=self.MICRO_BATCH_SIZE, - data_parallel_size=parallel_state.get_data_parallel_world_size(), - ) - # TODO(mkozuki): Call `build_model` with `model_type`. - model = build_model( - testing_utils.mlp_provider_func, - wrap_with_ddp=False, - virtual_pipeline_model_parallel_size=None, - hidden_size=self.HIDDEN_SIZE, - sequence_parallel_enabled=sequence_parallel_enabled, - ) - model = [m.to(dtype=dtype) for m in model] - - if parallel_state.is_pipeline_first_stage(): - batch: Tuple[torch.Tensor] = ( - torch.ones( - ( - self.GLOBAL_BATCH_SIZE, - self.ENCODER_SEQUENCE_LENGTH, - self.HIDDEN_SIZE, - ), - dtype=dtype, - device="cuda", - ), - ) - else: - batch = None - - forward_backward_pipelining_without_interleaving( - forward_step_func=testing_utils.ToyParallelMLPFwdBwdStepFunc( - sequence_parallel_enabled=sequence_parallel_enabled, - ), - batch=batch, - model=model, - forward_only=forward_only, - tensor_shape=( - self.ENCODER_SEQUENCE_LENGTH, - self.MICRO_BATCH_SIZE, - self.HIDDEN_SIZE, - ), - model_type=model_type, - decoder_sequence_length=self.DECODER_SEQUENCE_LENGTH, - async_comm=False, - grad_scaler=None, - deallocate_pipeline_outputs=False, - dtype=dtype, - sequence_parallel_enabled=sequence_parallel_enabled, - ) - - def test_pipelining_without_interleaving_encoder_and_decoder(self) -> None: - self._forward_backward_test_impl( - forward_only=False, - sequence_parallel_enabled=False, - model_type=ModelType.encoder_and_decoder, - ) - - def test_pipelining_without_interleaving_inferenc_encoder_and_decoder(self) -> None: - self._forward_backward_test_impl( - forward_only=True, - sequence_parallel_enabled=False, - model_type=ModelType.encoder_and_decoder, - ) - - def test_pipelining_without_interleaving_sequence_paralle_encoder_and_decoder( - self, - ) -> None: - self._forward_backward_test_impl( - forward_only=False, - sequence_parallel_enabled=True, - model_type=ModelType.encoder_and_decoder, - ) - - def test_pipelining_without_interleaving_inference_sequence_paralle_encoder_and_decoder( - self, - ) -> None: - self._forward_backward_test_impl( - forward_only=True, - sequence_parallel_enabled=True, - model_type=ModelType.encoder_and_decoder, - ) - - def test_pipelining_without_interleaving_encoder_or_decoder(self) -> None: - self._forward_backward_test_impl( - forward_only=False, - sequence_parallel_enabled=False, - model_type=ModelType.encoder_or_decoder, - ) - - def test_pipelining_without_interleaving_sequence_parallel_encoder_or_decoder( - self, - ) -> None: - self._forward_backward_test_impl( - forward_only=False, - sequence_parallel_enabled=True, - model_type=ModelType.encoder_or_decoder, - ) - - def test_pipelining_without_interleaving_sequence_parallel_encoder_or_decoder_half( - self, - ) -> None: - self._forward_backward_test_impl( - forward_only=False, - sequence_parallel_enabled=True, - model_type=ModelType.encoder_or_decoder, - dtype=torch.half, - ) - - -class NcclPipelineParallelWithCustomSyncContextHandler(NcclDistributedTestBase): - GLOBAL_BATCH_SIZE = 32 - MICRO_BATCH_SIZE = 1 - HIDDEN_SIZE = 1 - - @property - def world_size(self) -> int: - return min(torch.cuda.device_count(), 8) - - @unittest.skipIf( - torch.cuda.device_count() < 2 or torch.cuda.device_count() % 2 != 0, - "Requires >= 2 GPUs", - ) - def test_pipelining_without_interleaving_with_custom_sync_context_handler( - self, - ) -> None: - # Parallel configuration - world_size = torch.cuda.device_count() - tensor_model_parallel_world_size = 1 - data_parallel_size = 2 if world_size > 2 else 1 - pipeline_model_parallel_world_size = world_size // data_parallel_size - - # Initialize pipeline parallelism - parallel_state.initialize_model_parallel( - tensor_model_parallel_size_=tensor_model_parallel_world_size, - pipeline_model_parallel_size_=pipeline_model_parallel_world_size, - ) - pp_utils._reconfigure_microbatch_calculator( - rank=parallel_state.get_tensor_model_parallel_rank(), - rampup_batch_size=None, - global_batch_size=self.GLOBAL_BATCH_SIZE, - micro_batch_size=self.MICRO_BATCH_SIZE, - data_parallel_size=parallel_state.get_data_parallel_world_size(), - ) - pp_utils.update_num_microbatches(0) - - # Construct synthetic data - dtype = get_dtype_for_comparison() - hidden_size = self.HIDDEN_SIZE - microbatch_size = self.MICRO_BATCH_SIZE - global_batch_shape = ( - self.GLOBAL_BATCH_SIZE // parallel_state.get_data_parallel_world_size(), - hidden_size, - hidden_size, - ) - batch = None - if parallel_state.is_pipeline_first_stage(): - batch = (torch.ones(global_batch_shape, dtype=dtype).cuda(),) - - # Construct model - model = build_model( - testing_utils.model_provider_func, - wrap_with_ddp=True, - hidden_size=hidden_size, - )[0] - model = model.to(dtype) - model.module.apply(get_init_weights_func(0)) - - # Construct context that destroys all grads on exit - has_entered_grad_sync_context = False - has_exited_grad_sync_context = False - has_called_grad_sync_func = False - - @contextlib.contextmanager - def custom_grad_sync_context(): - try: - nonlocal has_entered_grad_sync_context - has_entered_grad_sync_context = True - yield - finally: - nonlocal has_exited_grad_sync_context - has_exited_grad_sync_context = True - for param in model.parameters(): - param.grad = None - - def custom_grad_sync_func(): - nonlocal has_called_grad_sync_func - has_called_grad_sync_func = True - - # Training step with pipeline parallelism - loss = forward_backward_pipelining_without_interleaving( - testing_utils.fwd_step_func, - batch, - model, - forward_only=False, - tensor_shape=(microbatch_size, hidden_size, hidden_size), - dtype=dtype, - async_comm=False, - grad_scaler=None, - deallocate_pipeline_outputs=False, - sequence_parallel_enabled=False, - custom_sync_context_handler=custom_grad_sync_context, - custom_grad_sync_func=custom_grad_sync_func, - ) - torch.cuda.synchronize() - - # Check if model has initialized gradients - has_any_grads = any(param.grad is not None for param in model.parameters()) - has_all_grads = all(param.grad is not None for param in model.parameters()) - - # Check context behavior - self.assertTrue(has_entered_grad_sync_context, "Has not entered custom sync context") - self.assertTrue(has_exited_grad_sync_context, "Has not exited custom sync context") - self.assertEqual( - has_any_grads, - has_all_grads, - "Expected gradients to all be uninitialized or all be initialized", - ) - self.assertEqual( - has_all_grads, - parallel_state.is_pipeline_first_stage(), - "Expected gradients to be initialized only in first pipeline stage", - ) - - # Clean up - parallel_state.destroy_model_parallel() - - @unittest.skipIf( - torch.cuda.device_count() < 4 or torch.cuda.device_count() % 2 != 0, - "Requires >= 4 GPUs", - ) - def test_pipelining_with_interleaving_with_custom_sync_context_handler( - self, - ) -> None: - # Parallel configuration - world_size = torch.cuda.device_count() - tensor_model_parallel_world_size = 1 - data_parallel_size = 2 if world_size > 4 else 1 - pipeline_model_parallel_world_size = world_size // data_parallel_size - virtual_pipeline_model_parallel_size = 2 - - # Initialize pipeline parallelism - parallel_state.initialize_model_parallel( - tensor_model_parallel_size_=tensor_model_parallel_world_size, - pipeline_model_parallel_size_=pipeline_model_parallel_world_size, - virtual_pipeline_model_parallel_size_=virtual_pipeline_model_parallel_size, - ) - pp_utils._reconfigure_microbatch_calculator( - rank=parallel_state.get_tensor_model_parallel_rank(), - rampup_batch_size=None, - global_batch_size=self.GLOBAL_BATCH_SIZE, - micro_batch_size=self.MICRO_BATCH_SIZE, - data_parallel_size=parallel_state.get_data_parallel_world_size(), - ) - pp_utils.update_num_microbatches(0) - - # Construct synthetic data - dtype = get_dtype_for_comparison() - hidden_size = self.HIDDEN_SIZE - microbatch_size = self.MICRO_BATCH_SIZE - global_batch_shape = ( - self.GLOBAL_BATCH_SIZE // parallel_state.get_data_parallel_world_size(), - hidden_size, - hidden_size, - ) - batch = None - if parallel_state.is_pipeline_first_stage(): - batch = (torch.ones(global_batch_shape, dtype=dtype).cuda(),) - - # Construct model - model = build_model( - testing_utils.model_provider_func, - wrap_with_ddp=True, - virtual_pipeline_model_parallel_size=virtual_pipeline_model_parallel_size, - hidden_size=hidden_size, - ) - for module in model: - module.to(dtype) - module.module.apply(get_init_weights_func(0)) - - # Construct context that keeps track whenever entered/exited - grad_sync_context_enter_count = 0 - grad_sync_context_exit_count = 0 - - @contextlib.contextmanager - def custom_grad_sync_context(): - try: - nonlocal grad_sync_context_enter_count - grad_sync_context_enter_count += 1 - yield - finally: - nonlocal grad_sync_context_exit_count - grad_sync_context_exit_count += 1 - for module in model: - for param in module.parameters(): - param.grad = None - - # Training step with pipeline parallelism - loss = _forward_backward_pipelining_with_interleaving( - testing_utils.fwd_step_func, - batch, - model, - forward_only=False, - tensor_shape=(microbatch_size, hidden_size, hidden_size), - dtype=dtype, - async_comm=False, - grad_scaler=None, - deallocate_pipeline_outputs=False, - sequence_parallel_enabled=False, - custom_sync_context_handler=custom_grad_sync_context, - ) - torch.cuda.synchronize() - - # Check context behavior - self.assertTrue( - grad_sync_context_enter_count > 0, - "Has not entered custom sync context", - ) - self.assertEqual( - grad_sync_context_enter_count, - grad_sync_context_exit_count, - "Has not entered and exited custom sync context the same number of times", - ) - self.assertEqual( - grad_sync_context_exit_count, - virtual_pipeline_model_parallel_size + 1, - "Expected to exit custom sync context once per model chunk " - "and once at the function end", - ) - - # Clean up - parallel_state.destroy_model_parallel() - - -if __name__ == "__main__": - common_utils.run_tests() diff --git a/tests/L0/run_transformer/test_random.py b/tests/L0/run_transformer/test_random.py deleted file mode 100644 index 3b268dba7..000000000 --- a/tests/L0/run_transformer/test_random.py +++ /dev/null @@ -1,122 +0,0 @@ -import logging - -import torch -from torch.testing._internal import common_utils - -logging.getLogger("torch").setLevel(logging.WARNING) - -from apex.transformer import parallel_state -from apex.transformer import tensor_parallel -from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase -from apex.transformer.testing.distributed_test_base import UccDistributedTestBase - -logging.getLogger("apex").setLevel(logging.WARNING) - - -class TransformerRandomTestBase: - def test_set_cuda_rng_state(self): - for tensor_model_parallel_world_size in range(1, self.world_size + 1): - if self.world_size % tensor_model_parallel_world_size: - continue - msg = f"tensor_model_parallel_world_size: {tensor_model_parallel_world_size}" - parallel_state.initialize_model_parallel( - tensor_model_parallel_size_=tensor_model_parallel_world_size - ) - - size, seed = 123, 1234 - torch.cuda.manual_seed(seed) - tensor = torch.cuda.FloatTensor(size) - - rng_state = torch.cuda.get_rng_state() - rng_state_clone = rng_state.clone() - - for _ in range(5): - torch.randn(size, out=tensor) - result_1 = tensor.clone() - - self.assertEqual(rng_state.sub(rng_state_clone).max(), 0, msg=msg) - self.assertGreater( - torch.cuda.get_rng_state().sub(rng_state_clone).max(), - 0, - msg=msg, - ) - - new_rng_state = torch.cuda.get_rng_state() - self.assertGreater(new_rng_state.sub(rng_state).max(), 0, msg=msg) - - tensor_parallel.random._set_cuda_rng_state(rng_state) - for _ in range(5): - torch.randn(size, out=tensor) - tensor_parallel.random._set_cuda_rng_state(rng_state) - for _ in range(5): - torch.randn(size, out=tensor) - result_2 = tensor.clone() - - self.assertEqual(result_2, result_1, msg=msg) - - self.assertEqual(rng_state.sub(rng_state_clone).max(), 0, msg=msg) - - parallel_state.destroy_model_parallel() - - def test_cuda_rng_tracker(self): - for tensor_model_parallel_world_size in range(1, self.world_size + 1): - if self.world_size % tensor_model_parallel_world_size: - continue - msg = f"tensor_model_parallel_world_size: {tensor_model_parallel_world_size}" - parallel_state.initialize_model_parallel( - tensor_model_parallel_size_=tensor_model_parallel_world_size - ) - - seed_1, seed_2, size = 1234, 4321, [12, 21] - tensor = torch.cuda.FloatTensor(size) - - torch.cuda.manual_seed(seed_1) - torch.randn(size, out=tensor) - target_11 = tensor.clone() - torch.randn(size, out=tensor) - target_12 = tensor.clone() - - torch.cuda.manual_seed(seed_2) - torch.randn(size, out=tensor) - targt_21 = tensor.clone() - torch.randn(size, out=tensor) - target_22 = tensor.clone() - - torch.cuda.manual_seed(seed_1) - tensor_parallel.random.get_cuda_rng_tracker().add("test", seed_2) - - torch.randn(size, out=tensor) - result_11 = tensor.clone() - - with tensor_parallel.random.get_cuda_rng_tracker().fork("test"): - torch.randn(size, out=tensor) - result_21 = tensor.clone() - - torch.randn(size, out=tensor) - result_12 = tensor.clone() - - with tensor_parallel.random.get_cuda_rng_tracker().fork("test"): - torch.randn(size, out=tensor) - result_22 = tensor.clone() - - self.assertEqual(target_11, result_11, msg=msg) - self.assertEqual(target_12, result_12, msg=msg) - self.assertEqual(targt_21, result_21, msg=msg) - self.assertEqual(target_22, result_22, msg=msg) - self.assertNotEqual(result_11, result_21, msg=msg) - self.assertNotEqual(result_21, result_22, msg=msg) - - tensor_parallel.random.get_cuda_rng_tracker().reset() - parallel_state.destroy_model_parallel() - - -class NcclTransformerRandomTest(TransformerRandomTestBase, NcclDistributedTestBase): - pass - - -class UccTransformerRandomTest(TransformerRandomTestBase, UccDistributedTestBase): - pass - - -if __name__ == "__main__": - common_utils.run_tests() diff --git a/tests/L0/run_transformer/test_transformer_utils.py b/tests/L0/run_transformer/test_transformer_utils.py deleted file mode 100644 index d0af2ed22..000000000 --- a/tests/L0/run_transformer/test_transformer_utils.py +++ /dev/null @@ -1,41 +0,0 @@ -import logging - -import torch -from torch.testing._internal import common_utils - -logging.getLogger("torch").setLevel(logging.WARNING) - -from apex.transformer import parallel_state -from apex.transformer.tensor_parallel import utils -from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase - -logging.getLogger("apex").setLevel(logging.WARNING) - - -class TransformerUtilsTest(NcclDistributedTestBase): - def test_split_tensor_along_last_dim(self): - for tensor_model_paralell_world_size in range(1, self.world_size + 1): - if self.world_size % tensor_model_paralell_world_size > 0: - continue - parallel_state.initialize_model_parallel( - tensor_model_parallel_size_=tensor_model_paralell_world_size - ) - - device = "cpu" - input_tensor = torch.randn((100, 100, 100), device=device) - splits = utils.split_tensor_along_last_dim(input_tensor, 10) - last_dim_shapes = torch.tensor([int(split.size()[-1]) for split in splits]) - - self.assertTrue( - torch.equal( - last_dim_shapes, - torch.full((10,), 10), - ), - msg=f"tensor_model_paralell_world_size: {tensor_model_paralell_world_size}", - ) - - parallel_state.destroy_model_parallel() - - -if __name__ == "__main__": - common_utils.run_tests() diff --git a/tests/L1/transformer/pipeline_parallel_fwd_bwd_ucc_async.py b/tests/L1/transformer/pipeline_parallel_fwd_bwd_ucc_async.py deleted file mode 100644 index 6bf60fffa..000000000 --- a/tests/L1/transformer/pipeline_parallel_fwd_bwd_ucc_async.py +++ /dev/null @@ -1,265 +0,0 @@ -import os -import logging -import itertools -from typing import Optional, Tuple -import unittest - -import torch -from torch.testing._internal import common_utils -from torch.testing._internal import common_distributed - -from apex._autocast_utils import _get_autocast_dtypes -from apex.transformer import parallel_state -from apex.transformer.pipeline_parallel import utils as pp_utils -from apex.transformer.pipeline_parallel.schedules.common import ( - FwdStepFunc, - build_model, - _get_params_for_weight_decay_optimization, -) -from apex.transformer.pipeline_parallel.schedules.fwd_bwd_no_pipelining import ( - forward_backward_no_pipelining, -) -from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_with_interleaving import ( - _forward_backward_pipelining_with_interleaving, -) -from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_without_interleaving import ( - forward_backward_pipelining_without_interleaving, -) -from apex.transformer.testing.distributed_test_base import UccDistributedTestBase -from apex.transformer.testing import commons as testing_utils - - -logging.getLogger("torch").setLevel(logging.WARNING) -logging.getLogger("apex").setLevel(logging.WARNING) - - -def _get_default_world_sizes_model_parallel_world_size( - pipeline_model_parallel_world_size: Optional[int] = None, -) -> Tuple[int, int, int]: - # TODO: revisit if we can fold this into the class for skip logic / avoid duplication - # of world size computation - world_size = torch.cuda.device_count() - tensor_model_parallel_world_size = 1 - data_parallel_size = 1 + (world_size >= 8 and world_size % 2 == 0) - - if pipeline_model_parallel_world_size is None: - pipeline_model_parallel_world_size = world_size // ( - tensor_model_parallel_world_size * data_parallel_size - ) - else: - data_parallel_size = world_size // ( - tensor_model_parallel_world_size * pipeline_model_parallel_world_size - ) - - return ( - tensor_model_parallel_world_size, - data_parallel_size, - pipeline_model_parallel_world_size, - ) - - -class UccPipelineParallelForwardBackwardProf(UccDistributedTestBase): - # The purpose of this class is to test and confirm asynchronous communication via profiling. - # Having that in mind, it is safe to skip all the numerical checks. - # For unit testing with numerical checks please refer to `tests/L0/run_transformer/test_pipeline_parallel_fwd_bwd.py`. - - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - self.GLOBAL_BATCH_SIZE = 1024 - self.MICRO_BATCH_SIZE = 64 - self.HIDDEN_SIZE = 256 - self.NUM_FWD_BWD_ITERATIONS = 4 - self.deallocate_options = (False,) - self.dtypes = (torch.float32,) - - @property - def world_size(self) -> int: - return min(torch.cuda.device_count(), 8) - - def _forward_backward_test_impl( - self, - forward_only: bool, - fwd_bwd_func: FwdStepFunc, - pipeline_model_parallel_world_size: Optional[int], - virtual_pipeline_model_parallel_size: Optional[int], - async_comm: bool = False, - *, - default_backend: Optional[str] = None, - p2p_backend: Optional[str] = None, - ) -> None: - if fwd_bwd_func == _forward_backward_pipelining_with_interleaving: - self.assertIsNotNone(virtual_pipeline_model_parallel_size) - self.assertGreater(virtual_pipeline_model_parallel_size, 1) - dtype_options = self.dtypes or [torch.float32, torch.double] + _get_autocast_dtypes() - - for dtype, deallocate_pipeline_outputs in itertools.product( - dtype_options, - self.deallocate_options, - ): - grad_scaler = ( - torch.amp.GradScaler("cuda", init_scale=4.0) if dtype == torch.half else None - ) - - ( - tensor_model_parallel_world_size, - data_parallel_size, - pipeline_model_parallel_world_size, - ) = _get_default_world_sizes_model_parallel_world_size( - pipeline_model_parallel_world_size - ) - - parallel_state.initialize_model_parallel( - tensor_model_parallel_size_=tensor_model_parallel_world_size, - pipeline_model_parallel_size_=pipeline_model_parallel_world_size, - virtual_pipeline_model_parallel_size_=virtual_pipeline_model_parallel_size, - default_backend=default_backend, - p2p_backend=p2p_backend, - ) - pp_utils._reconfigure_microbatch_calculator( - rank=parallel_state.get_tensor_model_parallel_rank(), - rampup_batch_size=None, - global_batch_size=self.GLOBAL_BATCH_SIZE, - micro_batch_size=self.MICRO_BATCH_SIZE, - data_parallel_size=parallel_state.get_data_parallel_world_size(), - ) - - global_batch_shape = ( - self.GLOBAL_BATCH_SIZE // parallel_state.get_data_parallel_world_size(), - self.HIDDEN_SIZE, - self.HIDDEN_SIZE, - ) - - batch = None - if parallel_state.is_pipeline_first_stage(): - batch = (torch.ones(global_batch_shape, dtype=dtype).cuda(),) - - model = build_model( - testing_utils.model_provider_func, - # Use DDP only when it's better to have - wrap_with_ddp=data_parallel_size > 1, - virtual_pipeline_model_parallel_size=virtual_pipeline_model_parallel_size, - hidden_size=self.HIDDEN_SIZE, - ) - - offset = ( - pipeline_model_parallel_world_size - if virtual_pipeline_model_parallel_size is not None - else 0 - ) - for idx, model_module in enumerate(model): - model_module = model_module.to(dtype) - - _param_groups = _get_params_for_weight_decay_optimization(model) - optimizer = torch.optim.Adam(_param_groups, lr=1e-3) - - pp_utils.update_num_microbatches(0) - - for _ in range(self.NUM_FWD_BWD_ITERATIONS): - loss = fwd_bwd_func( - testing_utils.fwd_step_func, - batch, - model, - forward_only=forward_only, - # `tensor_shape` is the shape of micro batch. - tensor_shape=( - self.MICRO_BATCH_SIZE, - self.HIDDEN_SIZE, - self.HIDDEN_SIZE, - ), - dtype=dtype, - async_comm=async_comm, - grad_scaler=grad_scaler, - deallocate_pipeline_output=deallocate_pipeline_outputs, - ) - - parallel_state.destroy_model_parallel() - - def test_learning_no_pipelining(self): - self._forward_backward_test_impl(False, forward_backward_no_pipelining, 1, None) - - def test_inference_no_pipelining(self): - self._forward_backward_test_impl(True, forward_backward_no_pipelining, 1, None) - - def test_learning_pipelining_without_interleaving(self): - self._forward_backward_test_impl( - False, forward_backward_pipelining_without_interleaving, None, None - ) - - def test_inference_pipelining_without_interleaving(self): - self._forward_backward_test_impl( - True, forward_backward_pipelining_without_interleaving, None, None - ) - - def test_learning_async_pipelining_without_interleaving(self): - self._forward_backward_test_impl( - False, - forward_backward_pipelining_without_interleaving, - None, - None, - async_comm=True, - ) - - def test_inference_async_pipelining_without_interleaving(self): - self._forward_backward_test_impl( - True, - forward_backward_pipelining_without_interleaving, - None, - None, - async_comm=True, - ) - - @unittest.skipUnless( - _get_default_world_sizes_model_parallel_world_size()[-1] > 2, - "Interleaved schedule requires pipeline_model_parallel_world_size > 2", - ) - def test_learning_pipelining_with_interleaving(self): - self._forward_backward_test_impl( - False, - _forward_backward_pipelining_with_interleaving, - None, - virtual_pipeline_model_parallel_size=2, - ) - - @unittest.skipUnless( - _get_default_world_sizes_model_parallel_world_size()[-1] > 2, - "Interleaved schedule requires pipeline_model_parallel_world_size > 2", - ) - def test_inference_pipelining_with_interleaving(self): - self._forward_backward_test_impl( - True, - _forward_backward_pipelining_with_interleaving, - None, - virtual_pipeline_model_parallel_size=2, - ) - - @unittest.skipUnless( - _get_default_world_sizes_model_parallel_world_size()[-1] > 2, - "Interleaved schedule requires pipeline_model_parallel_world_size > 2", - ) - def test_learning_async_pipelining_with_interleaving(self): - self._forward_backward_test_impl( - False, - _forward_backward_pipelining_with_interleaving, - None, - virtual_pipeline_model_parallel_size=2, - async_comm=True, - ) - - @unittest.skipUnless( - _get_default_world_sizes_model_parallel_world_size()[-1] > 2, - "Interleaved schedule requires pipeline_model_parallel_world_size > 2", - ) - def test_inference_async_pipelining_with_interleaving(self): - self._forward_backward_test_impl( - True, - _forward_backward_pipelining_with_interleaving, - None, - virtual_pipeline_model_parallel_size=2, - async_comm=True, - ) - - -if __name__ == "__main__": - os.environ["UCC_TLS"] = "ucp,cuda" - common_distributed.TIMEOUT_DEFAULT = 500 - common_utils.run_tests() From d3a1636e26fecf0ca2217e2fcb44b740631b5dd1 Mon Sep 17 00:00:00 2001 From: syedshazli Date: Tue, 23 Dec 2025 18:28:29 -0500 Subject: [PATCH 5/5] remove transformer based tests in Apex --- tests/L0/run_test.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/L0/run_test.py b/tests/L0/run_test.py index 1d92e67ec..7011c37ea 100644 --- a/tests/L0/run_test.py +++ b/tests/L0/run_test.py @@ -22,7 +22,6 @@ "run_optimizers", "run_fused_layer_norm", "run_mlp", - "run_transformer", ] DEFAULT_TEST_DIRS = [ "run_optimizers",