-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathpatchmatch.py
More file actions
159 lines (117 loc) · 5.4 KB
/
patchmatch.py
File metadata and controls
159 lines (117 loc) · 5.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import numpy as np
from numpy.random import randint
from numba import njit, prange
from utils import is_in_inner_boundaries, is_same_shift
from patch_measure import patch_measure, calculate_patch_distance
from config import H_PATCH_SIZE
@njit
def is_valid_match(img, i, j):
"""Check if the match is valid (inside)"""
return is_in_inner_boundaries(img, i, j)
@njit
def sample_around(x, window, low=None, high=None):
"""Sample a point in [max(x - window, low), min(x + window, high)]"""
rand_min = max(x - window, low)
rand_max = min(x + window + 1, high)
if rand_min >= rand_max:
return int(rand_min)
return randint(rand_min, rand_max)
def patch_match(img1, img2, first_guess, params):
"""Compute a shift map from img1 to img2"""
shift_map = initialise_shift_map(first_guess, img1, img2)
for i in range(params["n_iters"]):
propagation(shift_map, img1, img2, i)
random_search(shift_map, img1, img2, params["w"], params["alpha"])
return shift_map
@njit(parallel=True)
def initialise_shift_map(shift_map, img1, img2):
"""Initialise the shift map by sampling valid coordinates. If the
existing shift is valid, update the distance.
"""
for i in prange(img1.shape[0]):
for j in prange(img1.shape[1]):
if not is_in_inner_boundaries(img1, i, j):
continue
y_shift = int(shift_map[i, j, 0])
x_shift = int(shift_map[i, j, 1])
while not is_valid_match(img2, i + y_shift, j + x_shift):
y_shift = randint(H_PATCH_SIZE, img2.shape[0] - H_PATCH_SIZE) - i
x_shift = randint(H_PATCH_SIZE, img2.shape[1] - H_PATCH_SIZE) - j
shift_map[i, j, 0] = float(y_shift)
shift_map[i, j, 1] = float(x_shift)
shift_map[i, j, 2] = calculate_patch_distance(img1, img2, shift_map, i, j)
return shift_map
@njit
def propagation(shift_map, img1, img2, iteration_nb):
"""Propagation step that evaluate neighbors shift"""
if iteration_nb % 2 == 0:
shift = -1
irange = range(img1.shape[0])
jrange = range(img1.shape[1])
else:
shift = 1
irange = range(img1.shape[0]-1, -1, -1)
jrange = range(img1.shape[1]-1, -1, -1)
for i in irange:
for j in jrange:
if not is_in_inner_boundaries(img1, i, j):
continue
current_shift = shift_map[i, j, :2]
current_distance = shift_map[i, j, 2]
neighbors = [[i + shift, j], # up / down
[i, j + shift]] # left / right
for (y, x) in neighbors:
if not is_in_inner_boundaries(img1, y, x):
continue
y_shift = int(shift_map[y, x, 0])
x_shift = int(shift_map[y, x, 1])
if not is_valid_match(img2, i + y_shift, j + x_shift):
continue
if is_same_shift(current_shift, y_shift, x_shift):
continue
distance = patch_measure(img1, img2,
i, j,
i + y_shift, j + x_shift,
current_distance)
if distance < current_distance:
current_shift[:] = [y_shift, x_shift]
current_distance = distance
if current_distance < shift_map[i, j, 2]:
shift_map[i, j, :2] = current_shift
shift_map[i, j, 2] = current_distance
@njit(parallel=True)
def random_search(shift_map, img1, img2, max_window, alpha):
"""Randomly search around each match to find better matches"""
max_window = min(max_window, max(img2.shape[:-1]))
max_exponent = int(np.ceil(-np.log(max_window) / np.log(alpha)))
windows = np.zeros(max_exponent, dtype=np.int64)
for exponent in range(max_exponent):
windows[exponent] = max_window * np.power(alpha, exponent)
for i in prange(img1.shape[0]):
for j in prange(img1.shape[1]):
if not is_in_inner_boundaries(img1, i, j):
continue
current_shift_x = int(shift_map[i, j, 0])
current_shift_y = int(shift_map[i, j, 1])
current_shift = [current_shift_x, current_shift_y]
current_distance = shift_map[i, j, 2]
y_match = i + current_shift_x
x_match = j + current_shift_y
for window in windows:
y_rand = sample_around(y_match, window, H_PATCH_SIZE, img2.shape[0] - H_PATCH_SIZE)
x_rand = sample_around(x_match, window, H_PATCH_SIZE, img2.shape[1] - H_PATCH_SIZE)
if not is_valid_match(img2, y_rand, x_rand):
continue
if is_same_shift(current_shift, y_rand - i, x_rand - j):
continue
distance = patch_measure(img1, img2,
i, j,
y_rand, x_rand,
current_distance)
if distance < current_distance:
current_shift[:] = [y_rand - i, x_rand - j]
current_distance = distance
if current_distance < shift_map[i, j, 2]:
shift_map[i, j, 0] = float(current_shift[0])
shift_map[i, j, 1] = float(current_shift[1])
shift_map[i, j, 2] = current_distance