diff --git a/models/latte.py b/models/latte.py index 723686a..223ec02 100644 --- a/models/latte.py +++ b/models/latte.py @@ -26,6 +26,18 @@ except: XFORMERS_IS_AVAILBLE = False +try: + # needs to have https://github.com/corl-team/rebased/ installed + from fla.ops.triton.rebased_fast import parallel_rebased +except: + REBASED_IS_AVAILABLE = False + +try: + # needs to have https://github.com/lucidrains/ring-attention-pytorch installed + from ring_attention_pytorch.ring_flash_attention_cuda import ring_flash_attn_cuda +except: + RING_ATTENTION_IS_AVAILABLE = False + # from timm.models.layers.helpers import to_2tuple # from timm.models.layers.trace_utils import _assert @@ -37,7 +49,7 @@ def modulate(x, shift, scale): ################################################################################# class Attention(nn.Module): - def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., use_lora=False, attention_mode='math'): + def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., use_lora=False, attention_mode='math', eps=1e-12, causal=True, ring_bucket_size=1024): super().__init__() assert dim % num_heads == 0, 'dim should be divisible by num_heads' self.num_heads = num_heads @@ -48,6 +60,9 @@ def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) + self.eps = eps + self.causal = causal + self.ring_bucket_size = ring_bucket_size def forward(self, x): B, N, C = x.shape @@ -69,6 +84,12 @@ def forward(self, x): attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B, N, C) + elif self.attention_mode == 'rebased': + x = parallel_rebased(q, k, v, self.eps, True, True).reshape(B, N, C) + + elif self.attention_mode == 'ring': + x = ring_flash_attn_cuda(q, k, v, causal=self.causal, bucket_size=self.ring_bucket_size).reshape(B, N, C) + else: raise NotImplemented diff --git a/models/latte_img.py b/models/latte_img.py index c468c63..b9b65d6 100644 --- a/models/latte_img.py +++ b/models/latte_img.py @@ -26,6 +26,18 @@ except: XFORMERS_IS_AVAILBLE = False +try: + # needs to have https://github.com/corl-team/rebased/ installed + from fla.ops.triton.rebased_fast import parallel_rebased +except: + REBASED_IS_AVAILABLE = False + +try: + # needs to have https://github.com/lucidrains/ring-attention-pytorch installed + from ring_attention_pytorch.ring_flash_attention_cuda import ring_flash_attn_cuda +except: + RING_ATTENTION_IS_AVAILABLE = False + # from timm.models.layers.helpers import to_2tuple # from timm.models.layers.trace_utils import _assert @@ -37,7 +49,7 @@ def modulate(x, shift, scale): ################################################################################# class Attention(nn.Module): - def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., use_lora=False, attention_mode='math'): + def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., use_lora=False, attention_mode='math', eps=1e-12, causal=True, ring_bucket_size=1024): super().__init__() assert dim % num_heads == 0, 'dim should be divisible by num_heads' self.num_heads = num_heads @@ -51,6 +63,9 @@ def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) + self.eps = eps + self.causal = causal + self.ring_bucket_size = ring_bucket_size def forward(self, x): B, N, C = x.shape @@ -72,6 +87,12 @@ def forward(self, x): attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B, N, C) + elif self.attention_mode == 'rebased': + x = parallel_rebased(q, k, v, self.eps, True, True).reshape(B, N, C) + + elif self.attention_mode == 'ring': + x = ring_flash_attn_cuda(q, k, v, causal=self.causal, bucket_size=self.ring_bucket_size).reshape(B, N, C) + else: raise NotImplemented diff --git a/models/latte_t2v.py b/models/latte_t2v.py index fc96a30..6debde1 100644 --- a/models/latte_t2v.py +++ b/models/latte_t2v.py @@ -37,13 +37,22 @@ class GatedSelfAttentionDense(nn.Module): d_head (`int`): The number of channels in each head. """ - def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int): + def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int, attn_type: str = 'vanilla'): super().__init__() # we need a linear projection since we need cat visual feature and obj feature self.linear = nn.Linear(context_dim, query_dim) - self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head) + if attn_type == 'rebased': + from models.utils import RebasedAttnProcessor + attn_proc = RebasedAttnProcessor() + elif attn_type == 'ring': + from models.utils import RingAttnProcessor + attn_proc = RingAttnProcessor() + else: + attn_proc = None + + self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head, processor=attn_proc) self.ff = FeedForward(query_dim, activation_fn="geglu") self.norm1 = nn.LayerNorm(query_dim) @@ -178,6 +187,7 @@ def __init__( attention_type: str = "default", positional_embeddings: Optional[str] = None, num_positional_embeddings: Optional[int] = None, + attn_type: str = "vanilla" ): super().__init__() self.only_cross_attention = only_cross_attention @@ -212,6 +222,15 @@ def __init__( else: self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + if attn_type == 'rebased': + from models.utils import RebasedAttnProcessor + attn_proc = RebasedAttnProcessor() + elif attn_type == 'ring': + from models.utils import RingAttnProcessor + attn_proc = RingAttnProcessor() + else: + attn_proc = None + self.attn1 = Attention( query_dim=dim, heads=num_attention_heads, @@ -220,6 +239,7 @@ def __init__( bias=attention_bias, cross_attention_dim=cross_attention_dim if only_cross_attention else None, upcast_attention=upcast_attention, + processor=attn_proc ) # # 2. Cross-Attn @@ -254,7 +274,7 @@ def __init__( # 4. Fuser if attention_type == "gated" or attention_type == "gated-text-image": - self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim) + self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim,attn_type=attn_type) # 5. Scale-shift for PixArt-Alpha. if self.use_ada_layer_norm_single: @@ -498,6 +518,7 @@ def __init__( attention_type: str = "default", caption_channels: int = None, video_length: int = 16, + attn_type: str = "vanilla" ): super().__init__() self.use_linear_projection = use_linear_projection @@ -600,6 +621,7 @@ def __init__( norm_elementwise_affine=norm_elementwise_affine, norm_eps=norm_eps, attention_type=attention_type, + attn_type=attn_type ) for d in range(num_layers) ] @@ -624,6 +646,7 @@ def __init__( norm_elementwise_affine=norm_elementwise_affine, norm_eps=norm_eps, attention_type=attention_type, + attn_type=attn_type ) for d in range(num_layers) ] diff --git a/models/utils.py b/models/utils.py index 0e13056..11d6fde 100644 --- a/models/utils.py +++ b/models/utils.py @@ -212,4 +212,103 @@ def count_params(model, verbose=False): total_params = sum(p.numel() for p in model.parameters()) if verbose: print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") - return total_params \ No newline at end of file + return total_params + +try: + # needs to have https://github.com/corl-team/rebased/ installed + from fla.ops.triton.rebased_fast import parallel_rebased +except: + REBASED_IS_AVAILABLE = False + +try: + # needs to have https://github.com/lucidrains/ring-attention-pytorch installed + from ring_attention_pytorch.ring_flash_attention_cuda import ring_flash_attn_cuda +except: + RING_ATTENTION_IS_AVAILABLE = False + +from diffusers.models.attention_processor import Attention +class AltAttnProcessor: + + def __call__( + self, + attn: Attention, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + eps = 1e-12 + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = self.attn_fn(query, key, value, eps, True, True) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + +class RebasedAttnProcessor(AltAttnProcessor): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.attn_fn = parallel_rebased + +class RingAttnProcessor(AltAttnProcessor): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.attn_fn = ring_flash_attn_cuda