Skip to content

Commit 1c552a6

Browse files
liufengdbtensorflower-gardener
authored andcommitted
Propagate the allowed unregistered attributes to the expanded operations
Since the expansion is by two passes, the propagation of the attributes is by the chain of "tf composite op -> tfr func call op -> inlined tfr func call op -> inlined tf ops". In case the propagation is broken by the canonicalization of TF ops, these canonicalizations are excluded. PiperOrigin-RevId: 346146558 Change-Id: Iecf633b511e4c15e72394c26f313a98ff19206a3
1 parent 8b5b9dc commit 1c552a6

File tree

13 files changed

+267
-19
lines changed

13 files changed

+267
-19
lines changed

tensorflow/compiler/mlir/tfr/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,9 @@ cc_library(
104104
"utils/utils.h",
105105
],
106106
deps = [
107+
":tfr",
107108
"@llvm-project//llvm:Support",
109+
"@llvm-project//mlir:IR",
108110
"@llvm-project//mlir:Support",
109111
],
110112
)

tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx_test.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ tfr.func @tf__my_add_n(%values: !tfr.tensor_list,
8282
tfr.return %res : !tfr.tensor
8383
}
8484
85+
tfr.func @tf__my_add_n_(!tfr.tensor_list<N,T>, i64 {tfr.name="N"}) -> !tfr.tensor attributes{N,T}
8586
tfr.func @tf__risc_add_dummy_(!tfr.tensor<T>, !tfr.tensor<T>) -> !tfr.tensor<T> attributes{T}
8687
)";
8788

tensorflow/compiler/mlir/tfr/ir/tfr_ops.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ limitations under the License.
1616
#ifndef TENSORFLOW_COMPILER_MLIR_TFR_IR_TFR_OPS_H_
1717
#define TENSORFLOW_COMPILER_MLIR_TFR_IR_TFR_OPS_H_
1818

19+
#include "llvm/ADT/StringSet.h"
1920
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
2021
#include "mlir/IR/Dialect.h" // from @llvm-project
2122
#include "mlir/IR/DialectImplementation.h" // from @llvm-project

tensorflow/compiler/mlir/tfr/ir/tfr_ops.td

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -411,6 +411,22 @@ def TFR_TFRFuncOp : TFR_Op<"func", [HasParent<"ModuleOp">,
411411
// Hooks for the input/output type enumeration in FunctionLike .
412412
unsigned getNumFuncArguments() { return getType().getNumInputs(); }
413413
unsigned getNumFuncResults() { return getType().getNumResults(); }
414+
415+
// Get the names of all defined attributes, including both derived and
416+
// non-derived ones.
417+
llvm::StringSet<> getDefinedAttributeNames() {
418+
llvm::StringSet<> all_attrs;
419+
for (auto& attr : getAttrs()) {
420+
all_attrs.insert(attr.first.strref());
421+
}
422+
for (const auto& operand : llvm::enumerate(getType().getInputs())) {
423+
if (auto attr_name = getArgAttrOfType<StringAttr>(
424+
operand.index(), kAttrArgumentNameAttr)) {
425+
all_attrs.insert(attr_name.getValue());
426+
}
427+
}
428+
return all_attrs;
429+
}
414430
}];
415431

416432
let verifier = [{ return Verify(*this); }];

tensorflow/compiler/mlir/tfr/passes/canonicalize.cc

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -151,9 +151,21 @@ LogicalResult SimplifySCFIfOp::InlineRegion(Location loc,
151151

152152
} // namespace
153153

154-
void populateSCFOpsCanonicalizationPatterns(OwningRewritePatternList &results,
155-
MLIRContext *context) {
156-
results.insert<UnrollSCFForOp, SimplifySCFIfOp>(context);
154+
void populateCanonicalizationPatterns(FuncOp func,
155+
OwningRewritePatternList &patterns) {
156+
MLIRContext *context = func.getContext();
157+
mlir::Dialect *tf = context->getLoadedDialect<mlir::TF::TensorFlowDialect>();
158+
// Load all official canonicalization patterns. Here we skip the
159+
// canonicalization of the ops in the tf dialect, because they couldn't
160+
// propagate the attributes correctly. These optimization will be played by
161+
// bridge.
162+
func->walk([&](Operation *op) {
163+
if (op->getDialect() != tf) {
164+
op->getAbstractOperation()->getCanonicalizationPatterns(patterns,
165+
context);
166+
}
167+
});
168+
patterns.insert<UnrollSCFForOp, SimplifySCFIfOp>(context);
157169
}
158170

159171
} // namespace TFR

