-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathM2O_ST.py
More file actions
618 lines (506 loc) · 20.7 KB
/
M2O_ST.py
File metadata and controls
618 lines (506 loc) · 20.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import matplotlib.pyplot as plt
import cv2
from torchvision import models
import numpy as np
import os
# Transformer class to modify layers as needed
class transformer(nn.Module):
def __init__(self, subnet):
super(transformer, self).__init__()
self.net = subnet
for name, layer in self.net.named_children():
if isinstance(layer, nn.MaxPool2d):
self.net[int(name)] = nn.AvgPool2d(kernel_size=2, stride=2)
def get_gram_matrix(self, style):
"""
计算风格特征的Gram矩阵
style: 特征图列表,每个元素shape为[1, C, H, W]
"""
gram_matrix = []
for feature in style:
b, c, h, w = feature.size() # 获取4个维度
feature = feature.view(b * c, h * w) # 重塑为2D矩阵
gram = torch.mm(feature, feature.t()) # 计算Gram矩阵
gram = gram.div(c * h * w) # 归一化
gram_matrix.append(gram)
return gram_matrix
def forward(self, x, content_list=None, style_list=None):
content = []
style = []
for name, layer in self.net.named_children():
x = layer(x)
if content_list and int(name) in content_list:
content.append(x)
if style_list and int(name) in style_list:
style.append(x)
# Compute Gram matrix for style loss
style_matrix = self.get_gram_matrix(style)
return content, style_matrix
# Training function for style transfer
class MaskGenerator:
def __init__(self, num_regions):
self.num_regions = num_regions
def create_sky_ground_mask(self, image):
if image.shape[2] == 3:
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
lab = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2LAB)
l_channel = lab[:, :, 0]
binary = cv2.adaptiveThreshold(
l_channel, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, 201, 2
)
kernel = np.ones((5, 5), np.uint8)
binary = cv2.morphologyEx(binary, cv2.MORPH_CLOSE, kernel)
binary = cv2.morphologyEx(binary, cv2.MORPH_OPEN, kernel)
# 使用GrabCut进行精细分割
mask = np.zeros(image.shape[:2], np.uint8)
bgdModel = np.zeros((1, 65), np.float64)
fgdModel = np.zeros((1, 65), np.float64)
# 初始化矩形区域
height, width = image.shape[:2]
rect = (0, 0, width, height // 2)
# 运行GrabCut算法
cv2.grabCut(image, mask, rect, bgdModel, fgdModel, 5, cv2.GC_INIT_WITH_RECT)
# 创建基础遮罩
height, width = image.shape[:2]
sky_mask_np = np.where((mask == 2) | (mask == 0), 0, 1).astype("float32")
# 创建过渡区域
transition_width = 30 # 过渡区域的宽度
# 使用距离变换创建渐变
dist = cv2.distanceTransform(sky_mask_np.astype(np.uint8), cv2.DIST_L2, 5)
dist_inv = cv2.distanceTransform(
(1 - sky_mask_np).astype(np.uint8), cv2.DIST_L2, 5
)
# 标准化距离
dist = dist / (dist + dist_inv + 1e-6)
# 应用sigmoid函数使过渡更平滑
dist = 1 / (1 + np.exp(-(dist - 0.5) * 8))
# 使用高斯模糊进一步平滑
blur_radius = 15
sky_mask_np = cv2.GaussianBlur(dist, (blur_radius, blur_radius), 0)
# 创建遮罩
masks = []
# 创建天空遮罩
sky_mask = torch.zeros((3, height, width))
for c in range(3):
sky_mask[c] = torch.from_numpy(sky_mask_np)
masks.append(sky_mask)
# 创建地面遮罩(使用平滑的补集)
ground_mask = torch.zeros((3, height, width))
ground_mask_np = 1 - sky_mask_np
for c in range(3):
ground_mask[c] = torch.from_numpy(ground_mask_np)
masks.append(ground_mask)
# 可视化分割结果
if True:
plt.figure(figsize=(20, 5))
plt.subplot(141)
plt.imshow(image_rgb)
plt.title("Original Image")
plt.axis("off")
plt.subplot(142)
plt.imshow(sky_mask_np, cmap="gray")
plt.title("Sky Mask with Smooth Transition")
plt.axis("off")
plt.subplot(143)
plt.imshow(ground_mask_np, cmap="gray")
plt.title("Ground Mask with Smooth Transition")
plt.axis("off")
# 显示过渡区域
transition = np.abs(sky_mask_np - 0.5) < 0.2
plt.subplot(144)
plt.imshow(transition, cmap="hot")
plt.title("Transition Area")
plt.axis("off")
plt.tight_layout()
plt.show()
return masks
def create_grid_mask(self, height, width):
"""创建网格状的遮掩"""
rows = int(np.sqrt(self.num_regions))
cols = self.num_regions // rows if self.num_regions % rows == 0 else rows
masks = []
h_step = height // rows
w_step = width // cols
for i in range(rows):
for j in range(cols):
# 创建3通道的遮罩
mask = torch.zeros((3, height, width))
mask[
:, i * h_step : (i + 1) * h_step, j * w_step : (j + 1) * w_step
] = 1
masks.append(mask)
if len(masks) == self.num_regions: # 确保只创建需要的数量的遮罩
return masks
return masks
def create_random_mask(self, height, width):
"""创建随机形状的遮罩"""
masks = []
for _ in range(self.num_regions):
# 创建3通道的遮罩
mask = torch.rand(3, height, width) > 0.5
masks.append(mask.float())
return masks
def create_smooth_transition(self, mask, blur_radius=15):
"""
为遮罩创建平滑过渡
"""
# 确保mask是numpy数组
if torch.is_tensor(mask):
mask = mask.cpu().numpy()
# 应用高斯模糊
blurred = cv2.GaussianBlur(mask, (blur_radius, blur_radius), 0)
# 归一化到[0,1]范围
blurred = (blurred - blurred.min()) / (blurred.max() - blurred.min())
return torch.from_numpy(blurred).float()
def create_multi_object_mask(self, image, num_segments=3):
"""
创建多个物体区域的遮罩,使用改进的分割方法保持物体完整性
image: 原始图像 [H, W, C]
num_segments: 期望分割的区域数量
"""
# 转换为RGB格式
if image.shape[2] == 3:
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# 使用SLIC超像素分割
from skimage.segmentation import slic, mark_boundaries
from skimage.feature import canny
from scipy import ndimage as ndi
# 步骤1: 使用SLIC生成超像素
segments = slic(image_rgb, n_segments=100, compactness=10, sigma=1)
# 步骤2: 边缘检测
edges = canny(cv2.cvtColor(image_rgb, cv2.COLOR_RGB2GRAY))
# 步骤3: 使用分水岭算法进行区域合并
distance = ndi.distance_transform_edt(~edges)
from skimage.feature import peak_local_max
from skimage.segmentation import watershed
# 找到局部最大值作为标记点
coordinates = peak_local_max(distance, min_distance=20)
local_max = np.zeros_like(distance, dtype=bool)
local_max[tuple(coordinates.T)] = True
markers = ndi.label(local_max)[0]
# 应用分水岭算法
labels = watershed(-distance, markers, mask=~edges)
# 步骤4: 合并小区域
from skimage.measure import regionprops
props = regionprops(labels)
# 计算区域大小的中位数
areas = [prop.area for prop in props]
median_area = np.median(areas)
# 合并小区域到邻近的大区域
new_labels = labels.copy()
for prop in props:
if prop.area < median_area * 0.3: # 小于中位数面积30%的区域被视为小区域
mask = labels == prop.label
dilated = ndi.binary_dilation(mask)
neighbors = np.unique(labels[dilated & ~mask])
if len(neighbors) > 0 and 0 not in neighbors: # 确保不会合并到背景
# 合并到最大的邻居区域
largest_neighbor = max(neighbors, key=lambda x: np.sum(labels == x))
new_labels[mask] = largest_neighbor
# 步骤5: 使用K-means对区域进行聚类得到最终的num_segments个区域
from sklearn.cluster import KMeans
# 为每个区域计算颜色和位置特征
features = []
unique_labels = np.unique(new_labels)
unique_labels = unique_labels[unique_labels != 0] # 排除背景标签
for label in unique_labels:
mask = new_labels == label
if np.sum(mask) > 0: # 确保区域非空
color_mean = np.mean(image_rgb[mask], axis=0)
position = np.mean(np.where(mask), axis=1)
# 增加颜色特征的权重
features.append(np.concatenate([color_mean * 2, position * 0.5]))
if len(features) < num_segments:
print(
f"Warning: Only found {len(features)} regions, but {num_segments} were requested."
)
num_segments = max(1, len(features))
# 聚类得到最终区域
kmeans = KMeans(n_clusters=num_segments, random_state=42)
cluster_labels = kmeans.fit_predict(features)
# 创建最终的遮罩
height, width = image_rgb.shape[:2]
masks = []
# 为每个最终区域创建遮罩
for segment in range(num_segments):
base_mask = np.zeros((height, width), dtype=np.float32)
for orig_label, cluster in zip(unique_labels, cluster_labels):
if cluster == segment:
base_mask[new_labels == orig_label] = 1
# 使用形态学操作清理遮罩
kernel = np.ones((5, 5), np.uint8)
base_mask = cv2.morphologyEx(base_mask, cv2.MORPH_CLOSE, kernel)
base_mask = cv2.morphologyEx(base_mask, cv2.MORPH_OPEN, kernel)
# 使用距离变换创建平滑过渡
dist = cv2.distanceTransform(base_mask.astype(np.uint8), cv2.DIST_L2, 5)
dist_inv = cv2.distanceTransform(
(1 - base_mask).astype(np.uint8), cv2.DIST_L2, 5
)
# 标准化距离并创建平滑过渡
dist = dist / (dist + dist_inv + 1e-6)
smooth_mask = 1 / (1 + np.exp(-(dist - 0.5) * 8))
smooth_mask = cv2.GaussianBlur(smooth_mask, (15, 15), 0)
# 创建三通道遮罩
mask = torch.zeros((3, height, width))
for c in range(3):
mask[c] = torch.from_numpy(smooth_mask)
masks.append(mask)
# 可视化分割结果
if True:
plt.figure(figsize=(5 * (num_segments + 2), 5))
# 显示原图
plt.subplot(1, num_segments + 2, 1)
plt.imshow(image_rgb)
plt.title("Original Image")
plt.axis("off")
# 显示边缘检测结果
plt.subplot(1, num_segments + 2, 2)
plt.imshow(edges, cmap="gray")
plt.title("Edges")
plt.axis("off")
# 显示每个区域的遮罩
for i in range(num_segments):
plt.subplot(1, num_segments + 2, i + 3)
plt.imshow(masks[i][0].numpy(), cmap="gray")
plt.title(f"Region {i+1} Mask")
plt.axis("off")
plt.tight_layout()
plt.show()
return masks
# 修改 train 函数以支持多风格和遮掩
def train(x, net, parameter_list, content, styles):
"""
x: 输入图像 [1, 3, H, W]
styles: 列表,含多个风格的 Gram 矩阵
"""
epoches = parameter_list["epoches"]
device = parameter_list["device"]
content_list = parameter_list["content_list"]
style_list = parameter_list["style_list"]
# 创建遮罩
mask_generator = MaskGenerator(num_regions=len(styles))
content_img = x.squeeze(0).permute(1, 2, 0).cpu().numpy()
content_img = (content_img * 255).astype(np.uint8)
# 使用新的多物体分割方法
masks = mask_generator.create_multi_object_mask(
content_img, num_segments=len(styles)
)
masks = [mask.to(device) for mask in masks]
x.requires_grad = True
x = x.to(device)
net.to(device)
optimizer = optim.Adam([x], lr=parameter_list["lr"])
# FHT: 使用AdamW优化器
# optimizer = optim.AdamW([x], lr=parameter_list["lr"])
for epoch in range(epoches):
optimizer.zero_grad()
# 计算内容损失
x_content, x_style = net(x, content_list, style_list)
content_loss = torch.tensor(0.0, device=device)
for i in range(len(x_content)):
content_loss = content_loss + F.mse_loss(x_content[i], content[i])
style_losses = []
total_style_loss = torch.tensor(0.0, device=device)
for i, (style_gram, mask) in enumerate(zip(styles, masks)):
region_style_loss = torch.tensor(0.0, device=device)
masked_x = x * mask.unsqueeze(0)
with torch.set_grad_enabled(True):
_, masked_style = net(masked_x, content_list, style_list)
for j in range(len(masked_style)):
region_loss = F.mse_loss(masked_style[j], style_gram[j].detach())
region_style_loss = region_style_loss + region_loss
total_style_loss = total_style_loss + region_loss
style_losses.append(region_style_loss.item())
# 总变差损失
tv_loss = (
torch.sum(torch.abs(x[:, :, :-1] - x[:, :, 1:]))
+ torch.sum(torch.abs(x[:, :-1, :] - x[:, 1:, :]))
) / (3 * x.shape[2] * x.shape[3])
loss = (
parameter_list["content_weight"] * content_loss
+ parameter_list["style_weight"] * total_style_loss
+ parameter_list["tv_weight"] * tv_loss
)
loss.backward(retain_graph=True)
optimizer.step()
if epoch == 0 or epoch % 100 == 99:
print(f"Epoch {epoch+1}:")
print(f"Total Loss = {loss.item():.4f}")
print("----------------------")
with torch.no_grad():
img = torch.clamp(x, 0, 1)
# if dist.get_rank() == 0: # 只让主进程显示图片
show_image(img)
return x
def preprocess_image(image):
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = image.astype(np.float32) / 255.0
image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0)
return image
def show_image(tensor):
"""
显示张量图像
tensor: shape为[1, 3, H, W]的图像张量
"""
with torch.no_grad():
# 将张量转换为numpy数组
image = tensor.detach().squeeze(0).permute(1, 2, 0).cpu().numpy()
# 确保值在[0,1]范围内
image = np.clip(image, 0, 1)
# 转换为RGB格式显示
plt.figure(figsize=(10, 10))
plt.imshow(image)
plt.axis("off")
plt.show()
plt.close()
def save_image(tensor, filename):
"""
保存张量图像到文件
tensor: shape为[1, 3, H, W]的图像张量
filename: 保存的文件名
"""
# 分离梯度信息
with torch.no_grad():
# 将张量转换为numpy数组
image = tensor.detach().squeeze(0).permute(1, 2, 0).cpu().numpy()
# 确保值在[0,1]范围内
image = np.clip(image, 0, 1)
# 转换为0-255范围
image = (image * 255).astype(np.uint8)
# 转换为BGR格式(OpenCV格式)并保存
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
cv2.imwrite(filename, image)
def resize_to_nearest_power_of_two(image, max_dim=1024):
"""调整图像宽和高到小于等于自身且不超过 max_dim 的最大 2 的幂值"""
h, w = image.shape[:2]
# 计算宽和高调整后的目标值
def get_target_length(x, limit):
return min(2 ** int(np.floor(np.log2(x))), limit)
new_h = get_target_length(h, max_dim)
new_w = get_target_length(w, max_dim)
# 调整图像大小
return cv2.resize(image, (new_w, new_h))
def style_transfer(
style_paths: list[str],
content_path: str,
save_dir="../data/paper/result",
n_epoch=1000,
):
"""将若干张图片的风格迁移到一张图上
Args:
style_paths (list[str]): 若干风格图片的路径
target_path (str): 要迁移到的目标图片
save_path (str): 生成的图片保存到的文件夹. Default to "../data/paper/result"
n_epoch (int): 训练轮数. Default to 1000.
"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 加载内容图像
content_img = cv2.imread(content_path)
original_height, original_width = content_img.shape[:2] # FHT: 保存原始尺寸
print(f"\norigin: {content_img.shape}\n")
content_img = resize_to_nearest_power_of_two(content_img)
print(f"\ntransform: {content_img.shape}\n")
content_tensor = preprocess_image(content_img).to(device)
# 加载多个风格图像
style_tensors = []
for style_path in style_paths:
style_img = cv2.imread(style_path)
# modified by FHT
# 根据目标图像的大小调整风格图像的大小
target_height, target_width = content_img.shape[:2]
style_img = cv2.resize(style_img, (target_width, target_height))
print(f"\n{style_img.shape}\n")
style_tensors.append(preprocess_image(style_img).to(device))
# 设置参数
parameter_list = {
"device": device,
"epoches": n_epoch,
"lr": 1e-3,
"content_weight": 1e-2, # origin: 1e-2
"style_weight": 1e7, # origin: 1e7
"tv_weight": 1e-3,
"content_list": [25],
"style_list": [0, 5, 10, 19, 28],
}
# 初始化网络
net = transformer(
models.vgg19(weights=models.VGG19_Weights.DEFAULT).features.eval()
).to(device)
# 使用 DataParallel 包裹模型
net = nn.DataParallel(net, device_ids=[0, 1]).to(device)
# 获取内容特征
content_features, _ = net(
content_tensor, parameter_list["content_list"], parameter_list["style_list"]
)
# 获取多个风格的 Gram 矩阵
style_grams = []
for style_tensor in style_tensors:
_, style_gram = net(
style_tensor, parameter_list["content_list"], parameter_list["style_list"]
)
style_grams.append(style_gram)
x = content_tensor.clone()
result = train(x, net, parameter_list, content_features, style_grams)
# FHT: 输出的图像resize为原始内容图像大小
result = F.interpolate(
result,
size=(original_height, original_width),
mode="bilinear",
align_corners=False,
)
# 格式化输出图像名称为:{content}({style1, style2, ..., styleN}).png
content_name = os.path.splitext(os.path.basename(content_path))[0]
style_names = [
os.path.splitext(os.path.basename(style_path))[0] for style_path in style_paths
]
style_part = ", ".join(style_names)
result_name = f"{content_name}({style_part}).png"
# 如果提供了输出文件夹,添加到图像输出路径中
if os.path.isdir(save_dir):
result_name = os.path.join(save_dir, result_name)
# 保存图像
save_image(result, result_name)
print(f"Output image has been saved to {result_name}")
if __name__ == "__main__":
# 定义参数解析器
parser = argparse.ArgumentParser(description="Multi-to-One Style Transfer")
# 添加参数
parser.add_argument(
"--style_paths",
nargs="+",
default=[
"data/style/Picasso/guernica.jpg",
"data/style/Tsunami.jpg",
],
help="Paths to style images (provide multiple paths separated by space)",
)
parser.add_argument(
"--content_path",
default="data/content/chicago_0.jpg",
help="Path to the content image",
)
parser.add_argument(
"--save_dir",
default="data/result",
help="Directory to save the result image",
)
parser.add_argument(
"--n_epoch",
type=int,
default=1000,
help="Number of training epochs, default to 1000",
)
# 解析参数
args = parser.parse_args()
# 调用主函数
style_transfer(
style_paths=args.style_paths,
content_path=args.content_path,
save_dir=args.save_dir,
n_epoch=args.n_epoch,
)