-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgraph_cut_blending.py
More file actions
105 lines (82 loc) · 3.2 KB
/
graph_cut_blending.py
File metadata and controls
105 lines (82 loc) · 3.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from imageio import imread
import numpy as np
from matplotlib import pyplot as plt
from typing import Tuple
import maxflow
from random import random, randint
import os
# name = 'boat'
name = 'hut'
source_color = [0, 128, 255]
sink_color = [255, 255, 0]
plt.figure(num=None, figsize=(16, 16), dpi=80, facecolor='w', edgecolor='k')
def construct_cost_matrix_right(new: np.ndarray, old: np.ndarray):
difference_between_patch = np.abs(new - old)
shift_left_dif = np.roll(difference_between_patch, (0, -1))
match_cost_right = (difference_between_patch + shift_left_dif).sum(axis=2)
return match_cost_right + 1
def construct_cost_matrix_down(new: np.ndarray, old: np.ndarray):
difference_between_patch = np.abs(new - old)
shift_up_dif = np.roll(difference_between_patch, (-1, 0))
match_cost_down = (difference_between_patch + shift_up_dif).sum(axis=2)
return match_cost_down + 1
def graph_cut_blend(src_img, tgt_img, mask):
if src_img.shape != tgt_img.shape:
print('Error')
img_height, img_width, _ = src_img.shape
g = maxflow.Graph[int](img_height, img_width)
nodeids = g.add_grid_nodes((img_height, img_width))
cost_matrix_right = construct_cost_matrix_right(src_img, tgt_img)
cost_matrix_down = construct_cost_matrix_down(src_img, tgt_img)
# add right
structure = np.array([[0, 0, 0],
[0, 0, 1],
[0, 0, 0]])
g.add_grid_edges(nodeids, weights=cost_matrix_right, structure=structure,
symmetric=True)
# add down
structure = np.array([[0, 0, 0],
[0, 0, 0],
[0, 1, 0]])
g.add_grid_edges(nodeids, weights=cost_matrix_down, structure=structure,
symmetric=True)
sink_node = []
source_node = []
inf_weight = 90000 # very big number
for j in range(img_height):
for i in range(img_width):
# if (mask[j,i] != 0).any():
# print('yes')
# print(mask[j,i])
if np.equal(mask[j, i], source_color).all():
nodeid = nodeids[j, i]
source_node.append(nodeid)
g.add_tedge(nodeid, inf_weight, 0)
elif np.equal(mask[j, i], sink_color).all():
nodeid = nodeids[j, i]
sink_node.append(nodeid)
g.add_tedge(nodeid, 0, inf_weight)
# Find the maximum flow.
flow = g.maxflow()
print('flow', flow)
# Get the segments of the nodes in the grid.
sgm = g.get_grid_segments(nodeids)
print(sgm)
tgt_img[sgm] = src_img[sgm]
plt.subplot(2, 1, 1)
plt.imshow(tgt_img)
plt.subplot(2, 1, 2)
plt.imshow(sgm)
plt.show()
return tgt_img
if __name__ == "__main__":
src_img_in = imread('data/{}_src.jpg'.format(name))
tgt_img_in = imread('data/{}_target.jpg'.format(name))
mask_img = imread('data/{}_mask.png'.format(name))
if src_img_in.shape[2] == 4:
# remove alpha channel
src_img_in = np.array(src_img_in[:, :, 0:3])
if tgt_img_in.shape[2] == 4:
# remove alpha channel
tgt_img_in = np.array(tgt_img_in[:, :, 0:3])
out_img = graph_cut_blend(src_img_in, tgt_img_in, mask_img)