Skip to content

Commit 65b6626

Browse files
committed
Improve broadcast/reduce performance by coalescing tensors
1 parent 34ce58c commit 65b6626

File tree

3 files changed

+175
-46
lines changed

3 files changed

+175
-46
lines changed

torch/cuda/comm.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,35 @@ def broadcast(tensor, devices):
2828
return tuple(tensor.cuda(gpu, async=True) for gpu in devices)
2929

3030

31+
def broadcast_coalesced(tensors, devices, buffer_size=10485760):
32+
"""Broadcasts a sequence tensors to the specified GPUs.
33+
34+
Small tensors are first coalesced into a buffer to reduce the number
35+
of synchronizations.
36+
37+
Arguments:
38+
tensors (sequence): tensors to broadcast.
39+
devices (Iterable): an iterable of devices to which to broadcast.
40+
buffer_size (int): maximum size of the buffer used for coalescing
41+
42+
Returns:
43+
A tuple containing copies of the ``tensor``, placed on devices
44+
corresponding to indices from ``devices``.
45+
"""
46+
for tensor in tensors:
47+
if tensor.get_device() != devices[0]:
48+
raise RuntimeError('all tensors must be on devices[0]')
49+
outputs = [[] for _ in devices]
50+
# use the original tensors for the first device
51+
outputs[0].extend(tensors)
52+
for chunk in _take_tensors(tensors, buffer_size):
53+
results = broadcast(_flatten_tensors(chunk), devices)
54+
# use the broadcasted tensors for the remaining devices
55+
for dst, res in zip(outputs[1:], results[1:]):
56+
dst.extend(_unflatten_tensors(res, chunk))
57+
return tuple(outputs)
58+
59+
3160
def reduce_add(inputs, destination=None):
3261
"""Sums tensors from multiple GPUs.
3362
@@ -68,6 +97,31 @@ def reduce_add(inputs, destination=None):
6897
return result
6998

7099

100+
def reduce_add_coalesced(inputs, destination=None, buffer_size=10485760):
101+
"""Sums tensors from multiple GPUs.
102+
103+
Small tensors are first coalesced into a buffer to reduce the number
104+
of synchronizations.
105+
106+
Arguments:
107+
inputs (Iterable[Tensor]): an iterable of tensors to add.
108+
destination (int, optional): a device on which the output will be
109+
placed (default: current device).
110+
buffer_size (int): maximum size of the buffer used for coalescing
111+
112+
Returns:
113+
A tuple of tensors containing an elementwise sum of each group of
114+
inputs, placed on the ``destination`` device.
115+
"""
116+
output = []
117+
itrs = [_take_tensors(tensors, buffer_size) for tensors in inputs]
118+
for chunks in zip(*itrs):
119+
flattened = [_flatten_tensors(chunk) for chunk in chunks]
120+
result = reduce_add(flattened, destination)
121+
output.extend(_unflatten_tensors(result, chunks[0]))
122+
return tuple(output)
123+
124+
71125
def scatter(tensor, devices, chunk_sizes=None, dim=0):
72126
"""Scatters tensor across multiple GPUs.
73127
@@ -142,3 +196,42 @@ def gather(tensors, dim=0, destination=None):
142196
result.narrow(dim, chunk_start, tensor.size(dim)).copy_(tensor, True)
143197
chunk_start += tensor.size(dim)
144198
return result
199+
200+
201+
def _flatten_tensors(tensors):
202+
"""Flatten tensors into a single contiguous 1D buffer"""
203+
if len(tensors) == 1:
204+
return tensors[0].contiguous().view(-1)
205+
size = sum(tensor.numel() for tensor in tensors)
206+
offset = 0
207+
flat = tensors[0].new(size)
208+
for tensor in tensors:
209+
flat.narrow(0, offset, tensor.numel()).copy_(tensor)
210+
offset += tensor.numel()
211+
return flat
212+
213+
214+
def _unflatten_tensors(flat, tensors):
215+
"""View a flat buffer using the sizes of tensors"""
216+
outputs = []
217+
offset = 0
218+
for tensor in tensors:
219+
outputs.append(flat.narrow(0, offset, tensor.numel()).view_as(tensor))
220+
offset += tensor.numel()
221+
return tuple(outputs)
222+
223+
224+
def _take_tensors(tensors, size_limit):
225+
"""Groups tensors into lists of up to size_limit bytes"""
226+
buf = []
227+
size = 0
228+
for tensor in tensors:
229+
param_size = tensor.numel() * tensor.element_size()
230+
if size + param_size > size_limit and size > 0:
231+
yield buf
232+
size = 0
233+
buf = []
234+
buf.append(tensor)
235+
size += param_size
236+
if len(buf) > 0:
237+
yield buf

torch/nn/parallel/_functions.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import torch.cuda
21
import torch.cuda.comm as comm
32
from torch.autograd import Function
43

@@ -9,13 +8,19 @@ def __init__(self, target_gpus):
98
super(Broadcast, self).__init__()
109
self.target_gpus = target_gpus
1110

