diff --git a/docs/README.md b/docs/README.md index cfa8724172..14ace1b664 100644 --- a/docs/README.md +++ b/docs/README.md @@ -38,3 +38,17 @@ the docs. Then force add the `build/html` directory: `git add -f build/html` Commit and push the changes to the `gh-pages` branch. + +## Doc Development Setup + +To enable live refresh of docs while writing: + +Install sphinx autobuild +``` +pip install sphinx-autobuild +``` + +Run auto build on docs/src folder +``` +sphinx-autobuild ./src ./build/html +``` diff --git a/docs/src/_static/tp_inference/all-to-sharded-linear.png b/docs/src/_static/tp_inference/all-to-sharded-linear.png new file mode 100644 index 0000000000..0974abb02c Binary files /dev/null and b/docs/src/_static/tp_inference/all-to-sharded-linear.png differ diff --git a/docs/src/_static/tp_inference/column-row-tp.png b/docs/src/_static/tp_inference/column-row-tp.png new file mode 100644 index 0000000000..8ace8e0393 Binary files /dev/null and b/docs/src/_static/tp_inference/column-row-tp.png differ diff --git a/docs/src/_static/tp_inference/llama-transformer.png b/docs/src/_static/tp_inference/llama-transformer.png new file mode 100644 index 0000000000..f0756a2bf2 Binary files /dev/null and b/docs/src/_static/tp_inference/llama-transformer.png differ diff --git a/docs/src/_static/tp_inference/sharded-to-all-linear.png b/docs/src/_static/tp_inference/sharded-to-all-linear.png new file mode 100644 index 0000000000..b4a50136ba Binary files /dev/null and b/docs/src/_static/tp_inference/sharded-to-all-linear.png differ diff --git a/docs/src/examples/data_parallelism.rst b/docs/src/examples/data_parallelism.rst new file mode 100644 index 0000000000..b1592611d6 --- /dev/null +++ b/docs/src/examples/data_parallelism.rst @@ -0,0 +1,91 @@ +.. _data_parallelism: + +Data Parallelism +================ + +MLX enables efficient data parallel distributed training through its +distributed communication primitives. + +.. _training_example: + +Training Example +---------------- + +In this section we will adapt an MLX training loop to support data parallel +distributed training. Namely, we will average the gradients across a set of +hosts before applying them to the model. + +Our training loop looks like the following code snippet if we omit the model, +dataset, and optimizer initialization. + +.. code:: python + + model = ... + optimizer = ... + dataset = ... + + def step(model, x, y): + loss, grads = loss_grad_fn(model, x, y) + optimizer.update(model, grads) + return loss + + for x, y in dataset: + loss = step(model, x, y) + mx.eval(loss, model.parameters()) + +All we have to do to average the gradients across machines is perform an +:func:`all_sum` and divide by the size of the :class:`Group`. Namely we +have to :func:`mlx.utils.tree_map` the gradients with following function. + +.. code:: python + + def all_avg(x): + return mx.distributed.all_sum(x) / mx.distributed.init().size() + +Putting everything together our training loop step looks as follows with +everything else remaining the same. + +.. code:: python + + from mlx.utils import tree_map + + def all_reduce_grads(grads): + N = mx.distributed.init().size() + if N == 1: + return grads + return tree_map( + lambda x: mx.distributed.all_sum(x) / N, + grads + ) + + def step(model, x, y): + loss, grads = loss_grad_fn(model, x, y) + grads = all_reduce_grads(grads) # <--- This line was added + optimizer.update(model, grads) + return loss + +Using ``nn.average_gradients`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Although the code example above works correctly; it performs one communication +per gradient. It is significantly more efficient to aggregate several gradients +together and perform fewer communication steps. + +This is the purpose of :func:`mlx.nn.average_gradients`. The final code looks +almost identical to the example above: + +.. code:: python + + model = ... + optimizer = ... + dataset = ... + + def step(model, x, y): + loss, grads = loss_grad_fn(model, x, y) + grads = mx.nn.average_gradients(grads) # <---- This line was added + optimizer.update(model, grads) + return loss + + for x, y in dataset: + loss = step(model, x, y) + mx.eval(loss, model.parameters()) diff --git a/docs/src/examples/tensor_parallelism.rst b/docs/src/examples/tensor_parallelism.rst new file mode 100644 index 0000000000..04717678ec --- /dev/null +++ b/docs/src/examples/tensor_parallelism.rst @@ -0,0 +1,239 @@ +.. _tensor_parallelism: + +Tensor Parallelism +================== + +In this example, we will explore how tensor parallelism (TP) works in MLX. We +will start with an overview of the distributed layers in ``mlx.nn`` and then +show how to do tensor parallelism Llama-style transformer models. + +Sharded Layers +-------------- + +:class:`AllToShardedLinear ` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +This layer replicates a common input and shards the weight matrix along the +output dimension across all devices in the :class:`mlx.core.distributed.Group`. +The layer produces a sharded output. + +For example, consider an :class:`mlx.nn.AllToShardedLinear` layer with +``input_dims=2`` and ``output_dims=2``, a batched input of shape ``(4, 2)``, +and a device group with 2 devices. The layer shards the weight matrix along the +output dimension across the two devices, where each device receives the full +input and computes a partial output. + +.. raw:: html + +
+ column-wise tensor parallelism +
+ +This layer does not automatically gather all outputs from each device. This is +an intended and :ref:`useful design choice `. + +:class:`QuantizedAllToShardedLinear ` is +the quantized equivalent of :class:`mlx.nn.AllToShardedLinear`. Similar to +:class:`mlx.nn.QuantizedLinear`, its parameters are frozen and will not be +included in any gradient computation. + + +:class:`ShardedToAllLinear ` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +This layer expects inputs that are sharded along the feature dimension and +shards the weight matrix along the input dimension across all devices in the +:class:`mlx.core.distributed.Group`. The layer automatically aggregates the +results using :class:`mlx.core.distributed.all_sum`, so all devices in the +group will have the same result. + +For example, consider an :class:`mlx.nn.ShardedToAllLinear` layer with +``input_dims=2`` and ``output_dims=2``, a batched input of shape ``(4, 2)``, +and a device group with 2 devices. The layer shards the weight matrix along the +input dimension across the two devices. Each device computes a ``(4,2)`` +output, which is then aggregated with all other device outputs to get layer +output. + + .. raw:: html + +
+ row-wise tensor parallelism +
+ +This layer does not automatically shard the inputs along the feature dimension +for you. It is necessary to create a "partial" input structure to feed into the +layer. This is an intended and :ref:`useful design choice +`. + +:class:`QuantizedShardedToAllLinear ` is +the quantized equivalent of :class:`mlx.nn.ShardedToAllLinear`. Similar to +:class:`mlx.nn.QuantizedLinear`, its parameters are frozen and will not be +included in any gradient computation. + + +Shard Utility Functions +----------------------- + +:func:`shard_linear ` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Converts a regular linear layer into a tensor parallel layer that distributes +computation across multiple devices. Takes an existing :class:`mlx.nn.Linear` +or :class:`mlx.nn.QuantizedLinear` layer and returns a new distributed layer +(either :class:`mlx.nn.AllToShardedLinear` or +:class:`mlx.nn.ShardedToAllLinear`, depending on the sharding type). The +original layer is not modified. + +:func:`shard_inplace ` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Splits the parameters of an existing layer across multiple devices by modifying +the layer in-place. Unlike :func:`shard_linear +`, this function does not create a new +layer or add distributed communication. The layer itself must handle +distributed communication if needed. + + +.. _useful_design_choices: + +Useful Design Choices +--------------------- + +The design choices above regarding when operations are done automatically are intentional and make model training and inference easier. + +All-to-sharded and sharded-to-all layers naturally go together because the +output of the former layer is exactly the input needed needed for the latter. +This removes the need for an intermediate gather step between the layers, +reducing communication overhead. + +This is why :class:`mlx.nn.AllToShardedLinear` does not aggregate results +automatically and why :class:`mlx.nn.ShardedToAllLinear` does not shard inputs +automatically. It is so that they can be placed in successive order and work +together easily. + +We can demonstrate this through a simple model using our two types of +distributed layers. + +.. code-block:: python + + x = ... # some (4, 2) model input: batch size 4, feature size 2 + + l1 = nn.AllToShardedLinear(2, 2, bias=False) # initialize the layer + l1_out = l1(x) # (4, 1) output + + l2 = nn.ShardedToAllLinear(2, 2, bias=False) + l2_out = l2(l1_out) # (4, 2) output + +.. raw:: html + +
+ two layer tensor parallelism +

