Skip to content

Commit 528a871

Browse files
committed
Added cunn tests for UpSampling module.
1 parent e9d54e1 commit 528a871

File tree

1 file changed

+78
-0
lines changed

1 file changed

+78
-0
lines changed

test.lua

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4437,6 +4437,84 @@ function cunntest.SpatialUpSamplingBilinear_backward_batch()
44374437
end
44384438
end
44394439

4440+
function cunntest.UpSampling_forward_batch()
4441+
local minibatch = torch.random(1, 10)
4442+
local f = torch.random(3, 10)
4443+
local d = torch.random(3, 10)
4444+
local h = torch.random(3, 10)
4445+
local w = torch.random(3, 10)
4446+
local scale = torch.random(2,5)
4447+
4448+
for k, typename in ipairs(typenames) do
4449+
for _,mode in pairs({'nearest','linear'}) do
4450+
for dim = 4,5 do
4451+
local input
4452+
if (dim == 4) then
4453+
input = torch.randn(minibatch, f, h, w):type(typename)
4454+
else
4455+
input = torch.randn(minibatch, f, d, h, w):type(typename)
4456+
end
4457+
4458+
local ctype = t2cpu[typename]
4459+
input = makeNonContiguous(input:type(ctype))
4460+
local sconv = nn.UpSampling(scale, mode):type(ctype)
4461+
local groundtruth = sconv:forward(input)
4462+
4463+
input = makeNonContiguous(input:type(typename))
4464+
local gconv = sconv:clone():type(typename)
4465+
local rescuda = gconv:forward(input)
4466+
4467+
local error = rescuda:double() - groundtruth:double()
4468+
mytester:assertlt(error:abs():max(), precision_forward_type(precision_forward, typename),
4469+
string.format('error on state (forward) with %s', typename))
4470+
end
4471+
end
4472+
end
4473+
end
4474+
4475+
function cunntest.UpSampling_backward_batch()
4476+
local minibatch = torch.random(1, 10)
4477+
local f = torch.random(3, 10)
4478+
local d = torch.random(3, 10)
4479+
local h = torch.random(3, 10)
4480+
local w = torch.random(3, 10)
4481+
local scale = torch.random(2,4)
4482+
4483+
for k, typename in ipairs(typenames) do
4484+
for _,mode in pairs({'nearest','linear'}) do
4485+
for dim = 4,5 do
4486+
local input, gradOutput
4487+
if (dim == 4) then
4488+
input = torch.randn(minibatch, f, h, w):type(typename)
4489+
gradOutput = torch.randn(minibatch, f, h*scale, w*scale):type(typename)
4490+
else
4491+
input = torch.randn(minibatch, f, d, h, w):type(typename)
4492+
gradOutput = torch.randn(minibatch, f, d*scale, h*scale, w*scale):type(typename)
4493+
end
4494+
4495+
local ctype = t2cpu[typename]
4496+
input = makeNonContiguous(input:type(ctype))
4497+
gradOutput = makeNonContiguous(gradOutput:type(ctype))
4498+
local sconv = nn.UpSampling(scale, mode):type(ctype)
4499+
sconv:forward(input)
4500+
sconv:zeroGradParameters()
4501+
local groundgrad = sconv:backward(input, gradOutput)
4502+
4503+
input = makeNonContiguous(input:type(typename))
4504+
gradOutput = makeNonContiguous(gradOutput:type(typename))
4505+
local gconv = sconv:clone():type(typename)
4506+
gconv:forward(input)
4507+
gconv:zeroGradParameters()
4508+
local rescuda = gconv:backward(input, gradOutput)
4509+
4510+
local error = rescuda:double() - groundgrad:double()
4511+
mytester:assertlt(error:abs():max(), precision_backward_type(precision_backward, typename),
4512+
string.format('error on state (backward) with %s', typename))
4513+
end
4514+
end
4515+
end
4516+
end
4517+
44404518
function cunntest.l1cost()
44414519
local size = math.random(300,500)
44424520

0 commit comments

Comments
 (0)