Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions lightx2v/models/networks/wan/infer/triton_ops.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,12 @@ def fuse_scale_shift_kernel(
B, L, C = x.shape
output = torch.empty_like(x)

# for wan2.2-ti2v-5b, scale and shift are [L, 1, C]
if scale.shape[1] == 1:
scale = scale.squeeze(1).unsqueeze(0)
if shift.shape[1] == 1:
shift = shift.squeeze(1).unsqueeze(0)
Comment on lines +136 to +139
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The condition scale.shape[1] == 1 is too general and can cause issues. For example, if scale has a shape of [B, 1, C], which is a valid shape handled by the existing logic, this condition will be true. The tensor will be incorrectly reshaped from [B, 1, C] to [1, B, C], which will later cause a runtime error during the expand operation if B > 1.

To fix this, the condition should be more specific to target only the [L, 1, C] shape as mentioned in the comment.

Suggested change
if scale.shape[1] == 1:
scale = scale.squeeze(1).unsqueeze(0)
if shift.shape[1] == 1:
shift = shift.squeeze(1).unsqueeze(0)
if scale.dim() == 3 and scale.shape[0] == L and scale.shape[1] == 1:
scale = scale.squeeze(1).unsqueeze(0)
if shift.dim() == 3 and shift.shape[0] == L and shift.shape[1] == 1:
shift = shift.squeeze(1).unsqueeze(0)


if scale.dim() == 4:
# scale/shift: [B, F, 1, C]
rows = B * L
Expand Down