Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
98 commits
Select commit Hold shift + click to select a range
b5bdad3
initial commit
Jun 4, 2024
b67d712
cvxpy bug for resharding cost term in intra operator pass
pgimenes Jun 6, 2024
82b706e
update resharding cost model and ILP constraints
pgimenes Jun 11, 2024
e1cd29a
some refactoring
pgimenes Jun 11, 2024
b3ecaac
attach to runtime
pgimenes Jun 11, 2024
df4a02b
handle resharding between nodes, improve logging
pgimenes Jun 12, 2024
74c2712
make sharding decision dict based instead of positional
pgimenes Jun 12, 2024
357e8c3
run autosharding on huggingface bert
pgimenes Jun 13, 2024
bc8fef1
refactoring to fix circular imports
pgimenes Jun 13, 2024
dbc6591
need to fix batch dimension sharding and sharding along multiple mesh…
pgimenes Jun 13, 2024
04ebd34
autosharding works on patched bert without batch dimension sharding
pgimenes Jun 14, 2024
b1baaf3
insert inference timing and lower logging level
pgimenes Jun 14, 2024
6e14c21
pipeline for distributed inference + report parallelization pass
pgimenes Jun 19, 2024
ad299bc
layer norm and baddbmm ops
pgimenes Jun 25, 2024
4566af9
[REFACTOR] Export strategies for each node using torch distributed fo…
pgimenes Jun 25, 2024
4b62b88
[REFACTOR] enumerate sharding strategies for reshape nodes: view, exp…
pgimenes Jun 27, 2024
41a0eb8
[REFACTOR] unfinished: pointwise ops (add, gelu)
pgimenes Jun 27, 2024
f76a6d0
[REFACTOR] enumerate strategies for pointwise add
pgimenes Jun 27, 2024
b15bf87
[REFACTOR] unfinished: pointwise truediv
pgimenes Jun 27, 2024
de857ae
[REFACTOR] fix for truediv strategy enumeration
pgimenes Jun 28, 2024
a444aad
[REFACTOR] finished enumerating strategies for all BERT ops. To do: c…
pgimenes Jun 28, 2024
f4f4ce7
slight refactoring
pgimenes Jun 28, 2024
68146a4
add tensor meta for placeholder and transpose ops
pgimenes Jul 2, 2024
c9afa87
include tensormeta for all ops
pgimenes Jul 2, 2024
d9ecb19
include resharding cost, ILP now too complex
pgimenes Jul 2, 2024
4f37d01
ILP is solvable after replacing inf values in resharding matrix
pgimenes Jul 2, 2024
b0b8631
refactoring
pgimenes Jul 2, 2024
1d3fe4c
skip fully replicated strategies for placeholder ops
pgimenes Jul 2, 2024
4a46f47
start docs
pgimenes Jul 2, 2024
3c4b4a5
unnecessary imports
pgimenes Jul 3, 2024
80f50fb
export solution and optimizer profiling
pgimenes Jul 3, 2024
3e1fd15
vectorize constraint for linearized resharding cost variable and enab…
pgimenes Jul 9, 2024
a50e3cc
mark sharding and run checks for linearised variable constraints
pgimenes Jul 9, 2024
4e384d2
[ATTACH]: distribute get_attr nodes, bug in forward pass
pgimenes Jul 9, 2024
f36a790
fix
pgimenes Jul 9, 2024
65abdb9
enabling import/export autoshardig solutions
pgimenes Jul 10, 2024
8a9b043
common metadata for OPT at call_function granularity
pgimenes Jul 10, 2024
276c9c4
handle embedding op in autosharding
pgimenes Jul 11, 2024
6fc504f
tensormeta for embedding op
pgimenes Jul 11, 2024
c111fa1
support autosharding for OPT
pgimenes Jul 15, 2024
17374f6
support activation modules, remove legacy stuff, some refactoring
pgimenes Jul 16, 2024
8e738b7
extrapolate sharding from single layer solution
pgimenes Jul 17, 2024
1eed07e
layout for extended docs
pgimenes Jul 17, 2024
6f740a2
make dist barrier asynchronous for distributed timing, and account fo…
pgimenes Jul 18, 2024
663049e
Merge branch 'merging-fixes' into research/alpa-light
pgimenes Jul 18, 2024
014c6db
some docs
pgimenes Jul 18, 2024
9379b32
Merge branch 'research/alpa-light' of https://github.com/DeepWok/mase…
pgimenes Jul 18, 2024
5cfeb5d
fixes to extrapolate single layer solution, improved reporting for ex…
pgimenes Jul 19, 2024
c5dd6b8
Merge branch 'research/alpa-light' of https://github.com/DeepWok/mase…
pgimenes Jul 19, 2024
2e3440d
get solution extrapolation working for GPT2
pgimenes Jul 23, 2024
166fbcf
migrate DTensor API to chop/distributed/tensor
pgimenes Jul 23, 2024
bf626bf
patch for redistribute
pgimenes Jul 23, 2024
3fa6744
remove logging
pgimenes Jul 23, 2024
aeee73e
Merge branch 'main' into research/alpa-light
pgimenes Jul 23, 2024
519d09c
fix circular import and remove "setting verbosity to debug" message a…
pgimenes Jul 23, 2024
e8849a8
remove breakpoints
pgimenes Jul 23, 2024
550da9c
formatting
pgimenes Jul 23, 2024
64acb52
fix circular import
pgimenes Jul 23, 2024
5c8f36a
remove unfinished emit verilog tests for llama/mistral
pgimenes Jul 23, 2024
5eee310
remove deprecated stuff
pgimenes Jul 24, 2024
264683d
reduce ILP complexity by skipping placeholder/get_attr candidate shar…
pgimenes Jul 24, 2024
b594c76
refactor add_common_metadata such that args/kwargs ordering is preserved
pgimenes Jul 24, 2024
419dc5f
[UNFINISHED] profile ops with local tensor shapes to formulate comput…
pgimenes Jul 24, 2024
bc5ff83
[UNFINISHED]: simplify DTensor OpDispatcher
pgimenes Jul 24, 2024
9d95ace
find and replace call_method nodes with call_functional for arg order…
pgimenes Jul 26, 2024
25f4c30
insert resharding nodes
pgimenes Jul 26, 2024
996d2fa
simplify mase launcher and get resharding nodes working with non full…
pgimenes Jul 29, 2024
47eb41b
simplify op dispatcher
pgimenes Jul 30, 2024
c1e3a7c
add torch.mm as an op
pgimenes Jul 30, 2024
ee517df
get refactored op dispatcher working on single layer gpt2
pgimenes Jul 30, 2024
0747681
remove logging which was slowing down runtime and remove redistribute…
pgimenes Jul 31, 2024
45f9c71
DTensor: remove duplicated op call for out tensor meta propagation
pgimenes Aug 5, 2024
7ccd4eb
finish OpDispatcher refactoring + fix bug in pointwise_strategy + ins…
pgimenes Aug 8, 2024
9fcb595
include fully replicated backend for autosharding
pgimenes Aug 8, 2024
574689a
include DTensorCache to bypass DTensor construction + remove high ove…
pgimenes Aug 9, 2024
ea6bbe5
remove deprecated files in src/chop/distributed/tensor
pgimenes Aug 15, 2024
13efcf2
include compute cost in ILP
pgimenes Aug 15, 2024
c3a43a6
support sdpa strategy
pgimenes Aug 15, 2024
98ed095
directory refactoring
pgimenes Aug 15, 2024
b34e692
Merge branch 'main' into research/alpa-light
pgimenes Aug 15, 2024
935e876
remove breakpoint
pgimenes Aug 15, 2024
082e6d0
remove deprecated pass
pgimenes Aug 15, 2024
6cc1256
remove deprecated tests
pgimenes Aug 15, 2024
0de6731
set benchmarking device for compute cost estimation in intra operator…
pgimenes Aug 16, 2024
f323066
remove MLIR CI
pgimenes Aug 16, 2024
7ab2d94
update torch.distributed.tensor imports since _tensor has been added …
pgimenes Aug 19, 2024
d3f43f9
remove breakpoint
pgimenes Aug 19, 2024
97aa774
revert changes to support SDPA which had been incorrectly merged
pgimenes Aug 19, 2024
b5759a8
fix torch.distributed.tensor imports
pgimenes Aug 29, 2024
6e43341
remove deprecated files
pgimenes Aug 29, 2024
611c835
module level autosharding for vllm
pgimenes Aug 29, 2024
d31ce94
module level autosharding: update cost modelling for allgather/allred…
pgimenes Aug 30, 2024
fcabe9a
fix cost modelling time units
pgimenes Aug 30, 2024
dc6e03e
allgather/allreduce: replace cost db with regression model
pgimenes Sep 2, 2024
459d790
move all op benchmarking to use real vLLM classes
pgimenes Sep 4, 2024
b24a519
remove pynvml
pgimenes Sep 5, 2024
a7c3bf7
include layer norm and residual in cost modelling
pgimenes Sep 5, 2024
4ccbbbc
include residual resharding cost
pgimenes Sep 5, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 0 additions & 40 deletions .github/workflows/testTorchMLIR.yml

