The reason was that I found Taylor Seer was occupying too much cache memory space, so I tried modifying the code to remove the caching for the MLP part, which cost a little time in dit block.
However, in HunyuanVideo, after I modified the double block code in the following way, the generated video showed more noticeable noise.
I wonder why it happened. Is caching MLP necessary in TaylorSeer?
def forward(
self,
img: torch.Tensor,
txt: torch.Tensor,
vec: torch.Tensor,
cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_kv: Optional[torch.Tensor] = None,
max_seqlen_q: Optional[int] = None,
max_seqlen_kv: Optional[int] = None,
freqs_cis: tuple = None,
cache_dic: Optional[Dict] = None,
current: Optional[Dict] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
(
img_mod1_shift,
img_mod1_scale,
img_mod1_gate,
img_mod2_shift,
img_mod2_scale,
img_mod2_gate,
) = self.img_mod(vec).chunk(6, dim=-1)
(
txt_mod1_shift,
txt_mod1_scale,
txt_mod1_gate,
txt_mod2_shift,
txt_mod2_scale,
txt_mod2_gate,
) = self.txt_mod(vec).chunk(6, dim=-1)
if current['type'] == 'full':
current['module'] = 'attn'
# Prepare image for attention.
img_modulated = self.img_norm1(img)
img_modulated = modulate(
img_modulated, shift=img_mod1_shift, scale=img_mod1_scale
)
img_qkv = self.img_attn_qkv(img_modulated)
img_q, img_k, img_v = rearrange(
img_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num
)
# Apply QK-Norm if needed
img_q = self.img_attn_q_norm(img_q).to(img_v)
img_k = self.img_attn_k_norm(img_k).to(img_v)
# Apply RoPE if needed.
if freqs_cis is not None:
img_qq, img_kk = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False)
assert (
img_qq.shape == img_q.shape and img_kk.shape == img_k.shape
), f"img_kk: {img_qq.shape}, img_q: {img_q.shape}, img_kk: {img_kk.shape}, img_k: {img_k.shape}"
img_q, img_k = img_qq, img_kk
# Prepare txt for attention.
txt_modulated = self.txt_norm1(txt)
txt_modulated = modulate(
txt_modulated, shift=txt_mod1_shift, scale=txt_mod1_scale
)
txt_qkv = self.txt_attn_qkv(txt_modulated)
txt_q, txt_k, txt_v = rearrange(
txt_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num
)
# Apply QK-Norm if needed.
txt_q = self.txt_attn_q_norm(txt_q).to(txt_v)
txt_k = self.txt_attn_k_norm(txt_k).to(txt_v)
# Run actual attention.
q = torch.cat((img_q, txt_q), dim=1)
k = torch.cat((img_k, txt_k), dim=1)
v = torch.cat((img_v, txt_v), dim=1)
assert (
cu_seqlens_q.shape[0] == 2 * img.shape[0] + 1
), f"cu_seqlens_q.shape:{cu_seqlens_q.shape}, img.shape[0]:{img.shape[0]}"
# attention computation start
if not self.hybrid_seq_parallel_attn:
attn = attention(
q,
k,
v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
max_seqlen_q=max_seqlen_q,
max_seqlen_kv=max_seqlen_kv,
batch_size=img_k.shape[0],
mode="vanilla" if ((cache_dic['cache_type'] == 'attention') or (cache_dic['test_FLOPs'])) else "flash",
)
else:
attn = parallel_attention(
self.hybrid_seq_parallel_attn,
q,
k,
v,
img_q_len=img_q.shape[1],
img_kv_len=img_k.shape[1],
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv
)
# attention computation end
img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1] :]
# Calculate the img bloks.
current['module'] = 'img_attn'
taylor_cache_init(cache_dic, current)
img_attn_out = self.img_attn_proj(img_attn)
img = img + apply_gate(img_attn_out, gate=img_mod1_gate)
derivative_approximation(cache_dic, current, img_attn_out)
#img = img + apply_gate(self.img_attn_proj(img_attn), gate=img_mod1_gate)
# current['module'] = 'img_mlp'
# taylor_cache_init(cache_dic, current)
# img_mlp_out = self.img_mlp(
# modulate(
# self.img_norm2(img), shift=img_mod2_shift, scale=img_mod2_scale
# )
# )
# img = img + apply_gate(img_mlp_out, gate=img_mod2_gate)
# derivative_approximation(cache_dic, current, img_mlp_out)
img = img + apply_gate(
self.img_mlp(
modulate(
self.img_norm2(img), shift=img_mod2_shift, scale=img_mod2_scale
)
),
gate=img_mod2_gate,
)
# Calculate the txt bloks.
current['module'] = 'txt_attn'
taylor_cache_init(cache_dic, current)
txt_attn_out = self.txt_attn_proj(txt_attn)
txt = txt + apply_gate(txt_attn_out, gate=txt_mod1_gate)
derivative_approximation(cache_dic, current, txt_attn_out)
#txt = txt + apply_gate(self.txt_attn_proj(txt_attn), gate=txt_mod1_gate)
# current['module'] = 'txt_mlp'
# taylor_cache_init(cache_dic, current)
# txt_mlp_out = self.txt_mlp(
# modulate(
# self.txt_norm2(txt), shift=txt_mod2_shift, scale=txt_mod2_scale
# )
# )
# txt = txt + apply_gate(txt_mlp_out, gate=txt_mod2_gate)
# derivative_approximation(cache_dic, current, txt_mlp_out)
txt = txt + apply_gate(
self.txt_mlp(
modulate(
self.txt_norm2(txt), shift=txt_mod2_shift, scale=txt_mod2_scale
)
),
gate=txt_mod2_gate,
)
elif current['type'] == 'taylor_cache':
current['module'] = 'attn'
# Just a symbolic name
# Calculate the img bloks.
current['module'] = 'img_attn'
img = img + apply_gate(taylor_formula(cache_dic, current), gate=img_mod1_gate)
# current['module'] = 'img_mlp'
# img = img + apply_gate(
# taylor_formula(cache_dic, current),
# gate=img_mod2_gate,
# )
img = img + apply_gate(
self.img_mlp(
modulate(
self.img_norm2(img), shift=img_mod2_shift, scale=img_mod2_scale
)
),
gate=img_mod2_gate,
)
# Calculate the txt bloks.
current['module'] = 'txt_attn'
txt = txt + apply_gate(taylor_formula(cache_dic, current), gate=txt_mod1_gate)
# current['module'] = 'txt_mlp'
# txt = txt + apply_gate(
# taylor_formula(cache_dic, current),
# gate=txt_mod2_gate,
# )
txt = txt + apply_gate(
self.txt_mlp(
modulate(
self.txt_norm2(txt), shift=txt_mod2_shift, scale=txt_mod2_scale
)
),
gate=txt_mod2_gate,
)
elif current['type'] == 'ToCa':
current['module'] = 'attn'
# Just a symbolic name
# Calculate the img bloks.
current['module'] = 'img_attn'
img = img + apply_gate(taylor_formula(cache_dic, current), gate=img_mod1_gate)
current['module'] = 'img_mlp'
fresh_indices, fresh_tokens_img = cache_cutfresh(cache_dic=cache_dic, tokens=img, current=current)
fresh_tokens_img = self.img_mlp(
modulate(
self.img_norm2(fresh_tokens_img), shift=img_mod2_shift, scale=img_mod2_scale
)
)
update_cache(fresh_indices=fresh_indices, fresh_tokens=fresh_tokens_img, cache_dic=cache_dic, current=current)
img = img + apply_gate(
taylor_formula(cache_dic, current),
gate=img_mod2_gate,
)
# Calculate the txt bloks.
current['module'] = 'txt_attn'
txt = txt + apply_gate(taylor_formula(cache_dic, current), gate=txt_mod1_gate)
current['module'] = 'txt_mlp'
fresh_indices, fresh_tokens_txt = cache_cutfresh(cache_dic=cache_dic, tokens=txt, current=current)
fresh_tokens_txt = self.txt_mlp(
modulate(
self.txt_norm2(fresh_tokens_txt), shift=txt_mod2_shift, scale=txt_mod2_scale
)
)
update_cache(fresh_indices=fresh_indices, fresh_tokens=fresh_tokens_txt, cache_dic=cache_dic, current=current)
txt = txt + apply_gate(
taylor_formula(cache_dic, current),
gate=txt_mod2_gate,
)
return img, txt
The reason was that I found Taylor Seer was occupying too much cache memory space, so I tried modifying the code to remove the caching for the MLP part, which cost a little time in dit block.
However, in HunyuanVideo, after I modified the double block code in the following way, the generated video showed more noticeable noise.
I wonder why it happened. Is caching MLP necessary in TaylorSeer?