@@ -4437,6 +4437,84 @@ function cunntest.SpatialUpSamplingBilinear_backward_batch()
44374437 end
44384438end
44394439
4440+ function cunntest .UpSampling_forward_batch ()
4441+ local minibatch = torch .random (1 , 10 )
4442+ local f = torch .random (3 , 10 )
4443+ local d = torch .random (3 , 10 )
4444+ local h = torch .random (3 , 10 )
4445+ local w = torch .random (3 , 10 )
4446+ local scale = torch .random (2 ,5 )
4447+
4448+ for k , typename in ipairs (typenames ) do
4449+ for _ ,mode in pairs ({' nearest' ,' linear' }) do
4450+ for dim = 4 ,5 do
4451+ local input
4452+ if (dim == 4 ) then
4453+ input = torch .randn (minibatch , f , h , w ):type (typename )
4454+ else
4455+ input = torch .randn (minibatch , f , d , h , w ):type (typename )
4456+ end
4457+
4458+ local ctype = t2cpu [typename ]
4459+ input = makeNonContiguous (input :type (ctype ))
4460+ local sconv = nn .UpSampling (scale , mode ):type (ctype )
4461+ local groundtruth = sconv :forward (input )
4462+
4463+ input = makeNonContiguous (input :type (typename ))
4464+ local gconv = sconv :clone ():type (typename )
4465+ local rescuda = gconv :forward (input )
4466+
4467+ local error = rescuda :double () - groundtruth :double ()
4468+ mytester :assertlt (error :abs ():max (), precision_forward_type (precision_forward , typename ),
4469+ string.format (' error on state (forward) with %s' , typename ))
4470+ end
4471+ end
4472+ end
4473+ end
4474+
4475+ function cunntest .UpSampling_backward_batch ()
4476+ local minibatch = torch .random (1 , 10 )
4477+ local f = torch .random (3 , 10 )
4478+ local d = torch .random (3 , 10 )
4479+ local h = torch .random (3 , 10 )
4480+ local w = torch .random (3 , 10 )
4481+ local scale = torch .random (2 ,4 )
4482+
4483+ for k , typename in ipairs (typenames ) do
4484+ for _ ,mode in pairs ({' nearest' ,' linear' }) do
4485+ for dim = 4 ,5 do
4486+ local input , gradOutput
4487+ if (dim == 4 ) then
4488+ input = torch .randn (minibatch , f , h , w ):type (typename )
4489+ gradOutput = torch .randn (minibatch , f , h * scale , w * scale ):type (typename )
4490+ else
4491+ input = torch .randn (minibatch , f , d , h , w ):type (typename )
4492+ gradOutput = torch .randn (minibatch , f , d * scale , h * scale , w * scale ):type (typename )
4493+ end
4494+
4495+ local ctype = t2cpu [typename ]
4496+ input = makeNonContiguous (input :type (ctype ))
4497+ gradOutput = makeNonContiguous (gradOutput :type (ctype ))
4498+ local sconv = nn .UpSampling (scale , mode ):type (ctype )
4499+ sconv :forward (input )
4500+ sconv :zeroGradParameters ()
4501+ local groundgrad = sconv :backward (input , gradOutput )
4502+
4503+ input = makeNonContiguous (input :type (typename ))
4504+ gradOutput = makeNonContiguous (gradOutput :type (typename ))
4505+ local gconv = sconv :clone ():type (typename )
4506+ gconv :forward (input )
4507+ gconv :zeroGradParameters ()
4508+ local rescuda = gconv :backward (input , gradOutput )
4509+
4510+ local error = rescuda :double () - groundgrad :double ()
4511+ mytester :assertlt (error :abs ():max (), precision_backward_type (precision_backward , typename ),
4512+ string.format (' error on state (backward) with %s' , typename ))
4513+ end
4514+ end
4515+ end
4516+ end
4517+
44404518function cunntest .l1cost ()
44414519 local size = math.random (300 ,500 )
44424520
0 commit comments