diff --git a/lightx2v/models/runners/default_runner.py b/lightx2v/models/runners/default_runner.py index 98e42a72..b629c4a8 100755 --- a/lightx2v/models/runners/default_runner.py +++ b/lightx2v/models/runners/default_runner.py @@ -16,7 +16,7 @@ from lightx2v.utils.global_paras import CALIB from lightx2v.utils.memory_profiler import peak_memory_decorator from lightx2v.utils.profiler import * -from lightx2v.utils.utils import get_optimal_patched_size_with_sp, isotropic_crop_resize, save_to_video, vae_to_comfyui_image +from lightx2v.utils.utils import get_optimal_patched_size_with_sp, isotropic_crop_resize, save_to_video, wan_vae_to_comfy from lightx2v_platform.base.global_var import AI_DEVICE torch_device_module = getattr(torch, AI_DEVICE) @@ -433,7 +433,7 @@ def post_prompt_enhancer(self): return enhanced_prompt def process_images_after_vae_decoder(self): - self.gen_video_final = vae_to_comfyui_image(self.gen_video_final) + self.gen_video_final = wan_vae_to_comfy(self.gen_video_final) if "video_frame_interpolation" in self.config: assert self.vfi_model is not None and self.config["video_frame_interpolation"].get("target_fps", None) is not None diff --git a/lightx2v/models/runners/longcat_image/longcat_image_runner.py b/lightx2v/models/runners/longcat_image/longcat_image_runner.py index 55462159..bc275175 100755 --- a/lightx2v/models/runners/longcat_image/longcat_image_runner.py +++ b/lightx2v/models/runners/longcat_image/longcat_image_runner.py @@ -399,12 +399,14 @@ def run_pipeline(self, input_info): if not input_info.return_result_tensor: image = images[0] - image.save(f"{input_info.save_result_path}") + image.save(input_info.save_result_path) logger.info(f"Image saved: {input_info.save_result_path}") del latents, generator torch_device_module.empty_cache() gc.collect() - # Return (images, audio) - audio is None for default runner - return images, None + if input_info.return_result_tensor: + return {"images": images} + elif input_info.save_result_path is not None: + return {"images": None} diff --git a/lightx2v/models/runners/qwen_image/qwen_image_runner.py b/lightx2v/models/runners/qwen_image/qwen_image_runner.py index 663ab1c0..e722fc5b 100755 --- a/lightx2v/models/runners/qwen_image/qwen_image_runner.py +++ b/lightx2v/models/runners/qwen_image/qwen_image_runner.py @@ -380,19 +380,22 @@ def run_pipeline(self, input_info): self.end_run() if not input_info.return_result_tensor: + image_prefix = input_info.save_result_path.rsplit(".", 1)[0] + image_suffix = input_info.save_result_path.rsplit(".", 1)[1] if len(input_info.save_result_path.rsplit(".", 1)) > 1 else "png" if isinstance(images[0], list) and len(images[0]) > 1: - image_prefix = f"{input_info.save_result_path}".split(".")[0] for idx, image in enumerate(images[0]): - image.save(f"{image_prefix}_{idx}.png") - logger.info(f"Image saved: {image_prefix}_{idx}.png") + image.save(f"{image_prefix}_{idx:05d}.{image_suffix}") + logger.info(f"Image saved: {image_prefix}_{idx:05d}.{image_suffix}") else: image = images[0] - image.save(f"{input_info.save_result_path}") - logger.info(f"Image saved: {input_info.save_result_path}") + image.save(f"{image_prefix}.{image_suffix}") + logger.info(f"Image saved: {image_prefix}.{image_suffix}") del latents, generator torch_device_module.empty_cache() gc.collect() - # Return (images, audio) - audio is None for default runner - return images, None + if input_info.return_result_tensor: + return {"images": images} + elif input_info.save_result_path is not None: + return {"images": None} diff --git a/lightx2v/models/runners/wan/wan_audio_runner.py b/lightx2v/models/runners/wan/wan_audio_runner.py index 86029de7..c38fe074 100755 --- a/lightx2v/models/runners/wan/wan_audio_runner.py +++ b/lightx2v/models/runners/wan/wan_audio_runner.py @@ -26,7 +26,7 @@ from lightx2v.utils.envs import * from lightx2v.utils.profiler import * from lightx2v.utils.registry_factory import RUNNER_REGISTER -from lightx2v.utils.utils import find_torch_model_path, fixed_shape_resize, get_optimal_patched_size_with_sp, isotropic_crop_resize, load_weights, vae_to_comfyui_image_inplace +from lightx2v.utils.utils import find_torch_model_path, fixed_shape_resize, get_optimal_patched_size_with_sp, isotropic_crop_resize, load_weights, wan_vae_to_comfy from lightx2v_platform.base.global_var import AI_DEVICE warnings.filterwarnings("ignore", category=UserWarning, module="torchaudio") @@ -594,7 +594,7 @@ def end_run_segment(self, segment_idx, valid_duration=1e9): video_seg = self.gen_video[:, :, :useful_length].cpu() audio_seg = self.segment.audio_array[:, : useful_length * self._audio_processor.audio_frame_rate] audio_seg = audio_seg.sum(dim=0) # Multiple audio tracks, mixed into one track - video_seg = vae_to_comfyui_image_inplace(video_seg) + video_seg = wan_vae_to_comfy(video_seg) # [Warning] Need check whether video segment interpolation works... if "video_frame_interpolation" in self.config and self.vfi_model is not None: @@ -642,7 +642,7 @@ def end_run_segment_stream(self, latents, valid_duration=1e9): origin_seg = torch.clamp(origin_seg, -1, 1).to(torch.float) valid_T = min(valid_length - frame_idx, origin_seg.shape[2]) - video_seg = vae_to_comfyui_image_inplace(origin_seg[:, :, :valid_T].cpu()) + video_seg = wan_vae_to_comfy(origin_seg[:, :, :valid_T].cpu()) audio_start = frame_idx * self._audio_processor.audio_frame_rate audio_end = (frame_idx + valid_T) * self._audio_processor.audio_frame_rate audio_seg = self.segment.audio_array[:, audio_start:audio_end].sum(dim=0) diff --git a/lightx2v/models/runners/wan/wan_sf_runner.py b/lightx2v/models/runners/wan/wan_sf_runner.py index 4e5e91f5..62091dc7 100755 --- a/lightx2v/models/runners/wan/wan_sf_runner.py +++ b/lightx2v/models/runners/wan/wan_sf_runner.py @@ -13,7 +13,7 @@ from lightx2v.utils.memory_profiler import peak_memory_decorator from lightx2v.utils.profiler import * from lightx2v.utils.registry_factory import RUNNER_REGISTER -from lightx2v.utils.utils import vae_to_comfyui_image_inplace +from lightx2v.utils.utils import wan_vae_to_comfy @RUNNER_REGISTER("wan2.1_sf") @@ -121,7 +121,7 @@ def end_run_segment(self, segment_idx=None): self.gen_video_final = torch.cat([self.gen_video_final, self.gen_video], dim=0) if self.gen_video_final is not None else self.gen_video if self.is_live: if self.video_recorder: - stream_video = vae_to_comfyui_image_inplace(self.gen_video) + stream_video = wan_vae_to_comfy(self.gen_video) self.video_recorder.pub_video(stream_video) torch.cuda.empty_cache() diff --git a/lightx2v/models/runners/z_image/z_image_runner.py b/lightx2v/models/runners/z_image/z_image_runner.py index 8959486d..b19b127d 100755 --- a/lightx2v/models/runners/z_image/z_image_runner.py +++ b/lightx2v/models/runners/z_image/z_image_runner.py @@ -344,12 +344,14 @@ def run_pipeline(self, input_info): if not input_info.return_result_tensor: image = images[0] - image.save(f"{input_info.save_result_path}") + image.save(input_info.save_result_path) logger.info(f"Image saved: {input_info.save_result_path}") del latents, generator torch_device_module.empty_cache() gc.collect() - # Return (images, audio) - audio is None for default runner - return images, None + if input_info.return_result_tensor: + return {"images": images} + elif input_info.save_result_path is not None: + return {"images": None} diff --git a/lightx2v/utils/utils.py b/lightx2v/utils/utils.py index 2ed9d0bd..bc387410 100755 --- a/lightx2v/utils/utils.py +++ b/lightx2v/utils/utils.py @@ -106,68 +106,42 @@ def cache_video( return None -def vae_to_comfyui_image(vae_output: torch.Tensor) -> torch.Tensor: +def wan_vae_to_comfy(vae_output: torch.Tensor) -> torch.Tensor: """ - Convert VAE decoder output to ComfyUI Image format + Convert VAE decoder output to ComfyUI Image format (inplace operation) Args: vae_output: VAE decoder output tensor, typically in range [-1, 1] Shape: [B, C, T, H, W] or [B, C, H, W] + WARNING: This tensor will be modified in-place! Returns: ComfyUI Image tensor in range [0, 1] Shape: [B, H, W, C] for single frame or [B*T, H, W, C] for video + Note: The returned tensor is the same object as input (modified in-place) """ - # Handle video tensor (5D) vs image tensor (4D) - if vae_output.dim() == 5: - # Video tensor: [B, C, T, H, W] - B, C, T, H, W = vae_output.shape - # Reshape to [B*T, C, H, W] for processing - vae_output = vae_output.permute(0, 2, 1, 3, 4).reshape(B * T, C, H, W) - - # Normalize from [-1, 1] to [0, 1] - images = (vae_output + 1) / 2 - - # Clamp values to [0, 1] - images = torch.clamp(images, 0, 1) - # Convert from [B, C, H, W] to [B, H, W, C] - images = images.permute(0, 2, 3, 1).cpu() + vae_output.add_(1.0).mul_(0.5).clamp_(0.0, 1.0) - return images + if vae_output.ndim == 5: + # Video: [B, C, T, H, W] -> [B, T, H, W, C] + vae_output = vae_output.permute(0, 2, 3, 4, 1) + # -> [B*T, H, W, C] + return vae_output.cpu().flatten(0, 1) + else: + # Image: [B, C, H, W] -> [B, H, W, C] + return vae_output.permute(0, 2, 3, 1).cpu() -def vae_to_comfyui_image_inplace(vae_output: torch.Tensor) -> torch.Tensor: +def diffusers_vae_to_comfy(vae_output: torch.Tensor) -> torch.Tensor: """ - Convert VAE decoder output to ComfyUI Image format (inplace operation) + Convert Diffusers VAE decoder output to ComfyUI Image format + Image processor for VAE, return tensor in range [0, 1] when do_denormalize is True. - Args: - vae_output: VAE decoder output tensor, typically in range [-1, 1] - Shape: [B, C, T, H, W] or [B, C, H, W] - WARNING: This tensor will be modified in-place! + ref: https://github.com/huggingface/diffusers/blob/main/src/diffusers/image_processor.py#L744 - Returns: - ComfyUI Image tensor in range [0, 1] - Shape: [B, H, W, C] for single frame or [B*T, H, W, C] for video - Note: The returned tensor is the same object as input (modified in-place) """ - # Handle video tensor (5D) vs image tensor (4D) - if vae_output.dim() == 5: - # Video tensor: [B, C, T, H, W] - B, C, T, H, W = vae_output.shape - # Reshape to [B*T, C, H, W] for processing (inplace view) - vae_output = vae_output.permute(0, 2, 1, 3, 4).contiguous().view(B * T, C, H, W) - - # Normalize from [-1, 1] to [0, 1] (inplace) - vae_output.add_(1).div_(2) - - # Clamp values to [0, 1] (inplace) - vae_output.clamp_(0, 1) - - # Convert from [B, C, H, W] to [B, H, W, C] and move to CPU - vae_output = vae_output.permute(0, 2, 3, 1).cpu() - - return vae_output + return vae_output.permute(0, 2, 3, 1).cpu() def save_to_video(