Skip to content

Commit 95f351f

Browse files
committed
Zero gradients on secondary GPUs after accGradParameters
Before this change, calling forward & backward multiple times without calling zeroGradParameters would compute an incorrect gradient. Fixes #209
1 parent c79b0c0 commit 95f351f

File tree

2 files changed

+50
-1
lines changed

2 files changed

+50
-1
lines changed

DataParallelTable.lua

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,13 +216,21 @@ function DataParallelTable:__backward(method, input, gradOutput, scale)
216216
end
217217

218218
if method == 'backward' or method == 'accGradParameters' then
219+
local params = self:moduleParameters()
219220
-- Accumulate the gradients onto the base GPU
220221
if self.flattenedParams and self.usenccl and not cudaLaunchBlocking then
221222
if #self.gpuAssignments > 1 then
222223
nccl.reduce(pluck(self.flattenedParams, 2), nil, true, 1)
223224
end
224225
else
225-
self:_reduce(pluck(self:moduleParameters(), 2))
226+
self:_reduce(pluck(params, 2))
227+
end
228+
-- Zero out gradients on the other GPUs
229+
for i = 2, #self.gpuAssignments do
230+
cutorch.setDevice(self.gpuAssignments[i])
231+
for _, gradParam in ipairs(params[i][2]) do
232+
gradParam:zero()
233+
end
226234
end
227235
self.needsSync = true
228236
end

test_DataParallelTable.lua

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -397,6 +397,47 @@ function test.DataParallelTable_noGradInput()
397397
'backward prop error')
398398
end
399399

400+
function test.DataParallelTable_accGradParameters()
401+
local net = nn.Sequential()
402+
:add(nn.Linear(3, 10))
403+
:add(nn.ReLU())
404+
:add(nn.Linear(10, 7))
405+
:cuda()
406+
407+
local inputs = {}
408+
local gradOutputs = {}
409+
for i=1,3 do
410+
inputs[i] = torch.randn(8, 3):cuda()
411+
gradOutputs[i] = torch.randn(8, 7):cuda()
412+
end
413+
414+
local configs = {
415+
{1, false, false},
416+
{1, true, false},
417+
}
418+
419+
local function accumulateGradient(m)
420+
m:zeroGradParameters()
421+
for i=1,#inputs do
422+
m:forward(inputs[i])
423+
m:backward(inputs[i], gradOutputs[i])
424+
end
425+
m:updateParameters(0.5)
426+
end
427+
428+
local base = net:clone()
429+
accumulateGradient(base)
430+
local expected = base:forward(inputs[1])
431+
432+
for _, config in ipairs(configs) do
433+
local dpt = nn.DataParallelTable(1, true, false)
434+
:add(net:clone(), torch.range(1, numGpus):totable())
435+
accumulateGradient(dpt)
436+
local output = dpt:forward(inputs[1])
437+
mytester:assertlt((output - expected):abs():max(), 1e-5, 'invalid output')
438+
end
439+
end
440+
400441
function test.DataParallelTable_streams()
401442
local net = nn.Sequential()
402443
:add(nn.Linear(3, 10))

0 commit comments

Comments
 (0)