This file was deleted.

3 changes: 1 addition & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ def get_system():
"sphinx-glpi-theme",
"prettytable",
"pyyaml",
"pynvml",
"bitstring>=4.2",
"myst_parser",
"cvxpy",
Expand All @@ -98,7 +97,7 @@ def get_system():
author="Aaron Zhao, Jianyi Cheng, Cheng Zhang, Pedro Gimenes",
author_email="a.zhao@imperial.ac.uk, jianyi.cheng17@imperial.ac.uk, chengzhang98@outlook.com, pedro.gimenes19@imperial.ac.uk",
license_files=("LICENSE",),
python_requires=">=3.11.9",
python_requires=">=3.11.4",
package_dir={
"": "src",
},
Expand Down
1 change: 0 additions & 1 deletion src/chop/distributed/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +0,0 @@
from .launcher import MaseLauncher
183 changes: 26 additions & 157 deletions src/chop/distributed/launcher.py
Original file line number Diff line number Diff line change
@@ -1,186 +1,55 @@
import os
from functools import partial
from time import time

import torch
import torch.nn as nn
import torch.distributed as dist
import torch.multiprocessing as mp

from torch.distributed._tensor import (
DeviceMesh,
Replicate,
Shard,
)

from chop.distributed.tensor import distribute_module, distribute_tensor

