@@ -18,12 +18,46 @@ def forward(self, x):
1818 offset = self .conv_offset (x )
1919 dtype = offset .data .type ()
2020 ks = self .kernel_size
21+ N = offset .size (1 ) // 2
2122
22- # (b, 2N, h, w)
23+ # # (b, 2N, h, w)
2324 p = self ._get_p (offset , dtype )
2425
26+ # (b, h, w, 2N)
27+ p = p .contiguous ().permute (0 , 2 , 3 , 1 )
28+
29+ q_lt = p .floor ().long ()
30+ q_rb = p .ceil ().long ()
31+ q_lb = torch .cat ([q_lt [..., :N ], q_rb [..., N :]], - 1 )
32+ q_rt = torch .cat ([q_rb [..., :N ], q_lt [..., N :]], - 1 )
33+
34+ # (b, h, w, N)
35+ overlap_indices = torch .eq (q_lt , q_rb ) + 1
36+ overlap = (overlap_indices [..., :N ] * overlap_indices [..., N :]).float ()
37+
38+ # bilinear kernel (b, h, w, N)
39+ g_lt_x = 1 - torch .abs (p [..., :N ] - q_lt [..., :N ].type_as (p ))
40+ g_lt_y = 1 - torch .abs (p [..., N :] - q_lt [..., N :].type_as (p ))
41+ g_rb_x = 1 - torch .abs (p [..., :N ] - q_rb [..., :N ].type_as (p ))
42+ g_rb_y = 1 - torch .abs (p [..., N :] - q_rb [..., N :].type_as (p ))
43+ g_lb_x = 1 - torch .abs (p [..., :N ] - q_lb [..., :N ].type_as (p ))
44+ g_lb_y = 1 - torch .abs (p [..., N :] - q_lb [..., N :].type_as (p ))
45+ g_rt_x = 1 - torch .abs (p [..., :N ] - q_rt [..., :N ].type_as (p ))
46+ g_rt_y = 1 - torch .abs (p [..., N :] - q_rt [..., N :].type_as (p ))
47+
2548 # (b, c, h, w, N)
26- x_offset = self ._get_x_offset (x , p , dtype )
49+ x_q_lt = self ._get_x_q (x , q_lt , N )
50+ x_q_rb = self ._get_x_q (x , q_rb , N )
51+ x_q_lb = self ._get_x_q (x , q_lb , N )
52+ x_q_rt = self ._get_x_q (x , q_rt , N )
53+
54+ # (b, c, h, w, N)
55+ x_offset = (g_lt_x * g_lt_y ).unsqueeze (dim = 1 ) * x_q_lt + \
56+ (g_rb_x * g_rb_y ).unsqueeze (dim = 1 ) * x_q_rb + \
57+ (g_lb_x * g_lb_y ).unsqueeze (dim = 1 ) * x_q_lb + \
58+ (g_rt_x * g_rt_y ).unsqueeze (dim = 1 ) * x_q_rt
59+
60+ x_offset /= overlap .unsqueeze (dim = 1 )
2761
2862 # (b, c, h*kernel_size, w*kernel_size)
2963 x_offset = self ._reshape_x_offset (x_offset , ks )
@@ -74,24 +108,17 @@ def _get_q(x_size, dtype):
74108
75109 return q
76110
77- def _get_x_offset (self , x , p , dtype ):
111+ def _get_x_q (self , x , q , N ):
78112 b , c , h , w = x .size ()
79- N = p .size (1 )// 2
80- # (h, w)
81- q = self ._get_q (x .size (), dtype )
82- # (b,, 2N, h, w, 1, 1)
83- p = p .unsqueeze (dim = - 1 ).unsqueeze (dim = - 1 )
84- zero = Variable (torch .FloatTensor ([0 ]).type (dtype ), requires_grad = True )
85-
86- # (b, N, h, w, h, w)
87- G = torch .max ((1 - torch .abs (p [:, :N , :, :, :, :] - q [0 , :, :])), zero )\
88- * torch .max ((1 - torch .abs (p [:, N :, :, :, :, :] - q [1 , :, :])), zero )
89- # (b, N*h*w, h*w)
90- G = G .contiguous ().view (b , N * h * w , - 1 )
91- # (b, h*w, c)
92- x = x .permute (0 , 2 , 3 , 1 ).contiguous ().view (b , - 1 , c )
93- # (b, c, h, w, N)
94- x_offset = torch .bmm (G , x ).contiguous ().view (b , N , h , w , c ).permute (0 , 4 , 2 , 3 , 1 )
113+ # (b, c, h*w)
114+ x = x .contiguous ().view (b , c , - 1 )
115+
116+ # (b, h, w, N)
117+ index = q [..., :N ]* w + q [..., N :] # offset_x*offset_y + offset_y
118+ # (b, c, h*w*N)
119+ index = index .contiguous ().unsqueeze (dim = 1 ).expand (- 1 , c , - 1 , - 1 , - 1 ).contiguous ().view (b , c , - 1 )
120+
121+ x_offset = x .gather (dim = - 1 , index = index ).contiguous ().view (b , c , h , w , N )
95122
96123 return x_offset
97124
0 commit comments