Skip to content

Is caching MLP necessary TaylorSeer? #45

@endymion-ni

Description

@endymion-ni

The reason was that I found Taylor Seer was occupying too much cache memory space, so I tried modifying the code to remove the caching for the MLP part, which cost a little time in dit block.

However, in HunyuanVideo, after I modified the double block code in the following way, the generated video showed more noticeable noise.

I wonder why it happened. Is caching MLP necessary in TaylorSeer?

    def forward(
        self,
        img: torch.Tensor,
        txt: torch.Tensor,
        vec: torch.Tensor,
        cu_seqlens_q: Optional[torch.Tensor] = None,
        cu_seqlens_kv: Optional[torch.Tensor] = None,
        max_seqlen_q: Optional[int] = None,
        max_seqlen_kv: Optional[int] = None,
        freqs_cis: tuple = None,
        cache_dic: Optional[Dict] = None,
        current: Optional[Dict] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        (
            img_mod1_shift,
            img_mod1_scale,
            img_mod1_gate,
            img_mod2_shift,
            img_mod2_scale,
            img_mod2_gate,
        ) = self.img_mod(vec).chunk(6, dim=-1)
        (
            txt_mod1_shift,
            txt_mod1_scale,
            txt_mod1_gate,
            txt_mod2_shift,
            txt_mod2_scale,
            txt_mod2_gate,
        ) = self.txt_mod(vec).chunk(6, dim=-1)

        if current['type'] == 'full':

            current['module'] = 'attn'
            # Prepare image for attention.
            img_modulated = self.img_norm1(img)
            img_modulated = modulate(
                img_modulated, shift=img_mod1_shift, scale=img_mod1_scale
            )
            img_qkv = self.img_attn_qkv(img_modulated)
            img_q, img_k, img_v = rearrange(
                img_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num
            )
            # Apply QK-Norm if needed
            img_q = self.img_attn_q_norm(img_q).to(img_v)
            img_k = self.img_attn_k_norm(img_k).to(img_v)

            # Apply RoPE if needed.
            if freqs_cis is not None:
                img_qq, img_kk = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False)
                assert (
                    img_qq.shape == img_q.shape and img_kk.shape == img_k.shape
                ), f"img_kk: {img_qq.shape}, img_q: {img_q.shape}, img_kk: {img_kk.shape}, img_k: {img_k.shape}"
                img_q, img_k = img_qq, img_kk

            # Prepare txt for attention.
            txt_modulated = self.txt_norm1(txt)
            txt_modulated = modulate(
                txt_modulated, shift=txt_mod1_shift, scale=txt_mod1_scale
            )
            txt_qkv = self.txt_attn_qkv(txt_modulated)
            txt_q, txt_k, txt_v = rearrange(
                txt_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num
            )
            # Apply QK-Norm if needed.
            txt_q = self.txt_attn_q_norm(txt_q).to(txt_v)
            txt_k = self.txt_attn_k_norm(txt_k).to(txt_v)

            # Run actual attention.
            q = torch.cat((img_q, txt_q), dim=1)
            k = torch.cat((img_k, txt_k), dim=1)
            v = torch.cat((img_v, txt_v), dim=1)
            assert (
                cu_seqlens_q.shape[0] == 2 * img.shape[0] + 1
            ), f"cu_seqlens_q.shape:{cu_seqlens_q.shape}, img.shape[0]:{img.shape[0]}"

            # attention computation start
            if not self.hybrid_seq_parallel_attn:
                attn = attention(
                    q,
                    k,
                    v,
                    cu_seqlens_q=cu_seqlens_q,
                    cu_seqlens_kv=cu_seqlens_kv,
                    max_seqlen_q=max_seqlen_q,
                    max_seqlen_kv=max_seqlen_kv,
                    batch_size=img_k.shape[0],
                    mode="vanilla" if ((cache_dic['cache_type'] == 'attention') or (cache_dic['test_FLOPs'])) else "flash",
                )
            else:
                attn = parallel_attention(
                    self.hybrid_seq_parallel_attn,
                    q,
                    k,
                    v,
                    img_q_len=img_q.shape[1],
                    img_kv_len=img_k.shape[1],
                    cu_seqlens_q=cu_seqlens_q,
                    cu_seqlens_kv=cu_seqlens_kv
                )

            # attention computation end

            img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1] :]

            # Calculate the img bloks.
            current['module'] = 'img_attn'
            taylor_cache_init(cache_dic, current)

            img_attn_out = self.img_attn_proj(img_attn)
            img = img + apply_gate(img_attn_out, gate=img_mod1_gate)
            derivative_approximation(cache_dic, current, img_attn_out)
            #img = img + apply_gate(self.img_attn_proj(img_attn), gate=img_mod1_gate)

            # current['module'] = 'img_mlp'
            # taylor_cache_init(cache_dic, current)

            # img_mlp_out = self.img_mlp(
            #     modulate(
            #         self.img_norm2(img), shift=img_mod2_shift, scale=img_mod2_scale
            #     )
            # )
            # img = img + apply_gate(img_mlp_out, gate=img_mod2_gate)
            # derivative_approximation(cache_dic, current, img_mlp_out)

            img = img + apply_gate(
               self.img_mlp(
                   modulate(
                       self.img_norm2(img), shift=img_mod2_shift, scale=img_mod2_scale
                   )
               ),
               gate=img_mod2_gate,
            )

            # Calculate the txt bloks.
            current['module'] = 'txt_attn'
            taylor_cache_init(cache_dic, current)

            txt_attn_out = self.txt_attn_proj(txt_attn)
            txt = txt + apply_gate(txt_attn_out, gate=txt_mod1_gate)
            derivative_approximation(cache_dic, current, txt_attn_out)
            #txt = txt + apply_gate(self.txt_attn_proj(txt_attn), gate=txt_mod1_gate)

            # current['module'] = 'txt_mlp'
            # taylor_cache_init(cache_dic, current)

            # txt_mlp_out = self.txt_mlp(
            #     modulate(
            #         self.txt_norm2(txt), shift=txt_mod2_shift, scale=txt_mod2_scale
            #     )
            # )
            # txt = txt + apply_gate(txt_mlp_out, gate=txt_mod2_gate)
            # derivative_approximation(cache_dic, current, txt_mlp_out)

            txt = txt + apply_gate(
               self.txt_mlp(
                   modulate(
                       self.txt_norm2(txt), shift=txt_mod2_shift, scale=txt_mod2_scale
                   )
               ),
               gate=txt_mod2_gate,
            )

        elif current['type'] == 'taylor_cache':
            
            current['module'] = 'attn'
            # Just a symbolic name

            # Calculate the img bloks.
            current['module'] = 'img_attn'

            img = img + apply_gate(taylor_formula(cache_dic, current), gate=img_mod1_gate)

            # current['module'] = 'img_mlp'

            # img = img + apply_gate(
            #     taylor_formula(cache_dic, current),
            #     gate=img_mod2_gate,
            # )
            img = img + apply_gate(
                self.img_mlp(
                    modulate(
                        self.img_norm2(img), shift=img_mod2_shift, scale=img_mod2_scale
                    )
                ),
                gate=img_mod2_gate,
            )
    
            # Calculate the txt bloks.
            current['module'] = 'txt_attn'

            txt = txt + apply_gate(taylor_formula(cache_dic, current), gate=txt_mod1_gate)

            # current['module'] = 'txt_mlp'

            # txt = txt + apply_gate(
            #     taylor_formula(cache_dic, current),
            #     gate=txt_mod2_gate,
            # )
            txt = txt + apply_gate(
                self.txt_mlp(
                    modulate(
                        self.txt_norm2(txt), shift=txt_mod2_shift, scale=txt_mod2_scale
                    )
                ),
                gate=txt_mod2_gate,
            )

        elif current['type'] == 'ToCa':
            
            current['module'] = 'attn'
            # Just a symbolic name

            # Calculate the img bloks.
            current['module'] = 'img_attn'

            img = img + apply_gate(taylor_formula(cache_dic, current), gate=img_mod1_gate)

            current['module'] = 'img_mlp'
            fresh_indices, fresh_tokens_img = cache_cutfresh(cache_dic=cache_dic, tokens=img, current=current)
            
            fresh_tokens_img = self.img_mlp(
                modulate(
                    self.img_norm2(fresh_tokens_img), shift=img_mod2_shift, scale=img_mod2_scale
                )
            )

            update_cache(fresh_indices=fresh_indices, fresh_tokens=fresh_tokens_img, cache_dic=cache_dic, current=current)

            img = img + apply_gate(
                taylor_formula(cache_dic, current),
                gate=img_mod2_gate,
            )
    
            # Calculate the txt bloks.
            current['module'] = 'txt_attn'

            txt = txt + apply_gate(taylor_formula(cache_dic, current), gate=txt_mod1_gate)

            current['module'] = 'txt_mlp'

            fresh_indices, fresh_tokens_txt = cache_cutfresh(cache_dic=cache_dic, tokens=txt, current=current)

            fresh_tokens_txt = self.txt_mlp(
                modulate(
                    self.txt_norm2(fresh_tokens_txt), shift=txt_mod2_shift, scale=txt_mod2_scale
                )
            )

            update_cache(fresh_indices=fresh_indices, fresh_tokens=fresh_tokens_txt, cache_dic=cache_dic, current=current)

            txt = txt + apply_gate(
                taylor_formula(cache_dic, current),
                gate=txt_mod2_gate,
            )

        return img, txt

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions