Skip to content

Commit f432994

Browse files
liufengdbtensorflower-gardener
authored andcommitted
Enable nested composition functions
This patch allowed the users to define nested composition functions. all the composition functions need to be defined before their uses. PiperOrigin-RevId: 345073100 Change-Id: Ideab901b270e6036b5361feb82c64503a734d57e
1 parent d5eb677 commit f432994

File tree

2 files changed

+44
-1
lines changed

2 files changed

+44
-1
lines changed

tensorflow/compiler/mlir/tfr/python/tfr_gen.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,15 @@ def res_call(self, ns, types_ns, node, f_type, args, keywords):
431431
return ({tuple(_get_type_from_proto(arg) for arg in op_def.output_arg)},
432432
None)
433433

434+
elif f_type == (types.FunctionType,):
435+
# A composition Python function name is used directly.
436+
op_name = name.qn[0]
437+
op_def, _ = self._op_defs.lookup(op_name)
438+
if len(op_def.output_arg) == 1:
439+
return {_get_type_from_proto(op_def.output_arg[0])}, None
440+
return ({tuple(_get_type_from_proto(arg) for arg in op_def.output_arg)},
441+
None)
442+
434443
elif f_type == (TFRTypes.PY_BUILTIN_FUNC,):
435444
assert name.is_simple()
436445
if name == QN('range'):
@@ -809,6 +818,9 @@ def visit_Call(self, node):
809818
if func_type == TFRTypes.TF_RAW_OP:
810819
return self._visit_tf_op(func_name, node.args, node.keywords, node)
811820

821+
if func_type == types.FunctionType:
822+
return self._visit_tf_op(func_name, node.args, node.keywords, node)
823+
812824
if func_type == TFRTypes.TF_TENSOR_SHAPE_FUNC:
813825
return (func_name, TFRTypes.TF_TENSOR_SHAPE_LIST)
814826

@@ -1184,7 +1196,13 @@ def visit_If(self, node):
11841196
raise NotImplementedError('If not supported.')
11851197

11861198
def visit_Name(self, node):
1187-
val, lookup_type = self.symbol_table.lookup(node.id)
1199+
val_and_lookup_type = self.symbol_table.lookup(node.id)
1200+
if val_and_lookup_type:
1201+
(val, lookup_type) = val_and_lookup_type
1202+
else:
1203+
op_def, _ = self._op_defs.lookup(node.id)
1204+
val = op_def.name
1205+
lookup_type = anno.getanno(node, anno.Static.TYPES, types.FunctionType)
11881206
type_ = self._get_inferred_type(node, lookup_type)
11891207
return val, type_
11901208

tensorflow/compiler/mlir/tfr/python/tfr_gen_test.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,20 @@ def _tfr_shapes(x):
211211
return x
212212

213213

214+
#--- test fn for nested functions ---
215+
216+
217+
@composite.Composite('TestIdentityNOp')
218+
def _tfr_temp_op(x):
219+
return x
220+
221+
222+
@composite.Composite('TestIdentityOp')
223+
def _tfr_temp_use_op(x):
224+
y = _tfr_temp_op([x])
225+
return y[0]
226+
227+
214228
class TFRGenTestBase(test.TestCase):
215229

216230
def _check_code(self, tfr_code, exp_tfr_code):
@@ -557,6 +571,17 @@ def test_tf_tensor_shape(self):
557571
"""
558572
self._check_code(mlir_code, mlir_code_exp)
559573

574+
def test_temp_function(self):
575+
mlir_code = tfr_gen(sys.modules[__name__], '_tfr_temp', [test_ops])
576+
mlir_code_exp = r"""
577+
CHECK-LABEL: tfr.func @tf__test_identity_n_op(%x: !tfr.tensor_list) -> (!tfr.tensor_list)
578+
579+
CHECK-LABEL: tfr.func @tf__test_identity_op(%x: !tfr.tensor) -> (!tfr.tensor) {
580+
CHECK-NEXT: %[[list:.*]] = "tfr.build_list"(%x) : (!tfr.tensor) -> !tfr.tensor_list
581+
CHECK-NEXT: %[[call:.*]] = tfr.call @tf__test_identity_n_op(%[[list]]) : (!tfr.tensor_list)
582+
"""
583+
self._check_code(mlir_code, mlir_code_exp)
584+
560585

561586
if __name__ == '__main__':
562587
test.main()

0 commit comments

Comments
 (0)