Here's my code. I use the default seed and Drawbench prompts to sample 200 images.
import os
import argparse
from glob import glob
from typing import List, Optional, Tuple
import torch
import cv2
import numpy as np
from PIL import Image
from tqdm import tqdm
from torchvision import transforms
from torchmetrics import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure
# ==============================================================================
# Patch for ImageReward compatibility with newer transformers versions
# Map transformers.pytorch_utils.apply_chunking_to_forward
# to transformers.modeling_utils.apply_chunking_to_forward
# ==============================================================================
try:
from transformers import modeling_utils, pytorch_utils
# List of functions moved from modeling_utils to pytorch_utils
moved_functions = [
"apply_chunking_to_forward",
"find_pruneable_heads_and_indices",
"prune_linear_layer" # Just in case this is needed too
]
for func_name in moved_functions:
if not hasattr(modeling_utils, func_name) and hasattr(pytorch_utils, func_name):
print(f"Patching transformers.modeling_utils.{func_name} from pytorch_utils...")
setattr(modeling_utils, func_name, getattr(pytorch_utils, func_name))
except ImportError:
pass
# ==============================================================================
def load_image_tensor(path: str, size: Optional[Tuple[int, int]] = None) -> torch.Tensor:
img = Image.open(path).convert("RGB")
if size is not None:
img = img.resize(size, Image.BICUBIC)
return transforms.ToTensor()(img).unsqueeze(0)
def match_ref_path(ref_dir: str, filename: str) -> Optional[str]:
ref_path = os.path.join(ref_dir, filename)
if os.path.exists(ref_path):
return ref_path
base = os.path.splitext(filename)[0]
for ext in (".png", ".jpg", ".jpeg", ".bmp"):
candidate = os.path.join(ref_dir, base + ext)
if os.path.exists(candidate):
return candidate
return None
def collect_images(gen_dir: str) -> List[str]:
exts = ("*.jpg", "*.jpeg", "*.png", "*.bmp")
# Glob all matching files
all_files = [f for ext in exts for f in glob(os.path.join(gen_dir, ext))]
# Sort by the number before the first underscore
# Filename format expected: "123_img.jpg" or "123.jpg"
def sort_key(filepath):
filename = os.path.basename(filepath)
try:
number_part = filename.split('_')[0]
return int(number_part)
except ValueError:
return filename
gen_files = sorted(all_files, key=sort_key)
return gen_files
def collect_ref_images(gen_files: List[str], ref_dir: str) -> List[str]:
ref_files = []
for gen_path in gen_files:
filename = os.path.basename(gen_path)
ref_path = match_ref_path(ref_dir, filename)
if ref_path is None:
ref_files.append(None)
else:
ref_files.append(ref_path)
return ref_files
def compute_psnr_ssim_lpips(gen_files: List[str], ref_dir: str, device: torch.device):
from torchmetrics.functional import structural_similarity_index_measure as ssim_fn
import lpips
lpips_metric = lpips.LPIPS(net="alex").to(device)
total_psnr = total_ssim = total_lpips = 0.0
count = 0
for gen_path in tqdm(gen_files, desc="PSNR/SSIM/LPIPS"):
filename = os.path.basename(gen_path)
ref_path = match_ref_path(ref_dir, filename)
if ref_path is None:
print(f"Warning: reference for {filename} not found, skip")
continue
try:
img_gen = load_image_tensor(gen_path).to(device)
img_ref = load_image_tensor(ref_path).to(device)
if img_gen.shape != img_ref.shape:
resize = transforms.Resize((img_ref.shape[2], img_ref.shape[3]))
img_gen = resize(img_gen)
img_gen_uint8 = (img_gen * 255).clamp(0, 255).to(torch.uint8)
# OpenCV PSNR on uint8 RGB;
img_gen_np = (img_gen.squeeze(0).permute(1, 2, 0).cpu().numpy() * 255).clip(0, 255).astype(np.uint8)
img_ref_np = (img_ref.squeeze(0).permute(1, 2, 0).cpu().numpy() * 255).clip(0, 255).astype(np.uint8)
psnr_val = cv2.PSNR(img_gen_np, img_ref_np)
lpips_val = lpips_metric(img_gen * 2 - 1, img_ref * 2 - 1).item()
ssim_val = ssim_fn(img_gen, img_ref, data_range=1.0).item()
total_psnr += psnr_val
total_ssim += ssim_val
total_lpips += lpips_val
count += 1
except Exception as exc:
print(f"Error processing {filename}: {exc}")
if count == 0:
return None
return {
"count": count,
"psnr": total_psnr / count,
"ssim": total_ssim / count,
"lpips": total_lpips / count,
}
def compute_clip_score(image_files: List[Optional[str]], prompts: List[str], device: torch.device, model_id: str, label: str = "CLIPScore"):
from transformers import CLIPModel, CLIPProcessor
if len(image_files) != len(prompts):
print(f"Warning: number of prompts ({len(prompts)}) != number of images ({len(image_files)}), will pair by min length")
# Filter out None paths (missing images)
valid_pairs = []
for img_path, prompt in zip(image_files, prompts):
if img_path is not None:
valid_pairs.append((img_path, prompt))
if not valid_pairs:
return None
processor = CLIPProcessor.from_pretrained(model_id)
model = CLIPModel.from_pretrained(model_id).to(device)
scores = []
for img_path, prompt in tqdm(valid_pairs, desc=label):
try:
image = Image.open(img_path).convert("RGB")
text_inputs = processor.tokenizer(
text=[prompt],
padding="max_length",
truncation=True,
max_length=77,
return_tensors="pt"
).to(device)
image_inputs = processor.image_processor(
images=[image],
return_tensors="pt"
).to(device)
inputs = {**text_inputs, **image_inputs}
with torch.no_grad():
outputs = model(**inputs)
img_emb = outputs.image_embeds / outputs.image_embeds.norm(dim=-1, keepdim=True)
txt_emb = outputs.text_embeds / outputs.text_embeds.norm(dim=-1, keepdim=True)
score = (img_emb * txt_emb).sum(dim=-1) * 100.0
scores.append(score.item())
except Exception as e:
print(f"Error calculating CLIP score for {img_path}: {e}")
if not scores:
return None
return sum(scores) / len(scores)
def compute_image_reward(gen_files: List[str], prompts: List[str], device: torch.device):
try:
import ImageReward as ir
except Exception as exc: # noqa: BLE001
print(f"ImageReward not available: {exc}")
return None
if len(gen_files) != len(prompts):
print("Warning: number of prompts != number of images, will pair by min length")
pairs = list(zip(gen_files, prompts))
# scorer = ImageReward("ImageReward/ImageReward-v1.0", device=str(device))
scorer = ir.load("ImageReward-v1.0")
scores = []
for img_path, prompt in tqdm(pairs, desc="ImageReward"):
try:
scores.append(scorer.score(prompt, img_path))
except Exception as exc: # noqa: BLE001
print(f"ImageReward failed on {os.path.basename(img_path)}: {exc}")
if not scores:
return None
return sum(scores) / len(scores)
def read_prompts(prompt_file: str) -> List[str]:
with open(prompt_file, "r", encoding="utf-8") as f:
lines = [line.strip() for line in f.readlines() if line.strip()]
return lines
def main():
parser = argparse.ArgumentParser(description="Evaluate PSNR/SSIM/LPIPS and optionally CLIPScore/ImageReward.")
parser.add_argument("--gen_dir", required=True, help="Directory with generated images")
parser.add_argument("--ref_dir", required=True, help="Directory with reference images")
parser.add_argument("--prompt_file", help="Text file with one prompt per generated image (needed for CLIP/ImageReward)")
parser.add_argument("--clip_model", default="openai/clip-vit-large-patch14", help="CLIP model id")
parser.add_argument("--no_clip", action="store_true", help="Skip CLIPScore")
parser.add_argument("--no_imagereward", action="store_true", help="Skip ImageReward")
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="Device (cuda/cpu)")
args = parser.parse_args()
device = torch.device(args.device)
gen_files = collect_images(args.gen_dir)
if not gen_files:
print(f"No images found in {args.gen_dir}")
return
psnr_ssim_lpips = compute_psnr_ssim_lpips(gen_files, args.ref_dir, device)
prompts = read_prompts(args.prompt_file) if args.prompt_file else []
gen_clip_score = None
ref_clip_score = None
image_reward = None
if prompts and not args.no_clip:
gen_clip_score = compute_clip_score(gen_files, prompts, device, args.clip_model, label="CLIPScore(Gen)")
ref_files = collect_ref_images(gen_files, args.ref_dir)
ref_clip_score = compute_clip_score(ref_files, prompts, device, args.clip_model, label="CLIPScore(Ref)")
elif not args.no_clip:
print("CLIPScore skipped: prompt_file is required")
if prompts and not args.no_imagereward:
image_reward = compute_image_reward(gen_files, prompts, device)
elif not args.no_imagereward:
print("ImageReward skipped: prompt_file is required")
print("\n" + "=" * 40)
if psnr_ssim_lpips:
print(f"Pairs: {psnr_ssim_lpips['count']}")
print(f"PSNR: {psnr_ssim_lpips['psnr']:.4f}")
print(f"SSIM: {psnr_ssim_lpips['ssim']:.4f}")
print(f"LPIPS: {psnr_ssim_lpips['lpips']:.4f}")
else:
print("PSNR/SSIM/LPIPS skipped (no pairs)")
if gen_clip_score is not None:
print(f"CLIPScore (Gen): {gen_clip_score:.4f}")
if ref_clip_score is not None:
print(f"CLIPScore (Ref): {ref_clip_score:.4f}")
if image_reward is not None:
print(f"ImageReward: {image_reward:.4f}")
print("=" * 40)
if __name__ == "__main__":
main()
Thanks for your great work!
However, when I tried to evaluate the PSNR, SSIM, and LPIPS metrics on FLUX.1 models, I found a large gap between the original paper and the reproduced result on PSNR, while SSIM and LPIPS are generally consistent with the original paper. I have tried different packages(OpenCV and torchmetrics), but the results remain the same.
Here's my code. I use the default seed and Drawbench prompts to sample 200 images.
My result on Taylorseer N=3 O=2 is:
PSNR: 19.5747
SSIM: 0.7456
LPIPS: 0.2233
while the result in the paper is:
PSNR: 30.762
SSIM: 0.7818
LPIPS: 0.2300