diff --git a/configs/z_image/z_image_turbo_t2i_fp8.json b/configs/z_image/z_image_turbo_t2i_fp8.json new file mode 100755 index 000000000..ec9f6e9f5 --- /dev/null +++ b/configs/z_image/z_image_turbo_t2i_fp8.json @@ -0,0 +1,12 @@ +{ + "aspect_ratio": "16:9", + "num_channels_latents": 16, + "infer_steps": 9, + "attn_type": "flash_attn3", + "enable_cfg": false, + "sample_guide_scale": 0.0, + "patch_size": 2, + "dit_quantized": true, + "dit_quant_scheme": "fp8-sgl", + "dit_quantized_ckpt": "/path/to/z_image_turbo_fp8.safetensors" +} diff --git a/configs/z_image/z_image_turbo_i2i.json b/configs/z_image/z_image_turbo_t2i_offload.json similarity index 66% rename from configs/z_image/z_image_turbo_i2i.json rename to configs/z_image/z_image_turbo_t2i_offload.json index b2c7180d9..cc2bc8c3b 100755 --- a/configs/z_image/z_image_turbo_i2i.json +++ b/configs/z_image/z_image_turbo_t2i_offload.json @@ -3,9 +3,9 @@ "num_channels_latents": 16, "infer_steps": 9, "attn_type": "flash_attn3", - "enable_cfg": true, + "enable_cfg": false, "sample_guide_scale": 0.0, "patch_size": 2, - "strength": 0.6, - "resize_mode": "adaptive" + "cpu_offload": true, + "offload_granularity": "block" } diff --git a/lightx2v/common/modules/weight_module.py b/lightx2v/common/modules/weight_module.py index 486e09f44..21fdc2169 100755 --- a/lightx2v/common/modules/weight_module.py +++ b/lightx2v/common/modules/weight_module.py @@ -88,7 +88,7 @@ def named_parameters(self, prefix=""): if module is not None: yield from module.named_parameters(prefix + name + ".") - def to_cpu(self): + def to_cpu(self, non_blocking=False): for name, param in self._parameters.items(): if param is not None: if hasattr(param, "cpu"): @@ -110,7 +110,7 @@ def to_cpu(self): if module is not None and hasattr(module, "to_cpu"): module.to_cpu() - def to_cuda(self): + def to_cuda(self, non_blocking=False): for name, param in self._parameters.items(): if param is not None: if hasattr(param, "cuda"): @@ -131,7 +131,7 @@ def to_cuda(self): if module is not None and hasattr(module, "to_cuda"): module.to_cuda() - def to_cpu_async(self): + def to_cpu_async(self, non_blocking=True): for name, param in self._parameters.items(): if param is not None: if hasattr(param, "cpu"): @@ -153,7 +153,7 @@ def to_cpu_async(self): if module is not None and hasattr(module, "to_cpu"): module.to_cpu(non_blocking=True) - def to_cuda_async(self): + def to_cuda_async(self, non_blocking=True): for name, param in self._parameters.items(): if param is not None: if hasattr(param, "cuda"): diff --git a/lightx2v/models/networks/z_image/infer/offload/transformer_infer.py b/lightx2v/models/networks/z_image/infer/offload/transformer_infer.py old mode 100644 new mode 100755 index b9c53b19d..d6ced7338 --- a/lightx2v/models/networks/z_image/infer/offload/transformer_infer.py +++ b/lightx2v/models/networks/z_image/infer/offload/transformer_infer.py @@ -10,42 +10,68 @@ class ZImageOffloadTransformerInfer(ZImageTransformerInfer): def __init__(self, config): super().__init__(config) - self.phases_num = 3 - self.num_blocks = config["num_layers"] if self.config.get("cpu_offload", False): - if "offload_ratio" in self.config: - self.offload_ratio = self.config["offload_ratio"] - else: - self.offload_ratio = 1 offload_granularity = self.config.get("offload_granularity", "block") if offload_granularity == "block": - if not self.config.get("lazy_load", False): - self.infer_func = self.infer_with_blocks_offload - else: - assert NotImplementedError - - if offload_granularity != "model": self.offload_manager = WeightAsyncStreamManager(offload_granularity=offload_granularity) - else: - assert NotImplementedError + self.lazy_load = self.config.get("lazy_load", False) + self.infer_main_blocks = self.infer_main_blocks_offload + if self.lazy_load: + self.offload_manager.init_lazy_load(num_workers=self.config.get("num_disk_workers", 4)) + elif offload_granularity == "phase": + raise NotImplementedError("offload_granularity=phase not supported") - def infer_with_blocks_offload(self, block_weights, hidden_states, encoder_hidden_states, temb, image_rotary_emb, modulate_index): - for block_idx in range(self.num_blocks): + def infer_with_blocks_offload( + self, + main_blocks, + unified, + unified_freqs_cis, + adaln_input, + ): + num_blocks = len(main_blocks) + for block_idx in range(num_blocks): self.block_idx = block_idx - if self.offload_manager.need_init_first_buffer: - self.offload_manager.init_first_buffer(block_weights.blocks) - self.offload_manager.prefetch_weights((block_idx + 1) % self.num_blocks, block_weights.blocks) + if self.lazy_load: + next_prefetch = (block_idx + 1) % num_blocks + self.offload_manager.start_prefetch_block(next_prefetch) + + if block_idx == 0: + self.offload_manager.init_first_buffer(main_blocks) + + if self.lazy_load: + self.offload_manager.swap_cpu_buffers() + self.offload_manager.prefetch_weights((block_idx + 1) % num_blocks, main_blocks) + with torch_device_module.stream(self.offload_manager.compute_stream): - encoder_hidden_states, hidden_states = self.infer_block( + unified = self.infer_block( block_weight=self.offload_manager.cuda_buffers[0], - hidden_states=hidden_states, - encoder_hidden_states=encoder_hidden_states, - temb=temb, - image_rotary_emb=image_rotary_emb, - modulate_index=modulate_index, + hidden_states=unified, + freqs_cis=unified_freqs_cis, + adaln_input=adaln_input, ) self.offload_manager.swap_blocks() - return encoder_hidden_states, hidden_states + return unified + + def infer_main_blocks_offload( + self, + main_blocks, + hidden_states, + encoder_hidden_states, + x_freqs_cis, + cap_freqs_cis, + adaln_input, + x_len, + cap_len, + ): + unified = torch.cat([hidden_states, encoder_hidden_states], dim=0) + unified_freqs_cis = torch.cat([x_freqs_cis[:x_len], cap_freqs_cis[:cap_len]], dim=0) + unified = self.infer_with_blocks_offload( + main_blocks=main_blocks, + unified=unified, + unified_freqs_cis=unified_freqs_cis, + adaln_input=adaln_input, + ) + return unified diff --git a/lightx2v/models/networks/z_image/infer/post_infer.py b/lightx2v/models/networks/z_image/infer/post_infer.py old mode 100644 new mode 100755 index 2cd1d0505..997ceb511 --- a/lightx2v/models/networks/z_image/infer/post_infer.py +++ b/lightx2v/models/networks/z_image/infer/post_infer.py @@ -11,22 +11,42 @@ def set_scheduler(self, scheduler): self.scheduler = scheduler def infer(self, weights, hidden_states, temb_img_silu, image_tokens_len=None): - temb_silu = F.silu(temb_img_silu) - temb1 = weights.norm_out_linear.apply(temb_silu) + """ + Post-inference processing: apply norm_out, proj_out, and unpatchify. + All processing is done without batch dimension: [T, D] instead of [B, T, D]. - scale = 1.0 + temb1 - normed = weights.norm_out.apply(hidden_states) - scaled_norm = normed * scale.unsqueeze(1) - B, T, D = scaled_norm.shape - hidden_states_2d = scaled_norm.reshape(B * T, D) + Args: + weights: PostInfer weights + hidden_states: Hidden states [T, D] (no batch dimension) + temb_img_silu: Time embedding [1, D] + image_tokens_len: Image tokens length (optional) - output_2d = weights.proj_out_linear.apply(hidden_states_2d) - out_dim = output_2d.shape[-1] - output = output_2d.reshape(B, T, out_dim) + Returns: + output_4d: Output tensor [C, H, W] (no batch dimension) + """ + # hidden_states is already [T, D] from pre_infer (no batch dimension) + # temb_img_silu is [1, D] from pre_infer + # Apply norm_out_linear: [1, D] -> [1, D] + temb_silu = F.silu(temb_img_silu) # [1, D] + temb1 = weights.norm_out_linear.apply(temb_silu) # [1, D] + + # Apply modulation: scale = 1.0 + temb1 + scale = 1.0 + temb1 # [1, D] + normed = weights.norm_out.apply(hidden_states) # [T, D] + scaled_norm = normed * scale # [T, D] * [1, D] -> [T, D] + + # Apply proj_out_linear: [T, D] -> [T, out_dim] + output = weights.proj_out_linear.apply(scaled_norm) # [T, out_dim] + + # Trim to image_tokens_len if specified if image_tokens_len is not None: - output = output[:, :image_tokens_len, :] + output = output[:image_tokens_len, :] # [image_tokens_len, out_dim] + + # Get output dimension + T, out_dim = output.shape + # Validate output dimension patch_size = self.config.get("patch_size", 2) f_patch_size = self.config.get("f_patch_size", 1) transformer_out_channels = out_dim // (patch_size * patch_size * f_patch_size) @@ -47,12 +67,19 @@ def infer(self, weights, hidden_states, temb_img_silu, image_tokens_len=None): W_tokens = width // pW expected_T = F_tokens * H_tokens * W_tokens - if output.shape[1] != expected_T: - raise ValueError(f"Token count mismatch: output.shape[1]={output.shape[1]} != expected_T={expected_T} (from target_shape={target_shape})") + if T != expected_T: + raise ValueError(f"Token count mismatch: T={T} != expected_T={expected_T} (from target_shape={target_shape})") - output_reshaped = output.view(B, F_tokens, H_tokens, W_tokens, pF, pH, pW, out_channels) - output_permuted = output_reshaped.permute(0, 7, 1, 4, 2, 5, 3, 6) - output_4d = output_permuted.reshape(B, out_channels, num_frames, height, width) - output_4d = output_4d.squeeze(2) + # Unpatchify: [T, out_dim] -> [C, H, W] + # Reshape: [T, out_dim] -> [F_tokens, H_tokens, W_tokens, pF, pH, pW, out_channels] + output_reshaped = output.view(F_tokens, H_tokens, W_tokens, pF, pH, pW, out_channels) + # Permute: [F_tokens, H_tokens, W_tokens, pF, pH, pW, out_channels] + # -> [out_channels, F_tokens, pF, H_tokens, pH, W_tokens, pW] + output_permuted = output_reshaped.permute(6, 0, 3, 1, 4, 2, 5) + # Reshape: [out_channels, F_tokens, pF, H_tokens, pH, W_tokens, pW] + # -> [out_channels, num_frames, height, width] + output_4d = output_permuted.reshape(out_channels, num_frames, height, width) + # Remove frame dimension: [out_channels, 1, height, width] -> [out_channels, height, width] + output_4d = output_4d.squeeze(1) return output_4d diff --git a/lightx2v/models/networks/z_image/infer/pre_infer.py b/lightx2v/models/networks/z_image/infer/pre_infer.py old mode 100644 new mode 100755 index 52fc552e1..aad32f374 --- a/lightx2v/models/networks/z_image/infer/pre_infer.py +++ b/lightx2v/models/networks/z_image/infer/pre_infer.py @@ -1,6 +1,5 @@ import torch import torch.nn.functional as F -from torch.nn.utils.rnn import pad_sequence from lightx2v.utils.envs import * @@ -25,10 +24,9 @@ def infer(self, weights, hidden_states, encoder_hidden_states): patch_size = self.config.get("patch_size", 2) f_patch_size = self.config.get("f_patch_size", 1) - if hidden_states.dim() == 4: - hidden_states = patchify(hidden_states, patch_size=patch_size, f_patch_size=f_patch_size) + hidden_states = patchify(hidden_states, patch_size=patch_size, f_patch_size=f_patch_size).squeeze(0) - batch_size, num_tokens, patch_dim = hidden_states.shape + num_tokens, patch_dim = hidden_states.shape original_shape = self.scheduler.input_info.target_shape if len(original_shape) >= 2: @@ -40,160 +38,120 @@ def infer(self, weights, hidden_states, encoder_hidden_states): H_tokens = original_height // patch_size W_tokens = original_width // patch_size - padded_list = [] - image_pad_masks = [] - x_item_seqlens_ori = [] - for b in range(batch_size): - x_item = hidden_states[b] - x_ori_len = x_item.shape[0] - x_item_seqlens_ori.append(x_ori_len) - x_padding_len = (-x_ori_len) % SEQ_MULTI_OF - - if x_padding_len > 0: - pad_mask = torch.cat( - [ - torch.zeros((x_ori_len,), dtype=torch.bool, device=x_item.device), - torch.ones((x_padding_len,), dtype=torch.bool, device=x_item.device), - ], - dim=0, - ) - x_padded = torch.cat([x_item, x_item[-1:].repeat(x_padding_len, 1)], dim=0) - padded_list.append(x_padded) - image_pad_masks.append(pad_mask) - else: - pad_mask = torch.zeros((x_ori_len,), dtype=torch.bool, device=x_item.device) - padded_list.append(x_item) - image_pad_masks.append(pad_mask) - - x_item_seqlens = [x.shape[0] for x in padded_list] - x_cat = torch.cat(padded_list, dim=0) + # Process image tokens (single sample, no batch) + x_ori_len = hidden_states.shape[0] + x_padding_len = (-x_ori_len) % SEQ_MULTI_OF + + if x_padding_len > 0: + x_pad_mask = torch.cat( + [ + torch.zeros((x_ori_len,), dtype=torch.bool, device=hidden_states.device), + torch.ones((x_padding_len,), dtype=torch.bool, device=hidden_states.device), + ], + dim=0, + ) + x_padded = torch.cat([hidden_states, hidden_states[-1:].repeat(x_padding_len, 1)], dim=0) + else: + x_pad_mask = torch.zeros((x_ori_len,), dtype=torch.bool, device=hidden_states.device) + x_padded = hidden_states - hidden_states_2d = weights.img_in.apply(x_cat) + x_padded_len = x_padded.shape[0] + hidden_states = weights.img_in.apply(x_padded) # [L, D] if hasattr(weights, "x_pad_token") and hasattr(weights.x_pad_token, "tensor"): x_pad_token = weights.x_pad_token.tensor - x_inner_pad_mask = torch.cat(image_pad_masks, dim=0) - hidden_states_2d[x_inner_pad_mask] = x_pad_token.squeeze(0) # Broadcast to [D] - - hidden_states_list = list(hidden_states_2d.split(x_item_seqlens, dim=0)) - hidden_states = pad_sequence(hidden_states_list, batch_first=True, padding_value=0.0) + # Handle both [1, D] and [D] formats + if x_pad_token.dim() == 2: + x_pad_token = x_pad_token.squeeze(0) # [D] + hidden_states[x_pad_mask] = x_pad_token + # Process encoder hidden states (single sample, no batch) + # Remove batch dimension if present: [B, L, D] -> [L, D] if encoder_hidden_states.dim() == 3: - pass - elif encoder_hidden_states.dim() == 2: - encoder_hidden_states = encoder_hidden_states.unsqueeze(0) - else: + encoder_hidden_states = encoder_hidden_states.squeeze(0) + elif encoder_hidden_states.dim() != 2: raise ValueError(f"encoder_hidden_states must be 2D [L, D] or 3D [B, L, D], got {encoder_hidden_states.shape}") - cap_padded_list = [] - cap_pad_masks = [] - cap_item_seqlens_ori = [] - for b in range(batch_size): - cap_item = encoder_hidden_states[b] - cap_ori_len = cap_item.shape[0] - cap_item_seqlens_ori.append(cap_ori_len) - cap_padding_len = (-cap_ori_len) % SEQ_MULTI_OF - - if cap_padding_len > 0: - cap_pad_mask = torch.cat( - [ - torch.zeros((cap_ori_len,), dtype=torch.bool, device=cap_item.device), - torch.ones((cap_padding_len,), dtype=torch.bool, device=cap_item.device), - ], - dim=0, - ) - cap_padded = torch.cat([cap_item, cap_item[-1:].repeat(cap_padding_len, 1)], dim=0) - cap_padded_list.append(cap_padded) - cap_pad_masks.append(cap_pad_mask) - else: - cap_pad_mask = torch.zeros((cap_ori_len,), dtype=torch.bool, device=cap_item.device) - cap_padded_list.append(cap_item) - cap_pad_masks.append(cap_pad_mask) - - cap_item_seqlens = [x.shape[0] for x in cap_padded_list] - cap_cat = torch.cat(cap_padded_list, dim=0) + cap_ori_len = encoder_hidden_states.shape[0] + cap_padding_len = (-cap_ori_len) % SEQ_MULTI_OF + + if cap_padding_len > 0: + cap_pad_mask = torch.cat( + [ + torch.zeros((cap_ori_len,), dtype=torch.bool, device=encoder_hidden_states.device), + torch.ones((cap_padding_len,), dtype=torch.bool, device=encoder_hidden_states.device), + ], + dim=0, + ) + cap_padded = torch.cat([encoder_hidden_states, encoder_hidden_states[-1:].repeat(cap_padding_len, 1)], dim=0) + else: + cap_pad_mask = torch.zeros((cap_ori_len,), dtype=torch.bool, device=encoder_hidden_states.device) + cap_padded = encoder_hidden_states - cap_cat = weights.txt_norm.apply(cap_cat) - cap_cat = weights.txt_in.apply(cap_cat) + cap_padded_len = cap_padded.shape[0] + encoder_hidden_states = weights.txt_norm.apply(cap_padded) # [L, D] + encoder_hidden_states = weights.txt_in.apply(encoder_hidden_states) # [L, D] if hasattr(weights, "cap_pad_token") and hasattr(weights.cap_pad_token, "tensor"): cap_pad_token = weights.cap_pad_token.tensor - cap_inner_pad_mask = torch.cat(cap_pad_masks, dim=0) - cap_cat[cap_inner_pad_mask] = cap_pad_token.squeeze(0) - - encoder_hidden_states_list = list(cap_cat.split(cap_item_seqlens, dim=0)) - encoder_hidden_states = pad_sequence(encoder_hidden_states_list, batch_first=True, padding_value=0.0) + # Handle both [1, D] and [D] formats + if cap_pad_token.dim() == 2: + cap_pad_token = cap_pad_token.squeeze(0) # [D] + encoder_hidden_states[cap_pad_mask] = cap_pad_token device = hidden_states.device - x_pos_ids_list = [] - cap_pos_ids_list = [] - - for b in range(batch_size): - cap_ori_len = cap_item_seqlens_ori[b] - cap_padded_len = cap_item_seqlens[b] - cap_pos_ids = self.scheduler.create_coordinate_grid( - size=(cap_padded_len, 1, 1), - start=(1, 0, 0), - device=device, - ).flatten(0, 2) - cap_pos_ids_list.append(cap_pos_ids) - - x_ori_len = x_item_seqlens_ori[b] - x_padded_len = x_item_seqlens[b] - image_pos_ids = self.scheduler.create_coordinate_grid( - size=(F_tokens, H_tokens, W_tokens), - start=(cap_padded_len + 1, 0, 0), - device=device, - ).flatten(0, 2) - - if x_padded_len > x_ori_len: - padding_pos_ids = ( - self.scheduler.create_coordinate_grid( - size=(1, 1, 1), - start=(0, 0, 0), - device=device, - ) - .flatten(0, 2) - .repeat(x_padded_len - x_ori_len, 1) + # Generate position IDs for caption + cap_pos_ids = self.scheduler.create_coordinate_grid( + size=(cap_padded_len, 1, 1), + start=(1, 0, 0), + device=device, + ).flatten(0, 2) + + # Generate position IDs for image + image_pos_ids = self.scheduler.create_coordinate_grid( + size=(F_tokens, H_tokens, W_tokens), + start=(cap_padded_len + 1, 0, 0), + device=device, + ).flatten(0, 2) + + if x_padded_len > x_ori_len: + padding_pos_ids = ( + self.scheduler.create_coordinate_grid( + size=(1, 1, 1), + start=(0, 0, 0), + device=device, ) - image_pos_ids = torch.cat([image_pos_ids, padding_pos_ids], dim=0) - - x_pos_ids_list.append(image_pos_ids) - - x_pos_ids_cat = torch.cat(x_pos_ids_list, dim=0) - cap_pos_ids_cat = torch.cat(cap_pos_ids_list, dim=0) - - x_freqs_cis_cat = self.scheduler.generate_freqs_cis_from_position_ids(x_pos_ids_cat, device=device) - cap_freqs_cis_cat = self.scheduler.generate_freqs_cis_from_position_ids(cap_pos_ids_cat, device=device) + .flatten(0, 2) + .repeat(x_padded_len - x_ori_len, 1) + ) + image_pos_ids = torch.cat([image_pos_ids, padding_pos_ids], dim=0) - x_freqs_cis_list = list(x_freqs_cis_cat.split(x_item_seqlens, dim=0)) - cap_freqs_cis_list = list(cap_freqs_cis_cat.split(cap_item_seqlens, dim=0)) - - x_freqs_cis = pad_sequence(x_freqs_cis_list, batch_first=True, padding_value=0.0) - cap_freqs_cis = pad_sequence(cap_freqs_cis_list, batch_first=True, padding_value=0.0) + # Generate freqs_cis + x_freqs_cis = self.scheduler.generate_freqs_cis_from_position_ids(image_pos_ids, device=device) + cap_freqs_cis = self.scheduler.generate_freqs_cis_from_position_ids(cap_pos_ids, device=device) embed0 = weights.time_text_embed_timestep_embedder_linear_1.apply(self.scheduler.timesteps_proj) - embed0 = F.silu(embed0) - embed0 = weights.time_text_embed_timestep_embedder_linear_2.apply(embed0) - temb_img_silu = embed0 + temb_img_silu = weights.time_text_embed_timestep_embedder_linear_2.apply(embed0) if self.zero_cond_t: - temb_txt_silu = torch.zeros_like(temb_img_silu) + temb_txt_silu = torch.zeros_like(temb_img_silu) # [D] else: - pooled_text = encoder_hidden_states.mean(dim=1) + # encoder_hidden_states is [L, D], mean over sequence length + pooled_text = encoder_hidden_states.mean(dim=0) # [D] + if pooled_text.shape[-1] != temb_img_silu.shape[-1]: target_dim = temb_img_silu.shape[-1] if pooled_text.shape[-1] > target_dim: pooled_text = pooled_text[..., :target_dim] else: - padding = torch.zeros(batch_size, target_dim - pooled_text.shape[-1], device=pooled_text.device, dtype=pooled_text.dtype) + padding = torch.zeros(target_dim - pooled_text.shape[-1], device=pooled_text.device, dtype=pooled_text.dtype) pooled_text = torch.cat([pooled_text, padding], dim=-1) - temb_txt_silu = F.silu(pooled_text) + temb_txt_silu = F.silu(pooled_text) # [D] - image_tokens_len = x_item_seqlens_ori[0] + image_tokens_len = x_ori_len return ZPreInferModuleOutput( hidden_states=hidden_states, @@ -203,6 +161,6 @@ def infer(self, weights, hidden_states, encoder_hidden_states): x_freqs_cis=x_freqs_cis, cap_freqs_cis=cap_freqs_cis, image_tokens_len=image_tokens_len, - x_item_seqlens=x_item_seqlens, - cap_item_seqlens=cap_item_seqlens, + x_item_seqlens=[x_padded_len], + cap_item_seqlens=[cap_padded_len], ) diff --git a/lightx2v/models/networks/z_image/infer/transformer_infer.py b/lightx2v/models/networks/z_image/infer/transformer_infer.py old mode 100644 new mode 100755 index b22a1bf25..360a11a72 --- a/lightx2v/models/networks/z_image/infer/transformer_infer.py +++ b/lightx2v/models/networks/z_image/infer/transformer_infer.py @@ -1,20 +1,11 @@ import torch +import torch.nn.functional as F from lightx2v.common.transformer_infer.transformer_infer import BaseTransformerInfer -from .triton_ops import ( - fuse_scale_shift_kernel, -) from .utils import apply_rotary_emb_qwen, apply_wan_rope_with_flashinfer -def calculate_q_k_len(q, k_lens): - q_lens = torch.tensor([q.size(0)], dtype=torch.int32, device=q.device) - cu_seqlens_q = torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(0, dtype=torch.int32) - cu_seqlens_k = torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(0, dtype=torch.int32) - return cu_seqlens_q, cu_seqlens_k - - class ZImageTransformerInfer(BaseTransformerInfer): def __init__(self, config): self.config = config @@ -22,15 +13,12 @@ def __init__(self, config): self.clean_cuda_cache = self.config.get("clean_cuda_cache", False) self.attn_type = config.get("attn_type", "flash_attn3") self.zero_cond_t = config.get("zero_cond_t", False) + self.n_heads = config.get("n_heads", config.get("num_attention_heads", 24)) if self.config["seq_parallel"]: self.seq_p_group = self.config.get("device_mesh").get_group(mesh_dim="seq_p") else: self.seq_p_group = None self.seq_p_fp8_comm = False - if self.config.get("modulate_type", "triton") == "triton": - self.modulate_func = fuse_scale_shift_kernel - else: - self.modulate_func = lambda x, scale, shift: x * (1 + scale) + shift if self.config.get("rope_type", "flashinfer") == "flashinfer": self.apply_rope_func = apply_wan_rope_with_flashinfer else: @@ -39,51 +27,60 @@ def __init__(self, config): def set_scheduler(self, scheduler): self.scheduler = scheduler - def apply_attn(self, block_weight, hidden_states, freqs_cis): - is_3d = hidden_states.dim() == 3 - if is_3d: - B, T, D = hidden_states.shape - hidden_states_2d = hidden_states.reshape(-1, D) - freqs_cis_2d = freqs_cis.reshape(-1, freqs_cis.shape[-1]) - else: - hidden_states_2d = hidden_states - freqs_cis_2d = freqs_cis + def infer_mod(self, mod_phase, hidden_states, adaln_input): + if mod_phase is None: + return None, None, None, None - query = block_weight.attention.to_q.apply(hidden_states_2d) - key = block_weight.attention.to_k.apply(hidden_states_2d) - value = block_weight.attention.to_v.apply(hidden_states_2d) + mod_params = mod_phase.adaLN_modulation.apply(adaln_input) + scale_msa, gate_msa, scale_mlp, gate_mlp = mod_params.chunk(4, dim=-1) + gate_msa.tanh_() + gate_mlp.tanh_() - query = query.unflatten(-1, (block_weight.attention.heads, -1)) - key = key.unflatten(-1, (block_weight.attention.heads, -1)) - value = value.unflatten(-1, (block_weight.attention.heads, -1)) + scale_msa.add_(1.0) + scale_mlp.add_(1.0) + + return scale_msa, gate_msa, scale_mlp, gate_mlp + + def infer_attn(self, attn_phase, hidden_states, freqs_cis, scale_msa=None): + norm1_out = attn_phase.attention_norm1.apply(hidden_states) + if scale_msa is not None: + scaled_norm1 = norm1_out * scale_msa + else: + scaled_norm1 = norm1_out - if block_weight.attention.norm_q is not None: - query = block_weight.attention.norm_q.apply(query) - if block_weight.attention.norm_k is not None: - key = block_weight.attention.norm_k.apply(key) + query = attn_phase.to_q.apply(scaled_norm1) + key = attn_phase.to_k.apply(scaled_norm1) + value = attn_phase.to_v.apply(scaled_norm1) + head_dim = query.shape[-1] // self.n_heads + query = query.unflatten(-1, (self.n_heads, head_dim)) + key = key.unflatten(-1, (self.n_heads, head_dim)) + value = value.unflatten(-1, (self.n_heads, head_dim)) - query, key = self.apply_rope_func(query, key, freqs_cis_2d) + if attn_phase.norm_q is not None: + query = attn_phase.norm_q.apply(query) + if attn_phase.norm_k is not None: + key = attn_phase.norm_k.apply(key) - dtype = query.dtype - query, key = query.to(dtype), key.to(dtype) + query, key = self.apply_rope_func(query, key, freqs_cis) total_seq_len = query.shape[0] cu_seqlens = torch.tensor([0, total_seq_len], dtype=torch.int32, device="cpu").to(query.device, non_blocking=True) if self.config["seq_parallel"]: - hidden_states_out = block_weight.attention.calculate_parallel.apply( + hidden_states_out = attn_phase.calculate_parallel.apply( q=query, k=key, v=value, slice_qkv_len=total_seq_len, cu_seqlens_qkv=cu_seqlens, - attention_module=block_weight.attention.calculate, + attention_module=attn_phase.calculate, seq_p_group=self.seq_p_group, use_fp8_comm=self.seq_p_fp8_comm, img_first=False, ) else: - hidden_states_out = block_weight.attention.calculate.apply( + # todo + hidden_states_out = attn_phase.calculate.apply( q=query, k=key, v=value, @@ -93,14 +90,25 @@ def apply_attn(self, block_weight, hidden_states, freqs_cis): max_seqlen_kv=total_seq_len, ) - output = block_weight.attention.to_out[0].apply(hidden_states_out) - if len(block_weight.attention.to_out) > 1: - output = block_weight.attention.to_out[1].apply(output) + output = attn_phase.to_out[0].apply(hidden_states_out) + if len(attn_phase.to_out) > 1: + output = attn_phase.to_out[1].apply(output) - if is_3d: - output = output.reshape(B, T, -1) + attn_out = attn_phase.attention_norm2.apply(output) - return output + return attn_out + + def infer_ffn(self, ffn_phase, hidden_states, scale_mlp=None, gate_mlp=None): + ffn_norm1_out = ffn_phase.ffn_norm1.apply(hidden_states) + if scale_mlp is not None: + ffn_norm1_out.mul_(scale_mlp) + w1_out = ffn_phase.w1.apply(ffn_norm1_out) + w3_out = ffn_phase.w3.apply(ffn_norm1_out) + silu_gated = F.silu(w1_out) * w3_out + ffn_out = ffn_phase.w2.apply(silu_gated) + norm2_ffn = ffn_phase.ffn_norm2.apply(ffn_out) + + return norm2_ffn, gate_mlp def infer_block( self, @@ -109,57 +117,90 @@ def infer_block( freqs_cis, adaln_input=None, ): - if block_weight.modulation: - assert adaln_input is not None - mod_params = block_weight.adaLN_modulation.apply(adaln_input) - scale_msa, gate_msa, scale_mlp, gate_mlp = mod_params.unsqueeze(1).chunk(4, dim=2) + mod_phase = block_weight.compute_phases[0] if block_weight.has_modulation else None + attn_phase = block_weight.compute_phases[1] + ffn_phase = block_weight.compute_phases[2] - gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh() - scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp + scale_msa, gate_msa, scale_mlp, gate_mlp = self.infer_mod(mod_phase, hidden_states, adaln_input) + attn_out = self.infer_attn(attn_phase, hidden_states, freqs_cis, scale_msa) - norm1_out = block_weight.attention_norm1.apply(hidden_states) - scaled_norm1 = norm1_out * scale_msa + if gate_msa is not None: + hidden_states.add_(gate_msa * attn_out) + else: + hidden_states.add_(attn_out) + norm2_ffn, gate_mlp = self.infer_ffn(ffn_phase, hidden_states, scale_mlp, gate_mlp) - # Attention block - attn_out = self.apply_attn( - block_weight=block_weight, - hidden_states=scaled_norm1, - freqs_cis=freqs_cis, - ) - norm2_attn = block_weight.attention_norm2.apply(attn_out) + if gate_mlp is not None: + hidden_states.add_(gate_mlp * norm2_ffn) + else: + hidden_states.add_(norm2_ffn) - hidden_states = hidden_states + gate_msa * norm2_attn + if hidden_states.dtype == torch.float16: + hidden_states.clip_(-65504, 65504) - ffn_norm1_out = block_weight.ffn_norm1.apply(hidden_states) - scaled_ffn_norm1 = ffn_norm1_out * scale_mlp + return hidden_states - ffn_out = block_weight.feed_forward.forward(scaled_ffn_norm1) - norm2_ffn = block_weight.ffn_norm2.apply(ffn_out) + def infer_noise_refiner( + self, + noise_refiner_blocks, + hidden_states, + x_freqs_cis, + adaln_input, + x_len, + ): + x_hidden = hidden_states[:x_len] + x_freqs = x_freqs_cis[:x_len] + for block_weight in noise_refiner_blocks: + x_hidden = self.infer_block( + block_weight=block_weight, + hidden_states=x_hidden, + freqs_cis=x_freqs, + adaln_input=adaln_input, + ) - hidden_states = hidden_states + gate_mlp * norm2_ffn - else: - norm1_out = block_weight.attention_norm1.apply(hidden_states) + return x_hidden - # Attention block - attn_out = self.apply_attn( + def infer_context_refiner( + self, + context_refiner_blocks, + encoder_hidden_states, + cap_freqs_cis, + cap_len, + ): + cap_hidden = encoder_hidden_states[:cap_len] + cap_freqs = cap_freqs_cis[:cap_len] + for block_weight in context_refiner_blocks: + cap_hidden = self.infer_block( block_weight=block_weight, - hidden_states=norm1_out, - freqs_cis=freqs_cis, + hidden_states=cap_hidden, + freqs_cis=cap_freqs, + adaln_input=None, ) - norm2_attn = block_weight.attention_norm2.apply(attn_out) - hidden_states = hidden_states + norm2_attn + return cap_hidden - # FFN block - ffn_norm1_out = block_weight.ffn_norm1.apply(hidden_states) - ffn_out = block_weight.feed_forward.forward(ffn_norm1_out) - norm2_ffn = block_weight.ffn_norm2.apply(ffn_out) - hidden_states = hidden_states + norm2_ffn + def infer_main_blocks( + self, + main_blocks, + hidden_states, + encoder_hidden_states, + x_freqs_cis, + cap_freqs_cis, + adaln_input, + x_len, + cap_len, + ): + unified = torch.cat([hidden_states, encoder_hidden_states], dim=0) + unified_freqs_cis = torch.cat([x_freqs_cis[:x_len], cap_freqs_cis[:cap_len]], dim=0) + for block_weight in main_blocks: + unified = self.infer_block( + block_weight=block_weight, + hidden_states=unified, + freqs_cis=unified_freqs_cis, + adaln_input=adaln_input, + ) - # Clip to prevent overflow for fp16 - if hidden_states.dtype == torch.float16: - hidden_states = hidden_states.clip(-65504, 65504) - return hidden_states + return unified def infer_calculating( self, @@ -172,101 +213,37 @@ def infer_calculating( x_item_seqlens, cap_item_seqlens, ): - from torch.nn.utils.rnn import pad_sequence - - batch_size = hidden_states.shape[0] - device = hidden_states.device - - # ==================== Stage 1: Noise Refiner (Image Stream) ==================== - # Process image stream with modulation - if block_weights.noise_refiner is not None and len(block_weights.noise_refiner) > 0: - # Build attention mask for image tokens - x_max_seqlen = max(x_item_seqlens) - x_attn_mask = torch.zeros((batch_size, x_max_seqlen), dtype=torch.bool, device=device) - for i, seq_len in enumerate(x_item_seqlens): - x_attn_mask[i, :seq_len] = True - - # Process through noise_refiner layers (with modulation) - # Use 3D [B, T, D] format to match official implementation - for idx in range(len(block_weights.noise_refiner)): - hidden_states = self.infer_block( - block_weight=block_weights.noise_refiner[idx], - hidden_states=hidden_states, - freqs_cis=x_freqs_cis, - adaln_input=adaln_input, - ) - - # ==================== Stage 2: Context Refiner (Text Stream) ==================== - # Process text stream without modulation - if block_weights.context_refiner is not None and len(block_weights.context_refiner) > 0: - # Build attention mask for text tokens - cap_max_seqlen = max(cap_item_seqlens) - cap_attn_mask = torch.zeros((batch_size, cap_max_seqlen), dtype=torch.bool, device=device) - for i, seq_len in enumerate(cap_item_seqlens): - cap_attn_mask[i, :seq_len] = True - - # Process through context_refiner layers (without modulation) - # Use 3D [B, L, D] format to match official implementation - for idx in range(len(block_weights.context_refiner)): - encoder_hidden_states = self.infer_block( - block_weight=block_weights.context_refiner[idx], - hidden_states=encoder_hidden_states, # [B, L, D] - freqs_cis=cap_freqs_cis, # [B, L, D_rope] - adaln_input=None, # No modulation for context_refiner - ) - - # ==================== Stage 3: Unified Layers (Merged Stream) ==================== - # Merge image and text streams - unified_list = [] - unified_freqs_cis_list = [] - unified_item_seqlens = [] - - for b in range(batch_size): - x_len = x_item_seqlens[b] - cap_len = cap_item_seqlens[b] - - # Concatenate image and text tokens: [image_tokens, text_tokens] - unified_item = torch.cat( - [ - hidden_states[b, :x_len], - encoder_hidden_states[b, :cap_len], - ], - dim=0, - ) - unified_list.append(unified_item) - - # Concatenate freqs_cis: [image_freqs, text_freqs] - unified_freqs_item = torch.cat( - [ - x_freqs_cis[b, :x_len], - cap_freqs_cis[b, :cap_len], - ], - dim=0, - ) - unified_freqs_cis_list.append(unified_freqs_item) - - unified_item_seqlens.append(x_len + cap_len) - - # Pad unified sequences - unified_max_seqlen = max(unified_item_seqlens) - unified = pad_sequence(unified_list, batch_first=True, padding_value=0.0) # [B, max_seqlen, D] - unified_freqs_cis = pad_sequence(unified_freqs_cis_list, batch_first=True, padding_value=0.0) # [B, max_seqlen, D_rope] - - # Build attention mask for unified stream - unified_attn_mask = torch.zeros((batch_size, unified_max_seqlen), dtype=torch.bool, device=device) - for i, seq_len in enumerate(unified_item_seqlens): - unified_attn_mask[i, :seq_len] = True - - # Process through unified layers (with modulation) - # Use 3D [B, T_unified, D] format to match official implementation - if block_weights.blocks is not None and len(block_weights.blocks) > 0: - for idx in range(len(block_weights.blocks)): - unified = self.infer_block( - block_weight=block_weights.blocks[idx], - hidden_states=unified, - freqs_cis=unified_freqs_cis, - adaln_input=adaln_input, - ) + x_len = x_item_seqlens[0] + cap_len = cap_item_seqlens[0] + + # Stage 1: Noise Refiner (Image Stream) + hidden_states = self.infer_noise_refiner( + noise_refiner_blocks=block_weights.noise_refiner, + hidden_states=hidden_states, + x_freqs_cis=x_freqs_cis, + adaln_input=adaln_input, + x_len=x_len, + ) + + # Stage 2: Context Refiner (Text Stream) + encoder_hidden_states = self.infer_context_refiner( + context_refiner_blocks=block_weights.context_refiner, + encoder_hidden_states=encoder_hidden_states, + cap_freqs_cis=cap_freqs_cis, + cap_len=cap_len, + ) + + # Stage 3: Main Blocks (Unified Stream) + unified = self.infer_main_blocks( + main_blocks=block_weights.blocks, + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + x_freqs_cis=x_freqs_cis, + cap_freqs_cis=cap_freqs_cis, + adaln_input=adaln_input, + x_len=x_len, + cap_len=cap_len, + ) return unified @@ -276,8 +253,6 @@ def infer(self, block_weights, pre_infer_out): adaln_input = pre_infer_out.adaln_input x_item_seqlens = pre_infer_out.x_item_seqlens cap_item_seqlens = pre_infer_out.cap_item_seqlens - - # Use freqs_cis generated from position ids in pre_infer x_freqs_cis = pre_infer_out.x_freqs_cis cap_freqs_cis = pre_infer_out.cap_freqs_cis diff --git a/lightx2v/models/networks/z_image/infer/triton_ops.py b/lightx2v/models/networks/z_image/infer/triton_ops.py deleted file mode 100644 index d9d612ff0..000000000 --- a/lightx2v/models/networks/z_image/infer/triton_ops.py +++ /dev/null @@ -1,1165 +0,0 @@ -# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo & https://github.com/sgl-project/sglang - -# TODO: for temporary usage, expecting a refactor -from typing import Optional - -import torch -import triton # type: ignore -import triton.language as tl # type: ignore -from torch import Tensor - - -@triton.autotune( - configs=[ - triton.Config({"BLOCK_N": 64}, num_warps=2), - triton.Config({"BLOCK_N": 128}, num_warps=4), - triton.Config({"BLOCK_N": 256}, num_warps=4), - triton.Config({"BLOCK_N": 512}, num_warps=4), - triton.Config({"BLOCK_N": 1024}, num_warps=8), - ], - key=["inner_dim"], -) -@triton.jit -def _fused_scale_shift_4d_kernel( - output_ptr, - normalized_ptr, - scale_ptr, - shift_ptr, - rows, - inner_dim, - seq_len, - num_frames, - frame_seqlen, - BLOCK_N: tl.constexpr, -): - pid_row = tl.program_id(0) - pid_col = tl.program_id(1) - - col_offsets = pid_col * BLOCK_N + tl.arange(0, BLOCK_N) - mask = col_offsets < inner_dim - - # Pointers for normalized and output - row_base = pid_row * inner_dim - norm_ptrs = normalized_ptr + row_base + col_offsets - out_ptrs = output_ptr + row_base + col_offsets - - # Pointers for scale and shift for 4D - b_idx = pid_row // seq_len - t_idx = pid_row % seq_len - frame_idx_in_batch = t_idx // frame_seqlen - - scale_row_idx = b_idx * num_frames + frame_idx_in_batch - scale_ptrs = scale_ptr + scale_row_idx * inner_dim + col_offsets - shift_ptrs = shift_ptr + scale_row_idx * inner_dim + col_offsets - - normalized = tl.load(norm_ptrs, mask=mask, other=0.0) - scale = tl.load(scale_ptrs, mask=mask, other=0.0) - shift = tl.load(shift_ptrs, mask=mask, other=0.0) - - one = tl.full([BLOCK_N], 1.0, dtype=scale.dtype) - output = normalized * (one + scale) + shift - - tl.store(out_ptrs, output, mask=mask) - - -@triton.jit -def fuse_scale_shift_kernel_blc_opt( - x_ptr, - shift_ptr, - scale_ptr, - y_ptr, - B, - L, - C, - stride_x_b, - stride_x_l, - stride_x_c, - stride_s_b, - stride_s_l, - stride_s_c, - stride_sc_b, - stride_sc_l, - stride_sc_c, - SCALE_IS_SCALAR: tl.constexpr, - SHIFT_IS_SCALAR: tl.constexpr, - BLOCK_L: tl.constexpr, - BLOCK_C: tl.constexpr, -): - pid_l = tl.program_id(0) - pid_c = tl.program_id(1) - pid_b = tl.program_id(2) - - l_offsets = pid_l * BLOCK_L + tl.arange(0, BLOCK_L) - c_offsets = pid_c * BLOCK_C + tl.arange(0, BLOCK_C) - - mask_l = l_offsets < L - mask_c = c_offsets < C - mask = mask_l[:, None] & mask_c[None, :] - - x_off = pid_b * stride_x_b + l_offsets[:, None] * stride_x_l + c_offsets[None, :] * stride_x_c - x = tl.load(x_ptr + x_off, mask=mask, other=0) - - if SHIFT_IS_SCALAR: - shift_val = tl.load(shift_ptr) - shift = tl.full((BLOCK_L, BLOCK_C), shift_val, dtype=shift_val.dtype) - else: - s_off = pid_b * stride_s_b + l_offsets[:, None] * stride_s_l + c_offsets[None, :] * stride_s_c - shift = tl.load(shift_ptr + s_off, mask=mask, other=0) - - if SCALE_IS_SCALAR: - scale_val = tl.load(scale_ptr) - scale = tl.full((BLOCK_L, BLOCK_C), scale_val, dtype=scale_val.dtype) - else: - sc_off = pid_b * stride_sc_b + l_offsets[:, None] * stride_sc_l + c_offsets[None, :] * stride_sc_c - scale = tl.load(scale_ptr + sc_off, mask=mask, other=0) - - y = x * (1 + scale) + shift - tl.store(y_ptr + x_off, y, mask=mask) - - -@triton.jit -def fuse_scale_shift_gate_select01_kernel_blc_opt( - x_ptr, - shift0_ptr, - scale0_ptr, - gate0_ptr, - shift1_ptr, - scale1_ptr, - gate1_ptr, - index_ptr, - y_ptr, - gate_out_ptr, - B, - L, - C, - stride_x_b, - stride_x_l, - stride_x_c, - stride_s0_b, - stride_s0_c, - stride_sc0_b, - stride_sc0_c, - stride_g0_b, - stride_g0_c, - stride_s1_b, - stride_s1_c, - stride_sc1_b, - stride_sc1_c, - stride_g1_b, - stride_g1_c, - stride_i_b, - stride_i_l, - stride_go_b, - stride_go_l, - stride_go_c, - BLOCK_L: tl.constexpr, - BLOCK_C: tl.constexpr, -): - pid_l = tl.program_id(0) - pid_c = tl.program_id(1) - pid_b = tl.program_id(2) - - l_offsets = pid_l * BLOCK_L + tl.arange(0, BLOCK_L) - c_offsets = pid_c * BLOCK_C + tl.arange(0, BLOCK_C) - - mask_l = l_offsets < L - mask_c = c_offsets < C - mask = mask_l[:, None] & mask_c[None, :] - - x_off = pid_b * stride_x_b + l_offsets[:, None] * stride_x_l + c_offsets[None, :] * stride_x_c - x = tl.load(x_ptr + x_off, mask=mask, other=0) - - idx_off = pid_b * stride_i_b + l_offsets * stride_i_l - idx = tl.load(index_ptr + idx_off, mask=mask_l, other=0).to(tl.int1)[:, None] - - s0_off = pid_b * stride_s0_b + c_offsets[None, :] * stride_s0_c - sc0_off = pid_b * stride_sc0_b + c_offsets[None, :] * stride_sc0_c - g0_off = pid_b * stride_g0_b + c_offsets[None, :] * stride_g0_c - s1_off = pid_b * stride_s1_b + c_offsets[None, :] * stride_s1_c - sc1_off = pid_b * stride_sc1_b + c_offsets[None, :] * stride_sc1_c - g1_off = pid_b * stride_g1_b + c_offsets[None, :] * stride_g1_c - - shift0 = tl.load(shift0_ptr + s0_off, mask=mask_c[None, :], other=0) - scale0 = tl.load(scale0_ptr + sc0_off, mask=mask_c[None, :], other=0) - gate0 = tl.load(gate0_ptr + g0_off, mask=mask_c[None, :], other=0) - shift1 = tl.load(shift1_ptr + s1_off, mask=mask_c[None, :], other=0) - scale1 = tl.load(scale1_ptr + sc1_off, mask=mask_c[None, :], other=0) - gate1 = tl.load(gate1_ptr + g1_off, mask=mask_c[None, :], other=0) - - shift = tl.where(idx, shift1, shift0) - scale = tl.where(idx, scale1, scale0) - gate = tl.where(idx, gate1, gate0) - - y = x * (1 + scale) + shift - tl.store(y_ptr + x_off, y, mask=mask) - - go_off = pid_b * stride_go_b + l_offsets[:, None] * stride_go_l + c_offsets[None, :] * stride_go_c - tl.store(gate_out_ptr + go_off, gate, mask=mask) - - -@triton.jit -def fuse_scale_shift_select01_kernel_blc_opt( - x_ptr, - shift0_ptr, - scale0_ptr, - shift1_ptr, - scale1_ptr, - index_ptr, - y_ptr, - B, - L, - C, - stride_x_b, - stride_x_l, - stride_x_c, - stride_s0_b, - stride_s0_c, - stride_sc0_b, - stride_sc0_c, - stride_s1_b, - stride_s1_c, - stride_sc1_b, - stride_sc1_c, - stride_i_b, - stride_i_l, - BLOCK_L: tl.constexpr, - BLOCK_C: tl.constexpr, -): - pid_l = tl.program_id(0) - pid_c = tl.program_id(1) - pid_b = tl.program_id(2) - - l_offsets = pid_l * BLOCK_L + tl.arange(0, BLOCK_L) - c_offsets = pid_c * BLOCK_C + tl.arange(0, BLOCK_C) - - mask_l = l_offsets < L - mask_c = c_offsets < C - mask = mask_l[:, None] & mask_c[None, :] - - x_off = pid_b * stride_x_b + l_offsets[:, None] * stride_x_l + c_offsets[None, :] * stride_x_c - x = tl.load(x_ptr + x_off, mask=mask, other=0) - - idx_off = pid_b * stride_i_b + l_offsets * stride_i_l - idx = tl.load(index_ptr + idx_off, mask=mask_l, other=0).to(tl.int1)[:, None] - - s0_off = pid_b * stride_s0_b + c_offsets[None, :] * stride_s0_c - sc0_off = pid_b * stride_sc0_b + c_offsets[None, :] * stride_sc0_c - s1_off = pid_b * stride_s1_b + c_offsets[None, :] * stride_s1_c - sc1_off = pid_b * stride_sc1_b + c_offsets[None, :] * stride_sc1_c - - shift0 = tl.load(shift0_ptr + s0_off, mask=mask_c[None, :], other=0) - scale0 = tl.load(scale0_ptr + sc0_off, mask=mask_c[None, :], other=0) - shift1 = tl.load(shift1_ptr + s1_off, mask=mask_c[None, :], other=0) - scale1 = tl.load(scale1_ptr + sc1_off, mask=mask_c[None, :], other=0) - - shift = tl.where(idx, shift1, shift0) - scale = tl.where(idx, scale1, scale0) - - y = x * (1 + scale) + shift - tl.store(y_ptr + x_off, y, mask=mask) - - -def fuse_scale_shift_kernel( - x: torch.Tensor, - scale: torch.Tensor, - shift: torch.Tensor, - block_l: int = 128, - block_c: int = 128, -): - # assert x.is_cuda and scale.is_cuda - assert x.is_contiguous() - if x.dim() == 2: - x = x.unsqueeze(0) - - B, L, C = x.shape - output = torch.empty_like(x) - - if scale.dim() == 4: - # scale/shift: [B, F, 1, C] - rows = B * L - x_2d = x.view(rows, C) - output_2d = output.view(rows, C) - grid = lambda META: (rows, triton.cdiv(C, META["BLOCK_N"])) # noqa - num_frames = scale.shape[1] - assert L % num_frames == 0, "seq_len must be divisible by num_frames for 4D scale/shift" - frame_seqlen = L // num_frames - - # Compact [B, F, C] without the singleton dim into [B*F, C] - scale_reshaped = scale.squeeze(2).reshape(-1, C).contiguous() - shift_reshaped = shift.squeeze(2).reshape(-1, C).contiguous() - - _fused_scale_shift_4d_kernel[grid]( - output_2d, - x_2d, - scale_reshaped, - shift_reshaped, - rows, - C, - L, - num_frames, - frame_seqlen, - ) - else: - # 2D: [B, C] or [1, C] -> treat as [B, 1, C] and broadcast over L - # 3D: [B, L, C] (or broadcastable variants like [B, 1, C], [1, L, C], [1, 1, C]) - # Also support scalar (0D or 1-element) - if scale.dim() == 0 or (scale.dim() == 1 and scale.numel() == 1): - scale_blc = scale.reshape(1) - elif scale.dim() == 2: - scale_blc = scale[:, None, :] - elif scale.dim() == 3: - scale_blc = scale - else: - raise ValueError("scale must be 0D/1D(1)/2D/3D or 4D") - - if shift.dim() == 0 or (shift.dim() == 1 and shift.numel() == 1): - shift_blc = shift.reshape(1) - elif shift.dim() == 2: - shift_blc = shift[:, None, :] - elif shift.dim() == 3: - shift_blc = shift - else: - # broadcast later via expand if possible - shift_blc = shift - - need_scale_scalar = scale_blc.dim() == 1 and scale_blc.numel() == 1 - need_shift_scalar = shift_blc.dim() == 1 and shift_blc.numel() == 1 - - if not need_scale_scalar: - scale_exp = scale_blc.expand(B, L, C) - s_sb, s_sl, s_sc = scale_exp.stride() - else: - s_sb = s_sl = s_sc = 0 - - if not need_shift_scalar: - shift_exp = shift_blc.expand(B, L, C) - sh_sb, sh_sl, sh_sc = shift_exp.stride() - else: - sh_sb = sh_sl = sh_sc = 0 - - # If both scalars and both zero, copy fast-path - if need_scale_scalar and need_shift_scalar: - if (scale_blc.abs().max() == 0) and (shift_blc.abs().max() == 0): - output.copy_(x) - return output - - grid = (triton.cdiv(L, block_l), triton.cdiv(C, block_c), B) - fuse_scale_shift_kernel_blc_opt[grid]( - x, - shift_blc if need_shift_scalar else shift_exp, - scale_blc if need_scale_scalar else scale_exp, - output, - B, - L, - C, - x.stride(0), - x.stride(1), - x.stride(2), - sh_sb, - sh_sl, - sh_sc, - s_sb, - s_sl, - s_sc, - SCALE_IS_SCALAR=need_scale_scalar, - SHIFT_IS_SCALAR=need_shift_scalar, - BLOCK_L=block_l, - BLOCK_C=block_c, - num_warps=4, - num_stages=2, - ) - return output - - -def fuse_scale_shift_gate_select01_kernel( - x: torch.Tensor, - scale0: torch.Tensor, - shift0: torch.Tensor, - gate0: torch.Tensor, - scale1: torch.Tensor, - shift1: torch.Tensor, - gate1: torch.Tensor, - index: torch.Tensor, - block_l: int = 128, - block_c: int = 128, -): - assert x.is_contiguous() - if x.dim() == 2: - x = x.unsqueeze(0) - B, L, C = x.shape - output = torch.empty_like(x) - gate_out = torch.empty((B, L, C), device=x.device, dtype=x.dtype) - - if scale0.dim() != 2 or shift0.dim() != 2 or gate0.dim() != 2 or scale1.dim() != 2 or shift1.dim() != 2 or gate1.dim() != 2: - raise ValueError("scale0/shift0/gate0/scale1/shift1/gate1 must be 2D [B, C]") - if index.dim() != 2: - raise ValueError("index must be 2D [B, L]") - - grid = (triton.cdiv(L, block_l), triton.cdiv(C, block_c), B) - fuse_scale_shift_gate_select01_kernel_blc_opt[grid]( - x, - shift0, - scale0, - gate0, - shift1, - scale1, - gate1, - index, - output, - gate_out, - B, - L, - C, - x.stride(0), - x.stride(1), - x.stride(2), - shift0.stride(0), - shift0.stride(1), - scale0.stride(0), - scale0.stride(1), - gate0.stride(0), - gate0.stride(1), - shift1.stride(0), - shift1.stride(1), - scale1.stride(0), - scale1.stride(1), - gate1.stride(0), - gate1.stride(1), - index.stride(0), - index.stride(1), - gate_out.stride(0), - gate_out.stride(1), - gate_out.stride(2), - BLOCK_L=block_l, - BLOCK_C=block_c, - num_warps=4, - num_stages=2, - ) - return output, gate_out - - -def fuse_scale_shift_select01_kernel( - x: torch.Tensor, - scale0: torch.Tensor, - shift0: torch.Tensor, - scale1: torch.Tensor, - shift1: torch.Tensor, - index: torch.Tensor, - block_l: int = 128, - block_c: int = 128, -): - assert x.is_contiguous() - if x.dim() == 2: - x = x.unsqueeze(0) - B, L, C = x.shape - output = torch.empty_like(x) - - if scale0.dim() != 2 or shift0.dim() != 2 or scale1.dim() != 2 or shift1.dim() != 2: - raise ValueError("scale0/shift0/scale1/shift1 must be 2D [B, C]") - if index.dim() != 2: - raise ValueError("index must be 2D [B, L]") - - grid = (triton.cdiv(L, block_l), triton.cdiv(C, block_c), B) - fuse_scale_shift_select01_kernel_blc_opt[grid]( - x, - shift0, - scale0, - shift1, - scale1, - index, - output, - B, - L, - C, - x.stride(0), - x.stride(1), - x.stride(2), - shift0.stride(0), - shift0.stride(1), - scale0.stride(0), - scale0.stride(1), - shift1.stride(0), - shift1.stride(1), - scale1.stride(0), - scale1.stride(1), - index.stride(0), - index.stride(1), - BLOCK_L=block_l, - BLOCK_C=block_c, - num_warps=4, - num_stages=2, - ) - return output - - -@triton.autotune( - configs=[ - triton.Config({"BLOCK_HS_HALF": 32}, num_warps=2), - triton.Config({"BLOCK_HS_HALF": 64}, num_warps=4), - triton.Config({"BLOCK_HS_HALF": 128}, num_warps=4), - triton.Config({"BLOCK_HS_HALF": 256}, num_warps=8), - ], - key=["head_size", "interleaved"], -) -@triton.jit -def _rotary_embedding_kernel( - output_ptr, - x_ptr, - cos_ptr, - sin_ptr, - num_heads, - head_size, - num_tokens, - stride_x_row, - stride_cos_row, - stride_sin_row, - interleaved: tl.constexpr, - BLOCK_HS_HALF: tl.constexpr, -): - row_idx = tl.program_id(0) - token_idx = (row_idx // num_heads) % num_tokens - - x_row_ptr = x_ptr + row_idx * stride_x_row - cos_row_ptr = cos_ptr + token_idx * stride_cos_row - sin_row_ptr = sin_ptr + token_idx * stride_sin_row - output_row_ptr = output_ptr + row_idx * stride_x_row - - # half size for x1 and x2 - head_size_half = head_size // 2 - - for block_start in range(0, head_size_half, BLOCK_HS_HALF): - offsets_half = block_start + tl.arange(0, BLOCK_HS_HALF) - mask = offsets_half < head_size_half - - cos_vals = tl.load(cos_row_ptr + offsets_half, mask=mask, other=0.0) - sin_vals = tl.load(sin_row_ptr + offsets_half, mask=mask, other=0.0) - - offsets_x1 = 2 * offsets_half - offsets_x2 = 2 * offsets_half + 1 - - x1_vals = tl.load(x_row_ptr + offsets_x1, mask=mask, other=0.0) - x2_vals = tl.load(x_row_ptr + offsets_x2, mask=mask, other=0.0) - - x1_fp32 = x1_vals.to(tl.float32) - x2_fp32 = x2_vals.to(tl.float32) - cos_fp32 = cos_vals.to(tl.float32) - sin_fp32 = sin_vals.to(tl.float32) - o1_vals = tl.fma(-x2_fp32, sin_fp32, x1_fp32 * cos_fp32) - o2_vals = tl.fma(x1_fp32, sin_fp32, x2_fp32 * cos_fp32) - - tl.store(output_row_ptr + offsets_x1, o1_vals.to(x1_vals.dtype), mask=mask) - tl.store(output_row_ptr + offsets_x2, o2_vals.to(x2_vals.dtype), mask=mask) - - -def apply_rotary_embedding(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False) -> torch.Tensor: - output = torch.empty_like(x) - - if x.dim() > 3: - bsz, num_tokens, num_heads, head_size = x.shape - else: - num_tokens, num_heads, head_size = x.shape - bsz = 1 - - assert head_size % 2 == 0, "head_size must be divisible by 2" - - x_reshaped = x.view(-1, head_size) - output_reshaped = output.view(-1, head_size) - - # num_tokens per head, 1 token per block - grid = (bsz * num_tokens * num_heads,) - - if interleaved and cos.shape[-1] == head_size: - cos = cos[..., ::2].contiguous() - sin = sin[..., ::2].contiguous() - else: - cos = cos.contiguous() - sin = sin.contiguous() - - _rotary_embedding_kernel[grid]( - output_reshaped, - x_reshaped, - cos, - sin, - num_heads, - head_size, - num_tokens, - x_reshaped.stride(0), - cos.stride(0), - sin.stride(0), - interleaved, - ) - - return output - - -# RMSNorm-fp32 -def maybe_contiguous_lastdim(x): - return x.contiguous() if x is not None and x.stride(-1) != 1 else x - - -def maybe_contiguous(x): - return x.contiguous() if x is not None else None - - -def triton_autotune_configs(): - if not torch.cuda.is_available(): - return [] - # Return configs with a valid warp count for the current device - configs = [] - # Maximum threads per block is architecture-dependent in theory, but in reality all are 1024 - max_threads_per_block = 1024 - # Default to warp size 32 if not defined by device - warp_size = getattr(torch.cuda.get_device_properties(torch.cuda.current_device()), "warp_size", 32) - # Autotune for warp counts which are powers of 2 and do not exceed thread per block limit - return [triton.Config({}, num_warps=warp_count) for warp_count in [1, 2, 4, 8, 16, 32] if warp_count * warp_size <= max_threads_per_block] - # return [triton.Config({}, num_warps=8)] - - -# Copied from flash-attn -@triton.autotune( - configs=triton_autotune_configs(), - key=[ - "N", - "HAS_RESIDUAL", - "STORE_RESIDUAL_OUT", - "IS_RMS_NORM", - "HAS_BIAS", - "HAS_WEIGHT", - "HAS_X1", - "HAS_W1", - "HAS_B1", - ], -) -# torch compile doesn't like triton.heuristics, so we set these manually when calling the kernel -# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) -# @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None}) -# @triton.heuristics({"HAS_X1": lambda args: args["X1"] is not None}) -# @triton.heuristics({"HAS_W1": lambda args: args["W1"] is not None}) -# @triton.heuristics({"HAS_B1": lambda args: args["B1"] is not None}) -@triton.jit -def _layer_norm_fwd_1pass_kernel( - X, # pointer to the input - Y, # pointer to the output - W, # pointer to the weights - B, # pointer to the biases - RESIDUAL, # pointer to the residual - X1, - W1, - B1, - Y1, - RESIDUAL_OUT, # pointer to the residual - ROWSCALE, - SEEDS, # Dropout seeds for each row - DROPOUT_MASK, - DROPOUT_MASK1, - Mean, # pointer to the mean - Rstd, # pointer to the 1/std - stride_x_row, # how much to increase the pointer when moving by 1 row - stride_y_row, - stride_res_row, - stride_res_out_row, - stride_x1_row, - stride_y1_row, - M, # number of rows in X - N, # number of columns in X - eps, # epsilon to avoid division by zero - dropout_p, # Dropout probability - zero_centered_weight, # If true, add 1.0 to the weight - IS_RMS_NORM: tl.constexpr, - BLOCK_N: tl.constexpr, - HAS_RESIDUAL: tl.constexpr, - STORE_RESIDUAL_OUT: tl.constexpr, - HAS_WEIGHT: tl.constexpr, - HAS_BIAS: tl.constexpr, - HAS_DROPOUT: tl.constexpr, - STORE_DROPOUT_MASK: tl.constexpr, - HAS_ROWSCALE: tl.constexpr, - HAS_X1: tl.constexpr, - HAS_W1: tl.constexpr, - HAS_B1: tl.constexpr, -): - # Map the program id to the row of X and Y it should compute. - row = tl.program_id(0) - X += row * stride_x_row - Y += row * stride_y_row - if HAS_RESIDUAL: - RESIDUAL += row * stride_res_row - if STORE_RESIDUAL_OUT: - RESIDUAL_OUT += row * stride_res_out_row - if HAS_X1: - X1 += row * stride_x1_row - if HAS_W1: - Y1 += row * stride_y1_row - # Compute mean and variance - cols = tl.arange(0, BLOCK_N) - x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) - if HAS_ROWSCALE: - rowscale = tl.load(ROWSCALE + row).to(tl.float32) - x *= rowscale - if HAS_DROPOUT: - # Compute dropout mask - # 7 rounds is good enough, and reduces register pressure - keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p - x = tl.where(keep_mask, x / (1.0 - dropout_p), 0.0) - if STORE_DROPOUT_MASK: - tl.store(DROPOUT_MASK + row * N + cols, keep_mask, mask=cols < N) - if HAS_X1: - x1 = tl.load(X1 + cols, mask=cols < N, other=0.0).to(tl.float32) - if HAS_ROWSCALE: - rowscale = tl.load(ROWSCALE + M + row).to(tl.float32) - x1 *= rowscale - if HAS_DROPOUT: - # Compute dropout mask - # 7 rounds is good enough, and reduces register pressure - keep_mask = tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p - x1 = tl.where(keep_mask, x1 / (1.0 - dropout_p), 0.0) - if STORE_DROPOUT_MASK: - tl.store(DROPOUT_MASK1 + row * N + cols, keep_mask, mask=cols < N) - x += x1 - if HAS_RESIDUAL: - residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32) - x += residual - if STORE_RESIDUAL_OUT: - tl.store(RESIDUAL_OUT + cols, x, mask=cols < N) - if not IS_RMS_NORM: - mean = tl.sum(x, axis=0) / N - tl.store(Mean + row, mean) - xbar = tl.where(cols < N, x - mean, 0.0) - var = tl.sum(xbar * xbar, axis=0) / N - else: - xbar = tl.where(cols < N, x, 0.0) - var = tl.sum(xbar * xbar, axis=0) / N - rstd = 1 / tl.sqrt(var + eps) - tl.store(Rstd + row, rstd) - # Normalize and apply linear transformation - mask = cols < N - if HAS_WEIGHT: - w = tl.load(W + cols, mask=mask).to(tl.float32) - if zero_centered_weight: - w += 1.0 - if HAS_BIAS: - b = tl.load(B + cols, mask=mask).to(tl.float32) - x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd - if HAS_WEIGHT: - y = x_hat * w + b if HAS_BIAS else x_hat * w - else: - y = x_hat + b if HAS_BIAS else x_hat - # Write output - tl.store(Y + cols, y, mask=mask) - if HAS_W1: - w1 = tl.load(W1 + cols, mask=mask).to(tl.float32) - if zero_centered_weight: - w1 += 1.0 - if HAS_B1: - b1 = tl.load(B1 + cols, mask=mask).to(tl.float32) - y1 = x_hat * w1 + b1 if HAS_B1 else x_hat * w1 - tl.store(Y1 + cols, y1, mask=mask) - - -def _layer_norm_fwd( - x: Tensor, - weight: Tensor, - bias: Tensor, - eps: float, - residual: Optional[Tensor] = None, - x1: Optional[Tensor] = None, - weight1: Optional[Tensor] = None, - bias1: Optional[Tensor] = None, - dropout_p: float = 0.0, - rowscale: Optional[Tensor] = None, - out_dtype: Optional[torch.dtype] = None, - residual_dtype: Optional[torch.dtype] = None, - zero_centered_weight: bool = False, - is_rms_norm: bool = False, - return_dropout_mask: bool = False, - out: Optional[Tensor] = None, - residual_out: Optional[Tensor] = None, -) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor): - # Need to wrap to handle the case where residual_out is a alias of x, which makes torch.library - # and torch.compile unhappy. Also allocate memory for out and residual_out if they are None - # so that _layer_norm_fwd_impl doesn't have to return them. - if out is None: - out = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype) - if residual is not None: - residual_dtype = residual.dtype - if residual_out is None and (residual is not None or (residual_dtype is not None and residual_dtype != x.dtype) or dropout_p > 0.0 or rowscale is not None or x1 is not None): - residual_out = torch.empty_like(x, dtype=residual_dtype if residual_dtype is not None else x.dtype) - else: - residual_out = None - y1, mean, rstd, seeds, dropout_mask, dropout_mask1 = _layer_norm_fwd_impl( - x, - weight, - bias, - eps, - out, - residual=residual, - x1=x1, - weight1=weight1, - bias1=bias1, - dropout_p=dropout_p, - rowscale=rowscale, - zero_centered_weight=zero_centered_weight, - is_rms_norm=is_rms_norm, - return_dropout_mask=return_dropout_mask, - residual_out=residual_out, - ) - # residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0 - if residual_out is None: - residual_out = x - return out, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 - - -# [2025-04-28] torch.library.triton_op ignores the schema argument, but here we need the schema -# since we're returning a tuple of tensors -def _layer_norm_fwd_impl( - x: Tensor, - weight: Optional[Tensor], - bias: Tensor, - eps: float, - out: Tensor, - residual: Optional[Tensor] = None, - x1: Optional[Tensor] = None, - weight1: Optional[Tensor] = None, - bias1: Optional[Tensor] = None, - dropout_p: float = 0.0, - rowscale: Optional[Tensor] = None, - zero_centered_weight: bool = False, - is_rms_norm: bool = False, - return_dropout_mask: bool = False, - residual_out: Optional[Tensor] = None, -) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor): - M, N = x.shape - assert x.stride(-1) == 1 - if residual is not None: - assert residual.stride(-1) == 1 - assert residual.shape == (M, N) - if weight is not None: - assert weight.shape == (N,) - assert weight.stride(-1) == 1 - if bias is not None: - assert bias.stride(-1) == 1 - assert bias.shape == (N,) - if x1 is not None: - assert x1.shape == x.shape - assert rowscale is None - assert x1.stride(-1) == 1 - if weight1 is not None: - assert weight1.shape == (N,) - assert weight1.stride(-1) == 1 - if bias1 is not None: - assert bias1.shape == (N,) - assert bias1.stride(-1) == 1 - if rowscale is not None: - assert rowscale.is_contiguous() - assert rowscale.shape == (M,) - assert out.shape == x.shape - assert out.stride(-1) == 1 - if residual_out is not None: - assert residual_out.shape == x.shape - assert residual_out.stride(-1) == 1 - if weight1 is not None: - y1 = torch.empty_like(out) - assert y1.stride(-1) == 1 - else: - y1 = None - mean = torch.empty((M,), dtype=torch.float32, device=x.device) if not is_rms_norm else None - rstd = torch.empty((M,), dtype=torch.float32, device=x.device) - if dropout_p > 0.0: - seeds = torch.randint(2**32, (M if x1 is None else 2 * M,), device=x.device, dtype=torch.int64) - else: - seeds = None - if return_dropout_mask and dropout_p > 0.0: - dropout_mask = torch.empty(M, N, device=x.device, dtype=torch.bool) - if x1 is not None: - dropout_mask1 = torch.empty(M, N, device=x.device, dtype=torch.bool) - else: - dropout_mask1 = None - else: - dropout_mask, dropout_mask1 = None, None - # Less than 64KB per feature: enqueue fused kernel - MAX_FUSED_SIZE = 65536 // x.element_size() - BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) - if N > BLOCK_N: - raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") - with torch.cuda.device(x.device.index): - torch.library.wrap_triton(_layer_norm_fwd_1pass_kernel)[(M,)]( - x, - out, - weight if weight is not None else x, # unused when HAS_WEIGHT == False - bias, - residual, - x1, - weight1, - bias1, - y1, - residual_out, - rowscale, - seeds, - dropout_mask, - dropout_mask1, - mean, - rstd, - x.stride(0), - out.stride(0), - residual.stride(0) if residual is not None else 0, - residual_out.stride(0) if residual_out is not None else 0, - x1.stride(0) if x1 is not None else 0, - y1.stride(0) if y1 is not None else 0, - M, - N, - eps, - dropout_p, - # Passing bool make torch inductor very unhappy since it then tries to compare to int_max - int(zero_centered_weight), - is_rms_norm, - BLOCK_N, - residual is not None, - residual_out is not None, - weight is not None, - bias is not None, - dropout_p > 0.0, - dropout_mask is not None, - rowscale is not None, - HAS_X1=x1 is not None, - HAS_W1=weight1 is not None, - HAS_B1=bias1 is not None, - ) - return y1, mean, rstd, seeds, dropout_mask, dropout_mask1 - - -class LayerNormFn: - @staticmethod - def forward( - x, - weight, - bias, - residual=None, - x1=None, - weight1=None, - bias1=None, - eps=1e-6, - dropout_p=0.0, - rowscale=None, - prenorm=False, - residual_in_fp32=False, - zero_centered_weight=False, - is_rms_norm=False, - return_dropout_mask=False, - out_dtype=None, - out=None, - residual_out=None, - ): - x_shape_og = x.shape - # reshape input data into 2D tensor - x = maybe_contiguous_lastdim(x.reshape(-1, x.shape[-1])) - if residual is not None: - assert residual.shape == x_shape_og - residual = maybe_contiguous_lastdim(residual.reshape(-1, residual.shape[-1])) - if x1 is not None: - assert x1.shape == x_shape_og - assert rowscale is None, "rowscale is not supported with parallel LayerNorm" - x1 = maybe_contiguous_lastdim(x1.reshape(-1, x1.shape[-1])) - # weight can be None when elementwise_affine=False for LayerNorm - if weight is not None: - weight = weight.contiguous() - bias = maybe_contiguous(bias) - weight1 = maybe_contiguous(weight1) - bias1 = maybe_contiguous(bias1) - if rowscale is not None: - rowscale = rowscale.reshape(-1).contiguous() - residual_dtype = residual.dtype if residual is not None else (torch.float32 if residual_in_fp32 else None) - if out is not None: - out = out.reshape(-1, out.shape[-1]) - if residual_out is not None: - residual_out = residual_out.reshape(-1, residual_out.shape[-1]) - y, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 = _layer_norm_fwd( - x, - weight, - bias, - eps, - residual, - x1, - weight1, - bias1, - dropout_p=dropout_p, - rowscale=rowscale, - out_dtype=out_dtype, - residual_dtype=residual_dtype, - zero_centered_weight=zero_centered_weight, - is_rms_norm=is_rms_norm, - return_dropout_mask=return_dropout_mask, - out=out, - residual_out=residual_out, - ) - y = y.reshape(x_shape_og) - return y - - -def layer_norm_fn( - x, - weight, - bias, - residual=None, - x1=None, - weight1=None, - bias1=None, - eps=1e-6, - dropout_p=0.0, - rowscale=None, - prenorm=False, - residual_in_fp32=False, - zero_centered_weight=False, - is_rms_norm=False, - return_dropout_mask=False, - out_dtype=None, - out=None, - residual_out=None, -): - return LayerNormFn.forward( - x, - weight, - bias, - residual, - x1, - weight1, - bias1, - eps, - dropout_p, - rowscale, - prenorm, - residual_in_fp32, - zero_centered_weight, - is_rms_norm, - return_dropout_mask, - out_dtype, - out, - residual_out, - ) - - -@triton.jit -def _norm_infer_kernel( - X, - Y, - W, - B, - stride_x_row, - stride_y_row, - M, - N, - eps, - IS_RMS_NORM: tl.constexpr, - HAS_WEIGHT: tl.constexpr, - HAS_BIAS: tl.constexpr, - BLOCK_N: tl.constexpr, -): - row = tl.program_id(0) - X += row * stride_x_row - Y += row * stride_y_row - if HAS_WEIGHT: - W += 0 - if HAS_BIAS: - B += 0 - cols = tl.arange(0, BLOCK_N) - x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) - if not IS_RMS_NORM: - mean = tl.sum(x, axis=0) / N - xbar = tl.where(cols < N, x - mean, 0.0) - var = tl.sum(xbar * xbar, axis=0) / N - else: - xbar = tl.where(cols < N, x, 0.0) - var = tl.sum(xbar * xbar, axis=0) / N - rstd = 1 / tl.sqrt(var + eps) - x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd - if HAS_WEIGHT: - w = tl.load(W + cols, mask=cols < N, other=1.0).to(tl.float32) - y = x_hat * w - else: - y = x_hat - if HAS_BIAS: - b = tl.load(B + cols, mask=cols < N, other=0.0).to(tl.float32) - y += b - tl.store(Y + cols, y, mask=cols < N) - - -def norm_infer( - x: Tensor, - weight: Optional[Tensor], - bias: Optional[Tensor], - eps: float, - is_rms_norm: bool = False, - out: Optional[Tensor] = None, -): - M, N = x.shape - x = x.contiguous() - if weight is not None: - assert weight.shape == (N,) - assert weight.stride(-1) == 1 - if bias is not None: - assert bias.shape == (N,) - assert bias.stride(-1) == 1 - if out is None: - out = torch.empty_like(x) - MAX_FUSED_SIZE = 65536 // x.element_size() - BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) - if N > BLOCK_N: - raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") - num_warps = min(max(BLOCK_N // 256, 1), 8) - _norm_infer_kernel[(M,)]( - x, - out, - weight if weight is not None else x, # dummy when HAS_WEIGHT=False - bias if bias is not None else x, # dummy when HAS_BIAS=False - x.stride(0), - out.stride(0), - M, - N, - eps, - IS_RMS_NORM=is_rms_norm, - HAS_WEIGHT=weight is not None, - HAS_BIAS=bias is not None, - BLOCK_N=BLOCK_N, - num_warps=num_warps, - ) - return out - - -def rms_norm_fn( - x, - weight, - bias, - residual=None, - x1=None, - weight1=None, - bias1=None, - eps=1e-6, - dropout_p=0.0, - rowscale=None, - prenorm=False, - residual_in_fp32=False, - zero_centered_weight=False, - return_dropout_mask=False, - out_dtype=None, - out=None, - residual_out=None, -): - return LayerNormFn.forward( - x, - weight, - bias, - residual, - x1, - weight1, - bias1, - eps, - dropout_p, - rowscale, - prenorm, - residual_in_fp32, - zero_centered_weight, - True, - return_dropout_mask, - out_dtype, - out, - residual_out, - ) diff --git a/lightx2v/models/networks/z_image/model.py b/lightx2v/models/networks/z_image/model.py old mode 100644 new mode 100755 index ec4ea3076..dd59d930d --- a/lightx2v/models/networks/z_image/model.py +++ b/lightx2v/models/networks/z_image/model.py @@ -1,12 +1,10 @@ import gc import glob -import json import os import torch import torch.distributed as dist from safetensors import safe_open -from torch.nn import functional as F from lightx2v.utils.envs import * from lightx2v.utils.utils import * @@ -25,24 +23,22 @@ class ZImageTransformerModel: transformer_weight_class = ZImageTransformerWeights post_weight_class = ZImagePostWeights - def __init__(self, config): + def __init__(self, config, lora_path=None, lora_strength=1.0): self.config = config self.model_path = os.path.join(config["model_path"], "transformer") + self.lora_path = lora_path + self.lora_strength = lora_strength self.cpu_offload = config.get("cpu_offload", False) self.offload_granularity = self.config.get("offload_granularity", "block") self.device = torch.device("cpu") if self.cpu_offload else torch.device(AI_DEVICE) - - with open(os.path.join(config["model_path"], "transformer", "config.json"), "r") as f: - transformer_config = json.load(f) - self.in_channels = transformer_config["in_channels"] - self.attention_kwargs = {} - + self.remove_keys = [] + self.lazy_load = self.config.get("lazy_load", False) + if self.lazy_load: + self.remove_keys.extend(["layers."]) self.dit_quantized = self.config.get("dit_quantized", False) if self.config["seq_parallel"]: - self.seq_p_group = self.config.get("device_mesh").get_group(mesh_dim="seq_p") - else: - self.seq_p_group = None + raise NotImplementedError self._init_infer_class() self._init_weights() @@ -75,10 +71,7 @@ def _init_weights(self, weight_dict=None): weight_dict = self._load_ckpt(unified_dtype, sensitive_layer) else: # Load quantized weights - if not self.config.get("lazy_load", False): - weight_dict = self._load_quant_ckpt(unified_dtype, sensitive_layer) - else: - weight_dict = self._load_quant_split_ckpt(unified_dtype, sensitive_layer) + weight_dict = self._load_quant_ckpt(unified_dtype, sensitive_layer) if self.config.get("device_mesh") is not None and self.config.get("load_from_rank0", False): weight_dict = self._load_weights_from_rank0(weight_dict, is_weight_loader) @@ -89,7 +82,10 @@ def _init_weights(self, weight_dict=None): # Initialize weight containers self.pre_weight = self.pre_weight_class(self.config) - self.transformer_weights = self.transformer_weight_class(self.config) + if self.lazy_load: + self.transformer_weights = self.transformer_weight_class(self.config, self.lazy_load_path, self.lora_path) + else: + self.transformer_weights = self.transformer_weight_class(self.config) self.post_weight = self.post_weight_class(self.config) if not self._should_init_empty_model(): self._apply_weights() @@ -101,6 +97,7 @@ def _apply_weights(self, weight_dict=None): gc.collect() # Load weights into containers self.pre_weight.load(self.original_weight_dict) + print(self.original_weight_dict.keys()) self.transformer_weights.load(self.original_weight_dict) self.post_weight.load(self.original_weight_dict) @@ -124,7 +121,7 @@ def _should_load_weights(self): return False def _should_init_empty_model(self): - if self.config.get("lora_configs") and self.config["lora_configs"]: + if self.config.get("lora_configs") and self.config["lora_configs"] and not self.config.get("lora_dynamic_apply", False): return True return False @@ -150,8 +147,18 @@ def _load_ckpt(self, unified_dtype, sensitive_layer): safetensors_path = self.model_path if os.path.isdir(safetensors_path): - safetensors_files = glob.glob(os.path.join(safetensors_path, "*.safetensors")) + if self.lazy_load: + self.lazy_load_path = safetensors_path + non_block_file = os.path.join(safetensors_path, "non_block.safetensors") + if os.path.exists(non_block_file): + safetensors_files = [non_block_file] + else: + raise ValueError(f"Non-block file not found in {safetensors_path}. Please check the model path.") + else: + safetensors_files = glob.glob(os.path.join(safetensors_path, "*.safetensors")) else: + if self.lazy_load: + self.lazy_load_path = safetensors_path safetensors_files = [safetensors_path] weight_dict = {} @@ -171,8 +178,18 @@ def _load_quant_ckpt(self, unified_dtype, sensitive_layer): safetensors_path = self.model_path if os.path.isdir(safetensors_path): - safetensors_files = glob.glob(os.path.join(safetensors_path, "*.safetensors")) + if self.lazy_load: + self.lazy_load_path = safetensors_path + non_block_file = os.path.join(safetensors_path, "non_block.safetensors") + if os.path.exists(non_block_file): + safetensors_files = [non_block_file] + else: + raise ValueError(f"Non-block file not found in {safetensors_path}. Please check the model path.") + else: + safetensors_files = glob.glob(os.path.join(safetensors_path, "*.safetensors")) else: + if self.lazy_load: + self.lazy_load_path = safetensors_path safetensors_files = [safetensors_path] safetensors_path = os.path.dirname(safetensors_path) @@ -204,28 +221,6 @@ def _load_quant_ckpt(self, unified_dtype, sensitive_layer): return weight_dict - def _load_quant_split_ckpt(self, unified_dtype, sensitive_layer): # Need rewrite - lazy_load_model_path = self.dit_quantized_ckpt - logger.info(f"Loading splited quant model from {lazy_load_model_path}") - pre_post_weight_dict = {} - - safetensor_path = os.path.join(lazy_load_model_path, "non_block.safetensors") - with safe_open(safetensor_path, framework="pt", device="cpu") as f: - for k in f.keys(): - if f.get_tensor(k).dtype in [ - torch.float16, - torch.bfloat16, - torch.float, - ]: - if unified_dtype or all(s not in k for s in sensitive_layer): - pre_post_weight_dict[k] = f.get_tensor(k).to(GET_DTYPE()).to(self.device) - else: - pre_post_weight_dict[k] = f.get_tensor(k).to(GET_SENSITIVE_DTYPE()).to(self.device) - else: - pre_post_weight_dict[k] = f.get_tensor(k).to(self.device) - - return pre_post_weight_dict - def _load_weights_from_rank0(self, weight_dict, is_weight_loader): logger.info("Loading distributed weights") global_src_rank = 0 @@ -310,6 +305,7 @@ def infer(self, inputs): elif self.offload_granularity != "model": self.pre_weight.to_cuda() self.post_weight.to_cuda() + self.transformer_weights.non_block_weights_to_cuda() latents = self.scheduler.latents latents_input = latents @@ -356,6 +352,14 @@ def infer(self, inputs): self.scheduler.noise_pred = noise_pred + if self.cpu_offload: + if self.offload_granularity == "model" and self.scheduler.step_index == self.scheduler.infer_steps - 1 and "wan2.2_moe" not in self.config["model_cls"]: + self.to_cpu() + elif self.offload_granularity != "model": + self.pre_weight.to_cpu() + self.post_weight.to_cpu() + self.transformer_weights.non_block_weights_to_cpu() + @torch.no_grad() def _infer_cond_uncond(self, latents_input, prompt_embeds, infer_condition=True): self.scheduler.infer_condition = infer_condition @@ -385,19 +389,8 @@ def _infer_cond_uncond(self, latents_input, prompt_embeds, infer_condition=True) @torch.no_grad() def _seq_parallel_pre_process(self, pre_infer_out): - world_size = dist.get_world_size(self.seq_p_group) - cur_rank = dist.get_rank(self.seq_p_group) - seqlen = pre_infer_out.hidden_states.shape[1] - padding_size = (world_size - (seqlen % world_size)) % world_size - if padding_size > 0: - pre_infer_out.hidden_states = F.pad(pre_infer_out.hidden_states, (0, 0, 0, padding_size)) - pre_infer_out.hidden_states = torch.chunk(pre_infer_out.hidden_states, world_size, dim=1)[cur_rank] - return pre_infer_out + raise NotImplementedError @torch.no_grad() def _seq_parallel_post_process(self, noise_pred): - world_size = dist.get_world_size(self.seq_p_group) - gathered_noise_pred = [torch.empty_like(noise_pred) for _ in range(world_size)] - dist.all_gather(gathered_noise_pred, noise_pred, group=self.seq_p_group) - noise_pred = torch.cat(gathered_noise_pred, dim=1) - return noise_pred + raise NotImplementedError diff --git a/lightx2v/models/networks/z_image/weights/transformer_weights.py b/lightx2v/models/networks/z_image/weights/transformer_weights.py index 6fb2920bf..a23275d86 100755 --- a/lightx2v/models/networks/z_image/weights/transformer_weights.py +++ b/lightx2v/models/networks/z_image/weights/transformer_weights.py @@ -1,19 +1,13 @@ -import os - -import torch.nn.functional as F -from safetensors import safe_open - from lightx2v.common.modules.weight_module import WeightModule, WeightModuleList -from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER, MM_WEIGHT_REGISTER, RMS_WEIGHT_REGISTER +from lightx2v.utils.registry_factory import ( + ATTN_WEIGHT_REGISTER, + MM_WEIGHT_REGISTER, + RMS_WEIGHT_REGISTER, +) class ZImageTransformerWeights(WeightModule): - """ - Z-Image single stream transformer weights. - Based on ZImageTransformer2DModel structure. - """ - - def __init__(self, config): + def __init__(self, config, lazy_load_path=None, lora_path=None): super().__init__() self.blocks_num = config["n_layers"] self.task = config["task"] @@ -21,59 +15,97 @@ def __init__(self, config): self.mm_type = config.get("dit_quant_scheme", "Default") if self.mm_type != "Default": assert config.get("dit_quantized") is True + self.lazy_load = self.config.get("lazy_load", False) + self.n_refiner_layers = config.get("n_refiner_layers", 0) + self.register_offload_buffers(config, lazy_load_path, lora_path) + self.add_module( + "blocks", + WeightModuleList(ZImageTransformerBlock(i, self.task, self.mm_type, self.config, False, False, "layers") for i in range(self.blocks_num)), + ) - # Main transformer blocks - blocks = WeightModuleList(ZImageTransformerBlock(i, self.task, self.mm_type, self.config, False, False, "layers") for i in range(self.blocks_num)) - self.add_module("blocks", blocks) - - # Noise refiner (if exists) - n_refiner_layers = config.get("n_refiner_layers", 0) - if n_refiner_layers > 0: - noise_refiner = WeightModuleList( + self.add_module( + "noise_refiner", + WeightModuleList( ZImageTransformerBlock( - i, # layer_id should be the index i for noise_refiner (not 1000 + i) + i, self.task, self.mm_type, self.config, False, False, "noise_refiner", - modulation=True, + has_modulation=True, ) - for i in range(n_refiner_layers) - ) - self.add_module("noise_refiner", noise_refiner) - else: - self.noise_refiner = None + for i in range(self.n_refiner_layers) + ), + ) - # Context refiner (if exists) - if n_refiner_layers > 0: - context_refiner = WeightModuleList( + self.add_module( + "context_refiner", + WeightModuleList( ZImageTransformerBlock( - i, # layer_id should be the index i for context_refiner + i, self.task, self.mm_type, self.config, False, False, "context_refiner", - modulation=False, + has_modulation=False, ) - for i in range(n_refiner_layers) - ) - self.add_module("context_refiner", context_refiner) - else: - self.context_refiner = None - - self.register_offload_buffers(config) + for i in range(self.n_refiner_layers) + ), + ) - def register_offload_buffers(self, config): + def register_offload_buffers(self, config, lazy_load_path, lora_path): if config["cpu_offload"]: if config["offload_granularity"] == "block": self.offload_blocks_num = 2 - self.offload_block_cuda_buffers = WeightModuleList([ZImageTransformerBlock(i, self.task, self.mm_type, self.config, True, False, "layers") for i in range(self.offload_blocks_num)]) + self.offload_block_cuda_buffers = WeightModuleList( + [ + ZImageTransformerBlock( + i, + self.task, + self.mm_type, + self.config, + True, + False, + "layers", + ) + for i in range(self.offload_blocks_num) + ] + ) self.add_module("offload_block_cuda_buffers", self.offload_block_cuda_buffers) self.offload_phase_cuda_buffers = None + if self.lazy_load: + self.offload_blocks_num = 2 + self.offload_block_cpu_buffers = WeightModuleList( + [ + ZImageTransformerBlock( + i, + self.task, + self.mm_type, + self.config, + False, + True, + "layers", + lazy_load=self.lazy_load, + lazy_load_path=lazy_load_path, + lora_path=lora_path, + ) + for i in range(self.offload_blocks_num) + ] + ) + self.add_module("offload_block_cpu_buffers", self.offload_block_cpu_buffers) + self.offload_phase_cpu_buffers = None + + def non_block_weights_to_cuda(self): + self.noise_refiner.to_cuda() + self.context_refiner.to_cuda() + + def non_block_weights_to_cpu(self): + self.noise_refiner.to_cpu() + self.context_refiner.to_cpu() class ZImageTransformerBlock(WeightModule): @@ -84,141 +116,173 @@ class ZImageTransformerBlock(WeightModule): def __init__( self, - layer_id, + block_idx, task, mm_type, config, create_cuda_buffer=False, create_cpu_buffer=False, block_prefix="layers", - modulation=True, + has_modulation=True, + lazy_load=False, + lazy_load_path=None, + lora_path=None, ): super().__init__() - self.layer_id = layer_id + self.block_idx = block_idx self.mm_type = mm_type self.task = task self.config = config - self.modulation = modulation - self.quant_method = config.get("quant_method", None) - self.sparge = config.get("sparge", False) + self.has_modulation = has_modulation self.ln_type = config.get("ln_type", "Triton") self.rms_norm_type = config.get("rms_norm_type", "sgl-kernel") - self.lazy_load = self.config.get("lazy_load", False) + self.lazy_load = lazy_load if self.lazy_load: - lazy_load_path = os.path.join(self.config["dit_quantized_ckpt"], f"block_{layer_id}.safetensors") - self.lazy_load_file = safe_open(lazy_load_path, framework="pt", device="cpu") + self.lazy_load_file = lazy_load_path else: self.lazy_load_file = None - # Attention normalization layers - self.add_module( - "attention_norm1", - RMS_WEIGHT_REGISTER[self.rms_norm_type]( - f"{block_prefix}.{layer_id}.attention_norm1.weight", - create_cuda_buffer=create_cuda_buffer, - create_cpu_buffer=create_cpu_buffer, - ), - ) - self.add_module( - "attention_norm2", - RMS_WEIGHT_REGISTER[self.rms_norm_type]( - f"{block_prefix}.{layer_id}.attention_norm2.weight", - create_cuda_buffer=create_cuda_buffer, - create_cpu_buffer=create_cpu_buffer, - ), + self.compute_phases = WeightModuleList( + [ + ( + ZImageAdaLNModulation( + block_idx=block_idx, + block_prefix=block_prefix, + task=task, + mm_type=mm_type, + config=config, + create_cuda_buffer=create_cuda_buffer, + create_cpu_buffer=create_cpu_buffer, + lazy_load=self.lazy_load, + lazy_load_file=self.lazy_load_file, + ) + if self.has_modulation + else WeightModule() + ), + ZImageAttention( + block_idx=block_idx, + block_prefix=block_prefix, + task=task, + mm_type=mm_type, + config=config, + create_cuda_buffer=create_cuda_buffer, + create_cpu_buffer=create_cpu_buffer, + lazy_load=self.lazy_load, + lazy_load_file=self.lazy_load_file, + ), + ZImageFFN( + block_idx=block_idx, + block_prefix=block_prefix, + task=task, + mm_type=mm_type, + config=config, + create_cuda_buffer=create_cuda_buffer, + create_cpu_buffer=create_cpu_buffer, + lazy_load=self.lazy_load, + lazy_load_file=self.lazy_load_file, + ), + ] ) + self.add_module("compute_phases", self.compute_phases) - # Single stream attention - self.attention = ZImageSingleStreamAttention( - layer_id=layer_id, - block_prefix=block_prefix, - task=task, - mm_type=mm_type, - config=config, - create_cuda_buffer=create_cuda_buffer, - create_cpu_buffer=create_cpu_buffer, - lazy_load=self.lazy_load, - lazy_load_file=self.lazy_load_file, - ) - self.add_module("attention", self.attention) - # FFN normalization layers - # Note: In Z-Image, ffn_norm1 and ffn_norm2 are directly under layers.{layer_id}, not under feed_forward - self.add_module( - "ffn_norm1", - RMS_WEIGHT_REGISTER[self.rms_norm_type]( - f"{block_prefix}.{layer_id}.ffn_norm1.weight", - create_cuda_buffer=create_cuda_buffer, - create_cpu_buffer=create_cpu_buffer, - ), - ) +class ZImageAdaLNModulation(WeightModule): + def __init__( + self, + block_idx, + block_prefix, + task, + mm_type, + config, + create_cuda_buffer, + create_cpu_buffer, + lazy_load, + lazy_load_file, + ): + super().__init__() + self.block_idx = block_idx + self.mm_type = mm_type + self.task = task + self.config = config + self.lazy_load = lazy_load + self.lazy_load_file = lazy_load_file self.add_module( - "ffn_norm2", - RMS_WEIGHT_REGISTER[self.rms_norm_type]( - f"{block_prefix}.{layer_id}.ffn_norm2.weight", - create_cuda_buffer=create_cuda_buffer, - create_cpu_buffer=create_cpu_buffer, + "adaLN_modulation", + MM_WEIGHT_REGISTER["Default"]( + f"{block_prefix}.{block_idx}.adaLN_modulation.0.weight", + f"{block_prefix}.{block_idx}.adaLN_modulation.0.bias", + create_cuda_buffer, + create_cpu_buffer, + self.lazy_load, + self.lazy_load_file, ), ) - # Feed forward network - self.feed_forward = ZImageFeedForward( - layer_id=layer_id, - block_prefix=block_prefix, - task=task, - mm_type=mm_type, - config=config, - create_cuda_buffer=create_cuda_buffer, - create_cpu_buffer=create_cpu_buffer, - lazy_load=self.lazy_load, - lazy_load_file=self.lazy_load_file, - ) - self.add_module("feed_forward", self.feed_forward) + def to_cpu(self, non_blocking=True): + for module in self._modules.values(): + if module is not None and hasattr(module, "to_cpu"): + module.to_cpu(non_blocking=non_blocking) - # AdaLN modulation (if modulation is enabled) - if self.modulation: - dim = config["dim"] - adaln_embed_dim = min(dim, 256) # ADALN_EMBED_DIM = 256 - self.add_module( - "adaLN_modulation", - MM_WEIGHT_REGISTER[self.mm_type]( - f"{block_prefix}.{layer_id}.adaLN_modulation.0.weight", - f"{block_prefix}.{layer_id}.adaLN_modulation.0.bias", - create_cuda_buffer, - create_cpu_buffer, - self.lazy_load, - self.lazy_load_file, - ), - ) + def to_cuda(self, non_blocking=True): + for module in self._modules.values(): + if module is not None and hasattr(module, "to_cuda"): + print(module) + module.to_cuda(non_blocking=non_blocking) -class ZImageSingleStreamAttention(WeightModule): +class ZImageAttention(WeightModule): """ Single stream attention for Z-Image. Based on ZSingleStreamAttnProcessor. """ - def __init__(self, layer_id, block_prefix, task, mm_type, config, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file): + def __init__( + self, + block_idx, + block_prefix, + task, + mm_type, + config, + create_cuda_buffer, + create_cpu_buffer, + lazy_load, + lazy_load_file, + ): super().__init__() - self.layer_id = layer_id + self.block_idx = block_idx self.mm_type = mm_type self.task = task self.config = config - self.quant_method = config.get("quant_method", None) - self.sparge = config.get("sparge", False) self.attn_type = config.get("attn_type", "flash_attn3") - self.heads = config["n_heads"] - self.rms_norm_type = config.get("rms_norm_type", "sgl-kernel") + self.rms_norm_type = config.get("rms_norm_type", "one-pass") self.lazy_load = lazy_load self.lazy_load_file = lazy_load_file + # Attention normalization layers + self.add_module( + "attention_norm1", + RMS_WEIGHT_REGISTER[self.rms_norm_type]( + f"{block_prefix}.{block_idx}.attention_norm1.weight", + create_cuda_buffer=create_cuda_buffer, + create_cpu_buffer=create_cpu_buffer, + ), + ) + self.add_module( + "attention_norm2", + RMS_WEIGHT_REGISTER[self.rms_norm_type]( + f"{block_prefix}.{block_idx}.attention_norm2.weight", + create_cuda_buffer=create_cuda_buffer, + create_cpu_buffer=create_cpu_buffer, + ), + ) + # QK normalization (applied in processor) self.add_module( "norm_q", RMS_WEIGHT_REGISTER[self.rms_norm_type]( - f"{block_prefix}.{layer_id}.attention.norm_q.weight", + f"{block_prefix}.{block_idx}.attention.norm_q.weight", create_cuda_buffer=create_cuda_buffer, create_cpu_buffer=create_cpu_buffer, ), @@ -226,7 +290,7 @@ def __init__(self, layer_id, block_prefix, task, mm_type, config, create_cuda_bu self.add_module( "norm_k", RMS_WEIGHT_REGISTER[self.rms_norm_type]( - f"{block_prefix}.{layer_id}.attention.norm_k.weight", + f"{block_prefix}.{block_idx}.attention.norm_k.weight", create_cuda_buffer=create_cuda_buffer, create_cpu_buffer=create_cpu_buffer, ), @@ -237,7 +301,7 @@ def __init__(self, layer_id, block_prefix, task, mm_type, config, create_cuda_bu self.add_module( "to_q", MM_WEIGHT_REGISTER[self.mm_type]( - f"{block_prefix}.{layer_id}.attention.to_q.weight", + f"{block_prefix}.{block_idx}.attention.to_q.weight", None, # No bias in Z-Image create_cuda_buffer, create_cpu_buffer, @@ -248,7 +312,7 @@ def __init__(self, layer_id, block_prefix, task, mm_type, config, create_cuda_bu self.add_module( "to_k", MM_WEIGHT_REGISTER[self.mm_type]( - f"{block_prefix}.{layer_id}.attention.to_k.weight", + f"{block_prefix}.{block_idx}.attention.to_k.weight", None, # No bias in Z-Image create_cuda_buffer, create_cpu_buffer, @@ -259,7 +323,7 @@ def __init__(self, layer_id, block_prefix, task, mm_type, config, create_cuda_bu self.add_module( "to_v", MM_WEIGHT_REGISTER[self.mm_type]( - f"{block_prefix}.{layer_id}.attention.to_v.weight", + f"{block_prefix}.{block_idx}.attention.to_v.weight", None, # No bias in Z-Image create_cuda_buffer, create_cpu_buffer, @@ -274,7 +338,7 @@ def __init__(self, layer_id, block_prefix, task, mm_type, config, create_cuda_bu WeightModuleList( [ MM_WEIGHT_REGISTER[self.mm_type]( - f"{block_prefix}.{layer_id}.attention.to_out.0.weight", + f"{block_prefix}.{block_idx}.attention.to_out.0.weight", None, # No bias in Z-Image create_cuda_buffer, create_cpu_buffer, @@ -305,31 +369,38 @@ def to_cuda(self, non_blocking=True): module.to_cuda(non_blocking=non_blocking) -class ZImageFeedForward(WeightModule): +class ZImageFFN(WeightModule): """ Feed forward network for Z-Image. Based on FeedForward with w1, w2, w3 and SiLU gating. """ - def __init__(self, layer_id, block_prefix, task, mm_type, config, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file): + def __init__( + self, + block_idx, + block_prefix, + task, + mm_type, + config, + create_cuda_buffer, + create_cpu_buffer, + lazy_load, + lazy_load_file, + ): super().__init__() - self.layer_id = layer_id + self.block_idx = block_idx self.mm_type = mm_type self.task = task self.config = config - self.quant_method = config.get("quant_method", None) self.lazy_load = lazy_load self.lazy_load_file = lazy_load_file - - dim = config["dim"] - hidden_dim = int(dim / 3 * 8) # FeedForward hidden_dim = dim / 3 * 8 - + self.rms_norm_type = config.get("rms_norm_type", "one-pass") # w1, w2, w3 for SiLU gating # Note: Z-Image feed_forward layers don't have bias self.add_module( "w1", MM_WEIGHT_REGISTER[self.mm_type]( - f"{block_prefix}.{layer_id}.feed_forward.w1.weight", + f"{block_prefix}.{block_idx}.feed_forward.w1.weight", None, create_cuda_buffer, create_cpu_buffer, @@ -340,7 +411,7 @@ def __init__(self, layer_id, block_prefix, task, mm_type, config, create_cuda_bu self.add_module( "w2", MM_WEIGHT_REGISTER[self.mm_type]( - f"{block_prefix}.{layer_id}.feed_forward.w2.weight", + f"{block_prefix}.{block_idx}.feed_forward.w2.weight", None, create_cuda_buffer, create_cpu_buffer, @@ -351,7 +422,7 @@ def __init__(self, layer_id, block_prefix, task, mm_type, config, create_cuda_bu self.add_module( "w3", MM_WEIGHT_REGISTER[self.mm_type]( - f"{block_prefix}.{layer_id}.feed_forward.w3.weight", + f"{block_prefix}.{block_idx}.feed_forward.w3.weight", None, create_cuda_buffer, create_cpu_buffer, @@ -360,12 +431,22 @@ def __init__(self, layer_id, block_prefix, task, mm_type, config, create_cuda_bu ), ) - def forward(self, x): - w1_out = F.linear(x, self.w1.weight.t(), None) # Z-Image FFN has no bias - w3_out = F.linear(x, self.w3.weight.t(), None) - silu_gated = F.silu(w1_out) * w3_out - output = F.linear(silu_gated, self.w2.weight.t(), None) - return output + self.add_module( + "ffn_norm1", + RMS_WEIGHT_REGISTER[self.rms_norm_type]( + f"{block_prefix}.{block_idx}.ffn_norm1.weight", + create_cuda_buffer=create_cuda_buffer, + create_cpu_buffer=create_cpu_buffer, + ), + ) + self.add_module( + "ffn_norm2", + RMS_WEIGHT_REGISTER[self.rms_norm_type]( + f"{block_prefix}.{block_idx}.ffn_norm2.weight", + create_cuda_buffer=create_cuda_buffer, + create_cpu_buffer=create_cpu_buffer, + ), + ) def to_cpu(self, non_blocking=True): for module in self._modules.values(): diff --git a/lightx2v/models/runners/z_image/z_image_runner.py b/lightx2v/models/runners/z_image/z_image_runner.py index b19b127dc..2c91a4b2b 100755 --- a/lightx2v/models/runners/z_image/z_image_runner.py +++ b/lightx2v/models/runners/z_image/z_image_runner.py @@ -64,6 +64,7 @@ def init_modules(self): logger.info("Initializing runner modules...") if not self.config.get("lazy_load", False) and not self.config.get("unload_modules", False): self.load_model() + self.model.set_scheduler(self.scheduler) elif self.config.get("lazy_load", False): assert self.config.get("cpu_offload", False) self.run_dit = self._run_dit_local @@ -74,12 +75,11 @@ def init_modules(self): else: assert NotImplementedError - self.model.set_scheduler(self.scheduler) - @ProfilingContext4DebugL2("Run DiT") def _run_dit_local(self, total_steps=None): if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): self.model = self.load_transformer() + self.model.set_scheduler(self.scheduler) self.model.scheduler.prepare(self.input_info) latents, generator = self.run(total_steps) return latents, generator @@ -87,7 +87,11 @@ def _run_dit_local(self, total_steps=None): @ProfilingContext4DebugL2("Run Encoders") def _run_input_encoder_local_t2i(self): prompt = self.input_info.prompt + if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): + self.text_encoders = self.load_text_encoder() text_encoder_output = self.run_text_encoder(prompt, neg_prompt=self.input_info.negative_prompt) + if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): + del self.text_encoders[0] torch_device_module.empty_cache() gc.collect() return { @@ -134,7 +138,11 @@ def _run_input_encoder_local_i2i(self): images_list.append(image) prompt = self.input_info.prompt + if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): + self.text_encoders = self.load_text_encoder() text_encoder_output = self.run_text_encoder(prompt, images_list, neg_prompt=self.input_info.negative_prompt) + if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): + del self.text_encoders[0] image_encoder_output_list = [] for vae_image in text_encoder_output["image_info"]["vae_image_list"]: @@ -208,7 +216,13 @@ def run_text_encoder(self, text, image_list=None, neg_prompt=None): @ProfilingContext4DebugL1("Run VAE Encoder", recorder_mode=GET_RECORDER_MODE(), metrics_func=monitor_cli.lightx2v_run_vae_encoder_image_duration, metrics_labels=["ZImageRunner"]) def run_vae_encoder(self, image): + if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): + self.vae = self.load_vae() image_latents = self.vae.encode_vae_image(image.to(GET_DTYPE())) + if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): + del self.vae + torch_device_module.empty_cache() + gc.collect() return {"image_latents": image_latents} def run(self, total_steps=None): @@ -318,10 +332,10 @@ def load_model(self): ) def run_vae_decoder(self, latents): if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): - self.vae_decoder = self.load_vae() + self.vae = self.load_vae() images = self.vae.decode(latents, self.input_info) if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): - del self.vae_decoder + del self.vae torch_device_module.empty_cache() gc.collect() return images diff --git a/tools/convert/converter.py b/tools/convert/converter.py index c6b86a007..04fd05dd9 100755 --- a/tools/convert/converter.py +++ b/tools/convert/converter.py @@ -747,7 +747,7 @@ def main(): parser.add_argument( "-t", "--model_type", - choices=["wan_dit", "hunyuan_dit", "wan_t5", "wan_clip", "wan_animate_dit", "qwen_image_dit", "qwen25vl_llm"], + choices=["wan_dit", "hunyuan_dit", "wan_t5", "wan_clip", "wan_animate_dit", "qwen_image_dit", "qwen25vl_llm", "z_image_dit"], default="wan_dit", help="Model type", ) @@ -815,6 +815,7 @@ def main(): args.non_linear_dtype = eval(args.non_linear_dtype) model_type_keys_map = { + "z_image_dit": {"key_idx": 2, "target_keys": ["attention", "feed_forward"], "ignore_key": None}, "qwen_image_dit": { "key_idx": 2, "target_keys": ["attn", "img_mlp", "txt_mlp", "txt_mod", "img_mod"],