From b5bdad39459efaf6046dcf3cdde7db6289a0460b Mon Sep 17 00:00:00 2001 From: Pedro Gimenes Date: Tue, 4 Jun 2024 13:36:24 +0100 Subject: [PATCH 01/93] initial commit --- src/chop/passes/__init__.py | 2 ++ .../module/analysis/autosharding/__init__.py | 2 ++ .../analysis/autosharding/autosharding.py | 18 +++++++++++++ .../autosharding/test_autosharding.py | 26 +++++++++++++++++++ 4 files changed, 48 insertions(+) create mode 100644 src/chop/passes/module/analysis/autosharding/__init__.py create mode 100644 src/chop/passes/module/analysis/autosharding/autosharding.py create mode 100644 test/passes/module/analysis/autosharding/test_autosharding.py diff --git a/src/chop/passes/__init__.py b/src/chop/passes/__init__.py index b6eec63df..1b4cf7b9c 100644 --- a/src/chop/passes/__init__.py +++ b/src/chop/passes/__init__.py @@ -42,3 +42,5 @@ from .onnx.analysis import ( export_fx_graph_analysis_pass, ) + +from .module.analysis.autosharding import autosharding_analysis_pass \ No newline at end of file diff --git a/src/chop/passes/module/analysis/autosharding/__init__.py b/src/chop/passes/module/analysis/autosharding/__init__.py new file mode 100644 index 000000000..2e53f199b --- /dev/null +++ b/src/chop/passes/module/analysis/autosharding/__init__.py @@ -0,0 +1,2 @@ + +from .autosharding import autosharding_analysis_pass \ No newline at end of file diff --git a/src/chop/passes/module/analysis/autosharding/autosharding.py b/src/chop/passes/module/analysis/autosharding/autosharding.py new file mode 100644 index 000000000..1776c27c7 --- /dev/null +++ b/src/chop/passes/module/analysis/autosharding/autosharding.py @@ -0,0 +1,18 @@ + + +def alpa_intra_op_sharding_pass(model): + + """ + A lightweight implementation of the core algorithm from the Alpa paper: https://arxiv.org/abs/2201.12023 + """ + + for name, cls in model.named_children(): + print(cls) + + return model, {} + +def autosharding_analysis_pass(model): + + model = alpa_intra_op_sharding_pass(model) + + return model, {} \ No newline at end of file diff --git a/test/passes/module/analysis/autosharding/test_autosharding.py b/test/passes/module/analysis/autosharding/test_autosharding.py new file mode 100644 index 000000000..51a4b76aa --- /dev/null +++ b/test/passes/module/analysis/autosharding/test_autosharding.py @@ -0,0 +1,26 @@ + +import torch +import torch.nn as nn + +from chop.ir import MaseGraph +import chop.passes as passes + +class MLP(nn.Module): + def __init__(self, in_features = 64, hidden_dimension=512, out_features=64): + super().__init__() + self.l1 = nn.Linear(in_features = in_features, out_features=hidden_dimension) + self.l2 = nn.Linear(in_features = hidden_dimension, out_features=hidden_dimension) + self.l3 = nn.Linear(in_features = hidden_dimension, out_features=out_features) + + def forward(self, x): + out = self.l1(x) + out = self.l2(out) + out = self.l3(out) + return out + +def test_autosharding(): + model = MLP() + model, _ = passes.autosharding_analysis_pass(model) + +if __name__ == "__main__": + test_autosharding() \ No newline at end of file From b67d712076609a9fe12b8dbad99f4cbad4899ddd Mon Sep 17 00:00:00 2001 From: pgimenes Date: Thu, 6 Jun 2024 20:41:25 +0100 Subject: [PATCH 02/93] cvxpy bug for resharding cost term in intra operator pass --- src/chop/passes/__init__.py | 2 +- .../analysis/autosharding/__init__.py | 0 .../analysis/autosharding/autosharding.py | 113 ++++++++++++++++++ .../autosharding/autosharding_layers.py | 64 ++++++++++ .../analysis/autosharding/autosharding.py | 18 --- .../autosharding/test_autosharding.py | 42 +++++++ .../autosharding/test_autosharding.py | 26 ---- 7 files changed, 220 insertions(+), 45 deletions(-) rename src/chop/passes/{module => graph}/analysis/autosharding/__init__.py (100%) create mode 100644 src/chop/passes/graph/analysis/autosharding/autosharding.py create mode 100644 src/chop/passes/graph/analysis/autosharding/autosharding_layers.py delete mode 100644 src/chop/passes/module/analysis/autosharding/autosharding.py create mode 100644 test/passes/graph/analysis/autosharding/test_autosharding.py delete mode 100644 test/passes/module/analysis/autosharding/test_autosharding.py diff --git a/src/chop/passes/__init__.py b/src/chop/passes/__init__.py index 1b4cf7b9c..70e7246ec 100644 --- a/src/chop/passes/__init__.py +++ b/src/chop/passes/__init__.py @@ -43,4 +43,4 @@ export_fx_graph_analysis_pass, ) -from .module.analysis.autosharding import autosharding_analysis_pass \ No newline at end of file +from .graph.analysis.autosharding import autosharding_analysis_pass diff --git a/src/chop/passes/module/analysis/autosharding/__init__.py b/src/chop/passes/graph/analysis/autosharding/__init__.py similarity index 100% rename from src/chop/passes/module/analysis/autosharding/__init__.py rename to src/chop/passes/graph/analysis/autosharding/__init__.py diff --git a/src/chop/passes/graph/analysis/autosharding/autosharding.py b/src/chop/passes/graph/analysis/autosharding/autosharding.py new file mode 100644 index 000000000..a7f09a718 --- /dev/null +++ b/src/chop/passes/graph/analysis/autosharding/autosharding.py @@ -0,0 +1,113 @@ +import numpy as np +import cvxpy as cp +import pulp + +from chop.tools import get_logger + +from .autosharding_layers import SHARDING_ALGOS, Shard + +logger = get_logger(__name__) + + +def alpa_intra_op_sharding_pass(mg): + """ + Intra-operator auto parallelization pass. + """ + + # Setup for the ILP optimization + expr = 0 + constr = [] + variables = [] + + # Write cost vectors into metadata for each operator + # This will later be used to solve the ILP optimization + for node in mg.fx_graph.nodes: + + # Extract the target + if isinstance(node.target, str): + target = getattr(node.meta["mase"].model, node.target, None) + target_cls = type(target) + else: + target = node.target + + if target_cls in SHARDING_ALGOS.keys(): + # Enumerate shardings and costs for this operator + ( + input_shardings, + output_shardings, + compute_cost_vector, + communication_cost_vector, + ) = SHARDING_ALGOS[target_cls]() + + # Formulate optimization variables + num_shardings = len(input_shardings) + opt_var = cp.Variable(num_shardings, boolean=True) + variables.append(opt_var) + + # Constrain choice to be a onehot vector + constr += [ + cp.sum(opt_var) == 1, + ] + + # Consider compute and communication cost + cost_sum = np.array(compute_cost_vector) + np.array( + communication_cost_vector + ) + expr += variables[-1].T @ (cost_sum) + + # Consider resharding cost + for in_node in node.all_input_nodes: + in_opt_var = in_node.meta["mase"]["software"]["autosharding"]["opt_var"] + resharding_costs = np.random.randint( + 1, 10, size=(opt_var.shape + in_opt_var.shape) + ) + flattened_resharding_cost = np.matrix.flatten(resharding_costs) + + e_var = cp.Variable(opt_var.shape + in_opt_var.shape, boolean=True) + expr += cp.vec(e_var).T @ flattened_resharding_cost + + constr += [ + cp.sum(e_var) == 1, + # e_var == cp.vec(cp.outer(opt_var, in_opt_var)), + ] + + # Write into metadata + node.meta["mase"]["software"]["autosharding"] = { + "valid_input_shardings": input_shardings, + "valid_output_shardings": output_shardings, + "compute_cost_vector": compute_cost_vector, + "communication_cost_vector": communication_cost_vector, + "opt_var": opt_var, + } + + elif node.op == "placeholder": + # Inputs to the whole graph are always replicated across all devices + rank = len(node.meta["mase"]["common"]["results"]["data_out_0"]["shape"]) + node.meta["mase"]["software"]["autosharding"] = { + "valid_input_shardings": [(Shard.R,) * rank], + "valid_output_shardings": [(Shard.R,) * rank], + "compute_cost_vector": [0], + "communication_cost_vector": [0], + "opt_var": np.array([1]), + } + + else: + logger.warning(f"No sharding algorithm found for operator: {target_cls}") + + # Solve the ILP optimization + prob = cp.Problem(cp.Minimize(expr), constr) + prob.solve() + + breakpoint() + + return mg, {} + + +def autosharding_analysis_pass(mg): + """ + A lightweight implementation of the core algorithm from the Alpa paper: https://arxiv.org/abs/2201.12023 + """ + + mg, _ = alpa_intra_op_sharding_pass(mg) + + return mg, {} diff --git a/src/chop/passes/graph/analysis/autosharding/autosharding_layers.py b/src/chop/passes/graph/analysis/autosharding/autosharding_layers.py new file mode 100644 index 000000000..51a122b32 --- /dev/null +++ b/src/chop/passes/graph/analysis/autosharding/autosharding_layers.py @@ -0,0 +1,64 @@ +from enum import Enum +import torch.nn as nn +import itertools +import random + + +class Shard(Enum): + R = 1 + S_0 = 2 + S_1 = 3 + S_01 = 4 + + def __repr__(self): + return self.name + + +VALID_2D_TENSOR_SHARDINGS = [ + (Shard.R, Shard.R), + (Shard.R, Shard.S_0), + (Shard.R, Shard.S_1), + (Shard.R, Shard.S_01), + (Shard.S_0, Shard.R), + (Shard.S_0, Shard.S_1), + (Shard.S_1, Shard.R), + (Shard.S_1, Shard.S_0), + (Shard.S_01, Shard.R), +] + + +def get_valid_linear_shardings(): + """ + Return every valid combination of shardings for the input tensors. For an operator + sharding to be valid, the inner dimension must have the same sharding. + E.g. ((R, S_0), (S_0, R)) are valid, but ((R, S_0), (S_1, R)) is not. + """ + input_shardings, output_shardings = [], [] + compute_cost_vector, communication_cost_vector = [], [] + + permutations = list(itertools.product(VALID_2D_TENSOR_SHARDINGS, repeat=2)) + for p in permutations: + output_sharding = (p[0][0], p[1][1]) + if p[0][1] == p[1][0] and output_sharding in VALID_2D_TENSOR_SHARDINGS: + input_shardings.append(p) + output_shardings.append(output_sharding) + + compute_cost_vector.append(random.random()) + + # TO DO: derive communication cost from the sharding + communication_cost_vector.append(random.random()) + + for i, in_shard in enumerate(input_shardings): + print(f"Sharding {i}: {in_shard} -> {output_shardings[i]}") + + return ( + input_shardings, + output_shardings, + compute_cost_vector, + communication_cost_vector, + ) + + +SHARDING_ALGOS = { + nn.Linear: get_valid_linear_shardings, +} diff --git a/src/chop/passes/module/analysis/autosharding/autosharding.py b/src/chop/passes/module/analysis/autosharding/autosharding.py deleted file mode 100644 index 1776c27c7..000000000 --- a/src/chop/passes/module/analysis/autosharding/autosharding.py +++ /dev/null @@ -1,18 +0,0 @@ - - -def alpa_intra_op_sharding_pass(model): - - """ - A lightweight implementation of the core algorithm from the Alpa paper: https://arxiv.org/abs/2201.12023 - """ - - for name, cls in model.named_children(): - print(cls) - - return model, {} - -def autosharding_analysis_pass(model): - - model = alpa_intra_op_sharding_pass(model) - - return model, {} \ No newline at end of file diff --git a/test/passes/graph/analysis/autosharding/test_autosharding.py b/test/passes/graph/analysis/autosharding/test_autosharding.py new file mode 100644 index 000000000..f8a3adc38 --- /dev/null +++ b/test/passes/graph/analysis/autosharding/test_autosharding.py @@ -0,0 +1,42 @@ +import torch +import torch.nn as nn + +from chop.ir import MaseGraph +import chop.passes as passes + +import sys, pdb, traceback + + +def excepthook(exc_type, exc_value, exc_traceback): + traceback.print_exception(exc_type, exc_value, exc_traceback) + print("\nEntering debugger...") + pdb.post_mortem(exc_traceback) + + +# Set the custom exception hook +sys.excepthook = excepthook + + +class MLP(nn.Module): + def __init__(self, in_features=64, hidden_dimension=128, out_features=64): + super().__init__() + self.l1 = nn.Linear(in_features, hidden_dimension) + self.l2 = nn.Linear(hidden_dimension, out_features) + + def forward(self, x): + out = self.l1(x) + return self.l2(out) + + +def test_autosharding(): + model = MLP() + mg = MaseGraph(model) + mg, _ = passes.init_metadata_analysis_pass(mg) + mg, _ = passes.add_common_metadata_analysis_pass( + mg, pass_args={"dummy_in": {"x": torch.randn((5, 16, 64))}, "add_value": False} + ) + mg, _ = passes.autosharding_analysis_pass(mg) + + +if __name__ == "__main__": + test_autosharding() diff --git a/test/passes/module/analysis/autosharding/test_autosharding.py b/test/passes/module/analysis/autosharding/test_autosharding.py deleted file mode 100644 index 51a4b76aa..000000000 --- a/test/passes/module/analysis/autosharding/test_autosharding.py +++ /dev/null @@ -1,26 +0,0 @@ - -import torch -import torch.nn as nn - -from chop.ir import MaseGraph -import chop.passes as passes - -class MLP(nn.Module): - def __init__(self, in_features = 64, hidden_dimension=512, out_features=64): - super().__init__() - self.l1 = nn.Linear(in_features = in_features, out_features=hidden_dimension) - self.l2 = nn.Linear(in_features = hidden_dimension, out_features=hidden_dimension) - self.l3 = nn.Linear(in_features = hidden_dimension, out_features=out_features) - - def forward(self, x): - out = self.l1(x) - out = self.l2(out) - out = self.l3(out) - return out - -def test_autosharding(): - model = MLP() - model, _ = passes.autosharding_analysis_pass(model) - -if __name__ == "__main__": - test_autosharding() \ No newline at end of file From 82b706e92d0fad3def961ca8cb9ee1622ca7cb2d Mon Sep 17 00:00:00 2001 From: Pedro Gimenes Date: Tue, 11 Jun 2024 16:20:09 +0000 Subject: [PATCH 03/93] update resharding cost model and ILP constraints --- .../analysis/autosharding/autosharding.py | 73 +++++++++----- .../autosharding/autosharding_layers.py | 38 ++----- .../graph/analysis/autosharding/common.py | 25 +++++ .../analysis/autosharding/cost_modelling.py | 99 +++++++++++++++++++ .../graph/analysis/autosharding/mesh.py | 53 ++++++++++ .../autosharding/test_autosharding.py | 10 +- 6 files changed, 243 insertions(+), 55 deletions(-) create mode 100644 src/chop/passes/graph/analysis/autosharding/common.py create mode 100644 src/chop/passes/graph/analysis/autosharding/cost_modelling.py create mode 100644 src/chop/passes/graph/analysis/autosharding/mesh.py diff --git a/src/chop/passes/graph/analysis/autosharding/autosharding.py b/src/chop/passes/graph/analysis/autosharding/autosharding.py index a7f09a718..0e5d439ed 100644 --- a/src/chop/passes/graph/analysis/autosharding/autosharding.py +++ b/src/chop/passes/graph/analysis/autosharding/autosharding.py @@ -1,15 +1,18 @@ + +import torch import numpy as np import cvxpy as cp -import pulp from chop.tools import get_logger +from .mesh import Mesh from .autosharding_layers import SHARDING_ALGOS, Shard +from .cost_modelling import get_resharding_matrix logger = get_logger(__name__) -def alpa_intra_op_sharding_pass(mg): +def alpa_intra_op_sharding_pass(mg, mesh): """ Intra-operator auto parallelization pass. """ @@ -37,13 +40,23 @@ def alpa_intra_op_sharding_pass(mg): output_shardings, compute_cost_vector, communication_cost_vector, - ) = SHARDING_ALGOS[target_cls]() + ) = SHARDING_ALGOS[target_cls](node.meta, mesh) + # Formulate optimization variables num_shardings = len(input_shardings) opt_var = cp.Variable(num_shardings, boolean=True) variables.append(opt_var) + # Write into metadata + node.meta["mase"]["software"]["autosharding"] = { + "valid_input_shardings": input_shardings, + "valid_output_shardings": output_shardings, + "compute_cost_vector": compute_cost_vector, + "communication_cost_vector": communication_cost_vector, + "opt_var": opt_var, + } + # Constrain choice to be a onehot vector constr += [ cp.sum(opt_var) == 1, @@ -58,29 +71,30 @@ def alpa_intra_op_sharding_pass(mg): # Consider resharding cost for in_node in node.all_input_nodes: in_opt_var = in_node.meta["mase"]["software"]["autosharding"]["opt_var"] - resharding_costs = np.random.randint( - 1, 10, size=(opt_var.shape + in_opt_var.shape) - ) - flattened_resharding_cost = np.matrix.flatten(resharding_costs) - e_var = cp.Variable(opt_var.shape + in_opt_var.shape, boolean=True) - expr += cp.vec(e_var).T @ flattened_resharding_cost + resharding_costs = get_resharding_matrix( + mesh, + src_shardings = in_node.meta["mase"]["software"]["autosharding"]["valid_output_shardings"], + dest_shardings = [sharding[0] for sharding in node.meta["mase"]["software"]["autosharding"]["valid_input_shardings"]], + dest_node_meta = node.meta["mase"] + ).flatten() + + e_var = cp.Variable(opt_var.shape[0] * in_opt_var.shape[0], boolean=True) + expr += e_var.T @ resharding_costs constr += [ cp.sum(e_var) == 1, - # e_var == cp.vec(cp.outer(opt_var, in_opt_var)), ] - # Write into metadata - node.meta["mase"]["software"]["autosharding"] = { - "valid_input_shardings": input_shardings, - "valid_output_shardings": output_shardings, - "compute_cost_vector": compute_cost_vector, - "communication_cost_vector": communication_cost_vector, - "opt_var": opt_var, - } + # Scalar construction of the inequality constraints for the linearized variable + for i in range(e_var.shape[0]): + constr += [ + e_var[i] <= opt_var[i // in_opt_var.shape[0]], + e_var[i] <= in_opt_var[i % in_opt_var.shape[0]], + e_var[i] >= opt_var[i // in_opt_var.shape[0]] + in_opt_var[i % in_opt_var.shape[0]] - 1 + ] - elif node.op == "placeholder": + elif node.op == "placeholder" or node.op == "output": # Inputs to the whole graph are always replicated across all devices rank = len(node.meta["mase"]["common"]["results"]["data_out_0"]["shape"]) node.meta["mase"]["software"]["autosharding"] = { @@ -98,16 +112,29 @@ def alpa_intra_op_sharding_pass(mg): prob = cp.Problem(cp.Minimize(expr), constr) prob.solve() - breakpoint() - return mg, {} -def autosharding_analysis_pass(mg): +def autosharding_analysis_pass(mg, pass_args: dict = {}): """ A lightweight implementation of the core algorithm from the Alpa paper: https://arxiv.org/abs/2201.12023 """ - mg, _ = alpa_intra_op_sharding_pass(mg) + assert "mesh_shape" in pass_args, "Logical description for device cluster was not specified." + assert "inter_node_bandwidth" in pass_args, "Inter-node bandwidth not specified" + assert "intra_node_bandwidth" in pass_args, "Intra-node bandwidth not specified" + + # Initialize representation of device mesh, used for cost estimation + mesh = Mesh(pass_args["mesh_shape"]) + + # Communication cost model depends + mesh.set_cost_model_parameters( + intra_node_bandwidth=pass_args["intra_node_bandwidth"], + inter_node_bandwidth=pass_args["inter_node_bandwidth"], + backend = pass_args.get("communications_backend", "default") + ) + + # Run intra-operator pass + mg, _ = alpa_intra_op_sharding_pass(mg, mesh) return mg, {} diff --git a/src/chop/passes/graph/analysis/autosharding/autosharding_layers.py b/src/chop/passes/graph/analysis/autosharding/autosharding_layers.py index 51a122b32..c25cd2a67 100644 --- a/src/chop/passes/graph/analysis/autosharding/autosharding_layers.py +++ b/src/chop/passes/graph/analysis/autosharding/autosharding_layers.py @@ -1,33 +1,11 @@ -from enum import Enum -import torch.nn as nn import itertools -import random - - -class Shard(Enum): - R = 1 - S_0 = 2 - S_1 = 3 - S_01 = 4 - def __repr__(self): - return self.name - - -VALID_2D_TENSOR_SHARDINGS = [ - (Shard.R, Shard.R), - (Shard.R, Shard.S_0), - (Shard.R, Shard.S_1), - (Shard.R, Shard.S_01), - (Shard.S_0, Shard.R), - (Shard.S_0, Shard.S_1), - (Shard.S_1, Shard.R), - (Shard.S_1, Shard.S_0), - (Shard.S_01, Shard.R), -] +import torch.nn as nn +from .common import Shard, VALID_2D_TENSOR_SHARDINGS +from .cost_modelling import get_communication_cost -def get_valid_linear_shardings(): +def get_valid_linear_shardings(node_meta, mesh): """ Return every valid combination of shardings for the input tensors. For an operator sharding to be valid, the inner dimension must have the same sharding. @@ -39,14 +17,14 @@ def get_valid_linear_shardings(): permutations = list(itertools.product(VALID_2D_TENSOR_SHARDINGS, repeat=2)) for p in permutations: output_sharding = (p[0][0], p[1][1]) - if p[0][1] == p[1][0] and output_sharding in VALID_2D_TENSOR_SHARDINGS: + if p != ((Shard.R, Shard.R), (Shard.R, Shard.R)) and p[0][1] == p[1][0] and output_sharding in VALID_2D_TENSOR_SHARDINGS: input_shardings.append(p) output_shardings.append(output_sharding) - compute_cost_vector.append(random.random()) + compute_cost_vector.append(0) - # TO DO: derive communication cost from the sharding - communication_cost_vector.append(random.random()) + cost = get_communication_cost(p, node_meta["mase"], mesh) + communication_cost_vector.append(cost) for i, in_shard in enumerate(input_shardings): print(f"Sharding {i}: {in_shard} -> {output_shardings[i]}") diff --git a/src/chop/passes/graph/analysis/autosharding/common.py b/src/chop/passes/graph/analysis/autosharding/common.py new file mode 100644 index 000000000..2148b2a81 --- /dev/null +++ b/src/chop/passes/graph/analysis/autosharding/common.py @@ -0,0 +1,25 @@ +from enum import Enum + +class Shard(Enum): + S_0 = 0 + S_1 = 1 + R = 3 + + def __repr__(self): + return self.name + + def __gt__(self, other): + if self.__class__ is other.__class__: + return self.value > other.value + return NotImplemented + + +VALID_2D_TENSOR_SHARDINGS = [ + (Shard.R, Shard.R), + (Shard.R, Shard.S_0), + (Shard.R, Shard.S_1), + (Shard.S_0, Shard.R), + (Shard.S_0, Shard.S_1), + (Shard.S_1, Shard.R), + (Shard.S_1, Shard.S_0), +] \ No newline at end of file diff --git a/src/chop/passes/graph/analysis/autosharding/cost_modelling.py b/src/chop/passes/graph/analysis/autosharding/cost_modelling.py new file mode 100644 index 000000000..bbb38a989 --- /dev/null +++ b/src/chop/passes/graph/analysis/autosharding/cost_modelling.py @@ -0,0 +1,99 @@ + +import numpy as np +from functools import lru_cache + +from chop.ir.graph import MaseMetadata + +from .common import Shard +from .mesh import Mesh + +BYTES_PER_ELEMENT = 4 + +def get_communication_cost(sharding: tuple, node_meta: MaseMetadata, mesh: Mesh): + assert sharding[0][1] == sharding[1][0], f"Inconsistent sharding for node: {node_meta.node}" + inner_dim_sharding = sharding[1][0] + + out_shape = node_meta["common"]["results"]["data_out_0"]["shape"] + + if inner_dim_sharding == Shard.R: + return 0 + + else: + ar_dim = inner_dim_sharding.value # 0 for S_0, 1 for S_1 + return mesh.all_reduce_cost(num_bytes = BYTES_PER_ELEMENT * np.prod(out_shape), mesh_dim = ar_dim) + +@lru_cache(maxsize=None) +def get_resharding_cost(mesh: Mesh, src: tuple, dest: tuple, dest_node_meta: MaseMetadata): + """ + Obtain the resharding cost given a source and destination sharding profile for a tensor. + The mesh object is assumed to have been initialized with alpha, beta parameters so that + the communication cost can be estimated for each MPI operator. + """ + + + # If original sharding is fully replicated, no resharding is required + if src == dest or src == (Shard.R, Shard.R): + return 0 + + num_bytes = BYTES_PER_ELEMENT * np.prod(dest_node_meta["common"]["args"]["data_in_0"]["shape"]) + + # No cost (simple split along given mesh dimension) + if ( + # Keep dim 0, split dim 1 + # E.g. (R, R) -> (R, S_0), (S_0, R) -> (S_0, S_1) + (src[0] == dest[0]) and (src[1] == Shard.R) and (dest[1] in [Shard.S_0, Shard.S_1]) + # Split dim 0, keep dim 1 + # E.g. (R, R) -> (S_1, R), (R, S_1) -> (S_0, S_1) + or (src[1] == dest[1]) and (src[0] == Shard.R) and (dest[0] in [Shard.S_0, Shard.S_1]) + ): + return 0 + + # Split -> Replicate (All Gather) + elif ( + # Keep dim 0, gather along dim 1 + # E.g. (S_1, S_0) -> (S_1, R) + (src[0] == dest[0]) and (src[1] in [Shard.S_0, Shard.S_1]) and (dest[1] == Shard.R) + # Gather along dim 0, keep dim 1 + # E.g. (S_0, S_1) -> (R, S_1) + or (src[1] == dest[1]) and (src[0] in [Shard.S_0, Shard.S_1]) and (dest[0] == Shard.R) + ): + ag_dim = 1 if src[0] == dest[0] else 0 + return mesh.all_gather_cost( + num_bytes = num_bytes, + mesh_dim = ag_dim, + ) + + # All-to-all + # E.g. (R, S_0) -> (S_0, R), (S_1, R) -> (R, S_1) + elif (src[0] == dest[1] and src[1] == dest[0] and (Shard.R in src)): + # all to all + a2a_dim = src[0].value if src[0] != Shard.R else src[1].value + return mesh.all_to_all_cost( + num_bytes = num_bytes, + mesh_dim = a2a_dim, + ) + + # Two-stage resharding: when the resharding cannot be resolved with a single split, all-gather or all-to-all, + # must first gather along the first non-replicated dimension, then recursively compute the cost for the + # reduced sharding + else: + # Reduce one dimension and re-compute + if (src[0] != Shard.R): + new_src = (Shard.R, src[1]) + ag_dim = src[0].value + else: + new_src = (Shard.R, Shard.R) + ag_dim = src[1].value + + return mesh.all_gather_cost( + num_bytes = num_bytes, + mesh_dim = ag_dim + ) + get_resharding_cost(mesh, new_src, dest, dest_node_meta) + +def get_resharding_matrix(mesh, src_shardings, dest_shardings, dest_node_meta): + mat = np.zeros((len(dest_shardings), len(src_shardings))) + for src_idx, src in enumerate(src_shardings): + for dest_idx, dest in enumerate(dest_shardings): + mat[dest_idx, src_idx] = get_resharding_cost(mesh, src, dest, dest_node_meta) + + return mat diff --git a/src/chop/passes/graph/analysis/autosharding/mesh.py b/src/chop/passes/graph/analysis/autosharding/mesh.py new file mode 100644 index 000000000..f30804715 --- /dev/null +++ b/src/chop/passes/graph/analysis/autosharding/mesh.py @@ -0,0 +1,53 @@ +import torch +import numpy as np + +class Mesh(): + def __init__(self, mesh_shape, mesh_alpha = None, mesh_beta = None): + self.mesh_shape = mesh_shape + + num_devices = np.prod(mesh_shape) + self.id_mesh = torch.arange(0, num_devices).reshape(mesh_shape) + + # Alpha/beta model is used to estimate communication cost between devices + self.mesh_alpha = [0] * 2 if mesh_alpha is None else mesh_alpha + self.mesh_beta = [None] * 2 if mesh_beta is None else mesh_beta + + def set_cost_model_parameters(self, intra_node_bandwidth: int, inter_node_bandwidth:int, backend:str = "default"): + # Assign differently depending if backend is NVLink, Infiniband, etc + if (backend == "default"): + # Assuming a setup with ethernet-connected nodes and devices connected through + # PCIe within each node + self.mesh_beta = [ + 1 / inter_node_bandwidth, + 1 / intra_node_bandwidth + ] + + def all_gather_cost(self, num_bytes, mesh_dim): + num_devices = self.id_mesh.shape[mesh_dim] + return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * + (num_devices - 1) / num_devices * num_bytes + 0.1) + + def all_reduce_cost(self, num_bytes, mesh_dim, num_devices = None): + """ + The term multiplied by beta represents the total number of bytes + transferred over the full transaction. For the ring implementation + of all reduce there are 2 rounds of (n-1) transfers, hence 2(n-1). + In each case num_bytes/num_devices is transferred, where num_bytes + is the number of bytes for the full tensor on each device. + """ + if num_devices is None: + num_devices = self.id_mesh.shape[mesh_dim] + return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * 2 * + (num_devices - 1) / num_devices * num_bytes + 0.01) + + def reduce_scatter_cost(self, num_bytes, mesh_dim): + num_devices = self.id_mesh.shape[mesh_dim] + return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * + (num_devices - 1) / num_devices * num_bytes + 0.001) + + def all_to_all_cost(self, num_bytes, mesh_dim): + num_devices = self.id_mesh.shape[mesh_dim] + penalty_factor = num_devices / 2.0 + return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * + (num_devices - 1) / num_devices / num_devices * num_bytes * + penalty_factor + 0.001) \ No newline at end of file diff --git a/test/passes/graph/analysis/autosharding/test_autosharding.py b/test/passes/graph/analysis/autosharding/test_autosharding.py index f8a3adc38..3f207018b 100644 --- a/test/passes/graph/analysis/autosharding/test_autosharding.py +++ b/test/passes/graph/analysis/autosharding/test_autosharding.py @@ -33,9 +33,15 @@ def test_autosharding(): mg = MaseGraph(model) mg, _ = passes.init_metadata_analysis_pass(mg) mg, _ = passes.add_common_metadata_analysis_pass( - mg, pass_args={"dummy_in": {"x": torch.randn((5, 16, 64))}, "add_value": False} + mg, pass_args={"dummy_in": {"x": torch.randn((16, 64))}, "add_value": False} ) - mg, _ = passes.autosharding_analysis_pass(mg) + mg, _ = passes.autosharding_analysis_pass( + mg, + pass_args = { + "mesh_shape": (2, 4), + "inter_node_bandwidth": 10e9, + "intra_node_bandwidth": 100e9 + }) if __name__ == "__main__": From e1cd29afe0e04badae2a0423ce4621f902f6aa06 Mon Sep 17 00:00:00 2001 From: Pedro Gimenes Date: Tue, 11 Jun 2024 17:09:10 +0000 Subject: [PATCH 04/93] some refactoring --- .../graph/analysis/autosharding/alpa.py | 112 +++++++++++++++++ ...st_modelling.py => alpa_cost_modelling.py} | 6 +- ...{autosharding_layers.py => alpa_layers.py} | 17 ++- .../analysis/autosharding/autosharding.py | 119 ++---------------- .../autosharding/{mesh.py => mesh_model.py} | 2 +- 5 files changed, 135 insertions(+), 121 deletions(-) create mode 100644 src/chop/passes/graph/analysis/autosharding/alpa.py rename src/chop/passes/graph/analysis/autosharding/{cost_modelling.py => alpa_cost_modelling.py} (95%) rename src/chop/passes/graph/analysis/autosharding/{autosharding_layers.py => alpa_layers.py} (76%) rename src/chop/passes/graph/analysis/autosharding/{mesh.py => mesh_model.py} (99%) diff --git a/src/chop/passes/graph/analysis/autosharding/alpa.py b/src/chop/passes/graph/analysis/autosharding/alpa.py new file mode 100644 index 000000000..36455f889 --- /dev/null +++ b/src/chop/passes/graph/analysis/autosharding/alpa.py @@ -0,0 +1,112 @@ +import numpy as np +import cvxpy as cp + +from chop.tools import get_logger + +from .common import Shard +from .alpa_layers import ALPA_LAYERS +from .alpa_cost_modelling import get_resharding_matrix + +logger = get_logger(__name__) + +def alpa_intra_op_sharding_pass(mg, mesh): + """ + Intra-operator auto parallelization pass. + """ + + # Setup for the ILP optimization + expr = 0 + constr = [] + + # Write cost vectors into metadata for each operator + # This will later be used to solve the ILP optimization + for node in mg.fx_graph.nodes: + + # Extract the target + if isinstance(node.target, str): + target = getattr(node.meta["mase"].model, node.target, None) + target_cls = type(target) + else: + target = node.target + + if target_cls in ALPA_LAYERS.keys(): + # Enumerate shardings and costs for this operator + ( + input_shardings, + output_shardings, + compute_cost_vector, + communication_cost_vector, + ) = ALPA_LAYERS[target_cls](node.meta, mesh) + + # Formulate optimization variable and consider compute/communication cost + opt_var = cp.Variable(len(input_shardings), boolean=True) + constr += [ + cp.sum(opt_var) == 1, + ] + expr += opt_var.T @ (compute_cost_vector + communication_cost_vector) + + # Write into metadata + node.meta["mase"]["software"]["autosharding"] = { + "valid_input_shardings": input_shardings, + "valid_output_shardings": output_shardings, + "compute_cost_vector": compute_cost_vector, + "communication_cost_vector": communication_cost_vector, + "opt_var": opt_var, + } + + # Consider resharding cost + for in_node in node.all_input_nodes: + in_opt_var = in_node.meta["mase"]["software"]["autosharding"]["opt_var"] + + resharding_costs = get_resharding_matrix( + mesh, + src_shardings = in_node.meta["mase"]["software"]["autosharding"]["valid_output_shardings"], + dest_shardings = [sharding[0] for sharding in node.meta["mase"]["software"]["autosharding"]["valid_input_shardings"]], + dest_node_meta = node.meta["mase"] + ).flatten() + + # Formulate resharding cost term with linearized variable + e_var = cp.Variable(opt_var.shape[0] * in_opt_var.shape[0], boolean=True) + expr += e_var.T @ resharding_costs + constr += [ + cp.sum(e_var) == 1, + ] + + # Scalar construction of the inequality constraints for the linearized variable + for i in range(e_var.shape[0]): + constr += [ + e_var[i] <= opt_var[i // in_opt_var.shape[0]], + e_var[i] <= in_opt_var[i % in_opt_var.shape[0]], + e_var[i] >= opt_var[i // in_opt_var.shape[0]] + in_opt_var[i % in_opt_var.shape[0]] - 1 + ] + + # Inputs to the whole graph are always replicated across all devices + elif node.op == "placeholder" or node.op == "output": + rank = len(node.meta["mase"]["common"]["results"]["data_out_0"]["shape"]) + node.meta["mase"]["software"]["autosharding"] = { + "valid_input_shardings": [(Shard.R,) * rank], + "valid_output_shardings": [(Shard.R,) * rank], + "compute_cost_vector": [0], + "communication_cost_vector": [0], + "opt_var": np.array([1]), + } + + else: + logger.warning(f"No sharding algorithm found for operator: {target_cls}") + + # Solve the ILP problem + prob = cp.Problem(cp.Minimize(expr), constr) + prob.solve() + + for node in mg.fx_graph.nodes: + if isinstance(node.meta["mase"]["software"]["autosharding"]["opt_var"], np.ndarray): + chosen_idx = 0 + else: + chosen_idx = np.where(node.meta["mase"]["software"]["autosharding"]["opt_var"].value == 1)[0][0] + node.meta["mase"]["software"]["autosharding"]["input_sharding"] = node.meta["mase"]["software"]["autosharding"]["valid_input_shardings"][chosen_idx] + node.meta["mase"]["software"]["autosharding"]["output_sharding"] = node.meta["mase"]["software"]["autosharding"]["valid_output_shardings"][chosen_idx] + + return mg, {} + +def alpa_autosharding_pass(mg, mesh): + return alpa_intra_op_sharding_pass(mg, mesh) \ No newline at end of file diff --git a/src/chop/passes/graph/analysis/autosharding/cost_modelling.py b/src/chop/passes/graph/analysis/autosharding/alpa_cost_modelling.py similarity index 95% rename from src/chop/passes/graph/analysis/autosharding/cost_modelling.py rename to src/chop/passes/graph/analysis/autosharding/alpa_cost_modelling.py index bbb38a989..2294e43d7 100644 --- a/src/chop/passes/graph/analysis/autosharding/cost_modelling.py +++ b/src/chop/passes/graph/analysis/autosharding/alpa_cost_modelling.py @@ -5,11 +5,11 @@ from chop.ir.graph import MaseMetadata from .common import Shard -from .mesh import Mesh +from .mesh_model import MeshModel BYTES_PER_ELEMENT = 4 -def get_communication_cost(sharding: tuple, node_meta: MaseMetadata, mesh: Mesh): +def get_communication_cost(sharding: tuple, node_meta: MaseMetadata, mesh: MeshModel): assert sharding[0][1] == sharding[1][0], f"Inconsistent sharding for node: {node_meta.node}" inner_dim_sharding = sharding[1][0] @@ -23,7 +23,7 @@ def get_communication_cost(sharding: tuple, node_meta: MaseMetadata, mesh: Mesh) return mesh.all_reduce_cost(num_bytes = BYTES_PER_ELEMENT * np.prod(out_shape), mesh_dim = ar_dim) @lru_cache(maxsize=None) -def get_resharding_cost(mesh: Mesh, src: tuple, dest: tuple, dest_node_meta: MaseMetadata): +def get_resharding_cost(mesh: MeshModel, src: tuple, dest: tuple, dest_node_meta: MaseMetadata): """ Obtain the resharding cost given a source and destination sharding profile for a tensor. The mesh object is assumed to have been initialized with alpha, beta parameters so that diff --git a/src/chop/passes/graph/analysis/autosharding/autosharding_layers.py b/src/chop/passes/graph/analysis/autosharding/alpa_layers.py similarity index 76% rename from src/chop/passes/graph/analysis/autosharding/autosharding_layers.py rename to src/chop/passes/graph/analysis/autosharding/alpa_layers.py index c25cd2a67..6733bea78 100644 --- a/src/chop/passes/graph/analysis/autosharding/autosharding_layers.py +++ b/src/chop/passes/graph/analysis/autosharding/alpa_layers.py @@ -1,9 +1,13 @@ import itertools - +import numpy as np import torch.nn as nn +from chop.tools import get_logger + from .common import Shard, VALID_2D_TENSOR_SHARDINGS -from .cost_modelling import get_communication_cost +from .alpa_cost_modelling import get_communication_cost + +logger = get_logger(__name__) def get_valid_linear_shardings(node_meta, mesh): """ @@ -26,17 +30,18 @@ def get_valid_linear_shardings(node_meta, mesh): cost = get_communication_cost(p, node_meta["mase"], mesh) communication_cost_vector.append(cost) + logger.debug(f"Valid shardings for linear layer:") for i, in_shard in enumerate(input_shardings): - print(f"Sharding {i}: {in_shard} -> {output_shardings[i]}") + logger.debug(f"Sharding {i}: {in_shard} -> {output_shardings[i]}") return ( input_shardings, output_shardings, - compute_cost_vector, - communication_cost_vector, + np.array(compute_cost_vector), + np.array(communication_cost_vector), ) -SHARDING_ALGOS = { +ALPA_LAYERS = { nn.Linear: get_valid_linear_shardings, } diff --git a/src/chop/passes/graph/analysis/autosharding/autosharding.py b/src/chop/passes/graph/analysis/autosharding/autosharding.py index 0e5d439ed..ea2c44243 100644 --- a/src/chop/passes/graph/analysis/autosharding/autosharding.py +++ b/src/chop/passes/graph/analysis/autosharding/autosharding.py @@ -1,120 +1,14 @@ -import torch import numpy as np import cvxpy as cp from chop.tools import get_logger -from .mesh import Mesh -from .autosharding_layers import SHARDING_ALGOS, Shard -from .cost_modelling import get_resharding_matrix +from .mesh_model import MeshModel +from .alpa import alpa_autosharding_pass logger = get_logger(__name__) - -def alpa_intra_op_sharding_pass(mg, mesh): - """ - Intra-operator auto parallelization pass. - """ - - # Setup for the ILP optimization - expr = 0 - constr = [] - variables = [] - - # Write cost vectors into metadata for each operator - # This will later be used to solve the ILP optimization - for node in mg.fx_graph.nodes: - - # Extract the target - if isinstance(node.target, str): - target = getattr(node.meta["mase"].model, node.target, None) - target_cls = type(target) - else: - target = node.target - - if target_cls in SHARDING_ALGOS.keys(): - # Enumerate shardings and costs for this operator - ( - input_shardings, - output_shardings, - compute_cost_vector, - communication_cost_vector, - ) = SHARDING_ALGOS[target_cls](node.meta, mesh) - - - # Formulate optimization variables - num_shardings = len(input_shardings) - opt_var = cp.Variable(num_shardings, boolean=True) - variables.append(opt_var) - - # Write into metadata - node.meta["mase"]["software"]["autosharding"] = { - "valid_input_shardings": input_shardings, - "valid_output_shardings": output_shardings, - "compute_cost_vector": compute_cost_vector, - "communication_cost_vector": communication_cost_vector, - "opt_var": opt_var, - } - - # Constrain choice to be a onehot vector - constr += [ - cp.sum(opt_var) == 1, - ] - - # Consider compute and communication cost - cost_sum = np.array(compute_cost_vector) + np.array( - communication_cost_vector - ) - expr += variables[-1].T @ (cost_sum) - - # Consider resharding cost - for in_node in node.all_input_nodes: - in_opt_var = in_node.meta["mase"]["software"]["autosharding"]["opt_var"] - - resharding_costs = get_resharding_matrix( - mesh, - src_shardings = in_node.meta["mase"]["software"]["autosharding"]["valid_output_shardings"], - dest_shardings = [sharding[0] for sharding in node.meta["mase"]["software"]["autosharding"]["valid_input_shardings"]], - dest_node_meta = node.meta["mase"] - ).flatten() - - e_var = cp.Variable(opt_var.shape[0] * in_opt_var.shape[0], boolean=True) - expr += e_var.T @ resharding_costs - - constr += [ - cp.sum(e_var) == 1, - ] - - # Scalar construction of the inequality constraints for the linearized variable - for i in range(e_var.shape[0]): - constr += [ - e_var[i] <= opt_var[i // in_opt_var.shape[0]], - e_var[i] <= in_opt_var[i % in_opt_var.shape[0]], - e_var[i] >= opt_var[i // in_opt_var.shape[0]] + in_opt_var[i % in_opt_var.shape[0]] - 1 - ] - - elif node.op == "placeholder" or node.op == "output": - # Inputs to the whole graph are always replicated across all devices - rank = len(node.meta["mase"]["common"]["results"]["data_out_0"]["shape"]) - node.meta["mase"]["software"]["autosharding"] = { - "valid_input_shardings": [(Shard.R,) * rank], - "valid_output_shardings": [(Shard.R,) * rank], - "compute_cost_vector": [0], - "communication_cost_vector": [0], - "opt_var": np.array([1]), - } - - else: - logger.warning(f"No sharding algorithm found for operator: {target_cls}") - - # Solve the ILP optimization - prob = cp.Problem(cp.Minimize(expr), constr) - prob.solve() - - return mg, {} - - def autosharding_analysis_pass(mg, pass_args: dict = {}): """ A lightweight implementation of the core algorithm from the Alpa paper: https://arxiv.org/abs/2201.12023 @@ -124,8 +18,10 @@ def autosharding_analysis_pass(mg, pass_args: dict = {}): assert "inter_node_bandwidth" in pass_args, "Inter-node bandwidth not specified" assert "intra_node_bandwidth" in pass_args, "Intra-node bandwidth not specified" - # Initialize representation of device mesh, used for cost estimation - mesh = Mesh(pass_args["mesh_shape"]) + # Initialize device mesh model, used for cost estimation + mesh = MeshModel(pass_args["mesh_shape"]) + + algo = pass_args.get("sharding_algo", "alpa") # Communication cost model depends mesh.set_cost_model_parameters( @@ -135,6 +31,7 @@ def autosharding_analysis_pass(mg, pass_args: dict = {}): ) # Run intra-operator pass - mg, _ = alpa_intra_op_sharding_pass(mg, mesh) + if algo == "alpa": + mg, _ = alpa_autosharding_pass(mg, mesh) return mg, {} diff --git a/src/chop/passes/graph/analysis/autosharding/mesh.py b/src/chop/passes/graph/analysis/autosharding/mesh_model.py similarity index 99% rename from src/chop/passes/graph/analysis/autosharding/mesh.py rename to src/chop/passes/graph/analysis/autosharding/mesh_model.py index f30804715..79aa93ce9 100644 --- a/src/chop/passes/graph/analysis/autosharding/mesh.py +++ b/src/chop/passes/graph/analysis/autosharding/mesh_model.py @@ -1,7 +1,7 @@ import torch import numpy as np -class Mesh(): +class MeshModel(): def __init__(self, mesh_shape, mesh_alpha = None, mesh_beta = None): self.mesh_shape = mesh_shape From b3ecaacaab63298168df92db5e00916cff8e3bee Mon Sep 17 00:00:00 2001 From: Pedro Gimenes Date: Tue, 11 Jun 2024 23:49:38 +0000 Subject: [PATCH 05/93] attach to runtime --- src/chop/distributed/__init__.py | 2 + src/chop/distributed/launcher.py | 84 +++++++++++++++++++ src/chop/distributed/utils.py | 22 +++++ .../graph/analysis/autosharding/alpa.py | 41 +++++---- .../autosharding/alpa_cost_modelling.py | 25 +++--- .../analysis/autosharding/alpa_layers.py | 4 +- .../analysis/autosharding/autosharding.py | 4 +- .../graph/analysis/autosharding/common.py | 16 ++-- src/chop/tools/__init__.py | 2 + src/chop/tools/utils.py | 8 ++ .../autosharding/test_autosharding.py | 8 +- 11 files changed, 175 insertions(+), 41 deletions(-) create mode 100644 src/chop/distributed/__init__.py create mode 100644 src/chop/distributed/launcher.py create mode 100644 src/chop/distributed/utils.py diff --git a/src/chop/distributed/__init__.py b/src/chop/distributed/__init__.py new file mode 100644 index 000000000..cc807c9ad --- /dev/null +++ b/src/chop/distributed/__init__.py @@ -0,0 +1,2 @@ + +from .launcher import MaseLauncher \ No newline at end of file diff --git a/src/chop/distributed/launcher.py b/src/chop/distributed/launcher.py new file mode 100644 index 000000000..7bf1707d6 --- /dev/null +++ b/src/chop/distributed/launcher.py @@ -0,0 +1,84 @@ +import os +from functools import partial + +import torch +import torch.nn as nn +import torch.distributed as dist +import torch.multiprocessing as mp + +from torch.distributed._tensor import ( + DeviceMesh, + distribute_module, + distribute_tensor, + Replicate, + Shard, +) + +from chop.tools import deepsetattr, get_logger + +from .utils import placement_from_sharding_config + +logger = get_logger(__name__) +logger.setLevel("DEBUG") + +def rlog(rank, msg): + """ + Only log on rank 0 to avoid repeated messages. + """ + if rank == 0: + logger.info(msg) + +def dist_model_fn( + name: str, module: nn.Module, device_mesh: DeviceMesh, rank: int, module_map={} +) -> None: + """ + This function gets called by torch.distributed._tensor.distribute_module on each module in the model. + Each tensor in each module is distributed according to the sharding configuration in module_map. + """ + rlog(rank, f"Processing module {module}") + if module in module_map: + for parameter, sharding_config in module_map[module].items(): + rlog(rank, f" Parameter: {parameter} has sharding config {sharding_config}") + if not hasattr(module, parameter): + rlog(rank, f" Module does not have parameter {parameter}") + continue + placement = placement_from_sharding_config(sharding_config) + rlog(rank, f" Placement: {placement}") + deepsetattr(module, parameter, torch.nn.Parameter(distribute_tensor(getattr(module, parameter), device_mesh, placement))) + + +def device_fn(rank, world_size, model=None, device_mesh=None, module_map={}, inputs=[]): + """ + This function gets called on each GPU device to set up the distributed environment and distribute the model, + following the SPMD model. + """ + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "12355" + dist.init_process_group("nccl", rank=rank, world_size=world_size) + device = torch.device("cuda", rank) + torch.cuda.set_device(device) + + mesh = DeviceMesh("cuda", mesh=device_mesh) + model = distribute_module( + model, mesh, partial(dist_model_fn, rank=rank, module_map=module_map), input_fn=None, output_fn=None + ) + + # TO DO: read from module_map with keys matching forward function signature + inputs = [distribute_tensor(in_tensor, mesh, [Replicate(), Replicate()]) for in_tensor in inputs] + out = model(*inputs) + + # TO DO: how to return output? + + rlog(rank, f"Module distribution done.") + + dist.destroy_process_group() + +class MaseLauncher(): + def __init__(self, mase_graph, world_size = None, device_mesh=None): + self.mg = mase_graph + self.model = mase_graph.model + self.world_size = world_size + self.device_mesh = device_mesh + + def run(self, module_map = {}, inputs=[]): + mp.spawn(partial(device_fn, model=self.model, device_mesh=self.device_mesh, module_map=module_map, inputs=inputs), args=(self.world_size,), nprocs=self.world_size, join=True) \ No newline at end of file diff --git a/src/chop/distributed/utils.py b/src/chop/distributed/utils.py new file mode 100644 index 000000000..86ef538d0 --- /dev/null +++ b/src/chop/distributed/utils.py @@ -0,0 +1,22 @@ + +from torch.distributed._tensor import ( + Replicate, + Shard, +) + +from chop.passes.graph.analysis.autosharding.common import SpmdShard + +def placement_from_sharding_config(sharding_config): + """ + Sharding config is given as a tuple such as (R, S_0) where a symbol S_x at index i indicates + that tensor dimension i is sharded along the x-th dimension of the device mesh. However, + the distribute_tensor API expects a tuple of Shard() and Replicate() objects where a Shard(x) + at index i indicates that tensor dimension x is sharded along device mesh dimension i. + """ + placement = [Replicate(), Replicate()] + for shard_type in [SpmdShard.S_0, SpmdShard.S_1]: + if shard_type in sharding_config: + idx = sharding_config.index(shard_type) + placement[shard_type.value] = Shard(idx) + return placement + \ No newline at end of file diff --git a/src/chop/passes/graph/analysis/autosharding/alpa.py b/src/chop/passes/graph/analysis/autosharding/alpa.py index 36455f889..0683faa58 100644 --- a/src/chop/passes/graph/analysis/autosharding/alpa.py +++ b/src/chop/passes/graph/analysis/autosharding/alpa.py @@ -3,17 +3,25 @@ from chop.tools import get_logger -from .common import Shard +from .common import SpmdShard from .alpa_layers import ALPA_LAYERS from .alpa_cost_modelling import get_resharding_matrix logger = get_logger(__name__) +def get_node_target(node): + if isinstance(node.target, str): + return getattr(node.meta["mase"].model, node.target, None) + else: + return node.target + def alpa_intra_op_sharding_pass(mg, mesh): """ Intra-operator auto parallelization pass. """ + module_map = {} + # Setup for the ILP optimization expr = 0 constr = [] @@ -22,12 +30,8 @@ def alpa_intra_op_sharding_pass(mg, mesh): # This will later be used to solve the ILP optimization for node in mg.fx_graph.nodes: - # Extract the target - if isinstance(node.target, str): - target = getattr(node.meta["mase"].model, node.target, None) - target_cls = type(target) - else: - target = node.target + target = get_node_target(node) + target_cls = type(target) if target_cls in ALPA_LAYERS.keys(): # Enumerate shardings and costs for this operator @@ -84,8 +88,8 @@ def alpa_intra_op_sharding_pass(mg, mesh): elif node.op == "placeholder" or node.op == "output": rank = len(node.meta["mase"]["common"]["results"]["data_out_0"]["shape"]) node.meta["mase"]["software"]["autosharding"] = { - "valid_input_shardings": [(Shard.R,) * rank], - "valid_output_shardings": [(Shard.R,) * rank], + "valid_input_shardings": [(SpmdShard.R,) * rank], + "valid_output_shardings": [(SpmdShard.R,) * rank], "compute_cost_vector": [0], "communication_cost_vector": [0], "opt_var": np.array([1]), @@ -99,14 +103,21 @@ def alpa_intra_op_sharding_pass(mg, mesh): prob.solve() for node in mg.fx_graph.nodes: - if isinstance(node.meta["mase"]["software"]["autosharding"]["opt_var"], np.ndarray): - chosen_idx = 0 - else: - chosen_idx = np.where(node.meta["mase"]["software"]["autosharding"]["opt_var"].value == 1)[0][0] + chosen_idx = 0 if isinstance(node.meta["mase"]["software"]["autosharding"]["opt_var"], np.ndarray) else np.where(node.meta["mase"]["software"]["autosharding"]["opt_var"].value == 1)[0][0] node.meta["mase"]["software"]["autosharding"]["input_sharding"] = node.meta["mase"]["software"]["autosharding"]["valid_input_shardings"][chosen_idx] node.meta["mase"]["software"]["autosharding"]["output_sharding"] = node.meta["mase"]["software"]["autosharding"]["valid_output_shardings"][chosen_idx] + + # Write into module map (used by distributed launcher) + target = get_node_target(node) + if target is not None: + module_map[target] = { + "input": node.meta["mase"]["software"]["autosharding"]["input_sharding"][0], + "weight": node.meta["mase"]["software"]["autosharding"]["input_sharding"][1], + "output": node.meta["mase"]["software"]["autosharding"]["output_sharding"], + } - return mg, {} + return mg, module_map def alpa_autosharding_pass(mg, mesh): - return alpa_intra_op_sharding_pass(mg, mesh) \ No newline at end of file + mg, module_map = alpa_intra_op_sharding_pass(mg, mesh) + return mg, module_map \ No newline at end of file diff --git a/src/chop/passes/graph/analysis/autosharding/alpa_cost_modelling.py b/src/chop/passes/graph/analysis/autosharding/alpa_cost_modelling.py index 2294e43d7..3e66e4697 100644 --- a/src/chop/passes/graph/analysis/autosharding/alpa_cost_modelling.py +++ b/src/chop/passes/graph/analysis/autosharding/alpa_cost_modelling.py @@ -4,7 +4,7 @@ from chop.ir.graph import MaseMetadata -from .common import Shard +from .common import SpmdShard from .mesh_model import MeshModel BYTES_PER_ELEMENT = 4 @@ -15,7 +15,7 @@ def get_communication_cost(sharding: tuple, node_meta: MaseMetadata, mesh: MeshM out_shape = node_meta["common"]["results"]["data_out_0"]["shape"] - if inner_dim_sharding == Shard.R: + if inner_dim_sharding == SpmdShard.R: return 0 else: @@ -32,7 +32,7 @@ def get_resharding_cost(mesh: MeshModel, src: tuple, dest: tuple, dest_node_meta # If original sharding is fully replicated, no resharding is required - if src == dest or src == (Shard.R, Shard.R): + if src == dest or src == (SpmdShard.R, SpmdShard.R): return 0 num_bytes = BYTES_PER_ELEMENT * np.prod(dest_node_meta["common"]["args"]["data_in_0"]["shape"]) @@ -41,10 +41,10 @@ def get_resharding_cost(mesh: MeshModel, src: tuple, dest: tuple, dest_node_meta if ( # Keep dim 0, split dim 1 # E.g. (R, R) -> (R, S_0), (S_0, R) -> (S_0, S_1) - (src[0] == dest[0]) and (src[1] == Shard.R) and (dest[1] in [Shard.S_0, Shard.S_1]) + (src[0] == dest[0]) and (src[1] == SpmdShard.R) and (dest[1] in [SpmdShard.S_0, SpmdShard.S_1]) # Split dim 0, keep dim 1 # E.g. (R, R) -> (S_1, R), (R, S_1) -> (S_0, S_1) - or (src[1] == dest[1]) and (src[0] == Shard.R) and (dest[0] in [Shard.S_0, Shard.S_1]) + or (src[1] == dest[1]) and (src[0] == SpmdShard.R) and (dest[0] in [SpmdShard.S_0, SpmdShard.S_1]) ): return 0 @@ -52,10 +52,10 @@ def get_resharding_cost(mesh: MeshModel, src: tuple, dest: tuple, dest_node_meta elif ( # Keep dim 0, gather along dim 1 # E.g. (S_1, S_0) -> (S_1, R) - (src[0] == dest[0]) and (src[1] in [Shard.S_0, Shard.S_1]) and (dest[1] == Shard.R) + (src[0] == dest[0]) and (src[1] in [SpmdShard.S_0, SpmdShard.S_1]) and (dest[1] == SpmdShard.R) # Gather along dim 0, keep dim 1 # E.g. (S_0, S_1) -> (R, S_1) - or (src[1] == dest[1]) and (src[0] in [Shard.S_0, Shard.S_1]) and (dest[0] == Shard.R) + or (src[1] == dest[1]) and (src[0] in [SpmdShard.S_0, SpmdShard.S_1]) and (dest[0] == SpmdShard.R) ): ag_dim = 1 if src[0] == dest[0] else 0 return mesh.all_gather_cost( @@ -65,9 +65,9 @@ def get_resharding_cost(mesh: MeshModel, src: tuple, dest: tuple, dest_node_meta # All-to-all # E.g. (R, S_0) -> (S_0, R), (S_1, R) -> (R, S_1) - elif (src[0] == dest[1] and src[1] == dest[0] and (Shard.R in src)): + elif (src[0] == dest[1] and src[1] == dest[0] and (SpmdShard.R in src)): # all to all - a2a_dim = src[0].value if src[0] != Shard.R else src[1].value + a2a_dim = src[0].value if src[0] != SpmdShard.R else src[1].value return mesh.all_to_all_cost( num_bytes = num_bytes, mesh_dim = a2a_dim, @@ -78,11 +78,11 @@ def get_resharding_cost(mesh: MeshModel, src: tuple, dest: tuple, dest_node_meta # reduced sharding else: # Reduce one dimension and re-compute - if (src[0] != Shard.R): - new_src = (Shard.R, src[1]) + if (src[0] != SpmdShard.R): + new_src = (SpmdShard.R, src[1]) ag_dim = src[0].value else: - new_src = (Shard.R, Shard.R) + new_src = (SpmdShard.R, SpmdShard.R) ag_dim = src[1].value return mesh.all_gather_cost( @@ -95,5 +95,4 @@ def get_resharding_matrix(mesh, src_shardings, dest_shardings, dest_node_meta): for src_idx, src in enumerate(src_shardings): for dest_idx, dest in enumerate(dest_shardings): mat[dest_idx, src_idx] = get_resharding_cost(mesh, src, dest, dest_node_meta) - return mat diff --git a/src/chop/passes/graph/analysis/autosharding/alpa_layers.py b/src/chop/passes/graph/analysis/autosharding/alpa_layers.py index 6733bea78..03393c191 100644 --- a/src/chop/passes/graph/analysis/autosharding/alpa_layers.py +++ b/src/chop/passes/graph/analysis/autosharding/alpa_layers.py @@ -4,7 +4,7 @@ from chop.tools import get_logger -from .common import Shard, VALID_2D_TENSOR_SHARDINGS +from .common import SpmdShard, VALID_2D_TENSOR_SHARDINGS from .alpa_cost_modelling import get_communication_cost logger = get_logger(__name__) @@ -21,7 +21,7 @@ def get_valid_linear_shardings(node_meta, mesh): permutations = list(itertools.product(VALID_2D_TENSOR_SHARDINGS, repeat=2)) for p in permutations: output_sharding = (p[0][0], p[1][1]) - if p != ((Shard.R, Shard.R), (Shard.R, Shard.R)) and p[0][1] == p[1][0] and output_sharding in VALID_2D_TENSOR_SHARDINGS: + if p != ((SpmdShard.R, SpmdShard.R), (SpmdShard.R, SpmdShard.R)) and p[0][1] == p[1][0] and output_sharding in VALID_2D_TENSOR_SHARDINGS: input_shardings.append(p) output_shardings.append(output_sharding) diff --git a/src/chop/passes/graph/analysis/autosharding/autosharding.py b/src/chop/passes/graph/analysis/autosharding/autosharding.py index ea2c44243..d1a0503f0 100644 --- a/src/chop/passes/graph/analysis/autosharding/autosharding.py +++ b/src/chop/passes/graph/analysis/autosharding/autosharding.py @@ -32,6 +32,6 @@ def autosharding_analysis_pass(mg, pass_args: dict = {}): # Run intra-operator pass if algo == "alpa": - mg, _ = alpa_autosharding_pass(mg, mesh) + mg, module_map = alpa_autosharding_pass(mg, mesh) - return mg, {} + return mg, module_map diff --git a/src/chop/passes/graph/analysis/autosharding/common.py b/src/chop/passes/graph/analysis/autosharding/common.py index 2148b2a81..e0b98001a 100644 --- a/src/chop/passes/graph/analysis/autosharding/common.py +++ b/src/chop/passes/graph/analysis/autosharding/common.py @@ -1,6 +1,6 @@ from enum import Enum -class Shard(Enum): +class SpmdShard(Enum): S_0 = 0 S_1 = 1 R = 3 @@ -15,11 +15,11 @@ def __gt__(self, other): VALID_2D_TENSOR_SHARDINGS = [ - (Shard.R, Shard.R), - (Shard.R, Shard.S_0), - (Shard.R, Shard.S_1), - (Shard.S_0, Shard.R), - (Shard.S_0, Shard.S_1), - (Shard.S_1, Shard.R), - (Shard.S_1, Shard.S_0), + (SpmdShard.R, SpmdShard.R), + (SpmdShard.R, SpmdShard.S_0), + (SpmdShard.R, SpmdShard.S_1), + (SpmdShard.S_0, SpmdShard.R), + (SpmdShard.S_0, SpmdShard.S_1), + (SpmdShard.S_1, SpmdShard.R), + (SpmdShard.S_1, SpmdShard.S_0), ] \ No newline at end of file diff --git a/src/chop/tools/__init__.py b/src/chop/tools/__init__.py index 5d7787b4d..0d9d707a9 100644 --- a/src/chop/tools/__init__.py +++ b/src/chop/tools/__init__.py @@ -9,3 +9,5 @@ from .logger import root_logger, set_logging_verbosity, get_logger from .get_input import get_cf_args, get_dummy_input + +from .utils import deepsetattr \ No newline at end of file diff --git a/src/chop/tools/utils.py b/src/chop/tools/utils.py index 7dd292e82..e38f4d7e7 100644 --- a/src/chop/tools/utils.py +++ b/src/chop/tools/utils.py @@ -256,3 +256,11 @@ def parse_accelerator(accelerator: str): else: raise RuntimeError(f"Unsupported accelerator {accelerator}") return device + +def deepsetattr(obj, attr, value): + """Recurses through an attribute chain to set the ultimate value.""" + attrs = attr.split(".") + if len(attrs) > 1: + deepsetattr(getattr(obj, attrs[0]), '.'.join(attrs[1:]), value) + else: + setattr(obj, attr, value) \ No newline at end of file diff --git a/test/passes/graph/analysis/autosharding/test_autosharding.py b/test/passes/graph/analysis/autosharding/test_autosharding.py index 3f207018b..6b6b870a6 100644 --- a/test/passes/graph/analysis/autosharding/test_autosharding.py +++ b/test/passes/graph/analysis/autosharding/test_autosharding.py @@ -2,6 +2,7 @@ import torch.nn as nn from chop.ir import MaseGraph +from chop.distributed import MaseLauncher import chop.passes as passes import sys, pdb, traceback @@ -16,6 +17,7 @@ def excepthook(exc_type, exc_value, exc_traceback): # Set the custom exception hook sys.excepthook = excepthook +WORLD_SIZE = 8 class MLP(nn.Module): def __init__(self, in_features=64, hidden_dimension=128, out_features=64): @@ -35,7 +37,7 @@ def test_autosharding(): mg, _ = passes.add_common_metadata_analysis_pass( mg, pass_args={"dummy_in": {"x": torch.randn((16, 64))}, "add_value": False} ) - mg, _ = passes.autosharding_analysis_pass( + mg, module_map = passes.autosharding_analysis_pass( mg, pass_args = { "mesh_shape": (2, 4), @@ -43,6 +45,10 @@ def test_autosharding(): "intra_node_bandwidth": 100e9 }) + launcher = MaseLauncher(mg, world_size=WORLD_SIZE, device_mesh=[[0, 1, 2, 3], [4, 5, 6, 7]]) + inputs = [torch.randn((16, 64))] + launcher.run(module_map, inputs) + if __name__ == "__main__": test_autosharding() From df4a02bff1d1aee11fa648ee8daaf4a85ce2da90 Mon Sep 17 00:00:00 2001 From: Pedro Gimenes Date: Wed, 12 Jun 2024 11:35:19 +0000 Subject: [PATCH 06/93] handle resharding between nodes, improve logging --- src/chop/distributed/launcher.py | 28 ++++------ src/chop/distributed/utils.py | 9 +++- src/chop/passes/__init__.py | 2 +- .../analysis/autosharding/autosharding.py | 8 +++ src/chop/passes/module/__init__.py | 2 +- src/chop/passes/module/transforms/__init__.py | 1 + .../transforms/autosharding/__init__.py | 2 + .../transforms/autosharding/resharding.py | 53 +++++++++++++++++++ .../autosharding/test_autosharding.py | 22 +++++--- 9 files changed, 99 insertions(+), 28 deletions(-) create mode 100644 src/chop/passes/module/transforms/autosharding/__init__.py create mode 100644 src/chop/passes/module/transforms/autosharding/resharding.py diff --git a/src/chop/distributed/launcher.py b/src/chop/distributed/launcher.py index 7bf1707d6..9874938fc 100644 --- a/src/chop/distributed/launcher.py +++ b/src/chop/distributed/launcher.py @@ -14,20 +14,13 @@ Shard, ) -from chop.tools import deepsetattr, get_logger - +from chop.distributed.utils import rlog +from ..tools import get_logger from .utils import placement_from_sharding_config logger = get_logger(__name__) logger.setLevel("DEBUG") -def rlog(rank, msg): - """ - Only log on rank 0 to avoid repeated messages. - """ - if rank == 0: - logger.info(msg) - def dist_model_fn( name: str, module: nn.Module, device_mesh: DeviceMesh, rank: int, module_map={} ) -> None: @@ -35,16 +28,14 @@ def dist_model_fn( This function gets called by torch.distributed._tensor.distribute_module on each module in the model. Each tensor in each module is distributed according to the sharding configuration in module_map. """ - rlog(rank, f"Processing module {module}") if module in module_map: for parameter, sharding_config in module_map[module].items(): - rlog(rank, f" Parameter: {parameter} has sharding config {sharding_config}") if not hasattr(module, parameter): - rlog(rank, f" Module does not have parameter {parameter}") + rlog(logger, rank, f"Module {module} does not have parameter {parameter}", level="warning") continue placement = placement_from_sharding_config(sharding_config) - rlog(rank, f" Placement: {placement}") - deepsetattr(module, parameter, torch.nn.Parameter(distribute_tensor(getattr(module, parameter), device_mesh, placement))) + rlog(logger, rank, f"Distributing parameter {parameter} of module {module} to {placement}", level="debug") + setattr(module, parameter, torch.nn.Parameter(distribute_tensor(getattr(module, parameter), device_mesh, placement))) def device_fn(rank, world_size, model=None, device_mesh=None, module_map={}, inputs=[]): @@ -54,23 +45,21 @@ def device_fn(rank, world_size, model=None, device_mesh=None, module_map={}, inp """ os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = "12355" + os.environ["RANK"] = str(rank) dist.init_process_group("nccl", rank=rank, world_size=world_size) device = torch.device("cuda", rank) torch.cuda.set_device(device) mesh = DeviceMesh("cuda", mesh=device_mesh) + rlog(logger, rank, f"Distributing module parameters...", level="info") model = distribute_module( model, mesh, partial(dist_model_fn, rank=rank, module_map=module_map), input_fn=None, output_fn=None ) + rlog(logger, rank, f"Module distribution done.") - # TO DO: read from module_map with keys matching forward function signature inputs = [distribute_tensor(in_tensor, mesh, [Replicate(), Replicate()]) for in_tensor in inputs] out = model(*inputs) - # TO DO: how to return output? - - rlog(rank, f"Module distribution done.") - dist.destroy_process_group() class MaseLauncher(): @@ -81,4 +70,5 @@ def __init__(self, mase_graph, world_size = None, device_mesh=None): self.device_mesh = device_mesh def run(self, module_map = {}, inputs=[]): + logger.info(f"Launching model with world size {self.world_size}.") mp.spawn(partial(device_fn, model=self.model, device_mesh=self.device_mesh, module_map=module_map, inputs=inputs), args=(self.world_size,), nprocs=self.world_size, join=True) \ No newline at end of file diff --git a/src/chop/distributed/utils.py b/src/chop/distributed/utils.py index 86ef538d0..42568cecf 100644 --- a/src/chop/distributed/utils.py +++ b/src/chop/distributed/utils.py @@ -19,4 +19,11 @@ def placement_from_sharding_config(sharding_config): idx = sharding_config.index(shard_type) placement[shard_type.value] = Shard(idx) return placement - \ No newline at end of file + +def rlog(logger, rank, msg, level="info"): + """ + Only log on rank 0 to avoid repeated messages. + """ + log_fn = getattr(logger, level, logger.info) + if rank == 0: + log_fn(msg) \ No newline at end of file diff --git a/src/chop/passes/__init__.py b/src/chop/passes/__init__.py index 70e7246ec..58119b592 100644 --- a/src/chop/passes/__init__.py +++ b/src/chop/passes/__init__.py @@ -37,7 +37,7 @@ tensorrt_fake_quantize_transform_pass, ) from .module.analysis import calculate_avg_bits_module_analysis_pass -from .module.transforms import quantize_module_transform_pass +from .module.transforms import quantize_module_transform_pass, resharding_transform_pass from .onnx.analysis import ( export_fx_graph_analysis_pass, diff --git a/src/chop/passes/graph/analysis/autosharding/autosharding.py b/src/chop/passes/graph/analysis/autosharding/autosharding.py index d1a0503f0..3f6ae4b60 100644 --- a/src/chop/passes/graph/analysis/autosharding/autosharding.py +++ b/src/chop/passes/graph/analysis/autosharding/autosharding.py @@ -1,6 +1,7 @@ import numpy as np import cvxpy as cp +from time import time from chop.tools import get_logger @@ -8,6 +9,7 @@ from .alpa import alpa_autosharding_pass logger = get_logger(__name__) +logger.setLevel("DEBUG") def autosharding_analysis_pass(mg, pass_args: dict = {}): """ @@ -18,6 +20,9 @@ def autosharding_analysis_pass(mg, pass_args: dict = {}): assert "inter_node_bandwidth" in pass_args, "Inter-node bandwidth not specified" assert "intra_node_bandwidth" in pass_args, "Intra-node bandwidth not specified" + # Timing + start_time = time() + # Initialize device mesh model, used for cost estimation mesh = MeshModel(pass_args["mesh_shape"]) @@ -34,4 +39,7 @@ def autosharding_analysis_pass(mg, pass_args: dict = {}): if algo == "alpa": mg, module_map = alpa_autosharding_pass(mg, mesh) + end_time = time() + logger.info(f"Autosharding pass complete. Time taken: {end_time - start_time} seconds.") + return mg, module_map diff --git a/src/chop/passes/module/__init__.py b/src/chop/passes/module/__init__.py index 400d5122d..946827339 100644 --- a/src/chop/passes/module/__init__.py +++ b/src/chop/passes/module/__init__.py @@ -1,5 +1,5 @@ from .analysis import calculate_avg_bits_module_analysis_pass -from .transforms import quantize_module_transform_pass +from .transforms import quantize_module_transform_pass, resharding_transform_pass ANALYSIS_PASSES = ["calculate_avg_bits_module_analysis_pass"] diff --git a/src/chop/passes/module/transforms/__init__.py b/src/chop/passes/module/transforms/__init__.py index 3fcc8c5b3..754c4f0ce 100644 --- a/src/chop/passes/module/transforms/__init__.py +++ b/src/chop/passes/module/transforms/__init__.py @@ -1 +1,2 @@ from .quantize import quantize_module_transform_pass +from .autosharding import resharding_transform_pass \ No newline at end of file diff --git a/src/chop/passes/module/transforms/autosharding/__init__.py b/src/chop/passes/module/transforms/autosharding/__init__.py new file mode 100644 index 000000000..95de26f74 --- /dev/null +++ b/src/chop/passes/module/transforms/autosharding/__init__.py @@ -0,0 +1,2 @@ + +from .resharding import resharding_transform_pass \ No newline at end of file diff --git a/src/chop/passes/module/transforms/autosharding/resharding.py b/src/chop/passes/module/transforms/autosharding/resharding.py new file mode 100644 index 000000000..a9a8cc3d5 --- /dev/null +++ b/src/chop/passes/module/transforms/autosharding/resharding.py @@ -0,0 +1,53 @@ + +import torch +import torch.nn as nn + +from torch.distributed._tensor import ( + DeviceMesh, +) + +from torch.distributed._tensor.api import Redistribute + +from chop.distributed.utils import placement_from_sharding_config, rlog +from chop.tools import get_logger + +logger = get_logger(__name__) +logger.setLevel("DEBUG") + +class ReshardingWrapper(nn.Module): + def __init__(self, device_mesh, module, resharding_config): + super().__init__() + self.module = module + self.resharding_config = resharding_config + self.device_mesh = device_mesh + + def forward(self, x): + rank = torch.distributed.get_rank() + device_mesh = DeviceMesh("cuda", self.device_mesh) + + required_placement = placement_from_sharding_config(self.resharding_config["input"]) + if (x.placements != required_placement): + rlog(logger, rank, f"For module {self.module}, resharding tensor x from {x.placements} to {required_placement}", level="debug") + x = Redistribute.apply(x, device_mesh, required_placement) + + return self.module(x) + +def resharding_transform_pass(mg, pass_args={}): + """ + This pass inserts a wrapper around each module in the graph to handle resharding + activation tensors when the output of the previous module has a different sharding + profile to the one assigned to the current module. + """ + + module_map = pass_args.get("module_map", None) + device_mesh = pass_args.get("device_mesh", None) + if module_map is None or device_mesh is None: + raise ValueError("module_map and device_mesh are required for resharding_transform_pass") + + for node in mg.fx_graph.nodes: + module = getattr(mg.model, node.target, None) + if module is not None: + resharding_config = module_map[module] + setattr(mg.model, node.target, ReshardingWrapper(device_mesh, module, resharding_config)) + + return mg, {} \ No newline at end of file diff --git a/test/passes/graph/analysis/autosharding/test_autosharding.py b/test/passes/graph/analysis/autosharding/test_autosharding.py index 6b6b870a6..d5f1bdc2a 100644 --- a/test/passes/graph/analysis/autosharding/test_autosharding.py +++ b/test/passes/graph/analysis/autosharding/test_autosharding.py @@ -1,12 +1,12 @@ +import sys, pdb, traceback, os + import torch import torch.nn as nn from chop.ir import MaseGraph from chop.distributed import MaseLauncher import chop.passes as passes - -import sys, pdb, traceback - +from chop.tools import get_logger def excepthook(exc_type, exc_value, exc_traceback): traceback.print_exception(exc_type, exc_value, exc_traceback) @@ -17,8 +17,11 @@ def excepthook(exc_type, exc_value, exc_traceback): # Set the custom exception hook sys.excepthook = excepthook -WORLD_SIZE = 8 +logger = get_logger(__name__) +logger.setLevel("DEBUG") +WORLD_SIZE = 8 +DEVICE_MESH = [[0, 1, 2, 3], [4, 5, 6, 7]] class MLP(nn.Module): def __init__(self, in_features=64, hidden_dimension=128, out_features=64): super().__init__() @@ -29,14 +32,17 @@ def forward(self, x): out = self.l1(x) return self.l2(out) - def test_autosharding(): + + # Initialize model and MaseGraph model = MLP() mg = MaseGraph(model) mg, _ = passes.init_metadata_analysis_pass(mg) mg, _ = passes.add_common_metadata_analysis_pass( mg, pass_args={"dummy_in": {"x": torch.randn((16, 64))}, "add_value": False} ) + + # Run autosharding pass to decide sharding configuration mg, module_map = passes.autosharding_analysis_pass( mg, pass_args = { @@ -45,7 +51,11 @@ def test_autosharding(): "intra_node_bandwidth": 100e9 }) - launcher = MaseLauncher(mg, world_size=WORLD_SIZE, device_mesh=[[0, 1, 2, 3], [4, 5, 6, 7]]) + # Insert resharding wrappers around each module + mg, _ = passes.resharding_transform_pass(mg, pass_args={"module_map": module_map, "device_mesh": DEVICE_MESH}) + + # Launch model in distributed cluster + launcher = MaseLauncher(mg, world_size=WORLD_SIZE, device_mesh=DEVICE_MESH) inputs = [torch.randn((16, 64))] launcher.run(module_map, inputs) From 74c27129aa9987047299633b31e4d9b0bb75a121 Mon Sep 17 00:00:00 2001 From: Pedro Gimenes Date: Wed, 12 Jun 2024 13:43:34 +0000 Subject: [PATCH 07/93] make sharding decision dict based instead of positional --- .../graph/analysis/autosharding/alpa.py | 7 +- .../analysis/autosharding/alpa_layers.py | 13 ++-- .../autosharding/test_autosharding_bert.py | 73 +++++++++++++++++++ ...harding.py => test_autosharding_linear.py} | 3 +- 4 files changed, 83 insertions(+), 13 deletions(-) create mode 100644 test/passes/graph/analysis/autosharding/test_autosharding_bert.py rename test/passes/graph/analysis/autosharding/{test_autosharding.py => test_autosharding_linear.py} (95%) diff --git a/src/chop/passes/graph/analysis/autosharding/alpa.py b/src/chop/passes/graph/analysis/autosharding/alpa.py index 0683faa58..a23c75aa3 100644 --- a/src/chop/passes/graph/analysis/autosharding/alpa.py +++ b/src/chop/passes/graph/analysis/autosharding/alpa.py @@ -65,7 +65,7 @@ def alpa_intra_op_sharding_pass(mg, mesh): resharding_costs = get_resharding_matrix( mesh, src_shardings = in_node.meta["mase"]["software"]["autosharding"]["valid_output_shardings"], - dest_shardings = [sharding[0] for sharding in node.meta["mase"]["software"]["autosharding"]["valid_input_shardings"]], + dest_shardings = [sharding["input"] for sharding in node.meta["mase"]["software"]["autosharding"]["valid_input_shardings"]], dest_node_meta = node.meta["mase"] ).flatten() @@ -111,10 +111,9 @@ def alpa_intra_op_sharding_pass(mg, mesh): target = get_node_target(node) if target is not None: module_map[target] = { - "input": node.meta["mase"]["software"]["autosharding"]["input_sharding"][0], - "weight": node.meta["mase"]["software"]["autosharding"]["input_sharding"][1], - "output": node.meta["mase"]["software"]["autosharding"]["output_sharding"], + key: node.meta["mase"]["software"]["autosharding"]["input_sharding"][key] for key in node.meta["mase"]["software"]["autosharding"]["input_sharding"].keys() } + module_map[target]["output"] = node.meta["mase"]["software"]["autosharding"]["output_sharding"] return mg, module_map diff --git a/src/chop/passes/graph/analysis/autosharding/alpa_layers.py b/src/chop/passes/graph/analysis/autosharding/alpa_layers.py index 03393c191..f382266bd 100644 --- a/src/chop/passes/graph/analysis/autosharding/alpa_layers.py +++ b/src/chop/passes/graph/analysis/autosharding/alpa_layers.py @@ -22,17 +22,14 @@ def get_valid_linear_shardings(node_meta, mesh): for p in permutations: output_sharding = (p[0][0], p[1][1]) if p != ((SpmdShard.R, SpmdShard.R), (SpmdShard.R, SpmdShard.R)) and p[0][1] == p[1][0] and output_sharding in VALID_2D_TENSOR_SHARDINGS: - input_shardings.append(p) + input_shardings.append({ + "input": p[0], + "weight": p[1] + }) output_shardings.append(output_sharding) compute_cost_vector.append(0) - - cost = get_communication_cost(p, node_meta["mase"], mesh) - communication_cost_vector.append(cost) - - logger.debug(f"Valid shardings for linear layer:") - for i, in_shard in enumerate(input_shardings): - logger.debug(f"Sharding {i}: {in_shard} -> {output_shardings[i]}") + communication_cost_vector.append(get_communication_cost(p, node_meta["mase"], mesh)) return ( input_shardings, diff --git a/test/passes/graph/analysis/autosharding/test_autosharding_bert.py b/test/passes/graph/analysis/autosharding/test_autosharding_bert.py new file mode 100644 index 000000000..65a4e3297 --- /dev/null +++ b/test/passes/graph/analysis/autosharding/test_autosharding_bert.py @@ -0,0 +1,73 @@ +import sys, pdb, traceback, os + +import torch +import torch.nn as nn + +from chop.ir import MaseGraph +from chop.distributed import MaseLauncher +import chop.passes as passes +from chop.tools import get_logger + +from chop.models.patched.bert import BertConfig, BertModel +from chop.models.patched.bert.modeling_bert import BertSelfAttention + +def excepthook(exc_type, exc_value, exc_traceback): + traceback.print_exception(exc_type, exc_value, exc_traceback) + print("\nEntering debugger...") + pdb.post_mortem(exc_traceback) + + +# Set the custom exception hook +sys.excepthook = excepthook + +logger = get_logger(__name__) +logger.setLevel("DEBUG") + +WORLD_SIZE = 8 +DEVICE_MESH = [[0, 1, 2, 3], [4, 5, 6, 7]] + +# * Define custom ops (leaf submodules during tracing) +BERT_CUSTOM_OPS = { + "modules": { + BertSelfAttention: {}, + }, + "functions": {}, +} + +def test_autosharding(): + + # Define config + config = BertConfig() + config.num_hidden_layers = 3 + config.hidden_size = 96 + config.intermediate_size = 384 + config_sequence_length = 4 + + # Initialize model and MaseGraph + model = BertModel(config, custom_ops=BERT_CUSTOM_OPS) + mg = MaseGraph(model) + mg, _ = passes.init_metadata_analysis_pass(mg) + mg, _ = passes.add_common_metadata_analysis_pass( + mg, pass_args={"dummy_in": {"x": torch.randn((16, 64))}, "add_value": False} + ) + + # Run autosharding pass to decide sharding configuration + mg, module_map = passes.autosharding_analysis_pass( + mg, + pass_args = { + "mesh_shape": (2, 4), + "inter_node_bandwidth": 10e9, + "intra_node_bandwidth": 100e9 + }) + + # Insert resharding wrappers around each module to handle inter-operator communication + mg, _ = passes.resharding_transform_pass(mg, pass_args={"module_map": module_map, "device_mesh": DEVICE_MESH}) + + # Launch model in distributed cluster + launcher = MaseLauncher(mg, world_size=WORLD_SIZE, device_mesh=DEVICE_MESH) + inputs = [torch.randn((16, 64))] + launcher.run(module_map, inputs) + + +if __name__ == "__main__": + test_autosharding() diff --git a/test/passes/graph/analysis/autosharding/test_autosharding.py b/test/passes/graph/analysis/autosharding/test_autosharding_linear.py similarity index 95% rename from test/passes/graph/analysis/autosharding/test_autosharding.py rename to test/passes/graph/analysis/autosharding/test_autosharding_linear.py index d5f1bdc2a..a337942ca 100644 --- a/test/passes/graph/analysis/autosharding/test_autosharding.py +++ b/test/passes/graph/analysis/autosharding/test_autosharding_linear.py @@ -22,6 +22,7 @@ def excepthook(exc_type, exc_value, exc_traceback): WORLD_SIZE = 8 DEVICE_MESH = [[0, 1, 2, 3], [4, 5, 6, 7]] + class MLP(nn.Module): def __init__(self, in_features=64, hidden_dimension=128, out_features=64): super().__init__() @@ -51,7 +52,7 @@ def test_autosharding(): "intra_node_bandwidth": 100e9 }) - # Insert resharding wrappers around each module + # Insert resharding wrappers around each module to handle inter-operator communication mg, _ = passes.resharding_transform_pass(mg, pass_args={"module_map": module_map, "device_mesh": DEVICE_MESH}) # Launch model in distributed cluster From 357e8c393486ac42f8381004caa504a012de3d78 Mon Sep 17 00:00:00 2001 From: Pedro Gimenes Date: Thu, 13 Jun 2024 15:03:10 +0000 Subject: [PATCH 08/93] run autosharding on huggingface bert --- src/chop/distributed/debug.py | 177 ++++++++++++++++++ src/chop/distributed/launcher.py | 15 +- src/chop/distributed/utils.py | 6 + .../add_metadata/common_metadata_layers.py | 2 + .../graph/analysis/autosharding/alpa.py | 61 ++++-- .../autosharding/alpa_cost_modelling.py | 15 +- .../analysis/autosharding/alpa_layers.py | 70 ++++++- src/chop/passes/graph/common.py | 2 + .../transforms/autosharding/resharding.py | 28 ++- src/chop/tools/__init__.py | 13 +- src/chop/tools/utils.py | 12 +- .../autosharding/test_autosharding_bert.py | 40 ++-- 12 files changed, 377 insertions(+), 64 deletions(-) create mode 100644 src/chop/distributed/debug.py diff --git a/src/chop/distributed/debug.py b/src/chop/distributed/debug.py new file mode 100644 index 000000000..e771fbb78 --- /dev/null +++ b/src/chop/distributed/debug.py @@ -0,0 +1,177 @@ +# mypy: allow-untyped-defs +from typing import List, Sequence, Tuple + +import numpy as np + +from torch._prims_common import ShapeType +from torch.distributed._tensor import DeviceMesh + +from torch.distributed._tensor.placement_types import Placement, Shard + + +def _mesh_to_coordinate(mesh, device_type): + """ + Given a n-dimensional list of device mesh, this function creates a map of + device and its coordinate + """ + # Convert the n-dimensional list to a NumPy array + np_mesh = np.array(mesh.mesh.tolist()) + + # Create a dictionary to map each value to its coordinate + device_to_coordinate_map = {} + for coord, value in np.ndenumerate(np_mesh): + # device is unique in device_mesh + device_to_coordinate_map[f"{device_type}:{str(value)}"] = list(coord) + + return device_to_coordinate_map + + +def _convert_offset_to_ranges(all_offsets): + """ + Using tabulate package to create a table is easier when we specify row and col ranges + This function converts offsets to ranges. + """ + converted_blocks = [] + + for offset in all_offsets: + shape, offset, value = offset + + # Calculate row_range and column_range + row_range = (offset[0], offset[0] + shape[0] - 1) + column_range = (offset[1], offset[1] + shape[1] - 1) + + # Convert value to string to match your desired format + converted_block = { + "row_range": row_range, + "column_range": column_range, + "value": str(value), + } + converted_blocks.append(converted_block) + + return converted_blocks + + +def _create_table(blocks): + """ + Creates a tabulate table given row and column ranges with device name + """ + try: + from tabulate import tabulate + except ImportError as e: + raise ImportError("tabulate package is required to visualize sharding") from e + + # Extract unique row and column ranges + row_ranges = sorted({block["row_range"] for block in blocks}) + col_ranges = sorted({block["column_range"] for block in blocks}) + + # Create a matrix initialized with empty strings + matrix = [["" for _ in col_ranges] for _ in row_ranges] + + # Fill the matrix with values + for block in blocks: + row_index = row_ranges.index(block["row_range"]) + col_index = col_ranges.index(block["column_range"]) + if matrix[row_index][col_index] == "": + matrix[row_index][col_index] = block["value"] + else: + matrix[row_index][col_index] += ", " + block["value"] + + # Prepare headers + row_headers = [f"Row {r[0]}-{r[1]}" for r in row_ranges] + col_headers = [f"Col {c[0]}-{c[1]}" for c in col_ranges] + + return tabulate(matrix, headers=col_headers, showindex=row_headers) + + +def compute_local_shape_and_global_offset( + global_shape: ShapeType, + mesh: DeviceMesh, + placements: Sequence[Placement], + my_coordinate: List[int], +) -> Tuple[Tuple[int, ...], Tuple[int, ...]]: + """ + Same as torch.distributed._tensor._utils.compute_local_shape_and_global_offset but + with custom my_coordinate input. This is the modified implementation for visualize_sharding. + """ + + if my_coordinate is None: + # if rank not in the mesh, return empty offset + return ((), ()) + else: + local_shape = list(global_shape) + global_offset = [0] * len(global_shape) + + for idx, placement in enumerate(placements): + mesh_dim_size = mesh.size(idx) + if isinstance(placement, Shard): + shard_dim = placement.dim + local_offset = [0] * len(global_shape) + assert shard_dim < len( + local_shape + ), f"Sharding dim {shard_dim} greater than tensor ndim {len(local_shape)}" + shard_size, shard_offset = placement._local_shard_size_on_dim( + local_shape[shard_dim], + mesh_dim_size, + my_coordinate[idx], + return_offset=True, + ) + + local_shape[shard_dim] = shard_size + local_offset[shard_dim] = shard_offset + + # On a given dimension, if the local_offset[shard_dim] is smaller than global_offset[shard_dim], + # it means that this dimension has been already sharded in previous placement. + # Therefore, we cannot simply replace the global_offset[shard_dim] with local_offset[shard_dim]. + # Instead, for the given shard_dim, we need to add local_offset[shard_dim] to existing global_offset[shard_dim]. + if global_offset[shard_dim] <= local_offset[shard_dim]: + global_offset[shard_dim] = local_offset[shard_dim] + else: + global_offset[shard_dim] += local_offset[shard_dim] + + return tuple(local_shape), tuple(global_offset) + + +def visualize_sharding(dtensor, header=""): + """ + Visualizes sharding in 1D-2D dtensors + Requires tabulate, install with `pip install tabulate` + + note: no sharding info will be printed for empty tensors + """ + if dtensor.numel() == 0: # we do not print for empty dtensors + return + + if len(dtensor.shape) >= 3: + raise RuntimeError( + "visualize sharding is only implemented for 1D or 2D dtensor" + ) + placements = dtensor.placements + device_mesh = dtensor.device_mesh + device_type = dtensor.device_mesh.device_type + + if device_mesh.get_coordinate() is None: # current rank is not in the mesh + return + + # Only display the visualization once for each DTensor, on the rank whose + # coordinate is 0 on all dimensions. For example, if the mesh is a full mesh, + # we will only print on rank 0. + local_rank_zero_on_all_dim = all( + device_mesh.get_local_rank(mesh_dim=dim) == 0 for dim in range(device_mesh.ndim) + ) + if not local_rank_zero_on_all_dim: + return + + device_map = _mesh_to_coordinate(device_mesh, device_type) + all_offsets = [] + for device in device_map: + local_shape, global_offset = compute_local_shape_and_global_offset( + dtensor.shape, device_mesh, placements, device_map[device] + ) + all_offsets.append([local_shape, global_offset, device]) + + # Convert offsets to blocks with row_ranges for tabulate + blocks = _convert_offset_to_ranges(all_offsets) + + # Print the table + print(header) + print(_create_table(blocks)) \ No newline at end of file diff --git a/src/chop/distributed/launcher.py b/src/chop/distributed/launcher.py index 9874938fc..a3d18fdb1 100644 --- a/src/chop/distributed/launcher.py +++ b/src/chop/distributed/launcher.py @@ -1,5 +1,6 @@ import os from functools import partial +from time import time import torch import torch.nn as nn @@ -30,12 +31,20 @@ def dist_model_fn( """ if module in module_map: for parameter, sharding_config in module_map[module].items(): + if parameter in ["data_in_0", "output", "data_out_0"]: + continue if not hasattr(module, parameter): rlog(logger, rank, f"Module {module} does not have parameter {parameter}", level="warning") continue + placement = placement_from_sharding_config(sharding_config) + rlog(logger, rank, f"Distributing parameter {parameter} of module {module} to {placement}", level="debug") - setattr(module, parameter, torch.nn.Parameter(distribute_tensor(getattr(module, parameter), device_mesh, placement))) + try: + distributed_tensor = distribute_tensor(getattr(module, parameter), device_mesh, placement) + setattr(module, parameter, torch.nn.Parameter(distributed_tensor)) + except Exception as e: + rlog(logger, rank, f"Error distributing parameter {parameter} of module {module}: {e}", level="error") def device_fn(rank, world_size, model=None, device_mesh=None, module_map={}, inputs=[]): @@ -52,10 +61,12 @@ def device_fn(rank, world_size, model=None, device_mesh=None, module_map={}, inp mesh = DeviceMesh("cuda", mesh=device_mesh) rlog(logger, rank, f"Distributing module parameters...", level="info") + start = time() model = distribute_module( model, mesh, partial(dist_model_fn, rank=rank, module_map=module_map), input_fn=None, output_fn=None ) - rlog(logger, rank, f"Module distribution done.") + end = time() + rlog(logger, rank, f"Module distribution done. Time taken: {end - start} seconds.") inputs = [distribute_tensor(in_tensor, mesh, [Replicate(), Replicate()]) for in_tensor in inputs] out = model(*inputs) diff --git a/src/chop/distributed/utils.py b/src/chop/distributed/utils.py index 42568cecf..55104edef 100644 --- a/src/chop/distributed/utils.py +++ b/src/chop/distributed/utils.py @@ -17,7 +17,13 @@ def placement_from_sharding_config(sharding_config): for shard_type in [SpmdShard.S_0, SpmdShard.S_1]: if shard_type in sharding_config: idx = sharding_config.index(shard_type) + # Preserve batch dimension + if (len(sharding_config) > 2): + idx = idx - (len(sharding_config) - 2) placement[shard_type.value] = Shard(idx) + + if placement == [Shard(1), Shard(1)]: + print(f"Warning: Invalid sharding config {sharding_config}") return placement def rlog(logger, rank, msg, level="info"): diff --git a/src/chop/passes/graph/analysis/add_metadata/common_metadata_layers.py b/src/chop/passes/graph/analysis/add_metadata/common_metadata_layers.py index a5e30a04d..7eb147a27 100644 --- a/src/chop/passes/graph/analysis/add_metadata/common_metadata_layers.py +++ b/src/chop/passes/graph/analysis/add_metadata/common_metadata_layers.py @@ -158,6 +158,7 @@ "getattr": {"a": "data_in", "b": "data_in"}, # https://pytorch.org/docs/stable/generated/torch.ones.html "ones": {"size": "config", "device": "config"}, + "finfo": {"dtype": "config"}, } module_data = { @@ -247,6 +248,7 @@ "transpose": {"dim_0": "config", "dim_1": "config"}, # https://pytorch.org/docs/stable/generated/torch.Tensor.contiguous.html#torch.Tensor.contiguous "contiguous": {}, + "masked_fill": {"mask": "data_in", "value": "data_in"}, } diff --git a/src/chop/passes/graph/analysis/autosharding/alpa.py b/src/chop/passes/graph/analysis/autosharding/alpa.py index a23c75aa3..924aa55c9 100644 --- a/src/chop/passes/graph/analysis/autosharding/alpa.py +++ b/src/chop/passes/graph/analysis/autosharding/alpa.py @@ -1,3 +1,7 @@ + +import functools + +import torch.nn as nn import numpy as np import cvxpy as cp @@ -8,13 +12,39 @@ from .alpa_cost_modelling import get_resharding_matrix logger = get_logger(__name__) +import sys, pdb, traceback + +def excepthook(exc_type, exc_value, exc_traceback): + traceback.print_exception(exc_type, exc_value, exc_traceback) + print("\nEntering debugger...") + pdb.post_mortem(exc_traceback) + +# Set the custom exception hook +sys.excepthook = excepthook + +def deepgetattr(obj, attr, default=None): + """Recurses through an attribute chain to get the ultimate value.""" + try: + return functools.reduce(getattr, attr.split("."), obj) + except AttributeError: + return default def get_node_target(node): if isinstance(node.target, str): - return getattr(node.meta["mase"].model, node.target, None) + return deepgetattr(node.meta["mase"].model, node.target, None) else: return node.target +def assign_default_sharding(node): + rank = len(node.meta["mase"]["common"]["results"]["data_out_0"]["shape"]) + node.meta["mase"]["software"]["autosharding"] = { + "valid_input_shardings": [{"data_in_0": (SpmdShard.R,) * rank}], + "valid_output_shardings": [(SpmdShard.R,) * rank], + "compute_cost_vector": [0], + "communication_cost_vector": [0], + "opt_var": np.array([1]), + } + def alpa_intra_op_sharding_pass(mg, mesh): """ Intra-operator auto parallelization pass. @@ -32,15 +62,19 @@ def alpa_intra_op_sharding_pass(mg, mesh): target = get_node_target(node) target_cls = type(target) + num_params = len([i for i in target.parameters()]) if isinstance(target, nn.Module) else 0 + + if node.op != "call_module" or num_params == 0: + assign_default_sharding(node) - if target_cls in ALPA_LAYERS.keys(): + elif target_cls in ALPA_LAYERS.keys(): # Enumerate shardings and costs for this operator ( input_shardings, output_shardings, compute_cost_vector, communication_cost_vector, - ) = ALPA_LAYERS[target_cls](node.meta, mesh) + ) = ALPA_LAYERS[target_cls](node.meta, mesh, target) # Formulate optimization variable and consider compute/communication cost opt_var = cp.Variable(len(input_shardings), boolean=True) @@ -65,7 +99,7 @@ def alpa_intra_op_sharding_pass(mg, mesh): resharding_costs = get_resharding_matrix( mesh, src_shardings = in_node.meta["mase"]["software"]["autosharding"]["valid_output_shardings"], - dest_shardings = [sharding["input"] for sharding in node.meta["mase"]["software"]["autosharding"]["valid_input_shardings"]], + dest_shardings = [sharding["data_in_0"] for sharding in node.meta["mase"]["software"]["autosharding"]["valid_input_shardings"]], dest_node_meta = node.meta["mase"] ).flatten() @@ -84,19 +118,14 @@ def alpa_intra_op_sharding_pass(mg, mesh): e_var[i] >= opt_var[i // in_opt_var.shape[0]] + in_opt_var[i % in_opt_var.shape[0]] - 1 ] - # Inputs to the whole graph are always replicated across all devices - elif node.op == "placeholder" or node.op == "output": - rank = len(node.meta["mase"]["common"]["results"]["data_out_0"]["shape"]) - node.meta["mase"]["software"]["autosharding"] = { - "valid_input_shardings": [(SpmdShard.R,) * rank], - "valid_output_shardings": [(SpmdShard.R,) * rank], - "compute_cost_vector": [0], - "communication_cost_vector": [0], - "opt_var": np.array([1]), - } + # No sharding algorithm found for this operator, but this has parameter attributes + # (i.e. not an elementwise or implicit function) + elif (len([i for i in target.parameters()]) > 0): + logger.warning(f"No sharding algorithm found for operator: {target_cls}, but the parameter count is non-zero.") + logger.warning(f" MaseLauncher will fully replicate the parameters of this module.") else: - logger.warning(f"No sharding algorithm found for operator: {target_cls}") + logger.debug(f"Skipping implicit/elementwise operator: {target_cls}") # Solve the ILP problem prob = cp.Problem(cp.Minimize(expr), constr) @@ -109,7 +138,7 @@ def alpa_intra_op_sharding_pass(mg, mesh): # Write into module map (used by distributed launcher) target = get_node_target(node) - if target is not None: + if node.op == "call_module" and target is not None: module_map[target] = { key: node.meta["mase"]["software"]["autosharding"]["input_sharding"][key] for key in node.meta["mase"]["software"]["autosharding"]["input_sharding"].keys() } diff --git a/src/chop/passes/graph/analysis/autosharding/alpa_cost_modelling.py b/src/chop/passes/graph/analysis/autosharding/alpa_cost_modelling.py index 3e66e4697..7e051cb5b 100644 --- a/src/chop/passes/graph/analysis/autosharding/alpa_cost_modelling.py +++ b/src/chop/passes/graph/analysis/autosharding/alpa_cost_modelling.py @@ -10,7 +10,7 @@ BYTES_PER_ELEMENT = 4 def get_communication_cost(sharding: tuple, node_meta: MaseMetadata, mesh: MeshModel): - assert sharding[0][1] == sharding[1][0], f"Inconsistent sharding for node: {node_meta.node}" + assert sharding[0][-1] == sharding[1][-2], f"Inconsistent sharding for node: {node_meta.node}" inner_dim_sharding = sharding[1][0] out_shape = node_meta["common"]["results"]["data_out_0"]["shape"] @@ -32,7 +32,7 @@ def get_resharding_cost(mesh: MeshModel, src: tuple, dest: tuple, dest_node_meta # If original sharding is fully replicated, no resharding is required - if src == dest or src == (SpmdShard.R, SpmdShard.R): + if src == dest or all(i == SpmdShard.R for i in src): return 0 num_bytes = BYTES_PER_ELEMENT * np.prod(dest_node_meta["common"]["args"]["data_in_0"]["shape"]) @@ -68,10 +68,13 @@ def get_resharding_cost(mesh: MeshModel, src: tuple, dest: tuple, dest_node_meta elif (src[0] == dest[1] and src[1] == dest[0] and (SpmdShard.R in src)): # all to all a2a_dim = src[0].value if src[0] != SpmdShard.R else src[1].value - return mesh.all_to_all_cost( - num_bytes = num_bytes, - mesh_dim = a2a_dim, - ) + try: + return mesh.all_to_all_cost( + num_bytes = num_bytes, + mesh_dim = a2a_dim, + ) + except: + breakpoint() # Two-stage resharding: when the resharding cannot be resolved with a single split, all-gather or all-to-all, # must first gather along the first non-replicated dimension, then recursively compute the cost for the diff --git a/src/chop/passes/graph/analysis/autosharding/alpa_layers.py b/src/chop/passes/graph/analysis/autosharding/alpa_layers.py index f382266bd..74651a624 100644 --- a/src/chop/passes/graph/analysis/autosharding/alpa_layers.py +++ b/src/chop/passes/graph/analysis/autosharding/alpa_layers.py @@ -3,33 +3,52 @@ import torch.nn as nn from chop.tools import get_logger +from chop.models.patched.bert.modeling_bert import BertSelfAttention from .common import SpmdShard, VALID_2D_TENSOR_SHARDINGS from .alpa_cost_modelling import get_communication_cost + logger = get_logger(__name__) -def get_valid_linear_shardings(node_meta, mesh): +def is_valid_2d_sharding(sharding): + if len(sharding) > 2: + return sharding[1:] in VALID_2D_TENSOR_SHARDINGS + else: + return sharding in VALID_2D_TENSOR_SHARDINGS + +def is_valid_sharding_pair(sharding_pair): + return sharding_pair[0][-1] == sharding_pair[1][-2] + +def is_fully_replicated(sharding_pair): + return all(all(dimp == SpmdShard.R for dimp in subp) for subp in sharding_pair) + +def get_valid_2d_shardings(node_meta, mesh, module): """ Return every valid combination of shardings for the input tensors. For an operator sharding to be valid, the inner dimension must have the same sharding. E.g. ((R, S_0), (S_0, R)) are valid, but ((R, S_0), (S_1, R)) is not. """ - input_shardings, output_shardings = [], [] - compute_cost_vector, communication_cost_vector = [], [] + input_shardings = [] + output_shardings = [] + compute_cost_vector = [] + communication_cost_vector = [] - permutations = list(itertools.product(VALID_2D_TENSOR_SHARDINGS, repeat=2)) - for p in permutations: - output_sharding = (p[0][0], p[1][1]) - if p != ((SpmdShard.R, SpmdShard.R), (SpmdShard.R, SpmdShard.R)) and p[0][1] == p[1][0] and output_sharding in VALID_2D_TENSOR_SHARDINGS: + out_rank = len(node_meta["mase"]["common"]["results"]["data_out_0"]["shape"]) + + for perm in itertools.product(VALID_2D_TENSOR_SHARDINGS, repeat=2): + if out_rank > 2: + perm = tuple((SpmdShard.R,) * (out_rank - 2) + p for p in perm) + output_sharding = tuple((SpmdShard.R,) * (out_rank - 2) + (perm[0][-2], perm[1][-1])) + if not is_fully_replicated(perm) and is_valid_sharding_pair(perm) and is_valid_2d_sharding(output_sharding): input_shardings.append({ - "input": p[0], - "weight": p[1] + "data_in_0": perm[0], + "weight": perm[1] }) output_shardings.append(output_sharding) compute_cost_vector.append(0) - communication_cost_vector.append(get_communication_cost(p, node_meta["mase"], mesh)) + communication_cost_vector.append(get_communication_cost(perm, node_meta["mase"], mesh)) return ( input_shardings, @@ -38,7 +57,38 @@ def get_valid_linear_shardings(node_meta, mesh): np.array(communication_cost_vector), ) +def get_valid_linear_shardings(node_meta, mesh, module): + return get_valid_2d_shardings(node_meta, mesh, module) + +def get_valid_layernorm_shardings(node_meta, mesh, module): + rank = len(node_meta["mase"]["common"]["results"]["data_out_0"]["shape"]) + valid_input_shardings = [{"data_in_0": (SpmdShard.R,) * rank}] + valid_output_shardings = [(SpmdShard.R,) * rank] + compute_cost_vector = [0] + communication_cost_vector = [0] + return ( + valid_input_shardings, + valid_output_shardings, + np.array(compute_cost_vector), + np.array(communication_cost_vector), + ) + +def get_valid_embedding_shardings(node_meta, mesh, module): + weight_rank = len(module.weight.shape) + data_in_rank = len(node_meta["mase"]["common"]["args"]["data_in_0"]["shape"]) + valid_input_shardings = [{"data_in_0": (SpmdShard.R,) * data_in_rank, "weight": (SpmdShard.R,) * weight_rank}] + valid_output_shardings = [(SpmdShard.R,) * data_in_rank] + compute_cost_vector = [0] + communication_cost_vector = [0] + return ( + valid_input_shardings, + valid_output_shardings, + np.array(compute_cost_vector), + np.array(communication_cost_vector), + ) ALPA_LAYERS = { nn.Linear: get_valid_linear_shardings, + nn.LayerNorm: get_valid_layernorm_shardings, + nn.Embedding: get_valid_embedding_shardings, } diff --git a/src/chop/passes/graph/common.py b/src/chop/passes/graph/common.py index 8c441b83a..c219b7bd0 100644 --- a/src/chop/passes/graph/common.py +++ b/src/chop/passes/graph/common.py @@ -48,6 +48,8 @@ "full", "ones", "dim", + "finfo", + "masked_fill" ] MASE_MODULE_RELATED_FUNCS = [ diff --git a/src/chop/passes/module/transforms/autosharding/resharding.py b/src/chop/passes/module/transforms/autosharding/resharding.py index a9a8cc3d5..ecfc8ad78 100644 --- a/src/chop/passes/module/transforms/autosharding/resharding.py +++ b/src/chop/passes/module/transforms/autosharding/resharding.py @@ -1,4 +1,6 @@ +import functools + import torch import torch.nn as nn @@ -14,6 +16,21 @@ logger = get_logger(__name__) logger.setLevel("DEBUG") +def deepsetattr(obj, attr, value): + """Recurses through an attribute chain to set the ultimate value.""" + attrs = attr.split(".") + if len(attrs) > 1: + deepsetattr(getattr(obj, attrs[0]), '.'.join(attrs[1:]), value) + else: + setattr(obj, attr, value) + +def deepgetattr(obj, attr, default=None): + """Recurses through an attribute chain to get the ultimate value.""" + try: + return functools.reduce(getattr, attr.split("."), obj) + except AttributeError: + return default + class ReshardingWrapper(nn.Module): def __init__(self, device_mesh, module, resharding_config): super().__init__() @@ -25,7 +42,7 @@ def forward(self, x): rank = torch.distributed.get_rank() device_mesh = DeviceMesh("cuda", self.device_mesh) - required_placement = placement_from_sharding_config(self.resharding_config["input"]) + required_placement = placement_from_sharding_config(self.resharding_config["data_in_0"]) if (x.placements != required_placement): rlog(logger, rank, f"For module {self.module}, resharding tensor x from {x.placements} to {required_placement}", level="debug") x = Redistribute.apply(x, device_mesh, required_placement) @@ -45,9 +62,14 @@ def resharding_transform_pass(mg, pass_args={}): raise ValueError("module_map and device_mesh are required for resharding_transform_pass") for node in mg.fx_graph.nodes: - module = getattr(mg.model, node.target, None) + if node.op != "call_module": + continue + module = deepgetattr(mg.model, node.target, None) if module is not None: resharding_config = module_map[module] - setattr(mg.model, node.target, ReshardingWrapper(device_mesh, module, resharding_config)) + logger.info(f"Inserting resharding wrapper around node: {node}") + deepsetattr(mg.model, node.target, ReshardingWrapper(device_mesh, module, resharding_config)) + + mg.model.recompile() return mg, {} \ No newline at end of file diff --git a/src/chop/tools/__init__.py b/src/chop/tools/__init__.py index 0d9d707a9..47b22b690 100644 --- a/src/chop/tools/__init__.py +++ b/src/chop/tools/__init__.py @@ -10,4 +10,15 @@ from .get_input import get_cf_args, get_dummy_input -from .utils import deepsetattr \ No newline at end of file +from .utils import ( + set_excepthook, + deepsetattr, + deepgetattr, + get_checkpoint_file, + copy_weights, + to_numpy, + to_numpy_if_tensor, + to_tensor, + to_tensor_if_numpy, + is_tensor, +) diff --git a/src/chop/tools/utils.py b/src/chop/tools/utils.py index e38f4d7e7..0131f0355 100644 --- a/src/chop/tools/utils.py +++ b/src/chop/tools/utils.py @@ -1,9 +1,8 @@ import numpy as np import os -import pickle import torch +import functools -import colorlog import torch import subprocess @@ -263,4 +262,11 @@ def deepsetattr(obj, attr, value): if len(attrs) > 1: deepsetattr(getattr(obj, attrs[0]), '.'.join(attrs[1:]), value) else: - setattr(obj, attr, value) \ No newline at end of file + setattr(obj, attr, value) + +def deepgetattr(obj, attr, default=None): + """Recurses through an attribute chain to get the ultimate value.""" + try: + return functools.reduce(getattr, attr.split("."), obj) + except AttributeError: + return default diff --git a/test/passes/graph/analysis/autosharding/test_autosharding_bert.py b/test/passes/graph/analysis/autosharding/test_autosharding_bert.py index 65a4e3297..b0982601f 100644 --- a/test/passes/graph/analysis/autosharding/test_autosharding_bert.py +++ b/test/passes/graph/analysis/autosharding/test_autosharding_bert.py @@ -1,4 +1,4 @@ -import sys, pdb, traceback, os +import sys, pdb, traceback import torch import torch.nn as nn @@ -8,17 +8,7 @@ import chop.passes as passes from chop.tools import get_logger -from chop.models.patched.bert import BertConfig, BertModel -from chop.models.patched.bert.modeling_bert import BertSelfAttention - -def excepthook(exc_type, exc_value, exc_traceback): - traceback.print_exception(exc_type, exc_value, exc_traceback) - print("\nEntering debugger...") - pdb.post_mortem(exc_traceback) - - -# Set the custom exception hook -sys.excepthook = excepthook +from transformers.models.bert import BertConfig, BertModel logger = get_logger(__name__) logger.setLevel("DEBUG") @@ -26,14 +16,6 @@ def excepthook(exc_type, exc_value, exc_traceback): WORLD_SIZE = 8 DEVICE_MESH = [[0, 1, 2, 3], [4, 5, 6, 7]] -# * Define custom ops (leaf submodules during tracing) -BERT_CUSTOM_OPS = { - "modules": { - BertSelfAttention: {}, - }, - "functions": {}, -} - def test_autosharding(): # Define config @@ -41,14 +23,22 @@ def test_autosharding(): config.num_hidden_layers = 3 config.hidden_size = 96 config.intermediate_size = 384 + config._attn_implementation = "eager" config_sequence_length = 4 # Initialize model and MaseGraph - model = BertModel(config, custom_ops=BERT_CUSTOM_OPS) + model = BertModel(config) mg = MaseGraph(model) mg, _ = passes.init_metadata_analysis_pass(mg) + mg, _ = passes.report_graph_analysis_pass(mg, pass_args={"file_name": "bert.txt"}) mg, _ = passes.add_common_metadata_analysis_pass( - mg, pass_args={"dummy_in": {"x": torch.randn((16, 64))}, "add_value": False} + mg, + pass_args={ + "dummy_in": { + "input_ids": torch.randint(0, 10, (1, config_sequence_length)), + }, + "add_value": False, + }, ) # Run autosharding pass to decide sharding configuration @@ -63,9 +53,13 @@ def test_autosharding(): # Insert resharding wrappers around each module to handle inter-operator communication mg, _ = passes.resharding_transform_pass(mg, pass_args={"module_map": module_map, "device_mesh": DEVICE_MESH}) + # dump print model to a file + with open("model.txt", "w") as f: + print(mg.model, file=f) + # Launch model in distributed cluster launcher = MaseLauncher(mg, world_size=WORLD_SIZE, device_mesh=DEVICE_MESH) - inputs = [torch.randn((16, 64))] + inputs = [torch.randint(0, 10, (1, config_sequence_length))] launcher.run(module_map, inputs) From bc8fef1b3f1deadd4735cdfac66c4d3559bc0b89 Mon Sep 17 00:00:00 2001 From: Pedro Gimenes Date: Thu, 13 Jun 2024 17:21:01 +0000 Subject: [PATCH 09/93] refactoring to fix circular imports --- setup.py | 11 ----------- src/chop/__init__.py | 2 -- src/chop/{passes/graph => ir}/common.py | 3 --- src/chop/ir/graph/mase_graph.py | 7 +++---- src/chop/ir/graph/mase_metadata.py | 7 +++++-- src/chop/{tools => ir/onnx}/onnx_operators.py | 0 src/chop/ir/onnx/utils.py | 2 +- src/chop/models/__init__.py | 18 +++++++++--------- .../models/patched/bert/configuration_bert.py | 5 ----- src/chop/models/patched/bert/modeling_bert.py | 4 ---- src/chop/nn/__init__.py | 5 +++++ .../add_metadata/add_common_metadata.py | 2 +- src/chop/passes/graph/analysis/utils.py | 2 +- src/chop/passes/graph/patching/__init__.py | 13 ------------- src/chop/tools/__init__.py | 1 - 15 files changed, 25 insertions(+), 57 deletions(-) rename src/chop/{passes/graph => ir}/common.py (99%) rename src/chop/{tools => ir/onnx}/onnx_operators.py (100%) diff --git a/setup.py b/setup.py index 250652a0e..b132e4bdf 100644 --- a/setup.py +++ b/setup.py @@ -44,7 +44,6 @@ def is_cuda_available(): "ipdb", "sentencepiece", "einops", - "deepspeed", "pybind11", "tabulate", "tensorboard", @@ -73,16 +72,6 @@ def is_cuda_available(): "bitstring>=4.2", ] -if is_cuda_available(): - requirements += [ - "pycuda", - "onnxruntime-gpu", - "torch-tensorRT; platform_system == 'Linux'", - "tensorRT; platform_system == 'Linux'", - "cuda-python; platform_system == 'Linux'", - "pytorch-quantization; platform_system == 'Linux'", - ] - setup( name="mase-tools", version="1.0.0", diff --git a/src/chop/__init__.py b/src/chop/__init__.py index 0202dc9a1..99e464c5a 100644 --- a/src/chop/__init__.py +++ b/src/chop/__init__.py @@ -1,5 +1,3 @@ -from .tools.logger import root_logger - from .ir.graph.mase_graph import MaseGraph from .ir.onnx.mase_onnx_graph import MaseOnnxGraph diff --git a/src/chop/passes/graph/common.py b/src/chop/ir/common.py similarity index 99% rename from src/chop/passes/graph/common.py rename to src/chop/ir/common.py index c219b7bd0..fd0cd1b6a 100644 --- a/src/chop/passes/graph/common.py +++ b/src/chop/ir/common.py @@ -1,5 +1,3 @@ -import torch.nn.functional as F - MASE_TYPES = [ "module", "module_related_func", @@ -10,7 +8,6 @@ "output", ] - MASE_IMPLICIT_FUNCS = [ # possibly are just constants "size", diff --git a/src/chop/ir/graph/mase_graph.py b/src/chop/ir/graph/mase_graph.py index 39f1ab2d3..5826cb22c 100644 --- a/src/chop/ir/graph/mase_graph.py +++ b/src/chop/ir/graph/mase_graph.py @@ -11,9 +11,8 @@ import torch.fx as fx from torch.fx.passes.graph_drawer import FxGraphDrawer -from chop.passes.graph.common import MASE_IMPLICIT_FUNCS -from chop.passes.graph.transforms import utils as utils_passes -from chop.passes.graph.patching import MASE_LEAF_FUNCTIONS, MASE_LEAF_LAYERS +from chop.ir.common import MASE_IMPLICIT_FUNCS +from chop.nn import MASE_LEAF_LAYERS from chop.nn.quantized import ( quantized_func_map, quantized_module_map, @@ -64,7 +63,7 @@ def __init__( self.param_shapes_constant = param_shapes_constant super().__init__( self.custom_leaf_modules + (math,), - self.custom_leaf_functions + MASE_LEAF_FUNCTIONS, + self.custom_leaf_functions, self.param_shapes_constant, ) diff --git a/src/chop/ir/graph/mase_metadata.py b/src/chop/ir/graph/mase_metadata.py index 303292b90..770ae8269 100644 --- a/src/chop/ir/graph/mase_metadata.py +++ b/src/chop/ir/graph/mase_metadata.py @@ -2,10 +2,13 @@ from torch import nn -from ...passes.graph.utils import get_module_by_name - logger = logging.getLogger(__name__) +def get_module_by_name(model, request_name): + for name, layer in model.named_modules(): + if name == request_name: + return layer + return None class MaseMetadata: """ diff --git a/src/chop/tools/onnx_operators.py b/src/chop/ir/onnx/onnx_operators.py similarity index 100% rename from src/chop/tools/onnx_operators.py rename to src/chop/ir/onnx/onnx_operators.py diff --git a/src/chop/ir/onnx/utils.py b/src/chop/ir/onnx/utils.py index 38b248ed5..bb65aa8b2 100644 --- a/src/chop/ir/onnx/utils.py +++ b/src/chop/ir/onnx/utils.py @@ -1,6 +1,6 @@ import torch -from chop.tools.onnx_operators import ( +from chop.ir.onnx.onnx_operators import ( onnx_gemm, onnx_slice, onnx_squeeze, diff --git a/src/chop/models/__init__.py b/src/chop/models/__init__.py index aa4cad195..0174373ce 100644 --- a/src/chop/models/__init__.py +++ b/src/chop/models/__init__.py @@ -12,15 +12,15 @@ get_patched_model_config_cls, get_patched_model_tokenizer_cls, ) -from .manual import ( - is_manual_model, - get_manual_model, - get_manual_model_cls, - get_manual_model_config_cls, - get_manual_model_tokenizer_cls, - get_manual_model_info, - get_manual_model_tokenizer, -) +# from .manual import ( +# is_manual_model, +# get_manual_model, +# get_manual_model_cls, +# get_manual_model_config_cls, +# get_manual_model_tokenizer_cls, +# get_manual_model_info, +# get_manual_model_tokenizer, +# ) from .huggingface_nlp_models import ( is_hf_nlp_model, diff --git a/src/chop/models/patched/bert/configuration_bert.py b/src/chop/models/patched/bert/configuration_bert.py index 6ef356b67..a3f82a8d2 100644 --- a/src/chop/models/patched/bert/configuration_bert.py +++ b/src/chop/models/patched/bert/configuration_bert.py @@ -25,11 +25,6 @@ logger = logging.get_logger(__name__) -from transformers.models.deprecated._archive_maps import ( - BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, -) # noqa: F401, E402 - - class BertConfig(PretrainedConfig): r""" This is the configuration class to store the configuration of a [`BertModel`] or a [`TFBertModel`]. It is used to diff --git a/src/chop/models/patched/bert/modeling_bert.py b/src/chop/models/patched/bert/modeling_bert.py index 10c7088a9..cdd6335ee 100644 --- a/src/chop/models/patched/bert/modeling_bert.py +++ b/src/chop/models/patched/bert/modeling_bert.py @@ -80,10 +80,6 @@ _SEQ_CLASS_EXPECTED_LOSS = 0.01 -from transformers.models.deprecated._archive_maps import ( - BERT_PRETRAINED_MODEL_ARCHIVE_LIST, -) # noqa: F401, E402 - def load_tf_weights_in_bert(model, config, tf_checkpoint_path): """Load tf checkpoints in a pytorch model.""" diff --git a/src/chop/nn/__init__.py b/src/chop/nn/__init__.py index e69de29bb..345a92d90 100644 --- a/src/chop/nn/__init__.py +++ b/src/chop/nn/__init__.py @@ -0,0 +1,5 @@ +from .quantized import quantized_module_map + +MASE_LEAF_LAYERS = tuple( + quantized_module_map.values() +) \ No newline at end of file diff --git a/src/chop/passes/graph/analysis/add_metadata/add_common_metadata.py b/src/chop/passes/graph/analysis/add_metadata/add_common_metadata.py index 53785b105..c304445c2 100644 --- a/src/chop/passes/graph/analysis/add_metadata/add_common_metadata.py +++ b/src/chop/passes/graph/analysis/add_metadata/add_common_metadata.py @@ -13,7 +13,7 @@ get_input_nodes, get_output_nodes, ) -from chop.passes.graph.common import ( +from chop.ir.common import ( MASE_BUILTIN_FUNCS, MASE_IMPLICIT_FUNCS, MASE_MODULE_RELATED_FUNCS, diff --git a/src/chop/passes/graph/analysis/utils.py b/src/chop/passes/graph/analysis/utils.py index ea0675511..0d22a952c 100644 --- a/src/chop/passes/graph/analysis/utils.py +++ b/src/chop/passes/graph/analysis/utils.py @@ -3,7 +3,7 @@ import torch import regex as re -from chop.passes.graph.common import MASE_IMPLICIT_FUNCS +from chop.ir.common import MASE_IMPLICIT_FUNCS # from ..session.plt_wrapper.nlp.classification import NLPClassificationModelWrapper # from ..session.plt_wrapper.nlp.lm import NLPLanguageModelingModelWrapper diff --git a/src/chop/passes/graph/patching/__init__.py b/src/chop/passes/graph/patching/__init__.py index 6180b0b01..e69de29bb 100644 --- a/src/chop/passes/graph/patching/__init__.py +++ b/src/chop/passes/graph/patching/__init__.py @@ -1,13 +0,0 @@ -from .mase_op_wrapper import torch_arange, torch_ones, torch_zeros -from chop.nn.quantized import quantized_module_map, quantized_func_map - -MASE_LEAF_FUNCTIONS = ( - # tensor constructors - torch_arange, - torch_ones, - torch_zeros, -) # + tuple(quantized_func_map.keys()) # add this if there is a case where quantized module is traced again - -MASE_LEAF_LAYERS = () + tuple( - quantized_module_map.values() -) # add this if there is a case where quantized module is traced again diff --git a/src/chop/tools/__init__.py b/src/chop/tools/__init__.py index 47b22b690..bd62682f8 100644 --- a/src/chop/tools/__init__.py +++ b/src/chop/tools/__init__.py @@ -11,7 +11,6 @@ from .get_input import get_cf_args, get_dummy_input from .utils import ( - set_excepthook, deepsetattr, deepgetattr, get_checkpoint_file, From dbc6591b6d4bbedcbed7efd5080802258918f8db Mon Sep 17 00:00:00 2001 From: Pedro Gimenes Date: Thu, 13 Jun 2024 19:51:51 +0000 Subject: [PATCH 10/93] need to fix batch dimension sharding and sharding along multiple mesh dimensions --- .../passes/graph/analysis/autosharding/common.py | 4 ++-- .../module/transforms/autosharding/resharding.py | 16 +++++++++++++--- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/src/chop/passes/graph/analysis/autosharding/common.py b/src/chop/passes/graph/analysis/autosharding/common.py index e0b98001a..7d76468a9 100644 --- a/src/chop/passes/graph/analysis/autosharding/common.py +++ b/src/chop/passes/graph/analysis/autosharding/common.py @@ -19,7 +19,7 @@ def __gt__(self, other): (SpmdShard.R, SpmdShard.S_0), (SpmdShard.R, SpmdShard.S_1), (SpmdShard.S_0, SpmdShard.R), - (SpmdShard.S_0, SpmdShard.S_1), + # (SpmdShard.S_0, SpmdShard.S_1), (SpmdShard.S_1, SpmdShard.R), - (SpmdShard.S_1, SpmdShard.S_0), + # (SpmdShard.S_1, SpmdShard.S_0), ] \ No newline at end of file diff --git a/src/chop/passes/module/transforms/autosharding/resharding.py b/src/chop/passes/module/transforms/autosharding/resharding.py index ecfc8ad78..b7835960b 100644 --- a/src/chop/passes/module/transforms/autosharding/resharding.py +++ b/src/chop/passes/module/transforms/autosharding/resharding.py @@ -10,12 +10,20 @@ from torch.distributed._tensor.api import Redistribute -from chop.distributed.utils import placement_from_sharding_config, rlog +from chop.distributed.utils import placement_from_sharding_config from chop.tools import get_logger logger = get_logger(__name__) logger.setLevel("DEBUG") +def rlog(logger, rank, msg, level="info"): + """ + Only log on rank 0 to avoid repeated messages. + """ + log_fn = getattr(logger, level, logger.info) + if (rank == 0): + log_fn(f"[RANK: {rank}]: {msg}") + def deepsetattr(obj, attr, value): """Recurses through an attribute chain to set the ultimate value.""" attrs = attr.split(".") @@ -44,10 +52,12 @@ def forward(self, x): required_placement = placement_from_sharding_config(self.resharding_config["data_in_0"]) if (x.placements != required_placement): - rlog(logger, rank, f"For module {self.module}, resharding tensor x from {x.placements} to {required_placement}", level="debug") + rlog(logger, rank, f"For module {self.module}, resharding tensor x from {x.placements} to {required_placement}", level="info") x = Redistribute.apply(x, device_mesh, required_placement) - return self.module(x) + out = self.module(x) + + return out def resharding_transform_pass(mg, pass_args={}): """ From 04ebd3459f3879608a02613dd210477c8ce7133c Mon Sep 17 00:00:00 2001 From: Pedro Gimenes Date: Fri, 14 Jun 2024 13:39:45 +0000 Subject: [PATCH 11/93] autosharding works on patched bert without batch dimension sharding --- src/chop/distributed/launcher.py | 7 +- src/chop/distributed/utils.py | 11 ++- .../graph/analysis/autosharding/alpa.py | 24 +++++-- .../analysis/autosharding/alpa_layers.py | 4 +- .../graph/analysis/autosharding/common.py | 4 +- .../analysis/autosharding/debug_utilities.py | 72 +++++++++++++++++++ .../transforms/autosharding/resharding.py | 5 +- 7 files changed, 106 insertions(+), 21 deletions(-) create mode 100644 src/chop/passes/graph/analysis/autosharding/debug_utilities.py diff --git a/src/chop/distributed/launcher.py b/src/chop/distributed/launcher.py index a3d18fdb1..17bc786c9 100644 --- a/src/chop/distributed/launcher.py +++ b/src/chop/distributed/launcher.py @@ -30,7 +30,8 @@ def dist_model_fn( Each tensor in each module is distributed according to the sharding configuration in module_map. """ if module in module_map: - for parameter, sharding_config in module_map[module].items(): + node_name = module_map[module]["node"] + for parameter, sharding_config in module_map[module]["sharding"].items(): if parameter in ["data_in_0", "output", "data_out_0"]: continue if not hasattr(module, parameter): @@ -39,12 +40,12 @@ def dist_model_fn( placement = placement_from_sharding_config(sharding_config) - rlog(logger, rank, f"Distributing parameter {parameter} of module {module} to {placement}", level="debug") try: + rlog(logger, rank, f"Distributing parameter {parameter} of module {node_name} to {placement}", level="debug") distributed_tensor = distribute_tensor(getattr(module, parameter), device_mesh, placement) setattr(module, parameter, torch.nn.Parameter(distributed_tensor)) except Exception as e: - rlog(logger, rank, f"Error distributing parameter {parameter} of module {module}: {e}", level="error") + rlog(logger, rank, f"Error distributing parameter {parameter} of module {node_name} to {placement}: {e}", level="error") def device_fn(rank, world_size, model=None, device_mesh=None, module_map={}, inputs=[]): diff --git a/src/chop/distributed/utils.py b/src/chop/distributed/utils.py index 55104edef..19b1e6dbf 100644 --- a/src/chop/distributed/utils.py +++ b/src/chop/distributed/utils.py @@ -6,6 +6,8 @@ from chop.passes.graph.analysis.autosharding.common import SpmdShard +import torch + def placement_from_sharding_config(sharding_config): """ Sharding config is given as a tuple such as (R, S_0) where a symbol S_x at index i indicates @@ -13,18 +15,13 @@ def placement_from_sharding_config(sharding_config): the distribute_tensor API expects a tuple of Shard() and Replicate() objects where a Shard(x) at index i indicates that tensor dimension x is sharded along device mesh dimension i. """ - placement = [Replicate(), Replicate()] + placement = [Replicate()] * 2 for shard_type in [SpmdShard.S_0, SpmdShard.S_1]: if shard_type in sharding_config: idx = sharding_config.index(shard_type) - # Preserve batch dimension - if (len(sharding_config) > 2): - idx = idx - (len(sharding_config) - 2) placement[shard_type.value] = Shard(idx) - if placement == [Shard(1), Shard(1)]: - print(f"Warning: Invalid sharding config {sharding_config}") - return placement + return tuple(placement) def rlog(logger, rank, msg, level="info"): """ diff --git a/src/chop/passes/graph/analysis/autosharding/alpa.py b/src/chop/passes/graph/analysis/autosharding/alpa.py index 924aa55c9..4f4fd32a0 100644 --- a/src/chop/passes/graph/analysis/autosharding/alpa.py +++ b/src/chop/passes/graph/analysis/autosharding/alpa.py @@ -1,4 +1,4 @@ - +import sys, pdb, traceback import functools import torch.nn as nn @@ -10,9 +10,10 @@ from .common import SpmdShard from .alpa_layers import ALPA_LAYERS from .alpa_cost_modelling import get_resharding_matrix +from .debug_utilities import debug_shardings, are_layers_equal logger = get_logger(__name__) -import sys, pdb, traceback +logger.setLevel("DEBUG") def excepthook(exc_type, exc_value, exc_traceback): traceback.print_exception(exc_type, exc_value, exc_traceback) @@ -45,7 +46,8 @@ def assign_default_sharding(node): "opt_var": np.array([1]), } -def alpa_intra_op_sharding_pass(mg, mesh): + +def alpa_intra_op_sharding_pass(mg, mesh, debug=False): """ Intra-operator auto parallelization pass. """ @@ -58,6 +60,7 @@ def alpa_intra_op_sharding_pass(mg, mesh): # Write cost vectors into metadata for each operator # This will later be used to solve the ILP optimization + debugged_layers = [] for node in mg.fx_graph.nodes: target = get_node_target(node) @@ -76,6 +79,11 @@ def alpa_intra_op_sharding_pass(mg, mesh): communication_cost_vector, ) = ALPA_LAYERS[target_cls](node.meta, mesh, target) + # Debug each possible sharding by running an inference step over the device mesh + if debug and not any([are_layers_equal(layer, target) for layer in debugged_layers]): + debug_shardings(target, input_shardings, mesh) + debugged_layers.append(target) + # Formulate optimization variable and consider compute/communication cost opt_var = cp.Variable(len(input_shardings), boolean=True) constr += [ @@ -135,14 +143,18 @@ def alpa_intra_op_sharding_pass(mg, mesh): chosen_idx = 0 if isinstance(node.meta["mase"]["software"]["autosharding"]["opt_var"], np.ndarray) else np.where(node.meta["mase"]["software"]["autosharding"]["opt_var"].value == 1)[0][0] node.meta["mase"]["software"]["autosharding"]["input_sharding"] = node.meta["mase"]["software"]["autosharding"]["valid_input_shardings"][chosen_idx] node.meta["mase"]["software"]["autosharding"]["output_sharding"] = node.meta["mase"]["software"]["autosharding"]["valid_output_shardings"][chosen_idx] - + chosen_sharding = {key: node.meta["mase"]["software"]["autosharding"]["input_sharding"][key] for key in node.meta["mase"]["software"]["autosharding"]["input_sharding"].keys()} + # Write into module map (used by distributed launcher) target = get_node_target(node) if node.op == "call_module" and target is not None: module_map[target] = { - key: node.meta["mase"]["software"]["autosharding"]["input_sharding"][key] for key in node.meta["mase"]["software"]["autosharding"]["input_sharding"].keys() + "node": node.name, + "sharding": chosen_sharding } - module_map[target]["output"] = node.meta["mase"]["software"]["autosharding"]["output_sharding"] + module_map[target]["sharding"]["output"] = node.meta["mase"]["software"]["autosharding"]["output_sharding"] + + logger.info(f"Chosen sharding for node {node.name}: {chosen_sharding}") return mg, module_map diff --git a/src/chop/passes/graph/analysis/autosharding/alpa_layers.py b/src/chop/passes/graph/analysis/autosharding/alpa_layers.py index 74651a624..27298cad1 100644 --- a/src/chop/passes/graph/analysis/autosharding/alpa_layers.py +++ b/src/chop/passes/graph/analysis/autosharding/alpa_layers.py @@ -38,7 +38,9 @@ def get_valid_2d_shardings(node_meta, mesh, module): for perm in itertools.product(VALID_2D_TENSOR_SHARDINGS, repeat=2): if out_rank > 2: - perm = tuple((SpmdShard.R,) * (out_rank - 2) + p for p in perm) + # Always replicate along batch dimension, and assume weights are always 2D + # TO DO: handle sharding along batch dimension + perm = ((SpmdShard.R, ) + perm[0], perm[1]) output_sharding = tuple((SpmdShard.R,) * (out_rank - 2) + (perm[0][-2], perm[1][-1])) if not is_fully_replicated(perm) and is_valid_sharding_pair(perm) and is_valid_2d_sharding(output_sharding): input_shardings.append({ diff --git a/src/chop/passes/graph/analysis/autosharding/common.py b/src/chop/passes/graph/analysis/autosharding/common.py index 7d76468a9..e0b98001a 100644 --- a/src/chop/passes/graph/analysis/autosharding/common.py +++ b/src/chop/passes/graph/analysis/autosharding/common.py @@ -19,7 +19,7 @@ def __gt__(self, other): (SpmdShard.R, SpmdShard.S_0), (SpmdShard.R, SpmdShard.S_1), (SpmdShard.S_0, SpmdShard.R), - # (SpmdShard.S_0, SpmdShard.S_1), + (SpmdShard.S_0, SpmdShard.S_1), (SpmdShard.S_1, SpmdShard.R), - # (SpmdShard.S_1, SpmdShard.S_0), + (SpmdShard.S_1, SpmdShard.S_0), ] \ No newline at end of file diff --git a/src/chop/passes/graph/analysis/autosharding/debug_utilities.py b/src/chop/passes/graph/analysis/autosharding/debug_utilities.py new file mode 100644 index 000000000..10b4b1c67 --- /dev/null +++ b/src/chop/passes/graph/analysis/autosharding/debug_utilities.py @@ -0,0 +1,72 @@ + + +import torch.nn as nn + +from chop.tools import get_logger + +from chop import MaseGraph +import chop.passes as passes +import torch + +logger = get_logger(__name__) +logger.setLevel("DEBUG") + +def are_layers_equal(layer1, layer2): + # Check if both layers are instances of the same class + if type(layer1) != type(layer2): + return False + + # Compare their attributes + for attr in dir(layer1): + # Skip methods and special attributes + l1_attr = getattr(layer1, attr) + if callable(getattr(layer1, attr)) or attr.startswith("_") or isinstance(l1_attr, torch.Tensor): + continue + # Check if both layers have the same attribute and their values are equal + if hasattr(layer2, attr): + if getattr(layer1, attr) != getattr(layer2, attr): + return False + else: + return False + + return True + +def debug_shardings(layer, input_shardings, world_size, device_mesh): + + from chop.distributed import MaseLauncher + + class WrapperModule(nn.Module): + def __init__(self, layer): + super().__init__() + self.layer = layer + + def forward(self, x): + return self.layer(x) + + logger.info(f"Generating subgraph for layer: {layer}") + mg = MaseGraph(WrapperModule(layer)) + mg, _ = passes.init_metadata_analysis_pass(mg) + mg, _ = passes.add_common_metadata_analysis_pass( + mg, + pass_args={ + "dummy_in": { + "x": torch.randn((1, layer.in_features)), + }, + "add_value": False, + }, + ) + + for idx, sharding in enumerate(input_shardings): + module_map = { + "node": "---", + "sharding": { + layer: { + key: sharding[key] for key in sharding.keys() + } + } + } + logger.info(f"[{idx}/{len(input_shardings)}] Testing shading: {sharding}") + launcher = MaseLauncher(mg, world_size=world_size, device_mesh=device_mesh) + # inputs = [torch.randint(0, 10, (1, config_sequence_length))] + inputs = [torch.randn((1, layer.in_features))] + launcher.run(module_map, inputs) \ No newline at end of file diff --git a/src/chop/passes/module/transforms/autosharding/resharding.py b/src/chop/passes/module/transforms/autosharding/resharding.py index b7835960b..618c97ab7 100644 --- a/src/chop/passes/module/transforms/autosharding/resharding.py +++ b/src/chop/passes/module/transforms/autosharding/resharding.py @@ -43,7 +43,8 @@ class ReshardingWrapper(nn.Module): def __init__(self, device_mesh, module, resharding_config): super().__init__() self.module = module - self.resharding_config = resharding_config + self.resharding_config = resharding_config["sharding"] + self.node = resharding_config["node"] self.device_mesh = device_mesh def forward(self, x): @@ -52,7 +53,7 @@ def forward(self, x): required_placement = placement_from_sharding_config(self.resharding_config["data_in_0"]) if (x.placements != required_placement): - rlog(logger, rank, f"For module {self.module}, resharding tensor x from {x.placements} to {required_placement}", level="info") + rlog(logger, rank, f"For module {self.node}, resharding tensor x from {x.placements} to {required_placement}", level="info") x = Redistribute.apply(x, device_mesh, required_placement) out = self.module(x) From b1baaf3c5e82a341f7a90800fed8b6be8a42cbd4 Mon Sep 17 00:00:00 2001 From: Pedro Gimenes Date: Fri, 14 Jun 2024 15:39:31 +0000 Subject: [PATCH 12/93] insert inference timing and lower logging level --- src/chop/distributed/launcher.py | 26 +++++++++++++------ .../graph/analysis/autosharding/alpa.py | 4 +-- .../transforms/autosharding/resharding.py | 6 ++--- 3 files changed, 22 insertions(+), 14 deletions(-) diff --git a/src/chop/distributed/launcher.py b/src/chop/distributed/launcher.py index 17bc786c9..31095d7bb 100644 --- a/src/chop/distributed/launcher.py +++ b/src/chop/distributed/launcher.py @@ -20,7 +20,15 @@ from .utils import placement_from_sharding_config logger = get_logger(__name__) -logger.setLevel("DEBUG") +logger.setLevel("INFO") + +def distributed_timing(fn, *args, **kwargs): + dist.barrier() + start = time() + result = fn(*args, **kwargs) + dist.barrier() + end = time() + return result, (end - start) def dist_model_fn( name: str, module: nn.Module, device_mesh: DeviceMesh, rank: int, module_map={} @@ -56,21 +64,23 @@ def device_fn(rank, world_size, model=None, device_mesh=None, module_map={}, inp os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = "12355" os.environ["RANK"] = str(rank) + + # Initialize dist.init_process_group("nccl", rank=rank, world_size=world_size) device = torch.device("cuda", rank) torch.cuda.set_device(device) + # Distribute model parameters according to sharding configuration mesh = DeviceMesh("cuda", mesh=device_mesh) rlog(logger, rank, f"Distributing module parameters...", level="info") - start = time() - model = distribute_module( - model, mesh, partial(dist_model_fn, rank=rank, module_map=module_map), input_fn=None, output_fn=None - ) - end = time() - rlog(logger, rank, f"Module distribution done. Time taken: {end - start} seconds.") + model, dist_time = distributed_timing(distribute_module, model, mesh, partial(dist_model_fn, rank=rank, module_map=module_map), input_fn=None, output_fn=None) + rlog(logger, rank, f"Module distribution done. Time taken: {dist_time} seconds.") + # Run forward pass + rlog(logger, rank, f"Starting forward pass.", level="info") inputs = [distribute_tensor(in_tensor, mesh, [Replicate(), Replicate()]) for in_tensor in inputs] - out = model(*inputs) + out, time_taken = distributed_timing(model, *inputs) + rlog(logger, rank, f"Forward pass finished. Time taken: {time_taken}", level="info") dist.destroy_process_group() diff --git a/src/chop/passes/graph/analysis/autosharding/alpa.py b/src/chop/passes/graph/analysis/autosharding/alpa.py index 4f4fd32a0..946b4ea49 100644 --- a/src/chop/passes/graph/analysis/autosharding/alpa.py +++ b/src/chop/passes/graph/analysis/autosharding/alpa.py @@ -13,7 +13,7 @@ from .debug_utilities import debug_shardings, are_layers_equal logger = get_logger(__name__) -logger.setLevel("DEBUG") +logger.setLevel("INFO") def excepthook(exc_type, exc_value, exc_traceback): traceback.print_exception(exc_type, exc_value, exc_traceback) @@ -154,8 +154,6 @@ def alpa_intra_op_sharding_pass(mg, mesh, debug=False): } module_map[target]["sharding"]["output"] = node.meta["mase"]["software"]["autosharding"]["output_sharding"] - logger.info(f"Chosen sharding for node {node.name}: {chosen_sharding}") - return mg, module_map def alpa_autosharding_pass(mg, mesh): diff --git a/src/chop/passes/module/transforms/autosharding/resharding.py b/src/chop/passes/module/transforms/autosharding/resharding.py index 618c97ab7..4a8617c4d 100644 --- a/src/chop/passes/module/transforms/autosharding/resharding.py +++ b/src/chop/passes/module/transforms/autosharding/resharding.py @@ -14,7 +14,7 @@ from chop.tools import get_logger logger = get_logger(__name__) -logger.setLevel("DEBUG") +logger.setLevel("INFO") def rlog(logger, rank, msg, level="info"): """ @@ -53,7 +53,7 @@ def forward(self, x): required_placement = placement_from_sharding_config(self.resharding_config["data_in_0"]) if (x.placements != required_placement): - rlog(logger, rank, f"For module {self.node}, resharding tensor x from {x.placements} to {required_placement}", level="info") + rlog(logger, rank, f"For module {self.node}, resharding tensor x from {x.placements} to {required_placement}", level="debug") x = Redistribute.apply(x, device_mesh, required_placement) out = self.module(x) @@ -78,7 +78,7 @@ def resharding_transform_pass(mg, pass_args={}): module = deepgetattr(mg.model, node.target, None) if module is not None: resharding_config = module_map[module] - logger.info(f"Inserting resharding wrapper around node: {node}") + logger.debug(f"Inserting resharding wrapper around node: {node}") deepsetattr(mg.model, node.target, ReshardingWrapper(device_mesh, module, resharding_config)) mg.model.recompile() From 6e14c2138c6f86623a1d77b81994ef2c6a02a275 Mon Sep 17 00:00:00 2001 From: Pedro Gimenes Date: Wed, 19 Jun 2024 21:33:35 +0000 Subject: [PATCH 13/93] pipeline for distributed inference + report parallelization pass --- src/chop/__init__.py | 2 ++ src/chop/ir/common.py | 4 +++ .../add_metadata/common_metadata_layers.py | 31 +++++++++++++++++-- .../passes/graph/analysis/report/__init__.py | 1 + .../analysis/report/report_parallelization.py | 26 ++++++++++++++++ src/chop/pipelines/__init__.py | 2 ++ src/chop/pipelines/auto_pipeline.py | 26 ++++++++++++++++ src/chop/pipelines/distributed_inference.py | 19 ++++++++++++ 8 files changed, 108 insertions(+), 3 deletions(-) create mode 100644 src/chop/passes/graph/analysis/report/report_parallelization.py create mode 100644 src/chop/pipelines/__init__.py create mode 100644 src/chop/pipelines/auto_pipeline.py create mode 100644 src/chop/pipelines/distributed_inference.py diff --git a/src/chop/__init__.py b/src/chop/__init__.py index 99e464c5a..6e33393f4 100644 --- a/src/chop/__init__.py +++ b/src/chop/__init__.py @@ -2,3 +2,5 @@ from .ir.onnx.mase_onnx_graph import MaseOnnxGraph from . import passes + +from .pipelines import AutoPipelineForDistributedInference \ No newline at end of file diff --git a/src/chop/ir/common.py b/src/chop/ir/common.py index fd0cd1b6a..3c3b08fb5 100644 --- a/src/chop/ir/common.py +++ b/src/chop/ir/common.py @@ -86,6 +86,7 @@ MASE_BUILTIN_FUNCS = [ "mul", + "addmm", "sub", "add", "matmul", @@ -105,6 +106,7 @@ "cosh", "tanh", "greater", + "gt", "less", "le", # less or equal "sigmoid", @@ -112,8 +114,10 @@ "min", "neg", "log", + "arange", "range", "gelu", + "scaled_dot_product_attention", ] diff --git a/src/chop/passes/graph/analysis/add_metadata/common_metadata_layers.py b/src/chop/passes/graph/analysis/add_metadata/common_metadata_layers.py index 7eb147a27..5455ec83e 100644 --- a/src/chop/passes/graph/analysis/add_metadata/common_metadata_layers.py +++ b/src/chop/passes/graph/analysis/add_metadata/common_metadata_layers.py @@ -15,6 +15,16 @@ # The following information is fetched from pytorch documentation func_data = { + # https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html + "scaled_dot_product_attention": { + "query": "data_in", + "key": "data_in", + "value": "data_in", + "attn_mask": "data_in", + "dropout_p": "config", + "is_causal": "config", + "scale": "config", + }, # https://pytorch.org/docs/stable/generated/torch.flatten.html#torch.flatten "flatten": {"input": "data_in", "start_dim": "config", "end_dim": "config"}, # https://pytorch.org/docs/stable/generated/torch.nn.functional.relu.html @@ -43,6 +53,8 @@ "softsign": {"input": "data_in", "inplace": "config"}, # https://pytorch.org/docs/stable/generated/torch.nn.Softplus.html "softplus": {"input": "data_in", "inplace": "config"}, + # https://pytorch.org/docs/stable/generated/torch.addmm.html + "addmm": {"input": "data_in", "mat1": "data_in", "mat2": "data_in", "beta": "config", "alpha": "config"}, # https://pytorch.org/docs/stable/generated/torch.add.html "add": {"input": "data_in", "other": "data_in"}, # https://pytorch.org/docs/stable/generated/torch.mul.html @@ -106,8 +118,10 @@ "tan": {"input": "data_in"}, # https://pytorch.org/docs/stable/generated/torch.tanh.html "tanh": {"input": "data_in"}, - # https://pytorch.org/docs/stable/generated/torch.gt.html#torch.gt + # https://pytorch.org/docs/stable/generated/torch.greater.html "greater": {"input": "data_in", "other": "data_in"}, + # https://pytorch.org/docs/stable/generated/torch.gt.html + "gt": {"input": "data_in", "other": "data_in"}, # https://pytorch.org/docs/stable/generated/torch.abs.html "abs": {"input": "data_in"}, # https://pytorch.org/docs/stable/generated/torch.sigmoid.html @@ -124,6 +138,8 @@ "less": {"input": "data_in", "other": "data_in"}, # https://pytorch.org/docs/stable/generated/torch.le.html "lessorequal": {"input": "data_in", "other": "data_in"}, + # https://pytorch.org/docs/stable/generated/torch.le.html + "le": {"input": "data_in", "other": "data_in"}, # https://pytorch.org/docs/stable/generated/torch.min.html "min": {"input": "data_in"}, # https://pytorch.org/docs/stable/generated/torch.neg.html @@ -132,6 +148,8 @@ "log": {"input": "data_in"}, # https://pytorch.org/docs/stable/generated/torch.mean.html "mean": {"input": "data_in"}, + # https://pytorch.org/docs/stable/generated/torch.arange.html + "arange": {"start": "config", "end": "config", "step": "config", "dtype": "config", "device": "config"}, # https://pytorch.org/docs/stable/generated/torch.range.html "range": {"start": "config", "end": "config", "step": "config"}, # https://pytorch.org/docs/stable/generated/torch.where.html @@ -249,6 +267,10 @@ # https://pytorch.org/docs/stable/generated/torch.Tensor.contiguous.html#torch.Tensor.contiguous "contiguous": {}, "masked_fill": {"mask": "data_in", "value": "data_in"}, + # https://pytorch.org/docs/stable/generated/torch.Tensor.unsqueeze.html#torch.Tensor.unsqueeze + "unsqueeze": {"input": "data_in", "dim": "config"}, + # https://pytorch.org/docs/stable/generated/torch.Tensor.split.html#torch.Tensor.split + "split": {"input": "data_in", "split_size_or_sections": "config", "dim": "config"}, } @@ -275,7 +297,9 @@ def match_args_and_kwargs(meta, args, kwargs, data, add_value): meta_kwargs[n] = args[i] def get_shape(x): - if isinstance(x, torch.Tensor): + if x is None: + return None + elif isinstance(x, torch.Tensor): return list(x.shape) elif isinstance(x, int): return [1] @@ -287,8 +311,9 @@ def get_shape(x): for k, v in kwargs.items(): if data[k] == "data_in": # rename this to mase data_in_number + shape = get_shape(v) arg_meta = { - "shape": get_shape(v), + "shape": shape, "torch_dtype": v.dtype if isinstance(v, torch.Tensor) else type(v), "type": "float", "precision": [32], diff --git a/src/chop/passes/graph/analysis/report/__init__.py b/src/chop/passes/graph/analysis/report/__init__.py index 61a83b786..057db9740 100644 --- a/src/chop/passes/graph/analysis/report/__init__.py +++ b/src/chop/passes/graph/analysis/report/__init__.py @@ -5,3 +5,4 @@ report_node_shape_analysis_pass, report_node_type_analysis_pass, ) +from .report_parallelization import report_parallelization_analysis_pass \ No newline at end of file diff --git a/src/chop/passes/graph/analysis/report/report_parallelization.py b/src/chop/passes/graph/analysis/report/report_parallelization.py new file mode 100644 index 000000000..17d97f905 --- /dev/null +++ b/src/chop/passes/graph/analysis/report/report_parallelization.py @@ -0,0 +1,26 @@ + +from tabulate import tabulate + +def report_parallelization_analysis_pass(mg, pass_args={}): + fname = pass_args.get("file_name", "report_parallelization.txt") + + headers = ["Node", "Node op", "Mase op", "Args", "Kwargs", "Valid Input Shardings", "Input Sharding", "Output Sharding"] + info = [] + for node in mg.fx_graph.nodes: + sharding_config = node.meta['mase']['software']['autosharding'] + info.append([ + node.name, + node.op, + node.meta['mase']['common']['mase_op'], + node.args, + node.kwargs, + sharding_config["valid_input_shardings"], + sharding_config["input_sharding"], + sharding_config["output_sharding"] + + ]) + + with open(fname, "w") as f: + f.write(f"{tabulate(info, headers)}\n") + + return mg, {} \ No newline at end of file diff --git a/src/chop/pipelines/__init__.py b/src/chop/pipelines/__init__.py new file mode 100644 index 000000000..16a0e5e6f --- /dev/null +++ b/src/chop/pipelines/__init__.py @@ -0,0 +1,2 @@ +from .auto_pipeline import AutoPipeline +from .distributed_inference import AutoPipelineForDistributedInference \ No newline at end of file diff --git a/src/chop/pipelines/auto_pipeline.py b/src/chop/pipelines/auto_pipeline.py new file mode 100644 index 000000000..147240753 --- /dev/null +++ b/src/chop/pipelines/auto_pipeline.py @@ -0,0 +1,26 @@ +from chop.ir import MaseGraph +from chop.tools.logger import get_logger + +logger = get_logger(__name__) + + +class AutoPipeline: + def __init__(self, pass_list=[]) -> None: + self.pass_list = pass_list + self.pass_outputs = {} + + def __call__(self, mg: MaseGraph, pass_args: dict, skip_passes: list = []): + for pass_fn in self.pass_list: + if pass_fn in skip_passes: + logger.debug(f"Skipping pass: {pass_fn.__name__}") + continue + logger.debug(f"Running pass: {pass_fn.__name__}") + args = pass_args.get(pass_fn.__name__, {}) + + for k, v in args.items(): + if isinstance(v, str) and v.startswith("self/"): + args[k] = self.pass_outputs[v[5:]] + + mg, pass_output = pass_fn(mg, pass_args=args) + self.pass_outputs[pass_fn.__name__] = pass_output + return mg \ No newline at end of file diff --git a/src/chop/pipelines/distributed_inference.py b/src/chop/pipelines/distributed_inference.py new file mode 100644 index 000000000..6907fb65f --- /dev/null +++ b/src/chop/pipelines/distributed_inference.py @@ -0,0 +1,19 @@ +import chop.passes as passes + +from .auto_pipeline import AutoPipeline + + +class AutoPipelineForDistributedInference(AutoPipeline): + def __init__(self) -> None: + + pass_list = [ + passes.init_metadata_analysis_pass, + passes.report_graph_analysis_pass, + passes.add_common_metadata_analysis_pass, + passes.report_node_meta_param_analysis_pass, + passes.autosharding_analysis_pass, + passes.resharding_transform_pass, + passes.graph.analysis.report.report_parallelization_analysis_pass, + ] + + super().__init__(pass_list) \ No newline at end of file From ad299bc27040628c2a26f66f17d1928b84f5bbc4 Mon Sep 17 00:00:00 2001 From: pgimenes Date: Tue, 25 Jun 2024 12:04:18 +0100 Subject: [PATCH 14/93] layer norm and baddbmm ops --- src/chop/ir/common.py | 3 +- .../add_metadata/common_metadata_layers.py | 32 +++++++++++++++++-- 2 files changed, 32 insertions(+), 3 deletions(-) diff --git a/src/chop/ir/common.py b/src/chop/ir/common.py index 3c3b08fb5..8d6cd309b 100644 --- a/src/chop/ir/common.py +++ b/src/chop/ir/common.py @@ -46,7 +46,7 @@ "ones", "dim", "finfo", - "masked_fill" + "masked_fill", ] MASE_MODULE_RELATED_FUNCS = [ @@ -86,6 +86,7 @@ MASE_BUILTIN_FUNCS = [ "mul", + "baddbmm", "addmm", "sub", "add", diff --git a/src/chop/passes/graph/analysis/add_metadata/common_metadata_layers.py b/src/chop/passes/graph/analysis/add_metadata/common_metadata_layers.py index 5455ec83e..3135e9736 100644 --- a/src/chop/passes/graph/analysis/add_metadata/common_metadata_layers.py +++ b/src/chop/passes/graph/analysis/add_metadata/common_metadata_layers.py @@ -54,7 +54,21 @@ # https://pytorch.org/docs/stable/generated/torch.nn.Softplus.html "softplus": {"input": "data_in", "inplace": "config"}, # https://pytorch.org/docs/stable/generated/torch.addmm.html - "addmm": {"input": "data_in", "mat1": "data_in", "mat2": "data_in", "beta": "config", "alpha": "config"}, + "baddbmm": { + "input": "data_in", + "batch1": "data_in", + "batch2": "data_in", + "beta": "config", + "alpha": "config", + }, + # https://pytorch.org/docs/stable/generated/torch.addmm.html + "addmm": { + "input": "data_in", + "mat1": "data_in", + "mat2": "data_in", + "beta": "config", + "alpha": "config", + }, # https://pytorch.org/docs/stable/generated/torch.add.html "add": {"input": "data_in", "other": "data_in"}, # https://pytorch.org/docs/stable/generated/torch.mul.html @@ -149,7 +163,13 @@ # https://pytorch.org/docs/stable/generated/torch.mean.html "mean": {"input": "data_in"}, # https://pytorch.org/docs/stable/generated/torch.arange.html - "arange": {"start": "config", "end": "config", "step": "config", "dtype": "config", "device": "config"}, + "arange": { + "start": "config", + "end": "config", + "step": "config", + "dtype": "config", + "device": "config", + }, # https://pytorch.org/docs/stable/generated/torch.range.html "range": {"start": "config", "end": "config", "step": "config"}, # https://pytorch.org/docs/stable/generated/torch.where.html @@ -177,6 +197,14 @@ # https://pytorch.org/docs/stable/generated/torch.ones.html "ones": {"size": "config", "device": "config"}, "finfo": {"dtype": "config"}, + # https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html#torch.nn.LayerNorm + "layer_norm": { + "input": "data_in", + "normalized_shape": "config", + "weight": "data_in", + "bias": "data_in", + "eps": "config", + }, } module_data = { From 4566af95de6b0a2e0e561aac12f0053e87086a44 Mon Sep 17 00:00:00 2001 From: pgimenes Date: Tue, 25 Jun 2024 12:56:04 +0100 Subject: [PATCH 15/93] [REFACTOR] Export strategies for each node using torch distributed format. Supported ops: mm, bmm, addmm, baddbmm --- .../analysis/autosharding/alpa_layers.py | 95 ++------- .../autosharding/ops/basic_strategy.py | 180 ++++++++++++++++++ .../analysis/autosharding/ops/matrix_ops.py | 133 +++++++++++++ .../graph/analysis/autosharding/utils.py | 20 ++ 4 files changed, 346 insertions(+), 82 deletions(-) create mode 100644 src/chop/passes/graph/analysis/autosharding/ops/basic_strategy.py create mode 100644 src/chop/passes/graph/analysis/autosharding/ops/matrix_ops.py create mode 100644 src/chop/passes/graph/analysis/autosharding/utils.py diff --git a/src/chop/passes/graph/analysis/autosharding/alpa_layers.py b/src/chop/passes/graph/analysis/autosharding/alpa_layers.py index 27298cad1..45a5f8952 100644 --- a/src/chop/passes/graph/analysis/autosharding/alpa_layers.py +++ b/src/chop/passes/graph/analysis/autosharding/alpa_layers.py @@ -1,96 +1,27 @@ import itertools import numpy as np +import torch import torch.nn as nn from chop.tools import get_logger from chop.models.patched.bert.modeling_bert import BertSelfAttention -from .common import SpmdShard, VALID_2D_TENSOR_SHARDINGS from .alpa_cost_modelling import get_communication_cost +from .ops.matrix_ops import ( + transpose_strategy, + mm_strategy, + addmm_strategy, + bmm_strategy, + baddmm_strategy, +) logger = get_logger(__name__) -def is_valid_2d_sharding(sharding): - if len(sharding) > 2: - return sharding[1:] in VALID_2D_TENSOR_SHARDINGS - else: - return sharding in VALID_2D_TENSOR_SHARDINGS - -def is_valid_sharding_pair(sharding_pair): - return sharding_pair[0][-1] == sharding_pair[1][-2] - -def is_fully_replicated(sharding_pair): - return all(all(dimp == SpmdShard.R for dimp in subp) for subp in sharding_pair) - -def get_valid_2d_shardings(node_meta, mesh, module): - """ - Return every valid combination of shardings for the input tensors. For an operator - sharding to be valid, the inner dimension must have the same sharding. - E.g. ((R, S_0), (S_0, R)) are valid, but ((R, S_0), (S_1, R)) is not. - """ - input_shardings = [] - output_shardings = [] - compute_cost_vector = [] - communication_cost_vector = [] - - out_rank = len(node_meta["mase"]["common"]["results"]["data_out_0"]["shape"]) - - for perm in itertools.product(VALID_2D_TENSOR_SHARDINGS, repeat=2): - if out_rank > 2: - # Always replicate along batch dimension, and assume weights are always 2D - # TO DO: handle sharding along batch dimension - perm = ((SpmdShard.R, ) + perm[0], perm[1]) - output_sharding = tuple((SpmdShard.R,) * (out_rank - 2) + (perm[0][-2], perm[1][-1])) - if not is_fully_replicated(perm) and is_valid_sharding_pair(perm) and is_valid_2d_sharding(output_sharding): - input_shardings.append({ - "data_in_0": perm[0], - "weight": perm[1] - }) - output_shardings.append(output_sharding) - - compute_cost_vector.append(0) - communication_cost_vector.append(get_communication_cost(perm, node_meta["mase"], mesh)) - - return ( - input_shardings, - output_shardings, - np.array(compute_cost_vector), - np.array(communication_cost_vector), - ) - -def get_valid_linear_shardings(node_meta, mesh, module): - return get_valid_2d_shardings(node_meta, mesh, module) - -def get_valid_layernorm_shardings(node_meta, mesh, module): - rank = len(node_meta["mase"]["common"]["results"]["data_out_0"]["shape"]) - valid_input_shardings = [{"data_in_0": (SpmdShard.R,) * rank}] - valid_output_shardings = [(SpmdShard.R,) * rank] - compute_cost_vector = [0] - communication_cost_vector = [0] - return ( - valid_input_shardings, - valid_output_shardings, - np.array(compute_cost_vector), - np.array(communication_cost_vector), - ) - -def get_valid_embedding_shardings(node_meta, mesh, module): - weight_rank = len(module.weight.shape) - data_in_rank = len(node_meta["mase"]["common"]["args"]["data_in_0"]["shape"]) - valid_input_shardings = [{"data_in_0": (SpmdShard.R,) * data_in_rank, "weight": (SpmdShard.R,) * weight_rank}] - valid_output_shardings = [(SpmdShard.R,) * data_in_rank] - compute_cost_vector = [0] - communication_cost_vector = [0] - return ( - valid_input_shardings, - valid_output_shardings, - np.array(compute_cost_vector), - np.array(communication_cost_vector), - ) - ALPA_LAYERS = { - nn.Linear: get_valid_linear_shardings, - nn.LayerNorm: get_valid_layernorm_shardings, - nn.Embedding: get_valid_embedding_shardings, + torch.transpose: transpose_strategy, + torch.mm: mm_strategy, + torch.addmm: addmm_strategy, + torch.bmm: bmm_strategy, + torch.baddbmm: baddmm_strategy, } diff --git a/src/chop/passes/graph/analysis/autosharding/ops/basic_strategy.py b/src/chop/passes/graph/analysis/autosharding/ops/basic_strategy.py new file mode 100644 index 000000000..416f3eda0 --- /dev/null +++ b/src/chop/passes/graph/analysis/autosharding/ops/basic_strategy.py @@ -0,0 +1,180 @@ +import itertools +from dataclasses import dataclass +from typing import List, Set, Tuple + +from torch.distributed._tensor.op_schema import OpStrategy, PlacementStrategy +from torch.distributed._tensor.placement_types import ( + DTensorSpec, + _Partial, + Placement, + Replicate, + Shard, +) + + +@dataclass +class EinsumDims: + contracting_dims: List[str] + batch_dims: List[str] + lhs_out_only_dims: List[str] + rhs_out_only_dims: List[str] + + @classmethod + def parse_equation(cls, equation: str) -> Tuple[List[str], str]: + # parse einop equation and extract arg specs + """ + Parse the einsum equation str to input dim chars and output dim char + """ + inputs, outputs = equation.split("->") + input_dims, output_dims = inputs.split(","), outputs.split(",") + + # NOTE: only support at most two inputs, and single output + # extend to support more inputs if needed in future + assert len(input_dims) <= 2, "Only support at most two inputs" + assert len(output_dims) == 1, "Only support single output" + output_dim = output_dims[0] + return input_dims, output_dim + + @classmethod + def parse_dims(cls, input_dims: List[str], output_dim: str) -> "EinsumDims": + """ + Parse the dims and extract the contracting, batch, and free dimensions + for the left and right hand sides. + """ + dim_char_set: Set[str] = set() + for input_dim in input_dims: + dim_char_set.update(input_dim) + + # get a determinisitc order of all dim chars + all_dim_chars = sorted(dim_char_set) + + # parse input and output dimensions + lhs_out_only_dims, rhs_out_only_dims = [], [] + batch_dims, contracting_dims = [], [] + + for dim_char in all_dim_chars: + if dim_char not in output_dim: + contracting_dims.append(dim_char) + else: + is_batch_dim = True + for input_dim in input_dims: + is_batch_dim = is_batch_dim and dim_char in input_dim + + if is_batch_dim: + batch_dims.append(dim_char) + else: + assert ( + len(input_dims) == 2 + ), "free dimension only supported for two inputs!" + lhs, rhs = input_dims + if dim_char in lhs: + lhs_out_only_dims.append(dim_char) + elif dim_char in rhs: + rhs_out_only_dims.append(dim_char) + else: + raise RuntimeError("Invalid dimension character") + + return cls( + contracting_dims=contracting_dims, + batch_dims=batch_dims, + lhs_out_only_dims=lhs_out_only_dims, + rhs_out_only_dims=rhs_out_only_dims, + ) + + +def gen_einsum_strategies( + equation: str, + mesh: tuple, + *, + linearity: bool = False, +) -> OpStrategy: + """ + Generate a strategy list for the ops that follow einsum style notation. + """ + # parse einop equation and extract dims + input_dims, output_dim = EinsumDims.parse_equation(equation) + edims = EinsumDims.parse_dims(input_dims, output_dim) + + all_mesh_dim_strategies = [] + + # generate strategies for each mesh dim + for mesh_dim in range(len(mesh)): + mesh_dim_strategies = [] + + # placement list stores placements of [output, input1, input2, ...] + # first we always have replicate all for inputs and output + placement_list: List[Placement] = [Replicate()] * (len(input_dims) + 1) + mesh_dim_strategies.append(placement_list) + + if mesh[mesh_dim] <= 1: + # only replicate strategy for mesh dim with size 1 + # TODO: see if this is valid for the submesh case + continue + + # split batch dim + for batch_dim in edims.batch_dims: + output_batch_dim = output_dim.index(batch_dim) + placement_list = [Shard(output_batch_dim)] + for input_dim in input_dims: + input_batch_dim = input_dim.index(batch_dim) + placement_list.append(Shard(input_batch_dim)) + + mesh_dim_strategies.append(placement_list) + + # split contracting dim + for contracting_dim in edims.contracting_dims: + placement_list = [_Partial()] + for input_dim in input_dims: + input_contracting_dim = input_dim.index(contracting_dim) + placement_list.append(Shard(input_contracting_dim)) + + mesh_dim_strategies.append(placement_list) + + # split lhs free dim + for lhs_dim in edims.lhs_out_only_dims: + lhs_free_dim = output_dim.index(lhs_dim) + # this means split the lhs input and output + # i.e. S(0), R -> S(0) + lhs_placement_list: List[Placement] = [ + Shard(lhs_free_dim), + Shard(lhs_free_dim), + Replicate(), + ] + mesh_dim_strategies.append(lhs_placement_list) + + # split rhs free dim + for rhs_dim in edims.rhs_out_only_dims: + rhs_free_dim = output_dim.index(rhs_dim) + rhs_placement_list: List[Placement] = [ + Shard(rhs_free_dim), + Replicate(), + Shard(rhs_free_dim), + ] + mesh_dim_strategies.append(rhs_placement_list) + + # linearity strategy + if linearity: + linearity_placement_list: List[Placement] = [_Partial()] + for input_dim in input_dims: + linearity_placement_list.append(_Partial()) + mesh_dim_strategies.append(linearity_placement_list) + + all_mesh_dim_strategies.append(mesh_dim_strategies) + + # generate strategies for entire mesh + strategy_combs = itertools.product(*all_mesh_dim_strategies) + + # TODO: filter out invalid strategies, at this point we generate + # all possible strategies without considering the whether the tensor + # dim could be sharded or not, we would need to filter out invalid + # strategies base on the actual tensor shape + # (i.e. for Shard, tensor dim size must > mesh size) + all_strategies = [] + for strategy_comb in strategy_combs: + spec_list = [] + for specs in zip(*strategy_comb): + spec_list.append(DTensorSpec(mesh, tuple(specs))) + strat = PlacementStrategy(output_specs=spec_list[0], input_specs=spec_list[1:]) + all_strategies.append(strat) + + return OpStrategy(all_strategies) diff --git a/src/chop/passes/graph/analysis/autosharding/ops/matrix_ops.py b/src/chop/passes/graph/analysis/autosharding/ops/matrix_ops.py new file mode 100644 index 000000000..6826b0429 --- /dev/null +++ b/src/chop/passes/graph/analysis/autosharding/ops/matrix_ops.py @@ -0,0 +1,133 @@ +# Adapted from https://github.com/pytorch/pytorch/blob/main/torch/distributed/_tensor/ops/matrix_ops.py + +import itertools +from typing import List, Optional + +import torch +from torch.distributed._tensor.op_schema import OpStrategy, PlacementStrategy +from .basic_strategy import gen_einsum_strategies +from torch.distributed._tensor.ops.utils import ( + infer_broadcast_dims_map, + map_placements_after_broadcast, + register_op_strategy, +) +from torch.distributed._tensor.placement_types import ( + DTensorSpec, + Placement, + Replicate, + Shard, +) +from torch.distributed.device_mesh import DeviceMesh + +from ..utils import is_tensor_shardable + +from chop.ir.graph import MaseMetadata + +aten = torch.ops.aten + + +def transpose_strategy(meta: MaseMetadata, mesh: tuple) -> OpStrategy: + self_strategy = op_schema.args_schema[0] + assert isinstance(self_strategy, OpStrategy) + + transpose_strategies = [] + for input_strategy in self_strategy.strategies: + input_spec = input_strategy.output_spec + # follow the input spec but transpose the Shard placements + output_placements = [ + Shard(1 - p.dim) if isinstance(p, Shard) else p + for p in input_spec.placements + ] + transpose_strategy = PlacementStrategy( + output_specs=DTensorSpec( + mesh=input_strategy.output_spec.mesh, + placements=tuple(output_placements), + ), + input_specs=(input_strategy.output_spec,), + ) + transpose_strategies.append(transpose_strategy) + + return OpStrategy(strategies=transpose_strategies) + + +def _mm_like_strategy(mm_equation: str, meta: MaseMetadata, mesh: tuple) -> OpStrategy: + self_shape, mat2_shape = [arg["shape"] for arg in meta["common"]["args"].values()] + # generate all possible strategies for mm + mm_strategy = gen_einsum_strategies(mm_equation, mesh) + # filter out invalid strategies and associate costs + strategies = mm_strategy.strategies + filtered_strategies = [] + for strtg in strategies: + assert strtg.input_specs is not None + self_spec = strtg.input_specs[0] + mat2_spec = strtg.input_specs[1] + if is_tensor_shardable(self_shape, self_spec) and is_tensor_shardable( + mat2_shape, mat2_spec + ): + filtered_strategies.append(strtg) + + mm_strategy.strategies = filtered_strategies + + return mm_strategy + + +def _addmm_like_strategy( + mm_equation: str, meta: MaseMetadata, mesh: tuple +) -> OpStrategy: + + self_shape, mat1_shape, mat2_shape = [ + arg["shape"] for arg in meta["common"]["args"].values() + ] + + mm_out_shape = torch.Size( + [ + mat2_shape[-1] if i == len(mat1_shape) - 1 else dim_size + for i, dim_size in enumerate(mat1_shape) + ] + ) + # generate all possible strategies for mm + mm_strategy = gen_einsum_strategies(mm_equation, mesh) + # filter out invalid strategies and associate costs + strategies = mm_strategy.strategies + filtered_strategies = [] + for strtg in strategies: + # construct new strategy by consider the self arg + assert strtg.input_specs is not None + mat1_spec = strtg.input_specs[0] + mat2_spec = strtg.input_specs[1] + out_spec = strtg.output_spec + + # self arg's spec should follow the output of mm, but need + # to consider broadcast for the self arg + broadcast_dims_map = infer_broadcast_dims_map(mm_out_shape, self_shape) + self_placements = map_placements_after_broadcast( + out_spec.placements, mm_out_shape, broadcast_dims_map + ) + self_spec = DTensorSpec(mesh=mesh, placements=self_placements) + + if is_tensor_shardable(mat1_shape, mat1_spec) and is_tensor_shardable( + mat2_shape, mat2_spec + ): + # update input specs with new self spec + strtg.input_specs = (self_spec, mat1_spec, mat2_spec) + filtered_strategies.append(strtg) + + mm_strategy.strategies = filtered_strategies + + return mm_strategy + + +def mm_strategy(meta: MaseMetadata, mesh: tuple) -> OpStrategy: + return _mm_like_strategy("mk,kn->mn", meta, mesh) + + +def addmm_strategy(meta: MaseMetadata, mesh: tuple) -> OpStrategy: + return _addmm_like_strategy("mk,kn->mn", meta, mesh) + + +def bmm_strategy(meta: MaseMetadata, mesh: tuple) -> OpStrategy: + return _mm_like_strategy("bmk,bkn->bmn", meta, mesh) + + +def baddmm_strategy(meta: MaseMetadata, mesh: tuple) -> OpStrategy: + return _addmm_like_strategy("bmk,bkn->bmn", meta, mesh) diff --git a/src/chop/passes/graph/analysis/autosharding/utils.py b/src/chop/passes/graph/analysis/autosharding/utils.py new file mode 100644 index 000000000..8a08be311 --- /dev/null +++ b/src/chop/passes/graph/analysis/autosharding/utils.py @@ -0,0 +1,20 @@ +from typing import cast, Iterable, List, Sequence, Tuple, Union +from torch.distributed._tensor.placement_types import DTensorSpec, Shard + + +def is_tensor_shardable(shape: Sequence[int], spec: DTensorSpec) -> bool: + """Check if the shape is shardable according to the spec.""" + # number of shards in each tensor dimension + shards_map = [1] * len(shape) + for i, placement in enumerate(spec.placements): + if placement.is_shard(): + shard_dim = cast(Shard, placement).dim + shards_map[shard_dim] *= spec.mesh[i] + + for i, dim_size in enumerate(shape): + # TODO: maybe we should determine is_shardable based on + # whether it's evenly sharded or not + if shards_map[i] > 1 and dim_size < shards_map[i]: + return False + + return True From 4b62b881f738c20a9651895e394e0b2987db573b Mon Sep 17 00:00:00 2001 From: Pedro Gimenes Date: Thu, 27 Jun 2024 14:41:02 +0000 Subject: [PATCH 16/93] [REFACTOR] enumerate sharding strategies for reshape nodes: view, expand, permute, etc --- .../add_metadata/add_common_metadata.py | 5 + .../add_metadata/common_metadata_layers.py | 2 + .../graph/analysis/autosharding/alpa.py | 238 ++++--- .../analysis/autosharding/alpa_layers.py | 12 +- .../autosharding/ops/basic_strategy.py | 2 +- .../analysis/autosharding/ops/matrix_ops.py | 6 +- .../analysis/autosharding/ops/view_ops.py | 633 ++++++++++++++++++ 7 files changed, 800 insertions(+), 98 deletions(-) create mode 100644 src/chop/passes/graph/analysis/autosharding/ops/view_ops.py diff --git a/src/chop/passes/graph/analysis/add_metadata/add_common_metadata.py b/src/chop/passes/graph/analysis/add_metadata/add_common_metadata.py index c304445c2..575f22f62 100644 --- a/src/chop/passes/graph/analysis/add_metadata/add_common_metadata.py +++ b/src/chop/passes/graph/analysis/add_metadata/add_common_metadata.py @@ -257,6 +257,11 @@ def graph_iterator_for_metadata( ) env[node.name] = result + # For call_method nodes, the input tensor is not kept in meta["common"]["args"] + # so we keep a copy under the "self" key. This is used in autosharding spec propagation. + if add_value and node.op == "call_method": + node.meta["mase"]["common"]["self"] = self_obj + return graph diff --git a/src/chop/passes/graph/analysis/add_metadata/common_metadata_layers.py b/src/chop/passes/graph/analysis/add_metadata/common_metadata_layers.py index 3135e9736..368fb5375 100644 --- a/src/chop/passes/graph/analysis/add_metadata/common_metadata_layers.py +++ b/src/chop/passes/graph/analysis/add_metadata/common_metadata_layers.py @@ -205,6 +205,8 @@ "bias": "data_in", "eps": "config", }, + # https://pytorch.org/docs/stable/generated/torch.transpose.html + "transpose": {"input": "data_in", "dim_0": "config", "dim_1": "config"}, } module_data = { diff --git a/src/chop/passes/graph/analysis/autosharding/alpa.py b/src/chop/passes/graph/analysis/autosharding/alpa.py index 946b4ea49..48ca08075 100644 --- a/src/chop/passes/graph/analysis/autosharding/alpa.py +++ b/src/chop/passes/graph/analysis/autosharding/alpa.py @@ -8,20 +8,17 @@ from chop.tools import get_logger from .common import SpmdShard -from .alpa_layers import ALPA_LAYERS +from .alpa_layers import ALPA_FUNCTIONS, ALPA_METHODS from .alpa_cost_modelling import get_resharding_matrix -from .debug_utilities import debug_shardings, are_layers_equal logger = get_logger(__name__) -logger.setLevel("INFO") +logger.setLevel("DEBUG") -def excepthook(exc_type, exc_value, exc_traceback): - traceback.print_exception(exc_type, exc_value, exc_traceback) - print("\nEntering debugger...") - pdb.post_mortem(exc_traceback) - -# Set the custom exception hook -sys.excepthook = excepthook +import operator +IGNORE_FUNCS = [operator.getitem] +IGNORE_METHODS = [ + "size" +] def deepgetattr(obj, attr, default=None): """Recurses through an attribute chain to get the ultimate value.""" @@ -30,12 +27,14 @@ def deepgetattr(obj, attr, default=None): except AttributeError: return default + def get_node_target(node): if isinstance(node.target, str): return deepgetattr(node.meta["mase"].model, node.target, None) else: return node.target + def assign_default_sharding(node): rank = len(node.meta["mase"]["common"]["results"]["data_out_0"]["shape"]) node.meta["mase"]["software"]["autosharding"] = { @@ -47,6 +46,60 @@ def assign_default_sharding(node): } +def mark_choices(mg): + """ + Once the metadata has already been filled for each op with the possible shardings and costs, + and the ILP has been solved, this function marks the chosen sharding for each op. + """ + for node in mg.fx_graph.nodes: + chosen_idx = ( + 0 + if isinstance( + node.meta["mase"]["software"]["autosharding"]["opt_var"], np.ndarray + ) + else np.where( + node.meta["mase"]["software"]["autosharding"]["opt_var"].value == 1 + )[0][0] + ) + node.meta["mase"]["software"]["autosharding"]["input_sharding"] = node.meta[ + "mase" + ]["software"]["autosharding"]["valid_input_shardings"][chosen_idx] + node.meta["mase"]["software"]["autosharding"]["output_sharding"] = node.meta[ + "mase" + ]["software"]["autosharding"]["valid_output_shardings"][chosen_idx] + chosen_sharding = { + key: node.meta["mase"]["software"]["autosharding"]["input_sharding"][key] + for key in node.meta["mase"]["software"]["autosharding"][ + "input_sharding" + ].keys() + } + + # Write into module map (used by distributed launcher) + target = get_node_target(node) + if node.op == "call_module" and target is not None: + module_map[target] = {"node": node.name, "sharding": chosen_sharding} + module_map[target]["sharding"]["output"] = node.meta["mase"]["software"][ + "autosharding" + ]["output_sharding"] + + return mg + +from torch.distributed._tensor._op_schema import OpStrategy, PlacementStrategy +from torch.distributed._tensor.placement_types import Replicate, Shard, DTensorSpec +import itertools + +def placeholder_or_getattr_strategy(meta, mesh): + ndims = len(meta["common"]["results"]["data_out_0"]["shape"]) + opts = [Replicate()] + [Shard(dim) for dim in range(ndims)] + shardings = [] + for sharding in itertools.product(opts, repeat=2): + spec = DTensorSpec(mesh, sharding) + shardings.append(PlacementStrategy( + input_specs=spec, + output_specs=spec + )) + return OpStrategy(shardings) + def alpa_intra_op_sharding_pass(mg, mesh, debug=False): """ Intra-operator auto parallelization pass. @@ -60,102 +113,99 @@ def alpa_intra_op_sharding_pass(mg, mesh, debug=False): # Write cost vectors into metadata for each operator # This will later be used to solve the ILP optimization - debugged_layers = [] for node in mg.fx_graph.nodes: - target = get_node_target(node) - target_cls = type(target) - num_params = len([i for i in target.parameters()]) if isinstance(target, nn.Module) else 0 + if (node.op == "call_function" and node.target in IGNORE_FUNCS) or (node.op == "call_method" and node.target in IGNORE_METHODS): + logger.debug(f"Ignoring {node.op} node {node.name} with target {node.target}") + continue + + # Obtain strategy according to node op + # ================================================ + + if node.op in ["placeholder", "get_attr"]: + logger.debug(f"Node {node} with op {node.op} will be assigned all permutations of Shard(dims) and Replicate()") + op_strategy = placeholder_or_getattr_strategy(node.meta["mase"], mesh.mesh_shape) - if node.op != "call_module" or num_params == 0: - assign_default_sharding(node) + elif node.op == "call_method" and node.target in ALPA_METHODS.keys(): + logger.debug(f"Obtaining strategy for node {node.name}") + op_strategy = ALPA_METHODS[node.target](node.meta["mase"], mesh.mesh_shape) - elif target_cls in ALPA_LAYERS.keys(): + elif node.op == "call_function" and node.target in ALPA_FUNCTIONS.keys(): # Enumerate shardings and costs for this operator - ( - input_shardings, - output_shardings, - compute_cost_vector, - communication_cost_vector, - ) = ALPA_LAYERS[target_cls](node.meta, mesh, target) - - # Debug each possible sharding by running an inference step over the device mesh - if debug and not any([are_layers_equal(layer, target) for layer in debugged_layers]): - debug_shardings(target, input_shardings, mesh) - debugged_layers.append(target) - - # Formulate optimization variable and consider compute/communication cost - opt_var = cp.Variable(len(input_shardings), boolean=True) - constr += [ - cp.sum(opt_var) == 1, - ] - expr += opt_var.T @ (compute_cost_vector + communication_cost_vector) - - # Write into metadata - node.meta["mase"]["software"]["autosharding"] = { - "valid_input_shardings": input_shardings, - "valid_output_shardings": output_shardings, - "compute_cost_vector": compute_cost_vector, - "communication_cost_vector": communication_cost_vector, - "opt_var": opt_var, - } - - # Consider resharding cost - for in_node in node.all_input_nodes: - in_opt_var = in_node.meta["mase"]["software"]["autosharding"]["opt_var"] - - resharding_costs = get_resharding_matrix( - mesh, - src_shardings = in_node.meta["mase"]["software"]["autosharding"]["valid_output_shardings"], - dest_shardings = [sharding["data_in_0"] for sharding in node.meta["mase"]["software"]["autosharding"]["valid_input_shardings"]], - dest_node_meta = node.meta["mase"] - ).flatten() - - # Formulate resharding cost term with linearized variable - e_var = cp.Variable(opt_var.shape[0] * in_opt_var.shape[0], boolean=True) - expr += e_var.T @ resharding_costs - constr += [ - cp.sum(e_var) == 1, - ] - - # Scalar construction of the inequality constraints for the linearized variable - for i in range(e_var.shape[0]): - constr += [ - e_var[i] <= opt_var[i // in_opt_var.shape[0]], - e_var[i] <= in_opt_var[i % in_opt_var.shape[0]], - e_var[i] >= opt_var[i // in_opt_var.shape[0]] + in_opt_var[i % in_opt_var.shape[0]] - 1 - ] - - # No sharding algorithm found for this operator, but this has parameter attributes - # (i.e. not an elementwise or implicit function) - elif (len([i for i in target.parameters()]) > 0): - logger.warning(f"No sharding algorithm found for operator: {target_cls}, but the parameter count is non-zero.") - logger.warning(f" MaseLauncher will fully replicate the parameters of this module.") + # ( + # input_shardings, + # output_shardings, + # compute_cost_vector, + # communication_cost_vector, + # ) = ALPA_FUNCTIONS[node.target](node.meta, mesh) + logger.debug(f"Obtaining strategy for node {node.name}") + op_strategy = ALPA_FUNCTIONS[node.target](node.meta["mase"], mesh.mesh_shape) else: - logger.debug(f"Skipping implicit/elementwise operator: {target_cls}") + logger.warning(f"Unknown node {node.name} with op {node.op}") + continue + breakpoint() + + # Formulate optimization variable and consider compute/communication cost + opt_var = cp.Variable(len(op_strategy.strategies), boolean=True) + constr += [ + cp.sum(opt_var) == 1, + ] + # expr += opt_var.T @ (compute_cost_vector + communication_cost_vector) + + # Write into metadata + node.meta["mase"]["software"]["autosharding"] = { + "op_strategy": op_strategy, + "opt_var": opt_var, + } + + # Consider resharding cost + # for in_node in node.all_input_nodes: + # in_opt_var = in_node.meta["mase"]["software"]["autosharding"]["opt_var"] + + # resharding_costs = get_resharding_matrix( + # mesh, + # src_shardings=in_node.meta["mase"]["software"]["autosharding"][ + # "valid_output_shardings" + # ], + # dest_shardings=[ + # sharding["data_in_0"] + # for sharding in node.meta["mase"]["software"]["autosharding"][ + # "valid_input_shardings" + # ] + # ], + # dest_node_meta=node.meta["mase"], + # ).flatten() + + # # Formulate resharding cost term with linearized variable + # e_var = cp.Variable( + # opt_var.shape[0] * in_opt_var.shape[0], boolean=True + # ) + # expr += e_var.T @ resharding_costs + # constr += [ + # cp.sum(e_var) == 1, + # ] + + # # Scalar construction of the inequality constraints for the linearized variable + # for i in range(e_var.shape[0]): + # constr += [ + # e_var[i] <= opt_var[i // in_opt_var.shape[0]], + # e_var[i] <= in_opt_var[i % in_opt_var.shape[0]], + # e_var[i] + # >= opt_var[i // in_opt_var.shape[0]] + # + in_opt_var[i % in_opt_var.shape[0]] + # - 1, + # ] # Solve the ILP problem - prob = cp.Problem(cp.Minimize(expr), constr) - prob.solve() - - for node in mg.fx_graph.nodes: - chosen_idx = 0 if isinstance(node.meta["mase"]["software"]["autosharding"]["opt_var"], np.ndarray) else np.where(node.meta["mase"]["software"]["autosharding"]["opt_var"].value == 1)[0][0] - node.meta["mase"]["software"]["autosharding"]["input_sharding"] = node.meta["mase"]["software"]["autosharding"]["valid_input_shardings"][chosen_idx] - node.meta["mase"]["software"]["autosharding"]["output_sharding"] = node.meta["mase"]["software"]["autosharding"]["valid_output_shardings"][chosen_idx] - chosen_sharding = {key: node.meta["mase"]["software"]["autosharding"]["input_sharding"][key] for key in node.meta["mase"]["software"]["autosharding"]["input_sharding"].keys()} + # prob = cp.Problem(cp.Minimize(expr), constr) + # prob.solve() - # Write into module map (used by distributed launcher) - target = get_node_target(node) - if node.op == "call_module" and target is not None: - module_map[target] = { - "node": node.name, - "sharding": chosen_sharding - } - module_map[target]["sharding"]["output"] = node.meta["mase"]["software"]["autosharding"]["output_sharding"] + # mg = mark_choices(mg) return mg, module_map + def alpa_autosharding_pass(mg, mesh): mg, module_map = alpa_intra_op_sharding_pass(mg, mesh) - return mg, module_map \ No newline at end of file + return mg, module_map diff --git a/src/chop/passes/graph/analysis/autosharding/alpa_layers.py b/src/chop/passes/graph/analysis/autosharding/alpa_layers.py index 45a5f8952..b0c1df206 100644 --- a/src/chop/passes/graph/analysis/autosharding/alpa_layers.py +++ b/src/chop/passes/graph/analysis/autosharding/alpa_layers.py @@ -1,7 +1,9 @@ import itertools import numpy as np + import torch import torch.nn as nn +import torch.nn.functional as F from chop.tools import get_logger from chop.models.patched.bert.modeling_bert import BertSelfAttention @@ -16,12 +18,20 @@ baddmm_strategy, ) +from .ops.view_ops import get_reshape_strategy + logger = get_logger(__name__) -ALPA_LAYERS = { +ALPA_FUNCTIONS = { torch.transpose: transpose_strategy, torch.mm: mm_strategy, torch.addmm: addmm_strategy, torch.bmm: bmm_strategy, torch.baddbmm: baddmm_strategy, } + +ALPA_METHODS = { + "view": get_reshape_strategy(torch.Tensor.view), + "expand": get_reshape_strategy(torch.Tensor.expand), + "permute": get_reshape_strategy(torch.permute), +} diff --git a/src/chop/passes/graph/analysis/autosharding/ops/basic_strategy.py b/src/chop/passes/graph/analysis/autosharding/ops/basic_strategy.py index 416f3eda0..f96666ef1 100644 --- a/src/chop/passes/graph/analysis/autosharding/ops/basic_strategy.py +++ b/src/chop/passes/graph/analysis/autosharding/ops/basic_strategy.py @@ -2,7 +2,7 @@ from dataclasses import dataclass from typing import List, Set, Tuple -from torch.distributed._tensor.op_schema import OpStrategy, PlacementStrategy +from torch.distributed._tensor._op_schema import OpStrategy, PlacementStrategy from torch.distributed._tensor.placement_types import ( DTensorSpec, _Partial, diff --git a/src/chop/passes/graph/analysis/autosharding/ops/matrix_ops.py b/src/chop/passes/graph/analysis/autosharding/ops/matrix_ops.py index 6826b0429..e2390815f 100644 --- a/src/chop/passes/graph/analysis/autosharding/ops/matrix_ops.py +++ b/src/chop/passes/graph/analysis/autosharding/ops/matrix_ops.py @@ -4,7 +4,7 @@ from typing import List, Optional import torch -from torch.distributed._tensor.op_schema import OpStrategy, PlacementStrategy +from torch.distributed._tensor._op_schema import OpStrategy, PlacementStrategy from .basic_strategy import gen_einsum_strategies from torch.distributed._tensor.ops.utils import ( infer_broadcast_dims_map, @@ -27,7 +27,9 @@ def transpose_strategy(meta: MaseMetadata, mesh: tuple) -> OpStrategy: - self_strategy = op_schema.args_schema[0] + + parent_node = meta.node.args[0] + self_strategy = parent_node.meta["mase"]["software"]["autosharding"]["op_strategy"] assert isinstance(self_strategy, OpStrategy) transpose_strategies = [] diff --git a/src/chop/passes/graph/analysis/autosharding/ops/view_ops.py b/src/chop/passes/graph/analysis/autosharding/ops/view_ops.py new file mode 100644 index 000000000..2297c0bf0 --- /dev/null +++ b/src/chop/passes/graph/analysis/autosharding/ops/view_ops.py @@ -0,0 +1,633 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates +from dataclasses import dataclass +from typing import ( + Callable, + cast, + Dict, + Iterable, + List, + Optional, + Sequence, + Set, + Tuple, + Union, +) + +import torch +from torch import Tensor +from torch.distributed._tensor._op_schema import ( + OpSchema, + OpStrategy, + PlacementStrategy, + RuntimeSchemaInfo, + StrategyType, +) +from torch.distributed._tensor.api import Shard +from torch.distributed._tensor.ops.utils import ( + generate_redistribute_costs, + normalize_dim, + normalize_dims, + prod, +) +from torch.distributed._tensor.placement_types import DTensorSpec, Placement, Replicate +from torch.distributed.device_mesh import DeviceMesh + + +aten = torch.ops.aten + +Shape = Tuple[int, ...] + + +@dataclass +class DimSpec: + """Specifies how an output dimension maps to an input dimension.""" + + def inputs(self) -> Iterable["DimSpec"]: + return () + + +# Rules that map each dimension of the output to dimensions of the input tensor +DimMap = Tuple[DimSpec, ...] + + +@dataclass +class Singleton(DimSpec): + """Output dimension is a singleton.""" + + pass + + +@dataclass +class InputDim(DimSpec): + """Output dimension maps directly to an input dimension.""" + + input_dim: int + + +@dataclass +class Broadcast(DimSpec): + """Output is the broadcast of a singleton input dimension.""" + + dim: DimSpec + dim_size: int + + @classmethod + def new(cls, dim: DimSpec, dim_size: int) -> DimSpec: + return Broadcast(dim, dim_size) + + def inputs(self) -> Iterable[DimSpec]: + return (self.dim,) + + +@dataclass +class NewDim(DimSpec): + """This is a new dimension created by the op.""" + + size: int + + @classmethod + def new(cls, size: int) -> DimSpec: + return Singleton() if size == 1 else NewDim(size) + + +@dataclass +class Repeat(DimSpec): + """Output dimension is the input dimension repeated n-times.""" + + input_dim: DimSpec + times: int + + @classmethod + def new(cls, dim: DimSpec, times: int) -> DimSpec: + if times == 1: + return dim + elif isinstance(dim, Singleton): + # repeating a singleton is the same as broadcasting it + return Broadcast(dim, times) + else: + return Repeat(dim, times) + + def inputs(self) -> Iterable[DimSpec]: + return (self.input_dim,) + + +@dataclass +class Flatten(DimSpec): + """Flatten a set of input dimensions, ensuring right-most adjacent elements remain adjacent in the output.""" + + input_dims: Sequence[DimSpec] + + @classmethod + def new(cls, dims: Sequence[DimSpec]) -> DimSpec: + if len(dims) == 0: + # flattening a scalar leads to a singleton + return Singleton() + elif len(dims) == 1: + # flattening a single dimension is no-op + return dims[0] + else: + return Flatten(dims) + + def inputs(self) -> Iterable[DimSpec]: + return self.input_dims + + +@dataclass +class Split(DimSpec): + """ + This dimension is a member of a decomposition of the input dim. + + Note that input_dim itself could be a Flattened set of input dims. + """ + + input_dim: DimSpec + group_shape: Shape + split_id: int + + @classmethod + def new(cls, dim: DimSpec, group_shape: Tuple[int, ...], idx: int) -> DimSpec: + assert len(group_shape) > 0 + if len(group_shape) == 1: + # not really a group, just return the input dim back + assert idx == 0 + return dim + elif group_shape[idx] == 1: + return Singleton() + else: + # remove singletons from group + # group_mapping = [(new_index, (shape, old_index)) ...] + group_mapping = list( + enumerate((s, i) for i, s in enumerate(group_shape) if s != 1) + ) + new_group_shape = tuple(m[1][0] for m in group_mapping) + new_idx = next(filter(lambda x: x[1][1] == idx, group_mapping))[0] + return Split(dim, new_group_shape, new_idx) + + def inputs(self) -> Iterable[DimSpec]: + return (self.input_dim,) + + +def dim_pad_left(ndim: int, min_dims: int) -> DimMap: + return (Singleton(),) * max(0, min_dims - ndim) + tuple( + InputDim(i) for i in range(ndim) + ) + + +def dim_atleast_3d(ndim: int) -> DimMap: + if ndim == 0: + return (Singleton(), Singleton(), Singleton()) + elif ndim == 1: + return (Singleton(), InputDim(0), Singleton()) + elif ndim == 2: + return (InputDim(0), InputDim(1), Singleton()) + else: + return tuple(InputDim(i) for i in range(ndim)) + + +def expand(input_shape: Shape, shape: Shape) -> DimMap: + """Implement broadcast on multiple dimensions.""" + assert len(shape) >= len(input_shape) + + # 1. create padded input dimensions + padded_input = dim_pad_left(len(input_shape), len(shape)) + # 2. check that input shapes are compatible + mapping = [] + for p, desired_s in zip(padded_input, shape): + if isinstance(p, Singleton): + actual_s = 1 + assert desired_s >= 0 + else: + assert isinstance(p, InputDim), f"DimSpec not supported in expand: {p}" + actual_s = input_shape[p.input_dim] + assert actual_s == 1 or desired_s == -1 or desired_s == actual_s + mapping.append( + p + if desired_s in (1, -1) or desired_s == actual_s + else Broadcast.new(p, desired_s) + ) + return tuple(mapping) + + +def normalize_sizes(sizes: Union[Shape, Tuple[Shape]]) -> Shape: + if isinstance(sizes[0], int): + return cast(Shape, sizes) + elif len(sizes) == 1: + return cast(Shape, sizes[0]) # type: ignore[redundant-cast] + else: + raise RuntimeError("Size must be int... or tuple") + + +def dim_flatten(ndim: int, start_dim=0, end_dim=-1) -> DimMap: + if ndim == 0: + return (Singleton(),) + elif ndim == 1: + return (InputDim(0),) + else: + # only flattening dims from start_dim to end_dim (inclusive) + # other dims are passed through + if end_dim < 0: + end_dim += ndim + results: List[DimSpec] = [InputDim(i) for i in range(start_dim)] + results.append( + Flatten.new(tuple(InputDim(i) for i in range(start_dim, end_dim + 1))) + ) + results.extend([InputDim(i) for i in range(end_dim + 1, ndim)]) + return tuple(results) + + +def dim_movedim( + ndim: int, + input: Union[int, Sequence[int]], + destination: Union[int, Sequence[int]], +) -> DimMap: + input = normalize_dims(input, ndim) + destination = normalize_dims(destination, ndim) + + assert len(input) == len(destination) + input_set = set(input) + assert len(input_set) == len(input), "Found repeated input dims" + assert len(set(destination)) == len(destination), "Found repeated output dims" + assert max(input) < ndim + assert max(destination) < ndim + + dest = [-1] * ndim + for i, d in zip(input, destination): + dest[d] = i + + unused_inputs_iter = iter(i for i in range(ndim) if i not in input_set) + for i in range(ndim): + if dest[i] == -1: + dest[i] = next(unused_inputs_iter) + + return tuple(InputDim(i) for i in dest) + + +def dim_repeat(ndim: int, sizes: Shape) -> DimMap: + sizes = normalize_sizes(sizes) + assert ( + len(sizes) >= ndim + ), f"Number of dimensions of repeat dims {sizes} can not be smaller than number of dimensions of tensor {ndim}." + pad = len(sizes) - ndim + return tuple(Repeat.new(Singleton(), s) for s in sizes[:pad]) + tuple( + Repeat.new(InputDim(i), s) for i, s in enumerate(sizes[pad:]) + ) + + +def infer_size(total_size: int, sizes: Shape) -> Shape: + """ + One dimension input to view may be "-1". + + Infer the size of this dimension given the total_size. + """ + infers = [i for i, s in enumerate(sizes) if s == -1] + size = prod(sizes) + assert len(infers) <= 1, "can only infer one size" + if infers: + size = -size + missing_size = total_size // size + assert ( + total_size % size == 0 + ), f"size inferred for -1 is not integral {sizes} should have {total_size} elements." + return tuple(s if s != -1 else missing_size for s in sizes) + assert size == total_size, f"sizes do not match {total_size} vs {size}" + return sizes + + +def view_groups(from_size: Shape, to_size: Shape) -> DimMap: + """ + Decompose a reshape operation into forwarding, flattening, or splitting dimensions for each output dimension. + + A view or reshape operation can be decomposed into a set of 3 types of smaller operations: + 1) Forward a dimension from input to output + 2) Flatten a set of dimensions into a single dimension + 3) Split one dimension into multiple dimensions + + view_groups identifies these operations and returns, for each output dimension, what + is operation was performed in the input dimension. For example: + + view_groups([2, 3, 4], [2, 12]) -> ( + InputDim(0), + Flatten((InputDim(1), InputDim(2))) + ) + + - ouptut dimension 0 maps to input dimension 0 + - output dimension 1 maps to a flattened input dimensions 1 and 2 + + + view_groups([2, 3], [3, 2]) -> ( + Split(Flatten((InputDim(0), InputDim(1))), (3, 2), 0), + Split(Flatten((InputDim(0), InputDim(1))), (3, 2), 1), + ) + + - in the above, input is flattened into a single dimension and then split + into two separate dimensions with different sizes from the input. + """ + from_nelem = prod(from_size) + to_size = infer_size(from_nelem, normalize_sizes(to_size)) + + assert from_nelem == prod(to_size), "Total view shape does not add up" + + from_idx = 0 + to_idx = 0 + from_len = len(from_size) + to_len = len(to_size) + + result_pp = [] + + while from_idx < from_len or to_idx < to_len: + from_group_dim, to_group_shape = [], [] + + if from_idx >= from_len: + f = 1 + else: + f = from_size[from_idx] + from_group_dim.append(from_idx) + from_idx += 1 + + if to_idx >= to_len: + t = 1 + else: + t = to_size[to_idx] + to_group_shape.append(t) + to_idx += 1 + + # if any of the groups is singleton, great, we need to backtrack though + if f == 1 and t != 1: + # produces ([1], []) + to_idx -= 1 + to_group_shape = [] + elif f != 1 and t == 1: + # produces ([], [1]) + from_idx -= 1 + from_group_dim = [] + else: + # produces ([1], [1]), ([2], [2]), ([2,3], [6]) + while f != t: + if f < t: + nf = from_size[from_idx] + from_group_dim.append(from_idx) + from_idx += 1 + f *= nf + else: + nt = to_size[to_idx] + to_group_shape.append(nt) + to_idx += 1 + t *= nt + + if len(to_group_shape) > 0: + flattened = Flatten.new( + tuple(InputDim(fi) for fi in from_group_dim if from_size[fi] > 1) + ) + result_pp += [ + Split.new(flattened, tuple(to_group_shape), i) + for i in range(len(to_group_shape)) + ] + + return tuple(result_pp) + + +def dim_tile(ndim: int, dims: Tuple[int, ...]) -> DimMap: + if len(dims) < ndim: + dims = (1,) * (ndim - len(dims)) + dims + return dim_repeat(ndim, dims) + + +def dim_transpose(ndim: int, dim1: int, dim2: int) -> DimMap: + dim1 = normalize_dim(dim1, ndim) + dim2 = normalize_dim(dim2, ndim) + assert dim1 < ndim + assert dim2 < ndim + dimmap = [InputDim(i) for i in range(ndim)] + swapdim = dimmap[dim1] + dimmap[dim1] = dimmap[dim2] + dimmap[dim2] = swapdim + return tuple(dimmap) + + +def dim_squeeze(shape: Shape, dim: Optional[int] = None) -> DimMap: + # FIXME: this is wrong when dim=None and one of the dimensions + # equals size of the mesh. For example squeeze(DTensor(tensor(4), Shard[0])) could + # end up as squeeze(tensor(1)) if we have 4 devices; this would lead to + # removal of a dimension that is not actually a singleton. + return tuple( + InputDim(i) + for i, s in enumerate(shape) + if s > 1 or (dim is not None and i != normalize_dim(dim, len(shape))) + ) + + +def dim_unsqueeze(ndim: int, dim: int) -> DimMap: + dims = tuple(InputDim(i) for i in range(ndim)) + if dim < 0: + dim += ndim + 1 + return dims[:dim] + (Singleton(),) + dims[dim:] + + +def dim_view_as_real(shape: Shape) -> DimMap: + ndim = len(shape) + results: List[DimSpec] = [InputDim(i) for i in range(ndim - 1)] + # each complex number is split into two real numbers, + # resulting in one more dimension of size 2 + results.append(Split(InputDim(ndim - 1), (shape[-1], 2), 0)) + results.append(Split(InputDim(ndim - 1), (shape[-1], 2), 1)) + return tuple(results) + + +def dim_reduction( + ndim: int, dim_or_dims: Optional[Union[int, Sequence[int]]], keepdim: bool +) -> DimMap: + """ + General fallback for reduction ops where Partial() does not apply. + + This will cause incoming tensor to be replicated on the reducing dimensions. + """ + if dim_or_dims is None: + dim_or_dims = tuple(range(ndim)) + if isinstance(dim_or_dims, int): + dim_or_dims = (dim_or_dims,) + dim_or_dims = tuple(d if d >= 0 else d + ndim for d in dim_or_dims) + return tuple( + InputDim(i) if i not in dim_or_dims else Singleton() + for i in range(ndim) + if i not in dim_or_dims or keepdim + ) + + +dim_maps: Dict[Callable[..., torch.Tensor], Callable[..., DimMap]] = { + torch.atleast_1d: lambda x: dim_pad_left(x.ndim, 1), + torch.atleast_2d: lambda x: dim_pad_left(x.ndim, 2), + torch.atleast_3d: lambda x: dim_atleast_3d(x.ndim), + torch.broadcast_to: lambda input, shape: expand(input.shape, shape), + Tensor.expand: lambda self, *sizes: expand(self.shape, normalize_sizes(sizes)), + torch.flatten: lambda tensor: dim_flatten(tensor.ndim), + torch.movedim: lambda input, source, destination: dim_movedim( + input.ndim, source, destination + ), + torch.permute: lambda input, *dims: tuple( + InputDim(i) for i in normalize_dims(tuple(dims), input.ndim) + ), + torch.ravel: lambda tensor: dim_flatten(tensor.ndim), + Tensor.repeat: lambda self, *sizes: dim_repeat(self.ndim, sizes), + torch.reshape: lambda input, shape: view_groups(input.shape, shape), + torch.squeeze: lambda input, dim=None: dim_squeeze(input.shape, dim), + torch.tile: lambda input, dims: dim_tile(input.ndim, dims), + torch.transpose: lambda input, dim0, dim1: dim_transpose(input.ndim, dim0, dim1), + torch.unsqueeze: lambda input, dim: dim_unsqueeze(input.ndim, dim), + Tensor.view: lambda input, *shape: view_groups(input.shape, shape), + torch.view_as_complex: lambda input: dim_flatten(input.ndim, input.ndim - 2), + torch.view_as_real: lambda input: dim_view_as_real(input.shape), +} + + +def propagate_shape_and_sharding( + input_src_placements: Sequence[Placement], + local_in_shape: Shape, + rule: DimMap, + mesh_sizes: Shape, +) -> Tuple[Sequence[Placement], Sequence[Placement]]: + """ + Determine input target sharding and output sharding based on + given global tensor shape and input source sharding. + + Sharding propagation follows mapped dimensions: + - An output dimension that maps directly to an input dimension is sharded equally + - An output dimension that is a flattened set of input dimensions can only be + sharded if only the leftmost flattened dimension is sharded. + - An output dimension that is a split of the input dimension can only be sharded + if the leftmost split size is divisible by the mesh dimension + """ + assert len(input_src_placements) == len(mesh_sizes) + # for each input dim, for each mesh dim, provides a list of possible shardable dimensions + mesh_ndim = len(mesh_sizes) + shardable_dims: Dict[int, List[bool]] = {} + + # in case an input dimension disappears (e.g. collapsing, reduction) + # we cannot shard in that dimension (we need a replication fall-back rule) + seen_input_dims: Set[int] = set() + + def collect_used_inputs(cmd: DimSpec) -> None: + if isinstance(cmd, InputDim): + seen_input_dims.add(cmd.input_dim) + for inp in cmd.inputs(): + collect_used_inputs(inp) + + for cmd in rule: + collect_used_inputs(cmd) + for dim in range(len(local_in_shape)): + shardable_dims[dim] = [dim in seen_input_dims] * mesh_ndim + + def get_in_dim_to_shard(cmd: DimSpec) -> Optional[InputDim]: + if isinstance(cmd, InputDim): + return cmd + elif isinstance(cmd, Flatten): + for dim in cmd.input_dims[1:]: + if isinstance(dim, InputDim): + shardable_dims[dim.input_dim] = [False] * mesh_ndim + dim0 = cmd.input_dims[0] + return dim0 if isinstance(dim0, InputDim) else None + elif isinstance(cmd, Split): + in_dim = get_in_dim_to_shard(cmd.input_dim) + out_size = cmd.group_shape[cmd.split_id] + if cmd.split_id == 0 and in_dim is not None: + # we need to check that the input dimension is divisible + # by the size of the submesh we're sharding it on + # NOTE: it would be possible to shard the same input dimension + # on more than one mesh dimension. In that case, the dimension + # needs to be divisible by the product of mesh sizes. + # In order to keep the problem more tractable, we will not consider + # double resharding as a suggestion (e.g. [Shard(0), Shard(0) ]) + # but we will allow it if that's the input and it's compatible + + # 1. is this dimension shardable on each individual mesh dim? + shardable_dims[in_dim.input_dim] = [ + out_size % mesh_dim_size == 0 for mesh_dim_size in mesh_sizes + ] + + # 2. here we special case things like [Shard(0), Shard(0)] + submesh_size = 1 + for size, shard in zip(mesh_sizes, input_src_placements): + if isinstance(shard, Shard) and shard.dim == in_dim: + submesh_size *= size + assert ( + out_size % submesh_size == 0 + ), f"Resulting dimension size {out_size} is not divisible by its mesh dimension {submesh_size}." + + # we will only shard our first component of the split + return in_dim if cmd.split_id == 0 else None + elif isinstance(cmd, Repeat): + in_dim = get_in_dim_to_shard(cmd.input_dim) + if in_dim is not None: + shardable_dims[in_dim.input_dim] = [False] * mesh_ndim + return None + else: + return None + + # for each output dim, find the corresponding input dim in terms of sharding prop + shard_dim_map = {} + for dim, cmd in enumerate(rule): + in_dim = get_in_dim_to_shard(cmd) + if in_dim is not None: + shard_dim_map[in_dim.input_dim] = dim + + input_tgt_placements = [ + Replicate() + if isinstance(p, Shard) and not shardable_dims[p.dim][mesh_dim] + else p + for mesh_dim, p in enumerate(input_src_placements) + ] + output_placements = [ + Shard(shard_dim_map[p.dim]) if isinstance(p, Shard) else p + for p in input_tgt_placements + ] + + return input_tgt_placements, output_placements + + +def get_reshape_strategy(op): + dim_map = dim_maps[op] + + # def reshape_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: + def reshape_strategy(meta, mesh): + breakpoint() + assert meta.node.op == "call_method", "Node should have call_method op." + args_schema = [meta["common"]["self"]] + [i for i in meta["common"]["args"].values()] + rules = dim_map(*args_schema) + parent_node = meta.node.args[0] + # input_strategy = cast(OpStrategy, op_schema.args_schema[0]) + input_strategy = parent_node.meta["mase"]["software"]["autosharding"]["op_strategy"] + global_in_shape = meta["common"]["self"].shape + assert global_in_shape is not None, "Shape required." + + output_strategy = OpStrategy([]) + for input_placement_strategy in input_strategy.strategies: + input_src_spec = input_placement_strategy.output_spec + + input_tgt_placements, output_placements = propagate_shape_and_sharding( + input_src_spec.placements, + tuple(global_in_shape), + rules, + mesh, + ) + + # TODO: optimize this. we shouldn't simply blindly replicate + # unshardable dims ... + # FIXME: this can be wrong for situations where we have + # [Shard(0), Shard(0)] + input_tgt_spec = DTensorSpec( + placements=tuple(input_tgt_placements), + mesh=input_src_spec.mesh, + tensor_meta=input_src_spec.tensor_meta, + ) + + output_spec = DTensorSpec(mesh=mesh, placements=tuple(output_placements)) + output_strategy.strategies.append( + PlacementStrategy( + output_specs=output_spec, + input_specs=(input_tgt_spec,), + ) + ) + + return output_strategy + + return reshape_strategy \ No newline at end of file From 41a0eb8ee6d53a6e09e8b03c54c6ce31b0320b00 Mon Sep 17 00:00:00 2001 From: Pedro Gimenes Date: Thu, 27 Jun 2024 15:57:29 +0000 Subject: [PATCH 17/93] [REFACTOR] unfinished: pointwise ops (add, gelu) --- .../graph/analysis/autosharding/alpa.py | 31 +++- .../analysis/autosharding/alpa_layers.py | 5 +- .../autosharding/ops/pointwise_ops.py | 145 ++++++++++++++++++ .../analysis/autosharding/ops/view_ops.py | 1 - 4 files changed, 176 insertions(+), 6 deletions(-) create mode 100644 src/chop/passes/graph/analysis/autosharding/ops/pointwise_ops.py diff --git a/src/chop/passes/graph/analysis/autosharding/alpa.py b/src/chop/passes/graph/analysis/autosharding/alpa.py index 48ca08075..a61f36a32 100644 --- a/src/chop/passes/graph/analysis/autosharding/alpa.py +++ b/src/chop/passes/graph/analysis/autosharding/alpa.py @@ -100,6 +100,20 @@ def placeholder_or_getattr_strategy(meta, mesh): )) return OpStrategy(shardings) +def fully_replicated_strategy(meta, mesh): + """ + Output of ops like size, getitem etc are always fully replicated + """ + sharding = [Replicate(), Replicate()] + spec = DTensorSpec(mesh, sharding) + shardings = [ + PlacementStrategy( + input_specs=spec, + output_specs=spec + ) + ] + return OpStrategy(shardings) + def alpa_intra_op_sharding_pass(mg, mesh, debug=False): """ Intra-operator auto parallelization pass. @@ -111,12 +125,21 @@ def alpa_intra_op_sharding_pass(mg, mesh, debug=False): expr = 0 constr = [] - # Write cost vectors into metadata for each operator - # This will later be used to solve the ILP optimization + # Find sharding strategies for each operator in the graph for node in mg.fx_graph.nodes: if (node.op == "call_function" and node.target in IGNORE_FUNCS) or (node.op == "call_method" and node.target in IGNORE_METHODS): - logger.debug(f"Ignoring {node.op} node {node.name} with target {node.target}") + logger.debug(f"Implicit {node.op} node {node.name} was assigned fully replicated sharding.") + + op_strategy = fully_replicated_strategy(node.meta["mase"], mesh.mesh_shape) + + # Opt var is None since no decision needs to be taken + node.meta["mase"]["software"]["autosharding"] = { + "op_strategy": op_strategy, + "opt_var": None + } + + breakpoint() continue # Obtain strategy according to node op @@ -143,8 +166,8 @@ def alpa_intra_op_sharding_pass(mg, mesh, debug=False): else: logger.warning(f"Unknown node {node.name} with op {node.op}") - continue breakpoint() + continue # Formulate optimization variable and consider compute/communication cost opt_var = cp.Variable(len(op_strategy.strategies), boolean=True) diff --git a/src/chop/passes/graph/analysis/autosharding/alpa_layers.py b/src/chop/passes/graph/analysis/autosharding/alpa_layers.py index b0c1df206..5d0d71cca 100644 --- a/src/chop/passes/graph/analysis/autosharding/alpa_layers.py +++ b/src/chop/passes/graph/analysis/autosharding/alpa_layers.py @@ -1,5 +1,5 @@ import itertools -import numpy as np +import operator import torch import torch.nn as nn @@ -19,6 +19,7 @@ ) from .ops.view_ops import get_reshape_strategy +from .ops.pointwise_ops import linear_pointwise_strategy logger = get_logger(__name__) @@ -28,6 +29,8 @@ torch.addmm: addmm_strategy, torch.bmm: bmm_strategy, torch.baddbmm: baddmm_strategy, + torch.add: linear_pointwise_strategy, + operator.add: linear_pointwise_strategy } ALPA_METHODS = { diff --git a/src/chop/passes/graph/analysis/autosharding/ops/pointwise_ops.py b/src/chop/passes/graph/analysis/autosharding/ops/pointwise_ops.py new file mode 100644 index 000000000..ba69509bb --- /dev/null +++ b/src/chop/passes/graph/analysis/autosharding/ops/pointwise_ops.py @@ -0,0 +1,145 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +from typing import List, Sequence, Tuple + +import torch +from torch.distributed._tensor._op_schema import ( + _is_inplace_op, + _is_out_variant_op, + OpSchema, + OpStrategy, + PlacementStrategy, + RuntimeSchemaInfo, + StrategyType, + TupleStrategy, +) +from torch.distributed._tensor.ops.utils import ( + generate_redistribute_costs, + infer_broadcast_dims_map, + map_placements_after_broadcast, + normalize_dim, + register_op_strategy, +) +from torch.distributed._tensor.placement_types import ( + DTensorSpec, + Partial, + Placement, + Replicate, + Shard, +) +from torch.distributed.device_mesh import DeviceMesh + + +aten = torch.ops.aten + +def pointwise_strategy( + meta, mesh, linearity = False +): + max_shards_strategy_index = -1 + max_shards = -1 + followed_strategy = None + + # if _is_inplace_op(op_schema.op): + # # inplace op should follow the first arg strategy + # followed_strategy = op_schema.args_schema[0] + # elif _is_out_variant_op(op_schema.op): + # # out variant op should follow the out kwarg strategy + # followed_strategy = op_schema.kwargs_schema["out"] + # else: + + # normal pointwise op, we choose to follow the arg with + # the max shards in case operands needs reshard + for idx, arg in enumerate(meta.node.args): + if not isinstance(arg, torch.fx.Node): + continue + arg_strategy = arg.meta["mase"]["software"]["autosharding"]["op_strategy"] + + + arg_max_shards = arg_strategy.max_num_shards() + if arg_max_shards > max_shards: + max_shards_strategy_index = idx + max_shards = arg_max_shards + followed_strategy = arg_strategy + + assert isinstance( + followed_strategy, OpStrategy + ), f"no strategy to follow for {op_schema}!" + + return common_pointwise_strategy( + meta, mesh, followed_strategy, linearity + ) + + +def common_pointwise_strategy( + meta, + mesh, + followed_strategy, + linearity +): + breakpoint() + # handle broadcasting + common_shape = torch.broadcast_shapes( + *[arg.shape for arg in args_schema if isinstance(arg, OpStrategy)] + ) + pointwise_strategy = OpStrategy([]) + + for placement_strategy in followed_strategy.strategies: + spec_to_follow = placement_strategy.output_spec + out_placements: List[Placement] = [] + for placement in spec_to_follow.placements: + if isinstance(placement, Shard): + shard_dim = normalize_dim(placement.dim, len(spec_to_follow.shape)) + common_ndim = len(common_shape) + new_shard_dim = common_ndim - len(spec_to_follow.shape) + shard_dim + out_placements.append(Shard(new_shard_dim)) + elif isinstance(placement, Partial) and not linearity: + # clear the partial placemnet if op does not support linearity + # by default we just replicate the partial, need to see if this + # is optimal for all cases + out_placements.append(Replicate()) + else: + out_placements.append(placement) + + input_specs: List[DTensorSpec] = [] + redistribute_costs: List[List[float]] = [] + for idx, input_arg in enumerate(args_schema): + if isinstance(input_arg, OpStrategy): + # every arg follow the out_placements, but need to handle broadcasting + input_arg_spec = input_arg.strategies[0].output_spec + input_arg_dims_map = infer_broadcast_dims_map( + common_shape, input_arg_spec.shape + ) + input_target_placements = map_placements_after_broadcast( + tuple(out_placements), + common_shape, + input_arg_dims_map, + ) + input_arg_target_spec = DTensorSpec( + mesh=mesh, + placements=input_target_placements, + tensor_meta=input_arg_spec.tensor_meta, + ) + input_specs.append(input_arg_target_spec) + redistribute_costs.append( + generate_redistribute_costs(input_arg, input_arg_target_spec) + ) + + pointwise_strategy.strategies.append( + PlacementStrategy( + output_specs=DTensorSpec( + mesh=mesh, + placements=tuple(out_placements), + ), + input_specs=input_specs, + redistribute_cost=redistribute_costs, + ) + ) + return pointwise_strategy + + +def linear_pointwise_strategy(meta, mesh): + """ + Linear pointwise operators can propagate pending reductions. + For example, c = add(a, b); if a is pending sum, then c will be + pending sum as well without any communication overhead. + """ + return pointwise_strategy(meta, mesh, linearity=True) \ No newline at end of file diff --git a/src/chop/passes/graph/analysis/autosharding/ops/view_ops.py b/src/chop/passes/graph/analysis/autosharding/ops/view_ops.py index 2297c0bf0..87cdd833b 100644 --- a/src/chop/passes/graph/analysis/autosharding/ops/view_ops.py +++ b/src/chop/passes/graph/analysis/autosharding/ops/view_ops.py @@ -589,7 +589,6 @@ def get_reshape_strategy(op): # def reshape_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: def reshape_strategy(meta, mesh): - breakpoint() assert meta.node.op == "call_method", "Node should have call_method op." args_schema = [meta["common"]["self"]] + [i for i in meta["common"]["args"].values()] rules = dim_map(*args_schema) From f76a6d0bcba78dd76e0bad2fd7c1c213b5012cda Mon Sep 17 00:00:00 2001 From: Pedro Gimenes Date: Thu, 27 Jun 2024 16:47:35 +0000 Subject: [PATCH 18/93] [REFACTOR] enumerate strategies for pointwise add --- .../graph/analysis/autosharding/alpa.py | 2 -- .../autosharding/ops/pointwise_ops.py | 21 +++++++++++++++---- 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/src/chop/passes/graph/analysis/autosharding/alpa.py b/src/chop/passes/graph/analysis/autosharding/alpa.py index a61f36a32..973ff5235 100644 --- a/src/chop/passes/graph/analysis/autosharding/alpa.py +++ b/src/chop/passes/graph/analysis/autosharding/alpa.py @@ -138,8 +138,6 @@ def alpa_intra_op_sharding_pass(mg, mesh, debug=False): "op_strategy": op_strategy, "opt_var": None } - - breakpoint() continue # Obtain strategy according to node op diff --git a/src/chop/passes/graph/analysis/autosharding/ops/pointwise_ops.py b/src/chop/passes/graph/analysis/autosharding/ops/pointwise_ops.py index ba69509bb..4a7a33a41 100644 --- a/src/chop/passes/graph/analysis/autosharding/ops/pointwise_ops.py +++ b/src/chop/passes/graph/analysis/autosharding/ops/pointwise_ops.py @@ -75,10 +75,20 @@ def common_pointwise_strategy( followed_strategy, linearity ): - breakpoint() # handle broadcasting + parsed_args = [] + for arg in meta["common"]["args"].values(): + if isinstance(arg, torch.Size): + parsed_args.append(torch.Tensor(list(arg))) + elif isinstance(arg, (tuple, list)): + parsed_args.append(torch.Tensor(arg)) + elif isinstance(arg, torch.Tensor): + parsed_args.append(arg) + else: + raise ValueError("Unrecognized arg type") + common_shape = torch.broadcast_shapes( - *[arg.shape for arg in args_schema if isinstance(arg, OpStrategy)] + *[arg.shape for arg in parsed_args] ) pointwise_strategy = OpStrategy([]) @@ -101,12 +111,15 @@ def common_pointwise_strategy( input_specs: List[DTensorSpec] = [] redistribute_costs: List[List[float]] = [] - for idx, input_arg in enumerate(args_schema): + for arg_node in meta.node.args: + if not isinstance(arg_node, torch.fx.Node): + continue + input_arg = arg_node.meta["mase"]["software"]["autosharding"]["op_strategy"] if isinstance(input_arg, OpStrategy): # every arg follow the out_placements, but need to handle broadcasting input_arg_spec = input_arg.strategies[0].output_spec input_arg_dims_map = infer_broadcast_dims_map( - common_shape, input_arg_spec.shape + common_shape, arg_node.meta["mase"]["common"]["results"]["data_out_0"]["shape"] ) input_target_placements = map_placements_after_broadcast( tuple(out_placements), From b15bf87af38ad6b31433ade7168d5a5362e58b1c Mon Sep 17 00:00:00 2001 From: Pedro Gimenes Date: Thu, 27 Jun 2024 17:25:03 +0000 Subject: [PATCH 19/93] [REFACTOR] unfinished: pointwise truediv --- .../graph/analysis/autosharding/alpa.py | 67 ++-------------- .../analysis/autosharding/alpa_layers.py | 40 ---------- .../graph/analysis/autosharding/layers.py | 80 +++++++++++++++++++ .../graph/analysis/autosharding/mesh_model.py | 9 +++ .../autosharding/ops/basic_strategy.py | 2 +- .../autosharding/ops/pointwise_ops.py | 18 +++-- .../analysis/autosharding/ops/view_ops.py | 2 +- 7 files changed, 108 insertions(+), 110 deletions(-) delete mode 100644 src/chop/passes/graph/analysis/autosharding/alpa_layers.py create mode 100644 src/chop/passes/graph/analysis/autosharding/layers.py diff --git a/src/chop/passes/graph/analysis/autosharding/alpa.py b/src/chop/passes/graph/analysis/autosharding/alpa.py index 973ff5235..4c8b1e398 100644 --- a/src/chop/passes/graph/analysis/autosharding/alpa.py +++ b/src/chop/passes/graph/analysis/autosharding/alpa.py @@ -7,19 +7,12 @@ from chop.tools import get_logger -from .common import SpmdShard -from .alpa_layers import ALPA_FUNCTIONS, ALPA_METHODS +from .layers import ALPA_FUNCTIONS, ALPA_METHODS, IMPLICIT_FUNCS, IMPLICIT_METHODS, placeholder_or_getattr_strategy, fully_replicated_strategy from .alpa_cost_modelling import get_resharding_matrix logger = get_logger(__name__) logger.setLevel("DEBUG") -import operator -IGNORE_FUNCS = [operator.getitem] -IGNORE_METHODS = [ - "size" -] - def deepgetattr(obj, attr, default=None): """Recurses through an attribute chain to get the ultimate value.""" try: @@ -35,17 +28,6 @@ def get_node_target(node): return node.target -def assign_default_sharding(node): - rank = len(node.meta["mase"]["common"]["results"]["data_out_0"]["shape"]) - node.meta["mase"]["software"]["autosharding"] = { - "valid_input_shardings": [{"data_in_0": (SpmdShard.R,) * rank}], - "valid_output_shardings": [(SpmdShard.R,) * rank], - "compute_cost_vector": [0], - "communication_cost_vector": [0], - "opt_var": np.array([1]), - } - - def mark_choices(mg): """ Once the metadata has already been filled for each op with the possible shardings and costs, @@ -84,36 +66,6 @@ def mark_choices(mg): return mg -from torch.distributed._tensor._op_schema import OpStrategy, PlacementStrategy -from torch.distributed._tensor.placement_types import Replicate, Shard, DTensorSpec -import itertools - -def placeholder_or_getattr_strategy(meta, mesh): - ndims = len(meta["common"]["results"]["data_out_0"]["shape"]) - opts = [Replicate()] + [Shard(dim) for dim in range(ndims)] - shardings = [] - for sharding in itertools.product(opts, repeat=2): - spec = DTensorSpec(mesh, sharding) - shardings.append(PlacementStrategy( - input_specs=spec, - output_specs=spec - )) - return OpStrategy(shardings) - -def fully_replicated_strategy(meta, mesh): - """ - Output of ops like size, getitem etc are always fully replicated - """ - sharding = [Replicate(), Replicate()] - spec = DTensorSpec(mesh, sharding) - shardings = [ - PlacementStrategy( - input_specs=spec, - output_specs=spec - ) - ] - return OpStrategy(shardings) - def alpa_intra_op_sharding_pass(mg, mesh, debug=False): """ Intra-operator auto parallelization pass. @@ -128,10 +80,10 @@ def alpa_intra_op_sharding_pass(mg, mesh, debug=False): # Find sharding strategies for each operator in the graph for node in mg.fx_graph.nodes: - if (node.op == "call_function" and node.target in IGNORE_FUNCS) or (node.op == "call_method" and node.target in IGNORE_METHODS): + if (node.op == "call_function" and node.target in IMPLICIT_FUNCS) or (node.op == "call_method" and node.target in IMPLICIT_METHODS): logger.debug(f"Implicit {node.op} node {node.name} was assigned fully replicated sharding.") - op_strategy = fully_replicated_strategy(node.meta["mase"], mesh.mesh_shape) + op_strategy = fully_replicated_strategy(node.meta["mase"], mesh) # Opt var is None since no decision needs to be taken node.meta["mase"]["software"]["autosharding"] = { @@ -145,22 +97,15 @@ def alpa_intra_op_sharding_pass(mg, mesh, debug=False): if node.op in ["placeholder", "get_attr"]: logger.debug(f"Node {node} with op {node.op} will be assigned all permutations of Shard(dims) and Replicate()") - op_strategy = placeholder_or_getattr_strategy(node.meta["mase"], mesh.mesh_shape) + op_strategy = placeholder_or_getattr_strategy(node.meta["mase"], mesh) elif node.op == "call_method" and node.target in ALPA_METHODS.keys(): logger.debug(f"Obtaining strategy for node {node.name}") - op_strategy = ALPA_METHODS[node.target](node.meta["mase"], mesh.mesh_shape) + op_strategy = ALPA_METHODS[node.target](node.meta["mase"], mesh) elif node.op == "call_function" and node.target in ALPA_FUNCTIONS.keys(): - # Enumerate shardings and costs for this operator - # ( - # input_shardings, - # output_shardings, - # compute_cost_vector, - # communication_cost_vector, - # ) = ALPA_FUNCTIONS[node.target](node.meta, mesh) logger.debug(f"Obtaining strategy for node {node.name}") - op_strategy = ALPA_FUNCTIONS[node.target](node.meta["mase"], mesh.mesh_shape) + op_strategy = ALPA_FUNCTIONS[node.target](node.meta["mase"], mesh) else: logger.warning(f"Unknown node {node.name} with op {node.op}") diff --git a/src/chop/passes/graph/analysis/autosharding/alpa_layers.py b/src/chop/passes/graph/analysis/autosharding/alpa_layers.py deleted file mode 100644 index 5d0d71cca..000000000 --- a/src/chop/passes/graph/analysis/autosharding/alpa_layers.py +++ /dev/null @@ -1,40 +0,0 @@ -import itertools -import operator - -import torch -import torch.nn as nn -import torch.nn.functional as F - -from chop.tools import get_logger -from chop.models.patched.bert.modeling_bert import BertSelfAttention - -from .alpa_cost_modelling import get_communication_cost - -from .ops.matrix_ops import ( - transpose_strategy, - mm_strategy, - addmm_strategy, - bmm_strategy, - baddmm_strategy, -) - -from .ops.view_ops import get_reshape_strategy -from .ops.pointwise_ops import linear_pointwise_strategy - -logger = get_logger(__name__) - -ALPA_FUNCTIONS = { - torch.transpose: transpose_strategy, - torch.mm: mm_strategy, - torch.addmm: addmm_strategy, - torch.bmm: bmm_strategy, - torch.baddbmm: baddmm_strategy, - torch.add: linear_pointwise_strategy, - operator.add: linear_pointwise_strategy -} - -ALPA_METHODS = { - "view": get_reshape_strategy(torch.Tensor.view), - "expand": get_reshape_strategy(torch.Tensor.expand), - "permute": get_reshape_strategy(torch.permute), -} diff --git a/src/chop/passes/graph/analysis/autosharding/layers.py b/src/chop/passes/graph/analysis/autosharding/layers.py new file mode 100644 index 000000000..842ca167c --- /dev/null +++ b/src/chop/passes/graph/analysis/autosharding/layers.py @@ -0,0 +1,80 @@ +import itertools +import operator + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.distributed._tensor._op_schema import OpStrategy, PlacementStrategy +from torch.distributed._tensor.placement_types import Replicate, Shard, DTensorSpec + +from chop.tools import get_logger +from chop.models.patched.bert.modeling_bert import BertSelfAttention + +from .alpa_cost_modelling import get_communication_cost + +from .ops.matrix_ops import ( + transpose_strategy, + mm_strategy, + addmm_strategy, + bmm_strategy, + baddmm_strategy, +) + +from .ops.view_ops import get_reshape_strategy +from .ops.pointwise_ops import pointwise_strategy, linear_pointwise_strategy + + +logger = get_logger(__name__) + +ALPA_FUNCTIONS = { + torch.transpose: transpose_strategy, + torch.mm: mm_strategy, + torch.addmm: addmm_strategy, + torch.bmm: bmm_strategy, + torch.baddbmm: baddmm_strategy, + torch.add: linear_pointwise_strategy, + operator.add: linear_pointwise_strategy, + operator.truediv: pointwise_strategy, + torch.matmul: bmm_strategy +} + +ALPA_METHODS = { + "view": get_reshape_strategy(torch.Tensor.view), + "expand": get_reshape_strategy(torch.Tensor.expand), + "permute": get_reshape_strategy(torch.permute), + "transpose": get_reshape_strategy(torch.transpose) +} + +IMPLICIT_FUNCS = [ + operator.getitem +] + +IMPLICIT_METHODS = [ + "size" +] + +def placeholder_or_getattr_strategy(meta, mesh): + ndims = len(meta["common"]["results"]["data_out_0"]["shape"]) + opts = [Replicate()] + [Shard(dim) for dim in range(ndims)] + shardings = [] + for sharding in itertools.product(opts, repeat=2): + spec = DTensorSpec(mesh, sharding) + shardings.append(PlacementStrategy( + input_specs=spec, + output_specs=spec + )) + return OpStrategy(shardings) + +def fully_replicated_strategy(meta, mesh): + """ + Output of ops like size, getitem etc are always fully replicated + """ + sharding = [Replicate(), Replicate()] + spec = DTensorSpec(mesh, sharding) + shardings = [ + PlacementStrategy( + input_specs=spec, + output_specs=spec + ) + ] + return OpStrategy(shardings) \ No newline at end of file diff --git a/src/chop/passes/graph/analysis/autosharding/mesh_model.py b/src/chop/passes/graph/analysis/autosharding/mesh_model.py index 79aa93ce9..311da28ff 100644 --- a/src/chop/passes/graph/analysis/autosharding/mesh_model.py +++ b/src/chop/passes/graph/analysis/autosharding/mesh_model.py @@ -12,6 +12,15 @@ def __init__(self, mesh_shape, mesh_alpha = None, mesh_beta = None): self.mesh_alpha = [0] * 2 if mesh_alpha is None else mesh_alpha self.mesh_beta = [None] * 2 if mesh_beta is None else mesh_beta + def __getitem__(self, key): + return self.mesh_shape[key] + + def size(self, dim: None): + if dim is None: + return np.prod(self.mesh_shape) + else: + return self.mesh_shape[dim] + def set_cost_model_parameters(self, intra_node_bandwidth: int, inter_node_bandwidth:int, backend:str = "default"): # Assign differently depending if backend is NVLink, Infiniband, etc if (backend == "default"): diff --git a/src/chop/passes/graph/analysis/autosharding/ops/basic_strategy.py b/src/chop/passes/graph/analysis/autosharding/ops/basic_strategy.py index f96666ef1..6db629c89 100644 --- a/src/chop/passes/graph/analysis/autosharding/ops/basic_strategy.py +++ b/src/chop/passes/graph/analysis/autosharding/ops/basic_strategy.py @@ -98,7 +98,7 @@ def gen_einsum_strategies( all_mesh_dim_strategies = [] # generate strategies for each mesh dim - for mesh_dim in range(len(mesh)): + for mesh_dim in range(len(mesh.mesh_shape)): mesh_dim_strategies = [] # placement list stores placements of [output, input1, input2, ...] diff --git a/src/chop/passes/graph/analysis/autosharding/ops/pointwise_ops.py b/src/chop/passes/graph/analysis/autosharding/ops/pointwise_ops.py index 4a7a33a41..9434d5574 100644 --- a/src/chop/passes/graph/analysis/autosharding/ops/pointwise_ops.py +++ b/src/chop/passes/graph/analysis/autosharding/ops/pointwise_ops.py @@ -53,7 +53,6 @@ def pointwise_strategy( continue arg_strategy = arg.meta["mase"]["software"]["autosharding"]["op_strategy"] - arg_max_shards = arg_strategy.max_num_shards() if arg_max_shards > max_shards: max_shards_strategy_index = idx @@ -78,13 +77,18 @@ def common_pointwise_strategy( # handle broadcasting parsed_args = [] for arg in meta["common"]["args"].values(): - if isinstance(arg, torch.Size): + if isinstance(arg, dict): + parsed_args.append(torch.zeros(arg["shape"])) + elif isinstance(arg, torch.Size): parsed_args.append(torch.Tensor(list(arg))) elif isinstance(arg, (tuple, list)): parsed_args.append(torch.Tensor(arg)) elif isinstance(arg, torch.Tensor): parsed_args.append(arg) + elif isinstance(arg, float): + parsed_args.append(torch.Tensor([arg])) else: + breakpoint() raise ValueError("Unrecognized arg type") common_shape = torch.broadcast_shapes( @@ -110,7 +114,7 @@ def common_pointwise_strategy( out_placements.append(placement) input_specs: List[DTensorSpec] = [] - redistribute_costs: List[List[float]] = [] + # redistribute_costs: List[List[float]] = [] for arg_node in meta.node.args: if not isinstance(arg_node, torch.fx.Node): continue @@ -132,9 +136,9 @@ def common_pointwise_strategy( tensor_meta=input_arg_spec.tensor_meta, ) input_specs.append(input_arg_target_spec) - redistribute_costs.append( - generate_redistribute_costs(input_arg, input_arg_target_spec) - ) + # redistribute_costs.append( + # generate_redistribute_costs(input_arg, input_arg_target_spec) + # ) pointwise_strategy.strategies.append( PlacementStrategy( @@ -143,7 +147,7 @@ def common_pointwise_strategy( placements=tuple(out_placements), ), input_specs=input_specs, - redistribute_cost=redistribute_costs, + # redistribute_cost=redistribute_costs, ) ) return pointwise_strategy diff --git a/src/chop/passes/graph/analysis/autosharding/ops/view_ops.py b/src/chop/passes/graph/analysis/autosharding/ops/view_ops.py index 87cdd833b..b96556f15 100644 --- a/src/chop/passes/graph/analysis/autosharding/ops/view_ops.py +++ b/src/chop/passes/graph/analysis/autosharding/ops/view_ops.py @@ -606,7 +606,7 @@ def reshape_strategy(meta, mesh): input_src_spec.placements, tuple(global_in_shape), rules, - mesh, + mesh.mesh_shape, ) # TODO: optimize this. we shouldn't simply blindly replicate From de857aeb14491877fc175e1d381b7819a28842c4 Mon Sep 17 00:00:00 2001 From: Pedro Gimenes Date: Fri, 28 Jun 2024 10:04:19 +0000 Subject: [PATCH 20/93] [REFACTOR] fix for truediv strategy enumeration --- .../analysis/autosharding/ops/pointwise_ops.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/src/chop/passes/graph/analysis/autosharding/ops/pointwise_ops.py b/src/chop/passes/graph/analysis/autosharding/ops/pointwise_ops.py index 9434d5574..7aeb57bd9 100644 --- a/src/chop/passes/graph/analysis/autosharding/ops/pointwise_ops.py +++ b/src/chop/passes/graph/analysis/autosharding/ops/pointwise_ops.py @@ -64,7 +64,7 @@ def pointwise_strategy( ), f"no strategy to follow for {op_schema}!" return common_pointwise_strategy( - meta, mesh, followed_strategy, linearity + meta, mesh, followed_strategy, linearity, max_shards_strategy_index ) @@ -72,7 +72,8 @@ def common_pointwise_strategy( meta, mesh, followed_strategy, - linearity + linearity, + followed_strategy_index = 0 ): # handle broadcasting parsed_args = [] @@ -94,16 +95,20 @@ def common_pointwise_strategy( common_shape = torch.broadcast_shapes( *[arg.shape for arg in parsed_args] ) - pointwise_strategy = OpStrategy([]) + + # Extract followed argument shape + followed_shape = parsed_args[followed_strategy_index].shape + # Iterate through followed argument's strategies to cast output shardings + pointwise_strategy = OpStrategy([]) for placement_strategy in followed_strategy.strategies: spec_to_follow = placement_strategy.output_spec out_placements: List[Placement] = [] for placement in spec_to_follow.placements: if isinstance(placement, Shard): - shard_dim = normalize_dim(placement.dim, len(spec_to_follow.shape)) + shard_dim = normalize_dim(placement.dim, len(followed_shape)) common_ndim = len(common_shape) - new_shard_dim = common_ndim - len(spec_to_follow.shape) + shard_dim + new_shard_dim = common_ndim - len(followed_shape) + shard_dim out_placements.append(Shard(new_shard_dim)) elif isinstance(placement, Partial) and not linearity: # clear the partial placemnet if op does not support linearity From a444aada3f132f064fe40fa7d62fd6d24eef7abc Mon Sep 17 00:00:00 2001 From: Pedro Gimenes Date: Fri, 28 Jun 2024 13:47:54 +0000 Subject: [PATCH 21/93] [REFACTOR] finished enumerating strategies for all BERT ops. To do: consider resharding cost, solve ILP --- .../add_metadata/common_metadata_layers.py | 7 + .../graph/analysis/autosharding/alpa.py | 22 +- .../graph/analysis/autosharding/layers.py | 8 +- .../analysis/autosharding/ops/math_ops.py | 200 ++++++++++++++++++ .../analysis/autosharding/ops/view_ops.py | 1 + .../analysis/report/report_parallelization.py | 7 +- 6 files changed, 239 insertions(+), 6 deletions(-) create mode 100644 src/chop/passes/graph/analysis/autosharding/ops/math_ops.py diff --git a/src/chop/passes/graph/analysis/add_metadata/common_metadata_layers.py b/src/chop/passes/graph/analysis/add_metadata/common_metadata_layers.py index 368fb5375..066553327 100644 --- a/src/chop/passes/graph/analysis/add_metadata/common_metadata_layers.py +++ b/src/chop/passes/graph/analysis/add_metadata/common_metadata_layers.py @@ -268,6 +268,13 @@ "shape_2": "data_in", "shape_3": "data_in", }, + # https://pytorch.org/docs/stable/generated/torch.Tensor.reshape.html#torch.Tensor.reshape + "reshape": { + "shape_0": "data_in", + "shape_1": "data_in", + "shape_2": "data_in", + "shape_3": "data_in", + }, # https://pytorch.org/docs/stable/generated/torch.Tensor.addmm.html#torch.Tensor.addmm "addm": {"mat1": "data_in", "mat2": "data_in", "beta": "config", "alpha": "config"}, # https://pytorch.org/docs/stable/generated/torch.Tensor.size.html#torch.Tensor.size diff --git a/src/chop/passes/graph/analysis/autosharding/alpa.py b/src/chop/passes/graph/analysis/autosharding/alpa.py index 4c8b1e398..5db6bb582 100644 --- a/src/chop/passes/graph/analysis/autosharding/alpa.py +++ b/src/chop/passes/graph/analysis/autosharding/alpa.py @@ -88,7 +88,9 @@ def alpa_intra_op_sharding_pass(mg, mesh, debug=False): # Opt var is None since no decision needs to be taken node.meta["mase"]["software"]["autosharding"] = { "op_strategy": op_strategy, - "opt_var": None + "opt_var": None, + "input": None, + "output": None, } continue @@ -99,6 +101,16 @@ def alpa_intra_op_sharding_pass(mg, mesh, debug=False): logger.debug(f"Node {node} with op {node.op} will be assigned all permutations of Shard(dims) and Replicate()") op_strategy = placeholder_or_getattr_strategy(node.meta["mase"], mesh) + elif node.op == "output": + logger.debug(f"Op strategy from node {node.args[0]} is propagated to {node} node.") + node.meta["mase"]["software"]["autosharding"] = { + "op_strategy": node.args[0].meta["mase"]["software"]["autosharding"]["op_strategy"], + "opt_var": None, + "input": None, + "output": None, + } + continue + elif node.op == "call_method" and node.target in ALPA_METHODS.keys(): logger.debug(f"Obtaining strategy for node {node.name}") op_strategy = ALPA_METHODS[node.target](node.meta["mase"], mesh) @@ -109,6 +121,12 @@ def alpa_intra_op_sharding_pass(mg, mesh, debug=False): else: logger.warning(f"Unknown node {node.name} with op {node.op}") + node.meta["mase"]["software"]["autosharding"] = { + "op_strategy": fully_replicated_strategy(node.meta["mase"], mesh), + "opt_var": None, + "input": None, + "output": None, + } breakpoint() continue @@ -123,6 +141,8 @@ def alpa_intra_op_sharding_pass(mg, mesh, debug=False): node.meta["mase"]["software"]["autosharding"] = { "op_strategy": op_strategy, "opt_var": opt_var, + "input": None, + "output": None, } # Consider resharding cost diff --git a/src/chop/passes/graph/analysis/autosharding/layers.py b/src/chop/passes/graph/analysis/autosharding/layers.py index 842ca167c..8ed7347cf 100644 --- a/src/chop/passes/graph/analysis/autosharding/layers.py +++ b/src/chop/passes/graph/analysis/autosharding/layers.py @@ -23,6 +23,7 @@ from .ops.view_ops import get_reshape_strategy from .ops.pointwise_ops import pointwise_strategy, linear_pointwise_strategy +from .ops.math_ops import softmax_strategy, layer_norm_strategy logger = get_logger(__name__) @@ -35,11 +36,16 @@ torch.add: linear_pointwise_strategy, operator.add: linear_pointwise_strategy, operator.truediv: pointwise_strategy, - torch.matmul: bmm_strategy + F.gelu: pointwise_strategy, + torch.matmul: bmm_strategy, + torch.softmax: softmax_strategy, + F.softmax: softmax_strategy, + F.layer_norm: layer_norm_strategy } ALPA_METHODS = { "view": get_reshape_strategy(torch.Tensor.view), + "reshape": get_reshape_strategy(torch.Tensor.reshape), "expand": get_reshape_strategy(torch.Tensor.expand), "permute": get_reshape_strategy(torch.permute), "transpose": get_reshape_strategy(torch.transpose) diff --git a/src/chop/passes/graph/analysis/autosharding/ops/math_ops.py b/src/chop/passes/graph/analysis/autosharding/ops/math_ops.py new file mode 100644 index 000000000..a83b00e99 --- /dev/null +++ b/src/chop/passes/graph/analysis/autosharding/ops/math_ops.py @@ -0,0 +1,200 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates +import math +from dataclasses import dataclass +from enum import Enum +from typing import cast, List, Optional, Sequence, Tuple, Union + +import torch +from torch.distributed._tensor._op_schema import ( + OpSchema, + OpStrategy, + PlacementStrategy, + RuntimeSchemaInfo, + TupleStrategy, +) +from torch.distributed._tensor.ops.utils import ( + as_list, + expand_to_full_mesh_op_strategy, + generate_redistribute_costs, + is_tensor_evenly_shardable, + normalize_dim, + normalize_dims, + normalize_to_torch_size, + register_op_strategy, +) +from torch.distributed._tensor.placement_types import ( + DTensorSpec, + Partial, + Placement, + Replicate, + Shard, +) +from torch.distributed.device_mesh import DeviceMesh + + +aten = torch.ops.aten + +def _replicate_dims_start_at( + placements: Sequence[Placement], start_dim: int = 0 +) -> Tuple[Placement, ...]: + new_placements: List[Placement] = [] + for p in placements: + if p.is_partial() or (isinstance(p, Shard) and p.dim >= start_dim): + new_placements.append(Replicate()) # make it replicate + else: + new_placements.append(p) # keep the placement + return tuple(new_placements) + + +def replicate_reduction_dims( + placements: Tuple[Placement, ...], reduction_dims: List[int] +) -> Tuple[Placement, ...]: + # replicate the reduction dims if not reduction_linear + new_placements: List[Placement] = [] + + for p in placements: + if p.is_partial(): + new_placements.append(Replicate()) + elif isinstance(p, Shard) and p.dim in reduction_dims: + new_placements.append(Replicate()) + else: + new_placements.append(p) + + return tuple(new_placements) + + +def softmax_strategy(meta, mesh): + parent_node = meta.node.args[0] + input_strategy = parent_node.meta["mase"]["software"]["autosharding"]["op_strategy"] + ndim = len(meta["common"]["args"]["data_in_0"]["shape"]) + + softmax_dim = meta["common"]["args"]["dim"] + + input_strategy = cast(OpStrategy, input_strategy) + softmax_dim = cast(int, softmax_dim) + softmax_dim = normalize_dim(softmax_dim, ndim) + + output_strategy = OpStrategy([]) + for idx, input_placement_strategy in enumerate(input_strategy.strategies): + redistribute_costs = [] + input_src_spec = input_placement_strategy.output_spec + + # make sure input is replicated along the softmax dim + input_target_spec = DTensorSpec( + mesh=mesh, + placements=replicate_reduction_dims( + input_src_spec.placements, [softmax_dim] + ), + tensor_meta=input_src_spec.tensor_meta, + ) + # redistribute_costs.append( + # generate_redistribute_costs(input_strategy, input_target_spec) + # ) + output_target_spec = input_target_spec + output_strategy.strategies.append( + PlacementStrategy( + output_specs=output_target_spec, + input_specs=[input_target_spec], + # redistribute_cost=redistribute_costs, + ) + ) + + return output_strategy + + +def layer_norm_strategy(meta, mesh): + + # args must be: input, normalized_shape, weight, bias, eps + # for None weight and bias, their corresponding objects will + # be None as well. layer_norm_strategy returns one OpStrategy + # for the triple return values (out, mean, rstd). + assert len(meta["common"]["args"].keys()) == 5 + + input_strategy = meta.node.args[0].meta["mase"]["software"]["autosharding"]["op_strategy"] + normalized_shape = meta["common"]["args"]["normalized_shape"] + weight_strategy = meta.node.kwargs["weight"].meta["mase"]["software"]["autosharding"]["op_strategy"] + bias_strategy = meta.node.kwargs["bias"].meta["mase"]["software"]["autosharding"]["op_strategy"] + + # the current layer norm implementation requires that all + # input DTensor's sharding must be in form of OpStrategy + assert isinstance(input_strategy, OpStrategy) + assert isinstance(normalized_shape, (int, Sequence, torch.Size)) + normalized_size = normalize_to_torch_size(normalized_shape) + + input_ndim = len(meta["common"]["args"]["data_in_0"]["shape"]) + axis = input_ndim - len(normalized_size) + + # we use OpStrategy because the output (out, mean, rstd) + # should have the same placements + output_strategy = OpStrategy([]) + for idx, input_placement_strategy in enumerate(input_strategy.strategies): + op_args_target_specs = [] + redistribute_costs = [] + input_src_spec = input_placement_strategy.output_spec + + # for the input tensor, we replicate it on the inner dims if necessary + # TODO: we can avoid forcing the redistribution once we figure out + # how to decompose layer norm + input_target_spec = DTensorSpec( + mesh=mesh, + placements=_replicate_dims_start_at(input_src_spec.placements, axis), + tensor_meta=input_src_spec.tensor_meta, + ) + op_args_target_specs.append(input_target_spec) + # redistribute_costs.append( + # generate_redistribute_costs(input_strategy, input_target_spec) + # ) + + if weight_strategy is not None: + assert isinstance(weight_strategy, OpStrategy) + try: + # patching: weight and bias sharding strategy is currently always replicate + # So just take strategy at index 0 + # TO DO: when sharding decomposed layer norm, cross product weight strategies + # with input/bias strategies for final OpStrategy + weight_src_spec = weight_strategy.strategies[0].output_spec + except: + breakpoint() + + # for the weight tensor, we replicate it on all dims if necessary + # TODO: we can avoid forcing the redistribution once we figure out + # how to decompose layer norm + weight_target_spec = DTensorSpec( + mesh=mesh, + placements=_replicate_dims_start_at(weight_src_spec.placements), + tensor_meta=weight_src_spec.tensor_meta, + ) + op_args_target_specs.append(weight_target_spec) + # redistribute_costs.append( + # generate_redistribute_costs(weight_strategy, weight_target_spec) + # ) + + if bias_strategy is not None: + assert isinstance(bias_strategy, OpStrategy) + bias_src_spec = bias_strategy.strategies[0].output_spec + + # for the bias tensor, we replicate it on all dims if necessary + # TODO: we can avoid forcing the redistribution once we figure out + # how to decompose layer norm + bias_target_spec = DTensorSpec( + mesh=mesh, + placements=_replicate_dims_start_at(bias_src_spec.placements), + tensor_meta=bias_src_spec.tensor_meta, + ) + op_args_target_specs.append(bias_target_spec) + # redistribute_costs.append( + # generate_redistribute_costs(bias_strategy, bias_target_spec) + # ) + + # the output spec is the same as input spec + output_target_spec = input_target_spec + output_strategy.strategies.append( + PlacementStrategy( + output_specs=output_target_spec, + input_specs=op_args_target_specs, + # redistribute_cost=redistribute_costs, + ) + ) + + return output_strategy \ No newline at end of file diff --git a/src/chop/passes/graph/analysis/autosharding/ops/view_ops.py b/src/chop/passes/graph/analysis/autosharding/ops/view_ops.py index b96556f15..216d32c40 100644 --- a/src/chop/passes/graph/analysis/autosharding/ops/view_ops.py +++ b/src/chop/passes/graph/analysis/autosharding/ops/view_ops.py @@ -469,6 +469,7 @@ def dim_reduction( ), torch.ravel: lambda tensor: dim_flatten(tensor.ndim), Tensor.repeat: lambda self, *sizes: dim_repeat(self.ndim, sizes), + Tensor.reshape: lambda self, *shape: view_groups(self.shape, shape), torch.reshape: lambda input, shape: view_groups(input.shape, shape), torch.squeeze: lambda input, dim=None: dim_squeeze(input.shape, dim), torch.tile: lambda input, dims: dim_tile(input.ndim, dims), diff --git a/src/chop/passes/graph/analysis/report/report_parallelization.py b/src/chop/passes/graph/analysis/report/report_parallelization.py index 17d97f905..37be5c19b 100644 --- a/src/chop/passes/graph/analysis/report/report_parallelization.py +++ b/src/chop/passes/graph/analysis/report/report_parallelization.py @@ -4,7 +4,7 @@ def report_parallelization_analysis_pass(mg, pass_args={}): fname = pass_args.get("file_name", "report_parallelization.txt") - headers = ["Node", "Node op", "Mase op", "Args", "Kwargs", "Valid Input Shardings", "Input Sharding", "Output Sharding"] + headers = ["Node", "Node op", "Mase op", "Args", "Kwargs", "Input Sharding", "Output Sharding"] info = [] for node in mg.fx_graph.nodes: sharding_config = node.meta['mase']['software']['autosharding'] @@ -14,9 +14,8 @@ def report_parallelization_analysis_pass(mg, pass_args={}): node.meta['mase']['common']['mase_op'], node.args, node.kwargs, - sharding_config["valid_input_shardings"], - sharding_config["input_sharding"], - sharding_config["output_sharding"] + sharding_config["input"], + sharding_config["output"] ]) From f4f4ce74d274844d5bac7cc390ed5b0c6bdb938d Mon Sep 17 00:00:00 2001 From: Pedro Gimenes Date: Fri, 28 Jun 2024 14:07:04 +0000 Subject: [PATCH 22/93] slight refactoring --- .../graph/analysis/autosharding/alpa.py | 135 +--------------- .../autosharding/alpa_intra_operator.py | 150 ++++++++++++++++++ 2 files changed, 151 insertions(+), 134 deletions(-) create mode 100644 src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py diff --git a/src/chop/passes/graph/analysis/autosharding/alpa.py b/src/chop/passes/graph/analysis/autosharding/alpa.py index 5db6bb582..6e67c4f14 100644 --- a/src/chop/passes/graph/analysis/autosharding/alpa.py +++ b/src/chop/passes/graph/analysis/autosharding/alpa.py @@ -1,14 +1,6 @@ -import sys, pdb, traceback -import functools - -import torch.nn as nn -import numpy as np -import cvxpy as cp - from chop.tools import get_logger -from .layers import ALPA_FUNCTIONS, ALPA_METHODS, IMPLICIT_FUNCS, IMPLICIT_METHODS, placeholder_or_getattr_strategy, fully_replicated_strategy -from .alpa_cost_modelling import get_resharding_matrix +from .alpa_intra_operator import alpa_intra_op_sharding_pass logger = get_logger(__name__) logger.setLevel("DEBUG") @@ -66,131 +58,6 @@ def mark_choices(mg): return mg -def alpa_intra_op_sharding_pass(mg, mesh, debug=False): - """ - Intra-operator auto parallelization pass. - """ - - module_map = {} - - # Setup for the ILP optimization - expr = 0 - constr = [] - - # Find sharding strategies for each operator in the graph - for node in mg.fx_graph.nodes: - - if (node.op == "call_function" and node.target in IMPLICIT_FUNCS) or (node.op == "call_method" and node.target in IMPLICIT_METHODS): - logger.debug(f"Implicit {node.op} node {node.name} was assigned fully replicated sharding.") - - op_strategy = fully_replicated_strategy(node.meta["mase"], mesh) - - # Opt var is None since no decision needs to be taken - node.meta["mase"]["software"]["autosharding"] = { - "op_strategy": op_strategy, - "opt_var": None, - "input": None, - "output": None, - } - continue - - # Obtain strategy according to node op - # ================================================ - - if node.op in ["placeholder", "get_attr"]: - logger.debug(f"Node {node} with op {node.op} will be assigned all permutations of Shard(dims) and Replicate()") - op_strategy = placeholder_or_getattr_strategy(node.meta["mase"], mesh) - - elif node.op == "output": - logger.debug(f"Op strategy from node {node.args[0]} is propagated to {node} node.") - node.meta["mase"]["software"]["autosharding"] = { - "op_strategy": node.args[0].meta["mase"]["software"]["autosharding"]["op_strategy"], - "opt_var": None, - "input": None, - "output": None, - } - continue - - elif node.op == "call_method" and node.target in ALPA_METHODS.keys(): - logger.debug(f"Obtaining strategy for node {node.name}") - op_strategy = ALPA_METHODS[node.target](node.meta["mase"], mesh) - - elif node.op == "call_function" and node.target in ALPA_FUNCTIONS.keys(): - logger.debug(f"Obtaining strategy for node {node.name}") - op_strategy = ALPA_FUNCTIONS[node.target](node.meta["mase"], mesh) - - else: - logger.warning(f"Unknown node {node.name} with op {node.op}") - node.meta["mase"]["software"]["autosharding"] = { - "op_strategy": fully_replicated_strategy(node.meta["mase"], mesh), - "opt_var": None, - "input": None, - "output": None, - } - breakpoint() - continue - - # Formulate optimization variable and consider compute/communication cost - opt_var = cp.Variable(len(op_strategy.strategies), boolean=True) - constr += [ - cp.sum(opt_var) == 1, - ] - # expr += opt_var.T @ (compute_cost_vector + communication_cost_vector) - - # Write into metadata - node.meta["mase"]["software"]["autosharding"] = { - "op_strategy": op_strategy, - "opt_var": opt_var, - "input": None, - "output": None, - } - - # Consider resharding cost - # for in_node in node.all_input_nodes: - # in_opt_var = in_node.meta["mase"]["software"]["autosharding"]["opt_var"] - - # resharding_costs = get_resharding_matrix( - # mesh, - # src_shardings=in_node.meta["mase"]["software"]["autosharding"][ - # "valid_output_shardings" - # ], - # dest_shardings=[ - # sharding["data_in_0"] - # for sharding in node.meta["mase"]["software"]["autosharding"][ - # "valid_input_shardings" - # ] - # ], - # dest_node_meta=node.meta["mase"], - # ).flatten() - - # # Formulate resharding cost term with linearized variable - # e_var = cp.Variable( - # opt_var.shape[0] * in_opt_var.shape[0], boolean=True - # ) - # expr += e_var.T @ resharding_costs - # constr += [ - # cp.sum(e_var) == 1, - # ] - - # # Scalar construction of the inequality constraints for the linearized variable - # for i in range(e_var.shape[0]): - # constr += [ - # e_var[i] <= opt_var[i // in_opt_var.shape[0]], - # e_var[i] <= in_opt_var[i % in_opt_var.shape[0]], - # e_var[i] - # >= opt_var[i // in_opt_var.shape[0]] - # + in_opt_var[i % in_opt_var.shape[0]] - # - 1, - # ] - - # Solve the ILP problem - # prob = cp.Problem(cp.Minimize(expr), constr) - # prob.solve() - - # mg = mark_choices(mg) - - return mg, module_map - def alpa_autosharding_pass(mg, mesh): mg, module_map = alpa_intra_op_sharding_pass(mg, mesh) diff --git a/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py b/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py new file mode 100644 index 000000000..b32864a2c --- /dev/null +++ b/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py @@ -0,0 +1,150 @@ +import sys, pdb, traceback +import functools + +import torch.nn as nn +import numpy as np +import cvxpy as cp + +from chop.tools import get_logger + +from .layers import ALPA_FUNCTIONS, ALPA_METHODS, IMPLICIT_FUNCS, IMPLICIT_METHODS, placeholder_or_getattr_strategy, fully_replicated_strategy +from .alpa_cost_modelling import get_resharding_matrix + +logger = get_logger(__name__) +logger.setLevel("DEBUG") + +def _enumerate_sharding_strategies(mg, mesh): + """ + For each node in the graph, assign an OpStrategy object which contains all possible + sharding algorithms. Also assign opt_var instance which is one-hot vector used to + solve ILP. + + Return list of constraints associated with ILP. The constraints at this stage only + enforce that each optimizer variable is a one-hot boolean vector. + """ + + # Setup for the ILP optimization + constr = [] + + # Find sharding strategies for each operator in the graph + for node in mg.fx_graph.nodes: + + if (node.op == "call_function" and node.target in IMPLICIT_FUNCS) or (node.op == "call_method" and node.target in IMPLICIT_METHODS): + logger.debug(f"Implicit {node.op} node {node.name} was assigned fully replicated sharding.") + + op_strategy = fully_replicated_strategy(node.meta["mase"], mesh) + + # Opt var is None since no decision needs to be taken + node.meta["mase"]["software"]["autosharding"] = { + "op_strategy": op_strategy, + "opt_var": None, + "input": None, + "output": None, + } + continue + + # Obtain strategy according to node op + # ================================================ + + if node.op in ["placeholder", "get_attr"]: + logger.debug(f"Node {node} with op {node.op} will be assigned all permutations of Shard(dims) and Replicate()") + op_strategy = placeholder_or_getattr_strategy(node.meta["mase"], mesh) + + elif node.op == "output": + logger.debug(f"Op strategy from node {node.args[0]} is propagated to {node} node.") + node.meta["mase"]["software"]["autosharding"] = { + "op_strategy": node.args[0].meta["mase"]["software"]["autosharding"]["op_strategy"], + "opt_var": None, + "input": None, + "output": None, + } + continue + + elif node.op == "call_method" and node.target in ALPA_METHODS.keys(): + logger.debug(f"Obtaining strategy for node {node.name}") + op_strategy = ALPA_METHODS[node.target](node.meta["mase"], mesh) + + elif node.op == "call_function" and node.target in ALPA_FUNCTIONS.keys(): + logger.debug(f"Obtaining strategy for node {node.name}") + op_strategy = ALPA_FUNCTIONS[node.target](node.meta["mase"], mesh) + + else: + logger.warning(f"Unknown node {node.name} with op {node.op}") + node.meta["mase"]["software"]["autosharding"] = { + "op_strategy": fully_replicated_strategy(node.meta["mase"], mesh), + "opt_var": None, + "input": None, + "output": None, + } + continue + + # Formulate optimization variable and consider compute/communication cost + opt_var = cp.Variable(len(op_strategy.strategies), boolean=True) + constr += [ + cp.sum(opt_var) == 1, + ] + + # Write into metadata + node.meta["mase"]["software"]["autosharding"] = { + "op_strategy": op_strategy, + "opt_var": opt_var, + "input": None, + "output": None, + } + + return mg, constr + +def alpa_intra_op_sharding_pass(mg, mesh, debug=False): + """ + Intra-operator auto parallelization pass. + """ + + module_map = {} + + mg, constr = _enumerate_sharding_strategies(mg, mesh) + + # Consider resharding cost + # for in_node in node.all_input_nodes: + # in_opt_var = in_node.meta["mase"]["software"]["autosharding"]["opt_var"] + + # resharding_costs = get_resharding_matrix( + # mesh, + # src_shardings=in_node.meta["mase"]["software"]["autosharding"][ + # "valid_output_shardings" + # ], + # dest_shardings=[ + # sharding["data_in_0"] + # for sharding in node.meta["mase"]["software"]["autosharding"][ + # "valid_input_shardings" + # ] + # ], + # dest_node_meta=node.meta["mase"], + # ).flatten() + + # # Formulate resharding cost term with linearized variable + # e_var = cp.Variable( + # opt_var.shape[0] * in_opt_var.shape[0], boolean=True + # ) + # expr += e_var.T @ resharding_costs + # constr += [ + # cp.sum(e_var) == 1, + # ] + + # # Scalar construction of the inequality constraints for the linearized variable + # for i in range(e_var.shape[0]): + # constr += [ + # e_var[i] <= opt_var[i // in_opt_var.shape[0]], + # e_var[i] <= in_opt_var[i % in_opt_var.shape[0]], + # e_var[i] + # >= opt_var[i // in_opt_var.shape[0]] + # + in_opt_var[i % in_opt_var.shape[0]] + # - 1, + # ] + + # Solve the ILP problem + # prob = cp.Problem(cp.Minimize(expr), constr) + # prob.solve() + + # mg = mark_choices(mg) + + return mg, module_map \ No newline at end of file From 68146a44dc5156e8aeb4102c804572c3d3b164bc Mon Sep 17 00:00:00 2001 From: Pedro Gimenes Date: Tue, 2 Jul 2024 09:49:09 +0000 Subject: [PATCH 23/93] add tensor meta for placeholder and transpose ops --- .../autosharding/alpa_intra_operator.py | 10 +-- .../graph/analysis/autosharding/layers.py | 75 ++++++++++++++++--- .../graph/analysis/autosharding/mesh_model.py | 7 +- .../analysis/autosharding/ops/matrix_ops.py | 7 ++ 4 files changed, 82 insertions(+), 17 deletions(-) diff --git a/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py b/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py index b32864a2c..41c8aba91 100644 --- a/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py +++ b/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py @@ -7,7 +7,7 @@ from chop.tools import get_logger -from .layers import ALPA_FUNCTIONS, ALPA_METHODS, IMPLICIT_FUNCS, IMPLICIT_METHODS, placeholder_or_getattr_strategy, fully_replicated_strategy +from .layers import AUTOSHARDING_FUNCTIONS, AUTOSHARDING_METHODS, IMPLICIT_FUNCS, IMPLICIT_METHODS, placeholder_or_getattr_strategy, fully_replicated_strategy from .alpa_cost_modelling import get_resharding_matrix logger = get_logger(__name__) @@ -60,13 +60,13 @@ def _enumerate_sharding_strategies(mg, mesh): } continue - elif node.op == "call_method" and node.target in ALPA_METHODS.keys(): + elif node.op == "call_method" and node.target in AUTOSHARDING_METHODS.keys(): logger.debug(f"Obtaining strategy for node {node.name}") - op_strategy = ALPA_METHODS[node.target](node.meta["mase"], mesh) + op_strategy = AUTOSHARDING_METHODS[node.target](node.meta["mase"], mesh) - elif node.op == "call_function" and node.target in ALPA_FUNCTIONS.keys(): + elif node.op == "call_function" and node.target in AUTOSHARDING_FUNCTIONS.keys(): logger.debug(f"Obtaining strategy for node {node.name}") - op_strategy = ALPA_FUNCTIONS[node.target](node.meta["mase"], mesh) + op_strategy = AUTOSHARDING_FUNCTIONS[node.target](node.meta["mase"], mesh) else: logger.warning(f"Unknown node {node.name} with op {node.op}") diff --git a/src/chop/passes/graph/analysis/autosharding/layers.py b/src/chop/passes/graph/analysis/autosharding/layers.py index 8ed7347cf..c4b8c6600 100644 --- a/src/chop/passes/graph/analysis/autosharding/layers.py +++ b/src/chop/passes/graph/analysis/autosharding/layers.py @@ -5,7 +5,7 @@ import torch.nn as nn import torch.nn.functional as F from torch.distributed._tensor._op_schema import OpStrategy, PlacementStrategy -from torch.distributed._tensor.placement_types import Replicate, Shard, DTensorSpec +from torch.distributed._tensor.placement_types import Replicate, Shard, DTensorSpec, TensorMeta from chop.tools import get_logger from chop.models.patched.bert.modeling_bert import BertSelfAttention @@ -27,7 +27,7 @@ logger = get_logger(__name__) -ALPA_FUNCTIONS = { +AUTOSHARDING_FUNCTIONS = { torch.transpose: transpose_strategy, torch.mm: mm_strategy, torch.addmm: addmm_strategy, @@ -43,7 +43,7 @@ F.layer_norm: layer_norm_strategy } -ALPA_METHODS = { +AUTOSHARDING_METHODS = { "view": get_reshape_strategy(torch.Tensor.view), "reshape": get_reshape_strategy(torch.Tensor.reshape), "expand": get_reshape_strategy(torch.Tensor.expand), @@ -62,13 +62,26 @@ def placeholder_or_getattr_strategy(meta, mesh): ndims = len(meta["common"]["results"]["data_out_0"]["shape"]) opts = [Replicate()] + [Shard(dim) for dim in range(ndims)] + + tensor_meta = TensorMeta( + shape = meta["common"]["results"]["data_out_0"]["shape"], + stride = None, + dtype = meta["common"]["results"]["data_out_0"]["torch_dtype"] + ) + shardings = [] for sharding in itertools.product(opts, repeat=2): - spec = DTensorSpec(mesh, sharding) - shardings.append(PlacementStrategy( - input_specs=spec, - output_specs=spec - )) + spec = DTensorSpec( + mesh = mesh, + placements = sharding, + tensor_meta = tensor_meta + ) + shardings.append( + PlacementStrategy( + input_specs=spec, + output_specs=spec + ) + ) return OpStrategy(shardings) def fully_replicated_strategy(meta, mesh): @@ -76,11 +89,51 @@ def fully_replicated_strategy(meta, mesh): Output of ops like size, getitem etc are always fully replicated """ sharding = [Replicate(), Replicate()] - spec = DTensorSpec(mesh, sharding) + + # call_method nodes don't list input tensor in the args list, but + # tensor is copied into meta["common"]["self"] when add_value = True + # is passed to add_common_metadata_pass + if meta.node.op == "call_method": + in_shape = meta["common"]["self"].shape + in_dtype = meta["common"]["self"].dtype + else: + first_arg_key = "data_in_0" if "data_in_0" in meta["common"]["args"] else [i for i in meta["common"]["args"].keys()][0] + arg = meta["common"]["args"][first_arg_key] + if isinstance(arg, dict): + in_shape = arg["shape"] + in_dtype = arg["torch_dtype"] + else: + arg = torch.Tensor(arg) + in_shape = arg.shape + in_dtype = arg.dtype + + in_spec = DTensorSpec( + mesh, + sharding, + tensor_meta = TensorMeta ( + shape = in_shape, + stride = None, + dtype = in_dtype + ) + ) + + dtype_key = "torch_dtype" if "torch_dtype" in meta["common"]["results"]["data_out_0"].keys() else "type" + out_dtype = meta["common"]["results"]["data_out_0"][dtype_key] + out_spec = DTensorSpec( + mesh, + sharding, + tensor_meta = TensorMeta ( + shape = meta["common"]["results"]["data_out_0"]["shape"], + stride = None, + dtype = out_dtype + ) + ) + shardings = [ PlacementStrategy( - input_specs=spec, - output_specs=spec + input_specs=in_spec, + output_specs=out_spec ) ] + return OpStrategy(shardings) \ No newline at end of file diff --git a/src/chop/passes/graph/analysis/autosharding/mesh_model.py b/src/chop/passes/graph/analysis/autosharding/mesh_model.py index 311da28ff..763e151c9 100644 --- a/src/chop/passes/graph/analysis/autosharding/mesh_model.py +++ b/src/chop/passes/graph/analysis/autosharding/mesh_model.py @@ -11,11 +11,16 @@ def __init__(self, mesh_shape, mesh_alpha = None, mesh_beta = None): # Alpha/beta model is used to estimate communication cost between devices self.mesh_alpha = [0] * 2 if mesh_alpha is None else mesh_alpha self.mesh_beta = [None] * 2 if mesh_beta is None else mesh_beta + + # For compatibility with torch DeviceMesh when building MeshTopoInfo object + # for sharding redistribution cost estimation + self.device_type = "cuda" + self.ndim = 2 def __getitem__(self, key): return self.mesh_shape[key] - def size(self, dim: None): + def size(self, dim = None): if dim is None: return np.prod(self.mesh_shape) else: diff --git a/src/chop/passes/graph/analysis/autosharding/ops/matrix_ops.py b/src/chop/passes/graph/analysis/autosharding/ops/matrix_ops.py index e2390815f..cdbc75688 100644 --- a/src/chop/passes/graph/analysis/autosharding/ops/matrix_ops.py +++ b/src/chop/passes/graph/analysis/autosharding/ops/matrix_ops.py @@ -16,6 +16,7 @@ Placement, Replicate, Shard, + TensorMeta ) from torch.distributed.device_mesh import DeviceMesh @@ -30,6 +31,7 @@ def transpose_strategy(meta: MaseMetadata, mesh: tuple) -> OpStrategy: parent_node = meta.node.args[0] self_strategy = parent_node.meta["mase"]["software"]["autosharding"]["op_strategy"] + assert isinstance(self_strategy, OpStrategy) transpose_strategies = [] @@ -44,6 +46,11 @@ def transpose_strategy(meta: MaseMetadata, mesh: tuple) -> OpStrategy: output_specs=DTensorSpec( mesh=input_strategy.output_spec.mesh, placements=tuple(output_placements), + tensor_meta= TensorMeta( + shape = meta["common"]["results"]["data_out_0"]["shape"], + stride = None, + dtype = meta["common"]["results"]["data_out_0"]["torch_dtype"] + ) ), input_specs=(input_strategy.output_spec,), ) From c9afa871acc6b1d0bc765f50c8ae3749c8bcb6b6 Mon Sep 17 00:00:00 2001 From: Pedro Gimenes Date: Tue, 2 Jul 2024 11:25:40 +0000 Subject: [PATCH 24/93] include tensormeta for all ops --- .../autosharding/alpa_intra_operator.py | 145 ++++++++++++------ .../analysis/autosharding/ops/matrix_ops.py | 54 ++++++- .../autosharding/ops/pointwise_ops.py | 37 ++--- .../analysis/autosharding/ops/view_ops.py | 35 ++++- 4 files changed, 192 insertions(+), 79 deletions(-) diff --git a/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py b/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py index 41c8aba91..9cb16793d 100644 --- a/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py +++ b/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py @@ -7,12 +7,20 @@ from chop.tools import get_logger -from .layers import AUTOSHARDING_FUNCTIONS, AUTOSHARDING_METHODS, IMPLICIT_FUNCS, IMPLICIT_METHODS, placeholder_or_getattr_strategy, fully_replicated_strategy +from .layers import ( + AUTOSHARDING_FUNCTIONS, + AUTOSHARDING_METHODS, + IMPLICIT_FUNCS, + IMPLICIT_METHODS, + placeholder_or_getattr_strategy, + fully_replicated_strategy, +) from .alpa_cost_modelling import get_resharding_matrix logger = get_logger(__name__) logger.setLevel("DEBUG") + def _enumerate_sharding_strategies(mg, mesh): """ For each node in the graph, assign an OpStrategy object which contains all possible @@ -29,8 +37,12 @@ def _enumerate_sharding_strategies(mg, mesh): # Find sharding strategies for each operator in the graph for node in mg.fx_graph.nodes: - if (node.op == "call_function" and node.target in IMPLICIT_FUNCS) or (node.op == "call_method" and node.target in IMPLICIT_METHODS): - logger.debug(f"Implicit {node.op} node {node.name} was assigned fully replicated sharding.") + if (node.op == "call_function" and node.target in IMPLICIT_FUNCS) or ( + node.op == "call_method" and node.target in IMPLICIT_METHODS + ): + logger.debug( + f"Implicit {node.op} node {node.name} was assigned fully replicated sharding." + ) op_strategy = fully_replicated_strategy(node.meta["mase"], mesh) @@ -47,13 +59,19 @@ def _enumerate_sharding_strategies(mg, mesh): # ================================================ if node.op in ["placeholder", "get_attr"]: - logger.debug(f"Node {node} with op {node.op} will be assigned all permutations of Shard(dims) and Replicate()") + logger.debug( + f"Node {node} with op {node.op} will be assigned all permutations of Shard(dims) and Replicate()" + ) op_strategy = placeholder_or_getattr_strategy(node.meta["mase"], mesh) elif node.op == "output": - logger.debug(f"Op strategy from node {node.args[0]} is propagated to {node} node.") + logger.debug( + f"Op strategy from node {node.args[0]} is propagated to {node} node." + ) node.meta["mase"]["software"]["autosharding"] = { - "op_strategy": node.args[0].meta["mase"]["software"]["autosharding"]["op_strategy"], + "op_strategy": node.args[0].meta["mase"]["software"]["autosharding"][ + "op_strategy" + ], "opt_var": None, "input": None, "output": None, @@ -64,7 +82,9 @@ def _enumerate_sharding_strategies(mg, mesh): logger.debug(f"Obtaining strategy for node {node.name}") op_strategy = AUTOSHARDING_METHODS[node.target](node.meta["mase"], mesh) - elif node.op == "call_function" and node.target in AUTOSHARDING_FUNCTIONS.keys(): + elif ( + node.op == "call_function" and node.target in AUTOSHARDING_FUNCTIONS.keys() + ): logger.debug(f"Obtaining strategy for node {node.name}") op_strategy = AUTOSHARDING_FUNCTIONS[node.target](node.meta["mase"], mesh) @@ -92,8 +112,43 @@ def _enumerate_sharding_strategies(mg, mesh): "output": None, } + import torch + import torch.fx as fx + from torch.distributed._tensor._collective_utils import redistribute_cost + + for arg_idx, in_node in enumerate(node.all_input_nodes): + if not isinstance(in_node, fx.Node) or not isinstance( + in_node.meta["mase"]["common"]["results"]["data_out_0"]["value"], + torch.Tensor, + ): + continue + print(f"Parsing arg {in_node} of node {node}") + + node_op_strategy = node.meta["mase"]["software"]["autosharding"][ + "op_strategy" + ] + arg_op_strategy = in_node.meta["mase"]["software"]["autosharding"][ + "op_strategy" + ] + + arg_out_specs = [ + strategy.output_specs for strategy in arg_op_strategy.strategies + ] + node_in_specs = [ + strategy.input_specs[arg_idx] + for strategy in node_op_strategy.strategies + ] + + for out_spec in arg_out_specs: + for in_spec in node_in_specs: + cost = redistribute_cost(out_spec, in_spec) + # print( + # f"Cost for {out_spec} -> {in_spec}: {cost}" + # ) + return mg, constr + def alpa_intra_op_sharding_pass(mg, mesh, debug=False): """ Intra-operator auto parallelization pass. @@ -103,43 +158,43 @@ def alpa_intra_op_sharding_pass(mg, mesh, debug=False): mg, constr = _enumerate_sharding_strategies(mg, mesh) - # Consider resharding cost - # for in_node in node.all_input_nodes: - # in_opt_var = in_node.meta["mase"]["software"]["autosharding"]["opt_var"] - - # resharding_costs = get_resharding_matrix( - # mesh, - # src_shardings=in_node.meta["mase"]["software"]["autosharding"][ - # "valid_output_shardings" - # ], - # dest_shardings=[ - # sharding["data_in_0"] - # for sharding in node.meta["mase"]["software"]["autosharding"][ - # "valid_input_shardings" - # ] - # ], - # dest_node_meta=node.meta["mase"], - # ).flatten() - - # # Formulate resharding cost term with linearized variable - # e_var = cp.Variable( - # opt_var.shape[0] * in_opt_var.shape[0], boolean=True - # ) - # expr += e_var.T @ resharding_costs - # constr += [ - # cp.sum(e_var) == 1, - # ] - - # # Scalar construction of the inequality constraints for the linearized variable - # for i in range(e_var.shape[0]): - # constr += [ - # e_var[i] <= opt_var[i // in_opt_var.shape[0]], - # e_var[i] <= in_opt_var[i % in_opt_var.shape[0]], - # e_var[i] - # >= opt_var[i // in_opt_var.shape[0]] - # + in_opt_var[i % in_opt_var.shape[0]] - # - 1, - # ] + # Consider resharding cost + # for in_node in node.all_input_nodes: + # in_opt_var = in_node.meta["mase"]["software"]["autosharding"]["opt_var"] + + # resharding_costs = get_resharding_matrix( + # mesh, + # src_shardings=in_node.meta["mase"]["software"]["autosharding"][ + # "valid_output_shardings" + # ], + # dest_shardings=[ + # sharding["data_in_0"] + # for sharding in node.meta["mase"]["software"]["autosharding"][ + # "valid_input_shardings" + # ] + # ], + # dest_node_meta=node.meta["mase"], + # ).flatten() + + # # Formulate resharding cost term with linearized variable + # e_var = cp.Variable( + # opt_var.shape[0] * in_opt_var.shape[0], boolean=True + # ) + # expr += e_var.T @ resharding_costs + # constr += [ + # cp.sum(e_var) == 1, + # ] + + # # Scalar construction of the inequality constraints for the linearized variable + # for i in range(e_var.shape[0]): + # constr += [ + # e_var[i] <= opt_var[i // in_opt_var.shape[0]], + # e_var[i] <= in_opt_var[i % in_opt_var.shape[0]], + # e_var[i] + # >= opt_var[i // in_opt_var.shape[0]] + # + in_opt_var[i % in_opt_var.shape[0]] + # - 1, + # ] # Solve the ILP problem # prob = cp.Problem(cp.Minimize(expr), constr) @@ -147,4 +202,4 @@ def alpa_intra_op_sharding_pass(mg, mesh, debug=False): # mg = mark_choices(mg) - return mg, module_map \ No newline at end of file + return mg, module_map diff --git a/src/chop/passes/graph/analysis/autosharding/ops/matrix_ops.py b/src/chop/passes/graph/analysis/autosharding/ops/matrix_ops.py index cdbc75688..601cb6eed 100644 --- a/src/chop/passes/graph/analysis/autosharding/ops/matrix_ops.py +++ b/src/chop/passes/graph/analysis/autosharding/ops/matrix_ops.py @@ -16,7 +16,7 @@ Placement, Replicate, Shard, - TensorMeta + TensorMeta, ) from torch.distributed.device_mesh import DeviceMesh @@ -28,10 +28,10 @@ def transpose_strategy(meta: MaseMetadata, mesh: tuple) -> OpStrategy: - + parent_node = meta.node.args[0] self_strategy = parent_node.meta["mase"]["software"]["autosharding"]["op_strategy"] - + assert isinstance(self_strategy, OpStrategy) transpose_strategies = [] @@ -46,11 +46,11 @@ def transpose_strategy(meta: MaseMetadata, mesh: tuple) -> OpStrategy: output_specs=DTensorSpec( mesh=input_strategy.output_spec.mesh, placements=tuple(output_placements), - tensor_meta= TensorMeta( - shape = meta["common"]["results"]["data_out_0"]["shape"], - stride = None, - dtype = meta["common"]["results"]["data_out_0"]["torch_dtype"] - ) + tensor_meta=TensorMeta( + shape=meta["common"]["results"]["data_out_0"]["shape"], + stride=None, + dtype=meta["common"]["results"]["data_out_0"]["torch_dtype"], + ), ), input_specs=(input_strategy.output_spec,), ) @@ -70,6 +70,23 @@ def _mm_like_strategy(mm_equation: str, meta: MaseMetadata, mesh: tuple) -> OpSt assert strtg.input_specs is not None self_spec = strtg.input_specs[0] mat2_spec = strtg.input_specs[1] + + self_spec.tensor_meta = TensorMeta( + shape=self_shape, + stride=None, + dtype=meta["common"]["args"]["data_in_0"]["torch_dtype"], + ) + mat2_spec.tensor_meta = TensorMeta( + shape=mat2_shape, + stride=None, + dtype=meta["common"]["args"]["data_in_1"]["torch_dtype"], + ) + strtg.output_spec.tensor_meta = TensorMeta( + shape=meta["common"]["results"]["data_out_0"]["shape"], + stride=None, + dtype=meta["common"]["results"]["data_out_0"]["torch_dtype"], + ) + if is_tensor_shardable(self_shape, self_spec) and is_tensor_shardable( mat2_shape, mat2_spec ): @@ -114,6 +131,27 @@ def _addmm_like_strategy( ) self_spec = DTensorSpec(mesh=mesh, placements=self_placements) + self_spec.tensor_meta = TensorMeta( + shape=self_shape, + stride=None, + dtype=meta["common"]["args"]["data_in_0"]["torch_dtype"], + ) + mat1_spec.tensor_meta = TensorMeta( + shape=mat1_shape, + stride=None, + dtype=meta["common"]["args"]["data_in_1"]["torch_dtype"], + ) + mat2_spec.tensor_meta = TensorMeta( + shape=mat2_shape, + stride=None, + dtype=meta["common"]["args"]["data_in_2"]["torch_dtype"], + ) + strtg.output_spec.tensor_meta = TensorMeta( + shape=meta["common"]["results"]["data_out_0"]["shape"], + stride=None, + dtype=meta["common"]["results"]["data_out_0"]["torch_dtype"], + ) + if is_tensor_shardable(mat1_shape, mat1_spec) and is_tensor_shardable( mat2_shape, mat2_spec ): diff --git a/src/chop/passes/graph/analysis/autosharding/ops/pointwise_ops.py b/src/chop/passes/graph/analysis/autosharding/ops/pointwise_ops.py index 7aeb57bd9..959f99037 100644 --- a/src/chop/passes/graph/analysis/autosharding/ops/pointwise_ops.py +++ b/src/chop/passes/graph/analysis/autosharding/ops/pointwise_ops.py @@ -25,15 +25,15 @@ Placement, Replicate, Shard, + TensorMeta, ) from torch.distributed.device_mesh import DeviceMesh aten = torch.ops.aten -def pointwise_strategy( - meta, mesh, linearity = False -): + +def pointwise_strategy(meta, mesh, linearity=False): max_shards_strategy_index = -1 max_shards = -1 followed_strategy = None @@ -45,7 +45,7 @@ def pointwise_strategy( # # out variant op should follow the out kwarg strategy # followed_strategy = op_schema.kwargs_schema["out"] # else: - + # normal pointwise op, we choose to follow the arg with # the max shards in case operands needs reshard for idx, arg in enumerate(meta.node.args): @@ -59,9 +59,7 @@ def pointwise_strategy( max_shards = arg_max_shards followed_strategy = arg_strategy - assert isinstance( - followed_strategy, OpStrategy - ), f"no strategy to follow for {op_schema}!" + assert isinstance(followed_strategy, OpStrategy), f"no strategy to follow!" return common_pointwise_strategy( meta, mesh, followed_strategy, linearity, max_shards_strategy_index @@ -69,11 +67,7 @@ def pointwise_strategy( def common_pointwise_strategy( - meta, - mesh, - followed_strategy, - linearity, - followed_strategy_index = 0 + meta, mesh, followed_strategy, linearity, followed_strategy_index=0 ): # handle broadcasting parsed_args = [] @@ -92,10 +86,8 @@ def common_pointwise_strategy( breakpoint() raise ValueError("Unrecognized arg type") - common_shape = torch.broadcast_shapes( - *[arg.shape for arg in parsed_args] - ) - + common_shape = torch.broadcast_shapes(*[arg.shape for arg in parsed_args]) + # Extract followed argument shape followed_shape = parsed_args[followed_strategy_index].shape @@ -128,7 +120,8 @@ def common_pointwise_strategy( # every arg follow the out_placements, but need to handle broadcasting input_arg_spec = input_arg.strategies[0].output_spec input_arg_dims_map = infer_broadcast_dims_map( - common_shape, arg_node.meta["mase"]["common"]["results"]["data_out_0"]["shape"] + common_shape, + arg_node.meta["mase"]["common"]["results"]["data_out_0"]["shape"], ) input_target_placements = map_placements_after_broadcast( tuple(out_placements), @@ -145,11 +138,19 @@ def common_pointwise_strategy( # generate_redistribute_costs(input_arg, input_arg_target_spec) # ) + dtype = meta["common"]["results"]["data_out_0"].get( + "torch_dtype", torch.float32 + ) pointwise_strategy.strategies.append( PlacementStrategy( output_specs=DTensorSpec( mesh=mesh, placements=tuple(out_placements), + tensor_meta=TensorMeta( + shape=meta["common"]["results"]["data_out_0"]["shape"], + stride=None, + dtype=dtype, + ), ), input_specs=input_specs, # redistribute_cost=redistribute_costs, @@ -164,4 +165,4 @@ def linear_pointwise_strategy(meta, mesh): For example, c = add(a, b); if a is pending sum, then c will be pending sum as well without any communication overhead. """ - return pointwise_strategy(meta, mesh, linearity=True) \ No newline at end of file + return pointwise_strategy(meta, mesh, linearity=True) diff --git a/src/chop/passes/graph/analysis/autosharding/ops/view_ops.py b/src/chop/passes/graph/analysis/autosharding/ops/view_ops.py index 216d32c40..75051a17f 100644 --- a/src/chop/passes/graph/analysis/autosharding/ops/view_ops.py +++ b/src/chop/passes/graph/analysis/autosharding/ops/view_ops.py @@ -30,7 +30,12 @@ normalize_dims, prod, ) -from torch.distributed._tensor.placement_types import DTensorSpec, Placement, Replicate +from torch.distributed._tensor.placement_types import ( + DTensorSpec, + Placement, + Replicate, + TensorMeta, +) from torch.distributed.device_mesh import DeviceMesh @@ -572,9 +577,11 @@ def get_in_dim_to_shard(cmd: DimSpec) -> Optional[InputDim]: shard_dim_map[in_dim.input_dim] = dim input_tgt_placements = [ - Replicate() - if isinstance(p, Shard) and not shardable_dims[p.dim][mesh_dim] - else p + ( + Replicate() + if isinstance(p, Shard) and not shardable_dims[p.dim][mesh_dim] + else p + ) for mesh_dim, p in enumerate(input_src_placements) ] output_placements = [ @@ -591,11 +598,15 @@ def get_reshape_strategy(op): # def reshape_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: def reshape_strategy(meta, mesh): assert meta.node.op == "call_method", "Node should have call_method op." - args_schema = [meta["common"]["self"]] + [i for i in meta["common"]["args"].values()] + args_schema = [meta["common"]["self"]] + [ + i for i in meta["common"]["args"].values() + ] rules = dim_map(*args_schema) parent_node = meta.node.args[0] # input_strategy = cast(OpStrategy, op_schema.args_schema[0]) - input_strategy = parent_node.meta["mase"]["software"]["autosharding"]["op_strategy"] + input_strategy = parent_node.meta["mase"]["software"]["autosharding"][ + "op_strategy" + ] global_in_shape = meta["common"]["self"].shape assert global_in_shape is not None, "Shape required." @@ -620,7 +631,15 @@ def reshape_strategy(meta, mesh): tensor_meta=input_src_spec.tensor_meta, ) - output_spec = DTensorSpec(mesh=mesh, placements=tuple(output_placements)) + output_spec = DTensorSpec( + mesh=mesh, + placements=tuple(output_placements), + tensor_meta=TensorMeta( + shape=meta["common"]["results"]["data_out_0"]["shape"], + stride=None, + dtype=meta["common"]["results"]["data_out_0"]["torch_dtype"], + ), + ) output_strategy.strategies.append( PlacementStrategy( output_specs=output_spec, @@ -630,4 +649,4 @@ def reshape_strategy(meta, mesh): return output_strategy - return reshape_strategy \ No newline at end of file + return reshape_strategy From d9ecb19944854c7032aa8c36a5c9c526beb9b054 Mon Sep 17 00:00:00 2001 From: Pedro Gimenes Date: Tue, 2 Jul 2024 11:53:12 +0000 Subject: [PATCH 25/93] include resharding cost, ILP now too complex --- .../autosharding/alpa_intra_operator.py | 73 +++++++++++++------ 1 file changed, 50 insertions(+), 23 deletions(-) diff --git a/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py b/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py index 9cb16793d..6421f3cbc 100644 --- a/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py +++ b/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py @@ -1,7 +1,6 @@ -import sys, pdb, traceback -import functools - -import torch.nn as nn +import torch +import torch.fx as fx +from torch.distributed._tensor._collective_utils import redistribute_cost import numpy as np import cvxpy as cp @@ -15,7 +14,7 @@ placeholder_or_getattr_strategy, fully_replicated_strategy, ) -from .alpa_cost_modelling import get_resharding_matrix + logger = get_logger(__name__) logger.setLevel("DEBUG") @@ -33,6 +32,7 @@ def _enumerate_sharding_strategies(mg, mesh): # Setup for the ILP optimization constr = [] + expr = 0 # Find sharding strategies for each operator in the graph for node in mg.fx_graph.nodes: @@ -98,7 +98,7 @@ def _enumerate_sharding_strategies(mg, mesh): } continue - # Formulate optimization variable and consider compute/communication cost + # Formulate optimization variable opt_var = cp.Variable(len(op_strategy.strategies), boolean=True) constr += [ cp.sum(opt_var) == 1, @@ -112,41 +112,68 @@ def _enumerate_sharding_strategies(mg, mesh): "output": None, } - import torch - import torch.fx as fx - from torch.distributed._tensor._collective_utils import redistribute_cost - + # Consider resharding cost for each of the node's arguments for arg_idx, in_node in enumerate(node.all_input_nodes): + + # Skip constant nodes if not isinstance(in_node, fx.Node) or not isinstance( in_node.meta["mase"]["common"]["results"]["data_out_0"]["value"], torch.Tensor, ): continue - print(f"Parsing arg {in_node} of node {node}") + logger.debug(f"Parsing arg {in_node} of node {node}") + # Fetch this node's input specs node_op_strategy = node.meta["mase"]["software"]["autosharding"][ "op_strategy" ] + node_in_specs = [ + strategy.input_specs[arg_idx] + for strategy in node_op_strategy.strategies + ] + + # Fetch the argument node's output specs + in_opt_var = in_node.meta["mase"]["software"]["autosharding"]["opt_var"] arg_op_strategy = in_node.meta["mase"]["software"]["autosharding"][ "op_strategy" ] - arg_out_specs = [ strategy.output_specs for strategy in arg_op_strategy.strategies ] - node_in_specs = [ - strategy.input_specs[arg_idx] - for strategy in node_op_strategy.strategies + + # Formulate resharding cost matrix + resharding_costs = np.zeros((len(node_in_specs), len(arg_out_specs))) + for dest_idx, dest_spec in enumerate(node_in_specs): + for src_idx, src_spec in enumerate(arg_out_specs): + resharding_costs[dest_idx, src_idx] = redistribute_cost( + src_spec, dest_spec + ) + resharding_costs = resharding_costs.flatten() + + # Formulate linearized variable for resharding cost + e_var = cp.Variable(resharding_costs.shape[0], boolean=True) + expr += e_var.T @ resharding_costs + constr += [ + cp.sum(e_var) == 1, ] - for out_spec in arg_out_specs: - for in_spec in node_in_specs: - cost = redistribute_cost(out_spec, in_spec) - # print( - # f"Cost for {out_spec} -> {in_spec}: {cost}" - # ) + # Scalar construction of the inequality constraints for the linearized variable + for i in range(e_var.shape[0]): + constr += [ + e_var[i] <= opt_var[i // in_opt_var.shape[0]], + e_var[i] <= in_opt_var[i % in_opt_var.shape[0]], + e_var[i] + >= opt_var[i // in_opt_var.shape[0]] + + in_opt_var[i % in_opt_var.shape[0]] + - 1, + ] + + # Solve the ILP problem + breakpoint() + prob = cp.Problem(cp.Minimize(expr), constr) + prob.solve() - return mg, constr + return mg, {} def alpa_intra_op_sharding_pass(mg, mesh, debug=False): @@ -156,7 +183,7 @@ def alpa_intra_op_sharding_pass(mg, mesh, debug=False): module_map = {} - mg, constr = _enumerate_sharding_strategies(mg, mesh) + mg, _ = _enumerate_sharding_strategies(mg, mesh) # Consider resharding cost # for in_node in node.all_input_nodes: From 4f37d019d3c5326069455de37e3652e6896553c7 Mon Sep 17 00:00:00 2001 From: Pedro Gimenes Date: Tue, 2 Jul 2024 14:26:11 +0000 Subject: [PATCH 26/93] ILP is solvable after replacing inf values in resharding matrix --- .../autosharding/alpa_intra_operator.py | 51 ++----------------- 1 file changed, 4 insertions(+), 47 deletions(-) diff --git a/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py b/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py index 6421f3cbc..86555d3cc 100644 --- a/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py +++ b/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py @@ -145,9 +145,11 @@ def _enumerate_sharding_strategies(mg, mesh): resharding_costs = np.zeros((len(node_in_specs), len(arg_out_specs))) for dest_idx, dest_spec in enumerate(node_in_specs): for src_idx, src_spec in enumerate(arg_out_specs): - resharding_costs[dest_idx, src_idx] = redistribute_cost( - src_spec, dest_spec + cost = redistribute_cost(src_spec, dest_spec) + resharding_costs[dest_idx, src_idx] = ( + 1000000 if cost == float("inf") else cost ) + resharding_costs = resharding_costs.flatten() # Formulate linearized variable for resharding cost @@ -169,7 +171,6 @@ def _enumerate_sharding_strategies(mg, mesh): ] # Solve the ILP problem - breakpoint() prob = cp.Problem(cp.Minimize(expr), constr) prob.solve() @@ -185,48 +186,4 @@ def alpa_intra_op_sharding_pass(mg, mesh, debug=False): mg, _ = _enumerate_sharding_strategies(mg, mesh) - # Consider resharding cost - # for in_node in node.all_input_nodes: - # in_opt_var = in_node.meta["mase"]["software"]["autosharding"]["opt_var"] - - # resharding_costs = get_resharding_matrix( - # mesh, - # src_shardings=in_node.meta["mase"]["software"]["autosharding"][ - # "valid_output_shardings" - # ], - # dest_shardings=[ - # sharding["data_in_0"] - # for sharding in node.meta["mase"]["software"]["autosharding"][ - # "valid_input_shardings" - # ] - # ], - # dest_node_meta=node.meta["mase"], - # ).flatten() - - # # Formulate resharding cost term with linearized variable - # e_var = cp.Variable( - # opt_var.shape[0] * in_opt_var.shape[0], boolean=True - # ) - # expr += e_var.T @ resharding_costs - # constr += [ - # cp.sum(e_var) == 1, - # ] - - # # Scalar construction of the inequality constraints for the linearized variable - # for i in range(e_var.shape[0]): - # constr += [ - # e_var[i] <= opt_var[i // in_opt_var.shape[0]], - # e_var[i] <= in_opt_var[i % in_opt_var.shape[0]], - # e_var[i] - # >= opt_var[i // in_opt_var.shape[0]] - # + in_opt_var[i % in_opt_var.shape[0]] - # - 1, - # ] - - # Solve the ILP problem - # prob = cp.Problem(cp.Minimize(expr), constr) - # prob.solve() - - # mg = mark_choices(mg) - return mg, module_map From b0b86312e09f26521f947830fe22216d33332d73 Mon Sep 17 00:00:00 2001 From: Pedro Gimenes Date: Tue, 2 Jul 2024 14:33:13 +0000 Subject: [PATCH 27/93] refactoring --- src/chop/distributed/utils.py | 9 +- .../{ => deprecated}/alpa_cost_modelling.py | 88 +++++++++++------- .../autosharding/{ => deprecated}/common.py | 0 .../{ => deprecated}/debug_utilities.py | 0 .../graph/analysis/autosharding/layers.py | 93 +++++++++---------- 5 files changed, 101 insertions(+), 89 deletions(-) rename src/chop/passes/graph/analysis/autosharding/{ => deprecated}/alpa_cost_modelling.py (54%) rename src/chop/passes/graph/analysis/autosharding/{ => deprecated}/common.py (100%) rename src/chop/passes/graph/analysis/autosharding/{ => deprecated}/debug_utilities.py (100%) diff --git a/src/chop/distributed/utils.py b/src/chop/distributed/utils.py index 19b1e6dbf..91c35c47e 100644 --- a/src/chop/distributed/utils.py +++ b/src/chop/distributed/utils.py @@ -1,13 +1,13 @@ - from torch.distributed._tensor import ( Replicate, Shard, ) -from chop.passes.graph.analysis.autosharding.common import SpmdShard +from chop.passes.graph.analysis.autosharding.deprecated.common import SpmdShard import torch + def placement_from_sharding_config(sharding_config): """ Sharding config is given as a tuple such as (R, S_0) where a symbol S_x at index i indicates @@ -22,11 +22,12 @@ def placement_from_sharding_config(sharding_config): placement[shard_type.value] = Shard(idx) return tuple(placement) - + + def rlog(logger, rank, msg, level="info"): """ Only log on rank 0 to avoid repeated messages. """ log_fn = getattr(logger, level, logger.info) if rank == 0: - log_fn(msg) \ No newline at end of file + log_fn(msg) diff --git a/src/chop/passes/graph/analysis/autosharding/alpa_cost_modelling.py b/src/chop/passes/graph/analysis/autosharding/deprecated/alpa_cost_modelling.py similarity index 54% rename from src/chop/passes/graph/analysis/autosharding/alpa_cost_modelling.py rename to src/chop/passes/graph/analysis/autosharding/deprecated/alpa_cost_modelling.py index 7e051cb5b..467caa13c 100644 --- a/src/chop/passes/graph/analysis/autosharding/alpa_cost_modelling.py +++ b/src/chop/passes/graph/analysis/autosharding/deprecated/alpa_cost_modelling.py @@ -1,77 +1,93 @@ - import numpy as np -from functools import lru_cache +from functools import lru_cache from chop.ir.graph import MaseMetadata from .common import SpmdShard -from .mesh_model import MeshModel +from ..mesh_model import MeshModel BYTES_PER_ELEMENT = 4 + def get_communication_cost(sharding: tuple, node_meta: MaseMetadata, mesh: MeshModel): - assert sharding[0][-1] == sharding[1][-2], f"Inconsistent sharding for node: {node_meta.node}" + assert ( + sharding[0][-1] == sharding[1][-2] + ), f"Inconsistent sharding for node: {node_meta.node}" inner_dim_sharding = sharding[1][0] out_shape = node_meta["common"]["results"]["data_out_0"]["shape"] if inner_dim_sharding == SpmdShard.R: return 0 - + else: - ar_dim = inner_dim_sharding.value # 0 for S_0, 1 for S_1 - return mesh.all_reduce_cost(num_bytes = BYTES_PER_ELEMENT * np.prod(out_shape), mesh_dim = ar_dim) + ar_dim = inner_dim_sharding.value # 0 for S_0, 1 for S_1 + return mesh.all_reduce_cost( + num_bytes=BYTES_PER_ELEMENT * np.prod(out_shape), mesh_dim=ar_dim + ) + @lru_cache(maxsize=None) -def get_resharding_cost(mesh: MeshModel, src: tuple, dest: tuple, dest_node_meta: MaseMetadata): +def get_resharding_cost( + mesh: MeshModel, src: tuple, dest: tuple, dest_node_meta: MaseMetadata +): """ Obtain the resharding cost given a source and destination sharding profile for a tensor. The mesh object is assumed to have been initialized with alpha, beta parameters so that the communication cost can be estimated for each MPI operator. """ - # If original sharding is fully replicated, no resharding is required if src == dest or all(i == SpmdShard.R for i in src): return 0 - - num_bytes = BYTES_PER_ELEMENT * np.prod(dest_node_meta["common"]["args"]["data_in_0"]["shape"]) - + + num_bytes = BYTES_PER_ELEMENT * np.prod( + dest_node_meta["common"]["args"]["data_in_0"]["shape"] + ) + # No cost (simple split along given mesh dimension) if ( - # Keep dim 0, split dim 1 - # E.g. (R, R) -> (R, S_0), (S_0, R) -> (S_0, S_1) - (src[0] == dest[0]) and (src[1] == SpmdShard.R) and (dest[1] in [SpmdShard.S_0, SpmdShard.S_1]) - # Split dim 0, keep dim 1 - # E.g. (R, R) -> (S_1, R), (R, S_1) -> (S_0, S_1) - or (src[1] == dest[1]) and (src[0] == SpmdShard.R) and (dest[0] in [SpmdShard.S_0, SpmdShard.S_1]) - ): + # Keep dim 0, split dim 1 + # E.g. (R, R) -> (R, S_0), (S_0, R) -> (S_0, S_1) + (src[0] == dest[0]) + and (src[1] == SpmdShard.R) + and (dest[1] in [SpmdShard.S_0, SpmdShard.S_1]) + # Split dim 0, keep dim 1 + # E.g. (R, R) -> (S_1, R), (R, S_1) -> (S_0, S_1) + or (src[1] == dest[1]) + and (src[0] == SpmdShard.R) + and (dest[0] in [SpmdShard.S_0, SpmdShard.S_1]) + ): return 0 # Split -> Replicate (All Gather) elif ( - # Keep dim 0, gather along dim 1 - # E.g. (S_1, S_0) -> (S_1, R) - (src[0] == dest[0]) and (src[1] in [SpmdShard.S_0, SpmdShard.S_1]) and (dest[1] == SpmdShard.R) - # Gather along dim 0, keep dim 1 - # E.g. (S_0, S_1) -> (R, S_1) - or (src[1] == dest[1]) and (src[0] in [SpmdShard.S_0, SpmdShard.S_1]) and (dest[0] == SpmdShard.R) - ): + # Keep dim 0, gather along dim 1 + # E.g. (S_1, S_0) -> (S_1, R) + (src[0] == dest[0]) + and (src[1] in [SpmdShard.S_0, SpmdShard.S_1]) + and (dest[1] == SpmdShard.R) + # Gather along dim 0, keep dim 1 + # E.g. (S_0, S_1) -> (R, S_1) + or (src[1] == dest[1]) + and (src[0] in [SpmdShard.S_0, SpmdShard.S_1]) + and (dest[0] == SpmdShard.R) + ): ag_dim = 1 if src[0] == dest[0] else 0 return mesh.all_gather_cost( - num_bytes = num_bytes, - mesh_dim = ag_dim, + num_bytes=num_bytes, + mesh_dim=ag_dim, ) # All-to-all # E.g. (R, S_0) -> (S_0, R), (S_1, R) -> (R, S_1) - elif (src[0] == dest[1] and src[1] == dest[0] and (SpmdShard.R in src)): + elif src[0] == dest[1] and src[1] == dest[0] and (SpmdShard.R in src): # all to all a2a_dim = src[0].value if src[0] != SpmdShard.R else src[1].value try: return mesh.all_to_all_cost( - num_bytes = num_bytes, - mesh_dim = a2a_dim, + num_bytes=num_bytes, + mesh_dim=a2a_dim, ) except: breakpoint() @@ -81,7 +97,7 @@ def get_resharding_cost(mesh: MeshModel, src: tuple, dest: tuple, dest_node_meta # reduced sharding else: # Reduce one dimension and re-compute - if (src[0] != SpmdShard.R): + if src[0] != SpmdShard.R: new_src = (SpmdShard.R, src[1]) ag_dim = src[0].value else: @@ -89,13 +105,15 @@ def get_resharding_cost(mesh: MeshModel, src: tuple, dest: tuple, dest_node_meta ag_dim = src[1].value return mesh.all_gather_cost( - num_bytes = num_bytes, - mesh_dim = ag_dim + num_bytes=num_bytes, mesh_dim=ag_dim ) + get_resharding_cost(mesh, new_src, dest, dest_node_meta) + def get_resharding_matrix(mesh, src_shardings, dest_shardings, dest_node_meta): mat = np.zeros((len(dest_shardings), len(src_shardings))) for src_idx, src in enumerate(src_shardings): for dest_idx, dest in enumerate(dest_shardings): - mat[dest_idx, src_idx] = get_resharding_cost(mesh, src, dest, dest_node_meta) + mat[dest_idx, src_idx] = get_resharding_cost( + mesh, src, dest, dest_node_meta + ) return mat diff --git a/src/chop/passes/graph/analysis/autosharding/common.py b/src/chop/passes/graph/analysis/autosharding/deprecated/common.py similarity index 100% rename from src/chop/passes/graph/analysis/autosharding/common.py rename to src/chop/passes/graph/analysis/autosharding/deprecated/common.py diff --git a/src/chop/passes/graph/analysis/autosharding/debug_utilities.py b/src/chop/passes/graph/analysis/autosharding/deprecated/debug_utilities.py similarity index 100% rename from src/chop/passes/graph/analysis/autosharding/debug_utilities.py rename to src/chop/passes/graph/analysis/autosharding/deprecated/debug_utilities.py diff --git a/src/chop/passes/graph/analysis/autosharding/layers.py b/src/chop/passes/graph/analysis/autosharding/layers.py index c4b8c6600..1fc4f397e 100644 --- a/src/chop/passes/graph/analysis/autosharding/layers.py +++ b/src/chop/passes/graph/analysis/autosharding/layers.py @@ -5,12 +5,17 @@ import torch.nn as nn import torch.nn.functional as F from torch.distributed._tensor._op_schema import OpStrategy, PlacementStrategy -from torch.distributed._tensor.placement_types import Replicate, Shard, DTensorSpec, TensorMeta +from torch.distributed._tensor.placement_types import ( + Replicate, + Shard, + DTensorSpec, + TensorMeta, +) from chop.tools import get_logger from chop.models.patched.bert.modeling_bert import BertSelfAttention -from .alpa_cost_modelling import get_communication_cost +from .deprecated.alpa_cost_modelling import get_communication_cost from .ops.matrix_ops import ( transpose_strategy, @@ -40,7 +45,7 @@ torch.matmul: bmm_strategy, torch.softmax: softmax_strategy, F.softmax: softmax_strategy, - F.layer_norm: layer_norm_strategy + F.layer_norm: layer_norm_strategy, } AUTOSHARDING_METHODS = { @@ -48,42 +53,31 @@ "reshape": get_reshape_strategy(torch.Tensor.reshape), "expand": get_reshape_strategy(torch.Tensor.expand), "permute": get_reshape_strategy(torch.permute), - "transpose": get_reshape_strategy(torch.transpose) + "transpose": get_reshape_strategy(torch.transpose), } -IMPLICIT_FUNCS = [ - operator.getitem -] +IMPLICIT_FUNCS = [operator.getitem] + +IMPLICIT_METHODS = ["size"] -IMPLICIT_METHODS = [ - "size" -] def placeholder_or_getattr_strategy(meta, mesh): ndims = len(meta["common"]["results"]["data_out_0"]["shape"]) opts = [Replicate()] + [Shard(dim) for dim in range(ndims)] - + tensor_meta = TensorMeta( - shape = meta["common"]["results"]["data_out_0"]["shape"], - stride = None, - dtype = meta["common"]["results"]["data_out_0"]["torch_dtype"] + shape=meta["common"]["results"]["data_out_0"]["shape"], + stride=None, + dtype=meta["common"]["results"]["data_out_0"]["torch_dtype"], ) - + shardings = [] for sharding in itertools.product(opts, repeat=2): - spec = DTensorSpec( - mesh = mesh, - placements = sharding, - tensor_meta = tensor_meta - ) - shardings.append( - PlacementStrategy( - input_specs=spec, - output_specs=spec - ) - ) + spec = DTensorSpec(mesh=mesh, placements=sharding, tensor_meta=tensor_meta) + shardings.append(PlacementStrategy(input_specs=spec, output_specs=spec)) return OpStrategy(shardings) + def fully_replicated_strategy(meta, mesh): """ Output of ops like size, getitem etc are always fully replicated @@ -97,7 +91,11 @@ def fully_replicated_strategy(meta, mesh): in_shape = meta["common"]["self"].shape in_dtype = meta["common"]["self"].dtype else: - first_arg_key = "data_in_0" if "data_in_0" in meta["common"]["args"] else [i for i in meta["common"]["args"].keys()][0] + first_arg_key = ( + "data_in_0" + if "data_in_0" in meta["common"]["args"] + else [i for i in meta["common"]["args"].keys()][0] + ) arg = meta["common"]["args"][first_arg_key] if isinstance(arg, dict): in_shape = arg["shape"] @@ -108,32 +106,27 @@ def fully_replicated_strategy(meta, mesh): in_dtype = arg.dtype in_spec = DTensorSpec( - mesh, + mesh, sharding, - tensor_meta = TensorMeta ( - shape = in_shape, - stride = None, - dtype = in_dtype - ) + tensor_meta=TensorMeta(shape=in_shape, stride=None, dtype=in_dtype), + ) + + dtype_key = ( + "torch_dtype" + if "torch_dtype" in meta["common"]["results"]["data_out_0"].keys() + else "type" ) - - dtype_key = "torch_dtype" if "torch_dtype" in meta["common"]["results"]["data_out_0"].keys() else "type" out_dtype = meta["common"]["results"]["data_out_0"][dtype_key] out_spec = DTensorSpec( - mesh, + mesh, sharding, - tensor_meta = TensorMeta ( - shape = meta["common"]["results"]["data_out_0"]["shape"], - stride = None, - dtype = out_dtype - ) + tensor_meta=TensorMeta( + shape=meta["common"]["results"]["data_out_0"]["shape"], + stride=None, + dtype=out_dtype, + ), ) - - shardings = [ - PlacementStrategy( - input_specs=in_spec, - output_specs=out_spec - ) - ] - - return OpStrategy(shardings) \ No newline at end of file + + shardings = [PlacementStrategy(input_specs=in_spec, output_specs=out_spec)] + + return OpStrategy(shardings) From 1d3fe4c009d0af44489a75a29a6e779fbe573005 Mon Sep 17 00:00:00 2001 From: Pedro Gimenes Date: Tue, 2 Jul 2024 15:01:45 +0000 Subject: [PATCH 28/93] skip fully replicated strategies for placeholder ops --- .../graph/analysis/autosharding/alpa.py | 5 +-- .../autosharding/alpa_intra_operator.py | 31 +++++++++++++++---- .../analysis/autosharding/autosharding.py | 14 ++++++--- .../graph/analysis/autosharding/layers.py | 4 ++- 4 files changed, 40 insertions(+), 14 deletions(-) diff --git a/src/chop/passes/graph/analysis/autosharding/alpa.py b/src/chop/passes/graph/analysis/autosharding/alpa.py index 6e67c4f14..5145a793f 100644 --- a/src/chop/passes/graph/analysis/autosharding/alpa.py +++ b/src/chop/passes/graph/analysis/autosharding/alpa.py @@ -5,6 +5,7 @@ logger = get_logger(__name__) logger.setLevel("DEBUG") + def deepgetattr(obj, attr, default=None): """Recurses through an attribute chain to get the ultimate value.""" try: @@ -59,6 +60,6 @@ def mark_choices(mg): return mg -def alpa_autosharding_pass(mg, mesh): - mg, module_map = alpa_intra_op_sharding_pass(mg, mesh) +def alpa_autosharding_pass(mg, mesh, pass_args={}): + mg, module_map = alpa_intra_op_sharding_pass(mg, mesh, pass_args=pass_args) return mg, module_map diff --git a/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py b/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py index 86555d3cc..673c2d248 100644 --- a/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py +++ b/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py @@ -20,7 +20,7 @@ logger.setLevel("DEBUG") -def _enumerate_sharding_strategies(mg, mesh): +def _extract_ilp(mg, mesh, pass_args={}): """ For each node in the graph, assign an OpStrategy object which contains all possible sharding algorithms. Also assign opt_var instance which is one-hot vector used to @@ -62,7 +62,11 @@ def _enumerate_sharding_strategies(mg, mesh): logger.debug( f"Node {node} with op {node.op} will be assigned all permutations of Shard(dims) and Replicate()" ) - op_strategy = placeholder_or_getattr_strategy(node.meta["mase"], mesh) + op_strategy = placeholder_or_getattr_strategy( + node.meta["mase"], + mesh, + skip_fully_replicated=pass_args.get("skip_fully_replicated", False), + ) elif node.op == "output": logger.debug( @@ -172,18 +176,33 @@ def _enumerate_sharding_strategies(mg, mesh): # Solve the ILP problem prob = cp.Problem(cp.Minimize(expr), constr) - prob.solve() + return mg, prob + + +def _mark_sharding(mg): + for node in mg.fx_graph.nodes: + opt_var = node.meta["mase"]["software"]["autosharding"]["opt_var"] - return mg, {} + if opt_var is None: + continue + + idx = np.where(opt_var.value == 1) -def alpa_intra_op_sharding_pass(mg, mesh, debug=False): +def alpa_intra_op_sharding_pass(mg, mesh, pass_args={}, debug=False): """ Intra-operator auto parallelization pass. """ module_map = {} - mg, _ = _enumerate_sharding_strategies(mg, mesh) + # Formulate and solve the ILP + logger.info(f"Formulating the ILP...") + mg, problem = _extract_ilp(mg, mesh, pass_args) + + logger.info(f"Solving the ILP...") + problem.solve() + + mg, _ = _mark_sharding(mg) return mg, module_map diff --git a/src/chop/passes/graph/analysis/autosharding/autosharding.py b/src/chop/passes/graph/analysis/autosharding/autosharding.py index 3f6ae4b60..31dd28afe 100644 --- a/src/chop/passes/graph/analysis/autosharding/autosharding.py +++ b/src/chop/passes/graph/analysis/autosharding/autosharding.py @@ -1,4 +1,3 @@ - import numpy as np import cvxpy as cp from time import time @@ -11,12 +10,15 @@ logger = get_logger(__name__) logger.setLevel("DEBUG") + def autosharding_analysis_pass(mg, pass_args: dict = {}): """ A lightweight implementation of the core algorithm from the Alpa paper: https://arxiv.org/abs/2201.12023 """ - assert "mesh_shape" in pass_args, "Logical description for device cluster was not specified." + assert ( + "mesh_shape" in pass_args + ), "Logical description for device cluster was not specified." assert "inter_node_bandwidth" in pass_args, "Inter-node bandwidth not specified" assert "intra_node_bandwidth" in pass_args, "Intra-node bandwidth not specified" @@ -32,14 +34,16 @@ def autosharding_analysis_pass(mg, pass_args: dict = {}): mesh.set_cost_model_parameters( intra_node_bandwidth=pass_args["intra_node_bandwidth"], inter_node_bandwidth=pass_args["inter_node_bandwidth"], - backend = pass_args.get("communications_backend", "default") + backend=pass_args.get("communications_backend", "default"), ) # Run intra-operator pass if algo == "alpa": - mg, module_map = alpa_autosharding_pass(mg, mesh) + mg, module_map = alpa_autosharding_pass(mg, mesh, pass_args) end_time = time() - logger.info(f"Autosharding pass complete. Time taken: {end_time - start_time} seconds.") + logger.info( + f"Autosharding pass complete. Time taken: {end_time - start_time} seconds." + ) return mg, module_map diff --git a/src/chop/passes/graph/analysis/autosharding/layers.py b/src/chop/passes/graph/analysis/autosharding/layers.py index 1fc4f397e..b718d1383 100644 --- a/src/chop/passes/graph/analysis/autosharding/layers.py +++ b/src/chop/passes/graph/analysis/autosharding/layers.py @@ -61,7 +61,7 @@ IMPLICIT_METHODS = ["size"] -def placeholder_or_getattr_strategy(meta, mesh): +def placeholder_or_getattr_strategy(meta, mesh, skip_fully_replicated=False): ndims = len(meta["common"]["results"]["data_out_0"]["shape"]) opts = [Replicate()] + [Shard(dim) for dim in range(ndims)] @@ -73,6 +73,8 @@ def placeholder_or_getattr_strategy(meta, mesh): shardings = [] for sharding in itertools.product(opts, repeat=2): + if skip_fully_replicated and sharding == (Replicate(), Replicate()): + continue spec = DTensorSpec(mesh=mesh, placements=sharding, tensor_meta=tensor_meta) shardings.append(PlacementStrategy(input_specs=spec, output_specs=spec)) return OpStrategy(shardings) From 4a46f47c56304bf5615abd3238d33ebd07a16bc3 Mon Sep 17 00:00:00 2001 From: pgimenes Date: Tue, 2 Jul 2024 16:32:04 +0100 Subject: [PATCH 29/93] start docs --- .../modules/api/analysis/autosharding.rst | 17 ++++++++++++++ docs/source/modules/api/passes.rst | 1 + .../graph/analysis/autosharding/alpa.py | 11 ++++++++++ .../autosharding/alpa_intra_operator.py | 14 ++++++++++-- .../analysis/autosharding/autosharding.py | 22 +++++++++++++++++-- 5 files changed, 61 insertions(+), 4 deletions(-) create mode 100644 docs/source/modules/api/analysis/autosharding.rst diff --git a/docs/source/modules/api/analysis/autosharding.rst b/docs/source/modules/api/analysis/autosharding.rst new file mode 100644 index 000000000..12d404226 --- /dev/null +++ b/docs/source/modules/api/analysis/autosharding.rst @@ -0,0 +1,17 @@ +chop.passes.graph.analysis.autosharding +======================================== + +autosharding\_analysis\_pass +------------------------------------- + +.. autofunction:: chop.passes.graph.analysis.autosharding.autosharding_analysis_pass + +alpa\_autosharding\_pass +--------------------------------------- + +.. autofunction:: chop.passes.graph.analysis.autosharding.alpa.alpa_autosharding_pass + +alpa\_intra\_op\_sharding\_pass +--------------------------------------- + +.. autofunction:: chop.passes.graph.analysis.autosharding.alpa_intra_operator.alpa_intra_op_sharding_pass diff --git a/docs/source/modules/api/passes.rst b/docs/source/modules/api/passes.rst index bd801fa03..df0cb6952 100644 --- a/docs/source/modules/api/passes.rst +++ b/docs/source/modules/api/passes.rst @@ -56,6 +56,7 @@ MaseGraph Analysis Passes :maxdepth: 2 analysis/add_metadata + analysis/autosharding analysis/init_metadata analysis/report analysis/statistical_profiler diff --git a/src/chop/passes/graph/analysis/autosharding/alpa.py b/src/chop/passes/graph/analysis/autosharding/alpa.py index 5145a793f..8fe3844bf 100644 --- a/src/chop/passes/graph/analysis/autosharding/alpa.py +++ b/src/chop/passes/graph/analysis/autosharding/alpa.py @@ -1,6 +1,7 @@ from chop.tools import get_logger from .alpa_intra_operator import alpa_intra_op_sharding_pass +from .mesh_model import MeshModel logger = get_logger(__name__) logger.setLevel("DEBUG") @@ -61,5 +62,15 @@ def mark_choices(mg): def alpa_autosharding_pass(mg, mesh, pass_args={}): + """A lightweight implementation of the core algorithm from the Alpa paper: https://arxiv.org/abs/2201.12023 + + Args: + mg (MaseGraph): Input MaseGraph. + mesh (MeshModel): Input MeshModel. + pass_args (dict, optional): pass arguments. Defaults to {}. + + Returns: + MaseGraph: MaseGraph with sharding strategy annotated for each operator. + """ mg, module_map = alpa_intra_op_sharding_pass(mg, mesh, pass_args=pass_args) return mg, module_map diff --git a/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py b/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py index 673c2d248..616042b70 100644 --- a/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py +++ b/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py @@ -19,6 +19,8 @@ logger = get_logger(__name__) logger.setLevel("DEBUG") +from .mesh_model import MeshModel + def _extract_ilp(mg, mesh, pass_args={}): """ @@ -190,8 +192,16 @@ def _mark_sharding(mg): def alpa_intra_op_sharding_pass(mg, mesh, pass_args={}, debug=False): - """ - Intra-operator auto parallelization pass. + """Intra-operator auto parallelization pass from the Alpa paper: https://arxiv.org/abs/2201.12023 + + Args: + mg (MaseGraph): Input MaseGraph. + mesh (MeshModel): mesh description. + pass_args (dict, optional): pass arguments. Defaults to {}. + debug (bool, optional): enable debug. Defaults to False. + + Returns: + MaseGraph: annotated MaseGraph. """ module_map = {} diff --git a/src/chop/passes/graph/analysis/autosharding/autosharding.py b/src/chop/passes/graph/analysis/autosharding/autosharding.py index 31dd28afe..79e6d7d1a 100644 --- a/src/chop/passes/graph/analysis/autosharding/autosharding.py +++ b/src/chop/passes/graph/analysis/autosharding/autosharding.py @@ -12,8 +12,26 @@ def autosharding_analysis_pass(mg, pass_args: dict = {}): - """ - A lightweight implementation of the core algorithm from the Alpa paper: https://arxiv.org/abs/2201.12023 + """Annotate the metadata of each operator in the graph with a parallelization strategy. + + Args: + mg (MaseGraph): input mase graph. + pass_args (dict, optional): pass arguments. Defaults to {}. + + Returns: + MaseGraph: annotated mase graph. + + The pass_args dictionary expects the following elements. + + - algo (optional) -> str : Sharding algorithm to use. Default is "alpa". + - mesh_shape -> tuple : Shape of the device cluster. Should be a 2-dimensional tuple. + - inter_node_bandwidth -> int : Inter-node bandwidth, i.e. between GPU nodes. + - intra_node_bandwidth -> int : Intra-node bandwidth, i.e. between GPU devices in each node. + + Additionally, the following elements can be passed. + + - communications_backend (optional) -> str : Communications backend to use, e.g. "nccl" or "gloo". Default is "nccl". + - skip_fully_replicated (optional) -> bool : If set to true, do not consider fully replicated sharding as an option for any operator. """ assert ( From 3c4b4a5cf2ba2025b2a195519391f892039ad269 Mon Sep 17 00:00:00 2001 From: Pedro Gimenes Date: Wed, 3 Jul 2024 10:49:00 +0000 Subject: [PATCH 30/93] unnecessary imports --- .../analysis/autosharding/ops/math_ops.py | 46 +++++++------------ .../analysis/autosharding/ops/matrix_ops.py | 9 ---- .../autosharding/ops/pointwise_ops.py | 10 ---- .../analysis/autosharding/ops/view_ops.py | 29 ------------ 4 files changed, 16 insertions(+), 78 deletions(-) diff --git a/src/chop/passes/graph/analysis/autosharding/ops/math_ops.py b/src/chop/passes/graph/analysis/autosharding/ops/math_ops.py index a83b00e99..028c94dbf 100644 --- a/src/chop/passes/graph/analysis/autosharding/ops/math_ops.py +++ b/src/chop/passes/graph/analysis/autosharding/ops/math_ops.py @@ -1,40 +1,23 @@ -# mypy: allow-untyped-defs -# Copyright (c) Meta Platforms, Inc. and affiliates -import math -from dataclasses import dataclass -from enum import Enum from typing import cast, List, Optional, Sequence, Tuple, Union import torch from torch.distributed._tensor._op_schema import ( - OpSchema, OpStrategy, PlacementStrategy, - RuntimeSchemaInfo, - TupleStrategy, ) from torch.distributed._tensor.ops.utils import ( - as_list, - expand_to_full_mesh_op_strategy, generate_redistribute_costs, - is_tensor_evenly_shardable, normalize_dim, - normalize_dims, normalize_to_torch_size, - register_op_strategy, ) from torch.distributed._tensor.placement_types import ( DTensorSpec, - Partial, Placement, Replicate, Shard, ) -from torch.distributed.device_mesh import DeviceMesh -aten = torch.ops.aten - def _replicate_dims_start_at( placements: Sequence[Placement], start_dim: int = 0 ) -> Tuple[Placement, ...]: @@ -104,17 +87,23 @@ def softmax_strategy(meta, mesh): def layer_norm_strategy(meta, mesh): - + # args must be: input, normalized_shape, weight, bias, eps # for None weight and bias, their corresponding objects will # be None as well. layer_norm_strategy returns one OpStrategy # for the triple return values (out, mean, rstd). assert len(meta["common"]["args"].keys()) == 5 - input_strategy = meta.node.args[0].meta["mase"]["software"]["autosharding"]["op_strategy"] + input_strategy = meta.node.args[0].meta["mase"]["software"]["autosharding"][ + "op_strategy" + ] normalized_shape = meta["common"]["args"]["normalized_shape"] - weight_strategy = meta.node.kwargs["weight"].meta["mase"]["software"]["autosharding"]["op_strategy"] - bias_strategy = meta.node.kwargs["bias"].meta["mase"]["software"]["autosharding"]["op_strategy"] + weight_strategy = meta.node.kwargs["weight"].meta["mase"]["software"][ + "autosharding" + ]["op_strategy"] + bias_strategy = meta.node.kwargs["bias"].meta["mase"]["software"]["autosharding"][ + "op_strategy" + ] # the current layer norm implementation requires that all # input DTensor's sharding must be in form of OpStrategy @@ -148,14 +137,11 @@ def layer_norm_strategy(meta, mesh): if weight_strategy is not None: assert isinstance(weight_strategy, OpStrategy) - try: - # patching: weight and bias sharding strategy is currently always replicate - # So just take strategy at index 0 - # TO DO: when sharding decomposed layer norm, cross product weight strategies - # with input/bias strategies for final OpStrategy - weight_src_spec = weight_strategy.strategies[0].output_spec - except: - breakpoint() + # patching: weight and bias sharding strategy is currently always replicate + # So just take strategy at index 0 + # TO DO: when sharding decomposed layer norm, cross product weight strategies + # with input/bias strategies for final OpStrategy + weight_src_spec = weight_strategy.strategies[0].output_spec # for the weight tensor, we replicate it on all dims if necessary # TODO: we can avoid forcing the redistribution once we figure out @@ -197,4 +183,4 @@ def layer_norm_strategy(meta, mesh): ) ) - return output_strategy \ No newline at end of file + return output_strategy diff --git a/src/chop/passes/graph/analysis/autosharding/ops/matrix_ops.py b/src/chop/passes/graph/analysis/autosharding/ops/matrix_ops.py index 601cb6eed..1fdfb5151 100644 --- a/src/chop/passes/graph/analysis/autosharding/ops/matrix_ops.py +++ b/src/chop/passes/graph/analysis/autosharding/ops/matrix_ops.py @@ -1,31 +1,22 @@ # Adapted from https://github.com/pytorch/pytorch/blob/main/torch/distributed/_tensor/ops/matrix_ops.py -import itertools -from typing import List, Optional - import torch from torch.distributed._tensor._op_schema import OpStrategy, PlacementStrategy from .basic_strategy import gen_einsum_strategies from torch.distributed._tensor.ops.utils import ( infer_broadcast_dims_map, map_placements_after_broadcast, - register_op_strategy, ) from torch.distributed._tensor.placement_types import ( DTensorSpec, - Placement, - Replicate, Shard, TensorMeta, ) -from torch.distributed.device_mesh import DeviceMesh from ..utils import is_tensor_shardable from chop.ir.graph import MaseMetadata -aten = torch.ops.aten - def transpose_strategy(meta: MaseMetadata, mesh: tuple) -> OpStrategy: diff --git a/src/chop/passes/graph/analysis/autosharding/ops/pointwise_ops.py b/src/chop/passes/graph/analysis/autosharding/ops/pointwise_ops.py index 959f99037..be9da28ae 100644 --- a/src/chop/passes/graph/analysis/autosharding/ops/pointwise_ops.py +++ b/src/chop/passes/graph/analysis/autosharding/ops/pointwise_ops.py @@ -1,23 +1,17 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates from typing import List, Sequence, Tuple import torch from torch.distributed._tensor._op_schema import ( _is_inplace_op, _is_out_variant_op, - OpSchema, OpStrategy, PlacementStrategy, - RuntimeSchemaInfo, - StrategyType, - TupleStrategy, ) from torch.distributed._tensor.ops.utils import ( generate_redistribute_costs, infer_broadcast_dims_map, map_placements_after_broadcast, normalize_dim, - register_op_strategy, ) from torch.distributed._tensor.placement_types import ( DTensorSpec, @@ -27,10 +21,6 @@ Shard, TensorMeta, ) -from torch.distributed.device_mesh import DeviceMesh - - -aten = torch.ops.aten def pointwise_strategy(meta, mesh, linearity=False): diff --git a/src/chop/passes/graph/analysis/autosharding/ops/view_ops.py b/src/chop/passes/graph/analysis/autosharding/ops/view_ops.py index 75051a17f..a4e779539 100644 --- a/src/chop/passes/graph/analysis/autosharding/ops/view_ops.py +++ b/src/chop/passes/graph/analysis/autosharding/ops/view_ops.py @@ -1,5 +1,3 @@ -# mypy: allow-untyped-defs -# Copyright (c) Meta Platforms, Inc. and affiliates from dataclasses import dataclass from typing import ( Callable, @@ -17,11 +15,8 @@ import torch from torch import Tensor from torch.distributed._tensor._op_schema import ( - OpSchema, OpStrategy, PlacementStrategy, - RuntimeSchemaInfo, - StrategyType, ) from torch.distributed._tensor.api import Shard from torch.distributed._tensor.ops.utils import ( @@ -36,10 +31,6 @@ Replicate, TensorMeta, ) -from torch.distributed.device_mesh import DeviceMesh - - -aten = torch.ops.aten Shape = Tuple[int, ...] @@ -439,26 +430,6 @@ def dim_view_as_real(shape: Shape) -> DimMap: return tuple(results) -def dim_reduction( - ndim: int, dim_or_dims: Optional[Union[int, Sequence[int]]], keepdim: bool -) -> DimMap: - """ - General fallback for reduction ops where Partial() does not apply. - - This will cause incoming tensor to be replicated on the reducing dimensions. - """ - if dim_or_dims is None: - dim_or_dims = tuple(range(ndim)) - if isinstance(dim_or_dims, int): - dim_or_dims = (dim_or_dims,) - dim_or_dims = tuple(d if d >= 0 else d + ndim for d in dim_or_dims) - return tuple( - InputDim(i) if i not in dim_or_dims else Singleton() - for i in range(ndim) - if i not in dim_or_dims or keepdim - ) - - dim_maps: Dict[Callable[..., torch.Tensor], Callable[..., DimMap]] = { torch.atleast_1d: lambda x: dim_pad_left(x.ndim, 1), torch.atleast_2d: lambda x: dim_pad_left(x.ndim, 2), From 80f50fbda2630f54927dd82a69afa1de3093066d Mon Sep 17 00:00:00 2001 From: Pedro Gimenes Date: Wed, 3 Jul 2024 10:49:15 +0000 Subject: [PATCH 31/93] export solution and optimizer profiling --- .../autosharding/alpa_intra_operator.py | 50 ++++++++++++++++++- 1 file changed, 49 insertions(+), 1 deletion(-) diff --git a/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py b/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py index 616042b70..2e6912048 100644 --- a/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py +++ b/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py @@ -176,11 +176,56 @@ def _extract_ilp(mg, mesh, pass_args={}): - 1, ] + # Below speeds up compilation but the number of constraints is the same? + + # # Reshape e_var to match the dimensions of opt_var and in_opt_var + # e_var_reshaped = cp.reshape(e_var, (opt_var.shape[0], in_opt_var.shape[0])) + + # # Create broadcasted versions of opt_var and in_opt_var + # opt_var_broadcast = cp.reshape(opt_var, (opt_var.shape[0], 1)) + # in_opt_var_broadcast = cp.reshape(in_opt_var, (1, in_opt_var.shape[0])) + + # # Define the vectorized constraints + # constr += [ + # e_var_reshaped <= opt_var_broadcast, + # e_var_reshaped <= in_opt_var_broadcast, + # e_var_reshaped >= opt_var_broadcast + in_opt_var_broadcast - 1, + # ] + # Solve the ILP problem prob = cp.Problem(cp.Minimize(expr), constr) return mg, prob +def _export_solution(mg): + + nodes = [node for node in mg.fx_graph.nodes] + node_names = [node.name for node in nodes] + opt_vars = [ + node.meta["mase"]["software"]["autosharding"]["opt_var"] for node in nodes + ] + opt_vals = [i.value if i is not None else None for i in opt_vars] + choices = [np.argmax(i) for i in opt_vals] + + strategies = [ + i.meta["mase"]["software"]["autosharding"]["op_strategy"].strategies + for i in nodes + ] + shardings = [strat[choices[idx]] for idx, strat in enumerate(strategies)] + map = [ + { + "node": nodes[idx].name, + "input_specs": strat.input_specs, + "output_specs": strat.output_specs, + } + for idx, strat in enumerate(shardings) + ] + + breakpoint() + + return mg, {} + + def _mark_sharding(mg): for node in mg.fx_graph.nodes: opt_var = node.meta["mase"]["software"]["autosharding"]["opt_var"] @@ -190,6 +235,8 @@ def _mark_sharding(mg): idx = np.where(opt_var.value == 1) + return mg, {} + def alpa_intra_op_sharding_pass(mg, mesh, pass_args={}, debug=False): """Intra-operator auto parallelization pass from the Alpa paper: https://arxiv.org/abs/2201.12023 @@ -211,8 +258,9 @@ def alpa_intra_op_sharding_pass(mg, mesh, pass_args={}, debug=False): mg, problem = _extract_ilp(mg, mesh, pass_args) logger.info(f"Solving the ILP...") - problem.solve() + problem.solve(verbose=True, scipy_options={"disp": True}) + mg, _ = _export_solution(mg) mg, _ = _mark_sharding(mg) return mg, module_map From 3e1fd159d6f97dcae453f8561ed1e0d6ce38d11c Mon Sep 17 00:00:00 2001 From: Pedro Gimenes Date: Tue, 9 Jul 2024 14:27:55 +0000 Subject: [PATCH 32/93] vectorize constraint for linearized resharding cost variable and enable sweep with multithreading --- src/chop/__init__.py | 3 +- src/chop/ir/__init__.py | 1 + .../graph/analysis/autosharding/alpa.py | 4 +- .../autosharding/alpa_intra_operator.py | 49 +++++++------------ .../analysis/autosharding/autosharding.py | 7 +-- src/chop/pipelines/auto_pipeline.py | 3 +- 6 files changed, 29 insertions(+), 38 deletions(-) diff --git a/src/chop/__init__.py b/src/chop/__init__.py index 6e33393f4..e8ce982a1 100644 --- a/src/chop/__init__.py +++ b/src/chop/__init__.py @@ -1,6 +1,7 @@ from .ir.graph.mase_graph import MaseGraph + from .ir.onnx.mase_onnx_graph import MaseOnnxGraph from . import passes -from .pipelines import AutoPipelineForDistributedInference \ No newline at end of file +from .pipelines import AutoPipelineForDistributedInference diff --git a/src/chop/ir/__init__.py b/src/chop/ir/__init__.py index 9b987d139..834757faa 100644 --- a/src/chop/ir/__init__.py +++ b/src/chop/ir/__init__.py @@ -1,2 +1,3 @@ from .graph.mase_graph import MaseGraph, MaseTracer + from .onnx.mase_onnx_graph import MaseOnnxGraph diff --git a/src/chop/passes/graph/analysis/autosharding/alpa.py b/src/chop/passes/graph/analysis/autosharding/alpa.py index 8fe3844bf..51d62d41f 100644 --- a/src/chop/passes/graph/analysis/autosharding/alpa.py +++ b/src/chop/passes/graph/analysis/autosharding/alpa.py @@ -72,5 +72,5 @@ def alpa_autosharding_pass(mg, mesh, pass_args={}): Returns: MaseGraph: MaseGraph with sharding strategy annotated for each operator. """ - mg, module_map = alpa_intra_op_sharding_pass(mg, mesh, pass_args=pass_args) - return mg, module_map + mg, pass_outs = alpa_intra_op_sharding_pass(mg, mesh, pass_args=pass_args) + return mg, pass_outs diff --git a/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py b/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py index 2e6912048..b3859abd1 100644 --- a/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py +++ b/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py @@ -148,7 +148,7 @@ def _extract_ilp(mg, mesh, pass_args={}): ] # Formulate resharding cost matrix - resharding_costs = np.zeros((len(node_in_specs), len(arg_out_specs))) + resharding_costs = np.zeros((opt_var.shape[0], in_opt_var.shape[0])) for dest_idx, dest_spec in enumerate(node_in_specs): for src_idx, src_spec in enumerate(arg_out_specs): cost = redistribute_cost(src_spec, dest_spec) @@ -165,32 +165,14 @@ def _extract_ilp(mg, mesh, pass_args={}): cp.sum(e_var) == 1, ] - # Scalar construction of the inequality constraints for the linearized variable - for i in range(e_var.shape[0]): - constr += [ - e_var[i] <= opt_var[i // in_opt_var.shape[0]], - e_var[i] <= in_opt_var[i % in_opt_var.shape[0]], - e_var[i] - >= opt_var[i // in_opt_var.shape[0]] - + in_opt_var[i % in_opt_var.shape[0]] - - 1, - ] - - # Below speeds up compilation but the number of constraints is the same? - - # # Reshape e_var to match the dimensions of opt_var and in_opt_var - # e_var_reshaped = cp.reshape(e_var, (opt_var.shape[0], in_opt_var.shape[0])) - - # # Create broadcasted versions of opt_var and in_opt_var - # opt_var_broadcast = cp.reshape(opt_var, (opt_var.shape[0], 1)) - # in_opt_var_broadcast = cp.reshape(in_opt_var, (1, in_opt_var.shape[0])) - - # # Define the vectorized constraints - # constr += [ - # e_var_reshaped <= opt_var_broadcast, - # e_var_reshaped <= in_opt_var_broadcast, - # e_var_reshaped >= opt_var_broadcast + in_opt_var_broadcast - 1, - # ] + # Constraints s.t. e_var = outer(opt_var, in_opt_var) + indices = np.arange(e_var.shape[0]) + opt_indices, in_opt_indices = np.divmod(indices, in_opt_var.shape[0]) + constr += [ + e_var <= opt_var[opt_indices], + e_var <= in_opt_var[in_opt_indices], + e_var >= opt_var[opt_indices] + in_opt_var[in_opt_indices] - 1, + ] # Solve the ILP problem prob = cp.Problem(cp.Minimize(expr), constr) @@ -221,8 +203,6 @@ def _export_solution(mg): for idx, strat in enumerate(shardings) ] - breakpoint() - return mg, {} @@ -258,9 +238,16 @@ def alpa_intra_op_sharding_pass(mg, mesh, pass_args={}, debug=False): mg, problem = _extract_ilp(mg, mesh, pass_args) logger.info(f"Solving the ILP...") - problem.solve(verbose=True, scipy_options={"disp": True}) + problem.solve( + verbose=True, + scipy_options={ + "disp": True, + "time_limit": pass_args.get("time_limit", None), + "mip_rel_gap": pass_args.get("mip_rel_gap", 0) / 100, + }, + ) mg, _ = _export_solution(mg) mg, _ = _mark_sharding(mg) - return mg, module_map + return mg, {"module_map": module_map, "solution": problem.value} diff --git a/src/chop/passes/graph/analysis/autosharding/autosharding.py b/src/chop/passes/graph/analysis/autosharding/autosharding.py index 79e6d7d1a..24df38e74 100644 --- a/src/chop/passes/graph/analysis/autosharding/autosharding.py +++ b/src/chop/passes/graph/analysis/autosharding/autosharding.py @@ -57,11 +57,12 @@ def autosharding_analysis_pass(mg, pass_args: dict = {}): # Run intra-operator pass if algo == "alpa": - mg, module_map = alpa_autosharding_pass(mg, mesh, pass_args) + mg, pass_outs = alpa_autosharding_pass(mg, mesh, pass_args) end_time = time() + time_taken = end_time - start_time logger.info( - f"Autosharding pass complete. Time taken: {end_time - start_time} seconds." + f"Autosharding pass complete. Time taken: {time_taken} seconds. Solution: {pass_outs['solution']}" ) - return mg, module_map + return mg, {"autosharding_time": time_taken, **pass_outs} diff --git a/src/chop/pipelines/auto_pipeline.py b/src/chop/pipelines/auto_pipeline.py index 147240753..02ffe27e1 100644 --- a/src/chop/pipelines/auto_pipeline.py +++ b/src/chop/pipelines/auto_pipeline.py @@ -23,4 +23,5 @@ def __call__(self, mg: MaseGraph, pass_args: dict, skip_passes: list = []): mg, pass_output = pass_fn(mg, pass_args=args) self.pass_outputs[pass_fn.__name__] = pass_output - return mg \ No newline at end of file + + return mg, self.pass_outputs From a50e3cccf29230ae9c6cb054cbf6ee186772d04f Mon Sep 17 00:00:00 2001 From: Pedro Gimenes Date: Tue, 9 Jul 2024 17:39:33 +0000 Subject: [PATCH 33/93] mark sharding and run checks for linearised variable constraints --- .../graph/analysis/autosharding/alpa.py | 54 ----------------- .../autosharding/alpa_intra_operator.py | 58 +++++++++++++++++-- 2 files changed, 54 insertions(+), 58 deletions(-) diff --git a/src/chop/passes/graph/analysis/autosharding/alpa.py b/src/chop/passes/graph/analysis/autosharding/alpa.py index 51d62d41f..d00661128 100644 --- a/src/chop/passes/graph/analysis/autosharding/alpa.py +++ b/src/chop/passes/graph/analysis/autosharding/alpa.py @@ -7,60 +7,6 @@ logger.setLevel("DEBUG") -def deepgetattr(obj, attr, default=None): - """Recurses through an attribute chain to get the ultimate value.""" - try: - return functools.reduce(getattr, attr.split("."), obj) - except AttributeError: - return default - - -def get_node_target(node): - if isinstance(node.target, str): - return deepgetattr(node.meta["mase"].model, node.target, None) - else: - return node.target - - -def mark_choices(mg): - """ - Once the metadata has already been filled for each op with the possible shardings and costs, - and the ILP has been solved, this function marks the chosen sharding for each op. - """ - for node in mg.fx_graph.nodes: - chosen_idx = ( - 0 - if isinstance( - node.meta["mase"]["software"]["autosharding"]["opt_var"], np.ndarray - ) - else np.where( - node.meta["mase"]["software"]["autosharding"]["opt_var"].value == 1 - )[0][0] - ) - node.meta["mase"]["software"]["autosharding"]["input_sharding"] = node.meta[ - "mase" - ]["software"]["autosharding"]["valid_input_shardings"][chosen_idx] - node.meta["mase"]["software"]["autosharding"]["output_sharding"] = node.meta[ - "mase" - ]["software"]["autosharding"]["valid_output_shardings"][chosen_idx] - chosen_sharding = { - key: node.meta["mase"]["software"]["autosharding"]["input_sharding"][key] - for key in node.meta["mase"]["software"]["autosharding"][ - "input_sharding" - ].keys() - } - - # Write into module map (used by distributed launcher) - target = get_node_target(node) - if node.op == "call_module" and target is not None: - module_map[target] = {"node": node.name, "sharding": chosen_sharding} - module_map[target]["sharding"]["output"] = node.meta["mase"]["software"][ - "autosharding" - ]["output_sharding"] - - return mg - - def alpa_autosharding_pass(mg, mesh, pass_args={}): """A lightweight implementation of the core algorithm from the Alpa paper: https://arxiv.org/abs/2201.12023 diff --git a/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py b/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py index b3859abd1..2e5f70b43 100644 --- a/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py +++ b/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py @@ -1,6 +1,7 @@ import torch import torch.fx as fx from torch.distributed._tensor._collective_utils import redistribute_cost +from torch.distributed._tensor._op_schema import DTensorSpec import numpy as np import cvxpy as cp @@ -119,6 +120,7 @@ def _extract_ilp(mg, mesh, pass_args={}): } # Consider resharding cost for each of the node's arguments + e_var_checks = [] for arg_idx, in_node in enumerate(node.all_input_nodes): # Skip constant nodes @@ -165,6 +167,10 @@ def _extract_ilp(mg, mesh, pass_args={}): cp.sum(e_var) == 1, ] + # After solving the ILP, verify constraints were correctly formulated + if pass_args.get("run_checks", False): + e_var_checks.append((opt_var, in_opt_var, e_var)) + # Constraints s.t. e_var = outer(opt_var, in_opt_var) indices = np.arange(e_var.shape[0]) opt_indices, in_opt_indices = np.divmod(indices, in_opt_var.shape[0]) @@ -174,6 +180,9 @@ def _extract_ilp(mg, mesh, pass_args={}): e_var >= opt_var[opt_indices] + in_opt_var[in_opt_indices] - 1, ] + if pass_args.get("run_checks", False): + node.meta["mase"]["software"]["autosharding"]["e_var_checks"] = e_var_checks + # Solve the ILP problem prob = cp.Problem(cp.Minimize(expr), constr) return mg, prob @@ -206,14 +215,53 @@ def _export_solution(mg): return mg, {} -def _mark_sharding(mg): +def _run_checks(mg, pass_args): + for node in mg.fx_graph.nodes: + check_list = node.meta["mase"]["software"]["autosharding"].get( + "e_var_checks", [] + ) + + # Check that the constraints on the linearised variable for resharding cost + # are correctly formulated + for opt_var, in_opt_var, e_var in check_list: + idx1 = np.where(opt_var.value == 1)[0][0] + idx2 = np.where(in_opt_var.value == 1)[0][0] + idx3 = np.where(e_var.value == 1)[0][0] + assert ( + idx3 == idx1 * in_opt_var.shape[0] + idx2 + ), f"Linearized variable for resharding cost is not consistent for node {node}." + + +def _mark_sharding(mg, pass_args): for node in mg.fx_graph.nodes: opt_var = node.meta["mase"]["software"]["autosharding"]["opt_var"] if opt_var is None: continue - idx = np.where(opt_var.value == 1) + idx = np.where(opt_var.value == 1)[0][0] + chosen_strategy = node.meta["mase"]["software"]["autosharding"][ + "op_strategy" + ].strategies[idx] + + arg_specs = chosen_strategy.input_specs + out_specs = chosen_strategy.output_specs + + if isinstance(arg_specs, DTensorSpec): + arg_specs = (arg_specs,) + + # Annotate arg metadata with chosen strategy + if node.op not in ["placeholder", "get_attr", "output"]: + arg_list = [i for i in node.meta["mase"]["common"]["args"].keys()] + + for arg_idx, arg_spec in enumerate(arg_specs): + arg_meta = node.meta["mase"]["common"]["args"][arg_list[arg_idx]] + if not isinstance(arg_meta, dict): + continue + arg_meta["dtensor_spec"] = arg_spec + + # Annotate output metadata with chosen strategy + node.meta["mase"]["common"]["results"]["data_out_0"]["dtensor_spec"] = out_specs return mg, {} @@ -247,7 +295,9 @@ def alpa_intra_op_sharding_pass(mg, mesh, pass_args={}, debug=False): }, ) - mg, _ = _export_solution(mg) - mg, _ = _mark_sharding(mg) + if pass_args.get("run_checks", False): + _run_checks(mg, pass_args) + + mg, _ = _mark_sharding(mg, pass_args) return mg, {"module_map": module_map, "solution": problem.value} From 4e384d272eb704d865645427fbfce16c8bb58d39 Mon Sep 17 00:00:00 2001 From: Pedro Gimenes Date: Tue, 9 Jul 2024 18:41:06 +0000 Subject: [PATCH 34/93] [ATTACH]: distribute get_attr nodes, bug in forward pass --- src/chop/distributed/launcher.py | 91 ++++++++++++++----- .../autosharding/alpa_intra_operator.py | 34 ++++++- 2 files changed, 99 insertions(+), 26 deletions(-) diff --git a/src/chop/distributed/launcher.py b/src/chop/distributed/launcher.py index 31095d7bb..b6af31744 100644 --- a/src/chop/distributed/launcher.py +++ b/src/chop/distributed/launcher.py @@ -17,10 +17,10 @@ from chop.distributed.utils import rlog from ..tools import get_logger -from .utils import placement_from_sharding_config logger = get_logger(__name__) -logger.setLevel("INFO") +logger.setLevel("DEBUG") + def distributed_timing(fn, *args, **kwargs): dist.barrier() @@ -30,33 +30,59 @@ def distributed_timing(fn, *args, **kwargs): end = time() return result, (end - start) + def dist_model_fn( - name: str, module: nn.Module, device_mesh: DeviceMesh, rank: int, module_map={} + name: str, + module: nn.Module, + device_mesh: DeviceMesh, + rank: int, + tensor_sharding_map={}, ) -> None: """ This function gets called by torch.distributed._tensor.distribute_module on each module in the model. - Each tensor in each module is distributed according to the sharding configuration in module_map. + Each tensor in each module is distributed according to the sharding configuration in tensor_sharding_map. """ - if module in module_map: - node_name = module_map[module]["node"] - for parameter, sharding_config in module_map[module]["sharding"].items(): + if module in tensor_sharding_map: + node_name = tensor_sharding_map[module]["node"] + for parameter, sharding_config in tensor_sharding_map[module][ + "sharding" + ].items(): if parameter in ["data_in_0", "output", "data_out_0"]: continue if not hasattr(module, parameter): - rlog(logger, rank, f"Module {module} does not have parameter {parameter}", level="warning") + rlog( + logger, + rank, + f"Module {module} does not have parameter {parameter}", + level="warning", + ) continue - - placement = placement_from_sharding_config(sharding_config) + + placement = sharding_config.placements try: - rlog(logger, rank, f"Distributing parameter {parameter} of module {node_name} to {placement}", level="debug") - distributed_tensor = distribute_tensor(getattr(module, parameter), device_mesh, placement) + rlog( + logger, + rank, + f"Distributing parameter {parameter} of module {node_name} to {placement}", + level="debug", + ) + distributed_tensor = distribute_tensor( + getattr(module, parameter), device_mesh, placement + ) setattr(module, parameter, torch.nn.Parameter(distributed_tensor)) except Exception as e: - rlog(logger, rank, f"Error distributing parameter {parameter} of module {node_name} to {placement}: {e}", level="error") + rlog( + logger, + rank, + f"Error distributing parameter {parameter} of module {node_name} to {placement}: {e}", + level="error", + ) -def device_fn(rank, world_size, model=None, device_mesh=None, module_map={}, inputs=[]): +def device_fn( + rank, world_size, model=None, device_mesh=None, tensor_sharding_map={}, inputs=[] +): """ This function gets called on each GPU device to set up the distributed environment and distribute the model, following the SPMD model. @@ -73,24 +99,47 @@ def device_fn(rank, world_size, model=None, device_mesh=None, module_map={}, inp # Distribute model parameters according to sharding configuration mesh = DeviceMesh("cuda", mesh=device_mesh) rlog(logger, rank, f"Distributing module parameters...", level="info") - model, dist_time = distributed_timing(distribute_module, model, mesh, partial(dist_model_fn, rank=rank, module_map=module_map), input_fn=None, output_fn=None) + model, dist_time = distributed_timing( + distribute_module, + model, + mesh, + partial(dist_model_fn, rank=rank, tensor_sharding_map=tensor_sharding_map), + input_fn=None, + output_fn=None, + ) rlog(logger, rank, f"Module distribution done. Time taken: {dist_time} seconds.") # Run forward pass rlog(logger, rank, f"Starting forward pass.", level="info") - inputs = [distribute_tensor(in_tensor, mesh, [Replicate(), Replicate()]) for in_tensor in inputs] - out, time_taken = distributed_timing(model, *inputs) + inputs = [ + distribute_tensor(in_tensor, mesh, [Replicate(), Replicate()]) + for in_tensor in inputs + ] + # out, time_taken = distributed_timing(model, *inputs) + time_taken = 0 rlog(logger, rank, f"Forward pass finished. Time taken: {time_taken}", level="info") dist.destroy_process_group() -class MaseLauncher(): - def __init__(self, mase_graph, world_size = None, device_mesh=None): + +class MaseLauncher: + def __init__(self, mase_graph, world_size=None, device_mesh=None): self.mg = mase_graph self.model = mase_graph.model self.world_size = world_size self.device_mesh = device_mesh - def run(self, module_map = {}, inputs=[]): + def run(self, tensor_sharding_map={}, inputs=[]): logger.info(f"Launching model with world size {self.world_size}.") - mp.spawn(partial(device_fn, model=self.model, device_mesh=self.device_mesh, module_map=module_map, inputs=inputs), args=(self.world_size,), nprocs=self.world_size, join=True) \ No newline at end of file + mp.spawn( + partial( + device_fn, + model=self.model, + device_mesh=self.device_mesh, + tensor_sharding_map=tensor_sharding_map, + inputs=inputs, + ), + args=(self.world_size,), + nprocs=self.world_size, + join=True, + ) diff --git a/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py b/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py index 2e5f70b43..553b91886 100644 --- a/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py +++ b/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py @@ -232,7 +232,18 @@ def _run_checks(mg, pass_args): ), f"Linearized variable for resharding cost is not consistent for node {node}." +def deepgetattr(obj, attr, default=None): + """Recurses through an attribute chain to get the ultimate value.""" + import functools + + try: + return functools.reduce(getattr, attr.split("."), obj) + except AttributeError: + return default + + def _mark_sharding(mg, pass_args): + tensor_sharding_map = {} for node in mg.fx_graph.nodes: opt_var = node.meta["mase"]["software"]["autosharding"]["opt_var"] @@ -247,6 +258,21 @@ def _mark_sharding(mg, pass_args): arg_specs = chosen_strategy.input_specs out_specs = chosen_strategy.output_specs + if node.op == "get_attr": + module_str = ".".join(node.target.split(".")[:-1]) + attr = node.target.split(".")[-1] + module = deepgetattr(node.meta["mase"].model, module_str) + + if module not in tensor_sharding_map: + tensor_sharding_map[module] = { + "node": node.name, + "sharding": { + attr: out_specs, + }, + } + else: + tensor_sharding_map[module]["sharding"][attr] = out_specs + if isinstance(arg_specs, DTensorSpec): arg_specs = (arg_specs,) @@ -263,7 +289,7 @@ def _mark_sharding(mg, pass_args): # Annotate output metadata with chosen strategy node.meta["mase"]["common"]["results"]["data_out_0"]["dtensor_spec"] = out_specs - return mg, {} + return mg, tensor_sharding_map def alpa_intra_op_sharding_pass(mg, mesh, pass_args={}, debug=False): @@ -279,8 +305,6 @@ def alpa_intra_op_sharding_pass(mg, mesh, pass_args={}, debug=False): MaseGraph: annotated MaseGraph. """ - module_map = {} - # Formulate and solve the ILP logger.info(f"Formulating the ILP...") mg, problem = _extract_ilp(mg, mesh, pass_args) @@ -298,6 +322,6 @@ def alpa_intra_op_sharding_pass(mg, mesh, pass_args={}, debug=False): if pass_args.get("run_checks", False): _run_checks(mg, pass_args) - mg, _ = _mark_sharding(mg, pass_args) + mg, tensor_sharding_map = _mark_sharding(mg, pass_args) - return mg, {"module_map": module_map, "solution": problem.value} + return mg, {"tensor_sharding_map": tensor_sharding_map, "solution": problem.value} From f36a790c2210c975af8efe5129f61942dd18a487 Mon Sep 17 00:00:00 2001 From: Pedro Gimenes Date: Tue, 9 Jul 2024 19:02:33 +0000 Subject: [PATCH 35/93] fix --- src/chop/distributed/launcher.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/chop/distributed/launcher.py b/src/chop/distributed/launcher.py index b6af31744..9b9954e0f 100644 --- a/src/chop/distributed/launcher.py +++ b/src/chop/distributed/launcher.py @@ -115,8 +115,7 @@ def device_fn( distribute_tensor(in_tensor, mesh, [Replicate(), Replicate()]) for in_tensor in inputs ] - # out, time_taken = distributed_timing(model, *inputs) - time_taken = 0 + out, time_taken = distributed_timing(model, *inputs) rlog(logger, rank, f"Forward pass finished. Time taken: {time_taken}", level="info") dist.destroy_process_group() From 65abdb96933a0a097caa8dcd2b49d179d4449fd6 Mon Sep 17 00:00:00 2001 From: Pedro Gimenes Date: Wed, 10 Jul 2024 12:48:54 +0000 Subject: [PATCH 36/93] enabling import/export autoshardig solutions --- .../autosharding/alpa_intra_operator.py | 91 ++++---- .../analysis/autosharding/autosharding.py | 206 ++++++++++++++++-- 2 files changed, 219 insertions(+), 78 deletions(-) diff --git a/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py b/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py index 553b91886..da1e2904f 100644 --- a/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py +++ b/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py @@ -20,8 +20,6 @@ logger = get_logger(__name__) logger.setLevel("DEBUG") -from .mesh_model import MeshModel - def _extract_ilp(mg, mesh, pass_args={}): """ @@ -31,6 +29,15 @@ def _extract_ilp(mg, mesh, pass_args={}): Return list of constraints associated with ILP. The constraints at this stage only enforce that each optimizer variable is a one-hot boolean vector. + + Args: + mg (MaseGraph): input mase graph. + mesh (MeshModel): mesh model. + pass_args (dict, optional): pass arguments. Defaults to {}. + + Returns: + MaseGraph: input mase graph. + cp.Problem: optimization problem. """ # Setup for the ILP optimization @@ -188,34 +195,18 @@ def _extract_ilp(mg, mesh, pass_args={}): return mg, prob -def _export_solution(mg): - - nodes = [node for node in mg.fx_graph.nodes] - node_names = [node.name for node in nodes] - opt_vars = [ - node.meta["mase"]["software"]["autosharding"]["opt_var"] for node in nodes - ] - opt_vals = [i.value if i is not None else None for i in opt_vars] - choices = [np.argmax(i) for i in opt_vals] - - strategies = [ - i.meta["mase"]["software"]["autosharding"]["op_strategy"].strategies - for i in nodes - ] - shardings = [strat[choices[idx]] for idx, strat in enumerate(strategies)] - map = [ - { - "node": nodes[idx].name, - "input_specs": strat.input_specs, - "output_specs": strat.output_specs, - } - for idx, strat in enumerate(shardings) - ] +def _run_checks(mg, pass_args): + """ + Run checks on the ILP solution to ensure that the constraints were correctly formulated. - return mg, {} + Args: + mg (MaseGraph): input mase graph. + pass_args (dict): pass arguments. + Returns: + None + """ -def _run_checks(mg, pass_args): for node in mg.fx_graph.nodes: check_list = node.meta["mase"]["software"]["autosharding"].get( "e_var_checks", [] @@ -232,18 +223,20 @@ def _run_checks(mg, pass_args): ), f"Linearized variable for resharding cost is not consistent for node {node}." -def deepgetattr(obj, attr, default=None): - """Recurses through an attribute chain to get the ultimate value.""" - import functools +def _mark_sharding(mg, pass_args): + """ + After solving the ILP, annotate the metadata of each operator in the graph with the chosen + parallelization strategy. - try: - return functools.reduce(getattr, attr.split("."), obj) - except AttributeError: - return default + Args: + mg (MaseGraph): input mase graph. + pass_args (dict): pass arguments. + Returns: + MaseGraph: input mase graph. + dict: tensor sharding map. + """ -def _mark_sharding(mg, pass_args): - tensor_sharding_map = {} for node in mg.fx_graph.nodes: opt_var = node.meta["mase"]["software"]["autosharding"]["opt_var"] @@ -255,24 +248,14 @@ def _mark_sharding(mg, pass_args): "op_strategy" ].strategies[idx] + # Annotate chosen placement strategy + node.meta["mase"]["software"]["autosharding"][ + "placement_strategy" + ] = chosen_strategy + arg_specs = chosen_strategy.input_specs out_specs = chosen_strategy.output_specs - if node.op == "get_attr": - module_str = ".".join(node.target.split(".")[:-1]) - attr = node.target.split(".")[-1] - module = deepgetattr(node.meta["mase"].model, module_str) - - if module not in tensor_sharding_map: - tensor_sharding_map[module] = { - "node": node.name, - "sharding": { - attr: out_specs, - }, - } - else: - tensor_sharding_map[module]["sharding"][attr] = out_specs - if isinstance(arg_specs, DTensorSpec): arg_specs = (arg_specs,) @@ -289,7 +272,7 @@ def _mark_sharding(mg, pass_args): # Annotate output metadata with chosen strategy node.meta["mase"]["common"]["results"]["data_out_0"]["dtensor_spec"] = out_specs - return mg, tensor_sharding_map + return mg, {} def alpa_intra_op_sharding_pass(mg, mesh, pass_args={}, debug=False): @@ -322,6 +305,6 @@ def alpa_intra_op_sharding_pass(mg, mesh, pass_args={}, debug=False): if pass_args.get("run_checks", False): _run_checks(mg, pass_args) - mg, tensor_sharding_map = _mark_sharding(mg, pass_args) + mg, _ = _mark_sharding(mg, pass_args) - return mg, {"tensor_sharding_map": tensor_sharding_map, "solution": problem.value} + return mg, {"solution": problem.value} diff --git a/src/chop/passes/graph/analysis/autosharding/autosharding.py b/src/chop/passes/graph/analysis/autosharding/autosharding.py index 24df38e74..08d71e2e4 100644 --- a/src/chop/passes/graph/analysis/autosharding/autosharding.py +++ b/src/chop/passes/graph/analysis/autosharding/autosharding.py @@ -1,6 +1,11 @@ import numpy as np import cvxpy as cp from time import time +import csv +import dill + +from torch.distributed._tensor._op_schema import DTensorSpec +from torch.distributed._tensor.placement_types import Replicate from chop.tools import get_logger @@ -11,6 +16,130 @@ logger.setLevel("DEBUG") +def deepgetattr(obj, attr, default=None): + """Recurses through an attribute chain to get the ultimate value.""" + import functools + + try: + return functools.reduce(getattr, attr.split("."), obj) + except AttributeError: + return default + + +def _import_solution(mg, solution: dict, mesh: MeshModel): + """Import an autosharding solution into the metadata of the MaseGraph. + + Args: + mg (MaseGraph): input mase graph. + solution (dict): autosharding solution. + + Returns: + MaseGraph: input mase graph. + dict: empty dictionary. + """ + for node in mg.fx_graph.nodes: + if node.name not in solution.keys(): + continue + + for arg, arg_spec in solution[node.name].get("args", {}).items(): + node.meta["mase"]["common"]["args"][arg]["dtensor_spec"] = DTensorSpec( + mesh=mesh, placements=arg_spec + ) + + for result, result_spec in solution[node.name].get("results", {}).items(): + node.meta["mase"]["common"]["results"][result]["dtensor_spec"] = ( + DTensorSpec(mesh=mesh, placements=result_spec) + ) + + return mg, {} + + +def _export_solution(mg, export_file: str = "ilp_solution.csv"): + """Export the ILP solution to a csv file. + + Args: + mg (MaseGraph): input mase graph. + export_file (str, optional): output file name. Defaults to "ilp_solution.csv". + + Returns: + MaseGraph: input mase graph. + dict: empty dictionary. + """ + # Reduce metadata to autosharding solution + out_dict = {} + for node in mg.fx_graph.nodes: + node_name = node.name + out_dict[node_name] = { + "args": {}, + "results": {}, + } + for arg, arg_info in node.meta["mase"]["common"]["args"].items(): + if not isinstance(arg_info, dict): + continue + out_dict[node_name]["args"][arg] = arg_info.get( + "dtensor_spec", DTensorSpec(None, (Replicate(), Replicate())) + ).placements + + for result, result_info in node.meta["mase"]["common"]["results"].items(): + if not isinstance(result_info, dict): + continue + out_dict[node_name]["results"][result] = result_info.get( + "dtensor_spec", DTensorSpec(None, (Replicate(), Replicate())) + ).placements + + with open(export_file.replace(".csv", ".pkl"), "wb") as file: + dill.dump(out_dict, file) + + return mg, {} + + +def _get_sharding_map(mg): + """ + Export the tensor sharding map to a dictionary, to be used by the MaseLauncher for + distributed deployment. + + Args: + mg (MaseGraph): input mase graph. + + Returns: + MaseGraph: input mase graph. + dict: tensor sharding map. + + The tensor sharding map is a dictionary with the following structure. + { + module: { + node: node_name, + sharding: { + attr: out_specs, + }, + }, + } + """ + + tensor_sharding_map = {} + for node in mg.fx_graph.nodes: + if node.op == "get_attr": + module_str = ".".join(node.target.split(".")[:-1]) + attr = node.target.split(".")[-1] + module = deepgetattr(node.meta["mase"].model, module_str) + + out_specs = node.meta["mase"]["common"]["results"]["data_out_0"][ + "dtensor_spec" + ] + + if module not in tensor_sharding_map: + tensor_sharding_map[module] = { + "node": node.name, + "sharding": { + attr: out_specs, + }, + } + else: + tensor_sharding_map[module]["sharding"][attr] = out_specs + + return tensor_sharding_map + + def autosharding_analysis_pass(mg, pass_args: dict = {}): """Annotate the metadata of each operator in the graph with a parallelization strategy. @@ -23,15 +152,20 @@ def autosharding_analysis_pass(mg, pass_args: dict = {}): The pass_args dictionary expects the following elements. - - algo (optional) -> str : Sharding algorithm to use. Default is "alpa". - mesh_shape -> tuple : Shape of the device cluster. Should be a 2-dimensional tuple. - inter_node_bandwidth -> int : Inter-node bandwidth, i.e. between GPU nodes. - intra_node_bandwidth -> int : Intra-node bandwidth, i.e. between GPU devices in each node. Additionally, the following elements can be passed. + - algo (optional) -> str : Sharding algorithm to use. Default is "alpa". - communications_backend (optional) -> str : Communications backend to use, e.g. "nccl" or "gloo". Default is "nccl". - skip_fully_replicated (optional) -> bool : If set to true, do not consider fully replicated sharding as an option for any operator. + - time_limit (optional) -> int : Time limit for the ILP solver, in seconds. Default is 10000. + - mip_rel_gap (optional) -> int : MIP relative gap for the ILP solver. Default is 0 (i.e. obtain full solution). + - run_checks (optional) -> bool : If set to true, run checks on the autosharding solution. Default is False. + - preload_solution (optional) -> bool : If set to true, preload autosharding solution from file. + - ilp_solution_file (optional) -> str : File to export the autosharding solution to. Defaults to: "ilp_solution.pkl". """ assert ( @@ -40,29 +174,53 @@ def autosharding_analysis_pass(mg, pass_args: dict = {}): assert "inter_node_bandwidth" in pass_args, "Inter-node bandwidth not specified" assert "intra_node_bandwidth" in pass_args, "Intra-node bandwidth not specified" - # Timing - start_time = time() - # Initialize device mesh model, used for cost estimation mesh = MeshModel(pass_args["mesh_shape"]) - algo = pass_args.get("sharding_algo", "alpa") - - # Communication cost model depends - mesh.set_cost_model_parameters( - intra_node_bandwidth=pass_args["intra_node_bandwidth"], - inter_node_bandwidth=pass_args["inter_node_bandwidth"], - backend=pass_args.get("communications_backend", "default"), - ) - - # Run intra-operator pass - if algo == "alpa": - mg, pass_outs = alpa_autosharding_pass(mg, mesh, pass_args) - - end_time = time() - time_taken = end_time - start_time - logger.info( - f"Autosharding pass complete. Time taken: {time_taken} seconds. Solution: {pass_outs['solution']}" - ) - - return mg, {"autosharding_time": time_taken, **pass_outs} + # Preload autosharding solution + if pass_args.get("preload_solution", False): + fname = pass_args.get("ilp_solution_file", "ilp_solution.pkl") + logger.info(f"Preloading autosharding solution from: {fname}") + with open(fname, "rb") as file: + solution = dill.load(file) + + # Annotate the metadata of each operator with the autosharding solution + mg, pass_outs = _import_solution(mg, solution, mesh) + autosharding_time = 0 + + # Run autosharding pass + else: + # Define autosharding backend + algo = pass_args.get("sharding_algo", "alpa") + + # Communication cost model depends + mesh.set_cost_model_parameters( + intra_node_bandwidth=pass_args["intra_node_bandwidth"], + inter_node_bandwidth=pass_args["inter_node_bandwidth"], + backend=pass_args.get("communications_backend", "default"), + ) + + # Run intra-operator pass + start_time = time() + if algo == "alpa": + mg, pass_outs = alpa_autosharding_pass(mg, mesh, pass_args) + + end_time = time() + autosharding_time = end_time - start_time + logger.info( + f"Autosharding pass complete. Time taken: {autosharding_time} seconds. Solution: {pass_outs['solution']}" + ) + + # Export solution + fname = pass_args.get("ilp_solution_file", "ilp_solution.pkl") + logger.info(f"Exporting solution to {fname}") + mg, _ = _export_solution(mg, export_file=fname) + + if not pass_args.get(f"skip_forward", False): + tensor_sharding_map = _get_sharding_map(mg) + + return mg, { + "autosharding_time": autosharding_time, + "tensor_sharding_map": tensor_sharding_map, + **pass_outs, + } From 8a9b043bba49cdce9c4e64193a253d476887b3ef Mon Sep 17 00:00:00 2001 From: Pedro Gimenes Date: Wed, 10 Jul 2024 17:54:16 +0000 Subject: [PATCH 37/93] common metadata for OPT at call_function granularity --- src/chop/ir/common.py | 5 ++++ .../add_metadata/common_metadata_layers.py | 28 ++++++++++++++++++- .../autosharding/ops/pointwise_ops.py | 5 ++-- 3 files changed, 34 insertions(+), 4 deletions(-) diff --git a/src/chop/ir/common.py b/src/chop/ir/common.py index 8d6cd309b..154f3b181 100644 --- a/src/chop/ir/common.py +++ b/src/chop/ir/common.py @@ -14,6 +14,7 @@ "view", # Memory ops and tensor reshapes "to", + "bool", "flatten", "squeeze", "unsqueeze", @@ -23,6 +24,7 @@ "contiguous", "dropout", "eq", + "ne", "gemm", "ge", "where", @@ -47,6 +49,7 @@ "dim", "finfo", "masked_fill", + "masked_fill_", ] MASE_MODULE_RELATED_FUNCS = [ @@ -110,6 +113,7 @@ "gt", "less", "le", # less or equal + "lt", "sigmoid", "not", "min", @@ -119,6 +123,7 @@ "range", "gelu", "scaled_dot_product_attention", + "embedding", ] diff --git a/src/chop/passes/graph/analysis/add_metadata/common_metadata_layers.py b/src/chop/passes/graph/analysis/add_metadata/common_metadata_layers.py index 066553327..3a64c8d5f 100644 --- a/src/chop/passes/graph/analysis/add_metadata/common_metadata_layers.py +++ b/src/chop/passes/graph/analysis/add_metadata/common_metadata_layers.py @@ -150,6 +150,8 @@ "tile": {"input": "data_in", "dims": "config"}, # https://pytorch.org/docs/stable/generated/torch.lt.html#torch.lt "less": {"input": "data_in", "other": "data_in"}, + # https://pytorch.org/docs/stable/generated/torch.lt.html#torch.lt + "lt": {"input": "data_in", "other": "data_in"}, # https://pytorch.org/docs/stable/generated/torch.le.html "lessorequal": {"input": "data_in", "other": "data_in"}, # https://pytorch.org/docs/stable/generated/torch.le.html @@ -176,6 +178,8 @@ "where": {"condition": "config", "input": "data_in", "other": "data_in"}, # https://pytorch.org/docs/stable/generated/torch.equal.html "eq": {"input": "data_in", "other": "data_in"}, + # https://pytorch.org/docs/stable/generated/torch.ne.html + "ne": {"input": "data_in", "other": "data_in"}, # https://pytorch.org/docs/stable/generated/torch.cumsum.html "cumsum": {"input": "data_in", "dim": "config"}, # onnx_gemm (custom implementation) @@ -189,7 +193,7 @@ "transB": "config", }, # https://pytorch.org/docs/stable/generated/torch.full.html - "full": {"size": "config", "fill_value": "data_in"}, + "full": {"size": "config", "fill_value": "data_in", "device": "config"}, # get item "getitem": {"a": "data_in", "b": "data_in"}, # getattr @@ -207,6 +211,16 @@ }, # https://pytorch.org/docs/stable/generated/torch.transpose.html "transpose": {"input": "data_in", "dim_0": "config", "dim_1": "config"}, + # https://pytorch.org/docs/stable/generated/torch.nn.functional.embedding.html + "embedding": { + "input": "data_in", + "weight": "data_in", + "padding_idx": "config", + "max_norm": "config", + "norm_type": "config", + "scale_grad_by_freq": "config", + "sparse": "config", + }, } module_data = { @@ -303,11 +317,20 @@ "transpose": {"dim_0": "config", "dim_1": "config"}, # https://pytorch.org/docs/stable/generated/torch.Tensor.contiguous.html#torch.Tensor.contiguous "contiguous": {}, + # https://pytorch.org/docs/stable/generated/torch.Tensor.masked_fill.html#torch.Tensor.masked_fill "masked_fill": {"mask": "data_in", "value": "data_in"}, + # https://pytorch.org/docs/stable/generated/torch.Tensor.masked_fill_.html#torch.Tensor.masked_fill_ + "masked_fill_": {"mask": "data_in", "value": "data_in"}, # https://pytorch.org/docs/stable/generated/torch.Tensor.unsqueeze.html#torch.Tensor.unsqueeze "unsqueeze": {"input": "data_in", "dim": "config"}, # https://pytorch.org/docs/stable/generated/torch.Tensor.split.html#torch.Tensor.split "split": {"input": "data_in", "split_size_or_sections": "config", "dim": "config"}, + # https://pytorch.org/docs/stable/generated/torch.Tensor.bool.html + "bool": {"memory_format": "config"}, + # https://pytorch.org/docs/stable/generated/torch.Tensor.long.html + "long": {"memory_format": "config"}, + # https://pytorch.org/docs/stable/generated/torch.Tensor.type_as.html + "type_as": {"tensor": "data_in"}, } @@ -478,7 +501,10 @@ def analyse_common_parameters_module(meta, result, args, kwargs, add_value=True) def analyse_common_parameters_method(meta, result, args, kwargs, add_value=True): mase_op = meta.parameters["common"]["mase_op"] meta = analyse_result(meta, result, add_value) + # try: meta = match_args_and_kwargs(meta, args, kwargs, method_data[mase_op], add_value) + # except: + # breakpoint() return meta diff --git a/src/chop/passes/graph/analysis/autosharding/ops/pointwise_ops.py b/src/chop/passes/graph/analysis/autosharding/ops/pointwise_ops.py index be9da28ae..acea27692 100644 --- a/src/chop/passes/graph/analysis/autosharding/ops/pointwise_ops.py +++ b/src/chop/passes/graph/analysis/autosharding/ops/pointwise_ops.py @@ -70,11 +70,10 @@ def common_pointwise_strategy( parsed_args.append(torch.Tensor(arg)) elif isinstance(arg, torch.Tensor): parsed_args.append(arg) - elif isinstance(arg, float): + elif isinstance(arg, (float, int)): parsed_args.append(torch.Tensor([arg])) else: - breakpoint() - raise ValueError("Unrecognized arg type") + raise ValueError(f"Unrecognized arg type: {type(arg)}") common_shape = torch.broadcast_shapes(*[arg.shape for arg in parsed_args]) From 276c9c43483d63575ed357cb4a28668d7bdb646c Mon Sep 17 00:00:00 2001 From: Pedro Gimenes Date: Thu, 11 Jul 2024 11:31:54 +0000 Subject: [PATCH 38/93] handle embedding op in autosharding --- .../graph/analysis/autosharding/layers.py | 72 ++--- .../autosharding/ops/embedding_ops.py | 264 ++++++++++++++++++ 2 files changed, 302 insertions(+), 34 deletions(-) create mode 100644 src/chop/passes/graph/analysis/autosharding/ops/embedding_ops.py diff --git a/src/chop/passes/graph/analysis/autosharding/layers.py b/src/chop/passes/graph/analysis/autosharding/layers.py index b718d1383..1cf0050a8 100644 --- a/src/chop/passes/graph/analysis/autosharding/layers.py +++ b/src/chop/passes/graph/analysis/autosharding/layers.py @@ -2,7 +2,6 @@ import operator import torch -import torch.nn as nn import torch.nn.functional as F from torch.distributed._tensor._op_schema import OpStrategy, PlacementStrategy from torch.distributed._tensor.placement_types import ( @@ -13,9 +12,6 @@ ) from chop.tools import get_logger -from chop.models.patched.bert.modeling_bert import BertSelfAttention - -from .deprecated.alpa_cost_modelling import get_communication_cost from .ops.matrix_ops import ( transpose_strategy, @@ -24,42 +20,13 @@ bmm_strategy, baddmm_strategy, ) - from .ops.view_ops import get_reshape_strategy from .ops.pointwise_ops import pointwise_strategy, linear_pointwise_strategy - from .ops.math_ops import softmax_strategy, layer_norm_strategy +from .ops.embedding_ops import embedding_strategy logger = get_logger(__name__) -AUTOSHARDING_FUNCTIONS = { - torch.transpose: transpose_strategy, - torch.mm: mm_strategy, - torch.addmm: addmm_strategy, - torch.bmm: bmm_strategy, - torch.baddbmm: baddmm_strategy, - torch.add: linear_pointwise_strategy, - operator.add: linear_pointwise_strategy, - operator.truediv: pointwise_strategy, - F.gelu: pointwise_strategy, - torch.matmul: bmm_strategy, - torch.softmax: softmax_strategy, - F.softmax: softmax_strategy, - F.layer_norm: layer_norm_strategy, -} - -AUTOSHARDING_METHODS = { - "view": get_reshape_strategy(torch.Tensor.view), - "reshape": get_reshape_strategy(torch.Tensor.reshape), - "expand": get_reshape_strategy(torch.Tensor.expand), - "permute": get_reshape_strategy(torch.permute), - "transpose": get_reshape_strategy(torch.transpose), -} - -IMPLICIT_FUNCS = [operator.getitem] - -IMPLICIT_METHODS = ["size"] - def placeholder_or_getattr_strategy(meta, mesh, skip_fully_replicated=False): ndims = len(meta["common"]["results"]["data_out_0"]["shape"]) @@ -132,3 +99,40 @@ def fully_replicated_strategy(meta, mesh): shardings = [PlacementStrategy(input_specs=in_spec, output_specs=out_spec)] return OpStrategy(shardings) + + +AUTOSHARDING_FUNCTIONS = { + torch.transpose: transpose_strategy, + torch.mm: mm_strategy, + torch.addmm: addmm_strategy, + torch.bmm: bmm_strategy, + torch.baddbmm: baddmm_strategy, + torch.add: linear_pointwise_strategy, + operator.add: linear_pointwise_strategy, + operator.truediv: pointwise_strategy, + F.gelu: pointwise_strategy, + torch.sub: pointwise_strategy, + torch.gt: pointwise_strategy, + operator.gt: pointwise_strategy, + operator.sub: pointwise_strategy, + torch.matmul: bmm_strategy, + torch.softmax: softmax_strategy, + F.softmax: softmax_strategy, + F.layer_norm: layer_norm_strategy, + torch.ones: fully_replicated_strategy, + torch.full: fully_replicated_strategy, + getattr: fully_replicated_strategy, + F.embedding: embedding_strategy, +} + +AUTOSHARDING_METHODS = { + "view": get_reshape_strategy(torch.Tensor.view), + "reshape": get_reshape_strategy(torch.Tensor.reshape), + "expand": get_reshape_strategy(torch.Tensor.expand), + "permute": get_reshape_strategy(torch.permute), + "transpose": get_reshape_strategy(torch.transpose), +} + +IMPLICIT_FUNCS = [operator.getitem] + +IMPLICIT_METHODS = ["size"] diff --git a/src/chop/passes/graph/analysis/autosharding/ops/embedding_ops.py b/src/chop/passes/graph/analysis/autosharding/ops/embedding_ops.py new file mode 100644 index 000000000..27533e453 --- /dev/null +++ b/src/chop/passes/graph/analysis/autosharding/ops/embedding_ops.py @@ -0,0 +1,264 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates +# implement matrix related ops for distributed tensor +from dataclasses import dataclass, field +from typing import cast, List, Optional +import itertools + +import torch +import torch.distributed._functional_collectives as funcol +from torch.distributed._tensor._op_schema import ( + OpSchema, + OpStrategy, + StrategyType, + DTensorSpec, + PlacementStrategy, +) +from torch.distributed._tensor.ops.utils import ( + is_tensor_shardable, + generate_redistribute_costs, +) +from torch.distributed._tensor.placement_types import ( + Partial, + Placement, + Replicate, + Shard, +) +from torch.distributed.device_mesh import DeviceMesh + + +aten = torch.ops.aten + + +@dataclass +class MaskBuffer: + data: Optional[torch.Tensor] = None + + def materialize_mask(self, mask): + if self.data is not None: + raise RuntimeError("MaskBuffer has already been materialized") + self.data = mask + + def release_mask(self): + # TODO: evaluate if we need to release the mask buffer or the buffer + # can just have the same lifetime as the Partial placement + if self.data is None: + raise RuntimeError("MaskBuffer has not been materialized") + self.data = None + + def apply_mask(self, tensor): + if self.data is None: + raise RuntimeError("MaskBuffer has not been materialized") + + # NOTE: _MaskPartial is being used by the embedding op and the gather op. + # For gather, the mask has the same dimension as the output tensor, whereas + # the output of the embedding op has an additional dimension compare to the input, + # hence the output masking logic below having two different cases. + if tensor.ndim == self.data.ndim: + tensor[self.data] = 0.0 + else: + tensor[self.data, :] = 0.0 + + +@dataclass(frozen=True) +class _MaskPartial(Partial): + """ + A partial mask placement devised for rowwise sharded embedding op, where we need + to mask and adjust the indices to the local embedding shard, embedding masking + is a special type of the Partial placement + + NOTE: the lifecycle of this MaskPartial placement follows the corresponding DTensor + lifecycle, i.e. the indices_mask would only be alive during the lifetime of the DTensor. + """ + + logical_dim_size: int = -1 + mask_buffer: MaskBuffer = field(default_factory=MaskBuffer) + + def _partition_value( + self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int + ) -> torch.Tensor: + # override parent logic to perform partial mask for embedding + num_chunks = mesh.size(mesh_dim) + # get local shard size and offset on the embedding_dim + local_shard_size, local_offset_on_dim = Shard._local_shard_size_on_dim( + self.logical_dim_size, + num_chunks, + mesh.get_local_rank(mesh_dim), + return_offset=True, + ) + # Build the input mask and save it for the current partial placement + # this is so that the output of embedding op can reuse the same partial + # placement saved mask to perform mask + reduction + mask = (tensor < local_offset_on_dim) | ( + tensor >= local_offset_on_dim + local_shard_size + ) + # mask the input tensor + masked_tensor = tensor.clone() - local_offset_on_dim + masked_tensor[mask] = 0 + # materialize the mask buffer to be used for reduction + self.mask_buffer.materialize_mask(mask) + return masked_tensor + + def _reduce_value( + self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int + ) -> torch.Tensor: + # by the time we ned reduction, we should have already saved the mask + assert self.mask_buffer.data is not None + + # apply the mask to the tensor that pending reduction + self.mask_buffer.apply_mask(tensor) + + # clear the mask buffer + self.mask_buffer.release_mask() + + # perform sum reduction + return funcol.all_reduce( + tensor, reduceOp=self.reduce_op, group=(mesh, mesh_dim) + ) + + def _reduce_shard_value( + self, + tensor: torch.Tensor, + mesh: DeviceMesh, + mesh_dim: int, + shard_spec: Placement, + ) -> torch.Tensor: + # by the time we ned reduction, we should have already saved the mask + assert self.mask_buffer.data is not None + + # apply the mask to the tensor that pending reduction + self.mask_buffer.apply_mask(tensor) + + # clear the mask buffer + self.mask_buffer.release_mask() + + # call reduce_shard_tensor of the shard_spec. + shard_spec = cast(Shard, shard_spec) + return shard_spec._reduce_shard_tensor(tensor, mesh, self.reduce_op, mesh_dim) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, _MaskPartial): + return False + + # if either data is not None, we invalidate the sharding cache, as this indicates + # the current MaskPartial placement is still in use and should not be used for cache hit. + if self.mask_buffer.data is not None or other.mask_buffer.data is not None: + return False + + return ( + self.reduce_op == other.reduce_op + and self.logical_dim_size == other.logical_dim_size + ) + + def __hash__(self) -> int: + return 1 + hash( + (self.logical_dim_size, id(self.mask_buffer.data), self.reduce_op) + ) + + def __repr__(self) -> str: + """ + machine readable representation of the MaskPartial placement + """ + return f"_MaskPartial(logical_dim_size={self.logical_dim_size})" + + def __str__(self) -> str: + """ + human readable representation of the MaskPartial placement + """ + return "MaskP" + + +def expand_to_full_mesh_op_strategy( + meta, + mesh: DeviceMesh, + single_mesh_dim_strategies: List[List[Placement]], + *, + input_index: int = 1, + inplace_op: bool = False, +) -> OpStrategy: + # Expand the single_mesh_dim_strategies to full mesh dim strategies. + all_mesh_dim_strategies = [single_mesh_dim_strategies] * mesh.ndim + + strategy_combs = itertools.product(*all_mesh_dim_strategies) + + all_strategies = [] + for strategy_comb in strategy_combs: + spec_list = [] + for specs in zip(*strategy_comb): + spec_list.append(DTensorSpec(mesh, tuple(specs))) + + input_specs = spec_list[input_index:] + # input_args_strategy = op_schema.args_strategy + input_args_strategy = tuple( + arg.meta["mase"]["software"]["autosharding"]["op_strategy"] + for arg in meta.node.args + ) + assert len(input_specs) == len(input_args_strategy) + self_spec = input_args_strategy[0].strategies[0].output_spec + if inplace_op and self_spec.placements != input_specs[0].placements: + # if it's inplace op, we would only allow the placement strategy to be added when the + # input_spec matches the first argument's runtime sharding, otherwise we skip + continue + + # check inputs shardable + inputs_shardable = all( + is_tensor_shardable(inp.shape, s) + for inp, s in zip(input_args_strategy, input_specs) + ) + + # only add to the all_strategies list when all inputs are shardable + if inputs_shardable: + redistribute_cost = [ + generate_redistribute_costs(input_strategy, input_spec) + for input_strategy, input_spec in zip(input_args_strategy, input_specs) + ] + strategy = PlacementStrategy( + output_specs=( + tuple(spec_list[:input_index]) if input_index > 1 else spec_list[0] + ), + input_specs=input_specs, + redistribute_cost=redistribute_cost, + ) + all_strategies.append(strategy) + + return OpStrategy(all_strategies) + + +def embedding_strategy(meta, mesh) -> StrategyType: + """ + This strategy handles embedding op. We have two possible embedding shardings: + rowwise and colwise + """ + weight_shape = meta["common"]["args"]["data_in_0"]["shape"] + indices_shape = meta["common"]["args"]["data_in_1"]["shape"] + output_emd_dim = len(indices_shape) + + single_mesh_dim_strategies = [] + + # placement list stores placements of [output, weight, input_indices] + # first we always have replicate all for inputs and output + all_replicate: List[Placement] = [Replicate()] * 3 + single_mesh_dim_strategies.append(all_replicate) + + # colwise sharding, output shard on last dim, weight shard on dim 1, input replicate + colwise_sharding = [Shard(output_emd_dim), Shard(1), Replicate()] + single_mesh_dim_strategies.append(colwise_sharding) + + # rowwise sharding, output is embedding partial, weight shard on dim 0, input accepts embedding partial + embedding_partial_placement = _MaskPartial(logical_dim_size=weight_shape[0]) + + # NOTE we want to reuse the same mask partial placement so that we can reuse the same mask that generates + # from the input indices and use it for output reduction + rowwise_sharding = [ + embedding_partial_placement, + Shard(0), + embedding_partial_placement, + ] + single_mesh_dim_strategies.append(rowwise_sharding) + + # batch dim sharding, weight replicated, input can shard on any dim, output follows input + for input_dim in range(len(indices_shape)): + batch_sharding = [Shard(input_dim), Replicate(), Shard(input_dim)] + single_mesh_dim_strategies.append(batch_sharding) + + return expand_to_full_mesh_op_strategy(meta, mesh, single_mesh_dim_strategies) From 6fc504ff558d6aab2754e95c3f0604b2ac28b937 Mon Sep 17 00:00:00 2001 From: Pedro Gimenes Date: Thu, 11 Jul 2024 12:01:09 +0000 Subject: [PATCH 39/93] tensormeta for embedding op --- .../analysis/autosharding/alpa_intra_operator.py | 6 +++++- .../passes/graph/analysis/autosharding/layers.py | 1 + .../analysis/autosharding/ops/embedding_ops.py | 13 ++++++++++++- 3 files changed, 18 insertions(+), 2 deletions(-) diff --git a/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py b/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py index da1e2904f..a6ffce8da 100644 --- a/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py +++ b/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py @@ -143,7 +143,11 @@ def _extract_ilp(mg, mesh, pass_args={}): "op_strategy" ] node_in_specs = [ - strategy.input_specs[arg_idx] + ( + [strategy.input_specs][arg_idx] + if isinstance(strategy.input_specs, DTensorSpec) + else strategy.input_specs[arg_idx] + ) for strategy in node_op_strategy.strategies ] diff --git a/src/chop/passes/graph/analysis/autosharding/layers.py b/src/chop/passes/graph/analysis/autosharding/layers.py index 1cf0050a8..1e07f32d9 100644 --- a/src/chop/passes/graph/analysis/autosharding/layers.py +++ b/src/chop/passes/graph/analysis/autosharding/layers.py @@ -123,6 +123,7 @@ def fully_replicated_strategy(meta, mesh): torch.full: fully_replicated_strategy, getattr: fully_replicated_strategy, F.embedding: embedding_strategy, + torch.finfo: fully_replicated_strategy, } AUTOSHARDING_METHODS = { diff --git a/src/chop/passes/graph/analysis/autosharding/ops/embedding_ops.py b/src/chop/passes/graph/analysis/autosharding/ops/embedding_ops.py index 27533e453..15dd9c592 100644 --- a/src/chop/passes/graph/analysis/autosharding/ops/embedding_ops.py +++ b/src/chop/passes/graph/analysis/autosharding/ops/embedding_ops.py @@ -23,6 +23,7 @@ Placement, Replicate, Shard, + TensorMeta, ) from torch.distributed.device_mesh import DeviceMesh @@ -185,7 +186,17 @@ def expand_to_full_mesh_op_strategy( for strategy_comb in strategy_combs: spec_list = [] for specs in zip(*strategy_comb): - spec_list.append(DTensorSpec(mesh, tuple(specs))) + spec_list.append( + DTensorSpec( + mesh, + tuple(specs), + tensor_meta=TensorMeta( + shape=meta["common"]["results"]["data_out_0"]["shape"], + stride=None, + dtype=meta["common"]["results"]["data_out_0"]["torch_dtype"], + ), + ) + ) input_specs = spec_list[input_index:] # input_args_strategy = op_schema.args_strategy From c111fa1851012324a7ae4b4436cc7b85c290007a Mon Sep 17 00:00:00 2001 From: Pedro Gimenes Date: Mon, 15 Jul 2024 14:47:21 +0000 Subject: [PATCH 40/93] support autosharding for OPT --- src/chop/ir/graph/mase_graph.py | 15 +- .../autosharding/alpa_intra_operator.py | 33 ++- .../graph/analysis/autosharding/layers.py | 273 ++++++++++++++++-- .../autosharding/ops/basic_strategy.py | 3 + .../autosharding/ops/embedding_ops.py | 6 +- .../analysis/autosharding/ops/math_ops.py | 3 + .../analysis/autosharding/ops/matrix_ops.py | 3 +- .../autosharding/ops/pointwise_ops.py | 3 + .../analysis/autosharding/ops/tensor_ops.py | 111 +++++++ .../analysis/autosharding/ops/view_ops.py | 17 +- 10 files changed, 426 insertions(+), 41 deletions(-) create mode 100644 src/chop/passes/graph/analysis/autosharding/ops/tensor_ops.py diff --git a/src/chop/ir/graph/mase_graph.py b/src/chop/ir/graph/mase_graph.py index 5826cb22c..b8fff2ba9 100644 --- a/src/chop/ir/graph/mase_graph.py +++ b/src/chop/ir/graph/mase_graph.py @@ -93,6 +93,7 @@ def __init__( model, cf_args: Optional[Dict[str, Any]] = None, custom_ops: dict = None, + hf_input_names: list = None, ) -> None: """Mase takes a torch.fx graph representation of a model and translates it into a customised representation (Mase graph IR). The Mase graph @@ -110,7 +111,12 @@ def __init__( self.model.patched_custom_layers = [] self.model.additional_inputs = [] elif isinstance(model, torch.nn.Module): - self.model = self.trace_torch_module(model, cf_args, custom_ops) + self.model = self.trace_torch_module( + model, + cf_args, + custom_ops, + hf_input_names=hf_input_names, + ) else: raise ValueError( f"Expected fx.GraphModule or nn.Module, but received model: {type(model)}" @@ -123,6 +129,7 @@ def trace_torch_module( model: torch.nn.Module, cf_args: Optional[Dict[str, Any]] = None, custom_ops: dict = None, + hf_input_names: list = None, ): # * HuggingFace model if isinstance(model, PreTrainedModel): @@ -153,7 +160,11 @@ def is_leaf_module( wrap_is_leaf_module(tracer_cls.is_leaf_module), ) - graph_module = hf_symbolic_trace(model, tracer_cls=tracer_cls) + graph_module = hf_symbolic_trace( + model, + tracer_cls=tracer_cls, + input_names=hf_input_names, + ) graph_module.custom_ops = custom_ops # ! TO DO: remove this legacy stuff diff --git a/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py b/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py index a6ffce8da..4b88993db 100644 --- a/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py +++ b/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py @@ -56,10 +56,15 @@ def _extract_ilp(mg, mesh, pass_args={}): op_strategy = fully_replicated_strategy(node.meta["mase"], mesh) + opt_var = cp.Variable(1, boolean=True) + constr += [ + cp.sum(opt_var) == 1, + ] + # Opt var is None since no decision needs to be taken node.meta["mase"]["software"]["autosharding"] = { "op_strategy": op_strategy, - "opt_var": None, + "opt_var": opt_var, "input": None, "output": None, } @@ -80,12 +85,12 @@ def _extract_ilp(mg, mesh, pass_args={}): elif node.op == "output": logger.debug( - f"Op strategy from node {node.args[0]} is propagated to {node} node." + f"Op strategy from node {node.all_input_nodes[0]} is propagated to {node} node." ) node.meta["mase"]["software"]["autosharding"] = { - "op_strategy": node.args[0].meta["mase"]["software"]["autosharding"][ - "op_strategy" - ], + "op_strategy": node.all_input_nodes[0].meta["mase"]["software"][ + "autosharding" + ]["op_strategy"], "opt_var": None, "input": None, "output": None, @@ -104,9 +109,14 @@ def _extract_ilp(mg, mesh, pass_args={}): else: logger.warning(f"Unknown node {node.name} with op {node.op}") + op_strategy = fully_replicated_strategy(node.meta["mase"], mesh) + opt_var = cp.Variable(1, boolean=True) + constr += [ + cp.sum(opt_var) == 1, + ] node.meta["mase"]["software"]["autosharding"] = { "op_strategy": fully_replicated_strategy(node.meta["mase"], mesh), - "opt_var": None, + "opt_var": opt_var, "input": None, "output": None, } @@ -162,6 +172,7 @@ def _extract_ilp(mg, mesh, pass_args={}): # Formulate resharding cost matrix resharding_costs = np.zeros((opt_var.shape[0], in_opt_var.shape[0])) + for dest_idx, dest_spec in enumerate(node_in_specs): for src_idx, src_spec in enumerate(arg_out_specs): cost = redistribute_cost(src_spec, dest_spec) @@ -258,13 +269,17 @@ def _mark_sharding(mg, pass_args): ] = chosen_strategy arg_specs = chosen_strategy.input_specs - out_specs = chosen_strategy.output_specs + out_spec = chosen_strategy.output_specs if isinstance(arg_specs, DTensorSpec): arg_specs = (arg_specs,) # Annotate arg metadata with chosen strategy - if node.op not in ["placeholder", "get_attr", "output"]: + if node.op in ["placeholder", "get_attr", "call_method", "output"]: + pass + + # call_function nodes + else: arg_list = [i for i in node.meta["mase"]["common"]["args"].keys()] for arg_idx, arg_spec in enumerate(arg_specs): @@ -274,7 +289,7 @@ def _mark_sharding(mg, pass_args): arg_meta["dtensor_spec"] = arg_spec # Annotate output metadata with chosen strategy - node.meta["mase"]["common"]["results"]["data_out_0"]["dtensor_spec"] = out_specs + node.meta["mase"]["common"]["results"]["data_out_0"]["dtensor_spec"] = out_spec return mg, {} diff --git a/src/chop/passes/graph/analysis/autosharding/layers.py b/src/chop/passes/graph/analysis/autosharding/layers.py index 1e07f32d9..833fb3d5d 100644 --- a/src/chop/passes/graph/analysis/autosharding/layers.py +++ b/src/chop/passes/graph/analysis/autosharding/layers.py @@ -24,10 +24,36 @@ from .ops.pointwise_ops import pointwise_strategy, linear_pointwise_strategy from .ops.math_ops import softmax_strategy, layer_norm_strategy from .ops.embedding_ops import embedding_strategy +from .ops.tensor_ops import tensor_op_strategy, tensor_equal_strategy logger = get_logger(__name__) +def find_shape_and_dtype(arg): + + if isinstance(arg, dict): + in_shape = arg["shape"] + in_dtype = arg["torch_dtype"] + elif isinstance(arg, (tuple, list)): + arg = torch.Tensor(arg) + in_shape = arg.shape + in_dtype = arg.dtype + elif isinstance(arg, torch.Size): + arg = torch.Tensor(list(arg)) + in_shape = arg.shape + in_dtype = arg.dtype + elif isinstance(arg, (float, int)): + arg = torch.Tensor([arg]) + in_shape = arg.shape + in_dtype = arg.dtype + else: + logger.warning(f"Unknown type for arg: {arg}") + in_shape = tuple() + in_dtype = type(arg) + + return in_shape, in_dtype + + def placeholder_or_getattr_strategy(meta, mesh, skip_fully_replicated=False): ndims = len(meta["common"]["results"]["data_out_0"]["shape"]) opts = [Replicate()] + [Shard(dim) for dim in range(ndims)] @@ -66,18 +92,16 @@ def fully_replicated_strategy(meta, mesh): else [i for i in meta["common"]["args"].keys()][0] ) arg = meta["common"]["args"][first_arg_key] - if isinstance(arg, dict): - in_shape = arg["shape"] - in_dtype = arg["torch_dtype"] - else: - arg = torch.Tensor(arg) - in_shape = arg.shape - in_dtype = arg.dtype + in_shape, in_dtype = find_shape_and_dtype(arg) in_spec = DTensorSpec( mesh, sharding, - tensor_meta=TensorMeta(shape=in_shape, stride=None, dtype=in_dtype), + tensor_meta=TensorMeta( + shape=in_shape, + stride=None, + dtype=in_dtype, + ), ) dtype_key = ( @@ -102,38 +126,243 @@ def fully_replicated_strategy(meta, mesh): AUTOSHARDING_FUNCTIONS = { + # embedding_ops.py + F.embedding: embedding_strategy, + # math_ops.py + torch.softmax: softmax_strategy, + F.softmax: softmax_strategy, + torch.log_softmax: softmax_strategy, + F.log_softmax: softmax_strategy, + F.layer_norm: layer_norm_strategy, + # matrix_ops.py torch.transpose: transpose_strategy, torch.mm: mm_strategy, + torch.matmul: bmm_strategy, torch.addmm: addmm_strategy, torch.bmm: bmm_strategy, torch.baddbmm: baddmm_strategy, + # pointwise_ops.py torch.add: linear_pointwise_strategy, operator.add: linear_pointwise_strategy, - operator.truediv: pointwise_strategy, - F.gelu: pointwise_strategy, - torch.sub: pointwise_strategy, + torch.Tensor.add_: linear_pointwise_strategy, + torch.Tensor.to: linear_pointwise_strategy, + torch.abs: pointwise_strategy, + torch.acos: pointwise_strategy, + torch.acosh: pointwise_strategy, + torch.addcdiv: pointwise_strategy, + torch.addcmul: pointwise_strategy, + torch.angle: pointwise_strategy, + torch.asin: pointwise_strategy, + torch.asinh: pointwise_strategy, + torch.atan: pointwise_strategy, + torch.atan2: pointwise_strategy, + torch.atanh: pointwise_strategy, + torch.bitwise_and: pointwise_strategy, + torch.bitwise_left_shift: pointwise_strategy, + torch.bitwise_not: pointwise_strategy, + torch.bitwise_or: pointwise_strategy, + torch.bitwise_right_shift: pointwise_strategy, + torch.bitwise_xor: pointwise_strategy, + torch.ceil: pointwise_strategy, + torch.clamp: pointwise_strategy, + torch.clip: pointwise_strategy, + torch.conj_physical: pointwise_strategy, + torch.copysign: pointwise_strategy, + torch.cos: pointwise_strategy, + torch.cosh: pointwise_strategy, + torch.deg2rad: pointwise_strategy, + torch.digamma: pointwise_strategy, + torch.div: pointwise_strategy, + torch.eq: pointwise_strategy, + # operator.eq: pointwise_strategy, + torch.erf: pointwise_strategy, + torch.erfc: pointwise_strategy, + torch.erfinv: pointwise_strategy, + torch.exp: pointwise_strategy, + torch.exp2: pointwise_strategy, + torch.expm1: pointwise_strategy, + torch.float_power: pointwise_strategy, + torch.floor: pointwise_strategy, + torch.fmod: pointwise_strategy, + torch.frac: pointwise_strategy, + torch.ge: pointwise_strategy, torch.gt: pointwise_strategy, operator.gt: pointwise_strategy, + torch.hypot: pointwise_strategy, + torch.i0: pointwise_strategy, + torch.igamma: pointwise_strategy, + torch.igammac: pointwise_strategy, + torch.isnan: pointwise_strategy, + torch.ldexp: pointwise_strategy, + torch.lt: pointwise_strategy, + operator.lt: pointwise_strategy, + torch.le: pointwise_strategy, + torch.lerp: pointwise_strategy, + torch.lgamma: pointwise_strategy, + torch.log: pointwise_strategy, + torch.log10: pointwise_strategy, + torch.log1p: pointwise_strategy, + torch.log2: pointwise_strategy, + torch.logaddexp: pointwise_strategy, + torch.logaddexp2: pointwise_strategy, + torch.logical_and: pointwise_strategy, + torch.logical_not: pointwise_strategy, + torch.logical_or: pointwise_strategy, + torch.logical_xor: pointwise_strategy, + torch.logit: pointwise_strategy, + torch.masked_fill: pointwise_strategy, + torch.maximum: pointwise_strategy, + torch.mul: pointwise_strategy, + operator.mul: pointwise_strategy, + torch.mvlgamma: pointwise_strategy, + torch.nan_to_num: pointwise_strategy, + torch.ne: pointwise_strategy, + operator.ne: pointwise_strategy, + torch.neg: pointwise_strategy, + torch.nextafter: pointwise_strategy, + torch.polygamma: pointwise_strategy, + torch.positive: pointwise_strategy, + torch.pow: pointwise_strategy, + torch.reciprocal: pointwise_strategy, + torch.rad2deg: pointwise_strategy, + torch.relu: pointwise_strategy, + torch.remainder: pointwise_strategy, + torch.round: pointwise_strategy, + torch.rsqrt: pointwise_strategy, + torch.rsub: pointwise_strategy, + torch.sgn: pointwise_strategy, + torch.sigmoid: pointwise_strategy, + torch.sign: pointwise_strategy, + torch.signbit: pointwise_strategy, + torch.sin: pointwise_strategy, + torch.sinc: pointwise_strategy, + torch.sinh: pointwise_strategy, + torch.sqrt: pointwise_strategy, + torch.square: pointwise_strategy, + torch.sub: pointwise_strategy, operator.sub: pointwise_strategy, - torch.matmul: bmm_strategy, - torch.softmax: softmax_strategy, - F.softmax: softmax_strategy, - F.layer_norm: layer_norm_strategy, + torch.tan: pointwise_strategy, + torch.tanh: pointwise_strategy, + torch.true_divide: pointwise_strategy, + torch.trunc: pointwise_strategy, + torch.where: pointwise_strategy, + torch.xlogy: pointwise_strategy, + F.gelu: pointwise_strategy, + F.relu: pointwise_strategy, + F.sigmoid: pointwise_strategy, + F.silu: pointwise_strategy, + F.tanh: pointwise_strategy, + torch.Tensor.abs_: pointwise_strategy, + torch.Tensor.acos_: pointwise_strategy, + torch.Tensor.acosh_: pointwise_strategy, + torch.Tensor.add_: pointwise_strategy, + torch.Tensor.addcdiv_: pointwise_strategy, + torch.Tensor.addcmul_: pointwise_strategy, + torch.Tensor.asin_: pointwise_strategy, + torch.Tensor.asinh_: pointwise_strategy, + torch.Tensor.atan2_: pointwise_strategy, + torch.Tensor.atan_: pointwise_strategy, + torch.Tensor.atanh_: pointwise_strategy, + torch.Tensor.bitwise_and_: pointwise_strategy, + torch.Tensor.bitwise_left_shift_: pointwise_strategy, + torch.Tensor.bitwise_not_: pointwise_strategy, + torch.Tensor.bitwise_or_: pointwise_strategy, + torch.Tensor.bitwise_right_shift_: pointwise_strategy, + torch.Tensor.bitwise_xor_: pointwise_strategy, + torch.Tensor.ceil_: pointwise_strategy, + torch.Tensor.clamp_: pointwise_strategy, + torch.Tensor.clip_: pointwise_strategy, + torch.Tensor.conj_physical_: pointwise_strategy, + torch.Tensor.copysign_: pointwise_strategy, + torch.Tensor.cos_: pointwise_strategy, + torch.Tensor.cosh_: pointwise_strategy, + torch.Tensor.deg2rad_: pointwise_strategy, + torch.Tensor.digamma_: pointwise_strategy, + torch.Tensor.div_: pointwise_strategy, + torch.Tensor.erf_: pointwise_strategy, + torch.Tensor.erfc_: pointwise_strategy, + torch.Tensor.erfinv_: pointwise_strategy, + torch.Tensor.exp2_: pointwise_strategy, + torch.Tensor.exp_: pointwise_strategy, + torch.Tensor.expm1_: pointwise_strategy, + torch.Tensor.float_power_: pointwise_strategy, + torch.Tensor.floor_: pointwise_strategy, + torch.Tensor.fmod_: pointwise_strategy, + torch.Tensor.frac_: pointwise_strategy, + torch.Tensor.hypot_: pointwise_strategy, + torch.Tensor.i0_: pointwise_strategy, + torch.Tensor.igamma_: pointwise_strategy, + torch.Tensor.igammac_: pointwise_strategy, + torch.Tensor.ldexp_: pointwise_strategy, + torch.Tensor.lerp_: pointwise_strategy, + torch.Tensor.lgamma_: pointwise_strategy, + torch.Tensor.log10_: pointwise_strategy, + torch.Tensor.log1p_: pointwise_strategy, + torch.Tensor.log2_: pointwise_strategy, + torch.Tensor.log_: pointwise_strategy, + torch.Tensor.logical_and_: pointwise_strategy, + torch.Tensor.logical_not_: pointwise_strategy, + torch.Tensor.logical_or_: pointwise_strategy, + torch.Tensor.logical_xor_: pointwise_strategy, + torch.Tensor.logit_: pointwise_strategy, + torch.Tensor.mul_: pointwise_strategy, + torch.Tensor.mvlgamma_: pointwise_strategy, + torch.Tensor.nan_to_num_: pointwise_strategy, + torch.Tensor.neg_: pointwise_strategy, + torch.Tensor.nextafter_: pointwise_strategy, + torch.Tensor.polygamma_: pointwise_strategy, + torch.Tensor.pow_: pointwise_strategy, + torch.Tensor.reciprocal_: pointwise_strategy, + torch.Tensor.rad2deg_: pointwise_strategy, + torch.Tensor.relu_: pointwise_strategy, + torch.Tensor.remainder_: pointwise_strategy, + torch.Tensor.round_: pointwise_strategy, + torch.Tensor.rsqrt_: pointwise_strategy, + torch.Tensor.sgn_: pointwise_strategy, + torch.Tensor.sigmoid_: pointwise_strategy, + torch.Tensor.sign_: pointwise_strategy, + torch.Tensor.sin_: pointwise_strategy, + torch.Tensor.sinc_: pointwise_strategy, + torch.Tensor.sinh_: pointwise_strategy, + torch.Tensor.sqrt_: pointwise_strategy, + torch.Tensor.square_: pointwise_strategy, + torch.Tensor.sub_: pointwise_strategy, + torch.Tensor.tan_: pointwise_strategy, + torch.Tensor.tanh_: pointwise_strategy, + torch.Tensor.trunc_: pointwise_strategy, + torch.Tensor.xlogy_: pointwise_strategy, + # tensor_ops.py torch.ones: fully_replicated_strategy, torch.full: fully_replicated_strategy, - getattr: fully_replicated_strategy, - F.embedding: embedding_strategy, - torch.finfo: fully_replicated_strategy, + torch.Tensor.clone: tensor_op_strategy, + torch.Tensor.contiguous: tensor_op_strategy, + torch.Tensor.copy_: tensor_op_strategy, + torch.Tensor.detach: tensor_op_strategy, + torch.Tensor.fill_: tensor_op_strategy, + torch.Tensor.zero_: tensor_op_strategy, + torch.Tensor.equal: tensor_equal_strategy, + torch.Tensor.is_same_size: tensor_equal_strategy, } AUTOSHARDING_METHODS = { + # view_ops.py "view": get_reshape_strategy(torch.Tensor.view), "reshape": get_reshape_strategy(torch.Tensor.reshape), "expand": get_reshape_strategy(torch.Tensor.expand), - "permute": get_reshape_strategy(torch.permute), - "transpose": get_reshape_strategy(torch.transpose), + "permute": get_reshape_strategy(torch.Tensor.permute), + "transpose": get_reshape_strategy(torch.Tensor.transpose), + "masked_fill": pointwise_strategy, + "masked_fill_": pointwise_strategy, + "contiguous": tensor_op_strategy, } -IMPLICIT_FUNCS = [operator.getitem] +IMPLICIT_FUNCS = [ + operator.getitem, + getattr, + torch.finfo, + torch.arange, +] -IMPLICIT_METHODS = ["size"] +IMPLICIT_METHODS = [ + "size", +] diff --git a/src/chop/passes/graph/analysis/autosharding/ops/basic_strategy.py b/src/chop/passes/graph/analysis/autosharding/ops/basic_strategy.py index 6db629c89..b9541013a 100644 --- a/src/chop/passes/graph/analysis/autosharding/ops/basic_strategy.py +++ b/src/chop/passes/graph/analysis/autosharding/ops/basic_strategy.py @@ -1,3 +1,6 @@ +# Adapted from Pytorch Distributed DTensor API. +# https://github.com/pytorch/pytorch/blob/main/torch/distributed/_tensor/ops/basic_strategy.py + import itertools from dataclasses import dataclass from typing import List, Set, Tuple diff --git a/src/chop/passes/graph/analysis/autosharding/ops/embedding_ops.py b/src/chop/passes/graph/analysis/autosharding/ops/embedding_ops.py index 15dd9c592..feef6fd12 100644 --- a/src/chop/passes/graph/analysis/autosharding/ops/embedding_ops.py +++ b/src/chop/passes/graph/analysis/autosharding/ops/embedding_ops.py @@ -1,6 +1,6 @@ -# mypy: allow-untyped-defs -# Copyright (c) Meta Platforms, Inc. and affiliates -# implement matrix related ops for distributed tensor +# Adapted from Pytorch Distributed DTensor API. +# https://github.com/pytorch/pytorch/blob/main/torch/distributed/_tensor/ops/embedding_ops.py + from dataclasses import dataclass, field from typing import cast, List, Optional import itertools diff --git a/src/chop/passes/graph/analysis/autosharding/ops/math_ops.py b/src/chop/passes/graph/analysis/autosharding/ops/math_ops.py index 028c94dbf..4cdc2f0fe 100644 --- a/src/chop/passes/graph/analysis/autosharding/ops/math_ops.py +++ b/src/chop/passes/graph/analysis/autosharding/ops/math_ops.py @@ -1,3 +1,6 @@ +# Adapted from Pytorch Distributed DTensor API. +# https://github.com/pytorch/pytorch/blob/main/torch/distributed/_tensor/ops/math_ops.py + from typing import cast, List, Optional, Sequence, Tuple, Union import torch diff --git a/src/chop/passes/graph/analysis/autosharding/ops/matrix_ops.py b/src/chop/passes/graph/analysis/autosharding/ops/matrix_ops.py index 1fdfb5151..252d5df41 100644 --- a/src/chop/passes/graph/analysis/autosharding/ops/matrix_ops.py +++ b/src/chop/passes/graph/analysis/autosharding/ops/matrix_ops.py @@ -1,4 +1,5 @@ -# Adapted from https://github.com/pytorch/pytorch/blob/main/torch/distributed/_tensor/ops/matrix_ops.py +# Adapted from Pytorch Distributed DTensor API. +# https://github.com/pytorch/pytorch/blob/main/torch/distributed/_tensor/ops/matrix_ops.py import torch from torch.distributed._tensor._op_schema import OpStrategy, PlacementStrategy diff --git a/src/chop/passes/graph/analysis/autosharding/ops/pointwise_ops.py b/src/chop/passes/graph/analysis/autosharding/ops/pointwise_ops.py index acea27692..7ef6cb573 100644 --- a/src/chop/passes/graph/analysis/autosharding/ops/pointwise_ops.py +++ b/src/chop/passes/graph/analysis/autosharding/ops/pointwise_ops.py @@ -1,3 +1,6 @@ +# Adapted from Pytorch Distributed DTensor API. +# https://github.com/pytorch/pytorch/blob/main/torch/distributed/_tensor/ops/pointwise_ops.py + from typing import List, Sequence, Tuple import torch diff --git a/src/chop/passes/graph/analysis/autosharding/ops/tensor_ops.py b/src/chop/passes/graph/analysis/autosharding/ops/tensor_ops.py new file mode 100644 index 000000000..0a9daee2a --- /dev/null +++ b/src/chop/passes/graph/analysis/autosharding/ops/tensor_ops.py @@ -0,0 +1,111 @@ +# Adapted from Pytorch Distributed DTensor API. +# https://github.com/pytorch/pytorch/blob/main/torch/distributed/_tensor/ops/tensor_ops.py + +from torch.distributed._tensor._op_schema import ( + OpStrategy, + PlacementStrategy, + StrategyType, +) +from torch.distributed._tensor.ops.utils import ( + is_tensor_partial, +) +from torch.distributed._tensor.placement_types import ( + DTensorSpec, + Partial, + Replicate, + TensorMeta, +) + + +def tensor_op_strategy(meta, mesh) -> StrategyType: + # Default strategy by default just propagate the first input strategy + select_strategy = meta.node.args[0].meta["mase"]["software"]["autosharding"][ + "op_strategy" + ] + assert isinstance(select_strategy, OpStrategy) + + node_args = list(meta["common"]["args"].keys()) + if len(node_args) > 0: + first_arg_name = node_args[0] + arg_shape, arg_dtype = ( + meta["common"]["args"][first_arg_name]["shape"], + meta["common"]["args"][first_arg_name]["torch_dtype"], + ) + + else: + arg_shape, arg_dtype = ( + meta["common"]["self"].shape, + meta["common"]["self"].dtype, + ) + + first_result = list(meta["common"]["results"].keys())[0] + result_shape, result_dtype = ( + meta["common"]["results"][first_result]["shape"], + meta["common"]["results"][first_result]["torch_dtype"], + ) + + default_strategy = [] + for strategy in select_strategy.strategies: + # we create new DTensorSpecs even for default strategy to assure that + # the tensor metas are distinct between the arguments and outputs + default_strategy.append( + PlacementStrategy( + input_specs=( + DTensorSpec( + mesh=strategy.output_spec.mesh, + placements=strategy.output_spec.placements, + tensor_meta=TensorMeta( + shape=arg_shape, dtype=arg_dtype, stride=None + ), + ), + ) + * len(meta.node.args), + output_specs=DTensorSpec( + mesh=strategy.output_spec.mesh, + placements=strategy.output_spec.placements, + tensor_meta=TensorMeta( + shape=result_shape, dtype=result_dtype, stride=None + ), + ), + ) + ) + return OpStrategy(default_strategy) + + +def tensor_equal_strategy(meta, mesh) -> StrategyType: + # equal_strategy deals with ops that comparing two tensor, we need to make sure + # sharding layout the same with two operands, we choose to follow the arg with max + # num of shards, still keep is_same_size here for completeness as they share the + # same strategy in theory. + self_strategy, other_strategy = ( + meta.node.args[0].meta["mase"]["software"]["autosharding"]["op_strategy"], + meta.node.args[1].meta["mase"]["software"]["autosharding"]["op_strategy"], + ) + assert isinstance(self_strategy, OpStrategy) + assert isinstance(other_strategy, OpStrategy) + + select_strategy = ( + self_strategy + if self_strategy.max_num_shards() >= other_strategy.max_num_shards() + else other_strategy + ) + equal_strategy = OpStrategy([]) + + for arg_strategy in select_strategy.strategies: + arg_spec = arg_strategy.output_spec + if is_tensor_partial(arg_spec): + # if the arg_spec have partial, reshard to replicate + # otherwise local shard tensor comparison would be invalid + output_spec = DTensorSpec( + mesh=arg_spec.mesh, + placements=tuple( + Replicate() if isinstance(p, Partial) else p + for p in arg_spec.placements + ), + ) + equal_strategy.strategies.append( + PlacementStrategy(output_specs=output_spec) + ) + else: + equal_strategy.strategies.append(PlacementStrategy(arg_spec)) + return equal_strategy diff --git a/src/chop/passes/graph/analysis/autosharding/ops/view_ops.py b/src/chop/passes/graph/analysis/autosharding/ops/view_ops.py index a4e779539..83c638bfb 100644 --- a/src/chop/passes/graph/analysis/autosharding/ops/view_ops.py +++ b/src/chop/passes/graph/analysis/autosharding/ops/view_ops.py @@ -1,3 +1,6 @@ +# Adapted from Pytorch Distributed DTensor API. +# https://github.com/pytorch/pytorch/blob/main/torch/distributed/_tensor/ops/view_ops.py + from dataclasses import dataclass from typing import ( Callable, @@ -435,7 +438,6 @@ def dim_view_as_real(shape: Shape) -> DimMap: torch.atleast_2d: lambda x: dim_pad_left(x.ndim, 2), torch.atleast_3d: lambda x: dim_atleast_3d(x.ndim), torch.broadcast_to: lambda input, shape: expand(input.shape, shape), - Tensor.expand: lambda self, *sizes: expand(self.shape, normalize_sizes(sizes)), torch.flatten: lambda tensor: dim_flatten(tensor.ndim), torch.movedim: lambda input, source, destination: dim_movedim( input.ndim, source, destination @@ -444,16 +446,23 @@ def dim_view_as_real(shape: Shape) -> DimMap: InputDim(i) for i in normalize_dims(tuple(dims), input.ndim) ), torch.ravel: lambda tensor: dim_flatten(tensor.ndim), - Tensor.repeat: lambda self, *sizes: dim_repeat(self.ndim, sizes), - Tensor.reshape: lambda self, *shape: view_groups(self.shape, shape), torch.reshape: lambda input, shape: view_groups(input.shape, shape), torch.squeeze: lambda input, dim=None: dim_squeeze(input.shape, dim), torch.tile: lambda input, dims: dim_tile(input.ndim, dims), torch.transpose: lambda input, dim0, dim1: dim_transpose(input.ndim, dim0, dim1), torch.unsqueeze: lambda input, dim: dim_unsqueeze(input.ndim, dim), - Tensor.view: lambda input, *shape: view_groups(input.shape, shape), torch.view_as_complex: lambda input: dim_flatten(input.ndim, input.ndim - 2), torch.view_as_real: lambda input: dim_view_as_real(input.shape), + Tensor.expand: lambda self, *sizes: expand(self.shape, normalize_sizes(sizes)), + Tensor.repeat: lambda self, *sizes: dim_repeat(self.ndim, sizes), + Tensor.reshape: lambda self, *shape: view_groups(self.shape, shape), + Tensor.view: lambda input, *shape: view_groups(input.shape, shape), + # here + Tensor.permute: lambda input, *dims: tuple( + InputDim(i) for i in normalize_dims(tuple(dims), input.ndim) + ), + Tensor.transpose: lambda input, dim0, dim1: dim_transpose(input.ndim, dim0, dim1), + Tensor.view: lambda input, *shape: view_groups(input.shape, shape), } From 17374f6b3ff785f2697d370b91856fa685a0bcac Mon Sep 17 00:00:00 2001 From: Pedro Gimenes Date: Tue, 16 Jul 2024 11:14:51 +0000 Subject: [PATCH 41/93] support activation modules, remove legacy stuff, some refactoring --- src/chop/distributed/utils.py | 26 ---- .../autosharding/alpa_intra_operator.py | 13 +- .../deprecated/alpa_cost_modelling.py | 119 ----------------- .../autosharding/deprecated/common.py | 25 ---- .../deprecated/debug_utilities.py | 72 ----------- .../graph/analysis/autosharding/layers.py | 122 ++---------------- .../{ops => strategies}/basic_strategy.py | 0 .../autosharding/strategies/common.py | 111 ++++++++++++++++ .../{ops => strategies}/embedding_ops.py | 0 .../{ops => strategies}/math_ops.py | 0 .../{ops => strategies}/matrix_ops.py | 0 .../{ops => strategies}/pointwise_ops.py | 11 +- .../{ops => strategies}/tensor_ops.py | 0 .../{ops => strategies}/view_ops.py | 0 .../transforms/autosharding/resharding.py | 72 +---------- 15 files changed, 150 insertions(+), 421 deletions(-) delete mode 100644 src/chop/passes/graph/analysis/autosharding/deprecated/alpa_cost_modelling.py delete mode 100644 src/chop/passes/graph/analysis/autosharding/deprecated/common.py delete mode 100644 src/chop/passes/graph/analysis/autosharding/deprecated/debug_utilities.py rename src/chop/passes/graph/analysis/autosharding/{ops => strategies}/basic_strategy.py (100%) create mode 100644 src/chop/passes/graph/analysis/autosharding/strategies/common.py rename src/chop/passes/graph/analysis/autosharding/{ops => strategies}/embedding_ops.py (100%) rename src/chop/passes/graph/analysis/autosharding/{ops => strategies}/math_ops.py (100%) rename src/chop/passes/graph/analysis/autosharding/{ops => strategies}/matrix_ops.py (100%) rename src/chop/passes/graph/analysis/autosharding/{ops => strategies}/pointwise_ops.py (95%) rename src/chop/passes/graph/analysis/autosharding/{ops => strategies}/tensor_ops.py (100%) rename src/chop/passes/graph/analysis/autosharding/{ops => strategies}/view_ops.py (100%) diff --git a/src/chop/distributed/utils.py b/src/chop/distributed/utils.py index 91c35c47e..c7cd7c1c7 100644 --- a/src/chop/distributed/utils.py +++ b/src/chop/distributed/utils.py @@ -1,29 +1,3 @@ -from torch.distributed._tensor import ( - Replicate, - Shard, -) - -from chop.passes.graph.analysis.autosharding.deprecated.common import SpmdShard - -import torch - - -def placement_from_sharding_config(sharding_config): - """ - Sharding config is given as a tuple such as (R, S_0) where a symbol S_x at index i indicates - that tensor dimension i is sharded along the x-th dimension of the device mesh. However, - the distribute_tensor API expects a tuple of Shard() and Replicate() objects where a Shard(x) - at index i indicates that tensor dimension x is sharded along device mesh dimension i. - """ - placement = [Replicate()] * 2 - for shard_type in [SpmdShard.S_0, SpmdShard.S_1]: - if shard_type in sharding_config: - idx = sharding_config.index(shard_type) - placement[shard_type.value] = Shard(idx) - - return tuple(placement) - - def rlog(logger, rank, msg, level="info"): """ Only log on rank 0 to avoid repeated messages. diff --git a/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py b/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py index 4b88993db..345dd5593 100644 --- a/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py +++ b/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py @@ -6,14 +6,18 @@ import cvxpy as cp from chop.tools import get_logger +from chop.tools.utils import deepgetattr from .layers import ( + AUTOSHARDING_MODULES, AUTOSHARDING_FUNCTIONS, AUTOSHARDING_METHODS, IMPLICIT_FUNCS, IMPLICIT_METHODS, - placeholder_or_getattr_strategy, +) +from .strategies.common import ( fully_replicated_strategy, + placeholder_or_getattr_strategy, ) @@ -97,6 +101,13 @@ def _extract_ilp(mg, mesh, pass_args={}): } continue + elif node.op == "call_module" and isinstance( + deepgetattr(mg.model, node.target), tuple(AUTOSHARDING_MODULES.keys()) + ): + logger.debug(f"Obtaining strategy for node {node.name}") + module_cls = type(deepgetattr(mg.model, node.target)) + op_strategy = AUTOSHARDING_MODULES[module_cls](node.meta["mase"], mesh) + elif node.op == "call_method" and node.target in AUTOSHARDING_METHODS.keys(): logger.debug(f"Obtaining strategy for node {node.name}") op_strategy = AUTOSHARDING_METHODS[node.target](node.meta["mase"], mesh) diff --git a/src/chop/passes/graph/analysis/autosharding/deprecated/alpa_cost_modelling.py b/src/chop/passes/graph/analysis/autosharding/deprecated/alpa_cost_modelling.py deleted file mode 100644 index 467caa13c..000000000 --- a/src/chop/passes/graph/analysis/autosharding/deprecated/alpa_cost_modelling.py +++ /dev/null @@ -1,119 +0,0 @@ -import numpy as np -from functools import lru_cache - -from chop.ir.graph import MaseMetadata - -from .common import SpmdShard -from ..mesh_model import MeshModel - -BYTES_PER_ELEMENT = 4 - - -def get_communication_cost(sharding: tuple, node_meta: MaseMetadata, mesh: MeshModel): - assert ( - sharding[0][-1] == sharding[1][-2] - ), f"Inconsistent sharding for node: {node_meta.node}" - inner_dim_sharding = sharding[1][0] - - out_shape = node_meta["common"]["results"]["data_out_0"]["shape"] - - if inner_dim_sharding == SpmdShard.R: - return 0 - - else: - ar_dim = inner_dim_sharding.value # 0 for S_0, 1 for S_1 - return mesh.all_reduce_cost( - num_bytes=BYTES_PER_ELEMENT * np.prod(out_shape), mesh_dim=ar_dim - ) - - -@lru_cache(maxsize=None) -def get_resharding_cost( - mesh: MeshModel, src: tuple, dest: tuple, dest_node_meta: MaseMetadata -): - """ - Obtain the resharding cost given a source and destination sharding profile for a tensor. - The mesh object is assumed to have been initialized with alpha, beta parameters so that - the communication cost can be estimated for each MPI operator. - """ - - # If original sharding is fully replicated, no resharding is required - if src == dest or all(i == SpmdShard.R for i in src): - return 0 - - num_bytes = BYTES_PER_ELEMENT * np.prod( - dest_node_meta["common"]["args"]["data_in_0"]["shape"] - ) - - # No cost (simple split along given mesh dimension) - if ( - # Keep dim 0, split dim 1 - # E.g. (R, R) -> (R, S_0), (S_0, R) -> (S_0, S_1) - (src[0] == dest[0]) - and (src[1] == SpmdShard.R) - and (dest[1] in [SpmdShard.S_0, SpmdShard.S_1]) - # Split dim 0, keep dim 1 - # E.g. (R, R) -> (S_1, R), (R, S_1) -> (S_0, S_1) - or (src[1] == dest[1]) - and (src[0] == SpmdShard.R) - and (dest[0] in [SpmdShard.S_0, SpmdShard.S_1]) - ): - return 0 - - # Split -> Replicate (All Gather) - elif ( - # Keep dim 0, gather along dim 1 - # E.g. (S_1, S_0) -> (S_1, R) - (src[0] == dest[0]) - and (src[1] in [SpmdShard.S_0, SpmdShard.S_1]) - and (dest[1] == SpmdShard.R) - # Gather along dim 0, keep dim 1 - # E.g. (S_0, S_1) -> (R, S_1) - or (src[1] == dest[1]) - and (src[0] in [SpmdShard.S_0, SpmdShard.S_1]) - and (dest[0] == SpmdShard.R) - ): - ag_dim = 1 if src[0] == dest[0] else 0 - return mesh.all_gather_cost( - num_bytes=num_bytes, - mesh_dim=ag_dim, - ) - - # All-to-all - # E.g. (R, S_0) -> (S_0, R), (S_1, R) -> (R, S_1) - elif src[0] == dest[1] and src[1] == dest[0] and (SpmdShard.R in src): - # all to all - a2a_dim = src[0].value if src[0] != SpmdShard.R else src[1].value - try: - return mesh.all_to_all_cost( - num_bytes=num_bytes, - mesh_dim=a2a_dim, - ) - except: - breakpoint() - - # Two-stage resharding: when the resharding cannot be resolved with a single split, all-gather or all-to-all, - # must first gather along the first non-replicated dimension, then recursively compute the cost for the - # reduced sharding - else: - # Reduce one dimension and re-compute - if src[0] != SpmdShard.R: - new_src = (SpmdShard.R, src[1]) - ag_dim = src[0].value - else: - new_src = (SpmdShard.R, SpmdShard.R) - ag_dim = src[1].value - - return mesh.all_gather_cost( - num_bytes=num_bytes, mesh_dim=ag_dim - ) + get_resharding_cost(mesh, new_src, dest, dest_node_meta) - - -def get_resharding_matrix(mesh, src_shardings, dest_shardings, dest_node_meta): - mat = np.zeros((len(dest_shardings), len(src_shardings))) - for src_idx, src in enumerate(src_shardings): - for dest_idx, dest in enumerate(dest_shardings): - mat[dest_idx, src_idx] = get_resharding_cost( - mesh, src, dest, dest_node_meta - ) - return mat diff --git a/src/chop/passes/graph/analysis/autosharding/deprecated/common.py b/src/chop/passes/graph/analysis/autosharding/deprecated/common.py deleted file mode 100644 index e0b98001a..000000000 --- a/src/chop/passes/graph/analysis/autosharding/deprecated/common.py +++ /dev/null @@ -1,25 +0,0 @@ -from enum import Enum - -class SpmdShard(Enum): - S_0 = 0 - S_1 = 1 - R = 3 - - def __repr__(self): - return self.name - - def __gt__(self, other): - if self.__class__ is other.__class__: - return self.value > other.value - return NotImplemented - - -VALID_2D_TENSOR_SHARDINGS = [ - (SpmdShard.R, SpmdShard.R), - (SpmdShard.R, SpmdShard.S_0), - (SpmdShard.R, SpmdShard.S_1), - (SpmdShard.S_0, SpmdShard.R), - (SpmdShard.S_0, SpmdShard.S_1), - (SpmdShard.S_1, SpmdShard.R), - (SpmdShard.S_1, SpmdShard.S_0), -] \ No newline at end of file diff --git a/src/chop/passes/graph/analysis/autosharding/deprecated/debug_utilities.py b/src/chop/passes/graph/analysis/autosharding/deprecated/debug_utilities.py deleted file mode 100644 index 10b4b1c67..000000000 --- a/src/chop/passes/graph/analysis/autosharding/deprecated/debug_utilities.py +++ /dev/null @@ -1,72 +0,0 @@ - - -import torch.nn as nn - -from chop.tools import get_logger - -from chop import MaseGraph -import chop.passes as passes -import torch - -logger = get_logger(__name__) -logger.setLevel("DEBUG") - -def are_layers_equal(layer1, layer2): - # Check if both layers are instances of the same class - if type(layer1) != type(layer2): - return False - - # Compare their attributes - for attr in dir(layer1): - # Skip methods and special attributes - l1_attr = getattr(layer1, attr) - if callable(getattr(layer1, attr)) or attr.startswith("_") or isinstance(l1_attr, torch.Tensor): - continue - # Check if both layers have the same attribute and their values are equal - if hasattr(layer2, attr): - if getattr(layer1, attr) != getattr(layer2, attr): - return False - else: - return False - - return True - -def debug_shardings(layer, input_shardings, world_size, device_mesh): - - from chop.distributed import MaseLauncher - - class WrapperModule(nn.Module): - def __init__(self, layer): - super().__init__() - self.layer = layer - - def forward(self, x): - return self.layer(x) - - logger.info(f"Generating subgraph for layer: {layer}") - mg = MaseGraph(WrapperModule(layer)) - mg, _ = passes.init_metadata_analysis_pass(mg) - mg, _ = passes.add_common_metadata_analysis_pass( - mg, - pass_args={ - "dummy_in": { - "x": torch.randn((1, layer.in_features)), - }, - "add_value": False, - }, - ) - - for idx, sharding in enumerate(input_shardings): - module_map = { - "node": "---", - "sharding": { - layer: { - key: sharding[key] for key in sharding.keys() - } - } - } - logger.info(f"[{idx}/{len(input_shardings)}] Testing shading: {sharding}") - launcher = MaseLauncher(mg, world_size=world_size, device_mesh=device_mesh) - # inputs = [torch.randint(0, 10, (1, config_sequence_length))] - inputs = [torch.randn((1, layer.in_features))] - launcher.run(module_map, inputs) \ No newline at end of file diff --git a/src/chop/passes/graph/analysis/autosharding/layers.py b/src/chop/passes/graph/analysis/autosharding/layers.py index 833fb3d5d..03f99d8ee 100644 --- a/src/chop/passes/graph/analysis/autosharding/layers.py +++ b/src/chop/passes/graph/analysis/autosharding/layers.py @@ -1,129 +1,29 @@ -import itertools import operator import torch import torch.nn.functional as F -from torch.distributed._tensor._op_schema import OpStrategy, PlacementStrategy -from torch.distributed._tensor.placement_types import ( - Replicate, - Shard, - DTensorSpec, - TensorMeta, -) from chop.tools import get_logger -from .ops.matrix_ops import ( +from .strategies.common import fully_replicated_strategy +from .strategies.matrix_ops import ( transpose_strategy, mm_strategy, addmm_strategy, bmm_strategy, baddmm_strategy, ) -from .ops.view_ops import get_reshape_strategy -from .ops.pointwise_ops import pointwise_strategy, linear_pointwise_strategy -from .ops.math_ops import softmax_strategy, layer_norm_strategy -from .ops.embedding_ops import embedding_strategy -from .ops.tensor_ops import tensor_op_strategy, tensor_equal_strategy +from .strategies.view_ops import get_reshape_strategy +from .strategies.pointwise_ops import pointwise_strategy, linear_pointwise_strategy +from .strategies.math_ops import softmax_strategy, layer_norm_strategy +from .strategies.embedding_ops import embedding_strategy +from .strategies.tensor_ops import tensor_op_strategy, tensor_equal_strategy logger = get_logger(__name__) - -def find_shape_and_dtype(arg): - - if isinstance(arg, dict): - in_shape = arg["shape"] - in_dtype = arg["torch_dtype"] - elif isinstance(arg, (tuple, list)): - arg = torch.Tensor(arg) - in_shape = arg.shape - in_dtype = arg.dtype - elif isinstance(arg, torch.Size): - arg = torch.Tensor(list(arg)) - in_shape = arg.shape - in_dtype = arg.dtype - elif isinstance(arg, (float, int)): - arg = torch.Tensor([arg]) - in_shape = arg.shape - in_dtype = arg.dtype - else: - logger.warning(f"Unknown type for arg: {arg}") - in_shape = tuple() - in_dtype = type(arg) - - return in_shape, in_dtype - - -def placeholder_or_getattr_strategy(meta, mesh, skip_fully_replicated=False): - ndims = len(meta["common"]["results"]["data_out_0"]["shape"]) - opts = [Replicate()] + [Shard(dim) for dim in range(ndims)] - - tensor_meta = TensorMeta( - shape=meta["common"]["results"]["data_out_0"]["shape"], - stride=None, - dtype=meta["common"]["results"]["data_out_0"]["torch_dtype"], - ) - - shardings = [] - for sharding in itertools.product(opts, repeat=2): - if skip_fully_replicated and sharding == (Replicate(), Replicate()): - continue - spec = DTensorSpec(mesh=mesh, placements=sharding, tensor_meta=tensor_meta) - shardings.append(PlacementStrategy(input_specs=spec, output_specs=spec)) - return OpStrategy(shardings) - - -def fully_replicated_strategy(meta, mesh): - """ - Output of ops like size, getitem etc are always fully replicated - """ - sharding = [Replicate(), Replicate()] - - # call_method nodes don't list input tensor in the args list, but - # tensor is copied into meta["common"]["self"] when add_value = True - # is passed to add_common_metadata_pass - if meta.node.op == "call_method": - in_shape = meta["common"]["self"].shape - in_dtype = meta["common"]["self"].dtype - else: - first_arg_key = ( - "data_in_0" - if "data_in_0" in meta["common"]["args"] - else [i for i in meta["common"]["args"].keys()][0] - ) - arg = meta["common"]["args"][first_arg_key] - in_shape, in_dtype = find_shape_and_dtype(arg) - - in_spec = DTensorSpec( - mesh, - sharding, - tensor_meta=TensorMeta( - shape=in_shape, - stride=None, - dtype=in_dtype, - ), - ) - - dtype_key = ( - "torch_dtype" - if "torch_dtype" in meta["common"]["results"]["data_out_0"].keys() - else "type" - ) - out_dtype = meta["common"]["results"]["data_out_0"][dtype_key] - out_spec = DTensorSpec( - mesh, - sharding, - tensor_meta=TensorMeta( - shape=meta["common"]["results"]["data_out_0"]["shape"], - stride=None, - dtype=out_dtype, - ), - ) - - shardings = [PlacementStrategy(input_specs=in_spec, output_specs=out_spec)] - - return OpStrategy(shardings) - +AUTOSHARDING_MODULES = { + torch.nn.ReLU: pointwise_strategy, +} AUTOSHARDING_FUNCTIONS = { # embedding_ops.py @@ -174,7 +74,7 @@ def fully_replicated_strategy(meta, mesh): torch.digamma: pointwise_strategy, torch.div: pointwise_strategy, torch.eq: pointwise_strategy, - # operator.eq: pointwise_strategy, + operator.eq: pointwise_strategy, torch.erf: pointwise_strategy, torch.erfc: pointwise_strategy, torch.erfinv: pointwise_strategy, diff --git a/src/chop/passes/graph/analysis/autosharding/ops/basic_strategy.py b/src/chop/passes/graph/analysis/autosharding/strategies/basic_strategy.py similarity index 100% rename from src/chop/passes/graph/analysis/autosharding/ops/basic_strategy.py rename to src/chop/passes/graph/analysis/autosharding/strategies/basic_strategy.py diff --git a/src/chop/passes/graph/analysis/autosharding/strategies/common.py b/src/chop/passes/graph/analysis/autosharding/strategies/common.py new file mode 100644 index 000000000..0e27cfc6c --- /dev/null +++ b/src/chop/passes/graph/analysis/autosharding/strategies/common.py @@ -0,0 +1,111 @@ +import itertools + +import torch +import torch.nn.functional as F +from torch.distributed._tensor._op_schema import OpStrategy, PlacementStrategy +from torch.distributed._tensor.placement_types import ( + Replicate, + Shard, + DTensorSpec, + TensorMeta, +) + +from chop.tools import get_logger + +logger = get_logger(__name__) + + +def find_shape_and_dtype(arg): + + if isinstance(arg, dict): + in_shape = arg["shape"] + in_dtype = arg["torch_dtype"] + elif isinstance(arg, (tuple, list)): + arg = torch.Tensor(arg) + in_shape = arg.shape + in_dtype = arg.dtype + elif isinstance(arg, torch.Size): + arg = torch.Tensor(list(arg)) + in_shape = arg.shape + in_dtype = arg.dtype + elif isinstance(arg, (float, int)): + arg = torch.Tensor([arg]) + in_shape = arg.shape + in_dtype = arg.dtype + else: + logger.warning(f"Unknown type for arg: {arg}") + in_shape = tuple() + in_dtype = type(arg) + + return in_shape, in_dtype + + +def placeholder_or_getattr_strategy(meta, mesh, skip_fully_replicated=False): + ndims = len(meta["common"]["results"]["data_out_0"]["shape"]) + opts = [Replicate()] + [Shard(dim) for dim in range(ndims)] + + tensor_meta = TensorMeta( + shape=meta["common"]["results"]["data_out_0"]["shape"], + stride=None, + dtype=meta["common"]["results"]["data_out_0"]["torch_dtype"], + ) + + shardings = [] + for sharding in itertools.product(opts, repeat=2): + if skip_fully_replicated and sharding == (Replicate(), Replicate()): + continue + spec = DTensorSpec(mesh=mesh, placements=sharding, tensor_meta=tensor_meta) + shardings.append(PlacementStrategy(input_specs=spec, output_specs=spec)) + return OpStrategy(shardings) + + +def fully_replicated_strategy(meta, mesh): + """ + Output of ops like size, getitem etc are always fully replicated + """ + sharding = [Replicate(), Replicate()] + + # call_method nodes don't list input tensor in the args list, but + # tensor is copied into meta["common"]["self"] when add_value = True + # is passed to add_common_metadata_pass + if meta.node.op == "call_method": + in_shape = meta["common"]["self"].shape + in_dtype = meta["common"]["self"].dtype + else: + first_arg_key = ( + "data_in_0" + if "data_in_0" in meta["common"]["args"] + else [i for i in meta["common"]["args"].keys()][0] + ) + arg = meta["common"]["args"][first_arg_key] + in_shape, in_dtype = find_shape_and_dtype(arg) + + in_spec = DTensorSpec( + mesh, + sharding, + tensor_meta=TensorMeta( + shape=in_shape, + stride=None, + dtype=in_dtype, + ), + ) + + dtype_key = ( + "torch_dtype" + if "torch_dtype" in meta["common"]["results"]["data_out_0"].keys() + else "type" + ) + out_dtype = meta["common"]["results"]["data_out_0"][dtype_key] + out_spec = DTensorSpec( + mesh, + sharding, + tensor_meta=TensorMeta( + shape=meta["common"]["results"]["data_out_0"]["shape"], + stride=None, + dtype=out_dtype, + ), + ) + + shardings = [PlacementStrategy(input_specs=in_spec, output_specs=out_spec)] + + return OpStrategy(shardings) diff --git a/src/chop/passes/graph/analysis/autosharding/ops/embedding_ops.py b/src/chop/passes/graph/analysis/autosharding/strategies/embedding_ops.py similarity index 100% rename from src/chop/passes/graph/analysis/autosharding/ops/embedding_ops.py rename to src/chop/passes/graph/analysis/autosharding/strategies/embedding_ops.py diff --git a/src/chop/passes/graph/analysis/autosharding/ops/math_ops.py b/src/chop/passes/graph/analysis/autosharding/strategies/math_ops.py similarity index 100% rename from src/chop/passes/graph/analysis/autosharding/ops/math_ops.py rename to src/chop/passes/graph/analysis/autosharding/strategies/math_ops.py diff --git a/src/chop/passes/graph/analysis/autosharding/ops/matrix_ops.py b/src/chop/passes/graph/analysis/autosharding/strategies/matrix_ops.py similarity index 100% rename from src/chop/passes/graph/analysis/autosharding/ops/matrix_ops.py rename to src/chop/passes/graph/analysis/autosharding/strategies/matrix_ops.py diff --git a/src/chop/passes/graph/analysis/autosharding/ops/pointwise_ops.py b/src/chop/passes/graph/analysis/autosharding/strategies/pointwise_ops.py similarity index 95% rename from src/chop/passes/graph/analysis/autosharding/ops/pointwise_ops.py rename to src/chop/passes/graph/analysis/autosharding/strategies/pointwise_ops.py index 7ef6cb573..fe6ac93ef 100644 --- a/src/chop/passes/graph/analysis/autosharding/ops/pointwise_ops.py +++ b/src/chop/passes/graph/analysis/autosharding/strategies/pointwise_ops.py @@ -25,6 +25,12 @@ TensorMeta, ) +from chop.tools import get_logger + +from .common import fully_replicated_strategy + +logger = get_logger(__name__) + def pointwise_strategy(meta, mesh, linearity=False): max_shards_strategy_index = -1 @@ -76,7 +82,10 @@ def common_pointwise_strategy( elif isinstance(arg, (float, int)): parsed_args.append(torch.Tensor([arg])) else: - raise ValueError(f"Unrecognized arg type: {type(arg)}") + logger.warning( + f"Unrecognized arg type: {type(arg)}, defaulting to fully replicated strategy." + ) + return fully_replicated_strategy(meta, mesh) common_shape = torch.broadcast_shapes(*[arg.shape for arg in parsed_args]) diff --git a/src/chop/passes/graph/analysis/autosharding/ops/tensor_ops.py b/src/chop/passes/graph/analysis/autosharding/strategies/tensor_ops.py similarity index 100% rename from src/chop/passes/graph/analysis/autosharding/ops/tensor_ops.py rename to src/chop/passes/graph/analysis/autosharding/strategies/tensor_ops.py diff --git a/src/chop/passes/graph/analysis/autosharding/ops/view_ops.py b/src/chop/passes/graph/analysis/autosharding/strategies/view_ops.py similarity index 100% rename from src/chop/passes/graph/analysis/autosharding/ops/view_ops.py rename to src/chop/passes/graph/analysis/autosharding/strategies/view_ops.py diff --git a/src/chop/passes/module/transforms/autosharding/resharding.py b/src/chop/passes/module/transforms/autosharding/resharding.py index 4a8617c4d..3646064dd 100644 --- a/src/chop/passes/module/transforms/autosharding/resharding.py +++ b/src/chop/passes/module/transforms/autosharding/resharding.py @@ -1,64 +1,8 @@ - -import functools - -import torch -import torch.nn as nn - -from torch.distributed._tensor import ( - DeviceMesh, -) - -from torch.distributed._tensor.api import Redistribute - -from chop.distributed.utils import placement_from_sharding_config from chop.tools import get_logger logger = get_logger(__name__) logger.setLevel("INFO") -def rlog(logger, rank, msg, level="info"): - """ - Only log on rank 0 to avoid repeated messages. - """ - log_fn = getattr(logger, level, logger.info) - if (rank == 0): - log_fn(f"[RANK: {rank}]: {msg}") - -def deepsetattr(obj, attr, value): - """Recurses through an attribute chain to set the ultimate value.""" - attrs = attr.split(".") - if len(attrs) > 1: - deepsetattr(getattr(obj, attrs[0]), '.'.join(attrs[1:]), value) - else: - setattr(obj, attr, value) - -def deepgetattr(obj, attr, default=None): - """Recurses through an attribute chain to get the ultimate value.""" - try: - return functools.reduce(getattr, attr.split("."), obj) - except AttributeError: - return default - -class ReshardingWrapper(nn.Module): - def __init__(self, device_mesh, module, resharding_config): - super().__init__() - self.module = module - self.resharding_config = resharding_config["sharding"] - self.node = resharding_config["node"] - self.device_mesh = device_mesh - - def forward(self, x): - rank = torch.distributed.get_rank() - device_mesh = DeviceMesh("cuda", self.device_mesh) - - required_placement = placement_from_sharding_config(self.resharding_config["data_in_0"]) - if (x.placements != required_placement): - rlog(logger, rank, f"For module {self.node}, resharding tensor x from {x.placements} to {required_placement}", level="debug") - x = Redistribute.apply(x, device_mesh, required_placement) - - out = self.module(x) - - return out def resharding_transform_pass(mg, pass_args={}): """ @@ -70,17 +14,13 @@ def resharding_transform_pass(mg, pass_args={}): module_map = pass_args.get("module_map", None) device_mesh = pass_args.get("device_mesh", None) if module_map is None or device_mesh is None: - raise ValueError("module_map and device_mesh are required for resharding_transform_pass") + raise ValueError( + "module_map and device_mesh are required for resharding_transform_pass" + ) for node in mg.fx_graph.nodes: - if node.op != "call_module": - continue - module = deepgetattr(mg.model, node.target, None) - if module is not None: - resharding_config = module_map[module] - logger.debug(f"Inserting resharding wrapper around node: {node}") - deepsetattr(mg.model, node.target, ReshardingWrapper(device_mesh, module, resharding_config)) + pass mg.model.recompile() - - return mg, {} \ No newline at end of file + + return mg, {} From 8e738b7fc954cb294161d49933a87507a0b4d4b8 Mon Sep 17 00:00:00 2001 From: Pedro Gimenes Date: Wed, 17 Jul 2024 09:35:12 +0000 Subject: [PATCH 42/93] extrapolate sharding from single layer solution --- .../autosharding/alpa_intra_operator.py | 6 ++- .../analysis/autosharding/autosharding.py | 39 ++++++++++++++++--- 2 files changed, 38 insertions(+), 7 deletions(-) diff --git a/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py b/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py index 345dd5593..b431095d6 100644 --- a/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py +++ b/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py @@ -269,7 +269,11 @@ def _mark_sharding(mg, pass_args): if opt_var is None: continue - idx = np.where(opt_var.value == 1)[0][0] + try: + idx = np.where(opt_var.value == 1)[0][0] + except: + idx = np.argmax(opt_var.value) + chosen_strategy = node.meta["mase"]["software"]["autosharding"][ "op_strategy" ].strategies[idx] diff --git a/src/chop/passes/graph/analysis/autosharding/autosharding.py b/src/chop/passes/graph/analysis/autosharding/autosharding.py index 08d71e2e4..66778a33d 100644 --- a/src/chop/passes/graph/analysis/autosharding/autosharding.py +++ b/src/chop/passes/graph/analysis/autosharding/autosharding.py @@ -26,26 +26,47 @@ def deepgetattr(obj, attr, default=None): return default -def _import_solution(mg, solution: dict, mesh: MeshModel): +def _import_solution( + mg, + solution: dict, + mesh: MeshModel, + extrapolate_sharding: bool = True, +): """Import an autosharding solution into the metadata of the MaseGraph. Args: mg (MaseGraph): input mase graph. solution (dict): autosharding solution. + extrapolate (bool): extrapolate solution from the 1st layer to the rest. Returns: MaseGraph: input mase graph. dict: empty dictionary. """ for node in mg.fx_graph.nodes: - if node.name not in solution.keys(): - continue + logger.debug(f"Importing solution for node: {node.name}: {solution[node.name]}") + + if node.name not in solution.keys() and extrapolate_sharding: + layer_num = int([i for i in node.name.split("_") if i.isdigit()][0]) + extrapolate_node = node.name.replace(f"_{layer_num}_", "_0_") + if "decoder_layers" in node.name and extrapolate_node in solution.keys(): + logger.warning( + f"Node: {node.name} not found in solution. Extrapolating from solution for: {extrapolate_node}" + ) + solution[node.name] = solution[extrapolate_node] + else: + logger.debug( + f"Node: {node.name} not found in solution, and cannot extrapolate." + ) + continue + # Annotate the metadata for each argument for arg, arg_spec in solution[node.name].get("args", {}).items(): node.meta["mase"]["common"]["args"][arg]["dtensor_spec"] = DTensorSpec( mesh=mesh, placements=arg_spec ) + # Annotate the metadata for each result for result, result_spec in solution[node.name].get("results", {}).items(): node.meta["mase"]["common"]["results"][result]["dtensor_spec"] = ( DTensorSpec(mesh=mesh, placements=result_spec) @@ -54,12 +75,12 @@ def _import_solution(mg, solution: dict, mesh: MeshModel): return mg, {} -def _export_solution(mg, export_file: str = "ilp_solution.csv"): - """Export the ILP solution to a csv file. +def _export_solution(mg, export_file: str = "ilp_solution.pkl"): + """Export the ILP solution to a pickle file. Args: mg (MaseGraph): input mase graph. - export_file (str, optional): output file name. Defaults to "ilp_solution.csv". + export_file (str, optional): output file name. Defaults to "ilp_solution.pkl". Returns: MaseGraph: input mase graph. @@ -116,6 +137,8 @@ def _get_sharding_map(mg): } """ + logger.info(f"Exporting tensor sharding map from MaseGraph for MaseLauncher.") + tensor_sharding_map = {} for node in mg.fx_graph.nodes: if node.op == "get_attr": @@ -127,6 +150,10 @@ def _get_sharding_map(mg): "dtensor_spec" ] + logger.debug( + f"Exporting sharding map for {node.name} with spec: {out_specs}" + ) + if module not in tensor_sharding_map: tensor_sharding_map[module] = { "node": node.name, From 1eed07e26e2eba7876ff4d34d47d6e6000100bc5 Mon Sep 17 00:00:00 2001 From: pgimenes Date: Wed, 17 Jul 2024 13:39:02 +0100 Subject: [PATCH 43/93] layout for extended docs --- docs/source/index.rst | 8 +++++++- docs/source/modules/api/actions.rst | 2 +- docs/source/modules/api/datasets.md | 1 + docs/source/modules/api/distributed.md | 1 + docs/source/modules/api/ir.md | 2 +- docs/source/modules/api/models.md | 1 + docs/source/modules/api/nn.md | 1 + docs/source/modules/api/passes.rst | 2 +- docs/source/modules/api/pipelines.md | 1 + docs/source/modules/api/tools.md | 1 + .../passes/graph/analysis/autosharding/autosharding.py | 2 +- 11 files changed, 17 insertions(+), 5 deletions(-) create mode 100644 docs/source/modules/api/datasets.md create mode 100644 docs/source/modules/api/distributed.md create mode 100644 docs/source/modules/api/models.md create mode 100644 docs/source/modules/api/nn.md create mode 100644 docs/source/modules/api/pipelines.md create mode 100644 docs/source/modules/api/tools.md diff --git a/docs/source/index.rst b/docs/source/index.rst index 7d58105b1..af1670e39 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -36,8 +36,14 @@ For more, you can watch this :caption: Machop API modules/api/actions - modules/api/passes + modules/api/datasets + modules/api/distributed modules/api/ir + modules/api/models + modules/api/nn + modules/api/passes + modules/api/pipelines + modules/api/tools .. toctree:: :maxdepth: 1 diff --git a/docs/source/modules/api/actions.rst b/docs/source/modules/api/actions.rst index d1b3d4b48..7f1ee8380 100644 --- a/docs/source/modules/api/actions.rst +++ b/docs/source/modules/api/actions.rst @@ -1,4 +1,4 @@ -Chop Actions +chop.actions ==================== chop.actions.train diff --git a/docs/source/modules/api/datasets.md b/docs/source/modules/api/datasets.md new file mode 100644 index 000000000..51c06f667 --- /dev/null +++ b/docs/source/modules/api/datasets.md @@ -0,0 +1 @@ +# chop.datasets \ No newline at end of file diff --git a/docs/source/modules/api/distributed.md b/docs/source/modules/api/distributed.md new file mode 100644 index 000000000..5d02f918a --- /dev/null +++ b/docs/source/modules/api/distributed.md @@ -0,0 +1 @@ +# chop.distributed \ No newline at end of file diff --git a/docs/source/modules/api/ir.md b/docs/source/modules/api/ir.md index fc32c5d44..2983488d4 100644 --- a/docs/source/modules/api/ir.md +++ b/docs/source/modules/api/ir.md @@ -1,4 +1,4 @@ -# IR +# chop.ir ## MaseTracer diff --git a/docs/source/modules/api/models.md b/docs/source/modules/api/models.md new file mode 100644 index 000000000..9373c556b --- /dev/null +++ b/docs/source/modules/api/models.md @@ -0,0 +1 @@ +# chop.models \ No newline at end of file diff --git a/docs/source/modules/api/nn.md b/docs/source/modules/api/nn.md new file mode 100644 index 000000000..69f49affb --- /dev/null +++ b/docs/source/modules/api/nn.md @@ -0,0 +1 @@ +# chop.nn \ No newline at end of file diff --git a/docs/source/modules/api/passes.rst b/docs/source/modules/api/passes.rst index df0cb6952..9b2e945e5 100644 --- a/docs/source/modules/api/passes.rst +++ b/docs/source/modules/api/passes.rst @@ -1,4 +1,4 @@ -Chop Passes +chop.passes ============================ All passes, no matter analysis or transform, take a standard form: diff --git a/docs/source/modules/api/pipelines.md b/docs/source/modules/api/pipelines.md new file mode 100644 index 000000000..0fd043dcd --- /dev/null +++ b/docs/source/modules/api/pipelines.md @@ -0,0 +1 @@ +# chop.pipelines \ No newline at end of file diff --git a/docs/source/modules/api/tools.md b/docs/source/modules/api/tools.md new file mode 100644 index 000000000..18df45d51 --- /dev/null +++ b/docs/source/modules/api/tools.md @@ -0,0 +1 @@ +# chop.tools \ No newline at end of file diff --git a/src/chop/passes/graph/analysis/autosharding/autosharding.py b/src/chop/passes/graph/analysis/autosharding/autosharding.py index 66778a33d..472bbcfae 100644 --- a/src/chop/passes/graph/analysis/autosharding/autosharding.py +++ b/src/chop/passes/graph/analysis/autosharding/autosharding.py @@ -44,7 +44,7 @@ def _import_solution( dict: empty dictionary. """ for node in mg.fx_graph.nodes: - logger.debug(f"Importing solution for node: {node.name}: {solution[node.name]}") + logger.debug(f"Importing solution for node: {node.name}") if node.name not in solution.keys() and extrapolate_sharding: layer_num = int([i for i in node.name.split("_") if i.isdigit()][0]) From 6f740a2f739329badc24a0dad02b5b25afd31ad5 Mon Sep 17 00:00:00 2001 From: Pedro Gimenes Date: Thu, 18 Jul 2024 09:24:02 +0000 Subject: [PATCH 44/93] make dist barrier asynchronous for distributed timing, and account for GPU warmup --- src/chop/distributed/launcher.py | 27 ++++++++++++++++--- .../analysis/autosharding/autosharding.py | 11 +++++--- .../graph/analysis/autosharding/megatron.py | 23 ++++++++++++++++ 3 files changed, 54 insertions(+), 7 deletions(-) create mode 100644 src/chop/passes/graph/analysis/autosharding/megatron.py diff --git a/src/chop/distributed/launcher.py b/src/chop/distributed/launcher.py index 9b9954e0f..48de9a073 100644 --- a/src/chop/distributed/launcher.py +++ b/src/chop/distributed/launcher.py @@ -19,18 +19,32 @@ from ..tools import get_logger logger = get_logger(__name__) -logger.setLevel("DEBUG") +logger.setLevel("INFO") def distributed_timing(fn, *args, **kwargs): - dist.barrier() + dist.barrier(async_op=True) start = time() result = fn(*args, **kwargs) - dist.barrier() + dist.barrier(async_op=True) end = time() + return result, (end - start) +def distributed_average_timing(fn, repeat, args): + times = [] + for _ in range(repeat): + dist.barrier(async_op=True) + start = time() + result = fn(*args) + dist.barrier(async_op=True) + end = time() + times.append(end - start) + + return result, sum(times) / len(times) + + def dist_model_fn( name: str, module: nn.Module, @@ -115,7 +129,11 @@ def device_fn( distribute_tensor(in_tensor, mesh, [Replicate(), Replicate()]) for in_tensor in inputs ] - out, time_taken = distributed_timing(model, *inputs) + _, time_taken = distributed_average_timing( + fn=model, + repeat=5, + args=inputs, + ) rlog(logger, rank, f"Forward pass finished. Time taken: {time_taken}", level="info") dist.destroy_process_group() @@ -130,6 +148,7 @@ def __init__(self, mase_graph, world_size=None, device_mesh=None): def run(self, tensor_sharding_map={}, inputs=[]): logger.info(f"Launching model with world size {self.world_size}.") + mp.spawn( partial( device_fn, diff --git a/src/chop/passes/graph/analysis/autosharding/autosharding.py b/src/chop/passes/graph/analysis/autosharding/autosharding.py index 66778a33d..8a96ee3cd 100644 --- a/src/chop/passes/graph/analysis/autosharding/autosharding.py +++ b/src/chop/passes/graph/analysis/autosharding/autosharding.py @@ -11,9 +11,10 @@ from .mesh_model import MeshModel from .alpa import alpa_autosharding_pass +from .megatron import megatron_autosharding_pass logger = get_logger(__name__) -logger.setLevel("DEBUG") +logger.setLevel("INFO") def deepgetattr(obj, attr, default=None): @@ -44,7 +45,7 @@ def _import_solution( dict: empty dictionary. """ for node in mg.fx_graph.nodes: - logger.debug(f"Importing solution for node: {node.name}: {solution[node.name]}") + logger.debug(f"Importing solution for node: {node.name}") if node.name not in solution.keys() and extrapolate_sharding: layer_num = int([i for i in node.name.split("_") if i.isdigit()][0]) @@ -218,7 +219,7 @@ def autosharding_analysis_pass(mg, pass_args: dict = {}): # Run autosharding pass else: # Define autosharding backend - algo = pass_args.get("sharding_algo", "alpa") + algo = pass_args.get("algo", "alpa") # Communication cost model depends mesh.set_cost_model_parameters( @@ -231,6 +232,10 @@ def autosharding_analysis_pass(mg, pass_args: dict = {}): start_time = time() if algo == "alpa": mg, pass_outs = alpa_autosharding_pass(mg, mesh, pass_args) + elif algo == "megatron": + mg, pass_outs = megatron_autosharding_pass(mg, mesh, pass_args) + else: + raise ValueError(f"Autosharding algorithm {algo} not recognized") end_time = time() autosharding_time = end_time - start_time diff --git a/src/chop/passes/graph/analysis/autosharding/megatron.py b/src/chop/passes/graph/analysis/autosharding/megatron.py new file mode 100644 index 000000000..30cd36f7e --- /dev/null +++ b/src/chop/passes/graph/analysis/autosharding/megatron.py @@ -0,0 +1,23 @@ +from chop.ir import MaseGraph +from .mesh_model import MeshModel + + +def megatron_autosharding_pass( + mg: MaseGraph, + mesh: MeshModel, + pass_args: dict, +): + for node in mg.fx_graph.nodes: + meta = node.meta["mase"]["common"] + + for arg, arg_spec in meta["args"].items(): + if not isinstance(arg_spec, dict): + continue + arg_spec["dtensor_spec"] = None + + for result, result_spec in meta["results"].items(): + if not isinstance(result_spec, dict): + continue + result_spec["dtensor_spec"] = None + + return mg, {"solution": {}} From 014c6db8e2c45dcc9e077cde87ec36028de08306 Mon Sep 17 00:00:00 2001 From: pgimenes Date: Thu, 18 Jul 2024 13:17:47 +0100 Subject: [PATCH 45/93] some docs --- docs/source/modules/api/distributed.md | 8 +++++++- docs/source/modules/api/pipelines.md | 16 +++++++++++++++- src/chop/distributed/launcher.py | 11 +++++++++++ src/chop/pipelines/auto_pipeline.py | 12 ++++++++++++ src/chop/pipelines/distributed_inference.py | 20 +++++++++++++++++--- 5 files changed, 62 insertions(+), 5 deletions(-) diff --git a/docs/source/modules/api/distributed.md b/docs/source/modules/api/distributed.md index 5d02f918a..3b32e33d7 100644 --- a/docs/source/modules/api/distributed.md +++ b/docs/source/modules/api/distributed.md @@ -1 +1,7 @@ -# chop.distributed \ No newline at end of file +# chop.distributed + +## MaseLauncher + +```{eval-rst} +.. autoclass:: chop.distributed.MaseLauncher +``` \ No newline at end of file diff --git a/docs/source/modules/api/pipelines.md b/docs/source/modules/api/pipelines.md index 0fd043dcd..db9de6d8a 100644 --- a/docs/source/modules/api/pipelines.md +++ b/docs/source/modules/api/pipelines.md @@ -1 +1,15 @@ -# chop.pipelines \ No newline at end of file +# chop.pipelines + +Mase pipelines are pass managers that let you do things quicker for common use cases, without the faff of managing individual compiler passes. + +## AutoPipeline (base class) + +```{eval-rst} +.. autoclass:: chop.pipelines.AutoPipeline +``` + +## AutoPipeline for Distributed Inference + +```{eval-rst} +.. autoclass:: chop.pipelines.AutoPipelineForDistributedInference +``` \ No newline at end of file diff --git a/src/chop/distributed/launcher.py b/src/chop/distributed/launcher.py index 9b9954e0f..143747d6a 100644 --- a/src/chop/distributed/launcher.py +++ b/src/chop/distributed/launcher.py @@ -122,7 +122,18 @@ def device_fn( class MaseLauncher: + """ + MaseLauncher launches an optimized model on multiple GPUs using torch.distributed. + """ + def __init__(self, mase_graph, world_size=None, device_mesh=None): + """Initialize the MaseLauncher. + + Args: + mase_graph (MaseGraph): The MaseGraph object containing the model. + world_size (int, optional): Number of GPUs to use. Defaults to None. + device_mesh (list, optional): List of GPUs to use. Defaults to None. + """ self.mg = mase_graph self.model = mase_graph.model self.world_size = world_size diff --git a/src/chop/pipelines/auto_pipeline.py b/src/chop/pipelines/auto_pipeline.py index 02ffe27e1..967f0f2f6 100644 --- a/src/chop/pipelines/auto_pipeline.py +++ b/src/chop/pipelines/auto_pipeline.py @@ -5,7 +5,19 @@ class AutoPipeline: + """This is the base class for the AutoPipeline. + + It takes a list of passes and runs them in order. + + The output of each pass is stored in a dictionary and can be accessed by the next pass. + """ + def __init__(self, pass_list=[]) -> None: + """Initializes the AutoPipeline. + + Args: + pass_list (list, optional): List of passes to run. Defaults to []. + """ self.pass_list = pass_list self.pass_outputs = {} diff --git a/src/chop/pipelines/distributed_inference.py b/src/chop/pipelines/distributed_inference.py index 6907fb65f..bc06ebab5 100644 --- a/src/chop/pipelines/distributed_inference.py +++ b/src/chop/pipelines/distributed_inference.py @@ -4,16 +4,30 @@ class AutoPipelineForDistributedInference(AutoPipeline): + """This pipeline is used for distributed inference. + + It runs the following passes: + + - init_metadata_analysis_pass + + - report_graph_analysis_pass + + - add_common_metadata_analysis_pass + + - autosharding_analysis_pass + + - resharding_transform_pass + """ + def __init__(self) -> None: + """Initializes the AutoPipeline.""" pass_list = [ passes.init_metadata_analysis_pass, passes.report_graph_analysis_pass, passes.add_common_metadata_analysis_pass, - passes.report_node_meta_param_analysis_pass, passes.autosharding_analysis_pass, passes.resharding_transform_pass, - passes.graph.analysis.report.report_parallelization_analysis_pass, ] - super().__init__(pass_list) \ No newline at end of file + super().__init__(pass_list) From 5cfeb5d749a817e46ca683a97643e9c2933ffbf0 Mon Sep 17 00:00:00 2001 From: Pedro Gimenes Date: Fri, 19 Jul 2024 14:28:48 +0000 Subject: [PATCH 46/93] fixes to extrapolate single layer solution, improved reporting for exporting solution, refactor shape inference for getitem nodes, sdpa strategy --- .../analysis/autosharding/autosharding.py | 64 +++++++-- .../autosharding/strategies/common.py | 20 ++- .../autosharding/strategies/matrix_ops.py | 126 ++++++++++++++++-- 3 files changed, 177 insertions(+), 33 deletions(-) diff --git a/src/chop/passes/graph/analysis/autosharding/autosharding.py b/src/chop/passes/graph/analysis/autosharding/autosharding.py index 8a96ee3cd..04b3218fe 100644 --- a/src/chop/passes/graph/analysis/autosharding/autosharding.py +++ b/src/chop/passes/graph/analysis/autosharding/autosharding.py @@ -1,7 +1,6 @@ import numpy as np import cvxpy as cp from time import time -import csv import dill from torch.distributed._tensor._op_schema import DTensorSpec @@ -48,9 +47,14 @@ def _import_solution( logger.debug(f"Importing solution for node: {node.name}") if node.name not in solution.keys() and extrapolate_sharding: + + # Expect the layer number to be the first digit in the node name layer_num = int([i for i in node.name.split("_") if i.isdigit()][0]) - extrapolate_node = node.name.replace(f"_{layer_num}_", "_0_") - if "decoder_layers" in node.name and extrapolate_node in solution.keys(): + + # Only replace the first digit to find the equivalent node in the first layer + extrapolate_node = node.name.replace(f"_{layer_num}_", "_0_", 1) + + if extrapolate_node in solution.keys(): logger.warning( f"Node: {node.name} not found in solution. Extrapolating from solution for: {extrapolate_node}" ) @@ -64,13 +68,17 @@ def _import_solution( # Annotate the metadata for each argument for arg, arg_spec in solution[node.name].get("args", {}).items(): node.meta["mase"]["common"]["args"][arg]["dtensor_spec"] = DTensorSpec( - mesh=mesh, placements=arg_spec + mesh=mesh, + placements=arg_spec, ) # Annotate the metadata for each result for result, result_spec in solution[node.name].get("results", {}).items(): node.meta["mase"]["common"]["results"][result]["dtensor_spec"] = ( - DTensorSpec(mesh=mesh, placements=result_spec) + DTensorSpec( + mesh=mesh, + placements=result_spec, + ) ) return mg, {} @@ -98,18 +106,38 @@ def _export_solution(mg, export_file: str = "ilp_solution.pkl"): for arg, arg_info in node.meta["mase"]["common"]["args"].items(): if not isinstance(arg_info, dict): continue - out_dict[node_name]["args"][arg] = arg_info.get( - "dtensor_spec", DTensorSpec(None, (Replicate(), Replicate())) - ).placements + + if "dtensor_spec" not in arg_info: + logger.warning( + f"DTensor spec not found for arg: {arg} in node: {node_name}. Assigning fully-replicated solution." + ) + spec = DTensorSpec( + None, + (Replicate(), Replicate()), + ) + else: + spec = arg_info["dtensor_spec"] + + out_dict[node_name]["args"][arg] = spec.placements for result, result_info in node.meta["mase"]["common"]["results"].items(): if not isinstance(result_info, dict): continue - out_dict[node_name]["results"][result] = result_info.get( - "dtensor_spec", DTensorSpec(None, (Replicate(), Replicate())) - ).placements - with open(export_file.replace(".csv", ".pkl"), "wb") as file: + # TO DO: add warning when dtensor_spec not found + if "dtensor_spec" not in result_info: + logger.warning( + f"DTensor spec not found for result: {result} in node: {node_name}. Assigning fully-replicated solution." + ) + spec = DTensorSpec( + None, + (Replicate(), Replicate()), + ) + else: + spec = result_info["dtensor_spec"] + out_dict[node_name]["results"][result] = spec.placements + + with open(export_file, "wb") as file: dill.dump(out_dict, file) return mg, {} @@ -147,9 +175,17 @@ def _get_sharding_map(mg): attr = node.target.split(".")[-1] module = deepgetattr(node.meta["mase"].model, module_str) - out_specs = node.meta["mase"]["common"]["results"]["data_out_0"][ + if ( "dtensor_spec" - ] + not in node.meta["mase"]["common"]["results"]["data_out_0"] + ): + raise ValueError( + f"Couldn't find DTensor sharding specification in solution for node: {node.name}" + ) + else: + out_specs = node.meta["mase"]["common"]["results"]["data_out_0"][ + "dtensor_spec" + ] logger.debug( f"Exporting sharding map for {node.name} with spec: {out_specs}" diff --git a/src/chop/passes/graph/analysis/autosharding/strategies/common.py b/src/chop/passes/graph/analysis/autosharding/strategies/common.py index 0e27cfc6c..0959a6bf3 100644 --- a/src/chop/passes/graph/analysis/autosharding/strategies/common.py +++ b/src/chop/passes/graph/analysis/autosharding/strategies/common.py @@ -17,21 +17,19 @@ def find_shape_and_dtype(arg): + # If the argument in meta["common"]["args"][key] is correctly + # formulated with data, just extract shape and dtype if isinstance(arg, dict): in_shape = arg["shape"] in_dtype = arg["torch_dtype"] - elif isinstance(arg, (tuple, list)): - arg = torch.Tensor(arg) - in_shape = arg.shape - in_dtype = arg.dtype - elif isinstance(arg, torch.Size): - arg = torch.Tensor(list(arg)) - in_shape = arg.shape - in_dtype = arg.dtype + + # Otherwise, depends on the type of argument + elif isinstance(arg, torch.Size) or isinstance(arg, (tuple, list)): + in_shape = (len(arg),) + in_dtype = type(arg[0]) elif isinstance(arg, (float, int)): - arg = torch.Tensor([arg]) - in_shape = arg.shape - in_dtype = arg.dtype + in_shape = (1,) + in_dtype = type(arg) else: logger.warning(f"Unknown type for arg: {arg}") in_shape = tuple() diff --git a/src/chop/passes/graph/analysis/autosharding/strategies/matrix_ops.py b/src/chop/passes/graph/analysis/autosharding/strategies/matrix_ops.py index 252d5df41..76f30db47 100644 --- a/src/chop/passes/graph/analysis/autosharding/strategies/matrix_ops.py +++ b/src/chop/passes/graph/analysis/autosharding/strategies/matrix_ops.py @@ -2,7 +2,12 @@ # https://github.com/pytorch/pytorch/blob/main/torch/distributed/_tensor/ops/matrix_ops.py import torch -from torch.distributed._tensor._op_schema import OpStrategy, PlacementStrategy +from torch.distributed._tensor._op_schema import ( + OpStrategy, + PlacementStrategy, + PlacementList, +) +from torch.distributed._tensor.placement_types import Replicate, Shard, Placement from .basic_strategy import gen_einsum_strategies from torch.distributed._tensor.ops.utils import ( infer_broadcast_dims_map, @@ -19,7 +24,10 @@ from chop.ir.graph import MaseMetadata -def transpose_strategy(meta: MaseMetadata, mesh: tuple) -> OpStrategy: +def transpose_strategy( + meta: MaseMetadata, + mesh: tuple, +) -> OpStrategy: parent_node = meta.node.args[0] self_strategy = parent_node.meta["mase"]["software"]["autosharding"]["op_strategy"] @@ -51,7 +59,11 @@ def transpose_strategy(meta: MaseMetadata, mesh: tuple) -> OpStrategy: return OpStrategy(strategies=transpose_strategies) -def _mm_like_strategy(mm_equation: str, meta: MaseMetadata, mesh: tuple) -> OpStrategy: +def _mm_like_strategy( + mm_equation: str, + meta: MaseMetadata, + mesh: tuple, +) -> OpStrategy: self_shape, mat2_shape = [arg["shape"] for arg in meta["common"]["args"].values()] # generate all possible strategies for mm mm_strategy = gen_einsum_strategies(mm_equation, mesh) @@ -90,7 +102,9 @@ def _mm_like_strategy(mm_equation: str, meta: MaseMetadata, mesh: tuple) -> OpSt def _addmm_like_strategy( - mm_equation: str, meta: MaseMetadata, mesh: tuple + mm_equation: str, + meta: MaseMetadata, + mesh: tuple, ) -> OpStrategy: self_shape, mat1_shape, mat2_shape = [ @@ -156,17 +170,113 @@ def _addmm_like_strategy( return mm_strategy -def mm_strategy(meta: MaseMetadata, mesh: tuple) -> OpStrategy: +def mm_strategy( + meta: MaseMetadata, + mesh: tuple, +) -> OpStrategy: return _mm_like_strategy("mk,kn->mn", meta, mesh) -def addmm_strategy(meta: MaseMetadata, mesh: tuple) -> OpStrategy: +def addmm_strategy( + meta: MaseMetadata, + mesh: tuple, +) -> OpStrategy: return _addmm_like_strategy("mk,kn->mn", meta, mesh) -def bmm_strategy(meta: MaseMetadata, mesh: tuple) -> OpStrategy: +def bmm_strategy( + meta: MaseMetadata, + mesh: tuple, +) -> OpStrategy: return _mm_like_strategy("bmk,bkn->bmn", meta, mesh) -def baddmm_strategy(meta: MaseMetadata, mesh: tuple) -> OpStrategy: +def baddmm_strategy( + meta: MaseMetadata, + mesh: tuple, +) -> OpStrategy: return _addmm_like_strategy("bmk,bkn->bmn", meta, mesh) + + +def scaled_dot_product_flash_attention_strategy( + meta: MaseMetadata, + mesh: tuple, +) -> OpStrategy: + breakpoint() + # NOTE: currently we only support some simple strategies to support tensor parallelism + # TODO: sdpa might be a good candidate for us to explore decomposed sharding propagation + # as it involves: matmul, pointwise, reduction ops together. + return_debug_mask = len(op_schema.args_schema) >= 6 and op_schema.args_schema[5] + q_input_strategy = op_schema.args_schema[0] + assert isinstance(q_input_strategy, OpStrategy) + # assuming q/k/v have the same shape + qkv_shape = q_input_strategy.shape + + single_mesh_dim_strategies = [] + + # placement list stores placements of [outputs, inputs] + # in the spda case, we have 3 valid tensor outputs and 3 tensor inputs + # first we can always accept full replication for both inputs and outputs + all_replicate: PlacementList = [ + Replicate(), + Replicate(), + None, # cum_seq_q + None, # cum_seq_k + None, # max_q + None, # max_k + None, # philox_seed + None, # philox_offset + Replicate(), + Replicate(), + Replicate(), + Replicate(), + ] + single_mesh_dim_strategies.append(all_replicate) + + # second we can accept the sharding pattern of tensor parallelism, which + # shard on the num of head dim + qkv_sharding = Shard(1) # num head dim + output_sharding = Shard(1) # num head dim + logsumexp_sharding = Shard(1) # num head dim + if return_debug_mask: + debug_attn_mask_sharding: Placement = Shard(1) # num head dim + else: + # empty debug mask, replicated + debug_attn_mask_sharding = Replicate() + + num_heads_dim_sharding: PlacementList = [ + output_sharding, + logsumexp_sharding, + None, # cum_seq_q + None, # cum_seq_k + None, # max_q + None, # max_k + None, # philox_seed + None, # philox_offset + debug_attn_mask_sharding, + qkv_sharding, + qkv_sharding, + qkv_sharding, + ] + single_mesh_dim_strategies.append(num_heads_dim_sharding) + + # Context Parallelism: shards on the sequence dim + single_mesh_dim_strategies.append( + [ + Shard(2), # output + Shard(2), # logsumexp + None, # cum_seq_q + None, # cum_seq_k + None, # max_q + None, # max_k + None, # philox_seed + None, # philox_offset + Shard(2), # debugattn + Shard(2), # q + Shard(2), # k + Shard(2), # v + ] + ) + return expand_to_full_mesh_op_strategy( + mesh, op_schema, single_mesh_dim_strategies, input_index=9 + ) From 2e3440d74c00e5ad7179bc6f7d3a78485bf2e8d5 Mon Sep 17 00:00:00 2001 From: Pedro Gimenes Date: Tue, 23 Jul 2024 10:04:28 +0000 Subject: [PATCH 47/93] get solution extrapolation working for GPT2 --- src/chop/distributed/launcher.py | 20 +++++++++++++++---- .../analysis/autosharding/autosharding.py | 7 +++++++ 2 files changed, 23 insertions(+), 4 deletions(-) diff --git a/src/chop/distributed/launcher.py b/src/chop/distributed/launcher.py index a1ad11ce5..e75c59e7c 100644 --- a/src/chop/distributed/launcher.py +++ b/src/chop/distributed/launcher.py @@ -19,7 +19,7 @@ from ..tools import get_logger logger = get_logger(__name__) -logger.setLevel("INFO") +logger.setLevel("DEBUG") def distributed_timing(fn, *args, **kwargs): @@ -34,15 +34,27 @@ def distributed_timing(fn, *args, **kwargs): def distributed_average_timing(fn, repeat, args): times = [] - for _ in range(repeat): + for itr in range(repeat): + rlog( + logger, + dist.get_rank(), + f"Running teration {itr}", + "debug", + ) dist.barrier(async_op=True) start = time() result = fn(*args) dist.barrier(async_op=True) end = time() times.append(end - start) + rlog( + logger, + dist.get_rank(), + f"Time taken: {end - start}s", + "debug", + ) - return result, sum(times) / len(times) + return result, sum(times[2:]) / len(times[2:]) def dist_model_fn( @@ -131,7 +143,7 @@ def device_fn( ] _, time_taken = distributed_average_timing( fn=model, - repeat=5, + repeat=10, args=inputs, ) rlog(logger, rank, f"Forward pass finished. Time taken: {time_taken}", level="info") diff --git a/src/chop/passes/graph/analysis/autosharding/autosharding.py b/src/chop/passes/graph/analysis/autosharding/autosharding.py index 04b3218fe..78a2175ce 100644 --- a/src/chop/passes/graph/analysis/autosharding/autosharding.py +++ b/src/chop/passes/graph/analysis/autosharding/autosharding.py @@ -46,6 +46,13 @@ def _import_solution( for node in mg.fx_graph.nodes: logger.debug(f"Importing solution for node: {node.name}") + # Only import solution for getattr nodes + # TO DO: this is hard-coded for GPT2 + # Figure out how to generalize + if not node.name.startswith("transformer_"): + continue + + # Extrapolate from first layer by string matching if node.name not in solution.keys() and extrapolate_sharding: # Expect the layer number to be the first digit in the node name From 166fbcfcd76671222f0b2a07c6d9ded06ad45959 Mon Sep 17 00:00:00 2001 From: Pedro Gimenes Date: Tue, 23 Jul 2024 11:14:34 +0000 Subject: [PATCH 48/93] migrate DTensor API to chop/distributed/tensor --- src/chop/distributed/launcher.py | 4 +- src/chop/distributed/tensor/__init__.py | 358 ++++++ src/chop/distributed/tensor/_dispatch.py | 434 +++++++ src/chop/distributed/tensor/_redistribute.py | 340 ++++++ src/chop/distributed/tensor/_sharding_prop.py | 489 ++++++++ src/chop/distributed/tensor/_utils.py | 226 ++++ src/chop/distributed/tensor/api.py | 847 +++++++++++++ src/chop/distributed/tensor/ops/__init__.py | 10 + .../distributed/tensor/ops/basic_strategy.py | 181 +++ .../distributed/tensor/ops/common_rules.py | 288 +++++ src/chop/distributed/tensor/ops/conv_ops.py | 109 ++ .../distributed/tensor/ops/embedding_ops.py | 251 ++++ .../tensor/ops/experimental_ops.py | 26 + src/chop/distributed/tensor/ops/math_ops.py | 1056 +++++++++++++++++ src/chop/distributed/tensor/ops/matrix_ops.py | 459 +++++++ .../distributed/tensor/ops/pointwise_ops.py | 663 +++++++++++ src/chop/distributed/tensor/ops/random_ops.py | 38 + src/chop/distributed/tensor/ops/tensor_ops.py | 791 ++++++++++++ src/chop/distributed/tensor/ops/utils.py | 300 +++++ src/chop/distributed/tensor/ops/view_ops.py | 669 +++++++++++ .../graph/analysis/autosharding/alpa.py | 1 - 21 files changed, 7537 insertions(+), 3 deletions(-) create mode 100644 src/chop/distributed/tensor/__init__.py create mode 100644 src/chop/distributed/tensor/_dispatch.py create mode 100644 src/chop/distributed/tensor/_redistribute.py create mode 100644 src/chop/distributed/tensor/_sharding_prop.py create mode 100644 src/chop/distributed/tensor/_utils.py create mode 100644 src/chop/distributed/tensor/api.py create mode 100644 src/chop/distributed/tensor/ops/__init__.py create mode 100644 src/chop/distributed/tensor/ops/basic_strategy.py create mode 100644 src/chop/distributed/tensor/ops/common_rules.py create mode 100644 src/chop/distributed/tensor/ops/conv_ops.py create mode 100644 src/chop/distributed/tensor/ops/embedding_ops.py create mode 100644 src/chop/distributed/tensor/ops/experimental_ops.py create mode 100644 src/chop/distributed/tensor/ops/math_ops.py create mode 100644 src/chop/distributed/tensor/ops/matrix_ops.py create mode 100644 src/chop/distributed/tensor/ops/pointwise_ops.py create mode 100644 src/chop/distributed/tensor/ops/random_ops.py create mode 100644 src/chop/distributed/tensor/ops/tensor_ops.py create mode 100644 src/chop/distributed/tensor/ops/utils.py create mode 100644 src/chop/distributed/tensor/ops/view_ops.py diff --git a/src/chop/distributed/launcher.py b/src/chop/distributed/launcher.py index e75c59e7c..8157636d6 100644 --- a/src/chop/distributed/launcher.py +++ b/src/chop/distributed/launcher.py @@ -9,12 +9,12 @@ from torch.distributed._tensor import ( DeviceMesh, - distribute_module, - distribute_tensor, Replicate, Shard, ) +from chop.distributed.tensor import distribute_module, distribute_tensor + from chop.distributed.utils import rlog from ..tools import get_logger diff --git a/src/chop/distributed/tensor/__init__.py b/src/chop/distributed/tensor/__init__.py new file mode 100644 index 000000000..3d6067d28 --- /dev/null +++ b/src/chop/distributed/tensor/__init__.py @@ -0,0 +1,358 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates +from typing import Optional, Sequence + +# Import all builtin dist tensor ops +import torch +import torch.distributed._tensor.random as random + +from torch.distributed._tensor.placement_types import ( + Partial, + Placement, + Replicate, + Shard, +) +from torch.distributed.device_mesh import _mesh_resources, DeviceMesh, init_device_mesh + +import chop.distributed.tensor.ops +from chop.distributed.tensor._utils import compute_local_shape +from chop.distributed.tensor.api import distribute_module, distribute_tensor, DTensor +from chop.distributed.tensor.ops.utils import normalize_to_torch_size + + +# All public APIs from dtensor package +__all__ = [ + "DTensor", + "DeviceMesh", + "distribute_tensor", + "distribute_module", + "init_device_mesh,", + "Shard", + "Replicate", + "Partial", +] + + +def _dtensor_init_helper( + init_op, + size: torch.Size, + device_mesh=None, + placements=None, + **kwargs, +) -> DTensor: + from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta + + # if device_mesh is None, use the one from mesh resources + device_mesh = device_mesh or _mesh_resources.get_current_mesh() + kwargs["device"] = device_mesh.device_type + + # set default placements to replicated if not specified + placements = placements or tuple(Replicate() for _ in range(device_mesh.ndim)) + + # check device_mesh againts placements + assert device_mesh.ndim == len( + placements + ), "mesh dimension does not match the length of placements" + + assert kwargs["layout"] == torch.strided, "layout value not supported!" + torch_stride = torch._prims_common.make_contiguous_strides_for(size) + + # get local tensor shape + local_shape = compute_local_shape(size, device_mesh, placements) + # initialize the local tensor + if init_op == torch.full: + fill_value = kwargs.pop("fill_value", 0) + local_tensor = init_op(local_shape, fill_value, **kwargs) + elif init_op == torch.rand or init_op == torch.randn: + # this tensor meta is not used except `shape` + dtype = kwargs.get("dtype", torch.get_default_dtype()) + + tensor_meta = TensorMeta(size, (0,), dtype) + spec = DTensorSpec(device_mesh, placements, tensor_meta=tensor_meta) + + if random.is_rng_supported_mesh(device_mesh) and not random._rng_tracker: + random._rng_tracker = random.OffsetBasedRNGTracker() + + assert random._rng_tracker is not None + with random._rng_tracker._distribute_region(spec): + local_tensor = init_op(local_shape, **kwargs) + else: + local_tensor = init_op(local_shape, **kwargs) + + spec = DTensorSpec( + device_mesh, + tuple(placements), + tensor_meta=TensorMeta( + size, + torch_stride, + local_tensor.dtype, + ), + ) + + return DTensor( + local_tensor, + spec, + requires_grad=kwargs["requires_grad"], + ) + + +def ones( + *size, + dtype: Optional[torch.dtype] = None, + layout: torch.layout = torch.strided, + requires_grad: bool = False, + device_mesh: Optional[DeviceMesh] = None, + placements: Optional[Sequence[Placement]] = None, +) -> DTensor: + """ + Returns a :class:`DTensor` filled with the scalar value 1, with the shape defined + by the variable argument ``size``. + + Args: + size (int...): a sequence of integers defining the shape of the output :class:`DTensor`. + Can be a variable number of arguments or a collection like a list or tuple. + E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..)) + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned DTensor. + Default: ``torch.strided``. + requires_grad (bool, optional): If autograd should record operations on the + returned :class:`DTensor`. Default: ``False``. + device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks + placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate`` + + Returns: + A :class:`DTensor` object on each rank + """ + torch_size = normalize_to_torch_size(size) + + return _dtensor_init_helper( + torch.ones, + torch_size, + dtype=dtype, + layout=layout, + requires_grad=requires_grad, + device_mesh=device_mesh, + placements=placements, + ) + + +def empty( + *size, + dtype: Optional[torch.dtype] = None, + layout: torch.layout = torch.strided, + requires_grad: bool = False, + device_mesh: Optional[DeviceMesh] = None, + placements: Optional[Sequence[Placement]] = None, +) -> DTensor: + """ + Returns a :class:`DTensor` filled with uninitialized data. The shape of the :class:`DTensor` + is defined by the variable argument ``size``. + + Args: + size (int...): a sequence of integers defining the shape of the output :class:`DTensor`. + Can be a variable number of arguments or a collection like a list or tuple. + E.g.: empty(1,2,3..) or empty([1,2,3..]) or empty((1,2,3..)) + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`).\ + layout (:class:`torch.layout`, optional): the desired layout of returned :class:`DTensor`. + Default: ``torch.strided``. + requires_grad (bool, optional): If autograd should record operations on the + returned :class:`DTensor`. Default: ``False``. + device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks + placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate`` + + Returns: + A :class:`DTensor` object on each rank + """ + torch_size = normalize_to_torch_size(size) + + return _dtensor_init_helper( + torch.empty, + torch_size, + dtype=dtype, + layout=layout, + requires_grad=requires_grad, + device_mesh=device_mesh, + placements=placements, + ) + + +def full( + size, + fill_value, + *, + dtype: Optional[torch.dtype] = None, + layout: torch.layout = torch.strided, + requires_grad: bool = False, + device_mesh: Optional[DeviceMesh] = None, + placements: Optional[Sequence[Placement]] = None, +) -> DTensor: + """ + Returns a :class:`DTensor` filled with ``fill_value``. The scalar value type should match + ``device_mesh.device_type``. + + Args: + size (int...): a sequence of integers defining the shape of the output :class:`DTensor`. + Can be a variable number of arguments or a collection like a list or tuple. + E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..)) + fill_value(Scalar): the value to fill the output tensor with. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned DTensor. + Default: ``torch.strided``. + requires_grad (bool, optional): If autograd should record operations on the + returned :class:`DTensor`. Default: ``False``. + device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks. + placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate`` + + Returns: + A :class:`DTensor` object on each rank + """ + torch_size = normalize_to_torch_size(size) + + return _dtensor_init_helper( + torch.full, + torch_size, + fill_value=fill_value, + dtype=dtype, + layout=layout, + requires_grad=requires_grad, + device_mesh=device_mesh, + placements=placements, + ) + + +def rand( + *size, + requires_grad: bool = False, + dtype: Optional[torch.dtype] = None, + layout: torch.layout = torch.strided, + device_mesh: Optional[DeviceMesh] = None, + placements: Optional[Sequence[Placement]] = None, +) -> DTensor: + """ + Returns a :class:`DTensor` filled with random numbers from a uniform distribution + on the interval ``[0, 1)``. The shape of the tensor is defined by the variable + argument ``size``. + + Args: + size (int...): a sequence of integers defining the shape of the output :class:`DTensor`. + Can be a variable number of arguments or a collection like a list or tuple. + E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..)) + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned DTensor. + Default: ``torch.strided``. + requires_grad (bool, optional): If autograd should record operations on the + returned :class:`DTensor`. Default: ``False``. + device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks. + placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate`` + + Returns: + A :class:`DTensor` object on each rank + """ + torch_size = normalize_to_torch_size(size) + + return _dtensor_init_helper( + torch.rand, + torch_size, + dtype=dtype, + layout=layout, + requires_grad=requires_grad, + device_mesh=device_mesh, + placements=placements, + ) + + +def randn( + *size, + requires_grad: bool = False, + dtype: Optional[torch.dtype] = None, + layout: torch.layout = torch.strided, + device_mesh: Optional[DeviceMesh] = None, + placements: Optional[Sequence[Placement]] = None, +) -> DTensor: + """ + Returns a :class:`DTensor` filled with random numbers from a normal distribution + with mean 0 and variance 1. The shape of the tensor is defined by the variable + argument ``size``. + + Args: + size (int...): a sequence of integers defining the shape of the output :class:`DTensor`. + Can be a variable number of arguments or a collection like a list or tuple. + E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..)) + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned DTensor. + Default: ``torch.strided``. + requires_grad (bool, optional): If autograd should record operations on the + returned :class:`DTensor`. Default: ``False``. + device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks. + placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate`` + + Returns: + A :class:`DTensor` object on each rank + """ + torch_size = normalize_to_torch_size(size) + + return _dtensor_init_helper( + torch.randn, + torch_size, + dtype=dtype, + layout=layout, + requires_grad=requires_grad, + device_mesh=device_mesh, + placements=placements, + ) + + +def zeros( + *size, + requires_grad: bool = False, + dtype: Optional[torch.dtype] = None, + layout: torch.layout = torch.strided, + device_mesh: Optional[DeviceMesh] = None, + placements: Optional[Sequence[Placement]] = None, +) -> DTensor: + """ + Returns a :class:`DTensor` filled with the scalar value 0. + + Args: + size (int...): a sequence of integers defining the shape of the output :class:`DTensor`. + Can be a variable number of arguments or a collection like a list or tuple. + E.g.: zeros(1,2,3..) or zeros([1,2,3..]) or zeros((1,2,3..)) + Keyword args: + requires_grad (bool, optional): If autograd should record operations on the + returned :class:`DTensor`. Default: ``False``. + dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned :class:`DTensor`. + Default: ``torch.strided``. + device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks + placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate`` + + Returns: + A :class:`DTensor` object on each rank + """ + torch_size = normalize_to_torch_size(size) + + return _dtensor_init_helper( + torch.zeros, + torch_size, + dtype=dtype, + layout=layout, + requires_grad=requires_grad, + device_mesh=device_mesh, + placements=placements, + ) diff --git a/src/chop/distributed/tensor/_dispatch.py b/src/chop/distributed/tensor/_dispatch.py new file mode 100644 index 000000000..03e9139c0 --- /dev/null +++ b/src/chop/distributed/tensor/_dispatch.py @@ -0,0 +1,434 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +import contextlib +import functools +import logging +import operator +import warnings +from typing import cast, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING + +import torch +import torch.distributed as dist +import torch.distributed._tensor.random as random +from torch.distributed._tensor._op_schema import ( + _is_inplace_op, + _is_out_variant_op, + OpInfo, + OpSchema, + OutputSpecType, +) +from torch.distributed._tensor._redistribute import redistribute_local_tensor +from torch.distributed._tensor._tp_conv import ( + convolution_backward_handler, + convolution_handler, +) +from torch.distributed._tensor._utils import try_find_mesh_from_args +from torch.distributed._tensor.placement_types import DTensorSpec, Replicate, TensorMeta +from torch.distributed._tensor.random import is_rng_supported_mesh + + +if TYPE_CHECKING: + from torch.distributed.device_mesh import DeviceMesh + +try: + from torch.utils import _cxx_pytree as pytree +except ImportError: + from torch.utils import _pytree as pytree # type: ignore[no-redef] + +import chop.distributed.tensor.api as dtensor +from chop.distributed.tensor._sharding_prop import ShardingPropagator + +aten = torch.ops.aten +logger = logging.getLogger(__name__) + + +def decompose_handler( + op_call: torch._ops.OpOverload, + args: Tuple[object, ...], + kwargs: Dict[str, object], +) -> object: + """ + Decomposes a op to core ATen op, this handler is mostly here + for inference mode usage where the ops are not core aten ops. + """ + r = op_call.decompose(*args, **kwargs) + if r is not NotImplemented: + return r + else: + raise RuntimeError("Decomposition failed") + + +def is_same_size_handler( + op_call: torch._ops.OpOverload, + args: Tuple[object, ...], + kwargs: Dict[str, object], +) -> bool: + lhs = cast(torch.Tensor, args[0]) + rhs = cast(torch.Tensor, args[1]) + return lhs.shape == rhs.shape + + +def rlog(msg): + rank = torch.distributed.get_rank() + if rank == 0: + print(msg) + + +class OpDispatcher: + """ + Op dispatching class instance to handle args/kwargs pre-processing (un-wrapping), sharding + propagation, redistribute local args, local compute, and post-processing (re-wrapping). It + also handles any op specific logic if necessary. + """ + + def __init__(self) -> None: + self.sharding_propagator = ShardingPropagator() + self._random_ops = { + aten.native_dropout.default, + aten.normal_.default, + aten.rand_like.default, + aten.randn_like.default, + aten.randint_like.default, + aten.randint_like.low_dtype, + aten.randint_like.low_dtype_out, + aten.uniform_.default, + aten.bernoulli.default, + aten.bernoulli_.float, + } + self._custom_op_handlers = { + aten.linear.default: decompose_handler, + aten.is_same_size.default: is_same_size_handler, + aten.convolution.default: convolution_handler, + aten.convolution_backward.default: convolution_backward_handler, + } + + # This flag is used internally to control whether we treat the torch.Tensor(non-DTensor) + # as implicitly replicated or we throw error to user. + # NOTE: It is EXTREMELY UNSAFE to turn this flag on by default so we intentionally leave + # it as False by default. + self._allow_implicit_replication = True + + def dispatch( + self, + op_call: torch._ops.OpOverload, + args: Tuple[object, ...], + kwargs: Dict[str, object], + ) -> object: + """ + Main dispatching logic + """ + # operators that does not need to go through sharding propagation + + rlog(f"Dispatching op call: {op_call}") + + if op_call in self._custom_op_handlers: + return self._custom_op_handlers[op_call](op_call, args, kwargs) # type: ignore[operator] + + # extract local tensor and sharding infos to a OpInfo + op_info = self.unwrap_to_op_info(op_call, args, kwargs) + logger.debug("Dispatching op_call: %s", op_info.schema) + + self.sharding_propagator.propagate(op_info) + output_sharding = op_info.output_sharding + logger.debug("output_sharding for %s: %s", op_call, output_sharding) + assert output_sharding is not None, "output sharding should not be None" + + mesh = op_info.mesh + if mesh.get_coordinate() is None: + # For a non-participating device, we do: + # 1. if the return type is scalar, set the local result to None. + # The local results from all devices will then be all-gathered + # and a reduce op will be performed on the list of results + # with appropriate operators: + # for bool type, we by default use AND to reduce; + # we can extend for more ops if necessary. + # 2. if the return type is Tensor or List[Tensor], return empty + # tensor(s) with correct dtype. + spec = output_sharding.output_spec + ret_list = op_info.schema.op._schema.returns + + if spec is None: + # For a scalar return type, the non-participating device has None + # as its local result + local_results: object = None + else: + + def default_tensor(spec: DTensorSpec) -> torch.Tensor: + if spec.tensor_meta is not None: + shape = spec.tensor_meta.shape + dtype = spec.tensor_meta.dtype + if len(shape) == 0: + # scalar tensor + return torch.zeros((), dtype=dtype) + else: + # non-scalar tensor + return torch.tensor([], dtype=dtype) + else: + raise RuntimeError(f"{spec} has no tensor metadata.") + + if isinstance(spec, DTensorSpec): + # return a Tensor value + local_results = default_tensor(spec) + elif isinstance(spec, Sequence): + # return a List[Tensor] value + local_results = [ + default_tensor(s) if s is not None else None for s in spec + ] + assert isinstance(local_results, List) + if None in local_results: + ret_type = str(ret_list[0].type) + raise NotImplementedError( + f"return type {ret_type} in DTensor op is not supported" + ) + else: + if output_sharding.needs_redistribute: + # compute locally with redistribute first if needed + assert output_sharding.redistribute_schema is not None + rlog(f"Op: {op_call} needs redistribute") + self.redistribute_local_args( + op_info, output_sharding.redistribute_schema + ) + + local_tensor_args = ( + pytree.tree_unflatten( + cast(List[object], op_info.local_args), op_info.args_tree_spec + ) + if op_info.args_tree_spec + else op_info.local_args + ) + + # run local op computation with potentially modified args/kwargs + local_tensor_args = cast(Tuple[object, ...], local_tensor_args) + if op_call in self._random_ops: + if not random._rng_tracker and is_rng_supported_mesh(mesh): + # Default to `OffsetBasedRNGTracker` if the parallelism API + # did not already construct one + random._rng_tracker = random.OffsetBasedRNGTracker(mesh.device_type) + + first_arg, first_local_arg = cast(dtensor.DTensor, args[0]), cast( + torch.Tensor, local_tensor_args[0] + ) + rng_context = ( + random._rng_tracker._distribute_region(first_arg._spec) + if random._rng_tracker and not first_local_arg.is_meta + else contextlib.nullcontext() + ) + + # For DTensor random operator, run it within a distribute region + with rng_context: + local_results = op_call(*local_tensor_args, **op_info.local_kwargs) + else: + # rlog(f"Calling {op_call} with args: {local_tensor_args} and kwargs: {op_info.local_kwargs}") + local_results = op_call(*local_tensor_args, **op_info.local_kwargs) + + # communicate the result to all ranks for some operators that return scalar value + if output_sharding.output_spec is None: + if op_call == aten.equal.default: + obj_list = [None for _ in range(dist.get_world_size())] + dist.all_gather_object(obj_list, local_results) # type: ignore[possibly-undefined] + obj_list = list(filter(lambda x: x is not None, obj_list)) + # perform reduce on the collection with AND op + local_results = functools.reduce(operator.and_, obj_list, True) + + if _is_inplace_op(op_call): + # inplace op should return self instead of re-wrapping + if output_sharding.output_spec is not None: + return args[0] + else: + return None + elif _is_out_variant_op(op_call): + # out variant could possibly have multiple out args (i.e. lu_unpack.out) + output_specs = ( + (output_sharding.output_spec,) + if not isinstance(output_sharding.output_spec, tuple) + else output_sharding.output_spec + ) + out_dts = [] + spec_idx = 0 + for argument in op_call._schema.arguments: + if argument.is_out: + out_dt = cast(dtensor.DTensor, kwargs[argument.name]) + out_dt._spec = cast(DTensorSpec, output_specs[spec_idx]) + out_dts.append(out_dt) + spec_idx += 1 + + assert len(out_dts) >= 1, "out variant should have at least one out arg" + return tuple(out_dts) if len(out_dts) > 1 else out_dts[0] + else: + return self.wrap(local_results, output_sharding.output_spec) # type: ignore[possibly-undefined] + + @staticmethod + def redistribute_local_args( + op_info: OpInfo, + suggested_input_schema: OpSchema, + ) -> None: + # NOTE: it's very rare that we need to reshard kwargs so we intentionally skip it + + # TODO: the op schema should probably just remain flattened so that we can avoid this tree flatten + # Need to fix all the ops before doing this. + if op_info.args_tree_spec is not None: + flatten_args_schema_to_reshard = tuple( + pytree.tree_leaves(suggested_input_schema.args_schema) + ) + else: + flatten_args_schema_to_reshard = suggested_input_schema.args_schema + + new_local_args: List[object] = [] + for i, arg_spec in enumerate(op_info.flat_args_schema): + reshard_arg_spec = flatten_args_schema_to_reshard[i] + if isinstance(arg_spec, DTensorSpec): + local_tensor = cast(torch.Tensor, op_info.local_args[i]) + if arg_spec != reshard_arg_spec: + resharded_local_tensor = redistribute_local_tensor( + local_tensor, arg_spec, reshard_arg_spec + ) + new_local_args.append(resharded_local_tensor) + else: + new_local_args.append(local_tensor) + else: + new_local_args.append(reshard_arg_spec) + + op_info.local_args = tuple(new_local_args) + + def unwrap_to_op_info( + self, + op_call: torch._ops.OpOverload, + args: Tuple[object, ...], + kwargs: Dict[str, object], + ) -> OpInfo: + # get runtime schema to determine whether to use pytree to flatten inputs + runtime_schema_info = self.sharding_propagator.op_to_schema_info.get( + op_call, None + ) + + if runtime_schema_info is not None and runtime_schema_info.needs_pytree: + # flatten args/kwargs when necessary + tree_args, args_spec = pytree.tree_flatten(args) + args_list: Sequence[object] = tree_args + else: + args_list, args_spec = args, None + + args_schema: List[object] = [] + kwargs_schema: Dict[str, object] = {} + local_args: List[object] = [] + local_kwargs: Dict[str, object] = {} + mesh: Optional[DeviceMesh] = None + + def try_get_replicate_spec( + tensor_arg: torch.Tensor, mesh: "DeviceMesh" + ) -> DTensorSpec: + # tensor_arg is an instance of torch.Tensor and could be an arg or kwarg. + if tensor_arg.numel() == 1 and tensor_arg.ndim == 1: + warnings.warn( + "Found a non-scalar tensor with numel=1 and ndim!=0, " + "we are implicitly creating a replicated DTensor for it. " + "However, please consider changing it to a scalar tensor " + "or explicitly create a DTensor under distributed enviroment." + ) + + # if the arg.numel() == 1, arg.ndim could be 0 or 1. + if ( + tensor_arg.ndim <= 1 + and tensor_arg.numel() == 1 + or self._allow_implicit_replication + ): + # scalar tensor can be safely treated as replicated + replication_spec = DTensorSpec( + mesh, + (Replicate(),) * mesh.ndim, + tensor_meta=TensorMeta( + shape=tensor_arg.shape, + stride=tensor_arg.stride(), + dtype=tensor_arg.dtype, + ), + ) + else: + raise RuntimeError( + f"{op_call}: got mixed torch.Tensor and DTensor, need to convert all" + " torch.Tensor to DTensor before calling distributed operators!" + ) + return replication_spec + + for arg in args_list: + if isinstance(arg, dtensor.DTensor): + args_schema.append(arg._spec) + local_args.append(arg._local_tensor) + if mesh is not None: + if mesh != arg.device_mesh: + raise NotImplementedError( + f"{op_call}: DTensor does not support cross-mesh operation yet!" + f"Got meshes: {mesh} {arg.device_mesh}" + ) + else: + mesh = arg.device_mesh + elif isinstance(arg, torch.Tensor): + mesh = mesh or try_find_mesh_from_args(op_call, args_list) + args_schema.append(try_get_replicate_spec(arg, mesh)) + local_args.append(arg) + else: + args_schema.append(arg) + local_args.append(arg) + + for k, v in kwargs.items(): + if isinstance(v, dtensor.DTensor): + kwargs_schema[k] = v._spec + local_kwargs[k] = v._local_tensor + if mesh is not None: + if mesh != v.device_mesh: + raise NotImplementedError( + f"{op_call}: DTensor does not support cross-mesh operation yet!" + ) + else: + mesh = v.device_mesh + elif isinstance(v, torch.Tensor): + mesh = mesh or try_find_mesh_from_args(op_call, args_list) + kwargs_schema[k] = try_get_replicate_spec(v, mesh) + local_kwargs[k] = v + else: + kwargs_schema[k] = v + local_kwargs[k] = v + + assert mesh is not None, f"found no DeviceMesh from dtensor args for {op_call}!" + op_info = OpInfo( + mesh, + OpSchema( + op_call, + ( + pytree.tree_unflatten(args_schema, args_spec) + if args_spec + else tuple(args_schema) + ), + kwargs_schema, + schema_info=runtime_schema_info, + ), + args_schema, + tuple(local_args), + local_kwargs, + args_spec, + ) + return op_info + + @staticmethod + def wrap(res: object, spec: OutputSpecType) -> object: + if isinstance(res, torch.Tensor): + if spec is not None: + assert isinstance( + spec, DTensorSpec + ), f"output spec does not match with output! Expected DTensorSpec, got {spec}." + return dtensor.DTensor(res, spec, requires_grad=res.requires_grad) + else: + # if output does not have a DTensorSpec due to specific ops, it must be a scalar tensor + assert res.ndim == 0, "output tensor should be scalar!" + return res + elif isinstance(res, (list, tuple)): + assert spec is not None and isinstance( + spec, (list, tuple) + ), f"output spec does not match with output! Expected list/tuple, got {spec}." + res_list = [] + for e, s in zip(res, spec): + res_list.append(OpDispatcher.wrap(e, s)) + + return tuple(res_list) if isinstance(res, tuple) else res_list + else: + # if the res contains only non tensor values (i.e. int/float/none), we simply return it + # without rewrapping to DTensor. + return res diff --git a/src/chop/distributed/tensor/_redistribute.py b/src/chop/distributed/tensor/_redistribute.py new file mode 100644 index 000000000..fc66b219b --- /dev/null +++ b/src/chop/distributed/tensor/_redistribute.py @@ -0,0 +1,340 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates +from functools import lru_cache +from typing import cast, Dict, List, NamedTuple, Tuple + +import torch +import torch.distributed._functional_collectives as funcol +from torch.distributed._tensor.device_mesh import DeviceMesh +from torch.distributed._tensor.placement_types import ( + DTensorSpec, + Partial, + Placement, + Replicate, + Shard, + TensorMeta, +) + +import chop.distributed.tensor.api as dtensor + + +class _TransformInfo(NamedTuple): + mesh_dim: int + src_dst_placements: Tuple[Placement, Placement] + # logical_shape on this mesh dimension + logical_shape: List[int] + + +def _replicate_then_shard(val: _TransformInfo) -> int: + """ + This is a helper function to allow reordering _TransformInfo list. The high level + idea is that we want to reorder the sharding redistributions so that the DTensor + redistribution is consistent with its full tensor. This is built on top of two simple + assumptions: + 1. Replication happens from inner to outer dimension. i.e. Shard -> Replicate + 2. Sharding happens from outer to inner dimension, i.e. Replicate -> Shard + + So we always put the replication first and put sharding later. + """ + mesh_dim = val.mesh_dim + src, dst = val.src_dst_placements + if (dst.is_replicate() or dst.is_partial()) and src.is_shard(): + return -mesh_dim + elif (src.is_replicate() or src.is_partial()) and dst.is_shard(): + return mesh_dim + else: + return 0 + + +@lru_cache(maxsize=None) +def _gen_transform_infos( + src_spec: DTensorSpec, + dst_spec: DTensorSpec, +) -> List[_TransformInfo]: + """ + Generate the transform infos from the source placements to the target placements. + + To transform from source to target placement it might have multiple steps, i.e. it + might decompose Si -> Sj into Si -> R -> Sj. + This would detects if there're mis-aligned shardings between src/dst placements. + i.e. (Shard(0), Shard(0)) -> (Replicate(), Shard(0)), in this case Shard(0) -> Shard(0) + for mesh dimension 1 actually needs reshard, because in the first case it's a sub-sharding + of an already tensor dimension 0, and in the second case, it's the first sharding on tensor + dimension 0. + """ + src_dim_counts: Dict[int, int] = {} + dst_dim_counts: Dict[int, int] = {} + transform_infos: List[_TransformInfo] = [] + + src_placements = src_spec.placements + dst_placements = dst_spec.placements + device_mesh = src_spec.device_mesh + my_coordinate = device_mesh.get_coordinate() + assert my_coordinate is not None + + # logical shape records the logic tensor shape on the mesh dimension + # this is useful to ensure uneven sharding gets correct output shape + initial_logical_shape = list(src_spec.shape) + mesh_dims_to_logical_shape = [initial_logical_shape] + mesh_ndim = len(src_placements) + + for i, (src, dst) in enumerate(zip(src_placements, dst_placements)): + # detect mis-aligned sharding and build logical shapes + current_logical_shape = mesh_dims_to_logical_shape[i] + if isinstance(src, Shard): + src_dim_counts[src.dim] = src_dim_counts.get(src.dim, 0) + 1 + + if i < mesh_ndim - 1: + # calculate and save the logical shape for this sharding + mesh_dim_size = device_mesh.size(mesh_dim=i) + local_shard_size, _ = src._local_shard_size_on_dim( + current_logical_shape[src.dim], + mesh_dim_size, + my_coordinate[i], + ) + new_logical_shape = list(current_logical_shape) + new_logical_shape[src.dim] = local_shard_size + mesh_dims_to_logical_shape.append(new_logical_shape) + else: + mesh_dims_to_logical_shape.append(current_logical_shape) + + if isinstance(dst, Shard): + dst_dim_counts[dst.dim] = dst_dim_counts.get(dst.dim, 0) + 1 + + if ( + isinstance(src, Shard) + and isinstance(dst, Shard) + and (mesh_ndim > 1 or src_dim_counts[src.dim] != dst_dim_counts[dst.dim]) + ): + # for the case when mesh ndim > 1 or shard dim counts are different + # TODO: see if we can optimize the mesh_ndim > 1 case + # decompose Shard(i) -> Shard(j) into Shard(i) -> Replicate() -> Shard(j) + transform_infos.append( + _TransformInfo( + mesh_dim=i, + src_dst_placements=(src, Replicate()), + logical_shape=mesh_dims_to_logical_shape[i], + ) + ) + transform_infos.append( + _TransformInfo( + mesh_dim=i, + src_dst_placements=(Replicate(), dst), + logical_shape=mesh_dims_to_logical_shape[i], + ) + ) + else: + transform_infos.append( + _TransformInfo( + mesh_dim=i, + src_dst_placements=(src, dst), + logical_shape=mesh_dims_to_logical_shape[i], + ) + ) + + # sort the pairs by first perform replication then sharding + transform_infos.sort(key=_replicate_then_shard) + return transform_infos + + +def redistribute_local_tensor( + local_tensor: torch.Tensor, + current_spec: DTensorSpec, + target_spec: DTensorSpec, + *, + async_op: bool = False, + is_backward: bool = False, +) -> torch.Tensor: + """ + This redistribute the local tensor (torch.Tensor) from the current DTensorSpec to + the target DTensorSpec, which involves the necessary collective calls to transform + the local shard of the DTensor from its current spec to the target spec. + """ + + if current_spec.mesh != target_spec.mesh: + # TODO: alltoall/permute reshuffling to change device_mesh if they are not the same + raise NotImplementedError("Cross device mesh comm not supported yet!") + + new_local_tensor = None + device_mesh = current_spec.mesh + + my_coordinate = device_mesh.get_coordinate() + + if my_coordinate is None: + # if rank is not part of mesh, we skip redistribute and simply return local_tensor, + # which should be an empty tensor + return local_tensor + + transform_infos = _gen_transform_infos(current_spec, target_spec) + + for transform_info in transform_infos: + i = transform_info.mesh_dim + current, target = transform_info.src_dst_placements + num_chunks = device_mesh.size(mesh_dim=i) + + if current == target: + # short cut, just use the original local tensor + new_local_tensor = local_tensor + continue + + if target.is_replicate(): + # Case 1: target is Replicate + if current.is_partial(): + partial_spec = cast(Partial, current) + new_local_tensor = partial_spec._reduce_value( + local_tensor, device_mesh, i + ) + elif current.is_shard(): + current_placement = cast(Shard, current) + new_local_tensor = current_placement._to_replicate_tensor( + local_tensor, device_mesh, i, transform_info.logical_shape + ) + else: + raise RuntimeError( + f"redistribute from {current} to {target} not supported yet" + ) + elif target.is_shard(): + # Case 2: target is Shard + target_placement = cast(Shard, target) + target_dim = target_placement.dim + if current.is_partial(): + partial_spec = cast(Partial, current) + new_local_tensor = partial_spec._reduce_shard_value( + local_tensor, device_mesh, i, target_placement + ) + elif current.is_replicate(): + # split the tensor and return the corresponding cloned local shard + new_local_tensor = target_placement._replicate_to_shard( + local_tensor, device_mesh, i, my_coordinate[i] + ) + else: + assert ( + current.is_shard() + ), f"Current placement should be shard but found {current}" + shard_spec = cast(Shard, current) + if shard_spec.dim != target_placement.dim: + new_local_tensor = shard_spec._to_new_shard_dim( + local_tensor, + device_mesh, + i, + transform_info.logical_shape, + target_placement.dim, + ) + elif target.is_partial(): + if current.is_replicate(): + partial_spec = cast(Partial, target) + # skip the replicate to partial transformation when we are in backward pass + # In this case we keep the grad as replicate, this is because we don't + # want to convert the replicated gradients back to partial, although + # that's logically conform with the same layout, converting the gradients + # back to partial is actually useless as you would have to do reduce later + # which would be more expensive than keeping it replicate! For this reason, + # we keep the replicate grad here. + new_local_tensor = ( + partial_spec._partition_value(local_tensor, device_mesh, i) + if not is_backward + else local_tensor + ) + elif current.is_shard(): + if not is_backward: + raise RuntimeError( + f"redistribute from {current} to {target} not supported yet" + ) + # for backward shard -> partial, we just need to convert the shard to replicate + current_placement = cast(Shard, current) + new_local_tensor = current_placement._to_replicate_tensor( + local_tensor, device_mesh, i, transform_info.logical_shape + ) + else: + # partial -> partial no op, should never hit + new_local_tensor = local_tensor + + assert new_local_tensor is not None + local_tensor = new_local_tensor + + assert new_local_tensor is not None, "redistribute failed!" + + if not async_op and isinstance(new_local_tensor, funcol.AsyncCollectiveTensor): + new_local_tensor = new_local_tensor.wait() + + return new_local_tensor + + +class Redistribute(torch.autograd.Function): + @staticmethod + def forward( # type: ignore[override] + # pyre-fixme[2]: Parameter must be annotated. + ctx, + input: "dtensor.DTensor", + device_mesh: DeviceMesh, + placements: Tuple[Placement, ...], + async_op: bool = False, + ): + current_spec = input._spec + ctx.current_spec = current_spec + ctx.async_op = async_op + + if current_spec.placements != placements: + target_spec = DTensorSpec( + device_mesh, placements, tensor_meta=input._spec.tensor_meta + ) + + local_tensor = input._local_tensor + output = redistribute_local_tensor( + local_tensor, current_spec, target_spec, async_op=async_op + ) + else: + # use the same local tensor if placements are the same. + output = input._local_tensor + target_spec = current_spec + + return dtensor.DTensor( + output, + target_spec, + requires_grad=input.requires_grad, + ) + + @staticmethod + def backward(ctx, grad_output: "dtensor.DTensor"): # type: ignore[override] + previous_spec = ctx.current_spec + current_spec = grad_output._spec + async_op = ctx.async_op + + local_tensor = grad_output._local_tensor + output = redistribute_local_tensor( + local_tensor, + current_spec, + previous_spec, + async_op=async_op, + is_backward=True, + ) + # normalize the target placement to replicate if it is partial + normalized_placements: List[Placement] = [] + for previous_placement in previous_spec.placements: + if previous_placement.is_partial(): + # keep target placement to replicate instead of partial in this case + normalized_placements.append(Replicate()) + else: + normalized_placements.append(previous_placement) + + spec = DTensorSpec( + previous_spec.device_mesh, + tuple(normalized_placements), + tensor_meta=TensorMeta( + shape=grad_output.shape, + stride=grad_output.stride(), + dtype=grad_output.dtype, + ), + ) + output_dtensor = dtensor.DTensor( + output, + spec, + requires_grad=grad_output.requires_grad, + ) + + return ( + output_dtensor, + None, + None, + None, + ) diff --git a/src/chop/distributed/tensor/_sharding_prop.py b/src/chop/distributed/tensor/_sharding_prop.py new file mode 100644 index 000000000..6d7d6ea11 --- /dev/null +++ b/src/chop/distributed/tensor/_sharding_prop.py @@ -0,0 +1,489 @@ +# mypy: allow-untyped-defs +from functools import lru_cache +from itertools import chain +from typing import Callable, cast, Dict, List, Optional, Sequence, Tuple, Union + +import torch +from torch._ops import OpOverload +from torch._subclasses import FakeTensorMode +from torch.distributed._tensor._op_schema import ( + OpInfo, + OpSchema, + OpStrategy, + OutputSharding, + OutputSpecType, + PlacementStrategy, + RuntimeSchemaInfo, + StrategyType, + TupleStrategy, +) +from torch.distributed._tensor._utils import ( + compute_local_shape, + compute_local_stride, + try_find_mesh_from_args, +) +from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta +from torch.distributed.device_mesh import DeviceMesh + + +aten = torch.ops.aten + + +def _length(obj) -> int: + if obj is None: + return 0 + if not isinstance(obj, Sequence): + return 1 + return len(obj) + + +def rlog(msg): + rank = torch.distributed.get_rank() + if rank == 0: + print(msg) + + +class ShardingPropagator: + def __init__(self) -> None: + self.op_to_rules: Dict[OpOverload, Callable[[OpSchema], OutputSharding]] = {} + self.op_strategy_funcs: Dict[ + OpOverload, + Callable[[DeviceMesh, OpSchema], StrategyType], + ] = {} + # op map to save static argnum to decide to reuse sharding prop cache or re-run sharding prop + self.op_to_schema_info: Dict[OpOverload, RuntimeSchemaInfo] = {} + self.propagate_op_sharding = lru_cache(None)(self.propagate_op_sharding_non_cached) # type: ignore[method-assign] + # op map to save indices of shape (and stride) args which may need to be modified in sharding prop + self.op_to_shape_and_stride_idx: Dict[ + OpOverload, Union[int, Tuple[int, int]] + ] = { + # new factory ops + aten.new_empty.default: 1, + aten.new_full.default: 1, + aten.new_ones.default: 1, + aten.new_zeros.default: 1, + aten.new_empty_strided.default: (1, 2), + # view ops + aten.expand.default: 1, + aten.reshape.default: 1, + aten.view.default: 1, + aten._unsafe_view.default: 1, + } + + def register_sharding_prop_rule( + self, + op_overload: OpOverload, + rule_func: Callable[[OpSchema], OutputSharding], + schema_info: Optional[RuntimeSchemaInfo] = None, + ): + """ + Register a sharding propagation rule for an operator. + """ + self.op_to_rules[op_overload] = rule_func + if schema_info is not None: + self.op_to_schema_info[op_overload] = schema_info + + def register_op_strategy( + self, + op_overload: OpOverload, + strategy_func: Callable[[DeviceMesh, OpSchema], StrategyType], + schema_info: Optional[RuntimeSchemaInfo] = None, + ): + """ + Register a sharding strategy generator for an operator. + """ + self.op_strategy_funcs[op_overload] = strategy_func + if schema_info is not None: + self.op_to_schema_info[op_overload] = schema_info + + @lru_cache # noqa: B019 + def _propagate_tensor_meta( + self, op_schema: OpSchema + ) -> Union[None, TensorMeta, Sequence[Optional[TensorMeta]]]: + """ + Propagate the tensor metadata, it could either return a TensorMeta + or a list/tuple of TensorMetas + """ + if op_schema.op == aten.equal.default: + # data dependent ops can't be used for fake propagation + return None + + # NOTE: We must call the tracing in fake tensor mode so that it + # avoids materializing memory + with FakeTensorMode(): + fake_args = op_schema.gen_fake_args() + fake_kwargs = op_schema.gen_fake_kwargs() + fake_out = op_schema.op(*fake_args, **fake_kwargs) + + if isinstance(fake_out, torch.Tensor): + return TensorMeta( + shape=fake_out.shape, stride=fake_out.stride(), dtype=fake_out.dtype + ) + + elif isinstance(fake_out, (tuple, list)): + tensor_meta_list: List[Optional[TensorMeta]] = [] + for fake_out_item in fake_out: + if isinstance(fake_out_item, torch.Tensor): + tensor_meta_list.append( + TensorMeta( + shape=fake_out_item.shape, + stride=fake_out_item.stride(), + dtype=fake_out_item.dtype, + ) + ) + else: + tensor_meta_list.append(None) + return ( + tuple(tensor_meta_list) + if isinstance(fake_out, tuple) + else tensor_meta_list + ) + else: + # if fake is not a tensor or tuple of tensor, return as none + return None + + def _wrap_output_spec_tensor_meta( + self, + op: OpOverload, + output_specs: OutputSpecType, + output_tensor_meta: Union[None, TensorMeta, Sequence[Optional[TensorMeta]]], + ) -> None: + """ + Wrap the output_specs with the tensor metadata from the output. + """ + + if isinstance(output_specs, DTensorSpec): + if not isinstance(output_tensor_meta, TensorMeta): + # Either error due to ShardingPropagator or due to incorrect OutputSpec + if not isinstance(output_tensor_meta, (tuple, list)): + raise ValueError( + "ShardingPropagator error: output does not have an associated TensorMeta" + ) + raise ValueError( + f"For the op {op.name()}, `output_specs` has 1 output which does not equal the " + f"number of op outputs: {len(output_tensor_meta)}." + ) + output_specs.tensor_meta = output_tensor_meta + elif isinstance(output_specs, (tuple, list)): + if not isinstance(output_tensor_meta, (tuple, list)) or len( + output_specs + ) != len(output_tensor_meta): + raise ValueError( + f"For the op {op.name()}, `output_specs` has {len(output_specs)} outputs which does not equal the " + f"number of op outputs {_length(output_tensor_meta)}." + ) + for i, spec in enumerate(output_specs): + if isinstance(spec, DTensorSpec): + output_tensor_meta_i = output_tensor_meta[i] + if not isinstance(output_tensor_meta_i, TensorMeta): + raise ValueError( + f"ShardingPropagator error: output {i} does not have an associated TensorMeta" + ) + spec.tensor_meta = output_tensor_meta_i + + def propagate(self, op_info: OpInfo) -> None: + # We cannot use an lru cache if we know that inputs will have dynamic shapes, + # because SymInts are not hashable. + # This is generally ok because this only happens during tracing in torch.compile, + # and tracing does not need to be as fast as eagermode DTensor usages. + if op_info.schema.has_symints: + output_sharding = self.propagate_op_sharding_non_cached(op_info.schema) + else: + output_sharding = self.propagate_op_sharding(op_info.schema) + op_info.output_sharding = output_sharding + + def propagate_op_sharding_non_cached(self, op_schema: OpSchema) -> OutputSharding: + """ + Propagate the sharding for an operator given the op_schema. + """ + # special case op, we don't need to propagate for local + # scalar. TODO: figure out a better way to handle this + if op_schema.op is aten._local_scalar_dense.default: + return OutputSharding(None, op_schema) + + out_tensor_meta = self._propagate_tensor_meta(op_schema) + + def spec_to_strategy(spec: object) -> object: + if isinstance(spec, DTensorSpec): + return OpStrategy([PlacementStrategy(spec)]) + elif ( + isinstance(spec, (list, tuple)) + and len(spec) > 0 + and isinstance(spec[0], DTensorSpec) + ): + # tensor list create tuple strategy + tuple_strategy = [spec_to_strategy(s) for s in spec] + tuple_strategy = cast(Sequence[StrategyType], tuple_strategy) + return TupleStrategy( + tuple(tuple_strategy) if isinstance(spec, tuple) else tuple_strategy + ) + else: + return spec + + if op_schema.op in self.op_strategy_funcs: + # generate op strategy for the op. + mesh = try_find_mesh_from_args(op_schema.op, op_schema.args_schema) + # swap the args spec with args strategies + args_op_strategy = [spec_to_strategy(i) for i in op_schema.args_schema] + + kwargs_op_strategy = { + k: spec_to_strategy(v) for k, v in op_schema.kwargs_schema.items() + } + + # construct a new OpSchema on args for strategy based propagation + strategy_schema: OpSchema = OpSchema( + op=op_schema.op, + args_schema=tuple(args_op_strategy), + kwargs_schema=kwargs_op_strategy, + ) + + op_strategy = self.op_strategy_funcs[op_schema.op](mesh, strategy_schema) + + if isinstance(op_strategy, OpStrategy): + # single Op strategy + output_strategy = self._select_strategy(op_strategy) + + # check if we need to redistribute the input + needs_redistribute = False + expected_input_specs = [] + + # in case where the op does not specify input_specs and output_specs + # is a DTensorSpec, we use output_specs as the spec for each DTensor + # input arg. + if output_strategy.input_specs is None: + assert isinstance(output_strategy.output_specs, DTensorSpec) + + for idx, input_spec in enumerate(op_schema.args_spec): + if "layer_norm" in str(op_schema.op): + rlog(f" arg {idx}, input_spec: {input_spec}") + desired_spec = ( + output_strategy.output_spec + if output_strategy.input_specs is None + else output_strategy.input_specs[idx] + ) + expected_input_specs.append( + desired_spec.shallow_copy_with_tensor_meta( + input_spec.tensor_meta + ) + ) + if input_spec.placements != desired_spec.placements: + needs_redistribute = True + + suggestion_schema = None + if needs_redistribute: + suggestion_schema = OpSchema( + op_schema.op, tuple(expected_input_specs), {} + ) + suggestion_schema._inplace_rewrap_schema_suggestion(op_schema) + + # shape and stride args need to be modified for + # view ops and new factory ops, potentially + if op_schema.op in self.op_to_shape_and_stride_idx: + assert isinstance(output_strategy.output_spec, DTensorSpec) + # It happens when the output has the same shape as the input + # and the input placements are not all Replicate(). + if output_strategy.output_spec.is_sharded(): + schema = suggestion_schema or op_schema + assert isinstance(out_tensor_meta, TensorMeta) + suggestion_schema = self._adjust_shape_and_stride_args( + out_tensor_meta, schema, output_strategy.output_spec, mesh + ) + needs_redistribute = True + + # construct output spec for the op + if op_schema.return_type_tuple_tensor_like(): + # for ops that return multiple tensors and the output_specs is not + # a tuple, we use a tuple of that single output spec as the new + # output_specs + output_specs: OutputSpecType = output_strategy.output_specs + if isinstance(output_specs, DTensorSpec): + output_specs = tuple( + [ + # create a new DTensorSpec with the same placement as the + # output_specs in output_strategy + DTensorSpec( + mesh=output_specs.mesh, + placements=output_specs.placements, + tensor_meta=output_specs.tensor_meta, + ) + for _ in range(len(op_schema.op._schema.returns)) + ] + ) + elif op_schema.return_type_tensor(): + output_specs = output_strategy.output_specs + else: + output_specs = None + + output_sharding = OutputSharding( + output_specs, + suggestion_schema, + needs_redistribute=needs_redistribute, + ) + elif isinstance(op_strategy, TupleStrategy): + # tuple strategy output sharding processing + # runtime selected placement strategy for each TupleStrategy input arg + selected_strategies: List[PlacementStrategy] = [] + out_spec_list: List[DTensorSpec] = [] + for strategy in op_strategy.childs: + assert isinstance(strategy, OpStrategy) + selected_strategy = self._select_strategy(strategy) + selected_strategies.append(selected_strategy) + out_spec_list.append(selected_strategy.output_spec) + + needs_redistribute = False + suggestion_args: List[object] = [] + tensor_or_list_tensor_arg_idx = 0 + + for arg in op_schema.args_schema: + if ( + arg + and isinstance(arg, (list, tuple)) + and isinstance(arg[0], DTensorSpec) + ): + expected_input_spec_list: List[DTensorSpec] = [] + for idx, arg_spec in enumerate(arg): + expected_input_spec = selected_strategies[idx].input_spec( + tensor_or_list_tensor_arg_idx + ) + expected_input_spec = ( + expected_input_spec.shallow_copy_with_tensor_meta( + arg_spec.tensor_meta + ) + ) + if arg_spec.placements != expected_input_spec.placements: + needs_redistribute = True + expected_input_spec_list.append(expected_input_spec) + suggestion_args.append( + tuple(expected_input_spec_list) + if isinstance(arg, tuple) + else expected_input_spec_list + ) + tensor_or_list_tensor_arg_idx += 1 + + elif isinstance(arg, DTensorSpec): + expected_input_spec = selected_strategies[0].input_spec( + tensor_or_list_tensor_arg_idx + ) + expected_input_spec = ( + expected_input_spec.shallow_copy_with_tensor_meta( + arg.tensor_meta + ) + ) + if arg.placements != expected_input_spec.placements: + needs_redistribute = True + suggestion_args.append(expected_input_spec) + tensor_or_list_tensor_arg_idx += 1 + else: + suggestion_args.append(arg) + + suggestion_schema = None + if needs_redistribute: + suggestion_schema = OpSchema( + op_schema.op, tuple(suggestion_args), op_schema.kwargs_schema + ) + + output_sharding = OutputSharding( + tuple(out_spec_list) if out_tensor_meta is not None else None, + suggestion_schema, + needs_redistribute=needs_redistribute, + ) + else: + raise ValueError("Unsupported op strategy type") + + # associate the output sharding with the output tensor metadata + self._wrap_output_spec_tensor_meta( + op_schema.op, output_sharding.output_spec, out_tensor_meta + ) + return output_sharding + elif op_schema.op in self.op_to_rules: + # propagate the sharding with rule + sharding_prop_func = self.op_to_rules[op_schema.op] + + # step 1. there's sharding propagation rule, run + # sharding propagation to get the output sharding + try: + output_sharding = sharding_prop_func(op_schema) + except NotImplementedError as e: + raise e + except Exception as e: + raise RuntimeError( + f"Sharding propagation failed on op {op_schema}.\n" f"Error: {e}" + ) from e + + # step 2. if can't get output_spec from sharding + # propagation (i.e. no rules apply for input + # placements), we return the output sharding + # with schema suggestions, which can be used to + # decide how to do redistribute on inputs + if output_sharding.output_spec is None: + if output_sharding.redistribute_schema is None: + raise RuntimeError( + f"Sharding propagation failed on op {op_schema}!" + ) + else: + # we do auto redistribute on inputs if necessary + # run sharding propagation again with suggested schema + propagation_res = sharding_prop_func( + output_sharding.redistribute_schema + ) + # we set the output sharding with the new propagation result + # so that dispatching know both output_spec and redistribute_schema + # exist, which indicates a reshard is needed + output_sharding.output_spec = propagation_res.output_spec + output_sharding.needs_redistribute = True + + # associate the output sharding with the output tensor metadata + self._wrap_output_spec_tensor_meta( + op_schema.op, output_sharding.output_spec, out_tensor_meta + ) + + return output_sharding + else: + raise NotImplementedError( + f"Operator {op_schema.op} does not have a sharding strategy registered." + ) + + def _select_strategy(self, strategy: OpStrategy) -> PlacementStrategy: + if len(strategy.strategies) == 1: + # short cut with only one possible strategy + return strategy.strategies[0] + + strategy_costs: List[float] = [] + for strtg in strategy.strategies: + assert ( + strtg.redistribute_cost is not None + ), "must set redistribute cost each strategy!" + redistribute_cost = sum(chain.from_iterable(strtg.redistribute_cost)) + strategy_costs.append(redistribute_cost) + + # for eager execution, we just select the one with the minimal redistribute cost + return strategy.strategies[strategy_costs.index(min(strategy_costs))] + + def _adjust_shape_and_stride_args( + self, + out_tensor_meta: TensorMeta, + schema: OpSchema, + spec: DTensorSpec, + mesh: DeviceMesh, + ) -> OpSchema: + shape_stride_idx = self.op_to_shape_and_stride_idx[schema.op] + if isinstance(shape_stride_idx, tuple): + shape_idx, stride_idx = shape_stride_idx + else: + shape_idx = shape_stride_idx + stride_idx = None + + expected_input_schema = list(schema.args_schema) + # adjust shape to be the same as that of the _local_tensor + # of the DTensor input arg at index 0, which is inferred + expected_input_schema[shape_idx] = compute_local_shape( + out_tensor_meta.shape, mesh, spec.placements + ) + + # adjust the stride arg for aten.new_empty_strided.default + if stride_idx: + expected_input_schema[stride_idx] = compute_local_stride( + out_tensor_meta.stride, mesh, spec.placements + ) + + return OpSchema(schema.op, tuple(expected_input_schema), schema.kwargs_schema) diff --git a/src/chop/distributed/tensor/_utils.py b/src/chop/distributed/tensor/_utils.py new file mode 100644 index 000000000..a3cc8ee5a --- /dev/null +++ b/src/chop/distributed/tensor/_utils.py @@ -0,0 +1,226 @@ +from typing import cast, List, Sequence, Tuple + +import torch +import torch.distributed._tensor.api as dtensor +from torch._prims_common import ShapeType +from torch.distributed._tensor.placement_types import ( + DTensorSpec, + Partial, + Placement, + Replicate, + Shard, +) +from torch.distributed.device_mesh import DeviceMesh + + +# TODO: audit existing code base to see if we can safely remove this API. +def compute_local_shape( + global_shape: ShapeType, mesh: DeviceMesh, placements: Sequence[Placement] +) -> Tuple[int, ...]: + """ + Compute the shape of a local shard of the given DTensor on its current + coordinate of the mesh. + """ + my_coordinate = mesh.get_coordinate() + + if my_coordinate is None: + # if rank not in the mesh, return empty shape + return (0,) + else: + local_shape = list(global_shape) # start with global shape + ndim = len(global_shape) + for idx, placement in enumerate(placements): + mesh_dim_size = mesh.size(idx) + if isinstance(placement, Shard): + shard_dim = placement.dim + assert ( + shard_dim < ndim + ), f"Sharding dim {shard_dim} greater than tensor ndim {ndim}" + local_shard_size, _ = placement._local_shard_size_on_dim( + local_shape[shard_dim], mesh_dim_size, my_coordinate[idx] + ) + assert isinstance(local_shard_size, int) + local_shape[shard_dim] = local_shard_size + + return tuple(local_shape) + + +def compute_local_shape_and_global_offset( + global_shape: ShapeType, mesh: DeviceMesh, placements: Sequence[Placement] +) -> Tuple[Tuple[int, ...], Tuple[int, ...]]: + """ + Compute the local tensor shape and the global offsets into the original tensor + of a DTensor on its current global rank. This is useful for checkpointing purpose. + + Example (2 host with 4GPUs each): + # Below is a DeviceMesh with mesh_shape of (2, 4) + mesh = DeviceMesh(device_type="cuda", + mesh=[ + [0, 1, 2, 3], + [4, 5, 6, 7] + ], + ) + + Let's say we distribute a global_tensor of shape (8,4) over the above DeviceMesh + with a placements of [Shard(0), Shard(0)]. + The local shape and global offset will be as follows: + rank0 -- local_shape:[1, 4], global_offset:[0, 0] + rank1 -- local_shape:[1, 4], global_offset:[1, 0] + rank2 -- local_shape:[1, 4], global_offset:[2, 0] + rank5 -- local_shape:[1, 4], global_offset:[5, 0] + rank3 -- local_shape:[1, 4], global_offset:[3, 0] + rank4 -- local_shape:[1, 4], global_offset:[4, 0] + rank6 -- local_shape:[1, 4], global_offset:[6, 0] + rank7 -- local_shape:[1, 4], global_offset:[7, 0] + + Let's say we distribute a global_tensor of shape (2) over the above DeviceMesh with + a placements of [Shard(0)]. We will not have non-empty local tensor for all the ranks. + The local shape and global offset will be as follows: + rank0 -- local_shape:[1,], global_offset:[0,] + rank1 -- local_shape:[1,], global_offset:[1,] + rank2 -- local_shape:[0,], global_offset:[2,] + rank5 -- local_shape:[0,], global_offset:[2,] + rank3 -- local_shape:[0,], global_offset:[2,] + rank4 -- local_shape:[0,], global_offset:[2,] + rank6 -- local_shape:[0,], global_offset:[2,] + rank7 -- local_shape:[0,], global_offset:[2,] + """ + my_coordinate = mesh.get_coordinate() + + if my_coordinate is None: + # if rank not in the mesh, return empty offset + return ((), ()) + else: + local_shape = list(global_shape) + global_offset = [0] * len(global_shape) + + for idx, placement in enumerate(placements): + mesh_dim_size = mesh.size(idx) + if isinstance(placement, Shard): + shard_dim = placement.dim + local_offset = [0] * len(global_shape) + assert shard_dim < len( + local_shape + ), f"Sharding dim {shard_dim} greater than tensor ndim {len(local_shape)}" + shard_size, shard_offset = placement._local_shard_size_on_dim( + local_shape[shard_dim], + mesh_dim_size, + my_coordinate[idx], + return_offset=True, + ) + + local_shape[shard_dim] = shard_size + local_offset[shard_dim] = shard_offset + + # On a given dimension, if the local_offset[shard_dim] is smaller than global_offset[shard_dim], + # it means that this dimension has been already sharded in previous placement. + # Therefore, we cannot simply replace the global_offset[shard_dim] with local_offset[shard_dim]. + # Instead, for the given shard_dim, we need to add local_offset[shard_dim] to existing global_offset[shard_dim]. + if global_offset[shard_dim] <= local_offset[shard_dim]: + global_offset[shard_dim] = local_offset[shard_dim] + else: + global_offset[shard_dim] += local_offset[shard_dim] + + return tuple(local_shape), tuple(global_offset) + + +def compute_global_tensor_info( + tensor: torch.Tensor, mesh: DeviceMesh, placements: Sequence[Placement] +) -> Tuple[List[int], List[int]]: + """ + Compute the global size and stride of a DTensor from the given local tensor. + The local size is multiplited by `world_size` per Sharding dim. + The local stride is multiplited by `world_size` per Sharding dim, as long as the + dimension is outside sharding dim. + + For example, if we have a local tensor with size (4, 8, 2) and stride (16, 1, 8). + If the DTensor placements are [Shard(2)] and world_size is 2; + then the global size is (4, 8, 4) and stride is (16 * 2, 1, 8). + + Args: + tensor (:class:`torch.Tensor`): + Local tensor which DTensor will be constructed from. + mesh (:class:`DeviceMesh`): + Object which describes the mesh topology + of devices for the DTensor. + placements (Sequence[:class:`Placement`]]): + The attribute of the DTensor that describes its layout + on the mesh topology. + + Return: + tensor_shape: A List of int which specifies the size of DTensor which build + on top of the local tensor. + tensor_stride: A List of int which specifies the stride of DTensor. + """ + tensor_shape = list(tensor.size()) + tensor_stride = list(tensor.stride()) + for idx, placement in enumerate(placements): + mesh_dim_size = mesh.size(idx) + if placement.is_shard(): + shard_placement = cast(Shard, placement) + if shard_placement.dim < 0: + raise AssertionError( + "Shard placements should have negative dims normalized in " + f"the user-facing APIs: {shard_placement}" + ) + shard_dim = shard_placement.dim + + assert ( + shard_dim < tensor.ndim + ), f"Sharding dim {shard_dim} greater than tensor ndim {tensor.ndim} for placement number {idx}." + + local_dim_size = tensor_shape[shard_dim] + tensor_shape[shard_dim] = local_dim_size * mesh_dim_size + + # recover tensor stride by modifying the stride that larger than + # the current stride on the shard_dim + for i in range(len(tensor_stride)): + if i != shard_dim and tensor_stride[i] >= tensor_stride[shard_dim]: + # rescale the stride by the shard size + tensor_stride[i] = tensor_stride[i] * mesh_dim_size + elif not isinstance(placement, (Replicate, Partial)): + raise RuntimeError(f"placement type {type(placement)} not supported!") + return tensor_shape, tensor_stride + + +def try_find_mesh_from_args( + op_call: torch._ops.OpOverload, args: Sequence[object] +) -> DeviceMesh: + """ + Find the device mesh object from args. + It returns None if no mesh is found. + NOTE: we can optimize this search if needed + """ + for arg in args: + if isinstance(arg, (dtensor.DTensor, DTensorSpec)): + return arg.device_mesh + elif ( + isinstance(arg, (list, tuple)) + and len(arg) > 0 + and isinstance(arg[0], (dtensor.DTensor, DTensorSpec)) + ): + return arg[0].device_mesh + + raise ValueError(f"Cannot find device mesh from args for op : {op_call}.") + + +def compute_local_stride( + global_stride: ShapeType, mesh: DeviceMesh, placements: Sequence[Placement] +) -> Tuple[int, ...]: + """ + Compute the stride of a local tensor shard, given the global stride of the DTensor. + NOTE: Currently this function is assuming the DTensor is evenly shardable. + """ + stride_divisors = [1] * len(global_stride) + for mesh_idx, p in enumerate(placements): + if p.is_shard(): + i = cast(Shard, p).dim + # tensor dimension i is sharded on mesh dimension mesh_idx, + # so we need to divide all the strides larger than stride[i] + # (by the submesh size) + for j in range(len(global_stride)): + if global_stride[j] > global_stride[i]: + stride_divisors[j] *= mesh.size(mesh_idx) + return tuple( + global_stride[i] // stride_divisors[i] for i in range(len(global_stride)) + ) diff --git a/src/chop/distributed/tensor/api.py b/src/chop/distributed/tensor/api.py new file mode 100644 index 000000000..47a778bea --- /dev/null +++ b/src/chop/distributed/tensor/api.py @@ -0,0 +1,847 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates +import inspect +import warnings +from typing import Any, Callable, cast, Optional, Sequence, Tuple + +import torch +import torch.distributed._tensor.random as random +import torch.nn as nn +from torch.distributed._tensor._collective_utils import ( + check_tensor_meta, + mesh_broadcast, +) +from torch.distributed._tensor._redistribute import ( + Redistribute, + redistribute_local_tensor, +) +from torch.distributed._tensor._utils import compute_global_tensor_info +from torch.distributed._tensor.placement_types import ( + DTensorSpec, + Partial, + Placement, + Replicate, + Shard, + TensorMeta, +) +from torch.distributed._tensor.random import ( + is_rng_supported_mesh, + OffsetBasedRNGTracker, +) +from torch.distributed.device_mesh import _mesh_resources, DeviceMesh + +import chop.distributed.tensor._dispatch as op_dispatch + +__all__ = ["DTensor", "distribute_tensor", "distribute_module"] + +aten = torch.ops.aten + + +# NOTE [Autograd interaction between torch.Tensor] +# +# The autograd functions defined below are being used by the public +# facing APIs (i.e. from_local, to_local) to ensure our DTensor +# works together with torch.Tensor within autograd engine. This +# allows DistributedTensor to exist on part of the module hierarchy +# and still able to calculate gradients across the torch.Tensor and +# DistributedTensor boundary. +# As an example, we have the a module that consists of submodules +# A, B, and C, the execution flow would be like: +# input(torch.Tensor) -> Module A -> Module B -> Module C -> output (torch.Tensor) +# +# Suppose I only want to make Module B be a sharded module with +# DistributedTensor params, we would need to make the following +# flow to work: +# +# input(torch.Tensor) -> Module A +# -> DTensor input -> Sharded Module B -> DTensor output +# -> output (torch.Tensor) -> Module C -> output (torch.Tensor) +# +# We need the conversion from Module A to DTensor input, which is +# `from_local`, and conversion from DTensor output to output, which +# is `to_local`, thus these two functions must be Autograd functions. +# +class _ToTorchTensor(torch.autograd.Function): + @staticmethod + def forward( # type: ignore[override] + ctx, + input: "DTensor", + grad_placements: Optional[Sequence[Placement]], + ): + ctx.dtensor_spec = input._spec + ctx.grad_placements = grad_placements + local_tensor = input._local_tensor + + # We need to return a fresh Tensor object there as autograd metadata + # will be inplaced into it. So we don't want to pollute the Tensor + # object stored in the _local_tensor of this DTensor. + return local_tensor.view_as(local_tensor) + + @staticmethod + def backward(ctx, grad_output: torch.Tensor): # type: ignore[override] + dtensor_spec = ctx.dtensor_spec + mesh = dtensor_spec.mesh + grad_placements = ctx.grad_placements + dtensor_meta = dtensor_spec.tensor_meta + + _, tensor_stride = compute_global_tensor_info( + grad_output, mesh, dtensor_spec.placements + ) + tensor_stride = tuple(tensor_stride) + grad_placements = grad_placements or dtensor_spec.placements + grad_spec = DTensorSpec( + mesh, + grad_placements, + tensor_meta=TensorMeta( + shape=dtensor_meta.shape, + stride=tensor_stride, + dtype=dtensor_meta.dtype, + ), + ) + + return ( + DTensor( + grad_output, + grad_spec, + requires_grad=grad_output.requires_grad, + ), + None, + ) + + +class _FromTorchTensor(torch.autograd.Function): + @staticmethod + def forward( # type: ignore[override] + ctx, # pyre-ignore[2]: Parameter must be annotated. + input: torch.Tensor, + device_mesh: DeviceMesh, + placements: Tuple[Placement, ...], + run_check: bool, + shape: Optional[torch.Size] = None, + stride: Optional[Tuple[int, ...]] = None, + ) -> "DTensor": + ctx.previous_placement = placements + ctx.previous_device_mesh = device_mesh + + if shape and stride: + tensor_shape, tensor_stride = shape, stride + elif not shape and not stride: + # if it's not by default run_check, we assume user is certain that each + # rank has the same tensor shape, and we just use that to calculate the + # global shape + global_shape, global_stride = compute_global_tensor_info( + input, device_mesh, placements + ) + tensor_shape, tensor_stride = torch.Size(global_shape), tuple(global_stride) + else: + raise RuntimeError( + f"Found shape:{shape}, stride:{stride}.", + "Please pass both shape and stride at the same time.", + ) + + if device_mesh.get_coordinate() is None: + # if the global rank is not participating in the device mesh, we + # simply set the local tensor to an empty tensor + input = input.new_empty(0, requires_grad=input.requires_grad) + elif run_check: + # TODO: support uneven sharding when global shape/stride not passed, by + # building the global TensorMeta during check_tensor_meta + check_shape_stride = not shape and not stride + check_tensor_meta(input, check_shape_stride=check_shape_stride) + # TODO: See if we need to make this run_check logic + # have a corresponding backward. + for idx, placement in enumerate(placements): + if placement.is_replicate(): + # broadcast rank 0 tensor to all ranks + # only broadcast if run_check is True + input = input.contiguous() + mesh_broadcast(input, device_mesh, mesh_dim=idx) + + dist_spec = DTensorSpec( + device_mesh, + placements, + tensor_meta=TensorMeta( + tensor_shape, + tensor_stride, + input.dtype, + ), + ) + + # We want a fresh Tensor object that shares memory with the input tensor + dist_tensor = DTensor( + input.view_as(input), + dist_spec, + # requires_grad of the dist tensor depends on if input + # requires_grad or not + requires_grad=input.requires_grad, + ) + return dist_tensor + + @staticmethod + def backward(ctx, grad_output: "DTensor"): # type: ignore[override] + previous_placement = ctx.previous_placement + previous_device_mesh = ctx.previous_device_mesh + + # reshard to the placement when creating DistributedTensor + # so that the gradient layout matches, and we could return + # local gradients directly + if grad_output.placements != previous_placement: + current_spec = grad_output._spec + target_spec = DTensorSpec( + previous_device_mesh, + previous_placement, + tensor_meta=grad_output._spec.tensor_meta, + ) + local_tensor = grad_output._local_tensor + output = redistribute_local_tensor( + local_tensor, current_spec, target_spec, is_backward=True + ) + # TODO: return the redistributed local tensor directly without + # differentiable backward. see if this make sense for all cases. + return output, None, None, None, None, None + + # TODO: backward is also differentiable now, add a test + # to test higher level gradients. + return grad_output.to_local(), None, None, None, None, None + + +class DTensor(torch.Tensor): # pyre-ignore[13]: pyre is bad at __new__ + _local_tensor: torch.Tensor + _spec: DTensorSpec + __slots__ = ["_local_tensor", "_spec"] + + # class attribute that handles operator placements propagation + # rules, keyed by aten op name, value is propagation func + _op_dispatcher: op_dispatch.OpDispatcher = op_dispatch.OpDispatcher() + + @staticmethod + @torch._disable_dynamo + def __new__( + cls, + local_tensor: torch.Tensor, + spec: DTensorSpec, + *, + requires_grad: bool, + ) -> "DTensor": + """ + Construct a DTensor from a local tensor, device mesh, and placement and + other tensor properties (i.e. shape, requires_grad, strides, etc). + Note: This is not a public API and it's only supposed to be used by the + operator implementations and internals. If you want to construct a + DTensor from a local tensor, consider using `DTensor.from_local`, if + you want to construct a DTensor from a "global" tensor (where you + already have tensor initialized and want to shard this tensor), + consider using `distribute_tensor`. + """ + if local_tensor.requires_grad and not requires_grad: + warnings.warn( + "To construct DTensor from torch.Tensor, it's recommended to " + "use local_tensor.detach() and make requires_grad consistent." + ) + + # new method instruct wrapper tensor from local_tensor and add + # placement spec, it does not do actual distribution + assert spec.tensor_meta is not None, "TensorMeta should not be None!" + r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined] + cls, + spec.tensor_meta.shape, + strides=spec.tensor_meta.stride, + dtype=local_tensor.dtype, + device=local_tensor.device, + layout=local_tensor.layout, + requires_grad=requires_grad, + ) + + r._spec = spec + r._local_tensor = local_tensor + return r + + # pyre-fixme[14]: `__repr__` overrides method defined in `DTensor` inconsistently. + # pyre-fixme[3]: Return type must be annotated. + def __repr__(self): + # TODO: consider all_gather the local tensors for better debugging + return f"DTensor(local_tensor={self._local_tensor}, device_mesh={self._spec.mesh}, placements={self._spec.placements})" + + def __tensor_flatten__(self): + """ + protocol to inform how to flatten a DTensor to local tensor + for PT2 tracing + """ + return ["_local_tensor"], (self._spec, self.requires_grad) + + @staticmethod + def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride): + assert ( + flatten_spec is not None + ), "Expecting spec to be not None from `__tensor_flatten__` return value!" + local_tensor = inner_tensors["_local_tensor"] + spec, requires_grad = flatten_spec + unflatten_tensor_meta = TensorMeta( + shape=outer_size, + stride=outer_stride, + dtype=spec.tensor_meta.dtype, + ) + unflatten_spec = DTensorSpec( + spec.mesh, + spec.placements, + tensor_meta=unflatten_tensor_meta, + ) + return DTensor( + local_tensor, + unflatten_spec, + requires_grad=requires_grad, + ) + + def __coerce_tangent_metadata__(self): + if not any(isinstance(p, Partial) for p in self.placements): + return self + placements = [ + Replicate() if isinstance(p, Partial) else p for p in self.placements + ] + return self.redistribute(device_mesh=self.device_mesh, placements=placements) + + def __coerce_same_metadata_as_tangent__(self, flatten_spec): + (spec, _) = flatten_spec # Result of tensor_flatten() + return self.redistribute( + device_mesh=self.device_mesh, + placements=spec.placements, + ) + + @classmethod + @torch._disable_dynamo + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + return DTensor._op_dispatcher.dispatch( + func, + args, + kwargs or {}, + ) + + @staticmethod + def from_local( + local_tensor: torch.Tensor, + device_mesh: Optional[DeviceMesh] = None, + placements: Optional[Sequence[Placement]] = None, + *, + run_check: bool = False, + shape: Optional[torch.Size] = None, + stride: Optional[Tuple[int, ...]] = None, + ) -> "DTensor": + """ + Create a :class:`DTensor` from a local torch.Tensor on each rank + according to the ``device_mesh`` and ``placements`` specified. + + Args: + local_tensor (torch.Tensor): local torch.Tensor on each rank. + device_mesh (:class:`DeviceMesh`, optional): DeviceMesh to place the + tensor, if not specified, must be called under a DeviceMesh + context manager, default: None + placements (List[:class:`Placement`], optional): the placements that + describes how to place the local torch.Tensor on DeviceMesh, must + have the same number of elements as ``device_mesh.ndim``. + + Keyword args: + run_check (bool, optional): at a cost of extra communications, perform + sanity check across ranks to check each local tensor's meta information + to ensure correctness. If have :class:`Replicate` in ``placements``, the + data on first rank of the device mesh dimension will be broadcasted + to other ranks. default: False + shape (torch.Size, optional): A List of int which specifies the size of + DTensor which build on top of `local_tensor`. Note this needs to be + provided if the shape of ``local_tensor`` are different across the ranks. + If not provided, ``shape`` will be computed assuming the given distributed + tensor is evenly sharded across ranks. default: None + stride (tuple, optional): A List of int which specifies the stride of DTensor. + If not provided, ``stride`` will be computed assuming the given distributed + tensor is evenly sharded across ranks. default: None + + Returns: + A :class:`DTensor` object + + .. note:: When ``run_check=False``, it is the user's responsibility to ensure the + local tensor passed in is correct across ranks. If not, the behavior of the created + DTensor is undefined. + + .. note:: ``from_local`` is differentiable, the `requires_grad` of the created + `DTensor` object will depend on if `local_tensor` requires_grad or not. + """ + # if same shape/dtype, no need to run_check, if not, must allgather + # the metadatas to check the size/dtype across ranks + # There should be no data communication unless there's replication + # strategy, where we broadcast the replication from the first rank + # in the mesh dimension + device_mesh = device_mesh or _mesh_resources.get_current_mesh() + device_type = device_mesh.device_type + + # convert the local tensor to desired device base on device mesh's device_type + if device_type != local_tensor.device.type and not local_tensor.is_meta: + local_tensor = local_tensor.to(device_type) + + # set default placements to replicated if not specified + if placements is None: + placements = [Replicate() for _ in range(device_mesh.ndim)] + else: + placements = list(placements) + for idx, placement in enumerate(placements): + # normalize shard dim to be positive + if placement.is_shard(): + placement = cast(Shard, placement) + if placement.dim < 0: + placements[idx] = Shard(placement.dim + local_tensor.ndim) + + # `from_local` is differentiable, and the gradient of the dist tensor this function + # created should flow back the gradients to the local_tensor, so we call an autograd + # function to construct the dist tensor instead. + return _FromTorchTensor.apply( # pyre-ignore[16]: autograd func + local_tensor, + device_mesh, + tuple(placements), + run_check, + shape, + stride, + ) + + def to_local( + self, *, grad_placements: Optional[Sequence[Placement]] = None + ) -> torch.Tensor: + """ + Get the local tensor of this DTensor on its current rank. For sharding it returns + a local shard of the logical tensor view, for replication it returns the replica on + its current rank. + + Keyword args: + grad_placements (List[:class:`Placement`], optional): the placements describes + the future layout of any gradient layout of the Tensor returned from this + function. + `to_local` converts DTensor to local tensor and the returned local tensor + might not be used as the original DTensor layout later in the code. This + argument is the hint that user can give to autograd in case the gradient + layout of the returned tensor does not match the original DTensor layout. + If not specified, we will assume the gradient layout remains the same + as the original DTensor and use that for gradient computation. + + Returns: + A :class:`torch.Tensor` or `AsyncCollectiveTensor` object. it represents the + local tensor on its current rank. + + .. note:: `to_local` is differentiable, the `requires_grad` of the local tensor returned + will depend on if the `DTensor` requires_grad or not. + """ + if not torch.is_grad_enabled(): + return self._local_tensor + + if grad_placements is not None and not isinstance(grad_placements, tuple): + grad_placements = tuple(grad_placements) + return _ToTorchTensor.apply( + self, grad_placements + ) # pyre-ignore[16]: autograd func + + def redistribute( + self, + device_mesh: Optional[DeviceMesh] = None, + placements: Optional[Sequence[Placement]] = None, + *, + async_op: bool = False, + ) -> "DTensor": + """ + `redistribute` performs necessary collective operations that redistribute the current + DTensor from its current placements to a new placements, or from is current DeviceMesh + to a new DeviceMesh. i.e. we can turn a Sharded DTensor to a Replicated DTensor by + specifying a Replicate placement for each dimension of the DeviceMesh. + + Args: + device_mesh (:class:`DeviceMesh`, optional): DeviceMesh to place the + DTensor, if not specified, must be called under a DeviceMesh + context manager, default: None + placements (List[:class:`Placement`], optional): the new placements that + describes how to place the DTensor into the DeviceMesh, must + have the same number of elements as `device_mesh.ndim`. + + Keyword args: + async_op (bool, optional): whether to perform the DTensor redistribute operation + asynchronously or not. Default: False + + Returns: + A :class:`DTensor` object + + .. note:: `redistribute` is differentiable. + """ + # NOTE: This redistribute API currently only supports out + # of place redistribution, i.e. it always create a new + # DTensor object and leave the original one unchanged. + + # if device_mesh is not specified, use the current device_mesh + device_mesh = device_mesh or self.device_mesh + # raise error if new placements not specified + if placements is None: + raise RuntimeError("placements is needed for redistribute!") + + placements = list(placements) + for i, placement in enumerate(placements): + if placement.is_partial(): + raise RuntimeError( + "Can not redistribute to Partial, redistributing to Partial is for internal use only!" + ) + elif isinstance(placement, Shard) and placement.dim < 0: + # normalize shard dim to be positive + placements[i] = Shard(placement.dim + self.ndim) + placements = tuple(placements) + + # pyre-fixme[16]: `Redistribute` has no attribute `apply`. + return Redistribute.apply(self, device_mesh, placements, async_op) + + def full_tensor( + self, *, grad_placements: Optional[Sequence[Placement]] = None + ) -> torch.Tensor: + """ + Return the full tensor of this DTensor. It will perform necessary collectives + to gather the local tensors from other ranks in its DeviceMesh and concatenate + them together. It's a syntatic sugar of the following code: + + `dtensor.redistribute(placements=[Replicate()] * mesh.ndim).to_local()` + + Keyword args: + grad_placements (List[:class:`Placement`], optional): the placements describes + the future layout of any gradient layout of the full Tensor returned from this + function. + `full_tensor` converts DTensor to a full torch.Tensor and the returned torch.tensor + might not be used as the original replicated DTensor layout later in the code. This + argument is the hint that user can give to autograd in case the gradient + layout of the returned tensor does not match the original replicated DTensor layout. + If not specified, we will assume the gradient layout of the full tensor be replicated. + + Returns: + A :class:`torch.Tensor` object that represents the full tensor of this DTensor. + + .. note:: `full_tensor` is differentiable. + """ + + redist_res = self.redistribute( + placements=[Replicate()] * self.device_mesh.ndim, async_op=False + ) + return _ToTorchTensor.apply(redist_res, grad_placements) + + @property + def device_mesh(self) -> DeviceMesh: + """ + The :class:`DeviceMesh` attribute that associates with this DTensor object. + + .. note:: device_mesh is a read-only property, it can not be set. + """ + return self._spec.mesh + + @property + def placements(self) -> Sequence[Placement]: + """ + The placements attribute of this DTensor that describes the layout of this + DTensor on the its DeviceMesh. + + .. note:: placements is a read-only property, it can not be set. + """ + return self._spec.placements + + def __create_write_items__(self, fqn: str, object: Any): + from torch.distributed.checkpoint.planner_helpers import ( + _create_write_items_for_dtensor, + ) + + if hasattr(self._local_tensor, "__create_write_items__"): + return self._local_tensor.__create_write_items__(fqn, object) # type: ignore[attr-defined] + elif isinstance(self._local_tensor, torch.Tensor): + return [_create_write_items_for_dtensor(fqn, object)] + else: + raise RuntimeError("Unsupported tensor type!") + + def __create_chunk_list__(self): + from torch.distributed.checkpoint.planner_helpers import ( + _create_chunk_from_dtensor, + ) + + if hasattr(self._local_tensor, "__create_chunk_list__"): + return self._local_tensor.__create_chunk_list__() # type: ignore[attr-defined] + elif isinstance(self._local_tensor, torch.Tensor): + return [_create_chunk_from_dtensor(self)] + else: + raise RuntimeError("Unsupported tensor type!") + + def __get_tensor_shard__(self, index): + if hasattr(self._local_tensor, "__get_tensor_shard__"): + return self._local_tensor.__get_tensor_shard__(index) # type: ignore[attr-defined] + elif isinstance(self._local_tensor, torch.Tensor): + return self.to_local() + else: + raise RuntimeError("Unsupported tensor type!") + + +def distribute_tensor( + tensor: torch.Tensor, + device_mesh: Optional[DeviceMesh] = None, + placements: Optional[Sequence[Placement]] = None, +) -> DTensor: + """ + Distribute a leaf torch.Tensor (i.e. nn.Parameter) to the ``device_mesh`` according + to the ``placements`` specified. The rank of ``device_mesh`` and ``placements`` must be + the same. If you want to construct a DTensor in the middle of the Autograd computation, + please use ``DTensor.from_local`` instead. + + Args: + tensor (torch.Tensor): torch.Tensor to be distributed. Note that if you + want to shard a tensor on a dimension that is not evenly divisible by + the number of devices in that mesh dimension, we use ``torch.chunk`` + semantic to shard the tensor and scatter the shards. + device_mesh (:class:`DeviceMesh`, optional): DeviceMesh to distribute the + tensor, if not specified, must be called under a DeviceMesh context + manager, default: None + placements (List[:class:`Placement`], optional): the placements that + describes how to place the tensor on DeviceMesh, must have the same + number of elements as `device_mesh.ndim`. If not specified, we will + by default replicate the tensor across the `device_mesh` from the + first rank of each dimension of the `device_mesh`. + + Returns: + A :class:`DTensor` or `XLAShardedTensor` object. + + Note: + When initialize the DeviceMesh with the `xla` device_type, `distribute_tensor` + return `XLAShardedTensor` instead. see [link](https://github.com/pytorch/pytorch/issues/92909) + for more details. The XLA integration is experimental and subject to change. + """ + + torch._C._log_api_usage_once("torch.dtensor.distribute_tensor") + + # get default device mesh if there's nothing specified + device_mesh = device_mesh or _mesh_resources.get_current_mesh() + device_type = device_mesh.device_type + if device_type == "xla": + try: + # call PyTorch/XLA SPMD for `xla` backend type device mesh. + # This returns XLAShardedTensor + from torch_xla.distributed.spmd import ( # type:ignore[import] + xla_distribute_tensor, + ) + + return xla_distribute_tensor( + tensor, device_mesh, placements + ) # type:ignore[return-value] + except ImportError as e: + msg = "To use DTensor API with xla, you must install the torch_xla package!" + raise ImportError(msg) from e + + # instantiate a RNG tracker if haven't. By default DTensor uses an + # OffsetBasedRNGTracker to perform random operators. + # TODO: the value assignment to global variable is not the ideal solution + # we can replace it in future. + if not random._rng_tracker and is_rng_supported_mesh(device_mesh): + random._rng_tracker = OffsetBasedRNGTracker(device_type) + + if not tensor.is_leaf: + raise RuntimeError( + "`distribute_tensor` should be used to distribute leaf tensors! but found non-leaf tensor!" + ) + + # convert tensor to the corresponding device type if it's not in that device type + if device_type != tensor.device.type and not tensor.is_meta: + tensor = tensor.to(device_type) + + # set default placements to replicated if not specified + if placements is None: + placements = [Replicate() for _ in range(device_mesh.ndim)] + + if len(placements) != device_mesh.ndim: + raise ValueError( + f"`placements` must have the same length as `device_mesh.ndim`! " + f"Found placements length: {len(placements)}, and device_mesh.ndim: {device_mesh.ndim}." + ) + if isinstance(tensor, DTensor): + # if the tensor is already a DTensor, we need to check: + # 1. if the we can further shard this DTensor if the two device mesh belong to + # the same parenet mesh and further sharding is possible. + # 2. check if device mesh and placements are the same + if tensor.device_mesh != device_mesh: + raise ValueError( + f"Cannot distribute a DTensor with device mesh {tensor.device_mesh} " + f"to a different device mesh {device_mesh}." + ) + if tensor.placements != tuple(placements): + raise ValueError( + f"Cannot distribute a DTensor with placements {tensor.placements} " + f"to a different placements {placements}. do you want to call " + f"`redistribute` instead?" + ) + return tensor + + local_tensor = tensor.detach() + + # distribute the tensor according to the placements. + placements = list(placements) + for idx, placement in enumerate(placements): + if placement.is_shard(): + placement = cast(Shard, placement) + if placement.dim < 0: + # normalize shard placement dim + placement = Shard(placement.dim + tensor.ndim) + placements[idx] = placement + local_tensor = placement._shard_tensor(local_tensor, device_mesh, idx) + elif placement.is_replicate(): + placement = cast(Replicate, placement) + local_tensor = placement._replicate_tensor(local_tensor, device_mesh, idx) + else: + raise RuntimeError( + f"Trying to distribute tensor with unsupported placements {placement} on device mesh dimension {idx}!" + ) + placements = tuple(placements) + + assert local_tensor is not None, "distributing a tensor should not be None" + # detach the local tensor passed to DTensor since after the construction + # of DTensor, autograd would work on top of DTensor instead of local tensor + spec = DTensorSpec( + mesh=device_mesh, + placements=placements, + tensor_meta=TensorMeta( + shape=tensor.size(), + stride=tensor.stride(), + dtype=tensor.dtype, + ), + ) + return DTensor( + local_tensor.requires_grad_(tensor.requires_grad), + spec, + requires_grad=tensor.requires_grad, + ) + + +def distribute_module( + module: nn.Module, + device_mesh: Optional[DeviceMesh] = None, + partition_fn: Optional[Callable[[str, nn.Module, DeviceMesh], None]] = None, + input_fn: Optional[Callable[[nn.Module, Any, DeviceMesh], None]] = None, + output_fn: Optional[Callable[[nn.Module, Any, DeviceMesh], None]] = None, +) -> nn.Module: + """ + This function expose three functions to control the Tensors inside the module: + 1. To perform sharding on the module before runtime execution by specifying the + ``partition_fn`` (i.e. allow user to convert Module parameters to :class:`DTensor` + parameters according to the `partition_fn` specified). + 2. To control the inputs or outputs of the module during runtime execution by + specifying the ``input_fn`` and ``output_fn``. (i.e. convert the input to + :class:`DTensor`, convert the output back to torch.Tensor) + + Args: + module (:class:`nn.Module`): user module to be partitioned. + device_mesh (:class:`DeviceMesh`): the device mesh to place the module. + partition_fn (Callable): the function to partition parameters (i.e. shard certain + parameters across the `device_mesh`). If `partition_fn` is not specified, + by default we replicate all module parameters of `module` across the mesh. + input_fn (Callable): specify the input distribution, i.e. could control how the + input of the module is sharded. `input_fn` will be installed as a module + `forward_pre_hook` (pre forward hook). + output_fn (Callable): specify the output distribution, i.e. could control how the + output is sharded, or convert it back to torch.Tensor. output_fn will be + installed as a module `forward_hook` (post forward hook). + + Returns: + A module that contains parameters/buffers that are all `DTensor`s. + + Note: + When initialize the DeviceMesh with the `xla` device_type, `distribute_module` + return nn.Module with PyTorch/XLA SPMD annotated parameters. See [link](https://github.com/pytorch/pytorch/issues/92909) + for more details. The XLA integration is experimental and subject to change. + """ + + torch._C._log_api_usage_once("torch.dtensor.distribute_module") + + device_mesh = device_mesh or _mesh_resources.get_current_mesh() + device_type = device_mesh.device_type + if device_type == "xla": + try: + # This function annotates all module parameters for auto-partitioning with + # PyTorch/XLA SPMD or explicitly partition to :class:`XLAShardedTensor` parameters + # according to the `partition_fn` specified. + from torch_xla.distributed.spmd import ( # type:ignore[import] + xla_distribute_module, + ) + + return xla_distribute_module( + module, device_mesh, partition_fn, input_fn, output_fn + ) # type:ignore[return-value] + except ImportError as e: + msg = "To use DTensor API with xla, you must install the torch_xla package!" + raise ImportError(msg) from e + + def replicate_module_params_buffers(m: nn.Module, mesh: DeviceMesh) -> None: + # This function loop over the immediate module parameters and + # buffers, replicate all non DTensor params/buffers to DTensor + # parameters/buffers, if they have not been partitioned in the + # partition_fn, we can't easily use `module._apply` here + # because we don't know what happened inside partition_fn as + # user could do anything, i.e. install hooks, and we want to + # preserve those. + full_replicate = [Replicate()] * mesh.ndim + for key, param in m._parameters.items(): + if param is not None and not isinstance(param, DTensor): + m.register_parameter( + key, + nn.Parameter(distribute_tensor(param.data, mesh, full_replicate)), + ) + for key, buffer in m._buffers.items(): + if buffer is not None and not isinstance(buffer, DTensor): + m._buffers[key] = distribute_tensor(buffer, mesh, full_replicate) + + if partition_fn is None: + # if partition_fn not specified, we by default replicate + # all module params/buffers + for name, submod in module.named_modules(): + replicate_module_params_buffers(submod, device_mesh) + else: + # apply partition_fun to submodules + for name, submod in module.named_modules(): + partition_fn(name, submod, device_mesh) + replicate_module_params_buffers(submod, device_mesh) + + # register input_fn as module forward pre hook + if input_fn is not None: + # check the input_fn signature + num_args = len(inspect.signature(input_fn).parameters) + if num_args == 2: + # input_fn only takes in inputs and device mesh + warnings.warn( + "Deprecating input_fn that takes two arguments (inputs, device_mesh), " + "please use input_fn that takes in (module, inputs, device_mesh) instead!", + FutureWarning, + stacklevel=2, + ) + module.register_forward_pre_hook(lambda _, inputs: input_fn(inputs, device_mesh)) # type: ignore[call-arg] + elif num_args == 3: + # input_fn takes in module, inputs, device mesh + module.register_forward_pre_hook( + lambda mod, inputs: input_fn(mod, inputs, device_mesh) + ) + else: + raise ValueError( + f"input_fn should take in 3 arguments, but got {num_args} arguments!" + ) + # register output_fn as module forward hook + if output_fn is not None: + num_args = len(inspect.signature(output_fn).parameters) + if num_args == 2: + # output_fn only takes in outputs and device mesh + warnings.warn( + "Deprecating output_fn that takes two arguments (inputs, device_mesh), " + "please use output_fn that takes in (module, inputs, device_mesh) instead!", + FutureWarning, + stacklevel=2, + ) + module.register_forward_hook( + lambda mod, inputs, outputs: output_fn(outputs, device_mesh) # type: ignore[call-arg] + ) + elif num_args == 3: + module.register_forward_hook( + lambda mod, inputs, outputs: output_fn(mod, outputs, device_mesh) + ) + else: + raise ValueError( + f"output_fn should take in 3 arguments, but got {num_args} arguments!" + ) + + return module diff --git a/src/chop/distributed/tensor/ops/__init__.py b/src/chop/distributed/tensor/ops/__init__.py new file mode 100644 index 000000000..eaccc8aa8 --- /dev/null +++ b/src/chop/distributed/tensor/ops/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +from .conv_ops import * # noqa: F403 +from .embedding_ops import * # noqa: F403 +from .experimental_ops import * # noqa: F403 +from .math_ops import * # noqa: F403 +from .matrix_ops import * # noqa: F403 +from .pointwise_ops import * # noqa: F403 +from .random_ops import * # noqa: F403 +from .tensor_ops import * # noqa: F403 +from .view_ops import * # noqa: F403 diff --git a/src/chop/distributed/tensor/ops/basic_strategy.py b/src/chop/distributed/tensor/ops/basic_strategy.py new file mode 100644 index 000000000..97dd43b15 --- /dev/null +++ b/src/chop/distributed/tensor/ops/basic_strategy.py @@ -0,0 +1,181 @@ +import itertools +from dataclasses import dataclass +from typing import List, Set, Tuple + +from torch.distributed._tensor._op_schema import OpStrategy, PlacementStrategy +from torch.distributed._tensor.placement_types import ( + DTensorSpec, + Partial, + Placement, + Replicate, + Shard, +) +from torch.distributed.device_mesh import DeviceMesh + + +@dataclass +class EinsumDims: + contracting_dims: List[str] + batch_dims: List[str] + lhs_out_only_dims: List[str] + rhs_out_only_dims: List[str] + + @classmethod + def parse_equation(cls, equation: str) -> Tuple[List[str], str]: + # parse einop equation and extract arg specs + """ + Parse the einsum equation str to input dim chars and output dim char + """ + inputs, outputs = equation.split("->") + input_dims, output_dims = inputs.split(","), outputs.split(",") + + # NOTE: only support at most two inputs, and single output + # extend to support more inputs if needed in future + assert len(input_dims) <= 2, "Only support at most two inputs" + assert len(output_dims) == 1, "Only support single output" + output_dim = output_dims[0] + return input_dims, output_dim + + @classmethod + def parse_dims(cls, input_dims: List[str], output_dim: str) -> "EinsumDims": + """ + Parse the dims and extract the contracting, batch, and free dimensions + for the left and right hand sides. + """ + dim_char_set: Set[str] = set() + for input_dim in input_dims: + dim_char_set.update(input_dim) + + # get a determinisitc order of all dim chars + all_dim_chars = sorted(dim_char_set) + + # parse input and output dimensions + lhs_out_only_dims, rhs_out_only_dims = [], [] + batch_dims, contracting_dims = [], [] + + for dim_char in all_dim_chars: + if dim_char not in output_dim: + contracting_dims.append(dim_char) + else: + is_batch_dim = True + for input_dim in input_dims: + is_batch_dim = is_batch_dim and dim_char in input_dim + + if is_batch_dim: + batch_dims.append(dim_char) + else: + assert ( + len(input_dims) == 2 + ), "free dimension only supported for two inputs!" + lhs, rhs = input_dims + if dim_char in lhs: + lhs_out_only_dims.append(dim_char) + elif dim_char in rhs: + rhs_out_only_dims.append(dim_char) + else: + raise RuntimeError("Invalid dimension character") + + return cls( + contracting_dims=contracting_dims, + batch_dims=batch_dims, + lhs_out_only_dims=lhs_out_only_dims, + rhs_out_only_dims=rhs_out_only_dims, + ) + + +def gen_einsum_strategies( + equation: str, + mesh: DeviceMesh, + *, + linearity: bool = False, +) -> OpStrategy: + """ + Generate a strategy list for the ops that follow einsum style notation. + """ + # parse einop equation and extract dims + input_dims, output_dim = EinsumDims.parse_equation(equation) + edims = EinsumDims.parse_dims(input_dims, output_dim) + + all_mesh_dim_strategies = [] + + # generate strategies for each mesh dim + for mesh_dim in range(mesh.ndim): + mesh_dim_strategies = [] + + # placement list stores placements of [output, input1, input2, ...] + # first we always have replicate all for inputs and output + placement_list: List[Placement] = [Replicate()] * (len(input_dims) + 1) + mesh_dim_strategies.append(placement_list) + + if mesh.size(mesh_dim) <= 1: + # only replicate strategy for mesh dim with size 1 + # TODO: see if this is valid for the submesh case + continue + + # split batch dim + for batch_dim in edims.batch_dims: + output_batch_dim = output_dim.index(batch_dim) + placement_list = [Shard(output_batch_dim)] + for input_dim in input_dims: + input_batch_dim = input_dim.index(batch_dim) + placement_list.append(Shard(input_batch_dim)) + + mesh_dim_strategies.append(placement_list) + + # split contracting dim + for contracting_dim in edims.contracting_dims: + placement_list = [Partial()] + for input_dim in input_dims: + input_contracting_dim = input_dim.index(contracting_dim) + placement_list.append(Shard(input_contracting_dim)) + + mesh_dim_strategies.append(placement_list) + + # split lhs free dim + for lhs_dim in edims.lhs_out_only_dims: + lhs_free_dim = output_dim.index(lhs_dim) + # this means split the lhs input and output + # i.e. S(0), R -> S(0) + lhs_placement_list: List[Placement] = [ + Shard(lhs_free_dim), + Shard(lhs_free_dim), + Replicate(), + ] + mesh_dim_strategies.append(lhs_placement_list) + + # split rhs free dim + for rhs_dim in edims.rhs_out_only_dims: + rhs_free_dim = output_dim.index(rhs_dim) + rhs_placement_list: List[Placement] = [ + Shard(rhs_free_dim), + Replicate(), + Shard(rhs_free_dim), + ] + mesh_dim_strategies.append(rhs_placement_list) + + # linearity strategy + if linearity: + linearity_placement_list: List[Placement] = [Partial()] + for input_dim in input_dims: + linearity_placement_list.append(Partial()) + mesh_dim_strategies.append(linearity_placement_list) + + all_mesh_dim_strategies.append(mesh_dim_strategies) + + # generate strategies for entire mesh + strategy_combs = itertools.product(*all_mesh_dim_strategies) + + # TODO: filter out invalid strategies, at this point we generate + # all possible strategies without considering the whether the tensor + # dim could be sharded or not, we would need to filter out invalid + # strategies base on the actual tensor shape + # (i.e. for Shard, tensor dim size must > mesh size) + all_strategies = [] + for strategy_comb in strategy_combs: + spec_list = [] + for specs in zip(*strategy_comb): + spec_list.append(DTensorSpec(mesh, tuple(specs))) + strat = PlacementStrategy(output_specs=spec_list[0], input_specs=spec_list[1:]) + all_strategies.append(strat) + + return OpStrategy(all_strategies) diff --git a/src/chop/distributed/tensor/ops/common_rules.py b/src/chop/distributed/tensor/ops/common_rules.py new file mode 100644 index 000000000..f70b27076 --- /dev/null +++ b/src/chop/distributed/tensor/ops/common_rules.py @@ -0,0 +1,288 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +from typing import cast, Dict, List, Optional, Tuple + +import torch +from torch.distributed._tensor._op_schema import ( + _is_inplace_op, + _is_out_variant_op, + OpSchema, + OutputSharding, +) +from torch.distributed._tensor._utils import compute_local_shape +from torch.distributed._tensor.ops.utils import prod +from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta + + +def _replace_char_in_str(string: str, new_char: str, idx: int) -> str: + return string[:idx] + new_char + string[idx + 1 :] + + +def _gen_reshard_suggestions( + op_schema: OpSchema, + input_dims: List[str], + input_specs: Tuple[DTensorSpec, ...], + dim_to_sharding: Dict[str, int], + pending_sum: List[int], +) -> OutputSharding: + suggested_arg_specs: List[DTensorSpec] = [] + for input_dim, input_spec in zip(input_dims, input_specs): + dim_map = [dim_to_sharding[dim] for dim in input_dim] + suggested_arg_specs.append( + DTensorSpec.from_dim_map( + mesh=input_spec.mesh, + dim_map=dim_map, + sums=pending_sum, + tensor_meta=input_spec.tensor_meta, + ) + ) + suggested_schema = OpSchema(op_schema.op, tuple(suggested_arg_specs), {}) + suggested_schema._inplace_rewrap_schema_suggestion(op_schema) + return OutputSharding( + None, + redistribute_schema=suggested_schema, + ) + + +def einop_rule( + equation: str, + op_schema: OpSchema, + *, + linearity: bool = False, + enforce_sharding: Optional[Dict[str, int]] = None, +) -> OutputSharding: + """ + Propagate the sharding of inputs to output for ops whose data moves according to einsum notation. + + This is mostly borrowed from @zdevito's sharding simulator. Examples: + mk,kn->mn - einsum + ij,ij->ij - addition + ij,j->ij - broadcasted addition + ij->i - reduction + Other ops could use this propagation algorithm when applied, note + that einsum propagation only deal with list of specs (DTensor specs) + as it only works on list of tensors! + + linearity in einop_rule means that the calling op `f` follows this rule: + f(a + b) = f(a) + f(b) + + In this case we can propagate the partial sum, note that linearity in einop + only applies to partial sum, not other operations like min/max (which are + associative but not linear). + """ + # parse einop equation and extract arg specs + inputs, outputs = equation.split("->") + input_dims, output_dims = inputs.split(","), outputs.split(",") + input_specs = op_schema.args_spec + # NOTE: only support single output unless needed in future + output_dim = output_dims[0] + + dim_to_sharding: Dict[str, int] = {} + dim_to_size: Dict[str, int] = {} + # record pending sum, key is mesh dimension, value is pending sum + # counter across input specs + pending_sums_counter: Dict[int, int] = {} + seen_shardings: Dict[int, str] = {} + needs_reshard = False + + def merge_sharding(dim: str, a: int, b: int) -> int: + # merge the sharding of inputs if it's able to merge, i.e. we can merge + # replicate and shard to shard, but this will trigger an reshard operation + if a != b: + if a == -1 or b == -1: + # reshard the replicate to match the sharded one + nonlocal needs_reshard + needs_reshard = True + return a if a != -1 else b + else: + # TODO: further merge the sharding properly (i.e. reshard one input to replicate) + raise RuntimeError( + f"{equation}: dim {dim} sharded two different ways: {a} and {b}" + ) + else: + return a + + for input_dim, input_spec in zip(input_dims, input_specs): + # deal with partial sums + input_sums = input_spec.sums + for sum_dim in input_sums: + if sum_dim not in pending_sums_counter: + seen_shardings[sum_dim] = "+" + # update pending sum counter for pending sum mesh + # dimension with the occurrence from each input + pending_sums_counter[sum_dim] = pending_sums_counter.get(sum_dim, 0) + 1 + + for idx, (dim, mesh_dim) in enumerate(zip(input_dim, input_spec.dim_map)): + if enforce_sharding and dim in enforce_sharding: + if enforce_sharding[dim] != mesh_dim: + needs_reshard = True + dim_to_sharding[dim] = enforce_sharding[dim] + dim_to_size[dim] = input_spec.shape[idx] + elif dim not in dim_to_sharding: + dim_to_sharding[dim] = mesh_dim + dim_to_size[dim] = input_spec.shape[idx] + else: + dim_to_sharding[dim] = merge_sharding( + dim, dim_to_sharding[dim], mesh_dim + ) + assert dim_to_size[dim] == input_spec.shape[idx] + + # after merging sharding, we check if there're multiple + # sharding on the same mesh dim. + merged_sharding_for_dim = dim_to_sharding[dim] + if merged_sharding_for_dim != -1: + if ( + merged_sharding_for_dim in seen_shardings + and dim != seen_shardings[merged_sharding_for_dim] + ): + needs_reshard = True + seen_shardings[merged_sharding_for_dim] += dim + else: + seen_shardings[merged_sharding_for_dim] = dim + + if pending_sums_counter and not linearity: + # return reshard suggestion with no pending sum, because we already properly + # merge the sharding, this reshard suggestion is legit to use + return _gen_reshard_suggestions( + op_schema, input_dims, input_specs, dim_to_sharding, [] + ) + else: + # It's a op that support linearity, but not all input arguments are partial + # we fail the sharding propagation with suggestion to make all inputs be + # partial on the corresponding mesh dim (all inputs should be partial for + # the mesh dims in order to execute locally and delay the sum reduction) + for value in pending_sums_counter.values(): + if value != len(input_specs): + needs_reshard = True + + for mesh_dim, dims in seen_shardings.items(): + if len(dims) > 1: + # we found different input dims are being sharded on the same mesh dim + # in order to perform local op computation, we need to reshard inputs + # base on some simple heuristics, now we simply pick the one with least comm + # volume. (i.e. the input with least size) + # TODO: consider a more advanced heuristic to pick the best sharding + costs = [] + for d in dims: + cost = 0 + for input_dim, input_spec in zip(input_dims, input_specs): + if ( + d in input_dim + and input_spec.dim_map[input_dim.index(d)] == mesh_dim + ): + assert input_spec.tensor_meta is not None + global_shape = input_spec.tensor_meta.shape + local_shape = compute_local_shape( + global_shape, input_spec.mesh, input_spec.placements + ) + cost += prod(local_shape) * input_spec.mesh.size(mesh_dim) + costs.append(cost) + d_to_keep_sharding = dims[costs.index(max(costs))] + for d in dims: + # update dim_to_sharding to keep the sharding of the dim with + # highest comm and make the rest of the dims to replicate + if d != d_to_keep_sharding: + dim_to_sharding[d] = -1 + + pending_sums = list(pending_sums_counter.keys()) + if needs_reshard: + return _gen_reshard_suggestions( + op_schema, input_dims, input_specs, dim_to_sharding, pending_sums + ) + + # generate output pending sum if a dim is sharded, and it appears in input + # but not output + for dim, shard_on_mesh in dim_to_sharding.items(): + if dim not in output_dims[0] and shard_on_mesh != -1: + pending_sums.append(shard_on_mesh) + + # if no need to reshard, we directly generate the output sharding + output_dim_map = [] + output_shape = [] + for dim in output_dim: + if dim == "1": + # find output dim that is a singleton dimension, mark sharding and shape + output_dim_map.append(-1) + output_shape.append(1) + else: + output_dim_map.append(dim_to_sharding[dim]) + output_shape.append(dim_to_size[dim]) + + # XXX: since we still need to have intermediate shape calculation, we need + # to pass in the shape here. We should remove this once sharding decomp works + # for ops like addmm + assert input_specs[0].tensor_meta is not None + tensor_meta = TensorMeta( + torch.Size(output_shape), + input_specs[0].tensor_meta.stride, + input_specs[0].tensor_meta.dtype, + ) + return OutputSharding( + DTensorSpec.from_dim_map( + input_specs[0].mesh, + output_dim_map, + pending_sums, + tensor_meta=tensor_meta, + ) + ) + + +def pointwise_rule(op_schema: OpSchema, linearity: bool = False) -> OutputSharding: + """ + Propagate the sharding for pointwise operations. + + Examples: + ij,ij->ij - addition/mul + ij,j->ij - broadcasted addition + """ + alphabet = "abcdefghijklmnopqrstuvwxyz" + # find the max_dim first in case we need to broadcasting + input_specs = op_schema.args_spec + max_dim = max(input.ndim for input in input_specs) + dimchars = [] + singleton_counter: List[int] = [0] * max_dim + for input in input_specs: + start_dim = max_dim - input.ndim + p = alphabet[start_dim:max_dim] + # handle the "broadcasting to a common shape case" + # see https://pytorch.org/docs/stable/notes/broadcasting.html + # If any of the dimensions is singleton dimension (i.e. 1). + # we mark the dim char as a special "1" to distinguish with + # the non-singleton dimension, so that sharding propagation + # should just ignore the singleton dimension. + if len(input_specs) > 1: + for i in range(max_dim): + if i < start_dim: + # treat the leading miss dim chars as singleton + singleton_counter[i] += 1 + elif input.shape[i - start_dim] == 1: + # mark singleton dim char as a special "1" in einop rule + singleton_counter[i] += 1 + p = _replace_char_in_str(p, "1", (i - start_dim)) + + dimchars.append(p) + out_dimchars = alphabet[:max_dim] + # check if we replace the all inputs dim char with singleton dimension, + # if we replace all inputs, we also need to replace the output dimension. + for output_dim_idx in range(len(out_dimchars)): + out_dimchar = out_dimchars[output_dim_idx] + if singleton_counter[output_dim_idx] == len(input_specs): + out_dimchars = _replace_char_in_str(out_dimchars, "1", output_dim_idx) + + fmt = f"{','.join(p for p in dimchars)}->{out_dimchars}" + + enforce_sharding: Dict[str, int] = {} + if _is_inplace_op(op_schema.op): + # inplace op should keep the input sharding it writes to + for out_dimchar, mesh_dim in zip(out_dimchars, input_specs[0].dim_map): + enforce_sharding[out_dimchar] = mesh_dim + elif _is_out_variant_op(op_schema.op): + out_spec = cast(DTensorSpec, op_schema.kwargs_schema["out"]) + for out_dimchar, mesh_dim in zip(out_dimchars, out_spec.dim_map): + enforce_sharding[out_dimchar] = mesh_dim + + return einop_rule( + fmt, + op_schema, + linearity=linearity, + enforce_sharding=enforce_sharding, + ) diff --git a/src/chop/distributed/tensor/ops/conv_ops.py b/src/chop/distributed/tensor/ops/conv_ops.py new file mode 100644 index 000000000..7bf13241d --- /dev/null +++ b/src/chop/distributed/tensor/ops/conv_ops.py @@ -0,0 +1,109 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +# implement matrix related ops for distributed tensor +from typing import List + +import torch +from torch.distributed._tensor._op_schema import OpSchema, OutputSharding +from chop.distributed.tensor.ops.utils import register_prop_rule +from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta + + +aten = torch.ops.aten + + +@register_prop_rule(aten.convolution.default) +def convolution_rules(op_schema: OpSchema) -> OutputSharding: + ( + input_spec, + weight_spec, + bias_spec, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + ) = op_schema.args_schema + + assert isinstance(input_spec, DTensorSpec) + assert isinstance(weight_spec, DTensorSpec) + assert isinstance(bias_spec, DTensorSpec) + assert input_spec.tensor_meta is not None + assert weight_spec.tensor_meta is not None + in_shape = input_spec.tensor_meta.shape + weight_shape = weight_spec.tensor_meta.shape + assert isinstance(stride, List) + assert isinstance(padding, List) + assert isinstance(dilation, List) + assert isinstance(weight_shape, torch.Size) + N, C_in, H_in, W_in = in_shape[0], in_shape[1], in_shape[2], in_shape[3] + C_out = weight_shape[0] + H_out = (H_in + 2 * padding[0] - dilation[0] * (weight_shape[2] - 1) - 1) // stride[ + 0 + ] + 1 + W_out = (W_in + 2 * padding[1] - dilation[1] * (weight_shape[3] - 1) - 1) // stride[ + 1 + ] + 1 + output_shape = [N, C_out, H_out, W_out] + output_stride = (C_out * H_out * W_out, H_out * W_out, W_out, 1) + output_dim_map = input_spec.dim_map + pending_sums = input_spec.sums + + tensor_meta = TensorMeta( + torch.Size(output_shape), + output_stride, + input_spec.tensor_meta.dtype, + ) + return OutputSharding( + DTensorSpec.from_dim_map( + input_spec.mesh, + output_dim_map, + pending_sums, + tensor_meta=tensor_meta, + ) + ) + + +@register_prop_rule(aten.convolution_backward.default) +def convolution_backward_rules(op_schema: OpSchema) -> OutputSharding: + input_spec = op_schema.args_schema[0] + ( + grad_output_spec, + input_spec, + weight_spec, + bias_shape_opt, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + output_mask, + ) = op_schema.args_schema + + assert isinstance(grad_output_spec, DTensorSpec) + assert isinstance(input_spec, DTensorSpec) + assert isinstance(weight_spec, DTensorSpec) + assert isinstance(bias_shape_opt, List) + assert input_spec.tensor_meta is not None + weight_tensor_meta = weight_spec.tensor_meta + bias_tensor_meta = TensorMeta( + torch.Size(bias_shape_opt), + (1,), + input_spec.tensor_meta.dtype, + ) + + grad_input_spec = input_spec + grad_weight_spec = DTensorSpec.from_dim_map( + input_spec.mesh, + [-1, -1, -1, -1], + [0], + tensor_meta=weight_tensor_meta, + ) + grad_bias_spec = DTensorSpec.from_dim_map( + input_spec.mesh, + [-1], + [0], + tensor_meta=bias_tensor_meta, + ) + return OutputSharding([grad_input_spec, grad_weight_spec, grad_bias_spec]) diff --git a/src/chop/distributed/tensor/ops/embedding_ops.py b/src/chop/distributed/tensor/ops/embedding_ops.py new file mode 100644 index 000000000..d89ec651b --- /dev/null +++ b/src/chop/distributed/tensor/ops/embedding_ops.py @@ -0,0 +1,251 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates +# implement matrix related ops for distributed tensor +from dataclasses import dataclass, field +from typing import cast, Optional + +import torch +import torch.distributed._functional_collectives as funcol +from torch.distributed._tensor._op_schema import ( + OpSchema, + OpStrategy, + PlacementList, + StrategyType, +) +from torch.distributed._tensor.ops.utils import expand_to_full_mesh_op_strategy +from torch.distributed._tensor.placement_types import ( + Partial, + Placement, + Replicate, + Shard, +) +from torch.distributed.device_mesh import DeviceMesh +from chop.distributed.tensor.ops.utils import register_op_strategy + +aten = torch.ops.aten + + +@dataclass +class MaskBuffer: + data: Optional[torch.Tensor] = None + + def materialize_mask(self, mask): + if self.data is not None: + raise RuntimeError("MaskBuffer has already been materialized") + self.data = mask + + def release_mask(self): + # TODO: evaluate if we need to release the mask buffer or the buffer + # can just have the same lifetime as the Partial placement + if self.data is None: + raise RuntimeError("MaskBuffer has not been materialized") + self.data = None + + def apply_mask(self, tensor): + if self.data is None: + raise RuntimeError("MaskBuffer has not been materialized") + + # NOTE: _MaskPartial is being used by the embedding op and the gather op. + # For gather, the mask has the same dimension as the output tensor, whereas + # the output of the embedding op has an additional dimension compare to the input, + # hence the output masking logic below having two different cases. + if tensor.ndim == self.data.ndim: + tensor[self.data] = 0.0 + else: + tensor[self.data, :] = 0.0 + + +@dataclass(frozen=True) +class _MaskPartial(Partial): + """ + A partial mask placement devised for rowwise sharded embedding op, where we need + to mask and adjust the indices to the local embedding shard, embedding masking + is a special type of the Partial placement + + NOTE: the lifecycle of this MaskPartial placement follows the corresponding DTensor + lifecycle, i.e. the indices_mask would only be alive during the lifetime of the DTensor. + """ + + logical_dim_size: int = -1 + mask_buffer: MaskBuffer = field(default_factory=MaskBuffer) + + def _partition_value( + self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int + ) -> torch.Tensor: + # override parent logic to perform partial mask for embedding + num_chunks = mesh.size(mesh_dim) + # get local shard size and offset on the embedding_dim + local_shard_size, local_offset_on_dim = Shard._local_shard_size_on_dim( + self.logical_dim_size, + num_chunks, + mesh.get_local_rank(mesh_dim), + return_offset=True, + ) + # Build the input mask and save it for the current partial placement + # this is so that the output of embedding op can reuse the same partial + # placement saved mask to perform mask + reduction + mask = (tensor < local_offset_on_dim) | ( + tensor >= local_offset_on_dim + local_shard_size + ) + # mask the input tensor + masked_tensor = tensor.clone() - local_offset_on_dim + masked_tensor[mask] = 0 + # materialize the mask buffer to be used for reduction + self.mask_buffer.materialize_mask(mask) + return masked_tensor + + def _reduce_value( + self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int + ) -> torch.Tensor: + # by the time we ned reduction, we should have already saved the mask + assert self.mask_buffer.data is not None + + # apply the mask to the tensor that pending reduction + self.mask_buffer.apply_mask(tensor) + + # clear the mask buffer + self.mask_buffer.release_mask() + + # perform sum reduction + return funcol.all_reduce( + tensor, reduceOp=self.reduce_op, group=(mesh, mesh_dim) + ) + + def _reduce_shard_value( + self, + tensor: torch.Tensor, + mesh: DeviceMesh, + mesh_dim: int, + shard_spec: Placement, + ) -> torch.Tensor: + # by the time we ned reduction, we should have already saved the mask + assert self.mask_buffer.data is not None + + # apply the mask to the tensor that pending reduction + self.mask_buffer.apply_mask(tensor) + + # clear the mask buffer + self.mask_buffer.release_mask() + + # call reduce_shard_tensor of the shard_spec. + shard_spec = cast(Shard, shard_spec) + return shard_spec._reduce_shard_tensor(tensor, mesh, self.reduce_op, mesh_dim) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, _MaskPartial): + return False + + # if either data is not None, we invalidate the sharding cache, as this indicates + # the current MaskPartial placement is still in use and should not be used for cache hit. + if self.mask_buffer.data is not None or other.mask_buffer.data is not None: + return False + + return ( + self.reduce_op == other.reduce_op + and self.logical_dim_size == other.logical_dim_size + ) + + def __hash__(self) -> int: + return 1 + hash( + (self.logical_dim_size, id(self.mask_buffer.data), self.reduce_op) + ) + + def __repr__(self) -> str: + """ + machine readable representation of the MaskPartial placement + """ + return f"_MaskPartial(logical_dim_size={self.logical_dim_size})" + + def __str__(self) -> str: + """ + human readable representation of the MaskPartial placement + """ + return "MaskP" + + +@register_op_strategy(aten.embedding.default) +def embedding_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: + """ + This strategy handles embedding op. We have two possible embedding shardings: + rowwise and colwise + """ + weight_strategy = cast(OpStrategy, op_schema.args_schema[0]) + indices_strategy = cast(OpStrategy, op_schema.args_schema[1]) + + weight_shape = weight_strategy.shape + indices_shape = indices_strategy.shape + output_emd_dim = len(indices_shape) + + single_mesh_dim_strategies = [] + + # placement list stores placements of [output, weight, input_indices] + # first we always have replicate all for inputs and output + all_replicate: PlacementList = [Replicate()] * 3 + single_mesh_dim_strategies.append(all_replicate) + + # colwise sharding, output shard on last dim, weight shard on dim 1, input replicate + colwise_sharding: PlacementList = [Shard(output_emd_dim), Shard(1), Replicate()] + single_mesh_dim_strategies.append(colwise_sharding) + + # rowwise sharding, output is embedding partial, weight shard on dim 0, input accepts embedding partial + embedding_partial_placement = _MaskPartial(logical_dim_size=weight_shape[0]) + + # NOTE we want to reuse the same mask partial placement so that we can reuse the same mask that generates + # from the input indices and use it for output reduction + rowwise_sharding: PlacementList = [ + embedding_partial_placement, + Shard(0), + embedding_partial_placement, + ] + single_mesh_dim_strategies.append(rowwise_sharding) + + # batch dim sharding, weight replicated, input can shard on any dim, output follows input + for input_dim in range(len(indices_shape)): + batch_sharding: PlacementList = [ + Shard(input_dim), + Replicate(), + Shard(input_dim), + ] + single_mesh_dim_strategies.append(batch_sharding) + + return expand_to_full_mesh_op_strategy(mesh, op_schema, single_mesh_dim_strategies) + + +@register_op_strategy(aten.embedding_dense_backward.default) +def embedding_dense_backward_strategy( + mesh: DeviceMesh, op_schema: OpSchema +) -> StrategyType: + """ + This strategy handles embedding op. We have two possible embedding shardings: + rowwise and colwise + """ + grad_out_strategy = cast(OpStrategy, op_schema.args_schema[0]) + indices_strategy = cast(OpStrategy, op_schema.args_schema[1]) + + grad_out_shape = grad_out_strategy.shape + indices_shape = indices_strategy.shape + grad_out_ndim = len(grad_out_shape) + + single_mesh_dim_strategies = [] + + # placement list stores placements of [output, weight, input_indices] + # first we always have replicate all for inputs and output + all_replicate: PlacementList = [Replicate()] * 3 + single_mesh_dim_strategies.append(all_replicate) + + # colwise sharding backward, grad_out shard on last dim, input replicate, + # weight grad shard colwise + colwise_sharding: PlacementList = [Shard(1), Shard(grad_out_ndim - 1), Replicate()] + single_mesh_dim_strategies.append(colwise_sharding) + + # batch dim sharding, weight replicated, grad_out/input have same sharding + # that can shard on any dim, weight grad partial + for input_dim in range(len(indices_shape)): + batch_sharding: PlacementList = [Partial(), Shard(input_dim), Shard(input_dim)] + single_mesh_dim_strategies.append(batch_sharding) + + # grad_out partial, input replicate, weight grad keep partial + partial_sharding: PlacementList = [Partial(), Partial(), Replicate()] + single_mesh_dim_strategies.append(partial_sharding) + + return expand_to_full_mesh_op_strategy(mesh, op_schema, single_mesh_dim_strategies) diff --git a/src/chop/distributed/tensor/ops/experimental_ops.py b/src/chop/distributed/tensor/ops/experimental_ops.py new file mode 100644 index 000000000..432fbede8 --- /dev/null +++ b/src/chop/distributed/tensor/ops/experimental_ops.py @@ -0,0 +1,26 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +# implement matrix related ops for distributed tensor + +import torch +from torch.distributed._tensor._op_schema import ( + OpSchema, + OpStrategy, + PlacementStrategy, + StrategyType, +) +from torch.distributed._tensor.device_mesh import DeviceMesh +from chop.distributed.tensor.ops.utils import register_op_strategy +from torch.distributed._tensor.placement_types import DTensorSpec, Replicate + + +aten = torch.ops.aten + + +@register_op_strategy(aten.slice_backward.default) +def slice_backward_rules(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: + """ + slice_backward is a new_zeros + slice_scatter, we only allow replication + on the input/output for now since new_zeros would produce replication + """ + replicate_spec = DTensorSpec(mesh, tuple([Replicate()] * mesh.ndim)) + return OpStrategy([PlacementStrategy(replicate_spec)]) diff --git a/src/chop/distributed/tensor/ops/math_ops.py b/src/chop/distributed/tensor/ops/math_ops.py new file mode 100644 index 000000000..d18396873 --- /dev/null +++ b/src/chop/distributed/tensor/ops/math_ops.py @@ -0,0 +1,1056 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates +import math +from dataclasses import dataclass +from enum import Enum +from typing import cast, List, Optional, Sequence, Tuple, Union + +import torch +from torch.distributed._tensor._op_schema import ( + OpSchema, + OpStrategy, + PlacementList, + PlacementStrategy, + RuntimeSchemaInfo, + TupleStrategy, +) +from torch.distributed._tensor.ops.utils import ( + as_list, + expand_to_full_mesh_op_strategy, + generate_redistribute_costs, + is_tensor_evenly_shardable, + normalize_dim, + normalize_dims, + normalize_to_torch_size, +) +from torch.distributed._tensor.placement_types import ( + DTensorSpec, + Partial, + Placement, + Replicate, + Shard, +) +from torch.distributed.device_mesh import DeviceMesh +from chop.distributed.tensor.ops.utils import register_op_strategy + +aten = torch.ops.aten + + +class Reduction(Enum): + NONE = 0 + MEAN = 1 + SUM = 2 + + +@dataclass(frozen=True) +class NormReduction: + norm_type: Union[int, float, str] + + +ReductionOpType = Union[NormReduction, str] + + +@dataclass(frozen=True) +class _NormPartial(Partial): + """ + This placement is used for partial vector norm. + + For p-norms (where p not inf or -inf), the p-norm over n elements computes + (sum_i x_i^p)^(1/p) + where the sum is from i=1 to n. The reduction op is the p-norm itself. + For example, consider 2 ranks, a (4,) tensor sharded on dim-0, and 2-norm: + Rank 0: [t1, t2] | Rank 1: [t3, t4] + After computing 2-norm per gradient (partial placement): + Rank 0: [sqrt(t1^2 + t2^2)] | Rank 1: [sqrt(t3^2 + t4^2)] + Converting from partial to replicate wants to ultimately get: + Rank 0/1: [sqrt(t1^2 + t2^2 + t3^2 + t4^2)] + This can be achieved by computing 2-norm on each rank's result. This holds + similarly for inf and -inf norm. For 0-norm, the reduction op is sum. + """ + + norm_type: Union[int, float, str] = 2 + + def __post_init__(self): + """Set the appropriate reduce op based on the norm type.""" + # Use `object.__setattr__` to bypass frozen checks + if self.norm_type in (float("inf"), "inf"): + object.__setattr__(self, "reduce_op", "max") + elif self.norm_type in (float("-inf"), "-inf"): + object.__setattr__(self, "reduce_op", "min") + elif isinstance(self.norm_type, (int, float)): + object.__setattr__(self, "reduce_op", "sum") + else: + raise NotImplementedError(f"Unsupported norm type: {self.norm_type}") + + def _partition_value( + self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int + ) -> torch.Tensor: + """ + For example, consider 4 ranks, a (3,) replicated tensor, and 2-norm: + Ranks 0 and 1: sqrt(t1^2 + t2^2 + t3^3) + To convert from replicated to partial, we want f(x) such that + sqrt(t1^2 + t2^2 + t3^3) = sqrt(4f(t1)^2 + 4f(t2)^2 + 4f(t3)^2) + = sqrt(4) sqrt(f(t1)^2 + f(t2)^2 + f(t3)^2). + One such f(x) is f(x) = x / sqrt(4). This generalizes to d ranks and + p-norm as f(x) = x / d^(1/p). + """ + if self.reduce_op in ("max", "min"): + return tensor + elif self.reduce_op == "sum": + if self.norm_type == 0: + raise NotImplementedError(f"Unsupported norm type:: {self.norm_type}") + elif self.norm_type == 1: + return tensor / mesh.size(mesh_dim) + assert isinstance(self.norm_type, (int, float)) + return tensor / math.pow(mesh.size(mesh_dim), 1 / self.norm_type) + raise NotImplementedError(self.reduce_op) + + def _reduce_shard_value( + self, + tensor: torch.Tensor, + mesh: DeviceMesh, + mesh_dim: int, + shard_spec: Placement, + ) -> torch.Tensor: + assert isinstance(shard_spec, Shard), f"{shard_spec}" + tensor = self._pre_reduce_transform(tensor) + reduced_tensor = super()._reduce_shard_value(tensor, mesh, mesh_dim, shard_spec) + return self._post_reduce_transform(reduced_tensor) + + def _reduce_value( + self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int + ) -> torch.Tensor: + tensor = self._pre_reduce_transform(tensor) + reduced_tensor = super()._reduce_value(tensor, mesh, mesh_dim) + return self._post_reduce_transform(reduced_tensor) + + def _pre_reduce_transform(self, tensor: torch.Tensor) -> torch.Tensor: + if self.reduce_op == "sum": + assert isinstance(self.norm_type, (int, float)), f"{self.norm_type}" + if self.norm_type != 0 and self.norm_type != 1: + return tensor**self.norm_type + return tensor + + def _post_reduce_transform(self, tensor: torch.Tensor) -> torch.Tensor: + if self.reduce_op == "sum": + assert isinstance(self.norm_type, (int, float)), f"{self.norm_type}" + if self.norm_type != 0 and self.norm_type != 1: + return tensor ** (1.0 / self.norm_type) + return tensor + + def __eq__(self, other: object) -> bool: + if not isinstance(other, _NormPartial): + return False + return self.norm_type == other.norm_type + + def __hash__(self) -> int: + return 1 + hash(self.norm_type) + + +def _infer_reduction_dims(dims_arg: object, ndim: int) -> Optional[List[int]]: + if dims_arg is None: + return None + dims = cast(List[int], as_list(dims_arg)) + dims = cast(List[int], normalize_dims(dims, ndim)) + empty_dims = [[0], [-1], []] + if ndim == 0 and dims_arg in empty_dims: + return None + return dims + + +def _infer_reduce_dims_map( + reduction_dims: List[int], input_ndim: int, keep_dim=False +) -> List[int]: + reduction_dims_map = [] + new_dim_count = 0 + for input_dim in range(input_ndim): + if input_dim in reduction_dims and not keep_dim: + # if input dim in reduction dims, mark it as -1 + reduction_dims_map.append(-1) + else: + # otherwise mark it as the new dim + reduction_dims_map.append(new_dim_count) + new_dim_count += 1 + + return reduction_dims_map + + +def _replicate_dims_start_at( + placements: Sequence[Placement], start_dim: int = 0 +) -> Tuple[Placement, ...]: + new_placements: List[Placement] = [] + for p in placements: + if p.is_partial() or (isinstance(p, Shard) and p.dim >= start_dim): + new_placements.append(Replicate()) # make it replicate + else: + new_placements.append(p) # keep the placement + return tuple(new_placements) + + +# return new_placements which align with placements but skip the skipped_dim +def _skip_dim( + placements: Tuple[Placement, ...], skipped_dim: int +) -> Tuple[Placement, ...]: + new_placements: List[Placement] = [] + for p in placements: + if isinstance(p, Shard) and p.dim >= skipped_dim: + new_placements.append(Shard(p.dim - 1)) + else: + new_placements.append(p) + return tuple(new_placements) + + +def replicate_reduction_dims( + placements: Tuple[Placement, ...], reduction_dims: List[int] +) -> Tuple[Placement, ...]: + # replicate the reduction dims if not reduction_linear + new_placements: List[Placement] = [] + + for p in placements: + if p.is_partial(): + new_placements.append(Replicate()) + elif isinstance(p, Shard) and p.dim in reduction_dims: + new_placements.append(Replicate()) + else: + new_placements.append(p) + + return tuple(new_placements) + + +def map_placements_after_reduction( + placements: Tuple[Placement, ...], + reduction_dims: List[int], + reduction_dims_map: List[int], + reduction_op: ReductionOpType, +) -> Tuple[Placement, ...]: + """ + Map each placement based on the output shape after reduction. + """ + new_placements: List[Placement] = [] + for placement in placements: + if isinstance(placement, (Replicate, Partial)): + new_placements.append(placement) + else: + assert isinstance(placement, Shard) + shard_dim = placement.dim + new_shard_dim = reduction_dims_map[shard_dim] + if new_shard_dim == -1 or shard_dim in reduction_dims: + # if new_shard_dim collapsed or its in the reduction dims + # (i.e. for the case where keepdims=True), we generate partial + new_placements.append(get_placement_from_reduction_op(reduction_op)) + else: + new_placements.append(Shard(new_shard_dim)) + return tuple(new_placements) + + +def get_placement_from_reduction_op(reduction_op: ReductionOpType) -> Placement: + if isinstance(reduction_op, NormReduction): + return _NormPartial(norm_type=reduction_op.norm_type) + return Partial(reduction_op) + + +def common_reduction_strategy( + mesh: DeviceMesh, + input_strategy: OpStrategy, + reduce_dims: List[int], + keep_dim: bool = False, + reduction_linear: bool = True, + reduction_op: ReductionOpType = "sum", +) -> OpStrategy: + """ + reduction_linear means that the reduction `f` follows this rule: + f([f(a), f(b)]) = f([a, b]) + + reduction linear should be super set of linearity. + """ + # by default follow reduction input strategy + reduction_strategy = OpStrategy([]) + + for strtg in input_strategy.strategies: + if not reduction_linear: + # input placements for this strategy should clear out pending sum and sharding + # on the reduction dimension + input_placements = replicate_reduction_dims( + strtg.output_spec.placements, reduce_dims + ) + else: + input_placements = strtg.output_spec.placements + + input_spec = DTensorSpec( + mesh=mesh, + placements=input_placements, + tensor_meta=strtg.output_spec.tensor_meta, + ) + + reduce_dims_map = _infer_reduce_dims_map(reduce_dims, input_spec.ndim, keep_dim) + out_placements = map_placements_after_reduction( + input_spec.placements, reduce_dims, reduce_dims_map, reduction_op + ) + redistribute_cost = [generate_redistribute_costs(input_strategy, input_spec)] + reduction_strategy.strategies.append( + PlacementStrategy( + output_specs=DTensorSpec( + mesh=mesh, + placements=out_placements, + ), + input_specs=(input_spec,), + redistribute_cost=redistribute_cost, + ) + ) + + return reduction_strategy + + +LINEAR_REDUCTION_OP_MAP = { + aten.all.default: "sum", + aten.all.dim: "sum", + aten.sum.default: "sum", + aten.sum.dim_IntList: "sum", + aten.prod.default: "product", + aten.prod.dim_int: "product", + aten.prod.int_out: "product", + aten.mean.default: "avg", + aten.mean.dim: "avg", + aten.mean.out: "avg", + aten.max.default: "max", + aten.max.dim: "max", + aten.max.out: "max", + aten.min.default: "min", + aten.min.dim: "min", + aten.min.out: "min", +} + + +@register_op_strategy( + list(LINEAR_REDUCTION_OP_MAP.keys()), schema_info=RuntimeSchemaInfo(1) +) +def linear_reduction_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy: + args_schema = op_schema.args_schema + input_strategy = args_schema[0] + assert isinstance(input_strategy, OpStrategy) + dims = None + if len(op_schema.args_schema) > 1: + dims = _infer_reduction_dims(args_schema[1], input_strategy.ndim) + + reduce_dims = list(range(input_strategy.ndim)) if dims is None else dims + + keep_dim = len(op_schema.args_schema) > 2 and bool(op_schema.args_schema[2]) + reduction_op = LINEAR_REDUCTION_OP_MAP[op_schema.op] + return common_reduction_strategy( + mesh, + input_strategy, + reduce_dims, + keep_dim=keep_dim, + reduction_linear=True, + reduction_op=reduction_op, + ) + + +@register_op_strategy( + [aten.var.correction, aten.var.correction_out], + schema_info=RuntimeSchemaInfo(1, ["keepdim"]), +) +def var_reduction_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy: + args_schema = op_schema.args_schema + input_strategy = args_schema[0] + assert isinstance(input_strategy, OpStrategy) + dims = None + if len(op_schema.args_schema) > 1: + dims = _infer_reduction_dims(args_schema[1], input_strategy.ndim) + + reduce_dims = list(range(input_strategy.ndim)) if dims is None else dims + + keep_dim = cast(bool, op_schema.kwargs_schema.get("keepdim", False)) + return common_reduction_strategy( + mesh, input_strategy, reduce_dims, keep_dim=keep_dim, reduction_linear=False + ) + + +@register_op_strategy( + [aten.linalg_vector_norm.default], schema_info=RuntimeSchemaInfo(1) +) +def vector_norm_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy: + args_schema = op_schema.args_schema + input_strategy = args_schema[0] + assert isinstance(input_strategy, OpStrategy) + norm_type = args_schema[1] if len(args_schema) > 1 else 2 + assert isinstance(norm_type, (int, float, str)), f"{norm_type}" + dim = args_schema[2] if len(args_schema) > 2 else None + keepdim = args_schema[3] if len(args_schema) > 3 else False + dims = _infer_reduction_dims(dim, input_strategy.ndim) + reduce_dims = list(range(input_strategy.ndim)) if dims is None else dims + return common_reduction_strategy( + mesh, + input_strategy, + reduce_dims, + keep_dim=cast(bool, keepdim), + reduction_linear=True, + reduction_op=NormReduction(norm_type), + ) + + +@register_op_strategy( + [aten._foreach_norm.Scalar], schema_info=RuntimeSchemaInfo(1, needs_pytree=True) +) +def foreach_norm_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> TupleStrategy: + args_schema = op_schema.args_schema + input_tuple_strategy = args_schema[0] + assert isinstance(input_tuple_strategy, TupleStrategy) + norm_type = args_schema[1] if len(args_schema) > 1 else 2 + assert isinstance(norm_type, (int, float, str)), f"{norm_type}" + output_tuple_strategy_childs: List[OpStrategy] = [] + for op_strategy in input_tuple_strategy.childs: + assert isinstance(op_strategy, OpStrategy), f"{op_strategy}" + reduce_dims = list(range(op_strategy.ndim)) + output_strategy = common_reduction_strategy( + mesh, + op_strategy, + reduce_dims, + reduction_linear=True, + reduction_op=NormReduction(norm_type), + ) + output_tuple_strategy_childs.append(output_strategy) + return TupleStrategy(output_tuple_strategy_childs) + + +@register_op_strategy([aten._linalg_svd.default], schema_info=RuntimeSchemaInfo(1)) +def linalg_svd_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy: + # Since we do not have a simple way to compute a sharded SVD, always fall + # back to replicate + args_schema = op_schema.args_schema + input_strategy = args_schema[0] + assert isinstance(input_strategy, OpStrategy), f"{input_strategy}" + output_strategies: List[PlacementStrategy] = [] + for placement_strategy in input_strategy.strategies: + replicate_placements = tuple(Replicate() for _ in range(mesh.ndim)) + replicate_spec = DTensorSpec( + mesh=mesh, + placements=replicate_placements, + tensor_meta=placement_strategy.output_spec.tensor_meta, + ) + redistribute_cost = [ + generate_redistribute_costs(input_strategy, replicate_spec) + ] + replicate_strategy = PlacementStrategy( + output_specs=replicate_spec, + input_specs=(replicate_spec,), + redistribute_cost=redistribute_cost, + ) + output_strategies.append(replicate_strategy) + return OpStrategy(output_strategies) + + +@register_op_strategy( + [aten._log_softmax.default, aten._softmax.default], schema_info=RuntimeSchemaInfo(1) +) +def softmax_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy: + input_strategy, softmax_dim, _ = op_schema.args_schema + input_strategy = cast(OpStrategy, input_strategy) + softmax_dim = cast(int, softmax_dim) + softmax_dim = normalize_dim(softmax_dim, input_strategy.ndim) + + output_strategy = OpStrategy([]) + for idx, input_placement_strategy in enumerate(input_strategy.strategies): + redistribute_costs = [] + input_src_spec = input_placement_strategy.output_spec + + # make sure input is replicated along the softmax dim + input_target_spec = DTensorSpec( + mesh=mesh, + placements=replicate_reduction_dims( + input_src_spec.placements, [softmax_dim] + ), + tensor_meta=input_src_spec.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(input_strategy, input_target_spec) + ) + output_target_spec = input_target_spec + output_strategy.strategies.append( + PlacementStrategy( + output_specs=output_target_spec, + input_specs=[input_target_spec], + redistribute_cost=redistribute_costs, + ) + ) + + return output_strategy + + +@register_op_strategy( + [ + aten._log_softmax_backward_data.default, + aten._softmax_backward_data.default, + ], + schema_info=RuntimeSchemaInfo(2), +) +def softmax_backward_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy: + grad_out_strategy, out_strategy, softmax_dim, _ = op_schema.args_schema + grad_out_strategy = cast(OpStrategy, grad_out_strategy) + out_strategy = cast(OpStrategy, out_strategy) + softmax_dim = cast(int, softmax_dim) + softmax_dim = normalize_dim(softmax_dim, grad_out_strategy.ndim) + + grad_in_strategy = OpStrategy([]) + for grad_out_placement_strat, out_placement_strat in zip( + grad_out_strategy.strategies, out_strategy.strategies + ): + # follow the sharding of the grad_out or out depending on which has more shards + grad_out_src_spec = grad_out_placement_strat.output_spec + out_src_spec = out_placement_strat.output_spec + src_spec = ( + grad_out_src_spec + if grad_out_src_spec.num_shards >= out_src_spec.num_shards + else out_src_spec + ) + + # make sure inputs are replicated along the softmax dim + tgt_spec = DTensorSpec( + mesh=mesh, + placements=replicate_reduction_dims(src_spec.placements, [softmax_dim]), + ) + redist_grad_out_cost = generate_redistribute_costs(grad_out_strategy, tgt_spec) + redist_out_cost = generate_redistribute_costs(out_strategy, tgt_spec) + grad_in_strategy.strategies.append( + PlacementStrategy( + output_specs=tgt_spec, + redistribute_cost=[redist_grad_out_cost, redist_out_cost], + ) + ) + + return grad_in_strategy + + +@register_op_strategy( + [aten.nll_loss_forward.default, aten.nll_loss2d_forward.default], + schema_info=RuntimeSchemaInfo(3), +) +def nll_loss_forward_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy: + assert len(op_schema.args_schema) == 5 + ( + input_strategy, + target_strategy, + weight_strategy, + reduction, + _, + ) = op_schema.args_schema + input_strategy = cast(OpStrategy, input_strategy) + target_strategy = cast(OpStrategy, target_strategy) + reduction = cast(int, reduction) + + input_shape = input_strategy.shape + channel_dim = 1 if len(input_shape) >= 2 else 0 + + output_strategy = OpStrategy([]) + for idx, input_placement_strategy in enumerate(input_strategy.strategies): + op_args_target_specs = [] + redistribute_costs = [] + + # make sure input is replicated along the channel dim + input_src_spec = input_placement_strategy.output_spec + input_expected_spec = DTensorSpec( + mesh=mesh, + placements=replicate_reduction_dims( + input_src_spec.placements, [channel_dim] + ), + tensor_meta=input_src_spec.tensor_meta, + ) + op_args_target_specs.append(input_expected_spec) + redistribute_costs.append( + generate_redistribute_costs(input_strategy, input_expected_spec) + ) + + # target doesn't have channel dim, and it follows input on other dims + target_src_spec = target_strategy.strategies[idx].output_spec + target_expected_spec = DTensorSpec( + mesh=mesh, + placements=_skip_dim(input_expected_spec.placements, channel_dim), + tensor_meta=target_src_spec.tensor_meta, + ) + op_args_target_specs.append(target_expected_spec) + redistribute_costs.append( + generate_redistribute_costs(target_strategy, target_expected_spec) + ) + + # weight tensor, if given, has to be a Tensor of size input_shape[channel_dim] + # make sure it is replicated + if weight_strategy is not None: + assert isinstance(weight_strategy, OpStrategy) + weight_src_spec = weight_strategy.strategies[idx].output_spec + weight_expected_spec = DTensorSpec( + mesh=mesh, + placements=_replicate_dims_start_at(weight_src_spec.placements), + tensor_meta=weight_src_spec.tensor_meta, + ) + op_args_target_specs.append(weight_expected_spec) + redistribute_costs.append( + generate_redistribute_costs(weight_strategy, weight_expected_spec) + ) + + if reduction == Reduction.NONE.value: + output_expected_spec = target_expected_spec + total_weight_expected_spec = DTensorSpec( + mesh=mesh, placements=tuple([Replicate()] * mesh.ndim) + ) + else: + if reduction == Reduction.MEAN.value: + reduction_op = "avg" + if not is_tensor_evenly_shardable( + target_expected_spec.shape, target_expected_spec + ): + raise ValueError( + "The intermediate results of nll_loss cannot be evenly sharded, \ + resulting in biased mean result." + ) + else: # reduction == Reduction.SUM.value: + reduction_op = "sum" + reduce_dims = list(range(target_expected_spec.ndim)) + reduce_dims_map = _infer_reduce_dims_map( + reduce_dims, target_expected_spec.ndim, keep_dim=False + ) + out_placements = map_placements_after_reduction( + target_expected_spec.placements, + reduce_dims, + reduce_dims_map, + reduction_op, + ) + output_expected_spec = DTensorSpec( + mesh=mesh, + placements=out_placements, + ) + + # whether reduction is sum or mean, the total weight has to be summed up if not replicated + total_weight_placements = map_placements_after_reduction( + target_expected_spec.placements, + reduce_dims, + reduce_dims_map, + "sum", + ) + total_weight_expected_spec = DTensorSpec( + mesh=mesh, + placements=total_weight_placements, + ) + + output_strategy.strategies.append( + PlacementStrategy( + output_specs=(output_expected_spec, total_weight_expected_spec), + input_specs=op_args_target_specs, + redistribute_cost=redistribute_costs, + ) + ) + + return output_strategy + + +@register_op_strategy( + [aten.nll_loss_backward.default, aten.nll_loss2d_backward.default], + schema_info=RuntimeSchemaInfo(4), +) +def nll_loss_backward_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy: + assert len(op_schema.args_schema) == 7 + ( + grad_out_strategy, + input_strategy, + target_strategy, + weight_strategy, + reduction, + _, + total_weight_strategy, + ) = op_schema.args_schema + grad_out_strategy = cast(OpStrategy, grad_out_strategy) + input_strategy = cast(OpStrategy, input_strategy) + target_strategy = cast(OpStrategy, target_strategy) + reduction = cast(int, reduction) + total_weight_strategy = cast(OpStrategy, total_weight_strategy) + + input_shape = input_strategy.shape + channel_dim = 1 if len(input_shape) >= 2 else 0 + + grad_in_strategy = OpStrategy([]) + for idx, input_placement_strategy in enumerate(input_strategy.strategies): + op_args_target_specs = [] + redistribute_costs = [] + + # make sure input is replicated along the channel dim + input_src_spec = input_placement_strategy.output_spec + input_expected_spec = DTensorSpec( + mesh=mesh, + placements=replicate_reduction_dims( + input_src_spec.placements, [channel_dim] + ), + tensor_meta=input_src_spec.tensor_meta, + ) + op_args_target_specs.append(input_expected_spec) + redistribute_costs.append( + generate_redistribute_costs(input_strategy, input_expected_spec) + ) + + # target doesn't have channel dim, and it follows input on other dims + target_src_spec = target_strategy.strategies[idx].output_spec + target_expected_spec = DTensorSpec( + mesh=mesh, + placements=_skip_dim(input_expected_spec.placements, channel_dim), + tensor_meta=target_src_spec.tensor_meta, + ) + op_args_target_specs.append(target_expected_spec) + redistribute_costs.append( + generate_redistribute_costs(target_strategy, target_expected_spec) + ) + + # grad_out follows target if there is no reduction; + # otherwise, it should be a replicated scalar. + grad_out_src_spec = grad_out_strategy.strategies[idx].output_spec + if reduction == Reduction.NONE.value: + grad_out_expected_spec = target_expected_spec + else: + grad_out_expected_spec = DTensorSpec( + mesh=mesh, + placements=_replicate_dims_start_at(grad_out_src_spec.placements), + tensor_meta=grad_out_src_spec.tensor_meta, + ) + op_args_target_specs.insert(0, grad_out_expected_spec) + redistribute_costs.insert( + 0, generate_redistribute_costs(grad_out_strategy, grad_out_expected_spec) + ) + + # weight tensor, if given, has to be a Tensor of size input_shape[channel_dim] + # make sure it is replicated + if weight_strategy is not None: + assert isinstance(weight_strategy, OpStrategy) + weight_src_spec = weight_strategy.strategies[idx].output_spec + weight_expected_spec = DTensorSpec( + mesh=mesh, + placements=_replicate_dims_start_at(weight_src_spec.placements), + tensor_meta=weight_src_spec.tensor_meta, + ) + op_args_target_specs.append(weight_expected_spec) + redistribute_costs.append( + generate_redistribute_costs(weight_strategy, weight_expected_spec) + ) + + # total_weight should always be replicated + total_weight_src_spec = total_weight_strategy.strategies[idx].output_spec + total_weight_expected_spec = DTensorSpec( + mesh=mesh, + placements=_replicate_dims_start_at(total_weight_src_spec.placements), + tensor_meta=total_weight_src_spec.tensor_meta, + ) + op_args_target_specs.append(total_weight_expected_spec) + redistribute_costs.append( + generate_redistribute_costs( + total_weight_strategy, total_weight_expected_spec + ) + ) + + grad_in_expected_spec = input_expected_spec + grad_in_strategy.strategies.append( + PlacementStrategy( + output_specs=grad_in_expected_spec, + input_specs=op_args_target_specs, + redistribute_cost=redistribute_costs, + ) + ) + + return grad_in_strategy + + +def rlog(msg): + rank = torch.distributed.get_rank() + if rank == 0: + print(msg) + + +@register_op_strategy( + [aten.native_layer_norm.default], + schema_info=RuntimeSchemaInfo(1), +) +def layer_norm_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy: + # args must be: input, normalized_shape, weight, bias, eps + # for None weight and bias, their corresponding objects will + # be None as well. layer_norm_strategy returns one OpStrategy + # for the triple return values (out, mean, rstd). + assert len(op_schema.args_schema) == 5 + ( + input_strategy, + normalized_shape, + weight_strategy, + bias_strategy, + _, + ) = op_schema.args_schema + + # the current layer norm implementation requires that all + # input DTensor's sharding must be in form of OpStrategy + assert isinstance(input_strategy, OpStrategy) + assert isinstance(normalized_shape, (int, Sequence, torch.Size)) + normalized_size = normalize_to_torch_size(normalized_shape) + + input_ndim = input_strategy.ndim + axis = input_ndim - len(normalized_size) + + rlog(f"=========== LAYER NORM STRATEGY ===========") + + rlog(f" normalized_size: {normalized_size}") + rlog(f" input_ndim: {input_ndim}") + rlog(f" axis: {axis}") + rlog(f" input_strategy: {input_strategy}") + + # we use OpStrategy because the output (out, mean, rstd) + # should have the same placements + output_strategy = OpStrategy([]) + for idx, input_placement_strategy in enumerate(input_strategy.strategies): + rlog(f"strategy {idx}: {input_placement_strategy}") + op_args_target_specs = [] + redistribute_costs = [] + input_src_spec = input_placement_strategy.output_spec + + # for the input tensor, we replicate it on the inner dims if necessary + # TODO: we can avoid forcing the redistribution once we figure out + # how to decompose layer norm + input_target_spec = DTensorSpec( + mesh=mesh, + placements=_replicate_dims_start_at(input_src_spec.placements, axis), + tensor_meta=input_src_spec.tensor_meta, + ) + op_args_target_specs.append(input_target_spec) + redistribute_costs.append( + generate_redistribute_costs(input_strategy, input_target_spec) + ) + + if weight_strategy is not None: + assert isinstance(weight_strategy, OpStrategy) + weight_src_spec = weight_strategy.strategies[idx].output_spec + + # for the weight tensor, we replicate it on all dims if necessary + # TODO: we can avoid forcing the redistribution once we figure out + # how to decompose layer norm + weight_target_spec = DTensorSpec( + mesh=mesh, + placements=_replicate_dims_start_at(weight_src_spec.placements), + tensor_meta=weight_src_spec.tensor_meta, + ) + op_args_target_specs.append(weight_target_spec) + redistribute_costs.append( + generate_redistribute_costs(weight_strategy, weight_target_spec) + ) + + if bias_strategy is not None: + assert isinstance(bias_strategy, OpStrategy) + bias_src_spec = bias_strategy.strategies[idx].output_spec + + # for the bias tensor, we replicate it on all dims if necessary + # TODO: we can avoid forcing the redistribution once we figure out + # how to decompose layer norm + bias_target_spec = DTensorSpec( + mesh=mesh, + placements=_replicate_dims_start_at(bias_src_spec.placements), + tensor_meta=bias_src_spec.tensor_meta, + ) + op_args_target_specs.append(bias_target_spec) + redistribute_costs.append( + generate_redistribute_costs(bias_strategy, bias_target_spec) + ) + + # the output spec is the same as input spec + output_target_spec = input_target_spec + output_strategy.strategies.append( + PlacementStrategy( + output_specs=output_target_spec, + input_specs=op_args_target_specs, + redistribute_cost=redistribute_costs, + ) + ) + + rlog(f"LAYER_NORM_STRATEGY output: {output_strategy}") + rlog(f"===========================================") + + return output_strategy + + +@register_op_strategy( + [aten.native_layer_norm_backward.default], + schema_info=RuntimeSchemaInfo(2), +) +def layer_norm_bwd_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy: + # args must be: grad_out, input, normalized_shape, mean, rstd, + # weight, bias, output_mask. For None weight and bias, their + # corresponding objects will be None as well. + assert len(op_schema.args_schema) == 8 + ( + grad_out_strategy, + input_strategy, + normalized_shape, + mean_strategy, + rstd_strategy, + weight_strategy, + bias_strategy, + output_mask, + ) = op_schema.args_schema + + assert isinstance(grad_out_strategy, OpStrategy) + assert isinstance(input_strategy, OpStrategy) + assert isinstance(mean_strategy, OpStrategy) + assert isinstance(rstd_strategy, OpStrategy) + + assert isinstance(normalized_shape, (int, Sequence, torch.Size)) + normalized_size = normalize_to_torch_size(normalized_shape) + input_ndim = input_strategy.ndim + axis = input_ndim - len(normalized_size) + outer_dims = list(range(axis)) + + assert isinstance(output_mask, List) and len(output_mask) == 3 + + # output triple: (d_input, d_weight, d_bias) + out_tuple_strategy = OpStrategy([]) + for idx, input_placement_strategy in enumerate(input_strategy.strategies): + # args for PlacementStrategy + output_specs_list: List[Optional[DTensorSpec]] = [] + op_args_target_specs = [] + redistribute_costs = [] + + input_src_spec = input_placement_strategy.output_spec + # arg: grad_out + # TODO: change the strategy to the following rule. + # d_input is basically a product of element-wise mul of + # grad_out, rstd, and normalized input, among which rstd + # and normalized input (x_hat) should have the same sharding + # placements, and grad_out's sharding is determined by the + # pointwise result of x_hat and weight/bias. + if output_mask[0]: + # TODO: now grad_out spec follows input spec. we may need + # to change it to apply a pointwise rule over grad_out, + # input, and weight. + grad_out_target_spec = DTensorSpec( + mesh=mesh, + placements=_replicate_dims_start_at(input_src_spec.placements, axis), + tensor_meta=input_src_spec.tensor_meta, + ) + op_args_target_specs.append(grad_out_target_spec) + redistribute_costs.append( + generate_redistribute_costs(grad_out_strategy, grad_out_target_spec) + ) + output_specs_list.append(grad_out_target_spec) + else: + output_specs_list.append(None) + + # arg: input + input_target_spec = DTensorSpec( + mesh=mesh, + placements=_replicate_dims_start_at(input_src_spec.placements, axis), + tensor_meta=input_src_spec.tensor_meta, + ) + op_args_target_specs.append(input_target_spec) + redistribute_costs.append( + generate_redistribute_costs(input_strategy, input_target_spec) + ) + + # arg: mean, rstd + mean_src_spec = mean_strategy.strategies[idx].output_spec + op_args_target_specs.append(mean_src_spec) + redistribute_costs.append([0.0 for _ in mean_strategy.strategies]) + rstd_src_spec = rstd_strategy.strategies[idx].output_spec + op_args_target_specs.append(rstd_src_spec) + redistribute_costs.append([0.0 for _ in rstd_strategy.strategies]) + + # arg: weight + # d_weight = sum(grad_out * (input - mean) / rstd, outer_dim, keepdim=False) + if output_mask[1]: + assert isinstance(weight_strategy, OpStrategy) + weight_src_spec = weight_strategy.strategies[idx].output_spec + # no need to redistribute weight since they should be replicated + # in forward pass + op_args_target_specs.append(weight_src_spec) + redistribute_costs.append([0.0 for _ in weight_strategy.strategies]) + # TODO: now d_weight spec follows input spec w/ a reduction. + # we may need to change to a pointwise rule over grad_out and + # input, then apply a reduction. + inp_placements = _replicate_dims_start_at(input_src_spec.placements, axis) + reduce_dims_map = _infer_reduce_dims_map( + outer_dims, input_src_spec.ndim, False + ) + out_placements = map_placements_after_reduction( + inp_placements, outer_dims, reduce_dims_map, "sum" + ) + output_specs_list.append( + DTensorSpec( + mesh=mesh, + placements=out_placements, + tensor_meta=weight_src_spec.tensor_meta, + ) + ) + else: + output_specs_list.append(None) + + # arg: bias + # d_bias = sum(grad_out, outer_dim, keepdim=False) + if output_mask[2]: + assert isinstance(bias_strategy, OpStrategy) + bias_src_spec = bias_strategy.strategies[idx].output_spec + # no need to redistribute weight since they should be replicated + # in forward pass + op_args_target_specs.append(bias_src_spec) + redistribute_costs.append([0.0 for _ in bias_strategy.strategies]) + # Currently we do not support the case where output_mask[0] is False while + # output_mask[1] is True. But it's easy to support that by accessing + # grad_out_spec via a local variable rather than the list. We just don't + # see the case. + grad_out_spec = output_specs_list[0] + assert isinstance(grad_out_spec, DTensorSpec) + # d_bias spec follows a reduction over grad_out + inp_placements = _replicate_dims_start_at(grad_out_spec.placements, axis) + reduce_dims_map = _infer_reduce_dims_map( + outer_dims, grad_out_spec.ndim, False + ) + out_placements = map_placements_after_reduction( + inp_placements, outer_dims, reduce_dims_map, "sum" + ) + output_specs_list.append( + DTensorSpec( + mesh=mesh, + placements=out_placements, + tensor_meta=bias_src_spec.tensor_meta, + ) + ) + else: + output_specs_list.append(None) + + out_tuple_strategy.strategies.append( + PlacementStrategy( + output_specs=tuple(output_specs_list), + input_specs=op_args_target_specs, + redistribute_cost=redistribute_costs, + ) + ) + + return out_tuple_strategy + + +@register_op_strategy( + [aten.topk.default], + schema_info=RuntimeSchemaInfo(2), +) +def topk_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy: + input_strategy = cast(OpStrategy, op_schema.args_schema[0]) + k = cast(int, op_schema.args_schema[1]) + input_shape = input_strategy.shape + topk_dim = ( + cast(int, op_schema.args_schema[2]) if len(op_schema.args_schema) > 2 else -1 + ) + topk_dim = normalize_dim(topk_dim, input_strategy.ndim) + + single_mesh_dim_strategies = [] + + # two outputs (values, indices), 1 input + # replicate always works + all_replicate: PlacementList = [Replicate()] * 3 + single_mesh_dim_strategies.append(all_replicate) + + # every dim except topk dim should work + for dim in range(input_strategy.ndim): + if dim != topk_dim: + dim_shardings: PlacementList = [Shard(dim)] * 3 + single_mesh_dim_strategies.append(dim_shardings) + # TODO: topk on sharded dim requries non-trival reduction, address it later + + return expand_to_full_mesh_op_strategy( + mesh, op_schema, single_mesh_dim_strategies, input_index=2 + ) diff --git a/src/chop/distributed/tensor/ops/matrix_ops.py b/src/chop/distributed/tensor/ops/matrix_ops.py new file mode 100644 index 000000000..77484de7d --- /dev/null +++ b/src/chop/distributed/tensor/ops/matrix_ops.py @@ -0,0 +1,459 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +# implement matrix related ops for distributed tensor + +import torch +from torch.distributed._tensor._op_schema import ( + OpSchema, + OpStrategy, + PlacementList, + PlacementStrategy, +) +from torch.distributed._tensor.ops.utils import ( + expand_to_full_mesh_op_strategy, + generate_redistribute_costs, + infer_broadcast_dims_map, + is_tensor_shardable, + map_placements_after_broadcast, +) +from torch.distributed._tensor.placement_types import ( + DTensorSpec, + Placement, + Replicate, + Shard, +) +from torch.distributed.device_mesh import DeviceMesh +from chop.distributed.tensor.ops.basic_strategy import gen_einsum_strategies +from chop.distributed.tensor.ops.utils import register_op_strategy + +aten = torch.ops.aten + + +@register_op_strategy(aten.t.default) +def transpose_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy: + self_strategy = op_schema.args_schema[0] + assert isinstance(self_strategy, OpStrategy) + + transpose_strategies = [] + for input_strategy in self_strategy.strategies: + input_spec = input_strategy.output_spec + # follow the input spec but transpose the Shard placements + output_placements = [ + Shard(1 - p.dim) if isinstance(p, Shard) else p + for p in input_spec.placements + ] + transpose_strategy = PlacementStrategy( + output_specs=DTensorSpec( + mesh=input_strategy.output_spec.mesh, + placements=tuple(output_placements), + ), + input_specs=(input_strategy.output_spec,), + ) + transpose_strategies.append(transpose_strategy) + + return OpStrategy(strategies=transpose_strategies) + + +def _mm_like_strategy( + mm_equation: str, mesh: DeviceMesh, op_schema: OpSchema +) -> OpStrategy: + self_strategy, mat2_strategy = op_schema.args_schema + assert isinstance(self_strategy, OpStrategy) + assert isinstance(mat2_strategy, OpStrategy) + # generate all possible strategies for mm + mm_strategy = gen_einsum_strategies(mm_equation, mesh) + # filter out invalid strategies and associate costs + strategies = mm_strategy.strategies + filtered_strategies = [] + for strtg in strategies: + assert strtg.input_specs is not None + self_spec = strtg.input_specs[0] + mat2_spec = strtg.input_specs[1] + if is_tensor_shardable(self_strategy.shape, self_spec) and is_tensor_shardable( + mat2_strategy.shape, mat2_spec + ): + redistribute_cost = [ + generate_redistribute_costs(self_strategy, self_spec), + generate_redistribute_costs(mat2_strategy, mat2_spec), + ] + strtg.redistribute_cost = redistribute_cost + filtered_strategies.append(strtg) + + mm_strategy.strategies = filtered_strategies + + return mm_strategy + + +def _addmm_like_strategy( + mm_equation: str, mesh: DeviceMesh, op_schema: OpSchema +) -> OpStrategy: + self_strategy, mat1_strategy, mat2_strategy = op_schema.args_schema + assert isinstance(self_strategy, OpStrategy) + assert isinstance(mat1_strategy, OpStrategy) + assert isinstance(mat2_strategy, OpStrategy) + self_shape = self_strategy.shape + mm_out_shape = torch.Size( + [ + mat2_strategy.shape[-1] if i == len(mat1_strategy.shape) - 1 else dim_size + for i, dim_size in enumerate(mat1_strategy.shape) + ] + ) + # generate all possible strategies for mm + mm_strategy = gen_einsum_strategies(mm_equation, mesh) + # filter out invalid strategies and associate costs + strategies = mm_strategy.strategies + filtered_strategies = [] + for strtg in strategies: + # construct new strategy by consider the self arg + assert strtg.input_specs is not None + mat1_spec = strtg.input_specs[0] + mat2_spec = strtg.input_specs[1] + out_spec = strtg.output_spec + + # self arg's spec should follow the output of mm, but need + # to consider broadcast for the self arg + broadcast_dims_map = infer_broadcast_dims_map(mm_out_shape, self_shape) + self_placements = map_placements_after_broadcast( + out_spec.placements, mm_out_shape, broadcast_dims_map + ) + self_spec = DTensorSpec(mesh=mesh, placements=self_placements) + + if is_tensor_shardable(mat1_strategy.shape, mat1_spec) and is_tensor_shardable( + mat2_strategy.shape, mat2_spec + ): + # update input specs with new self spec + strtg.input_specs = (self_spec, mat1_spec, mat2_spec) + + # associate costs + redistribute_cost = [ + generate_redistribute_costs(self_strategy, self_spec), + generate_redistribute_costs(mat1_strategy, mat1_spec), + generate_redistribute_costs(mat2_strategy, mat2_spec), + ] + strtg.redistribute_cost = redistribute_cost + filtered_strategies.append(strtg) + + mm_strategy.strategies = filtered_strategies + + return mm_strategy + + +@register_op_strategy(aten.mm.default) +def mm_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy: + return _mm_like_strategy("mk,kn->mn", mesh, op_schema) + + +@register_op_strategy(aten.addmm.default) +def addmm_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy: + return _addmm_like_strategy("mk,kn->mn", mesh, op_schema) + + +@register_op_strategy(aten.bmm.default) +def bmm_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy: + return _mm_like_strategy("bmk,bkn->bmn", mesh, op_schema) + + +@register_op_strategy(aten.baddbmm.default) +def baddmm_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy: + return _addmm_like_strategy("bmk,bkn->bmn", mesh, op_schema) + + +@register_op_strategy(aten._scaled_dot_product_flash_attention.default) +def scaled_dot_product_flash_attention_strategy( + mesh: DeviceMesh, op_schema: OpSchema +) -> OpStrategy: + # NOTE: currently we only support some simple strategies to support tensor parallelism + # TODO: sdpa might be a good candidate for us to explore decomposed sharding propagation + # as it involves: matmul, pointwise, reduction ops together. + return_debug_mask = len(op_schema.args_schema) >= 6 and op_schema.args_schema[5] + q_input_strategy = op_schema.args_schema[0] + assert isinstance(q_input_strategy, OpStrategy) + # assuming q/k/v have the same shape + qkv_shape = q_input_strategy.shape + + single_mesh_dim_strategies = [] + + # placement list stores placements of [outputs, inputs] + # in the spda case, we have 3 valid tensor outputs and 3 tensor inputs + # first we can always accept full replication for both inputs and outputs + all_replicate: PlacementList = [ + Replicate(), + Replicate(), + None, # cum_seq_q + None, # cum_seq_k + None, # max_q + None, # max_k + None, # philox_seed + None, # philox_offset + Replicate(), + Replicate(), + Replicate(), + Replicate(), + ] + single_mesh_dim_strategies.append(all_replicate) + + # second we can accept the sharding pattern of tensor parallelism, which + # shard on the num of head dim + qkv_sharding = Shard(1) # num head dim + output_sharding = Shard(1) # num head dim + logsumexp_sharding = Shard(1) # num head dim + if return_debug_mask: + debug_attn_mask_sharding: Placement = Shard(1) # num head dim + else: + # empty debug mask, replicated + debug_attn_mask_sharding = Replicate() + + num_heads_dim_sharding: PlacementList = [ + output_sharding, + logsumexp_sharding, + None, # cum_seq_q + None, # cum_seq_k + None, # max_q + None, # max_k + None, # philox_seed + None, # philox_offset + debug_attn_mask_sharding, + qkv_sharding, + qkv_sharding, + qkv_sharding, + ] + single_mesh_dim_strategies.append(num_heads_dim_sharding) + + # Context Parallelism: shards on the sequence dim + single_mesh_dim_strategies.append( + [ + Shard(2), # output + Shard(2), # logsumexp + None, # cum_seq_q + None, # cum_seq_k + None, # max_q + None, # max_k + None, # philox_seed + None, # philox_offset + Shard(2), # debugattn + Shard(2), # q + Shard(2), # k + Shard(2), # v + ] + ) + return expand_to_full_mesh_op_strategy( + mesh, op_schema, single_mesh_dim_strategies, input_index=9 + ) + + +@register_op_strategy(aten._scaled_dot_product_flash_attention_backward.default) +def scaled_dot_product_flash_attention_backward_strategy( + mesh: DeviceMesh, op_schema: OpSchema +) -> OpStrategy: + q_input_strategy = op_schema.args_schema[1] + assert isinstance(q_input_strategy, OpStrategy) + # assuming q/k/v have the same shape + qkv_shape = q_input_strategy.shape + + tensor_input_indices = [ + i + for i, arg_spec in enumerate(op_schema.args_schema) + if isinstance(arg_spec, OpStrategy) + ] + num_tensor_inputs = len(tensor_input_indices) + + single_mesh_dim_strategies = [] + + # placement list stores placements of [outputs, inputs] + # in the spda backward case, we have 3 tensor outputs and 6 to 10 tensor inputs + # first we can always accept full replication for both inputs and outputs + all_replicate: PlacementList = [Replicate()] * (3 + num_tensor_inputs) + + single_mesh_dim_strategies.append(all_replicate) + + # second we can accept the sharding pattern of tensor parallelism, which + # shard on the num of head dim + grad_output_sharding = Shard(1) # num head dim + qkv_sharding = Shard(1) # num head dim + output_sharding = Shard(1) # num head dim + logsumexp_sharding = Shard(1) # num head dim + grad_qkv_sharding = Shard(1) # num head dim + + num_heads_dim_sharding: PlacementList = [ + grad_qkv_sharding, + grad_qkv_sharding, + grad_qkv_sharding, + grad_output_sharding, + qkv_sharding, + qkv_sharding, + qkv_sharding, + output_sharding, + logsumexp_sharding, + ] + # accept replicate on the rest tensor inputs, potentially + # cum_seq_q, cum_seq_k, philox_seed, philox_offset + # at indices 6, 7, 12, 13, respectively + num_heads_dim_sharding.extend([Replicate()] * (num_tensor_inputs - 6)) + single_mesh_dim_strategies.append(num_heads_dim_sharding) + + # Context Parallelism: shards on the sequence dim + seq_dim_sharding: PlacementList = [ + Shard(2), # grad_q + Shard(2), # grad_k + Shard(2), # grad_v + Shard(2), # grad_output + Shard(2), # q + Shard(2), # k + Shard(2), # v + Shard(2), # output + Shard(2), # logsumexp + ] + # accept replicate on the rest tensor inputs, potentially + # cum_seq_q, cum_seq_k, philox_seed, philox_offset + # at indices 6, 7, 12, 13, respectively + seq_dim_sharding.extend([Replicate()] * (num_tensor_inputs - 6)) + single_mesh_dim_strategies.append(seq_dim_sharding) + + return expand_to_full_mesh_op_strategy( + mesh, op_schema, single_mesh_dim_strategies, input_index=3 + ) + + +@register_op_strategy(aten.constant_pad_nd.default) +def constant_pad_nd_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy: + # TODO(d4l3k); implement a more correct strategy for constant_pad_nd + return OpStrategy( + [ + PlacementStrategy( + output_specs=DTensorSpec(mesh, (Replicate(),)), + input_specs=( + DTensorSpec(mesh, (Replicate(),)), + DTensorSpec(mesh, (Replicate(),)), + ), + redistribute_cost=[[1]], + ) + ] + ) + + +@register_op_strategy(aten._scaled_dot_product_efficient_attention.default) +def scaled_dot_product_efficient_attention_strategy( + mesh: DeviceMesh, op_schema: OpSchema +) -> OpStrategy: + # NOTE: currently we only support some simple strategies to support tensor parallelism + q_input_strategy = op_schema.args_schema[0] + assert isinstance(q_input_strategy, OpStrategy) + # assuming q/k/v have the same shape + qkv_shape = q_input_strategy.shape + has_attn_bias = op_schema.args_schema[3] is not None + compute_log_sumexp = op_schema.args_schema[4] + + single_mesh_dim_strategies = [] + + # placement list stores placements of [outputs, inputs] + # in the spda case, we have 2 valid tensor outputs and 3 or 4 tensor inputs + # first we can always accept full replication for both inputs and outputs + all_replicate: PlacementList = [ + Replicate(), + Replicate(), + None, + None, + Replicate(), + Replicate(), + Replicate(), + ] + if has_attn_bias: + all_replicate.append(Replicate()) # attn bias + single_mesh_dim_strategies.append(all_replicate) + + # second we can accept the sharding pattern of tensor parallelism, which + # shard on the heads dimension + qkv_sharding = Shard(1) + output_sharding = Shard(1) + if compute_log_sumexp: + logsumexp_sharding: Placement = Shard(1) + else: + # empty logsumexp, replicated + logsumexp_sharding = Replicate() + + num_heads_dim_sharding = [ + output_sharding, + logsumexp_sharding, + None, + None, + qkv_sharding, + qkv_sharding, + qkv_sharding, + ] + if has_attn_bias: + num_heads_dim_sharding.append(Shard(1)) + single_mesh_dim_strategies.append(num_heads_dim_sharding) + + return expand_to_full_mesh_op_strategy( + mesh, + op_schema, + single_mesh_dim_strategies, + input_index=4, + ) + + +@register_op_strategy(aten._scaled_dot_product_efficient_attention_backward.default) +def scaled_dot_product_efficient_attention_backward_strategy( + mesh: DeviceMesh, op_schema: OpSchema +) -> OpStrategy: + q_input_strategy = op_schema.args_schema[1] + assert isinstance(q_input_strategy, OpStrategy) + # assuming q/k/v have the same shape + qkv_shape = q_input_strategy.shape + has_attn_bias = op_schema.args_schema[4] is not None + + tensor_input_indices = [ + i + for i, arg_spec in enumerate(op_schema.args_schema) + if isinstance(arg_spec, OpStrategy) + ] + + single_mesh_dim_strategies = [] + + # placement list stores placements of [outputs, inputs] + # in the spda backward case, we have 4 tensor outputs and 8 or 9 tensor inputs + # NOTE: Output sharding of grad_bias on heads dim if attn_bias is present; + # otherwise grad_bias will be empty and its DTensorSpec will be removed. + # first we can always accept full replication for both inputs and outputs + all_replicate: PlacementList = [Replicate()] * (12 + has_attn_bias) + + if not has_attn_bias: + all_replicate[3] = None # grad bias is None if attn_bias is not present + + single_mesh_dim_strategies.append(all_replicate) + + # second we can accept the sharding pattern of tensor parallelism, which + # shard on the heads dimension + grad_output_sharding = Shard(1) + qkv_sharding = Shard(1) + output_sharding = Shard(1) + logsumexp_sharding = Shard(1) + grad_qkv_sharding = Shard(1) + grad_bias_sharding = Shard(1) if has_attn_bias else None + + num_heads_dim_sharding: PlacementList = [ + grad_qkv_sharding, + grad_qkv_sharding, + grad_qkv_sharding, + grad_bias_sharding, + grad_output_sharding, + qkv_sharding, + qkv_sharding, + qkv_sharding, + # the place for optional input attn_bias, + output_sharding, + logsumexp_sharding, + ] + # input sharding of attn_bias on heads dim if present + if has_attn_bias: + num_heads_dim_sharding.insert(8, Shard(1)) + # accept replicate on the rest scalar tensor inputs + # namely philox_seed and philox_offset + num_heads_dim_sharding.extend([Replicate(), Replicate()]) + single_mesh_dim_strategies.append(num_heads_dim_sharding) + + return expand_to_full_mesh_op_strategy( + mesh, + op_schema, + single_mesh_dim_strategies, + input_index=4, + ) diff --git a/src/chop/distributed/tensor/ops/pointwise_ops.py b/src/chop/distributed/tensor/ops/pointwise_ops.py new file mode 100644 index 000000000..c3e1f082f --- /dev/null +++ b/src/chop/distributed/tensor/ops/pointwise_ops.py @@ -0,0 +1,663 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +from typing import List, Sequence, Tuple + +import torch +from torch.distributed._tensor._op_schema import ( + _is_inplace_op, + _is_out_variant_op, + OpSchema, + OpStrategy, + PlacementStrategy, + RuntimeSchemaInfo, + StrategyType, + TupleStrategy, +) +from torch.distributed._tensor.ops.utils import ( + generate_redistribute_costs, + infer_broadcast_dims_map, + map_placements_after_broadcast, + normalize_dim, +) +from torch.distributed._tensor.placement_types import ( + DTensorSpec, + Partial, + Placement, + Replicate, + Shard, +) +from torch.distributed.device_mesh import DeviceMesh +from chop.distributed.tensor.ops.utils import register_op_strategy + +aten = torch.ops.aten +# leave the remaining pointwise_ops list here for convenience, +# Below ops are some pointwise ops that are yet to be supported, +# they might not be a complete list. +# pointwise_ops = [ +# "fake_quantize_per_channel_affine", +# "fake_quantize_per_tensor_affine", +# "floor_divide", # floor_divide is deprecated +# "frexp", # multiple output pointwise op, need to add support +# "gradient", # need investigation on this op +# "imag", # complex data type only +# "quantized_batch_norm", +# "quantized_max_pool1d", +# "quantized_max_pool2d", +# "real", # complex data type only +# ] + + +linear_pointwise_ops = [ + aten.div.Scalar, # this op is linear on the first argument, and the second argument is scalar, so it fits as a linear op. + aten.div_.Scalar, # this op is linear on the first argument, and the second argument is scalar, so it fits as a linear op. + aten.to.dtype, + aten.add.Tensor, + aten.add_.Tensor, +] + + +pointwise_ops = [ + # please keep the entries below alphabetically sorted + aten.__ilshift__.Scalar, + aten.__ilshift__.Tensor, + aten.__irshift__.Scalar, + aten.__irshift__.Tensor, + aten.__lshift__.Scalar, + aten.__lshift__.Tensor, + aten.__rshift__.Scalar, + aten.__rshift__.Tensor, + aten._conj.default, + aten.abs.default, + aten.abs.out, + aten.abs_.default, + aten.acos.default, + aten.acos.out, + aten.acos_.default, + aten.acosh.default, + aten.acosh.out, + aten.acosh_.default, + aten.add.Scalar, + aten.add.out, + aten.add_.Scalar, + aten.addcdiv.default, + aten.addcdiv.out, + aten.addcdiv_.default, + aten.addcmul.default, + aten.addcmul.out, + aten.addcmul_.default, + aten.angle.default, + aten.angle.out, + aten.asin.default, + aten.asin.out, + aten.asin_.default, + aten.asinh.default, + aten.asinh.out, + aten.asinh_.default, + aten.atan.default, + aten.atan.out, + aten.atan2.default, + aten.atan2.out, + aten.atan2_.default, + aten.atan_.default, + aten.atanh.default, + aten.atanh.out, + aten.atanh_.default, + aten.bitwise_and.Scalar, + aten.bitwise_and.Scalar_Tensor, + aten.bitwise_and.Scalar_out, + aten.bitwise_and.Tensor, + aten.bitwise_and.Tensor_out, + aten.bitwise_and_.Scalar, + aten.bitwise_and_.Tensor, + aten.bitwise_left_shift.Scalar_Tensor, + aten.bitwise_left_shift.Tensor, + aten.bitwise_left_shift.Tensor_Scalar, + aten.bitwise_left_shift.Tensor_Scalar_out, + aten.bitwise_left_shift.Tensor_out, + aten.bitwise_left_shift_.Tensor, + aten.bitwise_left_shift_.Tensor_Scalar, + aten.bitwise_not.default, + aten.bitwise_not.out, + aten.bitwise_not_.default, + aten.bitwise_or.Scalar, + aten.bitwise_or.Scalar_Tensor, + aten.bitwise_or.Scalar_out, + aten.bitwise_or.Tensor, + aten.bitwise_or.Tensor_out, + aten.bitwise_or_.Scalar, + aten.bitwise_or_.Tensor, + aten.bitwise_right_shift.Scalar_Tensor, + aten.bitwise_right_shift.Tensor, + aten.bitwise_right_shift.Tensor_Scalar, + aten.bitwise_right_shift.Tensor_Scalar_out, + aten.bitwise_right_shift.Tensor_out, + aten.bitwise_right_shift_.Tensor, + aten.bitwise_right_shift_.Tensor_Scalar, + aten.bitwise_xor.Scalar, + aten.bitwise_xor.Scalar_Tensor, + aten.bitwise_xor.Scalar_out, + aten.bitwise_xor.Tensor, + aten.bitwise_xor.Tensor_out, + aten.bitwise_xor_.Scalar, + aten.bitwise_xor_.Tensor, + aten.ceil.default, + aten.ceil.out, + aten.ceil_.default, + aten.clamp.default, + aten.clamp.out, + aten.clamp_.default, + aten.clip.default, + aten.clip.out, + aten.clip_.default, + aten.conj_physical.default, + aten.conj_physical.out, + aten.conj_physical_.default, + aten.copysign.Scalar, + aten.copysign.Scalar_out, + aten.copysign.Tensor, + aten.copysign.out, + aten.copysign_.Scalar, + aten.copysign_.Tensor, + aten.cos.default, + aten.cos.out, + aten.cos_.default, + aten.cosh.default, + aten.cosh.out, + aten.cosh_.default, + aten.deg2rad.default, + aten.deg2rad.out, + aten.deg2rad_.default, + aten.digamma.default, + aten.digamma.out, + aten.digamma_.default, + aten.div.Tensor, + aten.div.Tensor_mode, + aten.div.out, + aten.div.out_mode, + aten.div_.Tensor, + aten.div_.Tensor_mode, + aten.eq.Tensor, + aten.eq.Tensor_out, + aten.eq.Scalar, + aten.eq.Scalar_out, + aten.erf.default, + aten.erf.out, + aten.erf_.default, + aten.erfc.default, + aten.erfc.out, + aten.erfc_.default, + aten.erfinv.default, + aten.erfinv.out, + aten.erfinv_.default, + aten.exp.default, + aten.exp.out, + aten.exp2.default, + aten.exp2.out, + aten.exp2_.default, + aten.exp_.default, + aten.expm1.default, + aten.expm1.out, + aten.expm1_.default, + aten.float_power.Scalar, + aten.float_power.Scalar_out, + aten.float_power.Tensor_Scalar, + aten.float_power.Tensor_Scalar_out, + aten.float_power.Tensor_Tensor, + aten.float_power.Tensor_Tensor_out, + aten.float_power_.Scalar, + aten.float_power_.Tensor, + aten.floor.default, + aten.floor.out, + aten.floor_.default, + aten.fmod.Scalar, + aten.fmod.Scalar_out, + aten.fmod.Tensor, + aten.fmod.Tensor_out, + aten.fmod_.Scalar, + aten.fmod_.Tensor, + aten.frac.default, + aten.frac.out, + aten.frac_.default, + aten.ge.Scalar, + aten.ge.Tensor, + aten.gelu.default, + aten.gt.Tensor, + aten.gt.Tensor_out, + aten.gt.Scalar, + aten.gt.Scalar_out, + aten.gt.Scalar, + aten.gt.Tensor, + aten.hypot.default, + aten.hypot.out, + aten.hypot_.default, + aten.i0.default, + aten.i0.out, + aten.i0_.default, + aten.igamma.default, + aten.igamma.out, + aten.igamma_.default, + aten.igammac.default, + aten.igammac.out, + aten.igammac_.default, + aten.isnan.default, + aten.ldexp.default, + aten.ldexp.out, + aten.ldexp_.default, + aten.lt.Tensor, + aten.lt.Tensor_out, + aten.lt.Scalar, + aten.lt.Scalar_out, + aten.le.Scalar, + aten.le.Tensor, + aten.lerp.Scalar, + aten.lerp.Scalar_out, + aten.lerp.Tensor, + aten.lerp.Tensor_out, + aten.lerp_.Scalar, + aten.lerp_.Tensor, + aten.lgamma.default, + aten.lgamma.out, + aten.lgamma_.default, + aten.log.default, + aten.log.out, + aten.log10.default, + aten.log10.out, + aten.log10_.default, + aten.log1p.default, + aten.log1p.out, + aten.log1p_.default, + aten.log2.default, + aten.log2.out, + aten.log2_.default, + aten.log_.default, + aten.logaddexp.default, + aten.logaddexp.out, + aten.logaddexp2.default, + aten.logaddexp2.out, + aten.logical_and.default, + aten.logical_and.out, + aten.logical_and_.default, + aten.logical_not.default, + aten.logical_not.out, + aten.logical_not_.default, + aten.logical_or.default, + aten.logical_or.out, + aten.logical_or_.default, + aten.logical_xor.default, + aten.logical_xor.out, + aten.logical_xor_.default, + aten.logit.default, + aten.logit.out, + aten.logit_.default, + aten.masked_fill.Scalar, + aten.maximum.out, + aten.mul.Scalar, + aten.mul.Tensor, + aten.mul.out, + aten.mul_.Scalar, + aten.mul_.Tensor, + aten.mvlgamma.default, + aten.mvlgamma.out, + aten.mvlgamma_.default, + aten.native_dropout_backward.default, + aten.native_dropout_backward.out, + aten.nan_to_num.default, + aten.nan_to_num.out, + aten.nan_to_num_.default, + aten.ne.Scalar, + aten.neg.default, + aten.neg.out, + aten.neg_.default, + aten.nextafter.default, + aten.nextafter.out, + aten.nextafter_.default, + aten.polygamma.default, + aten.polygamma.out, + aten.polygamma_.default, + aten.positive.default, + aten.pow.Scalar, + aten.pow.Scalar_out, + aten.pow.Tensor_Scalar, + aten.pow.Tensor_Scalar_out, + aten.pow.Tensor_Tensor, + aten.pow.Tensor_Tensor_out, + aten.pow_.Scalar, + aten.pow_.Tensor, + aten.reciprocal.default, + aten.reciprocal.out, + aten.reciprocal_.default, + aten.rad2deg.default, + aten.rad2deg.out, + aten.rad2deg_.default, + aten.relu.default, + aten.relu_.default, + aten.remainder.Scalar, + aten.remainder.Scalar_Tensor, + aten.remainder.Scalar_out, + aten.remainder.Tensor, + aten.remainder.Tensor_out, + aten.remainder_.Scalar, + aten.remainder_.Tensor, + aten.round.decimals, + aten.round.decimals_out, + aten.round.default, + aten.round.out, + aten.round_.decimals, + aten.round_.default, + aten.rsqrt.default, + aten.rsqrt.out, + aten.rsqrt_.default, + aten.rsub.Scalar, + aten.sgn.default, + aten.sgn.out, + aten.sgn_.default, + aten.sigmoid.default, + aten.sigmoid.out, + aten.sigmoid_.default, + aten.sign.default, + aten.sign.out, + aten.sign_.default, + aten.signbit.default, + aten.signbit.out, + aten.silu.default, + aten.silu.out, + aten.sin.default, + aten.sin.out, + aten.sin_.default, + aten.sinc.default, + aten.sinc.out, + aten.sinc_.default, + aten.sinh.default, + aten.sinh.out, + aten.sinh_.default, + aten.sqrt.default, + aten.sqrt.out, + aten.sqrt_.default, + aten.square.default, + aten.square.out, + aten.square_.default, + aten.sub.Scalar, + aten.sub.Tensor, + aten.sub.out, + aten.sub_.Scalar, + aten.sub_.Tensor, + aten.tan.default, + aten.tan.out, + aten.tan_.default, + aten.tanh.default, + aten.tanh.out, + aten.tanh_.default, + aten.true_divide.Tensor, + aten.trunc.default, + aten.trunc.out, + aten.trunc_.default, + aten.where.self, + aten.where.self_out, + aten.xlogy.OutScalar_Self, + aten.xlogy.OutScalar_Other, + aten.xlogy.OutTensor, + aten.xlogy.Scalar_Other, + aten.xlogy.Scalar_Self, + aten.xlogy.Tensor, + aten.xlogy_.Scalar_Other, + aten.xlogy_.Tensor, + # backward point-wise ops + # please keep the entries below alphabetically sorted + aten.gelu_backward.default, + aten.sigmoid_backward.default, + aten.silu_backward.default, + aten.tanh_backward.default, + aten.threshold_backward.default, +] + + +def pointwise_strategy( + mesh: DeviceMesh, op_schema: OpSchema, linearity: bool = False +) -> OpStrategy: + max_shards_strategy_index = -1 + max_shards = -1 + + if _is_inplace_op(op_schema.op): + # inplace op should follow the first arg strategy + followed_strategy = op_schema.args_schema[0] + elif _is_out_variant_op(op_schema.op): + # out variant op should follow the out kwarg strategy + followed_strategy = op_schema.kwargs_schema["out"] + else: + # normal pointwise op, we choose to follow the arg with + # the max shards in case operands needs reshard + for idx, arg_strategy in enumerate(op_schema.args_schema): + if not isinstance(arg_strategy, OpStrategy): + continue + + arg_max_shards = arg_strategy.max_num_shards() + if arg_max_shards > max_shards: + max_shards_strategy_index = idx + max_shards = arg_max_shards + + followed_strategy = op_schema.args_schema[max_shards_strategy_index] + + assert isinstance( + followed_strategy, OpStrategy + ), f"no strategy to follow for {op_schema}!" + return common_pointwise_strategy( + mesh, op_schema.args_schema, followed_strategy, linearity + ) + + +def common_pointwise_strategy( + mesh: DeviceMesh, + args_schema: Sequence[object], + followed_strategy: OpStrategy, + linearity: bool, +) -> OpStrategy: + # handle broadcasting + common_shape = torch.broadcast_shapes( + *[arg.shape for arg in args_schema if isinstance(arg, OpStrategy)] + ) + pointwise_strategy = OpStrategy([]) + + for placement_strategy in followed_strategy.strategies: + spec_to_follow = placement_strategy.output_spec + out_placements: List[Placement] = [] + for placement in spec_to_follow.placements: + if isinstance(placement, Shard): + shard_dim = normalize_dim(placement.dim, len(spec_to_follow.shape)) + common_ndim = len(common_shape) + new_shard_dim = common_ndim - len(spec_to_follow.shape) + shard_dim + out_placements.append(Shard(new_shard_dim)) + elif isinstance(placement, Partial) and not linearity: + # clear the partial placemnet if op does not support linearity + # by default we just replicate the partial, need to see if this + # is optimal for all cases + out_placements.append(Replicate()) + else: + out_placements.append(placement) + + input_specs: List[DTensorSpec] = [] + redistribute_costs: List[List[float]] = [] + for idx, input_arg in enumerate(args_schema): + if isinstance(input_arg, OpStrategy): + # every arg follow the out_placements, but need to handle broadcasting + input_arg_spec = input_arg.strategies[0].output_spec + input_arg_dims_map = infer_broadcast_dims_map( + common_shape, input_arg_spec.shape + ) + input_target_placements = map_placements_after_broadcast( + tuple(out_placements), + common_shape, + input_arg_dims_map, + ) + input_arg_target_spec = DTensorSpec( + mesh=mesh, + placements=input_target_placements, + tensor_meta=input_arg_spec.tensor_meta, + ) + input_specs.append(input_arg_target_spec) + redistribute_costs.append( + generate_redistribute_costs(input_arg, input_arg_target_spec) + ) + + pointwise_strategy.strategies.append( + PlacementStrategy( + output_specs=DTensorSpec( + mesh=mesh, + placements=tuple(out_placements), + ), + input_specs=input_specs, + redistribute_cost=redistribute_costs, + ) + ) + return pointwise_strategy + + +def linear_pointwise_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: + """ + Linear pointwise operators can propagate pending reductions. + For example, c = add(a, b); if a is pending sum, then c will be + pending sum as well without any communication overhead. + """ + return pointwise_strategy(mesh, op_schema, linearity=True) + + +for op in linear_pointwise_ops: + register_op_strategy(op, schema_info=RuntimeSchemaInfo(static_kwargkey=["out"]))( + linear_pointwise_strategy + ) + +for op in pointwise_ops: + register_op_strategy(op, schema_info=RuntimeSchemaInfo(static_kwargkey=["out"]))( + pointwise_strategy + ) + + +# TODO: add all for_each ops +for_each_ops = [ + aten._foreach_abs.default, + aten._foreach_abs_.default, + aten._foreach_addcdiv_.Scalar, + aten._foreach_addcdiv_.ScalarList, + aten._foreach_addcdiv_.Tensor, + aten._foreach_addcmul.Scalar, + aten._foreach_addcmul_.Scalar, + aten._foreach_addcmul_.ScalarList, + aten._foreach_addcmul_.Tensor, + aten._foreach_clamp_max_.Scalar, + aten._foreach_clamp_min_.Scalar, + aten._foreach_div_.List, + aten._foreach_div_.ScalarList, + aten._foreach_lerp_.Scalar, + aten._foreach_maximum_.List, + aten._foreach_mul.Scalar, + aten._foreach_mul.List, + aten._foreach_mul_.Scalar, + aten._foreach_mul_.ScalarList, + aten._foreach_mul_.Tensor, + aten._foreach_mul_.List, + aten._foreach_neg.default, + aten._foreach_neg_.default, + aten._foreach_reciprocal_.default, + aten._foreach_sub.List, + aten._foreach_sub_.Scalar, + aten._foreach_sqrt.default, + aten._foreach_sqrt_.default, + aten._foreach_zero_.default, +] + +for_each_linearity_ops = [ + aten._foreach_add.Scalar, + aten._foreach_add_.Scalar, + aten._foreach_add_.ScalarList, + aten._foreach_add.List, + aten._foreach_add_.List, +] + + +def list_pointwise_strategy( + mesh: DeviceMesh, op_schema: OpSchema, linearity: bool = False +) -> StrategyType: + """ + Apply the pointwise strategy to the zipped arguments. For example, if we + run a foreach add of two lists l1 and l2, then we apply the pointwise + strategy on each pair (l1[i], l2[i]). If the first argument is a list but + the second (or later) one is a tensor, then we broadcast the tensor by + replicating it into a list with the length of the first argument. + + Args: + mesh (DeviceMesh): device mesh for pointwise ops + op_schema (OpSchema): schema of the operator to generate strategy for + linearity (bool): specify whether op(a) + op(b) = op(a + b) + + Returns: + OpStrategy: generated strategy + """ + + def args_tuple_strategies(args_schema: Tuple[object, ...]) -> List[TupleStrategy]: + first_arg = args_schema[0] + assert isinstance(first_arg, TupleStrategy) + strategy_len = len(first_arg.childs) + tuple_strategies: List[TupleStrategy] = [] + for arg_idx, arg in enumerate(args_schema): + if isinstance(arg, TupleStrategy): + # every tuple strategy should have the same length + assert len(arg.childs) == strategy_len + tuple_strategies.append(arg) + elif isinstance(arg, OpStrategy): + if arg_idx > 0: # implicitly broadcast + tuple_strategies.append( + TupleStrategy([arg for _ in range(strategy_len)]) + ) + else: + raise RuntimeError( + f"list op only supports tuple strategy! {op_schema}" + ) + return tuple_strategies + + args_strategies = args_tuple_strategies(op_schema.args_schema) + follow_strategy: TupleStrategy = args_strategies[0] + list_strategy: List[OpStrategy] = [] + for child_idx, child_strtgy in enumerate(follow_strategy.childs): + assert isinstance(child_strtgy, OpStrategy) + args_schema: List[StrategyType] = [ + arg_strategy.childs[child_idx] for arg_strategy in args_strategies + ] + pointwise_strategy: OpStrategy = common_pointwise_strategy( + mesh, args_schema, child_strtgy, linearity + ) + list_strategy.append(pointwise_strategy) + return TupleStrategy(list_strategy) + + +def list_linear_pointwise_strategy( + mesh: DeviceMesh, op_schema: OpSchema +) -> StrategyType: + """ + for each list op stratgy that supports linearity + """ + return list_pointwise_strategy(mesh, op_schema, linearity=True) + + +for op in for_each_ops: + register_op_strategy(op, schema_info=RuntimeSchemaInfo(needs_pytree=True))( + list_pointwise_strategy + ) + +for op in for_each_linearity_ops: + register_op_strategy(op, schema_info=RuntimeSchemaInfo(needs_pytree=True))( + list_linear_pointwise_strategy + ) + +fused_ops = [ + aten._fused_adam_.default, + aten._fused_adam.default, + aten._fused_adam.tensor_lr, + aten._fused_adam_.tensor_lr, + aten._fused_adamw_.default, + aten._fused_adamw.default, + aten._fused_adamw.tensor_lr, + aten._fused_adamw_.tensor_lr, +] + +for op in fused_ops: + register_op_strategy(op, schema_info=RuntimeSchemaInfo(needs_pytree=True))( + list_pointwise_strategy + ) diff --git a/src/chop/distributed/tensor/ops/random_ops.py b/src/chop/distributed/tensor/ops/random_ops.py new file mode 100644 index 000000000..7eefa30fc --- /dev/null +++ b/src/chop/distributed/tensor/ops/random_ops.py @@ -0,0 +1,38 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +import torch +from torch.distributed._tensor._op_schema import ( + OpSchema, + OpStrategy, + PlacementStrategy, + StrategyType, +) +from torch.distributed._tensor.ops.utils import is_tensor_partial +from chop.distributed.tensor.ops.utils import register_op_strategy +from torch.distributed.device_mesh import DeviceMesh + + +aten = torch.ops.aten + + +@register_op_strategy( + [ + aten.normal_.default, + aten.uniform_.default, + aten.native_dropout.default, + aten.bernoulli_.float, + aten.bernoulli.default, + ] +) +def random_op_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: + self_strategy = op_schema.args_schema[0] + assert isinstance(self_strategy, OpStrategy) + + random_strategy = OpStrategy([]) + for arg_strategy in self_strategy.strategies: + arg_spec = arg_strategy.output_spec + if is_tensor_partial(arg_spec): + # TODO: figure out how inplace random op should behave when it's partial + raise RuntimeError(f"{op_schema.op} with Partial is not supported yet!") + random_strategy.strategies.append(PlacementStrategy(output_specs=arg_spec)) + + return random_strategy diff --git a/src/chop/distributed/tensor/ops/tensor_ops.py b/src/chop/distributed/tensor/ops/tensor_ops.py new file mode 100644 index 000000000..8e64ff514 --- /dev/null +++ b/src/chop/distributed/tensor/ops/tensor_ops.py @@ -0,0 +1,791 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates +from typing import cast, List, Optional, Sequence, Tuple + +import torch +from torch.distributed._tensor._op_schema import ( + _is_inplace_op, + OpSchema, + OpStrategy, + OutputSharding, + PlacementList, + PlacementStrategy, + RuntimeSchemaInfo, + StrategyType, + TupleStrategy, +) +from torch.distributed._tensor.placement_types import ( + DTensorSpec, + Partial, + Placement, + Replicate, + Shard, +) +from torch.distributed.device_mesh import DeviceMesh + +from chop.distributed.tensor.ops.utils import register_op_strategy +from chop.distributed.tensor.ops.common_rules import pointwise_rule +from chop.distributed.tensor.ops.embedding_ops import _MaskPartial +from chop.distributed.tensor.ops.utils import ( + expand_to_full_mesh_op_strategy, + is_tensor_dim_sharded, + is_tensor_evenly_shardable, + is_tensor_partial, + normalize_dim, + register_prop_rule, +) + +aten = torch.ops.aten + + +def default_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: + # Default strategy by default just propagate the first input strategy + select_strategy = op_schema.args_schema[0] + assert isinstance(select_strategy, OpStrategy) + default_strategy = [] + for strategy in select_strategy.strategies: + # we create new DTensorSpecs even for default strategy to assure that + # the tensor metas are distinct between the arguments and outputs + default_strategy.append( + PlacementStrategy( + output_specs=DTensorSpec( + mesh=strategy.output_spec.mesh, + placements=strategy.output_spec.placements, + ) + ) + ) + return OpStrategy(default_strategy) + + +register_op_strategy( + [ + aten.clone.default, + aten.contiguous.default, + aten.copy_.default, + aten.detach.default, + aten.fill_.Scalar, + aten.zero_.default, + ] +)(default_strategy) + +register_op_strategy( + aten._to_copy.default, schema_info=RuntimeSchemaInfo(static_kwargkey=["dtype"]) +)(default_strategy) + + +@register_op_strategy( + [ + aten.equal.default, + aten.is_same_size.default, + ] +) +def equal_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: + # equal_strategy deals with ops that comparing two tensor, we need to make sure + # sharding layout the same with two operands, we choose to follow the arg with max + # num of shards, still keep is_same_size here for completeness as they share the + # same strategy in theory. + self_strategy, other_strategy = op_schema.args_schema + assert isinstance(self_strategy, OpStrategy) + assert isinstance(other_strategy, OpStrategy) + + select_strategy = ( + self_strategy + if self_strategy.max_num_shards() >= other_strategy.max_num_shards() + else other_strategy + ) + equal_strategy = OpStrategy([]) + + for arg_strategy in select_strategy.strategies: + arg_spec = arg_strategy.output_spec + if is_tensor_partial(arg_spec): + # if the arg_spec have partial, reshard to replicate + # otherwise local shard tensor comparison would be invalid + output_spec = DTensorSpec( + mesh=arg_spec.mesh, + placements=tuple( + Replicate() if isinstance(p, Partial) else p + for p in arg_spec.placements + ), + ) + equal_strategy.strategies.append( + PlacementStrategy(output_specs=output_spec) + ) + else: + equal_strategy.strategies.append(PlacementStrategy(arg_spec)) + return equal_strategy + + +@register_op_strategy( + [ + aten.empty_like.default, + aten.ones_like.default, + aten.rand_like.default, + aten.randn_like.default, + aten.zeros_like.default, + ], + schema_info=RuntimeSchemaInfo(1, ["dtype"]), +) +@register_op_strategy( + [aten.full_like.default], + schema_info=RuntimeSchemaInfo(2, ["dtype"]), +) +@register_op_strategy( + [ + aten.randint_like.default, + aten.randint_like.low_dtype, + aten.randint_like.low_dtype_out, + ], + schema_info=RuntimeSchemaInfo(3, ["dtype"]), +) +def create_like_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: + # create_like_strategy deals with ops that creating tensors with same + # shape as input, but with specific content that does not depend on + # the input, we can propagate sharding, but we have to make sure we + # move from partial to replicated. + select_strategy = op_schema.args_schema[0] + create_like_strategy = OpStrategy([]) + assert isinstance(select_strategy, OpStrategy) + for arg_strategy in select_strategy.strategies: + arg_spec = arg_strategy.output_spec + if is_tensor_partial(arg_spec): + # if the arg_spec have partial, accept partial + # in the input_specs but output replicate for + # those corresponding mesh dims + output_spec = DTensorSpec( + mesh=arg_spec.mesh, + placements=tuple( + Replicate() if isinstance(p, Partial) else p + for p in arg_spec.placements + ), + ) + create_like_strategy.strategies.append( + PlacementStrategy(output_specs=output_spec, input_specs=(arg_spec,)) + ) + + else: + create_like_strategy.strategies.append(PlacementStrategy(arg_spec)) + + return create_like_strategy + + +@register_op_strategy( + [ + aten.new_empty.default, + aten.new_full.default, + aten.new_ones.default, + aten.new_zeros.default, + aten.new_empty_strided.default, + ], + schema_info=RuntimeSchemaInfo(1, ["dtype"]), +) +def new_factory_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: + # Currently there are two strategies: + # 1. let the output be replicated + # 2. let the output follow the input if input and output have the same shape + input_strategy = op_schema.args_schema[0] + assert isinstance(input_strategy, OpStrategy) + input_shape = input_strategy.shape + output_shape = op_schema.args_schema[1] + assert isinstance(output_shape, list) + + new_factory_strategy = OpStrategy([]) + for arg_strategy in input_strategy.strategies: + input_spec = arg_strategy.output_spec + replica_spec = DTensorSpec(mesh, tuple([Replicate()] * mesh.ndim)) + new_factory_strategy.strategies.append( + PlacementStrategy( + output_specs=replica_spec, + input_specs=(input_spec,), + redistribute_cost=[[0.0] * mesh.ndim], + ) + ) + + if tuple(input_shape) == tuple(output_shape) and input_spec.is_sharded(): + # NOTE: for new_empty_strided, currently the non-replicate sharding + # is supported only when the shape is evenly shardable + if ( + op_schema.op == aten.new_empty_strided.default + and not is_tensor_evenly_shardable(input_shape, input_spec) + ): + continue + + new_factory_strategy.strategies.append( + PlacementStrategy( + output_specs=input_spec, + input_specs=(input_spec,), + # encouraging new tensor placement to be the same as input + redistribute_cost=[[-0.1] * mesh.ndim], + ) + ) + + return new_factory_strategy + + +@register_op_strategy(aten.bucketize.Tensor) +def gen_bucketize_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: + """Just propagate input sharding, but expect replicated for boundaries input.""" + input_strategy = op_schema.args_schema[0] + bucketize_strategy = OpStrategy([]) + assert isinstance(input_strategy, OpStrategy) + for arg_strategy in input_strategy.strategies: + arg_spec = DTensorSpec(mesh, arg_strategy.output_spec.placements) + replica_spec = DTensorSpec(mesh, tuple([Replicate()] * mesh.ndim)) + bucketize_strategy.strategies.append( + PlacementStrategy( + output_specs=arg_spec, input_specs=(arg_spec, replica_spec) + ) + ) + + return bucketize_strategy + + +@register_op_strategy(aten.slice.Tensor, schema_info=RuntimeSchemaInfo(1)) +def gen_slice_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: + """Forward all shardings except the slice dimension.""" + defaults = (None, 0, None, None, 1) + input_strategy, dim, start, end, step = ( + op_schema.args_schema + defaults[len(op_schema.args_schema) :] + ) + assert isinstance(input_strategy, OpStrategy) + input_shape = input_strategy.shape + input_ndim = input_strategy.ndim + assert isinstance(dim, int) + if start is None: + start = 0 + if end is None or end > input_shape[dim]: + end = input_shape[dim] + assert isinstance(start, int) + assert isinstance(end, int) + assert isinstance(step, int) + + # normalize args + slice_dim = normalize_dim(dim, input_ndim) + start = normalize_dim(start, input_shape[dim]) + end = normalize_dim(end, input_shape[dim]) + + redundant_slice = start == 0 and end == input_shape[dim] and step == 1 + + slice_strategy = OpStrategy([]) + + for arg_strategy in input_strategy.strategies: + arg_spec = arg_strategy.output_spec + if not is_tensor_dim_sharded(arg_spec, dim=slice_dim) or redundant_slice: + # only add the strategy if the slice dim is not sharded + out_spec = DTensorSpec(mesh, arg_spec.placements) + slice_strategy.strategies.append(PlacementStrategy(output_specs=out_spec)) + if not slice_strategy.strategies: + # if all strategies are filtered out, unsharding all specs on slice dim + # of the input strategy, and use that as the op strategy + for arg_strategy in input_strategy.strategies: + arg_spec = arg_strategy.output_spec + unshard_spec = DTensorSpec( + mesh, unshard_tensor_dim(arg_spec.placements, dim=slice_dim) + ) + slice_strategy.strategies.append( + PlacementStrategy(output_specs=unshard_spec) + ) + return slice_strategy + + +def unshard_tensor_dim( + placements: Sequence[Placement], dim: int +) -> Tuple[Placement, ...]: + """Disallow the given tensor dimension to be sharded.""" + return tuple( + p if (not isinstance(p, Shard) or p.dim != dim) else Replicate() + for p in placements + ) + + +def replicate_tensor_dim( + placements: Sequence[Placement], dim: int +) -> Tuple[Placement, ...]: + """Force the given tensor dimension to be replicated.""" + # Not using p.is_shard() to avoid mypy complain about Placement not having + # attribute dim. + return tuple( + Replicate() if p.is_partial() or isinstance(p, Shard) and p.dim == dim else p + for p in placements + ) + + +@register_op_strategy(aten.slice_scatter.default, schema_info=RuntimeSchemaInfo(2)) +def gen_slice_scatter_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: + # 1. number of dimensions in input and src need to match. + # 2. number of elements on all non-dim need to match between input and src. + # 3. numer of elements in src in dim need to match the slice size. + # Given the above: + # - We suggest for src to follow the sharding of input, except on the scatter dimension, + # where our best bet for now is to make them replicated as a fall-back. + # TODO: Ideally we'd like to make sure the output is re-sharded afterwards to keep input sharding. + + input_strategy = op_schema.args_schema[0] + assert isinstance(input_strategy, OpStrategy) + input_ndim = input_strategy.ndim + slice_dim = ( + cast(int, op_schema.args_schema[2]) if len(op_schema.args_schema) > 2 else 0 + ) + slice_dim = normalize_dim(slice_dim, input_ndim) + + slice_scatter_strategy = OpStrategy([]) + # by default follow the input strategy for both input and src + for arg_strategy in input_strategy.strategies: + arg_spec = arg_strategy.output_spec + if not ( + is_tensor_dim_sharded(arg_spec, dim=slice_dim) + or is_tensor_partial(arg_spec) + ): + # only add the strategy if the slice_scatter dim is not sharded or partial + slice_scatter_strategy.strategies.append( + PlacementStrategy(output_specs=arg_spec) + ) + + if not slice_scatter_strategy.strategies: + # if all strategies are filtered out, replicating all specs on slice_scatter dim + # of the input strategy, and use that as the op strategy + for arg_strategy in input_strategy.strategies: + arg_spec = arg_strategy.output_spec + replicate_spec = DTensorSpec( + mesh, replicate_tensor_dim(arg_spec.placements, dim=slice_dim) + ) + slice_scatter_strategy.strategies.append( + PlacementStrategy(output_specs=replicate_spec) + ) + return slice_scatter_strategy + + +@register_op_strategy(aten._local_scalar_dense.default) +def replica_only_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: + """Only allow replication on the input/output.""" + replicate_spec = DTensorSpec(mesh, tuple([Replicate()] * mesh.ndim)) + return OpStrategy([PlacementStrategy(replicate_spec)]) + + +@register_op_strategy( + [aten.scatter_.value, aten.scatter.value, aten.scatter_.src, aten.scatter.src], + schema_info=RuntimeSchemaInfo(1), +) +def scatter_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: + input_strategy = cast(OpStrategy, op_schema.args_schema[0]) + single_mesh_dim_strategies = [] + + # placement list stores placements of [output, input, index, src] + # first we always have replicate all for inputs and output + if len(op_schema.args_strategy) < 3: + # scatter_.src/scatter.src with src be float number instead of tensor + all_replicate: PlacementList = [Replicate()] * 3 + else: + all_replicate = [Replicate()] * 4 + single_mesh_dim_strategies.append(all_replicate) + + # TODO: see if we can support input sharding pattern + inplace_op = _is_inplace_op(op_schema.op) + + op_strategy = expand_to_full_mesh_op_strategy( + mesh, op_schema, single_mesh_dim_strategies, inplace_op=inplace_op + ) + return op_strategy + + +@register_op_strategy(aten.gather.default) +def gather_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: + input_strategy = cast(OpStrategy, op_schema.args_schema[0]) + dim = cast(int, op_schema.args_schema[1]) + index_strategy = cast(OpStrategy, op_schema.args_schema[2]) + + input_shape = input_strategy.shape + index_shape = index_strategy.shape + + single_mesh_dim_strategies = [] + + # placement list stores placements of [output, input, index] + # first we always have replicate all for inputs and output + all_replicate: PlacementList = [Replicate()] * 3 + single_mesh_dim_strategies.append(all_replicate) + + # input sharding, input sharded, index accepts mask partial, output follows index + # this only works when the input is sharded on the gather dimension, and + # index has size 1 on the gather dimension + if index_shape[dim] == 1: + index_partial_placement = _MaskPartial(logical_dim_size=input_shape[dim]) + input_sharding: PlacementList = [ + index_partial_placement, + Shard(dim), + index_partial_placement, + ] + single_mesh_dim_strategies.append(input_sharding) + + # index sharding, input replicated, index sharded, output follows index + # this only works when the sharding dimension is the gather dimension + index_sharding: PlacementList = [Shard(dim), Replicate(), Shard(dim)] + single_mesh_dim_strategies.append(index_sharding) + + return expand_to_full_mesh_op_strategy( + mesh, op_schema, single_mesh_dim_strategies, input_index=1 + ) + + +def _derive_follow_placements_from_tuple_strategy( + tuple_strategy: TupleStrategy, +) -> Sequence[Placement]: + """ + derive the placements to follow from the tuple strategy, mainly used by + aten.stack, aten.cat, where each operand have the same shape, and correspondingly + expecting the same sharding + """ + + def merge_placement( + cur_placement: Placement, new_placement: Placement + ) -> Placement: + # semantic if we already have a follow placement, we + # check each placement for the current arg placement + # to see if we want to merge/adjust the placement to follow + # the priority: Partial -> Shard -> Replicate + if cur_placement == new_placement: + return cur_placement + + if cur_placement.is_partial(): + if new_placement.is_shard(): + # follow new placement + return new_placement + elif new_placement.is_partial(): + # different partial types, we can't merge and have to replicate all here + return Replicate() + else: + # follow partial + return cur_placement + elif cur_placement.is_shard(): + if new_placement.is_shard(): + # cur/new placement are different sharding (i.e. different shard dim) + # currently fallback to replicate all args + return Replicate() + else: + # for partial/replicate, follow the current shard placement + return cur_placement + else: + # current replicate, just follow new placement + return new_placement + + follow_placements: Optional[List[Placement]] = None + for arg_strategy in tuple_strategy.childs: + assert isinstance(arg_strategy, OpStrategy) + for placement_strategy in arg_strategy.strategies: + arg_placements = placement_strategy.output_spec.placements + if follow_placements is None: + follow_placements = list(arg_placements) + continue + mesh_ndim = len(follow_placements) + assert follow_placements is not None + for mesh_idx in range(mesh_ndim): + # merge placements with the priority + follow_placements[mesh_idx] = merge_placement( + follow_placements[mesh_idx], arg_placements[mesh_idx] + ) + assert follow_placements is not None, "follow placements should not be None!" + return follow_placements + + +def normalize_shard_for_stack( + placements: Sequence[Placement], insert_dim: int = 0 +) -> Sequence[Placement]: + # stack op would "insert" new dim, so all sharded dim >= the inserted dim need to + # be normalized with the new Shard placement + normalized_placements: List[Placement] = [] + for placement in placements: + if isinstance(placement, Shard) and placement.dim >= insert_dim: + normalized_placements.append(Shard(placement.dim + 1)) + else: + normalized_placements.append(placement) + return normalized_placements + + +@register_op_strategy(aten.stack.default, RuntimeSchemaInfo(1, needs_pytree=True)) +def stack_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: + args_schema = op_schema.args_schema + input_tuple_strategy = args_schema[0] + assert isinstance(input_tuple_strategy, TupleStrategy), f"{input_tuple_strategy}" + first_input_strategy = input_tuple_strategy.childs[0] + assert isinstance(first_input_strategy, OpStrategy), f"{first_input_strategy}" + common_input_ndim = first_input_strategy.ndim + dim = cast(int, args_schema[1]) if len(args_schema) > 1 else 0 + # normalize the dim to be within the common input ndim + dim = normalize_dim(dim, common_input_ndim) + + follow_placements = _derive_follow_placements_from_tuple_strategy( + input_tuple_strategy + ) + + # create op strategy base on the follow placements + op_strategy = OpStrategy([]) + + input_specs = tuple( + DTensorSpec(mesh, tuple(follow_placements)) + for _ in range(len(input_tuple_strategy.childs)) + ) + + follow_placements = normalize_shard_for_stack(follow_placements, dim) + + op_strategy.strategies.append( + PlacementStrategy( + output_specs=DTensorSpec(mesh, tuple(follow_placements)), + input_specs=input_specs, + ) + ) + return op_strategy + + +@register_op_strategy(aten.cat.default, RuntimeSchemaInfo(1, needs_pytree=True)) +def cat_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: + args_schema = op_schema.args_schema + input_tuple_strategy = args_schema[0] + assert isinstance(input_tuple_strategy, TupleStrategy), f"{input_tuple_strategy}" + first_input_strategy = input_tuple_strategy.childs[0] + assert isinstance(first_input_strategy, OpStrategy), f"{first_input_strategy}" + common_input_ndim = first_input_strategy.ndim + dim = cast(int, args_schema[1]) if len(args_schema) > 1 else 0 + # normalize the dim to be within the common input ndim + dim = normalize_dim(dim, common_input_ndim) + + follow_placements = _derive_follow_placements_from_tuple_strategy( + input_tuple_strategy + ) + # for cat we unshard the cat dim if it is sharded + follow_placements = unshard_tensor_dim(follow_placements, dim) + + # create op strategy base on the follow placements + op_strategy = OpStrategy([]) + + input_specs = tuple( + DTensorSpec(mesh, tuple(follow_placements)) + for _ in range(len(input_tuple_strategy.childs)) + ) + op_strategy.strategies.append( + PlacementStrategy( + output_specs=DTensorSpec(mesh, tuple(follow_placements)), + input_specs=input_specs, + ) + ) + return op_strategy + + +@register_prop_rule(aten.index_select.default, schema_info=RuntimeSchemaInfo(1)) +def prop_index_select(op_schema: OpSchema) -> OutputSharding: + values_spec, dim, indices_spec = op_schema.args_schema + + assert isinstance(values_spec, DTensorSpec) + assert isinstance(dim, int) + assert isinstance(indices_spec, DTensorSpec) + + all_indices_spec: List[Optional[DTensorSpec]] = [ + indices_spec if dim == i else None for i in range(values_spec.ndim) + ] + + result = prop_index( + OpSchema( + op=op_schema.op, + args_schema=(values_spec, all_indices_spec), + kwargs_schema=op_schema.kwargs_schema, + ) + ) + if result.redistribute_schema: + schema_suggestion = result.redistribute_schema + result.redistribute_schema = OpSchema( + op=op_schema.op, + args_schema=( + schema_suggestion.args_schema[0], + dim, + schema_suggestion.args_schema[1][dim], + ), + kwargs_schema=op_schema.kwargs_schema, + ) + return result + + +@register_prop_rule(aten.index.Tensor, schema_info=RuntimeSchemaInfo(needs_pytree=True)) +def prop_index(op_schema: OpSchema) -> OutputSharding: + """ + Expect replicated on the first input; _mostly_ pointwise on the second input. + + TODO: exception: when the dtype of second input is "bool", then a torch.nonzero needs to be triggered first. + """ + # Current sharding constraints: + # For values: + # 1. We currently require that the dimension of values_spec be replicated or partial + # if they are being indexed on. + # 2. Other dimensions of values_spec can remain sharded if they are so. + # For indices: + # Indices can be either sharded or replicated. All index tensors need to be sharded + # in a compatible way, following the pointwise rule (including resolving Partial + # into either sharded or replicated) + + values_spec, multi_indices_spec = op_schema.args_schema + assert isinstance(values_spec, DTensorSpec) + assert isinstance(multi_indices_spec, list) + multi_indices_spec = cast(List[Optional[DTensorSpec]], multi_indices_spec) + valid_indices_spec: List[Tuple[int, DTensorSpec]] = [ + (i, a) for i, a in enumerate(multi_indices_spec) if a is not None + ] + + # 1. All indices have to be sharded equally. Moreover, indices can be broadcast. + # Here, we piggyback on the pointwise sharding rule for indices. + indices_out = pointwise_rule( + OpSchema( + op=op_schema.op, + args_schema=tuple(v[1] for v in valid_indices_spec), + kwargs_schema={}, + ) + ) + need_reshard_on_indices = indices_out.output_spec is None + + if not need_reshard_on_indices: + # this means that our inputs are already sharded properly and we will use that as our indices_spec + assert isinstance(indices_out.output_spec, DTensorSpec) + indices_spec: DTensorSpec = indices_out.output_spec + else: + assert indices_out.redistribute_schema is not None + valid_indices_suggestion = indices_out.redistribute_schema + for i, v in enumerate(valid_indices_suggestion.args_spec): + multi_indices_spec[valid_indices_spec[i][0]] = v + # we'll need to call pointwise_rule again to see what's our ideal indices_spec and then + # use that to compute our ideal values_spec + indices_output_spec = pointwise_rule(valid_indices_suggestion).output_spec + assert isinstance(indices_output_spec, DTensorSpec) + indices_spec = indices_output_spec + + lookup_dims = {v[0] for v in valid_indices_spec} + + need_reshard_on_values = tuple( + (isinstance(vp, Shard) and (vp.dim in lookup_dims or isinstance(ip, Shard))) + for vp, ip in zip(values_spec.placements, indices_spec.placements) + ) + + if not need_reshard_on_indices and not any(need_reshard_on_values): + value_placements = values_spec.placements + + all_dims_consecutive = all( + b[0] - a[0] == 1 + for b, a in zip(valid_indices_spec[1:], valid_indices_spec[:-1]) + ) + if all_dims_consecutive: + # if all index vectors are consecutives, insert at the dimension of the first index + insert_dim: int = valid_indices_spec[0][0] + else: + # else, insert on the first dimension + insert_dim = 0 + + def place(vp: Placement, ip: Placement) -> Placement: + if isinstance(vp, Shard): + return Shard( + vp.dim + if vp.dim < insert_dim + # accounts for the offset in output dimensions + else vp.dim + + indices_spec.ndim + - sum(1 if vp.dim > v[0] else 0 for v in valid_indices_spec) + ) + if isinstance(ip, Shard): + return Shard(ip.dim + insert_dim) + # Partial or Replicated + return vp + + value_placements = tuple( + place(vp, ip) + for vp, ip in zip(values_spec.placements, indices_spec.placements) + ) + result = OutputSharding( + output_spec=DTensorSpec( + mesh=values_spec.mesh, + placements=value_placements, + ) + ) + return result + else: + result = OutputSharding( + output_spec=None, + redistribute_schema=OpSchema( + op=op_schema.op, + args_schema=( + DTensorSpec( + mesh=values_spec.mesh, + placements=tuple( + [ + Replicate() if need_reshard_on_values[i] else v + for i, v in enumerate(values_spec.placements) + ] + ), + tensor_meta=values_spec.tensor_meta, + ), + multi_indices_spec, + ), + kwargs_schema=op_schema.kwargs_schema, + ), + ) + return result + + +@register_prop_rule( + [ + aten.split.Tensor, + aten.split_with_sizes.default, + aten.split_with_sizes_copy.default, + ], + schema_info=RuntimeSchemaInfo(1), +) +def split_rule(op_schema: OpSchema) -> OutputSharding: + output_spec_list: List[DTensorSpec] = [] + input_spec = cast(DTensorSpec, op_schema.args_schema[0]) + ndim = input_spec.ndim + split_size_or_sections = op_schema.args_schema[1] + dim = cast(int, op_schema.args_schema[2]) if len(op_schema.args_schema) > 2 else 0 + dim = normalize_dim(dim, ndim) + + # TODO: tensor to split cannot have Partial + # in its placements for now. Will need to + # support in future. + if input_spec.sums: + raise NotImplementedError( + f"splitting distributed tensor with " + f"Partial placement is not implemented!\n" + f"DTensorSpec={input_spec}" + ) + + # TODO: just like slice op, split replicates before + # splitting on a sharded dimension + need_reshard = False + if is_tensor_dim_sharded(input_spec, dim=dim): + need_reshard = True + input_spec = DTensorSpec( + mesh=input_spec.mesh, + placements=unshard_tensor_dim(input_spec.placements, dim=dim), + tensor_meta=input_spec.tensor_meta, + ) + + if need_reshard: + return OutputSharding( + None, + redistribute_schema=OpSchema( + op=op_schema.op, + args_schema=(input_spec,) + op_schema.args_schema[1:], + kwargs_schema=op_schema.kwargs_schema, + ), + ) + + def size_split(N, i): + # Last chunk will be smaller if the tensor size N + # along the given dimension dim is not divisible by i. + assert i > 0 + return [i] * (N // i) + ([N % i] if N % i != 0 else []) + + output_size_list = ( + size_split(input_spec.shape[dim], split_size_or_sections) + if isinstance(split_size_or_sections, int) + else split_size_or_sections + ) + output_spec_list = [ + DTensorSpec( + mesh=input_spec.mesh, + placements=input_spec.placements, + ) + for _ in range(len(output_size_list)) + ] + return OutputSharding(output_spec_list) diff --git a/src/chop/distributed/tensor/ops/utils.py b/src/chop/distributed/tensor/ops/utils.py new file mode 100644 index 000000000..27cf89224 --- /dev/null +++ b/src/chop/distributed/tensor/ops/utils.py @@ -0,0 +1,300 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates +import functools +import itertools +import operator +from typing import cast, Iterable, List, Optional, Sequence, Tuple, Union + +import torch +from torch.distributed._tensor._collective_utils import redistribute_cost +from torch.distributed._tensor._op_schema import ( + OpSchema, + OpStrategy, + PlacementList, + PlacementStrategy, + RuntimeSchemaInfo, +) +from torch.distributed._tensor.device_mesh import DeviceMesh +from torch.distributed._tensor.placement_types import ( + DTensorSpec, + Partial, + Placement, + Replicate, + Shard, +) + +from chop.distributed.tensor.api import DTensor + + +# convenient wrapper to register sharding propagation rules +# pyre-fixme[3]: Return type must be annotated. +# pyre-fixme[2]: Parameter must be annotated. +def register_prop_rule(op, schema_info=None): + # pyre-fixme[53]: Captured variable `func` is not annotated. + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. + def wrapper(impl): + overloads = op if isinstance(op, list) else [op] + for overload in overloads: + DTensor._op_dispatcher.sharding_propagator.register_sharding_prop_rule( + overload, impl, schema_info + ) + return impl + + return wrapper + + +def register_op_strategy(op, schema_info=None): + # pyre-fixme[53]: Captured variable `func` is not annotated. + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. + + # For every ATen op that accepts any args in this list, + # the arg itself can impact the strides (and potentially the sharding strategy) + # of the output tensor. + # thus, we will detect ATen schemas with any of these args and ensure + # that they get specialized here. + arg_names_that_require_specializing_cache_strategy = [ + "memory_format", + ] + + def wrapper(impl): + if isinstance(op, list): + overloads = op + else: + overloads = [op] + + for overload in overloads: + curr_schema_info = None + if schema_info is None: + specialized_args = [ + a.name + for a in overload._schema.arguments + if a.name in arg_names_that_require_specializing_cache_strategy + ] + if any(specialized_args): + curr_schema_info = RuntimeSchemaInfo( + static_kwargkey=specialized_args + ) + else: + curr_schema_info = schema_info + DTensor._op_dispatcher.sharding_propagator.register_op_strategy( + overload, impl, curr_schema_info + ) + return impl + + return wrapper + + +def as_list( + x: Union[List[object], object] + # pyre-fixme[11]: Annotation `immutable_list` is not defined as a type. +) -> Union[List[object], torch.fx.immutable_collections.immutable_list]: # type: ignore[valid-type] + # During tracing, `aten.sum.dim_IntList` uses `immutable_list` for its args, + # which is an object but treated as a list by the tracer. Therefore, keep + # `immutable_list` intact here as well. + if type(x) is list or isinstance(x, torch.fx.immutable_collections.immutable_list): + return x + else: + return [x] + + +def normalize_dim(dim: int, ndim: int) -> int: + return dim if dim >= 0 else dim + ndim + + +def normalize_dims(dims: Union[int, Sequence[int]], ndim: int) -> Sequence[int]: + """Normalize a dim or a sequence of dims, so that they are all positive.""" + if isinstance(dims, int): + dims = (normalize_dim(dims, ndim),) + elif isinstance(dims, list): + dims = [normalize_dim(dim, ndim) for dim in dims] + elif isinstance(dims, tuple): + dims = tuple([normalize_dim(dim, ndim) for dim in dims]) + return dims + + +def normalize_to_torch_size(size) -> torch.Size: + """ + Unify variable types of size argument to torch.Size + Acceptable types include: + int, Sequence[int], Tuple[int], Tuple[Sequence[int]], + or torch.Size + """ + if isinstance(size, torch.Size): + return size + + if isinstance(size, int): + torch_size = [size] + elif len(size) == 1 and isinstance(size[0], Sequence): + torch_size = list(size[0]) + else: + torch_size = list(size) + return torch.Size(torch_size) + + +def prod(xs: Iterable[int]) -> int: + return functools.reduce(operator.mul, xs, 1) + + +def is_tensor_shardable(shape: Sequence[int], spec: DTensorSpec) -> bool: + """Check if the shape is shardable according to the spec.""" + # number of shards in each tensor dimension + shards_map = [1] * len(shape) + for i, placement in enumerate(spec.placements): + if placement.is_shard(): + shard_dim = cast(Shard, placement).dim + shards_map[shard_dim] *= spec.mesh.size(i) + + for i, dim_size in enumerate(shape): + # TODO: maybe we should determine is_shardable based on + # whether it's evenly sharded or not + if shards_map[i] > 1 and dim_size < shards_map[i]: + return False + + return True + + +def is_tensor_evenly_shardable(shape: Sequence[int], spec: DTensorSpec) -> bool: + """Check if the shape is evenly shardable according to the spec.""" + # number of shards in each tensor dimension + shards_map = [1] * len(shape) + for i, placement in enumerate(spec.placements): + if placement.is_shard(): + shard_dim = cast(Shard, placement).dim + shards_map[shard_dim] *= spec.mesh.size(i) + + for i, dim_size in enumerate(shape): + if shards_map[i] > 1 and (dim_size % shards_map[i] != 0): + return False + + return True + + +def is_tensor_dim_sharded(spec: DTensorSpec, dim: int) -> bool: + """Return True if tensor dim is sharded.""" + return any(p.is_shard(dim) for p in spec.placements) + + +def is_tensor_partial(spec: DTensorSpec) -> bool: + """Return True if tensor is partial on the mesh.""" + return any(p.is_partial() for p in spec.placements) + + +def infer_broadcast_dims_map( + common_shape: torch.Size, input_shape: torch.Size +) -> List[int]: + # infer the broadcast dims map, where it maps from the common shape dim to the input shape dim + # this is aligned with the broadcast semantics + common_ndim = len(common_shape) + input_ndim = len(input_shape) + broadcast_dims_map = [-1] * common_ndim + for idx in range(-1, -1 - input_ndim, -1): + if input_shape[idx] == common_shape[idx]: + broadcast_dims_map[common_ndim + idx] = input_ndim + idx + return broadcast_dims_map + + +def map_placements_after_broadcast( + placements: Tuple[Placement, ...], + shape: torch.Size, + broadcast_dims_map: List[int], +) -> Tuple[Placement, ...]: + """Map each placement based on the output shape after broadcast.""" + new_placements: List[Placement] = [] + for placement in placements: + if isinstance(placement, (Replicate, Partial)): + new_placements.append(placement) + else: + assert isinstance(placement, Shard) + shard_dim = normalize_dim(placement.dim, len(shape)) + new_shard_dim = broadcast_dims_map[shard_dim] + if new_shard_dim != -1: + # there's a map from the common shape shard dim to + # the input shape shard dim before broadcasting, + # use that instead + new_placements.append(Shard(new_shard_dim)) + else: + # there's no map between common shape shard dim and + # the input shape shard dim before broadcasting, + # in this case it means implicit broadcasting happen + # in this dim, so we can just mark it as replicate + # and implict broadcast will broadcast automatically + # to the sharded shape + new_placements.append(Replicate()) + + return tuple(new_placements) + + +def generate_redistribute_costs( + src_strategy: OpStrategy, dst_spec: DTensorSpec +) -> List[float]: + redistribute_costs: List[float] = [] + for strat in src_strategy.strategies: + redistribute_costs.append(redistribute_cost(strat.output_spec, dst_spec)) + + return redistribute_costs + + +def expand_to_full_mesh_op_strategy( + mesh: DeviceMesh, + op_schema: OpSchema, + single_mesh_dim_strategies: List[PlacementList], + *, + input_index: int = 1, + inplace_op: bool = False, +) -> OpStrategy: + # Expand the single_mesh_dim_strategies to full mesh dim strategies. + all_mesh_dim_strategies = [single_mesh_dim_strategies] * mesh.ndim + + strategy_combs = itertools.product(*all_mesh_dim_strategies) + + all_strategies = [] + for strategy_comb in strategy_combs: + spec_list: List[Optional[DTensorSpec]] = [] + for specs in zip(*strategy_comb): + if specs[0] is not None: + spec_list.append(DTensorSpec(mesh, specs)) + else: + spec_list.append(None) + + input_specs: List[DTensorSpec] = [ + s for s in spec_list[input_index:] if isinstance(s, DTensorSpec) + ] + + input_args_strategy = op_schema.args_strategy + assert len(input_specs) == len(input_args_strategy) + self_spec = input_args_strategy[0].strategies[0].output_spec + + if inplace_op and self_spec.placements != input_specs[0].placements: + # if it's inplace op, we would only allow the placement strategy to be added when the + # input_spec matches the first argument's runtime sharding, otherwise we skip + continue + + # check inputs shardable + inputs_shardable = all( + is_tensor_shardable(inp.shape, s) + for inp, s in zip(input_args_strategy, input_specs) + ) + + # only add to the all_strategies list when all inputs are shardable + if inputs_shardable: + redistribute_cost = [ + generate_redistribute_costs(input_strategy, input_spec) + for input_strategy, input_spec in zip(input_args_strategy, input_specs) + ] + if input_index > 1: + output_specs = tuple(spec_list[:input_index]) + else: + if spec_list[0] is not None: + output_specs = spec_list[0] # type: ignore[assignment] + else: + raise RuntimeError("output spec is None") + strategy = PlacementStrategy( + output_specs=output_specs, + input_specs=input_specs, + redistribute_cost=redistribute_cost, + ) + all_strategies.append(strategy) + + return OpStrategy(all_strategies) diff --git a/src/chop/distributed/tensor/ops/view_ops.py b/src/chop/distributed/tensor/ops/view_ops.py new file mode 100644 index 000000000..dc8103b08 --- /dev/null +++ b/src/chop/distributed/tensor/ops/view_ops.py @@ -0,0 +1,669 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates +from dataclasses import dataclass +from typing import ( + Callable, + cast, + Dict, + Iterable, + List, + Optional, + Sequence, + Set, + Tuple, + Union, +) + +import torch +from torch import Tensor +from torch.distributed._tensor._op_schema import ( + OpSchema, + OpStrategy, + PlacementStrategy, + RuntimeSchemaInfo, + StrategyType, +) +from torch.distributed._tensor.api import Shard +from torch.distributed._tensor.placement_types import DTensorSpec, Placement, Replicate +from torch.distributed.device_mesh import DeviceMesh + +from chop.distributed.tensor.ops.utils import register_op_strategy +from chop.distributed.tensor.ops.utils import ( + generate_redistribute_costs, + normalize_dim, + normalize_dims, + prod, +) + +aten = torch.ops.aten + +Shape = Tuple[int, ...] + + +@dataclass +class DimSpec: + """Specifies how an output dimension maps to an input dimension.""" + + def inputs(self) -> Iterable["DimSpec"]: + return () + + +# Rules that map each dimension of the output to dimensions of the input tensor +DimMap = Tuple[DimSpec, ...] + + +@dataclass +class Singleton(DimSpec): + """Output dimension is a singleton.""" + + pass + + +@dataclass +class InputDim(DimSpec): + """Output dimension maps directly to an input dimension.""" + + input_dim: int + + +@dataclass +class Broadcast(DimSpec): + """Output is the broadcast of a singleton input dimension.""" + + dim: DimSpec + dim_size: int + + @classmethod + def new(cls, dim: DimSpec, dim_size: int) -> DimSpec: + return Broadcast(dim, dim_size) + + def inputs(self) -> Iterable[DimSpec]: + return (self.dim,) + + +@dataclass +class NewDim(DimSpec): + """This is a new dimension created by the op.""" + + size: int + + @classmethod + def new(cls, size: int) -> DimSpec: + return Singleton() if size == 1 else NewDim(size) + + +@dataclass +class Repeat(DimSpec): + """Output dimension is the input dimension repeated n-times.""" + + input_dim: DimSpec + times: int + + @classmethod + def new(cls, dim: DimSpec, times: int) -> DimSpec: + if times == 1: + return dim + elif isinstance(dim, Singleton): + # repeating a singleton is the same as broadcasting it + return Broadcast(dim, times) + else: + return Repeat(dim, times) + + def inputs(self) -> Iterable[DimSpec]: + return (self.input_dim,) + + +@dataclass +class Flatten(DimSpec): + """Flatten a set of input dimensions, ensuring right-most adjacent elements remain adjacent in the output.""" + + input_dims: Sequence[DimSpec] + + @classmethod + def new(cls, dims: Sequence[DimSpec]) -> DimSpec: + if len(dims) == 0: + # flattening a scalar leads to a singleton + return Singleton() + elif len(dims) == 1: + # flattening a single dimension is no-op + return dims[0] + else: + return Flatten(dims) + + def inputs(self) -> Iterable[DimSpec]: + return self.input_dims + + +@dataclass +class Split(DimSpec): + """ + This dimension is a member of a decomposition of the input dim. + + Note that input_dim itself could be a Flattened set of input dims. + """ + + input_dim: DimSpec + group_shape: Shape + split_id: int + + @classmethod + def new(cls, dim: DimSpec, group_shape: Tuple[int, ...], idx: int) -> DimSpec: + assert len(group_shape) > 0 + if len(group_shape) == 1: + # not really a group, just return the input dim back + assert idx == 0 + return dim + elif group_shape[idx] == 1: + return Singleton() + else: + # remove singletons from group + # group_mapping = [(new_index, (shape, old_index)) ...] + group_mapping = list( + enumerate((s, i) for i, s in enumerate(group_shape) if s != 1) + ) + new_group_shape = tuple(m[1][0] for m in group_mapping) + new_idx = next(filter(lambda x: x[1][1] == idx, group_mapping))[0] + return Split(dim, new_group_shape, new_idx) + + def inputs(self) -> Iterable[DimSpec]: + return (self.input_dim,) + + +def dim_pad_left(ndim: int, min_dims: int) -> DimMap: + return (Singleton(),) * max(0, min_dims - ndim) + tuple( + InputDim(i) for i in range(ndim) + ) + + +def dim_atleast_3d(ndim: int) -> DimMap: + if ndim == 0: + return (Singleton(), Singleton(), Singleton()) + elif ndim == 1: + return (Singleton(), InputDim(0), Singleton()) + elif ndim == 2: + return (InputDim(0), InputDim(1), Singleton()) + else: + return tuple(InputDim(i) for i in range(ndim)) + + +def expand(input_shape: Shape, shape: Shape) -> DimMap: + """Implement broadcast on multiple dimensions.""" + assert len(shape) >= len(input_shape) + + # 1. create padded input dimensions + padded_input = dim_pad_left(len(input_shape), len(shape)) + # 2. check that input shapes are compatible + mapping = [] + for p, desired_s in zip(padded_input, shape): + if isinstance(p, Singleton): + actual_s = 1 + assert desired_s >= 0 + else: + assert isinstance(p, InputDim), f"DimSpec not supported in expand: {p}" + actual_s = input_shape[p.input_dim] + assert actual_s == 1 or desired_s == -1 or desired_s == actual_s + mapping.append( + p + if desired_s in (1, -1) or desired_s == actual_s + else Broadcast.new(p, desired_s) + ) + return tuple(mapping) + + +def normalize_sizes(sizes: Union[Shape, Tuple[Shape]]) -> Shape: + if isinstance(sizes[0], int): + return cast(Shape, sizes) + elif len(sizes) == 1: + return cast(Shape, sizes[0]) # type: ignore[redundant-cast] + else: + raise RuntimeError("Size must be int... or tuple") + + +def dim_flatten(ndim: int, start_dim=0, end_dim=-1) -> DimMap: + if ndim == 0: + return (Singleton(),) + elif ndim == 1: + return (InputDim(0),) + else: + # only flattening dims from start_dim to end_dim (inclusive) + # other dims are passed through + if end_dim < 0: + end_dim += ndim + results: List[DimSpec] = [InputDim(i) for i in range(start_dim)] + results.append( + Flatten.new(tuple(InputDim(i) for i in range(start_dim, end_dim + 1))) + ) + results.extend([InputDim(i) for i in range(end_dim + 1, ndim)]) + return tuple(results) + + +def dim_movedim( + ndim: int, + input: Union[int, Sequence[int]], + destination: Union[int, Sequence[int]], +) -> DimMap: + input = normalize_dims(input, ndim) + destination = normalize_dims(destination, ndim) + + assert len(input) == len(destination) + input_set = set(input) + assert len(input_set) == len(input), "Found repeated input dims" + assert len(set(destination)) == len(destination), "Found repeated output dims" + assert max(input) < ndim + assert max(destination) < ndim + + dest = [-1] * ndim + for i, d in zip(input, destination): + dest[d] = i + + unused_inputs_iter = iter(i for i in range(ndim) if i not in input_set) + for i in range(ndim): + if dest[i] == -1: + dest[i] = next(unused_inputs_iter) + + return tuple(InputDim(i) for i in dest) + + +def dim_repeat(ndim: int, sizes: Shape) -> DimMap: + sizes = normalize_sizes(sizes) + assert ( + len(sizes) >= ndim + ), f"Number of dimensions of repeat dims {sizes} can not be smaller than number of dimensions of tensor {ndim}." + pad = len(sizes) - ndim + return tuple(Repeat.new(Singleton(), s) for s in sizes[:pad]) + tuple( + Repeat.new(InputDim(i), s) for i, s in enumerate(sizes[pad:]) + ) + + +def infer_size(total_size: int, sizes: Shape) -> Shape: + """ + One dimension input to view may be "-1". + + Infer the size of this dimension given the total_size. + """ + infers = [i for i, s in enumerate(sizes) if s == -1] + size = prod(sizes) + assert len(infers) <= 1, "can only infer one size" + if infers: + size = -size + missing_size = total_size // size + assert ( + total_size % size == 0 + ), f"size inferred for -1 is not integral {sizes} should have {total_size} elements." + return tuple(s if s != -1 else missing_size for s in sizes) + assert size == total_size, f"sizes do not match {total_size} vs {size}" + return sizes + + +def view_groups(from_size: Shape, to_size: Shape) -> DimMap: + """ + Decompose a reshape operation into forwarding, flattening, or splitting dimensions for each output dimension. + + A view or reshape operation can be decomposed into a set of 3 types of smaller operations: + 1) Forward a dimension from input to output + 2) Flatten a set of dimensions into a single dimension + 3) Split one dimension into multiple dimensions + + view_groups identifies these operations and returns, for each output dimension, what + is operation was performed in the input dimension. For example: + + view_groups([2, 3, 4], [2, 12]) -> ( + InputDim(0), + Flatten((InputDim(1), InputDim(2))) + ) + + - ouptut dimension 0 maps to input dimension 0 + - output dimension 1 maps to a flattened input dimensions 1 and 2 + + + view_groups([2, 3], [3, 2]) -> ( + Split(Flatten((InputDim(0), InputDim(1))), (3, 2), 0), + Split(Flatten((InputDim(0), InputDim(1))), (3, 2), 1), + ) + + - in the above, input is flattened into a single dimension and then split + into two separate dimensions with different sizes from the input. + """ + from_nelem = prod(from_size) + to_size = infer_size(from_nelem, normalize_sizes(to_size)) + + assert from_nelem == prod(to_size), "Total view shape does not add up" + + from_idx = 0 + to_idx = 0 + from_len = len(from_size) + to_len = len(to_size) + + result_pp = [] + + while from_idx < from_len or to_idx < to_len: + from_group_dim, to_group_shape = [], [] + + if from_idx >= from_len: + f = 1 + else: + f = from_size[from_idx] + from_group_dim.append(from_idx) + from_idx += 1 + + if to_idx >= to_len: + t = 1 + else: + t = to_size[to_idx] + to_group_shape.append(t) + to_idx += 1 + + # if any of the groups is singleton, great, we need to backtrack though + if f == 1 and t != 1: + # produces ([1], []) + to_idx -= 1 + to_group_shape = [] + elif f != 1 and t == 1: + # produces ([], [1]) + from_idx -= 1 + from_group_dim = [] + else: + # produces ([1], [1]), ([2], [2]), ([2,3], [6]) + while f != t: + if f < t: + nf = from_size[from_idx] + from_group_dim.append(from_idx) + from_idx += 1 + f *= nf + else: + nt = to_size[to_idx] + to_group_shape.append(nt) + to_idx += 1 + t *= nt + + if len(to_group_shape) > 0: + flattened = Flatten.new( + tuple(InputDim(fi) for fi in from_group_dim if from_size[fi] > 1) + ) + result_pp += [ + Split.new(flattened, tuple(to_group_shape), i) + for i in range(len(to_group_shape)) + ] + + return tuple(result_pp) + + +def dim_tile(ndim: int, dims: Tuple[int, ...]) -> DimMap: + if len(dims) < ndim: + dims = (1,) * (ndim - len(dims)) + dims + return dim_repeat(ndim, dims) + + +def dim_transpose(ndim: int, dim1: int, dim2: int) -> DimMap: + dim1 = normalize_dim(dim1, ndim) + dim2 = normalize_dim(dim2, ndim) + assert dim1 < ndim + assert dim2 < ndim + dimmap = [InputDim(i) for i in range(ndim)] + swapdim = dimmap[dim1] + dimmap[dim1] = dimmap[dim2] + dimmap[dim2] = swapdim + return tuple(dimmap) + + +def dim_squeeze(shape: Shape, dim: Optional[int] = None) -> DimMap: + # FIXME: this is wrong when dim=None and one of the dimensions + # equals size of the mesh. For example squeeze(DTensor(tensor(4), Shard[0])) could + # end up as squeeze(tensor(1)) if we have 4 devices; this would lead to + # removal of a dimension that is not actually a singleton. + return tuple( + InputDim(i) + for i, s in enumerate(shape) + if s > 1 or (dim is not None and i != normalize_dim(dim, len(shape))) + ) + + +def dim_unsqueeze(ndim: int, dim: int) -> DimMap: + dims = tuple(InputDim(i) for i in range(ndim)) + if dim < 0: + dim += ndim + 1 + return dims[:dim] + (Singleton(),) + dims[dim:] + + +def dim_view_as_real(shape: Shape) -> DimMap: + ndim = len(shape) + results: List[DimSpec] = [InputDim(i) for i in range(ndim - 1)] + # each complex number is split into two real numbers, + # resulting in one more dimension of size 2 + results.append(Split(InputDim(ndim - 1), (shape[-1], 2), 0)) + results.append(Split(InputDim(ndim - 1), (shape[-1], 2), 1)) + return tuple(results) + + +def dim_reduction( + ndim: int, dim_or_dims: Optional[Union[int, Sequence[int]]], keepdim: bool +) -> DimMap: + """ + General fallback for reduction ops where Partial() does not apply. + + This will cause incoming tensor to be replicated on the reducing dimensions. + """ + if dim_or_dims is None: + dim_or_dims = tuple(range(ndim)) + if isinstance(dim_or_dims, int): + dim_or_dims = (dim_or_dims,) + dim_or_dims = tuple(d if d >= 0 else d + ndim for d in dim_or_dims) + return tuple( + InputDim(i) if i not in dim_or_dims else Singleton() + for i in range(ndim) + if i not in dim_or_dims or keepdim + ) + + +dim_maps: Dict[Callable[..., torch.Tensor], Callable[..., DimMap]] = { + torch.atleast_1d: lambda x: dim_pad_left(x.ndim, 1), + torch.atleast_2d: lambda x: dim_pad_left(x.ndim, 2), + torch.atleast_3d: lambda x: dim_atleast_3d(x.ndim), + torch.broadcast_to: lambda input, shape: expand(input.shape, shape), + Tensor.expand: lambda self, *sizes: expand(self.shape, normalize_sizes(sizes)), + torch.flatten: lambda tensor: dim_flatten(tensor.ndim), + torch.movedim: lambda input, source, destination: dim_movedim( + input.ndim, source, destination + ), + torch.permute: lambda input, dims: tuple( + InputDim(i) for i in normalize_dims(dims, input.ndim) + ), + torch.ravel: lambda tensor: dim_flatten(tensor.ndim), + Tensor.repeat: lambda self, *sizes: dim_repeat(self.ndim, sizes), + torch.reshape: lambda input, shape: view_groups(input.shape, shape), + torch.squeeze: lambda input, dim=None: dim_squeeze(input.shape, dim), + torch.tile: lambda input, dims: dim_tile(input.ndim, dims), + torch.transpose: lambda input, dim0, dim1: dim_transpose(input.ndim, dim0, dim1), + torch.unsqueeze: lambda input, dim: dim_unsqueeze(input.ndim, dim), + Tensor.view: lambda input, *shape: view_groups(input.shape, shape), + torch.view_as_complex: lambda input: dim_flatten(input.ndim, input.ndim - 2), + torch.view_as_real: lambda input: dim_view_as_real(input.shape), +} + + +def propagate_shape_and_sharding( + input_src_placements: Sequence[Placement], + local_in_shape: Shape, + rule: DimMap, + mesh_sizes: Shape, +) -> Tuple[Sequence[Placement], Sequence[Placement]]: + """ + Determine input target sharding and output sharding based on + given global tensor shape and input source sharding. + + Sharding propagation follows mapped dimensions: + - An output dimension that maps directly to an input dimension is sharded equally + - An output dimension that is a flattened set of input dimensions can only be + sharded if only the leftmost flattened dimension is sharded. + - An output dimension that is a split of the input dimension can only be sharded + if the leftmost split size is divisible by the mesh dimension + """ + assert len(input_src_placements) == len(mesh_sizes) + # for each input dim, for each mesh dim, provides a list of possible shardable dimensions + mesh_ndim = len(mesh_sizes) + shardable_dims: Dict[int, List[bool]] = {} + + # in case an input dimension disappears (e.g. collapsing, reduction) + # we cannot shard in that dimension (we need a replication fall-back rule) + seen_input_dims: Set[int] = set() + + def collect_used_inputs(cmd: DimSpec) -> None: + if isinstance(cmd, InputDim): + seen_input_dims.add(cmd.input_dim) + for inp in cmd.inputs(): + collect_used_inputs(inp) + + for cmd in rule: + collect_used_inputs(cmd) + for dim in range(len(local_in_shape)): + shardable_dims[dim] = [dim in seen_input_dims] * mesh_ndim + + def get_in_dim_to_shard(cmd: DimSpec) -> Optional[InputDim]: + if isinstance(cmd, InputDim): + return cmd + elif isinstance(cmd, Flatten): + for dim in cmd.input_dims[1:]: + if isinstance(dim, InputDim): + shardable_dims[dim.input_dim] = [False] * mesh_ndim + dim0 = cmd.input_dims[0] + return dim0 if isinstance(dim0, InputDim) else None + elif isinstance(cmd, Split): + in_dim = get_in_dim_to_shard(cmd.input_dim) + out_size = cmd.group_shape[cmd.split_id] + if cmd.split_id == 0 and in_dim is not None: + # we need to check that the input dimension is divisible + # by the size of the submesh we're sharding it on + # NOTE: it would be possible to shard the same input dimension + # on more than one mesh dimension. In that case, the dimension + # needs to be divisible by the product of mesh sizes. + # In order to keep the problem more tractable, we will not consider + # double resharding as a suggestion (e.g. [Shard(0), Shard(0) ]) + # but we will allow it if that's the input and it's compatible + + # 1. is this dimension shardable on each individual mesh dim? + shardable_dims[in_dim.input_dim] = [ + out_size % mesh_dim_size == 0 for mesh_dim_size in mesh_sizes + ] + + # 2. here we special case things like [Shard(0), Shard(0)] + submesh_size = 1 + for size, shard in zip(mesh_sizes, input_src_placements): + if isinstance(shard, Shard) and shard.dim == in_dim: + submesh_size *= size + assert ( + out_size % submesh_size == 0 + ), f"Resulting dimension size {out_size} is not divisible by its mesh dimension {submesh_size}." + + # we will only shard our first component of the split + return in_dim if cmd.split_id == 0 else None + elif isinstance(cmd, Repeat): + in_dim = get_in_dim_to_shard(cmd.input_dim) + if in_dim is not None: + shardable_dims[in_dim.input_dim] = [False] * mesh_ndim + return None + else: + return None + + # for each output dim, find the corresponding input dim in terms of sharding prop + shard_dim_map = {} + for dim, cmd in enumerate(rule): + in_dim = get_in_dim_to_shard(cmd) + if in_dim is not None: + shard_dim_map[in_dim.input_dim] = dim + + input_tgt_placements = [ + ( + Replicate() + if isinstance(p, Shard) and not shardable_dims[p.dim][mesh_dim] + else p + ) + for mesh_dim, p in enumerate(input_src_placements) + ] + output_placements = [ + Shard(shard_dim_map[p.dim]) if isinstance(p, Shard) else p + for p in input_tgt_placements + ] + + return input_tgt_placements, output_placements + + +def register_op_strategy_map( + aten_op_overload: torch._ops.OpOverload, + local_op_name: Callable[..., torch.Tensor], + schema_info: Optional[RuntimeSchemaInfo] = None, +) -> None: + dim_map: Callable[..., DimMap] = dim_maps[local_op_name] + + @register_op_strategy(aten_op_overload, schema_info=schema_info) + def reshape_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: + rules = dim_map(*op_schema.args_schema, **op_schema.kwargs_schema) + input_strategy = cast(OpStrategy, op_schema.args_schema[0]) + global_in_shape = input_strategy.shape + assert global_in_shape is not None, "Shape required." + + output_strategy = OpStrategy([]) + for input_placement_strategy in input_strategy.strategies: + input_src_spec = input_placement_strategy.output_spec + + input_tgt_placements, output_placements = propagate_shape_and_sharding( + input_src_spec.placements, + tuple(global_in_shape), + rules, + mesh.shape, + ) + + # TODO: optimize this. we shouldn't simply blindly replicate + # unshardable dims ... + # FIXME: this can be wrong for situations where we have + # [Shard(0), Shard(0)] + input_tgt_spec = DTensorSpec( + placements=tuple(input_tgt_placements), + mesh=input_src_spec.mesh, + tensor_meta=input_src_spec.tensor_meta, + ) + redistribute_costs = [ + generate_redistribute_costs(input_strategy, input_tgt_spec) + ] + + output_spec = DTensorSpec(mesh=mesh, placements=tuple(output_placements)) + output_strategy.strategies.append( + PlacementStrategy( + output_specs=output_spec, + input_specs=(input_tgt_spec,), + redistribute_cost=redistribute_costs, + ) + ) + + return output_strategy + + +register_op_strategy_map(aten.squeeze.default, torch.squeeze) +register_op_strategy_map( + aten.squeeze.dim, torch.squeeze, schema_info=RuntimeSchemaInfo(1) +) +register_op_strategy_map( + aten.view.default, Tensor.view, schema_info=RuntimeSchemaInfo(1) +) +register_op_strategy_map( + aten.reshape.default, torch.reshape, schema_info=RuntimeSchemaInfo(1) +) +register_op_strategy_map( + aten._unsafe_view.default, Tensor.view, schema_info=RuntimeSchemaInfo(1) +) +register_op_strategy_map( + aten.unsqueeze.default, torch.unsqueeze, schema_info=RuntimeSchemaInfo(1) +) +register_op_strategy_map( + aten.expand.default, Tensor.expand, schema_info=RuntimeSchemaInfo(1) +) +register_op_strategy_map( + aten.permute.default, torch.permute, schema_info=RuntimeSchemaInfo(1) +) +register_op_strategy_map( + aten.repeat.default, Tensor.repeat, schema_info=RuntimeSchemaInfo(1) +) +register_op_strategy_map( + aten.transpose.int, torch.transpose, schema_info=RuntimeSchemaInfo(1) +) +register_op_strategy_map(aten.view_as_complex.default, torch.view_as_complex) +register_op_strategy_map(aten.view_as_real.default, torch.view_as_real) diff --git a/src/chop/passes/graph/analysis/autosharding/alpa.py b/src/chop/passes/graph/analysis/autosharding/alpa.py index d00661128..d51872c71 100644 --- a/src/chop/passes/graph/analysis/autosharding/alpa.py +++ b/src/chop/passes/graph/analysis/autosharding/alpa.py @@ -1,7 +1,6 @@ from chop.tools import get_logger from .alpa_intra_operator import alpa_intra_op_sharding_pass -from .mesh_model import MeshModel logger = get_logger(__name__) logger.setLevel("DEBUG") From bf626bf1ac4aa210019d5c88915708cf8b75e8ac Mon Sep 17 00:00:00 2001 From: Pedro Gimenes Date: Tue, 23 Jul 2024 11:31:22 +0000 Subject: [PATCH 49/93] patch for redistribute --- src/chop/distributed/tensor/_dispatch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/chop/distributed/tensor/_dispatch.py b/src/chop/distributed/tensor/_dispatch.py index 03e9139c0..8dc16f7ff 100644 --- a/src/chop/distributed/tensor/_dispatch.py +++ b/src/chop/distributed/tensor/_dispatch.py @@ -16,7 +16,6 @@ OpSchema, OutputSpecType, ) -from torch.distributed._tensor._redistribute import redistribute_local_tensor from torch.distributed._tensor._tp_conv import ( convolution_backward_handler, convolution_handler, @@ -36,6 +35,7 @@ import chop.distributed.tensor.api as dtensor from chop.distributed.tensor._sharding_prop import ShardingPropagator +from chop.distributed.tensor._redistribute import redistribute_local_tensor aten = torch.ops.aten logger = logging.getLogger(__name__) From 3fa67440d92b387901ce66bb94fc5906218eee3d Mon Sep 17 00:00:00 2001 From: Pedro Gimenes Date: Tue, 23 Jul 2024 11:38:52 +0000 Subject: [PATCH 50/93] remove logging --- src/chop/distributed/tensor/_dispatch.py | 4 ---- src/chop/distributed/tensor/_sharding_prop.py | 2 -- src/chop/distributed/tensor/ops/math_ops.py | 11 ----------- 3 files changed, 17 deletions(-) diff --git a/src/chop/distributed/tensor/_dispatch.py b/src/chop/distributed/tensor/_dispatch.py index 8dc16f7ff..b9316729e 100644 --- a/src/chop/distributed/tensor/_dispatch.py +++ b/src/chop/distributed/tensor/_dispatch.py @@ -118,8 +118,6 @@ def dispatch( """ # operators that does not need to go through sharding propagation - rlog(f"Dispatching op call: {op_call}") - if op_call in self._custom_op_handlers: return self._custom_op_handlers[op_call](op_call, args, kwargs) # type: ignore[operator] @@ -183,7 +181,6 @@ def default_tensor(spec: DTensorSpec) -> torch.Tensor: if output_sharding.needs_redistribute: # compute locally with redistribute first if needed assert output_sharding.redistribute_schema is not None - rlog(f"Op: {op_call} needs redistribute") self.redistribute_local_args( op_info, output_sharding.redistribute_schema ) @@ -217,7 +214,6 @@ def default_tensor(spec: DTensorSpec) -> torch.Tensor: with rng_context: local_results = op_call(*local_tensor_args, **op_info.local_kwargs) else: - # rlog(f"Calling {op_call} with args: {local_tensor_args} and kwargs: {op_info.local_kwargs}") local_results = op_call(*local_tensor_args, **op_info.local_kwargs) # communicate the result to all ranks for some operators that return scalar value diff --git a/src/chop/distributed/tensor/_sharding_prop.py b/src/chop/distributed/tensor/_sharding_prop.py index 6d7d6ea11..81224a952 100644 --- a/src/chop/distributed/tensor/_sharding_prop.py +++ b/src/chop/distributed/tensor/_sharding_prop.py @@ -254,8 +254,6 @@ def spec_to_strategy(spec: object) -> object: assert isinstance(output_strategy.output_specs, DTensorSpec) for idx, input_spec in enumerate(op_schema.args_spec): - if "layer_norm" in str(op_schema.op): - rlog(f" arg {idx}, input_spec: {input_spec}") desired_spec = ( output_strategy.output_spec if output_strategy.input_specs is None diff --git a/src/chop/distributed/tensor/ops/math_ops.py b/src/chop/distributed/tensor/ops/math_ops.py index d18396873..c6ebe49be 100644 --- a/src/chop/distributed/tensor/ops/math_ops.py +++ b/src/chop/distributed/tensor/ops/math_ops.py @@ -787,18 +787,10 @@ def layer_norm_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy: input_ndim = input_strategy.ndim axis = input_ndim - len(normalized_size) - rlog(f"=========== LAYER NORM STRATEGY ===========") - - rlog(f" normalized_size: {normalized_size}") - rlog(f" input_ndim: {input_ndim}") - rlog(f" axis: {axis}") - rlog(f" input_strategy: {input_strategy}") - # we use OpStrategy because the output (out, mean, rstd) # should have the same placements output_strategy = OpStrategy([]) for idx, input_placement_strategy in enumerate(input_strategy.strategies): - rlog(f"strategy {idx}: {input_placement_strategy}") op_args_target_specs = [] redistribute_costs = [] input_src_spec = input_placement_strategy.output_spec @@ -860,9 +852,6 @@ def layer_norm_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy: ) ) - rlog(f"LAYER_NORM_STRATEGY output: {output_strategy}") - rlog(f"===========================================") - return output_strategy From 519d09cdf0d9c20a40d4cd6e3a4f4512029b261d Mon Sep 17 00:00:00 2001 From: Pedro Gimenes Date: Tue, 23 Jul 2024 12:16:31 +0000 Subject: [PATCH 51/93] fix circular import and remove "setting verbosity to debug" message at every import --- src/chop/nn/quantized/__init__.py | 1 - src/chop/nn/quantized/modules/__init__.py | 3 +- src/chop/nn/quantized/modules/attention.py | 89 ------------------- .../add_metadata/common_metadata_layers.py | 5 +- .../transforms/verilog/emit_vivado_project.py | 1 - src/mase_components/linter.py | 1 - src/mase_components/synth_runner.py | 1 - 7 files changed, 3 insertions(+), 98 deletions(-) diff --git a/src/chop/nn/quantized/__init__.py b/src/chop/nn/quantized/__init__.py index 1d02146e4..f9f1389e7 100644 --- a/src/chop/nn/quantized/__init__.py +++ b/src/chop/nn/quantized/__init__.py @@ -2,7 +2,6 @@ quantized_module_map, BertSelfAttentionInteger, BertSelfAttentionHeadInteger, - LlamaSdpaAttentionInteger, LinearInteger, LayerNormInteger, GELUInteger, diff --git a/src/chop/nn/quantized/modules/__init__.py b/src/chop/nn/quantized/modules/__init__.py index 64a971aa8..defc0d76f 100644 --- a/src/chop/nn/quantized/modules/__init__.py +++ b/src/chop/nn/quantized/modules/__init__.py @@ -1,5 +1,5 @@ from .attention_head import BertSelfAttentionHeadInteger -from .attention import BertSelfAttentionInteger, LlamaSdpaAttentionInteger +from .attention import BertSelfAttentionInteger # from .add import AddInteger from .conv1d import ( @@ -266,5 +266,4 @@ "batch_norm1d_linear": BatchNorm1dInteger, "bert_self_attention_head_integer": BertSelfAttentionHeadInteger, "bert_self_attention_integer": BertSelfAttentionInteger, - "llama_sdpa_attention_integer": LlamaSdpaAttentionInteger, } diff --git a/src/chop/nn/quantized/modules/attention.py b/src/chop/nn/quantized/modules/attention.py index b769647a5..45819db75 100644 --- a/src/chop/nn/quantized/modules/attention.py +++ b/src/chop/nn/quantized/modules/attention.py @@ -5,8 +5,6 @@ from torch.nn import functional as F from transformers.models.bert.modeling_bert import BertSelfAttention -from chop.models.patched.llama.modeling_llama import LlamaSdpaAttention -from chop.models.patched.llama.configuration_llama import LlamaConfig from chop.nn.quantized.modules.linear import ( LinearInteger, @@ -58,47 +56,6 @@ def forward( return out -class _LlamaSdpaAttentionBase(LlamaSdpaAttention): - """Multi-headed attention from 'Attention Is All You Need' paper""" - - def __init__( - self, - config: LlamaConfig, - layer_idx: Optional[int] = None, - q_config: dict = None, - out_q_config: dict = None, - output_tensor_only=False, - ): - super().__init__(config, layer_idx) - self.bypass = False - self.q_config = q_config - self.out_q_config = out_q_config - self.output_tensor_only = output_tensor_only - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.FloatTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, - output_attentions: Optional[bool] = None, - use_cache: Optional[bool] = None, - cache_position: Optional[int] = None, - ) -> Tuple[Tensor, Optional[Tensor]]: - out = super().forward( - hidden_states, - attention_mask, - position_ids, - past_key_value, - output_attentions, - use_cache, - cache_position, - ) - if self.output_tensor_only: - return out[0] - return out - - class BertSelfAttentionInteger(_BertSelfAttentionBase): def __init__( self, @@ -156,49 +113,3 @@ def __init__( out_config=out_q_config, floor=floor, ) - - -class LlamaSdpaAttentionInteger(_LlamaSdpaAttentionBase): - def __init__( - self, - config: LlamaConfig, - layer_idx: Optional[int] = None, - q_config: dict = None, - out_q_config: dict = None, - output_tensor_only=False, - ): - super().__init__( - config, - layer_idx, - q_config, - out_q_config, - output_tensor_only=output_tensor_only, - ) - self.q_proj = LinearInteger( - self.hidden_size, - self.num_heads * self.head_dim, - bias=config.attention_bias, - config=q_config, - out_config=out_q_config, - ) - self.k_proj = LinearInteger( - self.hidden_size, - self.num_key_value_heads * self.head_dim, - bias=config.attention_bias, - config=q_config, - out_config=out_q_config, - ) - self.v_proj = LinearInteger( - self.hidden_size, - self.num_key_value_heads * self.head_dim, - bias=config.attention_bias, - config=q_config, - out_config=out_q_config, - ) - self.o_proj = LinearInteger( - self.hidden_size, - self.hidden_size, - bias=config.attention_bias, - config=q_config, - out_config=out_q_config, - ) diff --git a/src/chop/passes/graph/analysis/add_metadata/common_metadata_layers.py b/src/chop/passes/graph/analysis/add_metadata/common_metadata_layers.py index 85134b8fa..0db2aab6c 100644 --- a/src/chop/passes/graph/analysis/add_metadata/common_metadata_layers.py +++ b/src/chop/passes/graph/analysis/add_metadata/common_metadata_layers.py @@ -5,8 +5,7 @@ import inspect from chop.tools.utils import to_numpy_if_tensor as to_numpy from chop.passes.graph.utils import vf, get_node_by_name -from chop.passes.graph.patching import MASE_LEAF_FUNCTIONS, MASE_LEAF_LAYERS -import traceback +from chop.nn.quantized.modules import quantized_module_map from functools import reduce @@ -340,7 +339,7 @@ def get_type_and_precision(meta): # * Fetch type and precision from q_config for quantized modules - if isinstance(meta.module, MASE_LEAF_LAYERS): + if isinstance(meta.module, tuple(quantized_module_map.values())): cf = ( meta.module.q_config if hasattr(meta.module, "q_config") diff --git a/src/chop/passes/graph/transforms/verilog/emit_vivado_project.py b/src/chop/passes/graph/transforms/verilog/emit_vivado_project.py index 09e6ef3f7..3219a42ac 100644 --- a/src/chop/passes/graph/transforms/verilog/emit_vivado_project.py +++ b/src/chop/passes/graph/transforms/verilog/emit_vivado_project.py @@ -8,7 +8,6 @@ from mase_components.deps import MASE_HW_DEPS logger = get_logger(f"emit_vivado_project") -set_logging_verbosity("debug") COMPONENTS_PATH = Path(mase_components.__file__).parents[0] diff --git a/src/mase_components/linter.py b/src/mase_components/linter.py index 5de03967f..b9b516265 100644 --- a/src/mase_components/linter.py +++ b/src/mase_components/linter.py @@ -7,7 +7,6 @@ from mase_components.deps import MASE_HW_DEPS logger = get_logger(f"linter") -set_logging_verbosity("debug") COMPONENTS_PATH = Path(__file__).parents[0] diff --git a/src/mase_components/synth_runner.py b/src/mase_components/synth_runner.py index bb852f78f..3381d016c 100644 --- a/src/mase_components/synth_runner.py +++ b/src/mase_components/synth_runner.py @@ -7,7 +7,6 @@ from mase_components.deps import MASE_HW_DEPS logger = get_logger(f"linter") -set_logging_verbosity("debug") COMPONENTS_PATH = Path(__file__).parents[0] From e8849a850a1f7ba145cfb7f5a43ea1b483680dfc Mon Sep 17 00:00:00 2001 From: Pedro Gimenes Date: Tue, 23 Jul 2024 12:22:06 +0000 Subject: [PATCH 52/93] remove breakpoints --- .../graph/analysis/add_metadata/common_metadata_layers.py | 3 --- .../graph/analysis/autosharding/strategies/matrix_ops.py | 1 - 2 files changed, 4 deletions(-) diff --git a/src/chop/passes/graph/analysis/add_metadata/common_metadata_layers.py b/src/chop/passes/graph/analysis/add_metadata/common_metadata_layers.py index 0db2aab6c..39a1336ee 100644 --- a/src/chop/passes/graph/analysis/add_metadata/common_metadata_layers.py +++ b/src/chop/passes/graph/analysis/add_metadata/common_metadata_layers.py @@ -566,10 +566,7 @@ def analyse_common_parameters_module(meta, result, args, kwargs, add_value=True) def analyse_common_parameters_method(meta, result, args, kwargs, add_value=True): mase_op = meta.parameters["common"]["mase_op"] meta = analyse_result(meta, result, add_value) - # try: meta = match_args_and_kwargs(meta, args, kwargs, method_data[mase_op], add_value) - # except: - # breakpoint() return meta diff --git a/src/chop/passes/graph/analysis/autosharding/strategies/matrix_ops.py b/src/chop/passes/graph/analysis/autosharding/strategies/matrix_ops.py index 76f30db47..d552f8171 100644 --- a/src/chop/passes/graph/analysis/autosharding/strategies/matrix_ops.py +++ b/src/chop/passes/graph/analysis/autosharding/strategies/matrix_ops.py @@ -202,7 +202,6 @@ def scaled_dot_product_flash_attention_strategy( meta: MaseMetadata, mesh: tuple, ) -> OpStrategy: - breakpoint() # NOTE: currently we only support some simple strategies to support tensor parallelism # TODO: sdpa might be a good candidate for us to explore decomposed sharding propagation # as it involves: matmul, pointwise, reduction ops together. From 550da9cc828a017ead7f3ea8270a07b530de256d Mon Sep 17 00:00:00 2001 From: Pedro Gimenes Date: Tue, 23 Jul 2024 12:55:46 +0000 Subject: [PATCH 53/93] formatting --- src/chop/ir/graph/mase_metadata.py | 2 + src/chop/models/__init__.py | 1 + src/chop/nn/__init__.py | 4 +- .../passes/graph/analysis/report/__init__.py | 2 +- .../analysis/report/report_parallelization.py | 41 +++++++++++-------- src/chop/passes/module/transforms/__init__.py | 2 +- 6 files changed, 31 insertions(+), 21 deletions(-) diff --git a/src/chop/ir/graph/mase_metadata.py b/src/chop/ir/graph/mase_metadata.py index 770ae8269..cc58d64ad 100644 --- a/src/chop/ir/graph/mase_metadata.py +++ b/src/chop/ir/graph/mase_metadata.py @@ -4,12 +4,14 @@ logger = logging.getLogger(__name__) + def get_module_by_name(model, request_name): for name, layer in model.named_modules(): if name == request_name: return layer return None + class MaseMetadata: """ The metadata of a Mase node in a Mase graph describes the constraints of the diff --git a/src/chop/models/__init__.py b/src/chop/models/__init__.py index 0174373ce..0c28c23b7 100644 --- a/src/chop/models/__init__.py +++ b/src/chop/models/__init__.py @@ -12,6 +12,7 @@ get_patched_model_config_cls, get_patched_model_tokenizer_cls, ) + # from .manual import ( # is_manual_model, # get_manual_model, diff --git a/src/chop/nn/__init__.py b/src/chop/nn/__init__.py index 345a92d90..7ab651324 100644 --- a/src/chop/nn/__init__.py +++ b/src/chop/nn/__init__.py @@ -1,5 +1,3 @@ from .quantized import quantized_module_map -MASE_LEAF_LAYERS = tuple( - quantized_module_map.values() -) \ No newline at end of file +MASE_LEAF_LAYERS = tuple(quantized_module_map.values()) diff --git a/src/chop/passes/graph/analysis/report/__init__.py b/src/chop/passes/graph/analysis/report/__init__.py index 057db9740..972a68e8c 100644 --- a/src/chop/passes/graph/analysis/report/__init__.py +++ b/src/chop/passes/graph/analysis/report/__init__.py @@ -5,4 +5,4 @@ report_node_shape_analysis_pass, report_node_type_analysis_pass, ) -from .report_parallelization import report_parallelization_analysis_pass \ No newline at end of file +from .report_parallelization import report_parallelization_analysis_pass diff --git a/src/chop/passes/graph/analysis/report/report_parallelization.py b/src/chop/passes/graph/analysis/report/report_parallelization.py index 37be5c19b..49ac1bf82 100644 --- a/src/chop/passes/graph/analysis/report/report_parallelization.py +++ b/src/chop/passes/graph/analysis/report/report_parallelization.py @@ -1,25 +1,34 @@ - from tabulate import tabulate + def report_parallelization_analysis_pass(mg, pass_args={}): fname = pass_args.get("file_name", "report_parallelization.txt") - headers = ["Node", "Node op", "Mase op", "Args", "Kwargs", "Input Sharding", "Output Sharding"] + headers = [ + "Node", + "Node op", + "Mase op", + "Args", + "Kwargs", + "Input Sharding", + "Output Sharding", + ] info = [] for node in mg.fx_graph.nodes: - sharding_config = node.meta['mase']['software']['autosharding'] - info.append([ - node.name, - node.op, - node.meta['mase']['common']['mase_op'], - node.args, - node.kwargs, - sharding_config["input"], - sharding_config["output"] - - ]) - + sharding_config = node.meta["mase"]["software"]["autosharding"] + info.append( + [ + node.name, + node.op, + node.meta["mase"]["common"]["mase_op"], + node.args, + node.kwargs, + sharding_config["input"], + sharding_config["output"], + ] + ) + with open(fname, "w") as f: f.write(f"{tabulate(info, headers)}\n") - - return mg, {} \ No newline at end of file + + return mg, {} diff --git a/src/chop/passes/module/transforms/__init__.py b/src/chop/passes/module/transforms/__init__.py index 812b95d5e..efbb0ed14 100644 --- a/src/chop/passes/module/transforms/__init__.py +++ b/src/chop/passes/module/transforms/__init__.py @@ -1,3 +1,3 @@ from .autosharding import resharding_transform_pass from .quantize import quantize_module_transform_pass -from .autosharding import resharding_transform_pass \ No newline at end of file +from .autosharding import resharding_transform_pass From 64acb520bb8ad6ba4bbd3b62114551f0b89a9ae5 Mon Sep 17 00:00:00 2001 From: Pedro Gimenes Date: Tue, 23 Jul 2024 13:30:04 +0000 Subject: [PATCH 54/93] fix circular import --- src/chop/models/__init__.py | 18 +++++++-------- .../add_metadata/common_metadata_layers.py | 2 -- src/chop/tools/get_input.py | 23 ++++++++++++++++++- 3 files changed, 31 insertions(+), 12 deletions(-) diff --git a/src/chop/models/__init__.py b/src/chop/models/__init__.py index 0c28c23b7..df22bbd32 100644 --- a/src/chop/models/__init__.py +++ b/src/chop/models/__init__.py @@ -13,15 +13,15 @@ get_patched_model_tokenizer_cls, ) -# from .manual import ( -# is_manual_model, -# get_manual_model, -# get_manual_model_cls, -# get_manual_model_config_cls, -# get_manual_model_tokenizer_cls, -# get_manual_model_info, -# get_manual_model_tokenizer, -# ) +from .manual import ( + is_manual_model, + get_manual_model, + get_manual_model_cls, + get_manual_model_config_cls, + get_manual_model_tokenizer_cls, + get_manual_model_info, + get_manual_model_tokenizer, +) from .huggingface_nlp_models import ( is_hf_nlp_model, diff --git a/src/chop/passes/graph/analysis/add_metadata/common_metadata_layers.py b/src/chop/passes/graph/analysis/add_metadata/common_metadata_layers.py index 39a1336ee..2784fceb6 100644 --- a/src/chop/passes/graph/analysis/add_metadata/common_metadata_layers.py +++ b/src/chop/passes/graph/analysis/add_metadata/common_metadata_layers.py @@ -3,8 +3,6 @@ import torch import inspect -from chop.tools.utils import to_numpy_if_tensor as to_numpy -from chop.passes.graph.utils import vf, get_node_by_name from chop.nn.quantized.modules import quantized_module_map from functools import reduce diff --git a/src/chop/tools/get_input.py b/src/chop/tools/get_input.py index f5875cdda..18966b932 100644 --- a/src/chop/tools/get_input.py +++ b/src/chop/tools/get_input.py @@ -1,6 +1,27 @@ import inspect from typing import Literal -from ..models.utils import ModelSource +from enum import Enum + + +class ModelSource(Enum): + """ + The source of the model, must be one of the following: + - HF: HuggingFace + - MANUAL: manually implemented + - PATCHED: patched HuggingFace + - TOY: toy model for testing and debugging + - PHYSICAL: model that perform classification using physical data point vectors + - NERF: model that estimates neural radiance field (NeRF) of a 3D scene + """ + + HF_TRANSFORMERS = "hf_transformers" + MANUAL = "manual" + PATCHED = "patched" + TOY = "toy" + TORCHVISION = "torchvision" + VISION_OTHERS = "vision_others" + PHYSICAL = "physical" + NERF = "nerf" def _get_default_args(func): From 5c8f36a99aa74542a339eac80bc63fb6e87beac6 Mon Sep 17 00:00:00 2001 From: Pedro Gimenes Date: Tue, 23 Jul 2024 13:47:38 +0000 Subject: [PATCH 55/93] remove unfinished emit verilog tests for llama/mistral --- src/chop/tools/onnx_operators.py | 160 +++++++++++++ .../verilog/test_emit_verilog_llama.py | 211 ------------------ .../verilog/test_emit_verilog_mistral.py | 153 ------------- 3 files changed, 160 insertions(+), 364 deletions(-) create mode 100644 src/chop/tools/onnx_operators.py delete mode 100644 test/passes/graph/transforms/verilog/test_emit_verilog_llama.py delete mode 100644 test/passes/graph/transforms/verilog/test_emit_verilog_mistral.py diff --git a/src/chop/tools/onnx_operators.py b/src/chop/tools/onnx_operators.py new file mode 100644 index 000000000..9250d46ce --- /dev/null +++ b/src/chop/tools/onnx_operators.py @@ -0,0 +1,160 @@ +import torch + +""" + This module contains a collection of ONNX operators implemented + using Pytorch primitives. +""" + + +def onnx_gemm(A, B, C=None, alpha=1.0, beta=1.0, transA=False, transB=False): + # Transpose matrices A and B if needed + A = A.transpose() if transA else A + B = B.transpose() if transB else B + + # Perform matrix multiplication + result = alpha * torch.matmul(A, B) + + # Add optional matrix C + if C is not None: + result += beta * C + + return result + + +def onnx_slice(data, starts, ends, axes=None, steps=None): + assert len(starts) == len(ends), "Starts and ends must have the same length" + starts = starts.to(torch.int64) + ends = ends.to(torch.int64) + + rank = len(data.shape) + + if axes is None: + axes = list(range(rank)) + else: + axes = axes.to(torch.int64) + + if steps is None: + steps = [1] * rank + else: + steps = steps.to(torch.int64) + + # Default slices define entire range in each dimension + slices = [slice(0, data.shape[i], 1) for i in range(rank)] + for idx, dim in enumerate(axes): + slices[dim] = slice(starts[idx], ends[idx], steps[idx]) + + return data[slices] + + +def onnx_squeeze(input, dim): + if isinstance(dim, torch.nn.parameter.Parameter): + dim = dim.item() + return torch.squeeze(input, dim) + + +def onnx_unsqueeze(input, dim): + for i in dim: + input = torch.unsqueeze(input, i) + return input + + +def onnx_gather(input, dim, index): + """Gather operator with support for broadcasting. + See https://github.com/pytorch/pytorch/issues/9407 + + Args: + input (_type_): _description_ + dim (_type_): _description_ + index (_type_): _description_ + + Returns: + _type_: _description_ + """ + if not isinstance(input, torch.Tensor): + input = torch.tensor(list(input)) + + # expand_shape = list(index.shape[:-1]) + list(input.shape) + # tmp_inp = input.expand(expand_shape) + + n_dims = len(input.shape) + idx_list = [ + torch.arange(input.shape[i])[(None,) * i + (...,) + (None,) * (n_dims - i - 1)] + for i in range(n_dims) + ] + idx_list[dim] = index.squeeze()[ + (None,) * dim + (...,) + (None,) * (n_dims - dim - 1) + ] + return input[idx_list] + + +def onnx_shape(input): + return torch.Tensor([i for i in input.shape]) + + +def onnx_reshape(input, shape): + if isinstance(shape, torch.Tensor): + shape = tuple(shape.to(torch.int64).tolist()) + return torch.reshape(input, shape) + + +def onnx_identity(input): + return input + + +def onnx_expand(input, size): + if isinstance(size, torch.Size): + size = tuple(size) + elif isinstance(size, torch.Tensor): + size = tuple(size.to(torch.int64).tolist()) + return input.expand(size=size) + + +def onnx_where(condition, input, other): + cond = condition + pre_input_shape = input.shape + pre_other_shape = other.shape + + if len(input.shape) == 0: + input = input.unsqueeze(dim=0) + + # Two-way broadcasting of input tensors + input, other = torch.broadcast_tensors(input, other) + + assert ( + condition.shape == input.shape == other.shape + ), "Condition tensor has incorrect shape." + + # Convert condition to a boolean tensor + condition = torch.where( + condition == 0, + torch.full(input.shape, False, dtype=torch.bool), + torch.full(input.shape, True, dtype=torch.bool), + ).to(torch.bool) + return torch.where(condition, input, other) + + +def onnx_full(size, fill_value): + if isinstance(size, torch.Tensor): + size = tuple(size.to(torch.int64).tolist()) + if isinstance(fill_value, torch.Tensor): + fill_value = fill_value.item() + return torch.full(size, fill_value) + + +def onnx_min(*args, **kwargs): + input = torch.broadcast_tensors(*kwargs["input"]) + if len(input) <= 1: + raise ValueError(f"Expected 2 or more inputs, but received {len(input)}.") + + # minimum only accepts two inputs, so maintain a running minimum + result = input[0] + for i in range(1, len(input)): + result = torch.minimum(result, input[i]) + return result + + +def onnx_permute(input, dims): + input = input.squeeze() + if dims is None: + dims = [i for i in reversed(range(len(input.shape)))] + return torch.permute(input, dims) diff --git a/test/passes/graph/transforms/verilog/test_emit_verilog_llama.py b/test/passes/graph/transforms/verilog/test_emit_verilog_llama.py deleted file mode 100644 index af958c885..000000000 --- a/test/passes/graph/transforms/verilog/test_emit_verilog_llama.py +++ /dev/null @@ -1,211 +0,0 @@ -import os, operator -import pytest - -import torch -import torch.nn as nn - -from chop import AutoPipelineForEmitVerilog -import chop.passes as passes -import chop.actions as actions -from chop.ir import MaseGraph -from chop.passes.graph.utils import deepsetattr - -from chop.models.patched.llama import LlamaConfig, LlamaModel -from chop.models.patched.llama.modeling_llama import LlamaSdpaAttention, LlamaRMSNorm - -from chop.nn.quantized import ( - LlamaSdpaAttentionInteger, - LinearInteger, - RMSNormInteger, - SiLUInteger, -) - -from chop.tools import get_logger, set_excepthook - -from mase_components import get_module_dependencies -from mase_components.helper.generate_memory import generate_sv_lut - -logger = get_logger(__name__) -logger.setLevel("DEBUG") -set_excepthook() - -# Temporary: fix data coherency checks -os.environ["COCOTB_RESOLVE_X"] = "ZEROS" - -SMOKE_TEST_SCALE_FACTOR = 8 - -# * Define custom ops (leaf submodules during tracing) -# * This is useful so we can write a single optimised verilog file for self attention, -# * instead of relying on emit_verilog to instantiate each submodule -LLAMA_CUSTOM_OPS = { - "modules": { - LlamaSdpaAttention: { - "args": { - "hidden_states": "data_in", - "attention_mask": None, - "position_ids": None, - "past_key_value": None, - "output_attentions": None, - "use_cache": None, - "cache_position": None, - }, - "toolchain": "INTERNAL_RTL", - "module": "fixed_self_attention_single_precision_wrapper", - "dependence_files": get_module_dependencies( - "attention/fixed_self_attention_single_precision_wrapper" - ), - }, - RMSNormInteger: { - "args": { - "hidden_states": "data_in", - }, - "toolchain": "INTERNAL_RTL", - "module": "norm", - "dependence_files": get_module_dependencies("norm/norm"), - }, - SiLUInteger: { - "args": { - "input": "data_in", - }, - "toolchain": "INTERNAL_RTL", - "module": "silu", - "dependence_files": get_module_dependencies("silu/silu"), - }, - }, - "functions": {}, -} - - -def llama_module_level_quantize(model, model_config, q_config): - for name, module in model.named_modules(): - if isinstance(module, LlamaSdpaAttention): - new_module = LlamaSdpaAttentionInteger( - config=model_config, - q_config=q_config, - output_tensor_only=True, - ) - elif isinstance(module, nn.Linear): - new_module = LinearInteger( - in_features=module.in_features, - out_features=module.out_features, - bias=module.bias is not None, - config=q_config, - ) - elif isinstance(module, LlamaRMSNorm): - new_module = RMSNormInteger( - normalized_shape=None, - eps=module.variance_epsilon, - config=q_config, - ) - elif isinstance(module, nn.SiLU): - new_module = SiLUInteger( - inplace=module.inplace, - config=q_config, - ) - else: - continue - logger.info(f"Replacing module: {name}") - deepsetattr(model, name, new_module) - return model - - -def emit_verilog_llama( - config, - q_config, - config_sequence_length, - config_batch_size, - wait_count=15, - wait_unit="ms", - max_parallelism=4, -): - # * Get model and quantize self attention, linear and layer norm layers - model = LlamaModel(config) - model = llama_module_level_quantize(model, config, q_config) - logger.info(f"Quantized Llama model: {model}") - - # * Trace the model - mg = MaseGraph(model, custom_ops=LLAMA_CUSTOM_OPS) - - pipeline = AutoPipelineForEmitVerilog() - mg = pipeline( - mg, - pass_args={ - "report_graph_analysis_pass": {"file_name": "llama.txt"}, - "add_common_metadata_analysis_pass": { - "dummy_in": { - "input_ids": torch.randn( - (config_batch_size, config_sequence_length, config.hidden_size) - ) - }, - "add_value": False, - }, - "patch_metadata_transform_pass": { - "q_config": q_config, - }, - "add_hardware_metadata_analysis_pass": { - "max_parallelism": [max_parallelism] * 4, - }, - "report_node_meta_param_analysis_pass": { - "which": ["common", "hardware"], - "save_path": "llama_graph_meta_params.txt", - }, - "emit_cocotb_transform_pass": { - "wait_time": wait_count, - "wait_unit": wait_unit, - }, - }, - ) - - actions.simulate( - skip_build=False, skip_test=False, gui=True, waves=False, simulator="questa" - ) - - -def get_default_qconfig(): - return { - "data_in_width": 8, - "data_in_frac_width": 3, - "weight_width": 8, - "weight_frac_width": 3, - "bias_width": 8, - "bias_frac_width": 3, - "data_out_width": 8, - "data_out_frac_width": 3, - } - - -@pytest.mark.skip(reason="Fixing needed") -def test_emit_verilog_llama_smoke(): - config = LlamaConfig() - config.num_hidden_layers = 1 - config.hidden_size //= SMOKE_TEST_SCALE_FACTOR - config.intermediate_size //= SMOKE_TEST_SCALE_FACTOR - config.max_position_embeddings = 4096 - config.rms_norm_eps = 1e-5 - - config_batch_size = 5 - config_sequence_length = 4 - - q_config = get_default_qconfig() - - emit_verilog_llama( - config, - q_config, - config_sequence_length, - config_batch_size, - wait_count=10, - max_parallelism=2, - ) - - -if __name__ == "__main__": - generate_sv_lut( - "silu", - 8, - 3, - data_width=8, - f_width=3, - path="./src/mase_components/activation_layers/rtl", - path_with_dtype=False, - ) - test_emit_verilog_llama_smoke() diff --git a/test/passes/graph/transforms/verilog/test_emit_verilog_mistral.py b/test/passes/graph/transforms/verilog/test_emit_verilog_mistral.py deleted file mode 100644 index 8bafc9248..000000000 --- a/test/passes/graph/transforms/verilog/test_emit_verilog_mistral.py +++ /dev/null @@ -1,153 +0,0 @@ -import sys, os - -import torch -import torch.nn as nn -import pytest - -from transformers.activations import GELUActivation - -import chop.passes as passes -import chop.actions as actions -from chop.ir import MaseGraph -from chop.models.patched.mistral import MistralConfig, MistralModel -from chop.models.patched.mistral.modeling_mistral import MistralAttention -from chop.passes.graph.utils import deepsetattr - -# from chop.nn.quantized import MistralAttentionInteger -from chop.tools import get_logger, set_excepthook - -from mase_components import get_module_dependencies -from mase_components.helper.generate_memory import generate_sv_lut - -import operator -from functools import partial - -logger = get_logger(__name__) -logger.setLevel("DEBUG") -set_excepthook() - -# * Define custom ops (leaf submodules during tracing) -# * This is useful so we can write a single optimised verilog file for self attention, -# * instead of relying on emit_verilog to instantiate each submodule -MISTRAL_CUSTOM_OPS = { - "modules": {}, - "functions": {}, -} - - -def mistral_module_level_quantize(model, model_config, q_config): - return model - - -def mistral_update_metadata(mg, q_config): - """ - The following processing is a temporary hot fix to get emit verilog working on the mistral model. We - update the type and precision for the add, getitem and split (fork) nodes which are currently - inserted in the patched model code. In the (near) future, inserting forking nodes and setting their - precision correctly will be handled automatedly as a preprocessing step for the emit verilog pass, - so this function will be unnecessary. - """ - return mg, {} - - -def emit_verilog_mistral( - config, - q_config, - config_sequence_length, - wait_count=15, - wait_unit="ms", - max_parallelism=4, -): - # * Get model and quantize self attention, linear and layer norm layers - model = MistralModel(config) - model = mistral_module_level_quantize(model, config, q_config) - logger.info(f"Quantized mistral model: {model}") - - # * Trace the model - mg = MaseGraph(model, custom_ops=MISTRAL_CUSTOM_OPS) - mg, _ = passes.init_metadata_analysis_pass(mg) - - mg, _ = passes.report_graph_analysis_pass( - mg, pass_args={"file_name": "mistral.txt"} - ) - - # * Add metadata analysis passes - mg, _ = passes.add_common_metadata_analysis_pass( - mg, - pass_args={ - "dummy_in": { - "input_ids": torch.randn( - (1, config_sequence_length, config.hidden_size) - ) - }, - "add_value": False, - }, - ) - - mg, _ = mistral_update_metadata(mg, q_config) - - mg, _ = passes.add_hardware_metadata_analysis_pass( - mg, - pass_args={ - "max_parallelism": [max_parallelism] * 4, - }, - ) - - # * Save the metadata to a file for debugging - mg, _ = passes.report_node_meta_param_analysis_pass( - mg, - pass_args={ - "which": ["common", "hardware"], - "save_path": "mistral_graph_meta_params.txt", - }, - ) - - mg, _ = passes.emit_verilog_top_transform_pass(mg) - mg, _ = passes.emit_bram_transform_pass(mg) - mg, _ = passes.emit_internal_rtl_transform_pass(mg) - mg, _ = passes.emit_cocotb_transform_pass( - mg, - pass_args={ - "wait_time": wait_count, - "wait_unit": wait_unit, - }, - ) - mg, _ = passes.emit_vivado_project_transform_pass(mg) - - # Temporary: fix data coherency checks - os.environ["COCOTB_RESOLVE_X"] = "ZEROS" - - actions.simulate( - skip_build=False, skip_test=False, gui=False, waves=False, simulator="questa" - ) - - -def get_default_qconfig(): - return { - "data_in_width": 8, - "data_in_frac_width": 3, - "weight_width": 8, - "weight_frac_width": 3, - "bias_width": 8, - "bias_frac_width": 3, - "data_out_width": 8, - "data_out_frac_width": 3, - } - - -@pytest.mark.skip(reason="Fixing needed") -def test_emit_verilog_mistral_smoke(): - config = MistralConfig() - config.num_hidden_layers = 3 - config.hidden_size = 96 - config.intermediate_size = 384 - config_sequence_length = 4 - q_config = get_default_qconfig() - emit_verilog_mistral( - config, q_config, config_sequence_length, wait_count=10, max_parallelism=2 - ) - - -if __name__ == "__main__": - generate_sv_lut("silu", 8, 3, data_width=8, f_width=3, path_with_dtype=False) - test_emit_verilog_mistral_smoke() From 5eee310b82a13e9e72edc3e621d2419d5b6f578a Mon Sep 17 00:00:00 2001 From: Pedro Gimenes Date: Wed, 24 Jul 2024 16:22:42 +0000 Subject: [PATCH 56/93] remove deprecated stuff --- .../autosharding/alpa_cost_modelling.py | 119 ------------------ .../analysis/autosharding/alpa_layers.py | 112 ----------------- .../graph/analysis/autosharding/common.py | 26 ---- 3 files changed, 257 deletions(-) delete mode 100644 src/chop/passes/graph/analysis/autosharding/alpa_cost_modelling.py delete mode 100644 src/chop/passes/graph/analysis/autosharding/alpa_layers.py delete mode 100644 src/chop/passes/graph/analysis/autosharding/common.py diff --git a/src/chop/passes/graph/analysis/autosharding/alpa_cost_modelling.py b/src/chop/passes/graph/analysis/autosharding/alpa_cost_modelling.py deleted file mode 100644 index d9fea1aa5..000000000 --- a/src/chop/passes/graph/analysis/autosharding/alpa_cost_modelling.py +++ /dev/null @@ -1,119 +0,0 @@ -import numpy as np -from functools import lru_cache - -from chop.ir.graph import MaseMetadata - -from .common import SpmdShard -from .mesh_model import MeshModel - -BYTES_PER_ELEMENT = 4 - - -def get_communication_cost(sharding: tuple, node_meta: MaseMetadata, mesh: MeshModel): - assert ( - sharding[0][-1] == sharding[1][-2] - ), f"Inconsistent sharding for node: {node_meta.node}" - inner_dim_sharding = sharding[1][0] - - out_shape = node_meta["common"]["results"]["data_out_0"]["shape"] - - if inner_dim_sharding == SpmdShard.R: - return 0 - - else: - ar_dim = inner_dim_sharding.value # 0 for S_0, 1 for S_1 - return mesh.all_reduce_cost( - num_bytes=BYTES_PER_ELEMENT * np.prod(out_shape), mesh_dim=ar_dim - ) - - -@lru_cache(maxsize=None) -def get_resharding_cost( - mesh: MeshModel, src: tuple, dest: tuple, dest_node_meta: MaseMetadata -): - """ - Obtain the resharding cost given a source and destination sharding profile for a tensor. - The mesh object is assumed to have been initialized with alpha, beta parameters so that - the communication cost can be estimated for each MPI operator. - """ - - # If original sharding is fully replicated, no resharding is required - if src == dest or all(i == SpmdShard.R for i in src): - return 0 - - num_bytes = BYTES_PER_ELEMENT * np.prod( - dest_node_meta["common"]["args"]["data_in_0"]["shape"] - ) - - # No cost (simple split along given mesh dimension) - if ( - # Keep dim 0, split dim 1 - # E.g. (R, R) -> (R, S_0), (S_0, R) -> (S_0, S_1) - (src[0] == dest[0]) - and (src[1] == SpmdShard.R) - and (dest[1] in [SpmdShard.S_0, SpmdShard.S_1]) - # Split dim 0, keep dim 1 - # E.g. (R, R) -> (S_1, R), (R, S_1) -> (S_0, S_1) - or (src[1] == dest[1]) - and (src[0] == SpmdShard.R) - and (dest[0] in [SpmdShard.S_0, SpmdShard.S_1]) - ): - return 0 - - # Split -> Replicate (All Gather) - elif ( - # Keep dim 0, gather along dim 1 - # E.g. (S_1, S_0) -> (S_1, R) - (src[0] == dest[0]) - and (src[1] in [SpmdShard.S_0, SpmdShard.S_1]) - and (dest[1] == SpmdShard.R) - # Gather along dim 0, keep dim 1 - # E.g. (S_0, S_1) -> (R, S_1) - or (src[1] == dest[1]) - and (src[0] in [SpmdShard.S_0, SpmdShard.S_1]) - and (dest[0] == SpmdShard.R) - ): - ag_dim = 1 if src[0] == dest[0] else 0 - return mesh.all_gather_cost( - num_bytes=num_bytes, - mesh_dim=ag_dim, - ) - - # All-to-all - # E.g. (R, S_0) -> (S_0, R), (S_1, R) -> (R, S_1) - elif src[0] == dest[1] and src[1] == dest[0] and (SpmdShard.R in src): - # all to all - a2a_dim = src[0].value if src[0] != SpmdShard.R else src[1].value - try: - return mesh.all_to_all_cost( - num_bytes=num_bytes, - mesh_dim=a2a_dim, - ) - except: - assert False - - # Two-stage resharding: when the resharding cannot be resolved with a single split, all-gather or all-to-all, - # must first gather along the first non-replicated dimension, then recursively compute the cost for the - # reduced sharding - else: - # Reduce one dimension and re-compute - if src[0] != SpmdShard.R: - new_src = (SpmdShard.R, src[1]) - ag_dim = src[0].value - else: - new_src = (SpmdShard.R, SpmdShard.R) - ag_dim = src[1].value - - return mesh.all_gather_cost( - num_bytes=num_bytes, mesh_dim=ag_dim - ) + get_resharding_cost(mesh, new_src, dest, dest_node_meta) - - -def get_resharding_matrix(mesh, src_shardings, dest_shardings, dest_node_meta): - mat = np.zeros((len(dest_shardings), len(src_shardings))) - for src_idx, src in enumerate(src_shardings): - for dest_idx, dest in enumerate(dest_shardings): - mat[dest_idx, src_idx] = get_resharding_cost( - mesh, src, dest, dest_node_meta - ) - return mat diff --git a/src/chop/passes/graph/analysis/autosharding/alpa_layers.py b/src/chop/passes/graph/analysis/autosharding/alpa_layers.py deleted file mode 100644 index 986079b07..000000000 --- a/src/chop/passes/graph/analysis/autosharding/alpa_layers.py +++ /dev/null @@ -1,112 +0,0 @@ -import itertools -import numpy as np -import torch.nn as nn - -from chop.tools import get_logger -from chop.models.patched.bert.modeling_bert import BertSelfAttention - -from .common import SpmdShard, VALID_2D_TENSOR_SHARDINGS -from .alpa_cost_modelling import get_communication_cost - - -logger = get_logger(__name__) - - -def is_valid_2d_sharding(sharding): - if len(sharding) > 2: - return sharding[1:] in VALID_2D_TENSOR_SHARDINGS - else: - return sharding in VALID_2D_TENSOR_SHARDINGS - - -def is_valid_sharding_pair(sharding_pair): - return sharding_pair[0][-1] == sharding_pair[1][-2] - - -def is_fully_replicated(sharding_pair): - return all(all(dimp == SpmdShard.R for dimp in subp) for subp in sharding_pair) - - -def get_valid_2d_shardings(node_meta, mesh, module): - """ - Return every valid combination of shardings for the input tensors. For an operator - sharding to be valid, the inner dimension must have the same sharding. - E.g. ((R, S_0), (S_0, R)) are valid, but ((R, S_0), (S_1, R)) is not. - """ - input_shardings = [] - output_shardings = [] - compute_cost_vector = [] - communication_cost_vector = [] - - out_rank = len(node_meta["mase"]["common"]["results"]["data_out_0"]["shape"]) - - for perm in itertools.product(VALID_2D_TENSOR_SHARDINGS, repeat=2): - if out_rank > 2: - perm = tuple((SpmdShard.R,) * (out_rank - 2) + p for p in perm) - output_sharding = tuple( - (SpmdShard.R,) * (out_rank - 2) + (perm[0][-2], perm[1][-1]) - ) - if ( - not is_fully_replicated(perm) - and is_valid_sharding_pair(perm) - and is_valid_2d_sharding(output_sharding) - ): - input_shardings.append({"data_in_0": perm[0], "weight": perm[1]}) - output_shardings.append(output_sharding) - - compute_cost_vector.append(0) - communication_cost_vector.append( - get_communication_cost(perm, node_meta["mase"], mesh) - ) - - return ( - input_shardings, - output_shardings, - np.array(compute_cost_vector), - np.array(communication_cost_vector), - ) - - -def get_valid_linear_shardings(node_meta, mesh, module): - return get_valid_2d_shardings(node_meta, mesh, module) - - -def get_valid_layernorm_shardings(node_meta, mesh, module): - rank = len(node_meta["mase"]["common"]["results"]["data_out_0"]["shape"]) - valid_input_shardings = [{"data_in_0": (SpmdShard.R,) * rank}] - valid_output_shardings = [(SpmdShard.R,) * rank] - compute_cost_vector = [0] - communication_cost_vector = [0] - return ( - valid_input_shardings, - valid_output_shardings, - np.array(compute_cost_vector), - np.array(communication_cost_vector), - ) - - -def get_valid_embedding_shardings(node_meta, mesh, module): - weight_rank = len(module.weight.shape) - data_in_rank = len(node_meta["mase"]["common"]["args"]["data_in_0"]["shape"]) - valid_input_shardings = [ - { - "data_in_0": (SpmdShard.R,) * data_in_rank, - "weight": (SpmdShard.R,) * weight_rank, - } - ] - valid_output_shardings = [(SpmdShard.R,) * data_in_rank] - compute_cost_vector = [0] - communication_cost_vector = [0] - return ( - valid_input_shardings, - valid_output_shardings, - np.array(compute_cost_vector), - np.array(communication_cost_vector), - ) - - -ALPA_LAYERS = { - nn.Linear: get_valid_linear_shardings, - nn.LayerNorm: get_valid_layernorm_shardings, - nn.Embedding: get_valid_embedding_shardings, -} diff --git a/src/chop/passes/graph/analysis/autosharding/common.py b/src/chop/passes/graph/analysis/autosharding/common.py deleted file mode 100644 index e4fd59a48..000000000 --- a/src/chop/passes/graph/analysis/autosharding/common.py +++ /dev/null @@ -1,26 +0,0 @@ -from enum import Enum - - -class SpmdShard(Enum): - S_0 = 0 - S_1 = 1 - R = 3 - - def __repr__(self): - return self.name - - def __gt__(self, other): - if self.__class__ is other.__class__: - return self.value > other.value - return NotImplemented - - -VALID_2D_TENSOR_SHARDINGS = [ - (SpmdShard.R, SpmdShard.R), - (SpmdShard.R, SpmdShard.S_0), - (SpmdShard.R, SpmdShard.S_1), - (SpmdShard.S_0, SpmdShard.R), - (SpmdShard.S_0, SpmdShard.S_1), - (SpmdShard.S_1, SpmdShard.R), - (SpmdShard.S_1, SpmdShard.S_0), -] From 264683dbd5458ed97a797ede8d638ac03e80e8f8 Mon Sep 17 00:00:00 2001 From: Pedro Gimenes Date: Wed, 24 Jul 2024 16:37:17 +0000 Subject: [PATCH 57/93] reduce ILP complexity by skipping placeholder/get_attr candidate shardings when a tensor dimension is not large enough to be sharded along a given mesh dimension --- .../autosharding/strategies/common.py | 48 +++++++++++++++++-- 1 file changed, 45 insertions(+), 3 deletions(-) diff --git a/src/chop/passes/graph/analysis/autosharding/strategies/common.py b/src/chop/passes/graph/analysis/autosharding/strategies/common.py index 0959a6bf3..bdc5b9736 100644 --- a/src/chop/passes/graph/analysis/autosharding/strategies/common.py +++ b/src/chop/passes/graph/analysis/autosharding/strategies/common.py @@ -1,4 +1,5 @@ import itertools +import numpy as np import torch import torch.nn.functional as F @@ -40,20 +41,61 @@ def find_shape_and_dtype(arg): def placeholder_or_getattr_strategy(meta, mesh, skip_fully_replicated=False): ndims = len(meta["common"]["results"]["data_out_0"]["shape"]) + tensor_shape = meta["common"]["results"]["data_out_0"]["shape"] opts = [Replicate()] + [Shard(dim) for dim in range(ndims)] tensor_meta = TensorMeta( - shape=meta["common"]["results"]["data_out_0"]["shape"], + shape=tensor_shape, stride=None, dtype=meta["common"]["results"]["data_out_0"]["torch_dtype"], ) shardings = [] for sharding in itertools.product(opts, repeat=2): + # Skip fully replicated shardings since this sometimes forces the ILP + # to choose a fully replicated strategy for the entire model when + # the computation cost term is not formulated if skip_fully_replicated and sharding == (Replicate(), Replicate()): continue - spec = DTensorSpec(mesh=mesh, placements=sharding, tensor_meta=tensor_meta) - shardings.append(PlacementStrategy(input_specs=spec, output_specs=spec)) + + # Skip sharding if any dimension is sharded to 0 + skip_sharding = False + for dim in range(ndims): + # Find all device mesh dimensions along which this tensor dimension is sharded + mesh_sharded_dims = [ + idx for idx, shard in enumerate(sharding) if shard == Shard(dim) + ] + + # This tensor dimension is not sharded + if len(mesh_sharded_dims) == 0: + continue + + elif len(mesh_sharded_dims) == 1: + num_gpus = mesh.mesh_shape[mesh_sharded_dims[0]] + + else: + num_gpus = np.prod(mesh.mesh_shape) + + dim_size_after_sharding = tensor_shape[dim] // num_gpus + if dim_size_after_sharding == 0: + skip_sharding = True + continue + + if skip_sharding is True: + continue + + spec = DTensorSpec( + mesh=mesh, + placements=sharding, + tensor_meta=tensor_meta, + ) + shardings.append( + PlacementStrategy( + input_specs=spec, + output_specs=spec, + ) + ) + return OpStrategy(shardings) From b594c76a068f20c5bf9452ae5ce8e6a39744ba9b Mon Sep 17 00:00:00 2001 From: Pedro Gimenes Date: Wed, 24 Jul 2024 16:38:00 +0000 Subject: [PATCH 58/93] refactor add_common_metadata such that args/kwargs ordering is preserved --- src/chop/ir/__init__.py | 1 + .../add_metadata/common_metadata_layers.py | 220 ++++++++++++------ 2 files changed, 152 insertions(+), 69 deletions(-) diff --git a/src/chop/ir/__init__.py b/src/chop/ir/__init__.py index 834757faa..d0b968ab7 100644 --- a/src/chop/ir/__init__.py +++ b/src/chop/ir/__init__.py @@ -1,3 +1,4 @@ from .graph.mase_graph import MaseGraph, MaseTracer +from .graph.mase_graph_metadata import MaseGraphMetadata from .onnx.mase_onnx_graph import MaseOnnxGraph diff --git a/src/chop/passes/graph/analysis/add_metadata/common_metadata_layers.py b/src/chop/passes/graph/analysis/add_metadata/common_metadata_layers.py index 2784fceb6..88e9a9f71 100644 --- a/src/chop/passes/graph/analysis/add_metadata/common_metadata_layers.py +++ b/src/chop/passes/graph/analysis/add_metadata/common_metadata_layers.py @@ -1,11 +1,15 @@ import inspect -import math +from collections import OrderedDict +from functools import reduce import torch -import inspect + from chop.nn.quantized.modules import quantized_module_map -from functools import reduce +from chop.ir import MaseGraphMetadata +from chop.tools import get_logger +logger = get_logger(__name__) +logger.setLevel("INFO") # ---------------------------------------------------------- # Utility @@ -334,6 +338,10 @@ "type_as": {"tensor": "data_in"}, } +# ---------------------------------------------------------- +# Helpers +# ---------------------------------------------------------- + def get_type_and_precision(meta): # * Fetch type and precision from q_config for quantized modules @@ -354,16 +362,65 @@ def get_type_and_precision(meta): return arg_type, arg_precision -def match_args_and_kwargs(meta, args, kwargs, data, add_value): - ordered_func_data = [(k, v) for k, v in data.items()] - meta.parameters["common"]["args"] = {} - meta_kwargs = {} +def get_shape(x): + if x is None: + return None + elif isinstance(x, torch.Tensor): + return list(x.shape) + elif isinstance(x, int): + return [1] + elif isinstance(x, (list, tuple, torch.Size)): + return [len(x)] + else: + return [0] + + +def deepgetattr(obj, attr): + """Recurses through an attribute chain to get the ultimate value.""" + return reduce(getattr, attr.split("."), obj) + + +# ---------------------------------------------------------- +# Metadata annotators +# ---------------------------------------------------------- + + +def _annotate_arg_metadata( + meta: MaseGraphMetadata, + args: list, + kwargs: dict, + func_data: dict, + add_value: bool, +): + """ + Analyse target args and kwargs received from shape propagation to annotate combined meta["mase"]["args"] + dictionary with metadata about each argument. The order of the args and kwargs must be preserved in the + combined dictionary (this is expected by downstream passes). However, arguments with the 'data_in' flag + in func_data are renamed to 'data_in_{itr}' where itr = 0 ... the number of data_in arguments. + + This function should not be called directly, but rather through the `annotate_common_parameters_` function. + The value in the meta["common"]["args"] dictionary should always be a dictionary, not a tensor. + + Args: + meta (MaseGraphMetadata): The metadata object. + args (list): List of args passed to the target. + kwargs (dict): Dictionary of kwargs passed to the target. + func_data (dict): Dictionary defining whether each argument is data_in or config. + add_value (bool): indicate whether to add the value of the tensor to the metadata. + + Returns: + MaseGraphMetadata: metadata object with annotated args. + """ + ordered_func_data = [(k, v) for k, v in func_data.items()] + meta["common"]["args"] = OrderedDict() + data_in_itr = 0 arg_type, arg_precision = get_type_and_precision(meta) - # * Assign metadata for each argument - j = 0 + # * Handle args for i, x in enumerate(args): + + # Input data tensor if isinstance(x, torch.Tensor) and ordered_func_data[i][1] == "data_in": arg_meta = { "shape": list(x.shape), @@ -373,9 +430,10 @@ def match_args_and_kwargs(meta, args, kwargs, data, add_value): } if add_value: arg_meta["value"] = x - meta.parameters["common"]["args"][f"data_in_{j}"] = arg_meta - j += 1 - # check if it's a tuple of tensors + meta["common"]["args"][f"data_in_{data_in_itr}"] = arg_meta + data_in_itr += 1 + + # Tuple of tensors elif isinstance(x, tuple) and all([isinstance(x, torch.Tensor) for x in x]): for k, x in enumerate(x): arg_meta = { @@ -386,27 +444,32 @@ def match_args_and_kwargs(meta, args, kwargs, data, add_value): } if add_value: arg_meta["value"] = x - meta.parameters["common"]["args"][f"data_in_{j}"] = arg_meta - j += 1 - else: - # this is not an data_in, but just actually an named arg - n, vtype = ordered_func_data[i] - meta_kwargs[n] = args[i] - - def get_shape(x): - if x is None: - return None - elif isinstance(x, torch.Tensor): - return list(x.shape) - elif isinstance(x, int): - return [1] - elif isinstance(x, list): - return [len(x)] + meta["common"]["args"][f"data_in_{data_in_itr}"] = arg_meta + data_in_itr += 1 + + # Unknown data_in type or config argument else: - raise ValueError(f"Unknown type {type(x)}") + # Don't increment the iterator for config arguments, but + # preserve order in meta["common"]["args"] + arg_name, arg_flag = ordered_func_data[i] + + if arg_flag == "data_in": + arg_name = f"data_in_{data_in_itr}" + data_in_itr += 1 + + meta["common"]["args"][arg_name] = { + "torch_dtype": x.dtype if isinstance(x, torch.Tensor) else None, + "type": type(args[i]), + "precision": arg_precision, + "shape": get_shape(args[i]), + } + if add_value: + meta["common"]["args"][arg_name]["value"] = args[i] + + # * Handle kwargs for k, v in kwargs.items(): - if data[k] == "data_in": + if func_data[k] == "data_in": # rename this to mase data_in_number shape = get_shape(v) arg_meta = { @@ -417,47 +480,71 @@ def get_shape(x): } if add_value: arg_meta["value"] = v - meta.parameters["common"]["args"][f"data_in_{j}"] = arg_meta - j += 1 + meta["common"]["args"][f"data_in_{data_in_itr}"] = arg_meta + data_in_itr += 1 else: # otherwise this must be a configuration parameter in meta - meta_kwargs[k] = v - # merge configuratipn args - meta.parameters["common"]["args"] = meta.parameters["common"]["args"] | meta_kwargs + # meta_kwargs[k] = v + meta["common"]["args"][k] = { + "type": type(v), + "precision": arg_precision, + "shape": get_shape(v), + } + if add_value: + meta["common"]["args"][k]["value"] = v + return meta -def analyse_result(meta, result, add_value): +def _annotate_result_metadata( + meta: MaseGraphMetadata, + result, + add_value: bool, +) -> MaseGraphMetadata: + """ + Analyse the result from running the target to annotate the meta["mase"]["results"] dictionary with metadata. + + Args: + meta (MaseGraphMetadata): The metadata object. + result (_type_): The result object. + add_value (bool): indicate whether to add the value of the tensor to the metadata. + + Returns: + MaseGraphMetadata: metadata object with annotated results. + """ # deal with results - meta.parameters["common"]["results"] = {} + meta["common"]["results"] = OrderedDict() result_type, result_precision = get_type_and_precision(meta) if isinstance(result, torch.Tensor): - meta.parameters["common"]["results"]["data_out_0"] = { + meta["common"]["results"]["data_out_0"] = { "type": result_type, "precision": result_precision, "shape": list(result.shape), "torch_dtype": result.dtype, } if add_value: - meta.parameters["common"]["results"]["data_out_0"]["value"] = result + meta["common"]["results"]["data_out_0"]["value"] = result # check if it's a tuple of tensors elif isinstance(result, tuple) and all( [isinstance(x, torch.Tensor) for x in result] ): for i, x in enumerate(result): - meta.parameters["common"]["results"][f"data_out_{i}"] = { + meta["common"]["results"][f"data_out_{i}"] = { "type": result_type, "precision": result_precision, "shape": list(x.shape), "torch_dtype": x.dtype, } if add_value: - meta.parameters["common"]["results"][f"data_out_{i}"]["value"] = x + meta["common"]["results"][f"data_out_{i}"]["value"] = x else: - meta.parameters["common"]["results"]["data_out_0"] = { + logger.debug( + f"Expected result to be a tensor or tuple of tensors, but found: {type(result)}. Will annotate with default value, but this may cause issues downstream." + ) + meta["common"]["results"]["data_out_0"] = { "type": type(result), "shape": [1], "value": result, @@ -478,19 +565,19 @@ def analyse_common_parameters_placeholder(meta, result, args, kwargs, add_value= var_name = meta.node.target # deal with model specific inputs, normally these are not numerical values/tensors if var_name in meta.model.additional_inputs: - meta.parameters["common"]["args"] = {} - meta.parameters["common"]["results"] = {} - meta.parameters["common"]["results"]["data_out_0"] = { + meta["common"]["args"] = {} + meta["common"]["results"] = {} + meta["common"]["results"]["data_out_0"] = { "type": "model_specific_input", "shape": result.shape, "torhc_dtype": result.dtype, } if add_value: - meta.parameters["common"]["results"]["data_out_0"]["value"] = result + meta["common"]["results"]["data_out_0"]["value"] = result return meta - meta.parameters["common"]["args"] = {} - meta = analyse_result(meta, result, add_value) + meta["common"]["args"] = {} + meta = _annotate_result_metadata(meta, result, add_value) return meta @@ -501,12 +588,12 @@ def analyse_common_parameters_placeholder(meta, result, args, kwargs, add_value= def analyse_common_parameters_function(meta, result, args, kwargs, add_value=True): # fetch mase info - mase_op = meta.parameters["common"]["mase_op"] + mase_op = meta["common"]["mase_op"] # deal with result - meta = analyse_result(meta, result, add_value) + meta = _annotate_result_metadata(meta, result, add_value) # deal with args and kwargs - meta = match_args_and_kwargs(meta, args, kwargs, func_data[mase_op], add_value) + meta = _annotate_arg_metadata(meta, args, kwargs, func_data[mase_op], add_value) return meta @@ -516,13 +603,8 @@ def analyse_common_parameters_function(meta, result, args, kwargs, add_value=Tru # ---------------------------------------------------------- -def deepgetattr(obj, attr): - """Recurses through an attribute chain to get the ultimate value.""" - return reduce(getattr, attr.split("."), obj) - - def analyse_common_parameters_module(meta, result, args, kwargs, add_value=True): - mase_op = meta.parameters["common"]["mase_op"] + mase_op = meta["common"]["mase_op"] node_module = deepgetattr(meta.model, meta.node.target) if mase_op == "user_defined_module": @@ -533,13 +615,13 @@ def analyse_common_parameters_module(meta, result, args, kwargs, add_value=True) else: module_args = module_data[mase_op] - meta = match_args_and_kwargs(meta, args, kwargs, module_args, add_value) + meta = _annotate_arg_metadata(meta, args, kwargs, module_args, add_value) arg_type, arg_precision = get_type_and_precision(meta) for name, parameter in meta.module.named_parameters(): name = name.replace(".", "_") - meta.parameters["common"]["args"][name] = { + meta["common"]["args"][name] = { "type": arg_type, "precision": arg_precision, "shape": ( @@ -550,21 +632,21 @@ def analyse_common_parameters_module(meta, result, args, kwargs, add_value=True) "from": None, } if add_value: - meta.parameters["common"]["args"][name]["value"] = parameter + meta["common"]["args"][name]["value"] = parameter - meta = analyse_result(meta, result, add_value) + meta = _annotate_result_metadata(meta, result, add_value) return meta # ---------------------------------------------------------- -# Module +# Method # ---------------------------------------------------------- def analyse_common_parameters_method(meta, result, args, kwargs, add_value=True): - mase_op = meta.parameters["common"]["mase_op"] - meta = analyse_result(meta, result, add_value) - meta = match_args_and_kwargs(meta, args, kwargs, method_data[mase_op], add_value) + mase_op = meta["common"]["mase_op"] + meta = _annotate_result_metadata(meta, result, add_value) + meta = _annotate_arg_metadata(meta, args, kwargs, method_data[mase_op], add_value) return meta @@ -574,8 +656,8 @@ def analyse_common_parameters_method(meta, result, args, kwargs, add_value=True) def analyse_common_parameters_attr(meta, result, args, kwargs, add_value=True): - meta.parameters["common"]["args"] = {} - meta = analyse_result(meta, result, add_value) + meta["common"]["args"] = {} + meta = _annotate_result_metadata(meta, result, add_value) return meta @@ -585,6 +667,6 @@ def analyse_common_parameters_attr(meta, result, args, kwargs, add_value=True): def analyse_common_parameters_output(meta, result, args, kwargs, add_value=True): - meta.parameters["common"]["args"] = {} - meta = analyse_result(meta, result, add_value) + meta["common"]["args"] = {} + meta = _annotate_result_metadata(meta, result, add_value) return meta From 419dc5f024ac933d3de88d145db6da5e464a3292 Mon Sep 17 00:00:00 2001 From: Pedro Gimenes Date: Wed, 24 Jul 2024 17:47:43 +0000 Subject: [PATCH 59/93] [UNFINISHED] profile ops with local tensor shapes to formulate compute costs and avoid dummy ILP solutions --- .../autosharding/alpa_intra_operator.py | 136 ++++++++++++++++-- .../graph/analysis/autosharding/layers.py | 1 + .../autosharding/strategies/math_ops.py | 21 +-- .../autosharding/strategies/pointwise_ops.py | 31 ++-- .../autosharding/strategies/view_ops.py | 4 +- 5 files changed, 163 insertions(+), 30 deletions(-) diff --git a/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py b/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py index b431095d6..077a4869b 100644 --- a/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py +++ b/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py @@ -1,12 +1,17 @@ +import math +import numpy as np +import cvxpy as cp +from copy import copy + import torch import torch.fx as fx from torch.distributed._tensor._collective_utils import redistribute_cost -from torch.distributed._tensor._op_schema import DTensorSpec -import numpy as np -import cvxpy as cp +from torch.distributed._tensor._op_schema import DTensorSpec, OpStrategy +from torch.distributed._tensor.placement_types import Shard from chop.tools import get_logger from chop.tools.utils import deepgetattr +from .mesh_model import MeshModel from .layers import ( AUTOSHARDING_MODULES, @@ -25,6 +30,101 @@ logger.setLevel("DEBUG") +def _get_computation_cost_from_strategy( + node: fx.Node, + strategy: OpStrategy, + mesh: MeshModel, + repeat: int = 5, + warmup_iters: int = 2, + profiling_device: int = 0, +): + """ + ... + + Args: + node (fx.Node): _description_ + strategy (OpStrategy): _description_ + repeat (int, optional): _description_. Defaults to 5. + warmup_iters (int, optional): _description_. Defaults to 1. + + Returns: + _type_: _description_ + """ + arg_specs = strategy.input_specs + arg_specs = [arg_specs] if isinstance(arg_specs, DTensorSpec) else arg_specs + + # Formulate list of arguments to run the target with + args = [] + for arg_idx, arg_spec in enumerate(arg_specs): + + # If tensor meta is None, this is not a sharded argument + if arg_spec.tensor_meta is None: + key = list(node.meta["mase"]["common"]["args"].keys())[arg_idx] + arg_value = node.meta["mase"]["common"]["args"][key]["value"] + args.append(arg_value) + continue + + # If it is a sharded argument, find the local tensor shape + else: + global_shape = copy(arg_spec.tensor_meta.shape) + local_shape = copy(arg_spec.tensor_meta.shape) + + # Check if each tensor dimension is sharded to update local_shape + for dim in range(len(global_shape)): + # Get device mesh dimensions along which dimension 'dim' of the tensor is sharded + sharded_mesh_dims = [ + idx + for idx, plac in enumerate(arg_spec.placements) + if plac == Shard(dim) + ] + + # This tensor dimension is not sharded + if len(sharded_mesh_dims) == 0: + continue + + # This tensor dimension is fully sharded + elif len(sharded_mesh_dims) == 2: + num_gpus = np.prod(mesh.mesh_shape) + + # This tensor dimension is sharded along one mesh dimension + elif len(sharded_mesh_dims) == 1: + num_gpus = mesh.mesh_shape[sharded_mesh_dims[0]] + + # Define the local shape with minimum == 1 + local_shape[dim] = math.ceil(global_shape[dim] / num_gpus) + + # Generate a random tensor with the local shape + args.append( + torch.randn( + local_shape, + device=f"cuda:{profiling_device}", + ) + ) + + # Get target function + fn = node.target + + # Run the function with the arguments + start_event = [ + torch.cuda.Event(enable_timing=True, blocking=True) for _ in range(repeat) + ] + end_event = [ + torch.cuda.Event(enable_timing=True, blocking=True) for _ in range(repeat) + ] + + torch.cuda.empty_cache() + + for idx in range(repeat): + start_event[idx].record() + _ = fn(*args) + end_event[idx].record() + torch.cuda.synchronize() + + elapsed = [start_event[idx].elapsed_time(end_event[idx]) for idx in range(repeat)] + + return np.mean(elapsed[warmup_iters:]) + + def _extract_ilp(mg, mesh, pass_args={}): """ For each node in the graph, assign an OpStrategy object which contains all possible @@ -102,24 +202,27 @@ def _extract_ilp(mg, mesh, pass_args={}): continue elif node.op == "call_module" and isinstance( - deepgetattr(mg.model, node.target), tuple(AUTOSHARDING_MODULES.keys()) + deepgetattr(mg.model, node.target), + tuple(AUTOSHARDING_MODULES.keys()), ): - logger.debug(f"Obtaining strategy for node {node.name}") + logger.debug(f"Obtaining strategy for call_module node: {node.name}") module_cls = type(deepgetattr(mg.model, node.target)) op_strategy = AUTOSHARDING_MODULES[module_cls](node.meta["mase"], mesh) elif node.op == "call_method" and node.target in AUTOSHARDING_METHODS.keys(): - logger.debug(f"Obtaining strategy for node {node.name}") + logger.debug(f"Obtaining strategy for call_method node: {node.name}") op_strategy = AUTOSHARDING_METHODS[node.target](node.meta["mase"], mesh) elif ( node.op == "call_function" and node.target in AUTOSHARDING_FUNCTIONS.keys() ): - logger.debug(f"Obtaining strategy for node {node.name}") + logger.debug(f"Obtaining strategy for call_function node: {node.name}") op_strategy = AUTOSHARDING_FUNCTIONS[node.target](node.meta["mase"], mesh) else: - logger.warning(f"Unknown node {node.name} with op {node.op}") + logger.warning( + f"Unknown node {node.name} with op {node.op} with be allocated fully replicated strategy." + ) op_strategy = fully_replicated_strategy(node.meta["mase"], mesh) opt_var = cp.Variable(1, boolean=True) constr += [ @@ -147,6 +250,23 @@ def _extract_ilp(mg, mesh, pass_args={}): "output": None, } + # Consider computation cost (c_v term) for each of the node's strategies + # placeholder/get_attr/output nodes have no computation cost + if node.op not in [ + "placeholder", + "get_attr", + "output", + # todo: decide how to handle call_method nodes + "call_method", + ]: + cost_vector = [] + try: + for strategy in op_strategy.strategies: + cost = _get_computation_cost_from_strategy(node, strategy, mesh) + cost_vector.append(cost) + except: + print(f"Op {node} failed to compute cost") + # Consider resharding cost for each of the node's arguments e_var_checks = [] for arg_idx, in_node in enumerate(node.all_input_nodes): diff --git a/src/chop/passes/graph/analysis/autosharding/layers.py b/src/chop/passes/graph/analysis/autosharding/layers.py index 03f99d8ee..ff43a3cdc 100644 --- a/src/chop/passes/graph/analysis/autosharding/layers.py +++ b/src/chop/passes/graph/analysis/autosharding/layers.py @@ -251,6 +251,7 @@ "expand": get_reshape_strategy(torch.Tensor.expand), "permute": get_reshape_strategy(torch.Tensor.permute), "transpose": get_reshape_strategy(torch.Tensor.transpose), + "unsqueeze": get_reshape_strategy(torch.Tensor.unsqueeze), "masked_fill": pointwise_strategy, "masked_fill_": pointwise_strategy, "contiguous": tensor_op_strategy, diff --git a/src/chop/passes/graph/analysis/autosharding/strategies/math_ops.py b/src/chop/passes/graph/analysis/autosharding/strategies/math_ops.py index 4cdc2f0fe..3d0241580 100644 --- a/src/chop/passes/graph/analysis/autosharding/strategies/math_ops.py +++ b/src/chop/passes/graph/analysis/autosharding/strategies/math_ops.py @@ -100,7 +100,7 @@ def layer_norm_strategy(meta, mesh): input_strategy = meta.node.args[0].meta["mase"]["software"]["autosharding"][ "op_strategy" ] - normalized_shape = meta["common"]["args"]["normalized_shape"] + normalized_shape = meta["common"]["args"]["normalized_shape"]["value"] weight_strategy = meta.node.kwargs["weight"].meta["mase"]["software"][ "autosharding" ]["op_strategy"] @@ -120,9 +120,8 @@ def layer_norm_strategy(meta, mesh): # we use OpStrategy because the output (out, mean, rstd) # should have the same placements output_strategy = OpStrategy([]) - for idx, input_placement_strategy in enumerate(input_strategy.strategies): + for input_placement_strategy in input_strategy.strategies: op_args_target_specs = [] - redistribute_costs = [] input_src_spec = input_placement_strategy.output_spec # for the input tensor, we replicate it on the inner dims if necessary @@ -134,9 +133,15 @@ def layer_norm_strategy(meta, mesh): tensor_meta=input_src_spec.tensor_meta, ) op_args_target_specs.append(input_target_spec) - # redistribute_costs.append( - # generate_redistribute_costs(input_strategy, input_target_spec) - # ) + + # Add replicate spec for normalized_shape + normalized_shape_spec = DTensorSpec( + mesh=mesh, + placements=(Replicate(),) * 2, + # todo: check that it's safe not to assign tensor meta here + tensor_meta=None, + ) + op_args_target_specs.append(normalized_shape_spec) if weight_strategy is not None: assert isinstance(weight_strategy, OpStrategy) @@ -172,9 +177,6 @@ def layer_norm_strategy(meta, mesh): tensor_meta=bias_src_spec.tensor_meta, ) op_args_target_specs.append(bias_target_spec) - # redistribute_costs.append( - # generate_redistribute_costs(bias_strategy, bias_target_spec) - # ) # the output spec is the same as input spec output_target_spec = input_target_spec @@ -182,7 +184,6 @@ def layer_norm_strategy(meta, mesh): PlacementStrategy( output_specs=output_target_spec, input_specs=op_args_target_specs, - # redistribute_cost=redistribute_costs, ) ) diff --git a/src/chop/passes/graph/analysis/autosharding/strategies/pointwise_ops.py b/src/chop/passes/graph/analysis/autosharding/strategies/pointwise_ops.py index fe6ac93ef..10f3f854e 100644 --- a/src/chop/passes/graph/analysis/autosharding/strategies/pointwise_ops.py +++ b/src/chop/passes/graph/analysis/autosharding/strategies/pointwise_ops.py @@ -32,7 +32,11 @@ logger = get_logger(__name__) -def pointwise_strategy(meta, mesh, linearity=False): +def pointwise_strategy( + meta, + mesh, + linearity=False, +): max_shards_strategy_index = -1 max_shards = -1 followed_strategy = None @@ -66,7 +70,11 @@ def pointwise_strategy(meta, mesh, linearity=False): def common_pointwise_strategy( - meta, mesh, followed_strategy, linearity, followed_strategy_index=0 + meta, + mesh, + followed_strategy, + linearity, + followed_strategy_index=0, ): # handle broadcasting parsed_args = [] @@ -112,7 +120,6 @@ def common_pointwise_strategy( out_placements.append(placement) input_specs: List[DTensorSpec] = [] - # redistribute_costs: List[List[float]] = [] for arg_node in meta.node.args: if not isinstance(arg_node, torch.fx.Node): continue @@ -134,10 +141,8 @@ def common_pointwise_strategy( placements=input_target_placements, tensor_meta=input_arg_spec.tensor_meta, ) - input_specs.append(input_arg_target_spec) - # redistribute_costs.append( - # generate_redistribute_costs(input_arg, input_arg_target_spec) - # ) + # input_specs.append(input_arg_target_spec) + input_specs = [input_arg_target_spec] * 2 dtype = meta["common"]["results"]["data_out_0"].get( "torch_dtype", torch.float32 @@ -154,16 +159,22 @@ def common_pointwise_strategy( ), ), input_specs=input_specs, - # redistribute_cost=redistribute_costs, ) ) return pointwise_strategy -def linear_pointwise_strategy(meta, mesh): +def linear_pointwise_strategy( + meta, + mesh, +): """ Linear pointwise operators can propagate pending reductions. For example, c = add(a, b); if a is pending sum, then c will be pending sum as well without any communication overhead. """ - return pointwise_strategy(meta, mesh, linearity=True) + return pointwise_strategy( + meta, + mesh, + linearity=True, + ) diff --git a/src/chop/passes/graph/analysis/autosharding/strategies/view_ops.py b/src/chop/passes/graph/analysis/autosharding/strategies/view_ops.py index 83c638bfb..668cbf8e6 100644 --- a/src/chop/passes/graph/analysis/autosharding/strategies/view_ops.py +++ b/src/chop/passes/graph/analysis/autosharding/strategies/view_ops.py @@ -463,6 +463,7 @@ def dim_view_as_real(shape: Shape) -> DimMap: ), Tensor.transpose: lambda input, dim0, dim1: dim_transpose(input.ndim, dim0, dim1), Tensor.view: lambda input, *shape: view_groups(input.shape, shape), + Tensor.unsqueeze: lambda input, dim: dim_unsqueeze(input.ndim, dim), } @@ -575,11 +576,10 @@ def get_in_dim_to_shard(cmd: DimSpec) -> Optional[InputDim]: def get_reshape_strategy(op): dim_map = dim_maps[op] - # def reshape_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: def reshape_strategy(meta, mesh): assert meta.node.op == "call_method", "Node should have call_method op." args_schema = [meta["common"]["self"]] + [ - i for i in meta["common"]["args"].values() + i["value"] for i in meta["common"]["args"].values() ] rules = dim_map(*args_schema) parent_node = meta.node.args[0] From bc5ff83941a671d9c58c49c35742586d350574fc Mon Sep 17 00:00:00 2001 From: Pedro Gimenes Date: Wed, 24 Jul 2024 17:48:26 +0000 Subject: [PATCH 60/93] [UNFINISHED]: simplify DTensor OpDispatcher --- src/chop/distributed/tensor/_dispatch.py | 278 +++++++++--------- src/chop/distributed/tensor/_sharding_prop.py | 8 + 2 files changed, 146 insertions(+), 140 deletions(-) diff --git a/src/chop/distributed/tensor/_dispatch.py b/src/chop/distributed/tensor/_dispatch.py index b9316729e..b35a1fc84 100644 --- a/src/chop/distributed/tensor/_dispatch.py +++ b/src/chop/distributed/tensor/_dispatch.py @@ -38,7 +38,6 @@ from chop.distributed.tensor._redistribute import redistribute_local_tensor aten = torch.ops.aten -logger = logging.getLogger(__name__) def decompose_handler( @@ -123,134 +122,136 @@ def dispatch( # extract local tensor and sharding infos to a OpInfo op_info = self.unwrap_to_op_info(op_call, args, kwargs) - logger.debug("Dispatching op_call: %s", op_info.schema) + rlog(f"Dispatching op_call: {op_call.name}") - self.sharding_propagator.propagate(op_info) - output_sharding = op_info.output_sharding - logger.debug("output_sharding for %s: %s", op_call, output_sharding) - assert output_sharding is not None, "output sharding should not be None" - - mesh = op_info.mesh - if mesh.get_coordinate() is None: - # For a non-participating device, we do: - # 1. if the return type is scalar, set the local result to None. - # The local results from all devices will then be all-gathered - # and a reduce op will be performed on the list of results - # with appropriate operators: - # for bool type, we by default use AND to reduce; - # we can extend for more ops if necessary. - # 2. if the return type is Tensor or List[Tensor], return empty - # tensor(s) with correct dtype. - spec = output_sharding.output_spec - ret_list = op_info.schema.op._schema.returns - - if spec is None: - # For a scalar return type, the non-participating device has None - # as its local result - local_results: object = None - else: - - def default_tensor(spec: DTensorSpec) -> torch.Tensor: - if spec.tensor_meta is not None: - shape = spec.tensor_meta.shape - dtype = spec.tensor_meta.dtype - if len(shape) == 0: - # scalar tensor - return torch.zeros((), dtype=dtype) - else: - # non-scalar tensor - return torch.tensor([], dtype=dtype) - else: - raise RuntimeError(f"{spec} has no tensor metadata.") - - if isinstance(spec, DTensorSpec): - # return a Tensor value - local_results = default_tensor(spec) - elif isinstance(spec, Sequence): - # return a List[Tensor] value - local_results = [ - default_tensor(s) if s is not None else None for s in spec - ] - assert isinstance(local_results, List) - if None in local_results: - ret_type = str(ret_list[0].type) - raise NotImplementedError( - f"return type {ret_type} in DTensor op is not supported" - ) - else: - if output_sharding.needs_redistribute: - # compute locally with redistribute first if needed - assert output_sharding.redistribute_schema is not None - self.redistribute_local_args( - op_info, output_sharding.redistribute_schema - ) + # self.sharding_propagator.propagate(op_info) + # output_sharding = op_info.output_sharding - local_tensor_args = ( - pytree.tree_unflatten( - cast(List[object], op_info.local_args), op_info.args_tree_spec - ) - if op_info.args_tree_spec - else op_info.local_args - ) + output_sharding = self.sharding_propagator.propagate(op_info) - # run local op computation with potentially modified args/kwargs - local_tensor_args = cast(Tuple[object, ...], local_tensor_args) - if op_call in self._random_ops: - if not random._rng_tracker and is_rng_supported_mesh(mesh): - # Default to `OffsetBasedRNGTracker` if the parallelism API - # did not already construct one - random._rng_tracker = random.OffsetBasedRNGTracker(mesh.device_type) - - first_arg, first_local_arg = cast(dtensor.DTensor, args[0]), cast( - torch.Tensor, local_tensor_args[0] - ) - rng_context = ( - random._rng_tracker._distribute_region(first_arg._spec) - if random._rng_tracker and not first_local_arg.is_meta - else contextlib.nullcontext() - ) + assert output_sharding is not None, "output sharding should not be None" - # For DTensor random operator, run it within a distribute region - with rng_context: - local_results = op_call(*local_tensor_args, **op_info.local_kwargs) - else: - local_results = op_call(*local_tensor_args, **op_info.local_kwargs) + # mesh = op_info.mesh + # if mesh.get_coordinate() is None: + # # For a non-participating device, we do: + # # 1. if the return type is scalar, set the local result to None. + # # The local results from all devices will then be all-gathered + # # and a reduce op will be performed on the list of results + # # with appropriate operators: + # # for bool type, we by default use AND to reduce; + # # we can extend for more ops if necessary. + # # 2. if the return type is Tensor or List[Tensor], return empty + # # tensor(s) with correct dtype. + # spec = output_sharding.output_spec + # ret_list = op_info.schema.op._schema.returns + + # if spec is None: + # # For a scalar return type, the non-participating device has None + # # as its local result + # local_results: object = None + # else: + + # def default_tensor(spec: DTensorSpec) -> torch.Tensor: + # if spec.tensor_meta is not None: + # shape = spec.tensor_meta.shape + # dtype = spec.tensor_meta.dtype + # if len(shape) == 0: + # # scalar tensor + # return torch.zeros((), dtype=dtype) + # else: + # # non-scalar tensor + # return torch.tensor([], dtype=dtype) + # else: + # raise RuntimeError(f"{spec} has no tensor metadata.") + + # if isinstance(spec, DTensorSpec): + # # return a Tensor value + # local_results = default_tensor(spec) + # elif isinstance(spec, Sequence): + # # return a List[Tensor] value + # local_results = [ + # default_tensor(s) if s is not None else None for s in spec + # ] + # assert isinstance(local_results, List) + # if None in local_results: + # ret_type = str(ret_list[0].type) + # raise NotImplementedError( + # f"return type {ret_type} in DTensor op is not supported" + # ) + # else: + if output_sharding.needs_redistribute: + # compute locally with redistribute first if needed + assert output_sharding.redistribute_schema is not None + self.redistribute_local_args(op_info, output_sharding.redistribute_schema) + + # local_tensor_args = ( + # pytree.tree_unflatten( + # cast(List[object], op_info.local_args), op_info.args_tree_spec + # ) + # if op_info.args_tree_spec + # else op_info.local_args + # ) + + local_tensor_args = op_info.local_args + + # run local op computation with potentially modified args/kwargs + local_tensor_args = cast(Tuple[object, ...], local_tensor_args) + # if op_call in self._random_ops: + # if not random._rng_tracker and is_rng_supported_mesh(mesh): + # # Default to `OffsetBasedRNGTracker` if the parallelism API + # # did not already construct one + # random._rng_tracker = random.OffsetBasedRNGTracker(mesh.device_type) + + # first_arg, first_local_arg = cast(dtensor.DTensor, args[0]), cast( + # torch.Tensor, local_tensor_args[0] + # ) + # rng_context = ( + # random._rng_tracker._distribute_region(first_arg._spec) + # if random._rng_tracker and not first_local_arg.is_meta + # else contextlib.nullcontext() + # ) + + # # For DTensor random operator, run it within a distribute region + # with rng_context: + # local_results = op_call(*local_tensor_args, **op_info.local_kwargs) + # else: + local_results = op_call(*local_tensor_args, **op_info.local_kwargs) # communicate the result to all ranks for some operators that return scalar value - if output_sharding.output_spec is None: - if op_call == aten.equal.default: - obj_list = [None for _ in range(dist.get_world_size())] - dist.all_gather_object(obj_list, local_results) # type: ignore[possibly-undefined] - obj_list = list(filter(lambda x: x is not None, obj_list)) - # perform reduce on the collection with AND op - local_results = functools.reduce(operator.and_, obj_list, True) - - if _is_inplace_op(op_call): - # inplace op should return self instead of re-wrapping - if output_sharding.output_spec is not None: - return args[0] - else: - return None - elif _is_out_variant_op(op_call): - # out variant could possibly have multiple out args (i.e. lu_unpack.out) - output_specs = ( - (output_sharding.output_spec,) - if not isinstance(output_sharding.output_spec, tuple) - else output_sharding.output_spec - ) - out_dts = [] - spec_idx = 0 - for argument in op_call._schema.arguments: - if argument.is_out: - out_dt = cast(dtensor.DTensor, kwargs[argument.name]) - out_dt._spec = cast(DTensorSpec, output_specs[spec_idx]) - out_dts.append(out_dt) - spec_idx += 1 - - assert len(out_dts) >= 1, "out variant should have at least one out arg" - return tuple(out_dts) if len(out_dts) > 1 else out_dts[0] - else: - return self.wrap(local_results, output_sharding.output_spec) # type: ignore[possibly-undefined] + # if output_sharding.output_spec is None: + # if op_call == aten.equal.default: + # obj_list = [None for _ in range(dist.get_world_size())] + # dist.all_gather_object(obj_list, local_results) # type: ignore[possibly-undefined] + # obj_list = list(filter(lambda x: x is not None, obj_list)) + # # perform reduce on the collection with AND op + # local_results = functools.reduce(operator.and_, obj_list, True) + + # if _is_inplace_op(op_call): + # # inplace op should return self instead of re-wrapping + # if output_sharding.output_spec is not None: + # return args[0] + # else: + # return None + # elif _is_out_variant_op(op_call): + # # out variant could possibly have multiple out args (i.e. lu_unpack.out) + # output_specs = ( + # (output_sharding.output_spec,) + # if not isinstance(output_sharding.output_spec, tuple) + # else output_sharding.output_spec + # ) + # out_dts = [] + # spec_idx = 0 + # for argument in op_call._schema.arguments: + # if argument.is_out: + # out_dt = cast(dtensor.DTensor, kwargs[argument.name]) + # out_dt._spec = cast(DTensorSpec, output_specs[spec_idx]) + # out_dts.append(out_dt) + # spec_idx += 1 + + # assert len(out_dts) >= 1, "out variant should have at least one out arg" + # return tuple(out_dts) if len(out_dts) > 1 else out_dts[0] + # else: + return self.wrap(local_results, output_sharding.output_spec) # type: ignore[possibly-undefined] @staticmethod def redistribute_local_args( @@ -296,12 +297,13 @@ def unwrap_to_op_info( op_call, None ) - if runtime_schema_info is not None and runtime_schema_info.needs_pytree: - # flatten args/kwargs when necessary - tree_args, args_spec = pytree.tree_flatten(args) - args_list: Sequence[object] = tree_args - else: - args_list, args_spec = args, None + # if runtime_schema_info is not None and runtime_schema_info.needs_pytree: + # # flatten args/kwargs when necessary + # print(f"needs pytree...") + # tree_args, args_spec = pytree.tree_flatten(args) + # args_list: Sequence[object] = tree_args + # else: + args_list, args_spec = args, None args_schema: List[object] = [] kwargs_schema: Dict[str, object] = {} @@ -385,21 +387,17 @@ def try_get_replicate_spec( assert mesh is not None, f"found no DeviceMesh from dtensor args for {op_call}!" op_info = OpInfo( - mesh, - OpSchema( + mesh=mesh, + schema=OpSchema( op_call, - ( - pytree.tree_unflatten(args_schema, args_spec) - if args_spec - else tuple(args_schema) - ), + tuple(args_schema), kwargs_schema, schema_info=runtime_schema_info, ), - args_schema, - tuple(local_args), - local_kwargs, - args_spec, + flat_args_schema=args_schema, + local_args=tuple(local_args), + local_kwargs=local_kwargs, + args_tree_spec=args_spec, ) return op_info diff --git a/src/chop/distributed/tensor/_sharding_prop.py b/src/chop/distributed/tensor/_sharding_prop.py index 81224a952..045465412 100644 --- a/src/chop/distributed/tensor/_sharding_prop.py +++ b/src/chop/distributed/tensor/_sharding_prop.py @@ -192,6 +192,8 @@ def propagate(self, op_info: OpInfo) -> None: output_sharding = self.propagate_op_sharding(op_info.schema) op_info.output_sharding = output_sharding + return output_sharding + def propagate_op_sharding_non_cached(self, op_schema: OpSchema) -> OutputSharding: """ Propagate the sharding for an operator given the op_schema. @@ -283,11 +285,15 @@ def spec_to_strategy(spec: object) -> object: if output_strategy.output_spec.is_sharded(): schema = suggestion_schema or op_schema assert isinstance(out_tensor_meta, TensorMeta) + rlog(f"Need the out tensor meta here!") suggestion_schema = self._adjust_shape_and_stride_args( out_tensor_meta, schema, output_strategy.output_spec, mesh ) needs_redistribute = True + else: + rlog(f"Don't need it because it's not sharded") + # construct output spec for the op if op_schema.return_type_tuple_tensor_like(): # for ops that return multiple tensors and the output_specs is not @@ -318,6 +324,7 @@ def spec_to_strategy(spec: object) -> object: needs_redistribute=needs_redistribute, ) elif isinstance(op_strategy, TupleStrategy): + rlog(f"Strategy returned a TupleStrategy") # tuple strategy output sharding processing # runtime selected placement strategy for each TupleStrategy input arg selected_strategies: List[PlacementStrategy] = [] @@ -394,6 +401,7 @@ def spec_to_strategy(spec: object) -> object: ) return output_sharding elif op_schema.op in self.op_to_rules: + rlog(f"Op {op_schema.op} has no strategy, using rule") # propagate the sharding with rule sharding_prop_func = self.op_to_rules[op_schema.op] From 9d95ace0547e6266c752f00580588f72870d3bf0 Mon Sep 17 00:00:00 2001 From: Pedro Gimenes Date: Fri, 26 Jul 2024 21:00:30 +0000 Subject: [PATCH 61/93] find and replace call_method nodes with call_functional for arg ordering consistency --- src/chop/ir/common.py | 9 +++ src/chop/nn/functional/tensor.py | 80 +++++++++++++++++++ src/chop/passes/__init__.py | 4 +- .../add_metadata/common_metadata_layers.py | 45 +++++++++++ .../analysis/autosharding/autosharding.py | 3 + .../autosharding/strategies/pointwise_ops.py | 2 +- src/chop/passes/graph/transforms/__init__.py | 4 + .../graph/transforms/find_replace/__init__.py | 0 .../find_replace/method_to_function.py | 63 +++++++++++++++ src/chop/pipelines/distributed_inference.py | 1 + 10 files changed, 209 insertions(+), 2 deletions(-) create mode 100644 src/chop/nn/functional/tensor.py create mode 100644 src/chop/passes/graph/transforms/find_replace/__init__.py create mode 100644 src/chop/passes/graph/transforms/find_replace/method_to_function.py diff --git a/src/chop/ir/common.py b/src/chop/ir/common.py index 98579739b..7cd6f219b 100644 --- a/src/chop/ir/common.py +++ b/src/chop/ir/common.py @@ -49,6 +49,15 @@ "finfo", "masked_fill", "masked_fill_", + # Inserted ops from the replace_method_with_function pass + "torch_size", + "torch_contiguous", + "torch_expand", + "torch_view", + "torch_reshape", + "torch_split", + "torch_permute", + "torch_transpose", ] MASE_MODULE_RELATED_FUNCS = [ diff --git a/src/chop/nn/functional/tensor.py b/src/chop/nn/functional/tensor.py new file mode 100644 index 000000000..6f423c8f2 --- /dev/null +++ b/src/chop/nn/functional/tensor.py @@ -0,0 +1,80 @@ +import torch +import torch.fx as fx + +# This file contains functional equivalent of some torch.Tensor methods +# which can be casted to call_function nodes by the replace_method_with_function pass. +# They must have the same signature as their torch.Tensor equivalents with an added +# input node at position 0. + + +@fx.wrap +def torch_size( + input: torch.Tensor, + dim: int = None, +): + return input.size(dim) + + +@fx.wrap +def torch_expand( + input: torch.Tensor, + *sizes, +): + return input.expand(*sizes) + + +@fx.wrap +def torch_view( + input: torch.Tensor, + *shape, +): + return input.view(*shape) + + +@fx.wrap +def torch_contiguous( + input: torch.Tensor, + memory_format: torch.memory_format = torch.contiguous_format, +): + return input.contiguous(memory_format=memory_format) + + +# The following functions exist in torch functional land, +# however their functional implementation does not accept +# arbitrary argument counts i.e. *args, **kwargs, so we +# reimplement them here. +# ============================================================ + + +@fx.wrap +def torch_reshape( + input: torch.Tensor, + *shape, +): + return input.reshape(*shape) + + +@fx.wrap +def torch_split( + input: torch.Tensor, + split_size: int, + dim: int = 0, +): + return input.split(split_size, dim) + + +@fx.wrap +def torch_permute( + input: torch.Tensor, + *dims, +): + return input.permute(*dims) + + +@fx.wrap +def torch_transpose( + input: torch.Tensor, + dim0: int, + dim1: int, +): + return input.transpose(dim0, dim1) diff --git a/src/chop/passes/__init__.py b/src/chop/passes/__init__.py index 8b0053113..b8d90219b 100644 --- a/src/chop/passes/__init__.py +++ b/src/chop/passes/__init__.py @@ -33,9 +33,11 @@ emit_vivado_project_transform_pass, raise_granularity_transform_pass, patch_metadata_transform_pass, + resharding_transform_pass, + replace_method_with_function, ) from .module.analysis import calculate_avg_bits_module_analysis_pass -from .module.transforms import quantize_module_transform_pass, resharding_transform_pass +from .module.transforms import quantize_module_transform_pass from .onnx.analysis import ( export_fx_graph_analysis_pass, diff --git a/src/chop/passes/graph/analysis/add_metadata/common_metadata_layers.py b/src/chop/passes/graph/analysis/add_metadata/common_metadata_layers.py index 88e9a9f71..b3213b2b1 100644 --- a/src/chop/passes/graph/analysis/add_metadata/common_metadata_layers.py +++ b/src/chop/passes/graph/analysis/add_metadata/common_metadata_layers.py @@ -225,6 +225,51 @@ "scale_grad_by_freq": "config", "sparse": "config", }, + # Inserted ops from the replace_method_with_function pass + "torch_size": {"input": "data_in", "dim": "config"}, + "torch_contiguous": { + "input": "data_in", + "memory_format": "config", + }, + # arbitrary length - support up to 4 + "torch_expand": { + "input": "data_in", + "size_0": "config", + "size_1": "config", + "size_2": "config", + "size_3": "config", + }, + "torch_view": { + "input": "data_in", + "shape_0": "config", + "shape_1": "config", + "shape_2": "config", + "shape_3": "config", + }, + "torch_reshape": { + "input": "data_in", + "shape_0": "config", + "shape_1": "config", + "shape_2": "config", + "shape_3": "config", + }, + "torch_split": { + "input": "data_in", + "split_size": "config", + "dim": "config", + }, + "torch_permute": { + "input": "data_in", + "dim_0": "config", + "dim_1": "config", + "dim_2": "config", + "dim_3": "config", + }, + "torch_transpose": { + "input": "data_in", + "dim0": "config", + "dim1": "config", + }, } module_data = { diff --git a/src/chop/passes/graph/analysis/autosharding/autosharding.py b/src/chop/passes/graph/analysis/autosharding/autosharding.py index 78a2175ce..d46ba4a2c 100644 --- a/src/chop/passes/graph/analysis/autosharding/autosharding.py +++ b/src/chop/passes/graph/analysis/autosharding/autosharding.py @@ -214,6 +214,9 @@ def _get_sharding_map(mg): def autosharding_analysis_pass(mg, pass_args: dict = {}): """Annotate the metadata of each operator in the graph with a parallelization strategy. + For the autosharding pass to work, the fx graph must contain only placeholder, get_attr, + call_functional and output nodes. call_method and call_module nodes are not allowed. + Args: mg (MaseGraph): input mase graph. pass_args (dict, optional): pass arguments. Defaults to {}. diff --git a/src/chop/passes/graph/analysis/autosharding/strategies/pointwise_ops.py b/src/chop/passes/graph/analysis/autosharding/strategies/pointwise_ops.py index 10f3f854e..a7504d2d8 100644 --- a/src/chop/passes/graph/analysis/autosharding/strategies/pointwise_ops.py +++ b/src/chop/passes/graph/analysis/autosharding/strategies/pointwise_ops.py @@ -142,7 +142,7 @@ def common_pointwise_strategy( tensor_meta=input_arg_spec.tensor_meta, ) # input_specs.append(input_arg_target_spec) - input_specs = [input_arg_target_spec] * 2 + input_specs = [input_arg_target_spec] * len(meta.node.args) dtype = meta["common"]["results"]["data_out_0"].get( "torch_dtype", torch.float32 diff --git a/src/chop/passes/graph/transforms/__init__.py b/src/chop/passes/graph/transforms/__init__.py index 612773262..09667c182 100644 --- a/src/chop/passes/graph/transforms/__init__.py +++ b/src/chop/passes/graph/transforms/__init__.py @@ -20,3 +20,7 @@ from .granularity import raise_granularity_transform_pass from .patching import patch_metadata_transform_pass + +from .resharding import resharding_transform_pass + +from .find_replace.method_to_function import replace_method_with_function diff --git a/src/chop/passes/graph/transforms/find_replace/__init__.py b/src/chop/passes/graph/transforms/find_replace/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/chop/passes/graph/transforms/find_replace/method_to_function.py b/src/chop/passes/graph/transforms/find_replace/method_to_function.py new file mode 100644 index 000000000..5cc70a672 --- /dev/null +++ b/src/chop/passes/graph/transforms/find_replace/method_to_function.py @@ -0,0 +1,63 @@ +import torch + +from chop.tools import get_logger +from chop.nn.functional.tensor import ( + torch_size, + torch_expand, + torch_view, + torch_contiguous, + torch_reshape, + torch_split, + torch_permute, + torch_transpose, +) + +logger = get_logger(__name__) +logger.setLevel("DEBUG") + + +REPLACE_METHODS = { + "size": torch_size, + "reshape": torch_reshape, + "expand": torch_expand, + "split": torch_split, + "view": torch_view, + "permute": torch_permute, + "transpose": torch_transpose, + "contiguous": torch_contiguous, +} + + +def replace_method_with_function(mg, pass_args={}): + """Replaces call_method calls with call_function calls in the graph. + + Args: + graph (MaseGraph): The input graph. + + Returns: + MaseGraph: The graph with method calls replaced with function calls. + """ + for node in mg.fx_graph.nodes: + if node.op != "call_method": + continue + + if node.target in REPLACE_METHODS: + + with mg.fx_graph.inserting_after(node): + logger.debug(f"Replacing {node.target} with function call.") + new_node = mg.fx_graph.call_function( + REPLACE_METHODS[node.target], + node.args, + node.kwargs, + ) + node.replace_all_uses_with(new_node) + mg.fx_graph.erase_node(node) + + else: + raise NotImplementedError( + f"Method {node.target} not implemented in replace_method_with_function." + ) + + mg.model.recompile() + + return mg, {} diff --git a/src/chop/pipelines/distributed_inference.py b/src/chop/pipelines/distributed_inference.py index bc06ebab5..51354583b 100644 --- a/src/chop/pipelines/distributed_inference.py +++ b/src/chop/pipelines/distributed_inference.py @@ -23,6 +23,7 @@ def __init__(self) -> None: """Initializes the AutoPipeline.""" pass_list = [ + passes.replace_method_with_function, passes.init_metadata_analysis_pass, passes.report_graph_analysis_pass, passes.add_common_metadata_analysis_pass, From 25f4c30f2bade993c9ad6456ea40f515c845cedd Mon Sep 17 00:00:00 2001 From: Pedro Gimenes Date: Fri, 26 Jul 2024 21:03:58 +0000 Subject: [PATCH 62/93] insert resharding nodes --- src/chop/ir/common.py | 2 + src/chop/nn/functional/dtensor.py | 87 +++++++++++++++ .../add_metadata/common_metadata_layers.py | 12 +++ .../autosharding/alpa_intra_operator.py | 100 +++++++++--------- .../analysis/autosharding/autosharding.py | 8 +- .../graph/analysis/autosharding/layers.py | 20 ++++ .../autosharding/strategies/common.py | 31 +++--- .../autosharding/strategies/embedding_ops.py | 12 ++- .../autosharding/strategies/math_ops.py | 8 ++ .../autosharding/strategies/matrix_ops.py | 7 +- .../autosharding/strategies/view_ops.py | 18 ++-- .../graph/analysis/report/report_graph.py | 8 +- .../passes/graph/transforms/resharding.py | 95 +++++++++++++++++ src/chop/passes/module/__init__.py | 2 +- src/chop/passes/module/transforms/__init__.py | 2 - .../transforms/autosharding/__init__.py | 1 - .../transforms/autosharding/resharding.py | 26 ----- src/chop/pipelines/distributed_inference.py | 2 +- 18 files changed, 334 insertions(+), 107 deletions(-) create mode 100644 src/chop/nn/functional/dtensor.py create mode 100644 src/chop/passes/graph/transforms/resharding.py delete mode 100644 src/chop/passes/module/transforms/autosharding/__init__.py delete mode 100644 src/chop/passes/module/transforms/autosharding/resharding.py diff --git a/src/chop/ir/common.py b/src/chop/ir/common.py index 7cd6f219b..8cf46ebf2 100644 --- a/src/chop/ir/common.py +++ b/src/chop/ir/common.py @@ -58,6 +58,8 @@ "torch_split", "torch_permute", "torch_transpose", + # dtensor ops (return DTensor) + "dtensor_arange", ] MASE_MODULE_RELATED_FUNCS = [ diff --git a/src/chop/nn/functional/dtensor.py b/src/chop/nn/functional/dtensor.py new file mode 100644 index 000000000..58888bee3 --- /dev/null +++ b/src/chop/nn/functional/dtensor.py @@ -0,0 +1,87 @@ +from typing import Tuple + +import torch +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed._tensor import DTensor +from torch.distributed._tensor.placement_types import DTensorSpec +from torch.distributed._tensor._redistribute import redistribute_local_tensor + +from torch.distributed._tensor.placement_types import Placement + + +def dtensor_arange( + start: int, + end: int, + step: int = 1, + out: torch.Tensor = None, + dtype: torch.dtype = None, + layout: torch.layout = torch.strided, + device: torch.device = None, + requires_grad: bool = False, + device_mesh: DeviceMesh = None, +): + """Returns a fully replicated DTensor with behaviour akin to `torch.arange`. + + Args: + start (int): _description_ + end (int): _description_ + step (int, optional): _description_. Defaults to 1. + out (torch.Tensor, optional): _description_. Defaults to None. + dtype (torch.dtype, optional): _description_. Defaults to None. + layout (torch.layout, optional): _description_. Defaults to torch.strided. + device (torch.device, optional): _description_. Defaults to None. + requires_grad (bool, optional): _description_. Defaults to False. + """ + return DTensor.from_local( + torch.arange( + start, + end, + step, + out=out, + dtype=dtype, + layout=layout, + device=device, + ), + device_mesh=device_mesh, + ) + + +def redistribute_dtensor( + input: DTensor, + placements: Tuple[Placement, ...], + async_op: bool = False, +): + """ + Redistribute a DTensor to a new set of placements. + + Args: + input (DTensor): The input DTensor to redistribute. + placements (Tuple[Placement, ...]): The new placements for the output DTensor. + async_op (bool, optional): Whether to perform the redistribution asynchronously. Defaults to False. + + Returns: + DTensor: The redistributed DTensor. + """ + current_spec = input._spec + + if current_spec.placements != placements: + target_spec = DTensorSpec( + None, + placements, + tensor_meta=input._spec.tensor_meta, + ) + + local_tensor = input._local_tensor + output = redistribute_local_tensor( + local_tensor, current_spec, target_spec, async_op=async_op + ) + else: + # use the same local tensor if placements are the same. + output = input._local_tensor + target_spec = current_spec + + return DTensor( + output, + target_spec, + requires_grad=input.requires_grad, + ) diff --git a/src/chop/passes/graph/analysis/add_metadata/common_metadata_layers.py b/src/chop/passes/graph/analysis/add_metadata/common_metadata_layers.py index b3213b2b1..5e0913d2f 100644 --- a/src/chop/passes/graph/analysis/add_metadata/common_metadata_layers.py +++ b/src/chop/passes/graph/analysis/add_metadata/common_metadata_layers.py @@ -270,6 +270,18 @@ "dim0": "config", "dim1": "config", }, + # DTensor ops + "dtensor_arange": { + "device_mesh": "config", + "start": "config", + "end": "config", + "step": "config", + "out": "config", + "dtype": "config", + "layout": "config", + "device": "config", + "requires_grad": "config", + }, } module_data = { diff --git a/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py b/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py index 077a4869b..d97478de1 100644 --- a/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py +++ b/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py @@ -159,7 +159,6 @@ def _extract_ilp(mg, mesh, pass_args={}): ) op_strategy = fully_replicated_strategy(node.meta["mase"], mesh) - opt_var = cp.Variable(1, boolean=True) constr += [ cp.sum(opt_var) == 1, @@ -169,49 +168,41 @@ def _extract_ilp(mg, mesh, pass_args={}): node.meta["mase"]["software"]["autosharding"] = { "op_strategy": op_strategy, "opt_var": opt_var, - "input": None, - "output": None, } continue - # Obtain strategy according to node op - # ================================================ - - if node.op in ["placeholder", "get_attr"]: - logger.debug( - f"Node {node} with op {node.op} will be assigned all permutations of Shard(dims) and Replicate()" - ) - op_strategy = placeholder_or_getattr_strategy( - node.meta["mase"], - mesh, - skip_fully_replicated=pass_args.get("skip_fully_replicated", False), - ) - + # Output nodes simply propagate op_strategy from their input nodes elif node.op == "output": logger.debug( f"Op strategy from node {node.all_input_nodes[0]} is propagated to {node} node." ) + opt_var = cp.Variable(1, boolean=True) + constr += [ + cp.sum(opt_var) == 1, + ] node.meta["mase"]["software"]["autosharding"] = { "op_strategy": node.all_input_nodes[0].meta["mase"]["software"][ "autosharding" ]["op_strategy"], - "opt_var": None, - "input": None, - "output": None, + "opt_var": opt_var, } continue - elif node.op == "call_module" and isinstance( - deepgetattr(mg.model, node.target), - tuple(AUTOSHARDING_MODULES.keys()), - ): - logger.debug(f"Obtaining strategy for call_module node: {node.name}") - module_cls = type(deepgetattr(mg.model, node.target)) - op_strategy = AUTOSHARDING_MODULES[module_cls](node.meta["mase"], mesh) + # Obtain strategy according to node op + # ================================================ - elif node.op == "call_method" and node.target in AUTOSHARDING_METHODS.keys(): - logger.debug(f"Obtaining strategy for call_method node: {node.name}") - op_strategy = AUTOSHARDING_METHODS[node.target](node.meta["mase"], mesh) + elif node.op in [ + "placeholder", + "get_attr", + ]: + logger.debug( + f"Node {node} with op {node.op} will be assigned all permutations of Shard(dims) and Replicate()" + ) + op_strategy = placeholder_or_getattr_strategy( + node.meta["mase"], + mesh, + skip_fully_replicated=pass_args.get("skip_fully_replicated", False), + ) elif ( node.op == "call_function" and node.target in AUTOSHARDING_FUNCTIONS.keys() @@ -229,10 +220,11 @@ def _extract_ilp(mg, mesh, pass_args={}): cp.sum(opt_var) == 1, ] node.meta["mase"]["software"]["autosharding"] = { - "op_strategy": fully_replicated_strategy(node.meta["mase"], mesh), + "op_strategy": fully_replicated_strategy( + node.meta["mase"], + mesh, + ), "opt_var": opt_var, - "input": None, - "output": None, } continue @@ -246,8 +238,6 @@ def _extract_ilp(mg, mesh, pass_args={}): node.meta["mase"]["software"]["autosharding"] = { "op_strategy": op_strategy, "opt_var": opt_var, - "input": None, - "output": None, } # Consider computation cost (c_v term) for each of the node's strategies @@ -383,35 +373,49 @@ def _mark_sharding(mg, pass_args): dict: tensor sharding map. """ + logger.info( + f"Autosharding optimization finished, annotating graph with chosen sharding strategies for each node." + ) + for node in mg.fx_graph.nodes: opt_var = node.meta["mase"]["software"]["autosharding"]["opt_var"] if opt_var is None: - continue - - try: - idx = np.where(opt_var.value == 1)[0][0] - except: - idx = np.argmax(opt_var.value) + chosen_strategy = node.meta["mase"]["software"]["autosharding"][ + "op_strategy" + ] + else: + # Get the strategy chosen by the ILP + try: + idx = np.where(opt_var.value == 1)[0][0] + except: + idx = np.argmax(opt_var.value) - chosen_strategy = node.meta["mase"]["software"]["autosharding"][ - "op_strategy" - ].strategies[idx] + chosen_strategy = node.meta["mase"]["software"]["autosharding"][ + "op_strategy" + ].strategies[idx] # Annotate chosen placement strategy node.meta["mase"]["software"]["autosharding"][ "placement_strategy" ] = chosen_strategy + # Annotate arg metadata with chosen strategy arg_specs = chosen_strategy.input_specs - out_spec = chosen_strategy.output_specs - if isinstance(arg_specs, DTensorSpec): arg_specs = (arg_specs,) - # Annotate arg metadata with chosen strategy - if node.op in ["placeholder", "get_attr", "call_method", "output"]: - pass + if not node.op in ["placeholder", "get_attr", "output"]: + assert len(arg_specs) == len( + node.meta["mase"]["common"]["args"].keys() + ), "Number of arguments do not match metadata." + + out_spec = chosen_strategy.output_specs + + if node.op in ["placeholder", "get_attr", "output"]: + node.meta["mase"]["common"]["results"]["data_out_0"][ + "dtensor_spec" + ] = out_spec # call_function nodes else: diff --git a/src/chop/passes/graph/analysis/autosharding/autosharding.py b/src/chop/passes/graph/analysis/autosharding/autosharding.py index d46ba4a2c..98a80dd46 100644 --- a/src/chop/passes/graph/analysis/autosharding/autosharding.py +++ b/src/chop/passes/graph/analysis/autosharding/autosharding.py @@ -46,12 +46,6 @@ def _import_solution( for node in mg.fx_graph.nodes: logger.debug(f"Importing solution for node: {node.name}") - # Only import solution for getattr nodes - # TO DO: this is hard-coded for GPT2 - # Figure out how to generalize - if not node.name.startswith("transformer_"): - continue - # Extrapolate from first layer by string matching if node.name not in solution.keys() and extrapolate_sharding: @@ -62,7 +56,7 @@ def _import_solution( extrapolate_node = node.name.replace(f"_{layer_num}_", "_0_", 1) if extrapolate_node in solution.keys(): - logger.warning( + logger.debug( f"Node: {node.name} not found in solution. Extrapolating from solution for: {extrapolate_node}" ) solution[node.name] = solution[extrapolate_node] diff --git a/src/chop/passes/graph/analysis/autosharding/layers.py b/src/chop/passes/graph/analysis/autosharding/layers.py index ff43a3cdc..42c883224 100644 --- a/src/chop/passes/graph/analysis/autosharding/layers.py +++ b/src/chop/passes/graph/analysis/autosharding/layers.py @@ -4,6 +4,16 @@ import torch.nn.functional as F from chop.tools import get_logger +from chop.nn.functional.tensor import ( + torch_size, + torch_expand, + torch_view, + torch_contiguous, + torch_reshape, + torch_split, + torch_permute, + torch_transpose, +) from .strategies.common import fully_replicated_strategy from .strategies.matrix_ops import ( @@ -242,6 +252,15 @@ torch.Tensor.zero_: tensor_op_strategy, torch.Tensor.equal: tensor_equal_strategy, torch.Tensor.is_same_size: tensor_equal_strategy, + # chop.nn.functional.tensor functions + torch_expand: get_reshape_strategy(torch.Tensor.expand), + torch_view: get_reshape_strategy(torch.Tensor.view), + torch_contiguous: tensor_op_strategy, + torch_reshape: get_reshape_strategy(torch.Tensor.reshape), + # torch_split: + torch_permute: get_reshape_strategy(torch.Tensor.permute), + torch_transpose: transpose_strategy, + torch.unsqueeze: get_reshape_strategy(torch.unsqueeze), } AUTOSHARDING_METHODS = { @@ -262,6 +281,7 @@ getattr, torch.finfo, torch.arange, + torch_size, ] IMPLICIT_METHODS = [ diff --git a/src/chop/passes/graph/analysis/autosharding/strategies/common.py b/src/chop/passes/graph/analysis/autosharding/strategies/common.py index bdc5b9736..b5c0923e4 100644 --- a/src/chop/passes/graph/analysis/autosharding/strategies/common.py +++ b/src/chop/passes/graph/analysis/autosharding/strategies/common.py @@ -120,15 +120,17 @@ def fully_replicated_strategy(meta, mesh): arg = meta["common"]["args"][first_arg_key] in_shape, in_dtype = find_shape_and_dtype(arg) - in_spec = DTensorSpec( - mesh, - sharding, - tensor_meta=TensorMeta( - shape=in_shape, - stride=None, - dtype=in_dtype, - ), - ) + in_spec = [ + DTensorSpec( + mesh, + sharding, + tensor_meta=TensorMeta( + shape=in_shape, + stride=None, + dtype=in_dtype, + ), + ) + ] * len(meta["common"]["args"].keys()) dtype_key = ( "torch_dtype" @@ -146,6 +148,11 @@ def fully_replicated_strategy(meta, mesh): ), ) - shardings = [PlacementStrategy(input_specs=in_spec, output_specs=out_spec)] - - return OpStrategy(shardings) + return OpStrategy( + [ + PlacementStrategy( + input_specs=in_spec, + output_specs=out_spec, + ) + ] + ) diff --git a/src/chop/passes/graph/analysis/autosharding/strategies/embedding_ops.py b/src/chop/passes/graph/analysis/autosharding/strategies/embedding_ops.py index feef6fd12..e8dc24012 100644 --- a/src/chop/passes/graph/analysis/autosharding/strategies/embedding_ops.py +++ b/src/chop/passes/graph/analysis/autosharding/strategies/embedding_ops.py @@ -217,6 +217,16 @@ def expand_to_full_mesh_op_strategy( for inp, s in zip(input_args_strategy, input_specs) ) + # extend input_specs to include fully replicated sharding for constant nodes + extended_input_specs = input_specs + [ + DTensorSpec( + mesh, + (Replicate(), Replicate()), + # todo: may need to set tensor meta + tensor_meta=None, + ) + ] * (len(meta["common"]["args"].keys()) - len(input_specs)) + # only add to the all_strategies list when all inputs are shardable if inputs_shardable: redistribute_cost = [ @@ -227,7 +237,7 @@ def expand_to_full_mesh_op_strategy( output_specs=( tuple(spec_list[:input_index]) if input_index > 1 else spec_list[0] ), - input_specs=input_specs, + input_specs=extended_input_specs, redistribute_cost=redistribute_cost, ) all_strategies.append(strategy) diff --git a/src/chop/passes/graph/analysis/autosharding/strategies/math_ops.py b/src/chop/passes/graph/analysis/autosharding/strategies/math_ops.py index 3d0241580..984295c58 100644 --- a/src/chop/passes/graph/analysis/autosharding/strategies/math_ops.py +++ b/src/chop/passes/graph/analysis/autosharding/strategies/math_ops.py @@ -178,6 +178,14 @@ def layer_norm_strategy(meta, mesh): ) op_args_target_specs.append(bias_target_spec) + # add fully replicated strategy for eps + eps_spec = DTensorSpec( + mesh=mesh, + placements=(Replicate(),) * 2, + tensor_meta=None, + ) + op_args_target_specs.append(eps_spec) + # the output spec is the same as input spec output_target_spec = input_target_spec output_strategy.strategies.append( diff --git a/src/chop/passes/graph/analysis/autosharding/strategies/matrix_ops.py b/src/chop/passes/graph/analysis/autosharding/strategies/matrix_ops.py index d552f8171..ede2167ff 100644 --- a/src/chop/passes/graph/analysis/autosharding/strategies/matrix_ops.py +++ b/src/chop/passes/graph/analysis/autosharding/strategies/matrix_ops.py @@ -34,6 +34,10 @@ def transpose_strategy( assert isinstance(self_strategy, OpStrategy) + fully_replicated_spec = DTensorSpec( + mesh=mesh, placements=[Replicate(), Replicate()], tensor_meta=None + ) + transpose_strategies = [] for input_strategy in self_strategy.strategies: input_spec = input_strategy.output_spec @@ -52,7 +56,8 @@ def transpose_strategy( dtype=meta["common"]["results"]["data_out_0"]["torch_dtype"], ), ), - input_specs=(input_strategy.output_spec,), + # include 2 fully replicated inputs for dim_0 and dim_1 arguments + input_specs=(input_strategy.output_spec,) + (fully_replicated_spec,) * 2, ) transpose_strategies.append(transpose_strategy) diff --git a/src/chop/passes/graph/analysis/autosharding/strategies/view_ops.py b/src/chop/passes/graph/analysis/autosharding/strategies/view_ops.py index 668cbf8e6..525a49897 100644 --- a/src/chop/passes/graph/analysis/autosharding/strategies/view_ops.py +++ b/src/chop/passes/graph/analysis/autosharding/strategies/view_ops.py @@ -577,17 +577,14 @@ def get_reshape_strategy(op): dim_map = dim_maps[op] def reshape_strategy(meta, mesh): - assert meta.node.op == "call_method", "Node should have call_method op." - args_schema = [meta["common"]["self"]] + [ - i["value"] for i in meta["common"]["args"].values() - ] + args_schema = [i["value"] for i in meta["common"]["args"].values()] rules = dim_map(*args_schema) parent_node = meta.node.args[0] # input_strategy = cast(OpStrategy, op_schema.args_schema[0]) input_strategy = parent_node.meta["mase"]["software"]["autosharding"][ "op_strategy" ] - global_in_shape = meta["common"]["self"].shape + global_in_shape = meta["common"]["args"]["data_in_0"]["shape"] assert global_in_shape is not None, "Shape required." output_strategy = OpStrategy([]) @@ -611,6 +608,15 @@ def reshape_strategy(meta, mesh): tensor_meta=input_src_spec.tensor_meta, ) + replicate_spec = DTensorSpec( + placements=tuple(input_tgt_placements), + mesh=input_src_spec.mesh, + # todo: may need to set tensor meta + tensor_meta=None, + ) + # add fully replicated spec for all constant args + input_specs = (input_tgt_spec,) + (replicate_spec,) * (len(args_schema) - 1) + output_spec = DTensorSpec( mesh=mesh, placements=tuple(output_placements), @@ -623,7 +629,7 @@ def reshape_strategy(meta, mesh): output_strategy.strategies.append( PlacementStrategy( output_specs=output_spec, - input_specs=(input_tgt_spec,), + input_specs=input_specs, ) ) diff --git a/src/chop/passes/graph/analysis/report/report_graph.py b/src/chop/passes/graph/analysis/report/report_graph.py index bfc7086ef..00ba219da 100644 --- a/src/chop/passes/graph/analysis/report/report_graph.py +++ b/src/chop/passes/graph/analysis/report/report_graph.py @@ -64,7 +64,13 @@ def report_graph_analysis_pass(graph, pass_args={"file_name": None}): {count} Layer types: -{layer_types}""" +{layer_types} + +===================== Code Gen ===================== + +{graph.model.code} + +""" if file_name is None: print(buff) else: diff --git a/src/chop/passes/graph/transforms/resharding.py b/src/chop/passes/graph/transforms/resharding.py new file mode 100644 index 000000000..413e256ee --- /dev/null +++ b/src/chop/passes/graph/transforms/resharding.py @@ -0,0 +1,95 @@ +import torch.fx as fx +from chop.tools import get_logger +from chop.nn.functional.dtensor import redistribute_dtensor +from chop.ir.graph import MaseMetadata + +from torch.distributed._tensor.placement_types import Replicate, Shard + +logger = get_logger(__name__) +logger.setLevel("INFO") + + +def resharding_transform_pass(mg, pass_args={}): + """Insert resharding nodes""" + logger.info( + f"Running resharding_transform_pass to insert resharding nodes along necessary edges." + ) + for node in mg.fx_graph.nodes: + + if node.op == "call_function" and node.target == redistribute_dtensor: + continue + + flattened_args = node.args + tuple(node.kwargs.values()) + + # Number of arguments should match metadata + if node.op != "output" and len(flattened_args) != len( + node.meta["mase"]["common"]["args"] + ): + logger.warning( + f"Skipping node: {node.name} because number of arguments do not match metadata." + ) + continue + + for arg_idx, arg_name in enumerate(node.meta["mase"]["common"]["args"].keys()): + + # Check if argument is an FX node, otherwise it's a constant + arg_obj = flattened_args[arg_idx] + if not isinstance(arg_obj, fx.Node): + logger.debug( + f"Skipping node: {node.name}, argument: {arg_name} because it is a constant." + ) + continue + + # Check if the parent node output spec is different from the arg input spec + arg_info = node.meta["mase"]["common"]["args"][arg_name] + arg_specs = arg_info.get("dtensor_spec", None) + + parent_out_specs = arg_obj.meta["mase"]["common"]["results"][ + "data_out_0" + ].get("dtensor_spec", None) + + if arg_specs is None or parent_out_specs is None: + logger.warning( + f"Skipping edge {arg_obj} -> {node}.{arg_name} because dtensor_spec was not found" + ) + continue + + if arg_specs.placements != parent_out_specs.placements: + logger.info( + f"Inserting resharding node along edge {arg_obj} -> {node.name} due to arg {arg_idx}: {arg_name}" + ) + + # Create resharding node + with mg.fx_graph.inserting_before(node): + resharding_node = mg.fx_graph.call_function( + redistribute_dtensor, + args=(arg_obj, arg_specs.placements), + kwargs={ + "async_op": False, + }, + ) + + resharding_node.meta["mase"] = MaseMetadata( + node=resharding_node, + model=mg.model, + ) + + # Update the current node's argument + updated_args = list(node.args) + updated_args[arg_idx] = resharding_node + node.args = tuple(updated_args) + + # Insert DTensor import at the top of code + def insert_imports(body): + return [ + "from torch.distributed._tensor.placement_types import Replicate, Shard \n", + *body, + ] + + mg.fx_graph.on_generate_code(lambda _: insert_imports) + + # Check the model is valid + mg.fx_graph.lint() + mg.model.recompile() + + return mg, {} diff --git a/src/chop/passes/module/__init__.py b/src/chop/passes/module/__init__.py index 566b306a8..f2ffd33ca 100644 --- a/src/chop/passes/module/__init__.py +++ b/src/chop/passes/module/__init__.py @@ -1,5 +1,5 @@ from .analysis import calculate_avg_bits_module_analysis_pass -from .transforms import quantize_module_transform_pass, resharding_transform_pass +from .transforms import quantize_module_transform_pass ANALYSIS_PASSES = ["calculate_avg_bits_module_analysis_pass"] diff --git a/src/chop/passes/module/transforms/__init__.py b/src/chop/passes/module/transforms/__init__.py index efbb0ed14..3fcc8c5b3 100644 --- a/src/chop/passes/module/transforms/__init__.py +++ b/src/chop/passes/module/transforms/__init__.py @@ -1,3 +1 @@ -from .autosharding import resharding_transform_pass from .quantize import quantize_module_transform_pass -from .autosharding import resharding_transform_pass diff --git a/src/chop/passes/module/transforms/autosharding/__init__.py b/src/chop/passes/module/transforms/autosharding/__init__.py deleted file mode 100644 index 699587b80..000000000 --- a/src/chop/passes/module/transforms/autosharding/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .resharding import resharding_transform_pass diff --git a/src/chop/passes/module/transforms/autosharding/resharding.py b/src/chop/passes/module/transforms/autosharding/resharding.py deleted file mode 100644 index 3646064dd..000000000 --- a/src/chop/passes/module/transforms/autosharding/resharding.py +++ /dev/null @@ -1,26 +0,0 @@ -from chop.tools import get_logger - -logger = get_logger(__name__) -logger.setLevel("INFO") - - -def resharding_transform_pass(mg, pass_args={}): - """ - This pass inserts a wrapper around each module in the graph to handle resharding - activation tensors when the output of the previous module has a different sharding - profile to the one assigned to the current module. - """ - - module_map = pass_args.get("module_map", None) - device_mesh = pass_args.get("device_mesh", None) - if module_map is None or device_mesh is None: - raise ValueError( - "module_map and device_mesh are required for resharding_transform_pass" - ) - - for node in mg.fx_graph.nodes: - pass - - mg.model.recompile() - - return mg, {} diff --git a/src/chop/pipelines/distributed_inference.py b/src/chop/pipelines/distributed_inference.py index 51354583b..9bac59864 100644 --- a/src/chop/pipelines/distributed_inference.py +++ b/src/chop/pipelines/distributed_inference.py @@ -25,10 +25,10 @@ def __init__(self) -> None: pass_list = [ passes.replace_method_with_function, passes.init_metadata_analysis_pass, - passes.report_graph_analysis_pass, passes.add_common_metadata_analysis_pass, passes.autosharding_analysis_pass, passes.resharding_transform_pass, + passes.report_graph_analysis_pass, ] super().__init__(pass_list) From 996d2fa53922808f8b8ac53e661e846a14c10857 Mon Sep 17 00:00:00 2001 From: Pedro Gimenes Date: Mon, 29 Jul 2024 14:46:51 +0000 Subject: [PATCH 63/93] simplify mase launcher and get resharding nodes working with non fully replicated model --- src/chop/distributed/__init__.py | 1 - src/chop/distributed/launcher.py | 178 +++--------------- src/chop/distributed/tensor/_dispatch.py | 2 +- src/chop/distributed/tensor/_sharding_prop.py | 10 +- src/chop/distributed/utils.py | 112 +++++++++++ src/chop/nn/functional/dtensor.py | 41 +++- .../add_metadata/add_common_metadata.py | 3 +- .../analysis/autosharding/autosharding.py | 8 +- .../passes/graph/transforms/resharding.py | 41 +++- 9 files changed, 214 insertions(+), 182 deletions(-) diff --git a/src/chop/distributed/__init__.py b/src/chop/distributed/__init__.py index 9f188fa07..e69de29bb 100644 --- a/src/chop/distributed/__init__.py +++ b/src/chop/distributed/__init__.py @@ -1 +0,0 @@ -from .launcher import MaseLauncher diff --git a/src/chop/distributed/launcher.py b/src/chop/distributed/launcher.py index 8157636d6..86f99227b 100644 --- a/src/chop/distributed/launcher.py +++ b/src/chop/distributed/launcher.py @@ -1,162 +1,22 @@ -import os -from functools import partial -from time import time - -import torch -import torch.nn as nn -import torch.distributed as dist import torch.multiprocessing as mp -from torch.distributed._tensor import ( - DeviceMesh, - Replicate, - Shard, -) - -from chop.distributed.tensor import distribute_module, distribute_tensor - -from chop.distributed.utils import rlog from ..tools import get_logger logger = get_logger(__name__) logger.setLevel("DEBUG") -def distributed_timing(fn, *args, **kwargs): - dist.barrier(async_op=True) - start = time() - result = fn(*args, **kwargs) - dist.barrier(async_op=True) - end = time() - - return result, (end - start) - - -def distributed_average_timing(fn, repeat, args): - times = [] - for itr in range(repeat): - rlog( - logger, - dist.get_rank(), - f"Running teration {itr}", - "debug", - ) - dist.barrier(async_op=True) - start = time() - result = fn(*args) - dist.barrier(async_op=True) - end = time() - times.append(end - start) - rlog( - logger, - dist.get_rank(), - f"Time taken: {end - start}s", - "debug", - ) - - return result, sum(times[2:]) / len(times[2:]) - - -def dist_model_fn( - name: str, - module: nn.Module, - device_mesh: DeviceMesh, - rank: int, - tensor_sharding_map={}, -) -> None: - """ - This function gets called by torch.distributed._tensor.distribute_module on each module in the model. - Each tensor in each module is distributed according to the sharding configuration in tensor_sharding_map. - """ - if module in tensor_sharding_map: - node_name = tensor_sharding_map[module]["node"] - for parameter, sharding_config in tensor_sharding_map[module][ - "sharding" - ].items(): - if parameter in ["data_in_0", "output", "data_out_0"]: - continue - if not hasattr(module, parameter): - rlog( - logger, - rank, - f"Module {module} does not have parameter {parameter}", - level="warning", - ) - continue - - placement = sharding_config.placements - - try: - rlog( - logger, - rank, - f"Distributing parameter {parameter} of module {node_name} to {placement}", - level="debug", - ) - distributed_tensor = distribute_tensor( - getattr(module, parameter), device_mesh, placement - ) - setattr(module, parameter, torch.nn.Parameter(distributed_tensor)) - except Exception as e: - rlog( - logger, - rank, - f"Error distributing parameter {parameter} of module {node_name} to {placement}: {e}", - level="error", - ) - - -def device_fn( - rank, world_size, model=None, device_mesh=None, tensor_sharding_map={}, inputs=[] -): - """ - This function gets called on each GPU device to set up the distributed environment and distribute the model, - following the SPMD model. - """ - os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "12355" - os.environ["RANK"] = str(rank) - - # Initialize - dist.init_process_group("nccl", rank=rank, world_size=world_size) - device = torch.device("cuda", rank) - torch.cuda.set_device(device) - - # Distribute model parameters according to sharding configuration - mesh = DeviceMesh("cuda", mesh=device_mesh) - rlog(logger, rank, f"Distributing module parameters...", level="info") - model, dist_time = distributed_timing( - distribute_module, - model, - mesh, - partial(dist_model_fn, rank=rank, tensor_sharding_map=tensor_sharding_map), - input_fn=None, - output_fn=None, - ) - rlog(logger, rank, f"Module distribution done. Time taken: {dist_time} seconds.") - - # Run forward pass - rlog(logger, rank, f"Starting forward pass.", level="info") - inputs = [ - distribute_tensor(in_tensor, mesh, [Replicate(), Replicate()]) - for in_tensor in inputs - ] - _, time_taken = distributed_average_timing( - fn=model, - repeat=10, - args=inputs, - ) - rlog(logger, rank, f"Forward pass finished. Time taken: {time_taken}", level="info") - - dist.destroy_process_group() - - class MaseLauncher: """ MaseLauncher launches an optimized model on multiple GPUs using torch.distributed. """ - def __init__(self, mase_graph, world_size=None, device_mesh=None): + def __init__( + self, + world_size=None, + device_mesh=None, + device_fn=None, + ): """Initialize the MaseLauncher. Args: @@ -164,23 +24,27 @@ def __init__(self, mase_graph, world_size=None, device_mesh=None): world_size (int, optional): Number of GPUs to use. Defaults to None. device_mesh (list, optional): List of GPUs to use. Defaults to None. """ - self.mg = mase_graph - self.model = mase_graph.model self.world_size = world_size self.device_mesh = device_mesh - - def run(self, tensor_sharding_map={}, inputs=[]): + self.device_fn = device_fn + + def run( + self, + model_class=None, + model_config=None, + cli_args=None, + ): logger.info(f"Launching model with world size {self.world_size}.") mp.spawn( - partial( - device_fn, - model=self.model, - device_mesh=self.device_mesh, - tensor_sharding_map=tensor_sharding_map, - inputs=inputs, + self.device_fn, + args=( + self.world_size, + self.device_mesh, + model_class, + model_config, + cli_args, ), - args=(self.world_size,), nprocs=self.world_size, join=True, ) diff --git a/src/chop/distributed/tensor/_dispatch.py b/src/chop/distributed/tensor/_dispatch.py index b35a1fc84..5ea308804 100644 --- a/src/chop/distributed/tensor/_dispatch.py +++ b/src/chop/distributed/tensor/_dispatch.py @@ -122,7 +122,7 @@ def dispatch( # extract local tensor and sharding infos to a OpInfo op_info = self.unwrap_to_op_info(op_call, args, kwargs) - rlog(f"Dispatching op_call: {op_call.name}") + # rlog(f"Dispatching op_call: {op_call.name}") # self.sharding_propagator.propagate(op_info) # output_sharding = op_info.output_sharding diff --git a/src/chop/distributed/tensor/_sharding_prop.py b/src/chop/distributed/tensor/_sharding_prop.py index 045465412..5b2996919 100644 --- a/src/chop/distributed/tensor/_sharding_prop.py +++ b/src/chop/distributed/tensor/_sharding_prop.py @@ -285,14 +285,14 @@ def spec_to_strategy(spec: object) -> object: if output_strategy.output_spec.is_sharded(): schema = suggestion_schema or op_schema assert isinstance(out_tensor_meta, TensorMeta) - rlog(f"Need the out tensor meta here!") + # rlog(f"Need the out tensor meta here!") suggestion_schema = self._adjust_shape_and_stride_args( out_tensor_meta, schema, output_strategy.output_spec, mesh ) needs_redistribute = True - else: - rlog(f"Don't need it because it's not sharded") + # else: + # rlog(f"Don't need it because it's not sharded") # construct output spec for the op if op_schema.return_type_tuple_tensor_like(): @@ -324,7 +324,7 @@ def spec_to_strategy(spec: object) -> object: needs_redistribute=needs_redistribute, ) elif isinstance(op_strategy, TupleStrategy): - rlog(f"Strategy returned a TupleStrategy") + # rlog(f"Strategy returned a TupleStrategy") # tuple strategy output sharding processing # runtime selected placement strategy for each TupleStrategy input arg selected_strategies: List[PlacementStrategy] = [] @@ -401,7 +401,7 @@ def spec_to_strategy(spec: object) -> object: ) return output_sharding elif op_schema.op in self.op_to_rules: - rlog(f"Op {op_schema.op} has no strategy, using rule") + # rlog(f"Op {op_schema.op} has no strategy, using rule") # propagate the sharding with rule sharding_prop_func = self.op_to_rules[op_schema.op] diff --git a/src/chop/distributed/utils.py b/src/chop/distributed/utils.py index c7cd7c1c7..f7eef248f 100644 --- a/src/chop/distributed/utils.py +++ b/src/chop/distributed/utils.py @@ -1,3 +1,18 @@ +from time import time + + +import torch +import torch.nn as nn +import torch.distributed as dist +from torch.distributed._tensor import DeviceMesh + +from chop.tools import get_logger +from chop.distributed.tensor import distribute_tensor + +logger = get_logger(__name__) +logger.setLevel("INFO") + + def rlog(logger, rank, msg, level="info"): """ Only log on rank 0 to avoid repeated messages. @@ -5,3 +20,100 @@ def rlog(logger, rank, msg, level="info"): log_fn = getattr(logger, level, logger.info) if rank == 0: log_fn(msg) + + +def distributed_timing(fn, *args, **kwargs): + dist.barrier(async_op=True) + start = time() + result = fn(*args, **kwargs) + dist.barrier(async_op=True) + end = time() + + return result, (end - start) + + +def distributed_average_timing( + fn, + args, + repeat=10, + warmup_iters=2, +): + times = [] + for itr in range(repeat): + rlog( + logger, + dist.get_rank(), + f"Running teration {itr}", + "info", + ) + dist.barrier(async_op=True) + start = time() + result = fn(*args) + dist.barrier(async_op=True) + end = time() + times.append(end - start) + rlog( + logger, + dist.get_rank(), + f"Time taken: {end - start}s", + "info", + ) + + return result, sum(times[warmup_iters:]) / len(times[warmup_iters:]) + + +def dist_model_fn( + name: str, + module: nn.Module, + device_mesh: DeviceMesh, + rank: int, + tensor_sharding_map={}, +) -> None: + """ + This function gets called by torch.distributed._tensor.distribute_module on each module in the model. + Each tensor in each module is distributed according to the sharding configuration in tensor_sharding_map. + + Args: + name (str): _description_ + module (nn.Module): _description_ + device_mesh (DeviceMesh): _description_ + rank (int): _description_ + tensor_sharding_map (dict, optional): _description_. Defaults to {}. + """ + + if module in tensor_sharding_map: + node_name = tensor_sharding_map[module]["node"] + for parameter, sharding_config in tensor_sharding_map[module][ + "sharding" + ].items(): + if parameter in ["data_in_0", "output", "data_out_0"]: + continue + if not hasattr(module, parameter): + rlog( + logger, + rank, + f"Module {module} does not have parameter {parameter}", + level="warning", + ) + continue + + placement = sharding_config.placements + + try: + rlog( + logger, + rank, + f"Distributing parameter {parameter} of module {node_name} to {placement}", + level="debug", + ) + distributed_tensor = distribute_tensor( + getattr(module, parameter), device_mesh, placement + ) + setattr(module, parameter, torch.nn.Parameter(distributed_tensor)) + except Exception as e: + rlog( + logger, + rank, + f"Error distributing parameter {parameter} of module {node_name} to {placement}: {e}", + level="error", + ) diff --git a/src/chop/nn/functional/dtensor.py b/src/chop/nn/functional/dtensor.py index 58888bee3..ffb18ace5 100644 --- a/src/chop/nn/functional/dtensor.py +++ b/src/chop/nn/functional/dtensor.py @@ -1,14 +1,22 @@ from typing import Tuple import torch +import torch.fx as fx from torch.distributed.device_mesh import DeviceMesh -from torch.distributed._tensor import DTensor from torch.distributed._tensor.placement_types import DTensorSpec from torch.distributed._tensor._redistribute import redistribute_local_tensor from torch.distributed._tensor.placement_types import Placement +from chop.distributed.tensor import DTensor +from chop.tools import get_logger +from chop.distributed.utils import rlog +logger = get_logger(__name__) +logger.setLevel("DEBUG") + + +@fx.wrap def dtensor_arange( start: int, end: int, @@ -46,6 +54,7 @@ def dtensor_arange( ) +@fx.wrap def redistribute_dtensor( input: DTensor, placements: Tuple[Placement, ...], @@ -62,18 +71,44 @@ def redistribute_dtensor( Returns: DTensor: The redistributed DTensor. """ + + # If we are not in a distributed setting, we can skip redistribution. + try: + rank = torch.distributed.get_rank() + except: + rank = 0 + + if not isinstance(input, DTensor): + rlog( + logger, + rank, + f"Skipping redistribution because received {type(input)} instead of DTensor", + level="warning", + ) + return input + current_spec = input._spec + rlog( + logger, + rank, + f"Redistributing tensor from {current_spec.placements} to {placements}", + level="info", + ) + if current_spec.placements != placements: target_spec = DTensorSpec( - None, + input._spec.mesh, placements, tensor_meta=input._spec.tensor_meta, ) local_tensor = input._local_tensor output = redistribute_local_tensor( - local_tensor, current_spec, target_spec, async_op=async_op + local_tensor, + current_spec, + target_spec, + async_op=async_op, ) else: # use the same local tensor if placements are the same. diff --git a/src/chop/passes/graph/analysis/add_metadata/add_common_metadata.py b/src/chop/passes/graph/analysis/add_metadata/add_common_metadata.py index a8209f2a0..a947cc630 100644 --- a/src/chop/passes/graph/analysis/add_metadata/add_common_metadata.py +++ b/src/chop/passes/graph/analysis/add_metadata/add_common_metadata.py @@ -209,14 +209,13 @@ def graph_iterator_for_metadata( model, fx_graph, modules = graph.model, graph.fx_graph, graph.modules env = {} - prev_result = None # force everything to be on device="meta" if force_device_meta: dummy_in = {k: v.to("meta") for k, v in dummy_in.items()} model = model.to("meta") - for node in graph.fx_graph.nodes: + for node in fx_graph.nodes: args, kwargs = None, None if node.op == "placeholder": result = dummy_in[node.name] diff --git a/src/chop/passes/graph/analysis/autosharding/autosharding.py b/src/chop/passes/graph/analysis/autosharding/autosharding.py index 98a80dd46..d994b1939 100644 --- a/src/chop/passes/graph/analysis/autosharding/autosharding.py +++ b/src/chop/passes/graph/analysis/autosharding/autosharding.py @@ -1,5 +1,4 @@ -import numpy as np -import cvxpy as cp +import os from time import time import dill @@ -246,8 +245,9 @@ def autosharding_analysis_pass(mg, pass_args: dict = {}): mesh = MeshModel(pass_args["mesh_shape"]) # Preload autosharding solution - if pass_args.get("preload_solution", False): - fname = pass_args.get("ilp_solution_file", "ilp_solution.pkl") + fname = pass_args.get("ilp_solution_file", "ilp_solution.pkl") + # check if solution file exists + if pass_args.get("preload_solution", False) and os.path.exists(fname): logger.info(f"Preloading autosharding solution from: {fname}") with open(fname, "rb") as file: solution = dill.load(file) diff --git a/src/chop/passes/graph/transforms/resharding.py b/src/chop/passes/graph/transforms/resharding.py index 413e256ee..e143c059a 100644 --- a/src/chop/passes/graph/transforms/resharding.py +++ b/src/chop/passes/graph/transforms/resharding.py @@ -1,15 +1,18 @@ +from copy import copy + +import torch import torch.fx as fx +from torch.distributed._tensor.placement_types import Replicate, Shard + from chop.tools import get_logger from chop.nn.functional.dtensor import redistribute_dtensor from chop.ir.graph import MaseMetadata -from torch.distributed._tensor.placement_types import Replicate, Shard - logger = get_logger(__name__) logger.setLevel("INFO") -def resharding_transform_pass(mg, pass_args={}): +def _insert_resharding_nodes(mg, pass_args={}): """Insert resharding nodes""" logger.info( f"Running resharding_transform_pass to insert resharding nodes along necessary edges." @@ -20,6 +23,7 @@ def resharding_transform_pass(mg, pass_args={}): continue flattened_args = node.args + tuple(node.kwargs.values()) + kwarg_keys = list(node.kwargs.keys()) # Number of arguments should match metadata if node.op != "output" and len(flattened_args) != len( @@ -43,7 +47,6 @@ def resharding_transform_pass(mg, pass_args={}): # Check if the parent node output spec is different from the arg input spec arg_info = node.meta["mase"]["common"]["args"][arg_name] arg_specs = arg_info.get("dtensor_spec", None) - parent_out_specs = arg_obj.meta["mase"]["common"]["results"][ "data_out_0" ].get("dtensor_spec", None) @@ -56,7 +59,7 @@ def resharding_transform_pass(mg, pass_args={}): if arg_specs.placements != parent_out_specs.placements: logger.info( - f"Inserting resharding node along edge {arg_obj} -> {node.name} due to arg {arg_idx}: {arg_name}" + f"Inserting resharding node along edge {arg_obj} -> {node.name} because arg {arg_name} requires placement {arg_specs.placements} but parent node {arg_obj.name} has placement {parent_out_specs.placements}." ) # Create resharding node @@ -75,14 +78,30 @@ def resharding_transform_pass(mg, pass_args={}): ) # Update the current node's argument - updated_args = list(node.args) - updated_args[arg_idx] = resharding_node - node.args = tuple(updated_args) + # Node arg can be referenced in either node.args or node.kwargs so we + # infer which container to update based on the arg_idx value, which + # indexes the combined list of args and kwargs + if arg_idx < len(node.args): + updated_args = list(node.args) + updated_args[arg_idx] = resharding_node + node.args = tuple(updated_args) + else: + kwarg_idx = arg_idx - len(node.args) + arg_key = kwarg_keys[kwarg_idx] + kwarg_dict = {} + + # Reconstruct they node.kwargs dict since this is immutable + for key, value in node.kwargs.items(): + if key == arg_key: + kwarg_dict[key] = resharding_node + else: + kwarg_dict[key] = value + node.kwargs = kwarg_dict # Insert DTensor import at the top of code def insert_imports(body): return [ - "from torch.distributed._tensor.placement_types import Replicate, Shard \n", + "from torch.distributed._tensor.placement_types import Replicate, Shard, Partial; sum = 'sum' \n", *body, ] @@ -93,3 +112,7 @@ def insert_imports(body): mg.model.recompile() return mg, {} + + +def resharding_transform_pass(mg, pass_args={}): + return _insert_resharding_nodes(mg, pass_args) From 47eb41b99aecb26ec83deb1c198b7065852dacfd Mon Sep 17 00:00:00 2001 From: Pedro Gimenes Date: Tue, 30 Jul 2024 17:24:41 +0000 Subject: [PATCH 64/93] simplify op dispatcher --- src/chop/distributed/tensor/_dispatch.py | 128 +++-------------------- 1 file changed, 12 insertions(+), 116 deletions(-) diff --git a/src/chop/distributed/tensor/_dispatch.py b/src/chop/distributed/tensor/_dispatch.py index 5ea308804..f63de14ed 100644 --- a/src/chop/distributed/tensor/_dispatch.py +++ b/src/chop/distributed/tensor/_dispatch.py @@ -131,126 +131,22 @@ def dispatch( assert output_sharding is not None, "output sharding should not be None" - # mesh = op_info.mesh - # if mesh.get_coordinate() is None: - # # For a non-participating device, we do: - # # 1. if the return type is scalar, set the local result to None. - # # The local results from all devices will then be all-gathered - # # and a reduce op will be performed on the list of results - # # with appropriate operators: - # # for bool type, we by default use AND to reduce; - # # we can extend for more ops if necessary. - # # 2. if the return type is Tensor or List[Tensor], return empty - # # tensor(s) with correct dtype. - # spec = output_sharding.output_spec - # ret_list = op_info.schema.op._schema.returns - - # if spec is None: - # # For a scalar return type, the non-participating device has None - # # as its local result - # local_results: object = None - # else: - - # def default_tensor(spec: DTensorSpec) -> torch.Tensor: - # if spec.tensor_meta is not None: - # shape = spec.tensor_meta.shape - # dtype = spec.tensor_meta.dtype - # if len(shape) == 0: - # # scalar tensor - # return torch.zeros((), dtype=dtype) - # else: - # # non-scalar tensor - # return torch.tensor([], dtype=dtype) - # else: - # raise RuntimeError(f"{spec} has no tensor metadata.") - - # if isinstance(spec, DTensorSpec): - # # return a Tensor value - # local_results = default_tensor(spec) - # elif isinstance(spec, Sequence): - # # return a List[Tensor] value - # local_results = [ - # default_tensor(s) if s is not None else None for s in spec - # ] - # assert isinstance(local_results, List) - # if None in local_results: - # ret_type = str(ret_list[0].type) - # raise NotImplementedError( - # f"return type {ret_type} in DTensor op is not supported" - # ) - # else: - if output_sharding.needs_redistribute: - # compute locally with redistribute first if needed - assert output_sharding.redistribute_schema is not None - self.redistribute_local_args(op_info, output_sharding.redistribute_schema) - - # local_tensor_args = ( - # pytree.tree_unflatten( - # cast(List[object], op_info.local_args), op_info.args_tree_spec - # ) - # if op_info.args_tree_spec - # else op_info.local_args - # ) - - local_tensor_args = op_info.local_args - # run local op computation with potentially modified args/kwargs + local_tensor_args = op_info.local_args local_tensor_args = cast(Tuple[object, ...], local_tensor_args) - # if op_call in self._random_ops: - # if not random._rng_tracker and is_rng_supported_mesh(mesh): - # # Default to `OffsetBasedRNGTracker` if the parallelism API - # # did not already construct one - # random._rng_tracker = random.OffsetBasedRNGTracker(mesh.device_type) - - # first_arg, first_local_arg = cast(dtensor.DTensor, args[0]), cast( - # torch.Tensor, local_tensor_args[0] - # ) - # rng_context = ( - # random._rng_tracker._distribute_region(first_arg._spec) - # if random._rng_tracker and not first_local_arg.is_meta - # else contextlib.nullcontext() - # ) - - # # For DTensor random operator, run it within a distribute region - # with rng_context: - # local_results = op_call(*local_tensor_args, **op_info.local_kwargs) - # else: + local_results = op_call(*local_tensor_args, **op_info.local_kwargs) - # communicate the result to all ranks for some operators that return scalar value - # if output_sharding.output_spec is None: - # if op_call == aten.equal.default: - # obj_list = [None for _ in range(dist.get_world_size())] - # dist.all_gather_object(obj_list, local_results) # type: ignore[possibly-undefined] - # obj_list = list(filter(lambda x: x is not None, obj_list)) - # # perform reduce on the collection with AND op - # local_results = functools.reduce(operator.and_, obj_list, True) - - # if _is_inplace_op(op_call): - # # inplace op should return self instead of re-wrapping - # if output_sharding.output_spec is not None: - # return args[0] - # else: - # return None - # elif _is_out_variant_op(op_call): - # # out variant could possibly have multiple out args (i.e. lu_unpack.out) - # output_specs = ( - # (output_sharding.output_spec,) - # if not isinstance(output_sharding.output_spec, tuple) - # else output_sharding.output_spec - # ) - # out_dts = [] - # spec_idx = 0 - # for argument in op_call._schema.arguments: - # if argument.is_out: - # out_dt = cast(dtensor.DTensor, kwargs[argument.name]) - # out_dt._spec = cast(DTensorSpec, output_specs[spec_idx]) - # out_dts.append(out_dt) - # spec_idx += 1 - - # assert len(out_dts) >= 1, "out variant should have at least one out arg" - # return tuple(out_dts) if len(out_dts) > 1 else out_dts[0] - # else: + # rlog( + # f"Reshape {op_call.name} outputs type {type(local_results)}, shape {local_results.shape}" + # ) + + # if "aten.view" in str(op_call.name): + rlog(f"op call: {str(op_call.name)}") + rlog(f"local tensor args: {local_tensor_args}") + if isinstance(local_results, (torch.Tensor, dtensor.DTensor)): + rlog(f"op call output: {local_results.shape}") + return self.wrap(local_results, output_sharding.output_spec) # type: ignore[possibly-undefined] @staticmethod From c1e3a7c8d0e23ec638dddda44dff00ae32ed52ac Mon Sep 17 00:00:00 2001 From: Pedro Gimenes Date: Tue, 30 Jul 2024 17:25:17 +0000 Subject: [PATCH 65/93] add torch.mm as an op --- src/chop/ir/common.py | 1 + .../graph/analysis/add_metadata/common_metadata_layers.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/src/chop/ir/common.py b/src/chop/ir/common.py index 8cf46ebf2..ca0d2379a 100644 --- a/src/chop/ir/common.py +++ b/src/chop/ir/common.py @@ -104,6 +104,7 @@ "sub", "add", "matmul", + "mm", "bmm", "mean", "pow", diff --git a/src/chop/passes/graph/analysis/add_metadata/common_metadata_layers.py b/src/chop/passes/graph/analysis/add_metadata/common_metadata_layers.py index 5e0913d2f..17ac46a89 100644 --- a/src/chop/passes/graph/analysis/add_metadata/common_metadata_layers.py +++ b/src/chop/passes/graph/analysis/add_metadata/common_metadata_layers.py @@ -79,6 +79,8 @@ "sub": {"input": "data_in", "other": "data_in"}, # https://pytorch.org/docs/stable/generated/torch.matmul.html "matmul": {"input": "data_in", "other": "data_in"}, + # https://pytorch.org/docs/stable/generated/torch.mm.html + "mm": {"input": "data_in", "mat2": "data_in"}, # https://pytorch.org/docs/stable/generated/torch.bmm.html "bmm": {"input": "data_in", "mat2": "data_in"}, # https://pytorch.org/docs/stable/generated/torch.squeeze.html From ee517dfab3c5909e5e471e89d76cb1fbcef20946 Mon Sep 17 00:00:00 2001 From: Pedro Gimenes Date: Tue, 30 Jul 2024 17:25:38 +0000 Subject: [PATCH 66/93] get refactored op dispatcher working on single layer gpt2 --- src/chop/nn/functional/dtensor.py | 19 ++- .../autosharding/alpa_intra_operator.py | 136 ++++++++++++++---- .../graph/analysis/autosharding/layers.py | 8 +- .../autosharding/strategies/common.py | 42 +++--- 4 files changed, 147 insertions(+), 58 deletions(-) diff --git a/src/chop/nn/functional/dtensor.py b/src/chop/nn/functional/dtensor.py index ffb18ace5..04a7710ba 100644 --- a/src/chop/nn/functional/dtensor.py +++ b/src/chop/nn/functional/dtensor.py @@ -89,14 +89,13 @@ def redistribute_dtensor( current_spec = input._spec - rlog( - logger, - rank, - f"Redistributing tensor from {current_spec.placements} to {placements}", - level="info", - ) - if current_spec.placements != placements: + rlog( + logger, + rank, + f"Redistributing tensor from {current_spec.placements} to {placements}", + level="info", + ) target_spec = DTensorSpec( input._spec.mesh, placements, @@ -112,6 +111,12 @@ def redistribute_dtensor( ) else: # use the same local tensor if placements are the same. + rlog( + logger, + rank, + f"Skipping redistribution because placements are the same: {placements}", + level="info", + ) output = input._local_tensor target_spec = current_spec diff --git a/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py b/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py index d97478de1..dccb5fbbb 100644 --- a/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py +++ b/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py @@ -6,8 +6,12 @@ import torch import torch.fx as fx from torch.distributed._tensor._collective_utils import redistribute_cost -from torch.distributed._tensor._op_schema import DTensorSpec, OpStrategy -from torch.distributed._tensor.placement_types import Shard +from torch.distributed._tensor._op_schema import ( + DTensorSpec, + OpStrategy, + PlacementStrategy, +) +from torch.distributed._tensor.placement_types import Shard, Replicate from chop.tools import get_logger from chop.tools.utils import deepgetattr @@ -19,6 +23,7 @@ AUTOSHARDING_METHODS, IMPLICIT_FUNCS, IMPLICIT_METHODS, + FULLY_REPLICATED_FUNCS, ) from .strategies.common import ( fully_replicated_strategy, @@ -125,6 +130,55 @@ def _get_computation_cost_from_strategy( return np.mean(elapsed[warmup_iters:]) +def _no_tensor_args(node): + has_tensor_args = False + for arg, arg_info in node.meta["mase"]["common"]["args"].items(): + if isinstance(arg_info["value"], torch.Tensor): + has_tensor_args = True + break + return not has_tensor_args + + +def _inherit_strategy(node, parent_strategy): + """ + Inherit the sharding strategy from the parent node. The main data + argument is assigned the ouput sharding of the parent node, with + all other arguments casted to fully replicated placement. The output + sharding of the parent node is also assigned to the output spec of + each strategy since implicit nodes don't change the tensor shardings + + Args: + node (fx.Node): input node. + parent_strategy (OpStrategy): parent node's sharding strategy. + + Returns: + OpStrategy: inherited sharding strategy. + """ + + strategies = [] + + logger.warning( + f"Node {node.name} will inherit sharding strategy from its parent, {node.all_input_nodes[0].name}." + ) + logger.warning(f"Args: {node.meta['mase']['common']['args'].keys()}") + for strategy in parent_strategy.strategies: + spec = [strategy.output_specs] + [ + DTensorSpec( + mesh=strategy.output_specs.mesh, + placements=(Replicate(), Replicate()), + tensor_meta=None, + ) + ] * (len(node.meta["mase"]["common"]["args"]) - 1) + strategies.append( + PlacementStrategy( + input_specs=spec, + output_specs=spec[0], + ) + ) + + return OpStrategy(strategies) + + def _extract_ilp(mg, mesh, pass_args={}): """ For each node in the graph, assign an OpStrategy object which contains all possible @@ -151,11 +205,24 @@ def _extract_ilp(mg, mesh, pass_args={}): # Find sharding strategies for each operator in the graph for node in mg.fx_graph.nodes: - if (node.op == "call_function" and node.target in IMPLICIT_FUNCS) or ( - node.op == "call_method" and node.target in IMPLICIT_METHODS - ): + # Placeholder and get_attr nodes inject tensors into the graph + if node.op in [ + "placeholder", + "get_attr", + ]: + logger.debug( + f"Node {node} with op {node.op} will be assigned all permutations of Shard(dims) and Replicate()" + ) + op_strategy = placeholder_or_getattr_strategy( + node.meta["mase"], + mesh, + skip_fully_replicated=pass_args.get("skip_fully_replicated", False), + ) + + # Constrain some nodes to have fully replicated sharding + elif node.op == "call_function" and node.target in FULLY_REPLICATED_FUNCS: logger.debug( - f"Implicit {node.op} node {node.name} was assigned fully replicated sharding." + f"Node {node.name} will be assigned fully replicated sharding." ) op_strategy = fully_replicated_strategy(node.meta["mase"], mesh) @@ -168,42 +235,39 @@ def _extract_ilp(mg, mesh, pass_args={}): node.meta["mase"]["software"]["autosharding"] = { "op_strategy": op_strategy, "opt_var": opt_var, + "is_implicit": False, } continue - # Output nodes simply propagate op_strategy from their input nodes - elif node.op == "output": + # Output nodes, implicit nodes and nodes with only non-tensor arguments + # inherit the sharding strategy from their parent node + elif ( + node.op == "output" + or node.op == "call_function" + and node.target in IMPLICIT_FUNCS + or _no_tensor_args(node) + ): logger.debug( - f"Op strategy from node {node.all_input_nodes[0]} is propagated to {node} node." + f"Node {node.name} will inherit sharding strategy from its parent, {node.all_input_nodes[0].name}." ) opt_var = cp.Variable(1, boolean=True) constr += [ cp.sum(opt_var) == 1, ] node.meta["mase"]["software"]["autosharding"] = { - "op_strategy": node.all_input_nodes[0].meta["mase"]["software"][ - "autosharding" - ]["op_strategy"], + "op_strategy": _inherit_strategy( + node, + node.all_input_nodes[0].meta["mase"]["software"]["autosharding"][ + "op_strategy" + ], + ), "opt_var": opt_var, + "is_implicit": True, + "inherited_from": node.all_input_nodes[0], } continue - # Obtain strategy according to node op - # ================================================ - - elif node.op in [ - "placeholder", - "get_attr", - ]: - logger.debug( - f"Node {node} with op {node.op} will be assigned all permutations of Shard(dims) and Replicate()" - ) - op_strategy = placeholder_or_getattr_strategy( - node.meta["mase"], - mesh, - skip_fully_replicated=pass_args.get("skip_fully_replicated", False), - ) - + # For general call_function nodes, evaluate strategy based on the target elif ( node.op == "call_function" and node.target in AUTOSHARDING_FUNCTIONS.keys() ): @@ -225,6 +289,7 @@ def _extract_ilp(mg, mesh, pass_args={}): mesh, ), "opt_var": opt_var, + "is_implicit": False, } continue @@ -238,6 +303,7 @@ def _extract_ilp(mg, mesh, pass_args={}): node.meta["mase"]["software"]["autosharding"] = { "op_strategy": op_strategy, "opt_var": opt_var, + "is_implicit": False, } # Consider computation cost (c_v term) for each of the node's strategies @@ -380,12 +446,19 @@ def _mark_sharding(mg, pass_args): for node in mg.fx_graph.nodes: opt_var = node.meta["mase"]["software"]["autosharding"]["opt_var"] - if opt_var is None: + # Get the strategy chosen by the ILP + if node.meta["mase"]["software"]["autosharding"].get("is_implicit", False): + parent_node = node.meta["mase"]["software"]["autosharding"][ + "inherited_from" + ] + idx = parent_node.meta["mase"]["software"]["autosharding"][ + "chosen_strategy_idx" + ] chosen_strategy = node.meta["mase"]["software"]["autosharding"][ "op_strategy" - ] + ].strategies[idx] + else: - # Get the strategy chosen by the ILP try: idx = np.where(opt_var.value == 1)[0][0] except: @@ -396,6 +469,7 @@ def _mark_sharding(mg, pass_args): ].strategies[idx] # Annotate chosen placement strategy + node.meta["mase"]["software"]["autosharding"]["chosen_strategy_idx"] = idx node.meta["mase"]["software"]["autosharding"][ "placement_strategy" ] = chosen_strategy diff --git a/src/chop/passes/graph/analysis/autosharding/layers.py b/src/chop/passes/graph/analysis/autosharding/layers.py index 42c883224..bb4c21e85 100644 --- a/src/chop/passes/graph/analysis/autosharding/layers.py +++ b/src/chop/passes/graph/analysis/autosharding/layers.py @@ -276,11 +276,17 @@ "contiguous": tensor_op_strategy, } +FULLY_REPLICATED_FUNCS = [ + F.embedding, + torch.arange, +] + +# Implicit functions inherit their parent's strategy +# and do not change the sharding profile of their input tensors IMPLICIT_FUNCS = [ operator.getitem, getattr, torch.finfo, - torch.arange, torch_size, ] diff --git a/src/chop/passes/graph/analysis/autosharding/strategies/common.py b/src/chop/passes/graph/analysis/autosharding/strategies/common.py index b5c0923e4..24628deb4 100644 --- a/src/chop/passes/graph/analysis/autosharding/strategies/common.py +++ b/src/chop/passes/graph/analysis/autosharding/strategies/common.py @@ -112,25 +112,29 @@ def fully_replicated_strategy(meta, mesh): in_shape = meta["common"]["self"].shape in_dtype = meta["common"]["self"].dtype else: - first_arg_key = ( - "data_in_0" - if "data_in_0" in meta["common"]["args"] - else [i for i in meta["common"]["args"].keys()][0] - ) - arg = meta["common"]["args"][first_arg_key] - in_shape, in_dtype = find_shape_and_dtype(arg) - - in_spec = [ - DTensorSpec( - mesh, - sharding, - tensor_meta=TensorMeta( - shape=in_shape, - stride=None, - dtype=in_dtype, - ), - ) - ] * len(meta["common"]["args"].keys()) + if len(meta["common"]["args"]) > 0: + first_arg_key = ( + "data_in_0" + if "data_in_0" in meta["common"]["args"] + else [i for i in meta["common"]["args"].keys()][0] + ) + arg = meta["common"]["args"][first_arg_key] + in_shape, in_dtype = find_shape_and_dtype(arg) + + in_spec = [ + DTensorSpec( + mesh, + sharding, + tensor_meta=TensorMeta( + shape=in_shape, + stride=None, + dtype=in_dtype, + ), + ) + ] * len(meta["common"]["args"].keys()) + + else: + in_spec = [] dtype_key = ( "torch_dtype" From 0747681c7a5d961614abc6d1ad1f34330b3794c1 Mon Sep 17 00:00:00 2001 From: Pedro Gimenes Date: Wed, 31 Jul 2024 09:33:03 +0000 Subject: [PATCH 67/93] remove logging which was slowing down runtime and remove redistribute nodes for non-tensor arguments --- src/chop/distributed/tensor/_dispatch.py | 10 ------ src/chop/distributed/utils.py | 2 +- src/chop/nn/functional/dtensor.py | 36 +++++++++---------- .../autosharding/alpa_intra_operator.py | 4 --- .../passes/graph/transforms/resharding.py | 6 ++-- 5 files changed, 23 insertions(+), 35 deletions(-) diff --git a/src/chop/distributed/tensor/_dispatch.py b/src/chop/distributed/tensor/_dispatch.py index f63de14ed..45fa52d49 100644 --- a/src/chop/distributed/tensor/_dispatch.py +++ b/src/chop/distributed/tensor/_dispatch.py @@ -137,16 +137,6 @@ def dispatch( local_results = op_call(*local_tensor_args, **op_info.local_kwargs) - # rlog( - # f"Reshape {op_call.name} outputs type {type(local_results)}, shape {local_results.shape}" - # ) - - # if "aten.view" in str(op_call.name): - rlog(f"op call: {str(op_call.name)}") - rlog(f"local tensor args: {local_tensor_args}") - if isinstance(local_results, (torch.Tensor, dtensor.DTensor)): - rlog(f"op call output: {local_results.shape}") - return self.wrap(local_results, output_sharding.output_spec) # type: ignore[possibly-undefined] @staticmethod diff --git a/src/chop/distributed/utils.py b/src/chop/distributed/utils.py index f7eef248f..e16fb911f 100644 --- a/src/chop/distributed/utils.py +++ b/src/chop/distributed/utils.py @@ -10,7 +10,7 @@ from chop.distributed.tensor import distribute_tensor logger = get_logger(__name__) -logger.setLevel("INFO") +logger.setLevel("DEBUG") def rlog(logger, rank, msg, level="info"): diff --git a/src/chop/nn/functional/dtensor.py b/src/chop/nn/functional/dtensor.py index 04a7710ba..be7ccacde 100644 --- a/src/chop/nn/functional/dtensor.py +++ b/src/chop/nn/functional/dtensor.py @@ -79,23 +79,23 @@ def redistribute_dtensor( rank = 0 if not isinstance(input, DTensor): - rlog( - logger, - rank, - f"Skipping redistribution because received {type(input)} instead of DTensor", - level="warning", - ) + # rlog( + # logger, + # rank, + # f"Skipping redistribution because received {type(input)} instead of DTensor", + # level="warning", + # ) return input current_spec = input._spec if current_spec.placements != placements: - rlog( - logger, - rank, - f"Redistributing tensor from {current_spec.placements} to {placements}", - level="info", - ) + # rlog( + # logger, + # rank, + # f"Redistributing tensor from {current_spec.placements} to {placements}", + # level="info", + # ) target_spec = DTensorSpec( input._spec.mesh, placements, @@ -111,12 +111,12 @@ def redistribute_dtensor( ) else: # use the same local tensor if placements are the same. - rlog( - logger, - rank, - f"Skipping redistribution because placements are the same: {placements}", - level="info", - ) + # rlog( + # logger, + # rank, + # f"Skipping redistribution because placements are the same: {placements}", + # level="info", + # ) output = input._local_tensor target_spec = current_spec diff --git a/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py b/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py index dccb5fbbb..fc183a389 100644 --- a/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py +++ b/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py @@ -157,10 +157,6 @@ def _inherit_strategy(node, parent_strategy): strategies = [] - logger.warning( - f"Node {node.name} will inherit sharding strategy from its parent, {node.all_input_nodes[0].name}." - ) - logger.warning(f"Args: {node.meta['mase']['common']['args'].keys()}") for strategy in parent_strategy.strategies: spec = [strategy.output_specs] + [ DTensorSpec( diff --git a/src/chop/passes/graph/transforms/resharding.py b/src/chop/passes/graph/transforms/resharding.py index e143c059a..785555026 100644 --- a/src/chop/passes/graph/transforms/resharding.py +++ b/src/chop/passes/graph/transforms/resharding.py @@ -38,14 +38,16 @@ def _insert_resharding_nodes(mg, pass_args={}): # Check if argument is an FX node, otherwise it's a constant arg_obj = flattened_args[arg_idx] - if not isinstance(arg_obj, fx.Node): + arg_info = node.meta["mase"]["common"]["args"][arg_name] + if not isinstance(arg_obj, fx.Node) or not isinstance( + arg_info["value"], torch.Tensor + ): logger.debug( f"Skipping node: {node.name}, argument: {arg_name} because it is a constant." ) continue # Check if the parent node output spec is different from the arg input spec - arg_info = node.meta["mase"]["common"]["args"][arg_name] arg_specs = arg_info.get("dtensor_spec", None) parent_out_specs = arg_obj.meta["mase"]["common"]["results"][ "data_out_0" From 45f9c7107e3c94faeb4f4065575450a4b1de0222 Mon Sep 17 00:00:00 2001 From: Pedro Gimenes Date: Mon, 5 Aug 2024 12:22:07 +0000 Subject: [PATCH 68/93] DTensor: remove duplicated op call for out tensor meta propagation --- src/chop/distributed/tensor/_dispatch.py | 142 +++++++++++------- src/chop/distributed/tensor/_sharding_prop.py | 118 ++------------- src/chop/distributed/tensor/api.py | 6 - src/chop/distributed/tensor/ops/tensor_ops.py | 10 +- src/chop/nn/functional/dtensor.py | 25 +-- 5 files changed, 113 insertions(+), 188 deletions(-) diff --git a/src/chop/distributed/tensor/_dispatch.py b/src/chop/distributed/tensor/_dispatch.py index 45fa52d49..52a34ed9f 100644 --- a/src/chop/distributed/tensor/_dispatch.py +++ b/src/chop/distributed/tensor/_dispatch.py @@ -21,7 +21,12 @@ convolution_handler, ) from torch.distributed._tensor._utils import try_find_mesh_from_args -from torch.distributed._tensor.placement_types import DTensorSpec, Replicate, TensorMeta +from torch.distributed._tensor.placement_types import ( + DTensorSpec, + Replicate, + Shard, + TensorMeta, +) from torch.distributed._tensor.random import is_rng_supported_mesh @@ -40,6 +45,30 @@ aten = torch.ops.aten +def try_get_replicate_spec(tensor_arg: torch.Tensor, mesh: "DeviceMesh") -> DTensorSpec: + # tensor_arg is an instance of torch.Tensor and could be an arg or kwarg. + if tensor_arg.numel() == 1 and tensor_arg.ndim == 1: + warnings.warn( + "Found a non-scalar tensor with numel=1 and ndim!=0, " + "we are implicitly creating a replicated DTensor for it. " + "However, please consider changing it to a scalar tensor " + "or explicitly create a DTensor under distributed enviroment." + ) + + # scalar tensor can be safely treated as replicated + replication_spec = DTensorSpec( + mesh, + (Replicate(),) * mesh.ndim, + tensor_meta=TensorMeta( + shape=tensor_arg.shape, + stride=tensor_arg.stride(), + dtype=tensor_arg.dtype, + ), + ) + + return replication_spec + + def decompose_handler( op_call: torch._ops.OpOverload, args: Tuple[object, ...], @@ -72,6 +101,23 @@ def rlog(msg): print(msg) +def _get_global_shape(local_shape, dtensor_spec): + if dtensor_spec is None: + return local_shape + + placements = dtensor_spec.placements + global_shape = list(local_shape) + mesh_shape = dtensor_spec.mesh.shape + + for mesh_dim, placement in enumerate(placements): + if isinstance(placement, Shard): + global_shape[placement.dim] = ( + global_shape[placement.dim] * mesh_shape[mesh_dim] + ) + + return torch.Size(global_shape) + + class OpDispatcher: """ Op dispatching class instance to handle args/kwargs pre-processing (un-wrapping), sharding @@ -102,8 +148,6 @@ def __init__(self) -> None: # This flag is used internally to control whether we treat the torch.Tensor(non-DTensor) # as implicitly replicated or we throw error to user. - # NOTE: It is EXTREMELY UNSAFE to turn this flag on by default so we intentionally leave - # it as False by default. self._allow_implicit_replication = True def dispatch( @@ -122,13 +166,8 @@ def dispatch( # extract local tensor and sharding infos to a OpInfo op_info = self.unwrap_to_op_info(op_call, args, kwargs) - # rlog(f"Dispatching op_call: {op_call.name}") - - # self.sharding_propagator.propagate(op_info) - # output_sharding = op_info.output_sharding output_sharding = self.sharding_propagator.propagate(op_info) - assert output_sharding is not None, "output sharding should not be None" # run local op computation with potentially modified args/kwargs @@ -137,7 +176,40 @@ def dispatch( local_results = op_call(*local_tensor_args, **op_info.local_kwargs) - return self.wrap(local_results, output_sharding.output_spec) # type: ignore[possibly-undefined] + # Getting tensor meta after running the op to avoid running it twice + if isinstance(local_results, (tuple, list)): + out_tensor_meta = [ + TensorMeta( + shape=_get_global_shape( + r.shape, + output_sharding.output_spec[out_idx], + ), + stride=r.stride(), + dtype=r.dtype, + ) + for out_idx, r in enumerate(local_results) + ] + else: + out_tensor_meta = TensorMeta( + shape=_get_global_shape( + local_results.shape, + output_sharding.output_spec, + ), + stride=local_results.stride(), + dtype=local_results.dtype, + ) + + # Annotate output DTensorSpec with TensorMeta object + self.sharding_propagator._wrap_output_spec_tensor_meta( + op_call, + output_sharding.output_spec, + out_tensor_meta, + ) + + return self.wrap( + local_results, + output_sharding.output_spec, + ) @staticmethod def redistribute_local_args( @@ -183,12 +255,6 @@ def unwrap_to_op_info( op_call, None ) - # if runtime_schema_info is not None and runtime_schema_info.needs_pytree: - # # flatten args/kwargs when necessary - # print(f"needs pytree...") - # tree_args, args_spec = pytree.tree_flatten(args) - # args_list: Sequence[object] = tree_args - # else: args_list, args_spec = args, None args_schema: List[object] = [] @@ -197,41 +263,6 @@ def unwrap_to_op_info( local_kwargs: Dict[str, object] = {} mesh: Optional[DeviceMesh] = None - def try_get_replicate_spec( - tensor_arg: torch.Tensor, mesh: "DeviceMesh" - ) -> DTensorSpec: - # tensor_arg is an instance of torch.Tensor and could be an arg or kwarg. - if tensor_arg.numel() == 1 and tensor_arg.ndim == 1: - warnings.warn( - "Found a non-scalar tensor with numel=1 and ndim!=0, " - "we are implicitly creating a replicated DTensor for it. " - "However, please consider changing it to a scalar tensor " - "or explicitly create a DTensor under distributed enviroment." - ) - - # if the arg.numel() == 1, arg.ndim could be 0 or 1. - if ( - tensor_arg.ndim <= 1 - and tensor_arg.numel() == 1 - or self._allow_implicit_replication - ): - # scalar tensor can be safely treated as replicated - replication_spec = DTensorSpec( - mesh, - (Replicate(),) * mesh.ndim, - tensor_meta=TensorMeta( - shape=tensor_arg.shape, - stride=tensor_arg.stride(), - dtype=tensor_arg.dtype, - ), - ) - else: - raise RuntimeError( - f"{op_call}: got mixed torch.Tensor and DTensor, need to convert all" - " torch.Tensor to DTensor before calling distributed operators!" - ) - return replication_spec - for arg in args_list: if isinstance(arg, dtensor.DTensor): args_schema.append(arg._spec) @@ -288,13 +319,20 @@ def try_get_replicate_spec( return op_info @staticmethod - def wrap(res: object, spec: OutputSpecType) -> object: + def wrap( + res: object, + spec: OutputSpecType, + ) -> object: if isinstance(res, torch.Tensor): if spec is not None: assert isinstance( spec, DTensorSpec ), f"output spec does not match with output! Expected DTensorSpec, got {spec}." - return dtensor.DTensor(res, spec, requires_grad=res.requires_grad) + return dtensor.DTensor( + res, + spec, + requires_grad=res.requires_grad, + ) else: # if output does not have a DTensorSpec due to specific ops, it must be a scalar tensor assert res.ndim == 0, "output tensor should be scalar!" diff --git a/src/chop/distributed/tensor/_sharding_prop.py b/src/chop/distributed/tensor/_sharding_prop.py index 5b2996919..7f3a12545 100644 --- a/src/chop/distributed/tensor/_sharding_prop.py +++ b/src/chop/distributed/tensor/_sharding_prop.py @@ -165,13 +165,19 @@ def _wrap_output_spec_tensor_meta( ) output_specs.tensor_meta = output_tensor_meta elif isinstance(output_specs, (tuple, list)): - if not isinstance(output_tensor_meta, (tuple, list)) or len( - output_specs - ) != len(output_tensor_meta): + if not isinstance(output_tensor_meta, (tuple, list)): + raise ValueError( + f"For the op {op.name()}, `output_specs` has type {type(output_specs)} but output_tensor_meta has type {type(output_tensor_meta)}" + f"Both should be tuple or list." + ) + + if len(output_specs) != len(output_tensor_meta): + raise ValueError( f"For the op {op.name()}, `output_specs` has {len(output_specs)} outputs which does not equal the " f"number of op outputs {_length(output_tensor_meta)}." ) + for i, spec in enumerate(output_specs): if isinstance(spec, DTensorSpec): output_tensor_meta_i = output_tensor_meta[i] @@ -203,8 +209,6 @@ def propagate_op_sharding_non_cached(self, op_schema: OpSchema) -> OutputShardin if op_schema.op is aten._local_scalar_dense.default: return OutputSharding(None, op_schema) - out_tensor_meta = self._propagate_tensor_meta(op_schema) - def spec_to_strategy(spec: object) -> object: if isinstance(spec, DTensorSpec): return OpStrategy([PlacementStrategy(spec)]) @@ -245,55 +249,12 @@ def spec_to_strategy(spec: object) -> object: # single Op strategy output_strategy = self._select_strategy(op_strategy) - # check if we need to redistribute the input - needs_redistribute = False - expected_input_specs = [] - # in case where the op does not specify input_specs and output_specs # is a DTensorSpec, we use output_specs as the spec for each DTensor # input arg. if output_strategy.input_specs is None: assert isinstance(output_strategy.output_specs, DTensorSpec) - for idx, input_spec in enumerate(op_schema.args_spec): - desired_spec = ( - output_strategy.output_spec - if output_strategy.input_specs is None - else output_strategy.input_specs[idx] - ) - expected_input_specs.append( - desired_spec.shallow_copy_with_tensor_meta( - input_spec.tensor_meta - ) - ) - if input_spec.placements != desired_spec.placements: - needs_redistribute = True - - suggestion_schema = None - if needs_redistribute: - suggestion_schema = OpSchema( - op_schema.op, tuple(expected_input_specs), {} - ) - suggestion_schema._inplace_rewrap_schema_suggestion(op_schema) - - # shape and stride args need to be modified for - # view ops and new factory ops, potentially - if op_schema.op in self.op_to_shape_and_stride_idx: - assert isinstance(output_strategy.output_spec, DTensorSpec) - # It happens when the output has the same shape as the input - # and the input placements are not all Replicate(). - if output_strategy.output_spec.is_sharded(): - schema = suggestion_schema or op_schema - assert isinstance(out_tensor_meta, TensorMeta) - # rlog(f"Need the out tensor meta here!") - suggestion_schema = self._adjust_shape_and_stride_args( - out_tensor_meta, schema, output_strategy.output_spec, mesh - ) - needs_redistribute = True - - # else: - # rlog(f"Don't need it because it's not sharded") - # construct output spec for the op if op_schema.return_type_tuple_tensor_like(): # for ops that return multiple tensors and the output_specs is not @@ -320,11 +281,10 @@ def spec_to_strategy(spec: object) -> object: output_sharding = OutputSharding( output_specs, - suggestion_schema, - needs_redistribute=needs_redistribute, + redistribute_schema=None, + needs_redistribute=False, ) elif isinstance(op_strategy, TupleStrategy): - # rlog(f"Strategy returned a TupleStrategy") # tuple strategy output sharding processing # runtime selected placement strategy for each TupleStrategy input arg selected_strategies: List[PlacementStrategy] = [] @@ -335,7 +295,6 @@ def spec_to_strategy(spec: object) -> object: selected_strategies.append(selected_strategy) out_spec_list.append(selected_strategy.output_spec) - needs_redistribute = False suggestion_args: List[object] = [] tensor_or_list_tensor_arg_idx = 0 @@ -355,8 +314,6 @@ def spec_to_strategy(spec: object) -> object: arg_spec.tensor_meta ) ) - if arg_spec.placements != expected_input_spec.placements: - needs_redistribute = True expected_input_spec_list.append(expected_input_spec) suggestion_args.append( tuple(expected_input_spec_list) @@ -374,34 +331,21 @@ def spec_to_strategy(spec: object) -> object: arg.tensor_meta ) ) - if arg.placements != expected_input_spec.placements: - needs_redistribute = True suggestion_args.append(expected_input_spec) tensor_or_list_tensor_arg_idx += 1 else: suggestion_args.append(arg) - suggestion_schema = None - if needs_redistribute: - suggestion_schema = OpSchema( - op_schema.op, tuple(suggestion_args), op_schema.kwargs_schema - ) - output_sharding = OutputSharding( - tuple(out_spec_list) if out_tensor_meta is not None else None, - suggestion_schema, - needs_redistribute=needs_redistribute, + tuple(out_spec_list), + redistribute_schema=None, + needs_redistribute=False, ) else: raise ValueError("Unsupported op strategy type") - # associate the output sharding with the output tensor metadata - self._wrap_output_spec_tensor_meta( - op_schema.op, output_sharding.output_spec, out_tensor_meta - ) return output_sharding elif op_schema.op in self.op_to_rules: - # rlog(f"Op {op_schema.op} has no strategy, using rule") # propagate the sharding with rule sharding_prop_func = self.op_to_rules[op_schema.op] @@ -438,11 +382,6 @@ def spec_to_strategy(spec: object) -> object: output_sharding.output_spec = propagation_res.output_spec output_sharding.needs_redistribute = True - # associate the output sharding with the output tensor metadata - self._wrap_output_spec_tensor_meta( - op_schema.op, output_sharding.output_spec, out_tensor_meta - ) - return output_sharding else: raise NotImplementedError( @@ -464,32 +403,3 @@ def _select_strategy(self, strategy: OpStrategy) -> PlacementStrategy: # for eager execution, we just select the one with the minimal redistribute cost return strategy.strategies[strategy_costs.index(min(strategy_costs))] - - def _adjust_shape_and_stride_args( - self, - out_tensor_meta: TensorMeta, - schema: OpSchema, - spec: DTensorSpec, - mesh: DeviceMesh, - ) -> OpSchema: - shape_stride_idx = self.op_to_shape_and_stride_idx[schema.op] - if isinstance(shape_stride_idx, tuple): - shape_idx, stride_idx = shape_stride_idx - else: - shape_idx = shape_stride_idx - stride_idx = None - - expected_input_schema = list(schema.args_schema) - # adjust shape to be the same as that of the _local_tensor - # of the DTensor input arg at index 0, which is inferred - expected_input_schema[shape_idx] = compute_local_shape( - out_tensor_meta.shape, mesh, spec.placements - ) - - # adjust the stride arg for aten.new_empty_strided.default - if stride_idx: - expected_input_schema[stride_idx] = compute_local_stride( - out_tensor_meta.stride, mesh, spec.placements - ) - - return OpSchema(schema.op, tuple(expected_input_schema), schema.kwargs_schema) diff --git a/src/chop/distributed/tensor/api.py b/src/chop/distributed/tensor/api.py index 47a778bea..69a9e62b8 100644 --- a/src/chop/distributed/tensor/api.py +++ b/src/chop/distributed/tensor/api.py @@ -226,12 +226,6 @@ def __new__( """ Construct a DTensor from a local tensor, device mesh, and placement and other tensor properties (i.e. shape, requires_grad, strides, etc). - Note: This is not a public API and it's only supposed to be used by the - operator implementations and internals. If you want to construct a - DTensor from a local tensor, consider using `DTensor.from_local`, if - you want to construct a DTensor from a "global" tensor (where you - already have tensor initialized and want to shard this tensor), - consider using `distribute_tensor`. """ if local_tensor.requires_grad and not requires_grad: warnings.warn( diff --git a/src/chop/distributed/tensor/ops/tensor_ops.py b/src/chop/distributed/tensor/ops/tensor_ops.py index 8e64ff514..6147a16fc 100644 --- a/src/chop/distributed/tensor/ops/tensor_ops.py +++ b/src/chop/distributed/tensor/ops/tensor_ops.py @@ -756,7 +756,10 @@ def split_rule(op_schema: OpSchema) -> OutputSharding: need_reshard = True input_spec = DTensorSpec( mesh=input_spec.mesh, - placements=unshard_tensor_dim(input_spec.placements, dim=dim), + placements=unshard_tensor_dim( + input_spec.placements, + dim=dim, + ), tensor_meta=input_spec.tensor_meta, ) @@ -777,7 +780,10 @@ def size_split(N, i): return [i] * (N // i) + ([N % i] if N % i != 0 else []) output_size_list = ( - size_split(input_spec.shape[dim], split_size_or_sections) + size_split( + input_spec.shape[dim], + split_size_or_sections, + ) if isinstance(split_size_or_sections, int) else split_size_or_sections ) diff --git a/src/chop/nn/functional/dtensor.py b/src/chop/nn/functional/dtensor.py index be7ccacde..e773aaa95 100644 --- a/src/chop/nn/functional/dtensor.py +++ b/src/chop/nn/functional/dtensor.py @@ -3,7 +3,7 @@ import torch import torch.fx as fx from torch.distributed.device_mesh import DeviceMesh -from torch.distributed._tensor.placement_types import DTensorSpec +from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta from torch.distributed._tensor._redistribute import redistribute_local_tensor from torch.distributed._tensor.placement_types import Placement @@ -73,29 +73,12 @@ def redistribute_dtensor( """ # If we are not in a distributed setting, we can skip redistribution. - try: - rank = torch.distributed.get_rank() - except: - rank = 0 - if not isinstance(input, DTensor): - # rlog( - # logger, - # rank, - # f"Skipping redistribution because received {type(input)} instead of DTensor", - # level="warning", - # ) return input current_spec = input._spec if current_spec.placements != placements: - # rlog( - # logger, - # rank, - # f"Redistributing tensor from {current_spec.placements} to {placements}", - # level="info", - # ) target_spec = DTensorSpec( input._spec.mesh, placements, @@ -111,12 +94,6 @@ def redistribute_dtensor( ) else: # use the same local tensor if placements are the same. - # rlog( - # logger, - # rank, - # f"Skipping redistribution because placements are the same: {placements}", - # level="info", - # ) output = input._local_tensor target_spec = current_spec From 7ccd4eb33ad50873bbafae98be03a219d38f3bdc Mon Sep 17 00:00:00 2001 From: Pedro Gimenes Date: Thu, 8 Aug 2024 17:52:53 +0000 Subject: [PATCH 69/93] finish OpDispatcher refactoring + fix bug in pointwise_strategy + insert_dtensor_wrapper_transform_pass + error when preload file not found --- setup.py | 2 +- src/chop/distributed/launcher.py | 2 + src/chop/distributed/tensor/_dispatch.py | 249 +++--------------- src/chop/distributed/tensor/_sharding_prop.py | 45 ++-- src/chop/distributed/tensor/api.py | 15 +- src/chop/distributed/tensor/ops/math_ops.py | 12 +- .../distributed/tensor/ops/pointwise_ops.py | 25 +- src/chop/distributed/tensor/ops/view_ops.py | 6 +- src/chop/distributed/utils.py | 7 +- src/chop/nn/functional/dtensor.py | 19 +- src/chop/passes/__init__.py | 1 + src/chop/passes/graph/__init__.py | 1 + .../analysis/autosharding/autosharding.py | 7 +- .../autosharding/strategies/pointwise_ops.py | 3 +- src/chop/passes/graph/transforms/__init__.py | 1 + .../transforms/insert_dtensor_wrapper.py | 100 +++++++ .../passes/graph/transforms/resharding.py | 4 + src/chop/pipelines/auto_pipeline.py | 2 + src/chop/pipelines/distributed_inference.py | 1 + 19 files changed, 211 insertions(+), 291 deletions(-) create mode 100644 src/chop/passes/graph/transforms/insert_dtensor_wrapper.py diff --git a/setup.py b/setup.py index 021884586..5fdc54272 100644 --- a/setup.py +++ b/setup.py @@ -98,7 +98,7 @@ def get_system(): author="Aaron Zhao, Jianyi Cheng, Cheng Zhang, Pedro Gimenes", author_email="a.zhao@imperial.ac.uk, jianyi.cheng17@imperial.ac.uk, chengzhang98@outlook.com, pedro.gimenes19@imperial.ac.uk", license_files=("LICENSE",), - python_requires=">=3.11.9", + python_requires=">=3.11.4", package_dir={ "": "src", }, diff --git a/src/chop/distributed/launcher.py b/src/chop/distributed/launcher.py index 86f99227b..3f3cb524d 100644 --- a/src/chop/distributed/launcher.py +++ b/src/chop/distributed/launcher.py @@ -13,6 +13,7 @@ class MaseLauncher: def __init__( self, + mg=None, world_size=None, device_mesh=None, device_fn=None, @@ -24,6 +25,7 @@ def __init__( world_size (int, optional): Number of GPUs to use. Defaults to None. device_mesh (list, optional): List of GPUs to use. Defaults to None. """ + self.mg = mg self.world_size = world_size self.device_mesh = device_mesh self.device_fn = device_fn diff --git a/src/chop/distributed/tensor/_dispatch.py b/src/chop/distributed/tensor/_dispatch.py index 52a34ed9f..3e77a015a 100644 --- a/src/chop/distributed/tensor/_dispatch.py +++ b/src/chop/distributed/tensor/_dispatch.py @@ -1,17 +1,8 @@ # Copyright (c) Meta Platforms, Inc. and affiliates -import contextlib -import functools -import logging -import operator -import warnings from typing import cast, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING import torch -import torch.distributed as dist -import torch.distributed._tensor.random as random from torch.distributed._tensor._op_schema import ( - _is_inplace_op, - _is_out_variant_op, OpInfo, OpSchema, OutputSpecType, @@ -30,8 +21,7 @@ from torch.distributed._tensor.random import is_rng_supported_mesh -if TYPE_CHECKING: - from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.device_mesh import DeviceMesh try: from torch.utils import _cxx_pytree as pytree @@ -45,16 +35,7 @@ aten = torch.ops.aten -def try_get_replicate_spec(tensor_arg: torch.Tensor, mesh: "DeviceMesh") -> DTensorSpec: - # tensor_arg is an instance of torch.Tensor and could be an arg or kwarg. - if tensor_arg.numel() == 1 and tensor_arg.ndim == 1: - warnings.warn( - "Found a non-scalar tensor with numel=1 and ndim!=0, " - "we are implicitly creating a replicated DTensor for it. " - "However, please consider changing it to a scalar tensor " - "or explicitly create a DTensor under distributed enviroment." - ) - +def get_replicate_spec(tensor_arg: torch.Tensor, mesh: "DeviceMesh") -> DTensorSpec: # scalar tensor can be safely treated as replicated replication_spec = DTensorSpec( mesh, @@ -95,29 +76,6 @@ def is_same_size_handler( return lhs.shape == rhs.shape -def rlog(msg): - rank = torch.distributed.get_rank() - if rank == 0: - print(msg) - - -def _get_global_shape(local_shape, dtensor_spec): - if dtensor_spec is None: - return local_shape - - placements = dtensor_spec.placements - global_shape = list(local_shape) - mesh_shape = dtensor_spec.mesh.shape - - for mesh_dim, placement in enumerate(placements): - if isinstance(placement, Shard): - global_shape[placement.dim] = ( - global_shape[placement.dim] * mesh_shape[mesh_dim] - ) - - return torch.Size(global_shape) - - class OpDispatcher: """ Op dispatching class instance to handle args/kwargs pre-processing (un-wrapping), sharding @@ -164,159 +122,47 @@ def dispatch( if op_call in self._custom_op_handlers: return self._custom_op_handlers[op_call](op_call, args, kwargs) # type: ignore[operator] - # extract local tensor and sharding infos to a OpInfo - op_info = self.unwrap_to_op_info(op_call, args, kwargs) - - output_sharding = self.sharding_propagator.propagate(op_info) - assert output_sharding is not None, "output sharding should not be None" - # run local op computation with potentially modified args/kwargs - local_tensor_args = op_info.local_args - local_tensor_args = cast(Tuple[object, ...], local_tensor_args) - - local_results = op_call(*local_tensor_args, **op_info.local_kwargs) + local_tensor_args = [ + arg._local_tensor if isinstance(arg, dtensor.DTensor) else arg + for arg in args + ] + + local_tensor_kwargs = { + k: v.local_tensor if isinstance(v, dtensor.DTensor) else v + for k, v in kwargs.items() + } - # Getting tensor meta after running the op to avoid running it twice - if isinstance(local_results, (tuple, list)): - out_tensor_meta = [ - TensorMeta( - shape=_get_global_shape( - r.shape, - output_sharding.output_spec[out_idx], + local_results = op_call(*local_tensor_args, **local_tensor_kwargs) + + # We still need to wrap the local result in a DTensor here in two cases + # 1. When creating a nn.Parameter from a DTensor, it must call tensor.detach + # and the return type must match the input type (DTensor). + # 2. When a single FX op decomposes into multiple aten ops (e.g. torch.embedding) + if op_call._name == "aten::detach": + return self.wrap( + local_results, + DTensorSpec( + mesh=DeviceMesh( + "cuda", + mesh=torch.Tensor( + # todo: generalize + [ + [0, 1, 2, 3], + [4, 5, 6, 7], + ] + ), + ), + placements=args[0]._spec.placements, + tensor_meta=TensorMeta( + shape=args[0]._spec.tensor_meta.shape, + stride=args[0]._spec.tensor_meta.stride, + dtype=args[0]._spec.tensor_meta.dtype, ), - stride=r.stride(), - dtype=r.dtype, - ) - for out_idx, r in enumerate(local_results) - ] - else: - out_tensor_meta = TensorMeta( - shape=_get_global_shape( - local_results.shape, - output_sharding.output_spec, ), - stride=local_results.stride(), - dtype=local_results.dtype, - ) - - # Annotate output DTensorSpec with TensorMeta object - self.sharding_propagator._wrap_output_spec_tensor_meta( - op_call, - output_sharding.output_spec, - out_tensor_meta, - ) - - return self.wrap( - local_results, - output_sharding.output_spec, - ) - - @staticmethod - def redistribute_local_args( - op_info: OpInfo, - suggested_input_schema: OpSchema, - ) -> None: - # NOTE: it's very rare that we need to reshard kwargs so we intentionally skip it - - # TODO: the op schema should probably just remain flattened so that we can avoid this tree flatten - # Need to fix all the ops before doing this. - if op_info.args_tree_spec is not None: - flatten_args_schema_to_reshard = tuple( - pytree.tree_leaves(suggested_input_schema.args_schema) ) - else: - flatten_args_schema_to_reshard = suggested_input_schema.args_schema - new_local_args: List[object] = [] - for i, arg_spec in enumerate(op_info.flat_args_schema): - reshard_arg_spec = flatten_args_schema_to_reshard[i] - if isinstance(arg_spec, DTensorSpec): - local_tensor = cast(torch.Tensor, op_info.local_args[i]) - if arg_spec != reshard_arg_spec: - resharded_local_tensor = redistribute_local_tensor( - local_tensor, arg_spec, reshard_arg_spec - ) - new_local_args.append(resharded_local_tensor) - else: - new_local_args.append(local_tensor) - else: - new_local_args.append(reshard_arg_spec) - - op_info.local_args = tuple(new_local_args) - - def unwrap_to_op_info( - self, - op_call: torch._ops.OpOverload, - args: Tuple[object, ...], - kwargs: Dict[str, object], - ) -> OpInfo: - # get runtime schema to determine whether to use pytree to flatten inputs - runtime_schema_info = self.sharding_propagator.op_to_schema_info.get( - op_call, None - ) - - args_list, args_spec = args, None - - args_schema: List[object] = [] - kwargs_schema: Dict[str, object] = {} - local_args: List[object] = [] - local_kwargs: Dict[str, object] = {} - mesh: Optional[DeviceMesh] = None - - for arg in args_list: - if isinstance(arg, dtensor.DTensor): - args_schema.append(arg._spec) - local_args.append(arg._local_tensor) - if mesh is not None: - if mesh != arg.device_mesh: - raise NotImplementedError( - f"{op_call}: DTensor does not support cross-mesh operation yet!" - f"Got meshes: {mesh} {arg.device_mesh}" - ) - else: - mesh = arg.device_mesh - elif isinstance(arg, torch.Tensor): - mesh = mesh or try_find_mesh_from_args(op_call, args_list) - args_schema.append(try_get_replicate_spec(arg, mesh)) - local_args.append(arg) - else: - args_schema.append(arg) - local_args.append(arg) - - for k, v in kwargs.items(): - if isinstance(v, dtensor.DTensor): - kwargs_schema[k] = v._spec - local_kwargs[k] = v._local_tensor - if mesh is not None: - if mesh != v.device_mesh: - raise NotImplementedError( - f"{op_call}: DTensor does not support cross-mesh operation yet!" - ) - else: - mesh = v.device_mesh - elif isinstance(v, torch.Tensor): - mesh = mesh or try_find_mesh_from_args(op_call, args_list) - kwargs_schema[k] = try_get_replicate_spec(v, mesh) - local_kwargs[k] = v - else: - kwargs_schema[k] = v - local_kwargs[k] = v - - assert mesh is not None, f"found no DeviceMesh from dtensor args for {op_call}!" - op_info = OpInfo( - mesh=mesh, - schema=OpSchema( - op_call, - tuple(args_schema), - kwargs_schema, - schema_info=runtime_schema_info, - ), - flat_args_schema=args_schema, - local_args=tuple(local_args), - local_kwargs=local_kwargs, - args_tree_spec=args_spec, - ) - return op_info + return local_results @staticmethod def wrap( @@ -324,23 +170,12 @@ def wrap( spec: OutputSpecType, ) -> object: if isinstance(res, torch.Tensor): - if spec is not None: - assert isinstance( - spec, DTensorSpec - ), f"output spec does not match with output! Expected DTensorSpec, got {spec}." - return dtensor.DTensor( - res, - spec, - requires_grad=res.requires_grad, - ) - else: - # if output does not have a DTensorSpec due to specific ops, it must be a scalar tensor - assert res.ndim == 0, "output tensor should be scalar!" - return res + return dtensor.DTensor( + res, + spec, + requires_grad=res.requires_grad, + ) elif isinstance(res, (list, tuple)): - assert spec is not None and isinstance( - spec, (list, tuple) - ), f"output spec does not match with output! Expected list/tuple, got {spec}." res_list = [] for e, s in zip(res, spec): res_list.append(OpDispatcher.wrap(e, s)) diff --git a/src/chop/distributed/tensor/_sharding_prop.py b/src/chop/distributed/tensor/_sharding_prop.py index 7f3a12545..c03046319 100644 --- a/src/chop/distributed/tensor/_sharding_prop.py +++ b/src/chop/distributed/tensor/_sharding_prop.py @@ -37,6 +37,24 @@ def _length(obj) -> int: return len(obj) +def spec_to_strategy(spec: object) -> object: + if isinstance(spec, DTensorSpec): + return OpStrategy([PlacementStrategy(spec)]) + elif ( + isinstance(spec, (list, tuple)) + and len(spec) > 0 + and isinstance(spec[0], DTensorSpec) + ): + # tensor list create tuple strategy + tuple_strategy = [spec_to_strategy(s) for s in spec] + tuple_strategy = cast(Sequence[StrategyType], tuple_strategy) + return TupleStrategy( + tuple(tuple_strategy) if isinstance(spec, tuple) else tuple_strategy + ) + else: + return spec + + def rlog(msg): rank = torch.distributed.get_rank() if rank == 0: @@ -209,23 +227,6 @@ def propagate_op_sharding_non_cached(self, op_schema: OpSchema) -> OutputShardin if op_schema.op is aten._local_scalar_dense.default: return OutputSharding(None, op_schema) - def spec_to_strategy(spec: object) -> object: - if isinstance(spec, DTensorSpec): - return OpStrategy([PlacementStrategy(spec)]) - elif ( - isinstance(spec, (list, tuple)) - and len(spec) > 0 - and isinstance(spec[0], DTensorSpec) - ): - # tensor list create tuple strategy - tuple_strategy = [spec_to_strategy(s) for s in spec] - tuple_strategy = cast(Sequence[StrategyType], tuple_strategy) - return TupleStrategy( - tuple(tuple_strategy) if isinstance(spec, tuple) else tuple_strategy - ) - else: - return spec - if op_schema.op in self.op_strategy_funcs: # generate op strategy for the op. mesh = try_find_mesh_from_args(op_schema.op, op_schema.args_schema) @@ -249,12 +250,6 @@ def spec_to_strategy(spec: object) -> object: # single Op strategy output_strategy = self._select_strategy(op_strategy) - # in case where the op does not specify input_specs and output_specs - # is a DTensorSpec, we use output_specs as the spec for each DTensor - # input arg. - if output_strategy.input_specs is None: - assert isinstance(output_strategy.output_specs, DTensorSpec) - # construct output spec for the op if op_schema.return_type_tuple_tensor_like(): # for ops that return multiple tensors and the output_specs is not @@ -290,7 +285,6 @@ def spec_to_strategy(spec: object) -> object: selected_strategies: List[PlacementStrategy] = [] out_spec_list: List[DTensorSpec] = [] for strategy in op_strategy.childs: - assert isinstance(strategy, OpStrategy) selected_strategy = self._select_strategy(strategy) selected_strategies.append(selected_strategy) out_spec_list.append(selected_strategy.output_spec) @@ -395,9 +389,6 @@ def _select_strategy(self, strategy: OpStrategy) -> PlacementStrategy: strategy_costs: List[float] = [] for strtg in strategy.strategies: - assert ( - strtg.redistribute_cost is not None - ), "must set redistribute cost each strategy!" redistribute_cost = sum(chain.from_iterable(strtg.redistribute_cost)) strategy_costs.append(redistribute_cost) diff --git a/src/chop/distributed/tensor/api.py b/src/chop/distributed/tensor/api.py index 69a9e62b8..1d52818fd 100644 --- a/src/chop/distributed/tensor/api.py +++ b/src/chop/distributed/tensor/api.py @@ -205,7 +205,7 @@ def backward(ctx, grad_output: "DTensor"): # type: ignore[override] return grad_output.to_local(), None, None, None, None, None -class DTensor(torch.Tensor): # pyre-ignore[13]: pyre is bad at __new__ +class DTensor(torch.Tensor): _local_tensor: torch.Tensor _spec: DTensorSpec __slots__ = ["_local_tensor", "_spec"] @@ -227,15 +227,7 @@ def __new__( Construct a DTensor from a local tensor, device mesh, and placement and other tensor properties (i.e. shape, requires_grad, strides, etc). """ - if local_tensor.requires_grad and not requires_grad: - warnings.warn( - "To construct DTensor from torch.Tensor, it's recommended to " - "use local_tensor.detach() and make requires_grad consistent." - ) - # new method instruct wrapper tensor from local_tensor and add - # placement spec, it does not do actual distribution - assert spec.tensor_meta is not None, "TensorMeta should not be None!" r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined] cls, spec.tensor_meta.shape, @@ -254,7 +246,10 @@ def __new__( # pyre-fixme[3]: Return type must be annotated. def __repr__(self): # TODO: consider all_gather the local tensors for better debugging - return f"DTensor(local_tensor={self._local_tensor}, device_mesh={self._spec.mesh}, placements={self._spec.placements})" + if self._spec is None: + return f"DTensor(local_tensor={self._local_tensor}, device_mesh=None, placements=None)" + else: + return f"DTensor(local_tensor={self._local_tensor}, device_mesh={self._spec.mesh}, placements={self._spec.placements})" def __tensor_flatten__(self): """ diff --git a/src/chop/distributed/tensor/ops/math_ops.py b/src/chop/distributed/tensor/ops/math_ops.py index c6ebe49be..1b2211ffd 100644 --- a/src/chop/distributed/tensor/ops/math_ops.py +++ b/src/chop/distributed/tensor/ops/math_ops.py @@ -792,7 +792,6 @@ def layer_norm_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy: output_strategy = OpStrategy([]) for idx, input_placement_strategy in enumerate(input_strategy.strategies): op_args_target_specs = [] - redistribute_costs = [] input_src_spec = input_placement_strategy.output_spec # for the input tensor, we replicate it on the inner dims if necessary @@ -804,9 +803,6 @@ def layer_norm_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy: tensor_meta=input_src_spec.tensor_meta, ) op_args_target_specs.append(input_target_spec) - redistribute_costs.append( - generate_redistribute_costs(input_strategy, input_target_spec) - ) if weight_strategy is not None: assert isinstance(weight_strategy, OpStrategy) @@ -821,9 +817,6 @@ def layer_norm_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy: tensor_meta=weight_src_spec.tensor_meta, ) op_args_target_specs.append(weight_target_spec) - redistribute_costs.append( - generate_redistribute_costs(weight_strategy, weight_target_spec) - ) if bias_strategy is not None: assert isinstance(bias_strategy, OpStrategy) @@ -838,9 +831,6 @@ def layer_norm_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy: tensor_meta=bias_src_spec.tensor_meta, ) op_args_target_specs.append(bias_target_spec) - redistribute_costs.append( - generate_redistribute_costs(bias_strategy, bias_target_spec) - ) # the output spec is the same as input spec output_target_spec = input_target_spec @@ -848,7 +838,7 @@ def layer_norm_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy: PlacementStrategy( output_specs=output_target_spec, input_specs=op_args_target_specs, - redistribute_cost=redistribute_costs, + redistribute_cost=[0], ) ) diff --git a/src/chop/distributed/tensor/ops/pointwise_ops.py b/src/chop/distributed/tensor/ops/pointwise_ops.py index c3e1f082f..f69a31577 100644 --- a/src/chop/distributed/tensor/ops/pointwise_ops.py +++ b/src/chop/distributed/tensor/ops/pointwise_ops.py @@ -13,7 +13,6 @@ TupleStrategy, ) from torch.distributed._tensor.ops.utils import ( - generate_redistribute_costs, infer_broadcast_dims_map, map_placements_after_broadcast, normalize_dim, @@ -29,22 +28,6 @@ from chop.distributed.tensor.ops.utils import register_op_strategy aten = torch.ops.aten -# leave the remaining pointwise_ops list here for convenience, -# Below ops are some pointwise ops that are yet to be supported, -# they might not be a complete list. -# pointwise_ops = [ -# "fake_quantize_per_channel_affine", -# "fake_quantize_per_tensor_affine", -# "floor_divide", # floor_divide is deprecated -# "frexp", # multiple output pointwise op, need to add support -# "gradient", # need investigation on this op -# "imag", # complex data type only -# "quantized_batch_norm", -# "quantized_max_pool1d", -# "quantized_max_pool2d", -# "real", # complex data type only -# ] - linear_pointwise_ops = [ aten.div.Scalar, # this op is linear on the first argument, and the second argument is scalar, so it fits as a linear op. @@ -474,8 +457,7 @@ def common_pointwise_strategy( out_placements.append(placement) input_specs: List[DTensorSpec] = [] - redistribute_costs: List[List[float]] = [] - for idx, input_arg in enumerate(args_schema): + for input_arg in args_schema: if isinstance(input_arg, OpStrategy): # every arg follow the out_placements, but need to handle broadcasting input_arg_spec = input_arg.strategies[0].output_spec @@ -493,9 +475,6 @@ def common_pointwise_strategy( tensor_meta=input_arg_spec.tensor_meta, ) input_specs.append(input_arg_target_spec) - redistribute_costs.append( - generate_redistribute_costs(input_arg, input_arg_target_spec) - ) pointwise_strategy.strategies.append( PlacementStrategy( @@ -504,7 +483,7 @@ def common_pointwise_strategy( placements=tuple(out_placements), ), input_specs=input_specs, - redistribute_cost=redistribute_costs, + redistribute_cost=[], ) ) return pointwise_strategy diff --git a/src/chop/distributed/tensor/ops/view_ops.py b/src/chop/distributed/tensor/ops/view_ops.py index dc8103b08..0e1402a2a 100644 --- a/src/chop/distributed/tensor/ops/view_ops.py +++ b/src/chop/distributed/tensor/ops/view_ops.py @@ -29,7 +29,6 @@ from chop.distributed.tensor.ops.utils import register_op_strategy from chop.distributed.tensor.ops.utils import ( - generate_redistribute_costs, normalize_dim, normalize_dims, prod, @@ -621,16 +620,13 @@ def reshape_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: mesh=input_src_spec.mesh, tensor_meta=input_src_spec.tensor_meta, ) - redistribute_costs = [ - generate_redistribute_costs(input_strategy, input_tgt_spec) - ] output_spec = DTensorSpec(mesh=mesh, placements=tuple(output_placements)) output_strategy.strategies.append( PlacementStrategy( output_specs=output_spec, input_specs=(input_tgt_spec,), - redistribute_cost=redistribute_costs, + redistribute_cost=[], ) ) diff --git a/src/chop/distributed/utils.py b/src/chop/distributed/utils.py index e16fb911f..c119ed9c6 100644 --- a/src/chop/distributed/utils.py +++ b/src/chop/distributed/utils.py @@ -109,7 +109,11 @@ def dist_model_fn( distributed_tensor = distribute_tensor( getattr(module, parameter), device_mesh, placement ) - setattr(module, parameter, torch.nn.Parameter(distributed_tensor)) + setattr( + module, + parameter, + torch.nn.Parameter(distributed_tensor), + ) except Exception as e: rlog( logger, @@ -117,3 +121,4 @@ def dist_model_fn( f"Error distributing parameter {parameter} of module {node_name} to {placement}: {e}", level="error", ) + raise e diff --git a/src/chop/nn/functional/dtensor.py b/src/chop/nn/functional/dtensor.py index e773aaa95..c5e29b73d 100644 --- a/src/chop/nn/functional/dtensor.py +++ b/src/chop/nn/functional/dtensor.py @@ -8,14 +8,20 @@ from torch.distributed._tensor.placement_types import Placement +from chop.ir.graph import MaseMetadata from chop.distributed.tensor import DTensor from chop.tools import get_logger -from chop.distributed.utils import rlog logger = get_logger(__name__) logger.setLevel("DEBUG") +def rlog(msg): + rank = torch.distributed.get_rank() + if rank == 0: + print(msg) + + @fx.wrap def dtensor_arange( start: int, @@ -59,6 +65,7 @@ def redistribute_dtensor( input: DTensor, placements: Tuple[Placement, ...], async_op: bool = False, + input_tensor_mesh=None, ): """ Redistribute a DTensor to a new set of placements. @@ -76,16 +83,24 @@ def redistribute_dtensor( if not isinstance(input, DTensor): return input + torch_mesh = DeviceMesh( + "cuda", + mesh=torch.Tensor(input_tensor_mesh), + ) + current_spec = input._spec if current_spec.placements != placements: target_spec = DTensorSpec( - input._spec.mesh, + torch_mesh, placements, tensor_meta=input._spec.tensor_meta, ) local_tensor = input._local_tensor + + assert not isinstance(local_tensor, DTensor) + output = redistribute_local_tensor( local_tensor, current_spec, diff --git a/src/chop/passes/__init__.py b/src/chop/passes/__init__.py index b8d90219b..a66399732 100644 --- a/src/chop/passes/__init__.py +++ b/src/chop/passes/__init__.py @@ -35,6 +35,7 @@ patch_metadata_transform_pass, resharding_transform_pass, replace_method_with_function, + insert_dtensor_wrapper_transform_pass, ) from .module.analysis import calculate_avg_bits_module_analysis_pass from .module.transforms import quantize_module_transform_pass diff --git a/src/chop/passes/graph/__init__.py b/src/chop/passes/graph/__init__.py index 0c54e01fd..ee671b21a 100644 --- a/src/chop/passes/graph/__init__.py +++ b/src/chop/passes/graph/__init__.py @@ -31,6 +31,7 @@ onnx_annotate_transform_pass, partition_to_multi_device_transform_pass, raise_granularity_transform_pass, + insert_dtensor_wrapper_transform_pass, ) from .interface import ( diff --git a/src/chop/passes/graph/analysis/autosharding/autosharding.py b/src/chop/passes/graph/analysis/autosharding/autosharding.py index d994b1939..b617dc221 100644 --- a/src/chop/passes/graph/analysis/autosharding/autosharding.py +++ b/src/chop/passes/graph/analysis/autosharding/autosharding.py @@ -29,7 +29,7 @@ def _import_solution( mg, solution: dict, mesh: MeshModel, - extrapolate_sharding: bool = True, + extrapolate_sharding: bool = False, ): """Import an autosharding solution into the metadata of the MaseGraph. @@ -247,7 +247,10 @@ def autosharding_analysis_pass(mg, pass_args: dict = {}): # Preload autosharding solution fname = pass_args.get("ilp_solution_file", "ilp_solution.pkl") # check if solution file exists - if pass_args.get("preload_solution", False) and os.path.exists(fname): + if pass_args.get("preload_solution", False): + if not os.path.exists(fname): + raise FileNotFoundError(f"Solution file {fname} not found.") + logger.info(f"Preloading autosharding solution from: {fname}") with open(fname, "rb") as file: solution = dill.load(file) diff --git a/src/chop/passes/graph/analysis/autosharding/strategies/pointwise_ops.py b/src/chop/passes/graph/analysis/autosharding/strategies/pointwise_ops.py index a7504d2d8..a198b3e93 100644 --- a/src/chop/passes/graph/analysis/autosharding/strategies/pointwise_ops.py +++ b/src/chop/passes/graph/analysis/autosharding/strategies/pointwise_ops.py @@ -141,8 +141,7 @@ def common_pointwise_strategy( placements=input_target_placements, tensor_meta=input_arg_spec.tensor_meta, ) - # input_specs.append(input_arg_target_spec) - input_specs = [input_arg_target_spec] * len(meta.node.args) + input_specs.append(input_arg_target_spec) dtype = meta["common"]["results"]["data_out_0"].get( "torch_dtype", torch.float32 diff --git a/src/chop/passes/graph/transforms/__init__.py b/src/chop/passes/graph/transforms/__init__.py index 09667c182..eaf089de3 100644 --- a/src/chop/passes/graph/transforms/__init__.py +++ b/src/chop/passes/graph/transforms/__init__.py @@ -22,5 +22,6 @@ from .patching import patch_metadata_transform_pass from .resharding import resharding_transform_pass +from .insert_dtensor_wrapper import insert_dtensor_wrapper_transform_pass from .find_replace.method_to_function import replace_method_with_function diff --git a/src/chop/passes/graph/transforms/insert_dtensor_wrapper.py b/src/chop/passes/graph/transforms/insert_dtensor_wrapper.py new file mode 100644 index 000000000..f44504903 --- /dev/null +++ b/src/chop/passes/graph/transforms/insert_dtensor_wrapper.py @@ -0,0 +1,100 @@ +import torch +from torch.distributed._tensor.api import DTensorSpec, TensorMeta +from torch.distributed import DeviceMesh +from copy import deepcopy + + +from chop.tools import get_logger +from chop.distributed.tensor import DTensor + + +logger = get_logger(__name__) +logger.setLevel("DEBUG") + + +def _create_dtensor(local_tensor, result_meta, torch_mesh): + return DTensor( + local_tensor=local_tensor, + spec=DTensorSpec( + mesh=torch_mesh, + placements=result_meta["dtensor_spec"].placements, + tensor_meta=TensorMeta( + shape=result_meta["value"].shape, + stride=result_meta["value"].stride(), + dtype=local_tensor.dtype, + ), + ), + requires_grad=local_tensor.requires_grad, + ) + + +def create_wrapper(node): + + target_fn = deepcopy(node.target) + + torch_mesh = DeviceMesh( + "cuda", + mesh=torch.Tensor([[0, 1, 2, 3], [4, 5, 6, 7]]), + ) + + result_names = list(node.meta["mase"]["common"]["results"].keys()) + + def dtensor_wrapper_fn(*args, **kwargs): + out = target_fn(*args, **kwargs) + + # In the event the OpDispatcher already wrapped a DTensor around + # the local result, avoid reaching recursive depth limit + if isinstance(out, (tuple, list)): + outs = [] + for r_idx, r in enumerate(out): + if isinstance(r, DTensor): + outs.append(r) + elif isinstance(r, torch.Tensor): + outs.append( + _create_dtensor( + r, + result_meta=node.meta["mase"]["common"]["results"][ + result_names[r_idx] + ], + torch_mesh=torch_mesh, + ) + ) + else: + outs.append(r) + + wrapped_out = tuple(outs) + + elif isinstance(out, DTensor): + wrapped_out = out + + elif isinstance(out, torch.Tensor): + wrapped_out = _create_dtensor( + out, + result_meta=node.meta["mase"]["common"]["results"][result_names[0]], + torch_mesh=torch_mesh, + ) + + else: + wrapped_out = out + + return wrapped_out + + return dtensor_wrapper_fn + + +def insert_dtensor_wrapper_transform_pass(mg, pass_args={}): + + logger.info("Inserting DTensor wrappers for call_function nodes") + + for node in mg.nodes: + if node.op == "call_function": + + logger.info(f"Inserting DTensor wrapper for {node.name}") + node.target = create_wrapper(node) + + else: + logger.warning( + f"Skipping node {node.name} because it is not a call_function" + ) + + return mg, {} diff --git a/src/chop/passes/graph/transforms/resharding.py b/src/chop/passes/graph/transforms/resharding.py index 785555026..07151e310 100644 --- a/src/chop/passes/graph/transforms/resharding.py +++ b/src/chop/passes/graph/transforms/resharding.py @@ -17,6 +17,9 @@ def _insert_resharding_nodes(mg, pass_args={}): logger.info( f"Running resharding_transform_pass to insert resharding nodes along necessary edges." ) + + device_mesh = pass_args.get("device_mesh", None) + for node in mg.fx_graph.nodes: if node.op == "call_function" and node.target == redistribute_dtensor: @@ -71,6 +74,7 @@ def _insert_resharding_nodes(mg, pass_args={}): args=(arg_obj, arg_specs.placements), kwargs={ "async_op": False, + "input_tensor_mesh": device_mesh, }, ) diff --git a/src/chop/pipelines/auto_pipeline.py b/src/chop/pipelines/auto_pipeline.py index 967f0f2f6..82b7ddfbc 100644 --- a/src/chop/pipelines/auto_pipeline.py +++ b/src/chop/pipelines/auto_pipeline.py @@ -36,4 +36,6 @@ def __call__(self, mg: MaseGraph, pass_args: dict, skip_passes: list = []): mg, pass_output = pass_fn(mg, pass_args=args) self.pass_outputs[pass_fn.__name__] = pass_output + mg.model.recompile() + return mg, self.pass_outputs diff --git a/src/chop/pipelines/distributed_inference.py b/src/chop/pipelines/distributed_inference.py index 9bac59864..a271c5298 100644 --- a/src/chop/pipelines/distributed_inference.py +++ b/src/chop/pipelines/distributed_inference.py @@ -27,6 +27,7 @@ def __init__(self) -> None: passes.init_metadata_analysis_pass, passes.add_common_metadata_analysis_pass, passes.autosharding_analysis_pass, + passes.insert_dtensor_wrapper_transform_pass, passes.resharding_transform_pass, passes.report_graph_analysis_pass, ] From 9fcb595689229293e85e7e1ffea2710de358ad35 Mon Sep 17 00:00:00 2001 From: Pedro Gimenes Date: Thu, 8 Aug 2024 17:59:29 +0000 Subject: [PATCH 70/93] include fully replicated backend for autosharding --- .../analysis/autosharding/autosharding.py | 5 +++- .../analysis/autosharding/fully_replicated.py | 28 +++++++++++++++++++ .../graph/analysis/autosharding/megatron.py | 15 +--------- 3 files changed, 33 insertions(+), 15 deletions(-) create mode 100644 src/chop/passes/graph/analysis/autosharding/fully_replicated.py diff --git a/src/chop/passes/graph/analysis/autosharding/autosharding.py b/src/chop/passes/graph/analysis/autosharding/autosharding.py index b617dc221..ad2ef006d 100644 --- a/src/chop/passes/graph/analysis/autosharding/autosharding.py +++ b/src/chop/passes/graph/analysis/autosharding/autosharding.py @@ -10,6 +10,7 @@ from .mesh_model import MeshModel from .alpa import alpa_autosharding_pass from .megatron import megatron_autosharding_pass +from .fully_replicated import fully_replicated_autosharding_pass logger = get_logger(__name__) logger.setLevel("INFO") @@ -273,7 +274,9 @@ def autosharding_analysis_pass(mg, pass_args: dict = {}): # Run intra-operator pass start_time = time() - if algo == "alpa": + if algo == "fully_replicated": + mg, pass_outs = fully_replicated_autosharding_pass(mg, mesh, pass_args) + elif algo == "alpa": mg, pass_outs = alpa_autosharding_pass(mg, mesh, pass_args) elif algo == "megatron": mg, pass_outs = megatron_autosharding_pass(mg, mesh, pass_args) diff --git a/src/chop/passes/graph/analysis/autosharding/fully_replicated.py b/src/chop/passes/graph/analysis/autosharding/fully_replicated.py new file mode 100644 index 000000000..8911c3720 --- /dev/null +++ b/src/chop/passes/graph/analysis/autosharding/fully_replicated.py @@ -0,0 +1,28 @@ +from torch.distributed._tensor._op_schema import DTensorSpec +from torch.distributed._tensor.placement_types import Replicate + +from chop.ir import MaseGraph +from .mesh_model import MeshModel + + +def fully_replicated_autosharding_pass( + mg: MaseGraph, + mesh: MeshModel, + pass_args: dict, +): + spec = DTensorSpec( + None, + (Replicate(), Replicate()), + None, + ) + + for node in mg.nodes: + meta = node.meta["mase"] + + for arg, arg_info in meta["common"]["args"].items(): + arg_info["dtensor_spec"] = spec + + for result, result_info in meta["common"]["results"].items(): + result_info["dtensor_spec"] = spec + + return mg, {"solution": {}} diff --git a/src/chop/passes/graph/analysis/autosharding/megatron.py b/src/chop/passes/graph/analysis/autosharding/megatron.py index 30cd36f7e..33372e101 100644 --- a/src/chop/passes/graph/analysis/autosharding/megatron.py +++ b/src/chop/passes/graph/analysis/autosharding/megatron.py @@ -7,17 +7,4 @@ def megatron_autosharding_pass( mesh: MeshModel, pass_args: dict, ): - for node in mg.fx_graph.nodes: - meta = node.meta["mase"]["common"] - - for arg, arg_spec in meta["args"].items(): - if not isinstance(arg_spec, dict): - continue - arg_spec["dtensor_spec"] = None - - for result, result_spec in meta["results"].items(): - if not isinstance(result_spec, dict): - continue - result_spec["dtensor_spec"] = None - - return mg, {"solution": {}} + raise NotImplementedError("Megatron autosharding pass is not implemented yet.") From 574689a19821163e902fe6f17713c5b57cc5fb78 Mon Sep 17 00:00:00 2001 From: Pedro Gimenes Date: Fri, 9 Aug 2024 13:54:03 +0000 Subject: [PATCH 71/93] include DTensorCache to bypass DTensor construction + remove high overhead @torch._disable_dynamo decorator + clean up logs --- src/chop/distributed/tensor/_dispatch.py | 15 --- src/chop/distributed/tensor/api.py | 4 - .../find_replace/method_to_function.py | 2 +- .../transforms/insert_dtensor_wrapper.py | 124 +++++++++++++----- .../passes/graph/transforms/resharding.py | 9 +- 5 files changed, 95 insertions(+), 59 deletions(-) diff --git a/src/chop/distributed/tensor/_dispatch.py b/src/chop/distributed/tensor/_dispatch.py index 3e77a015a..688a11868 100644 --- a/src/chop/distributed/tensor/_dispatch.py +++ b/src/chop/distributed/tensor/_dispatch.py @@ -3,34 +3,22 @@ import torch from torch.distributed._tensor._op_schema import ( - OpInfo, - OpSchema, OutputSpecType, ) from torch.distributed._tensor._tp_conv import ( convolution_backward_handler, convolution_handler, ) -from torch.distributed._tensor._utils import try_find_mesh_from_args from torch.distributed._tensor.placement_types import ( DTensorSpec, Replicate, - Shard, TensorMeta, ) -from torch.distributed._tensor.random import is_rng_supported_mesh - from torch.distributed.device_mesh import DeviceMesh -try: - from torch.utils import _cxx_pytree as pytree -except ImportError: - from torch.utils import _pytree as pytree # type: ignore[no-redef] - import chop.distributed.tensor.api as dtensor from chop.distributed.tensor._sharding_prop import ShardingPropagator -from chop.distributed.tensor._redistribute import redistribute_local_tensor aten = torch.ops.aten @@ -119,9 +107,6 @@ def dispatch( """ # operators that does not need to go through sharding propagation - if op_call in self._custom_op_handlers: - return self._custom_op_handlers[op_call](op_call, args, kwargs) # type: ignore[operator] - # run local op computation with potentially modified args/kwargs local_tensor_args = [ arg._local_tensor if isinstance(arg, dtensor.DTensor) else arg diff --git a/src/chop/distributed/tensor/api.py b/src/chop/distributed/tensor/api.py index 1d52818fd..e9cabae4e 100644 --- a/src/chop/distributed/tensor/api.py +++ b/src/chop/distributed/tensor/api.py @@ -215,7 +215,6 @@ class DTensor(torch.Tensor): _op_dispatcher: op_dispatch.OpDispatcher = op_dispatch.OpDispatcher() @staticmethod - @torch._disable_dynamo def __new__( cls, local_tensor: torch.Tensor, @@ -297,9 +296,6 @@ def __coerce_same_metadata_as_tangent__(self, flatten_spec): ) @classmethod - @torch._disable_dynamo - # pyre-fixme[3]: Return type must be annotated. - # pyre-fixme[2]: Parameter must be annotated. def __torch_dispatch__(cls, func, types, args=(), kwargs=None): return DTensor._op_dispatcher.dispatch( func, diff --git a/src/chop/passes/graph/transforms/find_replace/method_to_function.py b/src/chop/passes/graph/transforms/find_replace/method_to_function.py index 5cc70a672..34f7d5dd5 100644 --- a/src/chop/passes/graph/transforms/find_replace/method_to_function.py +++ b/src/chop/passes/graph/transforms/find_replace/method_to_function.py @@ -13,7 +13,7 @@ ) logger = get_logger(__name__) -logger.setLevel("DEBUG") +logger.setLevel("INFO") REPLACE_METHODS = { diff --git a/src/chop/passes/graph/transforms/insert_dtensor_wrapper.py b/src/chop/passes/graph/transforms/insert_dtensor_wrapper.py index f44504903..cda5508d2 100644 --- a/src/chop/passes/graph/transforms/insert_dtensor_wrapper.py +++ b/src/chop/passes/graph/transforms/insert_dtensor_wrapper.py @@ -9,32 +9,86 @@ logger = get_logger(__name__) -logger.setLevel("DEBUG") - - -def _create_dtensor(local_tensor, result_meta, torch_mesh): - return DTensor( - local_tensor=local_tensor, - spec=DTensorSpec( - mesh=torch_mesh, - placements=result_meta["dtensor_spec"].placements, - tensor_meta=TensorMeta( - shape=result_meta["value"].shape, - stride=result_meta["value"].stride(), - dtype=local_tensor.dtype, +logger.setLevel("INFO") + + +def rlog(msg): + rank = torch.distributed.get_rank() + if rank == 0: + logger.info(msg) + + +class DTensorCache: + _dtensor_dict: dict = {} + + def __init__(self): + """ + This cache is needed to avoid expensive calls to _make_wrapper_subclass + at runtime when wrapping local Tensor results in DTensor objects. + """ + pass + + +def _create_dtensor( + local_tensor, + node_name, + node_meta, + result_name, + torch_mesh, +): + cached_name = f"{node_name}_{result_name}" + cached_dtensor = DTensorCache._dtensor_dict.get(cached_name, None) + + # The DTensor is not found in the cache the first time each FX node is called + if cached_dtensor is None: + result_meta = node_meta["common"]["results"][result_name] + + dtensor = DTensor( + local_tensor=local_tensor, + spec=DTensorSpec( + mesh=torch_mesh, + placements=result_meta["dtensor_spec"].placements, + tensor_meta=TensorMeta( + shape=result_meta["value"].shape, + stride=result_meta["value"].stride(), + dtype=local_tensor.dtype, + ), ), - ), - requires_grad=local_tensor.requires_grad, - ) + requires_grad=local_tensor.requires_grad, + ) + + DTensorCache._dtensor_dict[cached_name] = dtensor + + return dtensor + + # If the DTensor is found in the cache, replace the local tensor + else: + # Replace local tensor without constructing a new dtensor + cached_dtensor._local_tensor = local_tensor + + # if DEBUG_MODE: + # assert cached dtensor has the same meta + # assert cached_dtensor._spec.placements == result_meta["dtensor_spec"].placements + # assert cached_dtensor._spec.tensor_meta.shape == result_meta["value"].shape + # assert cached_dtensor._spec.tensor_meta.stride == result_meta["value"].stride() + # assert cached_dtensor._spec.tensor_meta.dtype == local_tensor.dtype + + return cached_dtensor def create_wrapper(node): target_fn = deepcopy(node.target) + # todo: generalize torch_mesh = DeviceMesh( "cuda", - mesh=torch.Tensor([[0, 1, 2, 3], [4, 5, 6, 7]]), + mesh=torch.Tensor( + [ + [0, 1, 2, 3], + [4, 5, 6, 7], + ] + ), ) result_names = list(node.meta["mase"]["common"]["results"].keys()) @@ -42,20 +96,18 @@ def create_wrapper(node): def dtensor_wrapper_fn(*args, **kwargs): out = target_fn(*args, **kwargs) - # In the event the OpDispatcher already wrapped a DTensor around - # the local result, avoid reaching recursive depth limit if isinstance(out, (tuple, list)): outs = [] for r_idx, r in enumerate(out): - if isinstance(r, DTensor): - outs.append(r) - elif isinstance(r, torch.Tensor): + # if isinstance(r, DTensor): + # outs.append(r) + if isinstance(r, torch.Tensor): outs.append( _create_dtensor( - r, - result_meta=node.meta["mase"]["common"]["results"][ - result_names[r_idx] - ], + local_tensor=r, + node_name=node.name, + node_meta=node.meta["mase"], + result_name=result_names[r_idx], torch_mesh=torch_mesh, ) ) @@ -64,13 +116,17 @@ def dtensor_wrapper_fn(*args, **kwargs): wrapped_out = tuple(outs) - elif isinstance(out, DTensor): - wrapped_out = out + # In the event the OpDispatcher already wrapped a DTensor around + # the local result, avoid reaching recursive depth limit + # elif isinstance(out, DTensor): + # wrapped_out = out elif isinstance(out, torch.Tensor): wrapped_out = _create_dtensor( - out, - result_meta=node.meta["mase"]["common"]["results"][result_names[0]], + local_tensor=out, + node_name=node.name, + node_meta=node.meta["mase"], + result_name=result_names[0], torch_mesh=torch_mesh, ) @@ -84,17 +140,15 @@ def dtensor_wrapper_fn(*args, **kwargs): def insert_dtensor_wrapper_transform_pass(mg, pass_args={}): - logger.info("Inserting DTensor wrappers for call_function nodes") + rlog("Inserting DTensor wrappers for call_function nodes") for node in mg.nodes: if node.op == "call_function": - logger.info(f"Inserting DTensor wrapper for {node.name}") + logger.debug(f"Inserting DTensor wrapper for {node.name}") node.target = create_wrapper(node) else: - logger.warning( - f"Skipping node {node.name} because it is not a call_function" - ) + logger.debug(f"Skipping node {node.name} because it is not a call_function") return mg, {} diff --git a/src/chop/passes/graph/transforms/resharding.py b/src/chop/passes/graph/transforms/resharding.py index 07151e310..51cbc61d0 100644 --- a/src/chop/passes/graph/transforms/resharding.py +++ b/src/chop/passes/graph/transforms/resharding.py @@ -1,4 +1,4 @@ -from copy import copy +import operator import torch import torch.fx as fx @@ -32,9 +32,10 @@ def _insert_resharding_nodes(mg, pass_args={}): if node.op != "output" and len(flattened_args) != len( node.meta["mase"]["common"]["args"] ): - logger.warning( - f"Skipping node: {node.name} because number of arguments do not match metadata." - ) + if "getitem" not in node.name: + logger.warning( + f"Skipping node: {node.name} because number of arguments do not match metadata." + ) continue for arg_idx, arg_name in enumerate(node.meta["mase"]["common"]["args"].keys()): From ea6bbe5306075e1d3b176cff8b4d1fdd238c6259 Mon Sep 17 00:00:00 2001 From: Pedro Gimenes Date: Thu, 15 Aug 2024 10:33:51 +0000 Subject: [PATCH 72/93] remove deprecated files in src/chop/distributed/tensor --- src/chop/distributed/tensor/__init__.py | 21 +- src/chop/distributed/tensor/_dispatch.py | 17 - src/chop/distributed/tensor/_sharding_prop.py | 396 ------- src/chop/distributed/tensor/ops/__init__.py | 10 - .../distributed/tensor/ops/basic_strategy.py | 181 --- .../distributed/tensor/ops/common_rules.py | 288 ----- src/chop/distributed/tensor/ops/conv_ops.py | 109 -- .../distributed/tensor/ops/embedding_ops.py | 251 ---- .../tensor/ops/experimental_ops.py | 26 - src/chop/distributed/tensor/ops/math_ops.py | 1035 ----------------- src/chop/distributed/tensor/ops/matrix_ops.py | 459 -------- .../distributed/tensor/ops/pointwise_ops.py | 642 ---------- src/chop/distributed/tensor/ops/random_ops.py | 38 - src/chop/distributed/tensor/ops/tensor_ops.py | 797 ------------- src/chop/distributed/tensor/ops/utils.py | 300 ----- src/chop/distributed/tensor/ops/view_ops.py | 665 ----------- src/chop/pipelines/distributed_inference.py | 14 +- 17 files changed, 32 insertions(+), 5217 deletions(-) delete mode 100644 src/chop/distributed/tensor/_sharding_prop.py delete mode 100644 src/chop/distributed/tensor/ops/__init__.py delete mode 100644 src/chop/distributed/tensor/ops/basic_strategy.py delete mode 100644 src/chop/distributed/tensor/ops/common_rules.py delete mode 100644 src/chop/distributed/tensor/ops/conv_ops.py delete mode 100644 src/chop/distributed/tensor/ops/embedding_ops.py delete mode 100644 src/chop/distributed/tensor/ops/experimental_ops.py delete mode 100644 src/chop/distributed/tensor/ops/math_ops.py delete mode 100644 src/chop/distributed/tensor/ops/matrix_ops.py delete mode 100644 src/chop/distributed/tensor/ops/pointwise_ops.py delete mode 100644 src/chop/distributed/tensor/ops/random_ops.py delete mode 100644 src/chop/distributed/tensor/ops/tensor_ops.py delete mode 100644 src/chop/distributed/tensor/ops/utils.py delete mode 100644 src/chop/distributed/tensor/ops/view_ops.py diff --git a/src/chop/distributed/tensor/__init__.py b/src/chop/distributed/tensor/__init__.py index 3d6067d28..eecce89dc 100644 --- a/src/chop/distributed/tensor/__init__.py +++ b/src/chop/distributed/tensor/__init__.py @@ -14,10 +14,8 @@ ) from torch.distributed.device_mesh import _mesh_resources, DeviceMesh, init_device_mesh -import chop.distributed.tensor.ops from chop.distributed.tensor._utils import compute_local_shape from chop.distributed.tensor.api import distribute_module, distribute_tensor, DTensor -from chop.distributed.tensor.ops.utils import normalize_to_torch_size # All public APIs from dtensor package @@ -33,6 +31,25 @@ ] +def normalize_to_torch_size(size) -> torch.Size: + """ + Unify variable types of size argument to torch.Size + Acceptable types include: + int, Sequence[int], Tuple[int], Tuple[Sequence[int]], + or torch.Size + """ + if isinstance(size, torch.Size): + return size + + if isinstance(size, int): + torch_size = [size] + elif len(size) == 1 and isinstance(size[0], Sequence): + torch_size = list(size[0]) + else: + torch_size = list(size) + return torch.Size(torch_size) + + def _dtensor_init_helper( init_op, size: torch.Size, diff --git a/src/chop/distributed/tensor/_dispatch.py b/src/chop/distributed/tensor/_dispatch.py index 688a11868..3f9d38043 100644 --- a/src/chop/distributed/tensor/_dispatch.py +++ b/src/chop/distributed/tensor/_dispatch.py @@ -18,26 +18,10 @@ from torch.distributed.device_mesh import DeviceMesh import chop.distributed.tensor.api as dtensor -from chop.distributed.tensor._sharding_prop import ShardingPropagator aten = torch.ops.aten -def get_replicate_spec(tensor_arg: torch.Tensor, mesh: "DeviceMesh") -> DTensorSpec: - # scalar tensor can be safely treated as replicated - replication_spec = DTensorSpec( - mesh, - (Replicate(),) * mesh.ndim, - tensor_meta=TensorMeta( - shape=tensor_arg.shape, - stride=tensor_arg.stride(), - dtype=tensor_arg.dtype, - ), - ) - - return replication_spec - - def decompose_handler( op_call: torch._ops.OpOverload, args: Tuple[object, ...], @@ -72,7 +56,6 @@ class OpDispatcher: """ def __init__(self) -> None: - self.sharding_propagator = ShardingPropagator() self._random_ops = { aten.native_dropout.default, aten.normal_.default, diff --git a/src/chop/distributed/tensor/_sharding_prop.py b/src/chop/distributed/tensor/_sharding_prop.py deleted file mode 100644 index c03046319..000000000 --- a/src/chop/distributed/tensor/_sharding_prop.py +++ /dev/null @@ -1,396 +0,0 @@ -# mypy: allow-untyped-defs -from functools import lru_cache -from itertools import chain -from typing import Callable, cast, Dict, List, Optional, Sequence, Tuple, Union - -import torch -from torch._ops import OpOverload -from torch._subclasses import FakeTensorMode -from torch.distributed._tensor._op_schema import ( - OpInfo, - OpSchema, - OpStrategy, - OutputSharding, - OutputSpecType, - PlacementStrategy, - RuntimeSchemaInfo, - StrategyType, - TupleStrategy, -) -from torch.distributed._tensor._utils import ( - compute_local_shape, - compute_local_stride, - try_find_mesh_from_args, -) -from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta -from torch.distributed.device_mesh import DeviceMesh - - -aten = torch.ops.aten - - -def _length(obj) -> int: - if obj is None: - return 0 - if not isinstance(obj, Sequence): - return 1 - return len(obj) - - -def spec_to_strategy(spec: object) -> object: - if isinstance(spec, DTensorSpec): - return OpStrategy([PlacementStrategy(spec)]) - elif ( - isinstance(spec, (list, tuple)) - and len(spec) > 0 - and isinstance(spec[0], DTensorSpec) - ): - # tensor list create tuple strategy - tuple_strategy = [spec_to_strategy(s) for s in spec] - tuple_strategy = cast(Sequence[StrategyType], tuple_strategy) - return TupleStrategy( - tuple(tuple_strategy) if isinstance(spec, tuple) else tuple_strategy - ) - else: - return spec - - -def rlog(msg): - rank = torch.distributed.get_rank() - if rank == 0: - print(msg) - - -class ShardingPropagator: - def __init__(self) -> None: - self.op_to_rules: Dict[OpOverload, Callable[[OpSchema], OutputSharding]] = {} - self.op_strategy_funcs: Dict[ - OpOverload, - Callable[[DeviceMesh, OpSchema], StrategyType], - ] = {} - # op map to save static argnum to decide to reuse sharding prop cache or re-run sharding prop - self.op_to_schema_info: Dict[OpOverload, RuntimeSchemaInfo] = {} - self.propagate_op_sharding = lru_cache(None)(self.propagate_op_sharding_non_cached) # type: ignore[method-assign] - # op map to save indices of shape (and stride) args which may need to be modified in sharding prop - self.op_to_shape_and_stride_idx: Dict[ - OpOverload, Union[int, Tuple[int, int]] - ] = { - # new factory ops - aten.new_empty.default: 1, - aten.new_full.default: 1, - aten.new_ones.default: 1, - aten.new_zeros.default: 1, - aten.new_empty_strided.default: (1, 2), - # view ops - aten.expand.default: 1, - aten.reshape.default: 1, - aten.view.default: 1, - aten._unsafe_view.default: 1, - } - - def register_sharding_prop_rule( - self, - op_overload: OpOverload, - rule_func: Callable[[OpSchema], OutputSharding], - schema_info: Optional[RuntimeSchemaInfo] = None, - ): - """ - Register a sharding propagation rule for an operator. - """ - self.op_to_rules[op_overload] = rule_func - if schema_info is not None: - self.op_to_schema_info[op_overload] = schema_info - - def register_op_strategy( - self, - op_overload: OpOverload, - strategy_func: Callable[[DeviceMesh, OpSchema], StrategyType], - schema_info: Optional[RuntimeSchemaInfo] = None, - ): - """ - Register a sharding strategy generator for an operator. - """ - self.op_strategy_funcs[op_overload] = strategy_func - if schema_info is not None: - self.op_to_schema_info[op_overload] = schema_info - - @lru_cache # noqa: B019 - def _propagate_tensor_meta( - self, op_schema: OpSchema - ) -> Union[None, TensorMeta, Sequence[Optional[TensorMeta]]]: - """ - Propagate the tensor metadata, it could either return a TensorMeta - or a list/tuple of TensorMetas - """ - if op_schema.op == aten.equal.default: - # data dependent ops can't be used for fake propagation - return None - - # NOTE: We must call the tracing in fake tensor mode so that it - # avoids materializing memory - with FakeTensorMode(): - fake_args = op_schema.gen_fake_args() - fake_kwargs = op_schema.gen_fake_kwargs() - fake_out = op_schema.op(*fake_args, **fake_kwargs) - - if isinstance(fake_out, torch.Tensor): - return TensorMeta( - shape=fake_out.shape, stride=fake_out.stride(), dtype=fake_out.dtype - ) - - elif isinstance(fake_out, (tuple, list)): - tensor_meta_list: List[Optional[TensorMeta]] = [] - for fake_out_item in fake_out: - if isinstance(fake_out_item, torch.Tensor): - tensor_meta_list.append( - TensorMeta( - shape=fake_out_item.shape, - stride=fake_out_item.stride(), - dtype=fake_out_item.dtype, - ) - ) - else: - tensor_meta_list.append(None) - return ( - tuple(tensor_meta_list) - if isinstance(fake_out, tuple) - else tensor_meta_list - ) - else: - # if fake is not a tensor or tuple of tensor, return as none - return None - - def _wrap_output_spec_tensor_meta( - self, - op: OpOverload, - output_specs: OutputSpecType, - output_tensor_meta: Union[None, TensorMeta, Sequence[Optional[TensorMeta]]], - ) -> None: - """ - Wrap the output_specs with the tensor metadata from the output. - """ - - if isinstance(output_specs, DTensorSpec): - if not isinstance(output_tensor_meta, TensorMeta): - # Either error due to ShardingPropagator or due to incorrect OutputSpec - if not isinstance(output_tensor_meta, (tuple, list)): - raise ValueError( - "ShardingPropagator error: output does not have an associated TensorMeta" - ) - raise ValueError( - f"For the op {op.name()}, `output_specs` has 1 output which does not equal the " - f"number of op outputs: {len(output_tensor_meta)}." - ) - output_specs.tensor_meta = output_tensor_meta - elif isinstance(output_specs, (tuple, list)): - if not isinstance(output_tensor_meta, (tuple, list)): - raise ValueError( - f"For the op {op.name()}, `output_specs` has type {type(output_specs)} but output_tensor_meta has type {type(output_tensor_meta)}" - f"Both should be tuple or list." - ) - - if len(output_specs) != len(output_tensor_meta): - - raise ValueError( - f"For the op {op.name()}, `output_specs` has {len(output_specs)} outputs which does not equal the " - f"number of op outputs {_length(output_tensor_meta)}." - ) - - for i, spec in enumerate(output_specs): - if isinstance(spec, DTensorSpec): - output_tensor_meta_i = output_tensor_meta[i] - if not isinstance(output_tensor_meta_i, TensorMeta): - raise ValueError( - f"ShardingPropagator error: output {i} does not have an associated TensorMeta" - ) - spec.tensor_meta = output_tensor_meta_i - - def propagate(self, op_info: OpInfo) -> None: - # We cannot use an lru cache if we know that inputs will have dynamic shapes, - # because SymInts are not hashable. - # This is generally ok because this only happens during tracing in torch.compile, - # and tracing does not need to be as fast as eagermode DTensor usages. - if op_info.schema.has_symints: - output_sharding = self.propagate_op_sharding_non_cached(op_info.schema) - else: - output_sharding = self.propagate_op_sharding(op_info.schema) - op_info.output_sharding = output_sharding - - return output_sharding - - def propagate_op_sharding_non_cached(self, op_schema: OpSchema) -> OutputSharding: - """ - Propagate the sharding for an operator given the op_schema. - """ - # special case op, we don't need to propagate for local - # scalar. TODO: figure out a better way to handle this - if op_schema.op is aten._local_scalar_dense.default: - return OutputSharding(None, op_schema) - - if op_schema.op in self.op_strategy_funcs: - # generate op strategy for the op. - mesh = try_find_mesh_from_args(op_schema.op, op_schema.args_schema) - # swap the args spec with args strategies - args_op_strategy = [spec_to_strategy(i) for i in op_schema.args_schema] - - kwargs_op_strategy = { - k: spec_to_strategy(v) for k, v in op_schema.kwargs_schema.items() - } - - # construct a new OpSchema on args for strategy based propagation - strategy_schema: OpSchema = OpSchema( - op=op_schema.op, - args_schema=tuple(args_op_strategy), - kwargs_schema=kwargs_op_strategy, - ) - - op_strategy = self.op_strategy_funcs[op_schema.op](mesh, strategy_schema) - - if isinstance(op_strategy, OpStrategy): - # single Op strategy - output_strategy = self._select_strategy(op_strategy) - - # construct output spec for the op - if op_schema.return_type_tuple_tensor_like(): - # for ops that return multiple tensors and the output_specs is not - # a tuple, we use a tuple of that single output spec as the new - # output_specs - output_specs: OutputSpecType = output_strategy.output_specs - if isinstance(output_specs, DTensorSpec): - output_specs = tuple( - [ - # create a new DTensorSpec with the same placement as the - # output_specs in output_strategy - DTensorSpec( - mesh=output_specs.mesh, - placements=output_specs.placements, - tensor_meta=output_specs.tensor_meta, - ) - for _ in range(len(op_schema.op._schema.returns)) - ] - ) - elif op_schema.return_type_tensor(): - output_specs = output_strategy.output_specs - else: - output_specs = None - - output_sharding = OutputSharding( - output_specs, - redistribute_schema=None, - needs_redistribute=False, - ) - elif isinstance(op_strategy, TupleStrategy): - # tuple strategy output sharding processing - # runtime selected placement strategy for each TupleStrategy input arg - selected_strategies: List[PlacementStrategy] = [] - out_spec_list: List[DTensorSpec] = [] - for strategy in op_strategy.childs: - selected_strategy = self._select_strategy(strategy) - selected_strategies.append(selected_strategy) - out_spec_list.append(selected_strategy.output_spec) - - suggestion_args: List[object] = [] - tensor_or_list_tensor_arg_idx = 0 - - for arg in op_schema.args_schema: - if ( - arg - and isinstance(arg, (list, tuple)) - and isinstance(arg[0], DTensorSpec) - ): - expected_input_spec_list: List[DTensorSpec] = [] - for idx, arg_spec in enumerate(arg): - expected_input_spec = selected_strategies[idx].input_spec( - tensor_or_list_tensor_arg_idx - ) - expected_input_spec = ( - expected_input_spec.shallow_copy_with_tensor_meta( - arg_spec.tensor_meta - ) - ) - expected_input_spec_list.append(expected_input_spec) - suggestion_args.append( - tuple(expected_input_spec_list) - if isinstance(arg, tuple) - else expected_input_spec_list - ) - tensor_or_list_tensor_arg_idx += 1 - - elif isinstance(arg, DTensorSpec): - expected_input_spec = selected_strategies[0].input_spec( - tensor_or_list_tensor_arg_idx - ) - expected_input_spec = ( - expected_input_spec.shallow_copy_with_tensor_meta( - arg.tensor_meta - ) - ) - suggestion_args.append(expected_input_spec) - tensor_or_list_tensor_arg_idx += 1 - else: - suggestion_args.append(arg) - - output_sharding = OutputSharding( - tuple(out_spec_list), - redistribute_schema=None, - needs_redistribute=False, - ) - else: - raise ValueError("Unsupported op strategy type") - - return output_sharding - elif op_schema.op in self.op_to_rules: - # propagate the sharding with rule - sharding_prop_func = self.op_to_rules[op_schema.op] - - # step 1. there's sharding propagation rule, run - # sharding propagation to get the output sharding - try: - output_sharding = sharding_prop_func(op_schema) - except NotImplementedError as e: - raise e - except Exception as e: - raise RuntimeError( - f"Sharding propagation failed on op {op_schema}.\n" f"Error: {e}" - ) from e - - # step 2. if can't get output_spec from sharding - # propagation (i.e. no rules apply for input - # placements), we return the output sharding - # with schema suggestions, which can be used to - # decide how to do redistribute on inputs - if output_sharding.output_spec is None: - if output_sharding.redistribute_schema is None: - raise RuntimeError( - f"Sharding propagation failed on op {op_schema}!" - ) - else: - # we do auto redistribute on inputs if necessary - # run sharding propagation again with suggested schema - propagation_res = sharding_prop_func( - output_sharding.redistribute_schema - ) - # we set the output sharding with the new propagation result - # so that dispatching know both output_spec and redistribute_schema - # exist, which indicates a reshard is needed - output_sharding.output_spec = propagation_res.output_spec - output_sharding.needs_redistribute = True - - return output_sharding - else: - raise NotImplementedError( - f"Operator {op_schema.op} does not have a sharding strategy registered." - ) - - def _select_strategy(self, strategy: OpStrategy) -> PlacementStrategy: - if len(strategy.strategies) == 1: - # short cut with only one possible strategy - return strategy.strategies[0] - - strategy_costs: List[float] = [] - for strtg in strategy.strategies: - redistribute_cost = sum(chain.from_iterable(strtg.redistribute_cost)) - strategy_costs.append(redistribute_cost) - - # for eager execution, we just select the one with the minimal redistribute cost - return strategy.strategies[strategy_costs.index(min(strategy_costs))] diff --git a/src/chop/distributed/tensor/ops/__init__.py b/src/chop/distributed/tensor/ops/__init__.py deleted file mode 100644 index eaccc8aa8..000000000 --- a/src/chop/distributed/tensor/ops/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -from .conv_ops import * # noqa: F403 -from .embedding_ops import * # noqa: F403 -from .experimental_ops import * # noqa: F403 -from .math_ops import * # noqa: F403 -from .matrix_ops import * # noqa: F403 -from .pointwise_ops import * # noqa: F403 -from .random_ops import * # noqa: F403 -from .tensor_ops import * # noqa: F403 -from .view_ops import * # noqa: F403 diff --git a/src/chop/distributed/tensor/ops/basic_strategy.py b/src/chop/distributed/tensor/ops/basic_strategy.py deleted file mode 100644 index 97dd43b15..000000000 --- a/src/chop/distributed/tensor/ops/basic_strategy.py +++ /dev/null @@ -1,181 +0,0 @@ -import itertools -from dataclasses import dataclass -from typing import List, Set, Tuple - -from torch.distributed._tensor._op_schema import OpStrategy, PlacementStrategy -from torch.distributed._tensor.placement_types import ( - DTensorSpec, - Partial, - Placement, - Replicate, - Shard, -) -from torch.distributed.device_mesh import DeviceMesh - - -@dataclass -class EinsumDims: - contracting_dims: List[str] - batch_dims: List[str] - lhs_out_only_dims: List[str] - rhs_out_only_dims: List[str] - - @classmethod - def parse_equation(cls, equation: str) -> Tuple[List[str], str]: - # parse einop equation and extract arg specs - """ - Parse the einsum equation str to input dim chars and output dim char - """ - inputs, outputs = equation.split("->") - input_dims, output_dims = inputs.split(","), outputs.split(",") - - # NOTE: only support at most two inputs, and single output - # extend to support more inputs if needed in future - assert len(input_dims) <= 2, "Only support at most two inputs" - assert len(output_dims) == 1, "Only support single output" - output_dim = output_dims[0] - return input_dims, output_dim - - @classmethod - def parse_dims(cls, input_dims: List[str], output_dim: str) -> "EinsumDims": - """ - Parse the dims and extract the contracting, batch, and free dimensions - for the left and right hand sides. - """ - dim_char_set: Set[str] = set() - for input_dim in input_dims: - dim_char_set.update(input_dim) - - # get a determinisitc order of all dim chars - all_dim_chars = sorted(dim_char_set) - - # parse input and output dimensions - lhs_out_only_dims, rhs_out_only_dims = [], [] - batch_dims, contracting_dims = [], [] - - for dim_char in all_dim_chars: - if dim_char not in output_dim: - contracting_dims.append(dim_char) - else: - is_batch_dim = True - for input_dim in input_dims: - is_batch_dim = is_batch_dim and dim_char in input_dim - - if is_batch_dim: - batch_dims.append(dim_char) - else: - assert ( - len(input_dims) == 2 - ), "free dimension only supported for two inputs!" - lhs, rhs = input_dims - if dim_char in lhs: - lhs_out_only_dims.append(dim_char) - elif dim_char in rhs: - rhs_out_only_dims.append(dim_char) - else: - raise RuntimeError("Invalid dimension character") - - return cls( - contracting_dims=contracting_dims, - batch_dims=batch_dims, - lhs_out_only_dims=lhs_out_only_dims, - rhs_out_only_dims=rhs_out_only_dims, - ) - - -def gen_einsum_strategies( - equation: str, - mesh: DeviceMesh, - *, - linearity: bool = False, -) -> OpStrategy: - """ - Generate a strategy list for the ops that follow einsum style notation. - """ - # parse einop equation and extract dims - input_dims, output_dim = EinsumDims.parse_equation(equation) - edims = EinsumDims.parse_dims(input_dims, output_dim) - - all_mesh_dim_strategies = [] - - # generate strategies for each mesh dim - for mesh_dim in range(mesh.ndim): - mesh_dim_strategies = [] - - # placement list stores placements of [output, input1, input2, ...] - # first we always have replicate all for inputs and output - placement_list: List[Placement] = [Replicate()] * (len(input_dims) + 1) - mesh_dim_strategies.append(placement_list) - - if mesh.size(mesh_dim) <= 1: - # only replicate strategy for mesh dim with size 1 - # TODO: see if this is valid for the submesh case - continue - - # split batch dim - for batch_dim in edims.batch_dims: - output_batch_dim = output_dim.index(batch_dim) - placement_list = [Shard(output_batch_dim)] - for input_dim in input_dims: - input_batch_dim = input_dim.index(batch_dim) - placement_list.append(Shard(input_batch_dim)) - - mesh_dim_strategies.append(placement_list) - - # split contracting dim - for contracting_dim in edims.contracting_dims: - placement_list = [Partial()] - for input_dim in input_dims: - input_contracting_dim = input_dim.index(contracting_dim) - placement_list.append(Shard(input_contracting_dim)) - - mesh_dim_strategies.append(placement_list) - - # split lhs free dim - for lhs_dim in edims.lhs_out_only_dims: - lhs_free_dim = output_dim.index(lhs_dim) - # this means split the lhs input and output - # i.e. S(0), R -> S(0) - lhs_placement_list: List[Placement] = [ - Shard(lhs_free_dim), - Shard(lhs_free_dim), - Replicate(), - ] - mesh_dim_strategies.append(lhs_placement_list) - - # split rhs free dim - for rhs_dim in edims.rhs_out_only_dims: - rhs_free_dim = output_dim.index(rhs_dim) - rhs_placement_list: List[Placement] = [ - Shard(rhs_free_dim), - Replicate(), - Shard(rhs_free_dim), - ] - mesh_dim_strategies.append(rhs_placement_list) - - # linearity strategy - if linearity: - linearity_placement_list: List[Placement] = [Partial()] - for input_dim in input_dims: - linearity_placement_list.append(Partial()) - mesh_dim_strategies.append(linearity_placement_list) - - all_mesh_dim_strategies.append(mesh_dim_strategies) - - # generate strategies for entire mesh - strategy_combs = itertools.product(*all_mesh_dim_strategies) - - # TODO: filter out invalid strategies, at this point we generate - # all possible strategies without considering the whether the tensor - # dim could be sharded or not, we would need to filter out invalid - # strategies base on the actual tensor shape - # (i.e. for Shard, tensor dim size must > mesh size) - all_strategies = [] - for strategy_comb in strategy_combs: - spec_list = [] - for specs in zip(*strategy_comb): - spec_list.append(DTensorSpec(mesh, tuple(specs))) - strat = PlacementStrategy(output_specs=spec_list[0], input_specs=spec_list[1:]) - all_strategies.append(strat) - - return OpStrategy(all_strategies) diff --git a/src/chop/distributed/tensor/ops/common_rules.py b/src/chop/distributed/tensor/ops/common_rules.py deleted file mode 100644 index f70b27076..000000000 --- a/src/chop/distributed/tensor/ops/common_rules.py +++ /dev/null @@ -1,288 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -from typing import cast, Dict, List, Optional, Tuple - -import torch -from torch.distributed._tensor._op_schema import ( - _is_inplace_op, - _is_out_variant_op, - OpSchema, - OutputSharding, -) -from torch.distributed._tensor._utils import compute_local_shape -from torch.distributed._tensor.ops.utils import prod -from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta - - -def _replace_char_in_str(string: str, new_char: str, idx: int) -> str: - return string[:idx] + new_char + string[idx + 1 :] - - -def _gen_reshard_suggestions( - op_schema: OpSchema, - input_dims: List[str], - input_specs: Tuple[DTensorSpec, ...], - dim_to_sharding: Dict[str, int], - pending_sum: List[int], -) -> OutputSharding: - suggested_arg_specs: List[DTensorSpec] = [] - for input_dim, input_spec in zip(input_dims, input_specs): - dim_map = [dim_to_sharding[dim] for dim in input_dim] - suggested_arg_specs.append( - DTensorSpec.from_dim_map( - mesh=input_spec.mesh, - dim_map=dim_map, - sums=pending_sum, - tensor_meta=input_spec.tensor_meta, - ) - ) - suggested_schema = OpSchema(op_schema.op, tuple(suggested_arg_specs), {}) - suggested_schema._inplace_rewrap_schema_suggestion(op_schema) - return OutputSharding( - None, - redistribute_schema=suggested_schema, - ) - - -def einop_rule( - equation: str, - op_schema: OpSchema, - *, - linearity: bool = False, - enforce_sharding: Optional[Dict[str, int]] = None, -) -> OutputSharding: - """ - Propagate the sharding of inputs to output for ops whose data moves according to einsum notation. - - This is mostly borrowed from @zdevito's sharding simulator. Examples: - mk,kn->mn - einsum - ij,ij->ij - addition - ij,j->ij - broadcasted addition - ij->i - reduction - Other ops could use this propagation algorithm when applied, note - that einsum propagation only deal with list of specs (DTensor specs) - as it only works on list of tensors! - - linearity in einop_rule means that the calling op `f` follows this rule: - f(a + b) = f(a) + f(b) - - In this case we can propagate the partial sum, note that linearity in einop - only applies to partial sum, not other operations like min/max (which are - associative but not linear). - """ - # parse einop equation and extract arg specs - inputs, outputs = equation.split("->") - input_dims, output_dims = inputs.split(","), outputs.split(",") - input_specs = op_schema.args_spec - # NOTE: only support single output unless needed in future - output_dim = output_dims[0] - - dim_to_sharding: Dict[str, int] = {} - dim_to_size: Dict[str, int] = {} - # record pending sum, key is mesh dimension, value is pending sum - # counter across input specs - pending_sums_counter: Dict[int, int] = {} - seen_shardings: Dict[int, str] = {} - needs_reshard = False - - def merge_sharding(dim: str, a: int, b: int) -> int: - # merge the sharding of inputs if it's able to merge, i.e. we can merge - # replicate and shard to shard, but this will trigger an reshard operation - if a != b: - if a == -1 or b == -1: - # reshard the replicate to match the sharded one - nonlocal needs_reshard - needs_reshard = True - return a if a != -1 else b - else: - # TODO: further merge the sharding properly (i.e. reshard one input to replicate) - raise RuntimeError( - f"{equation}: dim {dim} sharded two different ways: {a} and {b}" - ) - else: - return a - - for input_dim, input_spec in zip(input_dims, input_specs): - # deal with partial sums - input_sums = input_spec.sums - for sum_dim in input_sums: - if sum_dim not in pending_sums_counter: - seen_shardings[sum_dim] = "+" - # update pending sum counter for pending sum mesh - # dimension with the occurrence from each input - pending_sums_counter[sum_dim] = pending_sums_counter.get(sum_dim, 0) + 1 - - for idx, (dim, mesh_dim) in enumerate(zip(input_dim, input_spec.dim_map)): - if enforce_sharding and dim in enforce_sharding: - if enforce_sharding[dim] != mesh_dim: - needs_reshard = True - dim_to_sharding[dim] = enforce_sharding[dim] - dim_to_size[dim] = input_spec.shape[idx] - elif dim not in dim_to_sharding: - dim_to_sharding[dim] = mesh_dim - dim_to_size[dim] = input_spec.shape[idx] - else: - dim_to_sharding[dim] = merge_sharding( - dim, dim_to_sharding[dim], mesh_dim - ) - assert dim_to_size[dim] == input_spec.shape[idx] - - # after merging sharding, we check if there're multiple - # sharding on the same mesh dim. - merged_sharding_for_dim = dim_to_sharding[dim] - if merged_sharding_for_dim != -1: - if ( - merged_sharding_for_dim in seen_shardings - and dim != seen_shardings[merged_sharding_for_dim] - ): - needs_reshard = True - seen_shardings[merged_sharding_for_dim] += dim - else: - seen_shardings[merged_sharding_for_dim] = dim - - if pending_sums_counter and not linearity: - # return reshard suggestion with no pending sum, because we already properly - # merge the sharding, this reshard suggestion is legit to use - return _gen_reshard_suggestions( - op_schema, input_dims, input_specs, dim_to_sharding, [] - ) - else: - # It's a op that support linearity, but not all input arguments are partial - # we fail the sharding propagation with suggestion to make all inputs be - # partial on the corresponding mesh dim (all inputs should be partial for - # the mesh dims in order to execute locally and delay the sum reduction) - for value in pending_sums_counter.values(): - if value != len(input_specs): - needs_reshard = True - - for mesh_dim, dims in seen_shardings.items(): - if len(dims) > 1: - # we found different input dims are being sharded on the same mesh dim - # in order to perform local op computation, we need to reshard inputs - # base on some simple heuristics, now we simply pick the one with least comm - # volume. (i.e. the input with least size) - # TODO: consider a more advanced heuristic to pick the best sharding - costs = [] - for d in dims: - cost = 0 - for input_dim, input_spec in zip(input_dims, input_specs): - if ( - d in input_dim - and input_spec.dim_map[input_dim.index(d)] == mesh_dim - ): - assert input_spec.tensor_meta is not None - global_shape = input_spec.tensor_meta.shape - local_shape = compute_local_shape( - global_shape, input_spec.mesh, input_spec.placements - ) - cost += prod(local_shape) * input_spec.mesh.size(mesh_dim) - costs.append(cost) - d_to_keep_sharding = dims[costs.index(max(costs))] - for d in dims: - # update dim_to_sharding to keep the sharding of the dim with - # highest comm and make the rest of the dims to replicate - if d != d_to_keep_sharding: - dim_to_sharding[d] = -1 - - pending_sums = list(pending_sums_counter.keys()) - if needs_reshard: - return _gen_reshard_suggestions( - op_schema, input_dims, input_specs, dim_to_sharding, pending_sums - ) - - # generate output pending sum if a dim is sharded, and it appears in input - # but not output - for dim, shard_on_mesh in dim_to_sharding.items(): - if dim not in output_dims[0] and shard_on_mesh != -1: - pending_sums.append(shard_on_mesh) - - # if no need to reshard, we directly generate the output sharding - output_dim_map = [] - output_shape = [] - for dim in output_dim: - if dim == "1": - # find output dim that is a singleton dimension, mark sharding and shape - output_dim_map.append(-1) - output_shape.append(1) - else: - output_dim_map.append(dim_to_sharding[dim]) - output_shape.append(dim_to_size[dim]) - - # XXX: since we still need to have intermediate shape calculation, we need - # to pass in the shape here. We should remove this once sharding decomp works - # for ops like addmm - assert input_specs[0].tensor_meta is not None - tensor_meta = TensorMeta( - torch.Size(output_shape), - input_specs[0].tensor_meta.stride, - input_specs[0].tensor_meta.dtype, - ) - return OutputSharding( - DTensorSpec.from_dim_map( - input_specs[0].mesh, - output_dim_map, - pending_sums, - tensor_meta=tensor_meta, - ) - ) - - -def pointwise_rule(op_schema: OpSchema, linearity: bool = False) -> OutputSharding: - """ - Propagate the sharding for pointwise operations. - - Examples: - ij,ij->ij - addition/mul - ij,j->ij - broadcasted addition - """ - alphabet = "abcdefghijklmnopqrstuvwxyz" - # find the max_dim first in case we need to broadcasting - input_specs = op_schema.args_spec - max_dim = max(input.ndim for input in input_specs) - dimchars = [] - singleton_counter: List[int] = [0] * max_dim - for input in input_specs: - start_dim = max_dim - input.ndim - p = alphabet[start_dim:max_dim] - # handle the "broadcasting to a common shape case" - # see https://pytorch.org/docs/stable/notes/broadcasting.html - # If any of the dimensions is singleton dimension (i.e. 1). - # we mark the dim char as a special "1" to distinguish with - # the non-singleton dimension, so that sharding propagation - # should just ignore the singleton dimension. - if len(input_specs) > 1: - for i in range(max_dim): - if i < start_dim: - # treat the leading miss dim chars as singleton - singleton_counter[i] += 1 - elif input.shape[i - start_dim] == 1: - # mark singleton dim char as a special "1" in einop rule - singleton_counter[i] += 1 - p = _replace_char_in_str(p, "1", (i - start_dim)) - - dimchars.append(p) - out_dimchars = alphabet[:max_dim] - # check if we replace the all inputs dim char with singleton dimension, - # if we replace all inputs, we also need to replace the output dimension. - for output_dim_idx in range(len(out_dimchars)): - out_dimchar = out_dimchars[output_dim_idx] - if singleton_counter[output_dim_idx] == len(input_specs): - out_dimchars = _replace_char_in_str(out_dimchars, "1", output_dim_idx) - - fmt = f"{','.join(p for p in dimchars)}->{out_dimchars}" - - enforce_sharding: Dict[str, int] = {} - if _is_inplace_op(op_schema.op): - # inplace op should keep the input sharding it writes to - for out_dimchar, mesh_dim in zip(out_dimchars, input_specs[0].dim_map): - enforce_sharding[out_dimchar] = mesh_dim - elif _is_out_variant_op(op_schema.op): - out_spec = cast(DTensorSpec, op_schema.kwargs_schema["out"]) - for out_dimchar, mesh_dim in zip(out_dimchars, out_spec.dim_map): - enforce_sharding[out_dimchar] = mesh_dim - - return einop_rule( - fmt, - op_schema, - linearity=linearity, - enforce_sharding=enforce_sharding, - ) diff --git a/src/chop/distributed/tensor/ops/conv_ops.py b/src/chop/distributed/tensor/ops/conv_ops.py deleted file mode 100644 index 7bf13241d..000000000 --- a/src/chop/distributed/tensor/ops/conv_ops.py +++ /dev/null @@ -1,109 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -# implement matrix related ops for distributed tensor -from typing import List - -import torch -from torch.distributed._tensor._op_schema import OpSchema, OutputSharding -from chop.distributed.tensor.ops.utils import register_prop_rule -from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta - - -aten = torch.ops.aten - - -@register_prop_rule(aten.convolution.default) -def convolution_rules(op_schema: OpSchema) -> OutputSharding: - ( - input_spec, - weight_spec, - bias_spec, - stride, - padding, - dilation, - transposed, - output_padding, - groups, - ) = op_schema.args_schema - - assert isinstance(input_spec, DTensorSpec) - assert isinstance(weight_spec, DTensorSpec) - assert isinstance(bias_spec, DTensorSpec) - assert input_spec.tensor_meta is not None - assert weight_spec.tensor_meta is not None - in_shape = input_spec.tensor_meta.shape - weight_shape = weight_spec.tensor_meta.shape - assert isinstance(stride, List) - assert isinstance(padding, List) - assert isinstance(dilation, List) - assert isinstance(weight_shape, torch.Size) - N, C_in, H_in, W_in = in_shape[0], in_shape[1], in_shape[2], in_shape[3] - C_out = weight_shape[0] - H_out = (H_in + 2 * padding[0] - dilation[0] * (weight_shape[2] - 1) - 1) // stride[ - 0 - ] + 1 - W_out = (W_in + 2 * padding[1] - dilation[1] * (weight_shape[3] - 1) - 1) // stride[ - 1 - ] + 1 - output_shape = [N, C_out, H_out, W_out] - output_stride = (C_out * H_out * W_out, H_out * W_out, W_out, 1) - output_dim_map = input_spec.dim_map - pending_sums = input_spec.sums - - tensor_meta = TensorMeta( - torch.Size(output_shape), - output_stride, - input_spec.tensor_meta.dtype, - ) - return OutputSharding( - DTensorSpec.from_dim_map( - input_spec.mesh, - output_dim_map, - pending_sums, - tensor_meta=tensor_meta, - ) - ) - - -@register_prop_rule(aten.convolution_backward.default) -def convolution_backward_rules(op_schema: OpSchema) -> OutputSharding: - input_spec = op_schema.args_schema[0] - ( - grad_output_spec, - input_spec, - weight_spec, - bias_shape_opt, - stride, - padding, - dilation, - transposed, - output_padding, - groups, - output_mask, - ) = op_schema.args_schema - - assert isinstance(grad_output_spec, DTensorSpec) - assert isinstance(input_spec, DTensorSpec) - assert isinstance(weight_spec, DTensorSpec) - assert isinstance(bias_shape_opt, List) - assert input_spec.tensor_meta is not None - weight_tensor_meta = weight_spec.tensor_meta - bias_tensor_meta = TensorMeta( - torch.Size(bias_shape_opt), - (1,), - input_spec.tensor_meta.dtype, - ) - - grad_input_spec = input_spec - grad_weight_spec = DTensorSpec.from_dim_map( - input_spec.mesh, - [-1, -1, -1, -1], - [0], - tensor_meta=weight_tensor_meta, - ) - grad_bias_spec = DTensorSpec.from_dim_map( - input_spec.mesh, - [-1], - [0], - tensor_meta=bias_tensor_meta, - ) - return OutputSharding([grad_input_spec, grad_weight_spec, grad_bias_spec]) diff --git a/src/chop/distributed/tensor/ops/embedding_ops.py b/src/chop/distributed/tensor/ops/embedding_ops.py deleted file mode 100644 index d89ec651b..000000000 --- a/src/chop/distributed/tensor/ops/embedding_ops.py +++ /dev/null @@ -1,251 +0,0 @@ -# mypy: allow-untyped-defs -# Copyright (c) Meta Platforms, Inc. and affiliates -# implement matrix related ops for distributed tensor -from dataclasses import dataclass, field -from typing import cast, Optional - -import torch -import torch.distributed._functional_collectives as funcol -from torch.distributed._tensor._op_schema import ( - OpSchema, - OpStrategy, - PlacementList, - StrategyType, -) -from torch.distributed._tensor.ops.utils import expand_to_full_mesh_op_strategy -from torch.distributed._tensor.placement_types import ( - Partial, - Placement, - Replicate, - Shard, -) -from torch.distributed.device_mesh import DeviceMesh -from chop.distributed.tensor.ops.utils import register_op_strategy - -aten = torch.ops.aten - - -@dataclass -class MaskBuffer: - data: Optional[torch.Tensor] = None - - def materialize_mask(self, mask): - if self.data is not None: - raise RuntimeError("MaskBuffer has already been materialized") - self.data = mask - - def release_mask(self): - # TODO: evaluate if we need to release the mask buffer or the buffer - # can just have the same lifetime as the Partial placement - if self.data is None: - raise RuntimeError("MaskBuffer has not been materialized") - self.data = None - - def apply_mask(self, tensor): - if self.data is None: - raise RuntimeError("MaskBuffer has not been materialized") - - # NOTE: _MaskPartial is being used by the embedding op and the gather op. - # For gather, the mask has the same dimension as the output tensor, whereas - # the output of the embedding op has an additional dimension compare to the input, - # hence the output masking logic below having two different cases. - if tensor.ndim == self.data.ndim: - tensor[self.data] = 0.0 - else: - tensor[self.data, :] = 0.0 - - -@dataclass(frozen=True) -class _MaskPartial(Partial): - """ - A partial mask placement devised for rowwise sharded embedding op, where we need - to mask and adjust the indices to the local embedding shard, embedding masking - is a special type of the Partial placement - - NOTE: the lifecycle of this MaskPartial placement follows the corresponding DTensor - lifecycle, i.e. the indices_mask would only be alive during the lifetime of the DTensor. - """ - - logical_dim_size: int = -1 - mask_buffer: MaskBuffer = field(default_factory=MaskBuffer) - - def _partition_value( - self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int - ) -> torch.Tensor: - # override parent logic to perform partial mask for embedding - num_chunks = mesh.size(mesh_dim) - # get local shard size and offset on the embedding_dim - local_shard_size, local_offset_on_dim = Shard._local_shard_size_on_dim( - self.logical_dim_size, - num_chunks, - mesh.get_local_rank(mesh_dim), - return_offset=True, - ) - # Build the input mask and save it for the current partial placement - # this is so that the output of embedding op can reuse the same partial - # placement saved mask to perform mask + reduction - mask = (tensor < local_offset_on_dim) | ( - tensor >= local_offset_on_dim + local_shard_size - ) - # mask the input tensor - masked_tensor = tensor.clone() - local_offset_on_dim - masked_tensor[mask] = 0 - # materialize the mask buffer to be used for reduction - self.mask_buffer.materialize_mask(mask) - return masked_tensor - - def _reduce_value( - self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int - ) -> torch.Tensor: - # by the time we ned reduction, we should have already saved the mask - assert self.mask_buffer.data is not None - - # apply the mask to the tensor that pending reduction - self.mask_buffer.apply_mask(tensor) - - # clear the mask buffer - self.mask_buffer.release_mask() - - # perform sum reduction - return funcol.all_reduce( - tensor, reduceOp=self.reduce_op, group=(mesh, mesh_dim) - ) - - def _reduce_shard_value( - self, - tensor: torch.Tensor, - mesh: DeviceMesh, - mesh_dim: int, - shard_spec: Placement, - ) -> torch.Tensor: - # by the time we ned reduction, we should have already saved the mask - assert self.mask_buffer.data is not None - - # apply the mask to the tensor that pending reduction - self.mask_buffer.apply_mask(tensor) - - # clear the mask buffer - self.mask_buffer.release_mask() - - # call reduce_shard_tensor of the shard_spec. - shard_spec = cast(Shard, shard_spec) - return shard_spec._reduce_shard_tensor(tensor, mesh, self.reduce_op, mesh_dim) - - def __eq__(self, other: object) -> bool: - if not isinstance(other, _MaskPartial): - return False - - # if either data is not None, we invalidate the sharding cache, as this indicates - # the current MaskPartial placement is still in use and should not be used for cache hit. - if self.mask_buffer.data is not None or other.mask_buffer.data is not None: - return False - - return ( - self.reduce_op == other.reduce_op - and self.logical_dim_size == other.logical_dim_size - ) - - def __hash__(self) -> int: - return 1 + hash( - (self.logical_dim_size, id(self.mask_buffer.data), self.reduce_op) - ) - - def __repr__(self) -> str: - """ - machine readable representation of the MaskPartial placement - """ - return f"_MaskPartial(logical_dim_size={self.logical_dim_size})" - - def __str__(self) -> str: - """ - human readable representation of the MaskPartial placement - """ - return "MaskP" - - -@register_op_strategy(aten.embedding.default) -def embedding_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: - """ - This strategy handles embedding op. We have two possible embedding shardings: - rowwise and colwise - """ - weight_strategy = cast(OpStrategy, op_schema.args_schema[0]) - indices_strategy = cast(OpStrategy, op_schema.args_schema[1]) - - weight_shape = weight_strategy.shape - indices_shape = indices_strategy.shape - output_emd_dim = len(indices_shape) - - single_mesh_dim_strategies = [] - - # placement list stores placements of [output, weight, input_indices] - # first we always have replicate all for inputs and output - all_replicate: PlacementList = [Replicate()] * 3 - single_mesh_dim_strategies.append(all_replicate) - - # colwise sharding, output shard on last dim, weight shard on dim 1, input replicate - colwise_sharding: PlacementList = [Shard(output_emd_dim), Shard(1), Replicate()] - single_mesh_dim_strategies.append(colwise_sharding) - - # rowwise sharding, output is embedding partial, weight shard on dim 0, input accepts embedding partial - embedding_partial_placement = _MaskPartial(logical_dim_size=weight_shape[0]) - - # NOTE we want to reuse the same mask partial placement so that we can reuse the same mask that generates - # from the input indices and use it for output reduction - rowwise_sharding: PlacementList = [ - embedding_partial_placement, - Shard(0), - embedding_partial_placement, - ] - single_mesh_dim_strategies.append(rowwise_sharding) - - # batch dim sharding, weight replicated, input can shard on any dim, output follows input - for input_dim in range(len(indices_shape)): - batch_sharding: PlacementList = [ - Shard(input_dim), - Replicate(), - Shard(input_dim), - ] - single_mesh_dim_strategies.append(batch_sharding) - - return expand_to_full_mesh_op_strategy(mesh, op_schema, single_mesh_dim_strategies) - - -@register_op_strategy(aten.embedding_dense_backward.default) -def embedding_dense_backward_strategy( - mesh: DeviceMesh, op_schema: OpSchema -) -> StrategyType: - """ - This strategy handles embedding op. We have two possible embedding shardings: - rowwise and colwise - """ - grad_out_strategy = cast(OpStrategy, op_schema.args_schema[0]) - indices_strategy = cast(OpStrategy, op_schema.args_schema[1]) - - grad_out_shape = grad_out_strategy.shape - indices_shape = indices_strategy.shape - grad_out_ndim = len(grad_out_shape) - - single_mesh_dim_strategies = [] - - # placement list stores placements of [output, weight, input_indices] - # first we always have replicate all for inputs and output - all_replicate: PlacementList = [Replicate()] * 3 - single_mesh_dim_strategies.append(all_replicate) - - # colwise sharding backward, grad_out shard on last dim, input replicate, - # weight grad shard colwise - colwise_sharding: PlacementList = [Shard(1), Shard(grad_out_ndim - 1), Replicate()] - single_mesh_dim_strategies.append(colwise_sharding) - - # batch dim sharding, weight replicated, grad_out/input have same sharding - # that can shard on any dim, weight grad partial - for input_dim in range(len(indices_shape)): - batch_sharding: PlacementList = [Partial(), Shard(input_dim), Shard(input_dim)] - single_mesh_dim_strategies.append(batch_sharding) - - # grad_out partial, input replicate, weight grad keep partial - partial_sharding: PlacementList = [Partial(), Partial(), Replicate()] - single_mesh_dim_strategies.append(partial_sharding) - - return expand_to_full_mesh_op_strategy(mesh, op_schema, single_mesh_dim_strategies) diff --git a/src/chop/distributed/tensor/ops/experimental_ops.py b/src/chop/distributed/tensor/ops/experimental_ops.py deleted file mode 100644 index 432fbede8..000000000 --- a/src/chop/distributed/tensor/ops/experimental_ops.py +++ /dev/null @@ -1,26 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -# implement matrix related ops for distributed tensor - -import torch -from torch.distributed._tensor._op_schema import ( - OpSchema, - OpStrategy, - PlacementStrategy, - StrategyType, -) -from torch.distributed._tensor.device_mesh import DeviceMesh -from chop.distributed.tensor.ops.utils import register_op_strategy -from torch.distributed._tensor.placement_types import DTensorSpec, Replicate - - -aten = torch.ops.aten - - -@register_op_strategy(aten.slice_backward.default) -def slice_backward_rules(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: - """ - slice_backward is a new_zeros + slice_scatter, we only allow replication - on the input/output for now since new_zeros would produce replication - """ - replicate_spec = DTensorSpec(mesh, tuple([Replicate()] * mesh.ndim)) - return OpStrategy([PlacementStrategy(replicate_spec)]) diff --git a/src/chop/distributed/tensor/ops/math_ops.py b/src/chop/distributed/tensor/ops/math_ops.py deleted file mode 100644 index 1b2211ffd..000000000 --- a/src/chop/distributed/tensor/ops/math_ops.py +++ /dev/null @@ -1,1035 +0,0 @@ -# mypy: allow-untyped-defs -# Copyright (c) Meta Platforms, Inc. and affiliates -import math -from dataclasses import dataclass -from enum import Enum -from typing import cast, List, Optional, Sequence, Tuple, Union - -import torch -from torch.distributed._tensor._op_schema import ( - OpSchema, - OpStrategy, - PlacementList, - PlacementStrategy, - RuntimeSchemaInfo, - TupleStrategy, -) -from torch.distributed._tensor.ops.utils import ( - as_list, - expand_to_full_mesh_op_strategy, - generate_redistribute_costs, - is_tensor_evenly_shardable, - normalize_dim, - normalize_dims, - normalize_to_torch_size, -) -from torch.distributed._tensor.placement_types import ( - DTensorSpec, - Partial, - Placement, - Replicate, - Shard, -) -from torch.distributed.device_mesh import DeviceMesh -from chop.distributed.tensor.ops.utils import register_op_strategy - -aten = torch.ops.aten - - -class Reduction(Enum): - NONE = 0 - MEAN = 1 - SUM = 2 - - -@dataclass(frozen=True) -class NormReduction: - norm_type: Union[int, float, str] - - -ReductionOpType = Union[NormReduction, str] - - -@dataclass(frozen=True) -class _NormPartial(Partial): - """ - This placement is used for partial vector norm. - - For p-norms (where p not inf or -inf), the p-norm over n elements computes - (sum_i x_i^p)^(1/p) - where the sum is from i=1 to n. The reduction op is the p-norm itself. - For example, consider 2 ranks, a (4,) tensor sharded on dim-0, and 2-norm: - Rank 0: [t1, t2] | Rank 1: [t3, t4] - After computing 2-norm per gradient (partial placement): - Rank 0: [sqrt(t1^2 + t2^2)] | Rank 1: [sqrt(t3^2 + t4^2)] - Converting from partial to replicate wants to ultimately get: - Rank 0/1: [sqrt(t1^2 + t2^2 + t3^2 + t4^2)] - This can be achieved by computing 2-norm on each rank's result. This holds - similarly for inf and -inf norm. For 0-norm, the reduction op is sum. - """ - - norm_type: Union[int, float, str] = 2 - - def __post_init__(self): - """Set the appropriate reduce op based on the norm type.""" - # Use `object.__setattr__` to bypass frozen checks - if self.norm_type in (float("inf"), "inf"): - object.__setattr__(self, "reduce_op", "max") - elif self.norm_type in (float("-inf"), "-inf"): - object.__setattr__(self, "reduce_op", "min") - elif isinstance(self.norm_type, (int, float)): - object.__setattr__(self, "reduce_op", "sum") - else: - raise NotImplementedError(f"Unsupported norm type: {self.norm_type}") - - def _partition_value( - self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int - ) -> torch.Tensor: - """ - For example, consider 4 ranks, a (3,) replicated tensor, and 2-norm: - Ranks 0 and 1: sqrt(t1^2 + t2^2 + t3^3) - To convert from replicated to partial, we want f(x) such that - sqrt(t1^2 + t2^2 + t3^3) = sqrt(4f(t1)^2 + 4f(t2)^2 + 4f(t3)^2) - = sqrt(4) sqrt(f(t1)^2 + f(t2)^2 + f(t3)^2). - One such f(x) is f(x) = x / sqrt(4). This generalizes to d ranks and - p-norm as f(x) = x / d^(1/p). - """ - if self.reduce_op in ("max", "min"): - return tensor - elif self.reduce_op == "sum": - if self.norm_type == 0: - raise NotImplementedError(f"Unsupported norm type:: {self.norm_type}") - elif self.norm_type == 1: - return tensor / mesh.size(mesh_dim) - assert isinstance(self.norm_type, (int, float)) - return tensor / math.pow(mesh.size(mesh_dim), 1 / self.norm_type) - raise NotImplementedError(self.reduce_op) - - def _reduce_shard_value( - self, - tensor: torch.Tensor, - mesh: DeviceMesh, - mesh_dim: int, - shard_spec: Placement, - ) -> torch.Tensor: - assert isinstance(shard_spec, Shard), f"{shard_spec}" - tensor = self._pre_reduce_transform(tensor) - reduced_tensor = super()._reduce_shard_value(tensor, mesh, mesh_dim, shard_spec) - return self._post_reduce_transform(reduced_tensor) - - def _reduce_value( - self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int - ) -> torch.Tensor: - tensor = self._pre_reduce_transform(tensor) - reduced_tensor = super()._reduce_value(tensor, mesh, mesh_dim) - return self._post_reduce_transform(reduced_tensor) - - def _pre_reduce_transform(self, tensor: torch.Tensor) -> torch.Tensor: - if self.reduce_op == "sum": - assert isinstance(self.norm_type, (int, float)), f"{self.norm_type}" - if self.norm_type != 0 and self.norm_type != 1: - return tensor**self.norm_type - return tensor - - def _post_reduce_transform(self, tensor: torch.Tensor) -> torch.Tensor: - if self.reduce_op == "sum": - assert isinstance(self.norm_type, (int, float)), f"{self.norm_type}" - if self.norm_type != 0 and self.norm_type != 1: - return tensor ** (1.0 / self.norm_type) - return tensor - - def __eq__(self, other: object) -> bool: - if not isinstance(other, _NormPartial): - return False - return self.norm_type == other.norm_type - - def __hash__(self) -> int: - return 1 + hash(self.norm_type) - - -def _infer_reduction_dims(dims_arg: object, ndim: int) -> Optional[List[int]]: - if dims_arg is None: - return None - dims = cast(List[int], as_list(dims_arg)) - dims = cast(List[int], normalize_dims(dims, ndim)) - empty_dims = [[0], [-1], []] - if ndim == 0 and dims_arg in empty_dims: - return None - return dims - - -def _infer_reduce_dims_map( - reduction_dims: List[int], input_ndim: int, keep_dim=False -) -> List[int]: - reduction_dims_map = [] - new_dim_count = 0 - for input_dim in range(input_ndim): - if input_dim in reduction_dims and not keep_dim: - # if input dim in reduction dims, mark it as -1 - reduction_dims_map.append(-1) - else: - # otherwise mark it as the new dim - reduction_dims_map.append(new_dim_count) - new_dim_count += 1 - - return reduction_dims_map - - -def _replicate_dims_start_at( - placements: Sequence[Placement], start_dim: int = 0 -) -> Tuple[Placement, ...]: - new_placements: List[Placement] = [] - for p in placements: - if p.is_partial() or (isinstance(p, Shard) and p.dim >= start_dim): - new_placements.append(Replicate()) # make it replicate - else: - new_placements.append(p) # keep the placement - return tuple(new_placements) - - -# return new_placements which align with placements but skip the skipped_dim -def _skip_dim( - placements: Tuple[Placement, ...], skipped_dim: int -) -> Tuple[Placement, ...]: - new_placements: List[Placement] = [] - for p in placements: - if isinstance(p, Shard) and p.dim >= skipped_dim: - new_placements.append(Shard(p.dim - 1)) - else: - new_placements.append(p) - return tuple(new_placements) - - -def replicate_reduction_dims( - placements: Tuple[Placement, ...], reduction_dims: List[int] -) -> Tuple[Placement, ...]: - # replicate the reduction dims if not reduction_linear - new_placements: List[Placement] = [] - - for p in placements: - if p.is_partial(): - new_placements.append(Replicate()) - elif isinstance(p, Shard) and p.dim in reduction_dims: - new_placements.append(Replicate()) - else: - new_placements.append(p) - - return tuple(new_placements) - - -def map_placements_after_reduction( - placements: Tuple[Placement, ...], - reduction_dims: List[int], - reduction_dims_map: List[int], - reduction_op: ReductionOpType, -) -> Tuple[Placement, ...]: - """ - Map each placement based on the output shape after reduction. - """ - new_placements: List[Placement] = [] - for placement in placements: - if isinstance(placement, (Replicate, Partial)): - new_placements.append(placement) - else: - assert isinstance(placement, Shard) - shard_dim = placement.dim - new_shard_dim = reduction_dims_map[shard_dim] - if new_shard_dim == -1 or shard_dim in reduction_dims: - # if new_shard_dim collapsed or its in the reduction dims - # (i.e. for the case where keepdims=True), we generate partial - new_placements.append(get_placement_from_reduction_op(reduction_op)) - else: - new_placements.append(Shard(new_shard_dim)) - return tuple(new_placements) - - -def get_placement_from_reduction_op(reduction_op: ReductionOpType) -> Placement: - if isinstance(reduction_op, NormReduction): - return _NormPartial(norm_type=reduction_op.norm_type) - return Partial(reduction_op) - - -def common_reduction_strategy( - mesh: DeviceMesh, - input_strategy: OpStrategy, - reduce_dims: List[int], - keep_dim: bool = False, - reduction_linear: bool = True, - reduction_op: ReductionOpType = "sum", -) -> OpStrategy: - """ - reduction_linear means that the reduction `f` follows this rule: - f([f(a), f(b)]) = f([a, b]) - - reduction linear should be super set of linearity. - """ - # by default follow reduction input strategy - reduction_strategy = OpStrategy([]) - - for strtg in input_strategy.strategies: - if not reduction_linear: - # input placements for this strategy should clear out pending sum and sharding - # on the reduction dimension - input_placements = replicate_reduction_dims( - strtg.output_spec.placements, reduce_dims - ) - else: - input_placements = strtg.output_spec.placements - - input_spec = DTensorSpec( - mesh=mesh, - placements=input_placements, - tensor_meta=strtg.output_spec.tensor_meta, - ) - - reduce_dims_map = _infer_reduce_dims_map(reduce_dims, input_spec.ndim, keep_dim) - out_placements = map_placements_after_reduction( - input_spec.placements, reduce_dims, reduce_dims_map, reduction_op - ) - redistribute_cost = [generate_redistribute_costs(input_strategy, input_spec)] - reduction_strategy.strategies.append( - PlacementStrategy( - output_specs=DTensorSpec( - mesh=mesh, - placements=out_placements, - ), - input_specs=(input_spec,), - redistribute_cost=redistribute_cost, - ) - ) - - return reduction_strategy - - -LINEAR_REDUCTION_OP_MAP = { - aten.all.default: "sum", - aten.all.dim: "sum", - aten.sum.default: "sum", - aten.sum.dim_IntList: "sum", - aten.prod.default: "product", - aten.prod.dim_int: "product", - aten.prod.int_out: "product", - aten.mean.default: "avg", - aten.mean.dim: "avg", - aten.mean.out: "avg", - aten.max.default: "max", - aten.max.dim: "max", - aten.max.out: "max", - aten.min.default: "min", - aten.min.dim: "min", - aten.min.out: "min", -} - - -@register_op_strategy( - list(LINEAR_REDUCTION_OP_MAP.keys()), schema_info=RuntimeSchemaInfo(1) -) -def linear_reduction_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy: - args_schema = op_schema.args_schema - input_strategy = args_schema[0] - assert isinstance(input_strategy, OpStrategy) - dims = None - if len(op_schema.args_schema) > 1: - dims = _infer_reduction_dims(args_schema[1], input_strategy.ndim) - - reduce_dims = list(range(input_strategy.ndim)) if dims is None else dims - - keep_dim = len(op_schema.args_schema) > 2 and bool(op_schema.args_schema[2]) - reduction_op = LINEAR_REDUCTION_OP_MAP[op_schema.op] - return common_reduction_strategy( - mesh, - input_strategy, - reduce_dims, - keep_dim=keep_dim, - reduction_linear=True, - reduction_op=reduction_op, - ) - - -@register_op_strategy( - [aten.var.correction, aten.var.correction_out], - schema_info=RuntimeSchemaInfo(1, ["keepdim"]), -) -def var_reduction_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy: - args_schema = op_schema.args_schema - input_strategy = args_schema[0] - assert isinstance(input_strategy, OpStrategy) - dims = None - if len(op_schema.args_schema) > 1: - dims = _infer_reduction_dims(args_schema[1], input_strategy.ndim) - - reduce_dims = list(range(input_strategy.ndim)) if dims is None else dims - - keep_dim = cast(bool, op_schema.kwargs_schema.get("keepdim", False)) - return common_reduction_strategy( - mesh, input_strategy, reduce_dims, keep_dim=keep_dim, reduction_linear=False - ) - - -@register_op_strategy( - [aten.linalg_vector_norm.default], schema_info=RuntimeSchemaInfo(1) -) -def vector_norm_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy: - args_schema = op_schema.args_schema - input_strategy = args_schema[0] - assert isinstance(input_strategy, OpStrategy) - norm_type = args_schema[1] if len(args_schema) > 1 else 2 - assert isinstance(norm_type, (int, float, str)), f"{norm_type}" - dim = args_schema[2] if len(args_schema) > 2 else None - keepdim = args_schema[3] if len(args_schema) > 3 else False - dims = _infer_reduction_dims(dim, input_strategy.ndim) - reduce_dims = list(range(input_strategy.ndim)) if dims is None else dims - return common_reduction_strategy( - mesh, - input_strategy, - reduce_dims, - keep_dim=cast(bool, keepdim), - reduction_linear=True, - reduction_op=NormReduction(norm_type), - ) - - -@register_op_strategy( - [aten._foreach_norm.Scalar], schema_info=RuntimeSchemaInfo(1, needs_pytree=True) -) -def foreach_norm_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> TupleStrategy: - args_schema = op_schema.args_schema - input_tuple_strategy = args_schema[0] - assert isinstance(input_tuple_strategy, TupleStrategy) - norm_type = args_schema[1] if len(args_schema) > 1 else 2 - assert isinstance(norm_type, (int, float, str)), f"{norm_type}" - output_tuple_strategy_childs: List[OpStrategy] = [] - for op_strategy in input_tuple_strategy.childs: - assert isinstance(op_strategy, OpStrategy), f"{op_strategy}" - reduce_dims = list(range(op_strategy.ndim)) - output_strategy = common_reduction_strategy( - mesh, - op_strategy, - reduce_dims, - reduction_linear=True, - reduction_op=NormReduction(norm_type), - ) - output_tuple_strategy_childs.append(output_strategy) - return TupleStrategy(output_tuple_strategy_childs) - - -@register_op_strategy([aten._linalg_svd.default], schema_info=RuntimeSchemaInfo(1)) -def linalg_svd_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy: - # Since we do not have a simple way to compute a sharded SVD, always fall - # back to replicate - args_schema = op_schema.args_schema - input_strategy = args_schema[0] - assert isinstance(input_strategy, OpStrategy), f"{input_strategy}" - output_strategies: List[PlacementStrategy] = [] - for placement_strategy in input_strategy.strategies: - replicate_placements = tuple(Replicate() for _ in range(mesh.ndim)) - replicate_spec = DTensorSpec( - mesh=mesh, - placements=replicate_placements, - tensor_meta=placement_strategy.output_spec.tensor_meta, - ) - redistribute_cost = [ - generate_redistribute_costs(input_strategy, replicate_spec) - ] - replicate_strategy = PlacementStrategy( - output_specs=replicate_spec, - input_specs=(replicate_spec,), - redistribute_cost=redistribute_cost, - ) - output_strategies.append(replicate_strategy) - return OpStrategy(output_strategies) - - -@register_op_strategy( - [aten._log_softmax.default, aten._softmax.default], schema_info=RuntimeSchemaInfo(1) -) -def softmax_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy: - input_strategy, softmax_dim, _ = op_schema.args_schema - input_strategy = cast(OpStrategy, input_strategy) - softmax_dim = cast(int, softmax_dim) - softmax_dim = normalize_dim(softmax_dim, input_strategy.ndim) - - output_strategy = OpStrategy([]) - for idx, input_placement_strategy in enumerate(input_strategy.strategies): - redistribute_costs = [] - input_src_spec = input_placement_strategy.output_spec - - # make sure input is replicated along the softmax dim - input_target_spec = DTensorSpec( - mesh=mesh, - placements=replicate_reduction_dims( - input_src_spec.placements, [softmax_dim] - ), - tensor_meta=input_src_spec.tensor_meta, - ) - redistribute_costs.append( - generate_redistribute_costs(input_strategy, input_target_spec) - ) - output_target_spec = input_target_spec - output_strategy.strategies.append( - PlacementStrategy( - output_specs=output_target_spec, - input_specs=[input_target_spec], - redistribute_cost=redistribute_costs, - ) - ) - - return output_strategy - - -@register_op_strategy( - [ - aten._log_softmax_backward_data.default, - aten._softmax_backward_data.default, - ], - schema_info=RuntimeSchemaInfo(2), -) -def softmax_backward_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy: - grad_out_strategy, out_strategy, softmax_dim, _ = op_schema.args_schema - grad_out_strategy = cast(OpStrategy, grad_out_strategy) - out_strategy = cast(OpStrategy, out_strategy) - softmax_dim = cast(int, softmax_dim) - softmax_dim = normalize_dim(softmax_dim, grad_out_strategy.ndim) - - grad_in_strategy = OpStrategy([]) - for grad_out_placement_strat, out_placement_strat in zip( - grad_out_strategy.strategies, out_strategy.strategies - ): - # follow the sharding of the grad_out or out depending on which has more shards - grad_out_src_spec = grad_out_placement_strat.output_spec - out_src_spec = out_placement_strat.output_spec - src_spec = ( - grad_out_src_spec - if grad_out_src_spec.num_shards >= out_src_spec.num_shards - else out_src_spec - ) - - # make sure inputs are replicated along the softmax dim - tgt_spec = DTensorSpec( - mesh=mesh, - placements=replicate_reduction_dims(src_spec.placements, [softmax_dim]), - ) - redist_grad_out_cost = generate_redistribute_costs(grad_out_strategy, tgt_spec) - redist_out_cost = generate_redistribute_costs(out_strategy, tgt_spec) - grad_in_strategy.strategies.append( - PlacementStrategy( - output_specs=tgt_spec, - redistribute_cost=[redist_grad_out_cost, redist_out_cost], - ) - ) - - return grad_in_strategy - - -@register_op_strategy( - [aten.nll_loss_forward.default, aten.nll_loss2d_forward.default], - schema_info=RuntimeSchemaInfo(3), -) -def nll_loss_forward_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy: - assert len(op_schema.args_schema) == 5 - ( - input_strategy, - target_strategy, - weight_strategy, - reduction, - _, - ) = op_schema.args_schema - input_strategy = cast(OpStrategy, input_strategy) - target_strategy = cast(OpStrategy, target_strategy) - reduction = cast(int, reduction) - - input_shape = input_strategy.shape - channel_dim = 1 if len(input_shape) >= 2 else 0 - - output_strategy = OpStrategy([]) - for idx, input_placement_strategy in enumerate(input_strategy.strategies): - op_args_target_specs = [] - redistribute_costs = [] - - # make sure input is replicated along the channel dim - input_src_spec = input_placement_strategy.output_spec - input_expected_spec = DTensorSpec( - mesh=mesh, - placements=replicate_reduction_dims( - input_src_spec.placements, [channel_dim] - ), - tensor_meta=input_src_spec.tensor_meta, - ) - op_args_target_specs.append(input_expected_spec) - redistribute_costs.append( - generate_redistribute_costs(input_strategy, input_expected_spec) - ) - - # target doesn't have channel dim, and it follows input on other dims - target_src_spec = target_strategy.strategies[idx].output_spec - target_expected_spec = DTensorSpec( - mesh=mesh, - placements=_skip_dim(input_expected_spec.placements, channel_dim), - tensor_meta=target_src_spec.tensor_meta, - ) - op_args_target_specs.append(target_expected_spec) - redistribute_costs.append( - generate_redistribute_costs(target_strategy, target_expected_spec) - ) - - # weight tensor, if given, has to be a Tensor of size input_shape[channel_dim] - # make sure it is replicated - if weight_strategy is not None: - assert isinstance(weight_strategy, OpStrategy) - weight_src_spec = weight_strategy.strategies[idx].output_spec - weight_expected_spec = DTensorSpec( - mesh=mesh, - placements=_replicate_dims_start_at(weight_src_spec.placements), - tensor_meta=weight_src_spec.tensor_meta, - ) - op_args_target_specs.append(weight_expected_spec) - redistribute_costs.append( - generate_redistribute_costs(weight_strategy, weight_expected_spec) - ) - - if reduction == Reduction.NONE.value: - output_expected_spec = target_expected_spec - total_weight_expected_spec = DTensorSpec( - mesh=mesh, placements=tuple([Replicate()] * mesh.ndim) - ) - else: - if reduction == Reduction.MEAN.value: - reduction_op = "avg" - if not is_tensor_evenly_shardable( - target_expected_spec.shape, target_expected_spec - ): - raise ValueError( - "The intermediate results of nll_loss cannot be evenly sharded, \ - resulting in biased mean result." - ) - else: # reduction == Reduction.SUM.value: - reduction_op = "sum" - reduce_dims = list(range(target_expected_spec.ndim)) - reduce_dims_map = _infer_reduce_dims_map( - reduce_dims, target_expected_spec.ndim, keep_dim=False - ) - out_placements = map_placements_after_reduction( - target_expected_spec.placements, - reduce_dims, - reduce_dims_map, - reduction_op, - ) - output_expected_spec = DTensorSpec( - mesh=mesh, - placements=out_placements, - ) - - # whether reduction is sum or mean, the total weight has to be summed up if not replicated - total_weight_placements = map_placements_after_reduction( - target_expected_spec.placements, - reduce_dims, - reduce_dims_map, - "sum", - ) - total_weight_expected_spec = DTensorSpec( - mesh=mesh, - placements=total_weight_placements, - ) - - output_strategy.strategies.append( - PlacementStrategy( - output_specs=(output_expected_spec, total_weight_expected_spec), - input_specs=op_args_target_specs, - redistribute_cost=redistribute_costs, - ) - ) - - return output_strategy - - -@register_op_strategy( - [aten.nll_loss_backward.default, aten.nll_loss2d_backward.default], - schema_info=RuntimeSchemaInfo(4), -) -def nll_loss_backward_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy: - assert len(op_schema.args_schema) == 7 - ( - grad_out_strategy, - input_strategy, - target_strategy, - weight_strategy, - reduction, - _, - total_weight_strategy, - ) = op_schema.args_schema - grad_out_strategy = cast(OpStrategy, grad_out_strategy) - input_strategy = cast(OpStrategy, input_strategy) - target_strategy = cast(OpStrategy, target_strategy) - reduction = cast(int, reduction) - total_weight_strategy = cast(OpStrategy, total_weight_strategy) - - input_shape = input_strategy.shape - channel_dim = 1 if len(input_shape) >= 2 else 0 - - grad_in_strategy = OpStrategy([]) - for idx, input_placement_strategy in enumerate(input_strategy.strategies): - op_args_target_specs = [] - redistribute_costs = [] - - # make sure input is replicated along the channel dim - input_src_spec = input_placement_strategy.output_spec - input_expected_spec = DTensorSpec( - mesh=mesh, - placements=replicate_reduction_dims( - input_src_spec.placements, [channel_dim] - ), - tensor_meta=input_src_spec.tensor_meta, - ) - op_args_target_specs.append(input_expected_spec) - redistribute_costs.append( - generate_redistribute_costs(input_strategy, input_expected_spec) - ) - - # target doesn't have channel dim, and it follows input on other dims - target_src_spec = target_strategy.strategies[idx].output_spec - target_expected_spec = DTensorSpec( - mesh=mesh, - placements=_skip_dim(input_expected_spec.placements, channel_dim), - tensor_meta=target_src_spec.tensor_meta, - ) - op_args_target_specs.append(target_expected_spec) - redistribute_costs.append( - generate_redistribute_costs(target_strategy, target_expected_spec) - ) - - # grad_out follows target if there is no reduction; - # otherwise, it should be a replicated scalar. - grad_out_src_spec = grad_out_strategy.strategies[idx].output_spec - if reduction == Reduction.NONE.value: - grad_out_expected_spec = target_expected_spec - else: - grad_out_expected_spec = DTensorSpec( - mesh=mesh, - placements=_replicate_dims_start_at(grad_out_src_spec.placements), - tensor_meta=grad_out_src_spec.tensor_meta, - ) - op_args_target_specs.insert(0, grad_out_expected_spec) - redistribute_costs.insert( - 0, generate_redistribute_costs(grad_out_strategy, grad_out_expected_spec) - ) - - # weight tensor, if given, has to be a Tensor of size input_shape[channel_dim] - # make sure it is replicated - if weight_strategy is not None: - assert isinstance(weight_strategy, OpStrategy) - weight_src_spec = weight_strategy.strategies[idx].output_spec - weight_expected_spec = DTensorSpec( - mesh=mesh, - placements=_replicate_dims_start_at(weight_src_spec.placements), - tensor_meta=weight_src_spec.tensor_meta, - ) - op_args_target_specs.append(weight_expected_spec) - redistribute_costs.append( - generate_redistribute_costs(weight_strategy, weight_expected_spec) - ) - - # total_weight should always be replicated - total_weight_src_spec = total_weight_strategy.strategies[idx].output_spec - total_weight_expected_spec = DTensorSpec( - mesh=mesh, - placements=_replicate_dims_start_at(total_weight_src_spec.placements), - tensor_meta=total_weight_src_spec.tensor_meta, - ) - op_args_target_specs.append(total_weight_expected_spec) - redistribute_costs.append( - generate_redistribute_costs( - total_weight_strategy, total_weight_expected_spec - ) - ) - - grad_in_expected_spec = input_expected_spec - grad_in_strategy.strategies.append( - PlacementStrategy( - output_specs=grad_in_expected_spec, - input_specs=op_args_target_specs, - redistribute_cost=redistribute_costs, - ) - ) - - return grad_in_strategy - - -def rlog(msg): - rank = torch.distributed.get_rank() - if rank == 0: - print(msg) - - -@register_op_strategy( - [aten.native_layer_norm.default], - schema_info=RuntimeSchemaInfo(1), -) -def layer_norm_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy: - # args must be: input, normalized_shape, weight, bias, eps - # for None weight and bias, their corresponding objects will - # be None as well. layer_norm_strategy returns one OpStrategy - # for the triple return values (out, mean, rstd). - assert len(op_schema.args_schema) == 5 - ( - input_strategy, - normalized_shape, - weight_strategy, - bias_strategy, - _, - ) = op_schema.args_schema - - # the current layer norm implementation requires that all - # input DTensor's sharding must be in form of OpStrategy - assert isinstance(input_strategy, OpStrategy) - assert isinstance(normalized_shape, (int, Sequence, torch.Size)) - normalized_size = normalize_to_torch_size(normalized_shape) - - input_ndim = input_strategy.ndim - axis = input_ndim - len(normalized_size) - - # we use OpStrategy because the output (out, mean, rstd) - # should have the same placements - output_strategy = OpStrategy([]) - for idx, input_placement_strategy in enumerate(input_strategy.strategies): - op_args_target_specs = [] - input_src_spec = input_placement_strategy.output_spec - - # for the input tensor, we replicate it on the inner dims if necessary - # TODO: we can avoid forcing the redistribution once we figure out - # how to decompose layer norm - input_target_spec = DTensorSpec( - mesh=mesh, - placements=_replicate_dims_start_at(input_src_spec.placements, axis), - tensor_meta=input_src_spec.tensor_meta, - ) - op_args_target_specs.append(input_target_spec) - - if weight_strategy is not None: - assert isinstance(weight_strategy, OpStrategy) - weight_src_spec = weight_strategy.strategies[idx].output_spec - - # for the weight tensor, we replicate it on all dims if necessary - # TODO: we can avoid forcing the redistribution once we figure out - # how to decompose layer norm - weight_target_spec = DTensorSpec( - mesh=mesh, - placements=_replicate_dims_start_at(weight_src_spec.placements), - tensor_meta=weight_src_spec.tensor_meta, - ) - op_args_target_specs.append(weight_target_spec) - - if bias_strategy is not None: - assert isinstance(bias_strategy, OpStrategy) - bias_src_spec = bias_strategy.strategies[idx].output_spec - - # for the bias tensor, we replicate it on all dims if necessary - # TODO: we can avoid forcing the redistribution once we figure out - # how to decompose layer norm - bias_target_spec = DTensorSpec( - mesh=mesh, - placements=_replicate_dims_start_at(bias_src_spec.placements), - tensor_meta=bias_src_spec.tensor_meta, - ) - op_args_target_specs.append(bias_target_spec) - - # the output spec is the same as input spec - output_target_spec = input_target_spec - output_strategy.strategies.append( - PlacementStrategy( - output_specs=output_target_spec, - input_specs=op_args_target_specs, - redistribute_cost=[0], - ) - ) - - return output_strategy - - -@register_op_strategy( - [aten.native_layer_norm_backward.default], - schema_info=RuntimeSchemaInfo(2), -) -def layer_norm_bwd_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy: - # args must be: grad_out, input, normalized_shape, mean, rstd, - # weight, bias, output_mask. For None weight and bias, their - # corresponding objects will be None as well. - assert len(op_schema.args_schema) == 8 - ( - grad_out_strategy, - input_strategy, - normalized_shape, - mean_strategy, - rstd_strategy, - weight_strategy, - bias_strategy, - output_mask, - ) = op_schema.args_schema - - assert isinstance(grad_out_strategy, OpStrategy) - assert isinstance(input_strategy, OpStrategy) - assert isinstance(mean_strategy, OpStrategy) - assert isinstance(rstd_strategy, OpStrategy) - - assert isinstance(normalized_shape, (int, Sequence, torch.Size)) - normalized_size = normalize_to_torch_size(normalized_shape) - input_ndim = input_strategy.ndim - axis = input_ndim - len(normalized_size) - outer_dims = list(range(axis)) - - assert isinstance(output_mask, List) and len(output_mask) == 3 - - # output triple: (d_input, d_weight, d_bias) - out_tuple_strategy = OpStrategy([]) - for idx, input_placement_strategy in enumerate(input_strategy.strategies): - # args for PlacementStrategy - output_specs_list: List[Optional[DTensorSpec]] = [] - op_args_target_specs = [] - redistribute_costs = [] - - input_src_spec = input_placement_strategy.output_spec - # arg: grad_out - # TODO: change the strategy to the following rule. - # d_input is basically a product of element-wise mul of - # grad_out, rstd, and normalized input, among which rstd - # and normalized input (x_hat) should have the same sharding - # placements, and grad_out's sharding is determined by the - # pointwise result of x_hat and weight/bias. - if output_mask[0]: - # TODO: now grad_out spec follows input spec. we may need - # to change it to apply a pointwise rule over grad_out, - # input, and weight. - grad_out_target_spec = DTensorSpec( - mesh=mesh, - placements=_replicate_dims_start_at(input_src_spec.placements, axis), - tensor_meta=input_src_spec.tensor_meta, - ) - op_args_target_specs.append(grad_out_target_spec) - redistribute_costs.append( - generate_redistribute_costs(grad_out_strategy, grad_out_target_spec) - ) - output_specs_list.append(grad_out_target_spec) - else: - output_specs_list.append(None) - - # arg: input - input_target_spec = DTensorSpec( - mesh=mesh, - placements=_replicate_dims_start_at(input_src_spec.placements, axis), - tensor_meta=input_src_spec.tensor_meta, - ) - op_args_target_specs.append(input_target_spec) - redistribute_costs.append( - generate_redistribute_costs(input_strategy, input_target_spec) - ) - - # arg: mean, rstd - mean_src_spec = mean_strategy.strategies[idx].output_spec - op_args_target_specs.append(mean_src_spec) - redistribute_costs.append([0.0 for _ in mean_strategy.strategies]) - rstd_src_spec = rstd_strategy.strategies[idx].output_spec - op_args_target_specs.append(rstd_src_spec) - redistribute_costs.append([0.0 for _ in rstd_strategy.strategies]) - - # arg: weight - # d_weight = sum(grad_out * (input - mean) / rstd, outer_dim, keepdim=False) - if output_mask[1]: - assert isinstance(weight_strategy, OpStrategy) - weight_src_spec = weight_strategy.strategies[idx].output_spec - # no need to redistribute weight since they should be replicated - # in forward pass - op_args_target_specs.append(weight_src_spec) - redistribute_costs.append([0.0 for _ in weight_strategy.strategies]) - # TODO: now d_weight spec follows input spec w/ a reduction. - # we may need to change to a pointwise rule over grad_out and - # input, then apply a reduction. - inp_placements = _replicate_dims_start_at(input_src_spec.placements, axis) - reduce_dims_map = _infer_reduce_dims_map( - outer_dims, input_src_spec.ndim, False - ) - out_placements = map_placements_after_reduction( - inp_placements, outer_dims, reduce_dims_map, "sum" - ) - output_specs_list.append( - DTensorSpec( - mesh=mesh, - placements=out_placements, - tensor_meta=weight_src_spec.tensor_meta, - ) - ) - else: - output_specs_list.append(None) - - # arg: bias - # d_bias = sum(grad_out, outer_dim, keepdim=False) - if output_mask[2]: - assert isinstance(bias_strategy, OpStrategy) - bias_src_spec = bias_strategy.strategies[idx].output_spec - # no need to redistribute weight since they should be replicated - # in forward pass - op_args_target_specs.append(bias_src_spec) - redistribute_costs.append([0.0 for _ in bias_strategy.strategies]) - # Currently we do not support the case where output_mask[0] is False while - # output_mask[1] is True. But it's easy to support that by accessing - # grad_out_spec via a local variable rather than the list. We just don't - # see the case. - grad_out_spec = output_specs_list[0] - assert isinstance(grad_out_spec, DTensorSpec) - # d_bias spec follows a reduction over grad_out - inp_placements = _replicate_dims_start_at(grad_out_spec.placements, axis) - reduce_dims_map = _infer_reduce_dims_map( - outer_dims, grad_out_spec.ndim, False - ) - out_placements = map_placements_after_reduction( - inp_placements, outer_dims, reduce_dims_map, "sum" - ) - output_specs_list.append( - DTensorSpec( - mesh=mesh, - placements=out_placements, - tensor_meta=bias_src_spec.tensor_meta, - ) - ) - else: - output_specs_list.append(None) - - out_tuple_strategy.strategies.append( - PlacementStrategy( - output_specs=tuple(output_specs_list), - input_specs=op_args_target_specs, - redistribute_cost=redistribute_costs, - ) - ) - - return out_tuple_strategy - - -@register_op_strategy( - [aten.topk.default], - schema_info=RuntimeSchemaInfo(2), -) -def topk_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy: - input_strategy = cast(OpStrategy, op_schema.args_schema[0]) - k = cast(int, op_schema.args_schema[1]) - input_shape = input_strategy.shape - topk_dim = ( - cast(int, op_schema.args_schema[2]) if len(op_schema.args_schema) > 2 else -1 - ) - topk_dim = normalize_dim(topk_dim, input_strategy.ndim) - - single_mesh_dim_strategies = [] - - # two outputs (values, indices), 1 input - # replicate always works - all_replicate: PlacementList = [Replicate()] * 3 - single_mesh_dim_strategies.append(all_replicate) - - # every dim except topk dim should work - for dim in range(input_strategy.ndim): - if dim != topk_dim: - dim_shardings: PlacementList = [Shard(dim)] * 3 - single_mesh_dim_strategies.append(dim_shardings) - # TODO: topk on sharded dim requries non-trival reduction, address it later - - return expand_to_full_mesh_op_strategy( - mesh, op_schema, single_mesh_dim_strategies, input_index=2 - ) diff --git a/src/chop/distributed/tensor/ops/matrix_ops.py b/src/chop/distributed/tensor/ops/matrix_ops.py deleted file mode 100644 index 77484de7d..000000000 --- a/src/chop/distributed/tensor/ops/matrix_ops.py +++ /dev/null @@ -1,459 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -# implement matrix related ops for distributed tensor - -import torch -from torch.distributed._tensor._op_schema import ( - OpSchema, - OpStrategy, - PlacementList, - PlacementStrategy, -) -from torch.distributed._tensor.ops.utils import ( - expand_to_full_mesh_op_strategy, - generate_redistribute_costs, - infer_broadcast_dims_map, - is_tensor_shardable, - map_placements_after_broadcast, -) -from torch.distributed._tensor.placement_types import ( - DTensorSpec, - Placement, - Replicate, - Shard, -) -from torch.distributed.device_mesh import DeviceMesh -from chop.distributed.tensor.ops.basic_strategy import gen_einsum_strategies -from chop.distributed.tensor.ops.utils import register_op_strategy - -aten = torch.ops.aten - - -@register_op_strategy(aten.t.default) -def transpose_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy: - self_strategy = op_schema.args_schema[0] - assert isinstance(self_strategy, OpStrategy) - - transpose_strategies = [] - for input_strategy in self_strategy.strategies: - input_spec = input_strategy.output_spec - # follow the input spec but transpose the Shard placements - output_placements = [ - Shard(1 - p.dim) if isinstance(p, Shard) else p - for p in input_spec.placements - ] - transpose_strategy = PlacementStrategy( - output_specs=DTensorSpec( - mesh=input_strategy.output_spec.mesh, - placements=tuple(output_placements), - ), - input_specs=(input_strategy.output_spec,), - ) - transpose_strategies.append(transpose_strategy) - - return OpStrategy(strategies=transpose_strategies) - - -def _mm_like_strategy( - mm_equation: str, mesh: DeviceMesh, op_schema: OpSchema -) -> OpStrategy: - self_strategy, mat2_strategy = op_schema.args_schema - assert isinstance(self_strategy, OpStrategy) - assert isinstance(mat2_strategy, OpStrategy) - # generate all possible strategies for mm - mm_strategy = gen_einsum_strategies(mm_equation, mesh) - # filter out invalid strategies and associate costs - strategies = mm_strategy.strategies - filtered_strategies = [] - for strtg in strategies: - assert strtg.input_specs is not None - self_spec = strtg.input_specs[0] - mat2_spec = strtg.input_specs[1] - if is_tensor_shardable(self_strategy.shape, self_spec) and is_tensor_shardable( - mat2_strategy.shape, mat2_spec - ): - redistribute_cost = [ - generate_redistribute_costs(self_strategy, self_spec), - generate_redistribute_costs(mat2_strategy, mat2_spec), - ] - strtg.redistribute_cost = redistribute_cost - filtered_strategies.append(strtg) - - mm_strategy.strategies = filtered_strategies - - return mm_strategy - - -def _addmm_like_strategy( - mm_equation: str, mesh: DeviceMesh, op_schema: OpSchema -) -> OpStrategy: - self_strategy, mat1_strategy, mat2_strategy = op_schema.args_schema - assert isinstance(self_strategy, OpStrategy) - assert isinstance(mat1_strategy, OpStrategy) - assert isinstance(mat2_strategy, OpStrategy) - self_shape = self_strategy.shape - mm_out_shape = torch.Size( - [ - mat2_strategy.shape[-1] if i == len(mat1_strategy.shape) - 1 else dim_size - for i, dim_size in enumerate(mat1_strategy.shape) - ] - ) - # generate all possible strategies for mm - mm_strategy = gen_einsum_strategies(mm_equation, mesh) - # filter out invalid strategies and associate costs - strategies = mm_strategy.strategies - filtered_strategies = [] - for strtg in strategies: - # construct new strategy by consider the self arg - assert strtg.input_specs is not None - mat1_spec = strtg.input_specs[0] - mat2_spec = strtg.input_specs[1] - out_spec = strtg.output_spec - - # self arg's spec should follow the output of mm, but need - # to consider broadcast for the self arg - broadcast_dims_map = infer_broadcast_dims_map(mm_out_shape, self_shape) - self_placements = map_placements_after_broadcast( - out_spec.placements, mm_out_shape, broadcast_dims_map - ) - self_spec = DTensorSpec(mesh=mesh, placements=self_placements) - - if is_tensor_shardable(mat1_strategy.shape, mat1_spec) and is_tensor_shardable( - mat2_strategy.shape, mat2_spec - ): - # update input specs with new self spec - strtg.input_specs = (self_spec, mat1_spec, mat2_spec) - - # associate costs - redistribute_cost = [ - generate_redistribute_costs(self_strategy, self_spec), - generate_redistribute_costs(mat1_strategy, mat1_spec), - generate_redistribute_costs(mat2_strategy, mat2_spec), - ] - strtg.redistribute_cost = redistribute_cost - filtered_strategies.append(strtg) - - mm_strategy.strategies = filtered_strategies - - return mm_strategy - - -@register_op_strategy(aten.mm.default) -def mm_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy: - return _mm_like_strategy("mk,kn->mn", mesh, op_schema) - - -@register_op_strategy(aten.addmm.default) -def addmm_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy: - return _addmm_like_strategy("mk,kn->mn", mesh, op_schema) - - -@register_op_strategy(aten.bmm.default) -def bmm_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy: - return _mm_like_strategy("bmk,bkn->bmn", mesh, op_schema) - - -@register_op_strategy(aten.baddbmm.default) -def baddmm_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy: - return _addmm_like_strategy("bmk,bkn->bmn", mesh, op_schema) - - -@register_op_strategy(aten._scaled_dot_product_flash_attention.default) -def scaled_dot_product_flash_attention_strategy( - mesh: DeviceMesh, op_schema: OpSchema -) -> OpStrategy: - # NOTE: currently we only support some simple strategies to support tensor parallelism - # TODO: sdpa might be a good candidate for us to explore decomposed sharding propagation - # as it involves: matmul, pointwise, reduction ops together. - return_debug_mask = len(op_schema.args_schema) >= 6 and op_schema.args_schema[5] - q_input_strategy = op_schema.args_schema[0] - assert isinstance(q_input_strategy, OpStrategy) - # assuming q/k/v have the same shape - qkv_shape = q_input_strategy.shape - - single_mesh_dim_strategies = [] - - # placement list stores placements of [outputs, inputs] - # in the spda case, we have 3 valid tensor outputs and 3 tensor inputs - # first we can always accept full replication for both inputs and outputs - all_replicate: PlacementList = [ - Replicate(), - Replicate(), - None, # cum_seq_q - None, # cum_seq_k - None, # max_q - None, # max_k - None, # philox_seed - None, # philox_offset - Replicate(), - Replicate(), - Replicate(), - Replicate(), - ] - single_mesh_dim_strategies.append(all_replicate) - - # second we can accept the sharding pattern of tensor parallelism, which - # shard on the num of head dim - qkv_sharding = Shard(1) # num head dim - output_sharding = Shard(1) # num head dim - logsumexp_sharding = Shard(1) # num head dim - if return_debug_mask: - debug_attn_mask_sharding: Placement = Shard(1) # num head dim - else: - # empty debug mask, replicated - debug_attn_mask_sharding = Replicate() - - num_heads_dim_sharding: PlacementList = [ - output_sharding, - logsumexp_sharding, - None, # cum_seq_q - None, # cum_seq_k - None, # max_q - None, # max_k - None, # philox_seed - None, # philox_offset - debug_attn_mask_sharding, - qkv_sharding, - qkv_sharding, - qkv_sharding, - ] - single_mesh_dim_strategies.append(num_heads_dim_sharding) - - # Context Parallelism: shards on the sequence dim - single_mesh_dim_strategies.append( - [ - Shard(2), # output - Shard(2), # logsumexp - None, # cum_seq_q - None, # cum_seq_k - None, # max_q - None, # max_k - None, # philox_seed - None, # philox_offset - Shard(2), # debugattn - Shard(2), # q - Shard(2), # k - Shard(2), # v - ] - ) - return expand_to_full_mesh_op_strategy( - mesh, op_schema, single_mesh_dim_strategies, input_index=9 - ) - - -@register_op_strategy(aten._scaled_dot_product_flash_attention_backward.default) -def scaled_dot_product_flash_attention_backward_strategy( - mesh: DeviceMesh, op_schema: OpSchema -) -> OpStrategy: - q_input_strategy = op_schema.args_schema[1] - assert isinstance(q_input_strategy, OpStrategy) - # assuming q/k/v have the same shape - qkv_shape = q_input_strategy.shape - - tensor_input_indices = [ - i - for i, arg_spec in enumerate(op_schema.args_schema) - if isinstance(arg_spec, OpStrategy) - ] - num_tensor_inputs = len(tensor_input_indices) - - single_mesh_dim_strategies = [] - - # placement list stores placements of [outputs, inputs] - # in the spda backward case, we have 3 tensor outputs and 6 to 10 tensor inputs - # first we can always accept full replication for both inputs and outputs - all_replicate: PlacementList = [Replicate()] * (3 + num_tensor_inputs) - - single_mesh_dim_strategies.append(all_replicate) - - # second we can accept the sharding pattern of tensor parallelism, which - # shard on the num of head dim - grad_output_sharding = Shard(1) # num head dim - qkv_sharding = Shard(1) # num head dim - output_sharding = Shard(1) # num head dim - logsumexp_sharding = Shard(1) # num head dim - grad_qkv_sharding = Shard(1) # num head dim - - num_heads_dim_sharding: PlacementList = [ - grad_qkv_sharding, - grad_qkv_sharding, - grad_qkv_sharding, - grad_output_sharding, - qkv_sharding, - qkv_sharding, - qkv_sharding, - output_sharding, - logsumexp_sharding, - ] - # accept replicate on the rest tensor inputs, potentially - # cum_seq_q, cum_seq_k, philox_seed, philox_offset - # at indices 6, 7, 12, 13, respectively - num_heads_dim_sharding.extend([Replicate()] * (num_tensor_inputs - 6)) - single_mesh_dim_strategies.append(num_heads_dim_sharding) - - # Context Parallelism: shards on the sequence dim - seq_dim_sharding: PlacementList = [ - Shard(2), # grad_q - Shard(2), # grad_k - Shard(2), # grad_v - Shard(2), # grad_output - Shard(2), # q - Shard(2), # k - Shard(2), # v - Shard(2), # output - Shard(2), # logsumexp - ] - # accept replicate on the rest tensor inputs, potentially - # cum_seq_q, cum_seq_k, philox_seed, philox_offset - # at indices 6, 7, 12, 13, respectively - seq_dim_sharding.extend([Replicate()] * (num_tensor_inputs - 6)) - single_mesh_dim_strategies.append(seq_dim_sharding) - - return expand_to_full_mesh_op_strategy( - mesh, op_schema, single_mesh_dim_strategies, input_index=3 - ) - - -@register_op_strategy(aten.constant_pad_nd.default) -def constant_pad_nd_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy: - # TODO(d4l3k); implement a more correct strategy for constant_pad_nd - return OpStrategy( - [ - PlacementStrategy( - output_specs=DTensorSpec(mesh, (Replicate(),)), - input_specs=( - DTensorSpec(mesh, (Replicate(),)), - DTensorSpec(mesh, (Replicate(),)), - ), - redistribute_cost=[[1]], - ) - ] - ) - - -@register_op_strategy(aten._scaled_dot_product_efficient_attention.default) -def scaled_dot_product_efficient_attention_strategy( - mesh: DeviceMesh, op_schema: OpSchema -) -> OpStrategy: - # NOTE: currently we only support some simple strategies to support tensor parallelism - q_input_strategy = op_schema.args_schema[0] - assert isinstance(q_input_strategy, OpStrategy) - # assuming q/k/v have the same shape - qkv_shape = q_input_strategy.shape - has_attn_bias = op_schema.args_schema[3] is not None - compute_log_sumexp = op_schema.args_schema[4] - - single_mesh_dim_strategies = [] - - # placement list stores placements of [outputs, inputs] - # in the spda case, we have 2 valid tensor outputs and 3 or 4 tensor inputs - # first we can always accept full replication for both inputs and outputs - all_replicate: PlacementList = [ - Replicate(), - Replicate(), - None, - None, - Replicate(), - Replicate(), - Replicate(), - ] - if has_attn_bias: - all_replicate.append(Replicate()) # attn bias - single_mesh_dim_strategies.append(all_replicate) - - # second we can accept the sharding pattern of tensor parallelism, which - # shard on the heads dimension - qkv_sharding = Shard(1) - output_sharding = Shard(1) - if compute_log_sumexp: - logsumexp_sharding: Placement = Shard(1) - else: - # empty logsumexp, replicated - logsumexp_sharding = Replicate() - - num_heads_dim_sharding = [ - output_sharding, - logsumexp_sharding, - None, - None, - qkv_sharding, - qkv_sharding, - qkv_sharding, - ] - if has_attn_bias: - num_heads_dim_sharding.append(Shard(1)) - single_mesh_dim_strategies.append(num_heads_dim_sharding) - - return expand_to_full_mesh_op_strategy( - mesh, - op_schema, - single_mesh_dim_strategies, - input_index=4, - ) - - -@register_op_strategy(aten._scaled_dot_product_efficient_attention_backward.default) -def scaled_dot_product_efficient_attention_backward_strategy( - mesh: DeviceMesh, op_schema: OpSchema -) -> OpStrategy: - q_input_strategy = op_schema.args_schema[1] - assert isinstance(q_input_strategy, OpStrategy) - # assuming q/k/v have the same shape - qkv_shape = q_input_strategy.shape - has_attn_bias = op_schema.args_schema[4] is not None - - tensor_input_indices = [ - i - for i, arg_spec in enumerate(op_schema.args_schema) - if isinstance(arg_spec, OpStrategy) - ] - - single_mesh_dim_strategies = [] - - # placement list stores placements of [outputs, inputs] - # in the spda backward case, we have 4 tensor outputs and 8 or 9 tensor inputs - # NOTE: Output sharding of grad_bias on heads dim if attn_bias is present; - # otherwise grad_bias will be empty and its DTensorSpec will be removed. - # first we can always accept full replication for both inputs and outputs - all_replicate: PlacementList = [Replicate()] * (12 + has_attn_bias) - - if not has_attn_bias: - all_replicate[3] = None # grad bias is None if attn_bias is not present - - single_mesh_dim_strategies.append(all_replicate) - - # second we can accept the sharding pattern of tensor parallelism, which - # shard on the heads dimension - grad_output_sharding = Shard(1) - qkv_sharding = Shard(1) - output_sharding = Shard(1) - logsumexp_sharding = Shard(1) - grad_qkv_sharding = Shard(1) - grad_bias_sharding = Shard(1) if has_attn_bias else None - - num_heads_dim_sharding: PlacementList = [ - grad_qkv_sharding, - grad_qkv_sharding, - grad_qkv_sharding, - grad_bias_sharding, - grad_output_sharding, - qkv_sharding, - qkv_sharding, - qkv_sharding, - # the place for optional input attn_bias, - output_sharding, - logsumexp_sharding, - ] - # input sharding of attn_bias on heads dim if present - if has_attn_bias: - num_heads_dim_sharding.insert(8, Shard(1)) - # accept replicate on the rest scalar tensor inputs - # namely philox_seed and philox_offset - num_heads_dim_sharding.extend([Replicate(), Replicate()]) - single_mesh_dim_strategies.append(num_heads_dim_sharding) - - return expand_to_full_mesh_op_strategy( - mesh, - op_schema, - single_mesh_dim_strategies, - input_index=4, - ) diff --git a/src/chop/distributed/tensor/ops/pointwise_ops.py b/src/chop/distributed/tensor/ops/pointwise_ops.py deleted file mode 100644 index f69a31577..000000000 --- a/src/chop/distributed/tensor/ops/pointwise_ops.py +++ /dev/null @@ -1,642 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -from typing import List, Sequence, Tuple - -import torch -from torch.distributed._tensor._op_schema import ( - _is_inplace_op, - _is_out_variant_op, - OpSchema, - OpStrategy, - PlacementStrategy, - RuntimeSchemaInfo, - StrategyType, - TupleStrategy, -) -from torch.distributed._tensor.ops.utils import ( - infer_broadcast_dims_map, - map_placements_after_broadcast, - normalize_dim, -) -from torch.distributed._tensor.placement_types import ( - DTensorSpec, - Partial, - Placement, - Replicate, - Shard, -) -from torch.distributed.device_mesh import DeviceMesh -from chop.distributed.tensor.ops.utils import register_op_strategy - -aten = torch.ops.aten - -linear_pointwise_ops = [ - aten.div.Scalar, # this op is linear on the first argument, and the second argument is scalar, so it fits as a linear op. - aten.div_.Scalar, # this op is linear on the first argument, and the second argument is scalar, so it fits as a linear op. - aten.to.dtype, - aten.add.Tensor, - aten.add_.Tensor, -] - - -pointwise_ops = [ - # please keep the entries below alphabetically sorted - aten.__ilshift__.Scalar, - aten.__ilshift__.Tensor, - aten.__irshift__.Scalar, - aten.__irshift__.Tensor, - aten.__lshift__.Scalar, - aten.__lshift__.Tensor, - aten.__rshift__.Scalar, - aten.__rshift__.Tensor, - aten._conj.default, - aten.abs.default, - aten.abs.out, - aten.abs_.default, - aten.acos.default, - aten.acos.out, - aten.acos_.default, - aten.acosh.default, - aten.acosh.out, - aten.acosh_.default, - aten.add.Scalar, - aten.add.out, - aten.add_.Scalar, - aten.addcdiv.default, - aten.addcdiv.out, - aten.addcdiv_.default, - aten.addcmul.default, - aten.addcmul.out, - aten.addcmul_.default, - aten.angle.default, - aten.angle.out, - aten.asin.default, - aten.asin.out, - aten.asin_.default, - aten.asinh.default, - aten.asinh.out, - aten.asinh_.default, - aten.atan.default, - aten.atan.out, - aten.atan2.default, - aten.atan2.out, - aten.atan2_.default, - aten.atan_.default, - aten.atanh.default, - aten.atanh.out, - aten.atanh_.default, - aten.bitwise_and.Scalar, - aten.bitwise_and.Scalar_Tensor, - aten.bitwise_and.Scalar_out, - aten.bitwise_and.Tensor, - aten.bitwise_and.Tensor_out, - aten.bitwise_and_.Scalar, - aten.bitwise_and_.Tensor, - aten.bitwise_left_shift.Scalar_Tensor, - aten.bitwise_left_shift.Tensor, - aten.bitwise_left_shift.Tensor_Scalar, - aten.bitwise_left_shift.Tensor_Scalar_out, - aten.bitwise_left_shift.Tensor_out, - aten.bitwise_left_shift_.Tensor, - aten.bitwise_left_shift_.Tensor_Scalar, - aten.bitwise_not.default, - aten.bitwise_not.out, - aten.bitwise_not_.default, - aten.bitwise_or.Scalar, - aten.bitwise_or.Scalar_Tensor, - aten.bitwise_or.Scalar_out, - aten.bitwise_or.Tensor, - aten.bitwise_or.Tensor_out, - aten.bitwise_or_.Scalar, - aten.bitwise_or_.Tensor, - aten.bitwise_right_shift.Scalar_Tensor, - aten.bitwise_right_shift.Tensor, - aten.bitwise_right_shift.Tensor_Scalar, - aten.bitwise_right_shift.Tensor_Scalar_out, - aten.bitwise_right_shift.Tensor_out, - aten.bitwise_right_shift_.Tensor, - aten.bitwise_right_shift_.Tensor_Scalar, - aten.bitwise_xor.Scalar, - aten.bitwise_xor.Scalar_Tensor, - aten.bitwise_xor.Scalar_out, - aten.bitwise_xor.Tensor, - aten.bitwise_xor.Tensor_out, - aten.bitwise_xor_.Scalar, - aten.bitwise_xor_.Tensor, - aten.ceil.default, - aten.ceil.out, - aten.ceil_.default, - aten.clamp.default, - aten.clamp.out, - aten.clamp_.default, - aten.clip.default, - aten.clip.out, - aten.clip_.default, - aten.conj_physical.default, - aten.conj_physical.out, - aten.conj_physical_.default, - aten.copysign.Scalar, - aten.copysign.Scalar_out, - aten.copysign.Tensor, - aten.copysign.out, - aten.copysign_.Scalar, - aten.copysign_.Tensor, - aten.cos.default, - aten.cos.out, - aten.cos_.default, - aten.cosh.default, - aten.cosh.out, - aten.cosh_.default, - aten.deg2rad.default, - aten.deg2rad.out, - aten.deg2rad_.default, - aten.digamma.default, - aten.digamma.out, - aten.digamma_.default, - aten.div.Tensor, - aten.div.Tensor_mode, - aten.div.out, - aten.div.out_mode, - aten.div_.Tensor, - aten.div_.Tensor_mode, - aten.eq.Tensor, - aten.eq.Tensor_out, - aten.eq.Scalar, - aten.eq.Scalar_out, - aten.erf.default, - aten.erf.out, - aten.erf_.default, - aten.erfc.default, - aten.erfc.out, - aten.erfc_.default, - aten.erfinv.default, - aten.erfinv.out, - aten.erfinv_.default, - aten.exp.default, - aten.exp.out, - aten.exp2.default, - aten.exp2.out, - aten.exp2_.default, - aten.exp_.default, - aten.expm1.default, - aten.expm1.out, - aten.expm1_.default, - aten.float_power.Scalar, - aten.float_power.Scalar_out, - aten.float_power.Tensor_Scalar, - aten.float_power.Tensor_Scalar_out, - aten.float_power.Tensor_Tensor, - aten.float_power.Tensor_Tensor_out, - aten.float_power_.Scalar, - aten.float_power_.Tensor, - aten.floor.default, - aten.floor.out, - aten.floor_.default, - aten.fmod.Scalar, - aten.fmod.Scalar_out, - aten.fmod.Tensor, - aten.fmod.Tensor_out, - aten.fmod_.Scalar, - aten.fmod_.Tensor, - aten.frac.default, - aten.frac.out, - aten.frac_.default, - aten.ge.Scalar, - aten.ge.Tensor, - aten.gelu.default, - aten.gt.Tensor, - aten.gt.Tensor_out, - aten.gt.Scalar, - aten.gt.Scalar_out, - aten.gt.Scalar, - aten.gt.Tensor, - aten.hypot.default, - aten.hypot.out, - aten.hypot_.default, - aten.i0.default, - aten.i0.out, - aten.i0_.default, - aten.igamma.default, - aten.igamma.out, - aten.igamma_.default, - aten.igammac.default, - aten.igammac.out, - aten.igammac_.default, - aten.isnan.default, - aten.ldexp.default, - aten.ldexp.out, - aten.ldexp_.default, - aten.lt.Tensor, - aten.lt.Tensor_out, - aten.lt.Scalar, - aten.lt.Scalar_out, - aten.le.Scalar, - aten.le.Tensor, - aten.lerp.Scalar, - aten.lerp.Scalar_out, - aten.lerp.Tensor, - aten.lerp.Tensor_out, - aten.lerp_.Scalar, - aten.lerp_.Tensor, - aten.lgamma.default, - aten.lgamma.out, - aten.lgamma_.default, - aten.log.default, - aten.log.out, - aten.log10.default, - aten.log10.out, - aten.log10_.default, - aten.log1p.default, - aten.log1p.out, - aten.log1p_.default, - aten.log2.default, - aten.log2.out, - aten.log2_.default, - aten.log_.default, - aten.logaddexp.default, - aten.logaddexp.out, - aten.logaddexp2.default, - aten.logaddexp2.out, - aten.logical_and.default, - aten.logical_and.out, - aten.logical_and_.default, - aten.logical_not.default, - aten.logical_not.out, - aten.logical_not_.default, - aten.logical_or.default, - aten.logical_or.out, - aten.logical_or_.default, - aten.logical_xor.default, - aten.logical_xor.out, - aten.logical_xor_.default, - aten.logit.default, - aten.logit.out, - aten.logit_.default, - aten.masked_fill.Scalar, - aten.maximum.out, - aten.mul.Scalar, - aten.mul.Tensor, - aten.mul.out, - aten.mul_.Scalar, - aten.mul_.Tensor, - aten.mvlgamma.default, - aten.mvlgamma.out, - aten.mvlgamma_.default, - aten.native_dropout_backward.default, - aten.native_dropout_backward.out, - aten.nan_to_num.default, - aten.nan_to_num.out, - aten.nan_to_num_.default, - aten.ne.Scalar, - aten.neg.default, - aten.neg.out, - aten.neg_.default, - aten.nextafter.default, - aten.nextafter.out, - aten.nextafter_.default, - aten.polygamma.default, - aten.polygamma.out, - aten.polygamma_.default, - aten.positive.default, - aten.pow.Scalar, - aten.pow.Scalar_out, - aten.pow.Tensor_Scalar, - aten.pow.Tensor_Scalar_out, - aten.pow.Tensor_Tensor, - aten.pow.Tensor_Tensor_out, - aten.pow_.Scalar, - aten.pow_.Tensor, - aten.reciprocal.default, - aten.reciprocal.out, - aten.reciprocal_.default, - aten.rad2deg.default, - aten.rad2deg.out, - aten.rad2deg_.default, - aten.relu.default, - aten.relu_.default, - aten.remainder.Scalar, - aten.remainder.Scalar_Tensor, - aten.remainder.Scalar_out, - aten.remainder.Tensor, - aten.remainder.Tensor_out, - aten.remainder_.Scalar, - aten.remainder_.Tensor, - aten.round.decimals, - aten.round.decimals_out, - aten.round.default, - aten.round.out, - aten.round_.decimals, - aten.round_.default, - aten.rsqrt.default, - aten.rsqrt.out, - aten.rsqrt_.default, - aten.rsub.Scalar, - aten.sgn.default, - aten.sgn.out, - aten.sgn_.default, - aten.sigmoid.default, - aten.sigmoid.out, - aten.sigmoid_.default, - aten.sign.default, - aten.sign.out, - aten.sign_.default, - aten.signbit.default, - aten.signbit.out, - aten.silu.default, - aten.silu.out, - aten.sin.default, - aten.sin.out, - aten.sin_.default, - aten.sinc.default, - aten.sinc.out, - aten.sinc_.default, - aten.sinh.default, - aten.sinh.out, - aten.sinh_.default, - aten.sqrt.default, - aten.sqrt.out, - aten.sqrt_.default, - aten.square.default, - aten.square.out, - aten.square_.default, - aten.sub.Scalar, - aten.sub.Tensor, - aten.sub.out, - aten.sub_.Scalar, - aten.sub_.Tensor, - aten.tan.default, - aten.tan.out, - aten.tan_.default, - aten.tanh.default, - aten.tanh.out, - aten.tanh_.default, - aten.true_divide.Tensor, - aten.trunc.default, - aten.trunc.out, - aten.trunc_.default, - aten.where.self, - aten.where.self_out, - aten.xlogy.OutScalar_Self, - aten.xlogy.OutScalar_Other, - aten.xlogy.OutTensor, - aten.xlogy.Scalar_Other, - aten.xlogy.Scalar_Self, - aten.xlogy.Tensor, - aten.xlogy_.Scalar_Other, - aten.xlogy_.Tensor, - # backward point-wise ops - # please keep the entries below alphabetically sorted - aten.gelu_backward.default, - aten.sigmoid_backward.default, - aten.silu_backward.default, - aten.tanh_backward.default, - aten.threshold_backward.default, -] - - -def pointwise_strategy( - mesh: DeviceMesh, op_schema: OpSchema, linearity: bool = False -) -> OpStrategy: - max_shards_strategy_index = -1 - max_shards = -1 - - if _is_inplace_op(op_schema.op): - # inplace op should follow the first arg strategy - followed_strategy = op_schema.args_schema[0] - elif _is_out_variant_op(op_schema.op): - # out variant op should follow the out kwarg strategy - followed_strategy = op_schema.kwargs_schema["out"] - else: - # normal pointwise op, we choose to follow the arg with - # the max shards in case operands needs reshard - for idx, arg_strategy in enumerate(op_schema.args_schema): - if not isinstance(arg_strategy, OpStrategy): - continue - - arg_max_shards = arg_strategy.max_num_shards() - if arg_max_shards > max_shards: - max_shards_strategy_index = idx - max_shards = arg_max_shards - - followed_strategy = op_schema.args_schema[max_shards_strategy_index] - - assert isinstance( - followed_strategy, OpStrategy - ), f"no strategy to follow for {op_schema}!" - return common_pointwise_strategy( - mesh, op_schema.args_schema, followed_strategy, linearity - ) - - -def common_pointwise_strategy( - mesh: DeviceMesh, - args_schema: Sequence[object], - followed_strategy: OpStrategy, - linearity: bool, -) -> OpStrategy: - # handle broadcasting - common_shape = torch.broadcast_shapes( - *[arg.shape for arg in args_schema if isinstance(arg, OpStrategy)] - ) - pointwise_strategy = OpStrategy([]) - - for placement_strategy in followed_strategy.strategies: - spec_to_follow = placement_strategy.output_spec - out_placements: List[Placement] = [] - for placement in spec_to_follow.placements: - if isinstance(placement, Shard): - shard_dim = normalize_dim(placement.dim, len(spec_to_follow.shape)) - common_ndim = len(common_shape) - new_shard_dim = common_ndim - len(spec_to_follow.shape) + shard_dim - out_placements.append(Shard(new_shard_dim)) - elif isinstance(placement, Partial) and not linearity: - # clear the partial placemnet if op does not support linearity - # by default we just replicate the partial, need to see if this - # is optimal for all cases - out_placements.append(Replicate()) - else: - out_placements.append(placement) - - input_specs: List[DTensorSpec] = [] - for input_arg in args_schema: - if isinstance(input_arg, OpStrategy): - # every arg follow the out_placements, but need to handle broadcasting - input_arg_spec = input_arg.strategies[0].output_spec - input_arg_dims_map = infer_broadcast_dims_map( - common_shape, input_arg_spec.shape - ) - input_target_placements = map_placements_after_broadcast( - tuple(out_placements), - common_shape, - input_arg_dims_map, - ) - input_arg_target_spec = DTensorSpec( - mesh=mesh, - placements=input_target_placements, - tensor_meta=input_arg_spec.tensor_meta, - ) - input_specs.append(input_arg_target_spec) - - pointwise_strategy.strategies.append( - PlacementStrategy( - output_specs=DTensorSpec( - mesh=mesh, - placements=tuple(out_placements), - ), - input_specs=input_specs, - redistribute_cost=[], - ) - ) - return pointwise_strategy - - -def linear_pointwise_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: - """ - Linear pointwise operators can propagate pending reductions. - For example, c = add(a, b); if a is pending sum, then c will be - pending sum as well without any communication overhead. - """ - return pointwise_strategy(mesh, op_schema, linearity=True) - - -for op in linear_pointwise_ops: - register_op_strategy(op, schema_info=RuntimeSchemaInfo(static_kwargkey=["out"]))( - linear_pointwise_strategy - ) - -for op in pointwise_ops: - register_op_strategy(op, schema_info=RuntimeSchemaInfo(static_kwargkey=["out"]))( - pointwise_strategy - ) - - -# TODO: add all for_each ops -for_each_ops = [ - aten._foreach_abs.default, - aten._foreach_abs_.default, - aten._foreach_addcdiv_.Scalar, - aten._foreach_addcdiv_.ScalarList, - aten._foreach_addcdiv_.Tensor, - aten._foreach_addcmul.Scalar, - aten._foreach_addcmul_.Scalar, - aten._foreach_addcmul_.ScalarList, - aten._foreach_addcmul_.Tensor, - aten._foreach_clamp_max_.Scalar, - aten._foreach_clamp_min_.Scalar, - aten._foreach_div_.List, - aten._foreach_div_.ScalarList, - aten._foreach_lerp_.Scalar, - aten._foreach_maximum_.List, - aten._foreach_mul.Scalar, - aten._foreach_mul.List, - aten._foreach_mul_.Scalar, - aten._foreach_mul_.ScalarList, - aten._foreach_mul_.Tensor, - aten._foreach_mul_.List, - aten._foreach_neg.default, - aten._foreach_neg_.default, - aten._foreach_reciprocal_.default, - aten._foreach_sub.List, - aten._foreach_sub_.Scalar, - aten._foreach_sqrt.default, - aten._foreach_sqrt_.default, - aten._foreach_zero_.default, -] - -for_each_linearity_ops = [ - aten._foreach_add.Scalar, - aten._foreach_add_.Scalar, - aten._foreach_add_.ScalarList, - aten._foreach_add.List, - aten._foreach_add_.List, -] - - -def list_pointwise_strategy( - mesh: DeviceMesh, op_schema: OpSchema, linearity: bool = False -) -> StrategyType: - """ - Apply the pointwise strategy to the zipped arguments. For example, if we - run a foreach add of two lists l1 and l2, then we apply the pointwise - strategy on each pair (l1[i], l2[i]). If the first argument is a list but - the second (or later) one is a tensor, then we broadcast the tensor by - replicating it into a list with the length of the first argument. - - Args: - mesh (DeviceMesh): device mesh for pointwise ops - op_schema (OpSchema): schema of the operator to generate strategy for - linearity (bool): specify whether op(a) + op(b) = op(a + b) - - Returns: - OpStrategy: generated strategy - """ - - def args_tuple_strategies(args_schema: Tuple[object, ...]) -> List[TupleStrategy]: - first_arg = args_schema[0] - assert isinstance(first_arg, TupleStrategy) - strategy_len = len(first_arg.childs) - tuple_strategies: List[TupleStrategy] = [] - for arg_idx, arg in enumerate(args_schema): - if isinstance(arg, TupleStrategy): - # every tuple strategy should have the same length - assert len(arg.childs) == strategy_len - tuple_strategies.append(arg) - elif isinstance(arg, OpStrategy): - if arg_idx > 0: # implicitly broadcast - tuple_strategies.append( - TupleStrategy([arg for _ in range(strategy_len)]) - ) - else: - raise RuntimeError( - f"list op only supports tuple strategy! {op_schema}" - ) - return tuple_strategies - - args_strategies = args_tuple_strategies(op_schema.args_schema) - follow_strategy: TupleStrategy = args_strategies[0] - list_strategy: List[OpStrategy] = [] - for child_idx, child_strtgy in enumerate(follow_strategy.childs): - assert isinstance(child_strtgy, OpStrategy) - args_schema: List[StrategyType] = [ - arg_strategy.childs[child_idx] for arg_strategy in args_strategies - ] - pointwise_strategy: OpStrategy = common_pointwise_strategy( - mesh, args_schema, child_strtgy, linearity - ) - list_strategy.append(pointwise_strategy) - return TupleStrategy(list_strategy) - - -def list_linear_pointwise_strategy( - mesh: DeviceMesh, op_schema: OpSchema -) -> StrategyType: - """ - for each list op stratgy that supports linearity - """ - return list_pointwise_strategy(mesh, op_schema, linearity=True) - - -for op in for_each_ops: - register_op_strategy(op, schema_info=RuntimeSchemaInfo(needs_pytree=True))( - list_pointwise_strategy - ) - -for op in for_each_linearity_ops: - register_op_strategy(op, schema_info=RuntimeSchemaInfo(needs_pytree=True))( - list_linear_pointwise_strategy - ) - -fused_ops = [ - aten._fused_adam_.default, - aten._fused_adam.default, - aten._fused_adam.tensor_lr, - aten._fused_adam_.tensor_lr, - aten._fused_adamw_.default, - aten._fused_adamw.default, - aten._fused_adamw.tensor_lr, - aten._fused_adamw_.tensor_lr, -] - -for op in fused_ops: - register_op_strategy(op, schema_info=RuntimeSchemaInfo(needs_pytree=True))( - list_pointwise_strategy - ) diff --git a/src/chop/distributed/tensor/ops/random_ops.py b/src/chop/distributed/tensor/ops/random_ops.py deleted file mode 100644 index 7eefa30fc..000000000 --- a/src/chop/distributed/tensor/ops/random_ops.py +++ /dev/null @@ -1,38 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -import torch -from torch.distributed._tensor._op_schema import ( - OpSchema, - OpStrategy, - PlacementStrategy, - StrategyType, -) -from torch.distributed._tensor.ops.utils import is_tensor_partial -from chop.distributed.tensor.ops.utils import register_op_strategy -from torch.distributed.device_mesh import DeviceMesh - - -aten = torch.ops.aten - - -@register_op_strategy( - [ - aten.normal_.default, - aten.uniform_.default, - aten.native_dropout.default, - aten.bernoulli_.float, - aten.bernoulli.default, - ] -) -def random_op_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: - self_strategy = op_schema.args_schema[0] - assert isinstance(self_strategy, OpStrategy) - - random_strategy = OpStrategy([]) - for arg_strategy in self_strategy.strategies: - arg_spec = arg_strategy.output_spec - if is_tensor_partial(arg_spec): - # TODO: figure out how inplace random op should behave when it's partial - raise RuntimeError(f"{op_schema.op} with Partial is not supported yet!") - random_strategy.strategies.append(PlacementStrategy(output_specs=arg_spec)) - - return random_strategy diff --git a/src/chop/distributed/tensor/ops/tensor_ops.py b/src/chop/distributed/tensor/ops/tensor_ops.py deleted file mode 100644 index 6147a16fc..000000000 --- a/src/chop/distributed/tensor/ops/tensor_ops.py +++ /dev/null @@ -1,797 +0,0 @@ -# mypy: allow-untyped-defs -# Copyright (c) Meta Platforms, Inc. and affiliates -from typing import cast, List, Optional, Sequence, Tuple - -import torch -from torch.distributed._tensor._op_schema import ( - _is_inplace_op, - OpSchema, - OpStrategy, - OutputSharding, - PlacementList, - PlacementStrategy, - RuntimeSchemaInfo, - StrategyType, - TupleStrategy, -) -from torch.distributed._tensor.placement_types import ( - DTensorSpec, - Partial, - Placement, - Replicate, - Shard, -) -from torch.distributed.device_mesh import DeviceMesh - -from chop.distributed.tensor.ops.utils import register_op_strategy -from chop.distributed.tensor.ops.common_rules import pointwise_rule -from chop.distributed.tensor.ops.embedding_ops import _MaskPartial -from chop.distributed.tensor.ops.utils import ( - expand_to_full_mesh_op_strategy, - is_tensor_dim_sharded, - is_tensor_evenly_shardable, - is_tensor_partial, - normalize_dim, - register_prop_rule, -) - -aten = torch.ops.aten - - -def default_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: - # Default strategy by default just propagate the first input strategy - select_strategy = op_schema.args_schema[0] - assert isinstance(select_strategy, OpStrategy) - default_strategy = [] - for strategy in select_strategy.strategies: - # we create new DTensorSpecs even for default strategy to assure that - # the tensor metas are distinct between the arguments and outputs - default_strategy.append( - PlacementStrategy( - output_specs=DTensorSpec( - mesh=strategy.output_spec.mesh, - placements=strategy.output_spec.placements, - ) - ) - ) - return OpStrategy(default_strategy) - - -register_op_strategy( - [ - aten.clone.default, - aten.contiguous.default, - aten.copy_.default, - aten.detach.default, - aten.fill_.Scalar, - aten.zero_.default, - ] -)(default_strategy) - -register_op_strategy( - aten._to_copy.default, schema_info=RuntimeSchemaInfo(static_kwargkey=["dtype"]) -)(default_strategy) - - -@register_op_strategy( - [ - aten.equal.default, - aten.is_same_size.default, - ] -) -def equal_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: - # equal_strategy deals with ops that comparing two tensor, we need to make sure - # sharding layout the same with two operands, we choose to follow the arg with max - # num of shards, still keep is_same_size here for completeness as they share the - # same strategy in theory. - self_strategy, other_strategy = op_schema.args_schema - assert isinstance(self_strategy, OpStrategy) - assert isinstance(other_strategy, OpStrategy) - - select_strategy = ( - self_strategy - if self_strategy.max_num_shards() >= other_strategy.max_num_shards() - else other_strategy - ) - equal_strategy = OpStrategy([]) - - for arg_strategy in select_strategy.strategies: - arg_spec = arg_strategy.output_spec - if is_tensor_partial(arg_spec): - # if the arg_spec have partial, reshard to replicate - # otherwise local shard tensor comparison would be invalid - output_spec = DTensorSpec( - mesh=arg_spec.mesh, - placements=tuple( - Replicate() if isinstance(p, Partial) else p - for p in arg_spec.placements - ), - ) - equal_strategy.strategies.append( - PlacementStrategy(output_specs=output_spec) - ) - else: - equal_strategy.strategies.append(PlacementStrategy(arg_spec)) - return equal_strategy - - -@register_op_strategy( - [ - aten.empty_like.default, - aten.ones_like.default, - aten.rand_like.default, - aten.randn_like.default, - aten.zeros_like.default, - ], - schema_info=RuntimeSchemaInfo(1, ["dtype"]), -) -@register_op_strategy( - [aten.full_like.default], - schema_info=RuntimeSchemaInfo(2, ["dtype"]), -) -@register_op_strategy( - [ - aten.randint_like.default, - aten.randint_like.low_dtype, - aten.randint_like.low_dtype_out, - ], - schema_info=RuntimeSchemaInfo(3, ["dtype"]), -) -def create_like_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: - # create_like_strategy deals with ops that creating tensors with same - # shape as input, but with specific content that does not depend on - # the input, we can propagate sharding, but we have to make sure we - # move from partial to replicated. - select_strategy = op_schema.args_schema[0] - create_like_strategy = OpStrategy([]) - assert isinstance(select_strategy, OpStrategy) - for arg_strategy in select_strategy.strategies: - arg_spec = arg_strategy.output_spec - if is_tensor_partial(arg_spec): - # if the arg_spec have partial, accept partial - # in the input_specs but output replicate for - # those corresponding mesh dims - output_spec = DTensorSpec( - mesh=arg_spec.mesh, - placements=tuple( - Replicate() if isinstance(p, Partial) else p - for p in arg_spec.placements - ), - ) - create_like_strategy.strategies.append( - PlacementStrategy(output_specs=output_spec, input_specs=(arg_spec,)) - ) - - else: - create_like_strategy.strategies.append(PlacementStrategy(arg_spec)) - - return create_like_strategy - - -@register_op_strategy( - [ - aten.new_empty.default, - aten.new_full.default, - aten.new_ones.default, - aten.new_zeros.default, - aten.new_empty_strided.default, - ], - schema_info=RuntimeSchemaInfo(1, ["dtype"]), -) -def new_factory_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: - # Currently there are two strategies: - # 1. let the output be replicated - # 2. let the output follow the input if input and output have the same shape - input_strategy = op_schema.args_schema[0] - assert isinstance(input_strategy, OpStrategy) - input_shape = input_strategy.shape - output_shape = op_schema.args_schema[1] - assert isinstance(output_shape, list) - - new_factory_strategy = OpStrategy([]) - for arg_strategy in input_strategy.strategies: - input_spec = arg_strategy.output_spec - replica_spec = DTensorSpec(mesh, tuple([Replicate()] * mesh.ndim)) - new_factory_strategy.strategies.append( - PlacementStrategy( - output_specs=replica_spec, - input_specs=(input_spec,), - redistribute_cost=[[0.0] * mesh.ndim], - ) - ) - - if tuple(input_shape) == tuple(output_shape) and input_spec.is_sharded(): - # NOTE: for new_empty_strided, currently the non-replicate sharding - # is supported only when the shape is evenly shardable - if ( - op_schema.op == aten.new_empty_strided.default - and not is_tensor_evenly_shardable(input_shape, input_spec) - ): - continue - - new_factory_strategy.strategies.append( - PlacementStrategy( - output_specs=input_spec, - input_specs=(input_spec,), - # encouraging new tensor placement to be the same as input - redistribute_cost=[[-0.1] * mesh.ndim], - ) - ) - - return new_factory_strategy - - -@register_op_strategy(aten.bucketize.Tensor) -def gen_bucketize_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: - """Just propagate input sharding, but expect replicated for boundaries input.""" - input_strategy = op_schema.args_schema[0] - bucketize_strategy = OpStrategy([]) - assert isinstance(input_strategy, OpStrategy) - for arg_strategy in input_strategy.strategies: - arg_spec = DTensorSpec(mesh, arg_strategy.output_spec.placements) - replica_spec = DTensorSpec(mesh, tuple([Replicate()] * mesh.ndim)) - bucketize_strategy.strategies.append( - PlacementStrategy( - output_specs=arg_spec, input_specs=(arg_spec, replica_spec) - ) - ) - - return bucketize_strategy - - -@register_op_strategy(aten.slice.Tensor, schema_info=RuntimeSchemaInfo(1)) -def gen_slice_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: - """Forward all shardings except the slice dimension.""" - defaults = (None, 0, None, None, 1) - input_strategy, dim, start, end, step = ( - op_schema.args_schema + defaults[len(op_schema.args_schema) :] - ) - assert isinstance(input_strategy, OpStrategy) - input_shape = input_strategy.shape - input_ndim = input_strategy.ndim - assert isinstance(dim, int) - if start is None: - start = 0 - if end is None or end > input_shape[dim]: - end = input_shape[dim] - assert isinstance(start, int) - assert isinstance(end, int) - assert isinstance(step, int) - - # normalize args - slice_dim = normalize_dim(dim, input_ndim) - start = normalize_dim(start, input_shape[dim]) - end = normalize_dim(end, input_shape[dim]) - - redundant_slice = start == 0 and end == input_shape[dim] and step == 1 - - slice_strategy = OpStrategy([]) - - for arg_strategy in input_strategy.strategies: - arg_spec = arg_strategy.output_spec - if not is_tensor_dim_sharded(arg_spec, dim=slice_dim) or redundant_slice: - # only add the strategy if the slice dim is not sharded - out_spec = DTensorSpec(mesh, arg_spec.placements) - slice_strategy.strategies.append(PlacementStrategy(output_specs=out_spec)) - if not slice_strategy.strategies: - # if all strategies are filtered out, unsharding all specs on slice dim - # of the input strategy, and use that as the op strategy - for arg_strategy in input_strategy.strategies: - arg_spec = arg_strategy.output_spec - unshard_spec = DTensorSpec( - mesh, unshard_tensor_dim(arg_spec.placements, dim=slice_dim) - ) - slice_strategy.strategies.append( - PlacementStrategy(output_specs=unshard_spec) - ) - return slice_strategy - - -def unshard_tensor_dim( - placements: Sequence[Placement], dim: int -) -> Tuple[Placement, ...]: - """Disallow the given tensor dimension to be sharded.""" - return tuple( - p if (not isinstance(p, Shard) or p.dim != dim) else Replicate() - for p in placements - ) - - -def replicate_tensor_dim( - placements: Sequence[Placement], dim: int -) -> Tuple[Placement, ...]: - """Force the given tensor dimension to be replicated.""" - # Not using p.is_shard() to avoid mypy complain about Placement not having - # attribute dim. - return tuple( - Replicate() if p.is_partial() or isinstance(p, Shard) and p.dim == dim else p - for p in placements - ) - - -@register_op_strategy(aten.slice_scatter.default, schema_info=RuntimeSchemaInfo(2)) -def gen_slice_scatter_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: - # 1. number of dimensions in input and src need to match. - # 2. number of elements on all non-dim need to match between input and src. - # 3. numer of elements in src in dim need to match the slice size. - # Given the above: - # - We suggest for src to follow the sharding of input, except on the scatter dimension, - # where our best bet for now is to make them replicated as a fall-back. - # TODO: Ideally we'd like to make sure the output is re-sharded afterwards to keep input sharding. - - input_strategy = op_schema.args_schema[0] - assert isinstance(input_strategy, OpStrategy) - input_ndim = input_strategy.ndim - slice_dim = ( - cast(int, op_schema.args_schema[2]) if len(op_schema.args_schema) > 2 else 0 - ) - slice_dim = normalize_dim(slice_dim, input_ndim) - - slice_scatter_strategy = OpStrategy([]) - # by default follow the input strategy for both input and src - for arg_strategy in input_strategy.strategies: - arg_spec = arg_strategy.output_spec - if not ( - is_tensor_dim_sharded(arg_spec, dim=slice_dim) - or is_tensor_partial(arg_spec) - ): - # only add the strategy if the slice_scatter dim is not sharded or partial - slice_scatter_strategy.strategies.append( - PlacementStrategy(output_specs=arg_spec) - ) - - if not slice_scatter_strategy.strategies: - # if all strategies are filtered out, replicating all specs on slice_scatter dim - # of the input strategy, and use that as the op strategy - for arg_strategy in input_strategy.strategies: - arg_spec = arg_strategy.output_spec - replicate_spec = DTensorSpec( - mesh, replicate_tensor_dim(arg_spec.placements, dim=slice_dim) - ) - slice_scatter_strategy.strategies.append( - PlacementStrategy(output_specs=replicate_spec) - ) - return slice_scatter_strategy - - -@register_op_strategy(aten._local_scalar_dense.default) -def replica_only_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: - """Only allow replication on the input/output.""" - replicate_spec = DTensorSpec(mesh, tuple([Replicate()] * mesh.ndim)) - return OpStrategy([PlacementStrategy(replicate_spec)]) - - -@register_op_strategy( - [aten.scatter_.value, aten.scatter.value, aten.scatter_.src, aten.scatter.src], - schema_info=RuntimeSchemaInfo(1), -) -def scatter_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: - input_strategy = cast(OpStrategy, op_schema.args_schema[0]) - single_mesh_dim_strategies = [] - - # placement list stores placements of [output, input, index, src] - # first we always have replicate all for inputs and output - if len(op_schema.args_strategy) < 3: - # scatter_.src/scatter.src with src be float number instead of tensor - all_replicate: PlacementList = [Replicate()] * 3 - else: - all_replicate = [Replicate()] * 4 - single_mesh_dim_strategies.append(all_replicate) - - # TODO: see if we can support input sharding pattern - inplace_op = _is_inplace_op(op_schema.op) - - op_strategy = expand_to_full_mesh_op_strategy( - mesh, op_schema, single_mesh_dim_strategies, inplace_op=inplace_op - ) - return op_strategy - - -@register_op_strategy(aten.gather.default) -def gather_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: - input_strategy = cast(OpStrategy, op_schema.args_schema[0]) - dim = cast(int, op_schema.args_schema[1]) - index_strategy = cast(OpStrategy, op_schema.args_schema[2]) - - input_shape = input_strategy.shape - index_shape = index_strategy.shape - - single_mesh_dim_strategies = [] - - # placement list stores placements of [output, input, index] - # first we always have replicate all for inputs and output - all_replicate: PlacementList = [Replicate()] * 3 - single_mesh_dim_strategies.append(all_replicate) - - # input sharding, input sharded, index accepts mask partial, output follows index - # this only works when the input is sharded on the gather dimension, and - # index has size 1 on the gather dimension - if index_shape[dim] == 1: - index_partial_placement = _MaskPartial(logical_dim_size=input_shape[dim]) - input_sharding: PlacementList = [ - index_partial_placement, - Shard(dim), - index_partial_placement, - ] - single_mesh_dim_strategies.append(input_sharding) - - # index sharding, input replicated, index sharded, output follows index - # this only works when the sharding dimension is the gather dimension - index_sharding: PlacementList = [Shard(dim), Replicate(), Shard(dim)] - single_mesh_dim_strategies.append(index_sharding) - - return expand_to_full_mesh_op_strategy( - mesh, op_schema, single_mesh_dim_strategies, input_index=1 - ) - - -def _derive_follow_placements_from_tuple_strategy( - tuple_strategy: TupleStrategy, -) -> Sequence[Placement]: - """ - derive the placements to follow from the tuple strategy, mainly used by - aten.stack, aten.cat, where each operand have the same shape, and correspondingly - expecting the same sharding - """ - - def merge_placement( - cur_placement: Placement, new_placement: Placement - ) -> Placement: - # semantic if we already have a follow placement, we - # check each placement for the current arg placement - # to see if we want to merge/adjust the placement to follow - # the priority: Partial -> Shard -> Replicate - if cur_placement == new_placement: - return cur_placement - - if cur_placement.is_partial(): - if new_placement.is_shard(): - # follow new placement - return new_placement - elif new_placement.is_partial(): - # different partial types, we can't merge and have to replicate all here - return Replicate() - else: - # follow partial - return cur_placement - elif cur_placement.is_shard(): - if new_placement.is_shard(): - # cur/new placement are different sharding (i.e. different shard dim) - # currently fallback to replicate all args - return Replicate() - else: - # for partial/replicate, follow the current shard placement - return cur_placement - else: - # current replicate, just follow new placement - return new_placement - - follow_placements: Optional[List[Placement]] = None - for arg_strategy in tuple_strategy.childs: - assert isinstance(arg_strategy, OpStrategy) - for placement_strategy in arg_strategy.strategies: - arg_placements = placement_strategy.output_spec.placements - if follow_placements is None: - follow_placements = list(arg_placements) - continue - mesh_ndim = len(follow_placements) - assert follow_placements is not None - for mesh_idx in range(mesh_ndim): - # merge placements with the priority - follow_placements[mesh_idx] = merge_placement( - follow_placements[mesh_idx], arg_placements[mesh_idx] - ) - assert follow_placements is not None, "follow placements should not be None!" - return follow_placements - - -def normalize_shard_for_stack( - placements: Sequence[Placement], insert_dim: int = 0 -) -> Sequence[Placement]: - # stack op would "insert" new dim, so all sharded dim >= the inserted dim need to - # be normalized with the new Shard placement - normalized_placements: List[Placement] = [] - for placement in placements: - if isinstance(placement, Shard) and placement.dim >= insert_dim: - normalized_placements.append(Shard(placement.dim + 1)) - else: - normalized_placements.append(placement) - return normalized_placements - - -@register_op_strategy(aten.stack.default, RuntimeSchemaInfo(1, needs_pytree=True)) -def stack_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: - args_schema = op_schema.args_schema - input_tuple_strategy = args_schema[0] - assert isinstance(input_tuple_strategy, TupleStrategy), f"{input_tuple_strategy}" - first_input_strategy = input_tuple_strategy.childs[0] - assert isinstance(first_input_strategy, OpStrategy), f"{first_input_strategy}" - common_input_ndim = first_input_strategy.ndim - dim = cast(int, args_schema[1]) if len(args_schema) > 1 else 0 - # normalize the dim to be within the common input ndim - dim = normalize_dim(dim, common_input_ndim) - - follow_placements = _derive_follow_placements_from_tuple_strategy( - input_tuple_strategy - ) - - # create op strategy base on the follow placements - op_strategy = OpStrategy([]) - - input_specs = tuple( - DTensorSpec(mesh, tuple(follow_placements)) - for _ in range(len(input_tuple_strategy.childs)) - ) - - follow_placements = normalize_shard_for_stack(follow_placements, dim) - - op_strategy.strategies.append( - PlacementStrategy( - output_specs=DTensorSpec(mesh, tuple(follow_placements)), - input_specs=input_specs, - ) - ) - return op_strategy - - -@register_op_strategy(aten.cat.default, RuntimeSchemaInfo(1, needs_pytree=True)) -def cat_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: - args_schema = op_schema.args_schema - input_tuple_strategy = args_schema[0] - assert isinstance(input_tuple_strategy, TupleStrategy), f"{input_tuple_strategy}" - first_input_strategy = input_tuple_strategy.childs[0] - assert isinstance(first_input_strategy, OpStrategy), f"{first_input_strategy}" - common_input_ndim = first_input_strategy.ndim - dim = cast(int, args_schema[1]) if len(args_schema) > 1 else 0 - # normalize the dim to be within the common input ndim - dim = normalize_dim(dim, common_input_ndim) - - follow_placements = _derive_follow_placements_from_tuple_strategy( - input_tuple_strategy - ) - # for cat we unshard the cat dim if it is sharded - follow_placements = unshard_tensor_dim(follow_placements, dim) - - # create op strategy base on the follow placements - op_strategy = OpStrategy([]) - - input_specs = tuple( - DTensorSpec(mesh, tuple(follow_placements)) - for _ in range(len(input_tuple_strategy.childs)) - ) - op_strategy.strategies.append( - PlacementStrategy( - output_specs=DTensorSpec(mesh, tuple(follow_placements)), - input_specs=input_specs, - ) - ) - return op_strategy - - -@register_prop_rule(aten.index_select.default, schema_info=RuntimeSchemaInfo(1)) -def prop_index_select(op_schema: OpSchema) -> OutputSharding: - values_spec, dim, indices_spec = op_schema.args_schema - - assert isinstance(values_spec, DTensorSpec) - assert isinstance(dim, int) - assert isinstance(indices_spec, DTensorSpec) - - all_indices_spec: List[Optional[DTensorSpec]] = [ - indices_spec if dim == i else None for i in range(values_spec.ndim) - ] - - result = prop_index( - OpSchema( - op=op_schema.op, - args_schema=(values_spec, all_indices_spec), - kwargs_schema=op_schema.kwargs_schema, - ) - ) - if result.redistribute_schema: - schema_suggestion = result.redistribute_schema - result.redistribute_schema = OpSchema( - op=op_schema.op, - args_schema=( - schema_suggestion.args_schema[0], - dim, - schema_suggestion.args_schema[1][dim], - ), - kwargs_schema=op_schema.kwargs_schema, - ) - return result - - -@register_prop_rule(aten.index.Tensor, schema_info=RuntimeSchemaInfo(needs_pytree=True)) -def prop_index(op_schema: OpSchema) -> OutputSharding: - """ - Expect replicated on the first input; _mostly_ pointwise on the second input. - - TODO: exception: when the dtype of second input is "bool", then a torch.nonzero needs to be triggered first. - """ - # Current sharding constraints: - # For values: - # 1. We currently require that the dimension of values_spec be replicated or partial - # if they are being indexed on. - # 2. Other dimensions of values_spec can remain sharded if they are so. - # For indices: - # Indices can be either sharded or replicated. All index tensors need to be sharded - # in a compatible way, following the pointwise rule (including resolving Partial - # into either sharded or replicated) - - values_spec, multi_indices_spec = op_schema.args_schema - assert isinstance(values_spec, DTensorSpec) - assert isinstance(multi_indices_spec, list) - multi_indices_spec = cast(List[Optional[DTensorSpec]], multi_indices_spec) - valid_indices_spec: List[Tuple[int, DTensorSpec]] = [ - (i, a) for i, a in enumerate(multi_indices_spec) if a is not None - ] - - # 1. All indices have to be sharded equally. Moreover, indices can be broadcast. - # Here, we piggyback on the pointwise sharding rule for indices. - indices_out = pointwise_rule( - OpSchema( - op=op_schema.op, - args_schema=tuple(v[1] for v in valid_indices_spec), - kwargs_schema={}, - ) - ) - need_reshard_on_indices = indices_out.output_spec is None - - if not need_reshard_on_indices: - # this means that our inputs are already sharded properly and we will use that as our indices_spec - assert isinstance(indices_out.output_spec, DTensorSpec) - indices_spec: DTensorSpec = indices_out.output_spec - else: - assert indices_out.redistribute_schema is not None - valid_indices_suggestion = indices_out.redistribute_schema - for i, v in enumerate(valid_indices_suggestion.args_spec): - multi_indices_spec[valid_indices_spec[i][0]] = v - # we'll need to call pointwise_rule again to see what's our ideal indices_spec and then - # use that to compute our ideal values_spec - indices_output_spec = pointwise_rule(valid_indices_suggestion).output_spec - assert isinstance(indices_output_spec, DTensorSpec) - indices_spec = indices_output_spec - - lookup_dims = {v[0] for v in valid_indices_spec} - - need_reshard_on_values = tuple( - (isinstance(vp, Shard) and (vp.dim in lookup_dims or isinstance(ip, Shard))) - for vp, ip in zip(values_spec.placements, indices_spec.placements) - ) - - if not need_reshard_on_indices and not any(need_reshard_on_values): - value_placements = values_spec.placements - - all_dims_consecutive = all( - b[0] - a[0] == 1 - for b, a in zip(valid_indices_spec[1:], valid_indices_spec[:-1]) - ) - if all_dims_consecutive: - # if all index vectors are consecutives, insert at the dimension of the first index - insert_dim: int = valid_indices_spec[0][0] - else: - # else, insert on the first dimension - insert_dim = 0 - - def place(vp: Placement, ip: Placement) -> Placement: - if isinstance(vp, Shard): - return Shard( - vp.dim - if vp.dim < insert_dim - # accounts for the offset in output dimensions - else vp.dim - + indices_spec.ndim - - sum(1 if vp.dim > v[0] else 0 for v in valid_indices_spec) - ) - if isinstance(ip, Shard): - return Shard(ip.dim + insert_dim) - # Partial or Replicated - return vp - - value_placements = tuple( - place(vp, ip) - for vp, ip in zip(values_spec.placements, indices_spec.placements) - ) - result = OutputSharding( - output_spec=DTensorSpec( - mesh=values_spec.mesh, - placements=value_placements, - ) - ) - return result - else: - result = OutputSharding( - output_spec=None, - redistribute_schema=OpSchema( - op=op_schema.op, - args_schema=( - DTensorSpec( - mesh=values_spec.mesh, - placements=tuple( - [ - Replicate() if need_reshard_on_values[i] else v - for i, v in enumerate(values_spec.placements) - ] - ), - tensor_meta=values_spec.tensor_meta, - ), - multi_indices_spec, - ), - kwargs_schema=op_schema.kwargs_schema, - ), - ) - return result - - -@register_prop_rule( - [ - aten.split.Tensor, - aten.split_with_sizes.default, - aten.split_with_sizes_copy.default, - ], - schema_info=RuntimeSchemaInfo(1), -) -def split_rule(op_schema: OpSchema) -> OutputSharding: - output_spec_list: List[DTensorSpec] = [] - input_spec = cast(DTensorSpec, op_schema.args_schema[0]) - ndim = input_spec.ndim - split_size_or_sections = op_schema.args_schema[1] - dim = cast(int, op_schema.args_schema[2]) if len(op_schema.args_schema) > 2 else 0 - dim = normalize_dim(dim, ndim) - - # TODO: tensor to split cannot have Partial - # in its placements for now. Will need to - # support in future. - if input_spec.sums: - raise NotImplementedError( - f"splitting distributed tensor with " - f"Partial placement is not implemented!\n" - f"DTensorSpec={input_spec}" - ) - - # TODO: just like slice op, split replicates before - # splitting on a sharded dimension - need_reshard = False - if is_tensor_dim_sharded(input_spec, dim=dim): - need_reshard = True - input_spec = DTensorSpec( - mesh=input_spec.mesh, - placements=unshard_tensor_dim( - input_spec.placements, - dim=dim, - ), - tensor_meta=input_spec.tensor_meta, - ) - - if need_reshard: - return OutputSharding( - None, - redistribute_schema=OpSchema( - op=op_schema.op, - args_schema=(input_spec,) + op_schema.args_schema[1:], - kwargs_schema=op_schema.kwargs_schema, - ), - ) - - def size_split(N, i): - # Last chunk will be smaller if the tensor size N - # along the given dimension dim is not divisible by i. - assert i > 0 - return [i] * (N // i) + ([N % i] if N % i != 0 else []) - - output_size_list = ( - size_split( - input_spec.shape[dim], - split_size_or_sections, - ) - if isinstance(split_size_or_sections, int) - else split_size_or_sections - ) - output_spec_list = [ - DTensorSpec( - mesh=input_spec.mesh, - placements=input_spec.placements, - ) - for _ in range(len(output_size_list)) - ] - return OutputSharding(output_spec_list) diff --git a/src/chop/distributed/tensor/ops/utils.py b/src/chop/distributed/tensor/ops/utils.py deleted file mode 100644 index 27cf89224..000000000 --- a/src/chop/distributed/tensor/ops/utils.py +++ /dev/null @@ -1,300 +0,0 @@ -# mypy: allow-untyped-defs -# Copyright (c) Meta Platforms, Inc. and affiliates -import functools -import itertools -import operator -from typing import cast, Iterable, List, Optional, Sequence, Tuple, Union - -import torch -from torch.distributed._tensor._collective_utils import redistribute_cost -from torch.distributed._tensor._op_schema import ( - OpSchema, - OpStrategy, - PlacementList, - PlacementStrategy, - RuntimeSchemaInfo, -) -from torch.distributed._tensor.device_mesh import DeviceMesh -from torch.distributed._tensor.placement_types import ( - DTensorSpec, - Partial, - Placement, - Replicate, - Shard, -) - -from chop.distributed.tensor.api import DTensor - - -# convenient wrapper to register sharding propagation rules -# pyre-fixme[3]: Return type must be annotated. -# pyre-fixme[2]: Parameter must be annotated. -def register_prop_rule(op, schema_info=None): - # pyre-fixme[53]: Captured variable `func` is not annotated. - # pyre-fixme[3]: Return type must be annotated. - # pyre-fixme[2]: Parameter must be annotated. - def wrapper(impl): - overloads = op if isinstance(op, list) else [op] - for overload in overloads: - DTensor._op_dispatcher.sharding_propagator.register_sharding_prop_rule( - overload, impl, schema_info - ) - return impl - - return wrapper - - -def register_op_strategy(op, schema_info=None): - # pyre-fixme[53]: Captured variable `func` is not annotated. - # pyre-fixme[3]: Return type must be annotated. - # pyre-fixme[2]: Parameter must be annotated. - - # For every ATen op that accepts any args in this list, - # the arg itself can impact the strides (and potentially the sharding strategy) - # of the output tensor. - # thus, we will detect ATen schemas with any of these args and ensure - # that they get specialized here. - arg_names_that_require_specializing_cache_strategy = [ - "memory_format", - ] - - def wrapper(impl): - if isinstance(op, list): - overloads = op - else: - overloads = [op] - - for overload in overloads: - curr_schema_info = None - if schema_info is None: - specialized_args = [ - a.name - for a in overload._schema.arguments - if a.name in arg_names_that_require_specializing_cache_strategy - ] - if any(specialized_args): - curr_schema_info = RuntimeSchemaInfo( - static_kwargkey=specialized_args - ) - else: - curr_schema_info = schema_info - DTensor._op_dispatcher.sharding_propagator.register_op_strategy( - overload, impl, curr_schema_info - ) - return impl - - return wrapper - - -def as_list( - x: Union[List[object], object] - # pyre-fixme[11]: Annotation `immutable_list` is not defined as a type. -) -> Union[List[object], torch.fx.immutable_collections.immutable_list]: # type: ignore[valid-type] - # During tracing, `aten.sum.dim_IntList` uses `immutable_list` for its args, - # which is an object but treated as a list by the tracer. Therefore, keep - # `immutable_list` intact here as well. - if type(x) is list or isinstance(x, torch.fx.immutable_collections.immutable_list): - return x - else: - return [x] - - -def normalize_dim(dim: int, ndim: int) -> int: - return dim if dim >= 0 else dim + ndim - - -def normalize_dims(dims: Union[int, Sequence[int]], ndim: int) -> Sequence[int]: - """Normalize a dim or a sequence of dims, so that they are all positive.""" - if isinstance(dims, int): - dims = (normalize_dim(dims, ndim),) - elif isinstance(dims, list): - dims = [normalize_dim(dim, ndim) for dim in dims] - elif isinstance(dims, tuple): - dims = tuple([normalize_dim(dim, ndim) for dim in dims]) - return dims - - -def normalize_to_torch_size(size) -> torch.Size: - """ - Unify variable types of size argument to torch.Size - Acceptable types include: - int, Sequence[int], Tuple[int], Tuple[Sequence[int]], - or torch.Size - """ - if isinstance(size, torch.Size): - return size - - if isinstance(size, int): - torch_size = [size] - elif len(size) == 1 and isinstance(size[0], Sequence): - torch_size = list(size[0]) - else: - torch_size = list(size) - return torch.Size(torch_size) - - -def prod(xs: Iterable[int]) -> int: - return functools.reduce(operator.mul, xs, 1) - - -def is_tensor_shardable(shape: Sequence[int], spec: DTensorSpec) -> bool: - """Check if the shape is shardable according to the spec.""" - # number of shards in each tensor dimension - shards_map = [1] * len(shape) - for i, placement in enumerate(spec.placements): - if placement.is_shard(): - shard_dim = cast(Shard, placement).dim - shards_map[shard_dim] *= spec.mesh.size(i) - - for i, dim_size in enumerate(shape): - # TODO: maybe we should determine is_shardable based on - # whether it's evenly sharded or not - if shards_map[i] > 1 and dim_size < shards_map[i]: - return False - - return True - - -def is_tensor_evenly_shardable(shape: Sequence[int], spec: DTensorSpec) -> bool: - """Check if the shape is evenly shardable according to the spec.""" - # number of shards in each tensor dimension - shards_map = [1] * len(shape) - for i, placement in enumerate(spec.placements): - if placement.is_shard(): - shard_dim = cast(Shard, placement).dim - shards_map[shard_dim] *= spec.mesh.size(i) - - for i, dim_size in enumerate(shape): - if shards_map[i] > 1 and (dim_size % shards_map[i] != 0): - return False - - return True - - -def is_tensor_dim_sharded(spec: DTensorSpec, dim: int) -> bool: - """Return True if tensor dim is sharded.""" - return any(p.is_shard(dim) for p in spec.placements) - - -def is_tensor_partial(spec: DTensorSpec) -> bool: - """Return True if tensor is partial on the mesh.""" - return any(p.is_partial() for p in spec.placements) - - -def infer_broadcast_dims_map( - common_shape: torch.Size, input_shape: torch.Size -) -> List[int]: - # infer the broadcast dims map, where it maps from the common shape dim to the input shape dim - # this is aligned with the broadcast semantics - common_ndim = len(common_shape) - input_ndim = len(input_shape) - broadcast_dims_map = [-1] * common_ndim - for idx in range(-1, -1 - input_ndim, -1): - if input_shape[idx] == common_shape[idx]: - broadcast_dims_map[common_ndim + idx] = input_ndim + idx - return broadcast_dims_map - - -def map_placements_after_broadcast( - placements: Tuple[Placement, ...], - shape: torch.Size, - broadcast_dims_map: List[int], -) -> Tuple[Placement, ...]: - """Map each placement based on the output shape after broadcast.""" - new_placements: List[Placement] = [] - for placement in placements: - if isinstance(placement, (Replicate, Partial)): - new_placements.append(placement) - else: - assert isinstance(placement, Shard) - shard_dim = normalize_dim(placement.dim, len(shape)) - new_shard_dim = broadcast_dims_map[shard_dim] - if new_shard_dim != -1: - # there's a map from the common shape shard dim to - # the input shape shard dim before broadcasting, - # use that instead - new_placements.append(Shard(new_shard_dim)) - else: - # there's no map between common shape shard dim and - # the input shape shard dim before broadcasting, - # in this case it means implicit broadcasting happen - # in this dim, so we can just mark it as replicate - # and implict broadcast will broadcast automatically - # to the sharded shape - new_placements.append(Replicate()) - - return tuple(new_placements) - - -def generate_redistribute_costs( - src_strategy: OpStrategy, dst_spec: DTensorSpec -) -> List[float]: - redistribute_costs: List[float] = [] - for strat in src_strategy.strategies: - redistribute_costs.append(redistribute_cost(strat.output_spec, dst_spec)) - - return redistribute_costs - - -def expand_to_full_mesh_op_strategy( - mesh: DeviceMesh, - op_schema: OpSchema, - single_mesh_dim_strategies: List[PlacementList], - *, - input_index: int = 1, - inplace_op: bool = False, -) -> OpStrategy: - # Expand the single_mesh_dim_strategies to full mesh dim strategies. - all_mesh_dim_strategies = [single_mesh_dim_strategies] * mesh.ndim - - strategy_combs = itertools.product(*all_mesh_dim_strategies) - - all_strategies = [] - for strategy_comb in strategy_combs: - spec_list: List[Optional[DTensorSpec]] = [] - for specs in zip(*strategy_comb): - if specs[0] is not None: - spec_list.append(DTensorSpec(mesh, specs)) - else: - spec_list.append(None) - - input_specs: List[DTensorSpec] = [ - s for s in spec_list[input_index:] if isinstance(s, DTensorSpec) - ] - - input_args_strategy = op_schema.args_strategy - assert len(input_specs) == len(input_args_strategy) - self_spec = input_args_strategy[0].strategies[0].output_spec - - if inplace_op and self_spec.placements != input_specs[0].placements: - # if it's inplace op, we would only allow the placement strategy to be added when the - # input_spec matches the first argument's runtime sharding, otherwise we skip - continue - - # check inputs shardable - inputs_shardable = all( - is_tensor_shardable(inp.shape, s) - for inp, s in zip(input_args_strategy, input_specs) - ) - - # only add to the all_strategies list when all inputs are shardable - if inputs_shardable: - redistribute_cost = [ - generate_redistribute_costs(input_strategy, input_spec) - for input_strategy, input_spec in zip(input_args_strategy, input_specs) - ] - if input_index > 1: - output_specs = tuple(spec_list[:input_index]) - else: - if spec_list[0] is not None: - output_specs = spec_list[0] # type: ignore[assignment] - else: - raise RuntimeError("output spec is None") - strategy = PlacementStrategy( - output_specs=output_specs, - input_specs=input_specs, - redistribute_cost=redistribute_cost, - ) - all_strategies.append(strategy) - - return OpStrategy(all_strategies) diff --git a/src/chop/distributed/tensor/ops/view_ops.py b/src/chop/distributed/tensor/ops/view_ops.py deleted file mode 100644 index 0e1402a2a..000000000 --- a/src/chop/distributed/tensor/ops/view_ops.py +++ /dev/null @@ -1,665 +0,0 @@ -# mypy: allow-untyped-defs -# Copyright (c) Meta Platforms, Inc. and affiliates -from dataclasses import dataclass -from typing import ( - Callable, - cast, - Dict, - Iterable, - List, - Optional, - Sequence, - Set, - Tuple, - Union, -) - -import torch -from torch import Tensor -from torch.distributed._tensor._op_schema import ( - OpSchema, - OpStrategy, - PlacementStrategy, - RuntimeSchemaInfo, - StrategyType, -) -from torch.distributed._tensor.api import Shard -from torch.distributed._tensor.placement_types import DTensorSpec, Placement, Replicate -from torch.distributed.device_mesh import DeviceMesh - -from chop.distributed.tensor.ops.utils import register_op_strategy -from chop.distributed.tensor.ops.utils import ( - normalize_dim, - normalize_dims, - prod, -) - -aten = torch.ops.aten - -Shape = Tuple[int, ...] - - -@dataclass -class DimSpec: - """Specifies how an output dimension maps to an input dimension.""" - - def inputs(self) -> Iterable["DimSpec"]: - return () - - -# Rules that map each dimension of the output to dimensions of the input tensor -DimMap = Tuple[DimSpec, ...] - - -@dataclass -class Singleton(DimSpec): - """Output dimension is a singleton.""" - - pass - - -@dataclass -class InputDim(DimSpec): - """Output dimension maps directly to an input dimension.""" - - input_dim: int - - -@dataclass -class Broadcast(DimSpec): - """Output is the broadcast of a singleton input dimension.""" - - dim: DimSpec - dim_size: int - - @classmethod - def new(cls, dim: DimSpec, dim_size: int) -> DimSpec: - return Broadcast(dim, dim_size) - - def inputs(self) -> Iterable[DimSpec]: - return (self.dim,) - - -@dataclass -class NewDim(DimSpec): - """This is a new dimension created by the op.""" - - size: int - - @classmethod - def new(cls, size: int) -> DimSpec: - return Singleton() if size == 1 else NewDim(size) - - -@dataclass -class Repeat(DimSpec): - """Output dimension is the input dimension repeated n-times.""" - - input_dim: DimSpec - times: int - - @classmethod - def new(cls, dim: DimSpec, times: int) -> DimSpec: - if times == 1: - return dim - elif isinstance(dim, Singleton): - # repeating a singleton is the same as broadcasting it - return Broadcast(dim, times) - else: - return Repeat(dim, times) - - def inputs(self) -> Iterable[DimSpec]: - return (self.input_dim,) - - -@dataclass -class Flatten(DimSpec): - """Flatten a set of input dimensions, ensuring right-most adjacent elements remain adjacent in the output.""" - - input_dims: Sequence[DimSpec] - - @classmethod - def new(cls, dims: Sequence[DimSpec]) -> DimSpec: - if len(dims) == 0: - # flattening a scalar leads to a singleton - return Singleton() - elif len(dims) == 1: - # flattening a single dimension is no-op - return dims[0] - else: - return Flatten(dims) - - def inputs(self) -> Iterable[DimSpec]: - return self.input_dims - - -@dataclass -class Split(DimSpec): - """ - This dimension is a member of a decomposition of the input dim. - - Note that input_dim itself could be a Flattened set of input dims. - """ - - input_dim: DimSpec - group_shape: Shape - split_id: int - - @classmethod - def new(cls, dim: DimSpec, group_shape: Tuple[int, ...], idx: int) -> DimSpec: - assert len(group_shape) > 0 - if len(group_shape) == 1: - # not really a group, just return the input dim back - assert idx == 0 - return dim - elif group_shape[idx] == 1: - return Singleton() - else: - # remove singletons from group - # group_mapping = [(new_index, (shape, old_index)) ...] - group_mapping = list( - enumerate((s, i) for i, s in enumerate(group_shape) if s != 1) - ) - new_group_shape = tuple(m[1][0] for m in group_mapping) - new_idx = next(filter(lambda x: x[1][1] == idx, group_mapping))[0] - return Split(dim, new_group_shape, new_idx) - - def inputs(self) -> Iterable[DimSpec]: - return (self.input_dim,) - - -def dim_pad_left(ndim: int, min_dims: int) -> DimMap: - return (Singleton(),) * max(0, min_dims - ndim) + tuple( - InputDim(i) for i in range(ndim) - ) - - -def dim_atleast_3d(ndim: int) -> DimMap: - if ndim == 0: - return (Singleton(), Singleton(), Singleton()) - elif ndim == 1: - return (Singleton(), InputDim(0), Singleton()) - elif ndim == 2: - return (InputDim(0), InputDim(1), Singleton()) - else: - return tuple(InputDim(i) for i in range(ndim)) - - -def expand(input_shape: Shape, shape: Shape) -> DimMap: - """Implement broadcast on multiple dimensions.""" - assert len(shape) >= len(input_shape) - - # 1. create padded input dimensions - padded_input = dim_pad_left(len(input_shape), len(shape)) - # 2. check that input shapes are compatible - mapping = [] - for p, desired_s in zip(padded_input, shape): - if isinstance(p, Singleton): - actual_s = 1 - assert desired_s >= 0 - else: - assert isinstance(p, InputDim), f"DimSpec not supported in expand: {p}" - actual_s = input_shape[p.input_dim] - assert actual_s == 1 or desired_s == -1 or desired_s == actual_s - mapping.append( - p - if desired_s in (1, -1) or desired_s == actual_s - else Broadcast.new(p, desired_s) - ) - return tuple(mapping) - - -def normalize_sizes(sizes: Union[Shape, Tuple[Shape]]) -> Shape: - if isinstance(sizes[0], int): - return cast(Shape, sizes) - elif len(sizes) == 1: - return cast(Shape, sizes[0]) # type: ignore[redundant-cast] - else: - raise RuntimeError("Size must be int... or tuple") - - -def dim_flatten(ndim: int, start_dim=0, end_dim=-1) -> DimMap: - if ndim == 0: - return (Singleton(),) - elif ndim == 1: - return (InputDim(0),) - else: - # only flattening dims from start_dim to end_dim (inclusive) - # other dims are passed through - if end_dim < 0: - end_dim += ndim - results: List[DimSpec] = [InputDim(i) for i in range(start_dim)] - results.append( - Flatten.new(tuple(InputDim(i) for i in range(start_dim, end_dim + 1))) - ) - results.extend([InputDim(i) for i in range(end_dim + 1, ndim)]) - return tuple(results) - - -def dim_movedim( - ndim: int, - input: Union[int, Sequence[int]], - destination: Union[int, Sequence[int]], -) -> DimMap: - input = normalize_dims(input, ndim) - destination = normalize_dims(destination, ndim) - - assert len(input) == len(destination) - input_set = set(input) - assert len(input_set) == len(input), "Found repeated input dims" - assert len(set(destination)) == len(destination), "Found repeated output dims" - assert max(input) < ndim - assert max(destination) < ndim - - dest = [-1] * ndim - for i, d in zip(input, destination): - dest[d] = i - - unused_inputs_iter = iter(i for i in range(ndim) if i not in input_set) - for i in range(ndim): - if dest[i] == -1: - dest[i] = next(unused_inputs_iter) - - return tuple(InputDim(i) for i in dest) - - -def dim_repeat(ndim: int, sizes: Shape) -> DimMap: - sizes = normalize_sizes(sizes) - assert ( - len(sizes) >= ndim - ), f"Number of dimensions of repeat dims {sizes} can not be smaller than number of dimensions of tensor {ndim}." - pad = len(sizes) - ndim - return tuple(Repeat.new(Singleton(), s) for s in sizes[:pad]) + tuple( - Repeat.new(InputDim(i), s) for i, s in enumerate(sizes[pad:]) - ) - - -def infer_size(total_size: int, sizes: Shape) -> Shape: - """ - One dimension input to view may be "-1". - - Infer the size of this dimension given the total_size. - """ - infers = [i for i, s in enumerate(sizes) if s == -1] - size = prod(sizes) - assert len(infers) <= 1, "can only infer one size" - if infers: - size = -size - missing_size = total_size // size - assert ( - total_size % size == 0 - ), f"size inferred for -1 is not integral {sizes} should have {total_size} elements." - return tuple(s if s != -1 else missing_size for s in sizes) - assert size == total_size, f"sizes do not match {total_size} vs {size}" - return sizes - - -def view_groups(from_size: Shape, to_size: Shape) -> DimMap: - """ - Decompose a reshape operation into forwarding, flattening, or splitting dimensions for each output dimension. - - A view or reshape operation can be decomposed into a set of 3 types of smaller operations: - 1) Forward a dimension from input to output - 2) Flatten a set of dimensions into a single dimension - 3) Split one dimension into multiple dimensions - - view_groups identifies these operations and returns, for each output dimension, what - is operation was performed in the input dimension. For example: - - view_groups([2, 3, 4], [2, 12]) -> ( - InputDim(0), - Flatten((InputDim(1), InputDim(2))) - ) - - - ouptut dimension 0 maps to input dimension 0 - - output dimension 1 maps to a flattened input dimensions 1 and 2 - - - view_groups([2, 3], [3, 2]) -> ( - Split(Flatten((InputDim(0), InputDim(1))), (3, 2), 0), - Split(Flatten((InputDim(0), InputDim(1))), (3, 2), 1), - ) - - - in the above, input is flattened into a single dimension and then split - into two separate dimensions with different sizes from the input. - """ - from_nelem = prod(from_size) - to_size = infer_size(from_nelem, normalize_sizes(to_size)) - - assert from_nelem == prod(to_size), "Total view shape does not add up" - - from_idx = 0 - to_idx = 0 - from_len = len(from_size) - to_len = len(to_size) - - result_pp = [] - - while from_idx < from_len or to_idx < to_len: - from_group_dim, to_group_shape = [], [] - - if from_idx >= from_len: - f = 1 - else: - f = from_size[from_idx] - from_group_dim.append(from_idx) - from_idx += 1 - - if to_idx >= to_len: - t = 1 - else: - t = to_size[to_idx] - to_group_shape.append(t) - to_idx += 1 - - # if any of the groups is singleton, great, we need to backtrack though - if f == 1 and t != 1: - # produces ([1], []) - to_idx -= 1 - to_group_shape = [] - elif f != 1 and t == 1: - # produces ([], [1]) - from_idx -= 1 - from_group_dim = [] - else: - # produces ([1], [1]), ([2], [2]), ([2,3], [6]) - while f != t: - if f < t: - nf = from_size[from_idx] - from_group_dim.append(from_idx) - from_idx += 1 - f *= nf - else: - nt = to_size[to_idx] - to_group_shape.append(nt) - to_idx += 1 - t *= nt - - if len(to_group_shape) > 0: - flattened = Flatten.new( - tuple(InputDim(fi) for fi in from_group_dim if from_size[fi] > 1) - ) - result_pp += [ - Split.new(flattened, tuple(to_group_shape), i) - for i in range(len(to_group_shape)) - ] - - return tuple(result_pp) - - -def dim_tile(ndim: int, dims: Tuple[int, ...]) -> DimMap: - if len(dims) < ndim: - dims = (1,) * (ndim - len(dims)) + dims - return dim_repeat(ndim, dims) - - -def dim_transpose(ndim: int, dim1: int, dim2: int) -> DimMap: - dim1 = normalize_dim(dim1, ndim) - dim2 = normalize_dim(dim2, ndim) - assert dim1 < ndim - assert dim2 < ndim - dimmap = [InputDim(i) for i in range(ndim)] - swapdim = dimmap[dim1] - dimmap[dim1] = dimmap[dim2] - dimmap[dim2] = swapdim - return tuple(dimmap) - - -def dim_squeeze(shape: Shape, dim: Optional[int] = None) -> DimMap: - # FIXME: this is wrong when dim=None and one of the dimensions - # equals size of the mesh. For example squeeze(DTensor(tensor(4), Shard[0])) could - # end up as squeeze(tensor(1)) if we have 4 devices; this would lead to - # removal of a dimension that is not actually a singleton. - return tuple( - InputDim(i) - for i, s in enumerate(shape) - if s > 1 or (dim is not None and i != normalize_dim(dim, len(shape))) - ) - - -def dim_unsqueeze(ndim: int, dim: int) -> DimMap: - dims = tuple(InputDim(i) for i in range(ndim)) - if dim < 0: - dim += ndim + 1 - return dims[:dim] + (Singleton(),) + dims[dim:] - - -def dim_view_as_real(shape: Shape) -> DimMap: - ndim = len(shape) - results: List[DimSpec] = [InputDim(i) for i in range(ndim - 1)] - # each complex number is split into two real numbers, - # resulting in one more dimension of size 2 - results.append(Split(InputDim(ndim - 1), (shape[-1], 2), 0)) - results.append(Split(InputDim(ndim - 1), (shape[-1], 2), 1)) - return tuple(results) - - -def dim_reduction( - ndim: int, dim_or_dims: Optional[Union[int, Sequence[int]]], keepdim: bool -) -> DimMap: - """ - General fallback for reduction ops where Partial() does not apply. - - This will cause incoming tensor to be replicated on the reducing dimensions. - """ - if dim_or_dims is None: - dim_or_dims = tuple(range(ndim)) - if isinstance(dim_or_dims, int): - dim_or_dims = (dim_or_dims,) - dim_or_dims = tuple(d if d >= 0 else d + ndim for d in dim_or_dims) - return tuple( - InputDim(i) if i not in dim_or_dims else Singleton() - for i in range(ndim) - if i not in dim_or_dims or keepdim - ) - - -dim_maps: Dict[Callable[..., torch.Tensor], Callable[..., DimMap]] = { - torch.atleast_1d: lambda x: dim_pad_left(x.ndim, 1), - torch.atleast_2d: lambda x: dim_pad_left(x.ndim, 2), - torch.atleast_3d: lambda x: dim_atleast_3d(x.ndim), - torch.broadcast_to: lambda input, shape: expand(input.shape, shape), - Tensor.expand: lambda self, *sizes: expand(self.shape, normalize_sizes(sizes)), - torch.flatten: lambda tensor: dim_flatten(tensor.ndim), - torch.movedim: lambda input, source, destination: dim_movedim( - input.ndim, source, destination - ), - torch.permute: lambda input, dims: tuple( - InputDim(i) for i in normalize_dims(dims, input.ndim) - ), - torch.ravel: lambda tensor: dim_flatten(tensor.ndim), - Tensor.repeat: lambda self, *sizes: dim_repeat(self.ndim, sizes), - torch.reshape: lambda input, shape: view_groups(input.shape, shape), - torch.squeeze: lambda input, dim=None: dim_squeeze(input.shape, dim), - torch.tile: lambda input, dims: dim_tile(input.ndim, dims), - torch.transpose: lambda input, dim0, dim1: dim_transpose(input.ndim, dim0, dim1), - torch.unsqueeze: lambda input, dim: dim_unsqueeze(input.ndim, dim), - Tensor.view: lambda input, *shape: view_groups(input.shape, shape), - torch.view_as_complex: lambda input: dim_flatten(input.ndim, input.ndim - 2), - torch.view_as_real: lambda input: dim_view_as_real(input.shape), -} - - -def propagate_shape_and_sharding( - input_src_placements: Sequence[Placement], - local_in_shape: Shape, - rule: DimMap, - mesh_sizes: Shape, -) -> Tuple[Sequence[Placement], Sequence[Placement]]: - """ - Determine input target sharding and output sharding based on - given global tensor shape and input source sharding. - - Sharding propagation follows mapped dimensions: - - An output dimension that maps directly to an input dimension is sharded equally - - An output dimension that is a flattened set of input dimensions can only be - sharded if only the leftmost flattened dimension is sharded. - - An output dimension that is a split of the input dimension can only be sharded - if the leftmost split size is divisible by the mesh dimension - """ - assert len(input_src_placements) == len(mesh_sizes) - # for each input dim, for each mesh dim, provides a list of possible shardable dimensions - mesh_ndim = len(mesh_sizes) - shardable_dims: Dict[int, List[bool]] = {} - - # in case an input dimension disappears (e.g. collapsing, reduction) - # we cannot shard in that dimension (we need a replication fall-back rule) - seen_input_dims: Set[int] = set() - - def collect_used_inputs(cmd: DimSpec) -> None: - if isinstance(cmd, InputDim): - seen_input_dims.add(cmd.input_dim) - for inp in cmd.inputs(): - collect_used_inputs(inp) - - for cmd in rule: - collect_used_inputs(cmd) - for dim in range(len(local_in_shape)): - shardable_dims[dim] = [dim in seen_input_dims] * mesh_ndim - - def get_in_dim_to_shard(cmd: DimSpec) -> Optional[InputDim]: - if isinstance(cmd, InputDim): - return cmd - elif isinstance(cmd, Flatten): - for dim in cmd.input_dims[1:]: - if isinstance(dim, InputDim): - shardable_dims[dim.input_dim] = [False] * mesh_ndim - dim0 = cmd.input_dims[0] - return dim0 if isinstance(dim0, InputDim) else None - elif isinstance(cmd, Split): - in_dim = get_in_dim_to_shard(cmd.input_dim) - out_size = cmd.group_shape[cmd.split_id] - if cmd.split_id == 0 and in_dim is not None: - # we need to check that the input dimension is divisible - # by the size of the submesh we're sharding it on - # NOTE: it would be possible to shard the same input dimension - # on more than one mesh dimension. In that case, the dimension - # needs to be divisible by the product of mesh sizes. - # In order to keep the problem more tractable, we will not consider - # double resharding as a suggestion (e.g. [Shard(0), Shard(0) ]) - # but we will allow it if that's the input and it's compatible - - # 1. is this dimension shardable on each individual mesh dim? - shardable_dims[in_dim.input_dim] = [ - out_size % mesh_dim_size == 0 for mesh_dim_size in mesh_sizes - ] - - # 2. here we special case things like [Shard(0), Shard(0)] - submesh_size = 1 - for size, shard in zip(mesh_sizes, input_src_placements): - if isinstance(shard, Shard) and shard.dim == in_dim: - submesh_size *= size - assert ( - out_size % submesh_size == 0 - ), f"Resulting dimension size {out_size} is not divisible by its mesh dimension {submesh_size}." - - # we will only shard our first component of the split - return in_dim if cmd.split_id == 0 else None - elif isinstance(cmd, Repeat): - in_dim = get_in_dim_to_shard(cmd.input_dim) - if in_dim is not None: - shardable_dims[in_dim.input_dim] = [False] * mesh_ndim - return None - else: - return None - - # for each output dim, find the corresponding input dim in terms of sharding prop - shard_dim_map = {} - for dim, cmd in enumerate(rule): - in_dim = get_in_dim_to_shard(cmd) - if in_dim is not None: - shard_dim_map[in_dim.input_dim] = dim - - input_tgt_placements = [ - ( - Replicate() - if isinstance(p, Shard) and not shardable_dims[p.dim][mesh_dim] - else p - ) - for mesh_dim, p in enumerate(input_src_placements) - ] - output_placements = [ - Shard(shard_dim_map[p.dim]) if isinstance(p, Shard) else p - for p in input_tgt_placements - ] - - return input_tgt_placements, output_placements - - -def register_op_strategy_map( - aten_op_overload: torch._ops.OpOverload, - local_op_name: Callable[..., torch.Tensor], - schema_info: Optional[RuntimeSchemaInfo] = None, -) -> None: - dim_map: Callable[..., DimMap] = dim_maps[local_op_name] - - @register_op_strategy(aten_op_overload, schema_info=schema_info) - def reshape_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: - rules = dim_map(*op_schema.args_schema, **op_schema.kwargs_schema) - input_strategy = cast(OpStrategy, op_schema.args_schema[0]) - global_in_shape = input_strategy.shape - assert global_in_shape is not None, "Shape required." - - output_strategy = OpStrategy([]) - for input_placement_strategy in input_strategy.strategies: - input_src_spec = input_placement_strategy.output_spec - - input_tgt_placements, output_placements = propagate_shape_and_sharding( - input_src_spec.placements, - tuple(global_in_shape), - rules, - mesh.shape, - ) - - # TODO: optimize this. we shouldn't simply blindly replicate - # unshardable dims ... - # FIXME: this can be wrong for situations where we have - # [Shard(0), Shard(0)] - input_tgt_spec = DTensorSpec( - placements=tuple(input_tgt_placements), - mesh=input_src_spec.mesh, - tensor_meta=input_src_spec.tensor_meta, - ) - - output_spec = DTensorSpec(mesh=mesh, placements=tuple(output_placements)) - output_strategy.strategies.append( - PlacementStrategy( - output_specs=output_spec, - input_specs=(input_tgt_spec,), - redistribute_cost=[], - ) - ) - - return output_strategy - - -register_op_strategy_map(aten.squeeze.default, torch.squeeze) -register_op_strategy_map( - aten.squeeze.dim, torch.squeeze, schema_info=RuntimeSchemaInfo(1) -) -register_op_strategy_map( - aten.view.default, Tensor.view, schema_info=RuntimeSchemaInfo(1) -) -register_op_strategy_map( - aten.reshape.default, torch.reshape, schema_info=RuntimeSchemaInfo(1) -) -register_op_strategy_map( - aten._unsafe_view.default, Tensor.view, schema_info=RuntimeSchemaInfo(1) -) -register_op_strategy_map( - aten.unsqueeze.default, torch.unsqueeze, schema_info=RuntimeSchemaInfo(1) -) -register_op_strategy_map( - aten.expand.default, Tensor.expand, schema_info=RuntimeSchemaInfo(1) -) -register_op_strategy_map( - aten.permute.default, torch.permute, schema_info=RuntimeSchemaInfo(1) -) -register_op_strategy_map( - aten.repeat.default, Tensor.repeat, schema_info=RuntimeSchemaInfo(1) -) -register_op_strategy_map( - aten.transpose.int, torch.transpose, schema_info=RuntimeSchemaInfo(1) -) -register_op_strategy_map(aten.view_as_complex.default, torch.view_as_complex) -register_op_strategy_map(aten.view_as_real.default, torch.view_as_real) diff --git a/src/chop/pipelines/distributed_inference.py b/src/chop/pipelines/distributed_inference.py index a271c5298..9179c5970 100644 --- a/src/chop/pipelines/distributed_inference.py +++ b/src/chop/pipelines/distributed_inference.py @@ -22,14 +22,26 @@ class AutoPipelineForDistributedInference(AutoPipeline): def __init__(self) -> None: """Initializes the AutoPipeline.""" + # Pre-processing pass_list = [ passes.replace_method_with_function, + ] + + # Raise to Mase IR + pass_list += [ passes.init_metadata_analysis_pass, passes.add_common_metadata_analysis_pass, + ] + + # Autosharding + pass_list += [ passes.autosharding_analysis_pass, + ] + + # Only run the following in distributed setup + pass_list += [ passes.insert_dtensor_wrapper_transform_pass, passes.resharding_transform_pass, - passes.report_graph_analysis_pass, ] super().__init__(pass_list) From 13efcf2cb6d6e80edd015707a97a75839e1c5232 Mon Sep 17 00:00:00 2001 From: Pedro Gimenes Date: Thu, 15 Aug 2024 13:40:19 +0000 Subject: [PATCH 73/93] include compute cost in ILP --- src/chop/distributed/launcher.py | 5 ++- src/chop/distributed/utils.py | 9 +++++- .../autosharding/alpa_intra_operator.py | 10 +++--- .../graph/analysis/autosharding/layers.py | 21 ------------ src/chop/pipelines/distributed_inference.py | 32 ++++++++++++++++--- 5 files changed, 43 insertions(+), 34 deletions(-) diff --git a/src/chop/distributed/launcher.py b/src/chop/distributed/launcher.py index 3f3cb524d..eac339e74 100644 --- a/src/chop/distributed/launcher.py +++ b/src/chop/distributed/launcher.py @@ -1,5 +1,6 @@ import torch.multiprocessing as mp +from chop.distributed.utils import _get_mesh_from_world_size from ..tools import get_logger logger = get_logger(__name__) @@ -27,9 +28,11 @@ def __init__( """ self.mg = mg self.world_size = world_size - self.device_mesh = device_mesh self.device_fn = device_fn + if device_mesh is None: + self.device_mesh, _ = _get_mesh_from_world_size(world_size) + def run( self, model_class=None, diff --git a/src/chop/distributed/utils.py b/src/chop/distributed/utils.py index c119ed9c6..c95b6b405 100644 --- a/src/chop/distributed/utils.py +++ b/src/chop/distributed/utils.py @@ -1,5 +1,5 @@ from time import time - +import numpy as np import torch import torch.nn as nn @@ -122,3 +122,10 @@ def dist_model_fn( level="error", ) raise e + + +def _get_mesh_from_world_size(world_size: int = 8): + device_ids = np.arange(world_size) + mesh_shape = (2, world_size // 2) + mesh_ids = device_ids.reshape(mesh_shape) + return mesh_ids.tolist(), tuple(mesh_shape) diff --git a/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py b/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py index fc183a389..129537cdd 100644 --- a/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py +++ b/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py @@ -18,11 +18,8 @@ from .mesh_model import MeshModel from .layers import ( - AUTOSHARDING_MODULES, AUTOSHARDING_FUNCTIONS, - AUTOSHARDING_METHODS, IMPLICIT_FUNCS, - IMPLICIT_METHODS, FULLY_REPLICATED_FUNCS, ) from .strategies.common import ( @@ -39,7 +36,7 @@ def _get_computation_cost_from_strategy( node: fx.Node, strategy: OpStrategy, mesh: MeshModel, - repeat: int = 5, + repeat: int = 100, warmup_iters: int = 2, profiling_device: int = 0, ): @@ -308,7 +305,6 @@ def _extract_ilp(mg, mesh, pass_args={}): "placeholder", "get_attr", "output", - # todo: decide how to handle call_method nodes "call_method", ]: cost_vector = [] @@ -316,8 +312,10 @@ def _extract_ilp(mg, mesh, pass_args={}): for strategy in op_strategy.strategies: cost = _get_computation_cost_from_strategy(node, strategy, mesh) cost_vector.append(cost) + + expr += np.array(cost_vector) @ opt_var except: - print(f"Op {node} failed to compute cost") + logger.error(f"Failed to compute computation cost for node {node}") # Consider resharding cost for each of the node's arguments e_var_checks = [] diff --git a/src/chop/passes/graph/analysis/autosharding/layers.py b/src/chop/passes/graph/analysis/autosharding/layers.py index bb4c21e85..7395d0608 100644 --- a/src/chop/passes/graph/analysis/autosharding/layers.py +++ b/src/chop/passes/graph/analysis/autosharding/layers.py @@ -31,10 +31,6 @@ logger = get_logger(__name__) -AUTOSHARDING_MODULES = { - torch.nn.ReLU: pointwise_strategy, -} - AUTOSHARDING_FUNCTIONS = { # embedding_ops.py F.embedding: embedding_strategy, @@ -263,19 +259,6 @@ torch.unsqueeze: get_reshape_strategy(torch.unsqueeze), } -AUTOSHARDING_METHODS = { - # view_ops.py - "view": get_reshape_strategy(torch.Tensor.view), - "reshape": get_reshape_strategy(torch.Tensor.reshape), - "expand": get_reshape_strategy(torch.Tensor.expand), - "permute": get_reshape_strategy(torch.Tensor.permute), - "transpose": get_reshape_strategy(torch.Tensor.transpose), - "unsqueeze": get_reshape_strategy(torch.Tensor.unsqueeze), - "masked_fill": pointwise_strategy, - "masked_fill_": pointwise_strategy, - "contiguous": tensor_op_strategy, -} - FULLY_REPLICATED_FUNCS = [ F.embedding, torch.arange, @@ -289,7 +272,3 @@ torch.finfo, torch_size, ] - -IMPLICIT_METHODS = [ - "size", -] diff --git a/src/chop/pipelines/distributed_inference.py b/src/chop/pipelines/distributed_inference.py index 9179c5970..e67c3ef6e 100644 --- a/src/chop/pipelines/distributed_inference.py +++ b/src/chop/pipelines/distributed_inference.py @@ -1,12 +1,22 @@ +import torch.distributed as dist + import chop.passes as passes +from chop.tools import get_logger from .auto_pipeline import AutoPipeline +logger = get_logger(__name__) +logger.setLevel("INFO") + class AutoPipelineForDistributedInference(AutoPipeline): """This pipeline is used for distributed inference. - It runs the following passes: + It runs the following pre-processing passes: + + - replace_method_with_function + + Then, it raises the graph to Mase IR: - init_metadata_analysis_pass @@ -14,9 +24,16 @@ class AutoPipelineForDistributedInference(AutoPipeline): - add_common_metadata_analysis_pass + Then, it runs the following passes: + - autosharding_analysis_pass + If the distributed setup is initialized, it runs the following passes: + + - insert_dtensor_wrapper_transform_pass + - resharding_transform_pass + """ def __init__(self) -> None: @@ -39,9 +56,14 @@ def __init__(self) -> None: ] # Only run the following in distributed setup - pass_list += [ - passes.insert_dtensor_wrapper_transform_pass, - passes.resharding_transform_pass, - ] + if dist.is_initialized(): + pass_list += [ + passes.insert_dtensor_wrapper_transform_pass, + passes.resharding_transform_pass, + ] + else: + logger.info( + "Torch distributed is not initialized, so will skip the following passes: insert_dtensor_wrapper_transform_pass, resharding_transform_pass" + ) super().__init__(pass_list) From c3a43a6400286b67aa1e22914783d23e538ae254 Mon Sep 17 00:00:00 2001 From: Pedro Gimenes Date: Thu, 15 Aug 2024 15:37:17 +0000 Subject: [PATCH 74/93] support sdpa strategy --- .../autosharding/alpa_intra_operator.py | 19 ++-- .../analysis/autosharding/autosharding.py | 2 + .../graph/analysis/autosharding/layers.py | 3 + .../autosharding/strategies/common.py | 90 +++++++++++++++++++ .../autosharding/strategies/embedding_ops.py | 90 +------------------ .../autosharding/strategies/matrix_ops.py | 53 ++++++++--- .../autosharding/strategies/view_ops.py | 2 +- 7 files changed, 154 insertions(+), 105 deletions(-) diff --git a/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py b/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py index 129537cdd..64b1e2f44 100644 --- a/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py +++ b/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py @@ -308,14 +308,17 @@ def _extract_ilp(mg, mesh, pass_args={}): "call_method", ]: cost_vector = [] - try: - for strategy in op_strategy.strategies: + for strategy in op_strategy.strategies: + try: cost = _get_computation_cost_from_strategy(node, strategy, mesh) - cost_vector.append(cost) + except Exception as e: + logger.warning( + f"Failed to compute computation cost for node {node} strategy: {strategy} due to exception: {e}" + ) + cost = 100000.0 + cost_vector.append(cost) - expr += np.array(cost_vector) @ opt_var - except: - logger.error(f"Failed to compute computation cost for node {node}") + expr += np.array(cost_vector) @ opt_var # Consider resharding cost for each of the node's arguments e_var_checks = [] @@ -356,6 +359,10 @@ def _extract_ilp(mg, mesh, pass_args={}): for dest_idx, dest_spec in enumerate(node_in_specs): for src_idx, src_spec in enumerate(arg_out_specs): + + if isinstance(src_spec, tuple): + src_spec = src_spec[0] + cost = redistribute_cost(src_spec, dest_spec) resharding_costs[dest_idx, src_idx] = ( 1000000 if cost == float("inf") else cost diff --git a/src/chop/passes/graph/analysis/autosharding/autosharding.py b/src/chop/passes/graph/analysis/autosharding/autosharding.py index ad2ef006d..9701e2d84 100644 --- a/src/chop/passes/graph/analysis/autosharding/autosharding.py +++ b/src/chop/passes/graph/analysis/autosharding/autosharding.py @@ -136,6 +136,8 @@ def _export_solution(mg, export_file: str = "ilp_solution.pkl"): ) else: spec = result_info["dtensor_spec"] + if isinstance(spec, tuple): + spec = spec[0] out_dict[node_name]["results"][result] = spec.placements with open(export_file, "wb") as file: diff --git a/src/chop/passes/graph/analysis/autosharding/layers.py b/src/chop/passes/graph/analysis/autosharding/layers.py index 7395d0608..79f59d23e 100644 --- a/src/chop/passes/graph/analysis/autosharding/layers.py +++ b/src/chop/passes/graph/analysis/autosharding/layers.py @@ -22,6 +22,7 @@ addmm_strategy, bmm_strategy, baddmm_strategy, + scaled_dot_product_strategy, ) from .strategies.view_ops import get_reshape_strategy from .strategies.pointwise_ops import pointwise_strategy, linear_pointwise_strategy @@ -257,6 +258,8 @@ torch_permute: get_reshape_strategy(torch.Tensor.permute), torch_transpose: transpose_strategy, torch.unsqueeze: get_reshape_strategy(torch.unsqueeze), + # SDPA + F.scaled_dot_product_attention: scaled_dot_product_strategy, } FULLY_REPLICATED_FUNCS = [ diff --git a/src/chop/passes/graph/analysis/autosharding/strategies/common.py b/src/chop/passes/graph/analysis/autosharding/strategies/common.py index 24628deb4..cda4d7e40 100644 --- a/src/chop/passes/graph/analysis/autosharding/strategies/common.py +++ b/src/chop/passes/graph/analysis/autosharding/strategies/common.py @@ -1,18 +1,27 @@ +from typing import List import itertools import numpy as np import torch import torch.nn.functional as F + +from torch.distributed.device_mesh import DeviceMesh from torch.distributed._tensor._op_schema import OpStrategy, PlacementStrategy from torch.distributed._tensor.placement_types import ( + Placement, Replicate, Shard, DTensorSpec, TensorMeta, ) +from torch.distributed._tensor.ops.utils import ( + is_tensor_shardable, + generate_redistribute_costs, +) from chop.tools import get_logger + logger = get_logger(__name__) @@ -160,3 +169,84 @@ def fully_replicated_strategy(meta, mesh): ) ] ) + + +def expand_to_full_mesh_op_strategy( + meta, + mesh: DeviceMesh, + single_mesh_dim_strategies: List[List[Placement]], + *, + input_index: int = 1, + inplace_op: bool = False, +) -> OpStrategy: + # Expand the single_mesh_dim_strategies to full mesh dim strategies. + all_mesh_dim_strategies = [single_mesh_dim_strategies] * mesh.ndim + + strategy_combs = itertools.product(*all_mesh_dim_strategies) + + all_strategies = [] + for strategy_comb in strategy_combs: + spec_list = [] + for specs in zip(*strategy_comb): + try: + spec_list.append( + DTensorSpec( + mesh, + tuple(specs), + tensor_meta=TensorMeta( + shape=meta["common"]["results"]["data_out_0"]["shape"], + stride=None, + dtype=meta["common"]["results"]["data_out_0"][ + "torch_dtype" + ], + ), + ) + ) + except: + breakpoint() + + input_specs = spec_list[input_index:] + # input_args_strategy = op_schema.args_strategy + input_args_strategy = tuple( + arg.meta["mase"]["software"]["autosharding"]["op_strategy"] + for arg in meta.node.args + ) + assert len(input_specs) == len(input_args_strategy) + self_spec = input_args_strategy[0].strategies[0].output_spec + if inplace_op and self_spec.placements != input_specs[0].placements: + # if it's inplace op, we would only allow the placement strategy to be added when the + # input_spec matches the first argument's runtime sharding, otherwise we skip + continue + + # check inputs shardable + inputs_shardable = all( + is_tensor_shardable(inp.shape, s) + for inp, s in zip(input_args_strategy, input_specs) + ) + + # extend input_specs to include fully replicated sharding for constant nodes + extended_input_specs = input_specs + [ + DTensorSpec( + mesh, + (Replicate(), Replicate()), + # todo: may need to set tensor meta + tensor_meta=None, + ) + ] * (len(meta["common"]["args"].keys()) - len(input_specs)) + + # only add to the all_strategies list when all inputs are shardable + if inputs_shardable: + redistribute_cost = [ + generate_redistribute_costs(input_strategy, input_spec) + for input_strategy, input_spec in zip(input_args_strategy, input_specs) + ] + strategy = PlacementStrategy( + output_specs=( + tuple(spec_list[:input_index]) if input_index > 1 else spec_list[0] + ), + input_specs=extended_input_specs, + redistribute_cost=redistribute_cost, + ) + all_strategies.append(strategy) + + return OpStrategy(all_strategies) diff --git a/src/chop/passes/graph/analysis/autosharding/strategies/embedding_ops.py b/src/chop/passes/graph/analysis/autosharding/strategies/embedding_ops.py index e8dc24012..ab1bf637e 100644 --- a/src/chop/passes/graph/analysis/autosharding/strategies/embedding_ops.py +++ b/src/chop/passes/graph/analysis/autosharding/strategies/embedding_ops.py @@ -7,26 +7,16 @@ import torch import torch.distributed._functional_collectives as funcol -from torch.distributed._tensor._op_schema import ( - OpSchema, - OpStrategy, - StrategyType, - DTensorSpec, - PlacementStrategy, -) -from torch.distributed._tensor.ops.utils import ( - is_tensor_shardable, - generate_redistribute_costs, -) +from torch.distributed._tensor._op_schema import StrategyType from torch.distributed._tensor.placement_types import ( Partial, Placement, Replicate, Shard, - TensorMeta, ) from torch.distributed.device_mesh import DeviceMesh +from .common import expand_to_full_mesh_op_strategy aten = torch.ops.aten @@ -169,82 +159,6 @@ def __str__(self) -> str: return "MaskP" -def expand_to_full_mesh_op_strategy( - meta, - mesh: DeviceMesh, - single_mesh_dim_strategies: List[List[Placement]], - *, - input_index: int = 1, - inplace_op: bool = False, -) -> OpStrategy: - # Expand the single_mesh_dim_strategies to full mesh dim strategies. - all_mesh_dim_strategies = [single_mesh_dim_strategies] * mesh.ndim - - strategy_combs = itertools.product(*all_mesh_dim_strategies) - - all_strategies = [] - for strategy_comb in strategy_combs: - spec_list = [] - for specs in zip(*strategy_comb): - spec_list.append( - DTensorSpec( - mesh, - tuple(specs), - tensor_meta=TensorMeta( - shape=meta["common"]["results"]["data_out_0"]["shape"], - stride=None, - dtype=meta["common"]["results"]["data_out_0"]["torch_dtype"], - ), - ) - ) - - input_specs = spec_list[input_index:] - # input_args_strategy = op_schema.args_strategy - input_args_strategy = tuple( - arg.meta["mase"]["software"]["autosharding"]["op_strategy"] - for arg in meta.node.args - ) - assert len(input_specs) == len(input_args_strategy) - self_spec = input_args_strategy[0].strategies[0].output_spec - if inplace_op and self_spec.placements != input_specs[0].placements: - # if it's inplace op, we would only allow the placement strategy to be added when the - # input_spec matches the first argument's runtime sharding, otherwise we skip - continue - - # check inputs shardable - inputs_shardable = all( - is_tensor_shardable(inp.shape, s) - for inp, s in zip(input_args_strategy, input_specs) - ) - - # extend input_specs to include fully replicated sharding for constant nodes - extended_input_specs = input_specs + [ - DTensorSpec( - mesh, - (Replicate(), Replicate()), - # todo: may need to set tensor meta - tensor_meta=None, - ) - ] * (len(meta["common"]["args"].keys()) - len(input_specs)) - - # only add to the all_strategies list when all inputs are shardable - if inputs_shardable: - redistribute_cost = [ - generate_redistribute_costs(input_strategy, input_spec) - for input_strategy, input_spec in zip(input_args_strategy, input_specs) - ] - strategy = PlacementStrategy( - output_specs=( - tuple(spec_list[:input_index]) if input_index > 1 else spec_list[0] - ), - input_specs=extended_input_specs, - redistribute_cost=redistribute_cost, - ) - all_strategies.append(strategy) - - return OpStrategy(all_strategies) - - def embedding_strategy(meta, mesh) -> StrategyType: """ This strategy handles embedding op. We have two possible embedding shardings: diff --git a/src/chop/passes/graph/analysis/autosharding/strategies/matrix_ops.py b/src/chop/passes/graph/analysis/autosharding/strategies/matrix_ops.py index ede2167ff..20a2c74c6 100644 --- a/src/chop/passes/graph/analysis/autosharding/strategies/matrix_ops.py +++ b/src/chop/passes/graph/analysis/autosharding/strategies/matrix_ops.py @@ -19,9 +19,18 @@ TensorMeta, ) +from chop.ir.graph import MaseMetadata +from .common import expand_to_full_mesh_op_strategy from ..utils import is_tensor_shardable -from chop.ir.graph import MaseMetadata + +def _other(meta, dim): + if dim == meta["common"]["args"]["dim0"]["value"]: + return meta["common"]["args"]["dim1"]["value"] + elif dim == meta["common"]["args"]["dim1"]["value"]: + return meta["common"]["args"]["dim0"]["value"] + else: + raise ValueError(f"Invalid dim: {dim}") def transpose_strategy( @@ -35,20 +44,27 @@ def transpose_strategy( assert isinstance(self_strategy, OpStrategy) fully_replicated_spec = DTensorSpec( - mesh=mesh, placements=[Replicate(), Replicate()], tensor_meta=None + mesh=mesh, + placements=[Replicate(), Replicate()], + tensor_meta=None, ) transpose_strategies = [] for input_strategy in self_strategy.strategies: - input_spec = input_strategy.output_spec + + if isinstance(input_strategy.output_specs, tuple): + input_spec = input_strategy.output_specs[0] + else: + input_spec = input_strategy.output_spec + # follow the input spec but transpose the Shard placements output_placements = [ - Shard(1 - p.dim) if isinstance(p, Shard) else p + Shard(_other(meta, p.dim)) if isinstance(p, Shard) else p for p in input_spec.placements ] transpose_strategy = PlacementStrategy( output_specs=DTensorSpec( - mesh=input_strategy.output_spec.mesh, + mesh=input_spec.mesh, placements=tuple(output_placements), tensor_meta=TensorMeta( shape=meta["common"]["results"]["data_out_0"]["shape"], @@ -57,7 +73,7 @@ def transpose_strategy( ), ), # include 2 fully replicated inputs for dim_0 and dim_1 arguments - input_specs=(input_strategy.output_spec,) + (fully_replicated_spec,) * 2, + input_specs=(input_spec,) + (fully_replicated_spec,) * 2, ) transpose_strategies.append(transpose_strategy) @@ -210,11 +226,17 @@ def scaled_dot_product_flash_attention_strategy( # NOTE: currently we only support some simple strategies to support tensor parallelism # TODO: sdpa might be a good candidate for us to explore decomposed sharding propagation # as it involves: matmul, pointwise, reduction ops together. - return_debug_mask = len(op_schema.args_schema) >= 6 and op_schema.args_schema[5] - q_input_strategy = op_schema.args_schema[0] + arg_names = list(meta["common"]["args"].keys()) + arg_infos = list(meta["common"]["args"].values()) + return_debug_mask = len(arg_names) >= 6 and arg_infos[5]["value"] + + # q_input_strategy = op_schema.args_schema[0] + q_parent_node = meta.node.args[0] + q_input_strategy = q_parent_node.meta["mase"]["software"]["autosharding"][ + "op_strategy" + ] assert isinstance(q_input_strategy, OpStrategy) # assuming q/k/v have the same shape - qkv_shape = q_input_strategy.shape single_mesh_dim_strategies = [] @@ -282,5 +304,16 @@ def scaled_dot_product_flash_attention_strategy( ] ) return expand_to_full_mesh_op_strategy( - mesh, op_schema, single_mesh_dim_strategies, input_index=9 + meta, + mesh, + single_mesh_dim_strategies, + input_index=9, ) + + +def scaled_dot_product_strategy( + meta: MaseMetadata, + mesh: tuple, +): + # todo: support efficient attention backend + return scaled_dot_product_flash_attention_strategy(meta, mesh) diff --git a/src/chop/passes/graph/analysis/autosharding/strategies/view_ops.py b/src/chop/passes/graph/analysis/autosharding/strategies/view_ops.py index 525a49897..4902e83ca 100644 --- a/src/chop/passes/graph/analysis/autosharding/strategies/view_ops.py +++ b/src/chop/passes/graph/analysis/autosharding/strategies/view_ops.py @@ -609,7 +609,7 @@ def reshape_strategy(meta, mesh): ) replicate_spec = DTensorSpec( - placements=tuple(input_tgt_placements), + placements=(Replicate(), Replicate()), mesh=input_src_spec.mesh, # todo: may need to set tensor meta tensor_meta=None, From 98ed095d6c2b17c015c35fa1391c31a23deb1d5f Mon Sep 17 00:00:00 2001 From: Pedro Gimenes Date: Thu, 15 Aug 2024 15:44:44 +0000 Subject: [PATCH 75/93] directory refactoring --- .../analysis/autosharding/{ => algos}/alpa.py | 0 .../{ => algos}/alpa_intra_operator.py | 7 +++---- .../autosharding/{ => algos}/fully_replicated.py | 2 +- .../analysis/autosharding/{ => algos}/megatron.py | 2 +- .../graph/analysis/autosharding/autosharding.py | 6 +++--- .../passes/graph/analysis/autosharding/layers.py | 14 +++++++------- .../{strategies => ops}/basic_strategy.py | 0 .../autosharding/{strategies => ops}/common.py | 0 .../{strategies => ops}/embedding_ops.py | 1 - .../autosharding/{strategies => ops}/math_ops.py | 0 .../autosharding/{strategies => ops}/matrix_ops.py | 0 .../{strategies => ops}/pointwise_ops.py | 5 +---- .../autosharding/{strategies => ops}/tensor_ops.py | 0 .../autosharding/{strategies => ops}/view_ops.py | 0 .../passes/graph/analysis/autosharding/utils.py | 2 +- 15 files changed, 17 insertions(+), 22 deletions(-) rename src/chop/passes/graph/analysis/autosharding/{ => algos}/alpa.py (100%) rename src/chop/passes/graph/analysis/autosharding/{ => algos}/alpa_intra_operator.py (99%) rename src/chop/passes/graph/analysis/autosharding/{ => algos}/fully_replicated.py (95%) rename src/chop/passes/graph/analysis/autosharding/{ => algos}/megatron.py (85%) rename src/chop/passes/graph/analysis/autosharding/{strategies => ops}/basic_strategy.py (100%) rename src/chop/passes/graph/analysis/autosharding/{strategies => ops}/common.py (100%) rename src/chop/passes/graph/analysis/autosharding/{strategies => ops}/embedding_ops.py (99%) rename src/chop/passes/graph/analysis/autosharding/{strategies => ops}/math_ops.py (100%) rename src/chop/passes/graph/analysis/autosharding/{strategies => ops}/matrix_ops.py (100%) rename src/chop/passes/graph/analysis/autosharding/{strategies => ops}/pointwise_ops.py (98%) rename src/chop/passes/graph/analysis/autosharding/{strategies => ops}/tensor_ops.py (100%) rename src/chop/passes/graph/analysis/autosharding/{strategies => ops}/view_ops.py (100%) diff --git a/src/chop/passes/graph/analysis/autosharding/alpa.py b/src/chop/passes/graph/analysis/autosharding/algos/alpa.py similarity index 100% rename from src/chop/passes/graph/analysis/autosharding/alpa.py rename to src/chop/passes/graph/analysis/autosharding/algos/alpa.py diff --git a/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py b/src/chop/passes/graph/analysis/autosharding/algos/alpa_intra_operator.py similarity index 99% rename from src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py rename to src/chop/passes/graph/analysis/autosharding/algos/alpa_intra_operator.py index 64b1e2f44..253ba2d66 100644 --- a/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py +++ b/src/chop/passes/graph/analysis/autosharding/algos/alpa_intra_operator.py @@ -14,15 +14,14 @@ from torch.distributed._tensor.placement_types import Shard, Replicate from chop.tools import get_logger -from chop.tools.utils import deepgetattr -from .mesh_model import MeshModel +from ..mesh_model import MeshModel -from .layers import ( +from ..layers import ( AUTOSHARDING_FUNCTIONS, IMPLICIT_FUNCS, FULLY_REPLICATED_FUNCS, ) -from .strategies.common import ( +from ..ops.common import ( fully_replicated_strategy, placeholder_or_getattr_strategy, ) diff --git a/src/chop/passes/graph/analysis/autosharding/fully_replicated.py b/src/chop/passes/graph/analysis/autosharding/algos/fully_replicated.py similarity index 95% rename from src/chop/passes/graph/analysis/autosharding/fully_replicated.py rename to src/chop/passes/graph/analysis/autosharding/algos/fully_replicated.py index 8911c3720..75035a39e 100644 --- a/src/chop/passes/graph/analysis/autosharding/fully_replicated.py +++ b/src/chop/passes/graph/analysis/autosharding/algos/fully_replicated.py @@ -2,7 +2,7 @@ from torch.distributed._tensor.placement_types import Replicate from chop.ir import MaseGraph -from .mesh_model import MeshModel +from ..mesh_model import MeshModel def fully_replicated_autosharding_pass( diff --git a/src/chop/passes/graph/analysis/autosharding/megatron.py b/src/chop/passes/graph/analysis/autosharding/algos/megatron.py similarity index 85% rename from src/chop/passes/graph/analysis/autosharding/megatron.py rename to src/chop/passes/graph/analysis/autosharding/algos/megatron.py index 33372e101..cc13679f6 100644 --- a/src/chop/passes/graph/analysis/autosharding/megatron.py +++ b/src/chop/passes/graph/analysis/autosharding/algos/megatron.py @@ -1,5 +1,5 @@ from chop.ir import MaseGraph -from .mesh_model import MeshModel +from ..mesh_model import MeshModel def megatron_autosharding_pass( diff --git a/src/chop/passes/graph/analysis/autosharding/autosharding.py b/src/chop/passes/graph/analysis/autosharding/autosharding.py index 9701e2d84..38e293fca 100644 --- a/src/chop/passes/graph/analysis/autosharding/autosharding.py +++ b/src/chop/passes/graph/analysis/autosharding/autosharding.py @@ -8,9 +8,9 @@ from chop.tools import get_logger from .mesh_model import MeshModel -from .alpa import alpa_autosharding_pass -from .megatron import megatron_autosharding_pass -from .fully_replicated import fully_replicated_autosharding_pass +from .algos.alpa import alpa_autosharding_pass +from .algos.megatron import megatron_autosharding_pass +from .algos.fully_replicated import fully_replicated_autosharding_pass logger = get_logger(__name__) logger.setLevel("INFO") diff --git a/src/chop/passes/graph/analysis/autosharding/layers.py b/src/chop/passes/graph/analysis/autosharding/layers.py index 79f59d23e..8576f5ebd 100644 --- a/src/chop/passes/graph/analysis/autosharding/layers.py +++ b/src/chop/passes/graph/analysis/autosharding/layers.py @@ -15,8 +15,8 @@ torch_transpose, ) -from .strategies.common import fully_replicated_strategy -from .strategies.matrix_ops import ( +from .ops.common import fully_replicated_strategy +from .ops.matrix_ops import ( transpose_strategy, mm_strategy, addmm_strategy, @@ -24,11 +24,11 @@ baddmm_strategy, scaled_dot_product_strategy, ) -from .strategies.view_ops import get_reshape_strategy -from .strategies.pointwise_ops import pointwise_strategy, linear_pointwise_strategy -from .strategies.math_ops import softmax_strategy, layer_norm_strategy -from .strategies.embedding_ops import embedding_strategy -from .strategies.tensor_ops import tensor_op_strategy, tensor_equal_strategy +from .ops.view_ops import get_reshape_strategy +from .ops.pointwise_ops import pointwise_strategy, linear_pointwise_strategy +from .ops.math_ops import softmax_strategy, layer_norm_strategy +from .ops.embedding_ops import embedding_strategy +from .ops.tensor_ops import tensor_op_strategy, tensor_equal_strategy logger = get_logger(__name__) diff --git a/src/chop/passes/graph/analysis/autosharding/strategies/basic_strategy.py b/src/chop/passes/graph/analysis/autosharding/ops/basic_strategy.py similarity index 100% rename from src/chop/passes/graph/analysis/autosharding/strategies/basic_strategy.py rename to src/chop/passes/graph/analysis/autosharding/ops/basic_strategy.py diff --git a/src/chop/passes/graph/analysis/autosharding/strategies/common.py b/src/chop/passes/graph/analysis/autosharding/ops/common.py similarity index 100% rename from src/chop/passes/graph/analysis/autosharding/strategies/common.py rename to src/chop/passes/graph/analysis/autosharding/ops/common.py diff --git a/src/chop/passes/graph/analysis/autosharding/strategies/embedding_ops.py b/src/chop/passes/graph/analysis/autosharding/ops/embedding_ops.py similarity index 99% rename from src/chop/passes/graph/analysis/autosharding/strategies/embedding_ops.py rename to src/chop/passes/graph/analysis/autosharding/ops/embedding_ops.py index ab1bf637e..941af13b2 100644 --- a/src/chop/passes/graph/analysis/autosharding/strategies/embedding_ops.py +++ b/src/chop/passes/graph/analysis/autosharding/ops/embedding_ops.py @@ -3,7 +3,6 @@ from dataclasses import dataclass, field from typing import cast, List, Optional -import itertools import torch import torch.distributed._functional_collectives as funcol diff --git a/src/chop/passes/graph/analysis/autosharding/strategies/math_ops.py b/src/chop/passes/graph/analysis/autosharding/ops/math_ops.py similarity index 100% rename from src/chop/passes/graph/analysis/autosharding/strategies/math_ops.py rename to src/chop/passes/graph/analysis/autosharding/ops/math_ops.py diff --git a/src/chop/passes/graph/analysis/autosharding/strategies/matrix_ops.py b/src/chop/passes/graph/analysis/autosharding/ops/matrix_ops.py similarity index 100% rename from src/chop/passes/graph/analysis/autosharding/strategies/matrix_ops.py rename to src/chop/passes/graph/analysis/autosharding/ops/matrix_ops.py diff --git a/src/chop/passes/graph/analysis/autosharding/strategies/pointwise_ops.py b/src/chop/passes/graph/analysis/autosharding/ops/pointwise_ops.py similarity index 98% rename from src/chop/passes/graph/analysis/autosharding/strategies/pointwise_ops.py rename to src/chop/passes/graph/analysis/autosharding/ops/pointwise_ops.py index a198b3e93..fe79c521a 100644 --- a/src/chop/passes/graph/analysis/autosharding/strategies/pointwise_ops.py +++ b/src/chop/passes/graph/analysis/autosharding/ops/pointwise_ops.py @@ -1,17 +1,14 @@ # Adapted from Pytorch Distributed DTensor API. # https://github.com/pytorch/pytorch/blob/main/torch/distributed/_tensor/ops/pointwise_ops.py -from typing import List, Sequence, Tuple +from typing import List import torch from torch.distributed._tensor._op_schema import ( - _is_inplace_op, - _is_out_variant_op, OpStrategy, PlacementStrategy, ) from torch.distributed._tensor.ops.utils import ( - generate_redistribute_costs, infer_broadcast_dims_map, map_placements_after_broadcast, normalize_dim, diff --git a/src/chop/passes/graph/analysis/autosharding/strategies/tensor_ops.py b/src/chop/passes/graph/analysis/autosharding/ops/tensor_ops.py similarity index 100% rename from src/chop/passes/graph/analysis/autosharding/strategies/tensor_ops.py rename to src/chop/passes/graph/analysis/autosharding/ops/tensor_ops.py diff --git a/src/chop/passes/graph/analysis/autosharding/strategies/view_ops.py b/src/chop/passes/graph/analysis/autosharding/ops/view_ops.py similarity index 100% rename from src/chop/passes/graph/analysis/autosharding/strategies/view_ops.py rename to src/chop/passes/graph/analysis/autosharding/ops/view_ops.py diff --git a/src/chop/passes/graph/analysis/autosharding/utils.py b/src/chop/passes/graph/analysis/autosharding/utils.py index 8a08be311..54ff94bea 100644 --- a/src/chop/passes/graph/analysis/autosharding/utils.py +++ b/src/chop/passes/graph/analysis/autosharding/utils.py @@ -1,4 +1,4 @@ -from typing import cast, Iterable, List, Sequence, Tuple, Union +from typing import Sequence, cast from torch.distributed._tensor.placement_types import DTensorSpec, Shard From 935e876a8f94f5834a9a174a0aaff17f5439eca4 Mon Sep 17 00:00:00 2001 From: Pedro Gimenes Date: Thu, 15 Aug 2024 15:55:31 +0000 Subject: [PATCH 76/93] remove breakpoint --- .../graph/analysis/autosharding/ops/common.py | 25 ++++++++----------- 1 file changed, 10 insertions(+), 15 deletions(-) diff --git a/src/chop/passes/graph/analysis/autosharding/ops/common.py b/src/chop/passes/graph/analysis/autosharding/ops/common.py index cda4d7e40..7e06d71dc 100644 --- a/src/chop/passes/graph/analysis/autosharding/ops/common.py +++ b/src/chop/passes/graph/analysis/autosharding/ops/common.py @@ -188,22 +188,17 @@ def expand_to_full_mesh_op_strategy( for strategy_comb in strategy_combs: spec_list = [] for specs in zip(*strategy_comb): - try: - spec_list.append( - DTensorSpec( - mesh, - tuple(specs), - tensor_meta=TensorMeta( - shape=meta["common"]["results"]["data_out_0"]["shape"], - stride=None, - dtype=meta["common"]["results"]["data_out_0"][ - "torch_dtype" - ], - ), - ) + spec_list.append( + DTensorSpec( + mesh, + tuple(specs), + tensor_meta=TensorMeta( + shape=meta["common"]["results"]["data_out_0"]["shape"], + stride=None, + dtype=meta["common"]["results"]["data_out_0"]["torch_dtype"], + ), ) - except: - breakpoint() + ) input_specs = spec_list[input_index:] # input_args_strategy = op_schema.args_strategy From 082e6d0def75a10b8bdf2c833d27345003b7af20 Mon Sep 17 00:00:00 2001 From: Pedro Gimenes Date: Thu, 15 Aug 2024 15:57:19 +0000 Subject: [PATCH 77/93] remove deprecated pass --- src/chop/passes/module/transforms/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/chop/passes/module/transforms/__init__.py b/src/chop/passes/module/transforms/__init__.py index c9a634321..3fcc8c5b3 100644 --- a/src/chop/passes/module/transforms/__init__.py +++ b/src/chop/passes/module/transforms/__init__.py @@ -1,2 +1 @@ from .quantize import quantize_module_transform_pass -from .autosharding import resharding_transform_pass From 6cc125625e513c226b5c8f2f3f7a073734f736ab Mon Sep 17 00:00:00 2001 From: Pedro Gimenes Date: Thu, 15 Aug 2024 16:21:51 +0000 Subject: [PATCH 78/93] remove deprecated tests --- .../autosharding/test_autosharding_bert.py | 73 ------------------- .../autosharding/test_autosharding_linear.py | 73 ------------------- 2 files changed, 146 deletions(-) delete mode 100644 test/passes/graph/analysis/autosharding/test_autosharding_bert.py delete mode 100644 test/passes/graph/analysis/autosharding/test_autosharding_linear.py diff --git a/test/passes/graph/analysis/autosharding/test_autosharding_bert.py b/test/passes/graph/analysis/autosharding/test_autosharding_bert.py deleted file mode 100644 index 97b4410d9..000000000 --- a/test/passes/graph/analysis/autosharding/test_autosharding_bert.py +++ /dev/null @@ -1,73 +0,0 @@ -import sys, pdb, traceback -import pytest - -import torch -import torch.nn as nn - -from chop.ir import MaseGraph -from chop.distributed import MaseLauncher -import chop.passes as passes -from chop.tools import get_logger - -from transformers.models.bert import BertConfig, BertModel - -logger = get_logger(__name__) -logger.setLevel("DEBUG") - -WORLD_SIZE = 8 -DEVICE_MESH = [[0, 1, 2, 3], [4, 5, 6, 7]] - - -@pytest.mark.skip(reason="Fixing needed") -def test_autosharding(): - - # Define config - config = BertConfig() - config.num_hidden_layers = 3 - config.hidden_size = 96 - config.intermediate_size = 384 - config._attn_implementation = "eager" - config_sequence_length = 4 - - # Initialize model and MaseGraph - model = BertModel(config) - mg = MaseGraph(model) - mg, _ = passes.init_metadata_analysis_pass(mg) - mg, _ = passes.report_graph_analysis_pass(mg, pass_args={"file_name": "bert.txt"}) - mg, _ = passes.add_common_metadata_analysis_pass( - mg, - pass_args={ - "dummy_in": { - "input_ids": torch.randint(0, 10, (1, config_sequence_length)), - }, - "add_value": False, - }, - ) - - # Run autosharding pass to decide sharding configuration - mg, module_map = passes.autosharding_analysis_pass( - mg, - pass_args={ - "mesh_shape": (2, 4), - "inter_node_bandwidth": 10e9, - "intra_node_bandwidth": 100e9, - }, - ) - - # Insert resharding wrappers around each module to handle inter-operator communication - mg, _ = passes.resharding_transform_pass( - mg, pass_args={"module_map": module_map, "device_mesh": DEVICE_MESH} - ) - - # dump print model to a file - with open("model.txt", "w") as f: - print(mg.model, file=f) - - # Launch model in distributed cluster - launcher = MaseLauncher(mg, world_size=WORLD_SIZE, device_mesh=DEVICE_MESH) - inputs = [torch.randint(0, 10, (1, config_sequence_length))] - launcher.run(module_map, inputs) - - -if __name__ == "__main__": - test_autosharding() diff --git a/test/passes/graph/analysis/autosharding/test_autosharding_linear.py b/test/passes/graph/analysis/autosharding/test_autosharding_linear.py deleted file mode 100644 index eae847aef..000000000 --- a/test/passes/graph/analysis/autosharding/test_autosharding_linear.py +++ /dev/null @@ -1,73 +0,0 @@ -import sys, pdb, traceback, os -import pytest - -import torch -import torch.nn as nn - -from chop.ir import MaseGraph -from chop.distributed import MaseLauncher -import chop.passes as passes -from chop.tools import get_logger - - -def excepthook(exc_type, exc_value, exc_traceback): - traceback.print_exception(exc_type, exc_value, exc_traceback) - print("\nEntering debugger...") - pdb.post_mortem(exc_traceback) - - -# Set the custom exception hook -sys.excepthook = excepthook - -logger = get_logger(__name__) -logger.setLevel("DEBUG") - -WORLD_SIZE = 8 -DEVICE_MESH = [[0, 1, 2, 3], [4, 5, 6, 7]] - - -class MLP(nn.Module): - def __init__(self, in_features=64, hidden_dimension=128, out_features=64): - super().__init__() - self.l1 = nn.Linear(in_features, hidden_dimension) - self.l2 = nn.Linear(hidden_dimension, out_features) - - def forward(self, x): - out = self.l1(x) - return self.l2(out) - - -@pytest.mark.skip(reason="Fixing needed") -def test_autosharding(): - - # Initialize model and MaseGraph - model = MLP() - mg = MaseGraph(model) - mg, _ = passes.init_metadata_analysis_pass(mg) - mg, _ = passes.add_common_metadata_analysis_pass( - mg, pass_args={"dummy_in": {"x": torch.randn((16, 64))}, "add_value": False} - ) - - # Run autosharding pass to decide sharding configuration - mg, module_map = passes.autosharding_analysis_pass( - mg, - pass_args={ - "mesh_shape": (2, 4), - "inter_node_bandwidth": 10e9, - "intra_node_bandwidth": 100e9, - }, - ) - - # Insert resharding wrappers around each module to handle inter-operator communication - mg, _ = passes.resharding_transform_pass( - mg, pass_args={"module_map": module_map, "device_mesh": DEVICE_MESH} - ) - - # Launch model in distributed cluster - launcher = MaseLauncher(mg, world_size=WORLD_SIZE, device_mesh=DEVICE_MESH) - inputs = [torch.randn((16, 64))] - launcher.run(module_map, inputs) - - -if __name__ == "__main__": - test_autosharding() From 0de6731c2c40795117fcdaf18d461efa89276ef1 Mon Sep 17 00:00:00 2001 From: Pedro Gimenes Date: Fri, 16 Aug 2024 14:26:33 +0000 Subject: [PATCH 79/93] set benchmarking device for compute cost estimation in intra operator pass --- .../autosharding/algos/alpa_intra_operator.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/src/chop/passes/graph/analysis/autosharding/algos/alpa_intra_operator.py b/src/chop/passes/graph/analysis/autosharding/algos/alpa_intra_operator.py index 253ba2d66..2a6bd5aea 100644 --- a/src/chop/passes/graph/analysis/autosharding/algos/alpa_intra_operator.py +++ b/src/chop/passes/graph/analysis/autosharding/algos/alpa_intra_operator.py @@ -37,7 +37,7 @@ def _get_computation_cost_from_strategy( mesh: MeshModel, repeat: int = 100, warmup_iters: int = 2, - profiling_device: int = 0, + profiling_device: int = None, ): """ ... @@ -119,7 +119,7 @@ def _get_computation_cost_from_strategy( start_event[idx].record() _ = fn(*args) end_event[idx].record() - torch.cuda.synchronize() + torch.cuda.synchronize(device=f"cuda:{profiling_device}") elapsed = [start_event[idx].elapsed_time(end_event[idx]) for idx in range(repeat)] @@ -309,7 +309,12 @@ def _extract_ilp(mg, mesh, pass_args={}): cost_vector = [] for strategy in op_strategy.strategies: try: - cost = _get_computation_cost_from_strategy(node, strategy, mesh) + cost = _get_computation_cost_from_strategy( + node, + strategy, + mesh, + profiling_device=pass_args.get("benchmarking_device", None), + ) except Exception as e: logger.warning( f"Failed to compute computation cost for node {node} strategy: {strategy} due to exception: {e}" @@ -522,6 +527,13 @@ def alpa_intra_op_sharding_pass(mg, mesh, pass_args={}, debug=False): # Formulate and solve the ILP logger.info(f"Formulating the ILP...") + + # Set CUDA device for profiling + device_id = pass_args.get("benchmarking_device", None) + torch.cuda.set_device(device_id) + + logger.info(f"Setting CUDA device to: {device_id}") + mg, problem = _extract_ilp(mg, mesh, pass_args) logger.info(f"Solving the ILP...") From f3230665577bf5b6784ddd367fb59026503a4c4c Mon Sep 17 00:00:00 2001 From: Pedro Gimenes Date: Fri, 16 Aug 2024 14:31:23 +0000 Subject: [PATCH 80/93] remove MLIR CI --- .github/workflows/testTorchMLIR.yml | 40 ----------------------------- 1 file changed, 40 deletions(-) delete mode 100644 .github/workflows/testTorchMLIR.yml diff --git a/.github/workflows/testTorchMLIR.yml b/.github/workflows/testTorchMLIR.yml deleted file mode 100644 index 9f5ac29bb..000000000 --- a/.github/workflows/testTorchMLIR.yml +++ /dev/null @@ -1,40 +0,0 @@ -name: Torch-MLIR Test - -on: - push: - branches: [ "main*" ] - pull_request: - branches: [ "main*" ] - workflow_dispatch: - logLevel: - description: 'Log level' - required: true - default: 'warning' - type: choice - options: - - info - - warning - - debug - -jobs: - - torch-mlir-test: - runs-on: ubuntu-latest - container: - image: deepwok/mase-docker-cpu:latest - steps: - - # Clone the MASE repo and its submodules. - - name: Get MASE - uses: actions/checkout@v3 - with: - submodules: "true" - - - name: Set git safe - run: | - git config --global --add safe.directory $PWD - - - name: Torch-MLIR regression test - run: | - python3 scripts/test-torch-mlir.py - From 7ab2d944d035284af468b536b6fc40f66d3e2d56 Mon Sep 17 00:00:00 2001 From: Pedro Gimenes Date: Mon, 19 Aug 2024 09:52:32 +0000 Subject: [PATCH 81/93] update torch.distributed.tensor imports since _tensor has been added to public torch namespace --- src/chop/distributed/debug.py | 6 ++--- src/chop/distributed/tensor/__init__.py | 6 ++--- src/chop/distributed/tensor/_dispatch.py | 6 ++--- src/chop/distributed/tensor/_redistribute.py | 4 ++-- src/chop/distributed/tensor/_sharding_prop.py | 6 ++--- src/chop/distributed/tensor/_utils.py | 4 ++-- src/chop/distributed/tensor/api.py | 12 +++++----- .../distributed/tensor/ops/basic_strategy.py | 4 ++-- .../distributed/tensor/ops/common_rules.py | 8 +++---- src/chop/distributed/tensor/ops/conv_ops.py | 4 ++-- .../distributed/tensor/ops/embedding_ops.py | 6 ++--- .../tensor/ops/experimental_ops.py | 6 ++--- src/chop/distributed/tensor/ops/math_ops.py | 9 +++++--- src/chop/distributed/tensor/ops/matrix_ops.py | 6 ++--- .../distributed/tensor/ops/pointwise_ops.py | 6 ++--- src/chop/distributed/tensor/ops/random_ops.py | 4 ++-- src/chop/distributed/tensor/ops/tensor_ops.py | 4 ++-- src/chop/distributed/tensor/ops/utils.py | 8 +++---- src/chop/distributed/tensor/ops/view_ops.py | 6 ++--- src/chop/distributed/utils.py | 22 ++++++++++++++----- src/chop/nn/functional/dtensor.py | 6 ++--- .../add_metadata/add_common_metadata.py | 10 +++++++-- .../autosharding/algos/alpa_intra_operator.py | 6 ++--- .../autosharding/algos/fully_replicated.py | 4 ++-- .../autosharding/alpa_intra_operator.py | 4 ++-- .../analysis/autosharding/autosharding.py | 4 ++-- .../autosharding/ops/basic_strategy.py | 4 ++-- .../graph/analysis/autosharding/ops/common.py | 6 ++--- .../autosharding/ops/embedding_ops.py | 4 ++-- .../analysis/autosharding/ops/math_ops.py | 9 ++++---- .../analysis/autosharding/ops/matrix_ops.py | 8 +++---- .../autosharding/ops/pointwise_ops.py | 6 ++--- .../analysis/autosharding/ops/tensor_ops.py | 6 ++--- .../analysis/autosharding/ops/view_ops.py | 8 +++---- .../autosharding/strategies/basic_strategy.py | 4 ++-- .../autosharding/strategies/common.py | 4 ++-- .../autosharding/strategies/embedding_ops.py | 6 ++--- .../autosharding/strategies/math_ops.py | 9 ++++---- .../autosharding/strategies/matrix_ops.py | 8 +++---- .../autosharding/strategies/pointwise_ops.py | 6 ++--- .../autosharding/strategies/tensor_ops.py | 6 ++--- .../autosharding/strategies/view_ops.py | 8 +++---- .../graph/analysis/autosharding/utils.py | 2 +- .../transforms/insert_dtensor_wrapper.py | 2 +- .../passes/graph/transforms/resharding.py | 4 ++-- src/chop/pipelines/distributed_inference.py | 1 + 46 files changed, 152 insertions(+), 130 deletions(-) diff --git a/src/chop/distributed/debug.py b/src/chop/distributed/debug.py index 76cd8f3e9..a65d72b3d 100644 --- a/src/chop/distributed/debug.py +++ b/src/chop/distributed/debug.py @@ -4,9 +4,9 @@ import numpy as np from torch._prims_common import ShapeType -from torch.distributed._tensor import DeviceMesh +from torch.distributed.tensor import DeviceMesh -from torch.distributed._tensor.placement_types import Placement, Shard +from torch.distributed.tensor.placement_types import Placement, Shard def _mesh_to_coordinate(mesh, device_type): @@ -90,7 +90,7 @@ def compute_local_shape_and_global_offset( my_coordinate: List[int], ) -> Tuple[Tuple[int, ...], Tuple[int, ...]]: """ - Same as torch.distributed._tensor._utils.compute_local_shape_and_global_offset but + Same as torch.distributed.tensor._utils.compute_local_shape_and_global_offset but with custom my_coordinate input. This is the modified implementation for visualize_sharding. """ diff --git a/src/chop/distributed/tensor/__init__.py b/src/chop/distributed/tensor/__init__.py index eecce89dc..d4415d599 100644 --- a/src/chop/distributed/tensor/__init__.py +++ b/src/chop/distributed/tensor/__init__.py @@ -4,9 +4,9 @@ # Import all builtin dist tensor ops import torch -import torch.distributed._tensor.random as random +import torch.distributed.tensor.random as random -from torch.distributed._tensor.placement_types import ( +from torch.distributed.tensor.placement_types import ( Partial, Placement, Replicate, @@ -57,7 +57,7 @@ def _dtensor_init_helper( placements=None, **kwargs, ) -> DTensor: - from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta + from torch.distributed.tensor.placement_types import DTensorSpec, TensorMeta # if device_mesh is None, use the one from mesh resources device_mesh = device_mesh or _mesh_resources.get_current_mesh() diff --git a/src/chop/distributed/tensor/_dispatch.py b/src/chop/distributed/tensor/_dispatch.py index 3f9d38043..0e0a661b8 100644 --- a/src/chop/distributed/tensor/_dispatch.py +++ b/src/chop/distributed/tensor/_dispatch.py @@ -2,14 +2,14 @@ from typing import cast, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING import torch -from torch.distributed._tensor._op_schema import ( +from torch.distributed.tensor._op_schema import ( OutputSpecType, ) -from torch.distributed._tensor._tp_conv import ( +from torch.distributed.tensor._tp_conv import ( convolution_backward_handler, convolution_handler, ) -from torch.distributed._tensor.placement_types import ( +from torch.distributed.tensor.placement_types import ( DTensorSpec, Replicate, TensorMeta, diff --git a/src/chop/distributed/tensor/_redistribute.py b/src/chop/distributed/tensor/_redistribute.py index fc66b219b..6f99b558c 100644 --- a/src/chop/distributed/tensor/_redistribute.py +++ b/src/chop/distributed/tensor/_redistribute.py @@ -5,8 +5,8 @@ import torch import torch.distributed._functional_collectives as funcol -from torch.distributed._tensor.device_mesh import DeviceMesh -from torch.distributed._tensor.placement_types import ( +from torch.distributed.tensor.device_mesh import DeviceMesh +from torch.distributed.tensor.placement_types import ( DTensorSpec, Partial, Placement, diff --git a/src/chop/distributed/tensor/_sharding_prop.py b/src/chop/distributed/tensor/_sharding_prop.py index 81224a952..0812f3c0d 100644 --- a/src/chop/distributed/tensor/_sharding_prop.py +++ b/src/chop/distributed/tensor/_sharding_prop.py @@ -6,7 +6,7 @@ import torch from torch._ops import OpOverload from torch._subclasses import FakeTensorMode -from torch.distributed._tensor._op_schema import ( +from torch.distributed.tensor._op_schema import ( OpInfo, OpSchema, OpStrategy, @@ -17,12 +17,12 @@ StrategyType, TupleStrategy, ) -from torch.distributed._tensor._utils import ( +from torch.distributed.tensor._utils import ( compute_local_shape, compute_local_stride, try_find_mesh_from_args, ) -from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta +from torch.distributed.tensor.placement_types import DTensorSpec, TensorMeta from torch.distributed.device_mesh import DeviceMesh diff --git a/src/chop/distributed/tensor/_utils.py b/src/chop/distributed/tensor/_utils.py index a3cc8ee5a..fa278c027 100644 --- a/src/chop/distributed/tensor/_utils.py +++ b/src/chop/distributed/tensor/_utils.py @@ -1,9 +1,9 @@ from typing import cast, List, Sequence, Tuple import torch -import torch.distributed._tensor.api as dtensor +import torch.distributed.tensor.api as dtensor from torch._prims_common import ShapeType -from torch.distributed._tensor.placement_types import ( +from torch.distributed.tensor.placement_types import ( DTensorSpec, Partial, Placement, diff --git a/src/chop/distributed/tensor/api.py b/src/chop/distributed/tensor/api.py index e9cabae4e..3dba5fc02 100644 --- a/src/chop/distributed/tensor/api.py +++ b/src/chop/distributed/tensor/api.py @@ -5,18 +5,18 @@ from typing import Any, Callable, cast, Optional, Sequence, Tuple import torch -import torch.distributed._tensor.random as random +import torch.distributed.tensor.random as random import torch.nn as nn -from torch.distributed._tensor._collective_utils import ( +from torch.distributed.tensor._collective_utils import ( check_tensor_meta, mesh_broadcast, ) -from torch.distributed._tensor._redistribute import ( +from torch.distributed.tensor._redistribute import ( Redistribute, redistribute_local_tensor, ) -from torch.distributed._tensor._utils import compute_global_tensor_info -from torch.distributed._tensor.placement_types import ( +from torch.distributed.tensor._utils import compute_global_tensor_info +from torch.distributed.tensor.placement_types import ( DTensorSpec, Partial, Placement, @@ -24,7 +24,7 @@ Shard, TensorMeta, ) -from torch.distributed._tensor.random import ( +from torch.distributed.tensor.random import ( is_rng_supported_mesh, OffsetBasedRNGTracker, ) diff --git a/src/chop/distributed/tensor/ops/basic_strategy.py b/src/chop/distributed/tensor/ops/basic_strategy.py index 97dd43b15..37b1c29ff 100644 --- a/src/chop/distributed/tensor/ops/basic_strategy.py +++ b/src/chop/distributed/tensor/ops/basic_strategy.py @@ -2,8 +2,8 @@ from dataclasses import dataclass from typing import List, Set, Tuple -from torch.distributed._tensor._op_schema import OpStrategy, PlacementStrategy -from torch.distributed._tensor.placement_types import ( +from torch.distributed.tensor._op_schema import OpStrategy, PlacementStrategy +from torch.distributed.tensor.placement_types import ( DTensorSpec, Partial, Placement, diff --git a/src/chop/distributed/tensor/ops/common_rules.py b/src/chop/distributed/tensor/ops/common_rules.py index f70b27076..082755077 100644 --- a/src/chop/distributed/tensor/ops/common_rules.py +++ b/src/chop/distributed/tensor/ops/common_rules.py @@ -2,15 +2,15 @@ from typing import cast, Dict, List, Optional, Tuple import torch -from torch.distributed._tensor._op_schema import ( +from torch.distributed.tensor._op_schema import ( _is_inplace_op, _is_out_variant_op, OpSchema, OutputSharding, ) -from torch.distributed._tensor._utils import compute_local_shape -from torch.distributed._tensor.ops.utils import prod -from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta +from torch.distributed.tensor._utils import compute_local_shape +from torch.distributed.tensor._ops.utils import prod +from torch.distributed.tensor.placement_types import DTensorSpec, TensorMeta def _replace_char_in_str(string: str, new_char: str, idx: int) -> str: diff --git a/src/chop/distributed/tensor/ops/conv_ops.py b/src/chop/distributed/tensor/ops/conv_ops.py index 7bf13241d..dfcfd4a88 100644 --- a/src/chop/distributed/tensor/ops/conv_ops.py +++ b/src/chop/distributed/tensor/ops/conv_ops.py @@ -3,9 +3,9 @@ from typing import List import torch -from torch.distributed._tensor._op_schema import OpSchema, OutputSharding +from torch.distributed.tensor._op_schema import OpSchema, OutputSharding from chop.distributed.tensor.ops.utils import register_prop_rule -from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta +from torch.distributed.tensor.placement_types import DTensorSpec, TensorMeta aten = torch.ops.aten diff --git a/src/chop/distributed/tensor/ops/embedding_ops.py b/src/chop/distributed/tensor/ops/embedding_ops.py index d89ec651b..11ec896e8 100644 --- a/src/chop/distributed/tensor/ops/embedding_ops.py +++ b/src/chop/distributed/tensor/ops/embedding_ops.py @@ -6,14 +6,14 @@ import torch import torch.distributed._functional_collectives as funcol -from torch.distributed._tensor._op_schema import ( +from torch.distributed.tensor._op_schema import ( OpSchema, OpStrategy, PlacementList, StrategyType, ) -from torch.distributed._tensor.ops.utils import expand_to_full_mesh_op_strategy -from torch.distributed._tensor.placement_types import ( +from torch.distributed.tensor._ops.utils import expand_to_full_mesh_op_strategy +from torch.distributed.tensor.placement_types import ( Partial, Placement, Replicate, diff --git a/src/chop/distributed/tensor/ops/experimental_ops.py b/src/chop/distributed/tensor/ops/experimental_ops.py index 432fbede8..03f94e2cd 100644 --- a/src/chop/distributed/tensor/ops/experimental_ops.py +++ b/src/chop/distributed/tensor/ops/experimental_ops.py @@ -2,15 +2,15 @@ # implement matrix related ops for distributed tensor import torch -from torch.distributed._tensor._op_schema import ( +from torch.distributed.tensor._op_schema import ( OpSchema, OpStrategy, PlacementStrategy, StrategyType, ) -from torch.distributed._tensor.device_mesh import DeviceMesh +from torch.distributed.tensor.device_mesh import DeviceMesh from chop.distributed.tensor.ops.utils import register_op_strategy -from torch.distributed._tensor.placement_types import DTensorSpec, Replicate +from torch.distributed.tensor.placement_types import DTensorSpec, Replicate aten = torch.ops.aten diff --git a/src/chop/distributed/tensor/ops/math_ops.py b/src/chop/distributed/tensor/ops/math_ops.py index c6ebe49be..8c19aefcb 100644 --- a/src/chop/distributed/tensor/ops/math_ops.py +++ b/src/chop/distributed/tensor/ops/math_ops.py @@ -6,7 +6,7 @@ from typing import cast, List, Optional, Sequence, Tuple, Union import torch -from torch.distributed._tensor._op_schema import ( +from torch.distributed.tensor._op_schema import ( OpSchema, OpStrategy, PlacementList, @@ -14,16 +14,19 @@ RuntimeSchemaInfo, TupleStrategy, ) -from torch.distributed._tensor.ops.utils import ( +from torch.distributed.tensor._ops.utils import ( as_list, expand_to_full_mesh_op_strategy, generate_redistribute_costs, is_tensor_evenly_shardable, normalize_dim, normalize_dims, +) + +from torch.distributed.utils import ( normalize_to_torch_size, ) -from torch.distributed._tensor.placement_types import ( +from torch.distributed.tensor.placement_types import ( DTensorSpec, Partial, Placement, diff --git a/src/chop/distributed/tensor/ops/matrix_ops.py b/src/chop/distributed/tensor/ops/matrix_ops.py index 77484de7d..35d96c814 100644 --- a/src/chop/distributed/tensor/ops/matrix_ops.py +++ b/src/chop/distributed/tensor/ops/matrix_ops.py @@ -2,20 +2,20 @@ # implement matrix related ops for distributed tensor import torch -from torch.distributed._tensor._op_schema import ( +from torch.distributed.tensor._op_schema import ( OpSchema, OpStrategy, PlacementList, PlacementStrategy, ) -from torch.distributed._tensor.ops.utils import ( +from torch.distributed.tensor._ops.utils import ( expand_to_full_mesh_op_strategy, generate_redistribute_costs, infer_broadcast_dims_map, is_tensor_shardable, map_placements_after_broadcast, ) -from torch.distributed._tensor.placement_types import ( +from torch.distributed.tensor.placement_types import ( DTensorSpec, Placement, Replicate, diff --git a/src/chop/distributed/tensor/ops/pointwise_ops.py b/src/chop/distributed/tensor/ops/pointwise_ops.py index c3e1f082f..2a9986ed0 100644 --- a/src/chop/distributed/tensor/ops/pointwise_ops.py +++ b/src/chop/distributed/tensor/ops/pointwise_ops.py @@ -2,7 +2,7 @@ from typing import List, Sequence, Tuple import torch -from torch.distributed._tensor._op_schema import ( +from torch.distributed.tensor._op_schema import ( _is_inplace_op, _is_out_variant_op, OpSchema, @@ -12,13 +12,13 @@ StrategyType, TupleStrategy, ) -from torch.distributed._tensor.ops.utils import ( +from torch.distributed.tensor._ops.utils import ( generate_redistribute_costs, infer_broadcast_dims_map, map_placements_after_broadcast, normalize_dim, ) -from torch.distributed._tensor.placement_types import ( +from torch.distributed.tensor.placement_types import ( DTensorSpec, Partial, Placement, diff --git a/src/chop/distributed/tensor/ops/random_ops.py b/src/chop/distributed/tensor/ops/random_ops.py index 7eefa30fc..4b9a3303d 100644 --- a/src/chop/distributed/tensor/ops/random_ops.py +++ b/src/chop/distributed/tensor/ops/random_ops.py @@ -1,12 +1,12 @@ # Copyright (c) Meta Platforms, Inc. and affiliates import torch -from torch.distributed._tensor._op_schema import ( +from torch.distributed.tensor._op_schema import ( OpSchema, OpStrategy, PlacementStrategy, StrategyType, ) -from torch.distributed._tensor.ops.utils import is_tensor_partial +from torch.distributed.tensor._ops.utils import is_tensor_partial from chop.distributed.tensor.ops.utils import register_op_strategy from torch.distributed.device_mesh import DeviceMesh diff --git a/src/chop/distributed/tensor/ops/tensor_ops.py b/src/chop/distributed/tensor/ops/tensor_ops.py index 8e64ff514..51ae573ab 100644 --- a/src/chop/distributed/tensor/ops/tensor_ops.py +++ b/src/chop/distributed/tensor/ops/tensor_ops.py @@ -3,7 +3,7 @@ from typing import cast, List, Optional, Sequence, Tuple import torch -from torch.distributed._tensor._op_schema import ( +from torch.distributed.tensor._op_schema import ( _is_inplace_op, OpSchema, OpStrategy, @@ -14,7 +14,7 @@ StrategyType, TupleStrategy, ) -from torch.distributed._tensor.placement_types import ( +from torch.distributed.tensor.placement_types import ( DTensorSpec, Partial, Placement, diff --git a/src/chop/distributed/tensor/ops/utils.py b/src/chop/distributed/tensor/ops/utils.py index 27cf89224..236e45fc9 100644 --- a/src/chop/distributed/tensor/ops/utils.py +++ b/src/chop/distributed/tensor/ops/utils.py @@ -6,16 +6,16 @@ from typing import cast, Iterable, List, Optional, Sequence, Tuple, Union import torch -from torch.distributed._tensor._collective_utils import redistribute_cost -from torch.distributed._tensor._op_schema import ( +from torch.distributed.tensor._collective_utils import redistribute_cost +from torch.distributed.tensor._op_schema import ( OpSchema, OpStrategy, PlacementList, PlacementStrategy, RuntimeSchemaInfo, ) -from torch.distributed._tensor.device_mesh import DeviceMesh -from torch.distributed._tensor.placement_types import ( +from torch.distributed.tensor.device_mesh import DeviceMesh +from torch.distributed.tensor.placement_types import ( DTensorSpec, Partial, Placement, diff --git a/src/chop/distributed/tensor/ops/view_ops.py b/src/chop/distributed/tensor/ops/view_ops.py index dc8103b08..fed462a6a 100644 --- a/src/chop/distributed/tensor/ops/view_ops.py +++ b/src/chop/distributed/tensor/ops/view_ops.py @@ -16,15 +16,15 @@ import torch from torch import Tensor -from torch.distributed._tensor._op_schema import ( +from torch.distributed.tensor._op_schema import ( OpSchema, OpStrategy, PlacementStrategy, RuntimeSchemaInfo, StrategyType, ) -from torch.distributed._tensor.api import Shard -from torch.distributed._tensor.placement_types import DTensorSpec, Placement, Replicate +from torch.distributed.tensor.api import Shard +from torch.distributed.tensor.placement_types import DTensorSpec, Placement, Replicate from torch.distributed.device_mesh import DeviceMesh from chop.distributed.tensor.ops.utils import register_op_strategy diff --git a/src/chop/distributed/utils.py b/src/chop/distributed/utils.py index c95b6b405..db751203c 100644 --- a/src/chop/distributed/utils.py +++ b/src/chop/distributed/utils.py @@ -4,7 +4,7 @@ import torch import torch.nn as nn import torch.distributed as dist -from torch.distributed._tensor import DeviceMesh +from torch.distributed.tensor import DeviceMesh from chop.tools import get_logger from chop.distributed.tensor import distribute_tensor @@ -47,10 +47,20 @@ def distributed_average_timing( "info", ) dist.barrier(async_op=True) - start = time() - result = fn(*args) - dist.barrier(async_op=True) - end = time() + + if isinstance(args, list): + start = time() + result = fn(*args) + dist.barrier(async_op=True) + end = time() + elif isinstance(args, dict): + start = time() + result = fn(**args) + dist.barrier(async_op=True) + end = time() + else: + raise ValueError("args must be a list or a dict") + times.append(end - start) rlog( logger, @@ -70,7 +80,7 @@ def dist_model_fn( tensor_sharding_map={}, ) -> None: """ - This function gets called by torch.distributed._tensor.distribute_module on each module in the model. + This function gets called by torch.distributed.tensor.distribute_module on each module in the model. Each tensor in each module is distributed according to the sharding configuration in tensor_sharding_map. Args: diff --git a/src/chop/nn/functional/dtensor.py b/src/chop/nn/functional/dtensor.py index c5e29b73d..49a974af3 100644 --- a/src/chop/nn/functional/dtensor.py +++ b/src/chop/nn/functional/dtensor.py @@ -3,10 +3,10 @@ import torch import torch.fx as fx from torch.distributed.device_mesh import DeviceMesh -from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta -from torch.distributed._tensor._redistribute import redistribute_local_tensor +from torch.distributed.tensor.placement_types import DTensorSpec, TensorMeta +from torch.distributed.tensor._redistribute import redistribute_local_tensor -from torch.distributed._tensor.placement_types import Placement +from torch.distributed.tensor.placement_types import Placement from chop.ir.graph import MaseMetadata from chop.distributed.tensor import DTensor diff --git a/src/chop/passes/graph/analysis/add_metadata/add_common_metadata.py b/src/chop/passes/graph/analysis/add_metadata/add_common_metadata.py index fb7d43d3a..ff2639fe9 100644 --- a/src/chop/passes/graph/analysis/add_metadata/add_common_metadata.py +++ b/src/chop/passes/graph/analysis/add_metadata/add_common_metadata.py @@ -204,7 +204,10 @@ def graph_iterator_for_mase_ops(graph): def graph_iterator_for_metadata( - graph, dummy_in=None, add_value=True, force_device_meta=False + graph, + dummy_in=None, + add_value=True, + force_device_meta=False, ): """ largely apated from https://pytorch.org/docs/stable/fx.html @@ -229,7 +232,10 @@ def graph_iterator_for_metadata( elif node.op == "call_function": args = load_arg(node.args, env) kwargs = load_arg(node.kwargs, env) - result = node.target(*args, **kwargs) + try: + result = node.target(*args, **kwargs) + except: + breakpoint() analyse_fn = analyse_common_parameters_function elif node.op == "call_method": self_obj, *args = load_arg(node.args, env) diff --git a/src/chop/passes/graph/analysis/autosharding/algos/alpa_intra_operator.py b/src/chop/passes/graph/analysis/autosharding/algos/alpa_intra_operator.py index 2a6bd5aea..71fc7ea6c 100644 --- a/src/chop/passes/graph/analysis/autosharding/algos/alpa_intra_operator.py +++ b/src/chop/passes/graph/analysis/autosharding/algos/alpa_intra_operator.py @@ -5,13 +5,13 @@ import torch import torch.fx as fx -from torch.distributed._tensor._collective_utils import redistribute_cost -from torch.distributed._tensor._op_schema import ( +from torch.distributed.tensor._collective_utils import redistribute_cost +from torch.distributed.tensor._op_schema import ( DTensorSpec, OpStrategy, PlacementStrategy, ) -from torch.distributed._tensor.placement_types import Shard, Replicate +from torch.distributed.tensor.placement_types import Shard, Replicate from chop.tools import get_logger from ..mesh_model import MeshModel diff --git a/src/chop/passes/graph/analysis/autosharding/algos/fully_replicated.py b/src/chop/passes/graph/analysis/autosharding/algos/fully_replicated.py index 75035a39e..a71cd6973 100644 --- a/src/chop/passes/graph/analysis/autosharding/algos/fully_replicated.py +++ b/src/chop/passes/graph/analysis/autosharding/algos/fully_replicated.py @@ -1,5 +1,5 @@ -from torch.distributed._tensor._op_schema import DTensorSpec -from torch.distributed._tensor.placement_types import Replicate +from torch.distributed.tensor._op_schema import DTensorSpec +from torch.distributed.tensor.placement_types import Replicate from chop.ir import MaseGraph from ..mesh_model import MeshModel diff --git a/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py b/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py index b431095d6..ca33d31a8 100644 --- a/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py +++ b/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py @@ -1,7 +1,7 @@ import torch import torch.fx as fx -from torch.distributed._tensor._collective_utils import redistribute_cost -from torch.distributed._tensor._op_schema import DTensorSpec +from torch.distributed.tensor._collective_utils import redistribute_cost +from torch.distributed.tensor._op_schema import DTensorSpec import numpy as np import cvxpy as cp diff --git a/src/chop/passes/graph/analysis/autosharding/autosharding.py b/src/chop/passes/graph/analysis/autosharding/autosharding.py index 2515828e9..80baba829 100644 --- a/src/chop/passes/graph/analysis/autosharding/autosharding.py +++ b/src/chop/passes/graph/analysis/autosharding/autosharding.py @@ -2,8 +2,8 @@ from time import time import dill -from torch.distributed._tensor._op_schema import DTensorSpec -from torch.distributed._tensor.placement_types import Replicate +from torch.distributed.tensor._op_schema import DTensorSpec +from torch.distributed.tensor.placement_types import Replicate from chop.tools import get_logger from .mesh_model import MeshModel diff --git a/src/chop/passes/graph/analysis/autosharding/ops/basic_strategy.py b/src/chop/passes/graph/analysis/autosharding/ops/basic_strategy.py index b9541013a..94708b89b 100644 --- a/src/chop/passes/graph/analysis/autosharding/ops/basic_strategy.py +++ b/src/chop/passes/graph/analysis/autosharding/ops/basic_strategy.py @@ -5,8 +5,8 @@ from dataclasses import dataclass from typing import List, Set, Tuple -from torch.distributed._tensor._op_schema import OpStrategy, PlacementStrategy -from torch.distributed._tensor.placement_types import ( +from torch.distributed.tensor._op_schema import OpStrategy, PlacementStrategy +from torch.distributed.tensor.placement_types import ( DTensorSpec, _Partial, Placement, diff --git a/src/chop/passes/graph/analysis/autosharding/ops/common.py b/src/chop/passes/graph/analysis/autosharding/ops/common.py index 7e06d71dc..56a19598f 100644 --- a/src/chop/passes/graph/analysis/autosharding/ops/common.py +++ b/src/chop/passes/graph/analysis/autosharding/ops/common.py @@ -6,15 +6,15 @@ import torch.nn.functional as F from torch.distributed.device_mesh import DeviceMesh -from torch.distributed._tensor._op_schema import OpStrategy, PlacementStrategy -from torch.distributed._tensor.placement_types import ( +from torch.distributed.tensor._op_schema import OpStrategy, PlacementStrategy +from torch.distributed.tensor.placement_types import ( Placement, Replicate, Shard, DTensorSpec, TensorMeta, ) -from torch.distributed._tensor.ops.utils import ( +from torch.distributed.tensor._ops.utils import ( is_tensor_shardable, generate_redistribute_costs, ) diff --git a/src/chop/passes/graph/analysis/autosharding/ops/embedding_ops.py b/src/chop/passes/graph/analysis/autosharding/ops/embedding_ops.py index 941af13b2..bf7fbe06c 100644 --- a/src/chop/passes/graph/analysis/autosharding/ops/embedding_ops.py +++ b/src/chop/passes/graph/analysis/autosharding/ops/embedding_ops.py @@ -6,8 +6,8 @@ import torch import torch.distributed._functional_collectives as funcol -from torch.distributed._tensor._op_schema import StrategyType -from torch.distributed._tensor.placement_types import ( +from torch.distributed.tensor._op_schema import StrategyType +from torch.distributed.tensor.placement_types import ( Partial, Placement, Replicate, diff --git a/src/chop/passes/graph/analysis/autosharding/ops/math_ops.py b/src/chop/passes/graph/analysis/autosharding/ops/math_ops.py index 984295c58..6b722763e 100644 --- a/src/chop/passes/graph/analysis/autosharding/ops/math_ops.py +++ b/src/chop/passes/graph/analysis/autosharding/ops/math_ops.py @@ -4,16 +4,17 @@ from typing import cast, List, Optional, Sequence, Tuple, Union import torch -from torch.distributed._tensor._op_schema import ( +from torch.distributed.tensor._op_schema import ( OpStrategy, PlacementStrategy, ) -from torch.distributed._tensor.ops.utils import ( - generate_redistribute_costs, +from torch.distributed.tensor._ops.utils import ( normalize_dim, +) +from torch.distributed.tensor._utils import ( normalize_to_torch_size, ) -from torch.distributed._tensor.placement_types import ( +from torch.distributed.tensor.placement_types import ( DTensorSpec, Placement, Replicate, diff --git a/src/chop/passes/graph/analysis/autosharding/ops/matrix_ops.py b/src/chop/passes/graph/analysis/autosharding/ops/matrix_ops.py index 20a2c74c6..8948a1c98 100644 --- a/src/chop/passes/graph/analysis/autosharding/ops/matrix_ops.py +++ b/src/chop/passes/graph/analysis/autosharding/ops/matrix_ops.py @@ -2,18 +2,18 @@ # https://github.com/pytorch/pytorch/blob/main/torch/distributed/_tensor/ops/matrix_ops.py import torch -from torch.distributed._tensor._op_schema import ( +from torch.distributed.tensor._op_schema import ( OpStrategy, PlacementStrategy, PlacementList, ) -from torch.distributed._tensor.placement_types import Replicate, Shard, Placement +from torch.distributed.tensor.placement_types import Replicate, Shard, Placement from .basic_strategy import gen_einsum_strategies -from torch.distributed._tensor.ops.utils import ( +from torch.distributed.tensor._ops.utils import ( infer_broadcast_dims_map, map_placements_after_broadcast, ) -from torch.distributed._tensor.placement_types import ( +from torch.distributed.tensor.placement_types import ( DTensorSpec, Shard, TensorMeta, diff --git a/src/chop/passes/graph/analysis/autosharding/ops/pointwise_ops.py b/src/chop/passes/graph/analysis/autosharding/ops/pointwise_ops.py index fe79c521a..3cc958ba3 100644 --- a/src/chop/passes/graph/analysis/autosharding/ops/pointwise_ops.py +++ b/src/chop/passes/graph/analysis/autosharding/ops/pointwise_ops.py @@ -4,16 +4,16 @@ from typing import List import torch -from torch.distributed._tensor._op_schema import ( +from torch.distributed.tensor._op_schema import ( OpStrategy, PlacementStrategy, ) -from torch.distributed._tensor.ops.utils import ( +from torch.distributed.tensor._ops.utils import ( infer_broadcast_dims_map, map_placements_after_broadcast, normalize_dim, ) -from torch.distributed._tensor.placement_types import ( +from torch.distributed.tensor.placement_types import ( DTensorSpec, Partial, Placement, diff --git a/src/chop/passes/graph/analysis/autosharding/ops/tensor_ops.py b/src/chop/passes/graph/analysis/autosharding/ops/tensor_ops.py index 0a9daee2a..90869c4d6 100644 --- a/src/chop/passes/graph/analysis/autosharding/ops/tensor_ops.py +++ b/src/chop/passes/graph/analysis/autosharding/ops/tensor_ops.py @@ -1,15 +1,15 @@ # Adapted from Pytorch Distributed DTensor API. # https://github.com/pytorch/pytorch/blob/main/torch/distributed/_tensor/ops/tensor_ops.py -from torch.distributed._tensor._op_schema import ( +from torch.distributed.tensor._op_schema import ( OpStrategy, PlacementStrategy, StrategyType, ) -from torch.distributed._tensor.ops.utils import ( +from torch.distributed.tensor._ops.utils import ( is_tensor_partial, ) -from torch.distributed._tensor.placement_types import ( +from torch.distributed.tensor.placement_types import ( DTensorSpec, Partial, Replicate, diff --git a/src/chop/passes/graph/analysis/autosharding/ops/view_ops.py b/src/chop/passes/graph/analysis/autosharding/ops/view_ops.py index 4902e83ca..48bb10c96 100644 --- a/src/chop/passes/graph/analysis/autosharding/ops/view_ops.py +++ b/src/chop/passes/graph/analysis/autosharding/ops/view_ops.py @@ -17,18 +17,18 @@ import torch from torch import Tensor -from torch.distributed._tensor._op_schema import ( +from torch.distributed.tensor._op_schema import ( OpStrategy, PlacementStrategy, ) -from torch.distributed._tensor.api import Shard -from torch.distributed._tensor.ops.utils import ( +from torch.distributed.tensor.api import Shard +from torch.distributed.tensor._ops.utils import ( generate_redistribute_costs, normalize_dim, normalize_dims, prod, ) -from torch.distributed._tensor.placement_types import ( +from torch.distributed.tensor.placement_types import ( DTensorSpec, Placement, Replicate, diff --git a/src/chop/passes/graph/analysis/autosharding/strategies/basic_strategy.py b/src/chop/passes/graph/analysis/autosharding/strategies/basic_strategy.py index b9541013a..94708b89b 100644 --- a/src/chop/passes/graph/analysis/autosharding/strategies/basic_strategy.py +++ b/src/chop/passes/graph/analysis/autosharding/strategies/basic_strategy.py @@ -5,8 +5,8 @@ from dataclasses import dataclass from typing import List, Set, Tuple -from torch.distributed._tensor._op_schema import OpStrategy, PlacementStrategy -from torch.distributed._tensor.placement_types import ( +from torch.distributed.tensor._op_schema import OpStrategy, PlacementStrategy +from torch.distributed.tensor.placement_types import ( DTensorSpec, _Partial, Placement, diff --git a/src/chop/passes/graph/analysis/autosharding/strategies/common.py b/src/chop/passes/graph/analysis/autosharding/strategies/common.py index 0959a6bf3..02ad99c96 100644 --- a/src/chop/passes/graph/analysis/autosharding/strategies/common.py +++ b/src/chop/passes/graph/analysis/autosharding/strategies/common.py @@ -2,8 +2,8 @@ import torch import torch.nn.functional as F -from torch.distributed._tensor._op_schema import OpStrategy, PlacementStrategy -from torch.distributed._tensor.placement_types import ( +from torch.distributed.tensor._op_schema import OpStrategy, PlacementStrategy +from torch.distributed.tensor.placement_types import ( Replicate, Shard, DTensorSpec, diff --git a/src/chop/passes/graph/analysis/autosharding/strategies/embedding_ops.py b/src/chop/passes/graph/analysis/autosharding/strategies/embedding_ops.py index feef6fd12..2e3872054 100644 --- a/src/chop/passes/graph/analysis/autosharding/strategies/embedding_ops.py +++ b/src/chop/passes/graph/analysis/autosharding/strategies/embedding_ops.py @@ -7,18 +7,18 @@ import torch import torch.distributed._functional_collectives as funcol -from torch.distributed._tensor._op_schema import ( +from torch.distributed.tensor._op_schema import ( OpSchema, OpStrategy, StrategyType, DTensorSpec, PlacementStrategy, ) -from torch.distributed._tensor.ops.utils import ( +from torch.distributed.tensor._ops.utils import ( is_tensor_shardable, generate_redistribute_costs, ) -from torch.distributed._tensor.placement_types import ( +from torch.distributed.tensor.placement_types import ( Partial, Placement, Replicate, diff --git a/src/chop/passes/graph/analysis/autosharding/strategies/math_ops.py b/src/chop/passes/graph/analysis/autosharding/strategies/math_ops.py index 4cdc2f0fe..0c6db820b 100644 --- a/src/chop/passes/graph/analysis/autosharding/strategies/math_ops.py +++ b/src/chop/passes/graph/analysis/autosharding/strategies/math_ops.py @@ -4,16 +4,17 @@ from typing import cast, List, Optional, Sequence, Tuple, Union import torch -from torch.distributed._tensor._op_schema import ( +from torch.distributed.tensor._op_schema import ( OpStrategy, PlacementStrategy, ) -from torch.distributed._tensor.ops.utils import ( - generate_redistribute_costs, +from torch.distributed.tensor._ops.utils import ( normalize_dim, +) +from torch.distributed.tensor._utils import ( normalize_to_torch_size, ) -from torch.distributed._tensor.placement_types import ( +from torch.distributed.tensor.placement_types import ( DTensorSpec, Placement, Replicate, diff --git a/src/chop/passes/graph/analysis/autosharding/strategies/matrix_ops.py b/src/chop/passes/graph/analysis/autosharding/strategies/matrix_ops.py index d552f8171..7577f573c 100644 --- a/src/chop/passes/graph/analysis/autosharding/strategies/matrix_ops.py +++ b/src/chop/passes/graph/analysis/autosharding/strategies/matrix_ops.py @@ -2,18 +2,18 @@ # https://github.com/pytorch/pytorch/blob/main/torch/distributed/_tensor/ops/matrix_ops.py import torch -from torch.distributed._tensor._op_schema import ( +from torch.distributed.tensor._op_schema import ( OpStrategy, PlacementStrategy, PlacementList, ) -from torch.distributed._tensor.placement_types import Replicate, Shard, Placement +from torch.distributed.tensor.placement_types import Replicate, Shard, Placement from .basic_strategy import gen_einsum_strategies -from torch.distributed._tensor.ops.utils import ( +from torch.distributed.tensor._ops.utils import ( infer_broadcast_dims_map, map_placements_after_broadcast, ) -from torch.distributed._tensor.placement_types import ( +from torch.distributed.tensor.placement_types import ( DTensorSpec, Shard, TensorMeta, diff --git a/src/chop/passes/graph/analysis/autosharding/strategies/pointwise_ops.py b/src/chop/passes/graph/analysis/autosharding/strategies/pointwise_ops.py index fe6ac93ef..ac609e2e3 100644 --- a/src/chop/passes/graph/analysis/autosharding/strategies/pointwise_ops.py +++ b/src/chop/passes/graph/analysis/autosharding/strategies/pointwise_ops.py @@ -4,19 +4,19 @@ from typing import List, Sequence, Tuple import torch -from torch.distributed._tensor._op_schema import ( +from torch.distributed.tensor._op_schema import ( _is_inplace_op, _is_out_variant_op, OpStrategy, PlacementStrategy, ) -from torch.distributed._tensor.ops.utils import ( +from torch.distributed.tensor._ops.utils import ( generate_redistribute_costs, infer_broadcast_dims_map, map_placements_after_broadcast, normalize_dim, ) -from torch.distributed._tensor.placement_types import ( +from torch.distributed.tensor.placement_types import ( DTensorSpec, Partial, Placement, diff --git a/src/chop/passes/graph/analysis/autosharding/strategies/tensor_ops.py b/src/chop/passes/graph/analysis/autosharding/strategies/tensor_ops.py index 0a9daee2a..90869c4d6 100644 --- a/src/chop/passes/graph/analysis/autosharding/strategies/tensor_ops.py +++ b/src/chop/passes/graph/analysis/autosharding/strategies/tensor_ops.py @@ -1,15 +1,15 @@ # Adapted from Pytorch Distributed DTensor API. # https://github.com/pytorch/pytorch/blob/main/torch/distributed/_tensor/ops/tensor_ops.py -from torch.distributed._tensor._op_schema import ( +from torch.distributed.tensor._op_schema import ( OpStrategy, PlacementStrategy, StrategyType, ) -from torch.distributed._tensor.ops.utils import ( +from torch.distributed.tensor._ops.utils import ( is_tensor_partial, ) -from torch.distributed._tensor.placement_types import ( +from torch.distributed.tensor.placement_types import ( DTensorSpec, Partial, Replicate, diff --git a/src/chop/passes/graph/analysis/autosharding/strategies/view_ops.py b/src/chop/passes/graph/analysis/autosharding/strategies/view_ops.py index 83c638bfb..f6b253ce0 100644 --- a/src/chop/passes/graph/analysis/autosharding/strategies/view_ops.py +++ b/src/chop/passes/graph/analysis/autosharding/strategies/view_ops.py @@ -17,18 +17,18 @@ import torch from torch import Tensor -from torch.distributed._tensor._op_schema import ( +from torch.distributed.tensor._op_schema import ( OpStrategy, PlacementStrategy, ) -from torch.distributed._tensor.api import Shard -from torch.distributed._tensor.ops.utils import ( +from torch.distributed.tensor.api import Shard +from torch.distributed.tensor._ops.utils import ( generate_redistribute_costs, normalize_dim, normalize_dims, prod, ) -from torch.distributed._tensor.placement_types import ( +from torch.distributed.tensor.placement_types import ( DTensorSpec, Placement, Replicate, diff --git a/src/chop/passes/graph/analysis/autosharding/utils.py b/src/chop/passes/graph/analysis/autosharding/utils.py index 54ff94bea..9af911a55 100644 --- a/src/chop/passes/graph/analysis/autosharding/utils.py +++ b/src/chop/passes/graph/analysis/autosharding/utils.py @@ -1,5 +1,5 @@ from typing import Sequence, cast -from torch.distributed._tensor.placement_types import DTensorSpec, Shard +from torch.distributed.tensor.placement_types import DTensorSpec, Shard def is_tensor_shardable(shape: Sequence[int], spec: DTensorSpec) -> bool: diff --git a/src/chop/passes/graph/transforms/insert_dtensor_wrapper.py b/src/chop/passes/graph/transforms/insert_dtensor_wrapper.py index cda5508d2..73402bc2a 100644 --- a/src/chop/passes/graph/transforms/insert_dtensor_wrapper.py +++ b/src/chop/passes/graph/transforms/insert_dtensor_wrapper.py @@ -1,5 +1,5 @@ import torch -from torch.distributed._tensor.api import DTensorSpec, TensorMeta +from torch.distributed.tensor.api import DTensorSpec, TensorMeta from torch.distributed import DeviceMesh from copy import deepcopy diff --git a/src/chop/passes/graph/transforms/resharding.py b/src/chop/passes/graph/transforms/resharding.py index 51cbc61d0..a5fa24e22 100644 --- a/src/chop/passes/graph/transforms/resharding.py +++ b/src/chop/passes/graph/transforms/resharding.py @@ -2,7 +2,7 @@ import torch import torch.fx as fx -from torch.distributed._tensor.placement_types import Replicate, Shard +from torch.distributed.tensor.placement_types import Replicate, Shard from chop.tools import get_logger from chop.nn.functional.dtensor import redistribute_dtensor @@ -108,7 +108,7 @@ def _insert_resharding_nodes(mg, pass_args={}): # Insert DTensor import at the top of code def insert_imports(body): return [ - "from torch.distributed._tensor.placement_types import Replicate, Shard, Partial; sum = 'sum' \n", + "from torch.distributed.tensor.placement_types import Replicate, Shard, Partial; sum = 'sum' \n", *body, ] diff --git a/src/chop/pipelines/distributed_inference.py b/src/chop/pipelines/distributed_inference.py index e67c3ef6e..cb0bc9278 100644 --- a/src/chop/pipelines/distributed_inference.py +++ b/src/chop/pipelines/distributed_inference.py @@ -47,6 +47,7 @@ def __init__(self) -> None: # Raise to Mase IR pass_list += [ passes.init_metadata_analysis_pass, + passes.report_graph_analysis_pass, passes.add_common_metadata_analysis_pass, ] From d3f43f91169e7481e6844199caf759f78ebee029 Mon Sep 17 00:00:00 2001 From: Pedro Gimenes Date: Mon, 19 Aug 2024 13:05:27 +0000 Subject: [PATCH 82/93] remove breakpoint --- .../graph/analysis/add_metadata/add_common_metadata.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/chop/passes/graph/analysis/add_metadata/add_common_metadata.py b/src/chop/passes/graph/analysis/add_metadata/add_common_metadata.py index ff2639fe9..df1b42f12 100644 --- a/src/chop/passes/graph/analysis/add_metadata/add_common_metadata.py +++ b/src/chop/passes/graph/analysis/add_metadata/add_common_metadata.py @@ -232,10 +232,7 @@ def graph_iterator_for_metadata( elif node.op == "call_function": args = load_arg(node.args, env) kwargs = load_arg(node.kwargs, env) - try: - result = node.target(*args, **kwargs) - except: - breakpoint() + result = node.target(*args, **kwargs) analyse_fn = analyse_common_parameters_function elif node.op == "call_method": self_obj, *args = load_arg(node.args, env) From 97aa774c82bf845ef07940b4b7ebdee0af41c105 Mon Sep 17 00:00:00 2001 From: Pedro Gimenes Date: Mon, 19 Aug 2024 16:52:16 +0000 Subject: [PATCH 83/93] revert changes to support SDPA which had been incorrectly merged --- .../autosharding/algos/alpa_intra_operator.py | 4 +- .../autosharding/strategies/common.py | 177 +++++++++++++++--- .../autosharding/strategies/matrix_ops.py | 55 +++++- .../autosharding/strategies/view_ops.py | 2 +- .../passes/graph/transforms/resharding.py | 9 +- 5 files changed, 213 insertions(+), 34 deletions(-) diff --git a/src/chop/passes/graph/analysis/autosharding/algos/alpa_intra_operator.py b/src/chop/passes/graph/analysis/autosharding/algos/alpa_intra_operator.py index 71fc7ea6c..0e8bed368 100644 --- a/src/chop/passes/graph/analysis/autosharding/algos/alpa_intra_operator.py +++ b/src/chop/passes/graph/analysis/autosharding/algos/alpa_intra_operator.py @@ -28,7 +28,7 @@ logger = get_logger(__name__) -logger.setLevel("DEBUG") +logger.setLevel("INFO") def _get_computation_cost_from_strategy( @@ -540,7 +540,7 @@ def alpa_intra_op_sharding_pass(mg, mesh, pass_args={}, debug=False): problem.solve( verbose=True, scipy_options={ - "disp": True, + "disp": pass_args.get(f"run_checks", False), "time_limit": pass_args.get("time_limit", None), "mip_rel_gap": pass_args.get("mip_rel_gap", 0) / 100, }, diff --git a/src/chop/passes/graph/analysis/autosharding/strategies/common.py b/src/chop/passes/graph/analysis/autosharding/strategies/common.py index 02ad99c96..ef3333373 100644 --- a/src/chop/passes/graph/analysis/autosharding/strategies/common.py +++ b/src/chop/passes/graph/analysis/autosharding/strategies/common.py @@ -40,20 +40,61 @@ def find_shape_and_dtype(arg): def placeholder_or_getattr_strategy(meta, mesh, skip_fully_replicated=False): ndims = len(meta["common"]["results"]["data_out_0"]["shape"]) + tensor_shape = meta["common"]["results"]["data_out_0"]["shape"] opts = [Replicate()] + [Shard(dim) for dim in range(ndims)] tensor_meta = TensorMeta( - shape=meta["common"]["results"]["data_out_0"]["shape"], + shape=tensor_shape, stride=None, dtype=meta["common"]["results"]["data_out_0"]["torch_dtype"], ) shardings = [] for sharding in itertools.product(opts, repeat=2): + # Skip fully replicated shardings since this sometimes forces the ILP + # to choose a fully replicated strategy for the entire model when + # the computation cost term is not formulated if skip_fully_replicated and sharding == (Replicate(), Replicate()): continue - spec = DTensorSpec(mesh=mesh, placements=sharding, tensor_meta=tensor_meta) - shardings.append(PlacementStrategy(input_specs=spec, output_specs=spec)) + + # Skip sharding if any dimension is sharded to 0 + skip_sharding = False + for dim in range(ndims): + # Find all device mesh dimensions along which this tensor dimension is sharded + mesh_sharded_dims = [ + idx for idx, shard in enumerate(sharding) if shard == Shard(dim) + ] + + # This tensor dimension is not sharded + if len(mesh_sharded_dims) == 0: + continue + + elif len(mesh_sharded_dims) == 1: + num_gpus = mesh.mesh_shape[mesh_sharded_dims[0]] + + else: + num_gpus = np.prod(mesh.mesh_shape) + + dim_size_after_sharding = tensor_shape[dim] // num_gpus + if dim_size_after_sharding == 0: + skip_sharding = True + continue + + if skip_sharding is True: + continue + + spec = DTensorSpec( + mesh=mesh, + placements=sharding, + tensor_meta=tensor_meta, + ) + shardings.append( + PlacementStrategy( + input_specs=spec, + output_specs=spec, + ) + ) + return OpStrategy(shardings) @@ -70,23 +111,29 @@ def fully_replicated_strategy(meta, mesh): in_shape = meta["common"]["self"].shape in_dtype = meta["common"]["self"].dtype else: - first_arg_key = ( - "data_in_0" - if "data_in_0" in meta["common"]["args"] - else [i for i in meta["common"]["args"].keys()][0] - ) - arg = meta["common"]["args"][first_arg_key] - in_shape, in_dtype = find_shape_and_dtype(arg) - - in_spec = DTensorSpec( - mesh, - sharding, - tensor_meta=TensorMeta( - shape=in_shape, - stride=None, - dtype=in_dtype, - ), - ) + if len(meta["common"]["args"]) > 0: + first_arg_key = ( + "data_in_0" + if "data_in_0" in meta["common"]["args"] + else [i for i in meta["common"]["args"].keys()][0] + ) + arg = meta["common"]["args"][first_arg_key] + in_shape, in_dtype = find_shape_and_dtype(arg) + + in_spec = [ + DTensorSpec( + mesh, + sharding, + tensor_meta=TensorMeta( + shape=in_shape, + stride=None, + dtype=in_dtype, + ), + ) + ] * len(meta["common"]["args"].keys()) + + else: + in_spec = [] dtype_key = ( "torch_dtype" @@ -104,6 +151,92 @@ def fully_replicated_strategy(meta, mesh): ), ) - shardings = [PlacementStrategy(input_specs=in_spec, output_specs=out_spec)] + return OpStrategy( + [ + PlacementStrategy( + input_specs=in_spec, + output_specs=out_spec, + ) + ] + ) - return OpStrategy(shardings) + +def expand_to_full_mesh_op_strategy( + meta, + mesh: DeviceMesh, + single_mesh_dim_strategies: List[List[Placement]], + *, + input_index: int = 1, + inplace_op: bool = False, +) -> OpStrategy: + # Expand the single_mesh_dim_strategies to full mesh dim strategies. + all_mesh_dim_strategies = [single_mesh_dim_strategies] * mesh.ndim + + strategy_combs = itertools.product(*all_mesh_dim_strategies) + + all_strategies = [] + for strategy_comb in strategy_combs: + spec_list = [] + for specs in zip(*strategy_comb): + try: + spec_list.append( + DTensorSpec( + mesh, + tuple(specs), + tensor_meta=TensorMeta( + shape=meta["common"]["results"]["data_out_0"]["shape"], + stride=None, + dtype=meta["common"]["results"]["data_out_0"][ + "torch_dtype" + ], + ), + ) + ) + except: + breakpoint() + + input_specs = spec_list[input_index:] + # input_args_strategy = op_schema.args_strategy + input_args_strategy = tuple( + arg.meta["mase"]["software"]["autosharding"]["op_strategy"] + for arg in meta.node.args + ) + assert len(input_specs) == len(input_args_strategy) + self_spec = input_args_strategy[0].strategies[0].output_spec + if inplace_op and self_spec.placements != input_specs[0].placements: + # if it's inplace op, we would only allow the placement strategy to be added when the + # input_spec matches the first argument's runtime sharding, otherwise we skip + continue + + # check inputs shardable + inputs_shardable = all( + is_tensor_shardable(inp.shape, s) + for inp, s in zip(input_args_strategy, input_specs) + ) + + # extend input_specs to include fully replicated sharding for constant nodes + extended_input_specs = input_specs + [ + DTensorSpec( + mesh, + (Replicate(), Replicate()), + # todo: may need to set tensor meta + tensor_meta=None, + ) + ] * (len(meta["common"]["args"].keys()) - len(input_specs)) + + # only add to the all_strategies list when all inputs are shardable + if inputs_shardable: + redistribute_cost = [ + generate_redistribute_costs(input_strategy, input_spec) + for input_strategy, input_spec in zip(input_args_strategy, input_specs) + ] + strategy = PlacementStrategy( + output_specs=( + tuple(spec_list[:input_index]) if input_index > 1 else spec_list[0] + ), + input_specs=extended_input_specs, + redistribute_cost=redistribute_cost, + ) + all_strategies.append(strategy) + + return OpStrategy(all_strategies) diff --git a/src/chop/passes/graph/analysis/autosharding/strategies/matrix_ops.py b/src/chop/passes/graph/analysis/autosharding/strategies/matrix_ops.py index 7577f573c..a5581797f 100644 --- a/src/chop/passes/graph/analysis/autosharding/strategies/matrix_ops.py +++ b/src/chop/passes/graph/analysis/autosharding/strategies/matrix_ops.py @@ -22,6 +22,16 @@ from ..utils import is_tensor_shardable from chop.ir.graph import MaseMetadata +from .common import expand_to_full_mesh_op_strategy + + +def _other(meta, dim): + if dim == meta["common"]["args"]["dim0"]["value"]: + return meta["common"]["args"]["dim1"]["value"] + elif dim == meta["common"]["args"]["dim1"]["value"]: + return meta["common"]["args"]["dim0"]["value"] + else: + raise ValueError(f"Invalid dim: {dim}") def transpose_strategy( @@ -34,17 +44,28 @@ def transpose_strategy( assert isinstance(self_strategy, OpStrategy) + fully_replicated_spec = DTensorSpec( + mesh=mesh, + placements=[Replicate(), Replicate()], + tensor_meta=None, + ) + transpose_strategies = [] for input_strategy in self_strategy.strategies: - input_spec = input_strategy.output_spec + + if isinstance(input_strategy.output_specs, tuple): + input_spec = input_strategy.output_specs[0] + else: + input_spec = input_strategy.output_spec + # follow the input spec but transpose the Shard placements output_placements = [ - Shard(1 - p.dim) if isinstance(p, Shard) else p + Shard(_other(meta, p.dim)) if isinstance(p, Shard) else p for p in input_spec.placements ] transpose_strategy = PlacementStrategy( output_specs=DTensorSpec( - mesh=input_strategy.output_spec.mesh, + mesh=input_spec.mesh, placements=tuple(output_placements), tensor_meta=TensorMeta( shape=meta["common"]["results"]["data_out_0"]["shape"], @@ -52,7 +73,8 @@ def transpose_strategy( dtype=meta["common"]["results"]["data_out_0"]["torch_dtype"], ), ), - input_specs=(input_strategy.output_spec,), + # include 2 fully replicated inputs for dim_0 and dim_1 arguments + input_specs=(input_spec,) + (fully_replicated_spec,) * 2, ) transpose_strategies.append(transpose_strategy) @@ -205,11 +227,17 @@ def scaled_dot_product_flash_attention_strategy( # NOTE: currently we only support some simple strategies to support tensor parallelism # TODO: sdpa might be a good candidate for us to explore decomposed sharding propagation # as it involves: matmul, pointwise, reduction ops together. - return_debug_mask = len(op_schema.args_schema) >= 6 and op_schema.args_schema[5] - q_input_strategy = op_schema.args_schema[0] + arg_names = list(meta["common"]["args"].keys()) + arg_infos = list(meta["common"]["args"].values()) + return_debug_mask = len(arg_names) >= 6 and arg_infos[5]["value"] + + # q_input_strategy = op_schema.args_schema[0] + q_parent_node = meta.node.args[0] + q_input_strategy = q_parent_node.meta["mase"]["software"]["autosharding"][ + "op_strategy" + ] assert isinstance(q_input_strategy, OpStrategy) # assuming q/k/v have the same shape - qkv_shape = q_input_strategy.shape single_mesh_dim_strategies = [] @@ -277,5 +305,16 @@ def scaled_dot_product_flash_attention_strategy( ] ) return expand_to_full_mesh_op_strategy( - mesh, op_schema, single_mesh_dim_strategies, input_index=9 + meta, + mesh, + single_mesh_dim_strategies, + input_index=9, ) + + +def scaled_dot_product_strategy( + meta: MaseMetadata, + mesh: tuple, +): + # todo: support efficient attention backend + return scaled_dot_product_flash_attention_strategy(meta, mesh) diff --git a/src/chop/passes/graph/analysis/autosharding/strategies/view_ops.py b/src/chop/passes/graph/analysis/autosharding/strategies/view_ops.py index f6b253ce0..e747ef924 100644 --- a/src/chop/passes/graph/analysis/autosharding/strategies/view_ops.py +++ b/src/chop/passes/graph/analysis/autosharding/strategies/view_ops.py @@ -606,7 +606,7 @@ def reshape_strategy(meta, mesh): # FIXME: this can be wrong for situations where we have # [Shard(0), Shard(0)] input_tgt_spec = DTensorSpec( - placements=tuple(input_tgt_placements), + placements=(Replicate(), Replicate()), mesh=input_src_spec.mesh, tensor_meta=input_src_spec.tensor_meta, ) diff --git a/src/chop/passes/graph/transforms/resharding.py b/src/chop/passes/graph/transforms/resharding.py index a5fa24e22..1b5a007b0 100644 --- a/src/chop/passes/graph/transforms/resharding.py +++ b/src/chop/passes/graph/transforms/resharding.py @@ -63,7 +63,14 @@ def _insert_resharding_nodes(mg, pass_args={}): ) continue - if arg_specs.placements != parent_out_specs.placements: + arg_placements = arg_specs.placements + parent_out_placements = ( + parent_out_specs[0].placements + if isinstance(parent_out_specs, (list, tuple)) + else parent_out_specs.placements + ) + + if arg_placements != parent_out_placements: logger.info( f"Inserting resharding node along edge {arg_obj} -> {node.name} because arg {arg_name} requires placement {arg_specs.placements} but parent node {arg_obj.name} has placement {parent_out_specs.placements}." ) From b5759a812301e29a057cfb5c305f8d68e9cb3b70 Mon Sep 17 00:00:00 2001 From: Pedro Gimenes Date: Thu, 29 Aug 2024 13:15:33 +0000 Subject: [PATCH 84/93] fix torch.distributed.tensor imports --- src/chop/distributed/debug.py | 6 +++--- src/chop/distributed/tensor/__init__.py | 6 +++--- src/chop/distributed/tensor/_dispatch.py | 6 +++--- src/chop/distributed/tensor/_redistribute.py | 4 ++-- src/chop/distributed/tensor/_sharding_prop.py | 6 +++--- src/chop/distributed/tensor/_utils.py | 4 ++-- src/chop/distributed/tensor/api.py | 12 ++++++------ src/chop/distributed/tensor/ops/basic_strategy.py | 4 ++-- src/chop/distributed/tensor/ops/common_rules.py | 8 ++++---- src/chop/distributed/tensor/ops/conv_ops.py | 4 ++-- src/chop/distributed/tensor/ops/embedding_ops.py | 6 +++--- src/chop/distributed/tensor/ops/experimental_ops.py | 6 +++--- src/chop/distributed/tensor/ops/math_ops.py | 6 +++--- src/chop/distributed/tensor/ops/matrix_ops.py | 6 +++--- src/chop/distributed/tensor/ops/pointwise_ops.py | 6 +++--- src/chop/distributed/tensor/ops/random_ops.py | 4 ++-- src/chop/distributed/tensor/ops/tensor_ops.py | 4 ++-- src/chop/distributed/tensor/ops/utils.py | 8 ++++---- src/chop/distributed/tensor/ops/view_ops.py | 6 +++--- src/chop/distributed/utils.py | 4 ++-- src/chop/nn/functional/dtensor.py | 6 +++--- src/chop/passes/__init__.py | 2 +- .../autosharding/algos/alpa_intra_operator.py | 6 +++--- .../analysis/autosharding/algos/fully_replicated.py | 4 ++-- .../analysis/autosharding/alpa_intra_operator.py | 4 ++-- .../graph/analysis/autosharding/autosharding.py | 4 ++-- .../analysis/autosharding/ops/basic_strategy.py | 4 ++-- .../passes/graph/analysis/autosharding/ops/common.py | 6 +++--- .../graph/analysis/autosharding/ops/embedding_ops.py | 4 ++-- .../graph/analysis/autosharding/ops/math_ops.py | 8 ++++---- .../graph/analysis/autosharding/ops/matrix_ops.py | 8 ++++---- .../graph/analysis/autosharding/ops/pointwise_ops.py | 6 +++--- .../graph/analysis/autosharding/ops/tensor_ops.py | 6 +++--- .../graph/analysis/autosharding/ops/view_ops.py | 8 ++++---- .../autosharding/strategies/basic_strategy.py | 4 ++-- .../graph/analysis/autosharding/strategies/common.py | 4 ++-- .../autosharding/strategies/embedding_ops.py | 6 +++--- .../analysis/autosharding/strategies/math_ops.py | 8 ++++---- .../analysis/autosharding/strategies/matrix_ops.py | 8 ++++---- .../autosharding/strategies/pointwise_ops.py | 6 +++--- .../analysis/autosharding/strategies/tensor_ops.py | 6 +++--- .../analysis/autosharding/strategies/view_ops.py | 8 ++++---- src/chop/passes/graph/analysis/autosharding/utils.py | 2 +- .../graph/transforms/insert_dtensor_wrapper.py | 2 +- src/chop/passes/graph/transforms/resharding.py | 4 ++-- 45 files changed, 125 insertions(+), 125 deletions(-) diff --git a/src/chop/distributed/debug.py b/src/chop/distributed/debug.py index a65d72b3d..76cd8f3e9 100644 --- a/src/chop/distributed/debug.py +++ b/src/chop/distributed/debug.py @@ -4,9 +4,9 @@ import numpy as np from torch._prims_common import ShapeType -from torch.distributed.tensor import DeviceMesh +from torch.distributed._tensor import DeviceMesh -from torch.distributed.tensor.placement_types import Placement, Shard +from torch.distributed._tensor.placement_types import Placement, Shard def _mesh_to_coordinate(mesh, device_type): @@ -90,7 +90,7 @@ def compute_local_shape_and_global_offset( my_coordinate: List[int], ) -> Tuple[Tuple[int, ...], Tuple[int, ...]]: """ - Same as torch.distributed.tensor._utils.compute_local_shape_and_global_offset but + Same as torch.distributed._tensor._utils.compute_local_shape_and_global_offset but with custom my_coordinate input. This is the modified implementation for visualize_sharding. """ diff --git a/src/chop/distributed/tensor/__init__.py b/src/chop/distributed/tensor/__init__.py index d4415d599..eecce89dc 100644 --- a/src/chop/distributed/tensor/__init__.py +++ b/src/chop/distributed/tensor/__init__.py @@ -4,9 +4,9 @@ # Import all builtin dist tensor ops import torch -import torch.distributed.tensor.random as random +import torch.distributed._tensor.random as random -from torch.distributed.tensor.placement_types import ( +from torch.distributed._tensor.placement_types import ( Partial, Placement, Replicate, @@ -57,7 +57,7 @@ def _dtensor_init_helper( placements=None, **kwargs, ) -> DTensor: - from torch.distributed.tensor.placement_types import DTensorSpec, TensorMeta + from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta # if device_mesh is None, use the one from mesh resources device_mesh = device_mesh or _mesh_resources.get_current_mesh() diff --git a/src/chop/distributed/tensor/_dispatch.py b/src/chop/distributed/tensor/_dispatch.py index 0e0a661b8..3f9d38043 100644 --- a/src/chop/distributed/tensor/_dispatch.py +++ b/src/chop/distributed/tensor/_dispatch.py @@ -2,14 +2,14 @@ from typing import cast, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING import torch -from torch.distributed.tensor._op_schema import ( +from torch.distributed._tensor._op_schema import ( OutputSpecType, ) -from torch.distributed.tensor._tp_conv import ( +from torch.distributed._tensor._tp_conv import ( convolution_backward_handler, convolution_handler, ) -from torch.distributed.tensor.placement_types import ( +from torch.distributed._tensor.placement_types import ( DTensorSpec, Replicate, TensorMeta, diff --git a/src/chop/distributed/tensor/_redistribute.py b/src/chop/distributed/tensor/_redistribute.py index 6f99b558c..fc66b219b 100644 --- a/src/chop/distributed/tensor/_redistribute.py +++ b/src/chop/distributed/tensor/_redistribute.py @@ -5,8 +5,8 @@ import torch import torch.distributed._functional_collectives as funcol -from torch.distributed.tensor.device_mesh import DeviceMesh -from torch.distributed.tensor.placement_types import ( +from torch.distributed._tensor.device_mesh import DeviceMesh +from torch.distributed._tensor.placement_types import ( DTensorSpec, Partial, Placement, diff --git a/src/chop/distributed/tensor/_sharding_prop.py b/src/chop/distributed/tensor/_sharding_prop.py index 0812f3c0d..81224a952 100644 --- a/src/chop/distributed/tensor/_sharding_prop.py +++ b/src/chop/distributed/tensor/_sharding_prop.py @@ -6,7 +6,7 @@ import torch from torch._ops import OpOverload from torch._subclasses import FakeTensorMode -from torch.distributed.tensor._op_schema import ( +from torch.distributed._tensor._op_schema import ( OpInfo, OpSchema, OpStrategy, @@ -17,12 +17,12 @@ StrategyType, TupleStrategy, ) -from torch.distributed.tensor._utils import ( +from torch.distributed._tensor._utils import ( compute_local_shape, compute_local_stride, try_find_mesh_from_args, ) -from torch.distributed.tensor.placement_types import DTensorSpec, TensorMeta +from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta from torch.distributed.device_mesh import DeviceMesh diff --git a/src/chop/distributed/tensor/_utils.py b/src/chop/distributed/tensor/_utils.py index fa278c027..a3cc8ee5a 100644 --- a/src/chop/distributed/tensor/_utils.py +++ b/src/chop/distributed/tensor/_utils.py @@ -1,9 +1,9 @@ from typing import cast, List, Sequence, Tuple import torch -import torch.distributed.tensor.api as dtensor +import torch.distributed._tensor.api as dtensor from torch._prims_common import ShapeType -from torch.distributed.tensor.placement_types import ( +from torch.distributed._tensor.placement_types import ( DTensorSpec, Partial, Placement, diff --git a/src/chop/distributed/tensor/api.py b/src/chop/distributed/tensor/api.py index 3dba5fc02..e9cabae4e 100644 --- a/src/chop/distributed/tensor/api.py +++ b/src/chop/distributed/tensor/api.py @@ -5,18 +5,18 @@ from typing import Any, Callable, cast, Optional, Sequence, Tuple import torch -import torch.distributed.tensor.random as random +import torch.distributed._tensor.random as random import torch.nn as nn -from torch.distributed.tensor._collective_utils import ( +from torch.distributed._tensor._collective_utils import ( check_tensor_meta, mesh_broadcast, ) -from torch.distributed.tensor._redistribute import ( +from torch.distributed._tensor._redistribute import ( Redistribute, redistribute_local_tensor, ) -from torch.distributed.tensor._utils import compute_global_tensor_info -from torch.distributed.tensor.placement_types import ( +from torch.distributed._tensor._utils import compute_global_tensor_info +from torch.distributed._tensor.placement_types import ( DTensorSpec, Partial, Placement, @@ -24,7 +24,7 @@ Shard, TensorMeta, ) -from torch.distributed.tensor.random import ( +from torch.distributed._tensor.random import ( is_rng_supported_mesh, OffsetBasedRNGTracker, ) diff --git a/src/chop/distributed/tensor/ops/basic_strategy.py b/src/chop/distributed/tensor/ops/basic_strategy.py index 37b1c29ff..97dd43b15 100644 --- a/src/chop/distributed/tensor/ops/basic_strategy.py +++ b/src/chop/distributed/tensor/ops/basic_strategy.py @@ -2,8 +2,8 @@ from dataclasses import dataclass from typing import List, Set, Tuple -from torch.distributed.tensor._op_schema import OpStrategy, PlacementStrategy -from torch.distributed.tensor.placement_types import ( +from torch.distributed._tensor._op_schema import OpStrategy, PlacementStrategy +from torch.distributed._tensor.placement_types import ( DTensorSpec, Partial, Placement, diff --git a/src/chop/distributed/tensor/ops/common_rules.py b/src/chop/distributed/tensor/ops/common_rules.py index 082755077..f70b27076 100644 --- a/src/chop/distributed/tensor/ops/common_rules.py +++ b/src/chop/distributed/tensor/ops/common_rules.py @@ -2,15 +2,15 @@ from typing import cast, Dict, List, Optional, Tuple import torch -from torch.distributed.tensor._op_schema import ( +from torch.distributed._tensor._op_schema import ( _is_inplace_op, _is_out_variant_op, OpSchema, OutputSharding, ) -from torch.distributed.tensor._utils import compute_local_shape -from torch.distributed.tensor._ops.utils import prod -from torch.distributed.tensor.placement_types import DTensorSpec, TensorMeta +from torch.distributed._tensor._utils import compute_local_shape +from torch.distributed._tensor.ops.utils import prod +from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta def _replace_char_in_str(string: str, new_char: str, idx: int) -> str: diff --git a/src/chop/distributed/tensor/ops/conv_ops.py b/src/chop/distributed/tensor/ops/conv_ops.py index dfcfd4a88..7bf13241d 100644 --- a/src/chop/distributed/tensor/ops/conv_ops.py +++ b/src/chop/distributed/tensor/ops/conv_ops.py @@ -3,9 +3,9 @@ from typing import List import torch -from torch.distributed.tensor._op_schema import OpSchema, OutputSharding +from torch.distributed._tensor._op_schema import OpSchema, OutputSharding from chop.distributed.tensor.ops.utils import register_prop_rule -from torch.distributed.tensor.placement_types import DTensorSpec, TensorMeta +from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta aten = torch.ops.aten diff --git a/src/chop/distributed/tensor/ops/embedding_ops.py b/src/chop/distributed/tensor/ops/embedding_ops.py index 11ec896e8..d89ec651b 100644 --- a/src/chop/distributed/tensor/ops/embedding_ops.py +++ b/src/chop/distributed/tensor/ops/embedding_ops.py @@ -6,14 +6,14 @@ import torch import torch.distributed._functional_collectives as funcol -from torch.distributed.tensor._op_schema import ( +from torch.distributed._tensor._op_schema import ( OpSchema, OpStrategy, PlacementList, StrategyType, ) -from torch.distributed.tensor._ops.utils import expand_to_full_mesh_op_strategy -from torch.distributed.tensor.placement_types import ( +from torch.distributed._tensor.ops.utils import expand_to_full_mesh_op_strategy +from torch.distributed._tensor.placement_types import ( Partial, Placement, Replicate, diff --git a/src/chop/distributed/tensor/ops/experimental_ops.py b/src/chop/distributed/tensor/ops/experimental_ops.py index 03f94e2cd..432fbede8 100644 --- a/src/chop/distributed/tensor/ops/experimental_ops.py +++ b/src/chop/distributed/tensor/ops/experimental_ops.py @@ -2,15 +2,15 @@ # implement matrix related ops for distributed tensor import torch -from torch.distributed.tensor._op_schema import ( +from torch.distributed._tensor._op_schema import ( OpSchema, OpStrategy, PlacementStrategy, StrategyType, ) -from torch.distributed.tensor.device_mesh import DeviceMesh +from torch.distributed._tensor.device_mesh import DeviceMesh from chop.distributed.tensor.ops.utils import register_op_strategy -from torch.distributed.tensor.placement_types import DTensorSpec, Replicate +from torch.distributed._tensor.placement_types import DTensorSpec, Replicate aten = torch.ops.aten diff --git a/src/chop/distributed/tensor/ops/math_ops.py b/src/chop/distributed/tensor/ops/math_ops.py index 8c19aefcb..c5c3609d6 100644 --- a/src/chop/distributed/tensor/ops/math_ops.py +++ b/src/chop/distributed/tensor/ops/math_ops.py @@ -6,7 +6,7 @@ from typing import cast, List, Optional, Sequence, Tuple, Union import torch -from torch.distributed.tensor._op_schema import ( +from torch.distributed._tensor._op_schema import ( OpSchema, OpStrategy, PlacementList, @@ -14,7 +14,7 @@ RuntimeSchemaInfo, TupleStrategy, ) -from torch.distributed.tensor._ops.utils import ( +from torch.distributed._tensor.ops.utils import ( as_list, expand_to_full_mesh_op_strategy, generate_redistribute_costs, @@ -26,7 +26,7 @@ from torch.distributed.utils import ( normalize_to_torch_size, ) -from torch.distributed.tensor.placement_types import ( +from torch.distributed._tensor.placement_types import ( DTensorSpec, Partial, Placement, diff --git a/src/chop/distributed/tensor/ops/matrix_ops.py b/src/chop/distributed/tensor/ops/matrix_ops.py index 35d96c814..77484de7d 100644 --- a/src/chop/distributed/tensor/ops/matrix_ops.py +++ b/src/chop/distributed/tensor/ops/matrix_ops.py @@ -2,20 +2,20 @@ # implement matrix related ops for distributed tensor import torch -from torch.distributed.tensor._op_schema import ( +from torch.distributed._tensor._op_schema import ( OpSchema, OpStrategy, PlacementList, PlacementStrategy, ) -from torch.distributed.tensor._ops.utils import ( +from torch.distributed._tensor.ops.utils import ( expand_to_full_mesh_op_strategy, generate_redistribute_costs, infer_broadcast_dims_map, is_tensor_shardable, map_placements_after_broadcast, ) -from torch.distributed.tensor.placement_types import ( +from torch.distributed._tensor.placement_types import ( DTensorSpec, Placement, Replicate, diff --git a/src/chop/distributed/tensor/ops/pointwise_ops.py b/src/chop/distributed/tensor/ops/pointwise_ops.py index 2a9986ed0..c3e1f082f 100644 --- a/src/chop/distributed/tensor/ops/pointwise_ops.py +++ b/src/chop/distributed/tensor/ops/pointwise_ops.py @@ -2,7 +2,7 @@ from typing import List, Sequence, Tuple import torch -from torch.distributed.tensor._op_schema import ( +from torch.distributed._tensor._op_schema import ( _is_inplace_op, _is_out_variant_op, OpSchema, @@ -12,13 +12,13 @@ StrategyType, TupleStrategy, ) -from torch.distributed.tensor._ops.utils import ( +from torch.distributed._tensor.ops.utils import ( generate_redistribute_costs, infer_broadcast_dims_map, map_placements_after_broadcast, normalize_dim, ) -from torch.distributed.tensor.placement_types import ( +from torch.distributed._tensor.placement_types import ( DTensorSpec, Partial, Placement, diff --git a/src/chop/distributed/tensor/ops/random_ops.py b/src/chop/distributed/tensor/ops/random_ops.py index 4b9a3303d..7eefa30fc 100644 --- a/src/chop/distributed/tensor/ops/random_ops.py +++ b/src/chop/distributed/tensor/ops/random_ops.py @@ -1,12 +1,12 @@ # Copyright (c) Meta Platforms, Inc. and affiliates import torch -from torch.distributed.tensor._op_schema import ( +from torch.distributed._tensor._op_schema import ( OpSchema, OpStrategy, PlacementStrategy, StrategyType, ) -from torch.distributed.tensor._ops.utils import is_tensor_partial +from torch.distributed._tensor.ops.utils import is_tensor_partial from chop.distributed.tensor.ops.utils import register_op_strategy from torch.distributed.device_mesh import DeviceMesh diff --git a/src/chop/distributed/tensor/ops/tensor_ops.py b/src/chop/distributed/tensor/ops/tensor_ops.py index 51ae573ab..8e64ff514 100644 --- a/src/chop/distributed/tensor/ops/tensor_ops.py +++ b/src/chop/distributed/tensor/ops/tensor_ops.py @@ -3,7 +3,7 @@ from typing import cast, List, Optional, Sequence, Tuple import torch -from torch.distributed.tensor._op_schema import ( +from torch.distributed._tensor._op_schema import ( _is_inplace_op, OpSchema, OpStrategy, @@ -14,7 +14,7 @@ StrategyType, TupleStrategy, ) -from torch.distributed.tensor.placement_types import ( +from torch.distributed._tensor.placement_types import ( DTensorSpec, Partial, Placement, diff --git a/src/chop/distributed/tensor/ops/utils.py b/src/chop/distributed/tensor/ops/utils.py index 236e45fc9..27cf89224 100644 --- a/src/chop/distributed/tensor/ops/utils.py +++ b/src/chop/distributed/tensor/ops/utils.py @@ -6,16 +6,16 @@ from typing import cast, Iterable, List, Optional, Sequence, Tuple, Union import torch -from torch.distributed.tensor._collective_utils import redistribute_cost -from torch.distributed.tensor._op_schema import ( +from torch.distributed._tensor._collective_utils import redistribute_cost +from torch.distributed._tensor._op_schema import ( OpSchema, OpStrategy, PlacementList, PlacementStrategy, RuntimeSchemaInfo, ) -from torch.distributed.tensor.device_mesh import DeviceMesh -from torch.distributed.tensor.placement_types import ( +from torch.distributed._tensor.device_mesh import DeviceMesh +from torch.distributed._tensor.placement_types import ( DTensorSpec, Partial, Placement, diff --git a/src/chop/distributed/tensor/ops/view_ops.py b/src/chop/distributed/tensor/ops/view_ops.py index fed462a6a..dc8103b08 100644 --- a/src/chop/distributed/tensor/ops/view_ops.py +++ b/src/chop/distributed/tensor/ops/view_ops.py @@ -16,15 +16,15 @@ import torch from torch import Tensor -from torch.distributed.tensor._op_schema import ( +from torch.distributed._tensor._op_schema import ( OpSchema, OpStrategy, PlacementStrategy, RuntimeSchemaInfo, StrategyType, ) -from torch.distributed.tensor.api import Shard -from torch.distributed.tensor.placement_types import DTensorSpec, Placement, Replicate +from torch.distributed._tensor.api import Shard +from torch.distributed._tensor.placement_types import DTensorSpec, Placement, Replicate from torch.distributed.device_mesh import DeviceMesh from chop.distributed.tensor.ops.utils import register_op_strategy diff --git a/src/chop/distributed/utils.py b/src/chop/distributed/utils.py index db751203c..55c5293f6 100644 --- a/src/chop/distributed/utils.py +++ b/src/chop/distributed/utils.py @@ -4,7 +4,7 @@ import torch import torch.nn as nn import torch.distributed as dist -from torch.distributed.tensor import DeviceMesh +from torch.distributed._tensor import DeviceMesh from chop.tools import get_logger from chop.distributed.tensor import distribute_tensor @@ -80,7 +80,7 @@ def dist_model_fn( tensor_sharding_map={}, ) -> None: """ - This function gets called by torch.distributed.tensor.distribute_module on each module in the model. + This function gets called by torch.distributed._tensor.distribute_module on each module in the model. Each tensor in each module is distributed according to the sharding configuration in tensor_sharding_map. Args: diff --git a/src/chop/nn/functional/dtensor.py b/src/chop/nn/functional/dtensor.py index 49a974af3..c5e29b73d 100644 --- a/src/chop/nn/functional/dtensor.py +++ b/src/chop/nn/functional/dtensor.py @@ -3,10 +3,10 @@ import torch import torch.fx as fx from torch.distributed.device_mesh import DeviceMesh -from torch.distributed.tensor.placement_types import DTensorSpec, TensorMeta -from torch.distributed.tensor._redistribute import redistribute_local_tensor +from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta +from torch.distributed._tensor._redistribute import redistribute_local_tensor -from torch.distributed.tensor.placement_types import Placement +from torch.distributed._tensor.placement_types import Placement from chop.ir.graph import MaseMetadata from chop.distributed.tensor import DTensor diff --git a/src/chop/passes/__init__.py b/src/chop/passes/__init__.py index a66399732..47b971f77 100644 --- a/src/chop/passes/__init__.py +++ b/src/chop/passes/__init__.py @@ -37,7 +37,7 @@ replace_method_with_function, insert_dtensor_wrapper_transform_pass, ) -from .module.analysis import calculate_avg_bits_module_analysis_pass +from .module.analysis import calculate_avg_bits_module_analysis_pass, autosharding_module_analysis_pass from .module.transforms import quantize_module_transform_pass from .onnx.analysis import ( diff --git a/src/chop/passes/graph/analysis/autosharding/algos/alpa_intra_operator.py b/src/chop/passes/graph/analysis/autosharding/algos/alpa_intra_operator.py index 0e8bed368..de9143d56 100644 --- a/src/chop/passes/graph/analysis/autosharding/algos/alpa_intra_operator.py +++ b/src/chop/passes/graph/analysis/autosharding/algos/alpa_intra_operator.py @@ -5,13 +5,13 @@ import torch import torch.fx as fx -from torch.distributed.tensor._collective_utils import redistribute_cost -from torch.distributed.tensor._op_schema import ( +from torch.distributed._tensor._collective_utils import redistribute_cost +from torch.distributed._tensor._op_schema import ( DTensorSpec, OpStrategy, PlacementStrategy, ) -from torch.distributed.tensor.placement_types import Shard, Replicate +from torch.distributed._tensor.placement_types import Shard, Replicate from chop.tools import get_logger from ..mesh_model import MeshModel diff --git a/src/chop/passes/graph/analysis/autosharding/algos/fully_replicated.py b/src/chop/passes/graph/analysis/autosharding/algos/fully_replicated.py index a71cd6973..75035a39e 100644 --- a/src/chop/passes/graph/analysis/autosharding/algos/fully_replicated.py +++ b/src/chop/passes/graph/analysis/autosharding/algos/fully_replicated.py @@ -1,5 +1,5 @@ -from torch.distributed.tensor._op_schema import DTensorSpec -from torch.distributed.tensor.placement_types import Replicate +from torch.distributed._tensor._op_schema import DTensorSpec +from torch.distributed._tensor.placement_types import Replicate from chop.ir import MaseGraph from ..mesh_model import MeshModel diff --git a/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py b/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py index ca33d31a8..b431095d6 100644 --- a/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py +++ b/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py @@ -1,7 +1,7 @@ import torch import torch.fx as fx -from torch.distributed.tensor._collective_utils import redistribute_cost -from torch.distributed.tensor._op_schema import DTensorSpec +from torch.distributed._tensor._collective_utils import redistribute_cost +from torch.distributed._tensor._op_schema import DTensorSpec import numpy as np import cvxpy as cp diff --git a/src/chop/passes/graph/analysis/autosharding/autosharding.py b/src/chop/passes/graph/analysis/autosharding/autosharding.py index 80baba829..2515828e9 100644 --- a/src/chop/passes/graph/analysis/autosharding/autosharding.py +++ b/src/chop/passes/graph/analysis/autosharding/autosharding.py @@ -2,8 +2,8 @@ from time import time import dill -from torch.distributed.tensor._op_schema import DTensorSpec -from torch.distributed.tensor.placement_types import Replicate +from torch.distributed._tensor._op_schema import DTensorSpec +from torch.distributed._tensor.placement_types import Replicate from chop.tools import get_logger from .mesh_model import MeshModel diff --git a/src/chop/passes/graph/analysis/autosharding/ops/basic_strategy.py b/src/chop/passes/graph/analysis/autosharding/ops/basic_strategy.py index 94708b89b..b9541013a 100644 --- a/src/chop/passes/graph/analysis/autosharding/ops/basic_strategy.py +++ b/src/chop/passes/graph/analysis/autosharding/ops/basic_strategy.py @@ -5,8 +5,8 @@ from dataclasses import dataclass from typing import List, Set, Tuple -from torch.distributed.tensor._op_schema import OpStrategy, PlacementStrategy -from torch.distributed.tensor.placement_types import ( +from torch.distributed._tensor._op_schema import OpStrategy, PlacementStrategy +from torch.distributed._tensor.placement_types import ( DTensorSpec, _Partial, Placement, diff --git a/src/chop/passes/graph/analysis/autosharding/ops/common.py b/src/chop/passes/graph/analysis/autosharding/ops/common.py index 56a19598f..7e06d71dc 100644 --- a/src/chop/passes/graph/analysis/autosharding/ops/common.py +++ b/src/chop/passes/graph/analysis/autosharding/ops/common.py @@ -6,15 +6,15 @@ import torch.nn.functional as F from torch.distributed.device_mesh import DeviceMesh -from torch.distributed.tensor._op_schema import OpStrategy, PlacementStrategy -from torch.distributed.tensor.placement_types import ( +from torch.distributed._tensor._op_schema import OpStrategy, PlacementStrategy +from torch.distributed._tensor.placement_types import ( Placement, Replicate, Shard, DTensorSpec, TensorMeta, ) -from torch.distributed.tensor._ops.utils import ( +from torch.distributed._tensor.ops.utils import ( is_tensor_shardable, generate_redistribute_costs, ) diff --git a/src/chop/passes/graph/analysis/autosharding/ops/embedding_ops.py b/src/chop/passes/graph/analysis/autosharding/ops/embedding_ops.py index bf7fbe06c..941af13b2 100644 --- a/src/chop/passes/graph/analysis/autosharding/ops/embedding_ops.py +++ b/src/chop/passes/graph/analysis/autosharding/ops/embedding_ops.py @@ -6,8 +6,8 @@ import torch import torch.distributed._functional_collectives as funcol -from torch.distributed.tensor._op_schema import StrategyType -from torch.distributed.tensor.placement_types import ( +from torch.distributed._tensor._op_schema import StrategyType +from torch.distributed._tensor.placement_types import ( Partial, Placement, Replicate, diff --git a/src/chop/passes/graph/analysis/autosharding/ops/math_ops.py b/src/chop/passes/graph/analysis/autosharding/ops/math_ops.py index 6b722763e..32e33afa2 100644 --- a/src/chop/passes/graph/analysis/autosharding/ops/math_ops.py +++ b/src/chop/passes/graph/analysis/autosharding/ops/math_ops.py @@ -4,17 +4,17 @@ from typing import cast, List, Optional, Sequence, Tuple, Union import torch -from torch.distributed.tensor._op_schema import ( +from torch.distributed._tensor._op_schema import ( OpStrategy, PlacementStrategy, ) -from torch.distributed.tensor._ops.utils import ( +from torch.distributed._tensor.ops.utils import ( normalize_dim, ) -from torch.distributed.tensor._utils import ( +from torch.distributed._tensor._utils import ( normalize_to_torch_size, ) -from torch.distributed.tensor.placement_types import ( +from torch.distributed._tensor.placement_types import ( DTensorSpec, Placement, Replicate, diff --git a/src/chop/passes/graph/analysis/autosharding/ops/matrix_ops.py b/src/chop/passes/graph/analysis/autosharding/ops/matrix_ops.py index 8948a1c98..20a2c74c6 100644 --- a/src/chop/passes/graph/analysis/autosharding/ops/matrix_ops.py +++ b/src/chop/passes/graph/analysis/autosharding/ops/matrix_ops.py @@ -2,18 +2,18 @@ # https://github.com/pytorch/pytorch/blob/main/torch/distributed/_tensor/ops/matrix_ops.py import torch -from torch.distributed.tensor._op_schema import ( +from torch.distributed._tensor._op_schema import ( OpStrategy, PlacementStrategy, PlacementList, ) -from torch.distributed.tensor.placement_types import Replicate, Shard, Placement +from torch.distributed._tensor.placement_types import Replicate, Shard, Placement from .basic_strategy import gen_einsum_strategies -from torch.distributed.tensor._ops.utils import ( +from torch.distributed._tensor.ops.utils import ( infer_broadcast_dims_map, map_placements_after_broadcast, ) -from torch.distributed.tensor.placement_types import ( +from torch.distributed._tensor.placement_types import ( DTensorSpec, Shard, TensorMeta, diff --git a/src/chop/passes/graph/analysis/autosharding/ops/pointwise_ops.py b/src/chop/passes/graph/analysis/autosharding/ops/pointwise_ops.py index 3cc958ba3..fe79c521a 100644 --- a/src/chop/passes/graph/analysis/autosharding/ops/pointwise_ops.py +++ b/src/chop/passes/graph/analysis/autosharding/ops/pointwise_ops.py @@ -4,16 +4,16 @@ from typing import List import torch -from torch.distributed.tensor._op_schema import ( +from torch.distributed._tensor._op_schema import ( OpStrategy, PlacementStrategy, ) -from torch.distributed.tensor._ops.utils import ( +from torch.distributed._tensor.ops.utils import ( infer_broadcast_dims_map, map_placements_after_broadcast, normalize_dim, ) -from torch.distributed.tensor.placement_types import ( +from torch.distributed._tensor.placement_types import ( DTensorSpec, Partial, Placement, diff --git a/src/chop/passes/graph/analysis/autosharding/ops/tensor_ops.py b/src/chop/passes/graph/analysis/autosharding/ops/tensor_ops.py index 90869c4d6..0a9daee2a 100644 --- a/src/chop/passes/graph/analysis/autosharding/ops/tensor_ops.py +++ b/src/chop/passes/graph/analysis/autosharding/ops/tensor_ops.py @@ -1,15 +1,15 @@ # Adapted from Pytorch Distributed DTensor API. # https://github.com/pytorch/pytorch/blob/main/torch/distributed/_tensor/ops/tensor_ops.py -from torch.distributed.tensor._op_schema import ( +from torch.distributed._tensor._op_schema import ( OpStrategy, PlacementStrategy, StrategyType, ) -from torch.distributed.tensor._ops.utils import ( +from torch.distributed._tensor.ops.utils import ( is_tensor_partial, ) -from torch.distributed.tensor.placement_types import ( +from torch.distributed._tensor.placement_types import ( DTensorSpec, Partial, Replicate, diff --git a/src/chop/passes/graph/analysis/autosharding/ops/view_ops.py b/src/chop/passes/graph/analysis/autosharding/ops/view_ops.py index 48bb10c96..4902e83ca 100644 --- a/src/chop/passes/graph/analysis/autosharding/ops/view_ops.py +++ b/src/chop/passes/graph/analysis/autosharding/ops/view_ops.py @@ -17,18 +17,18 @@ import torch from torch import Tensor -from torch.distributed.tensor._op_schema import ( +from torch.distributed._tensor._op_schema import ( OpStrategy, PlacementStrategy, ) -from torch.distributed.tensor.api import Shard -from torch.distributed.tensor._ops.utils import ( +from torch.distributed._tensor.api import Shard +from torch.distributed._tensor.ops.utils import ( generate_redistribute_costs, normalize_dim, normalize_dims, prod, ) -from torch.distributed.tensor.placement_types import ( +from torch.distributed._tensor.placement_types import ( DTensorSpec, Placement, Replicate, diff --git a/src/chop/passes/graph/analysis/autosharding/strategies/basic_strategy.py b/src/chop/passes/graph/analysis/autosharding/strategies/basic_strategy.py index 94708b89b..b9541013a 100644 --- a/src/chop/passes/graph/analysis/autosharding/strategies/basic_strategy.py +++ b/src/chop/passes/graph/analysis/autosharding/strategies/basic_strategy.py @@ -5,8 +5,8 @@ from dataclasses import dataclass from typing import List, Set, Tuple -from torch.distributed.tensor._op_schema import OpStrategy, PlacementStrategy -from torch.distributed.tensor.placement_types import ( +from torch.distributed._tensor._op_schema import OpStrategy, PlacementStrategy +from torch.distributed._tensor.placement_types import ( DTensorSpec, _Partial, Placement, diff --git a/src/chop/passes/graph/analysis/autosharding/strategies/common.py b/src/chop/passes/graph/analysis/autosharding/strategies/common.py index ef3333373..5e98c840e 100644 --- a/src/chop/passes/graph/analysis/autosharding/strategies/common.py +++ b/src/chop/passes/graph/analysis/autosharding/strategies/common.py @@ -2,8 +2,8 @@ import torch import torch.nn.functional as F -from torch.distributed.tensor._op_schema import OpStrategy, PlacementStrategy -from torch.distributed.tensor.placement_types import ( +from torch.distributed._tensor._op_schema import OpStrategy, PlacementStrategy +from torch.distributed._tensor.placement_types import ( Replicate, Shard, DTensorSpec, diff --git a/src/chop/passes/graph/analysis/autosharding/strategies/embedding_ops.py b/src/chop/passes/graph/analysis/autosharding/strategies/embedding_ops.py index 2e3872054..feef6fd12 100644 --- a/src/chop/passes/graph/analysis/autosharding/strategies/embedding_ops.py +++ b/src/chop/passes/graph/analysis/autosharding/strategies/embedding_ops.py @@ -7,18 +7,18 @@ import torch import torch.distributed._functional_collectives as funcol -from torch.distributed.tensor._op_schema import ( +from torch.distributed._tensor._op_schema import ( OpSchema, OpStrategy, StrategyType, DTensorSpec, PlacementStrategy, ) -from torch.distributed.tensor._ops.utils import ( +from torch.distributed._tensor.ops.utils import ( is_tensor_shardable, generate_redistribute_costs, ) -from torch.distributed.tensor.placement_types import ( +from torch.distributed._tensor.placement_types import ( Partial, Placement, Replicate, diff --git a/src/chop/passes/graph/analysis/autosharding/strategies/math_ops.py b/src/chop/passes/graph/analysis/autosharding/strategies/math_ops.py index 0c6db820b..2d2ecb200 100644 --- a/src/chop/passes/graph/analysis/autosharding/strategies/math_ops.py +++ b/src/chop/passes/graph/analysis/autosharding/strategies/math_ops.py @@ -4,17 +4,17 @@ from typing import cast, List, Optional, Sequence, Tuple, Union import torch -from torch.distributed.tensor._op_schema import ( +from torch.distributed._tensor._op_schema import ( OpStrategy, PlacementStrategy, ) -from torch.distributed.tensor._ops.utils import ( +from torch.distributed._tensor.ops.utils import ( normalize_dim, ) -from torch.distributed.tensor._utils import ( +from torch.distributed._tensor._utils import ( normalize_to_torch_size, ) -from torch.distributed.tensor.placement_types import ( +from torch.distributed._tensor.placement_types import ( DTensorSpec, Placement, Replicate, diff --git a/src/chop/passes/graph/analysis/autosharding/strategies/matrix_ops.py b/src/chop/passes/graph/analysis/autosharding/strategies/matrix_ops.py index a5581797f..245e42f69 100644 --- a/src/chop/passes/graph/analysis/autosharding/strategies/matrix_ops.py +++ b/src/chop/passes/graph/analysis/autosharding/strategies/matrix_ops.py @@ -2,18 +2,18 @@ # https://github.com/pytorch/pytorch/blob/main/torch/distributed/_tensor/ops/matrix_ops.py import torch -from torch.distributed.tensor._op_schema import ( +from torch.distributed._tensor._op_schema import ( OpStrategy, PlacementStrategy, PlacementList, ) -from torch.distributed.tensor.placement_types import Replicate, Shard, Placement +from torch.distributed._tensor.placement_types import Replicate, Shard, Placement from .basic_strategy import gen_einsum_strategies -from torch.distributed.tensor._ops.utils import ( +from torch.distributed._tensor.ops.utils import ( infer_broadcast_dims_map, map_placements_after_broadcast, ) -from torch.distributed.tensor.placement_types import ( +from torch.distributed._tensor.placement_types import ( DTensorSpec, Shard, TensorMeta, diff --git a/src/chop/passes/graph/analysis/autosharding/strategies/pointwise_ops.py b/src/chop/passes/graph/analysis/autosharding/strategies/pointwise_ops.py index ac609e2e3..fe6ac93ef 100644 --- a/src/chop/passes/graph/analysis/autosharding/strategies/pointwise_ops.py +++ b/src/chop/passes/graph/analysis/autosharding/strategies/pointwise_ops.py @@ -4,19 +4,19 @@ from typing import List, Sequence, Tuple import torch -from torch.distributed.tensor._op_schema import ( +from torch.distributed._tensor._op_schema import ( _is_inplace_op, _is_out_variant_op, OpStrategy, PlacementStrategy, ) -from torch.distributed.tensor._ops.utils import ( +from torch.distributed._tensor.ops.utils import ( generate_redistribute_costs, infer_broadcast_dims_map, map_placements_after_broadcast, normalize_dim, ) -from torch.distributed.tensor.placement_types import ( +from torch.distributed._tensor.placement_types import ( DTensorSpec, Partial, Placement, diff --git a/src/chop/passes/graph/analysis/autosharding/strategies/tensor_ops.py b/src/chop/passes/graph/analysis/autosharding/strategies/tensor_ops.py index 90869c4d6..0a9daee2a 100644 --- a/src/chop/passes/graph/analysis/autosharding/strategies/tensor_ops.py +++ b/src/chop/passes/graph/analysis/autosharding/strategies/tensor_ops.py @@ -1,15 +1,15 @@ # Adapted from Pytorch Distributed DTensor API. # https://github.com/pytorch/pytorch/blob/main/torch/distributed/_tensor/ops/tensor_ops.py -from torch.distributed.tensor._op_schema import ( +from torch.distributed._tensor._op_schema import ( OpStrategy, PlacementStrategy, StrategyType, ) -from torch.distributed.tensor._ops.utils import ( +from torch.distributed._tensor.ops.utils import ( is_tensor_partial, ) -from torch.distributed.tensor.placement_types import ( +from torch.distributed._tensor.placement_types import ( DTensorSpec, Partial, Replicate, diff --git a/src/chop/passes/graph/analysis/autosharding/strategies/view_ops.py b/src/chop/passes/graph/analysis/autosharding/strategies/view_ops.py index e747ef924..70e62b450 100644 --- a/src/chop/passes/graph/analysis/autosharding/strategies/view_ops.py +++ b/src/chop/passes/graph/analysis/autosharding/strategies/view_ops.py @@ -17,18 +17,18 @@ import torch from torch import Tensor -from torch.distributed.tensor._op_schema import ( +from torch.distributed._tensor._op_schema import ( OpStrategy, PlacementStrategy, ) -from torch.distributed.tensor.api import Shard -from torch.distributed.tensor._ops.utils import ( +from torch.distributed._tensor.api import Shard +from torch.distributed._tensor.ops.utils import ( generate_redistribute_costs, normalize_dim, normalize_dims, prod, ) -from torch.distributed.tensor.placement_types import ( +from torch.distributed._tensor.placement_types import ( DTensorSpec, Placement, Replicate, diff --git a/src/chop/passes/graph/analysis/autosharding/utils.py b/src/chop/passes/graph/analysis/autosharding/utils.py index 9af911a55..54ff94bea 100644 --- a/src/chop/passes/graph/analysis/autosharding/utils.py +++ b/src/chop/passes/graph/analysis/autosharding/utils.py @@ -1,5 +1,5 @@ from typing import Sequence, cast -from torch.distributed.tensor.placement_types import DTensorSpec, Shard +from torch.distributed._tensor.placement_types import DTensorSpec, Shard def is_tensor_shardable(shape: Sequence[int], spec: DTensorSpec) -> bool: diff --git a/src/chop/passes/graph/transforms/insert_dtensor_wrapper.py b/src/chop/passes/graph/transforms/insert_dtensor_wrapper.py index 73402bc2a..cda5508d2 100644 --- a/src/chop/passes/graph/transforms/insert_dtensor_wrapper.py +++ b/src/chop/passes/graph/transforms/insert_dtensor_wrapper.py @@ -1,5 +1,5 @@ import torch -from torch.distributed.tensor.api import DTensorSpec, TensorMeta +from torch.distributed._tensor.api import DTensorSpec, TensorMeta from torch.distributed import DeviceMesh from copy import deepcopy diff --git a/src/chop/passes/graph/transforms/resharding.py b/src/chop/passes/graph/transforms/resharding.py index 1b5a007b0..722225739 100644 --- a/src/chop/passes/graph/transforms/resharding.py +++ b/src/chop/passes/graph/transforms/resharding.py @@ -2,7 +2,7 @@ import torch import torch.fx as fx -from torch.distributed.tensor.placement_types import Replicate, Shard +from torch.distributed._tensor.placement_types import Replicate, Shard from chop.tools import get_logger from chop.nn.functional.dtensor import redistribute_dtensor @@ -115,7 +115,7 @@ def _insert_resharding_nodes(mg, pass_args={}): # Insert DTensor import at the top of code def insert_imports(body): return [ - "from torch.distributed.tensor.placement_types import Replicate, Shard, Partial; sum = 'sum' \n", + "from torch.distributed._tensor.placement_types import Replicate, Shard, Partial; sum = 'sum' \n", *body, ] From 6e43341c1a8a8c31005be13c9c2ecec9f0c76300 Mon Sep 17 00:00:00 2001 From: Pedro Gimenes Date: Thu, 29 Aug 2024 17:46:02 +0000 Subject: [PATCH 85/93] remove deprecated files --- .../autosharding/alpa_intra_operator.py | 344 ------------------ .../graph/analysis/autosharding/megatron.py | 23 -- 2 files changed, 367 deletions(-) delete mode 100644 src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py delete mode 100644 src/chop/passes/graph/analysis/autosharding/megatron.py diff --git a/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py b/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py deleted file mode 100644 index b431095d6..000000000 --- a/src/chop/passes/graph/analysis/autosharding/alpa_intra_operator.py +++ /dev/null @@ -1,344 +0,0 @@ -import torch -import torch.fx as fx -from torch.distributed._tensor._collective_utils import redistribute_cost -from torch.distributed._tensor._op_schema import DTensorSpec -import numpy as np -import cvxpy as cp - -from chop.tools import get_logger -from chop.tools.utils import deepgetattr - -from .layers import ( - AUTOSHARDING_MODULES, - AUTOSHARDING_FUNCTIONS, - AUTOSHARDING_METHODS, - IMPLICIT_FUNCS, - IMPLICIT_METHODS, -) -from .strategies.common import ( - fully_replicated_strategy, - placeholder_or_getattr_strategy, -) - - -logger = get_logger(__name__) -logger.setLevel("DEBUG") - - -def _extract_ilp(mg, mesh, pass_args={}): - """ - For each node in the graph, assign an OpStrategy object which contains all possible - sharding algorithms. Also assign opt_var instance which is one-hot vector used to - solve ILP. - - Return list of constraints associated with ILP. The constraints at this stage only - enforce that each optimizer variable is a one-hot boolean vector. - - Args: - mg (MaseGraph): input mase graph. - mesh (MeshModel): mesh model. - pass_args (dict, optional): pass arguments. Defaults to {}. - - Returns: - MaseGraph: input mase graph. - cp.Problem: optimization problem. - """ - - # Setup for the ILP optimization - constr = [] - expr = 0 - - # Find sharding strategies for each operator in the graph - for node in mg.fx_graph.nodes: - - if (node.op == "call_function" and node.target in IMPLICIT_FUNCS) or ( - node.op == "call_method" and node.target in IMPLICIT_METHODS - ): - logger.debug( - f"Implicit {node.op} node {node.name} was assigned fully replicated sharding." - ) - - op_strategy = fully_replicated_strategy(node.meta["mase"], mesh) - - opt_var = cp.Variable(1, boolean=True) - constr += [ - cp.sum(opt_var) == 1, - ] - - # Opt var is None since no decision needs to be taken - node.meta["mase"]["software"]["autosharding"] = { - "op_strategy": op_strategy, - "opt_var": opt_var, - "input": None, - "output": None, - } - continue - - # Obtain strategy according to node op - # ================================================ - - if node.op in ["placeholder", "get_attr"]: - logger.debug( - f"Node {node} with op {node.op} will be assigned all permutations of Shard(dims) and Replicate()" - ) - op_strategy = placeholder_or_getattr_strategy( - node.meta["mase"], - mesh, - skip_fully_replicated=pass_args.get("skip_fully_replicated", False), - ) - - elif node.op == "output": - logger.debug( - f"Op strategy from node {node.all_input_nodes[0]} is propagated to {node} node." - ) - node.meta["mase"]["software"]["autosharding"] = { - "op_strategy": node.all_input_nodes[0].meta["mase"]["software"][ - "autosharding" - ]["op_strategy"], - "opt_var": None, - "input": None, - "output": None, - } - continue - - elif node.op == "call_module" and isinstance( - deepgetattr(mg.model, node.target), tuple(AUTOSHARDING_MODULES.keys()) - ): - logger.debug(f"Obtaining strategy for node {node.name}") - module_cls = type(deepgetattr(mg.model, node.target)) - op_strategy = AUTOSHARDING_MODULES[module_cls](node.meta["mase"], mesh) - - elif node.op == "call_method" and node.target in AUTOSHARDING_METHODS.keys(): - logger.debug(f"Obtaining strategy for node {node.name}") - op_strategy = AUTOSHARDING_METHODS[node.target](node.meta["mase"], mesh) - - elif ( - node.op == "call_function" and node.target in AUTOSHARDING_FUNCTIONS.keys() - ): - logger.debug(f"Obtaining strategy for node {node.name}") - op_strategy = AUTOSHARDING_FUNCTIONS[node.target](node.meta["mase"], mesh) - - else: - logger.warning(f"Unknown node {node.name} with op {node.op}") - op_strategy = fully_replicated_strategy(node.meta["mase"], mesh) - opt_var = cp.Variable(1, boolean=True) - constr += [ - cp.sum(opt_var) == 1, - ] - node.meta["mase"]["software"]["autosharding"] = { - "op_strategy": fully_replicated_strategy(node.meta["mase"], mesh), - "opt_var": opt_var, - "input": None, - "output": None, - } - continue - - # Formulate optimization variable - opt_var = cp.Variable(len(op_strategy.strategies), boolean=True) - constr += [ - cp.sum(opt_var) == 1, - ] - - # Write into metadata - node.meta["mase"]["software"]["autosharding"] = { - "op_strategy": op_strategy, - "opt_var": opt_var, - "input": None, - "output": None, - } - - # Consider resharding cost for each of the node's arguments - e_var_checks = [] - for arg_idx, in_node in enumerate(node.all_input_nodes): - - # Skip constant nodes - if not isinstance(in_node, fx.Node) or not isinstance( - in_node.meta["mase"]["common"]["results"]["data_out_0"]["value"], - torch.Tensor, - ): - continue - logger.debug(f"Parsing arg {in_node} of node {node}") - - # Fetch this node's input specs - node_op_strategy = node.meta["mase"]["software"]["autosharding"][ - "op_strategy" - ] - node_in_specs = [ - ( - [strategy.input_specs][arg_idx] - if isinstance(strategy.input_specs, DTensorSpec) - else strategy.input_specs[arg_idx] - ) - for strategy in node_op_strategy.strategies - ] - - # Fetch the argument node's output specs - in_opt_var = in_node.meta["mase"]["software"]["autosharding"]["opt_var"] - arg_op_strategy = in_node.meta["mase"]["software"]["autosharding"][ - "op_strategy" - ] - arg_out_specs = [ - strategy.output_specs for strategy in arg_op_strategy.strategies - ] - - # Formulate resharding cost matrix - resharding_costs = np.zeros((opt_var.shape[0], in_opt_var.shape[0])) - - for dest_idx, dest_spec in enumerate(node_in_specs): - for src_idx, src_spec in enumerate(arg_out_specs): - cost = redistribute_cost(src_spec, dest_spec) - resharding_costs[dest_idx, src_idx] = ( - 1000000 if cost == float("inf") else cost - ) - - resharding_costs = resharding_costs.flatten() - - # Formulate linearized variable for resharding cost - e_var = cp.Variable(resharding_costs.shape[0], boolean=True) - expr += e_var.T @ resharding_costs - constr += [ - cp.sum(e_var) == 1, - ] - - # After solving the ILP, verify constraints were correctly formulated - if pass_args.get("run_checks", False): - e_var_checks.append((opt_var, in_opt_var, e_var)) - - # Constraints s.t. e_var = outer(opt_var, in_opt_var) - indices = np.arange(e_var.shape[0]) - opt_indices, in_opt_indices = np.divmod(indices, in_opt_var.shape[0]) - constr += [ - e_var <= opt_var[opt_indices], - e_var <= in_opt_var[in_opt_indices], - e_var >= opt_var[opt_indices] + in_opt_var[in_opt_indices] - 1, - ] - - if pass_args.get("run_checks", False): - node.meta["mase"]["software"]["autosharding"]["e_var_checks"] = e_var_checks - - # Solve the ILP problem - prob = cp.Problem(cp.Minimize(expr), constr) - return mg, prob - - -def _run_checks(mg, pass_args): - """ - Run checks on the ILP solution to ensure that the constraints were correctly formulated. - - Args: - mg (MaseGraph): input mase graph. - pass_args (dict): pass arguments. - - Returns: - None - """ - - for node in mg.fx_graph.nodes: - check_list = node.meta["mase"]["software"]["autosharding"].get( - "e_var_checks", [] - ) - - # Check that the constraints on the linearised variable for resharding cost - # are correctly formulated - for opt_var, in_opt_var, e_var in check_list: - idx1 = np.where(opt_var.value == 1)[0][0] - idx2 = np.where(in_opt_var.value == 1)[0][0] - idx3 = np.where(e_var.value == 1)[0][0] - assert ( - idx3 == idx1 * in_opt_var.shape[0] + idx2 - ), f"Linearized variable for resharding cost is not consistent for node {node}." - - -def _mark_sharding(mg, pass_args): - """ - After solving the ILP, annotate the metadata of each operator in the graph with the chosen - parallelization strategy. - - Args: - mg (MaseGraph): input mase graph. - pass_args (dict): pass arguments. - - Returns: - MaseGraph: input mase graph. - dict: tensor sharding map. - """ - - for node in mg.fx_graph.nodes: - opt_var = node.meta["mase"]["software"]["autosharding"]["opt_var"] - - if opt_var is None: - continue - - try: - idx = np.where(opt_var.value == 1)[0][0] - except: - idx = np.argmax(opt_var.value) - - chosen_strategy = node.meta["mase"]["software"]["autosharding"][ - "op_strategy" - ].strategies[idx] - - # Annotate chosen placement strategy - node.meta["mase"]["software"]["autosharding"][ - "placement_strategy" - ] = chosen_strategy - - arg_specs = chosen_strategy.input_specs - out_spec = chosen_strategy.output_specs - - if isinstance(arg_specs, DTensorSpec): - arg_specs = (arg_specs,) - - # Annotate arg metadata with chosen strategy - if node.op in ["placeholder", "get_attr", "call_method", "output"]: - pass - - # call_function nodes - else: - arg_list = [i for i in node.meta["mase"]["common"]["args"].keys()] - - for arg_idx, arg_spec in enumerate(arg_specs): - arg_meta = node.meta["mase"]["common"]["args"][arg_list[arg_idx]] - if not isinstance(arg_meta, dict): - continue - arg_meta["dtensor_spec"] = arg_spec - - # Annotate output metadata with chosen strategy - node.meta["mase"]["common"]["results"]["data_out_0"]["dtensor_spec"] = out_spec - - return mg, {} - - -def alpa_intra_op_sharding_pass(mg, mesh, pass_args={}, debug=False): - """Intra-operator auto parallelization pass from the Alpa paper: https://arxiv.org/abs/2201.12023 - - Args: - mg (MaseGraph): Input MaseGraph. - mesh (MeshModel): mesh description. - pass_args (dict, optional): pass arguments. Defaults to {}. - debug (bool, optional): enable debug. Defaults to False. - - Returns: - MaseGraph: annotated MaseGraph. - """ - - # Formulate and solve the ILP - logger.info(f"Formulating the ILP...") - mg, problem = _extract_ilp(mg, mesh, pass_args) - - logger.info(f"Solving the ILP...") - problem.solve( - verbose=True, - scipy_options={ - "disp": True, - "time_limit": pass_args.get("time_limit", None), - "mip_rel_gap": pass_args.get("mip_rel_gap", 0) / 100, - }, - ) - - if pass_args.get("run_checks", False): - _run_checks(mg, pass_args) - - mg, _ = _mark_sharding(mg, pass_args) - - return mg, {"solution": problem.value} diff --git a/src/chop/passes/graph/analysis/autosharding/megatron.py b/src/chop/passes/graph/analysis/autosharding/megatron.py deleted file mode 100644 index 30cd36f7e..000000000 --- a/src/chop/passes/graph/analysis/autosharding/megatron.py +++ /dev/null @@ -1,23 +0,0 @@ -from chop.ir import MaseGraph -from .mesh_model import MeshModel - - -def megatron_autosharding_pass( - mg: MaseGraph, - mesh: MeshModel, - pass_args: dict, -): - for node in mg.fx_graph.nodes: - meta = node.meta["mase"]["common"] - - for arg, arg_spec in meta["args"].items(): - if not isinstance(arg_spec, dict): - continue - arg_spec["dtensor_spec"] = None - - for result, result_spec in meta["results"].items(): - if not isinstance(result_spec, dict): - continue - result_spec["dtensor_spec"] = None - - return mg, {"solution": {}} From 611c835944eeb84541761b9b8f63a9e409a36ba5 Mon Sep 17 00:00:00 2001 From: Pedro Gimenes Date: Thu, 29 Aug 2024 17:46:11 +0000 Subject: [PATCH 86/93] module level autosharding for vllm --- src/chop/passes/module/__init__.py | 2 +- src/chop/passes/module/analysis/__init__.py | 2 + .../passes/module/analysis/autosharding.py | 201 ++++++++++ .../passes/module/analysis/cost_modelling.py | 348 ++++++++++++++++++ 4 files changed, 552 insertions(+), 1 deletion(-) create mode 100644 src/chop/passes/module/analysis/autosharding.py create mode 100644 src/chop/passes/module/analysis/cost_modelling.py diff --git a/src/chop/passes/module/__init__.py b/src/chop/passes/module/__init__.py index f2ffd33ca..50a70ff46 100644 --- a/src/chop/passes/module/__init__.py +++ b/src/chop/passes/module/__init__.py @@ -1,4 +1,4 @@ -from .analysis import calculate_avg_bits_module_analysis_pass +from .analysis import calculate_avg_bits_module_analysis_pass, autosharding_module_analysis_pass from .transforms import quantize_module_transform_pass ANALYSIS_PASSES = ["calculate_avg_bits_module_analysis_pass"] diff --git a/src/chop/passes/module/analysis/__init__.py b/src/chop/passes/module/analysis/__init__.py index b3b7d2ab1..6a50919c7 100644 --- a/src/chop/passes/module/analysis/__init__.py +++ b/src/chop/passes/module/analysis/__init__.py @@ -1 +1,3 @@ from .quantize import calculate_avg_bits_module_analysis_pass + +from .autosharding import autosharding_module_analysis_pass \ No newline at end of file diff --git a/src/chop/passes/module/analysis/autosharding.py b/src/chop/passes/module/analysis/autosharding.py new file mode 100644 index 000000000..f7cf4f6a1 --- /dev/null +++ b/src/chop/passes/module/analysis/autosharding.py @@ -0,0 +1,201 @@ +import torch +import torch.nn.functional as F + +import numpy as np +import cvxpy as cp +from copy import copy +from collections import OrderedDict + +import vllm +from vllm.attention import Attention as VllmAttention + +from chop.tools import get_logger +from chop.distributed.utils import rlog + +from .cost_modelling import ( + _get_compute_cost_from_layer, + _get_resharding_cost_matrix, + _get_intra_op_comms_cost, +) + +VllmLinear = vllm.model_executor.layers.linear.LinearBase + +logger = get_logger(__name__) +logger.setLevel("DEBUG") + +STRATEGY_MAP = OrderedDict( + { + VllmLinear: [ + "replicated", + "column", + "row", + "data", + ], + VllmAttention: [ + "replicated", + "head", + ], + type(None): None, + } +) + + +def _linearize_resharding_cost( + opt_var, + parent_opt_var, + resharding_costs, + pass_args, +): + # Flatten resharding matrix + resharding_costs = resharding_costs.flatten() + + # Formulate linearized variable for resharding cost + e_var = cp.Variable(resharding_costs.shape[0], boolean=True) + expr = e_var.T @ resharding_costs + constr = [ + cp.sum(e_var) == 1, + ] + + # Constraints s.t. e_var = outer(opt_var, in_opt_var) + indices = np.arange(e_var.shape[0]) + opt_indices, in_opt_indices = np.divmod(indices, parent_opt_var.shape[0]) + constr += [ + e_var <= opt_var[opt_indices], + e_var <= parent_opt_var[in_opt_indices], + e_var >= opt_var[opt_indices] + parent_opt_var[in_opt_indices] - 1, + ] + + return expr, constr + + +def _formulate_ilp( + model: torch.nn.Module, + pass_args: dict, +): + + self_rank = torch.distributed.get_rank() + data_size = pass_args.get("data_size", None) + + if data_size is None: + raise ValueError("data_size is required for autosharding analysis") + + module_list = [] + module_strategies = [] + + # ILP variables + constr = [] + expr = 0 + + for layer in model.modules(): + + # Skip non-leaf modules + if len(list(layer.children())) > 0: + continue + rlog(logger, self_rank, f"Parsing layer {layer.__class__.__name__}") + + # Check if matches with one of the supported layer types + for layer_type, layer_strategies in STRATEGY_MAP.items(): + if isinstance(layer, layer_type): + break + + if layer_type is None or layer_strategies is None: + continue + + # Register layer and instantiate optimization variable + # ============================ + module_list.append(layer) + module_strategies.append(layer_strategies) + + opt_var = cp.Variable(len(layer_strategies), boolean=True) + setattr(layer, "opt_var", opt_var) + constr += [ + cp.sum(opt_var) == 1, + ] + + # Consider compute cost + # ============================ + compute_cost = _get_compute_cost_from_layer( + layer, + layer_strategies, + data_size=data_size, + ) + expr += compute_cost @ opt_var + + # Consider intra operator comms cost + # ============================ + comms_cost = _get_intra_op_comms_cost( + layer, + layer_strategies, + pass_args=pass_args, + ) + expr += comms_cost @ opt_var + + # Consider resharding cost + # ============================ + + # Skip if no parent module + if len(module_list) <= 1: + continue + + parent_module = module_list[-2] + parent_strategies = module_strategies[-2] + logger.info( + f"Consider resharding cost between {parent_module.__class__.__name__} and {layer.__class__.__name__}" + ) + + resharding_costs = _get_resharding_cost_matrix( + layer, layer_strategies, parent_module, parent_strategies, pass_args + ) + + resharding_term, resharding_constraints = _linearize_resharding_cost( + opt_var, parent_module.opt_var, resharding_costs, pass_args + ) + expr += resharding_term + constr += resharding_constraints + + return cp.Problem(cp.Minimize(expr), constr) + + +def _get_sharding_config(model): + self_rank = torch.distributed.get_rank() + + sharding_config = {} + for layer in model.modules(): + + # Skip non-leaf modules + if len(list(layer.children())) > 0: + continue + + # Check if matches with one of the supported layer types + for layer_type, layer_strategies in STRATEGY_MAP.items(): + if isinstance(layer, layer_type): + break + + if layer_type is None or layer_strategies is None: + continue + + opt_var_value = layer.opt_var.value + strategy_idx = np.where(opt_var_value)[0][0] + strategy = layer_strategies[strategy_idx] + + sharding_config[layer.prefix] = strategy + + return sharding_config + + +def autosharding_module_analysis_pass(model, pass_args={}): + problem = _formulate_ilp(model, pass_args) + problem.solve( + verbose=True, + scipy_options={ + "disp": pass_args.get(f"debug", False), + "time_limit": pass_args.get("time_limit", None), + "mip_rel_gap": pass_args.get("mip_rel_gap", 0) / 100, + }, + ) + + sharding_config = _get_sharding_config(model) + + return model, { + "sharding_config": sharding_config, + } diff --git a/src/chop/passes/module/analysis/cost_modelling.py b/src/chop/passes/module/analysis/cost_modelling.py new file mode 100644 index 000000000..3ecbbe2ee --- /dev/null +++ b/src/chop/passes/module/analysis/cost_modelling.py @@ -0,0 +1,348 @@ +import math +import numpy as np +from copy import copy +from functools import lru_cache + +import torch +from torch.nn import functional as F + +import vllm +from vllm.attention import Attention as VllmAttention + +# Utilities +# ================================ + + +def _profile_op( + fn: callable, + args: list, + repeat: int, + warmup_iters: int, +): + start_event = [ + torch.cuda.Event(enable_timing=True, blocking=True) for _ in range(repeat) + ] + end_event = [ + torch.cuda.Event(enable_timing=True, blocking=True) for _ in range(repeat) + ] + + for idx in range(repeat): + start_event[idx].record() + out = fn(*args) + end_event[idx].record() + torch.cuda.synchronize(device=f"cuda:0") + + elapsed = [start_event[idx].elapsed_time(end_event[idx]) for idx in range(repeat)] + + return out, np.mean(elapsed[warmup_iters:]) + + +def allreduce_cost( + bytes_gb: float, + intra_device_latency: float, + intra_device_bandwidth: float, +) -> float: + world_size = torch.distributed.get_world_size() + mesh_dim_bandwidth = intra_device_bandwidth + # allreduce have almost 2x comm bytes compare to allgather/reduce_scatter + num_hops = 2 * world_size - 1 + + latency = 6.6 + num_hops * intra_device_latency + bw = (bytes_gb * num_hops / world_size) / mesh_dim_bandwidth + return latency + bw * 1e6 + + +def allgather_cost( + bytes_gb: float, + intra_device_latency: float, + intra_device_bandwidth: float, +) -> float: + world_size = torch.distributed.get_world_size() + num_hops = world_size - 1 + latency = 6.6 + num_hops * intra_device_latency + bw = (bytes_gb * num_hops / world_size) / intra_device_bandwidth + return latency + bw * 1e6 + + +def _get_output_shape_from_layer_type( + layer: torch.nn.Module, + data_size: int, +): + if isinstance(layer, vllm.model_executor.layers.linear.LinearBase): + return torch.Size([data_size, layer.weight.shape[0]]) + if isinstance(layer, VllmAttention): + return torch.Size([data_size, layer.impl.head_size * layer.impl.num_heads]) + else: + raise ValueError(f"Unsupported layer type: {layer.__class__.__name__}") + + +# Compute cost +# ================================ + + +@lru_cache(maxsize=128, typed=False) +def _cached_linear_cost_from_local_shapes( + local_weight_shape: torch.Size, + local_bias_shape: torch.Size, + local_input_shape: torch.Size, + repeat: int = 100, + warmup_iters: int = 2, +): + local_weights = torch.randn(local_weight_shape).to("cuda:0") + local_bias = torch.randn(local_bias_shape).to("cuda:0") + local_input = torch.randn(local_input_shape).to("cuda:0") + + _, elapsed = _profile_op( + fn=F.linear, + args=[local_input, local_weights, local_bias], + repeat=repeat, + warmup_iters=warmup_iters, + ) + + return elapsed + + +def _get_linear_compute_cost( + layer: torch.nn.Module, + layer_strategies: list, + data_size: int, + repeat: int = 100, + warmup_iters: int = 2, +): + + world_size = torch.distributed.get_world_size() + + # Global shapes + global_weight_shape = layer.weight.shape + global_bias_shape = layer.bias.shape + global_input_shape = torch.Size([data_size, global_weight_shape[1]]) + + cost_vector = [] + for strategy in layer_strategies: + + # Default values for local tensors + # (taken for replicated strategy) + local_weight_shape = copy(global_weight_shape) + local_bias_shape = copy(global_bias_shape) + local_input_shape = copy(global_input_shape) + + if strategy == "replicated": + pass + elif strategy == "column": + local_weight_shape = torch.Size( + [global_weight_shape[0] // world_size, global_weight_shape[1]] + ) + local_bias_shape = torch.Size([global_bias_shape[0] // world_size]) + elif strategy == "row": + local_input_shape = torch.Size( + [global_input_shape[0], global_input_shape[1] // world_size] + ) + local_weight_shape = torch.Size( + [global_weight_shape[0], global_weight_shape[1] // world_size] + ) + elif strategy == "data": + local_input_shape = torch.Size( + [global_input_shape[0] // world_size, global_input_shape[1]] + ) + else: + raise ValueError(f"Unknown strategy: {strategy}") + + # Create local tensors + elapsed = _cached_linear_cost_from_local_shapes( + local_weight_shape, + local_bias_shape, + local_input_shape, + repeat=repeat, + warmup_iters=warmup_iters, + ) + + cost_vector.append(elapsed) + + return np.array(cost_vector) + + +@lru_cache(maxsize=128, typed=False) +def _cached_attention_cost_from_local_shapes( + local_shape: torch.Size, + repeat: int = 100, + warmup_iters: int = 2, +): + local_query = torch.randn(local_shape).to("cuda:0") + local_key = torch.randn(local_shape).to("cuda:0") + local_value = torch.randn(local_shape).to("cuda:0") + + _, elapsed = _profile_op( + fn=F.scaled_dot_product_attention, + args=[local_query, local_key, local_value], + repeat=repeat, + warmup_iters=warmup_iters, + ) + + return elapsed + + +def _get_attention_compute_cost( + layer: torch.nn.Module, + layer_strategies: list, + data_size: int, + repeat: int = 100, + warmup_iters: int = 2, +): + + world_size = torch.distributed.get_world_size() + + global_shape = torch.Size([data_size, layer.impl.head_size * layer.impl.num_heads]) + + cost_vector = [] + for strategy in layer_strategies: + local_shape = copy(global_shape) + + if strategy == "replicated": + pass + elif strategy == "column": + local_shape = torch.Size([global_shape[0], global_shape[1] // world_size]) + + elapsed = _cached_attention_cost_from_local_shapes( + local_shape, + repeat=repeat, + warmup_iters=warmup_iters, + ) + + cost_vector.append(elapsed) + + return np.array(cost_vector) + + +def _get_compute_cost_from_layer(layer, layer_strategies, data_size): + if isinstance(layer, vllm.model_executor.layers.linear.LinearBase): + return _get_linear_compute_cost( + layer, + layer_strategies, + data_size, + ) + if isinstance(layer, VllmAttention): + return _get_attention_compute_cost( + layer, + layer_strategies, + data_size, + ) + else: + raise ValueError(f"Unsupported layer type: {layer.__class__.__name__}") + + +def _get_intra_op_comms_cost( + layer: torch.nn.Module, + layer_strategies: list, + pass_args: dict, +): + bw = pass_args.get("intra_device_bandwidth", None) + lat = pass_args.get("intra_device_latency", None) + data_size = pass_args.get("data_size", None) + + if bw is None: + raise ValueError("intra_device_bandwidth is not provided") + if lat is None: + raise ValueError("intra_device_latency is not provided") + if data_size is None: + raise ValueError("data_size is not provided") + + comms_cost = np.zeros(len(layer_strategies)) + for idx, strategy in enumerate(layer_strategies): + if strategy == "row": + comms_cost[idx] = allreduce_cost( + bytes_gb=math.prod(_get_output_shape_from_layer_type(layer, data_size)) + * 4 + / 1e9, + intra_device_latency=lat, + intra_device_bandwidth=bw, + ) + + return comms_cost * 1e-6 # convert back to seconds + + +# Resharding cost +# ================================ + + +def _get_resharding_cost( + layer: torch.nn.Module, + module_strategy: str, + parent_module: torch.nn.Module, + parent_strategy: str, + pass_args: dict, +) -> float: + + # Strategies which always return RR sharding + if parent_strategy in ["replicated", "row"]: + return 0 + + world_size = torch.distributed.get_world_size() + bw = pass_args.get("intra_device_bandwidth", None) + lat = pass_args.get("intra_device_latency", None) + data_size = pass_args.get("data_size", None) + + if bw is None: + raise ValueError("intra_device_bandwidth is not provided") + if lat is None: + raise ValueError("intra_device_latency is not provided") + if data_size is None: + raise ValueError("data_size is not provided") + + # all gather operation + skip_allgather = ( + ( + # Column parallel linear -> Row parallel linear (Megatron-LM) + parent_strategy == "column" + and module_strategy == "row" + ) + or ( + # Column parallel linear -> Head parallel attention (Megatron-LM) + parent_strategy == "column" + and module_strategy == "head" + ) + or ( + # Head parallel attention -> Row parallel linear (Megatron-LM) + parent_strategy == "head" + and module_strategy == "row" + ) + or ( + # Data parallel linear -> Data parallel linear + parent_strategy == "data" + and module_strategy == "data" + ) + ) + + if not skip_allgather: + cost = allgather_cost( + bytes_gb=world_size + * math.prod(_get_output_shape_from_layer_type(parent_module, data_size)) + * 4 + / 1e9, + intra_device_latency=lat, + intra_device_bandwidth=bw, + ) + # elif... + else: + cost = 0 + + return cost + + +def _get_resharding_cost_matrix( + layer, layer_strategies, parent_module, parent_strategies, pass_args +): + + resharding_costs = np.zeros([len(parent_strategies), len(layer_strategies)]) + for module_strategy_idx, module_strategy in enumerate(layer_strategies): + for parent_strategy_idx, parent_strategy in enumerate(parent_strategies): + resharding_costs[parent_strategy_idx, module_strategy_idx] = ( + _get_resharding_cost( + layer, + module_strategy, + parent_module, + parent_strategy, + pass_args, + ) + ) + + return resharding_costs * 1e-6 # convert back to seconds From d31ce94de9f92d4aed69a61e27d33c3f376591fe Mon Sep 17 00:00:00 2001 From: Pedro Gimenes Date: Fri, 30 Aug 2024 13:02:37 +0000 Subject: [PATCH 87/93] module level autosharding: update cost modelling for allgather/allreduce to use real profiling --- .../passes/module/analysis/cost_modelling.py | 145 +++++++++++++----- 1 file changed, 107 insertions(+), 38 deletions(-) diff --git a/src/chop/passes/module/analysis/cost_modelling.py b/src/chop/passes/module/analysis/cost_modelling.py index 3ecbbe2ee..3fda36a8c 100644 --- a/src/chop/passes/module/analysis/cost_modelling.py +++ b/src/chop/passes/module/analysis/cost_modelling.py @@ -1,3 +1,4 @@ +import os import math import numpy as np from copy import copy @@ -5,10 +6,27 @@ import torch from torch.nn import functional as F +import torch.distributed as dist +import torch.multiprocessing as mp +from torch.multiprocessing import Queue, set_start_method import vllm from vllm.attention import Attention as VllmAttention +VllmLinear = vllm.model_executor.layers.linear.LinearBase + +_ALL_REDUCE_COST_DB = { + (8, 1536): 0.2580975145101547, + (8, 4608): 0.4603813052177429, + (8, 6144): 0.5322111986185375, +} + +_ALL_GATHER_COST_DB = { + (8, 1536): 0.26789389503629585, + (8, 4608): 0.24845608441453232, + (8, 6144): 0.3390117042943051, +} + # Utilities # ================================ @@ -37,44 +55,106 @@ def _profile_op( return out, np.mean(elapsed[warmup_iters:]) +def _profile_distributed_op( + rank, + world_size, + result_queue, + repeat, + warmup_iters, + op, + global_shape, +): + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "12356" + os.environ["RANK"] = str(rank) + + # Initialize + device = torch.device("cuda", rank) + dist.init_process_group("nccl", rank=rank, world_size=world_size, device_id=device) + torch.cuda.set_device(device) + + start_event = [ + torch.cuda.Event(enable_timing=True, blocking=True) for _ in range(repeat) + ] + end_event = [ + torch.cuda.Event(enable_timing=True, blocking=True) for _ in range(repeat) + ] + + for idx in range(repeat): + if op == "allgather": + output_tensor = torch.zeros(global_shape, device=device) + local_shape = [global_shape[0], global_shape[1] // world_size] + local_tensor = torch.randn(local_shape, device=device) + + dist.barrier() + start_event[idx].record() + + dist.all_gather_into_tensor(output_tensor, local_tensor) + output_tensor = output_tensor.movedim(0, 1) + output_tensor = output_tensor.reshape(global_shape) + + elif op == "allreduce": + local_tensor = torch.randn(global_shape, device=device) + + dist.barrier() + start_event[idx].record() + + dist.all_reduce(local_tensor) + + dist.barrier() + end_event[idx].record() + + torch.cuda.synchronize(device=device) + + elapsed = [start_event[idx].elapsed_time(end_event[idx]) for idx in range(repeat)] + + avg = sum(elapsed[warmup_iters:]) / len(elapsed[warmup_iters:]) + + if rank == 0: + result_queue.put(avg) + + dist.barrier() + dist.destroy_process_group() + + +# @lru_cache(maxsize=128, typed=False) def allreduce_cost( - bytes_gb: float, - intra_device_latency: float, - intra_device_bandwidth: float, + output_shape: list, + repeat: int = 100, + warmup_iters: int = 5, ) -> float: - world_size = torch.distributed.get_world_size() - mesh_dim_bandwidth = intra_device_bandwidth - # allreduce have almost 2x comm bytes compare to allgather/reduce_scatter - num_hops = 2 * world_size - 1 + cost = _ALL_REDUCE_COST_DB.get(tuple(output_shape), None) + if cost is None: + raise ValueError(f"Unknown allreduce cost for shape: {output_shape}") - latency = 6.6 + num_hops * intra_device_latency - bw = (bytes_gb * num_hops / world_size) / mesh_dim_bandwidth - return latency + bw * 1e6 + return cost def allgather_cost( - bytes_gb: float, - intra_device_latency: float, - intra_device_bandwidth: float, + output_shape: list, + repeat: int = 100, + warmup_iters: int = 5, ) -> float: - world_size = torch.distributed.get_world_size() - num_hops = world_size - 1 - latency = 6.6 + num_hops * intra_device_latency - bw = (bytes_gb * num_hops / world_size) / intra_device_bandwidth - return latency + bw * 1e6 + cost = _ALL_GATHER_COST_DB.get(tuple(output_shape), None) + if cost is None: + raise ValueError(f"Unknown allgather cost for shape: {output_shape}") + + return cost def _get_output_shape_from_layer_type( layer: torch.nn.Module, data_size: int, ): - if isinstance(layer, vllm.model_executor.layers.linear.LinearBase): - return torch.Size([data_size, layer.weight.shape[0]]) - if isinstance(layer, VllmAttention): - return torch.Size([data_size, layer.impl.head_size * layer.impl.num_heads]) + if isinstance(layer, VllmLinear): + size = torch.Size([data_size, layer.weight.shape[0]]) + elif isinstance(layer, VllmAttention): + size = torch.Size([data_size, layer.impl.head_size * layer.impl.num_heads]) else: raise ValueError(f"Unsupported layer type: {layer.__class__.__name__}") + return list(size) + # Compute cost # ================================ @@ -214,7 +294,7 @@ def _get_attention_compute_cost( def _get_compute_cost_from_layer(layer, layer_strategies, data_size): - if isinstance(layer, vllm.model_executor.layers.linear.LinearBase): + if isinstance(layer, VllmLinear): return _get_linear_compute_cost( layer, layer_strategies, @@ -249,13 +329,8 @@ def _get_intra_op_comms_cost( comms_cost = np.zeros(len(layer_strategies)) for idx, strategy in enumerate(layer_strategies): if strategy == "row": - comms_cost[idx] = allreduce_cost( - bytes_gb=math.prod(_get_output_shape_from_layer_type(layer, data_size)) - * 4 - / 1e9, - intra_device_latency=lat, - intra_device_bandwidth=bw, - ) + out_shape = _get_output_shape_from_layer_type(layer, data_size) + comms_cost[idx] = allreduce_cost(output_shape=out_shape) return comms_cost * 1e-6 # convert back to seconds @@ -313,14 +388,8 @@ def _get_resharding_cost( ) if not skip_allgather: - cost = allgather_cost( - bytes_gb=world_size - * math.prod(_get_output_shape_from_layer_type(parent_module, data_size)) - * 4 - / 1e9, - intra_device_latency=lat, - intra_device_bandwidth=bw, - ) + out_shape = _get_output_shape_from_layer_type(parent_module, data_size) + cost = allgather_cost(output_shape=out_shape) # elif... else: cost = 0 From fcabe9a8dc89dbfc83cc11665dc85ba0d0ea7792 Mon Sep 17 00:00:00 2001 From: Pedro Gimenes Date: Fri, 30 Aug 2024 16:29:28 +0000 Subject: [PATCH 88/93] fix cost modelling time units --- .../passes/module/analysis/cost_modelling.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/src/chop/passes/module/analysis/cost_modelling.py b/src/chop/passes/module/analysis/cost_modelling.py index 3fda36a8c..b51a3b02a 100644 --- a/src/chop/passes/module/analysis/cost_modelling.py +++ b/src/chop/passes/module/analysis/cost_modelling.py @@ -15,12 +15,14 @@ VllmLinear = vllm.model_executor.layers.linear.LinearBase +# costs are in ms _ALL_REDUCE_COST_DB = { (8, 1536): 0.2580975145101547, (8, 4608): 0.4603813052177429, (8, 6144): 0.5322111986185375, } +# costs are in ms _ALL_GATHER_COST_DB = { (8, 1536): 0.26789389503629585, (8, 4608): 0.24845608441453232, @@ -52,7 +54,7 @@ def _profile_op( elapsed = [start_event[idx].elapsed_time(end_event[idx]) for idx in range(repeat)] - return out, np.mean(elapsed[warmup_iters:]) + return out, np.mean(elapsed[warmup_iters:]) * 1e-3 # convert back to seconds def _profile_distributed_op( @@ -127,7 +129,7 @@ def allreduce_cost( if cost is None: raise ValueError(f"Unknown allreduce cost for shape: {output_shape}") - return cost + return cost * 1e-3 # convert back to seconds def allgather_cost( @@ -139,7 +141,7 @@ def allgather_cost( if cost is None: raise ValueError(f"Unknown allgather cost for shape: {output_shape}") - return cost + return cost * 1e-3 # convert back to seconds def _get_output_shape_from_layer_type( @@ -275,12 +277,13 @@ def _get_attention_compute_cost( cost_vector = [] for strategy in layer_strategies: - local_shape = copy(global_shape) if strategy == "replicated": - pass - elif strategy == "column": + local_shape = copy(global_shape) + elif strategy == "head": local_shape = torch.Size([global_shape[0], global_shape[1] // world_size]) + else: + raise ValueError(f"Unknown strategy: {strategy}") elapsed = _cached_attention_cost_from_local_shapes( local_shape, @@ -332,7 +335,7 @@ def _get_intra_op_comms_cost( out_shape = _get_output_shape_from_layer_type(layer, data_size) comms_cost[idx] = allreduce_cost(output_shape=out_shape) - return comms_cost * 1e-6 # convert back to seconds + return comms_cost # Resharding cost @@ -414,4 +417,4 @@ def _get_resharding_cost_matrix( ) ) - return resharding_costs * 1e-6 # convert back to seconds + return resharding_costs From dc6e03e8dd2bf95ddb036095c3401023463b5087 Mon Sep 17 00:00:00 2001 From: Pedro Gimenes Date: Mon, 2 Sep 2024 19:09:09 +0000 Subject: [PATCH 89/93] allgather/allreduce: replace cost db with regression model --- .../passes/module/analysis/autosharding.py | 5 +- .../passes/module/analysis/cost_modelling.py | 78 +++++++++++-------- 2 files changed, 48 insertions(+), 35 deletions(-) diff --git a/src/chop/passes/module/analysis/autosharding.py b/src/chop/passes/module/analysis/autosharding.py index f7cf4f6a1..377b352c2 100644 --- a/src/chop/passes/module/analysis/autosharding.py +++ b/src/chop/passes/module/analysis/autosharding.py @@ -119,7 +119,6 @@ def _formulate_ilp( layer_strategies, data_size=data_size, ) - expr += compute_cost @ opt_var # Consider intra operator comms cost # ============================ @@ -128,7 +127,9 @@ def _formulate_ilp( layer_strategies, pass_args=pass_args, ) - expr += comms_cost @ opt_var + + expr += (compute_cost + comms_cost) @ opt_var + # Consider resharding cost # ============================ diff --git a/src/chop/passes/module/analysis/cost_modelling.py b/src/chop/passes/module/analysis/cost_modelling.py index b51a3b02a..c3ac9a746 100644 --- a/src/chop/passes/module/analysis/cost_modelling.py +++ b/src/chop/passes/module/analysis/cost_modelling.py @@ -15,19 +15,6 @@ VllmLinear = vllm.model_executor.layers.linear.LinearBase -# costs are in ms -_ALL_REDUCE_COST_DB = { - (8, 1536): 0.2580975145101547, - (8, 4608): 0.4603813052177429, - (8, 6144): 0.5322111986185375, -} - -# costs are in ms -_ALL_GATHER_COST_DB = { - (8, 1536): 0.26789389503629585, - (8, 4608): 0.24845608441453232, - (8, 6144): 0.3390117042943051, -} # Utilities # ================================ @@ -119,17 +106,35 @@ def _profile_distributed_op( dist.destroy_process_group() -# @lru_cache(maxsize=128, typed=False) def allreduce_cost( output_shape: list, repeat: int = 100, warmup_iters: int = 5, ) -> float: - cost = _ALL_REDUCE_COST_DB.get(tuple(output_shape), None) - if cost is None: - raise ValueError(f"Unknown allreduce cost for shape: {output_shape}") + ds, hs = output_shape - return cost * 1e-3 # convert back to seconds + intercept = 0.40594790939481484 + + coeff = [ + 0.0, + -0.00019876370905316763, + -4.174260473864464e-06, + 4.019442387061491e-08, + 6.210839534401708e-07, + 4.909228531291631e-11, + ] + + cost = ( + intercept + + coeff[0] + + (coeff[1] * ds) + + (coeff[2] * hs) + + (coeff[3] * ds**2) + + (coeff[4] * ds * hs) + + (coeff[5] * hs**2) + ) + + return cost * 1e-3 def allgather_cost( @@ -137,9 +142,28 @@ def allgather_cost( repeat: int = 100, warmup_iters: int = 5, ) -> float: - cost = _ALL_GATHER_COST_DB.get(tuple(output_shape), None) - if cost is None: - raise ValueError(f"Unknown allgather cost for shape: {output_shape}") + ds, hs = output_shape + + intercept = 0.478361915750253 + + coeff = [ + 0, + -0.00025625419990716485, + -1.9612017748514218e-05, + 4.892589021040619e-08, + 3.375990357833703e-07, + 5.192329766543819e-10, + ] + + cost = ( + intercept + + coeff[0] + + (coeff[1] * ds) + + (coeff[2] * hs) + + (coeff[3] * ds**2) + + (coeff[4] * ds * hs) + + (coeff[5] * hs**2) + ) return cost * 1e-3 # convert back to seconds @@ -318,14 +342,8 @@ def _get_intra_op_comms_cost( layer_strategies: list, pass_args: dict, ): - bw = pass_args.get("intra_device_bandwidth", None) - lat = pass_args.get("intra_device_latency", None) data_size = pass_args.get("data_size", None) - if bw is None: - raise ValueError("intra_device_bandwidth is not provided") - if lat is None: - raise ValueError("intra_device_latency is not provided") if data_size is None: raise ValueError("data_size is not provided") @@ -355,14 +373,8 @@ def _get_resharding_cost( return 0 world_size = torch.distributed.get_world_size() - bw = pass_args.get("intra_device_bandwidth", None) - lat = pass_args.get("intra_device_latency", None) data_size = pass_args.get("data_size", None) - if bw is None: - raise ValueError("intra_device_bandwidth is not provided") - if lat is None: - raise ValueError("intra_device_latency is not provided") if data_size is None: raise ValueError("data_size is not provided") From 459d7909ea4b60fd2123f2c88c716414f8a7d542 Mon Sep 17 00:00:00 2001 From: Pedro Gimenes Date: Wed, 4 Sep 2024 17:44:16 +0000 Subject: [PATCH 90/93] move all op benchmarking to use real vLLM classes --- .../passes/module/analysis/autosharding.py | 153 +++++- .../passes/module/analysis/cost_modelling.py | 466 ++++++++++-------- 2 files changed, 391 insertions(+), 228 deletions(-) diff --git a/src/chop/passes/module/analysis/autosharding.py b/src/chop/passes/module/analysis/autosharding.py index 377b352c2..35e243c39 100644 --- a/src/chop/passes/module/analysis/autosharding.py +++ b/src/chop/passes/module/analysis/autosharding.py @@ -14,37 +14,51 @@ from .cost_modelling import ( _get_compute_cost_from_layer, - _get_resharding_cost_matrix, _get_intra_op_comms_cost, + _get_resharding_cost_matrix, + _get_memory_cost_from_layer, ) VllmLinear = vllm.model_executor.layers.linear.LinearBase logger = get_logger(__name__) -logger.setLevel("DEBUG") +logger.setLevel("WARNING") STRATEGY_MAP = OrderedDict( { - VllmLinear: [ + VllmLinear: ( "replicated", "column", "row", "data", - ], - VllmAttention: [ + ), + VllmAttention: ( "replicated", "head", - ], + ), type(None): None, } ) +def _get_output_shape_from_layer_type( + layer: torch.nn.Module, + data_size: int, +): + if isinstance(layer, VllmLinear): + size = torch.Size([data_size, layer.weight.shape[0]]) + elif isinstance(layer, VllmAttention): + size = torch.Size([data_size, layer.impl.head_size * layer.impl.num_heads]) + else: + raise ValueError(f"Unsupported layer type: {layer.__class__.__name__}") + + return tuple(size) + + def _linearize_resharding_cost( opt_var, parent_opt_var, resharding_costs, - pass_args, ): # Flatten resharding matrix resharding_costs = resharding_costs.flatten() @@ -68,6 +82,25 @@ def _linearize_resharding_cost( return expr, constr +def _get_memory_constraint( + memory_constraint_terms: list, + self_rank: int, + pass_args: dict, +): + budget = pass_args.get("gpu_memory_budget", None) + + if budget is None: + raise ValueError("gpu_memory_budget is required for autosharding analysis") + + memory_available = torch.cuda.get_device_properties(self_rank).total_memory * budget + + mem_constr_expr = 0 + for i, (opt_var, mem_cost) in enumerate(memory_constraint_terms): + mem_constr_expr += mem_cost @ opt_var + + return [mem_constr_expr <= memory_available] + + def _formulate_ilp( model: torch.nn.Module, pass_args: dict, @@ -85,8 +118,16 @@ def _formulate_ilp( # ILP variables constr = [] expr = 0 + megatron_soln = 0 + megatron_mem_cost = 0 - for layer in model.modules(): + bad_soln = 0 + bad_soln_memory_cost = 0 + + # List of tuples: (opt_var, memory_cost) + memory_constr_terms = [] + + for name, layer in model.named_modules(): # Skip non-leaf modules if len(list(layer.children())) > 0: @@ -112,24 +153,66 @@ def _formulate_ilp( cp.sum(opt_var) == 1, ] + # Calculate Megatron solution for comparison + megatron_opt_var = np.zeros(len(layer_strategies)) + bad_soln_opt_var = np.zeros(len(layer_strategies)) + + if "attn.c_attn" in name: + megatron_opt_var[1] = 1 # column + bad_soln_opt_var[1] = 1 # column + elif "attn.attn" in name: + megatron_opt_var[1] = 1 # head + bad_soln_opt_var[1] = 1 # head + elif "attn.c_proj" in name: + megatron_opt_var[2] = 1 # row + bad_soln_opt_var[1] = 1 # column + elif "mlp.c_fc" in name: + megatron_opt_var[1] = 1 # column + bad_soln_opt_var[1] = 1 # column + elif "mlp.c_proj" in name: + megatron_opt_var[2] = 1 # row + bad_soln_opt_var[2] = 1 # column + else: + raise ValueError(f"Unsupported layer name: {name}") + + setattr(layer, "megatron_opt_var", megatron_opt_var) + setattr(layer, "bad_soln_opt_var", bad_soln_opt_var) + # Consider compute cost # ============================ compute_cost = _get_compute_cost_from_layer( layer, layer_strategies, data_size=data_size, + benchmarking_device=self_rank, ) # Consider intra operator comms cost # ============================ + comms_cost = _get_intra_op_comms_cost( + layer_strategies=tuple(layer_strategies), + output_shape=_get_output_shape_from_layer_type(layer, data_size), + benchmarking_device=self_rank, + ) + + expr += (compute_cost + comms_cost) @ opt_var + megatron_soln += (compute_cost + comms_cost) @ megatron_opt_var + bad_soln += (compute_cost + comms_cost) @ bad_soln_opt_var + + # Consider memory cost + # ============================ + + mem_cost = _get_memory_cost_from_layer( layer, layer_strategies, - pass_args=pass_args, + benchmarking_device=self_rank, ) - expr += (compute_cost + comms_cost) @ opt_var + memory_constr_terms.append((opt_var, mem_cost)) + megatron_mem_cost += mem_cost @ megatron_opt_var + bad_soln_memory_cost += mem_cost @ bad_soln_opt_var # Consider resharding cost # ============================ @@ -144,17 +227,52 @@ def _formulate_ilp( f"Consider resharding cost between {parent_module.__class__.__name__} and {layer.__class__.__name__}" ) + parent_out_shape = _get_output_shape_from_layer_type( + parent_module, + data_size, + ) + resharding_costs = _get_resharding_cost_matrix( - layer, layer_strategies, parent_module, parent_strategies, pass_args + layer_strategies=layer_strategies, + parent_strategies=parent_strategies, + parent_out_shape=parent_out_shape, + benchmarking_device=self_rank, ) resharding_term, resharding_constraints = _linearize_resharding_cost( - opt_var, parent_module.opt_var, resharding_costs, pass_args + opt_var, + parent_module.opt_var, + resharding_costs, ) expr += resharding_term constr += resharding_constraints - return cp.Problem(cp.Minimize(expr), constr) + # Add Megatron solution for comparison + megatron_resharding_term = ( + parent_module.megatron_opt_var @ resharding_costs @ megatron_opt_var + ) + megatron_soln += megatron_resharding_term + + bad_soln_resharding_term = ( + parent_module.bad_soln_opt_var @ resharding_costs @ bad_soln_opt_var + ) + bad_soln += bad_soln_resharding_term + + # After processing all layers, consider memory constraints + # ============================ + + mem_constr = _get_memory_constraint( + memory_constraint_terms=memory_constr_terms, + self_rank=self_rank, + pass_args=pass_args, + ) + constr += mem_constr + + return ( + cp.Problem(cp.Minimize(expr), constr), + (megatron_soln, megatron_mem_cost), + mem_constr, + ) def _get_sharding_config(model): @@ -185,9 +303,10 @@ def _get_sharding_config(model): def autosharding_module_analysis_pass(model, pass_args={}): - problem = _formulate_ilp(model, pass_args) + problem, megatron, mem_constr = _formulate_ilp(model, pass_args) + megatron_soln, megatron_mem_cost = megatron problem.solve( - verbose=True, + verbose=pass_args.get(f"debug", False), scipy_options={ "disp": pass_args.get(f"debug", False), "time_limit": pass_args.get("time_limit", None), @@ -197,6 +316,10 @@ def autosharding_module_analysis_pass(model, pass_args={}): sharding_config = _get_sharding_config(model) + memory_available = torch.cuda.get_device_properties( + torch.distributed.get_rank() + ).total_memory * pass_args.get("gpu_memory_budget") + return model, { "sharding_config": sharding_config, } diff --git a/src/chop/passes/module/analysis/cost_modelling.py b/src/chop/passes/module/analysis/cost_modelling.py index c3ac9a746..9b421e68c 100644 --- a/src/chop/passes/module/analysis/cost_modelling.py +++ b/src/chop/passes/module/analysis/cost_modelling.py @@ -1,5 +1,6 @@ import os import math +import gc import numpy as np from copy import copy from functools import lru_cache @@ -10,8 +11,20 @@ import torch.multiprocessing as mp from torch.multiprocessing import Queue, set_start_method + import vllm from vllm.attention import Attention as VllmAttention +from vllm.attention import AttentionMetadata +from vllm.model_executor.layers.linear import ( + ReplicatedLinear, + ColumnParallelLinear, + RowParallelLinear, + DataParallelLinear, +) +from vllm.distributed.communication_op import ( + tensor_model_parallel_all_gather, + tensor_model_parallel_all_reduce, +) VllmLinear = vllm.model_executor.layers.linear.LinearBase @@ -20,12 +33,44 @@ # ================================ +def _linear_cls_from_config(config: str): + if config == "replicated": + return ReplicatedLinear + if config == "column": + return ColumnParallelLinear + if config == "row": + return RowParallelLinear + if config == "data": + return DataParallelLinear + + raise ValueError(f"Unknown linear config: {config}") + + def _profile_op( + op: str, fn: callable, - args: list, + shape: tuple, repeat: int, warmup_iters: int, + benchmarking_device: int = 0, + extra_args: list = [], ): + """ + Profile op ``repeat`` times with ``warmup_iters`` warmup iterations. + Generate random input tensors of shape ``shape`` and pass them to the function ``fn`` in each iteration. + + Args: + op (str): _description_ + fn (callable): _description_ + shape (tuple): _description_ + repeat (int): _description_ + warmup_iters (int): _description_ + benchmarking_device (int, optional): _description_. Defaults to 0. + extra_args (list, optional): _description_. Defaults to []. + + Returns: + _type_: _description_ + """ start_event = [ torch.cuda.Event(enable_timing=True, blocking=True) for _ in range(repeat) ] @@ -34,152 +79,74 @@ def _profile_op( ] for idx in range(repeat): - start_event[idx].record() - out = fn(*args) - end_event[idx].record() - torch.cuda.synchronize(device=f"cuda:0") - - elapsed = [start_event[idx].elapsed_time(end_event[idx]) for idx in range(repeat)] - - return out, np.mean(elapsed[warmup_iters:]) * 1e-3 # convert back to seconds - - -def _profile_distributed_op( - rank, - world_size, - result_queue, - repeat, - warmup_iters, - op, - global_shape, -): - os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "12356" - os.environ["RANK"] = str(rank) - - # Initialize - device = torch.device("cuda", rank) - dist.init_process_group("nccl", rank=rank, world_size=world_size, device_id=device) - torch.cuda.set_device(device) - - start_event = [ - torch.cuda.Event(enable_timing=True, blocking=True) for _ in range(repeat) - ] - end_event = [ - torch.cuda.Event(enable_timing=True, blocking=True) for _ in range(repeat) - ] - - for idx in range(repeat): - if op == "allgather": - output_tensor = torch.zeros(global_shape, device=device) - local_shape = [global_shape[0], global_shape[1] // world_size] - local_tensor = torch.randn(local_shape, device=device) - - dist.barrier() - start_event[idx].record() - - dist.all_gather_into_tensor(output_tensor, local_tensor) - output_tensor = output_tensor.movedim(0, 1) - output_tensor = output_tensor.reshape(global_shape) - + if op == "linear": + input_ = torch.randn(shape).to(f"cuda:{benchmarking_device}") + args = [input_] + extra_args + elif op == "attention": + local_query = torch.randn(shape).to(f"cuda:{benchmarking_device}") + local_key = torch.randn(shape).to(f"cuda:{benchmarking_device}") + local_value = torch.randn(shape).to(f"cuda:{benchmarking_device}") + args = [ + local_query, + local_key, + local_value, + None, # benchmark without KV cache + ] + extra_args elif op == "allreduce": - local_tensor = torch.randn(global_shape, device=device) - - dist.barrier() - start_event[idx].record() - - dist.all_reduce(local_tensor) + local_tensor = torch.randn(shape).to(f"cuda:{benchmarking_device}") + args = [local_tensor] + elif op == "allgather": + local_tensor = torch.randn(shape).to(f"cuda:{benchmarking_device}") + args = [local_tensor, -1] + else: + raise ValueError(f"Unknown op: {op}") - dist.barrier() + start_event[idx].record() + out = fn(*args) end_event[idx].record() - - torch.cuda.synchronize(device=device) + torch.cuda.synchronize(device=f"cuda:{benchmarking_device}") elapsed = [start_event[idx].elapsed_time(end_event[idx]) for idx in range(repeat)] - avg = sum(elapsed[warmup_iters:]) / len(elapsed[warmup_iters:]) - - if rank == 0: - result_queue.put(avg) - - dist.barrier() - dist.destroy_process_group() + return out, np.mean(elapsed[warmup_iters:]), elapsed +@lru_cache(maxsize=128, typed=False) def allreduce_cost( - output_shape: list, + output_shape: tuple, repeat: int = 100, warmup_iters: int = 5, + benchmarking_device: int = 0, ) -> float: - ds, hs = output_shape - - intercept = 0.40594790939481484 - - coeff = [ - 0.0, - -0.00019876370905316763, - -4.174260473864464e-06, - 4.019442387061491e-08, - 6.210839534401708e-07, - 4.909228531291631e-11, - ] - - cost = ( - intercept - + coeff[0] - + (coeff[1] * ds) - + (coeff[2] * hs) - + (coeff[3] * ds**2) - + (coeff[4] * ds * hs) - + (coeff[5] * hs**2) + _, cost, elapsed_times = _profile_op( + op="allreduce", + fn=tensor_model_parallel_all_reduce, + shape=output_shape, + repeat=repeat, + warmup_iters=warmup_iters, + benchmarking_device=benchmarking_device, ) - return cost * 1e-3 + return cost +@lru_cache(maxsize=128, typed=False) def allgather_cost( - output_shape: list, + local_shape: tuple, repeat: int = 100, warmup_iters: int = 5, + benchmarking_device: int = 0, ) -> float: - ds, hs = output_shape - - intercept = 0.478361915750253 - - coeff = [ - 0, - -0.00025625419990716485, - -1.9612017748514218e-05, - 4.892589021040619e-08, - 3.375990357833703e-07, - 5.192329766543819e-10, - ] - - cost = ( - intercept - + coeff[0] - + (coeff[1] * ds) - + (coeff[2] * hs) - + (coeff[3] * ds**2) - + (coeff[4] * ds * hs) - + (coeff[5] * hs**2) + _, cost, elapsed_times = _profile_op( + op="allgather", + fn=tensor_model_parallel_all_gather, + shape=local_shape, + repeat=repeat, + warmup_iters=warmup_iters, + benchmarking_device=benchmarking_device, ) - return cost * 1e-3 # convert back to seconds - - -def _get_output_shape_from_layer_type( - layer: torch.nn.Module, - data_size: int, -): - if isinstance(layer, VllmLinear): - size = torch.Size([data_size, layer.weight.shape[0]]) - elif isinstance(layer, VllmAttention): - size = torch.Size([data_size, layer.impl.head_size * layer.impl.num_heads]) - else: - raise ValueError(f"Unsupported layer type: {layer.__class__.__name__}") - - return list(size) + return cost # Compute cost @@ -188,78 +155,67 @@ def _get_output_shape_from_layer_type( @lru_cache(maxsize=128, typed=False) def _cached_linear_cost_from_local_shapes( - local_weight_shape: torch.Size, - local_bias_shape: torch.Size, - local_input_shape: torch.Size, + type: str, + data_size: int, + input_size: int, + output_size: int, repeat: int = 100, warmup_iters: int = 2, + benchmarking_device: int = 0, ): - local_weights = torch.randn(local_weight_shape).to("cuda:0") - local_bias = torch.randn(local_bias_shape).to("cuda:0") - local_input = torch.randn(local_input_shape).to("cuda:0") + cls = _linear_cls_from_config(type) + + layer = cls( + input_size=input_size, + output_size=output_size, + ) - _, elapsed = _profile_op( - fn=F.linear, - args=[local_input, local_weights, local_bias], + local_shape = (data_size, input_size) + if type == "data": + local_shape = (data_size // torch.distributed.get_world_size(), input_size) + elif type == "row": + local_shape = (data_size, input_size // torch.distributed.get_world_size()) + elif type in ["replicated", "column"]: + pass + else: + raise ValueError(f"Unknown type: {type}") + + _, elapsed, elapsed_list = _profile_op( + op="linear", + fn=layer, + shape=local_shape, repeat=repeat, warmup_iters=warmup_iters, + benchmarking_device=benchmarking_device, ) - return elapsed + return elapsed, elapsed_list def _get_linear_compute_cost( layer: torch.nn.Module, - layer_strategies: list, + layer_strategies: tuple, data_size: int, repeat: int = 100, warmup_iters: int = 2, + benchmarking_device: int = 0, ): - world_size = torch.distributed.get_world_size() - - # Global shapes - global_weight_shape = layer.weight.shape - global_bias_shape = layer.bias.shape - global_input_shape = torch.Size([data_size, global_weight_shape[1]]) + input_size = layer.input_size + output_size = layer.output_size cost_vector = [] for strategy in layer_strategies: - # Default values for local tensors - # (taken for replicated strategy) - local_weight_shape = copy(global_weight_shape) - local_bias_shape = copy(global_bias_shape) - local_input_shape = copy(global_input_shape) - - if strategy == "replicated": - pass - elif strategy == "column": - local_weight_shape = torch.Size( - [global_weight_shape[0] // world_size, global_weight_shape[1]] - ) - local_bias_shape = torch.Size([global_bias_shape[0] // world_size]) - elif strategy == "row": - local_input_shape = torch.Size( - [global_input_shape[0], global_input_shape[1] // world_size] - ) - local_weight_shape = torch.Size( - [global_weight_shape[0], global_weight_shape[1] // world_size] - ) - elif strategy == "data": - local_input_shape = torch.Size( - [global_input_shape[0] // world_size, global_input_shape[1]] - ) - else: - raise ValueError(f"Unknown strategy: {strategy}") - # Create local tensors - elapsed = _cached_linear_cost_from_local_shapes( - local_weight_shape, - local_bias_shape, - local_input_shape, + elapsed, elapsed_list = _cached_linear_cost_from_local_shapes( + type=strategy, + data_size=data_size, + input_size=input_size, + output_size=output_size, repeat=repeat, warmup_iters=warmup_iters, + benchmarking_device=benchmarking_device, ) cost_vector.append(elapsed) @@ -269,19 +225,35 @@ def _get_linear_compute_cost( @lru_cache(maxsize=128, typed=False) def _cached_attention_cost_from_local_shapes( - local_shape: torch.Size, + data_size: int, + num_heads: int, + head_size: int, repeat: int = 100, warmup_iters: int = 2, + benchmarking_device: int = 0, ): - local_query = torch.randn(local_shape).to("cuda:0") - local_key = torch.randn(local_shape).to("cuda:0") - local_value = torch.randn(local_shape).to("cuda:0") + local_shape = torch.Size([data_size, head_size * num_heads]) + + attn_meta = AttentionMetadata( + num_prefills=9, + num_prefill_tokens=data_size, + num_decode_tokens=0, + slot_mapping=None, + ) + attn_layer = VllmAttention( + num_heads=num_heads, + head_size=head_size, + scale=1.0, + ) - _, elapsed = _profile_op( - fn=F.scaled_dot_product_attention, - args=[local_query, local_key, local_value], + _, elapsed, _ = _profile_op( + op="attention", + fn=attn_layer, + shape=local_shape, repeat=repeat, warmup_iters=warmup_iters, + benchmarking_device=benchmarking_device, + extra_args=[attn_meta], ) return elapsed @@ -289,30 +261,31 @@ def _cached_attention_cost_from_local_shapes( def _get_attention_compute_cost( layer: torch.nn.Module, - layer_strategies: list, + layer_strategies: tuple, data_size: int, repeat: int = 100, warmup_iters: int = 2, + benchmarking_device: int = 0, ): - world_size = torch.distributed.get_world_size() - - global_shape = torch.Size([data_size, layer.impl.head_size * layer.impl.num_heads]) - cost_vector = [] for strategy in layer_strategies: if strategy == "replicated": - local_shape = copy(global_shape) + num_heads = layer.impl.num_heads elif strategy == "head": - local_shape = torch.Size([global_shape[0], global_shape[1] // world_size]) + num_heads = layer.impl.num_heads // torch.distributed.get_world_size() + else: raise ValueError(f"Unknown strategy: {strategy}") elapsed = _cached_attention_cost_from_local_shapes( - local_shape, + data_size=data_size, + num_heads=num_heads, + head_size=layer.impl.head_size, repeat=repeat, warmup_iters=warmup_iters, + benchmarking_device=benchmarking_device, ) cost_vector.append(elapsed) @@ -320,38 +293,43 @@ def _get_attention_compute_cost( return np.array(cost_vector) -def _get_compute_cost_from_layer(layer, layer_strategies, data_size): +def _get_compute_cost_from_layer( + layer, + layer_strategies, + data_size, + benchmarking_device: int = 0, +): if isinstance(layer, VllmLinear): return _get_linear_compute_cost( layer, layer_strategies, data_size, + benchmarking_device=benchmarking_device, ) if isinstance(layer, VllmAttention): return _get_attention_compute_cost( layer, layer_strategies, data_size, + benchmarking_device=benchmarking_device, ) else: raise ValueError(f"Unsupported layer type: {layer.__class__.__name__}") +@lru_cache(maxsize=128, typed=False) def _get_intra_op_comms_cost( - layer: torch.nn.Module, - layer_strategies: list, - pass_args: dict, + layer_strategies: tuple, + output_shape: tuple, + benchmarking_device: int = 0, ): - data_size = pass_args.get("data_size", None) - - if data_size is None: - raise ValueError("data_size is not provided") - comms_cost = np.zeros(len(layer_strategies)) for idx, strategy in enumerate(layer_strategies): if strategy == "row": - out_shape = _get_output_shape_from_layer_type(layer, data_size) - comms_cost[idx] = allreduce_cost(output_shape=out_shape) + comms_cost[idx] = allreduce_cost( + output_shape=output_shape, + benchmarking_device=benchmarking_device, + ) return comms_cost @@ -361,11 +339,10 @@ def _get_intra_op_comms_cost( def _get_resharding_cost( - layer: torch.nn.Module, module_strategy: str, - parent_module: torch.nn.Module, + parent_out_shape: tuple, parent_strategy: str, - pass_args: dict, + benchmarking_device: int = 0, ) -> float: # Strategies which always return RR sharding @@ -373,10 +350,6 @@ def _get_resharding_cost( return 0 world_size = torch.distributed.get_world_size() - data_size = pass_args.get("data_size", None) - - if data_size is None: - raise ValueError("data_size is not provided") # all gather operation skip_allgather = ( @@ -403,17 +376,23 @@ def _get_resharding_cost( ) if not skip_allgather: - out_shape = _get_output_shape_from_layer_type(parent_module, data_size) - cost = allgather_cost(output_shape=out_shape) - # elif... + local_shape = [parent_out_shape[0], parent_out_shape[1] // world_size] + cost = allgather_cost( + local_shape=tuple(local_shape), + benchmarking_device=benchmarking_device, + ) else: cost = 0 return cost +@lru_cache(maxsize=128, typed=False) def _get_resharding_cost_matrix( - layer, layer_strategies, parent_module, parent_strategies, pass_args + layer_strategies, + parent_strategies, + parent_out_shape, + benchmarking_device: int = 0, ): resharding_costs = np.zeros([len(parent_strategies), len(layer_strategies)]) @@ -421,12 +400,73 @@ def _get_resharding_cost_matrix( for parent_strategy_idx, parent_strategy in enumerate(parent_strategies): resharding_costs[parent_strategy_idx, module_strategy_idx] = ( _get_resharding_cost( - layer, module_strategy, - parent_module, + parent_out_shape, parent_strategy, - pass_args, + benchmarking_device=benchmarking_device, ) ) return resharding_costs + + +# Memory cost +# ================================ + + +def _get_gpu_memory_usage(): + torch.cuda.synchronize() + return torch.cuda.memory_allocated() + + +@lru_cache(maxsize=128, typed=False) +def _cached_get_linear_memory_cost( + input_size, + output_size, + bias, + strategies, + benchmarking_device: int = 0, +): + cost_vector = [] + peak_mems = [] + for strategy in strategies: + cls = _linear_cls_from_config(strategy) + + # Clear cache and reset stats + torch.cuda.empty_cache() + gc.collect() + torch.cuda.reset_peak_memory_stats() + + # Instantiate layer to measure memory usage + start_memory = _get_gpu_memory_usage() + _ = cls( + input_size=input_size, + output_size=output_size, + bias=bias is not None, + ).to(f"cuda:{benchmarking_device}") + end_memory = _get_gpu_memory_usage() + + # Record cost + cost_vector.append(end_memory - start_memory) + peak_mems.append(torch.cuda.max_memory_allocated()) + + return cost_vector + + +def _get_memory_cost_from_layer( + layer, + layer_strategies, + benchmarking_device: int = 0, +): + if isinstance(layer, VllmLinear): + return _cached_get_linear_memory_cost( + input_size=layer.input_size, + output_size=layer.output_size, + bias=layer.bias is not None, + strategies=tuple(layer_strategies), + benchmarking_device=benchmarking_device, + ) + elif isinstance(layer, VllmAttention): + return np.zeros(len(layer_strategies)) + else: + raise ValueError(f"Unsupported layer type: {layer.__class__.__name__}") From b24a51991836f2a1116bd27d258d129a8901b297 Mon Sep 17 00:00:00 2001 From: Pedro Gimenes Date: Thu, 5 Sep 2024 14:07:36 +0100 Subject: [PATCH 91/93] remove pynvml --- setup.py | 1 - 1 file changed, 1 deletion(-) diff --git a/setup.py b/setup.py index 5fdc54272..c60fb62ec 100644 --- a/setup.py +++ b/setup.py @@ -80,7 +80,6 @@ def get_system(): "sphinx-glpi-theme", "prettytable", "pyyaml", - "pynvml", "bitstring>=4.2", "myst_parser", "cvxpy", From a7c3bf7d9a1057f477944a86c7535f4114063199 Mon Sep 17 00:00:00 2001 From: Pedro Gimenes Date: Thu, 5 Sep 2024 14:29:30 +0000 Subject: [PATCH 92/93] include layer norm and residual in cost modelling --- .../passes/module/analysis/autosharding.py | 19 ++ .../passes/module/analysis/cost_modelling.py | 244 ++++++++++++++++-- 2 files changed, 237 insertions(+), 26 deletions(-) diff --git a/src/chop/passes/module/analysis/autosharding.py b/src/chop/passes/module/analysis/autosharding.py index 35e243c39..d8997c848 100644 --- a/src/chop/passes/module/analysis/autosharding.py +++ b/src/chop/passes/module/analysis/autosharding.py @@ -21,6 +21,9 @@ VllmLinear = vllm.model_executor.layers.linear.LinearBase +from vllm.model_executor.layers.layer_norm import LayerNormBase as VllmLayerNorm +from vllm.model_executor.layers.residual import ResidualBase as VllmResidual + logger = get_logger(__name__) logger.setLevel("WARNING") @@ -36,6 +39,14 @@ "replicated", "head", ), + VllmLayerNorm: ( + "replicated", + "data", + ), + VllmResidual: ( + "replicated", + "data", + ), type(None): None, } ) @@ -49,6 +60,10 @@ def _get_output_shape_from_layer_type( size = torch.Size([data_size, layer.weight.shape[0]]) elif isinstance(layer, VllmAttention): size = torch.Size([data_size, layer.impl.head_size * layer.impl.num_heads]) + elif isinstance(layer, VllmLayerNorm): + size = torch.Size([data_size, layer.normalized_shape[0]]) + elif isinstance(layer, VllmResidual): + size = torch.Size([data_size, 1]) else: raise ValueError(f"Unsupported layer type: {layer.__class__.__name__}") @@ -172,6 +187,10 @@ def _formulate_ilp( elif "mlp.c_proj" in name: megatron_opt_var[2] = 1 # row bad_soln_opt_var[2] = 1 # column + elif "ln" in name: + megatron_opt_var[0] = 1 + elif "res" in name: + megatron_opt_var[0] = 1 else: raise ValueError(f"Unsupported layer name: {name}") diff --git a/src/chop/passes/module/analysis/cost_modelling.py b/src/chop/passes/module/analysis/cost_modelling.py index 9b421e68c..ed9278039 100644 --- a/src/chop/passes/module/analysis/cost_modelling.py +++ b/src/chop/passes/module/analysis/cost_modelling.py @@ -13,7 +13,7 @@ import vllm -from vllm.attention import Attention as VllmAttention + from vllm.attention import AttentionMetadata from vllm.model_executor.layers.linear import ( ReplicatedLinear, @@ -21,11 +21,23 @@ RowParallelLinear, DataParallelLinear, ) +from vllm.model_executor.layers.layer_norm import ( + ReplicatedLayerNorm, + DataParallelLayerNorm, +) +from vllm.model_executor.layers.residual import ( + ReplicatedResidual, + DataParallelResidual, +) + from vllm.distributed.communication_op import ( tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce, ) +from vllm.attention import Attention as VllmAttention +from vllm.model_executor.layers.layer_norm import LayerNormBase as VllmLayerNorm +from vllm.model_executor.layers.residual import ResidualBase as VllmResidual VllmLinear = vllm.model_executor.layers.linear.LinearBase @@ -45,6 +57,21 @@ def _linear_cls_from_config(config: str): raise ValueError(f"Unknown linear config: {config}") +def _layer_norm_cls_from_config(config: str): + if config == "replicated": + return ReplicatedLayerNorm + if config == "data": + return DataParallelLayerNorm + + raise ValueError(f"Unknown layer norm config: {config}") + +def _residual_cls_from_config(config: str): + if config == "replicated": + return ReplicatedResidual + if config == "data": + return DataParallelResidual + + raise ValueError(f"Unknown residual config: {config}") def _profile_op( op: str, @@ -79,7 +106,7 @@ def _profile_op( ] for idx in range(repeat): - if op == "linear": + if op in ["linear", "layer_norm"]: input_ = torch.randn(shape).to(f"cuda:{benchmarking_device}") args = [input_] + extra_args elif op == "attention": @@ -92,6 +119,10 @@ def _profile_op( local_value, None, # benchmark without KV cache ] + extra_args + elif op == "residual": + input_ = torch.randn(shape).to(f"cuda:{benchmarking_device}") + residual = torch.randn(shape).to(f"cuda:{benchmarking_device}") + args = [input_, residual] + extra_args elif op == "allreduce": local_tensor = torch.randn(shape).to(f"cuda:{benchmarking_device}") args = [local_tensor] @@ -160,7 +191,7 @@ def _cached_linear_cost_from_local_shapes( input_size: int, output_size: int, repeat: int = 100, - warmup_iters: int = 2, + warmup_iters: int = 5, benchmarking_device: int = 0, ): cls = _linear_cls_from_config(type) @@ -197,7 +228,7 @@ def _get_linear_compute_cost( layer_strategies: tuple, data_size: int, repeat: int = 100, - warmup_iters: int = 2, + warmup_iters: int = 5, benchmarking_device: int = 0, ): @@ -229,7 +260,7 @@ def _cached_attention_cost_from_local_shapes( num_heads: int, head_size: int, repeat: int = 100, - warmup_iters: int = 2, + warmup_iters: int = 5, benchmarking_device: int = 0, ): local_shape = torch.Size([data_size, head_size * num_heads]) @@ -258,13 +289,12 @@ def _cached_attention_cost_from_local_shapes( return elapsed - def _get_attention_compute_cost( layer: torch.nn.Module, layer_strategies: tuple, data_size: int, repeat: int = 100, - warmup_iters: int = 2, + warmup_iters: int = 5, benchmarking_device: int = 0, ): @@ -291,6 +321,133 @@ def _get_attention_compute_cost( cost_vector.append(elapsed) return np.array(cost_vector) + +@lru_cache(maxsize=128, typed=False) +def _cached_layer_norm_cost_from_local_shapes( + type: str, + normalized_shape: tuple, + eps: float, + elementwise_affine: bool, + bias: bool, + data_size: int, + repeat: int = 100, + warmup_iters: int = 5, + benchmarking_device: int = 0, +): + + layer = _layer_norm_cls_from_config(type)( + normalized_shape=normalized_shape, + eps=eps, + elementwise_affine=elementwise_affine, + bias=bias, + ) + + local_shape = torch.Size([data_size, normalized_shape[0]]) + + _, elapsed, elapsed_list = _profile_op( + op="layer_norm", + fn=layer, + shape=local_shape, + repeat=repeat, + warmup_iters=warmup_iters, + benchmarking_device=benchmarking_device, + ) + + return elapsed + +def _get_layer_norm_compute_cost( + layer: torch.nn.Module, + layer_strategies: tuple, + data_size: int, + repeat: int = 100, + warmup_iters: int = 5, + benchmarking_device: int = 0, +): + + cost_vector = [] + for strategy in layer_strategies: + + if strategy == "replicated": + real_data_size = data_size + elif strategy == "data": + real_data_size = data_size // torch.distributed.get_world_size() + else: + raise ValueError(f"Unknown strategy: {strategy}") + + elapsed = _cached_layer_norm_cost_from_local_shapes( + type=strategy, + normalized_shape=layer.normalized_shape, + eps=layer.eps, + elementwise_affine=layer.elementwise_affine, + bias=layer.bias is not None, + data_size=real_data_size, + repeat=repeat, + warmup_iters=warmup_iters, + benchmarking_device=benchmarking_device, + ) + + cost_vector.append(elapsed) + + return np.array(cost_vector) + +@lru_cache(maxsize=128, typed=False) +def _cached_residual_cost_from_local_shapes( + type: str, + data_size: int, + hidden_size: int, + repeat: int = 100, + warmup_iters: int = 5, + benchmarking_device: int = 0, +): + cls = _residual_cls_from_config(type) + + layer = cls( + hidden_size=hidden_size, + ) + + local_shape = torch.Size([data_size, hidden_size]) + + _, elapsed, elapsed_list = _profile_op( + op="residual", + fn=layer, + shape=local_shape, + repeat=repeat, + warmup_iters=warmup_iters, + benchmarking_device=benchmarking_device, + ) + + return elapsed + +def _get_residual_compute_cost( + layer: torch.nn.Module, + layer_strategies: tuple, + data_size: int, + repeat: int = 100, + warmup_iters: int = 5, + benchmarking_device: int = 0, +): + cost_vector = [] + for strategy in layer_strategies: + + if strategy == "replicated": + real_data_size = data_size + elif strategy == "data": + real_data_size = data_size // torch.distributed.get_world_size() + else: + raise ValueError(f"Unknown strategy: {strategy}") + + elapsed = _cached_residual_cost_from_local_shapes( + type=strategy, + data_size=real_data_size, + hidden_size=layer.hidden_size, + repeat=repeat, + warmup_iters=warmup_iters, + benchmarking_device=benchmarking_device, + ) + + cost_vector.append(elapsed) + + return np.array(cost_vector) def _get_compute_cost_from_layer( @@ -299,20 +456,20 @@ def _get_compute_cost_from_layer( data_size, benchmarking_device: int = 0, ): + profile_kwargs = { + "layer": layer, + "layer_strategies": layer_strategies, + "data_size": data_size, + "benchmarking_device": benchmarking_device, + } if isinstance(layer, VllmLinear): - return _get_linear_compute_cost( - layer, - layer_strategies, - data_size, - benchmarking_device=benchmarking_device, - ) - if isinstance(layer, VllmAttention): - return _get_attention_compute_cost( - layer, - layer_strategies, - data_size, - benchmarking_device=benchmarking_device, - ) + return _get_linear_compute_cost(**profile_kwargs) + elif isinstance(layer, VllmAttention): + return _get_attention_compute_cost(**profile_kwargs) + elif isinstance(layer, VllmResidual): + return _get_residual_compute_cost(**profile_kwargs) + elif isinstance(layer, VllmLayerNorm): + return _get_layer_norm_compute_cost(**profile_kwargs) else: raise ValueError(f"Unsupported layer type: {layer.__class__.__name__}") @@ -421,12 +578,12 @@ def _get_gpu_memory_usage(): @lru_cache(maxsize=128, typed=False) def _cached_get_linear_memory_cost( - input_size, - output_size, - bias, - strategies, + input_size: int, + output_size: int, + bias: bool, + strategies: tuple, benchmarking_device: int = 0, -): +) -> list: cost_vector = [] peak_mems = [] for strategy in strategies: @@ -452,6 +609,33 @@ def _cached_get_linear_memory_cost( return cost_vector +@lru_cache(maxsize=128, typed=False) +def _cached_get_layer_norm_memory_cost( + normalized_shape: tuple, + elementwise_affine: bool, + bias: bool, + strategies: tuple, + benchmarking_device: int = 0, +) -> list: + cls = _layer_norm_cls_from_config("replicated") + + # Clear cache and reset stats + torch.cuda.empty_cache() + gc.collect() + torch.cuda.reset_peak_memory_stats() + + # Instantiate layer to measure memory usage + start_memory = _get_gpu_memory_usage() + _ = cls( + normalized_shape=normalized_shape, + elementwise_affine=elementwise_affine, + bias=bias is not None, + ).to(f"cuda:{benchmarking_device}") + end_memory = _get_gpu_memory_usage() + + cost = end_memory - start_memory + + return [cost] * len(strategies) def _get_memory_cost_from_layer( layer, @@ -466,7 +650,15 @@ def _get_memory_cost_from_layer( strategies=tuple(layer_strategies), benchmarking_device=benchmarking_device, ) - elif isinstance(layer, VllmAttention): + elif isinstance(layer, VllmLayerNorm): + return _cached_get_layer_norm_memory_cost( + normalized_shape=layer.normalized_shape, + elementwise_affine=layer.elementwise_affine, + bias=layer.bias is not None, + strategies=tuple(layer_strategies), + benchmarking_device=benchmarking_device, + ) + elif isinstance(layer, (VllmAttention, VllmResidual)): return np.zeros(len(layer_strategies)) else: raise ValueError(f"Unsupported layer type: {layer.__class__.__name__}") From 4ccbbbc6144708ddc1f46502074c4c3912f0c599 Mon Sep 17 00:00:00 2001 From: Pedro Gimenes Date: Thu, 5 Sep 2024 15:44:09 +0000 Subject: [PATCH 93/93] include residual resharding cost --- .../passes/module/analysis/autosharding.py | 32 +++++++++++++++++-- 1 file changed, 29 insertions(+), 3 deletions(-) diff --git a/src/chop/passes/module/analysis/autosharding.py b/src/chop/passes/module/analysis/autosharding.py index d8997c848..cd4d0bd33 100644 --- a/src/chop/passes/module/analysis/autosharding.py +++ b/src/chop/passes/module/analysis/autosharding.py @@ -129,6 +129,7 @@ def _formulate_ilp( module_list = [] module_strategies = [] + last_residual = None # ILP variables constr = [] @@ -157,6 +158,8 @@ def _formulate_ilp( if layer_type is None or layer_strategies is None: continue + layer.strategies = layer_strategies + # Register layer and instantiate optimization variable # ============================ module_list.append(layer) @@ -265,7 +268,7 @@ def _formulate_ilp( ) expr += resharding_term constr += resharding_constraints - + # Add Megatron solution for comparison megatron_resharding_term = ( parent_module.megatron_opt_var @ resharding_costs @ megatron_opt_var @@ -277,6 +280,31 @@ def _formulate_ilp( ) bad_soln += bad_soln_resharding_term + # Residual layers may have an additional resharding cost for the residual path + if isinstance(layer, VllmResidual) and last_residual is not None: + last_residual_shape = _get_output_shape_from_layer_type( + last_residual, + data_size, + ) + + resharding_costs = _get_resharding_cost_matrix( + layer_strategies=layer_strategies, + parent_strategies=last_residual.strategies, + parent_out_shape=last_residual_shape, + benchmarking_device=self_rank, + ) + + resharding_term, resharding_constraints = _linearize_resharding_cost( + opt_var, + last_residual.opt_var, + resharding_costs, + ) + expr += resharding_term + constr += resharding_constraints + + if isinstance(layer, VllmResidual): + last_residual = layer + # After processing all layers, consider memory constraints # ============================ @@ -295,8 +323,6 @@ def _formulate_ilp( def _get_sharding_config(model): - self_rank = torch.distributed.get_rank() - sharding_config = {} for layer in model.modules():