Skip to content

Commit 670cf67

Browse files
abatterytensorflower-gardener
authored andcommitted
Make sure that RankedTensorType casting is valid.
PiperOrigin-RevId: 344509898 Change-Id: I7053148588a93e60099605f7a90419a1f14a0667
1 parent 902d90a commit 670cf67

File tree

4 files changed

+158
-54
lines changed

4 files changed

+158
-54
lines changed

tensorflow/compiler/mlir/lite/tests/prepare-composite-functions-tf.mlir

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,30 @@ func @layernormalizedlstmcellsimple(%arg0: tensor<1x?xf32>, %arg1: tensor<3x4xf3
153153

154154
// -----
155155

156+
module{
157+
158+
// expected-warning @+1 {{we cannot fuse this lstm func because all the inputs have not ranked tensor type.}}
159+
func @lstmcellsimple(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>, %arg3: tensor<*xf32>, %arg4: tensor<*xf32>) -> tensor<*xf32> attributes {tf._implements = "LSTMCellSimple", tf._reference = "mlir"} {
160+
%0 = "tf.BatchMatMulV2"(%arg3, %arg1) {adj_x = false, adj_y = false} : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
161+
%1 = constant dense<[[2.3, 3.4, 4.5, 5.5]]> : tensor<1x4xf32>
162+
%2 = "tf.Add"(%0, %1) : (tensor<*xf32>, tensor<1x4xf32>) -> tensor<*xf32>
163+
%3 = tensor_cast %2 : tensor<*xf32> to tensor<*xf32>
164+
return %3 : tensor<*xf32>
165+
}
166+
167+
// expected-warning @+1 {{we cannot fuse this lstm func because all the inputs have not ranked tensor type.}}
168+
func @layernormalizedlstmcellsimple(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>, %arg3: tensor<*xf32>, %arg4: tensor<*xf32>) -> tensor<*xf32> attributes {tf._implements = "LayerNormalizedLstmCellSimple", tf._reference = "mlir"} {
169+
%0 = "tf.BatchMatMulV2"(%arg3, %arg1) {adj_x = false, adj_y = false} : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
170+
%1 = constant dense<[[2.3, 3.4, 4.5, 5.5]]> : tensor<1x4xf32>
171+
%2 = "tf.Add"(%0, %1) : (tensor<*xf32>, tensor<1x4xf32>) -> tensor<*xf32>
172+
%3 = tensor_cast %2 : tensor<*xf32> to tensor<*xf32>
173+
return %3 : tensor<*xf32>
174+
}
175+
176+
}
177+
178+
// -----
179+
156180
module {
157181
func @inference_standard_lstm_time_major(%arg0: tensor<?x8x8xf32>, %arg1: tensor<8x10xf32>, %arg2: tensor<8x10xf32>, %arg3: tensor<8x40xf32>, %arg4: tensor<10x40xf32>, %arg5: tensor<40xf32>) -> (tensor<8x10xf32>, tensor<?x8x10xf32>, tensor<8x10xf32>, tensor<8x10xf32>, tensor<f32>) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: 8 } dim { size: 10 }", "tfshape$dim { size: 8 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: false", "tfshape$unknown_rank: false"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = false, tf.time_major = true} {
158182
%0 = "tf.BatchMatMulV2"(%arg0, %arg3) {adj_x = false, adj_y = false} : (tensor<?x8x8xf32>, tensor<8x40xf32>) -> tensor<?x8x40xf32>

tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -485,6 +485,18 @@ func @StridedSliceEllipsisMaskAfter(%arg0: tensor<21x15x7xf32>) -> tensor<5x15x7
485485
// CHECK: %[[STRIDED_SLICE:.*]] = "tf.StridedSlice"(%arg0, %[[CST]], %[[CST]], %[[CST_0]]) {begin_mask = 6 : i64, ellipsis_mask = 0 : i64, end_mask = 6 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<21x15x7xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<5x15x7xf32>
486486
}
487487

488+
// CHECK-LABEL: @NoStridedSliceEllipsisMask
489+
func @NoStridedSliceEllipsisMask(%arg0: tensor<*xf32>) -> tensor<21x15x2xf32> {
490+
%cst = constant dense<0> : tensor<2xi32>
491+
%cst_0 = constant dense<1> : tensor<2xi32>
492+
%0 = "tf.StridedSlice"(%arg0, %cst, %cst, %cst_0) {begin_mask = 0 : i64, ellipsis_mask = 1 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<*xf32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<21x15x2xf32>
493+
return %0 : tensor<21x15x2xf32>
494+
495+
// CHECK: %[[CST:.*]] = constant dense<0> : tensor<2xi32>
496+
// CHECK: %[[CST_0:.*]] = constant dense<1> : tensor<2xi32>
497+
// CHECK: %[[STRIDED_SLICE:.*]] = "tf.StridedSlice"(%arg0, %[[CST]], %[[CST]], %[[CST_0]]) {begin_mask = 0 : i64, ellipsis_mask = 1 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<*xf32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<21x15x2xf32>
498+
}
499+
488500
// CHECK-LABEL: @NoPadStridedSliceNonNewAxisMask
489501
func @NoPadStridedSliceNonNewAxisMask(%arg0: tensor<1x2x3x1xf32>) -> tensor<1x2x3x1xf32> {
490502
%cst = constant dense<0> : tensor<4xi32>

tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc

Lines changed: 117 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,117 @@ class PrepareCompositeFunctionsPass
125125
void runOnOperation() override;
126126
};
127127

128+
LogicalResult CheckFusableLayerNormalizedLstmCellSimple(FuncOp lstm_func) {
129+
for (int i = 0; i < 5; ++i) {
130+
auto input = lstm_func.getArgument(i);
131+
auto input_type = input.getType().dyn_cast_or_null<RankedTensorType>();
132+
if (!input_type) {
133+
lstm_func.emitWarning(
134+
"we cannot fuse this lstm func because all the inputs have not "
135+
"ranked tensor type.");
136+
return failure();
137+
}
138+
}
139+
140+
return success();
141+
}
142+
143+
LogicalResult CheckFusableLstmCellSimple(FuncOp lstm_func) {
144+
for (int i = 0; i < 4; ++i) {
145+
auto input = lstm_func.getArgument(i);
146+
auto input_type = input.getType().dyn_cast_or_null<RankedTensorType>();
147+
if (!input_type) {
148+
lstm_func.emitWarning(
149+
"we cannot fuse this lstm func because all the inputs have not "
150+
"ranked tensor type.");
151+
return failure();
152+
}
153+
}
154+
155+
return success();
156+
}
157+
158+
LogicalResult CheckOutputConsumer(
159+
Operation* call_op, int expected_num_outputs,
160+
llvm::DenseSet<int> expected_consumer_indices) {
161+
const int num_results = call_op->getNumResults();
162+
if (num_results != expected_num_outputs) return failure();
163+
164+
for (int i = 0; i < expected_num_outputs; ++i) {
165+
auto it = expected_consumer_indices.find(i);
166+
if (it == expected_consumer_indices.end()) {
167+
// Unexpected consumer.
168+
if (!call_op->getResult(i).use_empty()) return failure();
169+
}
170+
}
171+
return success();
172+
}
173+
174+
LogicalResult CheckFusableKerasLstm(FuncOp lstm_func, ModuleOp module) {
175+
for (auto func : module.getOps<FuncOp>()) {
176+
if (func == lstm_func) continue;
177+
auto result = func.walk([&](CallOpInterface op) {
178+
if (dyn_cast<FuncOp>(op.resolveCallable()) == lstm_func) {
179+
// Keras LSTM have 5 outputs.
180+
// We should make sure only the first or the second output are
181+
// consumed.
182+
if (failed(CheckOutputConsumer(op.getOperation(), 5, {0, 1})))
183+
return WalkResult::interrupt();
184+
}
185+
return WalkResult::advance();
186+
});
187+
188+
if (result.wasInterrupted()) return failure();
189+
}
190+
191+
// We should know the batch size in advance for the lstm fusion.
192+
// A good indicator of batch size is both cell state and input state (indices
193+
// 1 & 2) have fixed shape and other input tenors should have ranked tensor
194+
// types.
195+
for (int i = 0; i < 6; ++i) {
196+
auto input = lstm_func.getArgument(i);
197+
auto input_type = input.getType().dyn_cast_or_null<RankedTensorType>();
198+
if (!input_type) {
199+
lstm_func.emitWarning(
200+
"we cannot fuse this lstm func because all the inputs have not "
201+
"ranked tensor type.");
202+
return failure();
203+
}
204+
switch (i) {
205+
case 1: // output_init_state
206+
case 2: // hidden_init_state
207+
if (!input_type.hasStaticShape()) {
208+
lstm_func.emitWarning(
209+
"we cannot fuse this lstm func because the batch size is not "
210+
"fixed, please consider setting fixed batch size like "
211+
"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/"
212+
"lite/examples/experimental_new_converter/"
213+
"Keras_LSTM_fusion_Codelab.ipynb");
214+
return failure();
215+
}
216+
break;
217+
case 3: // wiehgt
218+
case 4: // recurrent_kernel
219+
case 5: // bias
220+
if (!input_type.hasStaticShape()) {
221+
lstm_func.emitWarning(
222+
"we cannot fuse this lstm func because the weight & bias are not "
223+
"fixed, please consider setting fixed batch size like "
224+
"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/"
225+
"lite/examples/experimental_new_converter/"
226+
"Keras_LSTM_fusion_Codelab.ipynb");
227+
return failure();
228+
}
229+
break;
230+
default:
231+
// No op.
232+
break;
233+
}
234+
}
235+
236+
return success();
237+
}
238+
128239
void PrepareCompositeFunctionsPass::ConvertTFImplements(FuncOp func,
129240
StringAttr attr) {
130241
if (attr.getValue() == "embedding_matmul") {
@@ -138,13 +249,19 @@ void PrepareCompositeFunctionsPass::ConvertTFImplements(FuncOp func,
138249
}
139250
convert_embedded_lookup.RewriteFunc();
140251
} else if (attr.getValue() == mlir::TFL::kLstmCellSimple) {
252+
// Check if the lstm cell simple can be fused, if not, we just don't do
253+
// anything.
254+
if (failed(CheckFusableLstmCellSimple(func))) return;
141255
func.eraseBody();
142256
func.addEntryBlock();
143257
ConvertLSTMCellSimpleToFusedLSTM convert_lstm_cell_simple(func);
144258
if (failed(convert_lstm_cell_simple.RewriteFunc())) {
145259
return signalPassFailure();
146260
}
147261
} else if (attr.getValue() == mlir::TFL::kLayerNormalizedLstmCellSimple) {
262+
// Check if the layer normalized lstm cell simple can be fused, if not, we
263+
// just don't do anything.
264+
if (failed(CheckFusableLayerNormalizedLstmCellSimple(func))) return;
148265
func.eraseBody();
149266
func.addEntryBlock();
150267
ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM
@@ -181,59 +298,6 @@ void PrepareCompositeFunctionsPass::ConvertTFImplementsWithAttributes(
181298
}
182299
}
183300

184-
LogicalResult CheckOutputConsumer(
185-
Operation* call_op, int expected_num_outputs,
186-
llvm::DenseSet<int> expected_consumer_indices) {
187-
const int num_results = call_op->getNumResults();
188-
if (num_results != expected_num_outputs) return failure();
189-
190-
for (int i = 0; i < expected_num_outputs; ++i) {
191-
auto it = expected_consumer_indices.find(i);
192-
if (it == expected_consumer_indices.end()) {
193-
// Unexpected consumer.
194-
if (!call_op->getResult(i).use_empty()) return failure();
195-
}
196-
}
197-
return success();
198-
}
199-
200-
LogicalResult CheckFusableKerasLstm(FuncOp lstm_func, ModuleOp module) {
201-
for (auto func : module.getOps<FuncOp>()) {
202-
if (func == lstm_func) continue;
203-
auto result = func.walk([&](CallOpInterface op) {
204-
if (dyn_cast<FuncOp>(op.resolveCallable()) == lstm_func) {
205-
// Keras LSTM have 5 outputs.
206-
// We should make sure only the first or the second output are
207-
// consumed.
208-
if (failed(CheckOutputConsumer(op.getOperation(), 5, {0, 1})))
209-
return WalkResult::interrupt();
210-
}
211-
return WalkResult::advance();
212-
});
213-
214-
if (result.wasInterrupted()) return failure();
215-
}
216-
217-
// We should know the batch size in advance for the lstm fusion.
218-
// A good indicator of batch size is both cell state and input state have
219-
// fixed shape. (indices 1 & 2).
220-
for (int i = 1; i < 3; ++i) {
221-
auto input = lstm_func.getArgument(i);
222-
auto input_type = input.getType().dyn_cast_or_null<RankedTensorType>();
223-
if (!input_type || !input_type.hasStaticShape()) {
224-
lstm_func.emitWarning(
225-
"we cannot fuse this lstm func because the batch size is not fixed, "
226-
"please consider setting fixed batch size like "
227-
"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/"
228-
"lite/examples/experimental_new_converter/"
229-
"Keras_LSTM_fusion_Codelab.ipynb");
230-
return failure();
231-
}
232-
}
233-
234-
return success();
235-
}
236-
237301
void PrepareCompositeFunctionsPass::ConvertTFAPIImplements(FuncOp func,
238302
StringAttr attr,
239303
ModuleOp module) {

tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -526,6 +526,7 @@ struct ConvertTFStridedSlice : public RewritePattern {
526526

527527
// Insert a new reshape op.
528528
Value original_input = strided_slice_op.input();
529+
// TODO(b/174267775): Make sure that the input type has ranked tensor type.
529530
RankedTensorType original_input_type =
530531
original_input.getType().cast<RankedTensorType>();
531532
const ArrayRef<int64_t> &original_input_shape =
@@ -619,7 +620,10 @@ struct ConvertTFStridedSlice : public RewritePattern {
619620
}
620621

621622
Value input = strided_slice_op.input();
622-
RankedTensorType input_type = input.getType().cast<RankedTensorType>();
623+
RankedTensorType input_type = input.getType().dyn_cast<RankedTensorType>();
624+
if (!input_type) {
625+
return failure();
626+
}
623627
const ArrayRef<int64_t> input_shape = input_type.getShape();
624628

625629
const int input_size = input_shape.size();

0 commit comments

Comments
 (0)