Skip to content

question about PSNR metric on FLUX.1 #46

@kjd123456

Description

@kjd123456

Thanks for your great work!
However, when I tried to evaluate the PSNR, SSIM, and LPIPS metrics on FLUX.1 models, I found a large gap between the original paper and the reproduced result on PSNR, while SSIM and LPIPS are generally consistent with the original paper. I have tried different packages(OpenCV and torchmetrics), but the results remain the same.

Here's my code. I use the default seed and Drawbench prompts to sample 200 images.

import os
import argparse
from glob import glob
from typing import List, Optional, Tuple

import torch
import cv2
import numpy as np
from PIL import Image
from tqdm import tqdm
from torchvision import transforms
from torchmetrics import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure

# ==============================================================================
# Patch for ImageReward compatibility with newer transformers versions
# Map transformers.pytorch_utils.apply_chunking_to_forward 
# to transformers.modeling_utils.apply_chunking_to_forward
# ==============================================================================
try:
    from transformers import modeling_utils, pytorch_utils
    
    # List of functions moved from modeling_utils to pytorch_utils
    moved_functions = [
        "apply_chunking_to_forward",
        "find_pruneable_heads_and_indices",
        "prune_linear_layer"  # Just in case this is needed too
    ]
    
    for func_name in moved_functions:
        if not hasattr(modeling_utils, func_name) and hasattr(pytorch_utils, func_name):
            print(f"Patching transformers.modeling_utils.{func_name} from pytorch_utils...")
            setattr(modeling_utils, func_name, getattr(pytorch_utils, func_name))
            
except ImportError:
    pass
# ==============================================================================

def load_image_tensor(path: str, size: Optional[Tuple[int, int]] = None) -> torch.Tensor:
    img = Image.open(path).convert("RGB")
    if size is not None:
        img = img.resize(size, Image.BICUBIC)
    return transforms.ToTensor()(img).unsqueeze(0)


def match_ref_path(ref_dir: str, filename: str) -> Optional[str]:
    ref_path = os.path.join(ref_dir, filename)
    if os.path.exists(ref_path):
        return ref_path
    base = os.path.splitext(filename)[0]
    for ext in (".png", ".jpg", ".jpeg", ".bmp"):
        candidate = os.path.join(ref_dir, base + ext)
        if os.path.exists(candidate):
            return candidate
    return None


def collect_images(gen_dir: str) -> List[str]:
    exts = ("*.jpg", "*.jpeg", "*.png", "*.bmp")
    # Glob all matching files
    all_files = [f for ext in exts for f in glob(os.path.join(gen_dir, ext))]
    
    # Sort by the number before the first underscore
    # Filename format expected: "123_img.jpg" or "123.jpg"
    def sort_key(filepath):
        filename = os.path.basename(filepath)
        try:
            number_part = filename.split('_')[0]
            return int(number_part)
        except ValueError:
            return filename

    gen_files = sorted(all_files, key=sort_key)
    return gen_files


def collect_ref_images(gen_files: List[str], ref_dir: str) -> List[str]:
    ref_files = []
    for gen_path in gen_files:
        filename = os.path.basename(gen_path)
        ref_path = match_ref_path(ref_dir, filename)
        if ref_path is None:
            ref_files.append(None)
        else:
            ref_files.append(ref_path)
    return ref_files


def compute_psnr_ssim_lpips(gen_files: List[str], ref_dir: str, device: torch.device):
    from torchmetrics.functional import structural_similarity_index_measure as ssim_fn
    import lpips

    lpips_metric = lpips.LPIPS(net="alex").to(device)

    total_psnr = total_ssim = total_lpips = 0.0
    count = 0

    for gen_path in tqdm(gen_files, desc="PSNR/SSIM/LPIPS"):
        filename = os.path.basename(gen_path)
        ref_path = match_ref_path(ref_dir, filename)
        if ref_path is None:
            print(f"Warning: reference for {filename} not found, skip")
            continue

        try:
            img_gen = load_image_tensor(gen_path).to(device)
            img_ref = load_image_tensor(ref_path).to(device)
            if img_gen.shape != img_ref.shape:
                resize = transforms.Resize((img_ref.shape[2], img_ref.shape[3]))
                img_gen = resize(img_gen)
            img_gen_uint8 = (img_gen * 255).clamp(0, 255).to(torch.uint8)

            # OpenCV PSNR on uint8 RGB;
            img_gen_np = (img_gen.squeeze(0).permute(1, 2, 0).cpu().numpy() * 255).clip(0, 255).astype(np.uint8)
            img_ref_np = (img_ref.squeeze(0).permute(1, 2, 0).cpu().numpy() * 255).clip(0, 255).astype(np.uint8)
            psnr_val = cv2.PSNR(img_gen_np, img_ref_np)
            lpips_val = lpips_metric(img_gen * 2 - 1, img_ref * 2 - 1).item()
            ssim_val = ssim_fn(img_gen, img_ref, data_range=1.0).item()
            
            total_psnr += psnr_val
            total_ssim += ssim_val
            total_lpips += lpips_val
            count += 1
        except Exception as exc:
            print(f"Error processing {filename}: {exc}")

    if count == 0:
        return None

    return {
        "count": count,
        "psnr": total_psnr / count,
        "ssim": total_ssim / count,
        "lpips": total_lpips / count,
    }