A visualization of the simple MLX model using all-to-sharded then sharded-to-all tensor parallelism across 2 devices.

+
+ + +LLM Inference with Tensor Parallelism +------------------------------------- + +We can apply these TP techniques to LLMs in order to enable inference for much +larger models by sharding parameters from huge layers across multiple devices. + +To demonstrate this, let's apply TP to the Transformer block of our :doc:`Llama +Inference ` example. In this example, we will use the same +inference script as the Llama Inference example, which can be found in +`mlx-examples`_. + +Our first edit is to initialize the distributed communication group and get the +current process rank: + +.. code-block:: python + + world = mx.distributed.init() + rank = world.rank() + +Next, let's look at the current architecture of the transformer block and see how we can apply tensor parallelism: + +.. raw:: html + +
+ llama transformer example +
+ + +This architecture has two natural places where +tensor parallelism can be applied: the attention block and the FFN +block. Both follow the same pattern: multiple parallel linear layers operating +on the same input, followed by a single output linear layer. In the attention +block, the Q, K, and V projections are sharded along the output dimension (all-to-sharded), and the output +projection is sharded along the input dimension (sharded-to-all). Similarly in the FFN block, the gate and up projections +become all-to-sharded layers, and the down projection becomes an sharded-to-all layer. + +The intermediate operations between the linear layers (RoPE, softmax, scaled +dot-product attention in the attention block, and element-wise multiplication +in the FFN block) do not impede the use of our TP paradigm. These operations +are either: + +- **Element-wise operations** (RoPE, element-wise multiplication): These + operate independently on each element or position, preserving the sharding + pattern without requiring cross-device communication. + +- **Operations on non-sharded dimensions** (softmax, scaled dot-product + attention): These operate along dimensions that are not sharded (such as the + sequence length or head dimensions), so they can be computed independently on + each device. The attention computation ``Q @ K^T`` and ``scores @ V`` work + correctly with sharded Q, K, V tensors because the matrix multiplications are + performed along the sharded feature dimension, and the results remain + properly sharded for the subsequent sharded-to-all layer. + +To implement sharding in our Llama inference, we use :func:`shard_linear +` to get sharded linear layers with +distributed communication. This is easier than using :func:`shard_inplace +` and implementing the steps manually +in the :code:`__call__` function. + +The following code shows how to shard the Attention block. The Q, K, and V +projection layers are converted to all-to-sharded layers, while the output +projection is converted to a sharded-to-all layer. The number of heads are also +adjusted to account for the sharding: + +.. code-block:: python + + # ... in Attention class + def shard(self, group: mx.distributed.Group): + self.n_heads = self.n_heads // group.size() + self.n_kv_heads = self.n_kv_heads // group.size() + + self.wq = nn.layers.distributed.shard_linear(self.wq, "all-to-sharded", group=group) + self.wk = nn.layers.distributed.shard_linear(self.wk, "all-to-sharded", group=group) + self.wv = nn.layers.distributed.shard_linear(self.wv, "all-to-sharded", group=group) + self.wo = nn.layers.distributed.shard_linear(self.wo, "sharded-to-all", group=group) + +Similarly, the FeedForward block is sharded by converting the gate (w1) and up +(w3) projections to all-to-sharded layers, and the down projection (w2) to +a sharded-to-all layer: + +.. code-block:: python + + # ... in FeedForward class + def shard(self, group: mx.distributed.Group): + self.w1 = nn.layers.distributed.shard_linear(self.w1, "all-to-sharded", group=group) + self.w2 = nn.layers.distributed.shard_linear(self.w2, "sharded-to-all", group=group) + self.w3 = nn.layers.distributed.shard_linear(self.w3, "all-to-sharded", group=group) + +Finally, in our :code:`load_model` function, we need to apply our sharding +functions to all transformer layers when using multiple devices: + +.. code-block:: python + + # ... in load_model function + if world.size() > 1: + # convert Linear layers in Transformer/FFN to appropriate Sharded Layers + for layer in model.layers: + layer.attention.shard(group=world) + layer.feed_forward.shard(group=world) + +This allows us to use the llama inference file as normal when running +:code:`python llama.py`, but now we can also run it across two (or more) +devices via :code:`mlx.launch -n 2 llama.py`. + +.. _mlx-examples: https://github.com/ml-explore/mlx-examples/tree/main/llms/llama diff --git a/docs/src/index.rst b/docs/src/index.rst index 2148d94baf..74c52aaa2b 100644 --- a/docs/src/index.rst +++ b/docs/src/index.rst @@ -54,6 +54,8 @@ are the CPU and GPU. examples/linear_regression examples/mlp examples/llama-inference + examples/data_parallelism + examples/tensor_parallelism .. toctree:: :caption: Python API Reference diff --git a/docs/src/python/nn.rst b/docs/src/python/nn.rst index 5a7344d9f5..00f7f95456 100644 --- a/docs/src/python/nn.rst +++ b/docs/src/python/nn.rst @@ -183,3 +183,4 @@ In detail: nn/functions nn/losses nn/init + nn/distributed diff --git a/docs/src/python/nn/distributed.rst b/docs/src/python/nn/distributed.rst new file mode 100644 index 0000000000..07dd8e2308 --- /dev/null +++ b/docs/src/python/nn/distributed.rst @@ -0,0 +1,30 @@ +.. _nn_distributed: + +Distributed +----------- + +Helper Routines +^^^^^^^^^^^^^^^ + +The :code:`mlx.nn.layers.distributed` package contains helpful routines to +create sharded layers from existing :class:`Modules `. + +.. currentmodule:: mlx.nn.layers.distributed +.. autosummary:: + :toctree: _autosummary + + shard_linear + shard_inplace + +Layers +^^^^^^ + +.. currentmodule:: mlx.nn +.. autosummary:: + :toctree: _autosummary + :template: nn-module-template.rst + + AllToShardedLinear + ShardedToAllLinear + QuantizedAllToShardedLinear + QuantizedShardedToAllLinear diff --git a/docs/src/python/nn/layers.rst b/docs/src/python/nn/layers.rst index 1469481477..b9544bae51 100644 --- a/docs/src/python/nn/layers.rst +++ b/docs/src/python/nn/layers.rst @@ -10,6 +10,7 @@ Layers :template: nn-module-template.rst ALiBi + AllToShardedLinear AvgPool1d AvgPool2d AvgPool3d @@ -46,8 +47,10 @@ Layers Mish MultiHeadAttention PReLU + QuantizedAllToShardedLinear QuantizedEmbedding QuantizedLinear + QuantizedShardedToAllLinear RMSNorm ReLU ReLU2 @@ -56,6 +59,7 @@ Layers RoPE SELU Sequential + ShardedToAllLinear Sigmoid SiLU SinusoidalPositionalEncoding diff --git a/docs/src/usage/distributed.rst b/docs/src/usage/distributed.rst index 54d6617ffe..1f271835cc 100644 --- a/docs/src/usage/distributed.rst +++ b/docs/src/usage/distributed.rst @@ -117,89 +117,11 @@ The following examples aim to clarify the backend initialization logic in MLX: world_ring = mx.distributed.init(backend="ring") world_any = mx.distributed.init() # same as MPI because it was initialized first! -.. _training_example: +Distributed Program Examples +---------------------------- -Training Example ----------------- - -In this section we will adapt an MLX training loop to support data parallel -distributed training. Namely, we will average the gradients across a set of -hosts before applying them to the model. - -Our training loop looks like the following code snippet if we omit the model, -dataset and optimizer initialization. - -.. code:: python - - model = ... - optimizer = ... - dataset = ... - - def step(model, x, y): - loss, grads = loss_grad_fn(model, x, y) - optimizer.update(model, grads) - return loss - - for x, y in dataset: - loss = step(model, x, y) - mx.eval(loss, model.parameters()) - -All we have to do to average the gradients across machines is perform an -:func:`all_sum` and divide by the size of the :class:`Group`. Namely we -have to :func:`mlx.utils.tree_map` the gradients with following function. - -.. code:: python - - def all_avg(x): - return mx.distributed.all_sum(x) / mx.distributed.init().size() - -Putting everything together our training loop step looks as follows with -everything else remaining the same. - -.. code:: python - - from mlx.utils import tree_map - - def all_reduce_grads(grads): - N = mx.distributed.init().size() - if N == 1: - return grads - return tree_map( - lambda x: mx.distributed.all_sum(x) / N, - grads - ) - - def step(model, x, y): - loss, grads = loss_grad_fn(model, x, y) - grads = all_reduce_grads(grads) # <--- This line was added - optimizer.update(model, grads) - return loss - -Utilizing ``nn.average_gradients`` -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Although the code example above works correctly; it performs one communication -per gradient. It is significantly more efficient to aggregate several gradients -together and perform fewer communication steps. - -This is the purpose of :func:`mlx.nn.average_gradients`. The final code looks -almost identical to the example above: - -.. code:: python - - model = ... - optimizer = ... - dataset = ... - - def step(model, x, y): - loss, grads = loss_grad_fn(model, x, y) - grads = mx.nn.average_gradients(grads) # <---- This line was added - optimizer.update(model, grads) - return loss - - for x, y in dataset: - loss = step(model, x, y) - mx.eval(loss, model.parameters()) +- :ref:`Data Parallelism ` +- :ref:`Tensor Parallelism ` .. _ring_section: diff --git a/python/mlx/nn/layers/quantized.py b/python/mlx/nn/layers/quantized.py index 0e520d1517..15bbfa76dd 100644 --- a/python/mlx/nn/layers/quantized.py +++ b/python/mlx/nn/layers/quantized.py @@ -295,9 +295,8 @@ class QQLinear(Module): Compared to the :class:`mlx.nn.QuantizedLinear` layer, this layer quantizes the input as well and includes weights in gradient computations. - :obj:`QQLinear` also provides: - - the class method :meth:`from_linear` to convert :class:`mlx.nn.Linear` - layers to :obj:`QQLinear` layers. + :obj:`QQLinear` also provides the class method :meth:`from_linear` to + convert :class:`mlx.nn.Linear` layers to :obj:`QQLinear` layers. Note: This layer does not support a bias term yet.