Skip to content

Commit 567d95f

Browse files
authored
Merge pull request pytorch#25 from killeent/nullable-tensors
add support for Null Tensors to functions
2 parents 7914d67 + 8451468 commit 567d95f

File tree

4 files changed

+39
-9
lines changed

4 files changed

+39
-9
lines changed

Utils.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,11 @@ namespace at {
1010
void runtime_error(const char *format, ...);
1111

1212
template <typename T, typename Base>
13-
static inline T* checked_cast(Base* expr, const char * name, int pos) {
13+
static inline T* checked_cast(Base* expr, const char * name, int pos, bool allowNull) {
1414
if(!expr) {
15+
if (allowNull) {
16+
return (T*) expr;
17+
}
1518
runtime_error("Expected a Tensor of type %s but found an undefined Tensor for argument #%d '%s'",
1619
T::typeString(),pos,name);
1720
}

common_with_cwrap.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,14 @@ def signature(option, kwarg_only_count):
9393

9494
def enumerate_options_due_to_default(declaration,
9595
allow_kwarg=True, type_to_signature=[], remove_self=True):
96+
97+
# Checks to see if an argument with a default keyword is a Tensor that
98+
# by default can be NULL. In this case, instead of generating another
99+
# option that excludes this argument, we will instead generate a single
100+
# function call that allows for the Tensor to be NULL
101+
def is_nullable_tensor_arg(arg):
102+
return arg['type'] == 'THTensor*' and arg['default'] == 'nullptr'
103+
96104
# TODO(zach): in cwrap this is shared among all declarations
97105
# but seems to assume that all declarations will have the same
98106
new_options = []

copy_wrapper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
FUNCTION = CodeTemplate("""\
2626
void ${Type}::copy(const Tensor & src, Tensor & dst) {
2727
// code generated by function_wrapper
28-
auto dst_ = checked_cast<${Tensor}>(dst.pImpl,"dst",0);
28+
auto dst_ = checked_cast<${Tensor}>(dst.pImpl,"dst",0,false);
2929
(void) dst_; //silence unused warning
3030
switch(src.type().ID()) {
3131
${copy_body}

function_wrapper.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -109,11 +109,14 @@ def __init__(self, reason):
109109
'long': 'int64_t',
110110
}
111111
CHECKED_CAST = {
112-
'THTensor*': CodeTemplate('checked_cast<${Tensor}>(${arg_name}.pImpl,"${arg_name}",${arg_pos})'),
113-
'THBoolTensor*': CodeTemplate('checked_cast<${Backend}ByteTensor>(${arg_name}.pImpl,"${arg_name}",${arg_pos})'),
114-
'THIndexTensor*': CodeTemplate('checked_cast<${Backend}LongTensor>(${arg_name}.pImpl,"${arg_name}",${arg_pos})'),
115-
'THIntegerTensor*': CodeTemplate('checked_cast<${Backend}IntTensor>(${arg_name}.pImpl,"${arg_name}",${arg_pos})'),
116-
'THStorage*': CodeTemplate('checked_cast<${Storage}>(&${arg_name},"${arg_name}",${arg_pos})'),
112+
'THTensor*': CodeTemplate('checked_cast<${Tensor}>(${arg_name}.pImpl,"${arg_name}",${arg_pos}, ${null_okay})'),
113+
'THBoolTensor*':
114+
CodeTemplate('checked_cast<${Backend}ByteTensor>(${arg_name}.pImpl,"${arg_name}",${arg_pos}, ${null_okay})'),
115+
'THIndexTensor*':
116+
CodeTemplate('checked_cast<${Backend}LongTensor>(${arg_name}.pImpl,"${arg_name}",${arg_pos}, ${null_okay})'),
117+
'THIntegerTensor*':
118+
CodeTemplate('checked_cast<${Backend}IntTensor>(${arg_name}.pImpl,"${arg_name}",${arg_pos}, ${null_okay})'),
119+
'THStorage*': CodeTemplate('checked_cast<${Storage}>(&${arg_name},"${arg_name}",${arg_pos}, false)'),
117120
'THGenerator*': CodeTemplate('check_generator(&${arg_name})'),
118121
'THSize*': CodeTemplate('THLongStorageView::make(${arg_name},true)'),
119122
'THStride*': CodeTemplate('THLongStorageView::make(${arg_name},true)'),
@@ -133,6 +136,8 @@ def __init__(self, reason):
133136
'TensorList': "{0}_.data(), {0}_.size()",
134137
}
135138

139+
CHECKED_USE_NULLABLE = CodeTemplate('${arg_name}_ ? ${usage} : NULL')
140+
136141
ALLOC_WRAP = {
137142
'THTensor*': 'new ${Tensor}(context)',
138143
'THBoolTensor*': 'new ${Backend}ByteTensor(context)',
@@ -341,12 +346,20 @@ def create_derived(backend_type_env, declarations):
341346
def requires_checked_cast(argument):
342347
return argument['type'] in CHECKED_CAST
343348

349+
def nullable_argument(argument):
350+
return (argument['type'] == 'THTensor*' and
351+
argument.get('default', '') == 'nullptr')
352+
344353
def bool_option_is_string(argument):
345354
return 'if_true' in argument and isinstance(argument['if_true'], string_type)
346355

347356
def get_argument(argument, option):
348357
if requires_checked_cast(argument):
349-
return CHECKED_USE.get(argument['type'], '{}_').format(argument['name'])
358+
checked_use = CHECKED_USE.get(argument['type'], '{}_').format(argument['name'])
359+
if nullable_argument(argument):
360+
checked_use = CHECKED_USE_NULLABLE.substitute(
361+
env={}, arg_name=argument['name'], usage=checked_use)
362+
return checked_use
350363
elif argument['type'] == 'bool' and 'if_true' in argument:
351364
if bool_option_is_string(argument):
352365
tpl = '({}) ? "{}" : "{}"'
@@ -424,8 +437,14 @@ def emit_body(env, option):
424437
arg['name'], arg['name']))
425438
# extract the TensorImpl from an existing tensor (or Storage, etc.)
426439
else:
440+
# special case where we allow undefined Tensors, and thus
441+
# the checked cast succeeds even if the Tensor is not
442+
# defined
443+
null_okay = 'true' if nullable_argument(arg) else 'false'
444+
427445
check_cast = CHECKED_CAST[arg['type']].substitute(
428-
env, arg_name=arg['name'], arg_pos=count)
446+
env, arg_name=arg['name'], arg_pos=count,
447+
null_okay=null_okay)
429448
body.append("auto {}_ = {};".format(
430449
arg['name'], check_cast))
431450
if drop_argument(arg, option):

0 commit comments

Comments
 (0)