Skip to content

Loading Transformers For Second Stage in Two Stages Pipeline Raise Error #68

@fahadh4ilyas

Description

@fahadh4ilyas

I'm trying to run the model in my local machine with GPU A6000. Here is my script:

from ltx_core.model.video_vae import TilingConfig, get_video_chunks_number
from ltx_pipelines import TI2VidTwoStagesPipeline
from ltx_pipelines.utils.args import default_2_stage_arg_parser
from ltx_pipelines.utils.constants import AUDIO_SAMPLE_RATE
from ltx_pipelines.utils.media_io import encode_video

parser = default_2_stage_arg_parser()
args = parser.parse_args(
    [
        "--checkpoint-path", "models/LTX-2/ltx-2-19b-dev-fp8.safetensors",
        "--distilled-lora", "models/LTX-2/ltx-2-19b-distilled-lora-384.safetensors",
        "--spatial-upsampler-path", "models/LTX-2/ltx-2-spatial-upscaler-x2-1.0.safetensors",
        "--gemma-root", "models/gemma-3-12B",
        "--prompt", "A beautiful sunset over the ocean",
        "--output-path", "outputs/test_output.mp4",
        "--enable-fp8",
    ]
)

pipeline = TI2VidTwoStagesPipeline(
    checkpoint_path=args.checkpoint_path,
    distilled_lora=args.distilled_lora,
    spatial_upsampler_path=args.spatial_upsampler_path,
    gemma_root=args.gemma_root,
    loras=args.lora,
    fp8transformer=args.enable_fp8,
)

tiling_config = TilingConfig.default()
video_chunks_number = get_video_chunks_number(args.num_frames, tiling_config)
video, audio = pipeline(
    prompt=args.prompt,
    negative_prompt=args.negative_prompt,
    seed=args.seed,
    height=args.height,
    width=args.width,
    num_frames=args.num_frames,
    frame_rate=args.frame_rate,
    num_inference_steps=args.num_inference_steps,
    cfg_guidance_scale=args.cfg_guidance_scale,
    images=args.images,
    tiling_config=tiling_config,
)

encode_video(
    video=video,
    fps=args.frame_rate,
    audio=audio,
    audio_sample_rate=AUDIO_SAMPLE_RATE,
    output_path=args.output_path,
    video_chunks_number=video_chunks_number,
)

But, I got an error when loading stage 2 transformers. The weird thing is, loading stage 1 transformers is fine and the stage passed successfully. The difference between stage 1 and stage 2 transformers is the lora models/LTX-2/ltx-2-19b-distilled-lora-384.safetensors here. Is this intentional error? Because when I tried using comfyui, the process run smoothly.

Here is the error:

---------------------------------------------------------------------------
CompilationError                          Traceback (most recent call last)
Cell In[3], [line 3](vscode-notebook-cell:?execution_count=3&line=3)
      1 tiling_config = TilingConfig.default()
      2 video_chunks_number = get_video_chunks_number(args.num_frames, tiling_config)
----> [3](vscode-notebook-cell:?execution_count=3&line=3) video, audio = pipeline(
      4     prompt=args.prompt,
      5     negative_prompt=args.negative_prompt,
      6     seed=args.seed,
      7     height=args.height,
      8     width=args.width,
      9     num_frames=args.num_frames,
     10     frame_rate=args.frame_rate,
     11     num_inference_steps=args.num_inference_steps,
     12     cfg_guidance_scale=args.cfg_guidance_scale,
     13     images=args.images,
     14     tiling_config=tiling_config,
     15 )
     17 encode_video(
     18     video=video,
     19     fps=args.frame_rate,
   (...)     23     video_chunks_number=video_chunks_number,
     24 )

