Skip to content

Commit 7296c27

Browse files
committed
Fix test_DataParallelTable.lua
1 parent 64d36c2 commit 7296c27

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

test_DataParallelTable.lua

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -409,14 +409,13 @@ function test.DataParallelTable_noGradInput()
409409
local input = torch.Tensor(5):random(10):cuda()
410410
local output1 = net:forward(input):clone()
411411
local gradOutput = output1:clone():uniform(-1, 1)
412-
local gradInput1 = net:backward(output1, gradOutput):clone()
412+
local gradInput1 = net:backward(input, gradOutput):clone()
413413

414414
local output2 = dpt:forward(input)
415-
local gradInput2 = dpt:backward(output2, gradOutput)
416-
mytester:assert((output1 - output2):abs():max(), precision,
415+
local gradInput2 = dpt:backward(input, gradOutput)
416+
mytester:assertlt((output1 - output2):abs():max(), precision,
417417
'forward prop error')
418-
mytester:assert(gradInput2:nElement() == 0 and gradInput1:nElement() == 0,
419-
'backward prop error')
418+
mytester:asserteq(gradInput2:nElement(), gradInput1:nElement())
420419
end
421420

422421
function test.DataParallelTable_accGradParameters()

0 commit comments

Comments
 (0)