def compute_clip_score(image_files: List[Optional[str]], prompts: List[str], device: torch.device, model_id: str, label: str = "CLIPScore"):
    from transformers import CLIPModel, CLIPProcessor

    if len(image_files) != len(prompts):
        print(f"Warning: number of prompts ({len(prompts)}) != number of images ({len(image_files)}), will pair by min length")
    
    # Filter out None paths (missing images)
    valid_pairs = []
    for img_path, prompt in zip(image_files, prompts):
        if img_path is not None:
             valid_pairs.append((img_path, prompt))
            
    if not valid_pairs:
        return None

    processor = CLIPProcessor.from_pretrained(model_id)
    model = CLIPModel.from_pretrained(model_id).to(device)
    scores = []

    for img_path, prompt in tqdm(valid_pairs, desc=label):
        try:
            image = Image.open(img_path).convert("RGB")
            text_inputs = processor.tokenizer(
                text=[prompt], 
                padding="max_length", 
                truncation=True, 
                max_length=77, 
                return_tensors="pt"
            ).to(device)
            
            image_inputs = processor.image_processor(
                images=[image], 
                return_tensors="pt"
            ).to(device)
            
            inputs = {**text_inputs, **image_inputs}
            with torch.no_grad():
                outputs = model(**inputs)
                img_emb = outputs.image_embeds / outputs.image_embeds.norm(dim=-1, keepdim=True)
                txt_emb = outputs.text_embeds / outputs.text_embeds.norm(dim=-1, keepdim=True)
                score = (img_emb * txt_emb).sum(dim=-1) * 100.0
                scores.append(score.item())
        except Exception as e:
            print(f"Error calculating CLIP score for {img_path}: {e}")

    if not scores:
        return None
    return sum(scores) / len(scores)


def compute_image_reward(gen_files: List[str], prompts: List[str], device: torch.device):
    try:
        import ImageReward as ir
    except Exception as exc:  # noqa: BLE001
        print(f"ImageReward not available: {exc}")
        return None

    if len(gen_files) != len(prompts):
        print("Warning: number of prompts != number of images, will pair by min length")
    pairs = list(zip(gen_files, prompts))

    # scorer = ImageReward("ImageReward/ImageReward-v1.0", device=str(device))
    scorer = ir.load("ImageReward-v1.0")
    scores = []
    for img_path, prompt in tqdm(pairs, desc="ImageReward"):
        try:
            scores.append(scorer.score(prompt, img_path))
        except Exception as exc:  # noqa: BLE001
            print(f"ImageReward failed on {os.path.basename(img_path)}: {exc}")
    if not scores:
        return None
    return sum(scores) / len(scores)


def read_prompts(prompt_file: str) -> List[str]:
    with open(prompt_file, "r", encoding="utf-8") as f:
        lines = [line.strip() for line in f.readlines() if line.strip()]
    return lines


def main():
    parser = argparse.ArgumentParser(description="Evaluate PSNR/SSIM/LPIPS and optionally CLIPScore/ImageReward.")
    parser.add_argument("--gen_dir", required=True, help="Directory with generated images")
    parser.add_argument("--ref_dir", required=True, help="Directory with reference images")
    parser.add_argument("--prompt_file", help="Text file with one prompt per generated image (needed for CLIP/ImageReward)")
    parser.add_argument("--clip_model", default="openai/clip-vit-large-patch14", help="CLIP model id")
    parser.add_argument("--no_clip", action="store_true", help="Skip CLIPScore")
    parser.add_argument("--no_imagereward", action="store_true", help="Skip ImageReward")
    parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="Device (cuda/cpu)")

    args = parser.parse_args()

    device = torch.device(args.device)
    gen_files = collect_images(args.gen_dir)
    if not gen_files:
        print(f"No images found in {args.gen_dir}")
        return

    psnr_ssim_lpips = compute_psnr_ssim_lpips(gen_files, args.ref_dir, device)

    prompts = read_prompts(args.prompt_file) if args.prompt_file else []
    gen_clip_score = None
    ref_clip_score = None
    image_reward = None

    if prompts and not args.no_clip:
        gen_clip_score = compute_clip_score(gen_files, prompts, device, args.clip_model, label="CLIPScore(Gen)")
        
        ref_files = collect_ref_images(gen_files, args.ref_dir)
        ref_clip_score = compute_clip_score(ref_files, prompts, device, args.clip_model, label="CLIPScore(Ref)")
        
    elif not args.no_clip:
        print("CLIPScore skipped: prompt_file is required")

    if prompts and not args.no_imagereward:
        image_reward = compute_image_reward(gen_files, prompts, device)
    elif not args.no_imagereward:
        print("ImageReward skipped: prompt_file is required")

    print("\n" + "=" * 40)
    if psnr_ssim_lpips:
        print(f"Pairs:   {psnr_ssim_lpips['count']}")
        print(f"PSNR:    {psnr_ssim_lpips['psnr']:.4f}")
        print(f"SSIM:    {psnr_ssim_lpips['ssim']:.4f}")
        print(f"LPIPS:   {psnr_ssim_lpips['lpips']:.4f}")
    else:
        print("PSNR/SSIM/LPIPS skipped (no pairs)")

    if gen_clip_score is not None:
        print(f"CLIPScore (Gen): {gen_clip_score:.4f}")
    if ref_clip_score is not None:
        print(f"CLIPScore (Ref): {ref_clip_score:.4f}")
        
    if image_reward is not None:
        print(f"ImageReward: {image_reward:.4f}")
    print("=" * 40)


if __name__ == "__main__":
    main()

My result on Taylorseer N=3 O=2 is:
PSNR: 19.5747
SSIM: 0.7456
LPIPS: 0.2233
while the result in the paper is:
PSNR: 30.762
SSIM: 0.7818
LPIPS: 0.2300

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions