From 3ff15402c18f3e9268698c4006d687c6c2a9869c Mon Sep 17 00:00:00 2001 From: kabachuha Date: Mon, 4 Mar 2024 14:10:43 +0300 Subject: [PATCH 1/4] add fla from rebased --- flash_linear_attention/__init__.py | 0 flash_linear_attention/fla/__init__.py | 8 + flash_linear_attention/fla/layers/__init__.py | 7 + flash_linear_attention/fla/layers/based.py | 214 ++++++ flash_linear_attention/fla/layers/gla.py | 154 ++++ .../fla/layers/multiscale_retention.py | 100 +++ flash_linear_attention/fla/layers/rebased.py | 258 +++++++ .../fla/layers/rebased_fast.py | 229 ++++++ .../fla/modules/__init__.py | 11 + .../fla/modules/convolution.py | 195 +++++ flash_linear_attention/fla/modules/rmsnorm.py | 647 ++++++++++++++++ flash_linear_attention/fla/modules/rotary.py | 312 ++++++++ flash_linear_attention/fla/ops/__init__.py | 25 + .../semiring/cal_A/inner_chunk16_dim16x.cpp | 11 + .../cal_A/inner_chunk16_dim16x_kernel.cu | 204 ++++++ .../fla/ops/torch/__init__.py | 7 + flash_linear_attention/fla/ops/torch/based.py | 131 ++++ flash_linear_attention/fla/ops/torch/gla.py | 119 +++ .../fla/ops/torch/retention.py | 15 + .../fla/ops/triton/__init__.py | 22 + .../fla/ops/triton/abc/__init__.py | 0 .../fla/ops/triton/abc/chunk_fuse.py | 692 ++++++++++++++++++ .../fla/ops/triton/based/__init__.py | 6 + .../fla/ops/triton/based/chunk_fuse.py | 410 +++++++++++ .../fla/ops/triton/based/parallel.py | 385 ++++++++++ .../fla/ops/triton/gla/__init__.py | 7 + .../ops/triton/gla/block_parallel/__init__.py | 0 .../inter_chunk_contribution/__init__.py | 0 .../chunk_scan_triton_full.py | 212 ++++++ .../chunk_scan_triton_no_decay.py | 166 +++++ .../chunk_scan_triton_only_gk.py | 187 +++++ .../chunk_scan_triton_only_gv.py | 199 +++++ .../inter_chunk_contribution/fn.py | 49 ++ .../preprocess_cumsum_gk.py | 259 +++++++ .../preprocess_cumsum_gv.py | 216 ++++++ .../intra_chunk_contribution/__init__.py | 0 .../intra_chunk_contribution/fn.py | 28 + .../intra_chunk_contribution/fn_only_gk.py | 343 +++++++++ .../intra_chunk_contribution/fn_only_gv.py | 336 +++++++++ .../fla/ops/triton/gla/chunk.py | 39 + .../fla/ops/triton/gla/chunk_fuse.py | 400 ++++++++++ .../fla/ops/triton/gla/recurrent_fuse.py | 403 ++++++++++ .../fla/ops/triton/rebased/__init__.py | 4 + .../fla/ops/triton/rebased/parallel.py | 388 ++++++++++ .../fla/ops/triton/rebased_fast/__init__.py | 4 + .../fla/ops/triton/rebased_fast/parallel.py | 390 ++++++++++ .../fla/ops/triton/retention/__init__.py | 9 + .../fla/ops/triton/retention/chunk.py | 389 ++++++++++ .../fla/ops/triton/retention/chunk_fuse.py | 329 +++++++++ .../fla/ops/triton/retention/parallel.py | 341 +++++++++ .../ops/triton/retention/recurrent_fuse.py | 280 +++++++ .../fla/ops/triton/rotary.py | 252 +++++++ .../fla/ops/triton/utils.py | 27 + flash_linear_attention/setup.py | 147 ++++ 54 files changed, 9566 insertions(+) create mode 100644 flash_linear_attention/__init__.py create mode 100644 flash_linear_attention/fla/__init__.py create mode 100644 flash_linear_attention/fla/layers/__init__.py create mode 100644 flash_linear_attention/fla/layers/based.py create mode 100644 flash_linear_attention/fla/layers/gla.py create mode 100644 flash_linear_attention/fla/layers/multiscale_retention.py create mode 100644 flash_linear_attention/fla/layers/rebased.py create mode 100644 flash_linear_attention/fla/layers/rebased_fast.py create mode 100644 flash_linear_attention/fla/modules/__init__.py create mode 100644 flash_linear_attention/fla/modules/convolution.py create mode 100644 flash_linear_attention/fla/modules/rmsnorm.py create mode 100644 flash_linear_attention/fla/modules/rotary.py create mode 100644 flash_linear_attention/fla/ops/__init__.py create mode 100644 flash_linear_attention/fla/ops/cuda/gla/semiring/cal_A/inner_chunk16_dim16x.cpp create mode 100644 flash_linear_attention/fla/ops/cuda/gla/semiring/cal_A/inner_chunk16_dim16x_kernel.cu create mode 100644 flash_linear_attention/fla/ops/torch/__init__.py create mode 100644 flash_linear_attention/fla/ops/torch/based.py create mode 100644 flash_linear_attention/fla/ops/torch/gla.py create mode 100644 flash_linear_attention/fla/ops/torch/retention.py create mode 100644 flash_linear_attention/fla/ops/triton/__init__.py create mode 100644 flash_linear_attention/fla/ops/triton/abc/__init__.py create mode 100644 flash_linear_attention/fla/ops/triton/abc/chunk_fuse.py create mode 100644 flash_linear_attention/fla/ops/triton/based/__init__.py create mode 100644 flash_linear_attention/fla/ops/triton/based/chunk_fuse.py create mode 100644 flash_linear_attention/fla/ops/triton/based/parallel.py create mode 100644 flash_linear_attention/fla/ops/triton/gla/__init__.py create mode 100644 flash_linear_attention/fla/ops/triton/gla/block_parallel/__init__.py create mode 100644 flash_linear_attention/fla/ops/triton/gla/block_parallel/inter_chunk_contribution/__init__.py create mode 100644 flash_linear_attention/fla/ops/triton/gla/block_parallel/inter_chunk_contribution/chunk_scan_triton_full.py create mode 100644 flash_linear_attention/fla/ops/triton/gla/block_parallel/inter_chunk_contribution/chunk_scan_triton_no_decay.py create mode 100644 flash_linear_attention/fla/ops/triton/gla/block_parallel/inter_chunk_contribution/chunk_scan_triton_only_gk.py create mode 100644 flash_linear_attention/fla/ops/triton/gla/block_parallel/inter_chunk_contribution/chunk_scan_triton_only_gv.py create mode 100644 flash_linear_attention/fla/ops/triton/gla/block_parallel/inter_chunk_contribution/fn.py create mode 100644 flash_linear_attention/fla/ops/triton/gla/block_parallel/inter_chunk_contribution/preprocess_cumsum_gk.py create mode 100644 flash_linear_attention/fla/ops/triton/gla/block_parallel/inter_chunk_contribution/preprocess_cumsum_gv.py create mode 100644 flash_linear_attention/fla/ops/triton/gla/block_parallel/intra_chunk_contribution/__init__.py create mode 100644 flash_linear_attention/fla/ops/triton/gla/block_parallel/intra_chunk_contribution/fn.py create mode 100644 flash_linear_attention/fla/ops/triton/gla/block_parallel/intra_chunk_contribution/fn_only_gk.py create mode 100644 flash_linear_attention/fla/ops/triton/gla/block_parallel/intra_chunk_contribution/fn_only_gv.py create mode 100644 flash_linear_attention/fla/ops/triton/gla/chunk.py create mode 100644 flash_linear_attention/fla/ops/triton/gla/chunk_fuse.py create mode 100644 flash_linear_attention/fla/ops/triton/gla/recurrent_fuse.py create mode 100644 flash_linear_attention/fla/ops/triton/rebased/__init__.py create mode 100644 flash_linear_attention/fla/ops/triton/rebased/parallel.py create mode 100644 flash_linear_attention/fla/ops/triton/rebased_fast/__init__.py create mode 100644 flash_linear_attention/fla/ops/triton/rebased_fast/parallel.py create mode 100644 flash_linear_attention/fla/ops/triton/retention/__init__.py create mode 100644 flash_linear_attention/fla/ops/triton/retention/chunk.py create mode 100644 flash_linear_attention/fla/ops/triton/retention/chunk_fuse.py create mode 100644 flash_linear_attention/fla/ops/triton/retention/parallel.py create mode 100644 flash_linear_attention/fla/ops/triton/retention/recurrent_fuse.py create mode 100644 flash_linear_attention/fla/ops/triton/rotary.py create mode 100644 flash_linear_attention/fla/ops/triton/utils.py create mode 100644 flash_linear_attention/setup.py diff --git a/flash_linear_attention/__init__.py b/flash_linear_attention/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/flash_linear_attention/fla/__init__.py b/flash_linear_attention/fla/__init__.py new file mode 100644 index 0000000..432f1ac --- /dev/null +++ b/flash_linear_attention/fla/__init__.py @@ -0,0 +1,8 @@ +# -*- coding: utf-8 -*- + +from fla.ops.triton import (fused_chunk_based, fused_chunk_gla, + fused_chunk_retention) + +__all__ = ['fused_chunk_based', 'fused_chunk_gla', 'fused_chunk_retention'] + +__version__ = '0.0.1' diff --git a/flash_linear_attention/fla/layers/__init__.py b/flash_linear_attention/fla/layers/__init__.py new file mode 100644 index 0000000..5080186 --- /dev/null +++ b/flash_linear_attention/fla/layers/__init__.py @@ -0,0 +1,7 @@ +# -*- coding: utf-8 -*- + +from .based import BasedLinearAttention +from .gla import GatedLinearAttention +from .multiscale_retention import MultiScaleRetention + +__all__ = ['GatedLinearAttention', 'MultiScaleRetention', 'BasedLinearAttention'] diff --git a/flash_linear_attention/fla/layers/based.py b/flash_linear_attention/fla/layers/based.py new file mode 100644 index 0000000..204b160 --- /dev/null +++ b/flash_linear_attention/fla/layers/based.py @@ -0,0 +1,214 @@ +# -*- coding: utf-8 -*- + +""" +Linear attention in Based. +https://github.com/HazyResearch/zoology/blob/main/zoology/mixers/based.py +""" +import math + +import opt_einsum as oe +import torch +import torch.nn as nn +from einops import rearrange + +from fla.ops.triton.based import fused_chunk_based, parallel_based + + +def init_feature_map(feature_map: str = 'none', **kwargs: any): + """ + Initialize query and key mapping for linear attention + """ + if feature_map in [None, 'none', 'identity']: + return FeatureMap(**kwargs) + # Taylor series approximations to exp(x) + elif feature_map == 'taylor_exp': + return TaylorExp(**kwargs) + else: + raise NotImplementedError( + f'Sorry "{feature_map}" feature map not implemented.') + + +class FeatureMap(nn.Module): + """ + Parent feature map; default is identity function + """ + + def __init__(self, + input_dim: int, + temp: int = None, + head_dim_idx: int = -1, + eps: float = 1e-12, + **kwargs: any): + super().__init__() + self.input_dim = input_dim + self.head_dim_idx = head_dim_idx + self.temp = 1. if temp is None else temp + self.eps = eps + + def forward(self, x: torch.Tensor): + """ + Assume x.shape is (batch_size, n_heads, seq_len, head_dim) + """ + return x + + +class TaylorExp(FeatureMap): + """ + Feature map to compute 2nd-order Taylor approx. of exp(q^T k / sqrt(d)) + """ + + def __init__(self, input_dim: int, **kwargs: any): + super().__init__(input_dim, **kwargs) + self.r2 = math.sqrt(2) + self.rd = math.sqrt(self.input_dim) + self.rrd = math.sqrt(self.rd) + self.tril_indices = torch.tril_indices( + self.input_dim, self.input_dim, -1) + + # Running these in parallel + def forward(self, x: torch.Tensor): + # Get 2nd-order terms (rearrange(x * x), '... m n -> ... (m n)') + x2 = (x.unsqueeze(-1) * x.unsqueeze(-2) + ).flatten(start_dim=-2) / self.r2 + return torch.cat([torch.ones(x[..., :1].shape).to(x.device), + x / self.rrd, x2 / self.rd], dim=self.head_dim_idx) + + def forward_mem_save(self, x: torch.Tensor) -> torch.Tensor: + """ + Compute f(x) s.t. f(x)^T f(x') = 1 + x^Tx' + (x^Tx')^2 / 2 + -> Assume x.shape is (batch_size, n_heads, seq_len, head_dim) + """ + # Slow but memory-saving way to compute 2nd-order terms; how do w/o outer-product first? + x2 = oe.contract('...m,...n->...mn', x, x) / self.rd + x2d = torch.diagonal(x2, dim1=-2, dim2=-1) / self.r2 + x2 = x2[..., self.tril_indices[0], self.tril_indices[1]] + x = torch.cat([torch.ones(x[..., :1].shape).to(x.device), + x / self.rrd, x2d, x2], dim=-1) + return x + + +class BasedLinearAttention(nn.Module): + def __init__( + self, + d_model: int, + l_max: int = 2048, + feature_dim: int = 16, + num_key_value_heads: int = 12, + num_heads: int = 12, + feature_name: str = "taylor_exp", + eps: float = 1e-12, + causal: bool = True, + mode: str = "parallel", + ): + super().__init__() + self.d_model = d_model + self.l_max = l_max + self.mode = mode + assert self.mode in ["fused_chunk", "parallel"] + + # linear attention + self.feature_name = feature_name + self.feature_dim = feature_dim + self.num_key_value_heads = num_key_value_heads + self.num_heads = num_heads + self.head_dim = self.d_model // self.num_key_value_heads + self.causal = causal + feature_map_kwargs = { + 'input_dim': self.feature_dim, + 'head_dim_idx': -1, + 'temp': 1., + 'eps': 1e-12 + } + self.feature_map = init_feature_map( + feature_map=self.feature_name, **feature_map_kwargs) + self.proj_q = nn.Linear( + self.d_model, self.feature_dim * self.num_heads, bias=False) + self.proj_k = nn.Linear( + self.d_model, self.feature_dim * self.num_heads, bias=False) + self.proj_v = nn.Linear( + self.d_model, self.num_key_value_heads * self.head_dim, bias=False) + self.proj_o = nn.Linear( + self.num_heads * self.head_dim, self.d_model, bias=False) + self.dropout = nn.Identity() + self.eps = eps + + def forward(self, hidden_states: torch.Tensor, **kwargs): + mode = self.mode + b, l, _ = hidden_states.size() + q, k, v = self.proj_q(hidden_states), self.proj_k( + hidden_states), self.proj_v(hidden_states) + q, k, v = map(lambda x: rearrange( + x, "b l (h d) -> b h l d", h=self.num_heads), [q, k, v]) + if mode == "fused_chunk": + assert q.shape[-1] <= 16 + o = fused_chunk_based(q, k, v, self.eps, True, True) + elif mode == 'parallel': + assert q.shape[-1] <= 128 + o = parallel_based(q, k, v, self.eps, True, True) + o = rearrange(o, "b h l d -> b l (h d)") + o = self.proj_o(o) + o = self.dropout(o) + return o + + # https://github.com/HazyResearch/zoology/blob/main/zoology/mixers/based.py#L119 + + def forward_reference(self, hidden_states: torch.Tensor, filters: torch.Tensor = None, *args, **kwargs): + """ + x (torch.Tensor): tensor of shape (b, d, l) + y (torch.Tensor): tensor of shape (b, d, l) + """ + # hidden_states = hidden_states.transpose(1, 2) + b, l, _ = hidden_states.size() + q, k, v = self.proj_q(hidden_states), self.proj_k( + hidden_states), self.proj_v(hidden_states) + + q = q.view(b, l, self.num_heads, self.feature_dim).transpose(1, 2) + k = k.view(b, l, self.num_key_value_heads, + self.feature_dim).transpose(1, 2) + v = v.view(b, l, self.num_key_value_heads, + self.head_dim).transpose(1, 2) + + # Linear attention + q, k = self.feature_map(q), self.feature_map(k) + q, k, v = q.unsqueeze(-2), k.unsqueeze(-2), v.unsqueeze(-1) + + # Compute attention + if self.causal: + y = ((q * (k * v).cumsum(2)).sum(-1) / ((q * k.cumsum(2)).sum(-1) + self.eps)) + else: + y = ((q * (k * v).sum(2, True)).sum(-1) / ((q * k.sum(2, True)).sum(-1) + self.eps)) + y = rearrange(y, 'b h l d -> b l (h d)') + y = self.proj_o(y.to(hidden_states.dtype)) + y = self.dropout(y) + return y.to(hidden_states.dtype) + + +if __name__ == '__main__': + batch = 4 + seq_len = 1024 + d_model = 1024 + dtype = torch.float32 + x = torch.randn(batch, seq_len, d_model).to( + dtype).cuda().requires_grad_(True) + dy = torch.randn(batch, seq_len, d_model).to( + dtype).cuda() + model = BasedLinearAttention(d_model=d_model).to(dtype).cuda() + y = model(x) + y.backward(dy, retain_graph=True) + x_grad, x.grad = x.grad, None + + proj_q_grad, model.proj_q.weight.grad = model.proj_q.weight.grad, None + proj_k_grad, model.proj_k.weight.grad = model.proj_k.weight.grad, None + proj_v_grad, model.proj_v.weight.grad = model.proj_v.weight.grad, None + + x.requires_grad_(True) + y2 = model.forward_reference(x) + y2.backward(dy) + print((y - y2).abs().max().item()) + # assert y.allclose(y2, 0, 1e-4), breakpoint() + print((x_grad - x.grad).abs().max().item()) + # assert x_grad.allclose(x.grad, 0, 1e-4), breakpoint() + print((proj_q_grad - model.proj_q.weight.grad).abs().max().item()) + print((proj_k_grad - model.proj_k.weight.grad).abs().max().item()) + print((proj_v_grad - model.proj_v.weight.grad).abs().max().item()) + print("All good with based!") diff --git a/flash_linear_attention/fla/layers/gla.py b/flash_linear_attention/fla/layers/gla.py new file mode 100644 index 0000000..d0626bb --- /dev/null +++ b/flash_linear_attention/fla/layers/gla.py @@ -0,0 +1,154 @@ +# -*- coding: utf-8 -*- + +# "Gated Linear Attention Transformers with Hardware-Efficient Training"[https://arxiv.org/abs/2312.06635] + +from __future__ import annotations + +import warnings + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from fla.modules.rmsnorm import RMSNorm +from fla.ops.triton.gla import chunk_gla, fused_chunk_gla, fused_recurrent_gla + + +def get_activation_fn(activation): + if activation == 'swish': + return F.silu + elif activation == 'gelu': + return F.gelu + else: + raise NotImplementedError + + +class GatedLinearAttention(nn.Module): + + def __init__( + self, + d_model: int = 1024, + expand_v: int = 2, + expand_k: int = 1, + num_heads: int = 1, + gate_fn: str = 'swish', + layernorm_eps: float = 1e-5, + gate_logit_normalizer: int = 32, + gate_logit_multiplier: int = 1, + gate_low_rank_dim: int = 32, + mode: str = 'fused_chunk', + chunk_size: int = 64, + use_gk: bool = True, # gate associated with key, i.e., $\alpha$ in the paper + use_gv: bool = False, # gate associated with value, i.e., $\beta$ in the paper + *args, **kwargs + ) -> GatedLinearAttention: + super().__init__() + if use_gv is True: + assert mode in ['chunk', 'fused_recurrent'] + if mode == 'fused_chunk': + assert use_gk is True + if mode != 'chunk' and chunk_size != 16: + warnings.warn( + f" `chunk_size` is only used for `chunk` mode." + f" The `{mode}` mode will suppress the passed value of {chunk_size} and always use 16." + ) + self.use_gk = use_gk + self.use_gv = use_gv + self.d_model = d_model + self.mode = mode + self.chunk_size = chunk_size + self.value_dim = int(d_model * expand_v) + self.key_dim = int(d_model * expand_k) + assert mode in ['chunk', 'fused_recurrent', + 'fused_chunk'], f"Not suppoerted mode `{mode}`." + assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}" + assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}" + self.num_heads = num_heads + self.head_qk_dim = self.key_dim // num_heads + self.head_v_dim = self.value_dim // num_heads + self.gate_fn = get_activation_fn(activation=str(gate_fn)) + self.q_proj = nn.Linear(d_model, self.key_dim, bias=False) + self.k_proj = nn.Linear(d_model, self.key_dim, bias=False) + self.v_proj = nn.Linear(d_model, self.value_dim, bias=False) + self.g_proj = nn.Linear(d_model, self.value_dim, bias=False) + + if self.use_gk: + self.gk_proj = nn.Sequential(nn.Linear(d_model, gate_low_rank_dim, bias=False), + nn.Linear(gate_low_rank_dim, self.key_dim, bias=True)) + else: + self.gk_proj = None + if self.use_gv: + self.gv_proj = nn.Sequential(nn.Linear(d_model, gate_low_rank_dim, bias=False), + nn.Linear(gate_low_rank_dim, self.value_dim, + bias=True)) + else: + self.gv_proj = None + self.out_proj = nn.Linear(self.value_dim, d_model, bias=False) + self.group_norm = RMSNorm(self.head_v_dim, eps=layernorm_eps) + self.gate_logit_normalizer = gate_logit_normalizer + self.gate_logit_multiplier = gate_logit_multiplier + + self.reset_parameters() + + def reset_parameters(self): + nn.init.xavier_uniform_(self.q_proj.weight, gain=2 ** -2.5) + nn.init.xavier_uniform_(self.k_proj.weight, gain=2 ** -2.5) + nn.init.xavier_uniform_(self.v_proj.weight, gain=2 ** -2.5) + nn.init.xavier_uniform_(self.g_proj.weight, gain=2 ** -2.5) + nn.init.xavier_uniform_(self.out_proj.weight, gain=2 ** -2.5) + if self.gk_proj is not None: + nn.init.xavier_uniform_(self.gk_proj[0].weight, gain=2 ** -2.5) + nn.init.xavier_uniform_(self.gk_proj[1].weight, gain=2 ** -2.5) + if self.gv_proj is not None: + nn.init.xavier_uniform_(self.gv_proj[0].weight, gain=2 ** -2.5) + nn.init.xavier_uniform_(self.gv_proj[1].weight, gain=2 ** -2.5) + + def forward(self, x): + mode = self.mode + chunk_size = self.chunk_size + + q = rearrange(self.q_proj(x), 'b n (h d) -> b h n d', h=self.num_heads) + k = rearrange(self.k_proj(x), 'b n (h d) -> b h n d', h=self.num_heads) + v = rearrange(self.v_proj(x), 'b n (h d) -> b h n d', h=self.num_heads) + + if mode == 'chunk' or mode == 'fused_recurrent': + # for numumerical stable consideration. fused_chunk has better numerical stability + if self.use_gk: + gk = self.gk_proj(x).to(torch.float32) + gk = (F.logsigmoid(gk) / self.gate_logit_normalizer).clamp_min_(-3) + gk = rearrange(gk, 'b n (h d) -> b h n d', h=self.num_heads) + else: + gk = None + if self.use_gv: + gv = self.gv_proj(x).to(torch.float32) + gv = (F.logsigmoid(gv) / self.gate_logit_normalizer).clamp_min_(-3) + gv = rearrange(gv, 'b n (h d) -> b h n d', h=self.num_heads) + else: + gv = None + if mode == 'fused_recurrent': + o = fused_recurrent_gla(q, k, v, gk=gk, gv=gv) + else: + o = chunk_gla(q, k, v, gk=gk, gv=gv, chunk_size=chunk_size) + else: + g = self.gk_proj(x).to(torch.float32) + g = F.logsigmoid(g * self.gate_logit_multiplier) / self.gate_logit_normalizer + g = rearrange(g, 'b n (h d) -> b h n d', h=self.num_heads) + o = fused_chunk_gla(q, k, v, g) + + o = self.group_norm(rearrange(o, 'b h n d -> b n h d')) + o = self.out_proj(rearrange(o, 'b n h d -> b n (h d)') + * self.gate_fn(self.g_proj(x))) + return o + + +if __name__ == '__main__': + batch = 4 + seq_len = 1023 + d_model = 1024 + x = torch.randn(batch, seq_len, d_model).to( + torch.bfloat16).cuda().requires_grad_(True) + model = GatedLinearAttention(use_gk=True, use_gv=True, mode='chunk').to(torch.bfloat16).cuda() + y = model(x) + print(y.shape) + y.sum().backward() + print(x.grad.shape) diff --git a/flash_linear_attention/fla/layers/multiscale_retention.py b/flash_linear_attention/fla/layers/multiscale_retention.py new file mode 100644 index 0000000..30650a4 --- /dev/null +++ b/flash_linear_attention/fla/layers/multiscale_retention.py @@ -0,0 +1,100 @@ +# -*- coding: utf-8 -*- + +# Retentive Network: A Successor to Transformer for Large Language Models"[https://arxiv.org/pdf/2307.08621.pdf] + +from __future__ import annotations + +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from fla.modules.rmsnorm import RMSNorm +from fla.modules.rotary import RotaryEmbedding +from fla.ops.triton.retention import (fused_chunk_retention, + fused_recurrent_retention, + parallel_retention) + + +def get_activation_fn(activation): + if activation == 'swish': + return F.silu + elif activation == 'gelu': + return F.gelu + else: + raise NotImplementedError + + +class MultiScaleRetention(nn.Module): + def __init__( + self, + d_model: str = 1024, + expand_k: str = 1, + expand_v: str = 2, + num_heads: str = 4, + gate_fn: str = 'swish', + layernorm_eps: float = 1e-5, + mode: str = 'chunk', + *args, **kwargs + ) -> MultiScaleRetention: + super().__init__() + + self.d_model = d_model + self.mode = mode + self.value_dim = int(d_model * expand_v) + self.key_dim = int(d_model * expand_k) + self.num_heads = num_heads + assert mode in ['fused_chunk', 'chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`." + assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}" + assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}" + self.head_qk_dim = self.key_dim // num_heads + self.head_v_dim = self.value_dim // num_heads + self.gate_fn = get_activation_fn(activation=str(gate_fn)) + self.q_proj = nn.Linear(d_model, d_model, bias=False) + self.k_proj = nn.Linear(d_model, d_model, bias=False) + self.v_proj = nn.Linear(d_model, self.value_dim, bias=False) + self.g_proj = nn.Linear(d_model, self.value_dim, bias=False) + self.out_proj = nn.Linear(self.value_dim, d_model, bias=False) + + self.group_norm = RMSNorm(self.head_v_dim, eps=layernorm_eps) + self.rotary = RotaryEmbedding(dim=self.head_qk_dim, interleaved=False) + self.reset_parameters() + + + def reset_parameters(self): + nn.init.xavier_uniform_(self.q_proj.weight, gain=2 ** -2.5) + nn.init.xavier_uniform_(self.k_proj.weight, gain=2 ** -2.5) + nn.init.xavier_uniform_(self.v_proj.weight, gain=2 ** -2.5) + nn.init.xavier_uniform_(self.g_proj.weight, gain=2 ** -2.5) + nn.init.xavier_uniform_(self.out_proj.weight, gain=2 ** -2.5) + + def forward(self, x): + mode = self.mode + q1 = rearrange(self.q_proj(x), '... (h d) -> ... h d', h=self.num_heads) + k1 = rearrange(self.k_proj(x), '... (h d) -> ... h d', h=self.num_heads) + q, k = self.rotary(q1, k1) + q, k = q.transpose(1, 2), k.transpose(1, 2) + v = rearrange(self.v_proj(x), 'b n (h d) -> b h n d', h=self.num_heads) + if mode == 'fused_chunk': + o = fused_chunk_retention(q, k, v) + elif mode == 'parallel': + o = parallel_retention(q, k, v) + elif mode == 'fused_recurrent': + o = fused_recurrent_retention(q, k, v) + # TODO: need fix to allow different d_head_qk and d_head_v for "chunk" form + else: + raise NotImplementedError + o = self.group_norm(rearrange(o, 'b h n d -> b n h d')) + return self.out_proj(rearrange(o, 'b n h d -> b n (h d)') * self.gate_fn(self.g_proj(x))) + + +if __name__ == '__main__': + import torch + batch = 4 + seq_len = 1024 + d_model = 1024 + x = torch.randn(batch, seq_len, d_model).to(torch.bfloat16).cuda().requires_grad_(True) + model = MultiScaleRetention().to(torch.bfloat16).cuda() + y = model(x) + print(y.shape) + y.sum().backward() + print(x.grad.shape) + print(x.grad.shape) diff --git a/flash_linear_attention/fla/layers/rebased.py b/flash_linear_attention/fla/layers/rebased.py new file mode 100644 index 0000000..1ebf3c6 --- /dev/null +++ b/flash_linear_attention/fla/layers/rebased.py @@ -0,0 +1,258 @@ +# -*- coding: utf-8 -*- + +""" +Linear attention in Based. +https://github.com/HazyResearch/zoology/blob/main/zoology/mixers/based.py +""" +import math + +import opt_einsum as oe +import torch +import torch.nn as nn +from einops import rearrange + +from fla.ops.triton.rebased_fast import parallel_rebased + + +def init_feature_map(feature_map: str = 'none', **kwargs: any): + """ + Initialize query and key mapping for linear attention + """ + if feature_map in [None, 'none', 'identity']: + return FeatureMap(**kwargs) + # Taylor series approximations to exp(x) + elif feature_map == 'taylor_exp': + return TaylorExp(**kwargs) + else: + raise NotImplementedError( + f'Sorry "{feature_map}" feature map not implemented.') + + +class FeatureMap(nn.Module): + """ + Parent feature map; default is identity function + """ + + def __init__(self, + input_dim: int, + temp: int = None, + head_dim_idx: int = -1, + eps: float = 1e-12, + **kwargs: any): + super().__init__() + self.input_dim = input_dim + self.head_dim_idx = head_dim_idx + self.temp = 1. if temp is None else temp + self.eps = eps + + def forward(self, x: torch.Tensor): + """ + Assume x.shape is (batch_size, n_heads, seq_len, head_dim) + """ + return x + + +class TaylorExp(FeatureMap): + """ + Feature map to compute 2nd-order Taylor approx. of exp(q^T k / sqrt(d)) + """ + + def __init__(self, input_dim: int, **kwargs: any): + super().__init__(input_dim, **kwargs) + self.r2 = math.sqrt(2) + self.rd = math.sqrt(self.input_dim) + self.rrd = math.sqrt(self.rd) + self.tril_indices = torch.tril_indices( + self.input_dim, self.input_dim, -1) + + # Running these in parallel + def forward(self, x: torch.Tensor): + # Get 2nd-order terms (rearrange(x * x), '... m n -> ... (m n)') + x2 = (x.unsqueeze(-1) * x.unsqueeze(-2) + ).flatten(start_dim=-2) / self.r2 + return torch.cat( + [ + (torch.ones(x[..., :1].shape).to(x.device) / self.r2), + # x / self.rrd, rebased_fast + x2 / self.rd + ], + dim=self.head_dim_idx + ) + + def forward_mem_save(self, x: torch.Tensor) -> torch.Tensor: + """ + Compute f(x) s.t. f(x)^T f(x') = 1 + x^Tx' + (x^Tx')^2 / 2 + -> Assume x.shape is (batch_size, n_heads, seq_len, head_dim) + """ + # Slow but memory-saving way to compute 2nd-order terms; how do w/o outer-product first? + x2 = oe.contract('...m,...n->...mn', x, x) / self.rd + x2d = torch.diagonal(x2, dim1=-2, dim2=-1) / self.r2 + x2 = x2[..., self.tril_indices[0], self.tril_indices[1]] + x = torch.cat( + [ + (torch.ones(x[..., :1].shape).to(x.device) / self.r2), + # x / self.rrd, + x2d, + x2 + ], + dim=-1 + ) + return x + + +class ReBasedLinearAttention(nn.Module): + def __init__( + self, + d_model: int, + l_max: int = 2048, + feature_dim: int = 16, + num_key_value_heads: int = 12, + num_heads: int = 12, + feature_name: str = "taylor_exp", + eps: float = 1e-12, + causal: bool = True, + mode: str = "parallel", + ): + super().__init__() + self.d_model = d_model + self.l_max = l_max + self.mode = mode + assert self.mode in ["fused_chunk", "parallel"] + + # linear attention + self.feature_name = feature_name + self.feature_dim = feature_dim + self.num_key_value_heads = num_key_value_heads + self.num_heads = num_heads + self.head_dim = self.d_model // self.num_key_value_heads + self.causal = causal + feature_map_kwargs = { + 'input_dim': self.feature_dim, + 'head_dim_idx': -1, + 'temp': 1., + 'eps': 1e-12 + } + self.feature_map = init_feature_map( + feature_map=self.feature_name, **feature_map_kwargs) + self.proj_q = nn.Linear( + self.d_model, self.feature_dim * self.num_heads, bias=False) + self.proj_k = nn.Linear( + self.d_model, self.feature_dim * self.num_heads, bias=False) + self.proj_v = nn.Linear( + self.d_model, self.num_key_value_heads * self.head_dim, bias=False) + self.proj_o = nn.Linear( + self.num_heads * self.head_dim, self.d_model, bias=False) + self.dropout = nn.Identity() + self.eps = eps + + def forward(self, hidden_states: torch.Tensor, **kwargs): + mode = self.mode + b, l, _ = hidden_states.size() + q, k, v = self.proj_q(hidden_states), self.proj_k( + hidden_states), self.proj_v(hidden_states) + q, k, v = map(lambda x: rearrange( + x, "b l (h d) -> b h l d", h=self.num_heads), [q, k, v]) + if mode == "fused_chunk": + assert q.shape[-1] <= 16 + #o = fused_chunk_based(q, k, v, True, True) + elif mode == 'parallel': + assert q.shape[-1] <= 128 + o = parallel_rebased(q, k, v, self.eps, True, True) + o = rearrange(o, "b h l d -> b l (h d)") + o = self.proj_o(o) + o = self.dropout(o) + return o + + # https://github.com/HazyResearch/zoology/blob/main/zoology/mixers/based.py#L119 + + def forward_reference(self, hidden_states: torch.Tensor, filters: torch.Tensor = None, *args, **kwargs): + """ + x (torch.Tensor): tensor of shape (b, d, l) + y (torch.Tensor): tensor of shape (b, d, l) + """ + # hidden_states = hidden_states.transpose(1, 2) + b, l, _ = hidden_states.size() + q, k, v = self.proj_q(hidden_states), self.proj_k( + hidden_states), self.proj_v(hidden_states) + + q = q.view(b, l, self.num_heads, self.feature_dim).transpose(1, 2) + k = k.view(b, l, self.num_key_value_heads, + self.feature_dim).transpose(1, 2) + v = v.view(b, l, self.num_key_value_heads, + self.head_dim).transpose(1, 2) + + # Linear attention + q, k = self.feature_map(q), self.feature_map(k) + q, k, v = q.unsqueeze(-2), k.unsqueeze(-2), v.unsqueeze(-1) + + # Compute attention + if self.causal: + y = ((q * (k * v).cumsum(2)).sum(-1) / ((q * k.cumsum(2)).sum(-1) + self.eps)) + else: + y = ((q * (k * v).sum(2, True)).sum(-1) / ((q * k.sum(2, True)).sum(-1) + self.eps)) + y = rearrange(y, 'b h l d -> b l (h d)') + y = self.proj_o(y.to(hidden_states.dtype)) + y = self.dropout(y) + return y.to(hidden_states.dtype) + + +if __name__ == '__main__': + batch = 4 + seq_len = 1024 + d_model = 1024 + dtype = torch.float32 + x = torch.randn(batch, seq_len, d_model).to( + dtype).cuda().requires_grad_(True) + dy = torch.randn(batch, seq_len, d_model).to( + dtype).cuda() + model = ReBasedLinearAttention(d_model=d_model).to(dtype).cuda() + y = model(x) + y.backward(dy, retain_graph=True) + x_grad, x.grad = x.grad, None + + proj_q_grad, model.proj_q.weight.grad = model.proj_q.weight.grad, None + proj_k_grad, model.proj_k.weight.grad = model.proj_k.weight.grad, None + proj_v_grad, model.proj_v.weight.grad = model.proj_v.weight.grad, None + + x.requires_grad_(True) + y2 = model.forward_reference(x) + y2.backward(dy) + print((y - y2).abs().max().item()) + # assert y.allclose(y2, 0, 1e-4), breakpoint() + print((x_grad - x.grad).abs().max().item()) + # assert x_grad.allclose(x.grad, 0, 1e-4), breakpoint() + print((proj_q_grad - model.proj_q.weight.grad).abs().max().item()) + print((proj_k_grad - model.proj_k.weight.grad).abs().max().item()) + print((proj_v_grad - model.proj_v.weight.grad).abs().max().item()) + print("All good with rebased!") + + starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) + + for d_model in [16, 64]: + model = ReBasedLinearAttention(d_model=d_model).to(dtype).cuda() + for seq_len in [256, 1024, 4096]: + timings_f = [] + timings_b = [] + for i in range(100): + x = torch.randn(batch, seq_len, d_model).to( + dtype).cuda().requires_grad_(True) + dy = torch.randn(batch, seq_len, d_model).to( + dtype).cuda() + + starter.record() + y = model(x) + ender.record() + # WAIT FOR GPU SYNC + torch.cuda.synchronize() + curr_time = starter.elapsed_time(ender) + timings_f.append(curr_time) + + starter.record() + y.backward(dy) + ender.record() + + torch.cuda.synchronize() + curr_time = starter.elapsed_time(ender) + timings_b.append(curr_time) + + print(f"fseq len {seq_len}, d_model {d_model}, forward time: {sum(timings_f) / len(timings_f)}, backward time: {sum(timings_b) / len(timings_b)}") \ No newline at end of file diff --git a/flash_linear_attention/fla/layers/rebased_fast.py b/flash_linear_attention/fla/layers/rebased_fast.py new file mode 100644 index 0000000..875808b --- /dev/null +++ b/flash_linear_attention/fla/layers/rebased_fast.py @@ -0,0 +1,229 @@ +# -*- coding: utf-8 -*- + +""" +Linear attention in Based. +https://github.com/HazyResearch/zoology/blob/main/zoology/mixers/based.py +""" +import math + +import opt_einsum as oe +import torch +import torch.nn as nn +from einops import rearrange + +from fla.ops.triton.rebased_fast import parallel_rebased + + +def init_feature_map(feature_map: str = 'none', **kwargs: any): + """ + Initialize query and key mapping for linear attention + """ + if feature_map in [None, 'none', 'identity']: + return FeatureMap(**kwargs) + # Taylor series approximations to exp(x) + elif feature_map == 'taylor_exp': + return TaylorExp(**kwargs) + else: + raise NotImplementedError( + f'Sorry "{feature_map}" feature map not implemented.') + + +class FeatureMap(nn.Module): + """ + Parent feature map; default is identity function + """ + + def __init__(self, + input_dim: int, + temp: int = None, + head_dim_idx: int = -1, + eps: float = 1e-12, + **kwargs: any): + super().__init__() + self.input_dim = input_dim + self.head_dim_idx = head_dim_idx + self.temp = 1. if temp is None else temp + self.eps = eps + + def forward(self, x: torch.Tensor): + """ + Assume x.shape is (batch_size, n_heads, seq_len, head_dim) + """ + return x + + +class TaylorExp(FeatureMap): + """ + Feature map to compute 2nd-order Taylor approx. of exp(q^T k / sqrt(d)) + """ + + def __init__(self, input_dim: int, **kwargs: any): + super().__init__(input_dim, **kwargs) + self.rd = math.sqrt(self.input_dim) + self.rrd = math.sqrt(self.rd) + + # Running these in parallel + def forward(self, x: torch.Tensor): + # Get 2nd-order terms (rearrange(x * x), '... m n -> ... (m n)') + x2 = (x.unsqueeze(-1) * x.unsqueeze(-2) + ).flatten(start_dim=-2) + return x2 / self.rd + + + +class ReBasedLinearAttention(nn.Module): + def __init__( + self, + d_model: int, + l_max: int = 2048, + feature_dim: int = 16, + num_key_value_heads: int = 12, + num_heads: int = 12, + feature_name: str = "taylor_exp", + eps: float = 1e-12, + causal: bool = True, + mode: str = "parallel", + ): + super().__init__() + self.d_model = d_model + self.l_max = l_max + self.mode = mode + assert self.mode in ["fused_chunk", "parallel"] + + # linear attention + self.feature_name = feature_name + self.feature_dim = feature_dim + self.num_key_value_heads = num_key_value_heads + self.num_heads = num_heads + self.head_dim = self.d_model // self.num_key_value_heads + self.causal = causal + feature_map_kwargs = { + 'input_dim': self.feature_dim, + 'head_dim_idx': -1, + 'temp': 1., + 'eps': 1e-12 + } + self.feature_map = init_feature_map( + feature_map=self.feature_name, **feature_map_kwargs) + self.proj_q = nn.Linear( + self.d_model, self.feature_dim * self.num_heads, bias=False) + self.proj_k = nn.Linear( + self.d_model, self.feature_dim * self.num_heads, bias=False) + self.proj_v = nn.Linear( + self.d_model, self.num_key_value_heads * self.head_dim, bias=False) + self.proj_o = nn.Linear( + self.num_heads * self.head_dim, self.d_model, bias=False) + self.dropout = nn.Identity() + self.eps = eps + + def forward(self, hidden_states: torch.Tensor, **kwargs): + mode = self.mode + b, l, _ = hidden_states.size() + q, k, v = self.proj_q(hidden_states), self.proj_k( + hidden_states), self.proj_v(hidden_states) + q, k, v = map(lambda x: rearrange( + x, "b l (h d) -> b h l d", h=self.num_heads), [q, k, v]) + if mode == "fused_chunk": + assert q.shape[-1] <= 16 + #o = fused_chunk_based(q, k, v, True, True) + elif mode == 'parallel': + assert q.shape[-1] <= 128 + o = parallel_rebased(q, k, v, self.eps, True, True) + o = rearrange(o, "b h l d -> b l (h d)") + o = self.proj_o(o) + o = self.dropout(o) + return o + + # https://github.com/HazyResearch/zoology/blob/main/zoology/mixers/based.py#L119 + + def forward_reference(self, hidden_states: torch.Tensor, filters: torch.Tensor = None, *args, **kwargs): + """ + x (torch.Tensor): tensor of shape (b, d, l) + y (torch.Tensor): tensor of shape (b, d, l) + """ + # hidden_states = hidden_states.transpose(1, 2) + b, l, _ = hidden_states.size() + q, k, v = self.proj_q(hidden_states), self.proj_k( + hidden_states), self.proj_v(hidden_states) + + q = q.view(b, l, self.num_heads, self.feature_dim).transpose(1, 2) + k = k.view(b, l, self.num_key_value_heads, + self.feature_dim).transpose(1, 2) + v = v.view(b, l, self.num_key_value_heads, + self.head_dim).transpose(1, 2) + + # Linear attention + q, k = self.feature_map(q), self.feature_map(k) + q, k, v = q.unsqueeze(-2), k.unsqueeze(-2), v.unsqueeze(-1) + + # Compute attention + if self.causal: + y = ((q * (k * v).cumsum(2)).sum(-1) / ((q * k.cumsum(2)).sum(-1) + self.eps)) + else: + y = ((q * (k * v).sum(2, True)).sum(-1) / ((q * k.sum(2, True)).sum(-1) + self.eps)) + y = rearrange(y, 'b h l d -> b l (h d)') + y = self.proj_o(y.to(hidden_states.dtype)) + y = self.dropout(y) + return y.to(hidden_states.dtype) + + +if __name__ == '__main__': + batch = 4 + seq_len = 1024 + d_model = 1024 + dtype = torch.float32 + x = torch.randn(batch, seq_len, d_model).to( + dtype).cuda().requires_grad_(True) + dy = torch.randn(batch, seq_len, d_model).to( + dtype).cuda() + model = ReBasedLinearAttention(d_model=d_model).to(dtype).cuda() + y = model(x) + y.backward(dy, retain_graph=True) + x_grad, x.grad = x.grad, None + proj_q_grad, model.proj_q.weight.grad = model.proj_q.weight.grad, None + proj_k_grad, model.proj_k.weight.grad = model.proj_k.weight.grad, None + proj_v_grad, model.proj_v.weight.grad = model.proj_v.weight.grad, None + x.requires_grad_(True) + y2 = model.forward_reference(x) + y2.backward(dy) + print((y - y2).abs().max().item()) + # assert y.allclose(y2, 0, 1e-4) + print((x_grad - x.grad).abs().max().item()) + # assert x_grad.allclose(x.grad, 0, 1e-4) + + print((proj_q_grad - model.proj_q.weight.grad).abs().max().item()) + print((proj_k_grad - model.proj_k.weight.grad).abs().max().item()) + print((proj_v_grad - model.proj_v.weight.grad).abs().max().item()) + + print("All good with rebased fast!") + + starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) + + for d_model in [16, 64]: + model = ReBasedLinearAttention(d_model=d_model).to(dtype).cuda() + for seq_len in [256, 1024, 4096]: + timings_f = [] + timings_b = [] + for i in range(100): + x = torch.randn(batch, seq_len, d_model).to( + dtype).cuda().requires_grad_(True) + dy = torch.randn(batch, seq_len, d_model).to( + dtype).cuda() + + starter.record() + y = model(x) + ender.record() + # WAIT FOR GPU SYNC + torch.cuda.synchronize() + curr_time = starter.elapsed_time(ender) + timings_f.append(curr_time) + + starter.record() + y.backward(dy) + ender.record() + + torch.cuda.synchronize() + curr_time = starter.elapsed_time(ender) + timings_b.append(curr_time) + + print(f"fseq len {seq_len}, d_model {d_model}, forward time: {sum(timings_f) / len(timings_f)}, backward time: {sum(timings_b) / len(timings_b)}") \ No newline at end of file diff --git a/flash_linear_attention/fla/modules/__init__.py b/flash_linear_attention/fla/modules/__init__.py new file mode 100644 index 0000000..d4c9909 --- /dev/null +++ b/flash_linear_attention/fla/modules/__init__.py @@ -0,0 +1,11 @@ +# -*- coding: utf-8 -*- + +from .convolution import LongConvolution, ShortConvolution, ImplicitLongConvolution +from .rmsnorm import RMSNorm +from .rotary import RotaryEmbedding + +__all__ = [ + 'LongConvolution', 'ShortConvolution', 'ImplicitLongConvolution', + 'RMSNorm', + 'RotaryEmbedding' +] diff --git a/flash_linear_attention/fla/modules/convolution.py b/flash_linear_attention/fla/modules/convolution.py new file mode 100644 index 0000000..ac3bbc5 --- /dev/null +++ b/flash_linear_attention/fla/modules/convolution.py @@ -0,0 +1,195 @@ +# -*- coding: utf-8 -*- + +# from https://github.com/HazyResearch/zoology/blob/main/zoology/mixers/convolution.py +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from einops import rearrange + + +def fft_conv(u, k, dropout_mask, gelu=True, k_rev=None): + seqlen = u.shape[-1] + fft_size = 2 * seqlen + k_f = torch.fft.rfft(k, n=fft_size) / fft_size + if k_rev is not None: + k_rev_f = torch.fft.rfft(k_rev, n=fft_size) / fft_size + k_f = k_f + k_rev_f.conj() + u_f = torch.fft.rfft(u.to(dtype=k.dtype), n=fft_size) + + if len(u.shape) > 3: + k_f = k_f.unsqueeze(1) + y = torch.fft.irfft(u_f * k_f, n=fft_size, norm="forward")[..., :seqlen] + + out = y + u + if gelu: + out = F.gelu(out) + if dropout_mask is not None: + return (out * rearrange(dropout_mask, "b H -> b H 1")).to(dtype=u.dtype) + else: + return out.to(dtype=u.dtype) + + +class ShortConvolution(nn.Module): + """ + Simple wrapper around nn.Conv1d that accepts dimension last. + """ + + def __init__( + self, + d_model: int, + kernel_size: int + ): + super().__init__() + self.conv = nn.Conv1d( + in_channels=d_model, + out_channels=d_model, + kernel_size=kernel_size, + groups=d_model, + padding=kernel_size - 1, + ) + + def forward(self, x: torch.Tensor): + """ + Args: + x: (b, l, d) tensor + Returns: + y: (b, l, d) tensor + """ + l = x.size(1) + y = self.conv(x.transpose(1, 2))[..., :l].transpose(1, 2) + return y + + +class LongConvolution(nn.Module): + """ + LongConvolution applies a convolution operation on the input tensor using a fixed + filter of length l_max. + The filter is learned during training and is applied using FFT convolution. + Args: + d_model (int): The number of expected features in the input and output. + l_max (int): The maximum sequence length. + Returns: + y: (b, l, d) tensor + """ + def __init__( + self, + d_model: int, + l_max: int, + **kwargs, + ): + """ + Initializes the LongConvolution module. + Args: + d_model (int): The number of expected features in the input and output. + l_max (int): The maximum sequence length. + """ + super().__init__() + self.d_model = d_model + self.filter = nn.Parameter(torch.randn(self.d_model, l_max), requires_grad=True) + + def forward(self, x: torch.Tensor, *args, **kwargs): + """ + Applies the LongConvolution operation on the input tensor. + Args: + x: (b, l, d) tensor + Returns: + y: (b, l, d) tensor + """ + x = x.transpose(1, 2) + y = fft_conv(x, self.filter, dropout_mask=None, gelu=False) + y = y.transpose(1, 2) + return y.to(dtype=x.dtype) + + +class PositionalEmbedding(nn.Module): + def __init__(self, emb_dim: int, seq_len: int, **kwargs): + """Complex exponential positional embeddings for implicit long convolution filters.""" + super().__init__() + + self.seq_len = seq_len + # The time embedding fed to the filteres is normalized so that t_f = 1 + t = torch.linspace(0, 1, self.seq_len)[None, :, None] # 1, L, 1 + + if emb_dim > 1: + bands = (emb_dim - 1) // 2 + # To compute the right embeddings we use the "proper" linspace + t_rescaled = torch.linspace(0, seq_len - 1, seq_len)[None, :, None] + w = 2 * math.pi * t_rescaled / seq_len # 1, L, 1 + + f = torch.linspace(1e-4, bands - 1, bands)[None, None] + z = torch.exp(-1j * f * w) + z = torch.cat([t, z.real, z.imag], dim=-1) + self.z = nn.Parameter(z, requires_grad=False) + + def forward(self, L): + return self.z[:, :L] + +class ImplicitLongConvolution(nn.Module): + """ + Long convolution with implicit filter parameterized by an MLP. + + Args: + d_model (int): The number of expected features in the input and output. + l_max (int): The maximum sequence length. + d_emb (int, optional): The dimension of the positional embeddings. Must be odd and greater or equal to 3 (time, sine and cosine). Defaults to 3. + d_hidden (int, optional): The number of features in the hidden layer of the MLP. Defaults to 16. + + Attributes: + pos_emb (PositionalEmbedding): The positional embedding layer. + mlp (nn.Sequential): The MLP that parameterizes the implicit filter. + + """ + + + def __init__( + self, + d_model: int, + l_max: int, + d_emb: int=3, + d_hidden: int = 16, + **kwargs, + ): + """ + Long convolution with implicit filter parameterized by an MLP. + + + """ + super().__init__() + self.d_model = d_model + self.d_emb = d_emb + + + assert ( + d_emb % 2 != 0 and d_emb >= 3 + ), "d_emb must be odd and greater or equal to 3 (time, sine and cosine)" + self.pos_emb = PositionalEmbedding(d_emb, l_max) + + # final linear layer + self.mlp = nn.Sequential( + nn.Linear(d_emb, d_hidden), + torch.nn.ReLU(), + nn.Linear(d_hidden, d_model), + ) + + + def filter(self, l: int, *args, **kwargs): + k = self.mlp(self.pos_emb(l)) + + return k.transpose(1, 2) + + def forward(self, x: torch.Tensor, *args, **kwargs): + """ + Args: + x: (b, l, d) tensor + Returns: + y: (b, l, d) tensor + """ + x = x.transpose(1, 2) + k = self.filter(x.shape[-1]) + y = fft_conv(x, k, dropout_mask=None, gelu=False) + + y = y.transpose(1, 2) + return y.to(dtype=x.dtype) \ No newline at end of file diff --git a/flash_linear_attention/fla/modules/rmsnorm.py b/flash_linear_attention/fla/modules/rmsnorm.py new file mode 100644 index 0000000..e2da44e --- /dev/null +++ b/flash_linear_attention/fla/modules/rmsnorm.py @@ -0,0 +1,647 @@ +# -*- coding: utf-8 -*- + +# Copyright (c) 2023, Tri Dao. +# https://github.com/state-spaces/mamba/blob/fb7b5310fa865dbd62aa059b1e26f2b431363e2a/mamba_ssm/ops/triton/layernorm.py +# Implement residual + layer_norm / rms_norm. + +# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html +# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate. +# This is faster for dimensions up to 8k, but after that it's much slower due to register spilling. +# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine. + +import math + +import torch +import torch.nn.functional as F +from torch.cuda.amp import custom_fwd, custom_bwd + +import triton +import triton.language as tl + + +def layer_norm_ref(x, weight, bias, residual=None, eps=1e-6, prenorm=False, upcast=False): + dtype = x.dtype + if upcast: + weight = weight.float() + bias = bias.float() if bias is not None else None + if upcast: + x = x.float() + residual = residual.float() if residual is not None else residual + if residual is not None: + x = (x + residual).to(x.dtype) + out = F.layer_norm(x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps).to( + dtype + ) + return out if not prenorm else (out, x) + + +def rms_norm_ref(x, weight, bias, residual=None, eps=1e-6, prenorm=False, upcast=False): + dtype = x.dtype + if upcast: + weight = weight.float() + bias = bias.float() if bias is not None else None + if upcast: + x = x.float() + residual = residual.float() if residual is not None else residual + if residual is not None: + x = (x + residual).to(x.dtype) + rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps) + out = (x * rstd * weight) + \ + bias if bias is not None else (x * rstd * weight) + out = out.to(dtype) + return out if not prenorm else (out, x) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16), + triton.Config({}, num_warps=32), + ], + key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"], +) +# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) +# @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None}) +@triton.jit +def _layer_norm_fwd_1pass_kernel( + X, # pointer to the input + Y, # pointer to the output + W, # pointer to the weights + B, # pointer to the biases + RESIDUAL, # pointer to the residual + RESIDUAL_OUT, # pointer to the residual + Mean, # pointer to the mean + Rstd, # pointer to the 1/std + stride_x_row, # how much to increase the pointer when moving by 1 row + stride_y_row, + stride_res_row, + stride_res_out_row, + N, # number of columns in X + eps, # epsilon to avoid division by zero + IS_RMS_NORM: tl.constexpr, + BLOCK_N: tl.constexpr, + HAS_RESIDUAL: tl.constexpr, + STORE_RESIDUAL_OUT: tl.constexpr, + HAS_BIAS: tl.constexpr, +): + # Map the program id to the row of X and Y it should compute. + row = tl.program_id(0) + X += row * stride_x_row + Y += row * stride_y_row + if HAS_RESIDUAL: + RESIDUAL += row * stride_res_row + if STORE_RESIDUAL_OUT: + RESIDUAL_OUT += row * stride_res_out_row + # Compute mean and variance + cols = tl.arange(0, BLOCK_N) + x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) + if HAS_RESIDUAL: + residual = tl.load(RESIDUAL + cols, mask=cols < + N, other=0.0).to(tl.float32) + x += residual + if STORE_RESIDUAL_OUT: + tl.store(RESIDUAL_OUT + cols, x, mask=cols < N) + if not IS_RMS_NORM: + mean = tl.sum(x, axis=0) / N + tl.store(Mean + row, mean) + xbar = tl.where(cols < N, x - mean, 0.0) + var = tl.sum(xbar * xbar, axis=0) / N + else: + xbar = tl.where(cols < N, x, 0.0) + var = tl.sum(xbar * xbar, axis=0) / N + rstd = 1 / tl.sqrt(var + eps) + tl.store(Rstd + row, rstd) + # Normalize and apply linear transformation + mask = cols < N + w = tl.load(W + cols, mask=mask).to(tl.float32) + if HAS_BIAS: + b = tl.load(B + cols, mask=mask).to(tl.float32) + x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd + y = x_hat * w + b if HAS_BIAS else x_hat * w + # Write output + tl.store(Y + cols, y, mask=mask) + + +def _layer_norm_fwd( + x, weight, bias, eps, residual=None, out_dtype=None, residual_dtype=None, is_rms_norm=False +): + if residual is not None: + residual_dtype = residual.dtype + M, N = x.shape + assert x.stride(-1) == 1 + if residual is not None: + assert residual.stride(-1) == 1 + assert residual.shape == (M, N) + assert weight.shape == (N,) + assert weight.stride(-1) == 1 + if bias is not None: + assert bias.stride(-1) == 1 + assert bias.shape == (N,) + # allocate output + y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype) + assert y.stride(-1) == 1 + if residual is not None or (residual_dtype is not None and residual_dtype != x.dtype): + residual_out = torch.empty(M, N, device=x.device, dtype=residual_dtype) + assert residual_out.stride(-1) == 1 + else: + residual_out = None + mean = torch.empty((M,), dtype=torch.float32, + device="cuda") if not is_rms_norm else None + rstd = torch.empty((M,), dtype=torch.float32, device="cuda") + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_N: + raise RuntimeError( + "This layer norm doesn't support feature dim >= 64KB.") + # heuristics for number of warps + with torch.cuda.device(x.device.index): + _layer_norm_fwd_1pass_kernel[(M,)]( + x, + y, + weight, + bias, + residual, + residual_out, + mean, + rstd, + x.stride(0), + y.stride(0), + residual.stride(0) if residual is not None else 0, + residual_out.stride(0) if residual_out is not None else 0, + N, + eps, + is_rms_norm, + BLOCK_N, + residual is not None, + residual_out is not None, + bias is not None, + ) + # residual_out is None if residual is None and residual_dtype == input_dtype + return y, mean, rstd, residual_out if residual_out is not None else x + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16), + triton.Config({}, num_warps=32), + ], + key=["N", "HAS_DRESIDUAL", "STORE_DRESIDUAL", "IS_RMS_NORM", "HAS_BIAS"], +) +# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) +# @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None}) +# @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None}) +@triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None}) +@triton.jit +def _layer_norm_bwd_kernel( + X, # pointer to the input + W, # pointer to the weights + B, # pointer to the biases + Y, # pointer to the output to be recomputed + DY, # pointer to the output gradient + DX, # pointer to the input gradient + DW, # pointer to the partial sum of weights gradient + DB, # pointer to the partial sum of biases gradient + DRESIDUAL, + DRESIDUAL_IN, + Mean, # pointer to the mean + Rstd, # pointer to the 1/std + stride_x_row, # how much to increase the pointer when moving by 1 row + stride_y_row, + stride_dy_row, + stride_dx_row, + stride_dres_row, + stride_dres_in_row, + M, # number of rows in X + N, # number of columns in X + eps, # epsilon to avoid division by zero + rows_per_program, + IS_RMS_NORM: tl.constexpr, + BLOCK_N: tl.constexpr, + HAS_DRESIDUAL: tl.constexpr, + STORE_DRESIDUAL: tl.constexpr, + HAS_BIAS: tl.constexpr, + RECOMPUTE_OUTPUT: tl.constexpr, +): + # Map the program id to the elements of X, DX, and DY it should compute. + row_block_id = tl.program_id(0) + row_start = row_block_id * rows_per_program + cols = tl.arange(0, BLOCK_N) + mask = cols < N + X += row_start * stride_x_row + if HAS_DRESIDUAL: + DRESIDUAL += row_start * stride_dres_row + if STORE_DRESIDUAL: + DRESIDUAL_IN += row_start * stride_dres_in_row + DY += row_start * stride_dy_row + DX += row_start * stride_dx_row + if RECOMPUTE_OUTPUT: + Y += row_start * stride_y_row + w = tl.load(W + cols, mask=mask).to(tl.float32) + if RECOMPUTE_OUTPUT and HAS_BIAS: + b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32) + dw = tl.zeros((BLOCK_N,), dtype=tl.float32) + if HAS_BIAS: + db = tl.zeros((BLOCK_N,), dtype=tl.float32) + row_end = min((row_block_id + 1) * rows_per_program, M) + for row in range(row_start, row_end): + # Load data to SRAM + x = tl.load(X + cols, mask=mask, other=0).to(tl.float32) + dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32) + if not IS_RMS_NORM: + mean = tl.load(Mean + row) + rstd = tl.load(Rstd + row) + # Compute dx + xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd + xhat = tl.where(mask, xhat, 0.0) + if RECOMPUTE_OUTPUT: + y = xhat * w + b if HAS_BIAS else xhat * w + tl.store(Y + cols, y, mask=mask) + wdy = w * dy + dw += dy * xhat + if HAS_BIAS: + db += dy + if not IS_RMS_NORM: + c1 = tl.sum(xhat * wdy, axis=0) / N + c2 = tl.sum(wdy, axis=0) / N + dx = (wdy - (xhat * c1 + c2)) * rstd + else: + c1 = tl.sum(xhat * wdy, axis=0) / N + dx = (wdy - xhat * c1) * rstd + if HAS_DRESIDUAL: + dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32) + dx += dres + # Write dx + if STORE_DRESIDUAL: + tl.store(DRESIDUAL_IN + cols, dx, mask=mask) + tl.store(DX + cols, dx, mask=mask) + + X += stride_x_row + if HAS_DRESIDUAL: + DRESIDUAL += stride_dres_row + if STORE_DRESIDUAL: + DRESIDUAL_IN += stride_dres_in_row + if RECOMPUTE_OUTPUT: + Y += stride_y_row + DY += stride_dy_row + DX += stride_dx_row + tl.store(DW + row_block_id * N + cols, dw, mask=mask) + if HAS_BIAS: + tl.store(DB + row_block_id * N + cols, db, mask=mask) + + +def _layer_norm_bwd( + dy, + x, + weight, + bias, + eps, + mean, + rstd, + dresidual=None, + has_residual=False, + is_rms_norm=False, + x_dtype=None, + recompute_output=False, +): + M, N = x.shape + assert x.stride(-1) == 1 + assert dy.stride(-1) == 1 + assert dy.shape == (M, N) + if dresidual is not None: + assert dresidual.stride(-1) == 1 + assert dresidual.shape == (M, N) + assert weight.shape == (N,) + assert weight.stride(-1) == 1 + if bias is not None: + assert bias.stride(-1) == 1 + assert bias.shape == (N,) + # allocate output + dx = ( + torch.empty_like(x) + if x_dtype is None + else torch.empty(M, N, dtype=x_dtype, device=x.device) + ) + dresidual_in = torch.empty_like( + x) if has_residual and dx.dtype != x.dtype else None + y = torch.empty(M, N, dtype=dy.dtype, + device=dy.device) if recompute_output else None + + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_N: + raise RuntimeError( + "This layer norm doesn't support feature dim >= 64KB.") + sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count + _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device) + _db = ( + torch.empty((sm_count, N), dtype=torch.float32, device=bias.device) + if bias is not None + else None + ) + rows_per_program = math.ceil(M / sm_count) + grid = (sm_count,) + with torch.cuda.device(x.device.index): + _layer_norm_bwd_kernel[grid]( + x, + weight, + bias, + y, + dy, + dx, + _dw, + _db, + dresidual, + dresidual_in, + mean, + rstd, + x.stride(0), + 0 if not recompute_output else y.stride(0), + dy.stride(0), + dx.stride(0), + dresidual.stride(0) if dresidual is not None else 0, + dresidual_in.stride(0) if dresidual_in is not None else 0, + M, + N, + eps, + rows_per_program, + is_rms_norm, + BLOCK_N, + dresidual is not None, + dresidual_in is not None, + bias is not None, + ) + dw = _dw.sum(0).to(weight.dtype) + db = _db.sum(0).to(bias.dtype) if bias is not None else None + # Don't need to compute dresidual_in separately in this case + if has_residual and dx.dtype == x.dtype: + dresidual_in = dx + return (dx, dw, db, dresidual_in) if not recompute_output else (dx, dw, db, dresidual_in, y) + + +class LayerNormFn(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x, + weight, + bias, + residual=None, + eps=1e-6, + prenorm=False, + residual_in_fp32=False, + is_rms_norm=False, + ): + x_shape_og = x.shape + # reshape input data into 2D tensor + x = x.reshape(-1, x.shape[-1]) + if x.stride(-1) != 1: + x = x.contiguous() + if residual is not None: + assert residual.shape == x_shape_og + residual = residual.reshape(-1, residual.shape[-1]) + if residual.stride(-1) != 1: + residual = residual.contiguous() + weight = weight.contiguous() + if bias is not None: + bias = bias.contiguous() + residual_dtype = ( + residual.dtype + if residual is not None + else (torch.float32 if residual_in_fp32 else None) + ) + y, mean, rstd, residual_out = _layer_norm_fwd( + x, weight, bias, eps, residual, residual_dtype=residual_dtype, is_rms_norm=is_rms_norm + ) + ctx.save_for_backward(residual_out, weight, bias, mean, rstd) + ctx.x_shape_og = x_shape_og + ctx.eps = eps + ctx.is_rms_norm = is_rms_norm + ctx.has_residual = residual is not None + ctx.prenorm = prenorm + ctx.x_dtype = x.dtype + y = y.reshape(x_shape_og) + return y if not prenorm else (y, residual_out.reshape(x_shape_og)) + + @staticmethod + def backward(ctx, dy, *args): + x, weight, bias, mean, rstd = ctx.saved_tensors + dy = dy.reshape(-1, dy.shape[-1]) + if dy.stride(-1) != 1: + dy = dy.contiguous() + assert dy.shape == x.shape + if ctx.prenorm: + dresidual = args[0] + dresidual = dresidual.reshape(-1, dresidual.shape[-1]) + if dresidual.stride(-1) != 1: + dresidual = dresidual.contiguous() + assert dresidual.shape == x.shape + else: + dresidual = None + dx, dw, db, dresidual_in = _layer_norm_bwd( + dy, + x, + weight, + bias, + ctx.eps, + mean, + rstd, + dresidual, + ctx.has_residual, + ctx.is_rms_norm, + x_dtype=ctx.x_dtype, + ) + return ( + dx.reshape(ctx.x_shape_og), + dw, + db, + dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None, + None, + None, + None, + None, + ) + + +def layer_norm_fn( + x, + weight, + bias, + residual=None, + eps=1e-6, + prenorm=False, + residual_in_fp32=False, + is_rms_norm=False, +): + return LayerNormFn.apply(x, weight, bias, residual, eps, prenorm, residual_in_fp32, is_rms_norm) + + +def rms_norm_fn(x, weight, bias, residual=None, prenorm=False, residual_in_fp32=False, eps=1e-6): + return LayerNormFn.apply(x, weight, bias, residual, eps, prenorm, residual_in_fp32, True) + + +class RMSNorm(torch.nn.Module): + def __init__(self, hidden_size, eps=1e-5): + # factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.eps = eps + self.weight = torch.nn.Parameter(torch.empty(hidden_size)) + self.register_parameter("bias", None) + self.reset_parameters() + + def reset_parameters(self): + torch.nn.init.ones_(self.weight) + + def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False): + return rms_norm_fn( + x, + self.weight, + self.bias, + residual=residual, + eps=self.eps, + prenorm=prenorm, + residual_in_fp32=residual_in_fp32, + ) + + +class LayerNormLinearFn(torch.autograd.Function): + @staticmethod + @custom_fwd + def forward( + ctx, + x, + norm_weight, + norm_bias, + linear_weight, + linear_bias, + residual=None, + eps=1e-6, + prenorm=False, + residual_in_fp32=False, + is_rms_norm=False, + ): + x_shape_og = x.shape + # reshape input data into 2D tensor + x = x.reshape(-1, x.shape[-1]) + if x.stride(-1) != 1: + x = x.contiguous() + if residual is not None: + assert residual.shape == x_shape_og + residual = residual.reshape(-1, residual.shape[-1]) + if residual.stride(-1) != 1: + residual = residual.contiguous() + norm_weight = norm_weight.contiguous() + if norm_bias is not None: + norm_bias = norm_bias.contiguous() + residual_dtype = ( + residual.dtype + if residual is not None + else (torch.float32 if residual_in_fp32 else None) + ) + y, mean, rstd, residual_out = _layer_norm_fwd( + x, + norm_weight, + norm_bias, + eps, + residual, + out_dtype=None if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype(), + residual_dtype=residual_dtype, + is_rms_norm=is_rms_norm, + ) + y = y.reshape(x_shape_og) + dtype = torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else y.dtype + linear_weight = linear_weight.to(dtype) + linear_bias = linear_bias.to( + dtype) if linear_bias is not None else None + out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias) + # We don't store y, will be recomputed in the backward pass to save memory + ctx.save_for_backward(residual_out, norm_weight, + norm_bias, linear_weight, mean, rstd) + ctx.x_shape_og = x_shape_og + ctx.eps = eps + ctx.is_rms_norm = is_rms_norm + ctx.has_residual = residual is not None + ctx.prenorm = prenorm + ctx.x_dtype = x.dtype + ctx.linear_bias_is_none = linear_bias is None + return out if not prenorm else (out, residual_out.reshape(x_shape_og)) + + @staticmethod + @custom_bwd + def backward(ctx, dout, *args): + x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors + dout = dout.reshape(-1, dout.shape[-1]) + dy = F.linear(dout, linear_weight.t()) + dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0) + if dy.stride(-1) != 1: + dy = dy.contiguous() + assert dy.shape == x.shape + if ctx.prenorm: + dresidual = args[0] + dresidual = dresidual.reshape(-1, dresidual.shape[-1]) + if dresidual.stride(-1) != 1: + dresidual = dresidual.contiguous() + assert dresidual.shape == x.shape + else: + dresidual = None + dx, dnorm_weight, dnorm_bias, dresidual_in, y = _layer_norm_bwd( + dy, + x, + norm_weight, + norm_bias, + ctx.eps, + mean, + rstd, + dresidual, + ctx.has_residual, + ctx.is_rms_norm, + x_dtype=ctx.x_dtype, + recompute_output=True, + ) + dlinear_weight = torch.einsum("bo,bi->oi", dout, y) + return ( + dx.reshape(ctx.x_shape_og), + dnorm_weight, + dnorm_bias, + dlinear_weight, + dlinear_bias, + dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None, + None, + None, + None, + None, + ) + + +def layer_norm_linear_fn( + x, + norm_weight, + norm_bias, + linear_weight, + linear_bias, + residual=None, + eps=1e-6, + prenorm=False, + residual_in_fp32=False, + is_rms_norm=False, +): + return LayerNormLinearFn.apply( + x, + norm_weight, + norm_bias, + linear_weight, + linear_bias, + residual, + eps, + prenorm, + residual_in_fp32, + is_rms_norm, + ) diff --git a/flash_linear_attention/fla/modules/rotary.py b/flash_linear_attention/fla/modules/rotary.py new file mode 100644 index 0000000..b326cb0 --- /dev/null +++ b/flash_linear_attention/fla/modules/rotary.py @@ -0,0 +1,312 @@ +# -*- coding: utf-8 -*- + +# Copyright (c) 2023, Tri Dao. + +import math +from typing import Optional, Tuple, Union + +import torch +from einops import rearrange, repeat +from fla.ops.triton.rotary import apply_rotary + + +def rotate_half(x, interleaved=False): + if not interleaved: + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + else: + x1, x2 = x[..., ::2], x[..., 1::2] + return rearrange(torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2) + + +def apply_rotary_emb_torch(x, cos, sin, interleaved=False): + """ + x: (batch_size, seqlen, nheads, headdim) + cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2) + """ + ro_dim = cos.shape[-1] * 2 + assert ro_dim <= x.shape[-1] + cos = repeat( + cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)") + sin = repeat( + sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)") + return torch.cat( + [x[..., :ro_dim] * cos + + rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]], + dim=-1, + ) + + +class ApplyRotaryEmb(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x, + cos, + sin, + interleaved=False, + inplace=False, + seqlen_offsets: Union[int, torch.Tensor] = 0, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + ): + out = apply_rotary( + x, + cos, + sin, + seqlen_offsets=seqlen_offsets, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + interleaved=interleaved, + inplace=inplace, + ) + if isinstance(seqlen_offsets, int): + # Can't save int with save_for_backward + ctx.save_for_backward(cos, sin, cu_seqlens) + ctx.seqlen_offsets = seqlen_offsets + else: + ctx.save_for_backward(cos, sin, cu_seqlens, seqlen_offsets) + ctx.seqlen_offsets = None + ctx.interleaved = interleaved + ctx.inplace = inplace + ctx.max_seqlen = max_seqlen + return out if not inplace else x + + @staticmethod + def backward(ctx, do): + seqlen_offsets = ctx.seqlen_offsets + if seqlen_offsets is None: + cos, sin, cu_seqlens, seqlen_offsets = ctx.saved_tensors + else: + cos, sin, cu_seqlens = ctx.saved_tensors + # TD [2023-09-02]: For some reason Triton (2.0.0.post1) errors with + # "[CUDA]: invalid device context", and cloning makes it work. Idk why. Triton 2.1.0 works. + if not ctx.interleaved and not ctx.inplace: + do = do.clone() + dx = apply_rotary( + do, + cos, + sin, + seqlen_offsets=seqlen_offsets, + cu_seqlens=cu_seqlens, + max_seqlen=ctx.max_seqlen, + interleaved=ctx.interleaved, + inplace=ctx.inplace, + conjugate=True, + ) + return dx, None, None, None, None, None, None, None + + +def apply_rotary_emb( + x, + cos, + sin, + interleaved=False, + inplace=False, + seqlen_offsets: Union[int, torch.Tensor] = 0, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, +): + """ + Arguments: + x: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None + else (total_seqlen, nheads, headdim) + cos, sin: (seqlen_rotary, rotary_dim / 2) + interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead + of 1st half and 2nd half (GPT-NeoX style). + inplace: if True, apply rotary embedding in-place. + seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount. + Most commonly used in inference when we have KV cache. + cu_seqlens: (batch + 1,) or None + max_seqlen: int + Return: + out: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None + else (total_seqlen, nheads, headdim) + rotary_dim must be <= headdim + Apply rotary embedding to the first rotary_dim of x. + """ + return ApplyRotaryEmb.apply( + x, cos, sin, interleaved, inplace, seqlen_offsets, cu_seqlens, max_seqlen + ) + + +# For backward compatibility +apply_rotary_emb_func = apply_rotary_emb + + +class RotaryEmbedding(torch.nn.Module): + """ + The rotary position embeddings from RoFormer_ (Su et. al). + A crucial insight from the method is that the query and keys are + transformed by rotation matrices which depend on the relative positions. + + Other implementations are available in the Rotary Transformer repo_ and in + GPT-NeoX_, GPT-NeoX was an inspiration + + .. _RoFormer: https://arxiv.org/abs/2104.09864 + .. _repo: https://github.com/ZhuiyiTechnology/roformer + .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox + + If scale_base is not None, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554). + A recommended value for scale_base is 512: https://github.com/HazyResearch/flash-attention/issues/96 + Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py + """ + + def __init__( + self, + dim: int, + base=10000.0, + interleaved=False, + scale_base=None, + pos_idx_in_fp32=True, + device=None, + ): + """ + interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead + of 1st half and 2nd half (GPT-NeoX style). + pos_idx_in_fp32: if True, the position indices [0.0, ..., seqlen - 1] are in fp32, + otherwise they might be in lower precision. + This option was added because previously (before 2023-07-02), when we construct + the position indices, we use the dtype of self.inv_freq. In most cases this would + be fp32, but if the model is trained in pure bf16 (not mixed precision), then + self.inv_freq would be bf16, and the position indices are also in bf16. + Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the + embeddings for some positions will coincide. + To maintain compatibility with models previously trained in pure bf16, + we add this option. + """ + super().__init__() + self.dim = dim + self.base = float(base) + self.pos_idx_in_fp32 = pos_idx_in_fp32 + # Generate and save the inverse frequency buffer (non trainable) + inv_freq = self._compute_inv_freq(device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.interleaved = interleaved + self.scale_base = scale_base + scale = ( + (torch.arange(0, dim, 2, device=device, + dtype=torch.float32) + 0.4 * dim) / (1.4 * dim) + if scale_base is not None + else None + ) + self.register_buffer("scale", scale, persistent=False) + + self._seq_len_cached = 0 + self._cos_cached = None + self._sin_cached = None + self._cos_k_cached = None + self._sin_k_cached = None + + def _compute_inv_freq(self, device=None): + return 1.0 / ( + self.base + ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim) + ) + + def _update_cos_sin_cache(self, seqlen, device=None, dtype=None): + # Reset the tables if the sequence length has changed, + # if we're on a new device (possibly due to tracing for instance), + # or if we're switching from inference mode to training + if ( + seqlen > self._seq_len_cached + or self._cos_cached is None + or self._cos_cached.device != device + or self._cos_cached.dtype != dtype + or (self.training and self._cos_cached.is_inference()) + ): + self._seq_len_cached = seqlen + # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16 + # And the output of arange can be quite large, so bf16 would lose a lot of precision. + # However, for compatibility reason, we add an option to use the dtype of self.inv_freq. + if self.pos_idx_in_fp32: + t = torch.arange(seqlen, device=device, dtype=torch.float32) + # We want fp32 here as well since inv_freq will be multiplied with t, and the output + # will be large. Having it in bf16 will lose a lot of precision and cause the + # cos & sin output to change significantly. + # We want to recompute self.inv_freq if it was not loaded in fp32 + if self.inv_freq.dtype != torch.float32: + inv_freq = self._compute_inv_freq(device=device) + else: + inv_freq = self.inv_freq + else: + t = torch.arange(seqlen, device=device, + dtype=self.inv_freq.dtype) + inv_freq = self.inv_freq + # Don't do einsum, it converts fp32 to fp16 under AMP + # freqs = torch.einsum("i,j->ij", t, self.inv_freq) + freqs = torch.outer(t, inv_freq) + if self.scale is None: + self._cos_cached = torch.cos(freqs).to(dtype) + self._sin_cached = torch.sin(freqs).to(dtype) + else: + power = ( + torch.arange(seqlen, dtype=self.scale.dtype, + device=self.scale.device) + - seqlen // 2 + ) / self.scale_base + scale = self.scale.to( + device=power.device) ** rearrange(power, "s -> s 1") + # We want the multiplication by scale to happen in fp32 + self._cos_cached = (torch.cos(freqs) * scale).to(dtype) + self._sin_cached = (torch.sin(freqs) * scale).to(dtype) + self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype) + self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype) + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + seqlen_offset: Union[int, torch.Tensor] = 0, + max_seqlen: Optional[int] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + qkv: (batch, seqlen, 3, nheads, headdim) if kv is none, + else it's just q of shape (batch, seqlen, nheads, headdim) + kv: (batch, seqlen, 2, nheads, headdim) + seqlen_offset: (batch_size,) or int. Each sequence in x is shifted by this amount. + Most commonly used in inference when we have KV cache. + If it's a tensor of shape (batch_size,), then to update the cos / sin cache, one + should pass in max_seqlen, which will update the cos / sin cache up to that length. + Apply rotary embedding *inplace* to qkv and / or kv. + """ + seqlen = q.shape[1] + if max_seqlen is not None: + self._update_cos_sin_cache( + max_seqlen, device=q.device, dtype=q.dtype) + elif isinstance(seqlen_offset, int): + self._update_cos_sin_cache( + seqlen + seqlen_offset, device=q.device, dtype=q.dtype) + if self.scale is None: + q = apply_rotary_emb_func( + q, + self._cos_cached, + self._sin_cached, + interleaved=self.interleaved, + seqlen_offsets=seqlen_offset, + ) + k = apply_rotary_emb_func( + k, + self._cos_cached, + self._sin_cached, + interleaved=self.interleaved, + seqlen_offsets=seqlen_offset, + ) + + else: + q = apply_rotary_emb_func( + q, + self._cos_cached, + self._sin_cached, + interleaved=self.interleaved, + seqlen_offsets=seqlen_offset, + ) + k = apply_rotary_emb_func( + k, + self._cos_k_cached, + self._sin_k_cached, + interleaved=self.interleaved, + seqlen_offsets=seqlen_offset, + ) + + return q, k diff --git a/flash_linear_attention/fla/ops/__init__.py b/flash_linear_attention/fla/ops/__init__.py new file mode 100644 index 0000000..0f8c005 --- /dev/null +++ b/flash_linear_attention/fla/ops/__init__.py @@ -0,0 +1,25 @@ +# -*- coding: utf-8 -*- + +from fla.ops.torch import (naive_chunk_based, naive_parallel_based, + naive_recurrent_gla, naive_retention) +from fla.ops.triton import (chunk_gla, chunk_retention, fused_chunk_based, + fused_chunk_gla, fused_chunk_retention, + fused_recurrent_gla, fused_recurrent_retention, + parallel_based, parallel_retention, parallel_rebased) + +__all__ = [ + 'naive_chunk_based', + 'naive_parallel_based', + 'naive_recurrent_gla', + 'naive_retention', + 'chunk_gla', + 'chunk_retention', + 'fused_chunk_based', + 'fused_chunk_gla', + 'fused_chunk_retention', + 'fused_recurrent_gla', + 'fused_recurrent_retention', + 'parallel_based', + 'parallel_rebased', + 'parallel_retention', +] diff --git a/flash_linear_attention/fla/ops/cuda/gla/semiring/cal_A/inner_chunk16_dim16x.cpp b/flash_linear_attention/fla/ops/cuda/gla/semiring/cal_A/inner_chunk16_dim16x.cpp new file mode 100644 index 0000000..57cbb98 --- /dev/null +++ b/flash_linear_attention/fla/ops/cuda/gla/semiring/cal_A/inner_chunk16_dim16x.cpp @@ -0,0 +1,11 @@ +#include + +torch::Tensor fwd_cuda(torch::Tensor& Q, torch::Tensor& K, torch::Tensor& g_K); + +std::vector bwd_cuda(torch::Tensor Q, torch::Tensor K, + torch::Tensor g_K, torch::Tensor DQK); + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &fwd_cuda, "GLA compute A semiring (CUDA)"); + m.def("backward", &bwd_cuda, "GLA compute A semiring (CUDA)"); +} diff --git a/flash_linear_attention/fla/ops/cuda/gla/semiring/cal_A/inner_chunk16_dim16x_kernel.cu b/flash_linear_attention/fla/ops/cuda/gla/semiring/cal_A/inner_chunk16_dim16x_kernel.cu new file mode 100644 index 0000000..9e9ee34 --- /dev/null +++ b/flash_linear_attention/fla/ops/cuda/gla/semiring/cal_A/inner_chunk16_dim16x_kernel.cu @@ -0,0 +1,204 @@ +#include +#include +#include + +#include "ATen/ATen.h" + +typedef at::BFloat16 bf16; + +template +__global__ void fwd_inner_chunk16_dim16x(int batchSize, int M, int N_K, + scalar_t* Q, scalar_t* K, float* G_K, + scalar_t* QK) { + // Batch index + const int batchIdx = blockIdx.x; + // allocate buffer for current block in fast shared mem + __shared__ float Q_tile[16][16]; + __shared__ float K_tile[16][16]; + __shared__ float G_tile[16][16]; + __shared__ float G_tile_trans[16][16]; + + const uint threadCol = threadIdx.x % 16; + const uint threadRow = threadIdx.x / 16; + + int K_Stride = M * N_K; + + // Adjust the pointer for batch and matrix size + Q += batchIdx * K_Stride; + K += batchIdx * K_Stride; + G_K += batchIdx * K_Stride; + QK += batchIdx * M * M; + + float tmp = 0.0; + // printf("Hello world"); + // printf("%d, %d, %d \n", threadRow, threadCol, N_K); + for (int bkIdx = 0; bkIdx < N_K; bkIdx += 16) { + Q_tile[threadRow][threadCol] = (float)Q[threadRow * N_K + threadCol]; + K_tile[threadRow][threadCol] = (float)K[threadRow * N_K + threadCol]; + float tmp_gk = (float)G_K[threadRow * N_K + threadCol]; + G_tile[threadRow][threadCol] = (float)tmp_gk; + G_tile_trans[threadCol][threadRow] = (float)tmp_gk; + + __syncthreads(); + + Q += 16; + K += 16; + G_K += 16; + + if (threadCol <= threadRow) { + for (int dotIdx = 0; dotIdx < 16; ++dotIdx) { + // avoid bank conflict? + float exp_term = + expf(G_tile[threadRow][dotIdx] - G_tile_trans[dotIdx][threadCol]); + tmp += Q_tile[threadRow][dotIdx] * K_tile[threadCol][dotIdx] * exp_term; + } + } + __syncthreads(); + } + + if (threadCol <= threadRow) { + QK[threadRow * M + threadCol] = (scalar_t)tmp; + } else { + QK[threadRow * M + threadCol] = (scalar_t)0.0; + } +} + +template +__global__ void bwd_inner_chunk16_dim16x(int batchSize, int M, int N_K, + scalar_t* Q, scalar_t* K, float* G, + scalar_t* DQK, scalar_t* DQ, + scalar_t* DK, float* DG) { + // Batch index + const uint batchIdx = blockIdx.x; + + // allocate buffer for current block in fast shared mem + __shared__ float Q_tile[16][16]; + __shared__ float QK_tile[16][16]; + __shared__ float K_tile[16][16]; + __shared__ float G_tile[16][16]; + __shared__ float G_tile_trans[16][16]; + + const uint threadCol = threadIdx.x % 16; + const uint threadRow = threadIdx.x / 16; + + int K_Stride = M * N_K; + + Q += batchIdx * K_Stride; + DQ += batchIdx * K_Stride; + K += batchIdx * K_Stride; + DK += batchIdx * K_Stride; + G += batchIdx * K_Stride; + DG += batchIdx * K_Stride; + + DQK += batchIdx * M * M; + QK_tile[threadRow][threadCol] = + (threadCol <= threadRow) ? (float)DQK[threadRow * M + threadCol] : 0.0; + __syncthreads(); + + for (int bkIdx = 0; bkIdx < N_K; bkIdx += 16) { + Q_tile[threadRow][threadCol] = (float)Q[threadRow * N_K + threadCol]; + K_tile[threadRow][threadCol] = (float)K[threadRow * N_K + threadCol]; + float tmp_gk = (float)G[threadRow * N_K + threadCol]; + G_tile[threadRow][threadCol] = tmp_gk; + // G_tile_trans[threadCol][threadRow] = tmp_gk; + + __syncthreads(); + + float threadResults_dK = 0; + float threadResults_dQ = 0; + + for (uint dotIdx = threadRow; dotIdx < 16; dotIdx += 1) { + float tmp = + QK_tile[dotIdx][threadRow] * + expf(G_tile[dotIdx][threadCol] - G_tile[threadRow][threadCol]) * + Q_tile[dotIdx][threadCol]; + threadResults_dK += tmp; + } + + for (uint dotIdx = 0; dotIdx <= threadRow; dotIdx += 1) { + float tmp = + QK_tile[threadRow][dotIdx] * + expf(G_tile[threadRow][threadCol] - G_tile[dotIdx][threadCol]) * + K_tile[dotIdx][threadCol]; + threadResults_dQ += dotIdx <= threadRow ? tmp : 0; + } + + __syncthreads(); + DQ[threadRow * N_K + threadCol] = (scalar_t)threadResults_dQ; + DK[threadRow * N_K + threadCol] = (scalar_t)threadResults_dK; + DG[threadRow * N_K + threadCol] = + (threadResults_dQ * Q_tile[threadRow][threadCol] - + threadResults_dK * K_tile[threadRow][threadCol]); + Q += 16; + K += 16; + G += 16; + DQ += 16; + DK += 16; + DG += 16; + __syncthreads(); + } +} + +std::vector bwd_cuda(torch::Tensor Q, torch::Tensor K, + torch::Tensor g_K, torch::Tensor DQK) { + auto DQ = torch::empty_like(Q); + auto DK = torch::empty_like(K); + auto Dg_K = torch::empty_like(g_K); + + int B_size = Q.size(0); // This is the batch size dimension. + int H_size = Q.size(1); // This is the head dimension + int num_chunk = Q.size(2); // This is the chunk dimension. + int M = Q.size(-2); + int N_K = Q.size(-1); + + dim3 gridDim(B_size * H_size * num_chunk); + dim3 blockDim(256); + + switch (Q.type().scalarType()) { + case torch::ScalarType::BFloat16: + bwd_inner_chunk16_dim16x<<>>( + B_size * H_size * num_chunk, M, N_K, Q.data_ptr(), + K.data_ptr(), g_K.data_ptr(), DQK.data_ptr(), + DQ.data_ptr(), DK.data_ptr(), Dg_K.data_ptr()); + break; + default: + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + Q.scalar_type(), "bwd_inner_chunk16_dim16x", ([&] { + bwd_inner_chunk16_dim16x<<>>( + B_size * H_size * num_chunk, M, N_K, Q.data_ptr(), + K.data_ptr(), g_K.data_ptr(), + DQK.data_ptr(), DQ.data_ptr(), + DK.data_ptr(), Dg_K.data_ptr()); + })); + }; + return {DQ, DK, Dg_K}; +} + +torch::Tensor fwd_cuda(torch::Tensor& Q, torch::Tensor& K, torch::Tensor& g_K) { + auto QK = torch::empty( + {Q.size(0), Q.size(1), Q.size(2), Q.size(3), Q.size(3)}, Q.options()); + int B_size = Q.size(0); // This is the batch size dimension. + int H_size = Q.size(1); // This is the head dimension + int num_chunk = Q.size(2); // This is the chunk dimension. + int M = Q.size(-2); // this is the chunk size + int N_K = Q.size(-1); // this is the head_K dim + + dim3 gridDim(B_size * H_size * num_chunk); + dim3 blockDim(256); + switch (Q.type().scalarType()) { + case torch::ScalarType::BFloat16: + fwd_inner_chunk16_dim16x<<>>( + B_size * H_size * num_chunk, M, N_K, Q.data_ptr(), + K.data_ptr(), g_K.data_ptr(), QK.data_ptr()); + break; + default: + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + Q.scalar_type(), "fwd_inner_chunk16_dim16x", ([&] { + fwd_inner_chunk16_dim16x<<>>( + B_size * H_size * num_chunk, M, N_K, Q.data_ptr(), + K.data_ptr(), g_K.data_ptr(), + QK.data_ptr()); + })); + }; + return QK; +} \ No newline at end of file diff --git a/flash_linear_attention/fla/ops/torch/__init__.py b/flash_linear_attention/fla/ops/torch/__init__.py new file mode 100644 index 0000000..4b6852d --- /dev/null +++ b/flash_linear_attention/fla/ops/torch/__init__.py @@ -0,0 +1,7 @@ +# -*- coding: utf-8 -*- + +from .based import naive_chunk_based, naive_parallel_based +from .gla import naive_recurrent_gla +from .retention import naive_retention + +__all__ = ['naive_chunk_based', 'naive_parallel_based', 'naive_recurrent_gla', 'naive_retention'] diff --git a/flash_linear_attention/fla/ops/torch/based.py b/flash_linear_attention/fla/ops/torch/based.py new file mode 100644 index 0000000..ee4c6ef --- /dev/null +++ b/flash_linear_attention/fla/ops/torch/based.py @@ -0,0 +1,131 @@ +# -*- coding: utf-8 -*- + +import torch +from einops import rearrange + +from fla.ops.triton.based import fused_chunk_based, parallel_based + + +def naive_parallel_based(q, k, v, use_scale=True, use_norm=True): + if use_scale: + q = q * (q.shape[-1] ** -0.5) + attn = q @ k.transpose(-2, -1) + attn = 1 + attn + 1/2 * (attn ** 2) + attn.masked_fill_(~torch.tril(torch.ones( + q.shape[-2], q.shape[-2], dtype=torch.bool, device=q.device)), 0) + o = attn @ v + if use_norm: + z = attn.sum(-1) + return o / (z[..., None] + 1e-6) + else: + return o + + +def naive_chunk_based(q, k, v, chunk_size=256): + q = q * (q.shape[-1] ** -0.5) + + # compute normalizer. + k_cumsum = torch.cumsum(k, dim=-2) + kk_cumsum = torch.cumsum(k.unsqueeze(-1) * k.unsqueeze(-2), dim=-3) + # first + z = (q * k_cumsum).sum(-1) + # second order + z += (q.unsqueeze(-1) * q.unsqueeze(-2) * kk_cumsum).sum((-1, -2)) * 0.5 + # zero-th order + z += (torch.arange(0, q.shape[-2]).to(z.device) * 1.0 + 1.0)[None, None, :] + + # compute o + # constant term + _o = v.cumsum(-2) + + q = rearrange(q, 'b h (n c) d -> b h n c d', c=chunk_size) + + k = rearrange(k, 'b h (n c) d -> b h n c d', c=chunk_size) + v = rearrange(v, 'b h (n c) d -> b h n c d', c=chunk_size) + + intra_chunk_attn = q @ k.transpose(-2, -1) + intra_chunk_attn = intra_chunk_attn + 1/2 * (intra_chunk_attn ** 2) + intra_chunk_attn.masked_fill_( + ~torch.tril( + torch.ones(chunk_size, chunk_size, + dtype=torch.bool, device=q.device), + ), 0) + o = intra_chunk_attn @ v + + # quadractic term + kv = torch.einsum( + 'b h n c x, b h n c y, b h n c z -> b h n x y z', k, k, v) + kv = kv.cumsum(2) + kv = torch.cat([torch.zeros_like(kv[:, :, :1]), kv[:, :, :-1]], dim=2) + + o += 0.5 * torch.einsum('b h n x y z, b h n c x, b h n c y -> b h n c z', kv, q, q) + + # linear term + kv = torch.einsum('b h n c x, b h n c y -> b h n x y', k, v) + kv = kv.cumsum(2) + kv = torch.cat([torch.zeros_like(kv[:, :, :1]), kv[:, :, :-1]], dim=2) + o += torch.einsum('b h n x y, b h n c x -> b h n c y', kv, q) + + o = rearrange(o, 'b h n c d -> b h (n c) d') + o = o + _o + return o / (z[..., None] + 1e-6) + + +if __name__ == "__main__": + B = 4 + H = 4 + L = 128 + # D = 15 + dtype = torch.float32 + q = (torch.randn(B, H, L, 16).cuda().to(dtype)).requires_grad_(True) + k = (torch.randn(B, H, L, 16).cuda().to(dtype)).requires_grad_(True) + v = torch.randn(B, H, L, 128).cuda().to(dtype).requires_grad_(True) + + do = torch.randn_like(v).cuda() + ref = naive_parallel_based(q, k, v, True, True) + ref.backward(do, retain_graph=True) + ref_dq, q.grad = q.grad.clone(), None + ref_dk, k.grad = k.grad.clone(), None + ref_dv, v.grad = v.grad.clone(), None + + # tri = naive_chunk_based(q, k, v) + # tri.backward(do, retain_graph=True) + # tri_dq, q.grad = q.grad.clone(), None + # tri_dk, k.grad = k.grad.clone(), None + # tri_dv, v.grad = v.grad.clone(), None + + # assert ref.allclose(tri, 0, 1e-4), breakpoint() + # assert ref_dq.allclose(tri_dq, 0, 1e-4), breakpoint() + # assert ref_dk.allclose(tri_dk, 0, 1e-4), breakpoint() + # assert ref_dv.allclose(tri_dv, 0, 1e-4), breakpoint() + + tri = fused_chunk_based(q, k, v, True, True) + tri.backward(do, retain_graph=True) + tri_dq, q.grad = q.grad.clone(), None + tri_dk, k.grad = k.grad.clone(), None + tri_dv, v.grad = v.grad.clone(), None + print((ref-tri).abs().max()) + print((ref_dq-tri_dq).abs().max()) + print((ref_dk-tri_dk).abs().max()) + print((ref_dv-tri_dv).abs().max()) + + # assert ref.allclose(tri, 0, 1e-4), breakpoint() + # assert ref_dq.allclose(tri_dq, 0, 1e-4), breakpoint() + # assert ref_dk.allclose(tri_dk, 0, 1e-4), breakpoint() + # assert ref_dv.allclose(tri_dv, 0, 1e-4), breakpoint() + + tri = parallel_based(q, k, v, True, True) + tri.backward(do, retain_graph=True) + tri_dq, q.grad = q.grad.clone(), None + tri_dk, k.grad = k.grad.clone(), None + tri_dv, v.grad = v.grad.clone(), None + + print((ref-tri).abs().max()) + print((ref_dq-tri_dq).abs().max()) + print((ref_dk-tri_dk).abs().max()) + print((ref_dv-tri_dv).abs().max()) + + # assert ref.allclose(tri, 0, 1e-4), breakpoint() + # assert ref_dq.allclose(tri_dq, 0, 1e-4), breakpoint() + # assert ref_dk.allclose(tri_dk, 0, 1e-4), breakpoint() + # assert ref_dv.allclose(tri_dv, 0, 1e-4), breakpoint() diff --git a/flash_linear_attention/fla/ops/torch/gla.py b/flash_linear_attention/fla/ops/torch/gla.py new file mode 100644 index 0000000..8b06a24 --- /dev/null +++ b/flash_linear_attention/fla/ops/torch/gla.py @@ -0,0 +1,119 @@ +# -*- coding: utf-8 -*- + +import torch +import torch.nn.functional as F + +from fla.ops.triton.gla import fused_recurrent_gla + + +def ceildiv(a, b): + return -(a // -b) + + +def naive_recurrent_gla( + q, + k, + v, + gk, + initial_state=None, + output_final_state=False, + causal=True +): + orig_dtype = q.dtype + q, k, v, gk = map(lambda x: x.float(), (q, k, v, gk)) + batch_size, n_heads, seq_len, d_head_k = q.shape + _, _, _, d_head_v = v.shape + h = torch.zeros(batch_size, n_heads, d_head_k, d_head_v, dtype=torch.float32, device=q.device) + o = torch.zeros_like(v) + scale = d_head_k ** -0.5 + + if initial_state is not None: + h += initial_state + + for i in range(seq_len): + q_i = q[:, :, i, :] * scale + k_i = k[:, :, i] + v_i = v[:, :, i, :] + gk_i = gk[:, :, i].exp() + kv_i = k_i[..., None] * v_i[..., None, :] + h = h * gk_i[..., None] + kv_i + o_i = (q_i[..., None] * h).sum(-2) + o[:, :, i] = o_i + + if causal: + if output_final_state: + return o.to(orig_dtype), h + else: + return o.to(orig_dtype) + else: + o_reverse = torch.zeros_like(v) + h = torch.zeros(batch_size, n_heads, d_head_k, d_head_v, dtype=torch.float32, device=q.device) + for i in range(seq_len-1, -1, -1): + q_i = q[:, :, i, :] * scale + k_i = k[:, :, i] + v_i = v[:, :, i, :] + gk_i = gk[:, :, i].exp() + kv_i = k_i[..., None] * v_i[..., None, :] + h = h * gk_i[..., None] + kv_i + o_i = (q_i[..., None] * h).sum(-2) + o_reverse[:, :, i] = o_i + + return o, o_reverse + + +if __name__ == "__main__": + B = 4 + H = 4 + L = 512 + D = 128 + dtype = torch.float32 + q = (torch.randn(B, H, L, D).cuda().to(dtype)).requires_grad_(True) + k = (torch.randn(B, H, L, D).cuda().to(dtype)).requires_grad_(True) + v = torch.randn(B, H, L, D).cuda().to(dtype).requires_grad_(True) + g = F.logsigmoid(torch.rand(B, H, L, D)).cuda( + ).clamp_min(-1).to(torch.float32).requires_grad_(True) + + do = torch.rand_like(v).cuda() + do2 = torch.rand_like(v).cuda() + intial_state = torch.rand(B, H, D, D).cuda() + + ref, ref_rev = naive_recurrent_gla(q, k, v, g, causal=False) + + ref.backward(do, retain_graph=True) + ref_rev.backward(do2, retain_graph=True) + + ref_dq, q.grad = q.grad.clone(), None + ref_dk, k.grad = k.grad.clone(), None + ref_dv, v.grad = v.grad.clone(), None + ref_dg, g.grad = g.grad.clone(), None + + tri, tri_rev = fused_recurrent_gla( + q, k, v, g, initial_state=None, scale=D**-0.5, output_final_state=False, causal=False) + tri.backward(do, retain_graph=True) + tri_rev.backward(do2, retain_graph=True) + tri_dq, q.grad = q.grad.clone(), None + tri_dk, k.grad = k.grad.clone(), None + tri_dv, v.grad = v.grad.clone(), None + tri_dg, g.grad = g.grad.clone(), None + + assert ref.allclose(tri, 0, 1e-5), breakpoint() + assert ref_rev.allclose(tri_rev, 0, 1e-5), breakpoint() + assert ref_dq.allclose(tri_dq, 0, 1e-5), breakpoint() + assert ref_dk.allclose(tri_dk, 0, 1e-5), breakpoint() + assert ref_dv.allclose(tri_dv, 0, 1e-5), breakpoint() + assert ref_dg.allclose(tri_dg, 0, 1e-4), breakpoint() + + # tri = fused_chunk_gla(q, k, v, g) + # tri.backward(do, retain_graph=True) + # tri_dq, q.grad = q.grad.clone(), None + # tri_dk, k.grad = k.grad.clone(), None + # tri_dv, v.grad = v.grad.clone(), None + # tri_dg, g.grad = g.grad.clone(), None + + # assert ref.allclose(tri, 0, 1e-5), breakpoint() + # assert ref_dq.allclose(tri_dq, 0, 1e-5), breakpoint() + # assert ref_dk.allclose(tri_dk, 0, 1e-5), breakpoint() + # assert ref_dv.allclose(tri_dv, 0, 1e-5), breakpoint() + # assert ref_dg.allclose(tri_dg, 0, 1e-4), breakpoint() + # breakpoint() + print("Pass") diff --git a/flash_linear_attention/fla/ops/torch/retention.py b/flash_linear_attention/fla/ops/torch/retention.py new file mode 100644 index 0000000..15611bf --- /dev/null +++ b/flash_linear_attention/fla/ops/torch/retention.py @@ -0,0 +1,15 @@ +# -*- coding: utf-8 -*- + +import torch + + +def naive_retention(q, k, v): + orig_type = q.dtype + q, k, v = q.float(), k.float(), v.float() + _, n_heads, seq_len, d_head = q.shape + s = (1 - q.new_tensor(2., dtype=torch.float).pow(-5. - q.new_tensor(range(n_heads), dtype=torch.float))).log2() + n = q.new_tensor(range(seq_len), dtype=torch.float) + n = torch.exp2((n.unsqueeze(-1) - n) * s.view(-1, 1, 1)) * n.unsqueeze(-1).ge(n) + s = torch.einsum('bhqd,bhkd,hqk->bhqk', q * d_head ** -0.5, k, n.to(q.dtype)) + o = torch.einsum('bhqk,bhkd->bhqd', s, v) + return o.to(orig_type) diff --git a/flash_linear_attention/fla/ops/triton/__init__.py b/flash_linear_attention/fla/ops/triton/__init__.py new file mode 100644 index 0000000..74c6286 --- /dev/null +++ b/flash_linear_attention/fla/ops/triton/__init__.py @@ -0,0 +1,22 @@ +# -*- coding: utf-8 -*- + +from .based import fused_chunk_based, parallel_based +from .rebased import parallel_rebased +from .gla import chunk_gla, fused_chunk_gla, fused_recurrent_gla +from .retention import (chunk_retention, fused_chunk_retention, + fused_recurrent_retention, parallel_retention) +from .rotary import apply_rotary + +__all__ = [ + 'fused_chunk_based', + 'parallel_based', + 'parallel_rebased', + 'chunk_gla', + 'fused_chunk_gla', + 'fused_recurrent_gla', + 'chunk_retention', + 'fused_chunk_retention', + 'fused_recurrent_retention', + 'parallel_retention', + 'apply_rotary' +] diff --git a/flash_linear_attention/fla/ops/triton/abc/__init__.py b/flash_linear_attention/fla/ops/triton/abc/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/flash_linear_attention/fla/ops/triton/abc/chunk_fuse.py b/flash_linear_attention/fla/ops/triton/abc/chunk_fuse.py new file mode 100644 index 0000000..ea6589f --- /dev/null +++ b/flash_linear_attention/fla/ops/triton/abc/chunk_fuse.py @@ -0,0 +1,692 @@ +# -*- coding: utf-8 -*- + +# Copyright (c) 2023-2024, Yu Zhang, Songlin Yang + +import torch +import triton +import triton.language as tl + +from fla.ops.triton.utils import contiguous + + +@triton.jit +def chunk_abc_fwd_kernel_s( + q, + k, + s, + rk, # rescale term + ck, # scores normalized over a chunk + pk, # scores normalized over the sequence + s_qk_h, + s_qk_t, + s_qk_d, + s_sk_h, + s_sk_t, + s_sk_m, + T, + scale, + BT: tl.constexpr, + BK: tl.constexpr, + BM: tl.constexpr, + DK: tl.constexpr, + DM: tl.constexpr, + NT: tl.constexpr +): + i_m, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + n_bh = tl.num_programs(2) + + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (0, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BT), (0, 1)) + p_s = tl.make_block_ptr(s + (i_k * n_bh + i_bh)*s_sk_h, (T, DM), (s_sk_t, s_sk_m), (0, i_m * BM), (BT, BM), (1, 0)) + p_rk = tl.make_block_ptr(rk + i_bh * s_sk_t * NT, (NT * DM,), (s_sk_m,), (i_m * BM,), (BM,), (0,)) + p_ck = tl.make_block_ptr(ck + i_bh * s_sk_h, (T, DM), (s_sk_t, s_sk_m), (0, i_m * BM), (BT, BM), (1, 0)) + p_pk = tl.make_block_ptr(pk + i_bh * s_sk_h, (T, DM), (s_sk_t, s_sk_m), (0, i_m * BM), (BT, BM), (1, 0)) + + o_i = tl.arange(0, BT) + # [BT, BT] + m_s = o_i[:, None] >= o_i[None, :] + + b_hk = tl.zeros([BK, BM], dtype=tl.float32) + for _ in range(NT): + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + # [BK, BT] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BM,] + b_rk = tl.load(p_rk, boundary_check=(0,)) + # [BT, BM] + b_ck = tl.load(p_ck, boundary_check=(0, 1)) + b_pk = tl.load(p_pk, boundary_check=(0, 1)) + + # [BT, BM] + b_inter = tl.dot(b_q, b_hk.to(b_q.dtype), allow_tf32=False) * b_rk[None, :] + b_intra = tl.dot(tl.where(m_s, tl.dot(b_q, b_k, allow_tf32=False), 0).to(b_q.dtype), b_ck, allow_tf32=False) + b_s = (b_inter + b_intra) * b_pk + # [BK, BM] + b_hk = b_hk * b_rk[None, :] + tl.dot(b_k, b_ck, allow_tf32=False) + + tl.store(p_s, b_s.to(p_s.dtype.element_ty), boundary_check=(0, 1)) + + p_q = tl.advance(p_q, (BT, 0)) + p_k = tl.advance(p_k, (0, BT)) + p_s = tl.advance(p_s, (BT, 0)) + p_rk = tl.advance(p_rk, (DM,)) + p_ck = tl.advance(p_ck, (BT, 0)) + p_pk = tl.advance(p_pk, (BT, 0)) + + +@triton.jit +def chunk_abc_fwd_kernel_o( + p, + v, + o, + rv, # rescale term + cv, # scores normalized over a chunk + pv, # scores normalized over the sequence + s_qk_h, + s_qk_t, + s_qk_d, + s_sk_h, + s_sk_t, + s_sk_m, + T, + BT: tl.constexpr, + BM: tl.constexpr, + BV: tl.constexpr, + DM: tl.constexpr, + DV: tl.constexpr, + NT: tl.constexpr +): + i_v, i_m, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + n_bh = tl.num_programs(2) + + p_p = tl.make_block_ptr(p + i_bh * s_sk_h, (T, DM), (s_sk_t, s_sk_m), (0, i_m * BM), (BT, BM), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * s_qk_h, (T, DV), (s_qk_t, s_qk_d), (0, i_v * BV), (BT, BV), (1, 0)) + p_o = tl.make_block_ptr(o + (i_m * n_bh + i_bh)*s_qk_h, (T, DV), (s_qk_t, s_qk_d), (0, i_v * BV), (BT, BV), (1, 0)) + p_rv = tl.make_block_ptr(rv + i_bh * s_sk_t * NT, (NT * DM,), (s_sk_m,), (i_m * BM,), (BM,), (0,)) + p_cv = tl.make_block_ptr(cv + i_bh * s_sk_h, (DM, T), (s_sk_m, s_sk_t), (i_m * BM, 0), (BM, BT), (0, 1)) + p_pv = tl.make_block_ptr(pv + i_bh * s_sk_h, (T, DM), (s_sk_t, s_sk_m), (0, i_m * BM), (BT, BM), (1, 0)) + + o_i = tl.arange(0, BT) + # [BT, BT] + m_s = o_i[:, None] >= o_i[None, :] + + # [BM, BV] + b_hv = tl.zeros([BM, BV], dtype=tl.float32) + for _ in range(NT): + # [BT, BM] + b_p = tl.load(p_p, boundary_check=(0, 1)) + # [BT, DV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BM,] + b_rv = tl.load(p_rv, boundary_check=(0,)) + # [BM, BT] + b_cv = tl.load(p_cv, boundary_check=(0, 1)) + # [BT, BM] + b_pv = tl.load(p_pv, boundary_check=(0, 1)) + + b_p = b_p * b_pv + # [BT, BV] + b_inter = tl.dot((b_p * b_rv[None, :]).to(b_v.dtype), b_hv.to(b_v.dtype), allow_tf32=False) + b_intra = tl.where(m_s, tl.dot(b_p.to(b_v.dtype), b_cv, allow_tf32=False), 0) + b_intra = tl.dot(b_intra.to(b_v.dtype), b_v, allow_tf32=False) + b_o = b_inter + b_intra + # [BM, BV] + b_hv = b_hv * b_rv[:, None] + tl.dot(b_cv, b_v, allow_tf32=False) + + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + p_p = tl.advance(p_p, (BT, 0)) + p_v = tl.advance(p_v, (BT, 0)) + p_o = tl.advance(p_o, (BT, 0)) + p_rv = tl.advance(p_rv, (DM,)) + p_cv = tl.advance(p_cv, (0, BT)) + p_pv = tl.advance(p_pv, (BT, 0)) + + +@triton.jit +def chunk_abc_bwd_kernel_dp( + v, + rv, # rescale term + cv, # scores normalized over a chunk + pv, # scores normalized over the sequence + do, + dp, + s_qk_h, + s_qk_t, + s_qk_d, + s_sk_h, + s_sk_t, + s_sk_m, + T, + BT: tl.constexpr, + BV: tl.constexpr, + BM: tl.constexpr, + DV: tl.constexpr, + DM: tl.constexpr, + NT: tl.constexpr +): + i_m, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + n_bh = tl.num_programs(2) + + p_v = tl.make_block_ptr(v + i_bh * s_qk_h, (DV, T), (s_qk_d, s_qk_t), (i_v * BV, 0), (BV, BT), (0, 1)) + p_rv = tl.make_block_ptr(rv + i_bh * s_sk_t * NT, (NT * DM,), (s_sk_m,), (i_m * BM,), (BM,), (0,)) + p_cv = tl.make_block_ptr(cv + i_bh * s_sk_h, (T, DM), (s_sk_t, s_sk_m), (0, i_m * BM), (BT, BM), (1, 0)) + p_pv = tl.make_block_ptr(pv + i_bh * s_sk_h, (T, DM), (s_sk_t, s_sk_m), (0, i_m * BM), (BT, BM), (1, 0)) + p_do = tl.make_block_ptr(do + i_bh * s_qk_h, (T, DV), (s_qk_t, s_qk_d), (0, i_v * BV), (BT, BV), (1, 0)) + p_dp = tl.make_block_ptr(dp + (i_v * n_bh + i_bh)*s_sk_h, (T, DM), (s_sk_t, s_sk_m), (0, i_m * BM), (BT, BM), (1, 0)) + + o_i = tl.arange(0, BT) + # [BT, BT] + m_s = o_i[:, None] >= o_i[None, :] + + # [BV, BM] + b_hv = tl.zeros([BV, BM], dtype=tl.float32) + for _ in range(NT): + # [BV, BT] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BM,] + b_rv = tl.load(p_rv, boundary_check=(0,)) + # [BT, BM] + b_cv = tl.load(p_cv, boundary_check=(0, 1)) + b_pv = tl.load(p_pv, boundary_check=(0, 1)) + # [BT, BV] + b_do = tl.load(p_do, boundary_check=(0, 1)) + + # [BT, BM] + b_inter = tl.dot(b_do, b_hv.to(b_do.dtype), allow_tf32=False) * b_rv[None, :] + b_intra = tl.dot(tl.where(m_s, tl.dot(b_do, b_v, allow_tf32=False), 0).to(b_v.dtype), b_cv, allow_tf32=False) + b_dp = (b_inter + b_intra) * b_pv + # [BV, BM] + b_hv = b_hv * b_rv[None, :] + tl.dot(b_v, b_cv, allow_tf32=False) + + tl.store(p_dp, b_dp.to(p_dp.dtype.element_ty), boundary_check=(0, 1)) + + p_v = tl.advance(p_v, (0, BT)) + p_rv = tl.advance(p_rv, (DM,)) + p_cv = tl.advance(p_cv, (BT, 0)) + p_pv = tl.advance(p_pv, (BT, 0)) + p_do = tl.advance(p_do, (BT, 0)) + p_dp = tl.advance(p_dp, (BT, 0)) + + +@triton.jit +def chunk_abc_bwd_kernel_dq( + k, + rk, # rescale term + ck, # scores normalized over a chunk + dq, + ds, + s_qk_h, + s_qk_t, + s_qk_d, + s_sk_h, + s_sk_t, + s_sk_m, + T, + BT: tl.constexpr, + BK: tl.constexpr, + BM: tl.constexpr, + DK: tl.constexpr, + DM: tl.constexpr, + NT: tl.constexpr +): + i_k, i_m, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + n_bh = tl.num_programs(2) + + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (0, i_k * BK), (BT, BK), (1, 0)) + p_rk = tl.make_block_ptr(rk + i_bh * s_sk_t * NT, (NT * DM,), (s_sk_m,), (i_m * BM,), (BM,), (0,)) + p_ck = tl.make_block_ptr(ck + i_bh * s_sk_h, (DM, T), (s_sk_m, s_sk_t), (i_m * BM, 0), (BM, BT), (0, 1)) + p_dq = tl.make_block_ptr(dq + (i_m * n_bh + i_bh)*s_qk_h, (T, DK), (s_qk_t, s_qk_d), (0, i_k * BK), (BT, BK), (1, 0)) + p_ds = tl.make_block_ptr(ds + i_bh * s_sk_h, (T, DM), (s_sk_t, s_sk_m), (0, i_m * BM), (BT, BM), (1, 0)) + + o_i = tl.arange(0, BT) + # [BT, BT] + m_s = o_i[:, None] >= o_i[None, :] + + # [BM, BK] + b_hk = tl.zeros([BM, BK], dtype=tl.float32) + for _ in range(NT): + # [BT, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BM,] + b_rk = tl.load(p_rk, boundary_check=(0,)) + # [BM, BT] + b_ck = tl.load(p_ck, boundary_check=(0, 1)) + # [BT, BM] + b_ds = tl.load(p_ds, boundary_check=(0, 1)) + + # [BT, BK] + b_inter = tl.dot((b_ds * b_rk[None, :]).to(b_k.dtype), b_hk.to(b_k.dtype), allow_tf32=False) + b_intra = tl.dot(tl.where(m_s, tl.dot(b_ds, b_ck, allow_tf32=False), 0).to(b_k.dtype), b_k, allow_tf32=False) + b_dq = b_inter + b_intra + # [BM, BK] + b_hk = b_hk * b_rk[:, None] + tl.dot(b_ck, b_k, allow_tf32=False) + + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + + p_k = tl.advance(p_k, (BT, 0)) + p_rk = tl.advance(p_rk, (DM,)) + p_ck = tl.advance(p_ck, (0, BT)) + p_dq = tl.advance(p_dq, (BT, 0)) + p_ds = tl.advance(p_ds, (BT, 0)) + + +@triton.jit +def chunk_abc_bwd_kernel_dk( + q, + k, + rk, # rescale term + ck, # scores normalized over a chunk + ds, + dk, + dsk, + s_qk_h, + s_qk_t, + s_qk_d, + s_sk_h, + s_sk_t, + s_sk_m, + T, + BT: tl.constexpr, + BK: tl.constexpr, + BM: tl.constexpr, + DK: tl.constexpr, + DM: tl.constexpr, + NT: tl.constexpr +): + i_k, i_m, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + n_bh = tl.num_programs(2) + + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), ((NT-1)*BT, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, (NT-1)*BT), (BK, BT), (0, 1)) + p_rk = tl.make_block_ptr(rk + i_bh * s_sk_t * NT, (NT * DM,), (s_sk_m,), (i_m * BM,), (BM,), (0,)) + p_ck = tl.make_block_ptr(ck + i_bh * s_sk_h, (T, DM), (s_sk_t, s_sk_m), ((NT-1)*BT, i_m * BM), (BT, BM), (1, 0)) + p_ds = tl.make_block_ptr(ds + i_bh * s_sk_h, (DM, T), (s_sk_m, s_sk_t), (i_m * BM, (NT-1)*BT), (BM, BT), (0, 1)) + p_dk = tl.make_block_ptr(dk + (i_m*n_bh+i_bh)*s_qk_h, (T, DK), (s_qk_t, s_qk_d), ((NT-1)*BT, i_k * BK), (BT, BK), (1, 0)) + p_dsk = tl.make_block_ptr(dsk + (i_k*n_bh+i_bh)*s_sk_h, (T, DM), (s_sk_t, s_sk_m), ((NT-1)*BT, i_m * BM), (BT, BM), (1, 0)) + + o_i = tl.arange(0, BT) + # [BT, BT] + m_s, m_t = o_i[:, None] <= o_i[None, :], o_i[:, None] >= o_i[None, :] + + # [BM, BK] + b_dhk = tl.zeros([BM, BK], dtype=tl.float32) + for i in range(NT): + p_rk = tl.make_block_ptr(rk + i_bh * s_sk_t * NT, (NT * DM,), (s_sk_m,), (((NT-i) % NT) * DM + i_m * BM,), (BM,), (0,)) + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + # [BK, BT] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BM,] + b_rk = tl.load(p_rk, boundary_check=(0,)) + # [BT, BM] + b_ck = tl.load(p_ck, boundary_check=(0, 1)) + b_ds = tl.load(p_ds, boundary_check=(0, 1)) + + # [BT, BK] + b_inter = tl.dot((b_ck * b_rk[None, :]).to(b_q.dtype), b_dhk.to(b_q.dtype), allow_tf32=False) + b_intra = tl.dot(tl.where(m_s, tl.dot(b_ck, b_ds, allow_tf32=False), 0.).to(b_q.dtype), b_q, allow_tf32=False) + b_dk = b_inter + b_intra + + # [BM, BT] + b_inter = tl.dot(b_dhk.to(b_k.dtype), b_k, allow_tf32=False) * b_rk[:, None] + b_intra = tl.dot(b_ds, tl.where(m_t, tl.dot(b_q, b_k, allow_tf32=False), 0.).to(b_q.dtype), allow_tf32=False) + # [BT, BM] + b_dsk = b_ck * tl.trans(b_inter + b_intra) + + # [BM, BK] + b_dhk = b_dhk * b_rk[:, None] + tl.dot(b_ds, b_q, allow_tf32=False) + + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dsk, b_dsk.to(p_dsk.dtype.element_ty), boundary_check=(0, 1)) + + p_q = tl.advance(p_q, (-BT, 0)) + p_k = tl.advance(p_k, (0, -BT)) + p_ck = tl.advance(p_ck, (-BT, 0)) + p_ds = tl.advance(p_ds, (0, -BT)) + p_dk = tl.advance(p_dk, (-BT, 0)) + p_dsk = tl.advance(p_dsk, (-BT, 0)) + + +@triton.jit +def chunk_abc_bwd_kernel_dv( + do, + v, + rv, # rescale term + cv, # scores normalized over a chunk + p, + dv, + dsv, + s_qk_h, + s_qk_t, + s_qk_d, + s_sk_h, + s_sk_t, + s_sk_m, + T, + BT: tl.constexpr, + BV: tl.constexpr, + BM: tl.constexpr, + DV: tl.constexpr, + DM: tl.constexpr, + NT: tl.constexpr +): + i_v, i_m, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + n_bh = tl.num_programs(2) + + p_do = tl.make_block_ptr(do + i_bh * s_qk_h, (T, DV), (s_qk_t, s_qk_d), ((NT-1)*BT, i_v * BV), (BT, BV), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * s_qk_h, (DV, T), (s_qk_d, s_qk_t), (i_v * BV, (NT-1)*BT), (BV, BT), (0, 1)) + p_rv = tl.make_block_ptr(rv + i_bh * s_sk_t * NT, (NT * DM,), (s_sk_m,), (i_m * BM,), (BM,), (0,)) + p_cv = tl.make_block_ptr(cv + i_bh * s_sk_h, (T, DM), (s_sk_t, s_sk_m), ((NT-1)*BT, i_m * BM), (BT, BM), (1, 0)) + p_p = tl.make_block_ptr(p + i_bh * s_sk_h, (DM, T), (s_sk_m, s_sk_t), (i_m * BM, (NT-1)*BT), (BM, BT), (0, 1)) + p_dv = tl.make_block_ptr(dv + (i_m*n_bh+i_bh)*s_qk_h, (T, DV), (s_qk_t, s_qk_d), ((NT-1)*BT, i_v * BV), (BT, BV), (1, 0)) + p_dsv = tl.make_block_ptr(dsv + (i_v*n_bh+i_bh)*s_sk_h, (T, DM), (s_sk_t, s_sk_m), ((NT-1)*BT, i_m * BM), (BT, BM), (1, 0)) + + o_i = tl.arange(0, BT) + # [BT, BT] + m_s, m_t = o_i[:, None] <= o_i[None, :], o_i[:, None] >= o_i[None, :] + + # [BM, BV] + b_dhv = tl.zeros([BM, BV], dtype=tl.float32) + for i in range(NT): + p_rv = tl.make_block_ptr(rv + i_bh * s_sk_t * NT, (NT * DM,), (s_sk_m,), (((NT-i) % NT) * DM + i_m * BM,), (BM,), (0,)) + # [BT, BV] + b_do = tl.load(p_do, boundary_check=(0, 1)) + # [BV, BT] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BM,] + b_rv = tl.load(p_rv, boundary_check=(0,)) + # [BT, BM] + b_cv = tl.load(p_cv, boundary_check=(0, 1)) + # [BM, BT] + b_p = tl.load(p_p, boundary_check=(0, 1)) + + # [BT, BV] + b_inter = tl.dot((b_cv * b_rv[None, :]).to(b_do.dtype), b_dhv.to(b_do.dtype), allow_tf32=False) + b_intra = tl.dot(tl.where(m_s, tl.dot(b_cv, b_p, allow_tf32=False), 0.).to(b_do.dtype), b_do, allow_tf32=False) + b_dv = b_inter + b_intra + + b_inter = tl.dot(b_dhv.to(b_v.dtype), b_v, allow_tf32=False) * b_rv[:, None] + b_intra = tl.dot(b_p, tl.where(m_t, tl.dot(b_do, b_v, allow_tf32=False), 0.).to(b_do.dtype), allow_tf32=False) + # [BT, BM] + b_dsv = b_cv * tl.trans(b_inter + b_intra) + + # [BM, BV] + b_dhv = b_dhv * b_rv[:, None] + tl.dot(b_p, b_do, allow_tf32=False) + + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dsv, b_dsv.to(p_dsv.dtype.element_ty), boundary_check=(0, 1)) + + p_do = tl.advance(p_do, (-BT, 0)) + p_v = tl.advance(p_v, (0, -BT)) + p_cv = tl.advance(p_cv, (-BT, 0)) + p_p = tl.advance(p_p, (0, -BT)) + p_dv = tl.advance(p_dv, (-BT, 0)) + p_dsv = tl.advance(p_dsv, (-BT, 0)) + + +@triton.jit +def chunk_abc_fwd_kernel_cum( + s, + r, + c, + p, + s_sk_h, + s_sk_t, + s_sk_m, + T, + BT: tl.constexpr, + BM: tl.constexpr, + DM: tl.constexpr, + NT: tl.constexpr +): + i_m, i_bh = tl.program_id(0), tl.program_id(1) + p_s = tl.make_block_ptr(s + i_bh * s_sk_h, (T, DM), (s_sk_t, s_sk_m), (0, i_m * BM), (BT, BM), (1, 0)) + p_r = tl.make_block_ptr(r + i_bh * s_sk_t * NT, (NT * DM,), (s_sk_m,), (i_m * BM,), (BM,), (0,)) + p_c = tl.make_block_ptr(c + i_bh * s_sk_h, (T, DM), (s_sk_t, s_sk_m), (0, i_m * BM), (BT, BM), (1, 0)) + p_p = tl.make_block_ptr(p + i_bh * s_sk_h, (T, DM), (s_sk_t, s_sk_m), (0, i_m * BM), (BT, BM), (1, 0)) + + b_mp = tl.zeros([BM,], dtype=tl.float32) + b_zp = tl.zeros([BM,], dtype=tl.float32) + for i in range(NT): + # [BT, BM] + b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32) + + b_m = tl.max(b_s, 0) + # workaround for some weird compiler bugs + if i == 0: + b_r = tl.exp(-b_m) + else: + b_m = tl.maximum(b_mp, b_m) + b_r = tl.exp(b_mp - b_m) + b_c = tl.exp(b_s - b_m[None, :]) + b_z = tl.cumsum(b_c, 0) + (b_zp * b_r)[None, :] + b_p = tl.exp(-tl.log(b_z)) + b_mp = b_m + b_zp = tl.max(b_z, 0) + + tl.store(p_r, b_r.to(p_r.dtype.element_ty), boundary_check=(0,)) + tl.store(p_c, b_c.to(p_c.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_p, b_p.to(p_p.dtype.element_ty), boundary_check=(0, 1)) + + p_s = tl.advance(p_s, (BT, 0)) + p_r = tl.advance(p_r, (DM,)) + p_c = tl.advance(p_c, (BT, 0)) + p_p = tl.advance(p_p, (BT, 0)) + + +@triton.jit +def chunk_abc_bwd_kernel_rcum( + s, + r, + c, + o, + s_sk_h, + s_sk_t, + s_sk_m, + T, + BT: tl.constexpr, + BM: tl.constexpr, + DM: tl.constexpr, + NT: tl.constexpr +): + i_m, i_bh = tl.program_id(0), tl.program_id(1) + p_s = tl.make_block_ptr(s + i_bh * s_sk_h, (T, DM), (s_sk_t, s_sk_m), ((NT-1)*BT, i_m * BM), (BT, BM), (1, 0)) + p_c = tl.make_block_ptr(c + i_bh * s_sk_h, (T, DM), (s_sk_t, s_sk_m), ((NT-1)*BT, i_m * BM), (BT, BM), (1, 0)) + p_o = tl.make_block_ptr(o + i_bh * s_sk_h, (T, DM), (s_sk_t, s_sk_m), ((NT-1)*BT, i_m * BM), (BT, BM), (1, 0)) + + o_i = tl.arange(0, BT) + # [BT, BT] + m_t = tl.where(o_i[:, None] <= o_i[None, :], 1., 0.) + + b_z = tl.zeros([BM,], dtype=tl.float32) + for i in range(NT): + p_r = tl.make_block_ptr(r + i_bh * s_sk_t * NT, (NT * DM,), (s_sk_m,), (((NT-i) % NT) * DM + i_m * BM,), (BM,), (0,)) + # [BT, BM] + b_s = tl.load(p_s, boundary_check=(0, 1)) + # [BM,] + b_r = tl.load(p_r, boundary_check=(0,)) + # [BT, BM] + b_c = tl.load(p_c, boundary_check=(0, 1)) + b_o = tl.load(p_o, boundary_check=(0, 1)) + + b_z = b_z * b_r + b_o -= b_c * (b_z[None, :] + tl.dot(m_t.to(b_s.dtype), b_s, allow_tf32=False)) + + # [BM,] + b_z += tl.sum(b_s, 0) + + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + p_s = tl.advance(p_s, (-BT, 0)) + p_c = tl.advance(p_c, (-BT, 0)) + p_o = tl.advance(p_o, (-BT, 0)) + + +class FusedChunkABCFunction(torch.autograd.Function): + + @staticmethod + @contiguous + def forward(ctx, q, k, v, sk, sv): + batch_size, n_heads, seq_len, d_head_qk, d_head_v, n_slots = *q.shape, v.shape[-1], sk.shape[-1] + scale = d_head_qk ** -0.5 + + DT, DK, DV, DM = seq_len, d_head_qk, d_head_v, n_slots + BT = 16 + if batch_size * n_heads > 100: + BK, BV, BM = min(DK, 64), min(DV, 64), min(DM, 64) + num_stages = 1 + num_warps = 2 + else: + # SM is not fully utilized so we add more parallelism in the hidden state dimension. + BK, BV, BM = min(DK, 32), min(DV, 32), min(DM, 32) + num_stages = 1 + num_warps = 1 + NT, NK, NV, NM = triton.cdiv(DT, BT), triton.cdiv(DK, BK), triton.cdiv(DV, BV), triton.cdiv(DM, BM) + + rk, ck, pk = sk.new_empty(batch_size, n_heads, NT, DM), torch.empty_like(sk), torch.empty_like(sk) + grid = (NM, batch_size * n_heads) + chunk_abc_fwd_kernel_cum[grid]( + sk, rk, ck, pk, + sk.stride(1), sk.stride(2), sk.stride(3), + seq_len, + BT=BT, BM=BM, DM=DM, NT=NT, + num_warps=num_warps, + num_stages=num_stages + ) + rv, cv, pv = sv.new_empty(batch_size, n_heads, NT, DM), torch.empty_like(sv), torch.empty_like(sv) + chunk_abc_fwd_kernel_cum[grid]( + sv, rv, cv, pv, + sv.stride(1), sv.stride(2), sv.stride(3), + seq_len, + BT=BT, BM=BM, DM=DM, NT=NT, + num_warps=num_warps, + num_stages=num_stages + ) + + s = q.new_empty(NK, batch_size, n_heads, seq_len, n_slots) + grid = (NM, NK, batch_size * n_heads) + chunk_abc_fwd_kernel_s[grid]( + q, k, s, rk, ck, pk, + q.stride(1), q.stride(2), q.stride(3), + sk.stride(1), sk.stride(2), sk.stride(3), + seq_len, scale, + BT=BT, BK=BK, BM=BM, DK=DK, DM=DM, NT=NT, + num_warps=num_warps, + num_stages=num_stages + ) + s = s.sum(0) + p = s.softmax(-1, dtype=torch.float).to(q.dtype) + o = q.new_empty(NM, batch_size, n_heads, seq_len, d_head_v) + grid = (NV, NM, batch_size * n_heads) + chunk_abc_fwd_kernel_o[grid]( + p, v, o, rv, cv, pv, + q.stride(1), q.stride(2), q.stride(3), + sk.stride(1), sk.stride(2), sk.stride(3), + seq_len, + BT=BT, BM=BM, BV=BV, DM=DM, DV=DV, NT=NT, + num_warps=num_warps, + num_stages=num_stages + ) + o = o.sum(0) + ctx.save_for_backward(q, k, v, o, s, p, rk, ck, pk, rv, cv, pv) + ctx.batch_size = batch_size + ctx.n_heads = n_heads + ctx.seq_len = seq_len + ctx.n_slots = n_slots + ctx.dtype = q.dtype + ctx.scale = scale + ctx.BT = BT + return o + + @staticmethod + @contiguous + def backward(ctx, do): + q, k, v, o, s, p, rk, ck, pk, rv, cv, pv = ctx.saved_tensors + batch_size, n_heads, seq_len, d_head_qk, d_head_v, n_slots = *q.shape, v.shape[-1], s.shape[-1] + scale = d_head_qk ** -0.5 + + DT, DK, DV, DM = seq_len, d_head_qk, d_head_v, n_slots + BT = ctx.BT + if batch_size * n_heads > 100: + BK, BV, BM = min(DK, 64), min(DV, 64), min(DM, 64) + num_stages = 1 + num_warps = 2 + else: + BK, BV, BM = min(DK, 32), min(DV, 32), min(DM, 32) + num_stages = 1 + num_warps = 2 + NT, NK, NV, NM = triton.cdiv(DT, BT), triton.cdiv(DK, BK), triton.cdiv(DV, BV), triton.cdiv(DM, BM) + dp = s.new_empty(NV, *s.shape) + grid = (NM, NV, batch_size * n_heads) + chunk_abc_bwd_kernel_dp[grid]( + v, rv, cv, pv, do, dp, + q.stride(1), q.stride(2), q.stride(3), + s.stride(1), s.stride(2), s.stride(3), + seq_len, + BT=BT, BV=BV, BM=BM, DV=DV, DM=DM, NT=NT, + num_warps=num_warps, + num_stages=num_stages + ) + dp = dp.sum(0) + ds = p * (dp - (o * do).sum(-1, True)) * pk + dss = ds * scale + dq, dk, dv = q.new_empty(NM, *q.shape), k.new_empty(NM, *k.shape), v.new_empty(NM, *v.shape) + dsk, dsv = s.new_empty(NK, *s.shape), s.new_empty(NV, *s.shape) + grid = (NK, NM, batch_size * n_heads) + chunk_abc_bwd_kernel_dq[grid]( + k, rk, ck, dq, dss, + q.stride(1), q.stride(2), q.stride(3), + s.stride(1), s.stride(2), s.stride(3), + seq_len, + BT=BT, BK=BK, BM=BM, DK=DK, DM=DM, NT=NT, + num_warps=num_warps, + num_stages=num_stages + ) + dq = dq.sum(0) + chunk_abc_bwd_kernel_dk[grid]( + q, k, rk, ck, dss, dk, dsk, + q.stride(1), q.stride(2), q.stride(3), + s.stride(1), s.stride(2), s.stride(3), + seq_len, + BT=BT, BK=BK, BM=BM, DK=DK, DM=DM, NT=NT, + num_warps=num_warps, + num_stages=num_stages + ) + dk, dsk = dk.sum(0), dsk.sum(0) + + p = p * pv + grid = (NV, NM, batch_size * n_heads) + chunk_abc_bwd_kernel_dv[grid]( + do, v, rv, cv, p, dv, dsv, + q.stride(1), q.stride(2), q.stride(3), + s.stride(1), s.stride(2), s.stride(3), + seq_len, + BT=BT, BV=BV, BM=BM, DV=DV, DM=DM, NT=NT, + num_warps=num_warps, + num_stages=num_stages + ) + dv, dsv = dv.sum(0), dsv.sum(0) + grid = (NM, batch_size * n_heads) + chunk_abc_bwd_kernel_rcum[grid]( + ds * s, rk, ck, dsk, + s.stride(1), s.stride(2), s.stride(3), + seq_len, + BT=BT, BM=BM, DM=DM, NT=NT, + num_warps=num_warps, + num_stages=num_stages + ) + chunk_abc_bwd_kernel_rcum[grid]( + p * dp, rv, cv, dsv, + s.stride(1), s.stride(2), s.stride(3), + seq_len, + BT=BT, BM=BM, DM=DM, NT=NT, + num_warps=num_warps, + num_stages=num_stages + ) + return dq, dk, dv, dsk, dsv + + +fused_chunk_abc = FusedChunkABCFunction.apply diff --git a/flash_linear_attention/fla/ops/triton/based/__init__.py b/flash_linear_attention/fla/ops/triton/based/__init__.py new file mode 100644 index 0000000..a18fb47 --- /dev/null +++ b/flash_linear_attention/fla/ops/triton/based/__init__.py @@ -0,0 +1,6 @@ +# -*- coding: utf-8 -*- + +from .chunk_fuse import fused_chunk_based +from .parallel import parallel_based + +__all__ = ["parallel_based", "fused_chunk_based"] diff --git a/flash_linear_attention/fla/ops/triton/based/chunk_fuse.py b/flash_linear_attention/fla/ops/triton/based/chunk_fuse.py new file mode 100644 index 0000000..ce1627f --- /dev/null +++ b/flash_linear_attention/fla/ops/triton/based/chunk_fuse.py @@ -0,0 +1,410 @@ +# -*- coding: utf-8 -*- + +import torch +import triton +import triton.language as tl +from torch.cuda.amp import custom_bwd, custom_fwd + +from fla.ops.triton.utils import contiguous + +# on-the-fly computation without materializing hidden statets into HBMs + + +@triton.jit +def fused_chunk_based_fwd_kernel( + # B: batch_size, H: n_heads, T: seq_len, D: d_head + q, # query [B, H, L, D_head_K] + k, # key [B, H, L, D_head_V] + v, # value [B, H, L, D_head_V] + o, # output [B, H, L, D_head_V] + z, # normalizer [B, H, L, 1] + s_qk_h, # stride size: L * D_head_K + s_qk_t, # stride size: D_head_K + s_qk_d, # stride size: 1 + s_vo_h, # stride size: L * D_head_V + s_vo_t, # stride size: D_head_V + s_vo_d, # stride size: 1 + B, # batch size + H, # n_heads + T, # seq_len + scale, # D_head_K ** -0.5 + BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size + BK: tl.constexpr, # BLOCK SIZE along the K dimension + BV: tl.constexpr, # BLOCK SIZE along the V dimension + DK: tl.constexpr, # D_head_K + DV: tl.constexpr, # D_head_V +): + # indices + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + o_i = tl.arange(0, BT) + + # [BT, BT] + m_s = o_i[:, None] >= o_i[None, :] + + # [BV], zero-order taylor expansion + b_h_0o = tl.zeros([BV], dtype=tl.float32) + # [BK, BV], first-order taylor expansion + b_h_1o = tl.zeros([BK, BV], dtype=tl.float32) + # [BK, BK, BV] second-order taylor expansion + b_h_2o = tl.zeros([BK*BK, BV], dtype=tl.float32) + + # make block pointers + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK), + (s_qk_t, s_qk_d), (0, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T), + (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BT), (0, 1)) + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), + (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0)) + p_o = tl.make_block_ptr(o + (i_bh + i_k*B*H) * s_vo_h, (T, DV), + (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0)) + + p_z = z + (i_bh + i_k * B * H) * T + tl.arange(0, BT) + k_2o = tl.zeros([1, BK * BK], dtype=tl.float32) + k_1o = tl.zeros([1, BK], dtype=tl.float32) + k_0o = 0 + + for i in range(0, tl.cdiv(T, BT)): + # [BK, BT] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BK*BK, BT] + b_k_2o = b_k[:, None, :] * b_k[None, :, :] + b_k_2o = tl.reshape(b_k_2o, [BK * BK, BT]).to(b_k.dtype) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BT, BK] + b_q = (tl.load(p_q, boundary_check=(0, 1)) * scale).to(b_k.dtype) + b_o = tl.zeros([BT, BV], dtype=tl.float32) + b_z = tl.zeros([BT], dtype=tl.float32) + + # interchunk + # zero-order + b_o += b_h_0o + b_z += k_0o + # first-order + b_o += tl.dot(b_q, b_h_1o.to(b_q.dtype), allow_tf32=False) + b_z += tl.sum(b_q * k_1o, axis=1) + # second-order + b_q_2o = b_q[:, :, None] * b_q[:, None, :] + b_q_2o = tl.reshape(b_q_2o, [BT, BK * BK]).to(b_k.dtype) + b_o += tl.dot(b_q_2o, b_h_2o.to(b_q_2o.dtype), allow_tf32=False) * 0.5 + b_z += tl.sum(b_q_2o * k_2o, axis=1) * 0.5 + + # update running statistics + k_1o += tl.sum(b_k, axis=1)[None, :] + k_2o += tl.sum(b_k_2o, axis=1)[None, :] + k_0o += BT + + # intrachunk + # [BT, BT] + b_s = tl.dot(b_q, b_k, allow_tf32=False) + b_s = 1 + b_s + 0.5 * b_s * b_s + b_s = tl.where(m_s, b_s, 0) + b_z += tl.sum(b_s, axis=1) + b_o += tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False) + # [TB, BV] + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_z, b_z.to(p_z.dtype.element_ty), + mask=(i * BT + tl.arange(0, BT)) < T) + + # update hidden state + # [BK, BV] + b_h_2o = b_h_2o + tl.dot(b_k_2o.to(b_v.dtype), b_v, allow_tf32=False) + b_h_1o = b_h_1o + tl.dot(b_k, b_v, allow_tf32=False) + b_h_0o = b_h_0o + tl.sum(b_v, axis=0) + + p_q = tl.advance(p_q, (BT, 0)) + p_k = tl.advance(p_k, (0, BT)) + p_v = tl.advance(p_v, (BT, 0)) + p_o = tl.advance(p_o, (BT, 0)) + p_z += BT + + +# Similar to Algorithm1 of https://arxiv.org/abs/2006.16236 +@triton.jit +def fused_chunk_based_bwd_kernel( + # B: batch_size, H: n_heads, T: seq_len, D: d_head + # NV: number of split in the V dimension. NK: number of split in the K dimension + q, # query [B, H, L, D_head_K] + k, # key [B, H, L, D_head_V] + v, # value [B, H, L, D_head_V] + do, # gradient of output [B, H, L, D_head_V] + dz, # gradient of normalizer [B, H, L] + dq, # gradient of query [NV, B, H, L, D_head_K] + dk, # gradient of key [NV, B, H, L, D_head_K] + dv, # gradient of value [NK, B, H, L, D_head_V] + s_qk_h, # stride size: L * D_head_K + s_qk_t, # stride size: D_head_K + s_qk_d, # stride size: 1 + s_vo_h, # stride size: L * D_head_V + s_vo_t, # stride size: D_head_V + s_vo_d, # stride size: 1 + B, # batch_size + H, # n_heads + T, # seq_len + scale, # D_head_K ** -0.5 + BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size + BK: tl.constexpr, # BLOCK SIZE along the K dimension + BV: tl.constexpr, # BLOCK SIZE along the V dimension + DK: tl.constexpr, # D_head_K + DV: tl.constexpr, # D_head_V +): + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + o_i = tl.arange(0, BT) + m_s = o_i[:, None] >= o_i[None, :] + + # [BV], zero-order taylor expansion + # b_h_0o = tl.zeros([BV], dtype=tl.float32) + # [BK, BV], first-order taylor expansion + b_h_1o = tl.zeros([BV, BK], dtype=tl.float32) + # [BK, BK, BV] second-order taylor expansion + b_h_2o = tl.zeros([BV, BK*BK], dtype=tl.float32) + + k_1o = tl.zeros([1, BK], dtype=tl.float32) + k_2o = tl.zeros([1, BK * BK], dtype=tl.float32) + + for i in range(0, tl.cdiv(T, BT)): + p_q = tl.make_block_ptr( + q + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i * BT, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr( + k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i * BT, i_k * BK), (BT, BK), (1, 0)) + p_v = tl.make_block_ptr( + v + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i * BT), (BV, BT), (0, 1)) + p_do = tl.make_block_ptr( + do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i * BT, i_v * BV), (BT, BV), (1, 0)) + p_dq = tl.make_block_ptr(dq + (i_bh + i_v*B*H) * s_qk_h, + (T, DK), (s_qk_t, s_qk_d), (i*BT, i_k*BK), (BT, BK), (1, 0)) + p_dz = dz + (i_bh) * T + tl.arange(0, BT) + i * BT + b_dq = tl.zeros([BT, BK], dtype=tl.float32) + + # load tensors + # [BT, BK] + b_dz = tl.load(p_dz, mask=(tl.arange(0, BT) + i * BT) < T) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype) + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BV, BT] + b_v = tl.load(p_v, boundary_check=(0, 1)) + + # inter-chunk + b_dq += tl.dot(b_do, (b_h_1o).to(b_do.dtype), allow_tf32=False) + if i_v == 0: + b_dq += b_dz[:, None] * k_1o + b_dq_2o = tl.dot(b_do, (b_h_2o).to(b_do.dtype), allow_tf32=False) * 0.5 + if i_v == 0: + b_dq_2o += (b_dz[:, None] * k_2o) * 0.5 + b_dq_2o = tl.reshape(b_dq_2o, [BT, BK, BK]) + b_dq += tl.sum(b_dq_2o * b_q[:, :, None], axis=1) + b_dq += tl.sum(b_dq_2o * b_q[:, None, :], axis=2) + b_dq *= scale + + # intra-chunk + # [BT, BT] + b_ds = tl.dot(b_do, b_v, allow_tf32=False) + if i_v == 0: + b_ds += b_dz[:, None] + b_ds = tl.where(m_s, b_ds, 0) * scale + b_s = tl.dot(b_q, tl.trans(b_k), allow_tf32=False) + b_s = tl.where(m_s, b_s, 0) + b_dq += tl.dot((b_ds * (1 + b_s)).to(b_q.dtype), b_k, allow_tf32=False) + + # store + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + + # update hidden state + # [BT, BK*BK] + b_k_2o = b_k[:, :, None] * b_k[:, None, :] + b_k_2o = tl.reshape(b_k_2o, [BT, BK * BK]).to(b_k.dtype) + # [BV, BK*BK] + b_h_2o = b_h_2o + tl.dot(b_v, b_k_2o.to(b_v.dtype), allow_tf32=False) + # [BV, BK] + b_h_1o = b_h_1o + tl.dot(b_v, b_k, allow_tf32=False) + + if i_v == 0: + # update running statistics + k_1o += tl.sum(b_k, axis=0)[None, :] + k_2o += tl.sum(b_k_2o, axis=0)[None, :] + + tl.debug_barrier() + b_h_1o = None + b_h_2o = None + + # [BK, BV], first-order taylor expansion + b_dh_1o = tl.zeros([BK, BV], dtype=tl.float32) + # [BK, BK, BV] second-order taylor expansion + b_dh_2o = tl.zeros([BK*BK, BV], dtype=tl.float32) + b_dh_0o = tl.zeros([BV], dtype=tl.float32) + m_s = tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :] + + dq_1o = tl.zeros([1, BK], dtype=tl.float32) + dq_2o = tl.zeros([BK * BK, 1], dtype=tl.float32) + + for i in range(tl.cdiv(T, BT) * BT - BT, -BT, -BT): + p_q = tl.make_block_ptr( + q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, i), (BK, BT), (0, 1)) + p_k = tl.make_block_ptr( + k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i, i_k * BK), (BT, BK), (1, 0)) + p_v = tl.make_block_ptr( + v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i, i_v * BV), (BT, BV), (1, 0)) + p_do = tl.make_block_ptr( + do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i, i_v * BV), (BT, BV), (1, 0)) + p_dk = tl.make_block_ptr(dk + (i_bh+i_v*B*H) * s_qk_h, (T, DK), + (s_qk_t, s_qk_d), (i, i_k*BK), (BT, BK), (1, 0)) + p_dv = tl.make_block_ptr(dv + (i_bh+i_k*B*H) * s_vo_h, (T, DV), + (s_vo_t, s_vo_d), (i, i_v*BV), (BT, BV), (1, 0)) + p_dz = dz + (i_bh) * T + tl.arange(0, BT) + i + + b_dk = tl.zeros([BT, BK], dtype=tl.float32) + b_dv = tl.zeros([BT, BV], dtype=tl.float32) + + b_dz = tl.load(p_dz, mask=(tl.arange(0, BT)+i) < T) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype) + b_q = (b_q * scale).to(b_k.dtype) + + # intra chunk + b_ds = tl.dot(b_v, tl.trans(b_do), allow_tf32=False) + if i_v == 0: + b_ds += b_dz[None, :] + b_ds = tl.where(m_s, b_ds, 0) + b_s = tl.dot(b_k, b_q, allow_tf32=False) + b_s2 = 1 + b_s + 0.5 * b_s * b_s + b_s = tl.where(m_s, b_s, 0) + b_s2 = tl.where(m_s, b_s2, 0) + b_ds *= (1+b_s) + + b_dk += tl.dot(b_ds.to(b_k.dtype), tl.trans(b_q), allow_tf32=False) + b_dv += tl.dot(b_s2.to(b_do.dtype), b_do, allow_tf32=False) + + # inter chunk + b_k_2o = b_k[:, :, None] * b_k[:, None, :] + b_k_2o = tl.reshape(b_k_2o, [BT, BK * BK]).to(b_k.dtype) + + b_dv += tl.dot(b_k, b_dh_1o.to(b_k.dtype), allow_tf32=False) + b_dv += tl.dot(b_k_2o, b_dh_2o.to(b_k.dtype), allow_tf32=False) + b_dv += b_dh_0o + + b_dk += tl.dot(b_v, tl.trans(b_dh_1o).to(b_k.dtype), allow_tf32=False) + + if i_v == 0: + b_dk += dq_1o + + b_dk_2o = tl.dot(b_dh_2o.to(b_k.dtype), + tl.trans(b_v), allow_tf32=False) + if i_v == 0: + b_dk_2o += dq_2o + b_dk_2o = tl.reshape(b_dk_2o, [BK, BK, BT]) + b_k_fp32 = tl.trans(b_k.to(tl.float32)) + b_dk2 = tl.sum(b_dk_2o * b_k_fp32[:, None, :], axis=0) + b_dk2 += tl.sum(b_dk_2o * b_k_fp32[None, :, :], axis=1) + b_dk += tl.trans(b_dk2) + + # hidden state update + b_dh_0o += tl.sum(b_do, axis=0) + b_dh_1o = b_dh_1o + tl.dot(b_q, b_do, allow_tf32=False) + b_q_2o = b_q[None, :, :] * b_q[:, None, :] + b_q_2o = tl.reshape(b_q_2o, [BK * BK, BT]).to(b_k.dtype) + b_dh_2o = b_dh_2o + tl.dot(b_q_2o, b_do, allow_tf32=False) * 0.5 + + if i_v == 0: + dq_1o += (tl.sum(b_dz[None, :] * b_q, axis=1))[None, :] + dq_2o += (tl.sum(b_dz[None, :] * b_q_2o, axis=1) * 0.5)[:, None] + + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + +class FusedChunkBasedFunction(torch.autograd.Function): + + @staticmethod + @contiguous + @custom_fwd + def forward(ctx, q, k, v, scale=1): + batch_size, n_heads, seq_len, d_head_qk = q.shape + # assert d_head_qk == 16, "currently we do not support feature dim other than 16" + d_head_v = v.shape[-1] + + scale = scale + BT = 16 + BK, BV = min(d_head_qk, 16), min(d_head_v, 32) + BK, BV = max(BK, 16), max(BV, 16) + NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV) + + num_warps = 4 + + # the norm of o might explode, so we need to use float32 here + o = q.new_empty(NK, batch_size, n_heads, seq_len, + d_head_v, dtype=torch.float32) + z = q.new_empty(NK, batch_size, n_heads, seq_len, dtype=torch.float32) + + grid = (NV, NK, batch_size * n_heads) + fused_chunk_based_fwd_kernel[grid]( + q, k, v, o, z, + q.stride(1), q.stride(2), q.stride(3), + v.stride(1), v.stride(2), v.stride(3), + batch_size, n_heads, seq_len, scale, + BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV, + num_warps=num_warps, + ) + o = o.sum(0) + z = z.sum(0) + ctx.save_for_backward(q, k, v) + ctx.scale = scale + return o.to(q.dtype), z.to(z.dtype) + + @staticmethod + @contiguous + @custom_bwd + def backward(ctx, do, dz): + q, k, v = ctx.saved_tensors + batch_size, n_heads, seq_len, d_head_qk = q.shape + d_head_v = v.shape[-1] + scale = ctx.scale + + BT = 16 + BK, BV = min(d_head_qk, 16), min(d_head_v, 32) + BK, BV = max(BK, 16), max(BV, 16) + NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV) + num_stages = 1 + num_warps = 4 + + dq = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk) + dk = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk) + dv = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v) + grid = (NV, NK, batch_size * n_heads) + + fused_chunk_based_bwd_kernel[grid]( + q, k, v, do, dz, dq, dk, dv, + q.stride(1), q.stride(2), q.stride(3), + v.stride(1), v.stride(2), v.stride(3), + batch_size, n_heads, seq_len, scale, + BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV, + num_warps=num_warps, + num_stages=num_stages + ) + dq = dq.sum(0) + dk = dk.sum(0) + dv = dv.sum(0) + return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), None + + +triton_fused_chunk_based = FusedChunkBasedFunction.apply + + +def fused_chunk_based(q, k, v, eps: float = 1e-6, use_scale: bool = True, use_normalize: bool = True): + assert q.shape[-1] <= 16, 'only support feature dimension up to 16.' + if use_scale: + scale = q.shape[-1] ** -0.5 + else: + scale = 1 + o, z = triton_fused_chunk_based(q, k, v, scale) + if use_normalize: + o = o / (z[..., None] + eps) + else: + o = o + + return o.to(q.dtype) diff --git a/flash_linear_attention/fla/ops/triton/based/parallel.py b/flash_linear_attention/fla/ops/triton/based/parallel.py new file mode 100644 index 0000000..fcff0ad --- /dev/null +++ b/flash_linear_attention/fla/ops/triton/based/parallel.py @@ -0,0 +1,385 @@ + +# -*- coding: utf-8 -*- + +import torch +import triton +import triton.language as tl + +from fla.ops.triton.utils import contiguous +from torch.cuda.amp import custom_bwd, custom_fwd + + +@triton.jit +def parallel_based_fwd_kernel( + # B: batch_size, H: n_heads, T: seq_len, D: d_head + q, # query [B, H, L, D_head_K] + k, # key [B, H, L, D_head_V] + v, # value [B, H, L, D_head_V] + o, # output [B, H, L, D_head_V] + z, # normalizer [B, H, L] + s_qk_h, # stride size: L * D_head_K + s_qk_t, # stride size: D_head_K + s_qk_d, # stride size: 1 + s_vo_h, # stride size: L * D_head_V + s_vo_t, # stride size: D_head_V + s_vo_d, # stride size: 1 + B, # batch size + H, # n_heads + T, # seq_len + scale, # D_head_K ** -0.5 + BTL: tl.constexpr, # BLOCK SIZE along the sequence dimension for Q + BTS: tl.constexpr, # BLOCK SIZE along the sequence dimension for K/V + BK: tl.constexpr, # BLOCK SIZE along the K dimension + BV: tl.constexpr, # BLOCK SIZE along the V dimension + DK: tl.constexpr, # D_head_K + DV: tl.constexpr, # D_head_V +): + # i_c: chunk index. used for sequence parallelism + i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + NV = tl.cdiv(DV, BV) + i_k = i_kv // (NV) + i_v = i_kv % (NV) + + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK), + (s_qk_t, s_qk_d), (i_c * BTL, i_k * BK), (BTL, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T), + (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BTS), (0, 1)) + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), + (s_vo_t, s_vo_d), (0, i_v * BV), (BTS, BV), (1, 0)) + + # [BQ, BD] block Q, in the shared memory throughout the whole kernel + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + b_o = tl.zeros([BTL, BV], dtype=tl.float32) + b_z = tl.zeros([BTL], dtype=tl.float32) + + # Q block and K block have no overlap + # no need for mask, thereby saving flops + for _ in range(0, i_c * BTL, BTS): + # [BK, BTS] + b_k = tl.load(p_k, boundary_check=(0, 1)) + + # [BTS, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BTL, BTS] + b_s = tl.dot(b_q, (b_k), allow_tf32=False) + b_s = 1 + b_s + 0.5 * b_s * b_s + b_z += tl.sum(b_s, axis=1) + + # [BQ, BD] + b_o = b_o + tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False) + p_k = tl.advance(p_k, (0, BTS)) + p_v = tl.advance(p_v, (BTS, 0)) + + # # rescale interchunk output + tl.debug_barrier() + o_q = tl.arange(0, BTL) + # # sync threads, easy for compiler to optimize + # tl.debug_barrier() + + o_k = tl.arange(0, BTS) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T), + (s_qk_d, s_qk_t), (i_k * BK, i_c * BTL), (BK, BTS), (0, 1)) + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), + (s_vo_t, s_vo_d), (i_c * BTL, i_v * BV), (BTS, BV), (1, 0)) + # Q block and K block have overlap. masks required + for _ in range(i_c * BTL, (i_c + 1) * BTL, BTS): + # [BK, BTS] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BTS, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BTL, BTS] + m_s = o_q[:, None] >= o_k[None, :] + b_s = tl.dot(b_q, b_k, allow_tf32=False) + b_s = 1 + b_s + 0.5 * b_s * b_s + b_s = tl.where(m_s, b_s, 0) + b_z += tl.sum(b_s, axis=1) + # [BTL, BV] + b_o += tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False) + + p_k = tl.advance(p_k, (0, BTS)) + p_v = tl.advance(p_v, (BTS, 0)) + o_k += BTS + + p_o = tl.make_block_ptr(o + (i_bh + B * H * i_k) * s_vo_h, (T, DV), + (s_vo_t, s_vo_d), (i_c*BTL, i_v*BV), (BTL, BV), (1, 0)) + p_z = z + (i_bh + B * H * i_k) * T + i_c * BTL + tl.arange(0, BTL) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_z, b_z.to(p_z.dtype.element_ty), + mask=((i_c * BTL + tl.arange(0, BTL)) < T)) + + +@triton.jit +def _parallel_based_bwd_dq( + i_bh, i_c, i_k, i_v, i_h, + q, k, v, do, dz, dq, s_qk_h, s_qk_t, s_qk_d, s_vo_h, + s_vo_t, s_vo_d, B, H, T, scale, + BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, + DK: tl.constexpr, DV: tl.constexpr, +): + p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), + (i_c * BTL, i_v * BV), (BTL, BV), (1, 0)) + p_q = tl.make_block_ptr(q + (i_bh) * s_qk_h, (T, DK), + (s_qk_t, s_qk_d), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype) + b_q = (b_q * scale).to(b_q.dtype) + b_dq = tl.zeros([BTL, BK], dtype=tl.float32) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), + (s_qk_t, s_qk_d), (0, i_k * BK), (BTS, BK), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (DV, T), + (s_vo_d, s_vo_t), (i_v * BV, 0), (BV, BTS), (0, 1)) + p_dz = dz + i_bh * T + i_c * BTL + tl.arange(0, BTL) + b_dz = tl.load(p_dz, mask=(i_c * BTL + tl.arange(0, BTL)) < T) + + for _ in range(0, i_c * BTL, BTS): + # [BTS, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BV, BTS] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BTL, BTS] + b_ds = tl.dot(b_do, b_v, allow_tf32=False) + if i_v == 0: + b_ds += b_dz[:, None] + else: + b_ds = b_ds + b_s = tl.dot(b_q, tl.trans(b_k), allow_tf32=False) + # [BQ, BD] + b_dq += tl.dot((b_ds * (1 + b_s)).to(b_v.dtype), b_k, allow_tf32=False) + p_k = tl.advance(p_k, (BTS, 0)) + p_v = tl.advance(p_v, (0, BTS)) + + b_dq *= scale + o_q = tl.arange(0, BTL) + o_k = tl.arange(0, BTS) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), + (s_qk_t, s_qk_d), (i_c * BTL, i_k * BK), (BTS, BK), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (DV, T), + (s_vo_d, s_vo_t), (i_v * BV, i_c * BTL), (BV, BTS), (0, 1)) + # Q block and K block have overlap. masks required + for _ in range(i_c * BTL, (i_c + 1) * BTL, BTS): + # [BTS, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BV, BTS] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BTL, BTS] + m_s = o_q[:, None] >= o_k[None, :] + b_ds = tl.dot(b_do, b_v, allow_tf32=False) + if i_v == 0: + b_ds += b_dz[:, None] + else: + b_ds = b_ds + b_ds = tl.where(m_s, b_ds, 0) * scale + b_s = tl.dot(b_q, tl.trans(b_k), allow_tf32=False) + b_s = tl.where(m_s, b_s, 0) + # [BTL, BK] + b_dq += tl.dot((b_ds + b_ds * b_s).to(b_k.dtype), + b_k, allow_tf32=False) + p_k = tl.advance(p_k, (BTS, 0)) + p_v = tl.advance(p_v, (0, BTS)) + o_k += BTS + p_dq = tl.make_block_ptr(dq + (i_bh + B * H * i_v) * s_qk_h, (T, DK), + (s_qk_t, s_qk_d), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0)) + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + return + + +@triton.jit +def _parallel_based_bwd_dkv( + i_bh, i_c, i_k, i_v, i_h, + q, k, v, do, dz, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h, + s_vo_t, s_vo_d, B, H, T, scale, + BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, + DK: tl.constexpr, DV: tl.constexpr, +): + # compute dk dv + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), + (i_c * BTL, i_k * BK), (BTL, BK), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), + (i_c * BTL, i_v * BV), (BTL, BV), (1, 0)) + b_k, b_v = tl.load(p_k, boundary_check=(0, 1)), tl.load( + p_v, boundary_check=(0, 1)) + b_dk, b_dv = tl.zeros([BTL, BK], dtype=tl.float32), tl.zeros( + [BTL, BV], dtype=tl.float32) + + for i in range((tl.cdiv(T, BTS) * BTS)-BTS, (i_c + 1) * BTL - BTS, -BTS): + p_q = tl.make_block_ptr( + q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, i), (BK, BTS), (0, 1)) + p_do = tl.make_block_ptr( + do + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i), (BV, BTS), (0, 1)) + p_dz = dz + i_bh * T + i + tl.arange(0, BTS) + b_q = tl.load(p_q, boundary_check=(0, 1)) # [BK, BTS] + b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype) # [BV, BTS] + b_dz = tl.load(p_dz, mask=(i + tl.arange(0, BTS)) < T) + b_s = tl.dot(b_k.to(b_q.dtype), b_q, allow_tf32=False) * \ + scale # [BTL, BTS] + b_s2 = 1 + b_s + 0.5 * b_s * b_s + b_dv += tl.dot(b_s2.to(b_q.dtype), tl.trans(b_do), allow_tf32=False) + b_ds = tl.dot(b_v, b_do, allow_tf32=False) * scale + if i_v == 0: + b_ds += b_dz[None, :] * scale + else: + b_ds = b_ds + b_dk += tl.dot((b_ds + b_ds * b_s).to(b_q.dtype), + tl.trans(b_q), allow_tf32=False) + + tl.debug_barrier() + o_q, o_k = tl.arange(0, BTS), tl.arange(0, BTL) + for i in range(i_c*BTL, (i_c+1)*BTL, BTS): + p_q = tl.make_block_ptr( + q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, i), (BK, BTS), (0, 1)) + p_do = tl.make_block_ptr( + do + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i), (BV, BTS), (0, 1)) + p_dz = dz + i_bh * T + i + tl.arange(0, BTS) + b_q = tl.load(p_q, boundary_check=(0, 1)) # [BD, BQ] + b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype) + b_dz = tl.load(p_dz, mask=(i + tl.arange(0, BTS)) < T) + # [BK, BQ] + m_s = o_k[:, None] <= o_q[None, :] + b_s = tl.dot(b_k, b_q, allow_tf32=False) * scale + b_s2 = 1 + b_s + 0.5 * b_s * b_s + b_s = tl.where(m_s, b_s, 0) + b_s2 = tl.where(m_s, b_s2, 0) + + b_ds = tl.dot(b_v, b_do, allow_tf32=False) + if i_v == 0: + b_ds += b_dz[None, :] + else: + b_ds = b_ds + b_ds = tl.where(m_s, b_ds, 0) * scale + # [BK, BD] + b_dv += tl.dot(b_s2.to(b_q.dtype), tl.trans(b_do), allow_tf32=False) + b_dk += tl.dot((b_ds + b_ds * b_s).to(b_q.dtype), + tl.trans(b_q), allow_tf32=False) + o_q += BTS + + p_dk = tl.make_block_ptr(dk + (i_bh + B * H * i_v) * s_qk_h, + (T, DK), (s_qk_t, s_qk_d), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0)) + p_dv = tl.make_block_ptr(dv + (i_bh + B * H * i_k) * s_vo_h, + (T, DV), (s_vo_t, s_vo_d), (i_c*BTL, i_v*BV), (BTL, BV), (1, 0)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + return + + +@triton.jit +def parallel_based_bwd_kernel( + q, k, v, do, dz, dq, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h, + s_vo_t, s_vo_d, B, H, T, scale, + BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, + DK: tl.constexpr, DV: tl.constexpr, +): + i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + NV = tl.cdiv(DV, BV) + i_k = i_kv // (NV) + i_v = i_kv % (NV) + i_h = i_bh % H + _parallel_based_bwd_dq( + i_bh, i_c, i_k, i_v, i_h, + q, k, v, do, dz, dq, s_qk_h, s_qk_t, s_qk_d, s_vo_h, + s_vo_t, s_vo_d, B, H, T, scale, BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=DK, DV=DV + ) + tl.debug_barrier() + _parallel_based_bwd_dkv( + i_bh, i_c, i_k, i_v, i_h, + q, k, v, do, dz, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h, + s_vo_t, s_vo_d, B, H, T, scale, BTL, BTS, BK, BV, DK, DV + ) + + +class ParallelBasedFunction(torch.autograd.Function): + @staticmethod + @contiguous + @custom_fwd + def forward(ctx, q, k, v, scale): + BTL, BTS = 128, 32 + assert BTL % BTS == 0 + # assert q.shape[-1] % 16 == 0 + BK = min(128, triton.next_power_of_2(k.shape[-1])) + BV = min(128, triton.next_power_of_2(v.shape[-1])) + BK, BV = max(BK, 16), max(BV, 16) + batch_size, n_heads, seq_len, d_head_qk = q.shape + d_head_v = v.shape[-1] + num_stages = 2 + num_warps = 4 + NK = triton.cdiv(d_head_qk, BK) + NV = triton.cdiv(d_head_v, BV) + grid = (NK * NV, triton.cdiv(seq_len, BTL), batch_size * n_heads) + + assert NK == 1, "will encounter some synchronization issue if not." + + o = torch.empty(NK, batch_size, n_heads, seq_len, + d_head_v, device=q.device) + z = torch.empty(NK, batch_size, n_heads, seq_len, + device=q.device) + parallel_based_fwd_kernel[grid]( + q, k, v, o, z, + q.stride(1), q.stride(2), q.stride(3), + v.stride(1), v.stride(2), v.stride(3), + batch_size, n_heads, seq_len, scale, + BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=d_head_qk, DV=d_head_v, + num_warps=num_warps, + num_stages=num_stages + ) + ctx.save_for_backward(q, k, v) + ctx.scale = scale + return o.sum(0).to(q.dtype), z.sum(0).to(q.dtype) + + @staticmethod + @custom_bwd + @contiguous + def backward(ctx, do, dz): + q, k, v = ctx.saved_tensors + scale = ctx.scale + BTL, BTS = 64, 32 + assert BTL % BTS == 0 + BK = min(128, triton.next_power_of_2(k.shape[-1])) + BV = min(128, triton.next_power_of_2(v.shape[-1])) + BK, BV = max(BK, 16), max(BV, 16) + batch_size, n_heads, seq_len, d_head_qk = q.shape + d_head_v = v.shape[-1] + num_stages = 2 + num_warps = 4 + NK = triton.cdiv(d_head_qk, BK) + NV = triton.cdiv(d_head_v, BV) + grid = (NK * NV, triton.cdiv(seq_len, BTL), batch_size * n_heads) + + assert NK == 1, "will encounter some synchronization issue if not" + + dq = torch.empty(NV, batch_size, n_heads, seq_len, + d_head_qk, dtype=q.dtype, device=q.device) + dk = torch.empty(NV, batch_size, n_heads, seq_len, + d_head_qk, dtype=q.dtype, device=q.device) + dv = torch.empty(NK, batch_size, n_heads, seq_len, + d_head_v, dtype=q.dtype, device=q.device) + + parallel_based_bwd_kernel[grid]( + q, k, v, do, dz, dq, dk, dv, + q.stride(1), q.stride(2), q.stride(3), + v.stride(1), v.stride(2), v.stride(3), + batch_size, n_heads, seq_len, scale, + BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=d_head_qk, DV=d_head_v, + num_warps=num_warps, + num_stages=num_stages + ) + + return dq.sum(0).to(q.dtype), dk.sum(0).to(k.dtype), dv.sum(0).to(v.dtype), None + + +triton_parallel_based = ParallelBasedFunction.apply + + +def parallel_based(q, k, v, eps: float = 1e-6, use_scale: bool = True, use_normalize: bool = True, return_both: bool = False): + assert q.shape[-1] <= 128, "only support feature dim up to 128" + if use_scale: + scale = q.shape[-1] ** -0.5 + else: + scale = 1 + o, z = triton_parallel_based(q, k, v, scale) + if return_both: + return o, z + if use_normalize: + o = o / (z[..., None] + eps) + else: + o = o + return o.to(q.dtype) diff --git a/flash_linear_attention/fla/ops/triton/gla/__init__.py b/flash_linear_attention/fla/ops/triton/gla/__init__.py new file mode 100644 index 0000000..7e41f8c --- /dev/null +++ b/flash_linear_attention/fla/ops/triton/gla/__init__.py @@ -0,0 +1,7 @@ +# -*- coding: utf-8 -*- + +from .chunk import chunk_gla +from .chunk_fuse import fused_chunk_gla +from .recurrent_fuse import fused_recurrent_gla + +__all__ = ['chunk_gla', 'fused_recurrent_gla', 'fused_chunk_gla'] diff --git a/flash_linear_attention/fla/ops/triton/gla/block_parallel/__init__.py b/flash_linear_attention/fla/ops/triton/gla/block_parallel/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/flash_linear_attention/fla/ops/triton/gla/block_parallel/inter_chunk_contribution/__init__.py b/flash_linear_attention/fla/ops/triton/gla/block_parallel/inter_chunk_contribution/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/flash_linear_attention/fla/ops/triton/gla/block_parallel/inter_chunk_contribution/chunk_scan_triton_full.py b/flash_linear_attention/fla/ops/triton/gla/block_parallel/inter_chunk_contribution/chunk_scan_triton_full.py new file mode 100644 index 0000000..c27f6ba --- /dev/null +++ b/flash_linear_attention/fla/ops/triton/gla/block_parallel/inter_chunk_contribution/chunk_scan_triton_full.py @@ -0,0 +1,212 @@ +# -*- coding: utf-8 -*- + +import torch +import triton +import triton.language as tl +from torch.cuda.amp import custom_bwd, custom_fwd + +from fla.ops.triton.utils import contiguous + + +@triton.jit +def _fwd_recurrence( + S, + p1, + p2, + O, + NUM_BLOCK, + D_MODEL_K: tl.constexpr, + D_MODEL_V: tl.constexpr, + BLOCK_MODEL: tl.constexpr +): + offset_bh = tl.program_id(0) + offset_d = tl.program_id(1) + offset_s = tl.program_id(2) + + S = S + offset_bh * NUM_BLOCK * D_MODEL_K * D_MODEL_V + offset_d * D_MODEL_V * BLOCK_MODEL + \ + tl.arange(0, BLOCK_MODEL)[:, None] * D_MODEL_V + offset_s * \ + BLOCK_MODEL + tl.arange(0, BLOCK_MODEL)[None, :] + + O = O + offset_bh * NUM_BLOCK * D_MODEL_K * D_MODEL_V + offset_d * D_MODEL_V * BLOCK_MODEL + \ + tl.arange(0, BLOCK_MODEL)[:, None] * D_MODEL_V + offset_s * BLOCK_MODEL + \ + tl.arange(0, BLOCK_MODEL)[None, :] + D_MODEL_K * D_MODEL_V + + p1 = p1 + offset_bh * NUM_BLOCK * D_MODEL_K + \ + tl.arange(0, BLOCK_MODEL) + offset_d * BLOCK_MODEL + D_MODEL_K + + p2 = p2 + offset_bh * NUM_BLOCK * D_MODEL_V + \ + tl.arange(0, BLOCK_MODEL) + offset_s * BLOCK_MODEL + D_MODEL_V + + acc = tl.zeros([BLOCK_MODEL, BLOCK_MODEL], dtype=tl.float32) + acc += tl.load(S) + + S += D_MODEL_K * D_MODEL_V + + tl.store(O, acc.to(O.dtype.element_ty)) + O += D_MODEL_K * D_MODEL_V + + for i in range(NUM_BLOCK-2): + p_k = tl.load(p1) + p_v = tl.load(p2) + S_i = tl.load(S) + acc = acc * p_k[:, None] * p_v[None, :] + S_i + tl.store(O, acc.to(O.dtype.element_ty)) + p1 += D_MODEL_K + p2 += D_MODEL_V + S += D_MODEL_K * D_MODEL_V + O += D_MODEL_K * D_MODEL_V + + +# NUM_SPLIT_K/V. K/V dimension split into NUM_SPLIT_K/V parts with equal size BLOCK_MODEL +@triton.jit +def _bwd_recurrence( + S, + p1, + p2, + DS, + Dp1, + Dp2, + NUM_BLOCK, + NUM_SPLIT_K, + NUM_SPLIT_V, + D_MODEL_K: tl.constexpr, + D_MODEL_V: tl.constexpr, + BLOCK_MODEL: tl.constexpr + +): + offset_bh = tl.program_id(0) + offset_d = tl.program_id(1) + offset_s = tl.program_id(2) + + # skip the last chunk because it is never used + S = S + offset_bh * NUM_BLOCK * D_MODEL_K * D_MODEL_V + offset_d * D_MODEL_V * BLOCK_MODEL + \ + tl.arange(0, BLOCK_MODEL)[:, None] * D_MODEL_V + offset_s * BLOCK_MODEL + tl.arange( + 0, BLOCK_MODEL)[None, :] + (NUM_BLOCK - 2) * D_MODEL_K * D_MODEL_V + + # start from the last chunk + DS = DS + offset_bh * NUM_BLOCK * D_MODEL_K * D_MODEL_V + offset_d * D_MODEL_V * BLOCK_MODEL + \ + tl.arange(0, BLOCK_MODEL)[:, None] * D_MODEL_V + offset_s * BLOCK_MODEL + tl.arange( + 0, BLOCK_MODEL)[None, :] + (NUM_BLOCK - 1) * D_MODEL_K * D_MODEL_V + + # skip the last chunk because it is never used + p1 = p1 + offset_bh * NUM_BLOCK * D_MODEL_K + \ + tl.arange(0, BLOCK_MODEL) + offset_d * \ + BLOCK_MODEL + (NUM_BLOCK - 2) * D_MODEL_K + + # skip the last chunk because it is never used + p2 = p2 + offset_bh * NUM_BLOCK * D_MODEL_V + \ + tl.arange(0, BLOCK_MODEL) + offset_s * \ + BLOCK_MODEL + (NUM_BLOCK - 2) * D_MODEL_V + + # skip the last chunk because it is never used + # NUM_BLOCK * D_MODEL_K * NUM_SPLIT_V: stride_bh + # offset_s * D_MODEL_K: find the right split in the K dimension + Dp1 = Dp1 + offset_bh * NUM_BLOCK * D_MODEL_K * NUM_SPLIT_V + offset_s * D_MODEL_K + \ + tl.arange(0, BLOCK_MODEL) + offset_d * BLOCK_MODEL + \ + (NUM_BLOCK - 2) * D_MODEL_K * NUM_SPLIT_V + + # skip the last chunk because it is never used + Dp2 = Dp2 + offset_bh * NUM_BLOCK * D_MODEL_V * NUM_SPLIT_K + offset_d * D_MODEL_V + \ + tl.arange(0, BLOCK_MODEL) + offset_s * BLOCK_MODEL + \ + (NUM_BLOCK - 2) * D_MODEL_V * NUM_SPLIT_K + + Dacc = tl.zeros([BLOCK_MODEL, BLOCK_MODEL], dtype=tl.float32) + + # ignore the first chunk + for i in range(NUM_BLOCK - 1): + p_key = tl.load(p1) + p_value = tl.load(p2) + S_i = tl.load(S) + DS_i = tl.load(DS) + Dacc += DS_i + dp_i = Dacc * S_i + dp_key = tl.sum(dp_i * p_value[None, :], axis=1) + tl.store(Dp1, dp_key.to(Dp1.dtype.element_ty)) + dp_value = tl.sum(dp_i * p_key[:, None], axis=0) + tl.store(Dp2, dp_value.to(Dp2.dtype.element_ty)) + + tl.store(S, Dacc.to(S.dtype.element_ty)) + + Dacc *= p_key[:, None] + Dacc *= p_value[None, :] + + S -= D_MODEL_K * D_MODEL_V + DS -= D_MODEL_K * D_MODEL_V + p1 -= D_MODEL_K + p2 -= D_MODEL_V + Dp1 -= D_MODEL_K * NUM_SPLIT_V + Dp2 -= D_MODEL_V * NUM_SPLIT_K + + +class Chunk_memory_update_full(torch.autograd.Function): + @staticmethod + @contiguous + @custom_fwd + def forward(ctx, decay_key_last, decay_value_last, to_add): + B, H, N, D_k, D_v = to_add.shape + output = torch.empty_like(to_add) + BLOCK_MODEL = 32 + + assert D_k % 32 == 0 + assert D_v % 32 == 0 + assert D_k == decay_key_last.shape[-1] + assert D_v == decay_value_last.shape[-1] + + grid = (B*H, D_k//BLOCK_MODEL, D_v//BLOCK_MODEL) + ctx.grid = grid + ctx.BLOCK_MODEL = BLOCK_MODEL + + _fwd_recurrence[grid]( + to_add, + decay_key_last, + decay_value_last, + output, + D_MODEL_K=D_k, D_MODEL_V=D_v, + NUM_BLOCK=N, + BLOCK_MODEL=BLOCK_MODEL + ) + + output[:, :, 0] = 0 + ctx.save_for_backward(output, decay_key_last, decay_value_last) + + return output + + @staticmethod + @contiguous + @custom_bwd + def backward(ctx, DO): + + output, decay_key_last, decay_value_last = ctx.saved_tensors + + B, H, N, D_k, D_v = output.shape + + num_block = N + + BLOCK_MODEL = 32 + + grid = (B*H, D_k//BLOCK_MODEL, D_v//BLOCK_MODEL) + + # I don't want atomic_add to be used in the backward pass + # so I add another dimension to the output tensor (D_k/v // BLOCK_MODEL) + # afterward, I sum over this dimension to get the correct gradient + D_p1 = torch.empty(B, H, N, D_v // BLOCK_MODEL, D_k, + device=DO.device, dtype=torch.float32) + D_p2 = torch.empty(B, H, N, D_k // BLOCK_MODEL, D_v, + device=DO.device, dtype=torch.float32) + + _bwd_recurrence[grid]( + output, decay_key_last, decay_value_last, + DO, D_p1, D_p2, + NUM_BLOCK=num_block, NUM_SPLIT_K=D_k // BLOCK_MODEL, NUM_SPLIT_V=D_v // BLOCK_MODEL, + D_MODEL_K=D_k, + D_MODEL_V=D_v, + BLOCK_MODEL=BLOCK_MODEL + ) + + output[:, :, -1] = 0 + D_p1[:, :, 0] = 0 + D_p1[:, :, -1] = 0 + D_p2[:, :, 0] = 0 + D_p2[:, :, -1] = 0 + + return D_p1.sum(-2).to(decay_key_last.dtype), D_p2.sum(-2).to(decay_key_last.dtype), output diff --git a/flash_linear_attention/fla/ops/triton/gla/block_parallel/inter_chunk_contribution/chunk_scan_triton_no_decay.py b/flash_linear_attention/fla/ops/triton/gla/block_parallel/inter_chunk_contribution/chunk_scan_triton_no_decay.py new file mode 100644 index 0000000..e83b54e --- /dev/null +++ b/flash_linear_attention/fla/ops/triton/gla/block_parallel/inter_chunk_contribution/chunk_scan_triton_no_decay.py @@ -0,0 +1,166 @@ +# -*- coding: utf-8 -*- + +import torch +import triton +import triton.language as tl +from torch.cuda.amp import custom_bwd, custom_fwd + +from fla.ops.triton.utils import contiguous + + +@triton.jit +def _fwd_recurrence( + S, + O, + NUM_BLOCK, + D_MODEL_K: tl.constexpr, + D_MODEL_V: tl.constexpr, + BLOCK_MODEL: tl.constexpr +): + offset_bh = tl.program_id(0) + offset_d = tl.program_id(1) + offset_s = tl.program_id(2) + + S = S + offset_bh * NUM_BLOCK * D_MODEL_K * D_MODEL_V + offset_d * D_MODEL_V * BLOCK_MODEL + \ + tl.arange(0, BLOCK_MODEL)[:, None] * D_MODEL_V + offset_s * \ + BLOCK_MODEL + tl.arange(0, BLOCK_MODEL)[None, :] + + O = O + offset_bh * NUM_BLOCK * D_MODEL_K * D_MODEL_V + offset_d * D_MODEL_V * BLOCK_MODEL + \ + tl.arange(0, BLOCK_MODEL)[:, None] * D_MODEL_V + offset_s * BLOCK_MODEL + \ + tl.arange(0, BLOCK_MODEL)[None, :] + D_MODEL_K * D_MODEL_V + + acc = tl.zeros([BLOCK_MODEL, BLOCK_MODEL], dtype=tl.float32) + acc += tl.load(S) + + S += D_MODEL_K * D_MODEL_V + + tl.store(O, acc.to(O.dtype.element_ty)) + O += D_MODEL_K * D_MODEL_V + + for i in range(NUM_BLOCK-2): + S_i = tl.load(S) + acc = acc + S_i + tl.store(O, acc.to(O.dtype.element_ty)) + S += D_MODEL_K * D_MODEL_V + O += D_MODEL_K * D_MODEL_V + + +# NUM_SPLIT_K/V. K/V dimension split into NUM_SPLIT_K/V parts with equal size BLOCK_MODEL +@triton.jit +def _bwd_recurrence( + S, + DS, + NUM_BLOCK, + NUM_SPLIT_K, + NUM_SPLIT_V, + D_MODEL_K: tl.constexpr, + D_MODEL_V: tl.constexpr, + BLOCK_MODEL: tl.constexpr +): + offset_bh = tl.program_id(0) + offset_d = tl.program_id(1) + offset_s = tl.program_id(2) + + # skip the last chunk because it is never used + S = S + offset_bh * NUM_BLOCK * D_MODEL_K * D_MODEL_V + offset_d * D_MODEL_V * BLOCK_MODEL + \ + tl.arange(0, BLOCK_MODEL)[:, None] * D_MODEL_V + offset_s * BLOCK_MODEL + tl.arange( + 0, BLOCK_MODEL)[None, :] + (NUM_BLOCK - 2) * D_MODEL_K * D_MODEL_V + + # start from the last chunk + DS = DS + offset_bh * NUM_BLOCK * D_MODEL_K * D_MODEL_V + offset_d * D_MODEL_V * BLOCK_MODEL + \ + tl.arange(0, BLOCK_MODEL)[:, None] * D_MODEL_V + offset_s * BLOCK_MODEL + tl.arange( + 0, BLOCK_MODEL)[None, :] + (NUM_BLOCK - 1) * D_MODEL_K * D_MODEL_V + + # skip the last chunk because it is never used + + # skip the last chunk because it is never used + # NUM_BLOCK * D_MODEL_K * NUM_SPLIT_V: stride_bh + # offset_s * D_MODEL_K: find the right split in the K dimension + Dacc = tl.zeros([BLOCK_MODEL, BLOCK_MODEL], dtype=tl.float32) + + # ignore the first chunk + for i in range(NUM_BLOCK - 1): + # S_i = tl.load(S) + DS_i = tl.load(DS) + Dacc += DS_i + # dp_i = Dacc * S_i + + # dp_key = tl.sum(dp_i * p_value[None, :], axis=1) + # tl.store(Dp1, dp_key.to(Dp1.dtype.element_ty)) + # dp_value = tl.sum(dp_i * p_key[:, None], axis=0) + # tl.store(Dp2, dp_value.to(Dp2.dtype.element_ty)) + + tl.store(S, Dacc.to(S.dtype.element_ty)) + + # Dacc *= p_key[:, None] + # Dacc *= p_value[None, :] + + S -= D_MODEL_K * D_MODEL_V + DS -= D_MODEL_K * D_MODEL_V + + +class Chunk_memory_update_no_decay(torch.autograd.Function): + @staticmethod + @custom_fwd + @contiguous + def forward(ctx, to_add): + B, H, N, D_k, D_v = to_add.shape + output = torch.empty_like(to_add) + BLOCK_MODEL = 32 + + assert D_k % 32 == 0 + assert D_v % 32 == 0 + # assert D_k == decay_key_last.shape[-1] + # assert D_v == decay_value_last.shape[-1] + + grid = (B*H, D_k//BLOCK_MODEL, D_v//BLOCK_MODEL) + ctx.grid = grid + ctx.BLOCK_MODEL = BLOCK_MODEL + + _fwd_recurrence[grid]( + to_add, + # decay_key_last, + # decay_value_last, + output, + D_MODEL_K=D_k, D_MODEL_V=D_v, + NUM_BLOCK=N, + BLOCK_MODEL=BLOCK_MODEL + ) + + output[:, :, 0] = 0 + ctx.save_for_backward(output) + + return output.to(to_add.dtype) + + @staticmethod + @custom_bwd + @contiguous + def backward(ctx, DO): + output, = ctx.saved_tensors + + B, H, N, D_k, D_v = output.shape + + num_block = N + + BLOCK_MODEL = 32 + + grid = (B*H, D_k//BLOCK_MODEL, D_v//BLOCK_MODEL) + + # I don't want atomic_add to be used in the backward pass + # so I add another dimension to the output tensor (D_k/v // BLOCK_MODEL) + # afterward, I sum over this dimension to get the correct gradient + # D_p1 = torch.empty(B, H, N, D_v // BLOCK_MODEL, D_k, device=DO.device, dtype=torch.float32) + # D_p2 = torch.empty(B, H, N, D_k // BLOCK_MODEL, D_v, device=DO.device, dtype=torch.float32) + + _bwd_recurrence[grid]( + output, + DO, + NUM_BLOCK=num_block, NUM_SPLIT_K=D_k // BLOCK_MODEL, NUM_SPLIT_V=D_v // BLOCK_MODEL, + D_MODEL_K=D_k, + D_MODEL_V=D_v, + BLOCK_MODEL=BLOCK_MODEL + ) + + output[:, :, -1] = 0 + + return output diff --git a/flash_linear_attention/fla/ops/triton/gla/block_parallel/inter_chunk_contribution/chunk_scan_triton_only_gk.py b/flash_linear_attention/fla/ops/triton/gla/block_parallel/inter_chunk_contribution/chunk_scan_triton_only_gk.py new file mode 100644 index 0000000..59cb608 --- /dev/null +++ b/flash_linear_attention/fla/ops/triton/gla/block_parallel/inter_chunk_contribution/chunk_scan_triton_only_gk.py @@ -0,0 +1,187 @@ +# -*- coding: utf-8 -*- + +import torch +import triton +import triton.language as tl +from torch.cuda.amp import custom_bwd, custom_fwd + +from fla.ops.triton.utils import contiguous + + +@triton.jit +def _fwd_recurrence( + S, + p1, + O, + NUM_BLOCK, + D_MODEL_K: tl.constexpr, + D_MODEL_V: tl.constexpr, + BLOCK_MODEL: tl.constexpr +): + offset_bh = tl.program_id(0) + offset_d = tl.program_id(1) + offset_s = tl.program_id(2) + + S = S + offset_bh * NUM_BLOCK * D_MODEL_K * D_MODEL_V + offset_d * D_MODEL_V * BLOCK_MODEL + \ + tl.arange(0, BLOCK_MODEL)[:, None] * D_MODEL_V + offset_s * \ + BLOCK_MODEL + tl.arange(0, BLOCK_MODEL)[None, :] + + O = O + offset_bh * NUM_BLOCK * D_MODEL_K * D_MODEL_V + offset_d * D_MODEL_V * BLOCK_MODEL + \ + tl.arange(0, BLOCK_MODEL)[:, None] * D_MODEL_V + offset_s * BLOCK_MODEL + \ + tl.arange(0, BLOCK_MODEL)[None, :] + D_MODEL_K * D_MODEL_V + + p1 = p1 + offset_bh * NUM_BLOCK * D_MODEL_K + \ + tl.arange(0, BLOCK_MODEL) + offset_d * BLOCK_MODEL + D_MODEL_K + + acc = tl.zeros([BLOCK_MODEL, BLOCK_MODEL], dtype=tl.float32) + acc += tl.load(S) + + S += D_MODEL_K * D_MODEL_V + + tl.store(O, acc.to(O.dtype.element_ty)) + O += D_MODEL_K * D_MODEL_V + + for i in range(NUM_BLOCK-2): + p_k = tl.load(p1) + S_i = tl.load(S) + acc = acc * p_k[:, None] + S_i + tl.store(O, acc.to(O.dtype.element_ty)) + p1 += D_MODEL_K + S += D_MODEL_K * D_MODEL_V + O += D_MODEL_K * D_MODEL_V + + +# NUM_SPLIT_K/V. K/V dimension split into NUM_SPLIT_K/V parts with equal size BLOCK_MODEL +@triton.jit +def _bwd_recurrence( + S, p1, + DS, Dp1, + NUM_BLOCK, + NUM_SPLIT_K, + NUM_SPLIT_V, + D_MODEL_K: tl.constexpr, + D_MODEL_V: tl.constexpr, + BLOCK_MODEL: tl.constexpr +): + offset_bh = tl.program_id(0) + offset_d = tl.program_id(1) + offset_s = tl.program_id(2) + + # skip the last chunk because it is never used + S = S + offset_bh * NUM_BLOCK * D_MODEL_K * D_MODEL_V + offset_d * D_MODEL_V * BLOCK_MODEL + \ + tl.arange(0, BLOCK_MODEL)[:, None] * D_MODEL_V + offset_s * BLOCK_MODEL + tl.arange( + 0, BLOCK_MODEL)[None, :] + (NUM_BLOCK - 2) * D_MODEL_K * D_MODEL_V + + # start from the last chunk + DS = DS + offset_bh * NUM_BLOCK * D_MODEL_K * D_MODEL_V + offset_d * D_MODEL_V * BLOCK_MODEL + \ + tl.arange(0, BLOCK_MODEL)[:, None] * D_MODEL_V + offset_s * BLOCK_MODEL + tl.arange( + 0, BLOCK_MODEL)[None, :] + (NUM_BLOCK - 1) * D_MODEL_K * D_MODEL_V + + # skip the last chunk because it is never used + p1 = p1 + offset_bh * NUM_BLOCK * D_MODEL_K + \ + tl.arange(0, BLOCK_MODEL) + offset_d * \ + BLOCK_MODEL + (NUM_BLOCK - 2) * D_MODEL_K + + # skip the last chunk because it is never used + # p2 = p2 + offset_bh * NUM_BLOCK * D_MODEL_V + tl.arange(0, BLOCK_MODEL) + offset_s * BLOCK_MODEL + (NUM_BLOCK - 2) * D_MODEL_V + + # skip the last chunk because it is never used + # NUM_BLOCK * D_MODEL_K * NUM_SPLIT_V: stride_bh + # offset_s * D_MODEL_K: find the right split in the K dimension + Dp1 = Dp1 + offset_bh * NUM_BLOCK * D_MODEL_K * NUM_SPLIT_V + offset_s * D_MODEL_K + \ + tl.arange(0, BLOCK_MODEL) + offset_d * BLOCK_MODEL + \ + (NUM_BLOCK - 2) * D_MODEL_K * NUM_SPLIT_V + + # skip the last chunk because it is never used + # Dp2 = Dp2 + offset_bh * NUM_BLOCK * D_MODEL_V * NUM_SPLIT_K + offset_d * D_MODEL_V + tl.arange(0, BLOCK_MODEL) + offset_s * BLOCK_MODEL + (NUM_BLOCK - 2) * D_MODEL_V * NUM_SPLIT_K + + Dacc = tl.zeros([BLOCK_MODEL, BLOCK_MODEL], dtype=tl.float32) + + # ignore the first chunk + for i in range(NUM_BLOCK - 1): + p_key = tl.load(p1) + S_i = tl.load(S) + DS_i = tl.load(DS) + Dacc += DS_i + dp_i = Dacc * S_i + dp_key = tl.sum(dp_i, axis=1) + tl.store(Dp1, dp_key.to(Dp1.dtype.element_ty)) + + tl.store(S, Dacc.to(S.dtype.element_ty)) + + Dacc *= p_key[:, None] + + S -= D_MODEL_K * D_MODEL_V + DS -= D_MODEL_K * D_MODEL_V + p1 -= D_MODEL_K + Dp1 -= D_MODEL_K * NUM_SPLIT_V + + +class Chunk_memory_update_only_gk(torch.autograd.Function): + @staticmethod + @custom_fwd + @contiguous + def forward(ctx, decay_key_last, to_add): + + B, H, N, D_k, D_v = to_add.shape + output = torch.empty_like(to_add) + BLOCK_MODEL = 32 + + assert D_k % 32 == 0 + assert D_v % 32 == 0 + assert D_k == decay_key_last.shape[-1] + # assert D_v == to_add.shape[-1] + + grid = (B*H, D_k//BLOCK_MODEL, D_v//BLOCK_MODEL) + ctx.grid = grid + ctx.BLOCK_MODEL = BLOCK_MODEL + + _fwd_recurrence[grid]( + to_add, + decay_key_last, + output, + D_MODEL_K=D_k, D_MODEL_V=D_v, + NUM_BLOCK=N, + BLOCK_MODEL=BLOCK_MODEL + ) + + output[:, :, 0] = 0 + ctx.save_for_backward(output, decay_key_last) + + return output + + @staticmethod + @custom_bwd + @contiguous + def backward(ctx, DO): + output, decay_key_last = ctx.saved_tensors + + B, H, N, D_k, D_v = output.shape + + num_block = N + + BLOCK_MODEL = 32 + + grid = (B*H, D_k//BLOCK_MODEL, D_v//BLOCK_MODEL) + + # I don't want atomic_add to be used in the backward pass + # so I add another dimension to the output tensor (D_k/v // BLOCK_MODEL) + # afterward, I sum over this dimension to get the correct gradient + D_p1 = torch.empty(B, H, N, D_v // BLOCK_MODEL, D_k, + device=DO.device, dtype=torch.float32) + # D_p2 = torch.empty(B, H, N, D_k // BLOCK_MODEL, D_v, device=DO.device, dtype=torch.float32) + + _bwd_recurrence[grid]( + output, decay_key_last, + DO, D_p1, + NUM_BLOCK=num_block, NUM_SPLIT_K=D_k // BLOCK_MODEL, NUM_SPLIT_V=D_v // BLOCK_MODEL, + D_MODEL_K=D_k, + D_MODEL_V=D_v, + BLOCK_MODEL=BLOCK_MODEL + ) + + output[:, :, -1] = 0 + D_p1[:, :, 0] = 0 + D_p1[:, :, -1] = 0 + + return D_p1.sum(-2).to(decay_key_last.dtype), output diff --git a/flash_linear_attention/fla/ops/triton/gla/block_parallel/inter_chunk_contribution/chunk_scan_triton_only_gv.py b/flash_linear_attention/fla/ops/triton/gla/block_parallel/inter_chunk_contribution/chunk_scan_triton_only_gv.py new file mode 100644 index 0000000..77d03a8 --- /dev/null +++ b/flash_linear_attention/fla/ops/triton/gla/block_parallel/inter_chunk_contribution/chunk_scan_triton_only_gv.py @@ -0,0 +1,199 @@ +# -*- coding: utf-8 -*- + +import torch +import triton +import triton.language as tl +from torch.cuda.amp import custom_bwd, custom_fwd + +from fla.ops.triton.utils import contiguous + + +@triton.jit +def _fwd_recurrence( + S, + p2, + O, + NUM_BLOCK, + D_MODEL_K: tl.constexpr, + D_MODEL_V: tl.constexpr, + BLOCK_MODEL: tl.constexpr +): + offset_bh = tl.program_id(0) + offset_d = tl.program_id(1) + offset_s = tl.program_id(2) + + S = S + offset_bh * NUM_BLOCK * D_MODEL_K * D_MODEL_V + offset_d * D_MODEL_V * BLOCK_MODEL + \ + tl.arange(0, BLOCK_MODEL)[:, None] * D_MODEL_V + offset_s * \ + BLOCK_MODEL + tl.arange(0, BLOCK_MODEL)[None, :] + + O = O + offset_bh * NUM_BLOCK * D_MODEL_K * D_MODEL_V + offset_d * D_MODEL_V * BLOCK_MODEL + \ + tl.arange(0, BLOCK_MODEL)[:, None] * D_MODEL_V + offset_s * BLOCK_MODEL + \ + tl.arange(0, BLOCK_MODEL)[None, :] + D_MODEL_K * D_MODEL_V + + p2 = p2 + offset_bh * NUM_BLOCK * D_MODEL_V + \ + tl.arange(0, BLOCK_MODEL) + offset_s * BLOCK_MODEL + D_MODEL_V + + acc = tl.zeros([BLOCK_MODEL, BLOCK_MODEL], dtype=tl.float32) + acc += tl.load(S) + + S += D_MODEL_K * D_MODEL_V + + tl.store(O, acc.to(O.dtype.element_ty)) + O += D_MODEL_K * D_MODEL_V + + for i in range(NUM_BLOCK-2): + p_v = tl.load(p2) + S_i = tl.load(S) + acc = acc * p_v[None, :] + S_i + tl.store(O, acc.to(O.dtype.element_ty)) + p2 += D_MODEL_V + S += D_MODEL_K * D_MODEL_V + O += D_MODEL_K * D_MODEL_V + + +# NUM_SPLIT_K/V. K/V dimension split into NUM_SPLIT_K/V parts with equal size BLOCK_MODEL +@triton.jit +def _bwd_recurrence( + S, + p2, + DS, + Dp2, + NUM_BLOCK, + NUM_SPLIT_K, + NUM_SPLIT_V, + D_MODEL_K: tl.constexpr, + D_MODEL_V: tl.constexpr, + BLOCK_MODEL: tl.constexpr + +): + + offset_bh = tl.program_id(0) + offset_d = tl.program_id(1) + offset_s = tl.program_id(2) + + # skip the last chunk because it is never used + S = S + offset_bh * NUM_BLOCK * D_MODEL_K * D_MODEL_V + offset_d * D_MODEL_V * BLOCK_MODEL + \ + tl.arange(0, BLOCK_MODEL)[:, None] * D_MODEL_V + offset_s * BLOCK_MODEL + tl.arange( + 0, BLOCK_MODEL)[None, :] + (NUM_BLOCK - 2) * D_MODEL_K * D_MODEL_V + + # start from the last chunk + DS = DS + offset_bh * NUM_BLOCK * D_MODEL_K * D_MODEL_V + offset_d * D_MODEL_V * BLOCK_MODEL + \ + tl.arange(0, BLOCK_MODEL)[:, None] * D_MODEL_V + offset_s * BLOCK_MODEL + tl.arange( + 0, BLOCK_MODEL)[None, :] + (NUM_BLOCK - 1) * D_MODEL_K * D_MODEL_V + + # skip the last chunk because it is never used + # p1 = p1 + offset_bh * NUM_BLOCK * D_MODEL_K + tl.arange(0, BLOCK_MODEL) + offset_d * BLOCK_MODEL + (NUM_BLOCK - 2) * D_MODEL_K + + # skip the last chunk because it is never used + p2 = p2 + offset_bh * NUM_BLOCK * D_MODEL_V + \ + tl.arange(0, BLOCK_MODEL) + offset_s * \ + BLOCK_MODEL + (NUM_BLOCK - 2) * D_MODEL_V + + # skip the last chunk because it is never used + # NUM_BLOCK * D_MODEL_K * NUM_SPLIT_V: stride_bh + # offset_s * D_MODEL_K: find the right split in the K dimension + # Dp1 = Dp1 + offset_bh * NUM_BLOCK * D_MODEL_K * NUM_SPLIT_V + offset_s * D_MODEL_K + tl.arange(0, BLOCK_MODEL) + offset_d * BLOCK_MODEL + (NUM_BLOCK - 2) * D_MODEL_K * NUM_SPLIT_V + + # skip the last chunk because it is never used + Dp2 = Dp2 + offset_bh * NUM_BLOCK * D_MODEL_V * NUM_SPLIT_K + offset_d * D_MODEL_V + \ + tl.arange(0, BLOCK_MODEL) + offset_s * BLOCK_MODEL + \ + (NUM_BLOCK - 2) * D_MODEL_V * NUM_SPLIT_K + + Dacc = tl.zeros([BLOCK_MODEL, BLOCK_MODEL], dtype=tl.float32) + + # ignore the first chunk + for i in range(NUM_BLOCK - 1): + + # p_key = tl.load(p1) + p_value = tl.load(p2) + S_i = tl.load(S) + DS_i = tl.load(DS) + Dacc += DS_i + dp_i = Dacc * S_i + # dp_key = tl.sum(dp_i * p_value[None, :], axis=1) + # tl.store(Dp1, dp_key.to(Dp1.dtype.element_ty)) + dp_value = tl.sum(dp_i, axis=0) + tl.store(Dp2, dp_value.to(Dp2.dtype.element_ty)) + + tl.store(S, Dacc.to(S.dtype.element_ty)) + + # Dacc *= p_key[:, None] + Dacc *= p_value[None, :] + + S -= D_MODEL_K * D_MODEL_V + DS -= D_MODEL_K * D_MODEL_V + # p1 -= D_MODEL_K + p2 -= D_MODEL_V + # Dp1 -= D_MODEL_K * NUM_SPLIT_V + Dp2 -= D_MODEL_V * NUM_SPLIT_K + + +class Chunk_memory_update_only_gv(torch.autograd.Function): + @staticmethod + @contiguous + @custom_fwd + def forward(ctx, decay_value_last, to_add): + B, H, N, D_k, D_v = to_add.shape + output = torch.empty_like(to_add) + BLOCK_MODEL = 32 + + assert D_k % 32 == 0 + assert D_v % 32 == 0 + # assert D_k == decay_key_last.shape[-1] + assert D_v == decay_value_last.shape[-1] + + grid = (B*H, D_k//BLOCK_MODEL, D_v//BLOCK_MODEL) + ctx.grid = grid + ctx.BLOCK_MODEL = BLOCK_MODEL + + _fwd_recurrence[grid]( + to_add, + decay_value_last, + output, + D_MODEL_K=D_k, D_MODEL_V=D_v, + NUM_BLOCK=N, + BLOCK_MODEL=BLOCK_MODEL + ) + + output[:, :, 0] = 0 + ctx.save_for_backward(output, decay_value_last) + + return output + + @staticmethod + @contiguous + @custom_bwd + def backward(ctx, DO): + + output, decay_value_last = ctx.saved_tensors + + B, H, N, D_k, D_v = output.shape + + num_block = N + + BLOCK_MODEL = 32 + + grid = (B*H, D_k//BLOCK_MODEL, D_v//BLOCK_MODEL) + + # I don't want atomic_add to be used in the backward pass + # so I add another dimension to the output tensor (D_k/v // BLOCK_MODEL) + # afterward, I sum over this dimension to get the correct gradient + D_p2 = torch.empty(B, H, N, D_k // BLOCK_MODEL, D_v, + device=DO.device, dtype=torch.float32) + + _bwd_recurrence[grid]( + output, decay_value_last, + DO, D_p2, + NUM_BLOCK=num_block, NUM_SPLIT_K=D_k // BLOCK_MODEL, NUM_SPLIT_V=D_v // BLOCK_MODEL, + D_MODEL_K=D_k, + D_MODEL_V=D_v, + BLOCK_MODEL=BLOCK_MODEL + ) + + output[:, :, -1] = 0 + # D_p1[:, :, 0] = 0 + # D_p1[:, :, -1] = 0 + D_p2[:, :, 0] = 0 + D_p2[:, :, -1] = 0 + + return D_p2.sum(-2).to(decay_value_last.dtype), output diff --git a/flash_linear_attention/fla/ops/triton/gla/block_parallel/inter_chunk_contribution/fn.py b/flash_linear_attention/fla/ops/triton/gla/block_parallel/inter_chunk_contribution/fn.py new file mode 100644 index 0000000..138eeba --- /dev/null +++ b/flash_linear_attention/fla/ops/triton/gla/block_parallel/inter_chunk_contribution/fn.py @@ -0,0 +1,49 @@ +# -*- coding: utf-8 -*- + +from .chunk_scan_triton_full import Chunk_memory_update_full +from .chunk_scan_triton_no_decay import Chunk_memory_update_no_decay +from .chunk_scan_triton_only_gk import Chunk_memory_update_only_gk +from .chunk_scan_triton_only_gv import Chunk_memory_update_only_gv +from .preprocess_cumsum_gk import PreprocessCumSum_GK +from .preprocess_cumsum_gv import PreprocessCumSum_GV + + +def inter_chunk_onc(query, key, value, gk, gv): + if gk is not None: + g_key_cumsum, reduce_key, q_exp, g_key_last_exp = PreprocessCumSum_GK.apply( + query, key, gk) + else: + reduce_key = key + q_exp = None + g_key_cumsum = None + g_key_last_exp = None + + if gv is not None: + g_value_cumsum, reduce_value, g_value_cumsum_exp, g_value_last_exp = PreprocessCumSum_GV.apply( + value, gv) + else: + reduce_value = value + g_value_cumsum = None + g_value_last_exp = None + + to_add = reduce_key.transpose(-1, -2).to(query.dtype) @ reduce_value.to(value.dtype) + + if gk is not None and gv is not None: + memory_cache = Chunk_memory_update_full.apply( + g_key_last_exp, g_value_last_exp, to_add) + inter_chunk_contribution = ( + (q_exp.to(query.dtype)) @ memory_cache) * g_value_cumsum_exp + elif gk is None and gv is not None: + memory_cache = Chunk_memory_update_only_gv.apply( + g_value_last_exp, to_add) + inter_chunk_contribution = ( + (query) @ memory_cache) * g_value_cumsum_exp + elif gk is not None and gv is None: + memory_cache = Chunk_memory_update_only_gk.apply( + g_key_last_exp, to_add) + inter_chunk_contribution = ((q_exp.to(query.dtype)) @ memory_cache) + else: + memory_cache = Chunk_memory_update_no_decay.apply(to_add) + inter_chunk_contribution = ((query) @ memory_cache) + + return g_key_cumsum, g_value_cumsum, inter_chunk_contribution.to(query.dtype) diff --git a/flash_linear_attention/fla/ops/triton/gla/block_parallel/inter_chunk_contribution/preprocess_cumsum_gk.py b/flash_linear_attention/fla/ops/triton/gla/block_parallel/inter_chunk_contribution/preprocess_cumsum_gk.py new file mode 100644 index 0000000..c0f1d81 --- /dev/null +++ b/flash_linear_attention/fla/ops/triton/gla/block_parallel/inter_chunk_contribution/preprocess_cumsum_gk.py @@ -0,0 +1,259 @@ +# -*- coding: utf-8 -*- + +import torch +import triton +import triton.language as tl +from torch.cuda.amp import custom_bwd, custom_fwd + +from fla.ops.triton.utils import contiguous + +# def stable_logsigmoid(x): +# # Use the identity log(sigmoid(x)) = -log(1 + exp(-x)) +# # This is stable for large negative values of x +# neg_abs_x = -torch.abs(x) +# return torch.where(x < 0, x, neg_abs_x) - torch.log1p(torch.exp(neg_abs_x)) + + +@triton.jit +def _fwd_preprocess_cumsum_gk( + Q, + K, + GK, + GK_cumsum, + Q_exp, + K_reduce, + GK_last_exp, + NUM_CHUNK, + L, + D_MODEL_K: tl.constexpr, + D_BLOCK_K: tl.constexpr, + CHUNK_SIZE: tl.constexpr, +): + offset_bh = tl.program_id(0) + offset_c = tl.program_id(1) + offset_nk = tl.program_id(2) + Q_ptr = Q + offset_bh * L * D_MODEL_K + offset_c * \ + CHUNK_SIZE * D_MODEL_K + tl.arange(0, D_BLOCK_K) + D_BLOCK_K * offset_nk + Q_exp_ptr = Q_exp + offset_bh * L * D_MODEL_K + offset_c * \ + CHUNK_SIZE * D_MODEL_K + tl.arange(0, D_BLOCK_K) + D_BLOCK_K * offset_nk + GK_ptr = GK + offset_bh * L * D_MODEL_K + offset_c * \ + CHUNK_SIZE * D_MODEL_K + tl.arange(0, D_BLOCK_K) + D_BLOCK_K * offset_nk + GK_cumsum_ptr = GK_cumsum + offset_bh * L * D_MODEL_K + \ + offset_c * CHUNK_SIZE * D_MODEL_K + tl.arange(0, D_BLOCK_K) + D_BLOCK_K * offset_nk + GK_last_exp_ptr = GK_last_exp + offset_bh * NUM_CHUNK * \ + D_MODEL_K + offset_c * D_MODEL_K + tl.arange(0, D_BLOCK_K) + D_BLOCK_K * offset_nk + cumsum = tl.zeros([D_BLOCK_K], dtype=tl.float32) + + mask = (D_BLOCK_K * offset_nk + tl.arange(0, D_BLOCK_K)) < D_MODEL_K + + for _ in range(CHUNK_SIZE): + gk = tl.load(GK_ptr, mask=mask, other=0).to(tl.float32) + cumsum += gk + tl.store(GK_cumsum_ptr, cumsum.to(GK_cumsum_ptr.dtype.element_ty), mask=mask) + cumsum_exp = tl.exp(cumsum) + q = tl.load(Q_ptr, mask=mask, other=0) + q_exp = q * cumsum_exp + tl.store(Q_exp_ptr, q_exp, mask=mask) + Q_ptr += D_MODEL_K + Q_exp_ptr += D_MODEL_K + GK_ptr += D_MODEL_K + GK_cumsum_ptr += D_MODEL_K + + tl.store(GK_last_exp_ptr, tl.exp(cumsum).to( + GK_last_exp_ptr.dtype.element_ty), mask=mask) + + tl.debug_barrier() + + GK_cumsum_ptr = GK_cumsum + offset_bh * L * D_MODEL_K + \ + offset_c * CHUNK_SIZE * D_MODEL_K + tl.arange(0, D_BLOCK_K) + D_BLOCK_K * offset_nk + K_ptr = K + offset_bh * L * D_MODEL_K + offset_c * \ + CHUNK_SIZE * D_MODEL_K + tl.arange(0, D_BLOCK_K) + D_BLOCK_K * offset_nk + K_reduce_ptr = K_reduce + offset_bh * L * D_MODEL_K + \ + offset_c * CHUNK_SIZE * D_MODEL_K + tl.arange(0, D_BLOCK_K) + D_BLOCK_K * offset_nk + + for _ in range(CHUNK_SIZE): + gk_cumsum = tl.load(GK_cumsum_ptr, mask=mask, other=0) + k = tl.load(K_ptr, mask=mask, other=0) + k_reduce = k * tl.exp(cumsum - gk_cumsum) + tl.store(K_reduce_ptr, k_reduce.to(K_reduce_ptr.dtype.element_ty), mask=mask) + + K_ptr += D_MODEL_K + GK_cumsum_ptr += D_MODEL_K + K_reduce_ptr += D_MODEL_K + + +@triton.jit +def _bwd_preprocess_cumsum_gk( + Q, + K, + GK, + GK_cumsum, + DQ_exp, + DK_reduce, + DGK_last_exp, + DGK_cumsum, + DQ, + DK, + DGK, + NUM_CHUNK, + L, + D_MODEL_K: tl.constexpr, + D_BLOCK_K: tl.constexpr, + CHUNK_SIZE: tl.constexpr, +): + + offset_bh = tl.program_id(0) + offset_c = tl.program_id(1) + offset_nk = tl.program_id(2) + mask = (D_BLOCK_K * offset_nk + tl.arange(0, D_BLOCK_K)) < D_MODEL_K + + Q_ptr = Q + offset_bh * L * D_MODEL_K + offset_c * \ + CHUNK_SIZE * D_MODEL_K + tl.arange(0, D_BLOCK_K) + D_BLOCK_K * offset_nk + K_ptr = K + offset_bh * L * D_MODEL_K + offset_c * \ + CHUNK_SIZE * D_MODEL_K + tl.arange(0, D_BLOCK_K) + D_BLOCK_K * offset_nk + GK_ptr = GK + offset_bh * L * D_MODEL_K + offset_c * \ + CHUNK_SIZE * D_MODEL_K + tl.arange(0, D_BLOCK_K) + D_BLOCK_K * offset_nk + GK_cumsum_ptr = GK_cumsum + offset_bh * L * D_MODEL_K + \ + offset_c * CHUNK_SIZE * D_MODEL_K + tl.arange(0, D_BLOCK_K) + D_BLOCK_K * offset_nk + + DQ_ptr = DQ + offset_bh * L * D_MODEL_K + offset_c * \ + CHUNK_SIZE * D_MODEL_K + tl.arange(0, D_BLOCK_K) + D_BLOCK_K * offset_nk + DK_ptr = DK + offset_bh * L * D_MODEL_K + offset_c * \ + CHUNK_SIZE * D_MODEL_K + tl.arange(0, D_BLOCK_K) + D_BLOCK_K * offset_nk + DQ_exp_ptr = DQ_exp + offset_bh * L * D_MODEL_K + offset_c * \ + CHUNK_SIZE * D_MODEL_K + tl.arange(0, D_BLOCK_K) + D_BLOCK_K * offset_nk + DK_reduce_ptr = DK_reduce + offset_bh * L * D_MODEL_K + \ + offset_c * CHUNK_SIZE * D_MODEL_K + tl.arange(0, D_BLOCK_K) + D_BLOCK_K * offset_nk + DGK_cumsum_ptr = DGK_cumsum + offset_bh * L * D_MODEL_K + \ + offset_c * CHUNK_SIZE * D_MODEL_K + tl.arange(0, D_BLOCK_K) + D_BLOCK_K * offset_nk + DGK_ptr = DGK + offset_bh * L * D_MODEL_K + offset_c * \ + CHUNK_SIZE * D_MODEL_K + tl.arange(0, D_BLOCK_K) + D_BLOCK_K * offset_nk + + D_GK_last_exp_ptr = DGK_last_exp + offset_bh * NUM_CHUNK * \ + D_MODEL_K + offset_c * D_MODEL_K + tl.arange(0, D_BLOCK_K) + D_BLOCK_K * offset_nk + # + cumsum_gradient = tl.zeros([D_BLOCK_K], dtype=tl.float32) + grad_gk_last = tl.zeros([D_BLOCK_K], dtype=tl.float32) + + gk_last = tl.load(GK_cumsum_ptr + (CHUNK_SIZE - 1) + * D_MODEL_K, mask=mask, other=0).to(tl.float32) + cumsum_gradient += tl.load(D_GK_last_exp_ptr, mask=mask, other=0) * tl.exp(gk_last) + + GK_ptr += (CHUNK_SIZE - 1) * D_MODEL_K + GK_cumsum_ptr += (CHUNK_SIZE - 1) * D_MODEL_K + Q_ptr += (CHUNK_SIZE - 1) * D_MODEL_K + K_ptr += (CHUNK_SIZE - 1) * D_MODEL_K + + DQ_exp_ptr += (CHUNK_SIZE - 1) * D_MODEL_K + DK_reduce_ptr += (CHUNK_SIZE - 1) * D_MODEL_K + DGK_cumsum_ptr += (CHUNK_SIZE - 1) * D_MODEL_K + DQ_ptr += (CHUNK_SIZE - 1) * D_MODEL_K + DK_ptr += (CHUNK_SIZE - 1) * D_MODEL_K + DGK_ptr += (CHUNK_SIZE - 1) * D_MODEL_K + + for idx in range(CHUNK_SIZE - 1, -1, -1): + gk_cs = tl.load(GK_cumsum_ptr, mask=mask, other=0).to(tl.float32) + k = tl.load(K_ptr, mask=mask, other=0).to(tl.float32) + grad_k = tl.exp(gk_last - gk_cs) * \ + tl.load(DK_reduce_ptr, mask=mask, other=0).to(tl.float32) + tl.store(DK_ptr, grad_k.to(DK_ptr.dtype.element_ty), mask=mask) + grad_k *= k + cumsum_gradient -= grad_k + grad_gk_last += grad_k + + q = tl.load(Q_ptr, mask=mask, other=0).to(tl.float32) + grad_q = tl.exp(gk_cs) * tl.load(DQ_exp_ptr, mask=mask, other=0) + tl.store(DQ_ptr, grad_q.to(DK_ptr.dtype.element_ty), mask=mask) + cumsum_gradient += grad_q * q.to(tl.float32) + + # from intra-chunk contribution. + cumsum_gradient += tl.load(DGK_cumsum_ptr, mask=mask, other=0).to(tl.float32) + + tl.store(DGK_ptr, cumsum_gradient.to(DGK_ptr.dtype.element_ty), mask=mask) + + Q_ptr -= D_MODEL_K + DQ_exp_ptr -= D_MODEL_K + K_ptr -= D_MODEL_K + DK_reduce_ptr -= D_MODEL_K + GK_cumsum_ptr -= D_MODEL_K + DGK_cumsum_ptr -= D_MODEL_K + DQ_ptr -= D_MODEL_K + DK_ptr -= D_MODEL_K + DGK_ptr -= D_MODEL_K + + DGK_ptr = DGK + offset_bh * L * D_MODEL_K + offset_c * CHUNK_SIZE * \ + D_MODEL_K + tl.arange(0, D_BLOCK_K) + (CHUNK_SIZE - 1) * D_MODEL_K + D_BLOCK_K * offset_nk + GK_ptr = GK + offset_bh * L * D_MODEL_K + offset_c * CHUNK_SIZE * \ + D_MODEL_K + tl.arange(0, D_BLOCK_K) + (CHUNK_SIZE - 1) * D_MODEL_K + D_BLOCK_K * offset_nk + + # tl.store(D_GK_last_exp_ptr, cumsum_gradient) + + # seems stupid. just workaround some compiler bugs. + grad_gk_last = grad_gk_last + 0. + for idx in range(CHUNK_SIZE - 1, -1, -1): + dgk = tl.load(DGK_ptr, mask=mask, other=0).to(tl.float32) + dgk += grad_gk_last + tl.store(DGK_ptr, dgk.to(DGK_ptr.dtype.element_ty), mask=mask) + DGK_ptr -= D_MODEL_K + GK_ptr -= D_MODEL_K + + +class PreprocessCumSum_GK(torch.autograd.Function): + @staticmethod + @contiguous + @custom_fwd + def forward(ctx, q, k, gk): + B, H, NUM_CHUNK, CHUNK_SIZE, D = q.shape + + D_k = k.shape[-1] + N_k = triton.cdiv(D_k, 32) + grid = (B * H, NUM_CHUNK, N_k) + + k_reduce = torch.empty_like(k) + + q_exp = torch.empty_like(q) + + gk_cumsum = torch.empty_like(gk) + + gk_last_exp = torch.empty_like(gk[:, :, :, 0], dtype=torch.float32) + + _fwd_preprocess_cumsum_gk[grid]( + q, k, gk, gk_cumsum, + q_exp, k_reduce, gk_last_exp, + CHUNK_SIZE=CHUNK_SIZE, NUM_CHUNK=NUM_CHUNK, L=CHUNK_SIZE * NUM_CHUNK, + D_MODEL_K=D_k, D_BLOCK_K=32, num_warps=1, num_stages=2 + ) + + ctx.grid = grid + ctx.save_for_backward(q, k, gk, gk_cumsum) + + return gk_cumsum, k_reduce, q_exp, gk_last_exp + + @staticmethod + @custom_bwd + @contiguous + def backward(ctx, dgk_cumsum, dk_reduce, dq_exp, dgk_last_exp): + q, k, gk, gk_cumsum = ctx.saved_tensors + B, H, NUM_CHUNK, CHUNK_SIZE, D = q.shape + + D_k = k.shape[-1] + N_k = triton.cdiv(D_k, 32) + grid = (B * H, NUM_CHUNK, N_k) + + dq = torch.empty_like(q) + dk = torch.empty_like(k) + dgk = torch.empty_like(gk) + + B, H, NUM_CHUNK, CHUNK_SIZE, D_k = q.shape + + # D_v = v.shape[-1] + + _bwd_preprocess_cumsum_gk[grid]( + q, k, gk, gk_cumsum, + dq_exp, dk_reduce, dgk_last_exp, dgk_cumsum, + dq, dk, dgk, + CHUNK_SIZE=CHUNK_SIZE, NUM_CHUNK=NUM_CHUNK, L=CHUNK_SIZE * NUM_CHUNK, + D_MODEL_K=D_k, D_BLOCK_K=32, num_warps=1, num_stages=2 + ) + + return dq.to(q.dtype), dk.to(k.dtype), dgk.to(gk.dtype), None, None, None diff --git a/flash_linear_attention/fla/ops/triton/gla/block_parallel/inter_chunk_contribution/preprocess_cumsum_gv.py b/flash_linear_attention/fla/ops/triton/gla/block_parallel/inter_chunk_contribution/preprocess_cumsum_gv.py new file mode 100644 index 0000000..85604aa --- /dev/null +++ b/flash_linear_attention/fla/ops/triton/gla/block_parallel/inter_chunk_contribution/preprocess_cumsum_gv.py @@ -0,0 +1,216 @@ +# -*- coding: utf-8 -*- + +import torch +import triton +import triton.language as tl + +from fla.ops.triton.utils import contiguous + + +@triton.jit +def _fwd_preprocess_cumsum_gv( + V, + GV, + GV_cumsum, + GV_exp, + V_reduce, + GV_last_exp, + NUM_CHUNK, + L, + D_MODEL_V: tl.constexpr, + CHUNK_SIZE: tl.constexpr, +): + + offset_bh = tl.program_id(0) + offset_c = tl.program_id(1) + + GV_ptr = GV + offset_bh * L * D_MODEL_V + offset_c * \ + CHUNK_SIZE * D_MODEL_V + tl.arange(0, D_MODEL_V) + + GV_last_exp_ptr = GV_last_exp + offset_bh * NUM_CHUNK * \ + D_MODEL_V + offset_c * D_MODEL_V + tl.arange(0, D_MODEL_V) + + GV_cumsum_ptr = GV_cumsum + offset_bh * L * D_MODEL_V + \ + offset_c * CHUNK_SIZE * D_MODEL_V + tl.arange(0, D_MODEL_V) + GV_exp_ptr = GV_exp + offset_bh * L * D_MODEL_V + offset_c * \ + CHUNK_SIZE * D_MODEL_V + tl.arange(0, D_MODEL_V) + + cumsum = tl.zeros([D_MODEL_V], dtype=tl.float32) + + for _ in range(CHUNK_SIZE): + gv = tl.load(GV_ptr).to(tl.float32) + cumsum += gv + + tl.store(GV_cumsum_ptr, cumsum.to(GV_cumsum_ptr.dtype.element_ty)) + tl.store(GV_exp_ptr, tl.exp(cumsum).to(GV_cumsum_ptr.dtype.element_ty)) + + GV_cumsum_ptr += D_MODEL_V + GV_exp_ptr += D_MODEL_V + GV_ptr += D_MODEL_V + + tl.store(GV_last_exp_ptr, tl.exp(cumsum).to( + GV_last_exp_ptr.dtype.element_ty)) + + tl.debug_barrier() + + V_ptr = V + offset_bh * L * D_MODEL_V + offset_c * \ + CHUNK_SIZE * D_MODEL_V + tl.arange(0, D_MODEL_V) + GV_cumsum_ptr = GV_cumsum + offset_bh * L * D_MODEL_V + \ + offset_c * CHUNK_SIZE * D_MODEL_V + tl.arange(0, D_MODEL_V) + V_reduce_ptr = V_reduce + offset_bh * L * D_MODEL_V + \ + offset_c * CHUNK_SIZE * D_MODEL_V + tl.arange(0, D_MODEL_V) + + for _ in range(CHUNK_SIZE): + v = tl.load(V_ptr) + gv = tl.load(GV_cumsum_ptr) + v_reduce = v * tl.exp(cumsum - gv) + tl.store(V_reduce_ptr, v_reduce.to(V_reduce_ptr.dtype.element_ty)) + + V_ptr += D_MODEL_V + V_reduce_ptr += D_MODEL_V + GV_cumsum_ptr += D_MODEL_V + + +@triton.jit +def _bwd_preprocess_cumsum_gv( + V, + GV, + GV_cumsum, + DGV_cumsum_exp, + DV_reduce, + DGV_last_exp, + DGV_cumsum, + DV, + DGV, + NUM_CHUNK, + L, + D_MODEL_V: tl.constexpr, + CHUNK_SIZE: tl.constexpr, +): + + offset_bh = tl.program_id(0) + offset_c = tl.program_id(1) + V_ptr = V + offset_bh * L * D_MODEL_V + offset_c * \ + CHUNK_SIZE * D_MODEL_V + tl.arange(0, D_MODEL_V) + GV_ptr = GV + offset_bh * L * D_MODEL_V + offset_c * \ + CHUNK_SIZE * D_MODEL_V + tl.arange(0, D_MODEL_V) + GV_cumsum_ptr = GV_cumsum + offset_bh * L * D_MODEL_V + \ + offset_c * CHUNK_SIZE * D_MODEL_V + tl.arange(0, D_MODEL_V) + + DV_ptr = DV + offset_bh * L * D_MODEL_V + offset_c * \ + CHUNK_SIZE * D_MODEL_V + tl.arange(0, D_MODEL_V) + DV_reduce_ptr = DV_reduce + offset_bh * L * D_MODEL_V + \ + offset_c * CHUNK_SIZE * D_MODEL_V + tl.arange(0, D_MODEL_V) + DGV_cumsum_ptr = DGV_cumsum + offset_bh * L * D_MODEL_V + \ + offset_c * CHUNK_SIZE * D_MODEL_V + tl.arange(0, D_MODEL_V) + DGV_cumsum_exp_ptr = DGV_cumsum_exp + offset_bh * L * D_MODEL_V + \ + offset_c * CHUNK_SIZE * D_MODEL_V + tl.arange(0, D_MODEL_V) + + DGV_ptr = DGV + offset_bh * L * D_MODEL_V + offset_c * \ + CHUNK_SIZE * D_MODEL_V + tl.arange(0, D_MODEL_V) + + D_GV_last_exp_ptr = DGV_last_exp + offset_bh * NUM_CHUNK * \ + D_MODEL_V + offset_c * D_MODEL_V + tl.arange(0, D_MODEL_V) + + cumsum_gradient = tl.zeros([D_MODEL_V], dtype=tl.float32) + grad_gv_last = tl.zeros([D_MODEL_V], dtype=tl.float32) + + gv_last = tl.load(GV_cumsum_ptr + (CHUNK_SIZE - 1) * D_MODEL_V) + cumsum_gradient += tl.load(D_GV_last_exp_ptr) * \ + tl.exp(gv_last).to(tl.float32) + + GV_ptr += (CHUNK_SIZE - 1) * D_MODEL_V + GV_cumsum_ptr += (CHUNK_SIZE - 1) * D_MODEL_V + + V_ptr += (CHUNK_SIZE - 1) * D_MODEL_V + + DV_reduce_ptr += (CHUNK_SIZE - 1) * D_MODEL_V + DGV_cumsum_ptr += (CHUNK_SIZE - 1) * D_MODEL_V + DGV_cumsum_exp_ptr += (CHUNK_SIZE - 1) * D_MODEL_V + DV_ptr += (CHUNK_SIZE - 1) * D_MODEL_V + DGV_ptr += (CHUNK_SIZE - 1) * D_MODEL_V + + for idx in range(CHUNK_SIZE - 1, -1, -1): + gv_cs = tl.load(GV_cumsum_ptr).to(tl.float32) + v = tl.load(V_ptr).to(tl.float32) + grad_v = tl.exp(gv_last - gv_cs) * \ + tl.load(DV_reduce_ptr).to(tl.float32) + tl.store(DV_ptr, grad_v.to(DV_ptr.dtype.element_ty)) + grad_v *= v + cumsum_gradient -= grad_v + grad_gv_last += grad_v + + # q = tl.load(Q_ptr).to(tl.float32) + grad_v = tl.exp(gv_cs) * tl.load(DGV_cumsum_exp_ptr) + cumsum_gradient += grad_v + + # from intra-chunk contribution. + cumsum_gradient += tl.load(DGV_cumsum_ptr).to(tl.float32) + + tl.store(DGV_ptr, cumsum_gradient.to(DGV_ptr.dtype.element_ty)) + + V_ptr -= D_MODEL_V + DV_reduce_ptr -= D_MODEL_V + GV_cumsum_ptr -= D_MODEL_V + DGV_cumsum_ptr -= D_MODEL_V + DV_ptr -= D_MODEL_V + DGV_ptr -= D_MODEL_V + DGV_cumsum_exp_ptr -= D_MODEL_V + + DGV_ptr = DGV + offset_bh * L * D_MODEL_V + offset_c * CHUNK_SIZE * \ + D_MODEL_V + tl.arange(0, D_MODEL_V) + (CHUNK_SIZE - 1) * D_MODEL_V + GV_ptr = GV + offset_bh * L * D_MODEL_V + offset_c * CHUNK_SIZE * \ + D_MODEL_V + tl.arange(0, D_MODEL_V) + (CHUNK_SIZE - 1) * D_MODEL_V + + grad_gv_last = grad_gv_last + 0. + + for idx in range(CHUNK_SIZE - 1, -1, -1): + dgv = tl.load(DGV_ptr).to(tl.float32) + dgv += grad_gv_last + tl.store(DGV_ptr, dgv.to(DGV_ptr.dtype.element_ty)) + DGV_ptr -= D_MODEL_V + GV_ptr -= D_MODEL_V + + +class PreprocessCumSum_GV(torch.autograd.Function): + @staticmethod + @contiguous + @torch.cuda.amp.custom_fwd + def forward(ctx, v, gv): + B, H, NUM_CHUNK, CHUNK_SIZE, D_v = v.shape + + grid = (B * H, NUM_CHUNK) + ctx.grid = grid + + gv_cumsum = torch.empty_like(gv, dtype=torch.float32) + gv_cumsum_exp = torch.empty_like(gv) + v_reduce = torch.empty_like(v) + gv_last_exp = torch.empty_like(gv[:, :, :, 0], dtype=torch.float32) + _fwd_preprocess_cumsum_gv[grid]( + v, gv, gv_cumsum, gv_cumsum_exp, + v_reduce, gv_last_exp, + CHUNK_SIZE=CHUNK_SIZE, NUM_CHUNK=NUM_CHUNK, L=CHUNK_SIZE * NUM_CHUNK, + D_MODEL_V=D_v, num_warps=8 if D_v >= 512 else 4 + ) + + ctx.grid = grid + ctx.save_for_backward(v, gv, gv_cumsum) + return gv_cumsum, v_reduce, gv_cumsum_exp, gv_last_exp + + @staticmethod + @contiguous + def backward(ctx, dgv_cumsum, dv_reduce, dgv_cumsum_exp, dgv_last_exp): + v, gv, gv_cumsum = ctx.saved_tensors + grid = ctx.grid + + B, H, NUM_CHUNK, CHUNK_SIZE, D_v = v.shape + + dv = torch.empty_like(v) + dgv = torch.empty_like(gv) + _bwd_preprocess_cumsum_gv[grid]( + v, gv, gv_cumsum, dgv_cumsum_exp, dv_reduce, dgv_last_exp, dgv_cumsum, + dv, dgv, + CHUNK_SIZE=CHUNK_SIZE, NUM_CHUNK=NUM_CHUNK, L=CHUNK_SIZE * NUM_CHUNK, + D_MODEL_V=D_v, num_warps=8 if D_v >= 512 else 4 + ) + return dv.to(v.dtype), dgv.to(gv.dtype) diff --git a/flash_linear_attention/fla/ops/triton/gla/block_parallel/intra_chunk_contribution/__init__.py b/flash_linear_attention/fla/ops/triton/gla/block_parallel/intra_chunk_contribution/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/flash_linear_attention/fla/ops/triton/gla/block_parallel/intra_chunk_contribution/fn.py b/flash_linear_attention/fla/ops/triton/gla/block_parallel/intra_chunk_contribution/fn.py new file mode 100644 index 0000000..4375856 --- /dev/null +++ b/flash_linear_attention/fla/ops/triton/gla/block_parallel/intra_chunk_contribution/fn.py @@ -0,0 +1,28 @@ +# -*- coding: utf-8 -*- + +import torch + +from .fn_only_gk import IntraCalA +from .fn_only_gv import IntraCalO + + +def intra_chunk_onc(q, k, v, gk, gv): + assert q.is_contiguous() + assert k.is_contiguous() + assert v.is_contiguous() + if gk is not None: + assert gk.is_contiguous() + if gv is not None: + assert gv.is_contiguous() + + assert k.shape[-2] % 16 == 0 + + if gk is not None: + A = IntraCalA.apply(q, k, gk) + else: + A = q @ k.transpose(-1, -2) + + mask = torch.triu(torch.ones(A.shape[-2], A.shape[-2]), diagonal=1).bool().to(A.device) + A.masked_fill_(mask, 0) + + return IntraCalO.apply(A, v, gv) if gv is not None else A.to(v) @ v diff --git a/flash_linear_attention/fla/ops/triton/gla/block_parallel/intra_chunk_contribution/fn_only_gk.py b/flash_linear_attention/fla/ops/triton/gla/block_parallel/intra_chunk_contribution/fn_only_gk.py new file mode 100644 index 0000000..8eaddd9 --- /dev/null +++ b/flash_linear_attention/fla/ops/triton/gla/block_parallel/intra_chunk_contribution/fn_only_gk.py @@ -0,0 +1,343 @@ +# -*- coding: utf-8 -*- + +import torch +import triton +import triton.language as tl +from torch.cuda.amp import custom_bwd, custom_fwd + +from fla.ops.triton.utils import contiguous + + +@triton.jit +def _fwd_kernel_compute_A( + Q, + K, + GK, + A, + stride_q1, + stride_q2, + stride_q3, + stride_q4, + stride_a1, + stride_a2, + stride_a3, + stride_a4, + Z, + H, + N_CTX, + D, + BLOCK_DMODEL_QK: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + start_m = tl.program_id(0) + off_hz = tl.program_id(1) + off_k = tl.program_id(2) + + qk_offset = off_hz * stride_q2 + off_k * BLOCK_DMODEL_QK + a_offset = (off_k * Z*H + off_hz) * stride_a2 + + lo = 0 + hi = BLOCK_N + + Q_ptr = Q + qk_offset + (start_m) * stride_q3 + tl.arange(0, BLOCK_DMODEL_QK)[ + None, :] + tl.arange(0, 16)[:, None] * stride_q4 + + K_ptr = K + qk_offset + (start_m) * stride_q3 + tl.arange(0, BLOCK_DMODEL_QK)[ + :, None] + tl.arange(0, 16)[None, :] * stride_q4 + + GK_K_ptr = GK + qk_offset + (start_m) * stride_q3 + tl.arange( + 0, BLOCK_DMODEL_QK)[:, None] + tl.arange(0, 16)[None, :] * stride_q4 + + GK_Q_ptr = GK + qk_offset + (start_m) * stride_q3 + tl.arange( + 0, BLOCK_DMODEL_QK)[None, :] + tl.arange(0, 16)[:, None] * stride_q4 + + A_ptr = A + a_offset + (start_m) * stride_a3 + tl.arange(0, + 16)[None, :] + tl.arange(0, 16)[:, None] * stride_a4 + + for q_high in range(16, hi, 16): + q = tl.load(Q_ptr + q_high * stride_q4) + q_gk = tl.load(GK_Q_ptr + q_high * stride_q4).to(tl.float32) + q_normalizer = tl.load(GK + qk_offset + start_m * stride_q3 + + q_high * stride_q4 + tl.arange(0, BLOCK_DMODEL_QK)).to(tl.float32) + q_gk2 = tl.exp(q_gk - q_normalizer[None, :]) + q = q * q_gk2.to(q.dtype) + + # inter-chunk bf16 + for k_high in range(0, q_high, 16): + k = tl.load(K_ptr + k_high * stride_q4) + k_gk = tl.load(GK_K_ptr + k_high * stride_q4).to(tl.float32) + k_gk = tl.exp(q_normalizer[:, None] - k_gk) + k = k * k_gk.to(k.dtype) + qk = tl.dot(q, k, allow_tf32=False) + tl.store(A_ptr + q_high * stride_a4 + k_high, + qk.to(A_ptr.dtype.element_ty)) + + # intra chunk fp32 + for q_high in range(lo, hi, 16): + q = tl.load(Q_ptr + q_high * stride_q4) + q_gk = tl.load(GK_Q_ptr + q_high * stride_q4).to(tl.float32) + q_normalizer = tl.load(GK + qk_offset + start_m * stride_q3 + + q_high * stride_q4 + tl.arange(0, BLOCK_DMODEL_QK)).to(tl.float32) + q_gk2 = tl.exp(q_gk - q_normalizer[None, :]) + q = q * q_gk2 + q_gk3 = tl.exp(q_normalizer[None, :] - q_gk) + k = tl.load(K_ptr + q_high * stride_q4) + k = k * tl.trans(q_gk3) + + qk = tl.dot(q, k, allow_tf32=False) + qk = tl.where(tl.arange(0, 16)[:, None] + >= tl.arange(0, 16)[None, :], qk, 0.) + tl.store(A_ptr + q_high * stride_a4 + q_high, + qk.to(A_ptr.dtype.element_ty)) + + +@triton.jit +def _bwd_kernel_dqk( + Q, + K, + GK, + DA, + DQ, + DK, + DGK, + stride_q1, + stride_q2, + stride_q3, + stride_q4, + stride_a1, + stride_a2, + stride_a3, + stride_a4, + Z, + H, + N_CTX, + D, + BLOCK_DMODEL_QK: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr +): + + start_m = tl.program_id(0) + off_hz = tl.program_id(1) + off_k = tl.program_id(2) + + qk_offset = off_hz * stride_q2 + BLOCK_DMODEL_QK * off_k + a_offset = off_hz * stride_a2 + + lo = 0 + hi = BLOCK_N + + Q_ptr = Q + qk_offset + (start_m) * stride_q3 + tl.arange(0, BLOCK_DMODEL_QK)[ + None, :] + tl.arange(0, 16)[:, None] * stride_q4 + + DQ_ptr = DQ + qk_offset + (start_m) * stride_q3 + tl.arange(0, BLOCK_DMODEL_QK)[ + None, :] + tl.arange(0, 16)[:, None] * stride_q4 + + K_ptr = K + qk_offset + (start_m) * stride_q3 + tl.arange(0, BLOCK_DMODEL_QK)[ + None, :] + tl.arange(0, 16)[:, None] * stride_q4 + + GK_K_ptr = GK + qk_offset + (start_m) * stride_q3 + tl.arange( + 0, BLOCK_DMODEL_QK)[None, :] + tl.arange(0, 16)[:, None] * stride_q4 + + GK_Q_ptr = GK + qk_offset + (start_m) * stride_q3 + tl.arange( + 0, BLOCK_DMODEL_QK)[None, :] + tl.arange(0, 16)[:, None] * stride_q4 + + # DGK_Q_ptr = DGK + qk_offset + (start_m) * stride_q3+ tl.arange(0, BLOCK_DMODEL_QK)[None, :] + tl.arange(0, 16)[:, None] * stride_q4 + + DA_ptr = DA + a_offset + (start_m) * stride_a3 + tl.arange(0, + 16)[None, :] + tl.arange(0, 16)[:, None] * stride_a4 + + # inter chunk dq. bf16 + for q_high in range(lo+16, hi, 16): + q = tl.load(Q_ptr + q_high * stride_q4) + + q_normalizer = tl.load(GK + qk_offset + (start_m * stride_q3) + + q_high * stride_q4 + tl.arange(0, BLOCK_DMODEL_QK)).to(tl.float32) + + # q2 = q * q_gk.to(q.dtype) + + dq2 = tl.zeros([16, BLOCK_DMODEL_QK], dtype=tl.float32) + + for k_high in range(0, q_high, 16): + k = tl.load(K_ptr + k_high * stride_q4) + k_gk = tl.load(GK_K_ptr + k_high * stride_q4).to(tl.float32) + dqk = tl.load(DA_ptr + q_high * stride_a4 + k_high).to(k.dtype) + k_gk = tl.exp(q_normalizer[None, :] - k_gk) + k = k * k_gk.to(k.dtype) + dq2 += tl.dot(dqk, k, allow_tf32=False) + + dq2 = dq2.to(q.dtype) + q_gk = tl.load(GK_Q_ptr + q_high * stride_q4).to(tl.float32) + q_gk = tl.exp(q_gk - q_normalizer[None, :]) + dq = dq2 * q_gk.to(q.dtype) + dq_gk = dq * q + + DQ_ptr = DQ + qk_offset + (start_m) * stride_q3 + tl.arange(0, BLOCK_DMODEL_QK)[ + None, :] + tl.arange(0, 16)[:, None] * stride_q4 + q_high * stride_q4 + tl.store(DQ_ptr, dq.to(DQ_ptr.dtype.element_ty)) + + DGK_Q_ptr = DGK + qk_offset + (start_m) * stride_q3 + tl.arange(0, BLOCK_DMODEL_QK)[ + None, :] + tl.arange(0, 16)[:, None] * stride_q4 + q_high * stride_q4 + # prev = tl.load(DGK_Q_ptr) + tl.store(DGK_Q_ptr, dq_gk.to(DGK_Q_ptr.dtype.element_ty)) + + tl.debug_barrier() + + for k_high in range(lo, hi-16, 16): + k = tl.load(K_ptr + k_high * stride_q4) + k_gk = tl.load(GK_K_ptr + k_high * stride_q4) + dk = tl.zeros([16, BLOCK_DMODEL_QK], dtype=tl.float32) + dgk = tl.zeros([16, BLOCK_DMODEL_QK], dtype=tl.float32) + + for q_high in range(k_high+16, hi, 16): + q = tl.load(Q_ptr + q_high * stride_q4) + q_normalizer = tl.load(GK + qk_offset + (start_m * stride_q3) + q_high * stride_q4 + tl.arange(0, + BLOCK_DMODEL_QK)).to(tl.float32) + q_gk = tl.load(GK_Q_ptr + q_high * stride_q4).to(tl.float32) + q_gk = tl.exp(q_gk - q_normalizer[None, :]).to(q.dtype) + q = q * q_gk + dqk = tl.load(DA_ptr + q_high * stride_a4 + k_high).to(q.dtype) + + k_gk2 = tl.exp(q_normalizer[None, :] - k_gk) + + dk2 = tl.dot(tl.trans(dqk), q, allow_tf32=False) + dk += dk2 * k_gk2 + dgk -= dk2 * k * k_gk2 + + DK_ptr = DK + qk_offset + (start_m) * stride_q3 + tl.arange(0, BLOCK_DMODEL_QK)[ + None, :] + tl.arange(0, 16)[:, None] * stride_q4 + k_high * stride_q4 + tl.store(DK_ptr, dk.to(DK_ptr.dtype.element_ty)) + + DGK_K_ptr = DGK + qk_offset + (start_m) * stride_q3 + tl.arange(0, BLOCK_DMODEL_QK)[ + None, :] + tl.arange(0, 16)[:, None] * stride_q4 + k_high * stride_q4 + prev = tl.load(DGK_K_ptr) + tl.store(DGK_K_ptr, (prev + dgk).to(DGK_K_ptr.dtype.element_ty)) + + tl.debug_barrier() + + DK_ptr = DK + qk_offset + (start_m) * stride_q3 + tl.arange(0, BLOCK_DMODEL_QK)[ + None, :] + tl.arange(0, 16)[:, None] * stride_q4 + + DGK_K_ptr = DGK + qk_offset + (start_m) * stride_q3 + tl.arange( + 0, BLOCK_DMODEL_QK)[None, :] + tl.arange(0, 16)[:, None] * stride_q4 + + DQ_ptr = DQ + qk_offset + (start_m) * stride_q3 + tl.arange(0, BLOCK_DMODEL_QK)[ + None, :] + tl.arange(0, 16)[:, None] * stride_q4 + + # intra chunk, fp32. + for q_high in range(lo, hi, 16): + q = tl.load(Q_ptr + q_high * stride_q4) + q_gk = tl.load(GK_Q_ptr + q_high * stride_q4).to(tl.float32) + q_normalizer = tl.load(GK + qk_offset + start_m * stride_q3 + + q_high * stride_q4 + tl.arange(0, BLOCK_DMODEL_QK)).to(tl.float32) + q_gk2 = tl.exp(q_gk - q_normalizer[None, :]) + q2 = q * q_gk2 + q_gk3 = tl.exp(q_normalizer[None, :] - q_gk) + + k = tl.load(K_ptr + q_high * stride_q4) + k2 = k * q_gk3 + + dqk = tl.load(DA_ptr + q_high * stride_a4 + q_high) + dqk = tl.where(tl.arange(0, 16)[:, None] + >= tl.arange(0, 16)[None, :], dqk, 0.) + + dk2 = tl.dot(tl.trans(dqk), q2, allow_tf32=False) + dk = dk2 * q_gk3 + prev_dk = tl.load(DK_ptr + q_high * stride_q4) + tl.store(DK_ptr + q_high * stride_q4, + (dk + prev_dk).to(DK_ptr.dtype.element_ty)) + + dgk = - dk * k + dq2 = tl.dot(dqk, k2, allow_tf32=False) + dq = dq2 * q_gk2 + + prev_dq = tl.load(DQ_ptr + q_high * stride_q4) + tl.store(DQ_ptr + q_high * stride_q4, + (dq + prev_dq).to(DQ_ptr.dtype.element_ty)) + + dgk += dq * q + prev_dq_gk = tl.load(DGK_K_ptr + q_high * stride_q4) + tl.store(DGK_K_ptr + q_high * stride_q4, + (dgk + prev_dq_gk).to(DGK_K_ptr.dtype.element_ty)) + + +class IntraCalA(torch.autograd.Function): + @staticmethod + @custom_fwd + @contiguous + def forward(ctx, q, k, gk): + + # assert gk.dtype==torch.float32 + # only support for Ampere now + + capability = torch.cuda.get_device_capability() + if capability[0] < 8: + raise RuntimeError( + "Flash attention currently only supported for compute capability >= 80") + + # assert gk.dtype == gv.dtype == torch.float32 + # for now. + BLOCK_M = BLOCK_N = q.shape[-2] + + # shape constraints + Lq, Lk = q.shape[-1], k.shape[-1] + assert Lq == Lk + if Lk > 128: + assert Lk % 128 == 0 + + BLOCK_DMODEL_QK = min(Lk, 128) + ctx.BLOCK_DMODEL_QK = BLOCK_DMODEL_QK + + A = torch.zeros(max(1, Lk//128), q.shape[0], q.shape[1], + q.shape[2], BLOCK_N, BLOCK_N, device=q.device, dtype=q.dtype) + + grid = (q.shape[2], q.shape[0] * q.shape[1], max(1, Lk//128)) + + # assert q.dtype == k.dtype == v.dtype + _fwd_kernel_compute_A[grid]( + q, k, gk, A, + q.stride(0), q.stride(1), q.stride(2), q.stride(3), + # be careful here! + A.stride(1), A.stride(2), A.stride(3), A.stride(4), + q.shape[0], q.shape[1], q.shape[2], q.shape[3], + BLOCK_N=BLOCK_N, BLOCK_DMODEL_QK=BLOCK_DMODEL_QK, BLOCK_M=BLOCK_M, num_warps=8 if ctx.BLOCK_DMODEL_QK == 128 else 4, num_stages=8 + ) + + ctx.save_for_backward(q, k, gk) + ctx.grid = grid + ctx.BLOCK_N = BLOCK_N + ctx.BLOCK_N = BLOCK_N + ctx.head = q.shape[1] + return A.sum(0).to(q.dtype) + + @staticmethod + @custom_bwd + @contiguous + def backward(ctx, dA): + q, k, gk = ctx.saved_tensors + + # appearantly, there is no sync issue when splitting K dim. + dq = torch.zeros_like(q) + dk = torch.zeros_like(k) + dgk = torch.zeros_like(gk) + + BLOCK_N = ctx.BLOCK_N + # for now. + BLOCK_M = BLOCK_N + + _bwd_kernel_dqk[ctx.grid]( + q, k, gk, dA, + dq, + dk, dgk, + q.stride(0), q.stride(1), q.stride(2), q.stride(3), + dA.stride(0), dA.stride(1), dA.stride(2), dA.stride(3), + q.shape[0], q.shape[1], q.shape[2], q.shape[3], + BLOCK_N=BLOCK_N, + BLOCK_DMODEL_QK=ctx.BLOCK_DMODEL_QK, + BLOCK_M=BLOCK_M, + num_warps=8 if ctx.BLOCK_DMODEL_QK == 128 else 4, + num_stages=5 + ) + + return dq.to(q.dtype), dk.to(k.dtype), dgk.to(gk.dtype) diff --git a/flash_linear_attention/fla/ops/triton/gla/block_parallel/intra_chunk_contribution/fn_only_gv.py b/flash_linear_attention/fla/ops/triton/gla/block_parallel/intra_chunk_contribution/fn_only_gv.py new file mode 100644 index 0000000..3057b7a --- /dev/null +++ b/flash_linear_attention/fla/ops/triton/gla/block_parallel/intra_chunk_contribution/fn_only_gv.py @@ -0,0 +1,336 @@ +# -*- coding: utf-8 -*- + +import torch +import triton +import triton.language as tl +from torch.cuda.amp import custom_bwd, custom_fwd + +from fla.ops.triton.utils import contiguous + + +@triton.jit +def _fwd_compute_O( + A, + V, + GV, + O, + stride_a2, + stride_a3, + stride_a4, + stride_v2, + stride_v3, + stride_v4, + BLOCK_N: tl.constexpr, + BLOCK_DMODEL_V: tl.constexpr, +): + start_m = tl.program_id(0) + off_hz = tl.program_id(1) + off_v = tl.program_id(2) + + a_offset = off_hz * stride_a2 + v_offset = off_hz * stride_v2 + off_v * BLOCK_DMODEL_V + + lo = 0 + hi = BLOCK_N + + V_ptr = V + v_offset + (start_m) * stride_v3 + tl.arange(0, + BLOCK_DMODEL_V)[None, :] + tl.arange(0, 16)[:, None] * stride_v4 + + O_ptr = O + v_offset + (start_m) * stride_v3 + tl.arange(0, + BLOCK_DMODEL_V)[None, :] + tl.arange(0, 16)[:, None] * stride_v4 + + GV_ptr = GV + v_offset + (start_m) * stride_v3 + tl.arange(0, BLOCK_DMODEL_V)[ + None, :] + tl.arange(0, 16)[:, None] * stride_v4 + + A_ptr = A + a_offset + (start_m) * stride_a3 + tl.arange(0, + 16)[None, :] + tl.arange(0, 16)[:, None] * stride_a4 + + for q_high in range(lo+16, hi, 16): + q_gv_normalizer = tl.load(GV + v_offset + (start_m) * stride_v3 + + q_high * stride_v4 + tl.arange(0, BLOCK_DMODEL_V)).to(tl.float32) + acc = tl.zeros([16, BLOCK_DMODEL_V], dtype=tl.float32) + + # k_gv = tl.load(GV_ptr + q_high * stride_v4) + # q_gv = tl.exp(k_gv - q_gv_normalizer[None, :]) + + for k_high in range(0, q_high, 16): + qk = tl.load(A_ptr + q_high * stride_a4 + k_high) + v = tl.load(V_ptr + k_high * stride_v4) + k_gv = tl.load(GV_ptr + k_high * stride_v4) + k_gv = tl.exp(q_gv_normalizer[None, :] - k_gv) + v = v * k_gv.to(v.dtype) + # bf16 + output = tl.dot(qk.to(v.dtype), v, allow_tf32=False) + acc += output + + tl.store(O_ptr + q_high * stride_v4, acc.to(O.dtype.element_ty)) + + tl.store(O_ptr, tl.zeros([16, BLOCK_DMODEL_V], + dtype=tl.float32).to(O.dtype.element_ty)) + + tl.debug_barrier() + + for q_high in range(lo, hi, 16): + q_gv_normalizer = tl.load(GV + v_offset + (start_m) * stride_v3 + + q_high * stride_v4 + tl.arange(0, BLOCK_DMODEL_V)).to(tl.float32) + + qk = tl.load(A_ptr + q_high * stride_a4 + q_high) + v = tl.load(V_ptr + q_high * stride_v4) + k_gv = tl.load(GV_ptr + q_high * stride_v4) + k_gv2 = tl.exp(q_gv_normalizer[None, :] - k_gv) + + # fp32 matmul + v = v * k_gv2 + output = tl.dot(qk.to(tl.float32), v, allow_tf32=False) + + q_gv = tl.exp(k_gv - q_gv_normalizer[None, :]) + + prev = tl.load(O_ptr + q_high * stride_v4) + output += prev + output = output * q_gv + + tl.store(O_ptr + q_high * stride_v4, output.to(O.dtype.element_ty)) + + +@triton.jit +def _bwd_kernel_dav( + V, + GV, + A, + O, + DO, + DA, + DV, + DGV, + Z, + H, + stride_a1, + stride_a2, + stride_a3, + stride_a4, + stride_v1, + stride_v2, + stride_v3, + stride_v4, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_DMODEL_V: tl.constexpr +): + + start_m = tl.program_id(0) + off_hz = tl.program_id(1) + off_v = tl.program_id(2) + + a_offset = off_hz * stride_a2 + da_offset = (off_v * Z * H + off_hz) * stride_a2 + v_offset = off_hz * stride_v2 + off_v * BLOCK_DMODEL_V + + lo = 0 + hi = BLOCK_N + + DO_ptr = DO + v_offset + (start_m) * stride_v3 + tl.arange(0, BLOCK_DMODEL_V)[ + None, :] + tl.arange(0, 16)[:, None] * stride_v4 + + O_ptr = O + v_offset + (start_m) * stride_v3 + tl.arange(0, + BLOCK_DMODEL_V)[None, :] + tl.arange(0, 16)[:, None] * stride_v4 + + DV_ptr = DV + v_offset + (start_m) * stride_v3 + tl.arange(0, BLOCK_DMODEL_V)[ + None, :] + tl.arange(0, 16)[:, None] * stride_v4 + + GV_ptr = GV + v_offset + (start_m) * stride_v3 + tl.arange(0, BLOCK_DMODEL_V)[ + None, :] + tl.arange(0, 16)[:, None] * stride_v4 + + DGV_ptr = DGV + v_offset + (start_m) * stride_v3 + tl.arange(0, BLOCK_DMODEL_V)[ + None, :] + tl.arange(0, 16)[:, None] * stride_v4 + + A_ptr = A + a_offset + (start_m) * stride_a3 + tl.arange(0, + 16)[None, :] + tl.arange(0, 16)[:, None] * stride_a4 + + DA_ptr = DA + da_offset + (start_m) * stride_a3 + tl.arange(0, + 16)[None, :] + tl.arange(0, 16)[:, None] * stride_a4 + + # pre-compute do*q_gv. in-place update + for q_high in range(lo, hi, 16): + do = tl.load(DO_ptr + q_high * stride_v4) + o = tl.load(O_ptr + q_high * stride_v4) + tl.store(DGV_ptr + q_high * stride_v4, (do * o)) + + q_gv_normalizer = tl.load(GV + v_offset + (start_m) * stride_v3 + + q_high * stride_v4 + tl.arange(0, BLOCK_DMODEL_V)).to(tl.float32) + q_gv = tl.load(GV_ptr + q_high * stride_v4) + q_gv = tl.exp(q_gv - q_gv_normalizer[None, :]) + do = do * q_gv + + tl.store(DO_ptr + q_high * stride_v4, do.to(DO_ptr.dtype.element_ty)) + + tl.debug_barrier() + + V_ptr = V + v_offset + (start_m) * stride_v3 + \ + tl.arange(0, BLOCK_DMODEL_V)[:, None] + tl.arange(0, 16)[None, :] * stride_v4 + GV_ptr = GV + v_offset + (start_m) * stride_v3 + tl.arange(0, BLOCK_DMODEL_V)[ + :, None] + tl.arange(0, 16)[None, :] * stride_v4 + + for q_high in range(lo+16, hi, 16): + do = tl.load(DO_ptr + q_high * stride_v4) + q_gv_normalizer = tl.load(GV + v_offset + (start_m) * stride_v3 + q_high * + stride_v4 + tl.arange(0, BLOCK_DMODEL_V)).to(tl.float32) + + for k_high in range(0, q_high, 16): + v = tl.load(V_ptr + k_high * stride_v4) + k_gv = tl.load(GV_ptr + k_high * stride_v4) + k_gv = tl.exp(q_gv_normalizer[:, None] - k_gv) + + # bf16 + v2 = v * k_gv.to(v.dtype) + dqk = tl.dot(do, v2, allow_tf32=False) + tl.store(DA_ptr + q_high * stride_a4 + + k_high, dqk.to(DA.dtype.element_ty)) + + tl.debug_barrier() + + A_ptr = A + a_offset + (start_m) * stride_a3 + \ + tl.arange(0, 16)[:, None] + tl.arange(0, 16)[None, :] * stride_a4 + + V_ptr = V + v_offset + (start_m) * stride_v3 + \ + tl.arange(0, BLOCK_DMODEL_V)[None, :] + tl.arange(0, 16)[:, None] * stride_v4 + GV_ptr = GV + v_offset + (start_m) * stride_v3 + \ + tl.arange(0, BLOCK_DMODEL_V)[None, :] + tl.arange(0, 16)[:, None] * stride_v4 + + for k_high in range(0, hi, 16): + dv = tl.zeros([16, BLOCK_DMODEL_V], dtype=tl.float32) + + k_gv = tl.load(GV_ptr + k_high * stride_v4) + + for q_high in range(k_high + 16, BLOCK_N, 16): + do = tl.load(DO_ptr + q_high * stride_v4) + + kq = tl.load(A_ptr + q_high * stride_a4 + k_high).to(do.dtype) + + q_gv_normalizer = tl.load(GV + v_offset + + start_m * stride_v3 + q_high * stride_v4 + tl.arange(0, BLOCK_DMODEL_V)).to(tl.float32) + k_gv2 = tl.exp(q_gv_normalizer[None, :] - k_gv) + + # bf16 + dv2 = tl.dot(kq, do, allow_tf32=False) + dv += dv2 * k_gv2 + + v = tl.load(V_ptr + k_high * stride_v4) + tl.store(DV_ptr + k_high * stride_v4, dv.to(v.dtype)) + + prev_dv = tl.load(DGV_ptr + k_high * stride_v4) + tl.store(DGV_ptr + k_high * stride_v4, prev_dv - dv*v) + + tl.debug_barrier() + + A_ptr = A + a_offset + (start_m) * stride_a3 + tl.arange(0, + 16)[:, None] + tl.arange(0, 16)[None, :] * stride_a4 + + # intra-chunk + for q_high in range(lo, hi, 16): + do = tl.load(DO_ptr + q_high * stride_v4) + + q_gv_normalizer = tl.load(GV + v_offset + start_m * stride_v3 + + q_high * stride_v4 + tl.arange(0, BLOCK_DMODEL_V)).to(tl.float32) + + v = tl.load(V_ptr + q_high * stride_v4) + k_gv = tl.load(GV_ptr + q_high * stride_v4) + k_gv = tl.exp(q_gv_normalizer[None, :] - k_gv) + v2 = v * k_gv + + dqk = tl.dot(do.to(v2.dtype), tl.trans(v2), allow_tf32=False) + dqk = tl.where(tl.arange(0, 16)[:, None] + >= tl.arange(0, 16)[None, :], dqk, 0.) + tl.store(DA_ptr + q_high * stride_a4 + q_high, + dqk.to(DA_ptr.dtype.element_ty)) + + kq = tl.load(A_ptr + q_high * stride_a4 + q_high).to(do.dtype) + dv2 = tl.dot(kq, do, allow_tf32=False) + + dv = dv2 * k_gv + prev_dv = tl.load(DV_ptr + q_high * stride_v4) + tl.store(DV_ptr + q_high * stride_v4, + (prev_dv + dv).to(DV.dtype.element_ty)) + + prev_gdv = tl.load(DGV_ptr + q_high * stride_v4) + prev_gdv -= dv * v + tl.store(DGV_ptr + q_high * stride_v4, + prev_gdv.to(DGV.dtype.element_ty)) + + +class IntraCalO(torch.autograd.Function): + @staticmethod + @custom_fwd + @contiguous + def forward(ctx, A, v, gv): + assert gv.dtype == torch.float32 + # assert A.dtype == torch.float32 + + # only support for Ampere now + capability = torch.cuda.get_device_capability() + if capability[0] < 8: + raise RuntimeError( + "Flash attention currently only supported for compute capability >= 80") + + # assert gk.dtype == gv.dtype == torch.float32 + BLOCK_M = BLOCK_N = v.shape[-2] + + # shape constraints + Lv = v.shape[-1] + BLOCK_V = min(128, Lv) + ctx.BLOCK_V = BLOCK_V + + assert v.shape[-1] % BLOCK_V == 0 + + grid = (v.shape[2], v.shape[0] * v.shape[1], + max(1, v.shape[-1] // BLOCK_V)) + + o = torch.empty_like(v) + + _fwd_compute_O[grid](A, v, gv, o, + A.stride(0), A.stride( + 1), A.stride(2), A.stride(3), + v.stride(0), v.stride( + 1), v.stride(2), v.stride(3), + BLOCK_N=BLOCK_N, BLOCK_M=BLOCK_M, + BLOCK_DMODEL_V=BLOCK_V, num_warps=8 if BLOCK_V == 128 else 4, num_stages=5 + ) + + ctx.save_for_backward(A, v, gv, o) + ctx.grid = grid + return o + + @staticmethod + @custom_bwd + @contiguous + def backward(ctx, do): + A, v, gv, o = ctx.saved_tensors + BLOCK_V = ctx.BLOCK_V + assert v.shape[-1] % BLOCK_V == 0 + + # dA = torch.empty_like(A) + dv = torch.zeros_like(v) + dgv = torch.zeros_like(gv) + + # for now. + BLOCK_M = BLOCK_N = v.shape[-2] + + # shape constraints + # Lv = v.shape[-1] + # grid = (v.shape[2] , v.shape[0] * v.shape[1], v.shape[-1] // BLOCK_V) + grid = ctx.grid + + dA = torch.empty(v.shape[-1] // BLOCK_V if BLOCK_V == 128 else 1, A.shape[0], + A.shape[1], A.shape[2], A.shape[3], A.shape[3], device=A.device, dtype=A.dtype) + + _bwd_kernel_dav[grid]( + v, gv, A, o, + do, dA, + dv, dgv, + v.shape[0], v.shape[1], + A.stride(0), A.stride(1), A.stride(2), A.stride(3), + v.stride(0), v.stride(1), v.stride(2), v.stride(3), + BLOCK_N=BLOCK_N, BLOCK_M=BLOCK_M, + BLOCK_DMODEL_V=ctx.BLOCK_V, num_warps=8, num_stages=4 + ) + + return dA.sum(0).to(A), dv.to(v), dgv.to(gv) diff --git a/flash_linear_attention/fla/ops/triton/gla/chunk.py b/flash_linear_attention/fla/ops/triton/gla/chunk.py new file mode 100644 index 0000000..04d1ad2 --- /dev/null +++ b/flash_linear_attention/fla/ops/triton/gla/chunk.py @@ -0,0 +1,39 @@ +# -*- coding: utf-8 -*- + +# Copyright (c) 2023, Songlin Yang +# Gated Linear Attention Transformers with Hardware-Efficient Training: https://arxiv.org/abs/2312.06635 +# chunkwise block parallel. Materialize chunkwise hidden states into HBMs. +# Therefore it is neccessary to have a large chunk size to reduce such materialization overhead. + +import torch.nn.functional as F +from einops import rearrange + +from fla.ops.triton.gla.block_parallel.inter_chunk_contribution.fn import \ + inter_chunk_onc +from fla.ops.triton.gla.block_parallel.intra_chunk_contribution.fn import \ + intra_chunk_onc + + +def pad_and_rearrange(x, chunk_size): + if x.shape[-2] % chunk_size != 0: + x = F.pad(x, (0, 0, 0, chunk_size - x.shape[-2] % chunk_size)) + if x.shape[-1] % 32 != 0: + x = F.pad(x, (0, 32 - x.shape[-1] % 32)) + x = rearrange(x, '... (n c) d -> ... n c d', c=chunk_size) + return x + + +def chunk_gla(q, k, v, gk=None, gv=None, chunk_size=128): + scale = (q.shape[-1])**-0.5 + seq_len = q.shape[-2] + output_dim = v.shape[-1] + q, k, v = map(lambda x: pad_and_rearrange(x, chunk_size), [q, k, v]) + q = q * scale + if gk is not None: + gk = pad_and_rearrange(gk, chunk_size) + if gv is not None: + gv = pad_and_rearrange(gv, chunk_size) + gk, gv, o1 = inter_chunk_onc(q, k, v, gk, gv) + o2 = intra_chunk_onc(q, k, v, gk, gv) + o = rearrange(o1+o2, 'b h n c d -> b h (n c) d') + return o[:, :, :seq_len, :output_dim] diff --git a/flash_linear_attention/fla/ops/triton/gla/chunk_fuse.py b/flash_linear_attention/fla/ops/triton/gla/chunk_fuse.py new file mode 100644 index 0000000..fd45ba2 --- /dev/null +++ b/flash_linear_attention/fla/ops/triton/gla/chunk_fuse.py @@ -0,0 +1,400 @@ +# -*- coding: utf-8 -*- + +# Copyright (c) 2023, Songlin Yang +# Gated Linear Attention Transformers with Hardware-Efficient Training: https://arxiv.org/abs/2312.06635 +# on-the-fly computation without materializing hidden statets into HBMs + +import warnings + +import torch +import torch.nn.functional as F +import triton +import triton.language as tl +from einops import rearrange +from fla.ops.triton.utils import contiguous, require_version +from torch.cuda.amp import custom_bwd, custom_fwd + +try: + import semiring_cal_A +except ImportError: + warnings.warn('Failed to import semiring_cal_A. Do not use FusedChunk implementation of GLA.') + +inv_ln2 = 1.44269504 + + +@triton.jit +def fused_chunk_gla_fwd_kernel( + # B: batch_size, H: n_heads, T: seq_len, D: d_head + q, # query [B, H, L, D_head_K] + k, # key [B, H, L, D_head_K] + v, # value [B, H, L, D_head_V] + g, # cumulative sum of log decay [B, H, L, D_head_K] + o, # output [B, H, L, D_head_V] + + initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V] + final_state, # final state of the chunk [B, H, D_head_K, D_head_V] + + s_qk_h, # stride size: L * D_head_K + s_qk_t, # stride size: D_head_K + s_qk_d, # stride size: 1 + + s_vo_h, # stride size: L * D_head_V + s_vo_t, # stride size: D_head_V + s_vo_d, # stride size: 1 + + B, # batch size + H, # n_heads + T, # seq_len + scale, # D_head_K ** -0.5 + BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size + BK: tl.constexpr, # BLOCK SIZE along the K dimension + BV: tl.constexpr, # BLOCK SIZE along the V dimension + DK: tl.constexpr, # D_head_K + DV: tl.constexpr, # D_head_V + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr, +): + # indices + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + b_h = tl.zeros([BK, BV], dtype=tl.float32) + + # make block pointers + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK), + (s_qk_t, s_qk_d), (0, i_k * BK), (BT, BK), (1, 0)) + p_g = tl.make_block_ptr(g + i_bh * s_qk_h, (T, DK), + (s_qk_t, s_qk_d), (0, i_k * BK), (BT, BK), (1, 0)) + p_db = g + i_bh * s_qk_h + (BT - 1) * s_qk_t + i_k * BK + tl.arange(0, BK) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T), + (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BT), (0, 1)) + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), + (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0)) + p_o = tl.make_block_ptr(o + (i_bh + i_k * B * H) * s_vo_h, + (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0)) + + if USE_INITIAL_STATE: + p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, + (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_h += tl.load(p_h, boundary_check=(0, 1)).to(tl.float32) + + for i in range(0, tl.cdiv(T, BT)): + # [BK, BT] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BT, BV] + b_o = tl.zeros([BT, BV], dtype=tl.float32) + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32) + b_g *= inv_ln2 + + d_b = tl.load(p_db) * inv_ln2 + + b_q = (b_q * scale * tl.math.exp2(b_g)) + b_k = b_k * tl.trans(tl.math.exp2(-b_g + d_b[None, :])) + + b_o = tl.dot(b_q.to(b_v.dtype), b_h.to(b_v.dtype), allow_tf32=False) + b_h *= tl.math.exp2(d_b)[:, None] + b_h += tl.dot(b_k.to(b_v.dtype), b_v, allow_tf32=False) + + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + p_q = tl.advance(p_q, (BT, 0)) + p_g = tl.advance(p_g, (BT, 0)) + p_k = tl.advance(p_k, (0, BT)) + p_v = tl.advance(p_v, (BT, 0)) + p_o = tl.advance(p_o, (BT, 0)) + p_db += BT * DK + + if STORE_FINAL_STATE: + p_final = tl.make_block_ptr( + final_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_final, b_h.to(p_final.dtype.element_ty), + boundary_check=(0, 1)) + + +# Similar to Algorithm1 of https://arxiv.org/abs/2006.16236 +@triton.jit +def fused_chunk_gla_bwd_kernel( + q, k, v, g, + do, # gradient of output [B, H, L, D_head_V] + dq, # gradient of query [NV, B, H, L, D_head_K] + dk, # gradient of key [NV, B, H, L, D_head_K] + dv, # gradient of value [NK, B, H, L, D_head_V] + + initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V] + + s_qk_h, # stride size: L * D_head_K + s_qk_t, # stride size: D_head_K + s_qk_d, # stride size: 1 + + s_vo_h, # stride size: L * D_head_V + s_vo_t, # stride size: D_head_V + s_vo_d, # stride size: 1 + + B, # batch_size + H, # n_heads + T, # seq_len + scale, # D_head_K ** -0.5 + # clamp_min, # minimum log value of the gate for numerical stability. default: -5 + BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size + BK: tl.constexpr, # BLOCK SIZE along the K dimension + BV: tl.constexpr, # BLOCK SIZE along the V dimension + DK: tl.constexpr, # D_head_K + DV: tl.constexpr, # D_head_V + USE_INITIAL_STATE: tl.constexpr +): + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + # [BV, BK] + b_h = tl.zeros([BV, BK], dtype=tl.float32) + + if USE_INITIAL_STATE: + p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, + (DV, DK), (1, DV), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + b_h += tl.load(p_h, boundary_check=(0, 1)).to(tl.float32) + + for i in range(0, tl.cdiv(T, BT)): + p_k = tl.make_block_ptr( + k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i * BT, i_k * BK), (BT, BK), (1, 0)) + p_g = tl.make_block_ptr( + g + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i * BT, i_k * BK), (BT, BK), (1, 0)) + p_db = g + i_bh * s_qk_h + \ + ((i+1) * BT - 1) * s_qk_t + i_k * BK + tl.arange(0, BK) + p_v = tl.make_block_ptr( + v + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i * BT), (BV, BT), (0, 1)) + p_do = tl.make_block_ptr( + do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i * BT, i_v * BV), (BT, BV), (1, 0)) + p_dq = tl.make_block_ptr(dq + (i_bh+i_v*B*H)*s_qk_h, (T, DK), + (s_qk_t, s_qk_d), (i * BT, i_k * BK), (BT, BK), (1, 0)) + b_dq = tl.zeros([BT, BK], dtype=tl.float32) + # [BT, DK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_g = tl.load(p_g, boundary_check=(0, 1)) * inv_ln2 + d_b = tl.load(p_db) * inv_ln2 + + # [DV, BT] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BT, DV] + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False) + # [DV, DK] + b_k *= tl.math.exp2(d_b[None, :] - b_g) + b_h *= tl.math.exp2(d_b)[None, :] + b_h += tl.dot(b_v, b_k.to(b_v.dtype), allow_tf32=False) + b_dq *= scale * tl.math.exp2(b_g) + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + + # sync threads + b_h = None + tl.debug_barrier() + # [BK, BV] + b_dh = tl.zeros([BK, BV], dtype=tl.float32) + + # cum = tl.zeros([BK], dtype=tl.float32) + for i in range(1, tl.cdiv(T, BT) + 1): + p_q = tl.make_block_ptr( + q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, T - i * BT), (BK, BT), (0, 1)) + p_k = tl.make_block_ptr( + k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (T - i * BT, i_k * BK), (BT, BK), (1, 0)) + p_g = tl.make_block_ptr( + g + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (T - i * BT, i_k * BK), (BT, BK), (1, 0)) + p_db = g + i_bh * s_qk_h + \ + (T - (i-1) * BT - 1) * s_qk_t + i_k * BK + tl.arange(0, BK) + p_v = tl.make_block_ptr( + v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0)) + p_do = tl.make_block_ptr( + do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0)) + p_dk = tl.make_block_ptr(dk + (i_bh + i_v * B * H) * s_qk_h, (T, DK), + (s_qk_t, s_qk_d), (T - i * BT, i_k * BK), (BT, BK), (1, 0)) + # p_dg = tl.make_block_ptr(dg + (i_bh + i_v * B * H) * s_qk_h, (T, DK), + # (s_qk_t, s_qk_d), (T - i * BT, i_k * BK), (BT, BK), (1, 0)) + p_dv = tl.make_block_ptr(dv + (i_bh + i_k * B * H) * s_vo_h, (T, DV), + (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0)) + # [DK, BT] + b_q = tl.load(p_q, boundary_check=(0, 1)) + # [BT, DK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BT, DV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_g = tl.load(p_g, boundary_check=(0, 1)) * inv_ln2 + b_db = tl.load(p_db) * inv_ln2 + + # inter-chunk + g_k = tl.math.exp2(b_db[None, :] - b_g) + b_k *= g_k + b_q *= tl.math.exp2(tl.trans(b_g)) + b_dk = tl.trans(tl.dot(b_dh.to(b_v.dtype), tl.trans( + b_v), allow_tf32=False)) * scale * g_k + b_dv = tl.dot((b_k).to(b_v.dtype), b_dh.to( + b_v.dtype), allow_tf32=False) * scale + + # [DK, DV] + b_dh *= tl.math.exp2(b_db)[:, None] + b_dh += tl.dot(b_q.to(b_do.dtype), b_do, allow_tf32=False) + + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + +class FusedChunkGLAFunction(torch.autograd.Function): + + @staticmethod + @contiguous + @custom_fwd + def forward(ctx, q, k, v, g, scale, initial_state, output_final_state): + ctx.g_dtype = g.dtype + batch_size, n_heads, seq_len, d_head_qk = q.shape + d_head_v = v.shape[-1] + ctx.scale = scale + + # inter-chunk + BT = 16 # chunk_size + BK, BV = min(d_head_qk, 64), min(d_head_v, 64) + num_stages = 1 + num_warps = 2 + + NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV) + o = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v) + + g = rearrange(g, 'b h (n c) d -> b h n c d', c=BT) + g = g.float().cumsum(-2) + g = rearrange(g, 'b h n c d -> b h (n c) d') + + if output_final_state: + final_state = q.new_empty( + batch_size, n_heads, d_head_qk, d_head_v, dtype=torch.float32, requires_grad=False) + else: + final_state = None + + grid = (NV, NK, batch_size * n_heads) + fused_chunk_gla_fwd_kernel[grid]( + q, k, v, g, o, initial_state, final_state, + q.stride(1), q.stride(2), q.stride(3), + v.stride(1), v.stride(2), v.stride(3), + batch_size, n_heads, seq_len, scale, + # clamp_min=-3, + BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV, + # USE_SIGMOID=True, USE_EXP=False, + USE_INITIAL_STATE=initial_state is not None, + STORE_FINAL_STATE=output_final_state, + num_warps=num_warps, + num_stages=num_stages + ) + + o = o.sum(0) + + # ### intra-chunk + chunk_size = 16 + num_chunk = seq_len // chunk_size + q2 = rearrange(q, 'b h (n c) d -> b h n c d', n=num_chunk) + k2 = rearrange(k, 'b h (n c) d -> b h n c d', n=num_chunk) + v2 = rearrange(v, 'b h (n c) d -> b h n c d', n=num_chunk) + g2 = rearrange(g, 'b h (n c) d -> b h n c d', n=num_chunk) + A = semiring_cal_A.forward(q2, k2, g2) * scale + o2 = A @ v2 + o2 = rearrange(o2, 'b h n c d -> b h (n c) d') + o.add_(o2) + ctx.save_for_backward(q, k, v, g, A, initial_state) + return o.to(v), final_state + + @staticmethod + @contiguous + @custom_bwd + def backward(ctx, do, d_final_state=None): + q, k, v, g, A, initial_state = ctx.saved_tensors + batch_size, n_heads, seq_len, d_head_qk = q.shape + d_head_v = v.shape[-1] + scale = ctx.scale + + # inter-chunk + BT = 16 + BK, BV = min(d_head_qk, 64), min(d_head_v, 64) + NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV) + num_stages = 1 + num_warps = 2 + + dq = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk) + dk = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk) + dv = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v) + + grid = (NV, NK, batch_size * n_heads) + fused_chunk_gla_bwd_kernel[grid]( + q, k, v, g, do, dq, dk, dv, initial_state, + q.stride(1), q.stride(2), q.stride(3), + v.stride(1), v.stride(2), v.stride(3), + batch_size, n_heads, seq_len, scale, + # clamp_min=-3, + BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV, + USE_INITIAL_STATE=initial_state is not None, + num_warps=num_warps, + num_stages=num_stages, + ) + dq = dq.sum(0) + dk = dk.sum(0) + dv = dv.sum(0) + + dg = dq * q + dg.add_(- dk * k) + + # # # #### intra chunk + num_chunk = seq_len // BT + q2 = rearrange(q, 'b h (n c) d -> b h n c d', n=num_chunk) + k2 = rearrange(k, 'b h (n c) d -> b h n c d', n=num_chunk) + v2 = rearrange(v, 'b h (n c) d -> b h n c d', n=num_chunk) + g2 = rearrange(g, 'b h (n c) d -> b h n c d', n=num_chunk) + do2 = rearrange(do, 'b h (n c) d -> b h n c d', n=num_chunk) + dA2 = (do2 @ v2.transpose(-2, -1)) * scale + dv2 = A.transpose(-1, -2) @ do2 + dq2, dk2, dg2 = semiring_cal_A.backward(q2, k2, g2, dA2) + dq2 = rearrange(dq2, '... h n c d -> ... h (n c) d') + dk2 = rearrange(dk2, '... h n c d -> ... h (n c) d') + dv2 = rearrange(dv2, '... h n c d -> ... h (n c) d') + dg2 = rearrange(dg2, '... h n c d -> ... h (n c) d') + dq.add_(dq2.to(dq)) + dk.add_(dk2.to(dk)) + dv.add_(dv2.to(dv)) + dg = dg.float() + dg.add_(dg2) + dg_cumsum = dg.cumsum(-2) + dg = dg - dg_cumsum + dg_cumsum[:, :, -1, None] + return dq.to(q), dk.to(k), dv.to(v), dg.to(ctx.g_dtype), None, None, None + + +def pad(x, chunk_size=16): + seq_len = x.shape[-2] + padded_seq_len = ceildiv(seq_len, chunk_size) * chunk_size + if x.shape[-2] % chunk_size != 0: + x = F.pad(x, (0, 0, 0, padded_seq_len - seq_len)) + if x.shape[-1] % 32 != 0: + x = F.pad(x, (0, 32 - x.shape[-1] % 32)) + return x + + +def ceildiv(a, b): + return -(a // -b) + + +@require_version('triton>=2.2', 'Numerical stability consideration!') +def fused_chunk_gla( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + scale: int = -1, + initial_state: torch.Tensor = None, + output_final_state: bool = False +): + if scale == -1: + scale = q.shape[-1] ** -0.5 + if initial_state is not None: + initial_state = initial_state.detach() + seq_len = v.shape[-2] + d_head_v = v.shape[-1] + q, k, v, g = map(lambda x: pad(x), [q, k, v, g]) + o, final_state = FusedChunkGLAFunction.apply( + q, k, v, g, scale, initial_state, output_final_state) + o = o[..., :seq_len, :d_head_v] + if output_final_state: + return o, final_state + return o diff --git a/flash_linear_attention/fla/ops/triton/gla/recurrent_fuse.py b/flash_linear_attention/fla/ops/triton/gla/recurrent_fuse.py new file mode 100644 index 0000000..0b17e4f --- /dev/null +++ b/flash_linear_attention/fla/ops/triton/gla/recurrent_fuse.py @@ -0,0 +1,403 @@ +# -*- coding: utf-8 -*- + +# Copyright (c) 2023, Songlin Yang + +import torch +import triton +import triton.language as tl +from torch.cuda.amp import custom_bwd, custom_fwd + +from fla.ops.triton.utils import contiguous + +# on-the-fly computation without materializing hidden statets into HBMs + + +@triton.jit +def fused_recurrent_gla_fwd_kernel( + # B: batch_size, H: n_heads, T: seq_len, D: d_head + q, # query [B, H, L, D_head_K] + k, # key [B, H, L, D_head_K] + v, # value [B, H, L, D_head_V] + gk, # log gate [B, H, L, D_head_K] + gv, # log gate [B, H, L, D_head_V] + o, # output [B, H, L, D_head_V] + # initial hidden state initialization [B, H, D_head_K, D_head_V] + initial_state, + final_state, # final hidden state [B, H, D_head_K, D_head_V] + + s_qk_h, # stride size: L * D_head_K + s_qk_t, # stride size: D_head_K + s_qk_d, # stride size: 1 + + s_vo_h, # stride size: L * D_head_V + s_vo_t, # stride size: D_head_V + s_vo_d, # stride size: 1 + + B, # batch size + H, # n_heads + T, # seq_len + scale, # D_head_K ** -0.5 + BK: tl.constexpr, # BLOCK SIZE along the K dimension + BV: tl.constexpr, # BLOCK SIZE along the V dimension + DK: tl.constexpr, # D_head_K + DV: tl.constexpr, # D_head_V + USE_INITIAL_STATE: tl.constexpr, # whether to use initial state + STORE_FINAL_STATE: tl.constexpr, # whether to store final state + REVERSE: tl.constexpr, # whether to do autoregressive modeling in the reverse direction + USE_GK: tl.constexpr, # whether to use gk + USE_GV: tl.constexpr, # whether to use gv +): + # indices + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + p_q = q + i_bh * s_qk_h + i_k * BK + \ + tl.arange(0, BK) + ((T-1) * DK if REVERSE else 0) + p_k = k + i_bh * s_qk_h + i_k * BK + \ + tl.arange(0, BK) + ((T-1) * DK if REVERSE else 0) + p_v = v + i_bh * s_vo_h + i_v * BV + \ + tl.arange(0, BV) + ((T-1) * DV if REVERSE else 0) + p_o = o + (i_bh + i_k * B * H) * s_vo_h + i_v * BV + \ + tl.arange(0, BV) + ((T-1) * DV if REVERSE else 0) + + if USE_GK: + p_gk = gk + i_bh * s_qk_h + i_k * BK + \ + tl.arange(0, BK) + ((T-1) * DK if REVERSE else 0) + if USE_GV: + p_gv = gv + i_bh * s_vo_h + i_v * BV + \ + tl.arange(0, BV) + ((T-1) * DV if REVERSE else 0) + + mask_bk = (i_k * BK + tl.arange(0, BK)) < DK + mask_bv = (i_v * BV + tl.arange(0, BV)) < DV + + h = tl.zeros([BV, BK], dtype=tl.float32) + + mask_kv = mask_bk[None, :] & mask_bv[:, None] + + if USE_INITIAL_STATE: + p_init_s = initial_state + i_bh * DK * DV + \ + (i_k * BK + tl.arange(0, BK)[None, :]) * \ + DV + (i_v * BV + tl.arange(0, BV)[:, None]) + h += tl.load(p_init_s, mask=mask_kv, other=0).to(tl.float32) + + for _ in range(0, T): + _k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32) + _v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32) + _q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale + if USE_GK: + _gk = tl.load(p_gk, mask=mask_bk, other=0).to(tl.float32) + h = h * _gk[None, :] + if USE_GV: + _gv = tl.load(p_gv, mask=mask_bv, other=0).to(tl.float32) + h = h * _gv[:, None] + h += _k[None, :] * _v[:, None] + _o = h * _q[None, :] + _o = tl.sum(_o, axis=1) + tl.store(p_o, _o.to(p_o.dtype.element_ty), mask=mask_bv) + p_q += -DK if REVERSE else DK + p_k += -DK if REVERSE else DK + p_o += -DV if REVERSE else DV + p_v += -DV if REVERSE else DV + if USE_GK: + p_gk += -DK if REVERSE else DK + if USE_GV: + p_gv += -DV if REVERSE else DV + + if STORE_FINAL_STATE: + p_final_s = final_state + i_bh * DK * DV + \ + (i_k * BK + tl.arange(0, BK)[None, :]) * \ + DV + (i_v * BV + tl.arange(0, BV)[:, None]) + tl.store(p_final_s, h.to(p_final_s.dtype.element_ty), mask=mask_kv) + + +# Similar to Algorithm1 of https://arxiv.org/abs/2006.16236 +@triton.jit +def fused_recurrent_gla_bwd_kernel( + # B: batch_size, H: n_heads, T: seq_len, D: d_head + # NV: number of split in the V dimension. NK: number of split in the K dimension + q, # query [B, H, L, D_head_K] + k, # key [B, H, L, D_head_V] + v, # value [B, H, L, D_head_V] + gk, # log gate [B, H, L, D_head_K] \alpha + gv, # log gate [B, H, L, D_head_V] \bete + + do, # gradient of output [B, H, L, D_head_V] + dq, # gradient of query [NV, B, H, L, D_head_K] + dk, # gradient of key [NV, B, H, L, D_head_K] + dv, # gradient of value [NK, B, H, L, D_head_V] + + # initial hidden state initialization [B, H, D_head_K, D_head_V] + initial_state, + + s_qk_h, # stride size: L * D_head_K + s_qk_t, # stride size: D_head_K + s_qk_d, # stride size: 1 + + s_vo_h, # stride size: L * D_head_V + s_vo_t, # stride size: D_head_V + s_vo_d, # stride size: 1 + + B, # batch_size + H, # n_heads + T, # seq_len + scale, # D_head_K ** -0.5 + BK: tl.constexpr, # BLOCK SIZE along the K dimension + BV: tl.constexpr, # BLOCK SIZE along the V dimension + DK: tl.constexpr, # D_head_K + DV: tl.constexpr, # D_head_V + USE_INITIAL_STATE: tl.constexpr, # whether to use initial state + REVERSE: tl.constexpr, # whether to do autoregressive modeling in the reverse direction + USE_GK: tl.constexpr, # whether to use gk + USE_GV: tl.constexpr, # whether to use gv +): + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + p_q = q + i_bh * s_qk_h + i_k * BK + \ + tl.arange(0, BK) + ((T-1) * DK if REVERSE else 0) + p_k = k + i_bh * s_qk_h + i_k * BK + \ + tl.arange(0, BK) + ((T-1) * DK if REVERSE else 0) + p_v = v + i_bh * s_vo_h + i_v * BV + \ + tl.arange(0, BV) + ((T-1) * DV if REVERSE else 0) + p_do = do + i_bh * s_vo_h + i_v * BV + \ + tl.arange(0, BV) + ((T-1) * DV if REVERSE else 0) + p_dq = dq + (i_bh + i_v * B * H) * s_qk_h + i_k * BK + \ + tl.arange(0, BK) + ((T-1) * DK if REVERSE else 0) + if USE_GK: + p_gk = gk + i_bh * s_qk_h + i_k * BK + \ + tl.arange(0, BK) + ((T-1) * DK if REVERSE else 0) + if USE_GV: + p_gv = gv + i_bh * s_vo_h + i_v * BV + \ + tl.arange(0, BV) + ((T-1) * DV if REVERSE else 0) + mask_bk = i_k * BK + tl.arange(0, BK) < DK + mask_bv = i_v * BV + tl.arange(0, BV) < DV + mask_kv = mask_bk[:, None] & mask_bv[None, :] + h = tl.zeros([BK, BV], dtype=tl.float32) + + if USE_INITIAL_STATE: + p_init_s = initial_state + i_bh * DK * DV + \ + (i_k * BK + tl.arange(0, BK)[:, None]) * \ + DV + (i_v * BV + tl.arange(0, BV)[None, :]) + h += tl.load(p_init_s, mask=mask_kv, other=0).to(tl.float32) + + for i in range(0, T): + _k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32) + _v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32) + _do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32) + if USE_GK: + _gk = tl.load(p_gk, mask=mask_bk, other=0).to(tl.float32) + h = h * _gk[:, None] + if USE_GV: + _gv = tl.load(p_gv, mask=mask_bv, other=0).to(tl.float32) + h = h * _gv[None, :] + h += _k[:, None] * _v[None, :] + _d_q = h * _do[None, :] + d_q = tl.sum(_d_q, axis=1) * scale + tl.store(p_dq, d_q.to(p_dq.dtype.element_ty), mask=mask_bk) + + p_k += -DK if REVERSE else DK + p_v += -DV if REVERSE else DV + p_q += -DK if REVERSE else DK + p_do += -DV if REVERSE else DV + p_dq += -DK if REVERSE else DK + if USE_GK: + p_gk += -DK if REVERSE else DK + if USE_GV: + p_gv += -DV if REVERSE else DV + + # sync threads + tl.debug_barrier() + + p_q = q + i_bh * s_qk_h + i_k * BK + \ + tl.arange(0, BK) + ((T - 1) * DK if not REVERSE else 0) + p_k = k + i_bh * s_qk_h + i_k * BK + \ + tl.arange(0, BK) + ((T - 1) * DK if not REVERSE else 0) + p_do = do + i_bh * s_vo_h + i_v * BV + \ + tl.arange(0, BV) + ((T - 1) * DV if not REVERSE else 0) + p_v = v + i_bh * s_vo_h + i_v * BV + \ + tl.arange(0, BV) + ((T - 1) * DV if not REVERSE else 0) + p_dk = dk + (i_bh + i_v * B * H) * s_qk_h + i_k * \ + BK + tl.arange(0, BK) + ((T - 1) * DK if not REVERSE else 0) + p_dv = dv + (i_bh + i_k * B * H) * s_vo_h + i_v * \ + BV + tl.arange(0, BV) + ((T - 1) * DV if not REVERSE else 0) + if USE_GK: + p_gk = gk + i_bh * s_qk_h + i_k * BK + \ + tl.arange(0, BK) + ((T - 1) * DK if not REVERSE else 0) + if USE_GV: + p_gv = gv + i_bh * s_vo_h + i_v * BV + \ + tl.arange(0, BV) + ((T - 1) * DV if not REVERSE else 0) + + d_h = tl.zeros([BK, BV], dtype=tl.float32) + + for _ in range(T): + _do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32) + _q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale + _k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32) + _v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32) + d_h += _q[:, None] * _do[None, :] + d_k = tl.sum(d_h * _v[None, :], axis=1) + d_v = tl.sum(d_h * _k[:, None], axis=0) + if USE_GK: + _gk = tl.load(p_gk, mask=mask_bk, other=0).to(tl.float32) + d_h *= _gk[:, None] + if USE_GV: + _gv = tl.load(p_gv, mask=mask_bv, other=0).to(tl.float32) + d_h *= _gv[None, :] + tl.store(p_dk, d_k.to(p_dk.dtype.element_ty), mask=mask_bk) + tl.store(p_dv, d_v.to(p_dv.dtype.element_ty), mask=mask_bv) + + p_do += DV if REVERSE else -DV + p_q += DK if REVERSE else -DK + p_k += DK if REVERSE else -DK + p_v += DV if REVERSE else -DV + p_dk += DK if REVERSE else -DK + p_dv += DV if REVERSE else -DV + if USE_GK: + p_gk += DK if REVERSE else -DK + if USE_GV: + p_gv += DV if REVERSE else -DV + + +class FusedRecurrentGLAFunction(torch.autograd.Function): + + @staticmethod + @contiguous + @custom_fwd + def forward(ctx, q, k, v, gk, gv, scale=None, initial_state=None, output_final_state=False, reverse=False): + batch_size, n_heads, seq_len, d_head_qk = q.shape + d_head_v = v.shape[-1] + # default scale + if scale is None: + scale = d_head_qk ** -0.5 + if gk is not None: + gk = gk.float().exp() + if gv is not None: + gv = gv.float().exp() + + BK, BV = min(d_head_qk, 32), min(d_head_v, 32) + NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV) + num_stages = 1 + num_warps = 1 + + o = q.new_empty(NK, batch_size, n_heads, seq_len, + d_head_v, dtype=torch.float32) + + if output_final_state: + final_state = q.new_empty(batch_size, n_heads, d_head_qk, d_head_v) + else: + final_state = None + + grid = (NV, NK, batch_size * n_heads) + fused_recurrent_gla_fwd_kernel[grid]( + q, k, v, gk, gv, o, initial_state, final_state, + q.stride(1), q.stride(2), q.stride(3), + v.stride(1), v.stride(2), v.stride(3), + batch_size, n_heads, seq_len, scale, + DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV, + USE_INITIAL_STATE=initial_state is not None, + STORE_FINAL_STATE=final_state is not None, + USE_GK=gk is not None, + USE_GV=gv is not None, + REVERSE=reverse, + num_warps=num_warps, + num_stages=num_stages + ) + + o = o.sum(0) + ctx.save_for_backward(q, k, v, gk, gv, initial_state, o) + ctx.scale = scale + ctx.reverse = reverse + # we do not need the gradient of the final state from the next chunk + # similiar to Trunctated BPTT + if final_state is not None: + final_state = final_state.detach() + return o.to(q.dtype), final_state + + @staticmethod + @contiguous + @custom_bwd + def backward(ctx, do, d_final_state=None): + q, k, v, gk, gv, initial_state, o = ctx.saved_tensors + batch_size, n_heads, seq_len, d_head_qk = q.shape + d_head_v = v.shape[-1] + scale = ctx.scale + + BK, BV = min(d_head_qk, 32), min(d_head_v, 32) + NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV) + num_stages = 1 + num_warps = 1 + + dq = q.new_empty(NV, batch_size, n_heads, seq_len, + d_head_qk, dtype=torch.float32) + dk = q.new_empty(NV, batch_size, n_heads, seq_len, + d_head_qk, dtype=torch.float32) + dv = q.new_empty(NK, batch_size, n_heads, seq_len, + d_head_v, dtype=torch.float32) + grid = (NV, NK, batch_size * n_heads) + + fused_recurrent_gla_bwd_kernel[grid]( + q, k, v, gk, gv, do, dq, dk, dv, initial_state, + q.stride(1), q.stride(2), q.stride(3), + v.stride(1), v.stride(2), v.stride(3), + batch_size, n_heads, seq_len, scale, + DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV, + num_warps=num_warps, + num_stages=num_stages, + USE_INITIAL_STATE=initial_state is not None, + REVERSE=ctx.reverse, + USE_GK=gk is not None, + USE_GV=gv is not None + ) + dq = dq.sum(0) + dk = dk.sum(0) + dv = dv.sum(0) + if gk is not None: + _dgk = dq * q.float() - dk * k.float() + if ctx.reverse: + dgk = _dgk.cumsum(-2) + else: + _dgk_cumsum = _dgk.cumsum(-2) + dgk = _dgk + _dgk_cumsum[:, :, -1, None] - _dgk_cumsum + else: + dgk = None + + if gv is not None: + _dgv = do.float() * o.float() - dv * v.float() + if ctx.reverse: + dgv = _dgv.cumsum(-2) + else: + _dgv_cumsum = _dgv.cumsum(-2) + dgv = _dgv + _dgv_cumsum[:, :, -1, None] - _dgv_cumsum + else: + dgv = None + + return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), dgk, dgv, None, None, None, None + + +# if scale is None, use d_head_qk ** -0.5 by default. Otherwise specify the scale yourself. e.g. scale = 1.0 +def fused_recurrent_gla(q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + gk: torch.Tensor = None, + gv: torch.Tensor = None, + scale: int = -1, + initial_state: torch.Tensor = None, + output_final_state: bool = False, + causal: bool = True): + if scale == -1: + scale = q.shape[-1] ** -0.5 + if initial_state is not None: + initial_state = initial_state.detach() + if causal: + o, final_state = FusedRecurrentGLAFunction.apply( + q, k, v, gk, gv, scale, initial_state, output_final_state) + if output_final_state: + return o, final_state + return o + else: + # do not support initial_state yet. looks very strange for bidirectional modeling + assert initial_state is None + assert output_final_state is False + o, final_state = FusedRecurrentGLAFunction.apply( + q, k, v, gk, gv, scale, initial_state, output_final_state, False) + o_reversed, final_state = FusedRecurrentGLAFunction.apply( + q, k, v, gk, gv, scale, initial_state, output_final_state, True) + return [o, o_reversed] diff --git a/flash_linear_attention/fla/ops/triton/rebased/__init__.py b/flash_linear_attention/fla/ops/triton/rebased/__init__.py new file mode 100644 index 0000000..8080094 --- /dev/null +++ b/flash_linear_attention/fla/ops/triton/rebased/__init__.py @@ -0,0 +1,4 @@ +from .parallel import parallel_rebased + + +__all__ = ["parallel_rebased"] diff --git a/flash_linear_attention/fla/ops/triton/rebased/parallel.py b/flash_linear_attention/fla/ops/triton/rebased/parallel.py new file mode 100644 index 0000000..777a9e5 --- /dev/null +++ b/flash_linear_attention/fla/ops/triton/rebased/parallel.py @@ -0,0 +1,388 @@ + +# -*- coding: utf-8 -*- + +import torch +import triton +import triton.language as tl + +from fla.ops.triton.utils import contiguous +from torch.cuda.amp import custom_bwd, custom_fwd + +# Based: An Educational and Effective Sequence Mixer +# https://hazyresearch.stanford.edu/blog/2023-12-11-zoology2-based + + +@triton.jit +def parallel_rebased_fwd_kernel( + # B: batch_size, H: n_heads, T: seq_len, D: d_head + q, # query [B, H, L, D_head_K] + k, # key [B, H, L, D_head_V] + v, # value [B, H, L, D_head_V] + o, # output [B, H, L, D_head_V] + z, # normalizer [B, H, L] + s_qk_h, # stride size: L * D_head_K + s_qk_t, # stride size: D_head_K + s_qk_d, # stride size: 1 + s_vo_h, # stride size: L * D_head_V + s_vo_t, # stride size: D_head_V + s_vo_d, # stride size: 1 + B, # batch size + H, # n_heads + T, # seq_len + scale, # D_head_K ** -0.5 + BTL: tl.constexpr, # BLOCK SIZE along the sequence dimension for Q + BTS: tl.constexpr, # BLOCK SIZE along the sequence dimension for K/V + BK: tl.constexpr, # BLOCK SIZE along the K dimension + BV: tl.constexpr, # BLOCK SIZE along the V dimension + DK: tl.constexpr, # D_head_K + DV: tl.constexpr, # D_head_V +): + # i_c: chunk index. used for sequence parallelism + i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + NV = tl.cdiv(DV, BV) + i_k = i_kv // (NV) + i_v = i_kv % (NV) + + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK), + (s_qk_t, s_qk_d), (i_c * BTL, i_k * BK), (BTL, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T), + (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BTS), (0, 1)) + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), + (s_vo_t, s_vo_d), (0, i_v * BV), (BTS, BV), (1, 0)) + + # [BQ, BD] block Q, in the shared memory throughout the whole kernel + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + b_o = tl.zeros([BTL, BV], dtype=tl.float32) + b_z = tl.zeros([BTL], dtype=tl.float32) + + # Q block and K block have no overlap + # no need for mask, thereby saving flops + for _ in range(0, i_c * BTL, BTS): + # [BK, BTS] + b_k = tl.load(p_k, boundary_check=(0, 1)) + + # [BTS, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BTL, BTS] + b_s = tl.dot(b_q, (b_k), allow_tf32=False) + b_s = 0.5 + b_s + 0.5 * b_s * b_s + b_z += tl.sum(b_s, axis=1) + + # [BQ, BD] + b_o = b_o + tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False) + p_k = tl.advance(p_k, (0, BTS)) + p_v = tl.advance(p_v, (BTS, 0)) + + # # rescale interchunk output + tl.debug_barrier() + o_q = tl.arange(0, BTL) + # # sync threads, easy for compiler to optimize + # tl.debug_barrier() + + o_k = tl.arange(0, BTS) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T), + (s_qk_d, s_qk_t), (i_k * BK, i_c * BTL), (BK, BTS), (0, 1)) + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), + (s_vo_t, s_vo_d), (i_c * BTL, i_v * BV), (BTS, BV), (1, 0)) + # Q block and K block have overlap. masks required + for _ in range(i_c * BTL, (i_c + 1) * BTL, BTS): + # [BK, BTS] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BTS, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BTL, BTS] + m_s = o_q[:, None] >= o_k[None, :] + b_s = tl.dot(b_q, b_k, allow_tf32=False) + b_s = 0.5 + b_s + 0.5 * b_s * b_s + b_s = tl.where(m_s, b_s, 0) + b_z += tl.sum(b_s, axis=1) + # [BTL, BV] + b_o += tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False) + + p_k = tl.advance(p_k, (0, BTS)) + p_v = tl.advance(p_v, (BTS, 0)) + o_k += BTS + + p_o = tl.make_block_ptr(o + (i_bh + B * H * i_k) * s_vo_h, (T, DV), + (s_vo_t, s_vo_d), (i_c*BTL, i_v*BV), (BTL, BV), (1, 0)) + p_z = z + (i_bh + B * H * i_k) * T + i_c * BTL + tl.arange(0, BTL) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_z, b_z.to(p_z.dtype.element_ty), + mask=((i_c * BTL + tl.arange(0, BTL)) < T)) + + +@triton.jit +def _parallel_rebased_bwd_dq( + i_bh, i_c, i_k, i_v, i_h, + q, k, v, do, dz, dq, s_qk_h, s_qk_t, s_qk_d, s_vo_h, + s_vo_t, s_vo_d, B, H, T, scale, + BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, + DK: tl.constexpr, DV: tl.constexpr, +): + p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), + (i_c * BTL, i_v * BV), (BTL, BV), (1, 0)) + p_q = tl.make_block_ptr(q + (i_bh) * s_qk_h, (T, DK), + (s_qk_t, s_qk_d), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype) + b_q = (b_q * scale).to(b_q.dtype) + b_dq = tl.zeros([BTL, BK], dtype=tl.float32) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), + (s_qk_t, s_qk_d), (0, i_k * BK), (BTS, BK), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (DV, T), + (s_vo_d, s_vo_t), (i_v * BV, 0), (BV, BTS), (0, 1)) + p_dz = dz + i_bh * T + i_c * BTL + tl.arange(0, BTL) + b_dz = tl.load(p_dz, mask=(i_c * BTL + tl.arange(0, BTL)) < T) + + for _ in range(0, i_c * BTL, BTS): + # [BTS, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BV, BTS] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BTL, BTS] + b_ds = tl.dot(b_do, b_v, allow_tf32=False) + if i_v == 0: + b_ds += b_dz[:, None] + else: + b_ds = b_ds + b_s = tl.dot(b_q, tl.trans(b_k), allow_tf32=False) + # [BQ, BD] + b_dq += tl.dot((b_ds * (1 + b_s)).to(b_v.dtype), b_k, allow_tf32=False) + p_k = tl.advance(p_k, (BTS, 0)) + p_v = tl.advance(p_v, (0, BTS)) + + b_dq *= scale + o_q = tl.arange(0, BTL) + o_k = tl.arange(0, BTS) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), + (s_qk_t, s_qk_d), (i_c * BTL, i_k * BK), (BTS, BK), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (DV, T), + (s_vo_d, s_vo_t), (i_v * BV, i_c * BTL), (BV, BTS), (0, 1)) + # Q block and K block have overlap. masks required + for _ in range(i_c * BTL, (i_c + 1) * BTL, BTS): + # [BTS, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BV, BTS] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BTL, BTS] + m_s = o_q[:, None] >= o_k[None, :] + b_ds = tl.dot(b_do, b_v, allow_tf32=False) + if i_v == 0: + b_ds += b_dz[:, None] + else: + b_ds = b_ds + b_ds = tl.where(m_s, b_ds, 0) * scale + b_s = tl.dot(b_q, tl.trans(b_k), allow_tf32=False) + b_s = tl.where(m_s, b_s, 0) + # [BTL, BK] + b_dq += tl.dot((b_ds + b_ds * b_s).to(b_k.dtype), + b_k, allow_tf32=False) + p_k = tl.advance(p_k, (BTS, 0)) + p_v = tl.advance(p_v, (0, BTS)) + o_k += BTS + p_dq = tl.make_block_ptr(dq + (i_bh + B * H * i_v) * s_qk_h, (T, DK), + (s_qk_t, s_qk_d), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0)) + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + return + + +@triton.jit +def _parallel_rebased_bwd_dkv( + i_bh, i_c, i_k, i_v, i_h, + q, k, v, do, dz, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h, + s_vo_t, s_vo_d, B, H, T, scale, + BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, + DK: tl.constexpr, DV: tl.constexpr, +): + # compute dk dv + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), + (i_c * BTL, i_k * BK), (BTL, BK), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), + (i_c * BTL, i_v * BV), (BTL, BV), (1, 0)) + b_k, b_v = tl.load(p_k, boundary_check=(0, 1)), tl.load( + p_v, boundary_check=(0, 1)) + b_dk, b_dv = tl.zeros([BTL, BK], dtype=tl.float32), tl.zeros( + [BTL, BV], dtype=tl.float32) + + for i in range((tl.cdiv(T, BTS) * BTS)-BTS, (i_c + 1) * BTL - BTS, -BTS): + p_q = tl.make_block_ptr( + q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, i), (BK, BTS), (0, 1)) + p_do = tl.make_block_ptr( + do + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i), (BV, BTS), (0, 1)) + p_dz = dz + i_bh * T + i + tl.arange(0, BTS) + b_q = tl.load(p_q, boundary_check=(0, 1)) # [BK, BTS] + b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype) # [BV, BTS] + b_dz = tl.load(p_dz, mask=(i + tl.arange(0, BTS)) < T) + b_s = tl.dot(b_k.to(b_q.dtype), b_q, allow_tf32=False) * \ + scale # [BTL, BTS] + b_s2 = 0.5 + b_s + 0.5 * b_s * b_s + b_dv += tl.dot(b_s2.to(b_q.dtype), tl.trans(b_do), allow_tf32=False) + b_ds = tl.dot(b_v, b_do, allow_tf32=False) * scale + if i_v == 0: + b_ds += b_dz[None, :] * scale + else: + b_ds = b_ds + b_dk += tl.dot((b_ds + b_ds * b_s).to(b_q.dtype), + tl.trans(b_q), allow_tf32=False) + + tl.debug_barrier() + o_q, o_k = tl.arange(0, BTS), tl.arange(0, BTL) + for i in range(i_c*BTL, (i_c+1)*BTL, BTS): + p_q = tl.make_block_ptr( + q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, i), (BK, BTS), (0, 1)) + p_do = tl.make_block_ptr( + do + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i), (BV, BTS), (0, 1)) + p_dz = dz + i_bh * T + i + tl.arange(0, BTS) + b_q = tl.load(p_q, boundary_check=(0, 1)) # [BD, BQ] + b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype) + b_dz = tl.load(p_dz, mask=(i + tl.arange(0, BTS)) < T) + # [BK, BQ] + m_s = o_k[:, None] <= o_q[None, :] + b_s = tl.dot(b_k, b_q, allow_tf32=False) * scale + b_s2 = 0.5 + b_s + 0.5 * b_s * b_s + b_s = tl.where(m_s, b_s, 0) + b_s2 = tl.where(m_s, b_s2, 0) + + b_ds = tl.dot(b_v, b_do, allow_tf32=False) + if i_v == 0: + b_ds += b_dz[None, :] + else: + b_ds = b_ds + b_ds = tl.where(m_s, b_ds, 0) * scale + # [BK, BD] + b_dv += tl.dot(b_s2.to(b_q.dtype), tl.trans(b_do), allow_tf32=False) + b_dk += tl.dot((b_ds + b_ds * b_s).to(b_q.dtype), + tl.trans(b_q), allow_tf32=False) + o_q += BTS + + p_dk = tl.make_block_ptr(dk + (i_bh + B * H * i_v) * s_qk_h, + (T, DK), (s_qk_t, s_qk_d), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0)) + p_dv = tl.make_block_ptr(dv + (i_bh + B * H * i_k) * s_vo_h, + (T, DV), (s_vo_t, s_vo_d), (i_c*BTL, i_v*BV), (BTL, BV), (1, 0)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + return + + +@triton.jit +def parallel_rebased_bwd_kernel( + q, k, v, do, dz, dq, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h, + s_vo_t, s_vo_d, B, H, T, scale, + BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, + DK: tl.constexpr, DV: tl.constexpr, +): + i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + NV = tl.cdiv(DV, BV) + i_k = i_kv // (NV) + i_v = i_kv % (NV) + i_h = i_bh % H + _parallel_rebased_bwd_dq( + i_bh, i_c, i_k, i_v, i_h, + q, k, v, do, dz, dq, s_qk_h, s_qk_t, s_qk_d, s_vo_h, + s_vo_t, s_vo_d, B, H, T, scale, BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=DK, DV=DV + ) + tl.debug_barrier() + _parallel_rebased_bwd_dkv( + i_bh, i_c, i_k, i_v, i_h, + q, k, v, do, dz, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h, + s_vo_t, s_vo_d, B, H, T, scale, BTL, BTS, BK, BV, DK, DV + ) + + +class ParallelBasedFunction(torch.autograd.Function): + @staticmethod + @contiguous + @custom_fwd + def forward(ctx, q, k, v, scale): + BTL, BTS = 128, 32 + assert BTL % BTS == 0 + # assert q.shape[-1] % 16 == 0 + BK = min(128, triton.next_power_of_2(k.shape[-1])) + BV = min(128, triton.next_power_of_2(v.shape[-1])) + BK, BV = max(BK, 16), max(BV, 16) + batch_size, n_heads, seq_len, d_head_qk = q.shape + d_head_v = v.shape[-1] + num_stages = 2 + num_warps = 4 + NK = triton.cdiv(d_head_qk, BK) + NV = triton.cdiv(d_head_v, BV) + grid = (NK * NV, triton.cdiv(seq_len, BTL), batch_size * n_heads) + + assert NK == 1, "will encounter some synchronization issue if not." + + o = torch.empty(NK, batch_size, n_heads, seq_len, + d_head_v, device=q.device) + z = torch.empty(NK, batch_size, n_heads, seq_len, + device=q.device) + parallel_rebased_fwd_kernel[grid]( + q, k, v, o, z, + q.stride(1), q.stride(2), q.stride(3), + v.stride(1), v.stride(2), v.stride(3), + batch_size, n_heads, seq_len, scale, + BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=d_head_qk, DV=d_head_v, + num_warps=num_warps, + num_stages=num_stages + ) + ctx.save_for_backward(q, k, v) + ctx.scale = scale + return o.sum(0).to(q.dtype), z.sum(0).to(q.dtype) + + @staticmethod + @custom_bwd + @contiguous + def backward(ctx, do, dz): + q, k, v = ctx.saved_tensors + scale = ctx.scale + BTL, BTS = 64, 32 + assert BTL % BTS == 0 + BK = min(128, triton.next_power_of_2(k.shape[-1])) + BV = min(128, triton.next_power_of_2(v.shape[-1])) + BK, BV = max(BK, 16), max(BV, 16) + batch_size, n_heads, seq_len, d_head_qk = q.shape + d_head_v = v.shape[-1] + num_stages = 2 + num_warps = 4 + NK = triton.cdiv(d_head_qk, BK) + NV = triton.cdiv(d_head_v, BV) + grid = (NK * NV, triton.cdiv(seq_len, BTL), batch_size * n_heads) + + assert NK == 1, "will encounter some synchronization issue if not" + + dq = torch.empty(NV, batch_size, n_heads, seq_len, + d_head_qk, dtype=q.dtype, device=q.device) + dk = torch.empty(NV, batch_size, n_heads, seq_len, + d_head_qk, dtype=q.dtype, device=q.device) + dv = torch.empty(NK, batch_size, n_heads, seq_len, + d_head_v, dtype=q.dtype, device=q.device) + + parallel_rebased_bwd_kernel[grid]( + q, k, v, do, dz, dq, dk, dv, + q.stride(1), q.stride(2), q.stride(3), + v.stride(1), v.stride(2), v.stride(3), + batch_size, n_heads, seq_len, scale, + BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=d_head_qk, DV=d_head_v, + num_warps=num_warps, + num_stages=num_stages + ) + + return dq.sum(0).to(q.dtype), dk.sum(0).to(k.dtype), dv.sum(0).to(v.dtype), None + + +triton_parallel_based = ParallelBasedFunction.apply + + +def parallel_rebased(q, k, v, eps, use_scale=True, use_normalize=True, return_both=False): + assert q.shape[-1] <= 128, "only support feature dim up to 128" + if use_scale: + scale = q.shape[-1] ** -0.5 + else: + scale = 1 + o, z = triton_parallel_based(q, k, v, scale) + if return_both: + return o, z + if use_normalize: + o = o / (z[..., None] + eps) + else: + o = o + return o.to(q.dtype) diff --git a/flash_linear_attention/fla/ops/triton/rebased_fast/__init__.py b/flash_linear_attention/fla/ops/triton/rebased_fast/__init__.py new file mode 100644 index 0000000..8080094 --- /dev/null +++ b/flash_linear_attention/fla/ops/triton/rebased_fast/__init__.py @@ -0,0 +1,4 @@ +from .parallel import parallel_rebased + + +__all__ = ["parallel_rebased"] diff --git a/flash_linear_attention/fla/ops/triton/rebased_fast/parallel.py b/flash_linear_attention/fla/ops/triton/rebased_fast/parallel.py new file mode 100644 index 0000000..dce2626 --- /dev/null +++ b/flash_linear_attention/fla/ops/triton/rebased_fast/parallel.py @@ -0,0 +1,390 @@ + +# -*- coding: utf-8 -*- + +import torch +import triton +import triton.language as tl + +from fla.ops.triton.utils import contiguous +from torch.cuda.amp import custom_bwd, custom_fwd + +# Based: An Educational and Effective Sequence Mixer +# https://hazyresearch.stanford.edu/blog/2023-12-11-zoology2-based + + +@triton.jit +def parallel_rebased_fwd_kernel( + # B: batch_size, H: n_heads, T: seq_len, D: d_head + q, # query [B, H, L, D_head_K] + k, # key [B, H, L, D_head_V] + v, # value [B, H, L, D_head_V] + o, # output [B, H, L, D_head_V] + z, # normalizer [B, H, L] + s_qk_h, # stride size: L * D_head_K + s_qk_t, # stride size: D_head_K + s_qk_d, # stride size: 1 + s_vo_h, # stride size: L * D_head_V + s_vo_t, # stride size: D_head_V + s_vo_d, # stride size: 1 + B, # batch size + H, # n_heads + T, # seq_len + scale, # D_head_K ** -0.5 + BTL: tl.constexpr, # BLOCK SIZE along the sequence dimension for Q + BTS: tl.constexpr, # BLOCK SIZE along the sequence dimension for K/V + BK: tl.constexpr, # BLOCK SIZE along the K dimension + BV: tl.constexpr, # BLOCK SIZE along the V dimension + DK: tl.constexpr, # D_head_K + DV: tl.constexpr, # D_head_V +): + # i_c: chunk index. used for sequence parallelism + i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + NV = tl.cdiv(DV, BV) + i_k = i_kv // (NV) + i_v = i_kv % (NV) + + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK), + (s_qk_t, s_qk_d), (i_c * BTL, i_k * BK), (BTL, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T), + (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BTS), (0, 1)) + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), + (s_vo_t, s_vo_d), (0, i_v * BV), (BTS, BV), (1, 0)) + + # [BQ, BD] block Q, in the shared memory throughout the whole kernel + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + b_o = tl.zeros([BTL, BV], dtype=tl.float32) + b_z = tl.zeros([BTL], dtype=tl.float32) + + # Q block and K block have no overlap + # no need for mask, thereby saving flops + for _ in range(0, i_c * BTL, BTS): + # [BK, BTS] + b_k = tl.load(p_k, boundary_check=(0, 1)) + + # [BTS, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BTL, BTS] + b_s = tl.dot(b_q, (b_k), allow_tf32=False) + b_s = b_s * b_s + b_z += tl.sum(b_s, axis=1) + + # [BQ, BD] + b_o = b_o + tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False) + p_k = tl.advance(p_k, (0, BTS)) + p_v = tl.advance(p_v, (BTS, 0)) + + # # rescale interchunk output + tl.debug_barrier() + o_q = tl.arange(0, BTL) + # # sync threads, easy for compiler to optimize + # tl.debug_barrier() + + o_k = tl.arange(0, BTS) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T), + (s_qk_d, s_qk_t), (i_k * BK, i_c * BTL), (BK, BTS), (0, 1)) + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), + (s_vo_t, s_vo_d), (i_c * BTL, i_v * BV), (BTS, BV), (1, 0)) + # Q block and K block have overlap. masks required + for _ in range(i_c * BTL, (i_c + 1) * BTL, BTS): + # [BK, BTS] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BTS, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BTL, BTS] + m_s = o_q[:, None] >= o_k[None, :] + b_s = tl.dot(b_q, b_k, allow_tf32=False) + b_s = b_s * b_s + b_s = tl.where(m_s, b_s, 0) + b_z += tl.sum(b_s, axis=1) + # [BTL, BV] + b_o += tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False) + + p_k = tl.advance(p_k, (0, BTS)) + p_v = tl.advance(p_v, (BTS, 0)) + o_k += BTS + + p_o = tl.make_block_ptr(o + (i_bh + B * H * i_k) * s_vo_h, (T, DV), + (s_vo_t, s_vo_d), (i_c*BTL, i_v*BV), (BTL, BV), (1, 0)) + p_z = z + (i_bh + B * H * i_k) * T + i_c * BTL + tl.arange(0, BTL) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_z, b_z.to(p_z.dtype.element_ty), + mask=((i_c * BTL + tl.arange(0, BTL)) < T)) + + +@triton.jit +def _parallel_rebased_bwd_dq( + i_bh, i_c, i_k, i_v, i_h, + q, k, v, do, dz, dq, s_qk_h, s_qk_t, s_qk_d, s_vo_h, + s_vo_t, s_vo_d, B, H, T, scale, + BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, + DK: tl.constexpr, DV: tl.constexpr, +): + p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), + (i_c * BTL, i_v * BV), (BTL, BV), (1, 0)) + p_q = tl.make_block_ptr(q + (i_bh) * s_qk_h, (T, DK), + (s_qk_t, s_qk_d), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype) + b_q = (b_q * scale).to(b_q.dtype) + b_dq = tl.zeros([BTL, BK], dtype=tl.float32) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), + (s_qk_t, s_qk_d), (0, i_k * BK), (BTS, BK), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (DV, T), + (s_vo_d, s_vo_t), (i_v * BV, 0), (BV, BTS), (0, 1)) + p_dz = dz + i_bh * T + i_c * BTL + tl.arange(0, BTL) + b_dz = tl.load(p_dz, mask=(i_c * BTL + tl.arange(0, BTL)) < T) + + for _ in range(0, i_c * BTL, BTS): + # [BTS, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BV, BTS] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BTL, BTS] + b_ds = tl.dot(b_do, b_v, allow_tf32=False) + if i_v == 0: + b_ds += b_dz[:, None] + else: + b_ds = b_ds + b_s = tl.dot(b_q, tl.trans(b_k), allow_tf32=False) + # [BQ, BD] + b_dq += tl.dot((2 * b_ds * b_s).to(b_v.dtype), b_k, allow_tf32=False) + p_k = tl.advance(p_k, (BTS, 0)) + p_v = tl.advance(p_v, (0, BTS)) + + b_dq *= scale + o_q = tl.arange(0, BTL) + o_k = tl.arange(0, BTS) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), + (s_qk_t, s_qk_d), (i_c * BTL, i_k * BK), (BTS, BK), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (DV, T), + (s_vo_d, s_vo_t), (i_v * BV, i_c * BTL), (BV, BTS), (0, 1)) + # Q block and K block have overlap. masks required + for _ in range(i_c * BTL, (i_c + 1) * BTL, BTS): + # [BTS, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BV, BTS] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BTL, BTS] + m_s = o_q[:, None] >= o_k[None, :] + b_ds = tl.dot(b_do, b_v, allow_tf32=False) + if i_v == 0: + b_ds += b_dz[:, None] + else: + b_ds = b_ds + b_ds = tl.where(m_s, b_ds, 0) * scale + b_s = tl.dot(b_q, tl.trans(b_k), allow_tf32=False) + b_s = tl.where(m_s, b_s, 0) + # [BTL, BK] + b_dq += tl.dot((2 * b_ds * b_s).to(b_k.dtype), + b_k, allow_tf32=False) + p_k = tl.advance(p_k, (BTS, 0)) + p_v = tl.advance(p_v, (0, BTS)) + o_k += BTS + p_dq = tl.make_block_ptr(dq + (i_bh + B * H * i_v) * s_qk_h, (T, DK), + (s_qk_t, s_qk_d), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0)) + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + return + + + + +@triton.jit +def _parallel_rebased_bwd_dkv( + i_bh, i_c, i_k, i_v, i_h, + q, k, v, do, dz, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h, + s_vo_t, s_vo_d, B, H, T, scale, + BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, + DK: tl.constexpr, DV: tl.constexpr, +): + # compute dk dv + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), + (i_c * BTL, i_k * BK), (BTL, BK), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), + (i_c * BTL, i_v * BV), (BTL, BV), (1, 0)) + b_k, b_v = tl.load(p_k, boundary_check=(0, 1)), tl.load( + p_v, boundary_check=(0, 1)) + b_dk, b_dv = tl.zeros([BTL, BK], dtype=tl.float32), tl.zeros( + [BTL, BV], dtype=tl.float32) + + for i in range((tl.cdiv(T, BTS) * BTS)-BTS, (i_c + 1) * BTL - BTS, -BTS): + p_q = tl.make_block_ptr( + q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, i), (BK, BTS), (0, 1)) + p_do = tl.make_block_ptr( + do + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i), (BV, BTS), (0, 1)) + p_dz = dz + i_bh * T + i + tl.arange(0, BTS) + b_q = tl.load(p_q, boundary_check=(0, 1)) # [BK, BTS] + b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype) # [BV, BTS] + b_dz = tl.load(p_dz, mask=(i + tl.arange(0, BTS)) < T) + b_s = tl.dot(b_k.to(b_q.dtype), b_q, allow_tf32=False) * \ + scale # [BTL, BTS] + b_s2 = b_s * b_s + b_dv += tl.dot(b_s2.to(b_q.dtype), tl.trans(b_do), allow_tf32=False) + b_ds = tl.dot(b_v, b_do, allow_tf32=False) * scale + if i_v == 0: + b_ds += b_dz[None, :] * scale + else: + b_ds = b_ds + b_dk += tl.dot((2 * b_ds * b_s).to(b_q.dtype), + tl.trans(b_q), allow_tf32=False) + + tl.debug_barrier() + o_q, o_k = tl.arange(0, BTS), tl.arange(0, BTL) + for i in range(i_c*BTL, (i_c+1)*BTL, BTS): + p_q = tl.make_block_ptr( + q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, i), (BK, BTS), (0, 1)) + p_do = tl.make_block_ptr( + do + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i), (BV, BTS), (0, 1)) + p_dz = dz + i_bh * T + i + tl.arange(0, BTS) + b_q = tl.load(p_q, boundary_check=(0, 1)) # [BD, BQ] + b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype) + b_dz = tl.load(p_dz, mask=(i + tl.arange(0, BTS)) < T) + # [BK, BQ] + m_s = o_k[:, None] <= o_q[None, :] + b_s = tl.dot(b_k, b_q, allow_tf32=False) * scale + b_s2 = b_s * b_s + b_s = tl.where(m_s, b_s, 0) + b_s2 = tl.where(m_s, b_s2, 0) + + b_ds = tl.dot(b_v, b_do, allow_tf32=False) + if i_v == 0: + b_ds += b_dz[None, :] + else: + b_ds = b_ds + b_ds = tl.where(m_s, b_ds, 0) * scale + # [BK, BD] + b_dv += tl.dot(b_s2.to(b_q.dtype), tl.trans(b_do), allow_tf32=False) + b_dk += tl.dot((2 * b_ds * b_s).to(b_q.dtype), + tl.trans(b_q), allow_tf32=False) + o_q += BTS + + p_dk = tl.make_block_ptr(dk + (i_bh + B * H * i_v) * s_qk_h, + (T, DK), (s_qk_t, s_qk_d), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0)) + p_dv = tl.make_block_ptr(dv + (i_bh + B * H * i_k) * s_vo_h, + (T, DV), (s_vo_t, s_vo_d), (i_c*BTL, i_v*BV), (BTL, BV), (1, 0)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + return + + +@triton.jit +def parallel_rebased_bwd_kernel( + q, k, v, do, dz, dq, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h, + s_vo_t, s_vo_d, B, H, T, scale, + BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, + DK: tl.constexpr, DV: tl.constexpr, +): + i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + NV = tl.cdiv(DV, BV) + i_k = i_kv // (NV) + i_v = i_kv % (NV) + i_h = i_bh % H + _parallel_rebased_bwd_dq( + i_bh, i_c, i_k, i_v, i_h, + q, k, v, do, dz, dq, s_qk_h, s_qk_t, s_qk_d, s_vo_h, + s_vo_t, s_vo_d, B, H, T, scale, BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=DK, DV=DV + ) + tl.debug_barrier() + _parallel_rebased_bwd_dkv( + i_bh, i_c, i_k, i_v, i_h, + q, k, v, do, dz, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h, + s_vo_t, s_vo_d, B, H, T, scale, BTL, BTS, BK, BV, DK, DV + ) + + +class ParallelBasedFunction(torch.autograd.Function): + @staticmethod + @contiguous + @custom_fwd + def forward(ctx, q, k, v, scale): + BTL, BTS = 128, 32 + assert BTL % BTS == 0 + # assert q.shape[-1] % 16 == 0 + BK = min(128, triton.next_power_of_2(k.shape[-1])) + BV = min(128, triton.next_power_of_2(v.shape[-1])) + BK, BV = max(BK, 16), max(BV, 16) + batch_size, n_heads, seq_len, d_head_qk = q.shape + d_head_v = v.shape[-1] + num_stages = 2 + num_warps = 4 + NK = triton.cdiv(d_head_qk, BK) + NV = triton.cdiv(d_head_v, BV) + grid = (NK * NV, triton.cdiv(seq_len, BTL), batch_size * n_heads) + + assert NK == 1, "will encounter some synchronization issue if not." + + o = torch.empty(NK, batch_size, n_heads, seq_len, + d_head_v, device=q.device) + z = torch.empty(NK, batch_size, n_heads, seq_len, + device=q.device) + parallel_rebased_fwd_kernel[grid]( + q, k, v, o, z, + q.stride(1), q.stride(2), q.stride(3), + v.stride(1), v.stride(2), v.stride(3), + batch_size, n_heads, seq_len, scale, + BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=d_head_qk, DV=d_head_v, + num_warps=num_warps, + num_stages=num_stages + ) + ctx.save_for_backward(q, k, v) + ctx.scale = scale + return o.sum(0).to(q.dtype), z.sum(0).to(q.dtype) + + @staticmethod + @custom_bwd + @contiguous + def backward(ctx, do, dz): + q, k, v = ctx.saved_tensors + scale = ctx.scale + BTL, BTS = 64, 32 + assert BTL % BTS == 0 + BK = min(128, triton.next_power_of_2(k.shape[-1])) + BV = min(128, triton.next_power_of_2(v.shape[-1])) + BK, BV = max(BK, 16), max(BV, 16) + batch_size, n_heads, seq_len, d_head_qk = q.shape + d_head_v = v.shape[-1] + num_stages = 2 + num_warps = 4 + NK = triton.cdiv(d_head_qk, BK) + NV = triton.cdiv(d_head_v, BV) + grid = (NK * NV, triton.cdiv(seq_len, BTL), batch_size * n_heads) + + assert NK == 1, "will encounter some synchronization issue if not" + + dq = torch.empty(NV, batch_size, n_heads, seq_len, + d_head_qk, dtype=q.dtype, device=q.device) + dk = torch.empty(NV, batch_size, n_heads, seq_len, + d_head_qk, dtype=q.dtype, device=q.device) + dv = torch.empty(NK, batch_size, n_heads, seq_len, + d_head_v, dtype=q.dtype, device=q.device) + + parallel_rebased_bwd_kernel[grid]( + q, k, v, do, dz, dq, dk, dv, + q.stride(1), q.stride(2), q.stride(3), + v.stride(1), v.stride(2), v.stride(3), + batch_size, n_heads, seq_len, scale, + BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=d_head_qk, DV=d_head_v, + num_warps=num_warps, + num_stages=num_stages + ) + + return dq.sum(0).to(q.dtype), dk.sum(0).to(k.dtype), dv.sum(0).to(v.dtype), None + + +triton_parallel_based = ParallelBasedFunction.apply + + +def parallel_rebased(q, k, v, eps, use_scale=True, use_normalize=True, return_both=False): + assert q.shape[-1] <= 128, "only support feature dim up to 128" + if use_scale: + scale = q.shape[-1] ** -0.5 + else: + scale = 1 + o, z = triton_parallel_based(q, k, v, scale) + if return_both: + return o, z + if use_normalize: + o = o / (z[..., None] + eps) + else: + o = o + return o.to(q.dtype) diff --git a/flash_linear_attention/fla/ops/triton/retention/__init__.py b/flash_linear_attention/fla/ops/triton/retention/__init__.py new file mode 100644 index 0000000..1aaa71d --- /dev/null +++ b/flash_linear_attention/fla/ops/triton/retention/__init__.py @@ -0,0 +1,9 @@ +# -*- coding: utf-8 -*- + +from .chunk import chunk_retention +from .chunk_fuse import fused_chunk_retention +from .parallel import parallel_retention +from .recurrent_fuse import fused_recurrent_retention + +__all__ = ['fused_chunk_retention', 'parallel_retention', + 'fused_recurrent_retention', 'chunk_retention'] diff --git a/flash_linear_attention/fla/ops/triton/retention/chunk.py b/flash_linear_attention/fla/ops/triton/retention/chunk.py new file mode 100644 index 0000000..fb6e93b --- /dev/null +++ b/flash_linear_attention/fla/ops/triton/retention/chunk.py @@ -0,0 +1,389 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023, Yu Zhang, Songlin Yang + +import torch +import triton +import triton.language as tl +from torch.cuda.amp import custom_bwd, custom_fwd + +from fla.ops.triton.utils import contiguous + + +@triton.jit +def chunk_retention_fwd_kernel_h( + k, + v, + h, + initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V] + final_state, # final state of the chunk [B, H, D_head_K, D_head_V] + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + s_hh, + s_ht, + H, + T, + TD, + DK, + DV, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr +): + i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_h = i_bh % H + b_b = tl.math.log2(1 - tl.math.pow(2, -5 - i_h * 1.0)) + + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BT), (0, 1)) + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0)) + p_h = tl.make_block_ptr(h + i_bh * s_hh, (TD, DV), (s_ht, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + + o_i = tl.arange(0, BT) + d_b, d_i = tl.math.exp2(BT * b_b), tl.math.exp2((BT - o_i - 1) * b_b) + # [BK, BV] + b_h = tl.zeros([BK, BV], dtype=tl.float32) + + if USE_INITIAL_STATE: + p_h0 = tl.make_block_ptr(initial_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32) + + for _ in range(0, T, BT): + tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1)) + # [BK, BT] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BK, BV] + b_h = d_b * b_h + tl.dot(b_k, (b_v * d_i[:, None]).to(b_k.dtype), allow_tf32=False) + + p_k = tl.advance(p_k, (0, BT)) + p_v = tl.advance(p_v, (BT, 0)) + p_h = tl.advance(p_h, (DK, 0)) + + if STORE_FINAL_STATE: + p_ht = tl.make_block_ptr(final_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit +def chunk_retention_fwd_kernel_o( + q, + k, + v, + h, + o, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + s_hh, + s_ht, + H, + T, + TD, + scale, + DK, + DV, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_h = i_bh % H + b_b = tl.math.log2(1 - tl.math.pow(2, -5 - i_h * 1.0)) + + o_i = tl.arange(0, BT) + d_i = tl.math.exp2((o_i + 1) * b_b) + m_s = o_i[:, None] >= o_i[None, :] + d_s = tl.where(m_s, tl.math.exp2((o_i[:, None] - o_i[None, :]) * b_b), 0) + + for i_v in range(0, tl.cdiv(DV, BV)): + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i_t * BT, 0), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (0, i_t * BT), (BK, BT), (0, 1)) + p_h = tl.make_block_ptr(h + i_bh * s_hh, (TD, DV), (s_ht, 1), (i_t * DK, i_v * BV), (BK, BV), (1, 0)) + + b_o = tl.zeros([BT, BV], dtype=tl.float32) + b_s = tl.zeros([BT, BT], dtype=tl.float32) + for _ in range(0, tl.cdiv(DK, BK)): + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + # [BD, BT] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BD, BD] + b_h = tl.load(p_h, boundary_check=(0, 1)) + b_o += tl.dot((b_q * d_i[:, None]).to(b_q.dtype), b_h, allow_tf32=False) + b_s += tl.dot(b_q, b_k, allow_tf32=False) + + p_q = tl.advance(p_q, (0, BK)) + p_k = tl.advance(p_k, (BK, 0)) + p_h = tl.advance(p_h, (BK, 0)) + + b_s *= d_s + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_o += tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False) + p_o = tl.make_block_ptr(o + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit +def chunk_retention_bwd_kernel_dh( + q, + do, + dh, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + s_hh, + s_ht, + H, + T, + scale, + DK, + DV, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr +): + i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_h = i_bh % H + b_b = tl.math.log2(1 - tl.math.pow(2, -5 - i_h * 1.0)) + + o_i = tl.arange(0, BT) + d_b, d_i = tl.math.exp2(BT * b_b), tl.math.exp2((o_i + 1) * b_b) + # [BK, BV] + b_dh = tl.zeros([BK, BV], dtype=tl.float32) + for i in range(NT - 1, -1, -1): + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, i * BT), (BK, BT), (0, 1)) + p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i * BT, i_v * BV), (BT, BV), (1, 0)) + p_dh = tl.make_block_ptr(dh + i_bh * s_hh, ((i+1)*DK, DV), (s_ht, 1), (i * DK + i_k * BK, i_v * BV), (BK, BV), (1, 0)) + + tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1)) + # [BK, BT] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + # [BT, DV] + b_do = tl.load(p_do, boundary_check=(0, 1)) + # [BK, BV] + b_dh = d_b * b_dh + tl.dot(b_q, (b_do * d_i[:, None]).to(b_q.dtype), allow_tf32=False) + + +@triton.jit +def chunk_retention_bwd_kernel_dqkv( + q, + k, + v, + h, + do, + dh, + dq, + dk, + dv, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + s_hh, + s_ht, + H, + T, + TDK, + scale, + DK, + DV, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_h = i_bh % H + b_b = tl.math.log2(1 - tl.math.pow(2, -5 - i_h * 1.0)) + + o_i = tl.arange(0, BT) + d_q, d_k = tl.math.exp2((o_i + 1) * b_b), tl.math.exp2((BT - o_i - 1) * b_b) + d_q = (d_q * scale).to(d_q.dtype) + m_s = o_i[:, None] >= o_i[None, :] + d_s = tl.where(m_s, tl.math.exp2((o_i[:, None] - o_i[None, :]) * b_b), 0) * scale + + for i_k in range(0, tl.cdiv(DK, BK)): + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i_t * BT, 0), (BT, BV), (1, 0)) + p_h = tl.make_block_ptr(h + i_bh * s_hh, (DV, TDK), (1, s_ht), (0, i_t * DK + i_k * BK), (BV, BK), (0, 1)) + p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i_t * BT, 0), (BT, BV), (1, 0)) + p_dh = tl.make_block_ptr(dh + i_bh * s_hh, (TDK, DV), (s_ht, 1), (i_t * DK + i_k * BK, 0), (BK, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i_t * BT, 0), (BT, BV), (1, 0)) + + p_dq = tl.make_block_ptr(dq + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_s = tl.dot(b_k, b_q, allow_tf32=False) * tl.trans(d_s) + + b_dq = tl.zeros([BT, BK], dtype=tl.float32) + b_dk = tl.zeros([BT, BK], dtype=tl.float32) + for _ in range(tl.cdiv(DV, BV)): + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + # [BV, BK] + b_h = tl.load(p_h, boundary_check=(0, 1)) + # [BK, BV] + b_dh = tl.load(p_dh, boundary_check=(0, 1)) + + # [BT, BT] + b_ds = tl.dot(b_do, tl.trans(b_v), allow_tf32=False) + b_ds = (b_ds * d_s).to(b_k.dtype) + # [BT, BK] + b_dq += tl.dot(b_do, b_h, allow_tf32=False) * d_q[:, None] + tl.dot(b_ds, b_k, allow_tf32=False) + + # [BT, BT] + b_ds = tl.trans(b_ds) + # [BK, BT] + b_dk += tl.dot(b_v, tl.trans(b_dh), allow_tf32=False) * d_k[:, None] + b_dk += tl.dot(b_ds, tl.trans(b_q), allow_tf32=False) + b_dv = tl.dot(b_k, b_dh, allow_tf32=False) * d_k[:, None] + tl.dot(b_s.to(b_q.dtype), b_do, allow_tf32=False) + b_dv += tl.load(p_dv, boundary_check=(0, 1)).to(tl.float32) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + p_v = tl.advance(p_v, (0, BV)) + p_h = tl.advance(p_h, (BV, 0)) + p_do = tl.advance(p_do, (0, BV)) + p_dh = tl.advance(p_dh, (0, BV)) + p_dv = tl.advance(p_dv, (0, BV)) + + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + + +class ChunkRetentionFunction(torch.autograd.Function): + + @staticmethod + @custom_fwd + @contiguous + def forward(ctx, q, k, v, initial_state, output_final_state): + BT = 64 + DK, DV = k.shape[-1], v.shape[-1] + BK, BV = min(64, triton.next_power_of_2(DK)), min(64, triton.next_power_of_2(DV)) + batch_size, n_heads, seq_len, d_head_qk = q.shape + d_head_v = v.shape[-1] + num_stages = 1 + num_warps = 4 if BK == 64 else 2 + scale = DK ** -0.5 + + NK, NV = triton.cdiv(DK, BK), triton.cdiv(DV, BV) + h = q.new_empty(batch_size, n_heads, triton.cdiv(seq_len, BT) * DK, DV) + + final_state = None + if output_final_state: + final_state = q.new_empty(batch_size, n_heads, d_head_qk, d_head_v, dtype=torch.float32, requires_grad=False) + + grid = (NK, NV, batch_size * n_heads) + chunk_retention_fwd_kernel_h[grid]( + k, v, h, initial_state, final_state, + q.stride(1), q.stride(2), q.stride(3), + v.stride(1), v.stride(2), v.stride(3), + h.stride(1), h.stride(2), + n_heads, seq_len, h.shape[2], + DK=DK, DV=DV, BT=BT, BK=BK, BV=BV, + USE_INITIAL_STATE=initial_state is not None, + STORE_FINAL_STATE=output_final_state, + num_warps=num_warps, + num_stages=num_stages + ) + grid = (triton.cdiv(seq_len, BT), batch_size * n_heads) + o = torch.empty_like(v) + chunk_retention_fwd_kernel_o[grid]( + q, k, v, h, o, + q.stride(1), q.stride(2), q.stride(3), + v.stride(1), v.stride(2), v.stride(3), + h.stride(1), h.stride(2), + n_heads, seq_len, h.shape[2], scale, + BK=BK, BV=BV, DK=DK, DV=DV, BT=BT, + num_warps=num_warps, + num_stages=num_stages + ) + + ctx.save_for_backward(q, k, v, h) + return o.to(q.dtype), final_state + + @staticmethod + @custom_bwd + @contiguous + def backward(ctx, do, d_ht=None): + q, k, v, h = ctx.saved_tensors + + BT = 64 + DK, DV = k.shape[-1], v.shape[-1] + BK, BV = min(64, triton.next_power_of_2(DK)), min(64, triton.next_power_of_2(DV)) + batch_size, n_heads, seq_len, _ = q.shape + num_stages = 1 + num_warps = 4 if BK == 64 else 2 + scale = DK ** -0.5 + + NK, NV = triton.cdiv(DK, BK), triton.cdiv(DV, BV) + grid = (NK, NV, batch_size * n_heads) + dh = q.new_empty(batch_size, n_heads, triton.cdiv(seq_len, BT) * DK, DV) + + chunk_retention_bwd_kernel_dh[grid]( + q, do, dh, + q.stride(1), q.stride(2), q.stride(3), + v.stride(1), v.stride(2), v.stride(3), + dh.stride(1), dh.stride(2), + n_heads, seq_len, scale, + BT=BT, BK=BK, BV=BV, DK=DK, DV=DV, NT=triton.cdiv(seq_len, BT), + num_warps=num_warps, + num_stages=num_stages + ) + + BK, BV = min(64, triton.next_power_of_2(DK)), min(64, triton.next_power_of_2(DV)) + NK, NV = triton.cdiv(DK, BK), triton.cdiv(DV, BV) + grid = (triton.cdiv(seq_len, BT), batch_size * n_heads) + dq = torch.empty_like(q) + dk = torch.empty_like(k) + # must be zero. we need reload + dv = torch.zeros_like(v) + num_stages = 1 + num_warps = 4 if BK == 64 else 2 + chunk_retention_bwd_kernel_dqkv[grid]( + q, k, v, h, do, dh, dq, dk, dv, + q.stride(1), q.stride(2), q.stride(3), + v.stride(1), v.stride(2), v.stride(3), + dh.stride(1), dh.stride(2), + n_heads, seq_len, h.shape[2], scale, + BT=BT, BK=BK, BV=BV, DK=DK, DV=DV, + num_warps=num_warps, + num_stages=num_stages + ) + return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), None, None + + +def chunk_retention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + initial_state: torch.Tensor = None, + output_final_state: bool = False +): + if initial_state is not None: + initial_state = initial_state.detach() + o, final_state = ChunkRetentionFunction.apply(q, k, v, initial_state, output_final_state) + if output_final_state: + return o, final_state + else: + return o diff --git a/flash_linear_attention/fla/ops/triton/retention/chunk_fuse.py b/flash_linear_attention/fla/ops/triton/retention/chunk_fuse.py new file mode 100644 index 0000000..2a8811e --- /dev/null +++ b/flash_linear_attention/fla/ops/triton/retention/chunk_fuse.py @@ -0,0 +1,329 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023, Yu Zhang, Songlin Yang + +import torch +import triton +import triton.language as tl +from packaging import version +from torch.cuda.amp import custom_bwd, custom_fwd + +from fla.ops.triton.utils import contiguous + +# on-the-fly computation without materializing hidden statets into HBMs + + +@triton.jit +def fused_chunk_retention_fwd_kernel( + # B: batch_size, H: n_heads, T: seq_len, D: d_head + q, # query [B, H, L, D_head_K] + k, # key [B, H, L, D_head_V] + v, # value [B, H, L, D_head_V] + o, # output [B, H, L, D_head_V] + initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V] + final_state, # final state of the chunk [B, H, D_head_K, D_head_V] + s_qk_h, # stride size: L * D_head_K + s_qk_t, # stride size: D_head_K + s_qk_d, # stride size: 1 + s_vo_h, # stride size: L * D_head_V + s_vo_t, # stride size: D_head_V + s_vo_d, # stride size: 1 + B, # batch size + H, # n_heads + T, # seq_len + scale, # D_head_K ** -0.5 + BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size + BK: tl.constexpr, # BLOCK SIZE along the K dimension + BV: tl.constexpr, # BLOCK SIZE along the V dimension + DK: tl.constexpr, # D_head_K + DV: tl.constexpr, # D_head_V + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr, + CHECK: tl.constexpr +): + # indices + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_h = i_bh % H + + o_i = tl.arange(0, BT) + # decay rate given the head index + b_b = tl.math.log2(1 - tl.math.pow(2, -5 - i_h * 1.0)) + + # d_b: overall decay for the entire chunk + # d_o: cumulative decay from the start of the chunk + # d_h: cumulative decay from the end of the chunk + d_b, d_o, d_h = tl.math.exp2(BT * b_b), tl.math.exp2((o_i + 1) * b_b), tl.math.exp2((BT - o_i - 1) * b_b) + + # [BT, BT] + m_s = o_i[:, None] >= o_i[None, :] + d_s = tl.where(m_s, tl.math.exp2((o_i[:, None] - o_i[None, :]) * b_b), 0) + # [BK, BV] + b_h = tl.zeros([BK, BV], dtype=tl.float32) + + # make block pointers + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (0, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BT), (0, 1)) + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0)) + p_o = tl.make_block_ptr(o + (i_bh+i_k*B*H) * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0)) + + if USE_INITIAL_STATE: + p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_h = tl.load(p_h, boundary_check=(0, 1)).to(tl.float32) + + for i in range(0, tl.cdiv(T, BT)): + # [BK, BT] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_k.dtype) + + # [BT, BT] + b_s = tl.dot(b_q, b_k, allow_tf32=False) * d_s + # [BT, BV] + b_o = tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False) + if CHECK and i == 0: + b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False) * d_o[:, None] + b_h = d_b * b_h + tl.dot(b_k, (b_v * d_h[:, None]).to(b_k.dtype), allow_tf32=False) + else: + b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False) * d_o[:, None] + b_h = d_b * b_h + tl.dot(b_k, (b_v * d_h[:, None]).to(b_k.dtype), allow_tf32=False) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + p_q = tl.advance(p_q, (BT, 0)) + p_k = tl.advance(p_k, (0, BT)) + p_v = tl.advance(p_v, (BT, 0)) + p_o = tl.advance(p_o, (BT, 0)) + + if STORE_FINAL_STATE: + p_final = tl.make_block_ptr(final_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_final, b_h.to(p_final.dtype.element_ty), boundary_check=(0, 1)) + + +# Similar to Algorithm1 of https://arxiv.org/abs/2006.16236 +@triton.jit +def fused_chunk_retention_bwd_kernel( + # B: batch_size, H: n_heads, T: seq_len, D: d_head + # NV: number of split in the V dimension. NK: number of split in the K dimension + q, # query [B, H, L, D_head_K] + k, # key [B, H, L, D_head_V] + v, # value [B, H, L, D_head_V] + do, # gradient of output [B, H, L, D_head_V] + dq, # gradient of query [NV, B, H, L, D_head_K] + dk, # gradient of key [NV, B, H, L, D_head_K] + dv, # gradient of value [NK, B, H, L, D_head_V] + + initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V] + + s_qk_h, # stride size: L * D_head_K + s_qk_t, # stride size: D_head_K + s_qk_d, # stride size: 1 + s_vo_h, # stride size: L * D_head_V + s_vo_t, # stride size: D_head_V + s_vo_d, # stride size: 1 + B, # batch_size + H, # n_heads + T, # seq_len + scale, # D_head_K ** -0.5 + BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size + BK: tl.constexpr, # BLOCK SIZE along the K dimension + BV: tl.constexpr, # BLOCK SIZE along the V dimension + DK: tl.constexpr, # D_head_K + DV: tl.constexpr, # D_head_V + USE_INITIAL_STATE: tl.constexpr, + CHECK: tl.constexpr +): + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_h = i_bh % H + + o_i = tl.arange(0, BT) + b_b = tl.math.log2(1 - tl.math.pow(2, -5 - i_h * 1.0)) + d_q, d_k = tl.math.exp2((o_i+1) * b_b) * scale, tl.math.exp2((BT - o_i - 1) * b_b) + d_b = tl.math.exp2(BT * b_b) + + m_s = o_i[:, None] >= o_i[None, :] + d_s = tl.where(m_s, tl.math.exp2((o_i[:, None] - o_i[None, :]) * b_b), 0) * scale + # [BV, BK] + b_h = tl.zeros([BV, BK], dtype=tl.float32) + if USE_INITIAL_STATE: + p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, (DV, DK), (1, DV), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + b_h = tl.load(p_h, boundary_check=(0, 1)).to(tl.float32) + + for i in range(0, tl.cdiv(T, BT)): + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i * BT, i_k * BK), (BT, BK), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i * BT), (BV, BT), (0, 1)) + p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i * BT, i_v * BV), (BT, BV), (1, 0)) + p_dq = tl.make_block_ptr(dq + (i_bh + i_v*B*H) * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i*BT, i_k*BK), (BT, BK), (1, 0)) + + # [BT, DK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [DV, BT] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BT, DV] + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_dd = (b_do * d_q[:, None]).to(b_do.dtype) + + # [BT, BT] + b_ds = tl.dot(b_do, b_v, allow_tf32=False) + b_ds = (b_ds * d_s).to(b_k.dtype) + # [BT, DK] + b_dq = tl.dot(b_ds, b_k, allow_tf32=False) + # [DV, DK] + if CHECK and i == 0: + b_dq += tl.dot(b_dd, b_h.to(b_k.dtype), allow_tf32=False) + b_h = d_b * b_h + tl.dot((b_v * d_k[None, :]).to(b_k.dtype), b_k, allow_tf32=False) + else: + b_dq += tl.dot(b_dd, b_h.to(b_k.dtype), allow_tf32=False) + b_h = d_b * b_h + tl.dot((b_v * d_k[None, :]).to(b_k.dtype), b_k, allow_tf32=False) + + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + + # sync threads + b_h = None + tl.debug_barrier() + d_s = tl.trans(d_s) + # [BK, BV] + b_dh = tl.zeros([BK, BV], dtype=tl.float32) + for i in range(1, tl.cdiv(T, BT) + 1): + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, T - i * BT), (BK, BT), (0, 1)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (T - i * BT, i_k * BK), (BT, BK), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0)) + p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0)) + p_dk = tl.make_block_ptr(dk + (i_bh+i_v*B*H) * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (T - i*BT, i_k*BK), (BT, BK), (1, 0)) + p_dv = tl.make_block_ptr(dv + (i_bh+i_k*B*H) * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i*BT, i_v*BV), (BT, BV), (1, 0)) + # [DK, BT] + b_q = tl.load(p_q, boundary_check=(0, 1)) + # [BT, DK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BT, DV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_dd = (b_do * d_q[:, None]).to(b_do.dtype) + + # [BT, BT] + b_ds = tl.dot(b_v, tl.trans(b_do), allow_tf32=False) + b_ds = (b_ds * d_s).to(b_k.dtype) + + # [BT, BT] + b_s = tl.dot(b_k, b_q, allow_tf32=False) * d_s + # [BT, DK] + b_dk = tl.dot(b_ds, tl.trans(b_q), allow_tf32=False) + # [BT, DV] + b_dv = tl.dot(b_s.to(b_q.dtype), b_do, allow_tf32=False) + if CHECK and i == 1: + b_dk += tl.dot(b_v, tl.trans(b_dh).to(b_v.dtype), allow_tf32=False) * d_k[:, None] + b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False) * d_k[:, None] + b_dh = d_b * b_dh + tl.dot(b_q, b_dd, allow_tf32=False) + else: + b_dk += tl.dot(b_v, tl.trans(b_dh).to(b_v.dtype), allow_tf32=False) * d_k[:, None] + b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False) * d_k[:, None] + b_dh = d_b * b_dh + tl.dot(b_q, b_dd, allow_tf32=False) + + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + +class FusedChunkRetentionFunction(torch.autograd.Function): + + @staticmethod + @contiguous + @custom_fwd + def forward(ctx, q, k, v, initial_state, output_final_state): + batch_size, n_heads, seq_len, d_head_qk = q.shape + d_head_v = v.shape[-1] + + scale = d_head_qk ** -0.5 + BT = 64 + BK, BV = min(triton.next_power_of_2(d_head_qk), 64), min(triton.next_power_of_2(d_head_v), 64) + NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV) + num_stages = 1 + num_warps = 4 + + o = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v) + + if output_final_state: + final_state = q.new_empty(batch_size, n_heads, d_head_qk, d_head_v, dtype=torch.float32, requires_grad=False) + else: + final_state = None + CHECK = False + if version.parse(triton.__version__) < version.parse('2.2.0'): + import warnings + warnings.warn( + "Triton<2.2.0 detected for running this kernel, " + "which is known to have some weird compiler issues (refer to https://github.com/openai/triton/issues/2852) " + "that lead to significant precision loss. " + "We've add some initial condition checks to resolve this, sadly at the sacrifice of the speed. " + "For optimal performance, it is recommended to install Triton>=2.2.0 (if possible)." + ) + CHECK = True + + grid = (NV, NK, batch_size * n_heads) + fused_chunk_retention_fwd_kernel[grid]( + q, k, v, o, initial_state, final_state, + q.stride(1), q.stride(2), q.stride(3), + v.stride(1), v.stride(2), v.stride(3), + batch_size, n_heads, seq_len, scale, + BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV, + USE_INITIAL_STATE=initial_state is not None, + STORE_FINAL_STATE=output_final_state, + CHECK=CHECK, + num_warps=num_warps, + num_stages=num_stages + ) + + o = o.sum(0) + ctx.save_for_backward(q, k, v, initial_state) + ctx.CHECK = CHECK + return o.to(q.dtype), final_state + + @staticmethod + @custom_bwd + @contiguous + def backward(ctx, do, d_final_state=None): + q, k, v, initial_state = ctx.saved_tensors + batch_size, n_heads, seq_len, d_head_qk = q.shape + d_head_v = v.shape[-1] + scale = d_head_qk ** -0.5 + + BT = 64 + BK, BV = min(triton.next_power_of_2(d_head_qk), 64), min(triton.next_power_of_2(d_head_v), 64) + NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV) + num_stages = 1 + num_warps = 4 + + dq = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk) + dk = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk) + dv = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v) + grid = (NV, NK, batch_size * n_heads) + + fused_chunk_retention_bwd_kernel[grid]( + q, k, v, do, dq, dk, dv, initial_state, + q.stride(1), q.stride(2), q.stride(3), + v.stride(1), v.stride(2), v.stride(3), + batch_size, n_heads, seq_len, scale, + BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV, + USE_INITIAL_STATE=initial_state is not None, + CHECK=ctx.CHECK, + num_warps=num_warps, + num_stages=num_stages + ) + dq = dq.sum(0) + dk = dk.sum(0) + dv = dv.sum(0) + return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), None, None + + +def fused_chunk_retention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + initial_state: torch.Tensor = None, + output_final_state: bool = False +): + if initial_state is not None: + initial_state = initial_state.detach() + o, final_state = FusedChunkRetentionFunction.apply(q, k, v, initial_state, output_final_state) + if output_final_state: + return o, final_state + else: + return o diff --git a/flash_linear_attention/fla/ops/triton/retention/parallel.py b/flash_linear_attention/fla/ops/triton/retention/parallel.py new file mode 100644 index 0000000..e114039 --- /dev/null +++ b/flash_linear_attention/fla/ops/triton/retention/parallel.py @@ -0,0 +1,341 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023, Yu Zhang, Songlin Yang + +import torch +import triton +import triton.language as tl + +from fla.ops.triton.utils import contiguous +from torch.cuda.amp import custom_bwd, custom_fwd + + +@triton.jit +def parallel_retention_fwd_kernel( + # B: batch_size, H: n_heads, T: seq_len, D: d_head + q, # query [B, H, L, D_head_K] + k, # key [B, H, L, D_head_V] + v, # value [B, H, L, D_head_V] + o, # output [B, H, L, D_head_V] + s_qk_h, # stride size: L * D_head_K + s_qk_t, # stride size: D_head_K + s_qk_d, # stride size: 1 + s_vo_h, # stride size: L * D_head_V + s_vo_t, # stride size: D_head_V + s_vo_d, # stride size: 1 + B, # batch size + H, # n_heads + T, # seq_len + scale, # D_head_K ** -0.5 + BTL: tl.constexpr, # BLOCK SIZE along the sequence dimension for Q + BTS: tl.constexpr, # BLOCK SIZE along the sequence dimension for K/V + BK: tl.constexpr, # BLOCK SIZE along the K dimension + BV: tl.constexpr, # BLOCK SIZE along the V dimension + DK: tl.constexpr, # D_head_K + DV: tl.constexpr, # D_head_V +): + # i_c: chunk index. used for sequence parallelism + i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + NV = tl.cdiv(DV, BV) + i_k = i_kv // (NV) + i_v = i_kv % (NV) + i_h = i_bh % H + # decay rate given the head index + b_b = tl.math.log2(1 - tl.math.pow(2, -5 - i_h * 1.0)) + # overall decay rate for an entire block + d_b = tl.math.exp2(b_b * BTS) + # cumulative decay from the end of the chunk + o_k = tl.arange(0, BTS) + d_h = tl.math.exp2((BTS - o_k) * b_b) + + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK), + (s_qk_t, s_qk_d), (i_c * BTL, i_k * BK), (BTL, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T), + (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BTS), (0, 1)) + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), + (s_vo_t, s_vo_d), (0, i_v * BV), (BTS, BV), (1, 0)) + + # [BQ, BD] block Q, in the shared memory throughout the whole kernel + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + b_o = tl.zeros([BTL, BV], dtype=tl.float32) + + # Q block and K block have no overlap + # no need for mask, thereby saving flops + for _ in range(0, i_c * BTL, BTS): + # [BK, BTS] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BTS, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BTL, BTS] + b_s = tl.dot(b_q, (b_k), allow_tf32=False) * d_h[None, :] + # [BQ, BD] + b_o = b_o * tl.math.exp2(b_b * BTS) + b_o = b_o + tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False) + p_k = tl.advance(p_k, (0, BTS)) + p_v = tl.advance(p_v, (BTS, 0)) + + # # rescale interchunk output + tl.debug_barrier() + o_q = tl.arange(0, BTL) + d_q = tl.math.exp2(tl.arange(0, BTL) * b_b) + b_o *= d_q[:, None] + # # sync threads, easy for compiler to optimize + # tl.debug_barrier() + + o_k = tl.arange(0, BTS) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T), + (s_qk_d, s_qk_t), (i_k * BK, i_c * BTL), (BK, BTS), (0, 1)) + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), + (s_vo_t, s_vo_d), (i_c * BTL, i_v * BV), (BTS, BV), (1, 0)) + # Q block and K block have overlap. masks required + for _ in range(i_c * BTL, (i_c + 1) * BTL, BTS): + # [BK, BTS] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BTS, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BTL, BTS] + m_s = o_q[:, None] >= o_k[None, :] + d_s = tl.where(m_s, tl.math.exp2( + (o_q[:, None] - o_k[None, :]) * b_b), 0) + b_s = tl.dot(b_q, b_k, allow_tf32=False) * d_s + # [BTL, BV] + b_o += tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False) + + p_k = tl.advance(p_k, (0, BTS)) + p_v = tl.advance(p_v, (BTS, 0)) + o_k += BTS + + p_o = tl.make_block_ptr(o + (i_bh + B * H * i_k) * s_vo_h, (T, DV), + (s_vo_t, s_vo_d), (i_c*BTL, i_v*BV), (BTL, BV), (1, 0)) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit +def _parallel_retention_bwd_dq( + i_bh, i_c, i_k, i_v, i_h, + k, v, do, dq, s_qk_h, s_qk_t, s_qk_d, s_vo_h, + s_vo_t, s_vo_d, B, H, T, scale, + BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, + DK: tl.constexpr, DV: tl.constexpr, +): + p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), + (i_c * BTL, i_v * BV), (BTL, BV), (1, 0)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_dq = tl.zeros([BTL, BK], dtype=tl.float32) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), + (s_qk_t, s_qk_d), (0, i_k * BK), (BTS, BK), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (DV, T), + (s_vo_d, s_vo_t), (i_v * BV, 0), (BV, BTS), (0, 1)) + # decay rate given the head index + b_b = tl.math.log2(1 - tl.math.pow(2, -5 - i_h * 1.0)) + # overall decay rate for an entire block + d_b = tl.math.exp2(b_b * BTS) + # cumulative decay from the end of the chunk + d_h = tl.math.exp2((BTS - tl.arange(0, BTS)) * b_b) + for _ in range(0, i_c * BTL, BTS): + # [BTS, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BV, BTS] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BTL, BTS] + b_ds = tl.dot(b_do, b_v, allow_tf32=False) * d_h[None, :] + # [BQ, BD] + b_dq *= d_b + b_dq += tl.dot(b_ds.to(b_v.dtype), b_k, allow_tf32=False) + p_k = tl.advance(p_k, (BTS, 0)) + p_v = tl.advance(p_v, (0, BTS)) + b_dq *= tl.math.exp2(tl.arange(0, BTL) * b_b)[:, None] * scale + o_q = tl.arange(0, BTL) + o_k = tl.arange(0, BTS) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), + (s_qk_t, s_qk_d), (i_c * BTL, i_k * BK), (BTS, BK), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (DV, T), + (s_vo_d, s_vo_t), (i_v * BV, i_c * BTL), (BV, BTS), (0, 1)) + # Q block and K block have overlap. masks required + for _ in range(i_c * BTL, (i_c + 1) * BTL, BTS): + # [BTS, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BV, BTS] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BTL, BTS] + m_s = o_q[:, None] >= o_k[None, :] + d_s = tl.where(m_s, tl.math.exp2( + (o_q[:, None] - o_k[None, :]) * b_b), 0) + b_ds = tl.dot(b_do, b_v, allow_tf32=False) * d_s * scale + # [BTL, BK] + b_dq += tl.dot(b_ds.to(b_k.dtype), b_k, allow_tf32=False) + p_k = tl.advance(p_k, (BTS, 0)) + p_v = tl.advance(p_v, (0, BTS)) + o_k += BTS + p_dq = tl.make_block_ptr(dq + (i_bh + B * H * i_v) * s_qk_h, (T, DK), + (s_qk_t, s_qk_d), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0)) + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + return + + +@triton.jit +def _parallel_retention_bwd_dkv( + i_bh, i_c, i_k, i_v, i_h, + q, k, v, do, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h, + s_vo_t, s_vo_d, B, H, T, scale, + BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, + DK: tl.constexpr, DV: tl.constexpr, +): + # no overlap. no need for mask. + b_b = tl.math.log2(1 - tl.math.pow(2, -5 - i_h * 1.0)) + # overall decay rate for an entire block + d_b = tl.math.exp2(b_b * BTS) + # compute dk dv + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), + (i_c * BTL, i_k * BK), (BTL, BK), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), + (i_c * BTL, i_v * BV), (BTL, BV), (1, 0)) + b_k, b_v = tl.load(p_k, boundary_check=(0, 1)), tl.load( + p_v, boundary_check=(0, 1)) + b_dk, b_dv = tl.zeros([BTL, BK], dtype=tl.float32), tl.zeros( + [BTL, BV], dtype=tl.float32) + d_h = tl.math.exp2((BTL - tl.arange(0, BTL)) * b_b) + b_kd = (b_k * d_h[:, None]).to(b_k.dtype) + d_q = tl.math.exp2(tl.arange(0, BTS) * b_b) + for i in range((tl.cdiv(T, BTS) * BTS)-BTS, (i_c + 1) * BTL - BTS, -BTS): + p_q = tl.make_block_ptr( + q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, i), (BK, BTS), (0, 1)) + p_do = tl.make_block_ptr( + do + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i), (BV, BTS), (0, 1)) + b_q = tl.load(p_q, boundary_check=(0, 1)) # [BK, BTS] + b_do = tl.load(p_do, boundary_check=(0, 1)) # [BV, BTS] + b_do = (b_do * d_q[None, :]).to(b_do.dtype) + + b_dv *= d_b + b_s = tl.dot(b_kd.to(b_q.dtype), b_q, allow_tf32=False) # [BTL, BTS] + b_dv += tl.dot(b_s.to(b_q.dtype), tl.trans(b_do), allow_tf32=False) + + b_dk *= d_b + b_ds = tl.dot(b_v, b_do, allow_tf32=False) + b_dk += tl.dot(b_ds.to(b_q.dtype), tl.trans(b_q), allow_tf32=False) + b_dk *= d_h[:, None] * scale + b_dv *= scale + tl.debug_barrier() + o_q, o_k = tl.arange(0, BTS), tl.arange(0, BTL) + for i in range(i_c*BTL, (i_c+1)*BTL, BTS): + p_q = tl.make_block_ptr( + q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, i), (BK, BTS), (0, 1)) + p_do = tl.make_block_ptr( + do + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i), (BV, BTS), (0, 1)) + b_q = tl.load(p_q, boundary_check=(0, 1)) # [BD, BQ] + b_do = tl.load(p_do, boundary_check=(0, 1)) + # [BK, BQ] + m_s = o_k[:, None] <= o_q[None, :] + d_s = tl.where(m_s, tl.math.exp2( + (-o_k[:, None] + o_q[None, :]) * b_b.to(tl.float32)), 0) * scale + b_s = tl.dot(b_k, b_q, allow_tf32=False) * d_s + b_ds = tl.dot(b_v, b_do, allow_tf32=False) * d_s + # [BK, BD] + b_dk += tl.dot(b_ds.to(b_q.dtype), tl.trans(b_q), allow_tf32=False) + b_dv += tl.dot(b_s.to(b_q.dtype), tl.trans(b_do), allow_tf32=False) + o_q += BTS + p_dk = tl.make_block_ptr(dk + (i_bh + B * H * i_v) * s_qk_h, + (T, DK), (s_qk_t, s_qk_d), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0)) + p_dv = tl.make_block_ptr(dv + (i_bh + B * H * i_k) * s_vo_h, + (T, DV), (s_vo_t, s_vo_d), (i_c*BTL, i_v*BV), (BTL, BV), (1, 0)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + return + + +@triton.jit +def parallel_retention_bwd_kernel( + q, k, v, do, dq, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h, + s_vo_t, s_vo_d, B, H, T, scale, + BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, + DK: tl.constexpr, DV: tl.constexpr, +): + i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + NV = tl.cdiv(DV, BV) + i_k = i_kv // (NV) + i_v = i_kv % (NV) + i_h = i_bh % H + _parallel_retention_bwd_dq( + i_bh, i_c, i_k, i_v, i_h, + k, v, do, dq, s_qk_h, s_qk_t, s_qk_d, s_vo_h, + s_vo_t, s_vo_d, B, H, T, scale, BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=DK, DV=DV + ) + tl.debug_barrier() + _parallel_retention_bwd_dkv( + i_bh, i_c, i_k, i_v, i_h, + q, k, v, do, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h, + s_vo_t, s_vo_d, B, H, T, scale, BTL, BTS, BK, BV, DK, DV + ) + + +class ParallelRetentionFunction(torch.autograd.Function): + @staticmethod + @contiguous + @custom_fwd + def forward(ctx, q, k, v): + BTL, BTS = 128, 32 + assert BTL % BTS == 0 + BK = min(128, triton.next_power_of_2(k.shape[-1])) + BV = min(128, triton.next_power_of_2(v.shape[-1])) + batch_size, n_heads, seq_len, d_head_qk = q.shape + d_head_v = v.shape[-1] + num_stages = 3 if d_head_qk <= 64 else 2 + num_warps = 4 + NK = triton.cdiv(d_head_qk, BK) + NV = triton.cdiv(d_head_v, BV) + + grid = (NK * NV, triton.cdiv(seq_len, BTL), batch_size * n_heads) + scale = d_head_qk ** -0.5 + o = torch.empty(NK, batch_size, n_heads, seq_len, + d_head_v, dtype=q.dtype, device=q.device) + parallel_retention_fwd_kernel[grid]( + q, k, v, o, + q.stride(1), q.stride(2), q.stride(3), + v.stride(1), v.stride(2), v.stride(3), + batch_size, n_heads, seq_len, scale, + BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=d_head_qk, DV=d_head_v, + num_warps=num_warps, + num_stages=num_stages + ) + ctx.save_for_backward(q, k, v) + return o.sum(0).to(q.dtype) + + @staticmethod + @contiguous + @custom_bwd + def backward(ctx, do): + q, k, v = ctx.saved_tensors + BTL, BTS = 64, 32 + assert BTL % BTS == 0 + BK = min(128, triton.next_power_of_2(k.shape[-1])) + BV = min(128, triton.next_power_of_2(v.shape[-1])) + batch_size, n_heads, seq_len, d_head_qk = q.shape + d_head_v = v.shape[-1] + num_stages = 3 if d_head_qk <= 64 else 2 + num_warps = 4 + NK = triton.cdiv(d_head_qk, BK) + NV = triton.cdiv(d_head_v, BV) + grid = (NK * NV, triton.cdiv(seq_len, BTL), batch_size * n_heads) + scale = d_head_qk ** -0.5 + + dq = torch.empty(NV, batch_size, n_heads, seq_len, + d_head_qk, dtype=q.dtype, device=q.device) + dk = torch.empty(NV, batch_size, n_heads, seq_len, + d_head_qk, dtype=q.dtype, device=q.device) + dv = torch.empty(NK, batch_size, n_heads, seq_len, + d_head_v, dtype=q.dtype, device=q.device) + + parallel_retention_bwd_kernel[grid]( + q, k, v, do, dq, dk, dv, + q.stride(1), q.stride(2), q.stride(3), + v.stride(1), v.stride(2), v.stride(3), + batch_size, n_heads, seq_len, scale, + BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=d_head_qk, DV=d_head_v, + num_warps=num_warps, + num_stages=num_stages + ) + + return dq.sum(0).to(q.dtype), dk.sum(0).to(k.dtype), dv.sum(0).to(v.dtype) + + +parallel_retention = ParallelRetentionFunction.apply diff --git a/flash_linear_attention/fla/ops/triton/retention/recurrent_fuse.py b/flash_linear_attention/fla/ops/triton/retention/recurrent_fuse.py new file mode 100644 index 0000000..3abb107 --- /dev/null +++ b/flash_linear_attention/fla/ops/triton/retention/recurrent_fuse.py @@ -0,0 +1,280 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023, Yu Zhang, Songlin Yang + +import torch +import triton +import triton.language as tl +from fla.ops.triton.utils import contiguous + +# on-the-fly computation without materializing hidden statets into HBMs + + +@triton.jit +def fused_recurrent_retention_fwd_kernel( + # B: batch_size, H: n_heads, T: seq_len, D: d_head + q, # query [B, H, L, D_head_K] + k, # key [B, H, L, D_head_V] + v, # value [B, H, L, D_head_V] + o, # output [B, H, L, D_head_V] + initial_state, + final_state, # final hidden state [B, H, D_head_K, D_head_V] + + s_qk_h, # stride size: L * D_head_K + s_qk_t, # stride size: D_head_K + s_qk_d, # stride size: 1 + + s_vo_h, # stride size: L * D_head_V + s_vo_t, # stride size: D_head_V + s_vo_d, # stride size: 1 + + B, # batch size + H, # n_heads + T, # seq_len + scale, # D_head_K ** -0.5 + BK: tl.constexpr, # BLOCK SIZE along the K dimension + BV: tl.constexpr, # BLOCK SIZE along the V dimension + DK: tl.constexpr, # D_head_K + DV: tl.constexpr, # D_head_V + USE_INITIAL_STATE: tl.constexpr, # whether to use initial state + STORE_FINAL_STATE: tl.constexpr, # whether to store final state +): + # indices + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_h = i_bh % H + + # decay rate given the head index + b_b = (1 - tl.math.pow(2, -5 - i_h * 1.0)) + + p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + p_o = o + (i_bh + i_k * B * H) * s_vo_h + i_v * BV + tl.arange(0, BV) + + mask_bk = (i_k * BK + tl.arange(0, BK)) < DK + mask_bv = (i_v * BV + tl.arange(0, BV)) < DV + mask_kv = mask_bk[None, :] & mask_bv[:, None] + + h = tl.zeros([BV, BK], dtype=tl.float32) + + if USE_INITIAL_STATE: + p_init_s = initial_state + i_bh * DK * DV + \ + (i_k * BK + tl.arange(0, BK)[None, :]) * \ + DV + (i_v * BV + tl.arange(0, BV)[:, None]) + h += tl.load(p_init_s, mask=mask_kv, other=0).to(tl.float32) + + for _ in range(0, T): + _k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32) + _v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32) + _q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale + + h = b_b * h + _k[None, :] * _v[:, None] + _o = h * _q[None, :] + _o = tl.sum(_o, axis=1) + tl.store(p_o, _o.to(p_o.dtype.element_ty), mask=mask_bv) + + p_q += DK + p_k += DK + p_o += DV + p_v += DV + + if STORE_FINAL_STATE: + p_final_s = final_state + i_bh * DK * DV + \ + (i_k * BK + tl.arange(0, BK)[None, :]) * \ + DV + (i_v * BV + tl.arange(0, BV)[:, None]) + tl.store(p_final_s, h.to(p_final_s.dtype.element_ty), mask=mask_kv) + + +# Similar to Algorithm1 of https://arxiv.org/abs/2006.16236 +@triton.jit +def fused_recurrent_retention_bwd_kernel( + # B: batch_size, H: n_heads, T: seq_len, D: d_head + # NV: number of split in the V dimension. NK: number of split in the K dimension + q, # query [B, H, L, D_head_K] + k, # key [B, H, L, D_head_V] + v, # value [B, H, L, D_head_V] + + do, # gradient of output [B, H, L, D_head_V] + dq, # gradient of query [NV, B, H, L, D_head_K] + dk, # gradient of key [NV, B, H, L, D_head_K] + dv, # gradient of value [NK, B, H, L, D_head_V] + + # initial hidden state initialization [B, H, D_head_K, D_head_V] + initial_state, + + s_qk_h, # stride size: L * D_head_K + s_qk_t, # stride size: D_head_K + s_qk_d, # stride size: 1 + + s_vo_h, # stride size: L * D_head_V + s_vo_t, # stride size: D_head_V + s_vo_d, # stride size: 1 + + B, # batch_size + H, # n_heads + T, # seq_len + scale, # D_head_K ** -0.5 + BK: tl.constexpr, # BLOCK SIZE along the K dimension + BV: tl.constexpr, # BLOCK SIZE along the V dimension + DK: tl.constexpr, # D_head_K + DV: tl.constexpr, # D_head_V + USE_INITIAL_STATE: tl.constexpr, # whether to use initial state +): + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_h = i_bh % H + + b_b = 1 - tl.math.pow(2, -5 - i_h * 1.0) + + p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + p_do = do + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + + p_dq = dq + (i_bh + i_v * B * H) * s_qk_h + i_k * BK + tl.arange(0, BK) + mask_bk = i_k * BK + tl.arange(0, BK) < DK + mask_bv = i_v * BV + tl.arange(0, BV) < DV + + h = tl.zeros([BK, BV], dtype=tl.float32) + + if USE_INITIAL_STATE: + mask_kv = mask_bk[:, None] & mask_bv[None, :] + p_init_s = initial_state + i_bh * DK * DV + \ + (i_k * BK + tl.arange(0, BK)[:, None]) * \ + DV + (i_v * BV + tl.arange(0, BV)[None, :]) + h += tl.load(p_init_s, mask=mask_kv, other=0).to(tl.float32) + + for i in range(0, T): + _k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32) + _v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32) + _do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32) + + h = b_b * h + _k[:, None] * _v[None, :] + _d_q = h * _do[None, :] + d_q = tl.sum(_d_q, axis=1) * scale + tl.store(p_dq, d_q.to(p_dq.dtype.element_ty), mask=mask_bk) + + p_k += DK + p_do += DV + p_v += DV + p_dq += DK + + # sync threads + tl.debug_barrier() + + p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (T - 1) * DK + p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (T - 1) * DK + p_do = do + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + (T - 1) * DV + p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + (T - 1) * DV + p_dk = dk + (i_bh + i_v * B * H) * s_qk_h + i_k * \ + BK + tl.arange(0, BK) + (T - 1) * DK + p_dv = dv + (i_bh + i_k * B * H) * s_vo_h + i_v * \ + BV + tl.arange(0, BV) + (T - 1) * DV + d_h = tl.zeros([BK, BV], dtype=tl.float32) + + for _ in range(T): + _do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32) + _q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale + _k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32) + _v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32) + d_h += _q[:, None] * _do[None, :] + d_k = tl.sum(d_h * _v[None, :], axis=1) + d_v = tl.sum(d_h * _k[:, None], axis=0) + + d_h *= b_b + tl.store(p_dk, d_k.to(p_dk.dtype.element_ty), mask=mask_bk) + tl.store(p_dv, d_v.to(p_dv.dtype.element_ty), mask=mask_bv) + + p_do -= DV + p_q -= DK + p_k -= DK + p_v -= DV + p_dk -= DK + p_dv -= DV + + +class FusedRecurrentRetentionFunction(torch.autograd.Function): + + @staticmethod + @contiguous + def forward(ctx, q, k, v, initial_state=None, output_final_state=False): + batch_size, n_heads, seq_len, d_head_qk = q.shape + d_head_v = v.shape[-1] + + scale = d_head_qk ** -0.5 + BK, BV = min(d_head_qk, 32), min(d_head_v, 32) + NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV) + num_stages = 1 + num_warps = 1 + + o = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v) + + if output_final_state: + final_state = q.new_empty(batch_size, n_heads, d_head_qk, d_head_v) + else: + final_state = None + + grid = (NV, NK, batch_size * n_heads) + fused_recurrent_retention_fwd_kernel[grid]( + q, k, v, o, initial_state, final_state, + q.stride(1), q.stride(2), q.stride(3), + v.stride(1), v.stride(2), v.stride(3), + batch_size, n_heads, seq_len, scale, + DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV, + num_warps=num_warps, + num_stages=num_stages, + USE_INITIAL_STATE=initial_state is not None, + STORE_FINAL_STATE=final_state is not None + ) + + o = o.sum(0) + ctx.save_for_backward(q, k, v, initial_state) + return o, final_state + + @staticmethod + @contiguous + def backward(ctx, do, d_final_state=None): + q, k, v, initial_state = ctx.saved_tensors + batch_size, n_heads, seq_len, d_head_qk = q.shape + d_head_v = v.shape[-1] + scale = d_head_qk ** -0.5 + + BK, BV = min(d_head_qk, 32), min(d_head_v, 32) + NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV) + num_stages = 1 + num_warps = 1 + + dq = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk) + dk = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk) + dv = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v) + grid = (NV, NK, batch_size * n_heads) + + fused_recurrent_retention_bwd_kernel[grid]( + q, k, v, do, dq, dk, dv, initial_state, + q.stride(1), q.stride(2), q.stride(3), + v.stride(1), v.stride(2), v.stride(3), + batch_size, n_heads, seq_len, scale, + DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV, + num_warps=num_warps, + num_stages=num_stages, + USE_INITIAL_STATE=initial_state is not None + ) + dq = dq.sum(0) + dk = dk.sum(0) + dv = dv.sum(0) + return dq, dk, dv, None, None + + +# fused_recurrent_retention = FusedRecurrentRetentionFunction.apply + +def fused_recurrent_retention(q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + initial_state: torch.Tensor = None, + output_final_state: bool = False): + if initial_state is not None: + initial_state = initial_state.detach() + o, final_state = FusedRecurrentRetentionFunction.apply( + q, k, v, initial_state, output_final_state) + if output_final_state: + return o, final_state + else: + return o diff --git a/flash_linear_attention/fla/ops/triton/rotary.py b/flash_linear_attention/fla/ops/triton/rotary.py new file mode 100644 index 0000000..18ccc5f --- /dev/null +++ b/flash_linear_attention/fla/ops/triton/rotary.py @@ -0,0 +1,252 @@ +# Copyright (c) 2023, Tri Dao. https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/rotary.py + +from typing import Optional, Union + +import torch + +import triton +import triton.language as tl + + +# @triton.autotune( +# configs=[ +# triton.Config({"BLOCK_M": 2}), +# triton.Config({"BLOCK_M": 4}), +# triton.Config({"BLOCK_M": 8}), +# triton.Config({"BLOCK_M": 16}), +# ], +# key=["CACHE_KEY_SEQLEN", "BLOCK_K", "INTERLEAVED"], +# ) +@triton.jit +def rotary_kernel( + OUT, # Pointers to matrices + X, + COS, + SIN, + CU_SEQLENS, + SEQLEN_OFFSETS, # this could be int or a pointer + # Matrix dimensions + seqlen, + nheads, + rotary_dim, + seqlen_ro, + CACHE_KEY_SEQLEN, + # strides + stride_out_batch, + stride_out_seqlen, + stride_out_nheads, + stride_out_headdim, + stride_x_batch, + stride_x_seqlen, + stride_x_nheads, + stride_x_headdim, + # Meta-parameters + BLOCK_K: tl.constexpr, + IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr, + IS_VARLEN: tl.constexpr, + INTERLEAVED: tl.constexpr, + CONJUGATE: tl.constexpr, + BLOCK_M: tl.constexpr, +): + pid_m = tl.program_id(axis=0) + pid_batch = tl.program_id(axis=1) + pid_head = tl.program_id(axis=2) + rotary_dim_half = rotary_dim // 2 + + if not IS_VARLEN: + X = X + pid_batch * stride_x_batch + pid_head * stride_x_nheads + OUT = OUT + pid_batch * stride_out_batch + pid_head * stride_out_nheads + else: + start_idx = tl.load(CU_SEQLENS + pid_batch) + seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx + X = X + start_idx * stride_x_seqlen + pid_head * stride_x_nheads + OUT = OUT + start_idx * stride_out_seqlen + pid_head * stride_out_nheads + + if pid_m * BLOCK_M >= seqlen: + return + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + if not IS_SEQLEN_OFFSETS_TENSOR: + rm_cs = rm + SEQLEN_OFFSETS + else: + rm_cs = rm + tl.load(SEQLEN_OFFSETS + pid_batch) + rk = tl.arange(0, BLOCK_K) + rk_half = tl.arange(0, BLOCK_K // 2) + + if not INTERLEAVED: + # Load the 1st and 2nd halves of X, do calculation, then store to 1st and 2nd halves of OUT + X = X + (rm[:, None] * stride_x_seqlen + + rk_half[None, :] * stride_x_headdim) + COS = COS + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :]) + SIN = SIN + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :]) + cos = tl.load( + COS, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=1.0 + ).to(tl.float32) + sin = tl.load( + SIN, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=0.0 + ).to(tl.float32) + x0 = tl.load( + X, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), other=0.0 + ).to(tl.float32) + x1 = tl.load( + X + rotary_dim_half * stride_x_headdim, + mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), + other=0.0, + ).to(tl.float32) + if CONJUGATE: + sin = -sin + o0 = x0 * cos - x1 * sin + o1 = x0 * sin + x1 * cos + # write back result + OUT = OUT + (rm[:, None] * stride_out_seqlen + + rk_half[None, :] * stride_out_headdim) + tl.store(OUT, o0, mask=(rm[:, None] < seqlen) + & (rk_half[None, :] < rotary_dim_half)) + tl.store( + OUT + rotary_dim_half * stride_out_headdim, + o1, + mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), + ) + else: + # We don't want to load X[0, 2, 4, ...] and X[1, 3, 5, ...] separately since both are slow. + # Instead, we load x0 = X[0, 1, 2, 3, ...] and x1 = X[1, 0, 3, 2, ...]. + # Loading x0 will be fast but x1 will be slow. + # Then we load cos = COS[0, 0, 1, 1, ...] and sin = SIN[0, 0, 1, 1, ...]. + # Then we do the calculation and use tl.where to pick put the right outputs for the even + # and for the odd indices. + rk_swap = rk + ((rk + 1) % 2) * 2 - 1 # 1, 0, 3, 2, 5, 4, ... + rk_repeat = tl.arange(0, BLOCK_K) // 2 + X0 = X + (rm[:, None] * stride_x_seqlen + + rk[None, :] * stride_x_headdim) + X1 = X + (rm[:, None] * stride_x_seqlen + + rk_swap[None, :] * stride_x_headdim) + COS = COS + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :]) + SIN = SIN + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :]) + cos = tl.load( + COS, + mask=(rm_cs[:, None] < seqlen_ro) & ( + rk_repeat[None, :] < rotary_dim_half), + other=1.0, + ).to(tl.float32) + sin = tl.load( + SIN, + mask=(rm_cs[:, None] < seqlen_ro) & ( + rk_repeat[None, :] < rotary_dim_half), + other=0.0, + ).to(tl.float32) + x0 = tl.load(X0, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim), other=0.0).to( + tl.float32 + ) + x1 = tl.load( + X1, mask=(rm[:, None] < seqlen) & (rk_swap[None, :] < rotary_dim), other=0.0 + ).to(tl.float32) + if CONJUGATE: + sin = -sin + x0_cos = x0 * cos + x1_sin = x1 * sin + out = tl.where(rk[None, :] % 2 == 0, x0_cos - x1_sin, x0_cos + x1_sin) + OUT = OUT + (rm[:, None] * stride_out_seqlen + + rk[None, :] * stride_out_headdim) + tl.store(OUT, out, mask=(rm[:, None] < seqlen) + & (rk[None, :] < rotary_dim)) + + +def apply_rotary( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + seqlen_offsets: Union[int, torch.Tensor] = 0, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + interleaved=False, + inplace=False, + conjugate=False, +) -> torch.Tensor: + """ + Arguments: + x: (batch, seqlen, nheads, headdim) if cu_seqlens is None + else (total_seqlen, nheads, headdim). + cos: (seqlen_ro, rotary_dim / 2) + sin: (seqlen_ro, rotary_dim / 2) + seqlen_offsets: integer or integer tensor of size (batch,) + cu_seqlens: (batch + 1,) or None + max_seqlen: int + Returns: + y: (batch, seqlen, nheads, headdim) + """ + is_varlen = cu_seqlens is not None + if not is_varlen: + batch, seqlen, nheads, headdim = x.shape + else: + assert max_seqlen is not None, "If cu_seqlens is passed in, then max_seqlen must be passed" + total_seqlen, nheads, headdim = x.shape + batch_p_1 = cu_seqlens.shape[0] + batch = batch_p_1 - 1 + seqlen = max_seqlen + seqlen_ro, rotary_dim = cos.shape + assert sin.shape == cos.shape + rotary_dim *= 2 + assert rotary_dim <= headdim, "rotary_dim must be <= headdim" + assert headdim <= 256, "Only support headdim <= 256" + assert seqlen_ro >= seqlen, "seqlen_ro must be >= seqlen" + + assert ( + cos.dtype == sin.dtype + ), f"cos and sin must have the same dtype, got {cos.dtype} and {sin.dtype}" + assert ( + x.dtype == cos.dtype + ), f"Input and cos/sin must have the same dtype, got {x.dtype} and {cos.dtype}" + + cos, sin = cos.contiguous(), sin.contiguous() + if isinstance(seqlen_offsets, torch.Tensor): + assert seqlen_offsets.shape == (batch,) + assert seqlen_offsets.dtype in [torch.int32, torch.int64] + seqlen_offsets = seqlen_offsets.contiguous() + else: + assert seqlen_offsets + seqlen <= seqlen_ro + + output = torch.empty_like(x) if not inplace else x + if rotary_dim < headdim and not inplace: + output[..., rotary_dim:].copy_(x[..., rotary_dim:]) + + BLOCK_K = ( + 32 + if rotary_dim <= 32 + else (64 if rotary_dim <= 64 else (128 if rotary_dim <= 128 else 256)) + ) + def grid(META): return (triton.cdiv(seqlen, META["BLOCK_M"]), batch, nheads) # noqa + BLOCK_M = 4 if interleaved else (8 if rotary_dim <= 64 else 4) + + # Need this, otherwise Triton tries to launch from cuda:0 and we get + # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?) + with torch.cuda.device(x.device.index): + rotary_kernel[grid]( + output, # data ptrs + x, + cos, + sin, + cu_seqlens, + seqlen_offsets, + seqlen, # shapes + nheads, + rotary_dim, + seqlen_ro, + # key for triton cache (limit number of compilations) + seqlen // 128, + # batch_strides if not varlen else 0 + output.stride(0) if not is_varlen else 0, + output.stride(-3), # seqlen_stride or total_seqlen_stride + output.stride(-2), # nheads_stride + output.stride(-1), # headdim_stride + # batch_strides if not varlen else 0 + x.stride(0) if not is_varlen else 0, + x.stride(-3), # seqlen stride or total_seqlen_stride + x.stride(-2), # nheads stride + x.stride(-1), # headdim stride + BLOCK_K, + isinstance(seqlen_offsets, torch.Tensor), + is_varlen, + interleaved, + conjugate, + BLOCK_M, + ) + return output diff --git a/flash_linear_attention/fla/ops/triton/utils.py b/flash_linear_attention/fla/ops/triton/utils.py new file mode 100644 index 0000000..93af956 --- /dev/null +++ b/flash_linear_attention/fla/ops/triton/utils.py @@ -0,0 +1,27 @@ +# -*- coding: utf-8 -*- + +import functools + +import torch + + +def contiguous(fn): + @functools.wraps(fn) + def wrapper(ctx, *args, **kwargs): + return fn(ctx, + *(i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args), + **{k: (v if not isinstance(v, torch.Tensor) else v.contiguous()) for k, v in kwargs.items()}) + return wrapper + + +def require_version(version, hint): + def decorator(fn): + @functools.wraps(fn) + def wrapper(ctx, *args, **kwargs): + from transformers.utils.versions import require_version + require_version(version, hint) + return fn(ctx, + *(i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args), + **{k: (v if not isinstance(v, torch.Tensor) else v.contiguous()) for k, v in kwargs.items()}) + return wrapper + return decorator diff --git a/flash_linear_attention/setup.py b/flash_linear_attention/setup.py new file mode 100644 index 0000000..1d3d6b6 --- /dev/null +++ b/flash_linear_attention/setup.py @@ -0,0 +1,147 @@ +# -*- coding: utf-8 -*- + +import ast +import os +import re +import subprocess +import warnings +from pathlib import Path + +import torch +from packaging.version import Version, parse +from setuptools import find_packages, setup +from torch.utils.cpp_extension import CUDA_HOME, BuildExtension, CUDAExtension + + +long_description = "" + +# ninja build does not work unless include_dirs are abs path +this_dir = os.path.dirname(os.path.abspath(__file__)) + +PACKAGE_NAME = 'fla' + +# FORCE_BUILD: force a fresh build locally, instead of attempting to find prebuilt wheels +FORCE_BUILD = os.getenv('FLA_FORCE_BUILD', "FALSE") == 'TRUE' +# SKIP_CUDA_BUILD: allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation +SKIP_CUDA_BUILD = os.getenv('FLA_SKIP_CUDA_BUILD', "FALSE") == 'TRUE' +# For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI +FORCE_CXX11_ABI = os.getenv('FLA_FORCE_CXX11_ABI', "FALSE") == 'TRUE' + + +def get_cuda_bare_metal_version(cuda_dir): + raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) + output = raw_output.split() + release_idx = output.index("release") + 1 + bare_metal_version = parse(output[release_idx].split(",")[0]) + + return raw_output, bare_metal_version + + +def check_if_cuda_home_none(global_option: str) -> None: + if CUDA_HOME is not None: + return + # warn instead of error because user could be downloading prebuilt wheels, so nvcc won't be necessary + # in that case. + warnings.warn( + f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? " + "If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, " + "only images whose names contain 'devel' will provide nvcc." + ) + + +def append_nvcc_threads(nvcc_extra_args): + return nvcc_extra_args + ["--threads", "4"] + + +ext_modules = [] +if not SKIP_CUDA_BUILD: + print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) + TORCH_MAJOR = int(torch.__version__.split(".")[0]) + TORCH_MINOR = int(torch.__version__.split(".")[1]) + + # Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h + # See https://github.com/pytorch/pytorch/pull/70650 + generator_flag = [] + torch_dir = torch.__path__[0] + if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.h")): + generator_flag = ["-DOLD_GENERATOR_PATH"] + + check_if_cuda_home_none('fla') + # Check, if CUDA11 is installed for compute capability 8.0 + cc_flag = [] + if CUDA_HOME is not None: + _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) + if bare_metal_version < Version("11.6"): + raise RuntimeError( + "FLA is only supported on CUDA 11.6 and above. " + "Note: make sure nvcc has a supported version by running nvcc -V." + ) + cc_flag.append("-gencode") + cc_flag.append("arch=compute_80,code=sm_80") + if CUDA_HOME is not None: + if bare_metal_version >= Version("11.8"): + cc_flag.append("-gencode") + cc_flag.append("arch=compute_90,code=sm_90") + + ext_modules = [ + CUDAExtension( + name='semiring_cal_A', + sources=[ + 'fla/ops/cuda/gla/semiring/cal_A/inner_chunk16_dim16x.cpp', + 'fla/ops/cuda/gla/semiring/cal_A/inner_chunk16_dim16x_kernel.cu', + ], + extra_compile_args={ + "cxx": ["-O3", "-std=c++17"] + generator_flag, + "nvcc": append_nvcc_threads( + [ + "-O3", + "-std=c++17", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_HALF2_OPERATORS__", + "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", + "--expt-relaxed-constexpr", + "--expt-extended-lambda", + "--use_fast_math", + ] + + generator_flag + + cc_flag + ), + }, + ), + ] + + +def get_package_version(): + with open(Path(this_dir) / 'fla' / '__init__.py') as f: + version_match = re.search(r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE) + return ast.literal_eval(version_match.group(1)) + + +setup( + name=PACKAGE_NAME, + version=get_package_version(), + description='Fast Triton-based implementations of causal linear attention', + long_description=long_description, + long_description_content_type='text/markdown', + author='Songlin Yang, Yu Zhang', + author_email='bestsonta@gmail.com', + url='https://github.com/sustcsonglin/flash-linear-attention', + packages=find_packages(), + license='MIT', + classifiers=[ + 'Programming Language :: Python :: 3', + 'License :: OSI Approved :: MIT License', + 'Operating System :: OS Independent', + 'Topic :: Scientific/Engineering :: Artificial Intelligence' + ], + python_requires='>=3.7', + ext_modules=ext_modules, + cmdclass={'build_ext': BuildExtension}, + install_requires=[ + 'triton', + 'transformers', + 'einops', + 'ninja' + ] +) From cb180114307abdad7c4376f39865df4a1db73e18 Mon Sep 17 00:00:00 2001 From: kabachuha Date: Mon, 4 Mar 2024 14:38:16 +0300 Subject: [PATCH 2/4] rebased attn processor --- models/latte.py | 11 +++++- models/latte_img.py | 11 +++++- models/latte_t2v.py | 23 +++++++++++-- models/utils.py | 84 ++++++++++++++++++++++++++++++++++++++++++++- 4 files changed, 123 insertions(+), 6 deletions(-) diff --git a/models/latte.py b/models/latte.py index 723686a..af5c793 100644 --- a/models/latte.py +++ b/models/latte.py @@ -26,6 +26,11 @@ except: XFORMERS_IS_AVAILBLE = False +try: + from fla.ops.triton.rebased_fast import parallel_rebased +except: + REBASED_IS_AVAILABLE = False + # from timm.models.layers.helpers import to_2tuple # from timm.models.layers.trace_utils import _assert @@ -37,7 +42,7 @@ def modulate(x, shift, scale): ################################################################################# class Attention(nn.Module): - def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., use_lora=False, attention_mode='math'): + def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., use_lora=False, attention_mode='math', eps=1e-12): super().__init__() assert dim % num_heads == 0, 'dim should be divisible by num_heads' self.num_heads = num_heads @@ -48,6 +53,7 @@ def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) + self.eps = eps def forward(self, x): B, N, C = x.shape @@ -69,6 +75,9 @@ def forward(self, x): attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B, N, C) + elif self.attention_mode == 'rebased': + x = parallel_rebased(q, k, v, self.eps, True, True).reshape(B, N, C) + else: raise NotImplemented diff --git a/models/latte_img.py b/models/latte_img.py index c468c63..a8fc486 100644 --- a/models/latte_img.py +++ b/models/latte_img.py @@ -26,6 +26,12 @@ except: XFORMERS_IS_AVAILBLE = False +try: + from fla.ops.triton.rebased_fast import parallel_rebased +except: + REBASED_IS_AVAILABLE = False + + # from timm.models.layers.helpers import to_2tuple # from timm.models.layers.trace_utils import _assert @@ -37,7 +43,7 @@ def modulate(x, shift, scale): ################################################################################# class Attention(nn.Module): - def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., use_lora=False, attention_mode='math'): + def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., use_lora=False, attention_mode='math', eps=1e-12): super().__init__() assert dim % num_heads == 0, 'dim should be divisible by num_heads' self.num_heads = num_heads @@ -51,6 +57,7 @@ def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) + self.eps = eps def forward(self, x): B, N, C = x.shape @@ -72,6 +79,8 @@ def forward(self, x): attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B, N, C) + elif self.attention_mode == 'rebased': + x = parallel_rebased(q, k, v, self.eps, True, True).reshape(B, N, C) else: raise NotImplemented diff --git a/models/latte_t2v.py b/models/latte_t2v.py index fc96a30..e085116 100644 --- a/models/latte_t2v.py +++ b/models/latte_t2v.py @@ -37,13 +37,19 @@ class GatedSelfAttentionDense(nn.Module): d_head (`int`): The number of channels in each head. """ - def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int): + def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int, attn_type: str = 'vanilla'): super().__init__() # we need a linear projection since we need cat visual feature and obj feature self.linear = nn.Linear(context_dim, query_dim) - self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head) + if attn_type == 'rebased': + from models.utils import RebasedAttnProcessor + attn_proc = RebasedAttnProcessor() + else: + attn_proc = None + + self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head, processor=attn_proc) self.ff = FeedForward(query_dim, activation_fn="geglu") self.norm1 = nn.LayerNorm(query_dim) @@ -178,6 +184,7 @@ def __init__( attention_type: str = "default", positional_embeddings: Optional[str] = None, num_positional_embeddings: Optional[int] = None, + attn_type: str = "vanilla" ): super().__init__() self.only_cross_attention = only_cross_attention @@ -212,6 +219,12 @@ def __init__( else: self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + if attn_type == 'rebased': + from models.utils import RebasedAttnProcessor + attn_proc = RebasedAttnProcessor() + else: + attn_proc = None + self.attn1 = Attention( query_dim=dim, heads=num_attention_heads, @@ -220,6 +233,7 @@ def __init__( bias=attention_bias, cross_attention_dim=cross_attention_dim if only_cross_attention else None, upcast_attention=upcast_attention, + processor=attn_proc ) # # 2. Cross-Attn @@ -254,7 +268,7 @@ def __init__( # 4. Fuser if attention_type == "gated" or attention_type == "gated-text-image": - self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim) + self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim,attn_type=attn_type) # 5. Scale-shift for PixArt-Alpha. if self.use_ada_layer_norm_single: @@ -498,6 +512,7 @@ def __init__( attention_type: str = "default", caption_channels: int = None, video_length: int = 16, + attn_type: str = "vanilla" ): super().__init__() self.use_linear_projection = use_linear_projection @@ -600,6 +615,7 @@ def __init__( norm_elementwise_affine=norm_elementwise_affine, norm_eps=norm_eps, attention_type=attention_type, + attn_type=attn_type ) for d in range(num_layers) ] @@ -624,6 +640,7 @@ def __init__( norm_elementwise_affine=norm_elementwise_affine, norm_eps=norm_eps, attention_type=attention_type, + attn_type=attn_type ) for d in range(num_layers) ] diff --git a/models/utils.py b/models/utils.py index 0e13056..42c5b4f 100644 --- a/models/utils.py +++ b/models/utils.py @@ -212,4 +212,86 @@ def count_params(model, verbose=False): total_params = sum(p.numel() for p in model.parameters()) if verbose: print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") - return total_params \ No newline at end of file + return total_params + +try: + from fla.ops.triton.rebased_fast import parallel_rebased +except: + REBASED_IS_AVAILABLE = False + +from diffusers.models.attention_processor import Attention +class RebasedAttnProcessor: + + def __call__( + self, + attn: Attention, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + eps = 1e-12 + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = parallel_rebased(query, key, value, eps, True, True) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states \ No newline at end of file From 09f77d81dccec09399312eb977acfe752c670d10 Mon Sep 17 00:00:00 2001 From: kabachuha Date: Wed, 6 Mar 2024 17:43:06 +0800 Subject: [PATCH 3/4] fla is installable, so don't need to embed --- flash_linear_attention/__init__.py | 0 flash_linear_attention/fla/__init__.py | 8 - flash_linear_attention/fla/layers/__init__.py | 7 - flash_linear_attention/fla/layers/based.py | 214 ------ flash_linear_attention/fla/layers/gla.py | 154 ---- .../fla/layers/multiscale_retention.py | 100 --- flash_linear_attention/fla/layers/rebased.py | 258 ------- .../fla/layers/rebased_fast.py | 229 ------ .../fla/modules/__init__.py | 11 - .../fla/modules/convolution.py | 195 ----- flash_linear_attention/fla/modules/rmsnorm.py | 647 ---------------- flash_linear_attention/fla/modules/rotary.py | 312 -------- flash_linear_attention/fla/ops/__init__.py | 25 - .../semiring/cal_A/inner_chunk16_dim16x.cpp | 11 - .../cal_A/inner_chunk16_dim16x_kernel.cu | 204 ------ .../fla/ops/torch/__init__.py | 7 - flash_linear_attention/fla/ops/torch/based.py | 131 ---- flash_linear_attention/fla/ops/torch/gla.py | 119 --- .../fla/ops/torch/retention.py | 15 - .../fla/ops/triton/__init__.py | 22 - .../fla/ops/triton/abc/__init__.py | 0 .../fla/ops/triton/abc/chunk_fuse.py | 692 ------------------ .../fla/ops/triton/based/__init__.py | 6 - .../fla/ops/triton/based/chunk_fuse.py | 410 ----------- .../fla/ops/triton/based/parallel.py | 385 ---------- .../fla/ops/triton/gla/__init__.py | 7 - .../ops/triton/gla/block_parallel/__init__.py | 0 .../inter_chunk_contribution/__init__.py | 0 .../chunk_scan_triton_full.py | 212 ------ .../chunk_scan_triton_no_decay.py | 166 ----- .../chunk_scan_triton_only_gk.py | 187 ----- .../chunk_scan_triton_only_gv.py | 199 ----- .../inter_chunk_contribution/fn.py | 49 -- .../preprocess_cumsum_gk.py | 259 ------- .../preprocess_cumsum_gv.py | 216 ------ .../intra_chunk_contribution/__init__.py | 0 .../intra_chunk_contribution/fn.py | 28 - .../intra_chunk_contribution/fn_only_gk.py | 343 --------- .../intra_chunk_contribution/fn_only_gv.py | 336 --------- .../fla/ops/triton/gla/chunk.py | 39 - .../fla/ops/triton/gla/chunk_fuse.py | 400 ---------- .../fla/ops/triton/gla/recurrent_fuse.py | 403 ---------- .../fla/ops/triton/rebased/__init__.py | 4 - .../fla/ops/triton/rebased/parallel.py | 388 ---------- .../fla/ops/triton/rebased_fast/__init__.py | 4 - .../fla/ops/triton/rebased_fast/parallel.py | 390 ---------- .../fla/ops/triton/retention/__init__.py | 9 - .../fla/ops/triton/retention/chunk.py | 389 ---------- .../fla/ops/triton/retention/chunk_fuse.py | 329 --------- .../fla/ops/triton/retention/parallel.py | 341 --------- .../ops/triton/retention/recurrent_fuse.py | 280 ------- .../fla/ops/triton/rotary.py | 252 ------- .../fla/ops/triton/utils.py | 27 - flash_linear_attention/setup.py | 147 ---- 54 files changed, 9566 deletions(-) delete mode 100644 flash_linear_attention/__init__.py delete mode 100644 flash_linear_attention/fla/__init__.py delete mode 100644 flash_linear_attention/fla/layers/__init__.py delete mode 100644 flash_linear_attention/fla/layers/based.py delete mode 100644 flash_linear_attention/fla/layers/gla.py delete mode 100644 flash_linear_attention/fla/layers/multiscale_retention.py delete mode 100644 flash_linear_attention/fla/layers/rebased.py delete mode 100644 flash_linear_attention/fla/layers/rebased_fast.py delete mode 100644 flash_linear_attention/fla/modules/__init__.py delete mode 100644 flash_linear_attention/fla/modules/convolution.py delete mode 100644 flash_linear_attention/fla/modules/rmsnorm.py delete mode 100644 flash_linear_attention/fla/modules/rotary.py delete mode 100644 flash_linear_attention/fla/ops/__init__.py delete mode 100644 flash_linear_attention/fla/ops/cuda/gla/semiring/cal_A/inner_chunk16_dim16x.cpp delete mode 100644 flash_linear_attention/fla/ops/cuda/gla/semiring/cal_A/inner_chunk16_dim16x_kernel.cu delete mode 100644 flash_linear_attention/fla/ops/torch/__init__.py delete mode 100644 flash_linear_attention/fla/ops/torch/based.py delete mode 100644 flash_linear_attention/fla/ops/torch/gla.py delete mode 100644 flash_linear_attention/fla/ops/torch/retention.py delete mode 100644 flash_linear_attention/fla/ops/triton/__init__.py delete mode 100644 flash_linear_attention/fla/ops/triton/abc/__init__.py delete mode 100644 flash_linear_attention/fla/ops/triton/abc/chunk_fuse.py delete mode 100644 flash_linear_attention/fla/ops/triton/based/__init__.py delete mode 100644 flash_linear_attention/fla/ops/triton/based/chunk_fuse.py delete mode 100644 flash_linear_attention/fla/ops/triton/based/parallel.py delete mode 100644 flash_linear_attention/fla/ops/triton/gla/__init__.py delete mode 100644 flash_linear_attention/fla/ops/triton/gla/block_parallel/__init__.py delete mode 100644 flash_linear_attention/fla/ops/triton/gla/block_parallel/inter_chunk_contribution/__init__.py delete mode 100644 flash_linear_attention/fla/ops/triton/gla/block_parallel/inter_chunk_contribution/chunk_scan_triton_full.py delete mode 100644 flash_linear_attention/fla/ops/triton/gla/block_parallel/inter_chunk_contribution/chunk_scan_triton_no_decay.py delete mode 100644 flash_linear_attention/fla/ops/triton/gla/block_parallel/inter_chunk_contribution/chunk_scan_triton_only_gk.py delete mode 100644 flash_linear_attention/fla/ops/triton/gla/block_parallel/inter_chunk_contribution/chunk_scan_triton_only_gv.py delete mode 100644 flash_linear_attention/fla/ops/triton/gla/block_parallel/inter_chunk_contribution/fn.py delete mode 100644 flash_linear_attention/fla/ops/triton/gla/block_parallel/inter_chunk_contribution/preprocess_cumsum_gk.py delete mode 100644 flash_linear_attention/fla/ops/triton/gla/block_parallel/inter_chunk_contribution/preprocess_cumsum_gv.py delete mode 100644 flash_linear_attention/fla/ops/triton/gla/block_parallel/intra_chunk_contribution/__init__.py delete mode 100644 flash_linear_attention/fla/ops/triton/gla/block_parallel/intra_chunk_contribution/fn.py delete mode 100644 flash_linear_attention/fla/ops/triton/gla/block_parallel/intra_chunk_contribution/fn_only_gk.py delete mode 100644 flash_linear_attention/fla/ops/triton/gla/block_parallel/intra_chunk_contribution/fn_only_gv.py delete mode 100644 flash_linear_attention/fla/ops/triton/gla/chunk.py delete mode 100644 flash_linear_attention/fla/ops/triton/gla/chunk_fuse.py delete mode 100644 flash_linear_attention/fla/ops/triton/gla/recurrent_fuse.py delete mode 100644 flash_linear_attention/fla/ops/triton/rebased/__init__.py delete mode 100644 flash_linear_attention/fla/ops/triton/rebased/parallel.py delete mode 100644 flash_linear_attention/fla/ops/triton/rebased_fast/__init__.py delete mode 100644 flash_linear_attention/fla/ops/triton/rebased_fast/parallel.py delete mode 100644 flash_linear_attention/fla/ops/triton/retention/__init__.py delete mode 100644 flash_linear_attention/fla/ops/triton/retention/chunk.py delete mode 100644 flash_linear_attention/fla/ops/triton/retention/chunk_fuse.py delete mode 100644 flash_linear_attention/fla/ops/triton/retention/parallel.py delete mode 100644 flash_linear_attention/fla/ops/triton/retention/recurrent_fuse.py delete mode 100644 flash_linear_attention/fla/ops/triton/rotary.py delete mode 100644 flash_linear_attention/fla/ops/triton/utils.py delete mode 100644 flash_linear_attention/setup.py diff --git a/flash_linear_attention/__init__.py b/flash_linear_attention/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/flash_linear_attention/fla/__init__.py b/flash_linear_attention/fla/__init__.py deleted file mode 100644 index 432f1ac..0000000 --- a/flash_linear_attention/fla/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -# -*- coding: utf-8 -*- - -from fla.ops.triton import (fused_chunk_based, fused_chunk_gla, - fused_chunk_retention) - -__all__ = ['fused_chunk_based', 'fused_chunk_gla', 'fused_chunk_retention'] - -__version__ = '0.0.1' diff --git a/flash_linear_attention/fla/layers/__init__.py b/flash_linear_attention/fla/layers/__init__.py deleted file mode 100644 index 5080186..0000000 --- a/flash_linear_attention/fla/layers/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -# -*- coding: utf-8 -*- - -from .based import BasedLinearAttention -from .gla import GatedLinearAttention -from .multiscale_retention import MultiScaleRetention - -__all__ = ['GatedLinearAttention', 'MultiScaleRetention', 'BasedLinearAttention'] diff --git a/flash_linear_attention/fla/layers/based.py b/flash_linear_attention/fla/layers/based.py deleted file mode 100644 index 204b160..0000000 --- a/flash_linear_attention/fla/layers/based.py +++ /dev/null @@ -1,214 +0,0 @@ -# -*- coding: utf-8 -*- - -""" -Linear attention in Based. -https://github.com/HazyResearch/zoology/blob/main/zoology/mixers/based.py -""" -import math - -import opt_einsum as oe -import torch -import torch.nn as nn -from einops import rearrange - -from fla.ops.triton.based import fused_chunk_based, parallel_based - - -def init_feature_map(feature_map: str = 'none', **kwargs: any): - """ - Initialize query and key mapping for linear attention - """ - if feature_map in [None, 'none', 'identity']: - return FeatureMap(**kwargs) - # Taylor series approximations to exp(x) - elif feature_map == 'taylor_exp': - return TaylorExp(**kwargs) - else: - raise NotImplementedError( - f'Sorry "{feature_map}" feature map not implemented.') - - -class FeatureMap(nn.Module): - """ - Parent feature map; default is identity function - """ - - def __init__(self, - input_dim: int, - temp: int = None, - head_dim_idx: int = -1, - eps: float = 1e-12, - **kwargs: any): - super().__init__() - self.input_dim = input_dim - self.head_dim_idx = head_dim_idx - self.temp = 1. if temp is None else temp - self.eps = eps - - def forward(self, x: torch.Tensor): - """ - Assume x.shape is (batch_size, n_heads, seq_len, head_dim) - """ - return x - - -class TaylorExp(FeatureMap): - """ - Feature map to compute 2nd-order Taylor approx. of exp(q^T k / sqrt(d)) - """ - - def __init__(self, input_dim: int, **kwargs: any): - super().__init__(input_dim, **kwargs) - self.r2 = math.sqrt(2) - self.rd = math.sqrt(self.input_dim) - self.rrd = math.sqrt(self.rd) - self.tril_indices = torch.tril_indices( - self.input_dim, self.input_dim, -1) - - # Running these in parallel - def forward(self, x: torch.Tensor): - # Get 2nd-order terms (rearrange(x * x), '... m n -> ... (m n)') - x2 = (x.unsqueeze(-1) * x.unsqueeze(-2) - ).flatten(start_dim=-2) / self.r2 - return torch.cat([torch.ones(x[..., :1].shape).to(x.device), - x / self.rrd, x2 / self.rd], dim=self.head_dim_idx) - - def forward_mem_save(self, x: torch.Tensor) -> torch.Tensor: - """ - Compute f(x) s.t. f(x)^T f(x') = 1 + x^Tx' + (x^Tx')^2 / 2 - -> Assume x.shape is (batch_size, n_heads, seq_len, head_dim) - """ - # Slow but memory-saving way to compute 2nd-order terms; how do w/o outer-product first? - x2 = oe.contract('...m,...n->...mn', x, x) / self.rd - x2d = torch.diagonal(x2, dim1=-2, dim2=-1) / self.r2 - x2 = x2[..., self.tril_indices[0], self.tril_indices[1]] - x = torch.cat([torch.ones(x[..., :1].shape).to(x.device), - x / self.rrd, x2d, x2], dim=-1) - return x - - -class BasedLinearAttention(nn.Module): - def __init__( - self, - d_model: int, - l_max: int = 2048, - feature_dim: int = 16, - num_key_value_heads: int = 12, - num_heads: int = 12, - feature_name: str = "taylor_exp", - eps: float = 1e-12, - causal: bool = True, - mode: str = "parallel", - ): - super().__init__() - self.d_model = d_model - self.l_max = l_max - self.mode = mode - assert self.mode in ["fused_chunk", "parallel"] - - # linear attention - self.feature_name = feature_name - self.feature_dim = feature_dim - self.num_key_value_heads = num_key_value_heads - self.num_heads = num_heads - self.head_dim = self.d_model // self.num_key_value_heads - self.causal = causal - feature_map_kwargs = { - 'input_dim': self.feature_dim, - 'head_dim_idx': -1, - 'temp': 1., - 'eps': 1e-12 - } - self.feature_map = init_feature_map( - feature_map=self.feature_name, **feature_map_kwargs) - self.proj_q = nn.Linear( - self.d_model, self.feature_dim * self.num_heads, bias=False) - self.proj_k = nn.Linear( - self.d_model, self.feature_dim * self.num_heads, bias=False) - self.proj_v = nn.Linear( - self.d_model, self.num_key_value_heads * self.head_dim, bias=False) - self.proj_o = nn.Linear( - self.num_heads * self.head_dim, self.d_model, bias=False) - self.dropout = nn.Identity() - self.eps = eps - - def forward(self, hidden_states: torch.Tensor, **kwargs): - mode = self.mode - b, l, _ = hidden_states.size() - q, k, v = self.proj_q(hidden_states), self.proj_k( - hidden_states), self.proj_v(hidden_states) - q, k, v = map(lambda x: rearrange( - x, "b l (h d) -> b h l d", h=self.num_heads), [q, k, v]) - if mode == "fused_chunk": - assert q.shape[-1] <= 16 - o = fused_chunk_based(q, k, v, self.eps, True, True) - elif mode == 'parallel': - assert q.shape[-1] <= 128 - o = parallel_based(q, k, v, self.eps, True, True) - o = rearrange(o, "b h l d -> b l (h d)") - o = self.proj_o(o) - o = self.dropout(o) - return o - - # https://github.com/HazyResearch/zoology/blob/main/zoology/mixers/based.py#L119 - - def forward_reference(self, hidden_states: torch.Tensor, filters: torch.Tensor = None, *args, **kwargs): - """ - x (torch.Tensor): tensor of shape (b, d, l) - y (torch.Tensor): tensor of shape (b, d, l) - """ - # hidden_states = hidden_states.transpose(1, 2) - b, l, _ = hidden_states.size() - q, k, v = self.proj_q(hidden_states), self.proj_k( - hidden_states), self.proj_v(hidden_states) - - q = q.view(b, l, self.num_heads, self.feature_dim).transpose(1, 2) - k = k.view(b, l, self.num_key_value_heads, - self.feature_dim).transpose(1, 2) - v = v.view(b, l, self.num_key_value_heads, - self.head_dim).transpose(1, 2) - - # Linear attention - q, k = self.feature_map(q), self.feature_map(k) - q, k, v = q.unsqueeze(-2), k.unsqueeze(-2), v.unsqueeze(-1) - - # Compute attention - if self.causal: - y = ((q * (k * v).cumsum(2)).sum(-1) / ((q * k.cumsum(2)).sum(-1) + self.eps)) - else: - y = ((q * (k * v).sum(2, True)).sum(-1) / ((q * k.sum(2, True)).sum(-1) + self.eps)) - y = rearrange(y, 'b h l d -> b l (h d)') - y = self.proj_o(y.to(hidden_states.dtype)) - y = self.dropout(y) - return y.to(hidden_states.dtype) - - -if __name__ == '__main__': - batch = 4 - seq_len = 1024 - d_model = 1024 - dtype = torch.float32 - x = torch.randn(batch, seq_len, d_model).to( - dtype).cuda().requires_grad_(True) - dy = torch.randn(batch, seq_len, d_model).to( - dtype).cuda() - model = BasedLinearAttention(d_model=d_model).to(dtype).cuda() - y = model(x) - y.backward(dy, retain_graph=True) - x_grad, x.grad = x.grad, None - - proj_q_grad, model.proj_q.weight.grad = model.proj_q.weight.grad, None - proj_k_grad, model.proj_k.weight.grad = model.proj_k.weight.grad, None - proj_v_grad, model.proj_v.weight.grad = model.proj_v.weight.grad, None - - x.requires_grad_(True) - y2 = model.forward_reference(x) - y2.backward(dy) - print((y - y2).abs().max().item()) - # assert y.allclose(y2, 0, 1e-4), breakpoint() - print((x_grad - x.grad).abs().max().item()) - # assert x_grad.allclose(x.grad, 0, 1e-4), breakpoint() - print((proj_q_grad - model.proj_q.weight.grad).abs().max().item()) - print((proj_k_grad - model.proj_k.weight.grad).abs().max().item()) - print((proj_v_grad - model.proj_v.weight.grad).abs().max().item()) - print("All good with based!") diff --git a/flash_linear_attention/fla/layers/gla.py b/flash_linear_attention/fla/layers/gla.py deleted file mode 100644 index d0626bb..0000000 --- a/flash_linear_attention/fla/layers/gla.py +++ /dev/null @@ -1,154 +0,0 @@ -# -*- coding: utf-8 -*- - -# "Gated Linear Attention Transformers with Hardware-Efficient Training"[https://arxiv.org/abs/2312.06635] - -from __future__ import annotations - -import warnings - -import torch -import torch.nn as nn -import torch.nn.functional as F -from einops import rearrange -from fla.modules.rmsnorm import RMSNorm -from fla.ops.triton.gla import chunk_gla, fused_chunk_gla, fused_recurrent_gla - - -def get_activation_fn(activation): - if activation == 'swish': - return F.silu - elif activation == 'gelu': - return F.gelu - else: - raise NotImplementedError - - -class GatedLinearAttention(nn.Module): - - def __init__( - self, - d_model: int = 1024, - expand_v: int = 2, - expand_k: int = 1, - num_heads: int = 1, - gate_fn: str = 'swish', - layernorm_eps: float = 1e-5, - gate_logit_normalizer: int = 32, - gate_logit_multiplier: int = 1, - gate_low_rank_dim: int = 32, - mode: str = 'fused_chunk', - chunk_size: int = 64, - use_gk: bool = True, # gate associated with key, i.e., $\alpha$ in the paper - use_gv: bool = False, # gate associated with value, i.e., $\beta$ in the paper - *args, **kwargs - ) -> GatedLinearAttention: - super().__init__() - if use_gv is True: - assert mode in ['chunk', 'fused_recurrent'] - if mode == 'fused_chunk': - assert use_gk is True - if mode != 'chunk' and chunk_size != 16: - warnings.warn( - f" `chunk_size` is only used for `chunk` mode." - f" The `{mode}` mode will suppress the passed value of {chunk_size} and always use 16." - ) - self.use_gk = use_gk - self.use_gv = use_gv - self.d_model = d_model - self.mode = mode - self.chunk_size = chunk_size - self.value_dim = int(d_model * expand_v) - self.key_dim = int(d_model * expand_k) - assert mode in ['chunk', 'fused_recurrent', - 'fused_chunk'], f"Not suppoerted mode `{mode}`." - assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}" - assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}" - self.num_heads = num_heads - self.head_qk_dim = self.key_dim // num_heads - self.head_v_dim = self.value_dim // num_heads - self.gate_fn = get_activation_fn(activation=str(gate_fn)) - self.q_proj = nn.Linear(d_model, self.key_dim, bias=False) - self.k_proj = nn.Linear(d_model, self.key_dim, bias=False) - self.v_proj = nn.Linear(d_model, self.value_dim, bias=False) - self.g_proj = nn.Linear(d_model, self.value_dim, bias=False) - - if self.use_gk: - self.gk_proj = nn.Sequential(nn.Linear(d_model, gate_low_rank_dim, bias=False), - nn.Linear(gate_low_rank_dim, self.key_dim, bias=True)) - else: - self.gk_proj = None - if self.use_gv: - self.gv_proj = nn.Sequential(nn.Linear(d_model, gate_low_rank_dim, bias=False), - nn.Linear(gate_low_rank_dim, self.value_dim, - bias=True)) - else: - self.gv_proj = None - self.out_proj = nn.Linear(self.value_dim, d_model, bias=False) - self.group_norm = RMSNorm(self.head_v_dim, eps=layernorm_eps) - self.gate_logit_normalizer = gate_logit_normalizer - self.gate_logit_multiplier = gate_logit_multiplier - - self.reset_parameters() - - def reset_parameters(self): - nn.init.xavier_uniform_(self.q_proj.weight, gain=2 ** -2.5) - nn.init.xavier_uniform_(self.k_proj.weight, gain=2 ** -2.5) - nn.init.xavier_uniform_(self.v_proj.weight, gain=2 ** -2.5) - nn.init.xavier_uniform_(self.g_proj.weight, gain=2 ** -2.5) - nn.init.xavier_uniform_(self.out_proj.weight, gain=2 ** -2.5) - if self.gk_proj is not None: - nn.init.xavier_uniform_(self.gk_proj[0].weight, gain=2 ** -2.5) - nn.init.xavier_uniform_(self.gk_proj[1].weight, gain=2 ** -2.5) - if self.gv_proj is not None: - nn.init.xavier_uniform_(self.gv_proj[0].weight, gain=2 ** -2.5) - nn.init.xavier_uniform_(self.gv_proj[1].weight, gain=2 ** -2.5) - - def forward(self, x): - mode = self.mode - chunk_size = self.chunk_size - - q = rearrange(self.q_proj(x), 'b n (h d) -> b h n d', h=self.num_heads) - k = rearrange(self.k_proj(x), 'b n (h d) -> b h n d', h=self.num_heads) - v = rearrange(self.v_proj(x), 'b n (h d) -> b h n d', h=self.num_heads) - - if mode == 'chunk' or mode == 'fused_recurrent': - # for numumerical stable consideration. fused_chunk has better numerical stability - if self.use_gk: - gk = self.gk_proj(x).to(torch.float32) - gk = (F.logsigmoid(gk) / self.gate_logit_normalizer).clamp_min_(-3) - gk = rearrange(gk, 'b n (h d) -> b h n d', h=self.num_heads) - else: - gk = None - if self.use_gv: - gv = self.gv_proj(x).to(torch.float32) - gv = (F.logsigmoid(gv) / self.gate_logit_normalizer).clamp_min_(-3) - gv = rearrange(gv, 'b n (h d) -> b h n d', h=self.num_heads) - else: - gv = None - if mode == 'fused_recurrent': - o = fused_recurrent_gla(q, k, v, gk=gk, gv=gv) - else: - o = chunk_gla(q, k, v, gk=gk, gv=gv, chunk_size=chunk_size) - else: - g = self.gk_proj(x).to(torch.float32) - g = F.logsigmoid(g * self.gate_logit_multiplier) / self.gate_logit_normalizer - g = rearrange(g, 'b n (h d) -> b h n d', h=self.num_heads) - o = fused_chunk_gla(q, k, v, g) - - o = self.group_norm(rearrange(o, 'b h n d -> b n h d')) - o = self.out_proj(rearrange(o, 'b n h d -> b n (h d)') - * self.gate_fn(self.g_proj(x))) - return o - - -if __name__ == '__main__': - batch = 4 - seq_len = 1023 - d_model = 1024 - x = torch.randn(batch, seq_len, d_model).to( - torch.bfloat16).cuda().requires_grad_(True) - model = GatedLinearAttention(use_gk=True, use_gv=True, mode='chunk').to(torch.bfloat16).cuda() - y = model(x) - print(y.shape) - y.sum().backward() - print(x.grad.shape) diff --git a/flash_linear_attention/fla/layers/multiscale_retention.py b/flash_linear_attention/fla/layers/multiscale_retention.py deleted file mode 100644 index 30650a4..0000000 --- a/flash_linear_attention/fla/layers/multiscale_retention.py +++ /dev/null @@ -1,100 +0,0 @@ -# -*- coding: utf-8 -*- - -# Retentive Network: A Successor to Transformer for Large Language Models"[https://arxiv.org/pdf/2307.08621.pdf] - -from __future__ import annotations - -import torch.nn as nn -import torch.nn.functional as F -from einops import rearrange -from fla.modules.rmsnorm import RMSNorm -from fla.modules.rotary import RotaryEmbedding -from fla.ops.triton.retention import (fused_chunk_retention, - fused_recurrent_retention, - parallel_retention) - - -def get_activation_fn(activation): - if activation == 'swish': - return F.silu - elif activation == 'gelu': - return F.gelu - else: - raise NotImplementedError - - -class MultiScaleRetention(nn.Module): - def __init__( - self, - d_model: str = 1024, - expand_k: str = 1, - expand_v: str = 2, - num_heads: str = 4, - gate_fn: str = 'swish', - layernorm_eps: float = 1e-5, - mode: str = 'chunk', - *args, **kwargs - ) -> MultiScaleRetention: - super().__init__() - - self.d_model = d_model - self.mode = mode - self.value_dim = int(d_model * expand_v) - self.key_dim = int(d_model * expand_k) - self.num_heads = num_heads - assert mode in ['fused_chunk', 'chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`." - assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}" - assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}" - self.head_qk_dim = self.key_dim // num_heads - self.head_v_dim = self.value_dim // num_heads - self.gate_fn = get_activation_fn(activation=str(gate_fn)) - self.q_proj = nn.Linear(d_model, d_model, bias=False) - self.k_proj = nn.Linear(d_model, d_model, bias=False) - self.v_proj = nn.Linear(d_model, self.value_dim, bias=False) - self.g_proj = nn.Linear(d_model, self.value_dim, bias=False) - self.out_proj = nn.Linear(self.value_dim, d_model, bias=False) - - self.group_norm = RMSNorm(self.head_v_dim, eps=layernorm_eps) - self.rotary = RotaryEmbedding(dim=self.head_qk_dim, interleaved=False) - self.reset_parameters() - - - def reset_parameters(self): - nn.init.xavier_uniform_(self.q_proj.weight, gain=2 ** -2.5) - nn.init.xavier_uniform_(self.k_proj.weight, gain=2 ** -2.5) - nn.init.xavier_uniform_(self.v_proj.weight, gain=2 ** -2.5) - nn.init.xavier_uniform_(self.g_proj.weight, gain=2 ** -2.5) - nn.init.xavier_uniform_(self.out_proj.weight, gain=2 ** -2.5) - - def forward(self, x): - mode = self.mode - q1 = rearrange(self.q_proj(x), '... (h d) -> ... h d', h=self.num_heads) - k1 = rearrange(self.k_proj(x), '... (h d) -> ... h d', h=self.num_heads) - q, k = self.rotary(q1, k1) - q, k = q.transpose(1, 2), k.transpose(1, 2) - v = rearrange(self.v_proj(x), 'b n (h d) -> b h n d', h=self.num_heads) - if mode == 'fused_chunk': - o = fused_chunk_retention(q, k, v) - elif mode == 'parallel': - o = parallel_retention(q, k, v) - elif mode == 'fused_recurrent': - o = fused_recurrent_retention(q, k, v) - # TODO: need fix to allow different d_head_qk and d_head_v for "chunk" form - else: - raise NotImplementedError - o = self.group_norm(rearrange(o, 'b h n d -> b n h d')) - return self.out_proj(rearrange(o, 'b n h d -> b n (h d)') * self.gate_fn(self.g_proj(x))) - - -if __name__ == '__main__': - import torch - batch = 4 - seq_len = 1024 - d_model = 1024 - x = torch.randn(batch, seq_len, d_model).to(torch.bfloat16).cuda().requires_grad_(True) - model = MultiScaleRetention().to(torch.bfloat16).cuda() - y = model(x) - print(y.shape) - y.sum().backward() - print(x.grad.shape) - print(x.grad.shape) diff --git a/flash_linear_attention/fla/layers/rebased.py b/flash_linear_attention/fla/layers/rebased.py deleted file mode 100644 index 1ebf3c6..0000000 --- a/flash_linear_attention/fla/layers/rebased.py +++ /dev/null @@ -1,258 +0,0 @@ -# -*- coding: utf-8 -*- - -""" -Linear attention in Based. -https://github.com/HazyResearch/zoology/blob/main/zoology/mixers/based.py -""" -import math - -import opt_einsum as oe -import torch -import torch.nn as nn -from einops import rearrange - -from fla.ops.triton.rebased_fast import parallel_rebased - - -def init_feature_map(feature_map: str = 'none', **kwargs: any): - """ - Initialize query and key mapping for linear attention - """ - if feature_map in [None, 'none', 'identity']: - return FeatureMap(**kwargs) - # Taylor series approximations to exp(x) - elif feature_map == 'taylor_exp': - return TaylorExp(**kwargs) - else: - raise NotImplementedError( - f'Sorry "{feature_map}" feature map not implemented.') - - -class FeatureMap(nn.Module): - """ - Parent feature map; default is identity function - """ - - def __init__(self, - input_dim: int, - temp: int = None, - head_dim_idx: int = -1, - eps: float = 1e-12, - **kwargs: any): - super().__init__() - self.input_dim = input_dim - self.head_dim_idx = head_dim_idx - self.temp = 1. if temp is None else temp - self.eps = eps - - def forward(self, x: torch.Tensor): - """ - Assume x.shape is (batch_size, n_heads, seq_len, head_dim) - """ - return x - - -class TaylorExp(FeatureMap): - """ - Feature map to compute 2nd-order Taylor approx. of exp(q^T k / sqrt(d)) - """ - - def __init__(self, input_dim: int, **kwargs: any): - super().__init__(input_dim, **kwargs) - self.r2 = math.sqrt(2) - self.rd = math.sqrt(self.input_dim) - self.rrd = math.sqrt(self.rd) - self.tril_indices = torch.tril_indices( - self.input_dim, self.input_dim, -1) - - # Running these in parallel - def forward(self, x: torch.Tensor): - # Get 2nd-order terms (rearrange(x * x), '... m n -> ... (m n)') - x2 = (x.unsqueeze(-1) * x.unsqueeze(-2) - ).flatten(start_dim=-2) / self.r2 - return torch.cat( - [ - (torch.ones(x[..., :1].shape).to(x.device) / self.r2), - # x / self.rrd, rebased_fast - x2 / self.rd - ], - dim=self.head_dim_idx - ) - - def forward_mem_save(self, x: torch.Tensor) -> torch.Tensor: - """ - Compute f(x) s.t. f(x)^T f(x') = 1 + x^Tx' + (x^Tx')^2 / 2 - -> Assume x.shape is (batch_size, n_heads, seq_len, head_dim) - """ - # Slow but memory-saving way to compute 2nd-order terms; how do w/o outer-product first? - x2 = oe.contract('...m,...n->...mn', x, x) / self.rd - x2d = torch.diagonal(x2, dim1=-2, dim2=-1) / self.r2 - x2 = x2[..., self.tril_indices[0], self.tril_indices[1]] - x = torch.cat( - [ - (torch.ones(x[..., :1].shape).to(x.device) / self.r2), - # x / self.rrd, - x2d, - x2 - ], - dim=-1 - ) - return x - - -class ReBasedLinearAttention(nn.Module): - def __init__( - self, - d_model: int, - l_max: int = 2048, - feature_dim: int = 16, - num_key_value_heads: int = 12, - num_heads: int = 12, - feature_name: str = "taylor_exp", - eps: float = 1e-12, - causal: bool = True, - mode: str = "parallel", - ): - super().__init__() - self.d_model = d_model - self.l_max = l_max - self.mode = mode - assert self.mode in ["fused_chunk", "parallel"] - - # linear attention - self.feature_name = feature_name - self.feature_dim = feature_dim - self.num_key_value_heads = num_key_value_heads - self.num_heads = num_heads - self.head_dim = self.d_model // self.num_key_value_heads - self.causal = causal - feature_map_kwargs = { - 'input_dim': self.feature_dim, - 'head_dim_idx': -1, - 'temp': 1., - 'eps': 1e-12 - } - self.feature_map = init_feature_map( - feature_map=self.feature_name, **feature_map_kwargs) - self.proj_q = nn.Linear( - self.d_model, self.feature_dim * self.num_heads, bias=False) - self.proj_k = nn.Linear( - self.d_model, self.feature_dim * self.num_heads, bias=False) - self.proj_v = nn.Linear( - self.d_model, self.num_key_value_heads * self.head_dim, bias=False) - self.proj_o = nn.Linear( - self.num_heads * self.head_dim, self.d_model, bias=False) - self.dropout = nn.Identity() - self.eps = eps - - def forward(self, hidden_states: torch.Tensor, **kwargs): - mode = self.mode - b, l, _ = hidden_states.size() - q, k, v = self.proj_q(hidden_states), self.proj_k( - hidden_states), self.proj_v(hidden_states) - q, k, v = map(lambda x: rearrange( - x, "b l (h d) -> b h l d", h=self.num_heads), [q, k, v]) - if mode == "fused_chunk": - assert q.shape[-1] <= 16 - #o = fused_chunk_based(q, k, v, True, True) - elif mode == 'parallel': - assert q.shape[-1] <= 128 - o = parallel_rebased(q, k, v, self.eps, True, True) - o = rearrange(o, "b h l d -> b l (h d)") - o = self.proj_o(o) - o = self.dropout(o) - return o - - # https://github.com/HazyResearch/zoology/blob/main/zoology/mixers/based.py#L119 - - def forward_reference(self, hidden_states: torch.Tensor, filters: torch.Tensor = None, *args, **kwargs): - """ - x (torch.Tensor): tensor of shape (b, d, l) - y (torch.Tensor): tensor of shape (b, d, l) - """ - # hidden_states = hidden_states.transpose(1, 2) - b, l, _ = hidden_states.size() - q, k, v = self.proj_q(hidden_states), self.proj_k( - hidden_states), self.proj_v(hidden_states) - - q = q.view(b, l, self.num_heads, self.feature_dim).transpose(1, 2) - k = k.view(b, l, self.num_key_value_heads, - self.feature_dim).transpose(1, 2) - v = v.view(b, l, self.num_key_value_heads, - self.head_dim).transpose(1, 2) - - # Linear attention - q, k = self.feature_map(q), self.feature_map(k) - q, k, v = q.unsqueeze(-2), k.unsqueeze(-2), v.unsqueeze(-1) - - # Compute attention - if self.causal: - y = ((q * (k * v).cumsum(2)).sum(-1) / ((q * k.cumsum(2)).sum(-1) + self.eps)) - else: - y = ((q * (k * v).sum(2, True)).sum(-1) / ((q * k.sum(2, True)).sum(-1) + self.eps)) - y = rearrange(y, 'b h l d -> b l (h d)') - y = self.proj_o(y.to(hidden_states.dtype)) - y = self.dropout(y) - return y.to(hidden_states.dtype) - - -if __name__ == '__main__': - batch = 4 - seq_len = 1024 - d_model = 1024 - dtype = torch.float32 - x = torch.randn(batch, seq_len, d_model).to( - dtype).cuda().requires_grad_(True) - dy = torch.randn(batch, seq_len, d_model).to( - dtype).cuda() - model = ReBasedLinearAttention(d_model=d_model).to(dtype).cuda() - y = model(x) - y.backward(dy, retain_graph=True) - x_grad, x.grad = x.grad, None - - proj_q_grad, model.proj_q.weight.grad = model.proj_q.weight.grad, None - proj_k_grad, model.proj_k.weight.grad = model.proj_k.weight.grad, None - proj_v_grad, model.proj_v.weight.grad = model.proj_v.weight.grad, None - - x.requires_grad_(True) - y2 = model.forward_reference(x) - y2.backward(dy) - print((y - y2).abs().max().item()) - # assert y.allclose(y2, 0, 1e-4), breakpoint() - print((x_grad - x.grad).abs().max().item()) - # assert x_grad.allclose(x.grad, 0, 1e-4), breakpoint() - print((proj_q_grad - model.proj_q.weight.grad).abs().max().item()) - print((proj_k_grad - model.proj_k.weight.grad).abs().max().item()) - print((proj_v_grad - model.proj_v.weight.grad).abs().max().item()) - print("All good with rebased!") - - starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) - - for d_model in [16, 64]: - model = ReBasedLinearAttention(d_model=d_model).to(dtype).cuda() - for seq_len in [256, 1024, 4096]: - timings_f = [] - timings_b = [] - for i in range(100): - x = torch.randn(batch, seq_len, d_model).to( - dtype).cuda().requires_grad_(True) - dy = torch.randn(batch, seq_len, d_model).to( - dtype).cuda() - - starter.record() - y = model(x) - ender.record() - # WAIT FOR GPU SYNC - torch.cuda.synchronize() - curr_time = starter.elapsed_time(ender) - timings_f.append(curr_time) - - starter.record() - y.backward(dy) - ender.record() - - torch.cuda.synchronize() - curr_time = starter.elapsed_time(ender) - timings_b.append(curr_time) - - print(f"fseq len {seq_len}, d_model {d_model}, forward time: {sum(timings_f) / len(timings_f)}, backward time: {sum(timings_b) / len(timings_b)}") \ No newline at end of file diff --git a/flash_linear_attention/fla/layers/rebased_fast.py b/flash_linear_attention/fla/layers/rebased_fast.py deleted file mode 100644 index 875808b..0000000 --- a/flash_linear_attention/fla/layers/rebased_fast.py +++ /dev/null @@ -1,229 +0,0 @@ -# -*- coding: utf-8 -*- - -""" -Linear attention in Based. -https://github.com/HazyResearch/zoology/blob/main/zoology/mixers/based.py -""" -import math - -import opt_einsum as oe -import torch -import torch.nn as nn -from einops import rearrange - -from fla.ops.triton.rebased_fast import parallel_rebased - - -def init_feature_map(feature_map: str = 'none', **kwargs: any): - """ - Initialize query and key mapping for linear attention - """ - if feature_map in [None, 'none', 'identity']: - return FeatureMap(**kwargs) - # Taylor series approximations to exp(x) - elif feature_map == 'taylor_exp': - return TaylorExp(**kwargs) - else: - raise NotImplementedError( - f'Sorry "{feature_map}" feature map not implemented.') - - -class FeatureMap(nn.Module): - """ - Parent feature map; default is identity function - """ - - def __init__(self, - input_dim: int, - temp: int = None, - head_dim_idx: int = -1, - eps: float = 1e-12, - **kwargs: any): - super().__init__() - self.input_dim = input_dim - self.head_dim_idx = head_dim_idx - self.temp = 1. if temp is None else temp - self.eps = eps - - def forward(self, x: torch.Tensor): - """ - Assume x.shape is (batch_size, n_heads, seq_len, head_dim) - """ - return x - - -class TaylorExp(FeatureMap): - """ - Feature map to compute 2nd-order Taylor approx. of exp(q^T k / sqrt(d)) - """ - - def __init__(self, input_dim: int, **kwargs: any): - super().__init__(input_dim, **kwargs) - self.rd = math.sqrt(self.input_dim) - self.rrd = math.sqrt(self.rd) - - # Running these in parallel - def forward(self, x: torch.Tensor): - # Get 2nd-order terms (rearrange(x * x), '... m n -> ... (m n)') - x2 = (x.unsqueeze(-1) * x.unsqueeze(-2) - ).flatten(start_dim=-2) - return x2 / self.rd - - - -class ReBasedLinearAttention(nn.Module): - def __init__( - self, - d_model: int, - l_max: int = 2048, - feature_dim: int = 16, - num_key_value_heads: int = 12, - num_heads: int = 12, - feature_name: str = "taylor_exp", - eps: float = 1e-12, - causal: bool = True, - mode: str = "parallel", - ): - super().__init__() - self.d_model = d_model - self.l_max = l_max - self.mode = mode - assert self.mode in ["fused_chunk", "parallel"] - - # linear attention - self.feature_name = feature_name - self.feature_dim = feature_dim - self.num_key_value_heads = num_key_value_heads - self.num_heads = num_heads - self.head_dim = self.d_model // self.num_key_value_heads - self.causal = causal - feature_map_kwargs = { - 'input_dim': self.feature_dim, - 'head_dim_idx': -1, - 'temp': 1., - 'eps': 1e-12 - } - self.feature_map = init_feature_map( - feature_map=self.feature_name, **feature_map_kwargs) - self.proj_q = nn.Linear( - self.d_model, self.feature_dim * self.num_heads, bias=False) - self.proj_k = nn.Linear( - self.d_model, self.feature_dim * self.num_heads, bias=False) - self.proj_v = nn.Linear( - self.d_model, self.num_key_value_heads * self.head_dim, bias=False) - self.proj_o = nn.Linear( - self.num_heads * self.head_dim, self.d_model, bias=False) - self.dropout = nn.Identity() - self.eps = eps - - def forward(self, hidden_states: torch.Tensor, **kwargs): - mode = self.mode - b, l, _ = hidden_states.size() - q, k, v = self.proj_q(hidden_states), self.proj_k( - hidden_states), self.proj_v(hidden_states) - q, k, v = map(lambda x: rearrange( - x, "b l (h d) -> b h l d", h=self.num_heads), [q, k, v]) - if mode == "fused_chunk": - assert q.shape[-1] <= 16 - #o = fused_chunk_based(q, k, v, True, True) - elif mode == 'parallel': - assert q.shape[-1] <= 128 - o = parallel_rebased(q, k, v, self.eps, True, True) - o = rearrange(o, "b h l d -> b l (h d)") - o = self.proj_o(o) - o = self.dropout(o) - return o - - # https://github.com/HazyResearch/zoology/blob/main/zoology/mixers/based.py#L119 - - def forward_reference(self, hidden_states: torch.Tensor, filters: torch.Tensor = None, *args, **kwargs): - """ - x (torch.Tensor): tensor of shape (b, d, l) - y (torch.Tensor): tensor of shape (b, d, l) - """ - # hidden_states = hidden_states.transpose(1, 2) - b, l, _ = hidden_states.size() - q, k, v = self.proj_q(hidden_states), self.proj_k( - hidden_states), self.proj_v(hidden_states) - - q = q.view(b, l, self.num_heads, self.feature_dim).transpose(1, 2) - k = k.view(b, l, self.num_key_value_heads, - self.feature_dim).transpose(1, 2) - v = v.view(b, l, self.num_key_value_heads, - self.head_dim).transpose(1, 2) - - # Linear attention - q, k = self.feature_map(q), self.feature_map(k) - q, k, v = q.unsqueeze(-2), k.unsqueeze(-2), v.unsqueeze(-1) - - # Compute attention - if self.causal: - y = ((q * (k * v).cumsum(2)).sum(-1) / ((q * k.cumsum(2)).sum(-1) + self.eps)) - else: - y = ((q * (k * v).sum(2, True)).sum(-1) / ((q * k.sum(2, True)).sum(-1) + self.eps)) - y = rearrange(y, 'b h l d -> b l (h d)') - y = self.proj_o(y.to(hidden_states.dtype)) - y = self.dropout(y) - return y.to(hidden_states.dtype) - - -if __name__ == '__main__': - batch = 4 - seq_len = 1024 - d_model = 1024 - dtype = torch.float32 - x = torch.randn(batch, seq_len, d_model).to( - dtype).cuda().requires_grad_(True) - dy = torch.randn(batch, seq_len, d_model).to( - dtype).cuda() - model = ReBasedLinearAttention(d_model=d_model).to(dtype).cuda() - y = model(x) - y.backward(dy, retain_graph=True) - x_grad, x.grad = x.grad, None - proj_q_grad, model.proj_q.weight.grad = model.proj_q.weight.grad, None - proj_k_grad, model.proj_k.weight.grad = model.proj_k.weight.grad, None - proj_v_grad, model.proj_v.weight.grad = model.proj_v.weight.grad, None - x.requires_grad_(True) - y2 = model.forward_reference(x) - y2.backward(dy) - print((y - y2).abs().max().item()) - # assert y.allclose(y2, 0, 1e-4) - print((x_grad - x.grad).abs().max().item()) - # assert x_grad.allclose(x.grad, 0, 1e-4) - - print((proj_q_grad - model.proj_q.weight.grad).abs().max().item()) - print((proj_k_grad - model.proj_k.weight.grad).abs().max().item()) - print((proj_v_grad - model.proj_v.weight.grad).abs().max().item()) - - print("All good with rebased fast!") - - starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) - - for d_model in [16, 64]: - model = ReBasedLinearAttention(d_model=d_model).to(dtype).cuda() - for seq_len in [256, 1024, 4096]: - timings_f = [] - timings_b = [] - for i in range(100): - x = torch.randn(batch, seq_len, d_model).to( - dtype).cuda().requires_grad_(True) - dy = torch.randn(batch, seq_len, d_model).to( - dtype).cuda() - - starter.record() - y = model(x) - ender.record() - # WAIT FOR GPU SYNC - torch.cuda.synchronize() - curr_time = starter.elapsed_time(ender) - timings_f.append(curr_time) - - starter.record() - y.backward(dy) - ender.record() - - torch.cuda.synchronize() - curr_time = starter.elapsed_time(ender) - timings_b.append(curr_time) - - print(f"fseq len {seq_len}, d_model {d_model}, forward time: {sum(timings_f) / len(timings_f)}, backward time: {sum(timings_b) / len(timings_b)}") \ No newline at end of file diff --git a/flash_linear_attention/fla/modules/__init__.py b/flash_linear_attention/fla/modules/__init__.py deleted file mode 100644 index d4c9909..0000000 --- a/flash_linear_attention/fla/modules/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -# -*- coding: utf-8 -*- - -from .convolution import LongConvolution, ShortConvolution, ImplicitLongConvolution -from .rmsnorm import RMSNorm -from .rotary import RotaryEmbedding - -__all__ = [ - 'LongConvolution', 'ShortConvolution', 'ImplicitLongConvolution', - 'RMSNorm', - 'RotaryEmbedding' -] diff --git a/flash_linear_attention/fla/modules/convolution.py b/flash_linear_attention/fla/modules/convolution.py deleted file mode 100644 index ac3bbc5..0000000 --- a/flash_linear_attention/fla/modules/convolution.py +++ /dev/null @@ -1,195 +0,0 @@ -# -*- coding: utf-8 -*- - -# from https://github.com/HazyResearch/zoology/blob/main/zoology/mixers/convolution.py -import math - -import torch -import torch.nn as nn -import torch.nn.functional as F - -from einops import rearrange - - -def fft_conv(u, k, dropout_mask, gelu=True, k_rev=None): - seqlen = u.shape[-1] - fft_size = 2 * seqlen - k_f = torch.fft.rfft(k, n=fft_size) / fft_size - if k_rev is not None: - k_rev_f = torch.fft.rfft(k_rev, n=fft_size) / fft_size - k_f = k_f + k_rev_f.conj() - u_f = torch.fft.rfft(u.to(dtype=k.dtype), n=fft_size) - - if len(u.shape) > 3: - k_f = k_f.unsqueeze(1) - y = torch.fft.irfft(u_f * k_f, n=fft_size, norm="forward")[..., :seqlen] - - out = y + u - if gelu: - out = F.gelu(out) - if dropout_mask is not None: - return (out * rearrange(dropout_mask, "b H -> b H 1")).to(dtype=u.dtype) - else: - return out.to(dtype=u.dtype) - - -class ShortConvolution(nn.Module): - """ - Simple wrapper around nn.Conv1d that accepts dimension last. - """ - - def __init__( - self, - d_model: int, - kernel_size: int - ): - super().__init__() - self.conv = nn.Conv1d( - in_channels=d_model, - out_channels=d_model, - kernel_size=kernel_size, - groups=d_model, - padding=kernel_size - 1, - ) - - def forward(self, x: torch.Tensor): - """ - Args: - x: (b, l, d) tensor - Returns: - y: (b, l, d) tensor - """ - l = x.size(1) - y = self.conv(x.transpose(1, 2))[..., :l].transpose(1, 2) - return y - - -class LongConvolution(nn.Module): - """ - LongConvolution applies a convolution operation on the input tensor using a fixed - filter of length l_max. - The filter is learned during training and is applied using FFT convolution. - Args: - d_model (int): The number of expected features in the input and output. - l_max (int): The maximum sequence length. - Returns: - y: (b, l, d) tensor - """ - def __init__( - self, - d_model: int, - l_max: int, - **kwargs, - ): - """ - Initializes the LongConvolution module. - Args: - d_model (int): The number of expected features in the input and output. - l_max (int): The maximum sequence length. - """ - super().__init__() - self.d_model = d_model - self.filter = nn.Parameter(torch.randn(self.d_model, l_max), requires_grad=True) - - def forward(self, x: torch.Tensor, *args, **kwargs): - """ - Applies the LongConvolution operation on the input tensor. - Args: - x: (b, l, d) tensor - Returns: - y: (b, l, d) tensor - """ - x = x.transpose(1, 2) - y = fft_conv(x, self.filter, dropout_mask=None, gelu=False) - y = y.transpose(1, 2) - return y.to(dtype=x.dtype) - - -class PositionalEmbedding(nn.Module): - def __init__(self, emb_dim: int, seq_len: int, **kwargs): - """Complex exponential positional embeddings for implicit long convolution filters.""" - super().__init__() - - self.seq_len = seq_len - # The time embedding fed to the filteres is normalized so that t_f = 1 - t = torch.linspace(0, 1, self.seq_len)[None, :, None] # 1, L, 1 - - if emb_dim > 1: - bands = (emb_dim - 1) // 2 - # To compute the right embeddings we use the "proper" linspace - t_rescaled = torch.linspace(0, seq_len - 1, seq_len)[None, :, None] - w = 2 * math.pi * t_rescaled / seq_len # 1, L, 1 - - f = torch.linspace(1e-4, bands - 1, bands)[None, None] - z = torch.exp(-1j * f * w) - z = torch.cat([t, z.real, z.imag], dim=-1) - self.z = nn.Parameter(z, requires_grad=False) - - def forward(self, L): - return self.z[:, :L] - -class ImplicitLongConvolution(nn.Module): - """ - Long convolution with implicit filter parameterized by an MLP. - - Args: - d_model (int): The number of expected features in the input and output. - l_max (int): The maximum sequence length. - d_emb (int, optional): The dimension of the positional embeddings. Must be odd and greater or equal to 3 (time, sine and cosine). Defaults to 3. - d_hidden (int, optional): The number of features in the hidden layer of the MLP. Defaults to 16. - - Attributes: - pos_emb (PositionalEmbedding): The positional embedding layer. - mlp (nn.Sequential): The MLP that parameterizes the implicit filter. - - """ - - - def __init__( - self, - d_model: int, - l_max: int, - d_emb: int=3, - d_hidden: int = 16, - **kwargs, - ): - """ - Long convolution with implicit filter parameterized by an MLP. - - - """ - super().__init__() - self.d_model = d_model - self.d_emb = d_emb - - - assert ( - d_emb % 2 != 0 and d_emb >= 3 - ), "d_emb must be odd and greater or equal to 3 (time, sine and cosine)" - self.pos_emb = PositionalEmbedding(d_emb, l_max) - - # final linear layer - self.mlp = nn.Sequential( - nn.Linear(d_emb, d_hidden), - torch.nn.ReLU(), - nn.Linear(d_hidden, d_model), - ) - - - def filter(self, l: int, *args, **kwargs): - k = self.mlp(self.pos_emb(l)) - - return k.transpose(1, 2) - - def forward(self, x: torch.Tensor, *args, **kwargs): - """ - Args: - x: (b, l, d) tensor - Returns: - y: (b, l, d) tensor - """ - x = x.transpose(1, 2) - k = self.filter(x.shape[-1]) - y = fft_conv(x, k, dropout_mask=None, gelu=False) - - y = y.transpose(1, 2) - return y.to(dtype=x.dtype) \ No newline at end of file diff --git a/flash_linear_attention/fla/modules/rmsnorm.py b/flash_linear_attention/fla/modules/rmsnorm.py deleted file mode 100644 index e2da44e..0000000 --- a/flash_linear_attention/fla/modules/rmsnorm.py +++ /dev/null @@ -1,647 +0,0 @@ -# -*- coding: utf-8 -*- - -# Copyright (c) 2023, Tri Dao. -# https://github.com/state-spaces/mamba/blob/fb7b5310fa865dbd62aa059b1e26f2b431363e2a/mamba_ssm/ops/triton/layernorm.py -# Implement residual + layer_norm / rms_norm. - -# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html -# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate. -# This is faster for dimensions up to 8k, but after that it's much slower due to register spilling. -# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine. - -import math - -import torch -import torch.nn.functional as F -from torch.cuda.amp import custom_fwd, custom_bwd - -import triton -import triton.language as tl - - -def layer_norm_ref(x, weight, bias, residual=None, eps=1e-6, prenorm=False, upcast=False): - dtype = x.dtype - if upcast: - weight = weight.float() - bias = bias.float() if bias is not None else None - if upcast: - x = x.float() - residual = residual.float() if residual is not None else residual - if residual is not None: - x = (x + residual).to(x.dtype) - out = F.layer_norm(x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps).to( - dtype - ) - return out if not prenorm else (out, x) - - -def rms_norm_ref(x, weight, bias, residual=None, eps=1e-6, prenorm=False, upcast=False): - dtype = x.dtype - if upcast: - weight = weight.float() - bias = bias.float() if bias is not None else None - if upcast: - x = x.float() - residual = residual.float() if residual is not None else residual - if residual is not None: - x = (x + residual).to(x.dtype) - rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps) - out = (x * rstd * weight) + \ - bias if bias is not None else (x * rstd * weight) - out = out.to(dtype) - return out if not prenorm else (out, x) - - -@triton.autotune( - configs=[ - triton.Config({}, num_warps=1), - triton.Config({}, num_warps=2), - triton.Config({}, num_warps=4), - triton.Config({}, num_warps=8), - triton.Config({}, num_warps=16), - triton.Config({}, num_warps=32), - ], - key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"], -) -# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) -# @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None}) -@triton.jit -def _layer_norm_fwd_1pass_kernel( - X, # pointer to the input - Y, # pointer to the output - W, # pointer to the weights - B, # pointer to the biases - RESIDUAL, # pointer to the residual - RESIDUAL_OUT, # pointer to the residual - Mean, # pointer to the mean - Rstd, # pointer to the 1/std - stride_x_row, # how much to increase the pointer when moving by 1 row - stride_y_row, - stride_res_row, - stride_res_out_row, - N, # number of columns in X - eps, # epsilon to avoid division by zero - IS_RMS_NORM: tl.constexpr, - BLOCK_N: tl.constexpr, - HAS_RESIDUAL: tl.constexpr, - STORE_RESIDUAL_OUT: tl.constexpr, - HAS_BIAS: tl.constexpr, -): - # Map the program id to the row of X and Y it should compute. - row = tl.program_id(0) - X += row * stride_x_row - Y += row * stride_y_row - if HAS_RESIDUAL: - RESIDUAL += row * stride_res_row - if STORE_RESIDUAL_OUT: - RESIDUAL_OUT += row * stride_res_out_row - # Compute mean and variance - cols = tl.arange(0, BLOCK_N) - x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) - if HAS_RESIDUAL: - residual = tl.load(RESIDUAL + cols, mask=cols < - N, other=0.0).to(tl.float32) - x += residual - if STORE_RESIDUAL_OUT: - tl.store(RESIDUAL_OUT + cols, x, mask=cols < N) - if not IS_RMS_NORM: - mean = tl.sum(x, axis=0) / N - tl.store(Mean + row, mean) - xbar = tl.where(cols < N, x - mean, 0.0) - var = tl.sum(xbar * xbar, axis=0) / N - else: - xbar = tl.where(cols < N, x, 0.0) - var = tl.sum(xbar * xbar, axis=0) / N - rstd = 1 / tl.sqrt(var + eps) - tl.store(Rstd + row, rstd) - # Normalize and apply linear transformation - mask = cols < N - w = tl.load(W + cols, mask=mask).to(tl.float32) - if HAS_BIAS: - b = tl.load(B + cols, mask=mask).to(tl.float32) - x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd - y = x_hat * w + b if HAS_BIAS else x_hat * w - # Write output - tl.store(Y + cols, y, mask=mask) - - -def _layer_norm_fwd( - x, weight, bias, eps, residual=None, out_dtype=None, residual_dtype=None, is_rms_norm=False -): - if residual is not None: - residual_dtype = residual.dtype - M, N = x.shape - assert x.stride(-1) == 1 - if residual is not None: - assert residual.stride(-1) == 1 - assert residual.shape == (M, N) - assert weight.shape == (N,) - assert weight.stride(-1) == 1 - if bias is not None: - assert bias.stride(-1) == 1 - assert bias.shape == (N,) - # allocate output - y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype) - assert y.stride(-1) == 1 - if residual is not None or (residual_dtype is not None and residual_dtype != x.dtype): - residual_out = torch.empty(M, N, device=x.device, dtype=residual_dtype) - assert residual_out.stride(-1) == 1 - else: - residual_out = None - mean = torch.empty((M,), dtype=torch.float32, - device="cuda") if not is_rms_norm else None - rstd = torch.empty((M,), dtype=torch.float32, device="cuda") - # Less than 64KB per feature: enqueue fused kernel - MAX_FUSED_SIZE = 65536 // x.element_size() - BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) - if N > BLOCK_N: - raise RuntimeError( - "This layer norm doesn't support feature dim >= 64KB.") - # heuristics for number of warps - with torch.cuda.device(x.device.index): - _layer_norm_fwd_1pass_kernel[(M,)]( - x, - y, - weight, - bias, - residual, - residual_out, - mean, - rstd, - x.stride(0), - y.stride(0), - residual.stride(0) if residual is not None else 0, - residual_out.stride(0) if residual_out is not None else 0, - N, - eps, - is_rms_norm, - BLOCK_N, - residual is not None, - residual_out is not None, - bias is not None, - ) - # residual_out is None if residual is None and residual_dtype == input_dtype - return y, mean, rstd, residual_out if residual_out is not None else x - - -@triton.autotune( - configs=[ - triton.Config({}, num_warps=1), - triton.Config({}, num_warps=2), - triton.Config({}, num_warps=4), - triton.Config({}, num_warps=8), - triton.Config({}, num_warps=16), - triton.Config({}, num_warps=32), - ], - key=["N", "HAS_DRESIDUAL", "STORE_DRESIDUAL", "IS_RMS_NORM", "HAS_BIAS"], -) -# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) -# @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None}) -# @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None}) -@triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None}) -@triton.jit -def _layer_norm_bwd_kernel( - X, # pointer to the input - W, # pointer to the weights - B, # pointer to the biases - Y, # pointer to the output to be recomputed - DY, # pointer to the output gradient - DX, # pointer to the input gradient - DW, # pointer to the partial sum of weights gradient - DB, # pointer to the partial sum of biases gradient - DRESIDUAL, - DRESIDUAL_IN, - Mean, # pointer to the mean - Rstd, # pointer to the 1/std - stride_x_row, # how much to increase the pointer when moving by 1 row - stride_y_row, - stride_dy_row, - stride_dx_row, - stride_dres_row, - stride_dres_in_row, - M, # number of rows in X - N, # number of columns in X - eps, # epsilon to avoid division by zero - rows_per_program, - IS_RMS_NORM: tl.constexpr, - BLOCK_N: tl.constexpr, - HAS_DRESIDUAL: tl.constexpr, - STORE_DRESIDUAL: tl.constexpr, - HAS_BIAS: tl.constexpr, - RECOMPUTE_OUTPUT: tl.constexpr, -): - # Map the program id to the elements of X, DX, and DY it should compute. - row_block_id = tl.program_id(0) - row_start = row_block_id * rows_per_program - cols = tl.arange(0, BLOCK_N) - mask = cols < N - X += row_start * stride_x_row - if HAS_DRESIDUAL: - DRESIDUAL += row_start * stride_dres_row - if STORE_DRESIDUAL: - DRESIDUAL_IN += row_start * stride_dres_in_row - DY += row_start * stride_dy_row - DX += row_start * stride_dx_row - if RECOMPUTE_OUTPUT: - Y += row_start * stride_y_row - w = tl.load(W + cols, mask=mask).to(tl.float32) - if RECOMPUTE_OUTPUT and HAS_BIAS: - b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32) - dw = tl.zeros((BLOCK_N,), dtype=tl.float32) - if HAS_BIAS: - db = tl.zeros((BLOCK_N,), dtype=tl.float32) - row_end = min((row_block_id + 1) * rows_per_program, M) - for row in range(row_start, row_end): - # Load data to SRAM - x = tl.load(X + cols, mask=mask, other=0).to(tl.float32) - dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32) - if not IS_RMS_NORM: - mean = tl.load(Mean + row) - rstd = tl.load(Rstd + row) - # Compute dx - xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd - xhat = tl.where(mask, xhat, 0.0) - if RECOMPUTE_OUTPUT: - y = xhat * w + b if HAS_BIAS else xhat * w - tl.store(Y + cols, y, mask=mask) - wdy = w * dy - dw += dy * xhat - if HAS_BIAS: - db += dy - if not IS_RMS_NORM: - c1 = tl.sum(xhat * wdy, axis=0) / N - c2 = tl.sum(wdy, axis=0) / N - dx = (wdy - (xhat * c1 + c2)) * rstd - else: - c1 = tl.sum(xhat * wdy, axis=0) / N - dx = (wdy - xhat * c1) * rstd - if HAS_DRESIDUAL: - dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32) - dx += dres - # Write dx - if STORE_DRESIDUAL: - tl.store(DRESIDUAL_IN + cols, dx, mask=mask) - tl.store(DX + cols, dx, mask=mask) - - X += stride_x_row - if HAS_DRESIDUAL: - DRESIDUAL += stride_dres_row - if STORE_DRESIDUAL: - DRESIDUAL_IN += stride_dres_in_row - if RECOMPUTE_OUTPUT: - Y += stride_y_row - DY += stride_dy_row - DX += stride_dx_row - tl.store(DW + row_block_id * N + cols, dw, mask=mask) - if HAS_BIAS: - tl.store(DB + row_block_id * N + cols, db, mask=mask) - - -def _layer_norm_bwd( - dy, - x, - weight, - bias, - eps, - mean, - rstd, - dresidual=None, - has_residual=False, - is_rms_norm=False, - x_dtype=None, - recompute_output=False, -): - M, N = x.shape - assert x.stride(-1) == 1 - assert dy.stride(-1) == 1 - assert dy.shape == (M, N) - if dresidual is not None: - assert dresidual.stride(-1) == 1 - assert dresidual.shape == (M, N) - assert weight.shape == (N,) - assert weight.stride(-1) == 1 - if bias is not None: - assert bias.stride(-1) == 1 - assert bias.shape == (N,) - # allocate output - dx = ( - torch.empty_like(x) - if x_dtype is None - else torch.empty(M, N, dtype=x_dtype, device=x.device) - ) - dresidual_in = torch.empty_like( - x) if has_residual and dx.dtype != x.dtype else None - y = torch.empty(M, N, dtype=dy.dtype, - device=dy.device) if recompute_output else None - - # Less than 64KB per feature: enqueue fused kernel - MAX_FUSED_SIZE = 65536 // x.element_size() - BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) - if N > BLOCK_N: - raise RuntimeError( - "This layer norm doesn't support feature dim >= 64KB.") - sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count - _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device) - _db = ( - torch.empty((sm_count, N), dtype=torch.float32, device=bias.device) - if bias is not None - else None - ) - rows_per_program = math.ceil(M / sm_count) - grid = (sm_count,) - with torch.cuda.device(x.device.index): - _layer_norm_bwd_kernel[grid]( - x, - weight, - bias, - y, - dy, - dx, - _dw, - _db, - dresidual, - dresidual_in, - mean, - rstd, - x.stride(0), - 0 if not recompute_output else y.stride(0), - dy.stride(0), - dx.stride(0), - dresidual.stride(0) if dresidual is not None else 0, - dresidual_in.stride(0) if dresidual_in is not None else 0, - M, - N, - eps, - rows_per_program, - is_rms_norm, - BLOCK_N, - dresidual is not None, - dresidual_in is not None, - bias is not None, - ) - dw = _dw.sum(0).to(weight.dtype) - db = _db.sum(0).to(bias.dtype) if bias is not None else None - # Don't need to compute dresidual_in separately in this case - if has_residual and dx.dtype == x.dtype: - dresidual_in = dx - return (dx, dw, db, dresidual_in) if not recompute_output else (dx, dw, db, dresidual_in, y) - - -class LayerNormFn(torch.autograd.Function): - @staticmethod - def forward( - ctx, - x, - weight, - bias, - residual=None, - eps=1e-6, - prenorm=False, - residual_in_fp32=False, - is_rms_norm=False, - ): - x_shape_og = x.shape - # reshape input data into 2D tensor - x = x.reshape(-1, x.shape[-1]) - if x.stride(-1) != 1: - x = x.contiguous() - if residual is not None: - assert residual.shape == x_shape_og - residual = residual.reshape(-1, residual.shape[-1]) - if residual.stride(-1) != 1: - residual = residual.contiguous() - weight = weight.contiguous() - if bias is not None: - bias = bias.contiguous() - residual_dtype = ( - residual.dtype - if residual is not None - else (torch.float32 if residual_in_fp32 else None) - ) - y, mean, rstd, residual_out = _layer_norm_fwd( - x, weight, bias, eps, residual, residual_dtype=residual_dtype, is_rms_norm=is_rms_norm - ) - ctx.save_for_backward(residual_out, weight, bias, mean, rstd) - ctx.x_shape_og = x_shape_og - ctx.eps = eps - ctx.is_rms_norm = is_rms_norm - ctx.has_residual = residual is not None - ctx.prenorm = prenorm - ctx.x_dtype = x.dtype - y = y.reshape(x_shape_og) - return y if not prenorm else (y, residual_out.reshape(x_shape_og)) - - @staticmethod - def backward(ctx, dy, *args): - x, weight, bias, mean, rstd = ctx.saved_tensors - dy = dy.reshape(-1, dy.shape[-1]) - if dy.stride(-1) != 1: - dy = dy.contiguous() - assert dy.shape == x.shape - if ctx.prenorm: - dresidual = args[0] - dresidual = dresidual.reshape(-1, dresidual.shape[-1]) - if dresidual.stride(-1) != 1: - dresidual = dresidual.contiguous() - assert dresidual.shape == x.shape - else: - dresidual = None - dx, dw, db, dresidual_in = _layer_norm_bwd( - dy, - x, - weight, - bias, - ctx.eps, - mean, - rstd, - dresidual, - ctx.has_residual, - ctx.is_rms_norm, - x_dtype=ctx.x_dtype, - ) - return ( - dx.reshape(ctx.x_shape_og), - dw, - db, - dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None, - None, - None, - None, - None, - ) - - -def layer_norm_fn( - x, - weight, - bias, - residual=None, - eps=1e-6, - prenorm=False, - residual_in_fp32=False, - is_rms_norm=False, -): - return LayerNormFn.apply(x, weight, bias, residual, eps, prenorm, residual_in_fp32, is_rms_norm) - - -def rms_norm_fn(x, weight, bias, residual=None, prenorm=False, residual_in_fp32=False, eps=1e-6): - return LayerNormFn.apply(x, weight, bias, residual, eps, prenorm, residual_in_fp32, True) - - -class RMSNorm(torch.nn.Module): - def __init__(self, hidden_size, eps=1e-5): - # factory_kwargs = {"device": device, "dtype": dtype} - super().__init__() - self.eps = eps - self.weight = torch.nn.Parameter(torch.empty(hidden_size)) - self.register_parameter("bias", None) - self.reset_parameters() - - def reset_parameters(self): - torch.nn.init.ones_(self.weight) - - def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False): - return rms_norm_fn( - x, - self.weight, - self.bias, - residual=residual, - eps=self.eps, - prenorm=prenorm, - residual_in_fp32=residual_in_fp32, - ) - - -class LayerNormLinearFn(torch.autograd.Function): - @staticmethod - @custom_fwd - def forward( - ctx, - x, - norm_weight, - norm_bias, - linear_weight, - linear_bias, - residual=None, - eps=1e-6, - prenorm=False, - residual_in_fp32=False, - is_rms_norm=False, - ): - x_shape_og = x.shape - # reshape input data into 2D tensor - x = x.reshape(-1, x.shape[-1]) - if x.stride(-1) != 1: - x = x.contiguous() - if residual is not None: - assert residual.shape == x_shape_og - residual = residual.reshape(-1, residual.shape[-1]) - if residual.stride(-1) != 1: - residual = residual.contiguous() - norm_weight = norm_weight.contiguous() - if norm_bias is not None: - norm_bias = norm_bias.contiguous() - residual_dtype = ( - residual.dtype - if residual is not None - else (torch.float32 if residual_in_fp32 else None) - ) - y, mean, rstd, residual_out = _layer_norm_fwd( - x, - norm_weight, - norm_bias, - eps, - residual, - out_dtype=None if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype(), - residual_dtype=residual_dtype, - is_rms_norm=is_rms_norm, - ) - y = y.reshape(x_shape_og) - dtype = torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else y.dtype - linear_weight = linear_weight.to(dtype) - linear_bias = linear_bias.to( - dtype) if linear_bias is not None else None - out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias) - # We don't store y, will be recomputed in the backward pass to save memory - ctx.save_for_backward(residual_out, norm_weight, - norm_bias, linear_weight, mean, rstd) - ctx.x_shape_og = x_shape_og - ctx.eps = eps - ctx.is_rms_norm = is_rms_norm - ctx.has_residual = residual is not None - ctx.prenorm = prenorm - ctx.x_dtype = x.dtype - ctx.linear_bias_is_none = linear_bias is None - return out if not prenorm else (out, residual_out.reshape(x_shape_og)) - - @staticmethod - @custom_bwd - def backward(ctx, dout, *args): - x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors - dout = dout.reshape(-1, dout.shape[-1]) - dy = F.linear(dout, linear_weight.t()) - dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0) - if dy.stride(-1) != 1: - dy = dy.contiguous() - assert dy.shape == x.shape - if ctx.prenorm: - dresidual = args[0] - dresidual = dresidual.reshape(-1, dresidual.shape[-1]) - if dresidual.stride(-1) != 1: - dresidual = dresidual.contiguous() - assert dresidual.shape == x.shape - else: - dresidual = None - dx, dnorm_weight, dnorm_bias, dresidual_in, y = _layer_norm_bwd( - dy, - x, - norm_weight, - norm_bias, - ctx.eps, - mean, - rstd, - dresidual, - ctx.has_residual, - ctx.is_rms_norm, - x_dtype=ctx.x_dtype, - recompute_output=True, - ) - dlinear_weight = torch.einsum("bo,bi->oi", dout, y) - return ( - dx.reshape(ctx.x_shape_og), - dnorm_weight, - dnorm_bias, - dlinear_weight, - dlinear_bias, - dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None, - None, - None, - None, - None, - ) - - -def layer_norm_linear_fn( - x, - norm_weight, - norm_bias, - linear_weight, - linear_bias, - residual=None, - eps=1e-6, - prenorm=False, - residual_in_fp32=False, - is_rms_norm=False, -): - return LayerNormLinearFn.apply( - x, - norm_weight, - norm_bias, - linear_weight, - linear_bias, - residual, - eps, - prenorm, - residual_in_fp32, - is_rms_norm, - ) diff --git a/flash_linear_attention/fla/modules/rotary.py b/flash_linear_attention/fla/modules/rotary.py deleted file mode 100644 index b326cb0..0000000 --- a/flash_linear_attention/fla/modules/rotary.py +++ /dev/null @@ -1,312 +0,0 @@ -# -*- coding: utf-8 -*- - -# Copyright (c) 2023, Tri Dao. - -import math -from typing import Optional, Tuple, Union - -import torch -from einops import rearrange, repeat -from fla.ops.triton.rotary import apply_rotary - - -def rotate_half(x, interleaved=False): - if not interleaved: - x1, x2 = x.chunk(2, dim=-1) - return torch.cat((-x2, x1), dim=-1) - else: - x1, x2 = x[..., ::2], x[..., 1::2] - return rearrange(torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2) - - -def apply_rotary_emb_torch(x, cos, sin, interleaved=False): - """ - x: (batch_size, seqlen, nheads, headdim) - cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2) - """ - ro_dim = cos.shape[-1] * 2 - assert ro_dim <= x.shape[-1] - cos = repeat( - cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)") - sin = repeat( - sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)") - return torch.cat( - [x[..., :ro_dim] * cos + - rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]], - dim=-1, - ) - - -class ApplyRotaryEmb(torch.autograd.Function): - @staticmethod - def forward( - ctx, - x, - cos, - sin, - interleaved=False, - inplace=False, - seqlen_offsets: Union[int, torch.Tensor] = 0, - cu_seqlens: Optional[torch.Tensor] = None, - max_seqlen: Optional[int] = None, - ): - out = apply_rotary( - x, - cos, - sin, - seqlen_offsets=seqlen_offsets, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, - interleaved=interleaved, - inplace=inplace, - ) - if isinstance(seqlen_offsets, int): - # Can't save int with save_for_backward - ctx.save_for_backward(cos, sin, cu_seqlens) - ctx.seqlen_offsets = seqlen_offsets - else: - ctx.save_for_backward(cos, sin, cu_seqlens, seqlen_offsets) - ctx.seqlen_offsets = None - ctx.interleaved = interleaved - ctx.inplace = inplace - ctx.max_seqlen = max_seqlen - return out if not inplace else x - - @staticmethod - def backward(ctx, do): - seqlen_offsets = ctx.seqlen_offsets - if seqlen_offsets is None: - cos, sin, cu_seqlens, seqlen_offsets = ctx.saved_tensors - else: - cos, sin, cu_seqlens = ctx.saved_tensors - # TD [2023-09-02]: For some reason Triton (2.0.0.post1) errors with - # "[CUDA]: invalid device context", and cloning makes it work. Idk why. Triton 2.1.0 works. - if not ctx.interleaved and not ctx.inplace: - do = do.clone() - dx = apply_rotary( - do, - cos, - sin, - seqlen_offsets=seqlen_offsets, - cu_seqlens=cu_seqlens, - max_seqlen=ctx.max_seqlen, - interleaved=ctx.interleaved, - inplace=ctx.inplace, - conjugate=True, - ) - return dx, None, None, None, None, None, None, None - - -def apply_rotary_emb( - x, - cos, - sin, - interleaved=False, - inplace=False, - seqlen_offsets: Union[int, torch.Tensor] = 0, - cu_seqlens: Optional[torch.Tensor] = None, - max_seqlen: Optional[int] = None, -): - """ - Arguments: - x: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None - else (total_seqlen, nheads, headdim) - cos, sin: (seqlen_rotary, rotary_dim / 2) - interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead - of 1st half and 2nd half (GPT-NeoX style). - inplace: if True, apply rotary embedding in-place. - seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount. - Most commonly used in inference when we have KV cache. - cu_seqlens: (batch + 1,) or None - max_seqlen: int - Return: - out: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None - else (total_seqlen, nheads, headdim) - rotary_dim must be <= headdim - Apply rotary embedding to the first rotary_dim of x. - """ - return ApplyRotaryEmb.apply( - x, cos, sin, interleaved, inplace, seqlen_offsets, cu_seqlens, max_seqlen - ) - - -# For backward compatibility -apply_rotary_emb_func = apply_rotary_emb - - -class RotaryEmbedding(torch.nn.Module): - """ - The rotary position embeddings from RoFormer_ (Su et. al). - A crucial insight from the method is that the query and keys are - transformed by rotation matrices which depend on the relative positions. - - Other implementations are available in the Rotary Transformer repo_ and in - GPT-NeoX_, GPT-NeoX was an inspiration - - .. _RoFormer: https://arxiv.org/abs/2104.09864 - .. _repo: https://github.com/ZhuiyiTechnology/roformer - .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox - - If scale_base is not None, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554). - A recommended value for scale_base is 512: https://github.com/HazyResearch/flash-attention/issues/96 - Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py - """ - - def __init__( - self, - dim: int, - base=10000.0, - interleaved=False, - scale_base=None, - pos_idx_in_fp32=True, - device=None, - ): - """ - interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead - of 1st half and 2nd half (GPT-NeoX style). - pos_idx_in_fp32: if True, the position indices [0.0, ..., seqlen - 1] are in fp32, - otherwise they might be in lower precision. - This option was added because previously (before 2023-07-02), when we construct - the position indices, we use the dtype of self.inv_freq. In most cases this would - be fp32, but if the model is trained in pure bf16 (not mixed precision), then - self.inv_freq would be bf16, and the position indices are also in bf16. - Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the - embeddings for some positions will coincide. - To maintain compatibility with models previously trained in pure bf16, - we add this option. - """ - super().__init__() - self.dim = dim - self.base = float(base) - self.pos_idx_in_fp32 = pos_idx_in_fp32 - # Generate and save the inverse frequency buffer (non trainable) - inv_freq = self._compute_inv_freq(device) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self.interleaved = interleaved - self.scale_base = scale_base - scale = ( - (torch.arange(0, dim, 2, device=device, - dtype=torch.float32) + 0.4 * dim) / (1.4 * dim) - if scale_base is not None - else None - ) - self.register_buffer("scale", scale, persistent=False) - - self._seq_len_cached = 0 - self._cos_cached = None - self._sin_cached = None - self._cos_k_cached = None - self._sin_k_cached = None - - def _compute_inv_freq(self, device=None): - return 1.0 / ( - self.base - ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim) - ) - - def _update_cos_sin_cache(self, seqlen, device=None, dtype=None): - # Reset the tables if the sequence length has changed, - # if we're on a new device (possibly due to tracing for instance), - # or if we're switching from inference mode to training - if ( - seqlen > self._seq_len_cached - or self._cos_cached is None - or self._cos_cached.device != device - or self._cos_cached.dtype != dtype - or (self.training and self._cos_cached.is_inference()) - ): - self._seq_len_cached = seqlen - # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16 - # And the output of arange can be quite large, so bf16 would lose a lot of precision. - # However, for compatibility reason, we add an option to use the dtype of self.inv_freq. - if self.pos_idx_in_fp32: - t = torch.arange(seqlen, device=device, dtype=torch.float32) - # We want fp32 here as well since inv_freq will be multiplied with t, and the output - # will be large. Having it in bf16 will lose a lot of precision and cause the - # cos & sin output to change significantly. - # We want to recompute self.inv_freq if it was not loaded in fp32 - if self.inv_freq.dtype != torch.float32: - inv_freq = self._compute_inv_freq(device=device) - else: - inv_freq = self.inv_freq - else: - t = torch.arange(seqlen, device=device, - dtype=self.inv_freq.dtype) - inv_freq = self.inv_freq - # Don't do einsum, it converts fp32 to fp16 under AMP - # freqs = torch.einsum("i,j->ij", t, self.inv_freq) - freqs = torch.outer(t, inv_freq) - if self.scale is None: - self._cos_cached = torch.cos(freqs).to(dtype) - self._sin_cached = torch.sin(freqs).to(dtype) - else: - power = ( - torch.arange(seqlen, dtype=self.scale.dtype, - device=self.scale.device) - - seqlen // 2 - ) / self.scale_base - scale = self.scale.to( - device=power.device) ** rearrange(power, "s -> s 1") - # We want the multiplication by scale to happen in fp32 - self._cos_cached = (torch.cos(freqs) * scale).to(dtype) - self._sin_cached = (torch.sin(freqs) * scale).to(dtype) - self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype) - self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype) - - def forward( - self, - q: torch.Tensor, - k: torch.Tensor, - seqlen_offset: Union[int, torch.Tensor] = 0, - max_seqlen: Optional[int] = None, - ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - """ - qkv: (batch, seqlen, 3, nheads, headdim) if kv is none, - else it's just q of shape (batch, seqlen, nheads, headdim) - kv: (batch, seqlen, 2, nheads, headdim) - seqlen_offset: (batch_size,) or int. Each sequence in x is shifted by this amount. - Most commonly used in inference when we have KV cache. - If it's a tensor of shape (batch_size,), then to update the cos / sin cache, one - should pass in max_seqlen, which will update the cos / sin cache up to that length. - Apply rotary embedding *inplace* to qkv and / or kv. - """ - seqlen = q.shape[1] - if max_seqlen is not None: - self._update_cos_sin_cache( - max_seqlen, device=q.device, dtype=q.dtype) - elif isinstance(seqlen_offset, int): - self._update_cos_sin_cache( - seqlen + seqlen_offset, device=q.device, dtype=q.dtype) - if self.scale is None: - q = apply_rotary_emb_func( - q, - self._cos_cached, - self._sin_cached, - interleaved=self.interleaved, - seqlen_offsets=seqlen_offset, - ) - k = apply_rotary_emb_func( - k, - self._cos_cached, - self._sin_cached, - interleaved=self.interleaved, - seqlen_offsets=seqlen_offset, - ) - - else: - q = apply_rotary_emb_func( - q, - self._cos_cached, - self._sin_cached, - interleaved=self.interleaved, - seqlen_offsets=seqlen_offset, - ) - k = apply_rotary_emb_func( - k, - self._cos_k_cached, - self._sin_k_cached, - interleaved=self.interleaved, - seqlen_offsets=seqlen_offset, - ) - - return q, k diff --git a/flash_linear_attention/fla/ops/__init__.py b/flash_linear_attention/fla/ops/__init__.py deleted file mode 100644 index 0f8c005..0000000 --- a/flash_linear_attention/fla/ops/__init__.py +++ /dev/null @@ -1,25 +0,0 @@ -# -*- coding: utf-8 -*- - -from fla.ops.torch import (naive_chunk_based, naive_parallel_based, - naive_recurrent_gla, naive_retention) -from fla.ops.triton import (chunk_gla, chunk_retention, fused_chunk_based, - fused_chunk_gla, fused_chunk_retention, - fused_recurrent_gla, fused_recurrent_retention, - parallel_based, parallel_retention, parallel_rebased) - -__all__ = [ - 'naive_chunk_based', - 'naive_parallel_based', - 'naive_recurrent_gla', - 'naive_retention', - 'chunk_gla', - 'chunk_retention', - 'fused_chunk_based', - 'fused_chunk_gla', - 'fused_chunk_retention', - 'fused_recurrent_gla', - 'fused_recurrent_retention', - 'parallel_based', - 'parallel_rebased', - 'parallel_retention', -] diff --git a/flash_linear_attention/fla/ops/cuda/gla/semiring/cal_A/inner_chunk16_dim16x.cpp b/flash_linear_attention/fla/ops/cuda/gla/semiring/cal_A/inner_chunk16_dim16x.cpp deleted file mode 100644 index 57cbb98..0000000 --- a/flash_linear_attention/fla/ops/cuda/gla/semiring/cal_A/inner_chunk16_dim16x.cpp +++ /dev/null @@ -1,11 +0,0 @@ -#include - -torch::Tensor fwd_cuda(torch::Tensor& Q, torch::Tensor& K, torch::Tensor& g_K); - -std::vector bwd_cuda(torch::Tensor Q, torch::Tensor K, - torch::Tensor g_K, torch::Tensor DQK); - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("forward", &fwd_cuda, "GLA compute A semiring (CUDA)"); - m.def("backward", &bwd_cuda, "GLA compute A semiring (CUDA)"); -} diff --git a/flash_linear_attention/fla/ops/cuda/gla/semiring/cal_A/inner_chunk16_dim16x_kernel.cu b/flash_linear_attention/fla/ops/cuda/gla/semiring/cal_A/inner_chunk16_dim16x_kernel.cu deleted file mode 100644 index 9e9ee34..0000000 --- a/flash_linear_attention/fla/ops/cuda/gla/semiring/cal_A/inner_chunk16_dim16x_kernel.cu +++ /dev/null @@ -1,204 +0,0 @@ -#include -#include -#include - -#include "ATen/ATen.h" - -typedef at::BFloat16 bf16; - -template -__global__ void fwd_inner_chunk16_dim16x(int batchSize, int M, int N_K, - scalar_t* Q, scalar_t* K, float* G_K, - scalar_t* QK) { - // Batch index - const int batchIdx = blockIdx.x; - // allocate buffer for current block in fast shared mem - __shared__ float Q_tile[16][16]; - __shared__ float K_tile[16][16]; - __shared__ float G_tile[16][16]; - __shared__ float G_tile_trans[16][16]; - - const uint threadCol = threadIdx.x % 16; - const uint threadRow = threadIdx.x / 16; - - int K_Stride = M * N_K; - - // Adjust the pointer for batch and matrix size - Q += batchIdx * K_Stride; - K += batchIdx * K_Stride; - G_K += batchIdx * K_Stride; - QK += batchIdx * M * M; - - float tmp = 0.0; - // printf("Hello world"); - // printf("%d, %d, %d \n", threadRow, threadCol, N_K); - for (int bkIdx = 0; bkIdx < N_K; bkIdx += 16) { - Q_tile[threadRow][threadCol] = (float)Q[threadRow * N_K + threadCol]; - K_tile[threadRow][threadCol] = (float)K[threadRow * N_K + threadCol]; - float tmp_gk = (float)G_K[threadRow * N_K + threadCol]; - G_tile[threadRow][threadCol] = (float)tmp_gk; - G_tile_trans[threadCol][threadRow] = (float)tmp_gk; - - __syncthreads(); - - Q += 16; - K += 16; - G_K += 16; - - if (threadCol <= threadRow) { - for (int dotIdx = 0; dotIdx < 16; ++dotIdx) { - // avoid bank conflict? - float exp_term = - expf(G_tile[threadRow][dotIdx] - G_tile_trans[dotIdx][threadCol]); - tmp += Q_tile[threadRow][dotIdx] * K_tile[threadCol][dotIdx] * exp_term; - } - } - __syncthreads(); - } - - if (threadCol <= threadRow) { - QK[threadRow * M + threadCol] = (scalar_t)tmp; - } else { - QK[threadRow * M + threadCol] = (scalar_t)0.0; - } -} - -template -__global__ void bwd_inner_chunk16_dim16x(int batchSize, int M, int N_K, - scalar_t* Q, scalar_t* K, float* G, - scalar_t* DQK, scalar_t* DQ, - scalar_t* DK, float* DG) { - // Batch index - const uint batchIdx = blockIdx.x; - - // allocate buffer for current block in fast shared mem - __shared__ float Q_tile[16][16]; - __shared__ float QK_tile[16][16]; - __shared__ float K_tile[16][16]; - __shared__ float G_tile[16][16]; - __shared__ float G_tile_trans[16][16]; - - const uint threadCol = threadIdx.x % 16; - const uint threadRow = threadIdx.x / 16; - - int K_Stride = M * N_K; - - Q += batchIdx * K_Stride; - DQ += batchIdx * K_Stride; - K += batchIdx * K_Stride; - DK += batchIdx * K_Stride; - G += batchIdx * K_Stride; - DG += batchIdx * K_Stride; - - DQK += batchIdx * M * M; - QK_tile[threadRow][threadCol] = - (threadCol <= threadRow) ? (float)DQK[threadRow * M + threadCol] : 0.0; - __syncthreads(); - - for (int bkIdx = 0; bkIdx < N_K; bkIdx += 16) { - Q_tile[threadRow][threadCol] = (float)Q[threadRow * N_K + threadCol]; - K_tile[threadRow][threadCol] = (float)K[threadRow * N_K + threadCol]; - float tmp_gk = (float)G[threadRow * N_K + threadCol]; - G_tile[threadRow][threadCol] = tmp_gk; - // G_tile_trans[threadCol][threadRow] = tmp_gk; - - __syncthreads(); - - float threadResults_dK = 0; - float threadResults_dQ = 0; - - for (uint dotIdx = threadRow; dotIdx < 16; dotIdx += 1) { - float tmp = - QK_tile[dotIdx][threadRow] * - expf(G_tile[dotIdx][threadCol] - G_tile[threadRow][threadCol]) * - Q_tile[dotIdx][threadCol]; - threadResults_dK += tmp; - } - - for (uint dotIdx = 0; dotIdx <= threadRow; dotIdx += 1) { - float tmp = - QK_tile[threadRow][dotIdx] * - expf(G_tile[threadRow][threadCol] - G_tile[dotIdx][threadCol]) * - K_tile[dotIdx][threadCol]; - threadResults_dQ += dotIdx <= threadRow ? tmp : 0; - } - - __syncthreads(); - DQ[threadRow * N_K + threadCol] = (scalar_t)threadResults_dQ; - DK[threadRow * N_K + threadCol] = (scalar_t)threadResults_dK; - DG[threadRow * N_K + threadCol] = - (threadResults_dQ * Q_tile[threadRow][threadCol] - - threadResults_dK * K_tile[threadRow][threadCol]); - Q += 16; - K += 16; - G += 16; - DQ += 16; - DK += 16; - DG += 16; - __syncthreads(); - } -} - -std::vector bwd_cuda(torch::Tensor Q, torch::Tensor K, - torch::Tensor g_K, torch::Tensor DQK) { - auto DQ = torch::empty_like(Q); - auto DK = torch::empty_like(K); - auto Dg_K = torch::empty_like(g_K); - - int B_size = Q.size(0); // This is the batch size dimension. - int H_size = Q.size(1); // This is the head dimension - int num_chunk = Q.size(2); // This is the chunk dimension. - int M = Q.size(-2); - int N_K = Q.size(-1); - - dim3 gridDim(B_size * H_size * num_chunk); - dim3 blockDim(256); - - switch (Q.type().scalarType()) { - case torch::ScalarType::BFloat16: - bwd_inner_chunk16_dim16x<<>>( - B_size * H_size * num_chunk, M, N_K, Q.data_ptr(), - K.data_ptr(), g_K.data_ptr(), DQK.data_ptr(), - DQ.data_ptr(), DK.data_ptr(), Dg_K.data_ptr()); - break; - default: - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - Q.scalar_type(), "bwd_inner_chunk16_dim16x", ([&] { - bwd_inner_chunk16_dim16x<<>>( - B_size * H_size * num_chunk, M, N_K, Q.data_ptr(), - K.data_ptr(), g_K.data_ptr(), - DQK.data_ptr(), DQ.data_ptr(), - DK.data_ptr(), Dg_K.data_ptr()); - })); - }; - return {DQ, DK, Dg_K}; -} - -torch::Tensor fwd_cuda(torch::Tensor& Q, torch::Tensor& K, torch::Tensor& g_K) { - auto QK = torch::empty( - {Q.size(0), Q.size(1), Q.size(2), Q.size(3), Q.size(3)}, Q.options()); - int B_size = Q.size(0); // This is the batch size dimension. - int H_size = Q.size(1); // This is the head dimension - int num_chunk = Q.size(2); // This is the chunk dimension. - int M = Q.size(-2); // this is the chunk size - int N_K = Q.size(-1); // this is the head_K dim - - dim3 gridDim(B_size * H_size * num_chunk); - dim3 blockDim(256); - switch (Q.type().scalarType()) { - case torch::ScalarType::BFloat16: - fwd_inner_chunk16_dim16x<<>>( - B_size * H_size * num_chunk, M, N_K, Q.data_ptr(), - K.data_ptr(), g_K.data_ptr(), QK.data_ptr()); - break; - default: - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - Q.scalar_type(), "fwd_inner_chunk16_dim16x", ([&] { - fwd_inner_chunk16_dim16x<<>>( - B_size * H_size * num_chunk, M, N_K, Q.data_ptr(), - K.data_ptr(), g_K.data_ptr(), - QK.data_ptr()); - })); - }; - return QK; -} \ No newline at end of file diff --git a/flash_linear_attention/fla/ops/torch/__init__.py b/flash_linear_attention/fla/ops/torch/__init__.py deleted file mode 100644 index 4b6852d..0000000 --- a/flash_linear_attention/fla/ops/torch/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -# -*- coding: utf-8 -*- - -from .based import naive_chunk_based, naive_parallel_based -from .gla import naive_recurrent_gla -from .retention import naive_retention - -__all__ = ['naive_chunk_based', 'naive_parallel_based', 'naive_recurrent_gla', 'naive_retention'] diff --git a/flash_linear_attention/fla/ops/torch/based.py b/flash_linear_attention/fla/ops/torch/based.py deleted file mode 100644 index ee4c6ef..0000000 --- a/flash_linear_attention/fla/ops/torch/based.py +++ /dev/null @@ -1,131 +0,0 @@ -# -*- coding: utf-8 -*- - -import torch -from einops import rearrange - -from fla.ops.triton.based import fused_chunk_based, parallel_based - - -def naive_parallel_based(q, k, v, use_scale=True, use_norm=True): - if use_scale: - q = q * (q.shape[-1] ** -0.5) - attn = q @ k.transpose(-2, -1) - attn = 1 + attn + 1/2 * (attn ** 2) - attn.masked_fill_(~torch.tril(torch.ones( - q.shape[-2], q.shape[-2], dtype=torch.bool, device=q.device)), 0) - o = attn @ v - if use_norm: - z = attn.sum(-1) - return o / (z[..., None] + 1e-6) - else: - return o - - -def naive_chunk_based(q, k, v, chunk_size=256): - q = q * (q.shape[-1] ** -0.5) - - # compute normalizer. - k_cumsum = torch.cumsum(k, dim=-2) - kk_cumsum = torch.cumsum(k.unsqueeze(-1) * k.unsqueeze(-2), dim=-3) - # first - z = (q * k_cumsum).sum(-1) - # second order - z += (q.unsqueeze(-1) * q.unsqueeze(-2) * kk_cumsum).sum((-1, -2)) * 0.5 - # zero-th order - z += (torch.arange(0, q.shape[-2]).to(z.device) * 1.0 + 1.0)[None, None, :] - - # compute o - # constant term - _o = v.cumsum(-2) - - q = rearrange(q, 'b h (n c) d -> b h n c d', c=chunk_size) - - k = rearrange(k, 'b h (n c) d -> b h n c d', c=chunk_size) - v = rearrange(v, 'b h (n c) d -> b h n c d', c=chunk_size) - - intra_chunk_attn = q @ k.transpose(-2, -1) - intra_chunk_attn = intra_chunk_attn + 1/2 * (intra_chunk_attn ** 2) - intra_chunk_attn.masked_fill_( - ~torch.tril( - torch.ones(chunk_size, chunk_size, - dtype=torch.bool, device=q.device), - ), 0) - o = intra_chunk_attn @ v - - # quadractic term - kv = torch.einsum( - 'b h n c x, b h n c y, b h n c z -> b h n x y z', k, k, v) - kv = kv.cumsum(2) - kv = torch.cat([torch.zeros_like(kv[:, :, :1]), kv[:, :, :-1]], dim=2) - - o += 0.5 * torch.einsum('b h n x y z, b h n c x, b h n c y -> b h n c z', kv, q, q) - - # linear term - kv = torch.einsum('b h n c x, b h n c y -> b h n x y', k, v) - kv = kv.cumsum(2) - kv = torch.cat([torch.zeros_like(kv[:, :, :1]), kv[:, :, :-1]], dim=2) - o += torch.einsum('b h n x y, b h n c x -> b h n c y', kv, q) - - o = rearrange(o, 'b h n c d -> b h (n c) d') - o = o + _o - return o / (z[..., None] + 1e-6) - - -if __name__ == "__main__": - B = 4 - H = 4 - L = 128 - # D = 15 - dtype = torch.float32 - q = (torch.randn(B, H, L, 16).cuda().to(dtype)).requires_grad_(True) - k = (torch.randn(B, H, L, 16).cuda().to(dtype)).requires_grad_(True) - v = torch.randn(B, H, L, 128).cuda().to(dtype).requires_grad_(True) - - do = torch.randn_like(v).cuda() - ref = naive_parallel_based(q, k, v, True, True) - ref.backward(do, retain_graph=True) - ref_dq, q.grad = q.grad.clone(), None - ref_dk, k.grad = k.grad.clone(), None - ref_dv, v.grad = v.grad.clone(), None - - # tri = naive_chunk_based(q, k, v) - # tri.backward(do, retain_graph=True) - # tri_dq, q.grad = q.grad.clone(), None - # tri_dk, k.grad = k.grad.clone(), None - # tri_dv, v.grad = v.grad.clone(), None - - # assert ref.allclose(tri, 0, 1e-4), breakpoint() - # assert ref_dq.allclose(tri_dq, 0, 1e-4), breakpoint() - # assert ref_dk.allclose(tri_dk, 0, 1e-4), breakpoint() - # assert ref_dv.allclose(tri_dv, 0, 1e-4), breakpoint() - - tri = fused_chunk_based(q, k, v, True, True) - tri.backward(do, retain_graph=True) - tri_dq, q.grad = q.grad.clone(), None - tri_dk, k.grad = k.grad.clone(), None - tri_dv, v.grad = v.grad.clone(), None - print((ref-tri).abs().max()) - print((ref_dq-tri_dq).abs().max()) - print((ref_dk-tri_dk).abs().max()) - print((ref_dv-tri_dv).abs().max()) - - # assert ref.allclose(tri, 0, 1e-4), breakpoint() - # assert ref_dq.allclose(tri_dq, 0, 1e-4), breakpoint() - # assert ref_dk.allclose(tri_dk, 0, 1e-4), breakpoint() - # assert ref_dv.allclose(tri_dv, 0, 1e-4), breakpoint() - - tri = parallel_based(q, k, v, True, True) - tri.backward(do, retain_graph=True) - tri_dq, q.grad = q.grad.clone(), None - tri_dk, k.grad = k.grad.clone(), None - tri_dv, v.grad = v.grad.clone(), None - - print((ref-tri).abs().max()) - print((ref_dq-tri_dq).abs().max()) - print((ref_dk-tri_dk).abs().max()) - print((ref_dv-tri_dv).abs().max()) - - # assert ref.allclose(tri, 0, 1e-4), breakpoint() - # assert ref_dq.allclose(tri_dq, 0, 1e-4), breakpoint() - # assert ref_dk.allclose(tri_dk, 0, 1e-4), breakpoint() - # assert ref_dv.allclose(tri_dv, 0, 1e-4), breakpoint() diff --git a/flash_linear_attention/fla/ops/torch/gla.py b/flash_linear_attention/fla/ops/torch/gla.py deleted file mode 100644 index 8b06a24..0000000 --- a/flash_linear_attention/fla/ops/torch/gla.py +++ /dev/null @@ -1,119 +0,0 @@ -# -*- coding: utf-8 -*- - -import torch -import torch.nn.functional as F - -from fla.ops.triton.gla import fused_recurrent_gla - - -def ceildiv(a, b): - return -(a // -b) - - -def naive_recurrent_gla( - q, - k, - v, - gk, - initial_state=None, - output_final_state=False, - causal=True -): - orig_dtype = q.dtype - q, k, v, gk = map(lambda x: x.float(), (q, k, v, gk)) - batch_size, n_heads, seq_len, d_head_k = q.shape - _, _, _, d_head_v = v.shape - h = torch.zeros(batch_size, n_heads, d_head_k, d_head_v, dtype=torch.float32, device=q.device) - o = torch.zeros_like(v) - scale = d_head_k ** -0.5 - - if initial_state is not None: - h += initial_state - - for i in range(seq_len): - q_i = q[:, :, i, :] * scale - k_i = k[:, :, i] - v_i = v[:, :, i, :] - gk_i = gk[:, :, i].exp() - kv_i = k_i[..., None] * v_i[..., None, :] - h = h * gk_i[..., None] + kv_i - o_i = (q_i[..., None] * h).sum(-2) - o[:, :, i] = o_i - - if causal: - if output_final_state: - return o.to(orig_dtype), h - else: - return o.to(orig_dtype) - else: - o_reverse = torch.zeros_like(v) - h = torch.zeros(batch_size, n_heads, d_head_k, d_head_v, dtype=torch.float32, device=q.device) - for i in range(seq_len-1, -1, -1): - q_i = q[:, :, i, :] * scale - k_i = k[:, :, i] - v_i = v[:, :, i, :] - gk_i = gk[:, :, i].exp() - kv_i = k_i[..., None] * v_i[..., None, :] - h = h * gk_i[..., None] + kv_i - o_i = (q_i[..., None] * h).sum(-2) - o_reverse[:, :, i] = o_i - - return o, o_reverse - - -if __name__ == "__main__": - B = 4 - H = 4 - L = 512 - D = 128 - dtype = torch.float32 - q = (torch.randn(B, H, L, D).cuda().to(dtype)).requires_grad_(True) - k = (torch.randn(B, H, L, D).cuda().to(dtype)).requires_grad_(True) - v = torch.randn(B, H, L, D).cuda().to(dtype).requires_grad_(True) - g = F.logsigmoid(torch.rand(B, H, L, D)).cuda( - ).clamp_min(-1).to(torch.float32).requires_grad_(True) - - do = torch.rand_like(v).cuda() - do2 = torch.rand_like(v).cuda() - intial_state = torch.rand(B, H, D, D).cuda() - - ref, ref_rev = naive_recurrent_gla(q, k, v, g, causal=False) - - ref.backward(do, retain_graph=True) - ref_rev.backward(do2, retain_graph=True) - - ref_dq, q.grad = q.grad.clone(), None - ref_dk, k.grad = k.grad.clone(), None - ref_dv, v.grad = v.grad.clone(), None - ref_dg, g.grad = g.grad.clone(), None - - tri, tri_rev = fused_recurrent_gla( - q, k, v, g, initial_state=None, scale=D**-0.5, output_final_state=False, causal=False) - tri.backward(do, retain_graph=True) - tri_rev.backward(do2, retain_graph=True) - tri_dq, q.grad = q.grad.clone(), None - tri_dk, k.grad = k.grad.clone(), None - tri_dv, v.grad = v.grad.clone(), None - tri_dg, g.grad = g.grad.clone(), None - - assert ref.allclose(tri, 0, 1e-5), breakpoint() - assert ref_rev.allclose(tri_rev, 0, 1e-5), breakpoint() - assert ref_dq.allclose(tri_dq, 0, 1e-5), breakpoint() - assert ref_dk.allclose(tri_dk, 0, 1e-5), breakpoint() - assert ref_dv.allclose(tri_dv, 0, 1e-5), breakpoint() - assert ref_dg.allclose(tri_dg, 0, 1e-4), breakpoint() - - # tri = fused_chunk_gla(q, k, v, g) - # tri.backward(do, retain_graph=True) - # tri_dq, q.grad = q.grad.clone(), None - # tri_dk, k.grad = k.grad.clone(), None - # tri_dv, v.grad = v.grad.clone(), None - # tri_dg, g.grad = g.grad.clone(), None - - # assert ref.allclose(tri, 0, 1e-5), breakpoint() - # assert ref_dq.allclose(tri_dq, 0, 1e-5), breakpoint() - # assert ref_dk.allclose(tri_dk, 0, 1e-5), breakpoint() - # assert ref_dv.allclose(tri_dv, 0, 1e-5), breakpoint() - # assert ref_dg.allclose(tri_dg, 0, 1e-4), breakpoint() - # breakpoint() - print("Pass") diff --git a/flash_linear_attention/fla/ops/torch/retention.py b/flash_linear_attention/fla/ops/torch/retention.py deleted file mode 100644 index 15611bf..0000000 --- a/flash_linear_attention/fla/ops/torch/retention.py +++ /dev/null @@ -1,15 +0,0 @@ -# -*- coding: utf-8 -*- - -import torch - - -def naive_retention(q, k, v): - orig_type = q.dtype - q, k, v = q.float(), k.float(), v.float() - _, n_heads, seq_len, d_head = q.shape - s = (1 - q.new_tensor(2., dtype=torch.float).pow(-5. - q.new_tensor(range(n_heads), dtype=torch.float))).log2() - n = q.new_tensor(range(seq_len), dtype=torch.float) - n = torch.exp2((n.unsqueeze(-1) - n) * s.view(-1, 1, 1)) * n.unsqueeze(-1).ge(n) - s = torch.einsum('bhqd,bhkd,hqk->bhqk', q * d_head ** -0.5, k, n.to(q.dtype)) - o = torch.einsum('bhqk,bhkd->bhqd', s, v) - return o.to(orig_type) diff --git a/flash_linear_attention/fla/ops/triton/__init__.py b/flash_linear_attention/fla/ops/triton/__init__.py deleted file mode 100644 index 74c6286..0000000 --- a/flash_linear_attention/fla/ops/triton/__init__.py +++ /dev/null @@ -1,22 +0,0 @@ -# -*- coding: utf-8 -*- - -from .based import fused_chunk_based, parallel_based -from .rebased import parallel_rebased -from .gla import chunk_gla, fused_chunk_gla, fused_recurrent_gla -from .retention import (chunk_retention, fused_chunk_retention, - fused_recurrent_retention, parallel_retention) -from .rotary import apply_rotary - -__all__ = [ - 'fused_chunk_based', - 'parallel_based', - 'parallel_rebased', - 'chunk_gla', - 'fused_chunk_gla', - 'fused_recurrent_gla', - 'chunk_retention', - 'fused_chunk_retention', - 'fused_recurrent_retention', - 'parallel_retention', - 'apply_rotary' -] diff --git a/flash_linear_attention/fla/ops/triton/abc/__init__.py b/flash_linear_attention/fla/ops/triton/abc/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/flash_linear_attention/fla/ops/triton/abc/chunk_fuse.py b/flash_linear_attention/fla/ops/triton/abc/chunk_fuse.py deleted file mode 100644 index ea6589f..0000000 --- a/flash_linear_attention/fla/ops/triton/abc/chunk_fuse.py +++ /dev/null @@ -1,692 +0,0 @@ -# -*- coding: utf-8 -*- - -# Copyright (c) 2023-2024, Yu Zhang, Songlin Yang - -import torch -import triton -import triton.language as tl - -from fla.ops.triton.utils import contiguous - - -@triton.jit -def chunk_abc_fwd_kernel_s( - q, - k, - s, - rk, # rescale term - ck, # scores normalized over a chunk - pk, # scores normalized over the sequence - s_qk_h, - s_qk_t, - s_qk_d, - s_sk_h, - s_sk_t, - s_sk_m, - T, - scale, - BT: tl.constexpr, - BK: tl.constexpr, - BM: tl.constexpr, - DK: tl.constexpr, - DM: tl.constexpr, - NT: tl.constexpr -): - i_m, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) - n_bh = tl.num_programs(2) - - p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (0, i_k * BK), (BT, BK), (1, 0)) - p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BT), (0, 1)) - p_s = tl.make_block_ptr(s + (i_k * n_bh + i_bh)*s_sk_h, (T, DM), (s_sk_t, s_sk_m), (0, i_m * BM), (BT, BM), (1, 0)) - p_rk = tl.make_block_ptr(rk + i_bh * s_sk_t * NT, (NT * DM,), (s_sk_m,), (i_m * BM,), (BM,), (0,)) - p_ck = tl.make_block_ptr(ck + i_bh * s_sk_h, (T, DM), (s_sk_t, s_sk_m), (0, i_m * BM), (BT, BM), (1, 0)) - p_pk = tl.make_block_ptr(pk + i_bh * s_sk_h, (T, DM), (s_sk_t, s_sk_m), (0, i_m * BM), (BT, BM), (1, 0)) - - o_i = tl.arange(0, BT) - # [BT, BT] - m_s = o_i[:, None] >= o_i[None, :] - - b_hk = tl.zeros([BK, BM], dtype=tl.float32) - for _ in range(NT): - # [BT, BK] - b_q = tl.load(p_q, boundary_check=(0, 1)) - b_q = (b_q * scale).to(b_q.dtype) - # [BK, BT] - b_k = tl.load(p_k, boundary_check=(0, 1)) - # [BM,] - b_rk = tl.load(p_rk, boundary_check=(0,)) - # [BT, BM] - b_ck = tl.load(p_ck, boundary_check=(0, 1)) - b_pk = tl.load(p_pk, boundary_check=(0, 1)) - - # [BT, BM] - b_inter = tl.dot(b_q, b_hk.to(b_q.dtype), allow_tf32=False) * b_rk[None, :] - b_intra = tl.dot(tl.where(m_s, tl.dot(b_q, b_k, allow_tf32=False), 0).to(b_q.dtype), b_ck, allow_tf32=False) - b_s = (b_inter + b_intra) * b_pk - # [BK, BM] - b_hk = b_hk * b_rk[None, :] + tl.dot(b_k, b_ck, allow_tf32=False) - - tl.store(p_s, b_s.to(p_s.dtype.element_ty), boundary_check=(0, 1)) - - p_q = tl.advance(p_q, (BT, 0)) - p_k = tl.advance(p_k, (0, BT)) - p_s = tl.advance(p_s, (BT, 0)) - p_rk = tl.advance(p_rk, (DM,)) - p_ck = tl.advance(p_ck, (BT, 0)) - p_pk = tl.advance(p_pk, (BT, 0)) - - -@triton.jit -def chunk_abc_fwd_kernel_o( - p, - v, - o, - rv, # rescale term - cv, # scores normalized over a chunk - pv, # scores normalized over the sequence - s_qk_h, - s_qk_t, - s_qk_d, - s_sk_h, - s_sk_t, - s_sk_m, - T, - BT: tl.constexpr, - BM: tl.constexpr, - BV: tl.constexpr, - DM: tl.constexpr, - DV: tl.constexpr, - NT: tl.constexpr -): - i_v, i_m, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) - n_bh = tl.num_programs(2) - - p_p = tl.make_block_ptr(p + i_bh * s_sk_h, (T, DM), (s_sk_t, s_sk_m), (0, i_m * BM), (BT, BM), (1, 0)) - p_v = tl.make_block_ptr(v + i_bh * s_qk_h, (T, DV), (s_qk_t, s_qk_d), (0, i_v * BV), (BT, BV), (1, 0)) - p_o = tl.make_block_ptr(o + (i_m * n_bh + i_bh)*s_qk_h, (T, DV), (s_qk_t, s_qk_d), (0, i_v * BV), (BT, BV), (1, 0)) - p_rv = tl.make_block_ptr(rv + i_bh * s_sk_t * NT, (NT * DM,), (s_sk_m,), (i_m * BM,), (BM,), (0,)) - p_cv = tl.make_block_ptr(cv + i_bh * s_sk_h, (DM, T), (s_sk_m, s_sk_t), (i_m * BM, 0), (BM, BT), (0, 1)) - p_pv = tl.make_block_ptr(pv + i_bh * s_sk_h, (T, DM), (s_sk_t, s_sk_m), (0, i_m * BM), (BT, BM), (1, 0)) - - o_i = tl.arange(0, BT) - # [BT, BT] - m_s = o_i[:, None] >= o_i[None, :] - - # [BM, BV] - b_hv = tl.zeros([BM, BV], dtype=tl.float32) - for _ in range(NT): - # [BT, BM] - b_p = tl.load(p_p, boundary_check=(0, 1)) - # [BT, DV] - b_v = tl.load(p_v, boundary_check=(0, 1)) - # [BM,] - b_rv = tl.load(p_rv, boundary_check=(0,)) - # [BM, BT] - b_cv = tl.load(p_cv, boundary_check=(0, 1)) - # [BT, BM] - b_pv = tl.load(p_pv, boundary_check=(0, 1)) - - b_p = b_p * b_pv - # [BT, BV] - b_inter = tl.dot((b_p * b_rv[None, :]).to(b_v.dtype), b_hv.to(b_v.dtype), allow_tf32=False) - b_intra = tl.where(m_s, tl.dot(b_p.to(b_v.dtype), b_cv, allow_tf32=False), 0) - b_intra = tl.dot(b_intra.to(b_v.dtype), b_v, allow_tf32=False) - b_o = b_inter + b_intra - # [BM, BV] - b_hv = b_hv * b_rv[:, None] + tl.dot(b_cv, b_v, allow_tf32=False) - - tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) - - p_p = tl.advance(p_p, (BT, 0)) - p_v = tl.advance(p_v, (BT, 0)) - p_o = tl.advance(p_o, (BT, 0)) - p_rv = tl.advance(p_rv, (DM,)) - p_cv = tl.advance(p_cv, (0, BT)) - p_pv = tl.advance(p_pv, (BT, 0)) - - -@triton.jit -def chunk_abc_bwd_kernel_dp( - v, - rv, # rescale term - cv, # scores normalized over a chunk - pv, # scores normalized over the sequence - do, - dp, - s_qk_h, - s_qk_t, - s_qk_d, - s_sk_h, - s_sk_t, - s_sk_m, - T, - BT: tl.constexpr, - BV: tl.constexpr, - BM: tl.constexpr, - DV: tl.constexpr, - DM: tl.constexpr, - NT: tl.constexpr -): - i_m, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) - n_bh = tl.num_programs(2) - - p_v = tl.make_block_ptr(v + i_bh * s_qk_h, (DV, T), (s_qk_d, s_qk_t), (i_v * BV, 0), (BV, BT), (0, 1)) - p_rv = tl.make_block_ptr(rv + i_bh * s_sk_t * NT, (NT * DM,), (s_sk_m,), (i_m * BM,), (BM,), (0,)) - p_cv = tl.make_block_ptr(cv + i_bh * s_sk_h, (T, DM), (s_sk_t, s_sk_m), (0, i_m * BM), (BT, BM), (1, 0)) - p_pv = tl.make_block_ptr(pv + i_bh * s_sk_h, (T, DM), (s_sk_t, s_sk_m), (0, i_m * BM), (BT, BM), (1, 0)) - p_do = tl.make_block_ptr(do + i_bh * s_qk_h, (T, DV), (s_qk_t, s_qk_d), (0, i_v * BV), (BT, BV), (1, 0)) - p_dp = tl.make_block_ptr(dp + (i_v * n_bh + i_bh)*s_sk_h, (T, DM), (s_sk_t, s_sk_m), (0, i_m * BM), (BT, BM), (1, 0)) - - o_i = tl.arange(0, BT) - # [BT, BT] - m_s = o_i[:, None] >= o_i[None, :] - - # [BV, BM] - b_hv = tl.zeros([BV, BM], dtype=tl.float32) - for _ in range(NT): - # [BV, BT] - b_v = tl.load(p_v, boundary_check=(0, 1)) - # [BM,] - b_rv = tl.load(p_rv, boundary_check=(0,)) - # [BT, BM] - b_cv = tl.load(p_cv, boundary_check=(0, 1)) - b_pv = tl.load(p_pv, boundary_check=(0, 1)) - # [BT, BV] - b_do = tl.load(p_do, boundary_check=(0, 1)) - - # [BT, BM] - b_inter = tl.dot(b_do, b_hv.to(b_do.dtype), allow_tf32=False) * b_rv[None, :] - b_intra = tl.dot(tl.where(m_s, tl.dot(b_do, b_v, allow_tf32=False), 0).to(b_v.dtype), b_cv, allow_tf32=False) - b_dp = (b_inter + b_intra) * b_pv - # [BV, BM] - b_hv = b_hv * b_rv[None, :] + tl.dot(b_v, b_cv, allow_tf32=False) - - tl.store(p_dp, b_dp.to(p_dp.dtype.element_ty), boundary_check=(0, 1)) - - p_v = tl.advance(p_v, (0, BT)) - p_rv = tl.advance(p_rv, (DM,)) - p_cv = tl.advance(p_cv, (BT, 0)) - p_pv = tl.advance(p_pv, (BT, 0)) - p_do = tl.advance(p_do, (BT, 0)) - p_dp = tl.advance(p_dp, (BT, 0)) - - -@triton.jit -def chunk_abc_bwd_kernel_dq( - k, - rk, # rescale term - ck, # scores normalized over a chunk - dq, - ds, - s_qk_h, - s_qk_t, - s_qk_d, - s_sk_h, - s_sk_t, - s_sk_m, - T, - BT: tl.constexpr, - BK: tl.constexpr, - BM: tl.constexpr, - DK: tl.constexpr, - DM: tl.constexpr, - NT: tl.constexpr -): - i_k, i_m, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) - n_bh = tl.num_programs(2) - - p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (0, i_k * BK), (BT, BK), (1, 0)) - p_rk = tl.make_block_ptr(rk + i_bh * s_sk_t * NT, (NT * DM,), (s_sk_m,), (i_m * BM,), (BM,), (0,)) - p_ck = tl.make_block_ptr(ck + i_bh * s_sk_h, (DM, T), (s_sk_m, s_sk_t), (i_m * BM, 0), (BM, BT), (0, 1)) - p_dq = tl.make_block_ptr(dq + (i_m * n_bh + i_bh)*s_qk_h, (T, DK), (s_qk_t, s_qk_d), (0, i_k * BK), (BT, BK), (1, 0)) - p_ds = tl.make_block_ptr(ds + i_bh * s_sk_h, (T, DM), (s_sk_t, s_sk_m), (0, i_m * BM), (BT, BM), (1, 0)) - - o_i = tl.arange(0, BT) - # [BT, BT] - m_s = o_i[:, None] >= o_i[None, :] - - # [BM, BK] - b_hk = tl.zeros([BM, BK], dtype=tl.float32) - for _ in range(NT): - # [BT, BK] - b_k = tl.load(p_k, boundary_check=(0, 1)) - # [BM,] - b_rk = tl.load(p_rk, boundary_check=(0,)) - # [BM, BT] - b_ck = tl.load(p_ck, boundary_check=(0, 1)) - # [BT, BM] - b_ds = tl.load(p_ds, boundary_check=(0, 1)) - - # [BT, BK] - b_inter = tl.dot((b_ds * b_rk[None, :]).to(b_k.dtype), b_hk.to(b_k.dtype), allow_tf32=False) - b_intra = tl.dot(tl.where(m_s, tl.dot(b_ds, b_ck, allow_tf32=False), 0).to(b_k.dtype), b_k, allow_tf32=False) - b_dq = b_inter + b_intra - # [BM, BK] - b_hk = b_hk * b_rk[:, None] + tl.dot(b_ck, b_k, allow_tf32=False) - - tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) - - p_k = tl.advance(p_k, (BT, 0)) - p_rk = tl.advance(p_rk, (DM,)) - p_ck = tl.advance(p_ck, (0, BT)) - p_dq = tl.advance(p_dq, (BT, 0)) - p_ds = tl.advance(p_ds, (BT, 0)) - - -@triton.jit -def chunk_abc_bwd_kernel_dk( - q, - k, - rk, # rescale term - ck, # scores normalized over a chunk - ds, - dk, - dsk, - s_qk_h, - s_qk_t, - s_qk_d, - s_sk_h, - s_sk_t, - s_sk_m, - T, - BT: tl.constexpr, - BK: tl.constexpr, - BM: tl.constexpr, - DK: tl.constexpr, - DM: tl.constexpr, - NT: tl.constexpr -): - i_k, i_m, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) - n_bh = tl.num_programs(2) - - p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), ((NT-1)*BT, i_k * BK), (BT, BK), (1, 0)) - p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, (NT-1)*BT), (BK, BT), (0, 1)) - p_rk = tl.make_block_ptr(rk + i_bh * s_sk_t * NT, (NT * DM,), (s_sk_m,), (i_m * BM,), (BM,), (0,)) - p_ck = tl.make_block_ptr(ck + i_bh * s_sk_h, (T, DM), (s_sk_t, s_sk_m), ((NT-1)*BT, i_m * BM), (BT, BM), (1, 0)) - p_ds = tl.make_block_ptr(ds + i_bh * s_sk_h, (DM, T), (s_sk_m, s_sk_t), (i_m * BM, (NT-1)*BT), (BM, BT), (0, 1)) - p_dk = tl.make_block_ptr(dk + (i_m*n_bh+i_bh)*s_qk_h, (T, DK), (s_qk_t, s_qk_d), ((NT-1)*BT, i_k * BK), (BT, BK), (1, 0)) - p_dsk = tl.make_block_ptr(dsk + (i_k*n_bh+i_bh)*s_sk_h, (T, DM), (s_sk_t, s_sk_m), ((NT-1)*BT, i_m * BM), (BT, BM), (1, 0)) - - o_i = tl.arange(0, BT) - # [BT, BT] - m_s, m_t = o_i[:, None] <= o_i[None, :], o_i[:, None] >= o_i[None, :] - - # [BM, BK] - b_dhk = tl.zeros([BM, BK], dtype=tl.float32) - for i in range(NT): - p_rk = tl.make_block_ptr(rk + i_bh * s_sk_t * NT, (NT * DM,), (s_sk_m,), (((NT-i) % NT) * DM + i_m * BM,), (BM,), (0,)) - # [BT, BK] - b_q = tl.load(p_q, boundary_check=(0, 1)) - # [BK, BT] - b_k = tl.load(p_k, boundary_check=(0, 1)) - # [BM,] - b_rk = tl.load(p_rk, boundary_check=(0,)) - # [BT, BM] - b_ck = tl.load(p_ck, boundary_check=(0, 1)) - b_ds = tl.load(p_ds, boundary_check=(0, 1)) - - # [BT, BK] - b_inter = tl.dot((b_ck * b_rk[None, :]).to(b_q.dtype), b_dhk.to(b_q.dtype), allow_tf32=False) - b_intra = tl.dot(tl.where(m_s, tl.dot(b_ck, b_ds, allow_tf32=False), 0.).to(b_q.dtype), b_q, allow_tf32=False) - b_dk = b_inter + b_intra - - # [BM, BT] - b_inter = tl.dot(b_dhk.to(b_k.dtype), b_k, allow_tf32=False) * b_rk[:, None] - b_intra = tl.dot(b_ds, tl.where(m_t, tl.dot(b_q, b_k, allow_tf32=False), 0.).to(b_q.dtype), allow_tf32=False) - # [BT, BM] - b_dsk = b_ck * tl.trans(b_inter + b_intra) - - # [BM, BK] - b_dhk = b_dhk * b_rk[:, None] + tl.dot(b_ds, b_q, allow_tf32=False) - - tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) - tl.store(p_dsk, b_dsk.to(p_dsk.dtype.element_ty), boundary_check=(0, 1)) - - p_q = tl.advance(p_q, (-BT, 0)) - p_k = tl.advance(p_k, (0, -BT)) - p_ck = tl.advance(p_ck, (-BT, 0)) - p_ds = tl.advance(p_ds, (0, -BT)) - p_dk = tl.advance(p_dk, (-BT, 0)) - p_dsk = tl.advance(p_dsk, (-BT, 0)) - - -@triton.jit -def chunk_abc_bwd_kernel_dv( - do, - v, - rv, # rescale term - cv, # scores normalized over a chunk - p, - dv, - dsv, - s_qk_h, - s_qk_t, - s_qk_d, - s_sk_h, - s_sk_t, - s_sk_m, - T, - BT: tl.constexpr, - BV: tl.constexpr, - BM: tl.constexpr, - DV: tl.constexpr, - DM: tl.constexpr, - NT: tl.constexpr -): - i_v, i_m, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) - n_bh = tl.num_programs(2) - - p_do = tl.make_block_ptr(do + i_bh * s_qk_h, (T, DV), (s_qk_t, s_qk_d), ((NT-1)*BT, i_v * BV), (BT, BV), (1, 0)) - p_v = tl.make_block_ptr(v + i_bh * s_qk_h, (DV, T), (s_qk_d, s_qk_t), (i_v * BV, (NT-1)*BT), (BV, BT), (0, 1)) - p_rv = tl.make_block_ptr(rv + i_bh * s_sk_t * NT, (NT * DM,), (s_sk_m,), (i_m * BM,), (BM,), (0,)) - p_cv = tl.make_block_ptr(cv + i_bh * s_sk_h, (T, DM), (s_sk_t, s_sk_m), ((NT-1)*BT, i_m * BM), (BT, BM), (1, 0)) - p_p = tl.make_block_ptr(p + i_bh * s_sk_h, (DM, T), (s_sk_m, s_sk_t), (i_m * BM, (NT-1)*BT), (BM, BT), (0, 1)) - p_dv = tl.make_block_ptr(dv + (i_m*n_bh+i_bh)*s_qk_h, (T, DV), (s_qk_t, s_qk_d), ((NT-1)*BT, i_v * BV), (BT, BV), (1, 0)) - p_dsv = tl.make_block_ptr(dsv + (i_v*n_bh+i_bh)*s_sk_h, (T, DM), (s_sk_t, s_sk_m), ((NT-1)*BT, i_m * BM), (BT, BM), (1, 0)) - - o_i = tl.arange(0, BT) - # [BT, BT] - m_s, m_t = o_i[:, None] <= o_i[None, :], o_i[:, None] >= o_i[None, :] - - # [BM, BV] - b_dhv = tl.zeros([BM, BV], dtype=tl.float32) - for i in range(NT): - p_rv = tl.make_block_ptr(rv + i_bh * s_sk_t * NT, (NT * DM,), (s_sk_m,), (((NT-i) % NT) * DM + i_m * BM,), (BM,), (0,)) - # [BT, BV] - b_do = tl.load(p_do, boundary_check=(0, 1)) - # [BV, BT] - b_v = tl.load(p_v, boundary_check=(0, 1)) - # [BM,] - b_rv = tl.load(p_rv, boundary_check=(0,)) - # [BT, BM] - b_cv = tl.load(p_cv, boundary_check=(0, 1)) - # [BM, BT] - b_p = tl.load(p_p, boundary_check=(0, 1)) - - # [BT, BV] - b_inter = tl.dot((b_cv * b_rv[None, :]).to(b_do.dtype), b_dhv.to(b_do.dtype), allow_tf32=False) - b_intra = tl.dot(tl.where(m_s, tl.dot(b_cv, b_p, allow_tf32=False), 0.).to(b_do.dtype), b_do, allow_tf32=False) - b_dv = b_inter + b_intra - - b_inter = tl.dot(b_dhv.to(b_v.dtype), b_v, allow_tf32=False) * b_rv[:, None] - b_intra = tl.dot(b_p, tl.where(m_t, tl.dot(b_do, b_v, allow_tf32=False), 0.).to(b_do.dtype), allow_tf32=False) - # [BT, BM] - b_dsv = b_cv * tl.trans(b_inter + b_intra) - - # [BM, BV] - b_dhv = b_dhv * b_rv[:, None] + tl.dot(b_p, b_do, allow_tf32=False) - - tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) - tl.store(p_dsv, b_dsv.to(p_dsv.dtype.element_ty), boundary_check=(0, 1)) - - p_do = tl.advance(p_do, (-BT, 0)) - p_v = tl.advance(p_v, (0, -BT)) - p_cv = tl.advance(p_cv, (-BT, 0)) - p_p = tl.advance(p_p, (0, -BT)) - p_dv = tl.advance(p_dv, (-BT, 0)) - p_dsv = tl.advance(p_dsv, (-BT, 0)) - - -@triton.jit -def chunk_abc_fwd_kernel_cum( - s, - r, - c, - p, - s_sk_h, - s_sk_t, - s_sk_m, - T, - BT: tl.constexpr, - BM: tl.constexpr, - DM: tl.constexpr, - NT: tl.constexpr -): - i_m, i_bh = tl.program_id(0), tl.program_id(1) - p_s = tl.make_block_ptr(s + i_bh * s_sk_h, (T, DM), (s_sk_t, s_sk_m), (0, i_m * BM), (BT, BM), (1, 0)) - p_r = tl.make_block_ptr(r + i_bh * s_sk_t * NT, (NT * DM,), (s_sk_m,), (i_m * BM,), (BM,), (0,)) - p_c = tl.make_block_ptr(c + i_bh * s_sk_h, (T, DM), (s_sk_t, s_sk_m), (0, i_m * BM), (BT, BM), (1, 0)) - p_p = tl.make_block_ptr(p + i_bh * s_sk_h, (T, DM), (s_sk_t, s_sk_m), (0, i_m * BM), (BT, BM), (1, 0)) - - b_mp = tl.zeros([BM,], dtype=tl.float32) - b_zp = tl.zeros([BM,], dtype=tl.float32) - for i in range(NT): - # [BT, BM] - b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32) - - b_m = tl.max(b_s, 0) - # workaround for some weird compiler bugs - if i == 0: - b_r = tl.exp(-b_m) - else: - b_m = tl.maximum(b_mp, b_m) - b_r = tl.exp(b_mp - b_m) - b_c = tl.exp(b_s - b_m[None, :]) - b_z = tl.cumsum(b_c, 0) + (b_zp * b_r)[None, :] - b_p = tl.exp(-tl.log(b_z)) - b_mp = b_m - b_zp = tl.max(b_z, 0) - - tl.store(p_r, b_r.to(p_r.dtype.element_ty), boundary_check=(0,)) - tl.store(p_c, b_c.to(p_c.dtype.element_ty), boundary_check=(0, 1)) - tl.store(p_p, b_p.to(p_p.dtype.element_ty), boundary_check=(0, 1)) - - p_s = tl.advance(p_s, (BT, 0)) - p_r = tl.advance(p_r, (DM,)) - p_c = tl.advance(p_c, (BT, 0)) - p_p = tl.advance(p_p, (BT, 0)) - - -@triton.jit -def chunk_abc_bwd_kernel_rcum( - s, - r, - c, - o, - s_sk_h, - s_sk_t, - s_sk_m, - T, - BT: tl.constexpr, - BM: tl.constexpr, - DM: tl.constexpr, - NT: tl.constexpr -): - i_m, i_bh = tl.program_id(0), tl.program_id(1) - p_s = tl.make_block_ptr(s + i_bh * s_sk_h, (T, DM), (s_sk_t, s_sk_m), ((NT-1)*BT, i_m * BM), (BT, BM), (1, 0)) - p_c = tl.make_block_ptr(c + i_bh * s_sk_h, (T, DM), (s_sk_t, s_sk_m), ((NT-1)*BT, i_m * BM), (BT, BM), (1, 0)) - p_o = tl.make_block_ptr(o + i_bh * s_sk_h, (T, DM), (s_sk_t, s_sk_m), ((NT-1)*BT, i_m * BM), (BT, BM), (1, 0)) - - o_i = tl.arange(0, BT) - # [BT, BT] - m_t = tl.where(o_i[:, None] <= o_i[None, :], 1., 0.) - - b_z = tl.zeros([BM,], dtype=tl.float32) - for i in range(NT): - p_r = tl.make_block_ptr(r + i_bh * s_sk_t * NT, (NT * DM,), (s_sk_m,), (((NT-i) % NT) * DM + i_m * BM,), (BM,), (0,)) - # [BT, BM] - b_s = tl.load(p_s, boundary_check=(0, 1)) - # [BM,] - b_r = tl.load(p_r, boundary_check=(0,)) - # [BT, BM] - b_c = tl.load(p_c, boundary_check=(0, 1)) - b_o = tl.load(p_o, boundary_check=(0, 1)) - - b_z = b_z * b_r - b_o -= b_c * (b_z[None, :] + tl.dot(m_t.to(b_s.dtype), b_s, allow_tf32=False)) - - # [BM,] - b_z += tl.sum(b_s, 0) - - tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) - - p_s = tl.advance(p_s, (-BT, 0)) - p_c = tl.advance(p_c, (-BT, 0)) - p_o = tl.advance(p_o, (-BT, 0)) - - -class FusedChunkABCFunction(torch.autograd.Function): - - @staticmethod - @contiguous - def forward(ctx, q, k, v, sk, sv): - batch_size, n_heads, seq_len, d_head_qk, d_head_v, n_slots = *q.shape, v.shape[-1], sk.shape[-1] - scale = d_head_qk ** -0.5 - - DT, DK, DV, DM = seq_len, d_head_qk, d_head_v, n_slots - BT = 16 - if batch_size * n_heads > 100: - BK, BV, BM = min(DK, 64), min(DV, 64), min(DM, 64) - num_stages = 1 - num_warps = 2 - else: - # SM is not fully utilized so we add more parallelism in the hidden state dimension. - BK, BV, BM = min(DK, 32), min(DV, 32), min(DM, 32) - num_stages = 1 - num_warps = 1 - NT, NK, NV, NM = triton.cdiv(DT, BT), triton.cdiv(DK, BK), triton.cdiv(DV, BV), triton.cdiv(DM, BM) - - rk, ck, pk = sk.new_empty(batch_size, n_heads, NT, DM), torch.empty_like(sk), torch.empty_like(sk) - grid = (NM, batch_size * n_heads) - chunk_abc_fwd_kernel_cum[grid]( - sk, rk, ck, pk, - sk.stride(1), sk.stride(2), sk.stride(3), - seq_len, - BT=BT, BM=BM, DM=DM, NT=NT, - num_warps=num_warps, - num_stages=num_stages - ) - rv, cv, pv = sv.new_empty(batch_size, n_heads, NT, DM), torch.empty_like(sv), torch.empty_like(sv) - chunk_abc_fwd_kernel_cum[grid]( - sv, rv, cv, pv, - sv.stride(1), sv.stride(2), sv.stride(3), - seq_len, - BT=BT, BM=BM, DM=DM, NT=NT, - num_warps=num_warps, - num_stages=num_stages - ) - - s = q.new_empty(NK, batch_size, n_heads, seq_len, n_slots) - grid = (NM, NK, batch_size * n_heads) - chunk_abc_fwd_kernel_s[grid]( - q, k, s, rk, ck, pk, - q.stride(1), q.stride(2), q.stride(3), - sk.stride(1), sk.stride(2), sk.stride(3), - seq_len, scale, - BT=BT, BK=BK, BM=BM, DK=DK, DM=DM, NT=NT, - num_warps=num_warps, - num_stages=num_stages - ) - s = s.sum(0) - p = s.softmax(-1, dtype=torch.float).to(q.dtype) - o = q.new_empty(NM, batch_size, n_heads, seq_len, d_head_v) - grid = (NV, NM, batch_size * n_heads) - chunk_abc_fwd_kernel_o[grid]( - p, v, o, rv, cv, pv, - q.stride(1), q.stride(2), q.stride(3), - sk.stride(1), sk.stride(2), sk.stride(3), - seq_len, - BT=BT, BM=BM, BV=BV, DM=DM, DV=DV, NT=NT, - num_warps=num_warps, - num_stages=num_stages - ) - o = o.sum(0) - ctx.save_for_backward(q, k, v, o, s, p, rk, ck, pk, rv, cv, pv) - ctx.batch_size = batch_size - ctx.n_heads = n_heads - ctx.seq_len = seq_len - ctx.n_slots = n_slots - ctx.dtype = q.dtype - ctx.scale = scale - ctx.BT = BT - return o - - @staticmethod - @contiguous - def backward(ctx, do): - q, k, v, o, s, p, rk, ck, pk, rv, cv, pv = ctx.saved_tensors - batch_size, n_heads, seq_len, d_head_qk, d_head_v, n_slots = *q.shape, v.shape[-1], s.shape[-1] - scale = d_head_qk ** -0.5 - - DT, DK, DV, DM = seq_len, d_head_qk, d_head_v, n_slots - BT = ctx.BT - if batch_size * n_heads > 100: - BK, BV, BM = min(DK, 64), min(DV, 64), min(DM, 64) - num_stages = 1 - num_warps = 2 - else: - BK, BV, BM = min(DK, 32), min(DV, 32), min(DM, 32) - num_stages = 1 - num_warps = 2 - NT, NK, NV, NM = triton.cdiv(DT, BT), triton.cdiv(DK, BK), triton.cdiv(DV, BV), triton.cdiv(DM, BM) - dp = s.new_empty(NV, *s.shape) - grid = (NM, NV, batch_size * n_heads) - chunk_abc_bwd_kernel_dp[grid]( - v, rv, cv, pv, do, dp, - q.stride(1), q.stride(2), q.stride(3), - s.stride(1), s.stride(2), s.stride(3), - seq_len, - BT=BT, BV=BV, BM=BM, DV=DV, DM=DM, NT=NT, - num_warps=num_warps, - num_stages=num_stages - ) - dp = dp.sum(0) - ds = p * (dp - (o * do).sum(-1, True)) * pk - dss = ds * scale - dq, dk, dv = q.new_empty(NM, *q.shape), k.new_empty(NM, *k.shape), v.new_empty(NM, *v.shape) - dsk, dsv = s.new_empty(NK, *s.shape), s.new_empty(NV, *s.shape) - grid = (NK, NM, batch_size * n_heads) - chunk_abc_bwd_kernel_dq[grid]( - k, rk, ck, dq, dss, - q.stride(1), q.stride(2), q.stride(3), - s.stride(1), s.stride(2), s.stride(3), - seq_len, - BT=BT, BK=BK, BM=BM, DK=DK, DM=DM, NT=NT, - num_warps=num_warps, - num_stages=num_stages - ) - dq = dq.sum(0) - chunk_abc_bwd_kernel_dk[grid]( - q, k, rk, ck, dss, dk, dsk, - q.stride(1), q.stride(2), q.stride(3), - s.stride(1), s.stride(2), s.stride(3), - seq_len, - BT=BT, BK=BK, BM=BM, DK=DK, DM=DM, NT=NT, - num_warps=num_warps, - num_stages=num_stages - ) - dk, dsk = dk.sum(0), dsk.sum(0) - - p = p * pv - grid = (NV, NM, batch_size * n_heads) - chunk_abc_bwd_kernel_dv[grid]( - do, v, rv, cv, p, dv, dsv, - q.stride(1), q.stride(2), q.stride(3), - s.stride(1), s.stride(2), s.stride(3), - seq_len, - BT=BT, BV=BV, BM=BM, DV=DV, DM=DM, NT=NT, - num_warps=num_warps, - num_stages=num_stages - ) - dv, dsv = dv.sum(0), dsv.sum(0) - grid = (NM, batch_size * n_heads) - chunk_abc_bwd_kernel_rcum[grid]( - ds * s, rk, ck, dsk, - s.stride(1), s.stride(2), s.stride(3), - seq_len, - BT=BT, BM=BM, DM=DM, NT=NT, - num_warps=num_warps, - num_stages=num_stages - ) - chunk_abc_bwd_kernel_rcum[grid]( - p * dp, rv, cv, dsv, - s.stride(1), s.stride(2), s.stride(3), - seq_len, - BT=BT, BM=BM, DM=DM, NT=NT, - num_warps=num_warps, - num_stages=num_stages - ) - return dq, dk, dv, dsk, dsv - - -fused_chunk_abc = FusedChunkABCFunction.apply diff --git a/flash_linear_attention/fla/ops/triton/based/__init__.py b/flash_linear_attention/fla/ops/triton/based/__init__.py deleted file mode 100644 index a18fb47..0000000 --- a/flash_linear_attention/fla/ops/triton/based/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -# -*- coding: utf-8 -*- - -from .chunk_fuse import fused_chunk_based -from .parallel import parallel_based - -__all__ = ["parallel_based", "fused_chunk_based"] diff --git a/flash_linear_attention/fla/ops/triton/based/chunk_fuse.py b/flash_linear_attention/fla/ops/triton/based/chunk_fuse.py deleted file mode 100644 index ce1627f..0000000 --- a/flash_linear_attention/fla/ops/triton/based/chunk_fuse.py +++ /dev/null @@ -1,410 +0,0 @@ -# -*- coding: utf-8 -*- - -import torch -import triton -import triton.language as tl -from torch.cuda.amp import custom_bwd, custom_fwd - -from fla.ops.triton.utils import contiguous - -# on-the-fly computation without materializing hidden statets into HBMs - - -@triton.jit -def fused_chunk_based_fwd_kernel( - # B: batch_size, H: n_heads, T: seq_len, D: d_head - q, # query [B, H, L, D_head_K] - k, # key [B, H, L, D_head_V] - v, # value [B, H, L, D_head_V] - o, # output [B, H, L, D_head_V] - z, # normalizer [B, H, L, 1] - s_qk_h, # stride size: L * D_head_K - s_qk_t, # stride size: D_head_K - s_qk_d, # stride size: 1 - s_vo_h, # stride size: L * D_head_V - s_vo_t, # stride size: D_head_V - s_vo_d, # stride size: 1 - B, # batch size - H, # n_heads - T, # seq_len - scale, # D_head_K ** -0.5 - BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size - BK: tl.constexpr, # BLOCK SIZE along the K dimension - BV: tl.constexpr, # BLOCK SIZE along the V dimension - DK: tl.constexpr, # D_head_K - DV: tl.constexpr, # D_head_V -): - # indices - i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) - - o_i = tl.arange(0, BT) - - # [BT, BT] - m_s = o_i[:, None] >= o_i[None, :] - - # [BV], zero-order taylor expansion - b_h_0o = tl.zeros([BV], dtype=tl.float32) - # [BK, BV], first-order taylor expansion - b_h_1o = tl.zeros([BK, BV], dtype=tl.float32) - # [BK, BK, BV] second-order taylor expansion - b_h_2o = tl.zeros([BK*BK, BV], dtype=tl.float32) - - # make block pointers - p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK), - (s_qk_t, s_qk_d), (0, i_k * BK), (BT, BK), (1, 0)) - p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T), - (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BT), (0, 1)) - p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), - (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0)) - p_o = tl.make_block_ptr(o + (i_bh + i_k*B*H) * s_vo_h, (T, DV), - (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0)) - - p_z = z + (i_bh + i_k * B * H) * T + tl.arange(0, BT) - k_2o = tl.zeros([1, BK * BK], dtype=tl.float32) - k_1o = tl.zeros([1, BK], dtype=tl.float32) - k_0o = 0 - - for i in range(0, tl.cdiv(T, BT)): - # [BK, BT] - b_k = tl.load(p_k, boundary_check=(0, 1)) - # [BK*BK, BT] - b_k_2o = b_k[:, None, :] * b_k[None, :, :] - b_k_2o = tl.reshape(b_k_2o, [BK * BK, BT]).to(b_k.dtype) - # [BT, BV] - b_v = tl.load(p_v, boundary_check=(0, 1)) - # [BT, BK] - b_q = (tl.load(p_q, boundary_check=(0, 1)) * scale).to(b_k.dtype) - b_o = tl.zeros([BT, BV], dtype=tl.float32) - b_z = tl.zeros([BT], dtype=tl.float32) - - # interchunk - # zero-order - b_o += b_h_0o - b_z += k_0o - # first-order - b_o += tl.dot(b_q, b_h_1o.to(b_q.dtype), allow_tf32=False) - b_z += tl.sum(b_q * k_1o, axis=1) - # second-order - b_q_2o = b_q[:, :, None] * b_q[:, None, :] - b_q_2o = tl.reshape(b_q_2o, [BT, BK * BK]).to(b_k.dtype) - b_o += tl.dot(b_q_2o, b_h_2o.to(b_q_2o.dtype), allow_tf32=False) * 0.5 - b_z += tl.sum(b_q_2o * k_2o, axis=1) * 0.5 - - # update running statistics - k_1o += tl.sum(b_k, axis=1)[None, :] - k_2o += tl.sum(b_k_2o, axis=1)[None, :] - k_0o += BT - - # intrachunk - # [BT, BT] - b_s = tl.dot(b_q, b_k, allow_tf32=False) - b_s = 1 + b_s + 0.5 * b_s * b_s - b_s = tl.where(m_s, b_s, 0) - b_z += tl.sum(b_s, axis=1) - b_o += tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False) - # [TB, BV] - tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) - tl.store(p_z, b_z.to(p_z.dtype.element_ty), - mask=(i * BT + tl.arange(0, BT)) < T) - - # update hidden state - # [BK, BV] - b_h_2o = b_h_2o + tl.dot(b_k_2o.to(b_v.dtype), b_v, allow_tf32=False) - b_h_1o = b_h_1o + tl.dot(b_k, b_v, allow_tf32=False) - b_h_0o = b_h_0o + tl.sum(b_v, axis=0) - - p_q = tl.advance(p_q, (BT, 0)) - p_k = tl.advance(p_k, (0, BT)) - p_v = tl.advance(p_v, (BT, 0)) - p_o = tl.advance(p_o, (BT, 0)) - p_z += BT - - -# Similar to Algorithm1 of https://arxiv.org/abs/2006.16236 -@triton.jit -def fused_chunk_based_bwd_kernel( - # B: batch_size, H: n_heads, T: seq_len, D: d_head - # NV: number of split in the V dimension. NK: number of split in the K dimension - q, # query [B, H, L, D_head_K] - k, # key [B, H, L, D_head_V] - v, # value [B, H, L, D_head_V] - do, # gradient of output [B, H, L, D_head_V] - dz, # gradient of normalizer [B, H, L] - dq, # gradient of query [NV, B, H, L, D_head_K] - dk, # gradient of key [NV, B, H, L, D_head_K] - dv, # gradient of value [NK, B, H, L, D_head_V] - s_qk_h, # stride size: L * D_head_K - s_qk_t, # stride size: D_head_K - s_qk_d, # stride size: 1 - s_vo_h, # stride size: L * D_head_V - s_vo_t, # stride size: D_head_V - s_vo_d, # stride size: 1 - B, # batch_size - H, # n_heads - T, # seq_len - scale, # D_head_K ** -0.5 - BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size - BK: tl.constexpr, # BLOCK SIZE along the K dimension - BV: tl.constexpr, # BLOCK SIZE along the V dimension - DK: tl.constexpr, # D_head_K - DV: tl.constexpr, # D_head_V -): - i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) - - o_i = tl.arange(0, BT) - m_s = o_i[:, None] >= o_i[None, :] - - # [BV], zero-order taylor expansion - # b_h_0o = tl.zeros([BV], dtype=tl.float32) - # [BK, BV], first-order taylor expansion - b_h_1o = tl.zeros([BV, BK], dtype=tl.float32) - # [BK, BK, BV] second-order taylor expansion - b_h_2o = tl.zeros([BV, BK*BK], dtype=tl.float32) - - k_1o = tl.zeros([1, BK], dtype=tl.float32) - k_2o = tl.zeros([1, BK * BK], dtype=tl.float32) - - for i in range(0, tl.cdiv(T, BT)): - p_q = tl.make_block_ptr( - q + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i * BT, i_k * BK), (BT, BK), (1, 0)) - p_k = tl.make_block_ptr( - k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i * BT, i_k * BK), (BT, BK), (1, 0)) - p_v = tl.make_block_ptr( - v + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i * BT), (BV, BT), (0, 1)) - p_do = tl.make_block_ptr( - do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i * BT, i_v * BV), (BT, BV), (1, 0)) - p_dq = tl.make_block_ptr(dq + (i_bh + i_v*B*H) * s_qk_h, - (T, DK), (s_qk_t, s_qk_d), (i*BT, i_k*BK), (BT, BK), (1, 0)) - p_dz = dz + (i_bh) * T + tl.arange(0, BT) + i * BT - b_dq = tl.zeros([BT, BK], dtype=tl.float32) - - # load tensors - # [BT, BK] - b_dz = tl.load(p_dz, mask=(tl.arange(0, BT) + i * BT) < T) - b_q = tl.load(p_q, boundary_check=(0, 1)) - b_q = (b_q * scale).to(b_q.dtype) - b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype) - b_k = tl.load(p_k, boundary_check=(0, 1)) - # [BV, BT] - b_v = tl.load(p_v, boundary_check=(0, 1)) - - # inter-chunk - b_dq += tl.dot(b_do, (b_h_1o).to(b_do.dtype), allow_tf32=False) - if i_v == 0: - b_dq += b_dz[:, None] * k_1o - b_dq_2o = tl.dot(b_do, (b_h_2o).to(b_do.dtype), allow_tf32=False) * 0.5 - if i_v == 0: - b_dq_2o += (b_dz[:, None] * k_2o) * 0.5 - b_dq_2o = tl.reshape(b_dq_2o, [BT, BK, BK]) - b_dq += tl.sum(b_dq_2o * b_q[:, :, None], axis=1) - b_dq += tl.sum(b_dq_2o * b_q[:, None, :], axis=2) - b_dq *= scale - - # intra-chunk - # [BT, BT] - b_ds = tl.dot(b_do, b_v, allow_tf32=False) - if i_v == 0: - b_ds += b_dz[:, None] - b_ds = tl.where(m_s, b_ds, 0) * scale - b_s = tl.dot(b_q, tl.trans(b_k), allow_tf32=False) - b_s = tl.where(m_s, b_s, 0) - b_dq += tl.dot((b_ds * (1 + b_s)).to(b_q.dtype), b_k, allow_tf32=False) - - # store - tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) - - # update hidden state - # [BT, BK*BK] - b_k_2o = b_k[:, :, None] * b_k[:, None, :] - b_k_2o = tl.reshape(b_k_2o, [BT, BK * BK]).to(b_k.dtype) - # [BV, BK*BK] - b_h_2o = b_h_2o + tl.dot(b_v, b_k_2o.to(b_v.dtype), allow_tf32=False) - # [BV, BK] - b_h_1o = b_h_1o + tl.dot(b_v, b_k, allow_tf32=False) - - if i_v == 0: - # update running statistics - k_1o += tl.sum(b_k, axis=0)[None, :] - k_2o += tl.sum(b_k_2o, axis=0)[None, :] - - tl.debug_barrier() - b_h_1o = None - b_h_2o = None - - # [BK, BV], first-order taylor expansion - b_dh_1o = tl.zeros([BK, BV], dtype=tl.float32) - # [BK, BK, BV] second-order taylor expansion - b_dh_2o = tl.zeros([BK*BK, BV], dtype=tl.float32) - b_dh_0o = tl.zeros([BV], dtype=tl.float32) - m_s = tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :] - - dq_1o = tl.zeros([1, BK], dtype=tl.float32) - dq_2o = tl.zeros([BK * BK, 1], dtype=tl.float32) - - for i in range(tl.cdiv(T, BT) * BT - BT, -BT, -BT): - p_q = tl.make_block_ptr( - q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, i), (BK, BT), (0, 1)) - p_k = tl.make_block_ptr( - k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i, i_k * BK), (BT, BK), (1, 0)) - p_v = tl.make_block_ptr( - v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i, i_v * BV), (BT, BV), (1, 0)) - p_do = tl.make_block_ptr( - do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i, i_v * BV), (BT, BV), (1, 0)) - p_dk = tl.make_block_ptr(dk + (i_bh+i_v*B*H) * s_qk_h, (T, DK), - (s_qk_t, s_qk_d), (i, i_k*BK), (BT, BK), (1, 0)) - p_dv = tl.make_block_ptr(dv + (i_bh+i_k*B*H) * s_vo_h, (T, DV), - (s_vo_t, s_vo_d), (i, i_v*BV), (BT, BV), (1, 0)) - p_dz = dz + (i_bh) * T + tl.arange(0, BT) + i - - b_dk = tl.zeros([BT, BK], dtype=tl.float32) - b_dv = tl.zeros([BT, BV], dtype=tl.float32) - - b_dz = tl.load(p_dz, mask=(tl.arange(0, BT)+i) < T) - b_q = tl.load(p_q, boundary_check=(0, 1)) - b_k = tl.load(p_k, boundary_check=(0, 1)) - b_v = tl.load(p_v, boundary_check=(0, 1)) - b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype) - b_q = (b_q * scale).to(b_k.dtype) - - # intra chunk - b_ds = tl.dot(b_v, tl.trans(b_do), allow_tf32=False) - if i_v == 0: - b_ds += b_dz[None, :] - b_ds = tl.where(m_s, b_ds, 0) - b_s = tl.dot(b_k, b_q, allow_tf32=False) - b_s2 = 1 + b_s + 0.5 * b_s * b_s - b_s = tl.where(m_s, b_s, 0) - b_s2 = tl.where(m_s, b_s2, 0) - b_ds *= (1+b_s) - - b_dk += tl.dot(b_ds.to(b_k.dtype), tl.trans(b_q), allow_tf32=False) - b_dv += tl.dot(b_s2.to(b_do.dtype), b_do, allow_tf32=False) - - # inter chunk - b_k_2o = b_k[:, :, None] * b_k[:, None, :] - b_k_2o = tl.reshape(b_k_2o, [BT, BK * BK]).to(b_k.dtype) - - b_dv += tl.dot(b_k, b_dh_1o.to(b_k.dtype), allow_tf32=False) - b_dv += tl.dot(b_k_2o, b_dh_2o.to(b_k.dtype), allow_tf32=False) - b_dv += b_dh_0o - - b_dk += tl.dot(b_v, tl.trans(b_dh_1o).to(b_k.dtype), allow_tf32=False) - - if i_v == 0: - b_dk += dq_1o - - b_dk_2o = tl.dot(b_dh_2o.to(b_k.dtype), - tl.trans(b_v), allow_tf32=False) - if i_v == 0: - b_dk_2o += dq_2o - b_dk_2o = tl.reshape(b_dk_2o, [BK, BK, BT]) - b_k_fp32 = tl.trans(b_k.to(tl.float32)) - b_dk2 = tl.sum(b_dk_2o * b_k_fp32[:, None, :], axis=0) - b_dk2 += tl.sum(b_dk_2o * b_k_fp32[None, :, :], axis=1) - b_dk += tl.trans(b_dk2) - - # hidden state update - b_dh_0o += tl.sum(b_do, axis=0) - b_dh_1o = b_dh_1o + tl.dot(b_q, b_do, allow_tf32=False) - b_q_2o = b_q[None, :, :] * b_q[:, None, :] - b_q_2o = tl.reshape(b_q_2o, [BK * BK, BT]).to(b_k.dtype) - b_dh_2o = b_dh_2o + tl.dot(b_q_2o, b_do, allow_tf32=False) * 0.5 - - if i_v == 0: - dq_1o += (tl.sum(b_dz[None, :] * b_q, axis=1))[None, :] - dq_2o += (tl.sum(b_dz[None, :] * b_q_2o, axis=1) * 0.5)[:, None] - - tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) - tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) - - -class FusedChunkBasedFunction(torch.autograd.Function): - - @staticmethod - @contiguous - @custom_fwd - def forward(ctx, q, k, v, scale=1): - batch_size, n_heads, seq_len, d_head_qk = q.shape - # assert d_head_qk == 16, "currently we do not support feature dim other than 16" - d_head_v = v.shape[-1] - - scale = scale - BT = 16 - BK, BV = min(d_head_qk, 16), min(d_head_v, 32) - BK, BV = max(BK, 16), max(BV, 16) - NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV) - - num_warps = 4 - - # the norm of o might explode, so we need to use float32 here - o = q.new_empty(NK, batch_size, n_heads, seq_len, - d_head_v, dtype=torch.float32) - z = q.new_empty(NK, batch_size, n_heads, seq_len, dtype=torch.float32) - - grid = (NV, NK, batch_size * n_heads) - fused_chunk_based_fwd_kernel[grid]( - q, k, v, o, z, - q.stride(1), q.stride(2), q.stride(3), - v.stride(1), v.stride(2), v.stride(3), - batch_size, n_heads, seq_len, scale, - BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV, - num_warps=num_warps, - ) - o = o.sum(0) - z = z.sum(0) - ctx.save_for_backward(q, k, v) - ctx.scale = scale - return o.to(q.dtype), z.to(z.dtype) - - @staticmethod - @contiguous - @custom_bwd - def backward(ctx, do, dz): - q, k, v = ctx.saved_tensors - batch_size, n_heads, seq_len, d_head_qk = q.shape - d_head_v = v.shape[-1] - scale = ctx.scale - - BT = 16 - BK, BV = min(d_head_qk, 16), min(d_head_v, 32) - BK, BV = max(BK, 16), max(BV, 16) - NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV) - num_stages = 1 - num_warps = 4 - - dq = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk) - dk = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk) - dv = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v) - grid = (NV, NK, batch_size * n_heads) - - fused_chunk_based_bwd_kernel[grid]( - q, k, v, do, dz, dq, dk, dv, - q.stride(1), q.stride(2), q.stride(3), - v.stride(1), v.stride(2), v.stride(3), - batch_size, n_heads, seq_len, scale, - BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV, - num_warps=num_warps, - num_stages=num_stages - ) - dq = dq.sum(0) - dk = dk.sum(0) - dv = dv.sum(0) - return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), None - - -triton_fused_chunk_based = FusedChunkBasedFunction.apply - - -def fused_chunk_based(q, k, v, eps: float = 1e-6, use_scale: bool = True, use_normalize: bool = True): - assert q.shape[-1] <= 16, 'only support feature dimension up to 16.' - if use_scale: - scale = q.shape[-1] ** -0.5 - else: - scale = 1 - o, z = triton_fused_chunk_based(q, k, v, scale) - if use_normalize: - o = o / (z[..., None] + eps) - else: - o = o - - return o.to(q.dtype) diff --git a/flash_linear_attention/fla/ops/triton/based/parallel.py b/flash_linear_attention/fla/ops/triton/based/parallel.py deleted file mode 100644 index fcff0ad..0000000 --- a/flash_linear_attention/fla/ops/triton/based/parallel.py +++ /dev/null @@ -1,385 +0,0 @@ - -# -*- coding: utf-8 -*- - -import torch -import triton -import triton.language as tl - -from fla.ops.triton.utils import contiguous -from torch.cuda.amp import custom_bwd, custom_fwd - - -@triton.jit -def parallel_based_fwd_kernel( - # B: batch_size, H: n_heads, T: seq_len, D: d_head - q, # query [B, H, L, D_head_K] - k, # key [B, H, L, D_head_V] - v, # value [B, H, L, D_head_V] - o, # output [B, H, L, D_head_V] - z, # normalizer [B, H, L] - s_qk_h, # stride size: L * D_head_K - s_qk_t, # stride size: D_head_K - s_qk_d, # stride size: 1 - s_vo_h, # stride size: L * D_head_V - s_vo_t, # stride size: D_head_V - s_vo_d, # stride size: 1 - B, # batch size - H, # n_heads - T, # seq_len - scale, # D_head_K ** -0.5 - BTL: tl.constexpr, # BLOCK SIZE along the sequence dimension for Q - BTS: tl.constexpr, # BLOCK SIZE along the sequence dimension for K/V - BK: tl.constexpr, # BLOCK SIZE along the K dimension - BV: tl.constexpr, # BLOCK SIZE along the V dimension - DK: tl.constexpr, # D_head_K - DV: tl.constexpr, # D_head_V -): - # i_c: chunk index. used for sequence parallelism - i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) - NV = tl.cdiv(DV, BV) - i_k = i_kv // (NV) - i_v = i_kv % (NV) - - p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK), - (s_qk_t, s_qk_d), (i_c * BTL, i_k * BK), (BTL, BK), (1, 0)) - p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T), - (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BTS), (0, 1)) - p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), - (s_vo_t, s_vo_d), (0, i_v * BV), (BTS, BV), (1, 0)) - - # [BQ, BD] block Q, in the shared memory throughout the whole kernel - b_q = tl.load(p_q, boundary_check=(0, 1)) - b_q = (b_q * scale).to(b_q.dtype) - b_o = tl.zeros([BTL, BV], dtype=tl.float32) - b_z = tl.zeros([BTL], dtype=tl.float32) - - # Q block and K block have no overlap - # no need for mask, thereby saving flops - for _ in range(0, i_c * BTL, BTS): - # [BK, BTS] - b_k = tl.load(p_k, boundary_check=(0, 1)) - - # [BTS, BV] - b_v = tl.load(p_v, boundary_check=(0, 1)) - # [BTL, BTS] - b_s = tl.dot(b_q, (b_k), allow_tf32=False) - b_s = 1 + b_s + 0.5 * b_s * b_s - b_z += tl.sum(b_s, axis=1) - - # [BQ, BD] - b_o = b_o + tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False) - p_k = tl.advance(p_k, (0, BTS)) - p_v = tl.advance(p_v, (BTS, 0)) - - # # rescale interchunk output - tl.debug_barrier() - o_q = tl.arange(0, BTL) - # # sync threads, easy for compiler to optimize - # tl.debug_barrier() - - o_k = tl.arange(0, BTS) - p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T), - (s_qk_d, s_qk_t), (i_k * BK, i_c * BTL), (BK, BTS), (0, 1)) - p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), - (s_vo_t, s_vo_d), (i_c * BTL, i_v * BV), (BTS, BV), (1, 0)) - # Q block and K block have overlap. masks required - for _ in range(i_c * BTL, (i_c + 1) * BTL, BTS): - # [BK, BTS] - b_k = tl.load(p_k, boundary_check=(0, 1)) - # [BTS, BV] - b_v = tl.load(p_v, boundary_check=(0, 1)) - # [BTL, BTS] - m_s = o_q[:, None] >= o_k[None, :] - b_s = tl.dot(b_q, b_k, allow_tf32=False) - b_s = 1 + b_s + 0.5 * b_s * b_s - b_s = tl.where(m_s, b_s, 0) - b_z += tl.sum(b_s, axis=1) - # [BTL, BV] - b_o += tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False) - - p_k = tl.advance(p_k, (0, BTS)) - p_v = tl.advance(p_v, (BTS, 0)) - o_k += BTS - - p_o = tl.make_block_ptr(o + (i_bh + B * H * i_k) * s_vo_h, (T, DV), - (s_vo_t, s_vo_d), (i_c*BTL, i_v*BV), (BTL, BV), (1, 0)) - p_z = z + (i_bh + B * H * i_k) * T + i_c * BTL + tl.arange(0, BTL) - tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) - tl.store(p_z, b_z.to(p_z.dtype.element_ty), - mask=((i_c * BTL + tl.arange(0, BTL)) < T)) - - -@triton.jit -def _parallel_based_bwd_dq( - i_bh, i_c, i_k, i_v, i_h, - q, k, v, do, dz, dq, s_qk_h, s_qk_t, s_qk_d, s_vo_h, - s_vo_t, s_vo_d, B, H, T, scale, - BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, - DK: tl.constexpr, DV: tl.constexpr, -): - p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), - (i_c * BTL, i_v * BV), (BTL, BV), (1, 0)) - p_q = tl.make_block_ptr(q + (i_bh) * s_qk_h, (T, DK), - (s_qk_t, s_qk_d), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0)) - b_q = tl.load(p_q, boundary_check=(0, 1)) - b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype) - b_q = (b_q * scale).to(b_q.dtype) - b_dq = tl.zeros([BTL, BK], dtype=tl.float32) - p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), - (s_qk_t, s_qk_d), (0, i_k * BK), (BTS, BK), (1, 0)) - p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (DV, T), - (s_vo_d, s_vo_t), (i_v * BV, 0), (BV, BTS), (0, 1)) - p_dz = dz + i_bh * T + i_c * BTL + tl.arange(0, BTL) - b_dz = tl.load(p_dz, mask=(i_c * BTL + tl.arange(0, BTL)) < T) - - for _ in range(0, i_c * BTL, BTS): - # [BTS, BK] - b_k = tl.load(p_k, boundary_check=(0, 1)) - # [BV, BTS] - b_v = tl.load(p_v, boundary_check=(0, 1)) - # [BTL, BTS] - b_ds = tl.dot(b_do, b_v, allow_tf32=False) - if i_v == 0: - b_ds += b_dz[:, None] - else: - b_ds = b_ds - b_s = tl.dot(b_q, tl.trans(b_k), allow_tf32=False) - # [BQ, BD] - b_dq += tl.dot((b_ds * (1 + b_s)).to(b_v.dtype), b_k, allow_tf32=False) - p_k = tl.advance(p_k, (BTS, 0)) - p_v = tl.advance(p_v, (0, BTS)) - - b_dq *= scale - o_q = tl.arange(0, BTL) - o_k = tl.arange(0, BTS) - p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), - (s_qk_t, s_qk_d), (i_c * BTL, i_k * BK), (BTS, BK), (1, 0)) - p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (DV, T), - (s_vo_d, s_vo_t), (i_v * BV, i_c * BTL), (BV, BTS), (0, 1)) - # Q block and K block have overlap. masks required - for _ in range(i_c * BTL, (i_c + 1) * BTL, BTS): - # [BTS, BK] - b_k = tl.load(p_k, boundary_check=(0, 1)) - # [BV, BTS] - b_v = tl.load(p_v, boundary_check=(0, 1)) - # [BTL, BTS] - m_s = o_q[:, None] >= o_k[None, :] - b_ds = tl.dot(b_do, b_v, allow_tf32=False) - if i_v == 0: - b_ds += b_dz[:, None] - else: - b_ds = b_ds - b_ds = tl.where(m_s, b_ds, 0) * scale - b_s = tl.dot(b_q, tl.trans(b_k), allow_tf32=False) - b_s = tl.where(m_s, b_s, 0) - # [BTL, BK] - b_dq += tl.dot((b_ds + b_ds * b_s).to(b_k.dtype), - b_k, allow_tf32=False) - p_k = tl.advance(p_k, (BTS, 0)) - p_v = tl.advance(p_v, (0, BTS)) - o_k += BTS - p_dq = tl.make_block_ptr(dq + (i_bh + B * H * i_v) * s_qk_h, (T, DK), - (s_qk_t, s_qk_d), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0)) - tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) - return - - -@triton.jit -def _parallel_based_bwd_dkv( - i_bh, i_c, i_k, i_v, i_h, - q, k, v, do, dz, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h, - s_vo_t, s_vo_d, B, H, T, scale, - BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, - DK: tl.constexpr, DV: tl.constexpr, -): - # compute dk dv - p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), - (i_c * BTL, i_k * BK), (BTL, BK), (1, 0)) - p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), - (i_c * BTL, i_v * BV), (BTL, BV), (1, 0)) - b_k, b_v = tl.load(p_k, boundary_check=(0, 1)), tl.load( - p_v, boundary_check=(0, 1)) - b_dk, b_dv = tl.zeros([BTL, BK], dtype=tl.float32), tl.zeros( - [BTL, BV], dtype=tl.float32) - - for i in range((tl.cdiv(T, BTS) * BTS)-BTS, (i_c + 1) * BTL - BTS, -BTS): - p_q = tl.make_block_ptr( - q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, i), (BK, BTS), (0, 1)) - p_do = tl.make_block_ptr( - do + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i), (BV, BTS), (0, 1)) - p_dz = dz + i_bh * T + i + tl.arange(0, BTS) - b_q = tl.load(p_q, boundary_check=(0, 1)) # [BK, BTS] - b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype) # [BV, BTS] - b_dz = tl.load(p_dz, mask=(i + tl.arange(0, BTS)) < T) - b_s = tl.dot(b_k.to(b_q.dtype), b_q, allow_tf32=False) * \ - scale # [BTL, BTS] - b_s2 = 1 + b_s + 0.5 * b_s * b_s - b_dv += tl.dot(b_s2.to(b_q.dtype), tl.trans(b_do), allow_tf32=False) - b_ds = tl.dot(b_v, b_do, allow_tf32=False) * scale - if i_v == 0: - b_ds += b_dz[None, :] * scale - else: - b_ds = b_ds - b_dk += tl.dot((b_ds + b_ds * b_s).to(b_q.dtype), - tl.trans(b_q), allow_tf32=False) - - tl.debug_barrier() - o_q, o_k = tl.arange(0, BTS), tl.arange(0, BTL) - for i in range(i_c*BTL, (i_c+1)*BTL, BTS): - p_q = tl.make_block_ptr( - q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, i), (BK, BTS), (0, 1)) - p_do = tl.make_block_ptr( - do + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i), (BV, BTS), (0, 1)) - p_dz = dz + i_bh * T + i + tl.arange(0, BTS) - b_q = tl.load(p_q, boundary_check=(0, 1)) # [BD, BQ] - b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype) - b_dz = tl.load(p_dz, mask=(i + tl.arange(0, BTS)) < T) - # [BK, BQ] - m_s = o_k[:, None] <= o_q[None, :] - b_s = tl.dot(b_k, b_q, allow_tf32=False) * scale - b_s2 = 1 + b_s + 0.5 * b_s * b_s - b_s = tl.where(m_s, b_s, 0) - b_s2 = tl.where(m_s, b_s2, 0) - - b_ds = tl.dot(b_v, b_do, allow_tf32=False) - if i_v == 0: - b_ds += b_dz[None, :] - else: - b_ds = b_ds - b_ds = tl.where(m_s, b_ds, 0) * scale - # [BK, BD] - b_dv += tl.dot(b_s2.to(b_q.dtype), tl.trans(b_do), allow_tf32=False) - b_dk += tl.dot((b_ds + b_ds * b_s).to(b_q.dtype), - tl.trans(b_q), allow_tf32=False) - o_q += BTS - - p_dk = tl.make_block_ptr(dk + (i_bh + B * H * i_v) * s_qk_h, - (T, DK), (s_qk_t, s_qk_d), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0)) - p_dv = tl.make_block_ptr(dv + (i_bh + B * H * i_k) * s_vo_h, - (T, DV), (s_vo_t, s_vo_d), (i_c*BTL, i_v*BV), (BTL, BV), (1, 0)) - tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) - tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) - return - - -@triton.jit -def parallel_based_bwd_kernel( - q, k, v, do, dz, dq, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h, - s_vo_t, s_vo_d, B, H, T, scale, - BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, - DK: tl.constexpr, DV: tl.constexpr, -): - i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) - NV = tl.cdiv(DV, BV) - i_k = i_kv // (NV) - i_v = i_kv % (NV) - i_h = i_bh % H - _parallel_based_bwd_dq( - i_bh, i_c, i_k, i_v, i_h, - q, k, v, do, dz, dq, s_qk_h, s_qk_t, s_qk_d, s_vo_h, - s_vo_t, s_vo_d, B, H, T, scale, BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=DK, DV=DV - ) - tl.debug_barrier() - _parallel_based_bwd_dkv( - i_bh, i_c, i_k, i_v, i_h, - q, k, v, do, dz, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h, - s_vo_t, s_vo_d, B, H, T, scale, BTL, BTS, BK, BV, DK, DV - ) - - -class ParallelBasedFunction(torch.autograd.Function): - @staticmethod - @contiguous - @custom_fwd - def forward(ctx, q, k, v, scale): - BTL, BTS = 128, 32 - assert BTL % BTS == 0 - # assert q.shape[-1] % 16 == 0 - BK = min(128, triton.next_power_of_2(k.shape[-1])) - BV = min(128, triton.next_power_of_2(v.shape[-1])) - BK, BV = max(BK, 16), max(BV, 16) - batch_size, n_heads, seq_len, d_head_qk = q.shape - d_head_v = v.shape[-1] - num_stages = 2 - num_warps = 4 - NK = triton.cdiv(d_head_qk, BK) - NV = triton.cdiv(d_head_v, BV) - grid = (NK * NV, triton.cdiv(seq_len, BTL), batch_size * n_heads) - - assert NK == 1, "will encounter some synchronization issue if not." - - o = torch.empty(NK, batch_size, n_heads, seq_len, - d_head_v, device=q.device) - z = torch.empty(NK, batch_size, n_heads, seq_len, - device=q.device) - parallel_based_fwd_kernel[grid]( - q, k, v, o, z, - q.stride(1), q.stride(2), q.stride(3), - v.stride(1), v.stride(2), v.stride(3), - batch_size, n_heads, seq_len, scale, - BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=d_head_qk, DV=d_head_v, - num_warps=num_warps, - num_stages=num_stages - ) - ctx.save_for_backward(q, k, v) - ctx.scale = scale - return o.sum(0).to(q.dtype), z.sum(0).to(q.dtype) - - @staticmethod - @custom_bwd - @contiguous - def backward(ctx, do, dz): - q, k, v = ctx.saved_tensors - scale = ctx.scale - BTL, BTS = 64, 32 - assert BTL % BTS == 0 - BK = min(128, triton.next_power_of_2(k.shape[-1])) - BV = min(128, triton.next_power_of_2(v.shape[-1])) - BK, BV = max(BK, 16), max(BV, 16) - batch_size, n_heads, seq_len, d_head_qk = q.shape - d_head_v = v.shape[-1] - num_stages = 2 - num_warps = 4 - NK = triton.cdiv(d_head_qk, BK) - NV = triton.cdiv(d_head_v, BV) - grid = (NK * NV, triton.cdiv(seq_len, BTL), batch_size * n_heads) - - assert NK == 1, "will encounter some synchronization issue if not" - - dq = torch.empty(NV, batch_size, n_heads, seq_len, - d_head_qk, dtype=q.dtype, device=q.device) - dk = torch.empty(NV, batch_size, n_heads, seq_len, - d_head_qk, dtype=q.dtype, device=q.device) - dv = torch.empty(NK, batch_size, n_heads, seq_len, - d_head_v, dtype=q.dtype, device=q.device) - - parallel_based_bwd_kernel[grid]( - q, k, v, do, dz, dq, dk, dv, - q.stride(1), q.stride(2), q.stride(3), - v.stride(1), v.stride(2), v.stride(3), - batch_size, n_heads, seq_len, scale, - BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=d_head_qk, DV=d_head_v, - num_warps=num_warps, - num_stages=num_stages - ) - - return dq.sum(0).to(q.dtype), dk.sum(0).to(k.dtype), dv.sum(0).to(v.dtype), None - - -triton_parallel_based = ParallelBasedFunction.apply - - -def parallel_based(q, k, v, eps: float = 1e-6, use_scale: bool = True, use_normalize: bool = True, return_both: bool = False): - assert q.shape[-1] <= 128, "only support feature dim up to 128" - if use_scale: - scale = q.shape[-1] ** -0.5 - else: - scale = 1 - o, z = triton_parallel_based(q, k, v, scale) - if return_both: - return o, z - if use_normalize: - o = o / (z[..., None] + eps) - else: - o = o - return o.to(q.dtype) diff --git a/flash_linear_attention/fla/ops/triton/gla/__init__.py b/flash_linear_attention/fla/ops/triton/gla/__init__.py deleted file mode 100644 index 7e41f8c..0000000 --- a/flash_linear_attention/fla/ops/triton/gla/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -# -*- coding: utf-8 -*- - -from .chunk import chunk_gla -from .chunk_fuse import fused_chunk_gla -from .recurrent_fuse import fused_recurrent_gla - -__all__ = ['chunk_gla', 'fused_recurrent_gla', 'fused_chunk_gla'] diff --git a/flash_linear_attention/fla/ops/triton/gla/block_parallel/__init__.py b/flash_linear_attention/fla/ops/triton/gla/block_parallel/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/flash_linear_attention/fla/ops/triton/gla/block_parallel/inter_chunk_contribution/__init__.py b/flash_linear_attention/fla/ops/triton/gla/block_parallel/inter_chunk_contribution/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/flash_linear_attention/fla/ops/triton/gla/block_parallel/inter_chunk_contribution/chunk_scan_triton_full.py b/flash_linear_attention/fla/ops/triton/gla/block_parallel/inter_chunk_contribution/chunk_scan_triton_full.py deleted file mode 100644 index c27f6ba..0000000 --- a/flash_linear_attention/fla/ops/triton/gla/block_parallel/inter_chunk_contribution/chunk_scan_triton_full.py +++ /dev/null @@ -1,212 +0,0 @@ -# -*- coding: utf-8 -*- - -import torch -import triton -import triton.language as tl -from torch.cuda.amp import custom_bwd, custom_fwd - -from fla.ops.triton.utils import contiguous - - -@triton.jit -def _fwd_recurrence( - S, - p1, - p2, - O, - NUM_BLOCK, - D_MODEL_K: tl.constexpr, - D_MODEL_V: tl.constexpr, - BLOCK_MODEL: tl.constexpr -): - offset_bh = tl.program_id(0) - offset_d = tl.program_id(1) - offset_s = tl.program_id(2) - - S = S + offset_bh * NUM_BLOCK * D_MODEL_K * D_MODEL_V + offset_d * D_MODEL_V * BLOCK_MODEL + \ - tl.arange(0, BLOCK_MODEL)[:, None] * D_MODEL_V + offset_s * \ - BLOCK_MODEL + tl.arange(0, BLOCK_MODEL)[None, :] - - O = O + offset_bh * NUM_BLOCK * D_MODEL_K * D_MODEL_V + offset_d * D_MODEL_V * BLOCK_MODEL + \ - tl.arange(0, BLOCK_MODEL)[:, None] * D_MODEL_V + offset_s * BLOCK_MODEL + \ - tl.arange(0, BLOCK_MODEL)[None, :] + D_MODEL_K * D_MODEL_V - - p1 = p1 + offset_bh * NUM_BLOCK * D_MODEL_K + \ - tl.arange(0, BLOCK_MODEL) + offset_d * BLOCK_MODEL + D_MODEL_K - - p2 = p2 + offset_bh * NUM_BLOCK * D_MODEL_V + \ - tl.arange(0, BLOCK_MODEL) + offset_s * BLOCK_MODEL + D_MODEL_V - - acc = tl.zeros([BLOCK_MODEL, BLOCK_MODEL], dtype=tl.float32) - acc += tl.load(S) - - S += D_MODEL_K * D_MODEL_V - - tl.store(O, acc.to(O.dtype.element_ty)) - O += D_MODEL_K * D_MODEL_V - - for i in range(NUM_BLOCK-2): - p_k = tl.load(p1) - p_v = tl.load(p2) - S_i = tl.load(S) - acc = acc * p_k[:, None] * p_v[None, :] + S_i - tl.store(O, acc.to(O.dtype.element_ty)) - p1 += D_MODEL_K - p2 += D_MODEL_V - S += D_MODEL_K * D_MODEL_V - O += D_MODEL_K * D_MODEL_V - - -# NUM_SPLIT_K/V. K/V dimension split into NUM_SPLIT_K/V parts with equal size BLOCK_MODEL -@triton.jit -def _bwd_recurrence( - S, - p1, - p2, - DS, - Dp1, - Dp2, - NUM_BLOCK, - NUM_SPLIT_K, - NUM_SPLIT_V, - D_MODEL_K: tl.constexpr, - D_MODEL_V: tl.constexpr, - BLOCK_MODEL: tl.constexpr - -): - offset_bh = tl.program_id(0) - offset_d = tl.program_id(1) - offset_s = tl.program_id(2) - - # skip the last chunk because it is never used - S = S + offset_bh * NUM_BLOCK * D_MODEL_K * D_MODEL_V + offset_d * D_MODEL_V * BLOCK_MODEL + \ - tl.arange(0, BLOCK_MODEL)[:, None] * D_MODEL_V + offset_s * BLOCK_MODEL + tl.arange( - 0, BLOCK_MODEL)[None, :] + (NUM_BLOCK - 2) * D_MODEL_K * D_MODEL_V - - # start from the last chunk - DS = DS + offset_bh * NUM_BLOCK * D_MODEL_K * D_MODEL_V + offset_d * D_MODEL_V * BLOCK_MODEL + \ - tl.arange(0, BLOCK_MODEL)[:, None] * D_MODEL_V + offset_s * BLOCK_MODEL + tl.arange( - 0, BLOCK_MODEL)[None, :] + (NUM_BLOCK - 1) * D_MODEL_K * D_MODEL_V - - # skip the last chunk because it is never used - p1 = p1 + offset_bh * NUM_BLOCK * D_MODEL_K + \ - tl.arange(0, BLOCK_MODEL) + offset_d * \ - BLOCK_MODEL + (NUM_BLOCK - 2) * D_MODEL_K - - # skip the last chunk because it is never used - p2 = p2 + offset_bh * NUM_BLOCK * D_MODEL_V + \ - tl.arange(0, BLOCK_MODEL) + offset_s * \ - BLOCK_MODEL + (NUM_BLOCK - 2) * D_MODEL_V - - # skip the last chunk because it is never used - # NUM_BLOCK * D_MODEL_K * NUM_SPLIT_V: stride_bh - # offset_s * D_MODEL_K: find the right split in the K dimension - Dp1 = Dp1 + offset_bh * NUM_BLOCK * D_MODEL_K * NUM_SPLIT_V + offset_s * D_MODEL_K + \ - tl.arange(0, BLOCK_MODEL) + offset_d * BLOCK_MODEL + \ - (NUM_BLOCK - 2) * D_MODEL_K * NUM_SPLIT_V - - # skip the last chunk because it is never used - Dp2 = Dp2 + offset_bh * NUM_BLOCK * D_MODEL_V * NUM_SPLIT_K + offset_d * D_MODEL_V + \ - tl.arange(0, BLOCK_MODEL) + offset_s * BLOCK_MODEL + \ - (NUM_BLOCK - 2) * D_MODEL_V * NUM_SPLIT_K - - Dacc = tl.zeros([BLOCK_MODEL, BLOCK_MODEL], dtype=tl.float32) - - # ignore the first chunk - for i in range(NUM_BLOCK - 1): - p_key = tl.load(p1) - p_value = tl.load(p2) - S_i = tl.load(S) - DS_i = tl.load(DS) - Dacc += DS_i - dp_i = Dacc * S_i - dp_key = tl.sum(dp_i * p_value[None, :], axis=1) - tl.store(Dp1, dp_key.to(Dp1.dtype.element_ty)) - dp_value = tl.sum(dp_i * p_key[:, None], axis=0) - tl.store(Dp2, dp_value.to(Dp2.dtype.element_ty)) - - tl.store(S, Dacc.to(S.dtype.element_ty)) - - Dacc *= p_key[:, None] - Dacc *= p_value[None, :] - - S -= D_MODEL_K * D_MODEL_V - DS -= D_MODEL_K * D_MODEL_V - p1 -= D_MODEL_K - p2 -= D_MODEL_V - Dp1 -= D_MODEL_K * NUM_SPLIT_V - Dp2 -= D_MODEL_V * NUM_SPLIT_K - - -class Chunk_memory_update_full(torch.autograd.Function): - @staticmethod - @contiguous - @custom_fwd - def forward(ctx, decay_key_last, decay_value_last, to_add): - B, H, N, D_k, D_v = to_add.shape - output = torch.empty_like(to_add) - BLOCK_MODEL = 32 - - assert D_k % 32 == 0 - assert D_v % 32 == 0 - assert D_k == decay_key_last.shape[-1] - assert D_v == decay_value_last.shape[-1] - - grid = (B*H, D_k//BLOCK_MODEL, D_v//BLOCK_MODEL) - ctx.grid = grid - ctx.BLOCK_MODEL = BLOCK_MODEL - - _fwd_recurrence[grid]( - to_add, - decay_key_last, - decay_value_last, - output, - D_MODEL_K=D_k, D_MODEL_V=D_v, - NUM_BLOCK=N, - BLOCK_MODEL=BLOCK_MODEL - ) - - output[:, :, 0] = 0 - ctx.save_for_backward(output, decay_key_last, decay_value_last) - - return output - - @staticmethod - @contiguous - @custom_bwd - def backward(ctx, DO): - - output, decay_key_last, decay_value_last = ctx.saved_tensors - - B, H, N, D_k, D_v = output.shape - - num_block = N - - BLOCK_MODEL = 32 - - grid = (B*H, D_k//BLOCK_MODEL, D_v//BLOCK_MODEL) - - # I don't want atomic_add to be used in the backward pass - # so I add another dimension to the output tensor (D_k/v // BLOCK_MODEL) - # afterward, I sum over this dimension to get the correct gradient - D_p1 = torch.empty(B, H, N, D_v // BLOCK_MODEL, D_k, - device=DO.device, dtype=torch.float32) - D_p2 = torch.empty(B, H, N, D_k // BLOCK_MODEL, D_v, - device=DO.device, dtype=torch.float32) - - _bwd_recurrence[grid]( - output, decay_key_last, decay_value_last, - DO, D_p1, D_p2, - NUM_BLOCK=num_block, NUM_SPLIT_K=D_k // BLOCK_MODEL, NUM_SPLIT_V=D_v // BLOCK_MODEL, - D_MODEL_K=D_k, - D_MODEL_V=D_v, - BLOCK_MODEL=BLOCK_MODEL - ) - - output[:, :, -1] = 0 - D_p1[:, :, 0] = 0 - D_p1[:, :, -1] = 0 - D_p2[:, :, 0] = 0 - D_p2[:, :, -1] = 0 - - return D_p1.sum(-2).to(decay_key_last.dtype), D_p2.sum(-2).to(decay_key_last.dtype), output diff --git a/flash_linear_attention/fla/ops/triton/gla/block_parallel/inter_chunk_contribution/chunk_scan_triton_no_decay.py b/flash_linear_attention/fla/ops/triton/gla/block_parallel/inter_chunk_contribution/chunk_scan_triton_no_decay.py deleted file mode 100644 index e83b54e..0000000 --- a/flash_linear_attention/fla/ops/triton/gla/block_parallel/inter_chunk_contribution/chunk_scan_triton_no_decay.py +++ /dev/null @@ -1,166 +0,0 @@ -# -*- coding: utf-8 -*- - -import torch -import triton -import triton.language as tl -from torch.cuda.amp import custom_bwd, custom_fwd - -from fla.ops.triton.utils import contiguous - - -@triton.jit -def _fwd_recurrence( - S, - O, - NUM_BLOCK, - D_MODEL_K: tl.constexpr, - D_MODEL_V: tl.constexpr, - BLOCK_MODEL: tl.constexpr -): - offset_bh = tl.program_id(0) - offset_d = tl.program_id(1) - offset_s = tl.program_id(2) - - S = S + offset_bh * NUM_BLOCK * D_MODEL_K * D_MODEL_V + offset_d * D_MODEL_V * BLOCK_MODEL + \ - tl.arange(0, BLOCK_MODEL)[:, None] * D_MODEL_V + offset_s * \ - BLOCK_MODEL + tl.arange(0, BLOCK_MODEL)[None, :] - - O = O + offset_bh * NUM_BLOCK * D_MODEL_K * D_MODEL_V + offset_d * D_MODEL_V * BLOCK_MODEL + \ - tl.arange(0, BLOCK_MODEL)[:, None] * D_MODEL_V + offset_s * BLOCK_MODEL + \ - tl.arange(0, BLOCK_MODEL)[None, :] + D_MODEL_K * D_MODEL_V - - acc = tl.zeros([BLOCK_MODEL, BLOCK_MODEL], dtype=tl.float32) - acc += tl.load(S) - - S += D_MODEL_K * D_MODEL_V - - tl.store(O, acc.to(O.dtype.element_ty)) - O += D_MODEL_K * D_MODEL_V - - for i in range(NUM_BLOCK-2): - S_i = tl.load(S) - acc = acc + S_i - tl.store(O, acc.to(O.dtype.element_ty)) - S += D_MODEL_K * D_MODEL_V - O += D_MODEL_K * D_MODEL_V - - -# NUM_SPLIT_K/V. K/V dimension split into NUM_SPLIT_K/V parts with equal size BLOCK_MODEL -@triton.jit -def _bwd_recurrence( - S, - DS, - NUM_BLOCK, - NUM_SPLIT_K, - NUM_SPLIT_V, - D_MODEL_K: tl.constexpr, - D_MODEL_V: tl.constexpr, - BLOCK_MODEL: tl.constexpr -): - offset_bh = tl.program_id(0) - offset_d = tl.program_id(1) - offset_s = tl.program_id(2) - - # skip the last chunk because it is never used - S = S + offset_bh * NUM_BLOCK * D_MODEL_K * D_MODEL_V + offset_d * D_MODEL_V * BLOCK_MODEL + \ - tl.arange(0, BLOCK_MODEL)[:, None] * D_MODEL_V + offset_s * BLOCK_MODEL + tl.arange( - 0, BLOCK_MODEL)[None, :] + (NUM_BLOCK - 2) * D_MODEL_K * D_MODEL_V - - # start from the last chunk - DS = DS + offset_bh * NUM_BLOCK * D_MODEL_K * D_MODEL_V + offset_d * D_MODEL_V * BLOCK_MODEL + \ - tl.arange(0, BLOCK_MODEL)[:, None] * D_MODEL_V + offset_s * BLOCK_MODEL + tl.arange( - 0, BLOCK_MODEL)[None, :] + (NUM_BLOCK - 1) * D_MODEL_K * D_MODEL_V - - # skip the last chunk because it is never used - - # skip the last chunk because it is never used - # NUM_BLOCK * D_MODEL_K * NUM_SPLIT_V: stride_bh - # offset_s * D_MODEL_K: find the right split in the K dimension - Dacc = tl.zeros([BLOCK_MODEL, BLOCK_MODEL], dtype=tl.float32) - - # ignore the first chunk - for i in range(NUM_BLOCK - 1): - # S_i = tl.load(S) - DS_i = tl.load(DS) - Dacc += DS_i - # dp_i = Dacc * S_i - - # dp_key = tl.sum(dp_i * p_value[None, :], axis=1) - # tl.store(Dp1, dp_key.to(Dp1.dtype.element_ty)) - # dp_value = tl.sum(dp_i * p_key[:, None], axis=0) - # tl.store(Dp2, dp_value.to(Dp2.dtype.element_ty)) - - tl.store(S, Dacc.to(S.dtype.element_ty)) - - # Dacc *= p_key[:, None] - # Dacc *= p_value[None, :] - - S -= D_MODEL_K * D_MODEL_V - DS -= D_MODEL_K * D_MODEL_V - - -class Chunk_memory_update_no_decay(torch.autograd.Function): - @staticmethod - @custom_fwd - @contiguous - def forward(ctx, to_add): - B, H, N, D_k, D_v = to_add.shape - output = torch.empty_like(to_add) - BLOCK_MODEL = 32 - - assert D_k % 32 == 0 - assert D_v % 32 == 0 - # assert D_k == decay_key_last.shape[-1] - # assert D_v == decay_value_last.shape[-1] - - grid = (B*H, D_k//BLOCK_MODEL, D_v//BLOCK_MODEL) - ctx.grid = grid - ctx.BLOCK_MODEL = BLOCK_MODEL - - _fwd_recurrence[grid]( - to_add, - # decay_key_last, - # decay_value_last, - output, - D_MODEL_K=D_k, D_MODEL_V=D_v, - NUM_BLOCK=N, - BLOCK_MODEL=BLOCK_MODEL - ) - - output[:, :, 0] = 0 - ctx.save_for_backward(output) - - return output.to(to_add.dtype) - - @staticmethod - @custom_bwd - @contiguous - def backward(ctx, DO): - output, = ctx.saved_tensors - - B, H, N, D_k, D_v = output.shape - - num_block = N - - BLOCK_MODEL = 32 - - grid = (B*H, D_k//BLOCK_MODEL, D_v//BLOCK_MODEL) - - # I don't want atomic_add to be used in the backward pass - # so I add another dimension to the output tensor (D_k/v // BLOCK_MODEL) - # afterward, I sum over this dimension to get the correct gradient - # D_p1 = torch.empty(B, H, N, D_v // BLOCK_MODEL, D_k, device=DO.device, dtype=torch.float32) - # D_p2 = torch.empty(B, H, N, D_k // BLOCK_MODEL, D_v, device=DO.device, dtype=torch.float32) - - _bwd_recurrence[grid]( - output, - DO, - NUM_BLOCK=num_block, NUM_SPLIT_K=D_k // BLOCK_MODEL, NUM_SPLIT_V=D_v // BLOCK_MODEL, - D_MODEL_K=D_k, - D_MODEL_V=D_v, - BLOCK_MODEL=BLOCK_MODEL - ) - - output[:, :, -1] = 0 - - return output diff --git a/flash_linear_attention/fla/ops/triton/gla/block_parallel/inter_chunk_contribution/chunk_scan_triton_only_gk.py b/flash_linear_attention/fla/ops/triton/gla/block_parallel/inter_chunk_contribution/chunk_scan_triton_only_gk.py deleted file mode 100644 index 59cb608..0000000 --- a/flash_linear_attention/fla/ops/triton/gla/block_parallel/inter_chunk_contribution/chunk_scan_triton_only_gk.py +++ /dev/null @@ -1,187 +0,0 @@ -# -*- coding: utf-8 -*- - -import torch -import triton -import triton.language as tl -from torch.cuda.amp import custom_bwd, custom_fwd - -from fla.ops.triton.utils import contiguous - - -@triton.jit -def _fwd_recurrence( - S, - p1, - O, - NUM_BLOCK, - D_MODEL_K: tl.constexpr, - D_MODEL_V: tl.constexpr, - BLOCK_MODEL: tl.constexpr -): - offset_bh = tl.program_id(0) - offset_d = tl.program_id(1) - offset_s = tl.program_id(2) - - S = S + offset_bh * NUM_BLOCK * D_MODEL_K * D_MODEL_V + offset_d * D_MODEL_V * BLOCK_MODEL + \ - tl.arange(0, BLOCK_MODEL)[:, None] * D_MODEL_V + offset_s * \ - BLOCK_MODEL + tl.arange(0, BLOCK_MODEL)[None, :] - - O = O + offset_bh * NUM_BLOCK * D_MODEL_K * D_MODEL_V + offset_d * D_MODEL_V * BLOCK_MODEL + \ - tl.arange(0, BLOCK_MODEL)[:, None] * D_MODEL_V + offset_s * BLOCK_MODEL + \ - tl.arange(0, BLOCK_MODEL)[None, :] + D_MODEL_K * D_MODEL_V - - p1 = p1 + offset_bh * NUM_BLOCK * D_MODEL_K + \ - tl.arange(0, BLOCK_MODEL) + offset_d * BLOCK_MODEL + D_MODEL_K - - acc = tl.zeros([BLOCK_MODEL, BLOCK_MODEL], dtype=tl.float32) - acc += tl.load(S) - - S += D_MODEL_K * D_MODEL_V - - tl.store(O, acc.to(O.dtype.element_ty)) - O += D_MODEL_K * D_MODEL_V - - for i in range(NUM_BLOCK-2): - p_k = tl.load(p1) - S_i = tl.load(S) - acc = acc * p_k[:, None] + S_i - tl.store(O, acc.to(O.dtype.element_ty)) - p1 += D_MODEL_K - S += D_MODEL_K * D_MODEL_V - O += D_MODEL_K * D_MODEL_V - - -# NUM_SPLIT_K/V. K/V dimension split into NUM_SPLIT_K/V parts with equal size BLOCK_MODEL -@triton.jit -def _bwd_recurrence( - S, p1, - DS, Dp1, - NUM_BLOCK, - NUM_SPLIT_K, - NUM_SPLIT_V, - D_MODEL_K: tl.constexpr, - D_MODEL_V: tl.constexpr, - BLOCK_MODEL: tl.constexpr -): - offset_bh = tl.program_id(0) - offset_d = tl.program_id(1) - offset_s = tl.program_id(2) - - # skip the last chunk because it is never used - S = S + offset_bh * NUM_BLOCK * D_MODEL_K * D_MODEL_V + offset_d * D_MODEL_V * BLOCK_MODEL + \ - tl.arange(0, BLOCK_MODEL)[:, None] * D_MODEL_V + offset_s * BLOCK_MODEL + tl.arange( - 0, BLOCK_MODEL)[None, :] + (NUM_BLOCK - 2) * D_MODEL_K * D_MODEL_V - - # start from the last chunk - DS = DS + offset_bh * NUM_BLOCK * D_MODEL_K * D_MODEL_V + offset_d * D_MODEL_V * BLOCK_MODEL + \ - tl.arange(0, BLOCK_MODEL)[:, None] * D_MODEL_V + offset_s * BLOCK_MODEL + tl.arange( - 0, BLOCK_MODEL)[None, :] + (NUM_BLOCK - 1) * D_MODEL_K * D_MODEL_V - - # skip the last chunk because it is never used - p1 = p1 + offset_bh * NUM_BLOCK * D_MODEL_K + \ - tl.arange(0, BLOCK_MODEL) + offset_d * \ - BLOCK_MODEL + (NUM_BLOCK - 2) * D_MODEL_K - - # skip the last chunk because it is never used - # p2 = p2 + offset_bh * NUM_BLOCK * D_MODEL_V + tl.arange(0, BLOCK_MODEL) + offset_s * BLOCK_MODEL + (NUM_BLOCK - 2) * D_MODEL_V - - # skip the last chunk because it is never used - # NUM_BLOCK * D_MODEL_K * NUM_SPLIT_V: stride_bh - # offset_s * D_MODEL_K: find the right split in the K dimension - Dp1 = Dp1 + offset_bh * NUM_BLOCK * D_MODEL_K * NUM_SPLIT_V + offset_s * D_MODEL_K + \ - tl.arange(0, BLOCK_MODEL) + offset_d * BLOCK_MODEL + \ - (NUM_BLOCK - 2) * D_MODEL_K * NUM_SPLIT_V - - # skip the last chunk because it is never used - # Dp2 = Dp2 + offset_bh * NUM_BLOCK * D_MODEL_V * NUM_SPLIT_K + offset_d * D_MODEL_V + tl.arange(0, BLOCK_MODEL) + offset_s * BLOCK_MODEL + (NUM_BLOCK - 2) * D_MODEL_V * NUM_SPLIT_K - - Dacc = tl.zeros([BLOCK_MODEL, BLOCK_MODEL], dtype=tl.float32) - - # ignore the first chunk - for i in range(NUM_BLOCK - 1): - p_key = tl.load(p1) - S_i = tl.load(S) - DS_i = tl.load(DS) - Dacc += DS_i - dp_i = Dacc * S_i - dp_key = tl.sum(dp_i, axis=1) - tl.store(Dp1, dp_key.to(Dp1.dtype.element_ty)) - - tl.store(S, Dacc.to(S.dtype.element_ty)) - - Dacc *= p_key[:, None] - - S -= D_MODEL_K * D_MODEL_V - DS -= D_MODEL_K * D_MODEL_V - p1 -= D_MODEL_K - Dp1 -= D_MODEL_K * NUM_SPLIT_V - - -class Chunk_memory_update_only_gk(torch.autograd.Function): - @staticmethod - @custom_fwd - @contiguous - def forward(ctx, decay_key_last, to_add): - - B, H, N, D_k, D_v = to_add.shape - output = torch.empty_like(to_add) - BLOCK_MODEL = 32 - - assert D_k % 32 == 0 - assert D_v % 32 == 0 - assert D_k == decay_key_last.shape[-1] - # assert D_v == to_add.shape[-1] - - grid = (B*H, D_k//BLOCK_MODEL, D_v//BLOCK_MODEL) - ctx.grid = grid - ctx.BLOCK_MODEL = BLOCK_MODEL - - _fwd_recurrence[grid]( - to_add, - decay_key_last, - output, - D_MODEL_K=D_k, D_MODEL_V=D_v, - NUM_BLOCK=N, - BLOCK_MODEL=BLOCK_MODEL - ) - - output[:, :, 0] = 0 - ctx.save_for_backward(output, decay_key_last) - - return output - - @staticmethod - @custom_bwd - @contiguous - def backward(ctx, DO): - output, decay_key_last = ctx.saved_tensors - - B, H, N, D_k, D_v = output.shape - - num_block = N - - BLOCK_MODEL = 32 - - grid = (B*H, D_k//BLOCK_MODEL, D_v//BLOCK_MODEL) - - # I don't want atomic_add to be used in the backward pass - # so I add another dimension to the output tensor (D_k/v // BLOCK_MODEL) - # afterward, I sum over this dimension to get the correct gradient - D_p1 = torch.empty(B, H, N, D_v // BLOCK_MODEL, D_k, - device=DO.device, dtype=torch.float32) - # D_p2 = torch.empty(B, H, N, D_k // BLOCK_MODEL, D_v, device=DO.device, dtype=torch.float32) - - _bwd_recurrence[grid]( - output, decay_key_last, - DO, D_p1, - NUM_BLOCK=num_block, NUM_SPLIT_K=D_k // BLOCK_MODEL, NUM_SPLIT_V=D_v // BLOCK_MODEL, - D_MODEL_K=D_k, - D_MODEL_V=D_v, - BLOCK_MODEL=BLOCK_MODEL - ) - - output[:, :, -1] = 0 - D_p1[:, :, 0] = 0 - D_p1[:, :, -1] = 0 - - return D_p1.sum(-2).to(decay_key_last.dtype), output diff --git a/flash_linear_attention/fla/ops/triton/gla/block_parallel/inter_chunk_contribution/chunk_scan_triton_only_gv.py b/flash_linear_attention/fla/ops/triton/gla/block_parallel/inter_chunk_contribution/chunk_scan_triton_only_gv.py deleted file mode 100644 index 77d03a8..0000000 --- a/flash_linear_attention/fla/ops/triton/gla/block_parallel/inter_chunk_contribution/chunk_scan_triton_only_gv.py +++ /dev/null @@ -1,199 +0,0 @@ -# -*- coding: utf-8 -*- - -import torch -import triton -import triton.language as tl -from torch.cuda.amp import custom_bwd, custom_fwd - -from fla.ops.triton.utils import contiguous - - -@triton.jit -def _fwd_recurrence( - S, - p2, - O, - NUM_BLOCK, - D_MODEL_K: tl.constexpr, - D_MODEL_V: tl.constexpr, - BLOCK_MODEL: tl.constexpr -): - offset_bh = tl.program_id(0) - offset_d = tl.program_id(1) - offset_s = tl.program_id(2) - - S = S + offset_bh * NUM_BLOCK * D_MODEL_K * D_MODEL_V + offset_d * D_MODEL_V * BLOCK_MODEL + \ - tl.arange(0, BLOCK_MODEL)[:, None] * D_MODEL_V + offset_s * \ - BLOCK_MODEL + tl.arange(0, BLOCK_MODEL)[None, :] - - O = O + offset_bh * NUM_BLOCK * D_MODEL_K * D_MODEL_V + offset_d * D_MODEL_V * BLOCK_MODEL + \ - tl.arange(0, BLOCK_MODEL)[:, None] * D_MODEL_V + offset_s * BLOCK_MODEL + \ - tl.arange(0, BLOCK_MODEL)[None, :] + D_MODEL_K * D_MODEL_V - - p2 = p2 + offset_bh * NUM_BLOCK * D_MODEL_V + \ - tl.arange(0, BLOCK_MODEL) + offset_s * BLOCK_MODEL + D_MODEL_V - - acc = tl.zeros([BLOCK_MODEL, BLOCK_MODEL], dtype=tl.float32) - acc += tl.load(S) - - S += D_MODEL_K * D_MODEL_V - - tl.store(O, acc.to(O.dtype.element_ty)) - O += D_MODEL_K * D_MODEL_V - - for i in range(NUM_BLOCK-2): - p_v = tl.load(p2) - S_i = tl.load(S) - acc = acc * p_v[None, :] + S_i - tl.store(O, acc.to(O.dtype.element_ty)) - p2 += D_MODEL_V - S += D_MODEL_K * D_MODEL_V - O += D_MODEL_K * D_MODEL_V - - -# NUM_SPLIT_K/V. K/V dimension split into NUM_SPLIT_K/V parts with equal size BLOCK_MODEL -@triton.jit -def _bwd_recurrence( - S, - p2, - DS, - Dp2, - NUM_BLOCK, - NUM_SPLIT_K, - NUM_SPLIT_V, - D_MODEL_K: tl.constexpr, - D_MODEL_V: tl.constexpr, - BLOCK_MODEL: tl.constexpr - -): - - offset_bh = tl.program_id(0) - offset_d = tl.program_id(1) - offset_s = tl.program_id(2) - - # skip the last chunk because it is never used - S = S + offset_bh * NUM_BLOCK * D_MODEL_K * D_MODEL_V + offset_d * D_MODEL_V * BLOCK_MODEL + \ - tl.arange(0, BLOCK_MODEL)[:, None] * D_MODEL_V + offset_s * BLOCK_MODEL + tl.arange( - 0, BLOCK_MODEL)[None, :] + (NUM_BLOCK - 2) * D_MODEL_K * D_MODEL_V - - # start from the last chunk - DS = DS + offset_bh * NUM_BLOCK * D_MODEL_K * D_MODEL_V + offset_d * D_MODEL_V * BLOCK_MODEL + \ - tl.arange(0, BLOCK_MODEL)[:, None] * D_MODEL_V + offset_s * BLOCK_MODEL + tl.arange( - 0, BLOCK_MODEL)[None, :] + (NUM_BLOCK - 1) * D_MODEL_K * D_MODEL_V - - # skip the last chunk because it is never used - # p1 = p1 + offset_bh * NUM_BLOCK * D_MODEL_K + tl.arange(0, BLOCK_MODEL) + offset_d * BLOCK_MODEL + (NUM_BLOCK - 2) * D_MODEL_K - - # skip the last chunk because it is never used - p2 = p2 + offset_bh * NUM_BLOCK * D_MODEL_V + \ - tl.arange(0, BLOCK_MODEL) + offset_s * \ - BLOCK_MODEL + (NUM_BLOCK - 2) * D_MODEL_V - - # skip the last chunk because it is never used - # NUM_BLOCK * D_MODEL_K * NUM_SPLIT_V: stride_bh - # offset_s * D_MODEL_K: find the right split in the K dimension - # Dp1 = Dp1 + offset_bh * NUM_BLOCK * D_MODEL_K * NUM_SPLIT_V + offset_s * D_MODEL_K + tl.arange(0, BLOCK_MODEL) + offset_d * BLOCK_MODEL + (NUM_BLOCK - 2) * D_MODEL_K * NUM_SPLIT_V - - # skip the last chunk because it is never used - Dp2 = Dp2 + offset_bh * NUM_BLOCK * D_MODEL_V * NUM_SPLIT_K + offset_d * D_MODEL_V + \ - tl.arange(0, BLOCK_MODEL) + offset_s * BLOCK_MODEL + \ - (NUM_BLOCK - 2) * D_MODEL_V * NUM_SPLIT_K - - Dacc = tl.zeros([BLOCK_MODEL, BLOCK_MODEL], dtype=tl.float32) - - # ignore the first chunk - for i in range(NUM_BLOCK - 1): - - # p_key = tl.load(p1) - p_value = tl.load(p2) - S_i = tl.load(S) - DS_i = tl.load(DS) - Dacc += DS_i - dp_i = Dacc * S_i - # dp_key = tl.sum(dp_i * p_value[None, :], axis=1) - # tl.store(Dp1, dp_key.to(Dp1.dtype.element_ty)) - dp_value = tl.sum(dp_i, axis=0) - tl.store(Dp2, dp_value.to(Dp2.dtype.element_ty)) - - tl.store(S, Dacc.to(S.dtype.element_ty)) - - # Dacc *= p_key[:, None] - Dacc *= p_value[None, :] - - S -= D_MODEL_K * D_MODEL_V - DS -= D_MODEL_K * D_MODEL_V - # p1 -= D_MODEL_K - p2 -= D_MODEL_V - # Dp1 -= D_MODEL_K * NUM_SPLIT_V - Dp2 -= D_MODEL_V * NUM_SPLIT_K - - -class Chunk_memory_update_only_gv(torch.autograd.Function): - @staticmethod - @contiguous - @custom_fwd - def forward(ctx, decay_value_last, to_add): - B, H, N, D_k, D_v = to_add.shape - output = torch.empty_like(to_add) - BLOCK_MODEL = 32 - - assert D_k % 32 == 0 - assert D_v % 32 == 0 - # assert D_k == decay_key_last.shape[-1] - assert D_v == decay_value_last.shape[-1] - - grid = (B*H, D_k//BLOCK_MODEL, D_v//BLOCK_MODEL) - ctx.grid = grid - ctx.BLOCK_MODEL = BLOCK_MODEL - - _fwd_recurrence[grid]( - to_add, - decay_value_last, - output, - D_MODEL_K=D_k, D_MODEL_V=D_v, - NUM_BLOCK=N, - BLOCK_MODEL=BLOCK_MODEL - ) - - output[:, :, 0] = 0 - ctx.save_for_backward(output, decay_value_last) - - return output - - @staticmethod - @contiguous - @custom_bwd - def backward(ctx, DO): - - output, decay_value_last = ctx.saved_tensors - - B, H, N, D_k, D_v = output.shape - - num_block = N - - BLOCK_MODEL = 32 - - grid = (B*H, D_k//BLOCK_MODEL, D_v//BLOCK_MODEL) - - # I don't want atomic_add to be used in the backward pass - # so I add another dimension to the output tensor (D_k/v // BLOCK_MODEL) - # afterward, I sum over this dimension to get the correct gradient - D_p2 = torch.empty(B, H, N, D_k // BLOCK_MODEL, D_v, - device=DO.device, dtype=torch.float32) - - _bwd_recurrence[grid]( - output, decay_value_last, - DO, D_p2, - NUM_BLOCK=num_block, NUM_SPLIT_K=D_k // BLOCK_MODEL, NUM_SPLIT_V=D_v // BLOCK_MODEL, - D_MODEL_K=D_k, - D_MODEL_V=D_v, - BLOCK_MODEL=BLOCK_MODEL - ) - - output[:, :, -1] = 0 - # D_p1[:, :, 0] = 0 - # D_p1[:, :, -1] = 0 - D_p2[:, :, 0] = 0 - D_p2[:, :, -1] = 0 - - return D_p2.sum(-2).to(decay_value_last.dtype), output diff --git a/flash_linear_attention/fla/ops/triton/gla/block_parallel/inter_chunk_contribution/fn.py b/flash_linear_attention/fla/ops/triton/gla/block_parallel/inter_chunk_contribution/fn.py deleted file mode 100644 index 138eeba..0000000 --- a/flash_linear_attention/fla/ops/triton/gla/block_parallel/inter_chunk_contribution/fn.py +++ /dev/null @@ -1,49 +0,0 @@ -# -*- coding: utf-8 -*- - -from .chunk_scan_triton_full import Chunk_memory_update_full -from .chunk_scan_triton_no_decay import Chunk_memory_update_no_decay -from .chunk_scan_triton_only_gk import Chunk_memory_update_only_gk -from .chunk_scan_triton_only_gv import Chunk_memory_update_only_gv -from .preprocess_cumsum_gk import PreprocessCumSum_GK -from .preprocess_cumsum_gv import PreprocessCumSum_GV - - -def inter_chunk_onc(query, key, value, gk, gv): - if gk is not None: - g_key_cumsum, reduce_key, q_exp, g_key_last_exp = PreprocessCumSum_GK.apply( - query, key, gk) - else: - reduce_key = key - q_exp = None - g_key_cumsum = None - g_key_last_exp = None - - if gv is not None: - g_value_cumsum, reduce_value, g_value_cumsum_exp, g_value_last_exp = PreprocessCumSum_GV.apply( - value, gv) - else: - reduce_value = value - g_value_cumsum = None - g_value_last_exp = None - - to_add = reduce_key.transpose(-1, -2).to(query.dtype) @ reduce_value.to(value.dtype) - - if gk is not None and gv is not None: - memory_cache = Chunk_memory_update_full.apply( - g_key_last_exp, g_value_last_exp, to_add) - inter_chunk_contribution = ( - (q_exp.to(query.dtype)) @ memory_cache) * g_value_cumsum_exp - elif gk is None and gv is not None: - memory_cache = Chunk_memory_update_only_gv.apply( - g_value_last_exp, to_add) - inter_chunk_contribution = ( - (query) @ memory_cache) * g_value_cumsum_exp - elif gk is not None and gv is None: - memory_cache = Chunk_memory_update_only_gk.apply( - g_key_last_exp, to_add) - inter_chunk_contribution = ((q_exp.to(query.dtype)) @ memory_cache) - else: - memory_cache = Chunk_memory_update_no_decay.apply(to_add) - inter_chunk_contribution = ((query) @ memory_cache) - - return g_key_cumsum, g_value_cumsum, inter_chunk_contribution.to(query.dtype) diff --git a/flash_linear_attention/fla/ops/triton/gla/block_parallel/inter_chunk_contribution/preprocess_cumsum_gk.py b/flash_linear_attention/fla/ops/triton/gla/block_parallel/inter_chunk_contribution/preprocess_cumsum_gk.py deleted file mode 100644 index c0f1d81..0000000 --- a/flash_linear_attention/fla/ops/triton/gla/block_parallel/inter_chunk_contribution/preprocess_cumsum_gk.py +++ /dev/null @@ -1,259 +0,0 @@ -# -*- coding: utf-8 -*- - -import torch -import triton -import triton.language as tl -from torch.cuda.amp import custom_bwd, custom_fwd - -from fla.ops.triton.utils import contiguous - -# def stable_logsigmoid(x): -# # Use the identity log(sigmoid(x)) = -log(1 + exp(-x)) -# # This is stable for large negative values of x -# neg_abs_x = -torch.abs(x) -# return torch.where(x < 0, x, neg_abs_x) - torch.log1p(torch.exp(neg_abs_x)) - - -@triton.jit -def _fwd_preprocess_cumsum_gk( - Q, - K, - GK, - GK_cumsum, - Q_exp, - K_reduce, - GK_last_exp, - NUM_CHUNK, - L, - D_MODEL_K: tl.constexpr, - D_BLOCK_K: tl.constexpr, - CHUNK_SIZE: tl.constexpr, -): - offset_bh = tl.program_id(0) - offset_c = tl.program_id(1) - offset_nk = tl.program_id(2) - Q_ptr = Q + offset_bh * L * D_MODEL_K + offset_c * \ - CHUNK_SIZE * D_MODEL_K + tl.arange(0, D_BLOCK_K) + D_BLOCK_K * offset_nk - Q_exp_ptr = Q_exp + offset_bh * L * D_MODEL_K + offset_c * \ - CHUNK_SIZE * D_MODEL_K + tl.arange(0, D_BLOCK_K) + D_BLOCK_K * offset_nk - GK_ptr = GK + offset_bh * L * D_MODEL_K + offset_c * \ - CHUNK_SIZE * D_MODEL_K + tl.arange(0, D_BLOCK_K) + D_BLOCK_K * offset_nk - GK_cumsum_ptr = GK_cumsum + offset_bh * L * D_MODEL_K + \ - offset_c * CHUNK_SIZE * D_MODEL_K + tl.arange(0, D_BLOCK_K) + D_BLOCK_K * offset_nk - GK_last_exp_ptr = GK_last_exp + offset_bh * NUM_CHUNK * \ - D_MODEL_K + offset_c * D_MODEL_K + tl.arange(0, D_BLOCK_K) + D_BLOCK_K * offset_nk - cumsum = tl.zeros([D_BLOCK_K], dtype=tl.float32) - - mask = (D_BLOCK_K * offset_nk + tl.arange(0, D_BLOCK_K)) < D_MODEL_K - - for _ in range(CHUNK_SIZE): - gk = tl.load(GK_ptr, mask=mask, other=0).to(tl.float32) - cumsum += gk - tl.store(GK_cumsum_ptr, cumsum.to(GK_cumsum_ptr.dtype.element_ty), mask=mask) - cumsum_exp = tl.exp(cumsum) - q = tl.load(Q_ptr, mask=mask, other=0) - q_exp = q * cumsum_exp - tl.store(Q_exp_ptr, q_exp, mask=mask) - Q_ptr += D_MODEL_K - Q_exp_ptr += D_MODEL_K - GK_ptr += D_MODEL_K - GK_cumsum_ptr += D_MODEL_K - - tl.store(GK_last_exp_ptr, tl.exp(cumsum).to( - GK_last_exp_ptr.dtype.element_ty), mask=mask) - - tl.debug_barrier() - - GK_cumsum_ptr = GK_cumsum + offset_bh * L * D_MODEL_K + \ - offset_c * CHUNK_SIZE * D_MODEL_K + tl.arange(0, D_BLOCK_K) + D_BLOCK_K * offset_nk - K_ptr = K + offset_bh * L * D_MODEL_K + offset_c * \ - CHUNK_SIZE * D_MODEL_K + tl.arange(0, D_BLOCK_K) + D_BLOCK_K * offset_nk - K_reduce_ptr = K_reduce + offset_bh * L * D_MODEL_K + \ - offset_c * CHUNK_SIZE * D_MODEL_K + tl.arange(0, D_BLOCK_K) + D_BLOCK_K * offset_nk - - for _ in range(CHUNK_SIZE): - gk_cumsum = tl.load(GK_cumsum_ptr, mask=mask, other=0) - k = tl.load(K_ptr, mask=mask, other=0) - k_reduce = k * tl.exp(cumsum - gk_cumsum) - tl.store(K_reduce_ptr, k_reduce.to(K_reduce_ptr.dtype.element_ty), mask=mask) - - K_ptr += D_MODEL_K - GK_cumsum_ptr += D_MODEL_K - K_reduce_ptr += D_MODEL_K - - -@triton.jit -def _bwd_preprocess_cumsum_gk( - Q, - K, - GK, - GK_cumsum, - DQ_exp, - DK_reduce, - DGK_last_exp, - DGK_cumsum, - DQ, - DK, - DGK, - NUM_CHUNK, - L, - D_MODEL_K: tl.constexpr, - D_BLOCK_K: tl.constexpr, - CHUNK_SIZE: tl.constexpr, -): - - offset_bh = tl.program_id(0) - offset_c = tl.program_id(1) - offset_nk = tl.program_id(2) - mask = (D_BLOCK_K * offset_nk + tl.arange(0, D_BLOCK_K)) < D_MODEL_K - - Q_ptr = Q + offset_bh * L * D_MODEL_K + offset_c * \ - CHUNK_SIZE * D_MODEL_K + tl.arange(0, D_BLOCK_K) + D_BLOCK_K * offset_nk - K_ptr = K + offset_bh * L * D_MODEL_K + offset_c * \ - CHUNK_SIZE * D_MODEL_K + tl.arange(0, D_BLOCK_K) + D_BLOCK_K * offset_nk - GK_ptr = GK + offset_bh * L * D_MODEL_K + offset_c * \ - CHUNK_SIZE * D_MODEL_K + tl.arange(0, D_BLOCK_K) + D_BLOCK_K * offset_nk - GK_cumsum_ptr = GK_cumsum + offset_bh * L * D_MODEL_K + \ - offset_c * CHUNK_SIZE * D_MODEL_K + tl.arange(0, D_BLOCK_K) + D_BLOCK_K * offset_nk - - DQ_ptr = DQ + offset_bh * L * D_MODEL_K + offset_c * \ - CHUNK_SIZE * D_MODEL_K + tl.arange(0, D_BLOCK_K) + D_BLOCK_K * offset_nk - DK_ptr = DK + offset_bh * L * D_MODEL_K + offset_c * \ - CHUNK_SIZE * D_MODEL_K + tl.arange(0, D_BLOCK_K) + D_BLOCK_K * offset_nk - DQ_exp_ptr = DQ_exp + offset_bh * L * D_MODEL_K + offset_c * \ - CHUNK_SIZE * D_MODEL_K + tl.arange(0, D_BLOCK_K) + D_BLOCK_K * offset_nk - DK_reduce_ptr = DK_reduce + offset_bh * L * D_MODEL_K + \ - offset_c * CHUNK_SIZE * D_MODEL_K + tl.arange(0, D_BLOCK_K) + D_BLOCK_K * offset_nk - DGK_cumsum_ptr = DGK_cumsum + offset_bh * L * D_MODEL_K + \ - offset_c * CHUNK_SIZE * D_MODEL_K + tl.arange(0, D_BLOCK_K) + D_BLOCK_K * offset_nk - DGK_ptr = DGK + offset_bh * L * D_MODEL_K + offset_c * \ - CHUNK_SIZE * D_MODEL_K + tl.arange(0, D_BLOCK_K) + D_BLOCK_K * offset_nk - - D_GK_last_exp_ptr = DGK_last_exp + offset_bh * NUM_CHUNK * \ - D_MODEL_K + offset_c * D_MODEL_K + tl.arange(0, D_BLOCK_K) + D_BLOCK_K * offset_nk - # - cumsum_gradient = tl.zeros([D_BLOCK_K], dtype=tl.float32) - grad_gk_last = tl.zeros([D_BLOCK_K], dtype=tl.float32) - - gk_last = tl.load(GK_cumsum_ptr + (CHUNK_SIZE - 1) - * D_MODEL_K, mask=mask, other=0).to(tl.float32) - cumsum_gradient += tl.load(D_GK_last_exp_ptr, mask=mask, other=0) * tl.exp(gk_last) - - GK_ptr += (CHUNK_SIZE - 1) * D_MODEL_K - GK_cumsum_ptr += (CHUNK_SIZE - 1) * D_MODEL_K - Q_ptr += (CHUNK_SIZE - 1) * D_MODEL_K - K_ptr += (CHUNK_SIZE - 1) * D_MODEL_K - - DQ_exp_ptr += (CHUNK_SIZE - 1) * D_MODEL_K - DK_reduce_ptr += (CHUNK_SIZE - 1) * D_MODEL_K - DGK_cumsum_ptr += (CHUNK_SIZE - 1) * D_MODEL_K - DQ_ptr += (CHUNK_SIZE - 1) * D_MODEL_K - DK_ptr += (CHUNK_SIZE - 1) * D_MODEL_K - DGK_ptr += (CHUNK_SIZE - 1) * D_MODEL_K - - for idx in range(CHUNK_SIZE - 1, -1, -1): - gk_cs = tl.load(GK_cumsum_ptr, mask=mask, other=0).to(tl.float32) - k = tl.load(K_ptr, mask=mask, other=0).to(tl.float32) - grad_k = tl.exp(gk_last - gk_cs) * \ - tl.load(DK_reduce_ptr, mask=mask, other=0).to(tl.float32) - tl.store(DK_ptr, grad_k.to(DK_ptr.dtype.element_ty), mask=mask) - grad_k *= k - cumsum_gradient -= grad_k - grad_gk_last += grad_k - - q = tl.load(Q_ptr, mask=mask, other=0).to(tl.float32) - grad_q = tl.exp(gk_cs) * tl.load(DQ_exp_ptr, mask=mask, other=0) - tl.store(DQ_ptr, grad_q.to(DK_ptr.dtype.element_ty), mask=mask) - cumsum_gradient += grad_q * q.to(tl.float32) - - # from intra-chunk contribution. - cumsum_gradient += tl.load(DGK_cumsum_ptr, mask=mask, other=0).to(tl.float32) - - tl.store(DGK_ptr, cumsum_gradient.to(DGK_ptr.dtype.element_ty), mask=mask) - - Q_ptr -= D_MODEL_K - DQ_exp_ptr -= D_MODEL_K - K_ptr -= D_MODEL_K - DK_reduce_ptr -= D_MODEL_K - GK_cumsum_ptr -= D_MODEL_K - DGK_cumsum_ptr -= D_MODEL_K - DQ_ptr -= D_MODEL_K - DK_ptr -= D_MODEL_K - DGK_ptr -= D_MODEL_K - - DGK_ptr = DGK + offset_bh * L * D_MODEL_K + offset_c * CHUNK_SIZE * \ - D_MODEL_K + tl.arange(0, D_BLOCK_K) + (CHUNK_SIZE - 1) * D_MODEL_K + D_BLOCK_K * offset_nk - GK_ptr = GK + offset_bh * L * D_MODEL_K + offset_c * CHUNK_SIZE * \ - D_MODEL_K + tl.arange(0, D_BLOCK_K) + (CHUNK_SIZE - 1) * D_MODEL_K + D_BLOCK_K * offset_nk - - # tl.store(D_GK_last_exp_ptr, cumsum_gradient) - - # seems stupid. just workaround some compiler bugs. - grad_gk_last = grad_gk_last + 0. - for idx in range(CHUNK_SIZE - 1, -1, -1): - dgk = tl.load(DGK_ptr, mask=mask, other=0).to(tl.float32) - dgk += grad_gk_last - tl.store(DGK_ptr, dgk.to(DGK_ptr.dtype.element_ty), mask=mask) - DGK_ptr -= D_MODEL_K - GK_ptr -= D_MODEL_K - - -class PreprocessCumSum_GK(torch.autograd.Function): - @staticmethod - @contiguous - @custom_fwd - def forward(ctx, q, k, gk): - B, H, NUM_CHUNK, CHUNK_SIZE, D = q.shape - - D_k = k.shape[-1] - N_k = triton.cdiv(D_k, 32) - grid = (B * H, NUM_CHUNK, N_k) - - k_reduce = torch.empty_like(k) - - q_exp = torch.empty_like(q) - - gk_cumsum = torch.empty_like(gk) - - gk_last_exp = torch.empty_like(gk[:, :, :, 0], dtype=torch.float32) - - _fwd_preprocess_cumsum_gk[grid]( - q, k, gk, gk_cumsum, - q_exp, k_reduce, gk_last_exp, - CHUNK_SIZE=CHUNK_SIZE, NUM_CHUNK=NUM_CHUNK, L=CHUNK_SIZE * NUM_CHUNK, - D_MODEL_K=D_k, D_BLOCK_K=32, num_warps=1, num_stages=2 - ) - - ctx.grid = grid - ctx.save_for_backward(q, k, gk, gk_cumsum) - - return gk_cumsum, k_reduce, q_exp, gk_last_exp - - @staticmethod - @custom_bwd - @contiguous - def backward(ctx, dgk_cumsum, dk_reduce, dq_exp, dgk_last_exp): - q, k, gk, gk_cumsum = ctx.saved_tensors - B, H, NUM_CHUNK, CHUNK_SIZE, D = q.shape - - D_k = k.shape[-1] - N_k = triton.cdiv(D_k, 32) - grid = (B * H, NUM_CHUNK, N_k) - - dq = torch.empty_like(q) - dk = torch.empty_like(k) - dgk = torch.empty_like(gk) - - B, H, NUM_CHUNK, CHUNK_SIZE, D_k = q.shape - - # D_v = v.shape[-1] - - _bwd_preprocess_cumsum_gk[grid]( - q, k, gk, gk_cumsum, - dq_exp, dk_reduce, dgk_last_exp, dgk_cumsum, - dq, dk, dgk, - CHUNK_SIZE=CHUNK_SIZE, NUM_CHUNK=NUM_CHUNK, L=CHUNK_SIZE * NUM_CHUNK, - D_MODEL_K=D_k, D_BLOCK_K=32, num_warps=1, num_stages=2 - ) - - return dq.to(q.dtype), dk.to(k.dtype), dgk.to(gk.dtype), None, None, None diff --git a/flash_linear_attention/fla/ops/triton/gla/block_parallel/inter_chunk_contribution/preprocess_cumsum_gv.py b/flash_linear_attention/fla/ops/triton/gla/block_parallel/inter_chunk_contribution/preprocess_cumsum_gv.py deleted file mode 100644 index 85604aa..0000000 --- a/flash_linear_attention/fla/ops/triton/gla/block_parallel/inter_chunk_contribution/preprocess_cumsum_gv.py +++ /dev/null @@ -1,216 +0,0 @@ -# -*- coding: utf-8 -*- - -import torch -import triton -import triton.language as tl - -from fla.ops.triton.utils import contiguous - - -@triton.jit -def _fwd_preprocess_cumsum_gv( - V, - GV, - GV_cumsum, - GV_exp, - V_reduce, - GV_last_exp, - NUM_CHUNK, - L, - D_MODEL_V: tl.constexpr, - CHUNK_SIZE: tl.constexpr, -): - - offset_bh = tl.program_id(0) - offset_c = tl.program_id(1) - - GV_ptr = GV + offset_bh * L * D_MODEL_V + offset_c * \ - CHUNK_SIZE * D_MODEL_V + tl.arange(0, D_MODEL_V) - - GV_last_exp_ptr = GV_last_exp + offset_bh * NUM_CHUNK * \ - D_MODEL_V + offset_c * D_MODEL_V + tl.arange(0, D_MODEL_V) - - GV_cumsum_ptr = GV_cumsum + offset_bh * L * D_MODEL_V + \ - offset_c * CHUNK_SIZE * D_MODEL_V + tl.arange(0, D_MODEL_V) - GV_exp_ptr = GV_exp + offset_bh * L * D_MODEL_V + offset_c * \ - CHUNK_SIZE * D_MODEL_V + tl.arange(0, D_MODEL_V) - - cumsum = tl.zeros([D_MODEL_V], dtype=tl.float32) - - for _ in range(CHUNK_SIZE): - gv = tl.load(GV_ptr).to(tl.float32) - cumsum += gv - - tl.store(GV_cumsum_ptr, cumsum.to(GV_cumsum_ptr.dtype.element_ty)) - tl.store(GV_exp_ptr, tl.exp(cumsum).to(GV_cumsum_ptr.dtype.element_ty)) - - GV_cumsum_ptr += D_MODEL_V - GV_exp_ptr += D_MODEL_V - GV_ptr += D_MODEL_V - - tl.store(GV_last_exp_ptr, tl.exp(cumsum).to( - GV_last_exp_ptr.dtype.element_ty)) - - tl.debug_barrier() - - V_ptr = V + offset_bh * L * D_MODEL_V + offset_c * \ - CHUNK_SIZE * D_MODEL_V + tl.arange(0, D_MODEL_V) - GV_cumsum_ptr = GV_cumsum + offset_bh * L * D_MODEL_V + \ - offset_c * CHUNK_SIZE * D_MODEL_V + tl.arange(0, D_MODEL_V) - V_reduce_ptr = V_reduce + offset_bh * L * D_MODEL_V + \ - offset_c * CHUNK_SIZE * D_MODEL_V + tl.arange(0, D_MODEL_V) - - for _ in range(CHUNK_SIZE): - v = tl.load(V_ptr) - gv = tl.load(GV_cumsum_ptr) - v_reduce = v * tl.exp(cumsum - gv) - tl.store(V_reduce_ptr, v_reduce.to(V_reduce_ptr.dtype.element_ty)) - - V_ptr += D_MODEL_V - V_reduce_ptr += D_MODEL_V - GV_cumsum_ptr += D_MODEL_V - - -@triton.jit -def _bwd_preprocess_cumsum_gv( - V, - GV, - GV_cumsum, - DGV_cumsum_exp, - DV_reduce, - DGV_last_exp, - DGV_cumsum, - DV, - DGV, - NUM_CHUNK, - L, - D_MODEL_V: tl.constexpr, - CHUNK_SIZE: tl.constexpr, -): - - offset_bh = tl.program_id(0) - offset_c = tl.program_id(1) - V_ptr = V + offset_bh * L * D_MODEL_V + offset_c * \ - CHUNK_SIZE * D_MODEL_V + tl.arange(0, D_MODEL_V) - GV_ptr = GV + offset_bh * L * D_MODEL_V + offset_c * \ - CHUNK_SIZE * D_MODEL_V + tl.arange(0, D_MODEL_V) - GV_cumsum_ptr = GV_cumsum + offset_bh * L * D_MODEL_V + \ - offset_c * CHUNK_SIZE * D_MODEL_V + tl.arange(0, D_MODEL_V) - - DV_ptr = DV + offset_bh * L * D_MODEL_V + offset_c * \ - CHUNK_SIZE * D_MODEL_V + tl.arange(0, D_MODEL_V) - DV_reduce_ptr = DV_reduce + offset_bh * L * D_MODEL_V + \ - offset_c * CHUNK_SIZE * D_MODEL_V + tl.arange(0, D_MODEL_V) - DGV_cumsum_ptr = DGV_cumsum + offset_bh * L * D_MODEL_V + \ - offset_c * CHUNK_SIZE * D_MODEL_V + tl.arange(0, D_MODEL_V) - DGV_cumsum_exp_ptr = DGV_cumsum_exp + offset_bh * L * D_MODEL_V + \ - offset_c * CHUNK_SIZE * D_MODEL_V + tl.arange(0, D_MODEL_V) - - DGV_ptr = DGV + offset_bh * L * D_MODEL_V + offset_c * \ - CHUNK_SIZE * D_MODEL_V + tl.arange(0, D_MODEL_V) - - D_GV_last_exp_ptr = DGV_last_exp + offset_bh * NUM_CHUNK * \ - D_MODEL_V + offset_c * D_MODEL_V + tl.arange(0, D_MODEL_V) - - cumsum_gradient = tl.zeros([D_MODEL_V], dtype=tl.float32) - grad_gv_last = tl.zeros([D_MODEL_V], dtype=tl.float32) - - gv_last = tl.load(GV_cumsum_ptr + (CHUNK_SIZE - 1) * D_MODEL_V) - cumsum_gradient += tl.load(D_GV_last_exp_ptr) * \ - tl.exp(gv_last).to(tl.float32) - - GV_ptr += (CHUNK_SIZE - 1) * D_MODEL_V - GV_cumsum_ptr += (CHUNK_SIZE - 1) * D_MODEL_V - - V_ptr += (CHUNK_SIZE - 1) * D_MODEL_V - - DV_reduce_ptr += (CHUNK_SIZE - 1) * D_MODEL_V - DGV_cumsum_ptr += (CHUNK_SIZE - 1) * D_MODEL_V - DGV_cumsum_exp_ptr += (CHUNK_SIZE - 1) * D_MODEL_V - DV_ptr += (CHUNK_SIZE - 1) * D_MODEL_V - DGV_ptr += (CHUNK_SIZE - 1) * D_MODEL_V - - for idx in range(CHUNK_SIZE - 1, -1, -1): - gv_cs = tl.load(GV_cumsum_ptr).to(tl.float32) - v = tl.load(V_ptr).to(tl.float32) - grad_v = tl.exp(gv_last - gv_cs) * \ - tl.load(DV_reduce_ptr).to(tl.float32) - tl.store(DV_ptr, grad_v.to(DV_ptr.dtype.element_ty)) - grad_v *= v - cumsum_gradient -= grad_v - grad_gv_last += grad_v - - # q = tl.load(Q_ptr).to(tl.float32) - grad_v = tl.exp(gv_cs) * tl.load(DGV_cumsum_exp_ptr) - cumsum_gradient += grad_v - - # from intra-chunk contribution. - cumsum_gradient += tl.load(DGV_cumsum_ptr).to(tl.float32) - - tl.store(DGV_ptr, cumsum_gradient.to(DGV_ptr.dtype.element_ty)) - - V_ptr -= D_MODEL_V - DV_reduce_ptr -= D_MODEL_V - GV_cumsum_ptr -= D_MODEL_V - DGV_cumsum_ptr -= D_MODEL_V - DV_ptr -= D_MODEL_V - DGV_ptr -= D_MODEL_V - DGV_cumsum_exp_ptr -= D_MODEL_V - - DGV_ptr = DGV + offset_bh * L * D_MODEL_V + offset_c * CHUNK_SIZE * \ - D_MODEL_V + tl.arange(0, D_MODEL_V) + (CHUNK_SIZE - 1) * D_MODEL_V - GV_ptr = GV + offset_bh * L * D_MODEL_V + offset_c * CHUNK_SIZE * \ - D_MODEL_V + tl.arange(0, D_MODEL_V) + (CHUNK_SIZE - 1) * D_MODEL_V - - grad_gv_last = grad_gv_last + 0. - - for idx in range(CHUNK_SIZE - 1, -1, -1): - dgv = tl.load(DGV_ptr).to(tl.float32) - dgv += grad_gv_last - tl.store(DGV_ptr, dgv.to(DGV_ptr.dtype.element_ty)) - DGV_ptr -= D_MODEL_V - GV_ptr -= D_MODEL_V - - -class PreprocessCumSum_GV(torch.autograd.Function): - @staticmethod - @contiguous - @torch.cuda.amp.custom_fwd - def forward(ctx, v, gv): - B, H, NUM_CHUNK, CHUNK_SIZE, D_v = v.shape - - grid = (B * H, NUM_CHUNK) - ctx.grid = grid - - gv_cumsum = torch.empty_like(gv, dtype=torch.float32) - gv_cumsum_exp = torch.empty_like(gv) - v_reduce = torch.empty_like(v) - gv_last_exp = torch.empty_like(gv[:, :, :, 0], dtype=torch.float32) - _fwd_preprocess_cumsum_gv[grid]( - v, gv, gv_cumsum, gv_cumsum_exp, - v_reduce, gv_last_exp, - CHUNK_SIZE=CHUNK_SIZE, NUM_CHUNK=NUM_CHUNK, L=CHUNK_SIZE * NUM_CHUNK, - D_MODEL_V=D_v, num_warps=8 if D_v >= 512 else 4 - ) - - ctx.grid = grid - ctx.save_for_backward(v, gv, gv_cumsum) - return gv_cumsum, v_reduce, gv_cumsum_exp, gv_last_exp - - @staticmethod - @contiguous - def backward(ctx, dgv_cumsum, dv_reduce, dgv_cumsum_exp, dgv_last_exp): - v, gv, gv_cumsum = ctx.saved_tensors - grid = ctx.grid - - B, H, NUM_CHUNK, CHUNK_SIZE, D_v = v.shape - - dv = torch.empty_like(v) - dgv = torch.empty_like(gv) - _bwd_preprocess_cumsum_gv[grid]( - v, gv, gv_cumsum, dgv_cumsum_exp, dv_reduce, dgv_last_exp, dgv_cumsum, - dv, dgv, - CHUNK_SIZE=CHUNK_SIZE, NUM_CHUNK=NUM_CHUNK, L=CHUNK_SIZE * NUM_CHUNK, - D_MODEL_V=D_v, num_warps=8 if D_v >= 512 else 4 - ) - return dv.to(v.dtype), dgv.to(gv.dtype) diff --git a/flash_linear_attention/fla/ops/triton/gla/block_parallel/intra_chunk_contribution/__init__.py b/flash_linear_attention/fla/ops/triton/gla/block_parallel/intra_chunk_contribution/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/flash_linear_attention/fla/ops/triton/gla/block_parallel/intra_chunk_contribution/fn.py b/flash_linear_attention/fla/ops/triton/gla/block_parallel/intra_chunk_contribution/fn.py deleted file mode 100644 index 4375856..0000000 --- a/flash_linear_attention/fla/ops/triton/gla/block_parallel/intra_chunk_contribution/fn.py +++ /dev/null @@ -1,28 +0,0 @@ -# -*- coding: utf-8 -*- - -import torch - -from .fn_only_gk import IntraCalA -from .fn_only_gv import IntraCalO - - -def intra_chunk_onc(q, k, v, gk, gv): - assert q.is_contiguous() - assert k.is_contiguous() - assert v.is_contiguous() - if gk is not None: - assert gk.is_contiguous() - if gv is not None: - assert gv.is_contiguous() - - assert k.shape[-2] % 16 == 0 - - if gk is not None: - A = IntraCalA.apply(q, k, gk) - else: - A = q @ k.transpose(-1, -2) - - mask = torch.triu(torch.ones(A.shape[-2], A.shape[-2]), diagonal=1).bool().to(A.device) - A.masked_fill_(mask, 0) - - return IntraCalO.apply(A, v, gv) if gv is not None else A.to(v) @ v diff --git a/flash_linear_attention/fla/ops/triton/gla/block_parallel/intra_chunk_contribution/fn_only_gk.py b/flash_linear_attention/fla/ops/triton/gla/block_parallel/intra_chunk_contribution/fn_only_gk.py deleted file mode 100644 index 8eaddd9..0000000 --- a/flash_linear_attention/fla/ops/triton/gla/block_parallel/intra_chunk_contribution/fn_only_gk.py +++ /dev/null @@ -1,343 +0,0 @@ -# -*- coding: utf-8 -*- - -import torch -import triton -import triton.language as tl -from torch.cuda.amp import custom_bwd, custom_fwd - -from fla.ops.triton.utils import contiguous - - -@triton.jit -def _fwd_kernel_compute_A( - Q, - K, - GK, - A, - stride_q1, - stride_q2, - stride_q3, - stride_q4, - stride_a1, - stride_a2, - stride_a3, - stride_a4, - Z, - H, - N_CTX, - D, - BLOCK_DMODEL_QK: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, -): - start_m = tl.program_id(0) - off_hz = tl.program_id(1) - off_k = tl.program_id(2) - - qk_offset = off_hz * stride_q2 + off_k * BLOCK_DMODEL_QK - a_offset = (off_k * Z*H + off_hz) * stride_a2 - - lo = 0 - hi = BLOCK_N - - Q_ptr = Q + qk_offset + (start_m) * stride_q3 + tl.arange(0, BLOCK_DMODEL_QK)[ - None, :] + tl.arange(0, 16)[:, None] * stride_q4 - - K_ptr = K + qk_offset + (start_m) * stride_q3 + tl.arange(0, BLOCK_DMODEL_QK)[ - :, None] + tl.arange(0, 16)[None, :] * stride_q4 - - GK_K_ptr = GK + qk_offset + (start_m) * stride_q3 + tl.arange( - 0, BLOCK_DMODEL_QK)[:, None] + tl.arange(0, 16)[None, :] * stride_q4 - - GK_Q_ptr = GK + qk_offset + (start_m) * stride_q3 + tl.arange( - 0, BLOCK_DMODEL_QK)[None, :] + tl.arange(0, 16)[:, None] * stride_q4 - - A_ptr = A + a_offset + (start_m) * stride_a3 + tl.arange(0, - 16)[None, :] + tl.arange(0, 16)[:, None] * stride_a4 - - for q_high in range(16, hi, 16): - q = tl.load(Q_ptr + q_high * stride_q4) - q_gk = tl.load(GK_Q_ptr + q_high * stride_q4).to(tl.float32) - q_normalizer = tl.load(GK + qk_offset + start_m * stride_q3 + - q_high * stride_q4 + tl.arange(0, BLOCK_DMODEL_QK)).to(tl.float32) - q_gk2 = tl.exp(q_gk - q_normalizer[None, :]) - q = q * q_gk2.to(q.dtype) - - # inter-chunk bf16 - for k_high in range(0, q_high, 16): - k = tl.load(K_ptr + k_high * stride_q4) - k_gk = tl.load(GK_K_ptr + k_high * stride_q4).to(tl.float32) - k_gk = tl.exp(q_normalizer[:, None] - k_gk) - k = k * k_gk.to(k.dtype) - qk = tl.dot(q, k, allow_tf32=False) - tl.store(A_ptr + q_high * stride_a4 + k_high, - qk.to(A_ptr.dtype.element_ty)) - - # intra chunk fp32 - for q_high in range(lo, hi, 16): - q = tl.load(Q_ptr + q_high * stride_q4) - q_gk = tl.load(GK_Q_ptr + q_high * stride_q4).to(tl.float32) - q_normalizer = tl.load(GK + qk_offset + start_m * stride_q3 + - q_high * stride_q4 + tl.arange(0, BLOCK_DMODEL_QK)).to(tl.float32) - q_gk2 = tl.exp(q_gk - q_normalizer[None, :]) - q = q * q_gk2 - q_gk3 = tl.exp(q_normalizer[None, :] - q_gk) - k = tl.load(K_ptr + q_high * stride_q4) - k = k * tl.trans(q_gk3) - - qk = tl.dot(q, k, allow_tf32=False) - qk = tl.where(tl.arange(0, 16)[:, None] - >= tl.arange(0, 16)[None, :], qk, 0.) - tl.store(A_ptr + q_high * stride_a4 + q_high, - qk.to(A_ptr.dtype.element_ty)) - - -@triton.jit -def _bwd_kernel_dqk( - Q, - K, - GK, - DA, - DQ, - DK, - DGK, - stride_q1, - stride_q2, - stride_q3, - stride_q4, - stride_a1, - stride_a2, - stride_a3, - stride_a4, - Z, - H, - N_CTX, - D, - BLOCK_DMODEL_QK: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr -): - - start_m = tl.program_id(0) - off_hz = tl.program_id(1) - off_k = tl.program_id(2) - - qk_offset = off_hz * stride_q2 + BLOCK_DMODEL_QK * off_k - a_offset = off_hz * stride_a2 - - lo = 0 - hi = BLOCK_N - - Q_ptr = Q + qk_offset + (start_m) * stride_q3 + tl.arange(0, BLOCK_DMODEL_QK)[ - None, :] + tl.arange(0, 16)[:, None] * stride_q4 - - DQ_ptr = DQ + qk_offset + (start_m) * stride_q3 + tl.arange(0, BLOCK_DMODEL_QK)[ - None, :] + tl.arange(0, 16)[:, None] * stride_q4 - - K_ptr = K + qk_offset + (start_m) * stride_q3 + tl.arange(0, BLOCK_DMODEL_QK)[ - None, :] + tl.arange(0, 16)[:, None] * stride_q4 - - GK_K_ptr = GK + qk_offset + (start_m) * stride_q3 + tl.arange( - 0, BLOCK_DMODEL_QK)[None, :] + tl.arange(0, 16)[:, None] * stride_q4 - - GK_Q_ptr = GK + qk_offset + (start_m) * stride_q3 + tl.arange( - 0, BLOCK_DMODEL_QK)[None, :] + tl.arange(0, 16)[:, None] * stride_q4 - - # DGK_Q_ptr = DGK + qk_offset + (start_m) * stride_q3+ tl.arange(0, BLOCK_DMODEL_QK)[None, :] + tl.arange(0, 16)[:, None] * stride_q4 - - DA_ptr = DA + a_offset + (start_m) * stride_a3 + tl.arange(0, - 16)[None, :] + tl.arange(0, 16)[:, None] * stride_a4 - - # inter chunk dq. bf16 - for q_high in range(lo+16, hi, 16): - q = tl.load(Q_ptr + q_high * stride_q4) - - q_normalizer = tl.load(GK + qk_offset + (start_m * stride_q3) + - q_high * stride_q4 + tl.arange(0, BLOCK_DMODEL_QK)).to(tl.float32) - - # q2 = q * q_gk.to(q.dtype) - - dq2 = tl.zeros([16, BLOCK_DMODEL_QK], dtype=tl.float32) - - for k_high in range(0, q_high, 16): - k = tl.load(K_ptr + k_high * stride_q4) - k_gk = tl.load(GK_K_ptr + k_high * stride_q4).to(tl.float32) - dqk = tl.load(DA_ptr + q_high * stride_a4 + k_high).to(k.dtype) - k_gk = tl.exp(q_normalizer[None, :] - k_gk) - k = k * k_gk.to(k.dtype) - dq2 += tl.dot(dqk, k, allow_tf32=False) - - dq2 = dq2.to(q.dtype) - q_gk = tl.load(GK_Q_ptr + q_high * stride_q4).to(tl.float32) - q_gk = tl.exp(q_gk - q_normalizer[None, :]) - dq = dq2 * q_gk.to(q.dtype) - dq_gk = dq * q - - DQ_ptr = DQ + qk_offset + (start_m) * stride_q3 + tl.arange(0, BLOCK_DMODEL_QK)[ - None, :] + tl.arange(0, 16)[:, None] * stride_q4 + q_high * stride_q4 - tl.store(DQ_ptr, dq.to(DQ_ptr.dtype.element_ty)) - - DGK_Q_ptr = DGK + qk_offset + (start_m) * stride_q3 + tl.arange(0, BLOCK_DMODEL_QK)[ - None, :] + tl.arange(0, 16)[:, None] * stride_q4 + q_high * stride_q4 - # prev = tl.load(DGK_Q_ptr) - tl.store(DGK_Q_ptr, dq_gk.to(DGK_Q_ptr.dtype.element_ty)) - - tl.debug_barrier() - - for k_high in range(lo, hi-16, 16): - k = tl.load(K_ptr + k_high * stride_q4) - k_gk = tl.load(GK_K_ptr + k_high * stride_q4) - dk = tl.zeros([16, BLOCK_DMODEL_QK], dtype=tl.float32) - dgk = tl.zeros([16, BLOCK_DMODEL_QK], dtype=tl.float32) - - for q_high in range(k_high+16, hi, 16): - q = tl.load(Q_ptr + q_high * stride_q4) - q_normalizer = tl.load(GK + qk_offset + (start_m * stride_q3) + q_high * stride_q4 + tl.arange(0, - BLOCK_DMODEL_QK)).to(tl.float32) - q_gk = tl.load(GK_Q_ptr + q_high * stride_q4).to(tl.float32) - q_gk = tl.exp(q_gk - q_normalizer[None, :]).to(q.dtype) - q = q * q_gk - dqk = tl.load(DA_ptr + q_high * stride_a4 + k_high).to(q.dtype) - - k_gk2 = tl.exp(q_normalizer[None, :] - k_gk) - - dk2 = tl.dot(tl.trans(dqk), q, allow_tf32=False) - dk += dk2 * k_gk2 - dgk -= dk2 * k * k_gk2 - - DK_ptr = DK + qk_offset + (start_m) * stride_q3 + tl.arange(0, BLOCK_DMODEL_QK)[ - None, :] + tl.arange(0, 16)[:, None] * stride_q4 + k_high * stride_q4 - tl.store(DK_ptr, dk.to(DK_ptr.dtype.element_ty)) - - DGK_K_ptr = DGK + qk_offset + (start_m) * stride_q3 + tl.arange(0, BLOCK_DMODEL_QK)[ - None, :] + tl.arange(0, 16)[:, None] * stride_q4 + k_high * stride_q4 - prev = tl.load(DGK_K_ptr) - tl.store(DGK_K_ptr, (prev + dgk).to(DGK_K_ptr.dtype.element_ty)) - - tl.debug_barrier() - - DK_ptr = DK + qk_offset + (start_m) * stride_q3 + tl.arange(0, BLOCK_DMODEL_QK)[ - None, :] + tl.arange(0, 16)[:, None] * stride_q4 - - DGK_K_ptr = DGK + qk_offset + (start_m) * stride_q3 + tl.arange( - 0, BLOCK_DMODEL_QK)[None, :] + tl.arange(0, 16)[:, None] * stride_q4 - - DQ_ptr = DQ + qk_offset + (start_m) * stride_q3 + tl.arange(0, BLOCK_DMODEL_QK)[ - None, :] + tl.arange(0, 16)[:, None] * stride_q4 - - # intra chunk, fp32. - for q_high in range(lo, hi, 16): - q = tl.load(Q_ptr + q_high * stride_q4) - q_gk = tl.load(GK_Q_ptr + q_high * stride_q4).to(tl.float32) - q_normalizer = tl.load(GK + qk_offset + start_m * stride_q3 + - q_high * stride_q4 + tl.arange(0, BLOCK_DMODEL_QK)).to(tl.float32) - q_gk2 = tl.exp(q_gk - q_normalizer[None, :]) - q2 = q * q_gk2 - q_gk3 = tl.exp(q_normalizer[None, :] - q_gk) - - k = tl.load(K_ptr + q_high * stride_q4) - k2 = k * q_gk3 - - dqk = tl.load(DA_ptr + q_high * stride_a4 + q_high) - dqk = tl.where(tl.arange(0, 16)[:, None] - >= tl.arange(0, 16)[None, :], dqk, 0.) - - dk2 = tl.dot(tl.trans(dqk), q2, allow_tf32=False) - dk = dk2 * q_gk3 - prev_dk = tl.load(DK_ptr + q_high * stride_q4) - tl.store(DK_ptr + q_high * stride_q4, - (dk + prev_dk).to(DK_ptr.dtype.element_ty)) - - dgk = - dk * k - dq2 = tl.dot(dqk, k2, allow_tf32=False) - dq = dq2 * q_gk2 - - prev_dq = tl.load(DQ_ptr + q_high * stride_q4) - tl.store(DQ_ptr + q_high * stride_q4, - (dq + prev_dq).to(DQ_ptr.dtype.element_ty)) - - dgk += dq * q - prev_dq_gk = tl.load(DGK_K_ptr + q_high * stride_q4) - tl.store(DGK_K_ptr + q_high * stride_q4, - (dgk + prev_dq_gk).to(DGK_K_ptr.dtype.element_ty)) - - -class IntraCalA(torch.autograd.Function): - @staticmethod - @custom_fwd - @contiguous - def forward(ctx, q, k, gk): - - # assert gk.dtype==torch.float32 - # only support for Ampere now - - capability = torch.cuda.get_device_capability() - if capability[0] < 8: - raise RuntimeError( - "Flash attention currently only supported for compute capability >= 80") - - # assert gk.dtype == gv.dtype == torch.float32 - # for now. - BLOCK_M = BLOCK_N = q.shape[-2] - - # shape constraints - Lq, Lk = q.shape[-1], k.shape[-1] - assert Lq == Lk - if Lk > 128: - assert Lk % 128 == 0 - - BLOCK_DMODEL_QK = min(Lk, 128) - ctx.BLOCK_DMODEL_QK = BLOCK_DMODEL_QK - - A = torch.zeros(max(1, Lk//128), q.shape[0], q.shape[1], - q.shape[2], BLOCK_N, BLOCK_N, device=q.device, dtype=q.dtype) - - grid = (q.shape[2], q.shape[0] * q.shape[1], max(1, Lk//128)) - - # assert q.dtype == k.dtype == v.dtype - _fwd_kernel_compute_A[grid]( - q, k, gk, A, - q.stride(0), q.stride(1), q.stride(2), q.stride(3), - # be careful here! - A.stride(1), A.stride(2), A.stride(3), A.stride(4), - q.shape[0], q.shape[1], q.shape[2], q.shape[3], - BLOCK_N=BLOCK_N, BLOCK_DMODEL_QK=BLOCK_DMODEL_QK, BLOCK_M=BLOCK_M, num_warps=8 if ctx.BLOCK_DMODEL_QK == 128 else 4, num_stages=8 - ) - - ctx.save_for_backward(q, k, gk) - ctx.grid = grid - ctx.BLOCK_N = BLOCK_N - ctx.BLOCK_N = BLOCK_N - ctx.head = q.shape[1] - return A.sum(0).to(q.dtype) - - @staticmethod - @custom_bwd - @contiguous - def backward(ctx, dA): - q, k, gk = ctx.saved_tensors - - # appearantly, there is no sync issue when splitting K dim. - dq = torch.zeros_like(q) - dk = torch.zeros_like(k) - dgk = torch.zeros_like(gk) - - BLOCK_N = ctx.BLOCK_N - # for now. - BLOCK_M = BLOCK_N - - _bwd_kernel_dqk[ctx.grid]( - q, k, gk, dA, - dq, - dk, dgk, - q.stride(0), q.stride(1), q.stride(2), q.stride(3), - dA.stride(0), dA.stride(1), dA.stride(2), dA.stride(3), - q.shape[0], q.shape[1], q.shape[2], q.shape[3], - BLOCK_N=BLOCK_N, - BLOCK_DMODEL_QK=ctx.BLOCK_DMODEL_QK, - BLOCK_M=BLOCK_M, - num_warps=8 if ctx.BLOCK_DMODEL_QK == 128 else 4, - num_stages=5 - ) - - return dq.to(q.dtype), dk.to(k.dtype), dgk.to(gk.dtype) diff --git a/flash_linear_attention/fla/ops/triton/gla/block_parallel/intra_chunk_contribution/fn_only_gv.py b/flash_linear_attention/fla/ops/triton/gla/block_parallel/intra_chunk_contribution/fn_only_gv.py deleted file mode 100644 index 3057b7a..0000000 --- a/flash_linear_attention/fla/ops/triton/gla/block_parallel/intra_chunk_contribution/fn_only_gv.py +++ /dev/null @@ -1,336 +0,0 @@ -# -*- coding: utf-8 -*- - -import torch -import triton -import triton.language as tl -from torch.cuda.amp import custom_bwd, custom_fwd - -from fla.ops.triton.utils import contiguous - - -@triton.jit -def _fwd_compute_O( - A, - V, - GV, - O, - stride_a2, - stride_a3, - stride_a4, - stride_v2, - stride_v3, - stride_v4, - BLOCK_N: tl.constexpr, - BLOCK_DMODEL_V: tl.constexpr, -): - start_m = tl.program_id(0) - off_hz = tl.program_id(1) - off_v = tl.program_id(2) - - a_offset = off_hz * stride_a2 - v_offset = off_hz * stride_v2 + off_v * BLOCK_DMODEL_V - - lo = 0 - hi = BLOCK_N - - V_ptr = V + v_offset + (start_m) * stride_v3 + tl.arange(0, - BLOCK_DMODEL_V)[None, :] + tl.arange(0, 16)[:, None] * stride_v4 - - O_ptr = O + v_offset + (start_m) * stride_v3 + tl.arange(0, - BLOCK_DMODEL_V)[None, :] + tl.arange(0, 16)[:, None] * stride_v4 - - GV_ptr = GV + v_offset + (start_m) * stride_v3 + tl.arange(0, BLOCK_DMODEL_V)[ - None, :] + tl.arange(0, 16)[:, None] * stride_v4 - - A_ptr = A + a_offset + (start_m) * stride_a3 + tl.arange(0, - 16)[None, :] + tl.arange(0, 16)[:, None] * stride_a4 - - for q_high in range(lo+16, hi, 16): - q_gv_normalizer = tl.load(GV + v_offset + (start_m) * stride_v3 + - q_high * stride_v4 + tl.arange(0, BLOCK_DMODEL_V)).to(tl.float32) - acc = tl.zeros([16, BLOCK_DMODEL_V], dtype=tl.float32) - - # k_gv = tl.load(GV_ptr + q_high * stride_v4) - # q_gv = tl.exp(k_gv - q_gv_normalizer[None, :]) - - for k_high in range(0, q_high, 16): - qk = tl.load(A_ptr + q_high * stride_a4 + k_high) - v = tl.load(V_ptr + k_high * stride_v4) - k_gv = tl.load(GV_ptr + k_high * stride_v4) - k_gv = tl.exp(q_gv_normalizer[None, :] - k_gv) - v = v * k_gv.to(v.dtype) - # bf16 - output = tl.dot(qk.to(v.dtype), v, allow_tf32=False) - acc += output - - tl.store(O_ptr + q_high * stride_v4, acc.to(O.dtype.element_ty)) - - tl.store(O_ptr, tl.zeros([16, BLOCK_DMODEL_V], - dtype=tl.float32).to(O.dtype.element_ty)) - - tl.debug_barrier() - - for q_high in range(lo, hi, 16): - q_gv_normalizer = tl.load(GV + v_offset + (start_m) * stride_v3 + - q_high * stride_v4 + tl.arange(0, BLOCK_DMODEL_V)).to(tl.float32) - - qk = tl.load(A_ptr + q_high * stride_a4 + q_high) - v = tl.load(V_ptr + q_high * stride_v4) - k_gv = tl.load(GV_ptr + q_high * stride_v4) - k_gv2 = tl.exp(q_gv_normalizer[None, :] - k_gv) - - # fp32 matmul - v = v * k_gv2 - output = tl.dot(qk.to(tl.float32), v, allow_tf32=False) - - q_gv = tl.exp(k_gv - q_gv_normalizer[None, :]) - - prev = tl.load(O_ptr + q_high * stride_v4) - output += prev - output = output * q_gv - - tl.store(O_ptr + q_high * stride_v4, output.to(O.dtype.element_ty)) - - -@triton.jit -def _bwd_kernel_dav( - V, - GV, - A, - O, - DO, - DA, - DV, - DGV, - Z, - H, - stride_a1, - stride_a2, - stride_a3, - stride_a4, - stride_v1, - stride_v2, - stride_v3, - stride_v4, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_DMODEL_V: tl.constexpr -): - - start_m = tl.program_id(0) - off_hz = tl.program_id(1) - off_v = tl.program_id(2) - - a_offset = off_hz * stride_a2 - da_offset = (off_v * Z * H + off_hz) * stride_a2 - v_offset = off_hz * stride_v2 + off_v * BLOCK_DMODEL_V - - lo = 0 - hi = BLOCK_N - - DO_ptr = DO + v_offset + (start_m) * stride_v3 + tl.arange(0, BLOCK_DMODEL_V)[ - None, :] + tl.arange(0, 16)[:, None] * stride_v4 - - O_ptr = O + v_offset + (start_m) * stride_v3 + tl.arange(0, - BLOCK_DMODEL_V)[None, :] + tl.arange(0, 16)[:, None] * stride_v4 - - DV_ptr = DV + v_offset + (start_m) * stride_v3 + tl.arange(0, BLOCK_DMODEL_V)[ - None, :] + tl.arange(0, 16)[:, None] * stride_v4 - - GV_ptr = GV + v_offset + (start_m) * stride_v3 + tl.arange(0, BLOCK_DMODEL_V)[ - None, :] + tl.arange(0, 16)[:, None] * stride_v4 - - DGV_ptr = DGV + v_offset + (start_m) * stride_v3 + tl.arange(0, BLOCK_DMODEL_V)[ - None, :] + tl.arange(0, 16)[:, None] * stride_v4 - - A_ptr = A + a_offset + (start_m) * stride_a3 + tl.arange(0, - 16)[None, :] + tl.arange(0, 16)[:, None] * stride_a4 - - DA_ptr = DA + da_offset + (start_m) * stride_a3 + tl.arange(0, - 16)[None, :] + tl.arange(0, 16)[:, None] * stride_a4 - - # pre-compute do*q_gv. in-place update - for q_high in range(lo, hi, 16): - do = tl.load(DO_ptr + q_high * stride_v4) - o = tl.load(O_ptr + q_high * stride_v4) - tl.store(DGV_ptr + q_high * stride_v4, (do * o)) - - q_gv_normalizer = tl.load(GV + v_offset + (start_m) * stride_v3 + - q_high * stride_v4 + tl.arange(0, BLOCK_DMODEL_V)).to(tl.float32) - q_gv = tl.load(GV_ptr + q_high * stride_v4) - q_gv = tl.exp(q_gv - q_gv_normalizer[None, :]) - do = do * q_gv - - tl.store(DO_ptr + q_high * stride_v4, do.to(DO_ptr.dtype.element_ty)) - - tl.debug_barrier() - - V_ptr = V + v_offset + (start_m) * stride_v3 + \ - tl.arange(0, BLOCK_DMODEL_V)[:, None] + tl.arange(0, 16)[None, :] * stride_v4 - GV_ptr = GV + v_offset + (start_m) * stride_v3 + tl.arange(0, BLOCK_DMODEL_V)[ - :, None] + tl.arange(0, 16)[None, :] * stride_v4 - - for q_high in range(lo+16, hi, 16): - do = tl.load(DO_ptr + q_high * stride_v4) - q_gv_normalizer = tl.load(GV + v_offset + (start_m) * stride_v3 + q_high * - stride_v4 + tl.arange(0, BLOCK_DMODEL_V)).to(tl.float32) - - for k_high in range(0, q_high, 16): - v = tl.load(V_ptr + k_high * stride_v4) - k_gv = tl.load(GV_ptr + k_high * stride_v4) - k_gv = tl.exp(q_gv_normalizer[:, None] - k_gv) - - # bf16 - v2 = v * k_gv.to(v.dtype) - dqk = tl.dot(do, v2, allow_tf32=False) - tl.store(DA_ptr + q_high * stride_a4 + - k_high, dqk.to(DA.dtype.element_ty)) - - tl.debug_barrier() - - A_ptr = A + a_offset + (start_m) * stride_a3 + \ - tl.arange(0, 16)[:, None] + tl.arange(0, 16)[None, :] * stride_a4 - - V_ptr = V + v_offset + (start_m) * stride_v3 + \ - tl.arange(0, BLOCK_DMODEL_V)[None, :] + tl.arange(0, 16)[:, None] * stride_v4 - GV_ptr = GV + v_offset + (start_m) * stride_v3 + \ - tl.arange(0, BLOCK_DMODEL_V)[None, :] + tl.arange(0, 16)[:, None] * stride_v4 - - for k_high in range(0, hi, 16): - dv = tl.zeros([16, BLOCK_DMODEL_V], dtype=tl.float32) - - k_gv = tl.load(GV_ptr + k_high * stride_v4) - - for q_high in range(k_high + 16, BLOCK_N, 16): - do = tl.load(DO_ptr + q_high * stride_v4) - - kq = tl.load(A_ptr + q_high * stride_a4 + k_high).to(do.dtype) - - q_gv_normalizer = tl.load(GV + v_offset + - start_m * stride_v3 + q_high * stride_v4 + tl.arange(0, BLOCK_DMODEL_V)).to(tl.float32) - k_gv2 = tl.exp(q_gv_normalizer[None, :] - k_gv) - - # bf16 - dv2 = tl.dot(kq, do, allow_tf32=False) - dv += dv2 * k_gv2 - - v = tl.load(V_ptr + k_high * stride_v4) - tl.store(DV_ptr + k_high * stride_v4, dv.to(v.dtype)) - - prev_dv = tl.load(DGV_ptr + k_high * stride_v4) - tl.store(DGV_ptr + k_high * stride_v4, prev_dv - dv*v) - - tl.debug_barrier() - - A_ptr = A + a_offset + (start_m) * stride_a3 + tl.arange(0, - 16)[:, None] + tl.arange(0, 16)[None, :] * stride_a4 - - # intra-chunk - for q_high in range(lo, hi, 16): - do = tl.load(DO_ptr + q_high * stride_v4) - - q_gv_normalizer = tl.load(GV + v_offset + start_m * stride_v3 + - q_high * stride_v4 + tl.arange(0, BLOCK_DMODEL_V)).to(tl.float32) - - v = tl.load(V_ptr + q_high * stride_v4) - k_gv = tl.load(GV_ptr + q_high * stride_v4) - k_gv = tl.exp(q_gv_normalizer[None, :] - k_gv) - v2 = v * k_gv - - dqk = tl.dot(do.to(v2.dtype), tl.trans(v2), allow_tf32=False) - dqk = tl.where(tl.arange(0, 16)[:, None] - >= tl.arange(0, 16)[None, :], dqk, 0.) - tl.store(DA_ptr + q_high * stride_a4 + q_high, - dqk.to(DA_ptr.dtype.element_ty)) - - kq = tl.load(A_ptr + q_high * stride_a4 + q_high).to(do.dtype) - dv2 = tl.dot(kq, do, allow_tf32=False) - - dv = dv2 * k_gv - prev_dv = tl.load(DV_ptr + q_high * stride_v4) - tl.store(DV_ptr + q_high * stride_v4, - (prev_dv + dv).to(DV.dtype.element_ty)) - - prev_gdv = tl.load(DGV_ptr + q_high * stride_v4) - prev_gdv -= dv * v - tl.store(DGV_ptr + q_high * stride_v4, - prev_gdv.to(DGV.dtype.element_ty)) - - -class IntraCalO(torch.autograd.Function): - @staticmethod - @custom_fwd - @contiguous - def forward(ctx, A, v, gv): - assert gv.dtype == torch.float32 - # assert A.dtype == torch.float32 - - # only support for Ampere now - capability = torch.cuda.get_device_capability() - if capability[0] < 8: - raise RuntimeError( - "Flash attention currently only supported for compute capability >= 80") - - # assert gk.dtype == gv.dtype == torch.float32 - BLOCK_M = BLOCK_N = v.shape[-2] - - # shape constraints - Lv = v.shape[-1] - BLOCK_V = min(128, Lv) - ctx.BLOCK_V = BLOCK_V - - assert v.shape[-1] % BLOCK_V == 0 - - grid = (v.shape[2], v.shape[0] * v.shape[1], - max(1, v.shape[-1] // BLOCK_V)) - - o = torch.empty_like(v) - - _fwd_compute_O[grid](A, v, gv, o, - A.stride(0), A.stride( - 1), A.stride(2), A.stride(3), - v.stride(0), v.stride( - 1), v.stride(2), v.stride(3), - BLOCK_N=BLOCK_N, BLOCK_M=BLOCK_M, - BLOCK_DMODEL_V=BLOCK_V, num_warps=8 if BLOCK_V == 128 else 4, num_stages=5 - ) - - ctx.save_for_backward(A, v, gv, o) - ctx.grid = grid - return o - - @staticmethod - @custom_bwd - @contiguous - def backward(ctx, do): - A, v, gv, o = ctx.saved_tensors - BLOCK_V = ctx.BLOCK_V - assert v.shape[-1] % BLOCK_V == 0 - - # dA = torch.empty_like(A) - dv = torch.zeros_like(v) - dgv = torch.zeros_like(gv) - - # for now. - BLOCK_M = BLOCK_N = v.shape[-2] - - # shape constraints - # Lv = v.shape[-1] - # grid = (v.shape[2] , v.shape[0] * v.shape[1], v.shape[-1] // BLOCK_V) - grid = ctx.grid - - dA = torch.empty(v.shape[-1] // BLOCK_V if BLOCK_V == 128 else 1, A.shape[0], - A.shape[1], A.shape[2], A.shape[3], A.shape[3], device=A.device, dtype=A.dtype) - - _bwd_kernel_dav[grid]( - v, gv, A, o, - do, dA, - dv, dgv, - v.shape[0], v.shape[1], - A.stride(0), A.stride(1), A.stride(2), A.stride(3), - v.stride(0), v.stride(1), v.stride(2), v.stride(3), - BLOCK_N=BLOCK_N, BLOCK_M=BLOCK_M, - BLOCK_DMODEL_V=ctx.BLOCK_V, num_warps=8, num_stages=4 - ) - - return dA.sum(0).to(A), dv.to(v), dgv.to(gv) diff --git a/flash_linear_attention/fla/ops/triton/gla/chunk.py b/flash_linear_attention/fla/ops/triton/gla/chunk.py deleted file mode 100644 index 04d1ad2..0000000 --- a/flash_linear_attention/fla/ops/triton/gla/chunk.py +++ /dev/null @@ -1,39 +0,0 @@ -# -*- coding: utf-8 -*- - -# Copyright (c) 2023, Songlin Yang -# Gated Linear Attention Transformers with Hardware-Efficient Training: https://arxiv.org/abs/2312.06635 -# chunkwise block parallel. Materialize chunkwise hidden states into HBMs. -# Therefore it is neccessary to have a large chunk size to reduce such materialization overhead. - -import torch.nn.functional as F -from einops import rearrange - -from fla.ops.triton.gla.block_parallel.inter_chunk_contribution.fn import \ - inter_chunk_onc -from fla.ops.triton.gla.block_parallel.intra_chunk_contribution.fn import \ - intra_chunk_onc - - -def pad_and_rearrange(x, chunk_size): - if x.shape[-2] % chunk_size != 0: - x = F.pad(x, (0, 0, 0, chunk_size - x.shape[-2] % chunk_size)) - if x.shape[-1] % 32 != 0: - x = F.pad(x, (0, 32 - x.shape[-1] % 32)) - x = rearrange(x, '... (n c) d -> ... n c d', c=chunk_size) - return x - - -def chunk_gla(q, k, v, gk=None, gv=None, chunk_size=128): - scale = (q.shape[-1])**-0.5 - seq_len = q.shape[-2] - output_dim = v.shape[-1] - q, k, v = map(lambda x: pad_and_rearrange(x, chunk_size), [q, k, v]) - q = q * scale - if gk is not None: - gk = pad_and_rearrange(gk, chunk_size) - if gv is not None: - gv = pad_and_rearrange(gv, chunk_size) - gk, gv, o1 = inter_chunk_onc(q, k, v, gk, gv) - o2 = intra_chunk_onc(q, k, v, gk, gv) - o = rearrange(o1+o2, 'b h n c d -> b h (n c) d') - return o[:, :, :seq_len, :output_dim] diff --git a/flash_linear_attention/fla/ops/triton/gla/chunk_fuse.py b/flash_linear_attention/fla/ops/triton/gla/chunk_fuse.py deleted file mode 100644 index fd45ba2..0000000 --- a/flash_linear_attention/fla/ops/triton/gla/chunk_fuse.py +++ /dev/null @@ -1,400 +0,0 @@ -# -*- coding: utf-8 -*- - -# Copyright (c) 2023, Songlin Yang -# Gated Linear Attention Transformers with Hardware-Efficient Training: https://arxiv.org/abs/2312.06635 -# on-the-fly computation without materializing hidden statets into HBMs - -import warnings - -import torch -import torch.nn.functional as F -import triton -import triton.language as tl -from einops import rearrange -from fla.ops.triton.utils import contiguous, require_version -from torch.cuda.amp import custom_bwd, custom_fwd - -try: - import semiring_cal_A -except ImportError: - warnings.warn('Failed to import semiring_cal_A. Do not use FusedChunk implementation of GLA.') - -inv_ln2 = 1.44269504 - - -@triton.jit -def fused_chunk_gla_fwd_kernel( - # B: batch_size, H: n_heads, T: seq_len, D: d_head - q, # query [B, H, L, D_head_K] - k, # key [B, H, L, D_head_K] - v, # value [B, H, L, D_head_V] - g, # cumulative sum of log decay [B, H, L, D_head_K] - o, # output [B, H, L, D_head_V] - - initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V] - final_state, # final state of the chunk [B, H, D_head_K, D_head_V] - - s_qk_h, # stride size: L * D_head_K - s_qk_t, # stride size: D_head_K - s_qk_d, # stride size: 1 - - s_vo_h, # stride size: L * D_head_V - s_vo_t, # stride size: D_head_V - s_vo_d, # stride size: 1 - - B, # batch size - H, # n_heads - T, # seq_len - scale, # D_head_K ** -0.5 - BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size - BK: tl.constexpr, # BLOCK SIZE along the K dimension - BV: tl.constexpr, # BLOCK SIZE along the V dimension - DK: tl.constexpr, # D_head_K - DV: tl.constexpr, # D_head_V - USE_INITIAL_STATE: tl.constexpr, - STORE_FINAL_STATE: tl.constexpr, -): - # indices - i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) - - b_h = tl.zeros([BK, BV], dtype=tl.float32) - - # make block pointers - p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK), - (s_qk_t, s_qk_d), (0, i_k * BK), (BT, BK), (1, 0)) - p_g = tl.make_block_ptr(g + i_bh * s_qk_h, (T, DK), - (s_qk_t, s_qk_d), (0, i_k * BK), (BT, BK), (1, 0)) - p_db = g + i_bh * s_qk_h + (BT - 1) * s_qk_t + i_k * BK + tl.arange(0, BK) - p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T), - (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BT), (0, 1)) - p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), - (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0)) - p_o = tl.make_block_ptr(o + (i_bh + i_k * B * H) * s_vo_h, - (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0)) - - if USE_INITIAL_STATE: - p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, - (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) - b_h += tl.load(p_h, boundary_check=(0, 1)).to(tl.float32) - - for i in range(0, tl.cdiv(T, BT)): - # [BK, BT] - b_k = tl.load(p_k, boundary_check=(0, 1)) - # [BT, BV] - b_o = tl.zeros([BT, BV], dtype=tl.float32) - b_v = tl.load(p_v, boundary_check=(0, 1)) - # [BT, BK] - b_q = tl.load(p_q, boundary_check=(0, 1)) - b_g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32) - b_g *= inv_ln2 - - d_b = tl.load(p_db) * inv_ln2 - - b_q = (b_q * scale * tl.math.exp2(b_g)) - b_k = b_k * tl.trans(tl.math.exp2(-b_g + d_b[None, :])) - - b_o = tl.dot(b_q.to(b_v.dtype), b_h.to(b_v.dtype), allow_tf32=False) - b_h *= tl.math.exp2(d_b)[:, None] - b_h += tl.dot(b_k.to(b_v.dtype), b_v, allow_tf32=False) - - tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) - - p_q = tl.advance(p_q, (BT, 0)) - p_g = tl.advance(p_g, (BT, 0)) - p_k = tl.advance(p_k, (0, BT)) - p_v = tl.advance(p_v, (BT, 0)) - p_o = tl.advance(p_o, (BT, 0)) - p_db += BT * DK - - if STORE_FINAL_STATE: - p_final = tl.make_block_ptr( - final_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) - tl.store(p_final, b_h.to(p_final.dtype.element_ty), - boundary_check=(0, 1)) - - -# Similar to Algorithm1 of https://arxiv.org/abs/2006.16236 -@triton.jit -def fused_chunk_gla_bwd_kernel( - q, k, v, g, - do, # gradient of output [B, H, L, D_head_V] - dq, # gradient of query [NV, B, H, L, D_head_K] - dk, # gradient of key [NV, B, H, L, D_head_K] - dv, # gradient of value [NK, B, H, L, D_head_V] - - initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V] - - s_qk_h, # stride size: L * D_head_K - s_qk_t, # stride size: D_head_K - s_qk_d, # stride size: 1 - - s_vo_h, # stride size: L * D_head_V - s_vo_t, # stride size: D_head_V - s_vo_d, # stride size: 1 - - B, # batch_size - H, # n_heads - T, # seq_len - scale, # D_head_K ** -0.5 - # clamp_min, # minimum log value of the gate for numerical stability. default: -5 - BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size - BK: tl.constexpr, # BLOCK SIZE along the K dimension - BV: tl.constexpr, # BLOCK SIZE along the V dimension - DK: tl.constexpr, # D_head_K - DV: tl.constexpr, # D_head_V - USE_INITIAL_STATE: tl.constexpr -): - i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) - # [BV, BK] - b_h = tl.zeros([BV, BK], dtype=tl.float32) - - if USE_INITIAL_STATE: - p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, - (DV, DK), (1, DV), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) - b_h += tl.load(p_h, boundary_check=(0, 1)).to(tl.float32) - - for i in range(0, tl.cdiv(T, BT)): - p_k = tl.make_block_ptr( - k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i * BT, i_k * BK), (BT, BK), (1, 0)) - p_g = tl.make_block_ptr( - g + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i * BT, i_k * BK), (BT, BK), (1, 0)) - p_db = g + i_bh * s_qk_h + \ - ((i+1) * BT - 1) * s_qk_t + i_k * BK + tl.arange(0, BK) - p_v = tl.make_block_ptr( - v + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i * BT), (BV, BT), (0, 1)) - p_do = tl.make_block_ptr( - do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i * BT, i_v * BV), (BT, BV), (1, 0)) - p_dq = tl.make_block_ptr(dq + (i_bh+i_v*B*H)*s_qk_h, (T, DK), - (s_qk_t, s_qk_d), (i * BT, i_k * BK), (BT, BK), (1, 0)) - b_dq = tl.zeros([BT, BK], dtype=tl.float32) - # [BT, DK] - b_k = tl.load(p_k, boundary_check=(0, 1)) - b_g = tl.load(p_g, boundary_check=(0, 1)) * inv_ln2 - d_b = tl.load(p_db) * inv_ln2 - - # [DV, BT] - b_v = tl.load(p_v, boundary_check=(0, 1)) - # [BT, DV] - b_do = tl.load(p_do, boundary_check=(0, 1)) - b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False) - # [DV, DK] - b_k *= tl.math.exp2(d_b[None, :] - b_g) - b_h *= tl.math.exp2(d_b)[None, :] - b_h += tl.dot(b_v, b_k.to(b_v.dtype), allow_tf32=False) - b_dq *= scale * tl.math.exp2(b_g) - tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) - - # sync threads - b_h = None - tl.debug_barrier() - # [BK, BV] - b_dh = tl.zeros([BK, BV], dtype=tl.float32) - - # cum = tl.zeros([BK], dtype=tl.float32) - for i in range(1, tl.cdiv(T, BT) + 1): - p_q = tl.make_block_ptr( - q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, T - i * BT), (BK, BT), (0, 1)) - p_k = tl.make_block_ptr( - k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (T - i * BT, i_k * BK), (BT, BK), (1, 0)) - p_g = tl.make_block_ptr( - g + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (T - i * BT, i_k * BK), (BT, BK), (1, 0)) - p_db = g + i_bh * s_qk_h + \ - (T - (i-1) * BT - 1) * s_qk_t + i_k * BK + tl.arange(0, BK) - p_v = tl.make_block_ptr( - v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0)) - p_do = tl.make_block_ptr( - do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0)) - p_dk = tl.make_block_ptr(dk + (i_bh + i_v * B * H) * s_qk_h, (T, DK), - (s_qk_t, s_qk_d), (T - i * BT, i_k * BK), (BT, BK), (1, 0)) - # p_dg = tl.make_block_ptr(dg + (i_bh + i_v * B * H) * s_qk_h, (T, DK), - # (s_qk_t, s_qk_d), (T - i * BT, i_k * BK), (BT, BK), (1, 0)) - p_dv = tl.make_block_ptr(dv + (i_bh + i_k * B * H) * s_vo_h, (T, DV), - (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0)) - # [DK, BT] - b_q = tl.load(p_q, boundary_check=(0, 1)) - # [BT, DK] - b_k = tl.load(p_k, boundary_check=(0, 1)) - # [BT, DV] - b_v = tl.load(p_v, boundary_check=(0, 1)) - b_do = tl.load(p_do, boundary_check=(0, 1)) - b_g = tl.load(p_g, boundary_check=(0, 1)) * inv_ln2 - b_db = tl.load(p_db) * inv_ln2 - - # inter-chunk - g_k = tl.math.exp2(b_db[None, :] - b_g) - b_k *= g_k - b_q *= tl.math.exp2(tl.trans(b_g)) - b_dk = tl.trans(tl.dot(b_dh.to(b_v.dtype), tl.trans( - b_v), allow_tf32=False)) * scale * g_k - b_dv = tl.dot((b_k).to(b_v.dtype), b_dh.to( - b_v.dtype), allow_tf32=False) * scale - - # [DK, DV] - b_dh *= tl.math.exp2(b_db)[:, None] - b_dh += tl.dot(b_q.to(b_do.dtype), b_do, allow_tf32=False) - - tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) - tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) - - -class FusedChunkGLAFunction(torch.autograd.Function): - - @staticmethod - @contiguous - @custom_fwd - def forward(ctx, q, k, v, g, scale, initial_state, output_final_state): - ctx.g_dtype = g.dtype - batch_size, n_heads, seq_len, d_head_qk = q.shape - d_head_v = v.shape[-1] - ctx.scale = scale - - # inter-chunk - BT = 16 # chunk_size - BK, BV = min(d_head_qk, 64), min(d_head_v, 64) - num_stages = 1 - num_warps = 2 - - NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV) - o = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v) - - g = rearrange(g, 'b h (n c) d -> b h n c d', c=BT) - g = g.float().cumsum(-2) - g = rearrange(g, 'b h n c d -> b h (n c) d') - - if output_final_state: - final_state = q.new_empty( - batch_size, n_heads, d_head_qk, d_head_v, dtype=torch.float32, requires_grad=False) - else: - final_state = None - - grid = (NV, NK, batch_size * n_heads) - fused_chunk_gla_fwd_kernel[grid]( - q, k, v, g, o, initial_state, final_state, - q.stride(1), q.stride(2), q.stride(3), - v.stride(1), v.stride(2), v.stride(3), - batch_size, n_heads, seq_len, scale, - # clamp_min=-3, - BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV, - # USE_SIGMOID=True, USE_EXP=False, - USE_INITIAL_STATE=initial_state is not None, - STORE_FINAL_STATE=output_final_state, - num_warps=num_warps, - num_stages=num_stages - ) - - o = o.sum(0) - - # ### intra-chunk - chunk_size = 16 - num_chunk = seq_len // chunk_size - q2 = rearrange(q, 'b h (n c) d -> b h n c d', n=num_chunk) - k2 = rearrange(k, 'b h (n c) d -> b h n c d', n=num_chunk) - v2 = rearrange(v, 'b h (n c) d -> b h n c d', n=num_chunk) - g2 = rearrange(g, 'b h (n c) d -> b h n c d', n=num_chunk) - A = semiring_cal_A.forward(q2, k2, g2) * scale - o2 = A @ v2 - o2 = rearrange(o2, 'b h n c d -> b h (n c) d') - o.add_(o2) - ctx.save_for_backward(q, k, v, g, A, initial_state) - return o.to(v), final_state - - @staticmethod - @contiguous - @custom_bwd - def backward(ctx, do, d_final_state=None): - q, k, v, g, A, initial_state = ctx.saved_tensors - batch_size, n_heads, seq_len, d_head_qk = q.shape - d_head_v = v.shape[-1] - scale = ctx.scale - - # inter-chunk - BT = 16 - BK, BV = min(d_head_qk, 64), min(d_head_v, 64) - NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV) - num_stages = 1 - num_warps = 2 - - dq = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk) - dk = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk) - dv = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v) - - grid = (NV, NK, batch_size * n_heads) - fused_chunk_gla_bwd_kernel[grid]( - q, k, v, g, do, dq, dk, dv, initial_state, - q.stride(1), q.stride(2), q.stride(3), - v.stride(1), v.stride(2), v.stride(3), - batch_size, n_heads, seq_len, scale, - # clamp_min=-3, - BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV, - USE_INITIAL_STATE=initial_state is not None, - num_warps=num_warps, - num_stages=num_stages, - ) - dq = dq.sum(0) - dk = dk.sum(0) - dv = dv.sum(0) - - dg = dq * q - dg.add_(- dk * k) - - # # # #### intra chunk - num_chunk = seq_len // BT - q2 = rearrange(q, 'b h (n c) d -> b h n c d', n=num_chunk) - k2 = rearrange(k, 'b h (n c) d -> b h n c d', n=num_chunk) - v2 = rearrange(v, 'b h (n c) d -> b h n c d', n=num_chunk) - g2 = rearrange(g, 'b h (n c) d -> b h n c d', n=num_chunk) - do2 = rearrange(do, 'b h (n c) d -> b h n c d', n=num_chunk) - dA2 = (do2 @ v2.transpose(-2, -1)) * scale - dv2 = A.transpose(-1, -2) @ do2 - dq2, dk2, dg2 = semiring_cal_A.backward(q2, k2, g2, dA2) - dq2 = rearrange(dq2, '... h n c d -> ... h (n c) d') - dk2 = rearrange(dk2, '... h n c d -> ... h (n c) d') - dv2 = rearrange(dv2, '... h n c d -> ... h (n c) d') - dg2 = rearrange(dg2, '... h n c d -> ... h (n c) d') - dq.add_(dq2.to(dq)) - dk.add_(dk2.to(dk)) - dv.add_(dv2.to(dv)) - dg = dg.float() - dg.add_(dg2) - dg_cumsum = dg.cumsum(-2) - dg = dg - dg_cumsum + dg_cumsum[:, :, -1, None] - return dq.to(q), dk.to(k), dv.to(v), dg.to(ctx.g_dtype), None, None, None - - -def pad(x, chunk_size=16): - seq_len = x.shape[-2] - padded_seq_len = ceildiv(seq_len, chunk_size) * chunk_size - if x.shape[-2] % chunk_size != 0: - x = F.pad(x, (0, 0, 0, padded_seq_len - seq_len)) - if x.shape[-1] % 32 != 0: - x = F.pad(x, (0, 32 - x.shape[-1] % 32)) - return x - - -def ceildiv(a, b): - return -(a // -b) - - -@require_version('triton>=2.2', 'Numerical stability consideration!') -def fused_chunk_gla( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - g: torch.Tensor, - scale: int = -1, - initial_state: torch.Tensor = None, - output_final_state: bool = False -): - if scale == -1: - scale = q.shape[-1] ** -0.5 - if initial_state is not None: - initial_state = initial_state.detach() - seq_len = v.shape[-2] - d_head_v = v.shape[-1] - q, k, v, g = map(lambda x: pad(x), [q, k, v, g]) - o, final_state = FusedChunkGLAFunction.apply( - q, k, v, g, scale, initial_state, output_final_state) - o = o[..., :seq_len, :d_head_v] - if output_final_state: - return o, final_state - return o diff --git a/flash_linear_attention/fla/ops/triton/gla/recurrent_fuse.py b/flash_linear_attention/fla/ops/triton/gla/recurrent_fuse.py deleted file mode 100644 index 0b17e4f..0000000 --- a/flash_linear_attention/fla/ops/triton/gla/recurrent_fuse.py +++ /dev/null @@ -1,403 +0,0 @@ -# -*- coding: utf-8 -*- - -# Copyright (c) 2023, Songlin Yang - -import torch -import triton -import triton.language as tl -from torch.cuda.amp import custom_bwd, custom_fwd - -from fla.ops.triton.utils import contiguous - -# on-the-fly computation without materializing hidden statets into HBMs - - -@triton.jit -def fused_recurrent_gla_fwd_kernel( - # B: batch_size, H: n_heads, T: seq_len, D: d_head - q, # query [B, H, L, D_head_K] - k, # key [B, H, L, D_head_K] - v, # value [B, H, L, D_head_V] - gk, # log gate [B, H, L, D_head_K] - gv, # log gate [B, H, L, D_head_V] - o, # output [B, H, L, D_head_V] - # initial hidden state initialization [B, H, D_head_K, D_head_V] - initial_state, - final_state, # final hidden state [B, H, D_head_K, D_head_V] - - s_qk_h, # stride size: L * D_head_K - s_qk_t, # stride size: D_head_K - s_qk_d, # stride size: 1 - - s_vo_h, # stride size: L * D_head_V - s_vo_t, # stride size: D_head_V - s_vo_d, # stride size: 1 - - B, # batch size - H, # n_heads - T, # seq_len - scale, # D_head_K ** -0.5 - BK: tl.constexpr, # BLOCK SIZE along the K dimension - BV: tl.constexpr, # BLOCK SIZE along the V dimension - DK: tl.constexpr, # D_head_K - DV: tl.constexpr, # D_head_V - USE_INITIAL_STATE: tl.constexpr, # whether to use initial state - STORE_FINAL_STATE: tl.constexpr, # whether to store final state - REVERSE: tl.constexpr, # whether to do autoregressive modeling in the reverse direction - USE_GK: tl.constexpr, # whether to use gk - USE_GV: tl.constexpr, # whether to use gv -): - # indices - i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) - - p_q = q + i_bh * s_qk_h + i_k * BK + \ - tl.arange(0, BK) + ((T-1) * DK if REVERSE else 0) - p_k = k + i_bh * s_qk_h + i_k * BK + \ - tl.arange(0, BK) + ((T-1) * DK if REVERSE else 0) - p_v = v + i_bh * s_vo_h + i_v * BV + \ - tl.arange(0, BV) + ((T-1) * DV if REVERSE else 0) - p_o = o + (i_bh + i_k * B * H) * s_vo_h + i_v * BV + \ - tl.arange(0, BV) + ((T-1) * DV if REVERSE else 0) - - if USE_GK: - p_gk = gk + i_bh * s_qk_h + i_k * BK + \ - tl.arange(0, BK) + ((T-1) * DK if REVERSE else 0) - if USE_GV: - p_gv = gv + i_bh * s_vo_h + i_v * BV + \ - tl.arange(0, BV) + ((T-1) * DV if REVERSE else 0) - - mask_bk = (i_k * BK + tl.arange(0, BK)) < DK - mask_bv = (i_v * BV + tl.arange(0, BV)) < DV - - h = tl.zeros([BV, BK], dtype=tl.float32) - - mask_kv = mask_bk[None, :] & mask_bv[:, None] - - if USE_INITIAL_STATE: - p_init_s = initial_state + i_bh * DK * DV + \ - (i_k * BK + tl.arange(0, BK)[None, :]) * \ - DV + (i_v * BV + tl.arange(0, BV)[:, None]) - h += tl.load(p_init_s, mask=mask_kv, other=0).to(tl.float32) - - for _ in range(0, T): - _k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32) - _v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32) - _q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale - if USE_GK: - _gk = tl.load(p_gk, mask=mask_bk, other=0).to(tl.float32) - h = h * _gk[None, :] - if USE_GV: - _gv = tl.load(p_gv, mask=mask_bv, other=0).to(tl.float32) - h = h * _gv[:, None] - h += _k[None, :] * _v[:, None] - _o = h * _q[None, :] - _o = tl.sum(_o, axis=1) - tl.store(p_o, _o.to(p_o.dtype.element_ty), mask=mask_bv) - p_q += -DK if REVERSE else DK - p_k += -DK if REVERSE else DK - p_o += -DV if REVERSE else DV - p_v += -DV if REVERSE else DV - if USE_GK: - p_gk += -DK if REVERSE else DK - if USE_GV: - p_gv += -DV if REVERSE else DV - - if STORE_FINAL_STATE: - p_final_s = final_state + i_bh * DK * DV + \ - (i_k * BK + tl.arange(0, BK)[None, :]) * \ - DV + (i_v * BV + tl.arange(0, BV)[:, None]) - tl.store(p_final_s, h.to(p_final_s.dtype.element_ty), mask=mask_kv) - - -# Similar to Algorithm1 of https://arxiv.org/abs/2006.16236 -@triton.jit -def fused_recurrent_gla_bwd_kernel( - # B: batch_size, H: n_heads, T: seq_len, D: d_head - # NV: number of split in the V dimension. NK: number of split in the K dimension - q, # query [B, H, L, D_head_K] - k, # key [B, H, L, D_head_V] - v, # value [B, H, L, D_head_V] - gk, # log gate [B, H, L, D_head_K] \alpha - gv, # log gate [B, H, L, D_head_V] \bete - - do, # gradient of output [B, H, L, D_head_V] - dq, # gradient of query [NV, B, H, L, D_head_K] - dk, # gradient of key [NV, B, H, L, D_head_K] - dv, # gradient of value [NK, B, H, L, D_head_V] - - # initial hidden state initialization [B, H, D_head_K, D_head_V] - initial_state, - - s_qk_h, # stride size: L * D_head_K - s_qk_t, # stride size: D_head_K - s_qk_d, # stride size: 1 - - s_vo_h, # stride size: L * D_head_V - s_vo_t, # stride size: D_head_V - s_vo_d, # stride size: 1 - - B, # batch_size - H, # n_heads - T, # seq_len - scale, # D_head_K ** -0.5 - BK: tl.constexpr, # BLOCK SIZE along the K dimension - BV: tl.constexpr, # BLOCK SIZE along the V dimension - DK: tl.constexpr, # D_head_K - DV: tl.constexpr, # D_head_V - USE_INITIAL_STATE: tl.constexpr, # whether to use initial state - REVERSE: tl.constexpr, # whether to do autoregressive modeling in the reverse direction - USE_GK: tl.constexpr, # whether to use gk - USE_GV: tl.constexpr, # whether to use gv -): - i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) - - p_q = q + i_bh * s_qk_h + i_k * BK + \ - tl.arange(0, BK) + ((T-1) * DK if REVERSE else 0) - p_k = k + i_bh * s_qk_h + i_k * BK + \ - tl.arange(0, BK) + ((T-1) * DK if REVERSE else 0) - p_v = v + i_bh * s_vo_h + i_v * BV + \ - tl.arange(0, BV) + ((T-1) * DV if REVERSE else 0) - p_do = do + i_bh * s_vo_h + i_v * BV + \ - tl.arange(0, BV) + ((T-1) * DV if REVERSE else 0) - p_dq = dq + (i_bh + i_v * B * H) * s_qk_h + i_k * BK + \ - tl.arange(0, BK) + ((T-1) * DK if REVERSE else 0) - if USE_GK: - p_gk = gk + i_bh * s_qk_h + i_k * BK + \ - tl.arange(0, BK) + ((T-1) * DK if REVERSE else 0) - if USE_GV: - p_gv = gv + i_bh * s_vo_h + i_v * BV + \ - tl.arange(0, BV) + ((T-1) * DV if REVERSE else 0) - mask_bk = i_k * BK + tl.arange(0, BK) < DK - mask_bv = i_v * BV + tl.arange(0, BV) < DV - mask_kv = mask_bk[:, None] & mask_bv[None, :] - h = tl.zeros([BK, BV], dtype=tl.float32) - - if USE_INITIAL_STATE: - p_init_s = initial_state + i_bh * DK * DV + \ - (i_k * BK + tl.arange(0, BK)[:, None]) * \ - DV + (i_v * BV + tl.arange(0, BV)[None, :]) - h += tl.load(p_init_s, mask=mask_kv, other=0).to(tl.float32) - - for i in range(0, T): - _k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32) - _v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32) - _do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32) - if USE_GK: - _gk = tl.load(p_gk, mask=mask_bk, other=0).to(tl.float32) - h = h * _gk[:, None] - if USE_GV: - _gv = tl.load(p_gv, mask=mask_bv, other=0).to(tl.float32) - h = h * _gv[None, :] - h += _k[:, None] * _v[None, :] - _d_q = h * _do[None, :] - d_q = tl.sum(_d_q, axis=1) * scale - tl.store(p_dq, d_q.to(p_dq.dtype.element_ty), mask=mask_bk) - - p_k += -DK if REVERSE else DK - p_v += -DV if REVERSE else DV - p_q += -DK if REVERSE else DK - p_do += -DV if REVERSE else DV - p_dq += -DK if REVERSE else DK - if USE_GK: - p_gk += -DK if REVERSE else DK - if USE_GV: - p_gv += -DV if REVERSE else DV - - # sync threads - tl.debug_barrier() - - p_q = q + i_bh * s_qk_h + i_k * BK + \ - tl.arange(0, BK) + ((T - 1) * DK if not REVERSE else 0) - p_k = k + i_bh * s_qk_h + i_k * BK + \ - tl.arange(0, BK) + ((T - 1) * DK if not REVERSE else 0) - p_do = do + i_bh * s_vo_h + i_v * BV + \ - tl.arange(0, BV) + ((T - 1) * DV if not REVERSE else 0) - p_v = v + i_bh * s_vo_h + i_v * BV + \ - tl.arange(0, BV) + ((T - 1) * DV if not REVERSE else 0) - p_dk = dk + (i_bh + i_v * B * H) * s_qk_h + i_k * \ - BK + tl.arange(0, BK) + ((T - 1) * DK if not REVERSE else 0) - p_dv = dv + (i_bh + i_k * B * H) * s_vo_h + i_v * \ - BV + tl.arange(0, BV) + ((T - 1) * DV if not REVERSE else 0) - if USE_GK: - p_gk = gk + i_bh * s_qk_h + i_k * BK + \ - tl.arange(0, BK) + ((T - 1) * DK if not REVERSE else 0) - if USE_GV: - p_gv = gv + i_bh * s_vo_h + i_v * BV + \ - tl.arange(0, BV) + ((T - 1) * DV if not REVERSE else 0) - - d_h = tl.zeros([BK, BV], dtype=tl.float32) - - for _ in range(T): - _do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32) - _q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale - _k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32) - _v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32) - d_h += _q[:, None] * _do[None, :] - d_k = tl.sum(d_h * _v[None, :], axis=1) - d_v = tl.sum(d_h * _k[:, None], axis=0) - if USE_GK: - _gk = tl.load(p_gk, mask=mask_bk, other=0).to(tl.float32) - d_h *= _gk[:, None] - if USE_GV: - _gv = tl.load(p_gv, mask=mask_bv, other=0).to(tl.float32) - d_h *= _gv[None, :] - tl.store(p_dk, d_k.to(p_dk.dtype.element_ty), mask=mask_bk) - tl.store(p_dv, d_v.to(p_dv.dtype.element_ty), mask=mask_bv) - - p_do += DV if REVERSE else -DV - p_q += DK if REVERSE else -DK - p_k += DK if REVERSE else -DK - p_v += DV if REVERSE else -DV - p_dk += DK if REVERSE else -DK - p_dv += DV if REVERSE else -DV - if USE_GK: - p_gk += DK if REVERSE else -DK - if USE_GV: - p_gv += DV if REVERSE else -DV - - -class FusedRecurrentGLAFunction(torch.autograd.Function): - - @staticmethod - @contiguous - @custom_fwd - def forward(ctx, q, k, v, gk, gv, scale=None, initial_state=None, output_final_state=False, reverse=False): - batch_size, n_heads, seq_len, d_head_qk = q.shape - d_head_v = v.shape[-1] - # default scale - if scale is None: - scale = d_head_qk ** -0.5 - if gk is not None: - gk = gk.float().exp() - if gv is not None: - gv = gv.float().exp() - - BK, BV = min(d_head_qk, 32), min(d_head_v, 32) - NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV) - num_stages = 1 - num_warps = 1 - - o = q.new_empty(NK, batch_size, n_heads, seq_len, - d_head_v, dtype=torch.float32) - - if output_final_state: - final_state = q.new_empty(batch_size, n_heads, d_head_qk, d_head_v) - else: - final_state = None - - grid = (NV, NK, batch_size * n_heads) - fused_recurrent_gla_fwd_kernel[grid]( - q, k, v, gk, gv, o, initial_state, final_state, - q.stride(1), q.stride(2), q.stride(3), - v.stride(1), v.stride(2), v.stride(3), - batch_size, n_heads, seq_len, scale, - DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV, - USE_INITIAL_STATE=initial_state is not None, - STORE_FINAL_STATE=final_state is not None, - USE_GK=gk is not None, - USE_GV=gv is not None, - REVERSE=reverse, - num_warps=num_warps, - num_stages=num_stages - ) - - o = o.sum(0) - ctx.save_for_backward(q, k, v, gk, gv, initial_state, o) - ctx.scale = scale - ctx.reverse = reverse - # we do not need the gradient of the final state from the next chunk - # similiar to Trunctated BPTT - if final_state is not None: - final_state = final_state.detach() - return o.to(q.dtype), final_state - - @staticmethod - @contiguous - @custom_bwd - def backward(ctx, do, d_final_state=None): - q, k, v, gk, gv, initial_state, o = ctx.saved_tensors - batch_size, n_heads, seq_len, d_head_qk = q.shape - d_head_v = v.shape[-1] - scale = ctx.scale - - BK, BV = min(d_head_qk, 32), min(d_head_v, 32) - NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV) - num_stages = 1 - num_warps = 1 - - dq = q.new_empty(NV, batch_size, n_heads, seq_len, - d_head_qk, dtype=torch.float32) - dk = q.new_empty(NV, batch_size, n_heads, seq_len, - d_head_qk, dtype=torch.float32) - dv = q.new_empty(NK, batch_size, n_heads, seq_len, - d_head_v, dtype=torch.float32) - grid = (NV, NK, batch_size * n_heads) - - fused_recurrent_gla_bwd_kernel[grid]( - q, k, v, gk, gv, do, dq, dk, dv, initial_state, - q.stride(1), q.stride(2), q.stride(3), - v.stride(1), v.stride(2), v.stride(3), - batch_size, n_heads, seq_len, scale, - DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV, - num_warps=num_warps, - num_stages=num_stages, - USE_INITIAL_STATE=initial_state is not None, - REVERSE=ctx.reverse, - USE_GK=gk is not None, - USE_GV=gv is not None - ) - dq = dq.sum(0) - dk = dk.sum(0) - dv = dv.sum(0) - if gk is not None: - _dgk = dq * q.float() - dk * k.float() - if ctx.reverse: - dgk = _dgk.cumsum(-2) - else: - _dgk_cumsum = _dgk.cumsum(-2) - dgk = _dgk + _dgk_cumsum[:, :, -1, None] - _dgk_cumsum - else: - dgk = None - - if gv is not None: - _dgv = do.float() * o.float() - dv * v.float() - if ctx.reverse: - dgv = _dgv.cumsum(-2) - else: - _dgv_cumsum = _dgv.cumsum(-2) - dgv = _dgv + _dgv_cumsum[:, :, -1, None] - _dgv_cumsum - else: - dgv = None - - return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), dgk, dgv, None, None, None, None - - -# if scale is None, use d_head_qk ** -0.5 by default. Otherwise specify the scale yourself. e.g. scale = 1.0 -def fused_recurrent_gla(q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - gk: torch.Tensor = None, - gv: torch.Tensor = None, - scale: int = -1, - initial_state: torch.Tensor = None, - output_final_state: bool = False, - causal: bool = True): - if scale == -1: - scale = q.shape[-1] ** -0.5 - if initial_state is not None: - initial_state = initial_state.detach() - if causal: - o, final_state = FusedRecurrentGLAFunction.apply( - q, k, v, gk, gv, scale, initial_state, output_final_state) - if output_final_state: - return o, final_state - return o - else: - # do not support initial_state yet. looks very strange for bidirectional modeling - assert initial_state is None - assert output_final_state is False - o, final_state = FusedRecurrentGLAFunction.apply( - q, k, v, gk, gv, scale, initial_state, output_final_state, False) - o_reversed, final_state = FusedRecurrentGLAFunction.apply( - q, k, v, gk, gv, scale, initial_state, output_final_state, True) - return [o, o_reversed] diff --git a/flash_linear_attention/fla/ops/triton/rebased/__init__.py b/flash_linear_attention/fla/ops/triton/rebased/__init__.py deleted file mode 100644 index 8080094..0000000 --- a/flash_linear_attention/fla/ops/triton/rebased/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .parallel import parallel_rebased - - -__all__ = ["parallel_rebased"] diff --git a/flash_linear_attention/fla/ops/triton/rebased/parallel.py b/flash_linear_attention/fla/ops/triton/rebased/parallel.py deleted file mode 100644 index 777a9e5..0000000 --- a/flash_linear_attention/fla/ops/triton/rebased/parallel.py +++ /dev/null @@ -1,388 +0,0 @@ - -# -*- coding: utf-8 -*- - -import torch -import triton -import triton.language as tl - -from fla.ops.triton.utils import contiguous -from torch.cuda.amp import custom_bwd, custom_fwd - -# Based: An Educational and Effective Sequence Mixer -# https://hazyresearch.stanford.edu/blog/2023-12-11-zoology2-based - - -@triton.jit -def parallel_rebased_fwd_kernel( - # B: batch_size, H: n_heads, T: seq_len, D: d_head - q, # query [B, H, L, D_head_K] - k, # key [B, H, L, D_head_V] - v, # value [B, H, L, D_head_V] - o, # output [B, H, L, D_head_V] - z, # normalizer [B, H, L] - s_qk_h, # stride size: L * D_head_K - s_qk_t, # stride size: D_head_K - s_qk_d, # stride size: 1 - s_vo_h, # stride size: L * D_head_V - s_vo_t, # stride size: D_head_V - s_vo_d, # stride size: 1 - B, # batch size - H, # n_heads - T, # seq_len - scale, # D_head_K ** -0.5 - BTL: tl.constexpr, # BLOCK SIZE along the sequence dimension for Q - BTS: tl.constexpr, # BLOCK SIZE along the sequence dimension for K/V - BK: tl.constexpr, # BLOCK SIZE along the K dimension - BV: tl.constexpr, # BLOCK SIZE along the V dimension - DK: tl.constexpr, # D_head_K - DV: tl.constexpr, # D_head_V -): - # i_c: chunk index. used for sequence parallelism - i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) - NV = tl.cdiv(DV, BV) - i_k = i_kv // (NV) - i_v = i_kv % (NV) - - p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK), - (s_qk_t, s_qk_d), (i_c * BTL, i_k * BK), (BTL, BK), (1, 0)) - p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T), - (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BTS), (0, 1)) - p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), - (s_vo_t, s_vo_d), (0, i_v * BV), (BTS, BV), (1, 0)) - - # [BQ, BD] block Q, in the shared memory throughout the whole kernel - b_q = tl.load(p_q, boundary_check=(0, 1)) - b_q = (b_q * scale).to(b_q.dtype) - b_o = tl.zeros([BTL, BV], dtype=tl.float32) - b_z = tl.zeros([BTL], dtype=tl.float32) - - # Q block and K block have no overlap - # no need for mask, thereby saving flops - for _ in range(0, i_c * BTL, BTS): - # [BK, BTS] - b_k = tl.load(p_k, boundary_check=(0, 1)) - - # [BTS, BV] - b_v = tl.load(p_v, boundary_check=(0, 1)) - # [BTL, BTS] - b_s = tl.dot(b_q, (b_k), allow_tf32=False) - b_s = 0.5 + b_s + 0.5 * b_s * b_s - b_z += tl.sum(b_s, axis=1) - - # [BQ, BD] - b_o = b_o + tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False) - p_k = tl.advance(p_k, (0, BTS)) - p_v = tl.advance(p_v, (BTS, 0)) - - # # rescale interchunk output - tl.debug_barrier() - o_q = tl.arange(0, BTL) - # # sync threads, easy for compiler to optimize - # tl.debug_barrier() - - o_k = tl.arange(0, BTS) - p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T), - (s_qk_d, s_qk_t), (i_k * BK, i_c * BTL), (BK, BTS), (0, 1)) - p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), - (s_vo_t, s_vo_d), (i_c * BTL, i_v * BV), (BTS, BV), (1, 0)) - # Q block and K block have overlap. masks required - for _ in range(i_c * BTL, (i_c + 1) * BTL, BTS): - # [BK, BTS] - b_k = tl.load(p_k, boundary_check=(0, 1)) - # [BTS, BV] - b_v = tl.load(p_v, boundary_check=(0, 1)) - # [BTL, BTS] - m_s = o_q[:, None] >= o_k[None, :] - b_s = tl.dot(b_q, b_k, allow_tf32=False) - b_s = 0.5 + b_s + 0.5 * b_s * b_s - b_s = tl.where(m_s, b_s, 0) - b_z += tl.sum(b_s, axis=1) - # [BTL, BV] - b_o += tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False) - - p_k = tl.advance(p_k, (0, BTS)) - p_v = tl.advance(p_v, (BTS, 0)) - o_k += BTS - - p_o = tl.make_block_ptr(o + (i_bh + B * H * i_k) * s_vo_h, (T, DV), - (s_vo_t, s_vo_d), (i_c*BTL, i_v*BV), (BTL, BV), (1, 0)) - p_z = z + (i_bh + B * H * i_k) * T + i_c * BTL + tl.arange(0, BTL) - tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) - tl.store(p_z, b_z.to(p_z.dtype.element_ty), - mask=((i_c * BTL + tl.arange(0, BTL)) < T)) - - -@triton.jit -def _parallel_rebased_bwd_dq( - i_bh, i_c, i_k, i_v, i_h, - q, k, v, do, dz, dq, s_qk_h, s_qk_t, s_qk_d, s_vo_h, - s_vo_t, s_vo_d, B, H, T, scale, - BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, - DK: tl.constexpr, DV: tl.constexpr, -): - p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), - (i_c * BTL, i_v * BV), (BTL, BV), (1, 0)) - p_q = tl.make_block_ptr(q + (i_bh) * s_qk_h, (T, DK), - (s_qk_t, s_qk_d), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0)) - b_q = tl.load(p_q, boundary_check=(0, 1)) - b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype) - b_q = (b_q * scale).to(b_q.dtype) - b_dq = tl.zeros([BTL, BK], dtype=tl.float32) - p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), - (s_qk_t, s_qk_d), (0, i_k * BK), (BTS, BK), (1, 0)) - p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (DV, T), - (s_vo_d, s_vo_t), (i_v * BV, 0), (BV, BTS), (0, 1)) - p_dz = dz + i_bh * T + i_c * BTL + tl.arange(0, BTL) - b_dz = tl.load(p_dz, mask=(i_c * BTL + tl.arange(0, BTL)) < T) - - for _ in range(0, i_c * BTL, BTS): - # [BTS, BK] - b_k = tl.load(p_k, boundary_check=(0, 1)) - # [BV, BTS] - b_v = tl.load(p_v, boundary_check=(0, 1)) - # [BTL, BTS] - b_ds = tl.dot(b_do, b_v, allow_tf32=False) - if i_v == 0: - b_ds += b_dz[:, None] - else: - b_ds = b_ds - b_s = tl.dot(b_q, tl.trans(b_k), allow_tf32=False) - # [BQ, BD] - b_dq += tl.dot((b_ds * (1 + b_s)).to(b_v.dtype), b_k, allow_tf32=False) - p_k = tl.advance(p_k, (BTS, 0)) - p_v = tl.advance(p_v, (0, BTS)) - - b_dq *= scale - o_q = tl.arange(0, BTL) - o_k = tl.arange(0, BTS) - p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), - (s_qk_t, s_qk_d), (i_c * BTL, i_k * BK), (BTS, BK), (1, 0)) - p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (DV, T), - (s_vo_d, s_vo_t), (i_v * BV, i_c * BTL), (BV, BTS), (0, 1)) - # Q block and K block have overlap. masks required - for _ in range(i_c * BTL, (i_c + 1) * BTL, BTS): - # [BTS, BK] - b_k = tl.load(p_k, boundary_check=(0, 1)) - # [BV, BTS] - b_v = tl.load(p_v, boundary_check=(0, 1)) - # [BTL, BTS] - m_s = o_q[:, None] >= o_k[None, :] - b_ds = tl.dot(b_do, b_v, allow_tf32=False) - if i_v == 0: - b_ds += b_dz[:, None] - else: - b_ds = b_ds - b_ds = tl.where(m_s, b_ds, 0) * scale - b_s = tl.dot(b_q, tl.trans(b_k), allow_tf32=False) - b_s = tl.where(m_s, b_s, 0) - # [BTL, BK] - b_dq += tl.dot((b_ds + b_ds * b_s).to(b_k.dtype), - b_k, allow_tf32=False) - p_k = tl.advance(p_k, (BTS, 0)) - p_v = tl.advance(p_v, (0, BTS)) - o_k += BTS - p_dq = tl.make_block_ptr(dq + (i_bh + B * H * i_v) * s_qk_h, (T, DK), - (s_qk_t, s_qk_d), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0)) - tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) - return - - -@triton.jit -def _parallel_rebased_bwd_dkv( - i_bh, i_c, i_k, i_v, i_h, - q, k, v, do, dz, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h, - s_vo_t, s_vo_d, B, H, T, scale, - BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, - DK: tl.constexpr, DV: tl.constexpr, -): - # compute dk dv - p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), - (i_c * BTL, i_k * BK), (BTL, BK), (1, 0)) - p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), - (i_c * BTL, i_v * BV), (BTL, BV), (1, 0)) - b_k, b_v = tl.load(p_k, boundary_check=(0, 1)), tl.load( - p_v, boundary_check=(0, 1)) - b_dk, b_dv = tl.zeros([BTL, BK], dtype=tl.float32), tl.zeros( - [BTL, BV], dtype=tl.float32) - - for i in range((tl.cdiv(T, BTS) * BTS)-BTS, (i_c + 1) * BTL - BTS, -BTS): - p_q = tl.make_block_ptr( - q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, i), (BK, BTS), (0, 1)) - p_do = tl.make_block_ptr( - do + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i), (BV, BTS), (0, 1)) - p_dz = dz + i_bh * T + i + tl.arange(0, BTS) - b_q = tl.load(p_q, boundary_check=(0, 1)) # [BK, BTS] - b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype) # [BV, BTS] - b_dz = tl.load(p_dz, mask=(i + tl.arange(0, BTS)) < T) - b_s = tl.dot(b_k.to(b_q.dtype), b_q, allow_tf32=False) * \ - scale # [BTL, BTS] - b_s2 = 0.5 + b_s + 0.5 * b_s * b_s - b_dv += tl.dot(b_s2.to(b_q.dtype), tl.trans(b_do), allow_tf32=False) - b_ds = tl.dot(b_v, b_do, allow_tf32=False) * scale - if i_v == 0: - b_ds += b_dz[None, :] * scale - else: - b_ds = b_ds - b_dk += tl.dot((b_ds + b_ds * b_s).to(b_q.dtype), - tl.trans(b_q), allow_tf32=False) - - tl.debug_barrier() - o_q, o_k = tl.arange(0, BTS), tl.arange(0, BTL) - for i in range(i_c*BTL, (i_c+1)*BTL, BTS): - p_q = tl.make_block_ptr( - q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, i), (BK, BTS), (0, 1)) - p_do = tl.make_block_ptr( - do + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i), (BV, BTS), (0, 1)) - p_dz = dz + i_bh * T + i + tl.arange(0, BTS) - b_q = tl.load(p_q, boundary_check=(0, 1)) # [BD, BQ] - b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype) - b_dz = tl.load(p_dz, mask=(i + tl.arange(0, BTS)) < T) - # [BK, BQ] - m_s = o_k[:, None] <= o_q[None, :] - b_s = tl.dot(b_k, b_q, allow_tf32=False) * scale - b_s2 = 0.5 + b_s + 0.5 * b_s * b_s - b_s = tl.where(m_s, b_s, 0) - b_s2 = tl.where(m_s, b_s2, 0) - - b_ds = tl.dot(b_v, b_do, allow_tf32=False) - if i_v == 0: - b_ds += b_dz[None, :] - else: - b_ds = b_ds - b_ds = tl.where(m_s, b_ds, 0) * scale - # [BK, BD] - b_dv += tl.dot(b_s2.to(b_q.dtype), tl.trans(b_do), allow_tf32=False) - b_dk += tl.dot((b_ds + b_ds * b_s).to(b_q.dtype), - tl.trans(b_q), allow_tf32=False) - o_q += BTS - - p_dk = tl.make_block_ptr(dk + (i_bh + B * H * i_v) * s_qk_h, - (T, DK), (s_qk_t, s_qk_d), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0)) - p_dv = tl.make_block_ptr(dv + (i_bh + B * H * i_k) * s_vo_h, - (T, DV), (s_vo_t, s_vo_d), (i_c*BTL, i_v*BV), (BTL, BV), (1, 0)) - tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) - tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) - return - - -@triton.jit -def parallel_rebased_bwd_kernel( - q, k, v, do, dz, dq, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h, - s_vo_t, s_vo_d, B, H, T, scale, - BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, - DK: tl.constexpr, DV: tl.constexpr, -): - i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) - NV = tl.cdiv(DV, BV) - i_k = i_kv // (NV) - i_v = i_kv % (NV) - i_h = i_bh % H - _parallel_rebased_bwd_dq( - i_bh, i_c, i_k, i_v, i_h, - q, k, v, do, dz, dq, s_qk_h, s_qk_t, s_qk_d, s_vo_h, - s_vo_t, s_vo_d, B, H, T, scale, BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=DK, DV=DV - ) - tl.debug_barrier() - _parallel_rebased_bwd_dkv( - i_bh, i_c, i_k, i_v, i_h, - q, k, v, do, dz, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h, - s_vo_t, s_vo_d, B, H, T, scale, BTL, BTS, BK, BV, DK, DV - ) - - -class ParallelBasedFunction(torch.autograd.Function): - @staticmethod - @contiguous - @custom_fwd - def forward(ctx, q, k, v, scale): - BTL, BTS = 128, 32 - assert BTL % BTS == 0 - # assert q.shape[-1] % 16 == 0 - BK = min(128, triton.next_power_of_2(k.shape[-1])) - BV = min(128, triton.next_power_of_2(v.shape[-1])) - BK, BV = max(BK, 16), max(BV, 16) - batch_size, n_heads, seq_len, d_head_qk = q.shape - d_head_v = v.shape[-1] - num_stages = 2 - num_warps = 4 - NK = triton.cdiv(d_head_qk, BK) - NV = triton.cdiv(d_head_v, BV) - grid = (NK * NV, triton.cdiv(seq_len, BTL), batch_size * n_heads) - - assert NK == 1, "will encounter some synchronization issue if not." - - o = torch.empty(NK, batch_size, n_heads, seq_len, - d_head_v, device=q.device) - z = torch.empty(NK, batch_size, n_heads, seq_len, - device=q.device) - parallel_rebased_fwd_kernel[grid]( - q, k, v, o, z, - q.stride(1), q.stride(2), q.stride(3), - v.stride(1), v.stride(2), v.stride(3), - batch_size, n_heads, seq_len, scale, - BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=d_head_qk, DV=d_head_v, - num_warps=num_warps, - num_stages=num_stages - ) - ctx.save_for_backward(q, k, v) - ctx.scale = scale - return o.sum(0).to(q.dtype), z.sum(0).to(q.dtype) - - @staticmethod - @custom_bwd - @contiguous - def backward(ctx, do, dz): - q, k, v = ctx.saved_tensors - scale = ctx.scale - BTL, BTS = 64, 32 - assert BTL % BTS == 0 - BK = min(128, triton.next_power_of_2(k.shape[-1])) - BV = min(128, triton.next_power_of_2(v.shape[-1])) - BK, BV = max(BK, 16), max(BV, 16) - batch_size, n_heads, seq_len, d_head_qk = q.shape - d_head_v = v.shape[-1] - num_stages = 2 - num_warps = 4 - NK = triton.cdiv(d_head_qk, BK) - NV = triton.cdiv(d_head_v, BV) - grid = (NK * NV, triton.cdiv(seq_len, BTL), batch_size * n_heads) - - assert NK == 1, "will encounter some synchronization issue if not" - - dq = torch.empty(NV, batch_size, n_heads, seq_len, - d_head_qk, dtype=q.dtype, device=q.device) - dk = torch.empty(NV, batch_size, n_heads, seq_len, - d_head_qk, dtype=q.dtype, device=q.device) - dv = torch.empty(NK, batch_size, n_heads, seq_len, - d_head_v, dtype=q.dtype, device=q.device) - - parallel_rebased_bwd_kernel[grid]( - q, k, v, do, dz, dq, dk, dv, - q.stride(1), q.stride(2), q.stride(3), - v.stride(1), v.stride(2), v.stride(3), - batch_size, n_heads, seq_len, scale, - BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=d_head_qk, DV=d_head_v, - num_warps=num_warps, - num_stages=num_stages - ) - - return dq.sum(0).to(q.dtype), dk.sum(0).to(k.dtype), dv.sum(0).to(v.dtype), None - - -triton_parallel_based = ParallelBasedFunction.apply - - -def parallel_rebased(q, k, v, eps, use_scale=True, use_normalize=True, return_both=False): - assert q.shape[-1] <= 128, "only support feature dim up to 128" - if use_scale: - scale = q.shape[-1] ** -0.5 - else: - scale = 1 - o, z = triton_parallel_based(q, k, v, scale) - if return_both: - return o, z - if use_normalize: - o = o / (z[..., None] + eps) - else: - o = o - return o.to(q.dtype) diff --git a/flash_linear_attention/fla/ops/triton/rebased_fast/__init__.py b/flash_linear_attention/fla/ops/triton/rebased_fast/__init__.py deleted file mode 100644 index 8080094..0000000 --- a/flash_linear_attention/fla/ops/triton/rebased_fast/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .parallel import parallel_rebased - - -__all__ = ["parallel_rebased"] diff --git a/flash_linear_attention/fla/ops/triton/rebased_fast/parallel.py b/flash_linear_attention/fla/ops/triton/rebased_fast/parallel.py deleted file mode 100644 index dce2626..0000000 --- a/flash_linear_attention/fla/ops/triton/rebased_fast/parallel.py +++ /dev/null @@ -1,390 +0,0 @@ - -# -*- coding: utf-8 -*- - -import torch -import triton -import triton.language as tl - -from fla.ops.triton.utils import contiguous -from torch.cuda.amp import custom_bwd, custom_fwd - -# Based: An Educational and Effective Sequence Mixer -# https://hazyresearch.stanford.edu/blog/2023-12-11-zoology2-based - - -@triton.jit -def parallel_rebased_fwd_kernel( - # B: batch_size, H: n_heads, T: seq_len, D: d_head - q, # query [B, H, L, D_head_K] - k, # key [B, H, L, D_head_V] - v, # value [B, H, L, D_head_V] - o, # output [B, H, L, D_head_V] - z, # normalizer [B, H, L] - s_qk_h, # stride size: L * D_head_K - s_qk_t, # stride size: D_head_K - s_qk_d, # stride size: 1 - s_vo_h, # stride size: L * D_head_V - s_vo_t, # stride size: D_head_V - s_vo_d, # stride size: 1 - B, # batch size - H, # n_heads - T, # seq_len - scale, # D_head_K ** -0.5 - BTL: tl.constexpr, # BLOCK SIZE along the sequence dimension for Q - BTS: tl.constexpr, # BLOCK SIZE along the sequence dimension for K/V - BK: tl.constexpr, # BLOCK SIZE along the K dimension - BV: tl.constexpr, # BLOCK SIZE along the V dimension - DK: tl.constexpr, # D_head_K - DV: tl.constexpr, # D_head_V -): - # i_c: chunk index. used for sequence parallelism - i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) - NV = tl.cdiv(DV, BV) - i_k = i_kv // (NV) - i_v = i_kv % (NV) - - p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK), - (s_qk_t, s_qk_d), (i_c * BTL, i_k * BK), (BTL, BK), (1, 0)) - p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T), - (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BTS), (0, 1)) - p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), - (s_vo_t, s_vo_d), (0, i_v * BV), (BTS, BV), (1, 0)) - - # [BQ, BD] block Q, in the shared memory throughout the whole kernel - b_q = tl.load(p_q, boundary_check=(0, 1)) - b_q = (b_q * scale).to(b_q.dtype) - b_o = tl.zeros([BTL, BV], dtype=tl.float32) - b_z = tl.zeros([BTL], dtype=tl.float32) - - # Q block and K block have no overlap - # no need for mask, thereby saving flops - for _ in range(0, i_c * BTL, BTS): - # [BK, BTS] - b_k = tl.load(p_k, boundary_check=(0, 1)) - - # [BTS, BV] - b_v = tl.load(p_v, boundary_check=(0, 1)) - # [BTL, BTS] - b_s = tl.dot(b_q, (b_k), allow_tf32=False) - b_s = b_s * b_s - b_z += tl.sum(b_s, axis=1) - - # [BQ, BD] - b_o = b_o + tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False) - p_k = tl.advance(p_k, (0, BTS)) - p_v = tl.advance(p_v, (BTS, 0)) - - # # rescale interchunk output - tl.debug_barrier() - o_q = tl.arange(0, BTL) - # # sync threads, easy for compiler to optimize - # tl.debug_barrier() - - o_k = tl.arange(0, BTS) - p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T), - (s_qk_d, s_qk_t), (i_k * BK, i_c * BTL), (BK, BTS), (0, 1)) - p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), - (s_vo_t, s_vo_d), (i_c * BTL, i_v * BV), (BTS, BV), (1, 0)) - # Q block and K block have overlap. masks required - for _ in range(i_c * BTL, (i_c + 1) * BTL, BTS): - # [BK, BTS] - b_k = tl.load(p_k, boundary_check=(0, 1)) - # [BTS, BV] - b_v = tl.load(p_v, boundary_check=(0, 1)) - # [BTL, BTS] - m_s = o_q[:, None] >= o_k[None, :] - b_s = tl.dot(b_q, b_k, allow_tf32=False) - b_s = b_s * b_s - b_s = tl.where(m_s, b_s, 0) - b_z += tl.sum(b_s, axis=1) - # [BTL, BV] - b_o += tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False) - - p_k = tl.advance(p_k, (0, BTS)) - p_v = tl.advance(p_v, (BTS, 0)) - o_k += BTS - - p_o = tl.make_block_ptr(o + (i_bh + B * H * i_k) * s_vo_h, (T, DV), - (s_vo_t, s_vo_d), (i_c*BTL, i_v*BV), (BTL, BV), (1, 0)) - p_z = z + (i_bh + B * H * i_k) * T + i_c * BTL + tl.arange(0, BTL) - tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) - tl.store(p_z, b_z.to(p_z.dtype.element_ty), - mask=((i_c * BTL + tl.arange(0, BTL)) < T)) - - -@triton.jit -def _parallel_rebased_bwd_dq( - i_bh, i_c, i_k, i_v, i_h, - q, k, v, do, dz, dq, s_qk_h, s_qk_t, s_qk_d, s_vo_h, - s_vo_t, s_vo_d, B, H, T, scale, - BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, - DK: tl.constexpr, DV: tl.constexpr, -): - p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), - (i_c * BTL, i_v * BV), (BTL, BV), (1, 0)) - p_q = tl.make_block_ptr(q + (i_bh) * s_qk_h, (T, DK), - (s_qk_t, s_qk_d), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0)) - b_q = tl.load(p_q, boundary_check=(0, 1)) - b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype) - b_q = (b_q * scale).to(b_q.dtype) - b_dq = tl.zeros([BTL, BK], dtype=tl.float32) - p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), - (s_qk_t, s_qk_d), (0, i_k * BK), (BTS, BK), (1, 0)) - p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (DV, T), - (s_vo_d, s_vo_t), (i_v * BV, 0), (BV, BTS), (0, 1)) - p_dz = dz + i_bh * T + i_c * BTL + tl.arange(0, BTL) - b_dz = tl.load(p_dz, mask=(i_c * BTL + tl.arange(0, BTL)) < T) - - for _ in range(0, i_c * BTL, BTS): - # [BTS, BK] - b_k = tl.load(p_k, boundary_check=(0, 1)) - # [BV, BTS] - b_v = tl.load(p_v, boundary_check=(0, 1)) - # [BTL, BTS] - b_ds = tl.dot(b_do, b_v, allow_tf32=False) - if i_v == 0: - b_ds += b_dz[:, None] - else: - b_ds = b_ds - b_s = tl.dot(b_q, tl.trans(b_k), allow_tf32=False) - # [BQ, BD] - b_dq += tl.dot((2 * b_ds * b_s).to(b_v.dtype), b_k, allow_tf32=False) - p_k = tl.advance(p_k, (BTS, 0)) - p_v = tl.advance(p_v, (0, BTS)) - - b_dq *= scale - o_q = tl.arange(0, BTL) - o_k = tl.arange(0, BTS) - p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), - (s_qk_t, s_qk_d), (i_c * BTL, i_k * BK), (BTS, BK), (1, 0)) - p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (DV, T), - (s_vo_d, s_vo_t), (i_v * BV, i_c * BTL), (BV, BTS), (0, 1)) - # Q block and K block have overlap. masks required - for _ in range(i_c * BTL, (i_c + 1) * BTL, BTS): - # [BTS, BK] - b_k = tl.load(p_k, boundary_check=(0, 1)) - # [BV, BTS] - b_v = tl.load(p_v, boundary_check=(0, 1)) - # [BTL, BTS] - m_s = o_q[:, None] >= o_k[None, :] - b_ds = tl.dot(b_do, b_v, allow_tf32=False) - if i_v == 0: - b_ds += b_dz[:, None] - else: - b_ds = b_ds - b_ds = tl.where(m_s, b_ds, 0) * scale - b_s = tl.dot(b_q, tl.trans(b_k), allow_tf32=False) - b_s = tl.where(m_s, b_s, 0) - # [BTL, BK] - b_dq += tl.dot((2 * b_ds * b_s).to(b_k.dtype), - b_k, allow_tf32=False) - p_k = tl.advance(p_k, (BTS, 0)) - p_v = tl.advance(p_v, (0, BTS)) - o_k += BTS - p_dq = tl.make_block_ptr(dq + (i_bh + B * H * i_v) * s_qk_h, (T, DK), - (s_qk_t, s_qk_d), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0)) - tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) - return - - - - -@triton.jit -def _parallel_rebased_bwd_dkv( - i_bh, i_c, i_k, i_v, i_h, - q, k, v, do, dz, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h, - s_vo_t, s_vo_d, B, H, T, scale, - BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, - DK: tl.constexpr, DV: tl.constexpr, -): - # compute dk dv - p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), - (i_c * BTL, i_k * BK), (BTL, BK), (1, 0)) - p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), - (i_c * BTL, i_v * BV), (BTL, BV), (1, 0)) - b_k, b_v = tl.load(p_k, boundary_check=(0, 1)), tl.load( - p_v, boundary_check=(0, 1)) - b_dk, b_dv = tl.zeros([BTL, BK], dtype=tl.float32), tl.zeros( - [BTL, BV], dtype=tl.float32) - - for i in range((tl.cdiv(T, BTS) * BTS)-BTS, (i_c + 1) * BTL - BTS, -BTS): - p_q = tl.make_block_ptr( - q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, i), (BK, BTS), (0, 1)) - p_do = tl.make_block_ptr( - do + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i), (BV, BTS), (0, 1)) - p_dz = dz + i_bh * T + i + tl.arange(0, BTS) - b_q = tl.load(p_q, boundary_check=(0, 1)) # [BK, BTS] - b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype) # [BV, BTS] - b_dz = tl.load(p_dz, mask=(i + tl.arange(0, BTS)) < T) - b_s = tl.dot(b_k.to(b_q.dtype), b_q, allow_tf32=False) * \ - scale # [BTL, BTS] - b_s2 = b_s * b_s - b_dv += tl.dot(b_s2.to(b_q.dtype), tl.trans(b_do), allow_tf32=False) - b_ds = tl.dot(b_v, b_do, allow_tf32=False) * scale - if i_v == 0: - b_ds += b_dz[None, :] * scale - else: - b_ds = b_ds - b_dk += tl.dot((2 * b_ds * b_s).to(b_q.dtype), - tl.trans(b_q), allow_tf32=False) - - tl.debug_barrier() - o_q, o_k = tl.arange(0, BTS), tl.arange(0, BTL) - for i in range(i_c*BTL, (i_c+1)*BTL, BTS): - p_q = tl.make_block_ptr( - q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, i), (BK, BTS), (0, 1)) - p_do = tl.make_block_ptr( - do + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i), (BV, BTS), (0, 1)) - p_dz = dz + i_bh * T + i + tl.arange(0, BTS) - b_q = tl.load(p_q, boundary_check=(0, 1)) # [BD, BQ] - b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype) - b_dz = tl.load(p_dz, mask=(i + tl.arange(0, BTS)) < T) - # [BK, BQ] - m_s = o_k[:, None] <= o_q[None, :] - b_s = tl.dot(b_k, b_q, allow_tf32=False) * scale - b_s2 = b_s * b_s - b_s = tl.where(m_s, b_s, 0) - b_s2 = tl.where(m_s, b_s2, 0) - - b_ds = tl.dot(b_v, b_do, allow_tf32=False) - if i_v == 0: - b_ds += b_dz[None, :] - else: - b_ds = b_ds - b_ds = tl.where(m_s, b_ds, 0) * scale - # [BK, BD] - b_dv += tl.dot(b_s2.to(b_q.dtype), tl.trans(b_do), allow_tf32=False) - b_dk += tl.dot((2 * b_ds * b_s).to(b_q.dtype), - tl.trans(b_q), allow_tf32=False) - o_q += BTS - - p_dk = tl.make_block_ptr(dk + (i_bh + B * H * i_v) * s_qk_h, - (T, DK), (s_qk_t, s_qk_d), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0)) - p_dv = tl.make_block_ptr(dv + (i_bh + B * H * i_k) * s_vo_h, - (T, DV), (s_vo_t, s_vo_d), (i_c*BTL, i_v*BV), (BTL, BV), (1, 0)) - tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) - tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) - return - - -@triton.jit -def parallel_rebased_bwd_kernel( - q, k, v, do, dz, dq, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h, - s_vo_t, s_vo_d, B, H, T, scale, - BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, - DK: tl.constexpr, DV: tl.constexpr, -): - i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) - NV = tl.cdiv(DV, BV) - i_k = i_kv // (NV) - i_v = i_kv % (NV) - i_h = i_bh % H - _parallel_rebased_bwd_dq( - i_bh, i_c, i_k, i_v, i_h, - q, k, v, do, dz, dq, s_qk_h, s_qk_t, s_qk_d, s_vo_h, - s_vo_t, s_vo_d, B, H, T, scale, BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=DK, DV=DV - ) - tl.debug_barrier() - _parallel_rebased_bwd_dkv( - i_bh, i_c, i_k, i_v, i_h, - q, k, v, do, dz, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h, - s_vo_t, s_vo_d, B, H, T, scale, BTL, BTS, BK, BV, DK, DV - ) - - -class ParallelBasedFunction(torch.autograd.Function): - @staticmethod - @contiguous - @custom_fwd - def forward(ctx, q, k, v, scale): - BTL, BTS = 128, 32 - assert BTL % BTS == 0 - # assert q.shape[-1] % 16 == 0 - BK = min(128, triton.next_power_of_2(k.shape[-1])) - BV = min(128, triton.next_power_of_2(v.shape[-1])) - BK, BV = max(BK, 16), max(BV, 16) - batch_size, n_heads, seq_len, d_head_qk = q.shape - d_head_v = v.shape[-1] - num_stages = 2 - num_warps = 4 - NK = triton.cdiv(d_head_qk, BK) - NV = triton.cdiv(d_head_v, BV) - grid = (NK * NV, triton.cdiv(seq_len, BTL), batch_size * n_heads) - - assert NK == 1, "will encounter some synchronization issue if not." - - o = torch.empty(NK, batch_size, n_heads, seq_len, - d_head_v, device=q.device) - z = torch.empty(NK, batch_size, n_heads, seq_len, - device=q.device) - parallel_rebased_fwd_kernel[grid]( - q, k, v, o, z, - q.stride(1), q.stride(2), q.stride(3), - v.stride(1), v.stride(2), v.stride(3), - batch_size, n_heads, seq_len, scale, - BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=d_head_qk, DV=d_head_v, - num_warps=num_warps, - num_stages=num_stages - ) - ctx.save_for_backward(q, k, v) - ctx.scale = scale - return o.sum(0).to(q.dtype), z.sum(0).to(q.dtype) - - @staticmethod - @custom_bwd - @contiguous - def backward(ctx, do, dz): - q, k, v = ctx.saved_tensors - scale = ctx.scale - BTL, BTS = 64, 32 - assert BTL % BTS == 0 - BK = min(128, triton.next_power_of_2(k.shape[-1])) - BV = min(128, triton.next_power_of_2(v.shape[-1])) - BK, BV = max(BK, 16), max(BV, 16) - batch_size, n_heads, seq_len, d_head_qk = q.shape - d_head_v = v.shape[-1] - num_stages = 2 - num_warps = 4 - NK = triton.cdiv(d_head_qk, BK) - NV = triton.cdiv(d_head_v, BV) - grid = (NK * NV, triton.cdiv(seq_len, BTL), batch_size * n_heads) - - assert NK == 1, "will encounter some synchronization issue if not" - - dq = torch.empty(NV, batch_size, n_heads, seq_len, - d_head_qk, dtype=q.dtype, device=q.device) - dk = torch.empty(NV, batch_size, n_heads, seq_len, - d_head_qk, dtype=q.dtype, device=q.device) - dv = torch.empty(NK, batch_size, n_heads, seq_len, - d_head_v, dtype=q.dtype, device=q.device) - - parallel_rebased_bwd_kernel[grid]( - q, k, v, do, dz, dq, dk, dv, - q.stride(1), q.stride(2), q.stride(3), - v.stride(1), v.stride(2), v.stride(3), - batch_size, n_heads, seq_len, scale, - BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=d_head_qk, DV=d_head_v, - num_warps=num_warps, - num_stages=num_stages - ) - - return dq.sum(0).to(q.dtype), dk.sum(0).to(k.dtype), dv.sum(0).to(v.dtype), None - - -triton_parallel_based = ParallelBasedFunction.apply - - -def parallel_rebased(q, k, v, eps, use_scale=True, use_normalize=True, return_both=False): - assert q.shape[-1] <= 128, "only support feature dim up to 128" - if use_scale: - scale = q.shape[-1] ** -0.5 - else: - scale = 1 - o, z = triton_parallel_based(q, k, v, scale) - if return_both: - return o, z - if use_normalize: - o = o / (z[..., None] + eps) - else: - o = o - return o.to(q.dtype) diff --git a/flash_linear_attention/fla/ops/triton/retention/__init__.py b/flash_linear_attention/fla/ops/triton/retention/__init__.py deleted file mode 100644 index 1aaa71d..0000000 --- a/flash_linear_attention/fla/ops/triton/retention/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -# -*- coding: utf-8 -*- - -from .chunk import chunk_retention -from .chunk_fuse import fused_chunk_retention -from .parallel import parallel_retention -from .recurrent_fuse import fused_recurrent_retention - -__all__ = ['fused_chunk_retention', 'parallel_retention', - 'fused_recurrent_retention', 'chunk_retention'] diff --git a/flash_linear_attention/fla/ops/triton/retention/chunk.py b/flash_linear_attention/fla/ops/triton/retention/chunk.py deleted file mode 100644 index fb6e93b..0000000 --- a/flash_linear_attention/fla/ops/triton/retention/chunk.py +++ /dev/null @@ -1,389 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright (c) 2023, Yu Zhang, Songlin Yang - -import torch -import triton -import triton.language as tl -from torch.cuda.amp import custom_bwd, custom_fwd - -from fla.ops.triton.utils import contiguous - - -@triton.jit -def chunk_retention_fwd_kernel_h( - k, - v, - h, - initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V] - final_state, # final state of the chunk [B, H, D_head_K, D_head_V] - s_qk_h, - s_qk_t, - s_qk_d, - s_vo_h, - s_vo_t, - s_vo_d, - s_hh, - s_ht, - H, - T, - TD, - DK, - DV, - BT: tl.constexpr, - BK: tl.constexpr, - BV: tl.constexpr, - USE_INITIAL_STATE: tl.constexpr, - STORE_FINAL_STATE: tl.constexpr -): - i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) - i_h = i_bh % H - b_b = tl.math.log2(1 - tl.math.pow(2, -5 - i_h * 1.0)) - - p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BT), (0, 1)) - p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0)) - p_h = tl.make_block_ptr(h + i_bh * s_hh, (TD, DV), (s_ht, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) - - o_i = tl.arange(0, BT) - d_b, d_i = tl.math.exp2(BT * b_b), tl.math.exp2((BT - o_i - 1) * b_b) - # [BK, BV] - b_h = tl.zeros([BK, BV], dtype=tl.float32) - - if USE_INITIAL_STATE: - p_h0 = tl.make_block_ptr(initial_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) - b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32) - - for _ in range(0, T, BT): - tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1)) - # [BK, BT] - b_k = tl.load(p_k, boundary_check=(0, 1)) - # [BT, BV] - b_v = tl.load(p_v, boundary_check=(0, 1)) - # [BK, BV] - b_h = d_b * b_h + tl.dot(b_k, (b_v * d_i[:, None]).to(b_k.dtype), allow_tf32=False) - - p_k = tl.advance(p_k, (0, BT)) - p_v = tl.advance(p_v, (BT, 0)) - p_h = tl.advance(p_h, (DK, 0)) - - if STORE_FINAL_STATE: - p_ht = tl.make_block_ptr(final_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) - tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) - - -@triton.jit -def chunk_retention_fwd_kernel_o( - q, - k, - v, - h, - o, - s_qk_h, - s_qk_t, - s_qk_d, - s_vo_h, - s_vo_t, - s_vo_d, - s_hh, - s_ht, - H, - T, - TD, - scale, - DK, - DV, - BT: tl.constexpr, - BK: tl.constexpr, - BV: tl.constexpr -): - i_t, i_bh = tl.program_id(0), tl.program_id(1) - i_h = i_bh % H - b_b = tl.math.log2(1 - tl.math.pow(2, -5 - i_h * 1.0)) - - o_i = tl.arange(0, BT) - d_i = tl.math.exp2((o_i + 1) * b_b) - m_s = o_i[:, None] >= o_i[None, :] - d_s = tl.where(m_s, tl.math.exp2((o_i[:, None] - o_i[None, :]) * b_b), 0) - - for i_v in range(0, tl.cdiv(DV, BV)): - p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i_t * BT, 0), (BT, BK), (1, 0)) - p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (0, i_t * BT), (BK, BT), (0, 1)) - p_h = tl.make_block_ptr(h + i_bh * s_hh, (TD, DV), (s_ht, 1), (i_t * DK, i_v * BV), (BK, BV), (1, 0)) - - b_o = tl.zeros([BT, BV], dtype=tl.float32) - b_s = tl.zeros([BT, BT], dtype=tl.float32) - for _ in range(0, tl.cdiv(DK, BK)): - b_q = tl.load(p_q, boundary_check=(0, 1)) - b_q = (b_q * scale).to(b_q.dtype) - # [BD, BT] - b_k = tl.load(p_k, boundary_check=(0, 1)) - # [BD, BD] - b_h = tl.load(p_h, boundary_check=(0, 1)) - b_o += tl.dot((b_q * d_i[:, None]).to(b_q.dtype), b_h, allow_tf32=False) - b_s += tl.dot(b_q, b_k, allow_tf32=False) - - p_q = tl.advance(p_q, (0, BK)) - p_k = tl.advance(p_k, (BK, 0)) - p_h = tl.advance(p_h, (BK, 0)) - - b_s *= d_s - p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) - b_v = tl.load(p_v, boundary_check=(0, 1)) - b_o += tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False) - p_o = tl.make_block_ptr(o + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) - tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) - - -@triton.jit -def chunk_retention_bwd_kernel_dh( - q, - do, - dh, - s_qk_h, - s_qk_t, - s_qk_d, - s_vo_h, - s_vo_t, - s_vo_d, - s_hh, - s_ht, - H, - T, - scale, - DK, - DV, - BT: tl.constexpr, - BK: tl.constexpr, - BV: tl.constexpr, - NT: tl.constexpr -): - i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) - i_h = i_bh % H - b_b = tl.math.log2(1 - tl.math.pow(2, -5 - i_h * 1.0)) - - o_i = tl.arange(0, BT) - d_b, d_i = tl.math.exp2(BT * b_b), tl.math.exp2((o_i + 1) * b_b) - # [BK, BV] - b_dh = tl.zeros([BK, BV], dtype=tl.float32) - for i in range(NT - 1, -1, -1): - p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, i * BT), (BK, BT), (0, 1)) - p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i * BT, i_v * BV), (BT, BV), (1, 0)) - p_dh = tl.make_block_ptr(dh + i_bh * s_hh, ((i+1)*DK, DV), (s_ht, 1), (i * DK + i_k * BK, i_v * BV), (BK, BV), (1, 0)) - - tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1)) - # [BK, BT] - b_q = tl.load(p_q, boundary_check=(0, 1)) - b_q = (b_q * scale).to(b_q.dtype) - # [BT, DV] - b_do = tl.load(p_do, boundary_check=(0, 1)) - # [BK, BV] - b_dh = d_b * b_dh + tl.dot(b_q, (b_do * d_i[:, None]).to(b_q.dtype), allow_tf32=False) - - -@triton.jit -def chunk_retention_bwd_kernel_dqkv( - q, - k, - v, - h, - do, - dh, - dq, - dk, - dv, - s_qk_h, - s_qk_t, - s_qk_d, - s_vo_h, - s_vo_t, - s_vo_d, - s_hh, - s_ht, - H, - T, - TDK, - scale, - DK, - DV, - BT: tl.constexpr, - BK: tl.constexpr, - BV: tl.constexpr -): - i_t, i_bh = tl.program_id(0), tl.program_id(1) - i_h = i_bh % H - b_b = tl.math.log2(1 - tl.math.pow(2, -5 - i_h * 1.0)) - - o_i = tl.arange(0, BT) - d_q, d_k = tl.math.exp2((o_i + 1) * b_b), tl.math.exp2((BT - o_i - 1) * b_b) - d_q = (d_q * scale).to(d_q.dtype) - m_s = o_i[:, None] >= o_i[None, :] - d_s = tl.where(m_s, tl.math.exp2((o_i[:, None] - o_i[None, :]) * b_b), 0) * scale - - for i_k in range(0, tl.cdiv(DK, BK)): - p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) - p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) - - p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i_t * BT, 0), (BT, BV), (1, 0)) - p_h = tl.make_block_ptr(h + i_bh * s_hh, (DV, TDK), (1, s_ht), (0, i_t * DK + i_k * BK), (BV, BK), (0, 1)) - p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i_t * BT, 0), (BT, BV), (1, 0)) - p_dh = tl.make_block_ptr(dh + i_bh * s_hh, (TDK, DV), (s_ht, 1), (i_t * DK + i_k * BK, 0), (BK, BV), (1, 0)) - p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i_t * BT, 0), (BT, BV), (1, 0)) - - p_dq = tl.make_block_ptr(dq + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) - p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) - - b_q = tl.load(p_q, boundary_check=(0, 1)) - b_k = tl.load(p_k, boundary_check=(0, 1)) - b_s = tl.dot(b_k, b_q, allow_tf32=False) * tl.trans(d_s) - - b_dq = tl.zeros([BT, BK], dtype=tl.float32) - b_dk = tl.zeros([BT, BK], dtype=tl.float32) - for _ in range(tl.cdiv(DV, BV)): - # [BT, BV] - b_v = tl.load(p_v, boundary_check=(0, 1)) - b_do = tl.load(p_do, boundary_check=(0, 1)) - # [BV, BK] - b_h = tl.load(p_h, boundary_check=(0, 1)) - # [BK, BV] - b_dh = tl.load(p_dh, boundary_check=(0, 1)) - - # [BT, BT] - b_ds = tl.dot(b_do, tl.trans(b_v), allow_tf32=False) - b_ds = (b_ds * d_s).to(b_k.dtype) - # [BT, BK] - b_dq += tl.dot(b_do, b_h, allow_tf32=False) * d_q[:, None] + tl.dot(b_ds, b_k, allow_tf32=False) - - # [BT, BT] - b_ds = tl.trans(b_ds) - # [BK, BT] - b_dk += tl.dot(b_v, tl.trans(b_dh), allow_tf32=False) * d_k[:, None] - b_dk += tl.dot(b_ds, tl.trans(b_q), allow_tf32=False) - b_dv = tl.dot(b_k, b_dh, allow_tf32=False) * d_k[:, None] + tl.dot(b_s.to(b_q.dtype), b_do, allow_tf32=False) - b_dv += tl.load(p_dv, boundary_check=(0, 1)).to(tl.float32) - tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) - - p_v = tl.advance(p_v, (0, BV)) - p_h = tl.advance(p_h, (BV, 0)) - p_do = tl.advance(p_do, (0, BV)) - p_dh = tl.advance(p_dh, (0, BV)) - p_dv = tl.advance(p_dv, (0, BV)) - - tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) - tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) - - -class ChunkRetentionFunction(torch.autograd.Function): - - @staticmethod - @custom_fwd - @contiguous - def forward(ctx, q, k, v, initial_state, output_final_state): - BT = 64 - DK, DV = k.shape[-1], v.shape[-1] - BK, BV = min(64, triton.next_power_of_2(DK)), min(64, triton.next_power_of_2(DV)) - batch_size, n_heads, seq_len, d_head_qk = q.shape - d_head_v = v.shape[-1] - num_stages = 1 - num_warps = 4 if BK == 64 else 2 - scale = DK ** -0.5 - - NK, NV = triton.cdiv(DK, BK), triton.cdiv(DV, BV) - h = q.new_empty(batch_size, n_heads, triton.cdiv(seq_len, BT) * DK, DV) - - final_state = None - if output_final_state: - final_state = q.new_empty(batch_size, n_heads, d_head_qk, d_head_v, dtype=torch.float32, requires_grad=False) - - grid = (NK, NV, batch_size * n_heads) - chunk_retention_fwd_kernel_h[grid]( - k, v, h, initial_state, final_state, - q.stride(1), q.stride(2), q.stride(3), - v.stride(1), v.stride(2), v.stride(3), - h.stride(1), h.stride(2), - n_heads, seq_len, h.shape[2], - DK=DK, DV=DV, BT=BT, BK=BK, BV=BV, - USE_INITIAL_STATE=initial_state is not None, - STORE_FINAL_STATE=output_final_state, - num_warps=num_warps, - num_stages=num_stages - ) - grid = (triton.cdiv(seq_len, BT), batch_size * n_heads) - o = torch.empty_like(v) - chunk_retention_fwd_kernel_o[grid]( - q, k, v, h, o, - q.stride(1), q.stride(2), q.stride(3), - v.stride(1), v.stride(2), v.stride(3), - h.stride(1), h.stride(2), - n_heads, seq_len, h.shape[2], scale, - BK=BK, BV=BV, DK=DK, DV=DV, BT=BT, - num_warps=num_warps, - num_stages=num_stages - ) - - ctx.save_for_backward(q, k, v, h) - return o.to(q.dtype), final_state - - @staticmethod - @custom_bwd - @contiguous - def backward(ctx, do, d_ht=None): - q, k, v, h = ctx.saved_tensors - - BT = 64 - DK, DV = k.shape[-1], v.shape[-1] - BK, BV = min(64, triton.next_power_of_2(DK)), min(64, triton.next_power_of_2(DV)) - batch_size, n_heads, seq_len, _ = q.shape - num_stages = 1 - num_warps = 4 if BK == 64 else 2 - scale = DK ** -0.5 - - NK, NV = triton.cdiv(DK, BK), triton.cdiv(DV, BV) - grid = (NK, NV, batch_size * n_heads) - dh = q.new_empty(batch_size, n_heads, triton.cdiv(seq_len, BT) * DK, DV) - - chunk_retention_bwd_kernel_dh[grid]( - q, do, dh, - q.stride(1), q.stride(2), q.stride(3), - v.stride(1), v.stride(2), v.stride(3), - dh.stride(1), dh.stride(2), - n_heads, seq_len, scale, - BT=BT, BK=BK, BV=BV, DK=DK, DV=DV, NT=triton.cdiv(seq_len, BT), - num_warps=num_warps, - num_stages=num_stages - ) - - BK, BV = min(64, triton.next_power_of_2(DK)), min(64, triton.next_power_of_2(DV)) - NK, NV = triton.cdiv(DK, BK), triton.cdiv(DV, BV) - grid = (triton.cdiv(seq_len, BT), batch_size * n_heads) - dq = torch.empty_like(q) - dk = torch.empty_like(k) - # must be zero. we need reload - dv = torch.zeros_like(v) - num_stages = 1 - num_warps = 4 if BK == 64 else 2 - chunk_retention_bwd_kernel_dqkv[grid]( - q, k, v, h, do, dh, dq, dk, dv, - q.stride(1), q.stride(2), q.stride(3), - v.stride(1), v.stride(2), v.stride(3), - dh.stride(1), dh.stride(2), - n_heads, seq_len, h.shape[2], scale, - BT=BT, BK=BK, BV=BV, DK=DK, DV=DV, - num_warps=num_warps, - num_stages=num_stages - ) - return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), None, None - - -def chunk_retention( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - initial_state: torch.Tensor = None, - output_final_state: bool = False -): - if initial_state is not None: - initial_state = initial_state.detach() - o, final_state = ChunkRetentionFunction.apply(q, k, v, initial_state, output_final_state) - if output_final_state: - return o, final_state - else: - return o diff --git a/flash_linear_attention/fla/ops/triton/retention/chunk_fuse.py b/flash_linear_attention/fla/ops/triton/retention/chunk_fuse.py deleted file mode 100644 index 2a8811e..0000000 --- a/flash_linear_attention/fla/ops/triton/retention/chunk_fuse.py +++ /dev/null @@ -1,329 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright (c) 2023, Yu Zhang, Songlin Yang - -import torch -import triton -import triton.language as tl -from packaging import version -from torch.cuda.amp import custom_bwd, custom_fwd - -from fla.ops.triton.utils import contiguous - -# on-the-fly computation without materializing hidden statets into HBMs - - -@triton.jit -def fused_chunk_retention_fwd_kernel( - # B: batch_size, H: n_heads, T: seq_len, D: d_head - q, # query [B, H, L, D_head_K] - k, # key [B, H, L, D_head_V] - v, # value [B, H, L, D_head_V] - o, # output [B, H, L, D_head_V] - initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V] - final_state, # final state of the chunk [B, H, D_head_K, D_head_V] - s_qk_h, # stride size: L * D_head_K - s_qk_t, # stride size: D_head_K - s_qk_d, # stride size: 1 - s_vo_h, # stride size: L * D_head_V - s_vo_t, # stride size: D_head_V - s_vo_d, # stride size: 1 - B, # batch size - H, # n_heads - T, # seq_len - scale, # D_head_K ** -0.5 - BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size - BK: tl.constexpr, # BLOCK SIZE along the K dimension - BV: tl.constexpr, # BLOCK SIZE along the V dimension - DK: tl.constexpr, # D_head_K - DV: tl.constexpr, # D_head_V - USE_INITIAL_STATE: tl.constexpr, - STORE_FINAL_STATE: tl.constexpr, - CHECK: tl.constexpr -): - # indices - i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) - i_h = i_bh % H - - o_i = tl.arange(0, BT) - # decay rate given the head index - b_b = tl.math.log2(1 - tl.math.pow(2, -5 - i_h * 1.0)) - - # d_b: overall decay for the entire chunk - # d_o: cumulative decay from the start of the chunk - # d_h: cumulative decay from the end of the chunk - d_b, d_o, d_h = tl.math.exp2(BT * b_b), tl.math.exp2((o_i + 1) * b_b), tl.math.exp2((BT - o_i - 1) * b_b) - - # [BT, BT] - m_s = o_i[:, None] >= o_i[None, :] - d_s = tl.where(m_s, tl.math.exp2((o_i[:, None] - o_i[None, :]) * b_b), 0) - # [BK, BV] - b_h = tl.zeros([BK, BV], dtype=tl.float32) - - # make block pointers - p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (0, i_k * BK), (BT, BK), (1, 0)) - p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BT), (0, 1)) - p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0)) - p_o = tl.make_block_ptr(o + (i_bh+i_k*B*H) * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0)) - - if USE_INITIAL_STATE: - p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) - b_h = tl.load(p_h, boundary_check=(0, 1)).to(tl.float32) - - for i in range(0, tl.cdiv(T, BT)): - # [BK, BT] - b_k = tl.load(p_k, boundary_check=(0, 1)) - # [BT, BV] - b_v = tl.load(p_v, boundary_check=(0, 1)) - # [BT, BK] - b_q = tl.load(p_q, boundary_check=(0, 1)) - b_q = (b_q * scale).to(b_k.dtype) - - # [BT, BT] - b_s = tl.dot(b_q, b_k, allow_tf32=False) * d_s - # [BT, BV] - b_o = tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False) - if CHECK and i == 0: - b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False) * d_o[:, None] - b_h = d_b * b_h + tl.dot(b_k, (b_v * d_h[:, None]).to(b_k.dtype), allow_tf32=False) - else: - b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False) * d_o[:, None] - b_h = d_b * b_h + tl.dot(b_k, (b_v * d_h[:, None]).to(b_k.dtype), allow_tf32=False) - tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) - - p_q = tl.advance(p_q, (BT, 0)) - p_k = tl.advance(p_k, (0, BT)) - p_v = tl.advance(p_v, (BT, 0)) - p_o = tl.advance(p_o, (BT, 0)) - - if STORE_FINAL_STATE: - p_final = tl.make_block_ptr(final_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) - tl.store(p_final, b_h.to(p_final.dtype.element_ty), boundary_check=(0, 1)) - - -# Similar to Algorithm1 of https://arxiv.org/abs/2006.16236 -@triton.jit -def fused_chunk_retention_bwd_kernel( - # B: batch_size, H: n_heads, T: seq_len, D: d_head - # NV: number of split in the V dimension. NK: number of split in the K dimension - q, # query [B, H, L, D_head_K] - k, # key [B, H, L, D_head_V] - v, # value [B, H, L, D_head_V] - do, # gradient of output [B, H, L, D_head_V] - dq, # gradient of query [NV, B, H, L, D_head_K] - dk, # gradient of key [NV, B, H, L, D_head_K] - dv, # gradient of value [NK, B, H, L, D_head_V] - - initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V] - - s_qk_h, # stride size: L * D_head_K - s_qk_t, # stride size: D_head_K - s_qk_d, # stride size: 1 - s_vo_h, # stride size: L * D_head_V - s_vo_t, # stride size: D_head_V - s_vo_d, # stride size: 1 - B, # batch_size - H, # n_heads - T, # seq_len - scale, # D_head_K ** -0.5 - BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size - BK: tl.constexpr, # BLOCK SIZE along the K dimension - BV: tl.constexpr, # BLOCK SIZE along the V dimension - DK: tl.constexpr, # D_head_K - DV: tl.constexpr, # D_head_V - USE_INITIAL_STATE: tl.constexpr, - CHECK: tl.constexpr -): - i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) - i_h = i_bh % H - - o_i = tl.arange(0, BT) - b_b = tl.math.log2(1 - tl.math.pow(2, -5 - i_h * 1.0)) - d_q, d_k = tl.math.exp2((o_i+1) * b_b) * scale, tl.math.exp2((BT - o_i - 1) * b_b) - d_b = tl.math.exp2(BT * b_b) - - m_s = o_i[:, None] >= o_i[None, :] - d_s = tl.where(m_s, tl.math.exp2((o_i[:, None] - o_i[None, :]) * b_b), 0) * scale - # [BV, BK] - b_h = tl.zeros([BV, BK], dtype=tl.float32) - if USE_INITIAL_STATE: - p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, (DV, DK), (1, DV), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) - b_h = tl.load(p_h, boundary_check=(0, 1)).to(tl.float32) - - for i in range(0, tl.cdiv(T, BT)): - p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i * BT, i_k * BK), (BT, BK), (1, 0)) - p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i * BT), (BV, BT), (0, 1)) - p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i * BT, i_v * BV), (BT, BV), (1, 0)) - p_dq = tl.make_block_ptr(dq + (i_bh + i_v*B*H) * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i*BT, i_k*BK), (BT, BK), (1, 0)) - - # [BT, DK] - b_k = tl.load(p_k, boundary_check=(0, 1)) - # [DV, BT] - b_v = tl.load(p_v, boundary_check=(0, 1)) - # [BT, DV] - b_do = tl.load(p_do, boundary_check=(0, 1)) - b_dd = (b_do * d_q[:, None]).to(b_do.dtype) - - # [BT, BT] - b_ds = tl.dot(b_do, b_v, allow_tf32=False) - b_ds = (b_ds * d_s).to(b_k.dtype) - # [BT, DK] - b_dq = tl.dot(b_ds, b_k, allow_tf32=False) - # [DV, DK] - if CHECK and i == 0: - b_dq += tl.dot(b_dd, b_h.to(b_k.dtype), allow_tf32=False) - b_h = d_b * b_h + tl.dot((b_v * d_k[None, :]).to(b_k.dtype), b_k, allow_tf32=False) - else: - b_dq += tl.dot(b_dd, b_h.to(b_k.dtype), allow_tf32=False) - b_h = d_b * b_h + tl.dot((b_v * d_k[None, :]).to(b_k.dtype), b_k, allow_tf32=False) - - tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) - - # sync threads - b_h = None - tl.debug_barrier() - d_s = tl.trans(d_s) - # [BK, BV] - b_dh = tl.zeros([BK, BV], dtype=tl.float32) - for i in range(1, tl.cdiv(T, BT) + 1): - p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, T - i * BT), (BK, BT), (0, 1)) - p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (T - i * BT, i_k * BK), (BT, BK), (1, 0)) - p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0)) - p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0)) - p_dk = tl.make_block_ptr(dk + (i_bh+i_v*B*H) * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (T - i*BT, i_k*BK), (BT, BK), (1, 0)) - p_dv = tl.make_block_ptr(dv + (i_bh+i_k*B*H) * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i*BT, i_v*BV), (BT, BV), (1, 0)) - # [DK, BT] - b_q = tl.load(p_q, boundary_check=(0, 1)) - # [BT, DK] - b_k = tl.load(p_k, boundary_check=(0, 1)) - # [BT, DV] - b_v = tl.load(p_v, boundary_check=(0, 1)) - b_do = tl.load(p_do, boundary_check=(0, 1)) - b_dd = (b_do * d_q[:, None]).to(b_do.dtype) - - # [BT, BT] - b_ds = tl.dot(b_v, tl.trans(b_do), allow_tf32=False) - b_ds = (b_ds * d_s).to(b_k.dtype) - - # [BT, BT] - b_s = tl.dot(b_k, b_q, allow_tf32=False) * d_s - # [BT, DK] - b_dk = tl.dot(b_ds, tl.trans(b_q), allow_tf32=False) - # [BT, DV] - b_dv = tl.dot(b_s.to(b_q.dtype), b_do, allow_tf32=False) - if CHECK and i == 1: - b_dk += tl.dot(b_v, tl.trans(b_dh).to(b_v.dtype), allow_tf32=False) * d_k[:, None] - b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False) * d_k[:, None] - b_dh = d_b * b_dh + tl.dot(b_q, b_dd, allow_tf32=False) - else: - b_dk += tl.dot(b_v, tl.trans(b_dh).to(b_v.dtype), allow_tf32=False) * d_k[:, None] - b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False) * d_k[:, None] - b_dh = d_b * b_dh + tl.dot(b_q, b_dd, allow_tf32=False) - - tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) - tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) - - -class FusedChunkRetentionFunction(torch.autograd.Function): - - @staticmethod - @contiguous - @custom_fwd - def forward(ctx, q, k, v, initial_state, output_final_state): - batch_size, n_heads, seq_len, d_head_qk = q.shape - d_head_v = v.shape[-1] - - scale = d_head_qk ** -0.5 - BT = 64 - BK, BV = min(triton.next_power_of_2(d_head_qk), 64), min(triton.next_power_of_2(d_head_v), 64) - NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV) - num_stages = 1 - num_warps = 4 - - o = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v) - - if output_final_state: - final_state = q.new_empty(batch_size, n_heads, d_head_qk, d_head_v, dtype=torch.float32, requires_grad=False) - else: - final_state = None - CHECK = False - if version.parse(triton.__version__) < version.parse('2.2.0'): - import warnings - warnings.warn( - "Triton<2.2.0 detected for running this kernel, " - "which is known to have some weird compiler issues (refer to https://github.com/openai/triton/issues/2852) " - "that lead to significant precision loss. " - "We've add some initial condition checks to resolve this, sadly at the sacrifice of the speed. " - "For optimal performance, it is recommended to install Triton>=2.2.0 (if possible)." - ) - CHECK = True - - grid = (NV, NK, batch_size * n_heads) - fused_chunk_retention_fwd_kernel[grid]( - q, k, v, o, initial_state, final_state, - q.stride(1), q.stride(2), q.stride(3), - v.stride(1), v.stride(2), v.stride(3), - batch_size, n_heads, seq_len, scale, - BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV, - USE_INITIAL_STATE=initial_state is not None, - STORE_FINAL_STATE=output_final_state, - CHECK=CHECK, - num_warps=num_warps, - num_stages=num_stages - ) - - o = o.sum(0) - ctx.save_for_backward(q, k, v, initial_state) - ctx.CHECK = CHECK - return o.to(q.dtype), final_state - - @staticmethod - @custom_bwd - @contiguous - def backward(ctx, do, d_final_state=None): - q, k, v, initial_state = ctx.saved_tensors - batch_size, n_heads, seq_len, d_head_qk = q.shape - d_head_v = v.shape[-1] - scale = d_head_qk ** -0.5 - - BT = 64 - BK, BV = min(triton.next_power_of_2(d_head_qk), 64), min(triton.next_power_of_2(d_head_v), 64) - NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV) - num_stages = 1 - num_warps = 4 - - dq = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk) - dk = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk) - dv = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v) - grid = (NV, NK, batch_size * n_heads) - - fused_chunk_retention_bwd_kernel[grid]( - q, k, v, do, dq, dk, dv, initial_state, - q.stride(1), q.stride(2), q.stride(3), - v.stride(1), v.stride(2), v.stride(3), - batch_size, n_heads, seq_len, scale, - BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV, - USE_INITIAL_STATE=initial_state is not None, - CHECK=ctx.CHECK, - num_warps=num_warps, - num_stages=num_stages - ) - dq = dq.sum(0) - dk = dk.sum(0) - dv = dv.sum(0) - return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), None, None - - -def fused_chunk_retention( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - initial_state: torch.Tensor = None, - output_final_state: bool = False -): - if initial_state is not None: - initial_state = initial_state.detach() - o, final_state = FusedChunkRetentionFunction.apply(q, k, v, initial_state, output_final_state) - if output_final_state: - return o, final_state - else: - return o diff --git a/flash_linear_attention/fla/ops/triton/retention/parallel.py b/flash_linear_attention/fla/ops/triton/retention/parallel.py deleted file mode 100644 index e114039..0000000 --- a/flash_linear_attention/fla/ops/triton/retention/parallel.py +++ /dev/null @@ -1,341 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright (c) 2023, Yu Zhang, Songlin Yang - -import torch -import triton -import triton.language as tl - -from fla.ops.triton.utils import contiguous -from torch.cuda.amp import custom_bwd, custom_fwd - - -@triton.jit -def parallel_retention_fwd_kernel( - # B: batch_size, H: n_heads, T: seq_len, D: d_head - q, # query [B, H, L, D_head_K] - k, # key [B, H, L, D_head_V] - v, # value [B, H, L, D_head_V] - o, # output [B, H, L, D_head_V] - s_qk_h, # stride size: L * D_head_K - s_qk_t, # stride size: D_head_K - s_qk_d, # stride size: 1 - s_vo_h, # stride size: L * D_head_V - s_vo_t, # stride size: D_head_V - s_vo_d, # stride size: 1 - B, # batch size - H, # n_heads - T, # seq_len - scale, # D_head_K ** -0.5 - BTL: tl.constexpr, # BLOCK SIZE along the sequence dimension for Q - BTS: tl.constexpr, # BLOCK SIZE along the sequence dimension for K/V - BK: tl.constexpr, # BLOCK SIZE along the K dimension - BV: tl.constexpr, # BLOCK SIZE along the V dimension - DK: tl.constexpr, # D_head_K - DV: tl.constexpr, # D_head_V -): - # i_c: chunk index. used for sequence parallelism - i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) - NV = tl.cdiv(DV, BV) - i_k = i_kv // (NV) - i_v = i_kv % (NV) - i_h = i_bh % H - # decay rate given the head index - b_b = tl.math.log2(1 - tl.math.pow(2, -5 - i_h * 1.0)) - # overall decay rate for an entire block - d_b = tl.math.exp2(b_b * BTS) - # cumulative decay from the end of the chunk - o_k = tl.arange(0, BTS) - d_h = tl.math.exp2((BTS - o_k) * b_b) - - p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK), - (s_qk_t, s_qk_d), (i_c * BTL, i_k * BK), (BTL, BK), (1, 0)) - p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T), - (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BTS), (0, 1)) - p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), - (s_vo_t, s_vo_d), (0, i_v * BV), (BTS, BV), (1, 0)) - - # [BQ, BD] block Q, in the shared memory throughout the whole kernel - b_q = tl.load(p_q, boundary_check=(0, 1)) - b_q = (b_q * scale).to(b_q.dtype) - b_o = tl.zeros([BTL, BV], dtype=tl.float32) - - # Q block and K block have no overlap - # no need for mask, thereby saving flops - for _ in range(0, i_c * BTL, BTS): - # [BK, BTS] - b_k = tl.load(p_k, boundary_check=(0, 1)) - # [BTS, BV] - b_v = tl.load(p_v, boundary_check=(0, 1)) - # [BTL, BTS] - b_s = tl.dot(b_q, (b_k), allow_tf32=False) * d_h[None, :] - # [BQ, BD] - b_o = b_o * tl.math.exp2(b_b * BTS) - b_o = b_o + tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False) - p_k = tl.advance(p_k, (0, BTS)) - p_v = tl.advance(p_v, (BTS, 0)) - - # # rescale interchunk output - tl.debug_barrier() - o_q = tl.arange(0, BTL) - d_q = tl.math.exp2(tl.arange(0, BTL) * b_b) - b_o *= d_q[:, None] - # # sync threads, easy for compiler to optimize - # tl.debug_barrier() - - o_k = tl.arange(0, BTS) - p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T), - (s_qk_d, s_qk_t), (i_k * BK, i_c * BTL), (BK, BTS), (0, 1)) - p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), - (s_vo_t, s_vo_d), (i_c * BTL, i_v * BV), (BTS, BV), (1, 0)) - # Q block and K block have overlap. masks required - for _ in range(i_c * BTL, (i_c + 1) * BTL, BTS): - # [BK, BTS] - b_k = tl.load(p_k, boundary_check=(0, 1)) - # [BTS, BV] - b_v = tl.load(p_v, boundary_check=(0, 1)) - # [BTL, BTS] - m_s = o_q[:, None] >= o_k[None, :] - d_s = tl.where(m_s, tl.math.exp2( - (o_q[:, None] - o_k[None, :]) * b_b), 0) - b_s = tl.dot(b_q, b_k, allow_tf32=False) * d_s - # [BTL, BV] - b_o += tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False) - - p_k = tl.advance(p_k, (0, BTS)) - p_v = tl.advance(p_v, (BTS, 0)) - o_k += BTS - - p_o = tl.make_block_ptr(o + (i_bh + B * H * i_k) * s_vo_h, (T, DV), - (s_vo_t, s_vo_d), (i_c*BTL, i_v*BV), (BTL, BV), (1, 0)) - tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) - - -@triton.jit -def _parallel_retention_bwd_dq( - i_bh, i_c, i_k, i_v, i_h, - k, v, do, dq, s_qk_h, s_qk_t, s_qk_d, s_vo_h, - s_vo_t, s_vo_d, B, H, T, scale, - BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, - DK: tl.constexpr, DV: tl.constexpr, -): - p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), - (i_c * BTL, i_v * BV), (BTL, BV), (1, 0)) - b_do = tl.load(p_do, boundary_check=(0, 1)) - b_dq = tl.zeros([BTL, BK], dtype=tl.float32) - p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), - (s_qk_t, s_qk_d), (0, i_k * BK), (BTS, BK), (1, 0)) - p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (DV, T), - (s_vo_d, s_vo_t), (i_v * BV, 0), (BV, BTS), (0, 1)) - # decay rate given the head index - b_b = tl.math.log2(1 - tl.math.pow(2, -5 - i_h * 1.0)) - # overall decay rate for an entire block - d_b = tl.math.exp2(b_b * BTS) - # cumulative decay from the end of the chunk - d_h = tl.math.exp2((BTS - tl.arange(0, BTS)) * b_b) - for _ in range(0, i_c * BTL, BTS): - # [BTS, BK] - b_k = tl.load(p_k, boundary_check=(0, 1)) - # [BV, BTS] - b_v = tl.load(p_v, boundary_check=(0, 1)) - # [BTL, BTS] - b_ds = tl.dot(b_do, b_v, allow_tf32=False) * d_h[None, :] - # [BQ, BD] - b_dq *= d_b - b_dq += tl.dot(b_ds.to(b_v.dtype), b_k, allow_tf32=False) - p_k = tl.advance(p_k, (BTS, 0)) - p_v = tl.advance(p_v, (0, BTS)) - b_dq *= tl.math.exp2(tl.arange(0, BTL) * b_b)[:, None] * scale - o_q = tl.arange(0, BTL) - o_k = tl.arange(0, BTS) - p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), - (s_qk_t, s_qk_d), (i_c * BTL, i_k * BK), (BTS, BK), (1, 0)) - p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (DV, T), - (s_vo_d, s_vo_t), (i_v * BV, i_c * BTL), (BV, BTS), (0, 1)) - # Q block and K block have overlap. masks required - for _ in range(i_c * BTL, (i_c + 1) * BTL, BTS): - # [BTS, BK] - b_k = tl.load(p_k, boundary_check=(0, 1)) - # [BV, BTS] - b_v = tl.load(p_v, boundary_check=(0, 1)) - # [BTL, BTS] - m_s = o_q[:, None] >= o_k[None, :] - d_s = tl.where(m_s, tl.math.exp2( - (o_q[:, None] - o_k[None, :]) * b_b), 0) - b_ds = tl.dot(b_do, b_v, allow_tf32=False) * d_s * scale - # [BTL, BK] - b_dq += tl.dot(b_ds.to(b_k.dtype), b_k, allow_tf32=False) - p_k = tl.advance(p_k, (BTS, 0)) - p_v = tl.advance(p_v, (0, BTS)) - o_k += BTS - p_dq = tl.make_block_ptr(dq + (i_bh + B * H * i_v) * s_qk_h, (T, DK), - (s_qk_t, s_qk_d), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0)) - tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) - return - - -@triton.jit -def _parallel_retention_bwd_dkv( - i_bh, i_c, i_k, i_v, i_h, - q, k, v, do, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h, - s_vo_t, s_vo_d, B, H, T, scale, - BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, - DK: tl.constexpr, DV: tl.constexpr, -): - # no overlap. no need for mask. - b_b = tl.math.log2(1 - tl.math.pow(2, -5 - i_h * 1.0)) - # overall decay rate for an entire block - d_b = tl.math.exp2(b_b * BTS) - # compute dk dv - p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), - (i_c * BTL, i_k * BK), (BTL, BK), (1, 0)) - p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), - (i_c * BTL, i_v * BV), (BTL, BV), (1, 0)) - b_k, b_v = tl.load(p_k, boundary_check=(0, 1)), tl.load( - p_v, boundary_check=(0, 1)) - b_dk, b_dv = tl.zeros([BTL, BK], dtype=tl.float32), tl.zeros( - [BTL, BV], dtype=tl.float32) - d_h = tl.math.exp2((BTL - tl.arange(0, BTL)) * b_b) - b_kd = (b_k * d_h[:, None]).to(b_k.dtype) - d_q = tl.math.exp2(tl.arange(0, BTS) * b_b) - for i in range((tl.cdiv(T, BTS) * BTS)-BTS, (i_c + 1) * BTL - BTS, -BTS): - p_q = tl.make_block_ptr( - q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, i), (BK, BTS), (0, 1)) - p_do = tl.make_block_ptr( - do + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i), (BV, BTS), (0, 1)) - b_q = tl.load(p_q, boundary_check=(0, 1)) # [BK, BTS] - b_do = tl.load(p_do, boundary_check=(0, 1)) # [BV, BTS] - b_do = (b_do * d_q[None, :]).to(b_do.dtype) - - b_dv *= d_b - b_s = tl.dot(b_kd.to(b_q.dtype), b_q, allow_tf32=False) # [BTL, BTS] - b_dv += tl.dot(b_s.to(b_q.dtype), tl.trans(b_do), allow_tf32=False) - - b_dk *= d_b - b_ds = tl.dot(b_v, b_do, allow_tf32=False) - b_dk += tl.dot(b_ds.to(b_q.dtype), tl.trans(b_q), allow_tf32=False) - b_dk *= d_h[:, None] * scale - b_dv *= scale - tl.debug_barrier() - o_q, o_k = tl.arange(0, BTS), tl.arange(0, BTL) - for i in range(i_c*BTL, (i_c+1)*BTL, BTS): - p_q = tl.make_block_ptr( - q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, i), (BK, BTS), (0, 1)) - p_do = tl.make_block_ptr( - do + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i), (BV, BTS), (0, 1)) - b_q = tl.load(p_q, boundary_check=(0, 1)) # [BD, BQ] - b_do = tl.load(p_do, boundary_check=(0, 1)) - # [BK, BQ] - m_s = o_k[:, None] <= o_q[None, :] - d_s = tl.where(m_s, tl.math.exp2( - (-o_k[:, None] + o_q[None, :]) * b_b.to(tl.float32)), 0) * scale - b_s = tl.dot(b_k, b_q, allow_tf32=False) * d_s - b_ds = tl.dot(b_v, b_do, allow_tf32=False) * d_s - # [BK, BD] - b_dk += tl.dot(b_ds.to(b_q.dtype), tl.trans(b_q), allow_tf32=False) - b_dv += tl.dot(b_s.to(b_q.dtype), tl.trans(b_do), allow_tf32=False) - o_q += BTS - p_dk = tl.make_block_ptr(dk + (i_bh + B * H * i_v) * s_qk_h, - (T, DK), (s_qk_t, s_qk_d), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0)) - p_dv = tl.make_block_ptr(dv + (i_bh + B * H * i_k) * s_vo_h, - (T, DV), (s_vo_t, s_vo_d), (i_c*BTL, i_v*BV), (BTL, BV), (1, 0)) - tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) - tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) - return - - -@triton.jit -def parallel_retention_bwd_kernel( - q, k, v, do, dq, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h, - s_vo_t, s_vo_d, B, H, T, scale, - BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, - DK: tl.constexpr, DV: tl.constexpr, -): - i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) - NV = tl.cdiv(DV, BV) - i_k = i_kv // (NV) - i_v = i_kv % (NV) - i_h = i_bh % H - _parallel_retention_bwd_dq( - i_bh, i_c, i_k, i_v, i_h, - k, v, do, dq, s_qk_h, s_qk_t, s_qk_d, s_vo_h, - s_vo_t, s_vo_d, B, H, T, scale, BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=DK, DV=DV - ) - tl.debug_barrier() - _parallel_retention_bwd_dkv( - i_bh, i_c, i_k, i_v, i_h, - q, k, v, do, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h, - s_vo_t, s_vo_d, B, H, T, scale, BTL, BTS, BK, BV, DK, DV - ) - - -class ParallelRetentionFunction(torch.autograd.Function): - @staticmethod - @contiguous - @custom_fwd - def forward(ctx, q, k, v): - BTL, BTS = 128, 32 - assert BTL % BTS == 0 - BK = min(128, triton.next_power_of_2(k.shape[-1])) - BV = min(128, triton.next_power_of_2(v.shape[-1])) - batch_size, n_heads, seq_len, d_head_qk = q.shape - d_head_v = v.shape[-1] - num_stages = 3 if d_head_qk <= 64 else 2 - num_warps = 4 - NK = triton.cdiv(d_head_qk, BK) - NV = triton.cdiv(d_head_v, BV) - - grid = (NK * NV, triton.cdiv(seq_len, BTL), batch_size * n_heads) - scale = d_head_qk ** -0.5 - o = torch.empty(NK, batch_size, n_heads, seq_len, - d_head_v, dtype=q.dtype, device=q.device) - parallel_retention_fwd_kernel[grid]( - q, k, v, o, - q.stride(1), q.stride(2), q.stride(3), - v.stride(1), v.stride(2), v.stride(3), - batch_size, n_heads, seq_len, scale, - BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=d_head_qk, DV=d_head_v, - num_warps=num_warps, - num_stages=num_stages - ) - ctx.save_for_backward(q, k, v) - return o.sum(0).to(q.dtype) - - @staticmethod - @contiguous - @custom_bwd - def backward(ctx, do): - q, k, v = ctx.saved_tensors - BTL, BTS = 64, 32 - assert BTL % BTS == 0 - BK = min(128, triton.next_power_of_2(k.shape[-1])) - BV = min(128, triton.next_power_of_2(v.shape[-1])) - batch_size, n_heads, seq_len, d_head_qk = q.shape - d_head_v = v.shape[-1] - num_stages = 3 if d_head_qk <= 64 else 2 - num_warps = 4 - NK = triton.cdiv(d_head_qk, BK) - NV = triton.cdiv(d_head_v, BV) - grid = (NK * NV, triton.cdiv(seq_len, BTL), batch_size * n_heads) - scale = d_head_qk ** -0.5 - - dq = torch.empty(NV, batch_size, n_heads, seq_len, - d_head_qk, dtype=q.dtype, device=q.device) - dk = torch.empty(NV, batch_size, n_heads, seq_len, - d_head_qk, dtype=q.dtype, device=q.device) - dv = torch.empty(NK, batch_size, n_heads, seq_len, - d_head_v, dtype=q.dtype, device=q.device) - - parallel_retention_bwd_kernel[grid]( - q, k, v, do, dq, dk, dv, - q.stride(1), q.stride(2), q.stride(3), - v.stride(1), v.stride(2), v.stride(3), - batch_size, n_heads, seq_len, scale, - BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=d_head_qk, DV=d_head_v, - num_warps=num_warps, - num_stages=num_stages - ) - - return dq.sum(0).to(q.dtype), dk.sum(0).to(k.dtype), dv.sum(0).to(v.dtype) - - -parallel_retention = ParallelRetentionFunction.apply diff --git a/flash_linear_attention/fla/ops/triton/retention/recurrent_fuse.py b/flash_linear_attention/fla/ops/triton/retention/recurrent_fuse.py deleted file mode 100644 index 3abb107..0000000 --- a/flash_linear_attention/fla/ops/triton/retention/recurrent_fuse.py +++ /dev/null @@ -1,280 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright (c) 2023, Yu Zhang, Songlin Yang - -import torch -import triton -import triton.language as tl -from fla.ops.triton.utils import contiguous - -# on-the-fly computation without materializing hidden statets into HBMs - - -@triton.jit -def fused_recurrent_retention_fwd_kernel( - # B: batch_size, H: n_heads, T: seq_len, D: d_head - q, # query [B, H, L, D_head_K] - k, # key [B, H, L, D_head_V] - v, # value [B, H, L, D_head_V] - o, # output [B, H, L, D_head_V] - initial_state, - final_state, # final hidden state [B, H, D_head_K, D_head_V] - - s_qk_h, # stride size: L * D_head_K - s_qk_t, # stride size: D_head_K - s_qk_d, # stride size: 1 - - s_vo_h, # stride size: L * D_head_V - s_vo_t, # stride size: D_head_V - s_vo_d, # stride size: 1 - - B, # batch size - H, # n_heads - T, # seq_len - scale, # D_head_K ** -0.5 - BK: tl.constexpr, # BLOCK SIZE along the K dimension - BV: tl.constexpr, # BLOCK SIZE along the V dimension - DK: tl.constexpr, # D_head_K - DV: tl.constexpr, # D_head_V - USE_INITIAL_STATE: tl.constexpr, # whether to use initial state - STORE_FINAL_STATE: tl.constexpr, # whether to store final state -): - # indices - i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) - i_h = i_bh % H - - # decay rate given the head index - b_b = (1 - tl.math.pow(2, -5 - i_h * 1.0)) - - p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) - p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) - p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) - p_o = o + (i_bh + i_k * B * H) * s_vo_h + i_v * BV + tl.arange(0, BV) - - mask_bk = (i_k * BK + tl.arange(0, BK)) < DK - mask_bv = (i_v * BV + tl.arange(0, BV)) < DV - mask_kv = mask_bk[None, :] & mask_bv[:, None] - - h = tl.zeros([BV, BK], dtype=tl.float32) - - if USE_INITIAL_STATE: - p_init_s = initial_state + i_bh * DK * DV + \ - (i_k * BK + tl.arange(0, BK)[None, :]) * \ - DV + (i_v * BV + tl.arange(0, BV)[:, None]) - h += tl.load(p_init_s, mask=mask_kv, other=0).to(tl.float32) - - for _ in range(0, T): - _k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32) - _v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32) - _q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale - - h = b_b * h + _k[None, :] * _v[:, None] - _o = h * _q[None, :] - _o = tl.sum(_o, axis=1) - tl.store(p_o, _o.to(p_o.dtype.element_ty), mask=mask_bv) - - p_q += DK - p_k += DK - p_o += DV - p_v += DV - - if STORE_FINAL_STATE: - p_final_s = final_state + i_bh * DK * DV + \ - (i_k * BK + tl.arange(0, BK)[None, :]) * \ - DV + (i_v * BV + tl.arange(0, BV)[:, None]) - tl.store(p_final_s, h.to(p_final_s.dtype.element_ty), mask=mask_kv) - - -# Similar to Algorithm1 of https://arxiv.org/abs/2006.16236 -@triton.jit -def fused_recurrent_retention_bwd_kernel( - # B: batch_size, H: n_heads, T: seq_len, D: d_head - # NV: number of split in the V dimension. NK: number of split in the K dimension - q, # query [B, H, L, D_head_K] - k, # key [B, H, L, D_head_V] - v, # value [B, H, L, D_head_V] - - do, # gradient of output [B, H, L, D_head_V] - dq, # gradient of query [NV, B, H, L, D_head_K] - dk, # gradient of key [NV, B, H, L, D_head_K] - dv, # gradient of value [NK, B, H, L, D_head_V] - - # initial hidden state initialization [B, H, D_head_K, D_head_V] - initial_state, - - s_qk_h, # stride size: L * D_head_K - s_qk_t, # stride size: D_head_K - s_qk_d, # stride size: 1 - - s_vo_h, # stride size: L * D_head_V - s_vo_t, # stride size: D_head_V - s_vo_d, # stride size: 1 - - B, # batch_size - H, # n_heads - T, # seq_len - scale, # D_head_K ** -0.5 - BK: tl.constexpr, # BLOCK SIZE along the K dimension - BV: tl.constexpr, # BLOCK SIZE along the V dimension - DK: tl.constexpr, # D_head_K - DV: tl.constexpr, # D_head_V - USE_INITIAL_STATE: tl.constexpr, # whether to use initial state -): - i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) - i_h = i_bh % H - - b_b = 1 - tl.math.pow(2, -5 - i_h * 1.0) - - p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) - p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) - p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) - p_do = do + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) - - p_dq = dq + (i_bh + i_v * B * H) * s_qk_h + i_k * BK + tl.arange(0, BK) - mask_bk = i_k * BK + tl.arange(0, BK) < DK - mask_bv = i_v * BV + tl.arange(0, BV) < DV - - h = tl.zeros([BK, BV], dtype=tl.float32) - - if USE_INITIAL_STATE: - mask_kv = mask_bk[:, None] & mask_bv[None, :] - p_init_s = initial_state + i_bh * DK * DV + \ - (i_k * BK + tl.arange(0, BK)[:, None]) * \ - DV + (i_v * BV + tl.arange(0, BV)[None, :]) - h += tl.load(p_init_s, mask=mask_kv, other=0).to(tl.float32) - - for i in range(0, T): - _k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32) - _v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32) - _do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32) - - h = b_b * h + _k[:, None] * _v[None, :] - _d_q = h * _do[None, :] - d_q = tl.sum(_d_q, axis=1) * scale - tl.store(p_dq, d_q.to(p_dq.dtype.element_ty), mask=mask_bk) - - p_k += DK - p_do += DV - p_v += DV - p_dq += DK - - # sync threads - tl.debug_barrier() - - p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (T - 1) * DK - p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (T - 1) * DK - p_do = do + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + (T - 1) * DV - p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + (T - 1) * DV - p_dk = dk + (i_bh + i_v * B * H) * s_qk_h + i_k * \ - BK + tl.arange(0, BK) + (T - 1) * DK - p_dv = dv + (i_bh + i_k * B * H) * s_vo_h + i_v * \ - BV + tl.arange(0, BV) + (T - 1) * DV - d_h = tl.zeros([BK, BV], dtype=tl.float32) - - for _ in range(T): - _do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32) - _q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale - _k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32) - _v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32) - d_h += _q[:, None] * _do[None, :] - d_k = tl.sum(d_h * _v[None, :], axis=1) - d_v = tl.sum(d_h * _k[:, None], axis=0) - - d_h *= b_b - tl.store(p_dk, d_k.to(p_dk.dtype.element_ty), mask=mask_bk) - tl.store(p_dv, d_v.to(p_dv.dtype.element_ty), mask=mask_bv) - - p_do -= DV - p_q -= DK - p_k -= DK - p_v -= DV - p_dk -= DK - p_dv -= DV - - -class FusedRecurrentRetentionFunction(torch.autograd.Function): - - @staticmethod - @contiguous - def forward(ctx, q, k, v, initial_state=None, output_final_state=False): - batch_size, n_heads, seq_len, d_head_qk = q.shape - d_head_v = v.shape[-1] - - scale = d_head_qk ** -0.5 - BK, BV = min(d_head_qk, 32), min(d_head_v, 32) - NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV) - num_stages = 1 - num_warps = 1 - - o = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v) - - if output_final_state: - final_state = q.new_empty(batch_size, n_heads, d_head_qk, d_head_v) - else: - final_state = None - - grid = (NV, NK, batch_size * n_heads) - fused_recurrent_retention_fwd_kernel[grid]( - q, k, v, o, initial_state, final_state, - q.stride(1), q.stride(2), q.stride(3), - v.stride(1), v.stride(2), v.stride(3), - batch_size, n_heads, seq_len, scale, - DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV, - num_warps=num_warps, - num_stages=num_stages, - USE_INITIAL_STATE=initial_state is not None, - STORE_FINAL_STATE=final_state is not None - ) - - o = o.sum(0) - ctx.save_for_backward(q, k, v, initial_state) - return o, final_state - - @staticmethod - @contiguous - def backward(ctx, do, d_final_state=None): - q, k, v, initial_state = ctx.saved_tensors - batch_size, n_heads, seq_len, d_head_qk = q.shape - d_head_v = v.shape[-1] - scale = d_head_qk ** -0.5 - - BK, BV = min(d_head_qk, 32), min(d_head_v, 32) - NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV) - num_stages = 1 - num_warps = 1 - - dq = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk) - dk = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk) - dv = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v) - grid = (NV, NK, batch_size * n_heads) - - fused_recurrent_retention_bwd_kernel[grid]( - q, k, v, do, dq, dk, dv, initial_state, - q.stride(1), q.stride(2), q.stride(3), - v.stride(1), v.stride(2), v.stride(3), - batch_size, n_heads, seq_len, scale, - DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV, - num_warps=num_warps, - num_stages=num_stages, - USE_INITIAL_STATE=initial_state is not None - ) - dq = dq.sum(0) - dk = dk.sum(0) - dv = dv.sum(0) - return dq, dk, dv, None, None - - -# fused_recurrent_retention = FusedRecurrentRetentionFunction.apply - -def fused_recurrent_retention(q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - initial_state: torch.Tensor = None, - output_final_state: bool = False): - if initial_state is not None: - initial_state = initial_state.detach() - o, final_state = FusedRecurrentRetentionFunction.apply( - q, k, v, initial_state, output_final_state) - if output_final_state: - return o, final_state - else: - return o diff --git a/flash_linear_attention/fla/ops/triton/rotary.py b/flash_linear_attention/fla/ops/triton/rotary.py deleted file mode 100644 index 18ccc5f..0000000 --- a/flash_linear_attention/fla/ops/triton/rotary.py +++ /dev/null @@ -1,252 +0,0 @@ -# Copyright (c) 2023, Tri Dao. https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/rotary.py - -from typing import Optional, Union - -import torch - -import triton -import triton.language as tl - - -# @triton.autotune( -# configs=[ -# triton.Config({"BLOCK_M": 2}), -# triton.Config({"BLOCK_M": 4}), -# triton.Config({"BLOCK_M": 8}), -# triton.Config({"BLOCK_M": 16}), -# ], -# key=["CACHE_KEY_SEQLEN", "BLOCK_K", "INTERLEAVED"], -# ) -@triton.jit -def rotary_kernel( - OUT, # Pointers to matrices - X, - COS, - SIN, - CU_SEQLENS, - SEQLEN_OFFSETS, # this could be int or a pointer - # Matrix dimensions - seqlen, - nheads, - rotary_dim, - seqlen_ro, - CACHE_KEY_SEQLEN, - # strides - stride_out_batch, - stride_out_seqlen, - stride_out_nheads, - stride_out_headdim, - stride_x_batch, - stride_x_seqlen, - stride_x_nheads, - stride_x_headdim, - # Meta-parameters - BLOCK_K: tl.constexpr, - IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr, - IS_VARLEN: tl.constexpr, - INTERLEAVED: tl.constexpr, - CONJUGATE: tl.constexpr, - BLOCK_M: tl.constexpr, -): - pid_m = tl.program_id(axis=0) - pid_batch = tl.program_id(axis=1) - pid_head = tl.program_id(axis=2) - rotary_dim_half = rotary_dim // 2 - - if not IS_VARLEN: - X = X + pid_batch * stride_x_batch + pid_head * stride_x_nheads - OUT = OUT + pid_batch * stride_out_batch + pid_head * stride_out_nheads - else: - start_idx = tl.load(CU_SEQLENS + pid_batch) - seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx - X = X + start_idx * stride_x_seqlen + pid_head * stride_x_nheads - OUT = OUT + start_idx * stride_out_seqlen + pid_head * stride_out_nheads - - if pid_m * BLOCK_M >= seqlen: - return - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - if not IS_SEQLEN_OFFSETS_TENSOR: - rm_cs = rm + SEQLEN_OFFSETS - else: - rm_cs = rm + tl.load(SEQLEN_OFFSETS + pid_batch) - rk = tl.arange(0, BLOCK_K) - rk_half = tl.arange(0, BLOCK_K // 2) - - if not INTERLEAVED: - # Load the 1st and 2nd halves of X, do calculation, then store to 1st and 2nd halves of OUT - X = X + (rm[:, None] * stride_x_seqlen + - rk_half[None, :] * stride_x_headdim) - COS = COS + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :]) - SIN = SIN + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :]) - cos = tl.load( - COS, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=1.0 - ).to(tl.float32) - sin = tl.load( - SIN, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=0.0 - ).to(tl.float32) - x0 = tl.load( - X, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), other=0.0 - ).to(tl.float32) - x1 = tl.load( - X + rotary_dim_half * stride_x_headdim, - mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), - other=0.0, - ).to(tl.float32) - if CONJUGATE: - sin = -sin - o0 = x0 * cos - x1 * sin - o1 = x0 * sin + x1 * cos - # write back result - OUT = OUT + (rm[:, None] * stride_out_seqlen + - rk_half[None, :] * stride_out_headdim) - tl.store(OUT, o0, mask=(rm[:, None] < seqlen) - & (rk_half[None, :] < rotary_dim_half)) - tl.store( - OUT + rotary_dim_half * stride_out_headdim, - o1, - mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), - ) - else: - # We don't want to load X[0, 2, 4, ...] and X[1, 3, 5, ...] separately since both are slow. - # Instead, we load x0 = X[0, 1, 2, 3, ...] and x1 = X[1, 0, 3, 2, ...]. - # Loading x0 will be fast but x1 will be slow. - # Then we load cos = COS[0, 0, 1, 1, ...] and sin = SIN[0, 0, 1, 1, ...]. - # Then we do the calculation and use tl.where to pick put the right outputs for the even - # and for the odd indices. - rk_swap = rk + ((rk + 1) % 2) * 2 - 1 # 1, 0, 3, 2, 5, 4, ... - rk_repeat = tl.arange(0, BLOCK_K) // 2 - X0 = X + (rm[:, None] * stride_x_seqlen + - rk[None, :] * stride_x_headdim) - X1 = X + (rm[:, None] * stride_x_seqlen + - rk_swap[None, :] * stride_x_headdim) - COS = COS + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :]) - SIN = SIN + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :]) - cos = tl.load( - COS, - mask=(rm_cs[:, None] < seqlen_ro) & ( - rk_repeat[None, :] < rotary_dim_half), - other=1.0, - ).to(tl.float32) - sin = tl.load( - SIN, - mask=(rm_cs[:, None] < seqlen_ro) & ( - rk_repeat[None, :] < rotary_dim_half), - other=0.0, - ).to(tl.float32) - x0 = tl.load(X0, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim), other=0.0).to( - tl.float32 - ) - x1 = tl.load( - X1, mask=(rm[:, None] < seqlen) & (rk_swap[None, :] < rotary_dim), other=0.0 - ).to(tl.float32) - if CONJUGATE: - sin = -sin - x0_cos = x0 * cos - x1_sin = x1 * sin - out = tl.where(rk[None, :] % 2 == 0, x0_cos - x1_sin, x0_cos + x1_sin) - OUT = OUT + (rm[:, None] * stride_out_seqlen + - rk[None, :] * stride_out_headdim) - tl.store(OUT, out, mask=(rm[:, None] < seqlen) - & (rk[None, :] < rotary_dim)) - - -def apply_rotary( - x: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, - seqlen_offsets: Union[int, torch.Tensor] = 0, - cu_seqlens: Optional[torch.Tensor] = None, - max_seqlen: Optional[int] = None, - interleaved=False, - inplace=False, - conjugate=False, -) -> torch.Tensor: - """ - Arguments: - x: (batch, seqlen, nheads, headdim) if cu_seqlens is None - else (total_seqlen, nheads, headdim). - cos: (seqlen_ro, rotary_dim / 2) - sin: (seqlen_ro, rotary_dim / 2) - seqlen_offsets: integer or integer tensor of size (batch,) - cu_seqlens: (batch + 1,) or None - max_seqlen: int - Returns: - y: (batch, seqlen, nheads, headdim) - """ - is_varlen = cu_seqlens is not None - if not is_varlen: - batch, seqlen, nheads, headdim = x.shape - else: - assert max_seqlen is not None, "If cu_seqlens is passed in, then max_seqlen must be passed" - total_seqlen, nheads, headdim = x.shape - batch_p_1 = cu_seqlens.shape[0] - batch = batch_p_1 - 1 - seqlen = max_seqlen - seqlen_ro, rotary_dim = cos.shape - assert sin.shape == cos.shape - rotary_dim *= 2 - assert rotary_dim <= headdim, "rotary_dim must be <= headdim" - assert headdim <= 256, "Only support headdim <= 256" - assert seqlen_ro >= seqlen, "seqlen_ro must be >= seqlen" - - assert ( - cos.dtype == sin.dtype - ), f"cos and sin must have the same dtype, got {cos.dtype} and {sin.dtype}" - assert ( - x.dtype == cos.dtype - ), f"Input and cos/sin must have the same dtype, got {x.dtype} and {cos.dtype}" - - cos, sin = cos.contiguous(), sin.contiguous() - if isinstance(seqlen_offsets, torch.Tensor): - assert seqlen_offsets.shape == (batch,) - assert seqlen_offsets.dtype in [torch.int32, torch.int64] - seqlen_offsets = seqlen_offsets.contiguous() - else: - assert seqlen_offsets + seqlen <= seqlen_ro - - output = torch.empty_like(x) if not inplace else x - if rotary_dim < headdim and not inplace: - output[..., rotary_dim:].copy_(x[..., rotary_dim:]) - - BLOCK_K = ( - 32 - if rotary_dim <= 32 - else (64 if rotary_dim <= 64 else (128 if rotary_dim <= 128 else 256)) - ) - def grid(META): return (triton.cdiv(seqlen, META["BLOCK_M"]), batch, nheads) # noqa - BLOCK_M = 4 if interleaved else (8 if rotary_dim <= 64 else 4) - - # Need this, otherwise Triton tries to launch from cuda:0 and we get - # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?) - with torch.cuda.device(x.device.index): - rotary_kernel[grid]( - output, # data ptrs - x, - cos, - sin, - cu_seqlens, - seqlen_offsets, - seqlen, # shapes - nheads, - rotary_dim, - seqlen_ro, - # key for triton cache (limit number of compilations) - seqlen // 128, - # batch_strides if not varlen else 0 - output.stride(0) if not is_varlen else 0, - output.stride(-3), # seqlen_stride or total_seqlen_stride - output.stride(-2), # nheads_stride - output.stride(-1), # headdim_stride - # batch_strides if not varlen else 0 - x.stride(0) if not is_varlen else 0, - x.stride(-3), # seqlen stride or total_seqlen_stride - x.stride(-2), # nheads stride - x.stride(-1), # headdim stride - BLOCK_K, - isinstance(seqlen_offsets, torch.Tensor), - is_varlen, - interleaved, - conjugate, - BLOCK_M, - ) - return output diff --git a/flash_linear_attention/fla/ops/triton/utils.py b/flash_linear_attention/fla/ops/triton/utils.py deleted file mode 100644 index 93af956..0000000 --- a/flash_linear_attention/fla/ops/triton/utils.py +++ /dev/null @@ -1,27 +0,0 @@ -# -*- coding: utf-8 -*- - -import functools - -import torch - - -def contiguous(fn): - @functools.wraps(fn) - def wrapper(ctx, *args, **kwargs): - return fn(ctx, - *(i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args), - **{k: (v if not isinstance(v, torch.Tensor) else v.contiguous()) for k, v in kwargs.items()}) - return wrapper - - -def require_version(version, hint): - def decorator(fn): - @functools.wraps(fn) - def wrapper(ctx, *args, **kwargs): - from transformers.utils.versions import require_version - require_version(version, hint) - return fn(ctx, - *(i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args), - **{k: (v if not isinstance(v, torch.Tensor) else v.contiguous()) for k, v in kwargs.items()}) - return wrapper - return decorator diff --git a/flash_linear_attention/setup.py b/flash_linear_attention/setup.py deleted file mode 100644 index 1d3d6b6..0000000 --- a/flash_linear_attention/setup.py +++ /dev/null @@ -1,147 +0,0 @@ -# -*- coding: utf-8 -*- - -import ast -import os -import re -import subprocess -import warnings -from pathlib import Path - -import torch -from packaging.version import Version, parse -from setuptools import find_packages, setup -from torch.utils.cpp_extension import CUDA_HOME, BuildExtension, CUDAExtension - - -long_description = "" - -# ninja build does not work unless include_dirs are abs path -this_dir = os.path.dirname(os.path.abspath(__file__)) - -PACKAGE_NAME = 'fla' - -# FORCE_BUILD: force a fresh build locally, instead of attempting to find prebuilt wheels -FORCE_BUILD = os.getenv('FLA_FORCE_BUILD', "FALSE") == 'TRUE' -# SKIP_CUDA_BUILD: allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation -SKIP_CUDA_BUILD = os.getenv('FLA_SKIP_CUDA_BUILD', "FALSE") == 'TRUE' -# For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI -FORCE_CXX11_ABI = os.getenv('FLA_FORCE_CXX11_ABI', "FALSE") == 'TRUE' - - -def get_cuda_bare_metal_version(cuda_dir): - raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) - output = raw_output.split() - release_idx = output.index("release") + 1 - bare_metal_version = parse(output[release_idx].split(",")[0]) - - return raw_output, bare_metal_version - - -def check_if_cuda_home_none(global_option: str) -> None: - if CUDA_HOME is not None: - return - # warn instead of error because user could be downloading prebuilt wheels, so nvcc won't be necessary - # in that case. - warnings.warn( - f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? " - "If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, " - "only images whose names contain 'devel' will provide nvcc." - ) - - -def append_nvcc_threads(nvcc_extra_args): - return nvcc_extra_args + ["--threads", "4"] - - -ext_modules = [] -if not SKIP_CUDA_BUILD: - print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) - TORCH_MAJOR = int(torch.__version__.split(".")[0]) - TORCH_MINOR = int(torch.__version__.split(".")[1]) - - # Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h - # See https://github.com/pytorch/pytorch/pull/70650 - generator_flag = [] - torch_dir = torch.__path__[0] - if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.h")): - generator_flag = ["-DOLD_GENERATOR_PATH"] - - check_if_cuda_home_none('fla') - # Check, if CUDA11 is installed for compute capability 8.0 - cc_flag = [] - if CUDA_HOME is not None: - _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) - if bare_metal_version < Version("11.6"): - raise RuntimeError( - "FLA is only supported on CUDA 11.6 and above. " - "Note: make sure nvcc has a supported version by running nvcc -V." - ) - cc_flag.append("-gencode") - cc_flag.append("arch=compute_80,code=sm_80") - if CUDA_HOME is not None: - if bare_metal_version >= Version("11.8"): - cc_flag.append("-gencode") - cc_flag.append("arch=compute_90,code=sm_90") - - ext_modules = [ - CUDAExtension( - name='semiring_cal_A', - sources=[ - 'fla/ops/cuda/gla/semiring/cal_A/inner_chunk16_dim16x.cpp', - 'fla/ops/cuda/gla/semiring/cal_A/inner_chunk16_dim16x_kernel.cu', - ], - extra_compile_args={ - "cxx": ["-O3", "-std=c++17"] + generator_flag, - "nvcc": append_nvcc_threads( - [ - "-O3", - "-std=c++17", - "-U__CUDA_NO_HALF_OPERATORS__", - "-U__CUDA_NO_HALF_CONVERSIONS__", - "-U__CUDA_NO_HALF2_OPERATORS__", - "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", - "--expt-relaxed-constexpr", - "--expt-extended-lambda", - "--use_fast_math", - ] - + generator_flag - + cc_flag - ), - }, - ), - ] - - -def get_package_version(): - with open(Path(this_dir) / 'fla' / '__init__.py') as f: - version_match = re.search(r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE) - return ast.literal_eval(version_match.group(1)) - - -setup( - name=PACKAGE_NAME, - version=get_package_version(), - description='Fast Triton-based implementations of causal linear attention', - long_description=long_description, - long_description_content_type='text/markdown', - author='Songlin Yang, Yu Zhang', - author_email='bestsonta@gmail.com', - url='https://github.com/sustcsonglin/flash-linear-attention', - packages=find_packages(), - license='MIT', - classifiers=[ - 'Programming Language :: Python :: 3', - 'License :: OSI Approved :: MIT License', - 'Operating System :: OS Independent', - 'Topic :: Scientific/Engineering :: Artificial Intelligence' - ], - python_requires='>=3.7', - ext_modules=ext_modules, - cmdclass={'build_ext': BuildExtension}, - install_requires=[ - 'triton', - 'transformers', - 'einops', - 'ninja' - ] -) From f210c03bd24577c9740111ab3998f54303dff317 Mon Sep 17 00:00:00 2001 From: kabachuha Date: Wed, 6 Mar 2024 17:58:22 +0800 Subject: [PATCH 4/4] support for ring attention --- models/latte.py | 14 +++++++++++++- models/latte_img.py | 14 +++++++++++++- models/latte_t2v.py | 6 ++++++ models/utils.py | 23 ++++++++++++++++++++--- 4 files changed, 52 insertions(+), 5 deletions(-) diff --git a/models/latte.py b/models/latte.py index af5c793..223ec02 100644 --- a/models/latte.py +++ b/models/latte.py @@ -27,10 +27,17 @@ XFORMERS_IS_AVAILBLE = False try: + # needs to have https://github.com/corl-team/rebased/ installed from fla.ops.triton.rebased_fast import parallel_rebased except: REBASED_IS_AVAILABLE = False +try: + # needs to have https://github.com/lucidrains/ring-attention-pytorch installed + from ring_attention_pytorch.ring_flash_attention_cuda import ring_flash_attn_cuda +except: + RING_ATTENTION_IS_AVAILABLE = False + # from timm.models.layers.helpers import to_2tuple # from timm.models.layers.trace_utils import _assert @@ -42,7 +49,7 @@ def modulate(x, shift, scale): ################################################################################# class Attention(nn.Module): - def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., use_lora=False, attention_mode='math', eps=1e-12): + def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., use_lora=False, attention_mode='math', eps=1e-12, causal=True, ring_bucket_size=1024): super().__init__() assert dim % num_heads == 0, 'dim should be divisible by num_heads' self.num_heads = num_heads @@ -54,6 +61,8 @@ def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) self.eps = eps + self.causal = causal + self.ring_bucket_size = ring_bucket_size def forward(self, x): B, N, C = x.shape @@ -78,6 +87,9 @@ def forward(self, x): elif self.attention_mode == 'rebased': x = parallel_rebased(q, k, v, self.eps, True, True).reshape(B, N, C) + elif self.attention_mode == 'ring': + x = ring_flash_attn_cuda(q, k, v, causal=self.causal, bucket_size=self.ring_bucket_size).reshape(B, N, C) + else: raise NotImplemented diff --git a/models/latte_img.py b/models/latte_img.py index a8fc486..b9b65d6 100644 --- a/models/latte_img.py +++ b/models/latte_img.py @@ -27,10 +27,16 @@ XFORMERS_IS_AVAILBLE = False try: + # needs to have https://github.com/corl-team/rebased/ installed from fla.ops.triton.rebased_fast import parallel_rebased except: REBASED_IS_AVAILABLE = False +try: + # needs to have https://github.com/lucidrains/ring-attention-pytorch installed + from ring_attention_pytorch.ring_flash_attention_cuda import ring_flash_attn_cuda +except: + RING_ATTENTION_IS_AVAILABLE = False # from timm.models.layers.helpers import to_2tuple # from timm.models.layers.trace_utils import _assert @@ -43,7 +49,7 @@ def modulate(x, shift, scale): ################################################################################# class Attention(nn.Module): - def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., use_lora=False, attention_mode='math', eps=1e-12): + def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., use_lora=False, attention_mode='math', eps=1e-12, causal=True, ring_bucket_size=1024): super().__init__() assert dim % num_heads == 0, 'dim should be divisible by num_heads' self.num_heads = num_heads @@ -58,6 +64,8 @@ def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) self.eps = eps + self.causal = causal + self.ring_bucket_size = ring_bucket_size def forward(self, x): B, N, C = x.shape @@ -81,6 +89,10 @@ def forward(self, x): elif self.attention_mode == 'rebased': x = parallel_rebased(q, k, v, self.eps, True, True).reshape(B, N, C) + + elif self.attention_mode == 'ring': + x = ring_flash_attn_cuda(q, k, v, causal=self.causal, bucket_size=self.ring_bucket_size).reshape(B, N, C) + else: raise NotImplemented diff --git a/models/latte_t2v.py b/models/latte_t2v.py index e085116..6debde1 100644 --- a/models/latte_t2v.py +++ b/models/latte_t2v.py @@ -46,6 +46,9 @@ def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int, if attn_type == 'rebased': from models.utils import RebasedAttnProcessor attn_proc = RebasedAttnProcessor() + elif attn_type == 'ring': + from models.utils import RingAttnProcessor + attn_proc = RingAttnProcessor() else: attn_proc = None @@ -222,6 +225,9 @@ def __init__( if attn_type == 'rebased': from models.utils import RebasedAttnProcessor attn_proc = RebasedAttnProcessor() + elif attn_type == 'ring': + from models.utils import RingAttnProcessor + attn_proc = RingAttnProcessor() else: attn_proc = None diff --git a/models/utils.py b/models/utils.py index 42c5b4f..11d6fde 100644 --- a/models/utils.py +++ b/models/utils.py @@ -215,12 +215,19 @@ def count_params(model, verbose=False): return total_params try: + # needs to have https://github.com/corl-team/rebased/ installed from fla.ops.triton.rebased_fast import parallel_rebased except: REBASED_IS_AVAILABLE = False +try: + # needs to have https://github.com/lucidrains/ring-attention-pytorch installed + from ring_attention_pytorch.ring_flash_attention_cuda import ring_flash_attn_cuda +except: + RING_ATTENTION_IS_AVAILABLE = False + from diffusers.models.attention_processor import Attention -class RebasedAttnProcessor: +class AltAttnProcessor: def __call__( self, @@ -276,7 +283,7 @@ def __call__( # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 - hidden_states = parallel_rebased(query, key, value, eps, True, True) + hidden_states = self.attn_fn(query, key, value, eps, True, True) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) @@ -294,4 +301,14 @@ def __call__( hidden_states = hidden_states / attn.rescale_output_factor - return hidden_states \ No newline at end of file + return hidden_states + +class RebasedAttnProcessor(AltAttnProcessor): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.attn_fn = parallel_rebased + +class RingAttnProcessor(AltAttnProcessor): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.attn_fn = ring_flash_attn_cuda