diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 9132a1478..4ed8dc0a8 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -79,4 +79,3 @@ pybind11-stubgen magi_attention.magi_attn_ext -o . > [!IMPORTANT] > Failure to update stubs after modifying C++ code may cause type checking errors during CI. -``` diff --git a/README.md b/README.md index 7112e99b0..121b0afb1 100644 --- a/README.md +++ b/README.md @@ -208,7 +208,7 @@ For more usage instructions, you can refer to our [docs](https://SandAI-org.gith import magi_attention from magi_attention.api import ( - magi_attn_flex_dispatch, calc_attn, undispatch, # interface functions + magi_attn_flex_key, dispatch, calc_attn, undispatch, # interface functions compute_pad_size, # helper functions ) from magi_attention.common import AttnRanges @@ -288,19 +288,23 @@ For more usage instructions, you can refer to our [docs](https://SandAI-org.gith # 1. the dispatched local token embedding may be shuffled along seqlen dim, # so it's safe for token-wise operations such as matmul, layer-norm, etc # while for sample-wise operations like RoPE, you might need to be more careful - # 2. the `magi_attn_runtime_key` holds some inner meta data as one argument for many other magi_attention APIs, - # which users don’t have to bother with - local_x, magi_attn_runtime_key = magi_attn_flex_dispatch( - x, + # 2. the `magi_attn_runtime_key` holds some inner meta data, + # as a required argument for many APIs of ``magi_attention``, + # which users don't have to bother with + magi_attn_runtime_key = magi_attn_flex_key( q_ranges=q_ranges, k_ranges=k_ranges, attn_mask_type=attn_mask_type, total_seqlen_q=total_seqlen_q, total_seqlen_k=total_seqlen_k, + num_heads_q=num_heads_q, + num_heads_kv=num_heads_kv, + head_dim=head_dim, pad_size=pad_size, chunk_size=chunk_size, cp_group_or_mesh=world_group, # assuming we only have 1-dim context parallelism (cp) ) + local_x = dispatch(x, key=magi_attn_runtime_key) # --- Simulate QKV projection --- # @@ -405,11 +409,12 @@ We provide additional [magi_attn_extensions](https://github.com/SandAI-org/MagiA - [ ] **[WIP]** Support Ampere, Blackwell as well as other GPU architectures. - [ ] **[WIP]** Optimize `Flex-Flash-Attention` kernels to improve performance and better support sparse attention (*such as [NSA](https://arxiv.org/pdf/2502.11089)*). - [ ] **[WIP]** Optimize `DistAttnSolver` to reduce CPU overhead for meta info calculation and support better comp-/comm- overlapping. +- [ ] **[WIP]** Optimize `DynamicAttnSolver` for hybrid attention model or dynamic mask scenarios like sparse attention. - [ ] **[WIP]** Provide a more comprehensive documentation with tutorials, and a more detailed technical blog / paper. - [ ] Support other attention patterns including cross-attention, and inference scenarios involving KV cache (*w.r.t. [Paged Attention](https://arxiv.org/abs/2309.06180)*). - [ ] Provide more example codes and recipes for various training scenarios. - [ ] Upgrade `MagiAttention` to a distributed native `Flex-Flash-Attention` kernel (*as a major version update*). -- [x] Support `Dynamic DistAttnSolver` with query/output communication pattern, one for either hybrid attention model or dynamic mask scenarios like sparse attention, the other for reducing communication overhead for many cases when only communicating key/value is not the best choice. +- [x] Support `DynamicAttnSolver` with query/output communication pattern, for reducing communication overhead for many cases when only communicating key/value is not the best choice. - [x] Support native `GroupCast` and `GroupReduce` communication kernels with inter-/intra-node hierarchical optimization (*similar to [DeepEP](https://github.com/deepseek-ai/DeepEP)*). - [x] Support learnable attention sink (*w.r.t. [StreamingLLM](https://arxiv.org/abs/2309.17453)*). - [x] Refactor `Distributed Attention Solver` to support all mask types with all kinds of overlap. diff --git a/docs/source/env_variables.md b/docs/source/env_variables.md index deec9a4b8..5561d691d 100644 --- a/docs/source/env_variables.md +++ b/docs/source/env_variables.md @@ -109,6 +109,16 @@ This feature is experimental and under active development for now. If the C++ extension is not found or this variable is set to `0`, it will fall back to the Python implementation. ``` +**MAGI_ATTENTION_AUTO_RANGE_MERGE** + +Toggle this env variable to ``1`` to enable automatic range merging for flex-flash-attention, +to improve performance by reducing the number of attention ranges. The default value is `0`. + +```{note} +This feature is experimental and under active development for now, +thus please do NOT enable it unless you know exactly what you are doing. +``` + **MAGI_ATTENTION_DIST_ATTN_RUNTIME_DICT_SIZE** Set the value of this env variable to control the size of `dist_attn_runtime_dict`. The default value is `100`. diff --git a/docs/source/magi_api.md b/docs/source/magi_api.md index 9138e5271..796af5443 100644 --- a/docs/source/magi_api.md +++ b/docs/source/magi_api.md @@ -25,18 +25,12 @@ To support computing irregular-shaped masks, we implemented a `flexible_flash_at ### Varlen Dispatch -If you're using a mask defined by `cu_seqlens`, such as a varlen full or varlen causal mask, we've designed a similar interface inspired by FlashAttention's API, making it easy for you to get started quickly. In the function named `magi_attn_varlen_dispatch`, you can obtain the dispatched `x` and `key`. +If you're using a mask defined by `cu_seqlens`, such as a varlen full or varlen causal mask, we've designed a similar interface `magi_attn_varlen_key` inspired by FlashAttention's API as follows, making it easy for you to get started quickly. ```{eval-rst} .. currentmodule:: magi_attention.api.magi_attn_interface ``` -```{eval-rst} -.. autofunction:: magi_attn_varlen_dispatch -``` - -The logic of the `magi_attn_varlen_dispatch` function mainly consists of two parts: it first calls `magi_attn_varlen_key` to compute a key value, and then uses this key to dispatch the input x. The description of `magi_attn_varlen_key` is as follows. - ```{eval-rst} .. autofunction:: magi_attn_varlen_key ``` @@ -52,18 +46,12 @@ Then the new mask will reuse the same dispatch solution as the mask used for dis ### Flexible Dispatch -If the masks you're using are not limited to varlen full or varlen causal, but also include sliding window masks or other more diverse types, we recommend using the following API. By calling `magi_attn_flex_dispatch`, you can obtain the dispatched x and key. +If the masks you're using are not limited to varlen full or varlen causal, but also include sliding window masks or other more diverse types, we recommend using the `magi_attn_flex_key` as follows. ```{eval-rst} .. currentmodule:: magi_attention.api.magi_attn_interface ``` -```{eval-rst} -.. autofunction:: magi_attn_flex_dispatch -``` - -Similar to the logic of `magi_attn_varlen_dispatch`, `magi_attn_flex_dispatch` first calls `magi_attn_flex_key` to obtain a key, and then uses this key to dispatch x. The description of `magi_attn_flex_key` is as follows. - ```{eval-rst} .. autofunction:: magi_attn_flex_key ``` @@ -78,7 +66,7 @@ Then the new mask will reuse the same dispatch solution as the mask used for dis ### Dispatch Function -If you already have the key, you can call `dispatch` function to get the padded and dispatched local tensor. +When you get the dist attn runtime key, you can call `dispatch` function to dispatch the global input tensor(s) to get the padded local tensor(s) along the seqlen dim. ```{eval-rst} .. currentmodule:: magi_attention.api.magi_attn_interface @@ -90,7 +78,7 @@ If you already have the key, you can call `dispatch` function to get the padded ## Calculate Attention -After dispatch and projection, you should obtain the query, key, and value needed for computation. Using the key obtained from the dispatch function mentioned above, you can perform the computation by calling `calc_attn`, which returns the results out and meta (containing lse). +After dispatch and QKV projection, you should obtain the local query, key, and value. Then you can calculate the distributed attention by calling `calc_attn` with the dist attn runtime key to get the local attention output tensor. ```{eval-rst} .. currentmodule:: magi_attention.api.magi_attn_interface @@ -105,7 +93,7 @@ After dispatch and projection, you should obtain the query, key, and value neede ### Undispatch Function -When you need to recover the complete global tensor from the local tensor like computing the loss, you can call `undispatch` function to unpad and undispatch the local tensor along the seqlen dim. +When you need to recover the global output tensor(s) from the local one(s), to compute the loss or some reason else, you can call `undispatch` function to undispatch the padded local ouput tensor(s) back to the unpadded global tensor along the seqlen dim. ```{eval-rst} .. currentmodule:: magi_attention.api.magi_attn_interface diff --git a/docs/source/quickstart.md b/docs/source/quickstart.md index a3db6be97..fcb979cb3 100644 --- a/docs/source/quickstart.md +++ b/docs/source/quickstart.md @@ -84,7 +84,7 @@ import torch.distributed as dist import magi_attention from magi_attention.api import ( - magi_attn_flex_dispatch, calc_attn, undispatch, # interface functions + magi_attn_flex_key, dispatch, calc_attn, undispatch, # interface functions compute_pad_size, # helper functions ) from magi_attention.common import AttnRanges @@ -164,19 +164,23 @@ pad_size = compute_pad_size( # pad embeds along seqlen dim for better performanc # 1. the dispatched local token embedding may be shuffled along seqlen dim, # so it's safe for token-wise operations such as matmul, layer-norm, etc # while for sample-wise operations like RoPE, you might need to be more careful -# 2. the `magi_attn_runtime_key` holds some inner meta data as one argument for many other magi_attention APIs, -# which users don’t have to bother with -local_x, magi_attn_runtime_key = magi_attn_flex_dispatch( - x, +# 2. the `magi_attn_runtime_key` holds some inner meta data, +# as a required argument for many APIs of ``magi_attention``, +# which users don't have to bother with +magi_attn_runtime_key = magi_attn_flex_key( q_ranges=q_ranges, k_ranges=k_ranges, attn_mask_type=attn_mask_type, total_seqlen_q=total_seqlen_q, total_seqlen_k=total_seqlen_k, + num_heads_q=num_heads_q, + num_heads_kv=num_heads_kv, + head_dim=head_dim, pad_size=pad_size, chunk_size=chunk_size, cp_group_or_mesh=world_group, # assuming we only have 1-dim context parallelism (cp) ) +local_x = dispatch(x, key=magi_attn_runtime_key) # --- Simulate QKV projection --- # diff --git a/examples/torch_native/README.md b/examples/torch_native/README.md index 4244b48f4..1877ba00c 100644 --- a/examples/torch_native/README.md +++ b/examples/torch_native/README.md @@ -108,7 +108,14 @@ def train(model, optimizer, lr_scheduler, device_mesh, train_iter): ): # dispatched input and prepare magi_attn key. input, dist_attn_runtime_key = prepare_magi_attention( - input, cu_seqlens_q, cu_seqlens_k, pad_size, CHUNK_SIZE, device_mesh.get_group("cp") + input=input, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + num_heads_q=model.config.num_attention_heads, + num_heads_kv=model.config.num_key_value_heads, + head_dim=model.config.head_dim, + pad_size=pad_size, + cp_group=device_mesh.get_group("cp"), ) output = model(input, dist_attn_runtime_key) @@ -160,15 +167,26 @@ def prepare_data(device_mesh, train_iter): **Prepare magi_attn_key:** Dispatch input data along cp dim and get dist_attn_runtime_key. ```python -def prepare_magi_attention(input, cu_seqlens_q, cu_seqlens_k, pad_size, cp_group): - # --- magi_attn_flex_dispatch --- # +def prepare_magi_attention( + input: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + num_heads_q: int, + num_heads_kv: int, + head_dim: int, + pad_size: int, + cp_group: dist.ProcessGroup, +): + # --- magi_attn_varlen_dispatch --- # + dist_attn_config = DistAttnConfig() - # you can also use fa_varlen-like varlen dispatch interface directly - x_padded, dist_attn_runtime_key = magi_attn_varlen_dispatch( - input, + dist_attn_runtime_key = magi_attn_varlen_key( cu_seqlens_q, cu_seqlens_k, + num_heads_q=num_heads_q, + num_heads_kv=num_heads_kv, + head_dim=head_dim, pad_size=pad_size, chunk_size=CHUNK_SIZE, cp_group_or_mesh=cp_group, @@ -176,6 +194,8 @@ def prepare_magi_attention(input, cu_seqlens_q, cu_seqlens_k, pad_size, cp_group dist_attn_config=dist_attn_config, ) + x_padded = dispatch(input, key=dist_attn_runtime_key) + return x_padded, dist_attn_runtime_key ``` diff --git a/examples/torch_native/main.py b/examples/torch_native/main.py index 2bc9da48f..52337c41b 100644 --- a/examples/torch_native/main.py +++ b/examples/torch_native/main.py @@ -19,7 +19,7 @@ import torch.nn.functional as F from configuration_llama import LlamaConfig from llama_pretrain_config import data_config, parallel_config, train_config -from modeling_llama import LlamaDecoderLayer, build_llama3_1b_model +from modeling_llama import LlamaDecoderLayer, LlamaModel, build_llama3_1b_model from torch.distributed._composable.fsdp import fully_shard from torch.distributed.device_mesh import DeviceMesh from torch.distributed.tensor import DTensor, Partial, Shard, distribute_tensor @@ -28,8 +28,9 @@ from magi_attention.api import ( DistAttnConfig, compute_pad_size, + dispatch, infer_varlen_mask_from_batch, - magi_attn_varlen_dispatch, + magi_attn_varlen_key, squash_batch_dim, undispatch, ) @@ -238,16 +239,26 @@ def prepare_data(device_mesh, train_iter): return local_input, local_label, cu_seqlens_q, cu_seqlens_k, pad_size -def prepare_magi_attention(input, cu_seqlens_q, cu_seqlens_k, pad_size, cp_group): - # --- magi_attn_flex_dispatch --- # - # an example of distattnconfig +def prepare_magi_attention( + input: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + num_heads_q: int, + num_heads_kv: int, + head_dim: int, + pad_size: int, + cp_group: dist.ProcessGroup, +): + # --- magi_attn_varlen_dispatch --- # + dist_attn_config = DistAttnConfig() - # you can also use fa_varlen-like varlen dispatch interface directly - x_padded, dist_attn_runtime_key = magi_attn_varlen_dispatch( - input, + dist_attn_runtime_key = magi_attn_varlen_key( cu_seqlens_q, cu_seqlens_k, + num_heads_q=num_heads_q, + num_heads_kv=num_heads_kv, + head_dim=head_dim, pad_size=pad_size, chunk_size=CHUNK_SIZE, cp_group_or_mesh=cp_group, @@ -255,6 +266,8 @@ def prepare_magi_attention(input, cu_seqlens_q, cu_seqlens_k, pad_size, cp_group dist_attn_config=dist_attn_config, ) + x_padded = dispatch(input, key=dist_attn_runtime_key) + return x_padded, dist_attn_runtime_key @@ -292,7 +305,7 @@ def loss_func( return loss -def train(model, optimizer, lr_scheduler, device_mesh, train_iter): +def train(model: LlamaModel, optimizer, lr_scheduler, device_mesh, train_iter): """main training loop""" model.train() @@ -306,12 +319,19 @@ def train(model, optimizer, lr_scheduler, device_mesh, train_iter): dist_attn_runtime_key = None if ( - parallel_config["context_parallel_size"] > 1 + parallel_config["context_parallel_size"] > 1 # type: ignore[operator] and parallel_config["context_parallel_backend"] == "magi_attention" ): # dispatched input input, dist_attn_runtime_key = prepare_magi_attention( - input, cu_seqlens_q, cu_seqlens_k, pad_size, device_mesh.get_group("cp") + input=input, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + num_heads_q=model.config.num_attention_heads, + num_heads_kv=model.config.num_key_value_heads, + head_dim=model.config.head_dim, + pad_size=pad_size, + cp_group=device_mesh.get_group("cp"), ) output = model(input, dist_attn_runtime_key) diff --git a/examples/transformers/README.md b/examples/transformers/README.md index 54cce4158..d66d90314 100644 --- a/examples/transformers/README.md +++ b/examples/transformers/README.md @@ -93,11 +93,13 @@ def _prepare_inputs(): + ) + + local_input, magi_attn_key = self._prepare_magi_attention( -+ local_input, -+ cu_seqlens_q, -+ cu_seqlens_k, -+ pad_size, -+ self.model.config.head_dim, ++ inputs=local_input, ++ cu_seqlens_q=cu_seqlens_q, ++ cu_seqlens_k=cu_seqlens_k, ++ num_heads_q=self.model.config.num_attention_heads, ++ num_heads_kv=self.model.config.num_key_value_heads, ++ head_dim=self.model.config.head_dim, ++ pad_size=pad_size, + ) + position_ids = get_position_ids(magi_attn_key).unsqueeze(0) + @@ -108,25 +110,39 @@ def _prepare_inputs(): # dispatch data and prepare key + def _prepare_magi_attention( -+ self, inputs, cu_seqlens_q, cu_seqlens_k, pad_size, head_dim ++ self, ++ inputs: torch.Tensor, ++ cu_seqlens_q: torch.Tensor, ++ cu_seqlens_k: torch.Tensor, ++ num_heads_q: int, ++ num_heads_kv: int, ++ head_dim: int, ++ pad_size: int, + ): -+ # --- magi_attn_flex_dispatch --- # -+ dist_attn_config = DistAttnConfig() -+ cp_group = self._build_cp_group() -+ inputs = squash_batch_dim(inputs) ++ # --- magi_attn_varlen_dispatch --- # + -+ x_padded, dist_attn_runtime_key = magi_attn_varlen_dispatch( -+ inputs, -+ cu_seqlens_q, -+ cu_seqlens_k, -+ pad_size=pad_size, -+ cp_group_or_mesh=cp_group, -+ causal=True, -+ dist_attn_config=dist_attn_config, -+ ) -+ x_padded = x_padded.unsqueeze(0) ++ dist_attn_config = DistAttnConfig() ++ cp_group = self._build_cp_group() + -+ return x_padded, dist_attn_runtime_key ++ inputs = squash_batch_dim(inputs) ++ ++ dist_attn_runtime_key = magi_attn_varlen_key( ++ cu_seqlens_q=cu_seqlens_q, ++ cu_seqlens_k=cu_seqlens_k, ++ num_heads_q=num_heads_q, ++ num_heads_kv=num_heads_kv, ++ head_dim=head_dim, ++ chunk_size=512, ++ pad_size=pad_size, ++ cp_group_or_mesh=cp_group, ++ causal=True, ++ dist_attn_config=dist_attn_config, ++ ) ++ ++ x_padded = dispatch(inputs, key=dist_attn_runtime_key) ++ x_padded = x_padded.unsqueeze(0) ++ ++ return x_padded, dist_attn_runtime_key ``` Override `compute_loss` because we need to undispatch logits first: diff --git a/examples/transformers/magi_trainer.py b/examples/transformers/magi_trainer.py index e3af712ed..12a0cdbf8 100644 --- a/examples/transformers/magi_trainer.py +++ b/examples/transformers/magi_trainer.py @@ -40,10 +40,11 @@ from magi_attention.api import ( DistAttnConfig, compute_pad_size, + dispatch, get_most_recent_key, get_position_ids, infer_varlen_mask_from_batch, - magi_attn_varlen_dispatch, + magi_attn_varlen_key, squash_batch_dim, undispatch, ) @@ -291,24 +292,36 @@ def _build_cp_group(self): return cp_group def _prepare_magi_attention( - self, inputs, cu_seqlens_q, cu_seqlens_k, pad_size, head_dim + self, + inputs: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + num_heads_q: int, + num_heads_kv: int, + head_dim: int, + pad_size: int, ): - # --- magi_attn_flex_dispatch --- # + # --- magi_attn_varlen_dispatch --- # + dist_attn_config = DistAttnConfig() cp_group = self._build_cp_group() inputs = squash_batch_dim(inputs) - x_padded, dist_attn_runtime_key = magi_attn_varlen_dispatch( - inputs, - cu_seqlens_q, - cu_seqlens_k, + dist_attn_runtime_key = magi_attn_varlen_key( + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + num_heads_q=num_heads_q, + num_heads_kv=num_heads_kv, + head_dim=head_dim, chunk_size=512, pad_size=pad_size, cp_group_or_mesh=cp_group, causal=True, dist_attn_config=dist_attn_config, ) + + x_padded = dispatch(inputs, key=dist_attn_runtime_key) x_padded = x_padded.unsqueeze(0) return x_padded, dist_attn_runtime_key @@ -378,11 +391,13 @@ def _prepare_inputs( ) local_input, magi_attn_key = self._prepare_magi_attention( - local_input, - cu_seqlens_q, - cu_seqlens_k, - pad_size, - self.model.config.head_dim, + inputs=local_input, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + num_heads_q=self.model.config.num_attention_heads, + num_heads_kv=self.model.config.num_key_value_heads, + head_dim=self.model.config.head_dim, + pad_size=pad_size, ) position_ids = get_position_ids(magi_attn_key).unsqueeze(0) diff --git a/exps/dist_attn/Dockerfile.benchmark b/exps/dist_attn/Dockerfile.benchmark new file mode 100644 index 000000000..2a0c52481 --- /dev/null +++ b/exps/dist_attn/Dockerfile.benchmark @@ -0,0 +1,118 @@ +FROM nvcr.io/nvidia/pytorch:25.10-py3 + +ARG https_proxy +ARG http_proxy + +# Package Version +ARG NCCL_TESTS_VERSION=v2.16.4 +ARG CUTLASS_DSL_VERSION=4.3.4 +ARG FLASH_ATTENTION_VERSION=v2.8.3 +ARG FLASH_ATTENTION_COMMIT_ID="b613d9e2c8475945baff3fd68f2030af1b890acf" # FIXME: only this commit (>v2.8.3, > /root/.ssh/authorized_keys && \ + echo "Host *" > /root/.ssh/config && \ + echo " StrictHostKeyChecking no" >> /root/.ssh/config && \ + chmod 700 /root/.ssh && \ + chmod 600 /root/.ssh/authorized_keys && \ + chmod 600 /root/.ssh/config + +# 5. Expose SSH port +EXPOSE 4022 + +# 6. Start SSH service when container starts +CMD ["/usr/sbin/sshd", "-D"] + +# Install Flash Attention2/3/4 + +RUN pip install nvidia-cutlass-dsl==${CUTLASS_DSL_VERSION} + +RUN pip uninstall -y flash-attn && \ + mkdir -p /workspace/flash-attention && \ + cd /workspace/flash-attention && \ + \ + # 1. Clone the repository without checking out a specific branch \ + git init && \ + git remote add origin https://github.com/Dao-AILab/flash-attention.git && \ + git fetch origin ${FLASH_ATTENTION_COMMIT_ID} --depth 1 && \ + \ + # 2. Checkout the specific commit ID (This will be in a detached HEAD state) \ + git checkout ${FLASH_ATTENTION_COMMIT_ID} && \ + \ + # 3. Initialize/update submodules after checkout (Crucial for Flash-Attention) \ + git submodule update --init --recursive && \ + \ + # 4. Continue with installation \ + python setup.py install && \ + cd /workspace/flash-attention/hopper && python setup.py install && \ + \ + python_path=$(python -c "import site; print(site.getsitepackages()[0])") && \ + mkdir -p ${python_path}/flash_attn_3 && \ + # 5. Copy the file from the local checked-out repository (Safest method) \ + cp /workspace/flash-attention/hopper/flash_attn_interface.py ${python_path}/flash_attn_3/ && \ + cd /workspace && rm -rf /workspace/flash-attention + +# NOTE: Some packages are not included in the NGC container, so we need to install them manually +# Install FIO (for file-system performance test) +RUN mkdir /workspace/fio && cd /workspace/fio && \ + wget -q -O - https://github.com/axboe/fio/archive/refs/tags/${FIO_VERSION}.tar.gz | tar --strip-components=1 -xzf - && \ + ./configure && make -j100 && make install && \ + cd /workspace && rm -rf /workspace/fio + + +ENV DEBIAN_FRONTEND=noninteractive + +# Install benchmark required deps +RUN pip install --no-cache-dir \ + seaborn==0.13.2 \ + py3nvml==0.2.7 \ + pandas==2.3.3 \ + Megatron==0.5.1 \ + -e git+https://github.com/NVIDIA/Megatron-LM.git@dev#egg=megatron-core + +# Install MagiAttention from blackwell_benchmark branch +RUN set -eux; \ + git clone --depth 1 --branch blackwell_benchmark https://github.com/SandAI-org/MagiAttention.git /tmp/MagiAttention && \ + cd /tmp/MagiAttention && \ + git submodule update --init --recursive && \ + pip install -r requirements.txt && \ + bash scripts/install_flash_attn_cute.sh && \ + export MAGI_ATTENTION_PREBUILD_FFA=0 && \ + export MAGI_ATTENTION_SKIP_FFA_UTILS_BUILD=0 && \ + export MAGI_ATTENTION_SKIP_MAGI_ATTN_EXT_BUILD=0 && \ + export MAGI_ATTENTION_SKIP_MAGI_ATTN_COMM_BUILD=0 && \ + pip install -e . -v --no-build-isolation --force-reinstall && \ + pip show magi_attention && \ + python -c "from magi_attention import magi_attn_comm; print(magi_attn_comm)" + +CMD ["/bin/bash"] diff --git a/exps/dist_attn/Dockerfile.benchmark2 b/exps/dist_attn/Dockerfile.benchmark2 new file mode 100644 index 000000000..9a9984be1 --- /dev/null +++ b/exps/dist_attn/Dockerfile.benchmark2 @@ -0,0 +1,21 @@ +FROM magi-attn-benchmark:25.10.5 + +RUN pip uninstall -y magi_attention + +# Install MagiAttention from blackwell_benchmark branch +RUN set -eux; \ + git clone --depth 1 --branch blackwell_benchmark https://github.com/SandAI-org/MagiAttention.git /tmp/MagiAttention && \ + cd /tmp/MagiAttention && \ + git submodule update --init --recursive && \ + pip install -r requirements.txt && \ + bash scripts/install_flash_attn_cute.sh && \ + export MAGI_ATTENTION_PREBUILD_FFA=0 && \ + export MAGI_ATTENTION_SKIP_FFA_UTILS_BUILD=0 && \ + export MAGI_ATTENTION_SKIP_MAGI_ATTN_EXT_BUILD=0 && \ + export MAGI_ATTENTION_SKIP_MAGI_ATTN_COMM_BUILD=0 && \ + export MAGI_ATTENTION_BUILD_COMPUTE_CAPABILITY=100 && \ + pip install -e . -v --no-build-isolation --force-reinstall && \ + pip show magi_attention && \ + python -c "from magi_attention import magi_attn_comm; print(magi_attn_comm)" + +CMD ["/bin/bash"] diff --git a/exps/dist_attn/baselines/shard.py b/exps/dist_attn/baselines/shard.py index 6df8b758d..06c4afa94 100644 --- a/exps/dist_attn/baselines/shard.py +++ b/exps/dist_attn/baselines/shard.py @@ -61,7 +61,6 @@ def set_seed(seed): # init distribute environment # create DeviceMesh for all pg def init_distributed(world_size, pg_meta={}): - print(f"world_size: {world_size}, meta info: {pg_meta}") if not dist.is_initialized(): local_rank = int(os.environ.get("LOCAL_RANK", 0)) torch.cuda.set_device(local_rank) diff --git a/exps/dist_attn/benchmark_conf.py b/exps/dist_attn/benchmark_conf.py index 9a520a161..ca0c36e09 100644 --- a/exps/dist_attn/benchmark_conf.py +++ b/exps/dist_attn/benchmark_conf.py @@ -56,33 +56,33 @@ class ENVVAR_CONFIG: EXTEND_ENVVAR_CONFIG = { AttnImpl.MAGI_ATTENTION: { "envvars": { - # CUDA + # --- CUDA --- # "CUDA_DEVICE_MAX_CONNECTIONS": [8], # for better parallelism - # NCCL - "NCCL_CGA_CLUSTER_SIZE": [1], # for better overlap - # Torch - "TORCH_NCCL_HIGH_PRIORITY": [1], # for better overlap - # MagiAttention comm - "MAGI_ATTENTION_HIERARCHICAL_COMM": [ - 0 - ], # turn it to `1` to enable a2av-based hierarchical comm - "MAGI_ATTENTION_NATIVE_GRPCOLL": [ - 0 - ], # turn it to `1` to enable native grpcoll - "MAGI_ATTENTION_QO_COMM": [ - 0 - ], # turn it to `1` to enable dynamic solver with QO comm - # MagiAttention blackwell - "MAGI_ATTENTION_FA4_BACKEND": [ - 0 - ], # turn it to `1` to enable FA4 backend for Blackwell - "MAGI_ATTENTION_FA4_HSFU_MAX_NUM_FUNCS": [ - 3 - ], # only used when enabling FA4 backend + # --- NCCL --- # + "NCCL_CGA_CLUSTER_SIZE": [1], # for better overlap with a2av backend + # --- Torch --- # + "TORCH_NCCL_HIGH_PRIORITY": [1], # for better overlap with a2av backend + # --- MagiAttention comm --- # + # turn it to `1` to enable a2av-based hierarchical comm + "MAGI_ATTENTION_HIERARCHICAL_COMM": [0], + # turn it to `1` to enable native grpcoll + "MAGI_ATTENTION_NATIVE_GRPCOLL": [0], + # turn it to `1` to enable dynamic solver with QO comm + "MAGI_ATTENTION_QO_COMM": [0], + # turn it to `1` to flatten query head groups + # for better dynamic solver partitioning + "MAGI_ATTENTION_FLATTEN_HEAD_GROUPS": [0], + # --- MagiAttention blackwell --- # + # turn it to `1` to enable FA4 backend for Blackwell + "MAGI_ATTENTION_FA4_BACKEND": [0], + # set the maximum odd number of functions for HSFU representations + # which only takes effect when enabling FA4 backend + "MAGI_ATTENTION_FA4_HSFU_MAX_NUM_FUNCS": [3], }, - "extend_labels": [ - "exp0" - ], # optionally set the extended label for each envvar combination + # optionally set the extended label for each envvar combination + # which only works when `use_extend_labels` is True + # e.g. ["label0", "label1", ...] + "extend_labels": ["label0"], } } use_extend_labels = False @@ -187,17 +187,19 @@ class DATA_CONFIG: Data configuration. - seqlen_per_rank: sequence length per rank, total seqlen = seqlen_per_rank * world_size. - embed_dim: embedding dimension. - - hidden_size: hidden size. - - heads_q: number of query heads. - - heads_kv: number of key/value heads. + - head_dim: head dimension. + - num_heads_q: number of query heads. + - num_heads_kv: number of key/value heads. - dtype: data dtype. """ - seqlen_per_rank = 8 * 1024 + seqlen_per_rank = 8 * 1024 # for H100 + # seqlen_per_rank = 16 * 1024 # for H200/B200 + # seqlen_per_rank = 32 * 1024 # for B300 embed_dim = 1024 - hidden_size = 128 - heads_q = 64 - heads_kv = 8 + head_dim = 128 + num_heads_q = 64 + num_heads_kv = 8 dtype = torch.bfloat16 @@ -216,12 +218,15 @@ class ATTN_CONFIG: """ # ----- cp baselie dist-attn conf ---- # - attn_backend = AttnBackend.FA3 + + attn_backend = AttnBackend.FA3 # for Hopper + # attn_backend = AttnBackend.TE # for Blackwell dropout = 0.0 softmax_scale = None deterministic = False # ----- magi-attention conf ---- # + chunk_size = 2048 dispatch_alg = MinHeapDispatchAlg @@ -229,15 +234,14 @@ class ATTN_CONFIG: overlap_mode = AttnOverlapMode.STATIC degree = 2 min_chunk_size = 512 - max_num_chunks = 64 + max_num_chunks = 4096 # ----- magi-attention native grpcoll conf ---- # - num_sms = 88 - nvl_chunk_size = 8 - nvl_buffer_size = 256 - rdma_chunk_size = 16 - rdma_buffer_size = 128 - num_nvl_bytes = int(3e9) # ~3GB - # only valid for internode - num_rdma_bytes = int(1e9) # ~1GB + num_sms = 24 + nvl_chunk_size = 4 + nvl_buffer_size = 128 + rdma_chunk_size = 16 + rdma_buffer_size = 256 + num_nvl_bytes = int(5e9) # ~5GB + num_rdma_bytes = int(5e9) # ~5GB, only valid for internode diff --git a/exps/dist_attn/build_benchmark_docker.sh b/exps/dist_attn/build_benchmark_docker.sh new file mode 100644 index 000000000..246ad5abe --- /dev/null +++ b/exps/dist_attn/build_benchmark_docker.sh @@ -0,0 +1,24 @@ +#! /bin/bash + +# Copyright (c) 2025-2026 SandAI. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# NOTE: This assumes that the base image already has MagiAttention installed. + +# Make sure to log in before pushing, if necessary. +# docker login + +docker build --network host -f ./Dockerfile.benchmark --build-arg http_proxy=$http_proxy --build-arg https_proxy=$https_proxy -t magi-attn-benchmark:25.10.4 . + +docker push magi-attn-benchmark:25.10.4 diff --git a/exps/dist_attn/build_benchmark_docker2.sh b/exps/dist_attn/build_benchmark_docker2.sh new file mode 100644 index 000000000..c4a962e71 --- /dev/null +++ b/exps/dist_attn/build_benchmark_docker2.sh @@ -0,0 +1,24 @@ +#! /bin/bash + +# Copyright (c) 2025-2026 SandAI. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# NOTE: This assumes that the base image already has MagiAttention installed. + +# Make sure to log in before pushing, if necessary. +# docker login + +docker build --network host -f ./Dockerfile.benchmark2 --build-arg http_proxy=$http_proxy --build-arg https_proxy=$https_proxy -t magi-attn-benchmark:25.10.6 . + +docker push magi-attn-benchmark:25.10.6 diff --git a/exps/dist_attn/configs/b200_all.py b/exps/dist_attn/configs/b200_all.py new file mode 100644 index 000000000..f41c6144f --- /dev/null +++ b/exps/dist_attn/configs/b200_all.py @@ -0,0 +1,244 @@ +# Copyright (c) 2025-2026 SandAI. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass + +import torch + +from exps.dist_attn.baselines.interface import AttnImpl +from exps.dist_attn.baselines.utils_cp import AttnBackend +from exps.dist_attn.benchmark.enums import FlashMaskType +from magi_attention.common.enum import AttnOverlapMode +from magi_attention.meta.solver.dispatch_solver import MinHeapDispatchAlg + +""" +This file defines all cp benchmark configurations. +Each config name should start with a capital letter to be cnocluded, +all other keys will be discarded. The config will be loaded by `run_benchmark.py` +and specified via `--config` in `run_benchmark.sh`. +""" + + +SEED = 42 + + +@dataclass +class ENVVAR_CONFIG: + """ + Define env vars for MagiAttention here and dynamic switch in `run_benchmark.py`, avoid modify bash script + and simplify benchmarking process. + - EXTEND_ENVVAR_CONFIG: Specifies all environment variable product combinations used in benchmarking, + use extend_labels to assign a custom name to each extension. + If provided, the number of extend_labels must exactly match the number of generated extensions. + If not provided, the benchmark will assign default suffixes such as -0, -1, etc. + - use_extend_labels: specifies whether the values in extend_labels are appended to the result labels. + This option is only valid when each baseline has exactly one environment-variable extension; + otherwise, an error will be raised. + By default, use_extend_labels is set to False, which means no label extensions are applied to any baseline. + + Example of using extensions: + 1. Define multiple values for certain environment variables, e.g., NCCL_CGA_CLUSTER_SIZE: [1, 4]. + 2. Define multiple extend_labels, e.g., ["exp0", "exp1"], or leave it empty. + 3. Enable use_extend_labels to automatically extend labels for each configuration combination. + """ + + EXTEND_ENVVAR_CONFIG = { + AttnImpl.MAGI_ATTENTION: { + "envvars": { + # --- CUDA --- # + "CUDA_DEVICE_MAX_CONNECTIONS": [8], # for better parallelism + # --- NCCL --- # + "NCCL_CGA_CLUSTER_SIZE": [1], # for better overlap with a2av backend + # --- Torch --- # + "TORCH_NCCL_HIGH_PRIORITY": [1], # for better overlap with a2av backend + # --- MagiAttention comm --- # + # turn it to `1` to enable a2av-based hierarchical comm + "MAGI_ATTENTION_HIERARCHICAL_COMM": [0], + # turn it to `1` to enable native grpcoll + "MAGI_ATTENTION_NATIVE_GRPCOLL": [0, 1], + # turn it to `1` to enable dynamic solver with QO comm + "MAGI_ATTENTION_QO_COMM": [0], + # turn it to `1` to flatten query head groups + # for better dynamic solver partitioning + "MAGI_ATTENTION_FLATTEN_HEAD_GROUPS": [0], + # --- MagiAttention blackwell --- # + # turn it to `1` to enable FA4 backend for Blackwell + "MAGI_ATTENTION_FA4_BACKEND": [1], + # set the maximum odd number of functions for HSFU representations + # which only takes effect when enabling FA4 backend + "MAGI_ATTENTION_FA4_HSFU_MAX_NUM_FUNCS": [3], + }, + # optionally set the extended label for each envvar combination + # which only works when `use_extend_labels` is True + # e.g. ["label0", "label1", ...] + "extend_labels": ["a2av", "native"], + } + } + use_extend_labels = True + + +@dataclass +class BENCH_MODE: + """ + Benchmark runtime mode configuration. + - enable_profile: whether to enable nsys profiling. + - profile_only: if True, only profile the benchmark; skip flops/memory recording and skip plotting results. + - stat_warmup_iters: number of warmup iterations for statistical benchmark recording. + - stat_iters: number of iterations for statistical benchmark recording (flops/memory). + - profile_iters: number of iterations for profiling. + - profile_warmup_iters: number of warmup iterations for profiling. + """ + + enable_profile = True + profile_only = False + stat_warmup_iters = 5 + stat_iters = 20 + profile_warmup_iters = 1 + profile_iters = 3 + + +@dataclass +class BENCH_CONFIG: + """ + Benchmark combination configuration. + - quantiles: quantile points used for summarizing latency/throughput results. + - bench_flops: whether to benchmark flops. + - bench_mem: whether to benchmark memory. + - bench_mode: mode to summarize latency/throughput results (mean, median, min, max). + - output_path: output folder. + - mask_pattern: + list of attention masks to run evaluate (FULL, CAUSAL, Varlen-FULL, Varlen-CAUSAL). + - dist_attn_impl: + distributed attention implementations to evaluate (Ulysess, Ring-P2P, Ring-AllGather, + USP, LoongTrain, MagiAttention). + - workload: + pipeline schedule modes to evaluate ("fwd"=forward only, "bwd"=backward only, "1f1b"=forward+backward). + + e.g. + for mask in mask_patterm: + for dist_attn in dist_attn_impl: + for wd in workload: + do_bench + """ + + quantiles = [0.5, 0.2, 0.8] + bench_flops = True + bench_mem = False + bench_mode = "mean" + output_path = "./outs/b200_all" + mask_pattern = [ + FlashMaskType.FULL, + FlashMaskType.CAUSAL, + FlashMaskType.FULL_DOCUMENT, + FlashMaskType.CAUSAL_DOCUMENT, + ] + dist_attn_impl = [ + AttnImpl.ULYSSES, + AttnImpl.RING_P2P, + AttnImpl.RING_ALLGATHER, + AttnImpl.USP, + AttnImpl.LOONGTRAIN, + AttnImpl.MAGI_ATTENTION, + # AttnImpl.HYBRID_DCP, + ] + workload = [ + "fwd", + "bwd", + # "1f1b", + ] + + +@dataclass +class SAMPLE_CONFIG: + """ + Mask sampler configuration. + - dataset_path: path to the csv or json file of dataset length distribution. + - pack_num: number of data packs to evaluate. + - chunk_ratio: ratio used to determine chunk size; sequences longer than `pack_len * chunk_ratio` + are split into chunks of this length. + - is_binned: whether the dataset statistics are provided as intervals with counts (binned) + or as individual lengths with counts. + - to_attn_ranges: convert to attn_ranges. + - drop_thres: whether to drop large sample to generate short-varlen. + """ + + dataset_path = "./benchmark/datasets/default/doc_length_distribution.csv" + pack_num = 20 + chunk_ratio = 0.25 + is_binned = True + to_attn_ranges = True + drop_thres = -1 + + +@dataclass +class DATA_CONFIG: + """ + Data configuration. + - seqlen_per_rank: sequence length per rank, total seqlen = seqlen_per_rank * world_size. + - embed_dim: embedding dimension. + - head_dim: head dimension. + - num_heads_q: number of query heads. + - num_heads_kv: number of key/value heads. + - dtype: data dtype. + """ + + seqlen_per_rank = 16 * 1024 + embed_dim = 1024 + head_dim = 128 + num_heads_q = 64 + num_heads_kv = 8 + dtype = torch.bfloat16 + + +@dataclass +class ATTN_CONFIG: + """ + Baseline impl configuration. + - attn_backend: baseline attention backend to use (FA3, TE) + - dropout: dropout rate. + - softmax_scale: softmax scale. + - deterministic: whether to use deterministic mode. + MagiAttention impl configuration. + - chunk_size + - dispatch_alg + - OverlapConfig + """ + + # ----- cp baselie dist-attn conf ---- # + + attn_backend = AttnBackend.TE + dropout = 0.0 + softmax_scale = None + deterministic = False + + # ----- magi-attention conf ---- # + + chunk_size = 2048 + dispatch_alg = MinHeapDispatchAlg + + enable_overlap = True + overlap_mode = AttnOverlapMode.STATIC + degree = 2 + min_chunk_size = 512 + max_num_chunks = 4096 + + # ----- magi-attention native grpcoll conf ---- # + + num_sms = 48 + nvl_chunk_size = 4 + nvl_buffer_size = 256 + rdma_chunk_size = 16 + rdma_buffer_size = 256 + num_nvl_bytes = int(10e9) # ~10GB + num_rdma_bytes = int(10e9) # ~10GB, only valid for internode diff --git a/exps/dist_attn/configs/b200_baseline.py b/exps/dist_attn/configs/b200_baseline.py new file mode 100644 index 000000000..49a4cf952 --- /dev/null +++ b/exps/dist_attn/configs/b200_baseline.py @@ -0,0 +1,244 @@ +# Copyright (c) 2025-2026 SandAI. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass + +import torch + +from exps.dist_attn.baselines.interface import AttnImpl +from exps.dist_attn.baselines.utils_cp import AttnBackend +from exps.dist_attn.benchmark.enums import FlashMaskType +from magi_attention.common.enum import AttnOverlapMode +from magi_attention.meta.solver.dispatch_solver import MinHeapDispatchAlg + +""" +This file defines all cp benchmark configurations. +Each config name should start with a capital letter to be cnocluded, +all other keys will be discarded. The config will be loaded by `run_benchmark.py` +and specified via `--config` in `run_benchmark.sh`. +""" + + +SEED = 42 + + +@dataclass +class ENVVAR_CONFIG: + """ + Define env vars for MagiAttention here and dynamic switch in `run_benchmark.py`, avoid modify bash script + and simplify benchmarking process. + - EXTEND_ENVVAR_CONFIG: Specifies all environment variable product combinations used in benchmarking, + use extend_labels to assign a custom name to each extension. + If provided, the number of extend_labels must exactly match the number of generated extensions. + If not provided, the benchmark will assign default suffixes such as -0, -1, etc. + - use_extend_labels: specifies whether the values in extend_labels are appended to the result labels. + This option is only valid when each baseline has exactly one environment-variable extension; + otherwise, an error will be raised. + By default, use_extend_labels is set to False, which means no label extensions are applied to any baseline. + + Example of using extensions: + 1. Define multiple values for certain environment variables, e.g., NCCL_CGA_CLUSTER_SIZE: [1, 4]. + 2. Define multiple extend_labels, e.g., ["exp0", "exp1"], or leave it empty. + 3. Enable use_extend_labels to automatically extend labels for each configuration combination. + """ + + EXTEND_ENVVAR_CONFIG = { + AttnImpl.MAGI_ATTENTION: { + "envvars": { + # --- CUDA --- # + "CUDA_DEVICE_MAX_CONNECTIONS": [8], # for better parallelism + # --- NCCL --- # + "NCCL_CGA_CLUSTER_SIZE": [1], # for better overlap with a2av backend + # --- Torch --- # + "TORCH_NCCL_HIGH_PRIORITY": [1], # for better overlap with a2av backend + # --- MagiAttention comm --- # + # turn it to `1` to enable a2av-based hierarchical comm + "MAGI_ATTENTION_HIERARCHICAL_COMM": [0], + # turn it to `1` to enable native grpcoll + "MAGI_ATTENTION_NATIVE_GRPCOLL": [0], + # turn it to `1` to enable dynamic solver with QO comm + "MAGI_ATTENTION_QO_COMM": [0], + # turn it to `1` to flatten query head groups + # for better dynamic solver partitioning + "MAGI_ATTENTION_FLATTEN_HEAD_GROUPS": [0], + # --- MagiAttention blackwell --- # + # turn it to `1` to enable FA4 backend for Blackwell + "MAGI_ATTENTION_FA4_BACKEND": [0], + # set the maximum odd number of functions for HSFU representations + # which only takes effect when enabling FA4 backend + "MAGI_ATTENTION_FA4_HSFU_MAX_NUM_FUNCS": [3], + }, + # optionally set the extended label for each envvar combination + # which only works when `use_extend_labels` is True + # e.g. ["label0", "label1", ...] + "extend_labels": ["label0"], + } + } + use_extend_labels = False + + +@dataclass +class BENCH_MODE: + """ + Benchmark runtime mode configuration. + - enable_profile: whether to enable nsys profiling. + - profile_only: if True, only profile the benchmark; skip flops/memory recording and skip plotting results. + - stat_warmup_iters: number of warmup iterations for statistical benchmark recording. + - stat_iters: number of iterations for statistical benchmark recording (flops/memory). + - profile_iters: number of iterations for profiling. + - profile_warmup_iters: number of warmup iterations for profiling. + """ + + enable_profile = True + profile_only = False + stat_warmup_iters = 5 + stat_iters = 20 + profile_warmup_iters = 1 + profile_iters = 3 + + +@dataclass +class BENCH_CONFIG: + """ + Benchmark combination configuration. + - quantiles: quantile points used for summarizing latency/throughput results. + - bench_flops: whether to benchmark flops. + - bench_mem: whether to benchmark memory. + - bench_mode: mode to summarize latency/throughput results (mean, median, min, max). + - output_path: output folder. + - mask_pattern: + list of attention masks to run evaluate (FULL, CAUSAL, Varlen-FULL, Varlen-CAUSAL). + - dist_attn_impl: + distributed attention implementations to evaluate (Ulysess, Ring-P2P, Ring-AllGather, + USP, LoongTrain, MagiAttention). + - workload: + pipeline schedule modes to evaluate ("fwd"=forward only, "bwd"=backward only, "1f1b"=forward+backward). + + e.g. + for mask in mask_patterm: + for dist_attn in dist_attn_impl: + for wd in workload: + do_bench + """ + + quantiles = [0.5, 0.2, 0.8] + bench_flops = True + bench_mem = False + bench_mode = "mean" + output_path = "./outs/b200_baseline" + mask_pattern = [ + FlashMaskType.FULL, + FlashMaskType.CAUSAL, + FlashMaskType.FULL_DOCUMENT, + FlashMaskType.CAUSAL_DOCUMENT, + ] + dist_attn_impl = [ + AttnImpl.ULYSSES, + AttnImpl.RING_P2P, + AttnImpl.RING_ALLGATHER, + AttnImpl.USP, + AttnImpl.LOONGTRAIN, + # AttnImpl.MAGI_ATTENTION, + # AttnImpl.HYBRID_DCP, + ] + workload = [ + "fwd", + "bwd", + # "1f1b", + ] + + +@dataclass +class SAMPLE_CONFIG: + """ + Mask sampler configuration. + - dataset_path: path to the csv or json file of dataset length distribution. + - pack_num: number of data packs to evaluate. + - chunk_ratio: ratio used to determine chunk size; sequences longer than `pack_len * chunk_ratio` + are split into chunks of this length. + - is_binned: whether the dataset statistics are provided as intervals with counts (binned) + or as individual lengths with counts. + - to_attn_ranges: convert to attn_ranges. + - drop_thres: whether to drop large sample to generate short-varlen. + """ + + dataset_path = "./benchmark/datasets/default/doc_length_distribution.csv" + pack_num = 20 + chunk_ratio = 0.25 + is_binned = True + to_attn_ranges = True + drop_thres = -1 + + +@dataclass +class DATA_CONFIG: + """ + Data configuration. + - seqlen_per_rank: sequence length per rank, total seqlen = seqlen_per_rank * world_size. + - embed_dim: embedding dimension. + - head_dim: head dimension. + - num_heads_q: number of query heads. + - num_heads_kv: number of key/value heads. + - dtype: data dtype. + """ + + seqlen_per_rank = 16 * 1024 + embed_dim = 1024 + head_dim = 128 + num_heads_q = 64 + num_heads_kv = 8 + dtype = torch.bfloat16 + + +@dataclass +class ATTN_CONFIG: + """ + Baseline impl configuration. + - attn_backend: baseline attention backend to use (FA3, TE) + - dropout: dropout rate. + - softmax_scale: softmax scale. + - deterministic: whether to use deterministic mode. + MagiAttention impl configuration. + - chunk_size + - dispatch_alg + - OverlapConfig + """ + + # ----- cp baselie dist-attn conf ---- # + + attn_backend = AttnBackend.TE + dropout = 0.0 + softmax_scale = None + deterministic = False + + # ----- magi-attention conf ---- # + + chunk_size = 2048 + dispatch_alg = MinHeapDispatchAlg + + enable_overlap = True + overlap_mode = AttnOverlapMode.STATIC + degree = 2 + min_chunk_size = 512 + max_num_chunks = 4096 + + # ----- magi-attention native grpcoll conf ---- # + + num_sms = 48 + nvl_chunk_size = 4 + nvl_buffer_size = 256 + rdma_chunk_size = 16 + rdma_buffer_size = 256 + num_nvl_bytes = int(10e9) # ~10GB + num_rdma_bytes = int(10e9) # ~10GB, only valid for internode diff --git a/exps/dist_attn/configs/b200_magi.py b/exps/dist_attn/configs/b200_magi.py new file mode 100644 index 000000000..8f1531109 --- /dev/null +++ b/exps/dist_attn/configs/b200_magi.py @@ -0,0 +1,244 @@ +# Copyright (c) 2025-2026 SandAI. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass + +import torch + +from exps.dist_attn.baselines.interface import AttnImpl +from exps.dist_attn.baselines.utils_cp import AttnBackend +from exps.dist_attn.benchmark.enums import FlashMaskType +from magi_attention.common.enum import AttnOverlapMode +from magi_attention.meta.solver.dispatch_solver import MinHeapDispatchAlg + +""" +This file defines all cp benchmark configurations. +Each config name should start with a capital letter to be cnocluded, +all other keys will be discarded. The config will be loaded by `run_benchmark.py` +and specified via `--config` in `run_benchmark.sh`. +""" + + +SEED = 42 + + +@dataclass +class ENVVAR_CONFIG: + """ + Define env vars for MagiAttention here and dynamic switch in `run_benchmark.py`, avoid modify bash script + and simplify benchmarking process. + - EXTEND_ENVVAR_CONFIG: Specifies all environment variable product combinations used in benchmarking, + use extend_labels to assign a custom name to each extension. + If provided, the number of extend_labels must exactly match the number of generated extensions. + If not provided, the benchmark will assign default suffixes such as -0, -1, etc. + - use_extend_labels: specifies whether the values in extend_labels are appended to the result labels. + This option is only valid when each baseline has exactly one environment-variable extension; + otherwise, an error will be raised. + By default, use_extend_labels is set to False, which means no label extensions are applied to any baseline. + + Example of using extensions: + 1. Define multiple values for certain environment variables, e.g., NCCL_CGA_CLUSTER_SIZE: [1, 4]. + 2. Define multiple extend_labels, e.g., ["exp0", "exp1"], or leave it empty. + 3. Enable use_extend_labels to automatically extend labels for each configuration combination. + """ + + EXTEND_ENVVAR_CONFIG = { + AttnImpl.MAGI_ATTENTION: { + "envvars": { + # --- CUDA --- # + "CUDA_DEVICE_MAX_CONNECTIONS": [8], # for better parallelism + # --- NCCL --- # + "NCCL_CGA_CLUSTER_SIZE": [1], # for better overlap with a2av backend + # --- Torch --- # + "TORCH_NCCL_HIGH_PRIORITY": [1], # for better overlap with a2av backend + # --- MagiAttention comm --- # + # turn it to `1` to enable a2av-based hierarchical comm + "MAGI_ATTENTION_HIERARCHICAL_COMM": [0], + # turn it to `1` to enable native grpcoll + "MAGI_ATTENTION_NATIVE_GRPCOLL": [0, 1], + # turn it to `1` to enable dynamic solver with QO comm + "MAGI_ATTENTION_QO_COMM": [0], + # turn it to `1` to flatten query head groups + # for better dynamic solver partitioning + "MAGI_ATTENTION_FLATTEN_HEAD_GROUPS": [0], + # --- MagiAttention blackwell --- # + # turn it to `1` to enable FA4 backend for Blackwell + "MAGI_ATTENTION_FA4_BACKEND": [1], + # set the maximum odd number of functions for HSFU representations + # which only takes effect when enabling FA4 backend + "MAGI_ATTENTION_FA4_HSFU_MAX_NUM_FUNCS": [3], + }, + # optionally set the extended label for each envvar combination + # which only works when `use_extend_labels` is True + # e.g. ["label0", "label1", ...] + "extend_labels": ["a2av", "native"], + } + } + use_extend_labels = True + + +@dataclass +class BENCH_MODE: + """ + Benchmark runtime mode configuration. + - enable_profile: whether to enable nsys profiling. + - profile_only: if True, only profile the benchmark; skip flops/memory recording and skip plotting results. + - stat_warmup_iters: number of warmup iterations for statistical benchmark recording. + - stat_iters: number of iterations for statistical benchmark recording (flops/memory). + - profile_iters: number of iterations for profiling. + - profile_warmup_iters: number of warmup iterations for profiling. + """ + + enable_profile = True + profile_only = False + stat_warmup_iters = 5 + stat_iters = 20 + profile_warmup_iters = 1 + profile_iters = 3 + + +@dataclass +class BENCH_CONFIG: + """ + Benchmark combination configuration. + - quantiles: quantile points used for summarizing latency/throughput results. + - bench_flops: whether to benchmark flops. + - bench_mem: whether to benchmark memory. + - bench_mode: mode to summarize latency/throughput results (mean, median, min, max). + - output_path: output folder. + - mask_pattern: + list of attention masks to run evaluate (FULL, CAUSAL, Varlen-FULL, Varlen-CAUSAL). + - dist_attn_impl: + distributed attention implementations to evaluate (Ulysess, Ring-P2P, Ring-AllGather, + USP, LoongTrain, MagiAttention). + - workload: + pipeline schedule modes to evaluate ("fwd"=forward only, "bwd"=backward only, "1f1b"=forward+backward). + + e.g. + for mask in mask_patterm: + for dist_attn in dist_attn_impl: + for wd in workload: + do_bench + """ + + quantiles = [0.5, 0.2, 0.8] + bench_flops = True + bench_mem = False + bench_mode = "mean" + output_path = "./outs/b200_magi" + mask_pattern = [ + FlashMaskType.FULL, + FlashMaskType.CAUSAL, + FlashMaskType.FULL_DOCUMENT, + FlashMaskType.CAUSAL_DOCUMENT, + ] + dist_attn_impl = [ + # AttnImpl.ULYSSES, + # AttnImpl.RING_P2P, + # AttnImpl.RING_ALLGATHER, + # AttnImpl.USP, + # AttnImpl.LOONGTRAIN, + AttnImpl.MAGI_ATTENTION, + # AttnImpl.HYBRID_DCP, + ] + workload = [ + "fwd", + "bwd", + # "1f1b", + ] + + +@dataclass +class SAMPLE_CONFIG: + """ + Mask sampler configuration. + - dataset_path: path to the csv or json file of dataset length distribution. + - pack_num: number of data packs to evaluate. + - chunk_ratio: ratio used to determine chunk size; sequences longer than `pack_len * chunk_ratio` + are split into chunks of this length. + - is_binned: whether the dataset statistics are provided as intervals with counts (binned) + or as individual lengths with counts. + - to_attn_ranges: convert to attn_ranges. + - drop_thres: whether to drop large sample to generate short-varlen. + """ + + dataset_path = "./benchmark/datasets/default/doc_length_distribution.csv" + pack_num = 20 + chunk_ratio = 0.25 + is_binned = True + to_attn_ranges = True + drop_thres = -1 + + +@dataclass +class DATA_CONFIG: + """ + Data configuration. + - seqlen_per_rank: sequence length per rank, total seqlen = seqlen_per_rank * world_size. + - embed_dim: embedding dimension. + - head_dim: head dimension. + - num_heads_q: number of query heads. + - num_heads_kv: number of key/value heads. + - dtype: data dtype. + """ + + seqlen_per_rank = 16 * 1024 + embed_dim = 1024 + head_dim = 128 + num_heads_q = 64 + num_heads_kv = 8 + dtype = torch.bfloat16 + + +@dataclass +class ATTN_CONFIG: + """ + Baseline impl configuration. + - attn_backend: baseline attention backend to use (FA3, TE) + - dropout: dropout rate. + - softmax_scale: softmax scale. + - deterministic: whether to use deterministic mode. + MagiAttention impl configuration. + - chunk_size + - dispatch_alg + - OverlapConfig + """ + + # ----- cp baselie dist-attn conf ---- # + + attn_backend = AttnBackend.TE + dropout = 0.0 + softmax_scale = None + deterministic = False + + # ----- magi-attention conf ---- # + + chunk_size = 2048 + dispatch_alg = MinHeapDispatchAlg + + enable_overlap = True + overlap_mode = AttnOverlapMode.STATIC + degree = 2 + min_chunk_size = 512 + max_num_chunks = 4096 + + # ----- magi-attention native grpcoll conf ---- # + + num_sms = 48 + nvl_chunk_size = 4 + nvl_buffer_size = 256 + rdma_chunk_size = 16 + rdma_buffer_size = 256 + num_nvl_bytes = int(10e9) # ~10GB + num_rdma_bytes = int(10e9) # ~10GB, only valid for internode diff --git a/exps/dist_attn/configs/b200_magi_native.py b/exps/dist_attn/configs/b200_magi_native.py new file mode 100644 index 000000000..5707a036c --- /dev/null +++ b/exps/dist_attn/configs/b200_magi_native.py @@ -0,0 +1,244 @@ +# Copyright (c) 2025-2026 SandAI. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass + +import torch + +from exps.dist_attn.baselines.interface import AttnImpl +from exps.dist_attn.baselines.utils_cp import AttnBackend +from exps.dist_attn.benchmark.enums import FlashMaskType +from magi_attention.common.enum import AttnOverlapMode +from magi_attention.meta.solver.dispatch_solver import MinHeapDispatchAlg + +""" +This file defines all cp benchmark configurations. +Each config name should start with a capital letter to be cnocluded, +all other keys will be discarded. The config will be loaded by `run_benchmark.py` +and specified via `--config` in `run_benchmark.sh`. +""" + + +SEED = 42 + + +@dataclass +class ENVVAR_CONFIG: + """ + Define env vars for MagiAttention here and dynamic switch in `run_benchmark.py`, avoid modify bash script + and simplify benchmarking process. + - EXTEND_ENVVAR_CONFIG: Specifies all environment variable product combinations used in benchmarking, + use extend_labels to assign a custom name to each extension. + If provided, the number of extend_labels must exactly match the number of generated extensions. + If not provided, the benchmark will assign default suffixes such as -0, -1, etc. + - use_extend_labels: specifies whether the values in extend_labels are appended to the result labels. + This option is only valid when each baseline has exactly one environment-variable extension; + otherwise, an error will be raised. + By default, use_extend_labels is set to False, which means no label extensions are applied to any baseline. + + Example of using extensions: + 1. Define multiple values for certain environment variables, e.g., NCCL_CGA_CLUSTER_SIZE: [1, 4]. + 2. Define multiple extend_labels, e.g., ["exp0", "exp1"], or leave it empty. + 3. Enable use_extend_labels to automatically extend labels for each configuration combination. + """ + + EXTEND_ENVVAR_CONFIG = { + AttnImpl.MAGI_ATTENTION: { + "envvars": { + # --- CUDA --- # + "CUDA_DEVICE_MAX_CONNECTIONS": [8], # for better parallelism + # --- NCCL --- # + "NCCL_CGA_CLUSTER_SIZE": [1], # for better overlap with a2av backend + # --- Torch --- # + "TORCH_NCCL_HIGH_PRIORITY": [1], # for better overlap with a2av backend + # --- MagiAttention comm --- # + # turn it to `1` to enable a2av-based hierarchical comm + "MAGI_ATTENTION_HIERARCHICAL_COMM": [0], + # turn it to `1` to enable native grpcoll + "MAGI_ATTENTION_NATIVE_GRPCOLL": [1], + # turn it to `1` to enable dynamic solver with QO comm + "MAGI_ATTENTION_QO_COMM": [0], + # turn it to `1` to flatten query head groups + # for better dynamic solver partitioning + "MAGI_ATTENTION_FLATTEN_HEAD_GROUPS": [0], + # --- MagiAttention blackwell --- # + # turn it to `1` to enable FA4 backend for Blackwell + "MAGI_ATTENTION_FA4_BACKEND": [1], + # set the maximum odd number of functions for HSFU representations + # which only takes effect when enabling FA4 backend + "MAGI_ATTENTION_FA4_HSFU_MAX_NUM_FUNCS": [3], + }, + # optionally set the extended label for each envvar combination + # which only works when `use_extend_labels` is True + # e.g. ["label0", "label1", ...] + "extend_labels": ["native"], + } + } + use_extend_labels = False + + +@dataclass +class BENCH_MODE: + """ + Benchmark runtime mode configuration. + - enable_profile: whether to enable nsys profiling. + - profile_only: if True, only profile the benchmark; skip flops/memory recording and skip plotting results. + - stat_warmup_iters: number of warmup iterations for statistical benchmark recording. + - stat_iters: number of iterations for statistical benchmark recording (flops/memory). + - profile_iters: number of iterations for profiling. + - profile_warmup_iters: number of warmup iterations for profiling. + """ + + enable_profile = True + profile_only = False + stat_warmup_iters = 5 + stat_iters = 20 + profile_warmup_iters = 1 + profile_iters = 3 + + +@dataclass +class BENCH_CONFIG: + """ + Benchmark combination configuration. + - quantiles: quantile points used for summarizing latency/throughput results. + - bench_flops: whether to benchmark flops. + - bench_mem: whether to benchmark memory. + - bench_mode: mode to summarize latency/throughput results (mean, median, min, max). + - output_path: output folder. + - mask_pattern: + list of attention masks to run evaluate (FULL, CAUSAL, Varlen-FULL, Varlen-CAUSAL). + - dist_attn_impl: + distributed attention implementations to evaluate (Ulysess, Ring-P2P, Ring-AllGather, + USP, LoongTrain, MagiAttention). + - workload: + pipeline schedule modes to evaluate ("fwd"=forward only, "bwd"=backward only, "1f1b"=forward+backward). + + e.g. + for mask in mask_patterm: + for dist_attn in dist_attn_impl: + for wd in workload: + do_bench + """ + + quantiles = [0.5, 0.2, 0.8] + bench_flops = True + bench_mem = False + bench_mode = "mean" + output_path = "./outs/b200_magi_native" + mask_pattern = [ + FlashMaskType.FULL, + FlashMaskType.CAUSAL, + FlashMaskType.FULL_DOCUMENT, + FlashMaskType.CAUSAL_DOCUMENT, + ] + dist_attn_impl = [ + # AttnImpl.ULYSSES, + # AttnImpl.RING_P2P, + # AttnImpl.RING_ALLGATHER, + # AttnImpl.USP, + # AttnImpl.LOONGTRAIN, + AttnImpl.MAGI_ATTENTION, + # AttnImpl.HYBRID_DCP, + ] + workload = [ + "fwd", + "bwd", + # "1f1b", + ] + + +@dataclass +class SAMPLE_CONFIG: + """ + Mask sampler configuration. + - dataset_path: path to the csv or json file of dataset length distribution. + - pack_num: number of data packs to evaluate. + - chunk_ratio: ratio used to determine chunk size; sequences longer than `pack_len * chunk_ratio` + are split into chunks of this length. + - is_binned: whether the dataset statistics are provided as intervals with counts (binned) + or as individual lengths with counts. + - to_attn_ranges: convert to attn_ranges. + - drop_thres: whether to drop large sample to generate short-varlen. + """ + + dataset_path = "./benchmark/datasets/default/doc_length_distribution.csv" + pack_num = 20 + chunk_ratio = 0.25 + is_binned = True + to_attn_ranges = True + drop_thres = -1 + + +@dataclass +class DATA_CONFIG: + """ + Data configuration. + - seqlen_per_rank: sequence length per rank, total seqlen = seqlen_per_rank * world_size. + - embed_dim: embedding dimension. + - head_dim: head dimension. + - num_heads_q: number of query heads. + - num_heads_kv: number of key/value heads. + - dtype: data dtype. + """ + + seqlen_per_rank = 16 * 1024 + embed_dim = 1024 + head_dim = 128 + num_heads_q = 64 + num_heads_kv = 8 + dtype = torch.bfloat16 + + +@dataclass +class ATTN_CONFIG: + """ + Baseline impl configuration. + - attn_backend: baseline attention backend to use (FA3, TE) + - dropout: dropout rate. + - softmax_scale: softmax scale. + - deterministic: whether to use deterministic mode. + MagiAttention impl configuration. + - chunk_size + - dispatch_alg + - OverlapConfig + """ + + # ----- cp baselie dist-attn conf ---- # + + attn_backend = AttnBackend.TE + dropout = 0.0 + softmax_scale = None + deterministic = False + + # ----- magi-attention conf ---- # + + chunk_size = 2048 + dispatch_alg = MinHeapDispatchAlg + + enable_overlap = True + overlap_mode = AttnOverlapMode.STATIC + degree = 2 + min_chunk_size = 512 + max_num_chunks = 4096 + + # ----- magi-attention native grpcoll conf ---- # + + num_sms = 48 + nvl_chunk_size = 4 + nvl_buffer_size = 256 + rdma_chunk_size = 16 + rdma_buffer_size = 256 + num_nvl_bytes = int(10e9) # ~10GB + num_rdma_bytes = int(10e9) # ~10GB, only valid for internode diff --git a/exps/dist_attn/dyn_simulate/test_solver.py b/exps/dist_attn/dyn_simulate/test_solver.py index f12cf3c08..f20fa50b3 100644 --- a/exps/dist_attn/dyn_simulate/test_solver.py +++ b/exps/dist_attn/dyn_simulate/test_solver.py @@ -255,7 +255,6 @@ def simulate_solver_and_measure_cost(): q_ranges=q_ranges, k_ranges=k_ranges, attn_mask_type=attn_mask_type, - flatten_head_groups=magi_attention.is_flatten_head_groups_enable(), ) solve_end = time.time() solve_times.append(solve_end - solve_start) diff --git a/exps/dist_attn/main.py b/exps/dist_attn/main.py index b11673c60..41d54a0ae 100644 --- a/exps/dist_attn/main.py +++ b/exps/dist_attn/main.py @@ -182,11 +182,11 @@ attn_mask_type=[AttnMaskType.FULL] * len(q_ranges), total_seqlen_q=total_seqlen, total_seqlen_k=total_seqlen, + num_heads_q=num_heads_q, + num_heads_kv=num_heads_k, + head_dim=head_dim, chunk_size=chunk_size, cp_group=nccl_groups[0], - is_same_source=True, - is_q_permutable=True, - is_k_permutable=True, dist_attn_config=dist_attn_config, ) # HACK: seperate cp group for group-reduce diff --git a/exps/dist_attn/run_benchmark.py b/exps/dist_attn/run_benchmark.py index 90f409362..17178717f 100644 --- a/exps/dist_attn/run_benchmark.py +++ b/exps/dist_attn/run_benchmark.py @@ -42,7 +42,7 @@ from exps.dist_attn.baselines.utils_cp import AttnBackend from exps.dist_attn.benchmark.enums import FlashMaskType from exps.dist_attn.benchmark.mask import MaskIterator -from magi_attention.api import calc_attn, compute_pad_size, magi_attn_flex_dispatch +from magi_attention.api import calc_attn, compute_pad_size, dispatch, magi_attn_flex_key from magi_attention.benchmarking.bench import Benchmark, do_bench, perf_report from magi_attention.comm.primitive.grpcoll._config import GrpCollConfig from magi_attention.common import AttnRanges @@ -254,9 +254,9 @@ def init_dist_environment( def run_dist_attn( total_seqlen: int, embed_dim: int, - q_heads: int, - kv_heads: int, - hidden_size: int, + num_heads_q: int, + num_heads_kv: int, + head_dim: int, dtype, q_ranges: AttnRanges, k_ranges: AttnRanges, @@ -317,16 +317,16 @@ def run_dist_attn( x = torch.randn(total_seqlen, embed_dim, dtype=dtype, device=device) q_proj = torch.nn.Linear( - embed_dim, q_heads * hidden_size, dtype=dtype, device=device + embed_dim, num_heads_q * head_dim, dtype=dtype, device=device ) k_proj = torch.nn.Linear( - embed_dim, kv_heads * hidden_size, dtype=dtype, device=device + embed_dim, num_heads_kv * head_dim, dtype=dtype, device=device ) v_proj = torch.nn.Linear( - embed_dim, kv_heads * hidden_size, dtype=dtype, device=device + embed_dim, num_heads_kv * head_dim, dtype=dtype, device=device ) dout_proj = torch.nn.Linear( - embed_dim, q_heads * hidden_size, dtype=dtype, device=device + embed_dim, num_heads_q * head_dim, dtype=dtype, device=device ) # ----- dispatch ---- # @@ -348,10 +348,10 @@ def run_dist_attn( dout_local_samples: List[torch.Tensor] = [] for x_local in x_local_samples: x_local = x_local.view(-1, embed_dim) - q_local = q_proj(x_local).view(-1, q_heads, hidden_size) - k_local = k_proj(x_local).view(-1, kv_heads, hidden_size) - v_local = v_proj(x_local).view(-1, kv_heads, hidden_size) - dout_local = dout_proj(x_local).view(-1, q_heads, hidden_size) + q_local = q_proj(x_local).view(-1, num_heads_q, head_dim) + k_local = k_proj(x_local).view(-1, num_heads_kv, head_dim) + v_local = v_proj(x_local).view(-1, num_heads_kv, head_dim) + dout_local = dout_proj(x_local).view(-1, num_heads_q, head_dim) q_local.requires_grad_(True) k_local.requires_grad_(True) @@ -376,8 +376,8 @@ def run_dist_attn( ) if attn_impl == AttnImpl.ULYSSES: - assert world_size % kv_heads == 0 or kv_heads % world_size == 0 - H = world_size // kv_heads + assert world_size % num_heads_kv == 0 or num_heads_kv % world_size == 0 + H = world_size // num_heads_kv if H > 1: k_local = torch.repeat_interleave(k_local, H, dim=1) v_local = torch.repeat_interleave(v_local, H, dim=1) @@ -417,7 +417,7 @@ def fn(): if "CUDA out of memory" not in str(e): print( f"Error occured before running {attn_impl} with {attn_mask_type} mask " - f"when {total_seqlen=}, {q_heads=} during {wd}: {e=}" + f"when {total_seqlen=}, {num_heads_q=} during {wd}: {e=}" ) raise e global already_known_oom_before_run @@ -465,9 +465,9 @@ def fn(): def run_magi_attn( total_seqlen: int, embed_dim: int, - q_heads: int, - kv_heads: int, - hidden_size: int, + num_heads_q: int, + num_heads_kv: int, + head_dim: int, dtype: torch.dtype, q_ranges: AttnRanges, k_ranges: AttnRanges, @@ -486,19 +486,19 @@ def run_magi_attn( x = torch.randn(total_seqlen, embed_dim, dtype=dtype, device=device) q_proj = torch.nn.Linear( - embed_dim, q_heads * hidden_size, dtype=dtype, device=device + embed_dim, num_heads_q * head_dim, dtype=dtype, device=device ) k_proj = torch.nn.Linear( - embed_dim, kv_heads * hidden_size, dtype=dtype, device=device + embed_dim, num_heads_kv * head_dim, dtype=dtype, device=device ) v_proj = torch.nn.Linear( - embed_dim, kv_heads * hidden_size, dtype=dtype, device=device + embed_dim, num_heads_kv * head_dim, dtype=dtype, device=device ) dout_proj = torch.nn.Linear( - embed_dim, q_heads * hidden_size, dtype=dtype, device=device + embed_dim, num_heads_q * head_dim, dtype=dtype, device=device ) - # ----- init dispatch mata ----- # + # ----- init dist attn config ----- # pad_size = compute_pad_size( total_seqlen_q=total_seqlen, @@ -513,8 +513,54 @@ def run_magi_attn( num_nvl_bytes = int(getattr(ATTN_CONFIG, "num_nvl_bytes", int(3e9))) # ~3GB # only valid for internode num_rdma_bytes = int(getattr(ATTN_CONFIG, "num_rdma_bytes", int(1e9))) # ~1GB + if world_size <= 8: # single node num_rdma_bytes = 0 + min_num_nvl_bytes = GrpCollConfig.get_min_num_bytes_intranode( + num_sms=num_sms, + num_ranks=world_size, + hidden_size=num_heads_q * head_dim, + nvl_buffer_size=nvl_buffer_size, + dtype=torch.float32, + transfer_lse=True, + num_heads=num_heads_q, + num_groups=3, + ) + min_num_rdma_bytes = 0 + else: # multi node + assert ( + world_size % 8 == 0 + ), "world_size must be multiple of 8 for internode native grpcoll." + assert ( + num_rdma_bytes > 0 + ), "num_rdma_bytes must be positive for internode native grpcoll." + ( + min_num_rdma_bytes, + min_num_nvl_bytes, + ) = GrpCollConfig.get_min_num_bytes_internode( + num_sms=num_sms, + num_rdma_ranks=world_size // 8, + num_nvl_ranks=8, + hidden_size=num_heads_q * head_dim, + rdma_buffer_size=rdma_buffer_size, + nvl_buffer_size=nvl_buffer_size, + dtype=torch.float32, + transfer_lse=True, + num_heads=num_heads_q, + num_groups=3, + ) + + assert num_nvl_bytes >= min_num_nvl_bytes, ( + f"{num_nvl_bytes=} ({num_nvl_bytes / 1024**3:.2f} GB) " + "is insufficient for native grpcoll, " + f"since {min_num_nvl_bytes=} ({min_num_nvl_bytes / 1024**3:.2f} GB)." + ) + assert num_rdma_bytes >= min_num_rdma_bytes, ( + f"{num_rdma_bytes=} ({num_rdma_bytes / 1024**3:.2f} GB) " + "is insufficient for native grpcoll, " + f"since {min_num_rdma_bytes=} ({min_num_rdma_bytes / 1024**3:.2f} GB)." + ) + grpcoll_config = GrpCollConfig( num_sms=num_sms, nvl_chunk_size=nvl_chunk_size, @@ -542,30 +588,28 @@ def run_magi_attn( # ----- dispatch ---- # - ( - x_local, - magi_attn_runtime_key, - ) = magi_attn_flex_dispatch( # local_x with shape (total_seqlen_q + pad_size) / cp_size, h) - x, + magi_attn_runtime_key = magi_attn_flex_key( # local_x with shape (total_seqlen_q + pad_size) / cp_size, h) q_ranges=q_ranges, k_ranges=k_ranges, attn_mask_type=attn_mask_type, total_seqlen_q=total_seqlen, total_seqlen_k=total_seqlen, + num_heads_q=num_heads_q, + num_heads_kv=num_heads_kv, + head_dim=head_dim, pad_size=pad_size, chunk_size=chunk_size, cp_group_or_mesh=cp_group_or_mesh, dist_attn_config=dist_attn_config, - num_heads_q=q_heads, - num_heads_kv=kv_heads, ) + x_local = dispatch(x, key=magi_attn_runtime_key) # ----- projection ----- # - q_local = q_proj(x_local).view(-1, q_heads, hidden_size) - k_local = k_proj(x_local).view(-1, kv_heads, hidden_size) - v_local = v_proj(x_local).view(-1, kv_heads, hidden_size) - dout_local = dout_proj(x_local).view(-1, q_heads, hidden_size) + q_local = q_proj(x_local).view(-1, num_heads_q, head_dim) + k_local = k_proj(x_local).view(-1, num_heads_kv, head_dim) + v_local = v_proj(x_local).view(-1, num_heads_kv, head_dim) + dout_local = dout_proj(x_local).view(-1, num_heads_q, head_dim) q_local.requires_grad_(True) k_local.requires_grad_(True) @@ -583,7 +627,7 @@ def fn(): if "CUDA out of memory" not in str(e): print( f"Error occured before running magi-attention with {attn_mask_type} mask " - f"when {total_seqlen=}, {q_heads=} during {wd}: {e=}" + f"when {total_seqlen=}, {num_heads_q=} during {wd}: {e=}" ) raise e global already_known_oom_before_run @@ -748,6 +792,9 @@ def maybe_extend_xvals( for attn_impl in dist_attn_impl: if attn_impl in EXTENSIONS.keys() and len(EXTENSIONS[attn_impl]) > 0: for setting_key in EXTENSIONS[attn_impl].keys(): + assert ( + "-" not in setting_key + ), "Remind: setting_key should not contain '-' character." xvals.append(f"{attn_impl.value}-{setting_key}") else: xvals.append(attn_impl.value) @@ -767,8 +814,11 @@ def maybe_switch_envvars(attn_impl_key: str): assert ( extension is not None ), f"{exp_keys[0]} found specific exp setting key {exp_keys[1]}, but no extension." + enable_value_dict = {k: str(v) for k, v in extension.items() if v is not None} switch_back = switch_envvars( - envvar_name_list=list(extension.keys()), enable_dict=extension + envvar_name_list=list(extension.keys()), + enable_dict=extension, + enable_value_dict=enable_value_dict, ) return switch_back, attn_impl @@ -880,6 +930,12 @@ def run_benchmark( for mask_idx, (q_ranges, k_ranges, attn_mask_type, _) in enumerate( mask_iterator ): + if attn_impl_key == "ulysses" and DATA_CONFIG.num_heads_q % WORLD_SIZE != 0: + perf_dict_total = { + "flops": [-1 * mask_nums] * output_n, + "mem": [-1 * mask_nums] * output_n, + } + break global already_known_oom_before_run already_known_oom_before_run = False @@ -887,9 +943,9 @@ def run_benchmark( fn = run_dist_attn( total_seqlen=seqlen, embed_dim=DATA_CONFIG.embed_dim, - q_heads=DATA_CONFIG.heads_q, - kv_heads=DATA_CONFIG.heads_kv, - hidden_size=DATA_CONFIG.hidden_size, + num_heads_q=DATA_CONFIG.num_heads_q, + num_heads_kv=DATA_CONFIG.num_heads_kv, + head_dim=DATA_CONFIG.head_dim, dtype=DATA_CONFIG.dtype, q_ranges=q_ranges, k_ranges=k_ranges, @@ -910,9 +966,9 @@ def run_benchmark( fn = run_magi_attn( total_seqlen=TOTAL_SEQLEN, embed_dim=DATA_CONFIG.embed_dim, - q_heads=DATA_CONFIG.heads_q, - kv_heads=DATA_CONFIG.heads_kv, - hidden_size=DATA_CONFIG.hidden_size, + num_heads_q=DATA_CONFIG.num_heads_q, + num_heads_kv=DATA_CONFIG.num_heads_kv, + head_dim=DATA_CONFIG.head_dim, dtype=DATA_CONFIG.dtype, q_ranges=q_ranges, k_ranges=k_ranges, @@ -941,8 +997,8 @@ def run_benchmark( k_ranges=k_ranges, attn_mask_type=attn_mask_type, total_seqlen_q=seqlen, - num_heads_q=DATA_CONFIG.heads_q, - head_dim=DATA_CONFIG.hidden_size, + num_heads_q=DATA_CONFIG.num_heads_q, + head_dim=DATA_CONFIG.head_dim, ) attn_flops = attn_flops_dict[wd] diff --git a/exps/dist_attn/run_benchmark.sh b/exps/dist_attn/run_benchmark.sh index 26062323f..85798d620 100644 --- a/exps/dist_attn/run_benchmark.sh +++ b/exps/dist_attn/run_benchmark.sh @@ -14,15 +14,22 @@ # See the License for the specific language governing permissions and # limitations under the License. +if [[ -f .env ]]; then + source .env # maybe put your own master node IP here +fi + +export NCCL_SOCKET_IFNAME=${SOCKET_IFNAME:-"bond0"} +export NVSHMEM_BOOTSTRAP_UID_SOCK_IFNAME=${NCCL_SOCKET_IFNAME} + export OMP_NUM_THREADS=${OMP_NUM_THREADS:-1} export MASTER_ADDR=${MASTER_ADDR:-127.0.0.1} # replace with your own master node IP export MASTER_PORT=${MASTER_PORT:-16988} export NNODES=${NNODES:-1} export NPROC_PER_NODE=${NPROC_PER_NODE:-8} -export NODE_RANK=${NODE_RANK:-0} +export RANK=${RANK:-0} export WORLD_SIZE=$((NPROC_PER_NODE * NNODES)) -echo "MASTER_ADDR=$MASTER_ADDR, MASTER_PORT=$MASTER_PORT, NNODES=$NNODES, NPROC_PER_NODE=$NPROC_PER_NODE, NODE_RANK=$NODE_RANK" +echo "MASTER_ADDR=$MASTER_ADDR, MASTER_PORT=$MASTER_PORT, NNODES=$NNODES, NPROC_PER_NODE=$NPROC_PER_NODE, RANK=$RANK" # to provide custom profile output name by `--profile` argument export PROFILE_NAME=${PROFILE_NAME:-"cp_benchmark"} @@ -71,7 +78,7 @@ export PYTHONPATH=../../ DISTRIBUTED_ARGS=" --nproc_per_node $NPROC_PER_NODE \ --nnodes $NNODES \ - --node_rank $NODE_RANK \ + --node_rank $RANK \ --master_addr $MASTER_ADDR \ --master_port $MASTER_PORT " @@ -82,7 +89,7 @@ TORCHRUN_CMD="torchrun $DISTRIBUTED_ARGS run_benchmark.py" # generate a timestamp for the nsys output file TIMESTAMP=$(date +"%Y%m%d_%H%M%S") -OUT_NAME="${PROFILE_NAME}_node${NODE_RANK}_${TIMESTAMP}" +OUT_NAME="${PROFILE_NAME}_node${RANK}_${TIMESTAMP}" echo "Config: ${CONFIG_PATH}, Profile output(if enabled): ${OUT_NAME}.nsys-rep" CMD=" nsys profile \ diff --git a/exps/grpcoll/grpcoll_utils.py b/exps/grpcoll/grpcoll_utils.py index f91e87233..4a5972381 100644 --- a/exps/grpcoll/grpcoll_utils.py +++ b/exps/grpcoll/grpcoll_utils.py @@ -85,13 +85,33 @@ def perm_idxs2unperm_idxs(perm_idxs: torch.Tensor) -> torch.Tensor: def get_random_split_size_list( total_seqlen: int, num_splits: int, + split_alignment: int = 1, ) -> list[int]: - cu_seqlens = ( - [0] - + sorted(random.sample(range(1, total_seqlen - 1), num_splits - 1)) - + [total_seqlen] - ) - seqlens = torch.tensor(cu_seqlens, dtype=torch.int).diff().tolist() + # 1. Validation + if total_seqlen % split_alignment != 0: + raise ValueError( + f"total_seqlen ({total_seqlen}) must be divisible by split_alignment ({split_alignment})" + ) + + # Calculate how many "aligned blocks" we have in total + total_blocks = total_seqlen // split_alignment + + if total_blocks < num_splits: + raise ValueError( + f"Not enough length to satisfy {num_splits} splits with alignment {split_alignment}" + ) + + # 2. Random Sampling in "Block Space" + # We need to pick (num_splits - 1) dividers from (total_blocks - 1) possible slots + # to ensure each split has at least 1 block (size >= split_alignment) + dividers = sorted(random.sample(range(1, total_blocks), num_splits - 1)) + + cu_blocks = [0] + dividers + [total_blocks] + + # 3. Calculate diffs and scale back up + block_seqlens = torch.tensor(cu_blocks, dtype=torch.int).diff() + seqlens = (block_seqlens * split_alignment).tolist() + return seqlens diff --git a/exps/grpcoll/run_grpcoll_test.sh b/exps/grpcoll/run_grpcoll_test.sh index 3b5991963..0cb507c97 100644 --- a/exps/grpcoll/run_grpcoll_test.sh +++ b/exps/grpcoll/run_grpcoll_test.sh @@ -24,16 +24,26 @@ TEST_MODE=${TEST_MODE:-"intra_node"} # intra_node | low_latency | internode mkdir -p ${LOG_ROOT} +# Set common env vars export PYTHONPATH=$PYTHONPATH:. - -# For debug +export OMP_NUM_THREADS=1 # export CUDA_LAUNCH_BLOCKING=1 -# NOTE: grpcoll test will set the env vars in the script -# export NVSHMEM_IB_ENABLE_IBGDA=1 -# export NVSHMEM_IBGDA_NIC_HANDLER=gpu -# export NVSHMEM_DISABLE_P2P=0 # set to 0 to enable NVLink in low-latency mode -# export NVSHMEM_SYMMETRIC_SIZE=2**30 # default: 1GB +# Set nccl env vars +# export NCCL_DEBUG=INFO +export NCCL_SOCKET_IFNAME=${SOCKET_IFNAME:-"bond0"} + +# Set nvshmem env vars +# export NVSHMEM_DEBUG=INFO +# export NVSHMEM_ENABLE_NIC_PE_MAPPING=1 +# export NVSHMEM_IBGDA_ENABLE_MULTI_PORT=1 +# export NVSHMEM_HCA_LIST=mlx5_10,mlx5_11,mlx5_12,mlx5_13,mlx5_14,mlx5_15,mlx5_16,mlx5_17 +# export NVSHMEM_IB_ADDR_FAMILY=AF_INET +# export NVSHMEM_IB_ADDR_RANGE=0.0.0.0/0 +# export NVSHMEM_IB_GID_INDEX=3 +# export NVSHMEM_IB_TRAFFIC_CLASS=128 +# export NVSHMEM_BOOTSTRAP_UID_SOCK_FAMILY=AF_INET +export NVSHMEM_BOOTSTRAP_UID_SOCK_IFNAME=${NCCL_SOCKET_IFNAME} # ----- test-intranode ----- # @@ -70,27 +80,19 @@ else echo "Launch with node rank: $1" fi -# init dist env vars -export OMP_NUM_THREADS=1 +# Init multi-node dist env vars export MASTER_ADDR=${MASTER_ADDR:-127.0.0.1} # replace with your own master node IP export MASTER_PORT=23457 export NNODES=2 # in deepep internode kernels, it will check num_ranks > NUM_MAX_NVL_PEERS, which equals to 8 by default export NPROC_PER_NODE=8 export RANK=$1 -echo "MASTER_ADDR=$MASTER_ADDR, MASTER_PORT=$MASTER_PORT, NNODES=$NNODES, NPROC_PER_NODE=$NPROC_PER_NODE, RANK=$RANK" - -# set nccl env vars -# export NCCL_DEBUG=INFO -export NCCL_SOCKET_IFNAME=${SOCKET_IFNAME:-"bond0"} - if [[ $RANK -ge $NNODES ]]; then echo "Error: RANK=$RANK, but NNODES=$NNODES" exit 1 fi -# self-added env variable to control low-latency mode for test_internode.py -export GRPCOLL_TEST_INTERNODE_LL_COMPATIBILITY=0 +echo "Multi-Node Distributed settings: MASTER_ADDR=$MASTER_ADDR, MASTER_PORT=$MASTER_PORT, NNODES=$NNODES, NPROC_PER_NODE=$NPROC_PER_NODE, RANK=$RANK" CMD="torchrun \ --nproc_per_node=$NPROC_PER_NODE \ diff --git a/exps/grpcoll/test_internode_grpcoll.py b/exps/grpcoll/test_internode_grpcoll.py index 9f1598b4a..b06b0a081 100644 --- a/exps/grpcoll/test_internode_grpcoll.py +++ b/exps/grpcoll/test_internode_grpcoll.py @@ -83,6 +83,7 @@ def prepare_test_func_kwargs( hidden_size: int, num_heads: int, num_input_splits: int, + split_alignment: int, num_data_groups_gc: int, num_data_groups_gr: int, dtype: torch.dtype, @@ -131,7 +132,11 @@ def prepare_test_func_kwargs( print(f"[RANK {rank}]: {x.shape=} | {x=}\n" f"{lse_shape=} | {lse=}\n", flush=True) # Random score - input_split_size_list = get_random_split_size_list(num_tokens, num_input_splits) + input_split_size_list = get_random_split_size_list( + total_seqlen=num_tokens, + num_splits=num_input_splits, + split_alignment=split_alignment, + ) dst_indices_list = get_random_dst_indices_list( num_splits=num_input_splits, num_ranks=num_ranks, @@ -311,9 +316,13 @@ def prepare_test_func_kwargs( # NOTE: we can assume num_local_experts == 1 # thus sending one token to one rank is equivalent to sending to the only one "local expert" in that rank num_local_experts=1, - input_split_size_list=input_split_size_list, + input_split_size_list=[ + split // split_alignment for split in input_split_size_list + ], dst_indices_list=dst_indices_list, - output_split_size_list=output_split_size_list, + output_split_size_list=[ + split // split_alignment for split in output_split_size_list + ], src_index_list=src_index_list, use_topk=False, use_a2a_order_output=not random_permute_output, @@ -331,7 +340,9 @@ def prepare_test_func_kwargs( # use host meta perm_to_a2av_idx = get_a2av_perm_idxs_from_group_cast_meta( - output_split_sizes=output_split_size_list, + output_split_sizes=[ + split // split_alignment for split in output_split_size_list + ], src_index=src_index_list, num_ranks=num_ranks, ) @@ -339,10 +350,10 @@ def prepare_test_func_kwargs( # use device meta perm_to_a2av_idx_device = get_a2av_perm_idxs_from_group_cast_meta( - output_split_sizes=output_split_sizes, + output_split_sizes=output_split_sizes // split_alignment, src_index=src_index, num_ranks=num_ranks, - output_seqlen=recv_x_gc_buf.shape[0], + output_seqlen=recv_x_gc_buf.shape[0] // split_alignment, ) if pass_padded_out_buffer: unperm_from_a2av_idx_device = perm_idxs2unperm_idxs( @@ -369,7 +380,7 @@ def prepare_test_func_kwargs( ) if not random_permute_output: arange_idx = torch.arange( - sum(output_split_size_list), + sum(output_split_size_list) // split_alignment, dtype=torch.int64, device="cuda", ) @@ -403,7 +414,7 @@ def prepare_test_func_kwargs( ref_num_tokens_per_rdma_rank, ref_is_token_in_rank, ) = get_native_group_cast_meta( - input_split_sizes=input_split_size_list, + input_split_sizes=[split // split_alignment for split in input_split_size_list], dst_indices=dst_indices_list, group=group, num_nodes=num_nodes, @@ -415,7 +426,7 @@ def prepare_test_func_kwargs( ref_num_tokens_per_rdma_rank_device, ref_is_token_in_rank_device, ) = get_native_group_cast_meta( - input_split_sizes=input_split_sizes, + input_split_sizes=input_split_sizes // split_alignment, dst_indices=dst_indices, group=group, num_nodes=num_nodes, @@ -433,14 +444,14 @@ def prepare_test_func_kwargs( # use host meta layout_t2r_idx = transfer_splits_and_dst_idxs_to_t2r_idx( - input_split_sizes=input_split_size_list, + input_split_sizes=[split // split_alignment for split in input_split_size_list], dst_indices=dst_indices_list, num_ranks=num_ranks, ) # use device meta layout_t2r_idx_device = transfer_splits_and_dst_idxs_to_t2r_idx( - input_split_sizes=input_split_sizes, + input_split_sizes=input_split_sizes // split_alignment, dst_indices=dst_indices, num_ranks=num_ranks, ) @@ -558,6 +569,7 @@ def test_func( acc_reduce_out_buffer: bool, acc_reduce_constant: int, min_num_dst_ranks: int, + split_alignment: int, **kwargs, ) -> dict[str, Any]: # fetch kwargs @@ -632,6 +644,42 @@ def test_func( if pass_out_buffer: recv_x_gc_buf_list.append(recv_x_gc_buf.clone()) + # View tensors with split alignment + # from (seqlen, hidden_dim) to (seqlen // split_alignment, split_alignment * hidden_dim) + if split_alignment > 1: + x = x.view(-1, split_alignment * x.shape[-1]) + x_list = [x_i.view(-1, split_alignment * x_i.shape[-1]) for x_i in x_list] + + recv_x_gc = recv_x_gc.view(-1, split_alignment * recv_x_gc.shape[-1]) + recv_x_gc_list = [ + buf.view(-1, split_alignment * buf.shape[-1]) for buf in recv_x_gc_list + ] + + recv_x_gc_buf = ( + recv_x_gc_buf.view(-1, split_alignment * recv_x_gc_buf.shape[-1]) + if recv_x_gc_buf is not None + else None + ) + if recv_x_gc_buf_list is not None: + recv_x_gc_buf_list = [ + buf.view(-1, split_alignment * buf.shape[-1]) + for buf in recv_x_gc_buf_list + ] + + lse = lse.view(-1, split_alignment * lse.shape[-1]) if lse is not None else None + + recv_lse_gc = ( + recv_lse_gc.view(-1, split_alignment * recv_lse_gc.shape[-1]) + if recv_lse_gc is not None + else None + ) + + recv_lse_gc_buf = ( + recv_lse_gc_buf.view(-1, split_alignment * recv_lse_gc_buf.shape[-1]) + if recv_lse_gc_buf is not None + else None + ) + common_group_cast_args: dict[str, Any] = { # w/o handle tensors "x": x if num_data_groups_gc == 1 else x_list, "recv_x": recv_x_gc_buf if num_data_groups_gc == 1 else recv_x_gc_buf_list, @@ -931,6 +979,38 @@ def test_func( reduced_x_gr_buf_list.append(reduced_x_gr_buf_2nd.clone()) num_data_groups_gr += 1 + # View tensors with split alignment + # from (seqlen, hidden_dim) to (seqlen // split_alignment, split_alignment * hidden_dim) + if split_alignment > 1: + reduced_x_gr = reduced_x_gr.view(-1, split_alignment * reduced_x_gr.shape[-1]) + + reduced_x_gr_list = [ + buf.view(-1, split_alignment * buf.shape[-1]) for buf in reduced_x_gr_list + ] + + reduced_x_gr_buf = ( + reduced_x_gr_buf.view(-1, split_alignment * reduced_x_gr_buf.shape[-1]) + if reduced_x_gr_buf is not None + else None + ) + if reduced_x_gr_buf_list is not None: + reduced_x_gr_buf_list = [ + buf.view(-1, split_alignment * buf.shape[-1]) + for buf in reduced_x_gr_buf_list + ] + + reduced_lse_gr = ( + reduced_lse_gr.view(-1, split_alignment * reduced_lse_gr.shape[-1]) + if reduced_lse_gr is not None + else None + ) + + reduced_lse_gr_buf = ( + reduced_lse_gr_buf.view(-1, split_alignment * reduced_lse_gr_buf.shape[-1]) + if reduced_lse_gr_buf is not None + else None + ) + # permute x/lse to the rank order if random_permute_output: if use_a2av_perm_idxs == "inside": @@ -1142,8 +1222,10 @@ def test_func( # For later tuning # NOTE: since we've not passed `comm_dtype` to group_reduce in tuning, # we can just calculate all the bytes based on x's dtype - group_cast_rdma_send_bytes = num_rdma_token_sent * hidden_size * x.dtype.itemsize - group_cast_nvl_recv_bytes = recv_x.numel() * recv_x_gc.dtype.itemsize + group_cast_rdma_send_bytes = ( + num_rdma_token_sent * hidden_size * split_alignment * x.dtype.itemsize + ) + group_cast_nvl_recv_bytes = recv_x_gc.numel() * recv_x_gc.dtype.itemsize group_reduce_nvl_send_bytes = group_cast_nvl_recv_bytes group_reduce_rdma_recv_bytes = group_cast_rdma_send_bytes @@ -1168,23 +1250,28 @@ def tune_func( rdma_buffer_size: int, pass_out_buffer: bool, acc_reduce_out_buffer: bool, + split_alignment: int, ) -> None: - # fetch some constant test kwargs for later usage + # Fetch some constant test kwargs for later usage x = test_kwargs["x"] num_tokens_per_rank = test_kwargs["num_tokens_per_rank"] num_tokens_per_rdma_rank = test_kwargs["num_tokens_per_rdma_rank"] is_token_in_rank = test_kwargs["is_token_in_rank"] - # fetch some constant test out for later usage + # Fetch some constant test out for later usage handle = test_out["handle"] group_cast_nvl_recv_bytes = test_out["group_cast_nvl_recv_bytes"] group_cast_rdma_recv_bytes = test_out["group_cast_rdma_recv_bytes"] group_reduce_nvl_send_bytes = test_out["group_reduce_nvl_send_bytes"] group_reduce_rdma_recv_bytes = test_out["group_reduce_rdma_recv_bytes"] + # View tensors with split alignment + # from (seqlen, hidden_dim) to (seqlen // split_alignment, split_alignment * hidden_dim) + x = x.view(-1, split_alignment * x.shape[-1]) + # -------------- tune group_cast -------------- # - # sync before tuning + # Sync before tuning torch.cuda.synchronize() dist.barrier() @@ -1249,6 +1336,7 @@ def tune_func( dist.all_gather(all_best_results_list, best_group_cast_results, group=group) best_group_cast_results = all_best_results_list[0].tolist() + # Apply group_cast to get handle before group_reduce group_cast_config = GrpCollConfig( num_sms=best_group_cast_results[0], nvl_chunk_size=best_group_cast_results[1], @@ -1273,11 +1361,10 @@ def tune_func( # -------------- tune group_reduce -------------- # - # sync before tuning + # Sync before tuning torch.cuda.synchronize() dist.barrier() - # Tune group_reduce performance best_time, best_results = 1e10, None reduced_x_buf = torch.zeros_like(x) if pass_out_buffer else None for nvl_chunk_size in range(1, 8, 1): @@ -1345,11 +1432,7 @@ def test_main( # Settings num_tokens, hidden_size = args.num_tokens, args.hidden_size num_channels = num_sms // 2 - # NOTE: different from intranode group reduce, - # if num_heads * size(float) % 16 == 0, i.e. num_heads % 4 == 0, - # internode group reduce `kNVLReceivers` will use TMA to copy lse - # otherwise using normal unrolled warp copy - num_heads = 16 + split_alignment = args.split_alignment # choose dtype from {torch.bfloat16, torch.float16, torch.float32, torch.float64} dtype = torch.float32 # TODO: make it parameterizable @@ -1358,7 +1441,13 @@ def test_main( # Remake the hidden size to control # the communication bytes per token the same as bf16/fp16 hidden_size = hidden_size * 2 // dtype.itemsize + # NOTE: different from intranode group reduce, + # if num_heads * size(float) % 16 == 0, i.e. num_heads % 4 == 0, + # internode group reduce `kNVLReceivers` will use TMA to copy lse + # otherwise using normal unrolled warp copy + num_heads = 16 assert hidden_size % num_heads == 0 + head_dim = hidden_size // num_heads # Re-Settings for group-collective # TODO: make these parameterizable @@ -1385,7 +1474,10 @@ def test_main( pass_out_buffer = True # for both group_cast and group_reduce pass_out_lse_buffer = True # for both group_cast and group_reduce - pass_padded_out_buffer = True # set to True to use oversized buffer for group_cast output and group_reduce input + pass_padded_out_buffer = False # set to True to use oversized buffer for group_cast output and group_reduce input + assert ( + split_alignment == 1 or not pass_padded_out_buffer + ), "pass_padded_out_buffer only supports split_alignment == 1 for simplicity" acc_reduce_out_buffer = True acc_reduce_constant = rank @@ -1410,19 +1502,15 @@ def test_main( # Config num_max_nvl_chunked_send_tokens = 8 - nvl_buffer_size = num_max_nvl_chunked_recv_tokens = ( - 720 if num_ranks in (144, 160) else 512 - ) // ( # NOTE: too large NVL buffer size for triple data groups - max(2, dtype.itemsize // 2) if max_num_data_groups == 3 else 1 - ) + nvl_buffer_size = num_max_nvl_chunked_recv_tokens = 512 num_max_rdma_chunked_send_tokens = 16 - rdma_buffer_size = num_max_rdma_chunked_recv_tokens = 128 + rdma_buffer_size = num_max_rdma_chunked_recv_tokens = 1024 config = GrpCollConfig( - num_sms=num_sms, # num_sms, default 20 - nvl_chunk_size=num_max_nvl_chunked_send_tokens, # num_max_nvl_chunked_send_tokens (nvl_chunk_size), default 6 - nvl_buffer_size=num_max_nvl_chunked_recv_tokens, # num_max_nvl_chunked_recv_tokens (nvl_buffer_size), default 256 - rdma_chunk_size=num_max_rdma_chunked_send_tokens, # num_max_rdma_chunked_send_tokens, default 6 - rdma_buffer_size=num_max_rdma_chunked_recv_tokens, # num_max_rdma_chunked_recv_tokens, default 256 + num_sms=num_sms, + nvl_chunk_size=num_max_nvl_chunked_send_tokens, + nvl_buffer_size=num_max_nvl_chunked_recv_tokens, + rdma_chunk_size=num_max_rdma_chunked_send_tokens, + rdma_buffer_size=num_max_rdma_chunked_recv_tokens, ) min_num_rdma_bytes, min_num_nvl_bytes = GrpCollConfig.get_min_num_bytes_internode( num_sms=num_sms, @@ -1453,7 +1541,8 @@ def test_main( f"| {min_num_rdma_bytes=} ({min_num_rdma_bytes / 1024**2:.2f} MB)" f"| {min_num_nvl_bytes=} ({min_num_nvl_bytes / 1024**2:.2f} MB)\n" f"{num_tokens=} | {hidden_size=} | {dtype=} | {comm_dtype=}\n" - f"{num_heads=} | {num_data_groups_gc=} | {num_data_groups_gr=} | {cast_lse=} | {reduce_op=}\n" + f"{num_input_splits=} | {split_alignment=} | {num_heads=} | {head_dim=}\n" + f"{num_data_groups_gc=} | {num_data_groups_gr=} | {cast_lse=} | {reduce_op=}\n" f"{nvl_buffer_size=} | {num_max_nvl_chunked_send_tokens=} | {num_max_nvl_chunked_recv_tokens=}\n" f"{rdma_buffer_size=} | {num_max_rdma_chunked_send_tokens=} | {num_max_rdma_chunked_recv_tokens=}\n" f"{distinct_token=} | {random_permute_output=} | {sim_gemm_weight=} | {min_num_dst_ranks=}\n" @@ -1475,6 +1564,7 @@ def test_main( hidden_size=hidden_size, num_heads=num_heads, num_input_splits=num_input_splits, + split_alignment=split_alignment, num_data_groups_gc=num_data_groups_gc, num_data_groups_gr=num_data_groups_gr, dtype=dtype, @@ -1514,6 +1604,7 @@ def test_main( acc_reduce_out_buffer=acc_reduce_out_buffer, acc_reduce_constant=acc_reduce_constant, min_num_dst_ranks=min_num_dst_ranks, + split_alignment=split_alignment, # kwargs **test_kwargs, ) @@ -1531,6 +1622,7 @@ def test_main( rdma_buffer_size=rdma_buffer_size, pass_out_buffer=pass_out_buffer, acc_reduce_out_buffer=acc_reduce_out_buffer, + split_alignment=split_alignment, ) @@ -1563,8 +1655,8 @@ def test_loop(args: argparse.Namespace): num_sms, ll_num_experts // num_ranks if args.test_ll_compatibility else 0 ) - num_nvl_bytes = int(2e9) # ~2GB - num_rdma_bytes = int(1e9) # ~1GB + num_nvl_bytes = int(5e9) # ~5GB, to meet most of the requirements + num_rdma_bytes = int(5e9) # ~5GB, to meet most of the requirements # print config if local_rank == 0: @@ -1640,26 +1732,24 @@ def test_loop(args: argparse.Namespace): help="Number of processes to spawn (default: 8)", ) parser.add_argument( - "--num-tokens", type=int, default=4096, help="Number of tokens (default: 4096)" - ) - parser.add_argument( - # TODO: find out the relationship between hidden size and bandwidth - "--hidden_size", + # NOTE: the internode kernel performance is highly dependent on the sequence length + "--num-tokens", type=int, - default=56 * 128, - help="Hidden dimension size (default: 56x128=7168)", + default=4096, + help="Number of tokens (default: 4096)", ) parser.add_argument( - "--num-topk-groups", + "--split-alignment", type=int, - default=None, - help="Number of top-k groups (default: `min(num_nodes, 4)`)", + default=1, + help="Split alignment (default: 1)", ) parser.add_argument( - "--num-topk", type=int, default=8, help="Number of top-k experts (default: 8)" - ) - parser.add_argument( - "--num-experts", type=int, default=256, help="Number of experts (default: 256" + # NOTE: the internode kernel performance is highly dependent on the hidden size + "--hidden_size", + type=int, + default=64 * 128, + help="Hidden dimension size (default: 64x128=8192)", ) parser.add_argument( "--test-ll-compatibility", @@ -1668,6 +1758,10 @@ def test_loop(args: argparse.Namespace): ) args = parser.parse_args() + assert ( + args.hidden_size % args.split_alignment == 0 + ), f"hidden size {args.hidden_size} must be divisible by split alignment {args.split_alignment}" + args.test_ll_compatibility = False num_processes = args.num_processes diff --git a/exps/grpcoll/test_intranode_grpcoll.py b/exps/grpcoll/test_intranode_grpcoll.py index 84eb49a38..91ac164dc 100644 --- a/exps/grpcoll/test_intranode_grpcoll.py +++ b/exps/grpcoll/test_intranode_grpcoll.py @@ -81,6 +81,7 @@ def prepare_test_func_kwargs( hidden_size: int, num_heads: int, num_input_splits: int, + split_alignment: int, num_data_groups_gc: int, num_data_groups_gr: int, dtype: torch.dtype, @@ -129,7 +130,11 @@ def prepare_test_func_kwargs( print(f"[RANK {rank}]: {x.shape=} | {x=}\n" f"{lse_shape=} | {lse=}\n", flush=True) # Random score (transfered from group-cast meta args) - input_split_size_list = get_random_split_size_list(num_tokens, num_input_splits) + input_split_size_list = get_random_split_size_list( + total_seqlen=num_tokens, + num_splits=num_input_splits, + split_alignment=split_alignment, + ) dst_indices_list = get_random_dst_indices_list( num_splits=num_input_splits, num_ranks=num_ranks, @@ -309,9 +314,13 @@ def prepare_test_func_kwargs( # NOTE: we can assume num_local_experts == 1 # thus sending one token to one rank is equivalent to sending to the only one "local expert" in that rank num_local_experts=1, - input_split_size_list=input_split_size_list, + input_split_size_list=[ + split // split_alignment for split in input_split_size_list + ], dst_indices_list=dst_indices_list, - output_split_size_list=output_split_size_list, + output_split_size_list=[ + split // split_alignment for split in output_split_size_list + ], src_index_list=src_index_list, use_topk=False, use_a2a_order_output=not random_permute_output, @@ -330,7 +339,9 @@ def prepare_test_func_kwargs( # use host meta perm_to_a2av_idx = get_a2av_perm_idxs_from_group_cast_meta( - output_split_sizes=output_split_size_list, + output_split_sizes=[ + split // split_alignment for split in output_split_size_list + ], src_index=src_index_list, num_ranks=num_ranks, ) @@ -338,10 +349,10 @@ def prepare_test_func_kwargs( # use device meta perm_to_a2av_idx_device = get_a2av_perm_idxs_from_group_cast_meta( - output_split_sizes=output_split_sizes, + output_split_sizes=output_split_sizes // split_alignment, src_index=src_index, num_ranks=num_ranks, - output_seqlen=recv_x_gc_buf.shape[0], + output_seqlen=recv_x_gc_buf.shape[0] // split_alignment, ) if pass_padded_out_buffer: unperm_from_a2av_idx_device = perm_idxs2unperm_idxs( @@ -370,7 +381,7 @@ def prepare_test_func_kwargs( if not random_permute_output: arange_idx = torch.arange( - sum(output_split_size_list), + sum(output_split_size_list) // split_alignment, dtype=torch.int64, device="cuda", ) @@ -400,7 +411,7 @@ def prepare_test_func_kwargs( _, # ref_num_tokens_per_rdma_rank, ref_is_token_in_rank, ) = get_native_group_cast_meta( - input_split_sizes=input_split_size_list, + input_split_sizes=[split // split_alignment for split in input_split_size_list], dst_indices=dst_indices_list, group=group, num_nodes=1, @@ -412,7 +423,7 @@ def prepare_test_func_kwargs( _, # ref_num_tokens_per_rdma_rank_device, ref_is_token_in_rank_device, ) = get_native_group_cast_meta( - input_split_sizes=input_split_sizes, + input_split_sizes=input_split_sizes // split_alignment, dst_indices=dst_indices, group=group, num_nodes=1, @@ -426,14 +437,14 @@ def prepare_test_func_kwargs( # use host meta layout_t2r_idx = transfer_splits_and_dst_idxs_to_t2r_idx( - input_split_sizes=input_split_size_list, + input_split_sizes=[split // split_alignment for split in input_split_size_list], dst_indices=dst_indices_list, num_ranks=num_ranks, ) # use device meta layout_t2r_idx_device = transfer_splits_and_dst_idxs_to_t2r_idx( - input_split_sizes=input_split_sizes, + input_split_sizes=input_split_sizes // split_alignment, dst_indices=dst_indices, num_ranks=num_ranks, ) @@ -545,6 +556,7 @@ def test_func( acc_reduce_out_buffer: bool, acc_reduce_constant: int, min_num_dst_ranks: int, + split_alignment: int, **kwargs, ) -> dict[str, Any]: # fetch kwargs @@ -606,6 +618,42 @@ def test_func( if pass_out_buffer: recv_x_gc_buf_list.append(recv_x_gc_buf.clone()) + # View tensors with split alignment + # from (seqlen, hidden_dim) to (seqlen // split_alignment, split_alignment * hidden_dim) + if split_alignment > 1: + x = x.view(-1, split_alignment * x.shape[-1]) + x_list = [x_i.view(-1, split_alignment * x_i.shape[-1]) for x_i in x_list] + + recv_x_gc = recv_x_gc.view(-1, split_alignment * recv_x_gc.shape[-1]) + recv_x_gc_list = [ + buf.view(-1, split_alignment * buf.shape[-1]) for buf in recv_x_gc_list + ] + + recv_x_gc_buf = ( + recv_x_gc_buf.view(-1, split_alignment * recv_x_gc_buf.shape[-1]) + if recv_x_gc_buf is not None + else None + ) + if recv_x_gc_buf_list is not None: + recv_x_gc_buf_list = [ + buf.view(-1, split_alignment * buf.shape[-1]) + for buf in recv_x_gc_buf_list + ] + + lse = lse.view(-1, split_alignment * lse.shape[-1]) if lse is not None else None + + recv_lse_gc = ( + recv_lse_gc.view(-1, split_alignment * recv_lse_gc.shape[-1]) + if recv_lse_gc is not None + else None + ) + + recv_lse_gc_buf = ( + recv_lse_gc_buf.view(-1, split_alignment * recv_lse_gc_buf.shape[-1]) + if recv_lse_gc_buf is not None + else None + ) + common_group_cast_args: dict[str, Any] = { # w/o handle tensors "x": x if num_data_groups_gc == 1 else x_list, "recv_x": recv_x_gc_buf if num_data_groups_gc == 1 else recv_x_gc_buf_list, @@ -786,31 +834,32 @@ def test_func( assert recv_lse.size(0) == recv_src_idx.size(0) num_heads = recv_lse.size(1) - if random_permute_output: - if use_a2av_perm_idxs == "no": - permed_recv_src_idx = unpermute_output( - output=recv_src_idx, - unperm_after_a2a_kwargs=range_gather_post_group_cast_kwargs, - ) - else: # "inside" or "outside" - # NOTE: we won't permute recv_src_idx inside for now - permed_recv_src_idx = recv_src_idx[unperm_from_a2av_idx] - else: - permed_recv_src_idx = recv_src_idx - - repeated_permed_recv_src_idx = ( - permed_recv_src_idx.repeat_interleave(repeats=num_heads, dim=0) - .reshape(-1, num_heads) - .to(recv_lse.dtype) - ) + if split_alignment == 1: + if random_permute_output: + if use_a2av_perm_idxs == "no": + permed_recv_src_idx = unpermute_output( + output=recv_src_idx, + unperm_after_a2a_kwargs=range_gather_post_group_cast_kwargs, + ) + else: # "inside" or "outside" + # NOTE: we won't permute recv_src_idx inside for now + permed_recv_src_idx = recv_src_idx[unperm_from_a2av_idx] + else: + permed_recv_src_idx = recv_src_idx - if pass_padded_out_buffer: - assert torch.equal( - recv_lse[:actual_gc_output_seqlen], - repeated_permed_recv_src_idx[:actual_gc_output_seqlen], + repeated_permed_recv_src_idx = ( + permed_recv_src_idx.repeat_interleave(repeats=num_heads, dim=0) + .reshape(-1, num_heads) + .to(recv_lse.dtype) ) - else: - assert torch.equal(recv_lse, repeated_permed_recv_src_idx) + + if pass_padded_out_buffer: + assert torch.equal( + recv_lse[:actual_gc_output_seqlen], + repeated_permed_recv_src_idx[:actual_gc_output_seqlen], + ) + else: + assert torch.equal(recv_lse, repeated_permed_recv_src_idx) if local_rank == 0: print( @@ -896,6 +945,38 @@ def test_func( reduced_x_gr_buf_list.append(reduced_x_gr_buf_2nd.clone()) num_data_groups_gr += 1 + # View tensors with split alignment + # from (seqlen, hidden_dim) to (seqlen // split_alignment, split_alignment * hidden_dim) + if split_alignment > 1: + reduced_x_gr = reduced_x_gr.view(-1, split_alignment * reduced_x_gr.shape[-1]) + + reduced_x_gr_list = [ + buf.view(-1, split_alignment * buf.shape[-1]) for buf in reduced_x_gr_list + ] + + reduced_x_gr_buf = ( + reduced_x_gr_buf.view(-1, split_alignment * reduced_x_gr_buf.shape[-1]) + if reduced_x_gr_buf is not None + else None + ) + if reduced_x_gr_buf_list is not None: + reduced_x_gr_buf_list = [ + buf.view(-1, split_alignment * buf.shape[-1]) + for buf in reduced_x_gr_buf_list + ] + + reduced_lse_gr = ( + reduced_lse_gr.view(-1, split_alignment * reduced_lse_gr.shape[-1]) + if reduced_lse_gr is not None + else None + ) + + reduced_lse_gr_buf = ( + reduced_lse_gr_buf.view(-1, split_alignment * reduced_lse_gr_buf.shape[-1]) + if reduced_lse_gr_buf is not None + else None + ) + # permute x/lse to the rank order if random_permute_output: if use_a2av_perm_idxs == "inside": @@ -1096,20 +1177,25 @@ def tune_func( nvl_buffer_size: int, pass_out_buffer: bool, acc_reduce_out_buffer: bool, + split_alignment: int, ) -> None: - # fetch some constant test kwargs for later usage + # Fetch some constant test kwargs for later usage x = test_kwargs["x"] num_tokens_per_rank = test_kwargs["num_tokens_per_rank"] is_token_in_rank = test_kwargs["is_token_in_rank"] - # fetch some constant test out for later usage + # Fetch some constant test out for later usage handle = test_out["handle"] group_cast_nvl_recv_bytes = test_out["group_cast_nvl_recv_bytes"] group_reduce_nvl_send_bytes = test_out["group_reduce_nvl_send_bytes"] + # View tensors with split alignment + # from (seqlen, hidden_dim) to (seqlen // split_alignment, split_alignment * hidden_dim) + x = x.view(-1, split_alignment * x.shape[-1]) + # -------------- tune group_cast -------------- # - # sync before tuning + # Sync before tuning torch.cuda.synchronize() dist.barrier() @@ -1122,21 +1208,20 @@ def tune_func( best_group_cast_results = None best_time, best_results = 1e10, None nvl_recv_bytes = group_cast_nvl_recv_bytes - for nvl_chunk_size in tuple(range(4, 33, 2)) + (0,): - if nvl_chunk_size > 0: - config = GrpCollConfig( - num_sms=num_sms, - nvl_chunk_size=nvl_chunk_size, - nvl_buffer_size=nvl_buffer_size, - ) - else: # Test default config as well - config = GrpCollConfig.get_default_group_cast_config(num_ranks) + for nvl_chunk_size in range(4, 33, 2): + config = GrpCollConfig( + num_sms=num_sms, + nvl_chunk_size=nvl_chunk_size, + nvl_buffer_size=nvl_buffer_size, + ) + tune_args = { "x": x, "handle": handle, "config": config, } # TODO: add other flags to tune args t = bench(lambda: buffer.group_cast(**tune_args))[0] + if t < best_time and nvl_chunk_size > 0: best_time, best_results = t, (num_sms, nvl_chunk_size) if local_rank == 0: @@ -1148,8 +1233,7 @@ def tune_func( if local_rank == 0: print( - f"[tuning] Best group_cast " - f'({"FP8" if isinstance(x, tuple) else "BF16"}): ' + f"[tuning] Best group_cast : " f"SMs {best_results[0]}, NVL chunk {best_results[1]}, " f"{nvl_recv_bytes / 1e9 / best_time:.2f} GB/s (NVL), " f"t: {best_time * 1e6:.2f} us", @@ -1171,7 +1255,7 @@ def tune_func( dist.all_gather(all_best_results_list, best_group_cast_results, group=group) best_group_cast_results = all_best_results_list[0].tolist() - # apply group_cast to get handle before group_reduce + # Apply group_cast to get handle before group_reduce group_cast_config = GrpCollConfig( num_sms=best_group_cast_results[0], nvl_chunk_size=best_group_cast_results[1], @@ -1193,7 +1277,7 @@ def tune_func( # -------------- tune group_reduce -------------- # - # sync before tuning + # Sync before tuning torch.cuda.synchronize() dist.barrier() @@ -1205,15 +1289,13 @@ def tune_func( best_time, best_results = 1e10, None reduced_x_buf = torch.zeros_like(x) if pass_out_buffer else None - for nvl_chunk_size in tuple(range(1, 17, 1)) + (0,): - if nvl_chunk_size > 0: - config = GrpCollConfig( - num_sms=num_sms, - nvl_chunk_size=nvl_chunk_size, - nvl_buffer_size=nvl_buffer_size, - ) - else: # Test default config as well - config = GrpCollConfig.get_default_group_reduce_config(num_ranks) + for nvl_chunk_size in range(1, 17, 1): + config = GrpCollConfig( + num_sms=num_sms, + nvl_chunk_size=nvl_chunk_size, + nvl_buffer_size=nvl_buffer_size, + ) + tune_args = { "x": recv_x, "reduced_x": reduced_x_buf, @@ -1223,6 +1305,7 @@ def tune_func( "acc_reduce": acc_reduce_out_buffer, } t = bench(lambda: buffer.group_reduce(**tune_args))[0] + if local_rank == 0: print( f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size if nvl_chunk_size else "default"}: ' @@ -1253,9 +1336,9 @@ def test_main( group: dist.ProcessGroup, ): # Settings - num_tokens, hidden_size = args.num_tokens, args.hidden + num_tokens, hidden_size = args.num_tokens, args.hidden_size num_channels = num_sms // 2 - num_heads = 16 + split_alignment = args.split_alignment # choose dtype from {torch.bfloat16, torch.float16, torch.float32, torch.float64} dtype = torch.float32 # TODO: make it parameterizable @@ -1264,7 +1347,9 @@ def test_main( # Remake the hidden size to control # the communication bytes per token the same as bf16/fp16 hidden_size = hidden_size * 2 // dtype.itemsize + num_heads = 16 assert hidden_size % num_heads == 0 + head_dim = hidden_size // num_heads # Re-Settings for group-collective # TODO: make these parameterizable @@ -1292,6 +1377,9 @@ def test_main( pass_out_buffer = True # for both group_cast and group_reduce pass_out_lse_buffer = True # for both group_cast and group_reduce pass_padded_out_buffer = False # set to True to use oversized buffer for group_cast output and group_reduce input + assert ( + split_alignment == 1 or not pass_padded_out_buffer + ), "pass_padded_out_buffer only supports split_alignment == 1 for simplicity" acc_reduce_out_buffer = True acc_reduce_constant = rank @@ -1316,13 +1404,11 @@ def test_main( # Config num_max_nvl_chunked_send_tokens = 8 - nvl_buffer_size = num_max_nvl_chunked_recv_tokens = 256 + nvl_buffer_size = num_max_nvl_chunked_recv_tokens = 512 config = GrpCollConfig( - num_sms=num_sms, # num_sms, default 20 - nvl_chunk_size=num_max_nvl_chunked_send_tokens, # num_max_nvl_chunked_send_tokens (nvl_chunk_size), default 6 - nvl_buffer_size=num_max_nvl_chunked_recv_tokens, # num_max_nvl_chunked_recv_tokens (nvl_buffer_size), default 256 - # num_max_rdma_chunked_send_tokens, default 6 - # num_max_rdma_chunked_recv_tokens, default 256 + num_sms=num_sms, + nvl_chunk_size=num_max_nvl_chunked_send_tokens, + nvl_buffer_size=num_max_nvl_chunked_recv_tokens, ) min_num_nvl_bytes = GrpCollConfig.get_min_num_bytes_intranode( num_sms=num_sms, @@ -1345,7 +1431,8 @@ def test_main( ( f"[config] {num_sms=} | {num_channels=} | {min_num_nvl_bytes=} ({min_num_nvl_bytes / 1024**2:.2f} MB)\n" f"{num_tokens=} | {hidden_size=} | {dtype=} | {comm_dtype=}\n" - f"{num_heads=} | {num_data_groups_gc=} | {num_data_groups_gr=} | {cast_lse=} | {reduce_op=}\n" + f"{num_input_splits=} | {split_alignment=} | {num_heads=} | {head_dim=}\n" + f"{num_data_groups_gc=} | {num_data_groups_gr=} | {cast_lse=} | {reduce_op=}\n" f"{nvl_buffer_size=} | {num_max_nvl_chunked_send_tokens=} | {num_max_nvl_chunked_recv_tokens=}\n" f"{distinct_token=} | {random_permute_output=} | {sim_gemm_weight=} | {min_num_dst_ranks=}\n" f"{pass_out_buffer=} | {pass_out_lse_buffer=} | {pass_padded_out_buffer=}\n" @@ -1365,6 +1452,7 @@ def test_main( hidden_size=hidden_size, num_heads=num_heads, num_input_splits=num_input_splits, + split_alignment=split_alignment, num_data_groups_gc=num_data_groups_gc, num_data_groups_gr=num_data_groups_gr, dtype=dtype, @@ -1402,6 +1490,7 @@ def test_main( acc_reduce_out_buffer=acc_reduce_out_buffer, acc_reduce_constant=acc_reduce_constant, min_num_dst_ranks=min_num_dst_ranks, + split_alignment=split_alignment, # kwargs **test_kwargs, ) @@ -1418,6 +1507,7 @@ def test_main( nvl_buffer_size=nvl_buffer_size, pass_out_buffer=pass_out_buffer, acc_reduce_out_buffer=acc_reduce_out_buffer, + split_alignment=split_alignment, ) @@ -1439,10 +1529,11 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace): flush=True, ) - num_nvl_bytes = int(3e9) # ~3GB num_sms = 24 num_qps_per_rank = ll_num_experts // num_ranks if test_ll_compatibility else 1 + num_nvl_bytes = int(5e9) # ~5GB, to meet most of the requirements + # print config if local_rank == 0: print( @@ -1513,7 +1604,17 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace): help="Number of processes to spawn (default: 8)", ) parser.add_argument( - "--num-tokens", type=int, default=4096, help="Number of tokens (default: 4096)" + # NOTE: the intranode kernel performance is highly dependent on the sequence length + "--num-tokens", + type=int, + default=4096, + help="Number of tokens (default: 4096)", + ) + parser.add_argument( + "--split-alignment", + type=int, + default=1, + help="Split alignment (default: 1)", ) parser.add_argument( # NOTE: the intranode kernel performance is highly dependent on the hidden size @@ -1524,19 +1625,18 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace): # hidden_size = 48 * 128 => bandwidth = 280~300 GB/s # hidden_size = 56 * 128 => bandwidth = 260~280 GB/s # hidden_size = 64 * 128 => bandwidth = 270~280 GB/s - "--hidden", + "--hidden-size", type=int, - default=56 * 128, - help="Hidden dimension size (default: 56x128=7168)", - ) - parser.add_argument( - "--num-topk", type=int, default=8, help="Number of top-k experts (default: 8)" - ) - parser.add_argument( - "--num-experts", type=int, default=256, help="Number of experts (default: 256)" + default=64 * 128, + help="Hidden dimension size (default: 64x128=8192)", ) + args = parser.parse_args() + assert ( + args.hidden_size % args.split_alignment == 0 + ), f"hidden size {args.hidden_size} must be divisible by split alignment {args.split_alignment}" + num_processes = args.num_processes torch.multiprocessing.spawn( test_loop, args=(num_processes, args), nprocs=num_processes diff --git a/magi_attention/__init__.py b/magi_attention/__init__.py index e967042be..56c716c57 100644 --- a/magi_attention/__init__.py +++ b/magi_attention/__init__.py @@ -164,11 +164,25 @@ def dist_attn_runtime_dict_size() -> int: return int(os.environ.get("MAGI_ATTENTION_DIST_ATTN_RUNTIME_DICT_SIZE", "1000")) +def dist_attn_backward_hide_tail_reduce() -> bool: + """ + Set the value of this env variable to control + whether save the last stage for backward to get better overlaping + + Default value is ``0`` + """ + return os.environ.get("MAGI_ATTENTION_BWD_HIDE_TAIL_REDUCE", "0") == "1" + + def is_auto_range_merge_enable() -> bool: """ - Toggle this env variable to ``1`` to enable automatic range for flex flash attention + Toggle this env variable to ``1`` to enable automatic range merging for flex-flash-attention, + to improve performance by reducing the number of attention ranges Default value is ``0`` + + NOTE: this feature is experimental and under active development for now, + thus please do NOT enable it unless you know exactly what you are doing """ return os.environ.get("MAGI_ATTENTION_AUTO_RANGE_MERGE", "0") == "1" diff --git a/magi_attention/api/magi_attn_interface.py b/magi_attention/api/magi_attn_interface.py index da581e3cd..0fbc9f818 100644 --- a/magi_attention/api/magi_attn_interface.py +++ b/magi_attention/api/magi_attn_interface.py @@ -17,6 +17,7 @@ import torch import torch.distributed as dist from torch.distributed.device_mesh import DeviceMesh +from typing_extensions import deprecated import magi_attention from magi_attention.common import AttnForwardMeta, AttnRanges @@ -49,6 +50,9 @@ def magi_attn_varlen_key( cu_seqlens_q: torch.Tensor, cu_seqlens_k: torch.Tensor, + num_heads_q: int, + num_heads_kv: int, + head_dim: int, pad_size: int, chunk_size: int, cp_group_or_mesh: dist.ProcessGroup | DeviceMesh, @@ -57,17 +61,23 @@ def magi_attn_varlen_key( dist_attn_config: DistAttnConfig = DistAttnConfig(), ) -> DistAttnRuntimeKey: """This is a flash-attn-varlen like interface, - to generate q_ranges, k_ranges and attn_mask_type - from cu_seqlens_q, cu_seqlens_k, causal and window_size, - calculate DistAttnRuntimeKey and generate the corr. inner DistAttnRuntimeMgr. + to generate ``q_ranges``, ``k_ranges`` and ``attn_mask_type`` + from ``cu_seqlens_q``, ``cu_seqlens_k``, ``causal`` and ``window_size``, + calculate ``dist_attn_runtime_key`` and generate the corr. inner ``dist_attn_runtime_mgr``. Args: - cu_seqlens_q (torch.Tensor): Cumulative sequence lengths for queries. - cu_seqlens_k (torch.Tensor): Cumulative sequence lengths for keys. + cu_seqlens_q (torch.Tensor): the cumulative sequence lengths for queries. + cu_seqlens_k (torch.Tensor): the cumulative sequence lengths for keys. - pad_size (int): the size to pad along seq_dim. The seq_len need to be divisable by ``chunk_size * cp_size``. - chunk_size (int): chunk size to chunk the input tensor x along the seqlen dim for dispatch - to control the granularity of computation load-balance. + num_heads_q (int): the number of heads for query. + num_heads_kv (int): the number of heads for key/value. + head_dim (int): the dimension of each attention head. + + pad_size (int): the size to pad the global input tensor along sequence dim, + due to the constraint that the sequence length need to be divisable by ``chunk_size * cp_size``. + chunk_size (int): the size to chunk the global input tensor along the seqlen dim + for later sharding and dispatching among the cp ranks + as a granularity factor of computational load-balance. cp_group_or_mesh (dist.ProcessGroup | DeviceMesh): process group or device mesh. **NOTE**: for process group, we only support nccl backend for now, @@ -99,7 +109,7 @@ def magi_attn_varlen_key( ... ) >>> from magi_attention.common.enum import AttnOverlapMode >>> - >>> # Generate a DistAttnRuntimeKey to dispatch for flash-attn-varlen style mask + >>> # Step1. generate a dist_attn_runtime_key to store and indicate the inner meta info >>> dist_attn_runtime_key = magi_attn_varlen_key( ... cu_seqlen_q=torch.tensor( ... [0, 2048, 4096], dtype=torch.int32 @@ -107,6 +117,9 @@ def magi_attn_varlen_key( ... cu_seqlen_k=torch.tensor( ... [0, 2048, 4096], dtype=torch.int32 ... ), + ... num_heads_q=16, + ... num_heads_kv=4, + ... head_dim=128, ... pad_size=compute_pad_size(4096, 4, 512), # seqlen, cp_size, chunk_size ... chunk_size=512, ... cp_group_or_mesh=dist.new_group(list(range(4)), backend="nccl"), @@ -125,22 +138,24 @@ def magi_attn_varlen_key( ... ), ... ) >>> - >>> # Dispatch several tensors with the same key + >>> # Step2. dispatch the global tensors to local tensors >>> local_x, local_label, local_rope = [ ... dispatch(tensor, dist_attn_runtime_key) ... for tensor in [total_x, total_label, total_rope] ... ] >>> - >>> # Apply QKV projection + >>> # Step3. apply QKV projection on local tensors >>> local_q, local_k, local_v = q_project(local_x), k_project(local_x), v_project(local_x) >>> - >>> # Calculate local attention - >>> local_out, _ = calc_attn(local_q, local_k, local_v, dist_attn_runtime_key) + >>> # Step4. calculate distributed attention to get the local attention output tensor + >>> local_out, meta = calc_attn(local_q, local_k, local_v, dist_attn_runtime_key) >>> - >>> # Gather local attention outputs to total output if needed + >>> # Step5. undispatch local attention output to the global one if needed >>> total_out = undispatch(local_out, dist_attn_runtime_key) """ - # infer q_ranges, k_ranges and others from cu_seqlens_q, cu_seqlens_k and causal + + # Infer q_ranges, k_ranges and others + # from cu_seqlens_q, cu_seqlens_k and causal ( q_ranges, k_ranges, @@ -154,29 +169,34 @@ def magi_attn_varlen_key( window_size=window_size, ) - # call magi_attn_flex_key - # NOTE: for flash-attn-varlen, we assume - # is_same_source, is_q_permutable and is_k_permutable are all True. + # Call the API for flex key return magi_attn_flex_key( q_ranges=q_ranges, k_ranges=k_ranges, attn_mask_type=attn_mask_type, total_seqlen_q=total_seqlen_q, total_seqlen_k=total_seqlen_k, + num_heads_q=num_heads_q, + num_heads_kv=num_heads_kv, + head_dim=head_dim, pad_size=pad_size, chunk_size=chunk_size, cp_group_or_mesh=cp_group_or_mesh, dist_attn_config=dist_attn_config, - is_same_source=True, - is_q_permutable=True, - is_k_permutable=True, ) +@deprecated( + "This API is deprecated and will be removed in future versions. " + "Please use two steps calling of `magi_attn_varlen_key` + `dispatch` instead." +) def magi_attn_varlen_dispatch( x: torch.Tensor, cu_seqlens_q: torch.Tensor, cu_seqlens_k: torch.Tensor, + num_heads_q: int, + num_heads_kv: int, + head_dim: int, pad_size: int, chunk_size: int, cp_group_or_mesh: dist.ProcessGroup | DeviceMesh, @@ -184,17 +204,21 @@ def magi_attn_varlen_dispatch( window_size: tuple[int, int] = (-1, -1), dist_attn_config: DistAttnConfig = DistAttnConfig(), ): - """This is a flash-attn-varlen like interface, to - generate q_ranges, k_ranges and attn_mask_type from cu_seqlens_q, cu_seqlens_k, causal and window_size, - further calculate DistAttnRuntimeKey, generate the corr. inner DistAttnRuntimeMgr, - finally pad and dispatch the input tensor to local tensor. + """This is a flash-attn-varlen like interface, + to generate ``q_ranges``, ``k_ranges`` and ``attn_mask_type`` + from ``cu_seqlens_q``, ``cu_seqlens_k``, ``causal`` and ``window_size``, + calculate ``dist_attn_runtime_key`` and generate the corr. inner ``dist_attn_runtime_mgr``. Args: - x (torch.Tensor): input tensor + x (torch.Tensor): the global input tensor. cu_seqlens_q (torch.Tensor): Cumulative sequence lengths for queries. cu_seqlens_k (torch.Tensor): Cumulative sequence lengths for keys. + num_heads_q (int): the number of heads for query. + num_heads_kv (int): the number of heads for key/value. + head_dim (int): the dimension of each attention head. + pad_size (int): the size to pad along seq_dim. The seq_len need to be divisable by ``chunk_size * cp_size``. chunk_size (int): chunk size to chunk the input tensor x along the seqlen dim for dispatch to control the granularity of computation load-balance. @@ -231,7 +255,8 @@ def magi_attn_varlen_dispatch( ... ) >>> from magi_attention.common.enum import AttnOverlapMode >>> - >>> # Generate a DistAttnRuntimeKey and dispatch the input for flash-attn-varlen style mask + >>> # Step1. dispatch the global input tensor to local tensor + >>> # with a dist_attn_runtime_key generated to store and indicate the inner meta info >>> local_x, dist_attn_runtime_key = magi_attn_varlen_dispatch( ... x=torch.randn( ... 4096, # seqlen @@ -246,6 +271,9 @@ def magi_attn_varlen_dispatch( ... cu_seqlen_k=torch.tensor( ... [0, 2048, 4096], dtype=torch.int32 ... ), + ... num_heads_q=16, + ... num_heads_kv=4, + ... head_dim=128, ... pad_size=compute_pad_size(4096, 4, 512), # seqlen, cp_size, chunk_size ... chunk_size=512, ... cp_group_or_mesh=dist.new_group(list(range(4)), backend="nccl"), @@ -264,18 +292,22 @@ def magi_attn_varlen_dispatch( ... ), ... ) >>> - >>> # Apply QKV projection + >>> # Step2. apply QKV projection on local tensors >>> local_q, local_k, local_v = q_project(local_x), k_project(local_x), v_project(local_x) >>> - >>> # Calculate local attention + >>> # Step3. calculate distributed attention to get the local attention output tensor >>> local_out, _ = calc_attn(local_q, local_k, local_v, dist_attn_runtime_key) >>> - >>> # Gather local attention outputs to total output if needed + >>> # Step4. undispatch local attention output to the global one if needed >>> total_out = undispatch(local_out, dist_attn_runtime_key) """ + key = magi_attn_varlen_key( cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, + num_heads_q=num_heads_q, + num_heads_kv=num_heads_kv, + head_dim=head_dim, pad_size=pad_size, chunk_size=chunk_size, cp_group_or_mesh=cp_group_or_mesh, @@ -295,6 +327,9 @@ def magi_attn_flex_key( attn_mask_type: GeneralAttnMaskType, total_seqlen_q: int, total_seqlen_k: int, + num_heads_q: int, + num_heads_kv: int, + head_dim: int, pad_size: int, chunk_size: int, cp_group_or_mesh: dist.ProcessGroup | DeviceMesh, @@ -302,48 +337,49 @@ def magi_attn_flex_key( is_same_source: bool = True, is_q_permutable: bool = True, is_k_permutable: bool = True, - num_heads_q: int = 1, - num_heads_kv: int = 1, ) -> DistAttnRuntimeKey: """This is the most flexible interface, - directly passing in q_ranges, k_ranges and attn_mask_type to - calculate DistAttnRuntimeKey and generate the corr. inner DistAttnRuntimeMgr. + directly passing in ``q_ranges``, ``k_ranges`` and ``attn_mask_type`` to + generate ``dist_attn_runtime_key`` which stores and indicates the inner meta data + as a required argument for following APIs including ``dispatch``, ``undispatch``, ``calc_attn``, etc. Args: - q_ranges (AttnRanges): the global query ranges - k_ranges (AttnRanges): the global key ranges + q_ranges (AttnRanges): the global query ranges. + k_ranges (AttnRanges): the global key ranges. attn_mask_type (str | AttnMaskType | list[str | AttnMaskType]): - the global attn mask type (list) - represented by str or enum ``AttnMaskType`` or their mixed combination + the global attn mask type (list), represented by + str or enum ``AttnMaskType`` or their mixed combination. - total_seqlen_q (int): the total seqlen of query - total_seqlen_k (int): the total seqlen of key + total_seqlen_q (int): the total seqlen of query. + total_seqlen_k (int): the total seqlen of key. - pad_size (int): the size to pad along seq_dim. The seq_len need to be divisable by ``chunk_size * cp_size``. - chunk_size (int): chunk size to chunk the input tensor x along the seqlen dim for dispatch - to control the granularity of computation load-balance. + num_heads_q (int): the number of heads for query. + num_heads_kv (int): the number of heads for key/value. + head_dim (int): the dimension of each attention head. + + pad_size (int): the size to pad the global input tensor along sequence dim, + due to the constraint that the sequence length need to be divisable by ``chunk_size * cp_size``. + chunk_size (int): the size to chunk the global input tensor along the seqlen dim + for later sharding and dispatching among the cp ranks + as a granularity factor of computational load-balance. cp_group_or_mesh (dist.ProcessGroup | DeviceMesh): process group or device mesh. **NOTE**: for process group, we only support nccl backend for now, and for device mesh, we only support 1D or 2D mesh for now. - dist_attn_config (DistAttnConfig): dist attn config - - is_same_source (bool): is query tensor and key tensor share the same source - is_q_permutable (bool): is query tensor permutable - is_k_permutable (bool): is key tensor permutable + dist_attn_config (DistAttnConfig): dist attn config. - num_heads_q (int): the number of heads for query. Defaults to ``1``. - num_heads_kv (int): the number of heads for key/value. Defaults to ``1``. - **NOTE**: the information of number of heads for query/key/value - is an optional setting for us to try to deliver better performance - by distinguishing cases among ``MHA``, ``GQA``, ``MQA``, etc, - which is under active development and will be released in the future. + is_same_source (bool): is query tensor and key tensor share the same source. + Default to ``True``. + is_q_permutable (bool): is query tensor permutable. + Default to ``True``. + is_k_permutable (bool): is key tensor permutable. + Default to ``True``. Returns: - DistAttnRuntimeKey: the key points to the inner DistAttnRuntimeMgr. + DistAttnRuntimeKey: the key stores and indicates the inner meta data. - Note: + NOTE: 1. For decoder-only transformers (e.g., GPT), it applies 'self-attn' as follows: a. ``is_same_source`` is True. @@ -374,19 +410,19 @@ def magi_attn_flex_key( >>> from magi_attention.common.enum import AttnOverlapMode >>> from magi_attention.common import AttnRanges >>> - >>> # Generate a DistAttnRuntimeKey to dispatch for arbitrary mask represented by attn-slices + >>> # Step1. generate a dist_attn_runtime_key to store and indicate the inner meta info >>> dist_attn_runtime_key = magi_attn_flex_key( ... q_ranges=AttnRanges.from_ranges([[0, 2048], [2048, 4096]]), ... k_ranges=AttnRanges.from_ranges([[0, 2048], [0, 4096]]), ... attn_mask_type="full", ... total_seqlen_q=4096, ... total_seqlen_k=4096, + ... num_heads_q=16, + ... num_heads_kv=4, + ... head_dim=128, ... pad_size=compute_pad_size(4096, 4, 512), # seqlen, cp_size, chunk_size ... chunk_size=512, ... cp_group_or_mesh=dist.new_group(list(range(4)), backend="nccl"), - ... is_same_source=True, - ... is_q_permutable=True, - ... is_k_permutable=True, ... dist_attn_config=DistAttnConfig( ... dispatch_config=DispatchConfig(alg=MinHeapDispatchAlg()), ... overlap_config=OverlapConfig( @@ -400,29 +436,30 @@ def magi_attn_flex_key( ... ), ... ) >>> - >>> # Dispatch several tensors with the same key + >>> # Step2. dispatch the global tensors to local tensors >>> local_x, local_label, local_rope = [ ... dispatch(tensor, dist_attn_runtime_key) ... for tensor in [total_x, total_label, total_rope] ... ] >>> - >>> # Apply QKV projection + >>> # Step3. apply QKV projection on local tensors >>> local_q, local_k, local_v = q_project(local_x), k_project(local_x), v_project(local_x) >>> - >>> # Calculate local attention - >>> local_out, _ = calc_attn(local_q, local_k, local_v, dist_attn_runtime_key) + >>> # Step4. calculate distributed attention to get the local attention output tensor + >>> local_out, meta = calc_attn(local_q, local_k, local_v, dist_attn_runtime_key) >>> - >>> # Gather local attention outputs to total output if needed + >>> # Step5. undispatch local attention output to the global one if needed >>> total_out = undispatch(local_out, dist_attn_runtime_key) """ - # validate total_seqlen + + # Validate total_seqlen assert q_ranges.end <= total_seqlen_q and k_ranges.end <= total_seqlen_k, ( f"The maximum endpoint in ranges must be less than total_seqlen, " f"but got {q_ranges.end=} when {total_seqlen_q=}, " f"and got {k_ranges.end=} when {total_seqlen_k=}" ) - # validate and transform attn_mask_type + # Validate and transform attn_mask_type attn_mask_type = wrap_to_list(attn_mask_type, broadcast_to_length=q_ranges.size) assert is_list_type_all(attn_mask_type, (str, AttnMaskType)), ( f"attn_mask_type must be a list of str or AttnMaskType or their mixed combination, " @@ -436,7 +473,7 @@ def magi_attn_flex_key( f"but got {len(attn_mask_type)=} and {len(q_ranges)=}" ) - # validate process group (or device mesh) + # Validate process group (or device mesh) if isinstance(cp_group_or_mesh, dist.ProcessGroup): assert not magi_attention.comm.is_hierarchical_comm_enable(), ( "A 2D cp_mesh must be provided when hierarchical comm is enabled, " @@ -458,9 +495,9 @@ def magi_attn_flex_key( f"but got {type(cp_group_or_mesh)=}" ) - # apply padding + # Apply padding if pad_size > 0: - # apply padding to the mask with the empty slice + # Apply padding to the mask with the empty slice q_ranges, k_ranges, attn_mask_type = apply_padding( q_ranges=q_ranges, k_ranges=k_ranges, @@ -468,27 +505,28 @@ def magi_attn_flex_key( total_seqlen=total_seqlen_q, pad_size=pad_size, ) - # also apply padding to total_seqlen + # Apply padding to total_seqlen total_seqlen_q += pad_size total_seqlen_k += pad_size - # init dist attn runtime key + # Init dist attn runtime key key = init_dist_attn_runtime_key( q_ranges=q_ranges, k_ranges=k_ranges, attn_mask_type=attn_mask_type, total_seqlen_q=total_seqlen_q, total_seqlen_k=total_seqlen_k, + num_heads_q=num_heads_q, + num_heads_kv=num_heads_kv, + head_dim=head_dim, pad_size=pad_size, chunk_size=chunk_size, cp_group=cp_group, cp_mesh=cp_mesh, dist_attn_config=dist_attn_config, - num_heads_q=num_heads_q, - num_heads_kv=num_heads_kv, ) - # init dist attn runtime mgr and map it to the key + # Init dist attn runtime mgr and map it to the key if key not in dist_attn_runtime_dict.keys(): dist_attn_runtime_dict[key] = init_dist_attn_runtime_mgr( q_ranges=q_ranges, @@ -496,20 +534,28 @@ def magi_attn_flex_key( attn_mask_type=attn_mask_type, total_seqlen_q=total_seqlen_q, total_seqlen_k=total_seqlen_k, + num_heads_q=num_heads_q, + num_heads_kv=num_heads_kv, + head_dim=head_dim, chunk_size=chunk_size, cp_group=cp_group, + cp_mesh=cp_mesh, + dist_attn_config=dist_attn_config, + # TODO: think through other scnearios besides self-attn and cross-attn + # and find a better way to represent these flags + # now keep it here temporarily for consistency is_same_source=is_same_source, is_q_permutable=is_q_permutable, is_k_permutable=is_k_permutable, - dist_attn_config=dist_attn_config, - cp_mesh=cp_mesh, - num_heads_q=num_heads_q, - num_heads_kv=num_heads_kv, ) return key +@deprecated( + "This API is deprecated and will be removed in future versions. " + "Please use two steps calling of `magi_attn_flex_key` + `dispatch` instead." +) def magi_attn_flex_dispatch( x: torch.Tensor, q_ranges: AttnRanges, @@ -517,6 +563,9 @@ def magi_attn_flex_dispatch( attn_mask_type: GeneralAttnMaskType, total_seqlen_q: int, total_seqlen_k: int, + num_heads_q: int, + num_heads_kv: int, + head_dim: int, pad_size: int, chunk_size: int, cp_group_or_mesh: dist.ProcessGroup | DeviceMesh, @@ -524,46 +573,46 @@ def magi_attn_flex_dispatch( is_same_source: bool = True, is_q_permutable: bool = True, is_k_permutable: bool = True, - num_heads_q: int = 1, - num_heads_kv: int = 1, ) -> tuple[torch.Tensor, DistAttnRuntimeKey]: """This is the most flexible interface, - directly passing in q_ranges, k_ranges and attn_mask_type to - calculate DistAttnRuntimeKey, generate the corr. inner DistAttnRuntimeMgr, - finally pad and dispatch the input tensor to local tensor. + directly passing in ``q_ranges``, ``k_ranges`` and ``attn_mask_type`` to + generate ``dist_attn_runtime_key`` which stores and indicates the inner meta data + and then dispatch the global input tensor to local tensor. Args: - x (torch.Tensor): input tensor + x (torch.Tensor): the global input tensor. - q_ranges (AttnRanges): the global query ranges - k_ranges (AttnRanges): the global key ranges + q_ranges (AttnRanges): the global query ranges. + k_ranges (AttnRanges): the global key ranges. attn_mask_type (str | AttnMaskType | list[str | AttnMaskType]): - the global attn mask type (list) - represented by str or enum ``AttnMaskType`` or their mixed combination + the global attn mask type (list), represented by + str or enum ``AttnMaskType`` or their mixed combination. - total_seqlen_q (int): the total seqlen of query - total_seqlen_k (int): the total seqlen of key + total_seqlen_q (int): the total seqlen of query. + total_seqlen_k (int): the total seqlen of key. - pad_size (int): the size to pad along seq_dim. The seq_len need to be divisable by ``chunk_size * cp_size``. - chunk_size (int): chunk size to chunk the input tensor x along the seqlen dim for dispatch - to control the granularity of computation load-balance. + num_heads_q (int): the number of heads for query. + num_heads_kv (int): the number of heads for key/value. + head_dim (int): the dimension of each attention head. + + pad_size (int): the size to pad the global input tensor along sequence dim, + due to the constraint that the sequence length need to be divisable by ``chunk_size * cp_size``. + chunk_size (int): the size to chunk the global input tensor along the seqlen dim + for later sharding and dispatching among the cp ranks + as a granularity factor of computational load-balance. cp_group_or_mesh (dist.ProcessGroup | DeviceMesh): process group or device mesh. **NOTE**: for process group, we only support nccl backend for now, and for device mesh, we only support 1D or 2D mesh for now. - dist_attn_config (DistAttnConfig): dist attn config - - is_same_source (bool): is query tensor and key tensor share the same source - is_q_permutable (bool): is query tensor permutable - is_k_permutable (bool): is key tensor permutable + dist_attn_config (DistAttnConfig): dist attn config. - num_heads_q (int): the number of heads for query. Defaults to ``1``. - num_heads_kv (int): the number of heads for key/value. Defaults to ``1``. - **NOTE**: the information of number of heads for query/key/value - is an optional setting for us to try to deliver better performance - by distinguishing cases among ``MHA``, ``GQA``, ``MQA``, etc, - which is under active development and will be released in the future. + is_same_source (bool): is query tensor and key tensor share the same source. + Default to ``True``. + is_q_permutable (bool): is query tensor permutable. + Default to ``True``. + is_k_permutable (bool): is key tensor permutable. + Default to ``True``. Returns: tuple[torch.Tensor, DistAttnRuntimeKey]: @@ -601,7 +650,8 @@ def magi_attn_flex_dispatch( >>> from magi_attention.common.enum import AttnOverlapMode >>> from magi_attention.common import AttnRanges >>> - >>> # Generate a DistAttnRuntimeKey and dispatch the input for arbitrary mask represented by attn-slices + >>> # Step1. dispatch the global input tensor to local tensor + >>> # with a dist_attn_runtime_key generated to store and indicate the inner meta info >>> local_x, dist_attn_runtime_key = magi_attn_flex_dispatch( ... x = torch.randn( ... 4096, # seqlen @@ -615,7 +665,10 @@ def magi_attn_flex_dispatch( ... attn_mask_type="full", ... total_seqlen_q=4096, ... total_seqlen_k=4096, - ... pad_size=compute_pad_size(4096, 4, 512), # seqlen, cp_size, chun_size + ... num_heads_q=16, + ... num_heads_kv=4, + ... head_dim=128, + ... pad_size=compute_pad_size(4096, 4, 512), # seqlen, cp_size, chunk_size ... chunk_size=512, ... cp_group_or_mesh=dist.new_group(list(range(4)), backend="nccl"), ... dist_attn_config=DistAttnConfig( @@ -629,38 +682,34 @@ def magi_attn_flex_dispatch( ... alg=UniformOverlapAlg(), ... ), ... ), - ... is_same_source=True, - ... is_q_permutable=True, - ... is_k_permutable=True, ... ) >>> - >>> # Apply QKV projection + >>> # Step2. apply QKV projection >>> local_q, local_k, local_v = q_project(local_x), k_project(local_x), v_project(local_x) >>> - >>> # Calculate local attention + >>> # Step3. calculate distributed attention to get the local attention output tensor >>> local_out, _ = calc_attn(local_q, local_k, local_v, dist_attn_runtime_key) >>> - >>> # Gather local attention outputs to total output if needed + >>> # Step4. undispatch local attention output to the global one if needed >>> total_out = undispatch(local_out, dist_attn_runtime_key) """ + key = magi_attn_flex_key( q_ranges=q_ranges, k_ranges=k_ranges, attn_mask_type=attn_mask_type, total_seqlen_q=total_seqlen_q, total_seqlen_k=total_seqlen_k, + num_heads_q=num_heads_q, + num_heads_kv=num_heads_kv, + head_dim=head_dim, pad_size=pad_size, chunk_size=chunk_size, cp_group_or_mesh=cp_group_or_mesh, dist_attn_config=dist_attn_config, - # TODO: think through other scnearios besides self-attn and cross-attn - # and find a better way to represent these flags - # now keep it here temporarily for consistency is_same_source=is_same_source, is_q_permutable=is_q_permutable, is_k_permutable=is_k_permutable, - num_heads_q=num_heads_q, - num_heads_kv=num_heads_kv, ) local_x = dispatch(x, key) @@ -673,29 +722,32 @@ def dispatch( pad_value: float = 0.0, ) -> torch.Tensor: """ - Pad and dispatch the global input tensor to local tensor on each rank along the seqlen dim. + Pad and dispatch the global input tensor to local input tensor + for each cp rank along the seqlen dim. Args: - x (torch.Tensor): global input tensor. + x (torch.Tensor): the global input tensor. key (DistAttnRuntimeKey): the key that holds some inner meta data, - as one argument for many other magi_attention APIs, - which users don’t have to bother with. - pad_value (float): the specific value to pad to input tensor. Defaults to 0. + as a required argument for many APIs of ``magi_attention``, + which users don't have to bother with. + pad_value (float): the specific value to pad to input tensor. + Defaults to ``0``. Returns: - torch.Tensor: the padded and dispatched local tensor. + torch.Tensor: the padded local input tensor. Raises: - ValueError: If the provided ``key`` does not exist in ``dist_attn_runtime_dict``. + ValueError: If the provided ``key`` does not exist in cached ``dist_attn_runtime_dict``. """ + mgr = dist_attn_runtime_dict.get(key) if mgr is None: raise ValueError("The dist attn runtime key does not exist!") - pad_size = key.pad_size - padded_x = pad_at_dim(x, 0, pad_size, value=pad_value) + padded_x = pad_at_dim(x=x, dim=0, pad_size=key.pad_size, value=pad_value) + padded_local_x = mgr.dispatch_qo(padded_x) - return mgr.dispatch_qo(padded_x) + return padded_local_x def undispatch( @@ -703,29 +755,30 @@ def undispatch( key: DistAttnRuntimeKey, ) -> torch.Tensor: """ - Undispatch and unpad the local tensor to global tensor along the seqlen dim. + Undispatch and unpad the local output tensor to global output tensor + for each cp rank along the seqlen dim. Args: - x (torch.Tensor): local tensor + x (torch.Tensor): the local output tensor. key (DistAttnRuntimeKey): the key that holds some inner meta data, - as one argument for many other magi_attention APIs, - which users don’t have to bother with. + as a required argument for many APIs of ``magi_attention``, + which users don't have to bother with. Returns: - torch.Tensor: the undispatched and unpadded tensor. + torch.Tensor: the unpadded global output tensor. Raises: - ValueError: If the provided ``key`` does not exist in ``dist_attn_runtime_dict``. + ValueError: If the provided ``key`` does not exist in cached ``dist_attn_runtime_dict``. """ + mgr = dist_attn_runtime_dict.get(key) if mgr is None: raise ValueError("The dist attn runtime key does not exist!") - total_x = mgr.undispatch_qo(x) - pad_size = key.pad_size - unpad_total_x = unpad_at_dim(total_x, 0, pad_size) + global_x = mgr.undispatch_qo(x) + unpadded_global_x = unpad_at_dim(x=global_x, dim=0, pad_size=key.pad_size) - return unpad_total_x + return unpadded_global_x def calc_attn( @@ -738,26 +791,28 @@ def calc_attn( softcap: float = 0.0, ) -> tuple[torch.Tensor, AttnForwardMeta]: """ - Apply attention computation. + Calculate distributed attention with local q, k, v tensors. Args: - q (torch.Tensor): local query tensor. - k (torch.Tensor): local key tensor. - v (torch.Tensor): local value tensor. - key (DistAttnRuntimeKey): the object that holds some inner meta data - as one argument for many other magi_attention APIs, - which users don’t have to bother with. - - sink (torch.Tensor, optional): global sink tensor (replicated among cp ranks). + q (torch.Tensor): the local query tensor. + k (torch.Tensor): the local key tensor. + v (torch.Tensor): the local value tensor. + key (DistAttnRuntimeKey): the key that holds some inner meta data, + as a required argument for many APIs of ``magi_attention``, + which users don't have to bother with. + + sink (torch.Tensor, optional): the global sink tensor (replicated among cp ranks). Defaults to ``None`` to not apply attention sink. softmax_scale (float, optional): softmax scale. - Defaults to ``None`` to use: ``1/sqrt(head_dim)``. - softcap (float, optional): softcap. Defaults to ``0.0``. + Defaults to ``None`` to use the value: ``1/sqrt(head_dim)``. + softcap (float, optional): softcap. + Defaults to ``0.0``. Returns: - out (torch.Tensor): local output tensor. - meta (AttnForwardMeta): attention forward meta. + tuple[torch.Tensor, AttnForwardMeta]: + - out (torch.Tensor): local output tensor. + - meta (AttnForwardMeta): attention forward meta. Shapes: - q: [num_tokens_q_local, num_heads_q, head_dim] @@ -768,8 +823,9 @@ def calc_attn( - lse: [num_tokens_q_local, num_heads_q] Raises: - ValueError: If the provided ``key`` does not exist in ``dist_attn_runtime_dict``. + ValueError: If the provided ``key`` does not exist in cached ``dist_attn_runtime_dict``. """ + mgr = dist_attn_runtime_dict.get(key) if mgr is None: raise ValueError("The dist attn runtime key does not exist!") @@ -786,19 +842,21 @@ def calc_attn( def get_position_ids(key: DistAttnRuntimeKey) -> torch.Tensor: """ - Get the position ids of local tensor to global tensor after dispatching. + Get the global positional ids of the local tensor, + as it is sliced from the global tensor after dispatching. Args: key (DistAttnRuntimeKey): the key that holds some inner meta data, - as one argument for many other magi_attention APIs, - which users don’t have to bother with. + as a required argument for many APIs of ``magi_attention``, + which users don't have to bother with. Returns: - torch.Tensor: postion ids of local tensor w.r.t. global tensor. + torch.Tensor: the global positional ids. Raises: - ValueError: If the provided ``key`` does not exist in ``dist_attn_runtime_dict``. + ValueError: If the provided ``key`` does not exist in cached ``dist_attn_runtime_dict``. """ + mgr = dist_attn_runtime_dict.get(key) if mgr is None: raise ValueError("The dist attn runtime key does not exist!") @@ -807,15 +865,15 @@ def get_position_ids(key: DistAttnRuntimeKey) -> torch.Tensor: def get_most_recent_key() -> DistAttnRuntimeKey: - """Get the most recent inserted key. + """Get the most recent inserted dist_attn_runtime_key. - This is useful when you can not access the key through the arguments, + NOTE: This is useful when you can not access the key through the arguments, and meanwhile you only need the most recent inserted key. However, we strongly recommend you to access the key passed through the arguments, in case of unexpected inconsistency. Returns: - DistAttnRuntimeKey: the most recent inserted key. + DistAttnRuntimeKey: the most recent inserted dist_attn_runtime_key. """ key = dist_attn_runtime_dict.get_most_recent_key() @@ -846,24 +904,25 @@ def make_varlen_key_for_new_mask_after_dispatch( and optimized in communication. Args: - cu_seqlens_q (torch.Tensor): Cumulative sequence lengths for queries. - cu_seqlens_k (torch.Tensor): Cumulative sequence lengths for keys. + cu_seqlens_q (torch.Tensor): the cumulative sequence lengths for queries. + cu_seqlens_k (torch.Tensor): the cumulative sequence lengths for keys. + + key_for_dispatch (DistAttnRuntimeKey): the key used for dispatch. - key_for_dispatch (DistAttnRuntimeKey): the key used for dispatch - causal (bool, optional): whether the varlen attention mask is causal. Defaults to ``False``. + causal (bool, optional): whether the varlen attention mask is causal. + Defaults to ``False``. window_size (tuple[int, int], optional): window_size of sliding window mask which represents ``[window_size_left, window_size_right]``. The parameter is effective only when ``causal`` is ``False``; when ``causal`` is ``True``, it is required to be ``(-1, -1)``. Defaults to be ``(-1, -1)``. - dist_attn_config (DistAttnConfig, optional): the optional new dist attn config, - + dist_attn_config (DistAttnConfig, optional): the optional new dist attn config. NOTE: if not provided, we will use the same config as the ``key_for_dispatch``, and if provided, the dispatch config of the new dist attn config won't be applied to the new mask Returns: DistAttnRuntimeKey: the new dist attn runtime key - for new mask with the same dispatch solution as the ``key_for_dispatch`` + for new mask with the same dispatch solution as the ``key_for_dispatch``. Example: >>> import torch @@ -880,7 +939,7 @@ def make_varlen_key_for_new_mask_after_dispatch( ... ) >>> from magi_attention.common.enum import AttnOverlapMode >>> - >>> # Generate a DistAttnRuntimeKey to dispatch for flash-attn-varlen style mask + >>> # Step1. generate a dist_attn_runtime_key to dispatch for flash-attn-varlen style mask >>> # in the following case, we use a causal mask as the key for dispatch, thus it will consider >>> # computation load-balance, communication optimization and computation-communication overlap >>> # according to the causal mask pattern @@ -891,6 +950,9 @@ def make_varlen_key_for_new_mask_after_dispatch( ... cu_seqlen_k=torch.tensor( ... [0, 4096], dtype=torch.int32 ... ), + ... num_heads_q=16, + ... num_heads_kv=4, + ... head_dim=128, ... pad_size=compute_pad_size(4096, 4, 512), # seqlen, cp_size, chunk_size ... chunk_size=512, ... cp_group_or_mesh=dist.new_group(list(range(4)), backend="nccl"), @@ -909,13 +971,13 @@ def make_varlen_key_for_new_mask_after_dispatch( ... ), ... ) >>> - >>> # Dispatch several tensors with the same key_for_dispatch + >>> # Step2. dispatch the global tensors to local tensors with the same key_for_dispatch >>> local_x, local_label, local_rope = [ ... dispatch(tensor, key_for_dispatch) ... for tensor in [total_x, total_label, total_rope] ... ] >>> - >>> # Make a new dist attn runtime key from key_for_dispatch + >>> # Step3. make a new dist_attn_runtime_key from key_for_dispatch >>> # for a new mask, such as a sliding window causal mask below, >>> # with the same dispatch solution as the causal mask used for dispatch, >>> # i.e. this new key share the same dispatch meta as key_for_dispatch @@ -929,21 +991,25 @@ def make_varlen_key_for_new_mask_after_dispatch( ... key_for_dispatch=key_for_dispatch, ... ) >>> - >>> # Apply QKV projection + >>> # Step4. apply QKV projection on local tensors >>> local_q, local_k, local_v = q_project(local_x), k_project(local_x), v_project(local_x) >>> - >>> # Calculate local attention for the mask used to dispatch with key_for_dispatch + >>> # Step5. calculate distributed attention + >>> # for the causal mask used to dispatch with key_for_dispatch >>> local_out1, _ = calc_attn(local_q, local_k, local_v, key_for_dispatch) >>> - >>> # Calculate local attention for the new swa mask with the new key - >>> # w/o undispatching back and dispatching again to avoid OOM + >>> # Step6. calculate distributed attention + >>> # for the new swa mask with the new key + >>> # w/o undispatching back and re-dispatching again to avoid OOM >>> local_out2, _ = calc_attn(local_q, local_k, local_v, new_key_for_swa_mask) >>> - >>> # Gather local attention outputs to total output if needed + >>> # Step7. undispatch local attention output to the global one if needed >>> total_out1 = undispatch(local_out1, key_for_dispatch) >>> total_out2 = undispatch(local_out2, new_key_for_swa_mask) """ - # infer q_ranges, k_ranges and others from cu_seqlens_q, cu_seqlens_k and causal + + # Infer q_ranges, k_ranges and others + # from cu_seqlens_q, cu_seqlens_k and causal ( q_ranges, k_ranges, @@ -957,6 +1023,7 @@ def make_varlen_key_for_new_mask_after_dispatch( window_size=window_size, ) + # Call the API for flex key return make_flex_key_for_new_mask_after_dispatch( q_ranges=q_ranges, k_ranges=k_ranges, @@ -987,16 +1054,15 @@ def make_flex_key_for_new_mask_after_dispatch( to optimize the computation and communication for each distinct mask with the same dispatch solution Args: - q_ranges (AttnRanges): the global query ranges - k_ranges (AttnRanges): the global key ranges + q_ranges (AttnRanges): the global query ranges. + k_ranges (AttnRanges): the global key ranges. attn_mask_type (str | AttnMaskType | list[str | AttnMaskType]): - the global attn mask type (list) - represented by str or enum ``AttnMaskType`` or their mixed combination + the global attn mask type (list), represented by + str or enum ``AttnMaskType`` or their mixed combination. - key_for_dispatch (DistAttnRuntimeKey): the key used for dispatch - - dist_attn_config (DistAttnConfig, optional): the optional new dist attn config, + key_for_dispatch (DistAttnRuntimeKey): the key used for dispatch. + dist_attn_config (DistAttnConfig, optional): the optional new dist attn config. NOTE: if not provided, we will use the same config as the ``key_for_dispatch``, and if provided, the dispatch config of the new dist attn config won't be applied to the new mask @@ -1020,7 +1086,7 @@ def make_flex_key_for_new_mask_after_dispatch( >>> from magi_attention.common.enum import AttnOverlapMode >>> from magi_attention.common import AttnRanges >>> - >>> # Generate a DistAttnRuntimeKey to dispatch for arbitrary mask represented by attn-slices + >>> # Step1. generate a dist_attn_runtime_key to dispatch for arbitrary mask represented by attn slices >>> # in the following case, we use a causal mask as the key for dispatch, thus it will consider >>> # computation load-balance, communication optimization and computation-communication overlap >>> # according to the causal mask pattern @@ -1030,12 +1096,12 @@ def make_flex_key_for_new_mask_after_dispatch( ... attn_mask_type="causal", ... total_seqlen_q=4096, ... total_seqlen_k=4096, + ... num_heads_q=16, + ... num_heads_kv=4, + ... head_dim=128, ... pad_size=compute_pad_size(4096, 4, 512), # seqlen, cp_size, chunk_size ... chunk_size=512, ... cp_group_or_mesh=dist.new_group(list(range(4)), backend="nccl"), - ... is_same_source=True, - ... is_q_permutable=True, - ... is_k_permutable=True, ... dist_attn_config=DistAttnConfig( ... dispatch_config=DispatchConfig(alg=MinHeapDispatchAlg()), ... overlap_config=OverlapConfig( @@ -1049,13 +1115,13 @@ def make_flex_key_for_new_mask_after_dispatch( ... ), ... ) >>> - >>> # Dispatch several tensors with the same key_for_dispatch + >>> # Step2. dispatch the global tensors to local tensors with the same key_for_dispatch >>> local_x, local_label, local_rope = [ ... dispatch(tensor, key_for_dispatch) ... for tensor in [total_x, total_label, total_rope] ... ] >>> - >>> # Make a new dist attn runtime key from key_for_dispatch + >>> # Step3. make a new dist_attn_runtime_key from key_for_dispatch >>> # for a new mask, such as a sliding window causal mask below, >>> # with the same dispatch solution as the causal mask used for dispatch, >>> # i.e. this new key share the same dispatch meta as key_for_dispatch @@ -1068,21 +1134,24 @@ def make_flex_key_for_new_mask_after_dispatch( ... key_for_dispatch=key_for_dispatch, ... ) >>> - >>> # Apply QKV projection + >>> # Step4. apply QKV projection on local tensors >>> local_q, local_k, local_v = q_project(local_x), k_project(local_x), v_project(local_x) >>> - >>> # Calculate local attention for the mask used to dispatch with key_for_dispatch + >>> # Step5. calculate distributed attention + >>> # for the causal mask used to dispatch with key_for_dispatch >>> local_out1, _ = calc_attn(local_q, local_k, local_v, key_for_dispatch) >>> - >>> # Calculate local attention for the new swa mask with the new key - >>> # w/o undispatching back and dispatching again to avoid OOM + >>> # Step6. calculate distributed attention + >>> # for the new swa mask with the new key + >>> # w/o undispatching back and re-dispatching again to avoid OOM >>> local_out2, _ = calc_attn(local_q, local_k, local_v, new_key_for_swa_mask) >>> - >>> # Gather local attention outputs to total output if needed + >>> # Step7. undispatch local attention output to the global one if needed >>> total_out1 = undispatch(local_out1, key_for_dispatch) >>> total_out2 = undispatch(local_out2, new_key_for_swa_mask) """ - # validate and transform attn_mask_type + + # Validate and transform attn_mask_type attn_mask_type = wrap_to_list(attn_mask_type, broadcast_to_length=q_ranges.size) assert is_list_type_all(attn_mask_type, (str, AttnMaskType)), ( f"attn_mask_type must be a list of str or AttnMaskType or their mixed combination, " @@ -1096,7 +1165,7 @@ def make_flex_key_for_new_mask_after_dispatch( f"but got {len(attn_mask_type)=} and {len(q_ranges)=}" ) - # extract the common attributes from the key for dispatch + # Extract the common attributes from the key for dispatch total_seqlen_q = key_for_dispatch.total_seqlen_q # already padded total_seqlen_k = key_for_dispatch.total_seqlen_k # already padded pad_size = key_for_dispatch.pad_size @@ -1110,22 +1179,25 @@ def make_flex_key_for_new_mask_after_dispatch( else key_for_dispatch.dist_attn_config.overlap_config, ) - # extract the common attributes from the mgr for dispatch + # Extract the common attributes from the mgr for dispatch mgr = dist_attn_runtime_dict.get(key_for_dispatch) if mgr is None: raise ValueError("The dist attn runtime key for dispatch does not exist!") + ref_dispatch_meta_q = mgr.dispatch_meta_q ref_dispatch_meta_k = mgr.dispatch_meta_k + is_same_source = mgr.is_same_source is_q_permutable = mgr.is_q_permutable is_k_permutable = mgr.is_k_permutable num_heads_q = mgr.num_heads_q num_heads_kv = mgr.num_heads_kv + head_dim = mgr.head_dim - # apply padding + # Apply padding if pad_size > 0: - # apply padding to the new mask with the empty slice + # Apply padding to the new mask with the empty slice q_ranges, k_ranges, attn_mask_type = apply_padding( q_ranges=q_ranges, k_ranges=k_ranges, @@ -1134,23 +1206,24 @@ def make_flex_key_for_new_mask_after_dispatch( pad_size=pad_size, ) - # init new dist attn runtime key + # Init new dist attn runtime key new_key = init_dist_attn_runtime_key( q_ranges=q_ranges, k_ranges=k_ranges, attn_mask_type=attn_mask_type, total_seqlen_q=total_seqlen_q, total_seqlen_k=total_seqlen_k, + num_heads_q=num_heads_q, + num_heads_kv=num_heads_kv, + head_dim=head_dim, pad_size=pad_size, chunk_size=chunk_size, cp_group=cp_group, cp_mesh=cp_mesh, dist_attn_config=new_dist_attn_config, - num_heads_q=num_heads_q, - num_heads_kv=num_heads_kv, ) - # init new dist attn runtime mgr and map it to the new key + # Init new dist attn runtime mgr and map it to the new key if new_key not in dist_attn_runtime_dict.keys(): dist_attn_runtime_dict[new_key] = init_dist_attn_runtime_mgr( q_ranges=q_ranges, @@ -1158,17 +1231,18 @@ def make_flex_key_for_new_mask_after_dispatch( attn_mask_type=attn_mask_type, total_seqlen_q=total_seqlen_q, total_seqlen_k=total_seqlen_k, + num_heads_q=num_heads_q, + num_heads_kv=num_heads_kv, + head_dim=head_dim, chunk_size=chunk_size, cp_group=cp_group, + cp_mesh=cp_mesh, + dist_attn_config=new_dist_attn_config, is_same_source=is_same_source, is_q_permutable=is_q_permutable, is_k_permutable=is_k_permutable, - dist_attn_config=new_dist_attn_config, - cp_mesh=cp_mesh, ref_dispatch_meta_q=ref_dispatch_meta_q, ref_dispatch_meta_k=ref_dispatch_meta_k, - num_heads_q=num_heads_q, - num_heads_kv=num_heads_kv, ) return new_key diff --git a/magi_attention/benchmarking/image_grid.py b/magi_attention/benchmarking/image_grid.py index 7c4431fd7..70e419fa7 100644 --- a/magi_attention/benchmarking/image_grid.py +++ b/magi_attention/benchmarking/image_grid.py @@ -331,7 +331,7 @@ def __call__(self, pic): """ return to_tensor(pic) - def __repr__(self) -> str: + def __repr__(self) -> str: # pragma: no cover return f"{self.__class__.__name__}()" @@ -372,7 +372,7 @@ def __call__(self, pic): """ return to_pil_image(pic, self.mode) - def __repr__(self) -> str: + def __repr__(self) -> str: # pragma: no cover format_string = self.__class__.__name__ + "(" if self.mode is not None: format_string += f"mode={self.mode}" diff --git a/magi_attention/comm/primitive/grpcoll/_a2av_grpcoll_impl.py b/magi_attention/comm/primitive/grpcoll/_a2av_grpcoll_impl.py index ba8935607..095f5616c 100644 --- a/magi_attention/comm/primitive/grpcoll/_a2av_grpcoll_impl.py +++ b/magi_attention/comm/primitive/grpcoll/_a2av_grpcoll_impl.py @@ -27,7 +27,7 @@ # ------------------ a2av group cast ------------------ # -# host meta interface +# Host meta interface @overload def a2av_group_cast_impl( input: torch.Tensor, @@ -46,7 +46,7 @@ def a2av_group_cast_impl( ... -# device meta interface +# Device meta interface @overload def a2av_group_cast_impl( input: torch.Tensor, @@ -182,7 +182,7 @@ def a2av_group_cast_impl( # ------------------ a2av group reduce ------------------ # -# host meta interface +# Host meta interface @overload def a2av_group_reduce_impl( input: torch.Tensor, @@ -203,7 +203,7 @@ def a2av_group_reduce_impl( ... -# device meta interface +# Device meta interface @overload def a2av_group_reduce_impl( input: torch.Tensor, diff --git a/magi_attention/comm/primitive/grpcoll/_buffer.py b/magi_attention/comm/primitive/grpcoll/_buffer.py index 3ca77afe9..4f566eb75 100644 --- a/magi_attention/comm/primitive/grpcoll/_buffer.py +++ b/magi_attention/comm/primitive/grpcoll/_buffer.py @@ -390,6 +390,7 @@ def group_cast( lse: torch.Tensor | None = None, recv_lse: torch.Tensor | None = None, max_num_rdma_recv_tokens: int = -1, + split_alignment: int = 1, ) -> tuple[ list[torch.Tensor], torch.Tensor | None, GrpCollIntraHandle, EventOverlap ]: @@ -430,6 +431,11 @@ def group_cast( if set to a non-negative value, we will use it to allocate some related internode handle tensors to avoid its GPU-CPU sync. + split_alignment: the split alignment to review x/recv_x/lse/recv_lse + from ``(seqlen, hidden_size)`` to ``(seqlen // split_alignment, hidden_size * split_alignment)``, + to raise up the hidden size for better performance. + Defaults to ``1``. TODO: support dynamic split_alignment varying from different dtypes. + NOTE: To fully avoid GPU-CPU sync, you can just given the ``handle`` to enable "cache mode", otherwise you have to at least provide the output tensor buffer, @@ -446,29 +452,23 @@ def group_cast( handle: the returned communication handle. event: the event after executing the kernel (valid only if `async_op` is set). """ - is_out_buf_given = recv_x is not None + # Check x = wrap_to_list(x) - num_groups = len(x) - if is_out_buf_given: - assert recv_x is not None # mypy + num_groups, hidden_shape, dtype = len(x), x[0].shape[1:], x[0].dtype + if recv_x is not None: recv_x = wrap_to_list(recv_x) assert len(recv_x) == len(x), ( "The number of groups of input and output buffer should be the same, " f"but got {len(x)=}, {len(recv_x)=}." ) - - hidden_shape = x[0].shape[1:] - hidden_size = math.prod(hidden_shape) - if is_out_buf_given: - assert recv_x is not None # mypy for i in range(num_groups): assert recv_x[i].shape[1:] == hidden_shape, ( "The hidden shape (except dim0) of input and output buffer should be the same, " f"but got {x[i].shape=}, {recv_x[i].shape=}." ) - # Default config + # Set grpcoll config config = ( GrpCollConfig.get_default_group_cast_config(self.group_size) if config is None @@ -476,21 +476,50 @@ def group_cast( ) # View input/output to 2D shape + # HACK: If non-trivial split alignment is given, + # we will re-view the input/output from (seqlen, hidden_size) to (seqlen // align, hidden_size * align) + # to raise up the hidden size for better performance + # and of course, it requires the arguments to be aligned and re-calculated accordingly + # which we've already checked and done in the higher-level programs. + hidden_size = math.prod(hidden_shape) + assert ( + hidden_size * split_alignment + ) % GrpCollBuffer.get_hidden_size_alignment(dtype) == 0, ( + "The hidden size multiplied by split alignment should be aligned to " + f"{GrpCollBuffer.get_hidden_size_alignment(dtype)} for dtype {dtype}, " + f"but got {hidden_size=}, {split_alignment=}." + ) for i in range(num_groups): - x[i] = x[i].view(-1, hidden_size) - if is_out_buf_given: - assert recv_x is not None # mypy + x[i] = x[i].view(-1, hidden_size * split_alignment) + if recv_x is not None: for i in range(num_groups): - recv_x[i] = recv_x[i].view(-1, hidden_size) + recv_x[i] = recv_x[i].view(-1, hidden_size * split_alignment) - # Internode - if self.runtime.get_num_rdma_ranks() > 1: - return self._internode_group_cast( + # Prepare lse and recv_lse + # HACK: same as above, we will re-view the lse/recv_lse + # from (seqlen, num_heads) to (seqlen // align, num_heads * align) + # if non-trivial split alignment is given + if cast_lse: + assert lse is not None, "lse should not be None when `cast_lse` is set" + num_heads = lse.shape[1] + lse = lse.view(-1, num_heads * split_alignment) + if recv_lse is not None: + recv_lse = recv_lse.view(-1, num_heads * split_alignment) + else: # no need to cast lse, even passed in + lse, recv_lse = None, None + + # Dispatch to intranode/internode group-cast + if self.runtime.get_num_rdma_ranks() > 1: # Internode + ( + recv_x, + recv_lse, + handle, + event, + ) = self._internode_group_cast( x=x, recv_x=recv_x, config=config, handle=handle, - hidden_shape=hidden_shape, num_tokens_per_rank=num_tokens_per_rank, num_tokens_per_rdma_rank=num_tokens_per_rdma_rank, is_token_in_rank=is_token_in_rank, @@ -499,30 +528,40 @@ def group_cast( kernel_barrier=kernel_barrier, async_op=async_op, allocate_on_comm_stream=allocate_on_comm_stream, - cast_lse=cast_lse, lse=lse, recv_lse=recv_lse, max_num_rdma_recv_tokens=max_num_rdma_recv_tokens, ) + else: # Intranode + ( + recv_x, + recv_lse, + handle, + event, + ) = self._intranode_group_cast( + x=x, + recv_x=recv_x, + config=config, + handle=handle, + num_tokens_per_rank=num_tokens_per_rank, + is_token_in_rank=is_token_in_rank, + post_perm_idx=post_perm_idx, + previous_event=previous_event, + kernel_barrier=kernel_barrier, + async_op=async_op, + allocate_on_comm_stream=allocate_on_comm_stream, + lse=lse, + recv_lse=recv_lse, + ) - # Intranode - return self._intranode_group_cast( - x=x, - recv_x=recv_x, - config=config, - handle=handle, - hidden_shape=hidden_shape, - num_tokens_per_rank=num_tokens_per_rank, - is_token_in_rank=is_token_in_rank, - post_perm_idx=post_perm_idx, - previous_event=previous_event, - kernel_barrier=kernel_barrier, - async_op=async_op, - allocate_on_comm_stream=allocate_on_comm_stream, - cast_lse=cast_lse, - lse=lse, - recv_lse=recv_lse, - ) + # View output back to original hidden shape + # as well as recv_lse if given + for i in range(num_groups): + recv_x[i] = recv_x[i].view(-1, *hidden_shape) + if recv_lse is not None: + recv_lse = recv_lse.view(-1, num_heads) + + return recv_x, recv_lse, handle, event def group_reduce( self, @@ -540,6 +579,7 @@ def group_reduce( comm_dtype: torch.dtype | None = None, lse: torch.Tensor | None = None, reduced_lse: torch.Tensor | None = None, + split_alignment: int = 1, ) -> tuple[list[torch.Tensor], torch.Tensor | None, EventOverlap]: """ Group reduce tokens (addition **without** weights) from different ranks, both intranode and internode @@ -571,6 +611,11 @@ def group_reduce( with shape `[num_recv_tokens, num_heads]`, to be received and reduced along with `reduced_x`, when `reduce_op` is "lse". + split_alignment: the split alignment to review x/reduced_x/lse/reduced_lse + from ``(seqlen, hidden_size)`` to ``(seqlen // split_alignment, hidden_size * split_alignment)``, + to raise up the hidden size for better performance. + Defaults to ``1``. TODO: support dynamic split_alignment varying from different dtypes. + Returns: reduced_x: reduced tokens for each group, with the same type and number of groups as the input group `x`, @@ -580,29 +625,23 @@ def group_reduce( valid if `reduce_op` is "lse", otherwise `None`. event: the event after executing the kernel (valid only if `async_op` is set). """ - is_out_buf_given = reduced_x is not None + # Check x = wrap_to_list(x) - num_groups = len(x) - if is_out_buf_given: - assert reduced_x is not None # mypy + num_groups, hidden_shape, dtype = len(x), x[0].shape[1:], x[0].dtype + if reduced_x is not None: reduced_x = wrap_to_list(reduced_x) assert len(reduced_x) == len(x), ( "The number of groups of input and output buffer should be the same, " f"but got {len(x)=}, {len(reduced_x)=}." ) - - hidden_shape = x[0].shape[1:] - hidden_size = math.prod(hidden_shape) - if is_out_buf_given: - assert reduced_x is not None # mypy for i in range(num_groups): assert reduced_x[i].shape[1:] == hidden_shape, ( "The hidden shape (except dim0) of input and output buffer should be the same, " f"but got {x[i].shape=}, {reduced_x[i].shape=}." ) - # Default config + # Set grpcoll config config = ( GrpCollConfig.get_default_group_reduce_config(self.group_size) if config is None @@ -610,21 +649,71 @@ def group_reduce( ) # View input/output to 2D shape + # HACK: If non-trivial split alignment is given, + # we will re-view the input/output from (seqlen, hidden_size) to (seqlen // align, hidden_size * align) + # to raise up the hidden size for better performance + # and of course, it requires the arguments to be aligned and re-calculated accordingly + # which we've already checked and done in the higher-level programs. + hidden_size = math.prod(hidden_shape) + assert ( + hidden_size * split_alignment + ) % GrpCollBuffer.get_hidden_size_alignment(dtype) == 0, ( + "The hidden size multiplied by split alignment should be aligned to " + f"{GrpCollBuffer.get_hidden_size_alignment(dtype)} for dtype {dtype}, " + f"but got {hidden_size=}, {split_alignment=}." + ) for i in range(num_groups): - x[i] = x[i].view(-1, hidden_size) - if is_out_buf_given: - assert reduced_x is not None # mypy + x[i] = x[i].view(-1, hidden_size * split_alignment) + if reduced_x is not None: for i in range(num_groups): - reduced_x[i] = reduced_x[i].view(-1, hidden_size) + reduced_x[i] = reduced_x[i].view(-1, hidden_size * split_alignment) + + # Prepare lse and reduced_lse + # HACK: same as above, we will re-view the lse/reduced_lse + # from (seqlen, num_heads) to (seqlen // align, num_heads * align) + # if non-trivial split alignment is given + if reduce_op == "lse": + assert lse is not None, "lse should not be None when `reduce_op == lse`" + num_heads = lse.shape[1] + lse = lse.view(-1, num_heads * split_alignment) + if reduced_lse is not None: + reduced_lse = reduced_lse.view(-1, num_heads * split_alignment) + else: # no need to reduce lse, even passed in + lse = None + reduced_lse = None - # Internode - if self.runtime.get_num_rdma_ranks() > 1: - return self._internode_group_reduce( + # Dispatch to intranode/internode group-reduce + if self.runtime.get_num_rdma_ranks() > 1: # Internode + ( + reduced_x, + reduced_lse, + event, + ) = self._internode_group_reduce( + x=x, + reduced_x=reduced_x, + config=config, + handle=handle, + reduce_op=reduce_op, + acc_reduce=acc_reduce, + pre_perm_idx=pre_perm_idx, + previous_event=previous_event, + kernel_barrier=kernel_barrier, + async_op=async_op, + allocate_on_comm_stream=allocate_on_comm_stream, + comm_dtype=comm_dtype, + lse=lse, + reduced_lse=reduced_lse, + ) + else: # Intranode + ( + reduced_x, + reduced_lse, + event, + ) = self._intranode_group_reduce( x=x, reduced_x=reduced_x, config=config, handle=handle, - hidden_shape=hidden_shape, reduce_op=reduce_op, acc_reduce=acc_reduce, pre_perm_idx=pre_perm_idx, @@ -637,24 +726,14 @@ def group_reduce( reduced_lse=reduced_lse, ) - # Intranode - return self._intranode_group_reduce( - x=x, - reduced_x=reduced_x, - config=config, - handle=handle, - hidden_shape=hidden_shape, - reduce_op=reduce_op, - acc_reduce=acc_reduce, - pre_perm_idx=pre_perm_idx, - previous_event=previous_event, - kernel_barrier=kernel_barrier, - async_op=async_op, - allocate_on_comm_stream=allocate_on_comm_stream, - comm_dtype=comm_dtype, - lse=lse, - reduced_lse=reduced_lse, - ) + # View output back to original hidden shape + # as well as reduced lse if given + for i in range(num_groups): + reduced_x[i] = reduced_x[i].view(-1, *hidden_shape) + if reduced_lse is not None: + reduced_lse = reduced_lse.view(-1, num_heads) + + return reduced_x, reduced_lse, event def _intranode_group_cast( self, @@ -662,7 +741,6 @@ def _intranode_group_cast( recv_x: list[torch.Tensor] | None, config: GrpCollConfig, handle: GrpCollHandle | None, - hidden_shape: torch.Size, num_tokens_per_rank: torch.Tensor | None = None, is_token_in_rank: torch.Tensor | None = None, post_perm_idx: torch.Tensor | None = None, @@ -670,7 +748,6 @@ def _intranode_group_cast( kernel_barrier=None, async_op: bool = False, allocate_on_comm_stream: bool = False, - cast_lse: bool = False, lse: torch.Tensor | None = None, recv_lse: torch.Tensor | None = None, ) -> tuple[ @@ -693,13 +770,6 @@ def _intranode_group_cast( rank_prefix_matrix = None channel_prefix_matrix = None - # Prepare lse and recv_lse - if cast_lse: - assert lse is not None, "lse should not be None when `cast_lse` is set" - else: # no need to cast lse, even passed in - lse = None - recv_lse = None - # Unpack (x,recv_x) groups # HACK: this is a hacky way to pack several tensors together # w/o introducing extra H2D for the vector of ptrs @@ -777,10 +847,6 @@ def _intranode_group_cast( if num_groups > 2: recv_x.append(recv_x_3rd) - # View output to hidden shape - for i in range(num_groups): - recv_x[i] = recv_x[i].view(-1, *hidden_shape) - return ( recv_x, recv_lse, @@ -794,7 +860,6 @@ def _intranode_group_reduce( reduced_x: list[torch.Tensor] | None, config: GrpCollConfig, handle: GrpCollHandle | None, - hidden_shape: torch.Size, reduce_op: GroupReduceOp = "sum", acc_reduce: bool = False, pre_perm_idx: torch.Tensor | None = None, @@ -808,15 +873,9 @@ def _intranode_group_reduce( ) -> tuple[list[torch.Tensor], torch.Tensor | None, EventOverlap]: """Intranode group reduce implementation""" + # Check assert isinstance(handle, GrpCollIntraHandle) - # Prepare lse and reduced_lse - if reduce_op == "lse": - assert lse is not None, "lse should not be None when `reduce_op == lse`" - else: # no need to reduce lse, even passed in - lse = None - reduced_lse = None - # Unpack (x,reduced_x) groups num_groups = len(x) assert ( @@ -865,10 +924,6 @@ def _intranode_group_reduce( if num_groups > 1: reduced_x.append(reduced_x_2nd) - # View output to hidden shape - for i in range(num_groups): - reduced_x[i] = reduced_x[i].view(-1, *hidden_shape) - return (reduced_x, reduced_lse, EventOverlap(event)) def _internode_group_cast( @@ -877,7 +932,6 @@ def _internode_group_cast( recv_x: list[torch.Tensor] | None, config: GrpCollConfig, handle: GrpCollHandle | None, - hidden_shape: torch.Size, num_tokens_per_rank: torch.Tensor | None = None, num_tokens_per_rdma_rank: torch.Tensor | None = None, is_token_in_rank: torch.Tensor | None = None, @@ -886,7 +940,6 @@ def _internode_group_cast( kernel_barrier=None, async_op: bool = False, allocate_on_comm_stream: bool = False, - cast_lse: bool = False, lse: torch.Tensor | None = None, recv_lse: torch.Tensor | None = None, max_num_rdma_recv_tokens: int = -1, @@ -917,13 +970,6 @@ def _internode_group_cast( gbl_channel_prefix_matrix = None recv_gbl_rank_prefix_sum = None - # Prepare lse and recv_lse - if cast_lse: - assert lse is not None, "lse should not be None when `cast_lse` is set" - else: # no need to cast lse, even passed in - lse = None - recv_lse = None - # Unpack (x,recv_x) groups # HACK: this is a hacky way to pack several tensors together # w/o introducing extra H2D for the vector of ptrs @@ -1013,10 +1059,6 @@ def _internode_group_cast( if num_groups > 2: recv_x.append(recv_x_3rd) - # View output to hidden shape - for i in range(num_groups): - recv_x[i] = recv_x[i].view(-1, *hidden_shape) - return ( recv_x, recv_lse, @@ -1030,7 +1072,6 @@ def _internode_group_reduce( reduced_x: list[torch.Tensor] | None, config: GrpCollConfig, handle: GrpCollHandle, - hidden_shape: torch.Size, reduce_op: GroupReduceOp = "sum", acc_reduce: bool = False, pre_perm_idx: torch.Tensor | None = None, @@ -1044,15 +1085,9 @@ def _internode_group_reduce( ) -> tuple[list[torch.Tensor], torch.Tensor | None, EventOverlap]: """Internode group reduce implementation""" + # Check assert isinstance(handle, GrpCollInterHandle) - # Prepare lse and reduced_lse - if reduce_op == "lse": - assert lse is not None, "lse should not be None when `reduce_op == lse`" - else: # no need to reduce lse, even passed in - lse = None - reduced_lse = None - # Unpack (x,reduced_x) groups num_groups = len(x) assert ( @@ -1104,10 +1139,6 @@ def _internode_group_reduce( if num_groups > 1: reduced_x.append(reduced_x_2nd) - # View output to hidden shape - for i in range(num_groups): - reduced_x[i] = reduced_x[i].view(-1, *hidden_shape) - return (reduced_x, reduced_lse, EventOverlap(event)) # NOTE: remain original low-latency interface here for future potential usage, @@ -1382,3 +1413,13 @@ def get_hidden_size_alignment(dtype: torch.dtype) -> int: # thus for bf16/fp16, the hidden size alignment is: # WARP_SIZE * sizeof(int4) / sizeof(dtype) = 32 * 16 / 2 = 256 return 32 * 16 // dtype.itemsize + + @staticmethod + def get_max_supported_hidden_size(dtype: torch.dtype) -> int: + max_supported_hidden_size_bf16 = 8192 + return max_supported_hidden_size_bf16 * 2 // dtype.itemsize + + @staticmethod + def get_min_high_bw_hidden_size(dtype: torch.dtype) -> int: + min_high_bw_hidden_size_bf16 = 4096 + return min_high_bw_hidden_size_bf16 * 2 // dtype.itemsize diff --git a/magi_attention/comm/primitive/grpcoll/_config.py b/magi_attention/comm/primitive/grpcoll/_config.py index c00bfd0d3..60b3948ea 100644 --- a/magi_attention/comm/primitive/grpcoll/_config.py +++ b/magi_attention/comm/primitive/grpcoll/_config.py @@ -143,6 +143,10 @@ def get_min_num_bytes_intranode( num_groups: int = 1, alignment: int = 128, # according to `NUM_BUFFER_ALIGNMENT_BYTES` ) -> int: + """Calculate the minimum number of bytes required for intranode native grpcoll. + Returns: + min_num_nvl_bytes: minimum number of bytes required for NVL buffer. + """ if transfer_lse: assert ( num_heads is not None @@ -187,9 +191,15 @@ def get_min_num_bytes_internode( transfer_lse: bool = False, num_heads: int | None = None, num_groups: int = 1, - alignment: int = 128, # according to `NUM_BUFFER_ALIGNMENT_BYTES` rdma_decoulped: bool = True, + alignment: int = 128, # according to `NUM_BUFFER_ALIGNMENT_BYTES` ) -> tuple[int, int]: + """Calculate the minimum number of bytes required for internode native grpcoll. + Returns: + tuple[int, int]: + min_num_rdma_bytes: minimum number of bytes required for RDMA buffer. + min_num_nvl_bytes: minimum number of bytes required for NVL buffer. + """ if transfer_lse: assert ( num_heads is not None diff --git a/magi_attention/comm/primitive/grpcoll/_group_collective.py b/magi_attention/comm/primitive/grpcoll/_group_collective.py index 63cdd04d7..65168aab0 100644 --- a/magi_attention/comm/primitive/grpcoll/_group_collective.py +++ b/magi_attention/comm/primitive/grpcoll/_group_collective.py @@ -38,7 +38,7 @@ # ------------------ group cast ------------------ # -# host meta interface +# Host meta interface @overload def group_cast( input: torch.Tensor, @@ -57,7 +57,7 @@ def group_cast( ... -# device meta interface +# Device meta interface @overload def group_cast( input: torch.Tensor, @@ -208,7 +208,7 @@ def group_cast( # ------------------ group reduce ------------------ # -# host meta interface +# Host meta interface @overload def group_reduce( input: torch.Tensor, @@ -229,7 +229,7 @@ def group_reduce( ... -# device meta interface +# Device meta interface @overload def group_reduce( input: torch.Tensor, diff --git a/magi_attention/comm/primitive/grpcoll/_group_collective_hier.py b/magi_attention/comm/primitive/grpcoll/_group_collective_hier.py index c161872d0..230bf3ead 100644 --- a/magi_attention/comm/primitive/grpcoll/_group_collective_hier.py +++ b/magi_attention/comm/primitive/grpcoll/_group_collective_hier.py @@ -567,7 +567,7 @@ def init_hier_group_cast_meta_solver( ) -# host meta interface +# Host meta interface @overload def hier_group_cast_impl_with_a2av( input_tensor: torch.Tensor, @@ -586,7 +586,7 @@ def hier_group_cast_impl_with_a2av( ... -# device meta interface +# Device meta interface @overload def hier_group_cast_impl_with_a2av( input_tensor: torch.Tensor, @@ -1133,7 +1133,7 @@ def init_hier_group_reduce_meta_solver( ) -# host meta interface +# Host meta interface @overload def hier_group_reduce_impl_with_a2av( input_tensor: torch.Tensor, @@ -1154,7 +1154,7 @@ def hier_group_reduce_impl_with_a2av( ... -# device meta interface +# Device meta interface @overload def hier_group_reduce_impl_with_a2av( input_tensor: torch.Tensor, diff --git a/magi_attention/comm/primitive/grpcoll/_mgr.py b/magi_attention/comm/primitive/grpcoll/_mgr.py index af67c1074..6326d5a34 100644 --- a/magi_attention/comm/primitive/grpcoll/_mgr.py +++ b/magi_attention/comm/primitive/grpcoll/_mgr.py @@ -25,6 +25,7 @@ __all__ = ["grpcoll_buffer_mgr"] +# TODO: make (process_group, buffer_name) pair as the key for grpcoll buffer class GrpCollBufferMgr(metaclass=SingletonMeta): """ A singleton class to manage GrpCollBuffer instances by name. diff --git a/magi_attention/comm/primitive/grpcoll/_native_grpcoll_impl.py b/magi_attention/comm/primitive/grpcoll/_native_grpcoll_impl.py index 6226c7e99..e8a83bf3b 100644 --- a/magi_attention/comm/primitive/grpcoll/_native_grpcoll_impl.py +++ b/magi_attention/comm/primitive/grpcoll/_native_grpcoll_impl.py @@ -40,7 +40,7 @@ # ------------------ native group cast ------------------ # -# host meta interface +# Host meta interface @overload def native_group_cast_impl( input: torch.Tensor, @@ -59,7 +59,7 @@ def native_group_cast_impl( ... -# device meta interface +# Device meta interface @overload def native_group_cast_impl( input: torch.Tensor, @@ -94,34 +94,38 @@ def native_group_cast_impl( **kwargs, ) -> WorkWithPostProcessFn: """Native group-cast implementation""" + buffer_name = kwargs.pop("buffer_name", "default") kernel_barrier = kwargs.pop("kernel_barrier", None) - # get grpcoll config and buffer + # Get grpcoll config and buffer config: GrpCollConfig = grpcoll_buffer_mgr.get_config() buffer: GrpCollBuffer = grpcoll_buffer_mgr.get_buffer(buffer_name) assert config is not None and buffer is not None - # pack input and output + # Pack input and output input: list[torch.Tensor] = wrap_to_list(input) output: list[torch.Tensor] | None = ( wrap_to_list(output) if output is not None else output ) num_groups = len(input) - # get seqlen info + # Get seqlen info input_seqlen: int = input[0].size(0) output_seqlen: int | None = ( output[0].size(0) if output is not None else kwargs.pop("output_seqlen", None) ) internode_output_seqlen: int = kwargs.pop("internode_output_seqlen", -1) - # get meta dict and handle + # Get split alignment + split_alignment: int = kwargs.pop("split_alignment", 1) + + # Get meta dict and handle meta_dict: dict[str, Any] = kwargs.pop("native_group_cast_meta_dict", {}) handle_dict: dict[str, GrpCollHandle] = kwargs.pop("native_grpcoll_handle_dict", {}) handle: GrpCollHandle | None = handle_dict.get("group_cast", None) - # transfer to native group-cast meta args + # Transfer to native group-cast meta args if meta_dict: num_tokens_per_rank = meta_dict["num_tokens_per_rank"] num_tokens_per_rdma_rank = meta_dict["num_tokens_per_rdma_rank"] @@ -137,13 +141,13 @@ def native_group_cast_impl( dst_indices=dst_indices, group=group, input_seqlen=input_seqlen, - # HACK: leave a slot for t2r_idx + # HACK: Leave a slot for `t2r_idx` # since for now, we transfer the group_cast meta to it inside anyway # which is helpful in the token-level communication scenarios such as ep, nsa t2r_idx=kwargs.pop("t2r_idx", None), ) - # for group-cast, perm_to_a2av_idx is the post_perm_idx + # For group-cast, perm_to_a2av_idx is the post_perm_idx post_perm_idx = get_a2av_perm_idxs_from_group_cast_meta( output_split_sizes=output_split_sizes, src_index=src_index, @@ -151,7 +155,7 @@ def native_group_cast_impl( output_seqlen=output_seqlen, ) - # launch group cast kernel + # Launch group cast kernel ( recv_x, recv_lse, @@ -174,17 +178,18 @@ def native_group_cast_impl( lse=input_lse, recv_lse=output_lse, max_num_rdma_recv_tokens=internode_output_seqlen, + split_alignment=split_alignment, ) - # unpack recv_x + # Unpack recv_x if num_groups == 1: recv_x = recv_x[0] - # HACK: prepare handle for symmetric group-reduce or cached group-cast + # HACK: Prepare handle for symmetric group-reduce or cached group-cast handle_dict["group_cast"] = handle handle_dict["group_reduce"] = handle - # prepare work with post-process + # Prepare work with post-process work_with_post_process_fn = WorkWithPostProcessFn( work=GeneralWork(event), post_process_fn=( @@ -201,7 +206,7 @@ def native_group_cast_impl( # ------------------ native group reduce ------------------ # -# host meta interface +# Host meta interface @overload def native_group_reduce_impl( input: torch.Tensor, @@ -222,7 +227,7 @@ def native_group_reduce_impl( ... -# device meta interface +# Device meta interface @overload def native_group_reduce_impl( input: torch.Tensor, @@ -261,34 +266,38 @@ def native_group_reduce_impl( **kwargs, ) -> WorkWithPostProcessFn: """Native group-reduce implementation""" - # maybe lazy init buffer + + # Maybe lazy init buffer buffer_name = kwargs.pop("buffer_name", "default") kernel_barrier = kwargs.pop("kernel_barrier", None) - # get grpcoll config and buffer + # Get grpcoll config and buffer config: GrpCollConfig = grpcoll_buffer_mgr.get_config() buffer: GrpCollBuffer = grpcoll_buffer_mgr.get_buffer(buffer_name) assert config is not None and buffer is not None - # pack input and output + # Pack input and output input: list[torch.Tensor] = wrap_to_list(input) output: list[torch.Tensor] | None = ( wrap_to_list(output) if output is not None else output ) num_groups = len(input) - # get seqlen info + # Get seqlen info input_seqlen: int = input[0].size(0) output_seqlen: int | None = ( output[0].size(0) if output is not None else kwargs.pop("output_seqlen", None) ) - # get meta dict and handle + # Get split alignment + split_alignment: int = kwargs.pop("split_alignment", 1) + + # Get meta dict and handle meta_dict: dict[str, Any] = kwargs.pop("native_group_reduce_meta_dict", {}) handle_dict: dict[str, GrpCollHandle] = kwargs.pop("native_grpcoll_handle_dict", {}) handle: GrpCollHandle | None = handle_dict.get("group_reduce", None) if handle is None: - # FIXME: for now, we don't support individual group-reduce + # FIXME: For now, we don't support individual group-reduce # since the necessary handle is not known until the symmetric group-cast returns handle = get_group_reduce_handle_from_sym_group_cast( input=input[0], @@ -303,11 +312,11 @@ def native_group_reduce_impl( t2r_idx=kwargs.pop("t2r_idx", None), ) - # transfer to symmetric native group-cast meta args + # Transfer to symmetric native group-cast meta args if meta_dict: pre_perm_idx = meta_dict["pre_perm_idx"] else: - # for group-reduce, perm_to_a2av_idx is the pre_perm_idx + # For group-reduce, perm_to_a2av_idx is the pre_perm_idx # the same as the post_perm_idx for symmetric group-cast pre_perm_idx = get_a2av_perm_idxs_from_group_cast_meta( output_split_sizes=input_split_sizes, @@ -316,7 +325,7 @@ def native_group_reduce_impl( output_seqlen=input_seqlen, ) - # launch group reduce kernel + # Launch group reduce kernel ( reduced_x, reduced_lse, @@ -336,18 +345,19 @@ def native_group_reduce_impl( comm_dtype=comm_dtype, lse=input_lse, reduced_lse=output_lse, + split_alignment=split_alignment, ) - # unpack reduced_x + # Unpack reduced_x if num_groups == 1: reduced_x = reduced_x[0] - # HACK: prepare handle for symmetric group-cast or cached group-reduce + # HACK: Prepare handle for symmetric group-cast or cached group-reduce # REVIEW: should we empty the handle dict since the tensors in handle is inplace modified ? handle_dict["group_cast"] = handle handle_dict["group_reduce"] = handle - # prepare work with post-process + # Prepare work with post-process work_with_post_process_fn = WorkWithPostProcessFn( work=GeneralWork(event), post_process_fn=( diff --git a/magi_attention/common/jit/core.py b/magi_attention/common/jit/core.py index f14a0ecfb..446115561 100644 --- a/magi_attention/common/jit/core.py +++ b/magi_attention/common/jit/core.py @@ -92,7 +92,7 @@ class JitSpec: extra_objects: Optional[list[str]] = None needs_device_linking: bool = False - def __repr__(self): + def __repr__(self) -> str: # pragma: no cover def _fmt_list(values, indent: str = " ", max_items: int = 8) -> str: if values is None: return "None" diff --git a/magi_attention/common/mask.py b/magi_attention/common/mask.py index 49146635c..e2a73e4fc 100644 --- a/magi_attention/common/mask.py +++ b/magi_attention/common/mask.py @@ -441,7 +441,7 @@ def visualize(self, save_path: str | None = None) -> None: save_path=save_path, ) - def __repr__(self) -> str: + def __repr__(self) -> str: # pragma: no cover repr_str = [""] repr_str.append( diff --git a/magi_attention/common/range.py b/magi_attention/common/range.py index 09ec48715..92f64fbf7 100644 --- a/magi_attention/common/range.py +++ b/magi_attention/common/range.py @@ -181,7 +181,7 @@ def __eq__(self, other: Any) -> bool: def __hash__(self) -> int: return hash((self._start, self._end)) - def __repr__(self) -> str: + def __repr__(self) -> str: # pragma: no cover return f"[{self._start}, {self._end})" diff --git a/magi_attention/common/ranges.py b/magi_attention/common/ranges.py index 6925c6850..0d4780373 100644 --- a/magi_attention/common/ranges.py +++ b/magi_attention/common/ranges.py @@ -248,6 +248,42 @@ def merge(self) -> "AttnRanges": return _merged_ranges + @nvtx.instrument_nvtx + def merge_with_split_alignment(self, split_alignment: int = 1) -> "AttnRanges": + """Merge the attn_ranges for the overlapped / tangent parts with split alignment + in ascending order by 'attn_range.start' + + Args: + split_alignment (int): The alignment of the split, default is 1 + + Returns: + AttnRanges: The merged attn_ranges with split alignment, all start and end are aligned to the split_alignment + """ + + _ranges = self.sort()._ranges # required to be sorted first + + _merged_ranges = AttnRanges() + + start, end = None, None + for attn_range in _ranges: + attn_range_start = attn_range.start // split_alignment * split_alignment + attn_range_end = ( + (attn_range.end + split_alignment - 1) + // split_alignment + * split_alignment + ) + if start is None: + start, end = attn_range_start, attn_range_end + _merged_ranges.append(AttnRange(start=start, end=end)) + elif attn_range_start > end: # type: ignore[operator] + start, end = attn_range_start, attn_range_end + _merged_ranges.append(AttnRange(start=start, end=end)) + elif attn_range_end > end: # type: ignore[operator] + end = attn_range_end + _merged_ranges[-1].end = end + + return _merged_ranges + @nvtx.instrument_nvtx def chunk(self, chunk_size: int, check: bool = True) -> list["AttnRanges"]: if check: # required to be non-overlap @@ -777,7 +813,7 @@ def __eq__(self, other: Any) -> bool: def __hash__(self) -> int: return hash(tuple(self._ranges)) - def __repr__(self) -> str: + def __repr__(self) -> str: # pragma: no cover if self.is_empty(): # to prevent repr as "[]" to mix up with empty list return "[[,)]" return f"{self._ranges}" diff --git a/magi_attention/common/rectangle.py b/magi_attention/common/rectangle.py index 79ca8c684..0e0ab0f38 100644 --- a/magi_attention/common/rectangle.py +++ b/magi_attention/common/rectangle.py @@ -507,7 +507,7 @@ def __eq__(self, other: Any) -> bool: def __hash__(self) -> int: return hash((self._q_range, self._k_range, self._d_range)) - def __repr__(self) -> str: + def __repr__(self) -> str: # pragma: no cover return f"{self._q_range} x {self._k_range} x {self._d_range}" diff --git a/magi_attention/common/rectangles.py b/magi_attention/common/rectangles.py index e5ae8a50d..6339bfb0b 100644 --- a/magi_attention/common/rectangles.py +++ b/magi_attention/common/rectangles.py @@ -247,7 +247,7 @@ def __eq__(self, other: Any) -> bool: def __hash__(self) -> int: return hash(tuple(self._rects)) - def __repr__(self) -> str: + def __repr__(self) -> str: # pragma: no cover if self.is_empty(): return "[-1, -1) x [-1, -1): None" return f"{self._rects}" diff --git a/magi_attention/csrc/comm/grpcoll/buffer.cpp b/magi_attention/csrc/comm/grpcoll/buffer.cpp index 18e4c1fc4..5b1854cae 100644 --- a/magi_attention/csrc/comm/grpcoll/buffer.cpp +++ b/magi_attention/csrc/comm/grpcoll/buffer.cpp @@ -374,7 +374,7 @@ namespace magi_attn_comm::grpcoll { // Buffer Initialization /////////////////////////////////////////////////////////////////////////////////////////////////// -Buffer::Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_bytes, bool low_latency_mode, bool explicitly_destroy) +Buffer::Buffer(int rank, int num_ranks, size_t num_nvl_bytes, size_t num_rdma_bytes, bool low_latency_mode, bool explicitly_destroy) : rank(rank), num_ranks(num_ranks), num_nvl_bytes(num_nvl_bytes), @@ -389,8 +389,8 @@ Buffer::Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_ int64_t barrier_signal_ptr_bytes = NUM_MAX_NVL_PEERS * sizeof(int*); // host signal ptr array to each signal for each nvl rank // Common checks - GRPCOLL_HOST_ASSERT(num_nvl_bytes % NUM_BUFFER_ALIGNMENT_BYTES == 0 and (num_nvl_bytes <= INT_MAX or num_rdma_bytes == 0)); - GRPCOLL_HOST_ASSERT(num_rdma_bytes % NUM_BUFFER_ALIGNMENT_BYTES == 0 and (low_latency_mode or num_rdma_bytes <= INT_MAX)); + GRPCOLL_HOST_ASSERT(num_nvl_bytes % NUM_BUFFER_ALIGNMENT_BYTES == 0); + GRPCOLL_HOST_ASSERT(num_rdma_bytes % NUM_BUFFER_ALIGNMENT_BYTES == 0); GRPCOLL_HOST_ASSERT(0 <= rank and rank < num_ranks and (num_ranks <= NUM_MAX_PEERS or low_latency_mode)); GRPCOLL_HOST_ASSERT(num_ranks < NUM_MAX_NVL_PEERS or num_ranks % NUM_MAX_NVL_PEERS == 0); if (num_rdma_bytes > 0) @@ -431,7 +431,7 @@ Buffer::Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_ auto local_nvl_buffer_byte_ptr = static_cast(buffer_ptrs[nvl_rank]); // Set the host ptr to the local nvl signal - int64_t local_nvl_buffer_byte_offs = num_nvl_bytes; + size_t local_nvl_buffer_byte_offs = num_nvl_bytes; barrier_signal_ptrs[nvl_rank] = reinterpret_cast(local_nvl_buffer_byte_ptr + local_nvl_buffer_byte_offs); // Set the device ptr to the buffer ptr array @@ -757,7 +757,7 @@ Buffer::intranode_group_cast( auto channel_prefix_matrix = torch::Tensor(); // Notify - int num_memset_int = num_channels * num_ranks * 4; // clean channel start/end offset, head and tail + size_t num_memset_int = num_channels * num_ranks * 4; // clean channel start/end offset, head and tail if (cached_mode) { num_recv_tokens = cached_num_recv_tokens; rank_prefix_matrix = cached_rank_prefix_matrix.value(); @@ -1111,12 +1111,13 @@ Buffer::intranode_group_reduce( // Launch barrier and reset queue head and tail // TODO: support notify_group_reduce when the group_reduce kernel is individually used // without relying on the symmetric group_cast called first and necessary handle given + size_t num_memset_int = num_channels * num_ranks * 2; // clean queue head and tail intranode::cached_notify_group_reduce( /*buffer_ptrs=*/buffer_ptrs_gpu, /*send_head=*/send_head.data_ptr(), /*num_channels=*/num_channels, /*num_reduced_tokens=*/num_reduced_tokens, - /*num_memset_int=*/num_channels * num_ranks * 2, + /*num_memset_int=*/num_memset_int, /*barrier_signal_ptrs=*/barrier_signal_ptrs_gpu, /*rank=*/rank, /*num_ranks=*/num_ranks, @@ -2013,7 +2014,7 @@ py::bytearray Buffer::get_local_nvshmem_unique_id() const { torch::Tensor Buffer::get_local_buffer_tensor(const py::object& dtype, int64_t offset, bool use_rdma_buffer) const { torch::ScalarType casted_dtype = torch::python::detail::py_object_to_dtype(dtype); - auto element_bytes = static_cast(elementSize(casted_dtype)); + auto element_bytes = elementSize(casted_dtype); auto base_ptr = static_cast(use_rdma_buffer ? rdma_buffer_ptr : buffer_ptrs[nvl_rank]) + offset; auto num_bytes = use_rdma_buffer ? num_rdma_bytes : num_nvl_bytes; return torch::from_blob(base_ptr, num_bytes / element_bytes, torch::TensorOptions().dtype(casted_dtype).device(at::kCUDA)); diff --git a/magi_attention/csrc/comm/grpcoll/buffer.hpp b/magi_attention/csrc/comm/grpcoll/buffer.hpp index b6fdf19cd..1aa60b1bc 100644 --- a/magi_attention/csrc/comm/grpcoll/buffer.hpp +++ b/magi_attention/csrc/comm/grpcoll/buffer.hpp @@ -76,12 +76,12 @@ struct Buffer { bool low_latency_mode = false; // NVLink Buffer - int64_t num_nvl_bytes; + size_t num_nvl_bytes; void* buffer_ptrs[NUM_MAX_NVL_PEERS] = {nullptr}; void** buffer_ptrs_gpu = nullptr; // NVSHMEM Buffer - int64_t num_rdma_bytes; + size_t num_rdma_bytes; void* rdma_buffer_ptr = nullptr; // Device info and communication @@ -120,7 +120,7 @@ struct Buffer { int* grpcoll_recv_rdma_counter_mapped = nullptr; public: - Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_bytes, bool low_latency_mode, bool explicitly_destroy); + Buffer(int rank, int num_ranks, size_t num_nvl_bytes, size_t num_rdma_bytes, bool low_latency_mode, bool explicitly_destroy); ~Buffer() noexcept(false); diff --git a/magi_attention/csrc/comm/grpcoll/generate_inst.py b/magi_attention/csrc/comm/grpcoll/generate_inst.py index 7eacd227f..4439cc4aa 100644 --- a/magi_attention/csrc/comm/grpcoll/generate_inst.py +++ b/magi_attention/csrc/comm/grpcoll/generate_inst.py @@ -1,10 +1,13 @@ import math import os +import re from collections import namedtuple +from pathlib import Path # ========================================== # Configuration # ========================================== + inst_dir = "magi_attention/csrc/comm/grpcoll/instantiations" os.makedirs(inst_dir, exist_ok=True) @@ -15,122 +18,76 @@ # ========================================== # C++ Argument Signatures (Constants) # ========================================== -# Extracting these long strings prevents linting errors and improves readability. - -INTRANODE_CAST_ARGS = """ void* recv_x, - float* recv_lse, - const void* x, - const float* lse, - void* recv_x_2nd, - const void* x_2nd, - void* recv_x_3rd, - const void* x_3rd, - int* recv_src_idx, - int* recv_channel_offset, - int* send_head, - const bool* is_token_in_rank, - const int* channel_prefix_matrix, - const int64_t* post_perm_idx, - int num_tokens, - int hidden_int4, - int num_heads, - void** buffer_ptrs, - int rank, - cudaStream_t stream, - int num_sms, - int num_max_send_tokens, - int num_recv_buffer_tokens, - std::optional& kernel_barrier""" - -INTRANODE_REDUCE_ARGS = """ void* reduced_x, - float* reduced_lse, - const void* x, - const float* lse, - void* reduced_x_2nd, - const void* x_2nd, - int* send_head, - const int* src_idx, - const int* rank_prefix_matrix, - const int* channel_prefix_matrix, - const int64_t* pre_perm_idx, - int num_reduced_tokens, - int hidden_size, - int num_heads, - void** buffer_ptrs, - int rank, - cudaStream_t stream, - int num_sms, - int num_max_send_tokens, - int num_recv_buffer_tokens, - ReduceOp reduce_op, - std::optional& kernel_barrier""" - -INTERNODE_CAST_ARGS = """ void* recv_x, - float* recv_lse, - const void* x, - const float* lse, - void* recv_x_2nd, - const void* x_2nd, - void* recv_x_3rd, - const void* x_3rd, - void* recv_src_meta, - int* send_rdma_head, - int* send_nvl_head, - int* recv_rdma_channel_prefix_matrix, - int* recv_gbl_channel_prefix_matrix, - const int* rdma_channel_prefix_matrix, - const int* recv_rdma_rank_prefix_sum, - const int* gbl_channel_prefix_matrix, - const int* recv_gbl_rank_prefix_sum, - const bool* is_token_in_rank, - const int64_t* post_perm_idx, - int num_tokens, - int hidden_int4, - int num_heads, - void* rdma_buffer_ptr, - int num_max_rdma_chunked_send_tokens, - int num_max_rdma_chunked_recv_tokens, - void** buffer_ptrs, - int num_max_nvl_chunked_send_tokens, - int num_max_nvl_chunked_recv_tokens, - int rank, - int num_ranks, - int num_channels, - bool is_cached_group_cast, - cudaStream_t stream, - std::optional& kernel_barrier""" - -INTERNODE_REDUCE_ARGS = """ void* reduced_x, - float* reduced_lse, - const void* x, - const float* lse, - void* reduced_x_2nd, - const void* x_2nd, - const bool* is_reduced_token_in_rank, - const int* reduced_rdma_head, - const int* reduced_nvl_head, - const void* src_meta, - const int* rdma_channel_prefix_matrix, - const int* rdma_rank_prefix_sum, - const int* gbl_channel_prefix_matrix, - const int* gbl_rank_prefix_sum, - const int64_t* pre_perm_idx, - int num_reduced_tokens, - int hidden_size, - int num_heads, - void* rdma_buffer_ptr, - int num_max_rdma_chunked_send_tokens, - int num_max_rdma_chunked_recv_tokens, - void** buffer_ptrs, - int num_max_nvl_chunked_send_tokens, - int num_max_nvl_chunked_recv_tokens, - int rank, - int num_ranks, - cudaStream_t stream, - int num_channels, - std::optional& kernel_barrier, - bool acc_reduce, - ReduceOp reduce_op""" + +# Get the kernel directory +kernel_dir = Path(__file__).parent / "kernels" + + +def extract_function_params(file_name, func_name): + """ + Extracts the parameter list of a specific function from a .cuh/.h file. + Cleans up comments and formats the parameters into a multi-line string. + """ + + file_path = kernel_dir / file_name + + if not os.path.exists(file_path): + return f"Error: {file_path} not found." + + with open(file_path, "r", encoding="utf-8") as f: + content = f.read() + + # Search for the function name followed by an opening parenthesis + # \b ensures we match the exact function name + pattern = rf"\b{re.escape(func_name)}\b\s*\(" + match = re.search(pattern, content) + + if not match: + return f"Error: Function '{func_name}' not found." + + # Use a counter to find the matching closing parenthesis + # This handles nested parentheses like std::optional correctly + start_index = match.end() + paren_count = 1 + current_index = start_index + raw_params = "" + + while paren_count > 0 and current_index < len(content): + char = content[current_index] + if char == "(": + paren_count += 1 + elif char == ")": + paren_count -= 1 + + if paren_count > 0: + raw_params += char + current_index += 1 + + # Remove C-style block comments /* ... */ + raw_params = re.sub(r"/\*.*?\*/", "", raw_params, flags=re.DOTALL) + # Remove C++-style line comments // ... + raw_params = re.sub(r"//.*", "", raw_params) + + # Split by comma and clean up whitespace for each parameter + param_list = raw_params.split(",") + cleaned_params = [] + for p in param_list: + p_clean = p.strip() + if p_clean: + # Flatten multiple spaces/newlines into a single space + p_clean = " ".join(p_clean.split()) + cleaned_params.append(p_clean) + + return " " + ",\n ".join(cleaned_params) + + +INTRANODE_CAST_ARGS = extract_function_params("intranode.cuh", "launch_group_cast") + +INTRANODE_REDUCE_ARGS = extract_function_params("intranode.cuh", "launch_group_reduce") + +INTERNODE_CAST_ARGS = extract_function_params("internode.cuh", "launch_group_cast") + +INTERNODE_REDUCE_ARGS = extract_function_params("internode.cuh", "launch_group_reduce") # ========================================== @@ -138,6 +95,18 @@ # ========================================== +def write_if_different(path: Path, content: str) -> bool: + if path.exists(): + with open(path, "r") as f: + if f.read() == content: + return False + else: + path.parent.mkdir(parents=True, exist_ok=True) + with open(path, "w", encoding="utf-8") as f: + f.write(content) + return True + + def write_batched_file(category_name, file_index, headers, namespace, content_list): """Generates a single .cu file containing a batch of instantiations.""" if not content_list: @@ -148,32 +117,35 @@ def write_batched_file(category_name, file_index, headers, namespace, content_li print(f"Generating {filepath} ({len(content_list)} instantiations)...") - with open(filepath, "w") as f: - # File Header - f.write( - "/**********************************************************************************\n" - ) - f.write(" * Copyright (c) 2025-2026 SandAI. All Rights Reserved.\n") - f.write(" * Auto-generated by generate_inst.py (Batched Version)\n") - f.write( - " *********************************************************************************/\n\n" - ) - - # Includes - for h in headers: - f.write(f'#include "{h}"\n') - f.write("\n") - - # Namespace Start - f.write(f"namespace {namespace} {{\n\n") - - # Instantiations - for content in content_list: - f.write(content) - f.write("\n") - - # Namespace End - f.write(f"}} // namespace {namespace}\n") + content = "" + + # File Header + content += "/**********************************************************************************\n" + content += " * Copyright (c) 2025-2026 SandAI. All Rights Reserved.\n" + content += " * Auto-generated by generate_inst.py (Batched Version)\n" + content += " *********************************************************************************/\n\n" + + # Includes + for h in headers: + content += f'#include "{h}"\n' + content += "\n" + + # Namespace Start + content += f"namespace {namespace} {{\n\n" + + # Instantiations + for c in content_list: + content += f"{c}\n" + + # Namespace End + content += f"}} // namespace {namespace}\n" + + is_different = write_if_different(Path(filepath), content) + + if is_different: + print(" File written/updated.") + else: + print(" No changes detected. Skipping write.") def process_batch(category_name, headers, namespace, all_instantiations): diff --git a/magi_attention/csrc/comm/grpcoll/kernels/api.cuh b/magi_attention/csrc/comm/grpcoll/kernels/api.cuh index c30bbcb1d..82f81457e 100644 --- a/magi_attention/csrc/comm/grpcoll/kernels/api.cuh +++ b/magi_attention/csrc/comm/grpcoll/kernels/api.cuh @@ -101,7 +101,7 @@ void notify_group_cast( const bool* is_token_in_rank, int* channel_prefix_matrix, int* rank_prefix_matrix, - int num_memset_int, + size_t num_memset_int, void** buffer_ptrs, int** barrier_signal_ptrs, int rank, @@ -111,7 +111,7 @@ void notify_group_cast( void cached_notify_group_cast( const int* rank_prefix_matrix, - int num_memset_int, + size_t num_memset_int, void** buffer_ptrs, int** barrier_signal_ptrs, int rank, @@ -154,7 +154,7 @@ void cached_notify_group_reduce( int* send_head, int num_channels, int num_reduced_tokens, - int num_memset_int, + size_t num_memset_int, int** barrier_signal_ptrs, int rank, int num_ranks, @@ -218,8 +218,8 @@ void notify_group_cast( int** barrier_signal_ptrs, int rank, cudaStream_t stream, - int64_t num_rdma_bytes, - int64_t num_nvl_bytes, + size_t num_rdma_bytes, + size_t num_nvl_bytes, bool require_recv_count); template @@ -281,8 +281,8 @@ void cached_notify( int** barrier_signal_ptrs, int rank, cudaStream_t stream, - int64_t num_rdma_bytes, - int64_t num_nvl_bytes, + size_t num_rdma_bytes, + size_t num_nvl_bytes, bool is_cached_group_cast); template < diff --git a/magi_attention/csrc/comm/grpcoll/kernels/buffer.cuh b/magi_attention/csrc/comm/grpcoll/kernels/buffer.cuh index 65d1e37a9..b0bce2b54 100644 --- a/magi_attention/csrc/comm/grpcoll/kernels/buffer.cuh +++ b/magi_attention/csrc/comm/grpcoll/kernels/buffer.cuh @@ -53,11 +53,11 @@ struct Buffer { uint8_t* ptr; public: - int total_bytes; + size_t total_bytes; DEVICE_INLINE Buffer() : ptr(nullptr), total_bytes(0) {} - DEVICE_INLINE Buffer(void*& gbl_ptr, int num_elems, int elem_offset = 0) { + DEVICE_INLINE Buffer(void*& gbl_ptr, size_t num_elems, int elem_offset = 0) { total_bytes = num_elems * sizeof(dtype_t); // the total bytes of this block ptr = reinterpret_cast(gbl_ptr) + elem_offset * sizeof(dtype_t); // the start ptr within this block @@ -96,29 +96,31 @@ template struct AsymBuffer { private: uint8_t* ptrs[kNumRanks]; - int num_bytes; + size_t num_bytes; public: - int total_bytes; + size_t total_bytes; DEVICE_INLINE AsymBuffer() : ptrs{nullptr}, num_bytes(0), total_bytes(0) {} - DEVICE_INLINE AsymBuffer(void*& gbl_ptr, int num_elems, int num_ranks, int sm_id = 0, int num_sms = 1, int offset = 0) { + DEVICE_INLINE AsymBuffer(void*& gbl_ptr, size_t num_elems, int num_ranks, int sm_id = 0, int num_sms = 1, int offset = 0) { GRPCOLL_STATIC_ASSERT(kNumRanks == 1, "This API is only available for single rank case"); - num_bytes = num_elems * sizeof(dtype_t); - int per_channel_bytes = num_bytes * num_ranks; + num_bytes = num_elems * sizeof(dtype_t); + size_t per_channel_bytes = num_bytes * num_ranks; total_bytes = per_channel_bytes * num_sms; + ptrs[0] = reinterpret_cast(gbl_ptr) + per_channel_bytes * sm_id + num_bytes * offset; gbl_ptr = reinterpret_cast(gbl_ptr) + total_bytes; } - DEVICE_INLINE AsymBuffer(void** gbl_ptrs, int num_elems, int num_ranks, int sm_id = 0, int num_sms = 1, int offset = 0) { + DEVICE_INLINE AsymBuffer(void** gbl_ptrs, size_t num_elems, int num_ranks, int sm_id = 0, int num_sms = 1, int offset = 0) { GRPCOLL_STATIC_ASSERT(kNumRanks > 1, "This API is only available for multi rank case"); - num_bytes = num_elems * sizeof(dtype_t); - int per_channel_bytes = num_bytes * num_ranks; + num_bytes = num_elems * sizeof(dtype_t); + size_t per_channel_bytes = num_bytes * num_ranks; total_bytes = per_channel_bytes * num_sms; + for (int r = 0; r < kNumRanks; ++r) { ptrs[r] = reinterpret_cast(gbl_ptrs[r]) + per_channel_bytes * sm_id + num_bytes * offset; gbl_ptrs[r] = reinterpret_cast(gbl_ptrs[r]) + total_bytes; @@ -159,20 +161,19 @@ struct SymBuffer { private: uint8_t* send_ptr; uint8_t* recv_ptr; // NOTE: for coupled case, `recv_ptr` is not used - int num_bytes; + size_t num_bytes; public: - int total_bytes; + size_t total_bytes; DEVICE_INLINE SymBuffer() : send_ptr(nullptr), recv_ptr(nullptr), num_bytes(0), total_bytes(0) {} // TODO: fix the parameter names - DEVICE_INLINE SymBuffer(void*& gbl_ptr, int num_elems, int num_ranks, int sm_id = 0, int num_sms = 1) { + DEVICE_INLINE SymBuffer(void*& gbl_ptr, size_t num_elems, int num_ranks, int sm_id = 0, int num_sms = 1) { num_bytes = num_elems * sizeof(dtype_t); - - const int per_channel_bytes = num_bytes * num_ranks; - + const size_t per_channel_bytes = num_bytes * num_ranks; total_bytes = per_channel_bytes * num_sms * (static_cast(kDecoupled) + 1); + send_ptr = reinterpret_cast(gbl_ptr) + per_channel_bytes * sm_id; recv_ptr = reinterpret_cast(gbl_ptr) + per_channel_bytes * (sm_id + num_sms); gbl_ptr = reinterpret_cast(gbl_ptr) + total_bytes; diff --git a/magi_attention/csrc/comm/grpcoll/kernels/internode.cuh b/magi_attention/csrc/comm/grpcoll/kernels/internode.cuh index 9e7ca30db..9bfc259a8 100644 --- a/magi_attention/csrc/comm/grpcoll/kernels/internode.cuh +++ b/magi_attention/csrc/comm/grpcoll/kernels/internode.cuh @@ -97,8 +97,8 @@ void launch_group_cast( constexpr int kNumWarps = kNumThreads / WARP_SIZE; GRPCOLL_STATIC_ASSERT(kNumWarps == kNumSenderWarps + 1 + NUM_MAX_NVL_PEERS, "Invalid number of warps"); - constexpr int kNumTMABytesPerWarp = 16384; // 16KB - constexpr int smem_size = kNumTMABytesPerWarp * NUM_MAX_NVL_PEERS; // 128KB + constexpr int kNumTMABytesPerWarp = 27 * 1024; // 27KB + constexpr int smem_size = kNumTMABytesPerWarp * NUM_MAX_NVL_PEERS; // 27KB * 8 = 216KB < 224KB, can hardly be raised up const auto num_bytes_per_token = get_num_bytes_per_token(hidden_int4, num_heads); GRPCOLL_HOST_ASSERT(num_bytes_per_token + /*mbarrier*/ sizeof(uint64_t) <= kNumTMABytesPerWarp); @@ -231,12 +231,12 @@ void launch_group_reduce( GRPCOLL_STATIC_ASSERT(kNumForwarders > NUM_MAX_NVL_PEERS and kNumForwarders <= kNumForwarderWarps, "Invalid number of active forwarder warps"); GRPCOLL_STATIC_ASSERT(num_warps == kNumForwarders + 1, "Invalid number of warps"); - constexpr int kNumTMABytesPerSenderWarp = 1024 * 27; // 27KB REVIEW: tune this value + constexpr int kNumTMABytesPerSenderWarp = 27 * 1024; // 27KB constexpr int kNumTMALoadBytes = sizeof(int4) * WARP_SIZE; // 512B, as a warp-copy unit, each lane for one int4 constexpr int kNumTMABufferBytesPerStage = align(kNumTMALoadBytes * (NUM_MAX_NVL_PEERS + 1) + /*mbarrier*/ sizeof(uint64_t), sizeof(int4)); // 4624B constexpr int kNumTMABytesPerForwarderWarp = kNumTMAStages * kNumTMABufferBytesPerStage; constexpr int smem_size = std::max( - kNumTMABytesPerSenderWarp * NUM_MAX_NVL_PEERS, // 27KB * 8 = 128KB, can still be raised up + kNumTMABytesPerSenderWarp * NUM_MAX_NVL_PEERS, // 27KB * 8 = 216KB < 224KB, can hardly be raised up kNumTMABytesPerForwarderWarp * kNumForwarderWarps // 9248B * 24 = 216.75KB < 224KB, can hardly be raised up ); @@ -249,7 +249,9 @@ void launch_group_reduce( } const int hidden_int4 = hidden_size / (sizeof(int4) / sizeof(dtype_t)); - const int num_bytes_per_token = get_num_bytes_per_token(hidden_int4, num_heads); // NOTE: we still need enough TMA load buffer for original dtype + // NOTE: we still need enough TMA load buffer for original dtype + // before downcasting to comm_dtype_t, thus the maximum num_bytes_per_token is halved + const auto num_bytes_per_token = get_num_bytes_per_token(hidden_int4, num_heads); GRPCOLL_HOST_ASSERT(num_bytes_per_token + /*mbarrier*/ sizeof(uint64_t) <= kNumTMABytesPerSenderWarp); GRPCOLL_HOST_ASSERT(num_max_nvl_chunked_recv_tokens % kNumRDMARanks == 0); GRPCOLL_HOST_ASSERT(num_max_nvl_chunked_recv_tokens / kNumRDMARanks > std::max(num_max_rdma_chunked_send_tokens, num_max_nvl_chunked_send_tokens)); diff --git a/magi_attention/csrc/comm/grpcoll/kernels/internode_notify_kernel.cu b/magi_attention/csrc/comm/grpcoll/kernels/internode_notify_kernel.cu index aaa532612..2aeff173e 100644 --- a/magi_attention/csrc/comm/grpcoll/kernels/internode_notify_kernel.cu +++ b/magi_attention/csrc/comm/grpcoll/kernels/internode_notify_kernel.cu @@ -31,10 +31,10 @@ __global__ void notify_group_cast_kernel( const bool* is_token_in_rank, int num_tokens, int num_channels, - const int rdma_clean_offset, - const int rdma_num_int_clean, - const int nvl_clean_offset, - const int nvl_num_int_clean, + const size_t rdma_clean_offset, + const size_t rdma_num_int_clean, + const size_t nvl_clean_offset, + const size_t nvl_num_int_clean, int* rdma_channel_prefix_matrix, int* recv_rdma_rank_prefix_sum, int* gbl_channel_prefix_matrix, @@ -79,7 +79,7 @@ __global__ void notify_group_cast_kernel( // Clean up RDMA buffer of this rank for later meta data switch GRPCOLL_DEVICE_ASSERT(rdma_recv_num_tokens_mixed.total_bytes <= rdma_clean_offset * sizeof(int)); #pragma unroll - for (int i = thread_id; i < rdma_num_int_clean; i += kNumThreads) + for (size_t i = thread_id; i < rdma_num_int_clean; i += kNumThreads) rdma_buffer_ptr_int[rdma_clean_offset + i] = 0; // Copy send meta data of this RDMA rank to its local send buffer @@ -149,7 +149,7 @@ __global__ void notify_group_cast_kernel( auto nvl_buffer_ptr_int = static_cast(buffer_ptrs[nvl_rank]); GRPCOLL_DEVICE_ASSERT(nvl_send_num_tokens_per_rank.total_bytes <= nvl_clean_offset * sizeof(int)); #pragma unroll - for (int i = thread_id; i < nvl_num_int_clean; i += kNumThreads) + for (size_t i = thread_id; i < nvl_num_int_clean; i += kNumThreads) nvl_buffer_ptr_int[nvl_clean_offset + i] = 0; __syncthreads(); @@ -301,10 +301,10 @@ void notify_group_cast( int** barrier_signal_ptrs, int rank, cudaStream_t stream, - int64_t num_rdma_bytes, - int64_t num_nvl_bytes, + size_t num_rdma_bytes, + size_t num_nvl_bytes, bool require_recv_count) { - constexpr int kNumThreads = 512; + constexpr int kNumThreads = 1024; const auto num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS; // Get clean meta @@ -316,10 +316,6 @@ void notify_group_cast( GRPCOLL_HOST_ASSERT((rdma_clean_meta.first + rdma_clean_meta.second) * sizeof(int) <= num_rdma_bytes); GRPCOLL_HOST_ASSERT((nvl_clean_meta.first + nvl_clean_meta.second) * sizeof(int) <= num_nvl_bytes); - // REVIEW: why limited to INT_MAX ? - GRPCOLL_HOST_ASSERT(num_rdma_bytes < INT_MAX); - GRPCOLL_HOST_ASSERT(num_nvl_bytes < INT_MAX); - // Launch kernel SETUP_LAUNCH_CONFIG(1 + num_rdma_ranks, kNumThreads, stream); RDMA_RANKS_SWITCH(num_rdma_ranks, kNumRDMARanks, [&] { @@ -340,10 +336,10 @@ void notify_group_cast( is_token_in_rank, num_tokens, num_channels, - rdma_clean_meta.first, - rdma_clean_meta.second, - nvl_clean_meta.first, - nvl_clean_meta.second, + /*rdma_clean_offset=*/rdma_clean_meta.first, + /*rdma_num_int_clean=*/rdma_clean_meta.second, + /*nvl_clean_offset=*/nvl_clean_meta.first, + /*nvl_num_int_clean=*/nvl_clean_meta.second, rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, gbl_channel_prefix_matrix, @@ -359,10 +355,10 @@ void notify_group_cast( template __global__ void cached_notify_kernel( - const int rdma_clean_offset, - const int rdma_num_int_clean, - const int nvl_clean_offset, - const int nvl_num_int_clean, + const size_t rdma_clean_offset, + const size_t rdma_num_int_clean, + const size_t nvl_clean_offset, + const size_t nvl_num_int_clean, int* reduced_rdma_head, int num_reduced_tokens, int num_channels, @@ -377,7 +373,7 @@ __global__ void cached_notify_kernel( bool is_cached_group_cast, const nvshmem_team_t rdma_team) { const auto sm_id = static_cast(blockIdx.x), thread_id = static_cast(threadIdx.x), num_threads = static_cast(blockDim.x); - const auto warp_id = thread_id / WARP_SIZE, lane_id = get_lane_id(); + const auto warp_id = thread_id / WARP_SIZE, lane_id = get_lane_id(), num_warps = num_threads / WARP_SIZE; const auto nvl_rank = rank % NUM_MAX_NVL_PEERS, num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS, rdma_rank = rank / NUM_MAX_NVL_PEERS; if (sm_id == 0) { // the first SM is responsible to wait all previous inflight WRs finished and then clean the RDMA/NVL buffer @@ -390,13 +386,13 @@ __global__ void cached_notify_kernel( // Clean RDMA buffer of this RDMA rank auto rdma_buffer_ptr_int = static_cast(rdma_buffer_ptr); #pragma unroll - for (int i = thread_id; i < rdma_num_int_clean; i += num_threads) + for (size_t i = thread_id; i < rdma_num_int_clean; i += num_threads) rdma_buffer_ptr_int[rdma_clean_offset + i] = 0; // Clean NVL buffer of this NVL rank auto nvl_buffer_ptr_int = static_cast(buffer_ptrs[nvl_rank]); #pragma unroll - for (int i = thread_id; i < nvl_num_int_clean; i += num_threads) + for (size_t i = thread_id; i < nvl_num_int_clean; i += num_threads) nvl_buffer_ptr_int[nvl_clean_offset + i] = 0; __syncthreads(); @@ -412,18 +408,19 @@ __global__ void cached_notify_kernel( // Reset the rdma head, iterating in reverse order // each warp is responsible for one channel // and each lane in any warp is responsible for one rdma rank of the corr. channel - if (lane_id < num_rdma_ranks and warp_id < num_channels) { - int token_start_idx, token_end_idx; - get_channel_task_range(num_reduced_tokens, num_channels, warp_id, token_start_idx, token_end_idx); - - // NOTE: `1 << 25` is a heuristic large number - int last_head = 1 << 25; - for (int token_idx = token_end_idx - 1; token_idx >= token_start_idx; --token_idx) { - auto current_head = __ldg(reduced_rdma_head + token_idx * num_rdma_ranks + lane_id); - if (current_head < 0) { - reduced_rdma_head[token_idx * num_rdma_ranks + lane_id] = encode(last_head); - } else { - last_head = current_head; + int last_head = 1 << 25; // NOTE: `1 << 25` is a heuristic large number + for (int channel_id = warp_id; channel_id < num_channels; channel_id += num_warps) { + if (lane_id < num_rdma_ranks) { + int token_start_idx, token_end_idx; + get_channel_task_range(num_reduced_tokens, num_channels, channel_id, token_start_idx, token_end_idx); + + for (int token_idx = token_end_idx - 1; token_idx >= token_start_idx; --token_idx) { + auto current_head = __ldg(reduced_rdma_head + token_idx * num_rdma_ranks + lane_id); + if (current_head < 0) { + reduced_rdma_head[token_idx * num_rdma_ranks + lane_id] = encode(last_head); + } else { + last_head = current_head; + } } } } @@ -433,30 +430,32 @@ __global__ void cached_notify_kernel( if (is_cached_group_cast) return; - if (warp_id < num_channels) { - const auto rest_sm_id = sm_id - 2, num_rest_sms = num_channels * 2 - 2; - constexpr int tma_batch_size = kNumTMABytesPerWarp - sizeof(uint64_t); - constexpr int num_bytes_per_token = sizeof(int) * NUM_MAX_NVL_PEERS; - constexpr int num_tokens_per_batch = tma_batch_size / num_bytes_per_token; - GRPCOLL_STATIC_ASSERT(num_bytes_per_token % 16 == 0, "num_bytes_per_token should be divisible by 16"); - - // Prepare TMA buffer and init mbarrier - extern __shared__ __align__(1024) uint8_t smem_tma_buffer[]; - auto tma_buffer = smem_tma_buffer + warp_id * kNumTMABytesPerWarp; - auto tma_mbarrier = reinterpret_cast(tma_buffer + tma_batch_size); - uint32_t tma_phase = 0; - if (lane_id == 0) { - mbarrier_init(tma_mbarrier, /*arrive_count=*/1); // only lane0 participates - fence_view_async_shared(); - fence_barrier_init(); - } - __syncwarp(); + const auto rest_sm_id = sm_id - 2, num_rest_sms = num_channels * 2 - 2; + constexpr int num_bytes_per_token = sizeof(int) * NUM_MAX_NVL_PEERS; + constexpr int tma_batch_size = kNumTMABytesPerWarp - sizeof(uint64_t); + constexpr int num_tokens_per_batch = tma_batch_size / num_bytes_per_token; + GRPCOLL_STATIC_ASSERT(num_bytes_per_token % 16 == 0, "num_bytes_per_token should be divisible by 16"); + GRPCOLL_STATIC_ASSERT(num_bytes_per_token + /*mbarrier*/ sizeof(uint64_t) <= kNumTMABytesPerWarp, "TMA buffer size per warp is not enough"); + + // Prepare TMA buffer and init mbarrier + extern __shared__ __align__(1024) uint8_t smem_tma_buffer[]; + auto tma_buffer = smem_tma_buffer + warp_id * kNumTMABytesPerWarp; + auto tma_mbarrier = reinterpret_cast(tma_buffer + tma_batch_size); + uint32_t tma_phase = 0; + if (lane_id == 0) { + mbarrier_init(tma_mbarrier, /*arrive_count=*/1); // only lane0 participates + fence_view_async_shared(); + fence_barrier_init(); + } + __syncwarp(); + // Each warp is responsible for one channel + for (int channel_id = warp_id; channel_id < num_channels; channel_id += num_warps) { // Each rest SM for one dst RDMA peer for (int dst_rdma_rank = rest_sm_id; dst_rdma_rank < num_rdma_ranks; dst_rdma_rank += num_rest_sms) { // Iterate in reverse order - int token_start_idx = warp_id == 0 ? 0 : rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + warp_id - 1]; - int token_end_idx = rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + warp_id]; + int token_start_idx = channel_id == 0 ? 0 : rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id - 1]; + int token_end_idx = rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id]; int rank_prefix = dst_rdma_rank == 0 ? 0 : rdma_rank_prefix_sum[dst_rdma_rank - 1]; token_start_idx += rank_prefix, token_end_idx += rank_prefix; @@ -533,14 +532,14 @@ void cached_notify( int** barrier_signal_ptrs, int rank, cudaStream_t stream, - int64_t num_rdma_bytes, - int64_t num_nvl_bytes, + size_t num_rdma_bytes, + size_t num_nvl_bytes, bool is_cached_group_cast) { - const int num_threads = std::max(128, WARP_SIZE * num_channels), num_warps = num_threads / WARP_SIZE; - const auto num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS; - const int kNumTMABytesPerWarp = 8192; - const int smem_size = kNumTMABytesPerWarp * num_warps; + constexpr int num_threads = 512, num_warps = num_threads / WARP_SIZE; + constexpr int kNumTMABytesPerWarp = 8192; // 8KB + constexpr int smem_size = kNumTMABytesPerWarp * num_warps; // 8KB * 16 = 128KB < 224KB const int num_sms = num_channels * 2; + const int num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS; // Get clean meta auto rdma_clean_meta = get_rdma_clean_meta(hidden_int4, num_heads, num_groups, num_rdma_ranks, num_max_rdma_chunked_recv_tokens, num_channels); @@ -551,15 +550,10 @@ void cached_notify( GRPCOLL_HOST_ASSERT((rdma_clean_meta.first + rdma_clean_meta.second) * sizeof(int) <= num_rdma_bytes); GRPCOLL_HOST_ASSERT((nvl_clean_meta.first + nvl_clean_meta.second) * sizeof(int) <= num_nvl_bytes); - // REVIEW: why limited to INT_MAX ? - GRPCOLL_HOST_ASSERT(num_rdma_bytes < INT_MAX); - GRPCOLL_HOST_ASSERT(num_nvl_bytes < INT_MAX); - GRPCOLL_HOST_ASSERT(num_sms > 3); // first to barrier, second to reset RDMA head, rest to reset NVL head GRPCOLL_HOST_ASSERT(num_warps > 1); // for `barrier_all` if (!is_cached_group_cast) { // for rdma head reset before group_reduce - GRPCOLL_HOST_ASSERT(num_warps >= num_channels); GRPCOLL_HOST_ASSERT(num_rdma_ranks <= WARP_SIZE); // for nvl head reset before group_reduce @@ -574,10 +568,10 @@ void cached_notify( LAUNCH_KERNEL( &cfg, cached_notify_func, - rdma_clean_meta.first, - rdma_clean_meta.second, - nvl_clean_meta.first, - nvl_clean_meta.second, + /*rdma_clean_offset=*/rdma_clean_meta.first, + /*rdma_num_int_clean=*/rdma_clean_meta.second, + /*nvl_clean_offset=*/nvl_clean_meta.first, + /*nvl_num_int_clean=*/nvl_clean_meta.second, reduced_rdma_head, num_reduced_tokens, num_channels, diff --git a/magi_attention/csrc/comm/grpcoll/kernels/internode_notify_kernel.cuh b/magi_attention/csrc/comm/grpcoll/kernels/internode_notify_kernel.cuh index 978a545b1..9602d717a 100644 --- a/magi_attention/csrc/comm/grpcoll/kernels/internode_notify_kernel.cuh +++ b/magi_attention/csrc/comm/grpcoll/kernels/internode_notify_kernel.cuh @@ -31,10 +31,10 @@ __global__ void notify_group_cast_kernel( const bool* is_token_in_rank, int num_tokens, int num_channels, - const int rdma_clean_offset, - const int rdma_num_int_clean, - const int nvl_clean_offset, - const int nvl_num_int_clean, + const size_t rdma_clean_offset, + const size_t rdma_num_int_clean, + const size_t nvl_clean_offset, + const size_t nvl_num_int_clean, int* rdma_channel_prefix_matrix, int* recv_rdma_rank_prefix_sum, int* gbl_channel_prefix_matrix, @@ -68,16 +68,16 @@ void notify_group_cast( int** barrier_signal_ptrs, int rank, cudaStream_t stream, - int64_t num_rdma_bytes, - int64_t num_nvl_bytes, + size_t num_rdma_bytes, + size_t num_nvl_bytes, bool require_recv_count); template __global__ void cached_notify_kernel( - const int rdma_clean_offset, - const int rdma_num_int_clean, - const int nvl_clean_offset, - const int nvl_num_int_clean, + const size_t rdma_clean_offset, + const size_t rdma_num_int_clean, + const size_t nvl_clean_offset, + const size_t nvl_num_int_clean, int* reduced_rdma_head, int num_reduced_tokens, int num_channels, @@ -110,8 +110,8 @@ void cached_notify( int** barrier_signal_ptrs, int rank, cudaStream_t stream, - int64_t num_rdma_bytes, - int64_t num_nvl_bytes, + size_t num_rdma_bytes, + size_t num_nvl_bytes, bool is_cached_group_cast); } // namespace magi_attn_comm::grpcoll::internode diff --git a/magi_attention/csrc/comm/grpcoll/kernels/internode_utils.cuh b/magi_attention/csrc/comm/grpcoll/kernels/internode_utils.cuh index 0f78384a8..40d8b0940 100644 --- a/magi_attention/csrc/comm/grpcoll/kernels/internode_utils.cuh +++ b/magi_attention/csrc/comm/grpcoll/kernels/internode_utils.cuh @@ -89,18 +89,18 @@ constexpr int get_num_threads_group_reduce(const int num_group_reduce_forwarder_ return (num_group_reduce_forwarder_warps + 1) * WARP_SIZE; } -HOST_DEVICE_INLINE int get_num_bytes_per_token(int hidden_int4, int num_heads) { - return static_cast(align( +HOST_DEVICE_INLINE size_t get_num_bytes_per_token(int hidden_int4, int num_heads) { + return align( /*hidden_states=*/hidden_int4 * sizeof(int4) + /*lse*/ num_heads * sizeof(float) + /*source_meta=*/sizeof(SourceMeta), - sizeof(int4))); + sizeof(int4)); } // Get data buffer size and meta buffer size for RDMA buffer, all in `int32_t` // NOTE: summing them together to get the required minimum RDMA buffer size template -HOST_DEVICE_INLINE std::pair get_rdma_clean_meta( +HOST_DEVICE_INLINE std::pair get_rdma_clean_meta( int hidden_int4, int num_heads, int num_groups, @@ -117,7 +117,7 @@ HOST_DEVICE_INLINE std::pair get_rdma_clean_meta( // Get data buffer size and meta buffer size for NVL buffer, all in `int32_t` // NOTE: summing them together to get the required minimum NVL buffer size -HOST_DEVICE_INLINE std::pair get_nvl_clean_meta( +HOST_DEVICE_INLINE std::pair get_nvl_clean_meta( int hidden_int4, int num_heads, int num_groups, diff --git a/magi_attention/csrc/comm/grpcoll/kernels/intranode_notify_kernel.cu b/magi_attention/csrc/comm/grpcoll/kernels/intranode_notify_kernel.cu index 3feb19afe..9f97c5ee3 100644 --- a/magi_attention/csrc/comm/grpcoll/kernels/intranode_notify_kernel.cu +++ b/magi_attention/csrc/comm/grpcoll/kernels/intranode_notify_kernel.cu @@ -28,7 +28,7 @@ __global__ void notify_group_cast_kernel( const bool* is_token_in_rank, int* channel_prefix_matrix, int* rank_prefix_matrix, - int num_memset_int, + size_t num_memset_int, void** buffer_ptrs, int** barrier_signal_ptrs, int rank) { @@ -107,7 +107,7 @@ __global__ void notify_group_cast_kernel( #pragma unroll // Extra memset for later channel metadata // including channel start/end offset, head and tail - for (int i = thread_id; i < num_memset_int; i += num_threads) + for (size_t i = thread_id; i < num_memset_int; i += num_threads) buffer_ptr_after_rank_prefix[i] = 0; // Barrier @@ -145,7 +145,7 @@ void notify_group_cast( const bool* is_token_in_rank, int* channel_prefix_matrix, int* rank_prefix_matrix, - int num_memset_int, + size_t num_memset_int, void** buffer_ptrs, int** barrier_signal_ptrs, int rank, @@ -177,7 +177,7 @@ void notify_group_cast( } template -__global__ void cached_notify_group_cast_kernel(const int* rank_prefix_matrix, int num_memset_int, void** buffer_ptrs, int** barrier_signal_ptrs, int rank) { +__global__ void cached_notify_group_cast_kernel(const int* rank_prefix_matrix, size_t num_memset_int, void** buffer_ptrs, int** barrier_signal_ptrs, int rank) { // A simplified version for cached handles barrier_block(barrier_signal_ptrs, rank); @@ -188,7 +188,7 @@ __global__ void cached_notify_group_cast_kernel(const int* rank_prefix_matrix, i for (int i = thread_id; i < kNumRanks * kNumRanks; i += num_threads) ptr[i] = rank_prefix_matrix[i]; #pragma unroll - for (int i = thread_id; i < num_memset_int; i += num_threads) + for (size_t i = thread_id; i < num_memset_int; i += num_threads) ptr[kNumRanks * kNumRanks + i] = 0; // Barrier after cleaning @@ -197,7 +197,7 @@ __global__ void cached_notify_group_cast_kernel(const int* rank_prefix_matrix, i void cached_notify_group_cast( const int* rank_prefix_matrix, - int num_memset_int, + size_t num_memset_int, void** buffer_ptrs, int** barrier_signal_ptrs, int rank, @@ -216,7 +216,7 @@ __global__ void cached_notify_group_reduce_kernel( int* send_head, int num_channels, int num_reduced_tokens, - int num_memset_int, + size_t num_memset_int, int** barrier_signal_ptrs, int rank) { const auto sm_id = static_cast(blockIdx.x); @@ -228,7 +228,7 @@ __global__ void cached_notify_group_reduce_kernel( auto thread_id = static_cast(threadIdx.x), num_threads = static_cast(blockDim.x); auto ptr = static_cast(buffer_ptrs[rank]); #pragma unroll - for (int i = thread_id; i < num_memset_int; i += num_threads) + for (size_t i = thread_id; i < num_memset_int; i += num_threads) ptr[i] = 0; // Barrier after cleaning @@ -275,7 +275,7 @@ void cached_notify_group_reduce( int* send_head, int num_channels, int num_reduced_tokens, - int num_memset_int, + size_t num_memset_int, int** barrier_signal_ptrs, int rank, int num_ranks, diff --git a/magi_attention/csrc/comm/grpcoll/kernels/intranode_notify_kernel.cuh b/magi_attention/csrc/comm/grpcoll/kernels/intranode_notify_kernel.cuh index c8bd48c8b..5b382e809 100644 --- a/magi_attention/csrc/comm/grpcoll/kernels/intranode_notify_kernel.cuh +++ b/magi_attention/csrc/comm/grpcoll/kernels/intranode_notify_kernel.cuh @@ -34,7 +34,7 @@ __global__ void notify_group_cast_kernel( const bool* is_token_in_rank, int* channel_prefix_matrix, int* rank_prefix_matrix, - int num_memset_int, + size_t num_memset_int, void** buffer_ptrs, int** barrier_signal_ptrs, int rank); @@ -47,7 +47,7 @@ void notify_group_cast( const bool* is_token_in_rank, int* channel_prefix_matrix, int* rank_prefix_matrix, - int num_memset_int, + size_t num_memset_int, void** buffer_ptrs, int** barrier_signal_ptrs, int rank, @@ -56,11 +56,11 @@ void notify_group_cast( bool require_recv_count); template -__global__ void cached_notify_group_cast_kernel(const int* rank_prefix_matrix, int num_memset_int, void** buffer_ptrs, int** barrier_signal_ptrs, int rank); +__global__ void cached_notify_group_cast_kernel(const int* rank_prefix_matrix, size_t num_memset_int, void** buffer_ptrs, int** barrier_signal_ptrs, int rank); void cached_notify_group_cast( const int* rank_prefix_matrix, - int num_memset_int, + size_t num_memset_int, void** buffer_ptrs, int** barrier_signal_ptrs, int rank, @@ -73,7 +73,7 @@ __global__ void cached_notify_group_reduce_kernel( int* send_head, int num_channels, int num_reduced_tokens, - int num_memset_int, + size_t num_memset_int, int** barrier_signal_ptrs, int rank); @@ -82,7 +82,7 @@ void cached_notify_group_reduce( int* send_head, int num_channels, int num_reduced_tokens, - int num_memset_int, + size_t num_memset_int, int** barrier_signal_ptrs, int rank, int num_ranks, diff --git a/magi_attention/csrc/comm/grpcoll/kernels/layout.cu b/magi_attention/csrc/comm/grpcoll/kernels/layout.cu index 8f15a85ee..1f1e93bd8 100644 --- a/magi_attention/csrc/comm/grpcoll/kernels/layout.cu +++ b/magi_attention/csrc/comm/grpcoll/kernels/layout.cu @@ -167,72 +167,83 @@ void get_group_cast_meta( template __global__ void get_a2av_perm_idx(const int64_t* output_split_sizes, const int64_t* src_idx, int64_t* perm_to_a2av_idx, int num_ranks, int num_splits) { - auto thread_id = static_cast(threadIdx.x); + // Declare dynamic shared memory + extern __shared__ int64_t s_mem[]; + + // Partition the shared memory: + // rank_split_sizes occupies the first (kNumThreads * kMaxNumRanks) elements + int64_t* rank_split_sizes = s_mem; + // curr_offset_per_rank starts after rank_split_sizes + int64_t* curr_offset_per_rank = &s_mem[kNumThreads * kMaxNumRanks]; - __shared__ int64_t rank_split_sizes[kNumThreads][kMaxNumRanks]; - __shared__ int64_t curr_offset_per_rank[kMaxNumRanks + 1]; + auto thread_id = static_cast(threadIdx.x); -// init rank_split_sizes + // Initialize rank_split_sizes (flattened 2D array) #pragma unroll for (int i = 0; i < num_ranks; ++i) - rank_split_sizes[thread_id][i] = 0; + rank_split_sizes[thread_id * kMaxNumRanks + i] = 0; - // init curr_offset_per_rank + // Initialize curr_offset_per_rank if (thread_id < num_ranks + 1) curr_offset_per_rank[thread_id] = 0; __syncthreads(); -// per-thread count partial rank_split_sizes -// rank_split_sizes[tid][rid]: the partial sum of split sizes recved from rank rid -// counted by thread tid #pragma unroll + // Per-thread partial count of rank_split_sizes + // rank_split_sizes[tid][rid]: the partial sum of split sizes + // recved from rank rid, counted by thread tid for (int i = thread_id; i < num_splits; i += kNumThreads) { auto rank = src_idx[i]; auto split_size = output_split_sizes[i]; - rank_split_sizes[thread_id][rank] += split_size; + rank_split_sizes[thread_id * kMaxNumRanks + rank] += split_size; } __syncthreads(); - // sum up rank_split_sizes - // rank_split_sizes[rid][rid]: the total sum of split sizes recved from rank rid - if (thread_id < num_ranks) { +#pragma unroll + // Sum up partial results from all threads for each rank + // rank_split_sizes[rid][rid]: the total sum of split sizes + // recved from rank rid + for (int r = thread_id; r < num_ranks; r += kNumThreads) { int64_t sum = 0; - -// sum up for partial results in each thread #pragma unroll + // sum up for partial results in each thread for (int i = 0; i < kNumThreads; ++i) - sum += rank_split_sizes[i][thread_id]; - rank_split_sizes[thread_id][thread_id] = sum; + sum += rank_split_sizes[i * kMaxNumRanks + r]; + + // Store total sum in the "diagonal" for the prefix sum step + rank_split_sizes[r * kMaxNumRanks + r] = sum; } __syncthreads(); - // prefix sum for each rank by thread 0 + // Prefix sum across ranks performed by thread 0 // rank_split_sizes[rid][rid]: the start offset of the a2av split buffer recved from rank rid // NOTE: since num_ranks are usually small, we don't need to use Blelloch scan algorithm if (thread_id == 0) { int64_t prefix_sum = 0; #pragma unroll for (int rid = 0; rid < num_ranks; ++rid) { - auto current = rank_split_sizes[rid][rid]; - rank_split_sizes[rid][rid] = prefix_sum; + auto current = rank_split_sizes[rid * kMaxNumRanks + rid]; + rank_split_sizes[rid * kMaxNumRanks + rid] = prefix_sum; prefix_sum += current; } } __syncthreads(); -// TODO: find a better way to parallelize -// especially when all the split sizes are small thus the number of splits is too large -// compute perm_to_a2av_idx, where output[perm_to_a2av_idx] => a2a_output #pragma unroll + // Compute final permutation indices + // TODO: find a better way to parallelize + // especially when all the split sizes are small thus the number of splits is too large + // compute perm_to_a2av_idx, where output[perm_to_a2av_idx] => a2a_output for (int i = 0; i < num_splits; ++i) { - // all threads process one split together auto rank = src_idx[i]; auto split_size = output_split_sizes[i]; - auto a2av_offset_this_rank = rank_split_sizes[rank][rank]; + auto a2av_offset_this_rank = rank_split_sizes[rank * kMaxNumRanks + rank]; auto a2av_offset_this_split = a2av_offset_this_rank + curr_offset_per_rank[rank]; auto start_token_idx = curr_offset_per_rank[num_ranks]; - __syncthreads(); // make sure each thread's read the same curr_offset_per_rank + + // Ensure all threads read the same offsets before processing the split + __syncthreads(); #pragma unroll for (int j = thread_id; j < split_size; j += kNumThreads) { @@ -241,22 +252,35 @@ __global__ void get_a2av_perm_idx(const int64_t* output_split_sizes, const int64 perm_to_a2av_idx[a2av_token_idx] = token_idx; } - // update the current offset by thread0 + // Update global offsets (performed by thread 0) if (thread_id == 0) { curr_offset_per_rank[num_ranks] += split_size; // start_token_idx curr_offset_per_rank[rank] += split_size; } - __syncthreads(); // make sure each thread'll read the latest curr_offset_per_rank in next iter + + // Sync before moving to the next split to read updated offsets + __syncthreads(); } } void get_a2av_perm_idx(const int64_t* output_split_sizes, const int64_t* src_idx, int64_t* perm_to_a2av_idx, int num_ranks, int num_splits, cudaStream_t stream) { - constexpr int num_sms = 1, kNumThreads = 256, kMaxNumRanks = 16; - GRPCOLL_STATIC_ASSERT(kNumThreads >= kMaxNumRanks, "kNumThreads should NOT less than kMaxNumRanks"); + constexpr int num_sms = 1; + constexpr int kMaxNumRanks = 32 * 8; // 8 ranks * 32 nodes = 256 + constexpr int kNumThreads = 108; // we will consume (108 * 256 + 256 + 1) * 8 = ~218KB shared memory + GRPCOLL_HOST_ASSERT(num_ranks <= kMaxNumRanks); + // Calculate required dynamic shared memory size in bytes + // Space for rank_split_sizes matrix + curr_offset_per_rank vector + constexpr int smem_size = (kNumThreads * kMaxNumRanks + kMaxNumRanks + 1) * sizeof(int64_t); + + auto kernel_func = get_a2av_perm_idx; + SETUP_LAUNCH_CONFIG(num_sms, kNumThreads, stream); - LAUNCH_KERNEL(&cfg, (get_a2av_perm_idx), output_split_sizes, src_idx, perm_to_a2av_idx, num_ranks, num_splits); + SET_SHARED_MEMORY_FOR_TMA(kernel_func); + + // Launch kernel + LAUNCH_KERNEL(&cfg, kernel_func, output_split_sizes, src_idx, perm_to_a2av_idx, num_ranks, num_splits); } } // namespace layout diff --git a/magi_attention/csrc/extensions/attn_ranges.hpp b/magi_attention/csrc/extensions/attn_ranges.hpp index fa52b619c..5acabd49c 100644 --- a/magi_attention/csrc/extensions/attn_ranges.hpp +++ b/magi_attention/csrc/extensions/attn_ranges.hpp @@ -1011,6 +1011,34 @@ struct AttnRanges { return _merged_ranges; } + AttnRanges merge_with_split_alignment(int split_alignment = 1) const { + AttnRanges sorted_ranges_obj = sort_ranges(); + const auto& _ranges = sorted_ranges_obj.ranges; + AttnRanges _merged_ranges; + + int start = std::numeric_limits::min(), end = std::numeric_limits::min(); + for (size_t i = 0; i < _ranges.size(); i++) { + const AttnRange& attn_range = _ranges[i]; + int attn_range_start = (attn_range.start / split_alignment) * split_alignment; + int attn_range_end = ((attn_range.end + split_alignment - 1) / split_alignment) * split_alignment; + + if (start == std::numeric_limits::min()) { + start = attn_range_start; + end = attn_range_end; + _merged_ranges.append(AttnRange(start, end)); + } else if (attn_range_start > end) { + start = attn_range_start; + end = attn_range_end; + _merged_ranges.append(AttnRange(start, end)); + } else if (attn_range_end > end) { + end = attn_range_end; + _merged_ranges.at(_merged_ranges.size() - 1).end = end; + } + } + + return _merged_ranges; + } + std::pair make_range_local(const AttnRange& other_attn_range, bool is_self_merged = false, const std::vector* prefix_offset_ptr = nullptr) const { AttnRanges merged_ranges_obj; diff --git a/magi_attention/csrc/extensions/magi_attn_ext.cpp b/magi_attention/csrc/extensions/magi_attn_ext.cpp index dd185c585..9ae6f9557 100644 --- a/magi_attention/csrc/extensions/magi_attn_ext.cpp +++ b/magi_attention/csrc/extensions/magi_attn_ext.cpp @@ -264,6 +264,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { .def("to_naive_ranges", &magi_attn_ext::AttnRanges::to_naive_ranges) .def_property_readonly("total_seqlen", &magi_attn_ext::AttnRanges::total_seqlen) .def("merge", &magi_attn_ext::AttnRanges::merge) + .def("merge_with_split_alignment", &magi_attn_ext::AttnRanges::merge_with_split_alignment, py::arg("split_alignment") = 1) .def("sort", &magi_attn_ext::AttnRanges::sort_ranges) .def("sort_ranges", &magi_attn_ext::AttnRanges::sort_ranges) .def("is_valid", &magi_attn_ext::AttnRanges::is_valid) diff --git a/magi_attention/dist_attn_runtime_mgr.py b/magi_attention/dist_attn_runtime_mgr.py index eb4a15473..757432b14 100644 --- a/magi_attention/dist_attn_runtime_mgr.py +++ b/magi_attention/dist_attn_runtime_mgr.py @@ -55,13 +55,15 @@ class DistAttnRuntimeKey: attn_mask_type: tuple[AttnMaskType, ...] total_seqlen_q: int total_seqlen_k: int + num_heads_q: int + num_heads_kv: int + head_dim: int pad_size: int chunk_size: int cp_group: dist.ProcessGroup cp_mesh: DeviceMesh | None dist_attn_config: DistAttnConfig - num_heads_q: int - num_heads_kv: int + # flags that might influence the runtime behavior is_deterministic_mode_enable: bool is_hierarchical_comm_enable: bool @@ -80,6 +82,9 @@ def __init__( dist_attn_config: DistAttnConfig, attn_solver: BaseDistAttnSolver, dist_attn_runtime: DistAttnRuntime, + num_heads_q: int, + num_heads_kv: int, + head_dim: int, cp_group: dist.ProcessGroup, *, ref_q_ranges: AttnRanges, @@ -87,26 +92,25 @@ def __init__( is_same_source: bool, is_q_permutable: bool, is_k_permutable: bool, - num_heads_q: int, - num_heads_kv: int, ): - self.cp_group = cp_group self.dispatch_meta_q = dispatch_meta_q self.dispatch_meta_k = dispatch_meta_k self.dist_attn_config = dist_attn_config self.attn_solver = attn_solver - self.dist_attn_runtime = dist_attn_runtime + self.num_heads_q = num_heads_q + self.num_heads_kv = num_heads_kv + self.head_dim = head_dim + + self.cp_group = cp_group + self.ref_q_ranges = ref_q_ranges self.ref_k_ranges = ref_k_ranges self.is_same_source = is_same_source self.is_q_permutable = is_q_permutable self.is_k_permutable = is_k_permutable - self.num_heads_q = num_heads_q - self.num_heads_kv = num_heads_kv - self._q_position_ids: None | torch.Tensor = None self._k_position_ids: None | torch.Tensor = None @@ -387,13 +391,14 @@ def init_dist_attn_runtime_key( attn_mask_type: list[AttnMaskType], total_seqlen_q: int, total_seqlen_k: int, + num_heads_q: int, + num_heads_kv: int, + head_dim: int, pad_size: int, chunk_size: int, cp_group: dist.ProcessGroup, cp_mesh: DeviceMesh | None, dist_attn_config: DistAttnConfig, - num_heads_q: int, - num_heads_kv: int, ) -> DistAttnRuntimeKey: """Initialize DistAttnRuntimeKey""" @@ -406,13 +411,14 @@ def init_dist_attn_runtime_key( attn_mask_type=tuple(attn_mask_type), total_seqlen_q=total_seqlen_q, total_seqlen_k=total_seqlen_k, + num_heads_q=num_heads_q, + num_heads_kv=num_heads_kv, + head_dim=head_dim, pad_size=pad_size, chunk_size=chunk_size, cp_group=cp_group, cp_mesh=cp_mesh, dist_attn_config=dist_attn_config, - num_heads_q=num_heads_q, - num_heads_kv=num_heads_kv, # auto set other flags that might influence the runtime behavior is_deterministic_mode_enable=magi_attention.is_deterministic_mode_enable(), is_hierarchical_comm_enable=magi_attention.comm.is_hierarchical_comm_enable(), @@ -444,68 +450,74 @@ def init_dist_attn_runtime_mgr( attn_mask_type: list[AttnMaskType], total_seqlen_q: int, total_seqlen_k: int, + num_heads_q: int, + num_heads_kv: int, + head_dim: int, chunk_size: int, cp_group: dist.ProcessGroup, - is_same_source: bool, - is_q_permutable: bool, - is_k_permutable: bool, - dist_attn_config: DistAttnConfig = DistAttnConfig(), cp_mesh: DeviceMesh | None = None, - num_heads_q: int = 1, - num_heads_kv: int = 1, + dist_attn_config: DistAttnConfig = DistAttnConfig(), + is_same_source: bool = True, + is_q_permutable: bool = True, + is_k_permutable: bool = True, ref_dispatch_meta_q: DispatchMeta | None = None, ref_dispatch_meta_k: DispatchMeta | None = None, ) -> DistAttnRuntimeMgr: """ Args: - q_ranges (AttnRanges): the global query ranges - k_ranges (AttnRanges): the global key ranges - attn_mask_type (list[AttnMaskType]): the global attn mask type list + q_ranges (AttnRanges): the global query ranges. + k_ranges (AttnRanges): the global key ranges. + attn_mask_type (list[AttnMaskType]): the global attn mask type list. - total_seqlen_q (int): the total seqlen of query - total_seqlen_k (int): the total seqlen of key + total_seqlen_q (int): the total seqlen of query. + total_seqlen_k (int): the total seqlen of key. - chunk_size (int): chunk size to chunk the permutable tensor + num_heads_q (int): number of heads of query. + num_heads_kv (int): number of heads of key/value. + head_dim (int): dimension of each head. - cp_group (dist.ProcessGroup): process group, only support nccl backend for now - - is_same_source (bool): is query tensor and key tensor share the same source - is_q_permutable (bool): is query tensor permutable - is_k_permutable (bool): is key tensor permutable - NOTE: e.g. - 1. for decoder-only transformer like gpt, it applies 'self-attn' as follows: - a) is_same_source is True - b) both q and k are permutable, as long as they are permuted in the same way. - 2. for encoder-decoder transformer like t5, it applies 'cross-attn' as follows: - a) is_same_source is False - b) q is permutable but k is not - 3. for multi-modal transformer with external encoders, it applies 'cross-attn' as follows: - a) is_same_source is False - b) q is unpermutable cuz of self-attn, but k is permutable even in a different way - - dist_attn_config (DistAttnConfig): dist attn config + chunk_size (int): chunk size to chunk the permutable tensor. + cp_group (dist.ProcessGroup): process group, only support nccl backend for now. cp_mesh (DeviceMesh): process mesh, only support 1D or 2D mesh for now. - num_heads_q (int): number of heads of query. Default: 1 - num_heads_kv (int): number of heads of key/value. Default: 1 + is_same_source (bool): is query tensor and key tensor share the same source. + Default to ``True``. + is_q_permutable (bool): is query tensor permutable. + Default to ``True``. + is_k_permutable (bool): is key tensor permutable. + Default to ``True``. + + NOTE: + 1. for decoder-only transformer like gpt, it applies 'self-attn' as follows: + a) is_same_source is True + b) both q and k are permutable, as long as they are permuted in the same way. + 2. for encoder-decoder transformer like t5, it applies 'cross-attn' as follows: + a) is_same_source is False + b) q is permutable but k is not + 3. for multi-modal transformer with external encoders, it applies 'cross-attn' as follows: + a) is_same_source is False + b) q is unpermutable cuz of self-attn, but k is permutable even in a different way + + dist_attn_config (DistAttnConfig): dist attn config. Returns: - DistAttnRuntimeMgr: dist attn runtime mgr + DistAttnRuntimeMgr: dist attn runtime manager. Example:: + >>> # Step1. initialize the dist attn runtime manager >>> dist_attn_runtime_mgr = init_dist_attn_runtime_mgr( ... q_ranges=AttnRanges.from_ranges([[0, 2048], [2048, 4096]]), ... k_ranges=AttnRanges.from_ranges([[0, 2048], [0, 4096]]), ... attn_mask_type=[AttnMaskType.FULL, AttnMaskType.CAUSAL], ... total_seqlen_q=4096, ... total_seqlen_k=4096, + ... num_heads_q=16, + ... num_heads_kv=4, + ... head_dim=128, ... chunk_size=512, ... cp_group=dist.new_group(list(range(4)), backend="nccl"), - ... is_same_source=True, - ... is_q_permutable=True, - ... is_k_permutable=True, ... dist_attn_config=DistAttnConfig( ... dispatch_config=DispatchConfig(alg=MinHeapDispatchAlg()), ... overlap_config=OverlapConfig( @@ -517,23 +529,29 @@ def init_dist_attn_runtime_mgr( ... alg=OverlapAlgType.UNIFORM, ... ), ... ), + ... is_same_source=True, + ... is_q_permutable=True, + ... is_k_permutable=True, ... ) - >>> # Dispatch global query tensor to local query tensor + >>> + >>> # Step2. dispatch global query tensor to local query tensor >>> local_q = dist_attn_runtime_mgr.dispatch_qo(total_q) - >>> # Dispatch global key tensor to local key tensor + >>> + >>> # Step3. dispatch global key/value tensor to local key/value tensor >>> local_k = dist_attn_runtime_mgr.dispatch_kv(total_k) - >>> # Dispatch global value tensor to local value tensor >>> local_v = dist_attn_runtime_mgr.dispatch_kv(total_v) - >>> # Calculate local attention result + >>> + >>> # Step4. calculate distributed attention >>> local_out, meta = dist_attn_runtime_mgr.calc_attn(local_q, local_k, local_v) - >>> # Gather local attention results to global result + >>> + >>> # Step5. undispatch local attention output to the global one if needed >>> total_out = dist_attn_runtime_mgr.undispatch_qo(local_out) """ cp_size = dist.get_world_size(cp_group) cp_rank = dist.get_rank(cp_group) - # make dispatch meta + # Make dispatch meta # to determine which rank should hold which chunks of seqlen dispatch_config: DispatchConfig = dist_attn_config.dispatch_config if ref_dispatch_meta_q is None or ref_dispatch_meta_k is None: @@ -564,23 +582,24 @@ def init_dist_attn_runtime_mgr( dispatch_meta_q = ref_dispatch_meta_q dispatch_meta_k = ref_dispatch_meta_k - # make comm meta and calc meta + # Make comm meta and calc meta # to organize the dist-attn calculation and communication overlap_config: OverlapConfig = dist_attn_config.overlap_config comm_meta, calc_meta, attn_solver = make_attn_meta_from_dispatch_meta( q_ranges=q_ranges, k_ranges=k_ranges, attn_mask_type=attn_mask_type, + num_heads_q=num_heads_q, + num_heads_kv=num_heads_kv, + head_dim=head_dim, dispatch_meta_q=dispatch_meta_q, dispatch_meta_k=dispatch_meta_k, - cp_group=cp_group, overlap_config=overlap_config, + cp_group=cp_group, cp_mesh=cp_mesh, - num_heads_q=num_heads_q, - num_heads_kv=num_heads_kv, ) - # init grpcoll buffer manager for native grpcoll kernels + # Init grpcoll buffer manager for native grpcoll grpcoll_config: GrpCollConfig = dist_attn_config.grpcoll_config init_grpcoll_buffer_mgr( comm_meta=comm_meta, @@ -590,7 +609,7 @@ def init_dist_attn_runtime_mgr( cp_group=cp_group, ) - # init dist attn runtime + # Init dist attn runtime dist_attn_runtime = DistAttnRuntime( comm_meta=comm_meta, calc_meta=calc_meta, @@ -598,21 +617,22 @@ def init_dist_attn_runtime_mgr( cp_group_gr=cp_group, # TODO: support interface to set distinct cp group for group-reduce ) - # init dist attn runtime mgr + # Init dist attn runtime mgr dist_attn_runtime_mgr = DistAttnRuntimeMgr( dispatch_meta_q=dispatch_meta_q, dispatch_meta_k=dispatch_meta_k, dist_attn_config=dist_attn_config, attn_solver=attn_solver, dist_attn_runtime=dist_attn_runtime, + num_heads_q=num_heads_q, + num_heads_kv=num_heads_kv, + head_dim=head_dim, cp_group=cp_group, ref_q_ranges=q_ranges, ref_k_ranges=k_ranges, is_same_source=is_same_source, is_q_permutable=is_q_permutable, is_k_permutable=is_k_permutable, - num_heads_q=num_heads_q, - num_heads_kv=num_heads_kv, ) return dist_attn_runtime_mgr diff --git a/magi_attention/functional/dist_attn.py b/magi_attention/functional/dist_attn.py index 66e2b6681..cbaffd8f1 100644 --- a/magi_attention/functional/dist_attn.py +++ b/magi_attention/functional/dist_attn.py @@ -12,9 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -import logging - # mypy: disable-error-code="union-attr,list-item" +import logging import warnings from logging import getLogger from typing import Any, TypeAlias @@ -59,10 +58,6 @@ class DistAttnRuntime: cp_group_gr (dist.ProcessGroup): the cp process group for group-reduce """ - num_heads_q: int - num_heads_kv: int - num_heads_per_group: int - remote_q_work_with_buffer_per_stage: list[WorkWithBuffer] remote_kv_work_with_buffer_per_stage: list[WorkWithBuffer] partial_out_lse_reduce_work_per_stage: list[WorkWithPostProcessFn] @@ -481,12 +476,42 @@ def save_tensors_for_bwd( local_kv: FusedOrTupleTensor, local_out: torch.Tensor, local_lse: torch.Tensor, + last_stage_q: torch.Tensor | None, + last_stage_kv: FusedOrTupleTensor | None, global_sink: torch.Tensor | None, ) -> None: - if self.concat_kv: # local_kv is a fused tensor - ctx.save_for_backward(local_q, local_kv, local_out, local_lse, global_sink) - else: # local_kv are tupled tensors - ctx.save_for_backward(local_q, *local_kv, local_out, local_lse, global_sink) + if last_stage_kv is None: + self.save_last_stage_for_backward = False + if self.concat_kv: # local_kv is a fused tensor + ctx.save_for_backward( + local_q, local_kv, local_out, local_lse, global_sink + ) + else: # local_kv are tupled tensors + ctx.save_for_backward( + local_q, *local_kv, local_out, local_lse, global_sink + ) + else: + self.save_last_stage_for_backward = True + if self.concat_kv: # local_kv is a fused tensor + ctx.save_for_backward( + local_q, + local_kv, + local_out, + local_lse, + last_stage_q, + last_stage_kv, + global_sink, + ) + else: # local_kv are tupled tensors + ctx.save_for_backward( + local_q, + *local_kv, + local_out, + local_lse, + last_stage_q, + *last_stage_kv, + global_sink, + ) # ---------- API for bwd --------- # @@ -656,6 +681,12 @@ def get_curr_qo_do_kv_lse_and_fetch_next( # pre-fetch remote qo_do,kv,lse for next stage(s) if self.prefetch_stage_by_stage and not is_last_remote_stage: + if ( + magi_attention.dist_attn_backward_hide_tail_reduce() + and self.is_penultimate_stage(overlap_stage) + ): + return curr_qo_do, curr_kv, curr_lse + ( self.remote_kv_work_with_buffer_per_stage[next_stage] ) = self._fetch_remote_kv( @@ -678,9 +709,17 @@ def get_curr_qo_do_kv_lse_and_fetch_next( # we issue all fetch-remote comms in advance of ffa bwd # and ffa bwd can still overlap with these comms # with the support of `sm_margin`, thanks to persistent kernel design + if ( + magi_attention.dist_attn_backward_hide_tail_reduce() + and self.overlap_degree > 0 + ): + degree = self.overlap_degree - 1 + else: + degree = self.overlap_degree + self.remote_kv_work_with_buffer_per_stage = [ self._fetch_remote_kv(local_kv=local_kv, overlap_stage=ith_stage) - for ith_stage in range(self.overlap_degree) + for ith_stage in range(degree) ] self.remote_qo_do_lse_work_with_buffer_per_stage = [ self._fetch_remote_qo_do_lse( @@ -688,7 +727,7 @@ def get_curr_qo_do_kv_lse_and_fetch_next( local_lse=local_lse, overlap_stage=ith_stage, ) - for ith_stage in range(self.overlap_degree) + for ith_stage in range(degree) ] return curr_qo_do, curr_kv, curr_lse @@ -854,21 +893,50 @@ def load_tensors_from_fwd( torch.Tensor, torch.Tensor, torch.Tensor | None, + FusedOrTupleTensor | None, + torch.Tensor | None, ]: - if self.concat_kv: # local kv is a fused tensor - local_q, local_kv, local_out, local_lse, global_sink = ctx.saved_tensors - else: # local kv are tupled tensors - ( - local_q, - local_k, - local_v, - local_out, - local_lse, - global_sink, - ) = ctx.saved_tensors - local_kv = (local_k, local_v) - - return local_q, local_kv, local_out, local_lse, global_sink + if self.save_last_stage_for_backward: + if self.concat_kv: # local kv is a fused tensor + ( + local_q, + local_kv, + local_out, + local_lse, + last_q, + last_kv, + global_sink, + ) = ctx.saved_tensors + else: # local kv are tupled tensors + ( + local_q, + local_k, + local_v, + local_out, + local_lse, + last_q, + last_k, + last_v, + global_sink, + ) = ctx.saved_tensors + local_kv = (local_k, local_v) + last_kv = (last_k, last_v) + else: + last_q, last_kv = None, None + if self.concat_kv: # local kv is a fused tensor + local_q, local_kv, local_out, local_lse, global_sink = ctx.saved_tensors + else: # local kv are tupled tensors + ( + local_q, + local_k, + local_v, + local_out, + local_lse, + global_sink, + ) = ctx.saved_tensors + local_kv = (local_k, local_v) + + return local_q, local_kv, local_out, local_lse, last_q, last_kv, global_sink # ---------- common API --------- # @@ -1028,6 +1096,12 @@ def is_first_remote_stage(self, overlap_stage: int) -> bool: """ return overlap_stage == 0 + def is_penultimate_stage(self, overlap_stage: int | None) -> bool: + """ + Check if the given overlap stage is the penultimate stage + """ + return self.get_next_stage(overlap_stage) == self.overlap_degree - 1 + def get_next_stage(self, overlap_stage: int | None) -> int: """ Get the next overlap stage @@ -1987,6 +2061,232 @@ def _init_dq_dkv_dsink_skipped_host_stage( return dq, dkv, dsink + def _init_dq_skipped_host_stage( + self, + qo_do: FusedOrTupleTensor, + ) -> torch.Tensor: + q, _, _ = self._maybe_chunk(qo_do, num_chunks=3) + + # NOTE: if local_dq and local_dkv calculation are skipped, + # we need to zero-initialize them since they might be reduced later + dq = torch.zeros_like( + q, + dtype=self._maybe_hp_dtype(q.dtype, not self.bwd_local_dq_lp_init), + ) + + return dq + + def _init_dkv_skipped_host_stage( + self, + kv: FusedOrTupleTensor, + ) -> FusedOrTupleTensor: + k, _ = self._maybe_chunk(kv, num_chunks=2) + if self.concat_kv: # kv is a fused tensor + dkv_shape = kv.shape + else: # kv are tupled tensors + dkv_shape = (k.shape[0] * 2, *k.shape[1:]) + + dkv = torch.zeros( + dkv_shape, + dtype=self._maybe_hp_dtype(k.dtype, not self.bwd_local_dkv_lp_init), + device=k.device, + ) + if not self.concat_dkv: # make partial_dkv tupled tensors + dkv = self._maybe_chunk(dkv, num_chunks=2) + + return dkv + + # TODO: unify this specific scheduling with the original one + def _hide_tail_stage_reduce_backward( + self, ctx, grad_output: torch.Tensor, *args + ): # pragma: no cover + ( + local_q, + local_kv, + local_out, + local_lse, + last_stage_q, + last_stage_kv, + global_sink, + ) = self.load_tensors_from_fwd(ctx) + softmax_scale: float | None = ctx.softmax_scale + softcap: float = ctx.softcap + save_last_stage = magi_attention.dist_attn_backward_hide_tail_reduce() + assert ( + not save_last_stage or not self.enable_qo_comm + ), "save_last_stage and enable_qo_comm can not be both True" + assert self.overlap_degree > 0, ( + f"when self.overlap_degree == 0, this branch should not be entered, " + f"but got {self.overlap_degree=}" + ) + + kernel_barrier_fetch = KernelBarrier(self.bwd_kernel_barrier_fetch_target) + kernel_barrier_reduce = KernelBarrier(self.bwd_kernel_barrier_reduce_target) + + # get local qo_do,kv,lse and pre-fetch qo_do,kv,lse for remote stage(s) + ( + local_qo_do, + local_kv, + local_lse, + ) = self.get_curr_qo_do_kv_lse_and_fetch_next( + local_qo_do=(local_q, local_out, grad_output), + local_kv=local_kv, + local_lse=local_lse, + overlap_stage=None, + kernel_barrier=kernel_barrier_fetch, + ) + + if not self.is_penultimate_stage(None): + kernel_barrier_fetch.synchronize() + kernel_barrier_reduce.reset() + + # apply bwd partial attn with ith remote qo_do,kv,lse + # overlapped with (i+1)th pre-fetch + ( + partial_local_dq, + partial_remote_dkv, + _, # partial_global_dsink + ) = self.apply_bwd_partial_attn( + qo_do=local_qo_do, + kv=last_stage_kv, + lse=local_lse, + dq_acc=None, + overlap_stage=self.overlap_degree - 1, + softmax_scale=softmax_scale, + softcap=softcap, + sink=global_sink, + ) + if partial_local_dq is None: + partial_local_dq = self._init_dq_skipped_host_stage(local_qo_do) + partial_local_dkv = self._init_dkv_skipped_host_stage(local_kv) + + # reduce ith partial dq,dkv + # overlapped with (i+1)th bwd partial attn and maybe (i+2)th pre-fetch + self.reduce_partial_dq_dkv( + partial_remote_dq=None, + partial_local_dq=partial_local_dq, + ref_remote_qo_do=local_qo_do, + partial_remote_dkv=partial_remote_dkv, + partial_local_dkv=partial_local_dkv, + ref_remote_kv=last_stage_kv, + overlap_stage=self.overlap_degree - 1, + kernel_barrier=kernel_barrier_reduce, + ) + num_of_degree = self.overlap_degree - 1 + + # loop into remote stages + for ith_overlap_stage in range(num_of_degree): + kernel_barrier_fetch.reset() + + # wait for ith remote qo_do,kv,lse prepared + # and pre-fetch (i+1)th remote qo_do,kv,lse + ( + curr_remote_qo_do, + curr_remote_kv, + curr_remote_lse, + ) = self.get_curr_qo_do_kv_lse_and_fetch_next( + local_qo_do=local_qo_do, + local_kv=local_kv, + local_lse=local_lse, + overlap_stage=ith_overlap_stage, + kernel_barrier=kernel_barrier_fetch, + ) + if not self.is_penultimate_stage(ith_overlap_stage): + kernel_barrier_fetch.synchronize() + + kernel_barrier_reduce.synchronize() + kernel_barrier_reduce.reset() + + # apply bwd partial attn with ith remote qo_do,kv,lse + # overlapped with (i+1)th pre-fetch + ( + partial_remote_dq, + partial_remote_dkv, + _, # partial_global_dsink + ) = self.apply_bwd_partial_attn( + qo_do=curr_remote_qo_do, + kv=curr_remote_kv, + lse=curr_remote_lse, + dq_acc=partial_local_dq if self.bwd_dq_use_acc else None, + overlap_stage=ith_overlap_stage, + softmax_scale=softmax_scale, + softcap=softcap, + sink=global_sink, + ) + + # reduce ith partial dq,dkv + # overlapped with (i+1)th bwd partial attn and maybe (i+2)th pre-fetch + self.reduce_partial_dq_dkv( + partial_remote_dq=partial_remote_dq, + partial_local_dq=partial_local_dq, + ref_remote_qo_do=curr_remote_qo_do, + partial_remote_dkv=partial_remote_dkv, + partial_local_dkv=partial_local_dkv, + ref_remote_kv=curr_remote_kv, + overlap_stage=ith_overlap_stage, + kernel_barrier=kernel_barrier_reduce, + ) + + kernel_barrier_reduce.synchronize() + ( + partial_host_dq, + partial_host_dkv, + partial_global_dsink, + ) = self.apply_bwd_partial_attn( + qo_do=local_qo_do, + kv=local_kv, + lse=local_lse, + dq_acc=partial_local_dq if self.bwd_dq_use_acc else None, + overlap_stage=None, + softmax_scale=softmax_scale, + softcap=softcap, + sink=global_sink, + ) + assert global_sink is None or partial_global_dsink is not None + + # reduce partial global dsink if required + self.reduce_partial_dsink( + partial_global_dsink=partial_global_dsink, + ) + + # if only one remote stage, num_of_degree = 0, get last remote stage work + # else, get self.overlap_degree - 1 remote stage + self.partial_dkv_reduce_work_per_stage[num_of_degree - 1]._wait_work() + if not self.bwd_dq_use_acc and partial_host_dq is not None: + partial_local_dq.add_(partial_host_dq) + if partial_host_dkv is not None: + if self.concat_dkv: + partial_local_dkv.add_(partial_host_dkv) + else: + for local_dkv, host_dkv in zip(partial_local_dkv, partial_host_dkv): + local_dkv.add_(host_dkv) + + # prepare reduced local dq,dk,dv and maybe global dsink + # before returning from backward + ( + local_dq, + local_dk, + local_dv, + global_dsink, + ) = self.prepare_reduced_local_dqkv_global_dsink( + partial_local_dq=partial_local_dq, + partial_local_dkv=partial_local_dkv, + partial_global_dsink=partial_global_dsink, + ref_local_dq=local_q, + ref_local_dkv=local_kv, + ) + + return ( + local_dq, + local_dk, + local_dv, + global_dsink, + None, # dist_attn_runtime + None, # softmax_scale + None, # softcap + None, # return_max_logits + ) + def _maybe_concat( self, *x: torch.Tensor, @@ -2051,13 +2351,12 @@ def _maybe_flatten_local_qkv_head_groups( assert isinstance( local_kv, tuple ), "local_kv should be tupled tensors for this API" - - # HACK: store the info about number of heads into runtime - # to conveniently access them later - self.num_heads_q = local_q.shape[1] - self.num_heads_kv = local_kv[0].shape[1] - assert self.num_heads_q % self.num_heads_kv == 0 - self.num_heads_per_group = self.num_heads_q // self.num_heads_kv + assert ( + local_q.size(1) == self.comm_meta.num_heads_q + ), f"local_q.num_heads ({local_q.size(1)}) != comm_meta.num_heads_q ({self.comm_meta.num_heads_q})" + assert ( + local_kv[0].size(1) == self.comm_meta.num_heads_kv + ), f"local_k.num_heads ({local_kv[0].size(1)}) != comm_meta.num_heads_kv ({self.comm_meta.num_heads_kv})" if not self.flatten_head_groups: return local_q, local_kv @@ -2068,8 +2367,8 @@ def _maybe_flatten_local_qkv_head_groups( local_q = rearrange( local_q, "n (g h) d -> (g n) h d", - g=self.num_heads_kv, - h=self.num_heads_per_group, + g=self.comm_meta.num_heads_kv, + h=self.comm_meta.num_heads_per_group, ).contiguous() # Transpose local_k and local_v: flatten groups (heads) into sequence dimension @@ -2112,16 +2411,16 @@ def _maybe_unflatten_local_out_lse_head_groups( local_out = rearrange( local_out, "(g n) h d -> n (g h) d", - g=self.num_heads_kv, - h=self.num_heads_per_group, + g=self.comm_meta.num_heads_kv, + h=self.comm_meta.num_heads_per_group, ).contiguous() # local_lse: [(g * n_q), h_per_group] -> [n_q, num_heads_q] local_lse = rearrange( local_lse, "(g n) h -> n (g h)", - g=self.num_heads_kv, - h=self.num_heads_per_group, + g=self.comm_meta.num_heads_kv, + h=self.comm_meta.num_heads_per_group, ).contiguous() return local_out, local_lse @@ -2166,8 +2465,8 @@ def _maybe_flatten_local_qo_do_lse_head_groups( rearrange( x, "n (g h) d -> (g n) h d", - g=self.num_heads_kv, - h=self.num_heads_per_group, + g=self.comm_meta.num_heads_kv, + h=self.comm_meta.num_heads_per_group, ).contiguous() for x in [local_out, local_do] ] @@ -2177,8 +2476,8 @@ def _maybe_flatten_local_qo_do_lse_head_groups( local_lse = rearrange( local_lse, "n (g h) -> (g n) h", - g=self.num_heads_kv, - h=self.num_heads_per_group, + g=self.comm_meta.num_heads_kv, + h=self.comm_meta.num_heads_per_group, ).contiguous() return local_qo_do, local_lse @@ -2217,8 +2516,8 @@ def _maybe_unflatten_local_dqkv_head_groups( local_dq = rearrange( local_dq, "(g n) h d -> n (g h) d", - g=self.num_heads_kv, - h=self.num_heads_per_group, + g=self.comm_meta.num_heads_kv, + h=self.comm_meta.num_heads_per_group, ) # local_dk/local_dv: [(num_heads_kv * n_kv), 1, d] -> [n_kv, num_heads_kv, d] @@ -2226,7 +2525,7 @@ def _maybe_unflatten_local_dqkv_head_groups( rearrange( x, "(h n) 1 d -> n h d", - h=self.num_heads_kv, + h=self.comm_meta.num_heads_kv, ) for x in [local_dk, local_dv] ] @@ -2434,12 +2733,22 @@ def forward( else: local_max_logits = None + if ( + magi_attention.dist_attn_backward_hide_tail_reduce() + and dist_attn_runtime.overlap_degree > 0 + ): + last_stage_q, last_stage_kv = curr_remote_q, curr_remote_kv + else: + last_stage_q, last_stage_kv = None, None + dist_attn_runtime.save_tensors_for_bwd( ctx, local_q=local_q, local_kv=local_kv, local_out=local_out, local_lse=local_lse, + last_stage_q=last_stage_q, + last_stage_kv=last_stage_kv, global_sink=global_sink, ) ctx.dist_attn_runtime = dist_attn_runtime @@ -2451,11 +2760,21 @@ def forward( @staticmethod def backward(ctx, grad_output: torch.Tensor, *args): # pragma: no cover dist_attn_runtime: DistAttnRuntime = ctx.dist_attn_runtime + if ( + magi_attention.dist_attn_backward_hide_tail_reduce() + and dist_attn_runtime.overlap_degree > 0 + ): + return dist_attn_runtime._hide_tail_stage_reduce_backward( + ctx, grad_output, *args + ) + ( local_q, local_kv, local_out, local_lse, + _, + _, global_sink, ) = dist_attn_runtime.load_tensors_from_fwd(ctx) softmax_scale: float | None = ctx.softmax_scale diff --git a/magi_attention/meta/_make_attn_meta.py b/magi_attention/meta/_make_attn_meta.py index 780364f11..25df91968 100644 --- a/magi_attention/meta/_make_attn_meta.py +++ b/magi_attention/meta/_make_attn_meta.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging +from logging import getLogger import torch.distributed as dist from torch.distributed.device_mesh import DeviceMesh @@ -31,55 +33,59 @@ from magi_attention.meta.solver.overlap_solver import OverlapConfig from magi_attention.utils import nvtx +logger = getLogger(__name__) + @nvtx.instrument_nvtx def make_attn_meta_from_dispatch_meta( q_ranges: AttnRanges, k_ranges: AttnRanges, attn_mask_type: list[AttnMaskType], + num_heads_q: int, + num_heads_kv: int, + head_dim: int, dispatch_meta_q: DispatchMeta, dispatch_meta_k: DispatchMeta, - cp_group: dist.ProcessGroup, overlap_config: OverlapConfig, + cp_group: dist.ProcessGroup, cp_mesh: DeviceMesh | None = None, - num_heads_q: int = 1, - num_heads_kv: int = 1, ) -> tuple[CommMeta, CalcMeta, BaseDistAttnSolver]: """Make the communication and calculation meta from the dispatch meta Args: - q_ranges (AttnRanges): global query ranges in the ref attn mask - k_ranges (AttnRanges): global key ranges in the ref attn mask - attn_mask_type (list[AttnMaskType]): attn mask type (list) - - dispatch_meta_q (DispatchMeta): The dispatch meta for query - dispatch_meta_k (DispatchMeta): The dispatch meta for key + q_ranges (AttnRanges): the global query ranges. + k_ranges (AttnRanges): the global key ranges. + attn_mask_type (list[AttnMaskType]): the attn mask type list. - cp_group (dist.ProcessGroup): The NCCL process group + num_heads_q (int): the number of heads of query. + num_heads_kv (int): the number of heads of key/value. + head_dim (int): the dimension of each attention head. - overlap_config (OverlapConfig): The overlap config, including the overlap mode, overlap degree, overlap chunk size, etc + dispatch_meta_q (DispatchMeta): the dispatch meta for query. + dispatch_meta_k (DispatchMeta): the dispatch meta for key/value. + overlap_config (OverlapConfig): the overlap config. - cp_mesh (DeviceMesh): process mesh, only support 1D or 2D mesh for now - - num_heads_q (int): number of heads of query. Default: 1 - num_heads_kv (int): number of heads of key/value. Default: 1 + cp_group (dist.ProcessGroup): the process group. + cp_mesh (DeviceMesh, optional): the process mesh. Defaults to ``None``. Returns: tuple[CommMeta, CalcMeta, BaseDistAttnSolver]: - the communication meta, calculation meta and the attn solver + the communication meta, calculation meta and the attn solver. """ + # Solve attention attn_solver: BaseDistAttnSolver if magi_attention.comm.is_qo_comm_enable(): # NOTE: for now, we use dynamic attn solver when and only when enabling qo comm # however, we will unify the static/dynamic attn solver in the future attn_solver = DynamicAttnSolver( algorithm=BinaryGreedyParallelDynamicAttnAlgorithm(), + num_heads_q=num_heads_q, + num_heads_kv=num_heads_kv, + head_dim=head_dim, cp_group=cp_group, dispatch_meta_q=dispatch_meta_q, dispatch_meta_k=dispatch_meta_k, - num_heads_q=num_heads_q, - num_heads_kv=num_heads_kv, cp_mesh=cp_mesh, ) attn_solver.solve( @@ -87,19 +93,20 @@ def make_attn_meta_from_dispatch_meta( k_ranges=k_ranges, attn_mask_type=attn_mask_type, ) - # TODO: add a flag to control visualization - # only for debug: visualize the buckets - # if cp_group.rank() == 0: - # attn_solver.output_solve_result( - # visualize=True, save_path=".", before_dispatch=True - # ) + # Visualize the buckets only for debug + if logger.isEnabledFor(logging.DEBUG) and cp_group.rank() == 0: + logger.debug("Visualizing the buckets...") + attn_solver.output_solve_result( + visualize=True, save_path="./dyn_solver_buckets.png" + ) else: attn_solver = DistAttnSolver( - cp_group=cp_group, - overlap_config=overlap_config, - cp_mesh=cp_mesh, num_heads_q=num_heads_q, num_heads_kv=num_heads_kv, + head_dim=head_dim, + overlap_config=overlap_config, + cp_group=cp_group, + cp_mesh=cp_mesh, ) attn_solver.solve( q_ranges=q_ranges, @@ -109,10 +116,12 @@ def make_attn_meta_from_dispatch_meta( dispatch_meta_k=dispatch_meta_k, ) + # Make comm/calc meta assert attn_solver.is_solved comm_meta = attn_solver.make_comm_meta() calc_meta = attn_solver.make_calc_meta() + # Sanity check assert comm_meta.overlap_degree == calc_meta.overlap_degree, ( "The overlap degree is inconsistent between " f"comm meta ({comm_meta.overlap_degree=}) and calc meta ({calc_meta.overlap_degree=})." diff --git a/magi_attention/meta/collection/calc_meta.py b/magi_attention/meta/collection/calc_meta.py index 939648944..4a645db72 100644 --- a/magi_attention/meta/collection/calc_meta.py +++ b/magi_attention/meta/collection/calc_meta.py @@ -232,7 +232,7 @@ def to_ffa_args(self, is_bwd: bool = False) -> dict: def can_skip(self, is_bwd: bool = False) -> bool: return self.skip_attn_bwd if is_bwd else self.skip_attn_fwd - def __repr__(self) -> str: + def __repr__(self) -> str: # pragma: no cover indent = "" repr_str = "AttnArg(\n" @@ -605,7 +605,7 @@ def _print_nonzero_by_col(cls, tensor: torch.Tensor, name: str = "tensor") -> No ] print(f"col {col}: {', '.join(entries)}", flush=True) - def __repr__(self) -> str: + def __repr__(self) -> str: # pragma: no cover indent = "" repr_str = "FA4AttnArg(\n" @@ -683,7 +683,7 @@ def __post_init__(self): seqlen_k=self.seqlen_k_per_remote_stage[stage], ) - def __repr__(self) -> str: + def __repr__(self) -> str: # pragma: no cover indent = "" repr_str = f"CalcMeta(overlap_degree={self.overlap_degree},\n" diff --git a/magi_attention/meta/collection/comm_meta.py b/magi_attention/meta/collection/comm_meta.py index ab4617476..c81b73d4b 100644 --- a/magi_attention/meta/collection/comm_meta.py +++ b/magi_attention/meta/collection/comm_meta.py @@ -52,6 +52,8 @@ class GroupCollectiveArg: deterministic: bool = False + split_alignment: int = 1 + def __post_init__(self): pass @@ -80,7 +82,7 @@ def to_packed_group_reduce_args(self, packed_times: int = 1) -> dict: src_indices=self.dst_indices_list * packed_times, ) - def __repr__(self) -> str: + def __repr__(self) -> str: # pragma: no cover indent = "" repr_str = "GroupCollectiveArg(\n" @@ -88,6 +90,7 @@ def __repr__(self) -> str: repr_str += f"{indent} world_size={self.world_size},\n" repr_str += f"{indent} device_mesh={repr(self.device_mesh)},\n" repr_str += f"{indent} deterministic={self.deterministic},\n" + repr_str += f"{indent} split_alignment={self.split_alignment},\n" repr_str += f"{indent} input_split_size_list={self.input_split_size_list},\n" repr_str += ( @@ -257,7 +260,7 @@ def to_group_reduce_args(self) -> dict: ) return self._group_reduce_args_dict_packed - def __repr__(self) -> str: + def __repr__(self) -> str: # pragma: no cover # Get the representation of the base class base_repr_str = super().__repr__() @@ -281,6 +284,7 @@ def __repr__(self) -> str: repr_str += f"{indent} init_group_reduce={self.init_group_reduce},\n" repr_str += f"{indent} # Generated by __post_init__:\n" + repr_str += f"{indent} device={self.device},\n" repr_str += format_dict_field( "_group_cast_args_dict_packed", self._group_cast_args_dict_packed, indent ) @@ -315,8 +319,15 @@ def __post_init__(self): not magi_attention.comm.is_hierarchical_comm_enable() ), "This arg dataclass is not supported for hierarchical comm for now." + self._check_split_alignment() + self.device = torch.cuda.current_device() + # transfer group-cast meta args to dispatch meta args + # HACK: for now, we only support internode grpcoll + # with intranode world size of 8 + self.num_nodes = max(1, self.group.size() // 8) + # ---- original group cast args dict ---- # self._group_cast_args_dict = super().to_group_cast_args() @@ -331,11 +342,22 @@ def __post_init__(self): self._init_meta_kwargs_for_native_group_cast() self._init_meta_kwargs_for_native_group_reduce() + def _check_split_alignment(self): + if self.split_alignment > 1: # non-trivial alignment + for idx, split in enumerate(self.input_split_size_list): + assert split % self.split_alignment == 0, ( + f"Each input split size must be multiple of {self.split_alignment} " + f"for better performance, but got {self.input_split_size_list=}, where the {idx}-th {split=}" + ) + for idx, split in enumerate(self.output_split_size_list): + assert split % self.split_alignment == 0, ( + f"Each output split size must be multiple of {self.split_alignment} " + f"for better performance, but got {self.output_split_size_list=}, where the {idx}-th {split=}" + ) + def _init_meta_kwargs_for_native_group_cast(self): - # transfer group-cast meta args to dispatch meta args - # HACK: for now, we only support internode grpcoll - # with intranode world size of 8 - num_nodes = max(1, self.group.size() // 8) + self._preprocess_args_for_split_alignment() + ( num_tokens_per_rank, num_tokens_per_rdma_rank, @@ -344,11 +366,11 @@ def _init_meta_kwargs_for_native_group_cast(self): input_split_sizes=self._group_cast_args_dict["input_split_sizes"], dst_indices=self._group_cast_args_dict["dst_indices"], group=self.group, - num_nodes=num_nodes, + num_nodes=self.num_nodes, ) if num_tokens_per_rdma_rank is not None: - assert num_tokens_per_rdma_rank.size(0) == num_nodes + assert num_tokens_per_rdma_rank.size(0) == self.num_nodes # NOTE: for internode grpcoll, besides providing output buffer, # we have to pass extra `internode_output_seqlen` to fully avoid GPU-CPU sync self._group_cast_args_dict[ @@ -391,13 +413,33 @@ def _init_meta_kwargs_for_native_group_reduce(self): "native_grpcoll_handle_dict" ] = self._group_cast_args_dict["native_grpcoll_handle_dict"] + def _preprocess_args_for_split_alignment(self): + if self.split_alignment > 1: + self._group_cast_args_dict["input_split_sizes"] = [ + split // self.split_alignment + for split in self._group_cast_args_dict["input_split_sizes"] + ] + self._group_cast_args_dict["output_split_sizes"] = [ + split // self.split_alignment + for split in self._group_cast_args_dict["output_split_sizes"] + ] + self._group_reduce_args_dict[ + "input_split_sizes" + ] = self._group_cast_args_dict["output_split_sizes"] + self._group_reduce_args_dict[ + "output_split_sizes" + ] = self._group_cast_args_dict["input_split_sizes"] + + self._group_cast_args_dict["split_alignment"] = self.split_alignment + self._group_reduce_args_dict["split_alignment"] = self.split_alignment + def to_group_cast_args(self) -> dict: return self._group_cast_args_dict def to_group_reduce_args(self) -> dict: return self._group_reduce_args_dict - def __repr__(self) -> str: + def __repr__(self) -> str: # pragma: no cover # Get the representation of the base class base_repr_str = super().__repr__() @@ -417,6 +459,8 @@ def __repr__(self) -> str: repr_str = f"{base_repr_str_without_closing.rstrip(',')},\n" # Remove trailing comma if exists and add our own repr_str += f"{indent} # Generated by __post_init__:\n" + repr_str += f"{indent} device={self.device},\n" + repr_str += f"{indent} num_nodes={self.num_nodes},\n" repr_str += format_dict_field( "_group_cast_args_dict", self._group_cast_args_dict, indent ) @@ -443,10 +487,18 @@ class CommMeta: num_remote_qo_tokens_per_stage: list[int] qo_group_collective_args_list: list[GroupCollectiveArg] + num_heads_q: int + num_heads_kv: int + head_dim: int + @property def overlap_degree(self) -> int: return len(self.num_remote_kv_tokens_per_stage) + @property + def num_heads_per_group(self) -> int: + return self._num_heads_per_group + def __post_init__(self): assert ( len(self.num_remote_kv_tokens_per_stage) @@ -464,6 +516,12 @@ def __post_init__(self): self.overlap_degree >= 0 ), f"Overlap degree must be >= 0, but got {self.overlap_degree=}" + assert self.num_heads_q % self.num_heads_kv == 0, ( + f"num_heads_q must be divisible by num_heads_kv, " + f"but got {self.num_heads_q=} and {self.num_heads_kv=}" + ) + self._num_heads_per_group = self.num_heads_q // self.num_heads_kv + if magi_attention.comm.is_native_grpcoll_enable(): self._init_native_grpcoll_args() else: @@ -581,12 +639,13 @@ def _init_native_grpcoll_args(self): **qo_group_collective_kwargs, ) - def __repr__(self) -> str: + def __repr__(self) -> str: # pragma: no cover indent = "" repr_str = f"CommMeta(overlap_degree={self.overlap_degree},\n" # num_remote_kv_tokens_per_stage repr_str += f"{indent} num_remote_kv_tokens_per_stage={self.num_remote_kv_tokens_per_stage},\n" + # kv_group_collective_args_list repr_str += format_list_field( "kv_group_collective_args_list", self.kv_group_collective_args_list, indent @@ -594,33 +653,40 @@ def __repr__(self) -> str: # num_remote_qo_tokens_per_stage repr_str += f"{indent} num_remote_qo_tokens_per_stage={self.num_remote_qo_tokens_per_stage},\n" + # qo_group_collective_args_list repr_str += format_list_field( "qo_group_collective_args_list", self.qo_group_collective_args_list, indent ) + # num_heads_q, num_heads_kv and num_heads_per_group + repr_str += f"{indent} num_heads_q={self.num_heads_q},\n" + repr_str += f"{indent} num_heads_kv={self.num_heads_kv},\n" + repr_str += f"{indent} head_dim={self.head_dim},\n" + # Generated fields from __post_init__ repr_str += f"{indent} # Generated by __post_init__:\n" + repr_str += f"{indent} num_heads_per_group={self.num_heads_per_group},\n" + if not magi_attention.comm.is_native_grpcoll_enable(): + # num_remote_qo_do_tokens_per_stage + repr_str += f"{indent} num_remote_qo_do_tokens_per_stage={self.num_remote_qo_do_tokens_per_stage},\n" + # qo_do_group_collective_args_list + repr_str += format_list_field( + "qo_do_group_collective_args_list", + self.qo_do_group_collective_args_list, + indent, + ) - # num_remote_qo_do_tokens_per_stage - repr_str += f"{indent} num_remote_qo_do_tokens_per_stage={self.num_remote_qo_do_tokens_per_stage},\n" - # qo_do_group_collective_args_list - repr_str += format_list_field( - "qo_do_group_collective_args_list", - self.qo_do_group_collective_args_list, - indent, - ) + # num_remote_out_lse_tokens_per_stage + repr_str += f"{indent} num_remote_out_lse_tokens_per_stage={self.num_remote_out_lse_tokens_per_stage},\n" + # out_lse_group_collective_args_list + repr_str += format_list_field( + "out_lse_group_collective_args_list", + self.out_lse_group_collective_args_list, + indent, + ) - # num_remote_out_lse_tokens_per_stage - repr_str += f"{indent} num_remote_out_lse_tokens_per_stage={self.num_remote_out_lse_tokens_per_stage},\n" - # out_lse_group_collective_args_list - repr_str += format_list_field( - "out_lse_group_collective_args_list", - self.out_lse_group_collective_args_list, - indent, - ) + # Remove trailing comma before final paren + repr_str = repr_str.rstrip(",\n") + "\n)" - repr_str = ( - repr_str.rstrip(",\n") + "\n)" - ) # Remove trailing comma before final paren return repr_str diff --git a/magi_attention/meta/container/slice.py b/magi_attention/meta/container/slice.py index 2b81739bf..384e04573 100644 --- a/magi_attention/meta/container/slice.py +++ b/magi_attention/meta/container/slice.py @@ -87,7 +87,7 @@ def __eq__(self, other: object) -> bool: and self.mask_type == other.mask_type ) - def __repr__(self) -> str: + def __repr__(self) -> str: # pragma: no cover return ( f"AttnSlice(slice_id={self.slice_id}, " f"q_range={self.q_range}, k_range={self.k_range}, " @@ -147,7 +147,7 @@ def __post_init__(self): f"but got {len(self.mask_types)} and {len(self.k_ranges)}" ) - def __repr__(self) -> str: + def __repr__(self) -> str: # pragma: no cover return ( f"MultiKAttnSlice(slice_id={self.slice_id}, " f"q_range={self.q_range}, k_ranges={self.k_ranges}, " diff --git a/magi_attention/meta/solver/dispatch_solver.py b/magi_attention/meta/solver/dispatch_solver.py index 5bcd1bd68..ce2feb99a 100644 --- a/magi_attention/meta/solver/dispatch_solver.py +++ b/magi_attention/meta/solver/dispatch_solver.py @@ -384,7 +384,7 @@ def update(self: T, other: T) -> None: """Update self affinity with other affinity in-place""" @abstractmethod - def __repr__(self) -> str: + def __repr__(self) -> str: # pragma: no cover """A string representation of the affinity""" def make_cmp_key(self: T) -> Callable: @@ -467,7 +467,7 @@ def update(self, other: "SampleIDAffinity") -> None: for sample_id, count in other.sample_id_cnt_dict.items(): self.sample_id_cnt_dict[sample_id] += count - def __repr__(self) -> str: + def __repr__(self) -> str: # pragma: no cover sample_id_to_count = dict(self.sample_id_cnt_dict) return f"sample id affinity: {sample_id_to_count=}" @@ -504,7 +504,7 @@ def update(self, other: "IOUAffinity") -> None: """Update self affinity with other affinity in-place""" self.iou_ranges.extend(other.iou_ranges) - def __repr__(self) -> str: + def __repr__(self) -> str: # pragma: no cover return f"iou affinity: ranges={self.iou_ranges}" diff --git a/magi_attention/meta/solver/dist_attn_solver.py b/magi_attention/meta/solver/dist_attn_solver.py index e0928530a..ebaa3e390 100644 --- a/magi_attention/meta/solver/dist_attn_solver.py +++ b/magi_attention/meta/solver/dist_attn_solver.py @@ -12,18 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. +import math +import warnings from abc import ABC, abstractmethod from bisect import bisect_left from collections import defaultdict from dataclasses import replace from itertools import chain -from typing import Any, Union +from logging import getLogger +from typing import Any, Literal, Union import torch import torch.distributed as dist from torch.distributed.device_mesh import DeviceMesh import magi_attention +from magi_attention.comm.primitive.grpcoll._buffer import GrpCollBuffer from magi_attention.comm.primitive.grpcoll.utils import ( sanity_check_for_group_cast_meta_args_per_rank, ) @@ -51,7 +55,9 @@ ) from magi_attention.utils._utils import ( argsort, + find_factors_in_range, flatten_nested_list, + get_factors, perm_idxs2unperm_idxs, ) @@ -69,6 +75,9 @@ pass +logger = getLogger(__name__) + + class BaseDistAttnSolver(ABC): """The base abstract dist-attn solver class to provide necessary abstract methods as common interfaces for sub-classes to implement @@ -91,6 +100,77 @@ def make_calc_meta(self) -> CalcMeta: def is_solved(self) -> bool: ... + @classmethod + def calc_split_alignment( + cls, + chunk_size: int, + num_heads: int, + head_dim: int, + strategy: Literal["min", "max"] = "min", + ) -> int: + """Calculate the split alignment automatically + to adjust the comm args for better performance of native grpcoll kernels. + + Args: + chunk_size (int): The chunk size used in dispatch meta. + num_heads (int): The number of attention heads. + head_dim (int): The dimension of each attention head. + strategy (Literal["min", "max", "auto"], optional): + The strategy to choose split alignment. Defaults to "min". + """ + if not magi_attention.comm.is_native_grpcoll_enable(): + # a2a-v backend does not need split alignment + return 1 + + dtype = ( + torch.float64 if magi_attention.is_sdpa_backend_enable() else torch.bfloat16 + ) + hidden_size = num_heads * head_dim + max_supported_hidden_size = GrpCollBuffer.get_max_supported_hidden_size(dtype) + min_high_bw_hidden_size = GrpCollBuffer.get_min_high_bw_hidden_size(dtype) + min_split_alignment = math.ceil(min_high_bw_hidden_size / hidden_size) + max_split_alignment = math.floor(max_supported_hidden_size / hidden_size) + + chunk_size_factors = get_factors(chunk_size) + valid_split_alignments = find_factors_in_range( + chunk_size_factors, min_split_alignment, max_split_alignment + ) + if len(valid_split_alignments) == 0: + warnings.warn( + f"Cannot find valid split alignment in range [{min_split_alignment}, {max_split_alignment}] " + f"within the factors of chunk_size={chunk_size}: [{', '.join(map(str, chunk_size_factors))}], " + f"for the settings: {num_heads=}, {head_dim=}, {dtype=}. " + f"Then we have to choose some smaller split alignment than recommended, " + "which might results in degraded performance of native grpcoll. " + "For better performance, you had better adjust the chunk size " + "to contain valid split alignment factors." + ) + + min_split_alignment = 1 + valid_split_alignments = find_factors_in_range( + chunk_size_factors, min_split_alignment, max_split_alignment + ) + + match strategy: + case "min": + split_alignment = min(valid_split_alignments) + strategy_str = "minimum" + case "max": + split_alignment = max(valid_split_alignments) + strategy_str = "maximum" + case _: + raise ValueError(f"Unknown strategy: {strategy}") + + logger.debug( + f"Found valid split alignment: {valid_split_alignments} " + f"in range [{min_split_alignment}, {max_split_alignment}] " + f"and chose {strategy_str} {split_alignment} within " + f"the factors of chunk_size={chunk_size}: [{', '.join(map(str, chunk_size_factors))}], " + f"for the settings: {num_heads=}, {head_dim=}, {dtype=}." + ) + + return split_alignment + def __eq__(self, other: Any) -> bool: if not isinstance(other, BaseDistAttnSolver): return False @@ -122,11 +202,12 @@ class DistAttnSolver(BaseDistAttnSolver): @nvtx.instrument_nvtx def __init__( self, - cp_group: dist.ProcessGroup, + num_heads_q: int, + num_heads_kv: int, + head_dim: int, overlap_config: OverlapConfig, + cp_group: dist.ProcessGroup, cp_mesh: DeviceMesh | None = None, - num_heads_q: int = 1, - num_heads_kv: int = 1, ): assert ( not magi_attention.comm.is_qo_comm_enable() @@ -136,13 +217,17 @@ def __init__( self.cp_size = dist.get_world_size(cp_group) self.cp_group = cp_group self.cp_mesh = cp_mesh + self.deterministic = magi_attention.is_deterministic_mode_enable() self.overlap_config = overlap_config self.overlap_solver = OverlapSolver(alg=self.overlap_config.alg) + self.org_num_heads_q = num_heads_q + self.org_num_heads_kv = num_heads_kv self.num_heads_q = num_heads_q self.num_heads_kv = num_heads_kv self.num_heads_group = 1 + self.head_dim = head_dim # NOTE: the real overlap degree should be determined in the later code: # 1. if overlap mode is static, then its real value equals to the one in the overlap config @@ -210,6 +295,7 @@ def solve( dispatch_meta_q: DispatchMeta, dispatch_meta_k: DispatchMeta, ) -> None: + # Apply flatten head groups if enabled flatten_head_groups = magi_attention.is_flatten_head_groups_enable() if flatten_head_groups: self.num_heads_group = self.num_heads_kv @@ -224,7 +310,17 @@ def solve( dispatch_meta_q = self._expand_dispatch_meta(dispatch_meta_q) dispatch_meta_k = self._expand_dispatch_meta(dispatch_meta_k) - # normalize attn_mask_type to list[AttnMaskType] + # Calculate kv split alignment for native grpcoll + if self.cp_size == 1: # cp1 shortcut + self.split_alignment_kv = 1 + else: + self.split_alignment_kv = self.calc_split_alignment( + chunk_size=dispatch_meta_q.chunk_size, + num_heads=self.num_heads_kv, + head_dim=self.head_dim, + ) + + # Normalize attn_mask_type to list[AttnMaskType] if isinstance(attn_mask_type, (AttnMaskType, int)): # HACK: for one mask type, wrap to list if isinstance(attn_mask_type, int): @@ -243,7 +339,7 @@ def solve( else: raise TypeError(f"Unsupported attn_mask_type type: {type(attn_mask_type)}") - # init bucket this rank from dispatch_meta_q + # Init bucket this rank from dispatch_meta_q # assuming it is self-attn scenarios and the partitions of q,k are the same if magi_attention.is_sanity_check_enable(): assert dispatch_meta_q.partitions == dispatch_meta_k.partitions @@ -254,7 +350,7 @@ def solve( dispatch_meta=dispatch_meta_q, ) - # init host / remote q/k ranges global for this rank + # Init host / remote q/k ranges global for this rank ( host_q_ranges_global_this_rank, host_k_ranges_global_this_rank, @@ -265,7 +361,7 @@ def solve( bucket_this_rank=bucket_this_rank, ) - # set some attributes that might be fetched from outside + # Set some attributes that might be fetched from outside self.bucket = bucket_this_rank self.host_q_ranges_global = host_q_ranges_global_this_rank self.host_k_ranges_global = host_k_ranges_global_this_rank @@ -274,7 +370,7 @@ def solve( self.shard_seqlen_q = dispatch_meta_q.shard_seqlen self.total_seqlen_k = dispatch_meta_k.total_seqlen - # init host rank entry for this rank + # Init host rank entry for this rank self.host_rank_entry_this_rank = self._init_host_rank_entry_this_rank( host_q_ranges_global=host_q_ranges_global_this_rank, host_k_ranges_global=host_k_ranges_global_this_rank, @@ -282,7 +378,7 @@ def solve( attn_calc_slice_global_list=bucket_this_rank.attn_slices, ) - # init remote rank entry for each stage for this rank + # Init remote rank entry for each stage for this rank # with the shape of [overlap_degree,] self.remote_rank_entry_per_stage_this_rank = ( self._init_remote_rank_entry_per_stage_this_rank( @@ -290,7 +386,7 @@ def solve( ) ) - # init remote rank entry for each rank for each stage + # Init remote rank entry for each rank for each stage # with the shape of [overlap_degree, cp_size] self.remote_rank_entry_per_rank_per_stage = ( self._init_remote_rank_entry_per_rank_per_stage( @@ -298,7 +394,7 @@ def solve( ) ) - # init transfer table per stage + # Init transfer table per stage # with the shape of [overlap_degree,] self.transfer_table_per_stage: list[ TransferTable @@ -345,17 +441,17 @@ def _init_host_remote_ranges_global_this_rank( ) remote_k_ranges_global_this_rank = AttnRanges() else: - # init host q_ranges global for this rank + # Init host q_ranges global for this rank host_q_ranges_global_this_rank = dispatch_meta_q.host_ranges_per_rank[ self.cp_rank ].merge() - # init host k_ranges global for this rank + # Init host k_ranges global for this rank host_k_ranges_global_this_rank = dispatch_meta_k.host_ranges_per_rank[ self.cp_rank ].merge() - # init remote k_ranges global for this rank + # Init remote k_ranges global for this rank # NOTE: this only contains the remote k ranges that we need to calculate from remote_k_ranges_global_this_rank = ( bucket_this_rank.k_ranges.find_hole_ranges( @@ -364,6 +460,25 @@ def _init_host_remote_ranges_global_this_rank( ) ) + # Apply split alignment + if self.split_alignment_kv > 1: + split_aligment = self.split_alignment_kv + for i, attn_range in enumerate(remote_k_ranges_global_this_rank): + if ( + attn_range.start % split_aligment != 0 + or attn_range.end % split_aligment != 0 + ): + remote_k_ranges_global_this_rank[i] = AttnRange( + start=(attn_range.start // split_aligment) * split_aligment, + end=( + (attn_range.end + split_aligment - 1) // split_aligment + ) + * split_aligment, + ) + remote_k_ranges_global_this_rank = ( + remote_k_ranges_global_this_rank.merge() + ) + # sanity check if magi_attention.is_sanity_check_enable(): # check if merged successfully @@ -517,7 +632,7 @@ def _chunk_remote_k_ranges_global( self.overlap_chunk_size = 0 remote_k_ranges_global_per_chunk: list[AttnRanges] = [] # empty list else: - # determine the chunk size constrainted by min_chunk_size and max_num_chunks + # Determine the chunk size constrainted by min_chunk_size and max_num_chunks total_remote_k_seqlen = remote_k_ranges_global.total_seqlen num_chunks = ( total_remote_k_seqlen + self.overlap_config.min_chunk_size - 1 @@ -534,7 +649,18 @@ def _chunk_remote_k_ranges_global( total_remote_k_seqlen + self.overlap_chunk_size - 1 ) // self.overlap_chunk_size - # chunk the remote k ranges global for multi-stage overlapping + # Adjust the overlap chunk size and num chunks with split alignment + if self.split_alignment_kv > 1: + split_aligment = self.split_alignment_kv + + self.overlap_chunk_size = ( + (self.overlap_chunk_size + split_aligment - 1) // split_aligment + ) * split_aligment + self.overlap_num_chunks = ( + total_remote_k_seqlen + self.overlap_chunk_size - 1 + ) // self.overlap_chunk_size + + # Chunk the remote k ranges global for multi-stage overlapping remote_k_ranges_global_per_chunk = remote_k_ranges_global.chunk( self.overlap_chunk_size, check=magi_attention.is_sanity_check_enable() ) @@ -1567,6 +1693,9 @@ def make_comm_meta(self) -> CommMeta: kv_group_collective_args_list=kv_group_collective_args_list, num_remote_qo_tokens_per_stage=num_remote_qo_tokens_per_stage, qo_group_collective_args_list=qo_group_collective_args_list, + num_heads_q=self.org_num_heads_q, + num_heads_kv=self.org_num_heads_kv, + head_dim=self.head_dim, ) return comm_meta @@ -1652,6 +1781,7 @@ def _calc_kv_group_collective_arg( group=self.cp_group, device_mesh=self.cp_mesh, deterministic=self.deterministic, + split_alignment=self.split_alignment_kv, ) # sanity check for group-cast arg per rank diff --git a/magi_attention/meta/solver/dynamic_attn_solver.py b/magi_attention/meta/solver/dynamic_attn_solver.py index 4dbd97ef6..ac46405d0 100644 --- a/magi_attention/meta/solver/dynamic_attn_solver.py +++ b/magi_attention/meta/solver/dynamic_attn_solver.py @@ -50,6 +50,9 @@ class DynamicAttnSolver(BaseDistAttnSolver): def __init__( self, algorithm: DynamicAttnAlgorithm, + num_heads_q: int, + num_heads_kv: int, + head_dim: int, cp_group: dist.ProcessGroup, dispatch_meta_q: DispatchMeta | None = None, dispatch_meta_k: DispatchMeta | None = None, @@ -57,8 +60,6 @@ def __init__( total_seqlen_k: int | None = None, host_ranges_q: list[AttnRanges] | None = None, host_ranges_k: list[AttnRanges] | None = None, - num_heads_q: int = 1, - num_heads_kv: int = 1, cp_rank: int | None = None, cp_size: int | None = None, cp_mesh: DeviceMesh | None = None, @@ -82,6 +83,7 @@ def __init__( self.host_ranges_k = [ hr.merge() for hr in dispatch_meta_k.host_ranges_per_rank ] + self.dispatch_chunk_size = dispatch_meta_q.chunk_size else: assert total_seqlen_q is not None and total_seqlen_k is not None assert host_ranges_q is not None and host_ranges_k is not None @@ -93,10 +95,18 @@ def __init__( self.total_seqlen_k = total_seqlen_k self.host_ranges_q = [host_ranges.merge() for host_ranges in host_ranges_q] self.host_ranges_k = [host_ranges.merge() for host_ranges in host_ranges_k] + self.dispatch_chunk_size = total_seqlen_q + assert ( + num_heads_q % num_heads_kv == 0 + ), f"num_heads_q ({num_heads_q}) must be divisible by num_heads_kv ({num_heads_kv})" + + self.org_num_heads_q = num_heads_q + self.org_num_heads_kv = num_heads_kv self.num_heads_q = num_heads_q self.num_heads_kv = num_heads_kv self.num_heads_group = 1 + self.head_dim = head_dim # set some attributes that might be fetched from outside self.host_q_ranges_global = self.host_ranges_q[self.cp_rank] @@ -181,6 +191,26 @@ def solve( self._is_solved = True + # Calculate kv split alignment for native grpcoll + if self.cp_size == 1: # cp1 shortcut + self.split_alignment_kv = 1 + else: + self.split_alignment_kv = self.calc_split_alignment( + chunk_size=self.dispatch_chunk_size, + num_heads=self.num_heads_kv, + head_dim=self.head_dim, + ) + + # Calculate qo split alignment for native grpcoll + if self.cp_size == 1: # cp1 shortcut + self.split_alignment_qo = 1 + else: + self.split_alignment_qo = self.calc_split_alignment( + chunk_size=self.dispatch_chunk_size, + num_heads=self.num_heads_q, + head_dim=self.head_dim, + ) + @property def is_solved(self) -> bool: return self._is_solved @@ -333,6 +363,17 @@ def _calc_intersection( j += 1 return intersections + def make_split_alignment(self, ranges: AttnRanges, calc_kv: bool) -> AttnRanges: + if calc_kv and self.split_alignment_kv > 1: + return ranges.merge_with_split_alignment( + split_alignment=self.split_alignment_kv + ) + elif not calc_kv and self.split_alignment_qo > 1: + return ranges.merge_with_split_alignment( + split_alignment=self.split_alignment_qo + ) + return ranges + @nvtx.instrument_nvtx def _calc_group_collective_arg( self, @@ -356,6 +397,10 @@ def _calc_group_collective_arg( if calc_kv else self.remote_bucket_this_rank.get_qo_ranges_union() ) + + # make split_alignment for group collective optimization + local_calc_ranges = self.make_split_alignment(local_calc_ranges, calc_kv) + # local_calc_ranges is sorted and merged intersections = self._calc_intersection_with_index( local_calc_ranges, indexed_remote_hold_ranges @@ -388,6 +433,10 @@ def _calc_group_collective_arg( if calc_kv else self.bucket_per_rank[remote_rank].get_qo_ranges_union() ) + + # make split_alignment for group collective optimization + remote_calc_ranges = self.make_split_alignment(remote_calc_ranges, calc_kv) + intersections = self._calc_intersection( host_ranges_this_rank, remote_calc_ranges ) @@ -437,6 +486,9 @@ def _calc_group_collective_arg( group=self.cp_group, device_mesh=self.cp_mesh, deterministic=self.deterministic, + split_alignment=self.split_alignment_kv + if calc_kv + else self.split_alignment_qo, ) return group_collective_arg @@ -478,6 +530,9 @@ def make_comm_meta(self) -> CommMeta: kv_group_collective_args_list=kv_group_collective_args_list, num_remote_qo_tokens_per_stage=num_remote_qo_tokens_per_stage, qo_group_collective_args_list=qo_group_collective_args_list, + num_heads_q=self.org_num_heads_q, + num_heads_kv=self.org_num_heads_kv, + head_dim=self.head_dim, ) return comm_meta @@ -579,10 +634,18 @@ def make_calc_meta( local_attn_arg_k_ranges = self.host_k_ranges_global.make_ranges_local( local_attn_arg_k_ranges ) - remote_attn_arg_q_ranges = remote_attn_arg_q_ranges.make_ranges_local( + # make split_alignment for remote q ranges + remote_q_ranges_global = self.make_split_alignment( + remote_attn_arg_q_ranges, calc_kv=False + ) + remote_attn_arg_q_ranges = remote_q_ranges_global.make_ranges_local( remote_attn_arg_q_ranges ) - remote_attn_arg_k_ranges = remote_attn_arg_k_ranges.make_ranges_local( + # make split_alignment for remote k ranges + remote_k_ranges_global = self.make_split_alignment( + remote_attn_arg_k_ranges, calc_kv=True + ) + remote_attn_arg_k_ranges = remote_k_ranges_global.make_ranges_local( remote_attn_arg_k_ranges ) diff --git a/magi_attention/meta/solver/dynamic_solver_vis.py b/magi_attention/meta/solver/dynamic_solver_vis.py index 249f5ff84..8ed4e0241 100644 --- a/magi_attention/meta/solver/dynamic_solver_vis.py +++ b/magi_attention/meta/solver/dynamic_solver_vis.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - from typing import Sequence import matplotlib.pyplot as plt @@ -33,6 +31,7 @@ def visualize_buckets( title: str = "DynamicAttnSolver buckets", max_size: int | None = None, save_path: str | None = None, + before_dispatch: bool = False, ) -> None: """ Visualize attention rectangles for each rank. @@ -42,14 +41,14 @@ def visualize_buckets( - Vertical axis: q (from q.start to q.end) Note: - q increases downward, so the q axis needs to be inverted. + q increases downwards, so the q-axis needs to be inverted (smaller at the top, larger at the bottom). """ if not bucket_per_rank: return # Prepare figure - # Slightly widen the figure and leave space for the legend on the right + # Slightly wider to leave space for the legend on the right fig, ax = plt.subplots(figsize=(10, 8)) # Use different colors for different ranks @@ -106,13 +105,13 @@ def visualize_buckets( ax.set_xlim(k_min, k_max) ax.set_ylim(q_min, q_max) - # Make q and k unit lengths the same physical size to avoid distortion + # Ensure q and k unit lengths are displayed as the same physical length to avoid distortion ax.set_aspect("equal", adjustable="box") - # q increases downward -> invert the y-axis + # q increases downwards -> invert y-axis ax.invert_yaxis() - # Place k-axis on top + # Move k-axis to the top ax.set_xlabel("k (key index)") ax.xaxis.set_label_position("top") ax.xaxis.tick_top() @@ -133,7 +132,7 @@ def visualize_buckets( ) ) labels.append(f"rank {rank}") - # Move the legend to the upper right outside the plot to avoid overlap + # Place legend in the upper right, moved outside the plot area to avoid obscuring content ax.legend( handles, labels, @@ -144,7 +143,7 @@ def visualize_buckets( ) ax.grid(True, linestyle="--", alpha=0.3) - # Reserve fixed space on the right for the legend + # Leave fixed space for the legend on the right to avoid squashing the main plot area fig.subplots_adjust(right=0.8, top=0.9) if save_path is not None: diff --git a/magi_attention/utils/_utils.py b/magi_attention/utils/_utils.py index b78723af3..c1400ed0b 100644 --- a/magi_attention/utils/_utils.py +++ b/magi_attention/utils/_utils.py @@ -12,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import bisect import hashlib +import math import os import random from contextlib import contextmanager @@ -332,6 +334,49 @@ def pad_and_pack_tensors( return packed_tensor +def get_factors(x: int) -> list[int]: + """Get the factors of a given integer x + in ascending sorted order. + """ + if x < 1: + return [] + + small_factors = [] + large_factors = [] + + limit = int(math.sqrt(x)) + for i in range(1, limit + 1): + if x % i == 0: + small_factors.append(i) + if i * i != x: + large_factors.append(x // i) + + return small_factors + large_factors[::-1] + + +def find_factors_in_range( + factors: list[int], min_val: int, max_val: int, multiple_of: int = 1 +) -> list[int]: + """Find all factors within the specified range [min_val, max_val], + optionally filtering by a ``multiple_of`` condition. + """ + if len(factors) == 0: + return [] + + # Find the first position where factors[i] >= min_val (left boundary) + left_idx = bisect.bisect_left(factors, min_val) + + # Find the last position where factors[i] <= max_val (right boundary) + # bisect_right returns the insertion position, so slicing up to this position includes max_val + right_idx = bisect.bisect_right(factors, max_val) + + # Return the slice result within the range + if multiple_of == 1: + return factors[left_idx:right_idx] + + return [f for f in factors[left_idx:right_idx] if f % multiple_of == 0] + + def transpose_matrix(matrix: list[list[Any]]) -> list[list[Any]]: """ Transposes a 2D list (matrix) where each cell contains custom objects. @@ -850,6 +895,18 @@ def make_attn_mask_from_ffa_args( return mask +def fp_dtype_bits( + dtype: torch.dtype, +) -> int: + if dtype == torch.float4_e2m1fn_x2: + # NOTE: _x2 suffix for packed representation + # of two float4_e2m1f values into one byte + # see issue: https://github.com/pytorch/pytorch/issues/146414 + return 4 + + return torch.finfo(dtype).bits + + def is_fp_dtype_at_least( tensor: torch.Tensor, lowest_precision: torch.dtype, diff --git a/scripts/install_nvshmem_roce.sh b/scripts/install_nvshmem_roce.sh new file mode 100644 index 000000000..7512bf9ff --- /dev/null +++ b/scripts/install_nvshmem_roce.sh @@ -0,0 +1,136 @@ +#!/bin/bash + +# Copyright (c) 2025-2026 SandAI. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Stop the script immediately if any command fails +set -e + +# --- Configuration Variables --- + +# We will install everything directly into /opt/nvshmem to keep paths clean +NVSHMEM_DIR="/opt/nvshmem" +NVSHMEM_VERSION="3.4.5-0" +# Temporary directory for source code +SOURCE_BASE="/tmp/nvshmem_build" +NVSHMEM_SRC_DIR="${SOURCE_BASE}/nvshmem-${NVSHMEM_VERSION}" +INSTALL_DIR="${NVSHMEM_DIR}" +CUDA_ARCHITECTURES="80-real;89-real;90-real;100-real;120" + +# --- Handle Local File Argument --- + +LOCAL_TARBALL=$1 + +echo "=== Starting NVSHMEM v${NVSHMEM_VERSION} Custom Build ===" + +# 1. Setup Directories +mkdir -p "${NVSHMEM_DIR}" +mkdir -p "${SOURCE_BASE}" + +# 2. Source Code Acquisition +if [ -n "$LOCAL_TARBALL" ]; then + if [ -f "$LOCAL_TARBALL" ]; then + echo "--- Using local tarball: $LOCAL_TARBALL ---" + ABS_TARBALL=$(realpath "$LOCAL_TARBALL") + cp "$ABS_TARBALL" "${SOURCE_BASE}/v${NVSHMEM_VERSION}.tar.gz" + else + echo "Error: Local file $LOCAL_TARBALL not found!" + exit 1 + fi +else + echo "--- No local file provided. Attempting to download... ---" + cd "${SOURCE_BASE}" + wget -O "v${NVSHMEM_VERSION}.tar.gz" "https://github.com/NVIDIA/nvshmem/archive/refs/tags/v${NVSHMEM_VERSION}.tar.gz" +fi + +# 3. Extraction +echo "--- Extracting Source ---" + +cd "${SOURCE_BASE}" +tar -zxf "v${NVSHMEM_VERSION}.tar.gz" +cd "${NVSHMEM_SRC_DIR}" + +# 4. Environment Fixes (Symlink for libmlx5) +echo "--- Applying System Library Fixes ---" + +MLX5_SO_PATH="/usr/lib/x86_64-linux-gnu/libmlx5.so" +MLX5_SO_REAL="/usr/lib/x86_64-linux-gnu/libmlx5.so.1" + +if [ ! -f "$MLX5_SO_PATH" ]; then + echo "Creating symlink for libmlx5.so..." + if command -v sudo >/dev/null 2>&1; then + sudo ln -s "$MLX5_SO_REAL" "$MLX5_SO_PATH" || true + else + ln -s "$MLX5_SO_REAL" "$MLX5_SO_PATH" || true + fi +fi + +# 5. Patching CMake and Source Code +echo "--- Applying Source Code Patches ---" + +find . -name "CMakeLists.txt" -print0 | xargs -0 sed -i 's/\$/\${MPI_CXX_INCLUDE_DIRS}/g' +sed -i "s/nvshmemi_call_rdxn_on_stream_kernel/nvshmemi_reduce_on_stream/g" src/host/team/team_internal.cpp + +# 6. Configuration via CMake +echo "--- Configuring NVSHMEM with IBGDA Support ---" + +NVSHMEM_SHMEM_SUPPORT=0 \ +NVSHMEM_UCX_SUPPORT=0 \ +NVSHMEM_USE_NCCL=1 \ +NVSHMEM_IBGDA_SUPPORT=1 \ +NVSHMEM_PMIX_SUPPORT=0 \ +NVSHMEM_TIMEOUT_DEVICE_POLLING=0 \ +NVSHMEM_USE_GDRCOPY=1 \ +MPI_HOME=/opt/hpcx/ompi \ +cmake -S . -B build/ \ + -DCMAKE_INSTALL_PREFIX="${INSTALL_DIR}" \ + -DCMAKE_CUDA_ARCHITECTURES="${CUDA_ARCHITECTURES}" \ + -DMLX5_lib="$MLX5_SO_PATH" + +# 7. Build and Install +echo "--- Compiling and Installing ---" + +cd build +make -j 16 +make install + +# 8. Inject Environment Variables into .bashrc +echo "--- Updating ~/.bashrc ---" + +# Define the lines to be injected +# We use a marker to easily identify and avoid duplicate injections +BASHRC_FILE="$HOME/.bashrc" +MARKER="# >>> NVSHMEM CUSTOM BUILD SETTINGS <<<" + +if ! grep -q "$MARKER" "$BASHRC_FILE"; then + echo "Injecting NVSHMEM variables into $BASHRC_FILE..." + cat >> "$BASHRC_FILE" << EOF + +$MARKER +export NVSHMEM_DIR=${INSTALL_DIR} +export NVSHMEM_HOME=${INSTALL_DIR} +export LD_LIBRARY_PATH="\${NVSHMEM_DIR}/lib:\$LD_LIBRARY_PATH" +export PATH="\${NVSHMEM_DIR}/bin:\$PATH" +# >>> END OF NVSHMEM SETTINGS <<< +EOF +else + echo "NVSHMEM settings already exist in $BASHRC_FILE. Skipping injection." +fi + +echo "=== NVSHMEM Successfully Installed to: ${INSTALL_DIR} ===" +echo "IMPORTANT: To apply changes to your current shell, run:" +echo " source ~/.bashrc" + +# Optional: Clean up build files +# rm -rf "${SOURCE_BASE}" diff --git a/setup.py b/setup.py index 60c9c44f6..dd612fd62 100644 --- a/setup.py +++ b/setup.py @@ -140,7 +140,9 @@ def get_cuda_bare_metal_version(cuda_dir) -> tuple[str, Version]: return raw_output, bare_metal_version -def get_device_compute_capability(with_minor: bool = True, with_a: bool = False) -> str: +def get_device_compute_capability( + with_minor: bool = True, with_a: bool = False, default_cap: str | None = None +) -> str: """Get the compute capability of the current CUDA device. Example: '80', '90', '100', etc. @@ -149,6 +151,8 @@ def get_device_compute_capability(with_minor: bool = True, with_a: bool = False) Defaults to ``True``. with_a (bool): Whether to append 'a' suffix to the capability. Defaults to ``False``. + default_cap (str | None): The default capability to return if CUDA is not available. + Defaults to ``None`` to raise an error if CUDA is not available. Returns: str: The compute capability of the current CUDA device. @@ -163,7 +167,10 @@ def get_device_compute_capability(with_minor: bool = True, with_a: bool = False) if with_a: # include suffix 'a' like 90a, 100a capability += "a" else: - raise RuntimeError("CUDA device is not available to get compute capability") + if default_cap is not None: + capability = default_cap + else: + raise RuntimeError("CUDA device is not available to get compute capability") return capability @@ -330,7 +337,11 @@ def build_magi_attn_comm_module( # NOTE: we've found the compilation fails with `sm103` # thus we only use the major version with minor as `0`, # i.e. only `sm80`, `sm90`, `sm100`, etc. - capability = get_device_compute_capability(with_minor=False, with_a=False) + capability = os.environ.get("MAGI_ATTENTION_BUILD_COMPUTE_CAPABILITY", "") + if capability == "": + capability = get_device_compute_capability( + with_minor=False, with_a=False, default_cap="90" + ) # --- for grpcoll submodule --- # @@ -381,8 +392,6 @@ def build_magi_attn_comm_module( # Generate instantiations inst_dir_abs = grpcoll_dir_abs / "instantiations" - if inst_dir_abs.exists(): - shutil.rmtree(inst_dir_abs) inst_dir_abs.mkdir(parents=True, exist_ok=True) gen_script = grpcoll_dir_abs / "generate_inst.py" @@ -452,6 +461,8 @@ def build_magi_attn_comm_module( "-gencode", # Explicitly specify for current device compute capability f"arch=compute_{capability},code=sm_{capability}", + # "-Xcompiler", # Uncomment for profiling compilation time + # "-ftime-report", # Uncomment for profiling compilation time ] # Initialize lists for linking configuration diff --git a/tests/test_api/test_interface.py b/tests/test_api/test_interface.py index f65a4bfb4..b09b79c0f 100644 --- a/tests/test_api/test_interface.py +++ b/tests/test_api/test_interface.py @@ -34,9 +34,8 @@ dispatch, dist_attn_runtime_dict, get_position_ids, - magi_attn_flex_dispatch, magi_attn_flex_key, - magi_attn_varlen_dispatch, + magi_attn_varlen_key, make_flex_key_for_new_mask_after_dispatch, make_varlen_key_for_new_mask_after_dispatch, undispatch, @@ -127,7 +126,7 @@ def timeout(self) -> int: @property def seed(self) -> int: - return 42 + return 42 + self.world_size @with_comms @parameterize( @@ -492,10 +491,12 @@ def test_interface( batch_size, attn_config["total_seqlen_q"] // batch_size ) - _, dist_attn_runtime_key = magi_attn_varlen_dispatch( - x, - cu_seqlens_q, - cu_seqlens_k, + dist_attn_runtime_key = magi_attn_varlen_key( + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + num_heads_q=num_heads_q, + num_heads_kv=num_heads_kv, + head_dim=head_dim, pad_size=pad_size, chunk_size=chunk_size, cp_group_or_mesh=self.device_mesh @@ -514,10 +515,12 @@ def test_interface( cu_seqlens_q = attn_config["cu_seqlens_q"] cu_seqlens_k = attn_config["cu_seqlens_k"] - _, dist_attn_runtime_key = magi_attn_varlen_dispatch( - x, - cu_seqlens_q, - cu_seqlens_k, + dist_attn_runtime_key = magi_attn_varlen_key( + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + num_heads_q=num_heads_q, + num_heads_kv=num_heads_kv, + head_dim=head_dim, pad_size=pad_size, chunk_size=chunk_size, cp_group_or_mesh=self.device_mesh @@ -528,8 +531,7 @@ def test_interface( ) case "magi_attn_flex": use_str_masktype: bool = attn_config["use_str_masktype"] - local_x_padded, dist_attn_runtime_key = magi_attn_flex_dispatch( - x, + dist_attn_runtime_key = magi_attn_flex_key( q_ranges=q_ranges, k_ranges=k_ranges, attn_mask_type=[masktype.value for masktype in attn_mask_type] @@ -537,6 +539,9 @@ def test_interface( else attn_mask_type, total_seqlen_q=total_seqlen_q, total_seqlen_k=total_seqlen_k, + num_heads_q=num_heads_q, + num_heads_kv=num_heads_kv, + head_dim=head_dim, pad_size=pad_size, chunk_size=chunk_size, cp_group_or_mesh=self.device_mesh @@ -544,16 +549,19 @@ def test_interface( else self.nccl_group, dist_attn_config=dist_attn_config, ) + local_x_padded = dispatch(x, key=dist_attn_runtime_key) case "set_mesh_and_group": if magi_attention.comm.is_hierarchical_comm_enable(): with pytest.raises(AssertionError): - _, dist_attn_runtime_key = magi_attn_flex_dispatch( - x, + dist_attn_runtime_key = magi_attn_flex_key( q_ranges=q_ranges, k_ranges=k_ranges, attn_mask_type=attn_mask_type, total_seqlen_q=total_seqlen_q, total_seqlen_k=total_seqlen_k, + num_heads_q=num_heads_q, + num_heads_kv=num_heads_kv, + head_dim=head_dim, pad_size=pad_size, chunk_size=chunk_size, cp_group_or_mesh=self.nccl_group, @@ -561,13 +569,15 @@ def test_interface( ) else: with pytest.raises(ValueError): - _, dist_attn_runtime_key = magi_attn_flex_dispatch( - x, + dist_attn_runtime_key = magi_attn_flex_key( q_ranges=q_ranges, k_ranges=k_ranges, attn_mask_type=attn_mask_type, total_seqlen_q=total_seqlen_q, total_seqlen_k=total_seqlen_k, + num_heads_q=num_heads_q, + num_heads_kv=num_heads_kv, + head_dim=head_dim, pad_size=pad_size, chunk_size=chunk_size, cp_group_or_mesh=self.device_mesh, @@ -577,13 +587,15 @@ def test_interface( case "test_for_invalid_mask": invalid_mask_type = attn_config["attn_mask_type"] with pytest.raises(ValueError): - _, dist_attn_runtime_key = magi_attn_flex_dispatch( - x, + dist_attn_runtime_key = magi_attn_flex_key( q_ranges=q_ranges, k_ranges=k_ranges, attn_mask_type=invalid_mask_type, total_seqlen_q=total_seqlen_q, total_seqlen_k=total_seqlen_k, + num_heads_q=num_heads_q, + num_heads_kv=num_heads_kv, + head_dim=head_dim, pad_size=pad_size, chunk_size=chunk_size, cp_group_or_mesh=self.device_mesh @@ -618,13 +630,13 @@ def test_interface( attn_mask_type=attn_mask_type, total_seqlen_q=total_seqlen_q + pad_size, total_seqlen_k=total_seqlen_k + pad_size, + num_heads_q=num_heads_q, + num_heads_kv=num_heads_kv, + head_dim=head_dim, chunk_size=chunk_size, cp_group=self.nccl_group, - is_same_source=True, - is_q_permutable=True, - is_k_permutable=True, - dist_attn_config=dist_attn_config, cp_mesh=self.device_mesh, + dist_attn_config=dist_attn_config, ) # ------- check mgr equality to ref -------- # @@ -725,13 +737,13 @@ def test_interface( attn_mask_type=new_attn_mask_type, total_seqlen_q=total_seqlen_q + pad_size, total_seqlen_k=total_seqlen_k + pad_size, + num_heads_q=num_heads_q, + num_heads_kv=num_heads_kv, + head_dim=head_dim, chunk_size=chunk_size, cp_group=self.nccl_group, - is_same_source=True, - is_q_permutable=True, - is_k_permutable=True, - dist_attn_config=new_dist_attn_config, cp_mesh=self.device_mesh, + dist_attn_config=new_dist_attn_config, ref_dispatch_meta_q=dist_attn_runtime_mgr.dispatch_meta_q, ref_dispatch_meta_k=dist_attn_runtime_mgr.dispatch_meta_k, ) @@ -886,6 +898,9 @@ def test_compiled_magiattn(self): attn_mask_type=attn_mask_type, total_seqlen_q=total_seqlen_q, total_seqlen_k=total_seqlen_k, + num_heads_q=num_heads_q, + num_heads_kv=num_heads_kv, + head_dim=head_dim, pad_size=pad_size, chunk_size=chunk_size, cp_group_or_mesh=self.nccl_group, # assuming we only have 1-dim context parallelism (cp) diff --git a/tests/test_attn/test_dist_attn.py b/tests/test_attn/test_dist_attn.py index f9f91f0de..cb7028972 100644 --- a/tests/test_attn/test_dist_attn.py +++ b/tests/test_attn/test_dist_attn.py @@ -220,6 +220,9 @@ def test_full_attn( # TODO: support qo comm meta calculation num_remote_qo_tokens_per_stage=[0], qo_group_collective_args_list=[None], # type: ignore[list-item] + num_heads_q=nhq, + num_heads_kv=nhk, + head_dim=head_dim, ) dist_attn_runtime = DistAttnRuntime( comm_meta=comm_meta, diff --git a/tests/test_attn_solver/test_dist_attn_solver.py b/tests/test_attn_solver/test_dist_attn_solver.py index 263fc429e..f95cbc4c6 100644 --- a/tests/test_attn_solver/test_dist_attn_solver.py +++ b/tests/test_attn_solver/test_dist_attn_solver.py @@ -511,6 +511,7 @@ def test_init_host_remote_ranges_global(self, testcase): test_solver_class = SimpleNamespace() test_solver_class.cp_rank = rank test_solver_class.cp_size = cp_size + test_solver_class.split_alignment_kv = 1 _make_bucket_this_rank = types.MethodType( DistAttnSolver._make_bucket_this_rank, test_solver_class ) @@ -1366,6 +1367,7 @@ def test_init_host_rank_entry(self, testcase): test_solver_class = SimpleNamespace() test_solver_class.cp_rank = rank test_solver_class.cp_size = cp_size + test_solver_class.split_alignment_kv = 1 _make_bucket_this_rank = types.MethodType( DistAttnSolver._make_bucket_this_rank, test_solver_class ) @@ -2461,6 +2463,7 @@ def test_init_remote_rank_entry_per_stage_this_rank(self, testcase): ) test_solver_class.cp_rank = rank test_solver_class.cp_size = cp_size + test_solver_class.split_alignment_kv = 1 # -------------- compute meta -------------- # @@ -2759,6 +2762,7 @@ def test_init_remote_rank_entry_this_rank_with_numpy( ) test_solver_class.cp_rank = rank test_solver_class.cp_size = cp_size + test_solver_class.split_alignment_kv = 1 _make_bucket_this_rank = types.MethodType( DistAttnSolver._make_bucket_this_rank, test_solver_class ) diff --git a/tests/test_attn_solver/test_dynamic_attn_solver.py b/tests/test_attn_solver/test_dynamic_attn_solver.py index a85a2aed2..56ea156e6 100644 --- a/tests/test_attn_solver/test_dynamic_attn_solver.py +++ b/tests/test_attn_solver/test_dynamic_attn_solver.py @@ -40,7 +40,7 @@ SEED = 42 # MaskIterator settings -TOTAL_SEQLEN = 100 +TOTAL_SEQLEN = 128 NUM_ITERATIONS = 1 @@ -116,6 +116,7 @@ def test_dynamic_attn_solver_solve_function( rank = self.rank cp_size = self.world_size manual_seed = self.seed + head_dim = 128 torch.manual_seed(manual_seed) # -------------- init parameters -------------- # @@ -186,13 +187,14 @@ def create_ranges_per_rank( solver = DynamicAttnSolver( algorithm=algorithm, + num_heads_q=num_heads_q, + num_heads_kv=num_heads_kv, + head_dim=head_dim, cp_group=self.process_group, total_seqlen_q=total_seqlen_q, total_seqlen_k=total_seqlen_k, host_ranges_q=host_q_ranges_global_this_rank, host_ranges_k=host_k_ranges_global_this_rank, - num_heads_q=num_heads_q, - num_heads_kv=num_heads_kv, cp_rank=rank, cp_size=cp_size, calc_local_range=True, @@ -226,8 +228,8 @@ def create_ranges_per_rank( ) # all q k attn position should be add 1 and equal2 in solver result - for i in range(cp_size): - rects = solver.bucket_per_rank[i] + for r in range(cp_size): + rects = solver.bucket_per_rank[r] for rect in rects: for i in range(rect.q_range.start, rect.q_range.end): for j in range(rect.k_range.start, rect.k_range.end): diff --git a/tests/test_common/test_attn_ranges.py b/tests/test_common/test_attn_ranges.py index 458683de3..4f2f0224a 100644 --- a/tests/test_common/test_attn_ranges.py +++ b/tests/test_common/test_attn_ranges.py @@ -555,6 +555,30 @@ def test_merge(self): self.assertEqual(merged_ranges, AttnRanges.from_ranges([(0, 10)])) self.assertTrue(merged_ranges.is_merged()) + def test_merge_with_split_alignment(self): + # case1: basic alignment + ranges = AttnRanges.from_ranges([(1, 9), (11, 19)]) + # alignment=10 -> [(0, 10), (10, 20)] -> merged to [(0, 20)] + merged = ranges.merge_with_split_alignment(split_alignment=10) + self.assertEqual(merged, AttnRanges.from_ranges([(0, 20)])) + + # case2: disjoint after alignment + ranges = AttnRanges.from_ranges([(1, 9), (21, 29)]) + # alignment=10 -> [(0, 10), (20, 30)] + merged = ranges.merge_with_split_alignment(split_alignment=10) + self.assertEqual(merged, AttnRanges.from_ranges([(0, 10), (20, 30)])) + + # case3: overlap after alignment + ranges = AttnRanges.from_ranges([(5, 15), (12, 25)]) + # alignment=10 -> [(0, 20), (10, 30)] -> merged to [(0, 30)] + merged = ranges.merge_with_split_alignment(split_alignment=10) + self.assertEqual(merged, AttnRanges.from_ranges([(0, 30)])) + + # case4: already aligned + ranges = AttnRanges.from_ranges([(0, 10), (10, 20), (25, 30)]) + merged = ranges.merge_with_split_alignment(split_alignment=5) + self.assertEqual(merged, AttnRanges.from_ranges([(0, 20), (25, 30)])) + def test_non_overlap(self): attn_ranges = AttnRanges.from_ranges([(8, 14), (5, 10)]) self.assertFalse(attn_ranges.is_non_overlap()) diff --git a/tests/test_dist_runtime_mgr/test_dist_runtime_mgr.py b/tests/test_dist_runtime_mgr/test_dist_runtime_mgr.py index 1b6937215..f0f1f74a1 100644 --- a/tests/test_dist_runtime_mgr/test_dist_runtime_mgr.py +++ b/tests/test_dist_runtime_mgr/test_dist_runtime_mgr.py @@ -420,19 +420,26 @@ def test_update_xattn_k_ranges( total_seqlen_xattn_k: int = test_config["total_seqlen_xattn_k"] chunk_size: int = test_config["chunk_size"] + num_heads_q = 1 + num_heads_kv = 1 + head_dim = 128 + dist_attn_runtime_mgr = init_dist_attn_runtime_mgr( q_ranges=q_ranges, k_ranges=k_ranges, attn_mask_type=[AttnMaskType.FULL] * len(q_ranges), total_seqlen_q=total_seqlen_q, total_seqlen_k=total_seqlen_k, + num_heads_q=num_heads_q, + num_heads_kv=num_heads_kv, + head_dim=head_dim, chunk_size=chunk_size, cp_group=self.nccl_group, cp_mesh=self.device_mesh, + dist_attn_config=DistAttnConfig(), is_same_source=True, is_q_permutable=True, is_k_permutable=True, - dist_attn_config=DistAttnConfig(), ) host_xattn_attn_arg: AttnArg = dist_attn_runtime_mgr.get_xattn_args( @@ -444,22 +451,22 @@ def test_update_xattn_k_ranges( total_q = torch.randn( total_seqlen_q, - 1, - 128, + num_heads_q, + head_dim, device=torch.cuda.current_device(), dtype=torch.float16, ) xattn_k = torch.randn( total_seqlen_xattn_k, - 1, - 128, + num_heads_kv, + head_dim, device=torch.cuda.current_device(), dtype=torch.float16, ) xattn_v = torch.randn( total_seqlen_xattn_k, - 1, - 128, + num_heads_kv, + head_dim, device=torch.cuda.current_device(), dtype=torch.float16, ) @@ -700,6 +707,10 @@ def test_ref_dispatch_meta( total_seqlen_k: int = test_config["total_seqlen_k"] chunk_size: int = test_config["chunk_size"] + num_heads_q = 16 + num_heads_kv = 4 + head_dim = 128 + # use dispatch mask to init dist attn runtime mgr dispatch_dist_attn_runtime_mgr = init_dist_attn_runtime_mgr( q_ranges=dispatch_q_ranges, @@ -707,6 +718,9 @@ def test_ref_dispatch_meta( attn_mask_type=dispatch_attn_mask_type, total_seqlen_q=total_seqlen_q, total_seqlen_k=total_seqlen_k, + num_heads_q=num_heads_q, + num_heads_kv=num_heads_kv, + head_dim=head_dim, chunk_size=chunk_size, cp_group=self.nccl_group, cp_mesh=self.device_mesh, @@ -728,6 +742,9 @@ def test_ref_dispatch_meta( attn_mask_type=attn_mask_type, total_seqlen_q=total_seqlen_q, total_seqlen_k=total_seqlen_k, + num_heads_q=num_heads_q, + num_heads_kv=num_heads_kv, + head_dim=head_dim, chunk_size=chunk_size, cp_group=self.nccl_group, cp_mesh=self.device_mesh, @@ -760,6 +777,9 @@ def test_ref_dispatch_meta( attn_mask_type=attn_mask_type, total_seqlen_q=total_seqlen_q, total_seqlen_k=total_seqlen_k, + num_heads_q=num_heads_q, + num_heads_kv=num_heads_kv, + head_dim=head_dim, dist_attn_runtime_mgr=dist_attn_runtime_mgr, dtype=torch.float64, test_case=test_config["test_case"], @@ -772,10 +792,10 @@ def _calc_attn_with_mgr_and_assert_close_to_ref( attn_mask_type: list[AttnMaskType], total_seqlen_q: int, total_seqlen_k: int, + num_heads_q: int, + num_heads_kv: int, + head_dim: int, dist_attn_runtime_mgr: DistAttnRuntimeMgr, - num_heads_q: int = 16, - num_heads_kv: int = 4, - head_dim: int = 128, dtype: torch.dtype = torch.float64, run_bwd: bool = True, test_case: str = "", diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 8394d5cda..bb9f7fd83 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -24,7 +24,6 @@ import magi_attention from magi_attention import init_dist_attn_runtime_mgr -from magi_attention.comm.primitive.grpcoll._buffer import GrpCollBuffer from magi_attention.comm.primitive.grpcoll._mgr import grpcoll_buffer_mgr from magi_attention.common.enum import AttnMaskType, AttnOverlapMode, AttnSinkLayout from magi_attention.common.ranges import AttnRanges @@ -147,6 +146,7 @@ def init_pg(self) -> None: "enable_native_grpcoll": "MAGI_ATTENTION_NATIVE_GRPCOLL", "fwd_hp_reduce": "MAGI_ATTENTION_FORWARD_HIGH_PRECISION_REDUCE", "bwd_hp_reduce": "MAGI_ATTENTION_BACKWARD_HIGH_PRECISION_REDUCE", + "bwd_hide_tail_reduce": "MAGI_ATTENTION_BWD_HIDE_TAIL_REDUCE", } # init flag generator and its iterator @@ -184,6 +184,7 @@ def init_pg(self) -> None: # TODO: support qo comm for fa4 backend else [False] ), + "bwd_hide_tail_reduce": [True, False], }, defaults={ "device_max_connections": 8, @@ -214,7 +215,7 @@ def world_size(self) -> int: @property def seed(self) -> int: - return 42 + return 42 + self.world_size @with_comms @parameterize( @@ -637,6 +638,10 @@ def test_pipeline( if magi_attention.comm.is_hierarchical_comm_enable(): return + # TODO: support hiding backward tail reduce for qo comm + if magi_attention.dist_attn_backward_hide_tail_reduce(): + return + # ----- skip for native grpcoll ---- # if magi_attention.comm.is_native_grpcoll_enable(): @@ -649,10 +654,6 @@ def test_pipeline( if magi_attention.is_deterministic_mode_enable(): return - hidden_size_kv = num_heads[1] * head_dim - if hidden_size_kv % GrpCollBuffer.get_hidden_size_alignment(dtype) != 0: - return - # ----- skip for flatten head groups ---- # if magi_attention.is_flatten_head_groups_enable(): @@ -776,15 +777,13 @@ def test_pipeline( attn_mask_type=attn_mask_type, total_seqlen_q=total_seqlen_q, total_seqlen_k=total_seqlen_k, + num_heads_q=num_heads_q, + num_heads_kv=num_heads_kv, + head_dim=head_dim, chunk_size=chunk_size, cp_group=self.nccl_group, - is_same_source=True, - is_q_permutable=True, - is_k_permutable=True, - dist_attn_config=dist_attn_config, cp_mesh=self.device_mesh, - num_heads_q=num_heads_q, - num_heads_kv=num_heads_kv, + dist_attn_config=dist_attn_config, ) # HACK: seperate cp group for group-reduce dist_attn_runtime_mgr.dist_attn_runtime.cp_group_gr = self.nccl_groups[1] diff --git a/tests/test_pipeline_sdpa.py b/tests/test_pipeline_sdpa.py index 775f6d0ed..dc5a581c3 100644 --- a/tests/test_pipeline_sdpa.py +++ b/tests/test_pipeline_sdpa.py @@ -23,7 +23,6 @@ import magi_attention from magi_attention import init_dist_attn_runtime_mgr -from magi_attention.comm.primitive.grpcoll._buffer import GrpCollBuffer from magi_attention.comm.primitive.grpcoll._mgr import grpcoll_buffer_mgr from magi_attention.common.enum import AttnMaskType, AttnOverlapMode, AttnSinkLayout from magi_attention.common.ranges import AttnRanges @@ -123,6 +122,7 @@ def init_pg(self) -> None: "fwd_hp_reduce": "MAGI_ATTENTION_FORWARD_HIGH_PRECISION_REDUCE", "bwd_hp_reduce": "MAGI_ATTENTION_BACKWARD_HIGH_PRECISION_REDUCE", "flatten_head_groups": "MAGI_ATTENTION_FLATTEN_HEAD_GROUPS", + "bwd_hide_tail_reduce": "MAGI_ATTENTION_BWD_HIDE_TAIL_REDUCE", } # init flag generator and its iterator @@ -136,6 +136,7 @@ def init_pg(self) -> None: # disable native grpcoll if not registered successfully else [False] ), + "bwd_hide_tail_reduce": [True, False], }, defaults={ "device_max_connections": 8, @@ -164,6 +165,10 @@ def nccl_group(self) -> dist.ProcessGroup: def world_size(self) -> int: return 1 + @property + def seed(self) -> int: + return 42 + self.world_size + @property def dtype(self) -> torch.dtype: return torch.float64 @@ -786,6 +791,10 @@ def test_pipeline_sdpa( if magi_attention.comm.is_hierarchical_comm_enable(): return + # TODO: support hiding backward tail reduce for qo comm + if magi_attention.dist_attn_backward_hide_tail_reduce(): + return + # ----- skip for native grpcoll ---- # if magi_attention.comm.is_native_grpcoll_enable(): @@ -798,13 +807,6 @@ def test_pipeline_sdpa( if magi_attention.comm.is_qo_comm_enable(): return - hidden_size_kv = num_heads[1] * head_dim - if ( - hidden_size_kv % GrpCollBuffer.get_hidden_size_alignment(self.dtype) - != 0 - ): - return - # ----- skip for flatten head groups ---- # if magi_attention.is_flatten_head_groups_enable(): @@ -900,13 +902,13 @@ def test_pipeline_sdpa( attn_mask_type=attn_mask_type, total_seqlen_q=total_seqlen_q, total_seqlen_k=total_seqlen_k, + num_heads_q=num_heads_q, + num_heads_kv=num_heads_kv, + head_dim=head_dim, chunk_size=chunk_size, cp_group=self.nccl_group, - is_same_source=True, - is_q_permutable=True, - is_k_permutable=True, - dist_attn_config=dist_attn_config, cp_mesh=self.device_mesh, + dist_attn_config=dist_attn_config, ) # HACK: seperate cp group for group-reduce dist_attn_runtime_mgr.dist_attn_runtime.cp_group_gr = self.nccl_groups[1] diff --git a/tests/test_utils/test_common_utils.py b/tests/test_utils/test_common_utils.py index fbd60d93a..410b5f6e7 100644 --- a/tests/test_utils/test_common_utils.py +++ b/tests/test_utils/test_common_utils.py @@ -304,7 +304,7 @@ class CustomObject: def __init__(self, value): self.value = value - def __repr__(self): + def __repr__(self) -> str: return f"CustomObject({self.value})" def __eq__(self, other) -> bool: