Skip to content

Commit f3f8ce4

Browse files
authored
Merge pull request pytorch#18 from soumith/master
Fix handling of if_true/if_false in ATen
2 parents 128e02d + 7ee7542 commit f3f8ce4

File tree

1 file changed

+16
-3
lines changed

1 file changed

+16
-3
lines changed

function_wrapper.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
import re
22
from code_template import CodeTemplate
33

4+
import sys
5+
if sys.version_info[0] == 3:
6+
string_type = str
7+
else:
8+
string_type = basestring
9+
410
# temporary things we cannot handle
511
EXCLUDE_PATTERN = "bernoulli.*|normal.*|exponential.*|random.*|arange.*"
612
# what has to be done to add a Operation ...
@@ -272,14 +278,21 @@ def create_derived(backend_type_env, declarations):
272278
def requires_checked_cast(argument):
273279
return argument['type'] in CHECKED_CAST
274280

281+
def bool_option_is_string(argument):
282+
return 'if_true' in argument and isinstance(argument['if_true'], string_type)
283+
275284
def get_argument(argument, option):
276285
if requires_checked_cast(argument):
277286
return CHECKED_USE.get(argument['type'], '{}_').format(argument['name'])
278287
elif argument['type'] == 'bool' and 'if_true' in argument:
279-
return '({}) ? "{}" : "{}"'.format(argument['name'],
280-
argument['if_true'], argument['if_false'])
288+
if bool_option_is_string(argument):
289+
tpl = '({}) ? "{}" : "{}"'
290+
else:
291+
tpl = '({}) ? {} : {}'
292+
return tpl.format(argument['name'],
293+
argument['if_true'], argument['if_false'])
281294
elif argument['type'] == "CONSTANT":
282-
if 'if_true' in argument: # this was a bool that is actually a string...
295+
if bool_option_is_string(argument): # this is a bool that is actually a string...
283296
return '"{}"'.format(argument['name'])
284297
v = str(argument['name'])
285298
for pattern, replacement in CONSTANT_REPLACEMENTS:

0 commit comments

Comments
 (0)