diff --git a/lightx2v/models/networks/wan/infer/triton_ops.py b/lightx2v/models/networks/wan/infer/triton_ops.py old mode 100644 new mode 100755 index 90c27094..5748a8c9 --- a/lightx2v/models/networks/wan/infer/triton_ops.py +++ b/lightx2v/models/networks/wan/infer/triton_ops.py @@ -132,6 +132,12 @@ def fuse_scale_shift_kernel( B, L, C = x.shape output = torch.empty_like(x) + # for wan2.2-ti2v-5b, scale and shift are [L, 1, C] + if scale.shape[1] == 1: + scale = scale.squeeze(1).unsqueeze(0) + if shift.shape[1] == 1: + shift = shift.squeeze(1).unsqueeze(0) + if scale.dim() == 4: # scale/shift: [B, F, 1, C] rows = B * L