@@ -79,6 +79,7 @@ def magi_attn_varlen_key(
7979 which represents ``[window_size_left, window_size_right]``. The parameter is effective only
8080 when ``causal`` is ``False``; when ``causal`` is ``True``, it is required to be ``(-1, -1)``.
8181 Defaults to be ``(-1, -1)``.
82+
8283 dist_attn_config (DistAttnConfig): dist attn config.
8384
8485 Returns:
@@ -208,6 +209,7 @@ def magi_attn_varlen_dispatch(
208209 which represents ``[window_size_left, window_size_right]``. The parameter is effective only
209210 when ``causal`` is ``False``; when ``causal`` is ``True``, it is required to be ``(-1, -1)``.
210211 Defaults to be ``(-1, -1)``.
212+
211213 dist_attn_config (DistAttnConfig): dist attn config.
212214
213215 Returns:
@@ -306,7 +308,6 @@ def magi_attn_flex_key(
306308 calculate DistAttnRuntimeKey and generate the corr. inner DistAttnRuntimeMgr.
307309
308310 Args:
309- x (torch.Tensor): input tensor
310311 q_ranges (AttnRanges): the global query ranges
311312 k_ranges (AttnRanges): the global key ranges
312313 attn_mask_type (str | AttnMaskType | list[str | AttnMaskType]):
@@ -518,6 +519,7 @@ def magi_attn_flex_dispatch(
518519
519520 Args:
520521 x (torch.Tensor): input tensor
522+
521523 q_ranges (AttnRanges): the global query ranges
522524 k_ranges (AttnRanges): the global key ranges
523525 attn_mask_type (str | AttnMaskType | list[str | AttnMaskType]):
@@ -817,12 +819,14 @@ def make_varlen_key_for_new_mask_after_dispatch(
817819 Args:
818820 cu_seqlens_q (torch.Tensor): Cumulative sequence lengths for queries.
819821 cu_seqlens_k (torch.Tensor): Cumulative sequence lengths for keys.
822+
820823 key_for_dispatch (DistAttnRuntimeKey): the key used for dispatch
821824 causal (bool, optional): whether the varlen attention mask is causal. Defaults to ``False``.
822825 window_size (tuple[int, int], optional): window_size of sliding window mask
823826 which represents ``[window_size_left, window_size_right]``. The parameter is effective only
824827 when ``causal`` is ``False``; when ``causal`` is ``True``, it is required to be ``(-1, -1)``.
825828 Defaults to be ``(-1, -1)``.
829+
826830 dist_attn_config (DistAttnConfig, optional): the optional new dist attn config,
827831
828832 NOTE: if not provided, we will use the same config as the ``key_for_dispatch``,
@@ -959,7 +963,9 @@ def make_flex_key_for_new_mask_after_dispatch(
959963 attn_mask_type (str | AttnMaskType | list[str | AttnMaskType]):
960964 the global attn mask type (list)
961965 represented by str or enum ``AttnMaskType`` or their mixed combination
966+
962967 key_for_dispatch (DistAttnRuntimeKey): the key used for dispatch
968+
963969 dist_attn_config (DistAttnConfig, optional): the optional new dist attn config,
964970
965971 NOTE: if not provided, we will use the same config as the ``key_for_dispatch``,
0 commit comments