tensorflow/compiler/mlir/tfr/passes/decompose.cc

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -99,15 +99,12 @@ struct DecomposeTFOpsPass
9999
};
100100

101101
void DecomposeTFOpsPass::ApplyCanonicalization() {
102+
FuncOp func = getFunction();
102103
OwningRewritePatternList patterns;
103104

104-
auto* context = &getContext();
105-
for (auto* op : context->getRegisteredOperations()) {
106-
op->getCanonicalizationPatterns(patterns, context);
107-
}
108-
populateSCFOpsCanonicalizationPatterns(patterns, context);
105+
populateCanonicalizationPatterns(func, patterns);
109106

110-
applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
107+
applyPatternsAndFoldGreedily(func, std::move(patterns));
111108
}
112109

113110
LogicalResult DecomposeTFOpsPass::RewriteUnregisteredTFOps() {
@@ -122,15 +119,25 @@ LogicalResult DecomposeTFOpsPass::RewriteUnregisteredTFOps() {
122119
// either will be constant folded or lowered by the rules defined in the
123120
// bridge.
124121
if (op->isRegistered()) {
125-
return;
122+
return WalkResult::advance();
126123
}
127124

128125
// Find out the compose function
129126
auto compose_func_name = GetComposeFuncName(op->getName().getStringRef());
130127
auto compose_func = table.lookup<TFRFuncOp>(compose_func_name);
131128
if (!compose_func || compose_func.isExternal()) {
132129
// There are no decomposition methods defined for this op, skip.
133-
return;
130+
return WalkResult::advance();
131+
}
132+
133+
// Make sure all the attributes are valid. An attribute is valid when it is
134+
// in the signature or it is allowed explicitly.
135+
auto compose_func_signature =
136+
table.lookup<TFRFuncOp>(compose_func_name + "_");
137+
if (!compose_func_signature) compose_func_signature = compose_func;
138+
auto defined_attrs = compose_func_signature.getDefinedAttributeNames();
139+
if (failed(ValidateAttrs(op, defined_attrs))) {
140+
return WalkResult::interrupt();
134141
}
135142

136143
tensorflow::IncreaseOpExpansionExecuteCounterByOne(
@@ -215,8 +222,15 @@ LogicalResult DecomposeTFOpsPass::RewriteUnregisteredTFOps() {
215222
op->getLoc(), std::get<0>(res).getType(), std::get<1>(res));
216223
std::get<0>(res).replaceAllUsesWith(casted.out());
217224
}
225+
226+
// Copy all the unregisted attributes to the new op.
227+
if (failed(CopyAllowedUnregisteredAttrs(op, new_op, defined_attrs))) {
228+
return WalkResult::interrupt();
229+
}
230+
218231
op->erase();
219232
changed |= true;
233+
return WalkResult::advance();
220234
});
221235