File ~/research-ltx-2/LTX-2/.venv/lib/python3.11/site-packages/torch/utils/_contextlib.py:[120](https://vscode-remote+ssh-002dremote-002b10-002e12-002e1-002e35.vscode-resource.vscode-cdn.net/home/fahadh/research-ltx-2/~/research-ltx-2/LTX-2/.venv/lib/python3.11/site-packages/torch/utils/_contextlib.py:120), in context_decorator.<locals>.decorate_context(*args, **kwargs)
    117 @functools.wraps(func)
    118 def decorate_context(*args, **kwargs):
    119     with ctx_factory():
--> 120         return func(*args, **kwargs)

File ~/research-ltx-2/LTX-2/packages/ltx-pipelines/src/ltx_pipelines/ti2vid_two_stages.py:181, in TI2VidTwoStagesPipeline.__call__(self, prompt, negative_prompt, seed, height, width, num_frames, frame_rate, num_inference_steps, cfg_guidance_scale, images, tiling_config, enhance_prompt)
    178 torch.cuda.synchronize()
    179 cleanup_memory()
--> [181](https://vscode-remote+ssh-002dremote-002b10-002e12-002e1-002e35.vscode-resource.vscode-cdn.net/home/fahadh/research-ltx-2/~/research-ltx-2/LTX-2/packages/ltx-pipelines/src/ltx_pipelines/ti2vid_two_stages.py:181) transformer = self.stage_2_model_ledger.transformer()
    182 distilled_sigmas = torch.Tensor(STAGE_2_DISTILLED_SIGMA_VALUES).to(self.device)
    184 def second_stage_denoising_loop(
    185     sigmas: torch.Tensor, video_state: LatentState, audio_state: LatentState, stepper: DiffusionStepProtocol
    186 ) -> tuple[LatentState, LatentState]:

File ~/research-ltx-2/LTX-2/packages/ltx-pipelines/src/ltx_pipelines/utils/model_ledger.py:190, in ModelLedger.transformer(self)
    184 if self.fp8transformer:
    185     fp8_builder = replace(
    186         self.transformer_builder,
    187         module_ops=(UPCAST_DURING_INFERENCE,),
    188         model_sd_ops=LTXV_MODEL_COMFY_RENAMING_WITH_TRANSFORMER_LINEAR_DOWNCAST_MAP,
    189     )
--> [190](https://vscode-remote+ssh-002dremote-002b10-002e12-002e1-002e35.vscode-resource.vscode-cdn.net/home/fahadh/research-ltx-2/~/research-ltx-2/LTX-2/packages/ltx-pipelines/src/ltx_pipelines/utils/model_ledger.py:190)     return X0Model(fp8_builder.build(device=self._target_device())).to(self.device).eval()
    191 else:
    192     return (
    193         X0Model(self.transformer_builder.build(device=self._target_device(), dtype=self.dtype))
    194         .to(self.device)
    195         .eval()
    196     )

File ~/research-ltx-2/LTX-2/packages/ltx-core/src/ltx_core/loader/single_gpu_model_builder.py:94, in SingleGPUModelBuilder.build(self, device, dtype)
     87 lora_state_dicts = [
     88     self.load_sd([lora.path], sd_ops=lora.sd_ops, registry=self.registry, device=device) for lora in self.loras
     89 ]
     90 lora_sd_and_strengths = [
     91     LoraStateDictWithStrength(sd, strength)
     92     for sd, strength in zip(lora_state_dicts, lora_strengths, strict=True)
     93 ]
---> [94](https://vscode-remote+ssh-002dremote-002b10-002e12-002e1-002e35.vscode-resource.vscode-cdn.net/home/fahadh/research-ltx-2/~/research-ltx-2/LTX-2/packages/ltx-core/src/ltx_core/loader/single_gpu_model_builder.py:94) final_sd = apply_loras(
     95     model_sd=model_state_dict,
     96     lora_sd_and_strengths=lora_sd_and_strengths,
     97     dtype=dtype,
     98     destination_sd=model_state_dict if isinstance(self.registry, DummyRegistry) else None,
     99 )
    100 meta_model.load_state_dict(final_sd.sd, strict=False, assign=True)
    101 return self._return_model(meta_model, device)

File ~/research-ltx-2/LTX-2/packages/ltx-core/src/ltx_core/loader/fuse_loras.py:88, in apply_loras(model_sd, lora_sd_and_strengths, dtype, destination_sd)
     86 elif weight.dtype == torch.float8_e4m3fn:
     87     if str(device).startswith("cuda"):
---> [88](https://vscode-remote+ssh-002dremote-002b10-002e12-002e1-002e35.vscode-resource.vscode-cdn.net/home/fahadh/research-ltx-2/~/research-ltx-2/LTX-2/packages/ltx-core/src/ltx_core/loader/fuse_loras.py:88)         deltas = calculate_weight_float8_(deltas, weight)
     89     else:
     90         deltas.add_(weight.to(dtype=deltas.dtype, device=device))

File ~/research-ltx-2/LTX-2/packages/ltx-core/src/ltx_core/loader/fuse_loras.py:39, in calculate_weight_float8_(target_weights, original_weights)
     38 def calculate_weight_float8_(target_weights: torch.Tensor, original_weights: torch.Tensor) -> torch.Tensor:
---> [39](https://vscode-remote+ssh-002dremote-002b10-002e12-002e1-002e35.vscode-resource.vscode-cdn.net/home/fahadh/research-ltx-2/~/research-ltx-2/LTX-2/packages/ltx-core/src/ltx_core/loader/fuse_loras.py:39)     result = fused_add_round_launch(target_weights, original_weights, seed=0).to(target_weights.dtype)
     40     target_weights.copy_(result, non_blocking=True)
     41     return target_weights

File ~/research-ltx-2/LTX-2/packages/ltx-core/src/ltx_core/loader/fuse_loras.py:26, in fused_add_round_launch(target_weight, original_weight, seed)
     23 grid = (triton.cdiv(n_elements, BLOCK_SIZE),)
     25 # Launch kernel
---> [26](https://vscode-remote+ssh-002dremote-002b10-002e12-002e1-002e35.vscode-resource.vscode-cdn.net/home/fahadh/research-ltx-2/~/research-ltx-2/LTX-2/packages/ltx-core/src/ltx_core/loader/fuse_loras.py:26) fused_add_round_kernel[grid](
     27     original_weight,
     28     target_weight,
     29     seed,
     30     n_elements,
     31     exponent_bias,
     32     mantissa_bits,
     33     BLOCK_SIZE,
     34 )
     35 return target_weight

File ~/research-ltx-2/LTX-2/.venv/lib/python3.11/site-packages/triton/runtime/jit.py:419, in KernelInterface.__getitem__.<locals>.<lambda>(*args, **kwargs)
    413 def __getitem__(self, grid) -> T:
    414     """
    415     A JIT function is launched with: fn[grid](*args, **kwargs).
    416     Hence JITFunction.__getitem__ returns a callable proxy that
    417     memorizes the grid.
    418     """
--> [419](https://vscode-remote+ssh-002dremote-002b10-002e12-002e1-002e35.vscode-resource.vscode-cdn.net/home/fahadh/research-ltx-2/~/research-ltx-2/LTX-2/.venv/lib/python3.11/site-packages/triton/runtime/jit.py:419)     return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)

File ~/research-ltx-2/LTX-2/.venv/lib/python3.11/site-packages/triton/runtime/jit.py:733, in JITFunction.run(self, grid, warmup, *args, **kwargs)
    729 if kernel is None:
    730     options, signature, constexprs, attrs = self._pack_args(backend, kwargs, bound_args, specialization,
    731                                                             options)
--> [733](https://vscode-remote+ssh-002dremote-002b10-002e12-002e1-002e35.vscode-resource.vscode-cdn.net/home/fahadh/research-ltx-2/~/research-ltx-2/LTX-2/.venv/lib/python3.11/site-packages/triton/runtime/jit.py:733)     kernel = self._do_compile(key, signature, device, constexprs, options, attrs, warmup)
    734     if kernel is None:
    735         return None

File ~/research-ltx-2/LTX-2/.venv/lib/python3.11/site-packages/triton/runtime/jit.py:861, in JITFunction._do_compile(self, key, signature, device, constexprs, options, attrs, warmup)
    859     kernel = async_mode.submit(cache_key, async_compile, finalize_compile)
    860 else:
--> [861](https://vscode-remote+ssh-002dremote-002b10-002e12-002e1-002e35.vscode-resource.vscode-cdn.net/home/fahadh/research-ltx-2/~/research-ltx-2/LTX-2/.venv/lib/python3.11/site-packages/triton/runtime/jit.py:861)     kernel = self.compile(src, target=target, options=options.__dict__)
    862     kernel_cache[key] = kernel
    863     self._call_hook(knobs.runtime.jit_post_compile_hook, key, signature, device, constexprs, options, [attrs],
    864                     warmup)

File ~/research-ltx-2/LTX-2/.venv/lib/python3.11/site-packages/triton/compiler/compiler.py:300, in compile(src, target, options, _env_vars)
    298 module_map = backend.get_module_map()
    299 try:
--> [300](https://vscode-remote+ssh-002dremote-002b10-002e12-002e1-002e35.vscode-resource.vscode-cdn.net/home/fahadh/research-ltx-2/~/research-ltx-2/LTX-2/.venv/lib/python3.11/site-packages/triton/compiler/compiler.py:300)     module = src.make_ir(target, options, codegen_fns, module_map, context)
    301 except Exception as e:
    302     filter_traceback(e)

File ~/research-ltx-2/LTX-2/.venv/lib/python3.11/site-packages/triton/compiler/compiler.py:80, in ASTSource.make_ir(self, target, options, codegen_fns, module_map, context)
     78 def make_ir(self, target: GPUTarget, options, codegen_fns, module_map, context):
     79     from .code_generator import ast_to_ttir
---> [80](https://vscode-remote+ssh-002dremote-002b10-002e12-002e1-002e35.vscode-resource.vscode-cdn.net/home/fahadh/research-ltx-2/~/research-ltx-2/LTX-2/.venv/lib/python3.11/site-packages/triton/compiler/compiler.py:80)     return ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns,
     81                        module_map=module_map)

CompilationError: at 1:0:
def fused_add_round_kernel(
^
ValueError("type fp8e4nv not supported in this architecture. The supported fp8 dtypes are ('fp8e4b15', 'fp8e5')")

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions