-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathdiff_algs.py
More file actions
132 lines (96 loc) · 4.02 KB
/
diff_algs.py
File metadata and controls
132 lines (96 loc) · 4.02 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
from PIL import Image, ImageChops
from dataclasses import dataclass
from typing import TYPE_CHECKING
import numpy as np
import scipy.ndimage
from skimage.metrics import structural_similarity
if TYPE_CHECKING:
from .patch_editor import EditorCtx
@dataclass
class DiffAlgorithm:
"""
Base class for difference calculation algorithms.
Subclasses should implement the `calculate` method to return a heatmap
of differences between two images.
"""
ctx: "EditorCtx"
half: int = 0
full: int = 255
def precalculate_blend(self, img_orig: Image.Image, img_edit: Image.Image) -> Image.Image:
if img_edit.mode == 'RGBA':
img_orig_rgb = img_orig.convert('RGB')
r, g, b, a = img_edit.split()
img_edit_rgb = Image.merge('RGB', (r, g, b))
mask = a.point(lambda x: 255 - x)
blended_rgb = Image.composite(img_orig_rgb, img_edit_rgb, mask)
if img_orig.mode == 'RGBA':
blended_rgba = blended_rgb.convert('RGBA')
blended_rgba.putalpha(img_orig.split()[-1])
return blended_rgba
else:
return blended_rgb
return img_edit
def calculate(self, img_orig: Image.Image, img_edit: Image.Image) -> np.ndarray:
raise NotImplementedError
def get_threshold(self, tolerance: float) -> int:
if tolerance <= 1.0:
return int(self.half * tolerance)
else:
return int(self.half + (self.full - self.half) * (tolerance - 1.0))
@dataclass
class HeatmapDifference(DiffAlgorithm):
"""
Calculates the difference between two images using ImageChops.difference.
This is a simple and fast algorithm that computes the absolute difference
for each pixel.
"""
half: int = 10
full: int = 40
def calculate(self, img_orig: Image.Image, img_edit: Image.Image) -> np.ndarray:
img_edit = self.precalculate_blend(img_orig, img_edit)
diff = ImageChops.difference(img_orig.convert("RGB"), img_edit.convert("RGB"))
diff_l = diff.convert('L')
arr = np.array(diff_l, dtype=np.float32)
if self.ctx.region_size > 1:
arr = scipy.ndimage.uniform_filter(arr, size=self.ctx.region_size, mode='reflect')
return arr
@dataclass
class HeatmapL2Norm(DiffAlgorithm):
"""
Calculates the L2 norm (Euclidean distance) between the pixel values of two images.
This method is more sensitive to color variations than simple difference.
"""
half: int = 10
full: int = 100
def calculate(self, img_orig: Image.Image, img_edit: Image.Image) -> np.ndarray:
img_edit = self.precalculate_blend(img_orig, img_edit)
a1 = np.array(img_orig.convert('RGB'), dtype=np.float32)
a2 = np.array(img_edit.convert('RGB'), dtype=np.float32)
diff_sq = np.sum((a1 - a2) ** 2, axis=2)
if self.ctx.region_size > 1:
mse = scipy.ndimage.uniform_filter(diff_sq, size=self.ctx.region_size, mode='reflect')
else:
mse = diff_sq
return np.sqrt(mse)
@dataclass
class HeatmapSSIM(DiffAlgorithm):
"""
Calculates the Structural Similarity Index (SSIM) between two images.
SSIM is a perceptual metric that quantifies the similarity between two images,
considering luminance, contrast, and structure.
"""
half: int = 40
full: int = 100
def calculate(self, img_orig: Image.Image, img_edit: Image.Image) -> np.ndarray:
img_edit = self.precalculate_blend(img_orig, img_edit)
a1 = np.array(img_orig.convert('RGB'))
a2 = np.array(img_edit.convert('RGB'))
win_size = self.ctx.region_size if self.ctx.region_size % 2 == 1 else self.ctx.region_size + 1
if win_size < 3:
win_size = 3
kwargs = {'channel_axis': 2, 'data_range': 255, 'win_size': win_size}
_, diff = structural_similarity(a1, a2, full=True, **kwargs)
diff_map = (1 - diff) * 255
if diff_map.ndim == 3:
diff_map = np.mean(diff_map, axis=2)
return diff_map