|
1 | | -from copy import copy |
2 | | -from collections import OrderedDict |
3 | | - |
4 | | -from ..modules import Module |
5 | 1 | import torch.cuda.comm as comm |
6 | 2 |
|
7 | 3 |
|
8 | | -def _replicate_module(module, gpu, param_remap): |
9 | | - if module is None: |
10 | | - return module |
11 | | - replica = copy(module) |
12 | | - replica._parameters = OrderedDict() |
13 | | - for key, param in module._parameters.items(): |
14 | | - replica._parameters[key] = param_remap.get(param) |
15 | | - replica._buffers = {} |
16 | | - for key, buffer in module._buffers.items(): |
17 | | - replica._buffers[key] = param_remap.get(buffer) |
18 | | - if replica._modules: |
19 | | - replica._modules = OrderedDict() |
20 | | - for name, child in module._modules.items(): |
21 | | - replica._modules[name] = _replicate_module(child, gpu, param_remap) |
22 | | - return replica |
23 | | - |
24 | | - |
25 | | -def replicate(module, device_ids): |
| 4 | +def replicate(network, devices): |
26 | 5 | from ._functions import Broadcast |
27 | | - seen_params = set() |
28 | | - param_remap = [{} for dev_id in device_ids] |
29 | | - for param in module.parameters(): |
30 | | - if param in seen_params: |
31 | | - continue |
32 | | - seen_params.add(param) |
33 | | - param_copies = Broadcast(device_ids)(param) |
34 | | - for param_copy, remap in zip(param_copies, param_remap): |
35 | | - remap[param] = param_copy |
36 | | - for m in module.modules(): |
37 | | - for buffer in m._buffers.values(): |
38 | | - copies = comm.broadcast(buffer, device_ids) |
39 | | - for buf_copy, remap in zip(copies, param_remap): |
40 | | - remap[buffer] = buf_copy |
41 | | - return [_replicate_module(module, device_id, remap) |
42 | | - for device_id, remap in zip(device_ids, param_remap)] |
| 6 | + |
| 7 | + devices = tuple(devices) |
| 8 | + num_replicas = len(devices) |
| 9 | + |
| 10 | + params = list(network.parameters()) |
| 11 | + param_indices = {param: idx for idx, param in enumerate(params)} |
| 12 | + param_copies = Broadcast(devices)(*params) |
| 13 | + if len(params) > 0: |
| 14 | + param_copies = [param_copies[i:i + len(params)] |
| 15 | + for i in range(0, len(param_copies), len(params))] |
| 16 | + |
| 17 | + buffers = _buffers(network) |
| 18 | + buffer_indices = {buf: idx for idx, buf in enumerate(buffers)} |
| 19 | + buffer_copies = comm.broadcast_coalesced(buffers, devices) |
| 20 | + |
| 21 | + modules = list(network.modules()) |
| 22 | + module_copies = [[] for device in devices] |
| 23 | + module_indices = {} |
| 24 | + |
| 25 | + for i, module in enumerate(modules): |
| 26 | + module_indices[module] = i |
| 27 | + for j in range(num_replicas): |
| 28 | + replica = module.__new__(type(module)) |
| 29 | + replica.__dict__ = module.__dict__.copy() |
| 30 | + replica._parameters = replica._parameters.copy() |
| 31 | + replica._buffers = replica._buffers.copy() |
| 32 | + replica._modules = replica._modules.copy() |
| 33 | + module_copies[j].append(replica) |
| 34 | + |
| 35 | + for i, module in enumerate(modules): |
| 36 | + for key, child in module._modules.items(): |
| 37 | + module_idx = module_indices[child] |
| 38 | + for j in range(num_replicas): |
| 39 | + replica = module_copies[j][i] |
| 40 | + replica._modules[key] = module_copies[j][module_idx] |
| 41 | + for key, param in module._parameters.items(): |
| 42 | + if param is None: |
| 43 | + for j in range(num_replicas): |
| 44 | + replica = module_copies[j][i] |
| 45 | + replica._parameters[key] = None |
| 46 | + else: |
| 47 | + param_idx = param_indices[param] |
| 48 | + for j in range(num_replicas): |
| 49 | + replica = module_copies[j][i] |
| 50 | + replica._parameters[key] = param_copies[j][param_idx] |
| 51 | + for key, buf in module._buffers.items(): |
| 52 | + if buf is None: |
| 53 | + for j in range(num_replicas): |
| 54 | + replica = module_copies[j][i] |
| 55 | + replica._buffers[key] = None |
| 56 | + else: |
| 57 | + buffer_idx = buffer_indices[buf] |
| 58 | + for j in range(num_replicas): |
| 59 | + replica = module_copies[j][i] |
| 60 | + replica._buffers[key] = buffer_copies[j][buffer_idx] |
| 61 | + |
| 62 | + return [module_copies[j][0] for j in range(num_replicas)] |
| 63 | + |
| 64 | + |
| 65 | +def _buffers(network): |
| 66 | + buffers = [] |
| 67 | + seen = set() |
| 68 | + for module in network.modules(): |
| 69 | + for buf in module._buffers.values(): |
| 70 | + if buf not in seen and buf is not None: |
| 71 | + seen.add(buf) |
| 72 | + buffers.append(buf) |
| 73 | + return buffers |
0 commit comments