From 53c5c5dbed6bd2ae00d862d5593e30b0e9b650d6 Mon Sep 17 00:00:00 2001 From: gdli7 Date: Tue, 27 Jan 2026 13:38:31 +0000 Subject: [PATCH] fix(qwen): add USP support for Qwen-Image-Edit-2511 This commit fixes USP (Ulysses Sequence Parallelism) support for Qwen-Image-Edit-2511 model with the following changes: 1. ulysses_attn.py: - Respect img_first parameter when concatenating Q/K/V tensors - Qwen-Image-Edit uses img_first=False (text before image) - Add txt_attn.reshape to ensure consistent 2D output format 2. model.py: - Support both 2D and 3D tensor formats in USP pre/post processing - Save original sequence length and padding size for accurate padding removal in post-processing - Fix gather dimension based on tensor dimensionality --- lightx2v/common/ops/attn/ulysses_attn.py | 29 +++++++++++---- lightx2v/models/networks/qwen_image/model.py | 39 +++++++++++++++++--- 2 files changed, 54 insertions(+), 14 deletions(-) diff --git a/lightx2v/common/ops/attn/ulysses_attn.py b/lightx2v/common/ops/attn/ulysses_attn.py index b23048fc..ade4b6a9 100755 --- a/lightx2v/common/ops/attn/ulysses_attn.py +++ b/lightx2v/common/ops/attn/ulysses_attn.py @@ -142,10 +142,15 @@ def apply( shard_txt_k = txt_k[:, (cur_rank * shard_heads + h) : (cur_rank * shard_heads + h + 1), :] shard_txt_v = txt_v[:, (cur_rank * shard_heads + h) : (cur_rank * shard_heads + h + 1), :] - # 合并图像和文本的查询、键和值 - q = torch.cat((shard_img_q, shard_txt_q), dim=0) - k = torch.cat((shard_img_k, shard_txt_k), dim=0) - v = torch.cat((shard_img_v, shard_txt_v), dim=0) + # 合并图像和文本的查询、键和值(根据 img_first 决定顺序) + if img_first: + q = torch.cat((shard_img_q, shard_txt_q), dim=0) + k = torch.cat((shard_img_k, shard_txt_k), dim=0) + v = torch.cat((shard_img_v, shard_txt_v), dim=0) + else: + q = torch.cat((shard_txt_q, shard_img_q), dim=0) + k = torch.cat((shard_txt_k, shard_img_k), dim=0) + v = torch.cat((shard_txt_v, shard_img_v), dim=0) # 调用注意力函数计算注意力结果 head_attn = attention_module.apply(q=q, k=k, v=v, cu_seqlens_q=cu_seqlens_qkv, cu_seqlens_kv=cu_seqlens_qkv, max_seqlen_q=max_seqlen_qkv, max_seqlen_kv=max_seqlen_qkv, **kwargs) @@ -182,10 +187,15 @@ def apply( shard_txt_k = txt_k[:, cur_rank * shard_heads : (cur_rank + 1) * shard_heads, :] shard_txt_v = txt_v[:, cur_rank * shard_heads : (cur_rank + 1) * shard_heads, :] - # 合并图像和文本的查询、键和值 - q = torch.cat((shard_img_q, shard_txt_q), dim=0) - k = torch.cat((shard_img_k, shard_txt_k), dim=0) - v = torch.cat((shard_img_v, shard_txt_v), dim=0) + # 合并图像和文本的查询、键和值(根据 img_first 决定顺序) + if img_first: + q = torch.cat((shard_img_q, shard_txt_q), dim=0) + k = torch.cat((shard_img_k, shard_txt_k), dim=0) + v = torch.cat((shard_img_v, shard_txt_v), dim=0) + else: + q = torch.cat((shard_txt_q, shard_img_q), dim=0) + k = torch.cat((shard_txt_k, shard_img_k), dim=0) + v = torch.cat((shard_txt_v, shard_img_v), dim=0) # 调用注意力函数计算注意力结果 attn = attention_module.apply(q=q, k=k, v=v, cu_seqlens_q=cu_seqlens_qkv, cu_seqlens_kv=cu_seqlens_qkv, max_seqlen_q=max_seqlen_qkv, max_seqlen_kv=max_seqlen_qkv, **kwargs) @@ -204,6 +214,9 @@ def apply( dist.all_gather(gathered_txt_attn, txt_attn, group=seq_p_group) txt_attn = torch.cat(gathered_txt_attn, dim=1) # 合并所有进程的文本注意力结果 + # 将 txt_attn reshape 为 2D,与 img_attn 一致 + txt_attn = txt_attn.reshape(txt_qkv_len, -1) + # 合并图像和文本的注意力结果 if img_first: attn = torch.cat([img_attn, txt_attn], dim=0) diff --git a/lightx2v/models/networks/qwen_image/model.py b/lightx2v/models/networks/qwen_image/model.py index 2f93b339..23a90c35 100755 --- a/lightx2v/models/networks/qwen_image/model.py +++ b/lightx2v/models/networks/qwen_image/model.py @@ -413,19 +413,46 @@ def _infer_cond_uncond(self, latents_input, prompt_embeds, infer_condition=True) @torch.no_grad() def _seq_parallel_pre_process(self, pre_infer_out): + """USP pre-process: shard hidden_states and track padding for post-process.""" world_size = dist.get_world_size(self.seq_p_group) cur_rank = dist.get_rank(self.seq_p_group) - seqlen = pre_infer_out.hidden_states.shape[0] - padding_size = (world_size - (seqlen % world_size)) % world_size - if padding_size > 0: - pre_infer_out.hidden_states = F.pad(pre_infer_out.hidden_states, (0, 0, 0, padding_size)) - pre_infer_out.hidden_states = torch.chunk(pre_infer_out.hidden_states, world_size, dim=0)[cur_rank] + + hidden_states = pre_infer_out.hidden_states + if hidden_states.dim() == 2: + # 2D: [seq_len, hidden_dim] + seqlen = hidden_states.shape[0] + padding_size = (world_size - (seqlen % world_size)) % world_size + if padding_size > 0: + hidden_states = F.pad(hidden_states, (0, 0, 0, padding_size)) + pre_infer_out.hidden_states = torch.chunk(hidden_states, world_size, dim=0)[cur_rank] + else: + # 3D: [batch, seq_len, hidden_dim] + seqlen = hidden_states.shape[1] + padding_size = (world_size - (seqlen % world_size)) % world_size + if padding_size > 0: + hidden_states = F.pad(hidden_states, (0, 0, 0, padding_size)) + pre_infer_out.hidden_states = torch.chunk(hidden_states, world_size, dim=1)[cur_rank] + + # Save for post-process + self._usp_original_seqlen = seqlen + self._usp_padding_size = padding_size return pre_infer_out @torch.no_grad() def _seq_parallel_post_process(self, noise_pred): + """USP post-process: gather outputs and remove padding.""" world_size = dist.get_world_size(self.seq_p_group) gathered_noise_pred = [torch.empty_like(noise_pred) for _ in range(world_size)] dist.all_gather(gathered_noise_pred, noise_pred, group=self.seq_p_group) - noise_pred = torch.cat(gathered_noise_pred, dim=1) + + if noise_pred.dim() == 2: + # 2D: [seq_len/N, hidden_dim] -> [seq_len, hidden_dim] + noise_pred = torch.cat(gathered_noise_pred, dim=0) + if hasattr(self, "_usp_padding_size") and self._usp_padding_size > 0: + noise_pred = noise_pred[: self._usp_original_seqlen, :] + else: + # 3D: [batch, seq_len/N, hidden_dim] -> [batch, seq_len, hidden_dim] + noise_pred = torch.cat(gathered_noise_pred, dim=1) + if hasattr(self, "_usp_padding_size") and self._usp_padding_size > 0: + noise_pred = noise_pred[:, : self._usp_original_seqlen, :] return noise_pred