from chop.distributed.utils import rlog
from chop.distributed.utils import _get_mesh_from_world_size
from ..tools import get_logger

logger = get_logger(__name__)
logger.setLevel("DEBUG")


def distributed_timing(fn, *args, **kwargs):
dist.barrier(async_op=True)
start = time()
result = fn(*args, **kwargs)
dist.barrier(async_op=True)
end = time()

return result, (end - start)


def distributed_average_timing(fn, repeat, args):
times = []
for itr in range(repeat):
rlog(
logger,
dist.get_rank(),
f"Running teration {itr}",
"debug",
)
dist.barrier(async_op=True)
start = time()
result = fn(*args)
dist.barrier(async_op=True)
end = time()
times.append(end - start)
rlog(
logger,
dist.get_rank(),
f"Time taken: {end - start}s",
"debug",
)

return result, sum(times[2:]) / len(times[2:])


def dist_model_fn(
name: str,
module: nn.Module,
device_mesh: DeviceMesh,
rank: int,
tensor_sharding_map={},
) -> None:
"""
This function gets called by torch.distributed._tensor.distribute_module on each module in the model.
Each tensor in each module is distributed according to the sharding configuration in tensor_sharding_map.
"""
if module in tensor_sharding_map:
node_name = tensor_sharding_map[module]["node"]
for parameter, sharding_config in tensor_sharding_map[module][
"sharding"
].items():
if parameter in ["data_in_0", "output", "data_out_0"]:
continue
if not hasattr(module, parameter):
rlog(
logger,
rank,
f"Module {module} does not have parameter {parameter}",
level="warning",
)
continue

placement = sharding_config.placements

