diff --git a/.github/workflows/testTorchMLIR.yml b/.github/workflows/testTorchMLIR.yml deleted file mode 100644 index 9f5ac29bb..000000000 --- a/.github/workflows/testTorchMLIR.yml +++ /dev/null @@ -1,40 +0,0 @@ -name: Torch-MLIR Test - -on: - push: - branches: [ "main*" ] - pull_request: - branches: [ "main*" ] - workflow_dispatch: - logLevel: - description: 'Log level' - required: true - default: 'warning' - type: choice - options: - - info - - warning - - debug - -jobs: - - torch-mlir-test: - runs-on: ubuntu-latest - container: - image: deepwok/mase-docker-cpu:latest - steps: - - # Clone the MASE repo and its submodules. - - name: Get MASE - uses: actions/checkout@v3 - with: - submodules: "true" - - - name: Set git safe - run: | - git config --global --add safe.directory $PWD - - - name: Torch-MLIR regression test - run: | - python3 scripts/test-torch-mlir.py - diff --git a/setup.py b/setup.py index 021884586..c60fb62ec 100644 --- a/setup.py +++ b/setup.py @@ -80,7 +80,6 @@ def get_system(): "sphinx-glpi-theme", "prettytable", "pyyaml", - "pynvml", "bitstring>=4.2", "myst_parser", "cvxpy", @@ -98,7 +97,7 @@ def get_system(): author="Aaron Zhao, Jianyi Cheng, Cheng Zhang, Pedro Gimenes", author_email="a.zhao@imperial.ac.uk, jianyi.cheng17@imperial.ac.uk, chengzhang98@outlook.com, pedro.gimenes19@imperial.ac.uk", license_files=("LICENSE",), - python_requires=">=3.11.9", + python_requires=">=3.11.4", package_dir={ "": "src", }, diff --git a/src/chop/distributed/__init__.py b/src/chop/distributed/__init__.py index 9f188fa07..e69de29bb 100644 --- a/src/chop/distributed/__init__.py +++ b/src/chop/distributed/__init__.py @@ -1 +0,0 @@ -from .launcher import MaseLauncher diff --git a/src/chop/distributed/launcher.py b/src/chop/distributed/launcher.py index 8157636d6..eac339e74 100644 --- a/src/chop/distributed/launcher.py +++ b/src/chop/distributed/launcher.py @@ -1,162 +1,24 @@ -import os -from functools import partial -from time import time - -import torch -import torch.nn as nn -import torch.distributed as dist import torch.multiprocessing as mp -from torch.distributed._tensor import ( - DeviceMesh, - Replicate, - Shard, -) - -from chop.distributed.tensor import distribute_module, distribute_tensor - -from chop.distributed.utils import rlog +from chop.distributed.utils import _get_mesh_from_world_size from ..tools import get_logger logger = get_logger(__name__) logger.setLevel("DEBUG") -def distributed_timing(fn, *args, **kwargs): - dist.barrier(async_op=True) - start = time() - result = fn(*args, **kwargs) - dist.barrier(async_op=True) - end = time() - - return result, (end - start) - - -def distributed_average_timing(fn, repeat, args): - times = [] - for itr in range(repeat): - rlog( - logger, - dist.get_rank(), - f"Running teration {itr}", - "debug", - ) - dist.barrier(async_op=True) - start = time() - result = fn(*args) - dist.barrier(async_op=True) - end = time() - times.append(end - start) - rlog( - logger, - dist.get_rank(), - f"Time taken: {end - start}s", - "debug", - ) - - return result, sum(times[2:]) / len(times[2:]) - - -def dist_model_fn( - name: str, - module: nn.Module, - device_mesh: DeviceMesh, - rank: int, - tensor_sharding_map={}, -) -> None: - """ - This function gets called by torch.distributed._tensor.distribute_module on each module in the model. - Each tensor in each module is distributed according to the sharding configuration in tensor_sharding_map. - """ - if module in tensor_sharding_map: - node_name = tensor_sharding_map[module]["node"] - for parameter, sharding_config in tensor_sharding_map[module][ - "sharding" - ].items(): - if parameter in ["data_in_0", "output", "data_out_0"]: - continue - if not hasattr(module, parameter): - rlog( - logger, - rank, - f"Module {module} does not have parameter {parameter}", - level="warning", - ) - continue - - placement = sharding_config.placements - - try: - rlog( - logger, - rank, - f"Distributing parameter {parameter} of module {node_name} to {placement}", - level="debug", - ) - distributed_tensor = distribute_tensor( - getattr(module, parameter), device_mesh, placement - ) - setattr(module, parameter, torch.nn.Parameter(distributed_tensor)) - except Exception as e: - rlog( - logger, - rank, - f"Error distributing parameter {parameter} of module {node_name} to {placement}: {e}", - level="error", - ) - - -def device_fn( - rank, world_size, model=None, device_mesh=None, tensor_sharding_map={}, inputs=[] -): - """ - This function gets called on each GPU device to set up the distributed environment and distribute the model, - following the SPMD model. - """ - os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "12355" - os.environ["RANK"] = str(rank) - - # Initialize - dist.init_process_group("nccl", rank=rank, world_size=world_size) - device = torch.device("cuda", rank) - torch.cuda.set_device(device) - - # Distribute model parameters according to sharding configuration - mesh = DeviceMesh("cuda", mesh=device_mesh) - rlog(logger, rank, f"Distributing module parameters...", level="info") - model, dist_time = distributed_timing( - distribute_module, - model, - mesh, - partial(dist_model_fn, rank=rank, tensor_sharding_map=tensor_sharding_map), - input_fn=None, - output_fn=None, - ) - rlog(logger, rank, f"Module distribution done. Time taken: {dist_time} seconds.") - - # Run forward pass - rlog(logger, rank, f"Starting forward pass.", level="info") - inputs = [ - distribute_tensor(in_tensor, mesh, [Replicate(), Replicate()]) - for in_tensor in inputs - ] - _, time_taken = distributed_average_timing( - fn=model, - repeat=10, - args=inputs, - ) - rlog(logger, rank, f"Forward pass finished. Time taken: {time_taken}", level="info") - - dist.destroy_process_group() - - class MaseLauncher: """ MaseLauncher launches an optimized model on multiple GPUs using torch.distributed. """ - def __init__(self, mase_graph, world_size=None, device_mesh=None): + def __init__( + self, + mg=None, + world_size=None, + device_mesh=None, + device_fn=None, + ): """Initialize the MaseLauncher. Args: @@ -164,23 +26,30 @@ def __init__(self, mase_graph, world_size=None, device_mesh=None): world_size (int, optional): Number of GPUs to use. Defaults to None. device_mesh (list, optional): List of GPUs to use. Defaults to None. """ - self.mg = mase_graph - self.model = mase_graph.model + self.mg = mg self.world_size = world_size - self.device_mesh = device_mesh + self.device_fn = device_fn + + if device_mesh is None: + self.device_mesh, _ = _get_mesh_from_world_size(world_size) - def run(self, tensor_sharding_map={}, inputs=[]): + def run( + self, + model_class=None, + model_config=None, + cli_args=None, + ): logger.info(f"Launching model with world size {self.world_size}.") mp.spawn( - partial( - device_fn, - model=self.model, - device_mesh=self.device_mesh, - tensor_sharding_map=tensor_sharding_map, - inputs=inputs, + self.device_fn, + args=( + self.world_size, + self.device_mesh, + model_class, + model_config, + cli_args, ), - args=(self.world_size,), nprocs=self.world_size, join=True, ) diff --git a/src/chop/distributed/tensor/__init__.py b/src/chop/distributed/tensor/__init__.py index 3d6067d28..eecce89dc 100644 --- a/src/chop/distributed/tensor/__init__.py +++ b/src/chop/distributed/tensor/__init__.py @@ -14,10 +14,8 @@ ) from torch.distributed.device_mesh import _mesh_resources, DeviceMesh, init_device_mesh -import chop.distributed.tensor.ops from chop.distributed.tensor._utils import compute_local_shape from chop.distributed.tensor.api import distribute_module, distribute_tensor, DTensor -from chop.distributed.tensor.ops.utils import normalize_to_torch_size # All public APIs from dtensor package @@ -33,6 +31,25 @@ ] +def normalize_to_torch_size(size) -> torch.Size: + """ + Unify variable types of size argument to torch.Size + Acceptable types include: + int, Sequence[int], Tuple[int], Tuple[Sequence[int]], + or torch.Size + """ + if isinstance(size, torch.Size): + return size + + if isinstance(size, int): + torch_size = [size] + elif len(size) == 1 and isinstance(size[0], Sequence): + torch_size = list(size[0]) + else: + torch_size = list(size) + return torch.Size(torch_size) + + def _dtensor_init_helper( init_op, size: torch.Size, diff --git a/src/chop/distributed/tensor/_dispatch.py b/src/chop/distributed/tensor/_dispatch.py index b9316729e..3f9d38043 100644 --- a/src/chop/distributed/tensor/_dispatch.py +++ b/src/chop/distributed/tensor/_dispatch.py @@ -1,44 +1,25 @@ # Copyright (c) Meta Platforms, Inc. and affiliates -import contextlib -import functools -import logging -import operator -import warnings from typing import cast, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING import torch -import torch.distributed as dist -import torch.distributed._tensor.random as random from torch.distributed._tensor._op_schema import ( - _is_inplace_op, - _is_out_variant_op, - OpInfo, - OpSchema, OutputSpecType, ) from torch.distributed._tensor._tp_conv import ( convolution_backward_handler, convolution_handler, ) -from torch.distributed._tensor._utils import try_find_mesh_from_args -from torch.distributed._tensor.placement_types import DTensorSpec, Replicate, TensorMeta -from torch.distributed._tensor.random import is_rng_supported_mesh - - -if TYPE_CHECKING: - from torch.distributed.device_mesh import DeviceMesh +from torch.distributed._tensor.placement_types import ( + DTensorSpec, + Replicate, + TensorMeta, +) -try: - from torch.utils import _cxx_pytree as pytree -except ImportError: - from torch.utils import _pytree as pytree # type: ignore[no-redef] +from torch.distributed.device_mesh import DeviceMesh import chop.distributed.tensor.api as dtensor -from chop.distributed.tensor._sharding_prop import ShardingPropagator -from chop.distributed.tensor._redistribute import redistribute_local_tensor aten = torch.ops.aten -logger = logging.getLogger(__name__) def decompose_handler( @@ -67,12 +48,6 @@ def is_same_size_handler( return lhs.shape == rhs.shape -def rlog(msg): - rank = torch.distributed.get_rank() - if rank == 0: - print(msg) - - class OpDispatcher: """ Op dispatching class instance to handle args/kwargs pre-processing (un-wrapping), sharding @@ -81,7 +56,6 @@ class OpDispatcher: """ def __init__(self) -> None: - self.sharding_propagator = ShardingPropagator() self._random_ops = { aten.native_dropout.default, aten.normal_.default, @@ -103,8 +77,6 @@ def __init__(self) -> None: # This flag is used internally to control whether we treat the torch.Tensor(non-DTensor) # as implicitly replicated or we throw error to user. - # NOTE: It is EXTREMELY UNSAFE to turn this flag on by default so we intentionally leave - # it as False by default. self._allow_implicit_replication = True def dispatch( @@ -118,307 +90,60 @@ def dispatch( """ # operators that does not need to go through sharding propagation - if op_call in self._custom_op_handlers: - return self._custom_op_handlers[op_call](op_call, args, kwargs) # type: ignore[operator] - - # extract local tensor and sharding infos to a OpInfo - op_info = self.unwrap_to_op_info(op_call, args, kwargs) - logger.debug("Dispatching op_call: %s", op_info.schema) - - self.sharding_propagator.propagate(op_info) - output_sharding = op_info.output_sharding - logger.debug("output_sharding for %s: %s", op_call, output_sharding) - assert output_sharding is not None, "output sharding should not be None" - - mesh = op_info.mesh - if mesh.get_coordinate() is None: - # For a non-participating device, we do: - # 1. if the return type is scalar, set the local result to None. - # The local results from all devices will then be all-gathered - # and a reduce op will be performed on the list of results - # with appropriate operators: - # for bool type, we by default use AND to reduce; - # we can extend for more ops if necessary. - # 2. if the return type is Tensor or List[Tensor], return empty - # tensor(s) with correct dtype. - spec = output_sharding.output_spec - ret_list = op_info.schema.op._schema.returns - - if spec is None: - # For a scalar return type, the non-participating device has None - # as its local result - local_results: object = None - else: - - def default_tensor(spec: DTensorSpec) -> torch.Tensor: - if spec.tensor_meta is not None: - shape = spec.tensor_meta.shape - dtype = spec.tensor_meta.dtype - if len(shape) == 0: - # scalar tensor - return torch.zeros((), dtype=dtype) - else: - # non-scalar tensor - return torch.tensor([], dtype=dtype) - else: - raise RuntimeError(f"{spec} has no tensor metadata.") - - if isinstance(spec, DTensorSpec): - # return a Tensor value - local_results = default_tensor(spec) - elif isinstance(spec, Sequence): - # return a List[Tensor] value - local_results = [ - default_tensor(s) if s is not None else None for s in spec - ] - assert isinstance(local_results, List) - if None in local_results: - ret_type = str(ret_list[0].type) - raise NotImplementedError( - f"return type {ret_type} in DTensor op is not supported" - ) - else: - if output_sharding.needs_redistribute: - # compute locally with redistribute first if needed - assert output_sharding.redistribute_schema is not None - self.redistribute_local_args( - op_info, output_sharding.redistribute_schema - ) - - local_tensor_args = ( - pytree.tree_unflatten( - cast(List[object], op_info.local_args), op_info.args_tree_spec - ) - if op_info.args_tree_spec - else op_info.local_args - ) - - # run local op computation with potentially modified args/kwargs - local_tensor_args = cast(Tuple[object, ...], local_tensor_args) - if op_call in self._random_ops: - if not random._rng_tracker and is_rng_supported_mesh(mesh): - # Default to `OffsetBasedRNGTracker` if the parallelism API - # did not already construct one - random._rng_tracker = random.OffsetBasedRNGTracker(mesh.device_type) - - first_arg, first_local_arg = cast(dtensor.DTensor, args[0]), cast( - torch.Tensor, local_tensor_args[0] - ) - rng_context = ( - random._rng_tracker._distribute_region(first_arg._spec) - if random._rng_tracker and not first_local_arg.is_meta - else contextlib.nullcontext() - ) - - # For DTensor random operator, run it within a distribute region - with rng_context: - local_results = op_call(*local_tensor_args, **op_info.local_kwargs) - else: - local_results = op_call(*local_tensor_args, **op_info.local_kwargs) - - # communicate the result to all ranks for some operators that return scalar value - if output_sharding.output_spec is None: - if op_call == aten.equal.default: - obj_list = [None for _ in range(dist.get_world_size())] - dist.all_gather_object(obj_list, local_results) # type: ignore[possibly-undefined] - obj_list = list(filter(lambda x: x is not None, obj_list)) - # perform reduce on the collection with AND op - local_results = functools.reduce(operator.and_, obj_list, True) - - if _is_inplace_op(op_call): - # inplace op should return self instead of re-wrapping - if output_sharding.output_spec is not None: - return args[0] - else: - return None - elif _is_out_variant_op(op_call): - # out variant could possibly have multiple out args (i.e. lu_unpack.out) - output_specs = ( - (output_sharding.output_spec,) - if not isinstance(output_sharding.output_spec, tuple) - else output_sharding.output_spec - ) - out_dts = [] - spec_idx = 0 - for argument in op_call._schema.arguments: - if argument.is_out: - out_dt = cast(dtensor.DTensor, kwargs[argument.name]) - out_dt._spec = cast(DTensorSpec, output_specs[spec_idx]) - out_dts.append(out_dt) - spec_idx += 1 - - assert len(out_dts) >= 1, "out variant should have at least one out arg" - return tuple(out_dts) if len(out_dts) > 1 else out_dts[0] - else: - return self.wrap(local_results, output_sharding.output_spec) # type: ignore[possibly-undefined] - - @staticmethod - def redistribute_local_args( - op_info: OpInfo, - suggested_input_schema: OpSchema, - ) -> None: - # NOTE: it's very rare that we need to reshard kwargs so we intentionally skip it - - # TODO: the op schema should probably just remain flattened so that we can avoid this tree flatten - # Need to fix all the ops before doing this. - if op_info.args_tree_spec is not None: - flatten_args_schema_to_reshard = tuple( - pytree.tree_leaves(suggested_input_schema.args_schema) - ) - else: - flatten_args_schema_to_reshard = suggested_input_schema.args_schema - - new_local_args: List[object] = [] - for i, arg_spec in enumerate(op_info.flat_args_schema): - reshard_arg_spec = flatten_args_schema_to_reshard[i] - if isinstance(arg_spec, DTensorSpec): - local_tensor = cast(torch.Tensor, op_info.local_args[i]) - if arg_spec != reshard_arg_spec: - resharded_local_tensor = redistribute_local_tensor( - local_tensor, arg_spec, reshard_arg_spec - ) - new_local_args.append(resharded_local_tensor) - else: - new_local_args.append(local_tensor) - else: - new_local_args.append(reshard_arg_spec) - - op_info.local_args = tuple(new_local_args) + # run local op computation with potentially modified args/kwargs + local_tensor_args = [ + arg._local_tensor if isinstance(arg, dtensor.DTensor) else arg + for arg in args + ] - def unwrap_to_op_info( - self, - op_call: torch._ops.OpOverload, - args: Tuple[object, ...], - kwargs: Dict[str, object], - ) -> OpInfo: - # get runtime schema to determine whether to use pytree to flatten inputs - runtime_schema_info = self.sharding_propagator.op_to_schema_info.get( - op_call, None - ) - - if runtime_schema_info is not None and runtime_schema_info.needs_pytree: - # flatten args/kwargs when necessary - tree_args, args_spec = pytree.tree_flatten(args) - args_list: Sequence[object] = tree_args - else: - args_list, args_spec = args, None - - args_schema: List[object] = [] - kwargs_schema: Dict[str, object] = {} - local_args: List[object] = [] - local_kwargs: Dict[str, object] = {} - mesh: Optional[DeviceMesh] = None - - def try_get_replicate_spec( - tensor_arg: torch.Tensor, mesh: "DeviceMesh" - ) -> DTensorSpec: - # tensor_arg is an instance of torch.Tensor and could be an arg or kwarg. - if tensor_arg.numel() == 1 and tensor_arg.ndim == 1: - warnings.warn( - "Found a non-scalar tensor with numel=1 and ndim!=0, " - "we are implicitly creating a replicated DTensor for it. " - "However, please consider changing it to a scalar tensor " - "or explicitly create a DTensor under distributed enviroment." - ) + local_tensor_kwargs = { + k: v.local_tensor if isinstance(v, dtensor.DTensor) else v + for k, v in kwargs.items() + } - # if the arg.numel() == 1, arg.ndim could be 0 or 1. - if ( - tensor_arg.ndim <= 1 - and tensor_arg.numel() == 1 - or self._allow_implicit_replication - ): - # scalar tensor can be safely treated as replicated - replication_spec = DTensorSpec( - mesh, - (Replicate(),) * mesh.ndim, + local_results = op_call(*local_tensor_args, **local_tensor_kwargs) + + # We still need to wrap the local result in a DTensor here in two cases + # 1. When creating a nn.Parameter from a DTensor, it must call tensor.detach + # and the return type must match the input type (DTensor). + # 2. When a single FX op decomposes into multiple aten ops (e.g. torch.embedding) + if op_call._name == "aten::detach": + return self.wrap( + local_results, + DTensorSpec( + mesh=DeviceMesh( + "cuda", + mesh=torch.Tensor( + # todo: generalize + [ + [0, 1, 2, 3], + [4, 5, 6, 7], + ] + ), + ), + placements=args[0]._spec.placements, tensor_meta=TensorMeta( - shape=tensor_arg.shape, - stride=tensor_arg.stride(), - dtype=tensor_arg.dtype, + shape=args[0]._spec.tensor_meta.shape, + stride=args[0]._spec.tensor_meta.stride, + dtype=args[0]._spec.tensor_meta.dtype, ), - ) - else: - raise RuntimeError( - f"{op_call}: got mixed torch.Tensor and DTensor, need to convert all" - " torch.Tensor to DTensor before calling distributed operators!" - ) - return replication_spec - - for arg in args_list: - if isinstance(arg, dtensor.DTensor): - args_schema.append(arg._spec) - local_args.append(arg._local_tensor) - if mesh is not None: - if mesh != arg.device_mesh: - raise NotImplementedError( - f"{op_call}: DTensor does not support cross-mesh operation yet!" - f"Got meshes: {mesh} {arg.device_mesh}" - ) - else: - mesh = arg.device_mesh - elif isinstance(arg, torch.Tensor): - mesh = mesh or try_find_mesh_from_args(op_call, args_list) - args_schema.append(try_get_replicate_spec(arg, mesh)) - local_args.append(arg) - else: - args_schema.append(arg) - local_args.append(arg) - - for k, v in kwargs.items(): - if isinstance(v, dtensor.DTensor): - kwargs_schema[k] = v._spec - local_kwargs[k] = v._local_tensor - if mesh is not None: - if mesh != v.device_mesh: - raise NotImplementedError( - f"{op_call}: DTensor does not support cross-mesh operation yet!" - ) - else: - mesh = v.device_mesh - elif isinstance(v, torch.Tensor): - mesh = mesh or try_find_mesh_from_args(op_call, args_list) - kwargs_schema[k] = try_get_replicate_spec(v, mesh) - local_kwargs[k] = v - else: - kwargs_schema[k] = v - local_kwargs[k] = v - - assert mesh is not None, f"found no DeviceMesh from dtensor args for {op_call}!" - op_info = OpInfo( - mesh, - OpSchema( - op_call, - ( - pytree.tree_unflatten(args_schema, args_spec) - if args_spec - else tuple(args_schema) ), - kwargs_schema, - schema_info=runtime_schema_info, - ), - args_schema, - tuple(local_args), - local_kwargs, - args_spec, - ) - return op_info + ) + + return local_results @staticmethod - def wrap(res: object, spec: OutputSpecType) -> object: + def wrap( + res: object, + spec: OutputSpecType, + ) -> object: if isinstance(res, torch.Tensor): - if spec is not None: - assert isinstance( - spec, DTensorSpec - ), f"output spec does not match with output! Expected DTensorSpec, got {spec}." - return dtensor.DTensor(res, spec, requires_grad=res.requires_grad) - else: - # if output does not have a DTensorSpec due to specific ops, it must be a scalar tensor - assert res.ndim == 0, "output tensor should be scalar!" - return res + return dtensor.DTensor( + res, + spec, + requires_grad=res.requires_grad, + ) elif isinstance(res, (list, tuple)): - assert spec is not None and isinstance( - spec, (list, tuple) - ), f"output spec does not match with output! Expected list/tuple, got {spec}." res_list = [] for e, s in zip(res, spec): res_list.append(OpDispatcher.wrap(e, s)) diff --git a/src/chop/distributed/tensor/api.py b/src/chop/distributed/tensor/api.py index 47a778bea..e9cabae4e 100644 --- a/src/chop/distributed/tensor/api.py +++ b/src/chop/distributed/tensor/api.py @@ -205,7 +205,7 @@ def backward(ctx, grad_output: "DTensor"): # type: ignore[override] return grad_output.to_local(), None, None, None, None, None -class DTensor(torch.Tensor): # pyre-ignore[13]: pyre is bad at __new__ +class DTensor(torch.Tensor): _local_tensor: torch.Tensor _spec: DTensorSpec __slots__ = ["_local_tensor", "_spec"] @@ -215,7 +215,6 @@ class DTensor(torch.Tensor): # pyre-ignore[13]: pyre is bad at __new__ _op_dispatcher: op_dispatch.OpDispatcher = op_dispatch.OpDispatcher() @staticmethod - @torch._disable_dynamo def __new__( cls, local_tensor: torch.Tensor, @@ -226,22 +225,8 @@ def __new__( """ Construct a DTensor from a local tensor, device mesh, and placement and other tensor properties (i.e. shape, requires_grad, strides, etc). - Note: This is not a public API and it's only supposed to be used by the - operator implementations and internals. If you want to construct a - DTensor from a local tensor, consider using `DTensor.from_local`, if - you want to construct a DTensor from a "global" tensor (where you - already have tensor initialized and want to shard this tensor), - consider using `distribute_tensor`. """ - if local_tensor.requires_grad and not requires_grad: - warnings.warn( - "To construct DTensor from torch.Tensor, it's recommended to " - "use local_tensor.detach() and make requires_grad consistent." - ) - # new method instruct wrapper tensor from local_tensor and add - # placement spec, it does not do actual distribution - assert spec.tensor_meta is not None, "TensorMeta should not be None!" r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined] cls, spec.tensor_meta.shape, @@ -260,7 +245,10 @@ def __new__( # pyre-fixme[3]: Return type must be annotated. def __repr__(self): # TODO: consider all_gather the local tensors for better debugging - return f"DTensor(local_tensor={self._local_tensor}, device_mesh={self._spec.mesh}, placements={self._spec.placements})" + if self._spec is None: + return f"DTensor(local_tensor={self._local_tensor}, device_mesh=None, placements=None)" + else: + return f"DTensor(local_tensor={self._local_tensor}, device_mesh={self._spec.mesh}, placements={self._spec.placements})" def __tensor_flatten__(self): """ @@ -308,9 +296,6 @@ def __coerce_same_metadata_as_tangent__(self, flatten_spec): ) @classmethod - @torch._disable_dynamo - # pyre-fixme[3]: Return type must be annotated. - # pyre-fixme[2]: Parameter must be annotated. def __torch_dispatch__(cls, func, types, args=(), kwargs=None): return DTensor._op_dispatcher.dispatch( func, diff --git a/src/chop/distributed/tensor/ops/math_ops.py b/src/chop/distributed/tensor/ops/math_ops.py index c6ebe49be..c5c3609d6 100644 --- a/src/chop/distributed/tensor/ops/math_ops.py +++ b/src/chop/distributed/tensor/ops/math_ops.py @@ -21,6 +21,9 @@ is_tensor_evenly_shardable, normalize_dim, normalize_dims, +) + +from torch.distributed.utils import ( normalize_to_torch_size, ) from torch.distributed._tensor.placement_types import ( diff --git a/src/chop/distributed/utils.py b/src/chop/distributed/utils.py index c7cd7c1c7..55c5293f6 100644 --- a/src/chop/distributed/utils.py +++ b/src/chop/distributed/utils.py @@ -1,3 +1,18 @@ +from time import time +import numpy as np + +import torch +import torch.nn as nn +import torch.distributed as dist +from torch.distributed._tensor import DeviceMesh + +from chop.tools import get_logger +from chop.distributed.tensor import distribute_tensor + +logger = get_logger(__name__) +logger.setLevel("DEBUG") + + def rlog(logger, rank, msg, level="info"): """ Only log on rank 0 to avoid repeated messages. @@ -5,3 +20,122 @@ def rlog(logger, rank, msg, level="info"): log_fn = getattr(logger, level, logger.info) if rank == 0: log_fn(msg) + + +def distributed_timing(fn, *args, **kwargs): + dist.barrier(async_op=True) + start = time() + result = fn(*args, **kwargs) + dist.barrier(async_op=True) + end = time() + + return result, (end - start) + + +def distributed_average_timing( + fn, + args, + repeat=10, + warmup_iters=2, +): + times = [] + for itr in range(repeat): + rlog( + logger, + dist.get_rank(), + f"Running teration {itr}", + "info", + ) + dist.barrier(async_op=True) + + if isinstance(args, list): + start = time() + result = fn(*args) + dist.barrier(async_op=True) + end = time() + elif isinstance(args, dict): + start = time() + result = fn(**args) + dist.barrier(async_op=True) + end = time() + else: + raise ValueError("args must be a list or a dict") + + times.append(end - start) + rlog( + logger, + dist.get_rank(), + f"Time taken: {end - start}s", + "info", + ) + + return result, sum(times[warmup_iters:]) / len(times[warmup_iters:]) + + +def dist_model_fn( + name: str, + module: nn.Module, + device_mesh: DeviceMesh, + rank: int, + tensor_sharding_map={}, +) -> None: + """ + This function gets called by torch.distributed._tensor.distribute_module on each module in the model. + Each tensor in each module is distributed according to the sharding configuration in tensor_sharding_map. + + Args: + name (str): _description_ + module (nn.Module): _description_ + device_mesh (DeviceMesh): _description_ + rank (int): _description_ + tensor_sharding_map (dict, optional): _description_. Defaults to {}. + """ + + if module in tensor_sharding_map: + node_name = tensor_sharding_map[module]["node"] + for parameter, sharding_config in tensor_sharding_map[module][ + "sharding" + ].items(): + if parameter in ["data_in_0", "output", "data_out_0"]: + continue + if not hasattr(module, parameter): + rlog( + logger, + rank, + f"Module {module} does not have parameter {parameter}", + level="warning", + ) + continue + + placement = sharding_config.placements + + try: + rlog( + logger, + rank, + f"Distributing parameter {parameter} of module {node_name} to {placement}", + level="debug", + ) + distributed_tensor = distribute_tensor( + getattr(module, parameter), device_mesh, placement + ) + setattr( + module, + parameter, + torch.nn.Parameter(distributed_tensor), + ) + except Exception as e: + rlog( + logger, + rank, + f"Error distributing parameter {parameter} of module {node_name} to {placement}: {e}", + level="error", + ) + raise e + + +def _get_mesh_from_world_size(world_size: int = 8): + device_ids = np.arange(world_size) + mesh_shape = (2, world_size // 2) + mesh_ids = device_ids.reshape(mesh_shape) + return mesh_ids.tolist(), tuple(mesh_shape) diff --git a/src/chop/ir/__init__.py b/src/chop/ir/__init__.py index 834757faa..d0b968ab7 100644 --- a/src/chop/ir/__init__.py +++ b/src/chop/ir/__init__.py @@ -1,3 +1,4 @@ from .graph.mase_graph import MaseGraph, MaseTracer +from .graph.mase_graph_metadata import MaseGraphMetadata from .onnx.mase_onnx_graph import MaseOnnxGraph diff --git a/src/chop/ir/common.py b/src/chop/ir/common.py index 98579739b..ca0d2379a 100644 --- a/src/chop/ir/common.py +++ b/src/chop/ir/common.py @@ -49,6 +49,17 @@ "finfo", "masked_fill", "masked_fill_", + # Inserted ops from the replace_method_with_function pass + "torch_size", + "torch_contiguous", + "torch_expand", + "torch_view", + "torch_reshape", + "torch_split", + "torch_permute", + "torch_transpose", + # dtensor ops (return DTensor) + "dtensor_arange", ] MASE_MODULE_RELATED_FUNCS = [ @@ -93,6 +104,7 @@ "sub", "add", "matmul", + "mm", "bmm", "mean", "pow", diff --git a/src/chop/nn/functional/dtensor.py b/src/chop/nn/functional/dtensor.py new file mode 100644 index 000000000..c5e29b73d --- /dev/null +++ b/src/chop/nn/functional/dtensor.py @@ -0,0 +1,119 @@ +from typing import Tuple + +import torch +import torch.fx as fx +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta +from torch.distributed._tensor._redistribute import redistribute_local_tensor + +from torch.distributed._tensor.placement_types import Placement + +from chop.ir.graph import MaseMetadata +from chop.distributed.tensor import DTensor +from chop.tools import get_logger + +logger = get_logger(__name__) +logger.setLevel("DEBUG") + + +def rlog(msg): + rank = torch.distributed.get_rank() + if rank == 0: + print(msg) + + +@fx.wrap +def dtensor_arange( + start: int, + end: int, + step: int = 1, + out: torch.Tensor = None, + dtype: torch.dtype = None, + layout: torch.layout = torch.strided, + device: torch.device = None, + requires_grad: bool = False, + device_mesh: DeviceMesh = None, +): + """Returns a fully replicated DTensor with behaviour akin to `torch.arange`. + + Args: + start (int): _description_ + end (int): _description_ + step (int, optional): _description_. Defaults to 1. + out (torch.Tensor, optional): _description_. Defaults to None. + dtype (torch.dtype, optional): _description_. Defaults to None. + layout (torch.layout, optional): _description_. Defaults to torch.strided. + device (torch.device, optional): _description_. Defaults to None. + requires_grad (bool, optional): _description_. Defaults to False. + """ + return DTensor.from_local( + torch.arange( + start, + end, + step, + out=out, + dtype=dtype, + layout=layout, + device=device, + ), + device_mesh=device_mesh, + ) + + +@fx.wrap +def redistribute_dtensor( + input: DTensor, + placements: Tuple[Placement, ...], + async_op: bool = False, + input_tensor_mesh=None, +): + """ + Redistribute a DTensor to a new set of placements. + + Args: + input (DTensor): The input DTensor to redistribute. + placements (Tuple[Placement, ...]): The new placements for the output DTensor. + async_op (bool, optional): Whether to perform the redistribution asynchronously. Defaults to False. + + Returns: + DTensor: The redistributed DTensor. + """ + + # If we are not in a distributed setting, we can skip redistribution. + if not isinstance(input, DTensor): + return input + + torch_mesh = DeviceMesh( + "cuda", + mesh=torch.Tensor(input_tensor_mesh), + ) + + current_spec = input._spec + + if current_spec.placements != placements: + target_spec = DTensorSpec( + torch_mesh, + placements, + tensor_meta=input._spec.tensor_meta, + ) + + local_tensor = input._local_tensor + + assert not isinstance(local_tensor, DTensor) + + output = redistribute_local_tensor( + local_tensor, + current_spec, + target_spec, + async_op=async_op, + ) + else: + # use the same local tensor if placements are the same. + output = input._local_tensor + target_spec = current_spec + + return DTensor( + output, + target_spec, + requires_grad=input.requires_grad, + ) diff --git a/src/chop/nn/functional/tensor.py b/src/chop/nn/functional/tensor.py new file mode 100644 index 000000000..6f423c8f2 --- /dev/null +++ b/src/chop/nn/functional/tensor.py @@ -0,0 +1,80 @@ +import torch +import torch.fx as fx + +# This file contains functional equivalent of some torch.Tensor methods +# which can be casted to call_function nodes by the replace_method_with_function pass. +# They must have the same signature as their torch.Tensor equivalents with an added +# input node at position 0. + + +@fx.wrap +def torch_size( + input: torch.Tensor, + dim: int = None, +): + return input.size(dim) + + +@fx.wrap +def torch_expand( + input: torch.Tensor, + *sizes, +): + return input.expand(*sizes) + + +@fx.wrap +def torch_view( + input: torch.Tensor, + *shape, +): + return input.view(*shape) + + +@fx.wrap +def torch_contiguous( + input: torch.Tensor, + memory_format: torch.memory_format = torch.contiguous_format, +): + return input.contiguous(memory_format=memory_format) + + +# The following functions exist in torch functional land, +# however their functional implementation does not accept +# arbitrary argument counts i.e. *args, **kwargs, so we +# reimplement them here. +# ============================================================ + + +@fx.wrap +def torch_reshape( + input: torch.Tensor, + *shape, +): + return input.reshape(*shape) + + +@fx.wrap +def torch_split( + input: torch.Tensor, + split_size: int, + dim: int = 0, +): + return input.split(split_size, dim) + + +@fx.wrap +def torch_permute( + input: torch.Tensor, + *dims, +): + return input.permute(*dims) + + +@fx.wrap +def torch_transpose( + input: torch.Tensor, + dim0: int, + dim1: int, +): + return input.transpose(dim0, dim1) diff --git a/src/chop/passes/__init__.py b/src/chop/passes/__init__.py index 8b0053113..47b971f77 100644 --- a/src/chop/passes/__init__.py +++ b/src/chop/passes/__init__.py @@ -33,9 +33,12 @@ emit_vivado_project_transform_pass, raise_granularity_transform_pass, patch_metadata_transform_pass, + resharding_transform_pass, + replace_method_with_function, + insert_dtensor_wrapper_transform_pass, ) -from .module.analysis import calculate_avg_bits_module_analysis_pass -from .module.transforms import quantize_module_transform_pass, resharding_transform_pass +from .module.analysis import calculate_avg_bits_module_analysis_pass, autosharding_module_analysis_pass +from .module.transforms import quantize_module_transform_pass from .onnx.analysis import ( export_fx_graph_analysis_pass, diff --git a/src/chop/passes/graph/__init__.py b/src/chop/passes/graph/__init__.py index 0c54e01fd..ee671b21a 100644 --- a/src/chop/passes/graph/__init__.py +++ b/src/chop/passes/graph/__init__.py @@ -31,6 +31,7 @@ onnx_annotate_transform_pass, partition_to_multi_device_transform_pass, raise_granularity_transform_pass, + insert_dtensor_wrapper_transform_pass, ) from .interface import ( diff --git a/src/chop/passes/graph/analysis/add_metadata/add_common_metadata.py b/src/chop/passes/graph/analysis/add_metadata/add_common_metadata.py index 346e486e9..df1b42f12 100644 --- a/src/chop/passes/graph/analysis/add_metadata/add_common_metadata.py +++ b/src/chop/passes/graph/analysis/add_metadata/add_common_metadata.py @@ -204,7 +204,10 @@ def graph_iterator_for_mase_ops(graph): def graph_iterator_for_metadata( - graph, dummy_in=None, add_value=True, force_device_meta=False + graph, + dummy_in=None, + add_value=True, + force_device_meta=False, ): """ largely apated from https://pytorch.org/docs/stable/fx.html @@ -212,14 +215,13 @@ def graph_iterator_for_metadata( model, fx_graph, modules = graph.model, graph.fx_graph, graph.modules env = {} - prev_result = None # force everything to be on device="meta" if force_device_meta: dummy_in = {k: v.to("meta") for k, v in dummy_in.items()} model = model.to("meta") - for node in graph.fx_graph.nodes: + for node in fx_graph.nodes: args, kwargs = None, None if node.op == "placeholder": result = dummy_in[node.name] diff --git a/src/chop/passes/graph/analysis/add_metadata/common_metadata_layers.py b/src/chop/passes/graph/analysis/add_metadata/common_metadata_layers.py index 961e514f8..6071db04f 100644 --- a/src/chop/passes/graph/analysis/add_metadata/common_metadata_layers.py +++ b/src/chop/passes/graph/analysis/add_metadata/common_metadata_layers.py @@ -1,11 +1,14 @@ -import inspect -import math +from collections import OrderedDict +from functools import reduce import torch -import inspect + from chop.nn.quantized.modules import quantized_module_map -from functools import reduce +from chop.ir import MaseGraphMetadata +from chop.tools import get_logger +logger = get_logger(__name__) +logger.setLevel("INFO") # ---------------------------------------------------------- # Utility @@ -75,6 +78,8 @@ "sub": {"input": "data_in", "other": "data_in"}, # https://pytorch.org/docs/stable/generated/torch.matmul.html "matmul": {"input": "data_in", "other": "data_in"}, + # https://pytorch.org/docs/stable/generated/torch.mm.html + "mm": {"input": "data_in", "mat2": "data_in"}, # https://pytorch.org/docs/stable/generated/torch.bmm.html "bmm": {"input": "data_in", "mat2": "data_in"}, # https://pytorch.org/docs/stable/generated/torch.squeeze.html @@ -222,6 +227,63 @@ "scale_grad_by_freq": "config", "sparse": "config", }, + # Inserted ops from the replace_method_with_function pass + "torch_size": {"input": "data_in", "dim": "config"}, + "torch_contiguous": { + "input": "data_in", + "memory_format": "config", + }, + # arbitrary length - support up to 4 + "torch_expand": { + "input": "data_in", + "size_0": "config", + "size_1": "config", + "size_2": "config", + "size_3": "config", + }, + "torch_view": { + "input": "data_in", + "shape_0": "config", + "shape_1": "config", + "shape_2": "config", + "shape_3": "config", + }, + "torch_reshape": { + "input": "data_in", + "shape_0": "config", + "shape_1": "config", + "shape_2": "config", + "shape_3": "config", + }, + "torch_split": { + "input": "data_in", + "split_size": "config", + "dim": "config", + }, + "torch_permute": { + "input": "data_in", + "dim_0": "config", + "dim_1": "config", + "dim_2": "config", + "dim_3": "config", + }, + "torch_transpose": { + "input": "data_in", + "dim0": "config", + "dim1": "config", + }, + # DTensor ops + "dtensor_arange": { + "device_mesh": "config", + "start": "config", + "end": "config", + "step": "config", + "out": "config", + "dtype": "config", + "layout": "config", + "device": "config", + "requires_grad": "config", + }, } module_data = { @@ -363,6 +425,10 @@ "type_as": {"tensor": "data_in"}, } +# ---------------------------------------------------------- +# Helpers +# ---------------------------------------------------------- + def get_type_and_precision(meta): # * Fetch type and precision from q_config for quantized modules @@ -383,16 +449,65 @@ def get_type_and_precision(meta): return arg_type, arg_precision -def match_args_and_kwargs(meta, args, kwargs, data, add_value): - ordered_func_data = [(k, v) for k, v in data.items()] - meta.parameters["common"]["args"] = {} - meta_kwargs = {} +def get_shape(x): + if x is None: + return None + elif isinstance(x, torch.Tensor): + return list(x.shape) + elif isinstance(x, int): + return [1] + elif isinstance(x, (list, tuple, torch.Size)): + return [len(x)] + else: + return [0] + + +def deepgetattr(obj, attr): + """Recurses through an attribute chain to get the ultimate value.""" + return reduce(getattr, attr.split("."), obj) + + +# ---------------------------------------------------------- +# Metadata annotators +# ---------------------------------------------------------- + + +def _annotate_arg_metadata( + meta: MaseGraphMetadata, + args: list, + kwargs: dict, + func_data: dict, + add_value: bool, +): + """ + Analyse target args and kwargs received from shape propagation to annotate combined meta["mase"]["args"] + dictionary with metadata about each argument. The order of the args and kwargs must be preserved in the + combined dictionary (this is expected by downstream passes). However, arguments with the 'data_in' flag + in func_data are renamed to 'data_in_{itr}' where itr = 0 ... the number of data_in arguments. + + This function should not be called directly, but rather through the `annotate_common_parameters_` function. + The value in the meta["common"]["args"] dictionary should always be a dictionary, not a tensor. + + Args: + meta (MaseGraphMetadata): The metadata object. + args (list): List of args passed to the target. + kwargs (dict): Dictionary of kwargs passed to the target. + func_data (dict): Dictionary defining whether each argument is data_in or config. + add_value (bool): indicate whether to add the value of the tensor to the metadata. + + Returns: + MaseGraphMetadata: metadata object with annotated args. + """ + ordered_func_data = [(k, v) for k, v in func_data.items()] + meta["common"]["args"] = OrderedDict() + data_in_itr = 0 arg_type, arg_precision = get_type_and_precision(meta) - # * Assign metadata for each argument - j = 0 + # * Handle args for i, x in enumerate(args): + + # Input data tensor if isinstance(x, torch.Tensor) and ordered_func_data[i][1] == "data_in": arg_meta = { "shape": list(x.shape), @@ -402,9 +517,10 @@ def match_args_and_kwargs(meta, args, kwargs, data, add_value): } if add_value: arg_meta["value"] = x - meta.parameters["common"]["args"][f"data_in_{j}"] = arg_meta - j += 1 - # check if it's a tuple of tensors + meta["common"]["args"][f"data_in_{data_in_itr}"] = arg_meta + data_in_itr += 1 + + # Tuple of tensors elif isinstance(x, tuple) and all([isinstance(x, torch.Tensor) for x in x]): for k, x in enumerate(x): arg_meta = { @@ -415,27 +531,32 @@ def match_args_and_kwargs(meta, args, kwargs, data, add_value): } if add_value: arg_meta["value"] = x - meta.parameters["common"]["args"][f"data_in_{j}"] = arg_meta - j += 1 - else: - # this is not an data_in, but just actually an named arg - n, vtype = ordered_func_data[i] - meta_kwargs[n] = args[i] - - def get_shape(x): - if x is None: - return None - elif isinstance(x, torch.Tensor): - return list(x.shape) - elif isinstance(x, int): - return [1] - elif isinstance(x, list): - return [len(x)] + meta["common"]["args"][f"data_in_{data_in_itr}"] = arg_meta + data_in_itr += 1 + + # Unknown data_in type or config argument else: - raise ValueError(f"Unknown type {type(x)}") + # Don't increment the iterator for config arguments, but + # preserve order in meta["common"]["args"] + arg_name, arg_flag = ordered_func_data[i] + + if arg_flag == "data_in": + arg_name = f"data_in_{data_in_itr}" + data_in_itr += 1 + meta["common"]["args"][arg_name] = { + "torch_dtype": x.dtype if isinstance(x, torch.Tensor) else None, + "type": type(args[i]), + "precision": arg_precision, + "shape": get_shape(args[i]), + } + + if add_value: + meta["common"]["args"][arg_name]["value"] = args[i] + + # * Handle kwargs for k, v in kwargs.items(): - if data[k] == "data_in": + if func_data[k] == "data_in": # rename this to mase data_in_number shape = get_shape(v) arg_meta = { @@ -446,47 +567,71 @@ def get_shape(x): } if add_value: arg_meta["value"] = v - meta.parameters["common"]["args"][f"data_in_{j}"] = arg_meta - j += 1 + meta["common"]["args"][f"data_in_{data_in_itr}"] = arg_meta + data_in_itr += 1 else: # otherwise this must be a configuration parameter in meta - meta_kwargs[k] = v - # merge configuratipn args - meta.parameters["common"]["args"] = meta.parameters["common"]["args"] | meta_kwargs + # meta_kwargs[k] = v + meta["common"]["args"][k] = { + "type": type(v), + "precision": arg_precision, + "shape": get_shape(v), + } + if add_value: + meta["common"]["args"][k]["value"] = v + return meta -def analyse_result(meta, result, add_value): +def _annotate_result_metadata( + meta: MaseGraphMetadata, + result, + add_value: bool, +) -> MaseGraphMetadata: + """ + Analyse the result from running the target to annotate the meta["mase"]["results"] dictionary with metadata. + + Args: + meta (MaseGraphMetadata): The metadata object. + result (_type_): The result object. + add_value (bool): indicate whether to add the value of the tensor to the metadata. + + Returns: + MaseGraphMetadata: metadata object with annotated results. + """ # deal with results - meta.parameters["common"]["results"] = {} + meta["common"]["results"] = OrderedDict() result_type, result_precision = get_type_and_precision(meta) if isinstance(result, torch.Tensor): - meta.parameters["common"]["results"]["data_out_0"] = { + meta["common"]["results"]["data_out_0"] = { "type": result_type, "precision": result_precision, "shape": list(result.shape), "torch_dtype": result.dtype, } if add_value: - meta.parameters["common"]["results"]["data_out_0"]["value"] = result + meta["common"]["results"]["data_out_0"]["value"] = result # check if it's a tuple of tensors elif isinstance(result, tuple) and all( [isinstance(x, torch.Tensor) for x in result] ): for i, x in enumerate(result): - meta.parameters["common"]["results"][f"data_out_{i}"] = { + meta["common"]["results"][f"data_out_{i}"] = { "type": result_type, "precision": result_precision, "shape": list(x.shape), "torch_dtype": x.dtype, } if add_value: - meta.parameters["common"]["results"][f"data_out_{i}"]["value"] = x + meta["common"]["results"][f"data_out_{i}"]["value"] = x else: - meta.parameters["common"]["results"]["data_out_0"] = { + logger.debug( + f"Expected result to be a tensor or tuple of tensors, but found: {type(result)}. Will annotate with default value, but this may cause issues downstream." + ) + meta["common"]["results"]["data_out_0"] = { "type": type(result), "shape": [1], "value": result, @@ -507,19 +652,19 @@ def analyse_common_parameters_placeholder(meta, result, args, kwargs, add_value= var_name = meta.node.target # deal with model specific inputs, normally these are not numerical values/tensors if var_name in meta.model.additional_inputs: - meta.parameters["common"]["args"] = {} - meta.parameters["common"]["results"] = {} - meta.parameters["common"]["results"]["data_out_0"] = { + meta["common"]["args"] = {} + meta["common"]["results"] = {} + meta["common"]["results"]["data_out_0"] = { "type": "model_specific_input", "shape": result.shape, "torhc_dtype": result.dtype, } if add_value: - meta.parameters["common"]["results"]["data_out_0"]["value"] = result + meta["common"]["results"]["data_out_0"]["value"] = result return meta - meta.parameters["common"]["args"] = {} - meta = analyse_result(meta, result, add_value) + meta["common"]["args"] = {} + meta = _annotate_result_metadata(meta, result, add_value) return meta @@ -530,12 +675,12 @@ def analyse_common_parameters_placeholder(meta, result, args, kwargs, add_value= def analyse_common_parameters_function(meta, result, args, kwargs, add_value=True): # fetch mase info - mase_op = meta.parameters["common"]["mase_op"] + mase_op = meta["common"]["mase_op"] # deal with result - meta = analyse_result(meta, result, add_value) + meta = _annotate_result_metadata(meta, result, add_value) # deal with args and kwargs - meta = match_args_and_kwargs(meta, args, kwargs, func_data[mase_op], add_value) + meta = _annotate_arg_metadata(meta, args, kwargs, func_data[mase_op], add_value) return meta @@ -545,13 +690,8 @@ def analyse_common_parameters_function(meta, result, args, kwargs, add_value=Tru # ---------------------------------------------------------- -def deepgetattr(obj, attr): - """Recurses through an attribute chain to get the ultimate value.""" - return reduce(getattr, attr.split("."), obj) - - def analyse_common_parameters_module(meta, result, args, kwargs, add_value=True): - mase_op = meta.parameters["common"]["mase_op"] + mase_op = meta["common"]["mase_op"] node_module = deepgetattr(meta.model, meta.node.target) if mase_op == "user_defined_module": @@ -562,13 +702,13 @@ def analyse_common_parameters_module(meta, result, args, kwargs, add_value=True) else: module_args = module_data[mase_op] - meta = match_args_and_kwargs(meta, args, kwargs, module_args, add_value) + meta = _annotate_arg_metadata(meta, args, kwargs, module_args, add_value) arg_type, arg_precision = get_type_and_precision(meta) for name, parameter in meta.module.named_parameters(): name = name.replace(".", "_") - meta.parameters["common"]["args"][name] = { + meta["common"]["args"][name] = { "type": arg_type, "precision": arg_precision, "shape": ( @@ -579,21 +719,21 @@ def analyse_common_parameters_module(meta, result, args, kwargs, add_value=True) "from": None, } if add_value: - meta.parameters["common"]["args"][name]["value"] = parameter + meta["common"]["args"][name]["value"] = parameter - meta = analyse_result(meta, result, add_value) + meta = _annotate_result_metadata(meta, result, add_value) return meta # ---------------------------------------------------------- -# Module +# Method # ---------------------------------------------------------- def analyse_common_parameters_method(meta, result, args, kwargs, add_value=True): - mase_op = meta.parameters["common"]["mase_op"] - meta = analyse_result(meta, result, add_value) - meta = match_args_and_kwargs(meta, args, kwargs, method_data[mase_op], add_value) + mase_op = meta["common"]["mase_op"] + meta = _annotate_result_metadata(meta, result, add_value) + meta = _annotate_arg_metadata(meta, args, kwargs, method_data[mase_op], add_value) return meta @@ -603,8 +743,8 @@ def analyse_common_parameters_method(meta, result, args, kwargs, add_value=True) def analyse_common_parameters_attr(meta, result, args, kwargs, add_value=True): - meta.parameters["common"]["args"] = {} - meta = analyse_result(meta, result, add_value) + meta["common"]["args"] = {} + meta = _annotate_result_metadata(meta, result, add_value) return meta @@ -614,6 +754,6 @@ def analyse_common_parameters_attr(meta, result, args, kwargs, add_value=True): def analyse_common_parameters_output(meta, result, args, kwargs, add_value=True): - meta.parameters["common"]["args"] = {} - meta = analyse_result(meta, result, add_value) + meta["common"]["args"] = {} + meta = _annotate_result_metadata(meta, result, add_value) return meta diff --git a/src/chop/passes/graph/analysis/autosharding/alpa.py b/src/chop/passes/graph/analysis/autosharding/algos/alpa.py similarity index 100% rename from src/chop/passes/graph/analysis/autosharding/alpa.py rename to src/chop/passes/graph/analysis/autosharding/algos/alpa.py diff --git a/src/chop/passes/graph/analysis/autosharding/algos/alpa_intra_operator.py b/src/chop/passes/graph/analysis/autosharding/algos/alpa_intra_operator.py new file mode 100644 index 000000000..de9143d56 --- /dev/null +++ b/src/chop/passes/graph/analysis/autosharding/algos/alpa_intra_operator.py @@ -0,0 +1,554 @@ +import math +import numpy as np +import cvxpy as cp +from copy import copy + +import torch +import torch.fx as fx +from torch.distributed._tensor._collective_utils import redistribute_cost +from torch.distributed._tensor._op_schema import ( + DTensorSpec, + OpStrategy, + PlacementStrategy, +) +from torch.distributed._tensor.placement_types import Shard, Replicate + +from chop.tools import get_logger +from ..mesh_model import MeshModel + +from ..layers import ( + AUTOSHARDING_FUNCTIONS, + IMPLICIT_FUNCS, + FULLY_REPLICATED_FUNCS, +) +from ..ops.common import ( + fully_replicated_strategy, + placeholder_or_getattr_strategy, +) + + +logger = get_logger(__name__) +logger.setLevel("INFO") + + +def _get_computation_cost_from_strategy( + node: fx.Node, + strategy: OpStrategy, + mesh: MeshModel, + repeat: int = 100, + warmup_iters: int = 2, + profiling_device: int = None, +): + """ + ... + + Args: + node (fx.Node): _description_ + strategy (OpStrategy): _description_ + repeat (int, optional): _description_. Defaults to 5. + warmup_iters (int, optional): _description_. Defaults to 1. + + Returns: + _type_: _description_ + """ + arg_specs = strategy.input_specs + arg_specs = [arg_specs] if isinstance(arg_specs, DTensorSpec) else arg_specs + + # Formulate list of arguments to run the target with + args = [] + for arg_idx, arg_spec in enumerate(arg_specs): + + # If tensor meta is None, this is not a sharded argument + if arg_spec.tensor_meta is None: + key = list(node.meta["mase"]["common"]["args"].keys())[arg_idx] + arg_value = node.meta["mase"]["common"]["args"][key]["value"] + args.append(arg_value) + continue + + # If it is a sharded argument, find the local tensor shape + else: + global_shape = copy(arg_spec.tensor_meta.shape) + local_shape = copy(arg_spec.tensor_meta.shape) + + # Check if each tensor dimension is sharded to update local_shape + for dim in range(len(global_shape)): + # Get device mesh dimensions along which dimension 'dim' of the tensor is sharded + sharded_mesh_dims = [ + idx + for idx, plac in enumerate(arg_spec.placements) + if plac == Shard(dim) + ] + + # This tensor dimension is not sharded + if len(sharded_mesh_dims) == 0: + continue + + # This tensor dimension is fully sharded + elif len(sharded_mesh_dims) == 2: + num_gpus = np.prod(mesh.mesh_shape) + + # This tensor dimension is sharded along one mesh dimension + elif len(sharded_mesh_dims) == 1: + num_gpus = mesh.mesh_shape[sharded_mesh_dims[0]] + + # Define the local shape with minimum == 1 + local_shape[dim] = math.ceil(global_shape[dim] / num_gpus) + + # Generate a random tensor with the local shape + args.append( + torch.randn( + local_shape, + device=f"cuda:{profiling_device}", + ) + ) + + # Get target function + fn = node.target + + # Run the function with the arguments + start_event = [ + torch.cuda.Event(enable_timing=True, blocking=True) for _ in range(repeat) + ] + end_event = [ + torch.cuda.Event(enable_timing=True, blocking=True) for _ in range(repeat) + ] + + torch.cuda.empty_cache() + + for idx in range(repeat): + start_event[idx].record() + _ = fn(*args) + end_event[idx].record() + torch.cuda.synchronize(device=f"cuda:{profiling_device}") + + elapsed = [start_event[idx].elapsed_time(end_event[idx]) for idx in range(repeat)] + + return np.mean(elapsed[warmup_iters:]) + + +def _no_tensor_args(node): + has_tensor_args = False + for arg, arg_info in node.meta["mase"]["common"]["args"].items(): + if isinstance(arg_info["value"], torch.Tensor): + has_tensor_args = True + break + return not has_tensor_args + + +def _inherit_strategy(node, parent_strategy): + """ + Inherit the sharding strategy from the parent node. The main data + argument is assigned the ouput sharding of the parent node, with + all other arguments casted to fully replicated placement. The output + sharding of the parent node is also assigned to the output spec of + each strategy since implicit nodes don't change the tensor shardings + + Args: + node (fx.Node): input node. + parent_strategy (OpStrategy): parent node's sharding strategy. + + Returns: + OpStrategy: inherited sharding strategy. + """ + + strategies = [] + + for strategy in parent_strategy.strategies: + spec = [strategy.output_specs] + [ + DTensorSpec( + mesh=strategy.output_specs.mesh, + placements=(Replicate(), Replicate()), + tensor_meta=None, + ) + ] * (len(node.meta["mase"]["common"]["args"]) - 1) + strategies.append( + PlacementStrategy( + input_specs=spec, + output_specs=spec[0], + ) + ) + + return OpStrategy(strategies) + + +def _extract_ilp(mg, mesh, pass_args={}): + """ + For each node in the graph, assign an OpStrategy object which contains all possible + sharding algorithms. Also assign opt_var instance which is one-hot vector used to + solve ILP. + + Return list of constraints associated with ILP. The constraints at this stage only + enforce that each optimizer variable is a one-hot boolean vector. + + Args: + mg (MaseGraph): input mase graph. + mesh (MeshModel): mesh model. + pass_args (dict, optional): pass arguments. Defaults to {}. + + Returns: + MaseGraph: input mase graph. + cp.Problem: optimization problem. + """ + + # Setup for the ILP optimization + constr = [] + expr = 0 + + # Find sharding strategies for each operator in the graph + for node in mg.fx_graph.nodes: + + # Placeholder and get_attr nodes inject tensors into the graph + if node.op in [ + "placeholder", + "get_attr", + ]: + logger.debug( + f"Node {node} with op {node.op} will be assigned all permutations of Shard(dims) and Replicate()" + ) + op_strategy = placeholder_or_getattr_strategy( + node.meta["mase"], + mesh, + skip_fully_replicated=pass_args.get("skip_fully_replicated", False), + ) + + # Constrain some nodes to have fully replicated sharding + elif node.op == "call_function" and node.target in FULLY_REPLICATED_FUNCS: + logger.debug( + f"Node {node.name} will be assigned fully replicated sharding." + ) + + op_strategy = fully_replicated_strategy(node.meta["mase"], mesh) + opt_var = cp.Variable(1, boolean=True) + constr += [ + cp.sum(opt_var) == 1, + ] + + # Opt var is None since no decision needs to be taken + node.meta["mase"]["software"]["autosharding"] = { + "op_strategy": op_strategy, + "opt_var": opt_var, + "is_implicit": False, + } + continue + + # Output nodes, implicit nodes and nodes with only non-tensor arguments + # inherit the sharding strategy from their parent node + elif ( + node.op == "output" + or node.op == "call_function" + and node.target in IMPLICIT_FUNCS + or _no_tensor_args(node) + ): + logger.debug( + f"Node {node.name} will inherit sharding strategy from its parent, {node.all_input_nodes[0].name}." + ) + opt_var = cp.Variable(1, boolean=True) + constr += [ + cp.sum(opt_var) == 1, + ] + node.meta["mase"]["software"]["autosharding"] = { + "op_strategy": _inherit_strategy( + node, + node.all_input_nodes[0].meta["mase"]["software"]["autosharding"][ + "op_strategy" + ], + ), + "opt_var": opt_var, + "is_implicit": True, + "inherited_from": node.all_input_nodes[0], + } + continue + + # For general call_function nodes, evaluate strategy based on the target + elif ( + node.op == "call_function" and node.target in AUTOSHARDING_FUNCTIONS.keys() + ): + logger.debug(f"Obtaining strategy for call_function node: {node.name}") + op_strategy = AUTOSHARDING_FUNCTIONS[node.target](node.meta["mase"], mesh) + + else: + logger.warning( + f"Unknown node {node.name} with op {node.op} with be allocated fully replicated strategy." + ) + op_strategy = fully_replicated_strategy(node.meta["mase"], mesh) + opt_var = cp.Variable(1, boolean=True) + constr += [ + cp.sum(opt_var) == 1, + ] + node.meta["mase"]["software"]["autosharding"] = { + "op_strategy": fully_replicated_strategy( + node.meta["mase"], + mesh, + ), + "opt_var": opt_var, + "is_implicit": False, + } + continue + + # Formulate optimization variable + opt_var = cp.Variable(len(op_strategy.strategies), boolean=True) + constr += [ + cp.sum(opt_var) == 1, + ] + + # Write into metadata + node.meta["mase"]["software"]["autosharding"] = { + "op_strategy": op_strategy, + "opt_var": opt_var, + "is_implicit": False, + } + + # Consider computation cost (c_v term) for each of the node's strategies + # placeholder/get_attr/output nodes have no computation cost + if node.op not in [ + "placeholder", + "get_attr", + "output", + "call_method", + ]: + cost_vector = [] + for strategy in op_strategy.strategies: + try: + cost = _get_computation_cost_from_strategy( + node, + strategy, + mesh, + profiling_device=pass_args.get("benchmarking_device", None), + ) + except Exception as e: + logger.warning( + f"Failed to compute computation cost for node {node} strategy: {strategy} due to exception: {e}" + ) + cost = 100000.0 + cost_vector.append(cost) + + expr += np.array(cost_vector) @ opt_var + + # Consider resharding cost for each of the node's arguments + e_var_checks = [] + for arg_idx, in_node in enumerate(node.all_input_nodes): + + # Skip constant nodes + if not isinstance(in_node, fx.Node) or not isinstance( + in_node.meta["mase"]["common"]["results"]["data_out_0"]["value"], + torch.Tensor, + ): + continue + logger.debug(f"Parsing arg {in_node} of node {node}") + + # Fetch this node's input specs + node_op_strategy = node.meta["mase"]["software"]["autosharding"][ + "op_strategy" + ] + node_in_specs = [ + ( + [strategy.input_specs][arg_idx] + if isinstance(strategy.input_specs, DTensorSpec) + else strategy.input_specs[arg_idx] + ) + for strategy in node_op_strategy.strategies + ] + + # Fetch the argument node's output specs + in_opt_var = in_node.meta["mase"]["software"]["autosharding"]["opt_var"] + arg_op_strategy = in_node.meta["mase"]["software"]["autosharding"][ + "op_strategy" + ] + arg_out_specs = [ + strategy.output_specs for strategy in arg_op_strategy.strategies + ] + + # Formulate resharding cost matrix + resharding_costs = np.zeros((opt_var.shape[0], in_opt_var.shape[0])) + + for dest_idx, dest_spec in enumerate(node_in_specs): + for src_idx, src_spec in enumerate(arg_out_specs): + + if isinstance(src_spec, tuple): + src_spec = src_spec[0] + + cost = redistribute_cost(src_spec, dest_spec) + resharding_costs[dest_idx, src_idx] = ( + 1000000 if cost == float("inf") else cost + ) + + resharding_costs = resharding_costs.flatten() + + # Formulate linearized variable for resharding cost + e_var = cp.Variable(resharding_costs.shape[0], boolean=True) + expr += e_var.T @ resharding_costs + constr += [ + cp.sum(e_var) == 1, + ] + + # After solving the ILP, verify constraints were correctly formulated + if pass_args.get("run_checks", False): + e_var_checks.append((opt_var, in_opt_var, e_var)) + + # Constraints s.t. e_var = outer(opt_var, in_opt_var) + indices = np.arange(e_var.shape[0]) + opt_indices, in_opt_indices = np.divmod(indices, in_opt_var.shape[0]) + constr += [ + e_var <= opt_var[opt_indices], + e_var <= in_opt_var[in_opt_indices], + e_var >= opt_var[opt_indices] + in_opt_var[in_opt_indices] - 1, + ] + + if pass_args.get("run_checks", False): + node.meta["mase"]["software"]["autosharding"]["e_var_checks"] = e_var_checks + + # Solve the ILP problem + prob = cp.Problem(cp.Minimize(expr), constr) + return mg, prob + + +def _run_checks(mg, pass_args): + """ + Run checks on the ILP solution to ensure that the constraints were correctly formulated. + + Args: + mg (MaseGraph): input mase graph. + pass_args (dict): pass arguments. + + Returns: + None + """ + + for node in mg.fx_graph.nodes: + check_list = node.meta["mase"]["software"]["autosharding"].get( + "e_var_checks", [] + ) + + # Check that the constraints on the linearised variable for resharding cost + # are correctly formulated + for opt_var, in_opt_var, e_var in check_list: + idx1 = np.where(opt_var.value == 1)[0][0] + idx2 = np.where(in_opt_var.value == 1)[0][0] + idx3 = np.where(e_var.value == 1)[0][0] + assert ( + idx3 == idx1 * in_opt_var.shape[0] + idx2 + ), f"Linearized variable for resharding cost is not consistent for node {node}." + + +def _mark_sharding(mg, pass_args): + """ + After solving the ILP, annotate the metadata of each operator in the graph with the chosen + parallelization strategy. + + Args: + mg (MaseGraph): input mase graph. + pass_args (dict): pass arguments. + + Returns: + MaseGraph: input mase graph. + dict: tensor sharding map. + """ + + logger.info( + f"Autosharding optimization finished, annotating graph with chosen sharding strategies for each node." + ) + + for node in mg.fx_graph.nodes: + opt_var = node.meta["mase"]["software"]["autosharding"]["opt_var"] + + # Get the strategy chosen by the ILP + if node.meta["mase"]["software"]["autosharding"].get("is_implicit", False): + parent_node = node.meta["mase"]["software"]["autosharding"][ + "inherited_from" + ] + idx = parent_node.meta["mase"]["software"]["autosharding"][ + "chosen_strategy_idx" + ] + chosen_strategy = node.meta["mase"]["software"]["autosharding"][ + "op_strategy" + ].strategies[idx] + + else: + try: + idx = np.where(opt_var.value == 1)[0][0] + except: + idx = np.argmax(opt_var.value) + + chosen_strategy = node.meta["mase"]["software"]["autosharding"][ + "op_strategy" + ].strategies[idx] + + # Annotate chosen placement strategy + node.meta["mase"]["software"]["autosharding"]["chosen_strategy_idx"] = idx + node.meta["mase"]["software"]["autosharding"][ + "placement_strategy" + ] = chosen_strategy + + # Annotate arg metadata with chosen strategy + arg_specs = chosen_strategy.input_specs + if isinstance(arg_specs, DTensorSpec): + arg_specs = (arg_specs,) + + if not node.op in ["placeholder", "get_attr", "output"]: + assert len(arg_specs) == len( + node.meta["mase"]["common"]["args"].keys() + ), "Number of arguments do not match metadata." + + out_spec = chosen_strategy.output_specs + + if node.op in ["placeholder", "get_attr", "output"]: + node.meta["mase"]["common"]["results"]["data_out_0"][ + "dtensor_spec" + ] = out_spec + + # call_function nodes + else: + arg_list = [i for i in node.meta["mase"]["common"]["args"].keys()] + + for arg_idx, arg_spec in enumerate(arg_specs): + arg_meta = node.meta["mase"]["common"]["args"][arg_list[arg_idx]] + if not isinstance(arg_meta, dict): + continue + arg_meta["dtensor_spec"] = arg_spec + + # Annotate output metadata with chosen strategy + node.meta["mase"]["common"]["results"]["data_out_0"]["dtensor_spec"] = out_spec + + return mg, {} + + +def alpa_intra_op_sharding_pass(mg, mesh, pass_args={}, debug=False): + """Intra-operator auto parallelization pass from the Alpa paper: https://arxiv.org/abs/2201.12023 + + Args: + mg (MaseGraph): Input MaseGraph. + mesh (MeshModel): mesh description. + pass_args (dict, optional): pass arguments. Defaults to {}. + debug (bool, optional): enable debug. Defaults to False. + + Returns: + MaseGraph: annotated MaseGraph. + """ + + # Formulate and solve the ILP + logger.info(f"Formulating the ILP...") + + # Set CUDA device for profiling + device_id = pass_args.get("benchmarking_device", None) + torch.cuda.set_device(device_id) + + logger.info(f"Setting CUDA device to: {device_id}") + + mg, problem = _extract_ilp(mg, mesh, pass_args) + + logger.info(f"Solving the ILP...") + problem.solve( + verbose=True, + scipy_options={ + "disp": pass_args.get(f"run_checks", False), + "time_limit": pass_args.get("time_limit", None), + "mip_rel_gap": pass_args.get("mip_rel_gap", 0) / 100, + }, + ) + + if pass_args.get("run_checks", False): + _run_checks(mg, pass_args) + + mg, _ = _mark_sharding(mg, pass_args) + + return mg, {"solution": problem.value} diff --git a/src/chop/passes/graph/analysis/autosharding/algos/fully_replicated.py b/src/chop/passes/graph/analysis/autosharding/algos/fully_replicated.py new file mode 100644 index 000000000..75035a39e --- /dev/null +++ b/src/chop/passes/graph/analysis/autosharding/algos/fully_replicated.py @@ -0,0 +1,28 @@ +from torch.distributed._tensor._op_schema import DTensorSpec +from torch.distributed._tensor.placement_types import Replicate + +from chop.ir import MaseGraph +from ..mesh_model import MeshModel + + +def fully_replicated_autosharding_pass( + mg: MaseGraph, + mesh: MeshModel, + pass_args: dict, +): + spec = DTensorSpec( + None, + (Replicate(), Replicate()), + None, + ) + + for node in mg.nodes: + meta = node.meta["mase"] + + for arg, arg_info in meta["common"]["args"].items(): + arg_info["dtensor_spec"] = spec + + for result, result_info in meta["common"]["results"].items(): + result_info["dtensor_spec"] = spec + + return mg, {"solution": {}} diff --git a/src/chop/passes/graph/analysis/autosharding/algos/megatron.py b/src/chop/passes/graph/analysis/autosharding/algos/megatron.py new file mode 100644 index 000000000..cc13679f6 --- /dev/null +++ b/src/chop/passes/graph/analysis/autosharding/algos/megatron.py @@ -0,0 +1,10 @@ +from chop.ir import MaseGraph +from ..mesh_model import MeshModel + + +def megatron_autosharding_pass( + mg: MaseGraph, + mesh: MeshModel, + pass_args: dict, +): + raise NotImplementedError("Megatron autosharding pass is not implemented yet.") diff --git a/src/chop/passes/graph/analysis/autosharding/alpa_cost_modelling.py b/src/chop/passes/graph/analysis/autosharding/alpa_cost_modelling.py deleted file mode 100644 index d9fea1aa5..000000000 --- a/src/chop/passes/graph/analysis/autosharding/alpa_cost_modelling.py +++ /dev/null @@ -1,119 +0,0 @@ -import numpy as np -from functools import lru_cache - -from chop.ir.graph import MaseMetadata - -from .common import SpmdShard -from .mesh_model import MeshModel - -BYTES_PER_ELEMENT = 4 - - -def get_communication_cost(sharding: tuple, node_meta: MaseMetadata, mesh: MeshModel): - assert ( - sharding[0][-1] == sharding[1][-2] - ), f"Inconsistent sharding for node: {node_meta.node}" - inner_dim_sharding = sharding[1][0] - - out_shape = node_meta["common"]["results"]["data_out_0"]["shape"] - - if inner_dim_sharding == SpmdShard.R: - return 0 - - else: - ar_dim = inner_dim_sharding.value # 0 for S_0, 1 for S_1 - return mesh.all_reduce_cost( - num_bytes=BYTES_PER_ELEMENT * np.prod(out_shape), mesh_dim=ar_dim - ) - - -@lru_cache(maxsize=None) -def get_resharding_cost( - mesh: MeshModel, src: tuple, dest: tuple, dest_node_meta: MaseMetadata -): - """ - Obtain the resharding cost given a source and destination sharding profile for a tensor. - The mesh object is assumed to have been initialized with alpha, beta parameters so that - the communication cost can be estimated for each MPI operator. - """ - - # If original sharding is fully replicated, no resharding is required - if src == dest or all(i == SpmdShard.R for i in src): - return 0 - - num_bytes = BYTES_PER_ELEMENT * np.prod( - dest_node_meta["common"]["args"]["data_in_0"]["shape"] - ) - - # No cost (simple split along given mesh dimension) - if ( - # Keep dim 0, split dim 1 - # E.g. (R, R) -> (R, S_0), (S_0, R) -> (S_0, S_1) - (src[0] == dest[0]) - and (src[1] == SpmdShard.R) - and (dest[1] in [SpmdShard.S_0, SpmdShard.S_1]) - # Split dim 0, keep dim 1 - # E.g. (R, R) -> (S_1, R), (R, S_1) -> (S_0, S_1) - or (src[1] == dest[1]) - and (src[0] == SpmdShard.R) - and (dest[0] in [SpmdShard.S_0, SpmdShard.S_1]) - ): - return 0 - - # Split -> Replicate (All Gather) - elif ( - # Keep dim 0, gather along dim 1 - # E.g. (S_1, S_0) -> (S_1, R) - (src[0] == dest[0]) - and (src[1] in [SpmdShard.S_0, SpmdShard.S_1]) - and (dest[1] == SpmdShard.R) - # Gather along dim 0, keep dim 1 - # E.g. (S_0, S_1) -> (R, S_1) - or (src[1] == dest[1]) - and (src[0] in [SpmdShard.S_0, SpmdShard.S_1]) - and (dest[0] == SpmdShard.R) - ): - ag_dim = 1 if src[0] == dest[0] else 0 - return mesh.all_gather_cost( - num_bytes=num_bytes, - mesh_dim=ag_dim, - ) - - # All-to-all - # E.g. (R, S_0) -> (S_0, R), (S_1, R) -> (R, S_1) - elif src[0] == dest[1] and src[1] == dest[0] and (SpmdShard.R in src): - # all to all - a2a_dim = src[0].value if src[0] != SpmdShard.R else src[1].value - try: - return mesh.all_to_all_cost( - num_bytes=num_bytes, - mesh_dim=a2a_dim, - ) - except: - assert False - - # Two-stage resharding: when the resharding cannot be resolved with a single split, all-gather or all-to-all, - # must first gather along the first non-replicated dimension, then recursively compute the cost for the - # reduced sharding - else: - # Reduce one dimension and re-compute - if src[0] != SpmdShard.R: - new_src = (SpmdShard.R, src[1]) - ag_dim = src[0].value - else: - new_src = (SpmdShard.R, SpmdShard.R) - ag_dim = src[1].value - - return mesh.all_gather_cost( - num_bytes=num_bytes, mesh_dim=ag_dim - ) + get_resharding_cost(mesh, new_src, dest, dest_node_meta) - - -def get_resharding_matrix(mesh, src_shardings, dest_shardings, dest_node_meta): - mat = np.zeros((len(dest_shardings), len(src_shardings))) - for src_idx, src in enumerate(src_shardings): - for dest_idx, dest in enumerate(dest_shardings): - mat[dest_idx, src_idx] = get_resharding_cost( - mesh, src, dest, dest_node_meta - ) - return mat diff --git a/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py b/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py deleted file mode 100644 index b431095d6..000000000 --- a/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py +++ /dev/null @@ -1,344 +0,0 @@ -import torch -import torch.fx as fx -from torch.distributed._tensor._collective_utils import redistribute_cost -from torch.distributed._tensor._op_schema import DTensorSpec -import numpy as np -import cvxpy as cp - -from chop.tools import get_logger -from chop.tools.utils import deepgetattr - -from .layers import ( - AUTOSHARDING_MODULES, - AUTOSHARDING_FUNCTIONS, - AUTOSHARDING_METHODS, - IMPLICIT_FUNCS, - IMPLICIT_METHODS, -) -from .strategies.common import ( - fully_replicated_strategy, - placeholder_or_getattr_strategy, -) - - -logger = get_logger(__name__) -logger.setLevel("DEBUG") - - -def _extract_ilp(mg, mesh, pass_args={}): - """ - For each node in the graph, assign an OpStrategy object which contains all possible - sharding algorithms. Also assign opt_var instance which is one-hot vector used to - solve ILP. - - Return list of constraints associated with ILP. The constraints at this stage only - enforce that each optimizer variable is a one-hot boolean vector. - - Args: - mg (MaseGraph): input mase graph. - mesh (MeshModel): mesh model. - pass_args (dict, optional): pass arguments. Defaults to {}. - - Returns: - MaseGraph: input mase graph. - cp.Problem: optimization problem. - """ - - # Setup for the ILP optimization - constr = [] - expr = 0 - - # Find sharding strategies for each operator in the graph - for node in mg.fx_graph.nodes: - - if (node.op == "call_function" and node.target in IMPLICIT_FUNCS) or ( - node.op == "call_method" and node.target in IMPLICIT_METHODS - ): - logger.debug( - f"Implicit {node.op} node {node.name} was assigned fully replicated sharding." - ) - - op_strategy = fully_replicated_strategy(node.meta["mase"], mesh) - - opt_var = cp.Variable(1, boolean=True) - constr += [ - cp.sum(opt_var) == 1, - ] - - # Opt var is None since no decision needs to be taken - node.meta["mase"]["software"]["autosharding"] = { - "op_strategy": op_strategy, - "opt_var": opt_var, - "input": None, - "output": None, - } - continue - - # Obtain strategy according to node op - # ================================================ - - if node.op in ["placeholder", "get_attr"]: - logger.debug( - f"Node {node} with op {node.op} will be assigned all permutations of Shard(dims) and Replicate()" - ) - op_strategy = placeholder_or_getattr_strategy( - node.meta["mase"], - mesh, - skip_fully_replicated=pass_args.get("skip_fully_replicated", False), - ) - - elif node.op == "output": - logger.debug( - f"Op strategy from node {node.all_input_nodes[0]} is propagated to {node} node." - ) - node.meta["mase"]["software"]["autosharding"] = { - "op_strategy": node.all_input_nodes[0].meta["mase"]["software"][ - "autosharding" - ]["op_strategy"], - "opt_var": None, - "input": None, - "output": None, - } - continue - - elif node.op == "call_module" and isinstance( - deepgetattr(mg.model, node.target), tuple(AUTOSHARDING_MODULES.keys()) - ): - logger.debug(f"Obtaining strategy for node {node.name}") - module_cls = type(deepgetattr(mg.model, node.target)) - op_strategy = AUTOSHARDING_MODULES[module_cls](node.meta["mase"], mesh) - - elif node.op == "call_method" and node.target in AUTOSHARDING_METHODS.keys(): - logger.debug(f"Obtaining strategy for node {node.name}") - op_strategy = AUTOSHARDING_METHODS[node.target](node.meta["mase"], mesh) - - elif ( - node.op == "call_function" and node.target in AUTOSHARDING_FUNCTIONS.keys() - ): - logger.debug(f"Obtaining strategy for node {node.name}") - op_strategy = AUTOSHARDING_FUNCTIONS[node.target](node.meta["mase"], mesh) - - else: - logger.warning(f"Unknown node {node.name} with op {node.op}") - op_strategy = fully_replicated_strategy(node.meta["mase"], mesh) - opt_var = cp.Variable(1, boolean=True) - constr += [ - cp.sum(opt_var) == 1, - ] - node.meta["mase"]["software"]["autosharding"] = { - "op_strategy": fully_replicated_strategy(node.meta["mase"], mesh), - "opt_var": opt_var, - "input": None, - "output": None, - } - continue - - # Formulate optimization variable - opt_var = cp.Variable(len(op_strategy.strategies), boolean=True) - constr += [ - cp.sum(opt_var) == 1, - ] - - # Write into metadata - node.meta["mase"]["software"]["autosharding"] = { - "op_strategy": op_strategy, - "opt_var": opt_var, - "input": None, - "output": None, - } - - # Consider resharding cost for each of the node's arguments - e_var_checks = [] - for arg_idx, in_node in enumerate(node.all_input_nodes): - - # Skip constant nodes - if not isinstance(in_node, fx.Node) or not isinstance( - in_node.meta["mase"]["common"]["results"]["data_out_0"]["value"], - torch.Tensor, - ): - continue - logger.debug(f"Parsing arg {in_node} of node {node}") - - # Fetch this node's input specs - node_op_strategy = node.meta["mase"]["software"]["autosharding"][ - "op_strategy" - ] - node_in_specs = [ - ( - [strategy.input_specs][arg_idx] - if isinstance(strategy.input_specs, DTensorSpec) - else strategy.input_specs[arg_idx] - ) - for strategy in node_op_strategy.strategies - ] - - # Fetch the argument node's output specs - in_opt_var = in_node.meta["mase"]["software"]["autosharding"]["opt_var"] - arg_op_strategy = in_node.meta["mase"]["software"]["autosharding"][ - "op_strategy" - ] - arg_out_specs = [ - strategy.output_specs for strategy in arg_op_strategy.strategies - ] - - # Formulate resharding cost matrix - resharding_costs = np.zeros((opt_var.shape[0], in_opt_var.shape[0])) - - for dest_idx, dest_spec in enumerate(node_in_specs): - for src_idx, src_spec in enumerate(arg_out_specs): - cost = redistribute_cost(src_spec, dest_spec) - resharding_costs[dest_idx, src_idx] = ( - 1000000 if cost == float("inf") else cost - ) - - resharding_costs = resharding_costs.flatten() - - # Formulate linearized variable for resharding cost - e_var = cp.Variable(resharding_costs.shape[0], boolean=True) - expr += e_var.T @ resharding_costs - constr += [ - cp.sum(e_var) == 1, - ] - - # After solving the ILP, verify constraints were correctly formulated - if pass_args.get("run_checks", False): - e_var_checks.append((opt_var, in_opt_var, e_var)) - - # Constraints s.t. e_var = outer(opt_var, in_opt_var) - indices = np.arange(e_var.shape[0]) - opt_indices, in_opt_indices = np.divmod(indices, in_opt_var.shape[0]) - constr += [ - e_var <= opt_var[opt_indices], - e_var <= in_opt_var[in_opt_indices], - e_var >= opt_var[opt_indices] + in_opt_var[in_opt_indices] - 1, - ] - - if pass_args.get("run_checks", False): - node.meta["mase"]["software"]["autosharding"]["e_var_checks"] = e_var_checks - - # Solve the ILP problem - prob = cp.Problem(cp.Minimize(expr), constr) - return mg, prob - - -def _run_checks(mg, pass_args): - """ - Run checks on the ILP solution to ensure that the constraints were correctly formulated. - - Args: - mg (MaseGraph): input mase graph. - pass_args (dict): pass arguments. - - Returns: - None - """ - - for node in mg.fx_graph.nodes: - check_list = node.meta["mase"]["software"]["autosharding"].get( - "e_var_checks", [] - ) - - # Check that the constraints on the linearised variable for resharding cost - # are correctly formulated - for opt_var, in_opt_var, e_var in check_list: - idx1 = np.where(opt_var.value == 1)[0][0] - idx2 = np.where(in_opt_var.value == 1)[0][0] - idx3 = np.where(e_var.value == 1)[0][0] - assert ( - idx3 == idx1 * in_opt_var.shape[0] + idx2 - ), f"Linearized variable for resharding cost is not consistent for node {node}." - - -def _mark_sharding(mg, pass_args): - """ - After solving the ILP, annotate the metadata of each operator in the graph with the chosen - parallelization strategy. - - Args: - mg (MaseGraph): input mase graph. - pass_args (dict): pass arguments. - - Returns: - MaseGraph: input mase graph. - dict: tensor sharding map. - """ - - for node in mg.fx_graph.nodes: - opt_var = node.meta["mase"]["software"]["autosharding"]["opt_var"] - - if opt_var is None: - continue - - try: - idx = np.where(opt_var.value == 1)[0][0] - except: - idx = np.argmax(opt_var.value) - - chosen_strategy = node.meta["mase"]["software"]["autosharding"][ - "op_strategy" - ].strategies[idx] - - # Annotate chosen placement strategy - node.meta["mase"]["software"]["autosharding"][ - "placement_strategy" - ] = chosen_strategy - - arg_specs = chosen_strategy.input_specs - out_spec = chosen_strategy.output_specs - - if isinstance(arg_specs, DTensorSpec): - arg_specs = (arg_specs,) - - # Annotate arg metadata with chosen strategy - if node.op in ["placeholder", "get_attr", "call_method", "output"]: - pass - - # call_function nodes - else: - arg_list = [i for i in node.meta["mase"]["common"]["args"].keys()] - - for arg_idx, arg_spec in enumerate(arg_specs): - arg_meta = node.meta["mase"]["common"]["args"][arg_list[arg_idx]] - if not isinstance(arg_meta, dict): - continue - arg_meta["dtensor_spec"] = arg_spec - - # Annotate output metadata with chosen strategy - node.meta["mase"]["common"]["results"]["data_out_0"]["dtensor_spec"] = out_spec - - return mg, {} - - -def alpa_intra_op_sharding_pass(mg, mesh, pass_args={}, debug=False): - """Intra-operator auto parallelization pass from the Alpa paper: https://arxiv.org/abs/2201.12023 - - Args: - mg (MaseGraph): Input MaseGraph. - mesh (MeshModel): mesh description. - pass_args (dict, optional): pass arguments. Defaults to {}. - debug (bool, optional): enable debug. Defaults to False. - - Returns: - MaseGraph: annotated MaseGraph. - """ - - # Formulate and solve the ILP - logger.info(f"Formulating the ILP...") - mg, problem = _extract_ilp(mg, mesh, pass_args) - - logger.info(f"Solving the ILP...") - problem.solve( - verbose=True, - scipy_options={ - "disp": True, - "time_limit": pass_args.get("time_limit", None), - "mip_rel_gap": pass_args.get("mip_rel_gap", 0) / 100, - }, - ) - - if pass_args.get("run_checks", False): - _run_checks(mg, pass_args) - - mg, _ = _mark_sharding(mg, pass_args) - - return mg, {"solution": problem.value} diff --git a/src/chop/passes/graph/analysis/autosharding/alpa_layers.py b/src/chop/passes/graph/analysis/autosharding/alpa_layers.py deleted file mode 100644 index 986079b07..000000000 --- a/src/chop/passes/graph/analysis/autosharding/alpa_layers.py +++ /dev/null @@ -1,112 +0,0 @@ -import itertools -import numpy as np -import torch.nn as nn - -from chop.tools import get_logger -from chop.models.patched.bert.modeling_bert import BertSelfAttention - -from .common import SpmdShard, VALID_2D_TENSOR_SHARDINGS -from .alpa_cost_modelling import get_communication_cost - - -logger = get_logger(__name__) - - -def is_valid_2d_sharding(sharding): - if len(sharding) > 2: - return sharding[1:] in VALID_2D_TENSOR_SHARDINGS - else: - return sharding in VALID_2D_TENSOR_SHARDINGS - - -def is_valid_sharding_pair(sharding_pair): - return sharding_pair[0][-1] == sharding_pair[1][-2] - - -def is_fully_replicated(sharding_pair): - return all(all(dimp == SpmdShard.R for dimp in subp) for subp in sharding_pair) - - -def get_valid_2d_shardings(node_meta, mesh, module): - """ - Return every valid combination of shardings for the input tensors. For an operator - sharding to be valid, the inner dimension must have the same sharding. - E.g. ((R, S_0), (S_0, R)) are valid, but ((R, S_0), (S_1, R)) is not. - """ - input_shardings = [] - output_shardings = [] - compute_cost_vector = [] - communication_cost_vector = [] - - out_rank = len(node_meta["mase"]["common"]["results"]["data_out_0"]["shape"]) - - for perm in itertools.product(VALID_2D_TENSOR_SHARDINGS, repeat=2): - if out_rank > 2: - perm = tuple((SpmdShard.R,) * (out_rank - 2) + p for p in perm) - output_sharding = tuple( - (SpmdShard.R,) * (out_rank - 2) + (perm[0][-2], perm[1][-1]) - ) - if ( - not is_fully_replicated(perm) - and is_valid_sharding_pair(perm) - and is_valid_2d_sharding(output_sharding) - ): - input_shardings.append({"data_in_0": perm[0], "weight": perm[1]}) - output_shardings.append(output_sharding) - - compute_cost_vector.append(0) - communication_cost_vector.append( - get_communication_cost(perm, node_meta["mase"], mesh) - ) - - return ( - input_shardings, - output_shardings, - np.array(compute_cost_vector), - np.array(communication_cost_vector), - ) - - -def get_valid_linear_shardings(node_meta, mesh, module): - return get_valid_2d_shardings(node_meta, mesh, module) - - -def get_valid_layernorm_shardings(node_meta, mesh, module): - rank = len(node_meta["mase"]["common"]["results"]["data_out_0"]["shape"]) - valid_input_shardings = [{"data_in_0": (SpmdShard.R,) * rank}] - valid_output_shardings = [(SpmdShard.R,) * rank] - compute_cost_vector = [0] - communication_cost_vector = [0] - return ( - valid_input_shardings, - valid_output_shardings, - np.array(compute_cost_vector), - np.array(communication_cost_vector), - ) - - -def get_valid_embedding_shardings(node_meta, mesh, module): - weight_rank = len(module.weight.shape) - data_in_rank = len(node_meta["mase"]["common"]["args"]["data_in_0"]["shape"]) - valid_input_shardings = [ - { - "data_in_0": (SpmdShard.R,) * data_in_rank, - "weight": (SpmdShard.R,) * weight_rank, - } - ] - valid_output_shardings = [(SpmdShard.R,) * data_in_rank] - compute_cost_vector = [0] - communication_cost_vector = [0] - return ( - valid_input_shardings, - valid_output_shardings, - np.array(compute_cost_vector), - np.array(communication_cost_vector), - ) - - -ALPA_LAYERS = { - nn.Linear: get_valid_linear_shardings, - nn.LayerNorm: get_valid_layernorm_shardings, - nn.Embedding: get_valid_embedding_shardings, -} diff --git a/src/chop/passes/graph/analysis/autosharding/autosharding.py b/src/chop/passes/graph/analysis/autosharding/autosharding.py index 596248110..2515828e9 100644 --- a/src/chop/passes/graph/analysis/autosharding/autosharding.py +++ b/src/chop/passes/graph/analysis/autosharding/autosharding.py @@ -1,10 +1,15 @@ -import numpy as np -import cvxpy as cp +import os from time import time import dill +from torch.distributed._tensor._op_schema import DTensorSpec +from torch.distributed._tensor.placement_types import Replicate + from chop.tools import get_logger from .mesh_model import MeshModel +from .algos.alpa import alpa_autosharding_pass +from .algos.megatron import megatron_autosharding_pass +from .algos.fully_replicated import fully_replicated_autosharding_pass logger = get_logger(__name__) logger.setLevel("INFO") @@ -24,7 +29,7 @@ def _import_solution( mg, solution: dict, mesh: MeshModel, - extrapolate_sharding: bool = True, + extrapolate_sharding: bool = False, ): """Import an autosharding solution into the metadata of the MaseGraph. @@ -40,12 +45,6 @@ def _import_solution( for node in mg.fx_graph.nodes: logger.debug(f"Importing solution for node: {node.name}") - # Only import solution for getattr nodes - # TO DO: this is hard-coded for GPT2 - # Figure out how to generalize - if not node.name.startswith("transformer_"): - continue - # Extrapolate from first layer by string matching if node.name not in solution.keys() and extrapolate_sharding: @@ -56,7 +55,7 @@ def _import_solution( extrapolate_node = node.name.replace(f"_{layer_num}_", "_0_", 1) if extrapolate_node in solution.keys(): - logger.warning( + logger.debug( f"Node: {node.name} not found in solution. Extrapolating from solution for: {extrapolate_node}" ) solution[node.name] = solution[extrapolate_node] @@ -136,6 +135,8 @@ def _export_solution(mg, export_file: str = "ilp_solution.pkl"): ) else: spec = result_info["dtensor_spec"] + if isinstance(spec, tuple): + spec = spec[0] out_dict[node_name]["results"][result] = spec.placements with open(export_file, "wb") as file: @@ -208,6 +209,9 @@ def _get_sharding_map(mg): def autosharding_analysis_pass(mg, pass_args: dict = {}): """Annotate the metadata of each operator in the graph with a parallelization strategy. + For the autosharding pass to work, the fx graph must contain only placeholder, get_attr, + call_functional and output nodes. call_method and call_module nodes are not allowed. + Args: mg (MaseGraph): input mase graph. pass_args (dict, optional): pass arguments. Defaults to {}. @@ -233,11 +237,6 @@ def autosharding_analysis_pass(mg, pass_args: dict = {}): - ilp_solution_file (optional) -> str : File to export the autosharding solution to. Defaults to: "ilp_solution.pkl". """ - from torch.distributed._tensor._op_schema import DTensorSpec - from torch.distributed._tensor.placement_types import Replicate - from .alpa import alpa_autosharding_pass - from .megatron import megatron_autosharding_pass - assert ( "mesh_shape" in pass_args ), "Logical description for device cluster was not specified." @@ -248,8 +247,12 @@ def autosharding_analysis_pass(mg, pass_args: dict = {}): mesh = MeshModel(pass_args["mesh_shape"]) # Preload autosharding solution + fname = pass_args.get("ilp_solution_file", "ilp_solution.pkl") + # check if solution file exists if pass_args.get("preload_solution", False): - fname = pass_args.get("ilp_solution_file", "ilp_solution.pkl") + if not os.path.exists(fname): + raise FileNotFoundError(f"Solution file {fname} not found.") + logger.info(f"Preloading autosharding solution from: {fname}") with open(fname, "rb") as file: solution = dill.load(file) @@ -272,7 +275,9 @@ def autosharding_analysis_pass(mg, pass_args: dict = {}): # Run intra-operator pass start_time = time() - if algo == "alpa": + if algo == "fully_replicated": + mg, pass_outs = fully_replicated_autosharding_pass(mg, mesh, pass_args) + elif algo == "alpa": mg, pass_outs = alpa_autosharding_pass(mg, mesh, pass_args) elif algo == "megatron": mg, pass_outs = megatron_autosharding_pass(mg, mesh, pass_args) diff --git a/src/chop/passes/graph/analysis/autosharding/common.py b/src/chop/passes/graph/analysis/autosharding/common.py deleted file mode 100644 index e4fd59a48..000000000 --- a/src/chop/passes/graph/analysis/autosharding/common.py +++ /dev/null @@ -1,26 +0,0 @@ -from enum import Enum - - -class SpmdShard(Enum): - S_0 = 0 - S_1 = 1 - R = 3 - - def __repr__(self): - return self.name - - def __gt__(self, other): - if self.__class__ is other.__class__: - return self.value > other.value - return NotImplemented - - -VALID_2D_TENSOR_SHARDINGS = [ - (SpmdShard.R, SpmdShard.R), - (SpmdShard.R, SpmdShard.S_0), - (SpmdShard.R, SpmdShard.S_1), - (SpmdShard.S_0, SpmdShard.R), - (SpmdShard.S_0, SpmdShard.S_1), - (SpmdShard.S_1, SpmdShard.R), - (SpmdShard.S_1, SpmdShard.S_0), -] diff --git a/src/chop/passes/graph/analysis/autosharding/layers.py b/src/chop/passes/graph/analysis/autosharding/layers.py index 03f99d8ee..8576f5ebd 100644 --- a/src/chop/passes/graph/analysis/autosharding/layers.py +++ b/src/chop/passes/graph/analysis/autosharding/layers.py @@ -4,27 +4,34 @@ import torch.nn.functional as F from chop.tools import get_logger +from chop.nn.functional.tensor import ( + torch_size, + torch_expand, + torch_view, + torch_contiguous, + torch_reshape, + torch_split, + torch_permute, + torch_transpose, +) -from .strategies.common import fully_replicated_strategy -from .strategies.matrix_ops import ( +from .ops.common import fully_replicated_strategy +from .ops.matrix_ops import ( transpose_strategy, mm_strategy, addmm_strategy, bmm_strategy, baddmm_strategy, + scaled_dot_product_strategy, ) -from .strategies.view_ops import get_reshape_strategy -from .strategies.pointwise_ops import pointwise_strategy, linear_pointwise_strategy -from .strategies.math_ops import softmax_strategy, layer_norm_strategy -from .strategies.embedding_ops import embedding_strategy -from .strategies.tensor_ops import tensor_op_strategy, tensor_equal_strategy +from .ops.view_ops import get_reshape_strategy +from .ops.pointwise_ops import pointwise_strategy, linear_pointwise_strategy +from .ops.math_ops import softmax_strategy, layer_norm_strategy +from .ops.embedding_ops import embedding_strategy +from .ops.tensor_ops import tensor_op_strategy, tensor_equal_strategy logger = get_logger(__name__) -AUTOSHARDING_MODULES = { - torch.nn.ReLU: pointwise_strategy, -} - AUTOSHARDING_FUNCTIONS = { # embedding_ops.py F.embedding: embedding_strategy, @@ -242,27 +249,29 @@ torch.Tensor.zero_: tensor_op_strategy, torch.Tensor.equal: tensor_equal_strategy, torch.Tensor.is_same_size: tensor_equal_strategy, + # chop.nn.functional.tensor functions + torch_expand: get_reshape_strategy(torch.Tensor.expand), + torch_view: get_reshape_strategy(torch.Tensor.view), + torch_contiguous: tensor_op_strategy, + torch_reshape: get_reshape_strategy(torch.Tensor.reshape), + # torch_split: + torch_permute: get_reshape_strategy(torch.Tensor.permute), + torch_transpose: transpose_strategy, + torch.unsqueeze: get_reshape_strategy(torch.unsqueeze), + # SDPA + F.scaled_dot_product_attention: scaled_dot_product_strategy, } -AUTOSHARDING_METHODS = { - # view_ops.py - "view": get_reshape_strategy(torch.Tensor.view), - "reshape": get_reshape_strategy(torch.Tensor.reshape), - "expand": get_reshape_strategy(torch.Tensor.expand), - "permute": get_reshape_strategy(torch.Tensor.permute), - "transpose": get_reshape_strategy(torch.Tensor.transpose), - "masked_fill": pointwise_strategy, - "masked_fill_": pointwise_strategy, - "contiguous": tensor_op_strategy, -} +FULLY_REPLICATED_FUNCS = [ + F.embedding, + torch.arange, +] +# Implicit functions inherit their parent's strategy +# and do not change the sharding profile of their input tensors IMPLICIT_FUNCS = [ operator.getitem, getattr, torch.finfo, - torch.arange, -] - -IMPLICIT_METHODS = [ - "size", + torch_size, ] diff --git a/src/chop/passes/graph/analysis/autosharding/megatron.py b/src/chop/passes/graph/analysis/autosharding/megatron.py deleted file mode 100644 index 30cd36f7e..000000000 --- a/src/chop/passes/graph/analysis/autosharding/megatron.py +++ /dev/null @@ -1,23 +0,0 @@ -from chop.ir import MaseGraph -from .mesh_model import MeshModel - - -def megatron_autosharding_pass( - mg: MaseGraph, - mesh: MeshModel, - pass_args: dict, -): - for node in mg.fx_graph.nodes: - meta = node.meta["mase"]["common"] - - for arg, arg_spec in meta["args"].items(): - if not isinstance(arg_spec, dict): - continue - arg_spec["dtensor_spec"] = None - - for result, result_spec in meta["results"].items(): - if not isinstance(result_spec, dict): - continue - result_spec["dtensor_spec"] = None - - return mg, {"solution": {}} diff --git a/src/chop/passes/graph/analysis/autosharding/ops/basic_strategy.py b/src/chop/passes/graph/analysis/autosharding/ops/basic_strategy.py new file mode 100644 index 000000000..b9541013a --- /dev/null +++ b/src/chop/passes/graph/analysis/autosharding/ops/basic_strategy.py @@ -0,0 +1,183 @@ +# Adapted from Pytorch Distributed DTensor API. +# https://github.com/pytorch/pytorch/blob/main/torch/distributed/_tensor/ops/basic_strategy.py + +import itertools +from dataclasses import dataclass +from typing import List, Set, Tuple + +from torch.distributed._tensor._op_schema import OpStrategy, PlacementStrategy +from torch.distributed._tensor.placement_types import ( + DTensorSpec, + _Partial, + Placement, + Replicate, + Shard, +) + + +@dataclass +class EinsumDims: + contracting_dims: List[str] + batch_dims: List[str] + lhs_out_only_dims: List[str] + rhs_out_only_dims: List[str] + + @classmethod + def parse_equation(cls, equation: str) -> Tuple[List[str], str]: + # parse einop equation and extract arg specs + """ + Parse the einsum equation str to input dim chars and output dim char + """ + inputs, outputs = equation.split("->") + input_dims, output_dims = inputs.split(","), outputs.split(",") + + # NOTE: only support at most two inputs, and single output + # extend to support more inputs if needed in future + assert len(input_dims) <= 2, "Only support at most two inputs" + assert len(output_dims) == 1, "Only support single output" + output_dim = output_dims[0] + return input_dims, output_dim + + @classmethod + def parse_dims(cls, input_dims: List[str], output_dim: str) -> "EinsumDims": + """ + Parse the dims and extract the contracting, batch, and free dimensions + for the left and right hand sides. + """ + dim_char_set: Set[str] = set() + for input_dim in input_dims: + dim_char_set.update(input_dim) + + # get a determinisitc order of all dim chars + all_dim_chars = sorted(dim_char_set) + + # parse input and output dimensions + lhs_out_only_dims, rhs_out_only_dims = [], [] + batch_dims, contracting_dims = [], [] + + for dim_char in all_dim_chars: + if dim_char not in output_dim: + contracting_dims.append(dim_char) + else: + is_batch_dim = True + for input_dim in input_dims: + is_batch_dim = is_batch_dim and dim_char in input_dim + + if is_batch_dim: + batch_dims.append(dim_char) + else: + assert ( + len(input_dims) == 2 + ), "free dimension only supported for two inputs!" + lhs, rhs = input_dims + if dim_char in lhs: + lhs_out_only_dims.append(dim_char) + elif dim_char in rhs: + rhs_out_only_dims.append(dim_char) + else: + raise RuntimeError("Invalid dimension character") + + return cls( + contracting_dims=contracting_dims, + batch_dims=batch_dims, + lhs_out_only_dims=lhs_out_only_dims, + rhs_out_only_dims=rhs_out_only_dims, + ) + + +def gen_einsum_strategies( + equation: str, + mesh: tuple, + *, + linearity: bool = False, +) -> OpStrategy: + """ + Generate a strategy list for the ops that follow einsum style notation. + """ + # parse einop equation and extract dims + input_dims, output_dim = EinsumDims.parse_equation(equation) + edims = EinsumDims.parse_dims(input_dims, output_dim) + + all_mesh_dim_strategies = [] + + # generate strategies for each mesh dim + for mesh_dim in range(len(mesh.mesh_shape)): + mesh_dim_strategies = [] + + # placement list stores placements of [output, input1, input2, ...] + # first we always have replicate all for inputs and output + placement_list: List[Placement] = [Replicate()] * (len(input_dims) + 1) + mesh_dim_strategies.append(placement_list) + + if mesh[mesh_dim] <= 1: + # only replicate strategy for mesh dim with size 1 + # TODO: see if this is valid for the submesh case + continue + + # split batch dim + for batch_dim in edims.batch_dims: + output_batch_dim = output_dim.index(batch_dim) + placement_list = [Shard(output_batch_dim)] + for input_dim in input_dims: + input_batch_dim = input_dim.index(batch_dim) + placement_list.append(Shard(input_batch_dim)) + + mesh_dim_strategies.append(placement_list) + + # split contracting dim + for contracting_dim in edims.contracting_dims: + placement_list = [_Partial()] + for input_dim in input_dims: + input_contracting_dim = input_dim.index(contracting_dim) + placement_list.append(Shard(input_contracting_dim)) + + mesh_dim_strategies.append(placement_list) + + # split lhs free dim + for lhs_dim in edims.lhs_out_only_dims: + lhs_free_dim = output_dim.index(lhs_dim) + # this means split the lhs input and output + # i.e. S(0), R -> S(0) + lhs_placement_list: List[Placement] = [ + Shard(lhs_free_dim), + Shard(lhs_free_dim), + Replicate(), + ] + mesh_dim_strategies.append(lhs_placement_list) + + # split rhs free dim + for rhs_dim in edims.rhs_out_only_dims: + rhs_free_dim = output_dim.index(rhs_dim) + rhs_placement_list: List[Placement] = [ + Shard(rhs_free_dim), + Replicate(), + Shard(rhs_free_dim), + ] + mesh_dim_strategies.append(rhs_placement_list) + + # linearity strategy + if linearity: + linearity_placement_list: List[Placement] = [_Partial()] + for input_dim in input_dims: + linearity_placement_list.append(_Partial()) + mesh_dim_strategies.append(linearity_placement_list) + + all_mesh_dim_strategies.append(mesh_dim_strategies) + + # generate strategies for entire mesh + strategy_combs = itertools.product(*all_mesh_dim_strategies) + + # TODO: filter out invalid strategies, at this point we generate + # all possible strategies without considering the whether the tensor + # dim could be sharded or not, we would need to filter out invalid + # strategies base on the actual tensor shape + # (i.e. for Shard, tensor dim size must > mesh size) + all_strategies = [] + for strategy_comb in strategy_combs: + spec_list = [] + for specs in zip(*strategy_comb): + spec_list.append(DTensorSpec(mesh, tuple(specs))) + strat = PlacementStrategy(output_specs=spec_list[0], input_specs=spec_list[1:]) + all_strategies.append(strat) + + return OpStrategy(all_strategies) diff --git a/src/chop/passes/graph/analysis/autosharding/ops/common.py b/src/chop/passes/graph/analysis/autosharding/ops/common.py new file mode 100644 index 000000000..7e06d71dc --- /dev/null +++ b/src/chop/passes/graph/analysis/autosharding/ops/common.py @@ -0,0 +1,247 @@ +from typing import List +import itertools +import numpy as np + +import torch +import torch.nn.functional as F + +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed._tensor._op_schema import OpStrategy, PlacementStrategy +from torch.distributed._tensor.placement_types import ( + Placement, + Replicate, + Shard, + DTensorSpec, + TensorMeta, +) +from torch.distributed._tensor.ops.utils import ( + is_tensor_shardable, + generate_redistribute_costs, +) + +from chop.tools import get_logger + + +logger = get_logger(__name__) + + +def find_shape_and_dtype(arg): + + # If the argument in meta["common"]["args"][key] is correctly + # formulated with data, just extract shape and dtype + if isinstance(arg, dict): + in_shape = arg["shape"] + in_dtype = arg["torch_dtype"] + + # Otherwise, depends on the type of argument + elif isinstance(arg, torch.Size) or isinstance(arg, (tuple, list)): + in_shape = (len(arg),) + in_dtype = type(arg[0]) + elif isinstance(arg, (float, int)): + in_shape = (1,) + in_dtype = type(arg) + else: + logger.warning(f"Unknown type for arg: {arg}") + in_shape = tuple() + in_dtype = type(arg) + + return in_shape, in_dtype + + +def placeholder_or_getattr_strategy(meta, mesh, skip_fully_replicated=False): + ndims = len(meta["common"]["results"]["data_out_0"]["shape"]) + tensor_shape = meta["common"]["results"]["data_out_0"]["shape"] + opts = [Replicate()] + [Shard(dim) for dim in range(ndims)] + + tensor_meta = TensorMeta( + shape=tensor_shape, + stride=None, + dtype=meta["common"]["results"]["data_out_0"]["torch_dtype"], + ) + + shardings = [] + for sharding in itertools.product(opts, repeat=2): + # Skip fully replicated shardings since this sometimes forces the ILP + # to choose a fully replicated strategy for the entire model when + # the computation cost term is not formulated + if skip_fully_replicated and sharding == (Replicate(), Replicate()): + continue + + # Skip sharding if any dimension is sharded to 0 + skip_sharding = False + for dim in range(ndims): + # Find all device mesh dimensions along which this tensor dimension is sharded + mesh_sharded_dims = [ + idx for idx, shard in enumerate(sharding) if shard == Shard(dim) + ] + + # This tensor dimension is not sharded + if len(mesh_sharded_dims) == 0: + continue + + elif len(mesh_sharded_dims) == 1: + num_gpus = mesh.mesh_shape[mesh_sharded_dims[0]] + + else: + num_gpus = np.prod(mesh.mesh_shape) + + dim_size_after_sharding = tensor_shape[dim] // num_gpus + if dim_size_after_sharding == 0: + skip_sharding = True + continue + + if skip_sharding is True: + continue + + spec = DTensorSpec( + mesh=mesh, + placements=sharding, + tensor_meta=tensor_meta, + ) + shardings.append( + PlacementStrategy( + input_specs=spec, + output_specs=spec, + ) + ) + + return OpStrategy(shardings) + + +def fully_replicated_strategy(meta, mesh): + """ + Output of ops like size, getitem etc are always fully replicated + """ + sharding = [Replicate(), Replicate()] + + # call_method nodes don't list input tensor in the args list, but + # tensor is copied into meta["common"]["self"] when add_value = True + # is passed to add_common_metadata_pass + if meta.node.op == "call_method": + in_shape = meta["common"]["self"].shape + in_dtype = meta["common"]["self"].dtype + else: + if len(meta["common"]["args"]) > 0: + first_arg_key = ( + "data_in_0" + if "data_in_0" in meta["common"]["args"] + else [i for i in meta["common"]["args"].keys()][0] + ) + arg = meta["common"]["args"][first_arg_key] + in_shape, in_dtype = find_shape_and_dtype(arg) + + in_spec = [ + DTensorSpec( + mesh, + sharding, + tensor_meta=TensorMeta( + shape=in_shape, + stride=None, + dtype=in_dtype, + ), + ) + ] * len(meta["common"]["args"].keys()) + + else: + in_spec = [] + + dtype_key = ( + "torch_dtype" + if "torch_dtype" in meta["common"]["results"]["data_out_0"].keys() + else "type" + ) + out_dtype = meta["common"]["results"]["data_out_0"][dtype_key] + out_spec = DTensorSpec( + mesh, + sharding, + tensor_meta=TensorMeta( + shape=meta["common"]["results"]["data_out_0"]["shape"], + stride=None, + dtype=out_dtype, + ), + ) + + return OpStrategy( + [ + PlacementStrategy( + input_specs=in_spec, + output_specs=out_spec, + ) + ] + ) + + +def expand_to_full_mesh_op_strategy( + meta, + mesh: DeviceMesh, + single_mesh_dim_strategies: List[List[Placement]], + *, + input_index: int = 1, + inplace_op: bool = False, +) -> OpStrategy: + # Expand the single_mesh_dim_strategies to full mesh dim strategies. + all_mesh_dim_strategies = [single_mesh_dim_strategies] * mesh.ndim + + strategy_combs = itertools.product(*all_mesh_dim_strategies) + + all_strategies = [] + for strategy_comb in strategy_combs: + spec_list = [] + for specs in zip(*strategy_comb): + spec_list.append( + DTensorSpec( + mesh, + tuple(specs), + tensor_meta=TensorMeta( + shape=meta["common"]["results"]["data_out_0"]["shape"], + stride=None, + dtype=meta["common"]["results"]["data_out_0"]["torch_dtype"], + ), + ) + ) + + input_specs = spec_list[input_index:] + # input_args_strategy = op_schema.args_strategy + input_args_strategy = tuple( + arg.meta["mase"]["software"]["autosharding"]["op_strategy"] + for arg in meta.node.args + ) + assert len(input_specs) == len(input_args_strategy) + self_spec = input_args_strategy[0].strategies[0].output_spec + if inplace_op and self_spec.placements != input_specs[0].placements: + # if it's inplace op, we would only allow the placement strategy to be added when the + # input_spec matches the first argument's runtime sharding, otherwise we skip + continue + + # check inputs shardable + inputs_shardable = all( + is_tensor_shardable(inp.shape, s) + for inp, s in zip(input_args_strategy, input_specs) + ) + + # extend input_specs to include fully replicated sharding for constant nodes + extended_input_specs = input_specs + [ + DTensorSpec( + mesh, + (Replicate(), Replicate()), + # todo: may need to set tensor meta + tensor_meta=None, + ) + ] * (len(meta["common"]["args"].keys()) - len(input_specs)) + + # only add to the all_strategies list when all inputs are shardable + if inputs_shardable: + redistribute_cost = [ + generate_redistribute_costs(input_strategy, input_spec) + for input_strategy, input_spec in zip(input_args_strategy, input_specs) + ] + strategy = PlacementStrategy( + output_specs=( + tuple(spec_list[:input_index]) if input_index > 1 else spec_list[0] + ), + input_specs=extended_input_specs, + redistribute_cost=redistribute_cost, + ) + all_strategies.append(strategy) + + return OpStrategy(all_strategies) diff --git a/src/chop/passes/graph/analysis/autosharding/ops/embedding_ops.py b/src/chop/passes/graph/analysis/autosharding/ops/embedding_ops.py new file mode 100644 index 000000000..941af13b2 --- /dev/null +++ b/src/chop/passes/graph/analysis/autosharding/ops/embedding_ops.py @@ -0,0 +1,198 @@ +# Adapted from Pytorch Distributed DTensor API. +# https://github.com/pytorch/pytorch/blob/main/torch/distributed/_tensor/ops/embedding_ops.py + +from dataclasses import dataclass, field +from typing import cast, List, Optional + +import torch +import torch.distributed._functional_collectives as funcol +from torch.distributed._tensor._op_schema import StrategyType +from torch.distributed._tensor.placement_types import ( + Partial, + Placement, + Replicate, + Shard, +) +from torch.distributed.device_mesh import DeviceMesh + +from .common import expand_to_full_mesh_op_strategy + +aten = torch.ops.aten + + +@dataclass +class MaskBuffer: + data: Optional[torch.Tensor] = None + + def materialize_mask(self, mask): + if self.data is not None: + raise RuntimeError("MaskBuffer has already been materialized") + self.data = mask + + def release_mask(self): + # TODO: evaluate if we need to release the mask buffer or the buffer + # can just have the same lifetime as the Partial placement + if self.data is None: + raise RuntimeError("MaskBuffer has not been materialized") + self.data = None + + def apply_mask(self, tensor): + if self.data is None: + raise RuntimeError("MaskBuffer has not been materialized") + + # NOTE: _MaskPartial is being used by the embedding op and the gather op. + # For gather, the mask has the same dimension as the output tensor, whereas + # the output of the embedding op has an additional dimension compare to the input, + # hence the output masking logic below having two different cases. + if tensor.ndim == self.data.ndim: + tensor[self.data] = 0.0 + else: + tensor[self.data, :] = 0.0 + + +@dataclass(frozen=True) +class _MaskPartial(Partial): + """ + A partial mask placement devised for rowwise sharded embedding op, where we need + to mask and adjust the indices to the local embedding shard, embedding masking + is a special type of the Partial placement + + NOTE: the lifecycle of this MaskPartial placement follows the corresponding DTensor + lifecycle, i.e. the indices_mask would only be alive during the lifetime of the DTensor. + """ + + logical_dim_size: int = -1 + mask_buffer: MaskBuffer = field(default_factory=MaskBuffer) + + def _partition_value( + self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int + ) -> torch.Tensor: + # override parent logic to perform partial mask for embedding + num_chunks = mesh.size(mesh_dim) + # get local shard size and offset on the embedding_dim + local_shard_size, local_offset_on_dim = Shard._local_shard_size_on_dim( + self.logical_dim_size, + num_chunks, + mesh.get_local_rank(mesh_dim), + return_offset=True, + ) + # Build the input mask and save it for the current partial placement + # this is so that the output of embedding op can reuse the same partial + # placement saved mask to perform mask + reduction + mask = (tensor < local_offset_on_dim) | ( + tensor >= local_offset_on_dim + local_shard_size + ) + # mask the input tensor + masked_tensor = tensor.clone() - local_offset_on_dim + masked_tensor[mask] = 0 + # materialize the mask buffer to be used for reduction + self.mask_buffer.materialize_mask(mask) + return masked_tensor + + def _reduce_value( + self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int + ) -> torch.Tensor: + # by the time we ned reduction, we should have already saved the mask + assert self.mask_buffer.data is not None + + # apply the mask to the tensor that pending reduction + self.mask_buffer.apply_mask(tensor) + + # clear the mask buffer + self.mask_buffer.release_mask() + + # perform sum reduction + return funcol.all_reduce( + tensor, reduceOp=self.reduce_op, group=(mesh, mesh_dim) + ) + + def _reduce_shard_value( + self, + tensor: torch.Tensor, + mesh: DeviceMesh, + mesh_dim: int, + shard_spec: Placement, + ) -> torch.Tensor: + # by the time we ned reduction, we should have already saved the mask + assert self.mask_buffer.data is not None + + # apply the mask to the tensor that pending reduction + self.mask_buffer.apply_mask(tensor) + + # clear the mask buffer + self.mask_buffer.release_mask() + + # call reduce_shard_tensor of the shard_spec. + shard_spec = cast(Shard, shard_spec) + return shard_spec._reduce_shard_tensor(tensor, mesh, self.reduce_op, mesh_dim) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, _MaskPartial): + return False + + # if either data is not None, we invalidate the sharding cache, as this indicates + # the current MaskPartial placement is still in use and should not be used for cache hit. + if self.mask_buffer.data is not None or other.mask_buffer.data is not None: + return False + + return ( + self.reduce_op == other.reduce_op + and self.logical_dim_size == other.logical_dim_size + ) + + def __hash__(self) -> int: + return 1 + hash( + (self.logical_dim_size, id(self.mask_buffer.data), self.reduce_op) + ) + + def __repr__(self) -> str: + """ + machine readable representation of the MaskPartial placement + """ + return f"_MaskPartial(logical_dim_size={self.logical_dim_size})" + + def __str__(self) -> str: + """ + human readable representation of the MaskPartial placement + """ + return "MaskP" + + +def embedding_strategy(meta, mesh) -> StrategyType: + """ + This strategy handles embedding op. We have two possible embedding shardings: + rowwise and colwise + """ + weight_shape = meta["common"]["args"]["data_in_0"]["shape"] + indices_shape = meta["common"]["args"]["data_in_1"]["shape"] + output_emd_dim = len(indices_shape) + + single_mesh_dim_strategies = [] + + # placement list stores placements of [output, weight, input_indices] + # first we always have replicate all for inputs and output + all_replicate: List[Placement] = [Replicate()] * 3 + single_mesh_dim_strategies.append(all_replicate) + + # colwise sharding, output shard on last dim, weight shard on dim 1, input replicate + colwise_sharding = [Shard(output_emd_dim), Shard(1), Replicate()] + single_mesh_dim_strategies.append(colwise_sharding) + + # rowwise sharding, output is embedding partial, weight shard on dim 0, input accepts embedding partial + embedding_partial_placement = _MaskPartial(logical_dim_size=weight_shape[0]) + + # NOTE we want to reuse the same mask partial placement so that we can reuse the same mask that generates + # from the input indices and use it for output reduction + rowwise_sharding = [ + embedding_partial_placement, + Shard(0), + embedding_partial_placement, + ] + single_mesh_dim_strategies.append(rowwise_sharding) + + # batch dim sharding, weight replicated, input can shard on any dim, output follows input + for input_dim in range(len(indices_shape)): + batch_sharding = [Shard(input_dim), Replicate(), Shard(input_dim)] + single_mesh_dim_strategies.append(batch_sharding) + + return expand_to_full_mesh_op_strategy(meta, mesh, single_mesh_dim_strategies) diff --git a/src/chop/passes/graph/analysis/autosharding/ops/math_ops.py b/src/chop/passes/graph/analysis/autosharding/ops/math_ops.py new file mode 100644 index 000000000..32e33afa2 --- /dev/null +++ b/src/chop/passes/graph/analysis/autosharding/ops/math_ops.py @@ -0,0 +1,199 @@ +# Adapted from Pytorch Distributed DTensor API. +# https://github.com/pytorch/pytorch/blob/main/torch/distributed/_tensor/ops/math_ops.py + +from typing import cast, List, Optional, Sequence, Tuple, Union + +import torch +from torch.distributed._tensor._op_schema import ( + OpStrategy, + PlacementStrategy, +) +from torch.distributed._tensor.ops.utils import ( + normalize_dim, +) +from torch.distributed._tensor._utils import ( + normalize_to_torch_size, +) +from torch.distributed._tensor.placement_types import ( + DTensorSpec, + Placement, + Replicate, + Shard, +) + + +def _replicate_dims_start_at( + placements: Sequence[Placement], start_dim: int = 0 +) -> Tuple[Placement, ...]: + new_placements: List[Placement] = [] + for p in placements: + if p.is_partial() or (isinstance(p, Shard) and p.dim >= start_dim): + new_placements.append(Replicate()) # make it replicate + else: + new_placements.append(p) # keep the placement + return tuple(new_placements) + + +def replicate_reduction_dims( + placements: Tuple[Placement, ...], reduction_dims: List[int] +) -> Tuple[Placement, ...]: + # replicate the reduction dims if not reduction_linear + new_placements: List[Placement] = [] + + for p in placements: + if p.is_partial(): + new_placements.append(Replicate()) + elif isinstance(p, Shard) and p.dim in reduction_dims: + new_placements.append(Replicate()) + else: + new_placements.append(p) + + return tuple(new_placements) + + +def softmax_strategy(meta, mesh): + parent_node = meta.node.args[0] + input_strategy = parent_node.meta["mase"]["software"]["autosharding"]["op_strategy"] + ndim = len(meta["common"]["args"]["data_in_0"]["shape"]) + + softmax_dim = meta["common"]["args"]["dim"] + + input_strategy = cast(OpStrategy, input_strategy) + softmax_dim = cast(int, softmax_dim) + softmax_dim = normalize_dim(softmax_dim, ndim) + + output_strategy = OpStrategy([]) + for idx, input_placement_strategy in enumerate(input_strategy.strategies): + redistribute_costs = [] + input_src_spec = input_placement_strategy.output_spec + + # make sure input is replicated along the softmax dim + input_target_spec = DTensorSpec( + mesh=mesh, + placements=replicate_reduction_dims( + input_src_spec.placements, [softmax_dim] + ), + tensor_meta=input_src_spec.tensor_meta, + ) + # redistribute_costs.append( + # generate_redistribute_costs(input_strategy, input_target_spec) + # ) + output_target_spec = input_target_spec + output_strategy.strategies.append( + PlacementStrategy( + output_specs=output_target_spec, + input_specs=[input_target_spec], + # redistribute_cost=redistribute_costs, + ) + ) + + return output_strategy + + +def layer_norm_strategy(meta, mesh): + + # args must be: input, normalized_shape, weight, bias, eps + # for None weight and bias, their corresponding objects will + # be None as well. layer_norm_strategy returns one OpStrategy + # for the triple return values (out, mean, rstd). + assert len(meta["common"]["args"].keys()) == 5 + + input_strategy = meta.node.args[0].meta["mase"]["software"]["autosharding"][ + "op_strategy" + ] + normalized_shape = meta["common"]["args"]["normalized_shape"]["value"] + weight_strategy = meta.node.kwargs["weight"].meta["mase"]["software"][ + "autosharding" + ]["op_strategy"] + bias_strategy = meta.node.kwargs["bias"].meta["mase"]["software"]["autosharding"][ + "op_strategy" + ] + + # the current layer norm implementation requires that all + # input DTensor's sharding must be in form of OpStrategy + assert isinstance(input_strategy, OpStrategy) + assert isinstance(normalized_shape, (int, Sequence, torch.Size)) + normalized_size = normalize_to_torch_size(normalized_shape) + + input_ndim = len(meta["common"]["args"]["data_in_0"]["shape"]) + axis = input_ndim - len(normalized_size) + + # we use OpStrategy because the output (out, mean, rstd) + # should have the same placements + output_strategy = OpStrategy([]) + for input_placement_strategy in input_strategy.strategies: + op_args_target_specs = [] + input_src_spec = input_placement_strategy.output_spec + + # for the input tensor, we replicate it on the inner dims if necessary + # TODO: we can avoid forcing the redistribution once we figure out + # how to decompose layer norm + input_target_spec = DTensorSpec( + mesh=mesh, + placements=_replicate_dims_start_at(input_src_spec.placements, axis), + tensor_meta=input_src_spec.tensor_meta, + ) + op_args_target_specs.append(input_target_spec) + + # Add replicate spec for normalized_shape + normalized_shape_spec = DTensorSpec( + mesh=mesh, + placements=(Replicate(),) * 2, + # todo: check that it's safe not to assign tensor meta here + tensor_meta=None, + ) + op_args_target_specs.append(normalized_shape_spec) + + if weight_strategy is not None: + assert isinstance(weight_strategy, OpStrategy) + # patching: weight and bias sharding strategy is currently always replicate + # So just take strategy at index 0 + # TO DO: when sharding decomposed layer norm, cross product weight strategies + # with input/bias strategies for final OpStrategy + weight_src_spec = weight_strategy.strategies[0].output_spec + + # for the weight tensor, we replicate it on all dims if necessary + # TODO: we can avoid forcing the redistribution once we figure out + # how to decompose layer norm + weight_target_spec = DTensorSpec( + mesh=mesh, + placements=_replicate_dims_start_at(weight_src_spec.placements), + tensor_meta=weight_src_spec.tensor_meta, + ) + op_args_target_specs.append(weight_target_spec) + # redistribute_costs.append( + # generate_redistribute_costs(weight_strategy, weight_target_spec) + # ) + + if bias_strategy is not None: + assert isinstance(bias_strategy, OpStrategy) + bias_src_spec = bias_strategy.strategies[0].output_spec + + # for the bias tensor, we replicate it on all dims if necessary + # TODO: we can avoid forcing the redistribution once we figure out + # how to decompose layer norm + bias_target_spec = DTensorSpec( + mesh=mesh, + placements=_replicate_dims_start_at(bias_src_spec.placements), + tensor_meta=bias_src_spec.tensor_meta, + ) + op_args_target_specs.append(bias_target_spec) + + # add fully replicated strategy for eps + eps_spec = DTensorSpec( + mesh=mesh, + placements=(Replicate(),) * 2, + tensor_meta=None, + ) + op_args_target_specs.append(eps_spec) + + # the output spec is the same as input spec + output_target_spec = input_target_spec + output_strategy.strategies.append( + PlacementStrategy( + output_specs=output_target_spec, + input_specs=op_args_target_specs, + ) + ) + + return output_strategy diff --git a/src/chop/passes/graph/analysis/autosharding/ops/matrix_ops.py b/src/chop/passes/graph/analysis/autosharding/ops/matrix_ops.py new file mode 100644 index 000000000..20a2c74c6 --- /dev/null +++ b/src/chop/passes/graph/analysis/autosharding/ops/matrix_ops.py @@ -0,0 +1,319 @@ +# Adapted from Pytorch Distributed DTensor API. +# https://github.com/pytorch/pytorch/blob/main/torch/distributed/_tensor/ops/matrix_ops.py + +import torch +from torch.distributed._tensor._op_schema import ( + OpStrategy, + PlacementStrategy, + PlacementList, +) +from torch.distributed._tensor.placement_types import Replicate, Shard, Placement +from .basic_strategy import gen_einsum_strategies +from torch.distributed._tensor.ops.utils import ( + infer_broadcast_dims_map, + map_placements_after_broadcast, +) +from torch.distributed._tensor.placement_types import ( + DTensorSpec, + Shard, + TensorMeta, +) + +from chop.ir.graph import MaseMetadata +from .common import expand_to_full_mesh_op_strategy +from ..utils import is_tensor_shardable + + +def _other(meta, dim): + if dim == meta["common"]["args"]["dim0"]["value"]: + return meta["common"]["args"]["dim1"]["value"] + elif dim == meta["common"]["args"]["dim1"]["value"]: + return meta["common"]["args"]["dim0"]["value"] + else: + raise ValueError(f"Invalid dim: {dim}") + + +def transpose_strategy( + meta: MaseMetadata, + mesh: tuple, +) -> OpStrategy: + + parent_node = meta.node.args[0] + self_strategy = parent_node.meta["mase"]["software"]["autosharding"]["op_strategy"] + + assert isinstance(self_strategy, OpStrategy) + + fully_replicated_spec = DTensorSpec( + mesh=mesh, + placements=[Replicate(), Replicate()], + tensor_meta=None, + ) + + transpose_strategies = [] + for input_strategy in self_strategy.strategies: + + if isinstance(input_strategy.output_specs, tuple): + input_spec = input_strategy.output_specs[0] + else: + input_spec = input_strategy.output_spec + + # follow the input spec but transpose the Shard placements + output_placements = [ + Shard(_other(meta, p.dim)) if isinstance(p, Shard) else p + for p in input_spec.placements + ] + transpose_strategy = PlacementStrategy( + output_specs=DTensorSpec( + mesh=input_spec.mesh, + placements=tuple(output_placements), + tensor_meta=TensorMeta( + shape=meta["common"]["results"]["data_out_0"]["shape"], + stride=None, + dtype=meta["common"]["results"]["data_out_0"]["torch_dtype"], + ), + ), + # include 2 fully replicated inputs for dim_0 and dim_1 arguments + input_specs=(input_spec,) + (fully_replicated_spec,) * 2, + ) + transpose_strategies.append(transpose_strategy) + + return OpStrategy(strategies=transpose_strategies) + + +def _mm_like_strategy( + mm_equation: str, + meta: MaseMetadata, + mesh: tuple, +) -> OpStrategy: + self_shape, mat2_shape = [arg["shape"] for arg in meta["common"]["args"].values()] + # generate all possible strategies for mm + mm_strategy = gen_einsum_strategies(mm_equation, mesh) + # filter out invalid strategies and associate costs + strategies = mm_strategy.strategies + filtered_strategies = [] + for strtg in strategies: + assert strtg.input_specs is not None + self_spec = strtg.input_specs[0] + mat2_spec = strtg.input_specs[1] + + self_spec.tensor_meta = TensorMeta( + shape=self_shape, + stride=None, + dtype=meta["common"]["args"]["data_in_0"]["torch_dtype"], + ) + mat2_spec.tensor_meta = TensorMeta( + shape=mat2_shape, + stride=None, + dtype=meta["common"]["args"]["data_in_1"]["torch_dtype"], + ) + strtg.output_spec.tensor_meta = TensorMeta( + shape=meta["common"]["results"]["data_out_0"]["shape"], + stride=None, + dtype=meta["common"]["results"]["data_out_0"]["torch_dtype"], + ) + + if is_tensor_shardable(self_shape, self_spec) and is_tensor_shardable( + mat2_shape, mat2_spec + ): + filtered_strategies.append(strtg) + + mm_strategy.strategies = filtered_strategies + + return mm_strategy + + +def _addmm_like_strategy( + mm_equation: str, + meta: MaseMetadata, + mesh: tuple, +) -> OpStrategy: + + self_shape, mat1_shape, mat2_shape = [ + arg["shape"] for arg in meta["common"]["args"].values() + ] + + mm_out_shape = torch.Size( + [ + mat2_shape[-1] if i == len(mat1_shape) - 1 else dim_size + for i, dim_size in enumerate(mat1_shape) + ] + ) + # generate all possible strategies for mm + mm_strategy = gen_einsum_strategies(mm_equation, mesh) + # filter out invalid strategies and associate costs + strategies = mm_strategy.strategies + filtered_strategies = [] + for strtg in strategies: + # construct new strategy by consider the self arg + assert strtg.input_specs is not None + mat1_spec = strtg.input_specs[0] + mat2_spec = strtg.input_specs[1] + out_spec = strtg.output_spec + + # self arg's spec should follow the output of mm, but need + # to consider broadcast for the self arg + broadcast_dims_map = infer_broadcast_dims_map(mm_out_shape, self_shape) + self_placements = map_placements_after_broadcast( + out_spec.placements, mm_out_shape, broadcast_dims_map + ) + self_spec = DTensorSpec(mesh=mesh, placements=self_placements) + + self_spec.tensor_meta = TensorMeta( + shape=self_shape, + stride=None, + dtype=meta["common"]["args"]["data_in_0"]["torch_dtype"], + ) + mat1_spec.tensor_meta = TensorMeta( + shape=mat1_shape, + stride=None, + dtype=meta["common"]["args"]["data_in_1"]["torch_dtype"], + ) + mat2_spec.tensor_meta = TensorMeta( + shape=mat2_shape, + stride=None, + dtype=meta["common"]["args"]["data_in_2"]["torch_dtype"], + ) + strtg.output_spec.tensor_meta = TensorMeta( + shape=meta["common"]["results"]["data_out_0"]["shape"], + stride=None, + dtype=meta["common"]["results"]["data_out_0"]["torch_dtype"], + ) + + if is_tensor_shardable(mat1_shape, mat1_spec) and is_tensor_shardable( + mat2_shape, mat2_spec + ): + # update input specs with new self spec + strtg.input_specs = (self_spec, mat1_spec, mat2_spec) + filtered_strategies.append(strtg) + + mm_strategy.strategies = filtered_strategies + + return mm_strategy + + +def mm_strategy( + meta: MaseMetadata, + mesh: tuple, +) -> OpStrategy: + return _mm_like_strategy("mk,kn->mn", meta, mesh) + + +def addmm_strategy( + meta: MaseMetadata, + mesh: tuple, +) -> OpStrategy: + return _addmm_like_strategy("mk,kn->mn", meta, mesh) + + +def bmm_strategy( + meta: MaseMetadata, + mesh: tuple, +) -> OpStrategy: + return _mm_like_strategy("bmk,bkn->bmn", meta, mesh) + + +def baddmm_strategy( + meta: MaseMetadata, + mesh: tuple, +) -> OpStrategy: + return _addmm_like_strategy("bmk,bkn->bmn", meta, mesh) + + +def scaled_dot_product_flash_attention_strategy( + meta: MaseMetadata, + mesh: tuple, +) -> OpStrategy: + # NOTE: currently we only support some simple strategies to support tensor parallelism + # TODO: sdpa might be a good candidate for us to explore decomposed sharding propagation + # as it involves: matmul, pointwise, reduction ops together. + arg_names = list(meta["common"]["args"].keys()) + arg_infos = list(meta["common"]["args"].values()) + return_debug_mask = len(arg_names) >= 6 and arg_infos[5]["value"] + + # q_input_strategy = op_schema.args_schema[0] + q_parent_node = meta.node.args[0] + q_input_strategy = q_parent_node.meta["mase"]["software"]["autosharding"][ + "op_strategy" + ] + assert isinstance(q_input_strategy, OpStrategy) + # assuming q/k/v have the same shape + + single_mesh_dim_strategies = [] + + # placement list stores placements of [outputs, inputs] + # in the spda case, we have 3 valid tensor outputs and 3 tensor inputs + # first we can always accept full replication for both inputs and outputs + all_replicate: PlacementList = [ + Replicate(), + Replicate(), + None, # cum_seq_q + None, # cum_seq_k + None, # max_q + None, # max_k + None, # philox_seed + None, # philox_offset + Replicate(), + Replicate(), + Replicate(), + Replicate(), + ] + single_mesh_dim_strategies.append(all_replicate) + + # second we can accept the sharding pattern of tensor parallelism, which + # shard on the num of head dim + qkv_sharding = Shard(1) # num head dim + output_sharding = Shard(1) # num head dim + logsumexp_sharding = Shard(1) # num head dim + if return_debug_mask: + debug_attn_mask_sharding: Placement = Shard(1) # num head dim + else: + # empty debug mask, replicated + debug_attn_mask_sharding = Replicate() + + num_heads_dim_sharding: PlacementList = [ + output_sharding, + logsumexp_sharding, + None, # cum_seq_q + None, # cum_seq_k + None, # max_q + None, # max_k + None, # philox_seed + None, # philox_offset + debug_attn_mask_sharding, + qkv_sharding, + qkv_sharding, + qkv_sharding, + ] + single_mesh_dim_strategies.append(num_heads_dim_sharding) + + # Context Parallelism: shards on the sequence dim + single_mesh_dim_strategies.append( + [ + Shard(2), # output + Shard(2), # logsumexp + None, # cum_seq_q + None, # cum_seq_k + None, # max_q + None, # max_k + None, # philox_seed + None, # philox_offset + Shard(2), # debugattn + Shard(2), # q + Shard(2), # k + Shard(2), # v + ] + ) + return expand_to_full_mesh_op_strategy( + meta, + mesh, + single_mesh_dim_strategies, + input_index=9, + ) + + +def scaled_dot_product_strategy( + meta: MaseMetadata, + mesh: tuple, +): + # todo: support efficient attention backend + return scaled_dot_product_flash_attention_strategy(meta, mesh) diff --git a/src/chop/passes/graph/analysis/autosharding/ops/pointwise_ops.py b/src/chop/passes/graph/analysis/autosharding/ops/pointwise_ops.py new file mode 100644 index 000000000..fe79c521a --- /dev/null +++ b/src/chop/passes/graph/analysis/autosharding/ops/pointwise_ops.py @@ -0,0 +1,176 @@ +# Adapted from Pytorch Distributed DTensor API. +# https://github.com/pytorch/pytorch/blob/main/torch/distributed/_tensor/ops/pointwise_ops.py + +from typing import List + +import torch +from torch.distributed._tensor._op_schema import ( + OpStrategy, + PlacementStrategy, +) +from torch.distributed._tensor.ops.utils import ( + infer_broadcast_dims_map, + map_placements_after_broadcast, + normalize_dim, +) +from torch.distributed._tensor.placement_types import ( + DTensorSpec, + Partial, + Placement, + Replicate, + Shard, + TensorMeta, +) + +from chop.tools import get_logger + +from .common import fully_replicated_strategy + +logger = get_logger(__name__) + + +def pointwise_strategy( + meta, + mesh, + linearity=False, +): + max_shards_strategy_index = -1 + max_shards = -1 + followed_strategy = None + + # if _is_inplace_op(op_schema.op): + # # inplace op should follow the first arg strategy + # followed_strategy = op_schema.args_schema[0] + # elif _is_out_variant_op(op_schema.op): + # # out variant op should follow the out kwarg strategy + # followed_strategy = op_schema.kwargs_schema["out"] + # else: + + # normal pointwise op, we choose to follow the arg with + # the max shards in case operands needs reshard + for idx, arg in enumerate(meta.node.args): + if not isinstance(arg, torch.fx.Node): + continue + arg_strategy = arg.meta["mase"]["software"]["autosharding"]["op_strategy"] + + arg_max_shards = arg_strategy.max_num_shards() + if arg_max_shards > max_shards: + max_shards_strategy_index = idx + max_shards = arg_max_shards + followed_strategy = arg_strategy + + assert isinstance(followed_strategy, OpStrategy), f"no strategy to follow!" + + return common_pointwise_strategy( + meta, mesh, followed_strategy, linearity, max_shards_strategy_index + ) + + +def common_pointwise_strategy( + meta, + mesh, + followed_strategy, + linearity, + followed_strategy_index=0, +): + # handle broadcasting + parsed_args = [] + for arg in meta["common"]["args"].values(): + if isinstance(arg, dict): + parsed_args.append(torch.zeros(arg["shape"])) + elif isinstance(arg, torch.Size): + parsed_args.append(torch.Tensor(list(arg))) + elif isinstance(arg, (tuple, list)): + parsed_args.append(torch.Tensor(arg)) + elif isinstance(arg, torch.Tensor): + parsed_args.append(arg) + elif isinstance(arg, (float, int)): + parsed_args.append(torch.Tensor([arg])) + else: + logger.warning( + f"Unrecognized arg type: {type(arg)}, defaulting to fully replicated strategy." + ) + return fully_replicated_strategy(meta, mesh) + + common_shape = torch.broadcast_shapes(*[arg.shape for arg in parsed_args]) + + # Extract followed argument shape + followed_shape = parsed_args[followed_strategy_index].shape + + # Iterate through followed argument's strategies to cast output shardings + pointwise_strategy = OpStrategy([]) + for placement_strategy in followed_strategy.strategies: + spec_to_follow = placement_strategy.output_spec + out_placements: List[Placement] = [] + for placement in spec_to_follow.placements: + if isinstance(placement, Shard): + shard_dim = normalize_dim(placement.dim, len(followed_shape)) + common_ndim = len(common_shape) + new_shard_dim = common_ndim - len(followed_shape) + shard_dim + out_placements.append(Shard(new_shard_dim)) + elif isinstance(placement, Partial) and not linearity: + # clear the partial placemnet if op does not support linearity + # by default we just replicate the partial, need to see if this + # is optimal for all cases + out_placements.append(Replicate()) + else: + out_placements.append(placement) + + input_specs: List[DTensorSpec] = [] + for arg_node in meta.node.args: + if not isinstance(arg_node, torch.fx.Node): + continue + input_arg = arg_node.meta["mase"]["software"]["autosharding"]["op_strategy"] + if isinstance(input_arg, OpStrategy): + # every arg follow the out_placements, but need to handle broadcasting + input_arg_spec = input_arg.strategies[0].output_spec + input_arg_dims_map = infer_broadcast_dims_map( + common_shape, + arg_node.meta["mase"]["common"]["results"]["data_out_0"]["shape"], + ) + input_target_placements = map_placements_after_broadcast( + tuple(out_placements), + common_shape, + input_arg_dims_map, + ) + input_arg_target_spec = DTensorSpec( + mesh=mesh, + placements=input_target_placements, + tensor_meta=input_arg_spec.tensor_meta, + ) + input_specs.append(input_arg_target_spec) + + dtype = meta["common"]["results"]["data_out_0"].get( + "torch_dtype", torch.float32 + ) + pointwise_strategy.strategies.append( + PlacementStrategy( + output_specs=DTensorSpec( + mesh=mesh, + placements=tuple(out_placements), + tensor_meta=TensorMeta( + shape=meta["common"]["results"]["data_out_0"]["shape"], + stride=None, + dtype=dtype, + ), + ), + input_specs=input_specs, + ) + ) + return pointwise_strategy + + +def linear_pointwise_strategy( + meta, + mesh, +): + """ + Linear pointwise operators can propagate pending reductions. + For example, c = add(a, b); if a is pending sum, then c will be + pending sum as well without any communication overhead. + """ + return pointwise_strategy( + meta, + mesh, + linearity=True, + ) diff --git a/src/chop/passes/graph/analysis/autosharding/ops/tensor_ops.py b/src/chop/passes/graph/analysis/autosharding/ops/tensor_ops.py new file mode 100644 index 000000000..0a9daee2a --- /dev/null +++ b/src/chop/passes/graph/analysis/autosharding/ops/tensor_ops.py @@ -0,0 +1,111 @@ +# Adapted from Pytorch Distributed DTensor API. +# https://github.com/pytorch/pytorch/blob/main/torch/distributed/_tensor/ops/tensor_ops.py + +from torch.distributed._tensor._op_schema import ( + OpStrategy, + PlacementStrategy, + StrategyType, +) +from torch.distributed._tensor.ops.utils import ( + is_tensor_partial, +) +from torch.distributed._tensor.placement_types import ( + DTensorSpec, + Partial, + Replicate, + TensorMeta, +) + + +def tensor_op_strategy(meta, mesh) -> StrategyType: + # Default strategy by default just propagate the first input strategy + select_strategy = meta.node.args[0].meta["mase"]["software"]["autosharding"][ + "op_strategy" + ] + assert isinstance(select_strategy, OpStrategy) + + node_args = list(meta["common"]["args"].keys()) + if len(node_args) > 0: + first_arg_name = node_args[0] + arg_shape, arg_dtype = ( + meta["common"]["args"][first_arg_name]["shape"], + meta["common"]["args"][first_arg_name]["torch_dtype"], + ) + + else: + arg_shape, arg_dtype = ( + meta["common"]["self"].shape, + meta["common"]["self"].dtype, + ) + + first_result = list(meta["common"]["results"].keys())[0] + result_shape, result_dtype = ( + meta["common"]["results"][first_result]["shape"], + meta["common"]["results"][first_result]["torch_dtype"], + ) + + default_strategy = [] + for strategy in select_strategy.strategies: + # we create new DTensorSpecs even for default strategy to assure that + # the tensor metas are distinct between the arguments and outputs + default_strategy.append( + PlacementStrategy( + input_specs=( + DTensorSpec( + mesh=strategy.output_spec.mesh, + placements=strategy.output_spec.placements, + tensor_meta=TensorMeta( + shape=arg_shape, dtype=arg_dtype, stride=None + ), + ), + ) + * len(meta.node.args), + output_specs=DTensorSpec( + mesh=strategy.output_spec.mesh, + placements=strategy.output_spec.placements, + tensor_meta=TensorMeta( + shape=result_shape, dtype=result_dtype, stride=None + ), + ), + ) + ) + return OpStrategy(default_strategy) + + +def tensor_equal_strategy(meta, mesh) -> StrategyType: + # equal_strategy deals with ops that comparing two tensor, we need to make sure + # sharding layout the same with two operands, we choose to follow the arg with max + # num of shards, still keep is_same_size here for completeness as they share the + # same strategy in theory. + self_strategy, other_strategy = ( + meta.node.args[0].meta["mase"]["software"]["autosharding"]["op_strategy"], + meta.node.args[1].meta["mase"]["software"]["autosharding"]["op_strategy"], + ) + assert isinstance(self_strategy, OpStrategy) + assert isinstance(other_strategy, OpStrategy) + + select_strategy = ( + self_strategy + if self_strategy.max_num_shards() >= other_strategy.max_num_shards() + else other_strategy + ) + equal_strategy = OpStrategy([]) + + for arg_strategy in select_strategy.strategies: + arg_spec = arg_strategy.output_spec + if is_tensor_partial(arg_spec): + # if the arg_spec have partial, reshard to replicate + # otherwise local shard tensor comparison would be invalid + output_spec = DTensorSpec( + mesh=arg_spec.mesh, + placements=tuple( + Replicate() if isinstance(p, Partial) else p + for p in arg_spec.placements + ), + ) + equal_strategy.strategies.append( + PlacementStrategy(output_specs=output_spec) + ) + else: + equal_strategy.strategies.append(PlacementStrategy(arg_spec)) + return equal_strategy diff --git a/src/chop/passes/graph/analysis/autosharding/ops/view_ops.py b/src/chop/passes/graph/analysis/autosharding/ops/view_ops.py new file mode 100644 index 000000000..4902e83ca --- /dev/null +++ b/src/chop/passes/graph/analysis/autosharding/ops/view_ops.py @@ -0,0 +1,638 @@ +# Adapted from Pytorch Distributed DTensor API. +# https://github.com/pytorch/pytorch/blob/main/torch/distributed/_tensor/ops/view_ops.py + +from dataclasses import dataclass +from typing import ( + Callable, + cast, + Dict, + Iterable, + List, + Optional, + Sequence, + Set, + Tuple, + Union, +) + +import torch +from torch import Tensor +from torch.distributed._tensor._op_schema import ( + OpStrategy, + PlacementStrategy, +) +from torch.distributed._tensor.api import Shard +from torch.distributed._tensor.ops.utils import ( + generate_redistribute_costs, + normalize_dim, + normalize_dims, + prod, +) +from torch.distributed._tensor.placement_types import ( + DTensorSpec, + Placement, + Replicate, + TensorMeta, +) + +Shape = Tuple[int, ...] + + +@dataclass +class DimSpec: + """Specifies how an output dimension maps to an input dimension.""" + + def inputs(self) -> Iterable["DimSpec"]: + return () + + +# Rules that map each dimension of the output to dimensions of the input tensor +DimMap = Tuple[DimSpec, ...] + + +@dataclass +class Singleton(DimSpec): + """Output dimension is a singleton.""" + + pass + + +@dataclass +class InputDim(DimSpec): + """Output dimension maps directly to an input dimension.""" + + input_dim: int + + +@dataclass +class Broadcast(DimSpec): + """Output is the broadcast of a singleton input dimension.""" + + dim: DimSpec + dim_size: int + + @classmethod + def new(cls, dim: DimSpec, dim_size: int) -> DimSpec: + return Broadcast(dim, dim_size) + + def inputs(self) -> Iterable[DimSpec]: + return (self.dim,) + + +@dataclass +class NewDim(DimSpec): + """This is a new dimension created by the op.""" + + size: int + + @classmethod + def new(cls, size: int) -> DimSpec: + return Singleton() if size == 1 else NewDim(size) + + +@dataclass +class Repeat(DimSpec): + """Output dimension is the input dimension repeated n-times.""" + + input_dim: DimSpec + times: int + + @classmethod + def new(cls, dim: DimSpec, times: int) -> DimSpec: + if times == 1: + return dim + elif isinstance(dim, Singleton): + # repeating a singleton is the same as broadcasting it + return Broadcast(dim, times) + else: + return Repeat(dim, times) + + def inputs(self) -> Iterable[DimSpec]: + return (self.input_dim,) + + +@dataclass +class Flatten(DimSpec): + """Flatten a set of input dimensions, ensuring right-most adjacent elements remain adjacent in the output.""" + + input_dims: Sequence[DimSpec] + + @classmethod + def new(cls, dims: Sequence[DimSpec]) -> DimSpec: + if len(dims) == 0: + # flattening a scalar leads to a singleton + return Singleton() + elif len(dims) == 1: + # flattening a single dimension is no-op + return dims[0] + else: + return Flatten(dims) + + def inputs(self) -> Iterable[DimSpec]: + return self.input_dims + + +@dataclass +class Split(DimSpec): + """ + This dimension is a member of a decomposition of the input dim. + + Note that input_dim itself could be a Flattened set of input dims. + """ + + input_dim: DimSpec + group_shape: Shape + split_id: int + + @classmethod + def new(cls, dim: DimSpec, group_shape: Tuple[int, ...], idx: int) -> DimSpec: + assert len(group_shape) > 0 + if len(group_shape) == 1: + # not really a group, just return the input dim back + assert idx == 0 + return dim + elif group_shape[idx] == 1: + return Singleton() + else: + # remove singletons from group + # group_mapping = [(new_index, (shape, old_index)) ...] + group_mapping = list( + enumerate((s, i) for i, s in enumerate(group_shape) if s != 1) + ) + new_group_shape = tuple(m[1][0] for m in group_mapping) + new_idx = next(filter(lambda x: x[1][1] == idx, group_mapping))[0] + return Split(dim, new_group_shape, new_idx) + + def inputs(self) -> Iterable[DimSpec]: + return (self.input_dim,) + + +def dim_pad_left(ndim: int, min_dims: int) -> DimMap: + return (Singleton(),) * max(0, min_dims - ndim) + tuple( + InputDim(i) for i in range(ndim) + ) + + +def dim_atleast_3d(ndim: int) -> DimMap: + if ndim == 0: + return (Singleton(), Singleton(), Singleton()) + elif ndim == 1: + return (Singleton(), InputDim(0), Singleton()) + elif ndim == 2: + return (InputDim(0), InputDim(1), Singleton()) + else: + return tuple(InputDim(i) for i in range(ndim)) + + +def expand(input_shape: Shape, shape: Shape) -> DimMap: + """Implement broadcast on multiple dimensions.""" + assert len(shape) >= len(input_shape) + + # 1. create padded input dimensions + padded_input = dim_pad_left(len(input_shape), len(shape)) + # 2. check that input shapes are compatible + mapping = [] + for p, desired_s in zip(padded_input, shape): + if isinstance(p, Singleton): + actual_s = 1 + assert desired_s >= 0 + else: + assert isinstance(p, InputDim), f"DimSpec not supported in expand: {p}" + actual_s = input_shape[p.input_dim] + assert actual_s == 1 or desired_s == -1 or desired_s == actual_s + mapping.append( + p + if desired_s in (1, -1) or desired_s == actual_s + else Broadcast.new(p, desired_s) + ) + return tuple(mapping) + + +def normalize_sizes(sizes: Union[Shape, Tuple[Shape]]) -> Shape: + if isinstance(sizes[0], int): + return cast(Shape, sizes) + elif len(sizes) == 1: + return cast(Shape, sizes[0]) # type: ignore[redundant-cast] + else: + raise RuntimeError("Size must be int... or tuple") + + +def dim_flatten(ndim: int, start_dim=0, end_dim=-1) -> DimMap: + if ndim == 0: + return (Singleton(),) + elif ndim == 1: + return (InputDim(0),) + else: + # only flattening dims from start_dim to end_dim (inclusive) + # other dims are passed through + if end_dim < 0: + end_dim += ndim + results: List[DimSpec] = [InputDim(i) for i in range(start_dim)] + results.append( + Flatten.new(tuple(InputDim(i) for i in range(start_dim, end_dim + 1))) + ) + results.extend([InputDim(i) for i in range(end_dim + 1, ndim)]) + return tuple(results) + + +def dim_movedim( + ndim: int, + input: Union[int, Sequence[int]], + destination: Union[int, Sequence[int]], +) -> DimMap: + input = normalize_dims(input, ndim) + destination = normalize_dims(destination, ndim) + + assert len(input) == len(destination) + input_set = set(input) + assert len(input_set) == len(input), "Found repeated input dims" + assert len(set(destination)) == len(destination), "Found repeated output dims" + assert max(input) < ndim + assert max(destination) < ndim + + dest = [-1] * ndim + for i, d in zip(input, destination): + dest[d] = i + + unused_inputs_iter = iter(i for i in range(ndim) if i not in input_set) + for i in range(ndim): + if dest[i] == -1: + dest[i] = next(unused_inputs_iter) + + return tuple(InputDim(i) for i in dest) + + +def dim_repeat(ndim: int, sizes: Shape) -> DimMap: + sizes = normalize_sizes(sizes) + assert ( + len(sizes) >= ndim + ), f"Number of dimensions of repeat dims {sizes} can not be smaller than number of dimensions of tensor {ndim}." + pad = len(sizes) - ndim + return tuple(Repeat.new(Singleton(), s) for s in sizes[:pad]) + tuple( + Repeat.new(InputDim(i), s) for i, s in enumerate(sizes[pad:]) + ) + + +def infer_size(total_size: int, sizes: Shape) -> Shape: + """ + One dimension input to view may be "-1". + + Infer the size of this dimension given the total_size. + """ + infers = [i for i, s in enumerate(sizes) if s == -1] + size = prod(sizes) + assert len(infers) <= 1, "can only infer one size" + if infers: + size = -size + missing_size = total_size // size + assert ( + total_size % size == 0 + ), f"size inferred for -1 is not integral {sizes} should have {total_size} elements." + return tuple(s if s != -1 else missing_size for s in sizes) + assert size == total_size, f"sizes do not match {total_size} vs {size}" + return sizes + + +def view_groups(from_size: Shape, to_size: Shape) -> DimMap: + """ + Decompose a reshape operation into forwarding, flattening, or splitting dimensions for each output dimension. + + A view or reshape operation can be decomposed into a set of 3 types of smaller operations: + 1) Forward a dimension from input to output + 2) Flatten a set of dimensions into a single dimension + 3) Split one dimension into multiple dimensions + + view_groups identifies these operations and returns, for each output dimension, what + is operation was performed in the input dimension. For example: + + view_groups([2, 3, 4], [2, 12]) -> ( + InputDim(0), + Flatten((InputDim(1), InputDim(2))) + ) + + - ouptut dimension 0 maps to input dimension 0 + - output dimension 1 maps to a flattened input dimensions 1 and 2 + + + view_groups([2, 3], [3, 2]) -> ( + Split(Flatten((InputDim(0), InputDim(1))), (3, 2), 0), + Split(Flatten((InputDim(0), InputDim(1))), (3, 2), 1), + ) + + - in the above, input is flattened into a single dimension and then split + into two separate dimensions with different sizes from the input. + """ + from_nelem = prod(from_size) + to_size = infer_size(from_nelem, normalize_sizes(to_size)) + + assert from_nelem == prod(to_size), "Total view shape does not add up" + + from_idx = 0 + to_idx = 0 + from_len = len(from_size) + to_len = len(to_size) + + result_pp = [] + + while from_idx < from_len or to_idx < to_len: + from_group_dim, to_group_shape = [], [] + + if from_idx >= from_len: + f = 1 + else: + f = from_size[from_idx] + from_group_dim.append(from_idx) + from_idx += 1 + + if to_idx >= to_len: + t = 1 + else: + t = to_size[to_idx] + to_group_shape.append(t) + to_idx += 1 + + # if any of the groups is singleton, great, we need to backtrack though + if f == 1 and t != 1: + # produces ([1], []) + to_idx -= 1 + to_group_shape = [] + elif f != 1 and t == 1: + # produces ([], [1]) + from_idx -= 1 + from_group_dim = [] + else: + # produces ([1], [1]), ([2], [2]), ([2,3], [6]) + while f != t: + if f < t: + nf = from_size[from_idx] + from_group_dim.append(from_idx) + from_idx += 1 + f *= nf + else: + nt = to_size[to_idx] + to_group_shape.append(nt) + to_idx += 1 + t *= nt + + if len(to_group_shape) > 0: + flattened = Flatten.new( + tuple(InputDim(fi) for fi in from_group_dim if from_size[fi] > 1) + ) + result_pp += [ + Split.new(flattened, tuple(to_group_shape), i) + for i in range(len(to_group_shape)) + ] + + return tuple(result_pp) + + +def dim_tile(ndim: int, dims: Tuple[int, ...]) -> DimMap: + if len(dims) < ndim: + dims = (1,) * (ndim - len(dims)) + dims + return dim_repeat(ndim, dims) + + +def dim_transpose(ndim: int, dim1: int, dim2: int) -> DimMap: + dim1 = normalize_dim(dim1, ndim) + dim2 = normalize_dim(dim2, ndim) + assert dim1 < ndim + assert dim2 < ndim + dimmap = [InputDim(i) for i in range(ndim)] + swapdim = dimmap[dim1] + dimmap[dim1] = dimmap[dim2] + dimmap[dim2] = swapdim + return tuple(dimmap) + + +def dim_squeeze(shape: Shape, dim: Optional[int] = None) -> DimMap: + # FIXME: this is wrong when dim=None and one of the dimensions + # equals size of the mesh. For example squeeze(DTensor(tensor(4), Shard[0])) could + # end up as squeeze(tensor(1)) if we have 4 devices; this would lead to + # removal of a dimension that is not actually a singleton. + return tuple( + InputDim(i) + for i, s in enumerate(shape) + if s > 1 or (dim is not None and i != normalize_dim(dim, len(shape))) + ) + + +def dim_unsqueeze(ndim: int, dim: int) -> DimMap: + dims = tuple(InputDim(i) for i in range(ndim)) + if dim < 0: + dim += ndim + 1 + return dims[:dim] + (Singleton(),) + dims[dim:] + + +def dim_view_as_real(shape: Shape) -> DimMap: + ndim = len(shape) + results: List[DimSpec] = [InputDim(i) for i in range(ndim - 1)] + # each complex number is split into two real numbers, + # resulting in one more dimension of size 2 + results.append(Split(InputDim(ndim - 1), (shape[-1], 2), 0)) + results.append(Split(InputDim(ndim - 1), (shape[-1], 2), 1)) + return tuple(results) + + +dim_maps: Dict[Callable[..., torch.Tensor], Callable[..., DimMap]] = { + torch.atleast_1d: lambda x: dim_pad_left(x.ndim, 1), + torch.atleast_2d: lambda x: dim_pad_left(x.ndim, 2), + torch.atleast_3d: lambda x: dim_atleast_3d(x.ndim), + torch.broadcast_to: lambda input, shape: expand(input.shape, shape), + torch.flatten: lambda tensor: dim_flatten(tensor.ndim), + torch.movedim: lambda input, source, destination: dim_movedim( + input.ndim, source, destination + ), + torch.permute: lambda input, *dims: tuple( + InputDim(i) for i in normalize_dims(tuple(dims), input.ndim) + ), + torch.ravel: lambda tensor: dim_flatten(tensor.ndim), + torch.reshape: lambda input, shape: view_groups(input.shape, shape), + torch.squeeze: lambda input, dim=None: dim_squeeze(input.shape, dim), + torch.tile: lambda input, dims: dim_tile(input.ndim, dims), + torch.transpose: lambda input, dim0, dim1: dim_transpose(input.ndim, dim0, dim1), + torch.unsqueeze: lambda input, dim: dim_unsqueeze(input.ndim, dim), + torch.view_as_complex: lambda input: dim_flatten(input.ndim, input.ndim - 2), + torch.view_as_real: lambda input: dim_view_as_real(input.shape), + Tensor.expand: lambda self, *sizes: expand(self.shape, normalize_sizes(sizes)), + Tensor.repeat: lambda self, *sizes: dim_repeat(self.ndim, sizes), + Tensor.reshape: lambda self, *shape: view_groups(self.shape, shape), + Tensor.view: lambda input, *shape: view_groups(input.shape, shape), + # here + Tensor.permute: lambda input, *dims: tuple( + InputDim(i) for i in normalize_dims(tuple(dims), input.ndim) + ), + Tensor.transpose: lambda input, dim0, dim1: dim_transpose(input.ndim, dim0, dim1), + Tensor.view: lambda input, *shape: view_groups(input.shape, shape), + Tensor.unsqueeze: lambda input, dim: dim_unsqueeze(input.ndim, dim), +} + + +def propagate_shape_and_sharding( + input_src_placements: Sequence[Placement], + local_in_shape: Shape, + rule: DimMap, + mesh_sizes: Shape, +) -> Tuple[Sequence[Placement], Sequence[Placement]]: + """ + Determine input target sharding and output sharding based on + given global tensor shape and input source sharding. + + Sharding propagation follows mapped dimensions: + - An output dimension that maps directly to an input dimension is sharded equally + - An output dimension that is a flattened set of input dimensions can only be + sharded if only the leftmost flattened dimension is sharded. + - An output dimension that is a split of the input dimension can only be sharded + if the leftmost split size is divisible by the mesh dimension + """ + assert len(input_src_placements) == len(mesh_sizes) + # for each input dim, for each mesh dim, provides a list of possible shardable dimensions + mesh_ndim = len(mesh_sizes) + shardable_dims: Dict[int, List[bool]] = {} + + # in case an input dimension disappears (e.g. collapsing, reduction) + # we cannot shard in that dimension (we need a replication fall-back rule) + seen_input_dims: Set[int] = set() + + def collect_used_inputs(cmd: DimSpec) -> None: + if isinstance(cmd, InputDim): + seen_input_dims.add(cmd.input_dim) + for inp in cmd.inputs(): + collect_used_inputs(inp) + + for cmd in rule: + collect_used_inputs(cmd) + for dim in range(len(local_in_shape)): + shardable_dims[dim] = [dim in seen_input_dims] * mesh_ndim + + def get_in_dim_to_shard(cmd: DimSpec) -> Optional[InputDim]: + if isinstance(cmd, InputDim): + return cmd + elif isinstance(cmd, Flatten): + for dim in cmd.input_dims[1:]: + if isinstance(dim, InputDim): + shardable_dims[dim.input_dim] = [False] * mesh_ndim + dim0 = cmd.input_dims[0] + return dim0 if isinstance(dim0, InputDim) else None + elif isinstance(cmd, Split): + in_dim = get_in_dim_to_shard(cmd.input_dim) + out_size = cmd.group_shape[cmd.split_id] + if cmd.split_id == 0 and in_dim is not None: + # we need to check that the input dimension is divisible + # by the size of the submesh we're sharding it on + # NOTE: it would be possible to shard the same input dimension + # on more than one mesh dimension. In that case, the dimension + # needs to be divisible by the product of mesh sizes. + # In order to keep the problem more tractable, we will not consider + # double resharding as a suggestion (e.g. [Shard(0), Shard(0) ]) + # but we will allow it if that's the input and it's compatible + + # 1. is this dimension shardable on each individual mesh dim? + shardable_dims[in_dim.input_dim] = [ + out_size % mesh_dim_size == 0 for mesh_dim_size in mesh_sizes + ] + + # 2. here we special case things like [Shard(0), Shard(0)] + submesh_size = 1 + for size, shard in zip(mesh_sizes, input_src_placements): + if isinstance(shard, Shard) and shard.dim == in_dim: + submesh_size *= size + assert ( + out_size % submesh_size == 0 + ), f"Resulting dimension size {out_size} is not divisible by its mesh dimension {submesh_size}." + + # we will only shard our first component of the split + return in_dim if cmd.split_id == 0 else None + elif isinstance(cmd, Repeat): + in_dim = get_in_dim_to_shard(cmd.input_dim) + if in_dim is not None: + shardable_dims[in_dim.input_dim] = [False] * mesh_ndim + return None + else: + return None + + # for each output dim, find the corresponding input dim in terms of sharding prop + shard_dim_map = {} + for dim, cmd in enumerate(rule): + in_dim = get_in_dim_to_shard(cmd) + if in_dim is not None: + shard_dim_map[in_dim.input_dim] = dim + + input_tgt_placements = [ + ( + Replicate() + if isinstance(p, Shard) and not shardable_dims[p.dim][mesh_dim] + else p + ) + for mesh_dim, p in enumerate(input_src_placements) + ] + output_placements = [ + Shard(shard_dim_map[p.dim]) if isinstance(p, Shard) else p + for p in input_tgt_placements + ] + + return input_tgt_placements, output_placements + + +def get_reshape_strategy(op): + dim_map = dim_maps[op] + + def reshape_strategy(meta, mesh): + args_schema = [i["value"] for i in meta["common"]["args"].values()] + rules = dim_map(*args_schema) + parent_node = meta.node.args[0] + # input_strategy = cast(OpStrategy, op_schema.args_schema[0]) + input_strategy = parent_node.meta["mase"]["software"]["autosharding"][ + "op_strategy" + ] + global_in_shape = meta["common"]["args"]["data_in_0"]["shape"] + assert global_in_shape is not None, "Shape required." + + output_strategy = OpStrategy([]) + for input_placement_strategy in input_strategy.strategies: + input_src_spec = input_placement_strategy.output_spec + + input_tgt_placements, output_placements = propagate_shape_and_sharding( + input_src_spec.placements, + tuple(global_in_shape), + rules, + mesh.mesh_shape, + ) + + # TODO: optimize this. we shouldn't simply blindly replicate + # unshardable dims ... + # FIXME: this can be wrong for situations where we have + # [Shard(0), Shard(0)] + input_tgt_spec = DTensorSpec( + placements=tuple(input_tgt_placements), + mesh=input_src_spec.mesh, + tensor_meta=input_src_spec.tensor_meta, + ) + + replicate_spec = DTensorSpec( + placements=(Replicate(), Replicate()), + mesh=input_src_spec.mesh, + # todo: may need to set tensor meta + tensor_meta=None, + ) + # add fully replicated spec for all constant args + input_specs = (input_tgt_spec,) + (replicate_spec,) * (len(args_schema) - 1) + + output_spec = DTensorSpec( + mesh=mesh, + placements=tuple(output_placements), + tensor_meta=TensorMeta( + shape=meta["common"]["results"]["data_out_0"]["shape"], + stride=None, + dtype=meta["common"]["results"]["data_out_0"]["torch_dtype"], + ), + ) + output_strategy.strategies.append( + PlacementStrategy( + output_specs=output_spec, + input_specs=input_specs, + ) + ) + + return output_strategy + + return reshape_strategy diff --git a/src/chop/passes/graph/analysis/autosharding/strategies/common.py b/src/chop/passes/graph/analysis/autosharding/strategies/common.py index 0959a6bf3..5e98c840e 100644 --- a/src/chop/passes/graph/analysis/autosharding/strategies/common.py +++ b/src/chop/passes/graph/analysis/autosharding/strategies/common.py @@ -40,20 +40,61 @@ def find_shape_and_dtype(arg): def placeholder_or_getattr_strategy(meta, mesh, skip_fully_replicated=False): ndims = len(meta["common"]["results"]["data_out_0"]["shape"]) + tensor_shape = meta["common"]["results"]["data_out_0"]["shape"] opts = [Replicate()] + [Shard(dim) for dim in range(ndims)] tensor_meta = TensorMeta( - shape=meta["common"]["results"]["data_out_0"]["shape"], + shape=tensor_shape, stride=None, dtype=meta["common"]["results"]["data_out_0"]["torch_dtype"], ) shardings = [] for sharding in itertools.product(opts, repeat=2): + # Skip fully replicated shardings since this sometimes forces the ILP + # to choose a fully replicated strategy for the entire model when + # the computation cost term is not formulated if skip_fully_replicated and sharding == (Replicate(), Replicate()): continue - spec = DTensorSpec(mesh=mesh, placements=sharding, tensor_meta=tensor_meta) - shardings.append(PlacementStrategy(input_specs=spec, output_specs=spec)) + + # Skip sharding if any dimension is sharded to 0 + skip_sharding = False + for dim in range(ndims): + # Find all device mesh dimensions along which this tensor dimension is sharded + mesh_sharded_dims = [ + idx for idx, shard in enumerate(sharding) if shard == Shard(dim) + ] + + # This tensor dimension is not sharded + if len(mesh_sharded_dims) == 0: + continue + + elif len(mesh_sharded_dims) == 1: + num_gpus = mesh.mesh_shape[mesh_sharded_dims[0]] + + else: + num_gpus = np.prod(mesh.mesh_shape) + + dim_size_after_sharding = tensor_shape[dim] // num_gpus + if dim_size_after_sharding == 0: + skip_sharding = True + continue + + if skip_sharding is True: + continue + + spec = DTensorSpec( + mesh=mesh, + placements=sharding, + tensor_meta=tensor_meta, + ) + shardings.append( + PlacementStrategy( + input_specs=spec, + output_specs=spec, + ) + ) + return OpStrategy(shardings) @@ -70,23 +111,29 @@ def fully_replicated_strategy(meta, mesh): in_shape = meta["common"]["self"].shape in_dtype = meta["common"]["self"].dtype else: - first_arg_key = ( - "data_in_0" - if "data_in_0" in meta["common"]["args"] - else [i for i in meta["common"]["args"].keys()][0] - ) - arg = meta["common"]["args"][first_arg_key] - in_shape, in_dtype = find_shape_and_dtype(arg) - - in_spec = DTensorSpec( - mesh, - sharding, - tensor_meta=TensorMeta( - shape=in_shape, - stride=None, - dtype=in_dtype, - ), - ) + if len(meta["common"]["args"]) > 0: + first_arg_key = ( + "data_in_0" + if "data_in_0" in meta["common"]["args"] + else [i for i in meta["common"]["args"].keys()][0] + ) + arg = meta["common"]["args"][first_arg_key] + in_shape, in_dtype = find_shape_and_dtype(arg) + + in_spec = [ + DTensorSpec( + mesh, + sharding, + tensor_meta=TensorMeta( + shape=in_shape, + stride=None, + dtype=in_dtype, + ), + ) + ] * len(meta["common"]["args"].keys()) + + else: + in_spec = [] dtype_key = ( "torch_dtype" @@ -104,6 +151,92 @@ def fully_replicated_strategy(meta, mesh): ), ) - shardings = [PlacementStrategy(input_specs=in_spec, output_specs=out_spec)] + return OpStrategy( + [ + PlacementStrategy( + input_specs=in_spec, + output_specs=out_spec, + ) + ] + ) - return OpStrategy(shardings) + +def expand_to_full_mesh_op_strategy( + meta, + mesh: DeviceMesh, + single_mesh_dim_strategies: List[List[Placement]], + *, + input_index: int = 1, + inplace_op: bool = False, +) -> OpStrategy: + # Expand the single_mesh_dim_strategies to full mesh dim strategies. + all_mesh_dim_strategies = [single_mesh_dim_strategies] * mesh.ndim + + strategy_combs = itertools.product(*all_mesh_dim_strategies) + + all_strategies = [] + for strategy_comb in strategy_combs: + spec_list = [] + for specs in zip(*strategy_comb): + try: + spec_list.append( + DTensorSpec( + mesh, + tuple(specs), + tensor_meta=TensorMeta( + shape=meta["common"]["results"]["data_out_0"]["shape"], + stride=None, + dtype=meta["common"]["results"]["data_out_0"][ + "torch_dtype" + ], + ), + ) + ) + except: + breakpoint() + + input_specs = spec_list[input_index:] + # input_args_strategy = op_schema.args_strategy + input_args_strategy = tuple( + arg.meta["mase"]["software"]["autosharding"]["op_strategy"] + for arg in meta.node.args + ) + assert len(input_specs) == len(input_args_strategy) + self_spec = input_args_strategy[0].strategies[0].output_spec + if inplace_op and self_spec.placements != input_specs[0].placements: + # if it's inplace op, we would only allow the placement strategy to be added when the + # input_spec matches the first argument's runtime sharding, otherwise we skip + continue + + # check inputs shardable + inputs_shardable = all( + is_tensor_shardable(inp.shape, s) + for inp, s in zip(input_args_strategy, input_specs) + ) + + # extend input_specs to include fully replicated sharding for constant nodes + extended_input_specs = input_specs + [ + DTensorSpec( + mesh, + (Replicate(), Replicate()), + # todo: may need to set tensor meta + tensor_meta=None, + ) + ] * (len(meta["common"]["args"].keys()) - len(input_specs)) + + # only add to the all_strategies list when all inputs are shardable + if inputs_shardable: + redistribute_cost = [ + generate_redistribute_costs(input_strategy, input_spec) + for input_strategy, input_spec in zip(input_args_strategy, input_specs) + ] + strategy = PlacementStrategy( + output_specs=( + tuple(spec_list[:input_index]) if input_index > 1 else spec_list[0] + ), + input_specs=extended_input_specs, + redistribute_cost=redistribute_cost, + ) + all_strategies.append(strategy) + + return OpStrategy(all_strategies) diff --git a/src/chop/passes/graph/analysis/autosharding/strategies/math_ops.py b/src/chop/passes/graph/analysis/autosharding/strategies/math_ops.py index 4cdc2f0fe..2d2ecb200 100644 --- a/src/chop/passes/graph/analysis/autosharding/strategies/math_ops.py +++ b/src/chop/passes/graph/analysis/autosharding/strategies/math_ops.py @@ -9,8 +9,9 @@ PlacementStrategy, ) from torch.distributed._tensor.ops.utils import ( - generate_redistribute_costs, normalize_dim, +) +from torch.distributed._tensor._utils import ( normalize_to_torch_size, ) from torch.distributed._tensor.placement_types import ( diff --git a/src/chop/passes/graph/analysis/autosharding/strategies/matrix_ops.py b/src/chop/passes/graph/analysis/autosharding/strategies/matrix_ops.py index d552f8171..245e42f69 100644 --- a/src/chop/passes/graph/analysis/autosharding/strategies/matrix_ops.py +++ b/src/chop/passes/graph/analysis/autosharding/strategies/matrix_ops.py @@ -22,6 +22,16 @@ from ..utils import is_tensor_shardable from chop.ir.graph import MaseMetadata +from .common import expand_to_full_mesh_op_strategy + + +def _other(meta, dim): + if dim == meta["common"]["args"]["dim0"]["value"]: + return meta["common"]["args"]["dim1"]["value"] + elif dim == meta["common"]["args"]["dim1"]["value"]: + return meta["common"]["args"]["dim0"]["value"] + else: + raise ValueError(f"Invalid dim: {dim}") def transpose_strategy( @@ -34,17 +44,28 @@ def transpose_strategy( assert isinstance(self_strategy, OpStrategy) + fully_replicated_spec = DTensorSpec( + mesh=mesh, + placements=[Replicate(), Replicate()], + tensor_meta=None, + ) + transpose_strategies = [] for input_strategy in self_strategy.strategies: - input_spec = input_strategy.output_spec + + if isinstance(input_strategy.output_specs, tuple): + input_spec = input_strategy.output_specs[0] + else: + input_spec = input_strategy.output_spec + # follow the input spec but transpose the Shard placements output_placements = [ - Shard(1 - p.dim) if isinstance(p, Shard) else p + Shard(_other(meta, p.dim)) if isinstance(p, Shard) else p for p in input_spec.placements ] transpose_strategy = PlacementStrategy( output_specs=DTensorSpec( - mesh=input_strategy.output_spec.mesh, + mesh=input_spec.mesh, placements=tuple(output_placements), tensor_meta=TensorMeta( shape=meta["common"]["results"]["data_out_0"]["shape"], @@ -52,7 +73,8 @@ def transpose_strategy( dtype=meta["common"]["results"]["data_out_0"]["torch_dtype"], ), ), - input_specs=(input_strategy.output_spec,), + # include 2 fully replicated inputs for dim_0 and dim_1 arguments + input_specs=(input_spec,) + (fully_replicated_spec,) * 2, ) transpose_strategies.append(transpose_strategy) @@ -205,11 +227,17 @@ def scaled_dot_product_flash_attention_strategy( # NOTE: currently we only support some simple strategies to support tensor parallelism # TODO: sdpa might be a good candidate for us to explore decomposed sharding propagation # as it involves: matmul, pointwise, reduction ops together. - return_debug_mask = len(op_schema.args_schema) >= 6 and op_schema.args_schema[5] - q_input_strategy = op_schema.args_schema[0] + arg_names = list(meta["common"]["args"].keys()) + arg_infos = list(meta["common"]["args"].values()) + return_debug_mask = len(arg_names) >= 6 and arg_infos[5]["value"] + + # q_input_strategy = op_schema.args_schema[0] + q_parent_node = meta.node.args[0] + q_input_strategy = q_parent_node.meta["mase"]["software"]["autosharding"][ + "op_strategy" + ] assert isinstance(q_input_strategy, OpStrategy) # assuming q/k/v have the same shape - qkv_shape = q_input_strategy.shape single_mesh_dim_strategies = [] @@ -277,5 +305,16 @@ def scaled_dot_product_flash_attention_strategy( ] ) return expand_to_full_mesh_op_strategy( - mesh, op_schema, single_mesh_dim_strategies, input_index=9 + meta, + mesh, + single_mesh_dim_strategies, + input_index=9, ) + + +def scaled_dot_product_strategy( + meta: MaseMetadata, + mesh: tuple, +): + # todo: support efficient attention backend + return scaled_dot_product_flash_attention_strategy(meta, mesh) diff --git a/src/chop/passes/graph/analysis/autosharding/strategies/view_ops.py b/src/chop/passes/graph/analysis/autosharding/strategies/view_ops.py index 83c638bfb..70e62b450 100644 --- a/src/chop/passes/graph/analysis/autosharding/strategies/view_ops.py +++ b/src/chop/passes/graph/analysis/autosharding/strategies/view_ops.py @@ -606,7 +606,7 @@ def reshape_strategy(meta, mesh): # FIXME: this can be wrong for situations where we have # [Shard(0), Shard(0)] input_tgt_spec = DTensorSpec( - placements=tuple(input_tgt_placements), + placements=(Replicate(), Replicate()), mesh=input_src_spec.mesh, tensor_meta=input_src_spec.tensor_meta, ) diff --git a/src/chop/passes/graph/analysis/autosharding/utils.py b/src/chop/passes/graph/analysis/autosharding/utils.py index 8a08be311..54ff94bea 100644 --- a/src/chop/passes/graph/analysis/autosharding/utils.py +++ b/src/chop/passes/graph/analysis/autosharding/utils.py @@ -1,4 +1,4 @@ -from typing import cast, Iterable, List, Sequence, Tuple, Union +from typing import Sequence, cast from torch.distributed._tensor.placement_types import DTensorSpec, Shard diff --git a/src/chop/passes/graph/analysis/report/report_graph.py b/src/chop/passes/graph/analysis/report/report_graph.py index bfc7086ef..00ba219da 100644 --- a/src/chop/passes/graph/analysis/report/report_graph.py +++ b/src/chop/passes/graph/analysis/report/report_graph.py @@ -64,7 +64,13 @@ def report_graph_analysis_pass(graph, pass_args={"file_name": None}): {count} Layer types: -{layer_types}""" +{layer_types} + +===================== Code Gen ===================== + +{graph.model.code} + +""" if file_name is None: print(buff) else: diff --git a/src/chop/passes/graph/transforms/__init__.py b/src/chop/passes/graph/transforms/__init__.py index 612773262..eaf089de3 100644 --- a/src/chop/passes/graph/transforms/__init__.py +++ b/src/chop/passes/graph/transforms/__init__.py @@ -20,3 +20,8 @@ from .granularity import raise_granularity_transform_pass from .patching import patch_metadata_transform_pass + +from .resharding import resharding_transform_pass +from .insert_dtensor_wrapper import insert_dtensor_wrapper_transform_pass + +from .find_replace.method_to_function import replace_method_with_function diff --git a/src/chop/passes/graph/transforms/find_replace/__init__.py b/src/chop/passes/graph/transforms/find_replace/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/chop/passes/graph/transforms/find_replace/method_to_function.py b/src/chop/passes/graph/transforms/find_replace/method_to_function.py new file mode 100644 index 000000000..34f7d5dd5 --- /dev/null +++ b/src/chop/passes/graph/transforms/find_replace/method_to_function.py @@ -0,0 +1,63 @@ +import torch + +from chop.tools import get_logger +from chop.nn.functional.tensor import ( + torch_size, + torch_expand, + torch_view, + torch_contiguous, + torch_reshape, + torch_split, + torch_permute, + torch_transpose, +) + +logger = get_logger(__name__) +logger.setLevel("INFO") + + +REPLACE_METHODS = { + "size": torch_size, + "reshape": torch_reshape, + "expand": torch_expand, + "split": torch_split, + "view": torch_view, + "permute": torch_permute, + "transpose": torch_transpose, + "contiguous": torch_contiguous, +} + + +def replace_method_with_function(mg, pass_args={}): + """Replaces call_method calls with call_function calls in the graph. + + Args: + graph (MaseGraph): The input graph. + + Returns: + MaseGraph: The graph with method calls replaced with function calls. + """ + for node in mg.fx_graph.nodes: + if node.op != "call_method": + continue + + if node.target in REPLACE_METHODS: + + with mg.fx_graph.inserting_after(node): + logger.debug(f"Replacing {node.target} with function call.") + new_node = mg.fx_graph.call_function( + REPLACE_METHODS[node.target], + node.args, + node.kwargs, + ) + node.replace_all_uses_with(new_node) + mg.fx_graph.erase_node(node) + + else: + raise NotImplementedError( + f"Method {node.target} not implemented in replace_method_with_function." + ) + + mg.model.recompile() + + return mg, {} diff --git a/src/chop/passes/graph/transforms/insert_dtensor_wrapper.py b/src/chop/passes/graph/transforms/insert_dtensor_wrapper.py new file mode 100644 index 000000000..cda5508d2 --- /dev/null +++ b/src/chop/passes/graph/transforms/insert_dtensor_wrapper.py @@ -0,0 +1,154 @@ +import torch +from torch.distributed._tensor.api import DTensorSpec, TensorMeta +from torch.distributed import DeviceMesh +from copy import deepcopy + + +from chop.tools import get_logger +from chop.distributed.tensor import DTensor + + +logger = get_logger(__name__) +logger.setLevel("INFO") + + +def rlog(msg): + rank = torch.distributed.get_rank() + if rank == 0: + logger.info(msg) + + +class DTensorCache: + _dtensor_dict: dict = {} + + def __init__(self): + """ + This cache is needed to avoid expensive calls to _make_wrapper_subclass + at runtime when wrapping local Tensor results in DTensor objects. + """ + pass + + +def _create_dtensor( + local_tensor, + node_name, + node_meta, + result_name, + torch_mesh, +): + cached_name = f"{node_name}_{result_name}" + cached_dtensor = DTensorCache._dtensor_dict.get(cached_name, None) + + # The DTensor is not found in the cache the first time each FX node is called + if cached_dtensor is None: + result_meta = node_meta["common"]["results"][result_name] + + dtensor = DTensor( + local_tensor=local_tensor, + spec=DTensorSpec( + mesh=torch_mesh, + placements=result_meta["dtensor_spec"].placements, + tensor_meta=TensorMeta( + shape=result_meta["value"].shape, + stride=result_meta["value"].stride(), + dtype=local_tensor.dtype, + ), + ), + requires_grad=local_tensor.requires_grad, + ) + + DTensorCache._dtensor_dict[cached_name] = dtensor + + return dtensor + + # If the DTensor is found in the cache, replace the local tensor + else: + # Replace local tensor without constructing a new dtensor + cached_dtensor._local_tensor = local_tensor + + # if DEBUG_MODE: + # assert cached dtensor has the same meta + # assert cached_dtensor._spec.placements == result_meta["dtensor_spec"].placements + # assert cached_dtensor._spec.tensor_meta.shape == result_meta["value"].shape + # assert cached_dtensor._spec.tensor_meta.stride == result_meta["value"].stride() + # assert cached_dtensor._spec.tensor_meta.dtype == local_tensor.dtype + + return cached_dtensor + + +def create_wrapper(node): + + target_fn = deepcopy(node.target) + + # todo: generalize + torch_mesh = DeviceMesh( + "cuda", + mesh=torch.Tensor( + [ + [0, 1, 2, 3], + [4, 5, 6, 7], + ] + ), + ) + + result_names = list(node.meta["mase"]["common"]["results"].keys()) + + def dtensor_wrapper_fn(*args, **kwargs): + out = target_fn(*args, **kwargs) + + if isinstance(out, (tuple, list)): + outs = [] + for r_idx, r in enumerate(out): + # if isinstance(r, DTensor): + # outs.append(r) + if isinstance(r, torch.Tensor): + outs.append( + _create_dtensor( + local_tensor=r, + node_name=node.name, + node_meta=node.meta["mase"], + result_name=result_names[r_idx], + torch_mesh=torch_mesh, + ) + ) + else: + outs.append(r) + + wrapped_out = tuple(outs) + + # In the event the OpDispatcher already wrapped a DTensor around + # the local result, avoid reaching recursive depth limit + # elif isinstance(out, DTensor): + # wrapped_out = out + + elif isinstance(out, torch.Tensor): + wrapped_out = _create_dtensor( + local_tensor=out, + node_name=node.name, + node_meta=node.meta["mase"], + result_name=result_names[0], + torch_mesh=torch_mesh, + ) + + else: + wrapped_out = out + + return wrapped_out + + return dtensor_wrapper_fn + + +def insert_dtensor_wrapper_transform_pass(mg, pass_args={}): + + rlog("Inserting DTensor wrappers for call_function nodes") + + for node in mg.nodes: + if node.op == "call_function": + + logger.debug(f"Inserting DTensor wrapper for {node.name}") + node.target = create_wrapper(node) + + else: + logger.debug(f"Skipping node {node.name} because it is not a call_function") + + return mg, {} diff --git a/src/chop/passes/graph/transforms/resharding.py b/src/chop/passes/graph/transforms/resharding.py new file mode 100644 index 000000000..722225739 --- /dev/null +++ b/src/chop/passes/graph/transforms/resharding.py @@ -0,0 +1,132 @@ +import operator + +import torch +import torch.fx as fx +from torch.distributed._tensor.placement_types import Replicate, Shard + +from chop.tools import get_logger +from chop.nn.functional.dtensor import redistribute_dtensor +from chop.ir.graph import MaseMetadata + +logger = get_logger(__name__) +logger.setLevel("INFO") + + +def _insert_resharding_nodes(mg, pass_args={}): + """Insert resharding nodes""" + logger.info( + f"Running resharding_transform_pass to insert resharding nodes along necessary edges." + ) + + device_mesh = pass_args.get("device_mesh", None) + + for node in mg.fx_graph.nodes: + + if node.op == "call_function" and node.target == redistribute_dtensor: + continue + + flattened_args = node.args + tuple(node.kwargs.values()) + kwarg_keys = list(node.kwargs.keys()) + + # Number of arguments should match metadata + if node.op != "output" and len(flattened_args) != len( + node.meta["mase"]["common"]["args"] + ): + if "getitem" not in node.name: + logger.warning( + f"Skipping node: {node.name} because number of arguments do not match metadata." + ) + continue + + for arg_idx, arg_name in enumerate(node.meta["mase"]["common"]["args"].keys()): + + # Check if argument is an FX node, otherwise it's a constant + arg_obj = flattened_args[arg_idx] + arg_info = node.meta["mase"]["common"]["args"][arg_name] + if not isinstance(arg_obj, fx.Node) or not isinstance( + arg_info["value"], torch.Tensor + ): + logger.debug( + f"Skipping node: {node.name}, argument: {arg_name} because it is a constant." + ) + continue + + # Check if the parent node output spec is different from the arg input spec + arg_specs = arg_info.get("dtensor_spec", None) + parent_out_specs = arg_obj.meta["mase"]["common"]["results"][ + "data_out_0" + ].get("dtensor_spec", None) + + if arg_specs is None or parent_out_specs is None: + logger.warning( + f"Skipping edge {arg_obj} -> {node}.{arg_name} because dtensor_spec was not found" + ) + continue + + arg_placements = arg_specs.placements + parent_out_placements = ( + parent_out_specs[0].placements + if isinstance(parent_out_specs, (list, tuple)) + else parent_out_specs.placements + ) + + if arg_placements != parent_out_placements: + logger.info( + f"Inserting resharding node along edge {arg_obj} -> {node.name} because arg {arg_name} requires placement {arg_specs.placements} but parent node {arg_obj.name} has placement {parent_out_specs.placements}." + ) + + # Create resharding node + with mg.fx_graph.inserting_before(node): + resharding_node = mg.fx_graph.call_function( + redistribute_dtensor, + args=(arg_obj, arg_specs.placements), + kwargs={ + "async_op": False, + "input_tensor_mesh": device_mesh, + }, + ) + + resharding_node.meta["mase"] = MaseMetadata( + node=resharding_node, + model=mg.model, + ) + + # Update the current node's argument + # Node arg can be referenced in either node.args or node.kwargs so we + # infer which container to update based on the arg_idx value, which + # indexes the combined list of args and kwargs + if arg_idx < len(node.args): + updated_args = list(node.args) + updated_args[arg_idx] = resharding_node + node.args = tuple(updated_args) + else: + kwarg_idx = arg_idx - len(node.args) + arg_key = kwarg_keys[kwarg_idx] + kwarg_dict = {} + + # Reconstruct they node.kwargs dict since this is immutable + for key, value in node.kwargs.items(): + if key == arg_key: + kwarg_dict[key] = resharding_node + else: + kwarg_dict[key] = value + node.kwargs = kwarg_dict + + # Insert DTensor import at the top of code + def insert_imports(body): + return [ + "from torch.distributed._tensor.placement_types import Replicate, Shard, Partial; sum = 'sum' \n", + *body, + ] + + mg.fx_graph.on_generate_code(lambda _: insert_imports) + + # Check the model is valid + mg.fx_graph.lint() + mg.model.recompile() + + return mg, {} + + +def resharding_transform_pass(mg, pass_args={}): + return _insert_resharding_nodes(mg, pass_args) diff --git a/src/chop/passes/module/__init__.py b/src/chop/passes/module/__init__.py index 566b306a8..50a70ff46 100644 --- a/src/chop/passes/module/__init__.py +++ b/src/chop/passes/module/__init__.py @@ -1,5 +1,5 @@ -from .analysis import calculate_avg_bits_module_analysis_pass -from .transforms import quantize_module_transform_pass, resharding_transform_pass +from .analysis import calculate_avg_bits_module_analysis_pass, autosharding_module_analysis_pass +from .transforms import quantize_module_transform_pass ANALYSIS_PASSES = ["calculate_avg_bits_module_analysis_pass"] diff --git a/src/chop/passes/module/analysis/__init__.py b/src/chop/passes/module/analysis/__init__.py index b3b7d2ab1..6a50919c7 100644 --- a/src/chop/passes/module/analysis/__init__.py +++ b/src/chop/passes/module/analysis/__init__.py @@ -1 +1,3 @@ from .quantize import calculate_avg_bits_module_analysis_pass + +from .autosharding import autosharding_module_analysis_pass \ No newline at end of file diff --git a/src/chop/passes/module/analysis/autosharding.py b/src/chop/passes/module/analysis/autosharding.py new file mode 100644 index 000000000..cd4d0bd33 --- /dev/null +++ b/src/chop/passes/module/analysis/autosharding.py @@ -0,0 +1,370 @@ +import torch +import torch.nn.functional as F + +import numpy as np +import cvxpy as cp +from copy import copy +from collections import OrderedDict + +import vllm +from vllm.attention import Attention as VllmAttention + +from chop.tools import get_logger +from chop.distributed.utils import rlog + +from .cost_modelling import ( + _get_compute_cost_from_layer, + _get_intra_op_comms_cost, + _get_resharding_cost_matrix, + _get_memory_cost_from_layer, +) + +VllmLinear = vllm.model_executor.layers.linear.LinearBase + +from vllm.model_executor.layers.layer_norm import LayerNormBase as VllmLayerNorm +from vllm.model_executor.layers.residual import ResidualBase as VllmResidual + +logger = get_logger(__name__) +logger.setLevel("WARNING") + +STRATEGY_MAP = OrderedDict( + { + VllmLinear: ( + "replicated", + "column", + "row", + "data", + ), + VllmAttention: ( + "replicated", + "head", + ), + VllmLayerNorm: ( + "replicated", + "data", + ), + VllmResidual: ( + "replicated", + "data", + ), + type(None): None, + } +) + + +def _get_output_shape_from_layer_type( + layer: torch.nn.Module, + data_size: int, +): + if isinstance(layer, VllmLinear): + size = torch.Size([data_size, layer.weight.shape[0]]) + elif isinstance(layer, VllmAttention): + size = torch.Size([data_size, layer.impl.head_size * layer.impl.num_heads]) + elif isinstance(layer, VllmLayerNorm): + size = torch.Size([data_size, layer.normalized_shape[0]]) + elif isinstance(layer, VllmResidual): + size = torch.Size([data_size, 1]) + else: + raise ValueError(f"Unsupported layer type: {layer.__class__.__name__}") + + return tuple(size) + + +def _linearize_resharding_cost( + opt_var, + parent_opt_var, + resharding_costs, +): + # Flatten resharding matrix + resharding_costs = resharding_costs.flatten() + + # Formulate linearized variable for resharding cost + e_var = cp.Variable(resharding_costs.shape[0], boolean=True) + expr = e_var.T @ resharding_costs + constr = [ + cp.sum(e_var) == 1, + ] + + # Constraints s.t. e_var = outer(opt_var, in_opt_var) + indices = np.arange(e_var.shape[0]) + opt_indices, in_opt_indices = np.divmod(indices, parent_opt_var.shape[0]) + constr += [ + e_var <= opt_var[opt_indices], + e_var <= parent_opt_var[in_opt_indices], + e_var >= opt_var[opt_indices] + parent_opt_var[in_opt_indices] - 1, + ] + + return expr, constr + + +def _get_memory_constraint( + memory_constraint_terms: list, + self_rank: int, + pass_args: dict, +): + budget = pass_args.get("gpu_memory_budget", None) + + if budget is None: + raise ValueError("gpu_memory_budget is required for autosharding analysis") + + memory_available = torch.cuda.get_device_properties(self_rank).total_memory * budget + + mem_constr_expr = 0 + for i, (opt_var, mem_cost) in enumerate(memory_constraint_terms): + mem_constr_expr += mem_cost @ opt_var + + return [mem_constr_expr <= memory_available] + + +def _formulate_ilp( + model: torch.nn.Module, + pass_args: dict, +): + + self_rank = torch.distributed.get_rank() + data_size = pass_args.get("data_size", None) + + if data_size is None: + raise ValueError("data_size is required for autosharding analysis") + + module_list = [] + module_strategies = [] + last_residual = None + + # ILP variables + constr = [] + expr = 0 + megatron_soln = 0 + megatron_mem_cost = 0 + + bad_soln = 0 + bad_soln_memory_cost = 0 + + # List of tuples: (opt_var, memory_cost) + memory_constr_terms = [] + + for name, layer in model.named_modules(): + + # Skip non-leaf modules + if len(list(layer.children())) > 0: + continue + rlog(logger, self_rank, f"Parsing layer {layer.__class__.__name__}") + + # Check if matches with one of the supported layer types + for layer_type, layer_strategies in STRATEGY_MAP.items(): + if isinstance(layer, layer_type): + break + + if layer_type is None or layer_strategies is None: + continue + + layer.strategies = layer_strategies + + # Register layer and instantiate optimization variable + # ============================ + module_list.append(layer) + module_strategies.append(layer_strategies) + + opt_var = cp.Variable(len(layer_strategies), boolean=True) + setattr(layer, "opt_var", opt_var) + constr += [ + cp.sum(opt_var) == 1, + ] + + # Calculate Megatron solution for comparison + megatron_opt_var = np.zeros(len(layer_strategies)) + bad_soln_opt_var = np.zeros(len(layer_strategies)) + + if "attn.c_attn" in name: + megatron_opt_var[1] = 1 # column + bad_soln_opt_var[1] = 1 # column + elif "attn.attn" in name: + megatron_opt_var[1] = 1 # head + bad_soln_opt_var[1] = 1 # head + elif "attn.c_proj" in name: + megatron_opt_var[2] = 1 # row + bad_soln_opt_var[1] = 1 # column + elif "mlp.c_fc" in name: + megatron_opt_var[1] = 1 # column + bad_soln_opt_var[1] = 1 # column + elif "mlp.c_proj" in name: + megatron_opt_var[2] = 1 # row + bad_soln_opt_var[2] = 1 # column + elif "ln" in name: + megatron_opt_var[0] = 1 + elif "res" in name: + megatron_opt_var[0] = 1 + else: + raise ValueError(f"Unsupported layer name: {name}") + + setattr(layer, "megatron_opt_var", megatron_opt_var) + setattr(layer, "bad_soln_opt_var", bad_soln_opt_var) + + # Consider compute cost + # ============================ + compute_cost = _get_compute_cost_from_layer( + layer, + layer_strategies, + data_size=data_size, + benchmarking_device=self_rank, + ) + + # Consider intra operator comms cost + # ============================ + + comms_cost = _get_intra_op_comms_cost( + layer_strategies=tuple(layer_strategies), + output_shape=_get_output_shape_from_layer_type(layer, data_size), + benchmarking_device=self_rank, + ) + + expr += (compute_cost + comms_cost) @ opt_var + megatron_soln += (compute_cost + comms_cost) @ megatron_opt_var + bad_soln += (compute_cost + comms_cost) @ bad_soln_opt_var + + # Consider memory cost + # ============================ + + mem_cost = _get_memory_cost_from_layer( + layer, + layer_strategies, + benchmarking_device=self_rank, + ) + + memory_constr_terms.append((opt_var, mem_cost)) + megatron_mem_cost += mem_cost @ megatron_opt_var + + bad_soln_memory_cost += mem_cost @ bad_soln_opt_var + + # Consider resharding cost + # ============================ + + # Skip if no parent module + if len(module_list) <= 1: + continue + + parent_module = module_list[-2] + parent_strategies = module_strategies[-2] + logger.info( + f"Consider resharding cost between {parent_module.__class__.__name__} and {layer.__class__.__name__}" + ) + + parent_out_shape = _get_output_shape_from_layer_type( + parent_module, + data_size, + ) + + resharding_costs = _get_resharding_cost_matrix( + layer_strategies=layer_strategies, + parent_strategies=parent_strategies, + parent_out_shape=parent_out_shape, + benchmarking_device=self_rank, + ) + + resharding_term, resharding_constraints = _linearize_resharding_cost( + opt_var, + parent_module.opt_var, + resharding_costs, + ) + expr += resharding_term + constr += resharding_constraints + + # Add Megatron solution for comparison + megatron_resharding_term = ( + parent_module.megatron_opt_var @ resharding_costs @ megatron_opt_var + ) + megatron_soln += megatron_resharding_term + + bad_soln_resharding_term = ( + parent_module.bad_soln_opt_var @ resharding_costs @ bad_soln_opt_var + ) + bad_soln += bad_soln_resharding_term + + # Residual layers may have an additional resharding cost for the residual path + if isinstance(layer, VllmResidual) and last_residual is not None: + last_residual_shape = _get_output_shape_from_layer_type( + last_residual, + data_size, + ) + + resharding_costs = _get_resharding_cost_matrix( + layer_strategies=layer_strategies, + parent_strategies=last_residual.strategies, + parent_out_shape=last_residual_shape, + benchmarking_device=self_rank, + ) + + resharding_term, resharding_constraints = _linearize_resharding_cost( + opt_var, + last_residual.opt_var, + resharding_costs, + ) + expr += resharding_term + constr += resharding_constraints + + if isinstance(layer, VllmResidual): + last_residual = layer + + # After processing all layers, consider memory constraints + # ============================ + + mem_constr = _get_memory_constraint( + memory_constraint_terms=memory_constr_terms, + self_rank=self_rank, + pass_args=pass_args, + ) + constr += mem_constr + + return ( + cp.Problem(cp.Minimize(expr), constr), + (megatron_soln, megatron_mem_cost), + mem_constr, + ) + + +def _get_sharding_config(model): + sharding_config = {} + for layer in model.modules(): + + # Skip non-leaf modules + if len(list(layer.children())) > 0: + continue + + # Check if matches with one of the supported layer types + for layer_type, layer_strategies in STRATEGY_MAP.items(): + if isinstance(layer, layer_type): + break + + if layer_type is None or layer_strategies is None: + continue + + opt_var_value = layer.opt_var.value + strategy_idx = np.where(opt_var_value)[0][0] + strategy = layer_strategies[strategy_idx] + + sharding_config[layer.prefix] = strategy + + return sharding_config + + +def autosharding_module_analysis_pass(model, pass_args={}): + problem, megatron, mem_constr = _formulate_ilp(model, pass_args) + megatron_soln, megatron_mem_cost = megatron + problem.solve( + verbose=pass_args.get(f"debug", False), + scipy_options={ + "disp": pass_args.get(f"debug", False), + "time_limit": pass_args.get("time_limit", None), + "mip_rel_gap": pass_args.get("mip_rel_gap", 0) / 100, + }, + ) + + sharding_config = _get_sharding_config(model) + + memory_available = torch.cuda.get_device_properties( + torch.distributed.get_rank() + ).total_memory * pass_args.get("gpu_memory_budget") + + return model, { + "sharding_config": sharding_config, + } diff --git a/src/chop/passes/module/analysis/cost_modelling.py b/src/chop/passes/module/analysis/cost_modelling.py new file mode 100644 index 000000000..ed9278039 --- /dev/null +++ b/src/chop/passes/module/analysis/cost_modelling.py @@ -0,0 +1,664 @@ +import os +import math +import gc +import numpy as np +from copy import copy +from functools import lru_cache + +import torch +from torch.nn import functional as F +import torch.distributed as dist +import torch.multiprocessing as mp +from torch.multiprocessing import Queue, set_start_method + + +import vllm + +from vllm.attention import AttentionMetadata +from vllm.model_executor.layers.linear import ( + ReplicatedLinear, + ColumnParallelLinear, + RowParallelLinear, + DataParallelLinear, +) +from vllm.model_executor.layers.layer_norm import ( + ReplicatedLayerNorm, + DataParallelLayerNorm, +) +from vllm.model_executor.layers.residual import ( + ReplicatedResidual, + DataParallelResidual, +) + +from vllm.distributed.communication_op import ( + tensor_model_parallel_all_gather, + tensor_model_parallel_all_reduce, +) + +from vllm.attention import Attention as VllmAttention +from vllm.model_executor.layers.layer_norm import LayerNormBase as VllmLayerNorm +from vllm.model_executor.layers.residual import ResidualBase as VllmResidual +VllmLinear = vllm.model_executor.layers.linear.LinearBase + + +# Utilities +# ================================ + + +def _linear_cls_from_config(config: str): + if config == "replicated": + return ReplicatedLinear + if config == "column": + return ColumnParallelLinear + if config == "row": + return RowParallelLinear + if config == "data": + return DataParallelLinear + + raise ValueError(f"Unknown linear config: {config}") + +def _layer_norm_cls_from_config(config: str): + if config == "replicated": + return ReplicatedLayerNorm + if config == "data": + return DataParallelLayerNorm + + raise ValueError(f"Unknown layer norm config: {config}") + +def _residual_cls_from_config(config: str): + if config == "replicated": + return ReplicatedResidual + if config == "data": + return DataParallelResidual + + raise ValueError(f"Unknown residual config: {config}") + +def _profile_op( + op: str, + fn: callable, + shape: tuple, + repeat: int, + warmup_iters: int, + benchmarking_device: int = 0, + extra_args: list = [], +): + """ + Profile op ``repeat`` times with ``warmup_iters`` warmup iterations. + Generate random input tensors of shape ``shape`` and pass them to the function ``fn`` in each iteration. + + Args: + op (str): _description_ + fn (callable): _description_ + shape (tuple): _description_ + repeat (int): _description_ + warmup_iters (int): _description_ + benchmarking_device (int, optional): _description_. Defaults to 0. + extra_args (list, optional): _description_. Defaults to []. + + Returns: + _type_: _description_ + """ + start_event = [ + torch.cuda.Event(enable_timing=True, blocking=True) for _ in range(repeat) + ] + end_event = [ + torch.cuda.Event(enable_timing=True, blocking=True) for _ in range(repeat) + ] + + for idx in range(repeat): + if op in ["linear", "layer_norm"]: + input_ = torch.randn(shape).to(f"cuda:{benchmarking_device}") + args = [input_] + extra_args + elif op == "attention": + local_query = torch.randn(shape).to(f"cuda:{benchmarking_device}") + local_key = torch.randn(shape).to(f"cuda:{benchmarking_device}") + local_value = torch.randn(shape).to(f"cuda:{benchmarking_device}") + args = [ + local_query, + local_key, + local_value, + None, # benchmark without KV cache + ] + extra_args + elif op == "residual": + input_ = torch.randn(shape).to(f"cuda:{benchmarking_device}") + residual = torch.randn(shape).to(f"cuda:{benchmarking_device}") + args = [input_, residual] + extra_args + elif op == "allreduce": + local_tensor = torch.randn(shape).to(f"cuda:{benchmarking_device}") + args = [local_tensor] + elif op == "allgather": + local_tensor = torch.randn(shape).to(f"cuda:{benchmarking_device}") + args = [local_tensor, -1] + else: + raise ValueError(f"Unknown op: {op}") + + start_event[idx].record() + out = fn(*args) + end_event[idx].record() + torch.cuda.synchronize(device=f"cuda:{benchmarking_device}") + + elapsed = [start_event[idx].elapsed_time(end_event[idx]) for idx in range(repeat)] + + return out, np.mean(elapsed[warmup_iters:]), elapsed + + +@lru_cache(maxsize=128, typed=False) +def allreduce_cost( + output_shape: tuple, + repeat: int = 100, + warmup_iters: int = 5, + benchmarking_device: int = 0, +) -> float: + _, cost, elapsed_times = _profile_op( + op="allreduce", + fn=tensor_model_parallel_all_reduce, + shape=output_shape, + repeat=repeat, + warmup_iters=warmup_iters, + benchmarking_device=benchmarking_device, + ) + + return cost + + +@lru_cache(maxsize=128, typed=False) +def allgather_cost( + local_shape: tuple, + repeat: int = 100, + warmup_iters: int = 5, + benchmarking_device: int = 0, +) -> float: + _, cost, elapsed_times = _profile_op( + op="allgather", + fn=tensor_model_parallel_all_gather, + shape=local_shape, + repeat=repeat, + warmup_iters=warmup_iters, + benchmarking_device=benchmarking_device, + ) + + return cost + + +# Compute cost +# ================================ + + +@lru_cache(maxsize=128, typed=False) +def _cached_linear_cost_from_local_shapes( + type: str, + data_size: int, + input_size: int, + output_size: int, + repeat: int = 100, + warmup_iters: int = 5, + benchmarking_device: int = 0, +): + cls = _linear_cls_from_config(type) + + layer = cls( + input_size=input_size, + output_size=output_size, + ) + + local_shape = (data_size, input_size) + if type == "data": + local_shape = (data_size // torch.distributed.get_world_size(), input_size) + elif type == "row": + local_shape = (data_size, input_size // torch.distributed.get_world_size()) + elif type in ["replicated", "column"]: + pass + else: + raise ValueError(f"Unknown type: {type}") + + _, elapsed, elapsed_list = _profile_op( + op="linear", + fn=layer, + shape=local_shape, + repeat=repeat, + warmup_iters=warmup_iters, + benchmarking_device=benchmarking_device, + ) + + return elapsed, elapsed_list + + +def _get_linear_compute_cost( + layer: torch.nn.Module, + layer_strategies: tuple, + data_size: int, + repeat: int = 100, + warmup_iters: int = 5, + benchmarking_device: int = 0, +): + + input_size = layer.input_size + output_size = layer.output_size + + cost_vector = [] + for strategy in layer_strategies: + + # Create local tensors + elapsed, elapsed_list = _cached_linear_cost_from_local_shapes( + type=strategy, + data_size=data_size, + input_size=input_size, + output_size=output_size, + repeat=repeat, + warmup_iters=warmup_iters, + benchmarking_device=benchmarking_device, + ) + + cost_vector.append(elapsed) + + return np.array(cost_vector) + + +@lru_cache(maxsize=128, typed=False) +def _cached_attention_cost_from_local_shapes( + data_size: int, + num_heads: int, + head_size: int, + repeat: int = 100, + warmup_iters: int = 5, + benchmarking_device: int = 0, +): + local_shape = torch.Size([data_size, head_size * num_heads]) + + attn_meta = AttentionMetadata( + num_prefills=9, + num_prefill_tokens=data_size, + num_decode_tokens=0, + slot_mapping=None, + ) + attn_layer = VllmAttention( + num_heads=num_heads, + head_size=head_size, + scale=1.0, + ) + + _, elapsed, _ = _profile_op( + op="attention", + fn=attn_layer, + shape=local_shape, + repeat=repeat, + warmup_iters=warmup_iters, + benchmarking_device=benchmarking_device, + extra_args=[attn_meta], + ) + + return elapsed + +def _get_attention_compute_cost( + layer: torch.nn.Module, + layer_strategies: tuple, + data_size: int, + repeat: int = 100, + warmup_iters: int = 5, + benchmarking_device: int = 0, +): + + cost_vector = [] + for strategy in layer_strategies: + + if strategy == "replicated": + num_heads = layer.impl.num_heads + elif strategy == "head": + num_heads = layer.impl.num_heads // torch.distributed.get_world_size() + + else: + raise ValueError(f"Unknown strategy: {strategy}") + + elapsed = _cached_attention_cost_from_local_shapes( + data_size=data_size, + num_heads=num_heads, + head_size=layer.impl.head_size, + repeat=repeat, + warmup_iters=warmup_iters, + benchmarking_device=benchmarking_device, + ) + + cost_vector.append(elapsed) + + return np.array(cost_vector) + +@lru_cache(maxsize=128, typed=False) +def _cached_layer_norm_cost_from_local_shapes( + type: str, + normalized_shape: tuple, + eps: float, + elementwise_affine: bool, + bias: bool, + data_size: int, + repeat: int = 100, + warmup_iters: int = 5, + benchmarking_device: int = 0, +): + + layer = _layer_norm_cls_from_config(type)( + normalized_shape=normalized_shape, + eps=eps, + elementwise_affine=elementwise_affine, + bias=bias, + ) + + local_shape = torch.Size([data_size, normalized_shape[0]]) + + _, elapsed, elapsed_list = _profile_op( + op="layer_norm", + fn=layer, + shape=local_shape, + repeat=repeat, + warmup_iters=warmup_iters, + benchmarking_device=benchmarking_device, + ) + + return elapsed + +def _get_layer_norm_compute_cost( + layer: torch.nn.Module, + layer_strategies: tuple, + data_size: int, + repeat: int = 100, + warmup_iters: int = 5, + benchmarking_device: int = 0, +): + + cost_vector = [] + for strategy in layer_strategies: + + if strategy == "replicated": + real_data_size = data_size + elif strategy == "data": + real_data_size = data_size // torch.distributed.get_world_size() + else: + raise ValueError(f"Unknown strategy: {strategy}") + + elapsed = _cached_layer_norm_cost_from_local_shapes( + type=strategy, + normalized_shape=layer.normalized_shape, + eps=layer.eps, + elementwise_affine=layer.elementwise_affine, + bias=layer.bias is not None, + data_size=real_data_size, + repeat=repeat, + warmup_iters=warmup_iters, + benchmarking_device=benchmarking_device, + ) + + cost_vector.append(elapsed) + + return np.array(cost_vector) + +@lru_cache(maxsize=128, typed=False) +def _cached_residual_cost_from_local_shapes( + type: str, + data_size: int, + hidden_size: int, + repeat: int = 100, + warmup_iters: int = 5, + benchmarking_device: int = 0, +): + cls = _residual_cls_from_config(type) + + layer = cls( + hidden_size=hidden_size, + ) + + local_shape = torch.Size([data_size, hidden_size]) + + _, elapsed, elapsed_list = _profile_op( + op="residual", + fn=layer, + shape=local_shape, + repeat=repeat, + warmup_iters=warmup_iters, + benchmarking_device=benchmarking_device, + ) + + return elapsed + +def _get_residual_compute_cost( + layer: torch.nn.Module, + layer_strategies: tuple, + data_size: int, + repeat: int = 100, + warmup_iters: int = 5, + benchmarking_device: int = 0, +): + cost_vector = [] + for strategy in layer_strategies: + + if strategy == "replicated": + real_data_size = data_size + elif strategy == "data": + real_data_size = data_size // torch.distributed.get_world_size() + else: + raise ValueError(f"Unknown strategy: {strategy}") + + elapsed = _cached_residual_cost_from_local_shapes( + type=strategy, + data_size=real_data_size, + hidden_size=layer.hidden_size, + repeat=repeat, + warmup_iters=warmup_iters, + benchmarking_device=benchmarking_device, + ) + + cost_vector.append(elapsed) + + return np.array(cost_vector) + + +def _get_compute_cost_from_layer( + layer, + layer_strategies, + data_size, + benchmarking_device: int = 0, +): + profile_kwargs = { + "layer": layer, + "layer_strategies": layer_strategies, + "data_size": data_size, + "benchmarking_device": benchmarking_device, + } + if isinstance(layer, VllmLinear): + return _get_linear_compute_cost(**profile_kwargs) + elif isinstance(layer, VllmAttention): + return _get_attention_compute_cost(**profile_kwargs) + elif isinstance(layer, VllmResidual): + return _get_residual_compute_cost(**profile_kwargs) + elif isinstance(layer, VllmLayerNorm): + return _get_layer_norm_compute_cost(**profile_kwargs) + else: + raise ValueError(f"Unsupported layer type: {layer.__class__.__name__}") + + +@lru_cache(maxsize=128, typed=False) +def _get_intra_op_comms_cost( + layer_strategies: tuple, + output_shape: tuple, + benchmarking_device: int = 0, +): + comms_cost = np.zeros(len(layer_strategies)) + for idx, strategy in enumerate(layer_strategies): + if strategy == "row": + comms_cost[idx] = allreduce_cost( + output_shape=output_shape, + benchmarking_device=benchmarking_device, + ) + + return comms_cost + + +# Resharding cost +# ================================ + + +def _get_resharding_cost( + module_strategy: str, + parent_out_shape: tuple, + parent_strategy: str, + benchmarking_device: int = 0, +) -> float: + + # Strategies which always return RR sharding + if parent_strategy in ["replicated", "row"]: + return 0 + + world_size = torch.distributed.get_world_size() + + # all gather operation + skip_allgather = ( + ( + # Column parallel linear -> Row parallel linear (Megatron-LM) + parent_strategy == "column" + and module_strategy == "row" + ) + or ( + # Column parallel linear -> Head parallel attention (Megatron-LM) + parent_strategy == "column" + and module_strategy == "head" + ) + or ( + # Head parallel attention -> Row parallel linear (Megatron-LM) + parent_strategy == "head" + and module_strategy == "row" + ) + or ( + # Data parallel linear -> Data parallel linear + parent_strategy == "data" + and module_strategy == "data" + ) + ) + + if not skip_allgather: + local_shape = [parent_out_shape[0], parent_out_shape[1] // world_size] + cost = allgather_cost( + local_shape=tuple(local_shape), + benchmarking_device=benchmarking_device, + ) + else: + cost = 0 + + return cost + + +@lru_cache(maxsize=128, typed=False) +def _get_resharding_cost_matrix( + layer_strategies, + parent_strategies, + parent_out_shape, + benchmarking_device: int = 0, +): + + resharding_costs = np.zeros([len(parent_strategies), len(layer_strategies)]) + for module_strategy_idx, module_strategy in enumerate(layer_strategies): + for parent_strategy_idx, parent_strategy in enumerate(parent_strategies): + resharding_costs[parent_strategy_idx, module_strategy_idx] = ( + _get_resharding_cost( + module_strategy, + parent_out_shape, + parent_strategy, + benchmarking_device=benchmarking_device, + ) + ) + + return resharding_costs + + +# Memory cost +# ================================ + + +def _get_gpu_memory_usage(): + torch.cuda.synchronize() + return torch.cuda.memory_allocated() + + +@lru_cache(maxsize=128, typed=False) +def _cached_get_linear_memory_cost( + input_size: int, + output_size: int, + bias: bool, + strategies: tuple, + benchmarking_device: int = 0, +) -> list: + cost_vector = [] + peak_mems = [] + for strategy in strategies: + cls = _linear_cls_from_config(strategy) + + # Clear cache and reset stats + torch.cuda.empty_cache() + gc.collect() + torch.cuda.reset_peak_memory_stats() + + # Instantiate layer to measure memory usage + start_memory = _get_gpu_memory_usage() + _ = cls( + input_size=input_size, + output_size=output_size, + bias=bias is not None, + ).to(f"cuda:{benchmarking_device}") + end_memory = _get_gpu_memory_usage() + + # Record cost + cost_vector.append(end_memory - start_memory) + peak_mems.append(torch.cuda.max_memory_allocated()) + + return cost_vector + +@lru_cache(maxsize=128, typed=False) +def _cached_get_layer_norm_memory_cost( + normalized_shape: tuple, + elementwise_affine: bool, + bias: bool, + strategies: tuple, + benchmarking_device: int = 0, +) -> list: + cls = _layer_norm_cls_from_config("replicated") + + # Clear cache and reset stats + torch.cuda.empty_cache() + gc.collect() + torch.cuda.reset_peak_memory_stats() + + # Instantiate layer to measure memory usage + start_memory = _get_gpu_memory_usage() + _ = cls( + normalized_shape=normalized_shape, + elementwise_affine=elementwise_affine, + bias=bias is not None, + ).to(f"cuda:{benchmarking_device}") + end_memory = _get_gpu_memory_usage() + + cost = end_memory - start_memory + + return [cost] * len(strategies) + +def _get_memory_cost_from_layer( + layer, + layer_strategies, + benchmarking_device: int = 0, +): + if isinstance(layer, VllmLinear): + return _cached_get_linear_memory_cost( + input_size=layer.input_size, + output_size=layer.output_size, + bias=layer.bias is not None, + strategies=tuple(layer_strategies), + benchmarking_device=benchmarking_device, + ) + elif isinstance(layer, VllmLayerNorm): + return _cached_get_layer_norm_memory_cost( + normalized_shape=layer.normalized_shape, + elementwise_affine=layer.elementwise_affine, + bias=layer.bias is not None, + strategies=tuple(layer_strategies), + benchmarking_device=benchmarking_device, + ) + elif isinstance(layer, (VllmAttention, VllmResidual)): + return np.zeros(len(layer_strategies)) + else: + raise ValueError(f"Unsupported layer type: {layer.__class__.__name__}") diff --git a/src/chop/passes/module/transforms/__init__.py b/src/chop/passes/module/transforms/__init__.py index efbb0ed14..3fcc8c5b3 100644 --- a/src/chop/passes/module/transforms/__init__.py +++ b/src/chop/passes/module/transforms/__init__.py @@ -1,3 +1 @@ -from .autosharding import resharding_transform_pass from .quantize import quantize_module_transform_pass -from .autosharding import resharding_transform_pass diff --git a/src/chop/passes/module/transforms/autosharding/__init__.py b/src/chop/passes/module/transforms/autosharding/__init__.py deleted file mode 100644 index 699587b80..000000000 --- a/src/chop/passes/module/transforms/autosharding/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .resharding import resharding_transform_pass diff --git a/src/chop/passes/module/transforms/autosharding/resharding.py b/src/chop/passes/module/transforms/autosharding/resharding.py deleted file mode 100644 index 3646064dd..000000000 --- a/src/chop/passes/module/transforms/autosharding/resharding.py +++ /dev/null @@ -1,26 +0,0 @@ -from chop.tools import get_logger - -logger = get_logger(__name__) -logger.setLevel("INFO") - - -def resharding_transform_pass(mg, pass_args={}): - """ - This pass inserts a wrapper around each module in the graph to handle resharding - activation tensors when the output of the previous module has a different sharding - profile to the one assigned to the current module. - """ - - module_map = pass_args.get("module_map", None) - device_mesh = pass_args.get("device_mesh", None) - if module_map is None or device_mesh is None: - raise ValueError( - "module_map and device_mesh are required for resharding_transform_pass" - ) - - for node in mg.fx_graph.nodes: - pass - - mg.model.recompile() - - return mg, {} diff --git a/src/chop/pipelines/auto_pipeline.py b/src/chop/pipelines/auto_pipeline.py index 967f0f2f6..82b7ddfbc 100644 --- a/src/chop/pipelines/auto_pipeline.py +++ b/src/chop/pipelines/auto_pipeline.py @@ -36,4 +36,6 @@ def __call__(self, mg: MaseGraph, pass_args: dict, skip_passes: list = []): mg, pass_output = pass_fn(mg, pass_args=args) self.pass_outputs[pass_fn.__name__] = pass_output + mg.model.recompile() + return mg, self.pass_outputs diff --git a/src/chop/pipelines/distributed_inference.py b/src/chop/pipelines/distributed_inference.py index bc06ebab5..cb0bc9278 100644 --- a/src/chop/pipelines/distributed_inference.py +++ b/src/chop/pipelines/distributed_inference.py @@ -1,12 +1,22 @@ +import torch.distributed as dist + import chop.passes as passes +from chop.tools import get_logger from .auto_pipeline import AutoPipeline +logger = get_logger(__name__) +logger.setLevel("INFO") + class AutoPipelineForDistributedInference(AutoPipeline): """This pipeline is used for distributed inference. - It runs the following passes: + It runs the following pre-processing passes: + + - replace_method_with_function + + Then, it raises the graph to Mase IR: - init_metadata_analysis_pass @@ -14,20 +24,47 @@ class AutoPipelineForDistributedInference(AutoPipeline): - add_common_metadata_analysis_pass + Then, it runs the following passes: + - autosharding_analysis_pass + If the distributed setup is initialized, it runs the following passes: + + - insert_dtensor_wrapper_transform_pass + - resharding_transform_pass + """ def __init__(self) -> None: """Initializes the AutoPipeline.""" + # Pre-processing pass_list = [ + passes.replace_method_with_function, + ] + + # Raise to Mase IR + pass_list += [ passes.init_metadata_analysis_pass, passes.report_graph_analysis_pass, passes.add_common_metadata_analysis_pass, + ] + + # Autosharding + pass_list += [ passes.autosharding_analysis_pass, - passes.resharding_transform_pass, ] + # Only run the following in distributed setup + if dist.is_initialized(): + pass_list += [ + passes.insert_dtensor_wrapper_transform_pass, + passes.resharding_transform_pass, + ] + else: + logger.info( + "Torch distributed is not initialized, so will skip the following passes: insert_dtensor_wrapper_transform_pass, resharding_transform_pass" + ) + super().__init__(pass_list) diff --git a/test/passes/graph/analysis/autosharding/test_autosharding_bert.py b/test/passes/graph/analysis/autosharding/test_autosharding_bert.py deleted file mode 100644 index 97b4410d9..000000000 --- a/test/passes/graph/analysis/autosharding/test_autosharding_bert.py +++ /dev/null @@ -1,73 +0,0 @@ -import sys, pdb, traceback -import pytest - -import torch -import torch.nn as nn - -from chop.ir import MaseGraph -from chop.distributed import MaseLauncher -import chop.passes as passes -from chop.tools import get_logger - -from transformers.models.bert import BertConfig, BertModel - -logger = get_logger(__name__) -logger.setLevel("DEBUG") - -WORLD_SIZE = 8 -DEVICE_MESH = [[0, 1, 2, 3], [4, 5, 6, 7]] - - -@pytest.mark.skip(reason="Fixing needed") -def test_autosharding(): - - # Define config - config = BertConfig() - config.num_hidden_layers = 3 - config.hidden_size = 96 - config.intermediate_size = 384 - config._attn_implementation = "eager" - config_sequence_length = 4 - - # Initialize model and MaseGraph - model = BertModel(config) - mg = MaseGraph(model) - mg, _ = passes.init_metadata_analysis_pass(mg) - mg, _ = passes.report_graph_analysis_pass(mg, pass_args={"file_name": "bert.txt"}) - mg, _ = passes.add_common_metadata_analysis_pass( - mg, - pass_args={ - "dummy_in": { - "input_ids": torch.randint(0, 10, (1, config_sequence_length)), - }, - "add_value": False, - }, - ) - - # Run autosharding pass to decide sharding configuration - mg, module_map = passes.autosharding_analysis_pass( - mg, - pass_args={ - "mesh_shape": (2, 4), - "inter_node_bandwidth": 10e9, - "intra_node_bandwidth": 100e9, - }, - ) - - # Insert resharding wrappers around each module to handle inter-operator communication - mg, _ = passes.resharding_transform_pass( - mg, pass_args={"module_map": module_map, "device_mesh": DEVICE_MESH} - ) - - # dump print model to a file - with open("model.txt", "w") as f: - print(mg.model, file=f) - - # Launch model in distributed cluster - launcher = MaseLauncher(mg, world_size=WORLD_SIZE, device_mesh=DEVICE_MESH) - inputs = [torch.randint(0, 10, (1, config_sequence_length))] - launcher.run(module_map, inputs) - - -if __name__ == "__main__": - test_autosharding() diff --git a/test/passes/graph/analysis/autosharding/test_autosharding_linear.py b/test/passes/graph/analysis/autosharding/test_autosharding_linear.py deleted file mode 100644 index eae847aef..000000000 --- a/test/passes/graph/analysis/autosharding/test_autosharding_linear.py +++ /dev/null @@ -1,73 +0,0 @@ -import sys, pdb, traceback, os -import pytest - -import torch -import torch.nn as nn - -from chop.ir import MaseGraph -from chop.distributed import MaseLauncher -import chop.passes as passes -from chop.tools import get_logger - - -def excepthook(exc_type, exc_value, exc_traceback): - traceback.print_exception(exc_type, exc_value, exc_traceback) - print("\nEntering debugger...") - pdb.post_mortem(exc_traceback) - - -# Set the custom exception hook -sys.excepthook = excepthook - -logger = get_logger(__name__) -logger.setLevel("DEBUG") - -WORLD_SIZE = 8 -DEVICE_MESH = [[0, 1, 2, 3], [4, 5, 6, 7]] - - -class MLP(nn.Module): - def __init__(self, in_features=64, hidden_dimension=128, out_features=64): - super().__init__() - self.l1 = nn.Linear(in_features, hidden_dimension) - self.l2 = nn.Linear(hidden_dimension, out_features) - - def forward(self, x): - out = self.l1(x) - return self.l2(out) - - -@pytest.mark.skip(reason="Fixing needed") -def test_autosharding(): - - # Initialize model and MaseGraph - model = MLP() - mg = MaseGraph(model) - mg, _ = passes.init_metadata_analysis_pass(mg) - mg, _ = passes.add_common_metadata_analysis_pass( - mg, pass_args={"dummy_in": {"x": torch.randn((16, 64))}, "add_value": False} - ) - - # Run autosharding pass to decide sharding configuration - mg, module_map = passes.autosharding_analysis_pass( - mg, - pass_args={ - "mesh_shape": (2, 4), - "inter_node_bandwidth": 10e9, - "intra_node_bandwidth": 100e9, - }, - ) - - # Insert resharding wrappers around each module to handle inter-operator communication - mg, _ = passes.resharding_transform_pass( - mg, pass_args={"module_map": module_map, "device_mesh": DEVICE_MESH} - ) - - # Launch model in distributed cluster - launcher = MaseLauncher(mg, world_size=WORLD_SIZE, device_mesh=DEVICE_MESH) - inputs = [torch.randn((16, 64))] - launcher.run(module_map, inputs) - - -if __name__ == "__main__": - test_autosharding() diff --git a/test/passes/graph/transforms/verilog/test_emit_verilog_llama.py b/test/passes/graph/transforms/verilog/test_emit_verilog_llama.py deleted file mode 100644 index d29239cc8..000000000 --- a/test/passes/graph/transforms/verilog/test_emit_verilog_llama.py +++ /dev/null @@ -1,157 +0,0 @@ -import sys, os - -import torch -import torch.nn as nn - -import pytest - -from transformers.activations import GELUActivation - -import chop.passes as passes -import chop.actions as actions -from chop.ir import MaseGraph -from chop.models.patched.llama import LlamaConfig, LlamaModel -from chop.models.patched.llama.modeling_llama import LlamaAttention -from chop.passes.graph.utils import deepsetattr - -# from chop.nn.quantized import LlamaAttentionInteger -from chop.tools import get_logger, set_excepthook - -from mase_components import get_module_dependencies -from mase_components.helper.generate_memory import generate_sv_lut - -import operator -from functools import partial - -logger = get_logger(__name__) -logger.setLevel("DEBUG") -set_excepthook() - -# * Define custom ops (leaf submodules during tracing) -# * This is useful so we can write a single optimised verilog file for self attention, -# * instead of relying on emit_verilog to instantiate each submodule -LLAMA_CUSTOM_OPS = { - "modules": {}, - "functions": {}, -} - - -def llama_module_level_quantize(model, model_config, q_config): - return model - - -def llama_update_metadata(mg, q_config): - """ - The following processing is a temporary hot fix to get emit verilog working on the llama model. We - update the type and precision for the add, getitem and split (fork) nodes which are currently - inserted in the patched model code. In the (near) future, inserting forking nodes and setting their - precision correctly will be handled automatedly as a preprocessing step for the emit verilog pass, - so this function will be unnecessary. - """ - return mg, {} - - -def emit_verilog_llama( - config, - q_config, - config_sequence_length, - wait_count=15, - wait_unit="ms", - max_parallelism=4, -): - # * Get model and quantize self attention, linear and layer norm layers - model = LlamaModel(config) - model = llama_module_level_quantize(model, config, q_config) - logger.info(f"Quantized Llama model: {model}") - - # * Trace the model - mg = MaseGraph(model, custom_ops=LLAMA_CUSTOM_OPS) - mg, _ = passes.init_metadata_analysis_pass(mg) - - mg, _ = passes.report_graph_analysis_pass(mg, pass_args={"file_name": "llama.txt"}) - - # * Add metadata analysis passes - mg, _ = passes.add_common_metadata_analysis_pass( - mg, - pass_args={ - "dummy_in": { - "input_ids": torch.randn( - (1, config_sequence_length, config.hidden_size) - ) - }, - "add_value": False, - }, - ) - - mg, _ = llama_update_metadata(mg, q_config) - - mg, _ = passes.add_hardware_metadata_analysis_pass( - mg, - pass_args={ - "max_parallelism": [max_parallelism] * 4, - }, - ) - - # * Save the metadata to a file for debugging - mg, _ = passes.report_node_meta_param_analysis_pass( - mg, - pass_args={ - "which": ["common", "hardware"], - "save_path": "llama_graph_meta_params.txt", - }, - ) - - mg, _ = passes.emit_verilog_top_transform_pass(mg) - mg, _ = passes.emit_bram_transform_pass(mg) - mg, _ = passes.emit_internal_rtl_transform_pass(mg) - mg, _ = passes.emit_cocotb_transform_pass( - mg, - pass_args={ - "wait_time": wait_count, - "wait_unit": wait_unit, - }, - ) - mg, _ = passes.emit_vivado_project_transform_pass(mg) - - # Temporary: fix data coherency checks - os.environ["COCOTB_RESOLVE_X"] = "ZEROS" - - actions.simulate( - skip_build=False, skip_test=False, gui=False, waves=False, simulator="questa" - ) - - -def get_default_qconfig(): - return { - "data_in_width": 8, - "data_in_frac_width": 3, - "weight_width": 8, - "weight_frac_width": 3, - "bias_width": 8, - "bias_frac_width": 3, - "data_out_width": 8, - "data_out_frac_width": 3, - } - - -@pytest.mark.skip(reason="Not working") -def test_emit_verilog_llama_smoke(): - config = LlamaConfig() - # config.num_hidden_layers = 3 - # config.hidden_size = 96 - # config.intermediate_size = 384 - - # Make config match 7b model - config.max_position_embeddings = 4096 - config.rms_norm_eps = 1e-5 - config_sequence_length = 4 - - q_config = get_default_qconfig() - emit_verilog_llama( - config, q_config, config_sequence_length, wait_count=10, max_parallelism=2 - ) - - -if __name__ == "__main__": - generate_sv_lut("silu", 8, 3, data_width=8, f_width=3, path_with_dtype=False) - test_emit_verilog_llama_smoke() diff --git a/test/passes/graph/transforms/verilog/test_emit_verilog_mistral.py b/test/passes/graph/transforms/verilog/test_emit_verilog_mistral.py deleted file mode 100644 index cdf1c9794..000000000 --- a/test/passes/graph/transforms/verilog/test_emit_verilog_mistral.py +++ /dev/null @@ -1,153 +0,0 @@ -import sys, os -import pytest - -import torch -import torch.nn as nn - -from transformers.activations import GELUActivation - -import chop.passes as passes -import chop.actions as actions -from chop.ir import MaseGraph -from chop.models.patched.mistral import MistralConfig, MistralModel -from chop.models.patched.mistral.modeling_mistral import MistralAttention -from chop.passes.graph.utils import deepsetattr - -# from chop.nn.quantized import MistralAttentionInteger -from chop.tools import get_logger, set_excepthook - -from mase_components import get_module_dependencies -from mase_components.helper.generate_memory import generate_sv_lut - -import operator -from functools import partial - -logger = get_logger(__name__) -logger.setLevel("DEBUG") -set_excepthook() - -# * Define custom ops (leaf submodules during tracing) -# * This is useful so we can write a single optimised verilog file for self attention, -# * instead of relying on emit_verilog to instantiate each submodule -MISTRAL_CUSTOM_OPS = { - "modules": {}, - "functions": {}, -} - - -def mistral_module_level_quantize(model, model_config, q_config): - return model - - -def mistral_update_metadata(mg, q_config): - """ - The following processing is a temporary hot fix to get emit verilog working on the mistral model. We - update the type and precision for the add, getitem and split (fork) nodes which are currently - inserted in the patched model code. In the (near) future, inserting forking nodes and setting their - precision correctly will be handled automatedly as a preprocessing step for the emit verilog pass, - so this function will be unnecessary. - """ - return mg, {} - - -def emit_verilog_mistral( - config, - q_config, - config_sequence_length, - wait_count=15, - wait_unit="ms", - max_parallelism=4, -): - # * Get model and quantize self attention, linear and layer norm layers - model = MistralModel(config) - model = mistral_module_level_quantize(model, config, q_config) - logger.info(f"Quantized mistral model: {model}") - - # * Trace the model - mg = MaseGraph(model, custom_ops=MISTRAL_CUSTOM_OPS) - mg, _ = passes.init_metadata_analysis_pass(mg) - - mg, _ = passes.report_graph_analysis_pass( - mg, pass_args={"file_name": "mistral.txt"} - ) - - # * Add metadata analysis passes - mg, _ = passes.add_common_metadata_analysis_pass( - mg, - pass_args={ - "dummy_in": { - "input_ids": torch.randn( - (1, config_sequence_length, config.hidden_size) - ) - }, - "add_value": False, - }, - ) - - mg, _ = mistral_update_metadata(mg, q_config) - - mg, _ = passes.add_hardware_metadata_analysis_pass( - mg, - pass_args={ - "max_parallelism": [max_parallelism] * 4, - }, - ) - - # * Save the metadata to a file for debugging - mg, _ = passes.report_node_meta_param_analysis_pass( - mg, - pass_args={ - "which": ["common", "hardware"], - "save_path": "mistral_graph_meta_params.txt", - }, - ) - - mg, _ = passes.emit_verilog_top_transform_pass(mg) - mg, _ = passes.emit_bram_transform_pass(mg) - mg, _ = passes.emit_internal_rtl_transform_pass(mg) - mg, _ = passes.emit_cocotb_transform_pass( - mg, - pass_args={ - "wait_time": wait_count, - "wait_unit": wait_unit, - }, - ) - mg, _ = passes.emit_vivado_project_transform_pass(mg) - - # Temporary: fix data coherency checks - os.environ["COCOTB_RESOLVE_X"] = "ZEROS" - - actions.simulate( - skip_build=False, skip_test=False, gui=False, waves=False, simulator="questa" - ) - - -def get_default_qconfig(): - return { - "data_in_width": 8, - "data_in_frac_width": 3, - "weight_width": 8, - "weight_frac_width": 3, - "bias_width": 8, - "bias_frac_width": 3, - "data_out_width": 8, - "data_out_frac_width": 3, - } - - -@pytest.mark.skip(reason="Not working") -def test_emit_verilog_mistral_smoke(): - config = MistralConfig() - config.num_hidden_layers = 3 - config.hidden_size = 96 - config.intermediate_size = 384 - config_sequence_length = 4 - q_config = get_default_qconfig() - emit_verilog_mistral( - config, q_config, config_sequence_length, wait_count=10, max_parallelism=2 - ) - - -if __name__ == "__main__": - generate_sv_lut("silu", 8, 3, data_width=8, f_width=3, path_with_dtype=False) - test_emit_verilog_mistral_smoke()