-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathnodes.py
More file actions
143 lines (128 loc) · 6.32 KB
/
nodes.py
File metadata and controls
143 lines (128 loc) · 6.32 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import torch
import numpy as np
from PIL import Image
from typing import Tuple, List, Dict, Any
from .patch_editor import PatchEditor # type: ignore
class ImgPatchNode:
"""
A node for patching images by detecting differences and merging them.
This node takes an original and an edited image, detects the changed areas,
and merges the changes back into the original image. It provides several
outputs for visualizing the process, including the final merged image,
the mask used for merging, a heatmap of the differences, and a debug overlay.
"""
@classmethod
def INPUT_TYPES(cls) -> Dict[str, Any]:
return {
"required": {
"original_image": ("IMAGE",),
"edited_image": ("IMAGE",),
"detection_alg": (["difference", "l2_norm", "ssim"],),
"region_size": ("INT", {"default": 8, "min": 1, "max": 128, "tooltip": "Size of the blocks used for difference detection."}),
"stride": (
"INT",
{
"default": 1,
"min": 1,
"max": 64,
"tooltip": "Step size for the sliding window. 1 means full overlap (slowest, most accurate), equal to region_size means no overlap (fastest).",
},
),
"min_area_size": ("INT", {"default": 4, "min": 0, "max": 128, "tooltip": "Minimum size of the changed area to keep."}),
"tolerance": (
"FLOAT",
{
"default": 0.5,
"min": 0.0,
"max": 5.0,
"step": 0.01,
"tooltip": "Sensitivity of difference detection. Higher values will detect more changes.",
},
),
"blur_radius": (
"FLOAT",
{"default": 3.0, "min": 0.0, "max": 100.0, "step": 0.1, "tooltip": "Radius of the gaussian blur applied to the mask."},
),
"mask_expansion": (
"INT",
{"default": 0, "min": -128, "max": 128, "tooltip": "Expand or shrink the mask. Positive values expand, negative values shrink."},
),
"edit_max_alpha": (
"FLOAT",
{"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Maximum alpha for the edited image overlay."},
),
"allow_rescale": ("BOOLEAN", {"default": True, "tooltip": "If true, rescales the edited image to match original size."}),
},
}
RETURN_TYPES = ("IMAGE", "MASK", "IMAGE", "MASK", "IMAGE")
RETURN_NAMES = ("merged_image", "final_mask", "heatmap", "raw_mask", "debug_overlay")
# OUTPUT_IS_LIST = (False, True, True, True, True)
OUTPUT_TOOLTIPS = (
"The original image with the patched areas from the edited image.",
"The final mask used for merging. White indicates changes (taken from edited image), black indicates unchanged areas.",
"Visualization of differences. Brighter areas indicate larger differences between original and edited images.",
"Initial binary mask before post-processing. White indicates detected changes, black indicates no change.",
"Original image with a red overlay highlighting the detected changes.",
)
FUNCTION = "process"
CATEGORY = "ImgPatch"
def process(
self,
original_image: torch.Tensor,
edited_image: torch.Tensor,
detection_alg: str,
region_size: int,
stride: int,
min_area_size: int,
tolerance: float,
blur_radius: float,
mask_expansion: int,
edit_max_alpha: float,
allow_rescale: bool,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
print(f"original_image shape: {original_image.shape}")
print(f"edited_image shape: {edited_image.shape}")
if original_image.shape[0] > 1:
raise ValueError("Only single image is supported for original_image")
current_image_tensor = original_image
out_final_masks: List[torch.Tensor] = []
out_heatmaps: List[torch.Tensor] = []
out_raw_masks: List[torch.Tensor] = []
out_overlays: List[torch.Tensor] = []
for i in range(edited_image.shape[0]):
pil_1 = self._tensor_to_pil(current_image_tensor[0])
pil_2 = self._tensor_to_pil(edited_image[i])
editor = PatchEditor()
editor.ctx.original_img = pil_1
editor.ctx.edited_img = pil_2
editor.ctx.detection_alg = detection_alg
editor.ctx.region_size = region_size
editor.ctx.stride = stride
editor.ctx.min_area_size = min_area_size
editor.ctx.tolerance = tolerance
editor.ctx.blur_radius = blur_radius
editor.ctx.mask_expansion = mask_expansion
editor.ctx.edit_max_alpha = edit_max_alpha
editor.ctx.allow_rescale = allow_rescale
merger = editor.patch_merger
current_image_tensor = self._pil_to_tensor(merger.merged_image)
out_final_masks.append(self._mask_to_tensor(merger.final_mask))
out_heatmaps.append(self._pil_to_tensor(editor.patch_detector.heatmap.convert("RGB")))
out_raw_masks.append(self._mask_to_tensor(merger.raw_mask))
out_overlays.append(self._pil_to_tensor(merger.debug_overlay))
return (
current_image_tensor,
torch.cat(out_final_masks, dim=0),
torch.cat(out_heatmaps, dim=0),
torch.cat(out_raw_masks, dim=0),
torch.cat(out_overlays, dim=0),
)
def _tensor_to_pil(self, tensor: torch.Tensor) -> Image.Image:
# tensor is [H, W, C]
return Image.fromarray(np.clip(255.0 * tensor.cpu().numpy(), 0, 255).astype(np.uint8))
def _pil_to_tensor(self, image: Image.Image) -> torch.Tensor:
# image is PIL
return torch.from_numpy(np.array(image).astype(np.float32) / 255.0).unsqueeze(0)
def _mask_to_tensor(self, image: Image.Image) -> torch.Tensor:
# image is PIL (L) -> Tensor [1, H, W]
return torch.from_numpy(np.array(image).astype(np.float32) / 255.0).unsqueeze(0)