@@ -334,8 +334,92 @@ def has_output_args(declaration):
334334 for option in declaration ['options' ]
335335 for arg in option ['arguments' ])
336336
337+ def backends_types_to_defined_if_string (declaration ):
338+ # A declaration has two fields: 'backend', which stores a list of
339+ # backends (currently 'cpu' and 'cuda') the declaration applies
340+ # to, and 'types', which stores a list of real types the
341+ # declaration applies to. In PyTorch, when a function is only
342+ # supported by a subset of types, we wrap it in macro definition
343+ # checks.
344+ #
345+ # Previously, we manually required the cwrap declaration to
346+ # specify for which backend/type combinations a function was
347+ # defined for. Now, we explicitly list the types and backends for
348+ # a declaration, if it should only be supported for a specific
349+ # subset of types, backends, or type-backend pairs.
350+
351+ types = declaration .get ('types' , [])
352+ backends = declaration ['backends' ]
353+ all_backends = ['CPU' , 'CUDA' ]
354+
355+ def get_defined_string (backend , real ):
356+ if backend == 'CUDA' :
357+ if real == 'all' :
358+ return "IS_CUDA"
359+ else :
360+ return 'CUDA_{0}' .format (real .upper ())
361+ else :
362+ if real == 'all' :
363+ return "!IS_CUDA"
364+ else :
365+ return 'defined(TH_REAL_IS_{0})' .format (real .upper ())
366+
367+ def expand_composite_type (p , t ):
368+ if t == 'floating_point' :
369+ result = ['double' , 'float' ]
370+ if p == 'CUDA' :
371+ result .append ('half' )
372+ elif t == 'integral' :
373+ result = ['byte' , 'char' , 'short' , 'int' , 'long' ]
374+ else :
375+ result = [t ]
376+ return result
377+
378+ defineds = []
379+
380+ # The logic below does not handle corner cases well. We allow the
381+ # declaration to have a field 'backend_type_pairs' that stores a
382+ # dictionary from type --> backend representing allowed
383+ # combinations. Let's use these first.
384+ for pair in declaration .get ('backend_type_pairs' , []):
385+ p , t = pair
386+ defineds .extend ([get_defined_string (p , et ) for et in
387+ expand_composite_type (p , t )])
388+
389+ # In the base case, types is empty and backends contains both
390+ # 'CPU' and 'CUDA' --> this means we support all types, and our
391+ # string should be empty, or simply the list of explict type
392+ # backend pairs
393+ if (len (types ) == 0 and all ([proc in backends for proc in
394+ all_backends ])):
395+ return " || " .join (defineds )
396+
397+ # Case 2: types is empty, but only one backend type is specified
398+ if len (types ) == 0 and len (backends ) == 1 :
399+ defineds .append ('IS_CUDA' if backends [0 ] == 'CUDA' else
400+ "!IS_CUDA" )
401+ return " || " .join (defineds )
402+
403+ # Else, we loop overall all of the backend, type pairs and add
404+ # them
405+ for p in backends :
406+ for t in types :
407+ defineds .extend ([get_defined_string (p , et ) for et in
408+ expand_composite_type (p , t )])
409+
410+ return " || " .join (defineds )
411+
337412 for declaration in declarations :
338413 # Disable all methods for THHalfTensor, unless cpu_half is True
414+
415+ dfstr = backends_types_to_defined_if_string (declaration )
416+ if len (dfstr ) > 0 :
417+ # for now, need to check for distributed defined if as well
418+ if 'defined_if' in declaration :
419+ declaration ['defined_if' ] += ' && (' + dfstr + ')'
420+ else :
421+ declaration ['defined_if' ] = dfstr
422+
339423 if not declaration .get ('cpu_half' , False ):
340424 defined_if = '!defined(TH_REAL_IS_HALF)'
341425 if 'defined_if' in declaration :
@@ -362,11 +446,13 @@ def has_output_args(declaration):
362446 if option .get ('sparse' , False ):
363447 defined_if = option .get ('defined_if' , '' )
364448 option ['defined_if' ] = '!IS_DISTRIBUTED' + (' && ' if defined_if else '' ) + defined_if
365- if declaration .get ('with_stateless' , False ) or declaration .get ('only_stateless' , False ):
449+
450+ variants = declaration .get ('variants' , ['method' ])
451+ if 'function' in variants :
366452 stateless_declaration = self .make_stateless (declaration )
367453 new_declarations .append (stateless_declaration )
368454 self .stateless_declarations .append (stateless_declaration )
369- if declaration . get ( 'only_stateless' , False ) :
455+ if 'method' not in variants :
370456 continue
371457
372458 self .declarations .append (declaration )
@@ -379,9 +465,13 @@ def has_output_args(declaration):
379465
380466 register_only = [d for d in declarations if d .get ('only_register' , False )]
381467 declarations = [d for d in declarations
382- if (not d .get ('only_stateless' , False )) and (not d .get ('only_register' , False ))]
383- self .declarations .extend (filter (lambda x : not x .get ('only_stateless' , False ), register_only ))
384- self .stateless_declarations .extend (filter (lambda x : x .get ('only_stateless' , False ), register_only ))
468+ if (('method' in d .get ('variants' , ['method' ])) and
469+ (not d .get ('only_register' , False )))]
470+ self .declarations .extend (filter (lambda x : 'method' in x .get ('variants' ,
471+ ['method' ]), register_only ))
472+ self .stateless_declarations .extend (filter (lambda x : 'method' not in
473+ x .get ('variants' , ['method' ]),
474+ register_only ))
385475
386476 self .process_docstrings ()
387477
0 commit comments