try:
rlog(
logger,
rank,
f"Distributing parameter {parameter} of module {node_name} to {placement}",
level="debug",
)
distributed_tensor = distribute_tensor(
getattr(module, parameter), device_mesh, placement
)
setattr(module, parameter, torch.nn.Parameter(distributed_tensor))
except Exception as e:
rlog(
logger,
rank,
f"Error distributing parameter {parameter} of module {node_name} to {placement}: {e}",
level="error",
)


def device_fn(
rank, world_size, model=None, device_mesh=None, tensor_sharding_map={}, inputs=[]
):
"""
This function gets called on each GPU device to set up the distributed environment and distribute the model,
following the SPMD model.
"""
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
os.environ["RANK"] = str(rank)

# Initialize
dist.init_process_group("nccl", rank=rank, world_size=world_size)
device = torch.device("cuda", rank)
torch.cuda.set_device(device)

# Distribute model parameters according to sharding configuration
mesh = DeviceMesh("cuda", mesh=device_mesh)
rlog(logger, rank, f"Distributing module parameters...", level="info")
model, dist_time = distributed_timing(
distribute_module,
model,
mesh,
partial(dist_model_fn, rank=rank, tensor_sharding_map=tensor_sharding_map),
input_fn=None,
output_fn=None,
)
rlog(logger, rank, f"Module distribution done. Time taken: {dist_time} seconds.")

# Run forward pass
rlog(logger, rank, f"Starting forward pass.", level="info")
inputs = [
distribute_tensor(in_tensor, mesh, [Replicate(), Replicate()])
for in_tensor in inputs
]
_, time_taken = distributed_average_timing(
fn=model,
repeat=10,
args=inputs,
)
rlog(logger, rank, f"Forward pass finished. Time taken: {time_taken}", level="info")

dist.destroy_process_group()


class MaseLauncher:
"""
MaseLauncher launches an optimized model on multiple GPUs using torch.distributed.
"""

def __init__(self, mase_graph, world_size=None, device_mesh=None):
def __init__(
self,
mg=None,
world_size=None,
device_mesh=None,
device_fn=None,
):
"""Initialize the MaseLauncher.

Args:
mase_graph (MaseGraph): The MaseGraph object containing the model.
world_size (int, optional): Number of GPUs to use. Defaults to None.
device_mesh (list, optional): List of GPUs to use. Defaults to None.
"""
self.mg = mase_graph
self.model = mase_graph.model
self.mg = mg
self.world_size = world_size
self.device_mesh = device_mesh
self.device_fn = device_fn

if device_mesh is None:
self.device_mesh, _ = _get_mesh_from_world_size(world_size)

def run(self, tensor_sharding_map={}, inputs=[]):
def run(
self,
model_class=None,
model_config=None,
cli_args=None,
):
logger.info(f"Launching model with world size {self.world_size}.")

mp.spawn(
partial(
device_fn,
model=self.model,
device_mesh=self.device_mesh,
tensor_sharding_map=tensor_sharding_map,
inputs=inputs,
self.device_fn,
args=(
self.world_size,
self.device_mesh,
model_class,
model_config,
cli_args,
),
args=(self.world_size,),
nprocs=self.world_size,
join=True,
)
21 changes: 19 additions & 2 deletions src/chop/distributed/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,8 @@
)
from torch.distributed.device_mesh import _mesh_resources, DeviceMesh, init_device_mesh

import chop.distributed.tensor.ops
from chop.distributed.tensor._utils import compute_local_shape
from chop.distributed.tensor.api import distribute_module, distribute_tensor, DTensor
from chop.distributed.tensor.ops.utils import normalize_to_torch_size


# All public APIs from dtensor package
Expand All @@ -33,6 +31,25 @@
]


def normalize_to_torch_size(size) -> torch.Size:
"""
Unify variable types of size argument to torch.Size
Acceptable types include:
int, Sequence[int], Tuple[int], Tuple[Sequence[int]],
or torch.Size
"""
if isinstance(size, torch.Size):
return size

if isinstance(size, int):
torch_size = [size]
elif len(size) == 1 and isinstance(size[0], Sequence):
torch_size = list(size[0])
else:
torch_size = list(size)
return torch.Size(torch_size)


def _dtensor_init_helper(
init_op,
size: torch.Size,
Expand Down
Loading