Skip to content

Commit a36e646

Browse files
ljk53facebook-github-bot
authored andcommitted
[pytorch][codegen] simplify python signature creation logic (pytorch#47977)
Summary: Pull Request resolved: pytorch#47977 Avoid calling CppSignatureGroup api - python signature shouldn't depend on cpp signature. Still use cpp.group_arguments() to group TensorOptions. Confirmed byte-for-byte compatible with the old codegen: ``` Run it before and after this PR: .jenkins/pytorch/codegen-test.sh <baseline_output_dir> .jenkins/pytorch/codegen-test.sh <test_output_dir> Then run diff to compare the generated files: diff -Naur <baseline_output_dir> <test_output_dir> ``` Test Plan: Imported from OSS Reviewed By: ezyang Differential Revision: D24976334 Pulled By: ljk53 fbshipit-source-id: 5df5a7bbfd2b8cb460153e5bea4d91e65f716390
1 parent 5eaf856 commit a36e646

File tree

1 file changed

+6
-18
lines changed

1 file changed

+6
-18
lines changed

tools/codegen/api/python.py

Lines changed: 6 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -510,12 +510,7 @@ def argument_type_size(t: Type) -> Optional[int]:
510510
else:
511511
return None
512512

513-
def argument(cpp_arg: CppArgument) -> PythonArgument:
514-
a = cpp_arg.argument
515-
if not isinstance(a, Argument):
516-
# cpp's TensorOptionsArguments is ignored, we will reintroduce the
517-
# scattered fields in tensor_options_args.
518-
raise RuntimeError(f'unsupported cpp argument: \'{cpp_arg}\'')
513+
def argument(a: Argument) -> PythonArgument:
519514
return PythonArgument(
520515
name=a.name,
521516
type=a.type,
@@ -527,25 +522,18 @@ def argument(cpp_arg: CppArgument) -> PythonArgument:
527522

528523
def signature(f: NativeFunction, *, method: bool = False) -> PythonSignature:
529524
# Use cpp api to gather TensorOptions fields from kwargs.
530-
# Always set 'method' to false as ThisArgument is not relevant - 'self'
531-
# is still included as regular Argument type.
532-
# TODO: maybe directly generate from FunctionSchema to avoid slicing back
533-
# into args/kwargs/outputs?
534-
cpp_sig = _cpp_signature(f, method=False)
535-
536525
# Skip ThisArgument if this is method signature.
537526
# Skip TensorOptionsArguments in C++ signature. Python side TensorOptions
538527
# arguments are created based on different rules - see below.
539-
cpp_arguments = tuple(filter(lambda a: not (method and a.name == 'self') and
540-
not isinstance(a.argument, TensorOptionsArguments), cpp_sig.arguments()))
528+
args = tuple(a for a in cpp.group_arguments(f.func, method=method) if isinstance(a, Argument))
541529

530+
input_arg_set = set(a.name for a in f.func.arguments)
542531
kwarg_only_set = set(a.name for a in f.func.kwarg_only_arguments)
543532
out_arg_set = set(a.name for a in f.func.out_arguments)
544533

545-
input_args = tuple(map(argument,
546-
filter(lambda a: not (a.name in kwarg_only_set or a.name in out_arg_set), cpp_arguments)))
547-
input_kwargs = tuple(map(argument, filter(lambda a: a.name in kwarg_only_set, cpp_arguments)))
548-
outputs = tuple(map(argument, filter(lambda a: a.name in out_arg_set, cpp_arguments)))
534+
input_args = tuple(map(argument, filter(lambda a: a.name in input_arg_set, args)))
535+
input_kwargs = tuple(map(argument, filter(lambda a: a.name in kwarg_only_set, args)))
536+
outputs = tuple(map(argument, filter(lambda a: a.name in out_arg_set, args)))
549537

550538
# Reintroduce the scattered fields of TensorOptions for Python.
551539
# Compared to the cpp counterpart, the python arguments have new property

0 commit comments

Comments
 (0)