Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions docs/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,17 @@ the docs. Then force add the `build/html` directory:
`git add -f build/html`

Commit and push the changes to the `gh-pages` branch.

## Doc Development Setup

To enable live refresh of docs while writing:

Install sphinx autobuild
```
pip install sphinx-autobuild
```

Run auto build on docs/src folder
```
sphinx-autobuild ./src ./build/html
```
Comment on lines +42 to +54
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's pretty cool.

Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/src/_static/tp_inference/column-row-tp.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
91 changes: 91 additions & 0 deletions docs/src/examples/data_parallelism.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
.. _data_parallelism:

Data Parallelism
================

MLX enables efficient data parallel distributed training through its
distributed communication primitives.

.. _training_example:

Training Example
----------------

In this section we will adapt an MLX training loop to support data parallel
distributed training. Namely, we will average the gradients across a set of
hosts before applying them to the model.

Our training loop looks like the following code snippet if we omit the model,
dataset, and optimizer initialization.

.. code:: python

model = ...
optimizer = ...
dataset = ...

def step(model, x, y):
loss, grads = loss_grad_fn(model, x, y)
optimizer.update(model, grads)
return loss

for x, y in dataset:
loss = step(model, x, y)
mx.eval(loss, model.parameters())

All we have to do to average the gradients across machines is perform an
:func:`all_sum` and divide by the size of the :class:`Group`. Namely we
have to :func:`mlx.utils.tree_map` the gradients with following function.

.. code:: python

def all_avg(x):
return mx.distributed.all_sum(x) / mx.distributed.init().size()

Putting everything together our training loop step looks as follows with
everything else remaining the same.

.. code:: python

from mlx.utils import tree_map

def all_reduce_grads(grads):
N = mx.distributed.init().size()
if N == 1:
return grads
return tree_map(
lambda x: mx.distributed.all_sum(x) / N,
grads
)

def step(model, x, y):
loss, grads = loss_grad_fn(model, x, y)
grads = all_reduce_grads(grads) # <--- This line was added
optimizer.update(model, grads)
return loss

Using ``nn.average_gradients``
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Although the code example above works correctly; it performs one communication
per gradient. It is significantly more efficient to aggregate several gradients
together and perform fewer communication steps.

This is the purpose of :func:`mlx.nn.average_gradients`. The final code looks
almost identical to the example above:

.. code:: python

model = ...
optimizer = ...
dataset = ...

def step(model, x, y):
loss, grads = loss_grad_fn(model, x, y)
grads = mx.nn.average_gradients(grads) # <---- This line was added
optimizer.update(model, grads)
return loss

for x, y in dataset:
loss = step(model, x, y)
mx.eval(loss, model.parameters())
239 changes: 239 additions & 0 deletions docs/src/examples/tensor_parallelism.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,239 @@
.. _tensor_parallelism:

Tensor Parallelism
==================

In this example, we will explore how tensor parallelism (TP) works in MLX. We
will start with an overview of the distributed layers in ``mlx.nn`` and then
show how to do tensor parallelism Llama-style transformer models.

Sharded Layers
--------------

:class:`AllToShardedLinear <mlx.nn.AllToShardedLinear>`
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

This layer replicates a common input and shards the weight matrix along the
output dimension across all devices in the :class:`mlx.core.distributed.Group`.
The layer produces a sharded output.

For example, consider an :class:`mlx.nn.AllToShardedLinear` layer with
``input_dims=2`` and ``output_dims=2``, a batched input of shape ``(4, 2)``,
and a device group with 2 devices. The layer shards the weight matrix along the
output dimension across the two devices, where each device receives the full
input and computes a partial output.

.. raw:: html

<div>
<img src="../_static/tp_inference/all-to-sharded-linear.png" alt="column-wise tensor parallelism" style="width: 100%">
</div>

This layer does not automatically gather all outputs from each device. This is
an intended and :ref:`useful design choice <useful_design_choices>`.

:class:`QuantizedAllToShardedLinear <mlx.nn.QuantizedAllToShardedLinear>` is
the quantized equivalent of :class:`mlx.nn.AllToShardedLinear`. Similar to
:class:`mlx.nn.QuantizedLinear`, its parameters are frozen and will not be
included in any gradient computation.


:class:`ShardedToAllLinear <mlx.nn.ShardedToAllLinear>`
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

This layer expects inputs that are sharded along the feature dimension and
shards the weight matrix along the input dimension across all devices in the
:class:`mlx.core.distributed.Group`. The layer automatically aggregates the
results using :class:`mlx.core.distributed.all_sum`, so all devices in the
group will have the same result.

