Skip to content

Commit e0b45e7

Browse files
committed
memory efficient
1 parent 88d7156 commit e0b45e7

File tree

1 file changed

+46
-19
lines changed

1 file changed

+46
-19
lines changed

deform_conv.py

Lines changed: 46 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,46 @@ def forward(self, x):
1818
offset = self.conv_offset(x)
1919
dtype = offset.data.type()
2020
ks = self.kernel_size
21+
N = offset.size(1) // 2
2122

22-
# (b, 2N, h, w)
23+
# # (b, 2N, h, w)
2324
p = self._get_p(offset, dtype)
2425

26+
# (b, h, w, 2N)
27+
p = p.contiguous().permute(0, 2, 3, 1)
28+
29+
q_lt = p.floor().long()
30+
q_rb = p.ceil().long()
31+
q_lb = torch.cat([q_lt[..., :N], q_rb[..., N:]], -1)
32+
q_rt = torch.cat([q_rb[..., :N], q_lt[..., N:]], -1)
33+
34+
# (b, h, w, N)
35+
overlap_indices = torch.eq(q_lt, q_rb) + 1
36+
overlap = (overlap_indices[..., :N] * overlap_indices[..., N:]).float()
37+
38+
# bilinear kernel (b, h, w, N)
39+
g_lt_x = 1 - torch.abs(p[..., :N] - q_lt[..., :N].type_as(p))
40+
g_lt_y = 1 - torch.abs(p[..., N:] - q_lt[..., N:].type_as(p))
41+
g_rb_x = 1 - torch.abs(p[..., :N] - q_rb[..., :N].type_as(p))
42+
g_rb_y = 1 - torch.abs(p[..., N:] - q_rb[..., N:].type_as(p))
43+
g_lb_x = 1 - torch.abs(p[..., :N] - q_lb[..., :N].type_as(p))
44+
g_lb_y = 1 - torch.abs(p[..., N:] - q_lb[..., N:].type_as(p))
45+
g_rt_x = 1 - torch.abs(p[..., :N] - q_rt[..., :N].type_as(p))
46+
g_rt_y = 1 - torch.abs(p[..., N:] - q_rt[..., N:].type_as(p))
47+
2548
# (b, c, h, w, N)
26-
x_offset = self._get_x_offset(x, p, dtype)
49+
x_q_lt = self._get_x_q(x, q_lt, N)
50+
x_q_rb = self._get_x_q(x, q_rb, N)
51+
x_q_lb = self._get_x_q(x, q_lb, N)
52+
x_q_rt = self._get_x_q(x, q_rt, N)
53+
54+
# (b, c, h, w, N)
55+
x_offset = (g_lt_x * g_lt_y).unsqueeze(dim=1) * x_q_lt + \
56+
(g_rb_x * g_rb_y).unsqueeze(dim=1) * x_q_rb + \
57+
(g_lb_x * g_lb_y).unsqueeze(dim=1) * x_q_lb + \
58+
(g_rt_x * g_rt_y).unsqueeze(dim=1) * x_q_rt
59+
60+
x_offset /= overlap.unsqueeze(dim=1)
2761

2862
# (b, c, h*kernel_size, w*kernel_size)
2963
x_offset = self._reshape_x_offset(x_offset, ks)
@@ -74,24 +108,17 @@ def _get_q(x_size, dtype):
74108

75109
return q
76110

77-
def _get_x_offset(self, x, p, dtype):
111+
def _get_x_q(self, x, q, N):
78112
b, c, h, w = x.size()
79-
N = p.size(1)//2
80-
# (h, w)
81-
q = self._get_q(x.size(), dtype)
82-
# (b,, 2N, h, w, 1, 1)
83-
p = p.unsqueeze(dim=-1).unsqueeze(dim=-1)
84-
zero = Variable(torch.FloatTensor([0]).type(dtype), requires_grad=True)
85-
86-
# (b, N, h, w, h, w)
87-
G = torch.max((1-torch.abs(p[:, :N, :, :, :, :] - q[0, :, :])), zero)\
88-
* torch.max((1-torch.abs(p[:, N:, :, :, :, :] - q[1, :, :])), zero)
89-
# (b, N*h*w, h*w)
90-
G = G.contiguous().view(b, N*h*w, -1)
91-
# (b, h*w, c)
92-
x = x.permute(0, 2, 3, 1).contiguous().view(b, -1, c)
93-
# (b, c, h, w, N)
94-
x_offset = torch.bmm(G, x).contiguous().view(b, N, h, w, c).permute(0, 4, 2, 3, 1)
113+
# (b, c, h*w)
114+
x = x.contiguous().view(b, c, -1)
115+
116+
# (b, h, w, N)
117+
index = q[..., :N]*w + q[..., N:] # offset_x*offset_y + offset_y
118+
# (b, c, h*w*N)
119+
index = index.contiguous().unsqueeze(dim=1).expand(-1, c, -1, -1, -1).contiguous().view(b, c, -1)
120+
121+
x_offset = x.gather(dim=-1, index=index).contiguous().view(b, c, h, w, N)
95122

96123
return x_offset
97124

0 commit comments

Comments
 (0)