Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions configs/z_image/z_image_turbo_t2i_fp8.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
{
"aspect_ratio": "16:9",
"num_channels_latents": 16,
"infer_steps": 9,
"attn_type": "flash_attn3",
"enable_cfg": false,
"sample_guide_scale": 0.0,
"patch_size": 2,
"dit_quantized": true,
"dit_quant_scheme": "fp8-sgl",
"dit_quantized_ckpt": "/path/to/z_image_turbo_fp8.safetensors"
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
"num_channels_latents": 16,
"infer_steps": 9,
"attn_type": "flash_attn3",
"enable_cfg": true,
"enable_cfg": false,
"sample_guide_scale": 0.0,
"patch_size": 2,
"strength": 0.6,
"resize_mode": "adaptive"
"cpu_offload": true,
"offload_granularity": "block"
}
8 changes: 4 additions & 4 deletions lightx2v/common/modules/weight_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def named_parameters(self, prefix=""):
if module is not None:
yield from module.named_parameters(prefix + name + ".")

def to_cpu(self):
def to_cpu(self, non_blocking=False):
for name, param in self._parameters.items():
if param is not None:
if hasattr(param, "cpu"):
Expand All @@ -110,7 +110,7 @@ def to_cpu(self):
if module is not None and hasattr(module, "to_cpu"):
module.to_cpu()

def to_cuda(self):
def to_cuda(self, non_blocking=False):
for name, param in self._parameters.items():
if param is not None:
if hasattr(param, "cuda"):
Expand All @@ -131,7 +131,7 @@ def to_cuda(self):
if module is not None and hasattr(module, "to_cuda"):
module.to_cuda()

def to_cpu_async(self):
def to_cpu_async(self, non_blocking=True):
for name, param in self._parameters.items():
if param is not None:
if hasattr(param, "cpu"):
Expand All @@ -153,7 +153,7 @@ def to_cpu_async(self):
if module is not None and hasattr(module, "to_cpu"):
module.to_cpu(non_blocking=True)

def to_cuda_async(self):
def to_cuda_async(self, non_blocking=True):
for name, param in self._parameters.items():
if param is not None:
if hasattr(param, "cuda"):
Expand Down
78 changes: 52 additions & 26 deletions lightx2v/models/networks/z_image/infer/offload/transformer_infer.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -10,42 +10,68 @@
class ZImageOffloadTransformerInfer(ZImageTransformerInfer):
def __init__(self, config):
super().__init__(config)
self.phases_num = 3
self.num_blocks = config["num_layers"]
if self.config.get("cpu_offload", False):
if "offload_ratio" in self.config:
self.offload_ratio = self.config["offload_ratio"]
else:
self.offload_ratio = 1
offload_granularity = self.config.get("offload_granularity", "block")
if offload_granularity == "block":
if not self.config.get("lazy_load", False):
self.infer_func = self.infer_with_blocks_offload
else:
assert NotImplementedError

if offload_granularity != "model":
self.offload_manager = WeightAsyncStreamManager(offload_granularity=offload_granularity)
else:
assert NotImplementedError
self.lazy_load = self.config.get("lazy_load", False)
self.infer_main_blocks = self.infer_main_blocks_offload
if self.lazy_load:
self.offload_manager.init_lazy_load(num_workers=self.config.get("num_disk_workers", 4))
elif offload_granularity == "phase":
raise NotImplementedError("offload_granularity=phase not supported")

def infer_with_blocks_offload(self, block_weights, hidden_states, encoder_hidden_states, temb, image_rotary_emb, modulate_index):
for block_idx in range(self.num_blocks):
def infer_with_blocks_offload(
self,
main_blocks,
unified,
unified_freqs_cis,
adaln_input,
):
num_blocks = len(main_blocks)
for block_idx in range(num_blocks):
self.block_idx = block_idx
if self.offload_manager.need_init_first_buffer:
self.offload_manager.init_first_buffer(block_weights.blocks)

self.offload_manager.prefetch_weights((block_idx + 1) % self.num_blocks, block_weights.blocks)
if self.lazy_load:
next_prefetch = (block_idx + 1) % num_blocks
self.offload_manager.start_prefetch_block(next_prefetch)

if block_idx == 0:
self.offload_manager.init_first_buffer(main_blocks)

if self.lazy_load:
self.offload_manager.swap_cpu_buffers()
self.offload_manager.prefetch_weights((block_idx + 1) % num_blocks, main_blocks)

