Skip to content

Commit 4af40e3

Browse files
apaszkesoumith
authored andcommitted
Let parallel_apply accept arbitrary inputs
1 parent f417cb0 commit 4af40e3

File tree

3 files changed

+26
-20
lines changed

3 files changed

+26
-20
lines changed

torch/nn/parallel/data_parallel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def scatter(self, inputs, kwargs, device_ids):
6767
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
6868

6969
def parallel_apply(self, replicas, inputs, kwargs):
70-
return parallel_apply(replicas, inputs, kwargs)
70+
return parallel_apply(replicas, inputs, kwargs, self.device_ids)
7171

7272
def gather(self, outputs, output_device):
7373
return gather(outputs, output_device, dim=self.dim)
@@ -101,5 +101,5 @@ def data_parallel(module, inputs, device_ids=None, output_device=None, dim=0, mo
101101
if len(device_ids) == 1:
102102
return module(*inputs[0], **module_kwargs[0])
103103
replicas = replicate(module, device_ids[:len(inputs)])
104-
outputs = parallel_apply(replicas, inputs, module_kwargs)
104+
outputs = parallel_apply(replicas, inputs, module_kwargs, device_ids)
105105
return gather(outputs, output_device, dim)

torch/nn/parallel/distributed.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ def scatter(self, inputs, kwargs, device_ids):
164164
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
165165

166166
def parallel_apply(self, replicas, inputs, kwargs):
167-
return parallel_apply(replicas, inputs, kwargs)
167+
return parallel_apply(replicas, inputs, kwargs, self.device_ids)
168168

169169
def gather(self, outputs, output_device):
170170
return gather(outputs, output_device, dim=self.dim)

torch/nn/parallel/parallel_apply.py

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -20,40 +20,46 @@ def get_a_var(obj):
2020
return None
2121

2222

23-
def parallel_apply(modules, inputs, kwargs_tup=None):
23+
def parallel_apply(modules, inputs, kwargs_tup=None, devices=None):
2424
assert len(modules) == len(inputs)
25-
if kwargs_tup:
25+
if kwargs_tup is not None:
2626
assert len(modules) == len(kwargs_tup)
2727
else:
2828
kwargs_tup = ({},) * len(modules)
29-
# Fast track
30-
if len(modules) == 1:
31-
return (modules[0](*inputs[0], **kwargs_tup[0]), )
29+
if devices is not None:
30+
assert len(modules) == len(devices)
31+
else:
32+
devices = [None] * len(modules)
3233

3334
lock = threading.Lock()
3435
results = {}
3536

36-
def _worker(i, module, input, kwargs, results, lock):
37-
var_input = get_a_var(input)
37+
def _worker(i, module, input, kwargs, results, lock, device=None):
38+
if device is None:
39+
device = get_a_var(input).get_device()
3840
try:
39-
with torch.cuda.device_of(var_input):
41+
with torch.cuda.device(device):
4042
output = module(*input, **kwargs)
4143
with lock:
4244
results[i] = output
4345
except Exception as e:
4446
with lock:
4547
results[i] = e
4648

47-
threads = [threading.Thread(target=_worker,
48-
args=(i, module, input, kwargs, results, lock),
49-
)
50-
for i, (module, input, kwargs) in
51-
enumerate(zip(modules, inputs, kwargs_tup))]
49+
if len(modules) > 1:
50+
threads = [threading.Thread(target=_worker,
51+
args=(i, module, input, kwargs, results, lock, device),
52+
)
53+
for i, (module, input, kwargs, device) in
54+
enumerate(zip(modules, inputs, kwargs_tup, devices))]
55+
56+
for thread in threads:
57+
thread.start()
58+
for thread in threads:
59+
thread.join()
60+
else:
61+
_worker(0, modules[0], inputs[0], kwargs_tup[0], results, lock, devices[0])
5262

53-
for thread in threads:
54-
thread.start()
55-
for thread in threads:
56-
thread.join()
5763
outputs = []
5864
for i in range(len(inputs)):
5965
output = results[i]

0 commit comments

Comments
 (0)