@@ -99,15 +99,12 @@ struct DecomposeTFOpsPass
9999};
100100
101101void DecomposeTFOpsPass::ApplyCanonicalization () {
102+ FuncOp func = getFunction ();
102103 OwningRewritePatternList patterns;
103104
104- auto * context = &getContext ();
105- for (auto * op : context->getRegisteredOperations ()) {
106- op->getCanonicalizationPatterns (patterns, context);
107- }
108- populateSCFOpsCanonicalizationPatterns (patterns, context);
105+ populateCanonicalizationPatterns (func, patterns);
109106
110- applyPatternsAndFoldGreedily (getFunction () , std::move (patterns));
107+ applyPatternsAndFoldGreedily (func , std::move (patterns));
111108}
112109
113110LogicalResult DecomposeTFOpsPass::RewriteUnregisteredTFOps () {
@@ -122,15 +119,25 @@ LogicalResult DecomposeTFOpsPass::RewriteUnregisteredTFOps() {
122119 // either will be constant folded or lowered by the rules defined in the
123120 // bridge.
124121 if (op->isRegistered ()) {
125- return ;
122+ return WalkResult::advance () ;
126123 }
127124
128125 // Find out the compose function
129126 auto compose_func_name = GetComposeFuncName (op->getName ().getStringRef ());
130127 auto compose_func = table.lookup <TFRFuncOp>(compose_func_name);
131128 if (!compose_func || compose_func.isExternal ()) {
132129 // There are no decomposition methods defined for this op, skip.
133- return ;
130+ return WalkResult::advance ();
131+ }
132+
133+ // Make sure all the attributes are valid. An attribute is valid when it is
134+ // in the signature or it is allowed explicitly.
135+ auto compose_func_signature =
136+ table.lookup <TFRFuncOp>(compose_func_name + " _" );
137+ if (!compose_func_signature) compose_func_signature = compose_func;
138+ auto defined_attrs = compose_func_signature.getDefinedAttributeNames ();
139+ if (failed (ValidateAttrs (op, defined_attrs))) {
140+ return WalkResult::interrupt ();
134141 }
135142
136143 tensorflow::IncreaseOpExpansionExecuteCounterByOne (
@@ -215,8 +222,15 @@ LogicalResult DecomposeTFOpsPass::RewriteUnregisteredTFOps() {
215222 op->getLoc (), std::get<0 >(res).getType (), std::get<1 >(res));
216223 std::get<0 >(res).replaceAllUsesWith (casted.out ());
217224 }
225+
226+ // Copy all the unregisted attributes to the new op.
227+ if (failed (CopyAllowedUnregisteredAttrs (op, new_op, defined_attrs))) {
228+ return WalkResult::interrupt ();
229+ }
230+
218231 op->erase ();
219232 changed |= true ;
233+ return WalkResult::advance ();
220234 });
221235
222236 // If `changed` is false, it is considered as a failure, so the recursive
@@ -237,6 +251,15 @@ LogicalResult DecomposeTFOpsPass::InlineTFRFuncCalls() {
237251 auto walk_result = func.walk ([&](CallOp call_op) {
238252 auto callee = table.lookup <TFRFuncOp>(call_op.callee ());
239253 if (!callee || callee.isExternal ()) return WalkResult::advance ();
254+
255+ // Record the boundary of the inlined operations. The inlined operation will
256+ // be inserted between these two operations.
257+ Operation* inlined_point = call_op.getOperation ();
258+ Operation* after_inlined_point =
259+ &*std::next (Block::iterator (call_op.getOperation ()));
260+
261+ // Use the inliner to replace all the uses of the call_op by its
262+ // composition.
240263 if (failed (inlineCall (inliner,
241264 cast<CallOpInterface>(call_op.getOperation ()),
242265 cast<CallableOpInterface>(callee.getOperation ()),
@@ -246,6 +269,13 @@ LogicalResult DecomposeTFOpsPass::InlineTFRFuncCalls() {
246269 // This call will be raised to TF ops.
247270 return WalkResult::interrupt ();
248271 }
272+
273+ // Propagate all the attributes to the inlined operations, which are defined
274+ // by the two boundary operations.
275+ PropagateAttrsToOperations (call_op, Block::iterator (inlined_point),
276+ Block::iterator (after_inlined_point));
277+
278+ // Remove the call_op to finish the op expansion.
249279 call_op.erase ();
250280 changed |= true ;
251281 return WalkResult::advance ();
0 commit comments