with torch_device_module.stream(self.offload_manager.compute_stream):
encoder_hidden_states, hidden_states = self.infer_block(
unified = self.infer_block(
block_weight=self.offload_manager.cuda_buffers[0],
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
modulate_index=modulate_index,
hidden_states=unified,
freqs_cis=unified_freqs_cis,
adaln_input=adaln_input,
)

self.offload_manager.swap_blocks()

return encoder_hidden_states, hidden_states
return unified

def infer_main_blocks_offload(
self,
main_blocks,
hidden_states,
encoder_hidden_states,
x_freqs_cis,
cap_freqs_cis,
adaln_input,
x_len,
cap_len,
):
unified = torch.cat([hidden_states, encoder_hidden_states], dim=0)
unified_freqs_cis = torch.cat([x_freqs_cis[:x_len], cap_freqs_cis[:cap_len]], dim=0)
unified = self.infer_with_blocks_offload(
main_blocks=main_blocks,
unified=unified,
unified_freqs_cis=unified_freqs_cis,
adaln_input=adaln_input,
)
return unified
61 changes: 44 additions & 17 deletions lightx2v/models/networks/z_image/infer/post_infer.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,42 @@ def set_scheduler(self, scheduler):
self.scheduler = scheduler

def infer(self, weights, hidden_states, temb_img_silu, image_tokens_len=None):
temb_silu = F.silu(temb_img_silu)
temb1 = weights.norm_out_linear.apply(temb_silu)
"""
Post-inference processing: apply norm_out, proj_out, and unpatchify.
All processing is done without batch dimension: [T, D] instead of [B, T, D].

scale = 1.0 + temb1
normed = weights.norm_out.apply(hidden_states)
scaled_norm = normed * scale.unsqueeze(1)
B, T, D = scaled_norm.shape
hidden_states_2d = scaled_norm.reshape(B * T, D)
Args:
weights: PostInfer weights
hidden_states: Hidden states [T, D] (no batch dimension)
temb_img_silu: Time embedding [1, D]
image_tokens_len: Image tokens length (optional)

output_2d = weights.proj_out_linear.apply(hidden_states_2d)
out_dim = output_2d.shape[-1]
output = output_2d.reshape(B, T, out_dim)
Returns:
output_4d: Output tensor [C, H, W] (no batch dimension)
"""
# hidden_states is already [T, D] from pre_infer (no batch dimension)
# temb_img_silu is [1, D] from pre_infer

# Apply norm_out_linear: [1, D] -> [1, D]
temb_silu = F.silu(temb_img_silu) # [1, D]
temb1 = weights.norm_out_linear.apply(temb_silu) # [1, D]

# Apply modulation: scale = 1.0 + temb1
scale = 1.0 + temb1 # [1, D]
normed = weights.norm_out.apply(hidden_states) # [T, D]
scaled_norm = normed * scale # [T, D] * [1, D] -> [T, D]

# Apply proj_out_linear: [T, D] -> [T, out_dim]
output = weights.proj_out_linear.apply(scaled_norm) # [T, out_dim]

# Trim to image_tokens_len if specified
if image_tokens_len is not None:
output = output[:, :image_tokens_len, :]
output = output[:image_tokens_len, :] # [image_tokens_len, out_dim]

# Get output dimension
T, out_dim = output.shape

# Validate output dimension
patch_size = self.config.get("patch_size", 2)
f_patch_size = self.config.get("f_patch_size", 1)
transformer_out_channels = out_dim // (patch_size * patch_size * f_patch_size)
Expand All @@ -47,12 +67,19 @@ def infer(self, weights, hidden_states, temb_img_silu, image_tokens_len=None):
W_tokens = width // pW

expected_T = F_tokens * H_tokens * W_tokens
if output.shape[1] != expected_T:
raise ValueError(f"Token count mismatch: output.shape[1]={output.shape[1]} != expected_T={expected_T} (from target_shape={target_shape})")
if T != expected_T:
raise ValueError(f"Token count mismatch: T={T} != expected_T={expected_T} (from target_shape={target_shape})")

output_reshaped = output.view(B, F_tokens, H_tokens, W_tokens, pF, pH, pW, out_channels)
output_permuted = output_reshaped.permute(0, 7, 1, 4, 2, 5, 3, 6)
output_4d = output_permuted.reshape(B, out_channels, num_frames, height, width)
output_4d = output_4d.squeeze(2)
# Unpatchify: [T, out_dim] -> [C, H, W]
# Reshape: [T, out_dim] -> [F_tokens, H_tokens, W_tokens, pF, pH, pW, out_channels]
output_reshaped = output.view(F_tokens, H_tokens, W_tokens, pF, pH, pW, out_channels)
# Permute: [F_tokens, H_tokens, W_tokens, pF, pH, pW, out_channels]
# -> [out_channels, F_tokens, pF, H_tokens, pH, W_tokens, pW]
output_permuted = output_reshaped.permute(6, 0, 3, 1, 4, 2, 5)
# Reshape: [out_channels, F_tokens, pF, H_tokens, pH, W_tokens, pW]
# -> [out_channels, num_frames, height, width]
output_4d = output_permuted.reshape(out_channels, num_frames, height, width)
# Remove frame dimension: [out_channels, 1, height, width] -> [out_channels, height, width]
output_4d = output_4d.squeeze(1)

return output_4d
Loading