12-
def forward(self, input):
13-
assert input.is_cuda, "Broadcast function not implemented for CPU tensors"
14-
self.input_device = input.get_device()
15-
return comm.broadcast(input, self.target_gpus)
16-
17-
def backward(self, *grad_output):
18-
return comm.reduce_add(grad_output, self.input_device)
11+
def forward(self, *inputs):
12+
if not all(input.is_cuda for input in inputs):
13+
raise TypeError('Broadcast function not implemented for CPU tensors')
14+
if len(inputs) == 0:
15+
return tuple()
16+
self.input_device = inputs[0].get_device()
17+
outputs = comm.broadcast_coalesced(inputs, self.target_gpus)
18+
return tuple([t for tensors in outputs for t in tensors])
19+
20+
def backward(self, *grad_outputs):
21+
grad_outputs = [grad_outputs[i:i + self.num_inputs]
22+
for i in range(0, len(grad_outputs), self.num_inputs)]
23+
return comm.reduce_add_coalesced(grad_outputs, self.input_device)
1924

2025

2126
class Gather(Function):

torch/nn/parallel/replicate.py

Lines changed: 69 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,73 @@
1-
from copy import copy
2-
from collections import OrderedDict
3-
4-
from ..modules import Module
51
import torch.cuda.comm as comm
62

73

8-
def _replicate_module(module, gpu, param_remap):
9-
if module is None:
10-
return module
11-
replica = copy(module)
12-
replica._parameters = OrderedDict()
13-
for key, param in module._parameters.items():
14-
replica._parameters[key] = param_remap.get(param)
15-
replica._buffers = {}
16-
for key, buffer in module._buffers.items():
17-
replica._buffers[key] = param_remap.get(buffer)
18-
if replica._modules:
19-
replica._modules = OrderedDict()
20-
for name, child in module._modules.items():
21-
replica._modules[name] = _replicate_module(child, gpu, param_remap)
22-
return replica
23-
24-
25-
def replicate(module, device_ids):
4+
def replicate(network, devices):
265
from ._functions import Broadcast
27-
seen_params = set()
28-
param_remap = [{} for dev_id in device_ids]
29-
for param in module.parameters():
30-
if param in seen_params:
31-
continue
32-
seen_params.add(param)
33-
param_copies = Broadcast(device_ids)(param)
34-
for param_copy, remap in zip(param_copies, param_remap):
35-
remap[param] = param_copy
36-
for m in module.modules():
37-
for buffer in m._buffers.values():
38-
copies = comm.broadcast(buffer, device_ids)
39-
for buf_copy, remap in zip(copies, param_remap):
40-
remap[buffer] = buf_copy
41-
return [_replicate_module(module, device_id, remap)
42-
for device_id, remap in zip(device_ids, param_remap)]
6+
7+
devices = tuple(devices)
8+
num_replicas = len(devices)
9+
10+
params = list(network.parameters())
11+
param_indices = {param: idx for idx, param in enumerate(params)}
12+
param_copies = Broadcast(devices)(*params)
13+
if len(params) > 0:
14+
param_copies = [param_copies[i:i + len(params)]
15+
for i in range(0, len(param_copies), len(params))]
16+
17+
buffers = _buffers(network)
18+
buffer_indices = {buf: idx for idx, buf in enumerate(buffers)}
19+
buffer_copies = comm.broadcast_coalesced(buffers, devices)
20+
21+
modules = list(network.modules())
22+
module_copies = [[] for device in devices]
23+
module_indices = {}
24+
25+
for i, module in enumerate(modules):
26+
module_indices[module] = i
27+
for j in range(num_replicas):
28+
replica = module.__new__(type(module))
29+
replica.__dict__ = module.__dict__.copy()
30+
replica._parameters = replica._parameters.copy()
31+
replica._buffers = replica._buffers.copy()
32+
replica._modules = replica._modules.copy()
33+
module_copies[j].append(replica)
34+
35+
for i, module in enumerate(modules):
36+
for key, child in module._modules.items():
37+
module_idx = module_indices[child]
38+
for j in range(num_replicas):
39+
replica = module_copies[j][i]
40+
replica._modules[key] = module_copies[j][module_idx]
41+
for key, param in module._parameters.items():
42+
if param is None:
43+
for j in range(num_replicas):
44+
replica = module_copies[j][i]
45+
replica._parameters[key] = None
46+
else:
47+
param_idx = param_indices[param]
48+
for j in range(num_replicas):
49+
replica = module_copies[j][i]
50+
replica._parameters[key] = param_copies[j][param_idx]
51+
for key, buf in module._buffers.items():
52+
if buf is None:
53+
for j in range(num_replicas):
54+
replica = module_copies[j][i]
55+
replica._buffers[key] = None
56+
else:
57+
buffer_idx = buffer_indices[buf]
58+
for j in range(num_replicas):
59+
replica = module_copies[j][i]
60+
replica._buffers[key] = buffer_copies[j][buffer_idx]
61+
62+
return [module_copies[j][0] for j in range(num_replicas)]
63+
64+
65+
def _buffers(network):
66+
buffers = []
67+
seen = set()
68+
for module in network.modules():
69+
for buf in module._buffers.values():
70+
if buf not in seen and buf is not None:
71+
seen.add(buf)
72+
buffers.append(buf)
73+
return buffers

0 commit comments

Comments
 (0)