4545}
4646""" )
4747
48+ ZERO_DIM_CHECK = CodeTemplate ("""\
49+ if(${check_name}.dim() == 0) {
50+ return ${method_prefix}${api_name}(${zero_dim_actuals});
51+ }""" )
52+
53+ SCALAR_EXPAND = CodeTemplate ("""\
54+ Tensor ${name}__;
55+ if(${name}_->isScalar()) {
56+ ${name}__ = ${name}.expand(${other}.sizes());
57+ ${name}_ = static_cast<${Tensor}*>(${name}__.pImpl);
58+ }
59+ """ )
4860
4961class NYIError (Exception ):
5062 """Indicates we don't support this declaration yet"""
@@ -83,8 +95,8 @@ def __init__(self, reason):
8395 'THIntegerTensor*' : CodeTemplate ('checked_cast<${Backend}IntTensor>(${arg_name}.pImpl,"${arg_name}",${arg_pos})' ),
8496 'THStorage*' : CodeTemplate ('checked_cast<${Storage}>(&${arg_name},"${arg_name}",${arg_pos})' ),
8597 'THGenerator*' : CodeTemplate ('check_generator(&${arg_name})' ),
86- 'THSize*' : CodeTemplate ('THLongStorageView::make(${arg_name})' ),
87- 'THStride*' : CodeTemplate ('THLongStorageView::make(${arg_name})' ),
98+ 'THSize*' : CodeTemplate ('THLongStorageView::make(${arg_name},true )' ),
99+ 'THStride*' : CodeTemplate ('THLongStorageView::make(${arg_name},true )' ),
88100 'real' : CodeTemplate ('${arg_name}.to${ScalarName}()' ),
89101 'accreal' : CodeTemplate ('${arg_name}.to${AccScalarName}()' ),
90102
@@ -290,16 +302,27 @@ def is_actual_return_long(ret):
290302 return ret ['type' ] == 'long' or (backend_type_env ['ScalarName' ] == 'Long' and
291303 ret ['type' ] == 'real' or ret ['type' ] == 'accreal' )
292304
305+ def handle_zero_dim (env ,option ):
306+ if 'zero_dim_dispatch_when_scalar' not in option :
307+ return []
308+ check_name = option ['zero_dim_dispatch_when_scalar' ]
309+ zero_dim_actuals = [ arg ['name' ]
310+ if arg ['name' ] != check_name else arg ['name' ]+ '.scalar()'
311+ for arg in option ['formals_list' ] ]
312+ return [ ZERO_DIM_CHECK .substitute (env ,check_name = check_name , zero_dim_actuals = zero_dim_actuals ) ]
313+
293314 def emit_body (env , option ):
294315 body = []
316+ body += handle_zero_dim (env ,option )
295317 # arguments are potentially duplicated because of one argument
296318 # referencing another
297319 seen_names = set ()
298- # only generated checked casts the first time we see it
299320 count = 0
300321 for arg in option ['arguments' ]:
301322 if is_real_argument_to_wrapper (arg ):
302323 count += 1
324+
325+ # only generated checked casts the first time we see it
303326 if not arg ['name' ] in seen_names and requires_checked_cast (arg ):
304327 seen_names .add (arg ['name' ])
305328 if arg .get ('allocate' , False ):
@@ -326,19 +349,25 @@ def emit_body(env, option):
326349 arg ['name' ], ',' .join (dims )))
327350 if arg .get ('cpu_zero' , False ):
328351 body .append ("{}.zero_();" .format (arg ['name' ]))
329-
330- option ['actuals' ] = get_arguments (option )
352+ # handle scalars that occur on LHS of things like a - b
353+ if 'broadcast' in arg and 'inplace' not in arg ['broadcast' ]:
354+ other = arg ['broadcast' ].split (' ' )[0 ].split (',' )[0 ]
355+ body .append (SCALAR_EXPAND .substitute (env ,
356+ name = arg ['name' ],
357+ other = other ))
358+
359+ option ['derived_actuals' ] = get_arguments (option )
331360 is_cuda = backend_type_env ['Backend' ] == 'CUDA'
332361 is_nn = option ['mode' ] == 'NN'
333362 if is_cuda or is_nn :
334- option ['actuals ' ] = ['context->thc_state' ] + option ['actuals ' ]
363+ option ['derived_actuals ' ] = ['context->thc_state' ] + option ['derived_actuals ' ]
335364
336365 if is_nn :
337366 prefix = 'THNN_{}' .format (env ['THType' ])
338367 else :
339368 prefix = env ['THTensor' ] + '_'
340369
341- call = prefix + CodeTemplate ("${cname}(${actuals })" ).substitute (env )
370+ call = prefix + CodeTemplate ("${cname}(${derived_actuals })" ).substitute (env )
342371 ret = option ['return' ]
343372 if ret ['kind' ] == 'arguments' :
344373 body .append (call + ";" )
0 commit comments