Skip to content

Commit ed038ca

Browse files
committed
updated tests and added test for enhance method
1 parent 0a2fadf commit ed038ca

File tree

1 file changed

+39
-27
lines changed

1 file changed

+39
-27
lines changed

tests/test_models.py

Lines changed: 39 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -3,50 +3,62 @@
33
import torch
44
from torch_enhance import models
55

6+
67
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
78
SCALE_FACTOR = 2
8-
x = torch.ones(1, 3, 32, 32)
9-
x = x.to(torch.float32)
10-
x = x.to(DEVICE)
9+
CHANNELS = 3
10+
lr = torch.ones(1, CHANNELS, 32, 32)
11+
lr = lr.to(torch.float32)
12+
lr = lr.to(DEVICE)
1113

1214
def test_bicubic():
13-
model = models.Bicubic(scale_factor=SCALE_FACTOR)
15+
model = models.Bicubic(scale_factor=SCALE_FACTOR, channels=CHANNELS)
1416
model = model.to(DEVICE)
15-
y_pred = model(x)
16-
assert y_pred.shape == (1, 3, 64, 64)
17-
assert y_pred.dtype == torch.float32
17+
sr = model(lr)
18+
assert sr.shape == (1, 3, 64, 64)
19+
assert sr.dtype == torch.float32
1820

1921
def test_edsr():
20-
model = models.EDSR(scale_factor=SCALE_FACTOR)
22+
model = models.EDSR(scale_factor=SCALE_FACTOR, channels=CHANNELS)
2123
model = model.to(DEVICE)
22-
y_pred = model(x)
23-
assert y_pred.shape == (1, 3, 64, 64)
24-
assert y_pred.dtype == torch.float32
24+
sr = model(lr)
25+
assert sr.shape == (1, 3, 64, 64)
26+
assert sr.dtype == torch.float32
2527

2628
def test_espcn():
27-
model = models.ESPCN(scale_factor=SCALE_FACTOR)
29+
model = models.ESPCN(scale_factor=SCALE_FACTOR, channels=CHANNELS)
2830
model = model.to(DEVICE)
29-
y_pred = model(x)
30-
assert y_pred.shape == (1, 3, 64, 64)
31-
assert y_pred.dtype == torch.float32
31+
sr = model(lr)
32+
assert sr.shape == (1, 3, 64, 64)
33+
assert sr.dtype == torch.float32
3234

3335
def test_srcnn():
34-
model = models.SRCNN(scale_factor=SCALE_FACTOR)
36+
model = models.SRCNN(scale_factor=SCALE_FACTOR, channels=CHANNELS)
3537
model = model.to(DEVICE)
36-
y_pred = model(x)
37-
assert y_pred.shape == (1, 3, 64, 64)
38-
assert y_pred.dtype == torch.float32
38+
sr = model(lr)
39+
assert sr.shape == (1, 3, 64, 64)
40+
assert sr.dtype == torch.float32
3941

4042
def test_srresnet():
41-
model = models.SRResNet(scale_factor=SCALE_FACTOR)
43+
model = models.SRResNet(scale_factor=SCALE_FACTOR, channels=CHANNELS)
4244
model = model.to(DEVICE)
43-
y_pred = model(x)
44-
assert y_pred.shape == (1, 3, 64, 64)
45-
assert y_pred.dtype == torch.float32
45+
sr = model(lr)
46+
assert sr.shape == (1, 3, 64, 64)
47+
assert sr.dtype == torch.float32
4648

4749
def test_vdsr():
48-
model = models.VDSR(scale_factor=SCALE_FACTOR)
50+
model = models.VDSR(scale_factor=SCALE_FACTOR, channels=CHANNELS)
51+
model = model.to(DEVICE)
52+
sr = model(lr)
53+
assert sr.shape == (1, 3, 64, 64)
54+
assert sr.dtype == torch.float32
55+
56+
def test_enhance():
57+
model = models.SRCNN(scale_factor=SCALE_FACTOR, channels=CHANNELS)
4958
model = model.to(DEVICE)
50-
y_pred = model(x)
51-
assert y_pred.shape == (1, 3, 64, 64)
52-
assert y_pred.dtype == torch.float32
59+
sr = model.enhance(lr)
60+
assert sr.shape == (64, 64, 3)
61+
assert sr.dtype == torch.uint8
62+
63+
sr = model.enhance(lr.squeeze(0))
64+
assert sr.shape == (64, 64, 3)

0 commit comments

Comments
 (0)