diff --git a/tools/ptq/README.md b/tools/ptq/README.md new file mode 100644 index 000000000000..4d7f9c963fb0 --- /dev/null +++ b/tools/ptq/README.md @@ -0,0 +1,31 @@ + +### Create PTQ Artefact / Run Calibration +Only needs to be done once per model +``` +python -m tools.ptq.quantize \ + --model_type flux_schnell + --unet_path /flux1-schnell.safetensors + --clip_path clip_l.safetensors + --t5_path t5xxl_fp16.safetensors + --output flux_schnell_debug.json + --calib_steps 16 +``` + +### Create Quantized Checkpoint +Uses the artefact from before + checkpoint to generate a quantized checkpoint based of a yml configuration. +This file defines which layers to quantize and what dtype they should use using regex. See `tools/ptq/configs/` for examples. +```yaml +# config.yml +disable_list: ["*img_in*", "*final_layer*", "*norm*"] # Keep these in BF16 +per_layer_dtype: {"*": "float8_e4m3fn"} # Everything else to FP8 +``` + +``` +python -m tools.ptq.checkpoint_merger + --artefact flux_dev_debug.json + --checkpoint /flux1-dev.safetensors + --config tools/ptq/configs/flux_nvfp4.yml + --output /flux1-nvfp4.safetensors + --debug +``` + diff --git a/tools/ptq/checkpoint_merger.py b/tools/ptq/checkpoint_merger.py new file mode 100644 index 000000000000..08fad0028e35 --- /dev/null +++ b/tools/ptq/checkpoint_merger.py @@ -0,0 +1,240 @@ +import argparse +import logging +import sys +import yaml +import re +from typing import Dict, Tuple +import torch +from safetensors.torch import save_file +import json + +import os +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) + +import comfy.utils +from comfy.ops import QUANT_FORMAT_MIXINS +from comfy.quant_ops import F8_E4M3_MAX, F4_E2M1_MAX + +class QuantizationConfig: + def __init__(self, config_path: str): + with open(config_path, 'r') as f: + self.config = yaml.safe_load(f) + + self.disable_patterns = [] + for pattern in self.config.get('disable_list', []): + regex_pattern = pattern.replace('*', '.*') + self.disable_patterns.append(re.compile(regex_pattern)) + + self.per_layer_dtype = self.config.get('per_layer_dtype', {}) + self.dtype_patterns = [] + for pattern, dtype in self.per_layer_dtype.items(): + regex_pattern = pattern.replace('*', '.*') + self.dtype_patterns.append((re.compile(regex_pattern), dtype)) + + logging.info(f"Loaded config with {len(self.disable_patterns)} disable patterns") + logging.info(f"Per-layer dtype rules: {self.per_layer_dtype}") + + def should_quantize(self, layer_name: str) -> bool: + for pattern in self.disable_patterns: + if pattern.match(layer_name): + logging.debug(f"Layer {layer_name} disabled by pattern {pattern.pattern}") + return False + return True + + def get_dtype(self, layer_name: str) -> str: + for pattern, dtype in self.dtype_patterns: + if pattern.match(layer_name): + return dtype + return None + +def load_amax_artefact(artefact_path: str) -> Dict: + logging.info(f"Loading amax artefact from {artefact_path}") + + with open(artefact_path, 'r') as f: + data = json.load(f) + + if 'amax_values' not in data: + raise ValueError("Invalid artefact format: missing 'amax_values' key") + + metadata = data.get('metadata', {}) + amax_values = data['amax_values'] + + logging.info(f"Loaded {len(amax_values)} amax values from artefact") + logging.info(f"Artefact metadata: {metadata}") + + return data + +def get_scale_fp8(amax: float, dtype: torch.dtype) -> torch.Tensor: + scale = amax / torch.finfo(dtype).max + scale_tensor = torch.tensor(scale, dtype=torch.float32) + return scale_tensor + +def get_scale_nvfp4(amax: float, dtype: torch.dtype) -> torch.Tensor: + scale = amax / (F8_E4M3_MAX * F4_E2M1_MAX) + scale_tensor = torch.tensor(scale, dtype=torch.float32) + return scale_tensor + +def get_scale(amax: float, dtype: torch.dtype): + if dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: + return get_scale_fp8(amax, dtype) + elif dtype in [torch.float4_e2m1fn_x2]: + return get_scale_nvfp4(amax, dtype) + else: + raise ValueError(f"Unsupported dtype {dtype} ") + +def apply_quantization( + checkpoint: Dict, + amax_values: Dict[str, float], + config: QuantizationConfig +) -> Tuple[Dict, Dict]: + quantized_dict = {} + layer_metadata = {} + + for key, amax in amax_values.items(): + if key.endswith(".input_quantizer"): + continue + + layer_name = ".".join(key.split(".")[:-1]) + + if not config.should_quantize(layer_name): + logging.debug(f"Layer {layer_name} disabled by config") + continue + + dtype_str = config.get_dtype(layer_name) + dtype = getattr(torch, dtype_str) + device = torch.device("cuda") + + weight = checkpoint.pop(f"{layer_name}.weight").to(device) + scale_tensor = get_scale(amax, dtype) + + input_amax = amax_values.get(f"{layer_name}.input_quantizer", None) + if input_amax is not None: + input_scale = get_scale(input_amax, dtype) + quantized_dict[f"{layer_name}.input_scale"] = input_scale.clone() + + tensor_layout = QUANT_FORMAT_MIXINS[dtype_str]["layout_type"] + quantized_weight, layout_params = tensor_layout.quantize( + weight, + scale=scale_tensor, + dtype=dtype + ) + quantized_dict[f"{layer_name}.weight_scale"] = scale_tensor.clone() + quantized_dict[f"{layer_name}.weight"] = quantized_weight.clone() + + if "block_scale" in layout_params: + quantized_dict[f"{layer_name}.weight_block_scale"] = layout_params["block_scale"].clone() + + layer_metadata[layer_name] = { + "format": dtype_str, + "params": {} + } + + logging.info(f"Quantized {len(layer_metadata)} layers") + + quantized_dict = quantized_dict | checkpoint + + metadata_dict = { + "_quantization_metadata": json.dumps({ + "format_version": "1.0", + "layers": layer_metadata + }) + } + return quantized_dict, metadata_dict + + +def main(): + parser = argparse.ArgumentParser( + description="Merge calibration artifacts with checkpoint to create quantized model", + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + + parser.add_argument( + "--artefact", + required=True, + help="Path to calibration artefact JSON file (amax values)" + ) + parser.add_argument( + "--checkpoint", + required=True, + help="Path to original checkpoint to quantize" + ) + parser.add_argument( + "--config", + required=True, + help="Path to YAML quantization config file" + ) + parser.add_argument( + "--output", + required=True, + help="Output path for quantized checkpoint" + ) + parser.add_argument( + "--debug", + action="store_true", + help="Enable debug logging" + ) + + args = parser.parse_args() + + # Configure logging + if args.debug: + logging.basicConfig( + level=logging.DEBUG, + format='[%(levelname)s] %(name)s: %(message)s' + ) + else: + logging.basicConfig( + level=logging.INFO, + format='[%(levelname)s] %(message)s' + ) + + logging.info("[1/5] Loading calibration artefact...") + try: + artefact_data = load_amax_artefact(args.artefact) + amax_values = artefact_data['amax_values'] + except Exception as e: + logging.error(f"Failed to load artefact: {e}") + sys.exit(1) + + logging.info("[2/5] Loading quantization config...") + try: + config = QuantizationConfig(args.config) + except Exception as e: + logging.error(f"Failed to load config: {e}") + sys.exit(1) + + logging.info("[3/5] Loading checkpoint...") + try: + checkpoint = comfy.utils.load_torch_file(args.checkpoint) + logging.info(f"Loaded checkpoint with {len(checkpoint)} keys") + except Exception as e: + logging.error(f"Failed to load checkpoint: {e}") + sys.exit(1) + + logging.info("[4/5] Applying quantization...") + try: + quantized_dict, metadata_json = apply_quantization( + checkpoint, + amax_values, + config + ) + except Exception as e: + logging.error(f"Failed to apply quantization: {e}") + import traceback + traceback.print_exc() + sys.exit(1) + + logging.info("[5/5] Exporting quantized checkpoint...") + try: + save_file(quantized_dict, args.output, metadata=metadata_json) + + except Exception as e: + logging.error(f"Failed to export checkpoint: {e}") + import traceback + traceback.print_exc() + sys.exit(1) + + +if __name__ == "__main__": + main() + diff --git a/tools/ptq/configs/flux_fp8.yml b/tools/ptq/configs/flux_fp8.yml new file mode 100644 index 000000000000..6786cba8e54e --- /dev/null +++ b/tools/ptq/configs/flux_fp8.yml @@ -0,0 +1,23 @@ +# FLUX Quantization Config: Transformer Blocks Only +# +# Quantize only double and single transformer blocks, +# leave input/output projections in higher precision. + +disable_list: [ + # Disable input projections + "*img_in*", + "*txt_in*", + "*time_in*", + "*vector_in*", + "*guidance_in*", + + # Disable output layers + "*final_layer*", + + # Disable positional embeddings + "*pe_embedder*", +] + +per_layer_dtype: { + "*": "float8_e4m3fn", +} diff --git a/tools/ptq/configs/flux_nvfp4.yml b/tools/ptq/configs/flux_nvfp4.yml new file mode 100644 index 000000000000..3ca94ffcdbbb --- /dev/null +++ b/tools/ptq/configs/flux_nvfp4.yml @@ -0,0 +1,27 @@ +# FLUX Quantization Config: Transformer Blocks Only +# +# Quantize only double and single transformer blocks, +# leave input/output projections in higher precision. + +disable_list: [ + # Disable input projections + "*img_in*", + "*txt_in*", + "*time_in*", + "*vector_in*", + "*guidance_in*", + + # Disable output layers + "*final_layer*", + + # Disable positional embeddings + "*pe_embedder*", + + "*modulation*", + "*txt_mod*", + "*img_mod*", +] + +per_layer_dtype: { + "*": "float4_e2m1fn_x2", +} diff --git a/tools/ptq/example.yml b/tools/ptq/example.yml new file mode 100644 index 000000000000..beffeb3a3555 --- /dev/null +++ b/tools/ptq/example.yml @@ -0,0 +1,30 @@ +# Quantization Configuration for Checkpoint Merger +# +# This file defines which layers to quantize and what precision to use. +# Patterns use glob-style syntax where * matches any characters. + +# Regex patterns of layers to DISABLE quantization +# If a layer matches any pattern here, it will NOT be quantized +disable_list: [ + # Example: disable input/output projection layers + # "*img_in*", + # "*txt_in*", + # "*final_layer*", + + # Example: disable specific block types + # "*norm*", + # "*time_in*", +] + +# Per-layer dtype configuration +# Maps layer name patterns to quantization formats +# Layers are matched in order - first match wins +per_layer_dtype: { + # Default: quantize all layers to FP8 E4M3 + "*": "fp8_e4m3fn", + + # Example: use different precision for specific layers + # "*attn*": "fp8_e4m3fn", # Attention layers + # "*mlp*": "fp8_e4m3fn", # MLP layers + # "*qkv*": "fp8_e4m3fn", # Q/K/V projections +} \ No newline at end of file diff --git a/tools/ptq/models/__init__.py b/tools/ptq/models/__init__.py new file mode 100644 index 000000000000..3db370e90e16 --- /dev/null +++ b/tools/ptq/models/__init__.py @@ -0,0 +1,36 @@ +from typing import Dict, Type +from .base import ModelRecipe + + +_RECIPE_REGISTRY: Dict[str, Type[ModelRecipe]] = {} + + +def register_recipe(recipe_cls: Type[ModelRecipe]): + recipe_name = recipe_cls.name() + if recipe_name in _RECIPE_REGISTRY: + raise ValueError(f"Recipe '{recipe_name}' is already registered") + + _RECIPE_REGISTRY[recipe_name] = recipe_cls + return recipe_cls + + +def get_recipe_class(name: str) -> Type[ModelRecipe]: + if name not in _RECIPE_REGISTRY: + available = ", ".join(sorted(_RECIPE_REGISTRY.keys())) + raise ValueError( + f"Unknown model type '{name}'. " + f"Available recipes: {available}" + ) + return _RECIPE_REGISTRY[name] + + +def list_recipes(): + return sorted(_RECIPE_REGISTRY.keys()) + + +# Import recipe modules to trigger registration +from . import flux # noqa: F401, E402 +from . import qwen # noqa: F401, E402 +from . import ltx_video # noqa: F401, E402 +from . import wan # noqa: F401, E402 + diff --git a/tools/ptq/models/base.py b/tools/ptq/models/base.py new file mode 100644 index 000000000000..4056d3bfce28 --- /dev/null +++ b/tools/ptq/models/base.py @@ -0,0 +1,45 @@ +from abc import ABC, abstractmethod +import argparse +from typing import Tuple, Any, Callable +from dataclasses import dataclass + + +class ModelRecipe(ABC): + @classmethod + @abstractmethod + def name(cls) -> str: + pass + + @classmethod + @abstractmethod + def add_model_args(cls, parser: argparse.ArgumentParser): + pass + + @abstractmethod + def __init__(self, args): + pass + + @abstractmethod + def load_model(self) -> Tuple[Any, ...]: + pass + + @abstractmethod + def create_calibration_pipeline(self, model_components) -> Any: + pass + + @abstractmethod + def get_forward_loop(self, calib_pipeline, num_calib_steps) -> Callable: + pass + + @abstractmethod + def get_default_calib_steps(self) -> int: + pass + +@dataclass +class SamplerCFG: + cfg: float + sampler_name: str + scheduler: str = "simple" + denoise: float = 1.0 + flux_cfg: float = 1.0 + img_cfg: float = 2.0 diff --git a/tools/ptq/models/dataset.py b/tools/ptq/models/dataset.py new file mode 100644 index 000000000000..38d0d2b3e510 --- /dev/null +++ b/tools/ptq/models/dataset.py @@ -0,0 +1,88 @@ +import torch +import numpy as np +from typing import Dict, Any +from datasets import load_dataset + + +class HFPromptDataloader(torch.utils.data.Dataset): + def __init__(self, split="test", max_samples=None): + self.dataset = load_dataset("Gustavosta/Stable-Diffusion-Prompts", split=split) + + if max_samples is not None and max_samples < len(self.dataset): + self.dataset = self.dataset.select(range(max_samples)) + + def __iter__(self): + for sample in self.dataset: + yield {"prompt": sample["Prompt"]} + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, idx): + sample = self.dataset[idx] + return {"prompt": sample["Prompt"]} + + +class KontextBenchDataLoader: + def __init__( + self, + dataset_name: str = "black-forest-labs/kontext-bench", + split: str = "test", + ): + self.dataset_name = dataset_name + self.split = split + self._dataset = None + + def load_dataset(self): + if self._dataset is None: + self._dataset = load_dataset( + self.dataset_name, + split=self.split, + ) + + return self._dataset + + def __iter__(self): + dataset = self.load_dataset() + for sample in dataset: + yield { + 'image': self.preprocess_img(sample['image']), + 'prompt': sample['instruction'] + } + + def __len__(self) -> int: + dataset = self.load_dataset() + return len(dataset) + + def __getitem__(self, idx: int) -> Dict[str, Any]: + dataset = self.load_dataset() + sample = dataset[idx] + return { + 'image': self.preprocess_img(sample['image']), + 'prompt': sample['instruction'] + } + + def preprocess_img(self, img): + img = img.convert("RGB") + img = np.array(img).astype(np.float32) / 255.0 + img = torch.from_numpy(img) + return img.unsqueeze(0) + + def get_full_sample(self, idx: int) -> Dict[str, Any]: + dataset = self.load_dataset() + return dataset[idx] + + def filter_by_category(self, category: str): + dataset = self.load_dataset() + return dataset.filter(lambda x: x['category'] == category) + + @staticmethod + def get_categories(): + return [ + "Character Reference", + "Instruction Editing - Global", + "Instruction Editing - Local", + "Style Reference", + "Text Editing" + ] + diff --git a/tools/ptq/models/flux.py b/tools/ptq/models/flux.py new file mode 100644 index 000000000000..bf2a7ef74840 --- /dev/null +++ b/tools/ptq/models/flux.py @@ -0,0 +1,298 @@ +import logging + +import torch + +import comfy.sd +import comfy.utils +import folder_paths +import random +from typing import Tuple, Callable + +from comfy_extras.nodes_flux import FluxGuidance, FluxKontextImageScale, CLIPTextEncodeFlux +from comfy_extras.nodes_sd3 import EmptySD3LatentImage +from comfy_extras.nodes_edit_model import ReferenceLatent + +from nodes import CLIPTextEncode, KSampler, VAEEncode, ConditioningZeroOut + +from . import register_recipe +from .base import ModelRecipe, SamplerCFG +from .dataset import HFPromptDataloader, KontextBenchDataLoader + +class FluxT2IPipe: + def __init__( + self, + model, + clip, + vae, + width, + height, + seed=0, + sampler_cfg: SamplerCFG = None, + device="cuda", + ) -> None: + self.clip = clip + self.vae = vae + self.diffusion_model = model + + self.width = width + self.height = height + self.sampler_cfg = sampler_cfg + self.device = device + self.seed = seed + assert self.sampler_cfg is not None, "Sampler configuration is required" + + @torch.inference_mode + def __call__(self, num_inference_steps, positive_prompt, *args, **kwargs): + positive = CLIPTextEncodeFlux().encode(self.clip, positive_prompt, positive_prompt, self.sampler_cfg.flux_cfg)[0] + negative = ConditioningZeroOut().zero_out(positive)[0] + latent = EmptySD3LatentImage().execute(self.width, self.height).args[0] + + KSampler().sample( + self.diffusion_model, self.seed, num_inference_steps, + self.sampler_cfg.cfg, self.sampler_cfg.sampler_name, + self.sampler_cfg.scheduler, positive=positive, + negative=negative, latent_image=latent, + denoise=self.sampler_cfg.denoise)[0] + + + +class FluxRecipeBase(ModelRecipe): + @classmethod + def add_model_args(cls, parser): + group = parser.add_mutually_exclusive_group(required=True) + group.add_argument( + "--ckpt_path", + help="Path to full checkpoint (includes diffusion model + text encoder + VAE)" + ) + group.add_argument( + "--unet_path", + help="Path to diffusion model only (requires test_encoder and VAE)" + ) + + parser.add_argument( + "--clip_path", + help="Path to text encoder (required with --unet_path)" + ) + parser.add_argument( + "--t5_path", + help="Path to text encoder (required with --unet_path)" + ) + + parser.add_argument( + "--vae_path", + help="Path to VAE model (required with --unet_path)", + required=False, + ) + + def __init__(self, args): + self.args = args + + # Validate args + if hasattr(args, 'unet_path') and args.unet_path: + if not args.clip_path or not args.t5_path: + raise ValueError("--unet_path requires both --clip_path and --t5_path") + + def load_model(self) -> Tuple: + if hasattr(self.args, 'ckpt_path') and self.args.ckpt_path: + # Load from full checkpoint + logging.info(f"Loading full checkpoint from {self.args.ckpt_path}") + model_patcher, clip, vae, _ = comfy.sd.load_checkpoint_guess_config( + self.args.ckpt_path, + output_vae=True, + output_clip=True, + embedding_directory=None + ) + else: + # Load from separate files + logging.info(f"Loading diffusion model from {self.args.unet_path}") + model_options = {} + clip_type = comfy.sd.CLIPType.FLUX + + clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", self.args.clip_path) + clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", self.args.t5_path) + + model_patcher = comfy.sd.load_diffusion_model( + self.args.unet_path, + model_options=model_options + ) + clip = comfy.sd.load_clip( + ckpt_paths=[clip_path1, clip_path2], + embedding_directory=folder_paths.get_folder_paths("embeddings"), + clip_type=clip_type, + model_options=model_options + ) + + vae = None # Not needed for calibration + if self.args.vae_path: + vae_path = folder_paths.get_full_path_or_raise("vae", self.args.vae_path) + sd = comfy.utils.load_torch_file(vae_path) + vae = comfy.sd.VAE(sd=sd) + vae.throw_exception_if_invalid() + return model_patcher, clip, vae + + def create_calibration_pipeline(self, model_components): + model_patcher, clip, vae = model_components + + return FluxT2IPipe( + model=model_patcher, + clip=clip, + vae=vae, + width=self.get_width(), + height=self.get_height(), + seed=42, + sampler_cfg=self.get_sampler_cfg(), + device="cuda" + ) + + def get_forward_loop(self, calib_pipeline, num_calib_steps) -> Callable: + num_steps = self.get_inference_steps() + dataloader = self.get_dataset() + + def forward_loop(): + for i in range(num_calib_steps): + rnd_idx = random.randint(0, len(dataloader) - 1) + sample = dataloader[rnd_idx] + prompt_text = sample["prompt"] + + logging.debug(f"Calibration step {i + 1}: '{prompt_text[:50]}...'") + try: + calib_pipeline(num_steps, prompt_text) + except Exception as e: + logging.warning(f"Calibration step {i + 1} failed: {e}") + + return forward_loop + + def get_width(self) -> int: + return 1024 + + def get_height(self) -> int: + return 1024 + + def get_default_calib_steps(self) -> int: + return 64 + + def get_sampler_cfg(self) -> SamplerCFG: + return SamplerCFG( + cfg=1.0, + sampler_name="euler", + scheduler="simple", + denoise=1.0, + flux_cfg=3.5 + ) + + def get_inference_steps(self) -> int: + """Number of sampling steps per calibration iteration.""" + raise NotImplementedError + + def get_dataset(self): + return HFPromptDataloader() + +@register_recipe +class FluxDevRecipe(FluxRecipeBase): + @classmethod + def name(cls) -> str: + return "flux_dev" + + def get_inference_steps(self) -> int: + return 30 + +@register_recipe +class FluxSchnellRecipe(FluxRecipeBase): + @classmethod + def name(cls) -> str: + return "flux_schnell" + + def get_inference_steps(self) -> int: + return 4 + +class FluxKontextPipe: + def __init__( + self, + model, + clip, + vae, + seed=0, + sampler_cfg: SamplerCFG = None, + device="cuda", + ) -> None: + self.clip = clip + self.vae = vae + self.diffusion_model = model + + self.sampler_cfg = sampler_cfg + self.device = device + self.seed = seed + assert self.sampler_cfg is not None, "Sampler configuration is required" + assert self.vae is not None, "VAE is required for FluxKontextRecipe" + + @torch.inference_mode + def __call__(self, num_inference_steps, positive_prompt, image, *args, **kwargs): + image_preprocessed = FluxKontextImageScale().execute(image).args[0] + image_encoded = VAEEncode().encode(self.vae, image_preprocessed)[0] + + positive = CLIPTextEncode().encode(self.clip, positive_prompt)[0] + conditioning_img = ReferenceLatent().execute(positive, image_encoded).args[0] + conditioning_img = FluxGuidance().execute(conditioning_img, self.sampler_cfg.img_cfg).args[0] + + conditioning_prompt = ConditioningZeroOut().zero_out(positive)[0] + KSampler().sample(self.diffusion_model, self.seed, num_inference_steps, + self.sampler_cfg.cfg, self.sampler_cfg.sampler_name, + self.sampler_cfg.scheduler, positive=conditioning_img, + negative=conditioning_prompt, latent_image=image_encoded, + denoise=self.sampler_cfg.denoise)[0] + + def _preview_img(self, out): + pass + +@register_recipe +class FluxKontextRecipe(FluxRecipeBase): + @classmethod + def name(cls) -> str: + return "flux_kontext" + + def create_calibration_pipeline(self, model_components): + model_patcher, clip, vae = model_components + assert vae is not None, "VAE is required for FluxKontextRecipe" + + return FluxKontextPipe( + model=model_patcher, + clip=clip, + vae=vae, + seed=42, + sampler_cfg=self.get_sampler_cfg(), + device="cuda" + ) + + def get_forward_loop(self, calib_pipeline, num_calib_steps) -> Callable: + num_steps = self.get_inference_steps() + dataloader = self.get_dataset() + def forward_loop(): + for i in range(num_calib_steps): + rnd_idx = random.randint(0, len(dataloader) - 1) + sample = dataloader[rnd_idx] + prompt_text = sample["prompt"] + img = sample["image"] + + logging.debug(f"Calibration step {i+1}: '{prompt_text[:50]}...'") + try: + calib_pipeline(num_steps, prompt_text, img) + except Exception as e: + logging.warning(f"Calibration step {i+1} failed: {e}") + + return forward_loop + + def get_inference_steps(self) -> int: + return 30 + + def get_dataset(self): + logging.info("Loading KontextBench dataset...") + return KontextBenchDataLoader() + + def get_sampler_cfg(self) -> SamplerCFG: + return SamplerCFG( + cfg=1.0, + sampler_name="euler", + scheduler="simple", + denoise=1.0, + flux_cfg=2.5 + ) diff --git a/tools/ptq/models/ltx_video.py b/tools/ptq/models/ltx_video.py new file mode 100644 index 000000000000..08a2449fc277 --- /dev/null +++ b/tools/ptq/models/ltx_video.py @@ -0,0 +1,236 @@ +import logging +import torch + +import comfy.sd +import comfy.utils +import folder_paths +import random +from typing import Tuple, Callable + +from comfy_extras.nodes_lt import LTXVConditioning, EmptyLTXVLatentVideo, LTXVScheduler, LTXVImgToVideo + +from comfy_extras.nodes_model_advanced import ModelSamplingAuraFlow +from comfy_extras.nodes_custom_sampler import SamplerCustom, KSamplerSelect +from nodes import CLIPTextEncode + +from . import register_recipe +from .base import ModelRecipe, SamplerCFG +from .dataset import HFPromptDataloader, KontextBenchDataLoader + +class LTXVideoPipe: + def __init__( + self, + model, + clip, + vae, + width, + height, + length, + seed=0, + sampler_cfg: SamplerCFG = None, + device="cuda", + custom_kwargs: dict = None, + ) -> None: + self.clip = clip + self.vae = vae + self.diffusion_model = model + + self.width = width + self.height = height + self.length = length + + self.sampler_cfg = sampler_cfg + self.device = device + self.seed = seed + assert self.sampler_cfg is not None, "Sampler configuration is required" + + self.custom_kwargs = custom_kwargs + + @torch.inference_mode + def __call__(self, num_inference_steps, positive_prompt, negative_prompt, image=None, *args, **kwargs): + positive = CLIPTextEncode().encode(self.clip, positive_prompt)[0] + negative = CLIPTextEncode().encode(self.clip, negative_prompt)[0] + + if image is None: + latent_image = EmptyLTXVLatentVideo().execute(self.width, self.height, self.length).args[0] + else: + positive, negative, latent_image = LTXVImgToVideo().execute(positive=positive, negative=negative, + vae=self.vae, image=image, + width=self.width, height=self.height, + length=self.length, strength=0.1, batch_size=1).args + + positive, negative = LTXVConditioning().execute(positive, negative, + frame_rate=self.custom_kwargs.get("frame_rate", 25.0)).args + + model = ModelSamplingAuraFlow().patch_aura(self.diffusion_model, self.sampler_cfg.flux_cfg)[0] + sigmas = LTXVScheduler().execute(num_inference_steps, max_shift=self.custom_kwargs.get("max_shift", 2.05), + base_shift=self.custom_kwargs.get("base_shift", 0.95), + stretch=self.custom_kwargs.get("stretch", True), + terminal=self.custom_kwargs.get("terminal", 0.1)).args[0] + + sampler = KSamplerSelect().get_sampler(self.sampler_cfg.sampler_name)[0] + SamplerCustom().sample(model, positive=positive, negative=negative, + sampler=sampler, sigmas=sigmas, latent_image=latent_image, + cfg=self.sampler_cfg.cfg, add_noise=True, noise_seed=self.seed) + + + + +@register_recipe +class LTXVRecipeBase(ModelRecipe): + @classmethod + def name(cls) -> str: + return "ltxv_t2i" + + @classmethod + def add_model_args(cls, parser): + group = parser.add_mutually_exclusive_group(required=True) + group.add_argument( + "--ckpt_path", + help="Path to full checkpoint (includes diffusion model + text encoder + VAE)" + ) + parser.add_argument( + "--t5_path", + help="Path to text encoder (required with --unet_path)" + ) + parser.add_argument( + "--vae_path", + help="Path to VAE model (required with --unet_path)", + required=False, + ) + + def __init__(self, args): + self.args = args + + if not args.t5_path: + raise ValueError("--unet_path requires both --t5_path") + + def load_model(self) -> Tuple: + logging.info(f"Loading full checkpoint from {self.args.ckpt_path}") + model_patcher, clip, vae, _ = comfy.sd.load_checkpoint_guess_config( + self.args.ckpt_path, + output_vae=True, + output_clip=True, + embedding_directory=None + ) + + model_options = {} + clip_type = comfy.sd.CLIPType.LTXV + + clip_path = folder_paths.get_full_path_or_raise("text_encoders", self.args.t5_path) + clip = comfy.sd.load_clip( + ckpt_paths=[clip_path], + embedding_directory=folder_paths.get_folder_paths("embeddings"), + clip_type=clip_type, + model_options=model_options + ) + + if self.args.vae_path: + vae_path = folder_paths.get_full_path_or_raise("vae", self.args.vae_path) + sd = comfy.utils.load_torch_file(vae_path) + vae = comfy.sd.VAE(sd=sd) + vae.throw_exception_if_invalid() + return model_patcher, clip, vae + + def create_calibration_pipeline(self, model_components): + model_patcher, clip, vae = model_components + + return LTXVideoPipe( + model=model_patcher, + clip=clip, + vae=vae, + width=self.get_width(), + height=self.get_height(), + length=self.get_length(), + seed=42, + sampler_cfg=self.get_sampler_cfg(), + device="cuda", + custom_kwargs=self.get_custom_kwargs() + ) + + def get_width(self) -> int: + return 768 + + def get_height(self) -> int: + return 512 + + def get_length(self) -> int: + return 97 + + def get_default_calib_steps(self) -> int: + return 64 + + def get_inference_steps(self) -> int: + return 30 + + def get_sampler_cfg(self) -> SamplerCFG: + return SamplerCFG( + cfg=3.0, + sampler_name="euler", + ) + + def get_custom_kwargs(self) -> dict: + return { + "frame_rate": 25.0, + "max_shift": 2.05, + "base_shift": 0.95, + "stretch": True, + "terminal": 0.1 + } + +@register_recipe +class LTXVText2Video(LTXVRecipeBase): + @classmethod + def name(cls) -> str: + return "ltxv_t2v" + + def get_dataset(self): + logging.info("Loading KontextBench dataset...") + return HFPromptDataloader() + + def get_forward_loop(self, calib_pipeline, num_calib_steps) -> Callable: + num_steps = self.get_inference_steps() + dataloader = self.get_dataset() + def forward_loop(): + for i in range(num_calib_steps): + rnd_idx = random.randint(0, len(dataloader) - 1) + sample = dataloader[rnd_idx] + prompt_text = sample["prompt"] + negative_text = "low quality" + + logging.debug(f"Calibration step {i+1}: '{prompt_text[:50]}...'") + try: + calib_pipeline(num_steps, prompt_text, negative_text) + except Exception as e: + logging.warning(f"Calibration step {i+1} failed: {e}") + + return forward_loop + +@register_recipe +class LTXVImg2Video(LTXVRecipeBase): + @classmethod + def name(cls) -> str: + return "ltxv_i2v" + + def get_dataset(self): + logging.info("Loading KontextBench dataset...") + return KontextBenchDataLoader() + + def get_forward_loop(self, calib_pipeline, num_calib_steps) -> Callable: + num_steps = self.get_inference_steps() + dataloader = self.get_dataset() + def forward_loop(): + for i in range(num_calib_steps): + rnd_idx = random.randint(0, len(dataloader) - 1) + sample = dataloader[rnd_idx] + prompt_text = sample["prompt"] + img = sample["image"] + negative_text = "low quality" + + logging.debug(f"Calibration step {i+1}: '{prompt_text[:50]}...'") + try: + calib_pipeline(num_steps, prompt_text, negative_text, img) + except Exception as e: + logging.warning(f"Calibration step {i+1} failed: {e}") + + return forward_loop diff --git a/tools/ptq/models/qwen.py b/tools/ptq/models/qwen.py new file mode 100644 index 000000000000..847749e4e1f9 --- /dev/null +++ b/tools/ptq/models/qwen.py @@ -0,0 +1,255 @@ +import logging +import torch + +import comfy.sd +import comfy.utils +import folder_paths +import random +from typing import Tuple, Callable + +from comfy_extras.nodes_sd3 import EmptySD3LatentImage + +from comfy_extras.nodes_model_advanced import ModelSamplingAuraFlow +from comfy_extras.nodes_post_processing import ImageScaleToTotalPixels +from comfy_extras.nodes_qwen import TextEncodeQwenImageEditPlus +from comfy_extras.nodes_cfg import CFGNorm +from nodes import CLIPTextEncode, KSampler, VAEEncode + +from . import register_recipe +from .base import ModelRecipe, SamplerCFG +from .dataset import HFPromptDataloader, KontextBenchDataLoader + +class QwenPipe: + def __init__( + self, + model, + clip, + vae, + width, + height, + seed=0, + sampler_cfg: SamplerCFG = None, + device="cuda", + ) -> None: + self.clip = clip + self.vae = vae + self.diffusion_model = model + + self.width = width + self.height = height + self.sampler_cfg = sampler_cfg + self.device = device + self.seed = seed + assert self.sampler_cfg is not None, "Sampler configuration is required" + + @torch.inference_mode + def __call__(self, num_inference_steps, positive_prompt, negative_prompt, *args, **kwargs): + positive = CLIPTextEncode().encode(self.clip, positive_prompt)[0] + negative = CLIPTextEncode().encode(self.clip, negative_prompt)[0] + + model = ModelSamplingAuraFlow().patch_aura(self.diffusion_model, self.sampler_cfg.flux_cfg)[0] + + latent = EmptySD3LatentImage().execute(self.width, self.height).args[0] + + out = KSampler().sample( + model=model, seed=self.seed, steps=num_inference_steps, + cfg=self.sampler_cfg.cfg, sampler_name=self.sampler_cfg.sampler_name, + scheduler=self.sampler_cfg.scheduler, positive=positive, + negative=negative, latent_image=latent, + denoise=self.sampler_cfg.denoise)[0] + + @torch.inference_mode + def image_edit(self, num_inference_steps, positive_prompt, negative_prompt, image, *args, **kwargs): + image = ImageScaleToTotalPixels().execute(image, "lanczos", 1.0).args[0] + latent_image = VAEEncode().encode(self.vae, image)[0] + + positive = TextEncodeQwenImageEditPlus().execute(self.clip, positive_prompt, image).args[0] + negative = TextEncodeQwenImageEditPlus().execute(self.clip, negative_prompt, image).args[0] + + + model = ModelSamplingAuraFlow().patch_aura(self.diffusion_model, self.sampler_cfg.flux_cfg)[0] + model = CFGNorm().execute(model, 1.0).args[0] + + + out = KSampler().sample( + model=model, seed=self.seed, steps=num_inference_steps, + cfg=self.sampler_cfg.cfg, sampler_name=self.sampler_cfg.sampler_name, + scheduler=self.sampler_cfg.scheduler, positive=positive, + negative=negative, latent_image=latent_image, + denoise=self.sampler_cfg.denoise)[0] + +class QwenRecipeBase(ModelRecipe): + @classmethod + def add_model_args(cls, parser): + group = parser.add_mutually_exclusive_group(required=True) + group.add_argument( + "--ckpt_path", + help="Path to full checkpoint (includes diffusion model + text encoder + VAE)" + ) + group.add_argument( + "--unet_path", + help="Path to diffusion model only (requires test_encoder and VAE)" + ) + parser.add_argument( + "--qwen_vl_path", + help="Path to text encoder (required with --unet_path)" + ) + parser.add_argument( + "--vae_path", + help="Path to VAE model (required with --unet_path)", + required=False, + ) + + def __init__(self, args): + self.args = args + + # Validate args + if hasattr(args, 'unet_path') and args.unet_path: + if not args.qwen_vl_path: + raise ValueError("--unet_path requires both --qwen_vl_path") + + def load_model(self) -> Tuple: + """Load FLUX model, CLIP, and VAE.""" + if hasattr(self.args, 'ckpt_path') and self.args.ckpt_path: + # Load from full checkpoint + logging.info(f"Loading full checkpoint from {self.args.ckpt_path}") + model_patcher, clip, vae, _ = comfy.sd.load_checkpoint_guess_config( + self.args.ckpt_path, + output_vae=True, + output_clip=True, + embedding_directory=None + ) + else: + # Load from separate files + logging.info(f"Loading diffusion model from {self.args.unet_path}") + model_options = {} + clip_type = comfy.sd.CLIPType.QWEN_IMAGE + + clip_path = folder_paths.get_full_path_or_raise("text_encoders", self.args.qwen_vl_path) + + model_patcher = comfy.sd.load_diffusion_model( + self.args.unet_path, + model_options=model_options + ) + clip = comfy.sd.load_clip( + ckpt_paths=[clip_path], + embedding_directory=folder_paths.get_folder_paths("embeddings"), + clip_type=clip_type, + model_options=model_options + ) + + vae = None # Not needed for calibration + if self.args.vae_path: + vae_path = folder_paths.get_full_path_or_raise("vae", self.args.vae_path) + sd = comfy.utils.load_torch_file(vae_path) + vae = comfy.sd.VAE(sd=sd) + vae.throw_exception_if_invalid() + return model_patcher, clip, vae + + def create_calibration_pipeline(self, model_components): + model_patcher, clip, vae = model_components + + return QwenPipe( + model=model_patcher, + clip=clip, + vae=vae, + width=self.get_width(), + height=self.get_height(), + seed=42, + sampler_cfg=self.get_sampler_cfg(), + device="cuda" + ) + + + def get_width(self) -> int: + return 1328 + + def get_height(self) -> int: + return 1328 + + def get_default_calib_steps(self) -> int: + return 64 + + def get_sampler_cfg(self) -> SamplerCFG: + return SamplerCFG( + cfg=2.5, + sampler_name="euler", + scheduler="simple", + denoise=1.0, + flux_cfg=3.1 + ) + + def get_inference_steps(self) -> int: + return 30 + + + +@register_recipe +class QwenImage(QwenRecipeBase): + @classmethod + def name(cls) -> str: + return "qwen_image" + + def get_dataset(self): + return HFPromptDataloader() + + def get_forward_loop(self, calib_pipeline, num_calib_steps) -> Callable: + num_steps = self.get_inference_steps() + dataloader = self.get_dataset() + + def forward_loop(): + for i in range(num_calib_steps): + rnd_idx = random.randint(0, len(dataloader) - 1) + sample = dataloader[rnd_idx] + prompt_text = sample["prompt"] + negative_text = "low quality" + + logging.debug(f"Calibration step {i + 1}: '{prompt_text[:50]}...'") + try: + calib_pipeline(num_steps, prompt_text, negative_text) + except Exception as e: + logging.warning(f"Calibration step {i + 1} failed: {e}") + + return forward_loop + + +@register_recipe +class QwenImageEdit(QwenRecipeBase): + @classmethod + def name(cls) -> str: + return "qwen_edit" + + def get_forward_loop(self, calib_pipeline, num_calib_steps) -> Callable: + num_steps = self.get_inference_steps() + dataloader = self.get_dataset() + def forward_loop(): + for i in range(num_calib_steps): + rnd_idx = random.randint(0, len(dataloader) - 1) + sample = dataloader[rnd_idx] + prompt_text = sample["prompt"] + negative_text = "low quality" + img = sample["image"] + + logging.debug(f"Calibration step {i+1}: '{prompt_text[:50]}...'") + try: + calib_pipeline.image_edit(num_steps, prompt_text, negative_text, img) + except Exception as e: + logging.warning(f"Calibration step {i+1} failed: {e}") + + return forward_loop + + def get_inference_steps(self) -> int: + return 30 + + def get_dataset(self): + logging.info("Loading KontextBench dataset...") + return KontextBenchDataLoader() + + def get_sampler_cfg(self) -> SamplerCFG: + return SamplerCFG( + cfg=2.5, + sampler_name="euler", + scheduler="simple", + denoise=1.0, + flux_cfg=3.0 + ) diff --git a/tools/ptq/models/wan.py b/tools/ptq/models/wan.py new file mode 100644 index 000000000000..ce50d39065d6 --- /dev/null +++ b/tools/ptq/models/wan.py @@ -0,0 +1,366 @@ +import logging + +import torch + +import comfy.sd +import comfy.utils +import folder_paths +import random +from typing import Tuple, Callable + +from comfy_extras.nodes_wan import Wan22ImageToVideoLatent, WanImageToVideo +from comfy_extras.nodes_model_advanced import ModelSamplingSD3 + +from nodes import CLIPTextEncode, KSampler, KSamplerAdvanced + +from . import register_recipe +from .base import ModelRecipe, SamplerCFG +from .dataset import HFPromptDataloader, KontextBenchDataLoader + +class WAN22SinglePipe: + def __init__( + self, + model, + clip, + vae, + width, + height, + length, + seed=0, + sampler_cfg: SamplerCFG = None, + device="cuda", + ) -> None: + self.clip = clip + self.vae = vae + self.diffusion_model = model + + self.width = width + self.height = height + self.length = length + + self.sampler_cfg = sampler_cfg + self.device = device + self.seed = seed + assert self.sampler_cfg is not None, "Sampler configuration is required" + + @torch.inference_mode + def __call__(self, num_inference_steps, positive_prompt, negative_prompt, image=None, *args, **kwargs): + positive = CLIPTextEncode().encode(self.clip, positive_prompt)[0] + negative = CLIPTextEncode().encode(self.clip, negative_prompt)[0] + + latent = Wan22ImageToVideoLatent().execute(width=self.width, height=self.height, length=self.length, + vae=self.vae, batch_size=1, start_image=image).args[0] + + model = ModelSamplingSD3().patch(self.diffusion_model, self.sampler_cfg.flux_cfg)[0] + + out, denoised_out = KSampler().sample(model, self.seed, num_inference_steps, + self.sampler_cfg.cfg, self.sampler_cfg.sampler_name, + self.sampler_cfg.scheduler, positive=positive, + negative=negative, latent_image=latent, + denoise=self.sampler_cfg.denoise) + + +@register_recipe +class WAN22SingleRecipe(ModelRecipe): + @classmethod + def name(cls) -> str: + return "wan_22_5b_t2v" + + @classmethod + def add_model_args(cls, parser): + group = parser.add_mutually_exclusive_group(required=True) + group.add_argument( + "--unet_path", + help="Path to diffusion model only (requires test_encoder and VAE)" + ) + + parser.add_argument( + "--clip_path", + help="Path to text encoder (required with --unet_path)" + ) + + parser.add_argument( + "--vae_path", + help="Path to VAE model (required with --unet_path)", + required=False, + ) + + def __init__(self, args): + self.args = args + + def load_model(self) -> Tuple: + # Load from separate files + logging.info(f"Loading diffusion model from {self.args.unet_path}") + model_options = {} + clip_type = comfy.sd.CLIPType.WAN + + clip_path = folder_paths.get_full_path_or_raise("text_encoders", self.args.clip_path) + + model_patcher = comfy.sd.load_diffusion_model( + self.args.unet_path, + model_options=model_options + ) + clip = comfy.sd.load_clip( + ckpt_paths=[clip_path], + embedding_directory=folder_paths.get_folder_paths("embeddings"), + clip_type=clip_type, + model_options=model_options + ) + + vae = None # Not needed for calibration + if self.args.vae_path: + vae_path = folder_paths.get_full_path_or_raise("vae", self.args.vae_path) + sd = comfy.utils.load_torch_file(vae_path) + vae = comfy.sd.VAE(sd=sd) + vae.throw_exception_if_invalid() + return model_patcher, clip, vae + + def create_calibration_pipeline(self, model_components): + model_patcher, clip, vae = model_components + + return WAN22SinglePipe( + model=model_patcher, + clip=clip, + vae=vae, + width=self.get_width(), + height=self.get_height(), + length=self.get_length(), + seed=42, + sampler_cfg=self.get_sampler_cfg(), + device="cuda" + ) + + def get_forward_loop(self, calib_pipeline, num_calib_steps) -> Callable: + num_steps = self.get_inference_steps() + dataloader = self.get_dataset() + + def forward_loop(): + for i in range(num_calib_steps): + rnd_idx = random.randint(0, len(dataloader) - 1) + sample = dataloader[rnd_idx] + prompt_text = sample["prompt"] + negative_prompt = "low quality" + + logging.debug(f"Calibration step {i + 1}: '{prompt_text[:50]}...'") + try: + calib_pipeline(num_steps, prompt_text, negative_prompt) + except Exception as e: + logging.warning(f"Calibration step {i + 1} failed: {e}") + + return forward_loop + + def get_width(self) -> int: + return 1280 + + def get_height(self) -> int: + return 704 + + def get_length(self) -> int: + return 121 + + def get_default_calib_steps(self) -> int: + return 32 + + def get_sampler_cfg(self) -> SamplerCFG: + return SamplerCFG( + cfg=5.0, + sampler_name="uni_pc", + scheduler="simple", + denoise=1.0, + flux_cfg=8.0 + ) + + def get_inference_steps(self) -> int: + return 20 + + def get_dataset(self): + return HFPromptDataloader() + +class WAN22DoublePipe: + def __init__( + self, + high_noise_model, + low_noise_model, + clip, + vae, + width, + height, + length, + seed=0, + sampler_cfg: SamplerCFG = None, + device="cuda", + ) -> None: + self.clip = clip + self.vae = vae + self.high_noise_model = high_noise_model + self.low_noise_model = low_noise_model + + self.width = width + self.height = height + self.length = length + + self.sampler_cfg = sampler_cfg + self.device = device + self.seed = seed + assert self.sampler_cfg is not None, "Sampler configuration is required" + + @torch.inference_mode + def __call__(self, num_inference_steps, positive_prompt, negative_prompt, image=None, *args, **kwargs): + positive = CLIPTextEncode().encode(self.clip, positive_prompt)[0] + negative = CLIPTextEncode().encode(self.clip, negative_prompt)[0] + + positive, negative, latent_image = WanImageToVideo().execute(width=self.width, height=self.height, + length=self.length, batch_size=1, + positive=positive, negative=negative, + vae=self.vae, start_image=image).args + + high_noise_model = ModelSamplingSD3().patch(self.high_noise_model, self.sampler_cfg.flux_cfg)[0] + low_noise_model = ModelSamplingSD3().patch(self.low_noise_model, self.sampler_cfg.flux_cfg)[0] + + mid_step = num_inference_steps // 2 + + out, denoised_out = KSamplerAdvanced().sample(model=high_noise_model, noise_seed=self.seed, + steps=num_inference_steps, cfg=self.sampler_cfg.cfg, + sampler_name=self.sampler_cfg.sampler_name, + scheduler=self.sampler_cfg.scheduler, positive=positive, + negative=negative, latent_image=latent_image, + denoise=self.sampler_cfg.denoise, add_noise=True, + start_at_step=0, end_at_step=mid_step, + return_with_leftover_noise=True) + + out, denoised_out = KSamplerAdvanced().sample(model=low_noise_model, noise_seed=self.seed, + steps=num_inference_steps, cfg=self.sampler_cfg.cfg, + sampler_name=self.sampler_cfg.sampler_name, + scheduler=self.sampler_cfg.scheduler, positive=positive, + negative=negative, latent_image=latent_image, + denoise=self.sampler_cfg.denoise, add_noise=False, + start_at_step=mid_step, end_at_step=num_inference_steps, + return_with_leftover_noise=False) + +@register_recipe +class WAN22DoubleRecipe(ModelRecipe): + @classmethod + def name(cls) -> str: + return "wan_22_14b_i2v" + + @classmethod + def add_model_args(cls, parser): + parser.add_argument( + "--unet_path_low_noise", + help="Path to diffusion model only (requires test_encoder and VAE)" + ) + + parser.add_argument( + "--unet_path_high_noise", + help="Path to diffusion model only (requires test_encoder and VAE)" + ) + + parser.add_argument( + "--clip_path", + help="Path to text encoder (required with --unet_path)" + ) + + parser.add_argument( + "--vae_path", + help="Path to VAE model (required with --unet_path)", + required=False, + ) + + def __init__(self, args): + self.args = args + + def load_model(self) -> Tuple: + # Load from separate files + logging.info(f"Loading diffusion model from {self.args.unet_path_low_noise}") + logging.info(f"Loading diffusion model from {self.args.unet_path_high_noise}") + model_options = {} + clip_type = comfy.sd.CLIPType.WAN + + clip_path = folder_paths.get_full_path_or_raise("text_encoders", self.args.clip_path) + + model_patcher_low = comfy.sd.load_diffusion_model( + self.args.unet_path_low_noise, + model_options=model_options + ) + model_patcher_high = comfy.sd.load_diffusion_model( + self.args.unet_path_high_noise, + model_options=model_options + ) + + clip = comfy.sd.load_clip( + ckpt_paths=[clip_path], + embedding_directory=folder_paths.get_folder_paths("embeddings"), + clip_type=clip_type, + model_options=model_options + ) + + vae = None # Not needed for calibration + if self.args.vae_path: + vae_path = folder_paths.get_full_path_or_raise("vae", self.args.vae_path) + sd = comfy.utils.load_torch_file(vae_path) + vae = comfy.sd.VAE(sd=sd) + vae.throw_exception_if_invalid() + return model_patcher_high, model_patcher_low, clip, vae + + def create_calibration_pipeline(self, model_components): + high_noise_model, low_noise_model, clip, vae = model_components + + return WAN22DoublePipe( + high_noise_model=high_noise_model, + low_noise_model=low_noise_model, + clip=clip, + vae=vae, + width=self.get_width(), + height=self.get_height(), + length=self.get_length(), + seed=42, + sampler_cfg=self.get_sampler_cfg(), + device="cuda" + ) + + def get_forward_loop(self, calib_pipeline, num_calib_steps) -> Callable: + num_steps = self.get_inference_steps() + dataloader = self.get_dataset() + + def forward_loop(): + for i in range(num_calib_steps): + rnd_idx = random.randint(0, len(dataloader) - 1) + sample = dataloader[rnd_idx] + image = sample["image"] + prompt_text = sample["prompt"] + negative_prompt = "low quality" + + logging.debug(f"Calibration step {i + 1}: '{prompt_text[:50]}...'") + try: + calib_pipeline(num_steps, prompt_text, negative_prompt, image) + except Exception as e: + logging.warning(f"Calibration step {i + 1} failed: {e}") + + return forward_loop + + def get_width(self) -> int: + return 720 + + def get_height(self) -> int: + return 480 + + def get_length(self) -> int: + return 81 + + def get_default_calib_steps(self) -> int: + return 32 + + def get_inference_steps(self) -> int: + return 20 + + def get_sampler_cfg(self) -> SamplerCFG: + return SamplerCFG( + cfg=3.5, + sampler_name="euler", + scheduler="simple", + denoise=1.0, + flux_cfg=8.0 + ) + + def get_dataset(self): + return KontextBenchDataLoader() diff --git a/tools/ptq/quantize.py b/tools/ptq/quantize.py new file mode 100644 index 000000000000..09cbc4983e63 --- /dev/null +++ b/tools/ptq/quantize.py @@ -0,0 +1,180 @@ +import torch +from typing import Dict + +import argparse +import logging +import sys +import torch.utils.data +import modelopt.torch.quantization as mtq +from tools.ptq.utils import log_quant_summary, save_amax_dict, extract_amax_values + +import os +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) + +from tools.ptq.models import get_recipe_class, list_recipes +from tools.ptq.utils import register_comfy_ops, FP8_CFG + +class PTQPipeline: + def __init__(self, model_patcher, quant_config: dict): + self.model_patcher = model_patcher + self.diffusion_model = model_patcher.model.diffusion_model + self.quant_config = quant_config + + logging.debug(f"PTQPipeline initialized with config: {quant_config}") + + @torch.no_grad() + def calibrate_with_pipeline( + self, + calib_pipeline, + num_steps: int, + recipe + ): + + logging.info(f"Running calibration with {num_steps} steps...") + forward_loop = recipe.get_forward_loop(calib_pipeline, num_steps) + try: + mtq.quantize(self.diffusion_model, self.quant_config, forward_loop=forward_loop) + except Exception as e: + logging.error(f"Calibration failed: {e}") + raise + + logging.info("Calibration complete") + log_quant_summary(self.diffusion_model) + + def get_amax_dict(self) -> Dict: + return extract_amax_values(self.diffusion_model) + + def save_amax_values(self, output_path: str, metadata: dict = None): + amax_dict = self.get_amax_dict() + save_amax_dict(amax_dict, output_path, metadata=metadata) + logging.info(f"Saved amax values to {output_path}") + + + +def main(): + parser = argparse.ArgumentParser( + description="Quantize ComfyUI models using NVIDIA ModelOptimizer", + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + + parser.add_argument( + "--model_type", + required=True, + choices=list_recipes(), + help="Model recipe to use" + ) + + args, remaining = parser.parse_known_args() + + recipe_cls = get_recipe_class(args.model_type) + recipe_cls.add_model_args(parser) + parser.add_argument( + "--output", + required=True, + help="Output path for amax artefact" + ) + parser.add_argument( + "--calib_steps", + type=int, + help="Override default calibration steps" + ) + parser.add_argument( + "--seed", + type=int, + default=42, + help="Random seed" + ) + parser.add_argument( + "--debug", + action="store_true", + help="Enable debug mode (sets logging to DEBUG and calib_steps to 1)" + ) + + args = parser.parse_args() + if args.debug: + logging.basicConfig( + level=logging.DEBUG, + format='[%(levelname)s] %(name)s: %(message)s' + ) + logging.info("Debug mode enabled") + else: + logging.basicConfig( + level=logging.INFO, + format='[%(levelname)s] %(message)s' + ) + try: + recipe = recipe_cls(args) + except Exception as e: + logging.error(f"Failed to initialize recipe: {e}") + sys.exit(1) + if args.debug: + calib_steps = 1 + logging.debug("Debug mode: forcing calib_steps=1") + elif args.calib_steps: + calib_steps = args.calib_steps + else: + calib_steps = recipe.get_default_calib_steps() + + logging.info("Registering ComfyUI ops with ModelOptimizer...") + register_comfy_ops() + + logging.info("[1/5] Loading model...") + try: + model_components = recipe.load_model() + model_patcher = model_components[0] + except Exception as e: + logging.error(f"Failed to load model: {e}") + sys.exit(1) + + logging.info("[2/5] Preparing quantization...") + try: + pipeline = PTQPipeline( + model_patcher, + quant_config=FP8_CFG, + ) + except Exception as e: + logging.error(f"Failed to prepare quantization: {e}") + sys.exit(1) + + logging.info("[3/5] Creating calibration pipeline...") + try: + calib_pipeline = recipe.create_calibration_pipeline(model_components) + except Exception as e: + logging.error(f"Failed to create calibration pipeline: {e}") + sys.exit(1) + + logging.info(f"[4/5] Running calibration ({calib_steps} steps)...") + try: + pipeline.calibrate_with_pipeline( + calib_pipeline, + num_steps=calib_steps, + recipe=recipe + ) + except Exception as e: + logging.error(f"Calibration failed: {e}") + import traceback + traceback.print_exc() + sys.exit(1) + + logging.info("[5/5] Extracting and saving amax values...") + try: + metadata = { + "model_type": recipe.name(), + "calibration_steps": calib_steps, + "quantization_format": "amax", + "debug_mode": args.debug + } + + if hasattr(args, 'ckpt_path') and args.ckpt_path: + metadata["checkpoint_path"] = args.ckpt_path + + pipeline.save_amax_values(args.output, metadata=metadata) + except Exception as e: + logging.error(f"Failed to save amax values: {e}") + import traceback + traceback.print_exc() + sys.exit(1) + +if __name__ == "__main__": + main() + diff --git a/tools/ptq/utils.py b/tools/ptq/utils.py new file mode 100644 index 000000000000..4366a621e2e8 --- /dev/null +++ b/tools/ptq/utils.py @@ -0,0 +1,89 @@ +import torch +import logging +from typing import Dict, Optional + +import comfy.ops +from modelopt.torch.quantization.nn import QuantModuleRegistry, TensorQuantizer + + +FP8_CFG = { + "quant_cfg": { + "*weight_quantizer": {"num_bits": (4, 3), "axis": None}, + "*input_quantizer": {"num_bits": (4, 3), "axis": None}, + "default": {"enable": False}, + }, + "algorithm": "max", +} + + +def register_comfy_ops(): + op = comfy.ops.disable_weight_init.Linear + op_name = op.__name__ + + if op in QuantModuleRegistry: + logging.debug("ComfyUI Linear already registered with ModelOptimizer") + return + + QuantModuleRegistry.register( + {op: f"comfy.{op_name}"} + )(QuantModuleRegistry._registry[getattr(torch.nn, op_name)]) + + logging.info("Registered ComfyUI Linear with ModelOptimizer") + +def log_quant_summary(model: torch.nn.Module, log_level=logging.INFO): + count = 0 + for name, mod in model.named_modules(): + if isinstance(mod, TensorQuantizer): + logging.log(log_level, f"{name:80} {mod}") + count += 1 + logging.log(log_level, f"{count} TensorQuantizers found in model") + +def extract_amax_values(model: torch.nn.Module) -> Dict[str, torch.Tensor]: + amax_dict = {} + + for name, module in model.named_modules(): + if not isinstance(module, TensorQuantizer): + continue + if not module.is_enabled: + continue + if hasattr(module, '_amax') and module._amax is not None: + amax = module._amax + if not isinstance(amax, torch.Tensor): + amax = torch.tensor(amax, dtype=torch.float32) + + amax_dict[name] = amax.clone().cpu() + logging.debug(f"Extracted amax from {name}: {amax.item():.6f}") + + logging.info(f"Extracted amax values from {len(amax_dict)} quantizers") + return amax_dict + + +def save_amax_dict(amax_dict: Dict[str, torch.Tensor], output_path: str, metadata: Optional[Dict] = None): + import json + from datetime import datetime + + logging.info(f"Saving {len(amax_dict)} amax values to {output_path}") + + amax_values = {} + for key, value in amax_dict.items(): + if isinstance(value, torch.Tensor): + if value.numel() == 1: + amax_values[key] = float(value.item()) + else: + amax_values[key] = value.cpu().numpy().tolist() + else: + amax_values[key] = float(value) + + output_dict = { + "metadata": { + "timestamp": datetime.now().isoformat(), + "num_layers": len(amax_values), + **(metadata or {}) + }, + "amax_values": amax_values + } + + with open(output_path, 'w') as f: + json.dump(output_dict, f, indent=2, sort_keys=True) + + logging.info(f"✓ Amax values saved to {output_path}")