Skip to content

Commit 9011878

Browse files
abatterytensorflower-gardener
authored andcommitted
Fix castings with invalid assumptions in mlir/lite.
PiperOrigin-RevId: 344753711 Change-Id: I8f01a52929888a124387323a8f30cebf048de15d
1 parent bc295ed commit 9011878

File tree

7 files changed

+64
-4
lines changed

7 files changed

+64
-4
lines changed

tensorflow/compiler/mlir/lite/ir/tfl_ops.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,8 @@ Attribute ConstFoldUnaryOp(Type result_type, Attribute operand,
494494
assert(IsF32ShapedType(result_type) || IsBF16ShapedType(result_type));
495495
auto result_shape_type = result_type.cast<ShapedType>();
496496

497+
if (!result_shape_type.hasStaticShape()) return {};
498+
497499
if (auto dense_elements = operand.dyn_cast_or_null<DenseElementsAttr>()) {
498500
SmallVector<APFloat, 16> new_values;
499501
const int num_elements = result_shape_type.getNumElements();
@@ -1740,7 +1742,8 @@ static LogicalResult Verify(LSTMOp op) {
17401742
op.forget_layer_norm_coefficients().getType().cast<ShapedType>();
17411743
// If this lstm has layer normalization, this input value,
17421744
// "forget_layer_norm_coefficients" should be a 1D tensor.
1743-
if (forget_layer_norm_coefficients.getRank() != 1 ||
1745+
if (!forget_layer_norm_coefficients.hasRank() ||
1746+
forget_layer_norm_coefficients.getRank() != 1 ||
17441747
forget_layer_norm_coefficients.getDimSize(0) != n_cell)
17451748
return op.emitOpError(
17461749
"coefficient inputs have more than 2 dimensions or "

tensorflow/compiler/mlir/lite/tests/dilated-conv.mlir

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,28 @@ func @testDilatedDepthWiseConvWithExpandSqueeze3(%arg0: tensor<1x128x128xf32>, %
283283
// CHECK-NEXT: return [[RESULT]] : tensor<1x128x128xf32>
284284
}
285285

286+
func @testAvoidDilatedConvWithExpand(%arg0: tensor<*xf32>, %arg1: tensor<5x5x1x1xf32>, %arg2: tensor<128xf32>) -> tensor<1x128x128xf32> {
287+
%cst = constant dense<[2, 2]> : tensor<2xi32>
288+
%cst_0 = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32>
289+
%cst_1 = constant dense<4> : tensor<2x2xi32>
290+
%cst_2 = constant dense<0> : tensor<2x2xi32>
291+
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %cst_1) : (tensor<*xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32>
292+
%1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x68x68xf32>, tensor<i32>) -> tensor<4x68x68x1xf32>
293+
%2 = "tf.Conv2D"(%1, %arg1) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32>
294+
%3 = "tf.Squeeze"(%2) {squeeze_dims = [3]} : (tensor<4x64x64x1xf32>) -> tensor<4x64x64xf32>
295+
%4 = "tf.BatchToSpaceND"(%3, %cst, %cst_2) : (tensor<4x64x64xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128xf32>
296+
%5 = "tf.BiasAdd"(%4, %arg2) : (tensor<1x128x128xf32>, tensor<128xf32>) -> tensor<1x128x128xf32>
297+
return %5 : tensor<1x128x128xf32>
298+
299+
// CHECK-LABEL: testAvoidDilatedConvWithExpand
300+
// CHECK: "tf.SpaceToBatchND"
301+
// CHECK: "tf.ExpandDims"
302+
// CHECK: "tf.Conv2D"
303+
// CHECK: "tf.Squeeze"
304+
// CHECK: "tf.BatchToSpaceND"
305+
// CHECK: "tf.BiasAdd"
306+
}
307+
286308
func @testDilatedConvWithDifferentExpandSqueezeAxis(%arg0: tensor<1x128x128xf32>, %arg1: tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32> {
287309
%cst = constant dense<[2, 2]> : tensor<2xi32>
288310
%cst_0 = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32>

tensorflow/compiler/mlir/lite/tests/ops.mlir

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -810,6 +810,16 @@ func @testLstmWithInvalidInputsRankMatch(%arg0: tensor<1x4xf32>, %arg1: tensor<4
810810
return %24 : tensor<1x4xf32>
811811
}
812812

813+
// -----
814+
815+
// Coefficient inputs of LSTM op have unknown rank.
816+
func @testLstmWithInvalidInputsRankMatch(%arg0: tensor<1x4xf32>, %arg1: tensor<4x4xf32>, %arg2: tensor<4x4xf32>, %arg3: tensor<4x4xf32>, %arg4: tensor<4x4xf32>, %arg5: tensor<4x4xf32>, %arg6: tensor<4x4xf32>, %arg7: tensor<4x4xf32>, %arg8: tensor<4x4xf32>, %arg9: tensor<4xf32>, %arg10: tensor<4xf32>, %arg11: tensor<4xf32>, %arg12: tensor<1x4xf32>, %arg13: tensor<4xf32>, %arg14: tensor<4xf32>, %arg15: tensor<4xf32>, %arg16: tensor<4x4xf32>, %arg17: tensor<4xf32>, %arg18: tensor<3xf32>, %arg19: tensor<3xf32>, %arg20: tensor<3xf32>, %arg21: tensor<*xf32>) -> tensor<1x4xf32> {
817+
%cst0 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<1x4xf32>} : () -> tensor<1x4xf32> loc("Const")
818+
%cst1 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<1x4xf32>} : () -> tensor<1x4xf32> loc("Const")
819+
// expected-error @+1 {{'tfl.lstm' op coefficient inputs have more than 2 dimensions or don't match the dimension with input operand `input_to_output_weights`.}}
820+
%24 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %cst0, %cst1, %arg18, %arg19, %arg20, %arg21) ({}) {cell_clip = 0.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<*xf32>) -> tensor<1x4xf32>
821+
return %24 : tensor<1x4xf32>
822+
}
813823

814824
// -----
815825

tensorflow/compiler/mlir/lite/tests/optimize.mlir

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,18 @@ func @fuseAddIntoFollowingFullyConnected(%arg0: tensor<4x2xf32>) -> tensor<4x2xf
343343
// CHECK-NEXT: return %[[fc]] : tensor<4x2xf32>
344344
}
345345

346+
// CHECK-LABEL: @doNotFuseAddIntoFollowingFullyConnected
347+
func @doNotFuseAddIntoFollowingFullyConnected(%arg0: tensor<4x2xf32>, %arg1: tensor<*xf32>) -> tensor<4x2xf32> {
348+
%cst1 = constant dense<1.5> : tensor<f32>
349+
%0 = "tfl.add"(%arg0, %cst1) {fused_activation_function = "NONE"} : (tensor<4x2xf32>, tensor<f32>) -> tensor<4x2xf32>
350+
%cst = constant dense<2.0> : tensor<2xf32>
351+
%1 = "tfl.fully_connected"(%0, %arg1, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<4x2xf32>, tensor<*xf32>, tensor<2xf32>) -> tensor<4x2xf32>
352+
return %1 : tensor<4x2xf32>
353+
354+
// CHECK: "tfl.add"
355+
// CHECK: "tfl.fully_connected"
356+
}
357+
346358
// CHECK-LABEL: @fuseMulIntoFollowingFullyConnected
347359
func @fuseMulIntoFollowingFullyConnected(%arg0: tensor<4x2xf32>) -> tensor<4x2xf32> {
348360
%cst2 = constant dense<1.5> : tensor<f32>

tensorflow/compiler/mlir/lite/transforms/dense_to_sparse.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,9 @@ InspectResult InspectWeight(
175175
} else if (auto cst = dyn_cast<QConstOp>(inst)) {
176176
attr = cst.value();
177177
type = cst.getType().cast<ShapedType>();
178+
} else {
179+
result.can_compress = false;
180+
return result;
178181
}
179182

180183
// Currently we only support compressing weights of ops:
@@ -222,6 +225,8 @@ std::vector<T> BuildSparsityParameterAttribute(
222225
} else if (auto cst = dyn_cast<QConstOp>(inst)) {
223226
attr = cst.value();
224227
type = cst.getType().cast<ShapedType>();
228+
} else {
229+
assert(false && "Expected a constant-like op");
225230
}
226231
const int dims_count = type.getRank();
227232
std::vector<int> shape(dims_count);

tensorflow/compiler/mlir/lite/transforms/dilated_conv.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,12 @@ LogicalResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
185185
llvm::Optional<ArrayAttr> dilations_attr = ExtractDilationsAttrFromBlockShape(
186186
stb_op.block_shape(), bts_op.block_shape(), rewriter);
187187
if (!dilations_attr.hasValue()) return failure();
188-
op.setAttr("dilations", dilations_attr.getValue());
188+
189+
if (expand_op) {
190+
if (stb_op.input().getType().dyn_cast<RankedTensorType>() == nullptr) {
191+
return failure();
192+
}
193+
}
189194

190195
// TODO(b/149936532): Check that the input width & height are multiples of
191196
// dilation rate.
@@ -234,6 +239,9 @@ LogicalResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
234239
}
235240
}
236241

242+
// Set dilations
243+
op.setAttr("dilations", dilations_attr.getValue());
244+
237245
if (expand_op) {
238246
// If there is `expand_op`, we need to rewire the inputs to bypass the
239247
// `SpaceToBatch`, `BatchToSpace` and `Pad` op. E.g, turning

tensorflow/compiler/mlir/lite/transforms/optimize.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -351,10 +351,10 @@ struct FuseFullyConnectedAndAdd : public OpRewritePattern<TFL::AddOp> {
351351
// to properly broadcast the scalar to `{num_channels}` shape.
352352

353353
// Get the number of channels if possible.
354-
auto filter_type = filter.getType().cast<ShapedType>();
354+
auto filter_type = filter.getType().dyn_cast<RankedTensorType>();
355355
// Filter must be a `2D` tensor with `{num_channels, num_features}`
356356
// shape. The following check is rejecting unknown rank (-1).
357-
if (filter_type.getRank() != 2) {
357+
if (filter_type == nullptr || filter_type.getRank() != 2) {
358358
return failure();
359359
}
360360
int num_channels = filter_type.getShape()[0];

0 commit comments

Comments
 (0)