For example, consider an :class:`mlx.nn.ShardedToAllLinear` layer with
``input_dims=2`` and ``output_dims=2``, a batched input of shape ``(4, 2)``,
and a device group with 2 devices. The layer shards the weight matrix along the
input dimension across the two devices. Each device computes a ``(4,2)``
output, which is then aggregated with all other device outputs to get layer
output.

.. raw:: html

<div>
<img src="../_static/tp_inference/sharded-to-all-linear.png" alt="row-wise tensor parallelism" style="width: 100%">
</div>

This layer does not automatically shard the inputs along the feature dimension
for you. It is necessary to create a "partial" input structure to feed into the
layer. This is an intended and :ref:`useful design choice
<useful_design_choices>`.

:class:`QuantizedShardedToAllLinear <mlx.nn.QuantizedShardedToAllLinear>` is
the quantized equivalent of :class:`mlx.nn.ShardedToAllLinear`. Similar to
:class:`mlx.nn.QuantizedLinear`, its parameters are frozen and will not be
included in any gradient computation.


Shard Utility Functions
-----------------------

:func:`shard_linear <mlx.nn.layers.distributed.shard_linear>`
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Converts a regular linear layer into a tensor parallel layer that distributes
computation across multiple devices. Takes an existing :class:`mlx.nn.Linear`
or :class:`mlx.nn.QuantizedLinear` layer and returns a new distributed layer
(either :class:`mlx.nn.AllToShardedLinear` or
:class:`mlx.nn.ShardedToAllLinear`, depending on the sharding type). The
original layer is not modified.

:func:`shard_inplace <mlx.nn.layers.distributed.shard_inplace>`
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Splits the parameters of an existing layer across multiple devices by modifying
the layer in-place. Unlike :func:`shard_linear
<mlx.nn.layers.distributed.shard_linear>`, this function does not create a new
layer or add distributed communication. The layer itself must handle
distributed communication if needed.


.. _useful_design_choices:

Useful Design Choices
---------------------

The design choices above regarding when operations are done automatically are intentional and make model training and inference easier.

All-to-sharded and sharded-to-all layers naturally go together because the
output of the former layer is exactly the input needed needed for the latter.
This removes the need for an intermediate gather step between the layers,
reducing communication overhead.

This is why :class:`mlx.nn.AllToShardedLinear` does not aggregate results
automatically and why :class:`mlx.nn.ShardedToAllLinear` does not shard inputs
automatically. It is so that they can be placed in successive order and work
together easily.

We can demonstrate this through a simple model using our two types of
distributed layers.

.. code-block:: python

x = ... # some (4, 2) model input: batch size 4, feature size 2

l1 = nn.AllToShardedLinear(2, 2, bias=False) # initialize the layer
l1_out = l1(x) # (4, 1) output

l2 = nn.ShardedToAllLinear(2, 2, bias=False)
l2_out = l2(l1_out) # (4, 2) output

.. raw:: html

<div>
<img src="../_static/tp_inference/column-row-tp.png" alt="two layer tensor parallelism" style="width: 100%">
<p style="font-size: 0.85em; margin-top: 0.5em;"><small>A visualization of the simple MLX model using all-to-sharded then sharded-to-all tensor parallelism across 2 devices.</small></p>
</div>


LLM Inference with Tensor Parallelism
-------------------------------------

We can apply these TP techniques to LLMs in order to enable inference for much
larger models by sharding parameters from huge layers across multiple devices.

