diff --git a/lightx2v/common/ops/attn/ulysses_attn.py b/lightx2v/common/ops/attn/ulysses_attn.py index b23048fc..8f3d7104 100755 --- a/lightx2v/common/ops/attn/ulysses_attn.py +++ b/lightx2v/common/ops/attn/ulysses_attn.py @@ -142,10 +142,15 @@ def apply( shard_txt_k = txt_k[:, (cur_rank * shard_heads + h) : (cur_rank * shard_heads + h + 1), :] shard_txt_v = txt_v[:, (cur_rank * shard_heads + h) : (cur_rank * shard_heads + h + 1), :] - # 合并图像和文本的查询、键和值 - q = torch.cat((shard_img_q, shard_txt_q), dim=0) - k = torch.cat((shard_img_k, shard_txt_k), dim=0) - v = torch.cat((shard_img_v, shard_txt_v), dim=0) + # 合并图像和文本的查询、键和值(根据 img_first 决定顺序) + if img_first: + q = torch.cat((shard_img_q, shard_txt_q), dim=0) + k = torch.cat((shard_img_k, shard_txt_k), dim=0) + v = torch.cat((shard_img_v, shard_txt_v), dim=0) + else: + q = torch.cat((shard_txt_q, shard_img_q), dim=0) + k = torch.cat((shard_txt_k, shard_img_k), dim=0) + v = torch.cat((shard_txt_v, shard_img_v), dim=0) # 调用注意力函数计算注意力结果 head_attn = attention_module.apply(q=q, k=k, v=v, cu_seqlens_q=cu_seqlens_qkv, cu_seqlens_kv=cu_seqlens_qkv, max_seqlen_q=max_seqlen_qkv, max_seqlen_kv=max_seqlen_qkv, **kwargs) diff --git a/lightx2v/models/networks/qwen_image/model.py b/lightx2v/models/networks/qwen_image/model.py index 2f93b339..23a90c35 100755 --- a/lightx2v/models/networks/qwen_image/model.py +++ b/lightx2v/models/networks/qwen_image/model.py @@ -413,19 +413,46 @@ def _infer_cond_uncond(self, latents_input, prompt_embeds, infer_condition=True) @torch.no_grad() def _seq_parallel_pre_process(self, pre_infer_out): + """USP pre-process: shard hidden_states and track padding for post-process.""" world_size = dist.get_world_size(self.seq_p_group) cur_rank = dist.get_rank(self.seq_p_group) - seqlen = pre_infer_out.hidden_states.shape[0] - padding_size = (world_size - (seqlen % world_size)) % world_size - if padding_size > 0: - pre_infer_out.hidden_states = F.pad(pre_infer_out.hidden_states, (0, 0, 0, padding_size)) - pre_infer_out.hidden_states = torch.chunk(pre_infer_out.hidden_states, world_size, dim=0)[cur_rank] + + hidden_states = pre_infer_out.hidden_states + if hidden_states.dim() == 2: + # 2D: [seq_len, hidden_dim] + seqlen = hidden_states.shape[0] + padding_size = (world_size - (seqlen % world_size)) % world_size + if padding_size > 0: + hidden_states = F.pad(hidden_states, (0, 0, 0, padding_size)) + pre_infer_out.hidden_states = torch.chunk(hidden_states, world_size, dim=0)[cur_rank] + else: + # 3D: [batch, seq_len, hidden_dim] + seqlen = hidden_states.shape[1] + padding_size = (world_size - (seqlen % world_size)) % world_size + if padding_size > 0: + hidden_states = F.pad(hidden_states, (0, 0, 0, padding_size)) + pre_infer_out.hidden_states = torch.chunk(hidden_states, world_size, dim=1)[cur_rank] + + # Save for post-process + self._usp_original_seqlen = seqlen + self._usp_padding_size = padding_size return pre_infer_out @torch.no_grad() def _seq_parallel_post_process(self, noise_pred): + """USP post-process: gather outputs and remove padding.""" world_size = dist.get_world_size(self.seq_p_group) gathered_noise_pred = [torch.empty_like(noise_pred) for _ in range(world_size)] dist.all_gather(gathered_noise_pred, noise_pred, group=self.seq_p_group) - noise_pred = torch.cat(gathered_noise_pred, dim=1) + + if noise_pred.dim() == 2: + # 2D: [seq_len/N, hidden_dim] -> [seq_len, hidden_dim] + noise_pred = torch.cat(gathered_noise_pred, dim=0) + if hasattr(self, "_usp_padding_size") and self._usp_padding_size > 0: + noise_pred = noise_pred[: self._usp_original_seqlen, :] + else: + # 3D: [batch, seq_len/N, hidden_dim] -> [batch, seq_len, hidden_dim] + noise_pred = torch.cat(gathered_noise_pred, dim=1) + if hasattr(self, "_usp_padding_size") and self._usp_padding_size > 0: + noise_pred = noise_pred[:, : self._usp_original_seqlen, :] return noise_pred