Skip to content

Commit ec34e55

Browse files
committed
allow, but warn for too small input to Gaussian filter
1 parent e5bc574 commit ec34e55

File tree

1 file changed

+16
-10
lines changed

1 file changed

+16
-10
lines changed

pytorch_msssim/ssim.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright 2020 by Gongfan Fang, Zhejiang University.
22
# All rights reserved.
3+
import warnings
34

45
import torch
56
import torch.nn.functional as F
@@ -32,18 +33,24 @@ def gaussian_filter(input, win):
3233
Returns:
3334
torch.Tensor: blurred tensors
3435
"""
36+
assert all([ws == 1 for ws in win.shape[:-1]]), win.shape
3537
if len(input.shape) == 4:
36-
N, C, H, W = input.shape
37-
out = F.conv2d(input, win, stride=1, padding=0, groups=C)
38-
out = F.conv2d(out, win.transpose(2, 3), stride=1, padding=0, groups=C)
38+
conv = F.conv2d
3939
elif len(input.shape) == 5:
40-
N, C, T, H, W = input.shape
41-
out = F.conv3d(input, weight=win, stride=1, padding=0, groups=C)
42-
out = F.conv3d(out, weight=win.transpose(3, 4), stride=1, padding=0, groups=C)
43-
out = F.conv3d(out, weight=win.transpose(2, 4), stride=1, padding=0, groups=C)
40+
conv = F.conv3d
4441
else:
4542
raise NotImplementedError(input.shape)
4643

44+
C = input.shape[1]
45+
out = input
46+
for i, s in enumerate(input.shape[2:]):
47+
if s >= win.shape[-1]:
48+
out = conv(out, weight=win.transpose(2 + i, -1), stride=1, padding=0, groups=C)
49+
else:
50+
warnings.warn(
51+
f"Skipping Gaussian Smoothing at dimension 2+{i} for input: {input.shape} and win size: {win.shape[-1]}"
52+
)
53+
4754
return out
4855

4956

@@ -118,7 +125,7 @@ def ssim(
118125
if not X.shape == Y.shape:
119126
raise ValueError("Input images should have the same dimensions.")
120127

121-
for d in range(len(X.shape) -1, 1, -1):
128+
for d in range(len(X.shape) - 1, 1, -1):
122129
X = X.squeeze(dim=d)
123130
Y = Y.squeeze(dim=d)
124131

@@ -169,7 +176,7 @@ def ms_ssim(
169176
if not X.shape == Y.shape:
170177
raise ValueError("Input images should have the same dimensions.")
171178

172-
for d in range(len(X.shape) -1, 1, -1):
179+
for d in range(len(X.shape) - 1, 1, -1):
173180
X = X.squeeze(dim=d)
174181
Y = Y.squeeze(dim=d)
175182

@@ -202,7 +209,6 @@ def ms_ssim(
202209
win = _fspecial_gauss_1d(win_size, win_sigma)
203210
win = win.repeat([X.shape[1]] + [1] * (len(X.shape) - 1))
204211

205-
206212
levels = weights.shape[0]
207213
mcs = []
208214
for i in range(levels):

0 commit comments

Comments
 (0)