3
3
import torch
4
4
from torch_enhance import models
5
5
6
+
6
7
DEVICE = "cuda" if torch .cuda .is_available () else "cpu"
7
8
SCALE_FACTOR = 2
8
- x = torch .ones (1 , 3 , 32 , 32 )
9
- x = x .to (torch .float32 )
10
- x = x .to (DEVICE )
9
+ CHANNELS = 3
10
+ lr = torch .ones (1 , CHANNELS , 32 , 32 )
11
+ lr = lr .to (torch .float32 )
12
+ lr = lr .to (DEVICE )
11
13
12
14
def test_bicubic ():
13
- model = models .Bicubic (scale_factor = SCALE_FACTOR )
15
+ model = models .Bicubic (scale_factor = SCALE_FACTOR , channels = CHANNELS )
14
16
model = model .to (DEVICE )
15
- y_pred = model (x )
16
- assert y_pred .shape == (1 , 3 , 64 , 64 )
17
- assert y_pred .dtype == torch .float32
17
+ sr = model (lr )
18
+ assert sr .shape == (1 , 3 , 64 , 64 )
19
+ assert sr .dtype == torch .float32
18
20
19
21
def test_edsr ():
20
- model = models .EDSR (scale_factor = SCALE_FACTOR )
22
+ model = models .EDSR (scale_factor = SCALE_FACTOR , channels = CHANNELS )
21
23
model = model .to (DEVICE )
22
- y_pred = model (x )
23
- assert y_pred .shape == (1 , 3 , 64 , 64 )
24
- assert y_pred .dtype == torch .float32
24
+ sr = model (lr )
25
+ assert sr .shape == (1 , 3 , 64 , 64 )
26
+ assert sr .dtype == torch .float32
25
27
26
28
def test_espcn ():
27
- model = models .ESPCN (scale_factor = SCALE_FACTOR )
29
+ model = models .ESPCN (scale_factor = SCALE_FACTOR , channels = CHANNELS )
28
30
model = model .to (DEVICE )
29
- y_pred = model (x )
30
- assert y_pred .shape == (1 , 3 , 64 , 64 )
31
- assert y_pred .dtype == torch .float32
31
+ sr = model (lr )
32
+ assert sr .shape == (1 , 3 , 64 , 64 )
33
+ assert sr .dtype == torch .float32
32
34
33
35
def test_srcnn ():
34
- model = models .SRCNN (scale_factor = SCALE_FACTOR )
36
+ model = models .SRCNN (scale_factor = SCALE_FACTOR , channels = CHANNELS )
35
37
model = model .to (DEVICE )
36
- y_pred = model (x )
37
- assert y_pred .shape == (1 , 3 , 64 , 64 )
38
- assert y_pred .dtype == torch .float32
38
+ sr = model (lr )
39
+ assert sr .shape == (1 , 3 , 64 , 64 )
40
+ assert sr .dtype == torch .float32
39
41
40
42
def test_srresnet ():
41
- model = models .SRResNet (scale_factor = SCALE_FACTOR )
43
+ model = models .SRResNet (scale_factor = SCALE_FACTOR , channels = CHANNELS )
42
44
model = model .to (DEVICE )
43
- y_pred = model (x )
44
- assert y_pred .shape == (1 , 3 , 64 , 64 )
45
- assert y_pred .dtype == torch .float32
45
+ sr = model (lr )
46
+ assert sr .shape == (1 , 3 , 64 , 64 )
47
+ assert sr .dtype == torch .float32
46
48
47
49
def test_vdsr ():
48
- model = models .VDSR (scale_factor = SCALE_FACTOR )
50
+ model = models .VDSR (scale_factor = SCALE_FACTOR , channels = CHANNELS )
51
+ model = model .to (DEVICE )
52
+ sr = model (lr )
53
+ assert sr .shape == (1 , 3 , 64 , 64 )
54
+ assert sr .dtype == torch .float32
55
+
56
+ def test_enhance ():
57
+ model = models .SRCNN (scale_factor = SCALE_FACTOR , channels = CHANNELS )
49
58
model = model .to (DEVICE )
50
- y_pred = model (x )
51
- assert y_pred .shape == (1 , 3 , 64 , 64 )
52
- assert y_pred .dtype == torch .float32
59
+ sr = model .enhance (lr )
60
+ assert sr .shape == (64 , 64 , 3 )
61
+ assert sr .dtype == torch .uint8
62
+
63
+ sr = model .enhance (lr .squeeze (0 ))
64
+ assert sr .shape == (64 , 64 , 3 )
0 commit comments