Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions lightx2v/common/ops/attn/ulysses_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,10 +142,15 @@ def apply(
shard_txt_k = txt_k[:, (cur_rank * shard_heads + h) : (cur_rank * shard_heads + h + 1), :]
shard_txt_v = txt_v[:, (cur_rank * shard_heads + h) : (cur_rank * shard_heads + h + 1), :]

# 合并图像和文本的查询、键和值
q = torch.cat((shard_img_q, shard_txt_q), dim=0)
k = torch.cat((shard_img_k, shard_txt_k), dim=0)
v = torch.cat((shard_img_v, shard_txt_v), dim=0)
# 合并图像和文本的查询、键和值(根据 img_first 决定顺序)
if img_first:
q = torch.cat((shard_img_q, shard_txt_q), dim=0)
k = torch.cat((shard_img_k, shard_txt_k), dim=0)
v = torch.cat((shard_img_v, shard_txt_v), dim=0)
else:
q = torch.cat((shard_txt_q, shard_img_q), dim=0)
k = torch.cat((shard_txt_k, shard_img_k), dim=0)
v = torch.cat((shard_txt_v, shard_img_v), dim=0)

# 调用注意力函数计算注意力结果
head_attn = attention_module.apply(q=q, k=k, v=v, cu_seqlens_q=cu_seqlens_qkv, cu_seqlens_kv=cu_seqlens_qkv, max_seqlen_q=max_seqlen_qkv, max_seqlen_kv=max_seqlen_qkv, **kwargs)
Expand Down
39 changes: 33 additions & 6 deletions lightx2v/models/networks/qwen_image/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,19 +413,46 @@ def _infer_cond_uncond(self, latents_input, prompt_embeds, infer_condition=True)

@torch.no_grad()
def _seq_parallel_pre_process(self, pre_infer_out):
"""USP pre-process: shard hidden_states and track padding for post-process."""
world_size = dist.get_world_size(self.seq_p_group)
cur_rank = dist.get_rank(self.seq_p_group)
seqlen = pre_infer_out.hidden_states.shape[0]
padding_size = (world_size - (seqlen % world_size)) % world_size
if padding_size > 0:
pre_infer_out.hidden_states = F.pad(pre_infer_out.hidden_states, (0, 0, 0, padding_size))
pre_infer_out.hidden_states = torch.chunk(pre_infer_out.hidden_states, world_size, dim=0)[cur_rank]

hidden_states = pre_infer_out.hidden_states
if hidden_states.dim() == 2:
# 2D: [seq_len, hidden_dim]
seqlen = hidden_states.shape[0]
padding_size = (world_size - (seqlen % world_size)) % world_size
if padding_size > 0:
hidden_states = F.pad(hidden_states, (0, 0, 0, padding_size))
pre_infer_out.hidden_states = torch.chunk(hidden_states, world_size, dim=0)[cur_rank]
else:
# 3D: [batch, seq_len, hidden_dim]
seqlen = hidden_states.shape[1]
padding_size = (world_size - (seqlen % world_size)) % world_size
if padding_size > 0:
hidden_states = F.pad(hidden_states, (0, 0, 0, padding_size))
pre_infer_out.hidden_states = torch.chunk(hidden_states, world_size, dim=1)[cur_rank]

# Save for post-process
self._usp_original_seqlen = seqlen
self._usp_padding_size = padding_size
return pre_infer_out

@torch.no_grad()
def _seq_parallel_post_process(self, noise_pred):
"""USP post-process: gather outputs and remove padding."""
world_size = dist.get_world_size(self.seq_p_group)
gathered_noise_pred = [torch.empty_like(noise_pred) for _ in range(world_size)]
dist.all_gather(gathered_noise_pred, noise_pred, group=self.seq_p_group)
noise_pred = torch.cat(gathered_noise_pred, dim=1)

if noise_pred.dim() == 2:
# 2D: [seq_len/N, hidden_dim] -> [seq_len, hidden_dim]
noise_pred = torch.cat(gathered_noise_pred, dim=0)
if hasattr(self, "_usp_padding_size") and self._usp_padding_size > 0:
noise_pred = noise_pred[: self._usp_original_seqlen, :]
else:
# 3D: [batch, seq_len/N, hidden_dim] -> [batch, seq_len, hidden_dim]
noise_pred = torch.cat(gathered_noise_pred, dim=1)
if hasattr(self, "_usp_padding_size") and self._usp_padding_size > 0:
noise_pred = noise_pred[:, : self._usp_original_seqlen, :]
return noise_pred