-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathcross_correlation.py
More file actions
executable file
·93 lines (80 loc) · 2.9 KB
/
cross_correlation.py
File metadata and controls
executable file
·93 lines (80 loc) · 2.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
"""
Normalized cross-correlation interactive display for the hurricane data set
"""
import numpy as np
import copy
import config
import util
import tt
import ttrecipes as tr
import matplotlib.pyplot as plt
import os
import time
input_dataset = os.path.join(config.data_folder, 'hurricane_500_500_91_uint8.raw')
B = 128
eps = 0.0001
basename, X, tth = util.prepare_dataset(config.data_folder, input_dataset, B=B, eps=eps)
print('Total computing time: {}'.format(tth.total_time))
print('NNZ: {}'.format(len(tth.tensor.core)))
def interactive_loop(tth, basis_size=8):
"""
Display an interactive image showing the normalized cross-correlation between
a histogram field and any window selected by the user
:param tth:
:param basis_size: used to approximate the norm of each individual histogram in the field
"""
cores = tt.vector.to_list(tth.tensor)
tr.core.orthogonalize(cores, len(cores)-1)
basis = np.linalg.svd(cores[-1][:, :, 0], full_matrices=0)[2][:basis_size, :]
cores[-1] = np.einsum('ijk,aj->iak', cores[-1], basis)
pca = tt.vector.from_list(cores)
pcatth = copy.copy(tth)
pcatth.tensor = pca
shape = np.array([8, 8, 90])
Is = [500, 500, 91]
corners = np.array([[shape[0] // 2, Is[0] - shape[0] // 2], [shape[1] // 2, Is[1] - shape[1] // 2], [45, 46]])
pcafield, elapsed = pcatth.box_field(corners, shape)
print('Box field computation time: {}'.format(elapsed))
pcafield = np.squeeze(pcafield)
norms = np.sqrt(np.sum(pcafield**2, axis=-1))
global im
im = None
global sc
sc = None
global counter
counter = 1
def update(x, y):
start = time.time()
v, elapsed = tth.box(np.array([[x-4, x+3], [y-4, y+3], [0, 91]]))
v = v / np.linalg.norm(v)
cores = tt.vector.to_list(tth.tensor)
cores[-1] = np.einsum('ijk,j->ik', cores[-1], v)[:, np.newaxis, :]
proj = tt.vector.from_list(cores)
projtth = copy.copy(tth)
projtth.tensor = proj
field, elapsed = projtth.box_field(corners, shape)
field = np.squeeze(field.T) / norms
global im
global sc
global counter
if im is None:
plt.axis('off')
im = plt.imshow(field, cmap='pink', vmin=0, vmax=1)
sc, = plt.plot(x, y, marker='+', ms=25, mew=5, color='red')
plt.show()
else:
im.set_data(field)
sc.set_data(x, y)
fig.canvas.draw()
extent = plt.gca().get_window_extent().transformed(fig.dpi_scale_trans.inverted())
plt.savefig(os.path.join(config.data_folder, 'similarity_{:03d}.pdf'.format(counter)), bbox_inches=extent)
counter += 1
def onclick(event):
x = int(event.xdata)
y = int(event.ydata)
print(x, y)
update(x, y)
fig = plt.figure()
fig.canvas.mpl_connect('button_press_event', onclick)
update(250, 250)
interactive_loop(tth)