@@ -433,7 +433,8 @@ def forward(self, x):
433433 x = F .pad (x , (self .pad ,) * 4 , self .pad_mode )
434434 weight = x .new_zeros ([x .shape [1 ], x .shape [1 ], self .kernel .shape [0 ], self .kernel .shape [1 ]])
435435 indices = torch .arange (x .shape [1 ], device = x .device )
436- weight [indices , indices ] = self .kernel .to (weight )
436+ kernel = self .kernel .to (weight )[None , :].expand (x .shape [1 ], - 1 , - 1 )
437+ weight [indices , indices ] = kernel
437438 return F .conv2d (x , weight , stride = 2 )
438439
439440
@@ -449,7 +450,8 @@ def forward(self, x):
449450 x = F .pad (x , ((self .pad + 1 ) // 2 ,) * 4 , self .pad_mode )
450451 weight = x .new_zeros ([x .shape [1 ], x .shape [1 ], self .kernel .shape [0 ], self .kernel .shape [1 ]])
451452 indices = torch .arange (x .shape [1 ], device = x .device )
452- weight [indices , indices ] = self .kernel .to (weight )
453+ kernel = self .kernel .to (weight )[None , :].expand (x .shape [1 ], - 1 , - 1 )
454+ weight [indices , indices ] = kernel
453455 return F .conv_transpose2d (x , weight , stride = 2 , padding = self .pad * 2 + 1 )
454456
455457
0 commit comments