Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion models/latte.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,18 @@
except:
XFORMERS_IS_AVAILBLE = False

try:
# needs to have https://github.com/corl-team/rebased/ installed
from fla.ops.triton.rebased_fast import parallel_rebased
except:
REBASED_IS_AVAILABLE = False

try:
# needs to have https://github.com/lucidrains/ring-attention-pytorch installed
from ring_attention_pytorch.ring_flash_attention_cuda import ring_flash_attn_cuda
except:
RING_ATTENTION_IS_AVAILABLE = False

# from timm.models.layers.helpers import to_2tuple
# from timm.models.layers.trace_utils import _assert

Expand All @@ -37,7 +49,7 @@ def modulate(x, shift, scale):
#################################################################################

class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., use_lora=False, attention_mode='math'):
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., use_lora=False, attention_mode='math', eps=1e-12, causal=True, ring_bucket_size=1024):
super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.num_heads = num_heads
Expand All @@ -48,6 +60,9 @@ def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.,
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.eps = eps
self.causal = causal
self.ring_bucket_size = ring_bucket_size

def forward(self, x):
B, N, C = x.shape
Expand All @@ -69,6 +84,12 @@ def forward(self, x):
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)

elif self.attention_mode == 'rebased':
x = parallel_rebased(q, k, v, self.eps, True, True).reshape(B, N, C)

elif self.attention_mode == 'ring':
x = ring_flash_attn_cuda(q, k, v, causal=self.causal, bucket_size=self.ring_bucket_size).reshape(B, N, C)

else:
raise NotImplemented

Expand Down
23 changes: 22 additions & 1 deletion models/latte_img.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,18 @@
except:
XFORMERS_IS_AVAILBLE = False

try:
# needs to have https://github.com/corl-team/rebased/ installed
from fla.ops.triton.rebased_fast import parallel_rebased
except:
REBASED_IS_AVAILABLE = False

try:
# needs to have https://github.com/lucidrains/ring-attention-pytorch installed
from ring_attention_pytorch.ring_flash_attention_cuda import ring_flash_attn_cuda
except:
RING_ATTENTION_IS_AVAILABLE = False

# from timm.models.layers.helpers import to_2tuple
# from timm.models.layers.trace_utils import _assert

Expand All @@ -37,7 +49,7 @@ def modulate(x, shift, scale):
#################################################################################

class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., use_lora=False, attention_mode='math'):
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., use_lora=False, attention_mode='math', eps=1e-12, causal=True, ring_bucket_size=1024):
super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.num_heads = num_heads
Expand All @@ -51,6 +63,9 @@ def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.,
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.eps = eps
self.causal = causal
self.ring_bucket_size = ring_bucket_size

def forward(self, x):
B, N, C = x.shape
Expand All @@ -72,6 +87,12 @@ def forward(self, x):
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)

elif self.attention_mode == 'rebased':
x = parallel_rebased(q, k, v, self.eps, True, True).reshape(B, N, C)

elif self.attention_mode == 'ring':
x = ring_flash_attn_cuda(q, k, v, causal=self.causal, bucket_size=self.ring_bucket_size).reshape(B, N, C)

else:
raise NotImplemented

Expand Down
29 changes: 26 additions & 3 deletions models/latte_t2v.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,22 @@ class GatedSelfAttentionDense(nn.Module):
d_head (`int`): The number of channels in each head.
"""

def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int):
def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int, attn_type: str = 'vanilla'):
super().__init__()

# we need a linear projection since we need cat visual feature and obj feature
self.linear = nn.Linear(context_dim, query_dim)

self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head)
if attn_type == 'rebased':
from models.utils import RebasedAttnProcessor
attn_proc = RebasedAttnProcessor()
elif attn_type == 'ring':
from models.utils import RingAttnProcessor
attn_proc = RingAttnProcessor()
else:
attn_proc = None

self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head, processor=attn_proc)
self.ff = FeedForward(query_dim, activation_fn="geglu")

self.norm1 = nn.LayerNorm(query_dim)
Expand Down Expand Up @@ -178,6 +187,7 @@ def __init__(
attention_type: str = "default",
positional_embeddings: Optional[str] = None,
num_positional_embeddings: Optional[int] = None,
attn_type: str = "vanilla"
):
super().__init__()
self.only_cross_attention = only_cross_attention
Expand Down Expand Up @@ -212,6 +222,15 @@ def __init__(
else:
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)

if attn_type == 'rebased':
from models.utils import RebasedAttnProcessor
attn_proc = RebasedAttnProcessor()
elif attn_type == 'ring':
from models.utils import RingAttnProcessor
attn_proc = RingAttnProcessor()
else:
attn_proc = None

self.attn1 = Attention(
query_dim=dim,
heads=num_attention_heads,
Expand All @@ -220,6 +239,7 @@ def __init__(
bias=attention_bias,
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
upcast_attention=upcast_attention,
processor=attn_proc
)

# # 2. Cross-Attn
Expand Down Expand Up @@ -254,7 +274,7 @@ def __init__(

# 4. Fuser
if attention_type == "gated" or attention_type == "gated-text-image":
self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim,attn_type=attn_type)

# 5. Scale-shift for PixArt-Alpha.
if self.use_ada_layer_norm_single:
Expand Down Expand Up @@ -498,6 +518,7 @@ def __init__(
attention_type: str = "default",
caption_channels: int = None,
video_length: int = 16,
attn_type: str = "vanilla"
):
super().__init__()
self.use_linear_projection = use_linear_projection
Expand Down Expand Up @@ -600,6 +621,7 @@ def __init__(
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
attention_type=attention_type,
attn_type=attn_type
)
for d in range(num_layers)
]
Expand All @@ -624,6 +646,7 @@ def __init__(
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
attention_type=attention_type,
attn_type=attn_type
)
for d in range(num_layers)
]
Expand Down
101 changes: 100 additions & 1 deletion models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,4 +212,103 @@ def count_params(model, verbose=False):
total_params = sum(p.numel() for p in model.parameters())
if verbose:
print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
return total_params
return total_params

try:
# needs to have https://github.com/corl-team/rebased/ installed
from fla.ops.triton.rebased_fast import parallel_rebased
except:
REBASED_IS_AVAILABLE = False

try:
# needs to have https://github.com/lucidrains/ring-attention-pytorch installed
from ring_attention_pytorch.ring_flash_attention_cuda import ring_flash_attn_cuda
except:
RING_ATTENTION_IS_AVAILABLE = False

from diffusers.models.attention_processor import Attention
class AltAttnProcessor:

def __call__(
self,
attn: Attention,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
):
residual = hidden_states

if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)

input_ndim = hidden_states.ndim

if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)

if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])

if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

query = attn.to_q(hidden_states)

if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)

inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads

query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

eps = 1e-12

# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = self.attn_fn(query, key, value, eps, True, True)

hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)

# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)

if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

if attn.residual_connection:
hidden_states = hidden_states + residual

hidden_states = hidden_states / attn.rescale_output_factor

return hidden_states

class RebasedAttnProcessor(AltAttnProcessor):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.attn_fn = parallel_rebased

class RingAttnProcessor(AltAttnProcessor):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.attn_fn = ring_flash_attn_cuda