Skip to content

Commit e7f5220

Browse files
authored
device_ids can be None again in data_parallel (pytorch#1187)
1 parent a7ae04a commit e7f5220

File tree

2 files changed

+7
-1
lines changed

2 files changed

+7
-1
lines changed

test/test_nn.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -996,6 +996,9 @@ def test_data_parallel(self):
996996
self.assertEqual(out.get_device(), 0)
997997
self.assertEqual(out.data, expected_out)
998998

999+
# Check for None device_ids
1000+
out = dp.data_parallel(l, i)
1001+
9991002
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
10001003
def test_data_parallel_nested_output(self):
10011004
def fn(input):

torch/nn/parallel/data_parallel.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def gather(self, outputs, output_device):
7474
return gather(outputs, output_device, dim=self.dim)
7575

7676

77-
def data_parallel(module, inputs, device_ids, output_device=None, dim=0, module_kwargs=None):
77+
def data_parallel(module, inputs, device_ids=None, output_device=None, dim=0, module_kwargs=None):
7878
"""Evaluates module(input) in parallel across the GPUs given in device_ids.
7979
8080
This is the functional version of the DataParallel module.
@@ -92,6 +92,9 @@ def data_parallel(module, inputs, device_ids, output_device=None, dim=0, module_
9292
if not isinstance(inputs, tuple):
9393
inputs = (inputs,)
9494

95+
if device_ids is None:
96+
device_ids = list(range(torch.cuda.device_count()))
97+
9598
if output_device is None:
9699
output_device = device_ids[0]
97100

0 commit comments

Comments
 (0)