Skip to content

Commit cb4eaa9

Browse files
killeentsoumith
authored andcommitted
TensorLib/Aten --> changes required in pytorch
1 parent b5854a1 commit cb4eaa9

17 files changed

+1344
-540
lines changed

setup.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,10 +173,12 @@ def run(self):
173173
from tools.cwrap.plugins.WrapDim import WrapDim
174174
from tools.cwrap.plugins.AssertNDim import AssertNDim
175175
from tools.cwrap.plugins.Broadcast import Broadcast
176+
from tools.cwrap.plugins.ProcessorSpecificPlugin import ProcessorSpecificPlugin
176177
thp_plugin = THPPlugin()
177178
cwrap('torch/csrc/generic/TensorMethods.cwrap', plugins=[
178-
BoolOption(), thp_plugin, AutoGPU(condition='IS_CUDA'),
179-
ArgcountSortPlugin(), KwargsPlugin(), AssertNDim(), WrapDim(), Broadcast()
179+
ProcessorSpecificPlugin(), BoolOption(), thp_plugin,
180+
AutoGPU(condition='IS_CUDA'), ArgcountSortPlugin(), KwargsPlugin(),
181+
AssertNDim(), WrapDim(), Broadcast()
180182
])
181183
cwrap('torch/csrc/cudnn/cuDNN.cwrap', plugins=[
182184
CuDNNPlugin(), NullableArguments()

tools/cwrap/cwrap.py

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from copy import deepcopy
55
from .plugins import ArgcountChecker, OptionalArguments, ArgumentReferences, \
66
BeforeAfterCall, ConstantArguments, ReturnArguments, GILRelease
7+
from ..shared import cwrap_common
78

89

910
class cwrap(object):
@@ -76,7 +77,7 @@ def wrap_declarations(self, declarations):
7677
elif line == ']]':
7778
in_declaration = False
7879
declaration = yaml.load('\n'.join(declaration_lines))
79-
self.set_declaration_defaults(declaration)
80+
cwrap_common.set_declaration_defaults(declaration)
8081

8182
# Pass declaration in a list - maybe some plugins want to add
8283
# multiple wrappers
@@ -104,24 +105,6 @@ def wrap_declarations(self, declarations):
104105

105106
return '\n'.join(output)
106107

107-
def set_declaration_defaults(self, declaration):
108-
declaration.setdefault('arguments', [])
109-
declaration.setdefault('return', 'void')
110-
if 'cname' not in declaration:
111-
declaration['cname'] = declaration['name']
112-
# Simulate multiple dispatch, even if it's not necessary
113-
if 'options' not in declaration:
114-
declaration['options'] = [{'arguments': declaration['arguments']}]
115-
del declaration['arguments']
116-
# Parse arguments (some of them can be strings)
117-
for option in declaration['options']:
118-
option['arguments'] = self.parse_arguments(option['arguments'])
119-
# Propagate defaults from declaration to options
120-
for option in declaration['options']:
121-
for k, v in declaration.items():
122-
if k != 'name' and k != 'options':
123-
option.setdefault(k, v)
124-
125108
def parse_arguments(self, args):
126109
new_args = []
127110
for arg in args:
Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
import os
12
from . import CWrapPlugin
3+
from ...shared import cwrap_common
24

35

46
class ArgcountSortPlugin(CWrapPlugin):
@@ -7,8 +9,7 @@ def __init__(self, descending=True):
79
self.descending = descending
810

911
def process_declarations(self, declarations):
10-
def num_checked_args(option):
11-
return sum(map(lambda a: not a.get('ignore_check', False), option['arguments']))
1212
for declaration in declarations:
13-
declaration['options'].sort(key=num_checked_args, reverse=self.descending)
13+
cwrap_common.sort_by_number_of_options(declaration,
14+
self.descending)
1415
return declarations
Lines changed: 8 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,58 +1,18 @@
1+
import os
12
from copy import deepcopy
23
from . import CWrapPlugin
34
from itertools import product
5+
from ...shared import cwrap_common
46

57

68
class OptionalArguments(CWrapPlugin):
79

810
def process_declarations(self, declarations):
9-
new_options = []
1011
for declaration in declarations:
11-
for option in declaration['options']:
12-
optional_args = []
13-
for i, arg in enumerate(option['arguments']):
14-
if 'default' in arg:
15-
optional_args.append(i)
16-
for permutation in product((True, False), repeat=len(optional_args)):
17-
option_copy = deepcopy(option)
18-
for i, bit in zip(optional_args, permutation):
19-
arg = option_copy['arguments'][i]
20-
if not bit:
21-
arg['type'] = 'CONSTANT'
22-
arg['ignore_check'] = True
23-
# PyYAML interprets NULL as None...
24-
arg['name'] = 'NULL' if arg['default'] is None else arg['default']
25-
new_options.append(option_copy)
26-
declaration['options'] = self.filter_unique_options(new_options)
27-
return declarations
12+
cwrap_common.enumerate_options_due_to_default(
13+
declaration,
14+
allow_kwarg=True,
15+
type_to_signature={},
16+
remove_self=False)
2817

29-
def filter_unique_options(self, options):
30-
def signature(option, kwarg_only_count):
31-
if kwarg_only_count == 0:
32-
kwarg_only_count = None
33-
else:
34-
kwarg_only_count = -kwarg_only_count
35-
arg_signature = '#'.join(
36-
arg['type']
37-
for arg in option['arguments'][:kwarg_only_count]
38-
if not arg.get('ignore_check'))
39-
if kwarg_only_count is None:
40-
return arg_signature
41-
kwarg_only_signature = '#'.join(
42-
arg['name'] + '#' + arg['type']
43-
for arg in option['arguments'][kwarg_only_count:]
44-
if not arg.get('ignore_check'))
45-
return arg_signature + "#-#" + kwarg_only_signature
46-
seen_signatures = set()
47-
unique = []
48-
for option in options:
49-
for num_kwarg_only in range(0, len(option['arguments']) + 1):
50-
sig = signature(option, num_kwarg_only)
51-
if sig not in seen_signatures:
52-
if num_kwarg_only > 0:
53-
for arg in option['arguments'][-num_kwarg_only:]:
54-
arg['kwarg_only'] = True
55-
unique.append(option)
56-
seen_signatures.add(sig)
57-
break
58-
return unique
18+
return declarations
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
from copy import deepcopy
2+
from . import CWrapPlugin
3+
import yaml
4+
5+
6+
class ProcessorSpecificPlugin(CWrapPlugin):
7+
8+
def process_declarations(self, declarations):
9+
# In order to move Torch's random functions into the same cwrap
10+
# declaration, we need to be able to handle the fact that on the CPU
11+
# these functions take a generator argument, while on the GPU, they
12+
# do not. As such, we would like to split those declarations at cwrap
13+
# runtime into two separate declarations, one for the CPU (unchanged),
14+
# and one for the GPU (with the generator argument removed).
15+
#
16+
# For example, the declaration arguments:
17+
# arguments:
18+
# - THTensor* self
19+
# - arg: THGenerator* generator
20+
# default: THPDefaultGenerator->cdata
21+
# kwarg_only: True
22+
#
23+
# Would have the generator argument removed when generating for the GPU
24+
# backend.
25+
26+
def arg_contains_generator(arg):
27+
return (arg['type'] == 'THGenerator*' or (arg.get('default', None)
28+
is not None and 'THPDefaultGenerator' in
29+
str(arg.get('default', ""))))
30+
31+
def split_candidate(declaration):
32+
# First, check and see if it is a declaration for both CPU/GPU
33+
if all([proc in declaration['backends'] for
34+
proc in ['CPU', 'CUDA']]):
35+
for option in declaration['options']:
36+
for argument in option['arguments']:
37+
if arg_contains_generator(argument):
38+
return True
39+
40+
return False
41+
42+
def can_we_handle_the_split(declaration):
43+
# hook into here if the split cannot happen for some reason
44+
return True
45+
46+
def generator_split(declaration):
47+
# the split must make two changes: 1. remove the generator argument
48+
# for the GPU, and 2. assign the correct backends/types to the
49+
# split declaration
50+
dec_cpu = declaration
51+
dec_gpu = deepcopy(declaration)
52+
53+
# Remove GPU backend and types from dec_cpu
54+
dec_cpu['backends'].remove('CUDA')
55+
if dec_cpu.get('backend_type_pairs', False):
56+
dec_cpu['backend_type_pairs'] = (
57+
[pair for pair in dec_cpu['backend_type_pairs'] if
58+
pair[1] == 'CPU'])
59+
# also need to reach into options
60+
for option in dec_cpu['options']:
61+
option['backends'].remove('CUDA')
62+
63+
# Remove CPU backend and types from dec_gpu
64+
dec_gpu['backends'].remove('CPU')
65+
if dec_gpu.get('backend_type_pairs', False):
66+
dec_gpu['backend_type_pairs'] = (
67+
[pair for pair in dec_gpu['backend_type_pairs'] if
68+
pair[1] == 'CUDA'])
69+
# also need to reach into options
70+
for option in dec_gpu['options']:
71+
option['backends'].remove('CPU')
72+
73+
# Remove generator arguments from dec_gpu options
74+
for option in dec_gpu['options']:
75+
option['arguments'] = (
76+
[arg for arg in option['arguments'] if
77+
not arg_contains_generator(arg)])
78+
79+
return [dec_cpu, dec_gpu]
80+
81+
decs = []
82+
for declaration in declarations:
83+
if split_candidate(declaration):
84+
assert(can_we_handle_the_split(declaration))
85+
newdecs = generator_split(declaration)
86+
decs.extend(newdecs)
87+
else:
88+
decs.append(declaration)
89+
90+
return decs

tools/cwrap/plugins/THPPlugin.py

Lines changed: 95 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -334,8 +334,92 @@ def has_output_args(declaration):
334334
for option in declaration['options']
335335
for arg in option['arguments'])
336336

337+
def backends_types_to_defined_if_string(declaration):
338+
# A declaration has two fields: 'backend', which stores a list of
339+
# backends (currently 'cpu' and 'cuda') the declaration applies
340+
# to, and 'types', which stores a list of real types the
341+
# declaration applies to. In PyTorch, when a function is only
342+
# supported by a subset of types, we wrap it in macro definition
343+
# checks.
344+
#
345+
# Previously, we manually required the cwrap declaration to
346+
# specify for which backend/type combinations a function was
347+
# defined for. Now, we explicitly list the types and backends for
348+
# a declaration, if it should only be supported for a specific
349+
# subset of types, backends, or type-backend pairs.
350+
351+
types = declaration.get('types', [])
352+
backends = declaration['backends']
353+
all_backends = ['CPU', 'CUDA']
354+
355+
def get_defined_string(backend, real):
356+
if backend == 'CUDA':
357+
if real == 'all':
358+
return "IS_CUDA"
359+
else:
360+
return 'CUDA_{0}'.format(real.upper())
361+
else:
362+
if real == 'all':
363+
return "!IS_CUDA"
364+
else:
365+
return 'defined(TH_REAL_IS_{0})'.format(real.upper())
366+
367+
def expand_composite_type(p, t):
368+
if t == 'floating_point':
369+
result = ['double', 'float']
370+
if p == 'CUDA':
371+
result.append('half')
372+
elif t == 'integral':
373+
result = ['byte', 'char', 'short', 'int', 'long']
374+
else:
375+
result = [t]
376+
return result
377+
378+
defineds = []
379+
380+
# The logic below does not handle corner cases well. We allow the
381+
# declaration to have a field 'backend_type_pairs' that stores a
382+
# dictionary from type --> backend representing allowed
383+
# combinations. Let's use these first.
384+
for pair in declaration.get('backend_type_pairs', []):
385+
p, t = pair
386+
defineds.extend([get_defined_string(p, et) for et in
387+
expand_composite_type(p, t)])
388+
389+
# In the base case, types is empty and backends contains both
390+
# 'CPU' and 'CUDA' --> this means we support all types, and our
391+
# string should be empty, or simply the list of explict type
392+
# backend pairs
393+
if (len(types) == 0 and all([proc in backends for proc in
394+
all_backends])):
395+
return " || ".join(defineds)
396+
397+
# Case 2: types is empty, but only one backend type is specified
398+
if len(types) == 0 and len(backends) == 1:
399+
defineds.append('IS_CUDA' if backends[0] == 'CUDA' else
400+
"!IS_CUDA")
401+
return " || ".join(defineds)
402+
403+
# Else, we loop overall all of the backend, type pairs and add
404+
# them
405+
for p in backends:
406+
for t in types:
407+
defineds.extend([get_defined_string(p, et) for et in
408+
expand_composite_type(p, t)])
409+
410+
return " || ".join(defineds)
411+
337412
for declaration in declarations:
338413
# Disable all methods for THHalfTensor, unless cpu_half is True
414+
415+
dfstr = backends_types_to_defined_if_string(declaration)
416+
if len(dfstr) > 0:
417+
# for now, need to check for distributed defined if as well
418+
if 'defined_if' in declaration:
419+
declaration['defined_if'] += ' && (' + dfstr + ')'
420+
else:
421+
declaration['defined_if'] = dfstr
422+
339423
if not declaration.get('cpu_half', False):
340424
defined_if = '!defined(TH_REAL_IS_HALF)'
341425
if 'defined_if' in declaration:
@@ -362,11 +446,13 @@ def has_output_args(declaration):
362446
if option.get('sparse', False):
363447
defined_if = option.get('defined_if', '')
364448
option['defined_if'] = '!IS_DISTRIBUTED' + (' && ' if defined_if else '') + defined_if
365-
if declaration.get('with_stateless', False) or declaration.get('only_stateless', False):
449+
450+
variants = declaration.get('variants', ['method'])
451+
if 'function' in variants:
366452
stateless_declaration = self.make_stateless(declaration)
367453
new_declarations.append(stateless_declaration)
368454
self.stateless_declarations.append(stateless_declaration)
369-
if declaration.get('only_stateless', False):
455+
if 'method' not in variants:
370456
continue
371457

372458
self.declarations.append(declaration)
@@ -379,9 +465,13 @@ def has_output_args(declaration):
379465

380466
register_only = [d for d in declarations if d.get('only_register', False)]
381467
declarations = [d for d in declarations
382-
if (not d.get('only_stateless', False)) and (not d.get('only_register', False))]
383-
self.declarations.extend(filter(lambda x: not x.get('only_stateless', False), register_only))
384-
self.stateless_declarations.extend(filter(lambda x: x.get('only_stateless', False), register_only))
468+
if (('method' in d.get('variants', ['method'])) and
469+
(not d.get('only_register', False)))]
470+
self.declarations.extend(filter(lambda x: 'method' in x.get('variants',
471+
['method']), register_only))
472+
self.stateless_declarations.extend(filter(lambda x: 'method' not in
473+
x.get('variants', ['method']),
474+
register_only))
385475

386476
self.process_docstrings()
387477

0 commit comments

Comments
 (0)