Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions tools/ptq/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@

### Create PTQ Artefact / Run Calibration
Only needs to be done once per model
```
python -m tools.ptq.quantize \
--model_type flux_schnell
--unet_path <path>/flux1-schnell.safetensors
--clip_path clip_l.safetensors
--t5_path t5xxl_fp16.safetensors
--output flux_schnell_debug.json
--calib_steps 16
```

### Create Quantized Checkpoint
Uses the artefact from before + checkpoint to generate a quantized checkpoint based of a yml configuration.
This file defines which layers to quantize and what dtype they should use using regex. See `tools/ptq/configs/` for examples.
```yaml
# config.yml
disable_list: ["*img_in*", "*final_layer*", "*norm*"] # Keep these in BF16
per_layer_dtype: {"*": "float8_e4m3fn"} # Everything else to FP8
```

```
python -m tools.ptq.checkpoint_merger
--artefact flux_dev_debug.json
--checkpoint <path>/flux1-dev.safetensors
--config tools/ptq/configs/flux_nvfp4.yml
--output <out_path>/flux1-nvfp4.safetensors
--debug
```

240 changes: 240 additions & 0 deletions tools/ptq/checkpoint_merger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,240 @@
import argparse
import logging
import sys
import yaml
import re
from typing import Dict, Tuple
import torch
from safetensors.torch import save_file
import json

import os
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))

import comfy.utils
from comfy.ops import QUANT_FORMAT_MIXINS
from comfy.quant_ops import F8_E4M3_MAX, F4_E2M1_MAX

class QuantizationConfig:
def __init__(self, config_path: str):
with open(config_path, 'r') as f:
self.config = yaml.safe_load(f)

self.disable_patterns = []
for pattern in self.config.get('disable_list', []):
regex_pattern = pattern.replace('*', '.*')
self.disable_patterns.append(re.compile(regex_pattern))

self.per_layer_dtype = self.config.get('per_layer_dtype', {})
self.dtype_patterns = []
for pattern, dtype in self.per_layer_dtype.items():
regex_pattern = pattern.replace('*', '.*')
self.dtype_patterns.append((re.compile(regex_pattern), dtype))

logging.info(f"Loaded config with {len(self.disable_patterns)} disable patterns")
logging.info(f"Per-layer dtype rules: {self.per_layer_dtype}")

def should_quantize(self, layer_name: str) -> bool:
for pattern in self.disable_patterns:
if pattern.match(layer_name):
logging.debug(f"Layer {layer_name} disabled by pattern {pattern.pattern}")
return False
return True

def get_dtype(self, layer_name: str) -> str:
for pattern, dtype in self.dtype_patterns:
if pattern.match(layer_name):
return dtype
return None

def load_amax_artefact(artefact_path: str) -> Dict:
logging.info(f"Loading amax artefact from {artefact_path}")

with open(artefact_path, 'r') as f:
data = json.load(f)

if 'amax_values' not in data:
raise ValueError("Invalid artefact format: missing 'amax_values' key")

metadata = data.get('metadata', {})
amax_values = data['amax_values']

logging.info(f"Loaded {len(amax_values)} amax values from artefact")
logging.info(f"Artefact metadata: {metadata}")

return data

def get_scale_fp8(amax: float, dtype: torch.dtype) -> torch.Tensor:
scale = amax / torch.finfo(dtype).max
scale_tensor = torch.tensor(scale, dtype=torch.float32)
return scale_tensor

def get_scale_nvfp4(amax: float, dtype: torch.dtype) -> torch.Tensor:
scale = amax / (F8_E4M3_MAX * F4_E2M1_MAX)
scale_tensor = torch.tensor(scale, dtype=torch.float32)
return scale_tensor

def get_scale(amax: float, dtype: torch.dtype):
if dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
return get_scale_fp8(amax, dtype)
elif dtype in [torch.float4_e2m1fn_x2]:
return get_scale_nvfp4(amax, dtype)
else:
raise ValueError(f"Unsupported dtype {dtype} ")

def apply_quantization(
checkpoint: Dict,
amax_values: Dict[str, float],
config: QuantizationConfig
) -> Tuple[Dict, Dict]:
quantized_dict = {}
layer_metadata = {}

for key, amax in amax_values.items():
if key.endswith(".input_quantizer"):
continue

layer_name = ".".join(key.split(".")[:-1])

if not config.should_quantize(layer_name):
logging.debug(f"Layer {layer_name} disabled by config")
continue

dtype_str = config.get_dtype(layer_name)
dtype = getattr(torch, dtype_str)
device = torch.device("cuda")

weight = checkpoint.pop(f"{layer_name}.weight").to(device)
scale_tensor = get_scale(amax, dtype)

input_amax = amax_values.get(f"{layer_name}.input_quantizer", None)
if input_amax is not None:
input_scale = get_scale(input_amax, dtype)
quantized_dict[f"{layer_name}.input_scale"] = input_scale.clone()

tensor_layout = QUANT_FORMAT_MIXINS[dtype_str]["layout_type"]
quantized_weight, layout_params = tensor_layout.quantize(
weight,
scale=scale_tensor,
dtype=dtype
)
quantized_dict[f"{layer_name}.weight_scale"] = scale_tensor.clone()
quantized_dict[f"{layer_name}.weight"] = quantized_weight.clone()

if "block_scale" in layout_params:
quantized_dict[f"{layer_name}.weight_block_scale"] = layout_params["block_scale"].clone()

layer_metadata[layer_name] = {
"format": dtype_str,
"params": {}
}

logging.info(f"Quantized {len(layer_metadata)} layers")

quantized_dict = quantized_dict | checkpoint

metadata_dict = {
"_quantization_metadata": json.dumps({
"format_version": "1.0",
"layers": layer_metadata
})
}
return quantized_dict, metadata_dict


