@@ -109,11 +109,14 @@ def __init__(self, reason):
109109 'long' : 'int64_t' ,
110110}
111111CHECKED_CAST = {
112- 'THTensor*' : CodeTemplate ('checked_cast<${Tensor}>(${arg_name}.pImpl,"${arg_name}",${arg_pos})' ),
113- 'THBoolTensor*' : CodeTemplate ('checked_cast<${Backend}ByteTensor>(${arg_name}.pImpl,"${arg_name}",${arg_pos})' ),
114- 'THIndexTensor*' : CodeTemplate ('checked_cast<${Backend}LongTensor>(${arg_name}.pImpl,"${arg_name}",${arg_pos})' ),
115- 'THIntegerTensor*' : CodeTemplate ('checked_cast<${Backend}IntTensor>(${arg_name}.pImpl,"${arg_name}",${arg_pos})' ),
116- 'THStorage*' : CodeTemplate ('checked_cast<${Storage}>(&${arg_name},"${arg_name}",${arg_pos})' ),
112+ 'THTensor*' : CodeTemplate ('checked_cast<${Tensor}>(${arg_name}.pImpl,"${arg_name}",${arg_pos}, ${null_okay})' ),
113+ 'THBoolTensor*' :
114+ CodeTemplate ('checked_cast<${Backend}ByteTensor>(${arg_name}.pImpl,"${arg_name}",${arg_pos}, ${null_okay})' ),
115+ 'THIndexTensor*' :
116+ CodeTemplate ('checked_cast<${Backend}LongTensor>(${arg_name}.pImpl,"${arg_name}",${arg_pos}, ${null_okay})' ),
117+ 'THIntegerTensor*' :
118+ CodeTemplate ('checked_cast<${Backend}IntTensor>(${arg_name}.pImpl,"${arg_name}",${arg_pos}, ${null_okay})' ),
119+ 'THStorage*' : CodeTemplate ('checked_cast<${Storage}>(&${arg_name},"${arg_name}",${arg_pos}, false)' ),
117120 'THGenerator*' : CodeTemplate ('check_generator(&${arg_name})' ),
118121 'THSize*' : CodeTemplate ('THLongStorageView::make(${arg_name},true)' ),
119122 'THStride*' : CodeTemplate ('THLongStorageView::make(${arg_name},true)' ),
@@ -133,6 +136,8 @@ def __init__(self, reason):
133136 'TensorList' : "{0}_.data(), {0}_.size()" ,
134137}
135138
139+ CHECKED_USE_NULLABLE = CodeTemplate ('${arg_name}_ ? ${usage} : NULL' )
140+
136141ALLOC_WRAP = {
137142 'THTensor*' : 'new ${Tensor}(context)' ,
138143 'THBoolTensor*' : 'new ${Backend}ByteTensor(context)' ,
@@ -341,12 +346,20 @@ def create_derived(backend_type_env, declarations):
341346 def requires_checked_cast (argument ):
342347 return argument ['type' ] in CHECKED_CAST
343348
349+ def nullable_argument (argument ):
350+ return (argument ['type' ] == 'THTensor*' and
351+ argument .get ('default' , '' ) == 'nullptr' )
352+
344353 def bool_option_is_string (argument ):
345354 return 'if_true' in argument and isinstance (argument ['if_true' ], string_type )
346355
347356 def get_argument (argument , option ):
348357 if requires_checked_cast (argument ):
349- return CHECKED_USE .get (argument ['type' ], '{}_' ).format (argument ['name' ])
358+ checked_use = CHECKED_USE .get (argument ['type' ], '{}_' ).format (argument ['name' ])
359+ if nullable_argument (argument ):
360+ checked_use = CHECKED_USE_NULLABLE .substitute (
361+ env = {}, arg_name = argument ['name' ], usage = checked_use )
362+ return checked_use
350363 elif argument ['type' ] == 'bool' and 'if_true' in argument :
351364 if bool_option_is_string (argument ):
352365 tpl = '({}) ? "{}" : "{}"'
@@ -424,8 +437,14 @@ def emit_body(env, option):
424437 arg ['name' ], arg ['name' ]))
425438 # extract the TensorImpl from an existing tensor (or Storage, etc.)
426439 else :
440+ # special case where we allow undefined Tensors, and thus
441+ # the checked cast succeeds even if the Tensor is not
442+ # defined
443+ null_okay = 'true' if nullable_argument (arg ) else 'false'
444+
427445 check_cast = CHECKED_CAST [arg ['type' ]].substitute (
428- env , arg_name = arg ['name' ], arg_pos = count )
446+ env , arg_name = arg ['name' ], arg_pos = count ,
447+ null_okay = null_okay )
429448 body .append ("auto {}_ = {};" .format (
430449 arg ['name' ], check_cast ))
431450 if drop_argument (arg , option ):
0 commit comments