-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathpatch_detector.py
More file actions
69 lines (53 loc) · 2.22 KB
/
patch_detector.py
File metadata and controls
69 lines (53 loc) · 2.22 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from PIL import Image
from dataclasses import dataclass
from functools import cached_property
from typing import Tuple
import numpy as np
from .editor_ctx import EditorCtx
from .image_warper import ImageWarper
from .diff_algs import DiffAlgorithm, HeatmapDifference, HeatmapL2Norm, HeatmapSSIM
@dataclass
class PatchDetector:
"""
Detects differences between two images and generates a heatmap of the changes.
This class is responsible for comparing the original and edited images.
It first aligns the images if `allow_rescale` is true, then uses a selected
difference algorithm to calculate a heatmap of the changes.
"""
ctx: EditorCtx
@cached_property
def images(self) -> Tuple[Image.Image, Image.Image]:
img1 = self.ctx.original_img.convert('RGBA')
img2 = self.ctx.edited_img.convert('RGBA')
if self.ctx.allow_rescale:
warper = ImageWarper()
img2 = warper.warp(img1, img2)
if img1.size != img2.size:
raise ValueError(f"Images must be same size: {img1.size} != {img2.size}. Enable allow_rescale to fix automatically.")
return img1, img2
@cached_property
def detection_alg(self) -> DiffAlgorithm:
alg_map = {
"difference": HeatmapDifference,
"l2_norm": HeatmapL2Norm,
"ssim": HeatmapSSIM,
}
alg_class = alg_map.get(self.ctx.detection_alg)
if not alg_class:
raise ValueError(f"Unknown detection algorithm: {self.ctx.detection_alg}")
return alg_class(ctx=self.ctx)
@cached_property
def heatmap(self) -> Image.Image:
img_orig, img_edit = self.images
heatmap_arr = self.detection_alg.calculate(img_orig, img_edit)
# Handle Stride
if self.ctx.stride > 1:
# Subsample
small = heatmap_arr[:: self.ctx.stride, :: self.ctx.stride]
# Upscale
h, w = heatmap_arr.shape
small_img = Image.fromarray(np.clip(small, 0, 255).astype(np.uint8), mode='L')
heatmap = small_img.resize((w, h), resample=Image.Resampling.NEAREST)
else:
heatmap = Image.fromarray(np.clip(heatmap_arr, 0, 255).astype(np.uint8), mode='L')
return heatmap