def main():
parser = argparse.ArgumentParser(
description="Merge calibration artifacts with checkpoint to create quantized model",
formatter_class=argparse.RawDescriptionHelpFormatter,
)

parser.add_argument(
"--artefact",
required=True,
help="Path to calibration artefact JSON file (amax values)"
)
parser.add_argument(
"--checkpoint",
required=True,
help="Path to original checkpoint to quantize"
)
parser.add_argument(
"--config",
required=True,
help="Path to YAML quantization config file"
)
parser.add_argument(
"--output",
required=True,
help="Output path for quantized checkpoint"
)
parser.add_argument(
"--debug",
action="store_true",
help="Enable debug logging"
)

args = parser.parse_args()

# Configure logging
if args.debug:
logging.basicConfig(
level=logging.DEBUG,
format='[%(levelname)s] %(name)s: %(message)s'
)
else:
logging.basicConfig(
level=logging.INFO,
format='[%(levelname)s] %(message)s'
)

logging.info("[1/5] Loading calibration artefact...")
try:
artefact_data = load_amax_artefact(args.artefact)
amax_values = artefact_data['amax_values']
except Exception as e:
logging.error(f"Failed to load artefact: {e}")
sys.exit(1)

logging.info("[2/5] Loading quantization config...")
try:
config = QuantizationConfig(args.config)
except Exception as e:
logging.error(f"Failed to load config: {e}")
sys.exit(1)

logging.info("[3/5] Loading checkpoint...")
try:
checkpoint = comfy.utils.load_torch_file(args.checkpoint)
logging.info(f"Loaded checkpoint with {len(checkpoint)} keys")
except Exception as e:
logging.error(f"Failed to load checkpoint: {e}")
sys.exit(1)

logging.info("[4/5] Applying quantization...")
try:
quantized_dict, metadata_json = apply_quantization(
checkpoint,
amax_values,
config
)
except Exception as e:
logging.error(f"Failed to apply quantization: {e}")
import traceback
traceback.print_exc()
sys.exit(1)

logging.info("[5/5] Exporting quantized checkpoint...")
try:
save_file(quantized_dict, args.output, metadata=metadata_json)

except Exception as e:
logging.error(f"Failed to export checkpoint: {e}")
import traceback
traceback.print_exc()
sys.exit(1)


if __name__ == "__main__":
main()

23 changes: 23 additions & 0 deletions tools/ptq/configs/flux_fp8.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# FLUX Quantization Config: Transformer Blocks Only
#
# Quantize only double and single transformer blocks,
# leave input/output projections in higher precision.

disable_list: [
# Disable input projections
"*img_in*",
"*txt_in*",
"*time_in*",
"*vector_in*",
"*guidance_in*",

# Disable output layers
"*final_layer*",

# Disable positional embeddings
"*pe_embedder*",
]

per_layer_dtype: {
"*": "float8_e4m3fn",
}
27 changes: 27 additions & 0 deletions tools/ptq/configs/flux_nvfp4.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# FLUX Quantization Config: Transformer Blocks Only
#
# Quantize only double and single transformer blocks,
# leave input/output projections in higher precision.

disable_list: [
# Disable input projections
"*img_in*",
"*txt_in*",
"*time_in*",
"*vector_in*",
"*guidance_in*",

# Disable output layers
"*final_layer*",

# Disable positional embeddings
"*pe_embedder*",

"*modulation*",
"*txt_mod*",
"*img_mod*",
]

per_layer_dtype: {
"*": "float4_e2m1fn_x2",
}
30 changes: 30 additions & 0 deletions tools/ptq/example.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Quantization Configuration for Checkpoint Merger
#
# This file defines which layers to quantize and what precision to use.
# Patterns use glob-style syntax where * matches any characters.

# Regex patterns of layers to DISABLE quantization
# If a layer matches any pattern here, it will NOT be quantized
disable_list: [
# Example: disable input/output projection layers
# "*img_in*",
# "*txt_in*",
# "*final_layer*",

# Example: disable specific block types
# "*norm*",
# "*time_in*",
]

# Per-layer dtype configuration
# Maps layer name patterns to quantization formats
# Layers are matched in order - first match wins
per_layer_dtype: {
# Default: quantize all layers to FP8 E4M3
"*": "fp8_e4m3fn",

# Example: use different precision for specific layers
# "*attn*": "fp8_e4m3fn", # Attention layers
# "*mlp*": "fp8_e4m3fn", # MLP layers
# "*qkv*": "fp8_e4m3fn", # Q/K/V projections
}
36 changes: 36 additions & 0 deletions tools/ptq/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from typing import Dict, Type
from .base import ModelRecipe


_RECIPE_REGISTRY: Dict[str, Type[ModelRecipe]] = {}


def register_recipe(recipe_cls: Type[ModelRecipe]):
recipe_name = recipe_cls.name()
if recipe_name in _RECIPE_REGISTRY:
raise ValueError(f"Recipe '{recipe_name}' is already registered")

_RECIPE_REGISTRY[recipe_name] = recipe_cls
return recipe_cls


def get_recipe_class(name: str) -> Type[ModelRecipe]:
if name not in _RECIPE_REGISTRY:
available = ", ".join(sorted(_RECIPE_REGISTRY.keys()))
raise ValueError(
f"Unknown model type '{name}'. "
f"Available recipes: {available}"
)
return _RECIPE_REGISTRY[name]


def list_recipes():
return sorted(_RECIPE_REGISTRY.keys())


# Import recipe modules to trigger registration
from . import flux # noqa: F401, E402
from . import qwen # noqa: F401, E402
from . import ltx_video # noqa: F401, E402
from . import wan # noqa: F401, E402

Loading
Loading