To demonstrate this, let's apply TP to the Transformer block of our :doc:`Llama
Inference <llama-inference>` example. In this example, we will use the same
inference script as the Llama Inference example, which can be found in
`mlx-examples`_.

Our first edit is to initialize the distributed communication group and get the
current process rank:

.. code-block:: python

world = mx.distributed.init()
rank = world.rank()

Next, let's look at the current architecture of the transformer block and see how we can apply tensor parallelism:

.. raw:: html

<div>
<img src="../_static/tp_inference/llama-transformer.png" alt="llama transformer example" style="width: 100%">
</div>


This architecture has two natural places where
tensor parallelism can be applied: the attention block and the FFN
block. Both follow the same pattern: multiple parallel linear layers operating
on the same input, followed by a single output linear layer. In the attention
block, the Q, K, and V projections are sharded along the output dimension (all-to-sharded), and the output
projection is sharded along the input dimension (sharded-to-all). Similarly in the FFN block, the gate and up projections
become all-to-sharded layers, and the down projection becomes an sharded-to-all layer.

The intermediate operations between the linear layers (RoPE, softmax, scaled
dot-product attention in the attention block, and element-wise multiplication
in the FFN block) do not impede the use of our TP paradigm. These operations
are either:

- **Element-wise operations** (RoPE, element-wise multiplication): These
operate independently on each element or position, preserving the sharding
pattern without requiring cross-device communication.

- **Operations on non-sharded dimensions** (softmax, scaled dot-product
attention): These operate along dimensions that are not sharded (such as the
sequence length or head dimensions), so they can be computed independently on
each device. The attention computation ``Q @ K^T`` and ``scores @ V`` work
correctly with sharded Q, K, V tensors because the matrix multiplications are
performed along the sharded feature dimension, and the results remain
properly sharded for the subsequent sharded-to-all layer.

To implement sharding in our Llama inference, we use :func:`shard_linear
<mlx.nn.layers.distributed.shard_linear>` to get sharded linear layers with
distributed communication. This is easier than using :func:`shard_inplace
<mlx.nn.layers.distributed.shard_inplace>` and implementing the steps manually
in the :code:`__call__` function.

The following code shows how to shard the Attention block. The Q, K, and V
projection layers are converted to all-to-sharded layers, while the output
projection is converted to a sharded-to-all layer. The number of heads are also
adjusted to account for the sharding:

.. code-block:: python

# ... in Attention class
def shard(self, group: mx.distributed.Group):
self.n_heads = self.n_heads // group.size()
self.n_kv_heads = self.n_kv_heads // group.size()

self.wq = nn.layers.distributed.shard_linear(self.wq, "all-to-sharded", group=group)
self.wk = nn.layers.distributed.shard_linear(self.wk, "all-to-sharded", group=group)
self.wv = nn.layers.distributed.shard_linear(self.wv, "all-to-sharded", group=group)
self.wo = nn.layers.distributed.shard_linear(self.wo, "sharded-to-all", group=group)

Similarly, the FeedForward block is sharded by converting the gate (w1) and up
(w3) projections to all-to-sharded layers, and the down projection (w2) to
a sharded-to-all layer:

.. code-block:: python

# ... in FeedForward class
def shard(self, group: mx.distributed.Group):
self.w1 = nn.layers.distributed.shard_linear(self.w1, "all-to-sharded", group=group)
self.w2 = nn.layers.distributed.shard_linear(self.w2, "sharded-to-all", group=group)
self.w3 = nn.layers.distributed.shard_linear(self.w3, "all-to-sharded", group=group)

Finally, in our :code:`load_model` function, we need to apply our sharding
functions to all transformer layers when using multiple devices:

.. code-block:: python

# ... in load_model function
if world.size() > 1:
# convert Linear layers in Transformer/FFN to appropriate Sharded Layers
for layer in model.layers:
layer.attention.shard(group=world)
layer.feed_forward.shard(group=world)

This allows us to use the llama inference file as normal when running
:code:`python llama.py`, but now we can also run it across two (or more)
devices via :code:`mlx.launch -n 2 llama.py`.

.. _mlx-examples: https://github.com/ml-explore/mlx-examples/tree/main/llms/llama
2 changes: 2 additions & 0 deletions docs/src/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ are the CPU and GPU.
examples/linear_regression
examples/mlp
examples/llama-inference
examples/data_parallelism
examples/tensor_parallelism

.. toctree::
:caption: Python API Reference
Expand Down
1 change: 1 addition & 0 deletions docs/src/python/nn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -183,3 +183,4 @@ In detail:
nn/functions
nn/losses
nn/init
nn/distributed
30 changes: 30 additions & 0 deletions docs/src/python/nn/distributed.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
.. _nn_distributed:

Distributed
-----------

Helper Routines
^^^^^^^^^^^^^^^

The :code:`mlx.nn.layers.distributed` package contains helpful routines to
create sharded layers from existing :class:`Modules <mlx.nn.Module>`.

.. currentmodule:: mlx.nn.layers.distributed
.. autosummary::
:toctree: _autosummary

shard_linear
shard_inplace

Layers
^^^^^^

.. currentmodule:: mlx.nn
.. autosummary::
:toctree: _autosummary
:template: nn-module-template.rst

AllToShardedLinear
ShardedToAllLinear
QuantizedAllToShardedLinear
QuantizedShardedToAllLinear
Loading
Loading