From f27dfeee66ac637acc35a8dd003ad4f7026b9557 Mon Sep 17 00:00:00 2001 From: gushiqiao <975033167> Date: Tue, 27 Jan 2026 03:03:55 +0000 Subject: [PATCH] fix scale_shift kernel for wan2.2-5b --- lightx2v/models/networks/wan/infer/triton_ops.py | 6 ++++++ 1 file changed, 6 insertions(+) mode change 100644 => 100755 lightx2v/models/networks/wan/infer/triton_ops.py diff --git a/lightx2v/models/networks/wan/infer/triton_ops.py b/lightx2v/models/networks/wan/infer/triton_ops.py old mode 100644 new mode 100755 index 90c27094f..5748a8c91 --- a/lightx2v/models/networks/wan/infer/triton_ops.py +++ b/lightx2v/models/networks/wan/infer/triton_ops.py @@ -132,6 +132,12 @@ def fuse_scale_shift_kernel( B, L, C = x.shape output = torch.empty_like(x) + # for wan2.2-ti2v-5b, scale and shift are [L, 1, C] + if scale.shape[1] == 1: + scale = scale.squeeze(1).unsqueeze(0) + if shift.shape[1] == 1: + shift = shift.squeeze(1).unsqueeze(0) + if scale.dim() == 4: # scale/shift: [B, F, 1, C] rows = B * L