222236
// If `changed` is false, it is considered as a failure, so the recursive
@@ -237,6 +251,15 @@ LogicalResult DecomposeTFOpsPass::InlineTFRFuncCalls() {
237251
auto walk_result = func.walk([&](CallOp call_op) {
238252
auto callee = table.lookup<TFRFuncOp>(call_op.callee());
239253
if (!callee || callee.isExternal()) return WalkResult::advance();
254+
255+
// Record the boundary of the inlined operations. The inlined operation will
256+
// be inserted between these two operations.
257+
Operation* inlined_point = call_op.getOperation();
258+
Operation* after_inlined_point =
259+
&*std::next(Block::iterator(call_op.getOperation()));
260+
261+
// Use the inliner to replace all the uses of the call_op by its
262+
// composition.
240263
if (failed(inlineCall(inliner,
241264
cast<CallOpInterface>(call_op.getOperation()),
242265
cast<CallableOpInterface>(callee.getOperation()),
@@ -246,6 +269,13 @@ LogicalResult DecomposeTFOpsPass::InlineTFRFuncCalls() {
246269
// This call will be raised to TF ops.
247270
return WalkResult::interrupt();
248271
}
272+
273+
// Propagate all the attributes to the inlined operations, which are defined
274+
// by the two boundary operations.
275+
PropagateAttrsToOperations(call_op, Block::iterator(inlined_point),
276+
Block::iterator(after_inlined_point));
277+
278+
// Remove the call_op to finish the op expansion.
249279
call_op.erase();
250280
changed |= true;
251281
return WalkResult::advance();

tensorflow/compiler/mlir/tfr/passes/passes.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,10 @@ limitations under the License.
2525
namespace mlir {
2626
namespace TFR {
2727

28-
void populateSCFOpsCanonicalizationPatterns(OwningRewritePatternList &results,
29-
MLIRContext *context);
28+
// Scans the func op and adds all the canonicalization patterns of the ops
29+
// except the tf ops, inside the function.
30+
void populateCanonicalizationPatterns(FuncOp func,
31+
OwningRewritePatternList &patterns);
3032

3133
// Decompose ops.
3234
std::unique_ptr<OperationPass<FuncOp>> CreateDecomposeTFOpsPass(

tensorflow/compiler/mlir/tfr/passes/raise_to_tf.cc

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,10 @@ LogicalResult RewriteTFRCallOp::CreateAndReplaceOp(
377377
new_results.push_back(list_op.out());
378378
}
379379
}
380+
381+
// Copy all the allowed attributes to the new op.
382+
if (failed(CopyNonSymbolRefAttrs(call_op, new_op))) return failure();
383+
380384
rewriter.replaceOp(call_op, new_results);
381385
return success();
382386
}
@@ -450,9 +454,8 @@ void RaiseToTFOpsPass::runOnFunction() {
450454

451455
OwningRewritePatternList patterns;
452456
patterns.insert<RewriteTFRCallOp>(ctx, table, materialize_derived_attrs);
453-
for (auto* op : ctx->getRegisteredOperations()) {
454-
op->getCanonicalizationPatterns(patterns, ctx);
455-
}
457+
458+
populateCanonicalizationPatterns(func, patterns);
456459

457460
applyPatternsAndFoldGreedily(func, std::move(patterns));
458461
}

tensorflow/compiler/mlir/tfr/resources/decomposition_lib.mlir

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ tfr.func @tf__my_add_n(%values: !tfr.tensor_list,
2828
tfr.return %res : !tfr.tensor
2929
}
3030

31+
tfr.func @tf__my_add_n_(!tfr.tensor_list<N,T>, i64 {tfr.name="N"}) -> !tfr.tensor attributes {N,T}
32+
3133
// Translated from tf.compose Python function.
3234
tfr.func @tf__my_biased_dense(%input: !tfr.tensor, %weight: !tfr.tensor,
3335
%bias: !tfr.tensor,
@@ -55,6 +57,9 @@ tfr.func @tf__my_biased_dense(%input: !tfr.tensor, %weight: !tfr.tensor,
5557
tfr.return %res : !tfr.tensor
5658
}
5759

60+
tfr.func @tf__my_biased_dense_(!tfr.tensor<T>, !tfr.tensor<T>, !tfr.tensor<T>,
61+
!tfr.attr{tfr.name="act", tfr.default=""}) -> !tfr.tensor attributes {T}
62+
5863
// This is a wong decomposition and used to verify that tf.Elu isn't decomposed
5964
// since its kernel has been registered.
6065
tfr.func @tf__elu_(%input: !tfr.tensor) -> !tfr.tensor {

tensorflow/compiler/mlir/tfr/tests/decompose.mlir

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,3 +82,43 @@ func @decompose_fused_n(%arg0: tensor<1x2x3x4x!tf.string>, %arg1: tensor<f32>, %
8282
// CHECK-NEXT: return %[[back]] : tensor<f32>
8383
}
8484

85+
// CHECK-LABEL: attribute_propagate_direct
86+
func @attribute_propagate_direct(%arg0: tensor<1x2x3x4x!tf.string>) -> tensor<1x2x3x4x!tf.string> {
87+
%0 = "tf.Intermediate"(%arg0) {_tpu_replicate, device="hello"} : (tensor<1x2x3x4x!tf.string>) -> tensor<1x2x3x4x!tf.string>
88+
return %0 : tensor<1x2x3x4x!tf.string>
89+
90+
// CHECK-NEXT: %[[casted:.*]] = "tfr.cast"(%arg0) : (tensor<1x2x3x4x!tf.string>) -> !tfr.tensor
91+
// CHECK-NEXT: %[[id:.*]] = tfr.call @tf__risc(%[[casted]]) {_tpu_replicate, device = "hello"}
92+
// CHECK-NEXT: %[[back:.*]] = "tfr.cast"(%[[id]]) : (!tfr.tensor) -> tensor<1x2x3x4x!tf.string>
93+
// CHECK-NEXT: return %[[back]]
94+
}
95+
96+
// CHECK-LABEL: attribute_propagate
97+
func @attribute_propagate(%arg0: tensor<1x2x3x4x!tf.string>, %arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<f32> {
98+
%0:2 = "tf.FusedN"(%arg0, %arg1, %arg2) {A=0:index, _tpu_replicate, device="hello"} : (tensor<1x2x3x4x!tf.string>, tensor<f32>, tensor<f32>) -> (tensor<1x2x3x4x!tf.string>, tensor<f32>)
99+
return %0#1 : tensor<f32>
100+
101+
// CHECK-NEXT: %[[in0:.*]] = "tfr.cast"(%arg0) : (tensor<1x2x3x4x!tf.string>) -> !tfr.tensor
102+
// CHECK-NEXT: %[[in1:.*]] = "tfr.cast"(%arg1) : (tensor<f32>) -> !tfr.tensor
103+
// CHECK-NEXT: %[[id0:.*]] = tfr.call @tf__risc(%[[in0]]) {_tpu_replicate, device = "hello"}
104+
// CHECK-NEXT: %[[id1:.*]] = tfr.call @tf__risc(%[[in1]]) {_tpu_replicate, device = "hello"}
105+
// CHECK-NEXT: %[[back:.*]] = "tfr.cast"(%[[id1]]) : (!tfr.tensor) -> tensor<f32>
106+
// CHECK-NEXT: return %[[back]] : tensor<f32>
107+
}
108+
109+
// CHECK-LABEL: no_tf_canonicalization
110+
func @no_tf_canonicalization(%arg0: tensor<8xi1>, %arg1: tensor<8x3xf32>, %arg2: tensor<8x3xf32>) -> tensor<8x3xf32> {
111+
%0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<8xi1>, tensor<8x3xf32>, tensor<8x3xf32>) -> tensor<8x3xf32>
112+
return %0: tensor<8x3xf32>
113+
114+
// CHECK: "tf.Select"
115+
}
116+
117+
// CHECK-LABEL: denied_attribute
118+
func @denied_attribute(%arg0: tensor<1x2x3x4x!tf.string>, %arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<f32> {
119+
// expected-error@+1 {{Denied unregistered attribute was found: denied_attr}}
120+
%0:2 = "tf.FusedN"(%arg0, %arg1, %arg2) {A=0:index, denied_attr} : (tensor<1x2x3x4x!tf.string>, tensor<f32>, tensor<f32>) -> (tensor<1x2x3x4x!tf.string>, tensor<f32>)
121+
return %0#1 : tensor<f32>
122+
123+
// CHECK-NEXT: "tf.FusedN"(%arg0, %arg1, %arg2) {A = 0 : index, denied_attr}
124+
}

0 commit comments

Comments
 (0)