@@ -85,6 +85,20 @@ def __init__(self, reason):
8585 'long' : 'int64_t' ,
8686}
8787
88+ DYNAMIC_TYPE = {
89+ 'THTensor*' : 'Tensor' ,
90+ 'THBoolTensor*' : 'BoolTensor' ,
91+ 'THIndexTensor*' : 'IndexTensor' ,
92+ 'THIntegerTensor*' : 'IntegerTensor' ,
93+ 'THStorage*' : 'Storage' ,
94+ 'THGenerator*' : 'Generator' ,
95+ 'THSize*' : 'IntList' ,
96+ 'THStride*' : 'IntList' ,
97+ 'accreal' : 'accreal' ,
98+ 'real' : 'real' ,
99+ 'long' : 'int64_t' ,
100+ }
101+
88102TYPE_RETURN = {
89103 'THTensor*' : 'Tensor' ,
90104 'THIndexTensor*' : 'Tensor' ,
@@ -164,11 +178,29 @@ def to_return_type(arg, option):
164178 rt = rt + ' &'
165179 if not is_mutable_formal_argument (arg , option ):
166180 rt = 'const ' + rt
167- return rt
181+ return {
182+ 'type' : rt ,
183+ 'dynamic_type' : DYNAMIC_TYPE .get (arg ['type' ], arg ['type' ]),
184+ }
168185
169186
170187def create_generic (top_env , declarations ):
171188
189+ # change from THTensor* to Tensor & so we get how it will appear
190+ # in the aten argument list...
191+ def translate_formal (argument , option ):
192+ type_str = TYPE_FORMAL_GENERIC .get (argument ['type' ], argument ['type' ])
193+ if type_str == 'Tensor &' and not is_mutable_formal_argument (argument , option ):
194+ type_str = 'const ' + type_str
195+ translated = {
196+ 'name' : argument ['name' ],
197+ 'type' : type_str ,
198+ 'dynamic_type' : DYNAMIC_TYPE .get (argument ['type' ], argument ['type' ]),
199+ }
200+ if argument .get ('output' ):
201+ translated ['output' ] = True
202+ return translated
203+
172204 def get_formals (option ):
173205 seen = set ()
174206 result = []
@@ -185,38 +217,43 @@ def insert(argument):
185217 for argument in option ['arguments' ]:
186218 if argument .get ('output' ) and not argument .get ('allocate' , False ):
187219 insert (argument )
188- return result
189220
190- def format_formal (argument , option ):
191- type_str = TYPE_FORMAL_GENERIC .get (argument ['type' ], argument ['type' ])
192- if type_str == 'Tensor &' and not is_mutable_formal_argument (argument , option ):
193- type_str = 'const ' + type_str
194- return '{} {}' .format (type_str , argument ['name' ])
221+ return [translate_formal (argument , option ) for argument in result ]
195222
196- def format_return_type (option ):
223+ def get_return_types (option ):
197224 ret = option ['return' ]
198225 if ret ['kind' ] == 'arguments' :
199226 argument_indices = ret ['arguments' ]
200227 if len (argument_indices ) == 1 :
201228 the_arg = option ['arguments' ][argument_indices [0 ]]
202- return to_return_type (the_arg , option )
229+ return [ to_return_type (the_arg , option )]
203230 else :
204- types = [to_return_type (option ['arguments' ][idx ], option )
205- for idx in argument_indices ]
206- return "std::tuple<{}>" .format (',' .join (types ))
207-
231+ return [to_return_type (option ['arguments' ][idx ], option )
232+ for idx in argument_indices ]
208233 elif ret ['kind' ] == 'type' :
209- return TYPE_RETURN .get (ret ['type' ], ret ['type' ])
234+ return [{
235+ 'type' : TYPE_RETURN .get (ret ['type' ], ret ['type' ]),
236+ 'dynamic_type' : DYNAMIC_TYPE .get (ret ['type' ], ret ['type' ]),
237+ }]
210238 else :
211239 raise Exception ("format_return_type" )
212240
241+ def format_return_type (return_types ):
242+ if len (return_types ) == 1 :
243+ return return_types [0 ]['type' ]
244+ return "std::tuple<{}>" .format (',' .join (r ['type' ] for r in return_types ))
245+ return return_types
246+
213247 def find_first_tensor (formals ):
214- for argument in formals :
215- if argument [ 'type' ] == "THTensor*" or argument [ 'type' ] == 'TensorList' :
216- return argument ['name' ]
248+ for formal in formals :
249+ if 'Tensor' == formal [ 'dynamic_type' ] or 'TensorList' == formal [ 'dynamic_type' ] :
250+ return formal ['name' ]
217251 return None
218252
219- def process_option (option ):
253+ def format_formal (f ):
254+ return '{} {}' .format (f ['type' ],f ['name' ])
255+
256+ def process_option (option , output_options ):
220257 option ['inplace' ] = re .search (
221258 '(^__i|[^_]_$)' , option ['api_name' ]) is not None
222259
@@ -226,13 +263,15 @@ def process_option(option):
226263 # print(yaml.dump(option))
227264 formals = get_formals (option )
228265 option ['formals_list' ] = formals
229- option ['formals' ] = [format_formal (f , option ) for f in formals ]
266+ option ['formals' ] = [format_formal (f ) for f in formals ]
267+ option ['returns' ] = get_return_types (option )
230268 option ['actuals' ] = [f ['name' ] for f in formals ]
231- option ['method_formals' ] = [format_formal (f , option ) for f in formals
269+
270+ option ['method_formals' ] = [format_formal (f ) for f in formals
232271 if f ['name' ] != 'self' ]
233272 option ['method_actuals' ] = [
234273 f ['name' ] if f ['name' ] != 'self' else '*this' for f in formals ]
235- option ['return_type' ] = format_return_type (option )
274+ option ['return_type' ] = format_return_type (option [ 'returns' ] )
236275
237276 option ['const_mark' ] = '' if option ['inplace' ] else ' const'
238277
@@ -253,22 +292,46 @@ def process_option(option):
253292 TENSOR_METHOD_DECLARATION .substitute (env ))
254293 top_env ['tensor_method_definitions' ].append (
255294 TENSOR_METHOD_DEFINITION .substitute (env ))
295+ output_options .append ({
296+ 'name' : option ['name' ],
297+ 'arguments' : [f for f in formals if f ['name' ] != 'self' ],
298+ 'method_of' : 'Tensor' ,
299+ 'returns' : option ['returns' ],
300+ 'inplace' : option ['inplace' ],
301+ })
256302
257303 if is_function :
258304 first_tensor = find_first_tensor (formals )
305+ output_option = {
306+ 'name' : option ['name' ],
307+ 'arguments' : formals ,
308+ 'returns' : option ['returns' ],
309+ 'inplace' : option ['inplace' ],
310+ }
259311 if first_tensor is not None :
260312 option ['inferred_type' ] = 'infer_type({})' .format (first_tensor )
261313 top_env ['function_declarations' ].append (
262314 FUNCTION_DECLARATION .substitute (env ))
263315 top_env ['function_definitions' ].append (
264316 FUNCTION_DEFINITION .substitute (env ))
317+ else :
318+ output_option ['method_of' ] = 'Type'
319+ output_options .append (output_option )
265320
321+ output_declarations = []
266322 for declaration in declarations :
323+ output_options = []
267324 for option in declaration ['options' ]:
268325 try :
269- process_option (option )
326+ process_option (option , output_options )
270327 except NYIError :
271328 option ['skip' ] = True
329+ if len (output_options ) > 0 :
330+ output_declarations .append ({
331+ 'name' : output_options [0 ]['name' ],
332+ 'options' : output_options ,
333+ })
334+ return output_declarations
272335
273336
274337def create_derived (backend_type_env , declarations ):
@@ -429,7 +492,7 @@ def emit_body(env, option):
429492 arg = arguments [0 ]
430493 body .append ("return {};" .format (arg ['name' ]))
431494 else :
432- types = [to_return_type (arg , option ) for arg in arguments ]
495+ types = [to_return_type (arg , option )[ 'type' ] for arg in arguments ]
433496 # TODO: check for move semantics...
434497 names = [arg ['name' ] for arg in arguments ]
435498 body .append (CodeTemplate ("return std::tuple<${types}>(${names});" ).substitute (
0 commit comments