Skip to content

Commit 60736bd

Browse files
authored
fix corner case in kwargs for DataParallel (pytorch#930)
1 parent 7d58765 commit 60736bd

File tree

2 files changed

+11
-9
lines changed

2 files changed

+11
-9
lines changed

test/test_nn.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -869,6 +869,9 @@ def local_test(out):
869869
out = dp.data_parallel(m, (var1, var2, float1), (0, 1))
870870
local_test(out)
871871

872+
out = dp.data_parallel(m, (var1, var2, float1), (1, 0))
873+
local_test(out)
874+
872875
out = dp.data_parallel(m, (var1, var2, float1), (0,))
873876
local_test(out)
874877

torch/nn/parallel/data_parallel.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -78,17 +78,16 @@ def _to_cuda(obj):
7878

7979
replicas = self.replicate(self.module, self.device_ids)
8080
scattered = self.scatter(inputs, self.device_ids)
81-
81+
used_gpus = len(scattered) # The last GPU might not be used. For example, input of size 4, on 5 GPUs
8282
gpu_dicts = None
83-
if kwargs:
84-
scatter_kwargs = {}
83+
if kwargs is not None:
84+
gpu_dicts = [{} for i in range(used_gpus)]
8585
for key in kwargs.keys():
86-
scatter_kwargs[key] = self.scatter(
87-
_to_cuda(kwargs[key]), self.device_ids)
88-
gpu_dicts = tuple(
89-
{key: values[i] for key, values in scatter_kwargs.items()}
90-
for i in self.device_ids
91-
)
86+
scattered_kwargs = self.scatter(_to_cuda(kwargs[key]), self.device_ids)
87+
assert len(scattered_kwargs) == used_gpus
88+
for i in range(used_gpus):
89+
gpu_dicts[i][key] = scattered_kwargs[i]
90+
9291
replicas = replicas[:len(scattered)]
9392
outputs = self.parallel_apply(replicas, scattered, gpu_dicts)
9493
return self.gather(outputs, self.output_device)

0 commit comments

Comments
 (0)