@@ -318,10 +318,12 @@ def emit_body(env, option):
318318 # referencing another
319319 seen_names = set ()
320320 count = 0
321+ scalar_check = None
321322 for arg in option ['arguments' ]:
322323 if is_real_argument_to_wrapper (arg ):
323324 count += 1
324-
325+ if arg ['type' ] == 'THSize*' :
326+ scalar_check = '{}.size() == 0' .format (arg ['name' ])
325327 # only generated checked casts the first time we see it
326328 if not arg ['name' ] in seen_names and requires_checked_cast (arg ):
327329 seen_names .add (arg ['name' ])
@@ -369,24 +371,32 @@ def emit_body(env, option):
369371
370372 call = prefix + CodeTemplate ("${cname}(${derived_actuals})" ).substitute (env )
371373 ret = option ['return' ]
374+
372375 if ret ['kind' ] == 'arguments' :
373376 body .append (call + ";" )
374377 arguments_indices = ret ['arguments' ]
378+ arguments = [option ['arguments' ][argi ]
379+ for argi in arguments_indices ]
380+ if scalar_check is not None :
381+ for arg in arguments :
382+ body .append ("bool maybe_scalar = {};" .format (scalar_check ))
383+ body .append ("{}_->maybeScalar(maybe_scalar);" .format (arg ['name' ]))
375384 if len (arguments_indices ) == 1 :
376- arg = option [ ' arguments' ][ arguments_indices [ 0 ] ]
385+ arg = arguments [ 0 ]
377386 body .append ("return {};" .format (arg ['name' ]))
378387 else :
379- arguments = [option ['arguments' ][argi ]
380- for argi in arguments_indices ]
381388 types = [to_return_type (arg , option ) for arg in arguments ]
382389 # TODO: check for move semantics...
383390 names = [arg ['name' ] for arg in arguments ]
384391 body .append (CodeTemplate ("return std::tuple<${types}>(${names});" ).substitute (
385392 types = types , names = names ))
386393 elif ret ['kind' ] == 'type' :
387394 if ret ['type' ] == 'THTensor*' :
395+ maybe_scalar = "->maybeScalar({})" .format (scalar_check ) \
396+ if scalar_check is not None \
397+ else ""
388398 body .append (CodeTemplate (
389- "return Tensor(new ${Tensor}(context,${arg_name}),false);" ).substitute (env , arg_name = call ))
399+ "return Tensor(( new ${Tensor}(context,${arg_name}))${maybe_scalar} ,false);" ).substitute (env , arg_name = call , maybe_scalar = maybe_scalar ))
390400 else :
391401 # we using int64_t for long in the API, so correct it here...
392402 if is_actual_return_long (ret ):
0 commit comments