diff --git a/models/latte.py b/models/latte.py index f7ee920..9680d7d 100644 --- a/models/latte.py +++ b/models/latte.py @@ -70,7 +70,7 @@ def forward(self, x): x = (attn @ v).transpose(1, 2).reshape(B, N, C) else: - raise NotImplemented + raise NotImplementedError x = self.proj(x) x = self.proj_drop(x) @@ -369,6 +369,8 @@ def forward(self, if self.extras == 2: c = timestep_spatial + y_spatial + elif self.extras == 78: + c = timestep_spatial + text_embedding_spatial else: c = timestep_spatial x = self.final_layer(x, c) diff --git a/models/latte_img.py b/models/latte_img.py index c468c63..4fc507d 100644 --- a/models/latte_img.py +++ b/models/latte_img.py @@ -58,7 +58,11 @@ def forward(self, x): q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) if self.attention_mode == 'xformers': # cause loss nan while using with amp - x = xformers.ops.memory_efficient_attention(q, k, v).reshape(B, N, C) + # https://github.com/facebookresearch/xformers/blob/e8bd8f932c2f48e3a3171d06749eecbbf1de420c/xformers/ops/fmha/__init__.py#L135 + q_xf = q.transpose(1,2).contiguous() + k_xf = k.transpose(1,2).contiguous() + v_xf = v.transpose(1,2).contiguous() + x = xformers.ops.memory_efficient_attention(q_xf, k_xf, v_xf).reshape(B, N, C) elif self.attention_mode == 'flash': # cause loss nan while using with amp @@ -73,7 +77,7 @@ def forward(self, x): x = (attn @ v).transpose(1, 2).reshape(B, N, C) else: - raise NotImplemented + raise NotImplementedError x = self.proj(x) x = self.proj_drop(x)