Skip to content

Commit 48e65c7

Browse files
authored
Merge pull request VainF#35 from trougnouf/master
Fix weights tensor creation / PyTorch 1.11 compat
2 parents 6ceec02 + a948c90 commit 48e65c7

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

pytorch_msssim/ssim.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ def ms_ssim(
203203

204204
if weights is None:
205205
weights = [0.0448, 0.2856, 0.3001, 0.2363, 0.1333]
206-
weights = torch.FloatTensor(weights, device=X.device, dtype=X.dtype)
206+
weights = X.new_tensor(weights)
207207

208208
if win is None:
209209
win = _fspecial_gauss_1d(win_size, win_sigma)

0 commit comments

Comments
 (0)