|
1 | 1 | # Copyright 2020 by Gongfan Fang, Zhejiang University.
|
2 | 2 | # All rights reserved.
|
| 3 | +import warnings |
3 | 4 |
|
4 | 5 | import torch
|
5 | 6 | import torch.nn.functional as F
|
@@ -32,18 +33,24 @@ def gaussian_filter(input, win):
|
32 | 33 | Returns:
|
33 | 34 | torch.Tensor: blurred tensors
|
34 | 35 | """
|
| 36 | + assert all([ws == 1 for ws in win.shape[:-1]]), win.shape |
35 | 37 | if len(input.shape) == 4:
|
36 |
| - N, C, H, W = input.shape |
37 |
| - out = F.conv2d(input, win, stride=1, padding=0, groups=C) |
38 |
| - out = F.conv2d(out, win.transpose(2, 3), stride=1, padding=0, groups=C) |
| 38 | + conv = F.conv2d |
39 | 39 | elif len(input.shape) == 5:
|
40 |
| - N, C, T, H, W = input.shape |
41 |
| - out = F.conv3d(input, weight=win, stride=1, padding=0, groups=C) |
42 |
| - out = F.conv3d(out, weight=win.transpose(3, 4), stride=1, padding=0, groups=C) |
43 |
| - out = F.conv3d(out, weight=win.transpose(2, 4), stride=1, padding=0, groups=C) |
| 40 | + conv = F.conv3d |
44 | 41 | else:
|
45 | 42 | raise NotImplementedError(input.shape)
|
46 | 43 |
|
| 44 | + C = input.shape[1] |
| 45 | + out = input |
| 46 | + for i, s in enumerate(input.shape[2:]): |
| 47 | + if s >= win.shape[-1]: |
| 48 | + out = conv(out, weight=win.transpose(2 + i, -1), stride=1, padding=0, groups=C) |
| 49 | + else: |
| 50 | + warnings.warn( |
| 51 | + f"Skipping Gaussian Smoothing at dimension 2+{i} for input: {input.shape} and win size: {win.shape[-1]}" |
| 52 | + ) |
| 53 | + |
47 | 54 | return out
|
48 | 55 |
|
49 | 56 |
|
@@ -118,7 +125,7 @@ def ssim(
|
118 | 125 | if not X.shape == Y.shape:
|
119 | 126 | raise ValueError("Input images should have the same dimensions.")
|
120 | 127 |
|
121 |
| - for d in range(len(X.shape) -1, 1, -1): |
| 128 | + for d in range(len(X.shape) - 1, 1, -1): |
122 | 129 | X = X.squeeze(dim=d)
|
123 | 130 | Y = Y.squeeze(dim=d)
|
124 | 131 |
|
@@ -169,7 +176,7 @@ def ms_ssim(
|
169 | 176 | if not X.shape == Y.shape:
|
170 | 177 | raise ValueError("Input images should have the same dimensions.")
|
171 | 178 |
|
172 |
| - for d in range(len(X.shape) -1, 1, -1): |
| 179 | + for d in range(len(X.shape) - 1, 1, -1): |
173 | 180 | X = X.squeeze(dim=d)
|
174 | 181 | Y = Y.squeeze(dim=d)
|
175 | 182 |
|
@@ -202,7 +209,6 @@ def ms_ssim(
|
202 | 209 | win = _fspecial_gauss_1d(win_size, win_sigma)
|
203 | 210 | win = win.repeat([X.shape[1]] + [1] * (len(X.shape) - 1))
|
204 | 211 |
|
205 |
| - |
206 | 212 | levels = weights.shape[0]
|
207 | 213 | mcs = []
|
208 | 214 | for i in range(levels):
|
|
0 commit comments